CN114676755A - 基于图卷积网络的无监督域自适应的分类方法 - Google Patents

基于图卷积网络的无监督域自适应的分类方法 Download PDF

Info

Publication number
CN114676755A
CN114676755A CN202210208723.5A CN202210208723A CN114676755A CN 114676755 A CN114676755 A CN 114676755A CN 202210208723 A CN202210208723 A CN 202210208723A CN 114676755 A CN114676755 A CN 114676755A
Authority
CN
China
Prior art keywords
domain
feature
source
source domain
classification
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210208723.5A
Other languages
English (en)
Inventor
吴飞
魏鹏飞
高广谓
胡长晖
季一木
蒋国平
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Nanjing University of Posts and Telecommunications
Original Assignee
Nanjing University of Posts and Telecommunications
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Nanjing University of Posts and Telecommunications filed Critical Nanjing University of Posts and Telecommunications
Priority to CN202210208723.5A priority Critical patent/CN114676755A/zh
Publication of CN114676755A publication Critical patent/CN114676755A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/22Matching criteria, e.g. proximity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/25Fusion techniques
    • G06F18/253Fusion techniques of extracted features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Molecular Biology (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Probability & Statistics with Applications (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本申请涉及一种基于图卷积网络的无监督域自适应的分类。所述方法包括:获取源域中的样本数据和目标域中样本数据作为训练数据;根据源域和目标域中样本数据间的相似性分别更新两个域中样本的图连接关系;将源域和目标域中的样本数据输入到域自适应网络中进行训练,域自适应网络是基于图卷积网络的无监督域自适应网络,域自适应网络包括:跨域特征提取模型、源域特征提取模型、分类模型、域对抗鉴别模型、类对齐模型;训练域自适应网络不断更新迭代域自适应网络中的参数,当域自适应网络达到收敛条件时,获得域自适应分类模型;输入待分类数据至域自适应分类模型进行分类,获得待分类数据的分类结果。提高了基于图卷积的无监督域自适应模型性能。

Description

基于图卷积网络的无监督域自适应的分类方法
技术领域
本申请涉及深度学习技术领域,特别是涉及一种基于图卷积网络的无监督域自适应的分类方法。
背景技术
无监督域自适应任务,即利用源域的信息来辅助完成目标域上的任务,其中源域的样本是标记的或者是部分标记的,目标域的样本是无标记的。无监督域自适应的主要挑战是如何对齐源域和目标域的数据分布。
对于无监督域适应任务,一般深度学习方法通常会将源域和目标域中的样本转换到同一公共空间中。例如曾等人为了减少公共空间中源域和目标域之间的分布差异在网络中共享参数层上设计了最大平均差异(MMD,Maximum Mean Discrepancy)损失。Ganin等人设计了一个域鉴别器来区分每个样本来自哪个域,并提出了一个梯度反转层(GRL,Gradient Reversal Layer)来最大化域分类损失以减少域之间的分布差异。丁等人提出了一种自适应探索(AE,Adaptive Exploration)方法,通过最大化所有行人图像之间的距离并最小化相似行人图像之间的距离来解决行人重识别的域转移问题。尽管深度学习方法在减少域差异方面取得了一些进展,但源域中的标签率还是会影响无监督域适应任务的预测结果。源域中的标签率越低,目标域的预测结果越差。
随着图神经网络的引入,Kipf等人提出的图卷积网络(GCN,Graph ConvolutionalNetwork)在半监督分类任务中取得了理想的结果。在领域自适应任务中,给定少量标记的源数据,图卷积网络通常能够通过在源网络中传播样本信息来构建一个性能良好的分类器。例如戴等人结合图卷积网络和对抗域自适应模型来减少分布差异并进行准确的标签预测。
现有的基于图卷积网络的无监督域自适应的分类方法关注于两个域之间的公共信息,而没有去对域的特定信息加以利用。而且没有进一步关注类别级别的分布对齐问题,这可能会导致跨域的同一类样本的分布负对齐,并且可能不利于目标域的任务,从而导致训练的基于图卷积的无监督域自适应分类模型的性能低。
发明内容
基于此,有必要针对上述技术问题,提供一种能够提高训练的基于图卷积的无监督域自适应分类模型性能的基于图卷积网络的无监督域自适应的分类方法。
一种基于图卷积网络的无监督域自适应的分类方法,所述方法包括:
获取源域中的样本数据和目标域中样本数据作为训练数据;
根据所述源域和所述目标域中样本数据间的相似性分别更新两个域中样本的图连接关系;
将所述源域和所述目标域中的样本数据输入到域自适应网络中进行训练,所述域自适应网络是基于图卷积网络的无监督域自适应网络,所述域自适应网络包括:跨域特征提取模型、源域特征提取模型、分类模型、域对抗鉴别模型、类对齐模型;
训练所述域自适应网络不断更新迭代所述域自适应网络中的参数,当所述域自适应网络达到收敛条件时,获得域自适应分类模型;
输入待分类数据至所述域自适应分类模型进行分类,获得所述待分类数据的分类结果。
在其中一个实施例中,所述跨域特征提取模型提取所述源域和所述目标域公共的样本特征,所述源域特征提取模型提取所述源域特定的样本特征,所述分类模型计算分类损失值,所述域对抗鉴别模型计算域特征对齐损失值,所述类对齐模型计算类特征对齐损失值。
在其中一个实施例中,所述总损失值为特征差异性损失值、分类损失值、域特征对齐损失值和类特征对齐损失值的和,其中,所述特征差异性损失值是所述源域的样本数据输入跨域特征提取模型和所述源域特征提取模型得到的特征差异,分类损失值是基于源域的样本数据输入分类模型的。
在其中一个实施例中,所述域自适应网络的构建方式包括:
源域的样本数据和目标域的样本数据输入到跨域特征提取模型中得到源域和目标域的公共嵌入特征表示;
源域的样本数据输入到源域特征提取模型中得到所述源域的特定嵌入特征表示;
计算所述源域的公共嵌入特征表示和所述特定嵌入特征表示的差异性构建特征差异性损失函数;
将目标域的样本数据输入到源域特征提取模型中得到带源域风格的目标域嵌入特征表示,再通过注意力机制与所述目标域的公共嵌入特征表示结合为目标域的嵌入特征表示;同时通过注意力机制将所述源域的公共嵌入特征表示与源域的特定嵌入特征表示结合为源域的嵌入特征表示;
将得到的所述源域和所述目标域的嵌入特征表示输入到分类模型中,所述源域的嵌入特征表示中有类别标签的部分构建分类损失函数,源域的其余无类别标签的部分和目标域的嵌入特征表示则生成其特征表示所对应的伪类别标签;
将得到的所述源域和所述目标域的公共嵌入特征表示输入到域对抗鉴别模型中,构建域特征对齐损失函数;
将所述源域的样本数据和所述目标域的样本数据按照类别标签和伪类别标签中的类别进行分组,同时将不同分组的样本的嵌入特征表示输入到类对齐模型中,构建类特征对齐损失函数。
在其中一个实施例中,所述跨域特征提取模型是由两层的图卷积神经网络的共享网络组成的,所述源域的样本数据与目标域的样本数据都输入到共享网络中得到其公共嵌入特征表示;
所述源域特征提取模型是由两层的图卷积神经网络模型组成的。
在其中一个实施例中,所述特征差异性损失函数为:
Figure BDA0003532260120000031
式中,
Figure BDA0003532260120000032
表示源域的公共嵌入特征表示,
Figure BDA0003532260120000033
表示源域的特定嵌入特征表示,Lm表示特征差异性损失函数,T表示转置运算。
在其中一个实施例中,所述分类损失函数为:
Figure BDA0003532260120000041
式中,
Figure BDA0003532260120000042
表示源域的有类别标签样本的嵌入特征表示,
Figure BDA0003532260120000043
为分类模型测得的分类结果,
Figure BDA0003532260120000044
为源域属于第k类的类别标签,k∈[1,C],C为样本的总类别数,nsl为源域中有类别标签的样本个数,Ls表示分类损失函数。
在其中一个实施例中,所述域特征对齐损失函数为:
Figure BDA0003532260120000045
式中,zci表示第i个公共嵌入特征表示,Gd(zi)为域对抗鉴别模型测得的结果,
Figure BDA0003532260120000046
为输入的公共特征表示属于源域的域标签还是目标域的域标签,ns为源域的样本总个数,nt为目标域的样本总个数,Ld表示域特征对齐损失函数。
在其中一个实施例中,所述类特征对齐损失函数为:
Figure BDA0003532260120000047
式中,Lc表示类特征对齐损失函数,
Figure BDA0003532260120000048
表示类别标签或者伪类别标签为第k类样本的源域嵌入特征表示,
Figure BDA0003532260120000049
表示伪类别标签为第k类样本的目标域嵌入特征表示,
Figure BDA00035322601200000410
Figure BDA00035322601200000411
分别为
Figure BDA00035322601200000412
Figure BDA00035322601200000413
的概率分布,C为样本的总类别数。
在其中一个实施例中,所述不断更新迭代所述域自适应网络中的参数的表达式为:
min(Ls+λLm-βLd+γLc)
式中,Lm表示特征差异性损失函数,Ls表示分类损失函数,Ld表示域特征对齐损失函数,Lc表示类特征对齐损失函数,λ,β和γ分别为对应损失函数之间的平衡因子。
上述基于图卷积网络的无监督域自适应的分类方法,通过获取源域中的样本数据和目标域中样本数据作为训练数据;根据源域和目标域中样本数据间的相似性分别更新两个域中样本的图连接关系;将源域和目标域中的样本数据输入到域自适应网络中进行训练,域自适应网络是基于图卷积网络的无监督域自适应网络,域自适应网络包括:跨域特征提取模型、源域特征提取模型、分类模型、域对抗鉴别模型、类对齐模型;训练域自适应网络不断更新迭代域自适应网络中的参数,当域自适应网络达到收敛条件时,获得域自适应分类模型;输入待分类数据至域自适应分类模型进行分类,获得待分类数据的分类结果。提高了基于图卷积的无监督域自适应分类模型性能,进一步提高对数据分类的准确性。
附图说明
图1为一个实施例中基于图卷积网络的无监督域自适应的分类方法的流程示意图;
图2为一个实施例中域自适应网络的构建方式的流程示意图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本申请,并不用于限定本申请。
本申请提供的基于图卷积网络的无监督域自适应的分类方法,可以应用于终端或服务器。其中,终端可以但不限于是各种个人计算机、笔记本电脑、智能手机、平板电脑和便携式可穿戴设备,服务器可以用独立的服务器或者是多个服务器组成的服务器集群来实现。
在一个实施例中,如图1所示,提供了一种基于图卷积网络的无监督域自适应的分类方法,以该方法应用于终端为例进行说明,包括以下步骤:
步骤S220,获取源域中的样本数据Xs和目标域中样本数据Xt作为训练数据。
其中,源域中的样本数据和目标域中样本数据的类别以及类别数相同。源域中的部分样本数据有类别标签,目标域中样本数据没有类别标签。类别标签是指用于标记样本属于哪一类别的标签。样本数据的类型可以是文本数据,也可以是图片数据,还可以是音频数据,根据分类任务的需要,确定样本数据的类型。如:需要训练用于分类论文属于哪个学科的分类模型时,可以将标记有属于哪个学科的论文作为源域中的样本数据,将没有标记属于哪个学科的论文作为目标域中样本数据,执行步骤S240至步骤S280,获得用于对论文属于哪个学科进行分类的域自适应分类模型。
步骤S240,根据源域和目标域中样本数据间的相似性分别去更新两个域中样本的图连接关系As和At
其中,使用正点互信息(PPMI,Positive Pointwise Mutual Information)来计算样本数据之间的相似性。PPMI的计算公式如下:
Figure BDA0003532260120000061
式中,
Figure BDA0003532260120000062
其中n为一个域内的样本个数,Aij为样本i和样本j连接的权重系数,ppmiij为样本i和样本j的样本相似度,ppmiij的值越大表明相似度越高。
步骤S260,将源域和目标域中的样本数据输入到域自适应网络中进行训练,域自适应网络是基于图卷积网络的无监督域自适应网络,域自适应网络包括:跨域特征提取模型、源域特征提取模型、分类模型、域对抗鉴别模型、类对齐模型。
步骤S280,训练域自适应网络不断更新迭代域自适应网络中的参数,当域自适应网络达到收敛条件时,获得域自适应分类模型。
步骤S300,输入待分类数据至域自适应分类模型进行分类,获得待分类数据的分类结果。
其中,待分类数据是需要进行分类的数据,待分类数据可以是很多个数据,也可以是一个数据,如:需要对某一篇论文属于哪个学科进行分类,将该论文输入域自适应分类模型,输出该论文所属于的学科。
上述基于图卷积网络的无监督域自适应的分类方法,通过获取源域中的样本数据和目标域中样本数据作为训练数据;根据源域和目标域中样本数据间的相似性分别更新两个域中样本的图连接关系;将源域和目标域中的样本数据输入到域自适应网络中进行训练,域自适应网络是基于图卷积网络的无监督域自适应网络,域自适应网络包括:跨域特征提取模型、源域特征提取模型、分类模型、域对抗鉴别模型、类对齐模型;训练域自适应网络不断更新迭代域自适应网络中的参数,当域自适应网络达到收敛条件时,获得域自适应分类模型;输入待分类数据至域自适应分类模型进行分类,获得待分类数据的分类结果。提高了基于图卷积的无监督域自适应分类模型性能,进一步提高对数据分类的准确性。
在一个实施例中,跨域特征提取模型提取源域和目标域公共的样本特征,源域特征提取模型提取源域特定的样本特征,分类模型计算分类损失值,域对抗鉴别模型计算域特征对齐损失值,类对齐模型计算类特征对齐损失值。
在一个实施例中,总损失值为特征差异性损失值、分类损失值、域特征对齐损失值和类特征对齐损失值的和,其中,特征差异性损失值是源域的样本数据输入跨域特征提取模型和源域特征提取模型得到的特征差异,分类损失值是基于源域的样本数据输入分类模型的。
如图2所示,在一个实施例中,域自适应网络的构建方式包括:源域的样本数据和目标域的样本数据输入到跨域特征提取模型中得到源域和目标域的公共嵌入特征表示;源域的样本数据输入到源域特征提取模型中得到源域的特定嵌入特征表示;计算源域的公共嵌入特征表示和特定嵌入特征表示的差异性构建特征差异性损失函数;将目标域的样本数据输入到源域特征提取模型中得到带源域风格的目标域嵌入特征表示,再通过注意力机制与目标域的公共嵌入特征表示结合为目标域的嵌入特征表示;同时通过注意力机制将源域的公共嵌入特征表示与源域的特定嵌入特征表示结合为源域的嵌入特征表示;将得到的源域和目标域的嵌入特征表示输入到分类模型中,源域的嵌入特征表示中有类别标签的部分构建分类损失函数,源域的其余无类别标签的部分和目标域的嵌入特征表示则生成其特征表示所对应的伪类别标签;将得到的源域和目标域的公共嵌入特征表示输入到域对抗鉴别模型中,构建域特征对齐损失函数;将源域的样本数据和目标域的样本数据按照类别标签和伪类别标签中的类别进行分组,同时将不同分组的样本的嵌入特征表示输入到类对齐模型中,构建类特征对齐损失函数。
其中,利用跨域特征提取模型得到源域和目标域的公共嵌入特征表示,再分别通过源域特征提取模型得到源域的特定嵌入特征表示和带源域风格的目标域的特定嵌入特征表示,通过注意力机制分别融合公共嵌入特征表示和特定嵌入特征表示,得到源域的嵌入特征表示和目标域的嵌入特征表示,从而混淆源域和目标域以缩小两个域的分布差异。将有源域中有类别标签的嵌入特征表示去训练分类模型,同时给两个域的无类别标签的样本生成伪类别标签。通过特征差异性损失函数使得公共嵌入特征表示和特定嵌入特征表示互斥,还分别设计了域特征对齐损失函数和类特征对齐损失函数消除域分布差异和相同类的分布差异,在仅依赖源域少量标签样本的情况下提高了域自适应任务的准确性。
在一个实施例中,跨域特征提取模型是由两层的图卷积神经网络(GCN)的共享网络组成的,源域的样本数据与目标域的样本数据都输入到共享网络中得到其公共嵌入特征表示;源域特征提取模型是由两层的图卷积神经网络模型组成的。
其中,图卷积神经网络提取模型不同域样本的嵌入特征表示,挖掘样本之间的连接关系,促进了样本之间的信息传递。其源域的样本数据与目标域的样本数据的公共嵌入特征表示计算公式如下:
Figure BDA0003532260120000081
Figure BDA0003532260120000082
其中,As为源域中样本之间的图连接关系,Xs为源域中的样本数据,θ0为图卷积神经网络第一层的网络参数,θ1为图卷积神经网络第二层的网络参数,At为目标域中样本的图连接关系,Xt为目标域中样本数据,
Figure BDA0003532260120000083
表示目标域的公共嵌入特征表示,
Figure BDA0003532260120000084
表示源域的公共嵌入特征表示。
源域特征提取模型是由两层的图卷积神经网络(GCN)构成的,其源域样本输入到该模型中得到源域特定嵌入特征表示
Figure BDA0003532260120000091
目标域输入到该模型中得到带源域风格的目标域特定嵌入特征表示
Figure BDA0003532260120000092
在一个实施例中,特征差异性损失函数为:
Figure BDA0003532260120000093
式中,
Figure BDA0003532260120000094
表示源域的公共嵌入特征表示,
Figure BDA0003532260120000095
表示源域的特定嵌入特征表示,Lm表示特征差异性损失函数,T表示转置运算。
其中,源域的嵌入特征表示Zs是由注意力机制将源域的特定嵌入特征表示和源域的公共嵌入特征表示结合起来,目标域的嵌入特征表示Zt是由注意力机制将带源域风格的目标域特定嵌入特征表示和目标域公共嵌入特征表示结合起来。其中,注意力机制的计算方法如下:
Figure BDA0003532260120000096
Figure BDA0003532260120000097
式中,w1和w2为列向量,且w1+w2=1。
在一个实施例中,分类损失函数为:
Figure BDA0003532260120000098
式中,
Figure BDA0003532260120000099
表示源域的有类别标签样本的嵌入特征表示,
Figure BDA00035322601200000910
为分类模型测得的分类结果,
Figure BDA00035322601200000911
为源域属于第k类的类别标签,k∈[1,C],C为样本的总类别数,nsl为源域中有类别标签的样本个数,Ls表示分类损失函数。
在一个实施例中,域特征对齐损失函数为:
Figure BDA00035322601200000912
式中,zci表示第i个公共嵌入特征表示,Gd(zi)为域对抗鉴别模型测得的结果,
Figure BDA00035322601200000913
为输入的公共特征表示属于源域的域标签还是目标域的域标签,ns为源域的样本总个数,nt为目标域的样本总个数,Ld表示域特征对齐损失函数。
其中,域标签是用于标识公共特征表示属于哪一个域的标识。
在一个实施例中,类特征对齐损失函数为:
Figure BDA0003532260120000101
式中,Lc表示类特征对齐损失函数,
Figure BDA0003532260120000102
表示类别标签或者伪类别标签为第k类样本的源域嵌入特征表示,
Figure BDA0003532260120000103
表示伪类别标签为第k类样本的目标域嵌入特征表示,
Figure BDA0003532260120000104
Figure BDA0003532260120000105
分别为
Figure BDA0003532260120000106
Figure BDA0003532260120000107
的概率分布,C为样本的总类别数。
在一个实施例中,不断更新迭代域自适应网络中的参数的表达式为:
min(Ls+λLm-βLd+γLc)
式中,Lm表示特征差异性损失函数,Ls表示分类损失函数,Ld表示域特征对齐损失函数,Lc表示类特征对齐损失函数,λ,β和γ分别为对应损失函数之间的平衡因子。
上述基于图卷积网络的无监督域自适应的分类方法,使用图卷积神经网络提取不同域样本的嵌入特征表示,挖掘样本之间的连接关系,促进了样本之间的信息传递。其次将目标域经过源域特征提取模型得到带源域风格的特定嵌入特征表示,再利用差异性损失使得公共嵌入特征表示与特定的嵌入特征表示不相关性。本发明中通过对抗机制设立域对抗鉴别模型最大化域分类损失消除了域之间公共嵌入特征的分布差异。通过注意力机制融合公共和特定的嵌入特征表示为源域嵌入特征表示和目标域嵌入特征表示,同时设立分类模型对有类别标签样本进行分类计算分类损失和给无类别标签样本打上伪类别标签,其中分类损失确保分类模型的有效性。最后在发明中设立类对齐模型消除同类样本不同域之间的分布差异,在类别级别上对齐两个域的样本分布。进一步有效的提升了基于图卷积的无监督域自适应分类模型的性能。
应该理解的是,虽然图1的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,这些步骤可以以其它的顺序执行。而且,图1中的至少一部分步骤可以包括多个子步骤或者多个阶段,这些子步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,这些子步骤或者阶段的执行顺序也不必然是依次进行,而是可以与其它步骤或者其它步骤的子步骤或者阶段的至少一部分轮流或者交替地执行。
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
以上所述实施例仅表达了本申请的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对发明专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本申请构思的前提下,还可以做出若干变形和改进,这些都属于本申请的保护范围。因此,本申请专利的保护范围应以所附权利要求为准。

Claims (10)

1.一种基于图卷积网络的无监督域自适应的分类方法,其特征在于,所述方法包括:
获取源域中的样本数据和目标域中样本数据作为训练数据;
根据所述源域和所述目标域中样本数据间的相似性分别更新两个域中样本的图连接关系;
将所述源域和所述目标域中的样本数据输入到域自适应网络中进行训练,所述域自适应网络是基于图卷积网络的无监督域自适应网络,所述域自适应网络包括:跨域特征提取模型、源域特征提取模型、分类模型、域对抗鉴别模型、类对齐模型;
训练所述域自适应网络不断更新迭代所述域自适应网络中的参数,当所述域自适应网络达到收敛条件时,获得域自适应分类模型;
输入待分类数据至所述域自适应分类模型进行分类,获得所述待分类数据的分类结果。
2.根据权利要求1所述的方法,其特征在于,所述跨域特征提取模型提取所述源域和所述目标域公共的样本特征,所述源域特征提取模型提取所述源域特定的样本特征,所述分类模型计算分类损失值,所述域对抗鉴别模型计算域特征对齐损失值,所述类对齐模型计算类特征对齐损失值。
3.根据权利要求2所述的方法,其特征在于,所述总损失值为特征差异性损失值、分类损失值、域特征对齐损失值和类特征对齐损失值的和,其中,所述特征差异性损失值是所述源域的样本数据输入跨域特征提取模型和所述源域特征提取模型得到的特征差异,分类损失值是基于源域的样本数据输入分类模型的。
4.根据权利要求1所述的方法,其特征在于,所述域自适应网络的构建方式包括:
源域的样本数据和目标域的样本数据输入到跨域特征提取模型中得到源域和目标域的公共嵌入特征表示;
源域的样本数据输入到源域特征提取模型中得到所述源域的特定嵌入特征表示;
计算所述源域的公共嵌入特征表示和所述特定嵌入特征表示的差异性构建特征差异性损失函数;
将目标域的样本数据输入到源域特征提取模型中得到带源域风格的目标域嵌入特征表示,再通过注意力机制与所述目标域的公共嵌入特征表示结合为目标域的嵌入特征表示;同时通过注意力机制将所述源域的公共嵌入特征表示与源域的特定嵌入特征表示结合为源域的嵌入特征表示;
将得到的所述源域和所述目标域的嵌入特征表示输入到分类模型中,所述源域的嵌入特征表示中有类别标签的部分构建分类损失函数,源域的其余无类别标签的部分和目标域的嵌入特征表示则生成其特征表示所对应的伪类别标签;
将得到的所述源域和所述目标域的公共嵌入特征表示输入到域对抗鉴别模型中,构建域特征对齐损失函数;
将所述源域的样本数据和所述目标域的样本数据按照类别标签和伪类别标签中的类别进行分组,同时将不同分组的样本的嵌入特征表示输入到类对齐模型中,构建类特征对齐损失函数。
5.根据权利要求4所述的方法,其特征在于,所述跨域特征提取模型是由两层的图卷积神经网络的共享网络组成的,所述源域的样本数据与目标域的样本数据都输入到共享网络中得到其公共嵌入特征表示;
所述源域特征提取模型是由两层的图卷积神经网络模型组成的。
6.根据权利要求4所述的方法,其特征在于,所述特征差异性损失函数为:
Figure FDA0003532260110000021
式中,
Figure FDA0003532260110000022
表示源域的公共嵌入特征表示,
Figure FDA0003532260110000023
表示源域的特定嵌入特征表示,Lm表示特征差异性损失函数,T表示转置运算。
7.根据权利要求4所述的方法,其特征在于,所述分类损失函数为:
Figure FDA0003532260110000024
式中,
Figure FDA0003532260110000025
表示源域的有类别标签样本的嵌入特征表示,
Figure FDA0003532260110000026
为分类模型测得的分类结果,
Figure FDA0003532260110000031
为源域属于第k类的类别标签,k∈[1,C],C为样本的总类别数,nsl为源域中有类别标签的样本个数,Ls表示分类损失函数。
8.根据权利要求4所述的方法,其特征在于,所述域特征对齐损失函数为:
Figure FDA0003532260110000032
式中,zci表示第i个公共嵌入特征表示,Gd(zi)为域对抗鉴别模型测得的结果,
Figure FDA0003532260110000033
为输入的公共特征表示属于源域的域标签还是目标域的域标签,ns为源域的样本总个数,nt为目标域的样本总个数,Ld表示域特征对齐损失函数。
9.根据权利要求4所述的方法,其特征在于,所述类特征对齐损失函数为:
Figure FDA0003532260110000034
式中,Lc表示类特征对齐损失函数,
Figure FDA0003532260110000035
表示类别标签或者伪类别标签为第k类样本的源域嵌入特征表示,
Figure FDA0003532260110000036
表示伪类别标签为第k类样本的目标域嵌入特征表示,
Figure FDA0003532260110000037
Figure FDA0003532260110000038
分别为
Figure FDA0003532260110000039
Figure FDA00035322601100000310
的概率分布,C为样本的总类别数。
10.根据权利要求1所述的方法,其特征在于,所述不断更新迭代所述域自适应网络中的参数的表达式为:
min(Ls+λLm-βLd+γLc)
式中,Lm表示特征差异性损失函数,Ls表示分类损失函数,Ld表示域特征对齐损失函数,Lc表示类特征对齐损失函数,λ,β和γ分别为对应损失函数之间的平衡因子。
CN202210208723.5A 2022-03-04 2022-03-04 基于图卷积网络的无监督域自适应的分类方法 Pending CN114676755A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210208723.5A CN114676755A (zh) 2022-03-04 2022-03-04 基于图卷积网络的无监督域自适应的分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210208723.5A CN114676755A (zh) 2022-03-04 2022-03-04 基于图卷积网络的无监督域自适应的分类方法

Publications (1)

Publication Number Publication Date
CN114676755A true CN114676755A (zh) 2022-06-28

Family

ID=82072060

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210208723.5A Pending CN114676755A (zh) 2022-03-04 2022-03-04 基于图卷积网络的无监督域自适应的分类方法

Country Status (1)

Country Link
CN (1) CN114676755A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116403058A (zh) * 2023-06-09 2023-07-07 昆明理工大学 一种遥感跨场景多光谱激光雷达点云分类方法

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116403058A (zh) * 2023-06-09 2023-07-07 昆明理工大学 一种遥感跨场景多光谱激光雷达点云分类方法
CN116403058B (zh) * 2023-06-09 2023-09-12 昆明理工大学 一种遥感跨场景多光谱激光雷达点云分类方法

Similar Documents

Publication Publication Date Title
CN110598206B (zh) 文本语义识别方法、装置、计算机设备和存储介质
CN109636658B (zh) 一种基于图卷积的社交网络对齐方法
CN110378366B (zh) 一种基于耦合知识迁移的跨域图像分类方法
CN112084331A (zh) 文本处理、模型训练方法、装置、计算机设备和存储介质
CN110263160B (zh) 一种计算机问答系统中的问句分类方法
CN111695415A (zh) 图像识别模型的构建方法、识别方法及相关设备
CN112380435A (zh) 基于异构图神经网络的文献推荐方法及推荐系统
CN111241992B (zh) 人脸识别模型构建方法、识别方法、装置、设备及存储介质
CN112231592B (zh) 基于图的网络社团发现方法、装置、设备以及存储介质
WO2022252458A1 (zh) 一种分类模型训练方法、装置、设备及介质
CN113657087B (zh) 信息的匹配方法及装置
CN111062036A (zh) 恶意软件识别模型构建、识别方法及介质和设备
CN113255714A (zh) 图像聚类方法、装置、电子设备及计算机可读存储介质
CN112819024B (zh) 模型处理方法、用户数据处理方法及装置、计算机设备
CN115130711A (zh) 一种数据处理方法、装置、计算机及可读存储介质
CN111159481B (zh) 图数据的边预测方法、装置及终端设备
CN114357151A (zh) 文本类目识别模型的处理方法、装置、设备及存储介质
CN116229170A (zh) 基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备
CN114676755A (zh) 基于图卷积网络的无监督域自适应的分类方法
CN111309923A (zh) 对象向量确定、模型训练方法、装置、设备和存储介质
CN114463552A (zh) 迁移学习、行人重识别方法及相关设备
CN111783088B (zh) 一种恶意代码家族聚类方法、装置和计算机设备
CN112215629B (zh) 基于构造对抗样本的多目标广告生成系统及其方法
CN116758373A (zh) 深度学习模型的训练方法、图像处理方法、装置和设备
CN116307078A (zh) 账户标签预测方法、装置、存储介质及电子设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination