CN111222648B - 半监督机器学习优化方法、装置、设备及存储介质 - Google Patents

半监督机器学习优化方法、装置、设备及存储介质 Download PDF

Info

Publication number
CN111222648B
CN111222648B CN202010044134.9A CN202010044134A CN111222648B CN 111222648 B CN111222648 B CN 111222648B CN 202010044134 A CN202010044134 A CN 202010044134A CN 111222648 B CN111222648 B CN 111222648B
Authority
CN
China
Prior art keywords
machine learning
loss function
learning model
sample
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202010044134.9A
Other languages
English (en)
Other versions
CN111222648A (zh
Inventor
魏锡光
鞠策
李�权
曹祥
刘洋
陈天健
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
WeBank Co Ltd
Original Assignee
WeBank Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by WeBank Co Ltd filed Critical WeBank Co Ltd
Priority to CN202010044134.9A priority Critical patent/CN111222648B/zh
Publication of CN111222648A publication Critical patent/CN111222648A/zh
Application granted granted Critical
Publication of CN111222648B publication Critical patent/CN111222648B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Abstract

本发明公开了一种半监督机器学习优化方法、装置、设备及存储介质,所述方法包括:获取训练样本,其中,训练样本包括有标签样本和无标签样本;根据训练样本与有标签样本之间的相似度以及有标签样本的真实标签,计算得到训练样本对应的伪标签;将训练样本的数据输入初始化机器学习模型得到训练样本对应的第一预测标签,并根据第一预测标签和伪标签计算初始化机器学习模型的损失函数;基于损失函数对初始化机器学习模型进行参数更新,迭代训练直到损失函数收敛时得到目标机器学习模型。本发明实现了使得采用少量的有标签数据也可训练得到效果很好的模型,节省了人工对数据进行标注的人力物力。

Description

半监督机器学习优化方法、装置、设备及存储介质
技术领域
本发明涉及人工智能领域,尤其涉及一种半监督机器学习优化方法、装置、设备及存储介质。
背景技术
随着人工智能技术的发展,机器学习也越来越多地被应用于各个领域。现今的机器学习通常需要大量的带标签数据训练机器学习模型才能够获得好的效果,训练数据可能很容易获得,而打标签则需要专门的人力完成。因此通常会面临有标签的数据比较少,而无标签数据比较多的情况,从而导致机器学习无法获得很好的效果,因此,如何采用少量的标签数据即可训练得到效果很好的模型,以节省人力物力成为了一个亟待解决的问题。
发明内容
本发明的主要目的在于提供一种半监督机器学习优化方法、装置、设备及存储介质,旨在解决如何采用少量的标签数据即可训练得到效果很好的模型,以节省人力物力的问题。
为实现上述目的,本发明提供一种半监督机器学习优化方法,所述半监督机器学习优化方法包括以下步骤:
获取训练样本,其中,所述训练样本包括有标签样本和无标签样本;
根据所述训练样本与所述有标签样本之间的相似度以及所述有标签样本的真实标签,计算得到所述训练样本对应的伪标签;
将所述训练样本的数据输入初始化机器学习模型得到所述训练样本对应的第一预测标签,并根据所述第一预测标签和所述伪标签计算所述初始化机器学习模型的损失函数;
基于所述损失函数对所述初始化机器学习模型进行参数更新,迭代训练直到所述损失函数收敛时得到目标机器学习模型。
可选地,所述根据所述训练样本与所述有标签样本之间的相似度以及所述有标签样本的真实标签,计算得到所述训练样本对应的伪标签的步骤包括:
采用所述初始化机器学习模型中的特征抽取层提取所述训练样本的特征;
根据所述训练样本的特征计算所述训练样本与所述有标签样本之间的相似度;
采用所述训练样本与各所述有标签样本之间的相似度做为权重,对各所述有标签样本的真实标签进行加权平均,得到所述训练样本对应的伪标签。
可选地,所述根据所述第一预测标签和所述伪标签计算所述初始化机器学习模型的损失函数的步骤包括:
根据所述第一预测标签和所述伪标签计算双边一致性损失函数;
根据所述有标签样本计算有监督损失函数;
根据所述双边一致性损失函数和所述有监督损失函数计算所述初始化机器学习模型的损失函数。
可选地,所述根据所述双边一致性损失函数和所述有监督损失函数计算所述初始化机器学习模型的损失函数的步骤包括:
对所述训练样本进行数据增广得到增广样本;
将所述增广样本输入所述初始化机器学习模型得到第二预测标签;
根据所述第一预测标签和所述第二预测标签计算自监督一致性损失函数;
将所述双边一致性损失函数、所述有监督损失函数和所述自监督一致性损失函数进行融合,得到所述初始化机器学习模型的损失函数。
可选地,所述将所述双边一致性损失函数、所述有监督损失函数和所述自监督一致性损失函数进行融合,得到所述初始化机器学习模型的损失函数的步骤包括:
将所述双边一致性损失函数、所述有监督损失函数和所述自监督一致性损失函数进行加权平均,得到所述初始化机器学习模型的损失函数。
可选地,所述根据所述训练样本与所述有标签样本之间的相似度以及所述有标签样本的真实标签,计算得到所述训练样本对应的伪标签的步骤之前,还包括:
采用所述有标签样本对待训练的机器学习模型进行有监督训练得到所述初始化机器学习模型。
可选地,所述训练样本是图像,所述有标签样本包括图像和图像中人脸的位置标注,所述目标机器学习模型用于对图像进行人脸位置检测,
所述基于所述损失函数对所述初始化机器学习模型进行参数更新,迭代训练直到所述损失函数收敛时得到目标机器学习模型的步骤之后,还包括:
将待检测图像输入所述目标机器学习模型得到所述待检测图像中人脸位置的检测结果。
为实现上述目的,本发明还提供一种半监督机器学习优化装置,所述半监督机器学习优化装置包括:
获取模块,用于获取训练样本,其中,所述训练样本包括有标签样本和无标签样本;
计算模块,用于根据所述训练样本与所述有标签样本之间的相似度以及所述有标签样本的真实标签,计算得到所述训练样本对应的伪标签;
输入模块,用于将所述训练样本的数据输入初始化机器学习模型得到所述训练样本对应的第一预测标签,并根据所述第一预测标签和所述伪标签计算所述初始化机器学习模型的损失函数;
训练模块,用于基于所述损失函数对所述初始化机器学习模型进行参数更新,迭代训练直到所述损失函数收敛时得到目标机器学习模型。
为实现上述目的,本发明还提供一种半监督机器学习优化设备,所述半监督机器学习优化设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的半监督机器学习优化程序,所述半监督机器学习优化程序被所述处理器执行时实现如上所述的半监督机器学习优化方法的步骤。
此外,为实现上述目的,本发明还提出一种计算机可读存储介质,所述计算机可读存储介质上存储有半监督机器学习优化程序,所述半监督机器学习优化程序被处理器执行时实现如上所述的半监督机器学习优化方法的步骤。
本发明中,通过获取包括有标签样本和无标签样本的训练样本,并根据训练样本与有标签样本之间的相似度以及有标签样本的真实标签,为训练样本打上伪标签,使得各个训练样本都获取一个伪标签,从而能够利用包括无标签样本和有标签样本的所有训练样本来对机器学习模型进行训练,从而解决了无标签样本不能利用于机器学习训练的问题,从而使得采用少量的有标签数据和大量的无标签数据即可训练得到效果很好的模型,节省了人工对数据进行标注的人力物力,从而使得机器学习能够应用到更广的领域。并且,由于是利用了训练样本与有标签样本之前的相似度来计算伪标签,使得伪标签更接近于训练样本的真实标签,从而使得机器学习模型的训练效果得到显著提升。
附图说明
图1是本发明实施例方案涉及的硬件运行环境的结构示意图;
图2为本发明半监督机器学习优化方法第一实施例的流程示意图;
图3为本发明实施例涉及的一种混合监督机器学习模型训练框架;
图4本发明半监督机器学习优化装置较佳实施例的功能示意图模块图。
本发明目的的实现、功能特点及优点将结合实施例,参照附图做进一步说明。
具体实施方式
应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
如图1所示,图1是本发明实施例方案涉及的硬件运行环境的设备结构示意图。
需要说明的是,本发明实施例半监督机器学习优化设备可以是智能手机、个人计算机和服务器等设备,在此不做具体限制。
如图1所示,该半监督机器学习优化设备可以包括:处理器1001,例如CPU,网络接口1004,用户接口1003,存储器1005,通信总线1002。其中,通信总线1002用于实现这些组件之间的连接通信。用户接口1003可以包括显示屏(Display)、输入单元比如键盘(Keyboard),可选用户接口1003还可以包括标准的有线接口、无线接口。网络接口1004可选的可以包括标准的有线接口、无线接口(如WI-FI接口)。存储器1005可以是高速RAM存储器,也可以是稳定的存储器(non-volatile memory),例如磁盘存储器。存储器1005可选的还可以是独立于前述处理器1001的存储装置。
本领域技术人员可以理解,图1中示出的设备结构并不构成对半监督机器学习优化设备的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。
如图1所示,作为一种计算机存储介质的存储器1005中可以包括操作系统、网络通信模块、用户接口模块以及半监督机器学习优化程序。其中,操作系统是管理和控制设备硬件和软件资源的程序,支持半监督机器学习优化程序以及其它软件或程序的运行。
在图1所示的设备中,用户接口1003主要用于与客户端进行数据通信;网络接口1004主要用于与各参与设备建立通信连接;而处理器1001可以用于调用存储器1005中存储的半监督机器学习优化程序,并执行以下操作:
获取训练样本,其中,所述训练样本包括有标签样本和无标签样本;
根据所述训练样本与所述有标签样本之间的相似度以及所述有标签样本的真实标签,计算得到所述训练样本对应的伪标签;
将所述训练样本的数据输入初始化机器学习模型得到所述训练样本对应的第一预测标签,并根据所述第一预测标签和所述伪标签计算所述初始化机器学习模型的损失函数;
基于所述损失函数对所述初始化机器学习模型进行参数更新,迭代训练直到所述损失函数收敛时得到目标机器学习模型。
进一步地,所述根据所述训练样本与所述有标签样本之间的相似度以及所述有标签样本的真实标签,计算得到所述训练样本对应的伪标签的步骤包括:
采用所述初始化机器学习模型中的特征抽取层提取所述训练样本的特征;
根据所述训练样本的特征计算所述训练样本与所述有标签样本之间的相似度;
采用所述训练样本与各所述有标签样本之间的相似度做为权重,对各所述有标签样本的真实标签进行加权平均,得到所述训练样本对应的伪标签。
进一步地,所述根据所述第一预测标签和所述伪标签计算所述初始化机器学习模型的损失函数的步骤包括:
根据所述第一预测标签和所述伪标签计算双边一致性损失函数;
根据所述有标签样本计算有监督损失函数;
根据所述双边一致性损失函数和所述有监督损失函数计算所述初始化机器学习模型的损失函数。
进一步地,所述根据所述双边一致性损失函数和所述有监督损失函数计算所述初始化机器学习模型的损失函数的步骤包括:
对所述训练样本进行数据增广得到增广样本;
将所述增广样本输入所述初始化机器学习模型得到第二预测标签;
根据所述第一预测标签和所述第二预测标签计算自监督一致性损失函数;
将所述双边一致性损失函数、所述有监督损失函数和所述自监督一致性损失函数进行融合,得到所述初始化机器学习模型的损失函数。
进一步地,所述将所述双边一致性损失函数、所述有监督损失函数和所述自监督一致性损失函数进行融合,得到所述初始化机器学习模型的损失函数的步骤包括:
将所述双边一致性损失函数、所述有监督损失函数和所述自监督一致性损失函数进行加权平均,得到所述初始化机器学习模型的损失函数。
进一步地,所述根据所述训练样本与所述有标签样本之间的相似度以及所述有标签样本的真实标签,计算得到所述训练样本对应的伪标签的步骤之前,还包括:
采用所述有标签样本对待训练的机器学习模型进行有监督训练得到所述初始化机器学习模型。
进一步地,所述训练样本是图像,所述有标签样本包括图像和图像中人脸的位置标注,所述目标机器学习模型用于对图像进行人脸位置检测,
所述基于所述损失函数对所述初始化机器学习模型进行参数更新,迭代训练直到所述损失函数收敛时得到目标机器学习模型的步骤之后,还包括:
将待检测图像输入所述目标机器学习模型得到所述待检测图像中人脸位置的检测结果。
基于上述的结构,提出半监督机器学习优化方法的各个实施例。
参照图2,图2为本发明半监督机器学习优化方法第一实施例的流程示意图。
本发明实施例提供了半监督机器学习优化方法的实施例,需要说明的是,虽然在流程图中示出了逻辑顺序,但是在某些情况下,可以以不同于此处的顺序执行所示出或描述的步骤。本发明半监督机器学习优化方法各个实施例的执行主体可以是智能手机、个人计算机和服务器等设备,为便于描述,以下各实施例中省略执行主体进行阐述。在本实施例中,半监督机器学习优化方法包括:
步骤S10,获取训练样本,其中,所述训练样本包括有标签样本和无标签样本;
获取训练样本,其中,训练样本包括多条样本,既包括有标签样本也包括无标签样本。需要说明的是,一条无标签样本包括一条数据,一条有标签样本包括一条数据以及该数据对应的标签。无标签样本的数量可以远远大于有标签样本的数量,为节省人工打标签的人力物力,有标签数据的数量可以不用过多,也即,相比于常规的有监督学习方式采用大量的有标签数据,在实施例提供的半监督机器学习优化方案中,可采用较少的有标签数据。
根据具体的机器学习任务不同,可以获取不同的训练样本。例如,机器学习的任务是采用神经网络模型对图像进行人脸位置进行检测,则获取的训练样本是多个图像,有标签样本还包括图像中人脸的位置标签。又如,机器学习的任务是采用决策树模型进行用户购买意向的预测,则获取到的训练样本是多个用户数据,有标签样本还包括用户的购买意向标签。
步骤S20,根据所述训练样本与所述有标签样本之间的相似度以及所述有标签样本的真实标签,计算得到所述训练样本对应的伪标签;
根据训练样本与有标签样本之间的相似度以及有标签样本的真实标签,计算得到训练样本对应的伪标签。具体地,在本实施例中,利用各训练样本与各有标签样本之间的相似度,来为各训练样本打上一个伪标签,也即,根据一条样本与有标签样本的相似度高,则该样本的真实标签与该有标签样本的真实标签越接近的原理,为样本打上与它相似度高的有标签样本类似或相同的标签作为该样本的伪标签,从而将该样本扩充为了一个有标签样本。需要说明的是,可以对所有的训练样本都打上伪标签,也即,忽略训练样本中的有标签样本的标签,将该有标签样本也作为无标签样本给其打上伪标签,这样可以提高训练数据的利用率。
具体计算相似度的方式有多种,如计算两样本之间的相似度,可将两样本的数据作为向量,采用传统的向量间相似度衡量方案,也可以计算两个样本之间在特征空间的相似度;根据相似度和真实标签计算伪标签的方式也有多种,如采用与该样本的相似度最高的有标签样本的真实标签作为该样本的伪标签,在此不作限定。
进一步地,步骤S20包括:
步骤S201,采用所述初始化机器学习模型中的特征抽取层提取所述训练样本的特征;
进一步地,在本实施例中,采用初始化机器学习模型中的特征抽取层提取训练样本的特征。具体地,初始化机器学习模型可以包括特征抽取层和预测层,特征抽取层用于提取样本数据的特征,预测层用于根据特征完成预测任务,如果机器学习模型是用于分类任务,则预测层用于根据特征完成分类任务。特征抽取层对输入的数据进行特征提取,得到向量形式的特征。
步骤S202,根据所述训练样本的特征计算所述训练样本与所述有标签样本之间的相似度;
根据训练样本的特征计算训练样本与有标签样本之间的相似度。也即,对每一个训练样本(每一个有标签样本和每一个无标签样本),采用该训练样本的特征,分别与每个有标签样本的特征计算相似度,即得到该训练样本分别与每个有标签样本的相似度。具体地,可采用传统的向量件相似度衡量方案来计算两个特征向量之间的相似度,如采用余弦相似性、欧几里得距离等。
步骤S203,采用所述训练样本与各所述有标签样本之间的相似度做为权重,对各所述有标签样本的真实标签进行加权平均,得到所述训练样本对应的伪标签。
在计算得到一个训练样本与各个有标签样本的相似度后,可以采用该样本与各个有标签样本的相似度作为权重,对各个有标签样本的真实标签进行加权平均,得到该训练样本对应的伪标签。采用同样的方法,计算得到每个训练样本的伪标签。例如,有三个训练样本(U1、U2、U3),其中U1和U2是有标签样本,标签分别是Y1、Y2,U3是无标签样本,计算得到U3与U1、U2的相似度分别为P1、P2,则U3的伪标签是:(Y1*P1+Y2*P2)/(P1+P2),采用同样的方法,计算得到U1和U2的伪标签。
步骤S30,将所述训练样本的数据输入初始化机器学习模型得到所述训练样本对应的第一预测标签,并根据所述第一预测标签和所述伪标签计算所述初始化机器学习模型的损失函数;
将训练样本的数据输入初始化机器学习模型得到训练样本对应的第一预测标签,根据第一预测标签和伪标签计算初始化机器学习模型的损失函数。具体地,可以采用第一预测标签和伪标签构造一个损失函数,该损失函数由于采用的是伪标签而不是真实标签,在本实施例中将该损失函数命名为双边监督一致性损失函数,以区别于有监督损失函数。在本实施例中,可以将该双边监督一致性损失函数作为初始化机器学习模型的损失函数。可以是对机器学习模型的模型参数给定一个初始值,将拥有初始值的机器学习模型作为初始化机器学习模型。
步骤S40,基于所述损失函数对所述初始化机器学习模型进行参数更新,迭代训练直到所述损失函数收敛时得到目标机器学习模型。
基于计算得到的初始化机器学习模型的损失函数对初始化机器学习模型进行参数更新。具体地,可以检测损失函数是否收敛;如检测损失函数的值是否小于一个预设阈值,若小于,则确定损失函数收敛,若不小于,则确定损失函数未收敛;还可以是检测迭代训练的次数是否大于预设次数,若大于,则确定损失函数收敛,若不大于,则确定损失函数未收敛;还可以是检测迭代训练的时间是否大于预设时间,若大于,则确定损失函数收敛,若不大于,则确定损失函数未收敛。若损失函数未收敛,则根据损失函数计算机器学习模型的各个模型参数的梯度值,根据梯度值来更新各个模型参数;再采用训练样本输入更新模型参数后的机器学习模型,得到新的预测标签,再计算新的伪标签,根据新的预测标签和伪标签计算新的损失函数,再进行收敛判断;若收敛,则停止训练,不再更新模型参数,得到目标机器学习模型,若未收敛,则继续训练。
在本实施例中,通过获取包括有标签样本和无标签样本的训练样本,并根据训练样本与有标签样本之间的相似度以及有标签样本的真实标签,为训练样本打上伪标签,使得各个训练样本都获取一个伪标签,从而能够利用包括无标签样本和有标签样本的所有训练样本来对机器学习模型进行训练,从而解决了无标签样本不能利用于机器学习训练的问题,从而使得采用少量的有标签数据和大量的无标签数据即可训练得到效果很好的模型,节省了人工对数据进行标注的人力物力,从而使得机器学习能够应用到更广的领域。并且,由于是利用了训练样本与有标签样本之前的相似度来计算伪标签,使得伪标签更接近于训练样本的真实标签,从而使得机器学习模型的训练效果得到显著提升。
进一步地,在步骤S20之前,还包括:
步骤S50,采用所述有标签样本对待训练的机器学习模型进行有监督训练得到所述初始化机器学习模型。
在获取到训练样本之后,可以先采用训练样本中的有标签样本对待训练的机器学习模型进行有监督训练得到初始化机器学习模型。具体地有监督训练方式与传统的有监督训练方式相同,在此不作详细的说明。通过采用有标签数据对待训练的机器学习模型进行一个初始化的训练,使得后续模型训练有一个相对较优化的模型参数作为基础,从而使得后续的训练过程能够更快速地达到收敛,节省训练时间,也节省计算机的计算资源。
进一步地,基于上述第一实施例,提出本发明半监督机器学习优化方法第二实施例,在本发明半监督机器学习优化方法第二实施例中,所述步骤S30中根据所述第一预测标签和所述伪标签计算所述初始化机器学习模型的损失函数的步骤包括:
步骤S301,根据所述第一预测标签和所述伪标签计算双边一致性损失函数;
可以采用第一预测标签和伪标签构造一个损失函数,该损失函数由于采用的是伪标签而不是真实标签,在本实施例中将该损失函数命名为双边监督一致性损失函数,以区别于有监督损失函数。也即双边监督一致性损失函数的计算方法与现有的有监督损失函数的计算方法类似,不同之处在于采用的不是真实标签而是伪标签。
步骤S302,根据所述有标签样本计算有监督损失函数;
采用有标签样本的真实标签和有标签样本的数据输入初始化机器学习模型得到的第一预测标签计算有监督损失函数。有监督损失函数可采用现有的损失函数计算方法,在此不作赘述。
步骤S303,根据所述双边一致性损失函数和所述有监督损失函数计算所述初始化机器学习模型的损失函数。
根据双边一致性损失函数和有监督损失函数计算初始化机器学习模型的损失函数。具体地,可以将双边一致性损失函数和有监督损失函数进行融合,得到初始化机器学习模型,融合可以是计算加权平均。
在本实施例中,通过采用双边一致性损失函数和有监督损失函数一起构造初始化机器学习模型的损失函数,利用了有标签数据对机器学习模型进行有监督学习,可以使得最终得到的目标机器学习模型的效果更好。
进一步地,在另一实施方式中,参照图3所示,为一种混合监督机器学习模型训练框架,步骤S303包括:
步骤S3031,对所述训练样本进行数据增广得到增广样本;
步骤S3032,将所述增广样本输入所述初始化机器学习模型得到第二预测标签;
步骤S3033,根据所述第一预测标签和所述第二预测标签计算自监督一致性损失函数;
步骤S3034,将所述双边一致性损失函数、所述有监督损失函数和所述自监督一致性损失函数进行融合,得到所述初始化机器学习模型的损失函数
对训练样本进行数据增广得到增广样本。具体地,数据增广的目的是对训练样本进行一些变化,根据训练样本的类型不同,数据增广方式不同,如训练样本是图像,则可以采用平移、旋转或缩放等方式进行数据增广。
将增广样本分别输入初始化机器学习模型得到各个增广样本对应的第二预测标签。
根据第一预测标签和第二预测标签计算一个损失函数,具体可采用现有的损失函数构建方式,由于是采用预测标签与预测标签来计算损失函数,因此将该损失函数称为自监督一致性损失函数,以区别于前述的双边一致性损失函数和有监督损失函数。将双边一致性损失函数、有监督损失函数和自监督一致性损失函数进行融合,得到初始化机器学习模型的损失函数,如图3中,将a、b、c三部分获得的损失函数进行融合,得到最终的损失函数。其中,融合的方式有多种,可以采用相加的方式,也可以是加权平均的方式等等。
在本实施例中,通过对训练数据进行数据增广得到增广样本,使得训练数据扩增了至少一倍,再采用增广样本输入机器学习模型得到第二预测标签,根据第二预测标签与训练样本输入机器学习模型得到的第一预测标签计算自监督一致性损失函数,再根据双边一致性损失函数、有监督损失函数和自监督一致性损失函数得到机器学习模型的损失函数,使得训练样本得到充分的利用,从而使得在训练样本少的情况下,在有标签样本少的情况下,也能够训练得到效果好的机器学习模型,从而降低了人工采集数据和人工打标注的人力物力。
进一步地,基于上述第一和第二实施例,提出本发明半监督机器学习优化方法第三实施例,在本发明半监督机器学习优化方法第三实施例中,所述训练样本是图像,所述有标签样本包括图像和图像中人脸的位置标注,所述目标机器学习模型用于对图像进行人脸位置检测,所述步骤S40之后,还包括:
步骤S60,将待检测图像输入所述目标机器学习模型得到所述待检测图像中人脸位置的检测结果。
在本实施例中,当机器学习的任务是对图像进行人脸位置检测时,获取到的训练样本是大量包含人脸的图像,有标签样本则包括图像和图像中人脸的位置标注。由于图像很容易通过摄像头终端采集,但人脸位置标注却需要人工一个一个进行标注,会花费较多的人力物力和时间。为克服这一问题,可以采用上述实施例中的半监督机器学习优化方案利用包含大量无标签样本和少量有标签样本的图像训练样本对机器学习模型进行训练,得到用于对图像进行人脸位置检测的目标机器学习模型。
采用该目标机器学习模型对待检测图像进行人脸位置检测,具体地,可获取待检测图像,将待检测图像输入目标机器学习模型,机器学习模型直接输出得到待检测图像中人脸位置的检测结果。需要说明的是,根据机器学习模型的结构设计不同,人脸位置的检测结果的形式不同,可以是输出表示人脸所在位置的坐标,或输出一张人脸区域颜色不同于其他区域的图片,在此不作具体限制。
通过采用上述实施例中的半监督机器学习优化方案来人脸位置检测的机器学习模型进行训练,使得在有标签样本少的情况下,也能够获得人脸位置检测准确率高的机器学习模型,从而降低了人工采集数据和人工打标注的人力物力。
需要说明的是,本发明实施例涉及的半监督机器学习模型还可以应用于其他预测或分类任务,如还可以应用于绩效等级预测、论文价值评价等。
此外,此外本发明实施例还提出一种半监督机器学习优化装置,参照图4,所述半监督机器学习优化装置包括:
获取模块10,用于获取训练样本,其中,所述训练样本包括有标签样本和无标签样本;
计算模块20,用于根据所述训练样本与所述有标签样本之间的相似度以及所述有标签样本的真实标签,计算得到所述训练样本对应的伪标签;
输入模块30,用于将所述训练样本的数据输入初始化机器学习模型得到所述训练样本对应的第一预测标签,并根据所述第一预测标签和所述伪标签计算所述初始化机器学习模型的损失函数;
训练模块40,用于基于所述损失函数对所述初始化机器学习模型进行参数更新,迭代训练直到所述损失函数收敛时得到目标机器学习模型。
进一步地,所述计算模块20包括:
提取单元,用于采用所述初始化机器学习模型中的特征抽取层提取所述训练样本的特征;
第一计算单元,用于根据所述训练样本的特征计算所述训练样本与所述有标签样本之间的相似度;
第二计算单元,用于采用所述训练样本与各所述有标签样本之间的相似度做为权重,对各所述有标签样本的真实标签进行加权平均,得到所述训练样本对应的伪标签。
进一步地,所述输入模块30包括:
第三计算单元,用于根据所述第一预测标签和所述伪标签计算双边一致性损失函数;
第四计算单元,用于根据所述有标签样本计算有监督损失函数;
第五计算单元,用于根据所述双边一致性损失函数和所述有监督损失函数计算所述初始化机器学习模型的损失函数。
进一步地,所述第五计算单元包括:
数据增广子单元,用于对所述训练样本进行数据增广得到增广样本;
输入子单元,用于将所述增广样本输入所述初始化机器学习模型得到第二预测标签;
计算子单元,用于根据所述第一预测标签和所述第二预测标签计算自监督一致性损失函数;
融合子单元,用于将所述双边一致性损失函数、所述有监督损失函数和所述自监督一致性损失函数进行融合,得到所述初始化机器学习模型的损失函数。
进一步地,所述融合子单元用于:将所述双边一致性损失函数、所述有监督损失函数和所述自监督一致性损失函数进行加权平均,得到所述初始化机器学习模型的损失函数。
进一步地,所述半监督机器学习优化装置还包括:
有监督训练模块,用于采用所述有标签样本对待训练的机器学习模型进行有监督训练得到所述初始化机器学习模型。
进一步地,所述训练样本是图像,所述有标签样本包括图像和图像中人脸的位置标注,所述目标机器学习模型用于对图像进行人脸位置检测,所述半监督机器学习优化装置还包括:
检测模块,用于将待检测图像输入所述目标机器学习模型得到所述待检测图像中人脸位置的检测结果。
本发明半监督机器学习优化装置的具体实施方式的拓展内容与上述半监督机器学习优化方法各实施例基本相同,在此不做赘述。
此外,本发明实施例还提出一种计算机可读存储介质,所述存储介质上存储有半监督机器学习优化程序,所述半监督机器学习优化程序被处理器执行时实现如下所述的半监督机器学习优化方法的步骤。
本发明半监督机器学习优化设备和计算机可读存储介质的各实施例,均可参照本发明半监督机器学习优化方法各个实施例,此处不再赘述。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者装置不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者装置所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者装置中还存在另外的相同要素。
上述本发明实施例序号仅仅为了描述,不代表实施例的优劣。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端设备(可以是手机,计算机,服务器,空调器,或者网络设备等)执行本发明各个实施例所述的方法。
以上仅为本发明的优选实施例,并非因此限制本发明的专利范围,凡是利用本发明说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本发明的专利保护范围内。

Claims (9)

1.一种半监督机器学习优化方法,其特征在于,所述半监督机器学习优化方法包括以下步骤:
获取训练样本,其中,所述训练样本包括有标签样本和无标签样本,所述训练样本是图像;
采用所述有标签样本对待训练的机器学习模型进行有监督训练得到初始化机器学习模型;
计算所述训练样本与所述有标签样本之间的相似度;
采用所述训练样本与各所述有标签样本之间的相似度做为权重,对各所述有标签样本的真实标签进行加权平均,得到所述训练样本对应的伪标签;
将所述训练样本的数据输入所述初始化机器学习模型得到所述训练样本对应的第一预测标签,并根据所述第一预测标签和所述伪标签计算所述初始化机器学习模型的损失函数;
基于所述损失函数对所述初始化机器学习模型进行参数更新,迭代训练直到所述损失函数收敛时得到目标机器学习模型。
2.如权利要求1所述的半监督机器学习优化方法,其特征在于,所述计算所述训练样本与所述有标签样本之间的相似度的步骤包括:
采用所述初始化机器学习模型中的特征抽取层提取所述训练样本的特征;
根据所述训练样本的特征计算所述训练样本与所述有标签样本之间的相似度。
3.如权利要求1所述的半监督机器学习优化方法,其特征在于,所述根据所述第一预测标签和所述伪标签计算所述初始化机器学习模型的损失函数的步骤包括:
根据所述第一预测标签和所述伪标签计算双边一致性损失函数;
根据所述有标签样本计算有监督损失函数;
根据所述双边一致性损失函数和所述有监督损失函数计算所述初始化机器学习模型的损失函数。
4.如权利要求3所述的半监督机器学习优化方法,其特征在于,所述根据所述双边一致性损失函数和所述有监督损失函数计算所述初始化机器学习模型的损失函数的步骤包括:
对所述训练样本进行数据增广得到增广样本;
将所述增广样本输入所述初始化机器学习模型得到第二预测标签;
根据所述第一预测标签和所述第二预测标签计算自监督一致性损失函数;
将所述双边一致性损失函数、所述有监督损失函数和所述自监督一致性损失函数进行融合,得到所述初始化机器学习模型的损失函数。
5.如权利要求4所述的半监督机器学习优化方法,其特征在于,所述将所述双边一致性损失函数、所述有监督损失函数和所述自监督一致性损失函数进行融合,得到所述初始化机器学习模型的损失函数的步骤包括:
将所述双边一致性损失函数、所述有监督损失函数和所述自监督一致性损失函数进行加权平均,得到所述初始化机器学习模型的损失函数。
6.如权利要求1至5任一项所述的半监督机器学习优化方法,其特征在于,所述有标签样本包括图像和图像中人脸的位置标注,所述目标机器学习模型用于对图像进行人脸位置检测,
所述基于所述损失函数对所述初始化机器学习模型进行参数更新,迭代训练直到所述损失函数收敛时得到目标机器学习模型的步骤之后,还包括:
将待检测图像输入所述目标机器学习模型得到所述待检测图像中人脸位置的检测结果。
7.一种半监督机器学习优化装置,其特征在于,所述半监督机器学习优化装置包括:
获取模块,用于获取训练样本,其中,所述训练样本包括有标签样本和无标签样本,所述训练样本是图像;
初始模块,用于采用所述有标签样本对待训练的机器学习模型进行有监督训练得到初始化机器学习模型;
计算模块,用于计算所述训练样本与所述有标签样本之间的相似度,再采用所述训练样本与各所述有标签样本之间的相似度做为权重,对各所述有标签样本的真实标签进行加权平均,得到所述训练样本对应的伪标签;
输入模块,用于将所述训练样本的数据输入所述初始化机器学习模型得到所述训练样本对应的第一预测标签,并根据所述第一预测标签和所述伪标签计算所述初始化机器学习模型的损失函数;
训练模块,用于基于所述损失函数对所述初始化机器学习模型进行参数更新,迭代训练直到所述损失函数收敛时得到目标机器学习模型。
8.一种半监督机器学习优化设备,其特征在于,所述半监督机器学习优化设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的半监督机器学习优化程序,所述半监督机器学习优化程序被所述处理器执行时实现如权利要求1至6中任一项所述的半监督机器学习优化方法的步骤。
9.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有半监督机器学习优化程序,所述半监督机器学习优化程序被处理器执行时实现如权利要求1至6中任一项所述的半监督机器学习优化方法的步骤。
CN202010044134.9A 2020-01-15 2020-01-15 半监督机器学习优化方法、装置、设备及存储介质 Active CN111222648B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010044134.9A CN111222648B (zh) 2020-01-15 2020-01-15 半监督机器学习优化方法、装置、设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010044134.9A CN111222648B (zh) 2020-01-15 2020-01-15 半监督机器学习优化方法、装置、设备及存储介质

Publications (2)

Publication Number Publication Date
CN111222648A CN111222648A (zh) 2020-06-02
CN111222648B true CN111222648B (zh) 2023-09-26

Family

ID=70831864

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010044134.9A Active CN111222648B (zh) 2020-01-15 2020-01-15 半监督机器学习优化方法、装置、设备及存储介质

Country Status (1)

Country Link
CN (1) CN111222648B (zh)

Families Citing this family (37)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111784595B (zh) * 2020-06-10 2023-08-29 北京科技大学 一种基于历史记录的动态标签平滑加权损失方法及装置
CN111740991B (zh) * 2020-06-19 2022-08-09 上海仪电(集团)有限公司中央研究院 一种异常检测方法及系统
CN111724867B (zh) * 2020-06-24 2022-09-09 中国科学技术大学 分子属性测定方法、装置、电子设备及存储介质
CN111783870B (zh) * 2020-06-29 2023-09-01 北京百度网讯科技有限公司 人体属性的识别方法、装置、设备及存储介质
CN111917740B (zh) * 2020-07-15 2022-08-26 杭州安恒信息技术股份有限公司 一种异常流量告警日志检测方法、装置、设备及介质
CN112102062A (zh) * 2020-07-24 2020-12-18 北京淇瑀信息科技有限公司 一种基于弱监督学习的风险评估方法、装置及电子设备
CN112183577A (zh) * 2020-08-31 2021-01-05 华为技术有限公司 一种半监督学习模型的训练方法、图像处理方法及设备
CN112183321A (zh) * 2020-09-27 2021-01-05 深圳奇迹智慧网络有限公司 机器学习模型优化的方法、装置、计算机设备和存储介质
CN113392864A (zh) * 2020-10-13 2021-09-14 腾讯科技(深圳)有限公司 模型生成方法及视频筛选方法、相关装置、存储介质
CN112418264A (zh) * 2020-10-14 2021-02-26 上海眼控科技股份有限公司 检测模型的训练方法、装置、目标检测方法、设备和介质
CN112381116B (zh) * 2020-10-21 2022-10-28 福州大学 基于对比学习的自监督图像分类方法
CN112417986B (zh) * 2020-10-30 2023-03-10 四川天翼网络股份有限公司 一种基于深度神经网络模型的半监督在线人脸识别方法及系统
CN112307472A (zh) * 2020-11-03 2021-02-02 平安科技(深圳)有限公司 基于智能决策的异常用户识别方法、装置及计算机设备
CN112381098A (zh) * 2020-11-19 2021-02-19 上海交通大学 基于目标分割领域自学习的半监督学习方法和系统
CN112287089B (zh) * 2020-11-23 2022-09-20 腾讯科技(深圳)有限公司 用于自动问答系统的分类模型训练、自动问答方法及装置
CN112257855B (zh) * 2020-11-26 2022-08-16 Oppo(重庆)智能科技有限公司 一种神经网络的训练方法及装置、电子设备及存储介质
CN112417767B (zh) * 2020-12-09 2024-02-27 东软睿驰汽车技术(沈阳)有限公司 一种衰减趋势确定模型构建方法、衰减趋势确定方法
CN112541904B (zh) * 2020-12-16 2023-03-24 西安电子科技大学 一种无监督遥感图像变化检测方法、存储介质及计算设备
CN112733275B (zh) * 2021-01-19 2023-07-25 中国人民解放军军事科学院国防科技创新研究院 基于半监督学习的卫星组件热布局温度场预测方法
CN112784749B (zh) * 2021-01-22 2023-11-10 北京百度网讯科技有限公司 目标模型的训练方法、目标对象的识别方法、装置及介质
CN112598091B (zh) * 2021-03-08 2021-09-07 北京三快在线科技有限公司 一种训练模型和小样本分类的方法及装置
CN113724189A (zh) * 2021-03-17 2021-11-30 腾讯科技(深圳)有限公司 图像处理方法、装置、设备及存储介质
CN113158554B (zh) * 2021-03-25 2023-02-14 腾讯科技(深圳)有限公司 模型优化方法、装置、计算机设备及存储介质
CN113095423A (zh) * 2021-04-21 2021-07-09 南京大学 一种基于在线反绎学习的流式数据分类方法及其实现装置
CN113420786A (zh) * 2021-05-31 2021-09-21 杭州电子科技大学 一种特征混合图像的半监督分类方法
CN113282921A (zh) * 2021-06-11 2021-08-20 深信服科技股份有限公司 一种文件检测方法、装置、设备及存储介质
CN113591914A (zh) * 2021-06-28 2021-11-02 中国平安人寿保险股份有限公司 一种数据分类方法、装置、计算机设备和存储介质
CN113688665B (zh) * 2021-07-08 2024-02-20 华中科技大学 一种基于半监督迭代学习的遥感影像目标检测方法及系统
CN113516251B (zh) * 2021-08-05 2023-06-06 上海高德威智能交通系统有限公司 一种机器学习系统及模型训练方法
CN113780389B (zh) * 2021-08-31 2023-05-26 中国人民解放军战略支援部队信息工程大学 基于一致性约束的深度学习半监督密集匹配方法及系统
CN113743618A (zh) * 2021-09-03 2021-12-03 北京航空航天大学 时间序列数据处理方法、装置、可读介质及电子设备
CN114118259A (zh) * 2021-11-19 2022-03-01 杭州海康威视数字技术股份有限公司 一种目标检测方法及装置
CN114186615B (zh) * 2021-11-22 2022-07-08 浙江华是科技股份有限公司 船舶检测半监督在线训练方法、装置及计算机存储介质
CN114462621A (zh) * 2022-01-06 2022-05-10 深圳安巽科技有限公司 一种机器监督学习方法、装置
CN114529759B (zh) * 2022-01-25 2023-01-17 北京医准智能科技有限公司 一种甲状腺结节的分类方法、装置及计算机可读介质
CN115272777B (zh) * 2022-09-26 2022-12-23 山东大学 面向输电场景的半监督图像解析方法
CN117332090B (zh) * 2023-11-29 2024-02-23 苏州元脑智能科技有限公司 一种敏感信息识别方法、装置、设备和存储介质

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP2009075737A (ja) * 2007-09-19 2009-04-09 Nec Corp 半教師あり学習方法、半教師あり学習装置及び半教師あり学習プログラム
CN108416370A (zh) * 2018-02-07 2018-08-17 深圳大学 基于半监督深度学习的图像分类方法、装置和存储介质
CN108764281A (zh) * 2018-04-18 2018-11-06 华南理工大学 一种基于半监督自步学习跨任务深度网络的图像分类方法
CN109146847A (zh) * 2018-07-18 2019-01-04 浙江大学 一种基于半监督学习的晶圆图批量分析方法
CN110298415A (zh) * 2019-08-20 2019-10-01 视睿(杭州)信息科技有限公司 一种半监督学习的训练方法、系统和计算机可读存储介质
CN110472533A (zh) * 2019-07-31 2019-11-19 北京理工大学 一种基于半监督训练的人脸识别方法

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP2009075737A (ja) * 2007-09-19 2009-04-09 Nec Corp 半教師あり学習方法、半教師あり学習装置及び半教師あり学習プログラム
CN108416370A (zh) * 2018-02-07 2018-08-17 深圳大学 基于半监督深度学习的图像分类方法、装置和存储介质
CN108764281A (zh) * 2018-04-18 2018-11-06 华南理工大学 一种基于半监督自步学习跨任务深度网络的图像分类方法
CN109146847A (zh) * 2018-07-18 2019-01-04 浙江大学 一种基于半监督学习的晶圆图批量分析方法
CN110472533A (zh) * 2019-07-31 2019-11-19 北京理工大学 一种基于半监督训练的人脸识别方法
CN110298415A (zh) * 2019-08-20 2019-10-01 视睿(杭州)信息科技有限公司 一种半监督学习的训练方法、系统和计算机可读存储介质

Also Published As

Publication number Publication date
CN111222648A (zh) 2020-06-02

Similar Documents

Publication Publication Date Title
CN111222648B (zh) 半监督机器学习优化方法、装置、设备及存储介质
US10769496B2 (en) Logo detection
CN108229489B (zh) 关键点预测、网络训练、图像处理方法、装置及电子设备
US11837017B2 (en) System and method for face recognition based on dynamic updating of facial features
CN111598164B (zh) 识别目标对象的属性的方法、装置、电子设备和存储介质
JP2022532460A (ja) モデル訓練方法、装置、端末及びプログラム
CN108229673B (zh) 卷积神经网络的处理方法、装置和电子设备
CN112232293A (zh) 图像处理模型训练、图像处理方法及相关设备
CN111488873B (zh) 一种基于弱监督学习的字符级场景文字检测方法和装置
CN114511041B (zh) 模型训练方法、图像处理方法、装置、设备和存储介质
US11170581B1 (en) Supervised domain adaptation
CN112784835B (zh) 圆形印章的真实性识别方法、装置、电子设备及存储介质
CN114092759A (zh) 图像识别模型的训练方法、装置、电子设备及存储介质
CN113052295A (zh) 一种神经网络的训练方法、物体检测方法、装置及设备
CN112614117A (zh) 设备区域提取模型训练方法、设备区域提取方法及装置
CN114429577B (zh) 一种基于高置信标注策略的旗帜检测方法及系统及设备
CN113963186A (zh) 目标检测模型的训练方法、目标检测方法及相关装置
CN111476144B (zh) 行人属性识别模型确定方法、装置及计算机可读存储介质
KR20160128869A (ko) 사전 정보를 이용한 영상 물체 탐색 방법 및 이를 수행하는 장치
CN116977271A (zh) 缺陷检测方法、模型训练方法、装置及电子设备
CN112822393B (zh) 图像处理方法、装置及电子设备
CN115439734A (zh) 质量评估模型训练方法、装置、电子设备及存储介质
CN111124862B (zh) 智能设备性能测试方法、装置及智能设备
CN114972910A (zh) 图文识别模型的训练方法、装置、电子设备及存储介质
CN112348060A (zh) 分类向量生成方法、装置、计算机设备和存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant