CN114782776A - 基于MoCo模型的多模块知识蒸馏方法 - Google Patents

基于MoCo模型的多模块知识蒸馏方法 Download PDF

Info

Publication number
CN114782776A
CN114782776A CN202210412270.8A CN202210412270A CN114782776A CN 114782776 A CN114782776 A CN 114782776A CN 202210412270 A CN202210412270 A CN 202210412270A CN 114782776 A CN114782776 A CN 114782776A
Authority
CN
China
Prior art keywords
network
module
teacher
student network
similarity
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202210412270.8A
Other languages
English (en)
Other versions
CN114782776B (zh
Inventor
王军
袁静波
刘新旺
李玉莲
李兵
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
China University of Mining and Technology CUMT
Original Assignee
China University of Mining and Technology CUMT
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by China University of Mining and Technology CUMT filed Critical China University of Mining and Technology CUMT
Priority to CN202210412270.8A priority Critical patent/CN114782776B/zh
Publication of CN114782776A publication Critical patent/CN114782776A/zh
Application granted granted Critical
Publication of CN114782776B publication Critical patent/CN114782776B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • G06F18/2155Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/217Validation; Performance evaluation; Active pattern learning techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/22Matching criteria, e.g. proximity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/28Determining representative reference patterns, e.g. by averaging or distorting; Generating dictionaries
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于MoCo模型的多模块知识蒸馏方法,利用中间过程中生成的特征间具有相似度这一特点,将教师和学生网络各自分成对应的多个模块,通过MoCo模型提取到教师和学生网络的每个模块生成的特征计算相似度,利用相似度达到教师网络指导学生网络的目的。本发明可以在只有少量标签的基础上,自动地对样本特征进行动态更新,此方法的内存效率更高,解决了在有限内存的情况下训练大规模数据集的问题,使教师网络指导下的学生网络有鲁棒性的同时,兼具泛化性。

Description

基于MoCo模型的多模块知识蒸馏方法
技术领域
本发明属于模型轻量化技术,尤其涉及一种基于MoCo模型的多模块知识蒸馏方法。
背景技术
近年来,机器学习和深度学习在计算机视觉、自然语言处理、预测和音频处理等方面都有了卓越的进步,对于这些复杂的任务,训练后模型的规模很大,这使得在资源受限的设备上部署它很困难。在知识蒸馏中,在大数据集上训练的较大的繁琐网络(教师模型)可以很好地将学习到的知识转移到作为一个学生模型的更小更轻的网络中。
在基于瘦长网络的提示的研究中,引入了一种两阶段的策略来训练深度网络,但是没有明显的速度提升;深度相互学习提出了教师-学生网络相互学习,并且同时更新,但是难以提取学习更细节的信息,带来的误差更大;再生网络中,提出了利用学习到的学生网络指导下一级的学生网络,但是训练时间长且冗余过程较多。
发明内容
本发明的目的在于提供一种基于MoCo模型的多模块知识蒸馏方法,解决了在有限内存的情况下训练大规模数据集的问题,达到了减少运算量提高内存效率的效果。
实现本发明目的的技术解决方案为:一种基于MoCo模型的多模块知识蒸馏方法,包括以下步骤:
步骤S1、在Imagenet中随机采集K幅带标签的图像,1000<K<10000,对上述K幅图像逐张统一尺寸后进行数据增强,得到像素大小为h×w的2K幅带标签的图像,构成教师网络训练集。
步骤S2、将教师网络训练集输入教师网络,利用教师网络训练集对教师网络进行预训练,得到预训练教师网络。
步骤S3、在Instagram中随机采集N幅无标签的图像,10000<N<100000,对上述N幅图像逐张统一尺寸后进行数据增强,得到像素大小为h×w的2N幅无标签的图像,构成教师-学生网络训练集。
步骤S4、构建MoCo模型:
所述MoCo模型包括预训练教师网络、学生网络、编码器和动态编码器,将预训练教师网络划分成m个模块,并将学生网络也对应划分成m个模块,2<m<100。
步骤S5、将教师-学生网络训练集输入MoCo模型,提取预训练教师网络和学生网络中各模块生成的特征,并将上述特征分别输入编码器和动态编码器进行编码,对应得到查询样本特征和匹配样本特征,求出查询样本特征和匹配样本特征的相似度。用学生网络中第n+1个模块生成的相似度学习预训练教师网络第n+1个模块生成的相似度和第n模块生成的相似度,以此更新学生网络的网络参数,1≤n≤m。同时,预训练教师网络和学生网络都根据自身各模块生成的相似度各自对网络参数进行更新,最终获得训练好的学生网络。
步骤S6、在Instagram中随机采集M幅带标签的图像,100<M<1000,对上述M幅图像逐张统一尺寸后进行数据增强,得到像素大小为h×w的2M幅图像,构成学生网络测试集。
步骤S7、将学生网络测试集输入MoCo模型中训练好的学生网络,输出学生网络测试集中每个样本对应的预测结果,测试训练好的学生网络的准确率。
本发明与现有技术相比,其显著优点在于:
(1)首次将Moco模型学习到的相似度用于知识蒸馏方法中,可以在只有少量标签的基础上,自动地对样本特征进行动态更新,使内存效率更高,并且没有匹配提取到的特征的步骤,减少中间数据转换的误差,使教师网络指导下的学生网络有鲁棒性的同时,兼具泛化性。
(2)利用Moco模型自身的特性,预训练教师网络和学生网络都能通过相似度对网络参数进行自我更新,学生网络不仅可以学习各模块的工作方式还可以回顾复习未被学习到的特征,通过增加更新策略的方式,提高了学生网络的准确度。
(3)在Moco模型中加入了池化层,为前期的训练提供可靠的数据,加速数据收敛,并且利用移动平均值的策略,使网络更好的更新,既保留了原数据又平稳添加新的梯度。
附图说明
图1为基于MoCo模型的多模块知识蒸馏方法模型图。
具体实施方式
下面结合附图对本发明作进一步详细描述。
结合图1,本发明所述的一种基于MoCo模型的多模块知识蒸馏方法,步骤如下:
步骤S1、在Imagenet中随机采集K幅带标签的图像,1000<K<10000,对上述K幅带标签的图像逐张统一尺寸后进行数据增强,得到像素大小为h×w(h取值范围为0~256,w取值范围为0~256)的2K幅带标签的图像,构成带标签的教师网络训练集,转入步骤S2。
步骤S2、将带标签的教师网络训练集输入教师网络,利用教师网络训练集对教师网络进行预训练,得到预训练教师网络,转入步骤S3。
步骤S3、在Instagram中随机采集N幅无标签的图像,10000<N<100000,对上述N幅图像逐张统一尺寸后进行数据增强,得到像素大小为h×w的2N幅无标签的图像,构成无标签的教师-学生网络训练集,转入步骤S4。
步骤S4、构建MoCo模型:
所述MoCo模型包括预训练教师网络、学生网络、编码器和动态编码器,将预训练教师网络划分成m个模块,并将学生网络也对应划分成m个模块,2<m<100。
所述预训练教师网络和学生网络均无分支,包括但不局限于经典网络结构中的ResNet、VGGNet、Mobilenet等。预训练教师网络规模数据均大于学生网络,转入步骤S5。
步骤S5、将无标签的教师-学生网络训练集输入MoCo模型,提取预训练教师网络和学生网络中各模块生成的特征,并将上述特征分别输入编码器和动态编码器进行编码,对应得到查询样本特征和匹配样本特征,求出查询样本特征和匹配样本特征的相似度;用学生网络中第n+1个模块生成的相似度学习预训练教师网络中的第n+1个模块生成的相似度和第n模块(预训练教师网络中)生成的相似度,以此更新学生网络的网络参数,1≤n≤m;同时,预训练教师网络和学生网络都根据自身各模块生成的相似度各自对网络参数进行更新,最终获得训练好的学生网络,具体如下:
编码器和动态编码器采用相同结构,编码器承担了生成查询特征的任务;动态编码器基于无监督学习的对比损失构建具有一致性的字典,字典是以队列的形式表现出来的:
当前的特征经过动态编码器编码后得到的匹配样本特征进入队列,最先进入的一组匹配样本特征被清理出队列。
当前有编码器生成的一个查询样本特征q和动态编码器生成的一组序列{k0,k1,k2,…},序列作为字典中的键,序列中存在一个与q匹配的键k+;利用点积度量相似性,提出对比损失函数Lq
Figure BDA0003604451330000041
其中,τ是一个温度超参数,ki为字典中的键;字典中的键包括一个正样本k+和K个负样本,1<K<100;当q与键k+相似,而与所有其他键不同时,Lq的值趋近于0。
查询样本特征q由编码器fq和池化层产生,即q=fq(xq)+poolq(xq),xq表示任意一个查询样本;键ki由动态编码器fk和池化层产生,即ki=fk(xki)+poolki(xki),xki是字典中的键(即字典所需样本)。
此外,提出了一种缓慢进行的动态编码器更新方式,其动态是基于编码器的移动平均值来实现的,并以此与编码器保持一致性,将fk的参数表示为θk,fq的参数表示为θq,更新θk的公式为:
θk→ε(θk-t+θk-t+1+……+θk)/t+(1-ε)θq
其中,ε∈[0,1)是一个动量系数,t为移动平均数个数,0<t<100,只有参数θq才会通过反向传播进行更新。
在MoCo模型中,提取预训练教师网络和学生网络中各模块生成的特征,并将上述特征分别输入编码器和动态编码器进行编码,对应得到查询样本特征和匹配样本特征,求出查询样本特征和匹配样本特征的相似度,具体如下:
上述相似度中包含的信息用于指导学生网络进行优化。
字典中的键包括一个正样本k+和K个负样本ks;查询样本特征与正样本产生正样本相似度lpos
lpos=bmm(q,k+)
其中,bmm是分批矩阵乘法函数。
查询样本与剩下K个负样本ks产生负样本相似度lneg
lneg=mm(q,ks)
其中,mm是矩阵乘法函数。
将得到的lpos和lneg拼接起来得到样本相似度logits:
logits=cat(lpos,lneg)
其中,cat是矩阵拼接函数;得到预训练教师网络和学生网络中各模块对应生成的相似度,利用无监督样本自动生成的标签labels与样本相似度logits求出标签损失函数Llabel
Llabel=CrossEntropyLoss(logits/τ,labels)
其中,CrossEntropyLoss可求出交叉熵。
在MoCo模型中,存在三个更新策略:用学生网络中第n+1个模块生成的相似度学习预训练教师网络第n+1个模块生成的相似度和第n模块生成的相似度,以此更新学生网络的网络参数,1≤n≤m。
用学生网络中第1个模块生成的相似度学习预训练教师网络第1个模块生成的相似度,用学生网络中第2个模块生成的相似度学习预训练教师网络第2个模块生成的相似度和第1模块生成的相似度,用学生网络中第3个模块生成的相似度学习预训练教师网络第3个模块生成的相似度和第2模块生成的相似度,以此更新学生网络的网络参数。
预训练教师网络和学生网络都根据自身各模块生成的相似度各自对网络参数进行更新,对应着三个损失函数:标签损失函数Llabel、教师-学生网络损失函数Lst1、回顾损失函数Lst2
将学生网络中第n+1个模块生成的相似度,向预训练教师网络中的第n+1个模块生成的相似度进行学习,以此更新学生网络的网络参数,具体如下:
利用预训练教师网络指导学生网络,即用预训练教师网络中第n+1个模块生成的相似度
Figure BDA0003604451330000051
与对应的学生网络中第n+1个模块生成的相似度
Figure BDA0003604451330000052
求出教师-学生网络损失函数Lst1
Figure BDA0003604451330000061
将学生网络中第n+1个模块生成的相似度
Figure BDA0003604451330000062
向预训练教师网络中第n个模块生成的相似度
Figure BDA0003604451330000063
进行学习,以此更新学生网络的网络参数,将回顾损失函数定义为Lst2
Figure BDA0003604451330000064
预训练教师网络根据损失函数Llabel进行更新迭代,而学生网络的损失函数包括三个部分:标签损失函数Llabel、教师-学生网络损失函数Lst1、回顾损失函数Lst2,则学生网络的损失函数L为:
L=αLlabel+βLst1+γLst2
其中,α,β,γ为损失函数L中的平衡系数;将教师-学生网络训练集中的所有图像分批次重复以上操作后,最终获得训练好的学生网络。
转入步骤S6。
步骤S6、在Instagram中随机采集M幅带标签的图像,100<M<1000,对上述M幅图像逐张统一尺寸后进行数据增强,得到像素大小为h×w的2M幅图像,构成学生网络测试集,转入步骤S7。
步骤S7、将学生网络测试集输入MoCo模型中训练好的学生网络,输出学生网络测试集中每个样本对应的预测结果,测试训练好的学生网络的准确率。
实施例1
本发明所述的基于MoCo模型的多模块知识蒸馏方法,步骤如下:
步骤S1、在Imagenet中随机采集5000幅带标签的图像,对这5000幅图像逐张进行尺寸的统一后进行数据增强,得到像素大小为256×256的10000幅图像,构成教师网络训练集。
步骤S2、将教师网络训练集输入教师网络,利用教师网络训练集对教师网络进行预训练,得到预训练教师网络。
步骤S3、在Instagram中随机采集50000幅图像,对这50000幅图像逐张进行尺寸的统一后进行数据增强,得到像素大小为256×256的100000幅图像,构成教师-学生网络训练集。
步骤S4、构建多模块知识蒸馏的MoCo模型:
所述MoCo模型包含了预训练教师网络和学生网络,将预训练教师网络和学生网络各自划分成一一对应的3个模块,提取每个模块生成的特征并输入编码器和动态编码器即可求出对应的相似度。在构建MoCo模型时,提取预训练教师网络和学生网络中每个模块生成的特征并输入编码器和动态编码器,其中,编码器和动态编码器可以被认为是为字典查找任务而训练的:编码器承担了生成查询特征的任务。动态编码器基于无监督学习的对比损失构建了大型且具有一致性的字典,其中,字典是以队列的形式表现出来的:当前的特征经过编码后得到的匹配样本特征进入队列,最先进入的一组匹配样本特征被清理出队列,在此处,字典可容纳500个匹配样本特征。
步骤S5、将教师-学生网络训练集以每个批次为128的数量输入多模块知识蒸馏的MoCo模型,得到预训练教师网络和学生网络各个模块对应生成的相似度,将学生网络各个模块生成的相似度根据该模块的对应关系,对预训练教师网络中对应模块生成的相似度和对应模块前一个模块生成的相似度进行学习,以此更新学生网络的网络参数。同时,预训练教师网络和学生网络都根据各个模块生成的相似度各自对网络参数进行更新,最终获得训练好的学生网络。
在MoCo模型中,存在三个更新策略:将学生网络各个模块生成的相似度根据该模块的对应关系,对预训练教师网络中对应模块生成的相似度和对应模块前一个模块生成的相似度进行学习,以此更新学生网络的网络参数。同时,预训练教师网络和学生网络都根据各个模块生成的相似度各自对网络参数进行更新,对应着三个损失函数。
步骤S6、在Instagram中随机采集500幅带标签的图像,对这500幅图像逐张进行尺寸的统一后进行数据增强,得到像素大小为256×256的1000幅图像,构成学生网络测试集。
步骤S7、将学生网络测试集输入多模块知识蒸馏的MoCo模型中训练好的的学生网络,输出学生网络测试集中每个样本对应的预测结果,测试训练好的学生网络的准确率。

Claims (5)

1.一种基于MoCo模型的多模块知识蒸馏方法,其特征在于,步骤如下:
步骤S1、在Imagenet中随机采集K幅带标签的图像,1000<K<10000,对上述K幅图像逐张统一尺寸后进行数据增强,得到像素大小为h×w的2K幅带标签的图像,构成教师网络训练集,转入步骤S2;
步骤S2、将教师网络训练集输入教师网络,利用教师网络训练集对教师网络进行预训练,得到预训练教师网络,转入步骤S3;
步骤S3、在Instagram中随机采集N幅无标签的图像,10000<N<100000,对上述N幅图像逐张统一尺寸后进行数据增强,得到像素大小为h×w的2N幅无标签的图像,构成教师-学生网络训练集,转入步骤S4;
步骤S4、构建MoCo模型:
所述MoCo模型包括预训练教师网络、学生网络、编码器和动态编码器,将预训练教师网络划分成m个模块,并将学生网络也对应划分成m个模块,2<m<100;
转入步骤S5;
步骤S5、将教师-学生网络训练集输入MoCo模型,提取预训练教师网络和学生网络中各模块生成的特征,并将上述特征分别输入编码器和动态编码器进行编码,对应得到查询样本特征和匹配样本特征,求出查询样本特征和匹配样本特征的相似度;用学生网络中第n+1个模块生成的相似度学习预训练教师网络第n+1个模块生成的相似度和第n模块生成的相似度,以此更新学生网络的网络参数,1≤n≤m;同时,预训练教师网络和学生网络都根据自身各模块生成的相似度各自对网络参数进行更新,最终获得训练好的学生网络,转入步骤S6;
步骤S6、在Instagram中随机采集M幅带标签的图像,100<M<1000,对上述M幅图像逐张统一尺寸后进行数据增强,得到像素大小为h×w的2M幅图像,构成学生网络测试集,转入步骤S7;
步骤S7、将学生网络测试集输入MoCo模型中训练好的学生网络,输出学生网络测试集中每个样本对应的预测结果,测试训练好的学生网络的准确率。
2.根据权利要求1所述的基于MoCo模型的多模块知识蒸馏方法,其特征在于,步骤S5中,在MoCo模型中,提取预训练教师网络和学生网络中各模块生成的特征并输入编码器和动态编码器,其中,编码器和动态编码器采用相同结构,编码器承担了生成查询特征的任务;动态编码器基于无监督学习的对比损失构建具有一致性的字典,字典是以队列的形式表现出来的:
当前特征经过动态编码器编码后得到的匹配样本特征进入队列,最先进入的一组匹配样本特征被清理出队列;
当前有编码器生成的一个查询样本特征q和动态编码器生成的一组序列{k0,k1,k2,…},序列作为字典中的键,序列中存在一个与q匹配的键k+;利用点积度量相似性,提出对比损失函数Lq
Figure FDA0003604451320000021
其中,τ是一个温度超参数,ki为字典中的键;字典中的键包括一个正样本k+和K个负样本,1<K<100;当q与键k+相似,而与所有其他键不同时,Lq的值趋近于0;
查询样本特征q由编码器fq和池化层产生,即q=fq(xq)+poolq(xq),xq表示任意一个查询样本;键ki由动态编码器fk和池化层产生,即ki=fk(xki)+poolki(xki),xki是字典中的键;
此外,提出了一种缓慢进行的动态编码器更新方式,其动态是基于编码器的移动平均值来实现的,并以此与编码器保持一致性,将fk的参数表示为θk,fq的参数表示为θq,更新θk的公式为:
θk→ε(θk-tk-t+1+……+θk)/t+(1-ε)θq
其中,动量系数ε∈[0,1),t为移动平均数的个数,0<t<100,只有参数θq才会通过反向传播进行更新。
3.根据权利要求2所述的基于MoCo模型的多模块知识蒸馏方法,其特征在于,步骤S5中,在MoCo模型中,提取预训练教师网络和学生网络中各模块生成的特征,并将上述特征分别输入编码器和动态编码器进行编码,对应得到查询样本特征和匹配样本特征,求出查询样本特征和匹配样本特征的相似度,具体如下:
上述相似度中包含的信息用于指导学生网络进行优化;
字典中的键包括一个正样本k+和K个负样本ks;查询样本特征与正样本产生正样本相似度lpos
lpos=bmm(q,k+)
其中,bmm是分批矩阵乘法函数;
查询样本与剩下K个负样本ks产生负样本相似度lneg
lneg=mm(q,ks)
其中,mm是矩阵乘法函数;
将得到的lpos和lneg拼接起来得到样本相似度logits:
logits=cat(lpos,lneg)
其中,cat是矩阵拼接函数;得到预训练教师网络和学生网络中各模块对应生成的相似度,利用无监督样本自动生成的标签labels与样本相似度logits求出标签损失函数Llabel
Llabel=CrossEntropyLoss(logits/τ,labels)
其中,CrossEntropyLoss求出交叉熵。
4.根据权利要求3所述的基于MoCo模型的多模块知识蒸馏方法,其特征在于,步骤S5中,在MoCo模型中,用学生网络中第n+1个模块生成的相似度学习预训练教师网络中的第n+1个模块生成的相似度和第n模块生成的相似度,以此更新学生网络的网络参数;具体如下:
用学生网络中第n+1个模块生成的相似度学习预训练教师网络中的第n+1个模块生成的相似度:
利用预训练教师网络指导学生网络,即用预训练教师网络中第n+1个模块生成的相似度
Figure FDA0003604451320000031
与对应的学生网络中第n+1个模块生成的相似度
Figure FDA0003604451320000032
求出教师-学生网络损失函数Lst1
Figure FDA0003604451320000033
n表示模块序号;
将学生网络中第n+1个模块生成的相似度
Figure FDA0003604451320000034
向预训练教师网络中第n个模块生成的相似度
Figure FDA0003604451320000035
进行学习,以此更新学生网络的网络参数,将回顾损失函数定义为Lst2
Figure FDA0003604451320000041
预训练教师网络根据损失函数Llabel进行更新迭代,而学生网络的损失函数包括三个部分:标签损失函数Llabel、教师-学生网络损失函数Lst1、回顾损失函数Lst2,则学生网络的损失函数L为:
L=αLlabel+βLst1+γLst2
其中,α,β,γ为损失函数L中的平衡系数;将教师-学生网络训练集中的所有图像分批次重复以上操作后,最终获得训练好的学生网络。
5.根据权利要求4所述的基于MoCo模型的多模块知识蒸馏方法,其特征在于,预训练教师网络和学生网络均无分支;预训练教师网络规模数据均大于学生网络。
CN202210412270.8A 2022-04-19 2022-04-19 基于MoCo模型的多模块知识蒸馏方法 Active CN114782776B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210412270.8A CN114782776B (zh) 2022-04-19 2022-04-19 基于MoCo模型的多模块知识蒸馏方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210412270.8A CN114782776B (zh) 2022-04-19 2022-04-19 基于MoCo模型的多模块知识蒸馏方法

Publications (2)

Publication Number Publication Date
CN114782776A true CN114782776A (zh) 2022-07-22
CN114782776B CN114782776B (zh) 2022-12-13

Family

ID=82431791

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210412270.8A Active CN114782776B (zh) 2022-04-19 2022-04-19 基于MoCo模型的多模块知识蒸馏方法

Country Status (1)

Country Link
CN (1) CN114782776B (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116486285A (zh) * 2023-03-15 2023-07-25 中国矿业大学 一种基于类别掩码蒸馏的航拍图像目标检测方法
CN117253123A (zh) * 2023-08-11 2023-12-19 中国矿业大学 一种基于中间层特征辅助模块融合匹配的知识蒸馏方法

Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190346522A1 (en) * 2018-05-10 2019-11-14 Siemens Healthcare Gmbh Method of reconstructing magnetic resonance image data
CA3076424A1 (en) * 2019-03-22 2020-09-22 Royal Bank Of Canada System and method for knowledge distillation between neural networks
US20210319266A1 (en) * 2020-04-13 2021-10-14 Google Llc Systems and methods for contrastive learning of visual representations
CN113610173A (zh) * 2021-08-13 2021-11-05 天津大学 一种基于知识蒸馏的多跨域少样本分类方法
CN113850012A (zh) * 2021-06-11 2021-12-28 腾讯科技(深圳)有限公司 数据处理模型生成方法、装置、介质及电子设备
CN113870845A (zh) * 2021-09-26 2021-12-31 平安科技(深圳)有限公司 语音识别模型训练方法、装置、设备及介质
CN114022697A (zh) * 2021-09-18 2022-02-08 华侨大学 基于多任务学习与知识蒸馏的车辆再辨识方法及系统
CN114091572A (zh) * 2021-10-26 2022-02-25 上海瑾盛通信科技有限公司 模型训练的方法、装置、数据处理系统及服务器
CN114328834A (zh) * 2021-12-29 2022-04-12 成都晓多科技有限公司 一种模型蒸馏方法、系统以及文本检索方法

Patent Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190346522A1 (en) * 2018-05-10 2019-11-14 Siemens Healthcare Gmbh Method of reconstructing magnetic resonance image data
CA3076424A1 (en) * 2019-03-22 2020-09-22 Royal Bank Of Canada System and method for knowledge distillation between neural networks
US20210319266A1 (en) * 2020-04-13 2021-10-14 Google Llc Systems and methods for contrastive learning of visual representations
CN113850012A (zh) * 2021-06-11 2021-12-28 腾讯科技(深圳)有限公司 数据处理模型生成方法、装置、介质及电子设备
CN113610173A (zh) * 2021-08-13 2021-11-05 天津大学 一种基于知识蒸馏的多跨域少样本分类方法
CN114022697A (zh) * 2021-09-18 2022-02-08 华侨大学 基于多任务学习与知识蒸馏的车辆再辨识方法及系统
CN113870845A (zh) * 2021-09-26 2021-12-31 平安科技(深圳)有限公司 语音识别模型训练方法、装置、设备及介质
CN114091572A (zh) * 2021-10-26 2022-02-25 上海瑾盛通信科技有限公司 模型训练的方法、装置、数据处理系统及服务器
CN114328834A (zh) * 2021-12-29 2022-04-12 成都晓多科技有限公司 一种模型蒸馏方法、系统以及文本检索方法

Non-Patent Citations (5)

* Cited by examiner, † Cited by third party
Title
HAOHANG XU 等: "BAG OF INSTANCES AGGREGATION BOOSTS SELF-SUPERVISED DISTILLATION", 《ICLR 2022》 *
JIALI DUAN 等: "SLADE: A Self-Training Framework For Distance Metric Learning", 《CVF》 *
ZEMING LI 等: "Momentum Teacher: Momentum Teacher with Momentum Statistics for Self-Supervised Learning", 《ARXIV》 *
田春娜 等: "自监督视频表征学习综述", 《西安电子科技大学学报》 *
陶超 等: "遥感影像智能解译:从监督学习到自监督学习", 《测绘学报》 *

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116486285A (zh) * 2023-03-15 2023-07-25 中国矿业大学 一种基于类别掩码蒸馏的航拍图像目标检测方法
CN116486285B (zh) * 2023-03-15 2024-03-19 中国矿业大学 一种基于类别掩码蒸馏的航拍图像目标检测方法
CN117253123A (zh) * 2023-08-11 2023-12-19 中国矿业大学 一种基于中间层特征辅助模块融合匹配的知识蒸馏方法
CN117253123B (zh) * 2023-08-11 2024-05-17 中国矿业大学 一种基于中间层特征辅助模块融合匹配的知识蒸馏方法

Also Published As

Publication number Publication date
CN114782776B (zh) 2022-12-13

Similar Documents

Publication Publication Date Title
CN114782776B (zh) 基于MoCo模型的多模块知识蒸馏方法
US20240177047A1 (en) Knowledge grap pre-training method based on structural context infor
CN108717574B (zh) 一种基于连词标记和强化学习的自然语言推理方法
CN111666406B (zh) 基于自注意力的单词和标签联合的短文本分类预测方法
CN113627093B (zh) 一种基于改进Unet网络的水下机构跨尺度流场特征预测方法
CN113988449A (zh) 基于Transformer模型的风电功率预测方法
CN112668719A (zh) 基于工程能力提升的知识图谱构建方法
CN111832637B (zh) 基于交替方向乘子法admm的分布式深度学习分类方法
CN113204633A (zh) 一种语义匹配蒸馏方法及装置
CN115687638A (zh) 基于三元组森林的实体关系联合抽取方法及系统
CN114239574A (zh) 一种基于实体和关系联合学习的矿工违规行为知识抽取方法
CN116932722A (zh) 一种基于跨模态数据融合的医学视觉问答方法及系统
CN107766895A (zh) 一种诱导式非负投影半监督数据分类方法及系统
CN112905750A (zh) 一种优化模型的生成方法和设备
CN116521887A (zh) 一种基于深度学习的知识图谱复杂问答系统及方法
CN116151335A (zh) 一种适用于嵌入式设备的脉冲神经网络轻量化方法及系统
CN116958700A (zh) 一种基于提示工程和对比学习的图像分类方法
CN114880527B (zh) 一种基于多预测任务的多模态知识图谱表示方法
CN116306653A (zh) 一种正则化领域知识辅助的命名实体识别方法
CN109919200B (zh) 一种基于张量分解和域适应的图像分类方法
CN116030257B (zh) 一种基于NesT模型的语义分割方法
CN112364654A (zh) 一种面向教育领域的实体和关系联合抽取方法
CN112417869A (zh) 一种产品模型描述对比方法及系统
Zhang et al. S 5 Mars: Semi-Supervised Learning for Mars Semantic Segmentation
CN113627073B (zh) 一种基于改进的Unet++网络的水下航行器流场结果预测方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant