CN114842257A - 一种基于多模型对抗蒸馏的鲁棒性图像分类方法 - Google Patents

一种基于多模型对抗蒸馏的鲁棒性图像分类方法 Download PDF

Info

Publication number
CN114842257A
CN114842257A CN202210488306.0A CN202210488306A CN114842257A CN 114842257 A CN114842257 A CN 114842257A CN 202210488306 A CN202210488306 A CN 202210488306A CN 114842257 A CN114842257 A CN 114842257A
Authority
CN
China
Prior art keywords
model
loss
training
distillation
student
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210488306.0A
Other languages
English (en)
Inventor
宣琦
王张伟
陈壮志
徐东伟
凌书扬
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang University of Technology ZJUT
Original Assignee
Zhejiang University of Technology ZJUT
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang University of Technology ZJUT filed Critical Zhejiang University of Technology ZJUT
Priority to CN202210488306.0A priority Critical patent/CN114842257A/zh
Publication of CN114842257A publication Critical patent/CN114842257A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computational Linguistics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Evolutionary Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本发明提供一种基于多模型对抗蒸馏的鲁棒性图像分类方法,包括以下步骤:S1:获取数据集,对复杂模型预训练得到模型T2;S2:根据训练数据集,通过对抗样本生成方法生成对应对抗样本数据;S3:将对抗样本输入复杂模型进行对抗训练,得到模型T2;S4:选择与复杂模型任务相同的轻量化模型作为学生模型S,通过多模型的知识蒸馏框架对学生模型进行蒸馏训练。本发明通过对抗训练和知识蒸馏的方法,实现了一种基于多模型对抗蒸馏的鲁棒性图像分类方法,实现了通过知识蒸馏使学生充分学习不同教师模型的特性,在提高对抗鲁棒性时兼顾精度。

Description

一种基于多模型对抗蒸馏的鲁棒性图像分类方法
技术领域
本发明涉及一种鲁棒性的图像分类方法,具体涉及基于深度神经网络的图像分类方法,属于深度学习、人工智能领域。
背景技术
近年来,随着深度学习技术的迅速发展,多层神经网络结构层出不穷,深度神经网络在图像分类任务中取得了前所未有的成绩。神经网络模型设计应该在模型性能以及模型的复杂程度之间达到良好的权衡。然而在实践的过程中,研究者难以确定恰到好处的平衡点,所以更倾向于选择过度参数化、有足够表达能力而且易于优化的神经网络模型。虽然神经网络随着深度增加以及结构更复杂能够拥有更好的表达能力,但是同时也面临着更多的资源消耗。如果考虑到实际落地应用将其移植到边缘端或者移动设备上,它将会受到占用内存大、高计算量、高能耗等诸多方面的约束。为了解决深度神经网络巨大的内存和计算需求严重阻碍了其在资源有限的设备中部署的问题,减小模型内存以及推理训练的压力,知识蒸馏(参考文献[1]:Geoffrey Hinton,Oriol Vinyals and Jeff Dean,Distilling theKnowledge in a Neural Network.NIPS Deep Learning Workshop,2014,即GeoffreyHinton,Oriol Vinyals and Jeff Dean,在神经网络中的知识蒸馏,NIPS Deep LearningWorkshop,2014.)作为将大模型压缩成小模型的有效方法脱颖而出。知识蒸馏与网络剪枝和参数量化同为模型轻量化的主流方法。
同时,随着深度神经网络模型的不断发展,其安全性也受到严重威胁,对抗攻击算法的出现对深度神经网络的应用造成了严重威胁。面向图像分类的对抗攻击通过在良性样本中加入人类视觉难以察觉的扰动从而使分类器做出错误判断,已成为生产中大规模部署深度学习模型的一个重大障碍。针对对抗样本的攻击,目前广泛应用的防御策略是对抗训练。在经过扰动的数据中进行训练,提升模型对恶意样本的鲁棒性。对抗训练的想法最早在2005年提出(参考文献[2]:Daniel Lowd and Christopher Meek.Adversarial learning,In Proceedings of the eleventh ACM SIGKDD international conference onKnowledge discovery in data mining,2005,即Daniel Lowd and ChristopherMeek.Adversarial learning,对抗学习,In Proceedings of the eleventh ACM SIGKDDinternational conference on Knowledge discovery in data mining,2005.),如今有多种形式的优化,例如,PGD对抗生成对抗样本参与训练提升模型对抗鲁棒性(参考文献[3]:Aleksander Madry,Aleksandar Makelov,Ludwig Schmidt,Dimitris Tsipras,andAdrian Vladu,Towards deep learning models resistant to adversarial attacks.InProc.of ICLR,2018,即Aleksander Madry,Aleksandar Makelov,Ludwig Schmidt,Dimitris Tsipras,and Adrian Vladu,对抗性攻击的深度学习模型,In Proc.of ICLR,2018.)
作为部署应用的需要,在图像分类任务中设计轻量化而鲁棒性高的深度神经网络模型具有重要的意义,因此蒸馏学习和对抗训练分别作为减少模型规模和提高模型鲁棒性的方法,为轻量化深度神经网络的鲁棒性应用提供了技术支持。而在当前的蒸馏学习框架中,通常以一个性能较好的大模型作为教师来蒸馏轻量化的小模型,知识蒸馏后的学生模型功能较为单一,难以在精度和防御性上同时拥有优异的性能;而对抗训练框架能有效地为模型带来强大的防御性,但会使模型损失模型的精度和降低模型的泛化性。因此,能够兼顾模型精度和安全性的训练方法是一个迫切的需要。
本发明结合模型蒸馏学习和对抗训练技术,旨在实现以下两个目标:(1)得到分类性能优异的轻量化模型;(2)在保证优异分类精度的同时兼顾模型的安全鲁棒性。
发明内容
本发明要克服现有技术中模型结构复杂,分类模型泛化能力不高的的缺陷,提供一种基于多模型对抗蒸馏的鲁棒性图像分类方法。
为了实现在轻量化深度神经网络训练中有效地兼顾模型的精度和防御性,本发明提出了一种基于多模型蒸馏结合对抗训练的深度神经网络训练方法,该方法通过知识蒸馏模型信息传递的渠道,将复杂模型中的分类能力以及防御能力传递给轻量化网络,得到一个泛化性和对抗性均优的图像分类器。
本发明的知识蒸馏方法是根据师生模型Logit输出差距作为损失函数项来更新学生模型参数;对抗训样本通过多步梯度下降法生成。最后通过多模型蒸馏框架,利用不同权值分配来区分教师模型知识的重要程度,进行蒸馏训练。
本发明实现上述发明目的所采用的技术方案如下:
S1:预训练复杂模型;
在给定训练数据集X下正常训练复杂模型,得到模型T1。复杂模型应根据先验知识选择分类能力优秀的复杂模型。
S2:产生对抗训练样本;
在给定数据集X下,根据损失函数的梯度通过多次迭代对每个样本添加扰动,得到对抗训练样本
Figure BDA0003630142360000031
对抗样本根据公式(1)生成。
Figure BDA0003630142360000032
其中xt为原始样本;α为扰动系数,确定每次迭代中的步长;sign(·)为符号函数,指定图像像素改变方向;J(x,y)为模型的损失函数;
Figure BDA0003630142360000033
是损失函数关于图像像素值的梯度。
S3:对抗训练复杂模型;
使用步骤S2中产生的扰动样本和正确标签作为对抗训练数据。选取一个复杂的模型框架,直接采取对抗训练,生成教师模型T2。步骤3中,对抗训练过程中每一次新的迭代都根据公式(1)重新生成一批新的扰动样本作为这一轮对抗训练的训练样本。
S4:知识蒸馏;
所述步骤S4的具体包括:
S4.1:选择轻量化模型结构作为学生模型S;
S4.2:将训练样本x(x∈X)输入学生模型S和教师模型T1中,分别得到两个模型的Logit输出。其中真实标签作为学生模型训练硬标签,教师模型的输出作为学生模型软标签。根据公式(2)计算得到lossT1,其中lamba为衡量损失函数重要性的权重,并根据交叉熵损失计算得到lossnat
lossT=(outs-outT)2*lamba (2)
S4.3:在知识蒸馏过程中利用公式(1)根据学生模型S生成对抗样本
Figure BDA0003630142360000034
与步骤S4.2操作相同,将对抗数据集
Figure BDA0003630142360000035
输入教师模型T2和学生模型S得到lossT2和lossadv
S4.4:根据公式(3)计算总损失值。
LOSS=lossnat+lossadv+lossT1+lossT2 (3)
在实验过程,使用者可以根据教师模型的不同以及所需生成轻量化模型功能的需求不同来更改lamba来进行调整。
具体来说,本发明所述的方法具有如下的有益效果:
由于较难直接找到分类精度和防御性都优异的模型作为教师模型,因此在知识蒸馏中难以同时提高学生模型的分类精度和防御性。本发明所述的训练策略采用对抗训练和多模型蒸馏的方法,在兼顾模型精度的同时大幅提高其对抗鲁棒性。避免了单一知识蒸馏无法兼顾精度和防御性以及对抗训练大幅削减精度和泛化性的问题,使训练得到的学生模型性能更加全面。
附图说明
图1为基本知识蒸馏框架。
图2为基于多模型知识蒸馏的鲁棒深度神经网络训练框架。
具体实施方式
下面结合附图并以CIFAR100图像数据为例对本发明的具体实施方式做进一步描述。
参照图2,由复杂模型预训练开始,具体步骤如下:
S1:预训练复杂模型;
复杂模型框架选取Densenet-121,训练数据集采用CIFAR100,数据集由100类32×32的3通道RGB彩色图片组成,其中训练集50000张,测试集10000张。训练得到模型T1
S2:产生对抗训练样本;
根据公式(1)采用PGD方法,设置扰动系数、每次迭代中的步长和迭代次数,使用步骤S1中相同数据集中的训练集来产生对抗训练样本。
Figure BDA0003630142360000041
S3:对抗训练复杂模型;
使用步骤2中产生的扰动图像和正确标签作为对抗训练数据。采用的模型为不经过预训练的Densenet-121,直接采取对抗训练,生成模型T2。步骤S3中,对抗训练过程中每一次新的迭代都根据公式(1)重新生成一批新的扰动样本作为这一轮的训练样本。
S4:知识蒸馏;
S4.1:选择ResNet-34作为轻量化学生模型结构,其网络深度、参数量都远小于教师模型Densenet-121。
S4.2:将训练样本x(x∈X)输入学生模型S和教师模型T1中,分别得到两个模型的Logit输出。其中真实标签作为学生模型训练硬标签,教师模型的输出作为学生模型软标签。根据公式(2)计算得到lossT1,其中lamba为衡量损失函数重要性的权重,并根据交叉熵损失计算得到lossnat
S4.3:在知识蒸馏过程中利用公式(1)根据学生模型S生成对抗样本
Figure BDA0003630142360000042
与步骤S4.2操作相同,将对抗数据集
Figure BDA0003630142360000043
输入教师模型T2和学生模型S得到lossT2和lossadv
S4.4:根据公式(3)计算总损失值。
lossT=(outs-outT)2*lamba (2)
LOSS=lossnat+lossadv+lossT1+lossT2 (3)
如上所述为本发明一种基于多模型对抗蒸馏的鲁棒性图像分类方法方法。相较已有的传统对抗训练方法以牺牲标准精度来提升模型对抗防御性,本发明通过结合知识蒸馏和对抗训练技术,在模型轻量化、分类精度、模型鲁棒性和泛化性四个常用指标来提升模型性能,在大幅提升模型对抗的防御性的同时能够维持甚至提升模型标准分类精度。我们的方法为轻量化深度学习图像分类网络的训练带来新的思路,并希望它能够帮助研究人员更好的提升模型性能,以在边缘端等硬件资源受限情况下部署安全可靠的应用深度学习模型。

Claims (3)

1.一种基于多模型对抗蒸馏的鲁棒性图像分类方法,其特征在于:所述提取方法包括以下步骤:
S1:预训练复杂模型;
在给定训练数据集X下正常训练复杂模型,得到模型T1;复杂模型应根据先验知识选择分类能力优秀的复杂模型;
S2:产生对抗训练样本;
在给定数据集X下,根据损失函数的梯度通过多次迭代对每个样本添加扰动,得到对抗训练样本
Figure FDA0003630142350000011
对抗样本根据公式(1)生成;
Figure FDA0003630142350000012
其中xt为原始样本;α为扰动系数,确定每次迭代中的步长;sign(·)为符号函数,指定图像像素改变方向;J(x,y)为模型的损失函数;
Figure FDA0003630142350000013
是损失函数关于图像像素值的梯度;
S3:对抗训练复杂模型;
使用步骤S2中产生的扰动样本和正确标签作为对抗训练数据;选取一个复杂的模型框架,直接采取对抗训练,生成教师模型T2;步骤S3中,对抗训练过程中每一次新的迭代都根据公式(1)重新生成一批新的扰动样本作为这一轮对抗训练的训练样本;
S4:使用多个教师模型进行知识蒸馏获得轻量化学生模型。
2.如权利要求1所述的基于多模型对抗蒸馏的鲁棒性图像分类方法,其特征在于:步骤S4具体包括:
S4.1:选择轻量化模型结构作为学生模型S;
S4.2:将训练样本x(x∈X)输入学生模型S和教师模型T1中,分别得到两个模型的Logit输出;其中真实标签作为学生模型训练硬标签,教师模型的输出作为学生模型软标签;根据公式(2)计算得到lossT1,其中lamba为衡量损失函数重要性的权重,并根据交叉熵损失计算得到lossnat
lossT=(outs-outT)2*lamba (2)
S4.3:在知识蒸馏过程中利用公式(1)根据学生模型S生成对抗样本XS,与步骤S4.2操作相同,将对抗数据集
Figure FDA0003630142350000021
输入教师模型T2和学生模型S得到lossT2和lossadv
S4.4:根据公式(3)计算总损失值;
LOSS=lossnat+lossadv+lossT1+lossT2 (3)
3.如权利要求2所述的基于多模型对抗蒸馏的鲁棒性图像分类方法,其特征在于:所述步骤S4.2中,使用者根据所需生成模型的功能来确定lamba权值的具体大小,使用者可通过权值分配的不同来生成强对抗鲁棒性模型或是高分类精度模型。
CN202210488306.0A 2022-05-06 2022-05-06 一种基于多模型对抗蒸馏的鲁棒性图像分类方法 Pending CN114842257A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210488306.0A CN114842257A (zh) 2022-05-06 2022-05-06 一种基于多模型对抗蒸馏的鲁棒性图像分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210488306.0A CN114842257A (zh) 2022-05-06 2022-05-06 一种基于多模型对抗蒸馏的鲁棒性图像分类方法

Publications (1)

Publication Number Publication Date
CN114842257A true CN114842257A (zh) 2022-08-02

Family

ID=82568296

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210488306.0A Pending CN114842257A (zh) 2022-05-06 2022-05-06 一种基于多模型对抗蒸馏的鲁棒性图像分类方法

Country Status (1)

Country Link
CN (1) CN114842257A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116030323A (zh) * 2023-03-27 2023-04-28 阿里巴巴(中国)有限公司 图像处理方法以及装置
CN117009534A (zh) * 2023-10-07 2023-11-07 之江实验室 文本分类方法、装置、计算机设备以及存储介质

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116030323A (zh) * 2023-03-27 2023-04-28 阿里巴巴(中国)有限公司 图像处理方法以及装置
CN116030323B (zh) * 2023-03-27 2023-08-29 阿里巴巴(中国)有限公司 图像处理方法以及装置
CN117009534A (zh) * 2023-10-07 2023-11-07 之江实验室 文本分类方法、装置、计算机设备以及存储介质
CN117009534B (zh) * 2023-10-07 2024-02-13 之江实验室 文本分类方法、装置、计算机设备以及存储介质

Similar Documents

Publication Publication Date Title
CN107563422B (zh) 一种基于半监督卷积神经网络的极化sar分类方法
CN114842257A (zh) 一种基于多模型对抗蒸馏的鲁棒性图像分类方法
CN109639710B (zh) 一种基于对抗训练的网络攻击防御方法
CN108764281A (zh) 一种基于半监督自步学习跨任务深度网络的图像分类方法
CN110991299A (zh) 一种物理域上针对人脸识别系统的对抗样本生成方法
CN111160533A (zh) 一种基于跨分辨率知识蒸馏的神经网络加速方法
CN110533570A (zh) 一种基于深度学习的通用隐写方法
CN110097178A (zh) 一种基于熵注意的神经网络模型压缩与加速方法
CN111429340A (zh) 一种基于自注意力机制的循环图像翻译方法
CN105787557A (zh) 一种计算机智能识别的深层神经网络结构设计方法
CN108038507A (zh) 基于粒子群优化的局部感受野极限学习机图像分类方法
CN109993100A (zh) 基于深层特征聚类的人脸表情识别的实现方法
CN106874879A (zh) 基于多特征融合和深度学习网络提取的手写数字识别方法
CN109960755B (zh) 一种基于动态迭代快速梯度的用户隐私保护方法
CN110175646A (zh) 基于图像变换的多通道对抗样本检测方法及装置
CN108256630A (zh) 一种基于低维流形正则化神经网络的过拟合解决方法
CN114676687A (zh) 基于增强语义句法信息的方面级情感分类方法
CN110287985A (zh) 一种基于带变异粒子群算法的可变拓扑结构的深度神经网络图像识别方法
CN103440352A (zh) 基于深度学习的对象间的关联分析方法及其装置
CN111428795A (zh) 一种改进的非凸鲁棒主成分分析方法
CN113901448A (zh) 基于卷积神经网络和轻量级梯度提升机的入侵检测方法
CN114399018B (zh) 一种基于轮转操控策略麻雀优化的EfficientNet陶瓷碎片分类方法
CN113806559B (zh) 一种基于关系路径与双层注意力的知识图谱嵌入方法
CN110795934A (zh) 语句分析模型的训练方法及装置、语句分析方法及装置
CN114329031A (zh) 一种基于图神经网络和深度哈希的细粒度鸟类图像检索方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination