CN117474080A - 一种基于多判别器的对抗迁移学习方法和装置 - Google Patents
一种基于多判别器的对抗迁移学习方法和装置 Download PDFInfo
- Publication number
- CN117474080A CN117474080A CN202311540180.8A CN202311540180A CN117474080A CN 117474080 A CN117474080 A CN 117474080A CN 202311540180 A CN202311540180 A CN 202311540180A CN 117474080 A CN117474080 A CN 117474080A
- Authority
- CN
- China
- Prior art keywords
- domain
- trained
- samples
- discriminators
- inter
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 73
- 238000013508 migration Methods 0.000 title claims abstract description 31
- 230000005012 migration Effects 0.000 title claims abstract description 20
- 238000012549 training Methods 0.000 claims description 65
- 238000004364 calculation method Methods 0.000 claims description 9
- 230000006870 function Effects 0.000 claims description 8
- 238000013526 transfer learning Methods 0.000 claims description 8
- 238000004590 computer program Methods 0.000 claims description 6
- 238000003745 diagnosis Methods 0.000 description 12
- 238000004891 communication Methods 0.000 description 9
- 230000008569 process Effects 0.000 description 8
- 238000013473 artificial intelligence Methods 0.000 description 6
- 238000010586 diagram Methods 0.000 description 6
- 238000000605 extraction Methods 0.000 description 5
- 238000005457 optimization Methods 0.000 description 4
- 238000012545 processing Methods 0.000 description 4
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 230000007547 defect Effects 0.000 description 3
- 201000010099 disease Diseases 0.000 description 3
- 208000037265 diseases, disorders, signs and symptoms Diseases 0.000 description 3
- 238000012360 testing method Methods 0.000 description 3
- 238000013145 classification model Methods 0.000 description 2
- 230000001419 dependent effect Effects 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 238000002595 magnetic resonance imaging Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 230000009467 reduction Effects 0.000 description 2
- 239000007787 solid Substances 0.000 description 2
- 238000006467 substitution reaction Methods 0.000 description 2
- 241000607479 Yersinia pestis Species 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 230000006378 damage Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 230000009545 invasion Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000005065 mining Methods 0.000 description 1
- 238000010295 mobile communication Methods 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000037380 skin damage Effects 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
- XLYOFNOQVPJJNP-UHFFFAOYSA-N water Substances O XLYOFNOQVPJJNP-UHFFFAOYSA-N 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/213—Feature extraction, e.g. by transforming the feature space; Summarisation; Mappings, e.g. subspace methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/094—Adversarial learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Biomedical Technology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Health & Medical Sciences (AREA)
- Evolutionary Biology (AREA)
- Image Analysis (AREA)
Abstract
本申请实施例涉及一种基于多判别器的对抗迁移学习方法和装置,本方法引入多个域判别器之间对所有待训练样本的总差异度来估计域间损失,量化并减小对待训练样本特征提取的域间差异,不仅具有良好的域迁移决策能力,而且使得域迁移具有良好的稳定性,在不需要依赖于大量完整的数据集,即可高效稳定的训练特征提取器,而且训练好的特征提取器能够对更多具有物理意义的特征进行提取,能够提高标签判别器对分类样本的分类精确度。
Description
技术领域
本申请实施例涉及人工智能技术领域,尤其涉及一种基于多判别器的对抗迁移学习方法和装置。
背景技术
人工智能(AI)的智能化和集成化应用已经在众多行业中获得了广泛的关注。基于数据驱动的智能诊断是一种端到端的方法,通过提取和挖掘大量数据中的信息和特征,可实现自动特征提取和损伤识别,并已被广泛应用于故障诊断分类任务中。相比传统的故障诊断和分类方法,其优点在于,不需要给出所有已知的故障的机理信息,也不需要人工给定分类阈值,通过机器学习和深度学习等人工智能方法自动完成特征提取和分类任务。
目前,利用人工智能方法进行故障诊断和分类的方法依赖大量已知的故障数据,通过参数估计和数值拟合等方法,结合优化算法对大批量数据进行特征提取和分类。通常现有人工智能方法需要大量已知数据进行训练、求解和优化,从而提高故障识别和分类的精度以及避免训练过拟合。但在实际应用中,由于设备很少工作在已知故障的状态下,因此难于获取实际工况下大量的已知数据,这就导致现有人工智能故障诊断方法还是存在效率低、不稳定等缺点。
发明内容
以下是对本文详细描述的主题的概述。本概述并非是为了限制权利要求的保护范围。
本公开实施例的主要目的在于提出一种基于多判别器的对抗迁移学习方法和装置,在不需要依赖于大量完整的数据集,即可高效稳定的训练特征提取器。
第一方面,本公开实施例提出一种基于多判别器的对抗迁移学习方法,所述基于多判别器的对抗迁移学习方法包括:
获取待分类样本;
通过特征提取器提取所述待分类样本的待分类特征,并将所述待分类特征输入至标签判别器,以使所述标签判别器对所述待分类样本进行分类;其中所述特征提取器通过如下方式进行训练:
获取多个待训练样本,所述多个待训练样本包括源域样本和目标域样本;
在当前训练次数下,通过所述特征提取器提取所述多个待训练样本对应的多个深层特征;
将所述多个深层特征分别输入至多个域判别器中,以得到每一个所述域判别器对每一个所述待训练样本的所述深层特征所属域的预测概率;
根据每两个所述域判别器对同一个所述待训练样本的所述预测概率计算每两个所述域判别器之间的差异度,根据每两个所述域判别器之间的所述差异度计算所述多个域判别器之间对同一个所述待训练样本的总差异度;
根据所述多个域判别器对所述多个待训练样本对应的多个所述总差异度计算域间损失;
根据所述域间损失对所述特征提取器进行优化,使所述特征提取器根据所述多个待训练样本进行下一次训练,直至训练结束。
本申请的一些实施例中,所述根据每两个所述域判别器对同一个所述待训练样本的所述预测概率计算每两个所述域判别器之间的差异度,根据每两个所述域判别器之间的所述差异度计算所述多个域判别器之间对同一个所述待训练样本的总差异度,包括:
计算每两个所述域判别器对同一个所述待训练样本的所述预测概率之间的距离值,将所述距离值作为每两个所述域判别器之间的差异值;
采用平均值法对所述多个域判别器中的所有所述距离值进行归一化,得到所述多个域判别器之间对同一个所述待训练样本的总差异度。
本申请的一些实施例中,所述总差异度的计算公式包括:
其中,DV(i)表示所述多个域判别器对第i个待训练样本的总差异度,Cd(m)(x(i))表示第m个域判别器对第i个待训练样本的所述深层特征所属域的预测概率,Cd(n)(x(i))表示第n个域判别器对第i个待训练样本的所述深层特征所属域的预测概率,x(i)表示第i个待训练样本,N表示域判别器的总数。
本申请的一些实施例中,根据所述多个域判别器对所述多个待训练样本对应的多个所述总差异度计算域间损失,包括:
利用信息熵公式将所述多个判别器对所述多个待训练样本的对应的多个所述总差异度转换成不确定度;
将所述不确定度进行梯度反转,得到所述域间损失。
本申请的一些实施例中,所述域间损失的计算公式包括:
其中,DLoss表示域间损失,DV(i)表示所述多个域判别器对第i个待训练样本的总差异度,K表示待训练样本的总数,表示不确定度,(.)gradient inverse表示梯度反转函数,log表示对数函数。
本申请的一些实施例中,判断所述特征提取器训练完成的方式,包括:
设置训练的结束次数和期望域间损失;
如果当前训练次数未达到所述训练的结束次数,根据所述域间损失对所述特征提取器进行优化,以使所述特征提取器对所述多个待训练样本进行下一次训练;
如果当前训练次数达到所述训练的结束次数,且所述域间损失大于或等于所述期望域间损失,根据所述域间损失对所述特征提取器进行优化,以使所述特征提取器对所述多个待训练样本进行下一次训练;
如果当前训练次数达到所述训练的结束次数,且所述域间损失小于所述期望域间损失,训练完成。
本申请的一些实施例中,在当前训练次数下,所述将所述多个深层特征分别输入至多个域判别器之后,所述基于多判别器的对抗迁移学习方法,还包括:
将所述多个深层特征输入至所述标签判别器中,以得到所述标签判别器输出的预测标签;
根据所述预测标签和预设真实标签,计算交叉熵损失;
所述根据所述域间损失对所述特征提取器进行优化,包括:
根据所述交叉熵损失和所述域间损失对所述特征提取器进行优化。
第二方面,本公开实施例提出一种基于多判别器的对抗迁移学习装置,所述基于多判别器的对抗迁移学习装置包括:
数据获取单元,用于获取待分类样本;
数据分类单元,用以通过特征提取器提取所述待分类样本对应的待分类特征,并将所述待分类特征输入至标签判别器,以使所述标签判别器对所述待分类样本进行分类;其中所述特征提取器通过如下方式进行训练:
获取多个待训练样本,所述多个待训练样本包括源域样本和目标域样本;
在当前训练次数下,通过所述特征提取器提取所述多个待训练样本对应的多个深层特征;将所述多个深层特征分别输入至多个域判别器中,以得到每一个所述域判别器对每一个所述待训练样本的所述深层特征所属域的预测概率;根据每两个所述域判别器对同一个所述待训练样本的所述预测概率计算每两个所述域判别器之间的差异度,根据每两个所述域判别器之间的所述差异度计算所述多个域判别器之间对同一个所述待训练样本的总差异度;根据所述多个域判别器对所述多个待训练样本对应的多个所述总差异度计算域间损失;
根据所述域间损失对所述特征提取器进行优化,使所述特征提取器根据所述多个待训练样本进行下一次训练,直至训练结束。
第三方面,本公开实施例提出一种电子设备,包括至少一个存储器;
至少一个处理器;
至少一个计算机程序;
所述计算机程序被存储在所述存储器中,处理器执行所述至少一个计算机程序以实现:
如第一方面实施例任一项所述的基于多判别器的对抗迁移学习方法。
第四方面,本公开实施例提出一种计算机可读存储介质,所述计算机可读存储介质存储有计算机可执行指令,所述计算机可执行指令用于使计算机执行:
如第一方面实施例任一项所述的基于多判别器的对抗迁移学习方法。
本申请的一些实施例提供了一种基于多判别器的对抗迁移学习方法,本方法在特征提取器的每一次训练过程中,先获得每一个域判别器对每一个待训练样本的深层特征所属域的预测概率,其次根据每两个域判别器对同一个待训练样本的预测概率计算每两个域判别器之间的差异度,然后根据每两个域判别器之间的差异度计算多个域判别器之间对同一个待训练样本的总差异度,最后根据多个域判别器对多个待训练样本对应的多个差异度计算域间损失,根据域间损失对特征提取器进行优化;本方法引入多个域判别器之间对所有待训练样本的总差异度来估计域间损失,量化并减小对待训练样本特征提取的域间差异,不仅具有良好的域迁移决策能力,而且使得域迁移具有良好的稳定性,在不需要依赖于大量完整的数据集,即可高效稳定的训练特征提取器,而且训练好的特征提取器能够对更多具有物理意义的特征进行提取,能够提高标签判别器对分类样本的分类精确度。
可以理解的是,上述第二方面至第四方面与相关技术相比存在的有益效果与上述第一方面与相关技术相比存在的有益效果相同,可以参见上述第一方面中的相关描述,在此不再赘述。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例或相关技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请实施例的一些实施例,对于本领域普通技术人员来说,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1是本申请一个实施提供的一种分类模型的分类流程示意图;
图2是本申请一个实施例提供的一种基于多判别器的对抗迁移学习方法的流程示意图;
图3是指图2中步骤S240中计算总差异度的流程示意图;
图4是指图2中步骤S250中计算域间损失的流程示意图;
图5是本申请一个实施例提供的训练特征提取器的流程示意图;
图6是本申请一个实施例提供的训练特征提取器的架构图;
图7是本申请一个实施例提供的源域数据和目标域数据的时域对比图;
图8是本申请一个实施例提供的当域判别器数量不同时,源域和目标域特征t-SNE降维对比图;
图9是本申请一个实施例提供的当域判别器数目不同时,域聚类度DCD数值对比图;
图10是本申请一个实施例提供的一种基于多判别器的对抗迁移学习装置的结构图;
图11是本申请一个实施例提供的一种电子设备的结构示意图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本申请,并不用于限定本申请。
需要说明的是,虽然在装置示意图中进行了功能模块划分,在流程图中示出了逻辑顺序,但是在某些情况下,可以以不同于装置中的模块划分,或流程图中的顺序执行所示出或描述的步骤。说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。
除非另有定义,本文所使用的所有的技术和科学术语与属于本申请的技术领域的技术人员通常理解的含义相同。本文中所使用的术语只是为了描述本申请实施例的目的,不是旨在限制本申请。
参照图1和图2,本申请的一个实施例,提供了一种基于多判别器的对抗迁移学习方法,基于多判别器的对抗迁移学习方法包括如下步骤:
步骤S100、获取待分类样本。
步骤S200、通过特征提取器提取待分类样本的待分类特征,并将待分类特征输入至标签判别器,以使标签判别器对待分类样本进行分类。其中特征提取器通过如下方式进行训练:
步骤S210、获取多个待训练样本,多个待训练样本包括源域样本和目标域样本。
步骤S220、在当前训练次数下,通过特征提取器提取多个待训练样本对应的多个深层特征。
步骤S230、将多个深层特征分别输入至多个域判别器中,以得到每一个域判别器对每一个待训练样本的深层特征所属域的预测概率。
步骤S240、根据每两个域判别器对同一个待训练样本的预测概率计算每两个域判别器之间的差异度,根据每两个域判别器之间的差异度计算多个域判别器之间对同一个待训练样本的总差异度。
步骤S250、根据多个域判别器对多个待训练样本对应的多个总差异度计算域间损失。
步骤S260、根据域间损失对特征提取器进行优化,使特征提取器根据多个待训练样本进行下一次训练,直至训练结束。
源域数据是指不同数据集的同类型数据;目标域数据是指需要进行分类的未知的同类型数据。本实施例提供了一种用于分类的模型,通过迁移学习结合源域数据和目标域数据,对模型进行训练,使得模型可以有效的对目标域数据完成分类。
在实施例中,提供一种用于进行故障诊断(或者分类)的模型,其中故障的类别通过事先指定,此处不进行限定,用于分类的模型包括:
特征提取器,用于提取待分类样本的深层特征;
标签判别器,用于输出对待分类样本的分类结果;
多个域判别器,用于输出待训练样本的深层特征所属域的预测概率,预测概率用于后续生成域间损失,以基于损失函数对特征提取器进行优化。
在本实施例中,首先采用多个待训练样本,样本可以图像数据(如RGB图像、红外、MRI(磁共振成像)图像等)、文本数据(故障诊断文本等)等。待训练样本包括源域样本和目标域样本,在数量上,选择数量较多的完整的源域样本和少量不完整的目标域样本。
在当前训练次数下,通过特征提取器提取多个待训练样本对应的多个深层特征,假设有K个待训练样本,将K个待训练样本输入至特征提取器中,得到K个深层特征,当前训练次数是指目前正在训练轮数,例如设置M次训练,当超出M次,结束训练,当前训练次数是从第1次至M次中的一次训练流程。这里对特征提取器的类别不进行限定。
步骤S220和230中,将多个深层特征分别输入至多个域判别器中,以得到每一个域判别器对每一个待训练样本的深层特征所属域的预测概率。域判别器的数量与待输入样本的数量不相关,假设有K个深层特征,M个域判别器,将K个深层特征分别输入第1个域判别器、第2个域判别器、...和第M个域判别器中,分别得到M个域判别器输出的M个预测概率。通过设置多个判别器,可以减小分类器的偶然性,提高对单一数据特征学习的同一性。
步骤S240中,根据每两个域判别器对同一个待训练样本的预测概率计算每两个域判别器之间的差异度,例如有M个域判别器和K个待训练样本,首先针对第1个待训练样本,基于第1个域判别器和第2个域判别器分别对第1个待训练样本的概率计算第1个域判别器和第2个域判别器之间的差异度,基于第1个域判别器和第3个域判别器分别对第1个待训练样本的概率计算第1个域判别器和第3个域判别器之间的差异度,...,以此类推,M个域判别器中每两个域判别器对K个待训练样本中的每一个待训练样本的差异度。
步骤S240中,再根据每两个域判别器之间的差异度计算多个域判别器之间对同一个待训练样本的总差异度,例如对于第1个待训练样本,有X个差异度,然后基于X个差异度计算对该待训练样本的总差异度。
在步骤S250中,根据多个域判别器对多个待训练样本对应的多个总差异度计算域间损失。例如步骤S240对于每一个待训练样本,所有域判别器有一个总差异度,对于N个待训练样本,有N个总差异度,步骤250根据N各总差异度计算域间损失,引入多个域判别器之间对所有待训练样本的总差异度来估计域间损失,这样的话,能量化并减小对待训练样本特征提取的域间差异。
在步骤S260中,根据域间损失对特征提取器进行优化,使特征提取器根据多个待训练样本进行下一次训练,直至训练结束。
本方法在特征提取器的每一次训练过程中,先获得每一个域判别器对每一个待训练样本的深层特征所属域的预测概率,其次根据每两个域判别器对同一个待训练样本的预测概率计算每两个域判别器之间的差异度,然后根据每两个域判别器之间的差异度计算多个域判别器之间对同一个待训练样本的总差异度,最后根据多个域判别器对多个待训练样本对应的多个差异度计算域间损失,根据域间损失对特征提取器进行优化;本方法引入多个域判别器之间对所有待训练样本的总差异度来估计域间损失,量化并减小对待训练样本特征提取的域间差异,不仅具有良好的域迁移决策能力,而且使得域迁移具有良好的稳定性,在不需要依赖于大量完整的数据集,即可高效稳定的训练特征提取器,而且训练好的特征提取器能够对更多具有物理意义的特征进行提取,能够提高标签判别器对分类样本的分类精确度。本方法避免了一般人工智能方法的过度依赖大量完备数据集的问题,更适用于工业应用中的故障诊断,得到的多域对抗迁移学习,为建立稳定的少样本、缺类别的故障诊断方法提供了思路和方法。
如图3,在本申请的一些实施例中,根据每两个域判别器对同一个待训练样本的预测概率计算每两个域判别器之间的差异度,根据每两个域判别器之间的差异度计算多个域判别器之间对同一个待训练样本的总差异度,包括:
步骤241、计算每两个域判别器对同一个待训练样本的预测概率之间的距离值,将距离值作为每两个域判别器之间的差异值。
步骤242、采用平均值法对多个域判别器中的所有距离值进行归一化,得到多个域判别器之间对同一个待训练样本的总差异度。
本方法采用计算每两个域判别器对同一个待训练样本的预测概率之间的距离值,将距离值作为每两个域判别器之间的差异值,然后采用平均值法对多个域判别器中的所有距离值进行归一化,得到多个域判别器之间对同一个待训练样本的总差异度。通过引入距离值作为每两个域判别器对同一个待训练样本的预测概率之间差异度,能够准确度量两个域判别器的预测概率之间的差异程度,提高差异度表示的精确度。
在本申请的一些实施例中,总差异度的计算公式包括:
其中,DV(i)表示多个域判别器对第i个待训练样本的总差异度,Cd(m)(x(i))表示第m个域判别器对第i个待训练样本的深层特征所属域的预测概率,Cd(n)(x(i))表示第n个域判别器对第i个待训练样本的深层特征所属域的预测概率,N表示多个判别器的数量。
参照图4,在本申请的一些实施例中,步骤S250中的域间损失的获得方式包括:
步骤S251、利用信息熵公式将多个判别器对多个待训练样本的对应的多个总差异度转换成不确定度。
步骤S252、将不确定度进行梯度反转,得到域间损失。
由于上述步骤在计算总差异度时,对数据归一化到0~1之间,这里采用信息熵对数函数形式进行映射,更有利于特征提取器的优化。
在本申请的一个实施例中,域间损失的计算公式包括:
其中,DLoss表示域间损失,DV(i)表示多个域判别器对第i个待训练样本的总差异度,K表示待训练样本的总数,表示不确定度,(.)gradient inverse表示梯度反转函数,log表示对数函数。
参照图5和图6,在本申请的一个例子中,提供了一种用于故障分类的模型的训练过程,包括如下的流程:
参数设置:设置训练的结束次数epoch,当前训练次数为t,t=0,期望域间损失ELoss,接近于0。
步骤S310、将源域和目标域样本输入特征提取器,提取出样本的深层特征,记为x(i),i=1,2,3…K,表示提取出的第i个样本的深层特征,样本总数为K。
源域(source domain)表示与测试样本不同的领域,但是有丰富的监督信息。
目标域(target domain)表示测试样本所在的领域,无标签或者只有少量标签。源域和目标域往往属于同一类任务,但是分布不同。
步骤S320、计算出每个域判别器对第i个样本的深层特征的所属域的预测概率,记为Cd(m)(x(i)),i=1,2,3…K;Cd(m)(·)表示第m个域判别器,m=1,2,3…N,共N个判别器。每一个域判别器为一层全连接层。
步骤S330、利用步骤S320中所得到的每个域判别器对同一个样本所属域的预测概率,通过如下公式计算出域判别器之间预测概率的差异度,并用平均值法来进行归一化,记为DV(i),m和n表示域判别器的编号。通过计算N个域判别器两两预测的样本所属域的概率之间的平均距离之和,来表示域判别器之间的总差异度。
步骤S340、通过信息熵公式将所有样本的总差异度转化为更直观的不确定度,并进行梯度反转,记为域间损失DLoss,此步骤由如下公式表示,(·)gradient inverse表示梯度反转运算,可通过取反或者取倒数来实现,K表示样本的总个数。
步骤S350、对同一批样本进行迭代优化,重复步骤S310至步骤S340,直至当前训练次数达到epoch次,使得域间损失DLoss下降并稳定在接近于0的数值,表示特征提取器对源域和目标域样本的深层次特征都可以有效识别,即完成对抗迁移学习。
步骤S360、在训练完成之后,在测试样本验证对抗迁移学习的效果。
步骤S370、利用分类的模型进行待分类数据的分类时,由特征提取器提取深层特征,将深层特征输入至标签判别器中,由标签判别器对待分类数据进行分类,标签判别器为一层全连接层。其中标签判别器在模型的训练过程中,也进行反向传播,即在训练时,深层特征也将输入至标签判别器,进而得到输出的预测标签,根据预测标签和预设真实标签,计算交叉熵损失,进而根据交叉熵损失对特征提取器进行优化。
通过上述步骤S310至S360训练得到的分类模型,可以作用于疾病辅助诊断(例如利用图像进行人体皮肤损伤的辨别)、植物病虫害识别(基于可见光图像和红外进行病叶的分类)、桥梁,建筑物,道路故障识别等众多基于图像对故障进行分类的应用场景。例如上述方法可以作用于桥梁,建筑物,道路的故障识别,首选选取对应的图像,标注相应的故障类别(裂缝、坑洼、侵水等),然后进行上述步骤310-360的训练过程,最后把模型应用于实际,进行分类,能够识别桥梁,建筑物,道路上的裂缝等缺陷。
本方法通过在计算域间损失的过程中,引入多个域判别器,分别将多个域判别器对同一样本特征的所属域概率预测进行估计,记为Cd(1)(x(i))~Cd(m)(x(i)),计算m个域判别器的概率估计后,进行差异度计算,并对所有样本通过信息熵公式和均值归一化处理,在梯度反转后作为域间损失DLoss对神经网络进行迭代优化,使得特征提取器可以提取源域和目标域数据中有意义的特征作为分类的一句。由于本方法通过引入多个域判别器之间的差异度来估计域间损失,量化并减小对样本特征提取的域间差异,并具有良好的决策能力,因此,使得域迁移具有良好的稳定性。而现有的域迁移方法构造的域间损失决策能力不足,缺乏稳定性。并且,本方法可以在人工智能方法上,对更多具有物理意义的特征进行提取,完成对少样本甚至缺类别的源域数据的诊断和分类。而现有的人工智能诊断方法,过于依赖于大量完整的数据集,现有的迁移学习方法存在效率低、不稳定等缺点。
参照图7至图9,本申请还提供一组实验数据:
图7示出了源域数据和目标域数据的时域对比图;图7中横坐标表示时间(微秒),纵坐标表示振幅。
下表1示出了当域判别器数量不同时,标签判别器在目标域数据上的分类准确率;
表1
图8示出了当域判别器数量不同时,源域和目标域特征t-SNE(降维法)降维对比图。
图9示出了当域判别器数目不同时,域聚类度DCD数值对比图,其中DCD值越大表明域对抗迁移效果越好。
参照图10,本申请的一个实施例,提供一种基于多判别器的对抗迁移学习装置,所述基于多判别器的对抗迁移学习装置包括:
数据获取单元1100用于获取待分类样本;
数据分类单元1200用以通过特征提取器提取待分类样本对应的待分类特征,并将待分类特征输入至标签判别器,以使标签判别器对待分类样本进行分类;其中特征提取器通过如下方式进行训练:
获取多个待训练样本,多个待训练样本包括源域样本和目标域样本;
在当前训练次数下,通过特征提取器提取多个待训练样本对应的多个深层特征;将多个深层特征分别输入至多个域判别器中,以得到每一个域判别器对每一个待训练样本的深层特征所属域的预测概率;根据每两个域判别器对同一个待训练样本的预测概率计算每两个域判别器之间的差异度,根据每两个域判别器之间的差异度计算多个域判别器之间对同一个待训练样本的总差异度;根据多个域判别器对多个待训练样本对应的多个总差异度计算域间损失;
根据域间损失对特征提取器进行优化,使特征提取器根据多个待训练样本进行下一次训练,直至训练结束。
需要注意的是,本申请提供的装置实施例与上述的方法实施例是基于相同的发发明构思,因此上述方法实施例的内容同样适用于本装置实施例,此处不再赘述。
本申请实施例还提供了一种电子设备,本电子设备包括:
至少一个存储器;
至少一个处理器;
至少一个程序;
程序被存储在存储器中,处理器执行至少一个程序以实现本公开实施上述的基于多判别器的对抗迁移学习方法。
该电子设备可以为包括手机、平板电脑、个人数字助理(Personal DigitalAssistant,PDA)、车载电脑等任意智能终端。
下面结合图11对本申请实施例的电子设备进行详细介绍。电子设备包括:
处理器1600,可以采用通用的中央处理器(Central Processing Unit,CPU)、微处理器、应用专用集成电路(Application Specific Integrated Circuit,ASIC)、或者一个或多个集成电路等方式实现,用于执行相关程序,以实现本公开实施例所提供的技术方案;
存储器1700,可以采用只读存储器(Read Only Memory,ROM)、静态存储设备、动态存储设备或者随机存取存储器(Random Access Memory,RAM)等形式实现。存储器1700可以存储操作系统和其他应用程序,在通过软件或者固件来实现本说明书实施例所提供的技术方案时,相关的程序代码保存在存储器1700中,并由处理器1600来调用执行本公开实施例的基于多判别器的对抗迁移学习方法。
输入/输出接口1800,用于实现信息输入及输出;
通信接口1900,用于实现本设备与其他设备的通信交互,可以通过有线方式(例如USB、网线等)实现通信,也可以通过无线方式(例如移动网络、WIFI、蓝牙等)实现通信;
总线2000,在设备的各个组件(例如处理器1600、存储器1700、输入/输出接口1800和通信接口1900)之间传输信息;
其中处理器1600、存储器1700、输入/输出接口1800和通信接口1900通过总线2000实现彼此之间在设备内部的通信链接。
本公开实施例还提供了一种存储介质,该存储介质是计算机可读存储介质,该计算机可读存储介质存储有计算机可执行指令,该计算机可执行指令用于使计算机执行上述基于多判别器的对抗迁移学习方法。
存储器作为一种非暂态计算机可读存储介质,可用于存储非暂态软件程序以及非暂态性计算机可执行程序。此外,存储器可以包括高速随机存取存储器,还可以包括非暂态存储器,例如至少一个磁盘存储器件、闪存器件、或其他非暂态固态存储器件。在一些实施方式中,存储器可选包括相对于处理器远程设置的存储器,这些远程存储器可以通过网络链接至该处理器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
本公开实施例描述的实施例是为了更加清楚的说明本公开实施例的技术方案,并不构成对于本公开实施例提供的技术方案的限定,本领域技术人员可知,随着技术的演变和新应用场景的出现,本公开实施例提供的技术方案对于类似的技术问题,同样适用。
本领域技术人员可以理解的是,图中示出的技术方案并不构成对本公开实施例的限定,可以包括比图示更多或更少的步骤,或者组合某些步骤,或者不同的步骤。
以上所描述的装置实施例仅仅是示意性的,其中作为分离部件说明的单元可以是或者也可以不是物理上分开的,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。
本领域普通技术人员可以理解,上文中所公开方法中的全部或某些步骤、系统、设备中的功能模块/单元可以被实施为软件、固件、硬件及其适当的组合。
本申请的说明书及上述附图中的术语“第一”、“第二”、“第三”、“第四”等(如果存在)是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本申请的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
应当理解,在本申请中,“至少一个(项)”是指一个或者多个,“多个”是指两个或两个以上。“和/或”,用于描述关联对象的关联关系,表示可以存在三种关系,例如,“A和/或B”可以表示:只存在A,只存在B以及同时存在A和B三种情况,其中A,B可以是单数或者复数。字符“/”一般表示前后关联对象是一种“或”的关系。“以下至少一项(个)”或其类似表达,是指这些项中的任意组合,包括单项(个)或复数项(个)的任意组合。例如,a,b或c中的至少一项(个),可以表示:a,b,c,“a和b”,“a和c”,“b和c”,或“a和b和c”,其中a,b,c可以是单个,也可以是多个。
在本申请所提供的几个实施例中,应该理解到,所揭露的装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信链接可以是通过一些接口,装置或单元的间接耦合或通信链接,可以是电性,机械或其它的形式。
作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括多指令用以使得一个电子设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(Read-Only Memory,ROM)、随机存取存储器(Random Access Memory,RAM)、磁碟或者光盘等各种可以存储程序的介质。
以上是对本申请实施例的较佳实施进行了具体说明,但本申请实施例并不局限于上述实施方式,熟悉本领域的技术人员在不违背本申请实施例精神的前提下还可作出种种的等同变形或替换,这些等同的变形或替换均包含在本申请实施例权利要求所限定的范围内。
Claims (10)
1.一种基于多判别器的对抗迁移学习方法,其特征在于,所述基于多判别器的对抗迁移学习方法,包括:
获取待分类样本;
通过特征提取器提取所述待分类样本的待分类特征,并将所述待分类特征输入至标签判别器,以使所述标签判别器对所述待分类样本进行分类;其中所述特征提取器通过如下方式进行训练:
获取多个待训练样本,所述多个待训练样本包括源域样本和目标域样本;
在当前训练次数下,通过所述特征提取器提取所述多个待训练样本对应的多个深层特征;
将所述多个深层特征分别输入至多个域判别器中,以得到每一个所述域判别器对每一个所述待训练样本的所述深层特征所属域的预测概率;
根据每两个所述域判别器对同一个所述待训练样本的所述预测概率计算每两个所述域判别器之间的差异度,根据每两个所述域判别器之间的所述差异度计算所述多个域判别器之间对同一个所述待训练样本的总差异度;
根据所述多个域判别器对所述多个待训练样本对应的多个所述总差异度计算域间损失;
根据所述域间损失对所述特征提取器进行优化,使所述特征提取器根据所述多个待训练样本进行下一次训练,直至训练结束。
2.根据权利要求1所述的基于多判别器的对抗迁移学习方法,其特征在于,所述根据每两个所述域判别器对同一个所述待训练样本的所述预测概率计算每两个所述域判别器之间的差异度,根据每两个所述域判别器之间的所述差异度计算所述多个域判别器之间对同一个所述待训练样本的总差异度,包括:
计算每两个所述域判别器对同一个所述待训练样本的所述预测概率之间的距离值,将所述距离值作为每两个所述域判别器之间的差异值;
采用平均值法对所述多个域判别器中的所有所述距离值进行归一化,得到所述多个域判别器之间对同一个所述待训练样本的总差异度。
3.根据权利要求2所述的基于多判别器的对抗迁移学习方法,其特征在于,所述总差异度的计算公式包括:
其中,DV(i)表示所述多个域判别器对第i个待训练样本的总差异度,Cd(m)(x(i))表示第m个域判别器对第i个待训练样本的所述深层特征所属域的预测概率,Cd(n)(x(i))表示第n个域判别器对第i个待训练样本的所述深层特征所属域的预测概率,x(i)表示第i个待训练样本,N表示域判别器的总数。
4.根据权利要求2所述的基于多判别器的对抗迁移学习方法,其特征在于,根据所述多个域判别器对所述多个待训练样本对应的多个所述总差异度计算域间损失,包括:
利用信息熵公式将所述多个判别器对所述多个待训练样本的对应的多个所述总差异度转换成不确定度;
将所述不确定度进行梯度反转,得到所述域间损失。
5.根据权利要求4所述的基于多判别器的对抗迁移学习方法,其特征在于,所述域间损失的计算公式包括:
其中,DLoss表示域间损失,DV(i)表示所述多个域判别器对第i个待训练样本的总差异度,K表示待训练样本的总数,表示不确定度,(.)gradient inverse表示梯度反转函数,log表示对数函数。
6.根据权利要求1所述的基于多判别器的对抗迁移学习方法,其特征在于,判断所述特征提取器训练完成的方式,包括:
设置训练的结束次数和期望域间损失;
如果当前训练次数未达到所述训练的结束次数,根据所述域间损失对所述特征提取器进行优化,以使所述特征提取器对所述多个待训练样本进行下一次训练;
如果当前训练次数达到所述训练的结束次数,且所述域间损失大于或等于所述期望域间损失,根据所述域间损失对所述特征提取器进行优化,以使所述特征提取器对所述多个待训练样本进行下一次训练;
如果当前训练次数达到所述训练的结束次数,且所述域间损失小于所述期望域间损失,训练完成。
7.根据权利要求1至6任一项所述的基于多判别器的对抗迁移学习方法,其特征在于,在当前训练次数下,所述将所述多个深层特征分别输入至多个域判别器之后,所述基于多判别器的对抗迁移学习方法,还包括:
将所述多个深层特征输入至所述标签判别器中,以得到所述标签判别器输出的预测标签;
根据所述预测标签和预设真实标签,计算交叉熵损失;
所述根据所述域间损失对所述特征提取器进行优化,包括:
根据所述交叉熵损失和所述域间损失对所述特征提取器进行优化。
8.一种基于多判别器的对抗迁移学习装置,其特征在于,所述基于多判别器的对抗迁移学习装置包括:
数据获取单元,用于获取待分类样本;
数据分类单元,用以通过特征提取器提取所述待分类样本的待分类特征,并将所述待分类特征输入至标签判别器,以使所述标签判别器对所述待分类样本进行分类;其中所述特征提取器通过如下方式进行训练:
获取多个待训练样本,所述多个待训练样本包括源域样本和目标域样本;
在当前训练次数下,通过所述特征提取器提取所述多个待训练样本对应的多个深层特征;将所述多个深层特征分别输入至多个域判别器中,以得到每一个所述域判别器对每一个所述待训练样本的所述深层特征所属域的预测概率;根据每两个所述域判别器对同一个所述待训练样本的所述预测概率计算每两个所述域判别器之间的差异度,根据每两个所述域判别器之间的所述差异度计算所述多个域判别器之间对同一个所述待训练样本的总差异度;根据所述多个域判别器对所述多个待训练样本对应的多个所述总差异度计算域间损失;
根据所述域间损失对所述特征提取器进行优化,使所述特征提取器根据所述多个待训练样本进行下一次训练,直至训练结束。
9.一种电子设备,其特征在于,包括:
至少一个存储器;
至少一个处理器;
至少一个计算机程序;
所述计算机程序被存储在所述存储器中,处理器执行所述至少一个计算机程序以实现:
如权利要求1至7任一项所述基于多判别器的对抗迁移学习方法。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机可执行指令,所述计算机可执行指令用于使计算机执行:
如权利要求1至7任一项所述基于多判别器的对抗迁移学习方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311540180.8A CN117474080A (zh) | 2023-11-17 | 2023-11-17 | 一种基于多判别器的对抗迁移学习方法和装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311540180.8A CN117474080A (zh) | 2023-11-17 | 2023-11-17 | 一种基于多判别器的对抗迁移学习方法和装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117474080A true CN117474080A (zh) | 2024-01-30 |
Family
ID=89627408
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311540180.8A Pending CN117474080A (zh) | 2023-11-17 | 2023-11-17 | 一种基于多判别器的对抗迁移学习方法和装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117474080A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117976198A (zh) * | 2024-03-28 | 2024-05-03 | 神州医疗科技股份有限公司 | 基于数据筛选和对抗网络的医学跨域辅助诊断方法及装置 |
-
2023
- 2023-11-17 CN CN202311540180.8A patent/CN117474080A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117976198A (zh) * | 2024-03-28 | 2024-05-03 | 神州医疗科技股份有限公司 | 基于数据筛选和对抗网络的医学跨域辅助诊断方法及装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN108171209B (zh) | 一种基于卷积神经网络进行度量学习的人脸年龄估计方法 | |
CN111860573B (zh) | 模型训练方法、图像类别检测方法、装置和电子设备 | |
CN110070029B (zh) | 一种步态识别方法及装置 | |
CN110033026B (zh) | 一种连续小样本图像的目标检测方法、装置及设备 | |
CN110929785B (zh) | 数据分类方法、装置、终端设备及可读存储介质 | |
CN110837846A (zh) | 一种图像识别模型的构建方法、图像识别方法及装置 | |
US20140286527A1 (en) | Systems and methods for accelerated face detection | |
CN117474080A (zh) | 一种基于多判别器的对抗迁移学习方法和装置 | |
CN111582358B (zh) | 户型识别模型的训练方法及装置、户型判重的方法及装置 | |
CN111401339A (zh) | 识别人脸图像中的人的年龄的方法、装置及电子设备 | |
CN114241505A (zh) | 化学结构图像的提取方法、装置、存储介质及电子设备 | |
CN112182269B (zh) | 图像分类模型的训练、图像分类方法、装置、设备及介质 | |
CN112115996B (zh) | 图像数据的处理方法、装置、设备及存储介质 | |
CN111915595A (zh) | 图像质量评价方法、图像质量评价模型的训练方法和装置 | |
CN117934463A (zh) | 一种基于光学测试的肉牛胴体质量评级方法 | |
CN117849193A (zh) | 钕铁硼烧结的裂纹损伤在线监测方法 | |
CN114328942A (zh) | 关系抽取方法、装置、设备、存储介质和计算机程序产品 | |
CN110349119B (zh) | 基于边缘检测神经网络的路面病害检测方法和装置 | |
CN112348011A (zh) | 一种车辆定损方法、装置及存储介质 | |
CN111967383A (zh) | 年龄估计方法、年龄估计模型的训练方法和装置 | |
CN110704678A (zh) | 评估排序方法、评估排序系统、计算机装置及储存介质 | |
CN113536845B (zh) | 人脸属性识别方法、装置、存储介质和智能设备 | |
CN111898531A (zh) | 卫星通信信号识别方法、装置及电子设备 | |
Bhanumathi et al. | Underwater Fish Species Classification Using Alexnet | |
CN114898186B (zh) | 细粒度图像识别模型训练、图像识别方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |