CN117115547A - 基于自监督学习与自训练机制的跨域长尾图像分类方法 - Google Patents
基于自监督学习与自训练机制的跨域长尾图像分类方法 Download PDFInfo
- Publication number
- CN117115547A CN117115547A CN202311137191.1A CN202311137191A CN117115547A CN 117115547 A CN117115547 A CN 117115547A CN 202311137191 A CN202311137191 A CN 202311137191A CN 117115547 A CN117115547 A CN 117115547A
- Authority
- CN
- China
- Prior art keywords
- self
- training
- domain
- learning
- long
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 103
- 238000000034 method Methods 0.000 title claims abstract description 90
- 230000007246 mechanism Effects 0.000 title claims abstract description 20
- 238000009826 distribution Methods 0.000 claims abstract description 63
- 238000012512 characterization method Methods 0.000 claims abstract description 23
- 238000005457 optimization Methods 0.000 claims abstract description 22
- 230000008569 process Effects 0.000 claims abstract description 12
- 238000012360 testing method Methods 0.000 claims abstract description 7
- 230000006870 function Effects 0.000 claims description 23
- 230000003044 adaptive effect Effects 0.000 claims description 16
- 238000007781 pre-processing Methods 0.000 claims description 12
- 238000005070 sampling Methods 0.000 claims description 11
- 238000012952 Resampling Methods 0.000 claims description 9
- 230000003416 augmentation Effects 0.000 claims description 7
- 230000010354 integration Effects 0.000 claims description 7
- 238000013507 mapping Methods 0.000 claims description 6
- 238000013459 approach Methods 0.000 claims description 5
- 238000013434 data augmentation Methods 0.000 claims description 5
- 238000010606 normalization Methods 0.000 claims description 5
- 238000012545 processing Methods 0.000 claims description 4
- 230000002708 enhancing effect Effects 0.000 claims description 3
- 238000004364 calculation method Methods 0.000 claims description 2
- 239000000284 extract Substances 0.000 claims description 2
- 238000000605 extraction Methods 0.000 abstract description 8
- 238000009825 accumulation Methods 0.000 abstract description 4
- 238000010586 diagram Methods 0.000 description 7
- 238000012800 visualization Methods 0.000 description 5
- 230000006978 adaptation Effects 0.000 description 4
- 239000011159 matrix material Substances 0.000 description 4
- 238000002474 experimental method Methods 0.000 description 3
- 238000013508 migration Methods 0.000 description 3
- 230000005012 migration Effects 0.000 description 3
- 238000002679 ablation Methods 0.000 description 2
- 238000013528 artificial neural network Methods 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 238000013526 transfer learning Methods 0.000 description 2
- NAWXUBYGYWOOIX-SFHVURJKSA-N (2s)-2-[[4-[2-(2,4-diaminoquinazolin-6-yl)ethyl]benzoyl]amino]-4-methylidenepentanedioic acid Chemical compound C1=CC2=NC(N)=NC(N)=C2C=C1CCC1=CC=C(C(=O)N[C@@H](CC(=C)C(O)=O)C(O)=O)C=C1 NAWXUBYGYWOOIX-SFHVURJKSA-N 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 230000015556 catabolic process Effects 0.000 description 1
- 239000003245 coal Substances 0.000 description 1
- 238000006731 degradation reaction Methods 0.000 description 1
- 230000004069 differentiation Effects 0.000 description 1
- 201000010099 disease Diseases 0.000 description 1
- 208000037265 diseases, disorders, signs and symptoms Diseases 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- IOJNPSPGHUEJAQ-UHFFFAOYSA-N n,n-dimethyl-4-(pyridin-2-yldiazenyl)aniline Chemical compound C1=CC(N(C)C)=CC=C1N=NC1=CC=CC=N1 IOJNPSPGHUEJAQ-UHFFFAOYSA-N 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
- G06V10/765—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects using rules for classification or partitioning the feature space
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0895—Weakly supervised learning, e.g. semi-supervised or self-supervised learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Multimedia (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于自监督学习与自训练机制的跨域长尾图像分类方法,包括:首先获取跨域长尾图像的数据集,并将数据集划分训练集以及测试集;然后采用训练集中的跨域长尾图像数据对双分支网络模型进行训练:对于非监督式自适应分支,利用自监督学习与全局分布对齐方法进行表征学习;对于监督式自适应分支,对分类器进行自训练和互信息最大化;利用动态加权集成训练策略自动调整两个分支的学习权重,从而实现两个分支学习目标的同时优化;最后采用训练后的模型对测试集中的跨域长尾图像数据进行分类。本发明解决了现有分类方法由于图像判别性特征提取不充分、模型决策边界偏移、自训练过程的错误积累等容易造成跨域长尾图像分类错误的问题。
Description
技术领域
本发明属于计算机视觉技术领域,特别是涉及一种基于自监督学习与自训练机制的跨域长尾图像分类方法。
背景技术
在迁移学习中,通常将具有大量带标签信息的数据称为源域,而将待解决的未标注数据称为目标域。领域自适应作为一种具有代表性的迁移学习技术,通过将具有标签的源域和未标记的目标域映射到一个共享特征空间中来消除数据间的分布差异,使得源域训练的分类器能够很好地适用于目标任务。传统的领域自适应方法需假设源域和目标域之间的标签分布相同,然而在实际应用中,源域和目标域之间不仅存在概率分布差异,其标签分布偏移现象也普遍存在,其原因在于:在不同的领域中,获取标签的成本往往不同,例如在不同的地区,高发疾病的类型和数量会有所差异,从而导致不同地区的病患类别数量出现偏差。传统领域自适应方法中,由于仅考虑了源域和目标域之间标签分布相同的情况,导致了传统方法在实际的跨域长尾图像分类任务上失效,甚至出现严重的负迁移现象。
由于深度领域自适应方法的学习目标可以分为两个主要方面,即分类损失和领域差异损失。前者利用源域的监督信息来训练分类器,例如通过最小化交叉熵损失;后者通过减少源域和目标域之间的差异来实现分布对齐,例如,在分布差异的度量方法方面,最常用的度量分布差异的方法是最大均值差异(MMD),其利用核技巧将原始特征映射到再生核希尔伯特空间,然后计算不同域的平均嵌入距离作为域之间的差异度量。因此,不平衡的标签分布对于跨域长尾图像分类同样会带来两个方面的影响:
对于分类器模型训练而言,由于只有源域具有监督信息,标签的质量将直接影响其分类性能。与平衡类别的情况不同,不平衡的标签分布会造成分类器的决策边界偏向于多数类。虽然再平衡策略(例如重采样和重加权)为不平衡学习提供了解决方案,但再平衡策略会由于改变数据的分布而破坏其内在的特征表示,从而将学习过程分解为表示学习和分类器学习。此外,考虑到自监督学习在无监督表示学习方面的优越性,该类型方法在不平衡学习中的有效性已经得到证明。因此,在确保特征表示质量的同时,如何构建源域和目标域的共享特征空间成为跨领域不平衡学习的关键。
对于分布对齐而言,在标签分布偏移的情况下直接对领域之间的差异进行减小并不合适,原因在于标签分布的不一致性会对分布对齐造成影响,导致传统的特征对齐方法失效。尤其是对于需要借助标签信息来实现同一类别之间分布对齐的子域自适应方法,这些方法试图通过利用具有噪声的目标域伪标签进行自训练或条件熵最小化来解决领域分布偏移问题。然而,由于类别的不平衡会直接导致分类器的性能降低,尤其是在初始训练阶段,早期的伪标签分类错误将被迭代累积,这将对领域自适应的整个过程将产生灾难性的影响,导致最终分类错误。
发明内容
本发明实施例的目的在于提供一种基于自监督学习与自训练机制的跨域长尾图像分类方法,以实现跨域长尾图像分类的精确分类,解决了现有分类方法由于图像判别性特征提取不充分、模型决策边界偏移、自训练过程的错误积累等容易造成跨域长尾图像分类错误的问题。
为解决上述技术问题,本发明所采用的技术方案是一种基于自监督学习与自训练机制的跨域长尾图像分类方法,包括以下步骤:
步骤S1、获取跨域长尾图像的数据集,所述数据集包括源域和目标域长尾图像数据,并将数据集划分训练集以及测试集;
步骤S2、采用训练集中的跨域长尾图像数据对双分支网络模型进行训练,包括以下步骤;
步骤S21、对于双分支网络模型的非监督式自适应分支,利用自监督学习与全局分布对齐方法进行表征学习;
步骤S22、对于双分支网络模型的监督式自适应分支,对分类器进行自训练和互信息最大化,进一步提升分类器在目标任务上的分类能力;
步骤S23、利用动态加权集成训练策略自动调整两个分支的学习权重,从而实现两个分支学习目标的同时优化;
步骤S3、采用训练后的模型对测试集中的跨域长尾图像数据进行分类。
进一步地,所述步骤S21包括以下步骤:
步骤S211、对输入的训练集数据进行预处理;
步骤S212、采用全局分布对齐方法最小化源域和目标域长尾图像数据之间的分布距离;
步骤S213、采用自监督表征学习长尾图像数据的内在特征;
步骤S214、调整全局分布对齐与自监督表征学习之间的比例,得到非监督式自适应分支的优化目标,并求得最优的共享特征表示。
进一步地,所述步骤S211中数据预处理具体包括:调整跨域长尾图像的图像大小、归一化处理以及数据增广;
其中,数据增广指对输入的所有源域和目标域的长尾图像数据均采用强增强以及弱增强的方式进行增强;所述强增强包括随机颜色抖动、随机灰度、随机高斯模糊以及随机水平翻转;所述弱增强指随机水平翻转。
进一步地,所述步骤S212中全局分布对齐方法指使用MMD仅对图像预处理中弱增强后的源域和目标域的长尾图像数据之间的分布差异进行度量,并使用高斯核作为MMD中的核函数,并通过将分布距离最小化来训练特征提取器的网络参数,使特征提取器提取到源域和目标域的共享特征表示;其中,对于来自概率分布为Ps和Pt的样本,MMD分布距离表示如下:
式中,映射函数通过使用特征核函数k来替代,即/>从而将样本数量为n的源域Ds中的第i个样本/>以及样本数量为m的目标域Dt中的第j个样本/>通过该核函数k映射到再生核希尔伯特空间H中,其中xs、xt分别表示源域、目标域中的样本。
进一步地,所述步骤S213中自监督表征通过词典查找任务来提取跨域长尾图像数据的内在特征:当一个查询与键值是来自于同一张图片的不同增广时,则将该查询与键值进行匹配;其中查询采用强增强策略,键值则采用弱增强策略;
自监督表征过程中的自监督损失函数为:
其中,N是负例的数量,τ是温度参数,Q为查询,K+为键值,Kn表示第n个负例。
进一步地,所述步骤S214中非监督式自适应分支的优化目标表示为:
其中,Lcon为自监督损失函数,为概率分布为Ps和Pt的MMD分布距离,α为权衡参数。
进一步地,所述步骤S22具体为:
步骤S221、对具有真实标签的源域数据与具有伪标签的高置信度目标域的长尾图像数据进行预处理;
步骤S222、通过自训练方法训练分类器;
步骤S223、通过互信息最大化优化伪标签;
步骤S224、调整自训练损失与互信息之间的比例,得到监督式自适应分支的优化目标;
所述步骤S222具体为:
对于给定的目标域长尾图像样本假设其在每个类别上的输出概率中,最高的概率值为/>其次高的概率值为/>最高的概率值与次高的概率值之间的差值大于阈值,则将所述目标域样本视为高置信度样本,并将具有最高概率值的类别作为伪标签加入到分类器训练中;若不满足条件,则将该样本视为低置信度长尾图像样本,即将其视为空值Null,不参与分类器训练;
选择高置信度的目标域长尾图像样本后,使用重采样来构建用于分类器fcl训练的平衡训练集,其中包括带真实标签的源域长尾图像样本和带伪标签的高置信度目标域长尾图像样本;对于实例数为nc的类标c,重采样的采样概率Pc为:
其中,ηc表示第c个类的样本数占训练样本总数的倒数,计算方式为C表示为分类任务的总类标数,nl是训练集中第l个类的样本数;根据采样概率Pc,在训练过程中等量的采样每个类的样本来实现再平衡;
最终通过最小化分类损失Lcl来训练分类器:
Lcl=Lce(xtrain,ytrain)
其中,Lce是在平衡训练集的交叉熵损失,xtrain、ytrain分别表示带真实标签的源域图像样本和带伪标签的高置信度目标域图像样本。
进一步地,所述步骤S223中互信息最大化指通过计算目标域互信息损失,使目标域互信息损失最大化来增加目标域输出结果的多样性,从而提高伪标签的质量;其中,目标域互信息损失表示为:
其中,是目标域长尾图像样本/>通过Softmax之后得到的预测概率,/>是所有长尾图像样本在第c个类上的预测概率平均值,<·,·>是内积运算。
进一步地,所述步骤S224中监督式自适应分支的优化目标表示为:
Lloab=Lcl-βLmi
其中,Lcl为分类损失,Lmi为目标域互信息损失,β为权衡因子。
进一步地,所述步骤S23中动态加权集成训练策略指通过自动的控制表征学习和分类器学习的权重,让模型首先进行表征学习,然后逐渐关注到分类器学习,其中,动态加权集成训练策略的动态加权集成因子μ的计算方式为:
其中,iter是当前的迭代次数,而T是总的迭代次数,μ会随着迭代次数的增加逐渐趋近于0;通过以当前迭代次数iter为变量的反函数,使用μ作为损失函数的权重来动态的调整学习策略;
所述双分支网络模型的损失函数为:
Ldbdan=μLlfab+(1-μ)Lloab
其中,Llfab为非监督式自适应分支的优化目标,Lloab为监督式自适应分支的优化目标。
本发明的有益效果是:
(1)本发明的非监督式自适应分支的表征学习方法,解决了现有方法因标签分布差异而导致的特征提取器对少数类的特征提取不充分问题。通过利用自监督学习来构造一种简单的词典查找任务,从而获取数据的内在特征表示;并使用MMD来度量全局分布距离,并使该距离最小化来实现跨域长尾图像数据之间的特征对齐。
(2)本发明的监督式自适应分支的分类器学习方法,解决了现有方法因类别不平衡而导致的分类器决策边界偏移问题。通过采用自训练的方式,对高置信度目标域数据与源域数据进行重采样来构造平衡的训练集进行分类学习;并采用互信息最大化来保障目标域伪标签的质量,共同提升跨域长尾图像分类的精确性。
(3)本发明的动态加权集成训练策略,解决了现有自训练过程中由于引入目标域伪标签所带来的错误积累以及神经网络无法端到端训练问题。通过自动的调整表征学习与分类器学习的权重,在训练初期更关注于跨域长尾数据的表征学习,在训练后期更关注于分类学习,缓解了由于表征能力不足造成的自训练错误积累,同时,实现对不同表征学习与分类器学习的同时优化,端到端可训练。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单的介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本发明实施例的基于自监督学习与自训练机制的跨域长尾图像分类方法的示意图。
图2是本发明实施例的不同方法在DomainNet数据集上的t-SNE特征可视化图,其中,(a)是ResNet50在DomainNet数据集上的t-SNE特征可视化图,(b)是DAN在DomainNet数据集上的t-SNE特征可视化图,(c)是JAN在DomainNet数据集上的t-SNE特征可视化图,(d)是本发明方法在DomainNet数据集上的t-SNE特征可视化图。
图3是本发明实施例的基于自监督学习与自训练机制的跨域长尾图像分类方法在DomainNet数据集上的混淆矩阵图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明公开了一种基于自监督学习与自训练机制的跨域长尾图像分类方法,如图1所示,构建了一种双分支网络结构模型将深度网络模型的训练过程拆分为表征学习和分类学习,首先,非监督式自适应分支利用自监督学习与全局分布对齐方法进行表征学习;其次,监督式分自适应分支对学习到的特征进行自训练和互信息最大化来进一步提升在目标任务上的分类能力;最后,设计了一种动态加权集成训练策略来自动调整不同表征学习与分类器学习的学习权重,从而实现两个表征学习与分类器学习目标的同时优化。
具体如下:
S1.非监督式自适应分支的表征学习方法
a.图像预处理:在本发明的实施例中,在非监督表征学习方法中,将输入的源域和目标域的长尾图像数据预处理可分为三步:1)调整跨域长尾图像的图像大小:对所有的输入的跨域长尾图像数据统一进行图像大小调整,并随机裁剪为224×224大小。2)归一化处理:对于输入的跨域长尾图像进行归一化处理,根据三个通道上的均值为{0.485,0.456,0.406},标准差为{0.229,0.224,0.225},将其每一个通道内的每一个像素减去均值,再除以标准差,得到归一化后的结果。3)数据增广:由于自监督学习方法需要利用数据的增广不变性去建立词典查找任务,对于该方法中输入的所有源域和目标域的长尾图像数据均采用了“强-弱”数据增强方式。其中,“强增强”方式采用了随机颜色抖动、随机灰度、随机高斯模糊以及随机水平翻转,从而提高学习特征表示的鲁棒性;“弱增强”方式仅采用了随机水平翻转,以保留更多的图片原始特征信息。由于本实施例采用的DomainNet数据集已预先划分了训练集与测试集,本实施例使用了原始的数据集划分,并将源域和目标域的训练集根据批大小为32进行特征提取器的网络参数训练。
b.全局分布对齐方法:在该步骤中,本实施例使用MMD仅对图像预处理中“弱增强”后的源域和目标域的长尾图像数据之间的分布差异进行度量,这是因为“强增强”会在一定程度上增加MMD的值,使得对于源域和目标域之间的分布差异度量不准确。作为一种非参数度量方法,在MMD的计算中不需要对分布进行任何先验假设,也并未涉及到任何的监督信息。虽然全局的分布差异度量在不平衡的情况下仍会有所偏差,但是相较于计算类内的分布差异而言,MMD受到标签偏移的影响较小,在初始情况下可以更好的进行分布差异度量。具体来说,MMD通过核空间中的均值嵌入在RKHS中以测量源域和目标域之间的分布差异。形式上,对于来自概率分布为Ps和Pt的样本,其MMD距离方式如式(1)所示:
其中,映射函数可以通过使用特征核函数k来替代,即/>其中xs、xt分别表示源域、目标域中的样本,从而将样本数量为n的源域Ds中的第i个样本/>以及样本数量为m的目标域Dt中的第j个样本/>通过该核函数k映射到再生核希尔伯特空间H中。在本实施例中,使用高斯核作为MMD中的核函数,并通过将分布距离/>最小化来训练特征提取器的网络参数,可以使得特征提取器提取到跨域长尾图像的源域和目标域的共享特征表示。
c.自监督表征学习方法:为了在最小化标签分布偏移的负面影响的情况下得到鲁棒的特征表示,本实施例利用自监督学习来建立一种非监督的特征提取方式,并遵循一种简单的词典查找任务来提取跨域长尾图像数据的内在特征,例如,若一个查询与键值是来自于同一张图片的不同增广,则将该查询与键值进行匹配。具体来说,对于给定的样本x∈{Ds∪Dt},目的是学习到它的内在表示Q作为一个查询,该查询可以从一组编码后的特征{K1,K2,K3,...,Kn}(即负例,其中Kn表示第n个负例)中将该样本经过增广后的键值K+(即正例)区分开来。因此,通过采用点积来衡量相似性,提出自监督损失函数,也称为InfoNCE,如式(2)所示:
其中,N是负例的数量,τ是温度参数,用于控制模型对负实例的区分度。
本实施例采用一个队列(大小为65536)来扩展负例的容量,从而有效增强了对比样本的多样性。此外,通过采用两个并行的编码器,查询编码器通过反向传播进行梯度更新,而键值编码器的参数则采用动量更新策略进行更新,以防止在初始阶段损失收敛的波动。本发明采用查询编码器作为深度领域自适应模型的特征提取器(主干网络)。将原始样本施加不同的增广策略用于查询和键,其中查询采用“强增强”策略,以提高学习表示的泛化能力,而键值则采用“弱增强”,以保留实例的原始特征信息。
d.非监督表征学习优化目标:通过在MMD和自监督损失之间添加一个权衡参数α,在本实施例中为0.1,用于调整全局分布对齐与自监督学习之间的比例,即重要性。因此,非监督式自适应分支的优化目标可以表示为式(3),并通过最小化该目标求得最优的共享特征表示:
S2.监督式自适应分支的分类器学习方法
a.图像预处理:通过非监督式表征学习方法对网络参数的“预热”,在一定程度上,增强了神经网络的特征提取能力,也在一定程度上间接增加了目标域伪标签的置信度。在分类器学习过程中,使用非监督表征学习方法中得到的特征提取器对输入数据进行特征提取。其中,输入的是具有真实标签的源域长尾图像数据与具有伪标签的高置信度目标域长尾图像数据,采用的调整图像大小、归一化处理方式与非监督式表征学习方法中的处理方式相同,不同点在于,该分类器学习方法中仅采用“弱增强”方式对输入的源域长尾图像数据和目标域长尾图像数据进行统一处理以进行后续的分类训练。
b.基于自训练的分类器学习方法:自训练方法通过将带有伪标签的高置信度目标域长尾图像样本加入模型训练,从而提升模型对于目标域的学习能力、提升分类器对目标域的预测置信度。因此,对于高置信度样本的选择尤为重要。考虑到在类别分布偏移条件下分类器的预测值很容易受到不平衡的影响,因此,对于目标域长尾图像样本的输出概率中,使用最高概率值与次高概率值的差值来定义该样本的置信度。对于给定的目标域样本/>假设其在每个类别上的输出概率中,最高的概率值为/>其次高的概率值为/>若满足式(4)条件,则将/>视为带有伪标签/>的高置信度样本。
其中,δ是阈值,c为特定的类标。若满足上述条件,即最高的概率值与次高的概率值之间的差值大于阈值,则将该样本视为高置信度样本,并将具有最高概率值的类别作为其伪标签加入到分类器训练中;若不满足条件,则将该样本视为低置信度样本,即将其视为空值Null,不参与分类器训练。在本发明的实验中,将阈值设置为0.9以确保分配到的伪标签的可靠性。
选择高置信度的目标域长尾图像样本后,对带真实标签的源域图像样本和带伪标签的高置信度目标域图像样本进行重采样,来构建用于分类器fcl训练的平衡训练集{xtrain,ytrain}。具体而言,调整每个类的采样概率,即实例数量越少的类别,其拥有的采样概率越高。记第c个类的样本数为nc,则第c个类的采样概率Pc可以通过式(5)计算:
其中,ηc表示第c个类的样本数占训练样本总数的倒数,其计算方式为C表示为该分类任务的总类标数,nl是训练集中第l个类的样本数。根据采样概率Pc,在训练过程中可以等量的采样每个类的样本数来实现再平衡。最终,通过最小化分类损失Lcl来训练分类器:
Lcl=Lce(xtrain,ytrain) (6)
其中,Lce是在平衡训练集上的交叉熵损失,xtrain、ytrain分别表示带真实标签的源域图像样本和带伪标签的高置信度目标域图像样本。
c.基于互信息最大化的伪标签优化方法:通常情况下,在选择高置信度目标域长尾图像样本时,预测的置信度分数越高,对目标领域中的样本进行正确分类的可能性就越大。然而,置信度过高容易出现平凡解的现象,即模型将过度的拟合到单一类别的信息上,缺失对于其他相似类别信息的学习,导致模型严重过拟合。因此,本实施例通过计算目标域互信息损失Lmi,并使其最大化来增加目标域输出结果的多样性,从而优化伪标签的质量,如式(7)所示。
其中,是目标域样本/>通过Softmax之后得到的预测概率,/>是所有样本在第c个类上的预测概率平均值,<·,·>是内积运算。
d.监督式分类器学习优化目标:通过在分类损失和互信息损失之间提供β作为权衡因子,在本发明中为0.1,用于调整自训练损失与互信息之间的比例,即重要性。监督式分类器学习方法的优化目标可以表示为式(8):
Lloab=Lcl-βLmi (8)
S3.动态加权集成训练策略
根据表征学习与分类器学习在网络的学习阶段所扮演的不同角色,该模型在训练分类器之前需要首先获得稳健的特征表示,以最小化标签分布偏移对模型造成的负面影响。因此,本实施例采用了一种动态加权集成训练策略来通过自动的控制表示学习和分类器学习的权重,让模型首先进行表征学习,然后逐渐关注到分类器学习。具体而言,动态加权集成因子μ的计算方式为:
其中,iter是当前的迭代次数,而T是总的迭代次数,μ会随着迭代次数的增加逐渐趋近于0。
通过以当前迭代次数iter为变量的反函数,使用μ作为损失函数的权重来动态的调整学习策略。因此,该模型的损失函数可以表示为式(10),本实施例使用带动量为0.9的小批量SGD对整个模型参数进行优化,得到最优的模型参数,来实现跨域长尾图像的共享特征提取与精确分类:
Ldbdan=μLlfab+(1-μ)Lloab (10)
本实施例选择了一些前沿的深度领域自适应方法对本发明的有效性进行验证。考虑到在类不平衡情况下,整体平均精度不适合评估分类性能,本实施例使用在每个类别上的平均准确率来进行模型性能验证。
本实施例将现有相关方法在具有标签分布偏移的基准数据集DomainNet、OfficeHome的跨域长尾图像分类任务的实验结果分别进行了展示,如表1和表2。总体来说,本发明的方法在大部分的任务上均取得了良好的分类效果,其中,在DomainNe上取得了微弱的领先,在所有的12个任务中,本发明取得了9个任务上的性能领先。而在OfficeHome中,本发明赢得了所有的迁移学习任务,并领先相关领域的先进算法超过2%。这些良好的实验结果都证明了本发明的方法对于标签分布偏移问题的优异性能,并表明本发明在这种特殊的数据分布情况下依然能够学习到更多的可迁移信息。
表1标签分布偏移数据集DomainNet的类平均正确率
表2标签分布偏移数据集OfficeHome的类平均正确率
Method | Rw-Pr | Rw-Cl | Pr-Rw | Pr-Cl | Cl-Rw | Cl-Pr | AVG |
source | 70.74 | 44.24 | 67.33 | 38.68 | 53.51 | 51.85 | 54.39 |
BSP | 72.80 | 23.82 | 66.19 | 20.05 | 32.59 | 30.36 | 40.97 |
PADA | 60.77 | 32.28 | 57.09 | 26.76 | 40.71 | 38.34 | 42.66 |
BBSE | 61.10 | 33.27 | 62.66 | 31.15 | 39.70 | 38.08 | 44.33 |
MCD | 66.03 | 33.17 | 62.95 | 29.99 | 44.47 | 39.01 | 45.94 |
DAN | 69.35 | 40.84 | 66.93 | 34.66 | 53.55 | 52.09 | 52.90 |
F-DANN | 68.56 | 40.57 | 67.32 | 37.33 | 55.84 | 53.67 | 53.88 |
JAN | 67.20 | 43.60 | 68.87 | 39.21 | 57.98 | 48.57 | 54.24 |
DANN | 71.62 | 46.51 | 68.40 | 38.07 | 58.83 | 58.05 | 56.91 |
MDD | 71.21 | 44.78 | 69.31 | 42.56 | 52.10 | 52.70 | 55.44 |
COAL | 73.65 | 42.58 | 73.26 | 40.61 | 59.22 | 57.33 | 58.40 |
InstaPBM | 75.56 | 42.93 | 70.30 | 39.32 | 61.87 | 63.40 | 58.90 |
MDD+I.A | 76.08 | 50.04 | 74.21 | 45.38 | 61.15 | 63.15 | 61.67 |
SENTRY | 76.12 | 56.80 | 73.60 | 54.75 | 65.94 | 64.29 | 65.25 |
本发明-w/o.μ | 75.56 | 52.31 | 73.45 | 49.52 | 63.68 | 63.84 | 63.06 |
本发明 | 79.34 | 57.57 | 75.86 | 54.95 | 67.59 | 67.71 | 67.17 |
为了证明本发明中不同模块的贡献程度,本实施例对模型进行了分解,并在OfficeHome上对6个跨域不平衡学习任务进行了消融实验。本实施例以经过类别重采样策略的交叉熵损失为基线,并将其与不同组合的实验结果(如领域差异损失、对比损失、互信息损失)进行了比较,如表3所示。可以观察到,在不平衡的情况下,类别重采样是有利于分类器的训练,平均性能提高了近4%。相较之下,领域差异损失和互信息损失对整体适应过程的贡献较大,因为它们的组合具有与最终结果相似的实验结果。
表3 DBDAN在数据集OfficeHome的消融实验
其次,本实施例利用t-SNE来可视化ResNet50、DAN、JAN和本发明在DomainNet的任务C-R上学习的特征表示,如图2(a)-(d)所示。从图中可以看出,本发明可以输出比其他特征更具区分性的特征,在图2(d)中具有更清晰的聚类结构,即类内距离较小,类间距离较大。将图2(a)和(b)进行比较可以看出,在LDS的情况下,简单地对齐全局的分布并不能取得良好的效果,这可能会导致比仅使用源域训练更差的域适配结果。相比之下,JAN比DAN获得了更好的适应性能,但很明显,在图2(c)中,有大量的实例分散在类簇之间。
最后,混淆矩阵能够有效地反映模型在类不平衡情况下的性能。从图3中可以看出,本发明的混淆矩阵不仅在每一类上取得了一致的预测结果,而且对混淆矩阵的对角线具有很高的置信度。这些实验结果均证明了本发明在跨域长尾图像分类任务上的有效性。
本发明提供的基于自监督学习与自训练机制的跨域长尾图像分类方法在两个标签分布偏移的数据集上得到了较好的验证。使用的DomainNet和OfficeHome数据集中,不同域之间不仅存在着明显的数据分布差异,其标签分布也存在着明显的不同。在传统领域自适应方法中,仅考虑了源域和目标域之间标签分布相同的情况,导致了传统方法在实际的跨域长尾图像分类任务上失效,甚至出现严重的负迁移现象。因此,本发明提出的技术方案不仅提升了模型在跨域长尾图像分类任务上的共享特征提取能力,也实现了在该复杂数据分布情况下的准确分类。
本说明书中的各个实施例均采用相关的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于系统实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
以上所述仅为本发明的较佳实施例而已,并非用于限定本发明的保护范围。凡在本发明的精神和原则之内所作的任何修改、等同替换、改进等,均包含在本发明的保护范围内。
Claims (10)
1.一种基于自监督学习与自训练机制的跨域长尾图像分类方法,其特征在于,包括以下步骤:
步骤S1、获取跨域长尾图像的数据集,所述数据集包括源域和目标域长尾图像数据,并将数据集划分训练集以及测试集;
步骤S2、采用训练集中的跨域长尾图像数据对双分支网络模型进行训练,包括以下步骤;
步骤S21、对于双分支网络模型的非监督式自适应分支,利用自监督学习与全局分布对齐方法进行表征学习;
步骤S22、对于双分支网络模型的监督式自适应分支,对分类器进行自训练和互信息最大化,进一步提升分类器在目标任务上的分类能力;
步骤S23、利用动态加权集成训练策略自动调整两个分支的学习权重,从而实现两个分支学习目标的同时优化;
步骤S3、采用训练后的模型对测试集中的跨域长尾图像数据进行分类。
2.根据权利要求1所述的一种基于自监督学习与自训练机制的跨域长尾图像分类方法,其特征在于,所述步骤S21包括以下步骤:
步骤S211、对输入的训练集数据进行预处理;
步骤S212、采用全局分布对齐方法最小化源域和目标域长尾图像数据之间的分布距离;
步骤S213、采用自监督表征学习长尾图像数据的内在特征;
步骤S214、调整全局分布对齐与自监督表征学习之间的比例,得到非监督式自适应分支的优化目标,并求得最优的共享特征表示。
3.根据权利要求2所述的一种基于自监督学习与自训练机制的跨域长尾图像分类方法,其特征在于,所述步骤S211中数据预处理具体包括:调整跨域长尾图像的图像大小、归一化处理以及数据增广;
其中,数据增广指对输入的所有源域和目标域的长尾图像数据均采用强增强以及弱增强的方式进行增强;所述强增强包括随机颜色抖动、随机灰度、随机高斯模糊以及随机水平翻转;所述弱增强指随机水平翻转。
4.根据权利要求2所述的一种基于自监督学习与自训练机制的跨域长尾图像分类方法,其特征在于,所述步骤S212中全局分布对齐方法指使用MMD仅对图像预处理中弱增强后的源域和目标域的长尾图像数据之间的分布差异进行度量,并使用高斯核作为MMD中的核函数,并通过将分布距离最小化来训练特征提取器的网络参数,使特征提取器提取到源域和目标域的共享特征表示;其中,对于来自概率分布为Ps和Pt的样本,MMD分布距离表示如下:
式中,映射函数通过使用特征核函数k来替代,即/>从而将样本数量为n的源域Ds中的第i个样本/>以及样本数量为m的目标域Dt中的第j个样本/>通过该核函数k映射到再生核希尔伯特空间H中,其中xs、xt分别表示源域、目标域中的样本。
5.根据权利要求2所述的一种基于自监督学习与自训练机制的跨域长尾图像分类方法,其特征在于,所述步骤S213中自监督表征通过词典查找任务来提取跨域长尾图像数据的内在特征:当一个查询与键值是来自于同一张图片的不同增广时,则将该查询与键值进行匹配;其中查询采用强增强策略,键值则采用弱增强策略;
自监督表征过程中的自监督损失函数为:
其中,N是负例的数量,τ是温度参数,Q为查询,K+为键值,Kn表示第n个负例。
6.根据权利要求2所述的一种基于自监督学习与自训练机制的跨域长尾图像分类方法,其特征在于,所述步骤S214中非监督式自适应分支的优化目标表示为:
其中,Lcon为自监督损失函数,为概率分布为Ps和Pt的MMD分布距离,α为权衡参数。
7.根据权利要求1所述的一种基于自监督学习与自训练机制的跨域长尾图像分类方法,其特征在于,所述步骤S22具体为:
步骤S221、对具有真实标签的源域数据与具有伪标签的高置信度目标域的长尾图像数据进行预处理;
步骤S222、通过自训练方法训练分类器;
步骤S223、通过互信息最大化优化伪标签;
步骤S224、调整自训练损失与互信息之间的比例,得到监督式自适应分支的优化目标;
所述步骤S222具体为:
对于给定的目标域长尾图像样本假设其在每个类别上的输出概率中,最高的概率值为/>其次高的概率值为/>最高的概率值与次高的概率值之间的差值大于阈值,则将所述目标域样本视为高置信度样本,并将具有最高概率值的类别作为伪标签加入到分类器训练中;若不满足条件,则将该样本视为低置信度长尾图像样本,即将其视为空值Null,不参与分类器训练;
选择高置信度的目标域长尾图像样本后,使用重采样来构建用于分类器fcl训练的平衡训练集,其中包括带真实标签的源域长尾图像样本和带伪标签的高置信度目标域长尾图像样本;对于实例数为nc的类标c,重采样的采样概率Pc为:
其中,ηc表示第c个类的样本数占训练样本总数的倒数,计算方式为C表示为分类任务的总类标数,nl是训练集中第l个类的样本数;根据采样概率Pc,在训练过程中等量的采样每个类的样本来实现再平衡;
最终通过最小化分类损失Lcl来训练分类器:
Lcl=Lce(xtrain,ytrain)
其中,Lce是在平衡训练集的交叉熵损失,xtrain、ytrain分别表示带真实标签的源域图像样本和带伪标签的高置信度目标域图像样本。
8.根据权利要求7所述的一种基于自监督学习与自训练机制的跨域长尾图像分类方法,其特征在于,所述步骤S223中互信息最大化指通过计算目标域互信息损失,使目标域互信息损失最大化来增加目标域输出结果的多样性,从而提高伪标签的质量;其中,目标域互信息损失表示为:
其中,是目标域长尾图像样本/>通过Softmax之后得到的预测概率,/>是所有长尾图像样本在第c个类上的预测概率平均值,<·,·>是内积运算。
9.根据权利要求7所述的一种基于自监督学习与自训练机制的跨域长尾图像分类方法,其特征在于,所述步骤S224中监督式自适应分支的优化目标表示为:
Lloab=Lcl-βLmi
其中,Lcl为分类损失,Lmi为目标域互信息损失,β为权衡因子。
10.根据权利要求1所述的一种基于自监督学习与自训练机制的跨域长尾图像分类方法,其特征在于:
所述步骤S23中动态加权集成训练策略指通过自动的控制表征学习和分类器学习的权重,让模型首先进行表征学习,然后逐渐关注到分类器学习,其中,动态加权集成训练策略的动态加权集成因子μ的计算方式为:
其中,iter是当前的迭代次数,而T是总的迭代次数,μ会随着迭代次数的增加逐渐趋近于0;通过以当前迭代次数iter为变量的反函数,使用μ作为损失函数的权重来动态的调整学习策略;
所述双分支网络模型的损失函数为:
Ldbdan=μLlfab+(1-μ)Lloab
其中,Llfab为非监督式自适应分支的优化目标,Lloab为监督式自适应分支的优化目标。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311137191.1A CN117115547A (zh) | 2023-09-05 | 2023-09-05 | 基于自监督学习与自训练机制的跨域长尾图像分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311137191.1A CN117115547A (zh) | 2023-09-05 | 2023-09-05 | 基于自监督学习与自训练机制的跨域长尾图像分类方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117115547A true CN117115547A (zh) | 2023-11-24 |
Family
ID=88801984
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311137191.1A Pending CN117115547A (zh) | 2023-09-05 | 2023-09-05 | 基于自监督学习与自训练机制的跨域长尾图像分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117115547A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117408330A (zh) * | 2023-12-14 | 2024-01-16 | 合肥高维数据技术有限公司 | 面向非独立同分布数据的联邦知识蒸馏方法及装置 |
CN117688472A (zh) * | 2023-12-13 | 2024-03-12 | 华东师范大学 | 一种基于因果结构的无监督域适应多元时间序列分类方法 |
-
2023
- 2023-09-05 CN CN202311137191.1A patent/CN117115547A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117688472A (zh) * | 2023-12-13 | 2024-03-12 | 华东师范大学 | 一种基于因果结构的无监督域适应多元时间序列分类方法 |
CN117688472B (zh) * | 2023-12-13 | 2024-05-24 | 华东师范大学 | 一种基于因果结构的无监督域适应多元时间序列分类方法 |
CN117408330A (zh) * | 2023-12-14 | 2024-01-16 | 合肥高维数据技术有限公司 | 面向非独立同分布数据的联邦知识蒸馏方法及装置 |
CN117408330B (zh) * | 2023-12-14 | 2024-03-15 | 合肥高维数据技术有限公司 | 面向非独立同分布数据的联邦知识蒸馏方法及装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113378632B (zh) | 一种基于伪标签优化的无监督域适应行人重识别方法 | |
CN117115547A (zh) | 基于自监督学习与自训练机制的跨域长尾图像分类方法 | |
WO2021134871A1 (zh) | 基于局部二值模式和深度学习的合成人脸图像取证方法 | |
CN112926397B (zh) | 基于两轮投票策略集成学习的sar图像海冰类型分类方法 | |
CN113469236B (zh) | 一种自我标签学习的深度聚类图像识别系统及方法 | |
CN113326731A (zh) | 一种基于动量网络指导的跨域行人重识别算法 | |
CN114842267A (zh) | 基于标签噪声域自适应的图像分类方法及系统 | |
CN113344044B (zh) | 一种基于领域自适应的跨物种医疗影像分类方法 | |
CN113269647B (zh) | 基于图的交易异常关联用户检测方法 | |
CN108877947B (zh) | 基于迭代均值聚类的深度样本学习方法 | |
CN103839033A (zh) | 一种基于模糊规则的人脸识别方法 | |
CN114139676A (zh) | 领域自适应神经网络的训练方法 | |
CN112115849A (zh) | 基于多粒度视频信息和注意力机制的视频场景识别方法 | |
CN117153268A (zh) | 一种细胞类别确定方法及系统 | |
CN114818963B (zh) | 一种基于跨图像特征融合的小样本检测方法 | |
CN109726703A (zh) | 一种基于改进集成学习策略的人脸图像年龄识别方法 | |
CN116912568A (zh) | 基于自适应类别均衡的含噪声标签图像识别方法 | |
CN117152606A (zh) | 一种基于置信度动态学习的遥感图像跨域小样本分类方法 | |
CN115761408A (zh) | 一种基于知识蒸馏的联邦域适应方法及系统 | |
CN109842614B (zh) | 基于数据挖掘的网络入侵检测方法 | |
CN114549909A (zh) | 一种基于自适应阈值的伪标签遥感图像场景分类方法 | |
CN113869451A (zh) | 一种基于改进jgsa算法的变工况下滚动轴承故障诊断方法 | |
CN116433909A (zh) | 基于相似度加权多教师网络模型的半监督图像语义分割方法 | |
CN112699782A (zh) | 基于N2N和Bert的雷达HRRP目标识别方法 | |
CN116523877A (zh) | 一种基于卷积神经网络的脑mri图像肿瘤块分割方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |