CN114118301A - 分类模型的训练方法及计算机可读存储介质 - Google Patents

分类模型的训练方法及计算机可读存储介质 Download PDF

Info

Publication number
CN114118301A
CN114118301A CN202210069084.9A CN202210069084A CN114118301A CN 114118301 A CN114118301 A CN 114118301A CN 202210069084 A CN202210069084 A CN 202210069084A CN 114118301 A CN114118301 A CN 114118301A
Authority
CN
China
Prior art keywords
classification model
data set
data
label
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210069084.9A
Other languages
English (en)
Inventor
刘国清
杨广
王启程
郑伟
刘德富
杨国武
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Minieye Innovation Technology Co Ltd
Original Assignee
Shenzhen Minieye Innovation Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Minieye Innovation Technology Co Ltd filed Critical Shenzhen Minieye Innovation Technology Co Ltd
Priority to CN202210069084.9A priority Critical patent/CN114118301A/zh
Publication of CN114118301A publication Critical patent/CN114118301A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q10/00Administration; Management
    • G06Q10/04Forecasting or optimisation specially adapted for administrative or management purposes, e.g. linear programming or "cutting stock problem"

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Business, Economics & Management (AREA)
  • General Physics & Mathematics (AREA)
  • Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Economics (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Artificial Intelligence (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Strategic Management (AREA)
  • Evolutionary Biology (AREA)
  • Human Resources & Organizations (AREA)
  • Game Theory and Decision Science (AREA)
  • Development Economics (AREA)
  • Entrepreneurship & Innovation (AREA)
  • Marketing (AREA)
  • Operations Research (AREA)
  • Quality & Reliability (AREA)
  • Tourism & Hospitality (AREA)
  • General Business, Economics & Management (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明提供了一种分类模型的训练方法,包括:根据第一数据集训练第一分类模型以得到预测分类模型,其中,第一数据集为有标签数据的集合;将第二数据集输入预测分类模型进行计算以得到预测标签,其中,第二数据集为无标签数据的集合;根据第一数据集和带有预测标签的第二数据集依次循环对第二分类模型进行训练,并更新第二分类模型的参数以得到中间分类模型,直至中间分类模型满足预设条件;以及当中间分类模型满足预设条件时,将满足预设条件的中间分类模型作为目标分类模型。此外,本发明还提供了一种计算机可读存储介质。本发明技术方案有效解决了有标签数据的数量少导致分类模型准确度不高的问题。

Description

分类模型的训练方法及计算机可读存储介质
技术领域
本发明涉及机器学习技术领域,尤其涉及一种分类模型的训练方法及计算机可读存储介质。
背景技术
深度学习模型在各个领域中已经取得了巨大的成就,特别是有监督学习算法在各种视觉任务中取得了巨大成功。深度学习一般是从大量已标注的训练样本中学习一个模型用于给未见过的样本预测一个尽可能正确的标签。然而在许多实际应用场景中,人工标注大规模的训练样本需要耗费巨大的人力和物力。因此,许多研究聚焦于半监督学习,即在只有部分已标注样本和大量未标注样本的情况下学习的一个模型。
发明内容
本发明提供了一种分类模型的训练方法及计算机可读存储介质,用于解决有标签数据的数量少导致分类模型准确度不高的问题。
第一方面,本发明实施例提供一种分类模型的训练方法,所述分类模型的训练方法包括:
根据第一数据集训练第一分类模型以得到预测分类模型,其中,所述第一数据集为有标签数据的集合;
将第二数据集输入所述预测分类模型进行计算以得到预测标签,其中,所述第二数据集为无标签数据的集合;
将所述第一数据集和带有预测标签的第二数据集两者之一输入第二分类模型中训练,并更新所述第二分类模型的参数以得到中间分类模型,将两者之另一输入所述中间分类模型中训练,并更新所述中间分类模型的参数;
判断所述中间分类模型是否满足预设条件;
当所述中间分类模型没有满足预设条件时,将所述第一数据集和带有预测标签的第二数据集两者之一输入所述中间分类模型中训练,并更新所述中间分类模型的参数,将两者之另一输入所述中间分类模型中训练,并更新所述中间分类模型的参数,直至所述中间分类模型满足预设条件;以及
当所述中间分类模型满足预设条件时,将满足预设条件的中间分类模型作为目标分类模型。
第二方面,本发明实施例提供一种计算机可读存储介质,其特征在于,所述计算机可读存储介质用于存储程序指令,所述程序指令可被处理器执行以实现如上所述的分类模型的训练方法。
上述分类模型的训练方法及计算机可读存储介质,根据有标签的第一数据训练第一分类模型以得到预测分类模型,利用预测分类模型为无标签的第二数据计算预测标签。根据第一数据集和带有预测标签的第二数据集两者之一训练第二分类模型以得到中间分类模型,再根据两者之另一训练中间分类模型,并按顺序对中间分类模型进行循环训练,直至中间分类模型满足预设条件,得到目标分类模型。利用预测分类模型为无标签的第二数据计算预测标签,根据第一数据集和带有预测标签的第二数据集对第二分类模型进行训练,极大增加了有标签的数据的数量,从而得到具有良好分类能力的目标分类模型,提高了目标分类模型的性能。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图示出的结构获得其他的附图。
图1为本发明第一实施例提供的分类模型的训练方法的流程图。
图2为本发明第一实施例提供的分类模型的训练方法的第一子流程图。
图3为本发明第一实施例提供的分类模型的训练方法的第二子流程图。
图4为本发明第二实施例提供的分类模型的训练方法的子流程图。
图5为本发明第三实施例提供的分类模型的训练方法的子流程图。
图6为本发明实施例提供的终端的内部结构示意图。
本发明目的的实现、功能特点及优点将结合实施例,参照附图做进一步说明。
具体实施方式
为了使本发明的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本发明进行进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本发明,并不用于限定本发明。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本申请的说明书和权利要求书及上述附图中的术语“第一”、“第二”、“第三”、“第四”等(如果存在)是用于区别类似的规划对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,换句话说,描述的实施例根据除了这里图示或描述的内容以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,还可以包含其他内容,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于只清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
需要说明的是,在本发明中涉及“第一”、“第二”等的描述仅用于描述目的,而不能理解为指示或暗示其相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括一个或者多个该特征。另外,各个实施例之间的技术方案可以相互结合,但是必须是以本领域普通技术人员能够实现为基础,当技术方案的结合出现相互矛盾或无法实现时应当认为这种技术方案的结合不存在,也不在本发明要求的保护范围之内。
请结合参看图1,其为本发明第一实施例提供的分类模型的训练方法的流程图。训练方法用于训练分类模型,训练得到的分类模型能够对无标签数据进行分类。分类模型的训练方法具体包括如下步骤。
步骤S102,根据第一数据集训练第一分类模型以得到预测分类模型。其中,第一数据集为有标签数据的集合。在本实施例中,训练第一分类模型之前,预先设定若干预设类别。第一数据集中的每一第一数据均包括一个真实标签,每一第一数据的真实标签与若干预设类别中的一个类别相对应。训练第一分类模型的具体过程为:将第一数据集中的第一数据输入第一分类模型以得到相应的训练标签,基于损失函数根据训练标签和真实标签计算训练损失值,根据训练损失值更新第一分类模型的参数,直至第一分类模型收敛。将收敛的第一分类模型作为预测分类模型。其中,损失函数包括但不限于分类交叉熵损失函数、均方误差损失函数、绝对误差损失函数等任何为有监督学习设计的损失函数。第一数据集中的第一数据包括但不限于图像、文字、音频等。
步骤S104,将第二数据集输入预测分类模型进行计算以得到预测标签。其中,第二数据集为无标签数据的集合。将第二数据集中的第二数据输入预测分类模型,可以得到与每一第二数据相对应的预测标签。第二数据集中的第二数据包括但不限于图像、文字、音频等。可以理解的是,第一数据和第二数据为相同类型的数据。即,若第一数据为图像,则第二数据也应该为图像;若第一数据为文字,则第二数据也应该为文字;若第一数据为音频,则第二数据也应该为音频。
步骤S106,将第一数据集和带有预测标签的第二数据集两者之一输入第二分类模型中训练,并更新第二分类模型的参数以得到中间分类模型,将两者之另一输入中间分类模型中训练,并更新中间分类模型的参数。即是说,可以先将第一数据集输入第二分类模型中训练,并更新第二分类模型的参数以得到中间分类模型,再将第二数据集输入中间分类模型中训练,并更新中间分类模型的参数;也可以先将第二数据集输入第二分类模型中训练,并更新第二分类模型的参数以得到中间分类模型,再将第一数据集输入中间分类模型中训练,并更新中间分类模型的参数。其中,第一分类模型的结构和第二分类模型的结构相同,第一分类模型的参数和第二分类模型的参数不同。可以理解的是,第一分类模型和第二分类模型设置的结构相同,但是第一分类模型和第二分类模型分别初始化后形成的参数不同。根据第一数据集和带有预测标签的第二数据集训练第二分类模型以得到中间分类模型的具体过程将在下文详细描述。
步骤S108,判断中间分类模型是否满足预设条件。当中间分类模型没有满足预设条件时,执行步骤S110;当中间分类模型满足预设条件时,执行步骤S112。判断中间分类模型是否满足预设条件的具体过程将在下文详细描述。
步骤S110,将第一数据集和带有预测标签的第二数据集两者之一输入中间分类模型中训练,并更新中间分类模型的参数,将两者之另一输入中间分类模型中训练,并更新中间分类模型的参数,直至中间分类模型满足预设条件。在本实施例中,利用第一数据集或第二数据集更新一次中间分类模型的参数,再利用第二数据集或第一数据集更新一次中间分类模型的参数,中间分类模型视为训练了一次。中间分类模型每训练一次,就判断一次训练好的中间分类模型是否满足预设条件。多次训练中间分类模型,直至中间分类模型满足预设条件。
步骤S112,将满足预设条件的中间分类模型作为目标分类模型。目标分类模型用于对无标签数据进行分类。
上述实施例中,根据有标签的第一数据训练第一分类模型以得到预测分类模型,利用预测分类模型为无标签的第二数据计算预测标签。根据第一数据集和带有预测标签的第二数据集两者之一训练第二分类模型以得到中间分类模型,再根据两者之另一训练中间分类模型,并按顺序对中间分类模型进行循环训练,直至中间分类模型满足预设条件,得到目标分类模型。利用预测分类模型为无标签的第二数据计算预测标签,根据第一数据集和带有预测标签的第二数据集对第二分类模型进行训练,极大增加了有标签的数据的数量,从而得到具有良好分类能力的目标分类模型,提高了目标分类模型的性能。同时,分类模型的训练方法能够应用于不同类型的数据集,具有广泛的应用价值。
请结合参看图2,其为本发明第一实施例提供的分类模型的训练方法的第一子流程图。步骤S106具体包括如下步骤。
步骤S202,将第一数据集输入第二分类模型以得到第一标签。将第一数据集中的第一数据输入第二分类模型以得到相应的第一标签。可以理解的是,每一第一数据均具有一个第一标签。
步骤S204,根据第一标签和真实标签构建第一损失值。在本实施例中,从第一数据集的第一数据中多次选取预设数量的第一数据。其中,每次选取的第一数据均不重复。即是说,同一次选取的第一数据均不相同,被选取过的第一数据不会再次被选取,所有第一数据均会被选取。但每次选取的第一数据的数量均相同,为预设数量。预设数量可以根据实际情况进行设置,在此不做限定。具体地,每次选取预设数量的第一数据的过程可以称为批处理,预设数量为batch size的大小。计算每次选取的所有第一数据的第一标签的平均值,根据每次选取的第一数据的真实标签和第一标签的平均值构建第一损失值。具体地,根据每一第一数据的真实标签和相应第一标签的平均值构建第一子损失,计算每次选取的所有第一数据的第一子损失的总和作为第一损失值。可以理解的是,根据第一标签和真实标签构建的第一损失值包括多个。其中,若第一数据集中第一数据的数量为n,预设数量为x,构建的第一损失值的数量为n/x。
在本实施例中,根据第一公式计算第一子损失。第一公式为
Figure 782792DEST_PATH_IMAGE001
。其中,
Figure 649117DEST_PATH_IMAGE002
表示第一子损失,
Figure 909197DEST_PATH_IMAGE003
表示第一数据,
Figure 519170DEST_PATH_IMAGE004
表示真实标签,
Figure 49771DEST_PATH_IMAGE005
表示预设数量的第一标签的平均值。在一些可行的实施例中,还可以计算每一第一数据的真实标签和相应第一标签的平均值的绝对误差或者均方误差作为第一子损失。
步骤S206,根据第一损失值更新第二分类模型的参数以得到中间分类模型。在本实施例中,根据多个第一损失值依次更新第二分类模型的参数,当第二分类模型的参数更新多次后,得到中间分类模型。其中,当第二分类模型的参数根据多个第一损失值更新完成后,第一数据集才算完成一次对第二分类模型参数的更新。
步骤S208,将第二数据集输入中间分类模型以得到第二标签。将第二数据集中的第二数据输入中间分类模型以得到相应的第二标签。可以理解的是,每一第二数据均包括一个第二标签。
步骤S210,根据第二标签和预测标签构建第二损失值。在本实施例中,从第二数据集的第二数据中多次选取预设数量的第二数据。其中,每次选取的第二数据均不重复。即是说,同一次选取的第二数据均不相同,被选取过的第二数据不会再次被选取,所有第二数据均会被选取。但每次选取的第二数据的数量均相同,为预设数量。预设数量可以根据实际情况进行设置,在此不做限定。具体地,每次选取预设数量的第一数据的过程可以称为批处理,预设数量为batch size的大小。计算每次选取的所有第二数据的第二标签的平均值,根据每次选取的第二数据的预测标签和第二标签的平均值构建第二损失值。具体地,根据每一第二数据的预测标签和相应第二标签的平均值构建第二子损失,计算每次选取的所有第二数据的第二子损失的总和作为第二损失值。可以理解的是,根据第二标签和预测标签构建的第二损失值包括多个。其中,若第二数据集中第二数据的数量为m,预设数量为x,构建的第二损失值的数量为m/x。
在本实施例中,根据第二公式计算第二子损失。第二公式为
Figure 352576DEST_PATH_IMAGE006
。其中,
Figure 834373DEST_PATH_IMAGE007
表示第二子损失,
Figure 513616DEST_PATH_IMAGE008
表示第二数据,
Figure 397258DEST_PATH_IMAGE009
表示预测标签,
Figure 870965DEST_PATH_IMAGE010
表示预设数量的第二标签的平均值。在一些可行的实施例中,还可以计算每一第二数据的预测标签和相应第二标签的平均值的绝对误差作为第二子损失。
在一些可行的实施例中,可以直接根据预设数量的第二数据的预测标签和第二标签的平均值构建第二损失值。其中,若根据每一第二数据的预测标签和相应第二标签的平均值构建第二子损失的过程是通过向量的方式进行,则根据预设数量的第二数据的预测标签和第二标签的平均值构建第二损失值的过程可以理解为通过矩阵的方式进行。即是说,可以直接对预设数量的第二数据的预测标签和第二标签的平均值进行批量处理得到第二损失值,而不是计算每一个第二数据所对应的第二子损失之后再求和得到第二损失值,对第二损失值的构建具有良好的鲁棒性,同时还能够缩短计算时间。
步骤S212,根据第二损失值更新中间分类模型的参数。在本实施例中,根据多个第二损失值依次更新中间分类模型的参数。其中,当中间分类模型的参数根据多个第二损失值更新完成后,第二数据集才算完成一次对中间分类模型参数的更新。可以理解的是,当第一数据集更新一次第二分类模型的参数以得到中间分类模型,第二数据集更新一次中间分类模型的参数,中间分类模型才视为训练了一次。
相应地,将第一数据集和带有预测标签的第二数据集两者之一输入中间分类模型中训练,并更新中间分类模型的参数,将两者之另一输入中间分类模型中训练,并更新中间分类模型的参数具体为:先将第一数据集输入中间分类模型中训练,并更新中间分类模型的参数,再将第二数据集输入更新参数后的中间分类模型中训练,并更新中间分类模型的参数。其中,更新中间分类模型参数的具体过程与上述步骤基本一致,在此不再一一赘述。当第一数据集更新一次中间分类模型的参数,第二数据集再更新一次中间分类模型的参数,中间分类模型才视为训练了一次。
上述实施例中,根据第一数据的第一标签和真实标签构建多个第一损失值,根据多个第一损失值依次更新第二分类模型;根据第二数据的第二标签和预测标签构建多个第二损失值,根据多个第二损失值依次更新中间分类模型。批量选取第一数据或者第二数据,根据第一数据第一标签的平均值或者第二数据第二标签的平均值相应构建第一损失值或者第二损失值,从而对模型的参数进行更新,可以有效地纠正模型的优化方向,引导模型向更一般的方向收敛,极大地减少噪声标签的干扰,分散错误的风险,有效地避免模型拟合可能错误的预测标签,从而提高目标分类模型的鲁棒性能。其中,根据第一数据第一标签的平均值或者第二数据第二标签的平均值相应构建第一损失值或者第二损失值,极大提升了半监督学习的效率,比单纯利用均方差损失函数(MSE)或者绝对误差损失函数(MAE)计算损失值更加稳妥、可靠。同时,批量选取第一数据或者第二数据,能够避免一个一个地计算更新过程太慢,或者全部一起计算更新过程太难、内存装不下的问题。
请结合参看图3,其为本发明第一实施例提供的分类模型的训练方法的第二子流程图。步骤S108具体包括如下步骤。
步骤S302,根据第一损失值和第二损失值计算总损失值。在本实施例中,将多个第一损失值和多个第二损失值进行求和以得到总损失值。可以理解的是,中间分类模型每训练一次,都需要计算相应的总损失值。
步骤S304,根据当前总损失值和上一总损失值判断当前总损失值是否不再减小。判断当前总损失值与上一总损失值是否相同,当当前总损失值与上一总损失值相同时,确认当前总损失值不再减小。
步骤S306,当当前总损失值不再减小时,确认中间分类模型满足预设条件。可以理解的是,当当前总损失值不再减小时,中间分类模型收敛。
请结合参看图4,其为本发明第二实施例提供的分类模型的训练方法的子流程图。第二实施例提供的分类模型的训练方法与第一实施例提供的分类模型的训练方法的不同之处在于,第二实施例提供的分类模型的训练方法中,步骤S106具体包括如下步骤。
步骤S402,将第二数据集输入第二分类模型以得到第三标签。将第二数据集中的第二数据输入第二分类模型以得到相应的第三标签。可以理解的是,每一第二数据均具有一个第三标签。
步骤S404,根据第三标签和预测标签构建第三损失值。在本实施例中,从第二数据集的第二数据中多次选取预设数量的第二数据。其中,每次选取的第二数据均不重复。即是说,同一次选取的第二数据均不相同,被选取过的第二数据不会再次被选取,所有第二数据均会被选取。但每次选取的第二数据的数量均相同,为预设数量。预设数量可以根据实际情况进行设置,在此不做限定。具体地,每次选取预设数量的第一数据的过程可以称为批处理,预设数量为batch size的大小。计算每次选取的所有第二数据的第三标签的平均值,根据每次选取的第二数据的预测标签和第三标签的平均值构建第三损失值。具体地,根据每一第二数据的预测标签和相应第三标签的平均值构建第三子损失,计算每次选取的所有第二数据的第三子损失的总和作为第三损失值。可以理解的是,根据第三标签和预测标签构建的第三损失值包括多个。其中,若第二数据集中第二数据的数量为m,预设数量为x,构建的第二损失值的数量为m/x。
在本实施例中,根据第三公式计算第三子损失。第三公式为
Figure 43320DEST_PATH_IMAGE011
。其中,
Figure 995096DEST_PATH_IMAGE012
表示第三子损失,
Figure 998824DEST_PATH_IMAGE008
表示第二数据,
Figure 144897DEST_PATH_IMAGE009
表示预测标签,
Figure 866865DEST_PATH_IMAGE013
表示预设数量的第三标签的平均值。在一些可行的实施例中,还可以计算每一第二数据的预测标签和相应第三标签的平均值的绝对误差作为第三子损失。
在一些可行的实施例中,可以直接根据预设数量的第二数据的预测标签和第三标签的平均值构建第三损失值。其中,若根据每一第二数据的预测标签和相应第三标签的平均值构建第三子损失的过程是通过向量的方式进行,则根据预设数量的第二数据的预测标签和第三标签的平均值构建第三损失值的过程可以理解为通过矩阵的方式进行。即是说,可以直接对预设数量的第二数据的预测标签和第三标签的平均值进行批量处理得到第三损失值,而不是计算每一个第二数据所对应的第三子损失之后再求和得到第三损失值,对第三损失值的构建具有良好的鲁棒性,同时还能够缩短计算时间。
步骤S406,根据第三损失值更新第二分类模型的参数以得到中间分类模型。在本实施例中,根据多个第三损失值依次更新第二分类模型的参数,当第二分类模型的参数更新多次后,得到中间分类模型。其中,当第二分类模型的参数根据多个第三损失值更新完成后,第二数据集才完成一次对第二分类模型参数的更新。
步骤S408,将第一数据集输入中间分类模型以得到第四标签。将第一数据集中的第一数据输入中间分类模型以得到相应的第四标签。可以理解的是,每一第一数据均具有一个第四标签。
步骤S410,根据第四标签和真实标签构建第四损失值。在本实施例中,从第一数据集的第一数据中多次选取预设数量的第一数据。其中,每次选取的第一数据均不重复。即是说,同一次选取的第一数据均不相同,被选取过的第一数据不会再次被选取,所有第一数据均会被选取。但每次选取的第一数据的数量均相同,为预设数量。预设数量可以根据实际情况进行设置,在此不做限定。具体地,每次选取预设数量的第一数据的过程可以称为批处理,预设数量为batch size的大小。计算每次选取的所有第一数据的第四标签的平均值,根据每次选取的第一数据的真实标签和第四标签的平均值构建第四损失值。具体地,根据每一第一数据的真实标签和相应第四标签的平均值构建第四子损失,计算每次选取的所有第一数据的第四子损失的总和作为第四损失值。可以理解的是,根据第四标签和真实标签构建的第四损失值包括多个。其中,若第一数据集中第一数据的数量为n,预设数量为x,构建的第一损失值的数量为n/x。
在本实施例中,根据第四公式计算第四子损失。第四公式
Figure 622332DEST_PATH_IMAGE014
。其中,
Figure 214987DEST_PATH_IMAGE015
表示第四子损失,
Figure 233759DEST_PATH_IMAGE003
表示第一数据,
Figure 177444DEST_PATH_IMAGE004
表示真实标签,
Figure 736601DEST_PATH_IMAGE016
表示预设数量的第四标签的平均值。在一些可行的实施例中,还可以计算每一第一数据的真实标签和相应第四标签的平均值的绝对误差或者均方误差作为第四子损失。
步骤S412,根据第四损失值更新中间分类模型的参数。在本实施例中,根据多个第四损失值依次更新中间分类模型的参数。其中,当中间分类模型的参数根据多个第四损失值更新完成后,第一数据集才算完成一次对中间分类模型参数的更新。可以理解的是,当第二数据集更新一次第二分类模型的参数以得到中间分类模型,第一数据集更新一次中间分类模型的参数,中间分类模型才视为训练了一次。
相应地,将第一数据集和带有预测标签的第二数据集两者之一输入中间分类模型中训练,并更新中间分类模型的参数,将两者之另一输入中间分类模型中训练,并更新中间分类模型的参数具体为:先将第二数据集输入中间分类模型中训练,并更新中间分类模型的参数,再将第一数据集输入更新参数后的中间分类模型中训练,并更新中间分类模型的参数。其中,更新中间分类模型参数的具体过程与上述步骤基本一致,在此不再一一赘述。当第二数据集更新一次中间分类模型的参数,第一数据集再更新一次中间分类模型的参数,中间分类模型才视为训练了一次。
相应地,根据第三损失值和第四损失值计算总损失值。具体地,将多个第三损失值和多个第四损失值进行求和以得到总损失值。根据总损失值判断中间分类模型是否满足预设条件的具体过程与第一实施例的基本一致,在此不再赘述。
第二实施例提供的分类模型的训练方法的其它步骤与第一实施例提供的分类模型的训练方法的基本一致,在此不再赘述。
请结合参看图5,其为本发明第三实施例提供的分类模型的训练方法的子流程图。第三实施例提供的分类模型的训练方法与第一实施例提供的分类模型的训练方法的不同之处在于,第三实施例提供的分类模型的训练方法中,步骤S108具体包括如下步骤。
步骤S502,根据验证数据集计算中间分类模型的准确率。其中,验证数据集中的验证数据包括正确标签。可以理解的是,验证数据集中的每一验证数据均包括一个正确标签,每一验证数据的正确标签与若干预设类别中的一个类别相对应。将验证数据集输入中间分类模型以得到验证标签,根据验证标签和正确标签计算准确率。具体地,将验证数据集中的验证数据输入中间分类模型以得到相应的验证标签,判断同一验证数据的验证标签和正确标签是否相同。统计同一验证数据的验证标签和正确标签相同的数量作为正确数量,根据正确数量和验证数据的数量计算中间分类模型的准确率。可以理解的是,中间分类模型每训练一次,都需要计算相应的准确率。
步骤S504,判断中间分类模型的准确率是否达到预设值。在本实施例中,预设值可以根据实际情况进行设置。优选地,预设值大于或者等于90%。
步骤S506,当中间分类模型的准确率达到预设值时,确认中间分类模型满足预设条件。可以理解的是,当中间分类模型的准确率达到预设值时,中间分类模型收敛。
第三实施例提供的分类模型的训练方法的其它步骤与第一实施例提供的分类模型的训练方法的基本一致,在此不再赘述。
请结合参看图6,其为本发明实施例提供的终端的内部结构示意图。终端10包括计算机可读存储介质11、处理器12以及总线13。其中,计算机可读存储介质11至少包括一种类型的可读存储介质,该可读存储介质包括闪存、硬盘、多媒体卡、卡型存储器(例如,SD或DX存储器等)、磁性存储器、磁盘、光盘等。计算机可读存储介质11在一些实施例中可以是终端10的内部存储单元,例如终端10的硬盘。计算机可读存储介质11在另一些实施例中也可以是终端10的外部存储设备,例如终端10上配备的插接式硬盘、智能存储卡(Smart MediaCard,SMC)、安全数字(Secure Digital,SD)卡、闪存卡(Flash Card)等。进一步地,计算机可读存储介质11还可以既包括终端10的内部存储单元也包括外部存储设备。计算机可读存储介质11不仅可以用于存储安装于终端10的应用软件及各类数据,还可以用于暂时地存储已经输出或者将要输出的数据。
总线13可以是外设部件互连标准(peripheral component interconnect,PCI)总线或扩展工业标准结构(extended industry standard architecture,EISA)总线等。该总线可以分为地址总线、数据总线、控制总线等。为便于表示,图6中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
进一步地,终端10还可以包括显示组件14。显示组件14可以是发光二极管(LightEmitting Diode,LED)显示器、液晶显示器、触控式液晶显示器以及有机发光二极管(Organic Light-Emitting Diode,OLED)触摸器等。其中,显示组件14也可以适当的称为显示装置或显示单元,用于显示在终端10中处理的信息以及用于显示可视化的用户界面。
进一步地,终端10还可以包括通信组件15。通信组件15可选地可以包括有线通信组件和/或无线通信组件,如WI-FI通信组件、蓝牙通信组件等,通常用于在终端10与其他智能控制设备之间建立通信连接。
处理器12在一些实施例中可以是一中央处理器(Central Processing Unit,CPU)、控制器、微控制器、微处理器或其他数据处理芯片,用于运行计算机可读存储介质11中存储的程序代码或处理数据。具体地,处理器12执行处理程序以控制终端10实现分类模型的训练方法。
图6仅示出了具有组件11-15、用于实现分类模型的训练方法的终端10,本领域技术人员可以理解的是,图6示出的结构并不构成对终端10的限定,终端10可以包括比图示更少或者更多的部件,或者组合某些部件,或者不同的部件布置。
显然,本领域的技术人员可以对本发明进行各种改动和变型而不脱离本发明的精神和范围。这样,倘且本发明的这些修改和变型属于本发明权利要求及其等同技术的范围之内,则本发明也意图包含这些改动和变型在内。
以上所列举的仅为本发明较佳实施例而已,当然不能以此来限定本发明之权利范围,因此依本发明权利要求所作的等同变化,仍属于本发明所涵盖的范围。

Claims (10)

1.一种分类模型的训练方法,其特征在于,所述分类模型的训练方法包括:
根据第一数据集训练第一分类模型以得到预测分类模型,其中,所述第一数据集为有标签数据的集合;
将第二数据集输入所述预测分类模型进行计算以得到预测标签,其中,所述第二数据集为无标签数据的集合;
将所述第一数据集和带有预测标签的第二数据集两者之一输入第二分类模型中训练,并更新所述第二分类模型的参数以得到中间分类模型,将两者之另一输入所述中间分类模型中训练,并更新所述中间分类模型的参数;
判断所述中间分类模型是否满足预设条件;
当所述中间分类模型没有满足预设条件时,将所述第一数据集和带有预测标签的第二数据集两者之一输入所述中间分类模型中训练,并更新所述中间分类模型的参数,将两者之另一输入所述中间分类模型中训练,并更新所述中间分类模型的参数,直至所述中间分类模型满足预设条件;以及
当所述中间分类模型满足预设条件时,将满足预设条件的中间分类模型作为目标分类模型。
2.如权利要求1所述的分类模型的训练方法,其特征在于,将所述第一数据集和带有预测标签的第二数据集两者之一输入第二分类模型中训练,并更新所述第二分类模型的参数以得到中间分类模型,将两者之另一输入所述中间分类模型中训练,并更新所述中间分类模型的参数具体包括:
将所述第一数据集输入所述第二分类模型以得到第一标签,其中,所述第一数据集中的第一数据包括真实标签;
根据所述第一标签和所述真实标签构建第一损失值;
根据所述第一损失值更新所述第二分类模型的参数以得到所述中间分类模型;
将所述第二数据集输入所述中间分类模型以得到第二标签;
根据所述第二标签和所述预测标签构建第二损失值;以及
根据所述第二损失值更新所述中间分类模型的参数。
3.如权利要求2所述的分类模型的训练方法,其特征在于,根据所述第一标签和所述真实标签构建第一损失值具体包括:
从所述第一数据集的第一数据中多次选取预设数量的第一数据,其中,每次选取的第一数据均不重复;
计算每次选取的所有第一数据的第一标签的平均值;以及
根据每次选取的第一数据的真实标签和第一标签的平均值构建所述第一损失值。
4.如权利要求2所述的分类模型的训练方法,其特征在于,根据所述第二标签和所述预测标签构建第二损失值具体包括:
从所述第二数据集的第二数据中多次选取预设数量的第二数据,其中,每次选取的第二数据均不重复;
计算每次选取的所有第二数据的第二标签的平均值;以及
根据每次选取的第二数据的预测标签和第二标签的平均值构建所述第二损失值。
5.如权利要求2所述的分类模型的训练方法,其特征在于,判断所述中间分类模型是否满足预设条件具体包括:
根据所述第一损失值和所述第二损失值计算总损失值;
根据当前总损失值和上一总损失值判断当前总损失值是否不再减小;以及
当当前总损失值不再减小时,确认所述中间分类模型满足预设条件。
6.如权利要求1所述的分类模型的训练方法,其特征在于,将所述第一数据集和带有预测标签的第二数据集两者之一输入第二分类模型中训练,并更新所述第二分类模型的参数以得到中间分类模型,将两者之另一输入所述中间分类模型中训练,并更新所述中间分类模型的参数具体包括:
将所述第二数据集输入所述第二分类模型以得到第三标签;
根据所述第三标签和所述预测标签构建第三损失值;
根据所述第三损失值更新所述第二分类模型的参数以得到所述中间分类模型;
将所述第一数据集输入所述中间分类模型以得到第四标签,其中,所述第一数据集中的第一数据包括真实标签;
根据所述第四标签和所述真实标签构建第四损失值;以及
根据所述第四损失值更新所述中间分类模型的参数。
7.如权利要求1所述的分类模型的训练方法,其特征在于,判断所述中间分类模型是否满足预设条件具体包括:
根据验证数据集计算所述中间分类模型的准确率;
判断所述中间分类模型的准确率是否达到预设值;以及
当所述中间分类模型的准确率达到预设值时,确认所述中间分类模型满足预设条件。
8.如权利要求7所述的分类模型的训练方法,其特征在于,根据验证数据集计算所述中间分类模型的准确率具体包括:
将所述验证数据集输入所述中间分类模型以得到验证标签,其中,所述验证数据集中的验证数据包括正确标签;以及
根据所述验证标签和所述正确标签计算所述准确率。
9.如权利要求1所述的分类模型的训练方法,其特征在于,所述第一分类模型的结构和所述第二分类模型的结构相同,所述第一分类模型的参数和所述第二分类模型的参数不同。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质用于存储程序指令,所述程序指令可被处理器执行以实现如权利要求1至9中任一项所述的分类模型的训练方法。
CN202210069084.9A 2022-01-21 2022-01-21 分类模型的训练方法及计算机可读存储介质 Pending CN114118301A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210069084.9A CN114118301A (zh) 2022-01-21 2022-01-21 分类模型的训练方法及计算机可读存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210069084.9A CN114118301A (zh) 2022-01-21 2022-01-21 分类模型的训练方法及计算机可读存储介质

Publications (1)

Publication Number Publication Date
CN114118301A true CN114118301A (zh) 2022-03-01

Family

ID=80361119

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210069084.9A Pending CN114118301A (zh) 2022-01-21 2022-01-21 分类模型的训练方法及计算机可读存储介质

Country Status (1)

Country Link
CN (1) CN114118301A (zh)

Similar Documents

Publication Publication Date Title
CN108304775B (zh) 遥感图像识别方法、装置、存储介质以及电子设备
CN108830329B (zh) 图片处理方法和装置
KR20210062687A (ko) 이미지 분류 모델 훈련 방법, 이미지 처리 방법 및 장치
CN113468338A (zh) 针对数字化云业务的大数据分析方法及大数据服务器
CN113869464B (zh) 图像分类模型的训练方法及图像分类方法
CN114021670A (zh) 分类模型的学习方法及终端
CN110781818A (zh) 视频分类方法、模型训练方法、装置及设备
CN111914949B (zh) 基于强化学习的零样本学习模型的训练方法及装置
CN117349424A (zh) 应用于语言模型的提示模板的处理方法、装置及电子设备
CN114139658A (zh) 分类模型的训练方法及计算机可读存储介质
CN112287140A (zh) 一种基于大数据的图像检索方法及系统
CN112418443A (zh) 基于迁移学习的数据处理方法、装置、设备及存储介质
WO2020088338A1 (zh) 一种建立识别模型的方法及装置
CN114118301A (zh) 分类模型的训练方法及计算机可读存储介质
CN111709475A (zh) 一种基于N-grams的多标签分类方法及装置
CN111143568A (zh) 一种论文分类时的缓冲方法、装置、设备及存储介质
CN110852261A (zh) 目标检测方法、装置、电子设备和可读存储介质
US20230137639A1 (en) Data processing system and method for operating an enterprise application
CN112949590B (zh) 一种跨域行人重识别模型构建方法及构建系统
CN113989596B (zh) 图像分类模型的训练方法及计算机可读存储介质
US11860769B1 (en) Automatic test maintenance leveraging machine learning algorithms
CN113591979A (zh) 行业类目识别方法、设备、介质及计算机程序产品
CN108960291B (zh) 一种基于并行化Softmax分类的图像处理方法和系统
CN113868240B (zh) 数据清洗方法及计算机可读存储介质
CN113936141B (zh) 图像语义分割方法及计算机可读存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication

Application publication date: 20220301

RJ01 Rejection of invention patent application after publication