CN115861765A - 基于无监督蒸馏网络的学生网络获取方法、图像分类模型获取方法、图像分类方法 - Google Patents
基于无监督蒸馏网络的学生网络获取方法、图像分类模型获取方法、图像分类方法 Download PDFInfo
- Publication number
- CN115861765A CN115861765A CN202211439778.3A CN202211439778A CN115861765A CN 115861765 A CN115861765 A CN 115861765A CN 202211439778 A CN202211439778 A CN 202211439778A CN 115861765 A CN115861765 A CN 115861765A
- Authority
- CN
- China
- Prior art keywords
- network
- student
- unsupervised
- distillation
- teacher
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Landscapes
- Image Analysis (AREA)
Abstract
基于无监督蒸馏网络的学生网络获取方法、图像分类模型获取方法、图像分类方法,涉及神经网络加速领域。针对现有技术中无监督训练方法面对大网络性能较好,在小网络上则不能保证训练的精度,除了预训练教师网络外,还会构造一个样本库来实现损失函数,限制了网络在边缘端的更新的问题,本发明提供了:基于无监督蒸馏网络的学生网络获取方法,包括:采集图像作为数据集;根据数据集,得到两个增广;两个增广分别通过教师网络和学生网络,得到教师网络的投影值和预测值和学生网络的投影值和预测值;根据教师网络的投影值和预测值和学生网络的投影值和预测值,更新教师网络和学生网络;输出当前学生网络,作为结果。适合应用于边缘计算场景。
Description
技术领域
涉及神经网络加速领域,具体涉及通过神经网络分类图像。
背景技术
边缘采集设备会生成大量的无标签数据,对这些无标签数据进行标注将会消耗大量人力物力。因此需要网络能够进行无监督训练。对比学习(contrast Learning,CL)的发展将众多研究者的注意力吸引到自监督学习(Self-Supervised Learning,SSL)领域。与自回归(AR)模型、基于流的模型和自动编码(AE)模型等生成方法不同,CL从正样本和负样本之间的比较中学习表示,而不是专注于重建像素空间中的误差。因此,CL可以专注于对类内聚集至关重要的更抽象的潜在因素。许多令人敬畏的CL方法已经被推导出来,它们通常可以分为两类:对比方法和非对称方法。对比方法将数据集中的每个图像都视为一个独立的类。学习过程是训练模型将同一幅图像的两幅增广图像识别为同一类别,而将其他图像识别为不同类别。这些实例判别方法取得了很好的效果。有些甚至缩小了基于监督的方法和基于自监督的方法之间的差距。然而,这些方法倾向于将每个嵌入分为不同的类别,这意味着在此过程中,属于同一类别的一些图像也被分离了。已经提出了许多方法来缓解这个问题,如SupCon,NNCLR,MMCL,CLD。与依赖正负对的对比方法不同,非对称方法采用了预测网络和停止梯度方法。通过预测网络将两个不同的嵌入投影到彼此接近的位置。具体来说,BYOL在在线分支之后引入了一个预测层,并使用动量编码器更新目标分支。进一步证明了停止梯度操作可以取代动量编码器。该方法仅使用一个网络。然而,对称或非对称方法在小型模型上的精度都会退化。这些工作专注于提高大型网络的性能,如ResNet-50,而边缘应用通常需要内存消耗和计算复杂度更少的小型模型。但是目前的无监督训练方法面对大网络性能较好,在小网络上则不能保证训练的精度,而边缘部署对网络的大小有着严格的要求。
小型模型通常在SSL中学习低层次表示。因此,参考监督学习,在框架中采用知识蒸馏(KD)来缓解特征提取能力的退化。知识蒸馏(KD)通过蒸馏集成模型中的知识来提高单个模型的性能。将教师模型推理生成的软目标与实际标签生成的硬目标相结合,引入到学生模型的训练中。结果相当令人印象深刻,许多研究人员将这种方法用于优化小型、轻量级网络。将KD应用于SSL时有几个问题需要考虑。首先,SSL避免了人工标注图像的使用,这意味着没有图像标签进行网络更新,损失函数需要重新设计。其次,KD需要一个两阶段的训练,首先更新教师模型以获得最佳性能,然后学生模型使用教师的logits。但是在边缘端,重新训练教师模型进行会相当消耗算力资源,同时SAR图像、遥感图像、瑕疵检测等等工业场景数据集又会限制网络的迁移性能。此外,一般的SSKD方法,除了预训练教师网络外,还会构造一个样本库来实现依赖于正负样本的损失函数,这进一步限制了网络在边缘端的更新。
发明内容
针对现有技术中,的无监督训练方法面对大网络性能较好,在小网络上则不能保证训练的精度,一般的SSKD方法,除了预训练教师网络外,还会构造一个样本库来实现依赖于正负样本的损失函数,这进一步限制了网络在边缘端的更新的问题,本发明提供的技术方案为:
基于无监督蒸馏网络的学生网络获取方法,应用于图像分类,其特征在于,所述方法包括:
步骤1:采集图像作为数据集;
步骤2:根据所述数据集,得到两个增广;
步骤3:所述两个增广分别通过教师网络和学生网络,得到教师网络输出的投影值和预测值,以及学生网络输出的投影值和预测值。
步骤4:根据教师网络输出的投影值和预测值以及学生网络输出的投影值和预测值,更新所述教师网络和学生网络;
步骤5:输出当前学生网络,作为结果。
进一步,提供一个优选实施方式,所述教师网络输出的投影值和预测值,以及学生网络输出的投影值和预测值的获取方法具体为:根据MLP头获取。
进一步,提供一个优选实施方式,所述的步骤3中,MLP头根据教师网络输出的投影值和预测值以及学生网络输出的投影值和预测值生成子项,通过所述子项求解预设的损失函数。
进一步,提供一个优选实施方式,所述MLP头包括学生网络的投影头和预测头以及教师网络的投影头和预测头,所述MLP头根据所述教师网络和学生网络生成分别对应两个所述增广的、教师网络的两个投影值和两个预测值,以及分别对应两个所述增广的、学生网络的两个投影值和两个预测值,四个投影值和四个预测值作为所述子项。
进一步,提供一个优选实施方式,所述步骤4中,更新所述教师网络和学生网络的具体方式为,根据教师网络输出的投影值和预测值,得到教师网络的自更新,根据教师网络输出的投影值和学生网络的预测值,得到蒸馏更新,根据教师网络输出的预测值和学生网络的投影值,得到对抗更新,根据所述自更新、蒸馏更新和对抗更新,更新所述教师网络和学生网络。
基于同一发明构思,本发明还提供了无监督蒸馏学生网络获取装置,所述装置包括:
模块1:用于采集图像作为数据集;
模块2:用于根据所述数据集,得到两个增广;
模块3:用于将所述两个增广分别通过教师网络和学生网络,得到教师网络输出的投影值和预测值,以及学生网络输出的投影值和预测值。
模块4:根据教师网络输出的投影值和预测值以及学生网络输出的投影值和预测值,更新所述教师网络和学生网络;
模块5:用于输出当前学生网络,作为结果。
基于同一发明构思,本发明还提供了基于无监督蒸馏网络的图像分类模型获取方法,所述方法包括:
步骤6:采集待分类图像;
步骤7:将所述图像通过所述的基于无监督蒸馏网络的学生网络获取方法,得到更新后的学生网络;
步骤8:通过预设训练图像集,训练预设的分类网络;
步骤9:得到训练后的学生网络和分类网络,并输出作为结果。
基于同一发明构思,本发明还提供了基于无监督蒸馏网络的图像分类模型获取装置,所述装置包括:
模块6:用于采集图像;
模块7:用于将所述图像通过所述的基于无监督蒸馏网络的学生网络获取方法,得到更新后的学生网络;
模块8:用于通过预设训练图像集,训练预设的分类网络;
模块9:用于得到训练后的学生网络和分类网络。
基于同一发明构思,本发明还提供了基于无监督蒸馏网络的图像分类方法,所述方法包括:
步骤10:采集待分类的图像;
步骤11:将所述待分类的图像通过所述的基于无监督蒸馏网络的图像分类模型获取方法输出的结果进行分类。
基于同一发明构思,本发明还提供了基于无监督蒸馏网络的图像分类装置,所述装置包括:
模块10:用于采集待分类的图像;
模块11:用于将所述待分类的图像通过所述的基于无监督蒸馏网络的图像分类模型获取装置输出的结果进行分类。
基于同一发明构思,本发明还提供了计算机储存介质,用于储存计算机程序,当所述储存介质中储存的计算机程序被计算机的处理器读取时,所述计算机执行所述的基于无监督蒸馏网络的学生网络获取方法或所述的基于无监督蒸馏网络的图像分类模型获取方法或所述的基于无监督蒸馏网络的图像分类方法。
基于同一发明构思,本发明还提供了计算机,包括处理器和储存介质,所述储存介质用于储存计算机程序,当所述储存介质中储存的计算机程序被所述处理器读取时,所述计算机执行所述的基于无监督蒸馏网络的学生网络获取方法或所述的基于无监督蒸馏网络的图像分类模型获取方法或所述的基于无监督蒸馏网络的图像分类方法。
相对于现有技术,本发明的有益之处在于:
本发明提供的基于无监督蒸馏网络的学生网络获取方法,不需要预训练教师网络,通过教师网络在线更新的方式,同时移除了对正负样本库的依赖,降低分类模型训练成本,同时提高了小网络的自监督学习精度。
本发明提供的基于无监督蒸馏网络的图像分类模型获取方法,以自监督学习的形式实现知识迁移路径,实现了不需要任何额外结构的单阶段自监督知识蒸馏的功能。
本发明提供的基于无监督蒸馏网络的图像分类模型获取方法,通过设计单阶段的自我监督只是蒸馏框架,使教师网络与学生网络同步更新,实现对不同数据集的灵活性更强,降低方法迁移成本。
本发明提供的基于无监督蒸馏网络的图像分类模型获取方法,利用自监督学习的形式重构蒸馏的知识迁移路径,移除了蒸馏方法对图像标签的要求,同时结合自监督学习,提高了网络性能。
本发明提供的基于无监督蒸馏网络的图像分类模型获取方法,通过MLP头得到的子项,代入预设函数,将蒸馏的学习对象从教师网络的输出变为教师学生网络的特征提取能力的差距。
本发明提供的基于无监督蒸馏网络的图像分类模型获取方法,适合应用于边缘计算场景。
附图说明
图1为实施方式七提供的基于无监督蒸馏网络的图像分类模型获取方法的框架示意图;
图2为实施方式九提供的基于无监督蒸馏网络的图像分类方法的流程示意图。
具体实施方式
为使本发明提供的技术方案的优点和有益之处体现得更具体,现结合附图对本发明提供的技术方案进行进一步详细地描述,具体的:
实施方式一、本实施方式提供了基于无监督蒸馏网络的学生网络获取方法,应用于图像分类,所述方法包括:
步骤1:采集图像作为数据集;
步骤2:根据所述数据集,得到两个增广;
步骤3:所述两个增广分别通过教师网络和学生网络,得到教师网络输出的投影值和预测值,以及学生网络输出的投影值和预测值。
步骤4:根据教师网络输出的投影值和预测值以及学生网络输出的投影值和预测值,更新所述教师网络和学生网络;
步骤5:输出当前学生网络,作为结果。
实施方式二、本实施方式是对实施方式一提供的基于无监督蒸馏网络的学生网络获取方法的进一步限定,所述教师网络输出的投影值和预测值,以及学生网络输出的投影值和预测值的获取方法具体为:根据MLP头获取。
实施方式三、本实施方式是对实施方式二提供的基于无监督蒸馏网络的学生网络获取方法的进一步限定,所述的步骤3中,MLP头根据教师网络输出的投影值和预测值以及学生网络输出的投影值和预测值生成子项,通过所述子项求解预设的损失函数。
实施方式四、本实施方式是对实施方式三提供的基于无监督蒸馏网络的学生网络获取方法的进一步限定,所述MLP头包括学生网络的投影头和预测头以及教师网络的投影头和预测头,所述MLP头根据所述教师网络和学生网络生成分别对应两个所述增广的、教师网络的两个投影值和两个预测值,以及分别对应两个所述增广的、学生网络的两个投影值和两个预测值,四个投影值和四个预测值作为所述子项。
实施方式五、本实施方式是对实施方式一提供的基于无监督蒸馏网络的学生网络获取方法的进一步限定,所述步骤4中,更新所述教师网络和学生网络的具体方式为,根据教师网络输出的投影值和预测值,得到教师网络的自更新,根据教师网络输出的投影值和学生网络的预测值,得到蒸馏更新,根据教师网络输出的预测值和学生网络的投影值,得到对抗更新,根据所述自更新、蒸馏更新和对抗更新,更新所述教师网络和学生网络。
实施方式六、本实施方式提供了无监督蒸馏学生网络获取装置,所述装置包括:
模块1:用于采集图像作为数据集;
模块2:用于根据所述数据集,得到两个增广;
模块3:用于将所述两个增广分别通过教师网络和学生网络,得到教师网络输出的投影值和预测值,以及学生网络输出的投影值和预测值。
模块4:根据教师网络输出的投影值和预测值以及学生网络输出的投影值和预测值,更新所述教师网络和学生网络;
模块5:用于输出当前学生网络,作为结果。
实施方式七、结合图1说明本实施方式,本实施方式提供了基于无监督蒸馏网络的图像分类模型获取方法,应用于图像分类,所述方法包括:
步骤6:采集待分类图像;
步骤7:将所述图像通过实施方式一至五任意一项提供的基于无监督蒸馏网络的学生网络获取方法,得到更新后的学生网络;
步骤8:通过预设训练图像集,训练预设的分类网络;
步骤9:得到训练后的学生网络和分类网络,并输出作为结果。
具体的:
步骤a:采集数据集T;
步骤b:根据T获得T1和T2;
步骤c:将T中图的两个增广分别输入给教师网络和学生网络,最后就会获得8个输出,分别是教师网络的输出经过投影头得到Zt1、Zt2,Zt1、Zt2进一步经过预测头得到Pt1、Pt2,类似的,学生网络的输出经过投影头得到Zs1、Zs2,Zs1、Zs2进一步经过预测头得到Ps1、Ps2
对应,每个Zs,Zt,Ps,Pt都有2个元素,这两个2元素其实就是对应于一张图的两个增广。
步骤4:Projection S就是学生网络的投影头,Prediction S就是学生网络的预测头;Projection T和Prediction T分别就是教师网络的投影头和预测头,在网络设计中他们与学生网络、教师网络是平等的关系,同为网络参数,没有因果承继关系,只有连接关系。所以这一步就是通过投影头和预测头分别生成8个用于损失函数构成的子项。预测头和投影头通过MLP头实现,共有4个,学生网络的投影头,学生网络的预测头,教师网络的投影头和预测头。
步骤5.1:学生网络的输出经过Projection S和Prediction S,获得Ps,教师网络的输出经过Projection T Prediction T,获得Pt;
步骤5.2:学生网络的输出经过Projection S,获得Zs,教师网络的输出经过Projection T获得Zt;
以上2个步骤就是在计算最终构成损失函数的8个子项,用于解释步骤4。
Zs、Ps、Zt、Pt的含义,它本身可以说成学生网络的输出经过学生网络的投影层获得的投影值(Zs),学生网络的输出经过学生网络的投影层和预测层获得的预测值(Ps),教师网络的输出经过教师网络的投影层获得的投影值(Zt),教师网络的输出经过教师网络的投影层和预测层获得的预测值(Pt)。
步骤6:停止梯度策略(Stop gradient strategy)指的是对于投影头的输出,也就是Zs、Zt,他们是没有反向传播路径的,它们用于损失函数计算,但是更新路径是从Ps、Pt开始沿着预测头—投影头—网络主干更新,不会由于投影头的输出再额外引入一个新的梯度路径。La就是由Zs、Pt构成的损失函数,Ld是由Zt、Ps构成的,而Lc是由Pt、Zt构成的。
步骤7:Ld,La,Lc三者组合得到L;其中Ld代表蒸馏关系,Lc代表教师网络的自更新,La代表教师网络对学生网络的远离关系,所以与学生网络更新相关的损失函数项是Ld,而Lc和La都是在更新教师网络。
步骤8,根据L更新Student和Online updated Teacher。
实施方式八、本实施方式提供了基于无监督蒸馏网络的图像分类模型获取装置,所述装置包括:
模块6:用于采集图像;
模块7:用于将所述图像通过实施方式六提供的基于无监督蒸馏网络的学生网络获取方法,得到更新后的学生网络;
模块8:用于通过预设训练图像集,训练预设的分类网络;
模块9:用于得到训练后的学生网络和分类网络。
实施方式九、结合图2说明本实施方式,本实施方式提供了基于无监督蒸馏网络的图像分类方法,所述方法包括:
步骤10:采集待分类的图像;
步骤11:将所述待分类的图像通过实施方式七提供的基于无监督蒸馏网络的图像分类模型获取方法输出的结果进行分类。
具体的,首先获得未标记的图像数据集X={x1,x2,x3,...,xn},小部分标记数据集T={t1,t2,t3,...,tm},未训练的教师网络FT,未训练的学生网络FS,投影头Fpro,FproS,预测头FpreT,FpreS;
对于未标记数据集中的每个图像,获得每张图像的2个增广,分别输入教师网络和学生网络,得到FproT(QT)=ZT,FproS(QS)=ZS,FpreT(ZT)=PT,FpreS(ZS)=PS;
其中,三个损失函数可以分别使用
更新学生网络和教师网络;
冻结已经训练完的学生网络,在标记数据集上训练分类网络(通过预设的标记好的数据集,训练学生网络和学生网络之后的分类网络;)
其中,可用的标记数据集有CIFAR-10、CIFAR-100和ImageNet-1K等。
可用的分类网络有ResNet-18和ResNet-50等。
结合训练完的学生网络和分类网络,获得完整的轻量网络F=Fc(FS(Input))用于边缘部署。
其中,Fc表示分类网络,Fs表示学生网络。
实施方式十、本实施方式提供了基于无监督蒸馏网络的图像分类装置,所述装置包括:
模块10:用于采集待分类的图像;
模块11:用于将所述待分类的图像通过实施方式八提供的基于无监督蒸馏网络的图像分类模型获取装置输出的结果进行分类。
实施方式十一、本实施方式提供了计算机储存介质,用于储存计算机程序,当所述储存介质中储存的计算机程序被计算机的处理器读取时,所述计算机执行实施方式一至五任意一项提供的基于无监督蒸馏网络的学生网络获取方法或实施方式七提供的基于无监督蒸馏网络的图像分类模型获取方法或实施方式九提供的基于无监督蒸馏网络的图像分类方法。
实施方式十二、本实施方式提供了计算机,包括处理器和储存介质,所述储存介质用于储存计算机程序,当所述储存介质中储存的计算机程序被所述处理器读取时,所述计算机执行实施方式一至五任意一项提供的基于无监督蒸馏网络的学生网络获取方法或实施方式七提供的基于无监督蒸馏网络的图像分类模型获取方法或实施方式九提供的基于无监督蒸馏网络的图像分类方法。
实施方式十三,本实施方式是对实施方式九提供的基于无监督蒸馏网络的图像分类方法提供具体的实验过程,用于验证其优点和有益之处,同时用于解释上述实施方式,具体的:
在CIFAR-10和CIFAR-100上训练ResNet-18,并用ResNet-50作为教师网络。使用SGD优化器和k=200最近邻(kNN)分类器训练模型500次。基础学习率为lr=0.2*batchsize/256。学习率遵循余弦衰减时间表。此外,使用SGD优化器每次迭代训练128幅图像,使用SGD优化器每次迭代训练128幅样本并进行100次微调的线性分类。CIFAR-100的初始学习率为30,CIFAR-10的初始学习率为3,并在第60和80次迭代中以0.1的比例因子递减。其中的CIFAR-10、CIFAR-100是数据集的名字。ResNet-18和ResNet-50是神经网络的名字。
实验一:在CIFAR-10/CIFAR-100数据集上进行测试
表1展示了实施方式九提供的基于无监督蒸馏网络的图像分类方法(简称ADCL)在不同批量大小下的KNN分类性能。我们按照提到的设置进行实验。ADCL在CIFAR-100上batchsize=256的实验是特别的,因为学习率被修改为lr=0.15*batchsize/256。其中,batchsize的意思是网络处理图像的批次大小,也就一次同时处理256张图像。
表1ADCL与其他SSL方法得性能对比
在kNN分类器下,ADCL表现出比SimCLR、DCL和SimSiam更好的性能。与CIFAR-10上的SimCLR相比,准确率提高了11.0%。虽然SimSiam缩小了差距,但仍有1.6%的提升。在CIFAR-100上,性能差距更加明显。与SimCLR相比,ADCL实现了14.1%的性能提升。与DCL和SimSiam相比,这个数字将降低到11%和3.8%,但精度提升仍然相当可观。
线性分类的实验如表2所示,ADCL在两个数据集上的表现都优于SimSiam基线。实验结果表明,CIFAR-10和CIFAR-100的性能分别得到了1.3%和3.1%的改善。
表2线性分类实验
实验二:在ImageNet数据集上进行测试
我们使用ADCL在ImageNet上训练ResNet-18,100个epoch(训练轮次),以ResNet-50为教师,基础lr=0.15*batchsize/256。学习率遵循余弦衰减时间表。此外,在训练中使用SGD优化器每次迭代使用256张图像,在线性分类中使用SGD优化器每次迭代使用256个样本,并进行100次微调。在线性分类中,初始学习率为30,并以余弦衰减的方式递减。结果如表3所示。为了进行公平的比较,离线训练的教师模型是ResNet-50,分类性能为67.4%。
表3ImageNet-1K结果
与采用离线教师进行蒸馏的方法相比,ADCL的性能明显优于ReKD,而ADCL没有使用在线数据队列产生正负样本。即使在ReKD中考虑到更深的教师网络,ADCL的性能也优于ReKD。此外,还考虑了基于离线教师模型的方法。显然,ADCL在所有条件下都比SEED表现出更好的精度。与DisCo相比,当在微调中只使用1%或10%的标签时,ADCL表现出显著的改进。尽管BINGO显示了最好的性能,但我们的方法的训练成本在所有方法中是最便宜的,因为没有维护数据队列,并且教师模型没有进行预训练。在边缘应用中,ADCL可以在一次训练中轻松地在新类型的数据上实现,并且内存开销不会成为训练的瓶颈。
实验三:对抗式损失项的消融实验
ADCL的损失函数由3个子项组成:表示学生模型的更新动态,和/>表示教师模型的更新动力源。毫无疑问,如果没有/>学生网络将是无意义的。因此,本实施方式中将研究/>和/>研究了在没有/>或/>的情况下训练模型的性能,而这两个损失子项是对抗性蒸馏学习的关键部分。结果如表4所示。
表4对抗式损失项的消融实验
如表4所示,其中的α和λ这是上面loss函数里的各个子项的缩放参数当λ=0时,只进行对比学习来优化教师网络,改进是明显的。在某些批量大小下,它甚至可以超过ADCL。此外,当α=0时,更新后的教师模型与学生模型不再相似,教师模型的表现也优于基线模型。尽管如此,在大多数情况下,与使用部分损失的方法相比,使用完整版损失函数的训练框架,即ADCL,实现了最好的性能。因此,对抗性损失项的维护有利于蒸馏训练。
实验四:学生对比损失的消融实验
由于ADCL引入了教师模式的对比学习,我们研究了是否可以将同样的方法应用于学生模型。因此,考虑包含ADCL损失和学生模型的自监督损失的组合损失。结果可以在表5中看到。很明显,加学生对比损失会对ADCL的性能造成负面影响。
表5加入学生对比损失的ADCL性能对比
以上通过几个具体实施方式对本发明提供的技术方案进行进一步详细地描述,是为了突出本发明的优点和有益之处,不过以上所述的几个具体实施方式仅仅用于解释本发明提供的技术方案,并不用于作为对本发明的限制,任何基于本发明的精神和原则内的,对本发明的合理更改和改进、实施方式的合理组合和替换等,均应当包含在本发明的保护范围之内。
Claims (12)
1.基于无监督蒸馏网络的学生网络获取方法,应用于图像分类,其特征在于,所述方法包括:
步骤1:采集图像作为数据集;
步骤2:根据所述数据集,得到两个增广;
步骤3:所述两个增广分别通过教师网络和学生网络,得到教师网络输出的投影值和预测值,以及学生网络输出的投影值和预测值。
步骤4:根据教师网络输出的投影值和预测值以及学生网络输出的投影值和预测值,更新所述教师网络和学生网络;
步骤5:输出当前学生网络,作为结果。
2.根据权利要求1所述的基于无监督蒸馏网络的学生网络获取方法,其特征在于,所述教师网络输出的投影值和预测值,以及学生网络输出的投影值和预测值的获取方法具体为:根据MLP头获取。
3.根据权利要求2所述的基于无监督蒸馏网络的学生网络获取方法,其特征在于,所述的步骤3中,MLP头根据教师网络输出的投影值和预测值以及学生网络输出的投影值和预测值生成子项,通过所述子项求解预设的损失函数。
4.根据权利要求3所述的无监督蒸馏学生网络获取方法,其特征在于,所述MLP头包括学生网络的投影头和预测头以及教师网络的投影头和预测头,所述MLP头根据所述教师网络和学生网络生成分别对应两个所述增广的、教师网络的两个投影值和两个预测值,以及分别对应两个所述增广的、学生网络的两个投影值和两个预测值,四个投影值和四个预测值作为所述子项。
5.根据权利要求1所述的无监督蒸馏学生网络获取方法,其特征在于,所述步骤4中,更新所述教师网络和学生网络的具体方式为,根据教师网络输出的投影值和预测值,得到教师网络的自更新,根据教师网络输出的投影值和学生网络的预测值,得到蒸馏更新,根据教师网络输出的预测值和学生网络的投影值,得到对抗更新,根据所述自更新、蒸馏更新和对抗更新,更新所述教师网络和学生网络。
6.无监督蒸馏学生网络获取装置,其特征在于,所述装置包括:
模块1:用于采集图像作为数据集;
模块2:用于根据所述数据集,得到两个增广;
模块3:用于将所述两个增广分别通过教师网络和学生网络,得到教师网络输出的投影值和预测值,以及学生网络输出的投影值和预测值。
模块4:根据教师网络输出的投影值和预测值以及学生网络输出的投影值和预测值,更新所述教师网络和学生网络;
模块5:用于输出当前学生网络,作为结果。
7.基于无监督蒸馏网络的图像分类模型获取方法,应用于图像分类,其特征在于,所述方法包括:
步骤6:采集待分类图像;
步骤7:将所述图像通过权利要求1-5任意一项所述的基于无监督蒸馏网络的学生网络获取方法,得到更新后的学生网络;
步骤8:通过预设训练图像集,训练预设的分类网络;
步骤9:得到训练后的学生网络和分类网络,并输出作为结果。
8.基于无监督蒸馏网络的图像分类模型获取装置,其特征在于,所述装置包括:
模块6:用于采集图像;
模块7:用于将所述图像通过权利要求6所述的基于无监督蒸馏网络的学生网络获取方法,得到更新后的学生网络;
模块8:用于通过预设训练图像集,训练预设的分类网络;
模块9:用于得到训练后的学生网络和分类网络。
9.基于无监督蒸馏网络的图像分类方法,其特征在于,所述方法包括:
步骤10:采集待分类的图像;
步骤11:将所述待分类的图像通过权利要求7所述的基于无监督蒸馏网络的图像分类模型获取方法输出的结果进行分类。
10.基于无监督蒸馏网络的图像分类装置,其特征在于,所述装置包括:
模块10:用于采集待分类的图像;
模块11:用于将所述待分类的图像通过权利要求8所述的基于无监督蒸馏网络的图像分类模型获取装置输出的结果进行分类。
11.计算机储存介质,用于储存计算机程序,其特征在于,当所述储存介质中储存的计算机程序被计算机的处理器读取时,所述计算机执行权利要求1-5任意一项所述的基于无监督蒸馏网络的学生网络获取方法或权利要求7所述的基于无监督蒸馏网络的图像分类模型获取方法或权利要求9所述的基于无监督蒸馏网络的图像分类方法。
12.计算机,包括处理器和储存介质,所述储存介质用于储存计算机程序,其特征在于,当所述储存介质中储存的计算机程序被所述处理器读取时,所述计算机执行权利要求1-5任意一项所述的基于无监督蒸馏网络的学生网络获取方法或权利要求7所述的基于无监督蒸馏网络的图像分类模型获取方法或权利要求9所述的基于无监督蒸馏网络的图像分类方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211439778.3A CN115861765A (zh) | 2022-11-17 | 2022-11-17 | 基于无监督蒸馏网络的学生网络获取方法、图像分类模型获取方法、图像分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211439778.3A CN115861765A (zh) | 2022-11-17 | 2022-11-17 | 基于无监督蒸馏网络的学生网络获取方法、图像分类模型获取方法、图像分类方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115861765A true CN115861765A (zh) | 2023-03-28 |
Family
ID=85663854
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211439778.3A Pending CN115861765A (zh) | 2022-11-17 | 2022-11-17 | 基于无监督蒸馏网络的学生网络获取方法、图像分类模型获取方法、图像分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115861765A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117058556A (zh) * | 2023-07-04 | 2023-11-14 | 南京航空航天大学 | 基于自监督蒸馏的边缘引导sar图像舰船检测方法 |
-
2022
- 2022-11-17 CN CN202211439778.3A patent/CN115861765A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117058556A (zh) * | 2023-07-04 | 2023-11-14 | 南京航空航天大学 | 基于自监督蒸馏的边缘引导sar图像舰船检测方法 |
CN117058556B (zh) * | 2023-07-04 | 2024-03-22 | 南京航空航天大学 | 基于自监督蒸馏的边缘引导sar图像舰船检测方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Kim et al. | Attract, perturb, and explore: Learning a feature alignment network for semi-supervised domain adaptation | |
CN111552807B (zh) | 一种短文本多标签分类方法 | |
JP2022542639A (ja) | 生物学関連のデータを処理するための機械学習アルゴリズムをトレーニングするためのシステムおよび方法、顕微鏡ならびにトレーニングされた機械学習アルゴリズム | |
CN109977094B (zh) | 一种用于结构化数据的半监督学习的方法 | |
Shen et al. | A fast knowledge distillation framework for visual recognition | |
Yang et al. | Robust and non-negative collective matrix factorization for text-to-image transfer learning | |
Chen et al. | Two-stage label embedding via neural factorization machine for multi-label classification | |
CN108898181B (zh) | 一种图像分类模型的处理方法、装置及存储介质 | |
Atanov et al. | Semi-conditional normalizing flows for semi-supervised learning | |
Granger et al. | Joint progressive knowledge distillation and unsupervised domain adaptation | |
Zhou et al. | Binary Linear Compression for Multi-label Classification. | |
CN113159072B (zh) | 基于一致正则化的在线超限学习机目标识别方法及系统 | |
Yakar et al. | Bilevel Sparse Models for Polyphonic Music Transcription. | |
CN115861765A (zh) | 基于无监督蒸馏网络的学生网络获取方法、图像分类模型获取方法、图像分类方法 | |
Wu et al. | A coarse-to-fine framework for resource efficient video recognition | |
CN113705222B (zh) | 槽识别模型训练方法及装置和槽填充方法及装置 | |
Zhu et al. | Multiview latent space learning with progressively fine-tuned deep features for unsupervised domain adaptation | |
Ranjan et al. | A sub-sequence based approach to protein function prediction via multi-attention based multi-aspect network | |
JP2007115245A (ja) | データの大域的構造を考慮する学習機械 | |
EP1891543A2 (en) | Cross descriptor learning system, method and program product therefor | |
Wang et al. | Q-YOLO: Efficient inference for real-time object detection | |
SATI | A novel semisupervised classification method via membership and polyhedral conic functions | |
Nasfi et al. | A novel feature selection method using generalized inverted Dirichlet-based HMMs for image categorization | |
Duan et al. | Bayesian deep embedding topic meta-learner | |
Rana et al. | Multi-task semisupervised adversarial autoencoding for speech emotion |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |