CN111460150B - 一种分类模型的训练方法、分类方法、装置及存储介质 - Google Patents

一种分类模型的训练方法、分类方法、装置及存储介质 Download PDF

Info

Publication number
CN111460150B
CN111460150B CN202010231207.5A CN202010231207A CN111460150B CN 111460150 B CN111460150 B CN 111460150B CN 202010231207 A CN202010231207 A CN 202010231207A CN 111460150 B CN111460150 B CN 111460150B
Authority
CN
China
Prior art keywords
model
loss
sample data
labels
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202010231207.5A
Other languages
English (en)
Other versions
CN111460150A (zh
Inventor
唐可欣
齐保元
韩佳乘
孟二利
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Xiaomi Pinecone Electronic Co Ltd
Original Assignee
Beijing Xiaomi Pinecone Electronic Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Xiaomi Pinecone Electronic Co Ltd filed Critical Beijing Xiaomi Pinecone Electronic Co Ltd
Priority to CN202010231207.5A priority Critical patent/CN111460150B/zh
Publication of CN111460150A publication Critical patent/CN111460150A/zh
Priority to US16/995,765 priority patent/US20210304069A1/en
Priority to EP20193056.7A priority patent/EP3886004A1/en
Application granted granted Critical
Publication of CN111460150B publication Critical patent/CN111460150B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • G06N20/20Ensemble learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/35Clustering; Classification
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • G06N5/04Inference or reasoning models
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T7/00Image analysis
    • G06T7/10Segmentation; Edge detection
    • G06T7/194Segmentation; Edge detection involving foreground-background segmentation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/20Special algorithmic details
    • G06T2207/20081Training; Learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/20Special algorithmic details
    • G06T2207/20084Artificial neural networks [ANN]

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Computational Linguistics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Biomedical Technology (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Medical Informatics (AREA)
  • Probability & Statistics with Applications (AREA)
  • Databases & Information Systems (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本公开是关于一种分类模型的训练方法、分类方法、装置及存储介质,该训练方法包括:基于预先训练好的第一模型对已标注数据集进行处理,分别得到已标注数据集中各个已标注样本数据在N个类别的N个第一类别概率;分别从各已标注样本数据所对应的N个第一类别概率中选取最大的K个第一类别概率,并确定各已标注样本数据的K个第一类别概率所对应的K个第一预测标签;基于已标注数据集、各已标注样本数据的真实标签以及各已标注样本数据的K个第一预测标签,训练第二模型。相较于基于第一模型输出的所有第一预测标签训练第二模型,能减少存储第一预测标签需要占用的内存空间;当基于第一预测标签计算第二模型的训练损失时,还能提高数据的计算速度。

Description

一种分类模型的训练方法、分类方法、装置及存储介质
技术领域
本公开涉及数学模型技术领域,尤其涉及一种分类模型的训练方法和装置、分类方法和装置及存储介质。
背景技术
文本分类,是指根据任务目标将文档分为N个类别中的一个或多个。目前,随着神经网络语言模型在自然语言处理(Natural Language Processing,NLP)领域的发展,越来越多的研究人员选择对预训练语言模型进行微调的方法来取得高精确度的模型。然而,由于预训练模型复杂的编码结构,模型的微调和实际生产往往伴随着巨额的时间和空间开销。
知识蒸馏是一种常见的深度学习模型压缩方法,旨在把一个大模型或者多个模型融合学到的知识迁移到另一个轻量级单模型上。相关技术的知识蒸馏中,对于海量标签文本分类来说,需要保存每个样例的预测标签,而保存每个样例的预测标签需要占用大量的内存空间,并且在实际计算损失函数时,也会因为向量的纬度太高,导致计算过程非常缓慢。
发明内容
本公开提供一种分类模型的训练方法、分类方法、装置及存储介质。
根据本公开实施例的第一方面,提供一种分类模型的训练方法,应用于电子设备,包括:
基于预先训练好的第一模型对已标注数据集进行处理,分别得到所述已标注数据集中各个已标注样本数据在N个类别的N个第一类别概率;
分别从各个所述已标注样本数据所对应的N个第一类别概率中选取最大的K个所述第一类别概率,并确定各个所述已标注样本数据的所述K个第一类别概率所对应的K个第一预测标签,其中,所述K和所述N为正整数,且所述K小于所述N;
基于所述已标注数据集、各个所述已标注样本数据的真实标签以及各个所述已标注样本数据的K个所述第一预测标签,训练第二模型。
可选的,所述方法还包括:
基于所述第一模型对未标注数据集进行处理,分别得到所述未标注数据集中各个未标注样本数据在M个类别的M个第二类别概率;
从各个所述未标注样本数据对应有M个第二类别概率中选取最大的H个第二类别概率,并确定各个所述未标注样本数据的所述H个第二类别概率所对应的H个第二预测标签,其中,所述M和所述H为正整数,且所述H小于所述M;
基于所述已标注数据集、所述未标注数据集、各个所述已标注样本数据的真实标签、各个所述已标注样本数据的K个所述第一预测标签、以及各个所述未标注样本数据的H个所述第二预测标签,训练所述第二模型。
可选的,所述基于所述已标注数据集、所述未标注数据集、各个所述已标注样本数据的真实标签、各个所述已标注样本数据的K个所述第一预测标签、以及各个所述未标注样本数据的H个所述第二预测标签,训练所述第二模型,训练所述第二模型,包括:
将所述已标注数据集中的各个所述已标注样本数据输入所述第二模型,得到所述第二模型输出的第三预测标签;
将所述未标注数据集中的各个所述未标注样本数据输入所述第二模型,得到所述第二模型输出的第四预测标签;
利用预设损失函数,基于所述真实标签、各个所述已标注样本数据的K个所述第一预测标签、所述第三预测标签、各个所述未标注样本数据的H个所述第二预测标签、以及所述第四预测标签,确定所述第二模型的训练损失;
基于所述训练损失,调整所述第二模型的模型参数。
可选的,所述基于所述真实标签、各个所述已标注样本数据的K个所述第一预测标签、所述第三预测标签、各个所述未标注样本数据的H个所述第二预测标签、以及所述第四预测标签,确定所述第二模型的训练损失,包括:
基于所述真实标签和所述第三预测标签确定所述第二模型在所述已标注数据集上的第一损失;
基各个所述已标注样本数据的K个所述第一预测标签和所述第三预测标签,确定所述第二模型在所述已标注数据集上的第二损失;
基于各个所述未标注样本数据的H个所述第二预测标签和所述第四预测标签,确定所述第二模型在所述未标注数据集上的第三损失;
基于所述第一损失、所述第二损失和所述第三损失的加权和,确定所述训练损失。
可选的,所述基于所述第一损失、所述第二损失和所述第三损失的加权和,确定所述训练损失,包括:
确定所述第一损失值与第一预设权重的第一乘积;
根据所述第一预设权重确定损失权重,并确定所述第二损失值与所述损失权重的第二乘积;
确定所述第三损失值与第二预设权重的第三乘积,所述第二预设权重小于或者等于所述第一预设权重;
将所述第一乘积、所述第二乘积、以及所述第三乘积相加,得到所述训练损失。
可选的,所述方法还包括:
当所述训练损失在设定时长内的数值变化小于设定变化阈值时,停止训练所述第二模型。
根据本公开实施例的第二方面,提供一种分类方法,应用于电子设备,包括:
将待分类数据输入采用上述第一方面提供的分类模型训练方法训练得到的第二模型,输出所述待分类数据在X个类别的X个类别概率;
按照类别概率从大到小,确定X个类别概率中前预设数量个类别概率对应的类别标签;
将所述预设数量个类别标签确定为所述待分类数据的类别标签。
根据本公开实施例的第三方面,提供一种分类模型的训练装置,应用于电子设备,包括:
第一确定模块,配置为基于预先训练好的第一模型对已标注数据集进行处理,分别得到所述已标注数据集中各个已标注样本数据在N个类别的N个第一类别概率;
第一选取模块,配置为分别从各个所述已标注样本数据所对应的N个第一类别概率中选取最大的K个所述第一类别概率,并确定各个所述已标注样本数据的所述K个第一类别概率所对应的K个第一预测标签,其中,所述K和所述N为正整数,且所述K小于所述N;
第一训练模块,配置为基于所述已标注数据集、各个所述已标注样本数据的真实标签以及各个所述已标注样本数据的K个所述第一预测标签,训练第二模型。
可选的,所述装置还包括:
第二确定模块,配置为基于所述第一模型对未标注数据集进行处理,分别得到所述未标注数据集中各个未标注样本数据在M个类别的M个第二类别概率;
第二选取模块,配置为从各个所述未标注样本数据对应有M个第二类别概率中选取最大的H个第二类别概率,并确定各个所述未标注样本数据的所述H个第二类别概率所对应的H个第二预测标签,其中,所述M和所述H为正整数,且所述H小于所述M;
第二训练模块,配置为基于所述已标注数据集、所述未标注数据集、各个所述已标注样本数据的真实标签、各个所述已标注样本数据的K个所述第一预测标签、以及各个所述未标注样本数据的H个所述第二预测标签,训练所述第二模型。
可选的,所述第二训练模块,包括:
第一确定子模块,配置为将所述已标注数据集中的各个所述已标注样本数据输入所述第二模型,得到所述第二模型输出的第三预测标签;
第二确定子模块,配置为将所述未标注数据集中的各个所述未标注样本数据输入所述第二模型,得到所述第二模型输出的第四预测标签;
第三确定子模块,配置为利用预设损失函数,基于所述真实标签、各个所述已标注样本数据的K个所述第一预测标签、所述第三预测标签、各个所述未标注样本数据的H个所述第二预测标签、以及所述第四预测标签,确定所述第二模型的训练损失;
调整子模块,配置为基于所述训练损失,调整所述第二模型的模型参数。
可选的,所述第三确定子模块,还配置为:
基于所述真实标签和所述第三预测标签确定所述第二模型在所述已标注数据集上的第一损失;
基各个所述已标注样本数据的K个所述第一预测标签和所述第三预测标签,确定所述第二模型在所述已标注数据集上的第二损失;
基于各个所述未标注样本数据的H个所述第二预测标签和所述第四预测标签,确定所述第二模型在所述未标注数据集上的第三损失;
基于所述第一损失、所述第二损失和所述第三损失的加权和,确定所述训练损失。
可选的,所述第三确定子模块,还配置为:
确定所述第一损失值与第一预设权重的第一乘积;
根据所述第一预设权重确定损失权重,并确定所述第二损失值与所述损失权重的第二乘积;
确定所述第三损失值与第二预设权重的第三乘积,所述第二预设权重小于或者等于所述第一预设权重;
将所述第一乘积、所述第二乘积、以及所述第三乘积相加,得到所述训练损失。
可选的,其特征在于,所述装置还包括:
停止模块,配置为当所述训练损失在设定时长内的数值变化小于设定变化阈值时,停止训练所述第二模型。
根据本公开实施例的第四方面,提供一种分类装置,应用于电子设备,包括:
分类模块,用于将待分类数据输入采用上述第一方面提供的分类模型训练方法训练得到的第二模型,输出所述待分类数据在X个类别的X个类别概率;
标签确定模块,用于按照类别概率从大到小,确定X个类别概率中前预设数量个类别概率对应的类别标签;
类别确定模块,用于将所述预设数量个类别标签确定为所述待分类数据的类别标签。
根据本公开实施例的第五方面,提供一种分类模型的训练装置,包括:
处理器;
配置为存储处理器可执行指令的存储器;
其中,所述处理器配置为:执行时实现上述第一方面中任一种分类模型的训练方法或者上述第二方面中分类方法中的步骤。
根据本公开实施例的第六方面,提供一种非临时性计算机可读存储介质,当所述存储介质中的指令由分类模型的训练装置的处理器执行时,使得所述装置能够执行上述第一方面中任一种分类模型的训练方法或者上述第二方面中分类方法中的步骤。
本公开的实施例提供的技术方案可以包括以下有益效果:
由上述实施例可知,本公开可以基于第一模型对已标注数据集中的已标注数据进行预测,并输出各个已标注数据的第一类别概率以及各个已标注数据的第一预测标签,然后从第一模型输出的所有第一预测标签中选取概率最大的K个第一类别概率,以及所述K个第一类别概率所对应的K个第一预测标签。
由于在基于第一模型对第二模型进行训练的过程中,需要将第一模型输出的第一预测标签保存至设定存储空间,需要基于该第一预测标签训练第二模型时,再从该设定存储空间调用第一预测标签,如果所存储的第一预测标签数量较大,可能会浪费设定存储空间的内存资源。本公开实施例中,通过选取最大的K个第一类别概率所对应的K个第一预测标签训练第二模型,相较于直接基于第一模型输出的所有第一预测标签训练第二模型,第一方面,能够减少存储第一预测标签需要占用的内存空间;第二方面,由于数据量减少,在训练的过程中,如果需要基于第一预测标签计算第二模型的训练损失,能够提高数据的计算速度。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本公开。
附图说明
此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本公开的实施例,并与说明书一起用于解释本公开的原理。
图1是根据一示例性实施例示出的一种分类模型的训练方法的流程示意图。
图2是根据一示例性实施例示出的另一种分类模型的训练方法的流程示意图。
图3是根据一示例性实施例示出的分类模型的训练装置框图。
图4是根据一示例性实施例示出的一种用于分类模型的训练装置的框图。
图5是根据一示例性实施例示出的另一种用于分类模型的训练装置的框图。
具体实施方式
这里将详细地对示例性实施例进行说明,其示例表示在附图中。下面的描述涉及附图时,除非另有表示,不同附图中的相同数字表示相同或相似的要素。以下示例性实施例中所描述的实施方式并不代表与本公开相一致的所有实施方式。相反,它们仅是与如所附权利要求书中所详述的、本公开的一些方面相一致的装置和方法的例子。
本公开实施例中提供了一种分类模型的训练方法,图1是根据一示例性实施例示出的一种分类模型的训练方法的流程示意图,如图1所示,该方法应用于电子设备,主要包括以下步骤:
在步骤101中,基于预先训练好的第一模型对已标注数据集进行处理,分别得到所述已标注数据集中各个已标注样本数据在N个类别的N个第一类别概率;
在步骤102中,分别从各个所述已标注样本数据所对应的N个第一类别概率中选取最大的K个所述第一类别概率,并确定各个所述已标注样本数据的所述K个第一类别概率所对应的K个第一预测标签,其中,所述K和所述N为正整数,且所述K小于所述N;
在步骤103中,基于所述已标注数据集、各个所述已标注样本数据的真实标签以及各个所述已标注样本数据的K个所述第一预测标签,训练第二模型。
这里,电子设备包括移动终端和固定终端,其中,移动终端包括:手机、平板电脑、笔记本电脑等;固定终端包括:个人计算机。在其他可选的实施例中,该信息处理方法也可以运行于网络侧设备,其中,网络侧设备包括:服务器、处理中心等。
本公开实施例第一模型和第二模型可为实现预定功能的数学模型,包括但不限于以下至少之一:
对输入的文本进行分类;
对输入图像中目标和背景进行分割的目标分割;
对输入图像中目标的分类;
基于收入图像的目标跟踪;
基于医疗图像的诊断辅助;
基于输入语音的语音识别、语音校正等功能。
以上仅是对所述第一模型和第二模型所实现预定功能的举例说明,具体实现不局限于上述举例。
在其他可选的实施例中,可以基于已标注的训练数据集对预设模型进行训练,得到第一模型,其中,预设模型包括包括预测精度高但数据处理速度慢的预训练模型。例如,Bert模型、知识增强语义表示模型(Enhanced Representation from KnowledgeIntegration,Ernie模型)、Xlnet模型、神经网络模型、快速文本分类模型以及支持向量机模型等。第二模型包括预测精度低但数据处理速度快的模型,例如,albert模型、tiny模型等。
以第一模型是Bert模型为例,可以基于训练数据集对Bert模型进行训练,得到训练好的目标Bert模型。这时,可以将已标注数据集中的已标注数据输入目标Bert模型,并基于目标Bert模型输出各个已标注数据样本在N各类别的N各第一类别概率,这里,第一类别概率的类型可以包括:非归一化类别概率和归一化类别概率,其中,非归一化概率是未经归一化函数(例如,softmax函数)进行归一化处理的概率值,归一化概率是指经过归一化函数进行归一化处理的概率值。由于非归一化类别概率相较于归一化类别概率所包含的信息量较高,本公开实施例中,可以基于第一模型输出非归一化类别概率,在其他可选的实施例中,也可以基于第一模型输出归一化类别概率。
以已标注数据集中某一个已标注样本数据(第一样本数据)为例,在将第一样本数据输入第一模型之后,可以基于第一模型输出第一样本数据在N个类别的N个第一类别概率。例如,第一样本数据在第一类的第一类别概率是0.4,第一样本数据在第二类的第一类别概率是0.001,第一样本数据在第三类的第一类别概率是0.05,.......第一样本数据在第N类的第一类别概率是0.35,这样,就能确定出第一样本数据在各个类别的第一类别概率,其中,第一类别概率越大,则第一样本数据属于该类别的可能性越大,第一类别概率越小,则第一样本数据属于该类别的可能性越小。例如,第一样本数据在第一类的第一类别概率是0.4,第一样本数据在第二类的第一类别概率是0.001,则可以确定第一样本数据属于第一类的概率高于属于第二类的概率。
在得到已标注数据集中各个已标注样本数据在N个类别的N个第一类别概率之后,可以将N个第一类别概率按照从大到小的顺序进行排序,根据排序结果从N个第一类别概率中选取最大的K个第一类别概率。还是以已标注数据集中的第一样本数据为例,第一样本数据在第一类的第一类别概率是0.4,第一样本数据在第二类的第一类别概率是0.001,第一样本数据在第三类的第一类别概率是0.05,.......第一样本数据在第N类的第一类别概率是0.35,在将第一样本数据对应的N个第一类别概率按照从大大小进行排序之后,可以取前K个第一类别概率。以N是3000,K是20为例,可以对3000个第一类别概率按照从大到小的顺序进行排序,并选取最大的20个第一类别概率。
由于当第一类别概率小于设定概率阈值时,第一样本数据属于该类别的可能性会很小,本公开实施例中,能够选取数值较高的第一类别概率,摒弃数值较低的第一类别概率,能够在保证输出的类别概率的准确性的基础上减少数据量,进而减少训练模型的计算量。在选取出最大的K个第一类别概率之后,能够确定出最大的K个第一类别概率所对应的的K个第一预测标签,并基于已标注数据集、各个已标注样本数据的真实标签以及该K个第一预测标签,训练第二模型。
本公开实施例中,可以基于第一模型对已标注数据集中的已标注样本数据进行预测,并输出各个已标注样本数据的第一类别概率以及各个已标注样本数据的第一预测标签,然后从第一模型输出的所有第一预测标签中选取概率最大的K个第一类别概率,以及所述K个第一类别概率所对应的K个第一预测标签。
由于在基于第一模型对第二模型进行训练的过程中,需要将第一模型输出的第一预测标签保存至设定存储空间,需要基于该第一预测标签训练第二模型时,再从该设定存储空间调用第一预测标签,如果所存储的第一预测标签数量较大,可能会浪费设定存储空间的内存资源。本公开实施例中,通过选取最大的K个第一类别概率所对应的K个第一预测标签训练第二模型,相较于直接基于第一模型输出的所有第一预测标签训练第二模型,第一方面,能够减少存储第一预测标签需要占用的内存空间;第二方面,由于数据量减少,在训练的过程中,如果需要基于第一预测标签计算第二模型的训练损失,能够提高数据的计算速度。
在其他可选的实施例中,所述方法还包括:
基于所述第一模型对未标注数据集进行处理,分别得到所述未标注数据集中各个未标注样本数据在M个类别的M个第二类别概率;
从各个所述未标注样本数据对应有M个第二类别概率中选取最大的H个第二类别概率,并确定各个所述未标注样本数据的所述H个第二类别概率所对应的H个第二预测标签,其中,所述M和所述H为正整数,且所述H小于所述M;
基于所述已标注数据集、所述未标注数据集、各个所述已标注样本数据的真实标签、各个所述已标注样本数据的K个所述第一预测标签、以及各个所述未标注样本数据的H个所述第二预测标签,训练所述第二模型。
这里,第二类别概率的类型可以包括:非归一化类别概率和归一化类别概率。由于归一化类别概率相较于非归一化类别概率能够使各个类别之间的分别更加明显,本公开实施例中,可以基于第一模型输出归一化类别概率,在其他可选的实施例中,也可以基于第一模型输出非归一化类别概率。
以未标注数据集中某一个未标注样本数据(第二样本数据)为例,在将第二样本数据输入第一模型之后,可以基于第一模型输出第二样本数据在M个类别的M个第二类别概率。例如,第二样本数据在第一类的第二类别概率是0.01,第二样本数据在第二类的第二类别概率是0.0001,第二样本数据在第三类的第二类别概率是0.45,.......第二样本数据在第N类的第二类别概率是0.35,这样,就能确定出第二样本数据在各个类别的第二类别概率,其中,第二类别概率越大,则第二样本数据属于该类别的可能性越大,第二类别概率越小,则第二样本数据属于该类别的可能性越小。例如,第二样本数据在第三类的第二类别概率是0.45,第二样本数据在第二类的第二类别概率是0.0001,则可以确定第二样本数据属于第三类的概率高于属于第二类的概率。
在得到未标注数据集中各个未标注样本数据在M个类别的M个第二类别概率之后,可以将M个第二类别概率按照从大到小的顺序进行排序,根据排序结果从M个第二类别概率中选取最大的H个第二类别概率。还是以未标注数据集中的第二样本数据为例,第二样本数据在第一类的第二类别概率是0.01,第二样本数据在第二类的第二类别概率是0.0001,第二样本数据在第三类的第二类别概率是0.45,.......第二样本数据在第N类的第二类别概率是0.35,在将第二样本数据对应的M个第二类别概率按照从大大小进行排序之后,可以取前H个第二类别概率。以M是300,H是1为例,可以对300个第二类别概率按照从大到小的顺序进行排序,并选取最大的一个第二类别概率,并将最大的一个第二类别概率所对应的第二预测标签确定为第二样本数据的标签。
本公开实施例中,可以基于第一模型对未标注数据集中的未标注样本数据进行预测,并输出各个未标注数据的第二类别概率以及各个未标注数据的第二预测标签,然后从第一模型输出的所有第二预测标签中选取概率最大的H个第二类别概率,以及所述H个第二类别概率所对应的H个第二预测标签。通过增加未标注样本数据的第二预测标签,并基于第二预测标签对第二模型进行训练,扩充了第二模型的训练语料,能够提高数据的多样性,以提高所训练的第二模型的泛化能力。
在其他可选的实施例中,所述基于所述已标注数据集、所述未标注数据集、各个所述已标注样本数据的真实标签、各个所述已标注样本数据的K个所述第一预测标签、以及各个所述未标注样本数据的H个所述第二预测标签,训练所述第二模型,训练所述第二模型,包括:
将所述已标注数据集中的各个所述已标注样本数据输入所述第二模型,得到所述第二模型输出的第三预测标签;
将所述未标注数据集中的各个所述未标注样本数据输入所述第二模型,得到所述第二模型输出的第四预测标签;
利用预设损失函数,基于所述真实标签、各个所述已标注样本数据的K个所述第一预测标签、所述第三预测标签、各个所述未标注样本数据的H个所述第二预测标签、以及所述第四预测标签,确定所述第二模型的训练损失;
基于所述训练损失,调整所述第二模型的模型参数。
这里,预设损失函数用于衡量第二模型预测的好坏,本公开实施例中,通过将已标注样本数据输入第二模型预测得到第三预测标签,将未标注样本数据输入第二模型得到第四预测标签,并利用预设损失函数,基于真实标签、各个已标注样本数据的K个第一预测标签、第三预测标签、各个未标注样本数据的H个第二预测标签、以及第四预测标签,确定第二模型的训练损失,进而利用基于预设损失函数得到的训练损失调整第二模型的模型参数。
本公开实施例中,第一方面,相较于直接基于第一模型输出的所有第一预测标签训练第二模型,能够减少存储第一预测标签需要占用的内存空间;第二方面,由于数据量减少,在训练的过程中,如果需要基于第一预测标签计算第二模型的训练损失,能够提高数据的计算速度;第三方面,通过增加未标注样本数据的第二预测标签,并基于第二预测标签对第二模型进行训练,扩充了第二模型的训练语料,能够提高数据的多样性,以提高所训练的第二模型的泛化能力;第四方面,还针对不同的损失计算任务使用新的预设损失函数,基于预设损失函数调整第二模型的模型参数,能够提升第二模型的性能。
在其他可选的实施例中,该方法还包括:基于测试数据集对训练后的第二模型进行性能评估,得到评估结果;其中,测试数据集中测试数据的类型包括以下至少之一:文本数据类型、图像数据类型、业务数据类型和音频数据类型。这里,在得到训练后的第二模型之后,可以在测试数据集上评估其性能,逐步优化第二模型,直至找到最优的第二模型,例如最小化验证损失或最大化奖励的第二模型。这里,可以将测试数据集中的测试数据输入训练好的第二模型,经由该第二模型输出评估结果,然后将输出的评估结果与预设的标准进行比较,得到比较结果,并根据比较结果评估第二模型的性能,其中,测试结果可以为第二模型处理测试数据的速度或者精度。
在其他可选的实施例中,所述基于所述真实标签、各个所述已标注样本数据的K个所述第一预测标签、所述第三预测标签、各个所述未标注样本数据的H个所述第二预测标签、以及所述第四预测标签,确定所述第二模型的训练损失,包括:
基于所述真实标签和所述第三预测标签确定所述第二模型在所述已标注数据集上的第一损失;
基各个所述已标注样本数据的K个所述第一预测标签和所述第三预测标签,确定所述第二模型在所述已标注数据集上的第二损失;
基于各个所述未标注样本数据的H个所述第二预测标签和所述第四预测标签,确定所述第二模型在所述未标注数据集上的第三损失;
基于所述第一损失、所述第二损失和所述第三损失的加权和,确定所述训练损失。
这里,所述第一损失为真实标签和第三预测标签的交叉熵,第一损失的计算公式包括:
公式(1)中,loss(hard)表示所述第一损失;N表示所述已标注数据集的大小;yi'表示第i个维度的真实标签,yi表示第i个维度的第三预测标签,i为正整数。其中,yi的计算公式包括:
公式(2)中,yi表示第i个维度的第三预测标签,Zi表示第i个维度的已标注数据的第一类别概率;Zj表示第j个维度的已标注数据的第一类别概率,i和j均为正整数。
所述第二损失为各个已标注样本数据的K个第一预测标签与第三预测标签的交叉熵,第二损失的计算公式包括:
公式(3)中,loss(soft)表示所述第二损失;表示第i个维度的第一预测标签;yi表示第i个维度第三预测标签;T表示预设温度参数;ST1表示第一预测标签的数量,可以等于K;i为正整数,这里,它所包含的分类信息就越多出的预测值更加平缓。其中,yi的计算公式包括:
公式(4)中,yi表示第i个维度第三预测标签,Zi表示第i个维度的已标注数据的第一类别概率;Zj表示第j个维度的已标注数据的第一类别概率,T表示预设温度参数;i和j均为正整数。这里,预设温度参数的数值越大,输出的概率分布就越平缓,输出结果所包含的分类信息就越多,通设置预设温度参数,可以基于预设温度参数调整输出的概率分布的平缓度,进而调整输出结果所包含的分类信息量,能够提高模型训练的精确度和灵活性。
所述第三损失为第二预测标签与第四预测标签的交叉熵,第三损失的计算公式包括:
公式(5)中,表示所述第三损失;y”i表示第i个维度的第二预测标签;yi表示第i个维度的第四预测标签;M表示未标注数据集的大小;i为正整数。本公开实施例中,针对不同的损失计算任务使用新的预设损失函数,基于预设损失函数调整第二模型的模型参数,能够提升第二模型的性能。
在其他可选的实施例中,所述基于所述第一损失、所述第二损失和所述第三损失的加权和,确定所述训练损失,包括:
确定所述第一损失值与第一预设权重的第一乘积;
根据所述第一预设权重确定损失权重,并确定所述第二损失值与所述损失权重的第二乘积;
确定所述第三损失值与第二预设权重的第三乘积,所述第二预设权重小于或者等于所述第一预设权重;
将所述第一乘积、所述第二乘积、以及所述第三乘积相加,得到所述训练损失。
在其他可选的实施例中,所述训练损失计算公式包括:
公式(6)中,Loss表示所述第二模型的训练损失;loss(hard)表示所述第一损失;loss(soft)表示所述第二损失;表示所述第三损失;α为第一预设权重,且α大于0.5且小于1;β为第二预设权重,且β小于等于α。本公开实施例中,一方面,针对不同的损失计算任务使用新的预设损失函数,基于预设损失函数调整第二模型的模型参数,能够提升第二模型的性能;另一方面,通过设置可调整的第一预设权重和第二预设权重,可以根据需要调整第一损失、第二损失和第三损失在训练损失中所占的比重,增加了模型训练的灵活性。
在其他可选的实施例中,所述方法还包括:
当所述训练损失在设定时长内的数值变化小于设定变化阈值时,停止训练所述第二模型。在其他可选的实施例中,还可以基于设定的验证集对第二模型的准确率进行验证,当准确率达到设定准确率时,停止训练第二模型,获得训练完成的目标模型。
图2是根据一示例性实施例示出的另一种模型的训练方法的流程示意图,如图2所示,在基于第一模型(Teacher模型)训练第二模型(Student模型)的过程中,可以预先确定好第一模型,在已标注的训练数据集L上对第一模型进行微调,并保存微调后的第一模型,这里,可以将微调后的第一模型记为TM。这里,第一模型可以为预测精度高但计算速度慢的预训练模型,例如,Bert模型、Ernie模型、Xlnet模型等。
在得到TM之后,可以使用TM对已标注数据集(转移集T)进行预测,分别得到已标注数据集中各个已标注样本数据在N个类别的N个第一类别概率,分别从各个已标注样本数据所对应的N个第一类别概率中选取最大的K个第一类别概率,并确定最大的K个第一类别概率所对应的K个第一预测标签,其中,K为超参数,例如,K可以等于20。
本公开实施例中,还可以使用TM对未标注数据集U进行预测,分别得到未标注数据集中各个未标注样本数据在M个类别的M个第二类别概率,从各个未标注样本数据对应的M个第二类别概率中选取最大的H个第二类别概率,并确定最大的H个第二类别概率所对应的H个第二预测标签,其中,H可以等于1。这里,当第二类别概率为非归一化类别概率时,可以利用激活函数softmax对第二类别概率进行归一化处理。这样,就能够确定训练第二模型所需的数据。
本公开实施例中,可以将已标注数据集中的各个已标注样本数据输入第二模型,得到第二模型输出的第三预测标签;将未标注数据集中的各个未标注样本数据输入第二模型,得到第二模型输出的第四预测标签;利用预设损失函数,基于真实标签、各个已标注样本数据的K个第一预测标签、第三预测标签、各个未标注样本数据的H个第二预测标签、以及第四预测标签,确定第二模型的训练损失;基于训练损失,调整第二模型的模型参数。
本公开实施例中,第一方面,通过选取第一模型输出的最大的K个第一预测标签代替传统模型蒸馏中选取所有第一预测标签来训练第二模型,在不会对第二模型的性能产生影响的基础上,减少了内存消耗,并提升了第二模型训练速度;第二方面,充分利用未标注数据集,将未标注数据引入数据蒸馏的过程中,扩充了第二模型的训练语料,能够提高数据的多样性,以提高所训练的第二模型的泛化能力;第三方面,针对联合任务使用新的预设损失函数,基于预设损失函数调整第二模型的模型参数,能够提升第二模型的性能。
本发明实施例还提供一种分类方法,可以采用经过训练的第二模型对待分类数据进行分类,可以包括如下步骤:
步骤一、将待分类数据输入采用上述介绍的任一分类模型训练方法训练得到的第二模型,输出所述待分类数据在X个类别的X个类别概率。其中,X为自然数。
步骤二、按照类别概率从大到小,确定X个类别概率中前预设数量个类别概率对应的类别标签。
步骤三、将所述预设数量个类别标签确定为所述待分类数据的类别标签。
可以根据实际需要确定待分类数据的类别标签数量,即预设数量,可以为一个或者多个。当预设数量为一个时,可以将类别概率最高的类别标签作为待分类数据的标签,当预设数量为多个时,可以按照类别概率从大到小的顺序,确定前多个类别概率,将这多个类别概率对应的类别标签,确定为待分类数据的类别标签。
图3是根据一示例性实施例示出的分类模型的训练装置框图,如图3所示,该分类模型的训练装置300应用于电子设备,主要包括:
第一确定模块301,配置为基于预先训练好的第一模型对已标注数据集进行处理,分别得到所述已标注数据集中各个已标注样本数据在N个类别的N个第一类别概率;
第一选取模块302,配置为分别从各个所述已标注样本数据所对应的N个第一类别概率中选取最大的K个所述第一类别概率,并确定各个所述已标注样本数据的所述K个第一类别概率所对应的K个第一预测标签,其中,所述K和所述N为正整数,且所述K小于所述N;
第一训练模块303,配置为基于所述已标注数据集、各个所述已标注样本数据的真实标签以及各个所述已标注样本数据的K个所述第一预测标签,训练第二模型。
在其他可选的实施例中,所述装置300还包括:
第二确定模块,配置为基于所述第一模型对未标注数据集进行处理,分别得到所述未标注数据集中各个未标注样本数据在M个类别的M个第二类别概率;
第二选取模块,配置为从各个所述未标注样本数据对应有M个第二类别概率中选取最大的H个第二类别概率,并确定各个所述未标注样本数据的所述H个第二类别概率所对应的H个第二预测标签,其中,所述M和所述H为正整数,且所述H小于所述M;
第二训练模块,配置为基于所述已标注数据集、所述未标注数据集、各个所述已标注样本数据的真实标签、各个所述已标注样本数据的K个所述第一预测标签、以及各个所述未标注样本数据的H个所述第二预测标签,训练所述第二模型。
在其他可选的实施例中,所述第二训练模块,包括:
第一确定子模块,配置为将所述已标注数据集中的各个所述已标注样本数据输入所述第二模型,得到所述第二模型输出的第三预测标签;
第二确定子模块,配置为将所述未标注数据集中的各个所述未标注样本数据输入所述第二模型,得到所述第二模型输出的第四预测标签;
第三确定子模块,配置为利用预设损失函数,基于所述真实标签、各个所述已标注样本数据的K个所述第一预测标签、所述第三预测标签、各个所述未标注样本数据的H个所述第二预测标签、以及所述第四预测标签,确定所述第二模型的训练损失;
调整子模块,配置为基于所述训练损失,调整所述第二模型的模型参数。
在其他可选的实施例中,所述第三确定子模块,还配置为:
基于所述真实标签和所述第三预测标签确定所述第二模型在所述已标注数据集上的第一损失;
基各个所述已标注样本数据的K个所述第一预测标签和所述第三预测标签,确定所述第二模型在所述已标注数据集上的第二损失;
基于各个所述未标注样本数据的H个所述第二预测标签和所述第四预测标签,确定所述第二模型在所述未标注数据集上的第三损失;
基于所述第一损失、所述第二损失和所述第三损失的加权和,确定所述训练损失。
在其他可选的实施例中,所述第三确定子模块,还配置为:
确定所述第一损失值与第一预设权重的第一乘积;
根据所述第一预设权重确定损失权重,并确定所述第二损失值与所述损失权重的第二乘积;
确定所述第三损失值与第二预设权重的第三乘积,所述第二预设权重小于或者等于所述第一预设权重;
将所述第一乘积、所述第二乘积、以及所述第三乘积相加,得到所述训练损失。
在其他可选的实施例中,所述装置300还包括:
停止模块,配置为当所述训练损失在设定时长内的数值变化小于设定变化阈值时,停止训练所述第二模型。
本发明实施例还提供一种分类装置,应用于电子设备,包括:
分类模块,用于将待分类数据输入采用上述任一实施例提供的分类模型训练方法训练得到的第二模型,输出所述待分类数据在X个类别的X个类别概率;
标签确定模块,用于按照类别概率从大到小,确定X个类别概率中前预设数量个类别概率对应的类别标签;
类别确定模块,用于将所述预设数量个类别标签确定为所述待分类数据的类别标签。
关于上述实施例中的装置,其中各个模块执行操作的具体方式已经在有关该方法的实施例中进行了详细描述,此处将不做详细阐述说明。
图4是根据一示例性实施例示出的一种用于分类模型的训练装置1200,或者分类装置1200的框图。例如,装置1200可以是移动电话,计算机,数字广播终端,消息收发设备,游戏控制台,平板设备,医疗设备,健身设备,个人数字助理等。
参照图4,装置1200可以包括以下一个或多个组件:处理组件1202,存储器1204,电力组件1206,多媒体组件1208,音频组件1210,输入/输出(I/O)的接口1212,传感器组件1214,以及通信组件1216。
处理组件1202通常控制装置1200的整体操作,诸如与显示,电话呼叫,数据通信,相机操作和记录操作相关联的操作。处理组件1202可以包括一个或多个处理器1220来执行指令,以完成上述的方法的全部或部分步骤。此外,处理组件1202可以包括一个或多个模块,便于处理组件1202和其他组件之间的交互。例如,处理组件1202可以包括多媒体模块,以方便多媒体组件1208和处理组件1202之间的交互。
存储器1204被配置为存储各种类型的数据以支持在设备1200的操作。这些数据的示例包括用于在装置1200上操作的任何应用程序或方法的指令,联系人数据,电话簿数据,消息,图片,视频等。存储器1204可以由任何类型的易失性或非易失性存储设备或者它们的组合实现,如静态随机存取存储器(SRAM),电可擦除可编程只读存储器(EEPROM),可擦除可编程只读存储器(EPROM),可编程只读存储器(PROM),只读存储器(ROM),磁存储器,快闪存储器,磁盘或光盘。
电力组件1206为装置1200的各种组件提供电力。电力组件1206可以包括电源管理系统,一个或多个电源,及其他与为装置1200生成、管理和分配电力相关联的组件。
多媒体组件1208包括在所述装置1200和用户之间的提供一个输出接口的屏幕。在一些实施例中,屏幕可以包括液晶显示器(LCD)和触摸面板(TP)。如果屏幕包括触摸面板,屏幕可以被实现为触摸屏,以接收来自用户的输入信号。触摸面板包括一个或多个触摸传感器以感测触摸、滑动和触摸面板上的手势。所述触摸传感器可以不仅感测触摸或滑动动作的边界,而且还检测与所述触摸或滑动操作相关的持续时间和压力。在一些实施例中,多媒体组件1208包括一个前置摄像头和/或后置摄像头。当设备1200处于操作模式,如拍摄模式或视频模式时,前置摄像头和/或后置摄像头可以接收外部的多媒体数据。每个前置摄像头和后置摄像头可以是一个固定的光学透镜系统或具有焦距和光学变焦能力。
音频组件1210被配置为输出和/或输入音频信号。例如,音频组件1210包括一个麦克风(MIC),当装置1200处于操作模式,如呼叫模式、记录模式和语音识别模式时,麦克风被配置为接收外部音频信号。所接收的音频信号可以被进一步存储在存储器1204或经由通信组件1216发送。在一些实施例中,音频组件1210还包括一个扬声器,用于输出音频信号。
I/O接口1212为处理组件1202和外围接口模块之间提供接口,上述外围接口模块可以是键盘,点击轮,按钮等。这些按钮可包括但不限于:主页按钮、音量按钮、启动按钮和锁定按钮。
传感器组件1214包括一个或多个传感器,用于为装置1200提供各个方面的状态评估。例如,传感器组件1214可以检测到设备1200的打开/关闭状态,组件的相对定位,例如所述组件为装置1200的显示器和小键盘,传感器组件1214还可以检测装置1200或装置1200一个组件的位置改变,用户与装置1200接触的存在或不存在,装置1200方位或加速/减速和装置1200的温度变化。传感器组件1214可以包括接近传感器,被配置用来在没有任何的物理接触时检测附近物体的存在。传感器组件1214还可以包括光传感器,如CMOS或CCD图像传感器,用于在成像应用中使用。在一些实施例中,该传感器组件1214还可以包括加速度传感器,陀螺仪传感器,磁传感器,压力传感器或温度传感器。
通信组件1216被配置为便于装置1200和其他设备之间有线或无线方式的通信。装置1200可以接入基于通信标准的无线网络,如WiFi,2G或3G,或它们的组合。在一个示例性实施例中,通信组件1216经由广播信道接收来自外部广播管理系统的广播信号或广播相关信息。在一个示例性实施例中,所述通信组件1216还包括近场通信(NFC)模块,以促进短程通信。例如,在NFC模块可基于射频识别(RFID)技术,红外数据协会(IrDA)技术,超宽带(UWB)技术,蓝牙(BT)技术和其他技术来实现。
在示例性实施例中,装置1200可以被一个或多个应用专用集成电路(ASIC)、数字信号处理器(DSP)、数字信号处理设备(DSPD)、可编程逻辑器件(PLD)、现场可编程门阵列(FPGA)、控制器、微控制器、微处理器或其他电子元件实现,用于执行上述方法。
在示例性实施例中,还提供了一种包括指令的非临时性计算机可读存储介质,例如包括指令的存储器1204,上述指令可由装置1200的处理器1220执行以完成上述方法。例如,所述非临时性计算机可读存储介质可以是ROM、随机存取存储器(RAM)、CD-ROM、磁带、软盘和光数据存储设备等。
一种非临时性计算机可读存储介质,当所述存储介质中的指令由移动终端的处理器执行时,使得移动终端能够执行一种分类模型的训练方法,所述方法包括:
基于预先训练好的第一模型对已标注数据集进行处理,分别得到所述已标注数据集中各个已标注样本数据在N个类别的N个第一类别概率;
分别从各个所述已标注样本数据所对应的N个第一类别概率中选取最大的K个所述第一类别概率,并确定各个所述已标注样本数据的所述K个第一类别概率所对应的K个第一预测标签,其中,所述K和所述N为正整数,且所述K小于所述N;
基于所述已标注数据集、各个所述已标注样本数据的真实标签以及各个所述已标注样本数据的K个所述第一预测标签,训练第二模型。
或者使得移动终端能够执行一种分类方法,所述方法包括:
将待分类数据输入采用上述任一实施例提供的分类模型训练方法训练得到的第二模型,输出所述待分类数据在X个类别的X个类别概率;
按照类别概率从大到小,确定X个类别概率中前预设数量个类别概率对应的类别标签;
将所述预设数量个类别标签确定为所述待分类数据的类别标签。
图5是根据一示例性实施例示出的另一种用于分类模型的训练装置1300或者用于分类的装置1300的框图。例如,装置1300可以被提供为一服务器。参照图5,装置1300包括处理组件1322,其进一步包括一个或多个处理器,以及由存储器1332所代表的存储器资源,用于存储可由处理组件1322的执行的指令,例如应用程序。存储器1332中存储的应用程序可以包括一个或一个以上的每一个对应于一组指令的模块。此外,处理组件1322被配置为执行指令,以执行上述分类模型的训练方法,所述方法包括:
基于预先训练好的第一模型对已标注数据集进行处理,分别得到所述已标注数据集中各个已标注样本数据在N个类别的N个第一类别概率;
分别从各个所述已标注样本数据所对应的N个第一类别概率中选取最大的K个所述第一类别概率,并确定各个所述已标注样本数据的所述K个第一类别概率所对应的K个第一预测标签,其中,所述K和所述N为正整数,且所述K小于所述N;
基于所述已标注数据集、各个所述已标注样本数据的真实标签以及各个所述已标注样本数据的K个所述第一预测标签,训练第二模型。
或执行上述分类方法,所述方法包括:
将待分类数据输入采用上述任一实施例提供的分类模型训练方法训练得到的第二模型,输出所述待分类数据在X个类别的X个类别概率;
按照类别概率从大到小,确定X个类别概率中前预设数量个类别概率对应的类别标签;
将所述预设数量个类别标签确定为所述待分类数据的类别标签。
装置1300还可以包括一个电源组件1326被配置为执行装置1300的电源管理,一个有线或无线网络接口1350被配置为将装置1300连接到网络,和一个输入输出(I/O)接口1358。装置1300可以操作基于存储在存储器1332的操作系统,例如Windows ServerTM,MacOS XTM,UnixTM,LinuxTM,FreeBSDTM或类似。
本领域技术人员在考虑说明书及实践这里公开的发明后,将容易想到本发明的其它实施方案。本申请旨在涵盖本发明的任何变型、用途或者适应性变化,这些变型、用途或者适应性变化遵循本发明的一般性原理并包括本公开未公开的本技术领域中的公知常识或惯用技术手段。说明书和实施例仅被视为示例性的,本发明的真正范围和精神由下面的权利要求指出。
应当理解的是,本发明并不局限于上面已经描述并在附图中示出的精确结构,并且可以在不脱离其范围进行各种修改和改变。本发明的范围仅由所附的权利要求来限制。

Claims (16)

1.一种分类模型的训练方法,其特征在于,应用于电子设备,包括:
基于预先训练好的第一模型对已标注数据集进行处理,分别得到所述已标注数据集中各个已标注样本数据在N个类别的N个第一类别概率;
分别从各个所述已标注样本数据所对应的N个第一类别概率中选取最大的K个所述第一类别概率,并确定各个所述已标注样本数据的所述K个第一类别概率所对应的K个第一预测标签,其中,所述K和所述N为正整数,且所述K小于所述N;
将K个所述第一预测标签保存至设定存储空间;
其中,所述设定存储空间内的K个所述第一预测标签在调用后用于训练第二模型,所述第一模型和所述第二模型用于文本分类,且所述第一类别概率为对应文本所在类别的概率;
或者,所述第一模型和所述第二模型用于图像中目标的分类,且所述第一类别概率为对应图像中目标的类别概率。
2.根据权利要求1所述的方法,其特征在于,所述方法还包括:
基于所述第一模型对未标注数据集进行处理,分别得到所述未标注数据集中各个未标注样本数据在M个类别的M个第二类别概率;
从各个所述未标注样本数据对应有M个第二类别概率中选取最大的H个第二类别概率,并确定各个所述未标注样本数据的所述H个第二类别概率所对应的H个第二预测标签,其中,所述M和所述H为正整数,且所述H小于所述M;
基于所述已标注数据集、所述未标注数据集、各个所述已标注样本数据的真实标签、各个所述已标注样本数据的K个所述第一预测标签、以及各个所述未标注样本数据的H个所述第二预测标签,训练所述第二模型。
3.根据权利要求2所述的方法,其特征在于,所述基于所述已标注数据集、所述未标注数据集、各个所述已标注样本数据的真实标签、各个所述已标注样本数据的K个所述第一预测标签、以及各个所述未标注样本数据的H个所述第二预测标签,训练所述第二模型,训练所述第二模型,包括:
将所述已标注数据集中的各个所述已标注样本数据输入所述第二模型,得到所述第二模型输出的第三预测标签;
将所述未标注数据集中的各个所述未标注样本数据输入所述第二模型,得到所述第二模型输出的第四预测标签;
利用预设损失函数,基于所述真实标签、各个所述已标注样本数据的K个所述第一预测标签、所述第三预测标签、各个所述未标注样本数据的H个所述第二预测标签、以及所述第四预测标签,确定所述第二模型的训练损失;
基于所述训练损失,调整所述第二模型的模型参数。
4.根据权利要求3所述的方法,其特征在于,所述基于所述真实标签、各个所述已标注样本数据的K个所述第一预测标签、所述第三预测标签、各个所述未标注样本数据的H个所述第二预测标签、以及所述第四预测标签,确定所述第二模型的训练损失,包括:
基于所述真实标签和所述第三预测标签确定所述第二模型在所述已标注数据集上的第一损失;
基各个所述已标注样本数据的K个所述第一预测标签和所述第三预测标签,确定所述第二模型在所述已标注数据集上的第二损失;
基于各个所述未标注样本数据的H个所述第二预测标签和所述第四预测标签,确定所述第二模型在所述未标注数据集上的第三损失;
基于所述第一损失、所述第二损失和所述第三损失的加权和,确定所述训练损失。
5.根据权利要求4所述的方法,其特征在于,所述基于所述第一损失、所述第二损失和所述第三损失的加权和,确定所述训练损失,包括:
确定所述第一损失值与第一预设权重的第一乘积;
根据所述第一预设权重确定损失权重,并确定所述第二损失值与所述损失权重的第二乘积;
确定所述第三损失值与第二预设权重的第三乘积,所述第二预设权重小于或者等于所述第一预设权重;
将所述第一乘积、所述第二乘积、以及所述第三乘积相加,得到所述训练损失。
6.根据权利要求3至5任一项所述的方法,其特征在于,所述方法还包括:
当所述训练损失在设定时长内的数值变化小于设定变化阈值时,停止训练所述第二模型。
7.一种分类方法,其特征在于,应用于电子设备,包括:
将待分类数据输入采用权利要求1~6任一项分类模型训练方法训练得到的第二模型,输出所述待分类数据在X个类别的X个类别概率;
按照类别概率从大到小,确定X个类别概率中前预设数量个类别概率对应的类别标签;
将所述预设数量个类别标签确定为所述待分类数据的类别标签。
8.一种分类模型的训练装置,其特征在于,应用于电子设备,包括:
第一确定模块,配置为基于预先训练好的第一模型对已标注数据集进行处理,分别得到所述已标注数据集中各个已标注样本数据在N个类别的N个第一类别概率;
第一选取模块,配置为分别从各个所述已标注样本数据所对应的N个第一类别概率中选取最大的K个所述第一类别概率,并确定各个所述已标注样本数据的所述K个第一类别概率所对应的K个第一预测标签,其中,所述K和所述N为正整数,且所述K小于所述N,并将K个所述第一预测标签保存至设定存储空间;
其中,所述设定存储空间内的K个所述第一预测标签在调用后用于训练第二模型,所述第一模型和所述第二模型用于文本分类,且所述第一类别概率为对应文本所在类别的概率;
或者,所述第一模型和所述第二模型用于图像中目标的分类,且所述第一类别概率为对应图像中目标的类别概率。
9.根据权利要求8所述的装置,其特征在于,所述装置还包括:
第二确定模块,配置为基于所述第一模型对未标注数据集进行处理,分别得到所述未标注数据集中各个未标注样本数据在M个类别的M个第二类别概率;
第二选取模块,配置为从各个所述未标注样本数据对应有M个第二类别概率中选取最大的H个第二类别概率,并确定各个所述未标注样本数据的所述H个第二类别概率所对应的H个第二预测标签,其中,所述M和所述H为正整数,且所述H小于所述M;
第二训练模块,配置为基于所述已标注数据集、所述未标注数据集、各个所述已标注样本数据的真实标签、各个所述已标注样本数据的K个所述第一预测标签、以及各个所述未标注样本数据的H个所述第二预测标签,训练所述第二模型。
10.根据权利要求9所述的装置,其特征在于,所述第二训练模块,包括:
第一确定子模块,配置为将所述已标注数据集中的各个所述已标注样本数据输入所述第二模型,得到所述第二模型输出的第三预测标签;
第二确定子模块,配置为将所述未标注数据集中的各个所述未标注样本数据输入所述第二模型,得到所述第二模型输出的第四预测标签;
第三确定子模块,配置为利用预设损失函数,基于所述真实标签、各个所述已标注样本数据的K个所述第一预测标签、所述第三预测标签、各个所述未标注样本数据的H个所述第二预测标签、以及所述第四预测标签,确定所述第二模型的训练损失;
调整子模块,配置为基于所述训练损失,调整所述第二模型的模型参数。
11.根据权利要求10所述的装置,其特征在于,所述第三确定子模块,还配置为:
基于所述真实标签和所述第三预测标签确定所述第二模型在所述已标注数据集上的第一损失;
基各个所述已标注样本数据的K个所述第一预测标签和所述第三预测标签,确定所述第二模型在所述已标注数据集上的第二损失;
基于各个所述未标注样本数据的H个所述第二预测标签和所述第四预测标签,确定所述第二模型在所述未标注数据集上的第三损失;
基于所述第一损失、所述第二损失和所述第三损失的加权和,确定所述训练损失。
12.根据权利要求11所述的装置,其特征在于,所述第三确定子模块,还配置为:
确定所述第一损失值与第一预设权重的第一乘积;
根据所述第一预设权重确定损失权重,并确定所述第二损失值与所述损失权重的第二乘积;
确定所述第三损失值与第二预设权重的第三乘积,所述第二预设权重小于或者等于所述第一预设权重;
将所述第一乘积、所述第二乘积、以及所述第三乘积相加,得到所述训练损失。
13.根据权利要求10至12任一项所述的装置,其特征在于,所述装置还包括:
停止模块,配置为当所述训练损失在设定时长内的数值变化小于设定变化阈值时,停止训练所述第二模型。
14.一种分类装置,其特征在于,应用于电子设备,包括:
分类模块,用于将待分类数据输入采用权利要求1~6任一项分类模型训练方法训练得到的第二模型,输出所述待分类数据在X个类别的X个类别概率;
标签确定模块,用于按照类别概率从大到小,确定X个类别概率中前预设数量个类别概率对应的类别标签;
类别确定模块,用于将所述预设数量个类别标签确定为所述待分类数据的类别标签。
15.一种分类模型的训练装置,其特征在于,包括:
处理器;
配置为存储处理器可执行指令的存储器;
其中,所述处理器配置为:执行时实现上述权利要求1至6中任一种分类模型的训练方法或权利要求7中分类方法中的步骤。
16.一种非临时性计算机可读存储介质,其特征在于,当所述存储介质中的指令由分类模型的训练装置的处理器执行时,使得所述装置能够执行上述权利要求1至6中任一种分类模型的训练方法或者权利要求7中分类方法中的步骤。
CN202010231207.5A 2020-03-27 2020-03-27 一种分类模型的训练方法、分类方法、装置及存储介质 Active CN111460150B (zh)

Priority Applications (3)

Application Number Priority Date Filing Date Title
CN202010231207.5A CN111460150B (zh) 2020-03-27 2020-03-27 一种分类模型的训练方法、分类方法、装置及存储介质
US16/995,765 US20210304069A1 (en) 2020-03-27 2020-08-17 Method for training classification model, classification method and device, and storage medium
EP20193056.7A EP3886004A1 (en) 2020-03-27 2020-08-27 Method for training classification model, classification method and device, and storage medium

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010231207.5A CN111460150B (zh) 2020-03-27 2020-03-27 一种分类模型的训练方法、分类方法、装置及存储介质

Publications (2)

Publication Number Publication Date
CN111460150A CN111460150A (zh) 2020-07-28
CN111460150B true CN111460150B (zh) 2023-11-10

Family

ID=71683548

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010231207.5A Active CN111460150B (zh) 2020-03-27 2020-03-27 一种分类模型的训练方法、分类方法、装置及存储介质

Country Status (3)

Country Link
US (1) US20210304069A1 (zh)
EP (1) EP3886004A1 (zh)
CN (1) CN111460150B (zh)

Families Citing this family (24)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111882063B (zh) * 2020-08-03 2022-12-02 清华大学 适应低预算的数据标注请求方法、装置、设备及存储介质
CN111951933B (zh) * 2020-08-07 2023-01-17 平安科技(深圳)有限公司 眼底彩照图像分级方法、装置、计算机设备及存储介质
CN111898696B (zh) * 2020-08-10 2023-10-27 腾讯云计算(长沙)有限责任公司 伪标签及标签预测模型的生成方法、装置、介质及设备
CN112749728A (zh) * 2020-08-13 2021-05-04 腾讯科技(深圳)有限公司 学生模型训练方法、装置、计算机设备及存储介质
CN112070233B (zh) * 2020-08-25 2024-03-22 北京百度网讯科技有限公司 模型联合训练方法、装置、电子设备和存储介质
CN112182214B (zh) * 2020-09-27 2024-03-19 中国建设银行股份有限公司 一种数据分类方法、装置、设备及介质
CN112464760A (zh) * 2020-11-16 2021-03-09 北京明略软件系统有限公司 一种目标识别模型的训练方法和装置
CN112528109B (zh) * 2020-12-01 2023-10-27 科大讯飞(北京)有限公司 一种数据分类方法、装置、设备及存储介质
CN112613938B (zh) * 2020-12-11 2023-04-07 上海哔哩哔哩科技有限公司 模型训练方法、装置及计算机设备
CN112686046A (zh) * 2021-01-06 2021-04-20 上海明略人工智能(集团)有限公司 模型训练方法、装置、设备及计算机可读介质
CN112861935A (zh) * 2021-01-25 2021-05-28 北京有竹居网络技术有限公司 模型生成方法、对象分类方法、装置、电子设备及介质
CN112800223A (zh) * 2021-01-26 2021-05-14 上海明略人工智能(集团)有限公司 基于长文本标签化的内容召回方法及系统
CN113239985B (zh) * 2021-04-25 2022-12-13 北京航空航天大学 一种面向分布式小规模医疗数据集的分类检测方法
CN113178189B (zh) * 2021-04-27 2023-10-27 科大讯飞股份有限公司 一种信息分类方法及装置、信息分类模型训练方法及装置
CN113792798A (zh) * 2021-09-16 2021-12-14 平安科技(深圳)有限公司 基于多源数据的模型训练方法、装置及计算机设备
US11450225B1 (en) * 2021-10-14 2022-09-20 Quizlet, Inc. Machine grading of short answers with explanations
CN114428858A (zh) * 2022-01-21 2022-05-03 平安科技(深圳)有限公司 基于分类模型的文本难度分类方法、装置及存储介质
CN114372978B (zh) * 2022-02-10 2022-06-28 北京安德医智科技有限公司 一种超声造影影像分类方法及装置、电子设备和存储介质
CN114186065B (zh) * 2022-02-14 2022-05-17 苏州浪潮智能科技有限公司 一种分类结果校正方法、系统、设备以及介质
CN114692724B (zh) * 2022-03-03 2023-03-28 支付宝(杭州)信息技术有限公司 数据分类模型的训练方法、数据分类方法和装置
CN114757169A (zh) * 2022-03-22 2022-07-15 中国电子科技集团公司第十研究所 基于albert模型自适应小样本学习智能纠错方法
CN114780709B (zh) * 2022-03-22 2023-04-07 北京三快在线科技有限公司 文本匹配方法、装置及电子设备
CN114419378B (zh) * 2022-03-28 2022-09-02 杭州未名信科科技有限公司 图像分类的方法、装置、电子设备及介质
CN114792173B (zh) * 2022-06-20 2022-10-04 支付宝(杭州)信息技术有限公司 预测模型训练方法和装置

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108509969A (zh) * 2017-09-06 2018-09-07 腾讯科技(深圳)有限公司 数据标注方法及终端
WO2019062414A1 (zh) * 2017-09-30 2019-04-04 Oppo广东移动通信有限公司 应用程序管控方法、装置、存储介质及电子设备
WO2020000961A1 (zh) * 2018-06-29 2020-01-02 北京达佳互联信息技术有限公司 图像标签识别方法、装置及服务器
CN110827253A (zh) * 2019-10-30 2020-02-21 北京达佳互联信息技术有限公司 一种目标检测模型的训练方法、装置及电子设备

Family Cites Families (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10885400B2 (en) * 2018-07-03 2021-01-05 General Electric Company Classification based on annotation information
US10719301B1 (en) * 2018-10-26 2020-07-21 Amazon Technologies, Inc. Development environment for machine learning media models
EP3671531A1 (en) * 2018-12-17 2020-06-24 Promaton Holding B.V. Semantic segmentation of non-euclidean 3d data sets using deep learning
WO2021007514A1 (en) * 2019-07-10 2021-01-14 Schlumberger Technology Corporation Active learning for inspection tool

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108509969A (zh) * 2017-09-06 2018-09-07 腾讯科技(深圳)有限公司 数据标注方法及终端
WO2019062414A1 (zh) * 2017-09-30 2019-04-04 Oppo广东移动通信有限公司 应用程序管控方法、装置、存储介质及电子设备
WO2020000961A1 (zh) * 2018-06-29 2020-01-02 北京达佳互联信息技术有限公司 图像标签识别方法、装置及服务器
CN110827253A (zh) * 2019-10-30 2020-02-21 北京达佳互联信息技术有限公司 一种目标检测模型的训练方法、装置及电子设备

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
基于标签相关性的卷积神经网络多标签分类算法;蒋俊钊;程良伦;李全杰;;工业控制计算机(第07期);全文 *

Also Published As

Publication number Publication date
CN111460150A (zh) 2020-07-28
EP3886004A1 (en) 2021-09-29
US20210304069A1 (en) 2021-09-30

Similar Documents

Publication Publication Date Title
CN111460150B (zh) 一种分类模型的训练方法、分类方法、装置及存储介质
RU2749970C1 (ru) Способ сжатия модели нейронной сети, а также способ и устройство для перевода языкового корпуса
US11556761B2 (en) Method and device for compressing a neural network model for machine translation and storage medium
CN111242303B (zh) 网络训练方法及装置、图像处理方法及装置
CN107564526B (zh) 处理方法、装置和机器可读介质
CN111753895A (zh) 数据处理方法、装置及存储介质
CN111831806B (zh) 语义完整性确定方法、装置、电子设备和存储介质
CN111210844B (zh) 语音情感识别模型的确定方法、装置、设备及存储介质
CN111160448A (zh) 一种图像分类模型的训练方法及装置
CN112148980B (zh) 基于用户点击的物品推荐方法、装置、设备和存储介质
CN111753091A (zh) 分类方法、分类模型的训练方法、装置、设备及存储介质
CN110764627B (zh) 一种输入方法、装置和电子设备
CN111160047A (zh) 一种数据处理方法、装置和用于数据处理的装置
CN112148923A (zh) 搜索结果的排序方法、排序模型的生成方法、装置及设备
CN111753917A (zh) 数据处理方法、装置及存储介质
CN111274389B (zh) 一种信息处理方法、装置、计算机设备及存储介质
CN112328809A (zh) 实体分类方法、装置及计算机可读存储介质
CN112035651A (zh) 语句补全方法、装置及计算机可读存储介质
CN113609380B (zh) 标签体系更新方法、搜索方法、装置以及电子设备
CN111400443B (zh) 信息处理方法、装置及存储介质
CN112579767B (zh) 搜索处理方法、装置和用于搜索处理的装置
CN114462410A (zh) 实体识别方法、装置、终端及存储介质
CN110858099B (zh) 候选词生成方法及装置
CN108345590B (zh) 一种翻译方法、装置、电子设备以及存储介质
CN112149653A (zh) 信息处理方法、装置、电子设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
CB02 Change of applicant information
CB02 Change of applicant information

Address after: 100085 unit C, building C, lin66, Zhufang Road, Qinghe, Haidian District, Beijing

Applicant after: Beijing Xiaomi pinecone Electronic Co.,Ltd.

Address before: 100085 unit C, building C, lin66, Zhufang Road, Qinghe, Haidian District, Beijing

Applicant before: BEIJING PINECONE ELECTRONICS Co.,Ltd.

SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant