CN116524289A - 一种模型训练方法及相关系统 - Google Patents

一种模型训练方法及相关系统 Download PDF

Info

Publication number
CN116524289A
CN116524289A CN202210074265.0A CN202210074265A CN116524289A CN 116524289 A CN116524289 A CN 116524289A CN 202210074265 A CN202210074265 A CN 202210074265A CN 116524289 A CN116524289 A CN 116524289A
Authority
CN
China
Prior art keywords
model
label
target
image
loss
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210074265.0A
Other languages
English (en)
Inventor
韩承志
吴学文
王小辉
李冠彬
王阔
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Huawei Technologies Co Ltd
Original Assignee
Huawei Technologies Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Huawei Technologies Co Ltd filed Critical Huawei Technologies Co Ltd
Priority to CN202210074265.0A priority Critical patent/CN116524289A/zh
Publication of CN116524289A publication Critical patent/CN116524289A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Image Analysis (AREA)

Abstract

本申请提供了一种模型训练方法,应用于人工智能(AI)技术领域,包括:获取训练数据集,对训练数据集中的无标签图像分别进行强增强和弱增强,得到强增强图像和弱增强图像,通过第一模型和弱增强图像对无标签图像中的目标进行预测获得第一标签,通过第二模型和强增强图像、第一标签对无标签图像中的目标进行预测,获得预测结果,预测结果包括目标的候选框的预测位置和所述目标的预测类别,通过第一模型和弱增强图像、候选框的预测位置对无标签图像中的目标进行再次预测获得第二标签,根据第二标签更新第二模型的参数,以训练第二模型。如此,通过二次检验解决了伪标签质量较低导致AI模型精度或召回率不足的问题。

Description

一种模型训练方法及相关系统
技术领域
本申请涉及人工智能(artificial intelligence,AI)技术领域,尤其涉及一种模型训练方法、模型训练系统以及计算设备集群、计算机可读存储介质、计算机程序产品。
背景技术
随着算力的提升以及大量标注数据的出现,以数据驱动的AI技术尤其是深度学习(deep learning,DL)技术在很多领域取得了较大的进展。然而,标注数据非常耗时耗力,尤其是针对目标检测任务,一张图像通常要标注多个目标。为此,业界提出了半监督学习(semi-supervised learning,SSL),以充分利用无标注数据。
半监督学习是一种将监督学习和无监督学习结合的机器学习方法,该方法将少量的标注数据和大量的无标注数据结合,用于训练AI模型。具体地,先为无标注数据生成伪标签,然后将无标注数据和伪标签组合形成伪标注数据,然后基于标注数据和伪标注数据,采用监督学习算法训练AI模型。
当前一个较为有效的范式是基于一致性的思想,使用Teacher-Student架构来进行半监督学习。以目标检测任务为例,首先可以对无标注数据(例如为无标签图像)分别进行强弱增强,得到强增强图像和弱增强图像,Teacher模型读入弱增强图像,并根据预测结果以及一个设置好的阈值来筛选高置信度的伪标签。然后,Student模型结合强增强图像以及伪标签,来进行监督训练。训练过程中,Student模型根据标注数据和伪标注数据计算损失值loss并根据该loss更新参数,而Teacher模型则是在Student模型更新完参数之后,读入Student的模型参数,然后通过指数移动平均(Exponential Moving Average,EMA)参数更新策略更新自身的参数。
然而,基于伪标签的半监督学习对于伪标签的质量要求较高,伪标签的质量直接影响训练得到的AI模型的精度和/或召回率。
发明内容
本申请提供了一种模型训练方法,该方法从“对抗低质量的伪标签”的角度出发,提供了一种二次检验机制,以解决伪标签质量较低影响AI模型的精度和/或召回率的问题。相应地,本申请还提供了模型训练系统、计算设备集群、计算机可读存储介质以及计算机程序产品。
第一方面,本申请提供了一种模型训练方法。该方法可以由模型训练系统执行。在一些实施例中,模型训练系统可以是软件系统。计算设备或计算设备集群通过运行该软件系统的程序代码,以执行模型训练方法。在另一些实施例中,该模型训练系统也可以是用于训练AI模型的硬件系统。
具体地,模型训练系统获取训练数据集,该训练数据集包括无标签图像,对所述无标签图像分别进行强增强和弱增强,得到强增强图像和弱增强图像,然后模型训练系统通过第一模型和所述弱增强图像对所述无标签图像中的目标进行预测获得第一标签,通过第二模型和所述强增强图像、所述第一标签对所述无标签图像中的所述目标进行预测,获得预测结果,该预测结果包括所述目标的候选框的预测位置和所述目标的预测类别,接着模型训练系统通过所述第一模型和所述弱增强图像、所述候选框的预测位置对所述无标签图像中的所述目标进行再次预测获得第二标签,根据所述第二标签更新所述第二模型的参数,以训练所述第二模型。
在该方法中,模型训练系统使用二次检验,以纠正利用第一标签(也即伪标签)进行模型训练时由于伪标签质量不足导致的候选分类标注错误,从根本上解决伪标签质量不足所带来的不良影响。其中,该方法所使用的第二标签(也即软标签)可以打破类别(包括前景类和背景类)之间的对立性,使得图像中外观相似的目标也能在训练中发挥正面作用。
在一些可能的实现方式中,模型训练系统可以按照所述候选框的预测位置裁剪所述弱增强图像,将裁剪后的图像输入所述第一模型,对所述无标签图像中的所述目标进行再次预测,获得所述目标属于各个类别的评分。然后模型训练系统根据所述目标属于各个类别的评分中由高至低排序的前K个评分对应的类别生成第二标签,所述K大于1。
如此,可以打破类别之间的对立性,使得相似的目标(如电视机、笔记本电脑)也能在训练中发挥正面作用,由此提高训练得到的模型的精度。
在一些可能的实现方式中,所述各个类别包括预设的N个前景类别和一个背景类别中的多个类别,所述N为正整数。如此,可以打破前景类和背景类之间的对立性,以便于将标注错误的背景候选框纠正为前景候选框,或者将标注错误的前景候选框纠正为背景候选框或其他前景类,进而提高训练得到的模型的精度。
在一些可能的实现方式中,所述第二模型包括区域候选网络RPN和感兴趣区域ROI网络,模型训练系统可以根据所述第一标签和所述目标的预测类别确定所述RPN的无监督分类损失、所述RPN的无监督回归损失和所述ROI网络的无监督回归损失,根据所述第二标签确定所述ROI网络的无监督分类损失,然后根据所述RPN的无监督分类损失和所述RPN的无监督回归损失以及所述ROI网络的无监督分类损失和所述ROI网络的无监督回归损失,确定所述第二模型的损失,接着模型训练系统可以根据所述第二模型的损失更新所述第二模型的参数。
由于第二标签包括多个可能的类别,相较于第一标签,具有较高的准确度,而确定目标归属的前景类,通常是由ROI网络实现,因此,基于第一标签确定RPN的无监督分类损失、无监督回归损失以及ROI网络的无监督回归损失,基于第二标签确定ROI网络的无监督分类损失具有较高的准确度,基于准确度较高的损失更新第二模型的参数可以使得第二模型快速收敛,提高训练效率。
在一些可能的实现方式中,第二标签包括多个类别,模型训练系统可以根据所述第二标签和所述预测类别,采用软交叉熵损失函数确定所述ROI网络的无监督分类损失。由于不仅仅计算在某一类别上的损失,而是对在多个类别上的损失进行累计,因而具有较高准确度。
在一些可能的实现方式中,所述无标签图像中包括多个所述目标,针对多个所述目标中的每个目标,模型训练系统可以根据所述第二模型预测的所述目标归属于各个类别的评分中由高至低排名前K个评分的和值确定所述目标的权重。相应地,针对多个所述目标中的每个目标,模型训练系统可以根据所述第二标签和所述目标的预测类别,采用软交叉熵损失函数,确定所述ROI网络对所述目标的无监督分类损失,然后根据每个目标的所述权重和所述ROI网络对所述目标的无监督分类损失进行加权运算,确定所述ROI网络的无监督分类损失。
该方法通过引入TOP-K权重机制确定ROI网络的无监督分类损失,可以有效降低二次检验后得到难以分辨前景或背景的候选框的有害影响,从而保障训练得到的模型的精度。
在一些可能的实现方式中,模型训练系统可以根据所述第一模型预测的所述目标归属于各个类别的评分中排名前K的评分之和,更新所述第一标签。换言之,模型训练系统可以根据TOP-K评分之和重新筛选第一标签,由此可以提高第一标签的召回率。
相应地,模型训练系统可以根据所述第一标签和所述目标的预测类别确定所述RPN的无监督分类损失、所述RPN的无监督回归损失,根据更新后的所述第一标签和所述目标的预测类别确定所述ROI网络的无监督回归损失。如此可以进一步提高损失的准确度,使得第二模型能够尽快收敛。
在一些可能的实现方式中,所述第一模型为教师模型,所述第二模型为学生模型。其中,教师模型和学生模型的输出具有一致性,因此,模型训练系统可以将教师模型的输出作为学生模型的监督信息进行监督学习,学生模型可以将监督学习的结果如候选框的预测位置再反馈到教师模型进行二次检验,由此解决伪标签质量较低,导致标注错误的问题。
在一些可能的实现方式中,模型训练系统可以根据更新后的所述第二模型的参数,更新所述第一模型的参数。具体地,模型训练系统可以根据更新后的第二模型的参数,通过指数移动平均参数更新策略,更新第一模型的参数。如此可以实现对第一模型的高效训练。
第二方面,本申请提供了一种模型训练系统。所述系统包括:
通信模块,用于获取训练数据集,所述训练数据集包括无标签图像;
预处理模块,用于对所述无标签图像分别进行强增强和弱增强,得到强增强图像和弱增强图像;
标签管理模块,用于通过第一模型和所述弱增强图像对所述无标签图像中的目标进行预测获得第一标签,通过第二模型和所述强增强图像、所述第一标签对所述无标签图像中的所述目标进行预测,获得预测结果,所述预测结果包括所述目标的候选框的预测位置和所述目标的预测类别;
所述标签管理模块,还用于通过所述第一模型和所述弱增强图像、所述候选框的预测位置对所述无标签图像中的所述目标进行再次预测获得第二标签;
训练模块,用于根据所述第二标签更新所述第二模型的参数,以训练所述第二模型。
在一些可能的实现方式中,所述标签管理模块具体用于:
按照所述候选框的预测位置裁剪所述弱增强图像,将裁剪后的图像输入所述第一模型,对所述无标签图像中的所述目标进行再次预测,获得所述目标属于各个类别的评分;
根据所述目标属于各个类别的评分中由高至低排序的前K个评分对应的类别生成第二标签,所述K大于1。
在一些可能的实现方式中,所述各个类别包括预设的N个前景类别和一个背景类别中的多个类别,所述N为正整数。
在一些可能的实现方式中,所述第二模型包括区域候选网络RPN和感兴趣区域ROI网络,所述训练模块具体用于:
根据所述第一标签和所述目标的预测类别确定所述RPN的无监督分类损失、所述RPN的无监督回归损失和所述ROI网络的无监督回归损失,根据所述第二标签确定所述ROI网络的无监督分类损失;
根据所述RPN的无监督分类损失和所述RPN的无监督回归损失以及所述ROI网络的无监督分类损失和所述ROI网络的无监督回归损失,确定所述第二模型的损失;
根据所述第二模型的损失更新所述第二模型的参数。
在一些可能的实现方式中,所述训练模块具体用于:
根据所述第二标签和所述预测类别,采用软交叉熵损失函数确定所述ROI网络的无监督分类损失。
在一些可能的实现方式中,所述无标签图像中包括多个所述目标,所述训练模块还用于:
针对多个所述目标中的每个目标,根据所述第二模型预测的所述目标归属于各个类别的评分中由高至低排名前K个评分的和值确定所述目标的权重;
所述训练模块具体用于:
针对多个所述目标中的每个目标,根据所述第二标签和所述目标的预测类别,采用软交叉熵损失函数,确定所述ROI网络对所述目标的无监督分类损失;
根据每个目标的所述权重和所述ROI网络对所述目标的无监督分类损失进行加权运算,确定所述ROI网络的无监督分类损失。
在一些可能的实现方式中,所述标签管理模块还用于:
根据所述第一模型预测的所述目标归属于各个类别的评分中排名前K的评分之和,更新所述第一标签;
所述训练模块具体用于:
根据所述第一标签和所述目标的预测类别确定所述RPN的无监督分类损失、所述RPN的无监督回归损失,根据更新后的所述第一标签和所述目标的预测类别确定所述ROI网络的无监督回归损失。
在一些可能的实现方式中,所述第一模型为教师模型,所述第二模型为学生模型。
在一些可能的实现方式中,所述训练模块还用于:
根据更新后的所述第二模型的参数,更新所述第一模型的参数。
第三方面,本申请提供一种计算设备集群。所述计算设备集群包括至少一台计算设备,所述至少一台计算设备包括至少一个处理器和至少一个存储器。所述至少一个处理器、所述至少一个存储器进行相互的通信。所述至少一个处理器用于执行所述至少一个存储器中存储的指令,以使得计算设备或计算设备集群执行如第一方面或第一方面的任一种实现方式中的模型训练方法。
第四方面,本申请提供一种计算机可读存储介质,所述计算机可读存储介质中存储有指令,所述指令指示计算设备或计算设备集群执行上述第一方面或第一方面的任一种实现方式所述的模型训练方法。
第五方面,本申请提供了一种包含指令的计算机程序产品,当其在计算设备或计算设备集群上运行时,使得计算设备或计算设备集群执行上述第一方面或第一方面的任一种实现方式所述的模型训练方法。
本申请在上述各方面提供的实现方式的基础上,还可以进行进一步组合以提供更多实现方式。
附图说明
为了更清楚地说明本申请实施例的技术方法,下面将对实施例中所需使用的附图作以简单地介绍。
图1为本申请实施例提供的一种模型训练系统的架构示意图;
图2为本申请实施例提供的一种任务配置界面的示意图;
图3为本申请实施例提供的一种软标签的示意图;
图4为本申请实施例提供的一种模型训练方法的流程图;
图5为本申请实施例提供的一种模型训练方法的流程示意图;
图6为本申请实施例提供的一种计算设备集群的结构示意图。
具体实施方式
本申请实施例中的术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括一个或者更多个该特征。
首先对本申请实施例中所涉及到的一些技术术语进行介绍。
深度学习(deep learning,DL),是机器学习的一个分支,具体是一种以深度神经网络为架构,对数据(也可以称作观测值,例如可以为图像)进行表征学习的算法。深度学习的基础是机器学习中的分散表示(distributed representation)。分散表示假定观测值是由不同因子相互作用生成。在此基础上,深度学习进一步假定这一相互作用的过程可分为多个层次,代表对观测值的多层抽象。不同的层数和层的规模可用于不同程度的抽象。相较于浅层神经网络,深度神经网络能够提供更高的抽象层次,因而提高了基于该深度神经网络的AI模型的能力。
深度学习可以采用非监督式或半监督式的特征学习和分层特征提取高效算法来替代手工获取特征。其中,非监督式的特征学习也称作非监督学习、无监督学习,无监督学习是一种从无标注数据中学习特征的机器学习算法。半监督式的特征学习也称作半监督学习,半监督学习是一种将监督学习和无监督学习结合的机器学习算法。监督学习是一种从标注数据中学习特征的机器学习算法。相应地,半监督学习是一种将监督学习和无监督学习结合,采用少量的标注数据和大量的无标注数据训练AI模型的算法。
半监督学习通常可以用于目标检测任务。目标检测任务是指从图像中检测出目标的任务。图像可以包括前景和背景,前景是指图像中目标检测任务关注的、感兴趣的部分,背景是指图像中目标检测任务不关注的、不感兴趣的部分。其中,目标可以是自然界中的物体,该物体可以是生物体或非生物体。在配置目标检测任务时,可以配置需要检测的物体的类别,以便于后续按照配置的类别,检测相应的物体。以在沙滩边拍摄的一张图像为例,目标检测任务检测到的目标可以包括人、狗、猫等生物体,以及帐篷、滑板等非生物体。
针对目标检测任务,无标注数据可以是无标签图像,标注数据可以是标签图像。在训练阶段,可以对无标签图像分别进行强弱增强,得到强增强图像和弱增强图像,一个模型(为了便于描述,有些情况下也称之为第一模型)读入弱增强图像,并根据预测结果以及一个设置好的阈值来筛选高置信度的伪标签。然后,另一个模型(为了便于描述,有些情况下也称之为第二模型)结合强增强图像以及伪标签,来进行监督训练。训练过程中,第二模型根据标签图像和伪标签图像计算损失值loss并根据该loss更新参数,而第一模型则是在第二模型更新完参数之后,读入第二模型的参数,然后基于第二模型的参数更新自身的参数。
上述基于伪标签的半监督学习对于伪标签的质量要求较高,伪标签的质量直接影响训练得到的AI模型(例如上述第一模型、第二模型)的精度和/或召回率。然而,一张图像内含有多个物体,每个物体都有自己的类别和位置,这使得生成高质量的伪标签变得十分困难。另外,由于标签图像是有限的,训练出来的AI模型本身的能力有限,导致生成高质量的伪标签难上加难。
业内针对“提升伪标签的质量”进行了大量探索,例如可以采用自适应的阈值生成伪标签,或者是为目标检测任务中的分类任务和回归任务分别生成伪标签。即使采用了上述方法,生成的伪标签依然存在着很多的错误。这些错误主要分为两部分:伪标签的位置不够准确;伪标签的召回率过低。
在进行模型训练时,模型可以从为图像生成的锚“anchor”中提取候选框,根据伪标签(具体是伪标签中表征目标位置的框)与候选框的交并比(Intersection-over-Union,IoU)为候选框打标签,从而进行监督训练。其中,伪标签的位置不准确可以导致候选框被错误地标注为某一前景类;伪标签的召回率不足可以导致候选框被错误地标注为背景类。
有鉴于此,本申请实施例从“对抗低质量的伪标签”的角度出发,提供了一种基于二次检验机制的模型训练方法,以解决上述问题。该方法可以由模型训练系统(一些情况下,也可以简称为训练系统、训练平台)执行。在一些实施例中,模型训练系统可以是软件系统。计算设备或计算设备集群通过运行该软件系统的程序代码,以执行模型训练方法。在另一些实施例中,该模型训练系统也可以是用于训练AI模型的硬件系统。本申请实施例以模型训练系统为软件系统进行示例说明。
参见图1所示的模型训练系统的架构示意图,模型训练系统100部署在云环境10中,云环境10指示云服务提供商拥有的,用于提供计算、存储、通信资源的中心计算设备集群。模型训练系统100具体可以部署在中心计算设备集群的一个或多个计算设备(例如:中心服务器)中。
用户(例如:AI模型的开发者)可以通过终端20访问云环境10中的模型训练系统100,并通过与模型训练系统100交互,如创建模型训练任务,使得模型训练系统100执行模型训练方法。其中,模型训练系统100可以从数据库300中获取训练数据集,进而根据该训练数据集进行模型训练。
在本实施例中,终端20为具有交互功能的设备。例如终端20可以是具有界面交互功能的设备,该设备包括但不限于台式机、笔记本电脑、平板电脑或者智能手机。参见图2所示的任务配置界面的示意图,任务配置界面200承载有数据集配置控件202和模型结构配置控件204。
其中,数据集配置控件202用于配置训练数据集。具体地,模型训练系统100可以提供多种训练数据集,例如包括多种开源的训练数据集,或者是自定义的训练数据集,用户可以从中选择一种或多种训练数据集用于模型训练。模型结构配置控件204用于配置待训练的AI模型的结构。待训练的AI模型的结构可以通过层级数据格式(Hierarchical DataFormat,HDF)文件,如HDF5文件表示。用户可以通过模型结构配置控件204以浏览文件系统方式,从文件系统中选择相应的HDF5文件,从而实现配置模型结构。
在一些可能的实现方式中,任务配置界面200还可以包括用于配置超参数的控件,例如任务配置界面可以包括优化器配置控件205和学习率配置控件206,其中,优化器配置控件205用于配置训练AI模型所采用的优化器,学习率配置控件206用于配置训练AI模型所采用的学习率。在一些实施例中,用户也可以不配置上述优化器或学习率,相应地,模型训练系统100可以通过默认的超参数进行模型训练。
在本实施例中,任务配置界面200还承载有确定控件207和取消控件208。当用户触发确定控件207时,终端20可以向云环境10中的模型训练系统100发送任务创建请求,以便于模型训练系统100根据任务创建请求中的任务配置信息,创建相应的训练任务,并执行该训练任务,以实现模型训练。
其中,模型训练系统100包括通信模块102、预处理模块104、标签管理模块106和训练模块108。下面对模型训练系统100的各个功能模块分别进行介绍。
通信模块102用于获取训练数据集。例如通信模块102可以根据接收的任务配置信息中训练数据集的标识,从数据库30中获取与该标识匹配的训练数据集。该训练数据集包括无标签图像。在一些实施例中,训练数据集包括少量的标签图像和大量的无标签图像。其中,标签图像标注有目标的位置和类别。
预处理模块104用于对所述无标签图像分别进行强增强和弱增强,得到强增强图像和弱增强图像。其中,强增强和弱增强是对图像进行不同程度的增强处理。弱增强通常是基于标准的翻转或平移实现。例如,弱增强可以包括随机水平翻转(Random HorizontalFlip)或者随机水平和垂直移动(random translation)。强增强通常是随机删除图像的正方形部分,并用灰色或黑色填充。强增强通常可以通过RandAugment或CTAugment处理,然后应用CutOut函数实现。
标签管理模块106用于通过第一模型和所述弱增强图像对所述无标签图像中的目标进行预测获得第一标签,通过第二模型和所述强增强图像、所述第一标签对所述无标签图像中的所述目标进行预测,获得预测结果。在图1的示例中,第一模型可以为教师(Teacher)模型,第二模型可以为学生(Student)模型。在本申请实施例其他可能的实现方式中,第一模型和第二模型也可以是能够使得输出具有一致性的其他模型,而不局限于Teacher-Student架构。为了便于理解,下文以Teacher-Student架构进行示例说明。
具体地,标签管理模块106可以调用Teacher模型对弱增强图像进行推理,得到该图像中所有的目标(前景物体)的位置和类别向量。其中,类别向量包括多个元素值,每个元素值表示该目标归属于一个类别(例如为前景类)的评分。标签管理模块106可以根据评分由高至低排序第一即TOP-1的评分和第一阈值筛选生成第一标签。该第一标签也称作伪标签。一个伪标签中包括目标归属的一个类别,目标归属于该类别的评分大于第一阈值,具有较高置信度。
由于强增强图像和弱增强图像是对同一无标签图像进行处理得到。标签管理模块106可以将基于弱增强图像筛选的伪标签作为强增强图像的监督信息,将所述强增强图像输入Student模型,结合所述监督信息进行监督学习,从而实现对所述无标签图像中的所述目标进行预测,获得预测结果。该预测结果包括所述目标的候选框的预测位置和所述目标的预测类别。
进一步地,标签管理模块106还用于通过所述Teacher模型和所述弱增强图像、所述候选框的预测位置对所述无标签图像中的所述目标进行再次预测获得第二标签。为了便于描述,第二标签也称作软标签。具体地,标签管理模块106可以按照Student模型预测得到的候选框的预测位置裁剪弱增强图像,将裁剪后的图像输入所述第一模型Teacher模型进行二次检验,获得软标签。
该软标签可以纠正伪标签中候选框的标注错误。如果伪标签对目标的候选框的标注是正确的,那么新的软标签将会和原来的标注别无二异;如果伪标签对目标的候选框的标注是错误的,那么新的软标签可以纠正这些错误。具体地,标签管理模块106将裁剪后的图像输入Teacher模型,对所述无标签图像中的所述目标进行再次预测,获得所述目标属于各个类别的评分,当TOP-1的评分大于第一阈值时,新的软标签可以与伪标签一致,当TOP1的评分不大于第一阈值时,可以根据目标属于各个类别的评分中由高至低排序的前K个评分(即TOP-K的评分)对应的类别生成软标签。
为了便于理解,下面结合一示例进行说明。参见图3所示的软标签的示意图,无标签图像中包括外观相似的多个目标,例如为目标301和目标302。目标301的软标签可以为目标301的类别为LAPTOP或者TV,目标302的软标签可以为目标302的类别为LAPTOP或者TV。进一步地,软标签还可以包括置信度。例如,目标301的类别为LAPTOP或者TV的置信度为0.9。
训练模块108用于根据软标签更新Student模型的参数,以训练所述Student模型。其中,训练模块108可以根据软标签和Student模型的预测结果中的预测类别确定Student模型的损失,根据该Student模型的损失更新Student模型的参数,从而实现训练所述Student模型。进一步地,训练模块108还用于根据更新后的所述Student模型的参数,更新所述Teacher模型的参数。
需要说明的是,图1以模型训练系统100部署在云环境10中进行示例说明。在一些可能的实现方式中,模型训练系统100也可以部署在边缘环境或是部署在终端20中。其中,边缘环境指示在地理位置上距离终端20(即端侧设备)较近的,用于提供计算、存储、通信资源的边缘计算设备集群。边缘计算设备集群包括一个或多个边缘计算设备,该边缘计算设备可以为服务器、计算盒子等。
还需要说明的是,图1所示的模型训练系统100的各个模块是按照功能划分,在本申请实施例其他可能的实现方式中,模型训练系统100也可以划分为不同模块,模型训练系统100的各个模块还可以分布式地部署在不同环境中。
接下来,将结合附图从模型训练系统100的角度对本申请实施例提供的模型训练方法进行详细说明。
参见图4所示的模型训练方法的流程图,该方法包括:
S402:模型训练系统100获取训练数据集。
所述训练数据集包括无标签图像。其中,训练数据集还包括标签图像,标签图像中标注有目标的位置和类别。其中,目标的位置可以通过标注框的位置标注。在一些实施例中,标注框可以是包围目标的矩形框。相应地,标注框的位置可以通过标注框位于对角的两个顶点的坐标表征,或者是通过标注框的中心点与其中一个顶点的坐标表征。
训练数据集中无标签图像和标签图像的比例可以根据需求设置。考虑到目标检测任务中,一张图像中通常包括较多的目标,导致标注工作量大且复杂度高,训练数据集可以配置少量的标签图像和大量的无标签图像。
模型训练系统100可以从数据库中获取训练数据集。该训练数据集可以是开源的数据集,也可以是自定义的数据集。本实施例对此不作限制。
S404:模型训练系统100对所述无标签图像分别进行强增强和弱增强,得到强增强图像和弱增强图像。
强增强和弱增强是对图像进行不同程度的增强处理。弱增强通常是基于标准的翻转或平移实现。例如,弱增强可以通过随机水平翻转实现,或者通过随机水平和垂直移动实现。强增强通常是随机删除图像的正方形部分,并用灰色或黑色填充。强增强通常可以通过RandAugment或CTAugment处理,然后应用CutOut实现。考虑到PyTorch没有内置的Cutout函数,模型训练系统100可以通过重用其RandomErasing函数来达到CutOut的效果。
具体地,模型训练系统100可以针对训练数据集的每张无标签图像,分别进行强增强和弱增强,从而得到对应的强增强图像和弱增强图像。由于强增强图像和弱增强图像是对同一张无标签图像进行增强处理得到,基于此,基于强增强图像检测的目标和基于弱增强图像检测的目标的类别具有一致性。模型训练系统100可以借助该一致性训练AI模型。
S406:模型训练系统100通过第一模型和所述弱增强图像对所述无标签图像中的目标进行预测获得伪标签。
具体地,模型训练系统100可以将弱增强图像输入第一模型,通过第一模型学习弱增强图像中的特征,并基于该特征进行推理,从而实现预测该弱增强图像对应的无标签图像中目标的位置和类别,获得伪标签。该伪标签包括第一模型预测的目标的候选框的位置和第一模型预测的目标的类别。
第一模型预测的目标的类别可以是前景类中的一个类别。该类别具体可以根据第一模型预测的目标归属于各个类别的评分和第一阈值筛选得到。具体地,第一模型预测得到目标归属于各个类别的评分后,按照由高至低顺序进行排序,然后将排名TOP-1的评分与第一阈值进行比较。当排名TOP-1的评分大于第一阈值时,模型训练系统100可以将排名TOP-1的评分对应的类别生成伪标签。当排名TOP-1的评分不大于第一阈值时,模型训练系统100可以将候选框对应的物体的类别标注为背景类。
在一些可能的实现方式中,伪标签中还可以包括第一模型预测的目标的类别的置信度。该置信度可以根据第一模型预测的目标归属于该类别的评分确定。例如,评分为归一化的评分时,置信度可以等于第一模型预测的目标归属于该类别的评分。
S408:模型训练系统100通过第二模型和所述强增强图像、所述伪标签对所述无标签图像中的所述目标进行预测,获得预测结果。
由于弱增强图像和强增强图像是对同一无标签图像进行不同程度的增强处理得到,因此,基于弱增强图像和强增强图像预测得到的标签具有一致性,基于此,模型训练系统100可以将强增强图像和伪标签结合用于以监督学习的方式训练第二模型。
具体地,模型训练系统100将伪标签作为强增强图像的监督信息,将强增强图像输入第二模型,通过第二模型学习强增强图像中的特征,并基于该特征进行推理,从而实现预测该强增强图像对应的无标签图像中目标的位置和类别,获得预测结果。该预测结果包括所述目标的候选框的预测位置和所述目标的预测类别。
S410:模型训练系统100通过所述第一模型和所述弱增强图像、所述候选框的预测位置对所述无标签图像中的所述目标进行再次预测获得软标签。
考虑到伪标签可能出现候选框的位置不准确或者召回率较低的问题,模型训练系统100可以结合第二模型的预测结果进行二次检验。具体地,模型训练系统100可以按照所述候选框的预测位置裁剪所述弱增强图像,然后将裁剪后的图像输入所述第一模型,对所述无标签图像中的所述目标进行再次预测,获得所述目标属于各个类别的评分。其中,各个类别可以是预设的N个前景类和一个背景类中的各个类别。接着模型训练系统100可以根据所述目标属于各个类别的评分中由高至低排序的前K个评分对应的类别生成软标签,所述K大于1,例如K可以取值为2、3或4。其中,软标签可以包括第一模型根据裁剪后的图像预测的目标的候选框的位置和上述K个类别。
在一些可能的实现方式中,软标签中还可以包括目标归属于上述K个类别中任意一个类别的置信度。其中,目标归属于上述K个类别中任意一个类别的置信度可以等于目标归属于K个类别中各个类别的置信度之和。当置信度等于评分时,目标归属于上述K个类别中任意一个类别的置信度可以等于目标归属于K个类别中各个类别的评分之和。
具体地,训练数据集中总共有N类目标时,模型训练系统100在二次检验时,可以为每个候选框生成一个N+1维的类别向量,通过对类别向量进行Softmax归一化,可以得到软标签。该软标签表达候选框是不同类目标或者是背景类的评分或者置信度。一方面,可以打破候选框中前景和背景的对立性,使得二次检验之后,错误的前景候选框能够被纠正为背景或者其他类别,错误的背景候选框也能够被纠正为前景。另一方面,也可以利用置信度较低的候选框,解决召回率较低的问题。
S412:模型训练系统100根据所述软标签更新所述第二模型的参数,以训练所述第二模型。
具体地,模型训练系统100可以根据软标签和第二模型预测的目标的预测类型确定第二模型的损失,然后根据反向传播(backpropagation,BP)算法通过损失回传的方式更新第二模型的参数。
其中,第二模型可以包括区域候选网络(Region Proposal Network,RPN)和感兴趣区域(region of interest,ROI)网络。在一些实施例中,第二模型还可以包括提取器Extractor。其中,Extractor用于从图像中提取特征,获得特征图feature map,RPN用于根据Extractor提取的特征图生成候选框,ROI网络用于对候选框进行分类,并调整候选框的位置与大小。
RPN在生成候选框时,通常是根据特征图中特征点对应的anchor进行分类和回归,从而生成候选框。需要说明的是,RPN在分类时通常是进行二分类,具体是对anchor是否包括目标进行分类,换言之,RPN将anchor分为前景类和背景类,而由后续的网络如ROI网络对前景类进行进一步细分。其中,ROI网络可以对候选框进行分类和回归,以预测候选框的位置和候选框中目标的类别。基于此,RPN和ROI网络均会产生分类损失和回归损失。
相应地,模型训练系统100可以根据所述伪标签和所述目标的预测类别确定所述RPN的无监督分类损失、所述RPN的无监督回归损失和所述ROI网络的无监督回归损失,根据所述软标签和所述目标的预测类别确定所述ROI网络的无监督分类损失。然后模型训练系统100可以根据所述RPN的无监督分类损失和所述RPN的无监督回归损失以及所述ROI网络的无监督分类损失和所述ROI网络的无监督回归损失,确定所述第二模型的损失,并根据所述第二模型的损失更新所述第二模型的参数。
为了提升训练效率,模型训练系统100通常可以将训练数据集分为多个批次batch。每次迭代时,采用一个批次的数据更新第二模型的参数。如此可以减少参数更新的次数,提高训练效率。
在一些实施例中,每个批次包括一组标签图像和一组无标签图像。其中,标签图像记作无标签图像记作Ns表示一组标签图像的数量,Nu表示一组无标签图像的数量,表示第i个标签图像,表示第i个标签图像的标签,该标签包括图像中目标的标注框的位置和类别,表示第i个无标签图像。
每次迭代时,模型训练系统100将一个批次中的一组标签图像和一组无标签图像进行处理,然后根据一组标签图像的处理结果确定第二模型的监督损失,根据一组无标签图像的处理结果确定第二模型的无监督损失,接着根据第二模型的监督损失和第二模型的无监督损失确定第二模型的损失,基于该损失可以更新第二模型的参数。
第二模型的监督损失可以通过如下公式计算得到:
其中,Lsup表示监督损失,为基于第i个标签图像确定的RPN的分类损失,为基于第i个标签图像确定的RPN的回归损失,为基于第i个标签图像确定的ROI网络的分类损失,为基于第i个标签图像确定的ROI网络的回归损失,为第i个标签图像的标签。
第二模型的无监督损失可以通过如下公式计算得到:
其中,Lunsup表示无监督损失,为基于第i个无标签图像确定的RPN的分类损失,为基于第i个无标签图像确定的RPN的回归损失,为基于第i个无标签图像确定的ROI网络的分类损失,为基于第i个无标签图像确定的ROI网络的回归损失,为第i个标签图像的伪标签,为更新后的伪标签。
伪标签可以是第一模型对第i个无标签图像的弱增强图像进行推理,得到图像中目标的位置和归属于各个类别的评分,根据排名TOP-1的评分和第一阈值σ1筛选得到。更新后的伪标签可以是根据第一模型推理得到的目标归属于各个类别的评分中排名TOP-K的评分之和与第二阈值σ2筛选得到。其中,第二阈值σ2大于第一阈值σ1。σ1、σ2可以根据经验值设置,例如σ1可以设置为0.5,σ2可以设置为0.7。
进一步地,由于模型训练系统100生成了软标签,因此,模型训练系统100可以采用如下公式计算ROI网络的无监督分类损失:
其中,Ni是无标签图像中用于训练的候选框的数量,pj是第二模型在第j个候选框上的预测结果,是第二模型在二次检验阶段为候选框生成的软标签。其中,pj可以采用N+1维向量表示,N为训练数据集中前景类的数量,也即训练数据集中包括目标的类别数量。
进一步地,第二模型根据伪标签生成的候选框经过二次检验之后,有可能会出现一些均匀分布的软标签。这种类型的候选框,第二模型难以分辨其是前景还是背景,对于训练过程的帮助作用十分有限。为了降低这种候选框给训练过程带来的影响,本申请还提出了TOP-K权重机制。
TOP-K权重机制是指采用软标签的TOP-K评分之和作为使用该软标签计算得到的loss的权重。具体地,无标签图像中包括多个目标时,针对多个所述目标中的每个目标,模型训练系统100可以根据所述第二模型预测的所述目标归属于各个类别的评分中由高至低排名前K个评分的和值确定所述目标的权重,然后针对每个目标,根据所述软标签和所述目标的预测类别,采用软交叉熵损失函数,确定所述ROI网络对所述目标的无监督分类损失,接着根据每个目标的所述权重和所述ROI网络对所述目标的无监督分类损失进行加权运算,确定所述ROI网络的无监督分类损失。假如软标签的分布较为均匀,则其权重较低,从而使第二模型更加关注那些易于分辨的候选框。
在确定第二模型的监督损失和无监督损失后,模型训练系统100可以通过如下公式确定第二模型的损失:
L=Lsup+λLunsup (4)
其中,λ为系数,通常可以根据标签图像和无标签图像的比例进行设置。例如,无标签图像的数量远多于标签图像的数量,则λ可以取值为4,无标签图像的数量与标签图像的数量接近,例如比例约为1:1时,则λ可以取值为2。
S414:模型训练系统100根据更新后的第二模型的参数,更新第一模型的参数。
具体地,模型训练系统100可以根据更新后的第二模型的参数,采用EMA策略更新第一模型的参数。其中,第一模型可以为Teacher模型,第二模型可以为Student模型。假设Teacher模型的参数为θt,Student模型的参数为θs,那么每次迭代,θt可以按照如下公式被更新:
θt=αθt+(1-α)θs (5)
其中,α可以为θt的比例系数,通常可以根据经验值设置,例如可以设置为0.999。本实施例对此不作限制。
需要说明的是,执行本申请实施例的模型训练方法也可以不执行S414。例如,当第二模型训练完成时,模型训练系统100可以停止训练,输出训练好的第二模型,以用于执行相应的任务。
基于上述内容描述,本申请实施例提供了一种模型训练方法。该方法中,模型训练系统100使用二次检验,以纠正利用伪标签进行模型训练时由于伪标签质量不足导致的候选分类标注错误,从根本上解决伪标签质量不足所带来的不良影响。其中,该方法所使用的软标签可以打破类别(包括前景类和背景类)之间的对立性,使得图像中外观相似的目标也能在训练中发挥正面作用。并且,TOP-K权重机制可以有效降低二次检验后得到的难以分辨类别的候选框带来的有害影响,保障模型的性能。此外,基于TOP-K评分之和更新伪标签,可以提升伪标签的召回率,进一步改善目标检测的效果。
接下来,结合一具体示例对模型训练方法进行说明。
参见图5所示的模型训练方法的流程示意图,如图5所示,待训练的模型包括Teacher模型和Student模型。Teacher模型包括RPN和ROI网络,Student模型也包括RPN和RPOI网络。训练数据集中包括标签图像和无标签图像,基于此,每次迭代可以分为监督分支(supervised branch)和无监督分支(unsupervised branch)。
监督分支的流程具体为:将标签图像输入Student模型,根据标签图像的标签以及Student模型对该标签图像的推理结果确定监督损失。无监督分支的流程具体为:先对无标签图像分别进行弱增强和强增强,然后将弱增强图像输入Teacher模型,Teacher模型生成伪标签,然后Student模型根据强增强图像和上述伪标签进行学习,获得预测结果。该预测结果包括候选框的预测位置和预测类别。其中,候选框的预测位置可以再输入Teacher模型,Teacher模型根据该预测位置对弱增强图像进行裁剪,然后Teacher模型对裁剪后的图像进行推理,以进行二次检验,获得软标签。
伪标签用于计算无监督RPN分类损失、无监督RPN回归损失和无监督ROI回归损失,软标签用于计算无监督ROI分类损失。根据上述无监督RPN分类损失、无监督RPN回归损失和无监督ROI回归损失、无监督ROI分类损失可以确定无监督损失。上述监督损失和无监督损失一起被用于更新Student模型的参数。接着,模型训练系统100使用EMA机制更新Teacher模型的参数。
基于本申请实施例提供的模型训练方法,本申请实施例还提供了一种如前述的模型训练系统100。下面将结合附图对本申请实施例提供的模型训练系统100进行介绍。
参见图1所示的模型训练系统100的结构示意图,该系统100包括:
通信模块102,用于获取训练数据集,所述训练数据集包括无标签图像;
预处理模块104,用于对所述无标签图像分别进行强增强和弱增强,得到强增强图像和弱增强图像;
标签管理模块106,用于通过第一模型和所述弱增强图像对所述无标签图像中的目标进行预测获得第一标签,通过第二模型和所述强增强图像、所述第一标签对所述无标签图像中的所述目标进行预测,获得预测结果,所述预测结果包括所述目标的候选框的预测位置和所述目标的预测类别;
所述标签管理模块106,还用于通过所述第一模型和所述弱增强图像、所述候选框的预测位置对所述无标签图像中的所述目标进行再次预测获得第二标签;
训练模块108,用于根据所述第二标签更新所述第二模型的参数,以训练所述第二模型。
在一些可能的实现方式中,所述标签管理模块106具体用于:
按照所述候选框的预测位置裁剪所述弱增强图像,将裁剪后的图像输入所述第一模型,对所述无标签图像中的所述目标进行再次预测,获得所述目标属于各个类别的评分;
根据所述目标属于各个类别的评分中由高至低排序的前K个评分对应的类别生成第二标签,所述K大于1。
在一些可能的实现方式中,所述各个类别包括预设的N个前景类别和一个背景类别中的多个类别,所述N为正整数。
在一些可能的实现方式中,所述第二模型包括区域候选网络RPN和感兴趣区域ROI网络,所述训练模块108具体用于:
根据所述第一标签和所述目标的预测类别确定所述RPN的无监督分类损失、所述RPN的无监督回归损失和所述ROI网络的无监督回归损失,根据所述第二标签确定所述ROI网络的无监督分类损失;
根据所述RPN的无监督分类损失和所述RPN的无监督回归损失以及所述ROI网络的无监督分类损失和所述ROI网络的无监督回归损失,确定所述第二模型的损失;
根据所述第二模型的损失更新所述第二模型的参数。
在一些可能的实现方式中,所述训练模块108具体用于:
根据所述第二标签和所述预测类别,采用软交叉熵损失函数确定所述ROI网络的无监督分类损失。
在一些可能的实现方式中,所述无标签图像中包括多个所述目标,所述训练模块108还用于:
针对多个所述目标中的每个目标,根据所述第二模型预测的所述目标归属于各个类别的评分中由高至低排名前K个评分的和值确定所述目标的权重;
所述训练模块108具体用于:
针对多个所述目标中的每个目标,根据所述第二标签和所述目标的预测类别,采用软交叉熵损失函数,确定所述ROI网络对所述目标的无监督分类损失;
根据每个目标的所述权重和所述ROI网络对所述目标的无监督分类损失进行加权运算,确定所述ROI网络的无监督分类损失。
在一些可能的实现方式中,所述标签管理模块106还用于:
根据所述第一模型预测的所述目标归属于各个类别的评分中排名前K的评分之和,更新所述第一标签;
所述训练模块108具体用于:
根据所述第一标签和所述目标的预测类别确定所述RPN的无监督分类损失、所述RPN的无监督回归损失,根据更新后的所述第一标签和所述目标的预测类别确定所述ROI网络的无监督回归损失。
在一些可能的实现方式中,所述第一模型为教师模型,所述第二模型为学生模型。
在一些可能的实现方式中,所述训练模块108还用于:
根据更新后的所述第二模型的参数,更新所述第一模型的参数。
根据本申请实施例的模型训练系统100可对应于执行本申请实施例中描述的方法,并且模型训练系统100的各个模块/单元的上述和其它操作和/或功能分别为了实现图4所示实施例中的各个方法的相应流程,为了简洁,在此不再赘述。
本申请实施例还提供一种计算设备集群。该计算设备集群包括至少一台计算设备,该至少一台计算设备中的任一台计算设备可以来自云环境或者边缘环境,也可以是终端设备。该计算设备集群具体用于实现如图1所示实施例中模型训练系统100的功能。
图6提供了一种计算设备集群的结构示意图,如图6所示,计算设备集群60包括多台计算设备600,计算设备600包括总线601、处理器602、通信接口603和存储器604。处理器602、存储器604和通信接口603之间通过总线601通信。
总线601可以是外设部件互连标准(peripheral component interconnect,PCI)总线或扩展工业标准结构(extended industry standard architecture,EISA)总线等。总线可以分为地址总线、数据总线、控制总线等。为便于表示,图6中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
处理器602可以为中央处理器(central processing unit,CPU)、图形处理器(graphics processing unit,GPU)、微处理器(micro processor,MP)或者数字信号处理器(digital signal processor,DSP)等处理器中的任意一种或多种。
通信接口603用于与外部通信。例如,通信接口603用于获取训练数据集,输出训练好的第二模型的参数,或者输出训练好的第一模型的参数等。
存储器604可以包括易失性存储器(volatile memory),例如随机存取存储器(random access memory,RAM)。存储器604还可以包括非易失性存储器(non-volatilememory),例如只读存储器(read-only memory,ROM),快闪存储器,硬盘驱动器(hard diskdrive,HDD)或固态驱动器(solid state drive,SSD)。
存储器604中存储有计算机可读指令,处理器602执行该计算机可读指令,以使得计算设备集群60执行前述模型训练方法(或实现前述模型训练系统100的功能)。
具体地,在实现图1所示系统的实施例的情况下,且图1中所描述的模型训练系统100的各模块如通信模块102、预处理模块104、标签管理模块106、训练模块108的功能为通过软件实现的情况下,执行图1中各模块的功能所需的软件或程序代码可以存储在计算设备集群60中的至少一个存储器604中。至少一个处理器602执行存储器604中存储的程序代码,以使得计算设备集群60执行前述模型训练方法。
本申请实施例还提供了一种计算机可读存储介质。所述计算机可读存储介质可以是计算设备能够存储的任何可用介质或者是包含一个或多个可用介质的数据中心等数据存储设备。所述可用介质可以是磁性介质,(例如,软盘、硬盘、磁带)、光介质(例如,DVD)、或者半导体介质(例如固态硬盘)等。该计算机可读存储介质包括指令,所述指令指示计算设备或计算设备集群执行上述模型训练方法。
本申请实施例还提供了一种计算机程序产品。所述计算机程序产品包括一个或多个计算机指令。在计算设备上加载和执行所述计算机指令时,全部或部分地产生按照本申请实施例所述的流程或功能。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算设备或数据中心通过有线(例如同轴电缆、光纤、数字用户线(DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算设备或数据中心进行传输。所述计算机程序产品可以为一个软件安装包,在需要使用前述模型训练方法的任一方法的情况下,可以下载该计算机程序产品并在计算设备或计算设备集群上执行该计算机程序产品。
上述各个附图对应的流程或结构的描述各有侧重,某个流程或结构中没有详述的部分,可以参见其他流程或结构的相关描述。

Claims (21)

1.一种模型训练方法,其特征在于,所述方法包括:
获取训练数据集,所述训练数据集包括无标签图像;
对所述无标签图像分别进行强增强和弱增强,得到强增强图像和弱增强图像;
通过第一模型和所述弱增强图像对所述无标签图像中的目标进行预测获得第一标签,通过第二模型和所述强增强图像、所述第一标签对所述无标签图像中的所述目标进行预测,获得预测结果,所述预测结果包括所述目标的候选框的预测位置和所述目标的预测类别;
通过所述第一模型和所述弱增强图像、所述候选框的预测位置对所述无标签图像中的所述目标进行再次预测获得第二标签;
根据所述第二标签更新所述第二模型的参数,以训练所述第二模型。
2.根据权利要求1所述的方法,其特征在于,所述通过所述第一模型和所述弱增强图像、所述候选框的预测位置对所述无标签图像中的目标进行再次预测获得第二标签,包括:
按照所述候选框的预测位置裁剪所述弱增强图像,将裁剪后的图像输入所述第一模型,对所述无标签图像中的所述目标进行再次预测,获得所述目标属于各个类别的评分;
根据所述目标属于各个类别的评分中由高至低排序的前K个评分对应的类别生成第二标签,所述K大于1。
3.根据权利要求2所述的方法,其特征在于,所述各个类别包括预设的N个前景类别和一个背景类别中的多个类别,所述N为正整数。
4.根据权利要求1至3任一项所述的方法,其特征在于,所述第二模型包括区域候选网络RPN和感兴趣区域ROI网络,所述根据所述第二标签更新所述第二模型的参数,包括:
根据所述第一标签和所述目标的预测类别确定所述RPN的无监督分类损失、所述RPN的无监督回归损失和所述ROI网络的无监督回归损失,根据所述第二标签确定所述ROI网络的无监督分类损失;
根据所述RPN的无监督分类损失和所述RPN的无监督回归损失以及所述ROI网络的无监督分类损失和所述ROI网络的无监督回归损失,确定所述第二模型的损失;
根据所述第二模型的损失更新所述第二模型的参数。
5.根据权利要求4所述的方法,其特征在于,所述根据所述第二标签确定所述ROI网络的无监督分类损失,包括:
根据所述第二标签和所述预测类别,采用软交叉熵损失函数确定所述ROI网络的无监督分类损失。
6.根据权利要求5所述的方法,其特征在于,所述无标签图像中包括多个所述目标,所述方法还包括:
针对多个所述目标中的每个目标,根据所述第二模型预测的所述目标归属于各个类别的评分中由高至低排名前K个评分的和值确定所述目标的权重;
所述根据所述第二标签和所述目标的预测类别,采用软交叉熵损失函数确定所述ROI网络的无监督分类损失,包括:
针对多个所述目标中的每个目标,根据所述第二标签和所述目标的预测类别,采用软交叉熵损失函数,确定所述ROI网络对所述目标的无监督分类损失;
根据每个目标的所述权重和所述ROI网络对所述目标的无监督分类损失进行加权运算,确定所述ROI网络的无监督分类损失。
7.根据权利要求4所述的方法,其特征在于,所述方法还包括:
根据所述第一模型预测的所述目标归属于各个类别的评分中排名前K的评分之和,更新所述第一标签;
所述根据所述第一标签和所述目标的预测类别确定所述RPN的无监督分类损失、所述RPN的无监督回归损失和所述ROI网络的无监督回归损失,包括:
根据所述第一标签和所述目标的预测类别确定所述RPN的无监督分类损失、所述RPN的无监督回归损失,根据更新后的所述第一标签和所述目标的预测类别确定所述ROI网络的无监督回归损失。
8.根据权利要求1至7任一项所述的方法,其特征在于,所述第一模型为教师模型,所述第二模型为学生模型。
9.根据权利要求8所述的方法,其特征在于,所述方法还包括:
根据更新后的所述第二模型的参数,更新所述第一模型的参数。
10.一种模型训练系统,其特征在于,所述系统包括:
通信模块,用于获取训练数据集,所述训练数据集包括无标签图像;
预处理模块,用于对所述无标签图像分别进行强增强和弱增强,得到强增强图像和弱增强图像;
标签管理模块,用于通过第一模型和所述弱增强图像对所述无标签图像中的目标进行预测获得第一标签,通过第二模型和所述强增强图像、所述第一标签对所述无标签图像中的所述目标进行预测,获得预测结果,所述预测结果包括所述目标的候选框的预测位置和所述目标的预测类别;
所述标签管理模块,还用于通过所述第一模型和所述弱增强图像、所述候选框的预测位置对所述无标签图像中的所述目标进行再次预测获得第二标签;
训练模块,用于根据所述第二标签更新所述第二模型的参数,以训练所述第二模型。
11.根据权利要求10所述的系统,其特征在于,所述标签管理模块具体用于:
按照所述候选框的预测位置裁剪所述弱增强图像,将裁剪后的图像输入所述第一模型,对所述无标签图像中的所述目标进行再次预测,获得所述目标属于各个类别的评分;
根据所述目标属于各个类别的评分中由高至低排序的前K个评分对应的类别生成第二标签,所述K大于1。
12.根据权利要求11所述的系统,其特征在于,所述各个类别包括预设的N个前景类别和一个背景类别中的多个类别,所述N为正整数。
13.根据权利要求10至12任一项所述的系统,其特征在于,所述第二模型包括区域候选网络RPN和感兴趣区域ROI网络,所述训练模块具体用于:
根据所述第一标签和所述目标的预测类别确定所述RPN的无监督分类损失、所述RPN的无监督回归损失和所述ROI网络的无监督回归损失,根据所述第二标签确定所述ROI网络的无监督分类损失;
根据所述RPN的无监督分类损失和所述RPN的无监督回归损失以及所述ROI网络的无监督分类损失和所述ROI网络的无监督回归损失,确定所述第二模型的损失;
根据所述第二模型的损失更新所述第二模型的参数。
14.根据权利要求13所述的系统,其特征在于,所述训练模块具体用于:
根据所述第二标签和所述预测类别,采用软交叉熵损失函数确定所述ROI网络的无监督分类损失。
15.根据权利要求14所述的系统,其特征在于,所述无标签图像中包括多个所述目标,所述训练模块还用于:
针对多个所述目标中的每个目标,根据所述第二模型预测的所述目标归属于各个类别的评分中由高至低排名前K个评分的和值确定所述目标的权重;
所述训练模块具体用于:
针对多个所述目标中的每个目标,根据所述第二标签和所述目标的预测类别,采用软交叉熵损失函数,确定所述ROI网络对所述目标的无监督分类损失;
根据每个目标的所述权重和所述ROI网络对所述目标的无监督分类损失进行加权运算,确定所述ROI网络的无监督分类损失。
16.根据权利要求13所述的系统,其特征在于,所述标签管理模块还用于:
根据所述第一模型预测的所述目标归属于各个类别的评分中排名前K的评分之和,更新所述第一标签;
所述训练模块具体用于:
根据所述第一标签和所述目标的预测类别确定所述RPN的无监督分类损失、所述RPN的无监督回归损失,根据更新后的所述第一标签和所述目标的预测类别确定所述ROI网络的无监督回归损失。
17.根据权利要求10至16任一项所述的系统,其特征在于,所述第一模型为教师模型,所述第二模型为学生模型。
18.根据权利要求17所述的系统,其特征在于,所述训练模块还用于:
根据更新后的所述第二模型的参数,更新所述第一模型的参数。
19.一种计算设备集群,其特征在于,所述计算设备集群包括至少一台计算设备,所述至少一台计算设备包括至少一个处理器和至少一个存储器,所述至少一个存储器中存储有计算机可读指令,所述至少一个处理器执行所述计算机可读指令,使得所述计算设备集群执行如权利要求1至9任一项所述的方法。
20.一种计算机可读存储介质,其特征在于,包括计算机可读指令,当所述计算机可读指令在计算设备或计算设备集群上运行时,使得所述计算设备或计算设备集群执行如权利要求1至9任一项所述的方法。
21.一种计算机程序产品,其特征在于,包括计算机可读指令,当所述计算机可读指令在计算设备或计算设备集群上运行时,使得所述计算设备或计算设备集群执行如权利要求1至9任一项所述的方法。
CN202210074265.0A 2022-01-21 2022-01-21 一种模型训练方法及相关系统 Pending CN116524289A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210074265.0A CN116524289A (zh) 2022-01-21 2022-01-21 一种模型训练方法及相关系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210074265.0A CN116524289A (zh) 2022-01-21 2022-01-21 一种模型训练方法及相关系统

Publications (1)

Publication Number Publication Date
CN116524289A true CN116524289A (zh) 2023-08-01

Family

ID=87394558

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210074265.0A Pending CN116524289A (zh) 2022-01-21 2022-01-21 一种模型训练方法及相关系统

Country Status (1)

Country Link
CN (1) CN116524289A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117636086A (zh) * 2023-10-13 2024-03-01 中国科学院自动化研究所 无源域适应目标检测方法及装置
CN117635917A (zh) * 2023-11-29 2024-03-01 北京声迅电子股份有限公司 基于半监督学习的目标检测模型训练方法及目标检测方法

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117636086A (zh) * 2023-10-13 2024-03-01 中国科学院自动化研究所 无源域适应目标检测方法及装置
CN117635917A (zh) * 2023-11-29 2024-03-01 北京声迅电子股份有限公司 基于半监督学习的目标检测模型训练方法及目标检测方法
CN117635917B (zh) * 2023-11-29 2024-09-13 北京声迅电子股份有限公司 基于半监督学习的目标检测模型训练方法及目标检测方法

Similar Documents

Publication Publication Date Title
US20230195845A1 (en) Fast annotation of samples for machine learning model development
US11151417B2 (en) Method of and system for generating training images for instance segmentation machine learning algorithm
US11537506B1 (en) System for visually diagnosing machine learning models
US10803398B2 (en) Apparatus and method for information processing
US11379718B2 (en) Ground truth quality for machine learning models
CN110674880A (zh) 用于知识蒸馏的网络训练方法、装置、介质与电子设备
US20170344848A1 (en) Generating image features based on robust feature-learning
US20230153622A1 (en) Method, Apparatus, and Computing Device for Updating AI Model, and Storage Medium
CN114912612A (zh) 鸟类识别方法、装置、计算机设备及存储介质
US20220067588A1 (en) Transforming a trained artificial intelligence model into a trustworthy artificial intelligence model
CN111966914A (zh) 基于人工智能的内容推荐方法、装置和计算机设备
CN116524289A (zh) 一种模型训练方法及相关系统
KR20190029083A (ko) 신경망 학습 방법 및 이를 적용한 장치
US11901969B2 (en) Systems and methods for managing physical connections of a connector panel
JP6751816B2 (ja) 新規学習データセット生成方法および新規学習データセット生成装置
KR20210066545A (ko) 반도체 소자의 시뮬레이션을 위한 전자 장치, 방법, 및 컴퓨터 판독가능 매체
CN112819024B (zh) 模型处理方法、用户数据处理方法及装置、计算机设备
JPWO2017188048A1 (ja) 作成装置、作成プログラム、および作成方法
CN115131604A (zh) 一种多标签图像分类方法、装置、电子设备及存储介质
CN117011737A (zh) 一种视频分类方法、装置、电子设备和存储介质
CN114723989A (zh) 多任务学习方法、装置及电子设备
Liao et al. ML-LUM: A system for land use mapping by machine learning algorithms
US20220114480A1 (en) Apparatus and method for labeling data
CN110059743B (zh) 确定预测的可靠性度量的方法、设备和存储介质
US20230141408A1 (en) Utilizing machine learning and natural language generation models to generate a digitized dynamic client solution

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination