CN115690534A - 一种基于迁移学习的图像分类模型的训练方法 - Google Patents
一种基于迁移学习的图像分类模型的训练方法 Download PDFInfo
- Publication number
- CN115690534A CN115690534A CN202211315590.8A CN202211315590A CN115690534A CN 115690534 A CN115690534 A CN 115690534A CN 202211315590 A CN202211315590 A CN 202211315590A CN 115690534 A CN115690534 A CN 115690534A
- Authority
- CN
- China
- Prior art keywords
- sample
- samples
- domain
- training
- label
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Landscapes
- Image Analysis (AREA)
Abstract
一种图像分类模型的训练方法,所述图像分类模型通过特征提取器提取样本对应的样本特征并通过分类器根据样本特征对样本进行分类,所述方法包括:S1、获取初始的训练集,其包括来自多个源域的样本及指示每个样本所属源域的域属性以及所属类别的标签,其中,所述样本为图像,各源域的标签空间相同,一种标签对应一个子域;S2、利用初始的训练集训练特征提取器提取样本特征,并基于第一损失函数计算的损失值更新特征提取器的参数,所述第一损失函数被配置为基于域属性和标签惩罚不同源域的同一子域的样本对应的样本特征之间的差异;S3、对初始的训练集中的样本进行增强处理以添加所述多个源域之外的增强样本,利用样本及其增强样本对特征提取器和两个分类器进行对抗训练。
Description
技术领域
本发明涉及计算机视觉范畴,具体来说,涉及机器学习中的图像分类技术领域,更具体地说,涉及一种基于迁移学习的图像分类模型的训练方法。
背景技术
传统机器学习尤其是图像分类模型中的最基本假设是训练数据(源域)和测试数据(目标域)之间是独立同分布的,在这个假设下,通过最小化训练数据的误差来优化模型在测试数据的性能。然而,在新环境下采集的目标域数据通常与源域的数据的分布并不相同。具体来说,图像分类模型在面对与训练分布无关的数据时,例如,当将不同视角和光照条件下的真实产品图像作为测试,在干净的产品图像上训练的模型,如果使用以最小化源域经验损失[1]的传统机器学习方法,会表现出较差的性能。
而迁移学习是缓解目标域分布差异的重要方法,迁移学习通过现有的包含较多标签的源域数据对模型进行训练来实现对标签未知或标签较少的目标域的正确识别。在实际应用中,训练用的带有标签的数据集是在不同环境下收集的,而这些不同源域的数据集往往遵循不同的分布,导致单一域的多源域迁移学习没有最大化多源域的信息优势。与此同时,多源域倾向产生一个过拟合模型,使得模型在未知的目标域上的表现不可预测。
领域泛化通常关注训练所使用的源域数据,而领域自适应通常同时关注训练所使用的源域数据以及测试所使用的目标域数据。领域泛化的出现是为了最小化数据集漂移导致的泛化风险,其典型技术特征是学习域不变的表示,从不同源域收集带有不同标签(一种标签即对应一种子域)的数据以获取更多对象的上下文信息,从而学习到不同数据与标签之间不变的关联特征,即本质特征。目前已有的领域自适应方法,例如一种使用域鉴别器的对抗源域对齐方法[2]、一种最大分类器差异源域对齐方法[3],在多源域的泛化问题上都取得了较好的表现,而目前领域泛化的主要做法是以整个源域为单位进行特征分布的对齐(后文简称:一般域对齐或者源域对齐),从而获得域不变特征。以整个源域为单位进行特征分布的对齐(强制的一般域对齐)可能会使模型学习到具有模糊类别分类边界的过压缩域不变特征,不仅存在域不变性,而且来自不同类别的表示也会被混淆,在源域和目标域的性能有待提高。
在领域泛化对抗方法中,例如一种面向域泛化的深度域对抗图像生成方法[4],强制对齐一般域可能会使模型学习到具有模糊类别分类边界的过压缩域不变特征。虽然在领域自适应中,子域对齐通常可以比传统的全域对齐方法获得更好的自适应性能,但是在领域泛化中,由于目标域的信息匮乏,已经证明了一般的域不变特征会增加潜在表示的跨度,特别是对于遥远的目标域;而如果直接采用对齐不同源域的同一子域的样本对应的样本特征(即:同一个类别的样本特征对齐,后面简称子域对齐)的方案,其模型在未知的目标域(特别是与源域的数据分布差距较大的目标域)的性能比源域对齐的方案得到的模型的性能更差,导致在领域泛化中子域对齐很难推广到不可见的目标域。
在例如一种面向域泛化的特征风格化域感知对比学习方法[5]中,使用了对比学习的思想来进行源域的对齐,但同样没有考虑到源域的子域的对齐。
总的来说,现有的域泛化方法在解决例如PACS数据集这类包含多个源域的图像分类问题中,不能有效对齐不同源域所包含的子域的样本特征,并且未知的目标域存在风险,导致分类效果不佳。
综上所述,目前已有的领域泛化方法中,以整个源域为单位进行对齐会导致模型学习到过度压缩的域不变特征,降低模型在源域和目标域上的性能,而直接采用对齐不同源域的同一子域的样本对应的样本特征的方案,很难推广到不可见的目标域。
参考文献:
[1]Vapnik V N.An overview of statistical learning theory[J].IEEEtransactions on neural networks,1999,10(5):988-999.
[2]Ganin Y,Ustinova E,Ajakan H,et al.Domain-adversarial training ofneural networks[J].The journal of machine learning research,2016,17(1):2096-2030.
[3]Saito K,Watanabe K,Ushiku Y,et al.Maximum classifier discrepancyfor unsupervised domain adaptation[C]//Proceedings of the IEEE conference oncomputer vision and pattern recognition.2018:3723-3732.
[4]Zhou K,Yang Y,Hospedales T,et al.Deep domain-adversarial imagegeneration for domain generalisation[C]//Proceedings of the AAAI Conferenceon Artificial Intelligence.2020,34(07):13025-13032.
[5]Jeon S,Hong K,Lee P,et al.Feature stylization and domain-awarecontrastive learning for domain generalization[C]//Proceedings of the 29thACM International Conference on Multimedia.2021:22-31.
发明内容
因此,本发明的目的在于克服上述现有技术的缺陷,提供一种图像分类模型的训练方法。
根据本发明的第一方面,提供一种图像分类模型的训练方法,所述图像分类模型包括特征提取器以及两个分类器,其通过特征提取器提取样本对应的样本特征并通过分类器根据样本特征对样本进行分类,所述方法包括:S1、获取初始的训练集,其包括来自多个源域的样本及指示每个样本所属源域的域属性以及所属类别的标签,其中,所述样本为图像,各源域的标签空间相同,一种标签对应一个子域;S2、利用初始的训练集训练特征提取器提取样本特征,并基于第一损失函数计算的损失值更新特征提取器的参数,所述第一损失函数被配置为基于域属性和标签惩罚不同源域的同一子域的样本对应的样本特征之间的差异;S3、对初始的训练集中的样本进行增强处理以添加所述多个源域之外的增强样本,利用样本及其增强样本对特征提取器和两个分类器进行对抗训练。
在本发明的一些实施例中,所述步骤S3包括:在对抗训练时,基于第二损失函数计算的损失更新两个分类器的参数,其中,第二损失函数被配置为惩罚每个分类器分类的偏差以及奖励两个分类器对增强样本在每个类别的置信概率差异之和;以及在对抗训练时,基于第三损失函数计算的损失更新特征提取器的参数,其中,第三损失函数被配置为惩罚每个分类器分类的偏差以及惩罚两个分类器对增强样本在每个类别的置信概率差异之和。
在本发明的一些实施例中,第二损失函数计算的损失被配置为与两个分类器对样本及其增强样本在每个类别的置信概率分别与对应标签之间的偏差正相关以及与两个分类器对增强样本在每个类别的置信概率差异之和负相关。
优选的,第二损失函数如下:
在本发明的一些实施例中,第三损失函数计算的损失被配置为与第一损失函数正相关、与每个分类器对样本及其增强样本在每个类别的置信概率分别与对应标签之间的偏差正相关以及与两个分类器对增强样本在每个类别的置信概率差异之和正相关。
优选的,第三损失函数如下:
其中,表示第一损失函数, 表示两个分类器对样本及其增强样本在每个类别的置信概率分别与对应标签之间的偏差的均值,表示两个分类器对增强样本在每个类别的置信概率差异之和的均值,α表示为预设的权重,μ表示为预设的权重。
其中,2N表示训练时采用的样本或者增强样本的数量,Lce(·)表示分类器输出的置信概率与对应标签之间的交叉熵损失,Gi(·)表示第i个分类器根据样本特征进行分类的置信概率,i=1时表示第1个分类器,i=2时表示第2个分类器,F(·)表示特征提取器提取的样本特征,xm表示样本m,表示样本m对应的增强样本,ym表示样本m对应的标签。
在本发明的一些实施例中,在步骤S2中,每次训练从初始的训练集中取N个样本作为锚样本并取与每个锚样本的标签相同但域属性不同的样本作为该锚样本的正样本,形成批次大小为2N的训练批次,并按照如下第一损失函数计算损失值:
其中,i表示训练批次中的一个锚样本的编号,i+N表示训练批次中编号为i的锚样本的正样本的编号,子损失l(i,i+N)或者l(i+N,i)按照如下方式计算:
其中,a=i时b=i+N,a=i+N时b=i,exp(·)表示以自然数e为底的指数函数,sim(·)表示计算余弦相似度的函数,分别表示样本a、样本b、样本l对应的样本特征,样本l为训练批次中编号为l的样本,τ表示温度超参数,k表示样本对应的域属性,y表示样本对应的标签,w(a,l)表示为设定的加权值。
在本发明的一些实施例中,w(a,l)按照以下方式确定:
其中,Mneg(a,l)、Mneg+(a,l)表示根据样本a和样本l的域属性以及标签确定的值,其中,样本a和样本l的域属性相同但标签不同时Mneg+(a,l)=1,否则Mneg+(a,l)=0;样本a和样本l的域属性和标签均不同时Mneg(a,l)=1,否则Mneg(a,l)=0;表示当前训练批次中和样本a的域属性不同但标签相同的样本的数量,表示当前训练批次中和样本a的域属性相同但标签不同的样本的数量。
在本发明的一些实施例中,步骤S3中对图像分类模型进行多轮对抗训练并在每一轮中对每个样本均随机采用所述多种图像增强方式中的至少一种方式进行增强,所述多种图像增强方式包括:左右翻转、颜色失真、高斯模糊和日光化或者其组合。
在本发明的一些实施例中,所述多源域包括艺术图像域、卡通域、照片域和手绘域或者其组合,样本对应的标签指示该图像所含有的对象的类别。
根据本发明的第二方面,提供一种图像分类方法,所述方法包括:获取待分类的图像样本;利用图像分类模型对待分类的图像样本所属的分类进行预测,其中所述图像分类模型由根据本发明第一方面所述的训练方法所得到的特征提取器以及两个分类器中的一个分类器组成。
与现有技术相比,本发明的优点在于:
1、基于对比学习,采用动态域加权对比损失来进行多源域的子域对齐,避免了过度压缩,并且具有更好的自适应能力。
2、通过在模型的训练过程中构建增强域来模拟潜在目标域,并将增强域的子域与多源域的子域对齐来实现子域对齐的模型外扩,降低未知的目标域的风险。
附图说明
以下参照附图对本发明实施例作进一步说明,其中:
图1为根据本发明实施例的一种不含头投影头的图像分类模型的结构示意图;
图2为根据本发明实施例的一种图像分类模型的训练方法流程示意图;
图3为根据本发明实施例的一种不含有投影头的对齐不同源域的同一子域的示意图;
图4为根据本发明实施例的一种包含投影头的图像分类模型的结构示意图;
图5为根据本发明实施例的一种包含投影头的对齐不同源域的同一子域的示意图;
图6为根据本发明实施例的一种图像分类模型进行训练的示意图。
具体实施方式
为了使本发明的目的,技术方案及优点更加清楚明白,以下通过具体实施例对本发明进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本发明,并不用于限定本发明。
如背景技术提到的,目前已有的领域泛化方法中,以整个源域为单位进行对齐会导致模型学习到过度压缩的域不变特征,降低模型在源域和目标域上的性能,而直接采用对齐不同源域的同一子域的样本对应的样本特征的方案,很难推广到不可见的目标域。因此,发明人利用初始的训练集中的域属性以及标签,对齐不同源域的同一子域的样本对应的样本特征;并且,通过对初始的训练集中的样本进行增强处理以添加所述多个源域之外的增强样本,得到增强的训练集,利用增强的训练集对特征提取器和两个分类器进行对抗训练,其目的是让特征提取器和两个分类器在对抗的过程中不断地见到源域之外的样本(即增强样本),从而更好地学习到更本质的样本特征,在保障模型对已知源域的样本的分类性能的情况下,同时提高模型对未知的目标域的样本的分类性能。
为了更好地理解本发明,下面结合附图以及实施例从模型结构、训练样本、模型训练、应用场景四个方面对本发明进行详细说明。
一、模型结构
根据本发明的一个实施例,参见图1,本发明提供一种图像分类模型,所述图像分类模型包括特征提取器以及两个分类器,即:分类器1和分类器2。根据本发明的一个实施例,所述特征提取器用于对样本进行特征提取以生成样本特征,所述分类器用于根据样本特征进行分类。其中,特征提取器可以采用现有的神经网络中的主干网络,例如ResNet模型(如ResNet18模型、ResNet34模型或者ResNet50模型)、AlexNet模型、VGG模型(如VGG16模型、VGG19模型等)、Transfomer网络的主干网络,又或者实施者自定义的特征提取器。比如,ResNet模型中,将conv1、conv2_x、conv3_x、conv4_x、conv5_x以及平均池化层(averagepool)构成的主干网络作为本发明的特征提取器。又比如,VGG模型中,将倒数第二个全连接层(FC-4096)及其之前的所有层构成的主干网络作为本发明的特征提取器。根据本发明的一个实施例,分类器1和分类器2可采用相同的结构。每个分类器包括用于对样本特征进行线性变换以得到分类特征的线性层(比如一层或多层的全连接层),以及用于根据分类特征计算各类别的置信概率的Softmax层。分类器最终输出置信概率最高的类别作为预测的分类结果。
二、训练样本
根据本发明的一个实施例,本发明的数据集可以采用现有的图像分类数据集,其包括来自多个源域的样本及指示每个样本所属源域的域属性以及所属类别的标签,其中,所述样本为图像,各源域的标签空间相同,一种标签对应一个子域。不同的源域的样本对应的数据分布不同。由于需要多个源域的数据,收集样本的数据的成本较高。为了测试,可以采用公开的数据集。比如:PACS图像数据集,其为迁移学习中多领域自适应常用数据集,包括多个源域的图像数据,所述多源域包括艺术图像域(Artpainting)、卡通域(Cartoon)、照片域(Photo)、手绘域(Sketch)或者其组合,且每个源域的样本为图像,样本对应的标签指示该图像所含有的对象的类别。标签包括:狗(Dog)、大象(Elephant)、长颈鹿(Giraffe)、吉他(Guitar)、马(Horse)、房子(House)、人(Person)。应当理解,本发明所述的图像分类模型可以应用于多种分类场景,以上分类场景仅为示意性的,本领域技术人员可以按需调整、设置,比如:添加一些标签,比如:牛、鸟,或者设置自定义的标签空间,比如:标签包括:马、牛、狗、猫、鼠、鸟、汽车、自行车、电动车、房子、人等,本发明对此不作任何限制。为了便于说明、测试本发明的效果,后面均以PACS图像数据集为例进行说明。
在训练前,数据集可以进一步被划分为训练集和验证集,或者,划分为训练集、验证集和测试集;通常按照一定的比例,比如按8:2的比例划分训练集和验证集,或者按8:1:1或者7:2:1的比例划分为训练集、验证集和测试集。若采用PACS图像数据集,初始的训练集即为从PACS图像数据集按一定的比例划分得到的训练集。另外,如果PACS图像数据集中样本的尺寸与选定的特征提取器的输入尺寸不一致,还可以先对样本的尺寸进行调整。比如:假设输入尺寸为224×224,为了便于样本在后续训练过程中的使用,在获取初始的训练集时首先将所有的图像样本大小统一调整为227×227,然后从每一个图像中心获取224×224的剪裁像素。
三、模型训练
根据本发明的一个实施例,参见图2,本发明提供一种针对图1的图像分类模型的训练方法,所述方法包括执行一次或者多次步骤:S1、S2、S3。为了更好地理解本发明,下面结合具体的实施例针对每一个步骤分别进行详细说明。
在步骤S1中,获取初始的训练集,其包括来自多个源域的样本及指示每个样本所属源域的域属性以及所属类别的标签,其中,所述样本为图像,各源域的标签空间相同,一种标签对应一个子域。
在步骤S2中,利用初始的训练集训练特征提取器提取样本特征,并基于预定的第一损失函数计算的损失值更新特征提取器的参数,所述第一损失函数被配置为基于域属性和标签惩罚不同源域的同一子域的样本对应的样本特征之间的差异。由此,以基于域属性对齐不同源域的同一子域的样本对应的样本特征。
为了更好地实现不同源域的子域对齐,可以针对性地利用初始的训练集中的域属性和标签,如果选定一个样本作为锚样本,比较锚样本与对应样本的域属性和标签可以划分不同类型样本,从而在子域对齐过程中更好地利用不同类型样本来提高对齐的效果。例如,以初始的训练集样本中的任意样本作为锚样本与锚样本的域属性不同但标签相同的样本为正样本(k1≠k2,其中k为域属性,y为标签)、与锚样本的域属性相同但标签不同的样本为第一类负样本 与锚样本的域属性和标签均不同的样本为第二类负样本
为了更好地理解所述第一损失函数,下面对其中涉及到的相关参数作出说明。在步骤S2中,每次训练从初始的训练集中取N个样本作为锚样本并取与每个锚样本的标签相同但域属性不同的样本作为该锚样本的正样本,形成批次大小为2N的训练批次。例如,本发明将初始的训练集中的样本平均地划分为若干组,每组均含有N个样本,以这N个样本作为锚样本分别对每一个样本选取与其构成正样本对(yi=yi+N,k≠ki+N)的样本,构成一个大小为的初始的训练集样本批次。在每个初始的训练集的样本批次中,每个样本i(其中i∈[1,N])对应的正样本为样本i+N(其中i+N∈[1+N,2N])。本发明对第一类负样本的样本增强权重,其中,通过样本的域属性以及标签类别计算两个样本构成的样本对所对应的加权值(也可称权重),加权值的集合构成了域加权值矩阵(也可称域加权权重矩阵)。为了获得域加权值矩阵,针对不同的样本对,可按照以下方式设置一些根据样本对中两个样本的域属性以及标签确定的值,以用于计算不同的样本对所对应的加权值:当样本i(其中i∈[1,N])和样本j(其中j∈[1,2N])构成正样本对时,Mpos(i,j)=1否则Mpos(i,j)=0,当两个样本i和j构成第一类负样本对时,Mneg+(i,j)=1,否则Mneg+(i,j)=0,当两个样本i和j构成第二类负样本对时,Mneg(i,j)=1,否则Mneg(i,j)=0,进一步地,
根据本发明的一个实施例,在步骤S2中,按照如下第一损失函数计算损失值:
其中,i表示训练批次中的一个锚样本的编号,i+N表示训练批次中编号为i的锚样本的正样本的编号,l(i,i+N)、l(i+N,i)表示子损失。训练时,基于第一损失函数计算的损失值更新特征提取器的参数,其目的是最小化第一损失函数计算的损失值,对应的公式如下:
根据本发明的一个实施例,子损失l(i,i+N)或者l(i+N,i)按照如下方式计算:
其中,a=i时b=i+N,a=i+N时b=i,exp(·)表示以自然数e为底的指数函数,sim(·)表示计算余弦相似度的函数,分别表示样本a、样本b、样本l对应的样本特征,样本l为训练批次中编号为l的样本,τ表示温度超参数,k表示样本对应的域属性,y表示样本对应的标签,w(a,l)表示为设定的加权值。温度超参数通常设置为小于1的数值,用于调整输出的样本特征的分布,增大相似样本对和不相似样本对的相似度差异以让分布变得更平滑,以此来放大类别相似性并提高模型的判别能力。当然,温度超参数不是必须的,可以设为1,相当于没有调整;又或者直接取消设置的温度超参数,后续涉及温度超参数的实施例类似,后面不再赘述。该实施例的技术方案至少能够实现以下有益技术效果:通过第一损失函数计算得到的损失值来更新特征提取器的参数,可以拉近潜空间(即样本特征对应的特征空间)中不同源域的同一子域的样本对应的样本特征之间的距离,推远不同子域的样本对应的样本特征之间的距离,以基于域属性和标签对齐不同源域的同一子域的样本对应的样本特征,使特征提取器能够更好地学习到多源域的初始的训练集中为表征样本分类相同与否更本质的特征。
根据本发明的一个实施例,加权值可以采用实施者预先设定的值,比如:分别设定样本l是样本a的正样本、第一类负样本、第二类负样本时的加权值。由于训练过程中通常是采用小批量梯度下降法,将初始的训练集划分为多个批次,以批次的形式训练所述图像分类模型并更新所述图像分类模型的参数。为了更好地保障子域对齐的效果,可根据一个批次中不同类型的样本的实际情况来动态地确定w(a,l),根据本发明的一个实施例,w(a,l)按照以下方式确定:
其中,Mneg(a,l)、Mneg+(a,l)表示根据样本a和样本l的域属性以及标签确定的值,其中,样本a和样本l的域属性相同但标签不同时Mneg+(a,l)=1,否则Mneg+(a,l)=0;样本a和样本l的域属性和标签均不同时Mneg(a,l)=1,否则Mneg(a,l)=0;表示当前训练批次中和样本a的域属性不同但标签相同的样本的数量,表示当前训练批次中和样本a的域属性相同但标签不同的样本的数量。该实施例的技术方案至少能够实现以下有益技术效果:通过计算w(a,l)可以增加一个批次中与每个样本a域属性相同但标签不同的样本的权重,提取到更好的去域特性的样本特征。需要说明的是,w(a,l)中的a和l分别对应上文的中的i和j;Mneg(a,l)对应于上文的Mneg(i,j),Mne(a,l)对应于上文的Mneg+(i,j)。
如果实施者有想要使用的确定的主干网络,可以直接构造图像分类模型并利用主干网络输出的样本特征直接用于第一损失函数计算损失值,例如参见图3,本发明通过将初始的训练集中的样本x通过特征提取器生成样本对应的样本特征z,并基于预定的第一损失函数计算的损失值更新特征提取器的参数,以基于域属性对齐不同源域的同一子域的样本对应的样本特征。在图3涉及到的参数中,表示初始的训练集中的一个锚样本,表示该锚样本经过特征提取器生成的样本特征,表示该锚样本的负样本表示该负样本经过特征提取器生成的样本特征,表示该锚样本的正样本(yj1=yj2,k1≠k2),表示该正样本经过特征提取器生成的样本特征,j表示该样本的编号,k表示样本的域属性,y表示样本的标签。
由于可以采用现有的不同主干网络,如果实施者想要从多个主干网络中选取一个最优的主干网络来构造模型,不同的主干网络输出的样本特征的维度可能不一致,不便于比较,因此,可以设计一个投影头(本发明所述图像分类模型在加入投影头后的结构如图4所示),利用投影头将主干网络提取到的样本特征转换到预定的特征维度,得到投影头转换后的样本特征。这样可以便于不同主干网络进行比较。根据本发明的一个实施例,参见图5,在加入投影头以后,第一损失函数中计算样本特征对应的余弦相似度采用的是经过投影头转换后的样本特征。所述投影头可以采用线性层(比如单层的全连接层,或者一个单隐层的多层感知机)实现,在步骤S2中训练特征提取器时,本发明通过将初始的训练集中的样本x先通过主干网络生成对应的样本特征z′,然后再通过投影头进行降维得到投影头转换后的样本特征z(此情况下,投影头转换后的样本特征z为特征提取最终的输出),以利用第一损失函数基于投影头转换后的样本特征计算的损失值更新主干网络以及投影头的参数。在图5涉及到的参数中,z′表示通过主干网络生成对应的样本特征,z表示通过投影头进行降维得到的投影头转换后的样本特征z,其余的参数说明与图3相同,此处不再进行赘述。
通过以上步骤最小化所述第一损失函数来更新所述特征提取器的参数以及投影头的参数以使当前批次中多源域样本的子域对齐(即基于域属性对齐不同源域的同一子域的样本对应的样本特征),使得同一标签的样本特征在潜空间中的分布尽可能一致,并使得子域的凸包直径缩小,不同子域之间的凸包中心距离增大。需要说明的是,多源域的凸包是指几何角度下,潜空间中包含所有初始的训练集样本对应的样本特征的所有凸集的交集,子域凸包是指标签相同的样本对应的样本特征的所有凸集的交集。对于每个训练批次,在基于域属性对齐不同源域的同一子域的样本对应的样本特征这一前提下进行后续的针对特征提取器和分类器的对抗训练过程。
在步骤S3中,对初始的训练集中的样本进行增强处理以添加所述多个源域之外的增强样本,利用样本及其增强样本对特征提取器和两个分类器进行对抗训练。
根据本发明的一个实施例,本发明通过对初始的训练集中的样本进行增强处理以添加所述多个源域之外的增强样本,利用样本及其增强样本对特征提取器和分类器进行对抗训练。为了确保数据集新颖的风格满足多样性和合理性,使增强样本构成的增强域能够近似表征目标域,因此要求增强域在语义不变的前提下尽可能多样,同时为了保证风格逼真,不偏离初始的训练集真实源风格的分布,本发明在步骤S3的对抗训练的过程中对每个样本均随机采用所述多种图像增强方式中的至少一种方式进行增强,其中,所述多种图像增强方式包括:左右翻转、颜色失真、高斯模糊和日光化或者其组合。例如,假设一个样本为x,本发明可基于SimCLR框架定义一个变换M来实现图像增强,得到增强样本:
根据本发明的一个实施例,所述对抗训练可进行一轮或者多个轮(Epoch),每轮训练中可以将初始的训练集分为多个批次(Batch),对每一个批次的样本进行图像增强获得增强样本。例如,每个批次的大小为2N(就是步骤S2中的一个批次的大小)。而且,优选的,对每个样本每次都是随机采用所述多种图像增强方式中的至少一种方式进行增强,由此,在不同的轮(Epoch)中,同一样本的增强样本可能采用了不一样的图像增强方式,以让图像分类模型在对抗训练中见到与源域的样本风格不一的增强样本,从而进一步提高领域泛化性,提升图像分类模型在潜在的目标域的性能。
通过上述步骤获得每个批次的样本的增强样本后,首先直接获取当前批次的初始的训练集中的样本在步骤S2中计算得到的第一损失值,再将当前批次的样本及其增强样本输入特征提取器进行特征提取生成样本特征,再将样本特征分别输入两个分类器中以输出预测的分类结果。并基于预定的第二损失函数计算的损失更新两个分类器的参数,以及基于预定的第三损失函数计算的损失更新特征提取器的参数。下面对第二损失函数与第三损失函数进行说明。
根据本发明的一个实施例,第二损失函数计算的损失被配置为与两个分类器对样本及其增强样本在每个类别的置信概率分别与对应标签之间的偏差正相关以及与两个分类器对增强样本在每个类别的置信概率差异之和负相关,用于惩罚每个分类器分类的偏差以及奖励两个分类器对增强样本在每个类别的置信概率差异之和。该实施例的技术方案至少能够实现以下有益技术效果:本发明的第二损失函数计算的损失被配置为与两个分类器对样本及其增强样本在每个类别的置信概率分别与对应标签之间的偏差正相关,这样可以让两个分类器学习到正确分类的知识,同时,本发明的第二损失函数计算的损失被配置与两个分类器对增强样本在每个类别的置信概率差异之和负相关,相当于让两个分类器在正确分类的前提下,奖励二者在各个类别的置信概率上的差异,由此让两个分类器均学习到更泛化的分类知识,这比含有单个分类器的图像分类模型更优,有助于提高图像分类模型的性能。
根据本发明的一个实施例,第二损失函数如下:
其中,表示两个分类器对样本及其增强样本在每个类别的置信概率分别与对应标签之间的偏差的均值(本发明中也将其称为样本及其增强样本在两个分类器上的交叉熵损失期望),表示两个分类器对增强样本在每个类别的置信概率差异之和的均值(本发明也将其称为L1分类差异距离),β表示为预设的权重。优选的,在本发明中β预设为0.5。
应当理解,更新分类器的参数的目的是最小化第二损失函数,对应的公式如下:
根据本发明的一个实施例,第三损失函数计算的损失被配置为与第一损失函数正相关、与每个分类器对样本及其增强样本在每个类别的置信概率分别与对应标签之间的偏差正相关以及与两个分类器对增强样本在每个类别的置信概率差异之和正相关,用于惩罚每个分类器分类的偏差以及惩罚两个分类器对增强样本在每个类别的置信概率差异之和。需要说明的是,第三损失函数只用于更新特征提取器的参数而不更新分类器的参数,通过最小化所述第三损失函数,来逼迫特征提取器能够生成泛化能力更强的特征(即能够使增强样本在两个分类器上的分类结果差异尽可能小),由此以完成特征提取器与两个分类器的一次对抗训练。
根据本发明的一个实施例,第三损失函数如下:
其中,表示第一损失函数, 表示两个分类器对样本及其增强样本在每个类别的置信概率分别与对应标签之间的偏差的均值,表示两个分类器对增强样本在每个类别的置信概率差异之和的均值,α表示为预设的权重,μ表示为预设的权重,其中μ∈[0,1]。优选的,在本发明中α预设为1,μ预设为0.4。
应当理解,步骤S3中,更新特征提取器的参数的目的是最小化第三损失函数,对应的公式如下:
其中,2N表示训练时采用的样本或者增强样本的数量,Lce(·)表示分类器输出的置信概率与对应标签之间的交叉熵损失,Gi(·)表示第i个分类器根据样本特征进行分类的置信概率,i=1时表示第1个分类器,i=2时表示第2个分类器,F(·)表示特征提取器提取的样本特征,xm表示样本m,表示样本m对应的增强样本,ym表示样本m对应的标签。
在上述计算的过程中,每次将一个样本以及其对应的增强样本分别输入到两个分类器G1和G2上,计算所述交叉熵损失期望,并直至将多源域样本中的每一个样本与其对应的增强样本完成计算,得到该样本批次整体的交叉熵损失期望。
其中,||·||1表示L1距离,M表示训练时采用的增强样本的数量,C表示类别的总数,表示分类器G1预测增强样本属于类别c的置信概率,表示分类器G2预测增强样本属于类别c的置信概率。应当说明,若是小批量训练,此处的训练时是指一个批次的训练。
在上述计算的过程中,每次将一个增强样本分别输入到两个分类器G1和G2上,计算所述L1分类差异损失(即增强样本在两个分类器上的概率输出的绝对值差值),并直至将增强样本中的每一个样本完成计算,得到该样本批次增强样本整体的L1分类差异损失(也可称为L1分类差异距离)。由于两个分类器的结构相似,很难在增强样本上产生不同的分类结果,因此需要通过增大两个分类器对增强样本在每个类别的置信概率差异之和以使增强样本能够在两个分类器上产生不同的分类结果,便于分类器各自能够学习到增强样本的不同特征。
需要说明的是,本发明中所说的参数是指通过反向传播算法更新特征提取器F和/或分类器G1,G2中的可训练参数(一些文献也称权重参数),目前神经网络中有多种参数更新方法,例如SGD、Momentum、Adam,本发明优选采用Adam优化器进行神经网络参数的优化。
下面以图6来更清晰地展示本发明的训练过程,其中,A类和B类表示不同的类别(分别对应不同的子域,用不同的灰色深度来区分,B类对应的深度相比A类更深),其中以DS表示源域,以DS1、DS2、DS3来区别表示不同的源域,以DAug表示增强样本对应的域(简称增强域)。图6a-b以一个子域的样本特征对应的分布在训练过程中变化过程来说明子域对齐,训练时,先按照本发明的步骤S1(图6未示出)获取初始的训练集,此时该子域的样本特征的分布对应于图6a所示初始状态,可以看到此时多个源域在该子域的样本特征整体的数据分布直径(后面简称凸包直径)较为分散,凸包直径较大,且不同源域的同一子域对应的样本特征对应的数据分布也相距较远;然后按照步骤S2的方式对特征提取器的参数进行更新,通过第一损失函数基于域属性和标签惩罚不同源域的同一子域的样本对应的样本特征之间的差异,由此拉近不同源域的同一子域的样本对应的样本特征的数据分布的距离,以对齐不同源域的同一子域的样本对应的样本特征,缩小该子域的凸包直径,得到对应于图6b的状态;再按照步骤S3的方式利用样本及其增强样本进行对抗训练,以让模型见到不属于任意源域的增强样本,同时拉近增强样本的样本特征与多个源域的样本对应的样本特征之间的数据分布的距离,得到对应于图6c的状态。另外,为了展现经对抗训练后,增强样本组成的增强域对应的不同子域的样本特征的分布变化,给出图6d,其中展示出了增强样本组成的增强域对应于A类的子域以及增强样本组成的增强域对应于B类的子域在经过步骤S3后的变化,可以看出,经过对抗训练后,能够拉近增强域的子域与源域的对应子域的样本特征的数据分布的距离,同时,增强域对应的不同子域之间的样本特征的数据分布之间的距离也被增大。
应当理解,本发明的前述实施例虽然给出了第一损失函数、第二损失函数、第三损失函数的一个具体实施细节,但是并非唯一的实现手段,本领域技术人员可以根据本发明的原理,采用能够实现相同技术效果的替代手段。例如:对于第一损失函数,也可采用传统的三元组损失函数,将样本组成包括锚样本、正样本、负样本的三元组,形如:{锚样本,正样本,负样本},正样本是与锚样本所属源域不同但属于同一子域的样本,负样本是与锚样本所属源域相同或者不同但属于不同子域的样本;步骤S2中,以三元组损失函数计算的损失值更新特征提取器的参数。对于第二损失函数,其用于惩罚每个分类器分类的偏差以及奖励两个分类器对增强样本在每个类别的置信概率差异之和,因此,在保证该前提的条件下,有多种可行的替代实施方式,比如将第二损失函数改为:当然,还可以修改第二损失函数中的一些子损失的系数,比如:将中的改为同样的,第三损失函数用于惩罚每个分类器分类的偏差以及惩罚两个分类器对增强样本在每个类别的置信概率差异之和,因此,在保证该前提的条件下,也有多种可行的替代实施方式,比如将第三损失函数改为:当然,还可以修改第三损失函数中的一些子损失的系数,比如:将中的改为或者,对于第一损失函数或者第三损失函数,还可将权重w(a,l)中的改为
四、应用场景
根据本发明的一个实施例,本发明还提供一种分类方法,所述方法包括:获取按照所述图像分类模型的训练方法训练所得到的图像分类模型中的特征提取器以及两个分类器中的一个分类器组成预测模型;获取待预测的样本,利用所述预测模型对待预测样本所属的分类进行预测,得到对应的分类结果。
为了验证本发明的效果,发明人还进行了实验。由于采集多个源域、未知的目标域的样本的成本较高,因此,在实验时,发明人从PACS数据集中选择一个域作为未知的目标域,而将其他域作为源域。例如,可以将手绘域作为未知的目标域,以艺术图像域、卡通域、照片域设置为多源域。从原始的PACS数据集移除手绘域的样本,得到修改后的PACS数据集;从修改后的PACS数据集中划分得到初始的训练集以及测试集。初始的训练集用于图像分类模型的训练,测试集用于对训练后的图像分类模型的精度进行测试。在通过所述图像分类模型对实际未知的目标域的待预测数据时进行分类预测之前,可先通过采用测试集测试两个分类器的分类结果,当两个分类器的分类结果相同时可选择其中任意一个分类器对未知的目标域的待预测属于进行分类预测,当两个分类器的分类结果不同时则选择分类结果相对更准确的分类器对未知的目标域的待预测属于进行分类预测。最终采用背景技术提到的5种现有方法与本发明的方案进行对比,得到的分类精度结果如表1所示,可以看出,本发明所提供的方案与现有技术相比,以每种源域作为未知的目标域进行分类的分类精度以及平均分类精度均比现有方法更高,相比现有方法能够实现更准确的分类效果。
表1
注:
方法1:一种最小化源域经验损失方法[1]
方法2:一种使用域鉴别器的对抗源域对齐方法[2]
方法3:一种最大分类器差异源域对齐方法[3]
方法4:一种面向域泛化的深度域对抗图像生成方法[4]
方法5:一种面向域泛化的特征风格化和域感知对比学习方法[5]
需要说明的是,虽然上文按照特定顺序描述了各个步骤,但是并不意味着必须按照上述特定顺序来执行各个步骤,实际上,这些步骤中的一些可以并发执行,甚至改变顺序,只要能够实现所需要的功能即可。
本发明可以是系统、方法和/或计算机程序产品。计算机程序产品可以包括计算机可读存储介质,其上载有用于使处理器实现本发明的各个方面的计算机可读程序指令。
计算机可读存储介质可以是保持和存储由指令执行设备使用的指令的有形设备。计算机可读存储介质例如可以包括但不限于电存储设备、磁存储设备、光存储设备、电磁存储设备、半导体存储设备或者上述的任意合适的组合。计算机可读存储介质的更具体的例子(非穷举的列表)包括:便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、静态随机存取存储器(SRAM)、便携式压缩盘只读存储器(CD-ROM)、数字多功能盘(DVD)、记忆棒、软盘、机械编码设备、例如其上存储有指令的打孔卡或凹槽内凸起结构、以及上述的任意合适的组合。
以上已经描述了本发明的各实施例,上述说明是示例性的,并非穷尽性的,并且也不限于所披露的各实施例。在不偏离所说明的各实施例的范围和精神的情况下,对于本技术领域的普通技术人员来说许多修改和变更都是显而易见的。本文中所用术语的选择,旨在最好地解释各实施例的原理、实际应用或对市场中的技术改进,或者使本技术领域的其它普通技术人员能理解本文披露的各实施例。
Claims (15)
1.一种图像分类模型的训练方法,其特征在于,所述图像分类模型包括特征提取器以及两个分类器,其通过特征提取器提取样本对应的样本特征并通过分类器根据样本特征对样本进行分类,所述方法包括:
S1、获取初始的训练集,其包括来自多个源域的样本及指示每个样本所属源域的域属性以及所属类别的标签,其中,所述样本为图像,各源域的标签空间相同,一种标签对应一个子域;
S2、利用初始的训练集训练特征提取器提取样本特征,并基于第一损失函数计算的损失值更新特征提取器的参数,所述第一损失函数被配置为基于域属性和标签惩罚不同源域的同一子域的样本对应的样本特征之间的差异;
S3、对初始的训练集中的样本进行增强处理以添加所述多个源域之外的增强样本,利用样本及其增强样本对特征提取器和两个分类器进行对抗训练。
2.根据权利要求1所述的方法,其特征在于,所述步骤S3包括:
在对抗训练时,基于第二损失函数计算的损失更新两个分类器的参数,其中,第二损失函数被配置为惩罚每个分类器分类的偏差以及奖励两个分类器对增强样本在每个类别的置信概率差异之和;以及
在对抗训练时,基于第三损失函数计算的损失更新特征提取器的参数,其中,第三损失函数被配置为惩罚每个分类器分类的偏差以及惩罚两个分类器对增强样本在每个类别的置信概率差异之和。
3.根据权利要求2所述的方法,其特征在于,第二损失函数计算的损失被配置为与两个分类器对样本及其增强样本在每个类别的置信概率分别与对应标签之间的偏差正相关以及与两个分类器对增强样本在每个类别的置信概率差异之和负相关。
5.根据权利要求2所述的方法,其特征在于,第三损失函数计算的损失被配置为与第一损失函数正相关、与每个分类器对样本及其增强样本在每个类别的置信概率分别与对应标签之间的偏差正相关以及与两个分类器对增强样本在每个类别的置信概率差异之和正相关。
9.根据权利要求1-6任一项所述的方法,其特征在于,在步骤S2中,每次训练从初始的训练集中取N个样本作为锚样本并取与每个锚样本的标签相同但域属性不同的样本作为该锚样本的正样本,形成批次大小为2N的训练批次,并按照如下第一损失函数计算损失值:
11.根据权利要求1-6任一项所述的方法,其特征在于,步骤S3中对图像分类模型进行多轮对抗训练并在每一轮中对每个样本均随机采用所述多种图像增强方式中的至少一种方式进行增强,所述多种图像增强方式包括:左右翻转、颜色失真、高斯模糊和日光化或者其组合。
12.根据权利要求1-6任一项所述的方法,其特征在于,所述多源域包括艺术图像域、卡通域、照片域和手绘域或者其组合,样本对应的标签指示该图像所含有的对象的类别。
13.一种图像分类方法,其特征在于,所述方法包括:
获取待分类的图像样本;
利用图像分类模型对待分类的图像样本所属的分类进行预测,其中所述图像分类模型由根据权利要求1-12的训练方法所得到的特征提取器以及两个分类器中的一个分类器组成。
14.一种计算机可读存储介质,其特征在于,其上存储有计算机程序,所述计算机程序可被处理器执行以实现权利要求1-12以及13中任一项所述方法的步骤。
15.一种电子设备,其特征在于,包括:
一个或多个处理器;
存储装置,用于存储一个或多个程序,当所述一个或多个程序被所述一个或多个处理器执行时,使得所述电子设备实现如权利要求1-12以及13中任一项所述方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211315590.8A CN115690534A (zh) | 2022-10-26 | 2022-10-26 | 一种基于迁移学习的图像分类模型的训练方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211315590.8A CN115690534A (zh) | 2022-10-26 | 2022-10-26 | 一种基于迁移学习的图像分类模型的训练方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115690534A true CN115690534A (zh) | 2023-02-03 |
Family
ID=85098787
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211315590.8A Pending CN115690534A (zh) | 2022-10-26 | 2022-10-26 | 一种基于迁移学习的图像分类模型的训练方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115690534A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116246349A (zh) * | 2023-05-06 | 2023-06-09 | 山东科技大学 | 一种基于渐进式子域挖掘的单源域领域泛化步态识别方法 |
CN117315487A (zh) * | 2023-11-02 | 2023-12-29 | 北京林业大学 | 一种基于深度迁移学习的野生动物图像跨域识别方法 |
CN117407698A (zh) * | 2023-12-14 | 2024-01-16 | 青岛明思为科技有限公司 | 一种混合距离引导的领域自适应故障诊断方法 |
-
2022
- 2022-10-26 CN CN202211315590.8A patent/CN115690534A/zh active Pending
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116246349A (zh) * | 2023-05-06 | 2023-06-09 | 山东科技大学 | 一种基于渐进式子域挖掘的单源域领域泛化步态识别方法 |
CN116246349B (zh) * | 2023-05-06 | 2023-08-15 | 山东科技大学 | 一种基于渐进式子域挖掘的单源域领域泛化步态识别方法 |
CN117315487A (zh) * | 2023-11-02 | 2023-12-29 | 北京林业大学 | 一种基于深度迁移学习的野生动物图像跨域识别方法 |
CN117407698A (zh) * | 2023-12-14 | 2024-01-16 | 青岛明思为科技有限公司 | 一种混合距离引导的领域自适应故障诊断方法 |
CN117407698B (zh) * | 2023-12-14 | 2024-03-08 | 青岛明思为科技有限公司 | 一种混合距离引导的领域自适应故障诊断方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Gao et al. | Deep leaf‐bootstrapping generative adversarial network for structural image data augmentation | |
CN110837836B (zh) | 基于最大化置信度的半监督语义分割方法 | |
CN115690534A (zh) | 一种基于迁移学习的图像分类模型的训练方法 | |
US20200134455A1 (en) | Apparatus and method for training deep learning model | |
CN113657561B (zh) | 一种基于多任务解耦学习的半监督夜间图像分类方法 | |
CN107480144A (zh) | 具备跨语言学习能力的图像自然语言描述生成方法和装置 | |
CN111008639B (zh) | 一种基于注意力机制的车牌字符识别方法 | |
CN113139664B (zh) | 一种跨模态的迁移学习方法 | |
CN110826609B (zh) | 一种基于强化学习的双流特征融合图像识别方法 | |
CN112232395B (zh) | 一种基于联合训练生成对抗网络的半监督图像分类方法 | |
CN114387366A (zh) | 一种感知联合空间注意力文本生成图像方法 | |
CN112560710B (zh) | 一种用于构建指静脉识别系统的方法及指静脉识别系统 | |
CN113392967A (zh) | 领域对抗神经网络的训练方法 | |
CN114842343A (zh) | 一种基于ViT的航空图像识别方法 | |
CN114998602A (zh) | 基于低置信度样本对比损失的域适应学习方法及系统 | |
CN114722892A (zh) | 基于机器学习的持续学习方法及装置 | |
CN115563327A (zh) | 基于Transformer网络选择性蒸馏的零样本跨模态检索方法 | |
CN111126155B (zh) | 一种基于语义约束生成对抗网络的行人再识别方法 | |
CN114187493A (zh) | 一种基于生成对抗网络的零样本学习算法 | |
CN115270752A (zh) | 一种基于多层次对比学习的模板句评估方法 | |
CN113160032A (zh) | 一种基于生成对抗网络的无监督多模态图像转换方法 | |
CN110717402B (zh) | 一种基于层级优化度量学习的行人再识别方法 | |
CN111382871A (zh) | 基于数据扩充一致性的领域泛化和领域自适应学习方法 | |
CN116232699A (zh) | 细粒度网络入侵检测模型的训练方法和网络入侵检测方法 | |
CN116997908A (zh) | 用于分类类型任务的连续学习神经网络系统训练 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |