CN115131599B - 一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法 - Google Patents
一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法 Download PDFInfo
- Publication number
- CN115131599B CN115131599B CN202210437273.7A CN202210437273A CN115131599B CN 115131599 B CN115131599 B CN 115131599B CN 202210437273 A CN202210437273 A CN 202210437273A CN 115131599 B CN115131599 B CN 115131599B
- Authority
- CN
- China
- Prior art keywords
- model
- student model
- sample
- student
- confrontation
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/02—Knowledge representation; Symbolic representation
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Computational Linguistics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法,该方法用来解决图像分类领域内知识蒸馏方法出现学生模型对抗鲁棒性学习不足的问题。该方法使学生模型的自然样本输出与对抗样本输出均向教师模型学习,还规定模型自然样本输出与针对其本身的对抗样本输出之间的距离度量为对抗偏差,将教师模型的对抗偏差作为额外蒸馏项传递给学生模型,提高学生模型的泛化性。本发明实现了将教师模型的分类准确性与对抗鲁棒性传递给了学生模型,使学生模型在进行图像分类任务时可以保证较高识别准确率,并更加有效地抵御图像对抗攻击。相比于其他方法,本方法在多个常见的图像分类数据集上取得良好效果。
Description
技术领域
本发明属于计算机深度学习领域,尤其涉及一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法。
背景技术
知识蒸馏是图像分类领域内的一种模型压缩技术,它通过让学生模型的输出模仿大型教师模型输出的方法,将教师模型的知识提取给学生模型,达到比学生模型自己训练更好的效果。然而随着对抗攻击等人工智能安全问题的出现,常用的模型都需要具备一定的抵御对抗攻击的鲁棒性,但是传统的知识蒸馏技术出现了无法将教师模型的对抗鲁棒性传递给学生模型的问题。因此如何在知识蒸馏过程中既能够使学生模型学习到准确性又能够学习到鲁棒性是亟需解决的。
发明内容
本发明的目的是为了提高图像分类领域中知识蒸馏方法的性能,解决蒸馏过程中学生模型无法较好获得教师模型的对抗鲁棒性的问题,提供一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法。
本发明的目的是通过以下技术方案来实现的:一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法,包括以下步骤:
步骤一:在输入的图像数据集上,对于教师模型进行对抗训练的预训练环节;对学生模型进行参数初始化;教师模型和学生模型均为图像分类模型;
步骤二:在步骤一输入的图像数据集上,使用对抗攻击方法,针对于教师模型与学生模型分别生成其各自的对抗样本;
步骤三:使用对抗偏差学习方法进行对抗鲁棒性知识蒸馏,使学生模型学习教师模型的知识,在训练中优化学生模型;
步骤四:将待识别图像输入步骤三优化后的学生模型,预测得到图像类别。
进一步地,步骤二中,获得对抗样本的方法是,使用投影梯度下降方法针对于模型攻击生成。
进一步地,步骤三通过以下子步骤来实现:
(3.1)使学生模型自然样本输出模仿教师模型自然样本输出,在损失函数中添加学生模型自然样本输出与教师模型自然样本输出的相对熵;
(3.2)使学生模型对抗样本输出模仿教师模型自然样本输出,在损失函数中添加学生模型对抗样本输出与教师模型自然样本输出的相对熵;
(3.3)计算学生模型自然样本输出与其自身对抗样本输出的差值为学生模型的对抗偏差,计算教师模型自然样本输出与其自身对抗样本输出的差值为教师模型的对抗偏差;
(3.4)使学生模型对抗偏差模仿教师模型对抗偏差,在损失函数中添加学生模型对抗偏差与教师模型对抗偏差的相对熵;
(3.5)为损失函数中的三项相对熵分配指定权重,优化对抗鲁棒性蒸馏效果;
(3.6)对于学生模型进行鲁棒性知识蒸馏训练,进行优化。
进一步地,步骤(3.5)中,学生模型的损失函数如下:
其中,KL(·)为相对熵函数,α、β、γ为权重;表示教师模型的第i个图像样本对应的输出,表示学生模型的第i个图像样本对应的输出,表示教师模型的第i个对抗样本对应的输出,表示学生模型的第i个对抗样本对应的输出。
进一步地,步骤(3.6)中,学生模型的优化函数为:
其中,L(·)表示学生模型的损失函数,CE(·)表示为交叉熵损失函数,为第i个图像样本,为学生模型的第i个对抗样本,为教师模型的第i个对抗样本,yi表示第i 个图像样本的真实的类别标签;W为学生模型的参数,N表示图像数据集中图像样本的数量; ||||p表示p-范数;∈表示距离上限。
进一步地,优化学生模型时使用梯度下降法进行优化,第T次迭代时损失函数L关于W 的偏导数为:
其中,n为梯度更新时输入的图片数量,则第T次更新的梯度为:
使用梯度下降进行优化更新学生模型:
WT=WT-1-μgT
其中,μ为学习率。
本发明的有益效果是:本发明实现了将教师模型的分类准确性与对抗鲁棒性传递给了学生模型,使得学生模型在保证较高分类准确率的情况下,能够更加有效地抵御图像对抗攻击,具有更好的泛化性,且在多个常见的图像分类数据集上取得良好的分类效果,准确率高。
附图说明
图1是本发明基于对抗偏差与鲁棒性知识蒸馏的图像分类方法的流程图;
图2是基于对抗偏差学习的鲁棒性知识蒸馏方法的示意图;
图3是所有对比方法训练过程中学生模型对抗样本输出与自然样本输出相对熵变化情况的折线图。
具体实施方式
下面根据附图详细说明本发明。
如图1所示,本发明一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法,包括以下步骤:
步骤一:在输入的图像数据集(包含多种不同类别的图像)上,对教师模型进行对抗训练的预训练环节,使教师模型获得一定的分类准确性与对抗鲁棒性。教师模型为图像分类模型,对输入的图像数据集进行分类。对学生模型进行参数初始化,以便训练优化。学生模型也是图像分类模型。
步骤二:将步骤一输入的图像数据集作为自然样本,并在自然样本上使用对抗攻击方法针对于教师模型与学生模型,分别生成教师模型的对抗样本和学生模型的对抗样本。
具体地,获得对抗样本的方法是,使用投影梯度下降方法(PGD)针对于模型攻击生成。
步骤三:使用对抗偏差(Adversarial Deviation)学习方法,进行对抗鲁棒性知识蒸馏,使学生模型学习教师模型的知识,在训练中优化学生模型,使其获得分类准确性与对抗鲁棒性。
具体地,步骤三是本发明的核心,如图2所示,包括以下子步骤:
3.1)使学生模型自然样本输出,模仿教师模型自然样本输出;在学生模型的损失函数中,添加学生模型自然样本输出与教师模型自然样本输出的相对熵。
3.2)使学生模型对抗样本输出,模仿教师模型自然样本输出;在学生模型的损失函数中,添加学生模型对抗样本输出与教师模型自然样本输出的相对熵。
3.3)计算学生模型自然样本输出与学生模型对抗样本输出的差值,作为学生模型的对抗偏差;计算教师模型自然样本输出与教师模型的对抗样本输出的差值,作为教师模型的对抗偏差。
3.4)使学生模型对抗偏差,模仿教师模型对抗偏差;在学生模型的损失函数中,添加学生模型对抗偏差与教师模型对抗偏差的相对熵。
3.5)为学生模型的损失函数中的三项相对熵,分配指定权重,优化对抗鲁棒性蒸馏效果,学生模型的损失函数L如下:
其中,KL(·)为相对熵函数,α、β、γ为权重;表示教师模型的第i个图像样本对应的输出,表示学生模型的第i个图像样本对应的输出,表示教师模型的第i个对抗样本对应的输出,表示学生模型的第i个对抗样本对应的输出。
3.6)对学生模型进行鲁棒性蒸馏训练,进行优化。学生模型的优化函数为:
其中,W为学生模型的参数,N表示图像数据集中图像样本的数量,L(·)表示学生模型的损失函数;为第i个图像样本(自然样本),为学生模型的第i个对抗样本,为教师模型的第i个对抗样本;yi表示第i个图像样本的真实的类别标签;CE(·)表示为交叉熵损失函数;||||p表示p-范数;∈表示距离上限。
具体地,优化学生模型时,使用梯度下降法进行优化,W为学生模型的参数,第T次迭代时,损失函数L关于W的偏导数为:
其中,n为梯度更新时输入的图片数量;Lj表示第j个图片输入得到的损失。
则第T次更新的梯度为:
其中,WT-1表示第T-1次迭代后学生模型的参数。
使用梯度下降进行优化更新学生模型:
WT=WT-1-μgT
其中,μ为学习率,其值大于零。WT表示第T次迭代后学生模型的参数。
步骤四:将待识别图像输入步骤三优化后的学生模型,预测得到图像类别。
以下结合具体实验来说明本发明的有效性。使用CIFAR10以及CIFAR100作为图像分类数据集进行实验,如表1所示。
表1:CIFAR10与CIFAR100图像数据集的详细信息
信息 | CIFAR10 | CIFAR100 |
图像类别数 | 10 | 100 |
图像大小 | 32px*32px | 32px*32px |
训练集数量 | 10*5000 | 100*500 |
测试集数量 | 10*1000 | 100*100 |
实验使用的教师模型与学生模型分别为WideResNet和ResNet18,选取对比的方法有AT (对抗训练)、ARD(对抗性鲁棒蒸馏)、IAD(自省性对抗性蒸馏)、RSLAD(鲁棒软标签对抗性蒸馏)和本发明,使用到的对抗攻击测试方法有无攻击、FGSM(快速梯度标志攻击)、PGD(投影梯度下降)、CW(基于优化的攻击),得到的结果如表2所示。
表2:在CIFAR10、CIFAR100数据集上各个对比方法在不同对抗攻击下的分类准确率(%)
从表2可以看出,无论是十分类还是百分类的图像分类问题,本发明使得学生模型在保持较高分类准确率的情况下,依然有着较为出色的抵御多种对抗攻击的鲁棒性,获得的性能明显优于现有其余方法。
从图3可以看到,相较于其他方法,本发明在训练过程中使学生模型自然样本输出与学生模型对抗样本输出之间的相对熵保持在较低的水平,说明本发明使得学生模型在面对对抗攻击样本时输出的变化是在较小范围内的,充分体现出学生模型的泛化性有所提高。
如上所述,本发明提出的基于对抗偏差学习的鲁棒性知识蒸馏方法,使学生模型在图像分类任务里,更好地从教师模型中学习到分类准确率与对抗鲁棒性。
本发明并不限于上述实施方式,采用与本发明上述实施方式相同或近似的方式,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,均在本发明专利的保护范围之内。
Claims (5)
1.一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法,其特征在于,包括以下步骤:
步骤一:在输入的图像数据集上,对于教师模型进行对抗训练的预训练环节;对学生模型进行参数初始化;教师模型和学生模型均为图像分类模型;
步骤二:在步骤一输入的图像数据集上,使用对抗攻击方法,针对于教师模型与学生模型分别生成其各自的对抗样本;
步骤三:使用对抗偏差学习方法进行对抗鲁棒性知识蒸馏,使学生模型学习教师模型的知识,在训练中优化学生模型,包括以下子步骤:
(3.1)使学生模型自然样本输出模仿教师模型自然样本输出,在损失函数中添加学生模型自然样本输出与教师模型自然样本输出的相对熵;
(3.2)使学生模型对抗样本输出模仿教师模型自然样本输出,在损失函数中添加学生模型对抗样本输出与教师模型自然样本输出的相对熵;
(3.3)计算学生模型自然样本输出与其自身对抗样本输出的差值为学生模型的对抗偏差,计算教师模型自然样本输出与其自身对抗样本输出的差值为教师模型的对抗偏差;
(3.4)使学生模型对抗偏差模仿教师模型对抗偏差,在损失函数中添加学生模型对抗偏差与教师模型对抗偏差的相对熵;
(3.5)为损失函数中的三项相对熵分配指定权重,优化对抗鲁棒性蒸馏效果;
(3.6)对于学生模型进行鲁棒性知识蒸馏训练,进行优化;
步骤四:将待识别图像输入步骤三优化后的学生模型,预测得到图像类别。
2.根据权利要求1所述基于对抗偏差与鲁棒性知识蒸馏的图像分类方法,其特征在于,步骤二中,获得对抗样本的方法是,使用投影梯度下降方法针对于模型攻击生成。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210437273.7A CN115131599B (zh) | 2022-04-19 | 2022-04-19 | 一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210437273.7A CN115131599B (zh) | 2022-04-19 | 2022-04-19 | 一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN115131599A CN115131599A (zh) | 2022-09-30 |
CN115131599B true CN115131599B (zh) | 2023-04-18 |
Family
ID=83376343
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210437273.7A Active CN115131599B (zh) | 2022-04-19 | 2022-04-19 | 一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115131599B (zh) |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114049513A (zh) * | 2021-09-24 | 2022-02-15 | 中国科学院信息工程研究所 | 一种基于多学生讨论的知识蒸馏方法和系统 |
CN114170332A (zh) * | 2021-11-27 | 2022-03-11 | 北京工业大学 | 一种基于对抗蒸馏技术的图像识别模型压缩方法 |
Family Cites Families (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109063456B (zh) * | 2018-08-02 | 2021-10-08 | 浙江大学 | 图像型验证码的安全性检测方法及系统 |
CN111461226A (zh) * | 2020-04-01 | 2020-07-28 | 深圳前海微众银行股份有限公司 | 对抗样本生成方法、装置、终端及可读存储介质 |
EP3910479A1 (en) * | 2020-05-15 | 2021-11-17 | Deutsche Telekom AG | A method and a system for testing machine learning and deep learning models for robustness, and durability against adversarial bias and privacy attacks |
CN114219043A (zh) * | 2021-12-21 | 2022-03-22 | 哈尔滨工业大学(深圳) | 基于对抗样本的多教师知识蒸馏方法及装置 |
-
2022
- 2022-04-19 CN CN202210437273.7A patent/CN115131599B/zh active Active
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114049513A (zh) * | 2021-09-24 | 2022-02-15 | 中国科学院信息工程研究所 | 一种基于多学生讨论的知识蒸馏方法和系统 |
CN114170332A (zh) * | 2021-11-27 | 2022-03-11 | 北京工业大学 | 一种基于对抗蒸馏技术的图像识别模型压缩方法 |
Also Published As
Publication number | Publication date |
---|---|
CN115131599A (zh) | 2022-09-30 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110377710B (zh) | 一种基于多模态融合的视觉问答融合增强方法 | |
CN110222770B (zh) | 一种基于组合关系注意力网络的视觉问答方法 | |
CN112446423B (zh) | 一种基于迁移学习的快速混合高阶注意力域对抗网络的方法 | |
Sonkar et al. | qdkt: Question-centric deep knowledge tracing | |
CN108416370A (zh) | 基于半监督深度学习的图像分类方法、装置和存储介质 | |
CN109389166A (zh) | 基于局部结构保存的深度迁移嵌入聚类机器学习方法 | |
CN109308485A (zh) | 一种基于字典域适应的迁移稀疏编码图像分类方法 | |
CN111931814B (zh) | 一种基于类内结构紧致约束的无监督对抗域适应方法 | |
CN114170461B (zh) | 基于特征空间重整化的师生架构含噪声标签图像分类方法 | |
CN113361685B (zh) | 一种基于学习者知识状态演化表示的知识追踪方法及系统 | |
CN111401156B (zh) | 基于Gabor卷积神经网络的图像识别方法 | |
CN112115967B (zh) | 一种基于数据保护的图像增量学习方法 | |
Ding et al. | Why Deep Knowledge Tracing Has Less Depth than Anticipated. | |
CN111241933A (zh) | 一种基于通用对抗扰动的养猪场目标识别方法 | |
CN114385801A (zh) | 一种基于分层细化lstm网络的知识追踪方法及系统 | |
CN114528928A (zh) | 一种基于Transformer的二训练图像分类算法 | |
CN113344053A (zh) | 一种基于试题异构图表征与学习者嵌入的知识追踪方法 | |
CN111274424A (zh) | 一种零样本图像检索的语义增强哈希方法 | |
CN116824216A (zh) | 一种无源无监督域适应图像分类方法 | |
CN116935447A (zh) | 基于自适应师生结构的无监督域行人重识别方法及系统 | |
CN115131599B (zh) | 一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法 | |
CN116433909A (zh) | 基于相似度加权多教师网络模型的半监督图像语义分割方法 | |
CN116431821A (zh) | 基于常识感知的知识图谱补全方法及问答系统 | |
CN113553402B (zh) | 一种基于图神经网络的考试阅读理解自动问答方法 | |
CN113379037B (zh) | 一种基于补标记协同训练的偏多标记学习方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |