CN114970673A - 一种半监督模型训练方法、系统及相关设备 - Google Patents

一种半监督模型训练方法、系统及相关设备 Download PDF

Info

Publication number
CN114970673A
CN114970673A CN202210412186.6A CN202210412186A CN114970673A CN 114970673 A CN114970673 A CN 114970673A CN 202210412186 A CN202210412186 A CN 202210412186A CN 114970673 A CN114970673 A CN 114970673A
Authority
CN
China
Prior art keywords
model
pseudo
sample
label
tag
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202210412186.6A
Other languages
English (en)
Other versions
CN114970673B (zh
Inventor
徐强
徐晓忻
黄全充
傅蓉蓉
蔡晨东
纪荣嵘
周奕毅
罗根
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Huawei Technologies Co Ltd
Original Assignee
Huawei Technologies Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Huawei Technologies Co Ltd filed Critical Huawei Technologies Co Ltd
Priority to CN202210412186.6A priority Critical patent/CN114970673B/zh
Publication of CN114970673A publication Critical patent/CN114970673A/zh
Application granted granted Critical
Publication of CN114970673B publication Critical patent/CN114970673B/zh
Priority to PCT/CN2023/089098 priority patent/WO2023202596A1/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/25Fusion techniques

Landscapes

  • Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本申请提供了一种半监督模型训练方法、系统及相关设备,该方法可包括以下步骤将第一无标签样本输入第一模型,获得第一无标签样本的第一伪标签,将第一扩充样本输入第一模型,获得扩充样本的第二伪标签,其中,第一模型为采用有标签样本进行训练后的人工智能AI模型,第一扩充样本为对第一无标签样本进行数据增强后获得的样本,根据第一伪标签和第二伪标签,获得第一无标签样本的第三伪标签,使用第一无标签样本和第三伪标签对第二模型进行训练,其中,第二模型是根据第一模型的权重参数获得的AI模型,这样获得的第三伪标签的精度更高,从而提高半监督模型训练的训练效率以及最终获得模型的精度。

Description

一种半监督模型训练方法、系统及相关设备
技术领域
本申请涉及人工智能(artificial intelligence,AI)技术领域,尤其涉及一种半监督模型训练方法、系统及相关设备。
背景技术
半监督学习(semi-supervised learning,SSL)指的是使用有标签样本和无标签样本对AI模型(本申请也简称为“模型”)进行训练的方法,通过半监督学习,可以有效减少有标签样本数量,降低模型训练的成本。
通常情况下,半监督学习采用第一-第二模型(又可称为教师-学生模型,其中,第一模型为教师模型,第二模型为学生模型)的半监督模型训练方法。具体地:可先使用有标签样本对第一模型进行训练,然后将无标签样本输入至经有标签样本训练后的第一模型,推理出无标签样本的伪标签,再将第一模型的权重参数复制给结构相同的第二模型,使用上述伪标签以及无标签样本对第二模型进行训练,获得更新后的第二模型的权重参数,然后将第二模型的部分权重参数更新给第一模型。进而,在第二轮迭代中,第一模型推理新的伪标签样本给第二模型进行训练,以此迭代,使得第一模型的权重参数得以稳定更新,最终训练获得的第一模型鲁棒性更强、模型性能更强。
但是,由于通常无标签样本的数量远高于有标签样本,伪标签样本对第二模型训练的影响很大,若第一模型生成的伪标签误差较大时,会导致第二模型的训练效率低,模型性能差,进而影响最终获得第一模型的训练效率以及模型性能。
发明内容
本申请提供了一种半监督模型训练方法、系统及相关设备,用于解决半监督学习过程中,伪标签质量差导致模型训练效率低、模型训练差的问题。
第一方面,提供了一种半监督模型训练方法,该方法可包括以下步骤:将第一无标签样本输入第一模型,获得第一无标签样本的第一伪标签,将第一扩充样本输入第一模型,获得第一扩充样本的第二伪标签,其中,第一模型为采用有标签样本进行训练后的人工智能AI模型,第一扩充样本为对第一无标签样本进行数据增强后获得的样本,根据第一伪标签和第二伪标签,获得第一无标签样本的第三伪标签,使用第一无标签样本和第三伪标签对第二模型进行训练,其中,第二模型是根据第一模型的权重参数获得的AI模型。
实施第一方面描述的方法,通过对第一无标签样本进行数据增强获得第一无标签样本的第一扩充样本,然后将第一无标签样本和第一扩充样本输入第一模型推理获得第一无标签样本的第一伪标签以及第一扩充样本的第二伪标签,然后根据第一伪标签和第二伪标签获得第三伪标签,这样获得的第三伪标签质量更高,使用第三伪标签对第二模型进行半监督训练时,模型的训练效率和性能得以提升,进而提升最终获得的第一模型的训练效率以及模型性能。
在一可能的实现方式中,对第一无标签样本进行数据增强时,数据增强方法可包括但不限于翻转变换(flip)、平移变换(shift)、尺度变换(scale)、旋转变换/反射变换(rotation/reflection)、缩放变换(zoom)、修剪(crop)、颜色变换(color space)、噪声扰动(noise)、内核过滤(kernel filters)中的一种或者多种。
其中,翻转变换指的是对图像进行水平或垂直翻转,水平翻转还可分为向上水平翻转和向下水平翻转,垂直翻转还可分为向左垂直翻转和向右垂直翻转;平移变换指的是对图像进行平移操作,比如x方向向右平移(param xoffset),y方向向下平移(paramyoffset),其中x方向和y方向指的是图像坐标系的横轴方向和纵轴方向;旋转变换也可称为反射变换,指的是对图像进行某个角度的旋转,该角度可以是0~360度中的任意角度;缩放变换指的是将图像按照一定比例进行放大或者缩小,而不会改变图像中的内容;修剪也可称为裁剪,包括统一裁剪和随机裁剪,统一裁剪指的是将不同尺寸的图像裁剪至设定大小,随机裁剪指的是将不同尺寸的图像随机裁剪成不同尺寸大小;颜色变换指的是对图像某种颜色通道进行修改,比如关闭通道或者改变通道亮度值,举例来说,图像通常包括RGB三个通道,颜色变换可以将R通道值减少或增大;噪声扰动指的是从高斯分布中采样出的随机值矩阵加入到图像的RGB像素矩阵中;内核过滤指的是使用特定功能的内核过滤器与图像进行卷积操作,比如锐化、模糊等内核过滤器。
应理解,上述数据增强方法用于举例说明,本申请还可通过其他数据增强的方法对第一无标签样本进行扩充获得第一扩充样本,比如对图像进行增强还可以通过对抗生成(adversarial training)、特征空间增强(feature space augmentation)、基于GAN的数据增强(gan-based data augmentation)等数据增强方法,本申请不一一展开赘述。
需要说明的,由于目标检测模型的标签通常为检测框(bounding box),因此在在第一模型是目标检测模型时,可使用翻转变换、平移变换、尺度变换、旋转变换、缩放变换等对检测框产生影响的数据增强方法,对第一无标签样本进行数据增强;由于图像识别模型的标签为图像所属类别的概率分布,因此在第一模型是图像识别模型时,可使用修剪、颜色变换、噪声扰动、内核过滤等对图像类别判定产生影响的数据增强方法,对第一无标签样本进行数据增强。
上述实现方式,通过针对模型类型进行不同的数据增强操作,可以使得最终获得的扩充样本可以增加模型的泛化性能,提高模型的鲁棒性。
在一可能的实现方式中,可根据第一伪标签和第二伪标签,获得第一伪标签和第二伪标签之间的匹配度;在匹配度高于阈值的情况下,将第一伪标签和第二伪标签进行融合,获得第三伪标签。
可选地,第一模型为目标检测模型时,目标检测模型的输出结果可能是多个目标检测框,可以通过非极大抑制(non maximum suppression,NMS)方法,选择多个目标检测框中精度最高的检测框作为第一伪标签或者第二伪标签,从而增加第一无标签样本的第一伪标签和扩充样本的第二伪标签的精度。
可选地,若第一模型为目标检测模型,在将第一伪标签和第二伪标签进行匹配时,可以对第二伪标签进行数据增强的逆操作,获得第四伪标签,然后将第四伪标签与第一伪标签进行匹配,获得第一伪标签和第四伪标签之间的匹配结果,根据上述匹配结果确定上述匹配度。其中,逆操作指的是与数据增强单元执行的数据增强方法相反的操作,比如数据增强方法是对第一无标签样本进行了水平向上翻转操作获得第一扩充样本,那么匹配单元此时可以对第一扩充样本的第二伪标签对应的标准框进行水平向下翻转操作获得第四伪标签,再比如数据增强方法是对第一无标签样本进行了向右旋转90°操作获得第一扩充样本,那么匹配单元123可以对第一扩充样本的第二伪标签对应的标准框进行向左旋转90操作获得第四伪标签,以此类推,这里不一一举例说明。
上述实现方式,目标检测模型的伪标签是图像中目标的检测框,因此通过数据增强方法获得的第一扩充样本,目标的位置实际已发生改变,检测框需要对其进行逆操作,使得第一伪标签和第二伪标签所框选的目标是同一个位置的目标,然后再将其进行匹配,可以筛选出标注目标不准确的第一伪标签,从而避免对第二模型进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
在一可能的实现方式中,若第一模型为目标检测模型,可以将第一伪标签对应的检测框与第四伪标签对应的检测框进行匹配,获得上述匹配度,这里的匹配度可以是两个检测框之间的交并比(intersection over union,IOU)。
具体实现中,在匹配度大于阈值时,匹配单元可以将第一伪标签和第二伪标签进行融合获得第三伪标签时,可以将上述第四伪标签对应的检测框与第一伪标签对应的检测框进行多值平均处理,获得第三伪标签。
上述实现方式,将第一伪标签对应的检测框与第四伪标签对应的检测框进行融合,两个检测框是使用不同方法确定的目标检测框,因此将二者融合可以进一步提高最终获得的第三伪标签的精度,从而提高后续半监督训练过程中所使用伪标签的精度,从而提高训练效率、提高最终获得的第一模型的精度。
在一可能的实现方式中,若第一模型为图像识别模型,可以将第一伪标签和第二伪标签进行匹配,获得第一伪标签和第二伪标签之间的匹配结果,根据该匹配结果确定二者之间的匹配度。
具体实现中,可以将第一伪标签对应的概率分布与第二伪标签对应的概率分布进行匹配,确定二者之间的匹配度,这里的匹配度可以是两个概率分布之间的相似度或者距离,本申请不作具体限定。将匹配度大于阈值的第一伪标签和第二伪标签进行融合时,可以将第一伪标签的概率分布与第二伪标签的概率分布进行均值处理,比如平均数、加权平均等等,本申请不作具体限定。
上述实现方式,在第一模型为图像识别模型时,将第一伪标签和第二伪标签进行匹配,大于阈值的情况下将二者进行融合,可以进一步提高最终获得的第三伪标签的精度,从而提高后续半监督训练过程中所使用伪标签的精度,从而提高训练效率、提高最终获得的第一模型的精度。
在一可能的实现方式中,也可以将第二无标签样本输入第一模型,获得第二无标签样本的第五伪标签,然后将第二扩充样本输入第一模型,获得第二扩充样本的第六伪标签,这里,第二扩充样本为对第二无标签样本进行数据增强后获得的样本。可以根据第五伪标签和第六伪标签,获得第五伪标签和第六伪标签之间的匹配度,在匹配度不高于上述阈值的情况下,删除第五伪标签和第六伪标签。其中,上述第五伪标签和第六伪标签之间匹配度的确定方式可以参考前述内容中第一伪标签和第二伪标签之间匹配度的确定方式,这里不重复展开赘述。
需要说明的,第一伪标签和第二伪标签之间的匹配度如果不高于阈值,也可以将第一伪标签和第二伪标签删除,同理,如果第五伪标签和第六伪标签之间的匹配度高于阈值,也可以将第五伪标签和第六伪标签进行融合,融合方式和参考前述内容中关于第一伪标签和第二伪标签融合获得第三伪标签的描述这里不重复赘述。简单来说,无标签样本集和扩充样本集进行匹配,呈对应关系的无标签样本的伪标签和扩充样本的伪标签会进行匹配获得相应的匹配度,若匹配度高于阈值则将二者的伪标签进行融合,匹配度低于阈值则将二者的伪标签都进行删除。
上述实现方式,通过将无标签样本集的伪标签和扩充样本集的伪标签进行匹配的方式,可以过滤出准确度较低的伪标签,从而避免对第二模型进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
在一可能的实现方式中,第一模型的模型结构可以与第二模型的模型结构相同。
可选地,在对第二模型进行训练时,可先将第一模型的权重参数拷贝给第二模型,然后使用第一无标签样本、第三伪标签以及上述训练第一模型时使用的有标签样本对第二模型进行迭代训练,根据每次迭代训练获得的第二模型的权重参数对第一模型的权重参数进行迭代更新,获得目标模型。
具体地,使用第一无标签样本、第三伪标签以及上述有标签样本对第二模型进行第一轮训练,获得第一轮更新后第二模型的权重参数,然后将其发送给第一模型进行对第一模型的更新获得新的第一模型,再将上述第一无标签样本和第一扩充样本输入新的第一模型,预测出新的第一伪标签和新的第二伪标签,再将匹配度高于阈值的新的第一伪标签和新的第二伪标签进行融合获得新的第三伪标签,再使用第一无标签样本和新的第三伪标签以及有标签样本继续对第二模型训练至收敛,获得第二轮更新后的权重参数,然后再将其更新至第一模型,以此类推,这里不一一展开赘述。
需要说明的,在接下来多轮迭代训练过程中,比如第二轮训练时,可以将上述第二无标签样本和第二扩充样本输入新的第一模型,获得新的第五伪标签和新的第六伪标签,如果第五伪标签和第六伪标签之间的匹配度仍然不高于阈值,可以继续将第五伪标签和第六伪标签删除;如果第五伪标签和第六伪标签之间的匹配度高于阈值,此时可以将第五伪标签和第六伪标签进行融合获得第七伪标签,然后使用第七伪标签、第二无标签样本以及有标签样本对第二模型进行训练至收敛,获得第二轮更新后第二模型的权重参数,再将其更新至第一模型,以此类推。
同理,如果接下来多轮迭代训练过程中,比如第二轮训练时,新的第一伪标签和新的第二伪标签之间的匹配度不高于阈值,也可以将新的第一伪标签和新的第二伪标签删除,这里不重复赘述。
上述实现方式,通过将第二模型的权重参数对第一模型进行更新,然后多轮迭代训练的方式,使得第一模型推理出的伪标签精度越来越高,直至第一模型的预测精度达到用户所需的标准,从而获得目标模型。
在一可能的实现方式中,第二模型每轮训练获得的权重参数可以全部更新至第一模型,也可以将每轮训练获得的部分权重参数更新至第一模型,使得第一模型得到缓慢、稳定的权重更新,这样训练获得的第一模型更具鲁棒性,模型性能更佳。
具体实现中,可通过指数滑动平均(exponential moving average,EMA)方法将学生的权重更新给第一模型,举例来说,假设EMA=0.99,那么每轮训练获得的权重参数的1%将被更新入第一模型。应理解,上述举例用于说明,本申请不作具体限定。
上述实现方式,第二模型与第一模型的模型结构相同,可以使用少量的、获取困难的有标签样本和大量的、容易获取的无标签样本对机器学习模型进行训练,获得的目标模型不仅鲁棒性好,而且模型性能好,训练效率高。
在一可能的实现方式中,第一模型的模型结构也可以包括第二模型的模型结构,也就是说,第二模型是小模型,第一模型是大模型,比如第二模型是第一模型的一个子模型,同样的,第二模型将每轮训练获得的更新后的权重同步至第一模型,新的第一模型再预测新的伪标签对第二模型进行训练,以此类推,稳步更新第二模型和第一模型。或者,不需要多轮训练,第一轮在第二模型训练收敛后,将训练好的第二模型作为目标模型。
上述实现方式,使用的第二模型是小模型,第一模型是大模型,这样最终获得的第二模型不仅结构复杂度低,而且模型性能与第一模型趋于接近,甚至可以比第一模型性能更好,从而达到模型压缩的目的。
在本申请实施例中,使用有标签样本、第一无标签样本和第三伪标签对第二模型进行训练时,可以将有标签样本集中的输入样本输入第二模型获得第一输出值,将第一无标签样本输入第二模型获得第二输出值,根据第一输出值和第二输出值确定第二模型的损失值,然后根据损失值对第二模型进行反向传播直至收敛,获得训练好的第二模型,然后将训练好的第二模型的模型参数同步至第一模型中,再进行下一轮的模型训练。其中,上述损失值L包括有标签损失L1和无标签损失L2,该损失值L是根据第一输出值和真实标签之间的差距获得的,伪标签损失是根据第二输出值和第三伪标签之间的差距获得的。
具体实现中,可以通过系数加权的方式对有标签损失L1和无标签损失L2在损失值L中的比重进行调控,比如损失值L=L1+λL2,其中,λ越大,无标签损失L2在损失值L中的占比越大,第一无标签样本对第二模型和第一模型的模型性能影响越大,通过对λ的值进行调整,可以认为干预模型训练的学习方向。
上述实现方式,通过有标签损失和无标签损失共同影响第二模型的训练方向,并且无标签损失是基于上述过滤和融合后获得的第三伪标签与输出值之间的差距确定的,使得第二模型在半监督学习过程中,可以使用大量的无标签样本进行训练,从而降低样本获取的成本,同时不影响最终获得目标模型的性能。
在一可能的实现方式中,上述半监督模型训练方法也可以打包成一个软件模块,对现有的一些模型训练的设备进行软件升级,使其能够拥有对伪标签过滤、融合的功能,使得升级后的模型训练系统可以有更好的半监督训练功能。
举例来说,在公有云场景下,用于实现上述半监督模型训练方法的各个单元模块可以打包为一个配置模块,作为公有云模型训练服务中的一个小的配置功能,如果公有云用户购买该功能,即可为用户提供相应的权限。在非公有云场景下,用于实现上述半监督模型训练方法的各个单元模块可以打包为一个微服务或者软件包,用户购买本申请提供的伪标签过滤、融合功能之后,可以向用户提供相应的权限的许可(license),不同权限可设置不同的收费程度。本申请不作具体限定。
上述实现方式,通过软件打包为微服务、提供licencse或者提供云服务的方式,不仅用户获取方法简单快捷,而且开发者可以对原有的模型训练系统进行简单的软件升级即可实现上述各种功能,对开发者来说升级、维护都十分便捷,本申请提供的半监督模型训练方法部署方便,可用性高。
第二方面,提供了一种半监督模型训练系统,该系统包括:推理单元,用于将第一无标签样本输入第一模型,获得第一无标签样本的第一伪标签;推理单元,用于将第一扩充样本输入第一模型,获得扩充样本的第二伪标签,其中,第一模型为采用有标签样本进行训练后的人工智能AI模型,第一扩充样本为对第一无标签样本进行数据增强后获得的样本;匹配单元,用于根据第一伪标签和第二伪标签,获得第一无标签样本的第三伪标签;训练单元,用于使用第一无标签样本和第三伪标签对第二模型进行训练,其中,第二模型是根据第一模型的权重参数获得的AI模型。
实施第二方面描述的方法,本申请提供的模型训练系统,通过对无标签样本集进行数据增强获得无标签样本集的扩充样本集,然后将无标签样本集和扩充样本集输入第一模型推理获得无标签样本的集的多个第一伪标签以及扩充样本集的多个第二伪标签,然后将匹配度高于阈值的第一伪标签和第二伪标签进行融合获得第三伪标签,将匹配度低于或等于阈值的第一伪标签进行过滤,从而提高未标注样本集合的伪标签的质量,使得后续使用第三伪标签对学生模型进行半监督训练时,模型的训练效率和性能得以提升,进而提升最终获得的第一模型的训练效率以及模型性能。
在一可能的实现方式中,第二模型与第一模型具有相同的结构。
在一可能的实现方式中,匹配单元,用于根据第一伪标签和第二伪标签,获得第一伪标签和第二伪标签之间的匹配度;匹配单元,用于在匹配度高于阈值的情况下,将第一伪标签和第二伪标签进行融合,获得第三伪标签。
在一可能的实现方式中,第一模型包括目标检测模型,数据增强方法包括翻转变换、平移变换、尺度变换、旋转变换、缩放变换中的一种或者多种。
在一可能的实现方式中,匹配单元,用于对第二伪标签进行数据增强的逆操作,获得第四伪标签;匹配单元,用于对第一伪标签和第四伪标签进行匹配,获得第一伪标签和第四伪标签之间的匹配结果;匹配单元,用于根据第一伪标签与第四伪标签之间的匹配结果确定匹配度。
在一可能的实现方式中,第一模型包括图像识别模型,数据增强方法包括修剪、颜色变换、噪声扰动、内核过滤中的一种或者多种。
在一可能的实现方式中,匹配单元,用于对第一伪标签和第二伪标签进行匹配,获得第一伪标签和第二伪标签之间的匹配结果;匹配单元,用于根据第一伪标签和第二伪标签之间的匹配结果获得匹配度。
在一可能的实现方式中,推理单元,用于将第二无标签样本输入第一模型,获得第二无标签样本的第五伪标签;推理单元,用于将第二扩充样本输入第一模型,获得第二扩充样本的第六伪标签,第二扩充样本为对第二无标签样本进行数据增强后获得的样本;匹配单元,用于根据第五伪标签和第六伪标签,获得第五伪标签和第六伪标签之间的匹配度;匹配单元,用于在匹配度不高于阈值的情况下,删除第五伪标签和第六伪标签。
在一可能的实现方式中,训练单元,用于使用有标签样本、第一无标签样本和第三伪标签样本对第二模型进行迭代训练,根据每次迭代训练获得的第二模型的权重参数对第一模型的权重参数进行迭代更新,获得目标模型。
在一可能的实现方式中,训练单元,用于将输入样本输入第二模型获得第一输出值,将第一无标签样本输入第二模型获得第二输出值,根据第一输出值和第二输出值确定第二模型的损失值,其中,损失值包括有标签损失和伪标签损失,有标签损失是根据第一输出值和真实标签之间的差值获得的,伪标签损失是根据第二输出值和第三伪标签之间的差值获得的;训练单元,用于根据损失值对第二模型进行迭代训练。
第三方面,提供了一种计算设备,该计算设备包括处理器和存储器,存储器存储有代码,处理器包括用于执行第一方面或第一方面任一种可能实现方式描述的方法。
第四方面,提供了一种计算机存储介质,所述存储介质中存储有指令,当其在计算设备上运行时,使得计算设备执行第一方面或第一方面任一种可能实现方式描述的方法。
第五方面,提供了一种计算机程序指令,该计算机程序指令在计算设备上运行时,使得计算设备执行第一方面或第一方面任一种可能实现方式描述的方法。
本申请在上述各方面提供的实现方式的基础上,还可以进行进一步组合以提供更多实现方式。
附图说明
图1是一种半监督学习的步骤流程示意图;
图2是本申请提供的一种半监督模型训练系统的架构示意图;
图3是本申请提供的一种半监督模型训练方法的步骤流程示意图;
图4是本申请提供的一种半监督模型训练方法在一应用场景下步骤流程示意图;
图5是本申请提供的半监督模型训练方法中第一伪标签和第二伪标签的融合流程示意图;
图6是本申请提供的一种计算设备的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
首先,对本申请涉及的术语进行简单解释。
有标签样本(labeled)和无标签样本(unlabeled):有标签样本指的是拥有标签(label)的样本,该样本的标签表示该样本的真实值,该真实值用于在模型训练时与模型的预测值一起计算损失值,进而对模型的权重参数进行调整。例如:在分类模型训练时,将有标签样本作为输入数据输入至初始分类模型后,初始分类模型提供的预测结果与该有标签样本的标签进行比较,得到该轮训练的损失值,进而根据损失值可以调整分类模型的权重参数。相反,未标注样本即为不包含标签的样本。
损失函数(loss function):损失函数用于在模型训练过程中,评估模型的输出结果与样本标签之间的差距,损失值(loss)即为损失函数对应的值。损失值越低,模型的鲁棒性越好,因此,模型训练过程中通常会在样本输入模型获得输出值后,根据输出值与样本标签之间的差距确定损失值,根据损失值的大小对模型的权重参数进行调整,以此迭代,直至模型的损失函数最小化,获得目标模型。
其次,对本申请涉及的“半监督学习”应用场景进行说明。
AI是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。人工智能领域的应用场景包括机器人,自然语言处理,计算机视觉,决策与推理,人机交互,推荐与搜索等。
一般来说,AI领域的各种应用通过AI模型实现,而AI模型则是通过样本集对模型训练获得。其中,半监督学习/半监督训练是降低模型训练时所需的有标签样本集数量的一种重要方法,因理解,有标签样本集通常是人为标注的样本集,因此有标签样本集的数量越多,模型训练的效率越低、成本越高,因此,半监督学习在标注预算受限下,具有很强的实践价值。
图1是一种半监督学习的步骤流程示意图,如图1所示,半监督学习过程使用的样本集包括有标签样本101和无标签样本102,第一模型112以及第二模型113的网络结构相同,具体地,半监督学习的步骤流程可以如下:
步骤1、使用有标签样本101对初始第一模型进行训练,模型收敛后获得第一模型112。其中,初始第一模型可以是网络初始化后,还未进行训练的AI模型。
步骤2、将无标签样本102输入第一模型112,第一模型112推理出无标签样本102的伪标签,获得伪标签样本103。
步骤3、再将第一模型112的权重参数复制给第二模型113。
步骤4、使用上述有标签样本101和伪标签样本103对第二模型113进行训练,获得更新后的第二模型的权重参数。
具体实现中,使用上述有标签样本101和伪标签样本103对第二模型113进行训练时,模型的损失包括有标签损失和伪标签损失,其中,有标签损失是根据模型输出与有标签样本101的标签之间的差值获得的,伪标签损失是根据模型输出与伪标签样本103的伪标签之间的差值获得的。
步骤5、将更新后的第二模型的权重参数反馈给第一模型,对第一模型的权重参数进行更新。
重复步骤2~步骤5,直至模型收敛,获得训练好的目标模型114。这样,第一模型可以缓慢、稳定的进行权重更新,更具鲁棒性,模型性能更强。
但是,由于有标签样本101需要人工标注,数量有限,通常情况下,无标签样本102的数量远大于有标签样本101,因此第二模型113的训练过程中,伪标签样本103占比远大于有标签样本101,伪标签损失的权重也远高于有标签损失,导致伪标签样本的优劣决定了第二模型和第一模型的优化方向,而伪标签是第一模型推理得到的标签,伪标签会出现较多错误和噪声,半监督学习又无法对伪标签进行校正,若人工对伪标签进行校正,不仅效率低下且人力成本太高。
综上可知,半监督学习过程中,由于无标签样本的数量远高于有标签样本,伪标签样本对第二模型训练的影响很大,若第一模型生成的伪标签误差较大时,会导致第二模型的训练效率低,模型性能差,进而影响最终获得第一模型的训练效率以及模型性能。
为了解决上述问题,本申请提供了一种半监督模型训练系统,该系统通过对无标签样本进行数据增强获得无标签样本的扩充样本,然后将无标签样本和扩充样本输入第一模型推理获得无标签样本的第一伪标签以及扩充样本的第二伪标签,然后根据第一伪标签和第二伪标签的匹配结果获得第三伪标签,这样获得的第三伪标签质量更高,使用第三伪标签对第二模型进行半监督训练时,模型的训练效率和性能得以提升,进而提升最终获得的第一模型的训练效率以及模型性能。
图2是本申请提供的一种半监督模型训练系统的架构示意图,如图2所示,该半监督模型训练系统的架构包括推理设备110、半监督模型训练系统120、用户设备140、数据存储系统150以及数据采集设备160。其中,推理设备110、半监督模型训练系统120、用户设备140、数据存储系统150以及数据采集设备160之间可建立通信连接,具体可通过有线网络或者无线网络的方式建立通信连接,本申请不作具体限定。
数据采集设备160用于采集原始样本,将其发送给半监督模型训练系统120进行模型训练,其中,该原始样本可以是图像类型的原始样本,数据采集设备160可包括图像采集装置、雷达采集装置等用于采集原始样本的传感器等,图像采集装置可以是监控摄像头、电子警察、深度摄像机、无人机等等,雷达采集装置可以是雷达、卫星等,本申请不作具体限定。应理解,在不同的应用场景下,训练机器学习模型所需的原始样本不同,对应的数据采集设备160也不同,本申请不对此进行具体限定。
半监督模型训练系统120用于接收数据采集设备160采集的原始样本,对原始样本进行处理后获得图2所示的有标签样本集131、无标签样本集132以及扩充样本集133,使用上述样本集对第一模型125和第二模型126进行训练后,获得训练好的目标模型127,并将其发送给推理设备110。
其中,半监督模型训练系统120可以部署于计算设备上,该计算设备可以是裸金属服务器(bare metal server,BMS)、虚拟机或容器。其中,BMS指的是通用的物理服务器,例如,ARM服务器或者X86服务器;虚拟机指的是网络功能虚拟化(network functionsvirtualization,NFV)技术实现的、通过软件模拟的具有完整硬件系统功能的、运行在一个完全隔离环境中的完整计算机系统,容器指的是一组受到资源限制,彼此间相互隔离的进程,计算设备还可以是边缘计算设备,本申请不作具体限定。可选地,半监督模型训练系统120也可以是服务器集群,比如集中式服务器或者分布式服务器。
可选地,半监督模型训练系统120也可部署于公有云中,作为一项模型训练的云服务提供给公有云用户,用户可通过购买该服务获取半监督模型训练系统120的使用权限,本申请不作具体限定。
可选地,半监督模型训练系统120也可以通过软件打包的方式提供给用户,用户将软件安装于自己的计算设备上,或者,以微服务的方式提供给用户。举例来说,软件打包的方式提供给用户后,用户可根据自己的需求购买所需的软件版本或者能力,获得相应权限的许可(license),不同权限可设置不同的收费。
推理设备110用于接收半监督模型训练系统120发送的目标模型127,使用目标模型127对用户设备140发送的输入数据进行推理,获得输出数据,并将其返回给用户设备140,或者,将其存储于数据存储系统150。上述推理设备110可以是计算设备,具体可以是BMS、虚拟机、容器、终端设备或者边缘计算设备,本申请不作具体限定。
用户设备140可以是用户所持有的终端设备,包括计算机、智能手机、掌上处理设备、平板电脑、移动笔记本、增强现实(augmented reality,AR)设备、虚拟现实(virtualreality,VR)设备、一体化掌机、穿戴设备、车载设备、智能会议设备、智能广告设备、智能家电等等,此处不作具体限定。
可选地,用户设备140也可以是数据采集设备,其采集的输入数据可以输入至推理设备110进行目标检测或者图像识别,输出结果存储至数据存储系统150。举例来说,用户设备140可以是道路上的电子警察,推理设备110是道路两侧的边缘计算设备,目标模型是车牌识别模型,数据存储系统是交警大队维护的数据库,电子警察采集的超速车辆图片可以输入边缘计算设备中的车牌识别模型,识别出超速车辆图片中超速车辆的车牌号,将其存储于交警大队维护的数据库中,应理解,上述举例用于说明,本申请不作具体限定。
可选地,用户设备140与推理设备110可以是同一个设备,比如用户的智能手机下载半监督模型训练系统120训练好的人脸识别模型后,用户通过智能手机上的摄像头采集人脸输入数据,将其输入至人脸识别模型,获得人脸识别结果,人脸识别结果可以直接显示于用户设备140上,或者,存储于远程服务器以便后续进行安全解锁、安全支付等认证匹配,本申请不作具体限定。
数据存储系统150可以是具有存储功能的服务器或者存储阵列,该服务器可以是物理服务器比如ARM服务器或者X86服务器,还可以是虚拟机,本申请不作具体限定。数据存储系统150用于存储推理设备110的输出数据。
进一步地,半监督模型训练系统120可进一步划分为多个单元模块,图2是一种示例性划分方式,如图2所示,半监督模型训练系统120可包括数据增强单元121、推理单元122、匹配单元123以及训练单元124,其中,数据增强单元121、推理单元122、匹配单元123以及训练单元124,之间建立通信连接,具体可以是有线连接或者无线连接,本申请不作具体限定。需要说明的,样本数据库130可以是如图2所示的,存储于半监督模型训练系统120内,也可以存储于半监督模型训练系统120的外部存储器中,本申请不作具体限定。
半监督模型训练系统120还可包括样本数据库130,其中,样本数据库130包括无标签样本集132、有标签样本集131以及扩充样本集133,上述无标签样本集132可以包括多个没有标签的无标签样本,该无标签样本可以是数据采集设备160采集的原始样本,或者,对原始样本进行数据预处理之后获得的(比如裁剪、降噪等提高样本质量的预处理手段);有标签样本集131可以包括多个有标签样本,每个有标签样本包括输入样本和真实标签,其中,输入样本可以是上述原始样本,也可以是对原始样本进行数据预处理之后获得的,输入样本的真实标签可以是人工标注后获得的;扩充样本集133包括多个没有标签的扩充样本,该扩充该样本是半监督模型训练系统120的数据增强单元121对无标签样本集132进行数据增强后获得的。
数据增强单元121可通过数据增强方法对无标签样本集132中的无标签样本进行数据增强,获得无标签样本对应的扩充样本,以此类推,获得扩充样本集,其中,数据增强方法可包括但不限于翻转变换(flip)、平移变换(shift)、尺度变换(scale)、旋转变换/反射变换(rotation/reflection)、缩放变换(zoom)、修剪(crop)、颜色变换(color space)、噪声扰动(noise)、内核过滤(kernel filters)中的一种或者多种。
其中,翻转变换指的是对图像进行水平或垂直翻转,水平翻转还可分为向上水平翻转和向下水平翻转,垂直翻转还可分为向左垂直翻转和向右垂直翻转;平移变换指的是对图像进行平移操作,比如x方向向右平移(param xoffset),y方向向下平移(paramyoffset),其中x方向和y方向指的是图像坐标系的横轴方向和纵轴方向;旋转变换也可称为反射变换,指的是对图像进行某个角度的旋转,该角度可以是0~360度中的任意角度;缩放变换指的是将图像按照一定比例进行放大或者缩小,而不会改变图像中的内容;修剪也可称为裁剪,包括统一裁剪和随机裁剪,统一裁剪指的是将不同尺寸的图像裁剪至设定大小,随机裁剪指的是将不同尺寸的图像随机裁剪成不同尺寸大小;颜色变换指的是对图像某种颜色通道进行修改,比如关闭通道或者改变通道亮度值,举例来说,图像通常包括RGB三个通道,颜色变换可以将R通道值减少或增大;噪声扰动指的是从高斯分布中采样出的随机值矩阵加入到图像的RGB像素矩阵中;内核过滤指的是使用特定功能的内核过滤器与图像进行卷积操作,比如锐化、模糊等内核过滤器。
应理解,上述数据增强方法用于举例说明,本申请还可通过其他数据增强的方法对无标签样本进行扩充获得扩充样本,比如对图像进行还可以通过对抗生成(adversarialtraining)、特征空间增强(feature space augmentation)、基于GAN的数据增强(gan-based data augmentation)等数据增强方法对无标签样本集132进行扩充,获得扩充样本集133,本申请不一一展开赘述。
需要说明的,由于目标检测模型的标签通常为检测框(bounding box),因此在在第一模型125是目标检测模型时,可使用翻转变换、平移变换、尺度变换、旋转变换、缩放变换等对检测框产生影响的数据增强方法,对无标签样本进行数据增强;由于图像识别模型的标签为图像所属类别的概率分布,因此在第一模型125是图像识别模型时,可使用修剪、颜色变换、噪声扰动、内核过滤等对图像类别判定产生影响的数据增强方法,对无标签样本进行数据增强。应理解,针对模型类型进行不同的数据增强操作,可以使得最终获得的扩充样本可以增加模型的泛化性能,提高模型的鲁棒性。
需要说明的,数据增强单元121和样本数据库130也可以部署于半监督模型训练系统120之外,比如半监督模型训练系统120与预处理系统建立连接,数据增强单元121和样本数据库130部署于该预处理系统中,通过预处理系统对样本数据库130进行维护,以及对无标签样本集132进行数据增强操作,本申请不作具体限定。
推理单元122用于将第一无标签样本输入第一模型125生成第一无标签样本的第一伪标签,将第一扩充样本输入第一模型125生成第一扩充样本的第二伪标签。其中,第一模型125是使用有标签样本进行训练后获得的AI模型,需要说明的,使用有标签样本集对第一模型进行训练时,可以控制训练的轮数,使得第一模型具备一定的检测能力,防止后续半监督训练过程中出现第一模型125和第二模型126过拟合的现象。
具体实现中,第一模型可以是目标检测模型或者图像识别模型。目标检测模型可以是一阶段统一实时目标检测(You Only Look Once:Unified,Real-Time ObjectDetection,Yolo)模型、单镜头多盒检测器(Single Shot multi box Detector,SSD)模型、区域卷积神经网络(Region Convolutional Neural Network,RCNN)模型或快速区域卷积神经网络(Fast Region Convolutional Neural Network,Fast-RCNN)模型等,本申请不作具体限定。
可选地,第一模型125为目标检测模型时,目标检测模型的输出结果可能是多个目标检测框,推理单元122可以通过非极大抑制(non maximum suppression,NMS)方法,选择多个目标检测框中精度最高的检测框作为第一伪标签或者第二伪标签,从而增加第一无标签样本的第一伪标签和扩充样本的第二伪标签的精度。
匹配单元123用于将第一无标签样本的第一伪标签和第一扩充样本的第二伪标签进行匹配,获得第一无标签样本的第三伪标签。
可选地,匹配单元123可以根据第一伪标签和第二伪标签,获得第一伪标签和第二伪标签之间的匹配度,在匹配度高于阈值的情况下,将第一伪标签和第二伪标签进行融合,获得第三伪标签。
可选地,若第一模型125为目标检测模型,匹配单元123在将第一伪标签和第二伪标签进行匹配时,可以对第二伪标签进行数据增强的逆操作,获得第四伪标签,然后将第四伪标签与第一伪标签进行匹配,获得第一伪标签和第四伪标签之间的匹配结果,根据上述匹配结果确定上述匹配度。其中,逆操作指的是与数据增强单元121执行的数据增强方法相反的操作,比如数据增强单元121对第一无标签样本进行了水平向上翻转操作获得第一扩充样本,那么匹配单元123此时可以对第一扩充样本的第二伪标签对应的标准框进行水平向下翻转操作获得第四伪标签,再比如数据增强单元121对第一无标签样本进行了向右旋转90°操作获得第一扩充样本,那么匹配单元123可以对第一扩充样本的第二伪标签对应的标准框进行向左旋转90操作获得第四伪标签,以此类推,这里不一一举例说明。
可以理解的,目标检测模型的伪标签是图像中目标的检测框,因此通过数据增强方法获得的第一扩充样本,目标的位置实际已发生改变,检测框需要对其进行逆操作,使得第一伪标签和第二伪标签所框选的目标是同一个位置的目标,然后再将其进行匹配,可以筛选出标注目标不准确的第一伪标签,从而避免对第二模型126进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
具体实现中,匹配单元123可以将第一伪标签对应的检测框与第四伪标签对应的检测框进行匹配,获得上述匹配度,这里的匹配度可以是两个检测框之间的交并比(intersection over union,IOU)。
具体实现中,在匹配度大于阈值时,匹配单元可以将第一伪标签和第二伪标签进行融合获得第三伪标签时,可以将上述第四伪标签对应的检测框与第一伪标签对应的检测框进行多值平均处理,获得第三伪标签。举例来说,第一伪标签对应的检测框坐标为
Figure BDA0003604427170000121
第四伪标签对应的检测框坐标为
Figure BDA0003604427170000122
那么融合后的第三伪标签yu的公式可以如下:
Figure BDA0003604427170000123
应理解,上述公式(1)用于助说明,还可以其他方式对第一伪标签和第四伪标签对应的检测框进行融合处理,比如加权平均,本申请不对此进行具体限定。
应理解,将第一伪标签对应的检测框与第四伪标签对应的检测框进行融合,两个检测框是使用不同方法确定的目标检测框,因此将二者融合可以进一步提高最终获得的第三伪标签的精度,从而避免对第二模型126进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
可选地,若第一模型125为图像识别模型,可以将第一伪标签和第二伪标签进行匹配,获得第一伪标签和第二伪标签之间的匹配结果,根据该匹配结果确定二者之间的匹配度。具体实现中,匹配单元123可以将第一伪标签对应的概率分布与第二伪标签对应的概率分布进行匹配,确定二者之间的匹配度,这里的匹配度可以是两个概率分布之间的相似度或者距离,本申请不作具体限定。匹配单元123将匹配度大于阈值的第一伪标签和第二伪标签进行融合时,可以将第一伪标签的概率分布与第二伪标签的概率分布进行均值处理,比如平均数、加权平均等等,本申请不作具体限定。
可选地,推理单元122也可以将第二无标签样本输入第一模型,获得第二无标签样本的第五伪标签,然后将第二扩充样本输入第一模型,获得第二扩充样本的第六伪标签,这里,第二扩充样本为对第二无标签样本进行数据增强后获得的样本。匹配单元123可以根据第五伪标签和第六伪标签,获得第五伪标签和第六伪标签之间的匹配度,在匹配度不高于上述阈值的情况下,删除第五伪标签和第六伪标签。其中,上述第五伪标签和第六伪标签之间匹配度的确定方式可以参考前述内容中第一伪标签和第二伪标签之间匹配度的确定方式,这里不重复展开赘述。
需要说明的,第一伪标签和第二伪标签之间的匹配度如果不高于阈值,也可以将第一伪标签和第二伪标签删除,同理,如果第五伪标签和第六伪标签之间的匹配度高于阈值,也可以将第五伪标签和第六伪标签进行融合,融合方式和参考前述内容中关于第一伪标签和第二伪标签融合获得第三伪标签的描述这里不重复赘述。简单来说,无标签样本集132和扩充样本集133进行匹配,呈对应关系的无标签样本的伪标签和扩充样本的伪标签会进行匹配获得相应的匹配度,若匹配度高于阈值则将二者的伪标签进行融合,匹配度低于阈值则将二者的伪标签都进行删除。
可以理解的,通过将无标签样本集132的伪标签和扩充样本集133的伪标签进行匹配的方式,可以过滤出准确度较低的伪标签,从而避免对第二模型126进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
训练单元124用于使用上述第一无标签样本和第三伪标签对第二模型126进行训练,其中,第二模型126的权重参数与第一模型125相同,简单来说,使用有标签样本集131对机器学习模型进行训练,获得第一模型125,然后将第一模型125的权重参数拷贝给第二模型126,然后使用第一无标签样本和第三伪标签对第二模型126进行训练。
可选地,第一模型125的模型结构可以与第二模型126的模型结构相同。在该应用场景下,在训练单元124对第二模型126进行训练时,先将第一模型125的权重参数拷贝给第二模型126,然后使用第一无标签样本、第三伪标签以及上述训练第一模型125时使用的有标签样本对第二模型126进行迭代训练,根据每次迭代训练获得的第二模型的权重参数对第一模型的权重参数进行迭代更新,获得目标模型127。
具体地,使用第一无标签样本、第三伪标签以及上述有标签样本对第二模型126进行第一轮训练,获得第一轮更新后第二模型126的权重参数,然后将其发送给第一模型125进行对第一模型125的更新获得新的第一模型125,再将上述第一无标签样本和第一扩充样本输入新的第一模型125,预测出新的第一伪标签和新的第二伪标签,再将匹配度高于阈值的新的第一伪标签和新的第二伪标签进行融合获得新的第三伪标签,再使用第一无标签样本和新的第三伪标签以及有标签样本继续对第二模型126训练至收敛,获得第二轮更新后的权重参数,然后再将其更新至第一模型125,以此类推,这里不一一展开赘述。
需要说明的,在接下来多轮迭代训练过程中,比如第二轮训练时,可以将上述第二无标签样本和第二扩充样本输入新的第一模型125,获得新的第五伪标签和新的第六伪标签,如果第五伪标签和第六伪标签之间的匹配度仍然不高于阈值,可以继续将第五伪标签和第六伪标签删除;如果第五伪标签和第六伪标签之间的匹配度高于阈值,此时可以将第五伪标签和第六伪标签进行融合获得第七伪标签,然后使用第七伪标签、第二无标签样本以及有标签样本对第二模型进行训练至收敛,获得第二轮更新后第二模型的权重参数,再将其更新至第一模型125,以此类推。
同理,如果接下来多轮迭代训练过程中,比如第二轮训练时,新的第一伪标签和新的第二伪标签之间的匹配度不高于阈值,也可以将新的第一伪标签和新的第二伪标签删除,这里不重复赘述。
可以理解的,通过将第二模型的权重参数对第一模型进行更新,然后多轮迭代训练的方式,使得第一模型125推理出的伪标签精度越来越高,直至第一模型125的预测精度达到用户所需的标准,从而获得目标模型127,并将其发送至推理设备110。
具体实现中,第二模型126每轮训练获得的权重参数可以全部更新至第一模型125,也可以将每轮训练获得的部分权重参数更新至第一模型125,使得第一模型得到缓慢、稳定的权重更新,这样训练获得的第一模型更具鲁棒性,模型性能更佳。具体实现中,可通过指数滑动平均(exponential moving average,EMA)方法将学生的权重更新给第一模型125,举例来说,假设EMA=0.99,那么每轮训练获得的权重参数的1%将被更新入第一模型125。应理解,上述举例用于说明,本申请不作具体限定。
可以理解的,第二模型126与第一模型125的模型结构相同,可以使用少量的、获取困难的有标签样本和大量的、容易获取的无标签样本对机器学习模型进行训练,获得的目标模型127不仅鲁棒性好,而且模型性能好,训练效率高。
可选地,第一模型125的模型结构也可以包括第二模型126的模型结构,也就是说,第二模型126是小模型,第一模型125是大模型,比如第二模型126是第一模型125的一个子模型,同样的,第二模型126将每轮训练获得的更新后的权重同步至第一模型125,新的第一模型125再预测新的伪标签对第二模型126进行训练,以此类推,稳步更新第二模型126和第一模型125,这样最终获得的第二模型不仅结构复杂度低,而且模型性能与第一模型125趋于接近,甚至可以比第一模型125性能更好,从而达到模型压缩的目的。
在本申请实施例中,使用有标签样本、第一无标签样本和第三伪标签对第二模型126进行训练时,可以将有标签样本集131中的输入样本输入第二模型获得第一输出值,将第一无标签样本输入第二模型获得第二输出值,根据第一输出值和第二输出值确定第二模型的损失值,然后根据损失值对第二模型进行反向传播直至收敛,获得训练好的第二模型,然后将训练好的第二模型的模型参数同步至第一模型125中,再进行下一轮的模型训练。其中,上述损失值L包括有标签损失L1和无标签损失L2,该损失值L是根据第一输出值和真实标签之间的差距获得的,伪标签损失是根据第二输出值和第三伪标签之间的差距获得的。
具体实现中,可以通过系数加权的方式对有标签损失L1和无标签损失L2在损失值L中的比重进行调控,比如损失值L=L1+λL2,其中,λ越大,无标签损失L2在损失值L中的占比越大,第一无标签样本对第二模型126和第一模型125的模型性能影响越大,通过对λ的值进行调整,可以认为干预模型训练的学习方向。
需要说明的,本申请的半监督模型训练系统120也可以打包成一个软件模块,对现有的一些模型训练系统进行软件升级,使其能够拥有对伪标签过滤、融合的功能,使得升级后的件模型训练系统可以有更好的半监督训练性能。
参考前述内容可知,在公有云场景下,上述数据增强单元121、推理单元122、匹配单元123以及训练单元124可以打包为一个配置模块,作为公有云模型训练服务中的一个小的配置功能,如果公有云用户购买该功能,即可为用户提供相应的权限。在非公有云场景下,上述数据增强单元121、推理单元122、匹配单元123以及训练单元124可以打包为一个微服务或者软件包,用户购买本申请提供的伪标签过滤、融合功能之后,可以向用户提供相应的权限的许可(license),不同权限可设置不同的收费程度。本申请不作具体限定。
需要说明的,图2展示了模型训练系统的一种示例性划分方式,具体实现中,半监督模型训练系统120可包括更多或者更少的单元模块,比如样本数据库130和数据增强单元121可以部署于半监督模型训练系统120之外,比如匹配单元123可进一步划分为阈值判断单元、融合单元和删除单元,其中,阈值判断单元用于确定第一伪标签和第二伪标签之间的匹配度,在匹配度高于阈值的情况下,通过融合单元将第一伪标签和第二伪标签进行融合获得第三伪标签,在匹配度不高于阈值的情况下,通过删除单元将第一伪标签和第二伪标签删除,这里不重复展开赘述。
综上可知,本申请提供的模型训练系统,通过对无标签样本集进行数据增强获得无标签样本集的扩充样本集,然后将无标签样本集和扩充样本集输入第一模型推理获得无标签样本的集的多个第一伪标签以及扩充样本集的多个第二伪标签,然后将匹配度高于阈值的第一伪标签和第二伪标签进行融合获得第三伪标签,将匹配度低于或等于阈值的第一伪标签进行过滤,从而提高未标注样本集合的伪标签的质量,使得后续使用第三伪标签对学生模型进行半监督训练时,模型的训练效率和性能得以提升,进而提升最终获得的第一模型的训练效率以及模型性能。
图3是本申请提供的一种半监督模型训练方法的步骤流程示意图,该方法可应用于图2所示的半监督模型训练系统120中,如图3所示,该方法可包括以下步骤:
步骤S310:半监督模型训练系统120将第一无标签样本输入第一模型,获得第一无标签样本的第一伪标签。该步骤可以由图2中的推理单元122实现。
其中,第一模型是使用有标签样本集对机器学习模型进行训练后获得的,机器学习模型可以是目标检测模型或者图像识别模型。其中,目标检测模型和图像识别模型的描述可参考图2实施例中的详细描述,这里不重复赘述。
可选地,第一模型为目标检测模型时,目标检测模型的输出结果可能是多个目标检测框,半监督模型训练系统120可以NMS方法,选择多个目标检测框中精度最高的检测框作为第一伪标签或者第二伪标签,从而增加第一无标签样本的第一伪标签和第一扩充样本的第二伪标签的精度。
步骤S320:将第一扩充样本输入第一模型,获得第一扩充样本的第二伪标签。该步骤可以由图2中的推理单元122实现。
具体实现中,在步骤S320之前,可以对第一无标签样本进行数据增强,获得第一无标签样本的第一扩充样本。该步骤可以由图2中的数据增强单元121实现。
其中,半监督模型训练系统120的描述可参考图2实施例,这里不重复赘述,第一无标签样本的描述可参考术语解释,这里也不重复赘述。
在本申请实施例中,数据增强方法可包括但不限于翻转变换、平移变换、尺度变换、旋转变换/反射变换、缩放变换、修剪、颜色变换、噪声扰动、内核过滤中的一种或者多种,每种数据增强方法的具体描述可参考图2实施例中的描述,这里不重复赘述。
需要说明的,由于目标检测模型的标签通常为检测框(bounding box),因此在在第一模型是目标检测模型时,可使用翻转变换、平移变换、尺度变换、旋转变换、缩放变换等对检测框产生影响的数据增强方法,对第一无标签样本进行数据增强;由于图像识别模型的标签为图像所属类别的概率分布,因此在第一模型是图像识别模型时,可使用修剪、颜色变换、噪声扰动、内核过滤等对图像类别判定产生影响的数据增强方法,对第一无标签样本进行数据增强。应理解,针对模型类型进行不同的数据增强操作,可以使得最终获得的扩充样本可以增加模型的泛化性能,提高模型的鲁棒性。
需要说明的,步骤S310也可以由其他数据预处理系统实现,换句话说,半监督模型训练系统120也可以向其他数据预处理系统获取上述第一扩充样本和第一无标签样本,也可以自己对第一无标签样本进行数据增强获得相应的第一扩充样本,具体可根据实际业务处理情况决定,本申请不作具体限定。
步骤S330:根据第一伪标签和第二伪标签获得第一无标签样本的第三伪标签。该步骤可以由图2中的匹配单元123实现。
具体实现中,半监督模型训练系统120可以根据第一伪标签和第二伪标签,获得第一伪标签和第二伪标签之间的匹配度,在匹配度高于阈值的情况下,将第一伪标签和第二伪标签进行融合,获得第三伪标签。
可选地,若第一模型为目标检测模型,半监督模型训练系统120在将第一伪标签和第二伪标签进行匹配时,可以对第二伪标签进行数据增强的逆操作,获得第四伪标签,然后将第四伪标签与第一伪标签进行匹配,获得上述匹配度。其中,逆操作指的是与数据增强单元121执行的数据增强方法相反的操作,比如步骤S310对第一无标签样本进行了水平向上翻转操作获得第一扩充样本,那么步骤S330此时可以对第一扩充样本的第二伪标签对应的标准框进行水平向下翻转操作获得第四伪标签,再比如步骤S310对第一无标签样本进行了向右旋转90°操作获得第一扩充样本,那么步骤S330可以对第一扩充样本的第二伪标签对应的标准框进行向左旋转90操作获得第四伪标签,以此类推,这里不一一举例说明。
可以理解的,目标检测模型的伪标签是图像中目标的检测框,因此通过数据增强方法获得的第一扩充样本,目标的位置实际已发生改变,检测框需要对其进行逆操作,使得第一伪标签和第二伪标签所框选的目标是同一个位置的目标,然后再将其进行匹配,可以筛选出标注目标不准确的第一伪标签,从而避免对第二模型进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
具体实现中,匹配单元123可以将第一伪标签对应的检测框与第四伪标签对应的检测框进行匹配,获得上述匹配度,这里的匹配度可以是两个检测框之间的IOU。在匹配度大于阈值时,匹配单元可以将第一伪标签和第二伪标签进行融合获得第三伪标签时,可以将上述第四伪标签对应的检测框与第一伪标签对应的检测框进行多值平均处理,获得第三伪标签。具体可参考公式(1)的相关描述,这里不重复赘述。
应理解,将第一伪标签对应的检测框与第四伪标签对应的检测框进行融合,两个检测框是使用不同方法确定的目标检测框,因此将二者融合可以进一步提高最终获得的第三伪标签的精度,从而避免对第二模型进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
可选地,若第一模型为图像识别模型,可以将第一伪标签和第二伪标签进行匹配,获得第一伪标签和第二伪标签之间的匹配结果,根据该匹配结果确定二者之间的匹配度。具体实现中,可以将第一伪标签对应的概率分布与第二伪标签对应的概率分布进行匹配,确定二者之间的匹配度,这里的匹配度可以是两个概率分布之间的相似度或者距离,本申请不作具体限定。将匹配度大于阈值的第一伪标签和第二伪标签进行融合时,可以将第一伪标签的概率分布与第二伪标签的概率分布进行均值处理,比如平均数、加权平均等等,本申请不作具体限定。
可选地,也可以将第二无标签样本输入第一模型,获得第二无标签样本的第五伪标签,然后将第二扩充样本输入第一模型,获得第二扩充样本的第六伪标签,这里,第二扩充样本为对第二无标签样本进行数据增强后获得的样本。根据第五伪标签和第六伪标签,获得第五伪标签和第六伪标签之间的匹配度,在匹配度不高于上述阈值的情况下,删除第五伪标签和第六伪标签。其中,上述第五伪标签和第六伪标签之间匹配度的确定方式可以参考前述内容中第一伪标签和第二伪标签之间匹配度的确定方式,这里不重复展开赘述。
需要说明的,第一伪标签和第二伪标签之间的匹配度如果不高于阈值,也可以将第一伪标签和第二伪标签删除,同理,如果第五伪标签和第六伪标签之间的匹配度高于阈值,也可以将第五伪标签和第六伪标签进行融合,融合方式和参考前述内容中关于第一伪标签和第二伪标签融合获得第三伪标签的描述这里不重复赘述。简单来说,无标签样本集132和扩充样本集133进行匹配,呈对应关系的无标签样本的伪标签和扩充样本的伪标签会进行匹配获得相应的匹配度,若匹配度高于阈值则将二者的伪标签进行融合,匹配度低于阈值则将二者的伪标签都进行删除。
可以理解的,通过将无标签样本集132的伪标签和扩充样本集133的伪标签进行匹配的方式,可以过滤出准确度较低的伪标签,从而避免对第二模型126进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
步骤S340:使用第一无标签样本和第三伪标签对第二模型进行训练,该第二模型与第一模型的权重参数相同。该步骤可以由图2中的训练单元124实现。简单来说,使用有标签样本集对机器学习模型进行训练,获得第一模型,然后将第一模型的权重参数拷贝给第二模型,然后使用第一无标签样本和第三伪标签对第二模型进行训练。
可选地,第一模型的模型结构可以与第二模型的模型结构相同,在该应用场景下,在对第二模型126进行训练时,先将第一模型125的权重参数拷贝给第二模型126,然后使用第一无标签样本、第三伪标签以及上述训练第一模型125时使用的有标签样本对第二模型126进行迭代训练,根据每次迭代训练获得的第二模型的权重参数对第一模型的权重参数进行迭代更新,获得目标模型127。
具体地,使用第一无标签样本、第三伪标签以及上述有标签样本对第二模型126进行第一轮训练,获得第一轮更新后第二模型126的权重参数,然后将其发送给第一模型125进行对第一模型125的更新获得新的第一模型125,再将上述第一无标签样本和第一扩充样本输入新的第一模型125,预测出新的第一伪标签和新的第二伪标签,再将匹配度高于阈值的新的第一伪标签和新的第二伪标签进行融合获得新的第三伪标签,再使用第一无标签样本和新的第三伪标签以及有标签样本继续对第二模型126训练至收敛,获得第二轮更新后的权重参数,然后再将其更新至第一模型125,以此类推,这里不一一展开赘述。
需要说明的,在接下来多轮迭代训练过程中,比如第二轮训练时,可以将上述第二无标签样本和第二扩充样本输入新的第一模型125,获得新的第五伪标签和新的第六伪标签,如果第五伪标签和第六伪标签之间的匹配度仍然不高于阈值,可以继续将第五伪标签和第六伪标签删除;如果第五伪标签和第六伪标签之间的匹配度高于阈值,此时可以将第五伪标签和第六伪标签进行融合获得第七伪标签,然后使用第七伪标签、第二无标签样本以及有标签样本对第二模型进行训练至收敛,获得第二轮更新后第二模型的权重参数,再将其更新至第一模型125,以此类推。
同理,如果接下来多轮迭代训练过程中,比如第二轮训练时,新的第一伪标签和新的第二伪标签之间的匹配度不高于阈值,也可以将新的第一伪标签和新的第二伪标签删除,这里不重复赘述。
可以理解的,通过将第二模型的权重参数对第一模型进行更新,然后多轮迭代训练的方式,使得第一模型125推理出的伪标签精度越来越高,直至第一模型125的预测精度达到用户所需的标准,从而获得目标模型127,并将其发送至推理设备110。
具体实现中,第二模型每轮训练获得的权重参数可以全部更新至第一模型,也可以将每轮训练获得的部分权重参数更新至第一模型,使得第一模型得到缓慢、稳定的权重更新,这样训练获得的第一模型更具鲁棒性,模型性能更佳。具体实现中,可通过EMA方法将学生的权重更新给第一模型,举例来说,假设EMA=0.99,那么每轮训练获得的权重参数的1%将被更新入第一模型。应理解,上述举例用于说明,本申请不作具体限定。
可以理解的,第二模型与第一模型的模型结构相同,可以使用少量的、获取困难的有标签样本和大量的、容易获取的无标签样本对机器学习模型进行训练,获得的目标模型127不仅鲁棒性好,而且模型性能好,训练效率高。
可选地,第一模型的模型结构也可以包括第二模型的模型结构,也就是说,第二模型是第一模型的一个子模型,同样的,第二模型将每轮训练获得的更新后的权重同步至第一模型,新的第一模型再预测新的伪标签对第二模型进行训练,以此类推,稳步更新第二模型和第一模型,这样最终获得的第二模型不仅结构复杂度低,而且模型性能与第一模型趋于接近,甚至可以比第一模型性能更好,从而达到模型压缩的目的。
在本申请实施例中,使用有标签样本、第一无标签样本和第三伪标签对第二模型进行训练时,可以将有标签样本集131中的输入样本输入第二模型获得第一输出值,将第一无标签样本输入第二模型获得第二输出值,根据第一输出值和第二输出值确定第二模型的损失值,然后根据损失值对第二模型进行反向传播直至收敛,获得训练好的第二模型,然后将训练好的第二模型的模型参数同步至第一模型中,再进行下一轮的模型训练。其中,上述损失值L包括有标签损失L1和无标签损失L2,该损失值L是根据第一输出值和真实标签之间的差距获得的,伪标签损失是根据第二输出值和第三伪标签之间的差距获得的。
具体实现中,可以通过系数加权的方式对有标签损失L1和无标签损失L2在损失值L中的比重进行调控,比如损失值L=L1+λL2,其中,λ越大,无标签损失L2在损失值L中的占比越大,第一无标签样本对第二模型和第一模型的模型性能影响越大,通过对λ的值进行调整,可以认为干预模型训练的学习方向。
下面结合附图4和图5,对图3所示的半监督模型训练方法进行举例说明,图4是本申请提供的半监督模型训练方法在一应用场景下的步骤流程示意图,图5是本申请提供的半监督模型训练方法中第一伪标签和第二伪标签的融合流程示意图,该应用场景中,第一模型为目标检测模型,第二模型为网络结构与第一模型相同的目标检测模型。
如图4所示,该应用场景下的半监督模型训练方法可包括以下步骤。
步骤1.获取训练样本集,该训练样本集可包括图2所示的有标签样本集131和无标签样本集132。
步骤2.判断是否为有标签样本,具体地,将训练样本集中的每个样本进行判断分类,有标签样本执行步骤3,无标签样本执行步骤4。
步骤3.使用有标签样本训练第一模型。该第一模型即为图2和图3实施例中的第一模型125。步骤3执行完毕后执行步骤5或步骤6。
需要说明的,使用有标签样本集对机器学习模型进行训练时,可以控制训练的轮数,使得第一模型具备一定的检测能力,防止后续半监督训练过程中出现第一模型和第二模型过拟合的现象。
步骤4.对无标签样本进行数据增强,获得扩充样本。该步骤的具体描述可参考图3实施例中的步骤S320,该步骤可以由图2实施例中的数据增强单元121实现。步骤4执行完毕后执行步骤7。
具体实现中,图4所示的应用场景使用的数据增强方法为:将第一无标签样本垂直向右翻转。
需要说明的,步骤3和步骤4可以并行或者串行处理,具体可根据半监督模型训练系统120所部署计算设备的处理能力决定,本申请不作具体限定。
步骤5.将第一模型的模型参数拷贝至第二模型。第二模型即为图2和图3实施例中描述的第二模型126。
步骤6.将无标签样本输入第一模型获得第一伪标签。该步骤可由图2实施例中的推理单元122实现,具体可参考图3实施例步骤S310的描述,这里不重复赘述。
步骤7.将扩充样本输入第一模型获得第二伪标签。该步骤可由图2实施例中的推理单元122实现,具体可参考图3实施例步骤S320的描述,这里不重复赘述。
需要说明的步骤6可以与步骤7并行或串行处理,本申请不作具体限定。
步骤8.对第二伪标签进行数据增强逆操作,获得第四伪标签。
具体实现中,数据增强的逆操作指的是对步骤4中的数据增强对应的逆操作,步骤4将第一无标签样本进行垂直向右的翻转,那么步骤8可以对第二伪标签进行垂直向左的翻转。
步骤9.将第一伪标签和第四伪标签进行匹配,获得匹配度。
具体实现中,步骤9可以将第一伪标签对应的检测框与第四伪标签对应的检测框进行匹配,这里的匹配度可以是两个检测框之间的IOU。
步骤10.判断匹配度是否高于阈值,在匹配度高于阈值的情况下执行步骤11,在匹配度低于阈值的情况下执行步骤14。
具体地,假设阈值为0.45,那么步骤10可以将IOU大于0.45的第一伪标签和第四伪标签进行融合,即执行步骤11,将其删除。
步骤11.将匹配度高于阈值的第一伪标签和第四伪标签融合为第三伪标签。上述步骤8~步骤11可以由图2实施例中的匹配单元123实现,具体可参考图3实施例中的步骤S330,这里不重复赘述。
示例性地,如图5所示,图5是本申请提供的半监督模型训练方法中第一伪标签和第二伪标签融合为第三伪标签的步骤流程示意图。对应图4实施例中的步骤4、步骤6~步骤11,由图5可知,步骤4对第一无标签样本进行扩充后,步骤6将第一无标签样本输入第一模型125获得第一伪标签,步骤7将扩充样本输入第一模型125获得第二伪标签,其中,第一伪标签和第二位标签的精度较低,两个伪标签对应的检测框都没有完整的标注出目标(也就是车辆)。
可以理解的,图5清晰的显示了第二伪标签所标注的样本与第一伪标签所标注的样本是不同的样本,第一伪标签所标注的未标注样本中货车在上方车道,第二伪标签所标注的扩充样本中货车在下方车道,因此不能直接执行步骤10将二者进行匹配。通过步骤8将第二伪标签进行数据增强的逆操作,即水平翻转-180度之后,获得第四伪标签,此时第一伪标签和第四伪标签所标注的样本是同一个样本,即货车在上的样本。这样,执行步骤10将第一伪标签和第四伪标签进行匹配,二者的匹配度可以良好的指示第一伪标签的精度,将匹配度高于阈值的第一伪标签和第四伪标签融合后,还可以进一步提高第一伪标签的精度,如图5所示,步骤11获得的第三伪标签相比第一伪标签和第二伪标签的精度更高,准确度更高,使用第三伪标签训练第二模型不仅可以提高训练效率,提高第二模型的性能,在第二模型的权重参数引入第一模型后,还可以提高第一模型的性能,整个过程不需要人工辅助校正样本,提高整个训练过程的训练效率,降低人力成本。
需要说明的,如果步骤10对第一标签和第二标签之间的匹配度判定为不高于阈值,那么可执行步骤14将第一伪标签、第二伪标签以及第四伪标签删除,本轮训练过程中,该第一无标签样本可以不参与训练,下一轮训练过程中,如果第一伪标签和第二伪标签之间的匹配度高于阈值,那么可以参与下一轮的训练,以此类推,这里不展开赘述。
步骤12.使用第三伪标签、无标签样本和有标签样本对第二模型进行训练。
步骤13.将训练好的第二模型的权重参数引入第一模型。上述步骤12和13可以由图1实施例中的训练单元124实现,具体可参考图3实施例中的步骤S340,这里不重复赘述。
可选地,可以根据EMA方法将部分训练好的第二模型的权重参数引入第一模型,获得新的第一模型,继续执行步骤6~步骤14,以此类推,直至第一模型收敛。
步骤14.删除第一伪标签和第二伪标签。
需要说明的,在重复执行步骤6~步骤14时,即使上一轮第一伪标签和第二伪标签进行了删除,下一轮中新的第一伪标签和第二伪标签可继续进行匹配、融合或删除的过程,使得第一模型推理出的第一伪标签和第二伪标签的精度越来越高,从而使得模型训练效果越来越好。
综上可知,本申请提供的半监督模型训练方法,通过对无标签样本进行数据增强获得无标签样本的扩充样本,然后将无标签样本和扩充样本输入第一模型推理获得无标签样本的第一伪标签以及扩充样本的第二伪标签,然后将匹配度高于阈值的第一伪标签和第二伪标签进行融合获得第三伪标签,将匹配度低于或等于阈值的第一伪标签进行过滤,从而提高未标注样本的伪标签的质量,使得后续使用第三伪标签对第二模型进行半监督训练时,模型的训练效率和性能得以提升,进而提升最终获得的第一模型的训练效率以及模型性能。
图6是本申请提供的一种计算设备的结构示意图,该计算设备600是图1至图5实施例中的半监督模型训练系统120。
进一步地,计算设备600包括处理器601、存储单元602、存储介质603和通信接口604,其中,处理器601、存储单元602、存储介质603和通信接口604通过总线605进行通信,也通过无线传输等其他手段实现通信。
处理器601由至少一个通用处理器构成,例如CPU、NPU或者CPU和硬件芯片的组合。上述硬件芯片是专用集成电路(Application-Specific Integrated Circuit,ASIC)、编程逻辑器件(Programmable Logic Device,PLD)或其组合。上述PLD是复杂编程逻辑器件(Complex Programmable Logic Device,CPLD)、现场编程逻辑门阵列(Field-Programmable Gate Array,FPGA)、通用阵列逻辑(Generic Array Logic,GAL)或其任意组合。处理器601执行各种类型的数字存储指令,例如存储在存储单元602中的软件或者固件程序,它能使计算设备600提供较宽的多种服务。
具体实现中,作为一种实施例,处理器601包括一个或多个CPU,例如图6中所示的CPU0和CPU1。
在具体实现中,作为一种实施例,计算设备600也包括多个处理器,例如图6中所示的处理器601和处理器606。这些处理器中的每一个可以是一个单核处理器(single-CPU),也可以是一个多核处理器(multi-CPU)。这里的处理器指一个或多个设备、电路、和/或用于处理数据(例如计算机程序指令)的处理核。
存储单元602用于存储程序代码,并由处理器601来控制执行,以执行上述图1-图5中任一实施例中半监督模型训练系统120的处理步骤。程序代码中包括一个或多个软件单元,上述一个或多个软件单元是图2实施例中的推理单元、匹配单元和训练单元,其中,推理单元用于将第一无标签样本输入第一模型获得第一无标签样本的第一伪标签,将第一扩充样本输入第一模型获得第一扩充样本的第二伪标签;匹配单元用于将第一伪标签和第二伪标签进行匹配,获得无标签样本的第三伪标签;训练单元用于使用第三伪标签和第一无标签样本对第二模型进行训练。其中,推理单元用于执行图3中的步骤S310~步骤S320以及图4和图5中的步骤6和步骤7,匹配单元用于执行图3中的步骤S330以及图4和图5中的步骤8~步骤11,训练单元用于执行图3中的步骤后S340以及图4中的步骤12和步骤13。具体实现方式参考图1~图5实施例,此处不再赘述。
存储单元602包括只读存储器和随机存取存储器,并向处理器601提供指令和数据。存储单元602还包括非易失性随机存取存储器。存储单元602是易失性存储器或非易失性存储器,或包括易失性和非易失性存储器两者。其中,非易失性存储器是只读存储器(read-only memory,ROM)、编程只读存储器(programmable ROM,PROM)、擦除编程只读存储器(erasable PROM,EPROM)、电擦除编程只读存储器(electrically EPROM,EEPROM)或闪存。易失性存储器是随机存取存储器(random access memory,RAM),其用作外部高速缓存。通过示例性但不是限制性说明,许多形式的RAM用,例如静态随机存取存储器(static RAM,SRAM)、动态随机存取存储器(DRAM)、同步动态随机存取存储器(synchronous DRAM,SDRAM)、双倍数据速率同步动态随机存取存储器(double data date SDRAM,DDR SDRAM)、增强型同步动态随机存取存储器(enhanced SDRAM,ESDRAM)、同步连接动态随机存取存储器(synchlink DRAM,SLDRAM)和直接内存总线随机存取存储器(direct rambus RAM,DRRAM)。还是硬盘(hard disk)、U盘(universal serial bus,USB)、闪存(flash)、SD卡(secure digital memory Card,SD card)、记忆棒等等,硬盘是硬盘驱动器(hard diskdrive,HDD)、固态硬盘(solid state disk,SSD)、机械硬盘(mechanical hard disk,HDD)等,本申请不作具体限定。
存储介质603是存储数据的载体,比如硬盘(hard disk)、U盘(universal serialbus,USB)、闪存(flash)、SD卡(secure digital memory Card,SD card)、记忆棒等等,硬盘可以是硬盘驱动器(hard disk drive,HDD)、固态硬盘(solid state disk,SSD)、机械硬盘(mechanical hard disk,HDD)等,本申请不作具体限定。
通信接口604为有线接口(例如以太网接口),为内部接口(例如高速串行计算机扩展总线(Peripheral Component Interconnect express,PCIe)总线接口)、有线接口(例如以太网接口)或无线接口(例如蜂窝网络接口或使用无线局域网接口),用于与其他服务器或单元进行通信。
总线605是快捷外围部件互联标准(Peripheral Component InterconnectExpress,PCIe)总线,或扩展工业标准结构(extended industry standard architecture,EISA)总线、统一总线(unified bus,Ubus或UB)、计算机快速链接(compute express link,CXL)、缓存一致互联协议(cache coherent interconnect for accelerators,CCIX)等。总线605分为地址总线、数据总线、控制总线等。
总线605除包括数据总线之外,还包括电源总线、控制总线和状态信号总线等。但是为了清楚说明起见,在图中将各种总线都标为总线605。
需要说明的,图6仅仅是本申请实施例的一种能的实现方式,实际应用中,计算设备600还包括更多或更少的部件,这里不作限制。关于本申请实施例中未示出或未描述的内容,参见前述图1-图5实施例中的相关阐述,这里不再赘述。
本申请实施例提供一种计算机存储介质,包括:该计算机存储介质中存储有指令;当该指令在计算设备上运行时,使得该计算设备执行上述图1至图5描述的半监督模型训练方法。
本申请实施例提供了一种包含指令的程序产品,包括程序或指令,当该程序或指令在计算设备上运行时,使得该计算设备执行上述图1至图5描述的半监督模型训练方法。
上述实施例,全部或部分地通过软件、硬件、固件或其他任意组合来实现。当使用软件实现时,上述实施例全部或部分地以计算机程序产品的形式实现。计算机程序产品包括至少一个计算机指令。在计算机上加载或执行计算机程序指令时,全部或部分地产生按照本发明实施例的流程或功能。计算机为通用计算机、专用计算机、计算机网络、或者其他编程装置。计算机指令存储在计算机读存储介质中,或者从一个计算机读存储介质向另一个计算机读存储介质传输,例如,计算机指令从一个网站站点、计算机、服务器或数据中心通过有线(例如同轴电缆、光纤、数字用户线(digital subscriber line,DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。计算机读存储介质是计算机能够存取的任何用介质或者是包含至少一个用介质集合的服务器、数据中心等数据存储节点。用介质是磁性介质(例如,软盘、硬盘、磁带)、光介质(例如,高密度数字视频光盘(digital video disc,DVD)、或者半导体介质。半导体介质是SSD。
以上,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,轻易想到各种等效的修复或替换,这些修复或替换都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以权利要求的保护范围为准。

Claims (21)

1.一种半监督模型训练方法,其特征在于,所述方法包括:
将第一无标签样本输入第一模型,获得所述第一无标签样本的第一伪标签;
将第一扩充样本输入所述第一模型,获得所述第一扩充样本的第二伪标签,其中,所述第一模型为采用有标签样本进行训练后的人工智能AI模型,所述第一扩充样本为对所述第一无标签样本进行数据增强后获得的样本;
根据所述第一伪标签和所述第二伪标签,获得所述第一无标签样本的第三伪标签;
使用所述第一无标签样本和所述第三伪标签对第二模型进行训练,其中,所述第二模型是根据所述第一模型的权重参数获得的AI模型。
2.根据权利要求1所述的方法,其特征在于,所述第二模型与所述第一模型具有相同的结构。
3.根据权利要求1或2所述的方法,其特征在于,所述根据所述第一伪标签和所述第二伪标签,获得所述第一无标签样本的第三伪标签包括:
根据所述第一伪标签和所述第二伪标签,获得所述第一伪标签和所述第二伪标签之间的匹配度;
在所述匹配度高于阈值的情况下,将所述第一伪标签和第二伪标签进行融合,获得所述第三伪标签。
4.根据权利要求1至3任一权利要求所述的方法,其特征在于,所述第一模型包括目标检测模型,所述数据增强方法包括翻转变换、平移变换、尺度变换、旋转变换、缩放变换中的一种或者多种。
5.根据权利要求4所述的方法,其特征在于,所述根据所述第一伪标签和所述第二伪标签,获得所述第一伪标签和所述第二伪标签之间的匹配度包括:
对所述第二伪标签进行所述数据增强的逆操作,获得第四伪标签;
对所述第一伪标签和所述第四伪标签进行匹配,获得所述第一伪标签和所述第四伪标签之间的匹配结果;
根据所述第一伪标签与所述第四伪标签之间的匹配结果确定所述匹配度。
6.根据权利要求1至4任一权利要求所述的方法,其特征在于,所述第一模型包括图像识别模型,所述数据增强方法包括修剪、颜色变换、噪声扰动、内核过滤中的一种或者多种。
7.根据权利要求6所述的方法,其特征在于,所述根据所述第一伪标签和所述第二伪标签,获得所述第一伪标签和所述第二伪标签之间的匹配度包括:
对所述第一伪标签和所述第二伪标签进行匹配,获得所述第一伪标签和所述第二伪标签之间的匹配结果;
根据所述第一伪标签和所述第二伪标签之间的匹配结果获得所述匹配度。
8.根据权利要求3至7任一权利要求所述的方法,其特征在于,所述方法还包括:
将第二无标签样本输入所述第一模型,获得所述第二无标签样本的第五伪标签;
将第二扩充样本输入所述第一模型,获得所述第二扩充样本的第六伪标签,所述第二扩充样本为对所述第二无标签样本进行数据增强后获得的样本;
根据所述第五伪标签和所述第六伪标签,获得所述第五伪标签和所述第六伪标签之间的匹配度;
在所述匹配度不高于所述阈值的情况下,删除所述第五伪标签和所述第六伪标签。
9.根据权利要求1至8任一权利要求所述的方法,其特征在于,所述使用所述第一无标签样本和所述第三伪标签样本对第二模型进行训练,包括:
使用所述有标签样本、所述第一无标签样本和所述第三伪标签样本对所述第二模型进行迭代训练,根据每次迭代训练获得的所述第二模型的权重参数对所述第一模型的权重参数进行迭代更新,获得目标模型。
10.根据权利要求9所述的方法,其特征在于,所述使用所述有标签样本、所述第一无标签样本和所述第三伪标签样本对所述第二模型进行迭代训练包括:
将所述输入样本输入所述第二模型获得第一输出值,将所述第一无标签样本输入所述第二模型获得第二输出值,根据所述第一输出值和所述第二输出值确定所述第二模型的损失值,其中,所述损失值包括有标签损失和伪标签损失,所述有标签损失是根据所述第一输出值和所述真实标签之间的差值获得的,所述伪标签损失是根据所述第二输出值和所述第三伪标签之间的差值获得的;
根据所述损失值对所述第二模型进行迭代训练。
11.一种半监督模型训练系统,其特征在于,所述系统包括:
推理单元,用于将第一无标签样本输入第一模型,获得所述第一无标签样本的第一伪标签;
推理单元,用于将第一扩充样本输入所述第一模型,获得所述扩充样本的第二伪标签,其中,所述第一模型为采用有标签样本进行训练后的人工智能AI模型,所述第一扩充样本为对所述第一无标签样本进行数据增强后获得的样本;
匹配单元,用于根据所述第一伪标签和所述第二伪标签,获得所述第一无标签样本的第三伪标签;
训练单元,用于使用所述第一无标签样本和所述第三伪标签对第二模型进行训练,其中,所述第二模型是根据所述第一模型的权重参数获得的AI模型。
12.根据权利要求11所述的系统,其特征在于,所述第二模型与所述第一模型具有相同的结构。
13.根据权利要求11或12所述的系统,其特征在于,
所述匹配单元,用于根据所述第一伪标签和所述第二伪标签,获得所述第一伪标签和所述第二伪标签之间的匹配度;
所述匹配单元,用于在所述匹配度高于阈值的情况下,将所述第一伪标签和第二伪标签进行融合,获得所述第三伪标签。
14.根据权利要求11至13任一权利要求所述的系统,其特征在于,所述第一模型包括目标检测模型,所述数据增强方法包括翻转变换、平移变换、尺度变换、旋转变换、缩放变换中的一种或者多种。
15.根据权利要求14所述的系统,其特征在于,
所述匹配单元,用于对所述第二伪标签进行所述数据增强的逆操作,获得第四伪标签;
所述匹配单元,用于对所述第一伪标签和所述第四伪标签进行匹配,获得所述第一伪标签和所述第四伪标签之间的匹配结果;
所述匹配单元,用于根据所述第一伪标签与所述第四伪标签之间的匹配结果确定所述匹配度。
16.根据权利要求11至14任一权利要求所述的系统,其特征在于,所述第一模型包括图像识别模型,所述数据增强方法包括修剪、颜色变换、噪声扰动、内核过滤中的一种或者多种。
17.根据权利要求16所述的系统,其特征在于,
所述匹配单元,用于对所述第一伪标签和所述第二伪标签进行匹配,获得所述第一伪标签和所述第二伪标签之间的匹配结果;
所述匹配单元,用于根据所述第一伪标签和所述第二伪标签之间的匹配结果获得所述匹配度。
18.根据权利要求11至17任一权利要求所述的系统,其特征在于,
所述推理单元,用于将第二无标签样本输入所述第一模型,获得所述第二无标签样本的第五伪标签;
所述推理单元,用于将第二扩充样本输入所述第一模型,获得所述第二扩充样本的第六伪标签,所述第二扩充样本为对所述第二无标签样本进行数据增强后获得的样本;
所述匹配单元,用于根据所述第五伪标签和所述第六伪标签,获得所述第五伪标签和所述第六伪标签之间的匹配度;
所述匹配单元,用于在所述匹配度不高于所述阈值的情况下,删除所述第五伪标签和所述第六伪标签。
19.根据权利要求11至18任一权利要求所述的系统,其特征在于,所述训练单元,用于使用所述有标签样本、所述第一无标签样本和所述第三伪标签样本对所述第二模型进行迭代训练,根据每次迭代训练获得的所述第二模型的权重参数对所述第一模型的权重参数进行迭代更新,获得目标模型。
20.根据权利要求19所述的系统,其特征在于,所述训练单元,用于将所述输入样本输入所述第二模型获得第一输出值,将所述第一无标签样本输入所述第二模型获得第二输出值,根据所述第一输出值和所述第二输出值确定所述第二模型的损失值,其中,所述损失值包括有标签损失和伪标签损失,所述有标签损失是根据所述第一输出值和所述真实标签之间的差值获得的,所述伪标签损失是根据所述第二输出值和所述第三伪标签之间的差值获得的;
所述训练单元,用于根据所述损失值对所述第二模型进行迭代训练。
21.一种计算设备,其特征在于,所述计算设备包括处理器和存储器,所述存储器用于存储代码,所述处理器用于执行所述代码实现如权利要求1至10任一权利要求所述的方法。
CN202210412186.6A 2022-04-19 2022-04-19 一种半监督模型训练方法、系统及相关设备 Active CN114970673B (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN202210412186.6A CN114970673B (zh) 2022-04-19 2022-04-19 一种半监督模型训练方法、系统及相关设备
PCT/CN2023/089098 WO2023202596A1 (zh) 2022-04-19 2023-04-19 一种半监督模型训练方法、系统及相关设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210412186.6A CN114970673B (zh) 2022-04-19 2022-04-19 一种半监督模型训练方法、系统及相关设备

Publications (2)

Publication Number Publication Date
CN114970673A true CN114970673A (zh) 2022-08-30
CN114970673B CN114970673B (zh) 2023-04-07

Family

ID=82977875

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210412186.6A Active CN114970673B (zh) 2022-04-19 2022-04-19 一种半监督模型训练方法、系统及相关设备

Country Status (2)

Country Link
CN (1) CN114970673B (zh)
WO (1) WO2023202596A1 (zh)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115471717A (zh) * 2022-09-20 2022-12-13 北京百度网讯科技有限公司 模型的半监督训练、分类方法装置、设备、介质及产品
WO2023202596A1 (zh) * 2022-04-19 2023-10-26 华为技术有限公司 一种半监督模型训练方法、系统及相关设备
CN117151200A (zh) * 2023-10-27 2023-12-01 成都合能创越软件有限公司 基于半监督训练提升yolo检测模型精度方法及系统

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117237343B (zh) * 2023-11-13 2024-01-30 安徽大学 半监督rgb-d图像镜面检测方法、存储介质及计算机设备

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108009589A (zh) * 2017-12-12 2018-05-08 腾讯科技(深圳)有限公司 样本数据处理方法、装置和计算机可读存储介质
CN112183099A (zh) * 2020-10-09 2021-01-05 上海明略人工智能(集团)有限公司 基于半监督小样本扩展的命名实体识别方法及系统
US20210279644A1 (en) * 2020-03-06 2021-09-09 International Business Machines Corporation Modification of Machine Learning Model Ensembles Based on User Feedback
CN113705769A (zh) * 2021-05-17 2021-11-26 华为技术有限公司 一种神经网络训练方法以及装置
CN114330588A (zh) * 2022-01-04 2022-04-12 杭州网易智企科技有限公司 一种图片分类方法、图片分类模型训练方法及相关装置

Family Cites Families (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20220083840A1 (en) * 2020-09-11 2022-03-17 Google Llc Self-training technique for generating neural network models
CN112232416B (zh) * 2020-10-16 2021-09-14 浙江大学 一种基于伪标签加权的半监督学习方法
CN114067444A (zh) * 2021-10-12 2022-02-18 中新国际联合研究院 基于元伪标签和光照不变特征的人脸欺骗检测方法和系统
CN114970673B (zh) * 2022-04-19 2023-04-07 华为技术有限公司 一种半监督模型训练方法、系统及相关设备

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108009589A (zh) * 2017-12-12 2018-05-08 腾讯科技(深圳)有限公司 样本数据处理方法、装置和计算机可读存储介质
US20210279644A1 (en) * 2020-03-06 2021-09-09 International Business Machines Corporation Modification of Machine Learning Model Ensembles Based on User Feedback
CN112183099A (zh) * 2020-10-09 2021-01-05 上海明略人工智能(集团)有限公司 基于半监督小样本扩展的命名实体识别方法及系统
CN113705769A (zh) * 2021-05-17 2021-11-26 华为技术有限公司 一种神经网络训练方法以及装置
CN114330588A (zh) * 2022-01-04 2022-04-12 杭州网易智企科技有限公司 一种图片分类方法、图片分类模型训练方法及相关装置

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2023202596A1 (zh) * 2022-04-19 2023-10-26 华为技术有限公司 一种半监督模型训练方法、系统及相关设备
CN115471717A (zh) * 2022-09-20 2022-12-13 北京百度网讯科技有限公司 模型的半监督训练、分类方法装置、设备、介质及产品
CN117151200A (zh) * 2023-10-27 2023-12-01 成都合能创越软件有限公司 基于半监督训练提升yolo检测模型精度方法及系统

Also Published As

Publication number Publication date
CN114970673B (zh) 2023-04-07
WO2023202596A1 (zh) 2023-10-26

Similar Documents

Publication Publication Date Title
CN114970673B (zh) 一种半监督模型训练方法、系统及相关设备
JP7058669B2 (ja) 車両外観特徴識別及び車両検索方法、装置、記憶媒体、電子デバイス
US11176423B2 (en) Edge-based adaptive machine learning for object recognition
US20220215259A1 (en) Neural network training method, data processing method, and related apparatus
JP7425147B2 (ja) 画像処理方法、テキスト認識方法及び装置
CN112348081A (zh) 用于图像分类的迁移学习方法、相关装置及存储介质
US20230281974A1 (en) Method and system for adaptation of a trained object detection model to account for domain shift
CN113516227B (zh) 一种基于联邦学习的神经网络训练方法及设备
Sharma et al. Vehicle identification using modified region based convolution network for intelligent transportation system
CN113011568A (zh) 一种模型的训练方法、数据处理方法及设备
WO2024083121A1 (zh) 一种数据处理方法及其装置
Mittal et al. Accelerated computer vision inference with AI on the edge
Song et al. Visibility estimation via deep label distribution learning in cloud environment
Kim et al. Convolutional neural network-based multi-target detection and recognition method for unmanned airborne surveillance systems
CN116861262A (zh) 一种感知模型训练方法、装置及电子设备和存储介质
CN117688984A (zh) 神经网络结构搜索方法、装置及存储介质
CN112990305B (zh) 一种遮挡关系的确定方法、装置、设备及存储介质
Tanner et al. Large-scale outdoor scene reconstruction and correction with vision
Gao et al. Air infrared small target local dehazing based on multiple-factor fusion cascade network
Kang et al. Inception network-based weather image classification with pre-filtering process
Niroshan et al. Poly-GAN: Regularizing Polygons with Generative Adversarial Networks
US11741152B2 (en) Object recognition and detection using reinforcement learning
Yu et al. Multi-view Stereo by Fusing Monocular and a Combination of Depth Representation Methods
JP7208314B1 (ja) 学習装置、学習方法及び学習プログラム
TSAKANIKAS SCHOOL OF ENGINEERING

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant