CN112836739A - 基于动态联合分布对齐的分类模型建立方法及其应用 - Google Patents
基于动态联合分布对齐的分类模型建立方法及其应用 Download PDFInfo
- Publication number
- CN112836739A CN112836739A CN202110128228.9A CN202110128228A CN112836739A CN 112836739 A CN112836739 A CN 112836739A CN 202110128228 A CN202110128228 A CN 202110128228A CN 112836739 A CN112836739 A CN 112836739A
- Authority
- CN
- China
- Prior art keywords
- distribution
- data set
- loss
- domain data
- alignment
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000009826 distribution Methods 0.000 title claims abstract description 291
- 238000000034 method Methods 0.000 title claims abstract description 91
- 238000013145 classification model Methods 0.000 title claims abstract description 52
- 238000012549 training Methods 0.000 claims abstract description 52
- 238000002372 labelling Methods 0.000 claims description 8
- 238000004590 computer program Methods 0.000 claims description 6
- 230000006978 adaptation Effects 0.000 abstract description 13
- 230000008569 process Effects 0.000 description 18
- 238000004364 calculation method Methods 0.000 description 17
- 238000005259 measurement Methods 0.000 description 16
- 230000006870 function Effects 0.000 description 7
- 238000013508 migration Methods 0.000 description 7
- 238000005457 optimization Methods 0.000 description 6
- 239000004576 sand Substances 0.000 description 6
- 230000005012 migration Effects 0.000 description 5
- 230000008859 change Effects 0.000 description 4
- 230000003247 decreasing effect Effects 0.000 description 3
- 238000010586 diagram Methods 0.000 description 3
- 230000006872 improvement Effects 0.000 description 3
- 230000009286 beneficial effect Effects 0.000 description 2
- 238000002474 experimental method Methods 0.000 description 2
- 238000013526 transfer learning Methods 0.000 description 2
- 230000003044 adaptive effect Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000003203 everyday effect Effects 0.000 description 1
- 239000000284 extract Substances 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 239000000463 material Substances 0.000 description 1
- 238000000691 measurement method Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
- 238000000844 transformation Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/245—Classification techniques relating to the decision surface
- G06F18/2451—Classification techniques relating to the decision surface linear, e.g. hyperplane
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/22—Matching criteria, e.g. proximity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Image Analysis (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明公开了一种基于动态联合分布对齐的分类模型建立方法及其应用,属于域适应领域,包括:分别为源域数据集和目标域数据集中的样本赋权重,使类别分布相同;将两个数据集中的样本输入联合分布对齐模型,并计算损失;联合分布对齐模型包括:特征提取器,用于提取输入样本的特征;特征判别器,用于判断特征提取器提取的特征来源;分类器,用于对特征提取器提取的特征进行分类以产生相应的类别标签;类别判别器,用于判断分类器产生的类别标签的来源;根据损失更新联合分布对齐模型后为目标域数据集中的样本标注伪标签以更新样本权重;迭代训练结束后,由特征提取器和分类器构成分类模型。本发明能够解决训练数据缺乏的问题并减少训练资源和时间。
Description
技术领域
本发明属于域适应领域,更具体地,涉及一种基于动态联合分布对齐的分类模型建立方法及其应用。
背景技术
传统的分类问题要求训练集和测试集的数据是独立同分布的,且训练集数据丰富有利于分类器的训练和性能提升。随着网络技术的不断发展,我们进入了大数据时代,每天每时都会产生大量的信息数据,使得分类器可以依赖这些数据不断的训练和更新模型,极大地提高了分类器的性能。但是,这些数据中却很少有完善的数据标注,这提升了分类器的训练的困难程度,而人工标注数据又十分耗费人力物力,给机器学习和深度学习的模型训练和更新带来了新的挑战,该问题在图像分类和文本分类任务中尤为明显。为了解决这个问题,迁移学习应运而生。
域适应方法是迁移学习的一大类子问题,针对训练集数据不足的问题,寻找一个相似的、有标注的数据集来帮助待分类数据集训练分类器,从而可以准确地为数据集进行分类。域适应方法的关键在于利用数据集的相似性,减少数据集间的差异。进一步地,减少数据集差异的关键之一就在于差异的度量方式。常见的度量方式有两种,一种是利用常见的距离度量公式进行度量,其中最常用的是最大均值差异;另一种方法是基于对抗的方式,借助生成对抗网络结构进行差异的度量和模型的训练,这种对抗的方式可以避免显式的距离度量,学习更多的非线性特征,应用更加广泛,进一步提高了迁移性能。然而,目前大部分方法还是利用常见的距离度量公式进行度量,基于对抗的方法较少。
域适应中最常用的一类方法是数据分布自适应。通过学习一些变换减少数据集间数据分布的差异。边缘分布的差异体现在数据的整体不同,条件分布的差异体现在具体到类中的不同。根据数据分布的性质,可以具体分为边缘分布自适应、条件分布自适应和联合分布自适应。一般数据集间的边缘分布和条件分布都是不同的,但是目前的数据分布自适应方法大部分都是边缘分布自适应方法,联合分布自适应,即同时对齐边缘分布和条件分布的方法较少。
此外,对于不同的数据集,边缘分布对齐和条件分布对齐的相对重要性也不同,对于整体看着不同的数据集,应该优先对齐边缘分布,整体相似,具体类不同的应该优先对齐条件分布。但是目前大部分联合分布自适应方法都是认为这两部分有相等的重要性。针对这类问题,研究者提出了动态分布对齐方法,根据边缘分布和条件分布差异的大小比值赋予边缘分布对齐和条件分布对齐相应的权重,并在训练过程中不断更新。然而,这些方法没有考虑到边缘分布距离对条件分布距离的影响,此外,在度量条件分布时,每一类数据都需要一个类判别器,需要更多的训练资源和时间。
综上所述,针对图像或文本分类中的训练数据缺乏问题,目前已有的域适应方法大部分仅对齐了一种数据分布,基于对抗的联合分布对齐方法较少;对于联合分布对齐中边缘分布和条件分布相对重要性不同的问题,可以通过动态分布对齐的思想来解决,但这类方法需要更多的训练资源,且没有考虑边缘分布对条件分布的影响。因此,现有的域适应方法在解决图像或文本分类中的训练数据缺乏的问题时,需要耗费大量的训练资源和时间,且训练精度有待进一步提高。
发明内容
针对现有技术的缺陷和改进需求,本发明提供了一种基于动态联合分布对齐的分类模型建立方法及其应用,其目的在于,基于域适应的方法解决图像或文本分类中的训练数据缺乏的问题,同时有效减少所需的训练资源和时间。
为实现上述目的,按照本发明的一个方面,提供了一种基于动态联合分布对齐的分类模型建立方法,包括如下步骤:
(S1)分别为已标注类别标签的源域数据集和未标注类别标签的目标域数据集中的各样本赋予权重,使得加权后两个数据集的类别分布相同;两个数据集的特征空间和标签空间相同,但边缘分布和条件分布不同;目标域数据集属于目标分类任务,目标分类任务为图像分类任务或文本分类任务;
(S2)将两个数据集中的样本输入联合分布对齐模型,并计算相应的损失;联合分布对齐模型包括基于对抗的边缘分布对齐网络和条件分布对齐网络,边缘分布对齐网络包括一个特征提取器和一个特征判别器,条件分布对齐网络包括一个分类器和一个类别判别器,特征提取器用于提取输入样本的特征,特征判别器用于判断特征提取器提取的特征来自于哪一个数据集,分类器用于对特征提取器提取的特征进行分类以产生相应的类别标签,类别判别器用于判断分类器产生的类别标签来自于哪一个数据集;
(S3)根据所计算的损失更新联合分布对齐模型后,利用特征提取器和分类器为目标域数据集中的样本标注伪标签,并基于该标注结果更新目标域数据集中各样本的权重;
(S4)重复执行步骤(S2)~(S3)以对联合分布对齐模型进行迭代训练,直至达到预设的迭代终止条件;迭代终止后,利用特征提取器和分类器相连构成对目标分类任务进行分类的分类模型。
联合分布对齐分为边缘分布对齐和条件分布对齐两部分,由于条件分布不易求得,本发明通过对齐类别分布来代替条件分布,具体地,首先对源域数据集和目标域数据集中的样本赋予权重,以对样本进行放缩,使得加权后两个数据集中的类别分布相同,即加权后两个数据集中同一类别的样本数量相同,其次,在联合分布对齐模型中,条件分布对齐网络中仅包含一个类别判别器,用于判断分类器产生的类别标签具体来自于哪一个数据集,在完成条件分布对齐的同时,整个联合分布对齐模型中仅包含两个判别器,即一个特征判别器和一个类别判别器,相比于现有的动态分布对齐方法中需要一个边缘分布判别器和与C(类别总数)个条件分布判别器而言,本发明有效减少了需要训练的模块数量,简化了模型,从而很大程度上减少了计算所需的资源和时间。
进一步地,步骤(S2)中,所计算的损失包括生成部分损失Gen_loss和判别部分损失Dis_loss,计算表达式如下:
Gen_loss=LCla+αLFea
其中,LCla是分类器的损失,LFea是特征提取器的损失,是特征判别器的损失,是类别判别器的损失;α表示生成部分损失Gen_loss中LFea的权重,k表示源域数据集与目标域数据集之间的边缘分布距离对条件分布距离的影响因子,μ表示源域数据集与目标域数据集的条件分布对齐权重。
在分别可以完成边缘分布对齐和条件分布对齐之后,一个重要的问题就是如何将它们结合,可以进一步提高迁移性能。对于不同的数据集,由于数据分布差异不同,其边缘分布对齐和条件分布对齐的相对重要性也不同,若数据集看起来很不相同,则应优先对齐边缘分布,若数据集看起来相似,而具体到每一类的时候不同,这时应优先对齐条件分布。边缘分布距离会影响条件分布距离,本发明在计算判别部分损失时,引入了源域数据集与目标域数据集之间的边缘分布距离对条件分布距离的影响因子k,能够进一步提高迁移后分类器的分类性能。
其中,dC和dM分别表示条件分布距离和边缘分布距离。
传统的联合分布自适应方法大多是为边缘分布对齐和条件分布对齐赋予平等的权重进行训练,这种方法显然不适合所有的数据集;如果人工实验寻找最优解,虽然可以找到最合适的权重,但是却需要进行大量的重复实验,耗费了大量的计算资源和时间。本发明通过动态分布对齐方法,根据当前边缘分布和条件分布的差异,计算得到两个分布的权重,并在训练过程中不断更新,从而达到近似的结果,在准确衡量边缘分布对齐和条件分布对齐的相对重要性的同时,有效减少了资源和时间消耗。
进一步地,条件分布距离和边缘分布距离均通过A-distance距离度量。
本发明使用A-distance距离度量能够准确度量图像分类或文本分类任务中,源域数据集和目标域数据集之间的条件分布距离和边缘分布距离。
进一步地,计算条件分布距离时,以类别判别器作为用于计算A-distance距离的线性分类器;计算边缘分布距离时,以特征判别器作为用于计算A-distance距离的线性分类器。
A-distance是一种常见的用于度量分布距离的度量公式,其形式定义为建立一个线性分类器来区分两个数据领域的hinge损失(即进行二类分类的hinge损失),其计算方式是先在源域和目标域上训练一个二分类器h,使得这个分类器可以区分样本是来自哪一个领域。假设有两个数据分布Ds和Dt,则它们的A-distance距离可根据如下公式计算:
d(Ds,Dt)=2(1-2d(h))
其中,d(h)代表分类器h的损失;根据A-distance的定义可知,训练的二分类器与生成对抗网络中的判别器定义相同,本发明相应地以特征判别器和类别判别器替代边缘分布距离度量和条件分布距离度量中的二分类器,可以准确地完成距离度量,此时,d(h)就是特征判别器或类别判别器的损失,由于在训练过程中判别器的损失已经存在,可以分别将特征判别器和类别判别器的损失值代入A-distance公式中进行计算,得到边缘分布距离和条件分布距离,再代入权重公式中计算,得到边缘分布对齐和条件分布对齐的权重,而无需引入其他的模块训练,能够有效减少资源和时间消耗。
进一步地,在迭代训练过程中,生成部分损失Gen_loss中LFea的权重α动态递减。
本发明在训练过程中,动态减小生成部分损失Gen_loss中特征提取器损失LFea的权重α,能够在一个好的特征表示的基础上,减少其对分类器的影响,最终有利于得到质量较高的伪标签。
其中,ws(x,y)表示源域数据集中样本(x,y)的权重,Maxs表示源域数据集中含样本最多的类别所包含的样本数量,Num(y)表示源域数据集中标签为y的样本数量;wt(x)表示目标域数据集中样本x的权重,Maxt表示标注伪标签后目标域数据集中含样本数量最多的类别所包含的样本数量,Num(y′)表示标注伪标签后目标域数据集中伪标签为y′的样本数量。
本发明通过上述计算公式为源域数据集或者标注了伪标签的目标域数据集中的样本赋予权重,能够有效保证两个数据集的类别分布一致,即同一类别所包含的样本数量相同,从而便于通过类别分布对齐替代难以求解的条件分布对齐。
按照本发明的另一个方面,提供了一种图像分类方法,包括:
将图像分类任务中待分类的图像数据输入本发明提供的基于动态联合分布对齐的分类模型建立方法所建立的分类模型,以由分类模型输出分类结果。
通过本发明提供的基于动态联合分布对齐的分类模型建立方法,能够在图像分类任务的训练数据缺乏的情况下,建立具有良好分类性能的分类模型,因此,基于该分类模型,本发明所提供的图像分类方法,能够准确完成图像分类。
按照本发明的又一个方面,提供了一种文本分类方法,包括:
将文本分类任务中待分类的文本数据输入本发明提供的基于动态联合分布对齐的分类模型建立方法所建立的分类模型,以由分类模型输出分类结果。
通过本发明提供的基于动态联合分布对齐的分类模型建立方法,能够在文本分类任务的训练数据缺乏的情况下,建立具有良好分类性能的分类模型,因此,基于该分类模型,本发明所提供的文本分类方法,能够准确完成文本分类。
按照本发明的又一个方面,提供了一种计算机可读存储介质,包括存储的计算机程序;计算机程序被处理器执行时,控制计算机可读存储介质所在设备执行本发明提供的基于动态联合分布对齐的分类模型建立方法,和/或本发明提供的图像分类方法,和/或本发明提供的文本分类方法。
总体而言,通过本发明所构思的以上技术方案,能够取得以下有益效果:
(1)本发明对源域数据集和目标域数据集中的样本赋予权重,使得加权后两个数据集中的类别分布相同,并且在联合分布对齐模型中,条件分布对齐网络中仅包含一个类别判别器,用于判断分类器产生的类别标签具体来自于哪一个数据集,由此通过对齐类别分布来代替难以求解的条件分布,并且整个模型中,仅包含两个判别器,有效减少了需要训练的模块数量,简化了模型,从而很大程度上减少了计算所需的资源和时间。
(2)本发明在模型训练过程中,计算判别部分损失时,考虑了源域数据集与目标域数据集之间的边缘分布距离对条件分布距离的影响,能够进一步提高迁移后分类器的分类性能。
(3)本发明通过动态分布对齐方法,根据当前边缘分布和条件分布的差异,计算得到两个分布的权重,并在训练过程中不断更新,从而达到近似的结果,在准确衡量边缘分布对齐和条件分布对齐的相对重要性的同时,有效减少了资源和时间消耗。
(4)本发明使用A-distance距离度量图像分类或文本分类任务中,源域数据集和目标域数据集之间的条件分布距离和边缘分布距离,能够准确完成距离的度量;在其优选方案中,分别以特征判别器和类别判别器替代边缘分布距离度量和条件分布距离度量中的二分类器,可以准确地完成距离度量,并且无需引入其他的模块训练,能够有效减少资源和时间消耗。
附图说明
图1为本发明实施例提供的基于动态联合分布对齐的分类模型建立方法流程图;
图2为本发明实施例提供的联合分布对齐模型对齐示意图;
图3为本发明实施例提供的边缘分布距离和条件分布距离随迭代次数的变化示意图;
图4为本发明实施例提供的动态分布对齐方法的伪代码示意图。
具体实施方式
为了使本发明的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本发明进行进一步详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。此外,下面所描述的本发明各个实施方式中所涉及到的技术特征只要彼此之间未构成冲突就可以相互组合。
在本发明中,本发明及附图中的术语“第一”、“第二”等(如果存在)是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。
针对现有技术利用域适应的方法解决图像分类或文本分类任务中训练数据缺乏的问题时,需要消耗大量的训练资源和训练时间的技术问题,本发明提供了一种基于动态联合分布对齐的分类模型建立方法及其应用,其整体思路在于:优先对源域数据集和目标域数据集进行边缘分布对齐,在两个数据集的边缘分布尽可能相近的基础上,以类别分布对齐的方式代替难以求解的条件分布对齐,有效减少模型中判别器的数量,从而有效减少所需消耗的训练资源和时间;在此基础上,在训练过程中考虑边缘分布距离对条件分布距离的影响因子,利用条件分布距离和边缘分布距离计算两个分布的权重并动态更新,在准确衡量边缘分布对齐和条件分布对齐的相对重要性的同时,进一步减少资源和时间消耗。
以下为实施例。
实施例1:
一种基于动态联合分布对齐的分类模型建立方法,如图1所示,包括如下步骤:
(S1)分别为已标注类别标签的源域数据集和未标注类别标签的目标域数据集中的各样本赋予权重,使得加权后两个数据集的类别分布相同;两个数据集的特征空间和标签空间相同,但边缘分布和条件分布不同;
本实施例中,目标域数据集属于图像分类任务;
分别以和表示已标注类别标签的源域数据集和未标注类别标签的目标域数据集,其中,ns和nt分别表示源域数据集Ds和目标域数据集Dt中的样本数量,(xi,yi)表示源域数据集Ds中的样本,xi表示样本中的图像数据,yi表示该样本的类别标签,xj表示目标域数据集Dt中的样本;源域数据集Ds和目标域数据集Dt的特征空间分别是Xs和Xt,边缘分布(即特征分布)分别是P(xs)和P(xt),条件分布分别是P(ys|xs)和P(yt|xt);
联合分布对齐可以分为边缘分布对齐和条件分布对齐两部分,由于条件分布不易求得,本实施例通过对齐类别分布来代替条件分布,这要求数据集的类别分布相同,否则,对于某些特征,在不同的数据集可能会产生不同的分类结果;实际应用中,数据集的类别分布很可能是不同的,因此,本实施例通过对两个数据集中的样本赋予权重,以对数据集中的每个样本进行放缩,使得处理后的数据集的类别分布相同;
由于源域数据由充足的标签,可以直接根据其标签计算得到每个样本的权重并在训练过程中保持不变,相关的计算公式为:其中,ws(x,y)表示源域数据集中样本(x,y)的权重,Maxs表示源域数据集中含样本最多的类别所包含的样本数量,Num(y)表示源域数据集中标签为y的样本数量;
对于目标域数据集,其中的样本没有标签,直接各样本的权重初始化为相同的数值,可选地,本实施例中,该数值具体为1;在后续训练过程中会通过为其赋予伪标签的形式计算样本权重并在训练过程中不断更新;
(S2)将两个数据集中的样本输入联合分布对齐模型,并计算相应的损失;
如图2所示,本实施例中,联合分布对齐模型包括基于对抗的边缘分布对齐网络和条件分布对齐网络;边缘分布对齐网络包括一个特征提取器和一个特征判别器,即图2中的Fea和Disf,条件分布对齐网络包括一个分类器和一个类别判别器,即图2中的Cla和Disc;
传统的生成对抗网络主要分为两个部分:生成器和判别器。生成器生成尽量以假乱真的数据,希望可以骗过判别器,判别器判断输入的数据是真实的还是生成器产生的假数据,希望可以尽可能地判断准确;本实施例中,由于在迁移学习中,已经存在两个不同的数据分布,即源域数据分布和目标域数据分布,因此不需要生成数据,生成器只需要提取数据的特征,希望提取后的两个数据集的数据分布尽可能相同;判别器负责判断生成器提取的特征来自于哪个数据集;当判别器不能正确判断数据来源时,就认为提取的数据分布几乎相同;
相应地,本实施例中,特征提取器Fea用于提取输入样本的特征,特征判别器Disf用于判断特征提取器Fea提取的特征来自于哪一个数据集,分类器Cla用于对特征提取器Fea提取的特征进行分类以产生相应的类别标签,类别判别器用于判断分类器Disc产生的类别标签来自于哪一个数据集;
动态分布对齐主要包括各模块损失函数的计算以及动态分布对齐参数的计算,通过在训练过程中不断更新边缘分布对齐和条件分布对齐的权重实现动态联合分布对齐,以提高迁移后分类器的性能;
将源域数据集和目标域数据集都输入至图2所示的联合分布对齐模型,并得到相应的结果输出后,所计算的损失包括:生成部分损失Gen_loss和判别部分损失Dis_loss,计算表达式如下:
Gen_loss=LCla+αLFea
其中,LCla是分类器的损失,LFea是特征提取器的损失,是特征判别器的损失,是类别判别器的损失;α表示生成部分损失Gen_loss中LFea的权重,k表示源域数据集与目标域数据集之间的边缘分布距离对条件分布距离的影响因子,μ表示源域数据集与目标域数据集的条件分布对齐权重;
边缘分布距离会影响条件分布距离,如图3可知,条件分布距离随着边缘分布距离的变化而变化,良好的特征表示有利于条件分布对齐,为了进一步提高边缘分布对齐的权重,以减少其对条件分布对齐的损失,本实施例引入了边缘分布距离对条件分布距离的影响因子k,用于衡量边缘分布距离对条件分布距离的影响程度,并在边缘分布距离的基础上增加部分条件分布的权重,则边缘分布对齐的权重应为边缘分布的距离和部分条件分布距离之和占总分布距离的比重;由于边缘分布对齐权重有所增加,在动态分布对齐过程中,会优先对齐边缘分布;可选地,本实施例中,k默认取值为0.3,在本发明其他的一些实施例中,k的取值可根据两个数据集实际的特性调整该取值,或者将该取值作为超参数在训练过程中具体确定;
本实施例中,α也是一个超参数,本实施例为α设置了较大的初始值,可选地,α的初始值默认取为1,以提高初始时刻Fea训练的权重,以便得到一个较好的特征表示,从而提高伪标签的质量,为了减少训练过程中特征提取器对分类器的影响,同时将α设置为一个动态递减的超参数,例如按照指数形式动态递减,直至递减至预设的下限值,在其他的一些实施例中,也可以按照其他形式递减;
本实施例在计算判别部分损失时,引入了源域数据集与目标域数据集之间的边缘分布距离对条件分布距离的影响因子k,能够进一步提高迁移后分类器的分类性能;
联合分布对齐模型中,对于生成部分,特征提取器Fea提取源域和目标域的特征,希望提取到的目标域的边缘分布与源域的边缘分布相同,根据生成对抗网络的优化目标,可得其优化目标如下:
分类器Cla利用源域的标记数据进行训练,属于监督学习过程;优化目标如下:
对于判别部分,特征判别器Disf判断特征提取器Fea提取的特征来源,其优化目标如下:
类别判别器Disc判断分类器Cla生成的类别标签的来源;由于这里实际上对齐的是类别分布,因此优化目标需要在传统的生成对抗网络的损失函数的基础上加入实例的类别权重,以保证加权后数据集的类别分布相同。类别判别器Disc的损失函数如下:
其中,xt表示目标域样本,指目标域数据服从目标域的特征分布;(xs,ys)表示源域样本,指源域数据服从源域的特征分布和类别分布,cs为类别数;表示指示函数,在k=ys时,函数值为1,否则,函数值为0;
为了在不消耗大量训练资源和时间的情况下,找到边缘分布对齐和条件分布对齐的合适权重,作为一种可选的实施方式,本实施例中,具体使用边缘分布距离和条件分布距离来计算两个分布的权重,其中,条件分布对齐权重μ的计算表达式具体为:dC和dM分别表示条件分布距离和边缘分布距;
本实施例通过动态分布对齐方法,根据当前边缘分布和条件分布的差异,计算得到两个分布的权重,并在训练过程中不断更新,从而达到近似的结果,在准确衡量边缘分布对齐和条件分布对齐的相对重要性的同时,有效减少了资源和时间消耗;
动态分布对齐参数,即条件分布对齐权重μ和边缘分布权重γ的计算依赖于两个判别器的损失函数值;作为一种可选的实施方式,本实施例中,利用A-distance计算数据集的边缘分布差异和条件分布差异,根据两个差异的比值计算并动态更新两部分分布对齐的权重,实现联合分布对齐;
A-distance是一种常见的用于度量分布距离的度量公式,其形式定义为建立一个线性分类器来区分两个数据领域的hinge损失(即进行二类分类的hinge损失),其计算方式是先在源域和目标域上训练一个二分类器h,使得这个分类器可以区分样本是来自哪一个领域;假设有两个数据分布Ds和Dt,则它们的A-distance距离可根据如下公式计算:
d(Ds,Dt)=2(1-2d(h))
其中,d(h)代表分类器h的损失;根据A-distance的定义可知,训练的二分类器与生成对抗网络中的判别器定义相同,因此,为了进一步减少训练资源和时间的消耗,作为一种可选的实施方式,本实施例中,计算条件分布距离时,以类别判别器作为用于计算A-distance距离的线性分类器;计算边缘分布距离时,以特征判别器作为用于计算A-distance距离的线性分类器,由此可以准确地完成距离度量,此时,d(h)就是特征判别器或类别判别器的损失,由于在训练过程中判别器的损失已经存在,可以分别将Disf和Disc的损失值代入A-distance公式中进行计算,得到边缘分布距离和条件分布距离,再代入权重公式中计算,得到边缘分布对齐和条件分布对齐的权重,而无需引入其他的模块训练,能够有效减少资源和时间消耗;
(S3)根据所计算的损失更新联合分布对齐模型后,利用特征提取器和分类器为目标域数据集中的样本标注伪标签,并基于该标注结果更新目标域数据集中各样本的权重;
其中,wt(x)表示目标域数据集中样本x的权重,Maxt表示标注伪标签后目标域数据集中含样本数量最多的类别所包含的样本数量,Num(y′)表示标注伪标签后目标域数据集中伪标签为y′的样本数量;
(S4)重复执行步骤(S2)~(S3)以对联合分布对齐模型进行迭代训练,直至达到预设的迭代终止条件;迭代终止后,利用特征提取器和分类器相连构成对目标分类任务进行分类的分类模型;
通过重复执行步骤(S2)~(S3)对联合分布对齐模型进行迭代训练,即可实现动态联合分布对齐,图4所示为用于实现该动态联合分布对齐的一种可选的实现伪代码;可选地,在迭代训练过程中,利用RMSprop优化器进行优化迭代,取小批次数据进行处理,默认值为32;
本实施例中,预设的迭代终止条件,具体是指达到了预设的最大迭代次数。
本实施例中,整个联合分布对齐模型中仅包包含两个判别器,即一个特征判别器和一个类别判别器,相比于现有的动态分布对齐方法中需要一个边缘分布判别器和与C(类别总数)个条件分布判别器而言,本实施例有效减少了需要训练的模块数量,简化了模型,从而很大程度上减少了计算所需的资源和时间。
实施例2:
一种图像分类方法,包括:
将图像分类任务中待分类的图像数据输入上述实施例1提供的基于动态联合分布对齐的分类模型建立方法所建立的分类模型,以由分类模型输出图像分类结果。
通过上述实施例1提供的基于动态联合分布对齐的分类模型建立方法,能够在图像分类任务的训练数据缺乏的情况下,建立具有良好分类性能的分类模型,因此,基于该分类模型,本发明所提供的图像分类方法,能够准确完成图像分类。
实施例3:
一种基于动态联合分布对齐的分类模型建立方法,本实施例与上述实施例1类似,所不同之处在于,本实施例中,目标分类任务为文本分类任务。
实施例4:
一种文本分类方法,包括:
将文本分类任务中待分类的文本数据输入上述实施例3提供的基于动态联合分布对齐的分类模型建立方法所建立的分类模型,以由分类模型输出文本分类结果。
通过本实施例提供的基于动态联合分布对齐的分类模型建立方法,能够在文本分类任务的训练数据缺乏的情况下,建立具有良好分类性能的分类模型,因此,基于该分类模型,本实施例所提供的文本分类方法,能够准确完成文本分类。
实施例5:
一种计算机可读存储介质,包括存储的计算机程序;计算机程序被处理器执行时,控制计算机可读存储介质所在设备执行上述实施例1或3提供的基于动态联合分布对齐的分类模型建立方法,和/或上述实施例2提供的图像分类方法,和/或上述实施例4提供的文本分类方法。
本领域的技术人员容易理解,以上所述仅为本发明的较佳实施例而已,并不用以限制本发明,凡在本发明的精神和原则之内所作的任何修改、等同替换和改进等,均应包含在本发明的保护范围之内。
Claims (10)
1.一种基于动态联合分布对齐的分类模型建立方法,其特征在于,包括如下步骤:
(S1)分别为已标注类别标签的源域数据集和未标注类别标签的目标域数据集中的各样本赋予权重,使得加权后两个数据集的类别分布相同;两个数据集的特征空间和标签空间相同,但边缘分布和条件分布不同;所述目标域数据集属于目标分类任务,所述目标分类任务为图像分类任务或文本分类任务;
(S2)将两个数据集中的样本输入联合分布对齐模型,并计算相应的损失;所述联合分布对齐模型包括基于对抗的边缘分布对齐网络和条件分布对齐网络,所述边缘分布对齐网络包括一个特征提取器和一个特征判别器,所述条件分布对齐网络包括一个分类器和一个类别判别器,所述特征提取器用于提取输入样本的特征,所述特征判别器用于判断所述特征提取器提取的特征来自于哪一个数据集,所述分类器用于对所述特征提取器提取的特征进行分类以产生相应的类别标签,所述类别判别器用于判断所述分类器产生的类别标签来自于哪一个数据集;
(S3)根据所计算的损失更新所述联合分布对齐模型后,利用所述特征提取器和所述分类器为所述目标域数据集中的样本标注伪标签,并基于该标注结果更新所述目标域数据集中各样本的权重;
(S4)重复执行步骤(S2)~(S3)以对所述联合分布对齐模型进行迭代训练,直至达到预设的迭代终止条件;迭代终止后,利用所述特征提取器和所述分类器相连构成对所述目标分类任务进行分类的分类模型。
4.如权利要求3所述的基于动态联合分布对齐的分类模型建立方法,其特征在于,所述条件分布距离和所述边缘分布距离均通过A-distance距离度量。
5.如权利要求4所述的基于动态联合分布对齐的分类模型建立方法,其特征在于,计算所述条件分布距离时,以所述类别判别器作为用于计算A-distance距离的线性分类器;计算所述边缘分布距离时,以所述特征判别器作为用于计算A-distance距离的线性分类器。
6.如权利要求2所述的基于动态联合分布对齐的分类模型建立方法,其特征在于,在迭代训练过程中,生成部分损失Gen_loss中LFea的权重α动态递减。
其中,ws(x,y)表示所述源域数据集中样本(x,y)的权重,Maxs表示所述源域数据集中含样本最多的类别所包含的样本数量,Num(y)表示所述源域数据集中标签为y的样本数量;wt(x)表示所述目标域数据集中样本x的权重,Maxt表示标注伪标签后所述目标域数据集中含样本数量最多的类别所包含的样本数量,Num(y′)表示标注伪标签后所述目标域数据集中伪标签为y′的样本数量。
8.一种图像分类方法,其特征在于,包括:
将图像分类任务中待分类的图像数据输入权利要求1-7任一项所述的基于动态联合分布对齐的分类模型建立方法所建立的分类模型,以由所述分类模型输出分类结果。
9.一种文本分类方法,其特征在于,包括:
将文本分类任务中待分类的文本数据输入权利要求1-7任一项所述的基于动态联合分布对齐的分类模型建立方法所建立的分类模型,以由所述分类模型输出分类结果。
10.一种计算机可读存储介质,其特征在于,包括存储的计算机程序;所述计算机程序被处理器执行时,控制所述计算机可读存储介质所在设备执行权利要求1-7任一项所述的基于动态联合分布对齐的分类模型建立方法,和/或权利要求8所述的图像分类方法,和/或权利要求9所述的文本分类方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110128228.9A CN112836739B (zh) | 2021-01-29 | 2021-01-29 | 基于动态联合分布对齐的分类模型建立方法及其应用 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110128228.9A CN112836739B (zh) | 2021-01-29 | 2021-01-29 | 基于动态联合分布对齐的分类模型建立方法及其应用 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112836739A true CN112836739A (zh) | 2021-05-25 |
CN112836739B CN112836739B (zh) | 2024-02-09 |
Family
ID=75932403
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110128228.9A Active CN112836739B (zh) | 2021-01-29 | 2021-01-29 | 基于动态联合分布对齐的分类模型建立方法及其应用 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112836739B (zh) |
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113268833A (zh) * | 2021-06-07 | 2021-08-17 | 重庆大学 | 一种基于深度联合分布对齐的迁移故障诊断方法 |
CN113469273A (zh) * | 2021-07-20 | 2021-10-01 | 南京信息工程大学 | 基于双向生成及中间域对齐的无监督域适应图像分类方法 |
CN113537403A (zh) * | 2021-08-14 | 2021-10-22 | 北京达佳互联信息技术有限公司 | 图像处理模型的训练方法和装置及预测方法和装置 |
CN113688867A (zh) * | 2021-07-20 | 2021-11-23 | 广东工业大学 | 一种跨域图像分类方法 |
CN114329003A (zh) * | 2021-12-27 | 2022-04-12 | 北京达佳互联信息技术有限公司 | 媒体资源数据处理方法、装置、电子设备及存储介质 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109753992A (zh) * | 2018-12-10 | 2019-05-14 | 南京师范大学 | 基于条件生成对抗网络的无监督域适应图像分类方法 |
WO2019204547A1 (en) * | 2018-04-18 | 2019-10-24 | Maneesh Kumar Singh | Systems and methods for automatic speech recognition using domain adaptation techniques |
CN111160462A (zh) * | 2019-12-30 | 2020-05-15 | 浙江大学 | 一种基于多传感器数据对齐的无监督个性化人类活动识别方法 |
CN112232241A (zh) * | 2020-10-22 | 2021-01-15 | 华中科技大学 | 一种行人重识别方法、装置、电子设备和可读存储介质 |
-
2021
- 2021-01-29 CN CN202110128228.9A patent/CN112836739B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2019204547A1 (en) * | 2018-04-18 | 2019-10-24 | Maneesh Kumar Singh | Systems and methods for automatic speech recognition using domain adaptation techniques |
CN109753992A (zh) * | 2018-12-10 | 2019-05-14 | 南京师范大学 | 基于条件生成对抗网络的无监督域适应图像分类方法 |
CN111160462A (zh) * | 2019-12-30 | 2020-05-15 | 浙江大学 | 一种基于多传感器数据对齐的无监督个性化人类活动识别方法 |
CN112232241A (zh) * | 2020-10-22 | 2021-01-15 | 华中科技大学 | 一种行人重识别方法、装置、电子设备和可读存储介质 |
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113268833A (zh) * | 2021-06-07 | 2021-08-17 | 重庆大学 | 一种基于深度联合分布对齐的迁移故障诊断方法 |
CN113268833B (zh) * | 2021-06-07 | 2023-07-04 | 重庆大学 | 一种基于深度联合分布对齐的迁移故障诊断方法 |
CN113469273A (zh) * | 2021-07-20 | 2021-10-01 | 南京信息工程大学 | 基于双向生成及中间域对齐的无监督域适应图像分类方法 |
CN113688867A (zh) * | 2021-07-20 | 2021-11-23 | 广东工业大学 | 一种跨域图像分类方法 |
CN113469273B (zh) * | 2021-07-20 | 2023-12-05 | 南京信息工程大学 | 基于双向生成及中间域对齐的无监督域适应图像分类方法 |
CN113537403A (zh) * | 2021-08-14 | 2021-10-22 | 北京达佳互联信息技术有限公司 | 图像处理模型的训练方法和装置及预测方法和装置 |
CN114329003A (zh) * | 2021-12-27 | 2022-04-12 | 北京达佳互联信息技术有限公司 | 媒体资源数据处理方法、装置、电子设备及存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN112836739B (zh) | 2024-02-09 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112836739B (zh) | 基于动态联合分布对齐的分类模型建立方法及其应用 | |
CN110580501B (zh) | 一种基于变分自编码对抗网络的零样本图像分类方法 | |
CN110866536B (zh) | 一种基于PU learning的跨区域企业偷漏税识别方法 | |
CN111476294A (zh) | 一种基于生成对抗网络的零样本图像识别方法及系统 | |
CN110502277B (zh) | 一种基于bp神经网络的代码坏味检测方法 | |
CN108537168B (zh) | 基于迁移学习技术的面部表情识别方法 | |
US20230385333A1 (en) | Method and system for building training database using automatic anomaly detection and automatic labeling technology | |
CN112069310A (zh) | 基于主动学习策略的文本分类方法及系统 | |
CN116644755B (zh) | 基于多任务学习的少样本命名实体识别方法、装置及介质 | |
CN111444342A (zh) | 一种基于多重弱监督集成的短文本分类方法 | |
CN111325264A (zh) | 一种基于熵的多标签数据分类方法 | |
CN113139664A (zh) | 一种跨模态的迁移学习方法 | |
CN107766895B (zh) | 一种诱导式非负投影半监督数据分类方法及系统 | |
CN109656808A (zh) | 一种基于混合式主动学习策略的软件缺陷预测方法 | |
Lonij et al. | Open-world visual recognition using knowledge graphs | |
CN116704208A (zh) | 基于特征关系的局部可解释方法 | |
CN112199287A (zh) | 基于强化混合专家模型的跨项目软件缺陷预测方法 | |
CN112070127A (zh) | 一种基于智能分析的海量数据样本增量分析方法 | |
Tang et al. | Semi-supervised Contrastive Memory Network for Industrial Process Working Condition Monitoring | |
CN109711456A (zh) | 一种具备鲁棒性的半监督图像聚类方法 | |
Hirunyawanakul et al. | A Novel Heuristic Method for Misclassification Cost Tuning in Imbalanced Data | |
CN117456309B (zh) | 基于中间域引导与度量学习约束的跨域目标识别方法 | |
CN117435916B (zh) | 航片ai解译中的自适应迁移学习方法 | |
Vasudevan et al. | Determination of nuclear position by the arrangement of actin filaments using deep generative networks | |
Rebanowako et al. | Age-Invariant Facial Expression Classification Method Using Deep Learning |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |