CN117743858A - 一种基于知识增强的连续学习软标签构建方法 - Google Patents
一种基于知识增强的连续学习软标签构建方法 Download PDFInfo
- Publication number
- CN117743858A CN117743858A CN202410183536.5A CN202410183536A CN117743858A CN 117743858 A CN117743858 A CN 117743858A CN 202410183536 A CN202410183536 A CN 202410183536A CN 117743858 A CN117743858 A CN 117743858A
- Authority
- CN
- China
- Prior art keywords
- soft
- semantic
- knowledge
- gram matrix
- knowledge distillation
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000010276 construction Methods 0.000 title claims abstract description 19
- 238000013140 knowledge distillation Methods 0.000 claims abstract description 122
- 238000000034 method Methods 0.000 claims abstract description 107
- 239000011159 matrix material Substances 0.000 claims abstract description 96
- 230000006870 function Effects 0.000 claims abstract description 90
- 239000013598 vector Substances 0.000 claims abstract description 43
- 238000003062 neural network model Methods 0.000 claims abstract description 33
- 238000009499 grossing Methods 0.000 claims description 28
- 238000005457 optimization Methods 0.000 claims description 15
- 230000002787 reinforcement Effects 0.000 claims 9
- 230000000694 effects Effects 0.000 abstract description 4
- 238000013473 artificial intelligence Methods 0.000 abstract description 2
- 230000009286 beneficial effect Effects 0.000 description 4
- 241000282472 Canis lupus familiaris Species 0.000 description 3
- 241000282326 Felis catus Species 0.000 description 3
- 241000220225 Malus Species 0.000 description 3
- 235000021016 apples Nutrition 0.000 description 3
- NAWXUBYGYWOOIX-SFHVURJKSA-N (2s)-2-[[4-[2-(2,4-diaminoquinazolin-6-yl)ethyl]benzoyl]amino]-4-methylidenepentanedioic acid Chemical compound C1=CC2=NC(N)=NC(N)=C2C=C1CCC1=CC=C(C(=O)N[C@@H](CC(=C)C(O)=O)C(O)=O)C=C1 NAWXUBYGYWOOIX-SFHVURJKSA-N 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 238000012423 maintenance Methods 0.000 description 2
- 241001465754 Metazoa Species 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 238000004821 distillation Methods 0.000 description 1
- 206010027175 memory impairment Diseases 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Classifications
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Image Analysis (AREA)
Abstract
本发明涉及人工智能技术领域,提供了一种基于知识增强的连续学习软标签构建方法,该方法包括:随机初始化语义软标签,计算语义Gram矩阵,通过语义Gram矩阵、词向量Gram矩阵和相应类别平滑后的语义软标签,获得优化后的语义软标签损失函数;随机初始化知识蒸馏软标签,计算知识蒸馏Gram矩阵,通过知识蒸馏Gram矩阵、嵌入Gram矩阵和相应类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数;将上述两种损失函数结合,获得总损失函数;将所述总损失函数用于新任务的训练。本发明解决了在神经网络模型连续学习过程中缺乏旧任务数据的问题,避免了灾难性遗忘的效果。
Description
技术领域
本发明涉及人工智能技术领域,尤其涉及一种基于知识增强的连续学习软标签构建方法。
背景技术
尽管深度学习在分类和检测领域取得了显著成就,但大多数算法都是基于封闭环境中的固定类别。真实的场景是一个开放和动态的环境,总是有新的类别出现。当神经网络模型应用于实际任务时,需要在新的数据集上进行更新。如果我们直接对模型进行微调,那么前一项任务的准确性就会降低,这种情况被称为灾难性遗忘。直接联合训练将造成巨大的训练成本。持续学习就是为了解决这个问题。连续学习的目标是在不忘记旧任务的情况下学习新任务,它已被应用于许多领域。
目前,连续学习的方法已经取得了一些进展。但他们中的大多数人都专注于学习策略。在图像分类任务中,他们遵循多分类问题的默认配置,并使用一个基于softmax损失的热编码器。这些方法将神经网络模型输出与groundtruth的一次性编码相匹配,称为硬标签。但对于连续学习任务,多个任务按顺序出现,并且类别是逐步学习的。由于缺乏完整的先前数据,无法通过前一类和当前类之间的关联,而导致遗忘的问题。
发明内容
有鉴于此,本发明提供了一种基于知识增强的连续学习软标签构建方法,以解决现有技术中由于缺乏先前数据的完整性,而无法考虑前一类和当前类之间的关联,从而导致遗忘的技术问题。
本发明提供了一种基于知识增强的连续学习软标签构建方法,包括:
S1.随机初始化语义软标签,计算语义Gram矩阵,通过所述语义Gram矩阵、词向量Gram矩阵和相应类别平滑后的语义软标签,获得优化后的语义软标签损失函数;
以及随机初始化知识蒸馏软标签,计算知识蒸馏Gram矩阵,通过所述知识蒸馏Gram矩阵、嵌入Gram矩阵和相应类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数;
S2.将所述优化后的语义软标签损失函数与所述优化后的知识蒸馏软标签损失函数结合,获得总损失函数;
S3.采用所述总损失函数进行新任务的训练。
进一步地,所述随机初始化语义软标签,通过所述语义Gram矩阵、词向量Gram矩阵和相应类别平滑后的语义软标签,获得优化后的语义软标签损失函数,包括:
a1.随机初始化所述语义软标签,确定不同类别的语义软标签之间的相关性,得到所述语义Gram矩阵;
a2.采用外部词向量模型获得对应类别的词向量,确定不同类别的词向量之间的相关性,得到所述词向量Gram矩阵,其中,所述外部词向量模型为CLIP或Bert;
a3.计算语义Gram矩阵和词向量Gram矩阵之间的欧几里得距离,获得中间过程语义软标签;
a4.采用softmax函数对所述中间过程语义软标签归一化,获得优化后的中间过程语义软标签,对于每个类别,采用该类别优化后的中间过程语义软标签平滑相应的原始硬标签,获得该类别平滑后的语义软标签;a5.基于所有类别平滑后的语义软标签,获得所述优化后的语义软标签损失函数。
进一步地,语义Gram矩阵和词向量Gram矩阵之间的欧几里得距离的表达式如下:
其中,表示中间过程的语义损失函数,/>表示语义Gram矩阵,表示词向量Gram矩阵。
进一步地,优化后的中间过程语义软标签的表达式如下:
其中,表示所述优化后的中间过程语义软标签,k表示中间过程语义软标签,/>表示将中间过程语义软标签除以温度系数T,进行数学上的操作,q(x)表示硬标签。
进一步地,所述优化后的语义软标签损失函数的表达式如下:
其中,表示超参数,q(x)表示硬标签,p(x)表示相应类别神经网络模型的输出,表示硬标签和相应类别神经网络模型的输出的KL散度,表示优化后的中间过程语义软标签和相应类别神经网络模型输出的KL散度。
进一步地,所述随机初始化知识蒸馏软标签,通过所述知识蒸馏Gram矩阵、嵌入Gram矩阵和相应类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数,包括:
b1.随机初始化知识蒸馏软标签,确定不同类别的知识蒸馏软标签之间的相关性,得到所述知识蒸馏Gram矩阵;
b2.将旧任务和新任务不同类别的聚类中心输入至旧神经网络模型,获取每个类别的嵌入特征,确定不同类别嵌入特征之间的相关性,得到所述嵌入Gram矩阵;
b3.计算所述知识蒸馏Gram矩阵和嵌入Gram矩阵之间的欧几里得距离,获得中间过程知识蒸馏软标签;
b4.采用softmax函数对所述中间过程知识蒸馏软标签归一化,获得优化后的中间过程知识蒸馏软标签,对于每个类别,采用该类别优化后的中间过程知识蒸馏软标签平滑相应原始的硬标签,获得该类别平滑后的知识蒸馏软标签;
b5.基于所有类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数。
进一步地,所述知识蒸馏Gram矩阵和嵌入Gram矩阵之间的欧几里得距离的表达式如下:
其中,表示中间过程的知识蒸馏损失函数,/>表示知识蒸馏Gram矩阵,/>表示嵌入Gram矩阵。
进一步地,所述优化后的中间过程知识蒸馏软标签的表达式如下:
其中,表示优化后的中间过程知识蒸馏软标签,/>表示中间过程知识蒸馏软标签,温度T被加到softmax中以缩放整体分布。
进一步地,优化后的知识蒸馏软标签损失函数的表达式如下:
其中,表示超参数,/>表示优化后的中间过程知识蒸馏标签和相应类别神经网络模型的输出的KL散度,/>表示优化后的中间过程知识蒸馏软标签。
进一步地,所述总损失函数的表达式如下:
其中,表示总损失函数,/>表示优化后的中间过程知识蒸馏标签和相应类别神经网络模型的输出的KL散度。
本发明与现有技术相比存在的有益效果是:
1、本发明的方法通过对标签进行平滑有助于提高神经网络模型对新任务中样本的泛化能力;
2、本发明采用知识嵌入的方法来反映类别相关性,有助于新任务的学习和对旧任务学习的关联信息;
3、本发明的方法获得的总损失函数,通过新任务的学习和对旧任务的学习之间的关联性,解决了由于缺乏先前数据的完整性,而无法考虑前一类和当前类之间的关系的问题,避免了灾难性遗忘。
附图说明
为了更清楚地说明本发明中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
图1是本发明实施例提供的一种基于知识增强的连续学习软标签构建方法的流程图;
图2是本发明实施例提供的类别之间关联的示意图。
具体实施方式
以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本发明实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本发明。在其它情况中,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本发明的描述。
下面将结合附图详细说明本发明的一种知识增强软标签的构建方法。
图1是本发明实施例提供的一种基于知识增强的连续学习软标签构建方法的流程图。
如图1所示,该基于知识增强的连续学习软标签构建方法包括:
S1.随机初始化语义软标签,计算语义Gram矩阵,通过所述语义Gram矩阵、词向量Gram矩阵和相应类别平滑后的语义软标签,获得优化后的语义软标签损失函数;
a1.随机初始化所述语义软标签,确定不同类别的语义软标签之间的相关性,得到所述语义Gram矩阵;
随机初始化语义软标签,其中,k表示中间过程语义软标签,对于每个类别,满足/>,且/>,n表示类别数量。在所述初始化语义软标签中,每个类别被表示非零监督信号,信号的分布代表不同类别之间的关系,每个实例的软标签最大值的位置应与硬标签的最大值位置一致。
在一个实施例中,基于随机初始化语义软标签,确定不同类别的语义软标签之间的相关性,由各不同类别的语义软标签之间的相关性,构成所述语义Gram矩阵。
a2.采用外部词向量模型获得对应类别的词向量,确定不同类别的词向量之间的相关性,得到所述词向量Gram矩阵,其中,所述外部词向量模型为CLIP或Bert;
在一个实施例中,采用外部词向量模型获得对应类别的词向量,确定不同类别的词向量之间的相关性,由各对应类别的词向量,构成所述词向量Gram矩阵。
a3.计算语义Gram矩阵和词向量Gram矩阵之间的欧几里得距离,获得中间过程语义软标签;
在一个实施例中,由于语义Gram矩阵和词向量Gram之间的欧几里得距离等于所述中间过程的语义损失函数。
所述中间过程的语义损失函数的表达式如下:
(1)
其中,表示所述中间过程的语义损失函数,n表示类别数量正整数,表示语义Gram矩阵,/>表示词向量Gram矩阵。
通过上式(1)求出中间过程语义软标签,即求得随机初始化语义软标签k中各项值。
然而,考虑到标准化的交叉熵分布的argmax可能与硬标签不匹配,那么,直接采用这个归一化的分布作为软标签,会违反基本约束条件。
因此,经分析,优化后的中间过程语义软标签的表达式如下:
(2)
其中,表示优化后的中间过程语义软标签,k表示中间过程语义软标签,表示将所述优化后的中间过程语义软标签除以温度系数T,进行数学上的/>操作,温度T被加到softmax中以缩放整体分布,通过选择适当的温度系数T,从而起到更好地平滑作用,/>表示硬标签。
将随机初始化语义软标签中各k值代入式(2),得到各优化后的中间过程语义软标签。
a4.采用softmax函数对所述中间过程语义软标签归一化,获得优化后的中间过程语义软标签,对于每个类别,采用该类别优化后的中间过程语义软标签平滑相应的原始硬标签,获得该类别平滑后的语义软标签;
获得所述相应类别平滑后的语义软标签的表达式如下:
(3)
其中,表示相应类别平滑后的语义软标签,/>表示超参数,/>,用于控制两种类型的监督信号:硬标签和中间过程语义软标签;/>表示优化后的中间过程语义软标签。
将式(2)代入式(3),求得各类别平滑后的语义软标签。
将各优化后的中间过程语义软标签表达式(3),得到各类别平滑后的语义软标签。
本发明的方法通过对标签进行平滑有助于提高神经网络模型对新样本的泛化能力。
a5.基于所有类别平滑后的语义软标签,获得所述优化后的语义软标签损失函数。
在一个实施例中,求得所有类别的平滑后的语义软标签之后,基于各类别的平滑后的语义软标签与相应类别神经网络模型的输出的数学关系式,求得所述优化后的语义软标签损失函数。
由于优化后的语义软标签损失函数的表达式如下:
(4)
其中,表示优化后的语义软标签损失函数,/>表示相对应类别的神
经网络模型的输出,
将式(3)代入上式(4)得到:
(5)
由于交叉熵表达式=/>,
=/>,
则上式(5)可写成:
从密度估计的角度,最小化交叉熵等价于优化KL散度。
KL散度的表达式如下:
(6)
其中,表示KL散度,/>表示两者的交叉熵,/>表示固定分布的常量。
因此,所述优化后的语义软标签损失函数的表达式如下:
(7)
其中,表示硬标签和相应类别神经网络模型输出的KL散度,表示优化后的中间过程语义软标签和相应类别神经网络模型的输出的KL散度,所述优化后的语义软标签损失函数是各类别语义软标签之间的关联信息,如果超参数/>,则优化目标将退化为多分类交叉熵形式,如果将超参数设置为 1,则优化目标仅与平滑后的语义软标签相关。由上式(7)看出,本公式将类别语义软标签之间的关联充分考虑进神经网络模型的整体训练过程中。
随机初始化知识蒸馏软标签,计算知识蒸馏Gram矩阵,通过所述知识蒸馏Gram矩阵、嵌入Gram矩阵和相应类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数;
b1.随机初始化知识蒸馏软标签,确定不同类别的知识蒸馏软标签之间的相关性,得到所述知识蒸馏Gram矩阵;
随机初始化语义软标签,其中,f表示中间过程语义软标签,对于每个类别,满足/>,且/>,m表示类别数。在所述初始化语义软标签中,每个类别被表示非零监督信号,信号的分布代表不同类别之间的关系,每个实例的软标签最大值的位置应与硬标签的最大值位置一致。
在一个实施例中,基于随机初始化知识蒸馏软标签,确定不同类别的知识蒸馏软标签之间的相关性,由各不同类别的知识蒸馏软标签之间的相关性,构成所述嵌入Gram矩阵。
b2.将旧任务和新任务不同类别的聚类中心输入至旧神经网络模型,获取每个类别的嵌入特征,确定不同类别嵌入特征之间的相关性,得到所述嵌入Gram矩阵;
在一个实施例中,将旧任务和新任务不同类别的聚类中心输入至所述神经网络模型,获取每个类别的嵌入特征,再根据每个类别的嵌入特征,确定不同类别嵌入特征之间的相关性,由不同类别嵌入特征之间的相关性,构成所述词向量Gram矩阵。
b3.计算所述知识蒸馏Gram矩阵和嵌入Gram矩阵之间的欧几里得距离,获得中间过程知识蒸馏软标签;
在一个实施例中,由于知识蒸馏Gram矩阵和嵌入Gram之间的欧几里得距离等于所述中间过程的知识蒸馏损失函数,因此,
所述中间过程的知识蒸馏损失函数的表达式如下:
(8)
其中,表示中间过程的知识蒸馏损失函数,i,j都表示正整数,/>表示知识蒸馏Gram矩阵,/>表示嵌入Gram 矩阵。
通过上式(8)求出中间过程知识蒸馏软标签,即求得随机初始化语义软标签中各f,即f1,f2,f3...fm。
然而,考虑到标准化的交叉熵分布的argmax可能与硬标签不匹配,那么,直接采用这个归一化的分布作为软标签,会违反基本约束条件。
因此,经分析,优化后的中间过程知识蒸馏软标签的表达式如下:
(9)
其中,表示优化后的中间过程知识蒸馏软标签,f表示中间过程知识蒸馏软标签,进行数学上的softmax操作,温度T被加到softmax中以缩放整体分布,通过选择适当的温度系数T,从而起到更好地平滑作用。
将随机初始化语义软标签中各f值代入式(9),得到各优化后的中间过程知识蒸馏软标签。
b4.采用softmax函数对所述中间过程知识蒸馏软标签归一化,获得优化后的中间过程知识蒸馏软标签,对于每个类别,采用该类别优化后的中间过程知识蒸馏软标签平滑相应原始的硬标签,获得该类别平滑后的知识蒸馏软标签;
获得所述相应类别平滑后的语义软标签的表达式如下:
(10)
其中,表示平滑后的知识蒸馏软标签,/>表示超参数,/>,用于控制两种类型的监督信号:硬标签和中间过程知识蒸馏软标签;/>表示硬标签,/>表示优化后的中间过程知识蒸馏软标签。
将式(9)代入式(10),求得所述各类别平滑后的知识蒸馏软标签。
本发明通过采用知识嵌入的方法来反映类别相关性,有助于新任务的学习和对旧任务的保持。
b5.基于所有类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数。
在一个实施例中,求得所有类别的平滑后的知识蒸馏软标签之后,基于各类别的平滑后的知识蒸馏软标签与相应类别神经网络模型的输出的数学关系式,求得所述优化后的知识蒸馏软标签损失函数。
由于优化后的知识蒸馏软标签损失函数的表达式如下:
(11)
其中,表示优化后的语义软标签损失函数。
将式(10)代入式(11)得到:
(12)
其中,表示超参数,由于交叉熵表达式/>,
,
则上式(12)可写成:
从密度估计的角度,最小化交叉熵等价于优化KL散度。
KL散度的表达式如下:
(13)
其中,表示HL散度,/>表示两者的交叉熵,/>表示固定分布的常量。
因此,所述优化后的知识蒸馏软标签损失函数的表达式如下:
(14)
其中,表示优化后的中间过程知识蒸馏标签和相应类别神经网络模型的输出的KL散度,所述优化后的知识蒸馏软标签损失函数是旧神经网络模型关于各类别之间的预测信息,如果超参数/>,则优化目标将退化为多分类交叉熵形式,如果将超参数设置为1,则优化目标仅与平滑后的蒸馏软标签相关。但事实上,由上式(14)看出,所述相应类别神经网络模型的输出不仅取决于当前类别,还取决于类别之间的关系。通过改变标签的分布,本发明的方法进一步模拟了类别之间的关系。
b5.基于所有类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数。
S2.将所述优化后的语义软标签损失函数与所述优化后的知识蒸馏软标签损失函数结合,获得总损失函数。
(15)
将所述优化后的中间过程知识蒸馏软标签,即式(7)和式(14)代入式(15)中,求得所述总损失函数。总损失函数表达式如下:
(16)
其中,表示总损失函数。
本发明获得的总损失函数,通过新任务的学习和对旧任务的学习之间的关联性,更好地体现了先前数据的完整性。
S3.采用所述总损失函数进行新任务的训练。
本发明的方法通过对标签进行平滑有助于提高模型对新任务中样本的泛化能力;通过采用知识嵌入的方法来反映类别相关性,有助于新任务的学习和对旧任务信息的保持;本发明获得的总损失函数,通过新任务的学习和对旧任务的学习之间的特征信息关联性,解决了由于缺乏先前数据的完整性,而无法考虑前一类和当前类之间的关系的问题,避免了灾难性遗忘的问题。
图2是本发明实施例提供的类别之间关联的示意图。
实施例1
步骤1.随机初始化语义软标签,计算语义Gram矩阵,通过所述语义Gram矩阵、词向量Gram矩阵和相应类别平滑后的语义软标签,获得优化后的语义软标签损失函数;
c1.随机初始化所述语义软标签,确定不同类别的语义软标签之间的相关性,得到所述语义Gram矩阵;
c2.采用外部词向量模型获得对应类别的词向量,确定不同类别的词向量之间的相关性,得到所述词向量Gram矩阵,其中,所述外部词向量模型为CLIP或Bert;
c3.计算语义Gram矩阵和词向量Gram矩阵之间的欧几里得距离,获得中间过程语义软标签;
c4.采用softmax函数对所述中间过程语义软标签归一化,获得优化后的中间过程语义软标签,对于每个类别,采用该类别优化后的中间过程语义软标签平滑相应的原始硬标签,获得该类别平滑后的语义软标签;
考虑到标准化的交叉熵分布的argmax可能与硬标签不匹配,那么,直接采用这个归一化的分布作为软标签,会违反基本约束条件。
因此,经分析,优化后的中间过程语义软标签的表达式如下:
其中,所述ksen(x)表示优化后的中间过程语义软标签,k表示中间过程语义软标签,softmax(k/T)表示将所述优化后的中间过程语义软标签除以温度系数T,进行数学上的softmax操作,温度T被加到softmax中以缩放整体分布,通过选择适当的温度系数T,从而起到更好地平滑作用,表示硬标签。
c5.基于所有类别平滑后的语义软标签,获得所述优化后的语义软标签损失函数。
所述优化后的语义软标签损失函数的表达式如下:
其中,表示超参数,/>表示硬标签/>和模型输出/>的KL散度,q(x)表示硬标签,p(x)表示相应类别神经网络模型的输出,/>表示优化后的语义软标签和相应类别神经网络模型的输出的KL散度。
随机初始化知识蒸馏软标签,计算知识蒸馏Gram矩阵,通过所述知识蒸馏Gram矩阵、嵌入Gram矩阵和相应类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数;
d1.随机初始化知识蒸馏软标签,确定不同类别的知识蒸馏软标签之间的相关性,得到所述知识蒸馏Gram矩阵;
d2.将旧任务和新任务不同类别的聚类中心输入至旧神经网络模型,获取每个类别的嵌入特征,确定不同类别嵌入特征之间的相关性,得到所述嵌入Gram矩阵;
d3.计算所述知识蒸馏Gram矩阵和嵌入Gram矩阵之间的欧几里得距离,获得中间过程知识蒸馏软标签;
d4.采用softmax函数对所述中间过程知识蒸馏软标签归一化,获得优化后的中间过程知识蒸馏软标签,对于每个类别,采用该类别优化后的中间过程知识蒸馏软标签平滑相应原始的硬标签,获得该类别平滑后的知识蒸馏软标签;
考虑到标准化的交叉熵分布的argmax可能与硬标签不匹配,那么,直接采用这个归一化的分布作为软标签,会违反基本约束条件。
因此,优化后的中间过程知识蒸馏软标签的表达式如下:
其中,表示优化后的中间过程知识蒸馏软标签,f表示中间过程知识蒸馏软标签,进行数学上的softmax操作,温度T被加到softmax中以缩放整体分布,通过选择适当的温度系数T,从而起到更好地平滑作用。
d5.基于所有类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数。
所述优化后的知识蒸馏软标签损失函数的表达式如下:
其中,表示优化后的中间过程知识蒸馏标签和相应类别神经网络模型的输出的KL散度,所述优化后的知识蒸馏软标签损失函数是旧模型关于各类别之间的预测信息,如果超参数/>,则优化目标将退化为多分类交叉熵形式,如果将超参数设置为1,则优化目标仅与平滑后的蒸馏软标签相关。但事实上,由上述关系式看出,所述相应类别神经网络模型的输出不仅取决于当前类别,还取决于类别之间的关系。通过改变标签的分布,本发明的方法进一步模拟了类别之间的关系。
步骤2.将所述优化后的语义软标签损失函数与所述优化后的知识蒸馏软标签损失函数结合,获得总损失函数;
所述总损失函数的表达式如下:
其中,表示总损失函数,/>是硬标签和相应类别神经网络模型的输出的KL散度。
步骤3.采用所述总损失函数进行新任务的训练。
例如,新模型在学习猫,狗,苹果的分类任务,通过本发明上述方法得到的优化后的后的语义软标签和优化后的后的蒸馏软标签,获得三个类别之间的关联信息以及旧模型关于三者的预测信息,得到结论是猫和狗之间的类别较与苹果之间的类别更相近,因此,将猫和狗划分为一类,动物类,将苹果划分为另一类。
上述所有可选技术方案,可以采用任意结合形成本申请的可选实施例,在此不再一一赘述。
应理解,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。
以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围,均应包含在本发明的保护范围之内。
Claims (10)
1.一种基于知识增强的连续学习软标签构建方法,其特征在于,包括:
S1.随机初始化语义软标签,计算语义Gram矩阵,通过所述语义Gram矩阵、词向量Gram矩阵和相应类别平滑后的语义软标签,获得优化后的语义软标签损失函数;
以及随机初始化知识蒸馏软标签,计算知识蒸馏Gram矩阵,通过所述知识蒸馏Gram矩阵、嵌入Gram矩阵和相应类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数;
S2.将所述优化后的语义软标签损失函数与所述优化后的知识蒸馏软标签损失函数结合,获得总损失函数;
S3.采用所述总损失函数进行新任务的训练。
2.根据权利要求1所述的基于知识增强的连续学习软标签构建方法,其特征在于,所述随机初始化语义软标签,通过所述语义Gram矩阵、词向量Gram矩阵和相应类别平滑后的语义软标签,获得优化后的语义软标签损失函数包括:
a1.随机初始化所述语义软标签,确定不同类别的语义软标签之间的相关性,得到所述语义Gram矩阵;
a2.采用外部词向量模型获得对应类别的词向量,确定不同类别的词向量之间的相关性,得到所述词向量Gram矩阵,其中,所述外部词向量模型为CLIP或Bert;
a3.计算语义Gram矩阵和词向量Gram矩阵之间的欧几里得距离,获得中间过程语义软标签;
a4.采用softmax函数对所述中间过程语义软标签归一化,获得优化后的中间过程语义软标签,对于每个类别,采用该类别优化后的中间过程语义软标签平滑相应的原始硬标签,获得该类别平滑后的语义软标签;
a5.基于所有类别平滑后的语义软标签,获得所述优化后的语义软标签损失函数。
3.根据权利要求2所述的基于知识增强的连续学习软标签构建方法,其特征在于,
语义Gram矩阵和词向量Gram矩阵之间的欧几里得距离的表达式如下:
其中,表示中间过程的语义损失函数,n表示类别数量正整数,/>表示语义Gram矩阵,/>表示词向量Gram矩阵。
4.根据权利要求2所述的基于知识增强的连续学习软标签构建方法,其特征在于,优化后的中间过程语义软标签的表达式如下:
其中,表示所述优化后的中间过程语义软标签,k表示中间过程语义软标签,表示将中间过程语义软标签除以温度系数T,进行数学上的操作,q(x)表示硬标签。
5.根据权利要求1所述的基于知识增强的连续学习软标签构建方法,其特征在于,所述优化后的语义软标签损失函数的表达式如下:
其中,表示超参数,q(x)表示硬标签,p(x)表示相应类别神经网络模型的输出,表示硬标签和相应类别神经网络模型的输出的KL散度,表示优化后的中间过程语义软标签和相应类别神经网络模型输出的KL散度。
6.根据权利要求1所述的基于知识增强的连续学习软标签构建方法,其特征在于,所述随机初始化知识蒸馏软标签,通过所述知识蒸馏Gram矩阵、嵌入Gram矩阵和相应类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数,包括:
b1.随机初始化知识蒸馏软标签,确定不同类别的知识蒸馏软标签之间的相关性,得到所述知识蒸馏Gram矩阵;
b2.将旧任务和新任务不同类别的聚类中心输入至旧神经网络模型,获取每个类别的嵌入特征,确定不同类别嵌入特征之间的相关性,得到所述嵌入Gram矩阵;
b3.计算所述知识蒸馏Gram矩阵和嵌入Gram矩阵之间的欧几里得距离,获得中间过程知识蒸馏软标签;
b4.采用softmax函数对所述中间过程知识蒸馏软标签归一化,获得优化后的中间过程知识蒸馏软标签,对于每个类别,采用该类别优化后的中间过程知识蒸馏软标签平滑相应原始的硬标签,获得该类别平滑后的知识蒸馏软标签;
b5.基于所有类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数。
7.根据权利要求6所述的基于知识增强的连续学习软标签构建方法,其特征在于,所述知识蒸馏Gram矩阵和嵌入Gram矩阵之间的欧几里得距离的表达式如下:
其中,表示中间过程的知识蒸馏损失函数,/>表示知识蒸馏Gram矩阵,表示嵌入Gram矩阵。
8.根据权利要求5所述的基于知识增强的连续学习软标签构建方法,其特征在于,所述优化后的中间过程知识蒸馏软标签的表达式如下:
其中,表示优化后的中间过程知识蒸馏软标签,/>表示中间过程知识蒸馏软标签,温度T被加到softmax中以缩放整体分布。
9.根据权利要求1所述的基于知识增强的连续学习软标签构建方法,其特征在于,优化后的知识蒸馏软标签损失函数的表达式如下:
其中,表示超参数,/>是优化后的中间过程知识蒸馏标签和相应类别神经网络模型的输出的KL散度,/>表示优化后的中间过程知识蒸馏软标签。
10.根据权利要求1所述的基于知识增强的连续学习软标签构建方法,其特征在于,所述总损失函数的表达式如下:
其中,表示总损失函数,/>表示优化后的中间过程知识蒸馏标签和相应类别神经网络模型的输出的KL散度。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410183536.5A CN117743858B (zh) | 2024-02-19 | 2024-02-19 | 一种基于知识增强的连续学习软标签构建方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410183536.5A CN117743858B (zh) | 2024-02-19 | 2024-02-19 | 一种基于知识增强的连续学习软标签构建方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117743858A true CN117743858A (zh) | 2024-03-22 |
CN117743858B CN117743858B (zh) | 2024-07-19 |
Family
ID=90261238
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410183536.5A Active CN117743858B (zh) | 2024-02-19 | 2024-02-19 | 一种基于知识增强的连续学习软标签构建方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117743858B (zh) |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112257864A (zh) * | 2020-10-22 | 2021-01-22 | 福州大学 | 一种用于解决灾难性遗忘问题的终生学习方法 |
CN113158902A (zh) * | 2021-04-23 | 2021-07-23 | 深圳龙岗智能视听研究院 | 一种基于知识蒸馏的自动化训练识别模型的方法 |
WO2021190451A1 (zh) * | 2020-03-24 | 2021-09-30 | 华为技术有限公司 | 训练图像处理模型的方法和装置 |
WO2022066133A1 (en) * | 2020-09-25 | 2022-03-31 | Aselsan Elektroni̇k Sanayi̇ Ve Ti̇c.A.Ş. | Meta tag generation method for learning from dirty tags |
WO2022227400A1 (zh) * | 2021-04-27 | 2022-11-03 | 商汤集团有限公司 | 神经网络训练方法和装置、设备,及计算机存储介质 |
-
2024
- 2024-02-19 CN CN202410183536.5A patent/CN117743858B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021190451A1 (zh) * | 2020-03-24 | 2021-09-30 | 华为技术有限公司 | 训练图像处理模型的方法和装置 |
WO2022066133A1 (en) * | 2020-09-25 | 2022-03-31 | Aselsan Elektroni̇k Sanayi̇ Ve Ti̇c.A.Ş. | Meta tag generation method for learning from dirty tags |
CN112257864A (zh) * | 2020-10-22 | 2021-01-22 | 福州大学 | 一种用于解决灾难性遗忘问题的终生学习方法 |
CN113158902A (zh) * | 2021-04-23 | 2021-07-23 | 深圳龙岗智能视听研究院 | 一种基于知识蒸馏的自动化训练识别模型的方法 |
WO2022227400A1 (zh) * | 2021-04-27 | 2022-11-03 | 商汤集团有限公司 | 神经网络训练方法和装置、设备,及计算机存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN117743858B (zh) | 2024-07-19 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113837370B (zh) | 用于训练基于对比学习的模型的方法和装置 | |
CN114037876B (zh) | 一种模型优化方法和装置 | |
CN111724083A (zh) | 金融风险识别模型的训练方法、装置、计算机设备及介质 | |
CN109816032A (zh) | 基于生成式对抗网络的无偏映射零样本分类方法和装置 | |
Angelov et al. | Toward anthropomorphic machine learning | |
CN113139664B (zh) | 一种跨模态的迁移学习方法 | |
CN113887580A (zh) | 一种考虑多粒度类相关性的对比式开放集识别方法及装置 | |
Dai et al. | Hybrid deep model for human behavior understanding on industrial internet of video things | |
CN116663568B (zh) | 基于优先级的关键任务识别系统及其方法 | |
CN112906398B (zh) | 句子语义匹配方法、系统、存储介质和电子设备 | |
CN112380427B (zh) | 基于迭代图注意力网络的用户兴趣预测方法及电子装置 | |
CN115687610A (zh) | 文本意图分类模型训练方法、识别方法、装置、电子设备及存储介质 | |
CN116561591B (zh) | 科技文献语义特征提取模型训练方法、特征提取方法及装置 | |
Nguyen et al. | Semi-supervised adversarial discriminative domain adaptation | |
CN117743858B (zh) | 一种基于知识增强的连续学习软标签构建方法 | |
Shen et al. | On image classification: Correlation vs causality | |
CN111737591A (zh) | 一种基于异质重边信息网络翻译模型的产品推荐方法 | |
CN113297385B (zh) | 基于改进GraphRNN的多标签文本分类系统及分类方法 | |
Wang et al. | From machine learning to transfer learning | |
CN112686318B (zh) | 一种基于球面嵌入、球面对齐和球面校准的零样本学习机制 | |
CN114782791A (zh) | 基于transformer模型和类别关联的场景图生成方法 | |
Serrano et al. | Inter-task similarity measure for heterogeneous tasks | |
Lai et al. | Cross-domain sentiment classification using topic attention and dual-task adversarial training | |
Wu et al. | Applying a Probabilistic Network Method to Solve Business‐Related Few‐Shot Classification Problems | |
Cao et al. | A new skeleton-neural DAG learning approach |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |