CN114065858A - 一种模型训练方法、装置、电子设备及存储介质 - Google Patents

一种模型训练方法、装置、电子设备及存储介质 Download PDF

Info

Publication number
CN114065858A
CN114065858A CN202111360462.0A CN202111360462A CN114065858A CN 114065858 A CN114065858 A CN 114065858A CN 202111360462 A CN202111360462 A CN 202111360462A CN 114065858 A CN114065858 A CN 114065858A
Authority
CN
China
Prior art keywords
category
basic
sample data
model
incremental
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111360462.0A
Other languages
English (en)
Inventor
李铎
程战战
钮毅
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shanghai Goldway Intelligent Transportation System Co Ltd
Original Assignee
Shanghai Goldway Intelligent Transportation System Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shanghai Goldway Intelligent Transportation System Co Ltd filed Critical Shanghai Goldway Intelligent Transportation System Co Ltd
Priority to CN202111360462.0A priority Critical patent/CN114065858A/zh
Publication of CN114065858A publication Critical patent/CN114065858A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches

Landscapes

  • Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本申请实施例提供了一种模型训练方法、装置、电子设备及存储介质,涉及人工智能技术领域,包括:获得基础样本数据,并获得增量样本数据,所述增量样本数据的增量样本类别与所述基础样本数据的基础样本类别不完全相同;利用所述增量样本数据训练用于进行分类的目标模型;利用所述目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别,所述新增类别为:所述增量样本类别中不同于所述基础样本类别的类别;利用所述增量样本数据、更新类别后的基础样本数据对所述基础模型进行训练。应用本申请实施例提供的方案可以提高增量训练后模型的准确度。

Description

一种模型训练方法、装置、电子设备及存储介质
技术领域
本申请涉及人工智能技术领域,特别是涉及一种模型训练方法、装置、电子设备及存储介质。
背景技术
在人工智能技术领域,为了对模型进行优化,通常需要利用新增的样本数据对基础模型进行增量训练,其中新增的样本数据的类别可能与训练基础模型时使用的、基础的样本数据的类别不同,使得增量训练后的模型能够对新的类别的数据进行分类。从而,在进行增量训练过程中,由于模型只关注了新增的样本数据的类别,可能导致模型对基础的样本数据的类别进行识别的性能降低,发生灾难性遗忘。
相关技术中,为了防止增量训练后的模型发生灾难性遗忘,一般在增量训练时,会使用基础的样本数据和新增的样本数据共同对模型进行训练。而由于基础的样本数据的类别可能与新增的样本数据的类别不同,导致增量训练所使用的样本数据中存在噪声,降低了增量训练后模型的准确度。
发明内容
本申请实施例的目的在于提供一种模型训练方法、装置、电子设备及存储介质,以提高增量训练后模型的准确度。具体技术方案如下:
第一方面,本申请实施例提供了一种模型训练方法,所述方法包括:
获得基础样本数据,并获得增量样本数据,其中,所述基础样本数据为:训练基础模型时所采用的样本数据,所述增量样本数据的增量样本类别与所述基础样本数据的基础样本类别不完全相同;
利用所述增量样本数据训练用于进行分类的目标模型;
利用所述目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别,其中,所述新增类别为:所述增量样本类别中不同于所述基础样本类别的类别;
利用所述增量样本数据、更新类别后的基础样本数据对所述基础模型进行训练。
本申请的一个实施例中,在所述针对每一基础样本数据、在该基础样本数据的伪类别为新增类别时、更新该基础样本数据的类别为所识别到的伪类别步骤之前,所述方法还包括:
确定所述增量样本类别中与所述基础样本类别相关联的类别,作为关联类别;
将属于被关联类别的基础样本数据的类别更新为所述关联类别,其中,所述被关联类别为:所述基础样本类别中与所述关联类别相关联的类别。
本申请的一个实施例中,所述确定所述增量样本类别中与所述基础样本类别相关联的类别,作为关联类别,包括:
利用所述基础模型识别各个增量样本数据的类别,获得所述基础模型对不同增量样本类别的增量样本数据进行识别的准确度;
确定准确度最高的增量样本类别,作为与所述基础样本类别相关联的关联类别。
本申请的一个实施例中,所述利用所述增量样本数据、更新类别后的基础样本数据对所述基础模型进行训练,包括:
将所述增量样本数据、更新类别后的基础样本数据作为输入数据输入所述基础模型,利用所述基础模型提取所述输入数据的第一特征、并根据所述第一特征预测所述输入数据的类别,得到输出结果;
计算所述输出结果相对所述输入数据的标注信息的第一损失,其中,每一输入数据的标注信息反映该输入数据的类别;
利用所述第一损失调整所述基础模型的参数,实现对所述基础模型的训练。
本申请的一个实施例中,所述利用所述第一损失调整所述基础模型的参数,实现对所述基础模型的训练,包括:
利用所述第一损失调整所述基础模型的参数,得到调参后模型;
将所述输入数据输入所述调参后模型,利用所述调参后模型提取所述输入数据的第二特征;
计算所述第二特征相对所述第一特征的第二损失,利用所述第二损失调整所述调参后模型的参数,实现模型训练。
第二方面,本申请实施例提供了一种模型训练装置,所述装置包括:
数据获得模块,用于获得基础样本数据,并获得增量样本数据,其中,所述基础样本数据为:训练基础模型时所采用的样本数据,所述增量样本数据的增量样本类别与所述基础样本数据的基础样本类别不完全相同;
目标模型训练模块,用于利用所述增量样本数据训练用于进行分类的目标模型;
第一类别更新模块,用于利用所述目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别,其中,所述新增类别为:所述增量样本类别中不同于所述基础样本类别的类别;
模型训练模块,用于利用所述增量样本数据、更新类别后的基础样本数据对所述基础模型进行训练。
本申请的一个实施例中,所述装置还包括第二类别更新模块,包括:
关联类别确定单元,用于在针对每一基础样本数据、在该基础样本数据的伪类别为新增类别时、更新该基础样本数据的类别为所识别到的伪类别之前,确定所述增量样本类别中与所述基础样本类别相关联的类别,作为关联类别;
第二类别更新单元,用于将属于被关联类别的基础样本数据的类别更新为所述关联类别,其中,所述被关联类别为:所述基础样本类别中与所述关联类别相关联的类别。
本申请的一个实施例中,所述关联类别确定单元,具体用于:
利用所述基础模型识别各个增量样本数据的类别,获得所述基础模型对不同增量样本类别的增量样本数据进行识别的准确度;
确定准确度最高的增量样本类别,作为与所述基础样本类别相关联的关联类别。
本申请的一个实施例中,所述模型训练模块,包括:
输出结果获得单元,用于将所述增量样本数据、更新类别后的基础样本数据作为输入数据输入所述基础模型,利用所述基础模型提取所述输入数据的第一特征、并根据所述第一特征预测所述输入数据的类别,得到输出结果;
第一损失计算单元,用于计算所述输出结果相对所述输入数据的标注信息的第一损失,其中,每一输入数据的标注信息反映该输入数据的类别;
模型训练单元,用于利用所述第一损失调整所述基础模型的参数,实现对所述基础模型的训练。
本申请的一个实施例中,所述模型训练单元,具体用于:
利用所述第一损失调整所述基础模型的参数,得到调参后模型;
将所述输入数据输入所述调参后模型,利用所述调参后模型提取所述输入数据的第二特征;
计算所述第二特征相对所述第一特征的第二损失,利用所述第二损失调整所述调参后模型的参数,实现模型训练。
第三方面,本申请实施例提供了一种电子设备,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现第一方面任一所述的方法步骤。
第四方面,本申请实施例提供了一种计算机可读存储介质,所述计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现第一方面任一所述的方法步骤。
本申请实施例还提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行上述任一所述的模型训练方法。
本申请实施例有益效果:
本申请实施例提供的模型训练方案中,可以获得基础样本数据,并获得增量样本数据,其中,基础样本数据为:训练基础模型时所采用的样本数据,增量样本数据的增量样本类别与基础样本数据的基础样本类别不完全相同;利用增量样本数据训练用于进行分类的目标模型;利用目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别,其中,新增类别为:增量样本类别中不同于基础样本类别的类别;利用增量样本数据、更新类别后的基础样本数据对基础模型进行训练。这样可以利用增量样本数据训练得到目标模型,该目标模型能够对属于增量样本类别的数据进行分类,因此可以利用上述目标模型更新基础样本数据中属于新增类别的样本数据的类别,使得更新类别后的基础样本数据的类别与新增样本类别一致,从而降低基础样本数据的噪声,后续可以利用新增样本数据和降低噪声后的基础样本数据对基础模型进行混合增量训练。由此可见,应用本申请实施例提供的方案,可以提高增量训练后模型的准确度。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,还可以根据这些附图获得其他的实施例。
图1为本申请实施例提供的一种模型训练方法的流程示意图;
图2为本申请实施例提供的另一种模型训练方法的流程示意图;
图3为本申请实施例提供的一种模型训练过程的示意图;
图4为本申请实施例提供的一种模型训练装置的结构示意图;
图5为本申请实施例提供的一种电子设备的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员基于本申请所获得的所有其他实施例,都属于本申请保护的范围。
为了提高增量训练后模型的准确度,本申请实施例提供了一种模型训练方法、装置、电子设备及存储介质,下面分别进行详细介绍。
本申请的一个实施例中,提供了一种模型训练方法,该方法可以应用于计算机、手机、平板电脑、服务器等电子设备,该方法包括:
获得基础样本数据,并获得增量样本数据,其中,基础样本数据为:训练基础模型时所采用的样本数据,增量样本数据的增量样本类别与基础样本数据的基础样本类别不完全相同;
利用增量样本数据训练用于进行分类的目标模型;
利用目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别,其中,新增类别为:增量样本类别中不同于基础样本类别的类别;
利用增量样本数据、更新类别后的基础样本数据对基础模型进行训练。
这样可以利用增量样本数据训练得到目标模型,该目标模型能够对属于增量样本类别的数据进行分类,因此可以利用上述目标模型更新基础样本数据中属于新增类别的样本数据的类别,使得更新类别后的基础样本数据的类别与新增样本类别一致,从而降低基础样本数据的噪声,后续可以利用新增样本数据和降低噪声后的基础样本数据对基础模型进行混合增量训练。由此可见,应用上述实施例提供的方案,可以提高增量训练后模型的准确度。
上述模型训练方法可以应用于人脸识别模型、车牌号码识别模型、报警事件检测模型、目标对象分类模型等,本申请实施例并不对此进行限定。
下面对上述模型训练方法进行详细介绍。
参见图1,图1为本申请实施例提供的一种模型训练方法的流程示意图,该方法包括如下步骤S101-S104:
S101,获得基础样本数据,并获得增量样本数据。
其中,基础样本数据为:训练基础模型时所采用的样本数据。上述样本数据可以是图像、视频片段、文本、音频片段等。基础样本数据的类别称为基础样本类别,利用基础样本数据训练得到的基础模型能够对基础样本类别的数据进行分类。上述类别指的是:数据中所包含的内容所属的分类,例如,假设数据为图像,图像内容为人脸,则可以认为该数据的类别为“人脸”。
增量样本数据为:新增加的、未进行过模型训练的样本数据。增量样本数据的类别称为增量样本类别。
增量样本数据的增量样本类别与基础样本数据的基础样本类别不完全相同,增量样本类别与基础样本类别之间存在交集,也存在不相交集。例如,假设基础样本类别为“手”,增量样本类别可以是“左手”、“右手”、“其他”等。
具体的,可以获得已使用过的、用于对基础模型进行训练的样本数据,作为基础样本数据,并获得未使用过的、所属类别与上述基础样本类别不完全相同的样本数据,作为增量样本数据。
本申请的一个实施例中,在获得上述增量样本数据时,可以获得上述基础模型在应用过程中输入该基础模型的数据,从上述数据中选择增量样本数据;也可以从公开的数据平台获得增量样本数据等,本申请实施例并不对此进行限定。
本申请的一个实施例中,可以将基础样本数据存储至基础数据集中,将增量样本数据存储至增量数据集中,在后续利用该增量样本数据进行增量训练后,可以将增量数据集中的数据转存至基础数据集,作为下一次进行增量训练的基础样本数据。
S102,利用增量样本数据训练用于进行分类的目标模型。
具体的,可以利用增量样本数据训练得到目标模型,该目标模型能够对属于增量样本类别的数据进行分类。
本申请的一个实施例中,可以利用增量样本数据对未经过训练的原始模型进行训练,得到目标模型;
除此之外,还可以利用增量样本数据对经过通用数据预训练后的通用模型进行训练,得到目标模型;
另外,还可以利用增量样本数据对上述采用基础样本数据训练得到的基础模型进行再次训练,得到目标模型,本申请实施例并不对此进行限定。
S103,利用目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别。
其中,新增类别为:增量样本类别中不同于基础样本类别的类别。例如,假设基础样本类别包括:“水杯”、“笔筒”、“电话”,增量样本类别包括:“电话”、“订书机”、“书籍”,则新增类别为:“订书机”和“书籍”。
上述伪类别指的是:由目标模型识别到的基础样本数据的类别。
具体的,目标模型能够对属于增量样本类别的数据进行分类,而基础样本数据所属的基础样本类别与上述增量样本类别不完全相同,这种情况下,可以利用上述目标模型重新识别各个基础样本数据所属的类别,得到各个基础样本数据的伪类别。针对每一基础样本数据,若该基础样本数据的伪类别为新增类别,则为了降低基础样本数据中的噪声,可以将该基础样本数据原本的类别更新为该伪类别。
例如,假设新增类别包括“右手”,基础样本数据原本的类别为“手”,目标模型对该基础样本数据进行识别得到的伪类别为“右手”,则可以将该基础样本数据的类别更新为“右手”。
S104,利用增量样本数据、更新类别后的基础样本数据对基础模型进行训练。
具体的,可以利用增量样本数据、更新类别后的基础样本数据对上述基础模型进行增量训练,增联训练后的模型能够对属于增量样本类别的数据进行分类。
上述实施例提供的模型训练方案中,可以获得基础样本数据,并获得增量样本数据,其中,基础样本数据为:训练基础模型时所采用的样本数据,增量样本数据的增量样本类别与基础样本数据的基础样本类别不完全相同;利用增量样本数据训练用于进行分类的目标模型;利用目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别,其中,新增类别为:增量样本类别中不同于基础样本类别的类别;利用增量样本数据、更新类别后的基础样本数据对基础模型进行训练。这样可以利用增量样本数据训练得到目标模型,该目标模型能够对属于增量样本类别的数据进行分类,因此可以利用上述目标模型更新基础样本数据中属于新增类别的样本数据的类别,使得更新类别后的基础样本数据的类别与新增样本类别一致,从而降低基础样本数据的噪声,后续可以利用新增样本数据和降低噪声后的基础样本数据对基础模型进行混合增量训练。由此可见,应用上述实施例提供的方案,可以提高增量训练后模型的准确度。
本申请的一个实施例中,在上述步骤S103中进行类别更新之前,还可以:
确定增量样本类别中与基础样本类别相关联的类别,作为关联类别;将属于被关联类别的基础样本数据的类别更新为关联类别。
其中,被关联类别为:基础样本类别中与关联类别相关联的类别。
具体的,可以从增量样本类别中确定与基础样本类别相关联的类别,作为关联类别,将基础样本类别中与上述关联类别相关联的类别,作为被关联类别,认为上述关联类别与被关联类别实质为同一类别,因此可以从基础样本数据中查找属于上述被关联类别的样本数据,将所查找到的样本数据的类别更新为关联类别。
上述方案中,由于增量样本类别与基础样本类别之间存在交集,也存在不相交集,可以对属于上述交集的类别进行一对一的映射,从而建立基础样本类别与增量样本类别之间的映射关系,进而可以利用上述映射关系对基础样本数据的类别进行更新,从而能够降低基础样本数据中的噪声。
本申请的一个实施例中,在确定关联类别时,可以利用基础模型识别各个增量样本数据的类别,获得基础模型对不同增量样本类别的增量样本数据进行识别的准确度;确定准确度最高的增量样本类别,作为与基础样本类别相关联的关联类别。
具体的,基础模型可以识别基础样本类别的数据所属的类别,可以将增量样本数据输入基础模型,利用基础模型识别各个增量样本数据的类别,并针对每一增量样本类别,获得该基础模型对该类别的增量样本数据进行识别的准确度,准确度越高,说明基础模型对该增量样本类别的识别效果越好,进而说明该增量样本类别与基础样本类别越相似,从而可以认为该增量样本类别为关联类别。
本申请的一个实施例中,在获得基础模型对不同增量样本类别的增量样本数据进行识别的准确度时,可以获得基础模型对不同增量样本类别的增量样本数据进行识别的置信度的数学统计值,作为准确度。其中,上述数学统计值可以是算数平均值、加权平均值、中位值等。
本申请的一个实施例中,在确定关联类别时,还可以计算各个增量样本类别与基础样本类别的相似度,确定上述相似度最高的增量样本类别与基础样本类别之间存在关联关系,进而将存在关联关系的增量样本类别作为关联类别,并将上述关联关系中的基础样本类别作为被关联类别。
本申请的一个实施例中,在上述步骤S104对基础模型进行训练时,可以包括如下步骤A-C:
步骤A,将增量样本数据、更新类别后的基础样本数据作为输入数据输入基础模型,利用基础模型提取输入数据的第一特征、并根据第一特征预测输入数据的类别,得到输出结果。
其中,上述输入数据包括增量样本数据和更新类别后的基础样本数据。
具体的,可以将输入数据输入待进行增量训练的基础模型,基础模型可以提取各个输入数据的特征,作为第一特征,然后利用上述第一特征对各个输入数据进行分类,得到各个输入数据对应的输出结果。
步骤B,计算输出结果相对输入数据的标注信息的第一损失。
其中,每一输入数据的标注信息反映该输入数据的类别。
具体的,每一输入数据对应的输出结果能够反映:基础模型所识别到的、该输入数据的类别,可以计算上述输出结果相对输入数据的标注信息的损失,作为第一损失。
步骤C,利用第一损失调整基础模型的参数,实现对基础模型的训练。
具体的,上述第一损失可以反映待进行增量训练的基础模型的输出结果相对标注信息的差异,利用该第一损失可以对上述基础模型进行参数调整,从而实现对基础模型的增量训练。
本申请的一个实施例中,可以利用第一损失调整基础模型的参数,得到调参后模型;将输入数据输入调参后模型,利用调参后模型提取输入数据的第二特征;计算第二特征相对第一特征的第二损失,利用第二损失调整调参后模型的参数,实现模型训练。
具体的,可以将上述输入数据重新输入调参后模型中,调参后模型可以提取各个输入数据的特征,作为第二特征,然后计算上述第二特征相对第一特征的损失,作为第二损失,最后利用第二损失再次调整调参后模型的参数,最终实现对待进行增量训练的基础模型的增量训练。
上述方案中,由于增量样本类别与基础样本类别之间存在交集,也存在不相交集,利用上述数据增量训练后的基础模型与增量训练前的基础模型所能够识别的类别之间存在相关性,也就说明增量训练后的基础模型与增量训练前的基础模型对数据进行特征提取时,所提取的特征之间也存在相关性,因此为了保证增量训练后的基础模型的可靠性,可以计算利用上述第二特征相对第一特征之间的第二损失对模型进行参数调整,使得参数调整后的模型所提取的特征与初始的基础模型所提取的特征相近,实现了利用基础模型对调参后模型的知识蒸馏。
利用上述第一特征对各个输入数据进行分类,得到各个输入数据对应的输出结果。
上述实施例提供的模型训练方案中,可以获得基础样本数据,并获得增量样本数据,其中,基础样本数据为:训练基础模型时所采用的样本数据,增量样本数据的增量样本类别与基础样本数据的基础样本类别不完全相同;利用增量样本数据训练用于进行分类的目标模型;利用目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别,其中,新增类别为:增量样本类别中不同于基础样本类别的类别;利用增量样本数据、更新类别后的基础样本数据对基础模型进行训练。这样可以利用增量样本数据训练得到目标模型,该目标模型能够对属于增量样本类别的数据进行分类,因此可以利用上述目标模型更新基础样本数据中属于新增类别的样本数据的类别,使得更新类别后的基础样本数据的类别与新增样本类别一致,从而降低基础样本数据的噪声,后续可以利用新增样本数据和降低噪声后的基础样本数据对基础模型进行混合增量训练。由此可见,应用上述实施例提供的方案,可以提高增量训练后模型的准确度。
参见图2,图2为本申请实施例提供的另一种模型训练方法的流程示意图,该方法包括如下步骤S201-S207:
S201,获得基础样本数据,并获得增量样本数据。
其中,基础样本数据为:训练基础模型时所采用的样本数据。
增量样本数据的增量样本类别与基础样本数据的基础样本类别不完全相同,增量样本类别与基础样本类别之间存在交集,也存在不相交集。
S202,确定增量样本类别中与基础样本类别相关联的类别,作为关联类别,将属于被关联类别的基础样本数据的类别更新为关联类别。
其中,被关联类别为:基础样本类别中与关联类别相关联的类别。
S203,利用增量样本数据训练用于进行分类的目标模型。
S204,利用目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别。
其中,新增类别为:增量样本类别中不同于基础样本类别的类别。
S205,将增量样本数据、更新类别后的基础样本数据作为输入数据输入基础模型,利用基础模型提取输入数据的第一特征、并根据第一特征预测输入数据的类别,得到输出结果。
其中,上述输入数据包括增量样本数据和更新类别后的基础样本数据。
S206,计算输出结果相对输入数据的标注信息的第一损失,利用第一损失调整基础模型的参数,得到调参后模型。
其中,每一输入数据的标注信息反映该输入数据的类别。
S207,将输入数据输入调参后模型,利用调参后模型提取输入数据的第二特征,计算第二特征相对第一特征的第二损失,利用第二损失调整调参后模型的参数,实现模型训练。
参见图3,图3为本申请实施例提供的一种模型训练过程的示意图,如图3所示:
在开始增量训练后,可以首先进行数据获得,具体可以获得基础样本数据,并获得增量样本数据;
之后可以利用所获得的数据对基础模型进行混合增量训练,在混合增量训练过程中,可以对属于增量样本类别与基础样本类别之间的交集的类别进行一对一的映射,进而利用上述映射关系对基础样本数据的类别进行更新,并利用目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别,实现对基础样本数据中噪声的滤除,然后利用新增样本数据和降低噪声后的基础样本数据对基础模型进行混合增量训练,最终训练得到增量模型;
之后可以重新获得基础样本数据和增量样本数据,重复执行上述增量训练的过程,实现对基础模型的多次增量训练。
与上述模型训练方法相对应地,本申请实施例还提供了一种模型训练装置,下面进行详细介绍。
参见图4,图4为本申请实施例提供的一种模型训练装置的结构示意图,所述装置包括:
数据获得模块401,用于获得基础样本数据,并获得增量样本数据,其中,所述基础样本数据为:训练基础模型时所采用的样本数据,所述增量样本数据的增量样本类别与所述基础样本数据的基础样本类别不完全相同;
目标模型训练模块402,用于利用所述增量样本数据训练用于进行分类的目标模型;
第一类别更新模块403,用于利用所述目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别,其中,所述新增类别为:所述增量样本类别中不同于所述基础样本类别的类别;
模型训练模块404,用于利用所述增量样本数据、更新类别后的基础样本数据对所述基础模型进行训练。
本申请的一个实施例中,所述装置还包括第二类别更新模块,包括:
关联类别确定单元,用于在针对每一基础样本数据、在该基础样本数据的伪类别为新增类别时、更新该基础样本数据的类别为所识别到的伪类别之前,确定所述增量样本类别中与所述基础样本类别相关联的类别,作为关联类别;
第二类别更新单元,用于将属于被关联类别的基础样本数据的类别更新为所述关联类别,其中,所述被关联类别为:所述基础样本类别中与所述关联类别相关联的类别。
本申请的一个实施例中,所述关联类别确定单元,具体用于:
利用所述基础模型识别各个增量样本数据的类别,获得所述基础模型对不同增量样本类别的增量样本数据进行识别的准确度;
确定准确度最高的增量样本类别,作为与所述基础样本类别相关联的关联类别。
本申请的一个实施例中,所述模型训练模块404,包括:
输出结果获得单元,用于将所述增量样本数据、更新类别后的基础样本数据作为输入数据输入所述基础模型,利用所述基础模型提取所述输入数据的第一特征、并根据所述第一特征预测所述输入数据的类别,得到输出结果;
第一损失计算单元,用于计算所述输出结果相对所述输入数据的标注信息的第一损失,其中,每一输入数据的标注信息反映该输入数据的类别;
模型训练单元,用于利用所述第一损失调整所述基础模型的参数,实现对所述基础模型的训练。
本申请的一个实施例中,所述模型训练单元,具体用于:
利用所述第一损失调整所述基础模型的参数,得到调参后模型;
将所述输入数据输入所述调参后模型,利用所述调参后模型提取所述输入数据的第二特征;
计算所述第二特征相对所述第一特征的第二损失,利用所述第二损失调整所述调参后模型的参数,实现模型训练。
上述实施例提供的模型训练方案中,可以获得基础样本数据,并获得增量样本数据,其中,基础样本数据为:训练基础模型时所采用的样本数据,增量样本数据的增量样本类别与基础样本数据的基础样本类别不完全相同;利用增量样本数据训练用于进行分类的目标模型;利用目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别,其中,新增类别为:增量样本类别中不同于基础样本类别的类别;利用增量样本数据、更新类别后的基础样本数据对基础模型进行训练。这样可以利用增量样本数据训练得到目标模型,该目标模型能够对属于增量样本类别的数据进行分类,因此可以利用上述目标模型更新基础样本数据中属于新增类别的样本数据的类别,使得更新类别后的基础样本数据的类别与新增样本类别一致,从而降低基础样本数据的噪声,后续可以利用新增样本数据和降低噪声后的基础样本数据对基础模型进行混合增量训练。由此可见,应用上述实施例提供的方案,可以提高增量训练后模型的准确度。
本申请实施例还提供了一种电子设备,如图5所示,包括处理器501、通信接口502、存储器503和通信总线504,其中,处理器501,通信接口502,存储器503通过通信总线504完成相互间的通信,
存储器503,用于存放计算机程序;
处理器501,用于执行存储器503上所存放的程序时,实现模型训练方法。
上述电子设备提到的通信总线可以是外设部件互连标准(Peripheral ComponentInterconnect,PCI)总线或扩展工业标准结构(Extended Industry StandardArchitecture,EISA)总线等。该通信总线可以分为地址总线、数据总线、控制总线等。为便于表示,图中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
通信接口用于上述电子设备与其他设备之间的通信。
存储器可以包括随机存取存储器(Random Access Memory,RAM),也可以包括非易失性存储器(Non-Volatile Memory,NVM),例如至少一个磁盘存储器。可选的,存储器还可以是至少一个位于远离前述处理器的存储装置。
上述的处理器可以是通用处理器,包括中央处理器(Central Processing Unit,CPU)、网络处理器(Network Processor,NP)等;还可以是数字信号处理器(Digital SignalProcessor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
在本申请提供的又一实施例中,还提供了一种计算机可读存储介质,该计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现上述任一模型训练方法的步骤。
在本申请提供的又一实施例中,还提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行上述实施例中任一模型训练方法。
上述实施例提供的模型训练方案中,可以获得基础样本数据,并获得增量样本数据,其中,基础样本数据为:训练基础模型时所采用的样本数据,增量样本数据的增量样本类别与基础样本数据的基础样本类别不完全相同;利用增量样本数据训练用于进行分类的目标模型;利用目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别,其中,新增类别为:增量样本类别中不同于基础样本类别的类别;利用增量样本数据、更新类别后的基础样本数据对基础模型进行训练。这样可以利用增量样本数据训练得到目标模型,该目标模型能够对属于增量样本类别的数据进行分类,因此可以利用上述目标模型更新基础样本数据中属于新增类别的样本数据的类别,使得更新类别后的基础样本数据的类别与新增样本类别一致,从而降低基础样本数据的噪声,后续可以利用新增样本数据和降低噪声后的基础样本数据对基础模型进行混合增量训练。由此可见,应用上述实施例提供的方案,可以提高增量训练后模型的准确度。
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。所述计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行所述计算机程序指令时,全部或部分地产生按照本申请实施例所述的流程或功能。所述计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机、服务器或数据中心通过有线(例如同轴电缆、光纤、数字用户线(DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可用介质集成的服务器、数据中心等数据存储设备。所述可用介质可以是磁性介质,(例如,软盘、硬盘、磁带)、光介质(例如,DVD)、或者半导体介质(例如固态硬盘Solid State Disk(SSD))等。
需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
本说明书中的各个实施例均采用相关的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置实施例、电子设备实施例、计算机可读存储介质实施例、计算机程序产品实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
以上所述仅为本申请的较佳实施例,并非用于限定本申请的保护范围。凡在本申请的精神和原则之内所作的任何修改、等同替换、改进等,均包含在本申请的保护范围内。

Claims (12)

1.一种模型训练方法,其特征在于,所述方法包括:
获得基础样本数据,并获得增量样本数据,其中,所述基础样本数据为:训练基础模型时所采用的样本数据,所述增量样本数据的增量样本类别与所述基础样本数据的基础样本类别不完全相同;
利用所述增量样本数据训练用于进行分类的目标模型;
利用所述目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别,其中,所述新增类别为:所述增量样本类别中不同于所述基础样本类别的类别;
利用所述增量样本数据、更新类别后的基础样本数据对所述基础模型进行训练。
2.根据权利要求1所述的方法,其特征在于,在所述针对每一基础样本数据、在该基础样本数据的伪类别为新增类别时、更新该基础样本数据的类别为所识别到的伪类别步骤之前,所述方法还包括:
确定所述增量样本类别中与所述基础样本类别相关联的类别,作为关联类别;
将属于被关联类别的基础样本数据的类别更新为所述关联类别,其中,所述被关联类别为:所述基础样本类别中与所述关联类别相关联的类别。
3.根据权利要求2所述的方法,其特征在于,所述确定所述增量样本类别中与所述基础样本类别相关联的类别,作为关联类别,包括:
利用所述基础模型识别各个增量样本数据的类别,获得所述基础模型对不同增量样本类别的增量样本数据进行识别的准确度;
确定准确度最高的增量样本类别,作为与所述基础样本类别相关联的关联类别。
4.根据权利要求1-3中任一项所述的方法,其特征在于,所述利用所述增量样本数据、更新类别后的基础样本数据对所述基础模型进行训练,包括:
将所述增量样本数据、更新类别后的基础样本数据作为输入数据输入所述基础模型,利用所述基础模型提取所述输入数据的第一特征、并根据所述第一特征预测所述输入数据的类别,得到输出结果;
计算所述输出结果相对所述输入数据的标注信息的第一损失,其中,每一输入数据的标注信息反映该输入数据的类别;
利用所述第一损失调整所述基础模型的参数,实现对所述基础模型的训练。
5.根据权利要求4所述的方法,其特征在于,所述利用所述第一损失调整所述基础模型的参数,实现对所述基础模型的训练,包括:
利用所述第一损失调整所述基础模型的参数,得到调参后模型;
将所述输入数据输入所述调参后模型,利用所述调参后模型提取所述输入数据的第二特征;
计算所述第二特征相对所述第一特征的第二损失,利用所述第二损失调整所述调参后模型的参数,实现模型训练。
6.一种模型训练装置,其特征在于,所述装置包括:
数据获得模块,用于获得基础样本数据,并获得增量样本数据,其中,所述基础样本数据为:训练基础模型时所采用的样本数据,所述增量样本数据的增量样本类别与所述基础样本数据的基础样本类别不完全相同;
目标模型训练模块,用于利用所述增量样本数据训练用于进行分类的目标模型;
第一类别更新模块,用于利用所述目标模型识别各个基础样本数据的伪类别,针对每一基础样本数据,在该基础样本数据的伪类别为新增类别时,更新该基础样本数据的类别为所识别到的伪类别,其中,所述新增类别为:所述增量样本类别中不同于所述基础样本类别的类别;
模型训练模块,用于利用所述增量样本数据、更新类别后的基础样本数据对所述基础模型进行训练。
7.根据权利要求6所述的装置,其特征在于,所述装置还包括第二类别更新模块,包括:
关联类别确定单元,用于在针对每一基础样本数据、在该基础样本数据的伪类别为新增类别时、更新该基础样本数据的类别为所识别到的伪类别之前,确定所述增量样本类别中与所述基础样本类别相关联的类别,作为关联类别;
第二类别更新单元,用于将属于被关联类别的基础样本数据的类别更新为所述关联类别,其中,所述被关联类别为:所述基础样本类别中与所述关联类别相关联的类别。
8.根据权利要求7所述的装置,其特征在于,所述关联类别确定单元,具体用于:
利用所述基础模型识别各个增量样本数据的类别,获得所述基础模型对不同增量样本类别的增量样本数据进行识别的准确度;
确定准确度最高的增量样本类别,作为与所述基础样本类别相关联的关联类别。
9.根据权利要求6-8中任一项所述的装置,其特征在于,所述模型训练模块,包括:
输出结果获得单元,用于将所述增量样本数据、更新类别后的基础样本数据作为输入数据输入所述基础模型,利用所述基础模型提取所述输入数据的第一特征、并根据所述第一特征预测所述输入数据的类别,得到输出结果;
第一损失计算单元,用于计算所述输出结果相对所述输入数据的标注信息的第一损失,其中,每一输入数据的标注信息反映该输入数据的类别;
模型训练单元,用于利用所述第一损失调整所述基础模型的参数,实现对所述基础模型的训练。
10.根据权利要求9所述的装置,其特征在于,所述模型训练单元,具体用于:
利用所述第一损失调整所述基础模型的参数,得到调参后模型;
将所述输入数据输入所述调参后模型,利用所述调参后模型提取所述输入数据的第二特征;
计算所述第二特征相对所述第一特征的第二损失,利用所述第二损失调整所述调参后模型的参数,实现模型训练。
11.一种电子设备,其特征在于,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现权利要求1-5任一所述的方法步骤。
12.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现权利要求1-5任一所述的方法步骤。
CN202111360462.0A 2021-11-17 2021-11-17 一种模型训练方法、装置、电子设备及存储介质 Pending CN114065858A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111360462.0A CN114065858A (zh) 2021-11-17 2021-11-17 一种模型训练方法、装置、电子设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111360462.0A CN114065858A (zh) 2021-11-17 2021-11-17 一种模型训练方法、装置、电子设备及存储介质

Publications (1)

Publication Number Publication Date
CN114065858A true CN114065858A (zh) 2022-02-18

Family

ID=80273327

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111360462.0A Pending CN114065858A (zh) 2021-11-17 2021-11-17 一种模型训练方法、装置、电子设备及存储介质

Country Status (1)

Country Link
CN (1) CN114065858A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115438755A (zh) * 2022-11-08 2022-12-06 腾讯科技(深圳)有限公司 分类模型的增量训练方法、装置和计算机设备

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115438755A (zh) * 2022-11-08 2022-12-06 腾讯科技(深圳)有限公司 分类模型的增量训练方法、装置和计算机设备
CN115438755B (zh) * 2022-11-08 2024-04-02 腾讯科技(深圳)有限公司 分类模型的增量训练方法、装置和计算机设备

Similar Documents

Publication Publication Date Title
CN110222791B (zh) 样本标注信息的审核方法及装置
CN108376129B (zh) 一种纠错方法及装置
CN110909784B (zh) 一种图像识别模型的训练方法、装置及电子设备
WO2023011470A1 (zh) 一种机器学习系统及模型训练方法
CN111047429A (zh) 一种概率预测方法及装置
CN109947903B (zh) 一种成语查询方法及装置
EP2707808A2 (en) Exploiting query click logs for domain detection in spoken language understanding
CN115935344A (zh) 一种异常设备的识别方法、装置及电子设备
CN113378852A (zh) 关键点检测方法、装置、电子设备及存储介质
CN114241411B (zh) 基于目标检测的计数模型处理方法、装置及计算机设备
CN114065858A (zh) 一种模型训练方法、装置、电子设备及存储介质
CN114596570A (zh) 一种文字识别模型的训练方法、文字识别方法及装置
CN112434717A (zh) 一种模型训练方法及装置
CN113095067A (zh) 一种ocr错误纠正的方法、装置、电子设备及存储介质
CN112163415A (zh) 针对反馈内容的用户意图识别方法、装置及电子设备
CN114693011A (zh) 一种政策匹配方法、装置、设备和介质
CN112069806B (zh) 简历筛选方法、装置、电子设备及存储介质
CN110895924B (zh) 一种文档内容朗读方法、装置、电子设备及可读存储介质
CN110399803B (zh) 一种车辆检测方法及装置
CN115292008A (zh) 用于分布式系统的事务处理方法、装置、设备及介质
CN112017634B (zh) 数据的处理方法、装置、设备以及存储介质
CN111767710B (zh) 印尼语的情感分类方法、装置、设备及介质
CN110135464B (zh) 一种图像处理方法、装置、电子设备及存储介质
CN114461900A (zh) 职位名称的识别方法、装置、计算处理设备、程序及介质
CN113901817A (zh) 文档分类方法、装置、计算机设备和存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination