CN114139676A - 领域自适应神经网络的训练方法 - Google Patents

领域自适应神经网络的训练方法 Download PDF

Info

Publication number
CN114139676A
CN114139676A CN202010911149.0A CN202010911149A CN114139676A CN 114139676 A CN114139676 A CN 114139676A CN 202010911149 A CN202010911149 A CN 202010911149A CN 114139676 A CN114139676 A CN 114139676A
Authority
CN
China
Prior art keywords
target data
loss function
class
probability
source data
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202010911149.0A
Other languages
English (en)
Inventor
汪洁
钟朝亮
冯成
孙俊
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Fujitsu Ltd
Original Assignee
Fujitsu Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Fujitsu Ltd filed Critical Fujitsu Ltd
Priority to CN202010911149.0A priority Critical patent/CN114139676A/zh
Priority to JP2021136658A priority patent/JP2022042487A/ja
Publication of CN114139676A publication Critical patent/CN114139676A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Probability & Statistics with Applications (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)

Abstract

公开了领域自适应神经网络的训练方法,包括:针对源数据和目标数据提取特征;基于提取的特征为目标数据预测第一标签;基于源数据集合上每个类别的类中心与目标数据的特征之间的距离,为目标数据确定第二标签;在目标数据集合中选择第一标签与第二标签相同的目标数据,并且第一或第二标签作为所选择的目标数据的伪标签;基于所选择的目标数据计算目标数据集合上每个类别的类中心;基于源数据集合的类中心和所计算的目标数据集合的类中心之间的距离构建第一损失函数;基于所选择的目标数据以及其伪标签构建第二损失函数;针对源数据集合中的源数据以及所选择的目标数据构建第三损失函数;基于第一至第三损失函数来训练该神经网络。

Description

领域自适应神经网络的训练方法
技术领域
本发明总体上涉及领域自适应(domain adaptation),更具体地,涉及用于无监督领域自适应的神经网络及其训练方法。
背景技术
无监督领域自适应是指将利用已标记的源数据而训练的模型迁移到未标记数据的目标域,并且同时尽可能保持该模型在目标域的性能。由于源域和目标域之间存在数据集偏差,并且目标域缺乏已标记数据,因此利用已标记的源数据训练得到的模型在目标域的性能往往很差。无监督领域自适应的训练过程同时利用了源域的已标记数据和目标域的未标记数据,可以有效缓解域差异,提高模型的鲁棒性。
目前,无监督领域自适应的主流方法包括以对抗训练为代表的学习域不变特征的方法。一个典型的对抗训练方法是领域对抗神经网络,其中,在特征提取网络后增加域判别器来判断特征是来自于源域还是目标域,并且在特征提取网络和域判别器之间增加梯度反转层。在使域判别器的损失函数最小化的时候,通过梯度反转层使得特征提取网络能够学习到域不变的特征。
此外,知识蒸馏近来被引入到无监督领域自适应问题中,发展了很多新的方法,例如包括:使用自集成的平均教师模型来引导学生模型学习目标域的未标记数据;利用自集成的教师模型获得更准确的目标数据的伪标签;从源数据中蒸馏出与目标数据类似的数据,以对预训练模型进行微调;在语义层面(类别层面)对源域和目标域的特征进行对齐,即,拉近源域和目标域中的相同类别的平均特征(类中心)。
下面对这些现有的方法进行简要介绍。
图1示出了典型的领域对抗神经网络的架构。如图1所示,领域对抗神经网络包括特征提取器F、分类器Cs和域判别器D。域判别器D通过梯度反转层与特征提取器F相连,梯度反转层将梯度乘以特定的负数后传回至特征提取器F。Is表示已标记的源数据,It表示未标记的目标数据,二者均被输入至特征提取器F。由特征提取器F针对源数据提取的特征被输入至分类器Cs,以预测源数据的类别。此外,由特征提取器F针对源数据和目标数据二者提取的特征均被输入至域判别器D,域判别器D根据输入的特征来判别当前处理的数据是来自于源域还是目标域。领域对抗神经网络的训练中采用针对源域的分类交叉熵损失函数Lc以及域判别中的二值交叉熵损失函数Ladv,以使得损失函数Lc和Ladv最小化为目标,按照标准的反向传播算法进行训练,从而使得特征提取器F学习到域不变的特征。
图2示出了自集成教师模型的架构,其中利用学生网络的参数的指数移动平均来构造教师网络。在图2中,xSi表示已标记的源数据,xTi表示未标记的目标数据,ySi表示源数据的真实标签,zTi表示学生网络对目标数据的预测概率,
Figure BDA0002663318320000021
表示教师网络对目标数据的预测概率。
该方案的前提假设是教师网络的预测准确率高于学生网络,学生网络可以从教师网络的预测概率中学习到目标数据的隐知识,因此这是一种知识蒸馏方法。针对源数据xSi,采用基于学生网络的预测概率zTi和真实标签ySi的交叉熵损失函数。针对目标数据xTi,采用教师网络的预测概率
Figure BDA0002663318320000022
和学生网络的预测概率zTi的均方误差作为损失函数。然后,将以上两种损失函数进行加权相加,以获得最终的损失函数。
另外,关于语义层面的特征对齐,已经提出了以下损失函数:
Figure BDA0002663318320000023
Figure BDA0002663318320000024
其中,Xs,k表示源域Xs中属于第k类的所有数据样本(根据真实标签确定),Xt,k表示目标域Xt中被标记为第k类的所有数据样本(根据伪标签确定)。λs,k表示源域中第k类的类中心,即属于第k类的所有源数据的特征F的平均值。类似地,λt,k表示目标域中第k类的类中心,即被标记为第k类的所有目标数据的特征F的平均值。目标数据的伪标签是利用分类器对目标数据的类别进行预测而得到的。数学式(1)所示的语义对齐损失函数La(Xs,Xt)表示源域和目标域中同一类别的类中心之间的距离。
尽管上述方法实现了不错的效果,但它们仍然存在一些值得改进的问题。首先,对于语义对齐而言,目标数据的伪标签的正确性对目标域中的类中心影响较大。对于一些位于分界面周围的数据,如果伪标签错误,类中心的计算结果将会出现较大的偏差。其次,对于对比学习而言,错误的伪标签会损害类内数据样本聚集和类间数据样本分离的约束。此外,对于自集成的平均教师模型而言,在指数移动平均中往往使用一个固定的衰减率,然而当前模型的性能是变化的,固定的衰减率无法根据当前模型的性能来调节集成的速率。此外,对于使用蒸馏数据来进行微调而言,该方法需要两个阶段,增加了中间切换的操作,不能一步到位完成训练。
发明内容
根据本发明的一个方面,提供了一种由计算机实现的用于训练领域自适应神经网络的方法,其中所述领域自适应神经网络包括第一特征提取单元、第一分类单元以及判别单元,其中,所述计算机包括存储有指令的存储器和处理器,所述指令在被所述处理器执行时使得所述处理器执行所述方法,所述方法包括:由所述第一特征提取单元针对已标记的源数据集合中的源数据提取第一特征,以及由所述第一分类单元基于所述第一特征来预测所述源数据属于多个类别中的每个类别的概率;由所述第一特征提取单元针对未标记的目标数据集合中的目标数据提取第二特征,以及由所述第一分类单元基于所述第二特征来预测所述目标数据属于所述每个类别的概率,并且将对应于最大概率的类别确定为所述目标数据的第一标签;计算源数据集合的针对所述每个类别的类中心与所述目标数据的特征之间的距离,并且将距离最近的类中心所对应的类别确定为所述目标数据的第二标签;在所述目标数据集合中选择对其确定的所述第一标签与所述第二标签相同的目标数据,其中,所述第一标签或所述第二标签作为所选择的目标数据的伪标签;基于所选择的目标数据计算所述目标数据集合的针对所述每个类别的类中心;基于所述源数据集合的类中心和所计算的目标数据集合的类中心之间的距离来构建第一损失函数;基于所选择的目标数据以及其伪标签来构建第二损失函数;针对所述源数据集合中的源数据以及所选择的目标数据构建第三损失函数;基于所述第一损失函数、所述第二损失函数和所述第三损失函数来训练所述领域自适应神经网络。
根据本发明的另一个方面,提供了一种用于训练领域自适应神经网络的装置,所述领域自适应神经网络包括:第一特征提取单元,其用于针对已标记的源数据集合中的源数据提取第一特征,并且针对未标记的目标数据集合中的目标数据提取第二特征;第一分类单元,其基于所述第一特征来预测所述源数据属于多个类别中的每个类别的概率,并且基于所述第二特征来预测所述目标数据属于所述每个类别的概率,并且将对应于最大概率的类别确定为所述目标数据的第一标签;以及判别单元,其基于所述第一特征和所述第二特征来确定当前输入的数据是源数据的概率;所述装置包括:存储有程序的存储器;以及一个或多个处理器,所述处理器通过执行所述程序而执行以下操作:计算源数据集合的针对所述每个类别的类中心与所述目标数据的特征之间的距离,并且将距离最近的类中心所对应的类别确定为所述目标数据的第二标签;在所述目标数据集合中选择对其确定的所述第一标签与所述第二标签相同的目标数据,其中,所述第一标签或所述第二标签作为所选择的目标数据的伪标签;基于所选择的目标数据计算所述目标数据集合的针对所述每个类别的类中心;基于所述源数据集合的类中心和所计算的目标数据集合的类中心之间的距离来构建第一损失函数;基于所选择的目标数据以及其伪标签来构建第二损失函数;针对所述源数据集合中的源数据以及所选择的目标数据构建第三损失函数;基于所述第一损失函数、所述第二损失函数和所述第三损失函数来训练所述领域自适应神经网络。
根据本发明的另一个方面,提供了一种存储有用于训练领域自适应神经网络的程序的存储介质,所述领域自适应神经网络包括第一特征提取单元、第一分类单元以及判别单元,所述程序在被计算机执行时使得所述计算机执行包括以下步骤的方法:由所述第一特征提取单元针对已标记的源数据集合中的源数据提取第一特征,以及由所述第一分类单元基于所述第一特征来预测所述源数据属于多个类别中的每个类别的概率;由所述第一特征提取单元针对未标记的目标数据集合中的目标数据提取第二特征,以及由所述第一分类单元基于所述第二特征来预测所述目标数据属于所述每个类别的概率,并且将对应于最大概率的类别确定为所述目标数据的第一标签;计算源数据集合的针对所述每个类别的类中心与所述目标数据的特征之间的距离,并且将距离最近的类中心所对应的类别确定为所述目标数据的第二标签;在所述目标数据集合中选择对其确定的所述第一标签与所述第二标签相同的目标数据,其中,所述第一标签或所述第二标签作为所选择的目标数据的伪标签;基于所选择的目标数据计算所述目标数据集合的针对所述每个类别的类中心;基于所述源数据集合的类中心和所计算的目标数据集合的类中心之间的距离来构建第一损失函数;基于所选择的目标数据以及其伪标签来构建第二损失函数;针对所述源数据集合中的源数据以及所选择的目标数据构建第三损失函数;基于所述第一损失函数、所述第二损失函数和所述第三损失函数来训练所述领域自适应神经网络。
附图说明
图1示意性地示出了现有的领域对抗神经网络的架构。
图2示意性地示出了现有的自集成教师模型的架构。
图3示意性地示出了根据本发明的领域自适应神经网络的架构。
图4示意性地示出了根据本发明的自集成教师模型的架构。
图5示出了权值λ1的曲线。
图6示出了权值λ2的曲线。
图7示出了根据本发明的生成优选目标数据集的方法的流程图。
图8示出了根据本发明的领域自适应神经网络的训练方法的流程图。
图9示出了根据本发明的领域自适应神经网络的训练装置的模块化框图。
图10示出了实现本发明的计算机硬件的示例性配置框图。
具体实施方式
图3示意性地示出了根据本发明的用于无监督领域自适应的神经网络的架构。如图3所示,该神经网络包括参考图1所描述的领域对抗神经网络,其包括第一特征提取器310、第一分类器320、域判别器330以及梯度反转层(未示出)。此外,该神经网络还包括第二特征提取器310_T和第二分类器320_T。需要说明的是,作为已知技术,图3中的特征提取器310,310_T,分类器320,320_T,以及域判别器330均可以由卷积神经网络来实现。本文中将不再详细描述实现这些单元的卷积神经网络的结构。
第一特征提取器310和第一分类器320组成学生网络,第二特征提取器310_T和第二分类器320_T组成教师网络。第二(教师)特征提取器310_T的参数是第一(学生)特征提取器310的参数的指数移动平均,第二(教师)分类器320_T的参数是第一(学生)分类器320的参数的指数移动平均。
源数据Xs和目标数据Xt被输入到第一特征提取器310和第二特征提取器310_T中的每一个。第一特征提取器310将针对源数据Xs和目标数据Xt提取的特征输入至第一分类器320,第二特征提取器310_T将针对源数据Xs和目标数据Xt提取的特征输入至第二分类器320_T。
在图3所示的领域自适应神经网络的训练中,本发明提出了多个损失函数,将在下文中详细描述。
作为本发明的一个方面,提出了采用投票策略来提高目标数据的伪标签的准确性。投票策略是指使用至少两种预测方式来对目标数据的预测标签进行投票。作为一个示例,针对目标数据
Figure BDA0002663318320000061
利用分类器预测其类别标签以获得预测结果
Figure BDA0002663318320000062
此外还采用类中心最近邻算法来预测其标签,以获得预测结果ld,如以下数学式(2)和(3)所示。
Figure BDA0002663318320000063
Figure BDA0002663318320000064
其中,λs,k表示源域中第k类的类中心,即,属于第k类的所有源数据的特征的平均值,K表示源域中全部类别的数目,ld表示在源域的全部K个类中心当中、与目标数据
Figure BDA0002663318320000065
距离最近的类中心所对应的类别。
如果预测标签lc和预测标签ld一致,则将该目标数据
Figure BDA0002663318320000066
选中,并且将预测标签lc或ld作为该目标数据
Figure BDA00026633183200000611
的伪标签。如果预测标签lc和预测标签ld不一致,则丢弃该目标数据
Figure BDA0002663318320000067
所有被选中的目标数据
Figure BDA0002663318320000068
构成优选目标数据集
Figure BDA0002663318320000069
相比于仅执行分类器预测或仅执行类中心最近邻预测的情形,以此方式筛选出的数据集
Figure BDA00026633183200000610
中的每个目标数据的伪标签的准确率更高。因此,根据本发明的投票策略能够有效地筛选出预测结果更为准确的目标数据。
需要说明的是,以上描述的分类器预测和类中心最近邻预测仅作为至少两种不同预测方式的示例,但本发明并不限于此,本领域技术人员易于想到采用其它适当的预测方式。
然后,基于优选目标数据集
Figure BDA0002663318320000071
来构建用于训练图3所示的神经网络的语义对齐损失函数La(图3中未示出),也称为第一损失函数。具体来说,以K个预定类别中的第k类为例,首先根据数学式(2)计算源域中第k类的类中心λs,k,并且根据以下数学式(4)来计算优选目标数据集
Figure BDA0002663318320000072
中第k类的类中心λt,k,然后根据数学式(5)来计算类中心λs,k与类中心λt,k之间的距离
Figure BDA0002663318320000073
以此方式,针对所有K个类别分别计算源域的类中心与目标域的类中心之间的距离,作为语义对齐损失函数。在训练中以使得距离
Figure BDA0002663318320000074
最小化作为目标。
Figure BDA0002663318320000075
Figure BDA0002663318320000076
由于优选目标数据集
Figure BDA0002663318320000077
中的目标数据的伪标签具有更高的准确性,因此利用数据集
Figure BDA0002663318320000078
而计算的目标域类中心λt,k更为准确,从而有助于提升语义对齐损失函数的作用。
此外,可以利用优选目标数据集
Figure BDA0002663318320000079
中的目标数据及其伪标签来构建用于训练图3所示的第一分类器320的交叉熵损失函数(图3中的
Figure BDA00026633183200000710
),也称为第二损失函数,具体如以下数学式(6)所示:
Figure BDA00026633183200000711
其中,
Figure BDA00026633183200000712
表示在对优选目标数据集
Figure BDA00026633183200000713
中的目标数据
Figure BDA00026633183200000714
预测标签时,预测结果是其伪标签的概率。
现有技术中通常只利用具有真实标签的源数据对第一分类器320进行训练,而在本发明中,由于优选目标数据集
Figure BDA0002663318320000081
中的目标数据的伪标签的准确性较高,因此本发明进一步利用优选目标数据集
Figure BDA0002663318320000082
对第一分类器320进行训练,这有助于提升网络模型对目标数据的识别能力。
此外,优选目标数据集
Figure BDA0002663318320000083
中的目标数据可以与源数据一起被用于对比学习,以实现以下效果:约束类内特征,使其更为紧致,同时推开类间特征,使不同的类的特征之间的距离增大。在这一方面,可以构建如数学式(7)所示的对比学习损失函数Lcon(图3中未示出),也称为第三损失函数。
Figure BDA0002663318320000084
其中,xi或xj表示源数据集以及优选目标数据集
Figure BDA0002663318320000085
中的某一数据样本,f(xi)和f(xj)表示数据样本的特征。δij是指示变量,当xi和xj为同一类的数据时,δij为1;当xi和xj是不同类的数据时,δij为0。d(f(xi),f(xj))表示数据xi和数据xj的特征之间的距离。m是常数,例如m=3。
如前文所述,目前用于无监督领域自适应的知识蒸馏方法利用指数移动平均来构造教师网络,但是其中的衰减率通常设置为固定值,因此难以获得性能好的教师网络。具体而言,指数移动平均是指根据一定衰减率来缓慢地更新教师网络的参数,如以下数学式(8)所示:
Tt=decay*Tt-1+(1-decay)*S, -(8)
其中,S表示学生网络的当前参数,Tt表示教师网络的当前参数(更新后的参数),Tt-1表示教师网络的先前参数(未更新的参数),衰减率decay通常被固定地设置为0.99。
作为本发明的另一个方面,本发明提出了自学习的衰减率,以改善教师模型的性能。“自学习”是指衰减率是一个可学习的参数或是经学习网络(learnt network)的输出。在本发明中,可以使用可导变量作为衰减率,或者使用一个全连接层的输出作为衰减率。在后者情况下,例如可以将该全连接层设置在与第二分类器320_T的输出层相同的层级,使得该全连接层与输出层并行地连接到输出层的前一层。以这两种方式设置的衰减率不再是固定值,其能够根据模型变化的性能来调节集成的速率,因此有助于提升知识蒸馏的性能。
此外,作为本发明的另一个方面,本发明提出了基于域判别器的数据蒸馏。具体而言,在利用源数据基于交叉熵损失函数训练分类器时,对源数据中的与目标数据相似的源数据赋予更高的权重。通过这样做,与目标数据相似性高的源数据可以在训练中起到较大的作用,由此训练得到的分类器能够在目标域中实现更好的性能。
可以借助于域判别器的输出来判断哪些源数据与目标数据的相似性高。域判别器可以预测当前数据是源数据的概率,因此当此概率越小时,表明当前数据和目标数据的相似性越大。换言之,域判别器输出的概率与相似性之间存在相反的关系。由此,可以利用域判别器的输出来为源数据加权。
基于此原理,可以构建用于训练图3所示的神经网络的数据蒸馏损失函数Ldd(图3中未示出),也称为第四损失函数,如以下数学式(9)或(10)所示:
Ldd=∑-(1-pd)log(ps) -(9)
或者
Ldd=∑-(1/pd)log(ps) -(10)
其中,ps表示在对源数据预测标签时,预测结果是其真实标签的概率。pd表示域判别器确定的源数据来自于源域的概率,1-pd或1/pd表示对该源数据赋予的权重。
当域判别器确定的概率pd较小时(表明当前源数据和目标数据的相似性较高),则1-pd或1/pd的值较大,因此对当前源数据赋予的权重较大。由此,(与目标数据相似的)当前源数据可以在训练中起到较大的作用。
此外,作为本发明的另一个方面,本发明还改进了图2所示的自集成教师模型的架构。图4示出了改进后的网络架构。
如图4所示,源数据xSi和目标数据xTi不仅被输入到学生网络,也被输入到教师网络。相比之下,在图2中对教师网络仅输入了目标数据xTi。因此,本发明不仅针对目标域进行蒸馏学习,也针对源域进行蒸馏学习。
在图4中,ySi表示源数据xSi的真实标签,zTi表示学生网络针对目标数据xTi预测的概率(即目标数据xTi属于每个类别的概率),
Figure BDA0002663318320000101
表示教师网络针对目标数据xTi预测的概率,zSi表示学生网络针对源数据xSi预测的概率(即源数据xSi属于每个类别的概率),
Figure BDA0002663318320000102
表示教师网络针对源数据xSi预测的概率。此外,图4中的学生网络可以包括图3所示的第一特征提取器310和第一分类器320,图4中的教师网络可以包括图3中所示的第二特征提取器310_T和第二分类器320_T,上述各个预测概率可以由第一分类器320或第二分类器320_T产生。
基于以上预测概率,可以构建用于训练图3所示的神经网络的知识蒸馏损失函数Lkd(包括图3中的Lkd-s和Lkd-t),也称为第五损失函数,如以下数学式(11)所示:
Figure BDA0002663318320000103
其中,
Figure BDA0002663318320000104
表示第一分类器320和第二分类器320_T各自针对源数据xSi预测的概率的均方误差,
Figure BDA0002663318320000105
Figure BDA0002663318320000106
表示第一分类器320和第二分类器320_T各自针对目标数据xTi预测的概率的均方误差。n表示源数据的数目,m表示目标数据的数目。
基于在上文中讨论的第一至第五损失函数,可以构建用于训练图3所示的神经网络的最终损失函数L,如数学式(12)所示:
Figure BDA0002663318320000107
其中,Lc-s表示针对源数据的分类交叉熵损失函数,与图1中所示的损失函数Lc相同。Ladv表示域判别器的二值交叉熵损失函数,与图1中所示的损失函数Ladv相同。由于损失函数Lc-s和Ladv是现有技术中已知的损失函数,因此本文中将省略其详细描述。
此外,数学式(12)中的λ1和λ2分别是对第四损失函数Lkd和第五损失函数Ldd加权的权值,可以用于控制在训练过程中第四损失函数和第五损失函数起作用的程度。具体而言,可以根据数学式(13)来确定权值λ1
λ1=α·pn -(13)
其中,p=step/totalstep,即,当前迭代步数除以训练的总步数的商,因此p可以表示训练进度。α和n表示超参数,例如,可以设置α=200,n=10。图5示出了权值λ1随着训练步数的增加而变化的曲线(假设训练的总步数为5000)。
权值λ2可以根据数学式(14)来确定:
λ2=α·min((2p)n,1) -(14)
其中,p与数学式(13)中的p含义相同。α和n表示超参数,例如,可以设置α=5,n=10。图6示出了权值λ2随着训练步数的增加而变化的曲线(假设训练的总步数为5000)。
如图5和图6所示,在训练的开始阶段,由于分类器的预测和域判别器的预测都不准确,优选的是将λ1和λ2的值设置得较小,随着训练的进行,教师网络的分类器和域判别器的预测逐渐变得准确,因此可以逐渐增大λ1和λ2的值,以使得知识蒸馏损失函数Lkd和数据蒸馏损失函数Ldd可以起到更大的作用。
图7示出了根据本发明的生成优选目标数据集的方法的流程图。该方法可以由图9中的优选目标数据集生成单元960来执行。
如图7所示,在步骤S710,由第一特征提取器310针对源数据提取特征,并且由第一分类器320基于所提取的特征来预测源数据属于多个预定类别中的每个类别的概率。对应于最大概率的类别将被确定为该源数据的标签。
在步骤S720,由第一特征提取器310针对目标数据提取特征,并且由第一分类器320基于所提取的特征来预测目标数据属于每个类别的概率。对应于最大概率的类别将被确定为该目标数据的第一标签。
在步骤S730,根据数学式(2)和(3),采用类中心最近邻算法来确定该目标数据的第二标签。
在步骤S740,选择对其确定的第一标签与第二标签相同的目标数据,该第一标签或第二标签就作为所选择的目标数据的伪标签。然后,所有被选择的目标数据可以构成优选目标数据集。
图8示出了根据本发明的领域自适应神经网络的训练方法的流程图,图9示出了根据本发明的领域自适应神经网络的训练装置的模块化框图。
如图8所示,在步骤S810,根据数学式(2)、(4)和(5),基于源数据集的类中心和优选目标数据集的类中心之间的距离来构建第一损失函数La(语义对齐损失函数)。这一步骤可以由图9中的第一损失函数生成单元910来执行。
在步骤S820,根据数学式(6),基于优选目标数据集中的目标数据以及其伪标签来构建第二损失函数
Figure BDA0002663318320000121
(交叉熵损失函数)。这一步骤可以由图9中的第二损失函数生成单元920来执行。
在步骤S830,根据数学式(7),针对源数据集中的源数据以及优选目标数据集中的目标数据构建第三损失函数Lcon(对比学习损失函数)。这一步骤可以由图9中的第三损失函数生成单元930来执行。
结合图9可以看到,通过图7所示的方法而生成的优选目标数据集被用于第一损失函数至第三损失函数的构建。
然后,在步骤S840,根据数学式(9)或(10),基于域判别器输出的概率来构建第四损失函数Ldd(数据蒸馏损失函数)。这一步骤可以由图9中的第四损失函数生成单元940来执行。
在步骤S850,由第二(教师)特征提取器310_T提取源数据和目标数据的特征,并且由第二(教师)分类器320_T预测源数据和目标数据的标签。然后在步骤S860,根据数学式(11),基于第一分类器320的预测结果和第二分类器320_T的预测结果来构建第五损失函数Lkd(知识蒸馏损失函数)。步骤S860可以由图9中的第五损失函数生成单元950来执行。
然后在步骤S870,根据数学式(12),基于第一损失函数至第五损失函数的加权组合来训练神经网络。这一步骤可以由图9中的训练单元970来执行。
需要说明的是,不是必须按照图8中所示的顺序来执行本发明的训练方法。例如,生成第一损失函数至第五损失函数的顺序可以与图中所示的不同,或者可以同时生成。
本发明人已经基于MNIST,USPS,SVHN(均为公知的字符数据集)进行了测试,包括三个方向的领域自适应,即,MNIST→USPS,USPS→MNIST,SVHN→MNIST。以下表1示出了本发明的方案与现有技术(ADDA、DANN等)的性能对比。表1中的数值表示分类准确率,准确率越高,方案的性能越好。可以看出,本发明的方案与现有方案的性能相当或甚至更优。
MNIST→USPS USPS→MNIST SVHN→MNIST
source only 81.6±0.02 52.1±0.1 73.8±0.06
DANN 77.1±1.8 73.0±2.0 73.9
ADDA 89.4±0.2 90.1±0.8 76.0±1.8
CAT+RevGrad 94.0±0.7 96.0±0.9 98.8±0.02
本发明 96.5±0.01 96.1±0.0 98.3±0.0
特别地,表1中的“source only”表示只利用源数据、而不利用目标数据进行训练的方案,是最简单的方案,作为比较的基准。DANN(Domain-Adversarial Training ofNeural Networks)表示图1所示的领域对抗神经网络,ADDA(Adversarial DiscriminativeDomain Adaptation)表示对抗判别领域自适应。CAT+RevGrad在以下技术文献中有所描述:“Cluster Alignment with a Teacher for Unsupervised Domain Adaptation[C]”,DengZ等人,IEEE计算机视觉国际会议论文集,2019:9944-9953。
根据本发明的无监督领域自适应技术能够应用于广泛的领域,以下仅以举例方式给出有代表性的应用场景。
[应用场景一]语义分割(semantic segmentation)
语义分割是指将图像中表示不同物体的部分用不同颜色标识出来。在语义分割的应用场景中,由于对真实世界的图像进行人工标记的代价非常高,因此真实世界的图像很少是带有标签的。在此情况下,一种替代方法是利用仿真环境(如3D游戏)中的场景的图像来进行训练。由于在仿真环境中很容易通过编程来实现对物体的自动标记,因此很容易得到有标签的数据。这样,利用仿真环境中生成的有标签的数据来训练模型,然后利用经训练的模型来处理真实环境的图像。但是,由于仿真环境不可能与真实环境完全一致,因此利用仿真环境的数据所训练的模型在处理真实环境的图像时性能会大打折扣。
在此情况下,使用本发明的领域自适应技术,可以基于有标签的仿真环境数据和无标签的真实环境数据进行训练,从而提高模型处理真实环境图像的性能。
[应用场景二]手写字符的识别
手写字符通常包括手写的数字、文字(如中文、日文)等。在手写字符的识别中,常用的有标签的字符集包括MNIST、USPS、SVHN等,通常利用这些有标签的字符数据来训练模型。然而,在将经训练的模型应用于实际(无标签)的手写字符的识别时,其准确率可能会降低。
在此情况下,使用本发明的领域自适应技术,可以基于有标签的源数据和无标签的目标数据进行训练,从而提高模型处理目标数据的性能。
[应用场景三]时间序列数据的分类和预测
时间序列数据的预测例如包括空气污染指数预测、ICU病人住院时长(LOS)的预测、股票行情预测等等。以细颗粒物PM 2.5指数的时间序列数据为例,可以利用具有标签的训练样本集来训练预测模型。在训练完成后,可以将训练好的模型应用于实际预测中,例如,基于当前时刻之前24个小时的数据(无标签数据)来预测三天后的PM 2.5指数的范围。
在此场景中,通过使用本发明的领域自适应技术,可以基于有标签的数据和无标签的数据来训练模型,从而提高模型的预测准确度。
[应用场景四]表格型数据的分类和预测
表格型数据可以包括金融数据,例如网络借贷数据。在此示例中,为了预测贷款者是否存在逾期还款的可能性,可以构建预测模型,并且使用根据本发明的方法来训练模型。
[应用场景五]图像识别
与语义分割类似,在图像识别或图像分类的应用场景中,也存在着对于真实世界的图像数据集进行标记的代价高昂的问题。因此,可以使用本发明的领域自适应技术,选择一个已标记的数据集(如ImageNet)作为源数据集,基于该源数据集和未标记的目标数据集进行训练,从而获得性能满足要求的模型。
在上述实施例中描述的方法可以由软件、硬件或者软件和硬件的组合来实现。包括在软件中的程序可以事先存储在设备的内部或外部所设置的存储介质中。作为一个示例,在执行期间,这些程序被写入随机存取存储器(RAM)并且由处理器(例如CPU)来执行,从而实现在本文中描述的各种方法和处理。
图10示出了根据程序执行本发明的方法的计算机硬件的示例配置框图,该计算机硬件是用于训练本发明的领域自适应神经网络的装置的一个示例。此外,本发明的领域自适应神经网络也可以基于该计算机硬件来实现。
如图10所示,在计算机1000中,中央处理单元(CPU)1001、只读存储器(ROM)1002以及随机存取存储器(RAM)1003通过总线1004彼此连接。
输入/输出接口1005进一步与总线1004连接。输入/输出接口1005连接有以下组件:以键盘、鼠标、麦克风等形成的输入单元1006;以显示器、扬声器等形成的输出单元1007;以硬盘、非易失性存储器等形成的存储单元1008;以网络接口卡(诸如局域网(LAN)卡、调制解调器等)形成的通信单元1009;以及驱动移动介质1011的驱动器1010,该移动介质1011例如是磁盘、光盘、磁光盘或半导体存储器。
在具有上述结构的计算机中,CPU 1001将存储在存储单元1008中的程序经由输入/输出接口1005和总线1004加载到RAM 1003中,并且执行该程序,以便执行上文中描述的方法。
要由计算机(CPU 1001)执行的程序可以被记录在作为封装介质的移动介质1011上,该封装介质以例如磁盘(包括软盘)、光盘(包括压缩光盘-只读存储器(CD-ROM))、数字多功能光盘(DVD)等)、磁光盘、或半导体存储器来形成。此外,要由计算机(CPU 1001)执行的程序也可以经由诸如局域网、因特网、或数字卫星广播的有线或无线传输介质来提供。
当移动介质1011安装在驱动器1010中时,可以将程序经由输入/输出接口1005安装在存储单元1008中。另外,可以经由有线或无线传输介质由通信单元1009来接收程序,并且将程序安装在存储单元1008中。可替选地,可以将程序预先安装在ROM 1002或存储单元1008中。
由计算机执行的程序可以是根据本说明书中描述的顺序来执行处理的程序,或者可以是并行地执行处理或当需要时(诸如,当调用时)执行处理的程序。
本文中所描述的单元或装置仅是逻辑意义上的,并不严格对应于物理设备或实体。例如,本文所描述的每个单元的功能可能由多个物理实体来实现,或者,本文所描述的多个单元的功能可能由单个物理实体来实现。此外,在一个实施例中描述的特征、部件、元素、步骤等并不局限于该实施例,而是也可以应用于其它实施例,例如替代其它实施例中的特定特征、部件、元素、步骤等,或者与其相结合。
本发明的范围不限于在本文中描述的具体实施例。本领域普通技术人员应该理解的是,取决于设计要求和其他因素,在不偏离本发明的原理和精神的情况下,可以对本文中的实施例进行各种修改或变化。本发明的范围由所附权利要求及其等同方案来限定。
附记:
(1).一种由计算机实现的用于训练领域自适应神经网络的方法,其中所述领域自适应神经网络包括第一特征提取单元、第一分类单元以及判别单元,其中,所述计算机包括存储有指令的存储器和处理器,所述指令在被所述处理器执行时使得所述处理器执行所述方法,所述方法包括:
由所述第一特征提取单元针对已标记的源数据集合中的源数据提取第一特征,以及由所述第一分类单元基于所述第一特征来预测所述源数据属于多个类别中的每个类别的概率;
由所述第一特征提取单元针对未标记的目标数据集合中的目标数据提取第二特征,以及由所述第一分类单元基于所述第二特征来预测所述目标数据属于所述每个类别的概率,并且将对应于最大概率的类别确定为所述目标数据的第一标签;
计算源数据集合的针对所述每个类别的类中心与所述目标数据的特征之间的距离,并且将距离最近的类中心所对应的类别确定为所述目标数据的第二标签;
在所述目标数据集合中选择对其确定的所述第一标签与所述第二标签相同的目标数据,其中,所述第一标签或所述第二标签作为所选择的目标数据的伪标签;
基于所选择的目标数据计算所述目标数据集合的针对所述每个类别的类中心;
基于所述源数据集合的类中心和所计算的目标数据集合的类中心之间的距离来构建第一损失函数;
基于所选择的目标数据以及其伪标签来构建第二损失函数;
针对所述源数据集合中的源数据以及所选择的目标数据构建第三损失函数;
基于所述第一损失函数、所述第二损失函数和所述第三损失函数来训练所述领域自适应神经网络。
(2).根据(1)所述的方法,还包括:
由所述判别单元基于所述第一特征和所述第二特征来确定当前输入的数据是源数据的概率;
基于所述判别单元确定的概率来构建第四损失函数;
基于所述第四损失函数来训练所述领域自适应神经网络。
(3).根据(2)所述的方法,其中,基于以下中的一个来构建所述第四损失函数:
所述判别单元确定的概率的倒数,以及
1减去所述判别单元确定的概率的差。
(4).根据(2)所述的方法,其中,所述领域自适应神经网络还包括第二特征提取单元和第二分类单元,上述方法还包括:
由所述第二特征提取单元针对所述源数据提取第三特征,并且由所述第二分类单元基于所述第三特征来预测所述源数据属于所述每个类别的概率;
由所述第二特征提取单元针对所述目标数据提取第四特征,并且由所述第二分类单元基于所述第四特征来预测所述目标数据属于所述每个类别的概率;
基于所述第一分类单元预测的概率和所述第二分类单元预测的概率,来构建第五损失函数;
基于所述第五损失函数来训练所述领域自适应神经网络。
(5).根据(4)所述的方法,其中,基于所述第一分类单元和所述第二分类单元各自针对所述源数据预测的概率的均方误差以及所述第一分类单元和所述第二分类单元各自针对所述目标数据预测的概率的均方误差,来构建所述第五损失函数。
(6).根据(4)所述的方法,其中,所述第二特征提取单元的参数是所述第一特征提取单元的参数的指数移动平均,并且所述第二分类单元的参数是所述第一分类单元的参数的指数移动平均,
其中,通过以下方式之一来获得在所述指数移动平均中使用的衰减率:
使用可导变量作为所述衰减率;
使用全连接层来生成所述衰减率,其中,所述全连接层被设置为与所述第二分类单元的输出层并行地连接到所述输出层的前一层。
(7).根据(4)所述的方法,其中,基于所述第一损失函数、所述第二损失函数、所述第三损失函数、所述第四损失函数和所述第五损失函数的加权组合来训练所述领域自适应神经网络,
其中,随着训练的进行,逐渐增大对于所述第四损失函数和所述第五损失函数的权值。
(8).根据(1)所述的方法,其中,所述第二损失函数是用于训练所述第一分类单元的交叉熵损失函数。
(9).根据(1)所述的方法,其中,所述判别单元经由梯度反转单元与所述第一特征提取单元连接,并且所述判别单元与所述第一特征提取单元以相互对抗的方式操作。
(10).根据(1)所述的方法,其中,所述领域自适应神经网络用于执行图像识别,并且所述源数据和所述目标数据是图像数据,或者
所述领域自适应神经网络用于处理金融数据,并且所述源数据和所述目标数据是表格类型数据,或者
所述领域自适应神经网络用于处理环境气象数据或医疗数据,并且所述源数据和所述目标数据是时间序列数据或图像数据。
(11).一种用于训练领域自适应神经网络的装置,所述领域自适应神经网络包括:
第一特征提取单元,其用于针对已标记的源数据集合中的源数据提取第一特征,并且针对未标记的目标数据集合中的目标数据提取第二特征;
第一分类单元,其基于所述第一特征来预测所述源数据属于多个类别中的每个类别的概率,并且基于所述第二特征来预测所述目标数据属于所述每个类别的概率以及将对应于最大概率的类别确定为所述目标数据的第一标签;以及
判别单元,其基于所述第一特征和所述第二特征来确定当前输入的数据是源数据的概率;
所述装置包括:
存储有程序的存储器;以及
一个或多个处理器,所述处理器通过执行所述程序而执行以下操作:
计算源数据集合的针对所述每个类别的类中心与所述目标数据的特征之间的距离,并且将距离最近的类中心所对应的类别确定为所述目标数据的第二标签;
在所述目标数据集合中选择对其确定的所述第一标签与所述第二标签相同的目标数据,其中,所述第一标签或所述第二标签作为所选择的目标数据的伪标签;
基于所选择的目标数据计算所述目标数据集合的针对所述每个类别的类中心;
基于所述源数据集合的类中心和所计算的目标数据集合的类中心之间的距离来构建第一损失函数;
基于所选择的目标数据以及其伪标签来构建第二损失函数;
针对所述源数据集合中的源数据以及所选择的目标数据构建第三损失函数;
基于所述第一损失函数、所述第二损失函数和所述第三损失函数来训练所述领域自适应神经网络。
(12).一种用于训练领域自适应神经网络的装置,所述领域自适应神经网络包括:
第一特征提取单元,其用于针对已标记的源数据集合中的源数据提取第一特征,并且针对未标记的目标数据集合中的目标数据提取第二特征;
第一分类单元,其基于所述第一特征来预测所述源数据属于多个类别中的每个类别的概率,并且基于所述第二特征来预测所述目标数据属于所述每个类别的概率,并且将对应于最大概率的类别确定为所述目标数据的第一标签;以及
判别单元,其基于所述第一特征和所述第二特征来确定当前输入的数据是源数据的概率;
所述装置包括:
优选目标数据集生成单元,其被配置为:计算源数据集合的针对所述每个类别的类中心与所述目标数据的特征之间的距离,并且将距离最近的类中心所对应的类别确定为所述目标数据的第二标签;在所述目标数据集合中选择对其确定的所述第一标签与所述第二标签相同的目标数据,以形成优选目标数据集,其中,所述第一标签或所述第二标签作为所选择的目标数据的伪标签;
第一损失函数生成单元,其被配置为:基于所述优选目标数据集中的目标数据来计算所述目标数据集合的针对所述每个类别的类中心;基于所述源数据集合的类中心和所计算的目标数据集合的类中心之间的距离来构建第一损失函数;
第二损失函数生成单元,其被配置为基于所述优选目标数据集中的目标数据以及其伪标签来构建第二损失函数;
第三损失函数生成单元,其被配置为针对所述源数据集合中的源数据以及所述优选目标数据集中的目标数据来构建第三损失函数;
训练单元,其基于所述第一损失函数、所述第二损失函数和所述第三损失函数来训练所述领域自适应神经网络。
(13).一种存储有用于训练领域自适应神经网络的程序的存储介质,所述领域自适应神经网络包括第一特征提取单元、第一分类单元以及判别单元,所述程序在被计算机执行时使得所述计算机执行包括以下步骤的方法:
由所述第一特征提取单元针对已标记的源数据集合中的源数据提取第一特征,以及由所述第一分类单元基于所述第一特征来预测所述源数据属于多个类别中的每个类别的概率;
由所述第一特征提取单元针对未标记的目标数据集合中的目标数据提取第二特征,以及由所述第一分类单元基于所述第二特征来预测所述目标数据属于所述每个类别的概率,并且将对应于最大概率的类别确定为所述目标数据的第一标签;
计算源数据集合的针对所述每个类别的类中心与所述目标数据的特征之间的距离,并且将距离最近的类中心所对应的类别确定为所述目标数据的第二标签;
在所述目标数据集合中选择对其确定的所述第一标签与所述第二标签相同的目标数据,其中,所述第一标签或所述第二标签作为所选择的目标数据的伪标签;
基于所选择的目标数据计算所述目标数据集合的针对所述每个类别的类中心;
基于所述源数据集合的类中心和所计算的目标数据集合的类中心之间的距离来构建第一损失函数;
基于所选择的目标数据以及其伪标签来构建第二损失函数;
针对所述源数据集合中的源数据以及所选择的目标数据构建第三损失函数;
基于所述第一损失函数、所述第二损失函数和所述第三损失函数来训练所述领域自适应神经网络。

Claims (10)

1.一种由计算机实现的用于训练领域自适应神经网络的方法,其中所述领域自适应神经网络包括第一特征提取单元、第一分类单元以及判别单元,其中,所述计算机包括存储有指令的存储器和处理器,所述指令在被所述处理器执行时使得所述处理器执行所述方法,所述方法包括:
由所述第一特征提取单元针对已标记的源数据集合中的源数据提取第一特征,以及由所述第一分类单元基于所述第一特征来预测所述源数据属于多个类别中的每个类别的概率;
由所述第一特征提取单元针对未标记的目标数据集合中的目标数据提取第二特征,以及由所述第一分类单元基于所述第二特征来预测所述目标数据属于所述每个类别的概率,并且将对应于最大概率的类别确定为所述目标数据的第一标签;
计算源数据集合的针对所述每个类别的类中心与所述目标数据的特征之间的距离,并且将距离最近的类中心所对应的类别确定为所述目标数据的第二标签;
在所述目标数据集合中选择对其确定的所述第一标签与所述第二标签相同的目标数据,其中,所述第一标签或所述第二标签作为所选择的目标数据的伪标签;
基于所选择的目标数据计算所述目标数据集合的针对所述每个类别的类中心;
基于所述源数据集合的类中心和所计算的目标数据集合的类中心之间的距离来构建第一损失函数;
基于所选择的目标数据以及其伪标签来构建第二损失函数;
针对所述源数据集合中的源数据以及所选择的目标数据构建第三损失函数;
基于所述第一损失函数、所述第二损失函数和所述第三损失函数来训练所述领域自适应神经网络。
2.根据权利要求1所述的方法,还包括:
由所述判别单元基于所述第一特征和所述第二特征来确定当前输入的数据是源数据的概率;
基于所述判别单元确定的概率来构建第四损失函数;
基于所述第四损失函数来训练所述领域自适应神经网络。
3.根据权利要求2所述的方法,其中,基于以下中的一个来构建所述第四损失函数:
所述判别单元确定的概率的倒数,以及
1减去所述判别单元确定的概率的差。
4.根据权利要求2所述的方法,其中,所述领域自适应神经网络还包括第二特征提取单元和第二分类单元,所述方法还包括:
由所述第二特征提取单元针对所述源数据提取第三特征,并且由所述第二分类单元基于所述第三特征来预测所述源数据属于所述每个类别的概率;
由所述第二特征提取单元针对所述目标数据提取第四特征,并且由所述第二分类单元基于所述第四特征来预测所述目标数据属于所述每个类别的概率;
基于所述第一分类单元预测的概率和所述第二分类单元预测的概率,来构建第五损失函数;
基于所述第五损失函数来训练所述领域自适应神经网络。
5.根据权利要求4所述的方法,其中,基于所述第一分类单元和所述第二分类单元各自针对所述源数据预测的概率的均方误差以及所述第一分类单元和所述第二分类单元各自针对所述目标数据预测的概率的均方误差,来构建所述第五损失函数。
6.根据权利要求4所述的方法,其中,所述第二特征提取单元的参数是所述第一特征提取单元的参数的指数移动平均,并且所述第二分类单元的参数是所述第一分类单元的参数的指数移动平均,
其中,通过以下方式之一来获得在所述指数移动平均中使用的衰减率:
使用可导变量作为所述衰减率;
使用全连接层来生成所述衰减率,其中,所述全连接层被设置为与所述第二分类单元的输出层并行地连接到所述输出层的前一层。
7.根据权利要求4所述的方法,其中,基于所述第一损失函数、所述第二损失函数、所述第三损失函数、所述第四损失函数和所述第五损失函数的加权组合来训练所述领域自适应神经网络,
其中,随着训练的进行,逐渐增大对于所述第四损失函数和所述第五损失函数的权值。
8.根据权利要求1所述的方法,其中,所述第二损失函数是用于训练所述第一分类单元的交叉熵损失函数。
9.一种用于训练领域自适应神经网络的装置,所述领域自适应神经网络包括:
第一特征提取单元,其被配置为针对已标记的源数据集合中的源数据提取第一特征,并且针对未标记的目标数据集合中的目标数据提取第二特征;
第一分类单元,其被配置为基于所述第一特征来预测所述源数据属于多个类别中的每个类别的概率,并且被配置为基于所述第二特征来预测所述目标数据属于所述每个类别的概率以及将对应于最大概率的类别确定为所述目标数据的第一标签;以及
判别单元,其被配置为基于所述第一特征和所述第二特征来确定当前输入的数据是源数据的概率;
所述装置包括:
存储有程序的存储器;以及
一个或多个处理器,所述处理器通过执行所述程序而执行以下操作:
计算源数据集合的针对所述每个类别的类中心与所述目标数据的特征之间的距离,并且将距离最近的类中心所对应的类别确定为所述目标数据的第二标签;
在所述目标数据集合中选择对其确定的所述第一标签与所述第二标签相同的目标数据,其中,所述第一标签或所述第二标签作为所选择的目标数据的伪标签;
基于所选择的目标数据计算所述目标数据集合的针对所述每个类别的类中心;
基于所述源数据集合的类中心和所计算的目标数据集合的类中心之间的距离来构建第一损失函数;
基于所选择的目标数据以及其伪标签来构建第二损失函数;
针对所述源数据集合中的源数据以及所选择的目标数据构建第三损失函数;
基于所述第一损失函数、所述第二损失函数和所述第三损失函数来训练所述领域自适应神经网络。
10.一种存储有用于训练领域自适应神经网络的程序的存储介质,所述领域自适应神经网络包括第一特征提取单元、第一分类单元以及判别单元,所述程序在被计算机执行时使得所述计算机执行根据权利要求1-8中任一项所述的方法。
CN202010911149.0A 2020-09-02 2020-09-02 领域自适应神经网络的训练方法 Pending CN114139676A (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN202010911149.0A CN114139676A (zh) 2020-09-02 2020-09-02 领域自适应神经网络的训练方法
JP2021136658A JP2022042487A (ja) 2020-09-02 2021-08-24 ドメイン適応型ニューラルネットワークの訓練方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010911149.0A CN114139676A (zh) 2020-09-02 2020-09-02 领域自适应神经网络的训练方法

Publications (1)

Publication Number Publication Date
CN114139676A true CN114139676A (zh) 2022-03-04

Family

ID=80438142

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010911149.0A Pending CN114139676A (zh) 2020-09-02 2020-09-02 领域自适应神经网络的训练方法

Country Status (2)

Country Link
JP (1) JP2022042487A (zh)
CN (1) CN114139676A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114399640A (zh) * 2022-03-24 2022-04-26 之江实验室 一种不确定区域发现与模型改进的道路分割方法及装置
CN114445670A (zh) * 2022-04-11 2022-05-06 腾讯科技(深圳)有限公司 图像处理模型的训练方法、装置、设备及存储介质

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116070796B (zh) * 2023-03-29 2023-06-23 中国科学技术大学 柴油车排放等级评估方法及系统
CN117017288B (zh) * 2023-06-14 2024-03-19 西南交通大学 跨被试情绪识别模型及其训练方法、情绪识别方法、设备
CN116452897B (zh) * 2023-06-16 2023-10-20 中国科学技术大学 跨域小样本分类方法、系统、设备及存储介质

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114399640A (zh) * 2022-03-24 2022-04-26 之江实验室 一种不确定区域发现与模型改进的道路分割方法及装置
CN114399640B (zh) * 2022-03-24 2022-07-15 之江实验室 一种不确定区域发现与模型改进的道路分割方法及装置
CN114445670A (zh) * 2022-04-11 2022-05-06 腾讯科技(深圳)有限公司 图像处理模型的训练方法、装置、设备及存储介质

Also Published As

Publication number Publication date
JP2022042487A (ja) 2022-03-14

Similar Documents

Publication Publication Date Title
CN113378632B (zh) 一种基于伪标签优化的无监督域适应行人重识别方法
CN114139676A (zh) 领域自适应神经网络的训练方法
CN109583501B (zh) 图片分类、分类识别模型的生成方法、装置、设备及介质
CN108647736B (zh) 一种基于感知损失和匹配注意力机制的图像分类方法
CN114841257B (zh) 一种基于自监督对比约束下的小样本目标检测方法
CN109800437A (zh) 一种基于特征融合的命名实体识别方法
CN112257449B (zh) 命名实体识别方法、装置、计算机设备和存储介质
CN111414461A (zh) 一种融合知识库与用户建模的智能问答方法及系统
CN113392967A (zh) 领域对抗神经网络的训练方法
CN115408525B (zh) 基于多层级标签的信访文本分类方法、装置、设备及介质
CN113469186A (zh) 一种基于少量点标注的跨域迁移图像分割方法
CN112232395B (zh) 一种基于联合训练生成对抗网络的半监督图像分类方法
CN113723083A (zh) 基于bert模型的带权消极监督文本情感分析方法
CN115984213A (zh) 基于深度聚类的工业产品外观缺陷检测方法
CN112527959B (zh) 基于无池化卷积嵌入和注意分布神经网络的新闻分类方法
CN114675249A (zh) 基于注意力机制的雷达信号调制方式识别方法
CN114048290A (zh) 一种文本分类方法及装置
CN113870863A (zh) 声纹识别方法及装置、存储介质及电子设备
CN116189671B (zh) 一种用于语言教学的数据挖掘方法及系统
CN116720498A (zh) 一种文本相似度检测模型的训练方法、装置及其相关介质
CN116433909A (zh) 基于相似度加权多教师网络模型的半监督图像语义分割方法
CN114495114A (zh) 基于ctc解码器的文本序列识别模型校准方法
CN114139655A (zh) 一种蒸馏式竞争学习的目标分类系统和方法
CN113851149A (zh) 一种基于对抗迁移和Frobenius范数的跨库语音情感识别方法
CN113239809A (zh) 基于多尺度稀疏sru分类模型的水声目标识别方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination