CN112580732B - 模型训练方法、装置、设备、存储介质和程序产品 - Google Patents
模型训练方法、装置、设备、存储介质和程序产品 Download PDFInfo
- Publication number
- CN112580732B CN112580732B CN202011563834.5A CN202011563834A CN112580732B CN 112580732 B CN112580732 B CN 112580732B CN 202011563834 A CN202011563834 A CN 202011563834A CN 112580732 B CN112580732 B CN 112580732B
- Authority
- CN
- China
- Prior art keywords
- round
- model
- training
- sample set
- trained
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 227
- 238000000034 method Methods 0.000 title claims abstract description 61
- 230000004044 response Effects 0.000 claims abstract description 13
- 230000006399 behavior Effects 0.000 claims description 22
- 239000013598 vector Substances 0.000 claims description 22
- 238000002372 labelling Methods 0.000 claims description 21
- 238000004364 calculation method Methods 0.000 claims description 17
- 238000012545 processing Methods 0.000 claims description 11
- 238000011478 gradient descent method Methods 0.000 claims description 6
- 238000010606 normalization Methods 0.000 claims description 5
- 238000013473 artificial intelligence Methods 0.000 abstract description 3
- 238000013135 deep learning Methods 0.000 abstract description 3
- 230000008569 process Effects 0.000 description 13
- 238000010586 diagram Methods 0.000 description 11
- 238000004590 computer program Methods 0.000 description 10
- 238000004891 communication Methods 0.000 description 8
- 230000006870 function Effects 0.000 description 8
- 238000005516 engineering process Methods 0.000 description 3
- 238000012986 modification Methods 0.000 description 3
- 230000004048 modification Effects 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 230000006872 improvement Effects 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 238000003062 neural network model Methods 0.000 description 2
- 238000003491 array Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000007667 floating Methods 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000001953 sensory effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Image Analysis (AREA)
Abstract
本申请公开了一种模型训练方法、装置、设备、存储介质和程序产品,涉及计算机技术领域,尤其涉及人工智能和深度学习技术领域。具体实现方案为:在样本集合中获取本轮样本集输入至待训练模型中,并根据待训练模型的输出结果,计算本轮训练损失;根据所述本轮训练损失,确定本轮扰动项,并将所述本轮扰动项加入至所述本轮样本集中,得到本轮对抗样本集;使用所述本轮样本集和所述本轮对抗样本集共同对待训练模型进行训练,得到本轮训练模型;将所述本轮训练模型确定为新的待训练模型后,返回在样本集合中获取本轮样本集输入至待训练模型中的操作,响应于满足结束训练条件,获取目标训练模型。本申请实施例的技术方案,提高了模型的泛化性能。
Description
技术领域
本申请涉及计算机技术领域,尤其涉及人工智能和深度学习技术,具体涉及一种模型训练方法、装置、设备、存储介质和程序产品。
背景技术
随着计算机技术的迅速发展,深度学习技术在图像分类识别、自然语言处理等技术领域得到广泛应用。
在深度神经网络模型使用过程中,经常因输入特征的细微变化,导致分类错误,因此,提高深度神经网络模型的泛化性能和鲁棒性显得十分重要。
发明内容
本申请提供了一种模型训练方法、装置、设备、存储介质和程序产品。
根据本申请的一方面,提供了一种模型训练方法,所述方法包括:
在样本集合中获取本轮样本集输入至待训练模型中,并根据待训练模型的输出结果,计算本轮训练损失;
根据所述本轮训练损失,确定本轮扰动项,并将所述本轮扰动项加入至所述本轮样本集中,得到本轮对抗样本集;
使用所述本轮样本集和所述本轮对抗样本集共同对待训练模型进行训练,得到本轮训练模型;
将所述本轮训练模型确定为新的待训练模型后,返回在样本集合中获取本轮样本集输入至待训练模型中的操作,响应于满足结束训练条件,获取目标训练模型。
根据本申请的另一方面,提供了一种模型训练装置,所述装置包括:
损失计算模块,用于在样本集合中获取本轮样本集输入至待训练模型中,并根据待训练模型的输出结果,计算本轮训练损失;
对抗样本获取模块,用于根据所述本轮训练损失,确定本轮扰动项,并将所述本轮扰动项加入至所述本轮样本集中,得到本轮对抗样本集;
模型训练模块,用于使用所述本轮样本集和所述本轮对抗样本集共同对待训练模型进行训练,得到本轮训练模型;
目标模型获取模块,用于将所述本轮训练模型确定为新的待训练模型后,返回在样本集合中获取本轮样本集输入至待训练模型中的操作,响应于满足结束训练条件,获取目标训练模型。
根据本申请的另一方面,提供了一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行本申请实施例中任一项所述的模型训练方法。
根据本申请的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行本申请实施例中任一项所述的模型训练方法。
根据本申请的另一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现本申请实施例中任一项所述的模型训练方法。
根据本申请的技术提高了模型的泛化性能。
应当理解,本部分所描述的内容并非旨在标识本申请的实施例的关键或重要特征,也不用于限制本申请的范围。本申请的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本申请的限定。其中:
图1是根据本申请实施例的一种模型训练方法的示意图;
图2a是根据本申请实施例的另一种模型训练方法的示意图;
图2b是根据本申请实施例的一种对抗样本生成示意图;
图3是根据本申请实施例的又一种模型训练方法的示意图;
图4根据本申请实施例的一种模型训练装置的结构示意图;
图5是用来实现本申请实施例的模型训练方法的电子设备的框图。
具体实施方式
以下结合附图对本申请的示范性实施例做出说明,其中包括本申请实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本申请的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
图1是本申请实施例中的一种模型训练方法的示意图,本申请实施例的技术方案适用于在模型训练过程中生成对抗样本的情况,该方法可以由模型训练装置执行,该装置可以通过软件,和/或硬件的方式实现,并一般可以集成在电子设备中,例如终端设备中,本申请实施例的方法具体包括以下:
S110、在样本集合中获取本轮样本集输入至待训练模型中,并根据待训练模型的输出结果,计算本轮训练损失。
其中,样本集合中包括多个样本,每个样本都可由输入特征,以及与输入特征对应的标注数据构成,示例性的,样本包含的输入特征是与用户行为对应的词向量,标注数据是用户画像中的某一特征,例如,用户性别。
训练损失用于表示模型输出结果与标注数据的偏差,训练损失越大,表示输出偏差越大,本轮训练损失是指针对当前输入的本轮样本集,得到的输出结果与标注数据的偏差,以分类问题为例,用于计算训练损失的损失函数一般设置为交叉熵。
本申请实施例中,根据预先设置的模型训练规则,将样本集合中的样本分为多个部分,并将各部分样本按照顺序依次输入至待训练模型进行多轮训练,在每一轮训练过程中,首先在样本集合中获取本轮样本集,并输入到待训练模型中,然后根据待训练模型的输出结果,通过预先设置的损失函数,计算本轮训练损失。
示例性的,样本集合中包含1000条样本,将1000条样本分为100份,在每一轮模型训练中,首先获取本轮学习的10条样本输入至待训练模型中,根据待训练模型的输出结果和与输入样本对应的标注数据,计算本轮训练损失,以找到待训练模型的薄弱环节,生成针对当前待训练模型薄弱环节的对抗样本。
S120、根据本轮训练损失,确定本轮扰动项,并将本轮扰动项加入至本轮样本集中,得到本轮对抗样本集。
其中,对抗样本指在数据集中通过添加细微的干扰所形成的输入样本,导致模型以高置信度给出一个错误的输出。
本实申请施例中,为了针对待训练模型的薄弱环节生成对抗样本,根据本轮训练损失,来计算本轮扰动项,然后将本轮扰动项加入至本轮样本集中,得到本轮对抗样本集,以根据本轮样本集和本轮对抗样本集共同对待训练模型进行训练。在模型训练时,模型在每一轮训练过程中都会更新,针对每一轮的待训练模型生成的每一组对抗样本,均是针对当前待训练模型的,更加有针对性,有效提高模型泛化性能。
示例性的,基于本轮训练损失,对输入特征求梯度,根据梯度值,确定针对当前待训练模型的本轮扰动项,最终将本轮扰动项加入至本轮样本集中,得到本轮对抗样本集。
S130、使用本轮样本集和本轮对抗样本集共同对待训练模型进行训练,得到本轮训练模型。
本实申请施例中,在得到本轮对抗样本集后,使用本轮样本集和本轮对抗样本集共同进行模型训练,具体的,将本轮样本集和本轮对抗样本集一起输入至待训练模型中,通过预先设定的损失函数,根据模型输出结果和标注数据,计算总损失,根据总损失来调节模型参数,由于模型训练过程中加入了对抗样本,且总损失中包含针对对抗样本的损失,因此,根据总损失调节模型参数,可以提高模型鲁棒性。
S140、将本轮训练模型确定为新的待训练模型后,返回在样本集合中获取本轮样本集输入至待训练模型中的操作,响应于满足结束训练条件,获取目标训练模型。
本申请实施例中,在对模型参数进行调节,得到本轮训练模型后,将本轮训练模型作为下一轮的待训练模型,并从样本集合中继续取出下一轮样本集输入至待训练模型中,重复S110-S130中的操作,响应于满足结束训练条件,将最后一轮获取的本轮训练模型,作为最终的目标训练模型。其中,结束训练条件可以是计算出的总损失小于设定阈值,或者到达指定迭代次数,在此不做具体限制。
本申请实施例的技术方案,首先在样本集合中获取本轮样本集输入至待训练模型中,并根据待训练模型的输出结果,计算本轮训练损失,然后根据本轮训练损失,确定本轮扰动项,并将本轮扰动项加入至本轮样本集中,得到本轮对抗样本集,进一步的,使用本轮样本集和本轮对抗样本集共同对待训练模型进行训练,得到本轮训练模型,最终将本轮训练模型确定为新的待训练模型后,返回在样本集合中获取本轮样本集输入至待训练模型中的操作,响应于满足结束训练条件,获取目标训练模型,解决了现有技术中通过预先生成的对抗样本进行模型训练,不能对当前待训练模型最薄弱环节进行针对性训练的问题,通过在模型训练过程中生成针对当前模型的对抗样本,并使用原始样本和训练过程中生成的对抗样本共同进行模型训练,对抗样本更加有针对性,有效提高模型鲁棒性。
图2a是本申请实施例中的一种模型训练方法的示意图,在上述实施例的基础上进一步细化,提供了根据本轮训练损失,确定本轮扰动项的具体步骤,以及在获取目标训练模型之后的具体步骤。下面结合图2a对本申请实施例提供的一种模型训练方法进行说明,包括以下:
S210、在样本集合中获取本轮样本集输入至待训练模型中,并根据待训练模型的输出结果,计算本轮训练损失。
可选的,样本集合中的样本包括输入特征,以及与输入特征对应的标注数据。
本可选的实施例中,提供了样本集合中样本的组成,包括输入特征,以及与输入特征对应的标注数据,其中,标注数据是与输入特征对应的分类信息,示例性的,样本包含的输入特征是与用户行为对应的词向量,与输入特征对应的标注数据是用户年龄段,又示例性的,输入特征是图片,输出特征是与图片对应的分类结果。
可选的,样本中的输入特征为:与用户行为数据匹配的多项词向量,样本中的标注数据为用户画像。
本可选的实施例中,提供一种具体应用场景,样本中的输入特征是与用户行为数据匹配的词向量,对应的标注数据是用户画像。其中,用户行为数据用于表征用户在各应用程序中的行为,示例性的,用户行为数据包括用户下载或者浏览过的内容,对应的标注数据可以是用户画像中包含的多个用户特征,例如,用户性别、年龄段以及婚姻状态等。
S220、根据本轮训练损失,对输入特征求梯度,得到梯度值。
本申请实施例中,针对当前待训练模型生成对抗样本的流程如图2b所示,基于输入本轮样本集得到的本轮训练损失,对输入特征求梯度,得到梯度值,具体的,通过梯度上升法,确定对待训练模型影响最大的扰动项,以提高模型鲁棒性。
S230、对梯度值进行归一化处理,得到本轮扰动项,并将本轮扰动项加入至本轮样本集中,得到本轮对抗样本集。
本申请实施例中,对得到的梯度值进行归一化处理,得到本轮扰动项,最终将本轮扰动项加入至本轮样本集中,得到本轮对抗样本集。相较于直接采用符号函数的的形式确定扰动项,对梯度值进行归一化处理得到的本轮扰动项中,每一维度都是不同的浮点数,数据松弛度更佳。
可选的,本轮扰动项的计算方式如下:
其中,η表示扰动项,x表示样本的输入特征,y表示与所述输入特征对应的标注数据,θ表示模型参数,∈表示添加扰动的最大强度,f(x;θ)表示待训练模型针对输入特征x的输出结果,L(f(x;θ),y)表示所述待训练模型的本轮训练损失,g表示基于所述本轮训练损失,对所述输入特征x的梯度值,‖g‖2表示梯度值的二范数。
本可选的实施例中,提供了一种具体的扰动计算方式,具体为,首先基于本轮训练损失对输入特征求梯度,得到梯度值g,然后用梯度值的二范数进行约束,得到数据松弛度更佳的本轮扰动项,其中,∈表示添加扰动的最大强度,用于控制对抗样本的偏离度。
S240、使用本轮样本集和本轮对抗样本集共同对待训练模型进行训练,得到本轮训练模型。
S250、将本轮训练模型确定为新的待训练模型后,返回在样本集合中获取本轮样本集输入至待训练模型中的操作,响应于满足结束训练条件,获取目标训练模型。
S260、获取待识别用户的目标用户行为数据,并提取与目标用户的行为数据匹配的多项目标词向量。
本申请实施例中,为了构建用户画像,首先获取待识别用户的目标用户行为数据,然后提取与目标用户的行为数据匹配的多项目标词向量,例如,用户行为数据可以是用户下载或者浏览内容等,包括用户下载的应用软件以及用户浏览的网站等,目标词向量可以是通过对用户行为数据进行分层词,然后转换为词向量获取的。
S270、将各目标词向量输入至目标训练模型中,并获取目标训练模型输出的,待识别用户的目标用户画像。
本申请实施例中,将得到的各目标词向量输入至目标训练模型中,获取目标训练模型输出的待识别用户的目标用户画像,其中,目标用户画像可以包括用户性别、年龄、文化程度以及婚姻状态等特征。由于目标训练模型是由在模型训练过程中生成的各轮对抗样本和各轮原始样本训练得到的,具有较强的鲁棒性,提高了目标用户画像的精确度。
本申请实施例的技术方案,首先在样本集合中获取本轮样本集输入至待训练模型中,并根据待训练模型的输出结果,计算本轮训练损失,然后根据本轮训练损失确定的本轮对抗样本集,将本轮样本集和本轮对抗样本集一起输入至待训练模型进行模型训练,得到本轮训练模型,并将本轮训练模型确定为新的待训练模型后,返回在样本集合中获取本轮样本集输入至待训练模型中的操作,响应于满足结束训练条件,获取目标训练模型,进一步的,将与目标用户行为数据匹配的多项目标词向量输入至目标训练模型中,获取目标训练模型输出的待识别用户的目标用户画像,通过在每一轮模型训练过程中,生成针对当前模型的对抗样本,使对抗样本更加有针对性,有效提高模型鲁棒性,进而提升了使用用户行为数据构建用户画像的准确度。
图3是本申请实施例中的一种模型训练方法的示意图,在上述实施例的基础上进一步细化,提供了使用本轮样本集和本轮对抗样本集共同对待训练模型进行训练,得到本轮训练模型的具体步骤。下面结合图3对本申请实施例提供的一种模型训练方法进行说明,包括以下:
S310、在样本集合中获取本轮样本集输入至待训练模型中,并根据待训练模型的输出结果,计算本轮训练损失。
S320、根据本轮训练损失,确定本轮扰动项,并将本轮扰动项加入至本轮样本集中,得到本轮对抗样本集。
S330、将本轮样本集和本轮对抗样本集共同输入至待训练模型中。
本申请实施例中,在得到针对当前待训练模型的本轮对抗样本集后,将本轮样本集和本轮对抗样本集共同输入到待训练模型中,进行模型训练,以提高待训练模型的泛化性能。其中,本轮对抗样本集中的标注数据与本轮样本集的标注数据相同。
S340、根据待训练模型的输出结果,计算总损失,总损失包括与本轮样本集对应的原始训练损失,以及与本轮对抗样本集对应的对抗训练损失。
本申请实施例中,在将本轮样本集和本轮对抗样本集共同输入到待训练模型后,根据待训练模型的输出结果,根据总损失函数,计算总损失,以根据总损失对待训练模型参数进行调节,其中,总损失包括本轮样本集对应的原始训练损失,以及与本轮对抗样本集对应的对抗训练损失,由于对抗训练样本集是根据当前待训练模型的本轮训练损失生成的,能够针对当前待训练模型最薄弱的环节,因此,采用本轮样本集和本轮对抗样本集一起进行模型训练,可以提高模型的鲁棒性,并且,可以通过调节原始训练损失和对抗训练损失的对应参数,来调节两种样本的贡献占比。
可选的,总损失的计算方式如下:
其中,Losstotal表示总损失,表示与所述本轮对抗样本集对应的对抗训练损失,L(f(x;θ),y)表示与所述本轮样本集对应的原始训练损失,α表示所述本轮对抗样本集贡献占比参数。
本申请实施例中,提供了一种计算总损失的方式,分别将原始训练损失和对抗训练损失与对应参数相乘后再相加,α取值越大,本轮对抗样本集贡献占比越大,也就是说训练得到的模型抵抗对抗样本攻击的能力相对较弱,具体的,可以根据实际需求灵活调整α取值。
S350、基于总损失,通过梯度下降法调节待训练模型的参数,以得到本轮训练模型。
本申请实施例中,计算得到总损失后,基于总损失,采用梯度下降法来调节待训练模型参数,得到本轮训练模型,以在下一轮训练时,生成针对调节参数后的本轮训练模型生成对抗样本,提升模型泛化性能。
S360、将本轮训练模型确定为新的待训练模型后,返回在样本集合中获取本轮样本集输入至待训练模型中的操作,响应于满足结束训练条件,获取目标训练模型。
本申请实施例的技术方案,在样本集合中获取本轮样本集输入至待训练模型中,并根据待训练模型的输出结果,计算本轮训练损失,然后根据本轮训练损失,确定本轮扰动项,并将本轮扰动项加入至本轮样本集中,得到本轮对抗样本集,进一步的,将本轮样本集和本轮对抗样本集共同输入至待训练模型中,并根据待训练模型的输出结果,计算总损失,最终基于总损失,通过梯度下降法调节待训练模型的参数,以得到本轮训练模型,通过在模型训练过程中生成有针对性的对抗样本,有效提升模型泛化性能。
图4是本申请实施例中的一种模型训练装置的结构示意图,该模型训练装置400,包括:损失计算模块410、对抗样本获取模块420、模型训练模块430和目标模型获取模块440。
损失计算模块410,用于在样本集合中获取本轮样本集输入至待训练模型中,并根据待训练模型的输出结果,计算本轮训练损失;
对抗样本获取模块420,用于根据所述本轮训练损失,确定本轮扰动项,并将所述本轮扰动项加入至所述本轮样本集中,得到本轮对抗样本集;
模型训练模块430,用于使用所述本轮样本集和所述本轮对抗样本集共同对待训练模型进行训练,得到本轮训练模型;
目标模型获取模块440,用于将所述本轮训练模型确定为新的待训练模型后,返回在样本集合中获取本轮样本集输入至待训练模型中的操作,响应于满足结束训练条件,获取目标训练模型。
本申请实施例的技术方案,首先在样本集合中获取本轮样本集输入至待训练模型中,并根据待训练模型的输出结果,计算本轮训练损失,然后根据本轮训练损失,确定本轮扰动项,并将本轮扰动项加入至本轮样本集中,得到本轮对抗样本集,进一步的,使用本轮样本集和本轮对抗样本集共同对待训练模型进行训练,得到本轮训练模型,最终将本轮训练模型确定为新的待训练模型后,返回在样本集合中获取本轮样本集输入至待训练模型中的操作,响应于满足结束训练条件,获取目标训练模型,解决了现有技术中通过预先生成的对抗样本进行模型训练,不能对当前待训练模型最薄弱环节进行针对性训练的问题,通过在模型训练过程中生成针对当前模型的对抗样本,并使用原始样本和训练过程中生成的对抗样本共同进行模型训练,对抗样本更加有针对性,有效提高模型鲁棒性。
可选的,样本集合中的样本包括输入特征,以及与输入特征对应的标注数据;
对抗样本获取模块420,包括:
梯度值计算单元,用于根据所述本轮训练损失,对所述输入特征求梯度,得到梯度值;
扰动项计算单元,用于对所述梯度值进行归一化处理,得到所述本轮扰动项。
可选的,模型训练模块430,包括:
样本输入单元,用于将本轮样本集和所述本轮对抗样本集共同输入至待训练模型中;
总损失计算单元,用于根据待训练模型的输出结果,计算总损失,总损失包括与本轮样本集对应的原始训练损失,以及与本轮对抗样本集对应的对抗训练损失;
参数调节单元,用于基于总损失,通过梯度下降法调节待训练模型的参数,以得到本轮训练模型。
可选的,本轮扰动项的计算方式如下:
其中,η表示扰动项,x表示样本的输入特征,y表示与所述输入特征对应的标注数据,θ表示模型参数,∈表示添加扰动的最大强度,f(x;θ)表示待训练模型针对输入特征x的输出结果,L(f(x;θ),y)表示所述待训练模型的本轮训练损失,g表示基于所述本轮训练损失,对所述输入特征x的梯度值,‖g‖2表示梯度值的二范数。
可选的,总损失的计算方式如下;
其中,Losstotal表示总损失,表示与所述本轮对抗样本集对应的对抗训练损失,L(f(x;θ),y)表示与所述本轮样本集对应的原始训练损失,α表示所述本轮对抗样本集贡献占比参数。
可选的,样本中的输入特征为:与用户行为数据匹配的多项词向量,样本中的标注数据为用户画像。
可选的,模型训练装置400,还包括:
词向量获取模块,用于在获取目标训练模型之后,获取待识别用户的目标用户行为数据,并提取与所述目标用户的行为数据匹配的多项目标词向量;
用户画像输出模块,用于将各所述目标词向量输入至所述目标训练模型中,并获取所述目标训练模型输出的,所述待识别用户的目标用户画像。
本申请实施例所提供的模型训练装置400可执行本申请任意实施例所提供的模型训练方法,具备执行方法相应的功能模块和有益效果。
根据本申请的实施例,本申请还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
图5示出了可以用来实施本申请的实施例的示例电子设备500的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本申请的实现。
如图5所示,设备500包括计算单元501,其可以根据存储在只读存储器(ROM)502中的计算机程序或者从存储单元508加载到随机访问存储器(RAM)503中的计算机程序,来执行各种适当的动作和处理。在RAM 503中,还可存储设备500操作所需的各种程序和数据。计算单元501、ROM 502以及RAM 503通过总线504彼此相连。输入/输出(I/O)接口505也连接至总线504。
设备500中的多个部件连接至I/O接口505,包括:输入单元506,例如键盘、鼠标等;输出单元507,例如各种类型的显示器、扬声器等;存储单元508,例如磁盘、光盘等;以及通信单元509,例如网卡、调制解调器、无线通信收发机等。通信单元509允许设备500通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元501可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元501的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元501执行上文所描述的各个方法和处理,例如模型训练方法。例如,在一些实施例中,模型训练方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元508。在一些实施例中,计算机程序的部分或者全部可以经由ROM 502和/或通信单元509而被载入和/或安装到设备500上。当计算机程序加载到RAM 503并由计算单元501执行时,可以执行上文描述的模型训练方法的一个或多个步骤。备选地,在其他实施例中,计算单元501可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行模型训练方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、负载可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本申请的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本申请的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本发公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本申请公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本申请保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本申请的精神和原则之内所作的修改、等同替换和改进等,均应包含在本申请保护范围之内。
Claims (12)
1.一种模型训练方法,包括:
在样本集合中获取本轮样本集输入至待训练模型中,并根据待训练模型的输出结果,计算本轮训练损失;其中,所述样本集合中的样本包括输入特征,以及与所述输入特征对应的标注数据;所述样本中的输入特征为:与用户行为数据匹配的多项词向量,所述样本中的标注数据为用户画像;
根据所述本轮训练损失,确定本轮扰动项,并将所述本轮扰动项加入至所述本轮样本集中,得到本轮对抗样本集;
使用所述本轮样本集和所述本轮对抗样本集共同对待训练模型进行训练,得到本轮训练模型;
将所述本轮训练模型确定为新的待训练模型后,返回在样本集合中获取本轮样本集输入至待训练模型中的操作,响应于满足结束训练条件,获取目标训练模型;
其中,所述本轮扰动项的计算方式如下:
其中,/>表示扰动项,x表示样本的输入特征,y表示与所述输入特征对应的标注数据,/>表示模型参数,/>表示添加扰动的最大强度,/>表示待训练模型针对输入特征x的输出结果,/>表示所述待训练模型的本轮训练损失,/>表示基于所述本轮训练损失,对所述输入特征x的梯度值,/>表示梯度值的二范数。
2.根据权利要求1所述的方法,根据所述本轮训练损失,确定本轮扰动项,包括:
根据所述本轮训练损失,对所述输入特征求梯度,得到梯度值;
对所述梯度值进行归一化处理,得到所述本轮扰动项。
3.根据权利要求1所述的方法,其中,使用所述本轮样本集和所述本轮对抗样本集共同对待训练模型进行训练,得到本轮训练模型,包括:
将所述本轮样本集和所述本轮对抗样本集共同输入至所述待训练模型中;
根据所述待训练模型的输出结果,计算总损失,所述总损失包括与所述本轮样本集对应的原始训练损失,以及与所述本轮对抗样本集对应的对抗训练损失;
基于所述总损失,通过梯度下降法调节所述待训练模型的参数,以得到所述本轮训练模型。
4.根据权利要求3所述的方法,其中,所述总损失的计算方式如下:
其中,/>表示总损失,/>表示与所述本轮对抗样本集对应的对抗训练损失,/>表示与所述本轮样本集对应的原始训练损失,/>表示所述本轮对抗样本集贡献占比参数。
5.根据权利要求1所述的方法,在获取目标训练模型之后,还包括:
获取待识别用户的目标用户行为数据,并提取与所述目标用户的行为数据匹配的多项目标词向量;
将各所述目标词向量输入至所述目标训练模型中,并获取所述目标训练模型输出的,所述待识别用户的目标用户画像。
6.一种模型训练装置,包括:
损失计算模块,用于在样本集合中获取本轮样本集输入至待训练模型中,并根据待训练模型的输出结果,计算本轮训练损失;其中,所述样本集合中的样本包括输入特征,以及与所述输入特征对应的标注数据;所述样本中的输入特征为:与用户行为数据匹配的多项词向量,所述样本中的标注数据为用户画像;
对抗样本获取模块,用于根据所述本轮训练损失,确定本轮扰动项,并将所述本轮扰动项加入至所述本轮样本集中,得到本轮对抗样本集;
模型训练模块,用于使用所述本轮样本集和所述本轮对抗样本集共同对待训练模型进行训练,得到本轮训练模型;
目标模型获取模块,用于将所述本轮训练模型确定为新的待训练模型后,返回在样本集合中获取本轮样本集输入至待训练模型中的操作,响应于满足结束训练条件,获取目标训练模型;
其中,所述本轮扰动项的计算方式如下:
其中,/>表示扰动项,x表示样本的输入特征,y表示与所述输入特征对应的标注数据,/>表示模型参数,/>表示添加扰动的最大强度,/>表示待训练模型针对输入特征x的输出结果,/>表示所述待训练模型的本轮训练损失,/>表示基于所述本轮训练损失,对所述输入特征x的梯度值,/>表示梯度值的二范数。
7.根据权利要求6所述的装置,
所述对抗样本获取模块,包括:
梯度值计算单元,用于根据所述本轮训练损失,对所述输入特征求梯度,得到梯度值;
扰动项计算单元,用于对所述梯度值进行归一化处理,得到所述本轮扰动项。
8.根据权利要求6所述的装置,其中,所述模型训练模块,包括:
样本输入单元,用于将所述本轮样本集和所述本轮对抗样本集共同输入至所述待训练模型中;
总损失计算单元,用于根据所述待训练模型的输出结果,计算总损失,所述总损失包括与所述本轮样本集对应的原始训练损失,以及与所述本轮对抗样本集对应的对抗训练损失;
参数调节单元,用于基于所述总损失,通过梯度下降法调节所述待训练模型的参数,以得到所述本轮训练模型。
9.根据权利要求8所述的装置,其中,所述总损失的计算方式如下;
其中,/>表示总损失,/>表示与所述本轮对抗样本集对应的对抗训练损失,/>表示与所述本轮样本集对应的原始训练损失,/>表示所述本轮对抗样本集贡献占比参数。
10.根据权利要求6所述的装置,还包括:
词向量获取模块,用于在获取目标训练模型之后,获取待识别用户的目标用户行为数据,并提取与所述目标用户的行为数据匹配的多项目标词向量;
用户画像输出模块,用于将各所述目标词向量输入至所述目标训练模型中,并获取所述目标训练模型输出的,所述待识别用户的目标用户画像。
11. 一种电子设备,其中,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-5中任一项所述的模型训练方法。
12.一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使计算机执行权利要求1-5中任一项所述的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011563834.5A CN112580732B (zh) | 2020-12-25 | 2020-12-25 | 模型训练方法、装置、设备、存储介质和程序产品 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011563834.5A CN112580732B (zh) | 2020-12-25 | 2020-12-25 | 模型训练方法、装置、设备、存储介质和程序产品 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112580732A CN112580732A (zh) | 2021-03-30 |
CN112580732B true CN112580732B (zh) | 2024-02-23 |
Family
ID=75140615
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011563834.5A Active CN112580732B (zh) | 2020-12-25 | 2020-12-25 | 模型训练方法、装置、设备、存储介质和程序产品 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112580732B (zh) |
Families Citing this family (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113222480B (zh) * | 2021-06-11 | 2023-05-12 | 支付宝(杭州)信息技术有限公司 | 对抗样本生成模型的训练方法及装置 |
CN114139781A (zh) * | 2021-11-17 | 2022-03-04 | 国网湖北省电力有限公司经济技术研究院 | 一种电力系统的运行趋势预测方法及系统 |
CN114821823B (zh) * | 2022-04-12 | 2023-07-25 | 马上消费金融股份有限公司 | 图像处理、人脸防伪模型的训练及活体检测方法和装置 |
CN117540791B (zh) * | 2024-01-03 | 2024-04-05 | 支付宝(杭州)信息技术有限公司 | 一种对抗训练的方法及装置 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109036389A (zh) * | 2018-08-28 | 2018-12-18 | 出门问问信息科技有限公司 | 一种对抗样本的生成方法及装置 |
CN109272031A (zh) * | 2018-09-05 | 2019-01-25 | 宽凳(北京)科技有限公司 | 一种训练样本生成方法及装置、设备、介质 |
CN109460814A (zh) * | 2018-09-28 | 2019-03-12 | 浙江工业大学 | 一种具有防御对抗样本攻击功能的深度学习分类方法 |
CN110647992A (zh) * | 2019-09-19 | 2020-01-03 | 腾讯云计算(北京)有限责任公司 | 卷积神经网络的训练方法、图像识别方法及其对应的装置 |
CN111523597A (zh) * | 2020-04-23 | 2020-08-11 | 北京百度网讯科技有限公司 | 目标识别模型训练方法、装置、设备以及存储介质 |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180024508A1 (en) * | 2016-07-25 | 2018-01-25 | General Electric Company | System modeling, control and optimization |
US10522036B2 (en) * | 2018-03-05 | 2019-12-31 | Nec Corporation | Method for robust control of a machine learning system and robust control system |
-
2020
- 2020-12-25 CN CN202011563834.5A patent/CN112580732B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109036389A (zh) * | 2018-08-28 | 2018-12-18 | 出门问问信息科技有限公司 | 一种对抗样本的生成方法及装置 |
CN109272031A (zh) * | 2018-09-05 | 2019-01-25 | 宽凳(北京)科技有限公司 | 一种训练样本生成方法及装置、设备、介质 |
CN109460814A (zh) * | 2018-09-28 | 2019-03-12 | 浙江工业大学 | 一种具有防御对抗样本攻击功能的深度学习分类方法 |
CN110647992A (zh) * | 2019-09-19 | 2020-01-03 | 腾讯云计算(北京)有限责任公司 | 卷积神经网络的训练方法、图像识别方法及其对应的装置 |
CN111523597A (zh) * | 2020-04-23 | 2020-08-11 | 北京百度网讯科技有限公司 | 目标识别模型训练方法、装置、设备以及存储介质 |
Non-Patent Citations (3)
Title |
---|
"An Enhanced Anti-Disturbance Control Approach for Systems Subject to Multiple Disturbances";Lei Guo 等;《IEEE》;全文 * |
基于对抗训练的文本表示和分类算法;张晓辉;于双元;王全新;徐保民;;计算机科学(S1);全文 * |
基于生成式对抗网络的通用性对抗扰动生成方法;刘恒;吴德鑫;徐剑;;信息网络安全(05);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN112580732A (zh) | 2021-03-30 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112580732B (zh) | 模型训练方法、装置、设备、存储介质和程序产品 | |
CN112560996B (zh) | 用户画像识别模型训练方法、设备、可读存储介质及产品 | |
CN113360700B (zh) | 图文检索模型的训练和图文检索方法、装置、设备和介质 | |
CN115358392B (zh) | 深度学习网络的训练方法、文本检测方法及装置 | |
CN112907552A (zh) | 图像处理模型的鲁棒性检测方法、设备及程序产品 | |
CN114020950B (zh) | 图像检索模型的训练方法、装置、设备以及存储介质 | |
CN115147680B (zh) | 目标检测模型的预训练方法、装置以及设备 | |
CN114881129A (zh) | 一种模型训练方法、装置、电子设备及存储介质 | |
CN114564971B (zh) | 深度学习模型的训练方法、文本数据处理方法和装置 | |
CN113642710B (zh) | 一种网络模型的量化方法、装置、设备和存储介质 | |
CN113240177B (zh) | 训练预测模型的方法、预测方法、装置、电子设备及介质 | |
CN114494747A (zh) | 模型的训练方法、图像处理方法、装置、电子设备及介质 | |
CN114495101A (zh) | 文本检测方法、文本检测网络的训练方法及装置 | |
CN113052063A (zh) | 置信度阈值选择方法、装置、设备以及存储介质 | |
CN115690443B (zh) | 特征提取模型训练方法、图像分类方法及相关装置 | |
CN112784967B (zh) | 信息处理方法、装置以及电子设备 | |
CN113792849B (zh) | 字符生成模型的训练方法、字符生成方法、装置和设备 | |
CN113361621B (zh) | 用于训练模型的方法和装置 | |
CN113627526B (zh) | 车辆标识的识别方法、装置、电子设备和介质 | |
CN114817476A (zh) | 语言模型的训练方法、装置、电子设备和存储介质 | |
CN115641481A (zh) | 用于训练图像处理模型和图像处理的方法、装置 | |
CN114330592B (zh) | 模型生成方法、装置、电子设备及计算机存储介质 | |
CN116151215B (zh) | 文本处理方法、深度学习模型训练方法、装置以及设备 | |
CN115131709B (zh) | 视频类别预测方法、视频类别预测模型的训练方法及装置 | |
CN116188875B (zh) | 图像分类方法、装置、电子设备、介质和产品 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |