CN114677535A - 域适应图像分类网络的训练方法、图像分类方法及装置 - Google Patents

域适应图像分类网络的训练方法、图像分类方法及装置 Download PDF

Info

Publication number
CN114677535A
CN114677535A CN202210193844.7A CN202210193844A CN114677535A CN 114677535 A CN114677535 A CN 114677535A CN 202210193844 A CN202210193844 A CN 202210193844A CN 114677535 A CN114677535 A CN 114677535A
Authority
CN
China
Prior art keywords
domain
image
target domain
cross
target
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210193844.7A
Other languages
English (en)
Inventor
林兰芬
马旭
袁俊坤
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang University ZJU
Original Assignee
Zhejiang University ZJU
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang University ZJU filed Critical Zhejiang University ZJU
Priority to CN202210193844.7A priority Critical patent/CN114677535A/zh
Publication of CN114677535A publication Critical patent/CN114677535A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/22Matching criteria, e.g. proximity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Probability & Statistics with Applications (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明公开了一种域适应图像分类网络的训练方法、图像分类方法及装置,其中域适应图像分类网络的训练方法包括:获取若干对源域图像和目标域图像;提取其中一对源域图像和目标域图像的跨层特征;利用注意力机制计算跨层特征之间的相似度;根据所述跨层特征的多核最大均值差异和所述相似度,计算域对齐泛化损失;根据源域图像和目标域图像的跨层特征,计算分类损失;根据域对齐泛化损失和分类损失,加权计算域适应图像分类网络的总损失;根据总损失,更新域适应图像分类网络的参数;对其余源域图像和目标域图像执行从提取其中一对源域图像和目标域图像的跨层特征至根据总损失更新域适应图像分类网络的参数的步骤,直至跨层对齐损失收敛。

Description

域适应图像分类网络的训练方法、图像分类方法及装置
技术领域
本申请涉及图像分类技术领域,尤其涉及域适应图像分类网络的训练方法、图像分类方法及装置。
背景技术
机器学习算法和深度神经网络等技术的快速发展,使得图像分类模型性能得以大幅度提升。当有足够的有标签训练样本,且训练样本与测试样本满足独立同分布的假设时,分类模型可以获得较好的效果。但是,在实际应用中,收集足够多的有标签的训练图像经常耗时长、代价昂贵甚至无法实现。同时,由于各种因素不可能保证训练样本始终与测试样本具有相同的分布,数据分布的差异导致传统深度学习方法训练得到的模型难以在新的数据集上取得较好的表现,这限制了机器学习模型的泛化能力。而无监督域适应图像分类方法可以有效解决以上问题。
无监督域适应图像分类方法主要分为两类。一类是基于对抗学习的方法,另一类是基于距离度量的对齐方法。后者通过减小源域和目标域的数据分布差异来降低模型在目标域的泛化误差,从而得到一个能在目标域上表现优异的模型。具体为,基于差异度量指标将源域和目标域的特征映射到一个公共的再生核希尔伯特空间(Reproducing KernelHilbert Space,RKHS)中,通过最小化域间分布差异的度量指标,实现源域和目标域的分布对齐。其中,度量域间分布差异的指标包括KL散度、最大均值差异(Maximum MeanDiscrepancy,MMD)、Wasserstein距离等。该类方法相较于基于对抗学习的方法,具有操作简单、训练耗时短等优点,因此成为目前主流的研究方法。
在实现本发明的过程中,发明人发现现有技术中至少存在如下问题:
基于距离度量的对齐方法均基于一个假设,即模型中的每层网络在两个域之间提取的语义特征蕴含相同层次的语义信息。具体来说,上述方法只对模型中同层网络提取的两个域的语义特征进行对齐。而现有的研究表明,由于两个域之间存在域偏移,导致相同层次的语义信息分散在模型各层网络的输出中。在这种情况下,只对同层网络提取的特征进行对齐,就会出现训练后的模型在目标域上分类准确率下降的情况,即负迁移。
发明内容
本申请实施例的目的是提供一种域适应图像分类网络的训练方法、图像分类方法及装置,以解决相关技术中存在的无法解决同层次语义信息分布在模型各层网络的输出中的技术问题。
根据本申请实施例的第一方面,提供一种域适应图像分类网络的训练方法,包括:
获取若干对源域图像和目标域图像,其中每一对源域图像和目标域图像的类别均相同;
提取其中一对所述源域图像和目标域图像的跨层特征;
利用注意力机制,计算所述源域图像和目标域图像的跨层特征之间的相似度;
根据所述跨层特征的多核最大均值差异和所述相似度,计算域对齐泛化损失;
根据所述源域图像和目标域图像的跨层特征,计算分类损失;
根据所述域对齐泛化损失和分类损失,加权计算域适应图像分类网络的总损失;
根据所述总损失,更新所述域适应图像分类网络的参数;
对其余源域图像和目标域图像执行从提取其中一对所述源域图像和目标域图像的跨层特征至根据所述总损失,更新所述域适应图像分类网络的参数的步骤,直至所述跨层对齐损失收敛。
进一步地,获取若干对源域图像和目标域图像之后,还包括:
调整所述源域图像和目标域图像,以使得所述源域图像和目标域图像的尺寸相同;
对调整后的源域图像和目标域图像编码。
进一步地,所述源域图像的跨层特征包括第一源域特征和第二源域特征,所述目标域图像的跨层特征包括第一目标域特征和第二目标域特征。
进一步地,利用注意力机制,计算所述源域图像和目标域图像的跨层特征之间的相似度,包括:
根据所述第一源域特征和第一目标域特征,提取局部源域特征和局部目标域特征;
计算每一对局部源域特征和局部目标域特征的通道相似度和空间相似度;
对所述通道相似度和空间相似度进行平均,得到所述源域图像和目标域图像的跨层特征之间的相似度。
进一步地,根据所述跨层特征的多核最大均值差异和所述相似度,计算域对齐泛化损失,包括:
计算所述第一源域特征和第一目标域特征的第一多核最大均值差异;
计算所述第二源域特征和第二目标域特征的第二多核最大均值差异;
计算所述第一多核最大均值差异与所述相似度的乘积的和;
对所述乘积的和与第二多核最大均值差异进行加权求和,得到域泛化损失。
根据本申请实施例的第二方面,提供一种域适应图像分类网络的训练装置,包括:
第一获取模块,用于获取若干对源域图像和目标域图像,其中每一对源域图像和目标域图像的类别均相同;
提取模块,用于提取其中一对所述源域图像和目标域图像的跨层特征;
第一计算模块,用于利用注意力机制,计算所述源域图像和目标域图像的跨层特征之间的相似度;
第二计算模块,用于根据所述跨层特征的多核最大均值差异和所述相似度,计算域对齐泛化损失;
第三计算模块,用于根据所述源域图像和目标域图像的跨层特征,计算分类损失;
第四计算模块,用于根据所述域对齐泛化损失和分类损失,加权计算域适应图像分类网络的总损失;
第一更新模块,用于根据所述总损失,更新所述域适应图像分类网络的参数;
第二更新模块,用于对其余源域图像和目标域图像执行从提取其中一对所述源域图像和目标域图像的跨层特征至根据所述总损失,更新所述域适应图像分类网络的参数的步骤,直至所述跨层对齐损失收敛。
根据本申请实施例的第三方面,提供一种图像分类方法,包括:
获取待分类的目标域图像;
将所述目标域图像输入域适应图像分类网络,其中所述域适应图像分类网络为根据第一方面所述方法训练得到的网络;
获取所述域适应图像分类网络输出概率组,所述概率组中包括所述目标域图像分别属于各已知类别的概率;
将值最大的所述概率对应的类别设置为所述目标域图像的类别。
根据本申请实施例的第四方面,提供一种图像分类装置,包括:
第二获取模块,用于获取待分类的目标域图像;
输入模块,用于将所述目标域图像输入域适应图像分类网络,其中所述域适应图像分类网络为根据第一方面所述方法训练得到的网络;
第三获取模块,用于获取所述域适应图像分类网络输出概率组,所述概率组中包括所述目标域图像分别属于各已知类别的概率;
设置模块,用于将值最大的所述概率对应的类别设置为所述目标域图像的类别。
根据本申请实施例的第五方面,提供一种电子设备,包括:
一个或多个处理器;
存储器,用于存储一个或多个程序;
当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理器实现如第一方面或第三方面任一项所述的方法。
根据本申请实施例的第六方面,提供一种计算机可读存储介质,其上存储有计算机指令,该指令被处理器执行时实现如第一方面或第三方面任一项所述方法的步骤。
本申请的实施例提供的技术方案可以包括以下有益效果:
由上述实施例可知,本申请通过对域适应图像分类网络的各层网络提取的特征进行对齐,解决了图像的语义信息分散在模型各层网络的输出中的问题,避免了负迁移造成的影响;提取其中一对所述源域图像和目标域图像的跨层特征,利用注意力机制计算所述跨层特征之间的相似度,并根据所述相似度和所述源域图像和目标域图像的跨层特征的多核最大均值差异,计算域对齐泛化损失,解决了现有技术无法解决同层次语义信息分布在模型各层网络的输出中的问题,在处理无监督域适应图像分类任务时,自动匹配并对齐分散在模型不同层网络中的同层次语义信息,进而提升域适应图像分类网络的的分类精度。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本申请。
附图说明
此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本申请的实施例,并与说明书一起用于解释本申请的原理。
图1是根据一示例性实施例示出的一种域适应图像分类网络的训练方法的流程图。
图2是根据一示例性实施例示出的域适应图像分类网络的结构示意图。
图3是根据一示例性实施例示出的步骤S11之后还可以包括的步骤的流程图。
图4是根据一示例性实施例示出的步骤S13的流程图。
图5是根据一示例性实施例示出的步骤S14的流程图。
图6是根据一示例性实施例示出的一种域适应图像分类网络的训练装置的框图。
图7是根据一示例性实施例示出的一种图像分类方法的流程图。
图8是根据一示例性实施例示出的一种图像分类装置的框图。
具体实施方式
这里将详细地对示例性实施例进行说明,其示例表示在附图中。下面的描述涉及附图时,除非另有表示,不同附图中的相同数字表示相同或相似的要素。以下示例性实施例中所描述的实施方式并不代表与本申请相一致的所有实施方式。相反,它们仅是与如所附权利要求书中所详述的、本申请的一些方面相一致的装置和方法的例子。
在本申请使用的术语是仅仅出于描述特定实施例的目的,而非旨在限制本申请。在本申请和所附权利要求书中所使用的单数形式的“一种”、“所述”和“该”也旨在包括多数形式,除非上下文清楚地表示其他含义。还应当理解,本文中使用的术语“和/或”是指并包含一个或多个相关联的列出项目的任何或所有可能组合。
应当理解,尽管在本申请可能采用术语第一、第二、第三等来描述各种信息,但这些信息不应限于这些术语。这些术语仅用来将同一类型的信息彼此区分开。例如,在不脱离本申请范围的情况下,第一信息也可以被称为第二信息,类似地,第二信息也可以被称为第一信息。取决于语境,如在此所使用的词语“如果”可以被解释成为“在……时”或“当……时”或“响应于确定”。
实施例1:
图1是根据一示例性实施例示出的一种域适应图像分类网络的训练方法的流程图,如图1所示,该方法可以包括以下步骤:
步骤S11:获取若干对源域图像和目标域图像,其中每一对源域图像和目标域图像的类别均相同;
步骤S12:提取其中一对所述源域图像和目标域图像的跨层特征;
步骤S13:利用注意力机制,计算所述源域图像和目标域图像的跨层特征之间的相似度;
步骤S14:根据所述跨层特征的多核最大均值差异和所述相似度,计算域对齐泛化损失;
步骤S15:根据所述源域图像和目标域图像的跨层特征,计算分类损失;
步骤S16:根据所述域对齐泛化损失和分类损失,加权计算域适应图像分类网络的总损失;
步骤S17:根据所述总损失,更新所述域适应图像分类网络的参数;
步骤S18:对其余源域图像和目标域图像执行从提取其中一对所述源域图像和目标域图像的跨层特征至根据所述总损失,更新所述域适应图像分类网络的参数的步骤,直至所述跨层对齐损失收敛。
由上述实施例可知,本申请提取其中一对所述源域图像和目标域图像的跨层特征,利用注意力机制计算所述跨层特征之间的相似度,并根据所述相似度和所述源域图像和目标域图像的跨层特征的多核最大均值差异,计算域对齐泛化损失,在处理无监督域适应图像分类任务时,自动匹配并对齐分散在模型不同层网络中的同层次语义信息,进而提升模型的的分类精度。
需要说明的是,该训练方法对应的域适应图像分类网络的结构如图2所示。所述域适应图像分类网络包括特征提取器F、基于注意力机制的特征对齐模块H和分类器C,其中F由一个将输出维度被改为d的ResNet-50网络组成;H由6个卷积层(convolutional layer)、一个通道注意力模块(channel attention module)和一个空间注意力模块(spatialattention module)组成;C由一个全连接层(fully connected layer)组成。
在步骤S11的具体实施中,获取若干对源域图像和目标域图像,其中每一对源域图像和目标域图像的类别均相同;
具体地,随机选取若干对源域图像gs和目标域图像gt。用于后续计算域对齐泛化损失。
具体地,如图3所示,在步骤S11之后还可以包括以下步骤:
步骤S21:调整所述源域图像和目标域图像,以使得所述源域图像和目标域图像的尺寸相同;
具体地,统一图像的尺寸。使用双线性插值算法将图像gs和gt都缩放成尺寸为224px*224px的图像。处理后,得到的图像大小相同,符合ResNet-50网络的输入规格。
步骤S22:对调整后的源域图像和目标域图像编码;
具体地,对源域图像gs和目标域图像gt进行编码,将大小相同的图像gs和gt中的所有像素值除以255,然后使用公式(1)对图像RGB三个通道的数值(val)进行归一化处理,得到编码矩阵xs和xt,其中RGB三个通道的均值(mean)分别为0.485、0.456和0.406,标准偏差(std)分别为0.229、0.224和0.225。
Figure BDA0003526106200000081
通过归一化处理,防止神经网络模型在训练过程中发生梯度爆炸。
在步骤S12的具体实施中,提取其中一对所述源域图像和目标域图像的跨层特征;
具体地,所述源域图像的跨层特征包括第一源域特征和第二源域特征,所述目标域图像的跨层特征包括第一目标域特征和第二目标域特征。
具体地,将编码矩阵xs和xt输入特征提取器F中进行特征提取,得到第一源域特征
Figure BDA0003526106200000091
第一目标域特征
Figure BDA0003526106200000092
第二源域特征fs和第二目标域特征ft,其中
Figure BDA0003526106200000093
Figure BDA0003526106200000094
分别表示F中倒数第i个残差块从源域和目标域中提取出来的mi×ai×ai维的实数矩阵特征,fs和ft分别表示F中最后一层从源域和目标域中提取出来的d维实数向量特征。通过收集不同残差块的输出,进而获取源域和目标域上不同层次的语义信息,使获得的语义信息更加完整。
在步骤S13的具体实施中,利用注意力机制,计算所述源域图像和目标域图像的跨层特征之间的相似度;
具体地,如图4所示,步骤S13可以包括以下子步骤:
步骤S31:根据所述第一源域特征和第一目标域特征,提取局部源域特征和局部目标域特征;
具体地,特征对齐模块H中的6个卷积层使用k个大小为hin-6的卷积核,提取k种局部特征,其中hin为卷积层输入特征的第2和第3维度大小。第一源域特征
Figure BDA0003526106200000095
和第一目标域特征
Figure BDA0003526106200000096
经过卷积层后得到k×7×7维的实数矩阵
Figure BDA0003526106200000097
Figure BDA0003526106200000098
其中
Figure BDA0003526106200000099
为局部源域特征,
Figure BDA00035261062000000910
Figure BDA00035261062000000911
为局部目标域特征。通过上述操作,对特征进行压缩,减少训练过程中对显存的占用。
步骤S32:计算每一对局部源域特征和局部目标域特征的通道相似度和空间相似度;
具体地,通过通道注意力模块,计算每一对跨层语义特征
Figure BDA00035261062000000912
的通道相似度,其中i,j∈{1,2,3}。每次计算时,先将
Figure BDA00035261062000000913
Figure BDA00035261062000000914
的大小拉伸为k×49,然后通过公式(2)得到
Figure BDA00035261062000000915
Figure BDA00035261062000000916
之间的通道相似度αi,j
Figure BDA00035261062000000917
其中avg(X)表示对矩阵X中所有元素求均值。
具体地,通过空间注意力模块,计算每一对跨层语义特征
Figure BDA0003526106200000101
的空间相似度,其中i,j∈{1,2,3}。每次计算时,先将
Figure BDA0003526106200000102
Figure BDA0003526106200000103
的大小拉伸为k×49,然后通过公式(3)得到
Figure BDA0003526106200000104
Figure BDA0003526106200000105
之间的空间相似度βi,j
Figure BDA0003526106200000106
其中avg(X)表示对矩阵X中所有元素求均值。
此步骤将每对跨层语义特征之间的相似度自动量化为两个实数αi,j和βi,j,避免人为调参,减少模型训练难度。
步骤S33:对所述通道相似度和空间相似度进行平均,得到所述源域图像和目标域图像的跨层特征之间的相似度;
具体地,得到所述通道相似度αi,j和空间相似度βi,j后,对两者进行平均运算,得到所述源域图像和目标域图像的跨层特征之间的相似度
Figure BDA0003526106200000107
在步骤S14的具体实施中,根据所述跨层特征的多核最大均值差异和所述相似度,计算域对齐泛化损失;
具体地,如图5所示,步骤S14可以包括以下子步骤:
步骤S41:计算所述第一源域特征和第一目标域特征的第一多核最大均值差异;
具体地,获取第一源域特征
Figure BDA0003526106200000108
和第一目标域特征
Figure BDA0003526106200000109
使用多核最大均值差异,计算第一多核最大均值差异,计算公式为:
Figure BDA00035261062000001010
其中Dk是多核最大均值差异的计算公式。
步骤S42:计算所述第二源域特征和第二目标域特征的第二多核最大均值差异;
具体地,获取第二源域特征fs和第二目标域特征ft,使用多核最大均值差异,计算第二源域和目标域之间的距离,计算公式为:
Figure BDA0003526106200000111
其中Dk是多核最大均值差异的计算公式。
步骤S43:计算所述第一多核最大均值差异与所述相似度的乘积的和;
具体地,计算源域和目标域的跨层之间的相似度
Figure BDA0003526106200000112
与所述第一多核最大均值差异
Figure BDA0003526106200000113
的乘积的和
Figure BDA0003526106200000114
步骤S44:对所述乘积的和与第二多核最大均值差异进行加权求和,得到域泛化损失;
具体地,使用公式(4)计算域对齐泛化损失。
Figure BDA0003526106200000115
其中,δ是第一超参数,在实施例中设置为0.3。
在处理无监督域适应图像分类任务时,通过自动匹配并对齐分散在模型不同层网络中的同层次语义信息,进而提升模型的分类精度。
在步骤S15的具体实施中,根据所述源域图像和目标域图像的跨层特征,计算分类损失;
具体地,分类器C中的全连接层以特征提取器F的输出f作为输入,使用Softmax函数作为激活函数,并使用公式(5)所示的交叉熵损失(cross entropy loss)计算分类损失。
Figure BDA0003526106200000116
式中c表示共有c个图像类别,pu是范围在0到1之间实数,表示图像属于类别u的概率。
在步骤S16的具体实施中,根据所述域对齐泛化损失和分类损失,加权计算域适应图像分类网络的总损失;
具体地,通过最小化公式(6)计算域适应图像分类网络的总损失,得到的总损失可用于训练整个网络。
Figure BDA0003526106200000121
式中的
Figure BDA0003526106200000122
表示使用特征fs计算得到的交叉熵损失,γ为第二超参数,本实施例中设置为0.3。
在步骤S17的具体实施中,根据所述总损失,更新所述域适应图像分类网络的参数;
具体地,使用随机梯度下降(SGD)方法更新分类网络的参数。
在步骤S18的具体实施中,对其余源域图像和目标域图像执行从提取其中一对所述源域图像和目标域图像的跨层特征至根据所述总损失,更新所述域适应图像分类网络的参数的步骤,直至所述跨层对齐损失收敛;
具体地,对于其余源域图像和目标域图像,重复执行步骤S11~S17,直至总损失值达到预设的收敛条件时,将收敛之后的域适应图像分类网络记录为训练完成的域适应图像分类模型。
与前述的域适应图像分类网络的训练方法的实施例相对应,本申请还提供了域适应图像分类网络的训练装置的实施例。
图6是根据一示例性实施例示出的一种域适应图像分类网络的训练装置的框图。参照图6,该装置可以包括:
第一获取模块21,用于获取若干对源域图像和目标域图像,其中每一对源域图像和目标域图像的类别均相同;
提取模块22,用于提取其中一对所述源域图像和目标域图像的跨层特征;
第一计算模块23,用于利用注意力机制,计算所述源域图像和目标域图像的跨层特征之间的相似度;
第二计算模块24,用于根据所述跨层特征的多核最大均值差异和所述相似度,计算域对齐泛化损失;
第三计算模块25,用于根据所述源域图像和目标域图像的跨层特征,计算分类损失;
第四计算模块26,用于根据所述域对齐泛化损失和分类损失,加权计算域适应图像分类网络的总损失;
第一更新模块27,用于根据所述总损失,更新所述域适应图像分类网络的参数;
第二更新模块28,用于对其余源域图像和目标域图像执行从提取其中一对所述源域图像和目标域图像的跨层特征至根据所述总损失,更新所述域适应图像分类网络的参数的步骤,直至所述跨层对齐损失收敛。
实施例2:
图7是根据一示例性实施例示出的一种图像分类方法的流程图,如图7所示,该方法可以包括以下步骤:
步骤S51:获取待分类的目标域图像;
具体地,对于每一个目标域图像,按照上述步骤S21和步骤S22,先统一图像的尺寸,然后统一图像尺寸并做归一化处理,得到目标域图像xt
步骤S52:将所述目标域图像输入域适应图像分类网络,其中所述域适应图像分类网络为根据实施例1中所述域适应图像分类网络的训练方法训练得到的网络;
具体地,将源域图像和目标域图像的编码矩阵xt输入特征提取器F中进行特征提取,得到第二目标域特征ft。将得到的ft输入分类器C中。
步骤S53:获取所述域适应图像分类网络输出概率组,所述概率组中包括所述目标域图像分别属于各已知类别的概率;
具体地,分类器C中的全连接层以上述ft作为输入,使用Softmax函数作为激活函数,输出目标域图像属于各个已知类别的概率[p1,p2,p3,...,pc],其中c表示类别总数,pi表示该图像属于类别i的概率。
步骤S54:将值最大的所述概率对应的类别设置为所述目标域图像的类别。
具体地,对于上述得到的图像属于已知类别的概率[p1,p2,p3,...,pc],假设其中最大的概率值为pu,则预测该图像属于第u类。计算上述得到的图像属于已知类别概率[p1,p2,p3,...,pc]中的最大值pu,则预测该图像属于第u类。
由上述实施例可知,本申请提出了一种图像分类方法,通过将待分类的目标域图像输入通过实施例1中的方法训练的域适应图像分类网络中,获得目标域图像属于各已知类别的概率,从而得到目标域图像的类别。由于实施例1中的方法通过将各层网络提取的特征进行对齐,避免了负迁移造成的影响,提升了分类精度,因此本图像分类方法也避免了负迁移造成的影响,具有高分类精度。
与前述的图像分类方法的实施例相对应,本申请还提供了图像分类装置的实施例。
图8是根据一示例性实施例示出的一种图像分类装置的框图,如图8所示,该装置可以包括:
第二获取模块31,用于获取待分类的目标域图像;
输入模块32,用于将所述目标域图像输入域适应图像分类网络,其中所述域适应图像分类网络为根据实施例1中所述域适应图像分类网络的训练方法训练得到的网络;
第三获取模块33,用于获取所述域适应图像分类网络输出概率组,所述概率组中包括所述目标域图像分别属于各已知类别的概率;
设置模块34,用于将值最大的所述概率对应的类别设置为所述目标域图像的类别。
关于上述实施例中的装置,其中各个模块执行操作的具体方式已经在有关该方法的实施例中进行了详细描述,此处将不做详细阐述说明。
对于装置实施例而言,由于其基本对应于方法实施例,所以相关之处参见方法实施例的部分说明即可。以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本申请方案的目的。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。
实施例3:
相应的,本申请还提供一种电子设备,包括:一个或多个处理器;存储器,用于存储一个或多个程序;当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理器实现如上述的域适应图像分类网络的训练方法或图像分类方法。
实施例4:
相应的,本申请还提供一种计算机可读存储介质,其上存储有计算机指令,其特征在于,该指令被处理器执行时实现如上述的域适应图像分类网络的训练方法或图像分类方法。
本领域技术人员在考虑说明书及实践这里公开的内容后,将容易想到本申请的其它实施方案。本申请旨在涵盖本申请的任何变型、用途或者适应性变化,这些变型、用途或者适应性变化遵循本申请的一般性原理并包括本申请未公开的本技术领域中的公知常识或惯用技术手段。说明书和实施例仅被视为示例性的,本申请的真正范围和精神由下面的权利要求指出。
应当理解的是,本申请并不局限于上面已经描述并在附图中示出的精确结构,并且可以在不脱离其范围进行各种修改和改变。本申请的范围仅由所附的权利要求来限制。

Claims (10)

1.一种域适应图像分类网络的训练方法,其特征在于,包括:
获取若干对源域图像和目标域图像,其中每一对源域图像和目标域图像的类别均相同;
提取其中一对所述源域图像和目标域图像的跨层特征;
利用注意力机制,计算所述源域图像和目标域图像的跨层特征之间的相似度;
根据所述跨层特征的多核最大均值差异和所述相似度,计算域对齐泛化损失;
根据所述源域图像和目标域图像的跨层特征,计算分类损失;
根据所述域对齐泛化损失和分类损失,加权计算域适应图像分类网络的总损失;
根据所述总损失,更新所述域适应图像分类网络的参数;
对其余源域图像和目标域图像执行从提取其中一对所述源域图像和目标域图像的跨层特征至根据所述总损失,更新所述域适应图像分类网络的参数的步骤,直至所述跨层对齐损失收敛。
2.根据权利要求1所述的训练方法,其特征在于,获取若干对源域图像和目标域图像之后,还包括:
调整所述源域图像和目标域图像,以使得所述源域图像和目标域图像的尺寸相同;
对调整后的源域图像和目标域图像编码。
3.根据权利要求1所述的训练方法,其特征在于,所述源域图像的跨层特征包括第一源域特征和第二源域特征,所述目标域图像的跨层特征包括第一目标域特征和第二目标域特征。
4.根据权利要求3所述的训练方法,其特征在于,利用注意力机制,计算所述源域图像和目标域图像的跨层特征之间的相似度,包括:
根据所述第一源域特征和第一目标域特征,提取局部源域特征和局部目标域特征;
计算每一对局部源域特征和局部目标域特征的通道相似度和空间相似度;
对所述通道相似度和空间相似度进行平均,得到所述源域图像和目标域图像的跨层特征之间的相似度。
5.根据权利要求3所述的训练方法,其特征在于,根据所述跨层特征的多核最大均值差异和所述相似度,计算域对齐泛化损失,包括:
计算所述第一源域特征和第一目标域特征的第一多核最大均值差异;
计算所述第二源域特征和第二目标域特征的第二多核最大均值差异;
计算所述第一多核最大均值差异与所述相似度的乘积的和;
对所述乘积的和与第二多核最大均值差异进行加权求和,得到域泛化损失。
6.一种域适应图像分类网络的训练装置,其特征在于,包括:
第一获取模块,用于获取若干对源域图像和目标域图像,其中每一对源域图像和目标域图像的类别均相同;
提取模块,用于提取其中一对所述源域图像和目标域图像的跨层特征;
第一计算模块,用于利用注意力机制,计算所述源域图像和目标域图像的跨层特征之间的相似度;
第二计算模块,用于根据所述跨层特征的多核最大均值差异和所述相似度,计算域对齐泛化损失;
第三计算模块,用于根据所述源域图像和目标域图像的跨层特征,计算分类损失;
第四计算模块,用于根据所述域对齐泛化损失和分类损失,加权计算域适应图像分类网络的总损失;
第一更新模块,用于根据所述总损失,更新所述域适应图像分类网络的参数;
第二更新模块,用于对其余源域图像和目标域图像执行从提取其中一对所述源域图像和目标域图像的跨层特征至根据所述总损失,更新所述域适应图像分类网络的参数的步骤,直至所述跨层对齐损失收敛。
7.一种图像分类方法,其特征在于,包括:
获取待分类的目标域图像;
将所述目标域图像输入域适应图像分类网络,其中所述域适应图像分类网络为根据权利要求1-5中任一项所述方法训练得到的网络;
获取所述域适应图像分类网络输出概率组,所述概率组中包括所述目标域图像分别属于各已知类别的概率;
将值最大的所述概率对应的类别设置为所述目标域图像的类别。
8.一种图像分类装置,其特征在于,包括:
第二获取模块,用于获取待分类的目标域图像;
输入模块,用于将所述目标域图像输入域适应图像分类网络,其中所述域适应图像分类网络为根据权利要求1-5中任一项所述方法训练得到的网络;
第三获取模块,用于获取所述域适应图像分类网络输出概率组,所述概率组中包括所述目标域图像分别属于各已知类别的概率;
设置模块,用于将值最大的所述概率对应的类别设置为所述目标域图像的类别。
9.一种电子设备,其特征在于,包括:
一个或多个处理器;
存储器,用于存储一个或多个程序;
当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理器实现如权利要求1-5或权利要求7任一项所述的方法。
10.一种计算机可读存储介质,其上存储有计算机指令,其特征在于,该指令被处理器执行时实现如权利要求1-5或权利要求7中任一项所述方法的步骤。
CN202210193844.7A 2022-03-01 2022-03-01 域适应图像分类网络的训练方法、图像分类方法及装置 Pending CN114677535A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210193844.7A CN114677535A (zh) 2022-03-01 2022-03-01 域适应图像分类网络的训练方法、图像分类方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210193844.7A CN114677535A (zh) 2022-03-01 2022-03-01 域适应图像分类网络的训练方法、图像分类方法及装置

Publications (1)

Publication Number Publication Date
CN114677535A true CN114677535A (zh) 2022-06-28

Family

ID=82072958

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210193844.7A Pending CN114677535A (zh) 2022-03-01 2022-03-01 域适应图像分类网络的训练方法、图像分类方法及装置

Country Status (1)

Country Link
CN (1) CN114677535A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115578593A (zh) * 2022-10-19 2023-01-06 北京建筑大学 一种使用残差注意力模块的域适应方法

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115578593A (zh) * 2022-10-19 2023-01-06 北京建筑大学 一种使用残差注意力模块的域适应方法

Similar Documents

Publication Publication Date Title
WO2021042828A1 (zh) 神经网络模型压缩的方法、装置、存储介质和芯片
CN110209859B (zh) 地点识别及其模型训练的方法和装置以及电子设备
CN110555399B (zh) 手指静脉识别方法、装置、计算机设备及可读存储介质
WO2022042123A1 (zh) 图像识别模型生成方法、装置、计算机设备和存储介质
CN109063719B (zh) 一种联合结构相似性和类信息的图像分类方法
CN111444951B (zh) 样本识别模型的生成方法、装置、计算机设备和存储介质
CN107680077A (zh) 一种基于多阶梯度特征的无参考图像质量评价方法
CN110309835B (zh) 一种图像局部特征提取方法及装置
CN109949200B (zh) 基于滤波器子集选择和cnn的隐写分析框架构建方法
CN114549913A (zh) 一种语义分割方法、装置、计算机设备和存储介质
CN111639230B (zh) 一种相似视频的筛选方法、装置、设备和存储介质
CN112861659A (zh) 一种图像模型训练方法、装置及电子设备、存储介质
CN113705596A (zh) 图像识别方法、装置、计算机设备和存储介质
WO2023020214A1 (zh) 检索模型的训练和检索方法、装置、设备及介质
CN112786160A (zh) 基于图神经网络的多图片输入的多标签胃镜图片分类方法
CN111126155B (zh) 一种基于语义约束生成对抗网络的行人再识别方法
CN114677535A (zh) 域适应图像分类网络的训练方法、图像分类方法及装置
CN109101984B (zh) 一种基于卷积神经网络的图像识别方法及装置
CN114299304A (zh) 一种图像处理方法及相关设备
CN111079930B (zh) 数据集质量参数的确定方法、装置及电子设备
CN116384471A (zh) 模型剪枝方法、装置、计算机设备、存储介质和程序产品
CN114155388B (zh) 一种图像识别方法、装置、计算机设备和存储介质
TWI803243B (zh) 圖像擴增方法、電腦設備及儲存介質
CN114937166A (zh) 图像分类模型构建方法、图像分类方法及装置、电子设备
CN115457638A (zh) 模型训练方法、数据检索方法、装置、设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination