CN114611617A - 基于原型网络的深度领域自适应图像分类方法 - Google Patents

基于原型网络的深度领域自适应图像分类方法 Download PDF

Info

Publication number
CN114611617A
CN114611617A CN202210259161.7A CN202210259161A CN114611617A CN 114611617 A CN114611617 A CN 114611617A CN 202210259161 A CN202210259161 A CN 202210259161A CN 114611617 A CN114611617 A CN 114611617A
Authority
CN
China
Prior art keywords
domain
loss function
prototype
classifier
data
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210259161.7A
Other languages
English (en)
Inventor
刘龙
刘泽宁
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Xian University of Technology
Original Assignee
Xian University of Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Xian University of Technology filed Critical Xian University of Technology
Priority to CN202210259161.7A priority Critical patent/CN114611617A/zh
Publication of CN114611617A publication Critical patent/CN114611617A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Evolutionary Biology (AREA)
  • Probability & Statistics with Applications (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Image Analysis (AREA)
  • Image Processing (AREA)

Abstract

本发明公开一种基于原型网络的深度领域自适应图像分类方法,包括:预训练好的特征提取器分别对扩充的DS和DT提取特征得到F(DS)和F(DT);将F(DS)、F(DT)依次输入第一非线性投影器、第一分类器得到分类损失函数Ld2,分别输入至嵌入模块得F′(DS)和F′(DT)和域混淆的损失函数Ld1;定义原型网络,将F′(DS)和F′(DT)输入原型网络中到原型损失函数LP1、LP2,通过Ld1、Ld2、LP1、LP2获得总损失函数用于反向训练;采用训练好的原型网络进行分类。本发明解决了现有技术中存在的源域和无标签目标域类别不一致且给定样本数量较少情况下导致训练后分类器预测分类正确率低下的问题。

Description

基于原型网络的深度领域自适应图像分类方法
技术领域
本发明属于迁移学习技术领域,涉及一种基于原型网络的深度领域自适应图像分类方法。
背景技术
机器学习特别是深度学习推动了计算机视觉领域的迅猛发展,其算法理论已经被广泛的应用于解决各类实际的工程问题,极大地提高了包括图像分类、目标检测、目标跟踪和语义分割在内的各类计算机视觉的性能表现。通常一个深度网络模型的优劣水平很大程度上取决于该任务相关的带标签训练数据集,然而获得大规模带标签的数据需要投入难以想象的人力和物力成本,依靠带标签数据集训练得到的深度网络模型弱鲁棒性,泛化能力极其低下。此外,机器学习要求训练、测试数据满足独立同分布的条件,但在现实应用场景中很难保证这一约束性的条件成立。针对特定图像分类任务训练得到的深度网络模型,在当前任务上具有很强的正确分类能力,但应用场景发生变化有别于训练任务时,其模型的性能表现可能会出现大幅下降,乃至完全不可用。
领域自适应放宽了传统机器学习的两个基本约束条件,即足够的训练样本以及训练、测试数据独立同分布。领域自适应利用源域与目标域数据分布的相关性,旨在通过借助有标签源域的知识达到解决无标签目标域分类任务的目的。领域自适应通常假设源域数据和目标域数据的类别保持一致,并通过不断拉近源域和目标域数据在高维共享空间中的分布距离,从而将有标签源域数据训练的模型应用于无标签目标域中,可以确保模型在解决目标域分类问题时性能不会发生大幅度下降。
与上述介绍的领域自适应略有不同,源域和目标域数据通过不同的渠道获取且目标域数据完全没有被标注,所以并不能明确目标域数据是否与源域数据共享所有类别。目标域数据很有可能包含和源域数据类别完全不同的样本,故在实际的应用场景中很难确保源域数据和目标域数据的类别完全一致,因此在保持源域和目标域类别区分性的同时能够实现全局域分布对齐尤为重要。
传统分类问题通常假设各类别的样本相对均衡,也即每类的样本数量大致相当,但往往源域中某些特定类别的样本数量较少,样本数量并不均衡,可能会造成错误映射导致目标域对应类别的样本错分。另外,当目标域数据规模和样本标签大幅缩减时,深度网络模型仅依据源域已有知识正确预测目标域类别的可能性较低。
发明内容
本发明的目的是提供一种基于原型网络的深度领域自适应图像分类方法,解决了现有技术中存在的源域和无标签目标域类别不一致且给定样本数量较少情况下对深度网络模型性能的影响,以及导致训练后分类器预测分类正确率低下的问题。
本发明所采用的技术方案是:
基于原型网络的深度领域自适应图像分类方法,包括以下步骤:
步骤1,对包括源域数据DS和目标域数据DT的数据集进行扩充;
步骤2、建立共享特征提取器F(·),采用预训练好的共享特征提取器F(·)分别对扩充的源域数据DS和目标域数据DT提取特征,得到特征向量F(DS)和F(DT);
步骤3、构建第一非线性投影器、第一分类器,将特征向量F(DS)、F(DT)输入至第一非线性投影器后再经第一分类器,得到分类损失函数Ld2
步骤4,构建并训练嵌入模块,将特征向量F(DS)和F(DT)分别输入至嵌入模块得嵌入特征F′(DS)和F′(DT)和域混淆的损失函数Ld1
步骤5定义原型网络,将嵌入模块得到中得到F′(DS)和F′(DT)输入原型网络中,对原型网络进行训练得到原型损失函数LP1、LP2,对损失函数Ld1、Ld2、LP1、LP2进行加权叠加得到总损失函数用于反向训练网络模型;
步骤6,采用训练好的原型网络对待分类图像进行分类。
本发明的特点还在于:
其中步骤1具体包括将源域、目标域数据集分批次输入到随机数据增广网络中,随机数据增广网络对原始的源域以及目标域数据集样本旋转、裁剪和加入高斯白噪声变换后恢复至原始输入大小,形成新的样本加入至原始数据集中,从而扩充数据集。
步骤2中共享的特征提取器F(·)采用预训练的VGG16深度卷积网络。
步骤3包括:
步骤3.1,第一非线性投影器f(·)中包括依次连接的全连接层、ReLU激活函数层、全连接层,将特征向量F(DS)和F(DT)输入至第一非线性投影器f(·)中得到特征向量f(DS)和f(DT);
步骤3.2,将特征向量f(DS)和f(DT)入至第一分类器G中得到关于源域和目标域各自的分类预测结果g(DS)和g(DT);
步骤3.3,定义领域判别器D,领域判别器D包括依次连接的三层全连接层,定义在关于源域数据分布PS(x)目标域数据分布PT(x)的领域判别器D、特征提取器F和第一分类器G上的分类损失E,该损失函数形式化定义为:
Figure BDA0003550074660000041
上式中,fi S
Figure BDA0003550074660000042
分别代表源域第i个样本
Figure BDA0003550074660000043
经过第一非线性投影、分类器后得到的特征向量,fi T
Figure BDA0003550074660000044
分别代表目标域第i个样本
Figure BDA0003550074660000045
经过第一非线性投影、分类器后得到的特征向量,设h(·)=(f,g),定义
Figure BDA0003550074660000046
式中
Figure BDA0003550074660000047
表示f和g的外积运算,用于融合第一非线性投影层和第一分类器层的输出,则学习领域独有特征表示的损失函数进一步定义为:
Figure BDA0003550074660000048
采用损失函数Ld2,对共享特征提取器、第一非线性投影器、第一分类器进行反向训练。
步骤4包括:
步骤4.1,嵌入模块包括自编码器,自编码器以F(DS)、F(DT)为输入,输出对应的降维重构向量,在自编码器中加入注意力机制作为强制约束,注意力机制包括全连接层,其中注意力机制采用注意力分数Sigmoid(FC(F))用于删除任何域特定信息;
将特征向量F(DS)、F(DT)输入嵌入模块最终输出嵌入向量F′(DS)和F′(DT);
步骤4.2,将嵌入向量F′(DS)和F′(DT)输入至第二非线性投影器得到f′(DS)和f′(DT);再经过第二分类器G′中得到对应源域和目标域各自的分类预测结果g′(DS)和g′(DT);
步骤4.3、定义在源域数据分布PS(x)目标域数据分布PT(x)的嵌入模块F′和第二分类器G′上的域对抗损失E,该损失函数形式化定义为:
Figure BDA0003550074660000051
上式中,fi S′
Figure BDA0003550074660000052
分别代表源域第i个样本
Figure BDA0003550074660000053
经过第二非线性投影、分类器后得到的特征向量,fi T′
Figure BDA0003550074660000054
分别代表目标域第i个样本经过
Figure BDA0003550074660000055
第二非线性投影、分类器后得到的特征向量,设h′(·)=(f′,g′),定义
Figure BDA0003550074660000056
式中
Figure BDA0003550074660000057
仍表示f′和g′的外积运算,则域对抗损失E等价为:
Figure BDA0003550074660000058
步骤4.4、虽然经过嵌入模块后源域、目标域数据尽可能的实现域混淆,但不同样本域混淆的困难程度不同,网络模型对不同的样本应保持一致性,故引入广义熵指数:
Figure BDA0003550074660000061
上式中,M代表每批次训练中的样本个数,Lk代表着第k个样本的对应的损失函数值,
Figure BDA0003550074660000065
代表着一批次样本的平均损失函数值。广义熵指数越大,对分布顶端的差异敏感性越大。令α=2,计算每批次源域、目标域数据对应的广义熵指数并记为wS和wT。则最终通过域对抗以实现域混淆的损失函数定义为:
Figure BDA0003550074660000062
采用损失函数Ld1对嵌入模块、第二非线性投影器、第二分类器进行反向训练。
步骤5包括:
步骤5.1,对嵌入向量F′(DS)、F′(DT),对于每一个类别,取其各自的特征向量的均值作为类别原型Ck
Figure BDA0003550074660000063
上式中,Sk表示类别为k的样本集合,所有的类别原型组成原型网络。
步骤5.2、通过距离函数d(·)计算测试样本
Figure BDA0003550074660000064
对应的特征向量与各个类别原型之间的平方欧式距离,并利用Softmax函数将其转化为概率值,预测该样本的类别概率:
Figure BDA0003550074660000071
步骤5.3、利用步骤5.2的类别预测规则,对于源域数据中支持集样本SS,其对应的真实类别标签是c,基于负对数似然定义源域每轮次上的原型损失函数为:
Figure BDA0003550074660000072
同样对于目标域数据中支持集样本ST,其对应的真实类别标签是c,基于负对数似然定义目标域每轮次上的原型损失函数为:
Figure BDA0003550074660000073
步骤5.4、利用步骤3.3、4.4得到的域混淆的损失函数Ld1、Ld1和步骤5.4到的原型损失函数LP1、LP2,定义总损失函数,通过反向训练网络模型参数,总损失定义为:
Ltotal=(LP1+LP2)+λ(Ld1+Ld1) (11)
上式中λ代表学习率超参数设置为5~10。
本发明的有益效果是:
发明一种基于原型网络的深度域自适应图像分类方法,在保持全局域对齐的基础上,同时也保留了源域和目标域中各领域间特有的域区分信息。经过域对抗、域混淆和原型网络训练后,通过计算查询集图像和类别原型之间的相似度程度,使网络模型在目标域中能够拥有更好的分类性能。
附图说明
图1是本发明基于原型网络的深度领域自适应图像分类方法的网络结构图;
图2是本发明步骤4的嵌入模块中深度自编码器结构示意图。
具体实施方式
下面结合附图和具体实施方式对本发明进行详细说明。
本发明基于原型网络的深度领域自适应图像分类方法,如图1,按照以下步骤实施:
步骤1,对包括源域数据DS和目标域数据DT的数据集进行扩充;
步骤2、建立共享特征提取器F(·),预训练好的共享特征提取器F(·)分别对扩充的源域数据DS和目标域数据DT提取特征,得到特征向量F(DS)和F(DT);
步骤3、构建第一非线性投影器、第一分类器,将源域和目标域的特征向量F(DS)、F(DT)输入至第一非线性投影器后再经第一分类器得到预测分类结果g(DS)、g(DT),进而得到分类损失函数Ld2,用以训练共享特征提取器、第一非线性投影器、第一分类器进行反向训练;
步骤4,构建并训练嵌入模块,将源域的特征向量F(DS)和目标域的特征向量F(DT)分别输入至嵌入模块得到中得到F′(DS)和F′(DT);
嵌入向量F′(DS)、F′(DT)输入至第二非线性投影得到f′(DS)和f′(DT),再经过第二分类器G′中得到对应源域和目标域各自的分类预测结果g′(DS)和g′(DT),进而得到域混淆的损失函数Ld1,用以对嵌入模块、第二非线性投影器、第二分类器进行反向训练;
步骤5,得到原型网络,将嵌入模块得到中得到F′(DS)和F′(DT)输入原型网络中,对原型网络进行训练得到原型损失函数LP1、LP2,对损失函数Ld1、Ld2、LP1、LP2进行加权叠加得到总损失函数用于反向训练网络模型;
步骤6,采用训练好的原型网络对待分类图像进行分类。
本发明的特点还在于:
步骤1具体为,源域数据集DS服从某种分布PS(x),类别标签为CS,即
Figure BDA0003550074660000091
目标域数据集DT服从某种分布PT(x),类别标签为CT,即
Figure BDA0003550074660000092
将源域、目标域数据集分批次输入到随机数据增广网络中,随机数据增广网络对原始的源域以及目标域数据集样本旋转、裁剪和加入高斯白噪声变换后恢复至原始输入大小,形成新的样本加入至原始数据集中,从而实现数据集扩充的目的;
步骤2具体为,利用共享的特征提取器F(·)对源域和目标域进行特征提取,特征提取器采用预训练的VGG16深度卷积网络,得到源域的特征向量F(DS)和目标域的特征向量F(DT),然后将得到的特征向量分别输入至第一非线性投影器和嵌入模块等下游支路。
步骤3具体按照以下步骤实施:
步骤3.1、将源域、目标域的特征向量F(DS)和F(DT)输入至第一非线性投影器f(·)中,第一非线性投影器(·)中由全连接层、ReLU激活函数层、全连接层等共三层结构组成,旨在拓展输出空间,保留更多的特征信息,经过第一非线性投影后得到f(DS)和f(DT);
步骤3.2、源域、目标域的特征向量F(DS)和F(DT)经过第一非线性投影后得到f(DS)和f(DT),将其输入至第一分类器G中得到关于源域和目标域各自的分类预测结果g(DS)和g(DT);
步骤3.3、因源域和目标域中每类的数据分布并不完全相同,引入领域判别器D,该领域判别器D由三层全连接层组成,旨在保留各领域中特有的域区分信息,定义在关于源域数据分布PS(x)目标域数据分布PT(x)的领域判别器D、特征提取器F和第一分类器G上的分类损失E,该损失函数形式化定义为:
Figure BDA0003550074660000101
上式中,fi S
Figure BDA0003550074660000102
分别代表源域第i个样本
Figure BDA0003550074660000103
经过第一非线性投影、分类器后得到的特征向量,fi T
Figure BDA0003550074660000104
分别代表目标域第i个样本
Figure BDA0003550074660000105
经过第一非线性投影、分类器后得到的特征向量,设h(·)=(f,g),定义
Figure BDA0003550074660000106
式中
Figure BDA0003550074660000107
表示f和g的外积运算,用于融合第一非线性投影层和第一分类器层的输出,则学习领域独有特征表示的损失函数进一步定义为:
Figure BDA0003550074660000108
采用损失函数Ld2,对共享特征提取器、第一非线性投影器、第一分类器进行反向训练。
如图2,步骤4具体按照以下步骤实施:
步骤4.1、将经特征提取器得到的源域和目标域的特征向量F(DS)、F(DT)输入至嵌入模块F′。嵌入模块由自编码器和注意力机制组成,自编码器以F(DS)、F(DT)为输入,输出对应的降维重构向量。为保持全局域对齐,也即源域和目标域尽可能的实现域混淆,需加入一个强制约束。该约束是由全连接层组成的注意力机制FC(·),注意力分数Sigmoid(FC(F))用于删除任何域特定信息。将自编码器和关注子模块结合在一起,得到嵌入模块的最终输出为F′(DS)和F′(DT);
步骤4.2、将经过嵌入模块后得到域混淆嵌入向量F′(DS)、F′(DT)输入至第二非线性投影得到f′(DS)和f′(DT),再经过第二分类器G′中得到对应源域和目标域各自的分类预测结果g′(DS)和g′(DT);
步骤4.3、定义在源域数据分布PS(x)目标域数据分布PT(x)的嵌入模块F′和第二分类器G′上的域对抗损失E,该损失函数形式化定义为:
Figure BDA0003550074660000111
上式中,fi S′
Figure BDA0003550074660000112
分别代表源域第i个样本
Figure BDA0003550074660000113
经过第二非线性投影、分类器后得到的特征向量,fi T′
Figure BDA0003550074660000114
分别代表目标域第i个样本经过
Figure BDA0003550074660000115
第二非线性投影、分类器后得到的特征向量,设h′(·)=(f′,g′),定义
Figure BDA0003550074660000116
式中
Figure BDA0003550074660000117
仍表示f′和g′的外积运算,则域对抗损失E等价为:
Figure BDA0003550074660000118
步骤4.4、虽然经过嵌入模块后源域、目标域数据尽可能的实现域混淆,但不同样本域混淆的困难程度不同,网络模型对不同的样本应保持一致性,故引入广义熵指数:
Figure BDA0003550074660000119
上式中,M代表每批次训练中的样本个数,Lk代表着第k个样本的对应的损失函数值,
Figure BDA0003550074660000121
代表着一批次样本的平均损失函数值。广义熵指数越大,对分布顶端的差异敏感性越大。令α=2,计算每批次源域、目标域数据对应的广义熵指数并记为wS和wT。则最终通过域对抗以实现域混淆的损失函数定义为:
Figure BDA0003550074660000122
采用损失函数Ld1对嵌入模块、第二非线性投影器、第二分类器进行反向训练。
步骤5具体按照以下步骤实施:
步骤5.1、经过嵌入模块后得到源域和目标域的域混淆嵌入向量F′(DS)、F′(DT),对于每一个类别,取其各自的特征向量的均值作为类别原型Ck,则类别原型定义为:
Figure BDA0003550074660000123
上式中,Sk表示类别为k的样本集合,所有的类别原型组成原型网络。
步骤5.2、通过距离函数d(·)计算测试样本
Figure BDA0003550074660000124
对应的特征向量与各个类别原型之间的平方欧式距离,并利用Softmax函数将其转化为概率值,预预测该样本的类别概率:
Figure BDA0003550074660000125
步骤5.3、利用步骤5.2的类别预测规则,对于源域数据中支持集样本SS,其对应的真实类别标签是c,基于负对数似然定义源域每轮次上的原型损失函数为:
Figure BDA0003550074660000131
同样对于目标域数据中支持集样本ST,其对应的真实类别标签是c,基于负对数似然定义目标域每轮次上的原型损失函数为:
Figure BDA0003550074660000132
步骤5.4、利用步骤3.3、4.4得到的域混淆的损失函数Ld1、Ld1和步骤5.4到的原型损失函数LP1、LP2,定义总损失函数,通过反向训练网络模型参数,总损失定义为:
Ltotal=(LP1+LP2)+λ(Ld1+Ld1) (11)
上式中λ代表学习率超参数,根据网络收敛情况设置为510。
步骤6具体按照以下步骤实施:
在测试阶段,给定目标域测试任务
Figure BDA0003550074660000133
中需要分类的查询集图像SQ,针对N-way、K-shot任务,利用已经训练好的原型网络,对查询集图像SQ进行分类,通过比较查询集图像SQ与各类别原型之间的平方欧式距离,预测其对应的类别。

Claims (6)

1.基于原型网络的深度领域自适应图像分类方法,其特征在于,包括以下步骤:
步骤1,对包括源域数据DS和目标域数据DT的数据集进行扩充;
步骤2、建立共享特征提取器F(·),采用预训练好的共享特征提取器F(·)分别对扩充的源域数据DS和目标域数据DT提取特征,得到特征向量F(DS)和F(DT);
步骤3、构建第一非线性投影器、第一分类器,将特征向量F(DS)、F(DT)输入至第一非线性投影器后再经第一分类器,得到分类损失函数Ld2
步骤4,构建并训练嵌入模块,将特征向量F(DS)和F(DT)分别输入至嵌入模块得嵌入特征F′(DS)和F′(DT)和域混淆的损失函数Ld1
步骤5定义原型网络,将嵌入模块得到中得到F′(DS)和F′(DT)输入原型网络中,对原型网络进行训练得到原型损失函数LP1、LP2,对损失函数Ld1、Ld2、LP1、LP2进行加权叠加得到总损失函数用于反向训练网络模型;
步骤6,采用训练好的原型网络对待分类图像进行分类。
2.如权利要求1所述的基于原型网络的深度领域自适应图像分类方法,其特征在于,所述步骤1具体包括将源域、目标域数据集分批次输入到随机数据增广网络中,随机数据增广网络对原始的源域以及目标域数据集样本旋转、裁剪和加入高斯白噪声变换后恢复至原始输入大小,形成新的样本加入至原始数据集中,从而扩充数据集。
3.如权利要求1所述的基于原型网络的深度领域自适应图像分类方法,其特征在于,所述步骤2中共享的特征提取器F(·)采用预训练的VGG16深度卷积网络。
4.如权利要求1所述的基于原型网络的深度领域自适应图像分类方法,其特征在于,所述步骤3包括:
步骤3.1,第一非线性投影器f(·)中包括依次连接的全连接层、ReLU激活函数层、全连接层,将特征向量F(DS)和F(DT)输入至第一非线性投影器f(·)中得到特征向量f(DS)和f(DT);
步骤3.2,将特征向量f(DS)和f(DT)入至第一分类器G中得到关于源域和目标域各自的分类预测结果g(DS)和g(DT);
步骤3.3,定义领域判别器D,领域判别器D包括依次连接的三层全连接层,定义在关于源域数据分布PS(x)目标域数据分布PT(x)的领域判别器D、特征提取器F和第一分类器G上的分类损失E,该损失函数形式化定义为:
Figure FDA0003550074650000021
上式中,
Figure FDA0003550074650000028
分别代表源域第i个样本
Figure FDA0003550074650000022
经过第一非线性投影、分类器后得到的特征向量,fi T
Figure FDA0003550074650000023
分别代表目标域第i个样本
Figure FDA0003550074650000024
经过第一非线性投影、分类器后得到的特征向量,设h(·)=(f,g),定义
Figure FDA0003550074650000025
式中
Figure FDA0003550074650000026
表示f和g的外积运算,用于融合第一非线性投影层和第一分类器层的输出,则学习领域独有特征表示的损失函数进一步定义为:
Figure FDA0003550074650000027
采用损失函数Ld2,对共享特征提取器、第一非线性投影器、第一分类器进行反向训练。
5.如权利要求1所述的基于原型网络的深度领域自适应图像分类方法,其特征在于,所述步骤4包括:
步骤4.1,嵌入模块包括自编码器,自编码器以F(DS)、F(DT)为输入,输出对应的降维重构向量,在自编码器中加入注意力机制作为强制约束,注意力机制包括全连接层,其中注意力机制采用注意力分数Sigmoid(FC(F))用于删除任何域特定信息;
将特征向量F(DS)、F(DT)输入嵌入模块最终输出嵌入向量F′(DS)和F′(DT);
步骤4.2,将嵌入向量F′(DS)和F′(DT)输入至第二非线性投影器得到f′(DS)和f′(DT);再经过第二分类器G′中得到对应源域和目标域各自的分类预测结果g′(DS)和g′(DT);
步骤4.3、定义在源域数据分布PS(x)目标域数据分布PT(x)的嵌入模块F′和第二分类器G′上的域对抗损失E,该损失函数形式化定义为:
Figure FDA0003550074650000031
上式中,fi S′
Figure FDA0003550074650000032
分别代表源域第i个样本
Figure FDA0003550074650000033
经过第二非线性投影、分类器后得到的特征向量,fi T′
Figure FDA0003550074650000034
分别代表目标域第i个样本经过
Figure FDA0003550074650000035
第二非线性投影、分类器后得到的特征向量,设h′(·)=(f′,g′),定义
Figure FDA0003550074650000036
式中
Figure FDA0003550074650000037
仍表示f′和g′的外积运算,则域对抗损失E等价为:
Figure FDA0003550074650000041
步骤4.4、虽然经过嵌入模块后源域、目标域数据尽可能的实现域混淆,但不同样本域混淆的困难程度不同,网络模型对不同的样本应保持一致性,故引入广义熵指数:
Figure FDA0003550074650000042
上式中,M代表每批次训练中的样本个数,Lk代表着第k个样本的对应的损失函数值,
Figure FDA0003550074650000043
代表着一批次样本的平均损失函数值,令α=2,计算每批次源域、目标域数据对应的广义熵指数并记为wS和wT,则最终通过域对抗以实现域混淆的损失函数定义为:
Figure FDA0003550074650000044
采用损失函数Ld1对嵌入模块、第二非线性投影器、第二分类器进行反向训练。
6.如权利要求1所述的基于原型网络的深度领域自适应图像分类方法,其特征在于,所述步骤5包括:
步骤5.1,对嵌入向量F′(DS)、F′(DT),对于每一个类别,取其各自的特征向量的均值作为类别原型Ck
Figure FDA0003550074650000045
上式中,Sk表示类别为k的样本集合,所有的类别原型组成原型网络;
步骤5.2、通过距离函数d(·)计算测试样本
Figure FDA0003550074650000054
对应的特征向量与各个类别原型之间的平方欧式距离,并利用Softmax函数将其转化为概率值,预测该样本的类别概率:
Figure FDA0003550074650000051
步骤5.3、利用步骤5.2的类别预测规则,对于源域数据中支持集样本SS,其对应的真实类别标签是c,基于负对数似然定义源域每轮次上的原型损失函数为:
Figure FDA0003550074650000052
同样对于目标域数据中支持集样本ST,其对应的真实类别标签是c,基于负对数似然定义目标域每轮次上的原型损失函数为:
Figure FDA0003550074650000053
步骤5.4、利用步骤3.3、4.4得到的域混淆的损失函数Ld1、Ld1和步骤5.4到的原型损失函数LP1、LP2,定义总损失函数,通过反向训练网络模型参数,总损失定义为:
Ltotal=(LP1+LP2)+λ(Ld1+Ld1) (11)
上式中λ代表学习率超参数,设置为5~10。
CN202210259161.7A 2022-03-16 2022-03-16 基于原型网络的深度领域自适应图像分类方法 Pending CN114611617A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210259161.7A CN114611617A (zh) 2022-03-16 2022-03-16 基于原型网络的深度领域自适应图像分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210259161.7A CN114611617A (zh) 2022-03-16 2022-03-16 基于原型网络的深度领域自适应图像分类方法

Publications (1)

Publication Number Publication Date
CN114611617A true CN114611617A (zh) 2022-06-10

Family

ID=81863290

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210259161.7A Pending CN114611617A (zh) 2022-03-16 2022-03-16 基于原型网络的深度领域自适应图像分类方法

Country Status (1)

Country Link
CN (1) CN114611617A (zh)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115861720A (zh) * 2023-02-28 2023-03-28 人工智能与数字经济广东省实验室(广州) 一种小样本亚类图像分类识别方法
CN116910571A (zh) * 2023-09-13 2023-10-20 南京大数据集团有限公司 一种基于原型对比学习的开集域适应方法及系统
CN117132841A (zh) * 2023-10-26 2023-11-28 之江实验室 一种保守渐进的领域自适应图像分类方法和装置

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115861720A (zh) * 2023-02-28 2023-03-28 人工智能与数字经济广东省实验室(广州) 一种小样本亚类图像分类识别方法
CN116910571A (zh) * 2023-09-13 2023-10-20 南京大数据集团有限公司 一种基于原型对比学习的开集域适应方法及系统
CN116910571B (zh) * 2023-09-13 2023-12-08 南京大数据集团有限公司 一种基于原型对比学习的开集域适应方法及系统
CN117132841A (zh) * 2023-10-26 2023-11-28 之江实验室 一种保守渐进的领域自适应图像分类方法和装置
CN117132841B (zh) * 2023-10-26 2024-03-29 之江实验室 一种保守渐进的领域自适应图像分类方法和装置

Similar Documents

Publication Publication Date Title
CN114611617A (zh) 基于原型网络的深度领域自适应图像分类方法
Kye et al. Meta-learned confidence for few-shot learning
Xia et al. Metalearning-based alternating minimization algorithm for nonconvex optimization
Lin et al. Structure-coherent deep feature learning for robust face alignment
CN108830301A (zh) 基于锚图结构的双拉普拉斯正则化的半监督数据分类方法
Reddy et al. AdaCrowd: Unlabeled scene adaptation for crowd counting
Li et al. Image manipulation localization using attentional cross-domain CNN features
Singh Gill et al. Efficient image classification technique for weather degraded fruit images
Lin et al. Rethinking crowdsourcing annotation: partial annotation with salient labels for multilabel aerial image classification
CN111144500A (zh) 基于解析高斯机制的差分隐私深度学习分类方法
Zheng et al. Detach and unite: A simple meta-transfer for few-shot learning
Feng et al. Introspective robot perception using smoothed predictions from bayesian neural networks
Truong et al. Domain generalization via universal non-volume preserving approach
Zhuang et al. Local label propagation for large-scale semi-supervised learning
CN117011219A (zh) 物品质量检测方法、装置、设备、存储介质和程序产品
Zhang et al. VESC: a new variational autoencoder based model for anomaly detection
CN115661539A (zh) 一种嵌入不确定性信息的少样本图像识别方法
Sagawa et al. Gradual domain adaptation via normalizing flows
Wang et al. A dynamic feature weighting method for mangrove pests image classification with heavy-tailed distributions
Liang et al. AMEMD-FSL: fuse attention mechanism and earth mover’s distance metric network to deep learning for few-shot image recognition
Azeez Joodi et al. A New Proposed Hybrid Learning Approach with Features for Extraction of Image Classification
CN117131936B (zh) 一种基于多层级类比推理的知识图谱嵌入方法
Liu et al. Hybrid learning network: a novel architecture for fast learning
Bai et al. Universal replication of chaotic characteristics by classical and quantum machine learning
Bhattacharjee et al. Addressing Class Imbalance in Fake News Detection with Latent Space Resampling

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination