CN111950579A - 分类模型的训练方法和训练装置 - Google Patents
分类模型的训练方法和训练装置 Download PDFInfo
- Publication number
- CN111950579A CN111950579A CN201910414262.5A CN201910414262A CN111950579A CN 111950579 A CN111950579 A CN 111950579A CN 201910414262 A CN201910414262 A CN 201910414262A CN 111950579 A CN111950579 A CN 111950579A
- Authority
- CN
- China
- Prior art keywords
- classification model
- output
- training
- loss
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
- G06F18/2155—Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V30/00—Character recognition; Recognising digital ink; Document-oriented image-based pattern recognition
- G06V30/10—Character recognition
- G06V30/19—Recognition using electronic means
- G06V30/192—Recognition using electronic means using simultaneous comparisons or correlations of the image signals with a plurality of references
- G06V30/194—References adjustable by an adaptive method, e.g. learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- Physics & Mathematics (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Computational Biology (AREA)
- General Engineering & Computer Science (AREA)
- Artificial Intelligence (AREA)
- Life Sciences & Earth Sciences (AREA)
- Databases & Information Systems (AREA)
- Multimedia (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本公开提出一种分类模型的训练方法和训练装置,涉及机器学习领域。本公开单独训练分类模型,相对于同时训练相关的两个模型,模型的稳定性更好;并且基于生成样本数据设置用于抑制所有输出类别激活的损失函数,在真实类别的基础上不需要额外增加虚假类别,有利于降低训练的复杂度。
Description
技术领域
本公开涉及机器学习领域,特别涉及一种分类模型的训练方法和训练装置。
背景技术
基于生成式对抗网络(Generative Adversarial Networks,GAN)的半监督分类方法:在训练阶段,同时训练生成式对抗网络的生成模型和分类模型。一般来说,训练分类模型所需要的迭代次数比训练生成模型所需要的迭代次数少,这会使得生成式对抗网络不太稳定。分类模型在训练时需要增加一个额外的虚假类别,专门用于识别生成模型生成的“虚假数据”,但该虚假类别在测试阶段不会被使用,这在一定程度上增加了训练的复杂性。此外,生成模型有时会生成足够真实的“虚假数据”,这样的训练数据对于训练没有帮助。
发明内容
本公开可以单独训练分类模型,相对于同时训练相关的两个模型,模型的稳定性更好;并且基于生成样本数据设置用于抑制所有输出类别激活的损失函数,在真实类别的基础上不需要额外增加虚假类别,有利于降低训练的复杂度。此外,通过在特征层添加噪声的方法,在一定程度上避免生成模型生成过于真实的“虚假数据”,有利于提升训练数据的有效性和提升训练效果。
根据本公开的一方面,提出一种分类模型的训练方法,包括:
将真实样本数据和所述真实样本数据的标签数据输入待训练的分类模型,得到所述分类模型输出的第一组输出值,基于预设的第一损失函数和第一组输出值计算第一损失,并计算所述第一损失函数在所述分类模型当前参数下的第一梯度信息;
将生成样本数据输入所述分类模型,得到所述分类模型输出的第二组输出值,基于预设的用于抑制所有输出类别激活的第二损失函数和第二组输出值计算第二损失,并计算所述第二损失函数在所述分类模型当前参数下的第二梯度信息;
根据所述第一损失和所述第二损失判断所述分类模型是否收敛,在所述分类模型未收敛的情况下,基于第一梯度信息和第二梯度信息的梯度叠加信息,按照梯度下降的方法更新所述分类模型的参数,并对所述分类模型继续进行训练。
在一些实施例中,第二损失函数根据所述第二组输出值中每个输出类别上的输出值与数值小的预设值之间的差值信息确定。
在一些实施例中,第二损失函数的公式表示为:
在一些实施例中,T小于或等于log0.0001。
在一些实施例中,所述生成样本数据通过生成模型生成,其中,所述生成模型的特征层被配置为添加噪声。
在一些实施例中,还包括:利用收敛的分类模型对输入的图像数据进行分类。
在一些实施例中,所述分类模型为图像分类模型;所述真实样本数据为真实事物的图像数据,所述真实样本数据的标签数据为标注的真实事物的种类,所述第一组输出值为真实事物的图像数据在各个种类上的概率;所述生成样本数据为对真实事物的图像数据添加噪声得到的虚假事物的图像数据,所述第二组输出值为虚假事物的图像数据在各个种类上的概率。
根据本公开的另一方面,提出一种分类模型的训练装置,包括:
第一训练单元,被配置为将真实样本数据和所述真实样本数据的标签数据输入待训练的分类模型,得到所述分类模型输出的第一组输出值,基于预设的第一损失函数和第一组输出值计算第一损失,并计算所述第一损失函数在所述分类模型当前参数下的第一梯度信息;
第二训练单元,被配置为将生成样本数据输入所述分类模型,得到所述分类模型输出的第二组输出值,基于预设的用于抑制所有输出类别激活的第二损失函数和第二组输出值计算第二损失,并计算所述第二损失函数在所述分类模型当前参数下的第二梯度信息;
判断单元,被配置为根据所述第一损失和所述第二损失判断所述分类模型是否收敛;
模型参数更新单元,被配置为在所述分类模型未收敛的情况下,基于第一梯度信息和第二梯度信息的梯度叠加信息,按照梯度下降的方法更新所述分类模型的参数,以便继续执行所述第一训练单元、所述第二训练单元、所述判断单元和所述模型参数更新单元,对所述分类模型继续进行训练。
根据本公开的再一方面,提出一种分类模型的训练装置,包括:
存储器;以及
耦接至所述存储器的处理器,所述处理器被配置为基于存储在所述存储器中的指令,执行前述任一个实施例的分类模型的训练方法。
根据本公开的又一方面,提出一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现前述任一个实施例的分类模型的训练方法的步骤。
附图说明
下面将对实施例或相关技术描述中所需要使用的附图作简单地介绍。根据下面参照附图的详细描述,可以更加清楚地理解本公开,
显而易见地,下面描述中的附图仅仅是本公开的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为本公开分类模型的训练方法一些实施例的流程示意图。
图2示出分类模型训练过程的信息流转示意图。
图3为本公开分类模型的训练装置一些实施例的结构示意图。
图4为本公开分类模型的训练装置一些实施例的结构示意图。
具体实施方式
下面将结合本公开实施例中的附图,对本公开实施例中的技术方案进行清楚、完整地描述。
本公开的“第一”“第二”等描述,用来区分不同的对象,并不用来表示大小或时序等含义。例如,第一损失函数和第二损失函数表示两个损失函数。
本公开中的分类模型、生成模型等均为机器学习模型。本公开用开对分类模型进行训练,对分类模型具体为何种模型不做限制。训练用的真实样本数据及其标签数据为标记数据,训练用的生成样本数据为无标记数据,因此,本公开涉及一种半监督的分类方案。
图1为本公开分类模型的训练方法一些实施例的流程示意图。
如图1所示,该实施例的训练方法包括:
步骤11,将真实样本数据和该真实样本数据的标签数据输入待训练的分类模型,得到该分类模型输出的第一组输出值,基于预设的第一损失函数和第一组输出值计算第一损失,并计算该第一损失函数在该分类模型当前参数下的第一梯度信息。
其中,第一损失函数针对真实训练数据设置,例如为交叉熵损失函数、指数损失函数、铰链损失函数等。
此外,在定义好损失函数和模型参数的情况下,损失和梯度信息的具体计算可以参考现有技术。
步骤12,将生成样本数据输入该分类模型,得到该分类模型输出的第二组输出值,基于预设的用于抑制所有输出类别激活的第二损失函数和第二组输出值计算第二损失,并计算该第二损失函数在该分类模型当前参数下的第二梯度信息。
可理解的,步骤11和12的执行不分先后顺序。
生成样本数据相对于真实样本数据来说也称为“虚假样本数据”。在一些实施例中,训练用的生成样本数据可以通过生成模型生成,其中,该生成模型的特征层被配置为添加噪声,使得生成模型生成接近真实样本数据但又不会过于真实致使模型难以分辨的“虚假样本数据”,有利于提升训练数据的有效性和提升训练效果。
其中,第二损失函数根据该第二组输出值中每个输出类别上的输出值与数值小的预设值之间的差值信息确定。
在一些实施例中,第二损失函数的公式表示例如为:
其中,c表示输入输出类别的数量,i表示其中某个输出类别,表示该分类模型在输出类别i上的输出值,T表示数值小的预设值,,例如,T小于或等于log0.0001,max表示取最大值的运算,Lss,m表示多分类m下的第二损失。
如果分类模型在输出类别i上的输出值很大,说明分类模型会将该输入样本识别为输出类别i,即该输出类别i被激活。然而,通过第二损失函数,使得一旦超过T就受到惩罚,进而达到针对生成样本数据抑制各个输出类别i被激活的目的。
则,在二分类b下的第二损失Lss,b可以表示为:
步骤13,根据该第一损失和该第二损失判断该分类模型是否收敛。
例如,将第一损失和第二损失叠加起来得到总损失,如果总损失的变化均足够小,则判定分类模型收敛。其中,总损失的变化根据迭代训练中相邻两次训练的总损失之间的差值确定。
步骤14a,在该分类模型未收敛的情况下,基于第一梯度信息和第二梯度信息的梯度叠加信息,按照梯度下降的方法更新该分类模型的参数,并对该分类模型继续进行训练,即继续从步骤11开始执行本方法。
其中,将第一梯度信息和第二梯度信息叠加起来得到梯度叠加信息。
其中,按照梯度下降的方法更新该分类模型的参数例如为:分类模型更新前的参数减去学习率与梯度叠加信息的乘积得到分类模型更新后的参数。
步骤14b,在该分类模型收敛的情况下,分类模型的训练结束。
此外,在一些应用中,利用收敛的分类模型可以对输入的图像数据进行分类。
上述实施例,单独训练分类模型,相对于同时训练相关的两个模型,模型的稳定性更好;并且基于生成样本数据设置用于抑制所有输出类别激活的损失函数,在真实类别的基础上不需要额外增加虚假类别,有利于降低训练的复杂度。此外,通过在特征层添加噪声的方法,在一定程度上避免生成模型生成过于真实的“虚假数据”,有利于提升训练数据的有效性和提升训练效果。
针对上述实施例描述的训练方法,图2示出分类模型训练过程的信息流转示意图。其中的箭头方向表示信息的流转方向。
在一些实施例中,分类模型为图像分类模型;真实样本数据为真实事物的图像数据,真实样本数据的标签数据为标注的真实事物的种类,第一组输出值为真实事物的图像数据在各个种类上的概率;生成样本数据为对真实事物的图像数据添加噪声得到的虚假事物的图像数据,第二组输出值为虚假事物的图像数据在各个种类上的概率。
下面以服饰图像的分类为例,具体说明本公开的方案。
模型训练阶段:
将真实的服饰图像和标注的服饰图像的种类输入待训练的图像分类模型,输出真实的服饰图像在各个种类上的概率(即第一组输出值),基于交叉熵损失函数和真实的服饰图像在各个种类上的概率计算第一损失,并计算交叉熵损失函数在图像分类模型当前参数下的第一梯度信息;
对真实的服饰图像添加噪声得到“虚假的”服饰图像,将“虚假的”服饰图像输入图像分类模型,输出“虚假的”服饰图像在各个种类上的概率(即第二组输出值),基于前述的用于抑制所有输出类别激活的第二损失函数Lss,m和第二组输出值计算第二损失,并计算第二损失函数Lss,m在图像分类模型当前参数下的第二梯度信息;
判断第一损失和第二损失叠加起来的总损失的变化是否足够小,以确定图像分类模型是否收敛,在图像分类模型未收敛的情况下,基于第一梯度信息和第二梯度信息的梯度叠加信息,按照梯度下降的方法更新图像分类模型的参数,并对图像分类模型继续进行训练,直至图像分类模型收敛。从而,得到能够对服饰图像进行分类的图像分类模型。
该图像分类模型是单独训练得到的,相对于同时训练相关的生成模型和分类模型,图像分类模型的稳定性更好;并且,在训练过程中,仅涉及真实图像的种类,没有额外增加的虚假图像种类,有利于降低训练的复杂度。
模型使用阶段:
将待分类的服饰图像输入上述训练得到的收敛的分类模型中,输出待分类的服饰图像在各个种类上的概率,其中,概率最大的种类被判定为该服饰图像的种类。
图3为本公开分类模型的训练装置一些实施例的结构示意图。
如图3所示,该实施例的训练装置30包括:
第一训练单元31,被配置为将真实样本数据和该真实样本数据的标签数据输入待训练的分类模型,得到该分类模型输出的第一组输出值,基于预设的第一损失函数和第一组输出值计算第一损失,并计算该第一损失函数在该分类模型当前参数下的第一梯度信息;
第二训练单元32,被配置为将生成样本数据输入该分类模型,得到该分类模型输出的第二组输出值,基于预设的用于抑制所有输出类别激活的第二损失函数和第二组输出值计算第二损失,并计算该第二损失函数在该分类模型当前参数下的第二梯度信息;
判断单元33,被配置为根据该第一损失和该第二损失判断该分类模型是否收敛;
模型参数更新单元34,被配置为在该分类模型未收敛的情况下,基于第一梯度信息和第二梯度信息的梯度叠加信息,按照梯度下降的方法更新该分类模型的参数,以便继续执行该第一训练单元、该第二训练单元、该判断单元和该模型参数更新单元,对该分类模型继续进行训练。
第二训练单元32涉及的第二损失函数根据第二组输出值中每个输出类别上的输出值与数值小的预设值之间的差值信息确定。例如,第二损失函数的公式表示为:
第二训练单元32涉及的生成样本数据通过生成模型生成,其中,该生成模型的特征层被配置为添加噪声。
图4为本公开分类模型的训练装置一些实施例的结构示意图。
如图4所示,该实施例的训练装置40包括:
存储器41;以及耦接至该存储器的处理器42,该处理器42被配置为基于存储在该存储器中的指令,执行前述任一个实施例的分类模型的训练方法。
其中,存储器41例如可以包括系统存储器、固定非易失性存储介质等。系统存储器例如存储有操作系统、应用程序、引导装载程序(Boot Loader)以及其他程序等。
本领域内的技术人员应当明白,本公开的实施例可提供为方法、系统、或计算机程序产品。因此,本公开可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本公开可采用在一个或多个其中包含有计算机可用程序代码的计算机可用非瞬时性存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
以上所述仅为本公开的较佳实施例,并不用以限制本公开,凡在本公开的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本公开的保护范围之内。
Claims (12)
1.一种分类模型的训练方法,包括:
将真实样本数据和所述真实样本数据的标签数据输入待训练的分类模型,得到所述分类模型输出的第一组输出值,基于预设的第一损失函数和第一组输出值计算第一损失,并计算所述第一损失函数在所述分类模型当前参数下的第一梯度信息;
将生成样本数据输入所述分类模型,得到所述分类模型输出的第二组输出值,基于预设的用于抑制所有输出类别激活的第二损失函数和第二组输出值计算第二损失,并计算所述第二损失函数在所述分类模型当前参数下的第二梯度信息;
根据所述第一损失和所述第二损失判断所述分类模型是否收敛,在所述分类模型未收敛的情况下,基于第一梯度信息和第二梯度信息的梯度叠加信息,按照梯度下降的方法更新所述分类模型的参数,并对所述分类模型继续进行训练。
2.如权利要求1所述的方法,其中,第二损失函数根据所述第二组输出值中每个输出类别上的输出值与数值小的预设值之间的差值信息确定。
4.如权利要求3所述的方法,其中,T小于或等于log0.0001。
5.如权利要求1所述的方法,其中,所述生成样本数据通过生成模型生成,其中,所述生成模型的特征层被配置为添加噪声。
6.如权利要求1所述的方法,还包括:
利用收敛的分类模型对输入的图像数据进行分类。
7.如权利要求1所述的方法,其中,
所述分类模型为图像分类模型;
所述真实样本数据为真实事物的图像数据,所述真实样本数据的标签数据为标注的真实事物的种类,所述第一组输出值为真实事物的图像数据在各个种类上的概率;
所述生成样本数据为对真实事物的图像数据添加噪声得到的虚假事物的图像数据,所述第二组输出值为虚假事物的图像数据在各个种类上的概率。
8.一种分类模型的训练装置,包括:
第一训练单元,被配置为将真实样本数据和所述真实样本数据的标签数据输入待训练的分类模型,得到所述分类模型输出的第一组输出值,基于预设的第一损失函数和第一组输出值计算第一损失,并计算所述第一损失函数在所述分类模型当前参数下的第一梯度信息;
第二训练单元,被配置为将生成样本数据输入所述分类模型,得到所述分类模型输出的第二组输出值,基于预设的用于抑制所有输出类别激活的第二损失函数和第二组输出值计算第二损失,并计算所述第二损失函数在所述分类模型当前参数下的第二梯度信息;
判断单元,被配置为根据所述第一损失和所述第二损失判断所述分类模型是否收敛;
模型参数更新单元,被配置为在所述分类模型未收敛的情况下,基于第一梯度信息和第二梯度信息的梯度叠加信息,按照梯度下降的方法更新所述分类模型的参数,以便继续执行所述第一训练单元、所述第二训练单元、所述判断单元和所述模型参数更新单元,对所述分类模型继续进行训练。
10.如权利要求8所述的装置,其中,所述生成样本数据通过生成模型生成,其中,所述生成模型的特征层被配置为添加噪声。
11.一种分类模型的训练装置,包括:
存储器;以及
耦接至所述存储器的处理器,所述处理器被配置为基于存储在所述存储器中的指令,执行权利要求1-7中任一项所述的分类模型的训练方法。
12.一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现权利要求1-7中任一项所述的分类模型的训练方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910414262.5A CN111950579A (zh) | 2019-05-17 | 2019-05-17 | 分类模型的训练方法和训练装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910414262.5A CN111950579A (zh) | 2019-05-17 | 2019-05-17 | 分类模型的训练方法和训练装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN111950579A true CN111950579A (zh) | 2020-11-17 |
Family
ID=73336130
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910414262.5A Pending CN111950579A (zh) | 2019-05-17 | 2019-05-17 | 分类模型的训练方法和训练装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111950579A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112288032A (zh) * | 2020-11-18 | 2021-01-29 | 上海依图网络科技有限公司 | 一种基于生成对抗网络的量化模型训练的方法及装置 |
CN112651458A (zh) * | 2020-12-31 | 2021-04-13 | 深圳云天励飞技术股份有限公司 | 分类模型的训练方法、装置、电子设备及存储介质 |
WO2022188327A1 (zh) * | 2021-03-09 | 2022-09-15 | 北京百度网讯科技有限公司 | 定位图获取模型的训练方法和装置 |
-
2019
- 2019-05-17 CN CN201910414262.5A patent/CN111950579A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112288032A (zh) * | 2020-11-18 | 2021-01-29 | 上海依图网络科技有限公司 | 一种基于生成对抗网络的量化模型训练的方法及装置 |
CN112651458A (zh) * | 2020-12-31 | 2021-04-13 | 深圳云天励飞技术股份有限公司 | 分类模型的训练方法、装置、电子设备及存储介质 |
CN112651458B (zh) * | 2020-12-31 | 2024-04-02 | 深圳云天励飞技术股份有限公司 | 分类模型的训练方法、装置、电子设备及存储介质 |
WO2022188327A1 (zh) * | 2021-03-09 | 2022-09-15 | 北京百度网讯科技有限公司 | 定位图获取模型的训练方法和装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
JP6781415B2 (ja) | ニューラルネットワーク学習装置、方法、プログラム、およびパターン認識装置 | |
US11790237B2 (en) | Methods and apparatus to defend against adversarial machine learning | |
CN111950579A (zh) | 分类模型的训练方法和训练装置 | |
CN110362814B (zh) | 一种基于改进损失函数的命名实体识别方法及装置 | |
KR102074909B1 (ko) | 소프트웨어 취약점 분류 장치 및 방법 | |
CN109766259B (zh) | 一种基于复合蜕变关系的分类器测试方法及系统 | |
CN112949693A (zh) | 图像分类模型的训练方法、图像分类方法、装置和设备 | |
CN111586071A (zh) | 一种基于循环神经网络模型的加密攻击检测方法及装置 | |
JP6973197B2 (ja) | データセット検証装置、データセット検証方法、およびデータセット検証プログラム | |
US20170039484A1 (en) | Generating negative classifier data based on positive classifier data | |
CN113726545A (zh) | 基于知识增强生成对抗网络的网络流量生成方法及装置 | |
KR102152081B1 (ko) | 딥러닝 기반의 가치 평가 방법 및 그 장치 | |
CN110889316B (zh) | 一种目标对象识别方法、装置及存储介质 | |
CN117134958A (zh) | 用于网络技术服务的信息处理方法及系统 | |
CN114445656A (zh) | 多标签模型处理方法、装置、电子设备及存储介质 | |
CN114139636B (zh) | 异常作业处理方法及装置 | |
CN115082761A (zh) | 模型产生装置及方法 | |
JP5824429B2 (ja) | スパムアカウントスコア算出装置、スパムアカウントスコア算出方法、及びプログラム | |
KR20180082680A (ko) | 분류기를 학습시키는 방법 및 이를 이용한 예측 분류 장치 | |
CN116935102B (zh) | 一种轻量化模型训练方法、装置、设备和介质 | |
CN115393659B (zh) | 基于多级决策树的个性化分类流程优化方法和装置 | |
JP7118938B2 (ja) | 分類装置、学習装置、方法及びプログラム | |
US20220253691A1 (en) | Execution behavior analysis text-based ensemble malware detector | |
WO2021111831A1 (ja) | 情報処理方法、情報処理システム及び情報処理装置 | |
CN117454187B (zh) | 一种基于频域限制目标攻击的集成模型训练方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |