CN114329124A - 基于梯度重优化的半监督小样本分类方法 - Google Patents

基于梯度重优化的半监督小样本分类方法 Download PDF

Info

Publication number
CN114329124A
CN114329124A CN202111547919.9A CN202111547919A CN114329124A CN 114329124 A CN114329124 A CN 114329124A CN 202111547919 A CN202111547919 A CN 202111547919A CN 114329124 A CN114329124 A CN 114329124A
Authority
CN
China
Prior art keywords
gradient
data
optimization
classification method
small sample
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111547919.9A
Other languages
English (en)
Inventor
吴泽彬
陈华生
徐洋
刘倩
张毅
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Nanjing University of Science and Technology
Original Assignee
Nanjing University of Science and Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Nanjing University of Science and Technology filed Critical Nanjing University of Science and Technology
Priority to CN202111547919.9A priority Critical patent/CN114329124A/zh
Publication of CN114329124A publication Critical patent/CN114329124A/zh
Pending legal-status Critical Current

Links

Images

Landscapes

  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明公开了一种基于梯度重优化的半监督小样本分类方法,包括:根据类别信息将数据集划分为元训练阶段和元测试阶段,每个阶段有若干个任务,每个任务分为支持集和查询集,支持集包括有标签数据和无标签数据,计算每个支持集中有标签训练样本的梯度信息,进行梯度优化得到粗分类器;利用粗分类器预测无标签数据的伪标签,得到支持集的全部标签;对支持集数据进行梯度重优化得到精分类器,再测试得到查询集的结果。本发明充分利用少量的有标签数据和无标签数据的梯度信息,提高算法的准确度,并且在计算样本梯度信息和梯度重优化过程中使用元任务的一阶近似值来代替二阶导信息,从而提升分类的速度。

Description

基于梯度重优化的半监督小样本分类方法
技术领域
本发明涉及图像处理技术领域,具体涉及一种基于梯度重优化的半监督小样本分类方法。
背景技术
近年来,计算机视觉在各个领域得到了广泛的应用。而小样本学习已经成为计算机视觉领域中非常重要的前沿问题,在医疗图像等数据采集难度较大的领域具有十分广阔的应用前景。小样本学习问题存在两个难点:标记样本极少,多数类别少于10个;分类器需要适应新的类别,小样本学习问题的分类器必须调整以适应新的类,传统的方法是在新的数据上重新训练新的模型,但是由于样本太少,往往会导致过拟合。为了解决这些问题,有研究者提出了元学习的概念,它主要分成了元测试阶段和元训练阶段,元训练阶段使用的是有大量标签的基础数据,元测试阶段使用的是有少量标签的新类数据。在每个阶段将数据分成了很多个任务,每个任务上有支持集和查询集,分别对应传统深度学习的训练集和测试集。它的思路是寻找模型的参数和超参数,这样一来在不会让小样本过拟合的条件下可以很容易的适应新的任务,即在元测试阶段能够达到很好的效果。
目前,已有一些学者对小样本学习进行了研究,主要的方法可以分为三类:第一类是基于模型的方法,该方法主要通过设计模型的结构,使用少量样本来更新参数从而直接建立输入和预测值的映射函数。第二类是基于度量学习的方法,它的主要思想是将任务中的样本映射到一个特征空间内,通过最近邻的思想来完成分类。最后一类是基于梯度重优化的方法,通过梯度下降找到一组最优的参数,从而能够在新任务上经过少量的更新就能达到很好的效果。
然而上述这些方法都是基于监督学习的,在现实生活中还有大量的无标签数据可以利用,如果直接应用在小样本的算法中仍然存在如下的几个问题:1)元训练阶段只有包含很少的有标签数据,如何构建一个模型,以在元测试阶段获得更好的效果是有待解决的问题之一;2)虽然目前对无标签数据处理的方法在图像识别中取得了较高的识别率,但是这些都是基于一定量样本的情况下的,在元学习条件下依然没有一个较好的处理无标签数据的方法。
通过以上描述,如何在元学习的情况下充分利用无标签数据,并且进一步提高检测准确率是亟待解决的问题。
发明内容
本发明的目的在于提供一种基于梯度重优化的半监督小样本分类方法,充分利用无标签数据的信息,来进一步提高网络对当前任务的适应度,并且使用了一种新的可用于小样本学习的半监督方法,从而在查询集上能够获得更高的准确度,具有良好的应用前景。
为了达到上述目的,发明采用的技术方案是:一种基于梯度重优化的半监督小样本分类方法,包括以下步骤:
步骤(A),对有标签数据的特征进行建模,得到一个初步模型,然后结合该模型以及伪标签生成算法得到无标签数据的伪标签,从而得到支持集的全部标签;
步骤(B),将新的支持集输入梯度重优化模块中,对支持集数据的特征进行建模,得到最终模型,再测试得到查询集的结果。
进一步的,所述步骤(A)具体实现如下:
(A1)将有标签数据输入网络结构中,计算样本的梯度信息,再利用样本的梯度信息更新网络参数,得到一个初步模型;
(A2)利用初步模型以及伪标签生成算法得到无标签数据的伪标签。
进一步的,步骤(A1)中的样本梯度信息计算公式具体为:
Figure BDA0003416252410000021
其中,x(j),y(j)分别表示模块的输入数据以及其对应的标签,fφ(x(j))表示输入样本的预测值,
Figure BDA0003416252410000022
表示第i个任务,
Figure BDA0003416252410000023
表示对φ求梯度。
进一步的,步骤(A1)中的更新网络参数具体公式为:
Figure BDA0003416252410000024
其中,
Figure BDA0003416252410000025
表示粗分类器的网络参数,
Figure BDA0003416252410000026
表示有标签数据的损失函数,φ表示网络的初始参数。
进一步的,所述步骤(A2)中的伪标签生成算法,其步骤如下:
首先对无标签数据进行两次数据增强,然后利用得到的初步模型对增强后的数据进行预测,从而得到最终数据增强后的伪标签。
进一步的,所述步骤(B)具体实现如下,
(B1)将支持集输入到梯度重优化模块中,计算样本的梯度信息,利用梯度信息再次更新网络参数,得到最终模型;
(B2)利用最终模型计算查询集的分类结果;
(B3)如果当前处于元训练阶段,则需要更新初始化参数,直到当前处于元测试阶段,分类结果即所求结果。
进一步的,所述步骤(B1)中再次更新网络参数的公式具体为:
Figure BDA0003416252410000031
其中,θ*i表示精分类器的网络参数,
Figure BDA0003416252410000032
表示支持集的损失函数,
Figure BDA0003416252410000033
表示粗略分类器的网络参数。
进一步的,所述步骤(B3)中更新初始化参数的公式具体为:
Figure BDA0003416252410000034
其中,φ表示网络的初始化参数;(η,ε)表示超参数;
Figure BDA0003416252410000035
分别表示第n个任务时,有标签数据以及支持集对应的训练损失函数;
Figure BDA0003416252410000036
θ*n分别表示第n个任务时,粗略分类器以及精细分类器对应的网络参数。
一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述的基于梯度重优化的半监督小样本分类方法。
一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现上述的基于梯度重优化的半监督小样本分类方法。
本发明与现有技术相比,其显著优点在于:1)本发明提出了一种新的基于梯度重优化的半监督小样本分类框架,以少量有标签数据进行预训练,通过对梯度信息的合理使用,得到了一个粗略分类器,然后在此基础上利用支持集数据得到一个精细分类器,可以使得分类更加精准;2)该算法能够在只有少量样本的情况下帮助无标签数据生成伪标签,从而能够达到扩充训练样本的目的;3)整个框架在计算样本的梯度信息时,使用一阶导近似值替代了二阶导数,有效降低了算法的时间复杂度。
附图说明
图1是本发明的整体流程图。
具体实施方式
以下结合说明书附图,对本发明做详细说明。
如图1所示,一种基于梯度重优化的半监督小样本分类方法,具体步骤如下:
步骤(A),对数据集进行处理,分成若干个任务,每个任务包括支持集和查询集,支持集中包括有标签数据和无标签数据,具体步骤如下:
(A1)在数据集中抽取一定的类别用于元训练阶段,剩余的类别用于元测试阶段;
(A2)对于M-way K-shot问题,分别在元训练数据集与元测试数据集中抽取M个种类;
(A3)每个种类抽取K张有标签样本,以及u张无标签样本作为支持集,最后抽取v张样本作为查询集;
(A4)将支持集和查询集合成为一个任务。
(A5)重复进行上述步骤(A2)~步骤(A4),将用于元训练阶段与元测试阶段的数据集全都划分成任务形式;
步骤(B),计算每个支持集中有标签训练样本的梯度信息,得到粗分类器,再利用粗分类器预测无标签数据的伪标签,得到支持集的全部标签,具体步骤如下:
(B1)将有标签数据输入到网络中,计算样本的梯度信息
Figure BDA0003416252410000041
其中
Figure BDA0003416252410000042
x(j),y(j)表示有标签数据以及其对应的标签;
(B2)利用样本的梯度信息更新网络参数,从而得到一个粗分类器,公式为:
Figure BDA0003416252410000043
Figure BDA0003416252410000044
表示粗分类器的网络参数,
Figure BDA0003416252410000045
表示有标签数据的损失函数,φ表示网络的初始参数,η表示超参数;
(B3)利用得到的粗分类器以及伪标签生成算法得到无标签数据的伪标签。伪标签生成算法是指:首先对无标签数据进行两次数据增强,然后利用得到的粗略分类器对增强后的数据进行预测,得到特征图,再经过softmax操作后取平均值,最后利用Sharpen锐化算法得到最终数据增强后的伪标签;
(B4)利用一致性正则化的原则,得到支持集的全部标签;
步骤(C),对支持集进行梯度重优化,得到一个精分类器,再测试得到查询集的结果,具体步骤如下:
(C1)利用网络训练支持集,计算样本的梯度信息
Figure BDA0003416252410000051
其中
Figure BDA0003416252410000052
Figure BDA0003416252410000053
x(j),y(j)表示支持集数据以及其对应的标签;
(C2)利用样本的梯度信息再次更新网络,公式为:
Figure BDA0003416252410000054
θ*i表示精分类器的网络参数,
Figure BDA0003416252410000055
表示支持集的损失函数,
Figure BDA0003416252410000056
表示粗略分类器的网络参数,ε表示超参数;
(C3)利用精分类器计算查询集的分类结果。
(C4)如果当前处于元训练阶段,则需要根据以下公式来更新初始化参数:
Figure BDA0003416252410000057
其中,φ表示网络的初始化参数;(η,ε)表示超参数;N表示一共有N个任务;
Figure BDA0003416252410000058
分别表示第n个任务时,有标签数据以及支持集对应的训练损失函数;
Figure BDA0003416252410000059
θ*n分别表示第n个任务时,粗略分类器以及精细分类器对应的网络参数。
综上所述,本发明的基于梯度重优化的半监督小样本分类方法,充分利用无标签数据的梯度信息来使得网络对当前任务有更好的适应度,并且使用了一种新的可用于小样本学习的半监督方法,从而能够达到更高的精度。
以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。

Claims (10)

1.一种基于梯度重优化的半监督小样本分类方法,其特征在于,包括以下步骤:
步骤(A),对有标签数据的特征进行建模,得到一个初步模型,然后结合该模型以及伪标签生成算法得到无标签数据的伪标签,从而得到支持集的全部标签;
步骤(B),将新的支持集输入梯度重优化模块中,对支持集数据的特征进行建模,得到最终模型,再测试得到查询集的结果。
2.根据权利要求1所述的基于梯度重优化的半监督小样本分类方法,其特征在于,所述步骤(A)具体实现如下:
(A1)将有标签数据输入网络结构中,计算样本的梯度信息,再利用样本的梯度信息更新网络参数,得到一个初步模型;
(A2)利用初步模型以及伪标签生成算法得到无标签数据的伪标签。
3.根据权利要求2所述的基于梯度重优化的半监督小样本分类方法,其特征在于,步骤(A1)中的样本梯度信息计算公式具体为:
Figure FDA0003416252400000011
其中,x(j),y(j)分别表示模块的输入数据以及其对应的标签,fφ(x(j))表示输入样本的预测值,
Figure FDA0003416252400000012
表示第i个任务,
Figure FDA0003416252400000013
表示对φ求梯度。
4.根据权利要求3所述的基于梯度重优化的半监督小样本分类方法,其特征在于,步骤(A1)中的更新网络参数具体公式为:
Figure FDA0003416252400000014
其中,
Figure FDA0003416252400000015
表示粗分类器的网络参数,
Figure FDA0003416252400000016
表示有标签数据的损失函数,φ表示网络的初始参数,η表示超参数。
5.根据权利要求4所述的基于梯度重优化的半监督小样本分类方法,其特征在于,所述步骤(A2)中的伪标签生成算法,其步骤如下:
首先对无标签数据进行两次数据增强,然后利用得到的初步模型对增强后的数据进行预测,从而得到最终数据增强后的伪标签。
6.根据权利要求1所述的基于梯度重优化的半监督小样本分类方法,其特征在于:所述步骤(B)具体实现如下:
(B1)将支持集输入到梯度重优化模块中,计算样本的梯度信息,利用梯度信息再次更新网络参数,得到最终模型;
(B2)利用最终模型计算查询集的分类结果;
(B3)如果当前处于元训练阶段,则需要更新初始化参数,直到当前处于元测试阶段,分类结果即所求结果。
7.根据权利要求6所述的基于梯度重优化的半监督小样本分类方法,其特征在于,所述步骤(B1)中再次更新网络参数的公式具体为:
Figure FDA0003416252400000021
其中,θ*i表示精分类器的网络参数,
Figure FDA0003416252400000022
表示支持集的损失函数,
Figure FDA0003416252400000023
表示粗略分类器的网络参数,ε表示超参数。
8.根据权利要求7所述的基于梯度重优化的半监督小样本分类方法,其特征在于,所述步骤(B3)中更新初始化参数的公式具体为:
Figure FDA0003416252400000024
其中,φ表示网络的初始化参数;(η,ε)表示超参数;
Figure FDA0003416252400000025
分别表示第n个任务时,有标签数据以及支持集对应的训练损失函数;
Figure FDA0003416252400000026
θ*n分别表示第n个任务时,粗略分类器以及精细分类器对应的网络参数。
9.一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现如权利要求1-8中任一所述的基于梯度重优化的半监督小样本分类方法。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该程序被处理器执行时实现如权利要求1-8中任一所述的基于梯度重优化的半监督小样本分类方法。
CN202111547919.9A 2021-12-16 2021-12-16 基于梯度重优化的半监督小样本分类方法 Pending CN114329124A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111547919.9A CN114329124A (zh) 2021-12-16 2021-12-16 基于梯度重优化的半监督小样本分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111547919.9A CN114329124A (zh) 2021-12-16 2021-12-16 基于梯度重优化的半监督小样本分类方法

Publications (1)

Publication Number Publication Date
CN114329124A true CN114329124A (zh) 2022-04-12

Family

ID=81052243

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111547919.9A Pending CN114329124A (zh) 2021-12-16 2021-12-16 基于梯度重优化的半监督小样本分类方法

Country Status (1)

Country Link
CN (1) CN114329124A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114782752A (zh) * 2022-05-06 2022-07-22 兰州理工大学 基于自训练的小样本图像集成分类方法及装置
CN116563638A (zh) * 2023-05-19 2023-08-08 广东石油化工学院 一种基于情景记忆的图像分类模型优化方法和系统

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114782752A (zh) * 2022-05-06 2022-07-22 兰州理工大学 基于自训练的小样本图像集成分类方法及装置
CN114782752B (zh) * 2022-05-06 2023-09-05 兰州理工大学 基于自训练的小样本图像集成分类方法及装置
CN116563638A (zh) * 2023-05-19 2023-08-08 广东石油化工学院 一种基于情景记忆的图像分类模型优化方法和系统
CN116563638B (zh) * 2023-05-19 2023-12-05 广东石油化工学院 一种基于情景记忆的图像分类模型优化方法和系统

Similar Documents

Publication Publication Date Title
Zhang et al. Intelligent fault diagnosis of machines with small & imbalanced data: A state-of-the-art review and possible extensions
US11960568B2 (en) Model and method for multi-source domain adaptation by aligning partial features
Shen et al. Wind speed prediction of unmanned sailboat based on CNN and LSTM hybrid neural network
CN110009030B (zh) 基于stacking元学习策略的污水处理故障诊断方法
CN114841257B (zh) 一种基于自监督对比约束下的小样本目标检测方法
CN114329124A (zh) 基于梯度重优化的半监督小样本分类方法
CN116644755B (zh) 基于多任务学习的少样本命名实体识别方法、装置及介质
CN112270345B (zh) 基于自监督字典学习的聚类算法
Zhang et al. Quantifying the knowledge in a DNN to explain knowledge distillation for classification
WO2023124342A1 (zh) 一种针对图像分类的神经网络结构低成本自动搜索方法
CN110598022A (zh) 一种基于鲁棒深度哈希网络的图像检索系统与方法
CN114609994A (zh) 基于多粒度正则化重平衡增量学习的故障诊断方法及装置
CN114255371A (zh) 一种基于组件监督网络的小样本图像分类方法
CN113920363B (zh) 一种基于轻量级深度学习网络的文物分类方法
Weber et al. Automated labeling of electron microscopy images using deep learning
CN117669656A (zh) 基于TCN-Semi PN的直流微电网稳定性实时监测方法及装置
Li Parallel two-class 3D-CNN classifiers for video classification
CN116168231A (zh) 基于增量式网络和动量对比学习的自监督图像分类方法
CN114881172A (zh) 一种基于加权词向量和神经网络的软件漏洞自动分类方法
CN114036947A (zh) 一种半监督学习的小样本文本分类方法和系统
CN113987170A (zh) 基于卷积神经网络的多标签文本分类方法
CN113011163A (zh) 基于深度学习模型的复合文本多分类方法及系统
CN114295967A (zh) 一种基于迁移神经网络的模拟电路故障诊断方法
CN112926670A (zh) 一种基于迁移学习的垃圾分类系统及方法
Jin Handwritten digit recognition based on classical machine learning methods

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination