CN114492843A - 一种基于半监督学习的分类方法、设备及存储介质 - Google Patents
一种基于半监督学习的分类方法、设备及存储介质 Download PDFInfo
- Publication number
- CN114492843A CN114492843A CN202210135599.4A CN202210135599A CN114492843A CN 114492843 A CN114492843 A CN 114492843A CN 202210135599 A CN202210135599 A CN 202210135599A CN 114492843 A CN114492843 A CN 114492843A
- Authority
- CN
- China
- Prior art keywords
- data
- loss function
- classification
- semi
- supervised learning
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- General Engineering & Computer Science (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Software Systems (AREA)
- Medical Informatics (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明公开了一种基于半监督学习的分类方法、设备及存储介质,其分类方法包括:获取训练数据以更新分类模型,基于更新后的所述分类模型对未标注数据进行伪标签预测,为预测所得的伪标签数据计算其对应的监督损失函数并对其进行正则化处理;对所述训练数据进行增广处理以获得目标数据,基于所述训练数据和所述目标数据之间的欧式距离计算出同一数据在高层语义特征上的相似性,并将其作为无监督损失函数与所述监督损失函数进行融合以获得总损失函数;根据所述总损失函数对所述分类模型进行优化,基于优化后的所述分类模型对预测样本进行分类。本发明可有效地提高了模型收敛速度及模型分类准确率,并降低了研发业务中的数据标注需求。
Description
技术领域
本发明涉及深度学习技术领域,尤其涉及一种基于半监督学习的分类方法、设备及存储介质。
背景技术
数据是驱动深度学习技术发展的主要因素之一,现实中有海量的数据,但仅有一小部分是经过标注的,目前的监督学习仅用已标注的数据进行训练,性能受限。而半监督学习同时使用已标注数据和未标注数据对模型进行优化,进而提升模型的泛化能力。而现在有的半监督学习过程中由于未标注数据未经过人工审核,其可信度较低;且针对半监督学习中无标签数据相对较多,使得模型优化难度提升,导致无法提升模型收敛速度以及模型分类准确率。
发明内容
为了克服现有技术的不足,本发明的目的之一在于提供一种基于半监督学习的分类方法,可有效地提高了模型收敛速度及模型分类准确率,并降低了研发业务中的数据标注需求。
本发明的目的之二在于提供一种电子设备。
本发明的目的之三在于提供一种计算机可读存储介质。
本发明的目的之一采用如下技术方案实现:
一种基于半监督学习的分类方法,包括:
获取训练数据以更新分类模型,基于更新后的所述分类模型对未标注数据进行伪标签预测,为预测所得的伪标签数据计算其对应的监督损失函数并对其进行正则化处理;
对所述训练数据进行增广处理以获得目标数据,基于所述训练数据和所述目标数据之间的欧式距离计算出同一数据在高层语义特征上的相似性,并将其作为无监督损失函数与所述监督损失函数进行融合以获得总损失函数;
根据所述总损失函数对所述分类模型进行优化,基于优化后的所述分类模型对预测样本进行分类。
进一步地,所述训练数据包括已标注数据和未标注数据;所述分类模型预先利用所述已标注数据训练获得。
进一步地,基于更新后的所述分类模型对未标注数据进行伪标签预测的方法为:
将所述训练数据导入所述分类模型中进行数据分类以区分出所述已标注数据以及所述未标注数据,为所述未标注数据生成对应的伪标签。
进一步地,对所述未标注数据进行伪标签预测过程中,还包括:
将分类置信度小于预设阈值的预测结果进行置零处理。
进一步地,对所述伪标签数据的监督损失函数进行正则化处理的方法包括:
利用交叉熵算法L=-∑ipilogpi对所述伪标签数据的监督损失函数进行正则化处理;其中,pi表示训练数据样本i的最大置信度。
进一步地,对所述伪标签数据的监督损失函数进行正则化处理的方法还包括:
进一步地,对增广处理前后的分类结果进行无监督损失函数计算的方法为:
根据计算同一训练数据在增广处理前和增广处理后之间的欧氏距离作为无监督损失函数;其中,fi表示第i个训练数据在增广处理前的特征向量,fAi表示第i个训练数据在增广处理后的特征向量,||·||表示求向量模长。
进一步地,对所述分类模型进行优化的方法为:
其中,L1表示所述已标注数据的损失函数,L2表示所述未标注数据的损失函数,deuclidean表示增广处理前后分类结果的无监督损失函数,λi表示第i类数据损失函数所对应的权重,q表示预测标签的最大概率值。
本发明的目的之二采用如下技术方案实现:
一种电子设备,其包括处理器、存储器及存储于所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述的基于半监督学习的分类方法。
本发明的目的之三采用如下技术方案实现:
一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被执行时实现上述的基于半监督学习的分类方法。
相比现有技术,本发明的有益效果在于:
本发明基于预先使用已标注数据训练获得的分类模型分类出未标注数据并为其标记上伪标签;对伪标签进行损失正则化处理有助于半监督学习训练平稳收敛;并引入额外的数据增广步骤进行无监督学习,对类内距离进行约束,提升了类内聚合能力,降低网络的优化难度,有效地提高了收敛速度及模型分类准确率,并降低了研发业务中的数据标注需求。
附图说明
图1为本发明基于半监督学习的分类方法的流程示意图。
具体实施方式
下面,结合附图以及具体实施方式,对本发明做进一步描述,需要说明的是,在不相冲突的前提下,以下描述的各实施例之间或各技术特征之间可以任意组合形成新的实施例。
实施例一
本实施例提供一种基于半监督学习的分类方法,针对生成的伪标签未经过人工审核使其可信度较低的问题,本实施例在使用伪标签计算损失时,引入联合损失函数对伪标签损失进行正则化,有助于半监督学习训练平稳收敛;同时,针对半监督学习中无标签数据多,模型优化难的问题,本实施例引入额外的数据增广模块及无监督的增广损失函数,对类内距离进行约束,降低网络的优化难度,从而有效地提高了收敛速度及模型分类准确率,并降低了研发业务中的数据标注需求。
参考图1所示,本实施例的基于半监督学习的分类方法包括具体包括如下步骤:
步骤S1:获取训练数据以更新分类模型,基于更新后的所述分类模型对未标注数据进行伪标签预测,为预测所得的所述伪标签计算其对应的监督损失函数并对其进行正则化处理;
步骤S2:对所述训练数据进行增广处理以获得目标数据,基于所述训练数据和所述目标数据之间的欧式距离计算出同一数据在高层语义特征上的相似性,并将其作为无监督损失函数与所述监督损失函数进行融合以获得总损失函数;
步骤S3:根据所述总损失函数对所述分类模型进行优化,基于优化后的所述分类模型对预测样本进行分类。
本实施例中,所述训练数据包括已标注数据和未标注数据;本实施例预先获取所述已标注数据,将所述已标注数据作为样本进行预训练以获得对应的分类模型;本方案实施例采用ResNet50作为分类模型进行预训练。
经过预训练所得的所述分类模型由于具有一定的分类能力,因此可用于预测未标注数据的伪标签;具体为:将所述已标注数据和所述未标注数据作为训练数据输入预训练所得的所述分类模型中以再次更新该模型,并在训练过程中所述分类模型为所述未标注数据生成对应的伪标签,实现将所述训练数据中的所述已标注数据以及所述伪标签数据进行区分。
本实施例使用所述分类模型对所述未标注数据进行标签预测过程中,估算出每个预测结果的分类置信度,并判断每个预测结果的分类置信度是否小于预设阈值,若存在任意一预测结果的分类置信度小于预设阈值,则对该预测结果进行置零,以减小错误标签的影响,同时减少运算量。
本实施例利用所述分类模型区分所述已标注数据和所述伪标签数据后,分别对所述已标注数据和所述伪标签数据进行监督损失计算。其中,由于伪标签未经过人工审核,其可信度较低,在使用伪标签计算损失时,本实施例采用两种方法损失进行正则化。
正则化处理的第一种方法为:降低伪标签的置信度。
基于本实施例的分类流程需对训练数据进行三种标签的分类任务,第一种分类任务是将数据分类并标记为已标注数据,第二种分类任务是将数据分类并标记为伪标签数据,第三种分类任务是将数据分类并标记为增广处理后的数据。而本实施例在进行正则化处理之前需对真实标签进行独热编码,如p=(1,0,0)表示该样本数据属于三分类任务的第一个类别,由于真实标签已经经过人工核验,其置信度为1。对于伪标签,我们采用软标签的方式进行编码,例如,某样本数据经过模型后,有0.9的置信度为第二个标签,则该样本的伪标为p=(0,0.9,0)。本实施例利用交叉熵算法L=-∑ipilogpi计算监督损失,可降低伪标签的置信度从而起到标签平滑和减小伪标签数据损失系数的作用;其中,pi表示训练数据样本i的最大置信度。
正则化处理的第二种方法为:减小多样本类别的权重系数。
由于某些类别的样本比较常见,在未标注数据中较多,为防止在半监督学习中较为常见且数量较多的样本类别主导优化方向;本实施例在计算损失函数时,需要对每个类别的损失分配一个权重,保证模型均衡优化,如下式:
其中,c表示类别数,ωi表示权重系数,Li表示第i类样本的损失函数。
本实施例为保证模型的连续性及类内距离最小化,本实施例引入数据增广模块及增广损失。本实施例采用随机擦除、随机翻转、随机旋转等方式对所述训练数据进行增广处理从而获得增广后的目标数据,增广后的目标数据经过所述分类模型进行数据分类后获得分类结果,结合增广前以及增广后的分类结果进行无监督损失计算。本实施例基于同一数据在增广前后的高层语义仍应相似的假设,引入无监督的损失函数计算同一数据在高层语义特征上的相似性,提升同一类别的类内距离,辅助半监督学习。
具体为:本实施例根据计算同一训练数据在增广处理前和增广处理后之间的欧氏距离,并将其作为所述增广训练数据的无监督损失函数以无监督损失的形式加入至训练损失中;其中,fi表示第i个训练数据在增广处理前的特征向量,fAi表示第i个训练数据在增广处理后的特征向量,i=1,2,…,n(n为实数),||·||表示求向量模长。
本实施例在获得所述已标注数据、所述伪标签数据以及增广后数据相对应的损失函数后,即可通过多个分类损失联合优化模型,以提升预测准确率。本实施例在步骤S1中得到的分类模型由于已经具备一定的分类能力,所以为了模型能逐渐适应新数据的加入,我们需要以一个不大的学习率训练模型,如10-3。而模型优化方法选用梯度下降法,模型测试可采用直推学习或归纳学习的形式;优化模型所使用的损失函数为所述已标注数据的监督损失函数、所述伪标签数据的监督损失函数,以及增广前后分类结果所对应的无监督损失函数的加权和;对所述分类模型进行优化的方法为如下式:
其中,L1表示所述已标注数据的损失函数,L2表示所述未标注数据的损失函数,deuclidean表示增广处理前后数据的损失函数,λi表示第i类数据损失函数所对应的权重,p表示真实标签,q表示预测标签的最大概率值,fi表示第i个训练数据在增广处理前的特征向量,fAi表示第i个训练数据在增广处理后的特征向量,i=1,2,…,n(n为实数),||·||表示求向量模长。
本实施例通过上述方法优化所述分类模型后,即可利用所述分类模型进行数据分类预测,提升预测准确率。采用本实施例的伪标签损失正则化有助于半监督学习训练平稳收敛,增广模块引入了无监督学习,提升了类内聚合能力,结合本实施例提出的训练流程,有效地提高了收敛速度及模型分类准确率,并降低了研发业务中的数据标注需求。
实施例二
本实施例提供一种电子设备,其包括处理器、存储器及存储于所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现实施例一中的基于半监督学习的分类方法;另外,本实施例还提供一种存储介质,其上存储有计算机程序,所述计算机程序被执行时实现上述的基于半监督学习的分类方法。
本实施例中的设备及存储介质与前述实施例中的方法是基于同一发明构思下的两个方面,在前面已经对方法实施过程作了详细的描述,所以本领域技术人员可根据前述描述清楚地了解本实施例中的设备及存储介质的结构及实施过程,为了说明书的简洁,在此就不再赘述。
上述实施方式仅为本发明的优选实施方式,不能以此来限定本发明保护的范围,本领域的技术人员在本发明的基础上所做的任何非实质性的变化及替换均属于本发明所要求保护的范围。
Claims (10)
1.一种基于半监督学习的分类方法,其特征在于,包括:
获取训练数据以更新分类模型,基于更新后的所述分类模型对未标注数据进行伪标签预测,为预测所得的伪标签数据计算其对应的监督损失函数并对其进行正则化处理;
对所述训练数据进行增广处理以获得目标数据,基于所述训练数据和所述目标数据之间的欧式距离计算出同一数据在高层语义特征上的相似性,并将其作为无监督损失函数与所述监督损失函数进行融合以获得总损失函数;
根据所述总损失函数对所述分类模型进行优化,基于优化后的所述分类模型对预测样本进行分类。
2.根据权利要求1所述的基于半监督学习的分类方法,其特征在于,所述训练数据包括已标注数据和未标注数据;所述分类模型预先利用所述已标注数据训练获得。
3.根据权利要求2所述的基于半监督学习的分类方法,其特征在于,基于更新后的所述分类模型对未标注数据进行伪标签预测的方法为:
将所述训练数据导入所述分类模型中进行数据分类以区分出所述已标注数据以及所述未标注数据,为所述未标注数据生成对应的伪标签。
4.根据权利要求1所述的基于半监督学习的分类方法,其特征在于,对所述未标注数据进行伪标签预测过程中,还包括:
将分类置信度小于预设阈值的预测结果进行置零处理。
5.根据权利要求1所述的基于半监督学习的分类方法,其特征在于,对所述伪标签数据的监督损失函数进行正则化处理的方法包括:
利用交叉熵算法L=-∑ipilog pi对所述伪标签数据的监督损失函数进行正则化处理;其中,pi表示训练数据样本i的最大置信度。
9.一种电子设备,其特征在于,其包括处理器、存储器及存储于所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现权利要求1~8任一所述的基于半监督学习的分类方法。
10.一种计算机可读存储介质,其特征在于,其上存储有计算机程序,所述计算机程序被执行时实现权利要求1~8任一所述的基于半监督学习的分类方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210135599.4A CN114492843A (zh) | 2022-02-14 | 2022-02-14 | 一种基于半监督学习的分类方法、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210135599.4A CN114492843A (zh) | 2022-02-14 | 2022-02-14 | 一种基于半监督学习的分类方法、设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114492843A true CN114492843A (zh) | 2022-05-13 |
Family
ID=81479913
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210135599.4A Pending CN114492843A (zh) | 2022-02-14 | 2022-02-14 | 一种基于半监督学习的分类方法、设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114492843A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114925773A (zh) * | 2022-05-30 | 2022-08-19 | 阿里巴巴(中国)有限公司 | 模型训练方法、装置、电子设备以及存储介质 |
CN115272777A (zh) * | 2022-09-26 | 2022-11-01 | 山东大学 | 面向输电场景的半监督图像解析方法 |
CN115482436A (zh) * | 2022-09-21 | 2022-12-16 | 北京百度网讯科技有限公司 | 图像筛选模型的训练方法、装置以及图像筛选方法 |
-
2022
- 2022-02-14 CN CN202210135599.4A patent/CN114492843A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114925773A (zh) * | 2022-05-30 | 2022-08-19 | 阿里巴巴(中国)有限公司 | 模型训练方法、装置、电子设备以及存储介质 |
CN115482436A (zh) * | 2022-09-21 | 2022-12-16 | 北京百度网讯科技有限公司 | 图像筛选模型的训练方法、装置以及图像筛选方法 |
CN115272777A (zh) * | 2022-09-26 | 2022-11-01 | 山东大学 | 面向输电场景的半监督图像解析方法 |
CN115272777B (zh) * | 2022-09-26 | 2022-12-23 | 山东大学 | 面向输电场景的半监督图像解析方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
EP3767536A1 (en) | Latent code for unsupervised domain adaptation | |
CN114492843A (zh) | 一种基于半监督学习的分类方法、设备及存储介质 | |
CN111275175B (zh) | 神经网络训练方法、装置、图像分类方法、设备和介质 | |
CN111666427B (zh) | 一种实体关系联合抽取方法、装置、设备及介质 | |
CN112883714B (zh) | 基于依赖图卷积和迁移学习的absc任务句法约束方法 | |
CN111126576B (zh) | 一种深度学习的训练方法 | |
CN112069831A (zh) | 基于bert模型和增强混合神经网络的不实信息检测方法 | |
CN113312447A (zh) | 基于概率标签估计的半监督日志异常检测方法 | |
JP2022531620A (ja) | Aiによるディープラーニングネットワークを学習させる方法及びこれを利用した学習装置 | |
CN111666406A (zh) | 基于自注意力的单词和标签联合的短文本分类预测方法 | |
CN113434683B (zh) | 文本分类方法、装置、介质及电子设备 | |
US11948078B2 (en) | Joint representation learning from images and text | |
CN113139051B (zh) | 文本分类模型训练方法、文本分类方法、设备和介质 | |
CN114328942A (zh) | 关系抽取方法、装置、设备、存储介质和计算机程序产品 | |
CN112906398B (zh) | 句子语义匹配方法、系统、存储介质和电子设备 | |
CN114048314A (zh) | 一种自然语言隐写分析方法 | |
US20220253630A1 (en) | Optimized policy-based active learning for content detection | |
CN114266252A (zh) | 命名实体识别方法、装置、设备及存储介质 | |
CN111460224B (zh) | 评论数据的质量标注方法、装置、设备及存储介质 | |
CN116579345A (zh) | 命名实体识别模型的训练方法、命名实体识别方法及装置 | |
CN117218408A (zh) | 基于因果纠偏学习的开放世界目标检测方法及装置 | |
US20220392205A1 (en) | Method for training image recognition model based on semantic enhancement | |
CN116681961A (zh) | 基于半监督方法和噪声处理的弱监督目标检测方法 | |
CN111177381A (zh) | 基于语境向量反馈的槽填充和意图检测联合建模方法 | |
CN113033817B (zh) | 基于隐空间的ood检测方法、装置、服务器及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |