CN111291823A - 分类模型的融合方法、装置、电子设备及存储介质 - Google Patents

分类模型的融合方法、装置、电子设备及存储介质 Download PDF

Info

Publication number
CN111291823A
CN111291823A CN202010113301.0A CN202010113301A CN111291823A CN 111291823 A CN111291823 A CN 111291823A CN 202010113301 A CN202010113301 A CN 202010113301A CN 111291823 A CN111291823 A CN 111291823A
Authority
CN
China
Prior art keywords
classification
classification model
model
training
fusion
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202010113301.0A
Other languages
English (en)
Other versions
CN111291823B (zh
Inventor
路泽
肖万鹏
鞠奇
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tencent Technology Shenzhen Co Ltd
Original Assignee
Tencent Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tencent Technology Shenzhen Co Ltd filed Critical Tencent Technology Shenzhen Co Ltd
Priority to CN202010113301.0A priority Critical patent/CN111291823B/zh
Publication of CN111291823A publication Critical patent/CN111291823A/zh
Application granted granted Critical
Publication of CN111291823B publication Critical patent/CN111291823B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/25Fusion techniques
    • G06F18/254Fusion techniques of classification results, e.g. of results related to same input data
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • General Engineering & Computer Science (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Software Systems (AREA)
  • Medical Informatics (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明提供了一种分类模型的融合方法、装置、电子设备及存储介质;方法包括:通过第i个分类模型对第j个分类模型的训练样本进行第i类别的分类预测,得到对应第j个分类模型的训练样本的第i分类结果;以第i分类结果作为第j个分类模型的训练样本的第i类别的分类标签,对第j个分类模型的训练样本进行标注;对j进行遍历,得到标注有对应第i类别的分类标签的训练样本所构成的第i数据集;对i进行遍历,得到标注有对应n个类别的分类标签的训练样本所构成的n个数据集;基于n个数据集中至少之一训练融合分类模型;通过本发明,能够实现不同任务类别的分类模型的快速融合,提高融合分类模型的分类精度及性能。

Description

分类模型的融合方法、装置、电子设备及存储介质
技术领域
本发明涉及人工智能技术领域,尤其涉及一种分类模型的融合方法、装置、电子设备及存储介质。
背景技术
人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术,人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。其中,机器学习(ML,Machine Learning)是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习和归纳学习等技术。
在人工神经网络学习技术的研究过程中,发现神经网络模型的融合会带来分类精度、性能等方面的巨大提升,相关技术中,典型的分类模型的融合方案可以分为两类,测试阶段融合和训练阶段融合。对于第一类测试阶段融合的方法,待测样本通常需要经过多个分类模型,且最终输出是在多个分类模型的结果上求加权平均或者利用投票机制得到,从而导致机器内存占用过高、推理耗时过长。
对于第二类训练阶段融合的方法,通常假设多个分类模型是针对同一分类任务进行训练的,即不同分类模型所对应的训练样本均标注有相同的分类标签。但是对于不同任务间的分类模型进行融合时,每个分类模型是由标注有不同分类标签的训练样本训练所得到,因此该分类模型的融合方法是不适用的。
发明内容
本发明实施例提供一种分类模型的融合方法、装置、电子设备及存储介质,能够实现不同任务类别的分类模型的快速融合,提高融合分类模型的分类精度及性能。
本发明实施例的技术方案是这样实现的:
本发明实施例提供一种分类模型的融合方法,包括:
获取训练得到的n个分类模型、及用于训练各个所述分类模型的训练样本;其中,n为不小于2的正整数,所述n个分类模型中第i个分类模型用于进行第i类别的分类预测,i为不大于n的正整数;
通过所述第i个分类模型,对第j个分类模型的训练样本进行所述第i类别的分类预测,得到对应所述第j个分类模型的训练样本的第i分类结果;其中,j为不大于n的正整数,且j不等于i;
以所述第i分类结果作为所述第j个分类模型的训练样本的第i类别的分类标签,对所述第j个分类模型的训练样本进行标注;
对所述j进行遍历,得到标注有对应所述第i类别的分类标签的训练样本所构成的第i数据集;
对所述i进行遍历,得到标注有对应n个类别的分类标签的训练样本所构成的n个数据集;
基于所述n个数据集中至少之一,训练融合分类模型,使得所述融合分类模型能够基于输入的待分类对象,进行所述n个类别的分类预测,并得到相应的分类结果。
本发明实施例还提供一种分类模型的融合装置,包括:
获取模块,用于获取训练得到的n个分类模型、及用于训练各个所述分类模型的训练样本;其中,n为不小于2的正整数,所述n个分类模型中第i个分类模型用于进行第i类别的分类预测,i为不大于n的正整数;
分类预测模块,用于通过所述第i个分类模型,对第j个分类模型的训练样本进行所述第i类别的分类预测,得到对应所述第j个分类模型的训练样本的第i分类结果;其中,j为不大于n的正整数,且j不等于i;
标注模块,用于以所述第i分类结果作为所述第j个分类模型的训练样本的第i类别的分类标签,对所述第j个分类模型的训练样本进行标注;
第一遍历模块,用于对所述j进行遍历,得到标注有对应所述第i类别的分类标签的训练样本所构成的第i数据集;
第二遍历模块,用于对所述i进行遍历,得到标注有对应n个类别的分类标签的训练样本所构成的n个数据集;
第一模型训练模块,用于基于所述n个数据集中至少之一,训练融合分类模型,使得所述融合分类模型能够基于输入的待分类对象,进行所述n个类别的分类预测,并得到相应的分类结果。
上述方案中,所述装置还包括:
第二模型训练模块,用于分别将用于训练各个所述分类模型的训练样本,输入至相应的分类模型进行分类预测,得到相应的预测结果;其中,用于训练第i个分类模型的训练样本标注有对应第i类别的初始分类标签;
基于得到的所述预测结果,及用于训练各个所述分类模型的训练样本的初始分类标签,确定各个所述分类模型的损失函数的值;
基于各个所述分类模型的损失函数的值,更新各个所述分类模型的模型参数。
上述方案中,所述第一模型训练模块,还用于通过所述融合分类模型,对所述n个数据集中至少之一的标注有对应n个类别的分类标签的训练样本,进行所述n个类别的分类预测,得到对应所述n个类别的分类结果;
获取所述n个类别中各类别的分类结果与相应类别的分类标签之间的差异;
基于所述差异,更新所述融合分类模型的模型参数。
上述方案中,所述训练样本中标注的对应n个类别的分类标签包括初始分类标签及软标签,所述初始分类标签为训练所述分类模型时所标注,所述软标签为基于所述分类模型对所述训练样本进行分类预测所得到的分类结果所标注;
所述第一模型训练模块,还用于获取所述n个类别中对应所述初始分类标签的类别的分类结果与所述初始分类标签之间的第一差异;以及
获取所述n个类别中对应软标签的类别的分类结果与所述软标签之间的第二差异;
所述第一模型训练模块,还用于基于所述第一差异及所述第二差异,确定所述融合分类模型的损失函数的值;
基于所述融合分类模型的损失函数的值,更新所述融合分类模型的模型参数。
上述方案中,所述融合分类模型的损失函数包括交叉熵损失函数及蒸馏损失函数,所述第一模型训练模块,还用于基于所述第一差异,确定所述交叉熵损失函数的值;
基于所述第二差异,确定所述蒸馏损失函数的值;
获取所述交叉熵损失函数对应的第一权重,及所述蒸馏损失函数对应的第二权重;
结合所述第一权重和第二权重、所述交叉熵损失函数的值和蒸馏损失函数的值,确定所述融合分类模型的损失函数的值。
上述方案中,所述第一模型训练模块,还用于当所述融合分类模型的损失函数的值超出第一损失阈值时,基于所述融合分类模型的损失函数确定所述融合分类模型的第一误差信号;
将所述第一误差信号在所述融合分类模型中反向传播,并在传播的过程中更新各个层的模型参数。
上述方案中,所述第一模型训练模块,还用于基于所述差异,确定所述融合分类模型的损失函数的值;
当所述融合分类模型的损失函数的值超出第二损失阈值时,基于所述融合分类模型的损失函数确定所述融合分类模型的第二误差信号;
将所述第二误差信号在所述融合分类模型中反向传播,并在传播的过程中更新各个层的模型参数。
上述方案中,所述装置还包括:
分类模块,用于通过所述融合分类模型的特征提取层,对输入的待分类对象进行特征提取,得到待分类对象的对象特征;
基于所述待分类对象的对象特征,通过所述融合分类模型的多分类层,进行所述n个类别的分类预测,得到对应所述n个类别的分类结果。
上述方案中,当n为2时,所述n个分类模型包括:第一分类模型和第二分类模型;其中,所述第一分类模型用于第一类别的分类预测,所述第二分类模型用于第二类别的分类预测;
所述n个数据集包括:由第一训练样本构成的第一训练样本集、及由第二训练样本构成的第二训练样本集;其中,所述第一训练样本标注有对应所述第一类别的初始分类标签和对应所述第二类别的软标签,所述第二训练样本标注有对应所述第二类别的初始分类标签和对应所述第一类别的软标签;
所述第一模型训练模块,还用于基于所述第一训练样本集、所述第二训练样本集中至少之一,训练所述融合分类模型,使得所述融合分类模型能够基于输入的待分类对象,进行所述第一类别和所述第二类别的分类预测,得到相应的分类结果。
本发明实施例还提供一种电子设备,包括:
存储器,用于存储可执行指令;
处理器,用于执行所述存储器中存储的可执行指令时,实现本发明实施例提供的分类模型的融合方法。
本发明实施例还提供一种计算机可读存储介质,存储有可执行指令,所述可执行指令被处理器执行时,实现本发明实施例提供的分类模型的融合方法。
本发明实施例具有以下有益效果:
通过第i个分类模型,对第j个分类模型的训练样本进行第i类别的分类预测,得到对应第j个分类模型的训练样本的第i分类结果,以第i分类结果作为第j个分类模型的训练样本的第i类别的分类标签,对第j个分类模型的训练样本进行标注,分别对i和j进行遍历,最终得到标注有对应n个类别的分类标签的训练样本所构成的n个数据集;如此,基于该n个数据集中至少之一,对融合分类模型进行训练,使得融合分类模型能够进行n个类别的分类预测,实现了不同任务类别的分类模型的融合;并且仅需要对融合分类模型进行测试即可,不需要经过多个分类模型,降低了时间消耗;因此实现了不同任务类别的分类模型的快速融合,提高了融合分类模型的分类精度及性能。
附图说明
图1A-B是相关技术中提供的分类模型的融合方法的示意图;
图2是本发明实施例提供的分类模型的融合系统的架构示意图;
图3是本发明实施例提供的电子设备的结构示意图;
图4是本发明实施例提供的分类模型的融合方法的流程示意图;
图5是本发明实施例提供的分类模型的融合方法的数据流走向示意图一;
图6是本发明实施例提供的分类模型的融合方法的流程示意图;
图7A是本发明实施例提供的用于训练分类模型的训练样本的标签示意图;
图7B是本发明实施例提供的训练样本的软标签标注的流程示意图;
图8是本发明实施例提供的分类模型的融合方法的数据流走向示意图二;
图9是本发明实施例提供的分类模型的融合方法的流程示意图;
图10是本发明实施例提供的分类模型的融合装置的结构示意图。
具体实施方式
为了使本发明的目的、技术方案和优点更加清楚,下面将结合附图对本发明作进一步地详细描述,所描述的实施例不应视为对本发明的限制,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本发明保护的范围。
在以下的描述中,涉及到“一些实施例”,其描述了所有可能实施例的子集,但是可以理解,“一些实施例”可以是所有可能实施例的相同子集或不同子集,并且可以在不冲突的情况下相互结合。
在以下的描述中,所涉及的术语“第一\第二\第三”仅仅是是区别类似的对象,不代表针对对象的特定排序,可以理解地,“第一\第二\第三”在允许的情况下可以互换特定的顺序或先后次序,以使这里描述的本发明实施例能够以除了在这里图示或描述的以外的顺序实施。
除非另有定义,本文所使用的所有的技术和科学术语与属于本发明的技术领域的技术人员通常理解的含义相同。本文中所使用的术语只是为了描述本发明实施例的目的,不是旨在限制本发明。
对本发明实施例进行进一步详细说明之前,对本发明实施例中涉及的名词和术语进行说明,本发明实施例中涉及的名词和术语适用于如下的解释。
1)响应于,用于表示所执行的操作所依赖的条件或者状态,当满足所依赖的条件或状态时,所执行的一个或多个操作可以是实时的,也可以具有设定的延迟;在没有特别说明的情况下,所执行的多个操作不存在执行先后顺序的限制。
2)初始分类标签,为训练样本在用于训练各相应分类模型时所标注,也可称为硬标签;
3)软标签,为基于各分类模型对训练样本进行分类预测所得到的分类结果所标注。
相关技术中,典型的分类模型的融合方案可以分为两类,测试阶段融合和训练阶段融合。对于测试阶段融合方法,一是在融合分类模型的训练收敛过程中,使融合分类模型输出多个局部最优解,以多个不同阶段模型所输出的最优解的平均,作为融合分类模型最终的输出,具体地该方案的模型训练优化过程如图1A所示。二是设目标分类任务有R个类别,分别对回归模型、多二分类模型、自编码模型进行训练,采用投票的方式进行模型融合。而对于此类测试阶段融合的方法,待测样本通常需要经过多个分类模型,且最终输出是在多个分类模型的结果上求加权平均或者利用投票机制得到,从而导致机器内存占用过高、推理耗时过长。
对于训练阶段融合方法,多通过对抗学习的方法将多个模型蒸馏成一个模型,基于各区分网络块的训练损失值来引导优化融合分类模型学习各分类模型的知识信息,该分类模型的融合方法的流程如图1B所示。而对于此类训练阶段融合的方法,通常假设多个分类模型是针对同一分类任务进行训练的,即不同分类模型所对应的训练样本均标注有相同的分类标签。但是对于不同任务间的分类模型进行融合时,每个分类模型是由标注有不同分类标签的训练样本训练所得到,因此该分类模型的融合方法是不适用的。
基于此,本发明实施例提供了一种分类模型的融合方法、装置、系统、电子设备及存储介质,以至少解决相关技术中的上述问题,接下来分别进行说明。
基于上述对本发明实施例中涉及的名词和术语的解释,首先说明本发明实施例提供的分类模型的融合系统,参见图2,图2是本发明实施例提供的分类模型的融合的架构示意图,为实现支撑一个示例性应用,终端(包括终端200-1和终端200-2)通过网络300连接服务器100,网络300可以是广域网或者局域网,又或者是二者的组合,使用无线或有线链路实现数据传输。
服务器100,用于获取训练得到的n个分类模型、及用于训练各个分类模型的训练样本;通过第i个分类模型,对第j个分类模型的训练样本进行第i类别的分类预测,得到对应第j个分类模型的训练样本的第i分类结果;以第i分类结果作为第j个分类模型的训练样本的第i类别的分类标签,对第j个分类模型的训练样本进行标注;对j进行遍历,得到标注有对应第i类别的分类标签的训练样本所构成的第i数据集;对i进行遍历,得到标注有对应n个类别的分类标签的训练样本所构成的n个数据集;基于n个数据集中至少之一,训练融合分类模型;
终端(如终端200-1),用于响应于针对待分类对象的分类预测指令,向服务器发送待分类对象的分类预测请求;
服务器100,用于接收到针对待分类对象的分类预测请求,通过训练完成的融合分类模型,对待分类对象进行n个类别的分类预测,得到相应的分类结果,并返回给终端;
终端(如终端200-1),用于接收并呈现对应待分类对象的n个类别的分类结果。
在实际应用中,服务器100既可以为单独配置的支持各种业务的一个服务器,亦可以配置为一个服务器集群;终端(如终端200-1)可以为智能手机、平板电脑、笔记本电脑等各种类型的用户终端,还可以为可穿戴计算设备、个人数字助理(PDA)、台式计算机、蜂窝电话、媒体播放器、导航设备、游戏机、电视机、或者这些数据处理设备或其他数据处理设备中任意两个或多个的组合。
下面对本发明实施例提供的分类模型的融合方法的电子设备的硬件结构做详细说明,参见图3,图3是本发明实施例提供的电子设备的结构示意图,图3所示的电子设备300包括:至少一个处理器310、存储器350、至少一个网络接口320和用户接口330。电子设备300中的各个组件通过总线系统340耦合在一起。可理解,总线系统340用于实现这些组件之间的连接通信。总线系统340除包括数据总线之外,还包括电源总线、控制总线和状态信号总线。但是为了清楚说明起见,在图3中将各种总线都标为总线系统340。
处理器310可以是一种集成电路芯片,具有信号的处理能力,例如通用处理器、数字信号处理器(DSP,Digital Signal Processor),或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等,其中,通用处理器可以是微处理器或者任何常规的处理器等。
用户接口330包括使得能够呈现媒体内容的一个或多个输出装置331,包括一个或多个扬声器和/或一个或多个视觉显示屏。用户接口330还包括一个或多个输入装置332,包括有助于用户输入的用户接口部件,比如键盘、鼠标、麦克风、触屏显示屏、摄像头、其他输入按钮和控件。
存储器350可以是可移除的,不可移除的或其组合。示例性的硬件设备包括固态存储器,硬盘驱动器,光盘驱动器等。存储器350可选地包括在物理位置上远离处理器310的一个或多个存储设备。
存储器350包括易失性存储器或非易失性存储器,也可包括易失性和非易失性存储器两者。非易失性存储器可以是只读存储器(ROM,Read Only Me mory),易失性存储器可以是随机存取存储器(RAM,Random Access Memor y)。本发明实施例描述的存储器350旨在包括任意适合类型的存储器。
在一些实施例中,存储器350能够存储数据以支持各种操作,这些数据的示例包括程序、模块和数据结构或者其子集或超集,下面示例性说明。
操作系统351,包括用于处理各种基本系统服务和执行硬件相关任务的系统程序,例如框架层、核心库层、驱动层等,用于实现各种基础业务以及处理基于硬件的任务;
网络通信模块352,用于经由一个或多个(有线或无线)网络接口320到达其他计算设备,示例性的网络接口320包括:蓝牙、无线相容性认证(WiFi)、和通用串行总线(USB,Universal Serial Bus)等;
呈现模块353,用于经由一个或多个与用户接口330相关联的输出装置331(例如,显示屏、扬声器等)使得能够呈现信息(例如,用于操作外围设备和显示内容和信息的用户接口);
输入处理模块354,用于对一个或多个来自一个或多个输入装置332之一的一个或多个用户输入或互动进行检测以及翻译所检测的输入或互动。
在一些实施例中,本发明实施例提供的分类模型的融合装置可以采用软件方式实现,图3示出了存储在存储器350中的分类模型的融合装置355,其可以是程序和插件等形式的软件,包括以下软件模块:获取模块3551、分类预测模块3552、标注模块3553、第一遍历模块3554、第二遍历模块3555和第一模型训练模块3556,这些模块是逻辑上的,因此根据所实现的功能可以进行任意的组合或进一步拆分,将在下文中说明各个模块的功能。
在另一些实施例中,本发明实施例提供的分类模型的融合装置可以采用软硬件结合的方式实现,作为示例,本发明实施例提供的分类模型的融合装置可以是采用硬件译码处理器形式的处理器,其被编程以执行本发明实施例提供的分类模型的融合方法,例如,硬件译码处理器形式的处理器可以采用一个或多个应用专用集成电路(ASIC,ApplicationSpecific Integrated Circuit)、DSP、可编程逻辑器件(PLD,Programmable LogicDevice)、复杂可编程逻辑器件(C PLD,Complex Programmable Logic Device)、现场可编程门阵列(FPGA,Fi eld-Programmable Gate Array)或其他电子元件。
基于上述对本发明实施例的分类模型的融合系统及电子设备的说明,下面说明本发明实施例提供的分类模型的融合方法。参见图4,图4是本发明实施例提供的分类模型的融合方法的流程示意图;在一些实施例中,该分类模型的融合方法可由服务器或终端单独实施,或由服务器及终端协同实施,以服务器实施为例,本发明实施例提供的分类模型的融合方法包括:
步骤401:服务器获取训练得到的n个分类模型、及用于训练的各个分类模型的训练样本。
这里,n为不小于2的正整数,n个分类模型中第i个分类模型用于进行第i类别的分类预测,i为不大于n的正整数。
在实际应用中,首先需要构建n个分类模型,每个分类模型分别用于不同类别的分类预测,比如可以是用于对待分类图像所包含的内容进行分类预测的图像分类模型、或者对待分类图像所呈现的颜色进行分类预测的图像分类模型等。然后获取用于训练各分类模型的训练样本,该每个训练样本均标注有相应类别的分类标签。将标注有分类标签的训练样本输入到相应的分类模型中,对各分类模型进行训练,以得到训练完成的n个分类模型。
在一些实施例中,服务器可通过如下方式训练上述分类模型:分别将用于训练各个分类模型的训练样本,输入至相应的分类模型进行分类预测,得到相应的预测结果;其中,用于训练第i个分类模型的训练样本标注有对应第i类别的初始分类标签;基于得到的预测结果,及用于训练各个分类模型的训练样本的初始分类标签,确定各个分类模型的损失函数的值;基于各个分类模型的损失函数的值,更新各个分类模型的模型参数。
基于此,服务器得到训练完成的n个分类模型,及用于训练各分类模型的训练样本。
步骤402:通过第i个分类模型,对第j个分类模型的训练样本进行第i类别的分类预测,得到对应第j个分类模型的训练样本的第i分类结果。
这里,j为不大于n的正整数,且j不等于i。
将第j个分类模型的训练样本输入到第i个分类模型中,对第j个分类模型的训练样本进行第i类别的分类预测,从而得到对应第j个分类模型的训练样本的第i分类结果。
示例地,第i个分类模型为用于对待分类图像所包含的内容(风景、动物萌宠、人物等)进行分类预测的图像分类模型,第j个分类模型为用于对待分类图像所展示的风格(中国风、文艺风、幽默风、西方艺术等)进行分类预测的图像分类模型。此时,第j个分类模型的训练样本可以标注有中国风、文艺风、幽默风、西方艺术等分类标签,将第j个分类模型的训练样本输入到第i个分类模型中,得到对应第j个分类模型的训练样本的第i分类结果,包括风景、动物萌宠、人物等分类结果。
步骤403:以第i分类结果作为第j个分类模型的训练样本的第i类别的分类标签,对第j个分类模型的训练样本进行标注。
继续以第i个分类模型为用于对待分类图像所包含的内容(风景、动物萌宠、人物等)进行分类预测的图像分类模型,第j个分类模型为用于对待分类图像所展示的风格(中国风、文艺风、幽默风、西方艺术等)进行分类预测的图像分类模型为例,这里,第j个分类模型的训练样本可以标注有中国风、文艺风、幽默风、西方艺术等分类标签。
将第j个分类模型的训练样本输入到第i个分类模型中,得到对应第j个分类模型的训练样本的第i分类结果(风景、动物萌宠、人物等),将第i分类结果作为第j个分类模型的训练样本的第i类别的分类标签,对第j个分类模型的训练样本进行标注,从而得到标注有“风景、动物萌宠、人物等”分类标签的第j个分类模型的训练样本。
步骤404:对j进行遍历,得到标注有对应第i类别的分类标签的训练样本所构成的第i数据集。
步骤405:对i进行遍历,得到标注有对应n个类别的分类标签的训练样本所构成的n个数据集。
这里,n个数据集中包含n个分类模型的训练样本,其中每个分类模型的训练样本均标注有对应n个类别的分类标签。
基于此,可实现对缺失相应类别的标签的训练样本进行自动化标注,极大地降低了标注工作中的人力消耗。
步骤406:基于n个数据集中至少之一,训练融合分类模型。
这里,融合分类模型能够基于输入的待分类对象,进行n个类别的分类预测,并得到相应的分类结果。
在得到标注有对应n个类别的分类标签的训练样本所构成的n个数据集后,基于该n个数据集中的至少一个数据集,对融合分类模型进行训练,以得到能够对待分类对象进行上述n个类别的分类预测的融合分类模型。
在一些实施例中,服务器可通过如下方式训练融合分类模型:通过融合分类模型,对n个数据集中至少之一的标注有对应n个类别的分类标签的训练样本,进行n个类别的分类预测,得到对应n个类别的分类结果;获取n个类别中各类别的分类结果与相应类别的分类标签之间的差异;基于差异,更新融合分类模型的模型参数。
在实际应用中,当训练融合分类模型时,可以采用上述n个数据集中的任意一个或者几个数据集作为训练样本,也可以采用该n个数据集中的所有数据集作为训练样本,从而对融合分类模型进行训练。
具体地,将上述n个数据集中至少之一的数据集,输入融合分类模型中,通过该融合分类模型,对标注有n个类别的分类标签的训练样本进行该n个类别的分类预测,以得到对应n个类别的分类结果。进一步地,获取预测得到的n个类别的分类结果与分类标签之间存在的差异,进而基于获取的差异,在融合分类模型的训练过程中,更新融合分类模型的模型参数。
在一些实施例中,上述训练样本中标注的对应n个类别的分类标签包括初始分类标签及软标签。这里,初始分类标签为训练分类模型时所标注,软标签为基于分类模型对训练样本进行分类预测所得到的分类结果所标注。
因此,当获取各类别的分类结果与相应类别的分类标签之间的差异时,具体可以是:获取n个类别中对应初始分类标签的类别的分类结果与初始分类标签之间的第一差异;以及获取n个类别中对应软标签的类别的分类结果与软标签之间的第二差异。
基于此,在一些实施例中,基于差异,服务器可通过如下方式更新融合分类模型的模型参数:基于第一差异及第二差异,确定融合分类模型的损失函数的值;基于融合分类模型的损失函数的值,更新融合分类模型的模型参数。
在一些实施例中,上述融合分类模型的损失函数可包括交叉熵损失函数和蒸馏损失函数。基于此,服务器可通过如下方式确定融合分类模型的损失函数的值:基于第一差异,确定交叉熵损失函数的值;基于第二差异,确定蒸馏损失函数的值;获取交叉熵损失函数对应的第一权重,及蒸馏损失函数对应的第二权重;结合第一权重和第二权重、交叉熵损失函数的值和蒸馏损失函数的值,确定融合分类模型的损失函数的值。
在实际应用中,针对交叉熵损失函数和蒸馏损失函数,分别设置了对应的权重值。具体地,通过第一差异确定交叉熵损失函数的值,并通过第二差异确定蒸馏损失函数的值。分别获取交叉熵损失函数对应的第一权重、及蒸馏损失函数对应的第二权重,进而结合第一权重和第二权重,基于交叉熵损失函数的值和蒸馏损失函数的值,确定融合分类函数的损失函数的值。
确定融合分类函数的损失函数的值之后,在一些实施例中,服务器可通过如下方式,基于融合分类模型的损失函数的值,更新融合分类模型的模型参数:当融合分类模型的损失函数的值超出第一损失阈值时,基于融合分类模型的损失函数确定融合分类模型的第一误差信号;将第一误差信号在融合分类模型中反向传播,并在传播的过程中更新各个层的模型参数。
具体地,可对比融合分类模型的损失函数的值与预设的第一损失阈值,当融合分类模型的损失函数的值超过第一损失阈值时,确定融合分类模型的第一误差信号,从而基于第一误差信号在融合分类模型中反向传播的过程中,更新融合分类模型各个层的模型参数。
在一些实施例中,融合分类模型的损失函数可仅包含一种损失函数,基于此,服务器还可通过如下方式更新融合分类模型的模型参数:基于差异,确定融合分类模型的损失函数的值;当融合分类模型的损失函数的值超出第二损失阈值时,基于融合分类模型的损失函数确定融合分类模型的第二误差信号;将第二误差信号在融合分类模型中反向传播,并在传播的过程中更新各个层的模型参数。
在一些实施例中,当n为2时,n个分类模型包括:第一分类模型和第二分类模型;其中,第一分类模型用于第一类别的分类预测,第二分类模型用于第二类别的分类预测;
n个数据集包括:由第一训练样本构成的第一训练样本集、及由第二训练样本构成的第二训练样本集;其中,第一训练样本标注有对应第一类别的初始分类标签和对应第二类别的软标签,第二训练样本标注有对应第二类别的初始分类标签和对应第一类别的软标签;这里,每个分类模型的训练样本集由多个不同的训练样本所构成,不同分类模型的训练样本集的数量不一定是相同的。
基于此,服务器可通过如下方式训练融合分类模型:基于第一训练样本集、第二训练样本集中至少之一,训练融合分类模型,使得融合分类模型能够基于输入的待分类对象,进行第一类别和第二类别的分类预测,得到相应的分类结果。
在一些实施例中,服务器可通过以下方式对待分类对象进行分类预测:通过融合分类模型的特征提取层,对输入的待分类对象进行特征提取,得到待分类对象的对象特征;基于待分类对象的对象特征,通过融合分类模型的多分类层,进行n个类别的分类预测,得到对应n个类别的分类结果。
在得到训练完成的融合分类模型后,可通过该融合分类模型对待分类对象进行分类预测。在实际应用中,融合分类模型可包括特征提取层和多分类层。具体地,当通过融合分类模型进行分类预测时,首先通过特征提取层对待分类对象进行特征提取,得到待分类对象的对象特征;然后通过多分类层对待分类对象的对象特征进行n个类别的分类预测,从而实现对待分类对象的分类,得到待分类对象对应的n个类别的分类结果。
应用本发明上述实施例,通过第i个分类模型,对第j个分类模型的训练样本进行第i类别的分类预测,得到对应第j个分类模型的训练样本的第i分类结果,以第i分类结果作为第j个分类模型的训练样本的第i类别的分类标签,对第j个分类模型的训练样本进行标注,分别对i和j进行遍历,最终得到标注有对应n个类别的分类标签的训练样本所构成的n个数据集;如此,基于该n个数据集中至少之一,对融合分类模型进行训练,使得融合分类模型能够进行n个类别的分类预测,实现了不同任务类别的分类模型的融合;并且仅需要对融合分类模型进行测试即可,不需要经过多个分类模型,降低了时间消耗;因此实现了不同任务类别的分类模型的快速融合,提高了融合分类模型的分类精度及性能。
下面将说明本发明实施例在实际应用场景中的示例性应用。以两个分类模型的融合(n=2)为例,继续对本发明实施例提供的分类模型的融合方法进行说明。其中,该两个分类模型包括第一分类模型和第二分类模型。参见图5和图6,图5是本发明实施例提供的分类模型的融合方法的数据流走向示意图一,图6是本发明实施例提供的分类模型的融合方法的流程示意图,包括:
步骤601:服务器基于第一训练样本训练第一分类模型,基于第二训练样本训练第二分类模型。
这里,第一分类模型和第二分类模型分别对应不同类别的分类任务,第一训练样本标注有对应第一类别的分类标签(即硬标签),第二训练样本标注有对应第二类别的分类标签。
示例性地,该两个分类模型可以为对应有不同类别分类任务的文本分类模型。比如,第一分类模型用于对待分类文本进行所属领域的分类预测,此时第一训练样本所标注的第一类别的分类标签可以为医学领域、文学领域、经济学领域等;第二分类模型用于对待分类文本进行所属来源的分类预测,此时第二训练样本所标注的第二类别的分类标签可以为出版物、网络刊物、电子文库、报刊等。参见图7A,图7A是本发明实施例提供的用于训练分类模型的训练样本的标签示意图,这里,第一训练样本仅标注有硬标签a,第二训练样本仅标注有硬标签b,其中硬标签a和硬标签b均为初始分类标签。
步骤602:获取训练完成的第一分类模型、第二分类模型、以及用于训练的第一训练样本和第二训练样本。
步骤603:通过第一分类模型对第二训练样本进行第一类别的分类预测,得到对应第二训练样本的第一分类结果。
步骤604:通过第二分类模型对第一训练样本进行第二类别的分类预测,得到对应第一训练样本的第二分类结果。
这里,在步骤603-604中,通过第一分类模型对第二训练样本进行分类预测,得到对应第二训练样本的第一分类结果,通过第二分类模型对第一训练样本进行分类预测,得到对应第一训练样本的第二分类结果。
步骤605:以第一分类结果作为第二训练样本的第一类别的分类标签,对第二训练样本进行标注;以第二分类结果作为第一训练样本的第二类别的分类标签,对第一训练样本进行标注。
这里,对应第二训练样本的第一类别的分类标签即为第二训练样本的软标签,对应第一训练样本的第二类别的分类标签即为第一训练样本的软标签。
参见图7B,图7B是本发明实施例提供的训练样本的软标签标注的流程示意图,这里,将标注有硬标签a(第一类别的分类标签)的第一训练样本输入第二分类模型中,得到对应第一训练样本的第二分类结果;将标注有硬标签b(第二类别的分类标签)的第二训练样本输入第一分类模型中,得到对应第二训练样本的第一分类结果。将第二分类结果作为第一训练样本的软标签β,将第一分类结果作为第二训练样本的软标签α。
采用软标签β对第一训练样本进行标注,得到分别标注有硬标签a和软标签β的第一训练样本;采用软标签α对第二训练样本进行标注,得到分别标注有硬标签b和软标签α的第二训练样本。
继续以该两个分类模型为文本分类模型为例,即,将标注有领域硬标签(医学领域、文学领域、经济学领域等)的第一训练样本,输入到用于对待分类文本进行所属来源的分类预测的第二分类模型中,得到对应第一训练样本的来源软标签(出版物、网络刊物、电子文库、报刊等),进一步采用对应第一训练样本的来源软标签对第一训练样本进行标注,从而得到标注有领域硬标签和来源软标签的第一训练样本。相同地,基于同样的方式,得到标注有来源硬标签和领域软标签的第二训练样本。
步骤606:将标注有第一类别和第二类别的分类标签的第一训练样本、以及标注有第一类别和第二类别的分类标签的第二训练样本,输入融合分类模型,得到对应第一类别和第二类别的分类结果。
这里,融合分类模型能够基于输入的待分类对象,进行第一类别和第二类别的分类预测,并得到相应的分类结果。
继续以该两个分类模型为文本分类模型为例,第一分类模型用于对待分类文本进行所属领域的分类预测,第二分类模型用于对待分类文本进行所属来源的分类预测,则融合分类模型可用于对待分类文本进行所属领域和所属来源的分类预测。当对该融合分类模型进行训练时,则将标注有领域硬标签和来源软标签的第一训练样本、及标注有来源硬标签和领域软标签的第二训练样本输入融合分类模型中,以实现对融合分类模型的训练。
步骤607:获取第一类别和第二类别的分类结果与相应硬标签之间的第一差异,第一类别和第二类别的分类结果与相应软标签之间的第二差异。
步骤608:基于第一差异,确定融合分类模型的交叉熵损失函数的值。
步骤609:基于第二差异,确定融合分类模型的蒸馏损失函数的值。
步骤610:获取交叉熵损失函数对应的第一权重、及蒸馏损失函数对应的第二权重。
这里,第一权重和第二权重可根据经验自定义。
步骤611:结合第一权重和第二权重、交叉熵损失函数的值和蒸馏损失函数的值,确定融合分类模型的损失函数的值。
步骤612:基于融合分类模型的损失函数的值,更新融合分类模型各个层的模型参数,以实现对融合分类模型的训练。
步骤613:终端响应于针对待分类对象的分类预测指令,向服务器发送待分类对象的分类预测请求。
步骤614:服务器接收到针对待分类对象的分类预测请求,通过训练完成的融合分类模型对待分类对象进行第一类别和第二类别的分类预测,得到分类结果,并返回终端。
继续以该两个分类模型为文本分类模型为例,第一分类模型用于对待分类文本进行所属领域的分类预测,第二分类模型用于对待分类文本进行所属来源的分类预测,则融合分类模型可用于对待分类文本进行所属领域和所属来源的分类预测。
当基于融合分类模型对待分类文本进行分类预测时,通过融合分类模型的特征提取层,对待分类文本进行特征提取,比如one-hot编码、预先训练完成的TextCNN模型、word2vec词向量映射等,得到待分类文本的文本特征;再通过融合分类模型的多分类层,对待分类文本的文本特征进行分类预测,得到对应待分类文本的分类结果,即待分类文本所属的领域和来源。
步骤615:终端接收并呈现对应待分类对象的第一类别和第二类别的分类结果。
接下来以三个分类模型的融合(n=3)为例,继续对本发明实施例提供的分类模型的融合方法进行说明。其中,该三个分类模型包括第一分类模型、第二分类模型和第三分类模型。参见图8和图9,图8是本发明实施例提供的分类模型的融合方法的数据流走向示意图二,图9是本发明实施例提供的分类模型的融合方法的流程示意图,包括:
步骤901:服务器基于第一训练样本训练第一分类模型,基于第二训练样本训练第二分类模型,基于第三训练样本训练第三分类模型。
这里,第一分类模型、第二分类模型和第三分类模型分别对应不同类别的分类任务,第一训练样本标注有对应第一类别的分类标签(即硬标签),第二训练样本标注有对应第二类别的分类标签,第三训练样本标注有第三类别的分类标签。
示例性地,该三个分类模型可以为对应有不同类别分类任务的图像分类模型。比如,第一分类模型用于对待分类图像所包含的内容进行分类预测,此时第一训练样本所标注的第一类别的分类标签可以为风景、动物萌宠、人物等;第二分类模型用于对待分类图像所呈现的颜色进行分类预测,此时第二训练样本所标注的第二类别的分类标签可以为绿色、蓝色、红色、白色等;第三分类模型用于对待分类图像所展示的风格进行分类预测,此时第三训练样本所标注的第三类别的分类标签可以为中国风、文艺风、幽默风、西方艺术等。
步骤902:获取训练完成的第一分类模型、第二分类模型、第三分类模型、以及用于训练的第一训练样本、第二训练样本和第三训练样本。
步骤903:通过第一分类模型分别对第二训练样本和第三训练样本进行第一类别的分类预测,得到对应第二训练样本和第三训练样本的第一分类结果。
步骤904:通过第二分类模型分别对第一训练样本和第三训练样本进行第二类别的分类预测,得到对应第一训练样本和第三训练样本的第二分类结果。
步骤905:通过第三分类模型分别对第一训练样本和第二训练样本进行第三类别的分类预测,得到对应第一训练样本和第二训练样本的第三分类结果。
步骤906:以第一分类结果作为第二训练样本和第三训练样本的第一类别的分类标签,对第二训练样本和第三训练样本进行标注;以第二分类结果作为第一训练样本和第三训练样本的第二类别的分类标签,对第一训练样本和第三训练样本进行标注;以第三分类结果作为第一训练样本和第二训练样本的第三类别的分类标签,对第一训练样本和第二训练样本进行标注。
这里,对应第二训练样本和第三训练样本的第一类别的分类标签即为第二训练样本和第三训练样本的软标签,对应第一训练样本和第三训练样本的第二类别的分类标签即为第一训练样本和第三训练样本的软标签;对应第一训练样本和第二训练样本的第三类别的分类标签即为第一训练样本和第二训练样本的软标签。
继续以该三个分类模型为图像分类模型为例,即,将标注有内容硬标签(风景、动物萌宠、人物等)的第一训练样本,输入到用于对待分类图像所呈现的颜色进行分类预测的第二分类模型中,得到对应第一训练样本的颜色软标签(绿色、蓝色、红色、白色等);并将标注有内容硬标签(风景、动物萌宠、人物等)的第一训练样本,输入到用于对待分类图像所展示的风格进行分类预测的第三分类模型,得到对应第一训练样本的风格软标签(中国风、文艺风、幽默风、西方艺术等)。
进一步地,采用对应第一训练样本的颜色软标签、及对应第一训练样本的风格软标签,对第一训练样本进行标注,从而得到标注有内容硬标签、颜色软标签和风格软标签的第一训练样本。
相同地,基于同样的方式,分别得到标注有内容软标签、颜色硬标签和风格软标签的第二训练样本,以及分别标注有内容软标签、颜色软标签和风格硬标签的第三训练样本。
步骤907:将分别标注有第一类别、第二类别和第三类别的分类标签的第一训练样本、第二训练样本和第三训练样本输入融合分类模型,得到对应第一类别、第二类别和第三类别的分类结果。
这里,融合分类模型能够基于输入的待分类对象,进行第一类别、第二类别和第三类别的分类预测,并得到相应的分类结果。
在实际应用中,也可以将分别标注有第一类别、第二类别和第三类别的分类标签的第一训练样本、第二训练样本和第三训练样本中的至少之一输入融合分类模型。
步骤908:获取第一类别、第二类别和第三类别的分类结果与相应硬标签之间的第一差异,第一类别、第二类别和第三类别的分类结果与相应软标签之间的第二差异。
步骤909:基于第一差异,确定融合分类模型的交叉熵损失函数的值。
步骤910:基于第二差异,确定融合分类模型的蒸馏损失函数的值。
步骤911:获取交叉熵损失函数对应的第一权重、及蒸馏损失函数对应的第二权重。
这里,第一权重和第二权重可根据经验自定义。
步骤912:结合第一权重和第二权重、交叉熵损失函数的值和蒸馏损失函数的值,确定融合分类模型的损失函数的值。
步骤913:基于融合分类模型的损失函数的值,更新融合分类模型各个层的模型参数,以实现对融合分类模型的训练。
步骤914:终端接收到针对待分类对象的分类预测指令,向服务器发送待分类对象的分类预测请求。
步骤915:服务器接收到针对待分类对象的分类预测请求,通过训练完成的融合分类模型对待分类对象进行第一类别、第二类别和第三类别的分类预测,得到分类结果,并返回终端。
继续以该三个分类模型为图像分类模型为例,第一分类模型用于对待分类图像所包含的内容进行分类预测,第二分类模型用于对待分类图像所呈现的颜色进行分类预测,第三分类模型用于对待分类图像所展示的风格进行分类预测,则融合分类模型可用于对待分类图像进行所包含内容、所呈现颜色和所展示风格的分类预测。
通过融合分类模型的特征提取层,对待分类图像进行特征提取,得到待分类图像的图像特征;再通过融合分类模型的多分类层,对待分类图像的图像特征进行分类预测,得到对应待分类图像的分类结果,即待分类图像所包含内容、所呈现颜色和所展示风格。
步骤916:终端呈现对应待分类对象的第一类别、第二类别和第三类别的分类结果。
下面继续说明本发明实施例提供的分类模型的融合装置355,在一些实施例中,分类模型的融合装置可采用软件模块的方式实现。参见图10,图10是本发明实施例提供的分类模型的融合装置355的结构示意图,本发明实施例提供的分类模型的融合装置355包括:
获取模块3551,用于获取训练得到的n个分类模型、及用于训练各个所述分类模型的训练样本;其中,n为不小于2的正整数,所述n个分类模型中第i个分类模型用于进行第i类别的分类预测,i为不大于n的正整数;
分类预测模块3552,用于通过所述第i个分类模型,对第j个分类模型的训练样本进行所述第i类别的分类预测,得到对应所述第j个分类模型的训练样本的第i分类结果;其中,j为不大于n的正整数,且j不等于i;
标注模块3553,用于以所述第i分类结果作为所述第j个分类模型的训练样本的第i类别的分类标签,对所述第j个分类模型的训练样本进行标注;
第一遍历模块3554,用于对所述j进行遍历,得到标注有对应所述第i类别的分类标签的训练样本所构成的第i数据集;
第二遍历模块3555,用于对所述i进行遍历,得到标注有对应n个类别的分类标签的训练样本所构成的n个数据集;
第一模型训练模块3556,用于基于所述n个数据集中至少之一,训练融合分类模型,使得所述融合分类模型能够基于输入的待分类对象,进行所述n个类别的分类预测,并得到相应的分类结果。
在一些实施例中,所述装置还包括:
第二模型训练模块,用于分别将用于训练各个所述分类模型的训练样本,输入至相应的分类模型进行分类预测,得到相应的预测结果;其中,用于训练第i个分类模型的训练样本标注有对应第i类别的初始分类标签;
基于得到的所述预测结果,及用于训练各个所述分类模型的训练样本的初始分类标签,确定各个所述分类模型的损失函数的值;
基于各个所述分类模型的损失函数的值,更新各个所述分类模型的模型参数。
在一些实施例中,所述第一模型训练模块3556,还用于通过所述融合分类模型,对所述n个数据集中至少之一的标注有对应n个类别的分类标签的训练样本,进行所述n个类别的分类预测,得到对应所述n个类别的分类结果;
获取所述n个类别中各类别的分类结果与相应类别的分类标签之间的差异;
基于所述差异,更新所述融合分类模型的模型参数。
在一些实施例中,所述训练样本中标注的对应n个类别的分类标签包括初始分类标签及软标签,所述初始分类标签为训练所述分类模型时所标注,所述软标签为基于所述分类模型对所述训练样本进行分类预测所得到的分类结果所标注;
所述第一模型训练模块3556,还用于获取所述n个类别中对应所述初始分类标签的类别的分类结果与所述初始分类标签之间的第一差异;以及
获取所述n个类别中对应软标签的类别的分类结果与所述软标签之间的第二差异;
所述第一模型训练模块3556,还用于基于所述第一差异及所述第二差异,确定所述融合分类模型的损失函数的值;
基于所述融合分类模型的损失函数的值,更新所述融合分类模型的模型参数。
在一些实施例中,所述融合分类模型的损失函数包括交叉熵损失函数及蒸馏损失函数,所述第一模型训练模块3556,还用于基于所述第一差异,确定所述交叉熵损失函数的值;
基于所述第二差异,确定所述蒸馏损失函数的值;
获取所述交叉熵损失函数对应的第一权重,及所述蒸馏损失函数对应的第二权重;
结合所述第一权重和第二权重、所述交叉熵损失函数的值和蒸馏损失函数的值,确定所述融合分类模型的损失函数的值。
在一些实施例中,所述第一模型训练模块3556,还用于当所述融合分类模型的损失函数的值超出第一损失阈值时,基于所述融合分类模型的损失函数确定所述融合分类模型的第一误差信号;
将所述第一误差信号在所述融合分类模型中反向传播,并在传播的过程中更新各个层的模型参数。
在一些实施例中,所述第一模型训练模块3556,还用于基于所述差异,确定所述融合分类模型的损失函数的值;
当所述融合分类模型的损失函数的值超出第二损失阈值时,基于所述融合分类模型的损失函数确定所述融合分类模型的第二误差信号;
将所述第二误差信号在所述融合分类模型中反向传播,并在传播的过程中更新各个层的模型参数。
在一些实施例中,所述装置还包括:
分类模块,用于通过所述融合分类模型的特征提取层,对输入的待分类对象进行特征提取,得到待分类对象的对象特征;
基于所述待分类对象的对象特征,通过所述融合分类模型的多分类层,进行所述n个类别的分类预测,得到对应所述n个类别的分类结果。
在一些实施例中,当n为2时,所述n个分类模型包括:第一分类模型和第二分类模型;其中,所述第一分类模型用于第一类别的分类预测,所述第二分类模型用于第二类别的分类预测;
所述n个数据集包括:由第一训练样本构成的第一训练样本集、及由第二训练样本构成的第二训练样本集;其中,所述第一训练样本标注有对应所述第一类别的初始分类标签和对应所述第二类别的软标签,所述第二训练样本标注有对应所述第二类别的初始分类标签和对应所述第一类别的软标签;
所述第一模型训练模块3556,还用于基于所述第一训练样本集、所述第二训练样本集中至少之一,训练所述融合分类模型,使得所述融合分类模型能够基于输入的待分类对象,进行所述第一类别和所述第二类别的分类预测,得到相应的分类结果。
本发明实施例还提供一种电子设备,所述电子设备包括:
存储器,用于存储可执行指令;
处理器,用于执行所述存储器中存储的可执行指令时,实现本发明实施例提供的分类模型的融合方法。
本发明实施例还提供一种计算机可读存储介质,存储有可执行指令,所述可执行指令被处理器执行时,实现本发明实施例提供的分类模型的融合方法。
在一些实施例中,存储介质可以是FRAM、ROM、PROM、EPROM、EE PROM、闪存、磁表面存储器、光盘、或CD-ROM等存储器;也可以是包括上述存储器之一或任意组合的各种设备。计算机可以是包括智能终端和服务器在内的各种计算设备。
在一些实施例中,可执行指令可以采用程序、软件、软件模块、脚本或代码的形式,按任意形式的编程语言(包括编译或解释语言,或者声明性或过程性语言)来编写,并且其可按任意形式部署,包括被部署为独立的程序或者被部署为模块、组件、子例程或者适合在计算环境中使用的其它单元。
作为示例,可执行指令可以但不一定对应于文件系统中的文件,可以可被存储在保存其它程序或数据的文件的一部分,例如,存储在超文本标记语言(H TML,Hyper TextMarkup Language)文档中的一个或多个脚本中,存储在专用于所讨论的程序的单个文件中,或者,存储在多个协同文件(例如,存储一个或多个模块、子程序或代码部分的文件)中。
作为示例,可执行指令可被部署为在一个计算设备上执行,或者在位于一个地点的多个计算设备上执行,又或者,在分布在多个地点且通过通信网络互连的多个计算设备上执行。
以上所述,仅为本发明的实施例而已,并非用于限定本发明的保护范围。凡在本发明的精神和范围之内所作的任何修改、等同替换和改进等,均包含在本发明的保护范围之内。

Claims (12)

1.一种分类模型的融合方法,其特征在于,所述方法包括:
获取训练得到的n个分类模型、及用于训练各个所述分类模型的训练样本;其中,n为不小于2的正整数,所述n个分类模型中第i个分类模型用于进行第i类别的分类预测,i为不大于n的正整数;
通过所述第i个分类模型,对第j个分类模型的训练样本进行所述第i类别的分类预测,得到对应所述第j个分类模型的训练样本的第i分类结果;其中,j为不大于n的正整数,且j不等于i;
以所述第i分类结果作为所述第j个分类模型的训练样本的第i类别的分类标签,对所述第j个分类模型的训练样本进行标注;
对所述j进行遍历,得到标注有对应所述第i类别的分类标签的训练样本所构成的第i数据集;
对所述i进行遍历,得到标注有对应n个类别的分类标签的训练样本所构成的n个数据集;
基于所述n个数据集中至少之一,训练融合分类模型,使得所述融合分类模型能够基于输入的待分类对象,进行所述n个类别的分类预测,并得到相应的分类结果。
2.如权利要求1所述的方法,其特征在于,所述获取训练得到的n个分类模型、及用于训练各个所述分类模型的训练样本之前,所述方法还包括:
分别将用于训练各个所述分类模型的训练样本,输入至相应的分类模型进行分类预测,得到相应的预测结果;其中,用于训练第i个分类模型的训练样本标注有对应第i类别的初始分类标签;
基于得到的所述预测结果,及用于训练各个所述分类模型的训练样本的初始分类标签,确定各个所述分类模型的损失函数的值;
基于各个所述分类模型的损失函数的值,更新各个所述分类模型的模型参数。
3.如权利要求1所述的方法,其特征在于,所述基于所述n个数据集中至少之一,训练融合分类模型,包括:
通过所述融合分类模型,对所述n个数据集中至少之一的标注有对应n个类别的分类标签的训练样本,进行所述n个类别的分类预测,得到对应所述n个类别的分类结果;
获取所述n个类别中各类别的分类结果与相应类别的分类标签之间的差异;
基于所述差异,更新所述融合分类模型的模型参数。
4.如权利要求3所述的方法,其特征在于,
所述训练样本中标注的对应n个类别的分类标签包括初始分类标签及软标签,所述初始分类标签为训练所述分类模型时所标注,所述软标签为基于所述分类模型对所述训练样本进行分类预测所得到的分类结果所标注;
所述获取所述n个类别中各类别的分类结果与相应类别的分类标签之间的差异,包括:
获取所述n个类别中对应所述初始分类标签的类别的分类结果与所述初始分类标签之间的第一差异;以及
获取所述n个类别中对应软标签的类别的分类结果与所述软标签之间的第二差异;
所述基于所述差异,更新所述融合分类模型的模型参数,包括:
基于所述第一差异及所述第二差异,确定所述融合分类模型的损失函数的值;
基于所述融合分类模型的损失函数的值,更新所述融合分类模型的模型参数。
5.如权利要求4所述的方法,其特征在于,所述融合分类模型的损失函数包括交叉熵损失函数及蒸馏损失函数,所述基于所述第一差异及所述第二差异,确定所述融合分类模型的损失函数的值,包括:
基于所述第一差异,确定所述交叉熵损失函数的值;
基于所述第二差异,确定所述蒸馏损失函数的值;
获取所述交叉熵损失函数对应的第一权重,及所述蒸馏损失函数对应的第二权重;
结合所述第一权重和第二权重、所述交叉熵损失函数的值和蒸馏损失函数的值,确定所述融合分类模型的损失函数的值。
6.如权利要求4所述的方法,其特征在于,所述基于所述融合分类模型的损失函数的值,更新所述融合分类模型的模型参数,包括:
当所述融合分类模型的损失函数的值超出第一损失阈值时,基于所述融合分类模型的损失函数确定所述融合分类模型的第一误差信号;
将所述第一误差信号在所述融合分类模型中反向传播,并在传播的过程中更新各个层的模型参数。
7.如权利要求3所述的方法,其特征在于,所述基于所述差异,更新所述融合分类模型的模型参数,包括:
基于所述差异,确定所述融合分类模型的损失函数的值;
当所述融合分类模型的损失函数的值超出第二损失阈值时,基于所述融合分类模型的损失函数确定所述融合分类模型的第二误差信号;
将所述第二误差信号在所述融合分类模型中反向传播,并在传播的过程中更新各个层的模型参数。
8.如权利要求1所述的方法,其特征在于,所述方法还包括:
通过所述融合分类模型的特征提取层,对输入的待分类对象进行特征提取,得到待分类对象的对象特征;
基于所述待分类对象的对象特征,通过所述融合分类模型的多分类层,进行所述n个类别的分类预测,得到对应所述n个类别的分类结果。
9.如权利要求1所述的方法,其特征在于,
当n为2时,所述n个分类模型包括:第一分类模型和第二分类模型;其中,所述第一分类模型用于第一类别的分类预测,所述第二分类模型用于第二类别的分类预测;
所述n个数据集包括:由第一训练样本构成的第一训练样本集、及由第二训练样本构成的第二训练样本集;其中,所述第一训练样本标注有对应所述第一类别的初始分类标签和对应所述第二类别的软标签,所述第二训练样本标注有对应所述第二类别的初始分类标签和对应所述第一类别的软标签;
所述基于所述n个数据集中至少之一,训练融合分类模型,包括:
基于所述第一训练样本集、所述第二训练样本集中至少之一,训练所述融合分类模型,使得所述融合分类模型能够基于输入的待分类对象,进行所述第一类别和所述第二类别的分类预测,得到相应的分类结果。
10.一种分类模型的融合装置,其特征在于,所述装置包括:
获取模块,用于获取训练得到的n个分类模型、及用于训练各个所述分类模型的训练样本;其中,n为不小于2的正整数,所述n个分类模型中第i个分类模型用于进行第i类别的分类预测,i为不大于n的正整数;
分类预测模块,用于通过所述第i个分类模型,对第j个分类模型的训练样本进行所述第i类别的分类预测,得到对应所述第j个分类模型的训练样本的第i分类结果;其中,j为不大于n的正整数,且j不等于i;
标注模块,用于以所述第i分类结果作为所述第j个分类模型的训练样本的第i类别的分类标签,对所述第j个分类模型的训练样本进行标注;
第一遍历模块,用于对所述j进行遍历,得到标注有对应所述第i类别的分类标签的训练样本所构成的第i数据集;
第二遍历模块,用于对所述i进行遍历,得到标注有对应n个类别的分类标签的训练样本所构成的n个数据集;
第一模型训练模块,用于基于所述n个数据集中至少之一,训练融合分类模型,使得所述融合分类模型能够基于输入的待分类对象,进行所述n个类别的分类预测,并得到相应的分类结果。
11.一种电子设备,其特征在于,所述电子设备包括:
存储器,用于存储可执行指令;
处理器,用于执行所述存储器中存储的可执行指令时,实现如权利要求1至9任一项所述的分类模型的融合方法。
12.一种计算机可读存储介质,其特征在于,存储有可执行指令,所述可执行指令被执行时,用于实现如权利要求1至9任一项所述的分类模型的融合方法。
CN202010113301.0A 2020-02-24 2020-02-24 分类模型的融合方法、装置、电子设备及存储介质 Active CN111291823B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010113301.0A CN111291823B (zh) 2020-02-24 2020-02-24 分类模型的融合方法、装置、电子设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010113301.0A CN111291823B (zh) 2020-02-24 2020-02-24 分类模型的融合方法、装置、电子设备及存储介质

Publications (2)

Publication Number Publication Date
CN111291823A true CN111291823A (zh) 2020-06-16
CN111291823B CN111291823B (zh) 2023-08-18

Family

ID=71031051

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010113301.0A Active CN111291823B (zh) 2020-02-24 2020-02-24 分类模型的融合方法、装置、电子设备及存储介质

Country Status (1)

Country Link
CN (1) CN111291823B (zh)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112329824A (zh) * 2020-10-23 2021-02-05 北京中科智加科技有限公司 多模型融合训练方法、文本分类方法以及装置
CN112528109A (zh) * 2020-12-01 2021-03-19 中科讯飞互联(北京)信息科技有限公司 一种数据分类方法、装置、设备及存储介质
CN112561000A (zh) * 2021-02-22 2021-03-26 腾讯科技(深圳)有限公司 基于组合模型的分类方法、装置、设备及存储介质
CN113312445A (zh) * 2021-07-29 2021-08-27 阿里云计算有限公司 数据处理方法、模型构建方法、分类方法及计算设备

Citations (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
KR101623431B1 (ko) * 2015-08-06 2016-05-23 주식회사 루닛 의료 영상의 병리 진단 분류 장치 및 이를 이용한 병리 진단 시스템
US20170018089A1 (en) * 2014-02-28 2017-01-19 International Business Machines Corporation Segmentation Using Hybrid Discriminative Generative Label Fusion of Multiple Atlases
CN107526785A (zh) * 2017-07-31 2017-12-29 广州市香港科大霍英东研究院 文本分类方法及装置
US20180032846A1 (en) * 2016-08-01 2018-02-01 Nvidia Corporation Fusing multilayer and multimodal deep neural networks for video classification
CN108460415A (zh) * 2018-02-28 2018-08-28 国信优易数据有限公司 伪标签生成模型训练方法及伪标签生成方法
CN108875045A (zh) * 2018-06-28 2018-11-23 第四范式(北京)技术有限公司 针对文本分类来执行机器学习过程的方法及其系统
CN109902722A (zh) * 2019-01-28 2019-06-18 北京奇艺世纪科技有限公司 分类器、神经网络模型训练方法、数据处理设备及介质
CN110147456A (zh) * 2019-04-12 2019-08-20 中国科学院深圳先进技术研究院 一种图像分类方法、装置、可读存储介质及终端设备
CN110196908A (zh) * 2019-04-17 2019-09-03 深圳壹账通智能科技有限公司 数据分类方法、装置、计算机装置及存储介质
CN110263697A (zh) * 2019-06-17 2019-09-20 哈尔滨工业大学(深圳) 基于无监督学习的行人重识别方法、装置及介质
CN110413775A (zh) * 2019-06-25 2019-11-05 北京清博大数据科技有限公司 一种数据打标签分类方法、装置、终端及存储介质
CN110781934A (zh) * 2019-10-15 2020-02-11 深圳市商汤科技有限公司 监督学习、标签预测方法及装置、电子设备和存储介质

Patent Citations (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20170018089A1 (en) * 2014-02-28 2017-01-19 International Business Machines Corporation Segmentation Using Hybrid Discriminative Generative Label Fusion of Multiple Atlases
KR101623431B1 (ko) * 2015-08-06 2016-05-23 주식회사 루닛 의료 영상의 병리 진단 분류 장치 및 이를 이용한 병리 진단 시스템
US20180032846A1 (en) * 2016-08-01 2018-02-01 Nvidia Corporation Fusing multilayer and multimodal deep neural networks for video classification
CN107526785A (zh) * 2017-07-31 2017-12-29 广州市香港科大霍英东研究院 文本分类方法及装置
CN108460415A (zh) * 2018-02-28 2018-08-28 国信优易数据有限公司 伪标签生成模型训练方法及伪标签生成方法
CN108875045A (zh) * 2018-06-28 2018-11-23 第四范式(北京)技术有限公司 针对文本分类来执行机器学习过程的方法及其系统
CN109902722A (zh) * 2019-01-28 2019-06-18 北京奇艺世纪科技有限公司 分类器、神经网络模型训练方法、数据处理设备及介质
CN110147456A (zh) * 2019-04-12 2019-08-20 中国科学院深圳先进技术研究院 一种图像分类方法、装置、可读存储介质及终端设备
CN110196908A (zh) * 2019-04-17 2019-09-03 深圳壹账通智能科技有限公司 数据分类方法、装置、计算机装置及存储介质
CN110263697A (zh) * 2019-06-17 2019-09-20 哈尔滨工业大学(深圳) 基于无监督学习的行人重识别方法、装置及介质
CN110413775A (zh) * 2019-06-25 2019-11-05 北京清博大数据科技有限公司 一种数据打标签分类方法、装置、终端及存储介质
CN110781934A (zh) * 2019-10-15 2020-02-11 深圳市商汤科技有限公司 监督学习、标签预测方法及装置、电子设备和存储介质

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
CHUNYAN LU 等: "Multi-model dynamic fusion soft-sensing modeling and its application", 《2017 36TH CHINESE CONTROL CONFERENCE (CCC)》, pages 9682 - 9685 *
余游 等: "一种基于深度网络的少样本学习方法", 《小型微型计算机系统》, pages 2304 - 2308 *

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112329824A (zh) * 2020-10-23 2021-02-05 北京中科智加科技有限公司 多模型融合训练方法、文本分类方法以及装置
CN112528109A (zh) * 2020-12-01 2021-03-19 中科讯飞互联(北京)信息科技有限公司 一种数据分类方法、装置、设备及存储介质
CN112528109B (zh) * 2020-12-01 2023-10-27 科大讯飞(北京)有限公司 一种数据分类方法、装置、设备及存储介质
CN112561000A (zh) * 2021-02-22 2021-03-26 腾讯科技(深圳)有限公司 基于组合模型的分类方法、装置、设备及存储介质
CN112561000B (zh) * 2021-02-22 2021-05-28 腾讯科技(深圳)有限公司 基于组合模型的分类方法、装置、设备及存储介质
CN113312445A (zh) * 2021-07-29 2021-08-27 阿里云计算有限公司 数据处理方法、模型构建方法、分类方法及计算设备
CN113312445B (zh) * 2021-07-29 2022-02-11 阿里云计算有限公司 数据处理方法、模型构建方法、分类方法及计算设备

Also Published As

Publication number Publication date
CN111291823B (zh) 2023-08-18

Similar Documents

Publication Publication Date Title
CN111275133B (zh) 分类模型的融合方法、装置及存储介质
CN111090987B (zh) 用于输出信息的方法和装置
CN111291823A (zh) 分类模型的融合方法、装置、电子设备及存储介质
US20230025317A1 (en) Text classification model training method, text classification method, apparatus, device, storage medium and computer program product
CN111090756B (zh) 基于人工智能的多目标推荐模型的训练方法及装置
CN113762052A (zh) 视频封面提取方法、装置、设备及计算机可读存储介质
CN112287994A (zh) 伪标签处理方法、装置、设备及计算机可读存储介质
Singh et al. Mobile Deep Learning with TensorFlow Lite, ML Kit and Flutter: Build scalable real-world projects to implement end-to-end neural networks on Android and iOS
CN111274473B (zh) 基于人工智能的推荐模型的训练方法、装置及存储介质
CN111858898A (zh) 基于人工智能的文本处理方法、装置及电子设备
CN112749558B (zh) 一种目标内容获取方法、装置、计算机设备和存储介质
CN111831826A (zh) 跨领域的文本分类模型的训练方法、分类方法以及装置
CN116956116A (zh) 文本的处理方法和装置、存储介质及电子设备
CN114600196A (zh) 特定领域的人类模型协同注释工具
KR20180105501A (ko) 언어 정보를 처리하기 위한 방법 및 그 전자 장치
CN116912187A (zh) 图像生成模型训练及图像生成方法、装置、设备和介质
CN111062216A (zh) 命名实体识别方法、装置、终端及可读介质
Zhao et al. A multimodal model for college English teaching using text and image feature extraction
CN117033649A (zh) 文本处理模型的训练方法、装置、电子设备及存储介质
CN113407806B (zh) 网络结构搜索方法、装置、设备及计算机可读存储介质
CN112818084B (zh) 信息交互方法、相关装置、设备及计算机可读介质
CN112182179B (zh) 实体问答处理方法、装置、电子设备和存储介质
CN114331932A (zh) 目标图像生成方法和装置、计算设备以及计算机存储介质
CN113377951A (zh) 智能客服机器人的语料构建方法及装置
CN113761154A (zh) 智能问答方法、装置、设备及计算机可读存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
REG Reference to a national code

Ref country code: HK

Ref legal event code: DE

Ref document number: 40023600

Country of ref document: HK

SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant