CN116227582A - 掩码自编码器的知识蒸馏方法、装置、设备及存储介质 - Google Patents
掩码自编码器的知识蒸馏方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- CN116227582A CN116227582A CN202310118939.7A CN202310118939A CN116227582A CN 116227582 A CN116227582 A CN 116227582A CN 202310118939 A CN202310118939 A CN 202310118939A CN 116227582 A CN116227582 A CN 116227582A
- Authority
- CN
- China
- Prior art keywords
- model
- loss
- encoder
- training
- intermediate feature
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Molecular Biology (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Feedback Control In General (AREA)
- Exposure Of Semiconductors, Excluding Electron Or Ion Beam Exposure (AREA)
Abstract
本发明公开了一种掩码自编码器的知识蒸馏方法、装置、设备及存储介质,该方法通过分别建立掩码自编码器的教师模型和学生模型,其中,所述教师模型和所述学生模型均为视觉变换模型,且所述教师模型的规模大于所述学生模型;对所述教师模型进行预训练;基于预训练好的所述教师模型对所述学生模型进行知识蒸馏预训练,使学生模型从预训练好的教师模型中学习数据泛化能力,得到表征能力更好的图像特征;基于下游任务对预训练好的所述学生模型进行微调训练,学生模型可部署在算力资源缺乏的电力边缘侧,在减少模型参数的同时保证了模型精度不下降,加速实时推理速度。
Description
技术领域
本发明涉及掩码自编码器压缩技术领域,尤其涉及一种掩码自编码器的知识蒸馏方法、装置、设备及存储介质。
背景技术
目前大规模预训练模型在下游任务上有着非常好的表现,但是在电网具体应用场景中,例如输变电图像缺陷检测任务中,实时性要求较高,,需要在电网的边缘侧对电力图像进行分类,边缘侧计算能力和存储空间有限,大模型不适合直接部署到边缘侧,因此需要对这类模型进行压缩。模型压缩方法中广泛使用的是基于教师(Teacher)-学生(Student)框架的知识蒸馏。该方法首先预训练参数量大的Teacher模型,然后利用Teacher模型的中间或最后输出结果监督Student模型的训练,将Teacher知识蒸馏到Student模型,提升Student模型学习能力和泛化数据能力。
近年来,利用Transformer模型将图像分成图像块(patch)序列的模型—ViTs(包括多个视觉变换(Vision Transformer,ViT)模型)受到了很大的关注,预训练的掩码自编码器(Masked AutoEncoders,MAE,模型结构为ViTs)在下游任务表现良好,然而这类模型参数量大,且没有相应的压缩方法简化其模型结构、减少其模型参数量,使其难以部署在电网边缘侧,无法快速利用现场采集数据进行针对下游任务的微调(fine-tune)。
发明内容
有鉴于此,本发明实施例提供了一种掩码自编码器的知识蒸馏方法、装置、设备及存储介质,以解决掩码自编码器模型参数大的技术问题。
本发明提出的技术方案如下:
本发明实施例第一方面提供了一种掩码自编码器的知识蒸馏方法,包括:
分别建立掩码自编码器的教师模型和学生模型,其中,所述教师模型和所述学生模型均为视觉变换模型,且所述教师模型的规模大于所述学生模型;对所述教师模型进行预训练;基于预训练好的所述教师模型对所述学生模型进行知识蒸馏预训练;基于下游任务对预训练好的所述学生模型进行微调训练。
可选地,所述对所述教师模型进行预训练,包括:根据第一预设掩码率对第一输入图像进行掩码;将掩码后的第一输入图像输入到所述教师模型的编码器得到第一中间特征;将所述第一中间特征输入到所述教师模型的解码器得到第一重构图像;根据所述第一输入图像和所述第一重构图像获取第一损失函数;基于所述第一损失函数对所述教师模型进行预训练。
可选地,所述基于预训练好的所述教师模型对所述学生模型进行知识蒸馏预训练,包括:基于所述学生模型和预训练好的所述教师模型获取重构损失和蒸馏损失;基于所述重构损失和蒸馏损失确定第二损失函数;基于所述第二损失函数对所述学生模型进行知识蒸馏预训练。
可选地,所述基于所述学生模型和预训练好的所述教师模型获取重构损失和蒸馏损失,包括:根据第二预设掩码率对第二输入图像进行掩码;将掩码后的所述第二输入图像输入到所述学生模型的编码器得到第二中间特征,将所述第二中间特征输入到所述学生模型的解码器得到第二重构图像;将掩码后的所述第二输入图像输入到预训练好的所述教师模型的编码器得到第三中间特征,将所述第三中间特征输入到预训练好的所述教师模型的解码器得到第三重构图像;基于所述第二重构图像获取重构损失;基于所述第二中间特征、第三中间特征、第二重构图像和第三重构图像获取蒸馏损失。
可选地,所述基于所述第二重构图像获取重构损失包括:获取所述第二输入图像和所述第二重构图像在掩蔽区域的像素点的距离,以及所述第二输入图像和所述第二重构图像在未掩蔽区域的像素点的距离;基于在掩蔽区域的像素点的距离和在未掩蔽区域的像素点的距离获取所述重构损失。
可选地,所述基于所述第二中间特征、第三中间特征、第二重构图像和第三重构图像获取蒸馏损失,包括:计算所述第二中间特征和所述第三中间特征的距离;基于预设的温度系数计算所述第二重构图像和所述第三重构图像在掩蔽区域的像素点的距离;基于所述第二中间特征和所述第三中间特征的距离,以及所述第二重构图像和所述第三重构图像在掩蔽区域的像素点的距离获取蒸馏损失。
可选地,所述基于下游任务对预训练好的所述学生模型进行微调训练,包括:根据预训练好的所述学生模型的编码器获取下游任务的训练集的第四中间特征;将所述第四中间特征输入到下游任务的映射层进行微调训练。
本发明实施例第二方面提供一种掩码自编码器的知识蒸馏装置,包括:
模型建立模块,用于分别建立掩码自编码器的教师模型和学生模型,其中,所述教师模型和所述学生模型均为视觉变换模型,且所述教师模型的规模大于所述学生模型;教师模型训练模块,用于对所述教师模型进行预训练;知识蒸馏预训练模块,用于基于预训练好的所述教师模型对所述学生模型进行知识蒸馏预训练;微调训练模块,用于基于下游任务对预训练好的所述学生模型进行微调训练。
可选地,所述教师模型训练模块包括:第一掩码模块,用于根据第一预设掩码率对第一输入图像进行掩码;第一输入模块,用于将掩码后的第一输入图像输入到教师模型的编码器得到第一中间特征;第一重构模块,用于将第一中间特征输入到教师模型的解码器得到第一重构图像;第一损失模块,用于根据第一输入图像和第一重构图像获取第一损失函数;第一预训练模块,用于基于第一损失函数对教师模型进行预训练。
可选地,所述知识蒸馏预训练模块包括:第一获取模块,用于基于学生模型和预训练好的教师模型获取重构损失和蒸馏损失;第二损失模块,用于基于重构损失和蒸馏损失确定第二损失函数;第二预训练模块,用于基于第二损失函数对学生模型进行知识蒸馏预训练。
可选地,所述第一获取模块包括:第二掩码模块,用于根据第二预设掩码率对第二输入图像进行掩码;第二重构模块,用于将掩码后的第二输入图像输入到学生模型的编码器得到第二中间特征,将第二中间特征输入到学生模型的解码器得到第二重构图像;第三重构模块,用于将掩码后的第二输入图像输入到预训练好的教师模型的编码器得到第三中间特征,将第三中间特征输入到预训练好的教师模型的解码器得到第三重构图像;第二损失模块,用于基于第二重构图像获取重构损失;第三损失模块,用于基于第二中间特征、第三中间特征、第二重构图像和第三重构图像获取蒸馏损失。
可选地,所述第二损失模块包括:第二获取模块,用于获取第二输入图像和第二重构图像在掩蔽区域的像素点的距离,以及第二输入图像和第二重构图像在未掩蔽区域的像素点的距离;重构损失模块,用于基于在掩蔽区域的像素点的距离和在未掩蔽区域的像素点的距离获取重构损失。
可选地,所述第三损失模块包括:计算模块,用于计算第二中间特征和第三中间特征的距离;距离模块,用于基于预设的温度系数计算第二重构图像和第三重构图像在掩蔽区域的像素点的距离;蒸馏损失模块,用于基于第二中间特征和第三中间特征的距离,以及第二重构图像和第三重构图像在掩蔽区域的像素点的距离获取蒸馏损失。
可选地,所述微调训练模块包括:第三获取模块,用于根据预训练好的学生模型的编码器获取下游任务的训练集的第四中间特征;微调模块,用于将第四中间特征输入到下游任务的映射层进行微调训练。本发明实施例第三方面提供一种电子设备,包括:存储器和处理器,所述存储器和所述处理器之间互相通信连接,所述存储器存储有计算机指令,所述处理器通过执行所述计算机指令,从而执行如本发明实施例第一方面任一项所述的掩码自编码器的知识蒸馏方法。
本发明实施例第四方面提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机指令,所述计算机指令用于使所述计算机执行如本发明实施例第一方面任一项所述的掩码自编码器的知识蒸馏方法。
从以上技术方案可以看出,本发明实施例具有以下优点:
本发明实施例提供的一种掩码自编码器的知识蒸馏方法、装置、设备及存储介质,通过分别建立掩码自编码器的教师模型和学生模型,其中,所述教师模型和所述学生模型均为视觉变换模型,且所述教师模型的规模大于所述学生模型;对所述教师模型进行预训练;基于预训练好的所述教师模型对所述学生模型进行知识蒸馏预训练,使学生模型从预训练好的教师模型中学习数据泛化能力,得到表征能力更好的图像特征;基于下游任务对预训练好的所述学生模型进行微调训练,学生模型可部署在算力资源缺乏的电力边缘侧,在减少模型参数的同时保证了模型精度不下降,加速实时推理速度。
附图说明
为了更清楚地表达说明本发明实施例的技术方案,下面将对实施例描述所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本发明实施例中掩码自编码器的知识蒸馏方法的流程图;
图2为本发明实施例中掩码自编码器的知识蒸馏方法的架构图;
图3为本发明实施例中掩码自编码器的知识蒸馏装置的结构示意图;
图4为本发明实施例中电子设备的结构示意图;
图5为本发明实施例中计算机可读存储介质的结构示意图。
具体实施方式
为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明实施例提供了一种掩码自编码器的知识蒸馏方法,如图1和图2所示,该方法包括:
步骤S100:分别建立掩码自编码器的教师(Teacher)模型和学生(Student)模型,其中,教师模型和学生模型均为视觉变换(Vision Transformer,ViT)模型,且教师模型的规模大于学生模型。具体地,教师模型和学生模型的网络结构均是ViT,包括编码器和解码器,其工作过程为:将掩码后的图像输入到编码器得到图像的中间特征,然后将中间特征输入到解码器重建图像,得到重构后的图像。教师模型和学生模型的区别在于教师模型的规模大于学生模型,即前者比后者的参数量多,网络结构复杂度高。
步骤S200:对教师模型进行预训练。通过对参数相对较多、网络结构较为复杂的教师模型进行预训练,得到预训练好的教师模型,通过预训练好的教师模型为学生模型的预训练提供基于知识蒸馏的监督参数。
步骤S300:基于预训练好的教师模型对学生模型进行知识蒸馏预训练。在对学生模型进行预训练时,将输入图像同步输入到教师模型,利用教师模型的信息对学生模型进行监督。通过教师模型的信息对学生模型进行监督,学生模型可从预训练好的教师模型中学习数据泛化能力,得到表征能力更好的图像特征,在预训练阶段减少模型参数的同时保证了模型精度不下降,在减少模型参数的同时保证了模型精度不下降,加速实时推理速度。
步骤S400:基于下游任务对预训练好的学生模型进行微调训练(fine-tune)。fine-tune就是在预训练的学生模型的基础上利用下游任务数据集再次训练,使其在下游任务数据集效果提升,例如在预训练的学生模型上,用电力图像数据集对图像分类任务进行fine-tune。
本发明实施例的一种掩码自编码器的知识蒸馏方法,通过分别建立掩码自编码器的教师模型和学生模型,其中,教师模型和学生模型均为视觉变换模型,且教师模型的规模大于学生模型;对教师模型进行预训练;基于预训练好的教师模型对学生模型进行知识蒸馏预训练;基于下游任务对预训练好的学生模型进行微调训练,使学生模型从预训练好的教师模型中学习数据泛化能力,得到表征能力更好的图像特征,学生模型可部署在算力资源缺乏的电力边缘侧,在减少模型参数的同时保证了模型精度不下降,加速实时推理速度。
在一实施例中,上述步骤S200,对教师模型进行预训练,包括:根据第一预设掩码率对第一输入图像进行掩码;将掩码后的第一输入图像输入到教师模型的编码器得到第一中间特征;将第一中间特征输入到教师模型的解码器得到第一重构图像;根据第一输入图像和第一重构图像获取第一损失函数;基于第一损失函数对教师模型进行预训练。
具体地,第一输入图像x为大规模数据集中的电力图像,首先将其划分为同等大小的N个图像块(Patch),使用预先确定的第一掩码率(Mask Ratio,MR)对图像随机进行掩码,被选中的patch对于教师模型不可见,将剩余可见的patch拼接成序列,输入到编码器后得到第一中间特征,解码器利用第一中间特征以及根据第一掩码率随机获取的掩码向量进行解码,得到第一重构图像x′,在像素空间中计算第一输入图像x以及第一重构图像x′之间的均方误差(MSE),将该均方误差作为第一损失函数,利用第一损失函数计算一个损失值,然后反向计算梯度,更新教师模型的参数,从而完成教师模型的预训练。
在一实施例中,上述步骤S300,基于预训练好的教师模型对学生模型进行知识蒸馏预训练,包括:
步骤S310:基于学生模型和预训练好的教师模型获取重构损失和蒸馏损失;
步骤S320:基于重构损失和蒸馏损失确定第二损失函数;
步骤S330:基于第二损失函数对学生模型进行知识蒸馏预训练。
具体地,在对学生模型进行预训练时,将输入图像同步输入到教师模型,利用教师模型的信息对学生模型进行监督,蒸馏损失根据教师模型的信息获取,重构损失为学生模型根据自身的输入图像和重构图像计算获取。第二损失函数由重构损失和蒸馏损失共同确定,两个损失对学生模型的影响度由超参数β决定,超参数β可以根据试验选取效果最好的值,学生模型预训练阶段的第二损失函数Ls如下式所示:
Ls=recons+Lkd
式中,Lrecons表示重构损失,Lkd表示蒸馏损失,利用第二损失函数Ls计算一个损失值,然后反向计算梯度,更新学生模型的参数,从而完成学生模型的预训练,学生模型可从预训练好的教师模型中学习数据泛化能力,得到表征能力更好的图像特征,在预训练阶段减少模型参数的同时保证了模型精度不下降,在减少模型参数的同时保证了模型精度不下降,加速实时推理速度。
在一实施例中,上述步骤S330,基于学生模型和预训练好的教师模型获取重构损失和蒸馏损失,包括:
步骤S331:根据第二预设掩码率对第二输入图像进行掩码;
步骤S332:将掩码后的第二输入图像输入到学生模型的编码器得到第二中间特征,将第二中间特征输入到学生模型的解码器得到第二重构图像;
步骤S333:将掩码后的第二输入图像输入到预训练好的教师模型的编码器得到第三中间特征,将第三中间特征输入到预训练好的教师模型的解码器得到第三重构图像;
步骤S334:基于第二重构图像获取重构损失;
步骤S335:基于第二中间特征、第三中间特征、第二重构图像和第三重构图像获取蒸馏损失。
具体地,和教师模型的与训练过程类似,第二输入图像X为大规模数据集中的电力图像,首先将其划分为同等大小的N个图像块(Patch),使用预先确定的第二掩码率对图像随机进行掩码,在本实施例中,第一掩码率和第二掩码率相同,第一输入图像和第二输入图像选自同一个数据集。被选中的patch对于教师模型和学生模型不可见,将剩余可见的patch拼接成序列,同步输入到学生模型的编码器和预训练好的教师模型的编码器。学生模型的编码器基于掩码后的第二输入图像X得到第二中间特征fs,将第二中间特征fs和根据第二掩码率随机获取的掩码向量共同输入到学生模型的编码器得到第二重构图像X′。预训练好的教师模型的编码器基于掩码后的第二输入图像X得到第三中间特征ft,将第三中间特征ft和根据第二掩码率随机获取的掩码向量共同输入到预训练好的教师模型的编码器得到第三重构图像t′。
重构损失和蒸馏损失用于共同确定学生模型的第二损失函数,根据第二重构图像X′和第二输入图像X的像素点的方差确定重构损失,重构损失为学生模型自身的基础预训练。根据第二中间特征fs、第三中间特征ft和第三重构图像t′获取蒸馏损失,蒸馏损失为学生模型的预训练提供了中间特征和软标签监督。
本发明实施例的学生模型在预训练阶段通过对第二输入图像进行掩码,并基于预训练好的教师模型对学生模型进行监督,不需要对电网场景图像即包含第二输入图像的大规模数据集进行标注,可减少标注的人力成本,快速利用现场采集样本更新模型,减少模型迭代成本。同时,结合知识蒸馏技术,减少学生模型参数,提升学生模型数据泛化能力。
在一实施例中,步骤S334,基于第二重构图像获取重构损失包括:获取第二输入图像和第二重构图像在掩蔽区域的像素点的距离,以及第二输入图像和第二重构图像在未掩蔽区域的像素点的距离;基于在掩蔽区域的像素点的距离和在未掩蔽区域的像素点的距离获取重构损失。
具体地,重构损失如下式所示:
式中,Nm、Nu分别表示掩蔽区域的像素点、未掩蔽区域的像素点的数量,对于掩蔽的像素点,由于未输入到模型中,模型对其重构的像素点与真实像素点的距离越近,表明模型对未知像素点的预测能力越好,因此降低未屏蔽像素点的重构损失权重,即上式中加号右边乘一个小于1的超参数c。
本发明实施例采用加权方法考虑了掩蔽区域和未掩蔽区域共同对图像还原的影响,不同于现有技术中只计算关注掩蔽区域的方法,能够增强学生对于未掩蔽区域进行还原的能力,解码器防止过度关注掩蔽区域而忽略未掩蔽区域的信息。
在一实施例中,步骤S335,基于第二中间特征、第三中间特征、第二重构图像和第三重构图像获取蒸馏损失,包括:计算第二中间特征和第三中间特征的距离;基于预设的温度系数计算第二重构图像和第三重构图像在掩蔽区域的像素点的距离;基于第二中间特征和第三中间特征的距离,以及第二重构图像和第三重构图像在掩蔽区域的像素点的距离获取蒸馏损失。
具体地,蒸馏损失如下式所示:
式中,函数用于计算第二中间特征fs和第三中间特征ft之间的距离,采用L1范式,即L1(,b)=||-b||1,等式右半部分计算学生模型的第二重构图像X′的被掩蔽像素点与教师模型的第三重构图像t′被掩蔽像素点的距离,其中T为预设的温度系数,可控制平滑度,/>是教师模型提供的软标签,T的值往往大于1,从而减轻预训练的教师模型预测结果的偏差。蒸馏损失针对掩码自编码器预训练时缺乏类别标签做监督的特点,利用教师模型编码后的中间特征以及重构的像素点预测结果为学生模型提供监督,即蒸馏损失为学生模型预训练提供了中间特征和软标签监督,使学生模型学习其表达数据的能力,提高泛化性。
在一实施例中,上述步骤S400,基于下游任务对预训练好的学生模型进行微调训练,包括:根据预训练好的学生模型的编码器获取下游任务的训练集的第四中间特征;将第四中间特征输入到下游任务的映射层进行微调训练。
具体地,微调(fine-tune)训练就是在预训练模型的基础上利用下游任务训练集再次训练,使其在下游任务的效果提升。
与预训练阶段不同的是,微调训练只使用学生模型的编码器部分,将下游任务的训练集进行标注,将下游任务的训练集中的图像输入到预训练好的学生模型的编码器获取第四中间特征,然后将第四中间特征输入到下游任务的映射层进行微调训练,映射层为根据具体的下游任务在预训练后的学生模型上增加的网络,用于执行具体的下游任务,例如下游任务为分类时,只需要将第四中间特征输入到一个全连接层以及一个softmax层,得到分类结果,计算分类损失,根据分类损失进行微调训练。在进行微调训练后即可使用学生模型进行推理,例如对于现场采集的一张图像样本,将其输入到学生模型,输出是分类结果,将其分类为某一种缺陷或不存在缺陷。
在实际应用场景中,只需要根据具体的下游任务对预训练好的学生模型进行微调训练即可快速更新模型,减少模型迭代成本。
本发明实施例还提供一种掩码自编码器的知识蒸馏装置,如图3所示,该装置包括:
模型建立模块310,用于分别建立掩码自编码器的教师模型和学生模型,其中,教师模型和学生模型均为视觉变换模型,且教师模型的规模大于学生模型;具体内容参见上述方法实施例对应部分,在此不再赘述。
教师模型训练模块320,用于对教师模型进行预训练;具体内容参见上述方法实施例对应部分,在此不再赘述。
知识蒸馏预训练模块330,用于基于预训练好的教师模型对学生模型进行知识蒸馏预训练;具体内容参见上述方法实施例对应部分,在此不再赘述。
微调训练模块340,用于基于下游任务对预训练好的学生模型进行微调训练。具体内容参见上述方法实施例对应部分,在此不再赘述。
本发明实施例的一种掩码自编码器的知识蒸馏装置,通过分别建立掩码自编码器的教师模型和学生模型,其中,教师模型和学生模型均为视觉变换模型,且教师模型的规模大于学生模型;对教师模型进行预训练;基于预训练好的教师模型对学生模型进行知识蒸馏预训练;基于下游任务对预训练好的学生模型进行微调训练,使学生模型从预训练好的教师模型中学习数据泛化能力,得到表征能力更好的图像特征,学生模型可部署在算力资源缺乏的电力边缘侧,在减少模型参数的同时保证了模型精度不下降,加速实时推理速度。
在一实施例中,教师模型训练模块320包括:
第一掩码模块,用于根据第一预设掩码率对第一输入图像进行掩码;
第一输入模块,用于将掩码后的第一输入图像输入到教师模型的编码器得到第一中间特征;
第一重构模块,用于将第一中间特征输入到教师模型的解码器得到第一重构图像;
第一损失模块,用于根据第一输入图像和第一重构图像获取第一损失函数;
第一预训练模块,用于基于第一损失函数对教师模型进行预训练。
在一实施例中,知识蒸馏预训练模块330包括:
第一获取模块,用于基于学生模型和预训练好的教师模型获取重构损失和蒸馏损失;
第二损失模块,用于基于重构损失和蒸馏损失确定第二损失函数;
第二预训练模块,用于基于第二损失函数对学生模型进行知识蒸馏预训练。
在一实施例中,第一获取模块包括:
第二掩码模块,用于根据第二预设掩码率对第二输入图像进行掩码;
第二重构模块,用于将掩码后的第二输入图像输入到学生模型的编码器得到第二中间特征,将第二中间特征输入到学生模型的解码器得到第二重构图像;
第三重构模块,用于将掩码后的第二输入图像输入到预训练好的教师模型的编码器得到第三中间特征,将第三中间特征输入到预训练好的教师模型的解码器得到第三重构图像;
第二损失模块,用于基于第二重构图像获取重构损失;
第三损失模块,用于基于第二中间特征、第三中间特征、第二重构图像和第三重构图像获取蒸馏损失。
在一实施例中,第二损失模块包括:
第二获取模块,用于获取第二输入图像和第二重构图像在掩蔽区域的像素点的距离,以及第二输入图像和第二重构图像在未掩蔽区域的像素点的距离;
重构损失模块,用于基于在掩蔽区域的像素点的距离和在未掩蔽区域的像素点的距离获取重构损失。
在一实施例中,第三损失模块,包括:
计算模块,用于计算第二中间特征和第三中间特征的距离;
距离模块,用于基于预设的温度系数计算第二重构图像和第三重构图像在掩蔽区域的像素点的距离;
蒸馏损失模块,用于基于第二中间特征和第三中间特征的距离,以及第二重构图像和第三重构图像在掩蔽区域的像素点的距离获取蒸馏损失。
在一实施例中,微调训练模块340包括:
第三获取模块,用于根据预训练好的学生模型的编码器获取下游任务的训练集的第四中间特征;
微调模块,用于将第四中间特征输入到下游任务的映射层进行微调训练。
本发明实施例还提供了一种电子设备,如图4所示,包括:存储器420和处理器410,存储器420和处理器410之间互相通信连接,存储器420存储有计算机指令,处理器410通过执行计算机指令,从而执行如本发明上述实施例中的掩码自编码器的知识蒸馏方法。其中处理器410和存储器420可以通过总线或者其他方式连接。处理器410可以为中央处理器(CentralProcessingUnit,CPU)。处理器410还可以为其他通用处理器、数字信号处理器(DigitalSignalProcessor,DSP)、专用集成电路(ApplicationSpecificIntegratedCircuit,ASIC)、现场可编程门阵列(Field-ProgrammableGateArray,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等芯片,或者上述各类芯片的组合。存储器420作为一种非暂态计算机存储介质,可用于存储非暂态软件程序、非暂态计算机可执行程序以及模块,如本发明实施例中的对应的程序指令/模块。处理器410通过运行存储在存储器420中的非暂态软件程序、指令以及模块,从而执行处理器410的各种功能应用以及数据处理,即实现上述方法实施例中的掩码自编码器的知识蒸馏方法。存储器420可以包括存储程序区和存储数据区,其中,存储程序区可存储操作装置、至少一个功能所需要的应用程序;存储数据区可存储处理器410所创建的数据等。此外,存储器420可以包括高速随机存取存储器420,还可以包括非暂态存储器420,例如至少一个磁盘存储器件、闪存器件、或其他非暂态固态存储器件。在一些实施例中,存储器420可选包括相对于处理器410远程设置的存储器420,这些远程存储器420可以通过网络连接至处理器410。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。一个或者多个模块存储在存储器420中,当被处理器410执行时,执行如上述方法实施例中的掩码自编码器的知识蒸馏方法。上述电子设备具体细节可以对应上述方法实施例中对应的相关描述和效果进行理解,此处不再赘述。
本发明实施例还提供一种计算机可读存储介质,如图5所示,其上存储有计算机程序510,该指令被处理器执行时实现上述实施例中掩码自编码器的知识蒸馏方法的步骤。该存储介质上还存储有音视频流数据,特征帧数据、交互请求信令、加密数据以及预设数据大小等。其中,存储介质可为磁碟、光盘、只读存储记忆体(Read-OnlyMemory,ROM)、随机存储记忆体(RandomAccessMemory,RAM)、快闪存储器(FlashMemory)、硬盘(HardDiskDrive,缩写:HDD)或固态硬盘(Solid-StateDrive,SSD)等;存储介质还可以包括上述种类的存储器的组合。本领域技术人员可以理解,实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,计算机程序13可存储于一计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,存储介质可为磁碟、光盘、只读存储记忆体(Read-OnlyMemory,ROM)、随机存储记忆体(RandomAccessMemory,RAM)、快闪存储器(FlashMemory)、硬盘(HardDiskDrive,缩写:HDD)或固态硬盘(Solid-StateDrive,SSD)等;存储介质还可以包括上述种类的存储器的组合。
以上,以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。
Claims (16)
1.一种掩码自编码器的知识蒸馏方法,其特征在于,包括:
分别建立掩码自编码器的教师模型和学生模型,其中,所述教师模型和所述学生模型均为视觉变换模型,且所述教师模型的规模大于所述学生模型;
对所述教师模型进行预训练;
基于预训练好的所述教师模型对所述学生模型进行知识蒸馏预训练;
基于下游任务对预训练好的所述学生模型进行微调训练。
2.根据权利要求1所述的掩码自编码器的知识蒸馏方法,其特征在于,所述对所述教师模型进行预训练,包括:
根据第一预设掩码率对第一输入图像进行掩码;
将掩码后的第一输入图像输入到所述教师模型的编码器得到第一中间特征;
将所述第一中间特征输入到所述教师模型的解码器得到第一重构图像;
根据所述第一输入图像和所述第一重构图像获取第一损失函数;
基于所述第一损失函数对所述教师模型进行预训练。
3.根据权利要求1所述的掩码自编码器的知识蒸馏方法,其特征在于,所述基于预训练好的所述教师模型对所述学生模型进行知识蒸馏预训练,包括:
基于所述学生模型和预训练好的所述教师模型获取重构损失和蒸馏损失;
基于所述重构损失和蒸馏损失确定第二损失函数;
基于所述第二损失函数对所述学生模型进行知识蒸馏预训练。
4.根据权利要求3所述的掩码自编码器的知识蒸馏方法,其特征在于,所述基于所述学生模型和预训练好的所述教师模型获取重构损失和蒸馏损失,包括:
根据第二预设掩码率对第二输入图像进行掩码;
将掩码后的所述第二输入图像输入到所述学生模型的编码器得到第二中间特征,将所述第二中间特征输入到所述学生模型的解码器得到第二重构图像;
将掩码后的所述第二输入图像输入到预训练好的所述教师模型的编码器得到第三中间特征,将所述第三中间特征输入到预训练好的所述教师模型的解码器得到第三重构图像;
基于所述第二重构图像获取重构损失;
基于所述第二中间特征、第三中间特征、第二重构图像和第三重构图像获取蒸馏损失。
5.根据权利要求4所述的掩码自编码器的知识蒸馏方法,其特征在于,所述基于所述第二重构图像获取重构损失包括:
获取所述第二输入图像和所述第二重构图像在掩蔽区域的像素点的距离,以及所述第二输入图像和所述第二重构图像在未掩蔽区域的像素点的距离;
基于在掩蔽区域的像素点的距离和在未掩蔽区域的像素点的距离获取所述重构损失。
6.根据权利要求4所述的掩码自编码器的知识蒸馏方法,其特征在于,所述基于所述第二中间特征、第三中间特征、第二重构图像和第三重构图像获取蒸馏损失,包括:
计算所述第二中间特征和所述第三中间特征的距离;
基于预设的温度系数计算所述第二重构图像和所述第三重构图像在掩蔽区域的像素点的距离;
基于所述第二中间特征和所述第三中间特征的距离,以及所述第二重构图像和所述第三重构图像在掩蔽区域的像素点的距离获取蒸馏损失。
7.根据权利要求1所述的掩码自编码器的知识蒸馏方法,其特征在于,所述基于下游任务对预训练好的所述学生模型进行微调训练,包括:
根据预训练好的所述学生模型的编码器获取下游任务的训练集的第四中间特征;
将所述第四中间特征输入到下游任务的映射层进行微调训练。
8.一种掩码自编码器的知识蒸馏装置,其特征在于,包括:
模型建立模块,用于分别建立掩码自编码器的教师模型和学生模型,其中,所述教师模型和所述学生模型均为视觉变换模型,且所述教师模型的规模大于所述学生模型;
教师模型训练模块,用于对所述教师模型进行预训练;
知识蒸馏预训练模块,用于基于预训练好的所述教师模型对所述学生模型进行知识蒸馏预训练;
微调训练模块,用于基于下游任务对预训练好的所述学生模型进行微调训练。
9.根据权利要求8所述的掩码自编码器的知识蒸馏装置,其特征在于,所述教师模型训练模块包括:
第一掩码模块,用于根据第一预设掩码率对第一输入图像进行掩码;
第一输入模块,用于将掩码后的第一输入图像输入到教师模型的编码器得到第一中间特征;
第一重构模块,用于将第一中间特征输入到教师模型的解码器得到第一重构图像;
第一损失模块,用于根据第一输入图像和第一重构图像获取第一损失函数;
第一预训练模块,用于基于第一损失函数对教师模型进行预训练。
10.根据权利要求8所述的掩码自编码器的知识蒸馏装置,其特征在于,所述知识蒸馏预训练模块包括:
第一获取模块,用于基于学生模型和预训练好的教师模型获取重构损失和蒸馏损失;
第二损失模块,用于基于重构损失和蒸馏损失确定第二损失函数;
第二预训练模块,用于基于第二损失函数对学生模型进行知识蒸馏预训练。
11.根据权利要求10所述的掩码自编码器的知识蒸馏装置,其特征在于,所述第一获取模块包括:
第二掩码模块,用于根据第二预设掩码率对第二输入图像进行掩码;
第二重构模块,用于将掩码后的第二输入图像输入到学生模型的编码器得到第二中间特征,将第二中间特征输入到学生模型的解码器得到第二重构图像;
第三重构模块,用于将掩码后的第二输入图像输入到预训练好的教师模型的编码器得到第三中间特征,将第三中间特征输入到预训练好的教师模型的解码器得到第三重构图像;
第二损失模块,用于基于第二重构图像获取重构损失;
第三损失模块,用于基于第二中间特征、第三中间特征、第二重构图像和第三重构图像获取蒸馏损失。
12.根据权利要求11所述的掩码自编码器的知识蒸馏装置,其特征在于,所述第二损失模块包括:
第二获取模块,用于获取第二输入图像和第二重构图像在掩蔽区域的像素点的距离,以及第二输入图像和第二重构图像在未掩蔽区域的像素点的距离;
重构损失模块,用于基于在掩蔽区域的像素点的距离和在未掩蔽区域的像素点的距离获取重构损失。
13.根据权利要求11所述的掩码自编码器的知识蒸馏装置,其特征在于,所述第三损失模块包括:
计算模块,用于计算第二中间特征和第三中间特征的距离;
距离模块,用于基于预设的温度系数计算第二重构图像和第三重构图像在掩蔽区域的像素点的距离;
蒸馏损失模块,用于基于第二中间特征和第三中间特征的距离,以及第二重构图像和第三重构图像在掩蔽区域的像素点的距离获取蒸馏损失。
14.根据权利要求8所述的掩码自编码器的知识蒸馏装置,其特征在于,所述微调训练模块包括:
第三获取模块,用于根据预训练好的学生模型的编码器获取下游任务的训练集的第四中间特征;
微调模块,用于将第四中间特征输入到下游任务的映射层进行微调训练。
15.一种电子设备,其特征在于,包括:存储器和处理器,所述存储器和所述处理器之间互相通信连接,所述存储器存储有计算机指令,所述处理器通过执行所述计算机指令,从而执行如权利要求1至7任一项所述的掩码自编码器的知识蒸馏方法。
16.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机指令,所述计算机指令用于使所述计算机执行如权利要求1至7任一项所述的掩码自编码器的知识蒸馏方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310118939.7A CN116227582A (zh) | 2023-01-31 | 2023-01-31 | 掩码自编码器的知识蒸馏方法、装置、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310118939.7A CN116227582A (zh) | 2023-01-31 | 2023-01-31 | 掩码自编码器的知识蒸馏方法、装置、设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116227582A true CN116227582A (zh) | 2023-06-06 |
Family
ID=86574526
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310118939.7A Pending CN116227582A (zh) | 2023-01-31 | 2023-01-31 | 掩码自编码器的知识蒸馏方法、装置、设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116227582A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116523917A (zh) * | 2023-07-04 | 2023-08-01 | 宁德时代新能源科技股份有限公司 | 缺陷检测方法、装置、计算机设备和存储介质 |
-
2023
- 2023-01-31 CN CN202310118939.7A patent/CN116227582A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116523917A (zh) * | 2023-07-04 | 2023-08-01 | 宁德时代新能源科技股份有限公司 | 缺陷检测方法、装置、计算机设备和存储介质 |
CN116523917B (zh) * | 2023-07-04 | 2023-10-13 | 宁德时代新能源科技股份有限公司 | 缺陷检测方法、装置、计算机设备和存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
AU2020103715A4 (en) | Method of monocular depth estimation based on joint self-attention mechanism | |
CN113034380B (zh) | 基于改进可变形卷积校正的视频时空超分辨率方法和装置 | |
KR101967089B1 (ko) | 컨볼루션 신경망 기반의 완전 기준 이미지 품질 평가 | |
CN109636721B (zh) | 基于对抗学习和注意力机制的视频超分辨率方法 | |
CN113065060B (zh) | 基于深度学习的教育平台课程推荐方法及系统 | |
CN106910192A (zh) | 一种基于卷积神经网络的图像融合效果评估方法 | |
CN112801047B (zh) | 缺陷检测方法、装置、电子设备及可读存储介质 | |
CN108111860B (zh) | 基于深度残差网络的视频序列丢失帧预测恢复方法 | |
CN110139046B (zh) | 一种基于张量的视频帧合成方法 | |
CN111598842A (zh) | 一种绝缘子缺陷样本生成模型的方法、系统及存储介质 | |
Jia et al. | Effective meta-attention dehazing networks for vision-based outdoor industrial systems | |
CN112132770A (zh) | 图像修复的方法、装置、计算机可读介质及电子设备 | |
CN109299170B (zh) | 一种针对带标签时间序列数据的补全方法 | |
CN116227582A (zh) | 掩码自编码器的知识蒸馏方法、装置、设备及存储介质 | |
CN107729885B (zh) | 一种基于多重残差学习的人脸增强方法 | |
CN112464718A (zh) | 一种基于YOLO-Terse网络的目标检测方法及存储介质 | |
Jiang et al. | Multi-level memory compensation network for rain removal via divide-and-conquer strategy | |
Yuan et al. | A simple self-supervised imu denoising method for inertial aided navigation | |
CN113763447B (zh) | 深度图的补全方法、电子设备及存储介质 | |
CN118229632A (zh) | 显示屏缺陷检测方法、模型训练方法、装置、设备及介质 | |
WO2020001046A1 (zh) | 一种基于自适应层次化运动建模的视频预测方法 | |
CN113436224A (zh) | 一种基于显式构图规则建模的智能图像裁剪方法及装置 | |
CN116152577B (zh) | 图像分类方法及装置 | |
CN117372810A (zh) | 遥感图像语义分割模型训练方法、分割方法及相关装置 | |
CN116129408A (zh) | 一种驾驶员疲劳检测方法及系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |