CN112668633A - 一种基于细粒度领域自适应的图迁移学习方法 - Google Patents
一种基于细粒度领域自适应的图迁移学习方法 Download PDFInfo
- Publication number
- CN112668633A CN112668633A CN202011561512.7A CN202011561512A CN112668633A CN 112668633 A CN112668633 A CN 112668633A CN 202011561512 A CN202011561512 A CN 202011561512A CN 112668633 A CN112668633 A CN 112668633A
- Authority
- CN
- China
- Prior art keywords
- domain
- graph
- target domain
- source
- samples
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Images
Abstract
本发明提供了一种基于细粒度领域自适应的图迁移学习方法,包括:(1)采集源域和目标域中的样本,并分别标注源域和目标域中的部分或全部样本,获得带标签的样本;(2)为带标签的样本分别构建源域图和目标域图;(3)将源域图中带标签的样本划分为训练集和验证集,将目标域图中带标签的样本视作测试集;(4)使用源域图的训练集及目标域样本训练图神经网络,得到至少两个参数不同的图神经网络;(5)使用源域图的验证集挑选图神经网络;(6)使用挑选出的图神经网络为目标域的样本预测标签;(7)通过对比目标域中全部带标签样本的真实标签和预测标签,得到评价结果。本发明使图迁移学习时可以从共享节点信息中学习领域无关特征。
Description
技术领域
本发明属于图表示学习领域,尤其是涉及一种基于细粒度领域自适应的图迁移学习方法。
背景技术
在过去几年中,图表示学习领域的研究受到越来越多的关注。图结构的数据可以方便地建模现实世界的关系型数据,比如社交网络,商品推荐,生物网络等。图表示学习旨在将具有复杂结构的图数据转换为保留多样化图属性和结构特征的低维空间中的密集表示。近年来,使用图神经网络(Graph Neural Network,GNN)进行图表示学习受到了广泛的关注。通常,GNN将具有属性的图作为输入,通过邻域状态的加权总和来更新节点的隐藏状态,利用卷积层逐层生成节点级别的表示。通过节点之间的信息传递,图神经网络能够捕捉来自其邻域的信息。
迁移学习是一种应对数据匮乏问题的有效方法。数据匮乏问题可能来源于:目标领域数据量较少;数据获取难度大或标注内容较为复杂,成本较高等。
如公开号为CN111680160A的中国专利文献公开了一种用于文本情感分类的深度迁移学习方法;公开号为CN111046731A的中国专利文献公开了一种基于表面肌电信号进行手势识别的迁移学习方法。
许多机器学习和数据挖掘算法的主要假设是,训练数据和测试数据必须在相同的特征空间中并且具有相同的分布。当分布发生变化时,大多数模型需要根据新收集的训练数据重新构建。但是在众多实际应用中,此假设可能不成立,且重新收集所需的训练数据并重建模型难以实现。在这种情况下,迁移学习放松了训练数据必须与测试数据独立同分布的假设,显著降低了对训练数据和训练时间的需求。使用迁移学习可使深度学习模型从标注数据较为丰富的源领域学习并训练至收敛状态,再将模型应用到目标领域。
领域自适应作为迁移学习中重要的一环,意图将分布不同的源域和目标域的数据映射到同一特征空间,并使两者在该空间中的距离尽可能相近。领域自适应认为领域无关的特征可通过捕捉源领域和目标领域的相似结构学习,但在图结构数据中,领域无关的特征不仅来源于源领域和目标领域的相似图结构,还来源于源领域和目标领域的共有节点。
发明内容
本发明提供了一种基于细粒度领域自适应的图迁移学习方法,使图迁移学习时可以从共享节点信息中学习领域无关特征。
一种基于细粒度领域自适应的图迁移学习方法,其特征在于,包括以下步骤:
(1)采集源域和目标域中的样本,并分别标注源域和目标域中的部分或全部样本,获得带标签的样本,作为监督信息;
(2)为带标签的样本分别构建源域图和目标域图;
(3)将源域图中带标签的样本划分为训练集和验证集,将目标域图中带标签的样本视作测试集;
(4)构建图神经网络,使用源域图的训练集和目标域不带标签的样本训练图神经网络,得到至少两个参数不同的图神经网络;(5)使用源域图的验证集挑选图神经网络;
(6)使用挑选出的图神经网络为目标域的样本预测标签;
(7)通过对比目标域中全部带标签样本的真实标签和预测标签,得到评价结果。
其中,xi表示源域内第i个样本,yi表示第i个样本的标签;xj表示目标域内第j个样本。
步骤(2)中,源域图和目标域图中,图的定义为G=(V,E),其中V为节点集合,E为边集合。
优选地,步骤(4)中,所述的图神经网络包括数据输入模块、特征抽取模块、标签预测模块和领域判别模块;
所述的数据输入模块用于将源域图和目标域图的相关数据转化并输入特征抽取模块;
所述的特征抽取模块用于从源域图和目标域图中抽取具有表达能力的特征,并输入至标签预测模块;
所述的标签预测模块用于预测指定节点的标签信息;
所述的领域判别模块用于判别所有节点来自于源域或来自于目标域。
所述的特征抽取模块包括节点嵌入层和特征抽取器,节点嵌入层的功能是将源域图和目标域图中每个节点变为固定大小的向量表示;特征抽取器的功能是从源域图和目标域图中抽取具有区分度的复杂高维特征,用于提升标签预测模块的标签预测能力。
节点嵌入层的参数使用正态分布随机初始化,随后在网络训练过程中进行更新。
节点嵌入层包含三类节点:源域非共享节点,目标域非共享节点和源域目标域共享节点。
在图神经网络中,本发明使用细粒度领域自适应技术,即预测目标领域标签的误差产生的梯度不变,预测领域来源的误差产生的梯度在领域共享节点上保持不变,在领域非共享节点上变为原来的-λ倍,称为梯度置反(Gradient Reversal)。
具体的,在反向传播更新过程中,源域目标域共享节点按计算所得的原始梯度进行更新,源域非共享节点和目标域非共享节点按梯度置反进行更新,计算所得的原始梯度变为原来的-λ倍。
网络训练过程中,特征抽取器中来自于领域判别模块的梯度将被梯度置反,而来自于标签预测模块的梯度保持不变。
与现有技术相比,本发明具有以下有益效果:
(1)本发明可从源域和目标域的共享节点中学习领域无关特征,避免仅对源域和目标域相似图结构的单一建模方式。
(2)本发明利用所学的领域无关特征,有利于从源域向目标域迁移过程中,充分利用共享节点中存储的信息。
(3)本发明大大减少了目标域对于数据标注的需求。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图做简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动前提下,还可以根据这些附图获得其他附图。
图1为本发明实施例提供的基于细粒度领域自适应的图迁移学习方法的流程示意图;
图2为本发明实施例提供的单细胞测序数据图构建过程示意图;
图3为本发明实施例提供的图神经网络的模型结构示意图。
具体实施方式
下面结合附图和实施例对本发明做进一步详细描述,需要指出的是,以下所述实施例旨在便于对本发明的理解,而对其不起任何限定作用。
本发明提供的基于细粒度领域自适应的图迁移学习方法可用于结合单细胞测序数据进行细胞分类等应用场景。以下实施例中,源域为小鼠肺部的单细胞测序数据,目标域为人体肺部的单细胞测序数据。将单细胞测序数据按一定步骤转化为二部图,得到源域二部图和目标域二部图。利用本方法从源域二部图中训练一个最优的模型,迁移至目标域二部图上进行测试,并得到预测结果和评价指标,从而解决因伦理因素等而导致的人体肺部单细胞测序数据匮乏,标注稀疏等问题;同时,本方法能进一步提升细胞分类的精度。
以下结合附图,详细说明本申请各实施例提供的技术方案。
如图1所示,一种基于细粒度领域自适应的图迁移学习方法,总体流程大致分为几个步骤:(1)数据准备,可参考图2;(2)图神经网络随机初始化;(3)在源域训练集上训练模型,直至收敛;(4)保存多个模型;(5)在源域验证集上挑选模型;(6)在源域测试集上测试模型;(7)迁移至目标域测试模型;(8)得到预测结果,计算评价指标。
如图2所示,为本发明实施例提供的利用单细胞测序数据构建二部图的完整流程的示意图。源域为小鼠肺部的单细胞测序数据,目标域为人体肺部的单细胞测序数据。二部图的节点集合由两个互不相交的细胞节点集合和基因节点集合组成。边集合的每个元素对应某一基因在某一细胞中的表达值,可从原始单细胞测序的表达谱中直接读取。具体而言,单细胞测序数据可表达为一个矩阵M∈Rm×n,其中,m为单细胞测序细胞总个数,n为单细胞中基因总个数,则二部图的邻接矩阵为
至此,本发明提供的细胞-基因二部图样本数据构建完成,样本数据中包含三部分:源域二部图,目标域二部图和源域部分或全部节点的标注信息。
如图3所示,为本发明实施例提供的图神经网络的模型结构示意图,包含4个模块:数据输入模块,特征抽取模块,标签预测模块和领域判别模块。数据输入模块将单细胞测序数据转化为细胞-基因二部图,并输出至特征抽取模块。
特征抽取模块主要由两部分构成:节点嵌入层和特征抽取器。节点嵌入层的功能是将二部图中每个节点变为固定大小的向量表示。节点嵌入层的参数可使用正态分布随机初始化,随后在训练过程中进行更新。节点嵌入层包含三类节点:源域非共享节点,目标域非共享节点和源域目标域共享节点。在反向传播更新过程中,源域目标域共享节点按计算所得的原始梯度进行更新,源域非共享节点和目标域非共享节点按梯度置反进行更新,即计算所得的原始梯度变为原来的-λ倍。
特征抽取模块的另一部分,特征抽取器的功能是从源域二部图和目标域二部图中抽取具有区分度的复杂高维特征,用于提升标签预测模块的标签预测能力,进而提高相关评价指标,得到更可靠的细胞分类结果。为了使得这些特征具有领域无关的特性,特征抽取器中来自于领域判别模块的梯度将被梯度置反,而来自于标签预测模块的梯度保持不变。这使得特征抽取模块的目标之一变为最大化领域判别模块的损失函数。对于领域判别模块而言,区分这些特征来自于源域或目标域可以促进领域判别模块的领域判别能力提升。
标签预测模块的功能是基于给定的标签数据,通过数据输入模块和特征抽取模块得到一个稠密的固定维度的向量表示,并输入至标签预测模块进行预测,得到预测标签,可与真实标签通过预先定义的评价方式计算评价指标。在本实施例中,标签预测任务是为特定的细胞注释细胞类型,是一个多分类任务。在测试阶段,标签预测模块不加训练地直接迁移到目标域二部图上,对目标域二部图上的细胞节点进行预测,得到每个细胞节点的类别。
领域判别模块的功能是对源域二部图或目标域二部图中抽取的特征进行区分,判断来自源域或目标域,是一个二分类任务。在学习过程中,特征抽取模块生成领域无关的特征,尽可能使得领域判别模块难以区分,与领域判别模块本身的功能形成对抗。在对抗过程中,特征抽取模块的特征抽取能力得到提升,进一步提高标签预测模块的标签预测能力。
以上所述的实施例对本发明的技术方案和有益效果进行了详细说明,应理解的是以上所述仅为本发明的具体实施例,并不用于限制本发明,凡在本发明的原则范围内所做的任何修改、补充和等同替换,均应包含在本发明的保护范围之内。
Claims (8)
1.一种基于细粒度领域自适应的图迁移学习方法,其特征在于,包括以下步骤:
(1)采集源域和目标域中的样本,并分别标注源域和目标域中的部分或全部样本,获得带标签的样本,作为监督信息;
(2)为带标签的样本分别构建源域图和目标域图;
(3)将源域图中带标签的样本划分为训练集和验证集,将目标域图中带标签的样本视作测试集;
(4)构建图神经网络,使用源域图的训练集和目标域不带标签的样本训练图神经网络,得到至少两个参数不同的图神经网络;
(5)使用源域图的验证集挑选图神经网络;
(6)使用挑选出的图神经网络为目标域的样本预测标签;
(7)通过对比目标域中全部带标签样本的真实标签和预测标签,得到评价结果。
3.根据权利要求1所述的基于细粒度领域自适应的图迁移学习方法,其特征在于,步骤(2)中,源域图和目标域图中,图的定义为G=(V,E),其中V为节点集合,E为边集合。
4.根据权利要求1所述的基于细粒度领域自适应的图迁移学习方法,其特征在于,步骤(4)中,所述的图神经网络包括数据输入模块、特征抽取模块、标签预测模块和领域判别模块;
所述的数据输入模块用于将源域图和目标域图的相关数据转化并输入特征抽取模块;
所述的特征抽取模块用于从源域图和目标域图中抽取具有表达能力的特征,并输入至标签预测模块;
所述的标签预测模块用于预测指定节点的标签信息;
所述的领域判别模块用于判别所有节点来自于源域或来自于目标域。
5.根据权利要求4所述的基于细粒度领域自适应的图迁移学习方法,其特征在于,所述的特征抽取模块包括节点嵌入层和特征抽取器,节点嵌入层的功能是将源域图和目标域图中每个节点变为固定大小的向量表示;特征抽取器的功能是从源域图和目标域图中抽取具有区分度的复杂高维特征,用于提升标签预测模块的标签预测能力。
6.根据权利要求5所述的基于细粒度领域自适应的图迁移学习方法,其特征在于,节点嵌入层的参数使用正态分布随机初始化,随后在网络训练过程中进行更新。
7.根据权利要求6所述的基于细粒度领域自适应的图迁移学习方法,其特征在于,节点嵌入层包含三类节点:源域非共享节点,目标域非共享节点和源域目标域共享节点;
在反向传播更新过程中,源域和目标域的共享节点按计算所得的原始梯度进行更新,源域非共享节点和目标域非共享节点按梯度置反进行更新,计算所得的原始梯度变为原来的-λ倍。
8.根据权利要求5所述的基于细粒度领域自适应的图迁移学习方法,其特征在于,网络训练过程中,特征抽取器中来自于领域判别模块的梯度将被梯度置反,而来自于标签预测模块的梯度保持不变。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011561512.7A CN112668633B (zh) | 2020-12-25 | 2020-12-25 | 一种基于细粒度领域自适应的图迁移学习方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011561512.7A CN112668633B (zh) | 2020-12-25 | 2020-12-25 | 一种基于细粒度领域自适应的图迁移学习方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112668633A true CN112668633A (zh) | 2021-04-16 |
CN112668633B CN112668633B (zh) | 2022-10-14 |
Family
ID=75408938
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011561512.7A Active CN112668633B (zh) | 2020-12-25 | 2020-12-25 | 一种基于细粒度领域自适应的图迁移学习方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112668633B (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114098768A (zh) * | 2021-11-25 | 2022-03-01 | 哈尔滨工业大学 | 基于动态阈值和EasyTL的跨个体表面肌电信号手势识别方法 |
CN115719514A (zh) * | 2022-11-23 | 2023-02-28 | 南京理工大学 | 一种面向手势识别的领域自适应方法及系统 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108875827A (zh) * | 2018-06-15 | 2018-11-23 | 广州深域信息科技有限公司 | 一种细粒度图像分类的方法及系统 |
CN109947086A (zh) * | 2019-04-11 | 2019-06-28 | 清华大学 | 基于对抗学习的机械故障迁移诊断方法及系统 |
US20200160177A1 (en) * | 2018-11-16 | 2020-05-21 | Royal Bank Of Canada | System and method for a convolutional neural network for multi-label classification with partial annotations |
CN111814977A (zh) * | 2020-08-28 | 2020-10-23 | 支付宝(杭州)信息技术有限公司 | 训练事件预测模型的方法及装置 |
CN112016687A (zh) * | 2020-08-20 | 2020-12-01 | 浙江大学 | 一种基于互补伪标签的跨域行人重识别方法 |
-
2020
- 2020-12-25 CN CN202011561512.7A patent/CN112668633B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108875827A (zh) * | 2018-06-15 | 2018-11-23 | 广州深域信息科技有限公司 | 一种细粒度图像分类的方法及系统 |
US20200160177A1 (en) * | 2018-11-16 | 2020-05-21 | Royal Bank Of Canada | System and method for a convolutional neural network for multi-label classification with partial annotations |
CN109947086A (zh) * | 2019-04-11 | 2019-06-28 | 清华大学 | 基于对抗学习的机械故障迁移诊断方法及系统 |
CN112016687A (zh) * | 2020-08-20 | 2020-12-01 | 浙江大学 | 一种基于互补伪标签的跨域行人重识别方法 |
CN111814977A (zh) * | 2020-08-28 | 2020-10-23 | 支付宝(杭州)信息技术有限公司 | 训练事件预测模型的方法及装置 |
Non-Patent Citations (2)
Title |
---|
YUNSHENG BAI ET AL: "《Graph Edit Distance Computation via Graph Neural Networks》", 《ARXIV:1808.05689V3》, 3 October 2018 (2018-10-03), pages 1 - 21 * |
赵传君 等;: "《跨领域文本情感分类研究进展》", 《软件学报》, 21 April 2020 (2020-04-21), pages 1723 - 1746 * |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114098768A (zh) * | 2021-11-25 | 2022-03-01 | 哈尔滨工业大学 | 基于动态阈值和EasyTL的跨个体表面肌电信号手势识别方法 |
CN114098768B (zh) * | 2021-11-25 | 2024-05-03 | 哈尔滨工业大学 | 基于动态阈值和EasyTL的跨个体表面肌电信号手势识别方法 |
CN115719514A (zh) * | 2022-11-23 | 2023-02-28 | 南京理工大学 | 一种面向手势识别的领域自适应方法及系统 |
CN115719514B (zh) * | 2022-11-23 | 2023-06-30 | 南京理工大学 | 一种面向手势识别的领域自适应方法及系统 |
Also Published As
Publication number | Publication date |
---|---|
CN112668633B (zh) | 2022-10-14 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111368528B (zh) | 一种面向医学文本的实体关系联合抽取方法 | |
CN111753101B (zh) | 一种融合实体描述及类型的知识图谱表示学习方法 | |
CN107944410B (zh) | 一种基于卷积神经网络的跨领域面部特征解析方法 | |
CN109002834B (zh) | 基于多模态表征的细粒度图像分类方法 | |
CN112131404A (zh) | 一种四险一金领域知识图谱中实体对齐方法 | |
CN112699247A (zh) | 一种基于多类交叉熵对比补全编码的知识表示学习框架 | |
CN112906770A (zh) | 一种基于跨模态融合的深度聚类方法及系统 | |
CN112711953A (zh) | 一种基于注意力机制和gcn的文本多标签分类方法和系统 | |
CN112308115B (zh) | 一种多标签图像深度学习分类方法及设备 | |
CN112199536A (zh) | 一种基于跨模态的快速多标签图像分类方法和系统 | |
CN112256866B (zh) | 一种基于深度学习的文本细粒度情感分析算法 | |
CN112668633B (zh) | 一种基于细粒度领域自适应的图迁移学习方法 | |
CN110598022B (zh) | 一种基于鲁棒深度哈希网络的图像检索系统与方法 | |
CN113535953B (zh) | 一种基于元学习的少样本分类方法 | |
CN105046323B (zh) | 一种正则化rbf网络多标签分类方法 | |
CN111723930A (zh) | 一种应用群智监督学习方法的系统 | |
CN114912423A (zh) | 一种基于迁移学习的方面级别情感分析方法及装置 | |
CN111222318A (zh) | 基于双通道双向lstm-crf网络的触发词识别方法 | |
CN111582506A (zh) | 基于全局和局部标记关系的偏多标记学习方法 | |
Chu et al. | Co-training based on semi-supervised ensemble classification approach for multi-label data stream | |
CN115661550A (zh) | 基于生成对抗网络的图数据类别不平衡分类方法及装置 | |
CN115062727A (zh) | 一种基于多阶超图卷积网络的图节点分类方法及系统 | |
CN114897085A (zh) | 一种基于封闭子图链路预测的聚类方法及计算机设备 | |
Lonij et al. | Open-world visual recognition using knowledge graphs | |
CN113849653A (zh) | 一种文本分类方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |