CN116562362A - 一种基于混合策略博弈的对抗训练微调方法 - Google Patents
一种基于混合策略博弈的对抗训练微调方法 Download PDFInfo
- Publication number
- CN116562362A CN116562362A CN202310500553.2A CN202310500553A CN116562362A CN 116562362 A CN116562362 A CN 116562362A CN 202310500553 A CN202310500553 A CN 202310500553A CN 116562362 A CN116562362 A CN 116562362A
- Authority
- CN
- China
- Prior art keywords
- training
- model
- countermeasure
- fine tuning
- game
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 122
- 238000000034 method Methods 0.000 title claims abstract description 57
- 238000005070 sampling Methods 0.000 claims abstract description 14
- 238000009826 distribution Methods 0.000 claims description 31
- 230000006870 function Effects 0.000 claims description 28
- 230000008569 process Effects 0.000 claims description 18
- 238000005457 optimization Methods 0.000 claims description 9
- 238000013461 design Methods 0.000 claims description 6
- 238000012935 Averaging Methods 0.000 claims description 3
- 238000012804 iterative process Methods 0.000 claims description 3
- 238000010200 validation analysis Methods 0.000 claims description 2
- 238000003058 natural language processing Methods 0.000 description 4
- 238000011156 evaluation Methods 0.000 description 3
- 238000012360 testing method Methods 0.000 description 3
- 238000012512 characterization method Methods 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 1
- 241000701959 Escherichia virus Lambda Species 0.000 description 1
- 239000000427 antigen Substances 0.000 description 1
- 102000036639 antigens Human genes 0.000 description 1
- 108091007433 antigens Proteins 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/094—Adversarial learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/35—Clustering; Classification
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
- G06N3/0455—Auto-encoder networks; Encoder-decoder networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0475—Generative networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/04—Inference or reasoning models
- G06N5/042—Backward inferencing
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Software Systems (AREA)
- Computational Linguistics (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Databases & Information Systems (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了一种基于混合策略博弈的对抗训练微调方法。本发明步骤如下:S1:确定预训练模型、目标数据集、训练任务;S2:微调预训练模型;S3:设计基于混合策略博弈的对抗训练微调目标函数;S4:求解混合策略博弈;S5:生成与更新对抗扰动;S6:更新模型参数;S7:训练与评估模型。本发明包括将混合策略博弈引入预训练模型进行微调的对抗训练中,用的博弈论方法‑熵镜下降法推导出纳什均衡来解决上述博弈。此外本发明还利用采样定理和随机梯度郎之万动力学采样法将该方法简化为一种性能优化的实用算法。通过本发明方法训练得到的模型在泛化性能和鲁棒性能上都可以得到提升。
Description
技术领域
本发明涉及机器学习领域,尤其涉及一种基于基于混合策略博弈的对大规模预训练模型进行对抗训练微调的方法及装置。
背景技术
基于大规模文本数据的预训练模型(如BERT、GPT、T5等)在几乎所有的自然语言处理任务中都取得了显著的进展。例如,BERT的基本技术突破是使用双向训练的Transformer和注意力模型来执行语言建模。与早期从左到右或双向训练相结合的文本序列的研究相比,双向训练的语言模型可以更好地理解语言上下文。BERT使用注意力机制以及学习单词之间上下文关系的Transformer。Transformer由两个独立的部分组成——编码器和解码器。编码器读取输入文本,解码器为任务生成预测。与顺序读取输入文本的传统定向模型(如LSTM等)相比,Transformer的编码器一次读取整个单词序列。
针对文本分类、Q&A等不同类型的下游任务,可以通过对与训练模型进行微调实现。虽然预训练模型的大量参数增强了它的能力,但也使模型的开发、训练和使用变得困难。主要问题是,在微调过程中,预训练模型可能会过拟合目标任务的训练数据,导致泛化性能较差。最近的一些研究表明,将微调与对抗训练相结合可以成功缓解上述问题,并提高模型在下游任务中的泛化能力。此外,微调阶段的对抗性训练主要用作正则化方法,以防止过拟合,而不是在计算机视觉中广泛应用的保护模型免受对抗性攻击。
发明内容
本申请主要解决的技术问题是提供一种基于混合策略博弈的对抗训练微调方法,能够改善自然语言处理的预训练模型微调过程中的过拟合问题,并提高模型的泛化能力和防御对抗攻击能力。本说明书中的一个或多个实施例描述了从预训练模型的检查点对大规模模型进行微调,已被证明对各种自然语言处理的任务是有效的。然而,由于目标的非凸性,以往的对抗训练方法容易收敛到局部最优。
针对现有技术方案的不足,本发明提供一种基于混合策略博弈的对抗训练微调方法。
本发明解决该技术问题所采用的技术方案包括以下步骤:
S1:确定预训练模型、目标数据集、训练任务;
S2:预训练模型微调;
S3:基于混合策略博弈的对抗训练微调目标函数设计;
S4:混合策略博弈求解;
S5:对抗扰动生成与更新;
S6:模型参数更新;
S7:模型训练与评估。
步骤S1和步骤S2执行的目标在于获得一个针对目标数据集和任务进行初步微调后的预训练模型。
步骤S1所述的确定预训练模型、目标数据集、训练任务,具体实现流程如下:
1-1.确定目标预训练模型fθ(·)和目标数据集D。其中,目标预训练模型fθ(·)已包含原始模型参数θ。目标数据集D={(x,y)},x表示数据样本,y表示对应的标签。
1-2.根据fθ(·)和D,确认训练任务与微调目标函数L(fθ(x),y)。同时,根据选定的目标任务修改fθ(·)的顶层模型结构。例如,目标模型为C类的多分类模型,可以将fθ(·)的顶层结构替换为输出维度为C的全连接层。
步骤S2所述的预训练模型微调,具体实现流程如下:
2-1.从目标数据集D中随机采样N个样本数据输入模型fθ(·)。
2-2.针对给定的微调目标函数L(fθ(x),y),计算损失值,并进行反向传播,更新模型参数θ。
2-3.重复2-1至2-2,直至模型收敛,得到微调后的模型fθ(·)。
步骤S3-S6执行的目标在于通过对抗训练微调对初步微调后的模型性能进行进一步优化。本发明将对抗性训练视为一种博弈,并利用博弈论中的混合策略对其进行改进。本发明主张对抗训练是模型和对抗扰动之间的两方完全信息博弈。现有的对抗性训练是符合纯策略博弈的,双方的策略都是具体确定的。相反,本发明将对抗训练扩展到策略为概率的混合策略博弈中,将现有的策略转化为概率性的。
步骤S3所述的基于混合策略博弈的对抗训练微调目标函数设计,具体实流程如下:
3-1.基于混合策略博弈的对抗训练微调目标函数设计。一般的预训练模型,如BERT,的模型参数θ都被认为是可确定的一般变量。在本发明中,参数的取值可称为模型的策略。由于参数值的连续性,该模型存在无限种策略。但无论使用哪种优化器,参数的每次更新都是确定的,这意味着模型每次只选择一种纯策略,因此这仍然是纯策略。为了提升对抗训练过程中模型的泛化性能,在本发明将模型从纯策略转换为混合策略,使用的方法是让模型参数θ服从概率分布,即θ转变成一个连续的随机变量。这样,在模型训练过程中,模型将更新一个其分布而不是选择一个确定的策略值。类似地,本发明也将对抗扰动δ从确定的扰动值转换为一个连续随机变量。因此,本发明将对抗训练的混合策略博弈重新定义为:两个博弈方分别为模型和对抗扰动,双方的策略分别为其参数的分布,而博弈的收益为目标函数的值。
针对初步微调后的模型fθ(·)和目标数据集D,对抗训练微调目标设定为
上述目标函数中,Θ表示θ服从的概率分布的集合,Δ表示δ服从的概率分布的集合。l(fθ(x+δ),fθ(x))用于描述生成的对抗样本与原始样本对于目标模型的相似度。对于给定的数据集D,随机采样一批数据B,在最小化模型在批训练数据B上的损失值的条件下,最大化对抗扰动与原始数据之间的差异。λ表示两者之间的调谐参数。进一步地,本发明用M(Θ)和M(Δ)表示Θ和Δ上的所有Borel概率度量的集合,将原对抗训练的目标函数转换为以下的Min-Max函数:
步骤S4所述的混合策略博弈求解,具体实现流程如下:
4-1.步骤S3中得到的Min-Max优化目标为纳什均衡博弈问题。本发明采用熵镜像下降算法(Entropic Mirror Descent,EMD)求解该博弈问题。将上述Min-Max优化目标记为P,则可以通过下式对
其中,μt和vt分别为第t轮的得到的扰动和参数分布。给定随机变量z,对应梯度h和学习率η,EMD中的MD迭代可表示为:
将该MD迭代过程扩展到无穷维,可得
由于无法获得μ的密度函数和v的密度函数,本发明采用一个常用的方法即用经验平均值代替对应分布的期望
其中,表示Kδ次采样的δ均值,/>表示Kθ次采样的θ均值。同时,MD迭代也可表示为一种更容易处理的形式:
综上,本发明通过T轮的MD迭代,可以求解上述纳什均衡博弈问题。
步骤S5所述的对抗扰动生成与更新,根据步骤S4中所述迭代求解过程,首先计算对抗样本经验平均,具体实流程如下:
5-1.给定每一轮中用于计算对抗样本经验平均的次数Kδ。
5-2.初始化第t(t<T)轮的对抗样本初始分布和经验平均/>输入大小为n的批训练数据B,利用随机梯度朗日万动力学采样,更新/>
其中,γt表示采样步长,ε是热噪声,ξ=N(0,1)是标准正态分布。
5-3.根据下式计算
其中,β为超参数,用于均衡历史均值和当前分布对经验均值的影响。
5-4.重复5-2至5-3Kδ次,得到第t轮的对抗样本经验平均
步骤S6所述模型参数更新,根据步骤S4中所述迭代求解过程,需要计算模型参数的经验平均,具体实现流程如下:
6-1.给定每一轮对抗训练中用于计算模型参数经验平均的次数Kθ。
6-2.初始化模型第t(t<T)轮模型参数和经验平均/>在输入数据大小为n的批训练数据B上,根据步骤S4中生成的对抗扰动经验平均/>利用随机梯度朗日万动力学采样,更新/>
6-3.根据下式计算
其中,β为超参数,用于均衡历史均值和当前分布对经验均值的影响。
6-4.重复6-2至6-3Kθ次,得到第t轮的模型参数经验平均并更新模型权重:
步骤S7所述的模型训练与评估,具体实现流程如下:
7-1.重复步骤S5至S6T次,得到对抗训练优化后的模型
7-2.使用测试数据进行模型性能评估。为了验证所提出方法的有效性,可将对抗训练微调后的模型与未经过对抗训练微调的模型fθ在测试数据上的性能进行比较。
本发明有益效果如下:
本发明从博弈论的混合策略角度对对抗训练进行了重新规划,并引入了完整的策略空间。在方法上,利用熵镜下降法推导了混合策略对抗训练的纳什均衡,发明了一种新的混合策略对抗训练算法。同时在数值上,验证了当本发明运用在BERT和RoBERTa等大规模预训练模型上时,本申请书实现的方法训练的模型在泛化和鲁棒性方面都优于现有的技术。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的一些附图仅仅是本申请的部分实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。其中:
图1是本发明的神经网络模型对于自然语言处理分类任务的句子输入示意图其中图1.A是两个句子的输入情况,图1.B是一个句子的输入情况;
图2是预训练模型微调的流程示意图;
图3是预训练模型对抗训练微调的流程示意图;
图4是本发明提出的混合策略对抗训练微调算法的流程示意图;
图5是利用熵镜像下降算法得到的训练算法伪代码;
图6是利用随机梯度郎之万动力学采样法得到的训练算法伪代码。
具体实施方式
为了使本技术领域的人员更好地理解本申请中的技术方案,下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本申请保护的范围。
一种基于混合策略博弈的对抗训练微调方法,包括针对目标数据集的微调训练和基于混合博弈策略的对抗训练微调2个部分,具体实施方式如下:
S1:确定预训练模型、目标数据集、训练任务;
S2:预训练模型微调;
S3:基于混合策略博弈的对抗训练微调目标函数设计;
S4:混合策略博弈求解;
S5:对抗扰动生成与更新;
S6:模型参数更新;
S7:模型训练与评估。
步骤S1确定预训练模型、目标数据集、训练任务,具体过程如下:
1-1.本实施例以BERT模型作为目标预训练模型,并获取公开的BERT模型预训练权重θ。本实施例以文本分类作为目标任务,选取文本分类目标数据集(如AGNews),并对数据进行预处理。
对于下游任务的输入,可以参考图1所示,下游任务的输入有可能是一个句子也有可能是句子对,为了使BERT可以处理所有这些任务,BERT的输入既可以是一个句子,也可以是句子对。具体来讲:一个句子可以是一段连续的文字,不一定是真正语义上的一个句子。BERT的输入叫做一个序列,所谓序列,可以是一个句子,也可以是两个句子。可能存在单句和双句的情况,但BERT通常是将序列的第一个词永远是一个特殊的记号[CLS],代表是分类(Classification)。这个词的作用是,BERT希望它最后的输出代表是整个序列的信息,比如说对整个句子层面的信息。在图1中可以看到输入的组成:[CLS]和[SEP]。每个token进入BERT模型,得到token的Embedding表示。
1-2.将BERT的顶层模型修改为输出维度为4的3层全联接网络,目标任务的损失函数为cross-entropy。
步骤S2预训练模型微调,具体过程如下:
2-1.在步骤1-1处理后的文数据集上,随机采样batchsize为N的批数据,输入修改后的BERT模型;
2-2.使用Adam作为优化器,通过反向传播对修改后的BERT权重θ进行微调;
2-3.重复2-1至2-2直至模型收敛,并得到目标模型fθ,具体流程如图2所示。
步骤S3基于混合策略博弈的对抗训练微调目标函数设计,具体过程如下:
3-1.本发明将模型参数与对抗扰动视为博弈双方,并采用混合策略对该博弈过程进行求解。对于从目标数据集D中采样的批训练数据B,该博弈过程可以表示为以下Min-Max目标函数:
其中,μ表示θ服从的分布,v表示δ服从的分布,Θ表示所有可能的θ所属分布的集合,Δ表示所有可能的ξ所属分布的集合。M(Θ)和M(Δ)表示Θ和Δ上的所有Borel概率度量的集合。L(fθ(x),y)为S1中确定的目标任务函数,l(fθ(x+δ),fθ(x))为对抗训练损失函数。在本实施例的分类任务中,选择l(·)为KL-divergence值,即在回归任务中,l一般为L2距离损失,即l(a,b)=(a-b)2。λ是两个损失函数的调谐参数。
步骤S4混合策略博弈求解,具体过程如下:
4-1本发明采用熵镜像下降算法(EntropicMirrorDescent,EMD)步骤S3中设计的Min-Max混合策略博弈问题。将上述Min-Max优化目标记为F,则通过下式对δ和θ进行更新:
其中,μt和vt分别为第t轮的得到的扰动和参数分布。将MD迭代过程扩展至无限维,并通过T次对抗训练,可得:
本发明提出的对抗训练流程如图4所示,基于MD迭代求解过程的流程伪代码如图5所示。
步骤S5对抗扰动生成与更新,具体过程如下:
5-1.设定对抗样本经验平均次数Kδ=K。
5-2.初始化第t(t<T)轮的对抗样本初始分布为F维随机向量,F为文本在BERT模型中的嵌入维度。同时,初始化经验平均/>从处理过后的目标文本数据集中,随机采样n个文本样本{xi}n,输入微调后的BERT模型,依据步骤S4中提出混合策略博弈求解,利用随机梯度朗日万动力学采样,更新/>
其中,γt表示采样步长,ε是热噪声,ξ=N(0,1)是标准正态分布。
5-3.将当前训练轮次计算得到的文本表征向量扰动加入其经验平均,并对扰动的经验平均值进行更新:
其中,β为超参数,用于均衡历史均值和当前分布对经验均值的影响。
5-4.重复5-2至5-3Kδ次,得到第t轮文本表征向量扰动的经验平均
步骤S6模型参数更新,具体过程如下:
6-1.设定对抗训练中用于计算模型参数经验平均的次数Kθ=K。
6-2.使用第t-1轮获取的BERT模型权重初始化模型第t轮模型参数 并初始化第t轮模型参数经验平均/>在步骤5-2采样的大小为n的批训练数据B上,结合步骤S5中生成的对抗扰动经验平均/>依据步骤S4中提出混合策略博弈求解,利用随机梯度朗日万动力学采样,更新/>
6-3.将当前训练轮次计算得到的BERT模型参数加入其经验平均,并对扰动的经验平均值进行更新/>
其中,β为超参数,用于均衡历史均值和当前分布对经验均值的影响。
6-4.重复6-2至6-3Kθ次,得到第t轮的模型参数经验平均并更新BERT模型权重:
步骤S7模型训练与评估,具体过程如下:
7-1.重复步骤S5至S6T次,得到对抗训练优化后的用于文本分类的BERT模型
7-2.从选定的目标文本数据集中选取测试数据,分别输入对抗训练微调后的BERT模型和未经对抗训练微调的BERT模型fθ,并比较两者之间的性能,以说明本发明所提出的基于混合策略博弈的对抗训练方法的有效性。
综上所述,图4是本申请提出的混合策略对抗训练微调算法的流程示意图。一般地,通过图6所示的相关伪代码可以实现本发明的整个模型参数的更新流程。
以上所述的具体实施方式,对本发明的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上所述仅为本发明的具体实施方式而已,并不用于限定本发明的保护范围,凡在本发明的技术方案的基础之上,所做的任何修改、等同替换、改进等,均应包括在本发明的保护范围之内。
Claims (7)
1.一种基于混合策略博弈的对抗训练微调方法,其特征在于包括如下步骤:
S1:确定预训练模型、目标数据集、训练任务;
S2:微调预训练模型;
S3:设计基于混合策略博弈的对抗训练微调目标函数;
S4:求解混合策略博弈;
S5:生成与更新对抗扰动;
S6:更新模型参数;
S7:训练与评估模型。
2.根据权利要求1所述的一种基于混合策略博弈的对抗训练微调方法,其特征在于步骤S1确定预训练模型、目标数据集、训练任务,具体实现如下:
1-1.确定目标预训练模型fθ(·)和目标数据集D,其中,目标预训练模型fθ(·)已包含原始模型参数θ;目标数据集D={(x,y)},x表示数据样本,y表示对应的标签;
1-2.根据目标预训练模型fθ(·)和目标数据集D,确认训练任务与微调目标函数L(fθ(x),y);
1-3.根据选定的目标任务改进目标预训练模型fθ(·)的顶层结构。
3.根据权利要求2所述的一种基于混合策略博弈的对抗训练微调方法,其特征在于步骤S2所述的预训练模型微调,具体实现如下:
2-1.从目标数据集D中随机采样N个样本数据输入模型fθ(·);
2-2.针对给定的微调目标函数L(fθ(x),y),计算损失值,并进行反向传播,更新模型参数θ;
2-3.重复2-1至2-2,直至模型收敛,得到微调后的模型fθ(·)。
4.根据权利要求2所述的一种基于混合策略博弈的对抗训练微调方法,其特征在于步骤S3设计基于混合策略博弈的对抗训练微调目标函数,具体实现如下:
将模型从纯策略转换为混合策略,让模型参数θ服从概率分布,即θ转变成一个连续的随机变量;将对抗扰动δ从确定的扰动值转换为一个连续随机变量,服从概率分布;将对抗训练的混合策略博弈重新定义为:两个博弈方分别为模型和对抗扰动,双方的策略分别为其参数的分布,而博弈的收益为目标函数的值;
给定目标预训练模型fθ(·),下游目标数据集D,对抗训练目标表示为以下博弈过程:
其中,l(fθ(x+δ),fθ(x))表示对抗训练目标函数,λ为调谐参数;考虑Θ和Δ上的所有概率分布集合;如果用M(Θ)和M(Δ)表示Θ和Δ上的所有Borel概率度量的集合,则将原对抗训练的目标函数转换为以下的Min-Max函数:
5.根据权利要求4所述的一种基于混合策略博弈的对抗训练微调方法,其特征在步骤S4具体方法如下:
使用熵镜像下降算法对Min-Max博弈优化进行求解:
其中,P表示上述Min-Max博弈优化目标;μt和vt分别为第t轮的得到的扰动和参数分布;给定随机变量z,对应梯度h和学习率η,EMD中的MD迭代可表示为:
将该MD迭代过程扩展到无穷维,可得:
由于无法获得μ的密度函数和v的密度函数,采用经验平均值代替对应分布的期望E:
其中,表示Kδ次采样的δ均值,/>表示Kθ次采样的θ均值;同时,MD迭代表示为一种更容易处理的形式:
因此通过T轮的MD迭代能够求解上述纳什均衡博弈问题。
6.根据权利要求5所述的一种基于混合策略博弈的对抗训练微调方法,其特征在步骤S5具体方法如下:
5-1.给定每一轮中用于计算对抗样本经验平均的次数Kδ;
5-2.初始化第t轮的对抗样本初始分布和经验平均/>输入大小为n的批训练数据B,利用随机梯度朗日万动力学采样,得到扰动更新的迭代式,更新/>
其中,0≤k<Kδ-1,表示第t轮对抗训练优化过程中第k次采样的对抗扰动分布,γt表示采样步长,ε是热噪声,ξ=N(0,1)是标准正态分布;/>表示对l进行求关于δ的梯度;
5-3.根据下式计算
其中,β为超参数,用于均衡历史均值和当前分布对经验均值的影响;
5-4.重复步骤5-2至5-3Kδ次,得到第t轮的对抗样本经验平均
7.根据权利要求6所述的一种基于混合策略博弈的对抗训练微调方法,其特征在步骤S6具体方法如下:
6-1.给定每一轮对抗训练中用于计算模型参数经验平均的次数Kθ;
6-2.初始化模型第t轮模型参数和经验平均/>在输入数据大小为n的批训练数据B上,根据步骤S4中生成的对抗扰动经验平均/>利用随机梯度朗日万动力学采样,更新/>
6-3.根据下式计算
其中,0≤k<Kθ-1,β为超参数,用于均衡历史均值和当前分布对经验均值的影响;
6-4.重复步骤6-2至6-3Kθ次,表示第t轮对抗训练优化过程中第k次采样的模型参数分布,最终得到第t轮的模型参数经验平均/>并更新模型权重:
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310500553.2A CN116562362A (zh) | 2023-05-06 | 2023-05-06 | 一种基于混合策略博弈的对抗训练微调方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310500553.2A CN116562362A (zh) | 2023-05-06 | 2023-05-06 | 一种基于混合策略博弈的对抗训练微调方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116562362A true CN116562362A (zh) | 2023-08-08 |
Family
ID=87499448
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310500553.2A Pending CN116562362A (zh) | 2023-05-06 | 2023-05-06 | 一种基于混合策略博弈的对抗训练微调方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116562362A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117236701A (zh) * | 2023-11-13 | 2023-12-15 | 清华大学 | 一种基于博弈分析的鲁棒风险识别方法 |
-
2023
- 2023-05-06 CN CN202310500553.2A patent/CN116562362A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117236701A (zh) * | 2023-11-13 | 2023-12-15 | 清华大学 | 一种基于博弈分析的鲁棒风险识别方法 |
CN117236701B (zh) * | 2023-11-13 | 2024-02-09 | 清华大学 | 一种基于博弈分析的鲁棒风险识别方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Murugan et al. | Regularization and optimization strategies in deep convolutional neural network | |
CN108197736B (zh) | 一种基于变分自编码器和极限学习机的空气质量预测方法 | |
CN109033095B (zh) | 基于注意力机制的目标变换方法 | |
CN111563706A (zh) | 一种基于lstm网络的多变量物流货运量预测方法 | |
CN110163433B (zh) | 一种船舶流量预测方法 | |
CN107578061A (zh) | 基于最小化损失学习的不平衡样本分类方法 | |
CN109345027B (zh) | 基于独立成分分析与支持向量机的微电网短期负荷预测方法 | |
CN111477247A (zh) | 基于gan的语音对抗样本生成方法 | |
CN116562362A (zh) | 一种基于混合策略博弈的对抗训练微调方法 | |
Wehenkel et al. | Diffusion priors in variational autoencoders | |
CN111477220A (zh) | 一种面向家居口语环境的神经网络语音识别方法及系统 | |
CN114548591A (zh) | 一种基于混合深度学习模型和Stacking的时序数据预测方法及系统 | |
CN115471665A (zh) | 基于三分图视觉Transformer语义信息解码器的抠图方法与装置 | |
CN113705724B (zh) | 基于自适应l-bfgs算法的深度神经网络的批量学习方法 | |
CN114691858A (zh) | 一种基于改进的unilm摘要生成方法 | |
CN115599918B (zh) | 一种基于图增强的互学习文本分类方法及系统 | |
Yang et al. | Towards stochastic neural network via feature distribution calibration | |
Mendonça et al. | Adversarial training with informed data selection | |
CN115510986A (zh) | 一种基于AdvGAN的对抗样本生成方法 | |
CN114743049A (zh) | 一种基于课程学习与对抗训练的图像分类方法 | |
CN108417204A (zh) | 基于大数据的信息安全处理方法 | |
Joo et al. | Towards more robust interpretation via local gradient alignment | |
CN108388942A (zh) | 基于大数据的信息智能处理方法 | |
Liu et al. | Unstoppable Attack: Label-Only Model Inversion via Conditional Diffusion Model | |
CN114492596A (zh) | 基于变分自编码器的成员推理攻击抵御方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |