CN110689048A - 用于样本分类的神经网络模型的训练方法和装置 - Google Patents

用于样本分类的神经网络模型的训练方法和装置 Download PDF

Info

Publication number
CN110689048A
CN110689048A CN201910822201.2A CN201910822201A CN110689048A CN 110689048 A CN110689048 A CN 110689048A CN 201910822201 A CN201910822201 A CN 201910822201A CN 110689048 A CN110689048 A CN 110689048A
Authority
CN
China
Prior art keywords
training
sample
model
prediction loss
feature extraction
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN201910822201.2A
Other languages
English (en)
Inventor
马良庄
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Advanced New Technologies Co Ltd
Advantageous New Technologies Co Ltd
Original Assignee
Alibaba Group Holding Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Alibaba Group Holding Ltd filed Critical Alibaba Group Holding Ltd
Priority to CN201910822201.2A priority Critical patent/CN110689048A/zh
Publication of CN110689048A publication Critical patent/CN110689048A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本说明书实施例提供一种用于样本分类的神经网络模型的训练方法和装置。方法包括:获取训练样本集中的训练样本,训练样本具有样本标识和样本类别标签;将训练样本输入特征提取模型,得到特征表示向量;将特征表示向量输入鉴别器模型,得到识别标识;根据识别标识和样本标识确定第一预测损失,以最小化第一预测损失为目标,对鉴别器模型和特征提取模型进行第一训练;将特征表示向量输入分类器模型,得到识别类别;根据识别类别和样本类别标签确定第二预测损失,并根据与第一预测损失负相关和与第二预测损失正相关确定第三预测损失,以最小化第三预测损失为目标,对分类器模型和特征提取模型进行第二训练。能够提高模型的泛化性。

Description

用于样本分类的神经网络模型的训练方法和装置
技术领域
本说明书一个或多个实施例涉及计算机领域,尤其涉及用于样本分类的神经网络模型的训练方法和装置。
背景技术
在深度学习里面,通常会设置较大的神经网络模型,去拟合训练数据,以使该神经网络模型用于样本分类。但是,由于参数太多,神经网络模型复杂度过高,绝大多数情况下,神经网络模型都会过拟合,训练得到的神经网络模型泛化性并不会最好。
因此,希望能有改进的方案,能够在对神经网络模型的训练过程中,防止神经网络模型过拟合,提高训练得到的神经网络模型的泛化性。
发明内容
本说明书一个或多个实施例描述了一种用于样本分类的神经网络模型的训练方法和装置,能够在对神经网络模型的训练过程中,防止神经网络模型过拟合,提高训练得到的神经网络模型的泛化性。
第一方面,提供了一种用于样本分类的神经网络模型的训练方法,方法包括:
获取训练样本集中的训练样本,所述训练样本具有样本标识,以及预先标注的样本类别标签;
将所述训练样本输入待训练的特征提取模型,通过所述特征提取模型输出所述训练样本的特征表示向量;
将所述训练样本的特征表示向量输入待训练的鉴别器模型,通过所述鉴别器模型输出所述训练样本的识别标识;
根据所述训练样本的识别标识和所述训练样本的样本标识确定第一预测损失,以最小化所述第一预测损失为目标,对所述鉴别器模型和所述特征提取模型进行第一训练;
将所述训练样本的特征表示向量输入待训练的分类器模型,通过所述分类器模型输出所述训练样本的识别类别;
根据所述训练样本的识别类别和所述训练样本的样本类别标签确定第二预测损失,并根据与所述第一预测损失负相关和与所述第二预测损失正相关确定第三预测损失,以最小化所述第三预测损失为目标,对所述分类器模型和所述特征提取模型进行第二训练。
在一种可能的实施方式中,所述第一训练和所述第二训练交替重复进行。
进一步地,当交替重复的次数达到第一预设阈值时,终止训练。
进一步地,每执行完一轮所述第一训练和所述第二训练,计算预设指标的指标值;当所述指标值大于第二预设阈值时,终止训练。
在一种可能的实施方式中,所述根据与所述第一预测损失负相关和与所述第二预测损失正相关确定第三预测损失,包括:
将所述第一预测损失和所述第二预测损失进行加权求和得到所述第三预测损失;其中,所述第一预测损失对应的权重因子为负数,所述第二预测损失对应的权重因子为正数。
在一种可能的实施方式中,所述方法还包括:
将待识别样本输入训练后的所述特征提取模型,通过所述特征提取模型输出所述待识别样本的特征表示向量;
将所述待识别样本的特征表示向量输入训练后的所述分类器模型,通过所述分类器模型输出所述待识别样本的识别类别。
在一种可能的实施方式中,所述训练样本对应一个用户,所述样本标识为所述一个用户的标识,所述样本类别标签对应包括多个用户的用户人群。
第二方面,提供了一种用于样本分类的神经网络模型的训练装置,装置包括:
获取单元,用于获取训练样本集中的训练样本,所述训练样本具有样本标识,以及预先标注的样本类别标签;
特征提取单元,用于将所述获取单元获取的所述训练样本输入待训练的特征提取模型,通过所述特征提取模型输出所述训练样本的特征表示向量;
鉴别单元,用于将所述特征提取单元得到的所述训练样本的特征表示向量输入待训练的鉴别器模型,通过所述鉴别器模型输出所述训练样本的识别标识;
第一训练单元,用于根据所述鉴别单元得到的所述训练样本的识别标识和所述获取单元获取的所述训练样本的样本标识确定第一预测损失,以最小化所述第一预测损失为目标,对所述鉴别器模型和所述特征提取模型进行第一训练;
分类单元,用于将所述特征提取单元得到的所述训练样本的特征表示向量输入待训练的分类器模型,通过所述分类器模型输出所述训练样本的识别类别;
第二训练单元,用于根据所述分类单元得到的所述训练样本的识别类别和所述获取单元获取的所述训练样本的样本类别标签确定第二预测损失,并根据与所述第一预测损失负相关和与所述第二预测损失正相关确定第三预测损失,以最小化所述第三预测损失为目标,对所述分类器模型和所述特征提取模型进行第二训练。
第三方面,提供了一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行第一方面的方法。
第四方面,提供了一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现第一方面的方法。
通过本说明书实施例提供的方法和装置,首先获取训练样本集中的训练样本,所述训练样本具有样本标识,以及预先标注的样本类别标签;然后将所述训练样本输入待训练的特征提取模型,通过所述特征提取模型输出所述训练样本的特征表示向量;接着将所述训练样本的特征表示向量输入待训练的鉴别器模型,通过所述鉴别器模型输出所述训练样本的识别标识;再根据所述训练样本的识别标识和所述训练样本的样本标识确定第一预测损失,以最小化所述第一预测损失为目标,对所述鉴别器模型和所述特征提取模型进行第一训练;再然后将所述训练样本的特征表示向量输入待训练的分类器模型,通过所述分类器模型输出所述训练样本的识别类别;最后根据所述训练样本的识别类别和所述训练样本的样本类别标签确定第二预测损失,并根据与所述第一预测损失负相关和与所述第二预测损失正相关确定第三预测损失,以最小化所述第三预测损失为目标,对所述分类器模型和所述特征提取模型进行第二训练。由上可见,本说明书实施例,通过对抗的方法,能够在对神经网络模型的训练过程中,防止神经网络模型过拟合,提高训练得到的神经网络模型的泛化性。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
图1为本说明书披露的一个实施例的实施场景示意图;
图2示出根据一个实施例的用于样本分类的神经网络模型的训练方法流程图;
图3示出根据一个实施例的对抗训练网络组成结构图;
图4示出根据一个实施例的用于样本分类的神经网络模型的训练装置的示意性框图。
具体实施方式
下面结合附图,对本说明书提供的方案进行描述。
图1为本说明书披露的一个实施例的实施场景示意图。该实施场景涉及用于样本分类的神经网络模型的训练。具体地,可以基于训练样本对神经网络模型进行训练,本说明书实施例中,训练样本具有样本标识,以及预先标注的样本类别标签。可以理解的是,上述神经网络模型可以适用各种场景下的样本分类,例如,对物品进行分类,对用户进行分类等。在一个示例中,所述训练样本对应一个用户,所述样本标识为所述一个用户的标识,所述样本类别标签对应包括多个用户的用户人群。参照图1,训练样本集包括的各训练样本分别对应一个用户,用户1、用户2、用户3、用户4和用户5分别代表不同用户的标识,其中,用户1、用户4和用户5对应的样本类别标签为类别A,用户2和用户3对应的样本类别标签为类别B。
本说明书实施例,基于上述训练样本,通过对抗的方法对用于样本分类的神经网络模型进行训练,以使该神经网络模型提取出的特征不带有个例识别的信息,防止过拟合,以提高泛化能力。
过拟合(overfitting),也称过度拟合,是指在拟合一个统计模型时,使用过多参数。它的直观表现是算法在训练集上表现好,但在测试集上表现不好,泛化性能差。过拟合是在模型参数拟合过程中由于训练数据包含抽样误差,在训练时复杂的模型将抽样误差也进行了拟合导致的。所谓抽样误差,是指抽样得到的样本集和整体数据集之间的偏差。引起过拟合的可能原因有:模型本身过于复杂,以至于拟合了训练样本集中的噪声。本说明书实施例,针对引起过拟合的原因,对神经网络模型的训练过程进行了改进,利用了两个模型对抗训练的思想。
本说明书实施例,利用了两个模型对抗训练的思想来对用于样本分类的神经网络模型进行训练,从而使该神经网络模型提取出的特征不带有个例识别的信息,防止过拟合,以提高泛化能力。
图2示出根据一个实施例的用于样本分类的神经网络模型的训练方法流程图,该方法可以基于图1所示的实施场景。如图2所示,该实施例中用于样本分类的神经网络模型的训练方法包括以下步骤:步骤21,获取训练样本集中的训练样本,所述训练样本具有样本标识,以及预先标注的样本类别标签;步骤22,将所述训练样本输入待训练的特征提取模型,通过所述特征提取模型输出所述训练样本的特征表示向量;步骤23,将所述训练样本的特征表示向量输入待训练的鉴别器模型,通过所述鉴别器模型输出所述训练样本的识别标识;步骤24,根据所述训练样本的识别标识和所述训练样本的样本标识确定第一预测损失,以最小化所述第一预测损失为目标,对所述鉴别器模型和所述特征提取模型进行第一训练;步骤25,将所述训练样本的特征表示向量输入待训练的分类器模型,通过所述分类器模型输出所述训练样本的识别类别;步骤26,根据所述训练样本的识别类别和所述训练样本的样本类别标签确定第二预测损失,并根据与所述第一预测损失负相关和与所述第二预测损失正相关确定第三预测损失,以最小化所述第三预测损失为目标,对所述分类器模型和所述特征提取模型进行第二训练。下面描述以上各个步骤的具体执行方式。
首先在步骤21,获取训练样本集中的训练样本,所述训练样本具有样本标识,以及预先标注的样本类别标签。可以理解的是,本说明书实施例用于样本分类的神经网络模型可以用于对物品分类或对用户分类,其场景广泛,在此不做一一列举。
在一个示例中,所述训练样本对应一个用户,所述样本标识为所述一个用户的标识,所述样本类别标签对应包括多个用户的用户人群。
然后在步骤22,将所述训练样本输入待训练的特征提取模型,通过所述特征提取模型输出所述训练样本的特征表示向量。可以理解的是,通常先对训练样本进行特征提取,得到特征表示向量,然后再基于该特征表示向量进行分类处理。
本说明书实施例,对于特征提取模型的具体形式不做限定,例如,可以包括卷积神经网络(convolutional neural networks,CNN)
接着在步骤23,将所述训练样本的特征表示向量输入待训练的鉴别器模型,通过所述鉴别器模型输出所述训练样本的识别标识。可以理解的是,鉴别器模型需要识别到训练样本的标识,相应地,特征提取模型得到的特征表示向量需要带有个例的信息。
再在步骤24,根据所述训练样本的识别标识和所述训练样本的样本标识确定第一预测损失,以最小化所述第一预测损失为目标,对所述鉴别器模型和所述特征提取模型进行第一训练。可以理解的是,这符合通常的训练目标。
其中,具体可以基于第一损失函数确定第一预测损失,上述第一损失函数可以采用交叉熵损失函数等。
再然后在步骤25,将所述训练样本的特征表示向量输入待训练的分类器模型,通过所述分类器模型输出所述训练样本的识别类别。可以理解的是,分类器模型需要识别到训练样本的类别,而不需要识别到训练样本的标识,相应地,特征提取模型得到的特征表示向量不需要带有个例的信息。
最后在步骤26,根据所述训练样本的识别类别和所述训练样本的样本类别标签确定第二预测损失,并根据与所述第一预测损失负相关和与所述第二预测损失正相关确定第三预测损失,以最小化所述第三预测损失为目标,对所述分类器模型和所述特征提取模型进行第二训练。可以理解的是,为防止过拟合,本说明书实施例并不单纯追求最小化第二预测损失,而是在拟合训练样本的过程中,还考虑了预测损失与第一预测损失负相关,以便模型提取的特征不带有个例的信息,从而提高泛化能力。
其中,具体可以基于第二损失函数确定第二预测损失,上述第二损失函数可以采用交叉熵损失函数等。
在一个示例中,所述第一训练和所述第二训练交替重复进行。
可以理解的是,所述第一训练和所述第二训练是一种对抗的关系,通常地,达到一个纳什均衡后,可以结束训练。
如何判断是否达到纳什均衡,从而结束训练,实际中可以有多种方式。
一种方式为,当交替重复的次数达到第一预设阈值时,终止训练。
另一种方式为,每执行完一轮所述第一训练和所述第二训练,计算预设指标的指标值;当所述指标值大于第二预设阈值时,终止训练。
在一个示例中,将所述第一预测损失和所述第二预测损失进行加权求和得到所述第三预测损失;其中,所述第一预测损失对应的权重因子为负数,所述第二预测损失对应的权重因子为正数。例如,将所述第二预测损失与所述第一预测损失的差值作为所述第三预测损失。其中,可以认为第三预测损失对应的损失函数为第三损失函数。
可以理解的是,在终止训练后,就可以将训练好的模型用来进行样本分类了,本说明书实施例中,前面提到了特征提取模型、鉴别器模型和分类器模型,其中,鉴别器模型只是用于模型训练过程中作为生成对抗网络的一部分,模型训练结束后,只利用训练后的特征提取模型和分类器模型。
在一个示例中,将待识别样本输入训练后的所述特征提取模型,通过所述特征提取模型输出所述待识别样本的特征表示向量;将所述待识别样本的特征表示向量输入训练后的所述分类器模型,通过所述分类器模型输出所述待识别样本的识别类别。
通过本说明书实施例提供的方法,首先获取训练样本集中的训练样本,所述训练样本具有样本标识,以及预先标注的样本类别标签;然后将所述训练样本输入待训练的特征提取模型,通过所述特征提取模型输出所述训练样本的特征表示向量;接着将所述训练样本的特征表示向量输入待训练的鉴别器模型,通过所述鉴别器模型输出所述训练样本的识别标识;再根据所述训练样本的识别标识和所述训练样本的样本标识确定第一预测损失,以最小化所述第一预测损失为目标,对所述鉴别器模型和所述特征提取模型进行第一训练;再然后将所述训练样本的特征表示向量输入待训练的分类器模型,通过所述分类器模型输出所述训练样本的识别类别;最后根据所述训练样本的识别类别和所述训练样本的样本类别标签确定第二预测损失,并根据与所述第一预测损失负相关和与所述第二预测损失正相关确定第三预测损失,以最小化所述第三预测损失为目标,对所述分类器模型和所述特征提取模型进行第二训练。由上可见,本说明书实施例,通过对抗的方法,能够在对神经网络模型的训练过程中,防止神经网络模型过拟合,提高训练得到的神经网络模型的泛化性。
图3示出根据一个实施例的对抗训练网络组成结构图。如图3所示,该对抗训练网络包括两部分,其中,一部分包括特征提取模型和鉴别器模型,用于识别训练样本的标识,损失函数可以为前述第一损失函数;另一部分包括特征提取模型和分类器模型,用于识别训练样本的类别,损失函数可以为前述第三损失函数,通过交替重复前述第一训练和第二训练,通过对抗的方法,可以防止过拟合,最终可以利用训练后的特征提取模型和分类器模型来用于样本分类。
根据另一方面的实施例,还提供一种用于样本分类的神经网络模型的训练装置,该装置用于执行本说明书实施例提供的用于样本分类的神经网络模型的训练方法。图4示出根据一个实施例的用于样本分类的神经网络模型的训练装置的示意性框图。如图4所示,该装置400包括:
获取单元41,用于获取训练样本集中的训练样本,所述训练样本具有样本标识,以及预先标注的样本类别标签;
特征提取单元42,用于将所述获取单元41获取的所述训练样本输入待训练的特征提取模型,通过所述特征提取模型输出所述训练样本的特征表示向量;
鉴别单元43,用于将所述特征提取单元42得到的所述训练样本的特征表示向量输入待训练的鉴别器模型,通过所述鉴别器模型输出所述训练样本的识别标识;
第一训练单元44,用于根据所述鉴别单元43得到的所述训练样本的识别标识和所述获取单元41获取的所述训练样本的样本标识确定第一预测损失,以最小化所述第一预测损失为目标,对所述鉴别器模型和所述特征提取模型进行第一训练;
分类单元45,用于将所述特征提取单元42得到的所述训练样本的特征表示向量输入待训练的分类器模型,通过所述分类器模型输出所述训练样本的识别类别;
第二训练单元46,用于根据所述分类单元45得到的所述训练样本的识别类别和所述获取单元41获取的所述训练样本的样本类别标签确定第二预测损失,并根据与所述第一预测损失负相关和与所述第二预测损失正相关确定第三预测损失,以最小化所述第三预测损失为目标,对所述分类器模型和所述特征提取模型进行第二训练。
可选地,作为一个实施例,所述第一训练单元44进行所述第一训练和所述第二训练单元进行所述第二训练交替重复进行。
进一步地,当交替重复的次数达到第一预设阈值时,终止训练。
进一步地,所述装置还包括:
计算单元,用于每执行完一轮所述第一训练和所述第二训练,计算预设指标的指标值;当所述指标值大于第二预设阈值时,终止训练。
可选地,作为一个实施例,所述第二训练单元46,具体用于将所述第一预测损失和所述第二预测损失进行加权求和得到所述第三预测损失;其中,所述第一预测损失对应的权重因子为负数,所述第二预测损失对应的权重因子为正数。
可选地,作为一个实施例,所述特征提取单元42,还用于将待识别样本输入训练后的所述特征提取模型,通过所述特征提取模型输出所述待识别样本的特征表示向量;
所述分类单元45,还用于将所述特征提取单元42得到的所述待识别样本的特征表示向量输入训练后的所述分类器模型,通过所述分类器模型输出所述待识别样本的识别类别。
可选地,作为一个实施例,所述训练样本对应一个用户,所述样本标识为所述一个用户的标识,所述样本类别标签对应包括多个用户的用户人群。
通过本说明书实施例提供的装置,首先获取单元41获取训练样本集中的训练样本,所述训练样本具有样本标识,以及预先标注的样本类别标签;然后特征提取单元42将所述训练样本输入待训练的特征提取模型,通过所述特征提取模型输出所述训练样本的特征表示向量;接着鉴别单元43将所述训练样本的特征表示向量输入待训练的鉴别器模型,通过所述鉴别器模型输出所述训练样本的识别标识;再由第一训练单元44根据所述训练样本的识别标识和所述训练样本的样本标识确定第一预测损失,以最小化所述第一预测损失为目标,对所述鉴别器模型和所述特征提取模型进行第一训练;再然后由分类单元45将所述训练样本的特征表示向量输入待训练的分类器模型,通过所述分类器模型输出所述训练样本的识别类别;最后第二训练单元46根据所述训练样本的识别类别和所述训练样本的样本类别标签确定第二预测损失,并根据与所述第一预测损失负相关和与所述第二预测损失正相关确定第三预测损失,以最小化所述第三预测损失为目标,对所述分类器模型和所述特征提取模型进行第二训练。由上可见,本说明书实施例,通过对抗的方法,能够在对神经网络模型的训练过程中,防止神经网络模型过拟合,提高训练得到的神经网络模型的泛化性。
本说明书实施例提供的方法具有显著的算法投放效果。该方法训练的110层的resnet分类网络,比常规的训练方法,在cifar测试集上能够提高1.5%的准确性。泛化能力得到大大提高。
根据另一方面的实施例,还提供一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行结合图2所描述的方法。
根据再一方面的实施例,还提供一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现结合图2所描述的方法。
本领域技术人员应该可以意识到,在上述一个或多个示例中,本发明所描述的功能可以用硬件、软件、固件或它们的任意组合来实现。当使用软件实现时,可以将这些功能存储在计算机可读介质中或者作为计算机可读介质上的一个或多个指令或代码进行传输。
以上所述的具体实施方式,对本发明的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上所述仅为本发明的具体实施方式而已,并不用于限定本发明的保护范围,凡在本发明的技术方案的基础之上,所做的任何修改、等同替换、改进等,均应包括在本发明的保护范围之内。

Claims (16)

1.一种用于样本分类的神经网络模型的训练方法,所述方法包括:
获取训练样本集中的训练样本,所述训练样本具有样本标识,以及预先标注的样本类别标签;
将所述训练样本输入待训练的特征提取模型,通过所述特征提取模型输出所述训练样本的特征表示向量;
将所述训练样本的特征表示向量输入待训练的鉴别器模型,通过所述鉴别器模型输出所述训练样本的识别标识;
根据所述训练样本的识别标识和所述训练样本的样本标识确定第一预测损失,以最小化所述第一预测损失为目标,对所述鉴别器模型和所述特征提取模型进行第一训练;
将所述训练样本的特征表示向量输入待训练的分类器模型,通过所述分类器模型输出所述训练样本的识别类别;
根据所述训练样本的识别类别和所述训练样本的样本类别标签确定第二预测损失,并根据与所述第一预测损失负相关和与所述第二预测损失正相关确定第三预测损失,以最小化所述第三预测损失为目标,对所述分类器模型和所述特征提取模型进行第二训练。
2.如权利要求1所述的方法,其中,所述第一训练和所述第二训练交替重复进行。
3.如权利要求2所述的方法,其中,当交替重复的次数达到第一预设阈值时,终止训练。
4.如权利要求2所述的方法,其中,所述方法还包括:
每执行完一轮所述第一训练和所述第二训练,计算预设指标的指标值;
当所述指标值大于第二预设阈值时,终止训练。
5.如权利要求1所述的方法,其中,所述根据与所述第一预测损失负相关和与所述第二预测损失正相关确定第三预测损失,包括:
将所述第一预测损失和所述第二预测损失进行加权求和得到所述第三预测损失;其中,所述第一预测损失对应的权重因子为负数,所述第二预测损失对应的权重因子为正数。
6.如权利要求1所述的方法,其中,所述方法还包括:
将待识别样本输入训练后的所述特征提取模型,通过所述特征提取模型输出所述待识别样本的特征表示向量;
将所述待识别样本的特征表示向量输入训练后的所述分类器模型,通过所述分类器模型输出所述待识别样本的识别类别。
7.如权利要求1所述的方法,其中,所述训练样本对应一个用户,所述样本标识为所述一个用户的标识,所述样本类别标签对应包括多个用户的用户人群。
8.一种用于样本分类的神经网络模型的训练装置,所述装置包括:
获取单元,用于获取训练样本集中的训练样本,所述训练样本具有样本标识,以及预先标注的样本类别标签;
特征提取单元,用于将所述获取单元获取的所述训练样本输入待训练的特征提取模型,通过所述特征提取模型输出所述训练样本的特征表示向量;
鉴别单元,用于将所述特征提取单元得到的所述训练样本的特征表示向量输入待训练的鉴别器模型,通过所述鉴别器模型输出所述训练样本的识别标识;
第一训练单元,用于根据所述鉴别单元得到的所述训练样本的识别标识和所述获取单元获取的所述训练样本的样本标识确定第一预测损失,以最小化所述第一预测损失为目标,对所述鉴别器模型和所述特征提取模型进行第一训练;
分类单元,用于将所述特征提取单元得到的所述训练样本的特征表示向量输入待训练的分类器模型,通过所述分类器模型输出所述训练样本的识别类别;
第二训练单元,用于根据所述分类单元得到的所述训练样本的识别类别和所述获取单元获取的所述训练样本的样本类别标签确定第二预测损失,并根据与所述第一预测损失负相关和与所述第二预测损失正相关确定第三预测损失,以最小化所述第三预测损失为目标,对所述分类器模型和所述特征提取模型进行第二训练。
9.如权利要求8所述的装置,其中,所述第一训练单元进行所述第一训练和所述第二训练单元进行所述第二训练交替重复进行。
10.如权利要求9所述的装置,其中,当交替重复的次数达到第一预设阈值时,终止训练。
11.如权利要求9所述的装置,其中,所述装置还包括:
计算单元,用于每执行完一轮所述第一训练和所述第二训练,计算预设指标的指标值;当所述指标值大于第二预设阈值时,终止训练。
12.如权利要求8所述的装置,其中,所述第二训练单元,具体用于将所述第一预测损失和所述第二预测损失进行加权求和得到所述第三预测损失;其中,所述第一预测损失对应的权重因子为负数,所述第二预测损失对应的权重因子为正数。
13.如权利要求8所述的装置,其中,所述特征提取单元,还用于将待识别样本输入训练后的所述特征提取模型,通过所述特征提取模型输出所述待识别样本的特征表示向量;
所述分类单元,还用于将所述特征提取单元得到的所述待识别样本的特征表示向量输入训练后的所述分类器模型,通过所述分类器模型输出所述待识别样本的识别类别。
14.如权利要求8所述的装置,其中,所述训练样本对应一个用户,所述样本标识为所述一个用户的标识,所述样本类别标签对应包括多个用户的用户人群。
15.一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行权利要求1-7中任一项的所述的方法。
16.一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现权利要求1-7中任一项的所述的方法。
CN201910822201.2A 2019-09-02 2019-09-02 用于样本分类的神经网络模型的训练方法和装置 Pending CN110689048A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201910822201.2A CN110689048A (zh) 2019-09-02 2019-09-02 用于样本分类的神经网络模型的训练方法和装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201910822201.2A CN110689048A (zh) 2019-09-02 2019-09-02 用于样本分类的神经网络模型的训练方法和装置

Publications (1)

Publication Number Publication Date
CN110689048A true CN110689048A (zh) 2020-01-14

Family

ID=69108687

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201910822201.2A Pending CN110689048A (zh) 2019-09-02 2019-09-02 用于样本分类的神经网络模型的训练方法和装置

Country Status (1)

Country Link
CN (1) CN110689048A (zh)

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111400754A (zh) * 2020-03-11 2020-07-10 支付宝(杭州)信息技术有限公司 保护用户隐私的用户分类系统的构建方法及装置
CN112580733A (zh) * 2020-12-25 2021-03-30 北京百度网讯科技有限公司 分类模型的训练方法、装置、设备以及存储介质
CN113222964A (zh) * 2021-05-27 2021-08-06 推想医疗科技股份有限公司 一种冠脉中心线提取模型的生成方法及装置
CN113239975A (zh) * 2021-04-21 2021-08-10 洛阳青鸟网络科技有限公司 一种基于神经网络的目标检测方法和装置
CN113762508A (zh) * 2021-09-06 2021-12-07 京东鲲鹏(江苏)科技有限公司 一种图像分类网络模型的训练方法、装置、设备和介质
CN113780378A (zh) * 2021-08-26 2021-12-10 北京科技大学 一种疾病高危人群预测装置
CN114282684A (zh) * 2021-12-24 2022-04-05 支付宝(杭州)信息技术有限公司 训练用户相关的分类模型、进行用户分类的方法及装置

Cited By (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111400754A (zh) * 2020-03-11 2020-07-10 支付宝(杭州)信息技术有限公司 保护用户隐私的用户分类系统的构建方法及装置
CN112580733A (zh) * 2020-12-25 2021-03-30 北京百度网讯科技有限公司 分类模型的训练方法、装置、设备以及存储介质
CN112580733B (zh) * 2020-12-25 2024-03-05 北京百度网讯科技有限公司 分类模型的训练方法、装置、设备以及存储介质
CN113239975A (zh) * 2021-04-21 2021-08-10 洛阳青鸟网络科技有限公司 一种基于神经网络的目标检测方法和装置
CN113239975B (zh) * 2021-04-21 2022-12-20 国网甘肃省电力公司白银供电公司 一种基于神经网络的目标检测方法和装置
CN113222964A (zh) * 2021-05-27 2021-08-06 推想医疗科技股份有限公司 一种冠脉中心线提取模型的生成方法及装置
CN113222964B (zh) * 2021-05-27 2021-11-12 推想医疗科技股份有限公司 一种冠脉中心线提取模型的生成方法及装置
CN113780378A (zh) * 2021-08-26 2021-12-10 北京科技大学 一种疾病高危人群预测装置
CN113780378B (zh) * 2021-08-26 2023-11-28 北京科技大学 一种疾病高危人群预测装置
CN113762508A (zh) * 2021-09-06 2021-12-07 京东鲲鹏(江苏)科技有限公司 一种图像分类网络模型的训练方法、装置、设备和介质
CN114282684A (zh) * 2021-12-24 2022-04-05 支付宝(杭州)信息技术有限公司 训练用户相关的分类模型、进行用户分类的方法及装置

Similar Documents

Publication Publication Date Title
CN110689048A (zh) 用于样本分类的神经网络模型的训练方法和装置
CN110852755B (zh) 针对交易场景的用户身份识别方法和装置
US8140450B2 (en) Active learning method for multi-class classifiers
CN110046706B (zh) 模型生成方法、装置及服务器
JP7130984B2 (ja) 画像判定システム、モデル更新方法およびモデル更新プログラム
EP2991003A2 (en) Method and apparatus for classification
CN110852450B (zh) 识别对抗样本以保护模型安全的方法及装置
JP7024515B2 (ja) 学習プログラム、学習方法および学習装置
CN110334488B (zh) 基于随机森林模型的用户认证口令安全评估方法及装置
CN111626367A (zh) 对抗样本检测方法、装置、设备及计算机可读存储介质
CN111340233B (zh) 机器学习模型的训练方法及装置、样本处理方法及装置
CN110288085B (zh) 一种数据处理方法、装置、系统及存储介质
CN111506709B (zh) 实体链接方法、装置、电子设备和存储介质
US20210326700A1 (en) Neural network optimization
US20200042883A1 (en) Dictionary learning device, dictionary learning method, data recognition method, and program storage medium
CN112182269B (zh) 图像分类模型的训练、图像分类方法、装置、设备及介质
JPWO2015146113A1 (ja) 識別辞書学習システム、識別辞書学習方法および識別辞書学習プログラム
CN112200862B (zh) 目标检测模型的训练方法、目标检测方法及装置
CN111783088B (zh) 一种恶意代码家族聚类方法、装置和计算机设备
CN111488950B (zh) 分类模型信息输出方法及装置
JP2010272004A (ja) 判別装置及び判別方法、並びにコンピューター・プログラム
CN116883786A (zh) 图数据增广方法、装置、计算机设备及可读存储介质
CN111429414A (zh) 基于人工智能的病灶影像样本确定方法和相关装置
CN112149121A (zh) 一种恶意文件识别方法、装置、设备及存储介质
CN111523308B (zh) 中文分词的方法、装置及计算机设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
TA01 Transfer of patent application right

Effective date of registration: 20200924

Address after: Cayman Enterprise Centre, 27 Hospital Road, George Town, Grand Cayman Islands

Applicant after: Innovative advanced technology Co.,Ltd.

Address before: Cayman Enterprise Centre, 27 Hospital Road, George Town, Grand Cayman Islands

Applicant before: Advanced innovation technology Co.,Ltd.

Effective date of registration: 20200924

Address after: Cayman Enterprise Centre, 27 Hospital Road, George Town, Grand Cayman Islands

Applicant after: Advanced innovation technology Co.,Ltd.

Address before: A four-storey 847 mailbox in Grand Cayman Capital Building, British Cayman Islands

Applicant before: Alibaba Group Holding Ltd.

TA01 Transfer of patent application right
RJ01 Rejection of invention patent application after publication

Application publication date: 20200114

RJ01 Rejection of invention patent application after publication