CN117494780A - 一种混合学习中知识蒸馏的学生网络训练方法 - Google Patents
一种混合学习中知识蒸馏的学生网络训练方法 Download PDFInfo
- Publication number
- CN117494780A CN117494780A CN202311105587.8A CN202311105587A CN117494780A CN 117494780 A CN117494780 A CN 117494780A CN 202311105587 A CN202311105587 A CN 202311105587A CN 117494780 A CN117494780 A CN 117494780A
- Authority
- CN
- China
- Prior art keywords
- network
- logic
- student
- teacher
- student network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 55
- 238000012549 training Methods 0.000 title claims abstract description 40
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 31
- 230000006870 function Effects 0.000 claims abstract description 35
- 238000012545 processing Methods 0.000 claims abstract description 10
- 230000008569 process Effects 0.000 claims description 12
- 238000004821 distillation Methods 0.000 claims description 11
- 238000002759 z-score normalization Methods 0.000 claims description 10
- 238000004590 computer program Methods 0.000 claims description 8
- 238000001514 detection method Methods 0.000 claims description 6
- 238000005457 optimization Methods 0.000 claims description 5
- 238000013135 deep learning Methods 0.000 abstract description 7
- 239000013598 vector Substances 0.000 description 13
- 238000010606 normalization Methods 0.000 description 10
- 238000007781 pre-processing Methods 0.000 description 10
- 238000013459 approach Methods 0.000 description 6
- 238000009795 derivation Methods 0.000 description 6
- 230000000694 effects Effects 0.000 description 6
- 230000006872 improvement Effects 0.000 description 4
- 101100153586 Caenorhabditis elegans top-1 gene Proteins 0.000 description 3
- 101100370075 Mus musculus Top1 gene Proteins 0.000 description 3
- 230000008901 benefit Effects 0.000 description 3
- 238000010586 diagram Methods 0.000 description 3
- 230000014509 gene expression Effects 0.000 description 3
- 238000010200 validation analysis Methods 0.000 description 3
- 238000002679 ablation Methods 0.000 description 2
- 230000009466 transformation Effects 0.000 description 2
- 101100153581 Bacillus anthracis topX gene Proteins 0.000 description 1
- 101150041570 TOP1 gene Proteins 0.000 description 1
- 230000003044 adaptive effect Effects 0.000 description 1
- 230000003042 antagnostic effect Effects 0.000 description 1
- 230000007812 deficiency Effects 0.000 description 1
- 238000006073 displacement reaction Methods 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 238000013508 migration Methods 0.000 description 1
- 230000005012 migration Effects 0.000 description 1
- 230000003278 mimic effect Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000008092 positive effect Effects 0.000 description 1
- 238000002203 pretreatment Methods 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
- 238000000638 solvent extraction Methods 0.000 description 1
- 239000000758 substrate Substances 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 230000007704 transition Effects 0.000 description 1
- 238000012800 visualization Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Multimedia (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- General Engineering & Computer Science (AREA)
- Molecular Biology (AREA)
- Data Mining & Analysis (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种混合学习中知识蒸馏的学生网络训练方法,其步骤包括:1)在训练数据上选取目标领域的训练样本;2)将训练样本预处理后分别输入学生网络、教师网络,获得相应的学生网络logit、教师网络logit;3)将每个学生网络logit、教师网络logit分别进行Z‑score标准化处理;4)将Z‑score标准化后的教师网络logit、学生网络logit转化为概率形式;5)任选一教师网络logit对应的概率和学生网络logit对应的概率,并计算所选两概率之间的KL散度作为损失函数,进行梯度下降优化蒸馏学生网络。本发明解决了深度学习知识蒸馏算法中教师网络和学生网络之间能力鸿沟问题。
Description
技术领域
本发明属于计算机软件技术领域,涉及一种混合学习中知识蒸馏的学生网络训练方法。
背景技术
知识蒸馏是一种在不牺牲太多精度的情况下减小模型尺寸的新方法,它涉及到将预先训练好的重型模型(称为教师网络)的知识转移到小型目标模型(称为学生网络)。
Hinton等人首先提出通过最小化教师网络与学生网络之间预测的Kullback-Leibler(KL)散度,将教师网络的知识蒸馏给学生网络,他们在softmax函数中引入一个称为温度T的比例因子来缓和预测概率。此后,温度作为超参数预先全局设置,并在整个训练过程中固定,直到Li等人提出的CTKD采用对抗性学习模块来预测逐样本不同的温度,以更好地适应样本间差异。但是,对于温度的标度因子从何而来以及能否确定其值,仍缺乏理论分析。
知识蒸馏旨在将“暗”知识从重型的教师模型转移到轻量级的学生模型。通过学习教师的软标签,学生网络可以获得比只在硬标签上训练更好的成绩。传统的方法是通过最小化学生网络预测概率与教师网络预测概率之间的差异(即KL散度)来训练学生。预测概率通常用logit输出的softmax函数来近似。
在基于逻辑的方法中,引入温度来平坦化概率。一些工作探讨了它的性质和效果。他们得出了相同的结论,即温度控制着学生网络对那些比平均水平更负向的logit的关注程度。非常低的温度使学生网络忽略其他logit,而主要关注教师网络的最大logit。然而,他们没有讨论为什么教师网络和学生网络共享一个全局预定义的温度,且目前尚不清楚是否可以在样本级别上确定温度。Li等人提出的CTKD首先提出利用对抗学习来预测样本温度。然而,没有提出为什么样本温度有效的理论分析。此外,教师网络和学生网络是否会有不同的温度也未被发现。ATKD设计了一个锐度度量,并通过减小教师网络和学生网络之间的锐度差距来选择自适应温度。然而,他们忽略了对教师网络和学生网络不同温度的理论支持。此外,他们的零logit平均值假设依赖于数值近似,限制了其性能。
发明内容
针对现有技术中存在的问题,本发明的目的在于提供一种混合学习中知识蒸馏的学生网络训练方法。本专利旨在利用逻辑单元logit标准化预处理以信息论中熵最大化的角度解决深度学习知识蒸馏算法中教师网络和学生网络之间能力鸿沟问题。本专利证明教师网络和学生网络之间共享温度的假设会产生副作用,即轻量级学生网络被迫生成与教师网络相同范围和方差的逻辑单元logit(http://cjc.ict.ac.cn/online/onlinepaper/HZH315.pdf),这可能会限制学生网络的表现。为了解决这个问题,本发明建议将温度设置为logit的加权标准差,并在使用softmax函数之前执行Z-score预处理。在理想情况下,教师网络和学生网络的温度之比等于其logit的标准差之比。该预处理过程易于定义和实现,可以提高现有的涉及温度的基于逻辑的知识蒸馏方法的性能。采用本发明的预处理技术的vanilla知识蒸馏可以获得较好的性能,而其他蒸馏变体在本发明的预处理技术的帮助下可以获得可观的收益。
本专利首先完成了从信息论角度对深度学习图像分类任务中涉及到的softmax函数推导,分类中的softmax函数可以被证明是受概率归一化条件和信息论中状态期望约束的熵最大化的唯一解。
本专利在上述基础上完成了深度学习知识蒸馏中涉及到的softmax函数推导,得到了教师网络和学生网络可以拥有不同温度的结论,且对不同样本也允许采用不同的样本温度。
本专利基于以上证明,提出深度学习知识蒸馏任务中的logit标准化。具体操作过程如下:
1.首先在训练数据(如CIFAR-100、ImageNet等)上选取用于图像分类等识别任务的训练样本;
2.将训练样本预处理后通过学生网络获得相应的学生网络logit,同样通过教师网络获得相应的教师网络logit;
3.将每个logit进行Z-score标准化处理,即将logit向量减去其均值后除以其标准差,对教师网络logit和学生网络logit分别都进行此操作,此操作等同于将学生网络和教师网络的温度设置为其logit标准差的倍数;
4.最后使用softmax函数将标准化后的对教师网络logit和学生网络logit转化为概率形式,计算两个概率之间的KL散度作为损失函数,进行梯度下降,完成此步的学生网络蒸馏,相较于传统蒸馏算法本方法将学生网络和教师网络的温度自适应地设置为其logit标准差倍数,弥补了两个网络规模带来的logit规模差距,解决了学生教师网络能力差距过大引起的蒸馏效率下降的问题。
上述方案是以图像分类为例进行描述的,但其同样适用于如文本分类等分类任务、目标检测等检测任务。
本发明的技术方案为:
一种混合学习中知识蒸馏的学生网络训练方法,其步骤包括:
1)在训练数据上选取目标领域的训练样本;
2)将所述训练样本预处理后分别输入学生网络、教师网络,获得相应的学生网络logit、教师网络logit;
3)将每个学生网络logit、教师网络logit分别进行Z-score标准化处理,获得Z-score标准化后的学生网络logit、Z-score标准化后的教师网络logit;
4)将Z-score标准化后的教师网络logit、Z-score标准化后的学生网络logit转化为概率形式;
5)任选一教师网络logit对应的概率和学生网络logit对应的概率,并计算所选两概率之间的KL散度作为损失函数,进行梯度下降优化蒸馏所述学生网络。
进一步的,对学生网络logit、教师网络logit分别进行Z-score标准化处理时,设置学生网络的温度不等于教师网络的温度。
进一步的,对学生网络logit、教师网络logit分别进行Z-score标准化处理的具体方法为:对于输入的logit,计算其均值以及标准差,然后计算该logit与其均值的差值,然后以该差值除以该logit的标准差,得到Z-score标准化后的logit;其中,所述输入的logit为学生网络logit或教师网络logit;学生网络的温度bS与教师网络的温度bT之比等于学生网络logit的标准差σ(zn)与教师网络logit的标准差σ(vn)之比,即
进一步的,使用softmax函数将Z-score标准化后的教师网络logit、Z-score标准化后的学生网络logit转化为概率形式。
进一步的,所述目标领域为文本分类领域,所述学生网络为文本分类任务模型。
进一步的,所述目标领域为目标检测领域,所述学生网络为目标检测模型。
一种图像分类识别方法,其步骤包括:
1)在训练数据上选取用于图像分类的训练样本;
2)将所述训练样本预处理后分别输入学生网络、教师网络,获得相应的学生网络logit、教师网络logit;
3)将每个学生网络logit、教师网络logit分别进行Z-score标准化处理,获得Z-score标准化后的学生网络logit、Z-score标准化后的教师网络logit;
4)将Z-score标准化后的教师网络logit、Z-score标准化后的学生网络logit转化为概率形式;
5)任选一教师网络logit对应的概率和学生网络logit对应的概率,并计算所选两概率之间的KL散度作为损失函数,进行梯度下降优化蒸馏所述学生网络;
6)对于一待识别的图像数据,将其输入步骤5)训练所得学生网络,得到该图像数据的类别。
一种服务器,其特征在于,包括存储器和处理器,所述存储器存储计算机程序,所述计算机程序被配置为由所述处理器执行,所述计算机程序包括用于执行上述方法中各步骤的指令。
一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现上述方法的步骤。
与现有技术相比,本发明的积极效果为:
本专利通过对学生网络logit和教师网络logit的Z-score标准化预处理,目的是消除教师网络和学生网络之间logit偏移和方差匹配,打破这两个限制学生网络预测logit的桎梏。同时,本专利使用的Z-score标准化至少有四个优点,即零均值、有限标准差、单调性和有界性。
图1是展示传统知识蒸馏与所提知识蒸馏算法的模式对比示意图,图1(a)中传统知识蒸馏迫使学生网络和教师网络之间的logit差距保持固定,易知,在分类任务中只有预测logit的大小顺序重要,而其logit的大小尺度并不重要,因而此固定大小的差距构成了一个不必要的副作用。考虑到学生网络和教师网络之间的能力差距,一个轻量级的学生网络很难产生与庞大的教师网络相同规模的logit。与之相比,图1(b)的基于Z-score的logit标准化预处理减轻了该副作用。学生网络的标准化logit与教师网络的logit量级相当,而其原始logit则可以位于任意量级。由于标准化操作的单调性,学生网络可以有效地从教师网络学习logit顺序,而不必被迫严格遵从其规模。
附图说明
图1为展示传统知识蒸馏与本发明所提知识蒸馏算法的模式对比示意图;
(a)常规知识蒸馏算法示意;(b)本专利提出的知识蒸馏算法示意。
图2为本专利提出的结合Z-score标准化logit预处理的深度学习知识蒸馏算法流程图。
图3为Z-score标准化预处理过程。
图4为所提知识蒸馏算法分析与推导过程的思维导图。
具体实施方式
下面结合附图对本发明进行进一步详细描述,所举实例只用于解释本发明,并非用于限定本发明的范围。
本专利提出的结合Z-score标准化logit预处理的深度学习知识蒸馏算法流程图如图2所示。首先在训练数据(如CIFAR-100、ImageNet等)上选取训练样本;将训练样本预处理后通过学生网络获得相应的学生网络logit,同样通过教师网络获得相应的教师网络logit;将每个logit进行Z-score标准化处理,即将logit向量减去其均值后除以其标准差,对教师网络logit和学生网络logit分别都进行此操作;最后使用softmax函数将标准化后的对教师网络logit和学生网络logit转化为概率形式,计算两个概率之间的KL散度作为损失函数,进行梯度下降,完成此步的学生网络训练。
Z-score标准化预处理过程如图3所示,已知一个logit输入向量,计算其均值以及标准差,先获得logit和其均值的差值,然后以差值除以标准差后获得Z-score标准化后的logit输出。
分类中softmax的推导:分类中的softmax函数可以证明是受概率归一化条件和信息论中状态期望约束的熵最大化的唯一解。该推导在置信度校准中也被利用来制定温度标度。涉及softmax函数的导出函数是玻尔兹曼分布。假设本发明对训练集和教师网络有如下的约束熵最大化优化,本发明有如下目标
其中,n为训练样本的索引,k为类别的索引,是第n个训练样本对应的教师网络输出的逻辑单元(logit)中第k个类所对应的元素,q(vn)(k)则是/>经过softmax函数之后概率分布中的对应数值。
第一个约束条件是离散概率密度的要求。第二个约束控制分布的范围,使教师网络能够准确地预测目标类。如果本发明将硬标签y看作一个极陡概率分布其值除目标指标/>外均为零,则第二个约束可以重写为/>这相当于使教师网络模拟硬概率/>
通过应用拉格朗日乘子α1,1,…,α1,N和α2,1,…,α2,N,它给出
对α1,N和α2,N求偏导,得到约束条件。相反,对q(vn)(k)求导得到
让导数等于0得到解
其中称为配分函数,以满足归一化条件。注意,α2,n没有明确的表达式,因此它的选择可以手工定义。若设α2,n=1/T,则公式7符合KD的常规表达式。当α2,n=1时,公式为分类中常用的传统softmax函数。
KD中softmax函数的推导:根据这一思想,本发明可以定义一个熵最大化问题来表述KD中的softmax函数。对于一个迁移集给定一个训练有素的教师网络及其预测q(vn),则预测学生网络的目标函数为
其中n为训练样本的索引,k为类别的索引,是第n个训练样本对应的学生网络输出的逻辑单元(logit)中第k个类所对应的元素,q(zn)(k)则是/>经过softmax函数之后概率分布中的对应数值。
通过应用拉格朗日乘子β1,n,β2,n和β3,n,
对q(zn)(k)求导得到
为简单起见,假设βn=β2,n+β3,n,得到
其中由于概率密度的归一化条件成立。
式10中的公式与式7结构相同。
1)不同的温度。注意,LT对α1,n和α2,n的偏导数分别指向公式5中的两个约束,它们与α1,n和α2,n无关。类似的情况也适用于公式8。因此,不能给出它们的显式表达式,而可以手动定义它们的值。如果设βn=α2,n=1/T,则式7和式10转化为涉及学生网络和教师网络共同温度的KD表达式。相比之下,可以选择βn≠α2,n,这意味着学生网络和教师网络可以拥有不同的温度。
2)每个样本温度。通常为所有样本定义一个全局温度。即对于任何n,α2,n和βn被定义为一个常量值。相反,由于缺乏对它们的限制,它们可能在不同的样本中有所不同。选择一个全局常数作为温度缺乏依据。因此,允许采用不同的采样温度。
Logit标准化:如何为教师网络和学生网络选择不同的样本温度仍然是未知的。本发明首先通过引入两个超参数aS和bS,bS可以看作1/βn,bT可以看作1/α2,n,将公式10中的softmax函数改写为一般公式,
其中aS可以消去并且不违反等式。当aS=0,bS=1/βn时,得到公式10中的特殊情况。通过引入aT和bT,可以得到教师网络情况下的类似方程。在KD任务中,本发明的目标是为学生网络和教师网络的超参数选择合适的值。
对于一个最终得到充分蒸馏的学生网络,本发明假设KL散度损失达到最小,并且其预测的概率密度与教师网络的概率密度相匹配,即q(zn;aS,bS)(k)=q(vn;aT,bT)(k)。则对于任意一对索引i,j∈[1,K],很容易得到
通过对j从1到K求和,本发明得到
其中,和/>分别为学生网络和教师网络logit向量的均值,即/>(类似,省略)。通过对公式11对i从1到K的平方求和,本发明可以得到
其中σ是输入向量标准差的函数。从公式11和12中,本发明可以根据logit移位和方差匹配来描述一个受过良好训练的学生网络的两个属性。
1)logit的偏移。由式11可知,在传统的共享温度(bS=bT)设置下,学生网络和教师网络在任意指标上的logit之间存在恒定的位移,即
其中,可以认为是第n个样本的常数。这意味着在传统的KD方法中,学生网络被迫严格模仿教师网络的转变逻辑。举一个极端的例子,/>学生网络对应的logit之差必须是/>考虑到模型大小和容量的差距,学生网络可能无法像教师网络那样制作出广泛的logit范围。相反,当学生网络的logit排名与教师网络匹配时,即给定对教师网络输出的logit进行排序的指标t1,…,tK∈[1,K]使得/>则/>成立。
logit顺序是一门既能使学生网络预测又能使教师网络预测的基本知识。因此,这种逻辑转换是传统KD的副作用,束缚学生网络生成困难但不必要的结果。
2)方差匹配。从公式12中,本发明得出结论,学生网络和教师网络的温度之比等于他们预测logit的标准差之比,即:
在vanilla KD的温度共享设置中,学生网络被迫预测logit,使σ(zn)=σ(vn)。
这是另一个限制学生网络预测logit标准差的桎梏。相反,由于超参数来自拉格朗日乘子且可灵活调谐,本发明定义和/>这样,公式14中的等式总是成立的。
因此,为了打破这两个束缚,本发明建议将超参数aS,bS,aT,bT分别设置为其logit的均值和加权标准差。
其中为简洁起见,形成学生网络输出的logit的加权Z-score标准化预处理。教师网络输出的logit的情况与此类似,略去。在教师模型和学生模型中引入并共享一个基础温度τ。
算法1总结了本发明提出的对logit进行Z-score标准化的知识蒸馏流程,以下再详尽描述其过程及含义:
1.输入训练样本xn至学生网络和教师网络,获得其对应的logit,其中fT和fS表示教师网络和学生网络。
vn=fT(xn),zn=fS(xn)
2.对logit计算其平均值,将学生网络logit的平均值赋值至aS,将教师网络logit的平均值赋值给aT,其中上横线表示平均值。
3.对logit计算其标准差,将学生网络logit的标准差赋值至bS,将教师网络logit的标准差赋值给bT,其中σ表示标准差。
4.通过softmax计算学生网络logit和教师网络logit对应的概率输出,其中τ为基础温度。
q(vn)=softmax[(vn-aT)/bT/τ],q(zn)=softmax[(zn-aS)/bS/τ]
5.通过计算知识蒸馏相关的损失函数完成训练,其中λKD为损失函数的权重。
Z-score标准化至少有四个优点,即零均值、有限标准差、单调性和有界性。
1)零均值。标准化向量的均值可以证明为零。Z-score函数本质上使平均值为零。
2)有限标准差。加权Z-score输出的标准差可以表示为1/τ。该属性使标准化的学生和教师logit映射到一个相同的高斯样分布,其平均值为零,标准差为确定值。映射是多对一的,这意味着它的反向是不确定的。因此,原始学生logit向量zn的方差和取值范围不受限制。
3)单调性。很容易证明Z-score是一个线性变换函数,因此属于单调函数。这个属性确保转换后的学生logit与原始logit保持相同的排名,即给定对原始logit排序的索引t1,…,tK∈[1,K],然后/>教师的隐性知识以logit顺序的形式被保留并传递给学生。
4)有界性。标准化的logit可以在范围内表示。令/> 对于任意指标t0,则有
此属性确保了标准化logit的值范围是有界的。与传统的KD相比,它可以控制logit的范围,避免指数值过大。为此,本发明定义了一个基本温度来控制范围。
本发明的效果分析如下:
本发明用三种基于logit的蒸馏方法来评估本发明的预处理过程。如表1所示,在应用本发明的Z-score标准化预处理后,vanilla KD达到了与最先进的基于特征的方法相当的性能。作为最先进的基于logit的方法,DKD也可以通过本发明提出的预处理进一步提高。CTKD是一种基于对抗学习确定样本温度的KD方法。本发明将它与本发明的标准化结合起来,利用它来预测算法1中的基本温度。如表2所示,CTKD蒸馏的学生模型受益于本发明的预处理。CTKD可以持续改进KD,并且在CTKD的基础上,本发明的方法进一步显著提高了性能。
不同方法在ImageNet上top1和top 5精度的比较结果如表3所示。本发明的预处理也可以在大规模数据集上对所有三种基于逻辑的方法实现一致的改进。
本发明在基底温度和KD损失的权重λKD的不同配置方面进行了广泛的消融研究。基准温度为2时的部分结果显示在表4中。本发明可以看到随着KD损失权重的增加,softmax函数将原始logit向量作为输入的vanilla KD没有较好的性能增益。相比之下,本发明对Z-score的预处理可以实现明显的提升。
表1.CIFAR-100验证集上不同知识蒸馏方法的Top-1准确度(%)。此表为教师网络和学生网络具有相同体系结构的情况。整理方法按类型排序,即基于特征和基于logit的方法。我们将logit标准化应用于现有的基于logit的方法,并使用Δ来展示其性能提升。
表2.CIFAR-100验证集上不同知识蒸馏方法的Top-1准确度(%)。此表为教师网络和学生网络具有不同体系结构的情况。整理方法按类型排序,即基于特征和基于logit的方法。我们将logit标准化应用于现有的基于logit的方法,并使用Δ来展示其性能提升。
表3.不同知识蒸馏方法在ImageNet验证集上的Top-1和Top-5准确率。
/>
表4.Z-score不同设置下的消融研究,基本温度τ设置为2。默认情况下λCE=0.1。为了简洁起见,表示教师logit的vn和表示学生logit的zn向量缩写为z。教师网络为ResNet32×4,学生网络为ResNet8×4。
延申
Logit范围:本发明计算了教师网络和学生网络之间平均logit的差异程度。如果不应用本发明的预处理,在目标标签指数(7.5v.s.12)下,学生未能产生与教师一样大的logit。教师和学生logit之间的平均距离也达到0.27。logit范围的限制阻碍了学生进行正确的预测。相比之下,本发明的预处理打破了这种限制并使学生能够生成适当范围的logit。其标准化后的有效logit输出却与教师的结果吻合得很好。标准化logit的平均距离也缩小到0.18,这意味着学生更好地模仿了教师。
logit方差:传统知识蒸馏算法迫使学生logit的方差接近教师(3.78v.s.3.10)。然而,本发明的预处理打破了束缚,学生可以有灵活的logit方差(0.48v.s.3.10),而其标准化logit具有与教师相同的方差(均为0.99)。
特征可视化:本发明在t-SNE可视化了教师月学生网络的深度特征表示后,本发明的预处理提高了包括KD、CTKD和DKD在内的所有方法的特征可分离性和可判别性。
改进大型教师蒸馏:在表5中,本发明的预处理一致地提高了不同规模和容量的各种教师的蒸馏性能。本发明还可通过双变量直方图的方式计算了学生的模仿优度。被传统知识蒸馏的学生预测的logit均值和标准差明显偏离教师。相比之下,本发明的预处理使学生在标准化的logit均值和标准差方面与教师完美匹配。
表5.在CIFAR-100上对各种教师模型进行蒸馏的结果。学生模型是WRN-16-2。
尽管为说明目的公开了本发明的具体实施例,其目的在于帮助理解本发明的内容并据以实施,本领域的技术人员可以理解:在不脱离本发明及所附的权利要求的精神和范围内,各种替换、变化和修改都是可能的。因此,本发明不应局限于最佳实施例所公开的内容,本发明要求保护的范围以权利要求书界定的范围为准。
Claims (9)
1.一种混合学习中知识蒸馏的学生网络训练方法,其步骤包括:
1)在训练数据上选取目标领域的训练样本;
2)将所述训练样本预处理后分别输入学生网络、教师网络,获得相应的学生网络logit、教师网络logit;
3)将每个学生网络logit、教师网络logit分别进行Z-score标准化处理,获得Z-score标准化后的学生网络logit、Z-score标准化后的教师网络logit;
4)将Z-score标准化后的教师网络logit、Z-score标准化后的学生网络logit转化为概率形式;
5)任选一教师网络logit对应的概率和学生网络logit对应的概率,并计算所选两概率之间的KL散度作为损失函数,进行梯度下降优化蒸馏所述学生网络。
2.根据权利要求1所述的方法,其特征在于,对学生网络logit、教师网络logit分别进行Z-score标准化处理时,设置学生网络的温度不等于教师网络的温度。
3.根据权利要求2所述的方法,其特征在于,对学生网络logit、教师网络logit分别进行Z-score标准化处理的具体方法为:对于输入的logit,计算其均值以及标准差,然后计算该logit与其均值的差值,然后以该差值除以该logit的标准差,得到Z-score标准化后的logit;其中,所述输入的logit为学生网络logit或教师网络logit;学生网络的温度bS与教师网络的温度bT之比等于学生网络logit的标准差σ(zn)与教师网络logit的标准差σ(vn)之比,即
4.根据权利要求1或2或3所述的方法,其特征在于,使用softmax函数将Z-score标准化后的教师网络logit、Z-score标准化后的学生网络logit转化为概率形式。
5.根据权利要求1或2或3所述的方法,其特征在于,所述目标领域为文本分类领域,所述学生网络为文本分类任务模型。
6.根据权利要求1或2或3所述的方法,其特征在于,所述目标领域为目标检测领域,所述学生网络为目标检测模型。
7.一种图像分类识别方法,其步骤包括:
1)在训练数据上选取用于图像分类的训练样本;
2)将所述训练样本预处理后分别输入学生网络、教师网络,获得相应的学生网络logit、教师网络logit;
3)将每个学生网络logit、教师网络logit分别进行Z-score标准化处理,获得Z-score标准化后的学生网络logit、Z-score标准化后的教师网络logit;
4)将Z-score标准化后的教师网络logit、Z-score标准化后的学生网络logit转化为概率形式;
5)任选一教师网络logit对应的概率和学生网络logit对应的概率,并计算所选两概率之间的KL散度作为损失函数,进行梯度下降优化蒸馏所述学生网络;
6)对于一待识别的图像数据,将其输入步骤5)训练所得学生网络,得到该图像数据的类别。
8.一种服务器,其特征在于,包括存储器和处理器,所述存储器存储计算机程序,所述计算机程序被配置为由所述处理器执行,所述计算机程序包括用于执行权利要求1至7任一所述方法中各步骤的指令。
9.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至7任一所述方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311105587.8A CN117494780A (zh) | 2023-08-30 | 2023-08-30 | 一种混合学习中知识蒸馏的学生网络训练方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311105587.8A CN117494780A (zh) | 2023-08-30 | 2023-08-30 | 一种混合学习中知识蒸馏的学生网络训练方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117494780A true CN117494780A (zh) | 2024-02-02 |
Family
ID=89678804
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311105587.8A Pending CN117494780A (zh) | 2023-08-30 | 2023-08-30 | 一种混合学习中知识蒸馏的学生网络训练方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117494780A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117892841A (zh) * | 2024-03-14 | 2024-04-16 | 青岛理工大学 | 基于渐进式联想学习的自蒸馏方法及系统 |
-
2023
- 2023-08-30 CN CN202311105587.8A patent/CN117494780A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117892841A (zh) * | 2024-03-14 | 2024-04-16 | 青岛理工大学 | 基于渐进式联想学习的自蒸馏方法及系统 |
CN117892841B (zh) * | 2024-03-14 | 2024-05-31 | 青岛理工大学 | 基于渐进式联想学习的自蒸馏方法及系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US10248664B1 (en) | Zero-shot sketch-based image retrieval techniques using neural networks for sketch-image recognition and retrieval | |
CN108960330B (zh) | 基于快速区域卷积神经网络的遥感图像语义生成方法 | |
WO2021143396A1 (zh) | 利用文本分类模型进行分类预测的方法及装置 | |
CN111127364B (zh) | 图像数据增强策略选择方法及人脸识别图像数据增强方法 | |
CN116134454A (zh) | 用于使用知识蒸馏训练神经网络模型的方法和系统 | |
Schreiter et al. | Efficient sparsification for Gaussian process regression | |
CN117494780A (zh) | 一种混合学习中知识蒸馏的学生网络训练方法 | |
CN112015868A (zh) | 基于知识图谱补全的问答方法 | |
CN111309878B (zh) | 检索式问答方法、模型训练方法、服务器及存储介质 | |
CN114186084B (zh) | 在线多模态哈希检索方法、系统、存储介质及设备 | |
CN112380421A (zh) | 简历的搜索方法、装置、电子设备及计算机存储介质 | |
CN112434134B (zh) | 搜索模型训练方法、装置、终端设备及存储介质 | |
CN113343125A (zh) | 一种面向学术精准推荐的异质科研信息集成方法及系统 | |
CN111611395B (zh) | 一种实体关系的识别方法及装置 | |
CN111079011A (zh) | 一种基于深度学习的信息推荐方法 | |
US11403339B2 (en) | Techniques for identifying color profiles for textual queries | |
CN117237727A (zh) | 基于生成对抗网络原型修正的少样本图像分类方法及系统 | |
CN112445899A (zh) | 一种基于神经网络的知识库问答中的属性匹配方法 | |
CN113407664A (zh) | 语义匹配方法、装置和介质 | |
CN113626537A (zh) | 一种面向知识图谱构建的实体关系抽取方法及系统 | |
Liu et al. | Novel Uncertainty Quantification through Perturbation-Assisted Sample Synthesis | |
Chen et al. | Sparse subnetwork inference for neural network epistemic uncertainty estimation with improved Hessian approximation | |
CN118113815B (zh) | 内容搜索方法、相关装置和介质 | |
CN113343666B (zh) | 评分的置信度的确定方法、装置、设备及存储介质 | |
KR20190093753A (ko) | 방위각 추정 장치 및 방법 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |