CN116796184A - 一种基于图原型网络和实例对比的领域泛化方法 - Google Patents
一种基于图原型网络和实例对比的领域泛化方法 Download PDFInfo
- Publication number
- CN116796184A CN116796184A CN202310289243.0A CN202310289243A CN116796184A CN 116796184 A CN116796184 A CN 116796184A CN 202310289243 A CN202310289243 A CN 202310289243A CN 116796184 A CN116796184 A CN 116796184A
- Authority
- CN
- China
- Prior art keywords
- domain
- node
- class
- graph
- nodes
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 29
- 238000009826 distribution Methods 0.000 claims abstract description 30
- 238000013140 knowledge distillation Methods 0.000 claims abstract description 17
- 238000005096 rolling process Methods 0.000 claims abstract description 7
- 238000012549 training Methods 0.000 claims description 24
- 230000006870 function Effects 0.000 claims description 20
- 239000011159 matrix material Substances 0.000 claims description 18
- 238000000605 extraction Methods 0.000 claims description 12
- 230000009466 transformation Effects 0.000 claims description 7
- 230000004913 activation Effects 0.000 claims description 6
- 230000000295 complement effect Effects 0.000 claims description 6
- 238000012935 Averaging Methods 0.000 claims description 5
- 238000005259 measurement Methods 0.000 claims description 5
- 239000013598 vector Substances 0.000 claims description 5
- 241000540325 Prays epsilon Species 0.000 claims description 3
- 238000012795 verification Methods 0.000 claims description 3
- 238000010586 diagram Methods 0.000 claims description 2
- 239000000284 extract Substances 0.000 claims description 2
- 238000011524 similarity measure Methods 0.000 claims description 2
- 230000015556 catabolic process Effects 0.000 abstract description 2
- 238000006731 degradation reaction Methods 0.000 abstract description 2
- 230000000644 propagated effect Effects 0.000 abstract description 2
- 230000000007 visual effect Effects 0.000 abstract description 2
- 241000282326 Felis catus Species 0.000 description 2
- 238000004364 calculation method Methods 0.000 description 2
- 238000004821 distillation Methods 0.000 description 2
- 238000003064 k means clustering Methods 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 230000003044 adaptive effect Effects 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000007796 conventional method Methods 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 238000013508 migration Methods 0.000 description 1
- 230000005012 migration Effects 0.000 description 1
- 230000008569 process Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/90—Details of database functions independent of the retrieved data types
- G06F16/901—Indexing; Data structures therefor; Storage structures
- G06F16/9024—Graphs; Linked lists
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Databases & Information Systems (AREA)
- Software Systems (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Evolutionary Computation (AREA)
- Computational Linguistics (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Life Sciences & Earth Sciences (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
针对分布于未知域的图像的分类预测问题,源数据集和目标数据集的不匹配分布将导致源模型在目标域的性能显著下降,而目前已提出的针对跨域视觉表示的分布对齐方法没有考虑到跨域内部数据结构的差异。本发明通过样本结构特征的相似性,利用样本的CNN特征构建密集连接实例图。每个节点对应样本的CNN特征,该特征由标准卷积网络提取。然后,将图卷积网络应用于实例图,并将图结构信息沿着设计的网络学习加权图的边缘进行传播以更新节点。本发明利用类均值构造类原型进行分类,还考虑了实例节点的比较监督学习,以学习实例节点上类语义信息。本发明为了更好地学习和减少领域之间类别语义信息的差异,采用软标签进行领域之间知识蒸馏。
Description
技术领域
本发明属于机器学习域泛化领域,具体涉及一种基于图原型网络和实例对比的领域泛化方法。
背景技术
通常,大多数机器学习模型先在源域数据集上进行训练,然后将训练结果在目标域数据集上进行预测,其中往往隐含地假设源域数据集和目标域数据集都遵循相同的分布。然而,此类假设在现实世界中往往不能成立。例如,对于基于不同角度、设备、环境等条件收集的多领域图像,在一个领域上通过训练获得的分类器在其他领域的运用效果不佳。这里,将某一领域的知识迁移到其他不可见领域的过程被称为领域泛化。在迁移学习中,域泛化问题的困难主要来自于两方面,其一是不同源数据集的分布差异,其二是目标域的不可知性。域泛化旨将在源域数据集上通过训练获得的模型直接推广到具有不同分布的不可见目标域,而无需在目标域数据集上进行再训练或微调。领域泛化解决针对分布于未知域的图像的分类预测问题。源数据集和目标数据集的不匹配分布将导致源模型在目标域的性能显著下降。目前已提出的针对跨域视觉表示的分布对齐方法没有考虑到跨域内部数据结构的差异,并受制于不充分的对齐跨域表示。例如,深度对抗性自适应方法仅迫使全域分布的对齐,但可能会丢失每个类别的关键语义类标签信息,同时必须在训练中使用域标签进行监督学习。即便使用完美的混淆对齐,也不能保证特征空间中具有相同类标签的非同域样本的相邻映射。然而,对于传统的与数据结构分布对齐相关的方法,虽然可以减少域之间分布差异,并保留原始空间属性,但很难有效地模型化数据结构信息并集成到现有的深度网络中。
发明内容
基于图像原型网络和案例比较网络的图像分类方法的总体框架如图1所示。为了对深度网络下的数据结构进行建模,通过样本结构特征的相似性,利用样本的CNN特征构建密集连接实例图。每个节点对应样本的CNN特征,该特征由标准卷积网络(例如ResNet)提取。然后,将图卷积网络(GCN)应用于实例图,并将图结构信息沿着设计的网络学习加权图的边缘进行传播以更新节点。一方面,利用类均值构造类原型进行分类;另一方面,还考虑了实例节点的比较监督学习,以学习实例节点上类语义信息。同时,为了更好地学习和减少领域之间类别语义信息的差异,采用软标签(logit)进行领域之间知识蒸馏,即缩小Kullback-Leibler(KL)散度。知识蒸馏将具有相同类别标签但不同域的数据的预测分布集合与每个预测分布相匹配,通过使用多个域累积的有意义误差的集合惩罚样本的预测来增加模型预测的熵,鼓励模型收敛到宽局部最小值。本发明提出的基于图原型网络和实例对比的领域泛化方法的具体步骤如下:
步骤1:获取图像样本及其标签,并构建图像特征提取模型;
获取图像样本构建初始图像数据集,将所述初始图像数据集划分为源域数据集M={M1,...Mi,...,Mm}和目标域数据集T,其中Mm表示第m域数据集;所述目标域数据集在所述图像特征提取模型的训练过程中是不可访问的;
源域数据集M划分为训练集和验证集,将所述源域数据集M中的图像进行数据增强;
获取预训练模型,基于预训练模型构建所述图像特征提取模型;-通过所述图像特征提取模型,提取源域数据集M中的特征,作为图输入特征X;
步骤2:建立图卷积网络并获取类原型表示;
将提取源域数据集M的特征的图结构信息定义为G=<V,E,Z>,其中V={v1,...,vn}是n个节点的集合,是通过两层GCN层提取获得的节点特征,E={e11,...,eij,...,enn}表示节点之间距离;其中,采用余弦相似度/>表示节点i和节点j之间距离;
通过节点间距离E构造包含n个节点的无向图邻接矩阵A,将所述无向图邻接矩阵A转换其中,/>为度矩阵,/>j为节点i的邻接节点编号;
根据节点之间相似度,构建归一化后的邻接矩阵其中,I是单位矩阵;
对于一个给定的包含n个节点的无向图邻接矩阵A∈Rn×n,图卷积的线性变换取决于图输入特征X∈Rk×n与滤波器W∈Rk×d;
其中,图输入特征X中的列向量Xi∈Rk是节点的集合V中第i个节点的特征表示,d表示输出的特征维度;
按如下式所示的方法,进行两层的GCN处理得到嵌入特征
其中,σ为激活函数,表示为第i个节点在第l层的输出,/>是图卷积输入;
之后利用图卷积网络生成的嵌入特征计算类原型P∈Rc×d的表示,/>表示图卷积网络的第m源域中第i个节点输出;
所述类原型的定义为被同一类的节点紧密包围,这样同一类的节点就可以表示自己的类;第m源域的第c类的原型通过以下方式计算:
其中PROTP是计算类原型P的表示的方式,是第m源域中第i个节点的表示,mc为第m源域的第c类,vi为第m域的第c类的第i个节点,具体公式如下:
一般计算类原型时假设每个类只使用一个原型来表示,但原型分布不是单峰时,这种类表示是不充分的。此时,每个类可以使用多原型来表示,并用置换不变函数代替均值(如K-means聚类)。为了简便起见,按平均值进行计算。
将所述节点从原始嵌入空间投影到另一个距离空间来学习一个距离度量表示;
步骤3.通过比较节点的学习距离度量表示与类原型的距离度量表示进行分类;
计算距离度量损失:
由图卷积学习到的嵌入节点计算每个节点到每个类原型/>的距离度量表示:
其中,为第m源域中每个节点和每个类原型之间嵌入差异;
将节点嵌入差异联系到所有类原型,并按如下式(5)所示的方法,应用线性变换f对嵌入差异的不同维度给予不同程度的关注,同时自适应地提取嵌入差异信息,如下式所示:;
距离度量表示g表示节点v到所有类原型的距离信息,用于定义了第m源域中节点与所有类原型的相对位置,c∈C表示第c个类;按如下式所示的方法,将距离度量表示通过连接层concat连接起来,以计算在所有源域M中的类原型和节点的距离度量表示:
G=concat(g1,…,gm) (6)
然后计算第i个节点vi的softmax的值:
其中表示在整个源域中节点vi对类c的距离度量,P(y=c|vi)给出了节点vi对类c的预测概率分布,按如下式所示的方法构建交叉熵分类损失:
步骤4:嵌入空间实例节点监督对比学习;
按如下式所示的方法构建领域监督对比损失:
其中,I是所有小批量样本集合,i是一个锚点,p∈P(i)是I中与第i个样本相同的正样本,|P(i)|是集合P(i)中样本数量,是不同于第i个样本的其他类别样本且与第i个样本为同一域的负样本,/>表示对图卷积输出进行l2正则化的特征,τ表示温度参数,f表示不同节点的相似性度量,相同类别标签的样本表示/>和/>彼此靠近,而不同类别标签的样本表示/>与/>彼此远离;
步骤5:域不变性知识蒸馏;
按如下式所示的方法,通过域不变性知识蒸馏学习到来自不同域的节点信息之间的互补知识,其中,Xc表示来自各个域的具有相同类标签c的所有样本集合;通过对Xc取平均值获得相应的软标签值:
其中,h是来自GNN编码器最后一层的学习节点嵌入,表示第i个节点的GNN输出;
定义来自Xc的预测分布softmax函数为:
计算域间知识蒸馏,定义域损失函数为Lkd:通过KL散度来最小化域间语义层面的差异;
其中,M表示域的集合,m表示域的数量,表示第m域中第i节点的图像分类概率,τ表示温度参数;Dkl表示KL散度,用于计算节点与均值的输出分布差异;
步骤6:定义总目标损失函数如下;
其中,γ和为权重因子;
通过总目标损失函数对图卷积网络进行训练,实现领域泛化。
作为优选,所述步骤1中,所述数据增强的方式包括裁切、反转。
作为优选,所述步骤1中,所述预训练模型为ResNet或者AlexNet模型中的任意一种。
作为优选,所述步骤2中,图输入特征X由特征提取器ResNet获得,所述GCN的激活函数采用ReLu。
本发明的有益效果
本发明提出一种基于图原型网络和实例对比的领域泛化方法。与现有方法相比,该方法不仅强调了整个全局域间的原型分类,而且结合了内部样本的结构化信息,并对数据结构信息进行建模。通过构建域不变类原型分类和实例节点监督对比来最大化不同类之间裕度,明确地最大化目标领域不同类别之间差距。不仅可以缓解不同域间分布的差异,同时也保留了类别语义信息。
附图说明
为了更清楚地说明本发明具体实施方式或现有技术中的技术方案,下面将对具体实施方式或现有技术描述中所需要使用的附图作简单介绍,后文将参照附图以示例性而非限制性的方式详细描述本发明的一些具体实施例。附图中相同的附图标记标示了相同或类似的部件或部分。本领域技术人员应该理解,这些附图未必是按比例绘制的。
附图中:
图1是本发明基于图原型网络和实例对比网络的图像分类方法的原理示意图。
图2是图原型网络对类原型分类的流程图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处描述的具体实施例仅用以解释本申请,并不用于限定本申请。
步骤1:获取图像样本及其标签,并构建图像特征提取模型;
如图2所示,获取图像样本构建初始图像数据集,将所述初始图像数据集划分为源域数据集M={M1,...Mi,...,Mm}和目标域数据集T,其中Mm表示第m域数据集;所述目标域数据集在所述图像特征提取模型的训练过程中是不可访问的;
源域数据集M划分为训练集和验证集,将所述源域数据集M中的图像进行数据增强,所述数据增强的方式包括裁切、反转;
获取预训练模型,基于预训练模型构建所述图像特征提取模型;所述预训练模型为ResNet或者AlexNet模型中的任意一种;
通过所述图像特征提取模型,提取源域数据集M中的特征,作为图输入特征X;
步骤2:建立图卷积网络并获取类原型表示;
将提取源域数据集M的特征的图结构信息定义为G=<V,E,Z>,其中V={v1,...,vn}是n个节点的集合,是通过两层GCN层提取获得的节点特征,E={e11,...,eij,...,enn}表示节点之间距离;其中,采用余弦相似度/>表示节点i和节点j之间距离;
通过节点间距离E构造包含n个节点的无向图邻接矩阵A,将所述无向图邻接矩阵A转换其中,/>为度矩阵,/>j为节点i的邻接节点编号;
根据节点之间相似度,构建归一化后的邻接矩阵其中,I是单位矩阵;
对于一个给定的包含n个节点的无向图邻接矩阵A∈Rn×n,图卷积的线性变换取决于图输入信息X∈Rk×n与滤波器W∈Rk×d;
其中,图输入特征X中的列向量Xi∈Rk是节点的集合V中第i个节点的特征表示,d表示输出的特征维度;
按如下式所示的方法,进行两层的GCN处理得到嵌入特征
其中,σ为激活函数,表示为第i个节点在第l层的输出,/>X是图卷积输入。此外,节点特征X由特征提取器ResNet获得,所述GCN的激活函数采用ReLu;
之后利用图卷积网络生成的嵌入特征计算类原型P∈RC×d的表示,/>表示图卷积网络的第m源域中第i个节点输出;
所述类原型的定义为被同一类的节点紧密包围,这样同一类的节点就可以表示自己的类。第m源域的第c类的原型通过以下方式计算:
其中PROTP是计算类原型P的表示的方式,是第m源域中第i个节点的表示,mc为第m源域的第c类,vi为第m域的第c类的第i个节点,具体公式如下:
一般计算类原型时假设每个类只使用一个原型来表示,但原型分布不是单峰时,这种类表示是不充分的。此时,每个类可以使用多原型来表示,并用置换不变函数代替均值(如K-means聚类)。为了简便起见,按平均值进行计算;
由于原型是每个类的表示,可以通过选择最近的原型来对节点分类。然而,直接将类嵌入节点的平均向量作为原型可能不会提供预期的结果。因此,可以将节点从原始嵌入空间投影到另一个距离空间来学习一个距离度量表示,而不是直接基于其最近的原型对节点进行分类;
步骤3:通过比较节点的学习距离度量表示与类原型的距离度量表示进行分类;
计算距离度量损失:
由图卷积学习到的节点生成的嵌入节点计算每个节点到每个类原型的距离度量表示:
其中,为第m源域中每个节点和每个类原型之间嵌入差异;将节点嵌入差异联系到所有类原型,并按如下式所示的方法,应用线性变换f对嵌入差异的不同维度给予不同程度的关注,同时自适应地提取有用的嵌入差异信息,如下式所示;
距离度量表示g表示节点v到所有类原型的距离信息,用于定义了第m源域中节点与所有类原型的相对位置,c∈C表示第c个类。按如下式所示的方法,将距离度量表示通过连接层concat连接起来,以计算在所有源域M中的类原型和节点的距离度量表示:
G=concat(g1,…,gm) (6)
然后计算第i个节点vi的softmax的值:
其中表示在整个源域中节点vi对类c的距离度量,P(y=c|vi)给出了节点vi对类c的预测概率分布,按如下式所示的方法构建交叉熵分类损失:
步骤4:嵌入空间实例节点监督对比学习;
上一步骤的分类损失原型仅考虑每个类的中心点,忽略了类别内部的变化。相反,样本之间实例对比包含更细粒度的嵌入网络的实例节点特征变化。为了确保领域不变性,同时增加类别信息的可分性,计算了实例领域监督对比损失。传统的监督对比损失只考虑了类标签组成的正集和负集,但在领域泛化中传统的监督对比损却未考虑多领域的影响。其中,某些域的正负样本会对其他一些域的正样本造成错误判断,这可能是由于样本中域信息相对于类别信息的比重过大,故从负集中排除了具有不同域的样本;
其中,I是所有小批量样本集合,i是一个锚点,p∈P(i)是I中与第i个样本相同的正样本,|P(i)|是集合P(i)中样本数量,是不同于第i个样本的其他类别样本且与第i个样本为同一域的负样本,/>表示对图卷积输出进行l2正则化的特征,τ表示温度参数,f表示不同节点的相似性度量;
相同类别标签的样本表示和/>彼此靠近,而不同类别标签的样本表示/>与/>彼此远离。当对比损失监督被直接用于域泛化任务时,性能会下降。具体地说,由于来自不同域的正样本与该锚点受到域的影响而被推离,特征空间变得具有域区分性。为使特征空间更加适合领域泛化,这里提出新的对比损失监督,使得特征提取器不仅产生对类别标签有区分性的特征,而且也通过吸引来自不同领域的正样本来更好地提取领域不变性的特征;
步骤5:域不变性知识蒸馏;
由于上述分别从类原型和实例节点对比学习来训练的特征提取器没有直接考虑到域之间的差距。类原型是使用多域样本的平均向量,而实例节点对比则是去掉非本域的负样本。为此,引入了域不变性蒸馏去缩小源域和目标域的差距。域不变性蒸馏的目的是使用来自不同领域的节点信息来传递互补知识。在一组小批量样本中,通过对类标签相同的多个软标签值进行平均并依赖更多的互补知识来生成集合。由于不同领域表现出不同的类间关系,因此每个样本提供的其自身或其特定领域的信息有助于构建互补知识,而且还可以利用该知识进行监督以学习领域不变信息。Xc表示来自各个域的具有相同类标签c的所有样本集合。通过对Xc取平均值获得相应的软标签值:
其中,h是来自GNN编码器最后一层的学习节点嵌入,表示第i个节点的GNN输出;
定义来自Xc的预测分布softmax函数为:
计算域间知识蒸馏,定义域损失函数为Lkd:通过KL散度来最小化域间语义层面的差异;
其中,M表示域的集合,m表示域的数量,表示第m域中第i节点的图像分类概率,τ表示温度参数;Dkl表示KL散度,用于计算节点与均值的输出分布差异;
知识蒸馏原本是一种模型压缩的方法,通过最小化大模型的输出类概率(软标签)分类损失,训练较小的“学生”模型来模拟大“教师”模型的方法。对于领域泛化来说,可以把类标签相同的样本进行平均来生成的集合老师模型,通过与特定域的最小化类概率分布来减小域间差异。软标签提供了关于图像语义的更多信息。例如,给定CIFAR-10中的图像,狗图像是猫的类概率将远高于图像是汽车的类概率。因此,软标签为网络提供了关于“狗和猫图像是类似的”语义信息的额外提示。知识蒸馏还改善了损失情况,并帮助找到网络中的平坦的最小值,从而提高了泛化能力。知识蒸馏已被证明可以放大希尔伯特空间中的正则化,提升了泛化能力;
步骤6:定义总目标损失函数如下;
γ和为权重因子。整体训练以端到端的方式进行,其中后两项γLkd和/>只在训练过程中计算。Lclass是通过GNN原型学习获得的原型监督分类损失。Lkd是分类器上域不变知识蒸馏,反映了域间分布差异。Linstance是经过特征层面的图像实例节点网络结构分析而感知的实例类别信息的领域监督对比损失。
实施例:
为了对深度网络下的数据结构建模,首先使用标准卷积网络(例如ResNet)提取样本。然后,利用图卷积网络(GCN)获取嵌入空间中节点特征。之后,一方面利用嵌入式空间中节点特征构建类原型,并通过类原型进行分类;另一方面,通过对比监督学习从嵌入空间的节点特征中获得具有类别语义的信息。同时,为了减少多领域数据分布的差异,使用软标签缩小KL散度,进行域不变的知识蒸馏。
1)假定源域数据集包括M1,M2,M3三个数据集,每个域包含k张图片。输入多个域的图像样本数据构建源域的样本的图结构信息/>首先将包含n个节点的无向图邻接矩阵A可被转换为表达性更强的形式/>为度矩阵,/>j为节点i的所有邻接节点编号)。根据节点之间相似度,构建归一化后的邻接矩阵/>(I是单位矩阵)。通过公式(1)进行两层的GCN变换得到嵌入特征之后通过公式(2,3)计算各个域的类原型/>
2)然后在给定节点v中嵌入节点h∈Rd,通过公式(4)计算每个节点到每个类原型g∈Rd的距离度量表示公式(5)将节点嵌入差异联系到所有类原型,得到节点到所有类原型的距离信息,精确地定义了节点与所有类原型的相对位置。公式(6,7,8)通过连接层将类原型连接,计算交叉熵损失。
3)同时在通过两层GCN获得的嵌入空间中,公式(9)计算了嵌入空间实例节点监督对比损失,使得相同类别标签的样本特征彼此靠近,而不同的样本特征彼此远离,缓解不同域间分布的差异,同时也保留了类别语义信息。实例节点监督对比损失一方面考虑了类别内部的变化,另一方面也考虑了多领域的影响,其中某些域的正负样本会对其他一些域的正样本造成错误判断。
4)为了缩小领域之间的域分布差异,公式(10,11,12)通过域不变性知识蒸馏学习到来自不同域的节点信息之间的互补知识,缓解不同域间分布差异。
5)最后公式(13)定义了总目标损失函数。
以上所述,仅为本发明部分具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本领域的人员在本发明揭露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。
Claims (4)
1.一种基于图原型网络和实例对比的领域泛化方法,其特征在于,包括以下步骤:
步骤1:获取图像样本及其标签,并构建图像特征提取模型;
获取图像样本构建初始图像数据集,将所述初始图像数据集划分为源域数据集M={M1,...Mi,...,Mm}和目标域数据集T,其中Mm表示第m域数据集;所述目标域数据集在所述图像特征提取模型的训练过程中是不可访问的;
源域数据集M划分为训练集和验证集,将所述源域数据集M中的图像进行数据增强;
获取预训练模型,基于预训练模型构建所述图像特征提取模型;
通过所述图像特征提取模型,提取源域数据集M中的特征,作为图输入特征X;
步骤2:建立图卷积网络并获取类原型表示;
将提取源域数据集M的特征的图结构信息定义为G=<V,E,Z>,其中V={v1,...,vn}是n个节点的集合,是通过两层GCN层提取获得的节点特征,E={e11,...,eij,...,enn}表示节点之间距离;其中,采用余弦相似度/>表示节点i和节点j之间距离;
通过节点间距离E构造包含n个节点的无向图邻接矩阵A,将所述无向图邻接矩阵A转换其中,/>为度矩阵,/>j为节点i的邻接节点编号;
根据节点之间相似度,构建归一化后的邻接矩阵其中,I是单位矩阵;
对于一个给定的包含n个节点的无向图邻接矩阵A∈Rn×n,图卷积的线性变换取决于图输入特征X∈Rk×n与滤波器W∈Rk×d;
其中,图输入特征X中的列向量Xi∈Rk是节点的集合V中第i个节点的特征表示,d表示输出的特征维度;
按如下式所示的方法,进行两层的GCN处理得到嵌入特征
其中,σ为激活函数,表示为第i个节点在第l层的输出,/>是图卷积输入;
之后利用图卷积网络生成的嵌入特征计算类原型P∈Rc×d的表示,/>表示图卷积网络的第m源域中第i个节点输出;
所述类原型的定义为被同一类的节点紧密包围,这样同一类的节点就可以表示自己的类;第m源域的第c类的原型通过以下方式计算:
其中PROTP是计算类原型P的表示的方式,是第m源域中第i个节点的表示,mc为第m源域的第c类,vi为第m域的第c类的第i个节点,具体公式如下:
将所述节点从原始嵌入空间投影到另一个距离空间来学习一个距离度量表示;
步骤3:通过比较节点的学习距离度量表示与类原型的距离度量表示进行分类;
计算距离度量损失:
由图卷积学习到的嵌入节点计算每个节点到每个类原型/>的距离度量表示:
其中,为第m源域中每个节点和每个类原型之间嵌入差异;
将节点嵌入差异联系到所有类原型,并按如下式(5)所示的方法,应用线性变换f对嵌入差异的不同维度给予不同程度的关注,同时自适应地提取嵌入差异信息,如下式所示:
距离度量表示g表示节点v到所有类原型的距离信息,用于定义了第m源域中节点与所有类原型的相对位置,c∈C表示第c个类;按如下式所示的方法,将距离度量表示通过连接层concat连接起来,以计算在所有源域M中的类原型和节点的距离度量表示:
G=concat(g1,…,gm) (6)
然后计算第i个节点vi的softmax的值:
其中表示在整个源域中节点vi对类c的距离度量,P(y=c|vi)给出了节点vi对类c的预测概率分布,按如下式所示的方法构建交叉熵分类损失:
步骤4:嵌入空间实例节点监督对比学习;
按如下式所示的方法构建领域监督对比损失:
其中,I是所有小批量样本集合,i是一个锚点,p∈P(i)是I中与第i个样本相同的正样本,|P(i)|是集合P(i)中样本的数量,是不同于第i个样本的其他类别样本且与第i个样本为同一域的负样本,/>表示对图卷积输出进行l2正则化的特征,τ表示温度参数,f表示不同节点的相似性度量,相同类别标签的样本表示/>和/>彼此靠近,而不同类别标签的样本表示/>与/>彼此远离;
步骤5:域不变性知识蒸馏;
按如下式所示的方法,通过域不变性知识蒸馏学习到来自不同域的节点信息之间的互补知识,其中,Xc表示来自各个域的具有相同类标签c的所有样本集合;通过对Xc取平均值获得相应的软标签值:
其中,h是来自GNN编码器最后一层的学习节点嵌入,表示第i个节点的GNN输出;
定义来自Xc的预测分布softmax函数为:
计算域间知识蒸馏,定义域损失函数为Lkd:通过KL散度来最小化域间语义层面的差异;
其中,M表示域的集合,m表示域的数量,表示第m域中第i节点的图像分类概率,τ表示温度参数;Dkl表示KL散度,用于计算节点与均值的输出分布差异;
步骤6:定义总目标损失函数如下;
其中,γ和为权重因子;
通过总目标损失函数对图卷积网络进行训练,实现领域泛化。
2.如权利要求1所述的一种基于图原型网络和实例对比的领域泛化方法,其特征在于,所述步骤1中,所述数据增强的方式包括裁切、反转。
3.如权利要求1所述的一种基于图原型网络和实例对比的领域泛化方法,其特征在于,所述步骤1中,所述预训练模型为ResNet或者AlexNet模型中的任意一种。
4.如权利要求1所述的一种基于图原型网络和实例对比的领域泛化方法,其特征在于,所述步骤2中,图输入特征X由特征提取器ResNet获得,所述GCN的激活函数采用ReLu。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310289243.0A CN116796184A (zh) | 2023-03-23 | 2023-03-23 | 一种基于图原型网络和实例对比的领域泛化方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310289243.0A CN116796184A (zh) | 2023-03-23 | 2023-03-23 | 一种基于图原型网络和实例对比的领域泛化方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116796184A true CN116796184A (zh) | 2023-09-22 |
Family
ID=88048802
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310289243.0A Pending CN116796184A (zh) | 2023-03-23 | 2023-03-23 | 一种基于图原型网络和实例对比的领域泛化方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116796184A (zh) |
-
2023
- 2023-03-23 CN CN202310289243.0A patent/CN116796184A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111814854B (zh) | 一种无监督域适应的目标重识别方法 | |
CN113378632B (zh) | 一种基于伪标签优化的无监督域适应行人重识别方法 | |
CN108132968B (zh) | 网络文本与图像中关联语义基元的弱监督学习方法 | |
CN110363282B (zh) | 一种基于图卷积网络的网络节点标签主动学习方法和系统 | |
CN110909820A (zh) | 基于自监督学习的图像分类方法及系统 | |
CN113313232B (zh) | 一种基于预训练和图神经网络的功能脑网络分类方法 | |
CN106682696A (zh) | 基于在线示例分类器精化的多示例检测网络及其训练方法 | |
CN113469186B (zh) | 一种基于少量点标注的跨域迁移图像分割方法 | |
TWI780567B (zh) | 對象再識別方法、儲存介質及電腦設備 | |
CN114692732B (zh) | 一种在线标签更新的方法、系统、装置及存储介质 | |
CN104268546A (zh) | 一种基于主题模型的动态场景分类方法 | |
CN113065409A (zh) | 一种基于摄像分头布差异对齐约束的无监督行人重识别方法 | |
CN116910571B (zh) | 一种基于原型对比学习的开集域适应方法及系统 | |
Cheng et al. | Leveraging semantic segmentation with learning-based confidence measure | |
CN112183464A (zh) | 基于深度神经网络和图卷积网络的视频行人识别方法 | |
CN115439685A (zh) | 一种小样本图像数据集划分方法及计算机可读存储介质 | |
CN117690098A (zh) | 一种基于动态图卷积的开放驾驶场景下多标签识别方法 | |
CN116977710A (zh) | 一种遥感图像长尾分布目标半监督检测方法 | |
CN116433909A (zh) | 基于相似度加权多教师网络模型的半监督图像语义分割方法 | |
CN117523295A (zh) | 基于类引导元学习的无源域适应的图像分类方法 | |
CN111930981A (zh) | 一种草图检索的数据处理方法 | |
CN117152427A (zh) | 基于扩散模型和知识蒸馏的遥感图像语义分割方法与系统 | |
Qin | Application of efficient recognition algorithm based on deep neural network in English teaching scene | |
Gori et al. | Semantic video labeling by developmental visual agents | |
CN116681128A (zh) | 一种带噪多标签数据的神经网络模型训练方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |