CN113469283A - 一种图像分类方法、图像分类模型的训练方法及设备 - Google Patents

一种图像分类方法、图像分类模型的训练方法及设备 Download PDF

Info

Publication number
CN113469283A
CN113469283A CN202110838884.8A CN202110838884A CN113469283A CN 113469283 A CN113469283 A CN 113469283A CN 202110838884 A CN202110838884 A CN 202110838884A CN 113469283 A CN113469283 A CN 113469283A
Authority
CN
China
Prior art keywords
vector
image
data
model
image classification
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Withdrawn
Application number
CN202110838884.8A
Other languages
English (en)
Inventor
张凯
王瑞丰
丁冬睿
杨光远
逯天斌
王潇涵
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shandong Huanke Information Technology Co ltd
Original Assignee
Shandong Liju Robot Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shandong Liju Robot Technology Co ltd filed Critical Shandong Liju Robot Technology Co ltd
Priority to CN202110838884.8A priority Critical patent/CN113469283A/zh
Publication of CN113469283A publication Critical patent/CN113469283A/zh
Withdrawn legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/213Feature extraction, e.g. by transforming the feature space; Summarisation; Mappings, e.g. subspace methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Computing Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Probability & Statistics with Applications (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种图像分类方法、图像分类模型的训练方法及设备。所述图像分类方法包括:将待分类图像切分成多个patch,通过线性层对每patch向量进行降维,得到第一序列向量;在第一序列向量的首部嵌入一个可变向量,得到第二序列向量;初始化第二序列向量的位置编码向量,将初始化后的位置编码向量嵌入到第二序列向量中,得到输入向量;将输入向量输入到Transformer模型的编码器,得到编码向量;取编码向量的首部的可变向量作为待分类图像的特征向量;将特征向量输入到Transformer模型的分类器,得到待分类图像的预测类别概率。本发明提高了在图像分类的分类效果。

Description

一种图像分类方法、图像分类模型的训练方法及设备
技术领域
本发明实施例涉及图像分类技术领域,特别涉及一种图像分类方法、图像分类模型的训练方法及设备。
背景技术
在图像识别与分类领域中,机器学习范围内的深度学习是一种有效的方法,产生了很多优秀的算法和网络,包括常见的卷积神经网络(Convolutional Neural Network,简称“CNN”)、循环神经网络(Recurrent Neural Network)、生成对抗网络(GenerativeAdversarial Networks)、深度强化学习(Reinforcement learning)四大主流网络结构。
但由于某些应用图像领域(例如医学图像领域)样本数据繁杂、需要专业人员才能进行标注,导致样本数据标注代价巨大,无法轻易获取大量的标注数据。图像数据集中数据集标签少、特殊分类任务效果差。
发明内容
本发明提供一种图像分类方法、图像分类模型的训练方法及设备,以解决现有技术中存在的上述问题。
第一方面,本发明实施例提供了一种图像分类方法,该方法包括:
S10:将待分类图像切分成多个patch,生成每个patch对应的patch向量;通过线性层对每patch向量进行降维,将多个降维后的patch向量进行拼接,得到第一序列向量;在所述第一序列向量的首部嵌入一个可变向量,得到第二序列向量,其中,所述可变向量与每个降维后的patch向量尺寸相同,且所述可变向量对应所述多个patch中最能代表所述待分类图像的特征的patch;
S20:初始化所述第二序列向量的位置编码向量,其中,所述位置编码向量中包含所述多个patch在所述待分类图像中的位置信息;将所述初始化后的位置编码向量嵌入到所述第二序列向量中,得到输入向量;
S30:将所述输入向量输入到Transformer模型的编码器,得到编码向量;取所述编码向量的首部的可变向量作为所述待分类图像的特征向量;将所述特征向量输入到所述Transformer模型的分类器,得到所述待分类图像的预测类别概率。
在一实施例中,S10包括:
S110:将尺寸为H×W×C的所述待分类图像切分成m个尺寸为P×P×C的patch,其中,H和W分别表示所述待分类图像的高度和宽度,C表示所述待分类图像的通道数,P表示每个patch的宽度;
S120:将每个patch展开成一个patch向量,通过所述线性层将每个patch向量降至D维,生成所述第一序列向量X1=[x1;x2;…;xm],其中,xi表示第i个patch的patch向量,i=1、2…m,
Figure BDA0003178208020000021
表示维度为D的向量域;
S130:在X1的首部嵌入所述可变向量xclass,得到所述第二序列向量X2=[xclass;x1;x2;…;xm],其中,
Figure BDA0003178208020000022
在一实施例中,S20包括:
S21:初始化xclass的位置编码向量P0,初始化xi的位置编码向量Pi,其中,所述第二序列向量的位置编码向量P=[P0;P1;P2;…;Pm],
Figure BDA0003178208020000023
j=0、1、2…m,Pj中包含Pj对应的patch在所述待分类图像中的位置信息;
S22:将P嵌入到X2中,得到所述输入向量X[xclass+P0;x1+P1;x2+P2;…;xm+Pm]。
在一实施例中,所述Transformer模型包含所述编码器和所述分类器,不包含解码器,其中,
所述编码器包括串行排列的的多头自注意力(Multiheaded Self-Attention,MSA)和第一多层感知器(Multilayer Perceptron,MLP),所述MSA的输出为所述第一MLP的输入;所述MSA与所述第一MLP的内部均采用残差连接方式;所述MSA和所述第一MLP之前均连接有一个归一化层(Layernorm,LN),待处理信号经过一个LN后再输入所述MSA或所述第一MLP进行处理;
所述分类器包括第二MLP。
第二方面,本发明实施例还提供了一种图像分类模型的训练方法。该方法包括:
S01:获取一个训练数据集D,其中,所述训练数据集中包括有标签数据集Dl和无标签数据集Du,每个训练数据为一幅训练图像,每个有标签数据dl的标签为dl的真实类别yl
S02:对每个有标签数据dl进行一次随机数据增强,得到增强后的有标签数据集
Figure BDA0003178208020000031
对每个无标签数据du进行K次随机数据增强,得到K个增强后的无标签数据集
Figure BDA0003178208020000032
k=1,...,K,将所有du的K个
Figure BDA0003178208020000033
的并集记为
Figure BDA0003178208020000034
将每个无标签数据du的K个
Figure BDA0003178208020000035
分别输入本发明实施例所述的图像分类方法对应的图像分类模型,得到K个预测类别,对所述K个预测类别取平均,将得到的平均值作为du的伪标签;
S03:将
Figure BDA0003178208020000041
输入所述图像分类模型,得到
Figure BDA0003178208020000042
中的每个数据
Figure BDA0003178208020000043
的预测类别概率;利用
Figure BDA0003178208020000044
中所有数据的预测类别概率和真实类别,计算交叉熵损失;
S04:将
Figure BDA0003178208020000045
输入所述图像分类模型,得到
Figure BDA0003178208020000046
中的每个数据
Figure BDA0003178208020000047
的预测类别概率,将所述预测类别概率中的最大概率值对应的类别作为
Figure BDA0003178208020000048
的预测类别;利用
Figure BDA0003178208020000049
中的所有数据的预测类别和伪标签,计算一致性损失;
S05:将所述交叉熵损失和所述一致性损失的加权和作为本轮训练的总损失,对所述图像分类模型中的网络参数进行训练,其中,所述网络参数包括:所述线性层的参数、所述编码器的参数和所述分类器的参数;
S06:返回S01,直到满足设定的终止条件,保存训练过程中总损失最小时的网络参数,将对应的图像分类模型作为训练好的图像分类模型。
在一实施例中,所述随机数据增强包括图像位移、改变图像的亮度、改变图像的对比度和改变图像的饱和度中的至少一种方式的随机组合,其中,图像的位移、图像的亮度、图像的对比度和图像的饱和度的改变值均为预设范围内的随机数。
在一实施例中,S03中,利用
Figure BDA00031782080200000410
中所有数据的预测类别概率和真实类别,计算交叉熵损失,包括:
根据公式(1),利用
Figure BDA00031782080200000411
中所有数据的预测类别概率和真实类别,计算所述交叉熵损失Lossl
Figure BDA00031782080200000412
其中,n表示
Figure BDA00031782080200000413
中的数据
Figure BDA00031782080200000414
的个数,
Figure BDA00031782080200000415
表示
Figure BDA00031782080200000416
的真实类别,pl,a表示所述图像分类模型预测得到的
Figure BDA00031782080200000417
的类别为
Figure BDA00031782080200000418
的概率。
在一实施例中,S04中,利用
Figure BDA0003178208020000051
中的所有数据的预测类别和伪标签,计算一致性损失,包括:
根据公式(2),利用
Figure BDA0003178208020000052
中的所有数据的预测类别和伪标签,计算一致性损失Lossu
Figure BDA0003178208020000053
其中,M表示
Figure BDA0003178208020000054
中的数据
Figure BDA0003178208020000055
的个数,ω(·)表示坡度函数,t表示全局迭代次数,yu,k,b表示
Figure BDA0003178208020000056
的预测类别,
Figure BDA0003178208020000057
表示
Figure BDA0003178208020000058
的伪标签。
在一实施例中,在S01之前,所述训练方法还包括:
S011:对所述图像分类模型进行初始化,利用大数据集对初始化的模型进行预训练,得到源模型;
S012:复制所述源模型的中的Transformer模型的编码器的参数,并初始化所述Transformer的分类器的参数,得到中间模型;
在S02中,将每个无标签数据du的K个
Figure BDA0003178208020000059
分别输入本发明实施例所述的图像分类方法对应的图像分类模型,包括:
将所述K个
Figure BDA00031782080200000510
分别输入所述中间模型。
第三方面,本发明实施例还提供了一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述实施例所述的图像分类方法,或实现所述实施例所述的图像分类模型的训练方法。
本发明提出了一种基于Transformer的半监督网络的图像分类方法与图像分类模型的训练方法。本发明具有如下有益效果:
1.本发明针对图像分类领域的特殊性,利用注意力机制思想,将Transformer模型引入到图像分类任务中,解决了传统深度学习模型提取图像的全局信息困难的问题,有效地关注图像的全局信息,同时更注重图像内容的连续性,从而提高了在图像分类的分类效果;
2.本发明通过伪标签预测和Consistency Regularization的方式,解决了图像分类领域的有标记数据获取困难的问题,仅仅用少量的有标签数据即可完成深度学习训练过程,实现了半监督网络学习,并且有良好学习效果;
3.本发明设计了适用于图像数据的数据结构,在Transformer模型的基础上增加了图像分块处理、可变(可学习)特征向量嵌入及图像位置信息编码操作,实现了Transformer模型及自注意力机制进行图像分类中的应用;
4.本发明采用基于图像的Transformer的模型多次识别无标签数据,预测无标签数据的伪标签,并将现有的预测类别与伪标签进行对比,通过保证二者的一致性来约束网络模型,实现了从大量无标签数据中学习有益的信息;
5.本发明将交叉熵损失与一致性损失联合起来对网络模型进行训练,通过交叉熵损失来实现有标签数据对网络模型的约束,通过一致性损失从无标签数据中提取有益的信息,实现了对训练数据的充分利用,在更全面的信息下也提高了网络的收敛速度和图像分类的准确性。
附图说明
图1是本发明实施例提供的一种图像分类方法的流程图。
图2是本发明实施例提供的一种图像分类模型的训练方法的流程图。
图3是本发明实施例提供的一种图像分类模型的整个训练过程的流程图。
图4是本发明实施例提供的另一种图像分类方法的流程图。
图5为本发明实施例提供的一种计算机设备的结构示意图。
具体实施方式
下面结合附图与实施例对本发明做进一步说明。在不冲突的情况下,本发明中的实施例及实施例中的特征可以相互组合。
应该指出,以下详细说明都是示例性的,旨在对本发明提供进一步的说明。除非另有指明,本文使用的所有技术和科学术语具有与本发明所属技术领域的普通技术人员通常理解的相同含义。
需要注意的是,这里所使用的术语仅是为了描述具体实施方式,而非意图限制根据本发明的示例性实施方式。如在这里所使用的,除非上下文另外明确指出,否则单数形式也意图包括复数形式,此外,还应当理解的是,术语“包括”和“具有”以及它们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
本发明实施例提出一种基于transformer的半监督网络的图像分类方法与相应的模型训练方法。
由于某些应用图像领域(例如医学图像领域)样本数据繁杂、需要专业人员才能进行标注,导致样本数据标注代价巨大,无法轻易获取大量的标注数据。图像数据集中数据集标签少、特殊分类任务效果差。基于现有的少量标记样本数据和大量未标记的样本数据进行深度学习的半监督学习算法(Semi-Supervised Learning,简称为“SSL”)可以利用仅仅一小部分的有标注数据就可以完成网络模型的训练。
目前计算机视觉领域的Transformer模型已经比肩甚至超越了传统的卷积神经网络,达到了SOTA(State Of The Art,指在该项研究任务中,目前最好/最先进的技术)的水平。Transformer模型具有CNN模型不具有的捕捉长期数据之间的依赖信息,容易获得全局图像的有效信息,是一种性能优异提取特征的模型。
实施例一
本实施例提出一种图像分类方法。图1是本发明实施例提供的一种图像分类方法的流程图。如图1所示,该方法包括S10-S30。
S10:将待分类图像切分成多个patch,生成每个patch对应的patch向量;通过线性层对每patch向量进行降维,将多个降维后的patch向量进行拼接,得到第一序列向量;在所述第一序列向量的首部嵌入一个可变向量,得到第二序列向量,其中,所述可变向量与每个降维后的patch向量尺寸相同,且所述可变向量对应所述多个patch中最能代表所述待分类图像的特征的patch。
可选地,将待分类图像切分成多个patch,并且铺平成一个序列,接着通过学习好的线性投影进行降维操作,减少维度,得到一个序列向量。最后在序列向量的首部嵌入一个同patch大小的向量,这个向量初始化是随机的,在预测阶段中是可变的,最终得到的patch向量是所有patch中最具有分类代表性的一个patch对应的向量。这个向量是专门用来做解码功能的,它与Transformer的编码器对应。
S20:初始化所述第二序列向量的位置编码向量,其中,所述位置编码向量中包含所述多个patch在所述待分类图像中的位置信息;将所述初始化后的位置编码向量嵌入到所述第二序列向量中,得到输入向量。
S30:将所述输入向量输入到Transformer模型的编码器,得到编码向量;取所述编码向量的首部的可变向量作为所述待分类图像的特征向量;将所述特征向量输入到所述Transformer模型的分类器,得到所述待分类图像的预测类别概率。
可选地,将处理后的图像数据向量输入Transformer模型的编码器内得到编码后的结果,其中,编码后的结果同样为向量形式,同输入向量的维度一致。再取位置零处的patch向量作为整个图片的特征向量,并且输入到Transformer模型的分类器,最后得到预测类别概率。“位置零处”的patch向量是指第二序列向量的首部位置的嵌入向量。
在一实施例中,S10包括S110-S130。
S110:将尺寸为H×W×C的所述待分类图像切分成m个尺寸为P×P×C的patch,其中,H和W分别表示所述待分类图像的高度和宽度,C表示所述待分类图像的通道数,P表示每个patch的宽度。
S120:将每个patch展开成一个patch向量,通过所述线性层将每个patch向量降至D维,生成所述第一序列向量X1=[x1;x2;…;xm],其中,xi表示第i个patch的patch向量,i=1、2…m,
Figure BDA0003178208020000091
表示维度为D的向量域。
S130:在X1的首部嵌入所述可变向量xclass,得到所述第二序列向量X2=[xclass;x1;x2;…;xm],其中,
Figure BDA0003178208020000101
由于Transformer模块需要连续化输入,因此需要把输入图像进行切分。可选地,将图像切分成等大的正方形patch,使得一张原始图片大小从H×W×C切分成m个P×P×C大小的patch。其中H、W分别是图像的高度和宽度,C表示图像通道数,P表示正方形patch的宽度,则m=WH/P2。可以根据具实际情况W和H的大小来选择P。可选地,将P选取为2的整数次幂。
接着用一个线性层进行数据降至D维,减少无用特征输入。最后在输入向量的起始位置嵌入一个可变的特征向量,用于输出进行分类预测的依据X,即X[xclass;x1;x2;…;xN],
Figure BDA0003178208020000102
需要说明的是,这里的“连续化输入”是指Transformer的输入需要满足连续化要求。例如,输入一句话时,其中的每个单词都具有关系性、连续性。在本发明中,将Transformer模型用于图像任务,因此需要把一幅图像切分成多个patch后排列成一个序列,也就如同“输入一句话”一样,多个patch之间具有关联性和连续性。
在一实施例中,S20包括:S21-S22。
S21:初始化xclass的位置编码向量P0,初始化xi的位置编码向量Pi,其中,所述第二序列向量的位置编码向量P=[P0;P1;P2;…;Pm],
Figure BDA0003178208020000103
j=0、1、2…m,Pj中包含Pj对应的patch在所述待分类图像中的位置信息。
S22:将P嵌入到X2中,得到所述输入向量X[xclass+P0;x1+P1;x2+P2;…;xm+Pm]。
经过S10的分块操作,必然会丢失图像原本的位置信息,因此需要在输入向量中增加可学习的位置编码向量。弥补了丢失的位置信息,同时可学习的设定保证了获得价值最高的位置信息。patch的嵌入向量(即patch向量)和patch的位置编码向量一同作为Transformer模型的编码器的输入,此过程可表示为X=[xclass+P0;x1+P1;x2+P2;…;xN+PN]。
需要说明的是,这里的“可学习的设定”在序列首部嵌入的可变向量在网络学习训练过程中会一直变化,根据注意力机制着重哪个patch,那么这个向量就会更新为这个patch对应的特征向量。
在一实施例中,所述Transformer模型包含所述编码器和所述分类器,不包含解码器。
所述编码器包括串行排列的的MSA和第一MLP,所述MSA的输出为所述第一MLP的输入。所述MSA与所述第一MLP的内部均采用残差连接方式。所述MSA和所述第一MLP之前均连接有一个LN,待处理信号经过一个LN后再输入所述MSA或所述第一MLP进行处理。所述分类器包括第二MLP。
可选地,将S20得到的结果向量输入Transformer模型中获得类别概率。Transformer模型包括两部分:一是编码器,另一个是分类器,编码器与分类器串联。编码器包括MSA和第一MLP,MSA和第一MLP均是残差连接方式,MSA的输出为第一MLP的输入。MSA和第一MLP都经过LN进行图像通道归一化操作。分类器包括第二MLP。
可选地,编码器部分负责提取图像的全局信息以及对任务有帮助的图像区域,而分类器负责根据图像特征进行分类获得其对应每类的概率值。整个Transformer模型并没有设计解码器组件,由于S10中提前嵌入了一个可变向量,由此向量充当解码器组件,作用是选一个最有效分类的patch。这一设置使得整个模型更加简单、有效。
本发明提出了一种基于Transformer的半监督网络的图像分类方法。本发明具有如下有益效果:
1.本发明实施例针对图像分类领域的特殊性,利用注意力机制思想,将Transformer模型引入到图像分类任务中,解决了传统深度学习模型提取图像的全局信息困难的问题,有效地关注图像的全局信息,同时更注重图像内容的连续性,从而提高了在图像分类的分类效果;
2.本发明实施例利用训练好的模型进行图像分类,在模型的训练过程中,通过伪标签预测和Consistency Regularization的方式,解决了图像分类领域的有标记数据获取困难的问题,仅仅用少量的有标签数据即可完成深度学习训练过程,实现了半监督网络学习,并且有良好学习效果;
3.本发明实施例设计了适用于图像数据的数据结构,在Transformer模型的基础上增加了图像分块处理、可变(可学习)特征向量嵌入及图像位置信息编码操作,实现了Transformer模型及自注意力机制进行图像分类中的应用;
4.本发明实施例利用训练好的模型进行图像分类,在模型的训练过程中,采用基于图像的Transformer的模型多次识别无标签数据,预测无标签数据的伪标签,并将现有的预测类别与伪标签进行对比,通过保证二者的一致性来约束网络模型,实现了从大量无标签数据中学习有益的信息;
5.本发明实施例利用训练好的模型进行图像分类,在模型的训练过程中,将交叉熵损失与一致性损失联合起来对网络模型进行训练,通过交叉熵损失来实现有标签数据对网络模型的约束,通过一致性损失从无标签数据中提取有益的信息,实现了对训练数据的充分利用,在更全面的信息下也提高了网络的收敛速度和图像分类的准确性。
实施例二
本实施例提供一种图像分类模型的训练方法,用于对实施例一所述的图像分类方法所构成的图像分类模型进行训练。图2是本发明实施例提供的一种图像分类模型的训练方法的流程图。如图2所示,该方法包括步骤S01-S06。
S01:获取一个训练数据集D,其中,所述训练数据集中包括有标签数据集Dl和无标签数据集Du,每个训练数据为一幅训练图像,每个有标签数据dl的标签为dl的真实类别yl
S02:对每个有标签数据dl进行一次随机数据增强,得到增强后的有标签数据集
Figure BDA0003178208020000131
对每个无标签数据du进行K次随机数据增强,得到K个增强后的无标签数据集
Figure BDA0003178208020000132
k=1,...,K,将所有du的K个
Figure BDA0003178208020000133
的并集记为
Figure BDA0003178208020000134
将每个无标签数据du的K个
Figure BDA0003178208020000135
分别输入实施例一所述的图像分类方法对应的图像分类模型,最终得到K个预测类别,对所述K个预测类别取平均,将得到的平均值作为du的伪标签。
可选地,将所有的无标签数据进行随机数据增强,重复K次,然后把增强后的无标签数据输入模型中进行预测,得到K个预测类别,最后进行取平均操作作为无标签数据的伪标签。需要说明的是,由于在代码实现过程中已将类别数字化,因此取类别的平均值可以预测出无标签数据所述的类别。
可选地,首先对原始数据集做数据增强处理。有标签数据集Dl={d1,d2,…,dn1}(其中,n1表示有标签数据的数量)。无标签数据集为Du={dn1+1,dn1+2,…,dn2}(其中,n2-n1表示无标签数据的数量)。将数据集Dl作一次随机数据增强操作,得到集合
Figure BDA0003178208020000141
将数据集Xu做K次随机数据增强操作,得到K个集合
Figure BDA0003178208020000142
k∈(1,...,K)。然后将
Figure BDA0003178208020000143
输入所述图像分类模型的初始化网络中进行伪标签预测,得到
Figure BDA0003178208020000144
k∈(1,...,K)。最后利用K次预测结果进行取平均得到最终的伪标签,即
Figure BDA0003178208020000145
Figure BDA0003178208020000146
需要说明的是,这里的“初始化网络”是指先利用大数据集对网络模型进行预训练后再用来作具体的分类任务。关于整个网络模型的训练阶段,将在后面进行详细描述。
在一实施例中,所述随机数据增强包括图像位移、改变图像的亮度、改变图像的对比度和改变图像的饱和度中的至少一种方式的随机组合,其中,图像的位移、图像的亮度、图像的对比度和图像的饱和度的改变值均为预设范围内的随机数。
对于基于伪标签和预测标签一致性实现半监督学习算法来说,随机数据增强的好坏很大程度上决定了算法的好坏。本发明针对图像领域数据集特点设计了合理的数据增强方法。
S03:将
Figure BDA0003178208020000147
输入所述图像分类模型,得到
Figure BDA0003178208020000148
中的每个数据
Figure BDA0003178208020000149
的预测类别概率;利用
Figure BDA00031782080200001410
中所有数据的预测类别概率和真实类别,计算交叉熵损失。
可选地,利用全部有标签数据的预测类别概率,计算其概率最大值所对应类别为预测类别,利用预测类别与真实标签类别进行交叉熵损失计算。交叉熵损失函数可以约束网络模型对有标记数据类别的预测与真实样本类别,使得网络模型输出更加逼近真实样本数据分布。
在一实施例中,S03中,利用
Figure BDA0003178208020000151
中所有数据的预测类别概率和真实类别,计算交叉熵损失,包括:根据公式(1),利用
Figure BDA0003178208020000152
中所有数据的预测类别概率和真实类别,计算所述交叉熵损失Lossl
Figure BDA0003178208020000153
其中,n表示
Figure BDA0003178208020000154
中的数据
Figure BDA0003178208020000155
的个数,
Figure BDA0003178208020000156
表示
Figure BDA0003178208020000157
的真实类别,pl,a表示所述图像分类模型预测得到的
Figure BDA0003178208020000158
的类别为
Figure BDA0003178208020000159
的概率。
S04:将
Figure BDA00031782080200001510
输入所述图像分类模型,得到
Figure BDA00031782080200001511
中的每个数据
Figure BDA00031782080200001512
的预测类别概率,将所述预测类别概率中的概率最大值所对应的类别作为
Figure BDA00031782080200001513
的预测类别;利用
Figure BDA00031782080200001514
中的所有数据的预测类别和伪标签,计算一致性损失。
可选地,利用全部无标签数据的预测类别概率,计算其概率最大值所对应类别为预测类别。将现在输出的预测结果(预测类别)与历史输出的预测结果(伪标签)做一致性损失计算;一致性损失函数可以约束网络模型对无标签数据类别的预测与历史输出的预测结果,使得它们尽量保持一致。由于同一数据的预测结果不变性,它们应当保持一致。基于此原理,可以挖掘无标签数据的有益信息,并且不需要已知标签信息。
在一实施例中,S04中,利用
Figure BDA00031782080200001515
中的所有数据的预测类别和伪标签,计算一致性损失,包括:根据公式(2),利用
Figure BDA00031782080200001516
中的所有数据的预测类别和伪标签,计算一致性损失Lossu
Figure BDA00031782080200001517
其中,M表示
Figure BDA00031782080200001518
中的数据
Figure BDA00031782080200001519
的个数,M=(n2-n1)×K,ω(·)表示坡度函数,t表示全局迭代次数,yu,k,b表示
Figure BDA00031782080200001520
的预测类别,
Figure BDA00031782080200001521
表示
Figure BDA00031782080200001522
的伪标签。
需要说明的是,交叉熵损失只能用有标签数据计算,因为它需要用到数据的真实标签信息。如果使用伪标签信息,那么会造成强噪声干扰,不利于模型训练。而一致性损失只用到了无标签数据结果,因为有标签数据的价值信息已经被交叉熵损失利用了,而无标签数据的伪标签信息还没有被利用。
S05:将所述交叉熵损失和所述一致性损失的加权和作为本轮训练的总损失,对所述图像分类模型中的网络参数进行训练,其中,所述网络参数包括:所述线性层的参数、所述编码器的参数和所述分类器的参数。
可选地,将交叉熵损失和一致性损失加权和做为总损失,不断进行训练,直到训练轮次达到设定值,保存其最小损失值时得网络模型。两种损失函数结合起来,可以同时使用有标签数据和无标签数据学习训练,同时得到一个批次内的有标签数据和无标签数据的有益信息,更正模型参数,为下一轮训练做准备。
可选地,将交叉熵损失Lossl和一致性损失Lossu加权和做为总损失Loss=Lossl+λLossu(其中λ是超参数),不断进行训练,使得Loss呈现下降趋势,直到训练轮次达到设定值或者Loss呈现平稳趋势。
S06:返回S01,直到满足设定的终止条件,保存训练过程中总损失最小时的网络参数,将对应的图像分类模型作为训练好的图像分类模型。
在一实施例中,在S01之前,所述训练方法还包括:S011-S012。
S011:对所述图像分类模型进行初始化,利用大数据集对初始化的模型进行预训练,得到源模型。
S012:复制所述源模型的中的Transformer模型的编码器的参数,并初始化所述Transformer的分类器的参数,得到中间模型。
这时,在S02中,将每个无标签数据du的K个
Figure BDA0003178208020000171
分别输入实施例一所述的图像分类方法对应的图像分类模型,包括:将所述K个
Figure BDA0003178208020000172
分别输入所述中间模型。
图3是本发明实施例提供的一种图像分类模型的整个训练过程的流程图。下面将结合图3,对整个图像分类模型的完整训练过程进行说明。模型的完整的训练过程需要经过初始化、预训练、复制、微调四个环节。
首先对模型的中间层以及输出层的参数进行初始化,然后用大数据集进行模型的预训练,训练完成后获得源模型以及参数。接着复制源模型的中间层参数并且初始化输出层组成中间模型。最后用任务数据集对中间模型进行训练,微调中间层参数,学习目标输出层的参数,获得鲁棒性能优良的目标模型。
在本发明实施例中,模型的中间层包括:线性层和整个Transformer模型的编码器,输出层包括整个Transformer的分类器。
在预测无标签数据的伪标签的过程中,用到的是整个图像分类模型的中间模型。
整个图像分类模型的流程可概括如下:(1)用户输入待测试图像数据进入分类系统,(2)分类系统内部自动进行图像分块处理、获取类别概率和确定预测类别三个过程,(3)输出预测类别与用户进行交互。
本发明提出了一种基于Transformer的半监督网络的图像分类模型的训练方法。本发明具有如下有益效果:
1.本发明实施例针对图像分类领域的特殊性,利用注意力机制思想,将Transformer模型引入到图像分类任务中,解决了传统深度学习模型提取图像的全局信息困难的问题,有效地关注图像的全局信息,同时更注重图像内容的连续性,从而提高了在图像分类的分类效果;
2.本发明实施例通过伪标签预测和Consistency Regularization的方式,解决了图像分类领域的有标记数据获取困难的问题,仅仅用少量的有标签数据即可完成深度学习训练过程,实现了半监督网络学习,并且有良好学习效果;
3.本发明实施例设计了适用于图像数据的数据结构,在Transformer模型的基础上增加了图像分块处理、可变(可学习)特征向量嵌入及图像位置信息编码操作,实现了Transformer模型及自注意力机制进行图像分类中的应用;
4.本发明实施例采用基于图像的Transformer的模型多次识别无标签数据,预测无标签数据的伪标签,并将现有的预测类别与伪标签进行对比,通过保证二者的一致性来约束网络模型,实现了从大量无标签数据中学习有益的信息;
5.本发明实施例将交叉熵损失与一致性损失联合起来对网络模型进行训练,通过交叉熵损失来实现有标签数据对网络模型的约束,通过一致性损失从无标签数据中提取有益的信息,实现了对训练数据的充分利用,在更全面的信息下也提高了网络的收敛速度和图像分类的准确性。
实施例三
图4是本发明实施例提供的另一种图像分类方法的流程图。该方法基于Transformer的半监督算法实现图像分类的网络学习过程,包括训练阶段和预测阶段。如图4所示,该方法包括S1-S8。
S1:预测伪标签。首先将所有的无标签数据进行随机数据增强,重复K次,然后把增强后的无标签数据输入模型中进行预测,得到K个伪标签,最后进行取平均操作作为无标签数据的伪标签。
S2:图像分块处理。将输入的图像切分成多个patch,并且铺平成一个序列,接着通过可学习的线性投影进行降维操作。最后在所有patch对应的序列向量的首部嵌入一个同patch大小的向量(简称为“patch嵌入向量”)。这个向量初始化是随机的,在训练过程中可学习(即是可变的)。这个向量是专门用来做解码功能,它与编码器对应,学习得到的patch嵌入向量是所有patch中最具有分类代表性的一个。
S3:嵌入位置编码。初始化位置编码向量加入到图像分块处理操作后的序列向量,一同作为输入向量。
S4:获取类别概率:将处理后的图像数据向量输入Transformer模型的编码器内得到编码后的结果,其中,编码后的结果同样为向量形式,同输入向量的维度一致。再取位置零处的patch嵌入向量作为整个图片的特征向量,并且输入到Transformer模型的分类器,最后得到预测类别概率。
S5:计算交叉熵损失。利用全部有标签数据的预测类别概率,其概率最大值所对应类别为预测类别。利用预测类别与真实标签类别进行交叉熵损失计算。交叉熵损失函数可以约束网络模型对有标记数据类别的预测与真实样本类别,使得网络模型输出更加逼近真实样本数据分布。
S6:计算一致性损失。利用全部无标签数据的预测类别概率,其概率最大值所对应类别为预测类别。将现在输出的预测结果与历史输出的预测结果(伪标签)做一致性损失计算。一致性损失函数可以约束网络模型对无标签数据类别的预测与历史输出的预测结果,使得它们尽量保持一致。由于同一数据的预测结果不变性,它们应当保持一致。基于此原理,可以挖掘无标签数据的有益信息,并且不需要已知标签信息。
S7:联合训练。将交叉熵损失和一致性损失加权和做为总损失,不断进行训练,直到训练轮次达到设定值。保存其最小损失值时得网络模型。两种损失函数结合起来,可以同时使用有标签数据和无标签数据学习训练,同时得到一个批次内的有标签数据和无标签数据的有益信息,更正模型参数,为下一轮训练做准备。
S8:预测类别。利用训练好得网络模型对输入的图像数据进行预测,得到预测类别概率,将最大概率值对应的类别确定为预测结果。
在上述方法中,S1和S7属于训练阶段,S8属于预测阶段。在预测阶段,将图像输入到训练好的网络模型后,在网络模型中只执行S2-S4。
在一实施例中,在S1:预测伪标签的步骤中,首先,对原始数据集做数据增强处理。有标签数据集Dl={d1,d2,…,dn1}(其中,n1表示有标签数据的数量)。无标签数据集为Du={dn1+1,dn1+2,…,dn2}(其中,n2-n1表示无标签数据的数量)。将数据集Dl作一次随机数据增强操作,得到集合
Figure BDA0003178208020000201
Figure BDA0003178208020000202
将数据集Xu做K次随机数据增强操作,得到K个集合
Figure BDA0003178208020000203
Figure BDA0003178208020000204
k∈(1,...,K)。然后将
Figure BDA0003178208020000205
输入所述图像分类模型的初始化网络中进行伪标签预测,得到
Figure BDA0003178208020000206
k∈(1,...,K)。最后利用K次预测结果进行取平均得到最终的伪标签,即
Figure BDA0003178208020000207
其中,“初始化网络”是指图3中的中间模型,即先用大数据集预训练后再用来处理具体的分类任务。
对于基于伪标签和预测标签一致性实现半监督学习算法来说,随机数据增强的好坏很大程度上决定了算法的好坏。本发明实施例针对图像领域数据集的特点设计了合理的数据增强方法。
随机数据增强包括图像的位移、改变图像的亮度、改变图像的对比度、改变图像的饱和度四种方式中的至少一种随机组合。其中,图像的位移、图像的亮度、图像的对比度、图像的饱和度的改变值全部采用一定范围内的随机数。
在S2:图像分块处理的步骤中,由于Transformer模型需要连续化输入,因此需要把输入图像进行切分,切分成等大的正方形patch,使得一张原始图片大小从H×W×C切分成m个P×P×C大小的patch。其中,H、W分别表示图像的高度和宽度,C表示图像的通道数,P表示正方形patch的宽度,则m=WH/P2。可以根据具实际情况中的W和H的大小来选择P,一般P是2的整数次幂。接着用一个线性层将数据降至D维,减少无用特征输入。最后在输入向量的起始位置嵌入一个可学习(即可变的)的特征向量,用于输出进行分类预测的依据X,即X=[xclass;x1;x2;…;xN],
Figure BDA0003178208020000211
其中,“连续化输入”是指Transformer用需要输入满足连续化要求,比如输入一句话,其中的每个单词都具有关系性、连续性。在本发明中,将Transformer模型用于图像任务,因此需要把一幅图像切分成多个patch后排列成一个序列,也就如同“输入一句话”一样,多个patch之间具有关联性和连续性。
在S2中,输入图像是指全部数据集,也就是包括了数据增强后的有标签数据
Figure BDA0003178208020000212
和无标签数据
Figure BDA0003178208020000213
在S3:嵌入位置编码的步骤中,经过S2的分块操作,必然会丢失图像原本的位置信息,因此需要在输入向量中增加可学习的位置编码向量。弥补了丢失的位置信息,同时,可学习的设定保证了获得价值最高的位置信息。patch的嵌入向量(即patch向量)和patch的位置编码向量一同作为Transformer模型的编码器的输入,此过程可表示为X=[xclass+P0;x1+P1;x2+P2;…;xN+PN]。“可学习的设定”是指这个向量在网络学习训练的过程中会一直变化,根据注意力机制着重哪个patch,那么这个向量就会更新为这个patch对应的特征向量。
在S4:获取类别概率的步骤中,将S3的结果向量输入Transformer模型中获得类别概率。Transformer模型包括两部分:一是编码器,另一个是分类器,编码器与分类器串联。编码器包括MSA和第一MLP,MSA和第一MLP均是残差连接方式,MSA的输出为第一MLP的输入。MSA和第一MLP都经过LN进行图像通道归一化操作。分类器包括第二MLP。
编码器部分负责提取图像的全局信息以及对任务有帮助的图像区域,而分类器负责根据图像特征进行分类获得其对应每类的概率值。整个Transformer模型并没有设计解码器组件,由于S2中提前嵌入了一个可变向量,由此向量充当解码器组件,作用是选一个最有效分类的patch。这一设置使得整个模型更加简单、有效。
在S5:计算交叉熵损失的步骤中,利用全部随机数据增强后的有标签数据(
Figure BDA0003178208020000221
中的数据)的预测概率pl与真实标签类别
Figure BDA0003178208020000222
计算交叉熵损失Lossl
Figure BDA0003178208020000223
其中,n表示
Figure BDA0003178208020000224
中的数据
Figure BDA0003178208020000225
的个数,
Figure BDA0003178208020000226
表示
Figure BDA0003178208020000227
的真实类别,pl,a表示所述图像分类模型预测得到的
Figure BDA0003178208020000228
的类别为
Figure BDA0003178208020000229
的概率。
在S6:计算一致性损失的步骤中,利用全部随机数据增强后无标签数据(
Figure BDA00031782080200002210
中的数据)的预测类别yu,k,与S1步骤的伪标签预测结果
Figure BDA00031782080200002211
做一致性损失计算:
Figure BDA0003178208020000231
其中,M表示
Figure BDA0003178208020000232
中的数据
Figure BDA0003178208020000233
的个数,M=(n2-n1)×K,ω(·)表示坡度函数,t表示全局迭代次数,yu,k,b表示
Figure BDA0003178208020000234
的预测类别,
Figure BDA0003178208020000235
表示
Figure BDA0003178208020000236
的伪标签。
交叉熵损失只能用有标签数据计算,因为它需要用到数据的真实标签信息。如果使用伪标签信息,那么会造成强噪声干扰,不利于模型训练。而一致性损失只用到了无标签数据结果,因为有标签数据的价值信息已经被交叉熵损失利用了,而无标签数据的伪标签信息还没有被利用。
在S7:联合训练的步骤中,将交叉熵损失Lossl和一致性损失Lossu加权和做为总损失Loss=Lossl+λLossu(其中λ是超参数),不断进行训练,使得Loss呈现下降趋势,直到训练轮次达到设定值或者Loss呈现平稳趋势。保存其最小损失值时得网络模型。
在S8:预测类别的步骤中,将待分类的图像数据输入已训练好的网络模型中进行预测,得到类别概率,将概率值最大的类别作为预测结果。
需要说明的是,在利用训练好的模型进行预测时,patch向量是可变的,通过Tranformer模型最终更新为待分类的图像中最具有分类代表性的一个patch对应的特征向量;同时,在对模型进行训练时,patch向量也是可学习的,在整个训练过程中不断更新,这一更新不但包括在Tranformer模型中的更新,还包括在训练过程中因网络梯度下降而获得的更新。
本发明提出了一种基于Transformer的半监督网络的图像分类模型的训练方法。本发明具有如下有益效果:
1.本发明实施例针对图像分类领域的特殊性,利用注意力机制思想,将Transformer模型引入到图像分类任务中,解决了传统深度学习模型提取图像的全局信息困难的问题,有效地关注图像的全局信息,同时更注重图像内容的连续性,从而提高了在图像分类的分类效果;
2.本发明实施例通过伪标签预测和Consistency Regularization的方式,解决了图像分类领域的有标记数据获取困难的问题,仅仅用少量的有标签数据即可完成深度学习训练过程,实现了半监督网络学习,并且有良好学习效果;
3.本发明实施例设计了适用于图像数据的数据结构,在Transformer模型的基础上增加了图像分块处理、可学习特征向量嵌入及图像位置信息编码操作,实现了Transformer模型及自注意力机制进行图像分类中的应用;
4.本发明实施例采用基于图像的Transformer的模型多次识别无标签数据,预测无标签数据的伪标签,并将现有的预测类别与伪标签进行对比,通过保证二者的一致性来约束网络模型,实现了从大量无标签数据中学习有益的信息;
5.本发明实施例将交叉熵损失与一致性损失联合起来对网络模型进行训练,通过交叉熵损失来实现有标签数据对网络模型的约束,通过一致性损失从无标签数据中提取有益的信息,实现了对训练数据的充分利用,在更全面的信息下也提高了网络的收敛速度和图像分类的准确性。
实施例四
图5为本发明实施例提供的一种计算机设备的结构示意图。如图5所示,该设备包括处理器510和存储器520。处理器510的数量可以是一个或多个,图5中以一个处理器510为例。
存储器520作为一种计算机可读存储介质,可用于存储软件程序、计算机可执行程序以及模块,如本发明实施例一、三所述的图像分类方法的程序指令/模块,或实施例二所述的图像分类模型的训练方法的程序指令/模块。
相应地,处理器510通过运行存储在存储器520中的软件程序、指令以及模块,实现本发明实施例一、三所述的图像分类方法,或实施例二所述的图像分类模型的训练方法。
存储器520可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序;存储数据区可存储根据终端的使用所创建的数据等。此外,存储器520可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他非易失性固态存储器件。在一些实例中,存储器520可进一步包括相对于处理器510远程设置的存储器,这些远程存储器可以通过网络连接至设备/终端/服务器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
本领域内的技术人员应明白,本发明的实施例可提供为方法、系统、或计算机程序产品。因此,本发明可采用硬件实施例、软件实施例、或结合软件和硬件方面的实施例的形式。而且,本发明可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器和光学存储器等)上实施的计算机程序产品的形式。
以上所述仅为本发明的优选实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

Claims (10)

1.一种图像分类方法,其特征在于,包括:
S10:将待分类图像切分成多个patch,生成每个patch对应的patch向量;通过线性层对每patch向量进行降维,将多个降维后的patch向量进行拼接,得到第一序列向量;在所述第一序列向量的首部嵌入一个可变向量,得到第二序列向量,其中,所述可变向量与每个降维后的patch向量尺寸相同,且所述可变向量对应所述多个patch中最能代表所述待分类图像的特征的patch;
S20:初始化所述第二序列向量的位置编码向量,其中,所述位置编码向量中包含所述多个patch在所述待分类图像中的位置信息;将所述初始化后的位置编码向量嵌入到所述第二序列向量中,得到输入向量;
S30:将所述输入向量输入到Transformer模型的编码器,得到编码向量;取所述编码向量的首部的可变向量作为所述待分类图像的特征向量;将所述特征向量输入到所述Transformer模型的分类器,得到所述待分类图像的预测类别概率。
2.如权利要求1所述的图像分类方法,其特征在于,S10包括:
S110:将尺寸为H×W×C的所述待分类图像切分成m个尺寸为P×P×C的patch,其中,H和W分别表示所述待分类图像的高度和宽度,C表示所述待分类图像的通道数,P表示每个patch的宽度;
S120:将每个patch展开成一个patch向量,通过所述线性层将每个patch向量降至D维,生成所述第一序列向量X1=[x1;x2;…;xm],其中,xi表示第i个patch的patch向量,i=1、2…m,
Figure FDA0003178208010000011
Figure FDA0003178208010000012
表示维度为D的向量域;
S130:在X1的首部嵌入所述可变向量xclass,得到所述第二序列向量X2=[xclass;x1;x2;…;xm],其中,
Figure FDA0003178208010000013
3.如权利要求2所述的图像分类方法,其特征在于,S20包括:
S21:初始化xclass的位置编码向量P0,初始化xi的位置编码向量Pi,其中,所述第二序列向量的位置编码向量P[P0;P1;P2;…;Pm],
Figure FDA0003178208010000021
j=0、1、2…m,Pj中包含Pj对应的patch在所述待分类图像中的位置信息;
S22:将P嵌入到X2中,得到所述输入向量X[xclass+P0;x1+P1;x2+P2;…;xm+Pm]。
4.如权利要求1所述的图像分类方法,其特征在于,所述Transformer模型包含所述编码器和所述分类器,不包含解码器,其中,
所述编码器包括串行排列的的多头自注意力MSA和第一多层感知器MLP,所述MSA的输出为所述第一MLP的输入;所述MSA与所述第一MLP的内部均采用残差连接方式;所述MSA和所述第一MLP之前均连接有一个归一化层LN,待处理信号经过一个LN后再输入所述MSA或所述第一MLP进行处理;
所述分类器包括第二MLP。
5.一种图像分类模型的训练方法,其特征在于,包括:
S01:获取一个训练数据集D,其中,所述训练数据集中包括有标签数据集Dl和无标签数据集Du,每个训练数据为一幅训练图像,每个有标签数据dl的标签为dl的真实类别yl
S02:对每个有标签数据dl进行一次随机数据增强,得到增强后的有标签数据集
Figure FDA0003178208010000022
对每个无标签数据du进行K次随机数据增强,得到K个增强后的无标签数据集
Figure FDA0003178208010000023
k=1,...,K,将所有du的K个
Figure FDA0003178208010000024
的并集记为
Figure FDA0003178208010000025
将每个无标签数据du的K个
Figure FDA0003178208010000026
分别输入如权利要求1-4中任意一项所述的图像分类方法对应的图像分类模型,最终得到K个预测类别,对所述K个预测类别取平均,将得到的平均值作为du的伪标签;
S03:将
Figure FDA0003178208010000031
输入所述图像分类模型,得到
Figure FDA0003178208010000032
中的每个数据
Figure FDA0003178208010000033
的预测类别概率;利用
Figure FDA0003178208010000034
中所有数据的预测类别概率和真实类别,计算交叉熵损失;
S04:将
Figure FDA0003178208010000035
输入所述图像分类模型,得到
Figure FDA0003178208010000036
中的每个数据
Figure FDA0003178208010000037
的预测类别概率,将所述预测类别概率中的概率最大值所对应的类别作为
Figure FDA0003178208010000038
的预测类别;利用
Figure FDA0003178208010000039
中的所有数据的预测类别和伪标签,计算一致性损失;
S05:将所述交叉熵损失和所述一致性损失的加权和作为本轮训练的总损失,对所述图像分类模型中的网络参数进行训练,其中,所述网络参数包括:所述线性层的参数、所述编码器的参数和所述分类器的参数;
S06:返回S01,直到满足设定的终止条件,保存训练过程中总损失最小时的网络参数,将对应的图像分类模型作为训练好的图像分类模型。
6.如权利要求5所述的训练方法,其特征在于,所述随机数据增强包括图像位移、改变图像的亮度、改变图像的对比度和改变图像的饱和度中的至少一种方式的随机组合,其中,图像的位移、图像的亮度、图像的对比度和图像的饱和度的改变值均为预设范围内的随机数。
7.如权利要求6所述的训练方法,其特征在于,S03中,利用
Figure FDA00031782080100000310
中所有数据的预测类别概率和真实类别,计算交叉熵损失,包括:
根据公式(1),利用
Figure FDA00031782080100000311
中所有数据的预测类别概率和真实类别,计算所述交叉熵损失Lossl
Figure FDA00031782080100000312
其中,n表示
Figure FDA00031782080100000313
中的数据
Figure FDA00031782080100000314
的个数,
Figure FDA00031782080100000315
表示
Figure FDA00031782080100000316
的真实类别,pl,a表示所述图像分类模型预测得到的
Figure FDA0003178208010000041
的类别为
Figure FDA0003178208010000042
的概率。
8.如权利要求6所述的训练方法,其特征在于,S04中,利用
Figure FDA0003178208010000043
中的所有数据的预测类别和伪标签,计算一致性损失,包括:
根据公式(2),利用
Figure FDA0003178208010000044
中的所有数据的预测类别和伪标签,计算一致性损失Lossu
Figure FDA0003178208010000045
其中,M表示
Figure FDA0003178208010000046
中的数据
Figure FDA0003178208010000047
的个数,ω(·)表示坡度函数,t表示全局迭代次数,yu,k,b表示
Figure FDA0003178208010000048
的预测类别,
Figure FDA0003178208010000049
表示
Figure FDA00031782080100000410
的伪标签。
9.如权利要求5所述的训练方法,其特征在于,在S01之前,还包括:
S011:对所述图像分类模型进行初始化,利用大数据集对初始化的模型进行预训练,得到源模型;
S012:复制所述源模型的中的Transformer模型的编码器的参数,并初始化所述Transformer的分类器的参数,得到中间模型;
在S02中,将每个无标签数据du的K个
Figure FDA00031782080100000411
分别输入如权利要求1-4中任意一项所述的图像分类方法对应的图像分类模型,包括:
将所述K个
Figure FDA00031782080100000412
分别输入所述中间模型。
10.一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现如权利要求1-4中任意一项所述的图像分类方法,或实现如权利要求5-9中任意一项所述的图像分类模型的训练方法。
CN202110838884.8A 2021-07-23 2021-07-23 一种图像分类方法、图像分类模型的训练方法及设备 Withdrawn CN113469283A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110838884.8A CN113469283A (zh) 2021-07-23 2021-07-23 一种图像分类方法、图像分类模型的训练方法及设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110838884.8A CN113469283A (zh) 2021-07-23 2021-07-23 一种图像分类方法、图像分类模型的训练方法及设备

Publications (1)

Publication Number Publication Date
CN113469283A true CN113469283A (zh) 2021-10-01

Family

ID=77882260

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110838884.8A Withdrawn CN113469283A (zh) 2021-07-23 2021-07-23 一种图像分类方法、图像分类模型的训练方法及设备

Country Status (1)

Country Link
CN (1) CN113469283A (zh)

Cited By (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113920583A (zh) * 2021-10-14 2022-01-11 根尖体育科技(北京)有限公司 细粒度行为识别模型构建方法及系统
CN114418030A (zh) * 2022-01-27 2022-04-29 腾讯科技(深圳)有限公司 图像分类方法、图像分类模型的训练方法及装置
CN115131607A (zh) * 2022-06-15 2022-09-30 北京工业大学 图像分类方法及装置
CN115880727A (zh) * 2023-03-01 2023-03-31 杭州海康威视数字技术股份有限公司 人体识别模型的训练方法和装置
CN116310520A (zh) * 2023-02-10 2023-06-23 中国科学院自动化研究所 目标检测方法、装置、电子设备以及存储介质
CN117173401A (zh) * 2022-12-06 2023-12-05 南华大学 基于交叉指导和特征级一致性双正则化的半监督医学图像分割方法及系统
CN117253044A (zh) * 2023-10-16 2023-12-19 安徽农业大学 一种基于半监督交互学习的农田遥感图像分割方法
CN117593557A (zh) * 2023-09-27 2024-02-23 北京邮电大学 一种基于Transformer模型的细粒度生物图像分类方法

Cited By (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113920583A (zh) * 2021-10-14 2022-01-11 根尖体育科技(北京)有限公司 细粒度行为识别模型构建方法及系统
CN114418030A (zh) * 2022-01-27 2022-04-29 腾讯科技(深圳)有限公司 图像分类方法、图像分类模型的训练方法及装置
CN114418030B (zh) * 2022-01-27 2024-04-23 腾讯科技(深圳)有限公司 图像分类方法、图像分类模型的训练方法及装置
CN115131607A (zh) * 2022-06-15 2022-09-30 北京工业大学 图像分类方法及装置
CN115131607B (zh) * 2022-06-15 2024-07-26 北京工业大学 图像分类方法及装置
CN117173401A (zh) * 2022-12-06 2023-12-05 南华大学 基于交叉指导和特征级一致性双正则化的半监督医学图像分割方法及系统
CN117173401B (zh) * 2022-12-06 2024-05-03 南华大学 基于交叉指导和特征级一致性双正则化的半监督医学图像分割方法及系统
CN116310520A (zh) * 2023-02-10 2023-06-23 中国科学院自动化研究所 目标检测方法、装置、电子设备以及存储介质
CN115880727A (zh) * 2023-03-01 2023-03-31 杭州海康威视数字技术股份有限公司 人体识别模型的训练方法和装置
CN117593557A (zh) * 2023-09-27 2024-02-23 北京邮电大学 一种基于Transformer模型的细粒度生物图像分类方法
CN117253044A (zh) * 2023-10-16 2023-12-19 安徽农业大学 一种基于半监督交互学习的农田遥感图像分割方法
CN117253044B (zh) * 2023-10-16 2024-05-24 安徽农业大学 一种基于半监督交互学习的农田遥感图像分割方法

Similar Documents

Publication Publication Date Title
CN113469283A (zh) 一种图像分类方法、图像分类模型的训练方法及设备
CN110322446B (zh) 一种基于相似性空间对齐的域自适应语义分割方法
Fleuret Uncertainty reduction for model adaptation in semantic segmentation
Niculae et al. A regularized framework for sparse and structured neural attention
Campos et al. Skip rnn: Learning to skip state updates in recurrent neural networks
Liu et al. Multi-objective convolutional learning for face labeling
US20190095787A1 (en) Sparse coding based classification
Kortylewski et al. Probabilistic Compositional Active Basis Models for Robust Pattern Recognition.
Wei et al. Compact MQDF classifiers using sparse coding for handwritten Chinese character recognition
Mukherjee et al. Predicting video-frames using encoder-convlstm combination
CN114663798B (zh) 一种基于强化学习的单步视频内容识别方法
US20110299789A1 (en) Systems and methods for determining image representations at a pixel level
CN115293348A (zh) 一种多模态特征提取网络的预训练方法及装置
CN112307883A (zh) 训练方法、装置、电子设备以及计算机可读存储介质
Uddin et al. A perceptually inspired new blind image denoising method using $ L_ {1} $ and perceptual loss
Xiao et al. Apple ripeness identification from digital images using transformers
CN115905613A (zh) 音视频多任务学习、评估方法、计算机设备及介质
WO2019234291A1 (en) An apparatus, a method and a computer program for selecting a neural network
TW202348029A (zh) 使用限幅輸入數據操作神經網路
Wang et al. Efficient crowd counting via dual knowledge distillation
CN115426671A (zh) 图神经网络训练、无线小区故障预测方法、系统及设备
Koohzadi et al. A context based deep temporal embedding network in action recognition
Wu et al. Extreme Learning Machine Combining Hidden-Layer Feature Weighting and Batch Training for Classification
Li et al. A self-adjusting transformer network for detecting transmission line defects
Lin et al. Face localization and enhancement

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
CB03 Change of inventor or designer information
CB03 Change of inventor or designer information

Inventor after: Wang Zijie

Inventor after: Wang Ruifeng

Inventor after: Ding Dongrui

Inventor after: Zhu Guoli

Inventor after: Lu Tianbin

Inventor after: Wang Xiaohan

Inventor before: Zhang Kai

Inventor before: Wang Ruifeng

Inventor before: Ding Dongrui

Inventor before: Yang Guangyuan

Inventor before: Lu Tianbin

Inventor before: Wang Xiaohan

TA01 Transfer of patent application right
TA01 Transfer of patent application right

Effective date of registration: 20220816

Address after: 277400 courtyard 29, Longwan villa, South Gate of the ancient city, Yunhe North Bank Road, Canal Street, Taierzhuang, Zaozhuang City, Shandong Province

Applicant after: Shandong huanke Information Technology Co.,Ltd.

Address before: 276808 No.99, Yuquan 2nd Road, antonwei street, Lanshan District, Rizhao City, Shandong Province

Applicant before: Shandong Liju Robot Technology Co.,Ltd.

WW01 Invention patent application withdrawn after publication
WW01 Invention patent application withdrawn after publication

Application publication date: 20211001