CN113610173A - 一种基于知识蒸馏的多跨域少样本分类方法 - Google Patents

一种基于知识蒸馏的多跨域少样本分类方法 Download PDF

Info

Publication number
CN113610173A
CN113610173A CN202110931565.1A CN202110931565A CN113610173A CN 113610173 A CN113610173 A CN 113610173A CN 202110931565 A CN202110931565 A CN 202110931565A CN 113610173 A CN113610173 A CN 113610173A
Authority
CN
China
Prior art keywords
teacher
student
networks
feature
encoder
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202110931565.1A
Other languages
English (en)
Other versions
CN113610173B (zh
Inventor
冀中
倪婧玮
刘西瑶
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tianjin University
Original Assignee
Tianjin University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tianjin University filed Critical Tianjin University
Priority to CN202110931565.1A priority Critical patent/CN113610173B/zh
Publication of CN113610173A publication Critical patent/CN113610173A/zh
Application granted granted Critical
Publication of CN113610173B publication Critical patent/CN113610173B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Probability & Statistics with Applications (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Image Analysis (AREA)

Abstract

一种基于知识蒸馏的多跨域少样本分类方法,利用知识蒸馏中师生网络的框架进行有效知识的迁移,从而使模型具有更好的泛化能力。本发明将元学习的训练策略引入知识蒸馏中,通过面向任务的知识蒸馏和多个教师网络之间的协作,不仅向学生网络提供了丰富且有效的知识,而且保证了学生网络对少样本任务的快速适应能力。通过引入多层次知识蒸馏,分别提取教师网络的输出预测和样本关系作为监督信息,从不同角度指导学生网络的训练,使得知识蒸馏的效率更高。由此,本发明能够将有效的知识更好地从多个源域迁移到目标域上,提高学生网络在目标少样本任务上的分类准确率。

Description

一种基于知识蒸馏的多跨域少样本分类方法
技术领域
本发明涉及一种少样本分类方法。特别是涉及一种基于知识蒸馏的多跨域少样本分类方法。
背景技术
目前,深度学习在计算机视觉领域已经取得了较大成功,例如物体分类、图像检索和动作识别等任务。深度学习的成功在很大程度上依赖于海量的数据和强大的计算资源。而许多认知学和心理学证据表明,人类往往可以从很少的例子中识别出新的视觉概念,这种快速学习的能力是现在的深度学习所不具备的。因此,如何通过有限的标记数据来学习识别新类别引起了人们的广泛关注,这也是少样本学习(Few-Shot Learning)所要解决的问题。近几年来,大量少样本学习的工作都采用了元学习(Meta Learning)的思想,其中基于度量的方法因其简单性和有效性而被广泛使用。这一类方法的模型结构主要包括两部分:特征编码器和度量函数。给定一个少样本任务,包含少量带标记的图像(支持集)和一些未标记的图像(查询集),特征编码器首先提取所有的图像特征,然后度量函数对标记图像和未标记图像的特征相似度进行计算,并预测查询图像的对应类别。
在基于元学习的少样本学习中,往往需要借助于一个包含大量标记数据的辅助数据集,这个数据集的类别与测试集相关但不相交。元学习的思想就是在辅助数据集上采样大量的少样本任务来训练模型,使其积累经验,从而能够快速适应新的少样本任务。然而,在某些现实场景中,例如医学、军事和金融等领域,存在数据获取困难和标记成本高等问题,无法获取包含相关类别的辅助数据集。这种情况下,只能使用来自其他领域的标记数据来训练模型以提供先验知识。由于辅助数据集和测试集的类别不相关,就不可避免地产生了域偏移问题,也严重损害了模型在测试集上的性能。因此,提高模型在不同领域之间的泛化能力成为少样本学习的一个重要挑战,也称为跨域少样本学习(Cross-Domain Few-ShotLearning)。具体来说,这个问题可以描述为在不使用目标域数据的情况下,借助其他域的辅助数据集训练模型,最终在目标域上很好的完成少样本分类任务。跨域问题也可以看做是一种知识迁移的问题,其目的是将有用的知识从辅助数据集传递到目标数据集上,保证模型在新的少样本任务上的分类性能。
由于深度学习模型的性能通常会随着可用数据的增加而提高,因此对跨域少样本学习来说,一种直观且简单的假设是,少样本学习模型应该充分利用不同任务和不同域中的数据来积累更多的经验。这一研究方向也被称为多跨域少样本学习(Multiple Cross-Domain Few-Shot Learning)。在这种设置下,辅助数据集包含了许多不同的领域,每个领域都由不同的源数据集表示。此时,另一个需要考虑的问题就是域之间的相关性对模型的影响。不同域数据的特征分布可能会有交叉重叠,也可能完全不相交。利用不相关的域可能对模型产生负面影响,带来知识干扰的问题。这就意味着,简单地混合多个域的标记数据来训练模型是不可行的,还必须探索如何利用或忽略从不同领域学到的知识,实现模型在目标域上的泛化和避免跨域干扰的问题。
发明内容
本发明所要解决的技术问题是,提供一种能够将有效的知识更好地从多个源域迁移到目标域上的基于知识蒸馏的多跨域少样本分类方法。
本发明所采用的技术方案是:一种基于知识蒸馏的多跨域少样本分类方法,其特征在于,包括如下步骤:
1)预训练阶段,分别利用N个不同源域的训练集{Z1,Z2,...,ZN}来训练N个不同的教师网络,每个教师网络包含一个教师特征编码器E和教师分类器C,初始化N个教师网络参数,利用交叉熵损失函数对每一个教师网络进行预训练,最终得到N个训练好的教师网络;
2)构建学生网络,学生网络是一种基于度量的少样本模型,包含一个学生特征编码器Es和一个度量函数d,初始化学生特征编码器Es的参数;
3)元训练阶段,从N个不同源域的训练集中随机选取一个训练集作为当前的元训练集Dtrain,根据元学习的思想,从当前的元训练集Dtrain中随机采样一定量的少样本任务,每个任务都包含一个支持集S和一个查询集Q,支持集中含有W个类别的数据,每个类别有K个样本;
4)依次将不同的少样本任务同时送到N个教师网络和学生网络中;
5)依次将支持集S中第k个样本图像xk输入到N个教师特征编码器和学生特征编码器中,分别得到相对应的视觉特征
Figure BDA0003211363720000021
Figure BDA0003211363720000022
Figure BDA0003211363720000023
其中xk为支持集S中第k个样本图像,En为第n个教师特征编码器,Es为学生特征编码器,
Figure BDA0003211363720000024
为第n个教师特征编码器En对xk编码后输出的视觉特征,
Figure BDA0003211363720000025
为学生特征编码器Es对xk编码后输出的视觉特征;
6)分别对支持集中属于同一类别的样本视觉特征取平均,得到每个类别的原型表示为:
Figure BDA0003211363720000026
Figure BDA0003211363720000027
其中K为第w个类别的样本总数,
Figure BDA0003211363720000028
为经过第n个教师特征编码器编码后的第w个类别的原型表示,
Figure BDA0003211363720000029
为经过学生特征编码器编码后的第w个类别的原型表示;
7)依次将查询集的样本图像xQ输入到N个教师特征编码器和学生特征编码器中,分别得到相对应的视觉特征
Figure BDA00032113637200000210
Figure BDA00032113637200000211
Figure BDA00032113637200000212
Figure BDA0003211363720000031
其中xQ为查询集Q中的样本图像,En为第n个教师特征编码器,Es为学生特征编码器,
Figure BDA0003211363720000032
为第n个教师特征编码器En对xQ编码后输出的视觉特征,
Figure BDA0003211363720000033
为学生特征编码器Es对xQ编码后输出的视觉特征;
8)根据经过学生特征编码器编码后的原型表示和查询集样本图像的视觉特征,按照如下公式计算查询集样本图像xQ属于支持集中各个类别的概率:
Figure BDA0003211363720000034
其中ps(y=w|xQ)为学生网络输出的查询集样本图像xQ属于第w个类别的预测概率,W为支持集中类别的总数,函数d为欧氏距离的度量函数,exp为自然常数e为底的指数函数;
9)根据学生网络输出的查询集样本图像xQ的类别预测概率计算分类损失,设定学生网络的分类目标函数Lcls如下:
Figure BDA0003211363720000035
其中yQ为查询集中样本图像xQ的真实标签,ps(y=w|xQ)为学生网络输出的查询集样本图像xQ属于第w个类别的预测概率,W为支持集中类别的总数;
10)根据经过N个教师特征编码器和学生特征编码器编码后的原型表示和查询集样本图像的视觉特征,利用温度系数τ计算软化后的类别预测概率分布,从而在N个教师网络和学生网络之间进行基于软标签的知识蒸馏,得到学生网络的基于软标签的目标函数LKL
11)根据经过N个教师特征编码器和学生特征编码器编码后的原型表示和查询集样本图像的视觉特征,利用成对的特征计算相似度矩阵,从而在N个教师网络和学生网络之间进行基于相似度的知识蒸馏,得到学生网络的基于相似度的目标函数Lsim
12)根据如下学生网络的总目标函数公式,使用SGD算法训练学生特征编码器:
L=Lcls+LKL+Lsim (21)
其中,L为学生网络的总目标函数,LKL为学生网络的基于软标签的目标函数,Lsim为学生网络的基于相似度的目标函数;
13)重复步骤3-12,直至总目标函数值逐渐收敛且趋于不变时,得到训练好的学生网络;
14)测试阶段,给定一个不同于N个源域的数据集作为目标域,依次将来自目标域测试集的支持集和查询集的样本图像输入到训练好的学生特征编码器Es中,得到相应的视觉特征,按照公式(7)计算支持集中各个类别的原型表示,再按照公式(10)计算查询集样本图像属于各个类别的概率,将计算得到的概率中最大的概率所对应的类别,作为查询集样本图像的类别。
本发明的一种基于知识蒸馏的多跨域少样本分类方法,利用知识蒸馏中师生网络的框架进行有效知识的迁移,从而使模型具有更好的泛化能力。本发明将元学习的训练策略引入知识蒸馏中,通过面向任务的知识蒸馏和多个教师网络之间的协作,不仅向学生网络提供了丰富且有效的知识,而且保证了学生网络对少样本任务的快速适应能力。通过引入多层次知识蒸馏,分别提取教师网络的输出预测和样本关系作为监督信息,从不同角度指导学生网络的训练,使得知识蒸馏的效率更高。由此,本发明能够将有效的知识更好地从多个源域迁移到目标域上,提高学生网络在目标少样本任务上的分类准确率。
附图说明
图1是本发明的一种基于知识蒸馏的多跨域少样本分类方法的流程图。
具体实施方式
下面结合实施例和附图对本发明的一种基于知识蒸馏的多跨域少样本分类方法做出详细说明。
多跨域少样本学习利用多个源域的训练数据作为辅助数据集为模型提供先验知识,最终完成对目标域上测试样本类别的预测。假设在训练阶段给出了N个不同源域的训练集{Z1,Z2,…,ZN},每个源域都包含大量带标签的数据
Figure BDA0003211363720000041
其中z为第n个源域训练集的样本总数,
Figure BDA0003211363720000042
是该训练集中的第i个样本图像,
Figure BDA0003211363720000043
是该训练集中第i个样本对应的类别标签。在测试阶段,利用训练好的模型完成目标域上的少样本分类任务。每个少样本任务都包含一个支持集S和一个查询集Q。支持集中含有W个类别的数据,每个类别有K个样本。少样本的任务就是利用W*K个支持集的训练数据,对查询集中的样本所属类别进行预测。
图1描述了基于知识蒸馏的多跨域少样本分类方法模型的流程图。T表示一个少样本任务的所有图像,{E1,E2,…,EN}为N个教师特征编码器,Es为学生特征编码器,v表示视觉特征,D为距离度量模块,包含两部分d和
Figure BDA0003211363720000044
p表示输出的类别概率分布,M表示相似度矩阵。
如图1所示,本发明的一种基于知识蒸馏的多跨域少样本分类方法,包括如下步骤:
1)预训练阶段,分别利用N个不同源域的训练集{Z1,Z2,…,ZN}来训练N个不同的教师网络,每个教师网络包含一个教师特征编码器E和教师分类器C,初始化N个教师网络参数,利用交叉熵损失函数对每一个教师网络进行预训练,最终得到N个训练好的教师网络;所述的利用交叉熵损失函数对每一个教师网络进行预训练,包括:
(1)从第n个源域训练集Zn中随机选取一定量的数据
Figure BDA0003211363720000045
作为第n个教师特征编码器En的输入,经过编码得到第i个样本图像的视觉特征
Figure BDA0003211363720000046
Figure BDA0003211363720000047
其中
Figure BDA0003211363720000048
为第n个源域训练集中的第i个样本图像,
Figure BDA0003211363720000049
为第i个样本图像
Figure BDA00032113637200000410
的真实标签;
(2)将第i个样本图像的视觉特征
Figure BDA00032113637200000411
输入第n个教师分类器Cn,得到第n个源域训练集中的第i个样本图像的类别预测概率:
Figure BDA00032113637200000412
其中
Figure BDA0003211363720000051
为第i个样本图像的视觉特征,
Figure BDA0003211363720000052
为第i个样本图像
Figure BDA0003211363720000053
属于第r个类别的预测概率;
(3)设定教师网络的目标函数Ln公式如下:
Figure BDA0003211363720000054
其中
Figure BDA0003211363720000055
为第i个样本图像
Figure BDA0003211363720000056
的真实标签,R为第n个源域训练集中的样本类别数,
Figure BDA0003211363720000057
为第i个样本图像
Figure BDA0003211363720000058
属于第r个类别的预测概率;
(4)根据公式(3)训练第n个教师特征编码器En和第n个教师分类器Cn,保留使公式(3)的误差最小的第n个教师特征编码器En和第n个教师分类器Cn的参数;
(5)重复第(1)步~第(4)步,得到训练好的N个教师网络。
2)构建学生网络,学生网络是一种基于度量的少样本模型,包含一个学生特征编码器Es和一个度量函数d,初始化学生特征编码器Es的参数;为了确保教师网络和学生网络的输出保持一致,只保留教师网络的训练好的教师特征编码器部分,不再使用训练好的教师分类器进行分类,而采用基于度量的方法完成分类,并且教师网络的参数固定不变。
3)元训练阶段,从N个不同源域的训练集中随机选取一个训练集作为当前的元训练集Dtrain,根据元学习的思想,从当前的元训练集Dtrain中随机采样一定量的少样本任务,每个任务都包含一个支持集S和一个查询集Q,支持集中含有W个类别的数据,每个类别有K个样本;
4)依次将不同的少样本任务同时送到N个教师网络和学生网络中;
5)依次将支持集S中第k个样本图像xk输入到N个教师特征编码器和学生特征编码器中,分别得到相对应的视觉特征
Figure BDA0003211363720000059
Figure BDA00032113637200000510
Figure BDA00032113637200000511
其中xk为支持集S中第k个样本图像,En为第n个教师特征编码器,Es为学生特征编码器,
Figure BDA00032113637200000512
为第n个教师特征编码器En对xk编码后输出的视觉特征,
Figure BDA00032113637200000513
为学生特征编码器Es对xk编码后输出的视觉特征;
6)分别对支持集中属于同一类别的样本视觉特征取平均,得到每个类别的原型表示为:
Figure BDA00032113637200000514
Figure BDA00032113637200000515
其中K为第w个类别的样本总数,
Figure BDA00032113637200000516
为经过第n个教师特征编码器编码后的第w个类别的原型表示,
Figure BDA00032113637200000517
为经过学生特征编码器编码后的第w个类别的原型表示;
7)依次将查询集的样本图像xQ输入到N个教师特征编码器和学生特征编码器中,分别得到相对应的视觉特征
Figure BDA0003211363720000061
Figure BDA0003211363720000062
Figure BDA0003211363720000063
Figure BDA0003211363720000064
其中xQ为查询集Q中的样本图像,En为第n个教师特征编码器,Es为学生特征编码器,
Figure BDA0003211363720000065
为第n个教师特征编码器En对xQ编码后输出的视觉特征,
Figure BDA0003211363720000066
为学生特征编码器Es对xQ编码后输出的视觉特征;
8)根据经过学生特征编码器编码后的原型表示和查询集样本图像的视觉特征,按照如下公式计算查询集样本图像xQ属于支持集中各个类别的概率:
Figure BDA0003211363720000067
其中ps(y=w|xQ)为学生网络输出的查询集样本图像xQ属于第w个类别的预测概率,W为支持集中类别的总数,函数d为欧氏距离的度量函数,exp为自然常数e为底的指数函数;
9)根据学生网络输出的查询集样本图像xQ的类别预测概率计算分类损失,设定学生网络的分类目标函数Lcls如下:
Figure BDA0003211363720000068
其中yQ为查询集中样本图像xQ的真实标签,ps(y=w|xQ)为学生网络输出的查询集样本图像xQ属于第w个类别的预测概率,W为支持集中类别的总数;
10)根据经过N个教师特征编码器和学生特征编码器编码后的原型表示和查询集样本图像的视觉特征,利用温度系数τ计算软化后的类别预测概率分布,从而在N个教师网络和学生网络之间进行基于软标签的知识蒸馏,得到学生网络的基于软标签的目标函数LKL;包括:
(1)根据经过N个教师特征编码器和学生特征编码器编码后的原型表示和查询集样本图像的视觉特征,通过度量函数d,再除以温度系数τ,最后做softmax变换,得到软化后的类别预测概率:
Figure BDA0003211363720000069
Figure BDA00032113637200000610
其中
Figure BDA00032113637200000611
为软化后第n个教师网络输出的查询集样本图像xQ属于第w个类别的预测概率,
Figure BDA00032113637200000612
为软化后学生网络输出的查询集样本图像xQ属于第w个类别的预测概率,
Figure BDA00032113637200000613
为经过第n个教师特征编码器编码后的第w个类别的原型表示,
Figure BDA00032113637200000614
为经过学生特征编码器编码后的第w个类别的原型表示,
Figure BDA0003211363720000071
为第n个教师特征编码器En对xQ编码后输出的视觉特征,
Figure BDA0003211363720000072
为学生特征编码器Es对xQ编码后输出的视觉特征,τ为温度系数,W为支持集中类别的总数,函数d为欧氏距离的度量函数,exp为自然常数e为底的指数函数;
(2)将软化后的N个教师网络输出的查询集样本图像xQ属于第w个类别的预测概率进行加权求和,作为训练学生网络的目标之一:
Figure BDA0003211363720000073
其中α12,…,αN分别为N个教师网络的权重系数,
Figure BDA0003211363720000074
分别为软化后N个教师网络输出的查询集样本图像xQ属于第w个类别的预测概率,
Figure BDA0003211363720000075
为软化后N个教师网络进行加权求和输出的查询集样本图像xQ属于第w个类别的预测概率;
(3)为了使学生网络与教师网络的输出一致,设定学生网络的基于软标签的目标函数LKL如下:
Figure BDA0003211363720000076
其中
Figure BDA0003211363720000077
为软化后N个教师网络进行加权求和输出的查询集样本图像xQ属于支持集各个类别的预测概率分布,
Figure BDA0003211363720000078
为软化后学生网络输出的查询集样本图像xQ属于支持集各个类别的预测概率分布,KLdiv为Kullback-Leibler散度,用来衡量两个概率分布
Figure BDA0003211363720000079
Figure BDA00032113637200000710
之间的差异,τ为温度系数。
11)根据经过N个教师特征编码器和学生特征编码器编码后的原型表示和查询集样本图像的视觉特征,利用成对的特征计算相似度矩阵,从而在N个教师网络和学生网络之间进行基于相似度的知识蒸馏,得到学生网络的基于相似度的目标函数Lsim;包括:
(1)定义特征集合F,包含支持集中所有类别的原型表示和查询集样本图像的视觉特征,表示为:
Figure BDA00032113637200000711
Figure BDA00032113637200000712
其中
Figure BDA00032113637200000713
为第n个教师网络的特征集合,Fs为学生网络的特征集合,
Figure BDA00032113637200000714
分别为经过第n个教师特征编码器编码后的各个类别的原型表示,
Figure BDA00032113637200000715
分别为经过学生特征编码器编码后的各个类别的原型表示,
Figure BDA00032113637200000716
分别为第n个教师特征编码器对查询集各个样本图像编码后输出的视觉特征,
Figure BDA00032113637200000717
分别为第n个教师特征编码器对查询集各个样本图像编码后输出的视觉特征,W为支持集中类别的总数,q为查询集中的样本图像总数;
(2)根据特征集合F计算相似度矩阵M:
Figure BDA00032113637200000718
其中Mij为相似度矩阵M中第i行第j列的元素,fi和fj分别为特征集合F中第i个和第j个元素,函数
Figure BDA00032113637200000719
为余弦距离的度量函数,m为特征集合F中的元素总数;
(3)根据N个教师网络和学生网络的特征集合,按照公式(19)得到N个教师网络和学生网络的相似度矩阵,将N个教师网络的相似度矩阵进行加权求和,作为训练学生网络的目标之二:
Figure BDA0003211363720000081
其中α12,…,αN分别为N个教师网络的权重系数,
Figure BDA0003211363720000082
分别为N个教师网络的相似度矩阵,Mt为N个教师网络的相似度矩阵进行加权求和的结果;
(4)为了使学生网络更好的学习样本之间的关系,进一步探索嵌入在样本相似度中的知识,使学生网络与教师网络的相似度矩阵尽可能相似,设定学生网络的基于相似度的目标函数Lsim如下:
Figure BDA0003211363720000083
其中Mt为N个教师网络的相似度矩阵进行加权求和的结果,Ms为学生网络的相似度矩阵,m为特征集合F中的元素总数,也就是相似度矩阵的维度。
12)根据如下学生网络的总目标函数公式,使用SGD算法训练学生特征编码器:
L=Lcls+LKL+Lsim (21)
其中,L为学生网络的总目标函数,Lcls为学生网络的分类目标函数,LKL为学生网络的基于软标签的目标函数,Lsim为学生网络的基于相似度的目标函数;
13)重复步骤3-12,直至总目标函数值逐渐收敛且趋于不变时,得到训练好的学生网络;
14)测试阶段,给定一个不同于N个源域的数据集作为目标域,依次将来自目标域测试集的支持集和查询集的样本图像输入到训练好的学生特征编码器Es中,得到相应的视觉特征,按照公式(7)计算支持集中各个类别的原型表示,再按照公式(10)计算查询集样本图像属于各个类别的概率,将计算得到的概率中最大的概率所对应的类别,作为查询集样本图像的类别。

Claims (4)

1.一种基于知识蒸馏的多跨域少样本分类方法,其特征在于,包括如下步骤:
1)预训练阶段,分别利用N个不同源域的训练集{Z1,Z2,…,ZN}来训练N个不同的教师网络,每个教师网络包含一个教师特征编码器E和教师分类器C,初始化N个教师网络参数,利用交叉熵损失函数对每一个教师网络进行预训练,最终得到N个训练好的教师网络;
2)构建学生网络,学生网络是一种基于度量的少样本模型,包含一个学生特征编码器Es和一个度量函数d,初始化学生特征编码器Es的参数;
3)元训练阶段,从N个不同源域的训练集中随机选取一个训练集作为当前的元训练集Dtrain,根据元学习的思想,从当前的元训练集Dtrain中随机采样一定量的少样本任务,每个任务都包含一个支持集S和一个查询集Q,支持集中含有W个类别的数据,每个类别有K个样本;
4)依次将不同的少样本任务同时送到N个教师网络和学生网络中;
5)依次将支持集S中第k个样本图像xk输入到N个教师特征编码器和学生特征编码器中,分别得到相对应的视觉特征
Figure FDA0003211363710000011
Figure FDA0003211363710000012
Figure FDA0003211363710000013
其中xk为支持集S中第k个样本图像,En为第n个教师特征编码器,Es为学生特征编码器,
Figure FDA0003211363710000014
为第n个教师特征编码器En对xk编码后输出的视觉特征,
Figure FDA0003211363710000015
为学生特征编码器Es对xk编码后输出的视觉特征;
6)分别对支持集中属于同一类别的样本视觉特征取平均,得到每个类别的原型表示为:
Figure FDA0003211363710000016
Figure FDA0003211363710000017
其中K为第w个类别的样本总数,
Figure FDA0003211363710000018
为经过第n个教师特征编码器编码后的第w个类别的原型表示,
Figure FDA0003211363710000019
为经过学生特征编码器编码后的第w个类别的原型表示;
7)依次将查询集的样本图像xQ输入到N个教师特征编码器和学生特征编码器中,分别得到相对应的视觉特征
Figure FDA00032113637100000110
Figure FDA00032113637100000111
Figure FDA00032113637100000112
Figure FDA00032113637100000113
其中xQ为查询集Q中的样本图像,En为第n个教师特征编码器,Es为学生特征编码器,
Figure FDA00032113637100000114
为第n个教师特征编码器En对xQ编码后输出的视觉特征,
Figure FDA00032113637100000115
为学生特征编码器Es对xQ编码后输出的视觉特征;
8)根据经过学生特征编码器编码后的原型表示和查询集样本图像的视觉特征,按照如下公式计算查询集样本图像xQ属于支持集中各个类别的概率:
Figure FDA0003211363710000021
其中ps(y=w|xQ)为学生网络输出的查询集样本图像xQ属于第w个类别的预测概率,W为支持集中类别的总数,函数d为欧氏距离的度量函数,exp为自然常数e为底的指数函数;
9)根据学生网络输出的查询集样本图像xQ的类别预测概率计算分类损失,设定学生网络的分类目标函数Lcls如下:
Figure FDA0003211363710000022
其中yQ为查询集中样本图像xQ的真实标签,ps(y=w|xQ)为学生网络输出的查询集样本图像xQ属于第w个类别的预测概率,W为支持集中类别的总数;
10)根据经过N个教师特征编码器和学生特征编码器编码后的原型表示和查询集样本图像的视觉特征,利用温度系数τ计算软化后的类别预测概率分布,从而在N个教师网络和学生网络之间进行基于软标签的知识蒸馏,得到学生网络的基于软标签的目标函数LKL
11)根据经过N个教师特征编码器和学生特征编码器编码后的原型表示和查询集样本图像的视觉特征,利用成对的特征计算相似度矩阵,从而在N个教师网络和学生网络之间进行基于相似度的知识蒸馏,得到学生网络的基于相似度的目标函数Lsim
12)根据如下学生网络的总目标函数公式,使用SGD算法训练学生特征编码器:
L=Lcls+LKL+Lsim (21)
其中,L为学生网络的总目标函数,LKL为学生网络的基于软标签的目标函数,Lsim为学生网络的基于相似度的目标函数;
13)重复步骤3-12,直至总目标函数值逐渐收敛且趋于不变时,得到训练好的学生网络;
14)测试阶段,给定一个不同于N个源域的数据集作为目标域,依次将来自目标域测试集的支持集和查询集的样本图像输入到训练好的学生特征编码器Es中,得到相应的视觉特征,按照公式(7)计算支持集中各个类别的原型表示,再按照公式(10)计算查询集样本图像属于各个类别的概率,将计算得到的概率中最大的概率所对应的类别,作为查询集样本图像的类别。
2.根据权利要求1所述的一种基于知识蒸馏的多跨域少样本分类方法,其特征在于,步骤1)所述的利用交叉熵损失函数对每一个教师网络进行预训练,包括:
(1)从第n个源域训练集Zn中随机选取一定量的数据
Figure FDA0003211363710000023
作为第n个教师特征编码器En的输入,经过编码得到第i个样本图像的视觉特征
Figure FDA0003211363710000031
Figure FDA0003211363710000032
其中
Figure FDA0003211363710000033
为第n个源域训练集中的第i个样本图像,
Figure FDA0003211363710000034
为第i个样本图像
Figure FDA0003211363710000035
的真实标签;
(2)将第i个样本图像的视觉特征
Figure FDA0003211363710000036
输入第n个教师分类器Cn,得到第n个源域训练集中的第i个样本图像的类别预测概率:
Figure FDA0003211363710000037
其中
Figure FDA0003211363710000038
为第i个样本图像的视觉特征,
Figure FDA0003211363710000039
为第i个样本图像
Figure FDA00032113637100000310
属于第r个类别的预测概率;
(3)设定教师网络的目标函数Ln公式如下:
Figure FDA00032113637100000311
其中
Figure FDA00032113637100000312
为第i个样本图像
Figure FDA00032113637100000313
的真实标签,R为第n个源域训练集中的样本类别数,
Figure FDA00032113637100000314
为第i个样本图像
Figure FDA00032113637100000315
属于第r个类别的预测概率;
(4)根据公式(3)训练第n个教师特征编码器En和第n个教师分类器Cn,保留使公式(3)的误差最小的第n个教师特征编码器En和第n个教师分类器Cn的参数;
(5)重复第(1)步~第(4)步,得到训练好的N个教师网络。
3.根据权利要求1所述的一种基于知识蒸馏的多跨域少样本分类方法,其特征在于,步骤10)包括:
(1)根据经过N个教师特征编码器和学生特征编码器编码后的原型表示和查询集样本图像的视觉特征,通过度量函数d,再除以温度系数τ,最后做softmax变换,得到软化后的类别预测概率:
Figure FDA00032113637100000316
Figure FDA00032113637100000317
其中
Figure FDA00032113637100000318
为软化后第n个教师网络输出的查询集样本图像xQ属于第w个类别的预测概率,
Figure FDA00032113637100000319
为软化后学生网络输出的查询集样本图像xQ属于第w个类别的预测概率,
Figure FDA00032113637100000320
为经过第n个教师特征编码器编码后的第w个类别的原型表示,
Figure FDA00032113637100000321
为经过学生特征编码器编码后的第w个类别的原型表示,
Figure FDA00032113637100000322
为第n个教师特征编码器En对xQ编码后输出的视觉特征,
Figure FDA00032113637100000323
为学生特征编码器Es对xQ编码后输出的视觉特征,τ为温度系数,W为支持集中类别的总数,函数d为欧氏距离的度量函数,exp为自然常数e为底的指数函数;
(2)将软化后的N个教师网络输出的查询集样本图像xQ属于第w个类别的预测概率进行加权求和,作为训练学生网络的目标之一:
Figure FDA0003211363710000041
其中α12,…,αN分别为N个教师网络的权重系数,
Figure FDA0003211363710000042
分别为软化后N个教师网络输出的查询集样本图像xQ属于第w个类别的预测概率,
Figure FDA0003211363710000043
为软化后N个教师网络进行加权求和输出的查询集样本图像xQ属于第w个类别的预测概率;
(3)为了使学生网络与教师网络的输出一致,设定学生网络的基于软标签的目标函数LKL如下:
Figure FDA0003211363710000044
其中
Figure FDA0003211363710000045
为软化后N个教师网络进行加权求和输出的查询集样本图像xQ属于支持集各个类别的预测概率分布,
Figure FDA0003211363710000046
为软化后学生网络输出的查询集样本图像xQ属于支持集各个类别的预测概率分布,KLdiv为Kullback-Leibler散度,用来衡量两个概率分布
Figure FDA0003211363710000047
Figure FDA0003211363710000048
之间的差异,τ为温度系数。
4.根据权利要求1所述的一种基于知识蒸馏的多跨域少样本分类方法,其特征在于,步骤11)包括:
(1)定义特征集合F,包含支持集中所有类别的原型表示和查询集样本图像的视觉特征,表示为:
Figure FDA0003211363710000049
Figure FDA00032113637100000410
其中
Figure FDA00032113637100000411
为第n个教师网络的特征集合,Fs为学生网络的特征集合,
Figure FDA00032113637100000412
分别为经过第n个教师特征编码器编码后的各个类别的原型表示,
Figure FDA00032113637100000413
分别为经过学生特征编码器编码后的各个类别的原型表示,
Figure FDA00032113637100000414
分别为第n个教师特征编码器对查询集各个样本图像编码后输出的视觉特征,
Figure FDA00032113637100000415
分别为第n个教师特征编码器对查询集各个样本图像编码后输出的视觉特征,W为支持集中类别的总数,q为查询集中的样本图像总数;
(2)根据特征集合F计算相似度矩阵M:
Figure FDA00032113637100000416
其中Mij为相似度矩阵M中第i行第j列的元素,fi和fj分别为特征集合F中第i个和第j个元素,函数
Figure FDA00032113637100000417
为余弦距离的度量函数,m为特征集合F中的元素总数;
(3)根据N个教师网络和学生网络的特征集合,按照公式(19)得到N个教师网络和学生网络的相似度矩阵,将N个教师网络的相似度矩阵进行加权求和,作为训练学生网络的目标之二:
Figure FDA0003211363710000051
其中α12,…,αN分别为N个教师网络的权重系数,
Figure FDA0003211363710000052
分别为N个教师网络的相似度矩阵,Mt为N个教师网络的相似度矩阵进行加权求和的结果;
(4)设定学生网络的基于相似度的目标函数Lsim如下:
Figure FDA0003211363710000053
其中Mt为N个教师网络的相似度矩阵进行加权求和的结果,Ms为学生网络的相似度矩阵,m为特征集合F中的元素总数,也就是相似度矩阵的维度。
CN202110931565.1A 2021-08-13 2021-08-13 一种基于知识蒸馏的多跨域少样本分类方法 Active CN113610173B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110931565.1A CN113610173B (zh) 2021-08-13 2021-08-13 一种基于知识蒸馏的多跨域少样本分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110931565.1A CN113610173B (zh) 2021-08-13 2021-08-13 一种基于知识蒸馏的多跨域少样本分类方法

Publications (2)

Publication Number Publication Date
CN113610173A true CN113610173A (zh) 2021-11-05
CN113610173B CN113610173B (zh) 2022-10-04

Family

ID=78340695

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110931565.1A Active CN113610173B (zh) 2021-08-13 2021-08-13 一种基于知识蒸馏的多跨域少样本分类方法

Country Status (1)

Country Link
CN (1) CN113610173B (zh)

Cited By (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113869462A (zh) * 2021-12-02 2021-12-31 之江实验室 一种基于双路结构对比嵌入学习的小样本对象分类方法
CN114266977A (zh) * 2021-12-27 2022-04-01 青岛澎湃海洋探索技术有限公司 基于超分辨可选择网络的多auv的水下目标识别方法
CN114782776A (zh) * 2022-04-19 2022-07-22 中国矿业大学 基于MoCo模型的多模块知识蒸馏方法
CN114972904A (zh) * 2022-04-18 2022-08-30 北京理工大学 一种基于对抗三元组损失的零样本知识蒸馏方法及系统
CN115100532A (zh) * 2022-08-02 2022-09-23 北京卫星信息工程研究所 小样本遥感图像目标检测方法和系统
CN115099988A (zh) * 2022-06-28 2022-09-23 腾讯科技(深圳)有限公司 模型训练方法、数据处理方法、设备及计算机介质
CN115908823A (zh) * 2023-03-09 2023-04-04 南京航空航天大学 一种基于难度蒸馏的语义分割方法
CN116204770A (zh) * 2022-12-12 2023-06-02 中国公路工程咨询集团有限公司 一种用于桥梁健康监测数据异常检测的训练方法及装置
CN116452794A (zh) * 2023-04-14 2023-07-18 中国矿业大学 一种基于半监督学习的有向目标检测方法
CN116958548A (zh) * 2023-07-21 2023-10-27 中国矿业大学 基于类别统计驱动的伪标签自蒸馏语义分割方法
WO2024032386A1 (en) * 2022-08-08 2024-02-15 Huawei Technologies Co., Ltd. Systems and methods for artificial-intelligence model training using unsupervised domain adaptation with multi-source meta-distillation
CN114972904B (zh) * 2022-04-18 2024-05-31 北京理工大学 一种基于对抗三元组损失的零样本知识蒸馏方法及系统

Citations (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110097094A (zh) * 2019-04-15 2019-08-06 天津大学 一种面向人物交互的多重语义融合少样本分类方法
CN112183670A (zh) * 2020-11-05 2021-01-05 南开大学 一种基于知识蒸馏的少样本虚假新闻检测方法
CN112364894A (zh) * 2020-10-23 2021-02-12 天津大学 一种基于元学习的对抗网络的零样本图像分类方法
CN112418343A (zh) * 2020-12-08 2021-02-26 中山大学 多教师自适应联合知识蒸馏
CN112633406A (zh) * 2020-12-31 2021-04-09 天津大学 一种基于知识蒸馏的少样本目标检测方法
CN112784964A (zh) * 2021-01-27 2021-05-11 西安电子科技大学 基于桥接知识蒸馏卷积神经网络的图像分类方法
CN112801105A (zh) * 2021-01-22 2021-05-14 之江实验室 一种两阶段的零样本图像语义分割方法
CN112861936A (zh) * 2021-01-26 2021-05-28 北京邮电大学 一种基于图神经网络知识蒸馏的图节点分类方法及装置

Patent Citations (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110097094A (zh) * 2019-04-15 2019-08-06 天津大学 一种面向人物交互的多重语义融合少样本分类方法
CN112364894A (zh) * 2020-10-23 2021-02-12 天津大学 一种基于元学习的对抗网络的零样本图像分类方法
CN112183670A (zh) * 2020-11-05 2021-01-05 南开大学 一种基于知识蒸馏的少样本虚假新闻检测方法
CN112418343A (zh) * 2020-12-08 2021-02-26 中山大学 多教师自适应联合知识蒸馏
CN112633406A (zh) * 2020-12-31 2021-04-09 天津大学 一种基于知识蒸馏的少样本目标检测方法
CN112801105A (zh) * 2021-01-22 2021-05-14 之江实验室 一种两阶段的零样本图像语义分割方法
CN112861936A (zh) * 2021-01-26 2021-05-28 北京邮电大学 一种基于图神经网络知识蒸馏的图节点分类方法及装置
CN112784964A (zh) * 2021-01-27 2021-05-11 西安电子科技大学 基于桥接知识蒸馏卷积神经网络的图像分类方法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
冀中: "基于自注意力和自编码器的少样本学习", 《天津大学学报(自然科学与工程技术版)》 *
冀中: "零样本图像分类综述: 十年进展", 《中国科学》 *

Cited By (16)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113869462A (zh) * 2021-12-02 2021-12-31 之江实验室 一种基于双路结构对比嵌入学习的小样本对象分类方法
CN113869462B (zh) * 2021-12-02 2022-06-10 之江实验室 一种基于双路结构对比嵌入学习的小样本对象分类方法
CN114266977A (zh) * 2021-12-27 2022-04-01 青岛澎湃海洋探索技术有限公司 基于超分辨可选择网络的多auv的水下目标识别方法
CN114972904A (zh) * 2022-04-18 2022-08-30 北京理工大学 一种基于对抗三元组损失的零样本知识蒸馏方法及系统
CN114972904B (zh) * 2022-04-18 2024-05-31 北京理工大学 一种基于对抗三元组损失的零样本知识蒸馏方法及系统
CN114782776A (zh) * 2022-04-19 2022-07-22 中国矿业大学 基于MoCo模型的多模块知识蒸馏方法
CN115099988A (zh) * 2022-06-28 2022-09-23 腾讯科技(深圳)有限公司 模型训练方法、数据处理方法、设备及计算机介质
CN115100532B (zh) * 2022-08-02 2023-04-07 北京卫星信息工程研究所 小样本遥感图像目标检测方法和系统
CN115100532A (zh) * 2022-08-02 2022-09-23 北京卫星信息工程研究所 小样本遥感图像目标检测方法和系统
WO2024032386A1 (en) * 2022-08-08 2024-02-15 Huawei Technologies Co., Ltd. Systems and methods for artificial-intelligence model training using unsupervised domain adaptation with multi-source meta-distillation
CN116204770A (zh) * 2022-12-12 2023-06-02 中国公路工程咨询集团有限公司 一种用于桥梁健康监测数据异常检测的训练方法及装置
CN116204770B (zh) * 2022-12-12 2023-10-13 中国公路工程咨询集团有限公司 一种用于桥梁健康监测数据异常检测的训练方法及装置
CN115908823A (zh) * 2023-03-09 2023-04-04 南京航空航天大学 一种基于难度蒸馏的语义分割方法
CN116452794A (zh) * 2023-04-14 2023-07-18 中国矿业大学 一种基于半监督学习的有向目标检测方法
CN116452794B (zh) * 2023-04-14 2023-11-03 中国矿业大学 一种基于半监督学习的有向目标检测方法
CN116958548A (zh) * 2023-07-21 2023-10-27 中国矿业大学 基于类别统计驱动的伪标签自蒸馏语义分割方法

Also Published As

Publication number Publication date
CN113610173B (zh) 2022-10-04

Similar Documents

Publication Publication Date Title
CN113610173B (zh) 一种基于知识蒸馏的多跨域少样本分类方法
CN109710800B (zh) 模型生成方法、视频分类方法、装置、终端及存储介质
CN110298037B (zh) 基于增强注意力机制的卷积神经网络匹配的文本识别方法
CN109214452B (zh) 基于注意深度双向循环神经网络的hrrp目标识别方法
CN113792113A (zh) 视觉语言模型获得及任务处理方法、装置、设备及介质
Bochinski et al. Deep active learning for in situ plankton classification
CN113626589B (zh) 一种基于混合注意力机制的多标签文本分类方法
CN110188827A (zh) 一种基于卷积神经网络和递归自动编码器模型的场景识别方法
CN116450796A (zh) 一种智能问答模型构建方法及设备
CN116303977B (zh) 一种基于特征分类的问答方法及系统
CN115546196A (zh) 一种基于知识蒸馏的轻量级遥感影像变化检测方法
CN113822125A (zh) 唇语识别模型的处理方法、装置、计算机设备和存储介质
CN114528835A (zh) 基于区间判别的半监督专业术语抽取方法、介质及设备
CN115546840A (zh) 基于半监督知识蒸馏的行人重识别模型训练方法及装置
CN110990678B (zh) 基于增强型循环神经网络的轨迹相似性计算方法
CN111882042A (zh) 用于液体状态机的神经网络架构自动搜索方法、系统及介质
CN114299326A (zh) 一种基于转换网络与自监督的小样本分类方法
Selvam et al. A transformer-based framework for scene text recognition
CN114898136A (zh) 一种基于特征自适应的小样本图像分类方法
CN117390506A (zh) 一种基于网格编码与TextRCNN的船舶路径分类方法
CN116561314B (zh) 基于自适应阈值选择自注意力的文本分类方法
CN116611517A (zh) 融合图嵌入和注意力的知识追踪方法
CN115348551A (zh) 一种轻量化业务识别方法、装置、电子设备及存储介质
CN114139655A (zh) 一种蒸馏式竞争学习的目标分类系统和方法
CN111563413A (zh) 一种基于混合双模型的年龄预测方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant