CN115019106A - 基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置 - Google Patents

基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置 Download PDF

Info

Publication number
CN115019106A
CN115019106A CN202210733867.2A CN202210733867A CN115019106A CN 115019106 A CN115019106 A CN 115019106A CN 202210733867 A CN202210733867 A CN 202210733867A CN 115019106 A CN115019106 A CN 115019106A
Authority
CN
China
Prior art keywords
robust
domain
target domain
model
unsupervised
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210733867.2A
Other languages
English (en)
Inventor
肖遥
罗彬�
陈宇恒
魏朋旭
林倞
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Sun Yat Sen University
Original Assignee
Sun Yat Sen University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Sun Yat Sen University filed Critical Sun Yat Sen University
Priority to CN202210733867.2A priority Critical patent/CN115019106A/zh
Publication of CN115019106A publication Critical patent/CN115019106A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Databases & Information Systems (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Multimedia (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置,方法包括下述步骤:获取无监督目标域自然样本集;构建鲁棒无监督域自适应图像分类框架,包括非鲁棒目标域教师模型和鲁棒目标域学生模型;使用预训练的非鲁棒源域模型对非鲁棒目标域教师模型的参数进行初始化,在无监督目标域自然样本集上进行端到端的迭代训练;构造鲁棒目标域学生模型,在无监督目标域自然样本集上进行对抗蒸馏训练,输出图像分类结果。本方法将知识蒸馏和对抗训练结合起来,在源域数据完全缺失的情况下,只使用非鲁棒源域模型获得目标域上的鲁棒模型,在保持对目标域自然样本分类性能的同时,有效地提升了对目标域对抗样本的分类性能和模型的鲁棒性。

Description

基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置
技术领域
本发明属于计算机图像分类的技术领域,具体涉及一种基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置。
背景技术
无监督域自适应学习能将知识从标记的源域转移到未标记的目标域,在标签稀缺或注释繁琐的场景中推进模型转移。然而,由于数据隐私和安全问题,在域适应阶段可能无法访问源域数据,并且同时使用目标域数据与大规模源数据训练目标域上模型,在计算上也很棘手。因此,无源域数据无监督域自适应学习应运而生,如一种无源域数据的无监督域自适应学习模型(SHOT模型)。尽管现有研究取得了显著的进展,但大多数现有的无监督域自适应或无源域数据无监督域自适应方法都忽略了深度学习模型的鲁棒性,这些模型对输入图片中难以察觉的扰动很敏感并且在对抗样本面前表现十分脆弱;特别地,由于它们是在没有对目标域进行精确监督的情况下进行的乐观训练,因此(无源域数据的)无监督域自适应模型可能对这种扰动更加敏感,加剧了模型的脆弱性并对安全敏感的应用程序构成巨大威胁。
现有研究中,一方面通过从鲁棒的源域模型或鲁棒的预训练模型转移鲁棒性来训练鲁棒的无监督域自适应模型,以此提高鲁棒性;但在许多实际应用中,假设鲁棒源域模型或鲁棒预训练的可用性是不切实际的,因此很难直接应用到分类任务中。另一方面提高模型鲁棒性的方法是进行对抗训练,但是对抗训练会导致严重的过拟合现象,极大地影响了分类结果有效性。
发明内容
本发明的主要目的在于克服现有技术的缺点与不足,提供一种基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置,该方法将知识蒸馏和对抗训练结合起来,在源域数据完全缺失的情况下,只使用非鲁棒源域模型获得目标域上的鲁棒模型,在保持对目标域自然样本分类性能的同时,有效地提升了对目标域对抗样本的分类性能和模型的鲁棒性。
为了达到上述目的,本发明采用以下技术方案:
一方面,本发明提供了一种基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,其特征在于,包括下述步骤:
获取无监督目标域自然样本集;
构建鲁棒无监督域自适应图像分类框架;所述鲁棒无监督域自适应图像分类框架包括非鲁棒目标域教师模型和鲁棒目标域学生模型;
使用预训练的非鲁棒源域模型对非鲁棒目标域教师模型的参数进行初始化,在无监督目标域自然样本集上进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型;
基于训练好的非鲁棒目标域教师模型构造鲁棒目标域学生模型,在无监督目标域自然样本集上进行对抗蒸馏训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
作为优选的技术方案,所述鲁棒无监督域自适应图像分类框架的目标函数是基于间隔差异散度在无源域数据条件下进行推导而得,具体为:
根据间隔学习理论可得,对于任意一个得分函数f,都满足:
Figure BDA0003714892410000021
其中
Figure BDA0003714892410000022
是一个理想间隔损失,
Figure BDA0003714892410000023
是得分函数f在目标对抗域
Figure BDA0003714892410000024
上基于0-1损失的分类误差,
Figure BDA0003714892410000025
是得分函数f在源域
Figure BDA0003714892410000026
上以常数ρ为间隔的分类误差,
Figure BDA0003714892410000027
是以常数ρ为间隔的源域
Figure BDA0003714892410000028
和目标域
Figure BDA0003714892410000029
的间隔差异散度,
Figure BDA00037148924100000210
是以常数ρ为间隔的目标域
Figure BDA00037148924100000211
和目标对抗域
Figure BDA00037148924100000212
的间隔差异散度;
令在目标对抗域
Figure BDA00037148924100000213
上基于0-1损失的分类误差
Figure BDA00037148924100000214
达到最小的最优得分函数f,故根据(1)式的右端项得:
Figure BDA00037148924100000215
在源域数据完全缺失的条件下可知,得分函数f在源域
Figure BDA00037148924100000216
上以常数ρ为间隔的分类误差
Figure BDA00037148924100000217
是常数,故根据(2)式推导出鲁棒无监督域自适应图像分类框架的目标函数为:
Figure BDA00037148924100000218
其中,
Figure BDA00037148924100000219
为非鲁棒目标域教师模型的目标函数,
Figure BDA00037148924100000220
为鲁棒目标域学生模型的目标函数。
作为优选的技术方案,所述得到训练好的非鲁棒目标域教师模型,具体为:
采用不使用源域数据的无监督域自适应学习模型进行标准的无监督域自适应学习,获得非鲁棒目标域教师模型;
使用预训练的非鲁棒源域模型的参数对非鲁棒目标域教师模型的参数进行初始化;
将无监督目标域自然样本集输入非鲁棒目标域教师模型中进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型。
作为优选的技术方案,所述得到训练好的鲁棒目标域学生模型,具体为:
采用和非鲁棒目标域教师模型相同的结构构造鲁棒目标域学生模型;
根据鲁棒目标域学生模型的参数信息,对无监督目标域自然样本集中每个自然样本生成对应的对抗样本;
进行对抗蒸馏训练,每次迭代训练过程中,将非鲁棒目标域教师模型的参数固定,对鲁棒目标域学生模型进行端到端的训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
作为优选的技术方案,基于面向鲁棒性与准确性权衡的对抗训练方法TRADES进行对抗蒸馏训练;
所述对抗样本的生成公式为:
Figure BDA0003714892410000031
所述对抗蒸馏损失函数根据非鲁棒目标域教师模型的输出和鲁棒目标域学生模型的输出进行建立,表示为:
Figure BDA0003714892410000032
其中φ表示鲁棒目标域学生模型,φT表示非鲁棒目标域教师模型,x是无监督目标域自然样本集中的某一自然样本,x'是x对应生成的对抗样本,
Figure BDA0003714892410000033
是KL散度损失函数,β是常系数,p是p-范数,∈是常数范围。
作为优选的技术方案,使用投影梯度下降的对抗训练方法PGD进行对抗蒸馏训练;
所述对抗样本的生成公式为:
Figure BDA0003714892410000034
所述对抗蒸馏损失函数根据非鲁棒目标域教师模型的输出和鲁棒目标域学生模型的输出进行建立,表示为:
Figure BDA0003714892410000035
其中φ表示鲁棒目标域学生模型,φT表示非鲁棒目标域教师模型,x是无监督目标域自然样本集中的某一自然样本,x'是x对应生成的对抗样本,
Figure BDA0003714892410000036
是KL散度损失函数,β是常系数,p是p-范数,∈是常数范围。
另一方面,本发明还提供了一种基于对抗蒸馏的鲁棒无监督域自适应图像分类系统,其特征在于,应用于上述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,包括数据获取模块、分类框架构建模块、教师模型训练模块及学生模型训练模块;
所述数据获取模块用于获取无监督目标域自然样本集;
所述分类框架构建模块用于构建鲁棒无监督域自适应图像分类框架;所述鲁棒无监督域自适应图像分类框架包括非鲁棒目标域教师模型和鲁棒目标域学生模型;
所述教师模型训练模块用于使用预训练的非鲁棒源域模型对非鲁棒目标域教师模型的参数进行初始化,在无监督目标域自然样本集上进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型;
所述学生模型训练模块用于基于训练好的非鲁棒目标域教师模型构造鲁棒目标域学生模型,在无监督目标域自然样本集上进行对抗蒸馏训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
作为优选的技术方案,所述教师模型训练模块具体为:
采用不使用源域数据的无监督域自适应学习模型进行标准的无监督域自适应学习,获得非鲁棒目标域教师模型;
使用预训练的非鲁棒源域模型的参数对非鲁棒目标域教师模型的参数进行初始化;
将无监督目标域自然样本集输入非鲁棒目标域教师模型中进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型。
作为优选的技术方案,所述学生模型训练模块具体为:
采用和非鲁棒目标域教师模型相同的结构构造鲁棒目标域学生模型;
根据鲁棒目标域学生模型的参数信息,对无监督目标域自然样本集中每个自然样本生成对应的对抗样本;
进行对抗蒸馏训练,每次迭代训练过程中,将非鲁棒目标域教师模型的参数固定,对鲁棒目标域学生模型进行端到端的训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
还一方面,本发明提供了一种计算机可读存储介质,存储有程序,其特征在于,所述程序被处理器执行时,实现上述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法。
本发明与现有技术相比,具有如下优点和有益效果:
1、相比一般的无监督域自适应图像分类方法,本方法在源域数据完全缺失的情况下,使用非鲁棒源域模型获得目标域上的鲁棒模型,通过知识蒸馏+对抗学习的方法训练模型,能够显著提升模型的鲁棒性。
2、本方法不需要使用源域数据或预训练的鲁棒模型,只使用非鲁棒源域模型仍然能获得目标域上的鲁棒模型,使得许多源域模型适用于现实世界中的鲁棒无监督域迁移自适应应用程序中,具有广泛的适用性。
3、本方法采用两种的无监督域自适应的对抗蒸馏方法来进行训练,减轻了过拟合的问题,能够提高模型的类别判别能力和图像分类性能,并且拥有更高的准确率。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本发明实施例中基于对抗蒸馏的鲁棒无监督域自适应图像分类方法的流程图;
图2为本发明实施例中基于对抗蒸馏的鲁棒无监督域自适应图像分类方法的结构图;
图3为本发明实施例中基于对抗蒸馏的鲁棒无监督域自适应图像分类系统的结构图;
图4为本发明实施例中计算机可读存储介质的结构示意图。
具体实施方式
为了使本技术领域的人员更好地理解本申请方案,下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述。显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
在本申请中提及“实施例”意味着,结合实施例描述的特定特征、结构或特性可以包含在本申请的至少一个实施例中。在说明书中的各个位置出现该短语并不一定均是指相同的实施例,也不是与其它实施例互斥的独立的或备选的实施例。本领域技术人员显式地和隐式地理解的是,本申请所描述的实施例可以与其它实施例相结合。
知识蒸馏是模型压缩的一种常用的方法,不同于模型压缩中的剪枝和量化,知识蒸馏是通过构建一个轻量化的小模型,利用性能更好的大模型的监督信息,来训练这个小模型,以期达到更好的性能和精度。这个大模型称之为教师模型,小模型称之为学生模型。来自教师模型输出的监督信息称之为知识,而学生模型学习迁移来自教师模型的监督信息的过程称之为蒸馏。在迁移学习领域中,知识蒸馏是一种常见的迁移学习方法。
因此本发明将知识蒸馏和对抗训练结合起来,使用对抗蒸馏法来训练鲁棒目标域学生模型,能够在源域数据完全缺失的情况下,只使用非鲁棒源域模型仍然能获得目标域上的鲁棒模型,使得许多源域模型适用于现实世界中的鲁棒无监督域迁移自适应应用程序中。
如图1、图2所示,本实施例基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,包括下述步骤:
S1、获取无监督目标域自然样本集;
S2、构建鲁棒无监督域自适应图像分类框架,包括非鲁棒目标域教师模型和鲁棒目标域学生模型;
鲁棒无监督域自适应图像分类框架构建后,使用间隔学习理论来估计该模型在无源域数据情况下,目标域对抗训练过程中学习误差的上界,启发式地设计在无源域数据情况下更合理有效的学习目标;本实施例中的鲁棒无监督自适应图像分类框架的目标函数,是基于间隔差异散度在无源域数据条件下进行推导而得,具体为:
根据间隔学习理论可得,对于任意一个得分函数f,都满足:
Figure BDA0003714892410000061
其中
Figure BDA0003714892410000062
是一个理想间隔损失,
Figure BDA0003714892410000063
是得分函数f在目标对抗域
Figure BDA0003714892410000064
上基于0-1损失的分类误差,
Figure BDA0003714892410000065
是得分函数f在源域
Figure BDA0003714892410000066
上以常数ρ为间隔的分类误差,
Figure BDA0003714892410000067
是以常数ρ为间隔的源域
Figure BDA0003714892410000068
和目标域
Figure BDA0003714892410000069
的间隔差异散度,
Figure BDA00037148924100000610
是以常数ρ为间隔的目标域
Figure BDA00037148924100000611
和目标对抗域
Figure BDA00037148924100000612
的间隔差异散度;
已知有如下引理,即对任意一个得分函数f,都满足
Figure BDA00037148924100000613
并且
Figure BDA00037148924100000614
满足三角不等式
Figure BDA00037148924100000615
理由如下:
Figure BDA00037148924100000616
则(1)式定理得证;
由于本发明的最终目标是通过对抗蒸馏训练来获得具有良好图像判别能力的鲁棒目标域学生模型,故需令在目标对抗域
Figure BDA0003714892410000071
上基于0-1损失的分类误差
Figure BDA0003714892410000072
达到最小的最优得分函数f,故根据(1)式的右端项得:
Figure BDA0003714892410000073
在源域数据完全缺失的条件下可知,得分函数f在源域
Figure BDA0003714892410000074
上以常数ρ为间隔的分类误差
Figure BDA0003714892410000075
是常数,故根据(2)式推导出鲁棒无监督域自适应图像分类框架的目标函数为:
Figure BDA0003714892410000076
其中,
Figure BDA0003714892410000077
为非鲁棒目标域教师模型的目标函数,
Figure BDA0003714892410000078
为鲁棒目标域学生模型的目标函数。
S3、使用预训练的非鲁棒源域模型对非鲁棒目标域教师模型的参数进行初始化,在无监督目标域自然样本集上进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型;
本发明鲁棒无监督域自适应图像分类框架的学习分为两部分,如图2所示,第一部分为基于无监督域自适应的迁移学习,第二部分为无监督域自适应的对抗蒸馏学习;
步骤S3即为本实施例中鲁棒无监督域自适应图像分类框架的第一部分,其中训练好的非鲁棒目标域教师模型是根据非鲁棒目标域教师模型的学习目标函数
Figure BDA0003714892410000079
基于无监督域自适应的迁移学习方法训练获得,用于在源域数据完全缺失的前提下提高目标域教师模型对目标域自然样本的类别判别性,具体为:
S301、由于数据隐私和安全问题,在领域适应阶段无法访问源域数据,故采用现有任意一种不使用源域数据的无监督域自适应学习模型进行标准的无监督域自适应学习,获得非鲁棒目标域教师模型;
S302、使用预训练的非鲁棒源域模型的参数对非鲁棒目标域教师模型的参数进行初始化;
S303、将无监督目标域自然样本集输入非鲁棒目标域教师模型中进行端到端的迭代训练,训练损失函数表示为:
Figure BDA00037148924100000710
S4、基于训练好的非鲁棒目标域教师模型构造鲁棒目标域学生模型,在无监督目标域自然样本集上进行对抗蒸馏训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
步骤S4为本实施例中鲁棒无监督域自适应图像分类框架的第二部分,其中训练好的鲁棒目标域学生模型是根据鲁棒目标域学生模型的学习目标函数
Figure BDA00037148924100000711
基于无监督域自适应的对抗蒸馏学习方法训练获得,用于提高鲁棒目标域学生模型对目标域上的对抗样本的类别判别性,具体为:
S401、采用和非鲁棒目标域教师模型相同的结构(包括骨干网络和分类器)构造鲁棒目标域学生模型;
S402、根据鲁棒目标域学生模型的参数信息,对无监督目标域自然样本集中每个自然样本生成对应的对抗样本;
S403、进行对抗蒸馏训练,每次迭代训练过程中,将非鲁棒目标域教师模型的参数固定,对鲁棒目标域学生模型进行端到端的训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
本实施例中,无监督域自适应的对抗蒸馏学习方法有两种方式,分别为:
1、基于面向鲁棒性与准确性权衡的对抗训练方法TRADES进行对抗蒸馏训练,具体为:
采用和非鲁棒目标域教师模型相同的结构(包括骨干网络和分类器)构造鲁棒目标域学生模型;
根据鲁棒目标域学生模型的参数信息,对无监督目标域自然样本集中每个自然样本生成对应的对抗样本,生成公式为:
Figure BDA0003714892410000081
进行对抗蒸馏训练,每次迭代训练过程中,将非鲁棒目标域教师模型的参数固定,对鲁棒目标域学生模型进行端到端的训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果;
根据非鲁棒目标域教师模型的输出和鲁棒目标域学生模型的输出进行建立对抗蒸馏损失函数,表示为:
Figure BDA0003714892410000082
其中φ表示鲁棒目标域学生模型,φT表示非鲁棒目标域教师模型,x是无监督目标域自然样本集中的某一自然样本,x'是x对应生成的对抗样本,
Figure BDA0003714892410000083
是KL散度损失函数,β是常系数,p是p-范数,∈是常数范围。
2、使用投影梯度下降的对抗训练方法PGD进行对抗蒸馏训练,具体为:
采用和非鲁棒目标域教师模型相同的结构(包括骨干网络和分类器)构造鲁棒目标域学生模型;
根据鲁棒目标域学生模型的参数信息,对无监督目标域自然样本集中每个自然样本生成对应的对抗样本,生成公式为:
Figure BDA0003714892410000091
进行对抗蒸馏训练,每次迭代训练过程中,将非鲁棒目标域教师模型的参数固定,对鲁棒目标域学生模型进行端到端的训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果;
根据非鲁棒目标域教师模型的输出和鲁棒目标域学生模型的输出进行建立对抗蒸馏损失函数,表示为:
Figure BDA0003714892410000092
其中φ表示鲁棒目标域学生模型,φT表示非鲁棒目标域教师模型,x是无监督目标域自然样本集中的某一自然样本,x'是x对应生成的对抗样本,
Figure BDA0003714892410000093
是KL散度损失函数,β是常系数,p是p-范数,∈是常数范围。
本发明基于无源域数据的无监督领域自适应的非鲁棒目标域教师模型学习,在源域数据完全缺失的情况下使用预训练的非鲁棒源域模型对非鲁棒目标域教师模型的参数进行初始化,然后通过无监督领域自适应算法迭代训练目标域教师模型,提高对目标域自然样本的类别判别性;然后在目标域上进行无监督对抗蒸馏学习,固定非鲁棒目标域教师模型的参数,使用无监督目标域数据迭代训练一个鲁棒的目标域学生模型。在每轮训练迭代中,根据学生模型的模型参数信息对目标域的每个自然样本生成相应的对抗样本,然后根据教师模型对目标域的自然样本的输出和学生模型对相应的对抗样本的输出建立蒸馏损失函数,据此更新学生模型的参数,提高学生模型对目标域对抗样本的类别判别性和图像分类性能。
本实施例中,基于对抗蒸馏的鲁棒无监督域自适应图像分类方法的实验细节如下:
1、本发明在四个自然样本数据集上进行大量实验并得到有效验证,分别为:Office-31数据集,Office-Home数据集,PACS数据集和VisDA-C数据集;
2、在Office-31,Office-home和PACS数据集上,非鲁棒目标域教师模型和鲁棒目标域学生模型都采用ResNet-50作为骨干网络,而在VisDA-C数据集上非鲁棒目标域教师模型和鲁棒目标域学生模型都采用ResNet-101作为骨干网络;分类器均为单个全连接层;
3、使用随机梯度下降法,初始学习率为0.1并使用余弦退火方法逐渐降低;在VisDA-C数据集上训练的迭代轮数为50轮,在其余三个数据集上的训练迭代轮数为100轮;对抗样本的范围∈为8/255,范数p为∞;在TRADES方法中β设置为1;
4、本发明的实验结果如下表,其中SHOT是基准模型,TRADES-AD和PGD-AD是本发明的两种对抗蒸馏训练方法;分类精度是指得到的目标域模型对目标域自然测试样本的分类准确率;鲁棒性是指目标域模型对目标域对抗测试样本的分类准确率。
Figure BDA0003714892410000101
需要说明的是,对于前述的各方法实施例,为了简便描述,将其都表述为一系列的动作组合,但是本领域技术人员应该知悉,本发明并不受所描述的动作顺序的限制,因为依据本发明,某些步骤可以采用其它顺序或者同时进行。
基于与上述实施例中的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法相同的思想,本发明还提供基于对抗蒸馏的鲁棒无监督域自适应图像分类系统,该系统可用于执行上述基于对抗蒸馏的鲁棒无监督域自适应图像分类方法。为了便于说明,基于对抗蒸馏的鲁棒无监督域自适应图像分类系统实施例的结构示意图中,仅仅示出了与本发明实施例相关的部分,本领域技术人员可以理解,图示结构并不构成对装置的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。
如图3所示,本发明另一个实施例提供了一种基于对抗蒸馏的鲁棒无监督域自适应图像分类系统100,包括数据获取模块101、分类框架构建模块102、教师模型训练模块103及学生模型训练模块104;
数据获取模块101用于获取无监督目标域自然样本集;
分类框架构建模块102用于构建鲁棒无监督域自适应图像分类框架,包括非鲁棒目标域教师模型和鲁棒目标域学生模型;
教师模型训练模块103用于使用预训练的非鲁棒源域模型对非鲁棒目标域教师模型的参数进行初始化,在无监督目标域自然样本集上进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型;
学生模型训练模块104用于基于训练好的非鲁棒目标域教师模型构造鲁棒目标域学生模型,在无监督目标域自然样本集上进行对抗蒸馏训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
具体的,教师模型训练模块103具体为:
采用不使用源域数据的无监督域自适应学习模型进行标准的无监督域自适应学习,获得非鲁棒目标域教师模型;
使用预训练的非鲁棒源域模型的参数对非鲁棒目标域教师模型的参数进行初始化;
将无监督目标域自然样本集输入非鲁棒目标域教师模型中进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型。
具体的,学生模型训练模块104具体为:
采用和非鲁棒目标域教师模型相同的结构构造鲁棒目标域学生模型;
根据鲁棒目标域学生模型的参数信息,对无监督目标域自然样本集中每个自然样本生成对应的对抗样本;
进行对抗蒸馏训练,每次迭代训练过程中,将非鲁棒目标域教师模型的参数固定,对鲁棒目标域学生模型进行端到端的训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
需要说明的是,本发明的基于对抗蒸馏的鲁棒无监督域自适应图像分类系统与本发明的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法一一对应,在上述基于对抗蒸馏的鲁棒无监督域自适应图像分类方法的实施例阐述的技术特征及其有益效果均适用于基于对抗蒸馏的鲁棒无监督域自适应图像分类系统的实施例中,具体内容可参见本发明方法实施例中的叙述,此处不再赘述,特此声明。
此外,上述实施例的基于对抗蒸馏的鲁棒无监督域自适应图像分类系统的实施方式中,各程序模块的逻辑划分仅是举例说明,实际应用中可以根据需要,例如出于相应硬件的配置要求或者软件的实现的便利考虑,将上述功能分配由不同的程序模块完成,即将所述基于对抗蒸馏的鲁棒无监督域自适应图像分类系统的内部结构划分成不同的程序模块,以完成以上描述的全部或者部分功能。
如图4所示,在一个实施例中,提供了一种计算机可读存储介质200,存储有程序于存储器201中,所述程序被处理器202执行时,实现所述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,具体为:
获取无监督目标域自然样本集;
构建鲁棒无监督域自适应图像分类框架;所述鲁棒无监督域自适应图像分类模型包括非鲁棒目标域教师模型和鲁棒目标域学生模型;
使用预训练的非鲁棒源域模型对非鲁棒目标域教师模型的参数进行初始化,在无监督目标域自然样本集上进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型;
基于训练好的非鲁棒目标域教师模型构造鲁棒目标域学生模型,在无监督目标域自然样本集上进行对抗蒸馏训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的程序可存储于一非易失性计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可包括只读存储器(ROM)、可编程ROM(PROM)、电可编程ROM(EPROM)、电可擦除可编程ROM(EEPROM)或闪存。易失性存储器可包括随机存取存储器(RAM)或者外部高速缓冲存储器。作为说明而非局限,RAM以多种形式可得,诸如静态RAM(SRAM)、动态RAM(DRAM)、同步DRAM(SDRAM)、双数据率SDRAM(DDRSDRAM)、增强型SDRAM(ESDRAM)、同步链路(Synchlink)DRAM(SLDRAM)、存储器总线(Rambus)直接RAM(RDRAM)、直接存储器总线动态RAM(DRDRAM)、以及存储器总线动态RAM(RDRAM)等。
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
上述实施例为本发明较佳的实施方式,但本发明的实施方式并不受上述实施例的限制,其他的任何未背离本发明的精神实质与原理下所作的改变、修饰、替代、组合、简化,均应为等效的置换方式,都包含在本发明的保护范围之内。

Claims (10)

1.基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,其特征在于,包括下述步骤:
获取无监督目标域自然样本集;
构建鲁棒无监督域自适应图像分类框架;所述鲁棒无监督域自适应图像分类框架包括非鲁棒目标域教师模型和鲁棒目标域学生模型;
使用预训练的非鲁棒源域模型对非鲁棒目标域教师模型的参数进行初始化,在无监督目标域自然样本集上进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型;
基于训练好的非鲁棒目标域教师模型构造鲁棒目标域学生模型,在无监督目标域自然样本集上进行对抗蒸馏训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
2.根据权利要求1所述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,其特征在于,所述鲁棒无监督域自适应图像分类框架的目标函数是基于间隔差异散度在无源域数据条件下进行推导而得,具体为:
根据间隔学习理论可得,对于任意一个得分函数f,都满足:
Figure FDA0003714892400000011
其中
Figure FDA0003714892400000012
是一个理想间隔损失,
Figure FDA0003714892400000013
是得分函数f在目标对抗域
Figure FDA0003714892400000014
上基于0-1损失的分类误差,
Figure FDA0003714892400000015
是得分函数f在源域
Figure FDA0003714892400000016
上以常数ρ为间隔的分类误差,
Figure FDA0003714892400000017
是以常数ρ为间隔的源域
Figure FDA0003714892400000018
和目标域
Figure FDA0003714892400000019
的间隔差异散度,
Figure FDA00037148924000000110
是以常数ρ为间隔的目标域
Figure FDA00037148924000000111
和目标对抗域
Figure FDA00037148924000000112
的间隔差异散度;
令在目标对抗域
Figure FDA00037148924000000113
上基于0-1损失的分类误差
Figure FDA00037148924000000114
达到最小的最优得分函数f,故根据(1)式的右端项得:
Figure FDA00037148924000000115
在源域数据完全缺失的条件下可知,得分函数f在源域
Figure FDA00037148924000000116
上以常数ρ为间隔的分类误差
Figure FDA00037148924000000117
是常数,故根据(2)式推导出鲁棒无监督域自适应图像分类框架的目标函数为:
Figure FDA00037148924000000118
其中,
Figure FDA00037148924000000119
为非鲁棒目标域教师模型的目标函数,
Figure FDA00037148924000000120
为鲁棒目标域学生模型的目标函数。
3.根据权利要求2所述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,其特征在于,所述得到训练好的非鲁棒目标域教师模型,具体为:
采用不使用源域数据的无监督域自适应学习模型进行标准的无监督域自适应学习,获得非鲁棒目标域教师模型;
使用预训练的非鲁棒源域模型的参数对非鲁棒目标域教师模型的参数进行初始化;
将无监督目标域自然样本集输入非鲁棒目标域教师模型中进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型。
4.根据权利要求3所述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,其特征在于,所述得到训练好的鲁棒目标域学生模型,具体为:
采用和非鲁棒目标域教师模型相同的结构构造鲁棒目标域学生模型;
根据鲁棒目标域学生模型的参数信息,对无监督目标域自然样本集中每个自然样本生成对应的对抗样本;
进行对抗蒸馏训练,每次迭代训练过程中,将非鲁棒目标域教师模型的参数固定,对鲁棒目标域学生模型进行端到端的训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
5.根据权利要求4所述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,其特征在于,基于面向鲁棒性与准确性权衡的对抗训练方法TRADES进行对抗蒸馏训练;
所述对抗样本的生成公式为:
Figure FDA0003714892400000021
所述对抗蒸馏损失函数根据非鲁棒目标域教师模型的输出和鲁棒目标域学生模型的输出进行建立,表示为:
Figure FDA0003714892400000022
其中φ表示鲁棒目标域学生模型,φT表示非鲁棒目标域教师模型,x是无监督目标域自然样本集中的某一自然样本,x'是x对应生成的对抗样本,
Figure FDA0003714892400000023
是KL散度损失函数,β是常系数,p是p-范数,∈是常数范围。
6.根据权利要求4所述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,其特征在于,使用投影梯度下降的对抗训练方法PGD进行对抗蒸馏训练;
所述对抗样本的生成公式为:
Figure FDA0003714892400000024
所述对抗蒸馏损失函数根据非鲁棒目标域教师模型的输出和鲁棒目标域学生模型的输出进行建立,表示为:
Figure FDA0003714892400000025
其中φ表示鲁棒目标域学生模型,φT表示非鲁棒目标域教师模型,x是无监督目标域自然样本集中的某一自然样本,x'是x对应生成的对抗样本,
Figure FDA0003714892400000031
是KL散度损失函数,β是常系数,p是p-范数,∈是常数范围。
7.基于对抗蒸馏的鲁棒无监督域自适应图像分类系统,其特征在于,应用于权利要求1-6中任一项所述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法,包括数据获取模块、分类框架构建模块、教师模型训练模块及学生模型训练模块;
所述数据获取模块用于获取无监督目标域自然样本集;
所述分类框架构建模块用于构建鲁棒无监督域自适应图像分类框架;所述鲁棒无监督域自适应图像分类框架包括非鲁棒目标域教师模型和鲁棒目标域学生模型;
所述教师模型训练模块用于使用预训练的非鲁棒源域模型对非鲁棒目标域教师模型的参数进行初始化,在无监督目标域自然样本集上进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型;
所述学生模型训练模块用于基于训练好的非鲁棒目标域教师模型构造鲁棒目标域学生模型,在无监督目标域自然样本集上进行对抗蒸馏训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
8.根据权利要求7中所述的基于对抗蒸馏的鲁棒无监督域自适应图像分类系统,其特征在于,所述教师模型训练模块具体为:
采用不使用源域数据的无监督域自适应学习模型进行标准的无监督域自适应学习,获得非鲁棒目标域教师模型;
使用预训练的非鲁棒源域模型的参数对非鲁棒目标域教师模型的参数进行初始化;
将无监督目标域自然样本集输入非鲁棒目标域教师模型中进行端到端的迭代训练,得到训练好的非鲁棒目标域教师模型。
9.根据权利要求8所述的基于对抗蒸馏的鲁棒无监督域自适应图像分类系统,其特征在于,所述学生模型训练模块具体为:
采用和非鲁棒目标域教师模型相同的结构构造鲁棒目标域学生模型;
根据鲁棒目标域学生模型的参数信息,对无监督目标域自然样本集中每个自然样本生成对应的对抗样本;
进行对抗蒸馏训练,每次迭代训练过程中,将非鲁棒目标域教师模型的参数固定,对鲁棒目标域学生模型进行端到端的训练,得到训练好的鲁棒目标域学生模型并输出图像分类结果。
10.一种计算机可读存储介质,存储有程序,其特征在于,所述程序被处理器执行时,实现权利要求1-6任一项所述的基于对抗蒸馏的鲁棒无监督域自适应图像分类方法。
CN202210733867.2A 2022-06-27 2022-06-27 基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置 Pending CN115019106A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210733867.2A CN115019106A (zh) 2022-06-27 2022-06-27 基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210733867.2A CN115019106A (zh) 2022-06-27 2022-06-27 基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置

Publications (1)

Publication Number Publication Date
CN115019106A true CN115019106A (zh) 2022-09-06

Family

ID=83076531

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210733867.2A Pending CN115019106A (zh) 2022-06-27 2022-06-27 基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置

Country Status (1)

Country Link
CN (1) CN115019106A (zh)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115186773A (zh) * 2022-09-13 2022-10-14 杭州涿溪脑与智能研究所 一种无源的主动领域自适应模型训练方法及装置
CN116306868A (zh) * 2023-03-01 2023-06-23 支付宝(杭州)信息技术有限公司 一种模型的处理方法、装置及设备
CN116543237A (zh) * 2023-06-27 2023-08-04 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 无源域无监督域适应的图像分类方法、系统、设备及介质
CN116758353A (zh) * 2023-06-20 2023-09-15 大连理工大学 基于域特定信息滤除的遥感图像目标分类方法
CN116935188A (zh) * 2023-09-15 2023-10-24 腾讯科技(深圳)有限公司 模型训练方法、图像识别方法、装置、设备及介质

Cited By (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115186773A (zh) * 2022-09-13 2022-10-14 杭州涿溪脑与智能研究所 一种无源的主动领域自适应模型训练方法及装置
CN115186773B (zh) * 2022-09-13 2022-12-09 杭州涿溪脑与智能研究所 一种无源的主动领域自适应模型训练方法及装置
CN116306868A (zh) * 2023-03-01 2023-06-23 支付宝(杭州)信息技术有限公司 一种模型的处理方法、装置及设备
CN116306868B (zh) * 2023-03-01 2024-01-05 支付宝(杭州)信息技术有限公司 一种模型的处理方法、装置及设备
CN116758353A (zh) * 2023-06-20 2023-09-15 大连理工大学 基于域特定信息滤除的遥感图像目标分类方法
CN116758353B (zh) * 2023-06-20 2024-01-23 大连理工大学 基于域特定信息滤除的遥感图像目标分类方法
CN116543237A (zh) * 2023-06-27 2023-08-04 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 无源域无监督域适应的图像分类方法、系统、设备及介质
CN116543237B (zh) * 2023-06-27 2023-11-28 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 无源域无监督域适应的图像分类方法、系统、设备及介质
CN116935188A (zh) * 2023-09-15 2023-10-24 腾讯科技(深圳)有限公司 模型训练方法、图像识别方法、装置、设备及介质
CN116935188B (zh) * 2023-09-15 2023-12-26 腾讯科技(深圳)有限公司 模型训练方法、图像识别方法、装置、设备及介质

Similar Documents

Publication Publication Date Title
CN115019106A (zh) 基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置
US20210295162A1 (en) Neural network model training method and apparatus, computer device, and storage medium
CN110321926B (zh) 一种基于深度残差修正网络的迁移方法及系统
US11348249B2 (en) Training method for image semantic segmentation model and server
CN112949837A (zh) 一种基于可信网络的目标识别联邦深度学习方法
CN109271958B (zh) 人脸年龄识别方法及装置
WO2022105123A1 (zh) 文本分类的方法、话题生成的方法、装置、设备及介质
Wu et al. Federated unlearning: Guarantee the right of clients to forget
Li et al. A hybrid imputation approach for microarray missing value estimation
Yu et al. Predicting protein complex in protein interaction network-a supervised learning based method
CN111079946A (zh) 模型训练方法、成员探测装置的训练方法及其系统
Reyes et al. Precision-weighted federated learning
CN111881737A (zh) 年龄预测模型的训练方法及装置、年龄预测方法及装置
CN113239168A (zh) 一种基于知识图谱嵌入预测模型的可解释性方法和系统
CN117201122A (zh) 基于视图级图对比学习的无监督属性网络异常检测方法及系统
CN113408706A (zh) 训练用户兴趣挖掘模型、用户兴趣挖掘的方法和装置
WO2023178793A1 (zh) 双视角图神经网络模型的训练方法、装置、设备及介质
Li et al. Feddkw–federated learning with dynamic kullback–leibler-divergence weight
Suri et al. Dissecting distribution inference
US20210073662A1 (en) Machine Learning Systems and Methods for Performing Entity Resolution Using a Flexible Minimum Weight Set Packing Framework
CN114386604A (zh) 基于多教师模型的模型蒸馏方法、装置、设备及存储介质
Malinowski Set-valued and fuzzy stochastic integral equations driven by semimartingales under Osgood condition
Sheth et al. Monte carlo structured svi for two-level non-conjugate models
Zhang et al. Cross-domain recommendation with multi-auxiliary domains via consistent and selective cluster-level knowledge transfer
Wang et al. Variance of the gradient also matters: Privacy leakage from gradients

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination