CN117932314A - 模型训练方法、装置、电子设备、存储介质及程序产品 - Google Patents
模型训练方法、装置、电子设备、存储介质及程序产品 Download PDFInfo
- Publication number
- CN117932314A CN117932314A CN202410344791.3A CN202410344791A CN117932314A CN 117932314 A CN117932314 A CN 117932314A CN 202410344791 A CN202410344791 A CN 202410344791A CN 117932314 A CN117932314 A CN 117932314A
- Authority
- CN
- China
- Prior art keywords
- sample
- network
- feature
- source domain
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 311
- 238000000034 method Methods 0.000 title claims abstract description 114
- 238000012545 processing Methods 0.000 claims abstract description 354
- 238000013473 artificial intelligence Methods 0.000 claims abstract description 101
- 238000000605 extraction Methods 0.000 claims abstract description 83
- 230000007704 transition Effects 0.000 claims abstract description 45
- 230000000875 corresponding effect Effects 0.000 claims description 73
- 230000015654 memory Effects 0.000 claims description 46
- 238000010586 diagram Methods 0.000 claims description 36
- 230000004927 fusion Effects 0.000 claims description 36
- 238000007499 fusion processing Methods 0.000 claims description 32
- 230000002596 correlated effect Effects 0.000 claims description 18
- 238000004590 computer program Methods 0.000 abstract description 14
- 239000011159 matrix material Substances 0.000 description 30
- 238000013145 classification model Methods 0.000 description 29
- 238000005516 engineering process Methods 0.000 description 18
- 230000008569 process Effects 0.000 description 18
- 238000003672 processing method Methods 0.000 description 15
- 230000000007 visual effect Effects 0.000 description 15
- 238000005295 random walk Methods 0.000 description 13
- 238000004364 calculation method Methods 0.000 description 12
- 238000009826 distribution Methods 0.000 description 8
- 238000004891 communication Methods 0.000 description 7
- 230000000694 effects Effects 0.000 description 7
- 230000006870 function Effects 0.000 description 7
- 238000012546 transfer Methods 0.000 description 6
- 239000013598 vector Substances 0.000 description 5
- 230000033228 biological regulation Effects 0.000 description 4
- 235000019800 disodium phosphate Nutrition 0.000 description 4
- 230000001965 increasing effect Effects 0.000 description 4
- 230000006978 adaptation Effects 0.000 description 3
- 230000006698 induction Effects 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 238000011160 research Methods 0.000 description 3
- 238000003491 array Methods 0.000 description 2
- 230000002457 bidirectional effect Effects 0.000 description 2
- 238000012512 characterization method Methods 0.000 description 2
- 239000002131 composite material Substances 0.000 description 2
- 238000013480 data collection Methods 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 239000000463 material Substances 0.000 description 2
- 230000007246 mechanism Effects 0.000 description 2
- 238000013508 migration Methods 0.000 description 2
- 230000005012 migration Effects 0.000 description 2
- 238000010606 normalization Methods 0.000 description 2
- 238000013515 script Methods 0.000 description 2
- 239000007787 solid Substances 0.000 description 2
- 238000000638 solvent extraction Methods 0.000 description 2
- NAWXUBYGYWOOIX-SFHVURJKSA-N (2s)-2-[[4-[2-(2,4-diaminoquinazolin-6-yl)ethyl]benzoyl]amino]-4-methylidenepentanedioic acid Chemical compound C1=CC2=NC(N)=NC(N)=C2C=C1CCC1=CC=C(C(=O)N[C@@H](CC(=C)C(O)=O)C(O)=O)C=C1 NAWXUBYGYWOOIX-SFHVURJKSA-N 0.000 description 1
- 208000037170 Delayed Emergence from Anesthesia Diseases 0.000 description 1
- 230000004913 activation Effects 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 230000008485 antagonism Effects 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000018109 developmental process Effects 0.000 description 1
- 239000006185 dispersion Substances 0.000 description 1
- 230000002708 enhancing effect Effects 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 230000007774 longterm Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 230000008447 perception Effects 0.000 description 1
- 230000011218 segmentation Effects 0.000 description 1
- 230000035945 sensitivity Effects 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 238000005309 stochastic process Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/213—Feature extraction, e.g. by transforming the feature space; Summarisation; Mappings, e.g. subspace methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/25—Fusion techniques
- G06F18/253—Fusion techniques of extracted features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/25—Fusion techniques
- G06F18/254—Fusion techniques of classification results, e.g. of results related to same input data
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
- G06N3/0455—Auto-encoder networks; Encoder-decoder networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- Biophysics (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Health & Medical Sciences (AREA)
- Probability & Statistics with Applications (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本申请提供了一种基于人工智能的模型训练方法、装置、电子设备、计算机可读存储介质及计算机程序产品。本申请可以应用于人工智能场景。该方法包括:获取样本集合;通过特征处理网络对样本集合中每个样本进行样本特征提取处理和图特征提取处理,得到每个样本的图结构特征;基于每个样本的图结构特征,确定每个样本对的样本间转移概率;基于每个样本的标签,确定每个样本对的标签损失;基于每个样本对的样本间转移概率以及每个样本对的标签损失,确定样本差异损失,并基于样本差异损失对特征处理网络进行更新,得到经过更新的特征处理网络。通过本申请提高特征处理网络在目标域上的泛化能力。
Description
技术领域
本申请涉及计算机技术领域,尤其涉及一种基于人工智能的模型训练方法、装置、电子设备、计算机可读存储介质及计算机程序产品。
背景技术
在对业务数据进行分类处理时,常用分类模型实现分类处理,即通过分类模型的特征处理网络提取业务数据的数据特征,通过分类模型的分类网络对数据特征进行分类。在使用源域数据对分类模型进行训练后,分类模型能够在源域样本上进行精确的分类处理。随着时间的推移,业务数据不断增加,其中会出现新增的不同类别的目标域数据,此时需要充分利用这部分无标签的目标域数据对分类模型进行不断更新,使分类模型能够不断适应新增的目标域数据,以保证分类模型在目标域数据上的分类结果依然准确。
然而,由于目标域数据和源域数据的数据集分布存在显著差异,导致使用源域数据训练的分类模型在目标域数据的泛化能力不足,主要是因为分类模型的特征处理网络无法对目标域数据进行精确的特征提取处理。已有技术中使用最小化最大平均差异准则或对抗性训练,来缩小目标域数据和源域数据之间的差异,但已有技术都是从数据样本角度和目标业务和原业务之间的差异角度进行设计,未考虑样本在由源域与目标域组成的整体数据集中的特征信息,导致分类模型的特征处理网络在目标域数据上的泛化能力弱。
发明内容
本申请实施例提供一种基于人工智能的模型训练方法、装置、电子设备、计算机可读存储介质及计算机程序产品,能够提高分类模型的特征处理网络在目标域数据上的泛化能力,并提高对应分类服务的计算资源利用率。
本申请实施例的技术方案是这样实现的:
本申请实施例提供一种基于人工智能的模型训练方法,所述方法包括:
获取样本集合,其中,所述样本集合包括源域样本以及目标域样本,所述源域样本与所述目标域样本来源于不同业务领域;
通过特征处理网络对所述样本集合中每个样本进行样本特征提取处理,得到每个所述样本的样本特征,并通过所述特征处理网络对所述样本集合中每个所述样本的样本特征进行图特征提取处理,得到每个所述样本的图结构特征;
基于每个所述样本的图结构特征,确定每个样本对的样本间转移概率,并基于每个所述目标域样本的图结构特征对每个所述目标域样本进行分类处理,得到每个所述目标域样本的预测标签;
基于每个所述源域样本的真实标签以及每个所述目标域样本的预测标签,确定每个所述样本对的标签损失;
基于每个所述样本对的样本间转移概率以及每个所述样本对的标签损失,确定样本差异损失,并基于所述样本差异损失对所述特征处理网络进行更新,得到经过更新的特征处理网络。
本申请实施例提供一种基于人工智能的模型训练方法,所述方法包括:
获取待分类数据;
通过特征处理网络对所述待分类数据进行特征提取处理,得到所述待分类数据的样本特征,并通过所述特征处理网络对所述待分类数据的样本特征进行图特征提取处理,得到所述待分类数据的图结构特征;
其中,所述特征处理网络是通过本申请实施例提供的基于人工智能的模型训练方法得到的;
基于所述待分类数据的图结构特征,对所述待分类数据进行分类处理,得到所述待分类数据的预测标签。
本申请实施例提供一种基于人工智能的模型训练装置,包括:
获取模块,用于获取样本集合,其中,所述样本集合包括源域样本以及目标域样本,所述源域样本与所述目标域样本来源于不同业务领域;
特征处理模块,用于通过特征处理网络对所述样本集合中每个样本进行样本特征提取处理,得到每个所述样本的样本特征,并通过所述特征处理网络对所述样本集合中每个所述样本的样本特征进行图特征提取处理,得到每个所述样本的图结构特征;
分类处理模块,用于基于每个所述样本的图结构特征,确定每个样本对的样本间转移概率,并基于每个所述目标域样本的图结构特征对每个所述目标域样本进行分类处理,得到每个所述目标域样本的预测标签;
损失处理模块,用于基于每个所述源域样本的真实标签以及每个所述目标域样本的预测标签,确定每个所述样本对的标签损失;
更新处理模块,用于基于每个所述样本对的样本间转移概率以及每个所述样本对的标签损失,确定样本差异损失,并基于所述样本差异损失对所述特征处理网络进行更新,得到经过更新的特征处理网络。
在上述方案中,更新处理模块,用于基于初始化的样本特征网络获取预训练三联体损失,并基于所述预训练三联体损失对所述初始化的样本特征网络进行更新,得到经过一次更新的样本特征网络;基于初始化的源域分类网络获取预训练分类器损失,并基于所述预训练分类器损失对所述经过一次更新的样本特征网络以及初始化的图特征网络进行更新,得到经过二次更新的样本特征网络以及经过一次更新的图特征网络;基于初始化的判别网络获取预训练判别器损失,并基于所述预训练判别器损失对所述经过二次更新的样本特征网络以及所述经过一次更新的图特征网络进行更新,得到用于构成所述特征处理网络的样本特征网络以及用于构成所述特征处理网络的图特征网络;基于所述预训练分类器损失对所述初始化的源域分类网络进行更新,得到源域分类网络,其中,所述源域分类网络用于基于每个所述目标域样本的图结构特征对每个所述目标域样本进行分类处理;基于所述预训练判别器损失对所述初始化的判别网络进行更新,得到判别网络,其中,所述判别网络用于对所述目标域样本以及所述源域样本进行域判别处理。
在上述方案中,更新处理模块,用于通过初始化的样本特征网络对每个所述源域样本进行样本特征提取处理,得到每个所述源域样本的预训练样本特征;针对每个所述源域样本执行以下处理:获取与所述源域样本具有相同真实标签的正面源域样本以及与所述源域样本具有不同真实标签的负面源域样本;获取所述源域样本的预训练样本特征与所述负面源域样本的预训练样本特征之间的第一特征距离,以及所述源域样本的预训练样本特征与所述正面源域样本的预训练样本特征之间的第二特征距离;获取与所述第一特征距离负相关,且与所述第二特征距离正相关的预训练三联体损失。
在上述方案中,更新处理模块,用于通过初始化的图特征网络对所述源域样本的预训练样本特征进行图特征提取处理,得到所述源域样本的预训练图结构特征;通过初始化的源域分类网络对所述源域样本执行基于所述预训练图结构特征的分类处理,得到所述源域样本的预测标签;基于所述源域样本的预测标签与所述源域样本的真实标签之间的差异,确定所述预训练分类器损失。
在上述方案中,更新处理模块,用于通过所述初始化的判别网络对所述源域样本进行基于所述源域样本的预训练图结构特征的域判别处理,得到所述源域样本的预训练域判别结果;通过所述初始化的判别网络对所述目标域样本进行基于所述目标域样本的预训练图结构特征的域判别处理,得到所述目标域样本的预训练域判别结果;对所述源域样本的预训练域判别结果以及所述目标域样本的预训练域判别结果进行融合处理,得到所述预训练判别器损失。
在上述方案中,特征处理模块,用于针对每个所述源域样本,结合所述样本集合中每个所述源域样本的样本特征,对所述源域样本的样本特征进行图特征提取处理,得到所述源域样本的图结构特征;针对每个所述目标域样本,结合所述样本集合中每个所述目标域样本的样本特征,对所述目标域样本的样本特征进行图特征提取处理,得到所述目标域样本的图结构特征。
在上述方案中,特征处理模块,用于获取所述源域样本的样本特征与所述样本集合中每个所述源域样本的样本特征之间的第三特征距离;以所述样本集合中每个所述源域样本的样本特征为权重,将对应所述样本集合中每个所述源域样本的第三特征距离进行融合处理,得到对应所述源域样本的图结构特征。
在上述方案中,分类处理模块,用于对所述样本集合进行两两遍历处理,得到多个样本对;获取每个所述样本对对应的两个图结构特征之间的第四特征距离;将多个所述样本对的第四特征距离进行融合处理,得到第一融合结果;获取与每个所述样本对的第四特征距离正相关,且与所述第一融合结果负相关的数值作为每个所述样本对的样本间转移概率。
在上述方案中,损失处理模块,用于基于每个所述源域样本的真实标签以及每个所述目标域样本的预测标签,确定对应每个类别的期望;对所述样本集合进行两两遍历处理,得到多个样本对,其中,所述样本对包括第一样本以及第二样本;针对每个所述样本对执行以下处理:从所述第一样本的真实标签或者预测标签中提取所述第一样本属于每个类别的第一概率,并从所述第二样本的真实标签或者预测标签中提取所述第二样本属于每个类别的第二概率;针对所述类别,获取与对应所述类别的第一概率以及第二概率正相关,且与所述类别的期望负相关的类别损失;对多个所述类别分别对应的类别损失进行融合处理,得到所述样本对的标签损失。
在上述方案中,损失处理模块,用于针对每个所述类别执行以下处理:从每个所述源域样本的真实标签中提取每个所述源域样本属于所述类别的第三概率,并从每个所述目标域样本的预测标签中提取每个所述目标域样本属于所述类别的第四概率;对多个所述第三概率以及多个所述第四概率进行融合处理,得到融合概率,并将所述融合概率与所述类别对应的标签值进行相乘,得到对应所述类别的期望。
在上述方案中,更新处理模块,用于针对每个所述样本对执行以下处理:对所述样本对的样本间转移概率以及所述样本对的标签损失进行相乘处理,得到所述样本对的子差异损失;对多个所述样本对的子差异损失进行融合处理,得到第二融合结果;获取与所述第二融合结果负相关的样本差异损失。
在上述方案中,更新处理模块,用于获取训练三联体损失;基于源域分类网络获取训练分类器损失;基于判别网络获取训练判别器损失;对所述训练三联体损失、所述训练分类器损失、所述训练判别器损失以及所述样本差异损失进行融合处理,得到综合损失;基于所述综合损失对所述特征处理网络进行更新,得到经过更新的特征处理网络;基于所述综合损失对所述判别网络、所述源域分类网络进行更新,得到更新后的判别网络以及更新后的源域分类网络。
本申请实施例提供一种基于人工智能的数据处理装置,包括:
获取模块,用于获取待分类数据;
特征处理模块,用于通过特征处理网络对所述待分类数据进行特征提取处理,得到所述待分类数据的样本特征,并通过所述特征处理网络对所述待分类数据的样本特征进行图特征提取处理,得到所述待分类数据的图结构特征;
其中,所述特征处理网络是通过本申请实施例提供的基于人工智能的模型训练方法得到的;
分类处理模块,用于基于所述待分类数据的图结构特征,对所述待分类数据进行分类处理,得到所述待分类数据的预测标签。
本申请实施例提供一种电子设备,所述电子设备包括:
存储器,用于存储计算机可执行指令;
处理器,用于执行所述存储器中存储的计算机可执行指令时,实现本申请实施例提供的基于人工智能的模型训练方法。
本申请实施例提供一种计算机可读存储介质,存储有计算机可执行指令,用于被处理器执行时实现本申请实施例提供的基于人工智能的模型训练方法。
本申请实施例提供一种计算机程序产品,包括计算机可执行指令,所述计算机可执行指令被处理器执行时,实现本申请实施例提供的基于人工智能的模型训练方法。
本申请实施例具有以下有益效果:
获取来源于不同业务领域的源域样本以及目标域样本,源域样本与目标域样本,通过特征处理网络对每个样本进行样本特征提取处理,得到每个样本的样本特征,以对每个样本的个体特征信息进行表征,并通过特征处理网络对样本集合中每个样本的样本特征进行图特征提取处理,得到每个样本的图结构特征,以对每个样本与相同域中的其他样本之间的关联特征信息进行表征,基于每个样本的图结构特征,确定每个样本对的样本间转移概率,以确定样本对中的两个样本之间的相似程度,并基于每个目标域样本的图结构特征对每个目标域样本进行分类处理,得到每个目标域样本的预测标签,以确定目标域样本的分类结果,基于每个源域样本的真实标签以及每个目标域样本的预测标签,确定每个样本对的标签损失,以表征样本对中两个样本之间的分类结果差异,基于每个样本对的样本间转移概率以及每个样本对的标签损失,确定样本差异损失,将样本之间的特征差异融合为样本差异损失,并基于样本差异损失对特征处理网络进行更新,得到经过更新的特征处理网络,使特征处理网络学习到相同域的样本之间以及不同域的样本之间的特征差异与特征共性,使特征处理网络能够对目标域样本进行精确的特征提取处理,从而提高分类模型的特征处理网络在目标域上的泛化能力,并提高对应分类服务的计算资源利用率。
附图说明
图1是本申请实施例提供的基于人工智能的模型训练系统的架构示意图;
图2A是本申请实施例提供的用于执行基于人工智能的模型训练方法的服务器的结构示意图;
图2B是本申请实施例提供的用于执行基于人工智能的数据处理方法的服务器的结构示意图;
图3A是本申请实施例提供的基于人工智能的模型训练方法的第一流程示意图;
图3B是本申请实施例提供的基于人工智能的模型训练方法的第二流程示意图;
图3C是本申请实施例提供的基于人工智能的模型训练方法的第三流程示意图;
图3D是本申请实施例提供的基于人工智能的模型训练方法的第四流程示意图;
图3E是本申请实施例提供的基于人工智能的模型训练方法的第五流程示意图;
图3F是本申请实施例提供的基于人工智能的模型训练方法的第六流程示意图;
图3G是本申请实施例提供的基于人工智能的模型训练方法的第七流程示意图;
图3H是本申请实施例提供的基于人工智能的模型训练方法的第八流程示意图;
图3I是本申请实施例提供的基于人工智能的模型训练方法的第九流程示意图;
图3J是本申请实施例提供的基于人工智能的模型训练方法的第十流程示意图;
图4是本申请实施例提供的基于人工智能的数据处理方法的流程示意图;
图5是本申请实施例提供的基于人工智能的模型训练方法的框架示意图;
图6是本申请实施例提供的视觉自注意力模型的模型架构示意图;
图7是本申请实施例提供的视觉自注意力模块的子模块组架构示意图;
图8是本申请实施例提供的基于自注意力模型的双向编码器的处理流程示意图。
需要指出,上述的“第一”、“第二”仅用于区分不同的方案,不代表用于区分方案的优劣程度或在实施过程中的优先级。
具体实施方式
为了使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请作进一步地详细描述,所描述的实施例不应视为对本申请的限制,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本申请保护的范围。
在以下的描述中,涉及到“一些实施例”,其描述了所有可能实施例的子集,但是可以理解,“一些实施例”可以是所有可能实施例的相同子集或不同子集,并且可以在不冲突的情况下相互结合。
在以下的描述中,所涉及的术语“第一\第二\第三”仅仅是是区别类似的对象,不代表针对对象的特定排序,可以理解地,“第一\第二\第三”在允许的情况下可以互换特定的顺序或先后次序,以使这里描述的本申请实施例能够以除了在这里图示或描述的以外的顺序实施。
本申请实施例中,术语“模块”或“单元”是指有预定功能的计算机程序或计算机程序的一部分,并与其他相关部分一起工作以实现预定目标,并且可以通过使用软件、硬件(如处理电路或存储器)或其组合来全部或部分实现。同样的,一个处理器(或多个处理器或存储器)可以用来实现一个或多个模块或单元。此外,每个模块或单元都可以是包含该模块或单元功能的整体模块或单元的一部分。
除非另有定义,本申请实施例所使用的所有的技术和科学术语与所属技术领域的技术人员通常理解的含义相同。本申请实施例中所使用的术语只是为了描述本申请实施例的目的,不是旨在限制本申请。
本申请实施例中相关数据收集处理在实例应用时应该严格根据相关国家法律法规的要求,获取个人信息主体的知情同意或单独同意,并在法律法规及个人信息主体的授权范围内,开展后续数据使用及处理。
对本申请实施例进行进一步详细说明之前,对本申请实施例中涉及的名词和术语进行说明,本申请实施例中涉及的名词和术语适用于如下的解释。
1)基于时空局部性归纳偏置的视觉自注意力模型(Video Swin Transformers):一种用于视频理解任务的基于自注意力编码器架构的图像分类模型,通过自注意力机制和局部窗口交换的方式来捕捉图像中的全局和局部信息。基于时空局部性归纳偏置的视觉自注意力模型应用于视频领域,通过对视频序列中的每一帧进行特征提取和建模,实现了对视频内容的理解和分析。
2)马尔可夫随机游走:马尔可夫随机游走是一种随机过程,描述了一个在状态空间中随机移动的过程,其中,每一步的移动只依赖于当前状态,而与之前的状态无关。具体步骤如下:从初始状态分布中选择一个初始状态,根据当前状态的转移概率,随机选择下一个状态,重复上述步骤,直到达到某个终止条件。
在对业务数据进行分类处理时,常用分类模型实现分类处理,即通过分类模型的特征处理网络提取业务数据的数据特征,通过分类模型的分类网络对数据特征进行分类。在使用源域数据对分类模型进行训练后,分类模型能够在源域样本上进行精确的分类处理。随着时间的推移,业务数据不断增加,其中会出现新增的不同类别的目标域数据,此时需要充分利用这部分无标签的目标域数据对分类模型进行不断更新,使分类模型能够不断适应新增的目标域数据,以保证分类模型在目标域数据上的分类结果依然准确。
然而,由于目标域数据和源域数据的数据集分布存在显著差异,这导致两个核心问题:首先,在数据融合后往往难以达到预期效果;其次,一个业务领域训练的模型在另一个业务领域的泛化能力明显不足。因此,如何更有效地整合不同业务方的数据集,以及确保在一个业务领域训练的模型在其他领域依然表现出色,成为当前亟需解决的关键挑战。
目前,领域自适应的研究方案主要集中在采用各种对比损失来缩小目标域样本和源域样本之间的差异。例如,已有技术提供了以下方案:
1)最小化最大平均差异(Maximize Mean Discrepancy,MMD)准则:是一种用于度量两个分布之间差异的方法。在领域自适应中,MMD被广泛用于最小化源领域和目标领域之间的特征分布差异。通过最小化不同领域间的特征空间的分布差异,使得源领域的特征在目标领域上更具有泛化能力。
2)使用生成式对抗网络(Generative Adversarial Networks,GAN):该方法通过设计一个判别器,该判别器负责辨别生成器生成的特征是来自哪个业务领域。通过最大化判别器的输出,使其无法准确地辨别一个样本来自于哪个业务,从而实现两个业务之间的深度融合。
申请人在实施本申请实施例的过程中,发现已有技术存在以下问题:
1、已有技术从数据样本角度和目标业务和原业务之间的差异角度进行设计,未考虑样本在由源域与目标域组成的整体数据集中的特征信息,导致分类模型的特征处理网络在目标域数据上的泛化能力弱。
2、不能处理较大的领域差异:当源域和目标域之间存在较大的领域差异时,无法有效地缩小这些差异,导致模型的数据处理性能下降。
本申请实施例提供一种基于人工智能的模型训练方法、数据处理方法、装置、电子设备、计算机可读存储介质及计算机程序产品,能够提高旧模型在新数据集上的泛化能力,下面说明本申请实施例提供的电子设备的示例性应用,本申请实施例提供的设备可以实施为笔记本电脑、平板电脑,台式计算机、机顶盒、移动设备(例如,移动电话,便携式音乐播放器,个人数字助理,专用消息设备,便携式游戏设备)、智能手机、智能音箱、智能手表、智能电视、车载终端等各种类型的用户终端,也可以实施为服务器。下面,将说明设备实施为或服务器时示例性应用。
参见图1,图1是本申请实施例提供的基于人工智能的模型训练系统100的架构示意图,为实现支撑一个基于人工智能的模型训练应用,终端400通过网络300连接服务器200,网络300可以是广域网或者局域网,又或者是二者的组合。
终端400用于生成模型训练请求,例如用户通过终端400的图形界面410生成模型训练指令,终端400响应于用户的模型训练指令,生成模型训练请求并发送至服务器200,服务器200用于基于模型训练请求,获取样本集合,其中,样本集合包括源域样本以及目标域样本,源域样本与目标域样本来源于不同业务领域,通过特征处理网络对样本集合中每个样本进行样本特征提取处理,得到每个样本的样本特征,并通过特征处理网络对样本集合中每个样本的样本特征进行图特征提取处理,得到每个样本的图结构特征,基于每个样本的图结构特征,确定每个样本对的样本间转移概率,并基于每个目标域样本的图结构特征对每个目标域样本进行分类处理,得到每个目标域样本的预测标签,基于每个源域样本的真实标签以及每个目标域样本的预测标签,确定每个样本对的标签损失,基于每个样本对的样本间转移概率以及每个样本对的标签损失,确定样本差异损失,并基于样本差异损失对特征处理网络进行更新,得到经过更新的特征处理网络。
在一些实施例中,在服务器200部署经过更新的特征处理网络,终端400生成数据处理请求,服务器200基于数据处理请求,获取待分类数据,待分类数据与目标域样本属于相同业务领域,通过特征处理网络对待分类数据进行特征提取处理,得到待分类数据的样本特征,并通过特征处理网络对待分类数据的样本特征进行图特征提取处理,得到待分类数据的图结构特征,基于待分类数据的图结构特征,对待分类数据进行分类处理,得到待分类数据的预测标签,并将待分类数据的预测标签返回至终端400。
在一些实施例中,在终端400部署经过更新的特征处理网络,终端400生成数据处理请求,并基于数据处理请求,获取待分类数据,待分类数据与目标域样本属于相同业务领域,通过特征处理网络对待分类数据进行特征提取处理,得到待分类数据的样本特征,并通过特征处理网络对待分类数据的样本特征进行图特征提取处理,得到待分类数据的图结构特征,基于待分类数据的图结构特征,对待分类数据进行分类处理,得到待分类数据的预测标签,得到待分类数据的分类结果。
在一些实施例中,服务器200可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(Content DeliveryNetwork,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。终端400可以是智能手机、平板电脑、笔记本电脑、台式计算机、智能音箱、智能手表、车载终端等,但并不局限于此。终端以及服务器可以通过有线或无线通信方式进行直接或间接地连接,本申请实施例中不做限制。
在一些实施例中,服务器包括硬盘、内存以及处理器,硬盘中储存样本集合,当服务器接收到模型训练请求时,将样本集合中的源域样本和目标域样本读取到内存,由处理器在内存中执行本申请实施例提供的基于人工智能的模型训练方法,得到经过更新的特征处理网络,完成一轮模型训练,然后重复上述过程,直到特征处理网络稳定,将经过更新的特征处理网络部署在服务器的硬盘或终端。相关技术从数据样本角度和目标业务和原业务之间的差异角度进行训练,未考虑样本在由源域与目标域组成的整体数据集中的特征信息,则会使特征处理网络无法对目标域样本进行精确的特征提取,需要进行更多轮模型训练,且训练得到的特征处理网络在目标域上的泛化能力弱。而本申请实施例可以通过根据样本的标签损失以及样本间转移概率,构建样本差异损失,使特征处理网络在每一轮模型训练中学习到样本在源域样本和目标域样本组成的样本集合中的特征信息,以使特征处理网络能够尽可能地忽略样本的来源业务域而精确提取其他特征信息,需要进行的训练轮次少且训练得到的特征处理网络在目标域上的泛化能力强,无需长时间占用处理器的计算资源,可以有效提高处理器的计算资源利用率。
在实际应用中,本申请实施例可以应用于新闻业务场景中。以新闻业务为例,样本为新闻样本,新闻样本可以为文本样本、图片样本以及视频样本中至少之一,源域样本可以为来源于原始新闻业务方的原始新闻样本,目标域样本可以为来源于新增新闻业务方的新增新闻样本。由于来自不同新闻业务方的新闻数据的内容形式相似且分类标签也相似,例如,财经标签、人文标签、娱乐标签等,因此可以将原始应用于原始新闻样本分类任务的特征处理网络迁移至新增新闻样本进行应用。将原始新闻样本和新增新闻样本组成样本集合,通过本申请实施例提供的基于人工智能的模型训练方法,对特征处理网络进行模型训练,得到经过更新的特征处理网络,以使经过更新的特征处理网络与源域分类网络能够对新增新闻业务方的新增新闻样本进行分类处理,得到待分类数据的预测标签。通过本申请实施例提供的基于人工智能的模型训练方法得到的特征处理网络在新增新闻业务方的新增新闻样本的泛化能力强,能够对新增新闻业务方的待分类数据进行精确分类,且训练轮次少,无需长时间占用处理器的计算资源,可以有效提高处理器的计算资源利用率。此外,本申请实施例还可以应用于在模型在同一业务方的不同数据域之间进行迁移的场景。以新闻业务场景为例,新闻样本具有实时性,数据更新速度快,如果使用人工标注方法为新增新闻样本进行打标处理,再进行分类,则会耗费大量成本,因此使用原始应用于原始新闻数据分类的特征处理网络对新增新闻数据进行分类,可以节约人力物力。原始新闻样本与新增新闻样本是同一新闻业务方的新闻样本,以原始新闻样本作为源域样本,以新增新闻样本作为目标域样本,以原始新闻样本和新增新闻样本组成样本集合,通过本申请实施例提供的基于人工智能的模型训练方法,对特征处理网络进行模型训练,得到经过更新的特征处理网络,以使经过更新的特征处理网络与源域分类网络能够对新增新闻样本中的待分类数据进行分类处理,得到待分类数据的预测标签。通过本申请实施例提供的基于人工智能的模型训练方法得到的特征处理网络在新增新闻样本上的泛化能力强,能够对新增新闻样本中的待分类数据进行精确分类,且训练轮次少,无需长时间占用处理器的计算资源,可以有效提高处理器的计算资源利用率。
在一些实施例中,本申请实施例还可以应用于在购物业务场景中。随着商品更新迭代,商品种类会越来越多,即出现大量新增商品样本,如果使用人工标注方法为新增商品样本进行打标处理,再进行分类,则会耗费大量成本,因此使用原始应用于原始商品样本分类的特征处理网络对新增商品样本进行分类,可以节约人力物力。源域样本为原始商品样本,目标域样本为新增商品样本。以原始商品样本和新增商品样本组成样本集合,通过本申请实施例提供的基于人工智能的模型训练方法,对特征处理网络进行模型训练,得到经过更新的特征处理网络,以使经过更新的特征处理网络与源域分类网络能够对新增商品样本中的待分类数据进行分类处理,得到待分类数据的预测标签。通过本申请实施例提供的基于人工智能的模型训练方法得到的特征处理网络在新增商品样本领域的泛化能力强,能够对新增商品样本领域的待分类数据进行精确分类,且训练轮次少,无需长时间占用处理器的计算资源,可以有效提高处理器的计算资源利用率。
在一些实施例中,本申请实施例还可以应用于视频业务场景。在视频业务场景中,无论是影视资源还是自媒体视频,视频样本的更新频率高,因此需要使用原始对原有视频样本分类的特征处理网络对新增视频样本进行分类。以原有视频样本为源域样本,以新增视频样本为目标域样本,以原有视频样本和新增新闻样本组成样本集合,通过本申请实施例提供的基于人工智能的模型训练方法,对特征处理网络进行模型训练,得到经过更新的特征处理网络,以使经过更新的特征处理网络与源域分类网络能够对新增视频样本中的待分类数据进行分类处理,得到待分类数据的预测标签。通过本申请实施例提供的基于人工智能的模型训练方法得到的特征处理网络在新增视频样本的泛化能力强,能够对新增视频样本中的待分类数据进行精确分类,且训练轮次少,无需长时间占用处理器的计算资源,可以有效提高处理器的计算资源利用率。
本申请实施例可应用于各种场景,包括但不限于人工智能等场景。人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、预训练模型技术、操作/交互系统、机电一体化等。其中,预训练模型又称大模型、基础模型,经过微调后可以广泛应用于人工智能各大方向下游任务。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
随着人工智能技术研究和进步,人工智能技术在多个领域展开研究和应用,例如常见的智能家居、智能穿戴设备、虚拟助理、智能音箱、智能营销、无人驾驶、自动驾驶、无人机、数字孪生、虚拟人、机器人、人工智能生成内容(AIGC)、对话式交互、智能医疗、智能客服、游戏AI等,相信随着技术的发展,人工智能技术将在更多的领域得到应用,并发挥越来越重要的价值。
参见图2A,图2A是本申请实施例提供的用于执行基于人工智能的模型训练方法的服务器200-1的结构示意图,图2A所示的服务器200-1包括:至少一个处理器210、存储器230和至少一个网络接口220。服务器200-1中的各个组件通过总线系统240耦合在一起。可理解,总线系统240用于实现这些组件之间的连接通信。总线系统240除包括数据总线之外,还包括电源总线、控制总线和状态信号总线。但是为了清楚说明起见,在图2A中将各种总线都标为总线系统240。
处理器210可以是一种集成电路芯片,具有信号的处理能力,例如通用处理器、数字信号处理器(Digital Signal Processor,DSP),或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等,其中,通用处理器可以是微处理器或者任何常规的处理器等。
存储器230可以是可移除的,不可移除的或其组合。示例性的硬件设备包括固态存储器,硬盘驱动器,光盘驱动器等。存储器230可选地包括在物理位置上远离处理器210的一个或多个存储设备。
存储器230包括易失性存储器或非易失性存储器,也可包括易失性和非易失性存储器两者。非易失性存储器可以是只读存储器(ROM,Read Only Memory),易失性存储器可以是随机存取存储器(Random Access Memory,RAM)。本申请实施例描述的存储器230旨在包括任意适合类型的存储器。
在一些实施例中,存储器230能够存储数据以支持各种操作,这些数据的示例包括程序、模块和数据结构或者其子集或超集,下面示例性说明。
操作系统231,包括用于处理各种基本系统服务和执行硬件相关任务的系统程序,例如框架层、核心库层、驱动层等,用于实现各种基础业务以及处理基于硬件的任务;
网络通信模块232,用于经由一个或多个(有线或无线)网络接口220到达其他电子设备,示例性的网络接口220包括:蓝牙、无线相容性认证(WiFi)、和通用串行总线(Universal Serial Bus,USB)等;
在一些实施例中,本申请实施例提供的基于人工智能的模型训练装置可以采用软件方式实现,图2A示出了存储在存储器230中的基于人工智能的模型训练装置233,其可以是程序和插件等形式的软件,包括以下软件模块:获取模块2331、特征处理模块2332、分类处理模块2333、损失处理模块2334以及更新处理模块2335,这些模块是逻辑上的,因此根据所实现的功能可以进行任意的组合或进一步拆分。将在下文中说明各个模块的功能。
在另一些实施例中,本申请实施例提供的基于人工智能的模型训练装置可以采用硬件方式实现,作为示例,本申请实施例提供的基于人工智能的模型训练装置可以是采用硬件译码处理器形式的处理器,其被编程以执行本申请实施例提供的基于人工智能的模型训练方法,例如,硬件译码处理器形式的处理器可以采用一个或多个应用专用集成电路(Application Specific Integrated Circuit,ASIC)、数字信号处理器(Digital SignalProcessor,DSP)、可编程逻辑器件(Programmable Logic Device,PLD)、复杂可编程逻辑器件(Complex Programmable Logic Device,CPLD)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或其他电子元件。
在一些实施例中,终端或服务器可以通过运行各种计算机可执行指令或计算机程序来实现本申请实施例提供的基于人工智能的模型训练方法。举例来说,计算机可执行指令可以是微程序级的命令、机器指令或软件指令。计算机程序可以是操作系统中的原生程序或软件模块;可以是本地(Native)应用程序(APPlication,APP),即需要在操作系统中安装才能运行的程序,如新闻APP或者购物APP;也可以是可以嵌入至任意APP中的小程序,即只需要下载到浏览器环境中就可以运行的程序。总而言之,上述的计算机可执行指令可以是任意形式的指令,上述计算机程序可以是任意形式的应用程序、模块或插件。
参见图2B,图2B是本申请实施例提供的用于执行基于人工智能的数据处理方法的服务器200-2的结构示意图,图2B所示的服务器200-2包括:至少一个处理器250、存储器260和至少一个网络接口270。服务器200-2中的各个组件通过总线系统280耦合在一起。可理解,总线系统280用于实现这些组件之间的连接通信。总线系统280除包括数据总线之外,还包括电源总线、控制总线和状态信号总线。但是为了清楚说明起见,在图2B中将各种总线都标为总线系统280。
处理器250可以是一种集成电路芯片,具有信号的处理能力,例如通用处理器、数字信号处理器(Digital Signal Processor,DSP),或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等,其中,通用处理器可以是微处理器或者任何常规的处理器等。
存储器260可以是可移除的,不可移除的或其组合。示例性的硬件设备包括固态存储器,硬盘驱动器,光盘驱动器等。存储器260可选地包括在物理位置上远离处理器250的一个或多个存储设备。
存储器260包括易失性存储器或非易失性存储器,也可包括易失性和非易失性存储器两者。非易失性存储器可以是只读存储器(ROM,Read Only Memory),易失性存储器可以是随机存取存储器(Random Access Memory,RAM)。本申请实施例描述的存储器260旨在包括任意适合类型的存储器。
在一些实施例中,存储器260能够存储数据以支持各种操作,这些数据的示例包括程序、模块和数据结构或者其子集或超集,下面示例性说明。
操作系统261,包括用于处理各种基本系统服务和执行硬件相关任务的系统程序,例如框架层、核心库层、驱动层等,用于实现各种基础业务以及处理基于硬件的任务;
网络通信模块262,用于经由一个或多个(有线或无线)网络接口270到达其他电子设备,示例性的网络接口270包括:蓝牙、无线相容性认证(WiFi)、和通用串行总线(Universal Serial Bus,USB)等;
在一些实施例中,本申请实施例提供的基于人工智能的数据处理装置可以采用软件方式实现,图2B示出了存储在存储器260中的基于人工智能的数据处理装置263,其可以是程序和插件等形式的软件,包括以下软件模块:获取模块2631、特征处理模块2632以及分类处理模块2633,这些模块是逻辑上的,因此根据所实现的功能可以进行任意的组合或进一步拆分。将在下文中说明各个模块的功能。
在另一些实施例中,本申请实施例提供的基于人工智能的数据处理装置可以采用硬件方式实现,作为示例,本申请实施例提供的基于人工智能的数据处理装置可以是采用硬件译码处理器形式的处理器,其被编程以执行本申请实施例提供的基于人工智能的模型训练方法,例如,硬件译码处理器形式的处理器可以采用一个或多个应用专用集成电路(Application Specific Integrated Circuit,ASIC)、数字信号处理器(Digital SignalProcessor,DSP)、可编程逻辑器件(Programmable Logic Device,PLD)、复杂可编程逻辑器件(Complex Programmable Logic Device,CPLD)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或其他电子元件。
在一些实施例中,终端或服务器可以通过运行各种计算机可执行指令或计算机程序来实现本申请实施例提供的基于人工智能的数据处理方法。举例来说,计算机可执行指令可以是微程序级的命令、机器指令或软件指令。计算机程序可以是操作系统中的原生程序或软件模块;可以是本地(Native)应用程序(APPlication,APP),即需要在操作系统中安装才能运行的程序,如新闻APP或者购物APP;也可以是可以嵌入至任意APP中的小程序,即只需要下载到浏览器环境中就可以运行的程序。总而言之,上述的计算机可执行指令可以是任意形式的指令,上述计算机程序可以是任意形式的应用程序、模块或插件。
下面,说明本申请实施例提供的基于人工智能的模型训练方法,如前,实现本申请实施例的基于人工智能的模型训练方法的电子设备可以是终端、服务器,又或者是二者的结合。因此下文中不再重复说明各个步骤的执行主体。
需要说明的是,下文中的基于人工智能的模型训练方法的示例中,是以对象为文本为例说明的,本领域技术人员根据对下文的理解,可以将本申请实施例提供的基于人工智能的模型训练方法应用于包括其他类型对象的样本集合的处理。
参见图3A,图3A是本申请实施例提供的基于人工智能的模型训练方法的第一流程示意图,将结合图3A示出的步骤101至步骤105进行说明。
在步骤101中,获取样本集合。
作为示例,样本集合包括源域样本以及目标域样本,样本集合中的样本的类型,可以根据实际业务需求确定,例如,在新闻业务场景、媒体流业务场景以及视频业务场景中,样本集合中的样本,可以是文本样本,可以是视频样本等,在购物业务场景中,样本集合中的样本,还可以是商品数据样本,本申请在此不作限制。
作为示例,源域样本与目标域样本来源于不同业务领域。不同业务领域可以是对应不同业务方的领域,例如,源域样本是对应新闻业务的样本,目标域样本是对应媒体流业务的样本,由于新闻业务与媒体流业务的分类处理要求相似,因此可以将应用于新闻业务分类处理的分类模型,同样在媒体流业务分类处理中进行应用。此外,不同业务领域还可以是对应相同业务方的原有数据集领域和新增数据集领域,例如,对于购物业务,购物业务需要根据商品的用途、种类等特点对商品进行分类,此时需要应用分类模型对原有的商品数据进行分类,随着商品的不断增加和更新,形成了新增的商品数据,此时需要将应用于原有的商品数据的分类处理的分类模型,同样对新增的商品数据进行分类处理。
在步骤102中,通过特征处理网络对样本集合中每个样本进行样本特征提取处理,得到每个样本的样本特征,并通过特征处理网络对样本集合中每个样本的样本特征进行图特征提取处理,得到每个样本的图结构特征。
作为示例,特征处理网络包括用于执行样本特征提取处理的样本特征网络以及用于执行图结构特征提取处理的图特征网络。其中,样本特征网络的类型,可以根据样本的类型确定,例如,当样本为视频时,可以采用基于时空局部性归纳偏置的视觉自注意力模型(Video Swin Transformers)或其他能够对视频进行特征提取处理的网络模型作为样本特征网络,当样本为文本时,可以采用基于自注意力模型的双向编码器(BidirectionalEncoder Representations from Transformers,BERT)或其他能够对文本进行特征提取处理的网络模型作为样本特征网络。
作为示例,在实际应用中,样本可以为新闻业务领域中的新闻数据,源域样本为原始新闻数据,目标域样本为新增样本数据。例如,当样本为视频类型的新闻数据时,通过基于时空局部性归纳偏置的视觉自注意力模型对视频类型的新闻数据进行特征提取处理,得到每个视频类型的新闻数据的样本特征。例如,当样本为文本类型的新闻数据时,通过基于自注意力模型的双向编码器对文本类型的新闻数据进行特征提取处理,得到每个视频类型的新闻数据的样本特征。
作为示例,在实际应用中,样本可以为购物业务领域中的商品数据,源域样本为原始商品数据,目标域样本为新增商品数据。例如,当样本为图文类型的商品数据时,可以通过能够处理图文理解任务的模型,例如多模态模型,对商品数据进行特征提取处理,得到每个图文类型的商品数据的样本特征。例如,当样本为文本类型的商品数据时,可以通过基于自注意力模型的双向编码器对商品数据进行特征提取处理,得到每个文本类型的商品数据的样本特征。
参见图3B,图3B是本申请实施例提供的基于人工智能的模型训练方法的第二流程示意图。在一些实施例中,图3A中的步骤102中的对样本集合中每个样本的样本特征进行图特征提取处理,得到每个样本的图结构特征,可以通过图3B示出的步骤1021和步骤1022实现,下面进行详细说明。
在步骤1021中,针对每个源域样本,结合样本集合中每个源域样本的样本特征,对源域样本的样本特征进行图特征提取处理,得到源域样本的图结构特征。
在一些实施例中,步骤1021可以通过以下方式实现:获取源域样本的样本特征与样本集合中每个源域样本的样本特征之间的第三特征距离;以样本集合中每个源域样本的样本特征为权重,将对应样本集合中每个源域样本的第三特征距离进行融合处理,得到对应源域样本的图结构特征。
作为示例,以源域样本i的样本特征和样本集合中其他各个源域样本的样本特征,确定源域样本i与其他每个源域样本的第三特征距离,如余弦相似度,将每两个源域样本之间的余弦相似度作为元素,组成对应源域样本的第一邻接矩阵A,第一邻接矩阵A中的元素aij即为源域样本i的样本特征与源域样本j的样本特征之间的余弦相似度。同时,基于每个源域样本的样本特征,组成源域样本特征矩阵,将源域样本特征矩阵与第一邻接矩阵A相乘,得到源域样本的图结构特征矩阵,图结构特征矩阵中的每个元素对应每个源域样本的图结构特征,即相当于将每个源域样本的样本特征与每个源域样本的第三特征距离相乘,得到对应每个源域样本的图结构特征。
通过计算每个源域样本的样本特征与其他源域样本的样本特征之间的第三特征距离,确定每个源域样本的样本特征与其他源域样本的样本特征之间的相似程度,再以每个源域样本的样本特征作为权重和每个源域样本的样本特征与其他源域样本的样本特征之间的第三特征距离进行融合,得到每个源域样本的图结构特征,使源域样本的图结构特征中包含其与其他源域样本之间的特征关系信息和源域样本自身的特征信息,保证源域样本的图结构特征包含的信息丰富,进而保证在后续根据源域样本的图结构特征和目标域样本的图结构特征共同进行马尔科夫随机游走得到的样本间转移概率准确,使基于每个样本对的样本间转移概率确定的样本差异损失准确,从而提高特征提取网络在目标域上的泛化能力。
在步骤1022中,针对每个目标域样本,结合样本集合中每个目标域样本的样本特征,对目标域样本的样本特征进行图特征提取处理,得到目标域样本的图结构特征。
作为示例,获取目标域样本的样本特征与样本集合中每个目标域样本的样本特征之间的第五特征距离,例如,以目标域样本w的样本特征和样本集合中其他各个目标域样本的样本特征,确定目标域样本w与其他每个目标域样本的第五特征距离,如余弦相似度,将每两个目标域样本之间的余弦相似度作为元素,组成对应目标域样本的第二邻接矩阵B,第二邻接矩阵B中的元素bwv即为目标域样本w的样本特征与目标域样本v的样本特征之间的余弦相似度。以样本集合中每个目标域样本的样本特征为权重,将对应样本集合中每个目标域样本的第五特征距离进行融合处理,得到对应目标域样本的图结构特征,例如,基于每个目标域样本的样本特征,组成目标域样本特征矩阵,将目标域样本特征矩阵与第二邻接矩阵B相乘,得到目标域样本的图结构特征矩阵,图结构特征矩阵中的每个元素对应每个目标域样本的图结构特征,即相当于将每个目标域样本的样本特征与每个目标域样本的第五特征距离相乘,得到对应每个目标域样本的图结构特征。
需要说明的是,在模型训练早期,由于源域样本的样本特征与目标域样本的样本特征之间的特征差异较大,因此需要分别对源域样本的样本特征和目标域样本的样本特征进行特征处理,而在模型训练后期,由于已经对样本处理网络进行过多次训练,样本处理网络输出的源域样本的样本特征与目标域样本的样本特征的特征差异缩小,此时可以将源域样本的样本特征与目标域样本的样本特征同时输入图特征网络中进行图结构特征提取处理,可以使图结构网络学习到源域样本和目标域样本之间的特征关系信息。
通过分别对源域样本和目标域样本进行图结构特征提取处理,一方面使特征处理网络能够学习到源域样本的域内特征关系和目标域样本的域内特征关系另一方面,使源域样本的图结构特征中能够包含其与其他源域样本之间的特征关系信息,使目标域样本的图结构特征中能够包含其与其他目标域样本之间的特征关系信息,进而保证在后续根据源域样本的图结构特征和目标域样本的图结构特征共同进行马尔科夫随机游走得到的样本间转移概率准确,使基于每个样本对的样本间转移概率确定的样本差异损失准确,从而提高特征提取网络在目标域上的泛化能力。
继续参见图3A,在步骤103中,基于每个样本的图结构特征,确定每个样本对的样本间转移概率,并基于每个目标域样本的图结构特征对每个目标域样本进行分类处理,得到每个目标域样本的预测标签。
作为示例,基于每个样本的图结构特征,通过马尔科夫随机游走模块计算每个样本对的样本间转移概率。由于目标域样本并不具有现有的真实标签,则需要调用源域分类网络对目标域样本的图结构特征进行分类处理,得到每个目标域样本的预测标签,其中,目标域样本的预测标签包括目标域样本属于每个类别的概率。
参见图3C,图3C是本申请实施例提供的基于人工智能的模型训练方法的第三流程示意图。在一些实施例中,图3A中的步骤103中的基于每个样本的图结构特征,确定每个样本对的样本间转移概率,可以通过图3C示出的步骤1031至步骤1034实现,下面进行详细说明。
在步骤1031中,对样本集合进行两两遍历处理,得到多个样本对。
作为示例,将样本i与样本集合中的每个样本分别组成样本对,例如,样本集合中存在样本i、样本j、样本w以及样本v,则组成样本对ij、样本对iw、样本对iv、样本对jw、样本对jv以及样本对wv。其中,此处的样本是从混合在同一样本集合中的源域样本和目标域样本中的任一样本,它可能是源域样本,也可能是目标域样本。
在步骤1032中,获取每个样本对对应的两个图结构特征之间的第四特征距离。
作为示例,以样本i为例,获取样本i的图结构特征与样本j的图结构特征之间的第
四特征距离,如余弦相似度,获取样本i的图结构特征与样本w的图结构特征之间的第四
特征距离,获取样本i的图结构特征与样本v的图结构特征之间的第四特征距离。
在步骤1033中,将多个样本对的第四特征距离进行融合处理,得到第一融合结果。
作为示例,对样本i的图结构特征与样本j的图结构特征之间的第四特征距离、
样本i的图结构特征与样本w的图结构特征之间的第四特征距离以及样本i的图结构特
征与样本v的图结构特征之间的第四特征距离进行融合处理,得到第一融合结果,此处
的融合处理可以为加和处理,也可以为其他处理方式,本申请在此不作限制。
在步骤1034中,获取与每个样本对的第四特征距离正相关,且与第一融合结果负相关的数值作为每个样本对的样本间转移概率。
作为示例,以样本对ij为例,根据公式(1)计算对应样本对ij的样本间转移概率:
(1)
其中,为样本i和样本j之间的第四特征距离,如余弦相似度,为样本i和任一
样本k之间的余弦相似度,k=1,…,N,N为样本集合中样本的总数量。
通过将源域样本和目标域样本混合在同一个样本集合中,计算任意两个样本之间的第四特征距离,以表征任意两个样本之间的特征关系,将对应同一样本的第四特征距离融合为第一融合结果,以表征该样本与其他样本之间的特征关系紧密程度,并根据样本对中两个样本的第四特征距离以及第一融合结果,确定该样本对的样本间转移概率,以表征该样本对中两个样本之间的关联紧密程度,并表征从当前样本转移至下一样本的概率,用于进行多步马尔科夫随机游走聚类,以计算样本损失差异,进而使后续特征处理网络能够根据样本损失差异学习到不同业务领域中的样本之间的样本特征差异,从而提高特征处理网络在目标域上的泛化能力。
继续参见图3A,在步骤104中,基于每个源域样本的真实标签以及每个目标域样本的预测标签,确定每个样本对的标签损失。
参见图3D,图3D是本申请实施例提供的基于人工智能的模型训练方法的第四流程示意图。在一些实施例中,图3A中的步骤104可以通过图3D示出的步骤1041至步骤1045实现,下面进行详细说明。
在步骤1041中,基于每个源域样本的真实标签以及每个目标域样本的预测标签,确定对应每个类别的期望。
在一些实施例中,步骤1041可以通过以下方式实现:针对每个类别执行以下处理:从每个源域样本的真实标签中提取每个源域样本属于类别的第三概率,并从每个目标域样本的预测标签中提取每个目标域样本属于类别的第四概率;对多个第三概率以及多个第四概率进行融合处理,得到融合概率,并将融合概率与类别对应的标签值进行相乘,得到对应类别的期望。
作为示例,当类别的数目为N时,真实标签与预测标签均是N维向量,每个维度对应一个类别,例如真实标签(0,1,0)表征该样本属于第2个类别,预测标签(0.2,0.7,0.1)表征该样本属于第2类别的概率为0.7,属于第1类别的概率为0.2,属于第3类别的概率为0.1。
作为示例,以第p类别为例,获取每个样本属于第p类别的概率。对于源域样本,
可以从源域样本的真实标签中直接获取该源域样本属于第p类别的第三概率,而对于目标
域样本,则可以从目标域样本的预测标签中获取该目标域样本属于第p类别的第四概率。对
应第p类别的期望可以通过公式(2)计算得到:
(2)
其中,为任一样本n属于第p类别的概率,n=1,…,N,N为样本的总数量,为第p
类别对应的标签值。
通过针对每个类别,从源域样本的真实标签中获取源域样本属于该类别的第三概率,并从目标域样本的预测标签中获取目标域样本获取目标域样本属于该类别的第四概率,再对多个第三概率以及多个第四概率进行融合处理,得到融合概率,并将融合概率与该类别对应的标签值进行相乘,得到对应该类别的期望,以表征该类别被命中的长期平均结果。
在步骤1042中,对样本集合进行两两遍历处理,得到多个样本对。
作为示例,其中,样本对包括第一样本以及第二样本,以样本集合中包括样本i、样本j以及样本w为例,对样本集合进行两两遍历处理,则得到样本对ij、样本对iw以及样本对jw。以样本对ij为例,样本对ij包括第一样本i和第二样本j。
针对每个样本对执行以下处理:
在步骤1043中,从第一样本的真实标签或者预测标签中提取第一样本属于每个类别的第一概率,并从第二样本的真实标签或者预测标签中提取第二样本属于每个类别的第二概率。
作为示例,对于第p类别,从第一样本i的真实标签或预测标签中提取第一样本i属于第p类别的第一概率,从第二样本j的真实标签或预测标签中提取第二样本j属于第p类别的第二概率。
在步骤1044中,针对类别,获取与对应类别的第一概率以及第二概率正相关,且与类别的期望负相关的类别损失。
作为示例,对于第p类别,可以将对应第p类别的第一概率和第二概率相乘,得到第一样本i和第二样本j均属于第p类别的概率,并以第一样本i和第二样本j均属于第p类别的概率与第p类别的期望进行求比值处理,得到对应第p类别的类别损失。此外,还可以用其他能够获取与对应类别的第一概率以及第二概率正相关,且与类别的期望负相关的类别损失的方式计算,本申请在此不作限制。
在步骤1045中,对多个类别分别对应的类别损失进行融合处理,得到样本对的标签损失。
作为示例,基于对应多个类别的多个类别损失,根据公式(3)计算样本对ij的标签
损失:
(3)
其中,p代表C个类别中的第p类别,C为类别总数量,为第一样本i属于第p类别
的第一概率,为第二样本j属于第p类别的第二概率,为第p类别的期望。
通过计算标签损失,对样本对中的两个样本的标签差异进行表征,并用于后续与样本间转移概率共同计算样本差异损失,以根据两个样本的标签差异与样本特征之间的相似程度对特征处理网络进行更新,使特征处理网络能够对学习到不同域的样本之间的差异与共性,从而提高特征处理网络在目标域上的泛化能力。
继续参见图3A,在步骤105中,基于每个样本对的样本间转移概率以及每个样本对的标签损失,确定样本差异损失,并基于样本差异损失对特征处理网络进行更新,得到经过更新的特征处理网络。
参见图3E,图3E是本申请实施例提供的基于人工智能的模型训练方法的第五流程示意图。在一些实施例中,图3A中的步骤105中的基于每个样本对的样本间转移概率以及每个样本对的标签损失,确定样本差异损失,可以通过图3E中示出的步骤1051至步骤1053实现,下面进行详细说明。
针对每个样本对执行以下处理:
在步骤1051中,对样本对的样本间转移概率以及样本对的标签损失进行相乘处理,得到样本对的子差异损失。
作为示例,对于样本对ij,对样本对ij的样本间转移概率以及样本对ij的标签
损失进行相乘处理,得到样本ij的子差异损失,也可以先对样本对ij的样本间转移概率进行取对数处理,得到样本对ij的样本间转移概率对数,将与样本对ij的
标签损失进行相乘处理,得到样本ij的子差异损失。
在步骤1052中,对多个样本对的子差异损失进行融合处理,得到第二融合结果。
作为示例,对样本对ij、样本对iw以及样本对jw的子差异损失进行融合处理,得到第二融合结果。
在步骤1053中,获取与第二融合结果负相关的样本差异损失。
作为示例,根据公式(4)获取与第二融合结果负相关的样本差异损失:
(4)
其中,为样本集合中的样本总数,为样本对ij的标签损失,为样本对ij的样
本间转移概率,则为样本对ij的子差异损失。
通过根据每个样本对的样本间转移概率和标签损失确定每个样本对的子差异损失,对每个样本对中第一样本和第二样本之间的特征差异进行表征,并基于每个样本对的子差异损失计算样本差异损失,以表征源域样本和目标域样本之间的样本差异,用于指导特征处理网络在尽可能学习源域样本和目标域样本之间的特征共性,进而使特征处理网络提取的图结构特征尽量模糊其业务领域来源,从而使特征处理网络能够在目标域上提高泛化能力。
参见图3F,图3F是本申请实施例提供的基于人工智能的模型训练方法的第六流程示意图。在一些实施例中,图3A示出的步骤105中的基于样本差异损失对特征处理网络进行更新,得到经过更新的特征处理网络,可以通过图3F示出的步骤1054至步骤1058实现,下面进行详细说明。
在步骤1054中,获取训练三联体损失。
作为示例,对于每个源域样本,获取与该源域样本具有相同真实标签的正面源域
样本,以及与该源域样本具有不同真实标签的负面源域样本,获取该源域样本的样本特征
与负面源域样本的样本特征之间的第六特征距离,以及该源域样本的样本特征与正面源域
样本的样本特征之间的第七特征距离。根据公式(5),获取与第一特征距离负相关,且与第
二特征距离正相关的训练三联体损失:
(5)
其中,为源域样本的样本特征,为正面源域样本的样本特征,为负面
源域样本的样本特征,d表示计算两个样本特征之间的第六特征距离或第七特征距离,如余
弦相似度,M1为一个预设值。
在步骤1055中,基于源域分类网络获取训练分类器损失。
作为示例,通过源域分类网络对源域样本的图结构特征执行分类处理,得到源域
样本的训练预测标签,基于源域样本的训练预测标签与真实标签之间的差异,根据公式(6)
确定训练分类器损失:
(6)
其中,T表示源域样本的数量,i=1,…,T,为源域样本的图结构特征,为对
应源域样本的预测标签y与源域样本的真实标签相同的概率。
在步骤1056中,基于判别网络获取训练判别器损失。
作为示例,通过判别网络对源域样本的图结构特征进行域判别处理,得到源域样
本的判别结果,并通过判别网络对目标域样本的图结构特征进行域判别处理,得到目标域
样本的判别结果,对源域样本的判别结果以及目标域样本的判别结果进行融合处理,根据
公式(7)得到训练判别器损失:
(7)
其中,为源域样本,为目标域样本,为对应样本x的图结构特征,
为判别网络以样本x的图结构特征为输入得到的判别结果,此处为样本x为源域样本的概
率,当样本x为源域样本时,则以计算,当样本x为目标域样本时,则以计算,表示取均值。
在步骤1057中,对训练三联体损失、训练分类器损失、训练判别器损失以及样本差异损失进行融合处理,得到综合损失。
作为示例,将样本差异损失、训练三联体损失、训练分类器损失以及训
练判别器损失进行融合处理,得到综合损失,此处的融合处理可以为相加处理,也
可以为其他能够实现损失融合的处理方法,本申请在此不作限制。
在步骤1058中,基于综合损失对特征处理网络进行更新,得到经过更新的特征处理网络。
作为示例,在分别获取样本差异损失、训练三联体损失、训练分类器损失以及训练判别器损失后,将其融合为综合损失,在本轮训练中,使用特征处理网
络进行一次性更新处理,得到更新后的特征处理网络,
在一些实施例中,还可以执行以下操作:基于综合损失对判别网络、源域分类网络进行更新,得到更新后的判别网络以及更新后的源域分类网络。
作为示例,在基于综合损失,对特征处理网络进行更新时,还可以基于综合损失,对用于执行域判别处理的判别网络以及用于执行分类处理的源域分类网络进行更新处理。
通过将训练三联体损失、训练分类器损失、训练判别器损失以及样本差异损失融合为综合损失,并基于综合损失对特征处理网络、判别网络以及源域分类网络进行更新,使特征处理网络、判别网络以及源域分类网络根据本轮训练的综合损失进行调参,一方面使特征处理网络能够学习到本轮训练中源域样本之间的特征差异、目标域样本之间的特征差异以及源域样本与目标域样本之间特征差异与特征共性,以提高特征处理网络在目标域上的泛化能力,另一方面,使判别网络以及源域分类网络的参数设置与特征处理网络相匹配,保证在下一轮的模型训练中,不会因为判别网络的域判别结果不准确或源域分类网络的分类结果不准确而影响特征处理网络的训练效果,从而提高特征处理网络在目标域上的泛化能力,且将训练三联体损失、训练分类器损失、训练判别器损失以及样本差异损失融合为综合损失,对特征处理网络、判别网络以及源域分类网络进行综合更新,减少参数更新处理次数,提高服务器的计算资源利用率。
参见图3G,图3G是本申请实施例提供的基于人工智能的模型训练方法的第七流程示意图。在一些实施例中,在执行步骤102之前,还可以执行步骤201至步骤203,下面进行详细说明。
在步骤201中,基于初始化的样本特征网络获取预训练三联体损失,并基于预训练三联体损失对初始化的样本特征网络进行更新,得到经过一次更新的样本特征网络。
参见图3H,图3H是本申请实施例提供的基于人工智能的模型训练方法的第八流程示意图。在一些实施例中,图3G中的步骤201中的基于初始化的样本特征网络获取预训练三联体损失,可以通过图3H示出的步骤2011至步骤2014实现,下面进行详细说明。
在步骤2011中,通过初始化的样本特征网络对每个源域样本进行样本特征提取处理,得到每个源域样本的预训练样本特征。
作为示例,初始化的样本特征网络可以为未经过训练的样本特征网络或经过上一轮训练得到的样本特征网络,初始化的样本特征网络的类型可以根据样本的类型进行选择,本申请在此不作限制。以源域样本作为初始化的样本特征网络的输入,初始化的样本特征网络对源域样本进行样本特征提取处理,得到每个源域样本的预训练样本特征。
针对每个源域样本执行以下处理:
在步骤2012中,获取与源域样本具有相同真实标签的正面源域样本以及与源域样本具有不同真实标签的负面源域样本。
作为示例,以新闻业务场景为例,源域样本a的真实标签为“财经”,源域样本b的真实标签为“娱乐”,源域样本c的真实标签为“财经”,则对于源域样本a而言,源域样本c为正面源域样本,源域样本b为负面源域样本。
在步骤2013中,获取源域样本的预训练样本特征与负面源域样本的预训练样本特征之间的第一特征距离,以及源域样本的预训练样本特征与正面源域样本的预训练样本特征之间的第二特征距离。
作为示例,由初始化的样本特征网络分别对正面源域样本和负面源域样本进行样本特征提取处理,得到正面源域样本的预训练样本特征和负面源域样本的预训练样本特征。计算源域样本的预训练样本特征与负面源域样本的预训练样本特征之间的第一特征距离,以及源域样本的预训练样本特征与正面源域样本的预训练样本特征之间的第二特征距离,如余弦相似度。
在步骤2014中,获取与第一特征距离负相关,且与第二特征距离正相关的预训练三联体损失。
作为示例,根据公式(8)计算与第一特征距离负相关,且与第二特征距离正相关的
预训练三联体损失:
(8)
其中,为源域样本的第一预训练样本特征,为正面源域样本的第一预训
练样本特征,为负面源域样本的第一预训练样本特征,d表示计算两个第一预训练样本
特征之间的第一特征距离或第二特征距离,M2为一个预设值。
通过基于预训练三联体损失对初始化的样本特征网络进行更新,使经过一次更新的样本特征网络学习根据源域样本的样本特征之间的分类差异,进而使经过一次更新的样本特征网络能够对样本的样本特征进行精确提取,保证在本轮的模型训练中,避免因特征处理网络中的样本特征网络提取的样本特征不准确而影响特征处理网络的训练效果,从而提高特征处理网络在目标域上的泛化能力。
继续参见图3G,在步骤202中,基于初始化的源域分类网络获取预训练分类器损失,并基于预训练分类器损失对经过一次更新的样本特征网络以及初始化的图特征网络进行更新,得到经过二次更新的样本特征网络以及经过一次更新的图特征网络。
参见图3I,图3I是本申请实施例提供的基于人工智能的模型训练方法的第九流程示意图。在一些实施例中,图3G中的步骤202中的基于初始化的源域分类网络获取预训练分类器损失,可以通过图3I示出的步骤2021至步骤2023实现,下面进行详细说明。
在步骤2021中,通过初始化的图特征网络对源域样本的预训练样本特征进行图特征提取处理,得到源域样本的预训练图结构特征。
作为示例,初始化的图特征网络可以为未经过训练的图特征网络或经过上一轮训练得到的图特征网络,将源域样本的预训练样本特征作为初始化的图特征网络的输入,初始化的图特征网络计算源域样本的预训练样本特征的两两之间的第八特征距离,并基于第八特征距离构建第三邻接矩阵,并基于源域样本的预训练样本特征构建源域样本的预训练样本特征矩阵,将源域样本的预训练样本特征矩阵和第三邻接矩阵相乘,得到源域样本的预训练图结构特征。
在步骤2022中,通过初始化的源域分类网络对源域样本执行基于预训练图结构特征的分类处理,得到源域样本的预测标签。
作为示例,初始化的源域分类网络可以为未经过训练的源域分类网络或经过上一轮训练得到的源域分类网络,以源域样本的预训练图结构特征作为初始化的源域分类网络的输入,初始化的源域分类网络对源域样本的预训练图结构特征进行分类处理,得到源域样本的预测标签。
在步骤2023中,基于源域样本的预测标签与源域样本的真实标签之间的差异,确定预训练分类器损失。
作为示例,根据源域样本的真实标签,从源域样本的预测标签中确定源域样本被
正确分类的概率,并根据公式(9)计算得到预训练分类器损失:
(9)
其中,T表示源域样本的数量,i=1,…,T,为源域样本的预训练图结构特征,为对应源域样本的预测标签y与源域样本的真实标签相同的概率。
通过由初始化的图特征网络对源域样本的预训练样本特征进行图特征提取处理,得到源域样本的预训练图结构特征,并根据源域样本的预训练图结构特征,确定源域样本的预测标签,基于源域样本的预测标签与源域样本的真实标签之间的差异,确定预训练分类器损失,基于预训练分类器损失对经过一次更新的样本特征网络以及初始化的图特征网络进行更新,得到经过二次更新的样本特征网络以及经过一次更新的图特征网络,令经过二次更新的样本特征网络和经过一次更新的图特征网络能够根据源域样本的预测标签和真实标签之间的差距进行参数更新,使经过二次更新的样本特征网络和经过一次更新的图特征网络能够在源域样本上实现精确的特征提取,用以后续对目标域样本的图结构特征进行精确提取,以对目标域样本之间的特征关系信息进行表征。
继续参见图3G,在步骤203中,基于初始化的判别网络获取预训练判别器损失,并基于预训练判别器损失对经过二次更新的样本特征网络以及经过一次更新的图特征网络进行更新,得到用于构成特征处理网络的样本特征网络以及用于构成特征处理网络的图特征网络。
参见图3J,图3J是本申请实施例提供的基于人工智能的模型训练方法的第十流程示意图。在一些实施例中,图3G中的步骤203中的基于初始化的判别网络获取预训练判别器损失,可以通过图3J示出的步骤2031至步骤2033实现,下面进行详细说明。
在步骤2031中,通过初始化的判别网络对源域样本进行基于源域样本的预训练图结构特征的域判别处理,得到源域样本的预训练域判别结果。
作为示例,初始化的判别网络可以为未经过训练的判别网络或经过上一轮训练得到的判别网络。以源域样本的预训练图结构特征为初始化的判别网络的输入,由初始化的判别网络对源域样本的预训练图结构特征进行域判别处理,得到源域样本的预训练域判别结果。
在步骤2032中,通过初始化的判别网络对目标域样本进行基于目标域样本的预训练图结构特征的域判别处理,得到目标域样本的预训练域判别结果。
作为示例,以目标域样本的预训练图结构特征为初始化的判别网络的输入,由初始化的判别网络对目标域样本的预训练图结构特征进行域判别处理,得到目标域样本的预训练域判别结果。
在步骤2033中,对源域样本的预训练域判别结果以及目标域样本的预训练域判别结果进行融合处理,得到预训练判别器损失。
作为示例,对源域样本的预训练域判别结果以及目标域样本的预训练域判别结果
进行融合处理,根据公式(10)得到预训练判别器损失:
(10)
其中,为源域样本,为目标域样本,为对应样本x的预训练图结构特征,为初始化的判别网络以样本x的预训练图结构特征为输入得到的预训练域判别结
果,此处为样本x为源域样本的概率,当样本x为源域样本时,则以计算,当
样本x为目标域样本时,则以计算,表示取均值。
通过由初始化的判别网络对源域样本的预训练图结构特征以及目标域样本的预训练图结构特征分别进行域判别处理,得到源域样本的预训练域判别结果以及目标域样本的预训练域判别结果,并根据源域样本的预训练域判别结果以及目标域样本的预训练域判别结果,确定预训练判别器损失,基于预训练判别器损失对经过二次更新的样本特征网络以及经过一次更新的图特征网络进行基于最大化梯度策略的更新,得到用于构成特征处理网络的样本特征网络以及用于构成特征处理网络的图特征网络,使用于构成特征处理网络的样本特征网络以及用于构成特征处理网络的图特征网络输出的样本的图结构特征中蕴含的关于样本来源域的信息尽量少,尽可能生成令判别网络无法正确区分样本的来源域的样本的图结构特征,进而使源域分类网络能够对特征处理网络输出的图结构特征进行精确分类,从而提高特征处理网络在目标域上的泛化能力。
在一些实施例中,还可以执行以下操作:基于预训练分类器损失对初始化的源域分类网络进行更新,得到源域分类网络,其中,源域分类网络用于基于每个目标域样本的图结构特征对每个目标域样本进行分类处理。
作为示例,在基于预训练分类器损失对初始化的样本特征网络进行更新时,可以基于预训练分类器损失对初始化的样本特征网络以及初始化的源域分类网络同时进行更新,则可得到特征处理网络以及源域分类网络。
在一些实施例中,还可以执行以下操作:基于预训练判别器损失对初始化的判别网络进行更新,得到判别网络,其中,判别网络用于对目标域样本以及源域样本进行域判别处理。
作为示例,在基于预训练判别器损失对初始化的样本特征网络进行更新时,可以基于预训练判别器损失对初始化的样本特征网络以及初始化的判别网络同时进行更新,则可得到特征处理网络以及判别网络。
通过基于预训练三联体损失对初始化的样本特征网络进行更新,基于预训练分类器损失对经过一次更新的样本特征网络、初始化的图特征网络以及初始化的源域分类网络进行更新,基于预训练判别器损失对经过二次更新的样本特征网络、经过一次更新的图特征网络以及初始化的判别网络进行更新,一方面使特征处理网络能够学习到上轮训练中源域样本之间的特征差异、目标域样本之间的特征差异以及源域样本与目标域样本之间特征差异与特征共性,另一方面,使判别网络以及源域分类网络的参数设置与特征处理网络相匹配,保证在本轮的模型训练中,不会因为判别网络的域判别结果不准确或源域分类网络的分类结果不准确而影响特征处理网络的训练效果,从而提高特征处理网络在目标域上的泛化能力。
下面,说明本申请实施例提供的基于人工智能的数据处理方法,如前,实现本申请实施例的基于人工智能的数据处理方法的电子设备可以是终端、服务器,又或者是二者的结合。因此下文中不再重复说明各个步骤的执行主体。
需要说明的是,下文中的基于人工智能的数据处理方法的示例中,是以对象为文本为例说明的,本领域技术人员根据对下文的理解,可以将本申请实施例提供的数据处理方法应用于包括其他类型对象的数据分类的处理。
参见图4,图4是本申请实施例提供的基于人工智能的数据处理方法的流程示意图,将结合图4示出的步骤301至步骤303进行说明。
在步骤301中,获取待分类数据。
作为示例,待分类数据与目标域样本属于相同业务领域。例如,目标域样本的样本领域是相对于源域样本的不同业务方的样本领域,则待分类数据与目标域样本来源于相同业务方的样本领域。又例如,目标域样本的样本领域是相对于源域样本的相同业务方的新增样本领域,则待分类数据与目标域样本均来源于新增样本领域。
作为示例,在新闻业务场景中,目标域样本和待分类数据可以属于新增新闻数据;在视频业务场景中,目标域样本和待分类数据可以属于新增视频数据;在购物业务场景中,目标域样本和待分类数据可以属于新增商品数据。此外,在不同业务场景之间实现领域自适应的场景下,目标域样本和待分类数据可以属于目标自适应领域数据,例如,将应用于媒体流业务场景的分类模型迁移至新闻业务场景时,目标域样本和待分类数据可以属于新闻数据。
在步骤302中,通过特征处理网络对待分类数据进行特征提取处理,得到待分类数据的样本特征,并通过特征处理网络对待分类数据的样本特征进行图特征提取处理,得到待分类数据的图结构特征。
作为示例,特征处理网络是通过本申请实施例提供的基于人工智能的模型训练方法得到的。以待分类数据的样本特征输入特征处理网络,特征处理网络中的样本特征网络对待分类数据的样本特征进行特征提取处理,得到待分类数据的样本特征,特征处理网络中的图特征网络对待分类数据的样本特征进行图结构特征提取处理,得到待分类数据的图结构特征。
在步骤303中,基于待分类数据的图结构特征,对待分类数据进行分类处理,得到待分类数据的预测标签。
作为示例,以待分类数据的图结构特征输入源域分类网络,源域分类网络对待分类数据的图结构特征进行分类处理,得到待分类数据的预测标签,待分类数据的预测标签包含待分类数据属于每个类别的概率。
通过使用由本申请实施例提供的基于人工智能的模型训练方法得到的特征处理网络,对与目标域样本属于相同业务领域的待分类数据进行特征提取处理,使得到的待分类数据的图结构特征中蕴含尽量少的目标域特征信息,以保证源域分类网络对待分类数据的图结构特征进行分类处理时,不会被待分类数据的图结构特征中的目标域特征信息干扰,而能够根据源域分类网络已经学习知识对待分类数据的图结构特征进行精确分类处理,得到准确的待分类数据的预测标签,从而提高特征处理网络在目标域上的泛化能力,无需针对目标域部署专门的特征处理网络,节约服务器的计算资源与内存资源,提高服务器的计算资源利用率与内存资源利用率。
下面,将说明本申请实施例在一个实际的新闻分类应用场景中的示例性应用。
以新增的新闻文章作为目标域的样本,原有的新闻文章作为源域的样本。在本申请实施例中,首先对两个域的样本进行同样的预处理,保证样本的一致性,然后通过一个样本特征网络提取两个域的样本特征。然后,将源域样本的样本特征和目标域样本的样本特征分别送入由图卷积神经网络(Graph Convolutional Network,GCN)构成的图特征网络,然后使用样本差异损失和训练判别器损失联合监督样本特征网络和图特征网络提取样本的特征信息。
参见图5,图5是本申请实施例提供的基于人工智能的模型训练方法的框架示意图。如图5所示,基于人工智能的模型训练方法的框架包括样本特征网络、图特征网络、马尔科夫随机游走模块、判别网络以及源域分类网络,其中,样本特征网络用于对源域样本以及目标域样本进行特征提取处理,得到源域样本的样本特征和目标域样本的样本特征;图特征网络用于对源域样本的样本特征和目标域样本的样本特征进行图结构特征提取处理,得到源域样本的图结构特征和目标域样本的图结构特征;源域分类网络用于对源域样本的图结构特征和目标域样本的图结构特征进行分类处理,得到源域样本的预测标签和目标域样本的预测标签;马尔科夫随机游走模块用于根据多步马尔科夫聚类方法对每个源域样本和每个目标域样本进行随机游走处理,得到每两个样本之间的样本间迁移概率;判别网络用于对源域样本的图结构特征和目标域样本的图结构特征进行域判别处理,得到源域样本和目标域样本的预训练域判别结果。基于上述框架,设计了四个损失以对框架中的网络进行训练:并根据源域样本的样本特征得到三联体损失;根据源域样本的真实损失和源域样本的预测样本得到分类器损失;根据源域样本的真实标签和目标域样本的预测标签计算标签损失,根据标签损失和样本间转移概率计算样本差异损失;根据源域样本和目标域样本的预训练域判别结果计算判别损失。下面结合图5对各模块以及框架进行详细说明。
1、样本特征网络
首先定义不同类型的样本,即不同类型的输入,对应不同类型的样本特征网络,得到不同类型的样本的样本特征。
如果样本是视频,采用基于时空局部性归纳偏置的视觉自注意力模型(VideoSwin Transformers)。
参见图6,图6是本申请实施例提供的视觉自注意力模型的模型架构示意图。如图6所示,输入的样本为视频样本,经过三维分区操作之后,视频样本会变成一个向量。三维分区操作之后会紧跟一个线性嵌入层,充当全连接层。之后分别会经过多个阶段,图6中示出了第一阶段、第二阶段、第三阶段以及第四阶段的特征提取处理,最终得到样本特征,每个阶段的特征提取处理会经过区融合层和多个视觉自注意力模块,区融合层用来改变特征的形状和维度,视觉自注意力模块是利用注意力机制对同一个窗口内的特征进行特征融合的模块。在图6中的视觉自注意力模块下方有×2或是×6这样的符号,表示有对应的视觉自注意力模块中的子模块数量,这必定是个偶数,因为视觉自注意力模块有两种子模块,两个不同的子模块需要连在一起搭配使用,作为一个子模块组,比如,×2表示对应的视觉自注意力模块中有一个子模块组,×6表示对应的视觉自注意力模块中有三个子模块组。
参见图7,图7是本申请实施例提供的视觉自注意力模块的子模块组架构示意图。如图7所示,视觉自注意力模块有两种子模块,两个不同的子模块需要连在一起搭配使用,作为一个子模块组,每个子模块都是先过一个归一化层,再过一个多头注意力层,再过一个归一化层,最后过一个多层感知器,其中有两处使用了残差模块。残差模块主要是为了缓解梯度弥散。两种子模块的区别在于第一个子模块的多头注意力层是基于窗口的多头自注意力,第二个子模块是基于移位窗口的多头自注意力。第一个子模块是为了进行窗口内(即局部)的信息交流,第二个子模块是为了不同窗口之间(即全局)的信息交流。
如果输入的样本是文本,则采用基于自注意力模型的双向编码器(BERT,Bidirectional Encoder Representations from Transformers)作为样本特征网络。参见图8,图8是本申请实施例提供的基于自注意力模型的双向编码器的处理流程示意图,如图8所示,将文本类型的样本输入自注意力模型的双向编码器,自注意力模型的双向编码器对文本类型的样本进行特征提取处理,得到样本特征。自注意力模型的双向编码器的输入层在原有静态词向量编码和位置编码的基础上,增加了语句分割编码,该层的输出结果是三种编码之和。
在本技术方案中,首先以源域样本为样本特征网络的输入,得到源域样本的样本
特征。
每个源域样本具有对应的标签分类,根据源域样本的标签分类计算三联体损失
(Triplet loss),计算方式如公式(11)所示:
(11)
其中,为当前源域样本的样本特征,为与当前源域样本具有相同真实标签
的正面源域样本的样本特征,为与当前源域样本具有不同真实标签的负面源域样本的
样本特征,d表示计算两个样本特征之间的余弦相似度,M为一个预设值。基于对样本特征
网络进行更新。
2、图特征网络
应用图特征网络来捕获图的拓扑结构。给定一个有m个节点的无向图,其中,无向
图中的一个节点对应一个样本。根据通过样本特征网络提取到的样本特征,构建样本特征
矩阵Hin∈,其中,R代表一个向量空间,m表示m个节点,每个节点对应的样本特征的
长度是din;计算每两个节点的样本特征之间的余弦相似度,并根据多个余弦相似度构建邻
接矩阵A∈R(m×m),其中,R代表一个向量空间,m×m表示m个节点之间的组合关系,例如,邻接
矩阵A中的元素aij表示样本i与样本j之间的余弦相似度Mij。
由于早期训练的目标域和源域差异性比较大,对于源域样本的样本特征和目标
域样本的样本特征,分别进行图结构特征提取处理。
首先以源域样本的样本特征作为图特征网络的输入:
第一步,对第一邻接矩阵进行归一化处理,将归一化后的第一邻接矩阵A与源域样本的样本特征矩阵Hin相乘,得到表征潜在的图结构信息的源域样本的融合特征矩阵。
在第二步,通过线性滤波器和激活函数ReLU对源域样本的融合特征矩阵进行映射
处理,得到源域样本的图结构特征。
3、源域分类网络
得到源域样本的图结构特征之后,将源域样本的图结构特征输入源域分类网
络(由两个全连接层构成的深度学习分类网络),对源域样本的图结构特征进行分类处理,
得到对应源域样本的预测标签y,根据预测标签y确定源域分类网络的分类器损失,计算
方式如公式(12)所示:
(12)
其中,T表示源域样本的数量,i=1,…,T,为对应源域样本的预测标签y与源
域样本的真实标签y'相同的概率。
基于分类器损失,对样本特征网络、图特征网络以及源域分类网络进行参数更
新处理。
利用更新后的样本特征网络对目标域样本进行特征提取处理,得到目标域样本的
样本特征,特征提取处理的过程与源域样本的样本特征的特征提取处理过程相同。
将目标域样本的样本特征输入更新后的图特征网络,利用更新后的图特征网络
对目标域样本的样本特征进行图结构特征提取处理,则得到目标域图结构特征。
将获得的源域图结构特征和目标域图结构特征一起输入马尔科夫随机游走模
块,以学习域不变特征和进行类内紧凑表示。
4、马尔科夫随机游走模块
为了增强图特征网络的特征表示能力,在每个训练批次中采用了一种创新的方
法:多步马尔科夫聚类,利用动态图重新创建的所有源域样本和目标域样本,更有针对性地
引导图特征网络学习到分歧的特征表达。首先,我们将目标域图结构特征输入到源域分类
网络,以对目标域的样本进行预测,生成了目标域样本的伪标签。然后,利用马尔科夫随机
游走模块来计算样本差异损失,以引导图特征网络学习源域样本和目标域样本之间的
特征关系,从而更好地捕捉两个域之间的差异性和共性。这个过程有助于提高图特征网络
在目标域上的泛化性能,使其能够更好地适应目标域的数据分布。通过结合多部马尔科夫
聚类和图特征网络的学习,实现了对特征表示的有效增强,为源域和目标域的信息传递提
供了更有力的指导。
通过图特征网络获取源域样本的图结构特征和目标域样本的图结构特征之
后,我们将源域样本的图结构特征和目标域样本的图结构特征一起输入到马尔科夫随
机游走模块中,相当于将源域样本和目标域样本混合在同一无向图中,通过余弦相似度来
衡量两个样本的图结构特征之间的距离,直觉上,如果两个样本越相似,两个样本的图结构
特征之间的相似度就越大,余弦相似度越接近于1,根据公式(13)计算从对应样本i转移到
对应样本j的样本间转移概率:
(13)
其中,为样本i和样本j之间的余弦相似度,为样本i和任一样本k之间的余弦
相似度,k=1,…,N,j=1,…,N,N为样本的总数量。
然后我们将目标域样本的图结构特征输入源域分类网络,使用源域分类网络为
目标域样本生成伪标签Y目标域,然后和源域样本的真实标签Y源域共同组成标签矩阵Y,标签矩
阵Y中的元素为样本i和样本j之间的标签损失。如前所述,样本i和样本j的来源域并不固
定,那么对于而言,是否有可能根据是两个伪标签计算得到,或是根据两个真实标签计算
得到,或是根据一真一伪两个标签计算得到。标签损失的计算方式如公式(14)所示:
(14)
其中,p代表C个类别中的第p类别,为样本i属于第p类别的概率,为样本j属
于第p类别的概率,为第p类别的所有样本的期望值,即所有样本各自属于第p类别的概率
与第p类别对应的标签值的总和,的计算方式如公式(15)所示:
(15)
其中,为任一样本n属于第p类别的概率,n=1,…,N,N为样本的总数量,为对
应第p类别的标签值。
获取样本间转移概率和标签矩阵之后,根据公式(16)计算样本差异损失,以
对样本特征网络以及图特征网络进行调参处理:
(16)
其中,N为样本的总数量,i代表N个样本中的样本i,j代表N个样本中的样本j,为
样本i和样本j的标签损失,为样本i和样本j的样本间转移矩阵。
5、判别网络
在获得图特征网络输出的源域图结构特征和目标域图结构特征后,我们引入
了一个专门的二分类器作为判别网络,以对源域样本和目标域样本进行区分判别。为了监
督模型,采用了一种最大化梯度的策略,旨在使模型在判别过程中尽可能分辨不清样本的
来源域,即无法确定其是来自源域还是目标域。通过操纵梯度的最大化来强制模型学习对
两个域之间微妙差异的敏感性,从而更好地实现领域间特征的区分性学习。这种监督机制
有助于增强模型对源域和目标域之间特征差异的感知,为领域自适应任务提供更为有力的
指导。根据公式(17)计算判别器损失:
(17)
其中,为源域样本,为目标域样本,为对应样本x的图结构特征,
为判别网络以样本x的图结构特征为输入得到的判别结果,此处为样本x为源域样本的概
率,当样本x为源域样本时,则以计算,当样本x为目标域样本时,则以计算,表示取均值。
计算得到后,即可根据对样本特征网络、图特征网络以及判别网
络进行参数调整。
在应用侧,对目标域的待分类数据,将其输入样本特征网络,对待分类数据进行特征提取处理,得到对应待分类数据的样本特征,将对应待分类数据的样本特征输入图特征网络中,对样本特征进行图卷积处理,得到对应待分类数据的图结构特征,此时将对应待分类数据的图结构特征输入源域分类网络中,对对应待分类数据的图结构特征进行分类处理,得到对应待分类数据的预测标签。
可以理解的是,在本申请实施例中,涉及到对象的操作数据、对象的对象特征等相关的数据,当本申请实施例运用到具体产品或技术中时,需要获得对象许可或者同意,且相关数据的收集、使用和处理需要遵守相关国家和地区的相关法律法规和标准。
在本申请中,涉及到的数据抓取技术方案实施,在本申请以上实施例运用到具体产品或技术中时,相关数据收集、使用和处理过程应该遵守国家法律法规要求,符合合法、正当、必要的原则,不涉及获取法律法规禁止或限制的数据类型,不会妨碍目标网站的正常运行。
下面继续说明本申请实施例提供的基于人工智能的模型训练装置233的实施为软件模块的示例性结构,在一些实施例中,如图2A所示,存储在存储器230的基于人工智能的模型训练装置233中的软件模块可以包括:获取模块2331,用于获取样本集合,其中,样本集合包括源域样本以及目标域样本,源域样本与目标域样本来源于不同业务领域;特征处理模块2332,用于通过特征处理网络对样本集合中每个样本进行样本特征提取处理,得到每个样本的样本特征,并通过特征处理网络对样本集合中每个样本的样本特征进行图特征提取处理,得到每个样本的图结构特征;分类处理模块2333,用于基于每个样本的图结构特征,确定每个样本对的样本间转移概率,并基于每个目标域样本的图结构特征对每个目标域样本进行分类处理,得到每个目标域样本的预测标签;损失处理模块2334,用于基于每个源域样本的真实标签以及每个目标域样本的预测标签,确定每个样本对的标签损失;更新处理模块2335,用于基于每个样本对的样本间转移概率以及每个样本对的标签损失,确定样本差异损失,并基于样本差异损失对特征处理网络进行更新,得到经过更新的特征处理网络。
在上述方案中,更新处理模块2335,用于基于初始化的样本特征网络获取预训练三联体损失,并基于预训练三联体损失对初始化的样本特征网络进行更新,得到经过一次更新的样本特征网络;基于初始化的源域分类网络获取预训练分类器损失,并基于预训练分类器损失对经过一次更新的样本特征网络以及初始化的图特征网络进行更新,得到经过二次更新的样本特征网络以及经过一次更新的图特征网络;基于初始化的判别网络获取预训练判别器损失,并基于预训练判别器损失对经过二次更新的样本特征网络以及经过一次更新的图特征网络进行更新,得到用于构成特征处理网络的样本特征网络以及用于构成特征处理网络的图特征网络;基于预训练分类器损失对初始化的源域分类网络进行更新,得到源域分类网络,其中,源域分类网络用于基于每个目标域样本的图结构特征对每个目标域样本进行分类处理;基于预训练判别器损失对初始化的判别网络进行更新,得到判别网络,其中,判别网络用于对目标域样本以及源域样本进行域判别处理。
在上述方案中,更新处理模块2335,用于通过初始化的样本特征网络对每个源域样本进行样本特征提取处理,得到每个源域样本的预训练样本特征;针对每个源域样本执行以下处理:获取与源域样本具有相同真实标签的正面源域样本以及与源域样本具有不同真实标签的负面源域样本;获取源域样本的预训练样本特征与负面源域样本的预训练样本特征之间的第一特征距离,以及源域样本的预训练样本特征与正面源域样本的预训练样本特征之间的第二特征距离;获取与第一特征距离负相关,且与第二特征距离正相关的预训练三联体损失。
在上述方案中,更新处理模块2335,用于通过初始化的图特征网络对源域样本的预训练样本特征进行图特征提取处理,得到源域样本的预训练图结构特征;通过初始化的源域分类网络对源域样本执行基于预训练图结构特征的分类处理,得到源域样本的预测标签;基于源域样本的预测标签与源域样本的真实标签之间的差异,确定预训练分类器损失。
在上述方案中,更新处理模块2335,用于通过初始化的判别网络对源域样本进行基于源域样本的预训练图结构特征的域判别处理,得到源域样本的预训练域判别结果;通过初始化的判别网络对目标域样本进行基于目标域样本的预训练图结构特征的域判别处理,得到目标域样本的预训练域判别结果;对源域样本的预训练域判别结果以及目标域样本的预训练域判别结果进行融合处理,得到预训练判别器损失。
在上述方案中,特征处理模块2332,用于针对每个源域样本,结合样本集合中每个源域样本的样本特征,对源域样本的样本特征进行图特征提取处理,得到源域样本的图结构特征;针对每个目标域样本,结合样本集合中每个目标域样本的样本特征,对目标域样本的样本特征进行图特征提取处理,得到目标域样本的图结构特征。
在上述方案中,特征处理模块2332,用于获取源域样本的样本特征与样本集合中每个源域样本的样本特征之间的第三特征距离;以样本集合中每个源域样本的样本特征为权重,将对应样本集合中每个源域样本的第三特征距离进行融合处理,得到对应源域样本的图结构特征。
在上述方案中,分类处理模块2333,用于对样本集合进行两两遍历处理,得到多个样本对;获取每个样本对对应的两个图结构特征之间的第四特征距离;将多个样本对的第四特征距离进行融合处理,得到第一融合结果;获取与每个样本对的第四特征距离正相关,且与第一融合结果负相关的数值作为每个样本对的样本间转移概率。
在上述方案中,损失处理模块2334,用于基于每个源域样本的真实标签以及每个目标域样本的预测标签,确定对应每个类别的期望;对样本集合进行两两遍历处理,得到多个样本对,其中,样本对包括第一样本以及第二样本;针对每个样本对执行以下处理:从第一样本的真实标签或者预测标签中提取第一样本属于每个类别的第一概率,并从第二样本的真实标签或者预测标签中提取第二样本属于每个类别的第二概率;针对类别,获取与对应类别的第一概率以及第二概率正相关,且与类别的期望负相关的类别损失;对多个类别分别对应的类别损失进行融合处理,得到样本对的标签损失。
在上述方案中,损失处理模块2334,用于针对每个类别执行以下处理:从每个源域样本的真实标签中提取每个源域样本属于类别的第三概率,并从每个目标域样本的预测标签中提取每个目标域样本属于类别的第四概率;对多个第三概率以及多个第四概率进行融合处理,得到融合概率,并将融合概率与类别对应的标签值进行相乘,得到对应类别的期望。
在上述方案中,更新处理模块2335,用于针对每个样本对执行以下处理:对样本对的样本间转移概率以及样本对的标签损失进行相乘处理,得到样本对的子差异损失;对多个样本对的子差异损失进行融合处理,得到第二融合结果;获取与第二融合结果负相关的样本差异损失。
在上述方案中,更新处理模块2335,用于获取训练三联体损失;基于源域分类网络获取训练分类器损失;基于判别网络获取训练判别器损失;对训练三联体损失、训练分类器损失、训练判别器损失以及样本差异损失进行融合处理,得到综合损失;基于综合损失对特征处理网络进行更新,得到经过更新的特征处理网络;基于综合损失对判别网络、源域分类网络进行更新,得到更新后的判别网络以及更新后的源域分类网络。
下面继续说明本申请实施例提供的基于人工智能的数据处理装置263的实施为软件模块的示例性结构,在一些实施例中,如图2A所示,存储在存储器260的基于人工智能的数据处理装置263中的软件模块可以包括:获取模块2631,用于获取待分类数据;特征处理模块2632,用于通过特征处理网络对待分类数据进行特征提取处理,得到待分类数据的样本特征,并通过特征处理网络对待分类数据的样本特征进行图特征提取处理,得到待分类数据的图结构特征;其中,特征处理网络是通过本申请实施例提供的基于人工智能的模型训练方法得到的,待分类数据与目标域样本属于相同业务领域;分类处理模块2633,用于基于待分类数据的图结构特征,对待分类数据进行分类处理,得到待分类数据的预测标签。
本申请实施例提供了一种计算机程序产品,该计算机程序产品包括计算机可执行指令,该计算机可执行指令存储在计算机可读存储介质中。电子设备的处理器从计算机可读存储介质读取该计算机可执行指令,处理器执行该计算机可执行指令,使得该电子设备执行本申请实施例上述的基于人工智能的模型训练方法以及基于人工智能的数据处理方法。
本申请实施例提供一种存储有计算机可执行指令的计算机可读存储介质,其中存储有计算机可执行指令,当计算机可执行指令被处理器执行时,将引起处理器执行本申请实施例提供的基于人工智能的模型训练方法以及基于人工智能的数据处理方法,例如,如图3A示出的基于人工智能的模型训练方法以及图4示出的基于人工智能的数据处理方法。
在一些实施例中,计算机可读存储介质可以是RAM、ROM、闪存、磁表面存储器、光盘、或CD-ROM等存储器;也可以是包括上述存储器之一或任意组合的各种设备。
在一些实施例中,计算机可执行指令可以采用程序、软件、软件模块、脚本或代码的形式,按任意形式的编程语言(包括编译或解释语言,或者声明性或过程性语言)来编写,并且其可按任意形式部署,包括被部署为独立的程序或者被部署为模块、组件、子例程或者适合在计算环境中使用的其它单元。
作为示例,计算机可执行指令可以但不一定对应于文件系统中的文件,可以可被存储在保存其它程序或数据的文件的一部分,例如,存储在超文本标记语言(Hyper TextMarkup Language,HTML)文档中的一个或多个脚本中,存储在专用于所讨论的程序的单个文件中,或者,存储在多个协同文件(例如,存储一个或多个模块、子程序或代码部分的文件)中。
作为示例,计算机可执行指令可被部署为在一个电子设备上执行,或者在位于一个地点的多个电子设备上执行,又或者,在分布在多个地点且通过通信网络互连的多个电子设备上执行。
综上所述,通过本申请实施例获取来源于不同业务领域的源域样本以及目标域样本,源域样本与目标域样本,通过特征处理网络对每个样本进行样本特征提取处理,得到每个样本的样本特征,以对每个样本的个体特征信息进行表征,并通过特征处理网络对样本集合中每个样本的样本特征进行图特征提取处理,得到每个样本的图结构特征,以对每个样本与相同域中的其他样本之间的关联特征信息进行表征,基于每个样本的图结构特征,确定每个样本对的样本间转移概率,以确定样本对中的两个样本之间的相似程度,并基于每个目标域样本的图结构特征对每个目标域样本进行分类处理,得到每个目标域样本的预测标签,以确定目标域样本的分类结果,基于每个源域样本的真实标签以及每个目标域样本的预测标签,确定每个样本对的标签损失,以表征样本对中两个样本之间的分类结果差异,基于每个样本对的样本间转移概率以及每个样本对的标签损失,确定样本差异损失,将样本之间的特征差异融合为样本差异损失,并基于样本差异损失对特征处理网络进行更新,得到经过更新的特征处理网络,使特征处理网络学习到相同域的样本之间以及不同域的样本之间的特征差异与特征共性,使特征处理网络能够对目标域样本进行精确的特征提取处理,从而能够解决源域数据和目标域数据的数据分布存在较大差异时导致的模型迁移效果变差的问题,提高分类模型的特征处理网络在目标域上的泛化能力,并提高对应分类服务的计算资源利用率。经过试验,由本申请实施例提供的基于人工智能的模型训练方案得到的特征处理网络构成的分类模型,在目标域数据标签识别中有很好的效果,在标签识别的准召效果的准确度提高了1.4%。
以上所述,仅为本申请的实施例而已,并非用于限定本申请的保护范围。凡在本申请的精神和范围之内所作的任何修改、等同替换和改进等,均包含在本申请的保护范围之内。
Claims (15)
1.一种基于人工智能的模型训练方法,其特征在于,所述方法包括:
获取样本集合,其中,所述样本集合包括源域样本以及目标域样本,所述源域样本与所述目标域样本来源于不同业务领域;
通过特征处理网络对所述样本集合中每个样本进行样本特征提取处理,得到每个所述样本的样本特征,并通过所述特征处理网络对所述样本集合中每个所述样本的样本特征进行图特征提取处理,得到每个所述样本的图结构特征;
基于每个所述样本的图结构特征,确定每个样本对的样本间转移概率,并基于每个所述目标域样本的图结构特征对每个所述目标域样本进行分类处理,得到每个所述目标域样本的预测标签;
基于每个所述源域样本的真实标签以及每个所述目标域样本的预测标签,确定每个所述样本对的标签损失;
基于每个所述样本对的样本间转移概率以及每个所述样本对的标签损失,确定样本差异损失,并基于所述样本差异损失对所述特征处理网络进行更新,得到经过更新的特征处理网络。
2.根据权利要求1所述的方法,其特征在于,所述特征处理网络包括用于执行样本特征提取处理的样本特征网络以及用于执行图结构特征提取处理的图特征网络;
在通过特征处理网络对所述样本集合中每个样本进行样本特征提取处理,得到每个所述样本的样本特征之前,所述方法还包括:
基于初始化的样本特征网络获取预训练三联体损失,并基于所述预训练三联体损失对所述初始化的样本特征网络进行更新,得到经过一次更新的样本特征网络;
基于初始化的源域分类网络获取预训练分类器损失,并基于所述预训练分类器损失对所述经过一次更新的样本特征网络以及初始化的图特征网络进行更新,得到经过二次更新的样本特征网络以及经过一次更新的图特征网络;
基于初始化的判别网络获取预训练判别器损失,并基于所述预训练判别器损失对所述经过二次更新的样本特征网络以及所述经过一次更新的图特征网络进行更新,得到用于构成所述特征处理网络的样本特征网络以及用于构成所述特征处理网络的图特征网络;
所述方法还包括:
基于所述预训练分类器损失对所述初始化的源域分类网络进行更新,得到源域分类网络,其中,所述源域分类网络用于基于每个所述目标域样本的图结构特征对每个所述目标域样本进行分类处理;
基于所述预训练判别器损失对所述初始化的判别网络进行更新,得到判别网络,其中,所述判别网络用于对所述目标域样本以及所述源域样本进行域判别处理。
3.根据权利要求2所述的方法,其特征在于,所述基于初始化的样本特征网络获取预训练三联体损失,包括:
通过初始化的样本特征网络对每个所述源域样本进行样本特征提取处理,得到每个所述源域样本的预训练样本特征;
针对每个所述源域样本执行以下处理:
获取与所述源域样本具有相同真实标签的正面源域样本以及与所述源域样本具有不同真实标签的负面源域样本;
获取所述源域样本的预训练样本特征与所述负面源域样本的预训练样本特征之间的第一特征距离,以及所述源域样本的预训练样本特征与所述正面源域样本的预训练样本特征之间的第二特征距离;
获取与所述第一特征距离负相关,且与所述第二特征距离正相关的预训练三联体损失。
4.根据权利要求2所述的方法,其特征在于,所述基于初始化的源域分类网络获取预训练分类器损失,包括:
通过初始化的图特征网络对所述源域样本的预训练样本特征进行图特征提取处理,得到所述源域样本的预训练图结构特征;
通过初始化的源域分类网络对所述源域样本执行基于所述预训练图结构特征的分类处理,得到所述源域样本的预测标签;
基于所述源域样本的预测标签与所述源域样本的真实标签之间的差异,确定所述预训练分类器损失。
5.根据权利要求2所述的方法,其特征在于,所述基于初始化的判别网络获取预训练判别器损失,包括:
通过所述初始化的判别网络对所述源域样本进行基于所述源域样本的预训练图结构特征的域判别处理,得到所述源域样本的预训练域判别结果;
通过所述初始化的判别网络对所述目标域样本进行基于所述目标域样本的预训练图结构特征的域判别处理,得到所述目标域样本的预训练域判别结果;
对所述源域样本的预训练域判别结果以及所述目标域样本的预训练域判别结果进行融合处理,得到所述预训练判别器损失。
6.根据权利要求1所述的方法,其特征在于,所述对所述样本集合中每个所述样本的样本特征进行图特征提取处理,得到每个所述样本的图结构特征,包括:
针对每个所述源域样本,结合所述样本集合中每个所述源域样本的样本特征,对所述源域样本的样本特征进行图特征提取处理,得到所述源域样本的图结构特征;
针对每个所述目标域样本,结合所述样本集合中每个所述目标域样本的样本特征,对所述目标域样本的样本特征进行图特征提取处理,得到所述目标域样本的图结构特征。
7.根据权利要求6所述的方法,其特征在于,所述结合所述样本集合中每个所述源域样本的样本特征,对所述源域样本的样本特征进行图特征提取处理,得到所述源域样本的图结构特征,包括:
获取所述源域样本的样本特征与所述样本集合中每个所述源域样本的样本特征之间的第三特征距离;
以所述样本集合中每个所述源域样本的样本特征为权重,将对应所述样本集合中每个所述源域样本的第三特征距离进行融合处理,得到对应所述源域样本的图结构特征。
8.根据权利要求1所述的方法,其特征在于,所述基于每个所述样本的图结构特征,确定每个样本对的样本间转移概率,包括:
对所述样本集合进行两两遍历处理,得到多个样本对;
获取每个所述样本对对应的两个图结构特征之间的第四特征距离;
将多个所述样本对的第四特征距离进行融合处理,得到第一融合结果;
获取与每个所述样本对的第四特征距离正相关,且与所述第一融合结果负相关的数值作为每个所述样本对的样本间转移概率。
9.根据权利要求1所述的方法,其特征在于,所述基于每个所述源域样本的真实标签以及每个所述目标域样本的预测标签,确定每个所述样本对的标签损失,包括:
基于每个所述源域样本的真实标签以及每个所述目标域样本的预测标签,确定对应每个类别的期望;
对所述样本集合进行两两遍历处理,得到多个样本对,其中,所述样本对包括第一样本以及第二样本;
针对每个所述样本对执行以下处理:
从所述第一样本的真实标签或者预测标签中提取所述第一样本属于每个类别的第一概率,并从所述第二样本的真实标签或者预测标签中提取所述第二样本属于每个类别的第二概率;
针对所述类别,获取与对应所述类别的第一概率以及第二概率正相关,且与所述类别的期望负相关的类别损失;
对多个所述类别分别对应的类别损失进行融合处理,得到所述样本对的标签损失。
10.根据权利要求9所述的方法,其特征在于,所述基于每个所述源域样本的真实标签以及每个所述目标域样本的预测标签,确定对应每个类别的期望,包括:
针对每个所述类别执行以下处理:
从每个所述源域样本的真实标签中提取每个所述源域样本属于所述类别的第三概率,并从每个所述目标域样本的预测标签中提取每个所述目标域样本属于所述类别的第四概率;
对多个所述第三概率以及多个所述第四概率进行融合处理,得到融合概率,并将所述融合概率与所述类别对应的标签值进行相乘,得到对应所述类别的期望。
11.根据权利要求1所述的方法,其特征在于,所述基于每个所述样本对的样本间转移概率以及每个所述样本对的标签损失,确定样本差异损失,包括:
针对每个所述样本对执行以下处理:对所述样本对的样本间转移概率以及所述样本对的标签损失进行相乘处理,得到所述样本对的子差异损失;
对多个所述样本对的子差异损失进行融合处理,得到第二融合结果;
获取与所述第二融合结果负相关的样本差异损失。
12.根据权利要求1所述的方法,其特征在于,所述基于所述样本差异损失对所述特征处理网络进行更新,得到经过更新的特征处理网络,包括:
获取训练三联体损失;
基于源域分类网络获取训练分类器损失;
基于判别网络获取训练判别器损失;
对所述训练三联体损失、所述训练分类器损失、所述训练判别器损失以及所述样本差异损失进行融合处理,得到综合损失;
基于所述综合损失对所述特征处理网络进行更新,得到经过更新的特征处理网络;
所述方法还包括:
基于所述综合损失对所述判别网络、所述源域分类网络进行更新,得到更新后的判别网络以及更新后的源域分类网络。
13.一种基于人工智能的模型训练装置,其特征在于,所述装置包括:
获取模块,用于获取样本集合,其中,所述样本集合包括源域样本以及目标域样本,所述源域样本与所述目标域样本来源于不同业务领域;
特征处理模块,用于通过特征处理网络对所述样本集合中每个样本进行样本特征提取处理,得到每个所述样本的样本特征,并通过所述特征处理网络对所述样本集合中每个所述样本的样本特征进行图特征提取处理,得到每个所述样本的图结构特征;
分类处理模块,用于基于每个所述样本的图结构特征,确定每个样本对的样本间转移概率,并基于每个所述目标域样本的图结构特征对每个所述目标域样本进行分类处理,得到每个所述目标域样本的预测标签;
损失处理模块,用于基于每个所述源域样本的真实标签以及每个所述目标域样本的预测标签,确定每个所述样本对的标签损失;
更新处理模块,用于基于每个所述样本对的样本间转移概率以及每个所述样本对的标签损失,确定样本差异损失,并基于所述样本差异损失对所述特征处理网络进行更新,得到经过更新的特征处理网络。
14.一种电子设备,其特征在于,所述电子设备包括:
存储器,用于存储计算机可执行指令;
处理器,用于执行所述存储器中存储的计算机可执行指令时,实现权利要求1至12任一项所述的基于人工智能的模型训练方法。
15.一种计算机可读存储介质,存储有计算机可执行指令,其特征在于,所述计算机可执行指令被处理器执行时实现权利要求1至12任一项所述的基于人工智能的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410344791.3A CN117932314A (zh) | 2024-03-25 | 2024-03-25 | 模型训练方法、装置、电子设备、存储介质及程序产品 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410344791.3A CN117932314A (zh) | 2024-03-25 | 2024-03-25 | 模型训练方法、装置、电子设备、存储介质及程序产品 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117932314A true CN117932314A (zh) | 2024-04-26 |
Family
ID=90768811
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410344791.3A Pending CN117932314A (zh) | 2024-03-25 | 2024-03-25 | 模型训练方法、装置、电子设备、存储介质及程序产品 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117932314A (zh) |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111860588A (zh) * | 2020-06-12 | 2020-10-30 | 华为技术有限公司 | 一种用于图神经网络的训练方法以及相关设备 |
CN113822315A (zh) * | 2021-06-17 | 2021-12-21 | 深圳市腾讯计算机系统有限公司 | 属性图的处理方法、装置、电子设备及可读存储介质 |
WO2023273769A1 (zh) * | 2021-07-01 | 2023-01-05 | 北京百度网讯科技有限公司 | 视频标签推荐模型的训练方法和确定视频标签的方法 |
CN116263785A (zh) * | 2022-11-16 | 2023-06-16 | 中移(苏州)软件技术有限公司 | 跨领域文本分类模型的训练方法、分类方法和装置 |
-
2024
- 2024-03-25 CN CN202410344791.3A patent/CN117932314A/zh active Pending
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111860588A (zh) * | 2020-06-12 | 2020-10-30 | 华为技术有限公司 | 一种用于图神经网络的训练方法以及相关设备 |
CN113822315A (zh) * | 2021-06-17 | 2021-12-21 | 深圳市腾讯计算机系统有限公司 | 属性图的处理方法、装置、电子设备及可读存储介质 |
WO2023273769A1 (zh) * | 2021-07-01 | 2023-01-05 | 北京百度网讯科技有限公司 | 视频标签推荐模型的训练方法和确定视频标签的方法 |
CN116263785A (zh) * | 2022-11-16 | 2023-06-16 | 中移(苏州)软件技术有限公司 | 跨领域文本分类模型的训练方法、分类方法和装置 |
Non-Patent Citations (2)
Title |
---|
SHANMING YANG ET AL.: "Orthogonality Loss:learning discriminative representations for face recognition", 《IEEE TRANSACTIONS ON CIRCUITS AND SYSTEMS FOR VIDEO TECHNOLOGY》, vol. 31, no. 6, 24 June 2021 (2021-06-24), pages 2301 - 2314, XP011858261, DOI: 10.1109/TCSVT.2020.3021128 * |
杨善明: "面向长尾和跨域数据的图像识别研究", 《中国优秀硕士学位论文全文数据库 (信息科技辑)》, no. 1, 15 January 2024 (2024-01-15), pages 1 * |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112699855A (zh) | 基于人工智能的图像场景识别方法、装置及电子设备 | |
CN113344206A (zh) | 融合通道与关系特征学习的知识蒸馏方法、装置及设备 | |
CN114064974B (zh) | 信息处理方法、装置、电子设备、存储介质及程序产品 | |
WO2023168903A1 (zh) | 模型训练和身份匿名化方法、装置、设备、存储介质及程序产品 | |
CN113011320B (zh) | 视频处理方法、装置、电子设备及存储介质 | |
CN113449011A (zh) | 基于大数据预测的信息推送更新方法及大数据预测系统 | |
CN111291695B (zh) | 人员违章行为识别模型训练方法、识别方法及计算机设备 | |
CN114357319A (zh) | 网络请求处理方法、装置、设备、存储介质及程序产品 | |
CN117217368A (zh) | 预测模型的训练方法、装置、设备、介质及程序产品 | |
Sharjeel et al. | Real time drone detection by moving camera using COROLA and CNN algorithm | |
CN114334040A (zh) | 分子图重构模型的训练方法、装置以及电子设备 | |
CN114677611B (zh) | 数据识别方法、存储介质及设备 | |
CN114818707A (zh) | 一种基于知识图谱的自动驾驶决策方法和系统 | |
CN113705293A (zh) | 图像场景的识别方法、装置、设备及可读存储介质 | |
CN116975347A (zh) | 图像生成模型训练方法及相关装置 | |
CN116824572A (zh) | 基于全局和部件匹配的小样本点云物体识别方法、系统及介质 | |
CN116975743A (zh) | 行业信息分类方法、装置、计算机设备和存储介质 | |
CN117932314A (zh) | 模型训练方法、装置、电子设备、存储介质及程序产品 | |
Zhang et al. | MTSCANet: Multi temporal resolution temporal semantic context aggregation network | |
CN115129885A (zh) | 实体链指方法、装置、设备及存储介质 | |
CN114898184A (zh) | 模型训练方法、数据处理方法、装置及电子设备 | |
CN115131600A (zh) | 检测模型训练方法、检测方法、装置、设备及存储介质 | |
CN113487374A (zh) | 一种基于5g网络的区块电商平台交易系统 | |
CN112084331B (zh) | 文本处理、模型训练方法、装置、计算机设备和存储介质 | |
CN116521761B (zh) | 基于人工智能的传感器运行行为挖掘方法及系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |