CN117036869B - 一种基于多样性和随机策略的模型训练方法及装置 - Google Patents
一种基于多样性和随机策略的模型训练方法及装置 Download PDFInfo
- Publication number
- CN117036869B CN117036869B CN202311293176.6A CN202311293176A CN117036869B CN 117036869 B CN117036869 B CN 117036869B CN 202311293176 A CN202311293176 A CN 202311293176A CN 117036869 B CN117036869 B CN 117036869B
- Authority
- CN
- China
- Prior art keywords
- layer
- model
- classification
- switching block
- switching
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 115
- 238000000034 method Methods 0.000 title claims abstract description 74
- 238000013145 classification model Methods 0.000 claims abstract description 112
- 238000003860 storage Methods 0.000 claims description 18
- 238000004590 computer program Methods 0.000 claims description 15
- 238000002372 labelling Methods 0.000 claims description 8
- 238000012935 Averaging Methods 0.000 claims description 6
- 230000004044 response Effects 0.000 claims description 3
- 230000008569 process Effects 0.000 abstract description 15
- 238000009826 distribution Methods 0.000 abstract description 8
- 230000006870 function Effects 0.000 description 18
- 238000010586 diagram Methods 0.000 description 16
- 238000005457 optimization Methods 0.000 description 11
- 230000006872 improvement Effects 0.000 description 10
- 238000012545 processing Methods 0.000 description 7
- 238000012360 testing method Methods 0.000 description 7
- 238000013136 deep learning model Methods 0.000 description 6
- 230000000694 effects Effects 0.000 description 6
- 238000003062 neural network model Methods 0.000 description 6
- 241000282326 Felis catus Species 0.000 description 5
- 238000005516 engineering process Methods 0.000 description 5
- 241000283690 Bos taurus Species 0.000 description 4
- 238000009825 accumulation Methods 0.000 description 4
- 230000008859 change Effects 0.000 description 4
- 230000007123 defense Effects 0.000 description 4
- 230000005540 biological transmission Effects 0.000 description 3
- 238000006243 chemical reaction Methods 0.000 description 3
- 238000013508 migration Methods 0.000 description 3
- 230000005012 migration Effects 0.000 description 3
- 238000003058 natural language processing Methods 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 238000013528 artificial neural network Methods 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 238000004519 manufacturing process Methods 0.000 description 2
- 230000000873 masking effect Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 206010039203 Road traffic accident Diseases 0.000 description 1
- 230000004913 activation Effects 0.000 description 1
- 230000006978 adaptation Effects 0.000 description 1
- 230000004075 alteration Effects 0.000 description 1
- 230000008485 antagonism Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 238000007796 conventional method Methods 0.000 description 1
- 238000007405 data analysis Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000002708 enhancing effect Effects 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 230000005055 memory storage Effects 0.000 description 1
- 210000002569 neuron Anatomy 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 229920001296 polysiloxane Polymers 0.000 description 1
- 238000007781 pre-processing Methods 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 238000013138 pruning Methods 0.000 description 1
- 239000010979 ruby Substances 0.000 description 1
- 229910001750 ruby Inorganic materials 0.000 description 1
- 229920006395 saturated elastomer Polymers 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/094—Adversarial learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Multimedia (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本说明书公开了一种基于多样性和随机策略的模型训练方法及装置,待训练的分类模型包含多个基模型,该待训练的分类模型又分为多层切换块,按照前向传播方向依次训练各层切换块,根据各基模型输出的分类结果和标注,确定分类损失,根据各基模型输出的分类结果和样本图像,确定多样性正则损失。在训练过程中,该层切换块的输入为该层切换块的上一层切换块中随机一子模块的输出或样本图像,输出为该层切换块中随机一子模块的输出。这样训练出的分类模型,既保证了模型输出的分类结果准确度,又使各基模型输出的分类结果中各类别对应的概率分布不同,模型的输出为随机一个基模型的输出,极大提高了该分类模型的鲁棒性。
Description
技术领域
本说明书涉及计算机技术领域,尤其涉及一种基于多样性和随机策略的模型训练方法及装置。
背景技术
如今,深度神经网络模型在许多应用领域都表现出了优异的性能,如,图像分类、汽车自动驾驶、语音识别、自然语言处理等。但是,深度学习模型很容易遭受由攻击者精心设计的对抗样本的攻击,即通过故意添加细微的、人类察觉不到的干扰形成的输入样本,导致模型以高置信度给出一个错误的输出。对抗样本的存在使深度神经网络模型的应用面临巨大安全隐患,因此,提高模型对于对抗样本的鲁棒性是一个重要课题。
目前,传统提高深度学习模型鲁棒性的方法包括对抗训练、输入转换、梯度掩码和随机网络防御等,这些方法虽然能够提高深度学习模型的鲁棒性,但都是以牺牲模型对于真实的输入样本的预测准确率为代价的。
因此,如何在提高深度学习模型鲁棒性的同时保证高的预测准确率是一个亟待解决的问题。
发明内容
本说明书提供一种基于多样性和随机策略的模型训练方法、装置、存储介质及电子设备,以至少部分地解决现有技术存在的上述问题。
本说明书采用下述技术方案:
本说明书提供了一种基于多样性和随机策略的模型训练方法,待训练的分类模型包含多个基模型,每个基模型划分为串联的多层子模块,各基模型中相同层的子模块组成切换块,所述待训练的分类模型由多层切换块组成,包括:
获取待分类的样本图像以及所述样本图像的标注,并将所述样本图像输入所述待训练的分类模型;
按照前向传播方向,依次训练各层切换块,针对每一层待训练的切换块,在已固定参数的各前层切换块的基础上,确定该层切换块的输入,所述输入为该层切换块的上一层切换块中随机一子模块的输出或样本图像;
根据该层切换块的输入,确定该层切换块中各子模块的输出,根据随机策略确定各基模型输出的分类结果,并基于所述样本图像的标注,确定分类损失;
根据预设的基准图像、所述样本图像以及各基模型输出的中间图像,确定各基模型的积分梯度,并根据各积分梯度确定多样性正则损失,所述多样性正则损失的值越小代表各积分梯度之间的差异越大;
根据所述分类损失和所述多样性正则损失,训练未固定模型参数的各层切换块,直至满足训练结束条件为止,固定该层切换块的模型参数,并训练下一层切换块;
当各层切换块的模型参数均固定时,得到训练完成的分类模型;响应于携带待分类图像的任务请求,将所述待分类图像输入所述训练完成的分类模型,通过在各层切换块中随机选择的子模块组成的网络,所述分类模型输出所述待分类图像的分类结果。
可选地,针对每一层待训练的切换块,在已固定参数的各前层切换块的基础上,确定该层切换块的输入,所述输入为该层切换块的上一层切换块中随机一子模块的输出或样本图像,具体包括:
若该层切换块是第一层切换块,则该层切换块的输入为所述样本图像;
若该层切换块不是第一层切换块,则将所述样本图像通过已固定参数的各前层切换块后,该层切换块的上一层切换块中随机一子模块的输出,作为该层切换块的输入。
可选地,所述待训练的分类模型的最后一层切换块的各子模块中至少包括输出层;
根据该层切换块的输入,确定该层切换块中各子模块的输出,根据随机策略确定各基模型输出的分类结果,具体包括:
将该层切换块的输入作为该层切换块中各子模块的输入,确定该层切换块中各子模块的输出;
根据随机策略,将该层切换块中各子模块的输出输入之后未固定参数的各层切换块;
通过所述最后一层切换块中各子模块的输出层,得到最后一层切换块中各子模块输出的分类结果;
确定所述最后一层切换块中各子模块所属的基模型,将最后一层切换块中各子模块输出的分类结果,作为所述最后一层切换块中各子模块所属的基模型输出的分类结果。
可选地,并基于所述样本图像的标注,确定分类损失,具体包括:
根据各基模型输出的分类结果和所述样本图像的标注,确定各基模型的分类交叉熵损失;
对各基模型的分类交叉熵损失求均值,得到所述待训练的分类模型的分类损失。
可选地,根据预设的基准图像、所述样本图像以及各基模型输出的中间图像,确定各基模型的积分梯度,具体包括:
确定预设的基准图像和插值路径,沿插值路径从基准图像开始按照指定插值步数均匀缩放样本图像像素强度,得到每一个插值步数上对应的插值图像;
针对待训练的分类模型的每一个基模型,分别计算各插值图像与该基模型输出的中间图像之间的梯度,将各梯度进行累加求和,得到各插值图像与该基模型输出的中间图像之间的累加梯度;
将所述累加梯度相对于所述插值步数进行平均,得到各插值图像与该基模型输出的中间图像之间的初始积分梯度;
根据所述样本图像相对于所述基准图像的差异,对所述初始积分梯度进行调整,得到该基模型的积分梯度。
可选地,根据各积分梯度确定多样性正则损失,具体包括:
将所述待训练的分类模型的各基模型两两组合,确定各组合中两基模型的积分梯度之间的相似度;
根据确定出的各相似度,确定所述待训练的分类模型的多样性正则损失。
可选地,根据所述分类损失和所述多样性正则损失,训练未固定模型参数的各层切换块,直至满足训练结束条件为止,固定该层切换块的模型参数,并训练下一层切换块,具体包括:
根据所述分类损失和所述多样性正则损失,确定所述待训练的分类模型的总损失;
以所述总损失最小为目标训练未固定模型参数的各层切换块,直至满足训练结束条件为止,对该层切换块的训练完成;
固定该层切换块的模型参数,并对未固定参数的各层切换块进行初始化,训练下一层切换块。
本说明书提供的一种基于多样性和随机策略的模型训练装置,所述装置包括:
获取模块,获取待分类的样本图像以及所述样本图像的标注,并将所述样本图像输入所述待训练的分类模型;
模型训练模块,按照前向传播方向,依次训练各层切换块,针对每一层待训练的切换块,在已训练好的各前层切换块的基础上,确定该层切换块的输入,所述输入为该层切换块的上一层切换块中随机一子模块的输出或样本图像;
分类损失确定模块,根据该层切换块的输入,确定该层切换块中各子模块的输出,根据随机策略确定各基模型输出的分类结果,并基于所述样本图像的标注,确定分类损失;
多样性正则损失确定模块,根据预设的基准图像、所述样本图像以及各基模型输出的中间图像,确定各基模型的积分梯度,并根据各积分梯度确定多样性正则损失,所述多样性正则损失的值越小代表各积分梯度之间的差异越大;
切换块训练模块,根据所述分类损失和所述多样性正则损失,训练未固定模型参数的各层切换块,直至满足训练结束条件为止,固定该层切换块的模型参数,并训练下一层切换块;
分类任务响应模块,当各层切换块的模型参数均固定时,得到训练完成的分类模型;响应于携带待分类图像的任务请求,将所述待分类图像输入所述训练完成的分类模型,通过在各层切换块中随机选择的子模块组成的网络,所述分类模型输出所述待分类图像的分类结果。
本说明书提供了一种计算机可读存储介质,所述存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述基于多样性和随机策略的模型训练方法。
本说明书提供了一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述基于多样性和随机策略的模型训练方法。
本说明书采用的上述至少一个技术方案能够达到以下有益效果:
在本说明书提供的基于多样性和随机策略的模型训练方法中,待训练的分类模型包含多个基模型,该待训练的分类模型又分为多层切换块,按照前向传播方向依次训练各层切换块,根据各基模型输出的分类结果和标注,确定分类损失,根据各基模型输出的分类结果和样本图像,确定多样性正则损失。在训练过程中,该层切换块的输入为该层切换块的上一层切换块中随机一子模块的输出或样本图像,输出为该层切换块中随机一子模块的输出。这样训练出的分类模型,既保证了模型输出的分类结果准确度,又使各基模型输出的分类结果中各类别对应的概率分布不同,模型的输出为随机一个基模型的输出,极大提高了该分类模型的鲁棒性。
附图说明
此处所说明的附图用来提供对本说明书的进一步理解,构成本说明书的一部分,本说明书的示意性实施例及其说明用于解释本说明书,并不构成对本说明书的不当限定。在附图中:
图1为本说明书实施例提供的一种基于多样性和随机策略的模型训练方法的流程示意图;
图2为本说明书实施例提供的一种基于多样性和随机策略的模型训练方法的一种模型结构示意图;
图3为本说明书实施例提供的一种基于多样性和随机策略的模型训练方法的随机策略示意图;
图4(a)、图4(b)和图4(c)为本说明书实施例提供的一种基于多样性和随机策略的模型训练方法的鲁棒性增强原理图;
图5为本说明书实施例提供的一种基于多样性和随机策略的模型训练装置的示意图;
图6为本说明书实施例提供的对应于图1的电子设备示意图。
具体实施方式
为使本说明书的目的、技术方案和优点更加清楚,下面将结合本说明书具体实施例及相应的附图对本说明书技术方案进行清楚、完整地描述。显然,所描述的实施例仅是本说明书一部分实施例,而不是全部的实施例。基于本说明书中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
深度学习是目前研究最活跃的计算机领域之一,作为一种重要的数据分析方法,深度神经网络模型在许多如生物特征识别、图像分类、汽车自动驾驶、语音识别、自然语言处理等应用领域表现出了优异的性能。但是,深度学习模型很容易遭受由攻击者精心设计的对抗样本的攻击。对抗样本指的是对原始样本加入微小的、人类察觉不到的扰动,导致深度学习模型以高置信度给出一个错误的输出。比如,针对面部识别模型的特点,通过对原始面部图片添加人为精心制造的微小扰动便可使面部识别模型做出错误的分类;针对自动汽车驾驶的恶意控制,通过对原始障碍物图片添加微小扰动,使自动汽车驾驶的识别模型对障碍物做出错误分类,造成交通事故等。对抗样本的存在给深度神经网络模型的应用带来巨大威胁,提高深度神经网络模型的鲁棒性显得尤为重要。
传统提高深度神经网络模型鲁棒性的方法包括对抗训练、输入转换、梯度掩码和随机网络防御等。对抗训练是指对原始样本添加一些微小扰动生成对抗样本,在模型每一轮训练过程中,将对抗样本加入到原始样本中共同训练,使深度神经网络适应这种改变,从而提高模型的鲁棒性,但是这种适应性改变通常会降低模型对于原始样本的预测准确率。输入转换是指在样本输入到模型之前,先对样本进行去噪,剔除其中扰动的信息,以提高模型的鲁棒性,但是,由于去噪不可能完全消除扰动,剩余的极微小扰动会在深度神将网络模型中被逐层放大,最终导致较大扰动,所以通过去噪后的样本训练出的模型预测准确率反而会降低。梯度掩码是指隐藏模型的原始梯度,能够防御以输入样本为起点,在损失函数梯度方向上修改输入样本而产生对抗样本的攻击,但对其他方法产生的抗样本几乎没有抵御效果。
传统的随机网络防御,如随机神经元激活剪枝或输入层随机扰动被证明能有效的提升防御能力,但也面临对真实的输入样本上预测准确率大幅下降的缺陷。分层随机交换(Hierarchical Random Switching,HRS)随机网络防御方法,其模型包含几个随机切换的通道块,在保证模型的预测准确率的前提下一定程度上提高了模型的鲁棒性,但是,如果通道块网络结构过于一致,基于相同训练数据训练的情况下,得到的通道块网络参数也比较相似,由于对抗攻击具有迁移特性,针对一个通道块生成的对抗样本对其他通道块也将有效,仍旧无法抵御迁移攻击,造成基于此方法的鲁棒性提升效果有限。如果能够增加各通道块之间的多样性,将提升模型对于迁移攻击的鲁棒性,进而极大提升模型的鲁棒性。
以下结合附图,详细说明本说明书各实施例提供的技术方案。
图1为本说明书中一种基于多样性和随机策略的模型训练方法的流程示意图,具体包括以下步骤:
S100:获取待分类的样本图像以及所述样本图像的标注,并将所述样本图像输入所述待训练的分类模型。
在本说明书中,用于进行基于多样性和随机策略的模型训练的设备可以是服务器,也可以是诸如台式电脑、笔记本电脑等电子设备。为了便于描述,下面仅以服务器为执行主体,对本说明书提供的基于多样性和随机策略的模型训练方法进行说明。
服务器从图像数据集中获取用于训练的样本图像,以及各样本图像对应的标注。然后,将获取到的样本图像按照一定比例分为训练集和测试集,将训练集中的样本图像输入待训练的分类模型,进行模型的训练。
S102:按照前向传播方向,依次训练各层切换块,针对每一层待训练的切换块,在已固定参数的各前层切换块的基础上,确定该层切换块的输入,所述输入为该层切换块的上一层切换块中随机一子模块的输出或样本图像。
本说明书中待训练的分类模型包含多个基模型,每个基模型划分为串联的多层子模块,各基模型中相同层的子模块组成切换块,所述待训练的分类模型由多层切换块组成,以切换块为单位进行模型的训练。图2为本说明书实施例提供的一种基于多样性和随机策略的模型训练方法的一种模型结构示意图,一个立方体表示深度神经网络的一层,箭头表示切换块的输入或输出。图2所示的分类模型,由三个基模型集成得到,每个基模型按照层次均分成N个子模块,各基模型中相同层的子模块组成切换块,该神将网络模型由N层切换块组成。
服务器按照前向传播方向,依次训练各层切换块,针对待训练的每层切换块,将样本图像依次通过已固定参数的各前层切换块,即通过已经训练完成的各前层切换块组成的网络,得到该层切换块的输入。其中,如果该层切换块是第一层切换块,则该层切换块的输入为样本图像,如果该层切换块不是第一层切换块,则将样本图像通过已固定参数的各前层切换块后,该层切换块的上一层切换块中随机一子模块的输出,作为该层切换块的输入。
上述步骤中的基模型,可选用任一可用于分类任务的模型,如ResNet网络、VGGNet网络、GoogleNet网络等,本说明书不做具体限定。
S104:根据该层切换块的输入,确定该层切换块中各子模块的输出,根据随机策略确定各基模型输出的分类结果,并基于所述样本图像的标注,确定分类损失。
待训练的分类模型的最后一层切换块的各子模块至少包括输出层,针对每一层待训练的切换块,服务器将该层切换块的输入作为该层切换块中各子模块的输入,得到该层切换块中各子模块的输出。根据随机策略,将该层切换块中各子模块的输出输入之后未固定参数的各层切换块,通过最后一层切换块中各子模块的输出层,得到最后一层切换块中各子模块输出的分类结果。
然后,服务器确定最后一层切换块中各子模块所属的基模型,将最后一层切换块中各子模块输出的分类结果,作为最后一层切换块中各子模块所属的基模型输出的分类结果。该分类结果是基于已经训练完成的所有前层切换块和未训练完成的所有后层切换块组成的网络对于样本图像做出的类别概率预测。
服务器对于模型在相邻两层切换块的数据传输采用随机策略。图3为本说明书实施例提供的一种基于多样性和随机策略的模型训练方法的随机策略原理图,表示第n层切换块中第j个子模块,/>表示第n层切换块的输出。如图3所示的分类模型中,第一层切换块的输入为输入模型的样本图像,第二层切换块的输入为第一层切换块的子模块中随机一个子模块的输出,该分类模型输出为最后一层切换块中随机一个子模块输出的分类结果。
当该层切换块的训练完成以后,随机将该层切换块中任一子模块的输出作为该层切换块的输出,该输出再作为下一层切换块的输入继续训练过程,直至损失函数收敛或达到最大训练轮数,对该层切换块的训练结束。
随机策略的应用,使得每个切换块的输出具有不确定性,如图3所示的分类模型网络结构中,分类模型由三个基模型集成得到,该分类模型分为了三个切换块,其中,每个基模型都是一个分类器,模型最终的分类结果为这三个基模型输出的分类结果之一。未应用随机策略时,该分类模型最终输出的分类结果只有3种可能,应用随机策略后,该分类模型最终输出的分类结果变为了种可能。每一个切换块中的每一个子模块和其他切换块中的每一个子模块随机组合,产生了/>种不同的网络,每一个网络的参数配置都不同,每一次模型输出的分类结果由在各层切换块中随机选择的子模块组成的一个网络决定,增强了模型输出的分类结果的多样性,从而增强了模型的鲁棒性。
最后,服务器根据各基模型输出的分类结果和样本图像对应的标注,确定各基模型的分类交叉熵损失。然后,对各基模型的分类交叉熵损失进行累加求和,得到所述待训练的分类模型的总交叉熵损失,并将所述总交叉熵损失相对于基模型数量进行平均,得到所述待训练的分类模型的分类损失。
具体的,可采用如下公式确定分类损失:
其中,M为基模型数量,为该层切换块中的第i个基模型对样本图像/>的分类预测结果,/>为样本图像标注的真实分类结果,/>为该切换块中第i个基模型的分类交叉熵损失。
上述公式表明,对各基模型的分类交叉熵损失进行求和,得到该层切换块的总交叉熵损失,再将该总交叉熵损失相对于基模型数量M进行平均,得到待训练的分类模型的分类损失。
根据分类损失对未固定参数的各层切换块中各子模块分别进行调参,可确保由每层切换块中随机一子模块组成的网络,对样本图像的分类预测都达到预设精度。
S106:根据预设的基准图像、所述样本图像以及各基模型输出的中间图像,确定各基模型的积分梯度,并根据各积分梯度确定多样性正则损失,所述多样性正则损失的值越小代表各积分梯度之间的差异越大。
本说明书中采用积分梯度作为衡量输入的样本图像中每个像素对于模型输出的预测结果中每个类别贡献度大小的指标。
样本图像根据随机策略,通过待训练的分类模型的各层切换块中随机选择的子模块组成的网络,得到最后一层切换块中各子模块输出的中间图像,针对最后一层切换块中的各子模块,该子模块的中间图像通过该子模块的输出层,得到该子模块输出的分类结果,并将该子模块对应的分类结果作为该子模块所属的基模型的分类预测结果。
首先,服务器需确定预设的基准图像和插值路径,沿插值路径从基准图像开始按照指定插值步数均匀缩放样本图像像素强度,得到每一个插值步数上对应的插值图像。针对待训练的分类模型的每一个基模型,分别计算各插值图像与该基模型输出的中间图像之间的梯度,将这些梯度进行累加求和得到各插值图像与该基模型输出的中间图像之间的累加梯度。将该累加梯度相对于插值步数进行平均,得到各插值图像与该基模型输出的中间图像之间的初始积分梯度。根据样本图像和基准图像之间的差异,对该初始积分梯度进行调整,得到该基模型的积分梯度。
对于基准图像,可以选用全黑图、噪声图片、光滑模糊图片、高斯随机图片等,本说明书不做具体限定。
具体的,取插值路径为,可采用下述公式确定样本图像相对于由当前正在训练的第n层切换块、已经固定参数的n-1层切换块和未固定参数的N-n层切换块组成的网络中第j个基模型的积分梯度:
其中,为0~1间的插值常数,N为待训练的分类模型包含的切换块层数,/>为样本图像,/>为基准图像,/>为插值步数,/>为各插值图像相对于由当前正在训练的第n层切换块、已经训练完成的n-1层切换块和未开始训练的N-n层切换块组成的网络对样本图像的分类预测结果的梯度。
由于插值图像是对于样本图像进行缩放得到的图像,不同的插值图像之间的像素值势必会存在尺度上的差异。为了避免模型在迭代过程中为抹除这种差异而增加大量迭代次数,导致收敛速度过慢,在计算出各插值图像与该基模型输出的中间图像之间的初始积分梯度后,服务器需要根据样本图像和基准图像之间的差异,对该初始积分梯度进行调整,以保证不同的插值图像相对于样本图像的积分梯度映射在同一尺度内。
然后,服务器将所述待训练的分类模型的各基模型两两组合,确定各组合中两基模型的积分梯度之间的相似度,将各相似度进行累加求和,得到待训练的分类模型输出的分类结果的全局相似度,通过对该全局相似度进行归一化操作,得到待训练的分类模型的多样性正则损失。具体的,可根据如下公式确定待训练的分类模型的多样性正则损失:
其中,M表示基模型数量,n表示当前正在训练的该层切换块层数,表示第j个基模型与第k个基模型的积分梯度之间的相似度。
积分梯度的方向代表了待训练的分类模型的迭代优化方向,两个基模型的积分梯度之间的相似度的值越大,代表两个基模型输出的分类结果中各类别对应的概率分布差异越大;两个基模型输出的积分梯度之间的相似度的值越小,代表两个基模型输出的分类结果中各类别对应的概率分布差异越小。沿着多样性正则损失减小的方向对该层切换块进行优化,由于模型的输出的分类结果为根据随即策略确定的最后一层切换块中随机一子模块的输出的分类结果,这样模型输出的分类结果就呈现出多样性。
另外,本说明书中沿积分梯度的方向来进行模型的迭代优化,避免了沿梯度方向进行模型迭代优化过程中出现的梯度饱和时,样本图像中像素值的改变无法得到模型输出的分类结果相应改变的现象。即当梯度饱和时,梯度将不能作为衡量输入的样本图像中每个像素值对于模型输出的预测结果中每个类别贡献度大小的指标。而积分梯度由于考虑了整条插值路径上各插值图像与该基模型输出的中间图像之间的梯度,不再受某一特定点梯度的制约,是比梯度更好的贡献度大小度量指标。
S108:根据所述分类损失和所述多样性正则损失,训练未固定模型参数的各层切换块,直至满足训练结束条件为止,固定该层切换块的模型参数,并训练下一层切换块。
服务器将分类损失和多样性正则损失以指定权值相加,构建待训练的分类模型的损失函数。具体的,损失函数可根据下式确定:
其中,为第n层切换块的分类损失,/>为第n层切换块的多样性正则损失,/>为惩罚因子。/>的数值可根据需要设置,本说明书不做具体限定。
在每一层切换块的训练过程中,服务器根据此损失函数确定待训练的分类模型的总损失,以总损失最小为目标调整各层未固定参数的切换块的参数,对各层未固定参数的各层切换块进行优化,直至满足训练结束条件,对该层切换块的训练完成。
服务器固定训练完成的该层切换块的模型参数,并对未固定参数的各层切换块进行初始化,进行下一层切换块的训练。
因为损失函数是分类损失和多样性正则损失以指定权值相加而得,当分类损失和多样性正则损失的值都比较小时,损失函数的值才能比较小时。分类损失小代表模型的分类准确度高,多样性正则损失小代表了不同的基模型输出的分类结果中各类别对应的概率分布差异大。所以,以此方式构建的损失函数在保证分类模型预测准确率的同时增加了分类模型输出的分类结果的多样性。
S110:当各层切换块的模型参数均固定时,得到训练完成的分类模型;响应于携带待分类图像的任务请求,将所述待分类图像输入所述训练完成的分类模型,通过在各层切换块中随机选择的子模块组成的网络,所述分类模型输出所述待分类图像的分类结果。
当各层切换块的参数都固定时,对待训练的分类模型中各层切换块的训练全部完成,各层训练完成的切换块组成的网络即为训练完成的分类模型。
当该训练完成的分类模型响应于携带待分类图像的任务请求时,将待分类图像输入训练完成的分类模型,通过各层切换块中随机选择的子模块组成的网络,该分类模型最后一层切换块输出待分类图像的分类结果。具体的,待分类图像的分类结果可用下式表示:
其中,表示分类模型输出的图像分类结果,/>表示样本图像,N表示组成该分类模型的切换块数量,/>表示第N层切换块的输出为第N层切换块中随机一子模块的输出。
样本图像输入该分类模型后,将作为第一层切换块中随机一个子模块的输入,通过该子模块得到第一层切换块的输出,第一层切换块的输出继续作为第二层切换块中随机一个子模块的输入,继续此过程一直到最后一层切换块,将最后一层切换块中随机一个子模块输出的分类结果作为该分类模型输出的分类结果。
随机策略使该分类模型中每一层切换块的输入输出具有不确定性,当该分类模型包含M个基模型,划分为N个切换块时,该分类模型输出的分类结果由在各层切换块中随机选择的子模块组成的一个网络决定,每一个切换块中的每一个子模块和其他切换块中的每一个子模块随机组合,将产生了种不同的网络,该分类模型的输出为样本图像经过这种网络中的随机一个网络所产生的结果。
又因为组成这每一种网络的子模块之间的优化方向不具有相关性,每一个子模块沿不同的积分梯度方向优化,通过各网络输出的分类结果中各类别对应的概率分布将呈现多样性。当遭遇对抗样本的攻击时,由于攻击者无法确定该分类模型输出的分类结果是通过哪一种网络所产生的,如果攻击者想生成有效的对抗样本,就必须使对抗样本对于该分类模型中的每一个基模型输出的分类结果同时具有攻击作用。但是由于攻击者无法确定此网络的优化方向,同时满足多个基模型优化方向的样本又很难生成,所以该分类模型对于对抗攻击的鲁棒性增强。
由于在本说明书中待训练的分类模型的损失函数考虑了分类损失和多样性正则损失,把模型分为多个切换块按照前向传播方向依次训练,相邻两个切换块之间的优化方向不具有相关性。这样,在模型的训练过程中,不仅可以使模型的预测准确率提高,同时能使模型中各基模型的输出的分类结果中各类别对应的概率分布差异尽可能变大。由于模型输出结果的多样性,实现了模型的鲁棒性增强。
例如,通过本书明书的模型训练方法训练出的一个分类类别为“牛、熊、鸟、猫、狗”的五分类模型,该模型包含三个基模型。对一张标注为“猫”的样本图像进行分类预测,第一个基模型输出的分类结果为:“牛:0.00;熊:0.00;鸟:0.00;猫:0.90;狗:0.10”,第二个基模型输出的分类结果为“牛:0.00;熊:0.15;鸟:0.00;猫:0.90;狗:0.05”,第三个基模型输出的分类结果为“牛:0.00;熊:0.00;鸟:0.05;猫:0.95;狗:0.00”。模型在保证各基模型输出的分类结果的达到一定精度的前提下,使每个基模型输出的分类结果中各类别对应的概率分布呈现多样性。
对于集成模型来说,当模型遭遇对抗样本的攻击时,对抗样本需满足集成模型中各基模型的分类标准,才能使集成模型输出错误的分类结果。由于本说明书中的分类模型的每个切换块中子模块之间沿不同的积分梯度方向进行优化,最终模型的输出结果是由各层切换块中随机选择的子模块组成的网络确定。如果遭遇对抗样本的攻击,此对抗样本中的扰动添加方向需同时满足此模型中所有基模型的积分梯度优化方向,这通常是比较困难甚至不可能实现的,所以按照本说明书中的模型训练方法训练出的分类模型,具有很好的鲁棒性。
图4(a)、图4(b)和图4(c)为本说明书实施例提供的一种基于多样性和随机策略的模型训练方法的鲁棒性增强原理图,阴影部分表示对抗样本需满足的条件。如图4(a)所示,若待训练的分类模型只包含一个基模型,即为普通的单一分类器时,对抗样本只需满足这一个模型的分类标准,沿着这一个基模型的优化方向添加扰动信息,就可能得到使这个分类模型做出误判的对抗样本。如图4(b)所示,若待训练的分类模型包含两个基模型,此分类模型的分类结果在两个基模型的分类结果中随机选定,对抗样本需同时满足这两个基模型的分类标准,才可能使这个分类模型做出误判,对抗样本生成的难度增加。而当两个基模型的优化方向不同,尤其是两个基模型沿着相互正交的积分梯度方向进行优化时,根本不存在同时满足两个基模型分类标准的对抗样本。如图4(c)所示,待训练的分类模型包含三个基模型,对抗样本需同时满足这三个基模型的分类标准,难度就更大了。
所以,本说明书中提供的模型训练方法可以很好的提高模型的鲁棒性,本说明书中以分类模型为例对模型的训练方法进行说明,该训练方法也可用于其他任务的模型训练,如识别模型、自然语言处理等,本说明书不做具体限定。
上述步骤S108中,服务器从图像数据集中获取用于训练的样本图像之后,可以对样本图像进行预处理。预处理包括归一化、顺序打乱等。由于深度学习网络在进行训练时的参数一般较小,而像素值一般较大,用较小参数来拟合较大数值会导致模型训练的时间成本加大,所以通常在样本图像用于模型训练之前,对样本图像的像素值进行归一化处理。为了避免样本图像在数据集中按照一定规律存储,邻近部分的图像之间可能存在相关性,直接用于训练不利于模型学习这些图片的特征,所以需要对获取到的样本图像进行顺序打乱处理。
预处理之后,对样本图像按照一定比例进行训练集与测试集的划分,将训练集中的样本图像输入待训练的分类模型,进行模型的迭代优化,模型训练后之后,用测试集测试模型的分类准确度。
但是,模型训练所用的样本数据和模型在应用过程中真正需要进行分类的图像数据通常不是一个数据集,而真正需要分类预测的数据集中数据量又很少,不能直接用来进行模型的训练,所以可能导致模型对于训练集中图片的分类预测效果很好,而测试集中图片的分类预测效果不太好的情况。例如,用于模型训练的数据集可能是从网络上下载的大规模图片数据集A,而模型在应用过程中真正需要进行分类的图像数据集B是用户在终端设备上传的图片,数量较少,不能直接用于模型的训练,且这些图片相较于数据集A中的图片清晰度偏低,所以直接由数据集A训练出的模型对数据集B中这些图片的分类效果并不好。此时,可以将来自数据集A的图片作为训练集,进行模型的训练,选取数据集B的一部分作为验证集,对于训练好的模型进行微调以适应真实场景的分类任务,然后用数据集B的另外一部分作为测试集,测试模型的分类准确度。
上述步骤S108中,插值图像是从基准图像开始按照指定插值步数均匀缩放样本图像像素强度得到的,越靠近基准图像的步数对应的插值图像的像素值对于样本图像的特征显示就越不明显。所以,在计算积分梯度时,可以先确定一个插值区间,该插值区间为一段连续的插值步数,对应于计算积分梯度所需各插值图像。其中,一般取插值区间的起始插值步数为插值步数的一个中间步数,结束步数为最大插值步数,具体插值区间的选取规则本说明书不做具体限定。
在插值区间上进行积分梯度的计算,加快模型迭代速度的同时又能更好的反应输入图像中的每个像素值对于模型最终预测的分类结果的影响。具体的,服务器计算插值区间上各插值图像与该基模型输出的中间图像之间的梯度,将这些梯度进行累加求和得到各插值图像与该基模型输出的中间图像之间的累加梯度。将该累加梯度相对于该插值区间中包含的插值步数的进行平均,得到样本图像与该基模型输出的中间图像之间的初始积分梯度。根据样本图像和基准图像之间的差异,对该初始积分梯度进行调整,得到该基模型的积分梯度。
以上是本说明书提供的基于多样性和随机策略的模型训练方法,基于同样的思路,本说明书还提供了相应的样本生成装置,如图5所示。
图5为本说明书提供的一种基于多样性和随机策略的模型训练装置示意图,具体包括:
获取模块200,获取待分类的样本图像以及所述样本图像的标注,并将所述样本图像输入所述待训练的分类模型;
模型训练模块202,按照前向传播方向,依次训练各层切换块,针对每一层待训练的切换块,在已训练好的各前层切换块的基础上,确定该层切换块的输入,所述输入为该层切换块的上一层切换块中随机一子模块的输出或样本图像;
分类损失确定模块204,根据该层切换块的输入,确定该层切换块中各子模块的输出,根据随机策略确定各基模型输出的分类结果,并基于所述样本图像的标注,确定分类损失;
多样性正则损失确定模块206,根据预设的基准图像、所述样本图像以及各基模型输出的中间图像,确定各基模型的积分梯度,并根据各积分梯度确定多样性正则损失,所述多样性正则损失的值越小代表各积分梯度之间的差异越大;
切换块训练模块208,根据所述分类损失和所述多样性正则损失,训练未固定模型参数的各层切换块,直至满足训练结束条件为止,固定该层切换块的模型参数,并训练下一层切换块;
分类任务响应模块210,当各层切换块的模型参数均固定时,得到训练完成的分类模型;响应于携带待分类图像的任务请求,将所述待分类图像输入所述训练完成的分类模型,通过在各层切换块中随机选择的子模块组成的网络,所述分类模型输出所述待分类图像的分类结果。
可选地,所述模型训练模块202,具体用于若该层切换块是第一层切换块,则该层切换块的输入为所述样本图像;若该层切换块不是第一层切换块,则将所述样本图像通过已固定参数的各前层切换块后,该层切换块的上一层切换块中随机一子模块的输出,作为该层切换块的输入。
可选地,所述分类损失确定模块204,所述待训练的分类模型的最后一层切换块的各子模块中至少包括输出层,具体用于将该层切换块的输入作为该层切换块中各子模块的输入,确定该层切换块中各子模块的输出;根据随机策略,将该层切换块中各子模块的输出输入之后未固定参数的各层切换块;通过所述最后一层切换块中各子模块的输出层,得到最后一层切换块中各子模块输出的分类结果;确定所述最后一层切换块中各子模块所属的基模型,将最后一层切换块中各子模块输出的分类结果,作为所述最后一层切换块中各子模块所属的基模型输出的分类结果。根据各基模型输出的分类结果和所述样本图像的标注,确定各基模型的分类交叉熵损失;对各基模型的分类交叉熵损失求均值,得到所述待训练的分类模型的分类损失。
可选地,所述多样性正则损失确定模块206,具体用于确定预设的基准图像和插值路径,沿插值路径从基准图像开始按照指定插值步数均匀缩放样本图像像素强度,得到每一个插值步数上对应的插值图像;针对待训练的分类模型的每一个基模型,分别计算各插值图像与该基模型输出的中间图像之间的梯度,将各梯度进行累加求和,得到各插值图像与该基模型输出的中间图像之间的累加梯度;将所述累加梯度相对于所述插值步数进行平均,得到各插值图像与该基模型输出的中间图像之间的初始积分梯度;根据所述样本图像相对于所述基准图像的差异,对所述初始积分梯度进行调整,得到该基模型的积分梯度。将所述待训练的分类模型的各基模型两两组合,确定各组合中两基模型的积分梯度之间的相似度;根据确定出的各相似度,确定所述待训练的分类模型的多样性正则损失。
可选地,所述切换块训练模块208,具体用于根据所述分类损失和所述多样性正则损失,确定所述待训练的分类模型的总损失;以所述总损失最小为目标训练未固定模型参数的各层切换块,直至满足训练结束条件为止,对该层切换块的训练完成;固定该层切换块的模型参数,并对未固定参数的各层切换块进行初始化,训练下一层切换块。
本说明书还提供了一种计算机可读存储介质,该存储介质存储有计算机程序,计算机程序可用于执行上述图1提供的基于多样性和随机策略的模型训练方法。
本说明书还提供了图6所示的电子设备的示意结构图。如图6所述,在硬件层面,该电子设备包括处理器、内部总线、网络接口、内存以及非易失性存储器,当然还可能包括其他业务所需要的硬件。处理器从非易失性存储器中读取对应的计算机程序到内存中然后运行,以实现上述图1所述的基于多样性和随机策略的模型训练方法。当然,除了软件实现方式之外,本说明书并不排除其他实现方式,比如逻辑器件抑或软硬件结合的方式等等,也就是说以下处理流程的执行主体并不限定于各个逻辑单元,也可以是硬件或逻辑器件。
对于一个技术的改进可以很明显地区分是硬件上的改进(例如,对二极管、晶体管、开关等电路结构的改进)还是软件上的改进(对于方法流程的改进)。然而,随着技术的发展,当今的很多方法流程的改进已经可以视为硬件电路结构的直接改进。设计人员几乎都通过将改进的方法流程编程到硬件电路中来得到相应的硬件电路结构。因此,不能说一个方法流程的改进就不能用硬件实体模块来实现。例如,可编程逻辑器件(ProgrammableLogic Device, PLD)(例如现场可编程门阵列(Field Programmable Gate Array,FPGA))就是这样一种集成电路,其逻辑功能由用户对器件编程来确定。由设计人员自行编程来把一个数字系统“集成”在一片PLD上,而不需要请芯片制造厂商来设计和制作专用的集成电路芯片。而且,如今,取代手工地制作集成电路芯片,这种编程也多半改用“逻辑编译器(logic compiler)”软件来实现,它与程序开发撰写时所用的软件编译器相类似,而要编译之前的原始代码也得用特定的编程语言来撰写,此称之为硬件描述语言(HardwareDescription Language,HDL),而HDL也并非仅有一种,而是有许多种,如ABEL(AdvancedBoolean Expression Language)、AHDL(Altera Hardware Description Language)、Confluence、CUPL(Cornell University Programming Language)、HDCal、JHDL(JavaHardware Description Language)、Lava、Lola、MyHDL、PALASM、RHDL(Ruby HardwareDescription Language)等,目前最普遍使用的是VHDL(Very-High-Speed IntegratedCircuit Hardware Description Language)与Verilog。本领域技术人员也应该清楚,只需要将方法流程用上述几种硬件描述语言稍作逻辑编程并编程到集成电路中,就可以很容易得到实现该逻辑方法流程的硬件电路。
控制器可以按任何适当的方式实现,例如,控制器可以采取例如微处理器或处理器以及存储可由该(微)处理器执行的计算机可读程序代码(例如软件或固件)的计算机可读介质、逻辑门、开关、专用集成电路(Application Specific Integrated Circuit,ASIC)、可编程逻辑控制器和嵌入微控制器的形式,控制器的例子包括但不限于以下微控制器:ARC 625D、Atmel AT91SAM、Microchip PIC18F26K20 以及Silicone Labs C8051F320,存储器控制器还可以被实现为存储器的控制逻辑的一部分。本领域技术人员也知道,除了以纯计算机可读程序代码方式实现控制器以外,完全可以通过将方法步骤进行逻辑编程来使得控制器以逻辑门、开关、专用集成电路、可编程逻辑控制器和嵌入微控制器等的形式来实现相同功能。因此这种控制器可以被认为是一种硬件部件,而对其内包括的用于实现各种功能的装置也可以视为硬件部件内的结构。或者甚至,可以将用于实现各种功能的装置视为既可以是实现方法的软件模块又可以是硬件部件内的结构。
上述实施例阐明的系统、装置、模块或单元,具体可以由计算机芯片或实体实现,或者由具有某种功能的产品来实现。一种典型的实现设备为计算机。具体的,计算机例如可以为个人计算机、膝上型计算机、蜂窝电话、相机电话、智能电话、个人数字助理、媒体播放器、导航设备、电子邮件设备、游戏控制台、平板计算机、可穿戴设备或者这些设备中的任何设备的组合。
为了描述的方便,描述以上装置时以功能分为各种单元分别描述。当然,在实施本说明书时可以把各单元的功能在同一个或多个软件和/或硬件中实现。
本领域内的技术人员应明白,本申请的实施例可提供为方法、系统、或计算机程序产品。因此,本申请可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本申请是参照根据本申请实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
在一个典型的配置中,计算设备包括一个或多个处理器(CPU)、输入/输出接口、网络接口和内存。
内存可能包括计算机可读介质中的非永久性存储器,随机存取存储器(RAM)和/或非易失性内存等形式,如只读存储器(ROM)或闪存(flash RAM)。内存是计算机可读介质的示例。
计算机可读介质包括永久性和非永久性、可移动和非可移动媒体可以由任何方法或技术来实现信息存储。信息可以是计算机可读指令、数据结构、程序的模块或其他数据。计算机的存储介质的例子包括,但不限于相变内存(PRAM)、静态随机存取存储器(SRAM)、动态随机存取存储器(DRAM)、其他类型的随机存取存储器(RAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、快闪记忆体或其他内存技术、只读光盘只读存储器(CD-ROM)、数字多功能光盘(DVD)或其他光学存储、磁盒式磁带,磁带磁磁盘存储或其他磁性存储设备或任何其他非传输介质,可用于存储可以被计算设备访问的信息。按照本文中的界定,计算机可读介质不包括暂存电脑可读媒体(transitory media),如调制的数据信号和载波。
还需要说明的是,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、商品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、商品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、商品或者设备中还存在另外的相同要素。
本领域技术人员应明白,本说明书的实施例可提供为方法、系统或计算机程序产品。因此,本说明书可采用完全硬件实施例、完全软件实施例或结合软件和硬件方面的实施例的形式。而且,本说明书可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本说明书可以在由计算机执行的计算机可执行指令的一般上下文中描述,例如程序模块。一般地,程序模块包括执行特定任务或实现特定抽象数据类型的例程、程序、对象、组件、数据结构等等。也可以在分布式计算环境中实践本说明书,在这些分布式计算环境中,由通过通信网络而被连接的远程处理设备来执行任务。在分布式计算环境中,程序模块可以位于包括存储设备在内的本地和远程计算机存储介质中。
本说明书中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于系统实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
以上所述仅为本说明书的实施例而已,并不用于限制本说明书。对于本领域技术人员来说,本说明书可以有各种更改和变化。凡在本说明书的精神和原理之内所作的任何修改、等同替换、改进等,均应包含在本申请的权利要求范围之内。
Claims (7)
1.一种基于多样性和随机策略的模型训练方法,其特征在于,待训练的分类模型包含多个基模型,每个基模型划分为串联的多层子模块,各基模型中相同层的子模块组成切换块,所述待训练的分类模型由多层切换块组成,所述方法包括:
获取待分类的样本图像以及所述样本图像的标注,并将所述样本图像输入所述待训练的分类模型;
按照前向传播方向,依次训练各层切换块,针对每一层待训练的切换块,在已固定参数的各前层切换块的基础上,确定该层切换块的输入,所述输入为该层切换块的上一层切换块中随机一子模块的输出或样本图像;
根据该层切换块的输入,确定该层切换块中各子模块的输出,根据随机策略确定各基模型输出的分类结果,根据各基模型输出的分类结果和所述样本图像的标注,确定各基模型的分类交叉熵损失,对所述各基模型的分类交叉熵损失求均值,得到所述待训练的分类模型的分类损失;
针对所述待训练的分类模型的每个基模型,确定预设的基准图像为,确定预设的插值路径为/>,根据公式
,确定所述样本图像相对于由当前正在训练的第n层切换块、已经固定参数的n-1层切换块和未固定参数的N-n层切换块组成的网络中第j个基模型的积分梯度/>,根据各基模型的积分梯度,以及公式
,确定所述待训练的分类模型的多样性正则损失/>,其中,/>为0~1间的插值常数,N为待训练的分类模型包含的切换块层数,该基模型为第j个基模型,该层切换块为第n层切换块,为所述插值路径的插值步数,/>表示第/>个插值步数,M表示基模型数量,/>为样本图像,/>为各插值步数对应的插值图像相对于由当前正在训练的第n层切换块、已经训练完成的n-1层切换块和未开始训练的N-n层切换块组成的网络对样本图像的分类预测结果的梯度,/>表示第j个基模型与第k个基模型的积分梯度之间的相似度,所述多样性正则损失的值越小代表各积分梯度之间的差异越大;
根据所述分类损失和所述多样性正则损失,训练未固定模型参数的各层切换块,直至满足训练结束条件为止,固定该层切换块的模型参数,并训练下一层切换块;
当各层切换块的模型参数均固定时,得到训练完成的分类模型;响应于携带待分类图像的任务请求,将所述待分类图像输入所述训练完成的分类模型,通过在各层切换块中随机选择的子模块组成的网络,所述分类模型输出所述待分类图像的分类结果。
2.如权利要求1所述的方法,其特征在于,针对每一层待训练的切换块,在已固定参数的各前层切换块的基础上,确定该层切换块的输入,所述输入为该层切换块的上一层切换块中随机一子模块的输出或样本图像,具体包括:
若该层切换块是第一层切换块,则该层切换块的输入为所述样本图像;
若该层切换块不是第一层切换块,则将所述样本图像通过已固定参数的各前层切换块后,该层切换块的上一层切换块中随机一子模块的输出,作为该层切换块的输入。
3.如权利要求1所述的方法,其特征在于,所述待训练的分类模型的最后一层切换块的各子模块中至少包括输出层;
根据该层切换块的输入,确定该层切换块中各子模块的输出,根据随机策略确定各基模型输出的分类结果,具体包括:
将该层切换块的输入作为该层切换块中各子模块的输入,确定该层切换块中各子模块的输出;
根据随机策略,将该层切换块中各子模块的输出输入之后未固定参数的各层切换块;
通过所述最后一层切换块中各子模块的输出层,得到最后一层切换块中各子模块输出的分类结果;
确定所述最后一层切换块中各子模块所属的基模型,将最后一层切换块中各子模块输出的分类结果,作为所述最后一层切换块中各子模块所属的基模型输出的分类结果。
4.如权利要求1所述的方法,其特征在于,根据所述分类损失和所述多样性正则损失,训练未固定模型参数的各层切换块,直至满足训练结束条件为止,固定该层切换块的模型参数,并训练下一层切换块,具体包括:
根据所述分类损失和所述多样性正则损失,确定所述待训练的分类模型的总损失;
以所述总损失最小为目标训练未固定模型参数的各层切换块,直至满足训练结束条件为止,对该层切换块的训练完成;
固定该层切换块的模型参数,并对未固定参数的各层切换块进行初始化,训练下一层切换块。
5.一种基于多样性和随机策略的模型训练装置,其特征在于,待训练的分类模型包含多个基模型,每个基模型划分为串联的多层子模块,各基模型中相同层的子模块组成切换块,所述待训练的分类模型由多层切换块组成,包括:
获取模块,获取待分类的样本图像以及所述样本图像的标注,并将所述样本图像输入所述待训练的分类模型;
模型训练模块,按照前向传播方向,依次训练各层切换块,针对每一层待训练的切换块,在已训练好的各前层切换块的基础上,确定该层切换块的输入,所述输入为该层切换块的上一层切换块中随机一子模块的输出或样本图像;
分类损失确定模块,根据该层切换块的输入,确定该层切换块中各子模块的输出,根据随机策略确定各基模型输出的分类结果,并基于所述样本图像的标注,根据各基模型输出的分类结果和所述样本图像的标注,确定各基模型的分类交叉熵损失,对所述各基模型的分类交叉熵损失求均值,得到所述待训练的分类模型的分类损失;
多样性正则损失确定模块,针对所述待训练的分类模型的每个基模型,确定预设的基准图像为,确定预设的插值路径为/>,根据公式
,确定所述样本图像相对于由当前正在训练的第n层切换块、已经固定参数的n-1层切换块和未固定参数的N-n层切换块组成的网络中第j个基模型的积分梯度/>,根据各基模型的积分梯度,以及公式
,确定所述待训练的分类模型的多样性正则损失/>,其中,/>为0~1间的插值常数,N为待训练的分类模型包含的切换块层数,该基模型为第j个基模型,该层切换块为第n层切换块,为所述插值路径的插值步数,/>表示第/>个插值步数,/>为样本图像,为各插值步数对应的插值图像相对于由当前正在训练的第n层切换块、已经训练完成的n-1层切换块和未开始训练的N-n层切换块组成的网络对样本图像的分类预测结果的梯度,/>表示第j个基模型与第k个基模型的积分梯度之间的相似度,M表示基模型数量,所述多样性正则损失的值越小代表各积分梯度之间的差异越大;
切换块训练模块,根据所述分类损失和所述多样性正则损失,训练未固定模型参数的各层切换块,直至满足训练结束条件为止,固定该层切换块的模型参数,并训练下一层切换块;
分类任务响应模块,当各层切换块的模型参数均固定时,得到训练完成的分类模型;响应于携带待分类图像的任务请求,将所述待分类图像输入所述训练完成的分类模型,通过在各层切换块中随机选择的子模块组成的网络,所述分类模型输出所述待分类图像的分类结果。
6.一种计算机可读存储介质,其特征在于,所述存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述权利要求1~4任一项所述的方法。
7.一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现上述权利要求1~4任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311293176.6A CN117036869B (zh) | 2023-10-08 | 2023-10-08 | 一种基于多样性和随机策略的模型训练方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311293176.6A CN117036869B (zh) | 2023-10-08 | 2023-10-08 | 一种基于多样性和随机策略的模型训练方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117036869A CN117036869A (zh) | 2023-11-10 |
CN117036869B true CN117036869B (zh) | 2024-01-09 |
Family
ID=88641546
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311293176.6A Active CN117036869B (zh) | 2023-10-08 | 2023-10-08 | 一种基于多样性和随机策略的模型训练方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117036869B (zh) |
Citations (21)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110413812A (zh) * | 2019-08-06 | 2019-11-05 | 北京字节跳动网络技术有限公司 | 神经网络模型的训练方法、装置、电子设备及存储介质 |
CN110705573A (zh) * | 2019-09-25 | 2020-01-17 | 苏州浪潮智能科技有限公司 | 一种目标检测模型的自动建模方法及装置 |
CN111860573A (zh) * | 2020-06-04 | 2020-10-30 | 北京迈格威科技有限公司 | 模型训练方法、图像类别检测方法、装置和电子设备 |
CN113032001A (zh) * | 2021-03-26 | 2021-06-25 | 中山大学 | 一种智能合约分类方法及装置 |
CN113228062A (zh) * | 2021-02-25 | 2021-08-06 | 东莞理工学院 | 基于特征多样性学习的深度集成模型训练方法 |
WO2021164306A1 (zh) * | 2020-09-17 | 2021-08-26 | 平安科技(深圳)有限公司 | 图像分类模型的训练方法、装置、计算机设备及存储介质 |
CN113723367A (zh) * | 2021-10-27 | 2021-11-30 | 北京世纪好未来教育科技有限公司 | 一种答案确定方法、判题方法及装置和电子设备 |
CN113850383A (zh) * | 2021-09-27 | 2021-12-28 | 平安科技(深圳)有限公司 | 文本匹配模型训练方法、装置、电子设备及存储介质 |
WO2022042123A1 (zh) * | 2020-08-25 | 2022-03-03 | 深圳思谋信息科技有限公司 | 图像识别模型生成方法、装置、计算机设备和存储介质 |
CN114494786A (zh) * | 2022-02-16 | 2022-05-13 | 重庆邮电大学 | 一种基于多层协调卷积神经网络的细粒度图像分类方法 |
CN114492574A (zh) * | 2021-12-22 | 2022-05-13 | 中国矿业大学 | 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法 |
WO2022142122A1 (zh) * | 2020-12-31 | 2022-07-07 | 平安科技(深圳)有限公司 | 实体识别模型的训练方法、装置、设备和存储介质 |
CN114764865A (zh) * | 2021-01-04 | 2022-07-19 | 腾讯科技(深圳)有限公司 | 数据分类模型训练方法、数据分类方法和装置 |
WO2022213846A1 (zh) * | 2021-04-07 | 2022-10-13 | 北京三快在线科技有限公司 | 识别模型的训练 |
CN115828162A (zh) * | 2023-02-08 | 2023-03-21 | 支付宝(杭州)信息技术有限公司 | 一种分类模型训练的方法、装置、存储介质及电子设备 |
CN116030309A (zh) * | 2023-02-03 | 2023-04-28 | 之江实验室 | 一种对抗样本生成方法、装置、存储介质及电子设备 |
WO2023070696A1 (zh) * | 2021-10-25 | 2023-05-04 | 中国科学院自动化研究所 | 针对连续学习能力系统的基于特征操纵的攻击和防御方法 |
WO2023077603A1 (zh) * | 2021-11-03 | 2023-05-11 | 深圳先进技术研究院 | 一种异常脑连接预测系统、方法、装置及可读存储介质 |
CN116129185A (zh) * | 2023-01-19 | 2023-05-16 | 北京工业大学 | 一种基于数据和模型协同更新的中医舌象腐腻特征模糊分类方法 |
CN116302898A (zh) * | 2023-05-17 | 2023-06-23 | 之江实验室 | 一种任务治理方法、装置、存储介质及电子设备 |
CN116342888A (zh) * | 2023-05-25 | 2023-06-27 | 之江实验室 | 一种基于稀疏标注训练分割模型的方法及装置 |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112669325B (zh) * | 2021-01-06 | 2022-10-14 | 大连理工大学 | 一种基于主动式学习的视频语义分割方法 |
US20230206114A1 (en) * | 2021-12-29 | 2023-06-29 | International Business Machines Corporation | Fair selective classification via a variational mutual information upper bound for imposing sufficiency |
-
2023
- 2023-10-08 CN CN202311293176.6A patent/CN117036869B/zh active Active
Patent Citations (21)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110413812A (zh) * | 2019-08-06 | 2019-11-05 | 北京字节跳动网络技术有限公司 | 神经网络模型的训练方法、装置、电子设备及存储介质 |
CN110705573A (zh) * | 2019-09-25 | 2020-01-17 | 苏州浪潮智能科技有限公司 | 一种目标检测模型的自动建模方法及装置 |
CN111860573A (zh) * | 2020-06-04 | 2020-10-30 | 北京迈格威科技有限公司 | 模型训练方法、图像类别检测方法、装置和电子设备 |
WO2022042123A1 (zh) * | 2020-08-25 | 2022-03-03 | 深圳思谋信息科技有限公司 | 图像识别模型生成方法、装置、计算机设备和存储介质 |
WO2021164306A1 (zh) * | 2020-09-17 | 2021-08-26 | 平安科技(深圳)有限公司 | 图像分类模型的训练方法、装置、计算机设备及存储介质 |
WO2022142122A1 (zh) * | 2020-12-31 | 2022-07-07 | 平安科技(深圳)有限公司 | 实体识别模型的训练方法、装置、设备和存储介质 |
CN114764865A (zh) * | 2021-01-04 | 2022-07-19 | 腾讯科技(深圳)有限公司 | 数据分类模型训练方法、数据分类方法和装置 |
CN113228062A (zh) * | 2021-02-25 | 2021-08-06 | 东莞理工学院 | 基于特征多样性学习的深度集成模型训练方法 |
CN113032001A (zh) * | 2021-03-26 | 2021-06-25 | 中山大学 | 一种智能合约分类方法及装置 |
WO2022213846A1 (zh) * | 2021-04-07 | 2022-10-13 | 北京三快在线科技有限公司 | 识别模型的训练 |
CN113850383A (zh) * | 2021-09-27 | 2021-12-28 | 平安科技(深圳)有限公司 | 文本匹配模型训练方法、装置、电子设备及存储介质 |
WO2023070696A1 (zh) * | 2021-10-25 | 2023-05-04 | 中国科学院自动化研究所 | 针对连续学习能力系统的基于特征操纵的攻击和防御方法 |
CN113723367A (zh) * | 2021-10-27 | 2021-11-30 | 北京世纪好未来教育科技有限公司 | 一种答案确定方法、判题方法及装置和电子设备 |
WO2023077603A1 (zh) * | 2021-11-03 | 2023-05-11 | 深圳先进技术研究院 | 一种异常脑连接预测系统、方法、装置及可读存储介质 |
CN114492574A (zh) * | 2021-12-22 | 2022-05-13 | 中国矿业大学 | 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法 |
CN114494786A (zh) * | 2022-02-16 | 2022-05-13 | 重庆邮电大学 | 一种基于多层协调卷积神经网络的细粒度图像分类方法 |
CN116129185A (zh) * | 2023-01-19 | 2023-05-16 | 北京工业大学 | 一种基于数据和模型协同更新的中医舌象腐腻特征模糊分类方法 |
CN116030309A (zh) * | 2023-02-03 | 2023-04-28 | 之江实验室 | 一种对抗样本生成方法、装置、存储介质及电子设备 |
CN115828162A (zh) * | 2023-02-08 | 2023-03-21 | 支付宝(杭州)信息技术有限公司 | 一种分类模型训练的方法、装置、存储介质及电子设备 |
CN116302898A (zh) * | 2023-05-17 | 2023-06-23 | 之江实验室 | 一种任务治理方法、装置、存储介质及电子设备 |
CN116342888A (zh) * | 2023-05-25 | 2023-06-27 | 之江实验室 | 一种基于稀疏标注训练分割模型的方法及装置 |
Non-Patent Citations (1)
Title |
---|
基于条件生成式对抗网络的数据增强方法;陈文兵;管正雄;陈允杰;;计算机应用(第11期);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN117036869A (zh) | 2023-11-10 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114821823B (zh) | 图像处理、人脸防伪模型的训练及活体检测方法和装置 | |
CN115618964B (zh) | 一种模型训练的方法、装置、存储介质及电子设备 | |
CN116304720A (zh) | 一种代价模型训练的方法、装置、存储介质及电子设备 | |
CN114943307A (zh) | 一种模型训练的方法、装置、存储介质以及电子设备 | |
CN116049761A (zh) | 数据处理方法、装置及设备 | |
CN114372566A (zh) | 图数据的增广、图神经网络训练方法、装置以及设备 | |
CN117036829A (zh) | 一种基于原型学习实现标签增强的叶片细粒度识别方法和系统 | |
CN117036869B (zh) | 一种基于多样性和随机策略的模型训练方法及装置 | |
CN117392694A (zh) | 数据处理方法、装置及设备 | |
CN116308738B (zh) | 一种模型训练的方法、业务风控的方法及装置 | |
CN116403097A (zh) | 一种目标检测方法、装置、存储介质及电子设备 | |
CN116805393A (zh) | 一种基于3DUnet光谱-空间信息融合的高光谱图像分类方法和系统 | |
CN116152933A (zh) | 一种异常检测模型的训练方法、装置、设备及存储介质 | |
CN115618748A (zh) | 一种模型优化的方法、装置、设备及存储介质 | |
CN117036870B (zh) | 一种基于积分梯度多样性的模型训练和图像识别方法 | |
CN112417275A (zh) | 一种信息提供的方法、装置存储介质及电子设备 | |
CN116991388B (zh) | 一种深度学习编译器的图优化序列生成方法及装置 | |
CN117392374B (zh) | 目标检测方法、装置、设备及存储介质 | |
CN115545938B (zh) | 一种执行风险识别业务的方法、装置、存储介质及设备 | |
CN112884478B (zh) | 一种数据处理方法、装置及设备 | |
CN116340852B (zh) | 一种模型训练、业务风控的方法及装置 | |
CN116363418A (zh) | 一种训练分类模型的方法、装置、存储介质及电子设备 | |
CN117576522B (zh) | 一种基于拟态结构动态防御的模型训练方法及装置 | |
CN117592056A (zh) | 一种模型的防盗取检测方法、装置、存储介质及电子设备 | |
WO2024113932A1 (zh) | 一种模型优化的方法、装置、设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |