CN117609887A - 数据增强模型训练及数据处理方法、装置、设备、介质 - Google Patents
数据增强模型训练及数据处理方法、装置、设备、介质 Download PDFInfo
- Publication number
- CN117609887A CN117609887A CN202410078708.2A CN202410078708A CN117609887A CN 117609887 A CN117609887 A CN 117609887A CN 202410078708 A CN202410078708 A CN 202410078708A CN 117609887 A CN117609887 A CN 117609887A
- Authority
- CN
- China
- Prior art keywords
- data
- model
- training
- sample
- target domain
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 574
- 238000003672 processing method Methods 0.000 title claims abstract description 20
- 238000000034 method Methods 0.000 claims abstract description 89
- 238000004873 anchoring Methods 0.000 claims abstract description 41
- 238000012216 screening Methods 0.000 claims abstract description 19
- 230000006870 function Effects 0.000 claims description 195
- 238000012545 processing Methods 0.000 claims description 73
- 230000015654 memory Effects 0.000 claims description 28
- 238000005070 sampling Methods 0.000 claims description 18
- 238000013507 mapping Methods 0.000 claims description 14
- 238000004364 calculation method Methods 0.000 claims description 8
- 230000003190 augmentative effect Effects 0.000 claims description 6
- 230000008569 process Effects 0.000 description 18
- 238000010586 diagram Methods 0.000 description 11
- 230000000694 effects Effects 0.000 description 11
- 238000004590 computer program Methods 0.000 description 7
- 230000001360 synchronised effect Effects 0.000 description 7
- 230000003993 interaction Effects 0.000 description 6
- 238000013508 migration Methods 0.000 description 6
- 230000005012 migration Effects 0.000 description 6
- 238000010606 normalization Methods 0.000 description 6
- 238000004891 communication Methods 0.000 description 5
- 238000012163 sequencing technique Methods 0.000 description 5
- 238000013528 artificial neural network Methods 0.000 description 4
- 230000000295 complement effect Effects 0.000 description 4
- 238000005516 engineering process Methods 0.000 description 4
- 238000013526 transfer learning Methods 0.000 description 4
- 230000009286 beneficial effect Effects 0.000 description 3
- 230000009471 action Effects 0.000 description 2
- 238000004458 analytical method Methods 0.000 description 2
- 238000013473 artificial intelligence Methods 0.000 description 2
- 230000006399 behavior Effects 0.000 description 2
- 238000012790 confirmation Methods 0.000 description 2
- 238000013527 convolutional neural network Methods 0.000 description 2
- 125000004122 cyclic group Chemical group 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 239000000463 material Substances 0.000 description 2
- 230000007246 mechanism Effects 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 230000004044 response Effects 0.000 description 2
- 238000013515 script Methods 0.000 description 2
- 238000009825 accumulation Methods 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 238000013434 data augmentation Methods 0.000 description 1
- 238000013480 data collection Methods 0.000 description 1
- 238000000354 decomposition reaction Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000012217 deletion Methods 0.000 description 1
- 230000037430 deletion Effects 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 238000001914 filtration Methods 0.000 description 1
- 230000014509 gene expression Effects 0.000 description 1
- 238000003780 insertion Methods 0.000 description 1
- 230000037431 insertion Effects 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 239000011159 matrix material Substances 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000002093 peripheral effect Effects 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 230000001960 triggered effect Effects 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/211—Selection of the most significant subset of features
- G06F18/2113—Selection of the most significant subset of features by ranking or filtering the set of features, e.g. using a measure of variance or of feature cross-correlation
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Artificial Intelligence (AREA)
- Life Sciences & Earth Sciences (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Probability & Statistics with Applications (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请提供了一种数据增强模型训练及数据处理方法、装置、设备、介质;方法包括:基于源域训练样本数据,确定预训练模型的模型参数,并基于模型参数确定数据增强模型;针对每一目标域训练样本数据,在采用目标域训练样本数据对数据增强模型进行模型训练,得到更新后的数据增强模型之后,基于锚定数据,分别对预训练模型和更新后的数据增强模型进行模型训练,对应得到基础损失函数值与更新损失函数值;根据更新损失函数值与基础损失函数值,从目标域训练样本集中筛选出至少一个扩充样本数据;基于源域样本数据和至少一个扩充样本数据,对数据增强模型进行迭代训练,得到训练后的数据增强模型。通过本申请,能够提升数据增强模型的模型性能。
Description
技术领域
本申请涉及人工智能技术领域,尤其涉及一种数据增强模型训练及数据处理方法、装置、设备、介质。
背景技术
数据增强(Data Augmentation)是一种通过先验知识产生跟目标任务相似的更多数据来扩展训练数据集的方法。当数据增强应用于推荐系统时,数据增强可以在不显著提高系统查询、存储成本的情况下,提供对用户、物料和两者之间的交互等方面更为完整的样本分布的刻画。而常见的数据增强方法包括负采样、数据扰动、数据插值、迁移学习,通常有助于提升网络模型的泛化能力和准确性。由于训练数据集中数据的数量和质量直接影响到网络模型的效果上限,因此,研究更有效的数据增强模型训练方法对于提高网络模型的性能和准确性至关重要。
发明内容
本申请实施例提供一种数据增强模型训练及数据处理方法、装置、设备、介质,能够提升数据增强模型的模型性能。
本申请实施例的技术方案是这样实现的:
本申请实施例提供一种数据增强模型训练方法,所述方法包括:获取源域样本数据集和目标域训练样本集;所述源域样本数据集包括源域训练样本集和锚定数据集;基于所述源域训练样本集中的源域训练样本数据,对预设的基础模型进行模型预训练,得到预训练模型的模型参数,并基于所述模型参数确定数据增强模型;针对所述目标域训练样本集中的每一目标域训练样本数据,在采用所述目标域训练样本数据对所述数据增强模型进行模型训练,得到与所述目标域训练样本数据对应的更新后的数据增强模型之后,基于所述锚定数据集中的锚定数据,分别对所述预训练模型和所述更新后的数据增强模型进行模型训练,对应得到基础损失函数值与更新损失函数值;根据所述更新损失函数值与所述基础损失函数值,从所述目标域训练样本集中筛选出至少一个扩充样本数据;基于所述源域样本数据集中的源域样本数据和所述至少一个扩充样本数据,对所述数据增强模型进行迭代训练,得到训练后的数据增强模型。
本申请实施例提供一种数据处理方法,所述方法包括:获取目标业务下的待处理数据集;所述目标业务包括内容订阅业务或者内容推荐业务;将所述待处理数据集输入到训练后的数据增强模型中,通过所述训练后的数据增强模型在所述目标业务下对所述待处理数据集进行数据处理,得到所述目标业务下的数据处理结果;其中,所述训练后的数据增强模型采用本申请实施例所提供的数据增强模型训练方法训练得到。
本申请实施例提供一种数据增强模型训练装置,包括:获取模块,用于获取源域样本数据集和目标域训练样本集;所述源域样本数据集包括源域训练样本集和锚定数据集;模型预训练模块,用于基于所述源域训练样本集中的源域训练样本数据,对预设的基础模型进行模型预训练,得到预训练模型的模型参数,并基于所述模型参数确定数据增强模型;模型训练模块,用于针对所述目标域训练样本集中的每一目标域训练样本数据,在采用所述目标域训练样本数据对所述数据增强模型进行模型训练,得到与所述目标域训练样本数据对应的更新后的数据增强模型之后,基于所述锚定数据集中的锚定数据,分别对所述预训练模型和所述更新后的数据增强模型进行模型训练,对应得到基础损失函数值与更新损失函数值;筛选模块,用于根据所述更新损失函数与所述基础损失函数,从所述目标域训练样本集中筛选出至少一个扩充样本数据;迭代训练模块,用于基于所述源域样本数据集中的源域样本数据和所述至少一个扩充样本数据,对所述数据增强模型进行迭代训练,得到训练后的数据增强模型。
在一些实施例中,所述模型预训练模块还用于:获取与所述基础模型具有相同模型结构的网络模型;将所述模型参数同步到所述网络模型中,得到所述数据增强模型。
在一些实施例中,所述目标域训练样本集中包括N个目标域训练样本数据;N为大于1的整数;所述模型训练模块还用于:针对所述N个目标域训练样本数据中的第i个目标域训练样本数据,采用所述第i个目标域训练样本数据对所述数据增强模型进行模型训练,得到与所述第i个目标域训练样本数据对应的第i个更新后的数据增强模型;基于所述锚定数据集中的锚定数据,分别对所述预训练模型和所述第i个更新后的数据增强模型进行模型训练,对应得到第i个基础损失函数值与第i个更新损失函数值;i为大于0且小于或等于N的任意整数。
在一些实施例中,所述筛选模块还用于:针对所述目标域训练样本集中的每一目标域训练样本数据,获取所述更新损失函数值与所述基础损失函数值之间的差值;将小于0的差值对应的目标域训练样本数据,确定为所述扩充样本数据。
在一些实施例中,所述迭代训练模块还用于:将所述源域样本数据和所述扩充样本数据输入至所述数据增强模型中;通过所述数据增强模型,对所述源域样本数据和所述扩充样本数据分别进行数据处理,对应得到源域样本预估概率和目标域样本预估概率;基于所述源域样本预估概率和所述源域样本数据的真实标签,构建源域样本损失函数;基于所述目标域样本预估概率、所述扩充样本数据的真实标签和所述扩充样本数据对应的所述差值,构建目标域样本损失函数;对所述源域样本损失函数和所述目标域样本损失函数分别进行损失计算,对应得到源域样本损失值和目标域样本损失值;根据所述源域样本损失值和所述目标域样本损失值,确定所述数据增强模型的总损失值;基于所述总损失值,按照预设的迭代次数对所述数据增强模型中的模型参数进行迭代更新,得到所述训练后的数据增强模型。
在一些实施例中,所述迭代训练模块还用于:通过所述数据增强模型的嵌入层,对所述源域样本数据和所述扩充样本数据分别进行特征提取,对应得到源域特征向量和目标域特征向量;通过所述数据增强模型的特征映射模块,对所述源域特征向量和所述目标域特征向量分别进行特征映射,对应得到所述源域样本预估概率和所述目标域样本预估概率。
在一些实施例中,所述迭代训练模块还用于:对所述扩充样本数据对应的差值进行取反操作,得到所述差值的取反结果;对所述取反结果进行数据标准化处理,得到标准化差值;获取所述目标域样本预估概率与所述扩充样本数据的真实标签构成的交叉熵损失函数;采用所述标准化差值对所述交叉熵损失函数进行加权处理,得到所述目标域损失函数。
在一些实施例中,所述目标域训练样本集是从预设的目标域样本数据库中采样得到的数据集,所述装置还包括再次训练模块,所述再次训练模块用于:当检测到所述源域样本数据集中具有新增源域样本数据或者所述目标域样本数据库中具有新增目标域样本数据时,将所述训练后的数据增强模型的模型参数同步到所述基础模型中,得到当前时刻的基础模型,并获取包括所述新增源域样本数据的新的源域样本数据集和包括所述新增目标域样本数据的新的目标域样本数据库;将所述新的源域样本数据集确定为当前时刻的源域样本数据集;将从所述新的目标域样本数据库中采样得到的数据集,确定为当前时刻的目标域训练样本集;基于所述当前时刻的源域样本数据集和所述当前时刻的目标域训练样本集再次对所述基础模型执行所述数据增强模型训练方法。
在一些实施例中,所述再次训练模块还用于:基于所述当前时刻的源域训练样本集中的源域训练样本数据,对所述当前时刻的基础模型进行模型预训练,得到当前时刻的预训练模型的模型参数,并基于所述模型参数确定当前时刻的数据增强模型;针对所述当前时刻的目标域训练样本集中的每一目标域训练样本数据,在采用所述目标域训练样本数据对所述数据增强模型进行模型训练,得到与所述目标域训练样本数据对应的当前时刻的更新后的数据增强模型之后,基于所述当前时刻的锚定数据集中的锚定数据,分别对所述预训练模型和所述更新后的数据增强模型进行模型训练,对应得到当前时刻的基础损失函数值与当前时刻的更新损失函数值;根据所述当前时刻的更新损失函数值与所述当前时刻的基础损失函数值,从所述当前时刻的目标域训练样本集中筛选出至少一个当前时刻的扩充样本数据;基于所述当前时刻的源域样本数据集中的源域样本数据和所述至少一个当前时刻的扩充样本数据,对所述当前时刻的数据增强模型进行迭代训练,得到当前时刻的训练后的数据增强模型。
在一些实施例中,所述装置还包括锚定数据确定模块,所述锚定数据确定模块用于:从所述源域训练样本集中提取源域训练样本集;根据预设采样数量,从提取所述源域训练样本集后剩余的源域训练样本集中进行锚定数据采样处理,得到所述锚定数据集。
在一些实施例中,所述源域样本数据集包括应用于内容订阅业务的订阅内容数据集,所述目标域训练样本集包括应用于内容推荐业务的推荐内容数据集;所述装置还包括数据处理模块,所述数据处理模块用于:获取所述内容订阅业务下的待处理数据集;将所述待处理数据集输入到所述训练后的数据增强模型中,通过所述训练后的数据增强模型在所述内容订阅业务下对所述待处理数据集进行数据处理,得到所述内容订阅业务下的数据处理结果。
本申请实施例提供一种数据处理装置,包括:数据集获取模块,用于获取目标业务下的待处理数据集;目标业务包括内容订阅业务或者内容推荐业务;数据处理结果确定模块,用于将待处理数据集输入到训练后的数据增强模型中,通过所述训练后的数据增强模型在所述目标业务下对所述待处理数据集进行数据处理,得到所述目标业务下的数据处理结果;其中,所述训练后的数据增强模型采用本申请实施例所提供的数据增强模型训练方法训练得到。
本申请实施例提供一种电子设备,包括:存储器,用于存储计算机可执行指令;处理器,用于执行所述存储器中存储的计算机可执行指令时,实现本申请实施例提供的数据增强模型训练方法,或者,实现本申请实施例提供的数据处理方法。
本申请实施例提供一种计算机可读存储介质,存储有计算机可执行指令,用于被处理器执行时实现本申请实施例提供的数据增强模型训练方法,或者,实现本申请实施例提供的数据处理方法。
本申请实施例提供一种计算机程序产品,该计算机程序产品包括可执行指令,可执行指令存储在计算机可读存储介质中;其中,电子设备的处理器从计算机可读存储介质中读取可执行指令,并执行可执行指令时,实现本申请实施例提供的数据增强模型训练方法,或者,实现本申请实施例提供的数据处理方法。
本申请实施例具有以下有益效果:利用源域训练样本数据对预设的基础模型进行模型预训练,得到预训练模型的模型参数,并确定数据增强模型,再利用每一目标域训练样本数据对数据增强模型进行模型训练,得到与目标域训练样本数据对应的更新后的数据增强模型,接着,将锚定数据分别输入到预训练模型和更新后的数据增强模型进行模型训练,对应得到基础损失函数值与更新损失函数值,并基于更新损失函数值与基础损失函数值之间的差值,从目标域训练样本集中筛选出至少一个扩充样本数据,最后,将源域样本数据和扩充样本数据输入到数据增强模型进行迭代训练,得到训练后的数据增强模型。如此,从目标域训练样本集中筛选出的扩充样本数据是对数据增强模型有正向影响的增强数据,使得数据增强模型的模型输入在源域的源域样本数据的基础上扩充了来自目标域的增强数据,从而提高了数据增强模型对模型输入的拟合能力,提升了数据增强模型的模型性能。
附图说明
图1是本申请实施例提供的数据增强模型训练系统架构的结构示意图;
图2是本申请实施例提供的数据增强模型训练装置的结构示意图;
图3是本申请实施例提供的数据增强模型训练方法的一个可选的流程示意图;
图4是本申请实施例提供的数据增强模型训练方法的另一个可选的流程示意图;
图5是本申请实施例提供的构建目标域样本损失函数的过程示意图;
图6是本申请实施例提供的生成订阅内容与推荐内容的过程示意图;
图7是本申请实施例提供的订阅场景与推荐场景之间实现跨场景数据迁移的过程示意图;
图8是本申请实施例提供的订阅号和订阅号消息的关系示意图;
图9是本申请实施例提供的引入跨域数据样本训练排序模型的流程框架示意图;
图10是本申请实施例提供的在线服务的推理环节的应用示意图。
具体实施方式
为了使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请作进一步地详细描述,所描述的实施例不应视为对本申请的限制,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本申请保护的范围。
在以下的描述中,涉及到“一些实施例”,其描述了所有可能实施例的子集,但是可以理解;“一些实施例”可以是所有可能实施例的相同子集或不同子集,并且可以在不冲突的情况下相互结合。
本申请实施例中,术语“模块”或“单元”是指有预定功能的计算机程序或计算机程序的一部分,并与其他相关部分一起工作以实现预定目标,并且可以通过使用软件、硬件(如处理电路或存储器)或其组合来全部或部分实现。同样的,一个处理器(或多个处理器或存储器)可以用来实现一个或多个模块或单元。此外,每个模块或单元都可以是包含该模块或单元功能的整体模块或单元的一部分。
如果申请文件中出现“第一/第二”的类似描述则增加以下的说明,在以下的描述中,所涉及的术语“第一\第二\第三”仅仅是是区别类似的对象,不代表针对对象的特定排序,可以理解地,“第一\第二\第三”在允许的情况下可以互换特定的顺序或先后次序,以使这里描述的本申请实施例能够以除了在这里图示或描述的以外的顺序实施。
除非另有定义,本申请实施例所使用的所有的技术和科学术语与所属技术领域的技术人员通常理解的含义相同。本申请实施例中所使用的术语只是为了描述本申请实施例的目的,不是旨在限制本申请。
对本申请实施例进行进一步详细说明之前,对本申请实施例中涉及的名词和术语进行说明,本申请实施例中涉及的名词和术语适用于如下的解释。
1)标签(label):训练深度神经网络时依赖的标注数据,如“0/1”表示的“属于/不属于”某一类,或者,是否产生点击。
2)嵌入(Embedding):由多个浮点数组成的数值型向量,描述了内容或用户在高维空间中的各种属性、性质。嵌入层是深度学习中常用的使用在模型第一层的一个网络层,嵌入层的作用是将离散的输入特征(如单词、字符等)转化为密集的向量表示,使得这些特征能够被神经网络更好地处理。
3)负采样:当数据增强应用到推荐系统时,常见的数据增强方法包括:负采样。在推荐系统中,负样本通常远多于正样本,负采样是一种通过从曝光数据的负样本中随机抽取一部分样本,使得正负样本比例更加平衡的方法,这样可以提高模型的训练效率,同时避免过拟合。
4)数据扰动:通过对原始数据添加噪声,如随机删除、替换、插入等,可以创造新的数据样本,这种方法可以提高模型的泛化能力,避免模型对原始数据过拟合。
5)数据插值:在推荐系统中,用户对内容的评分数据通常是稀疏的。通过数据插值方法,如基于邻居的协同过滤、矩阵分解等,可以预测用户对未评分内容的评分,从而扩充数据集。
6)迁移学习:在推荐系统中,可以利用其他场景的数据进行预训练,然后将训练好的模型应用到目标领域,这种方法可以充分利用其他领域的知识,提高模型在数据稀疏的目标域的效果。在业内,推荐系统里的迁移学习主要聚焦于稀疏数据场景和冷启动场景的应用。例如,将一个成熟场景(源域)上用户的统计数据特征,用于一个全新的内容场景(目标域)的用户数据特征描述。还有一些方法考虑在模型结构和表示学习的角度处理迁移学习问题。
在相关技术中提出了一种基于嵌入和映射的跨领域推荐方法,首先在源域和目标域学习用户和项目的嵌入表示,然后通过映射函数将源域的嵌入表示映射到目标域。此外,还提出了一种基于注意力机制的知识迁移方法,用于跨领域推荐,利用注意力机制从源域选择与目标域相关的知识,并将其迁移到目标域。
基于上述对相关技术的分析可以看出,数据增强有利于提升训练数据集的数量和质量,严重影响网络模型的泛化能力和准确性。因此,本申请实施例提出了一种从样本选择角度出发的跨域迁移学习范式,从源域的大量样本数据中过滤、筛选出符合目标域的样本数据,在模型训练或者模型参数更新之前,避免引入负向作用的不匹配样本。同时,本申请实施例对于两个不同场景的样本数据互相增强具有可行性,即源域的样本数据也可以通过相同的跨域迁移学习范式,从目标域的样本数据中扩充数据来进行模型训练,从而提升训练数据集的数量和质量。
下面说明本申请实施例提供的数据增强模型训练设备(即电子设备)的示例性应用,本申请实施例提供的数据增强模型训练设备可以实施为笔记本电脑、平板电脑,台式计算机、机顶盒、移动设备(例如,移动电话,便携式音乐播放器,个人数字助理,专用消息设备,便携式游戏设备)、智能手机、智能音箱、智能手表、智能电视、车载终端等各种类型的能够进行数据增强模型训练的用户终端,也可以实施为服务器。下面,将说明数据增强模型训练设备实施为服务器时示例性应用。
参见图1,图1是本申请实施例提供的数据增强模型训练系统100架构的结构示意图,为实现支撑一个数据增强模型训练应用,终端400通过网络300连接服务器200,网络300可以是广域网或者局域网,又或者是二者的组合。
终端400用于向服务器200发送数据增强模型训练请求,服务器200构成本申请实施例的数据增强模型训练设备,服务器200用于响应数据增强模型训练请求,获取源域样本数据集和目标域训练样本集;源域样本数据集包括源域训练样本集和锚定数据集;基于源域训练样本集中的源域训练样本数据,对预设的基础模型进行模型预训练,得到预训练模型的模型参数,并基于模型参数确定数据增强模型;针对目标域训练样本集中的每一目标域训练样本数据,在采用目标域训练样本数据对数据增强模型进行模型训练,得到与目标域训练样本数据对应的更新后的数据增强模型之后,基于锚定数据集中的锚定数据,分别对预训练模型和更新后的数据增强模型进行模型训练,对应得到基础损失函数值与更新损失函数值;根据更新损失函数值与基础损失函数值,从目标域训练样本集中筛选出至少一个扩充样本数据;基于源域样本数据集中的源域样本数据和至少一个扩充样本数据,对数据增强模型进行迭代训练,得到训练后的数据增强模型,并将训练后的数据增强模型返回给终端400,以实现在终端400输出该数据增强模型或者在终端400基于该数据增强模型继续进行下一步业务处理,继续进行数据增强模型的训练或者得到该数据增强模型的数据处理结果。
在一些实施例中,服务器200可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(Content DeliveryNetwork,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。终端400可以是智能手机、平板电脑、笔记本电脑、台式计算机、智能音箱、智能手表、车载终端等,但并不局限于此。终端以及服务器可以通过有线或无线通信方式进行直接或间接地连接,本申请实施例中不做限制。
参见图2,图2是本申请实施例提供的电子设备40的结构示意图,图2所示的电子设备40可以是数据增强模型训练设备,数据增强模型训练设备包括:至少一个处理器410、存储器450、至少一个网络接口420和用户接口430。数据增强模型训练设备中的各个组件通过总线系统440耦合在一起。可理解,总线系统440用于实现这些组件之间的连接通信。总线系统440除包括数据总线之外,还包括电源总线、控制总线和状态信号总线。但是为了清楚说明起见,在图2中将各种总线都标为总线系统440。
处理器410可以是一种集成电路芯片,具有信号的处理能力,例如通用处理器、数字信号处理器(Digital Signal Processor,DSP),或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等,其中,通用处理器可以是微处理器或者任何常规的处理器等。
用户接口430包括使得能够呈现媒体内容的一个或多个输出装置431,包括一个或多个扬声器和/或一个或多个视觉显示屏。用户接口430还包括一个或多个输入装置432,包括有助于用户输入的用户接口部件,比如键盘、鼠标、麦克风、触屏显示屏、摄像头、其他输入按钮和控件。
存储器450可以是可移除的,不可移除的或其组合。示例性的硬件设备包括固态存储器,硬盘驱动器,光盘驱动器等。存储器450可选地包括在物理位置上远离处理器410的一个或多个存储设备。
存储器450包括易失性存储器或非易失性存储器,也可包括易失性和非易失性存储器两者。非易失性存储器可以是只读存储器(Read Only Memory,ROM),易失性存储器可以是随机存取存储器(Random Access Memory,RAM)。本申请实施例描述的存储器450旨在包括任意适合类型的存储器。
在一些实施例中,存储器450能够存储数据以支持各种操作,这些数据的示例包括程序、模块和数据结构或者其子集或超集,下面示例性说明。
操作系统451,包括用于处理各种基本系统服务和执行硬件相关任务的系统程序,例如框架层、核心库层、驱动层等,用于实现各种基础业务以及处理基于硬件的任务;
网络通信模块452,用于经由一个或多个(有线或无线)网络接口420到达其他电子设备,示例性的网络接口420包括:蓝牙、无线相容性认证(WiFi)、和通用串行总线(Universal Serial Bus,USB)等;呈现模块453,用于经由一个或多个与用户接口430相关联的输出装置431(例如,显示屏、扬声器等)使得能够呈现信息(例如,用于操作外围设备和显示内容和信息的用户接口);输入处理模块454,用于对一个或多个来自一个或多个输入装置432之一的一个或多个用户输入或互动进行检测以及翻译所检测的输入或互动。
在一些实施例中,本申请实施例提供的装置可以采用软件方式实现,图2示出了存储在存储器450中的数据增强模型训练装置455,其可以是程序和插件等形式的软件,包括以下软件模块:获取模块4551、模型预训练模块4552、模型训练模块4553、筛选模块4554和迭代训练模块4555,这些模块是逻辑上的,因此根据所实现的功能可以进行任意的组合或进一步拆分。将在下文中说明各个模块的功能。
在另一些实施例中,图2示出的存储器450中还可以包括数据处理装置,数据处理装置也可以是程序和插件等形式的软件,包括以下软件模块:数据集获取模块和数据处理结果确定模块,这些模块是逻辑上的,因此根据所实现的功能可以进行任意的组合或进一步拆分。将在下文中说明各个模块的功能。
在再一些实施例中,本申请实施例提供的装置可以采用硬件方式实现,作为示例,本申请实施例提供的装置可以是采用硬件译码处理器形式的处理器,其被编程以执行本申请实施例提供的数据增强模型训练方法或者数据处理方法,例如,硬件译码处理器形式的处理器可以采用一个或多个应用专用集成电路(Application Specific IntegratedCircuit,ASIC)、数字信号处理器(Digital Signal Processor,DSP)、可编程逻辑器件(Programmable Logic Device,PLD)、复杂可编程逻辑器件(Complex Programmable LogicDevice,CPLD)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或其他电子元件。
在一些实施例中,终端或服务器可以通过运行各种计算机可执行指令或计算机程序来实现本申请实施例提供的数据增强模型训练方法。举例来说,计算机可执行指令可以是微程序级的命令、机器指令或软件指令。计算机程序可以是操作系统中的原生程序或软件模块;可以是本地(Native)应用程序(Application,APP),即需要在操作系统中安装才能运行的程序,也可以是可以嵌入至任意APP中的小程序,即只需要下载到浏览器环境中就可以运行的程序。总而言之,上述的计算机可执行指令可以是任意形式的指令,上述计算机程序可以是任意形式的应用程序、模块或插件。
本申请各实施例提供的数据增强模型训练方法可以由电子设备来执行,其中,该电子设备可以是服务器也可以是终端,即本申请各实施例的数据增强模型训练方法可以通过服务器来执行,也可以通过终端来执行,或者也可以通过服务器与终端之间交互执行。
参见图3,图3是本申请实施例提供的数据增强模型训练方法的一个可选的流程示意图,将结合图3示出的步骤进行说明,以数据增强模型训练方法的执行主体为服务器为例进行说明,方法包括以下步骤S101至步骤S105:
步骤S101,获取源域样本数据集和目标域训练样本集。
本申请实施例中,源域样本数据集与目标域训练样本集分别是包含源域样本数据与目标域训练样本数据的两个不同的数据集。源域是指模型进行预训练的领域,目标域是指模型将要应用到的新领域,即源域与目标域分别对应两个不同的领域或者应用场景。在目标域中,模型通常可以通过迁移学习来利用源域中学到的知识和特征,以便在目标任务上获得良好的性能表现。例如,对于订阅内容与推荐内容这两个不同的应用场景,将订阅内容的相关数据应用于推荐内容排序模型进行推荐内容排序,则订阅内容的相关数据属于目标域训练样本数据,而推荐内容的相关数据属于源域样本数据;将推荐内容的相关数据应用于订阅内容排序模型进行订阅内容排序,则推荐内容的相关数据属于目标域训练样本数据,而订阅内容的相关数据属于源域样本数据。
源域样本数据集和目标域训练样本集均可以是一个固定的样本数据集合,也可以是一个实时的样本数据流。源域样本数据集包括源域训练样本集和锚定数据集,源域训练样本集用于对本申请实施例的基础模型进行模型预训练,得到预训练模型;锚定数据集为源域样本数据集中对除了源域训练样本集之外的数据进行采样后得到的数据集,用于分别对预训练模型和更新后的数据增强模型进行模型训练。
这里,通过获取源域样本数据集和目标域训练样本集,便于后续将源域样本数据集和目标域训练样本集中的数据作为模型的输入进行模型预训练或者模型训练。
步骤S102,基于源域训练样本集中的源域训练样本数据,对预设的基础模型进行模型预训练,得到预训练模型的模型参数,并基于模型参数确定数据增强模型。
本申请实施例中,源域训练样本数据包含于源域训练样本集,预设的基础模型可以是预先设置的任意一种网络模型,数据增强模型为与预设的基础模型或者与预训练模型具有相同模型结构,且同步预训练模型的模型参数的网络模型,因此,对应的数据增强模型也可以是任意一种网络模型,例如,卷积神经网络、循环神经网络等。
将源域训练样本集中的源域训练样本数据作为预设的基础模型的模型输入,并构建源域训练样本数据对应的损失函数。在对预设的基础模型的模型预训练过程中,基于对损失函数进行损失计算后的损失值,对预设的基础模型的模型参数进行参数更新。当损失值趋于收敛或达到预设的迭代次数时,停止对预设的基础模型的模型预训练,得到模型预训练后的预训练模型。然后,对预训练模型的模型参数进行保存,并将保存的预训练模型的模型参数同步到与预设的基础模型具有相同模型结构的网络模型中,并将该网络模型确定为数据增强模型。
这里,将预训练模型的模型参数同步到数据增强模型,便于后续利用目标域训练样本集中的每一目标域训练样本数据对数据增强模型进行模型训练,得到受到目标域训练样本数据影响而更新后的数据增强模型。
步骤S103,针对目标域训练样本集中的每一目标域训练样本数据,在采用目标域训练样本数据对数据增强模型进行模型训练,得到与目标域训练样本数据对应的更新后的数据增强模型之后,基于锚定数据集中的锚定数据,分别对预训练模型和更新后的数据增强模型进行模型训练,对应得到基础损失函数值与更新损失函数值。
本申请实施例中,利用目标域训练样本集中的每一目标域训练样本数据对数据增强模型进行模型训练,模型训练次数可以为多次,且与目标域训练样本集中的目标域训练样本数据的数据量相同,并在任一目标域训练样本数据对数据增强模型进行一轮模型训练之后,在进行下一轮数据增强模型的模型训练之前,对更新后的数据增强模型进行模型初始化,得到初始的数据增强模型,再利用新的目标域训练样本数据对该数据增强模型进行新一轮的模型训练,以此类推,直到得到目标域训练样本集中的每一目标域训练样本数据对应的更新后的数据增强模型。
将锚定数据集中的锚定数据输入到预训练模型进行模型训练,得到预训练模型的基础损失函数值,再将锚定数据集中的锚定数据输入到每一轮更新后的数据增强模型进行模型训练,得到每一轮更新后的数据增强模型的更新损失函数值。
这里,通过得到的预训练模型的基础损失函数值和每一轮更新后的数据增强模型的更新损失函数值,以便后续对基础损失函数值和更新损失函数值进行数值比较,从目标域训练样本集中筛选出对数据增强模型的模型性能具有提升效果的扩充样本数据。
步骤S104,根据更新损失函数值与基础损失函数值,从目标域训练样本集中筛选出至少一个扩充样本数据。
本申请实施例中,扩充样本数据为目标域训练样本集中对数据增强模型的模型性能具有提升效果的样本数据。将更新损失函数值与基础损失函数值进行数值比较,若更新损失函数值小于基础损失函数值,则表明目标域训练样本数据对对应的更新后的数据增强模型的模型性能有提升效果,可作为扩充样本数据对数据增强模型进行后续的迭代训练;若更新损失函数值大于或等于基础损失函数值,则表明目标域训练样本数据对对应的更新后的数据增强模型的模型性能没有提升效果,无法作为扩充样本数据对数据增强模型进行后续的迭代训练。
这里,通过比较更新损失函数值与基础损失,能够从目标域训练样本集中筛选出对数据增强模型的模型性能具有提升效果的扩充样本数据,提升了数据增强模型的模型输入的数量和质量。
步骤S105,基于源域样本数据集中的源域样本数据和至少一个扩充样本数据,对数据增强模型进行迭代训练,得到训练后的数据增强模型。
本申请实施例中,源域样本数据包括源域样本数据集中的所有数据,将源域样本数据集中的源域样本数据和至少一个扩充样本数据分别作为数据增强模型的模型输入,并分别构建源域样本数据对应的源域损失函数和扩充样本数据对应的目标域损失函数,在对数据增强模型的迭代训练过程中,基于损失计算后的源域损失值和目标域损失值对数据增强模型的模型参数进行参数更新。当源域损失值和目标域损失值趋于收敛或者达到预设的迭代次数时,停止对数据增强模型的迭代训练,得到训练后的数据增强模型。
本申请实施例提供的数据增强模型训练方法,通过利用源域训练样本数据对预设的基础模型进行模型预训练,得到预训练模型的模型参数,并确定数据增强模型,再利用每一目标域训练样本数据对数据增强模型进行模型训练,得到与目标域训练样本数据对应的更新后的数据增强模型,接着,将锚定数据分别输入到预训练模型和更新后的数据增强模型进行模型训练,对应得到基础损失函数值与更新损失函数值,并基于更新损失函数值与基础损失函数值之间的差值,从目标域训练样本集中筛选出至少一个扩充样本数据,最后,将源域样本数据和扩充样本数据输入到数据增强模型进行迭代训练,得到训练后的数据增强模型。如此,从目标域训练样本集中筛选出的扩充样本数据是对数据增强模型有正向影响的增强数据,使得数据增强模型的模型输入在源域的源域样本数据的基础上扩充了来自目标域的增强数据,从而提高了数据增强模型对模型输入的拟合能力,提升了数据增强模型的模型性能。
下面将结合数据增强模型训练系统中的终端和服务器之间的交互,对本申请实施例中的数据增强模型训练方法进行说明。需要说明的是,这里的数据增强模型训练方法是通过终端与服务器进行交互实现的数据增强模型训练方法,与上述实施例中由服务器执行的数据增强模型训练方法实质上相同,所不同的仅在于本申请实施例中还描述了终端在数据增强模型训练方法的执行过程中所执行的动作,并且,有些步骤既可以由终端来执行也可以由服务器来执行,因此,对于本实施例中与上述实施例中内容相同但执行主体不同的步骤,本实施例只是示例性说明,在实现的过程中,可以由任意一个执行主体来执行,本申请实施例对此不做限定。
图4是本申请实施例提供的数据增强模型训练方法的另一个可选的流程示意图,如图4所示,方法包括以下步骤S201至步骤S219:
步骤S201,终端接收用户输入的数据增强模型训练操作。
本申请实施例中,终端上可以运行有数据增强模型训练应用,用户可以在数据增强模型训练应用的客户端输入数据增强模型训练操作,在数据增强模型训练应用中,可以提供数据增强模型训练功能,用户可以在该数据增强模型训练功能页面输入数据增强模型训练操作,以触发数据增强模型训练的请求。
在一些实施例中,用户在输入数据增强模型训练操作时,还可以同时输入源域样本数据集和目标域训练样本集,在终端接收到源域样本数据集和目标域训练样本集时,会在数据增强模型训练功能页面弹出确认数据增强模型训练窗口,在终端检测到用户点击确认数据增强模型训练按钮后,再对源域样本数据集和目标域训练样本集进行进一步处理来实现数据增强模型的模型训练。或者,在另一些实施例中,用户可以直接在数据增强模型训练功能页面输入源域样本数据集和目标域训练样本集,终端接收到源域样本数据集和目标域训练样本集就可直接触发数据增强模型训练功能,对源域样本数据集和目标域训练样本集进行进一步处理来实现数据增强模型的模型训练。
步骤S202,终端响应于数据增强模型训练操作生成数据增强模型训练请求。
本申请实施例中,可以将用户输入的数据封装至数据增强模型训练请求中。例如,在数据增强模型训练应用的显示界面,显示目标域样本数据库中的所有目标域样本数据,用户可以根据实际需求进行数据选取或者数据采样,得到目标域训练样本集,然后可以将用户输入的源域样本数据集和目标域训练样本集封装至数据增强模型训练请求中,或者,将用户输入的预设的基础模型封装至数据增强模型训练请求中。
步骤S203,终端将数据增强模型训练请求发送给服务器。
步骤S204,服务器响应于数据增强模型训练请求,获取源域样本数据集和目标域训练样本集。
本申请实施例中,如果数据增强模型训练请求中封装有源域样本数据集和目标域训练样本集,则可以直接解析得到源域样本数据集和目标域训练样本集,即获取源域样本数据集和目标域训练样本集。
本申请实施例中,源域样本数据集包括源域训练样本集和锚定数据集。锚定数据集的确定过程为:从源域训练样本集中提取源域训练样本集;根据预设采样数量,从提取源域训练样本集后剩余的源域训练样本集中进行锚定数据采样处理,得到锚定数据集。也就是说,锚定数据集为采样数据,属于剩余的源域训练样本集中的一部分。预设采样数量可以根据剩余的源域训练样本集的数据量大小来确定,也可以根据后续的数据增强模型的训练效果来确定,在此不做限定。
步骤S205,服务器基于源域训练样本集中的源域训练样本数据,对预设的基础模型进行模型预训练,得到预训练模型的模型参数。
本申请实施例中,模型预训练是指利用源域训练样本集中的源域训练样本数据预先训练预设的基础模型的过程。首先,采用一个已搭建的网络模型作为预设的基础模型,并对预设的基础模型的模型参数进行随机初始化;然后,在开始训练该基础模型时,通过对模型参数进行梯度更新,不断调整该基础模型的模型参数,直到该基础模型的损失值越来越小,最终损失值趋于收敛状态,停止模型预训练过程,得到预设的基础模型的预训练模型;最后,将预训练模型的模型参数保存下来,以便预训练好的预训练模型可以在下次执行类似任务时获得较好的结果。
步骤S206,服务器获取与基础模型具有相同模型结构的网络模型。
本申请实施例中,模型结构包括网络模型的网络层、网络层数以及网络层功能,具有相同模型结构是指新的网络模型与基础模型的网络层、网络层数以及网络层功能均相同。通过模型结构对比,获取预基础模型具有相同模型结构的网络模型,以便用于后续的模型参数同步。
步骤S207,服务器将模型参数同步到网络模型中,得到数据增强模型。
本申请实施例中,在得到预训练模型的模型参数和与基础模型具有相同模型结构的网络模型之后,将模型参数同步到网络模型,即直接使用之前保存下来的预训练模型的模型参数作为数据增强模型的初始化参数,使得在后续数据增强模型的模型训练过程中,从初始化参数开始,依据数据增强模型的训练结果不断对模型参数进行参数调整,并使得模型参数适应于新的数据集。
步骤S208,服务器针对目标域训练样本集中的每一目标域训练样本数据,在采用目标域训练样本数据对数据增强模型进行模型训练,得到与目标域训练样本数据对应的更新后的数据增强模型之后,基于锚定数据集中的锚定数据,分别对预训练模型和更新后的数据增强模型进行模型训练,对应得到基础损失函数值与更新损失函数值。
本申请实施例中,目标域训练样本集中的目标域训练样本数据有多个,利用数据增强模型逐个地遍历目标域训练样本集中的每一目标域训练样本数据,即采用每一目标域训练样本数据对数据增强模型进行模型训练。在每次模型训练完成之后,都会得到一个更新后的数据增强模型,而每一个更新后的数据增强模型对应一个目标域训练样本数据。在每得到一个更新后的数据增强模型之后,将锚定数据集中的锚定数据作为模型输入,输入到预训练模型进行模型训练,得到预训练模型的基础损失函数值,并且,还将锚定数据集中的锚定数据作为模型输入,输入到更新后的数据增强模型进行模型训练,得到更新后的数据增强模型的更新损失函数值。因此,预训练模型的基础损失函数值与更新后的数据增强模型的更新损失函数值的个数和目标域训练样本集中的目标域训练样本数据的数据量相同,且预训练模型的基础损失函数值与更新后的数据增强模型的更新损失函数值的个数属于一一对应关系。
在一些实施例中,目标域训练样本集中包括N个目标域训练样本数据;N为大于1的整数。针对N个目标域训练样本数据中的第i个目标域训练样本数据,采用第i个目标域训练样本数据对数据增强模型进行模型训练,得到与第i个目标域训练样本数据对应的第i个更新后的数据增强模型;基于锚定数据集中的锚定数据,分别对预训练模型和第i个更新后的数据增强模型进行模型训练,对应得到第i个基础损失函数值与第i个更新损失函数值;i为大于0且小于或等于N的任意整数。也就是说,目标域训练样本集中的N个目标域训练样本数据会对应得到N个基础损失函数值与N个更新损失函数值。
这里,通过模型训练得到基础损失函数值与更新损失函数值,以便后续对基础损失函数值与更新损失函数值进行数值大小比较,确定目标域训练样本集中的每一目标域训练样本数据对数据增强模型的影响是正向影响还是负向影响,从而考虑是否采用目标域训练样本数据扩充源域样本数据集来实现数据增强。
步骤S209,服务器针对目标域训练样本集中的每一目标域训练样本数据,获取更新损失函数值与基础损失函数值之间的差值。
本申请实施例中,针对目标域训练样本集中的每一目标域训练样本数据,对每一目标域训练样本数据对应的更新损失函数值与基础损失函数值进行减法计算,来比较每一目标域训练样本数据对应的更新损失函数值与基础损失函数值的数值大小,得到两者之间的差值。利用差值来确定目标域训练样本集中的每一目标域训练样本数据对数据增强模型的影响是正向影响还是负向影响,从而考虑是否采用目标域训练样本数据扩充源域样本数据集来实现数据增强。
步骤S210,将小于0的差值对应的目标域训练样本数据,确定为扩充样本数据。
本申请实施例中,基于上述通过对更新损失函数值与基础损失函数值进行相减得到的更新损失函数值与基础损失函数值之间的差值,对目标域训练样本数据集中的目标域训练样本数据进行筛选。若差值小于0,则保留该差值对应的目标域训练样本数据,并将该目标域训练样本数据确定为扩充样本数据;若差值大于或等于0,则直接排除该差值对应的目标域训练样本数据,将目标域训练样本数据集中剩余的目标域训练样本数据确定为扩充样本数据。
这里,由于差值为锚定数据经过更新后的数据增强模型得到的更新损失函数值与经过预训练模型得到的更新损失函数值之间的差值,而更新后的数据增强模型是由目标域训练样本数据对数据增强模型进行模型训练得到,并且数据增强模型与预训练模型的模型结构与模型参数均相同,因此,通过该差值可以看出目标域训练样本数据对数据增强模型的影响是正向影响还是负向影响。当差值小于0时,则更新损失函数值小于基础损失函数值,表明将该差值对应的目标域训练样本数据引入对数据增强模型产生了正向影响,可以将该目标域训练样本数据作为源域样本数据集的扩充样本,来提升数据增强模型的模型性能。当差值大于或者等于0时,则更新损失函数值大于或者等于基础损失函数值,表明将该差值对应的目标域训练样本数据引入对数据增强模型产生了负向影响或者无影响,不能将该目标域训练样本数据作为源域样本数据集的扩充样本,且无法提升数据增强模型的模型性能。
步骤S211,服务器将源域样本数据和扩充样本数据输入至数据增强模型中。
本申请实施例中,在根据差值来筛选得到扩充样本数据之后,服务器就会将源域样本数据和扩充样本数据作为数据增强模型的输入来进行数据增强模型的迭代训练,以使得源域样本数据和扩充样本数据作为增强数据,提升数据增强模型的模型性能。
步骤S212,服务器通过数据增强模型,对源域样本数据和扩充样本数据分别进行数据处理,对应得到源域样本预估概率和目标域样本预估概率。
本申请实施例中,数据处理过程具体为:通过数据增强模型的嵌入层,对源域样本数据和扩充样本数据分别进行特征提取,对应得到源域特征向量和目标域特征向量;通过数据增强模型的特征映射模块,对源域特征向量和目标域特征向量分别进行特征映射,对应得到源域样本预估概率和目标域样本预估概率。
数据增强模型的嵌入层用于分别将源域样本数据和扩充样本数据经过特征提取过程,转化为密集的向量表示(即源域特征向量和目标域特征向量)。数据增强模型的特征映射模块用于分别将源域特征向量和目标域特征向量经过特征映射过程,计算得到源域样本预估概率和目标域样本预估概率。数据增强模型的特征映射模块可以是任意一种具有特征映射功能的网络结构,例如,卷积神经网络、循环神经网络等,在此不做限定。
步骤S213,服务器基于源域样本预估概率和源域样本数据的真实标签,构建源域样本损失函数。
本申请实施例中,目标域样本预估概率与扩充样本数据的真实标签构成的交叉熵损失函数通过以下公式(1)得到:
(1)
其中,为源域样本数据的数据量,表示源域样本数据的真实标签,表示源域
样本预估概率,为累加函数,为源域样本预估概率与源域样本数据的真实标签构
成的源域样本损失函数。
步骤S214,服务器基于目标域样本预估概率、扩充样本数据的真实标签和扩充样本数据对应的差值,构建目标域样本损失函数。
在一些实施例中,参见图5,图5示出了步骤S214中,服务器基于目标域样本预估概率、扩充样本数据的真实标签和扩充样本数据对应的差值,构建目标域样本损失函数,可以通过以下步骤S2141至步骤S2144实现:
步骤S2141,对扩充样本数据对应的差值进行取反操作,得到差值的取反结果。
本申请实施例中,由于扩充样本对应的差值为小于0的负数,因此,在计算目标域的损失函数之前,需要将该差值进行取反操作,使得该差值由负数转换为正数,便于后续基于该差值得到数据增强模型的目标域损失函数。
在一些实施例中,取反操作具体过程为:先将差值转换成二进制数,再取得二进制数的补码,之后对补码的每一位(包括第一位的符号位)进行取反运算:即将0变为1、将1变为0,得到的是取反结果的补码,再次取补码得到取反结果的原码,再将取反结果的原码转换为十进制数,得到差值的取反结果。或者,在另一些实施例中,取反操作还可以直接对差值a按位取反,得到差值的取反结果为-(a+1)。例如,假设差值为-5,则该差值的取反结果为4。在另一些实施例中,还可以对扩充样本数据对应的差值进行取绝对值操作,得到差值的取绝对值结果,并将取绝对值结果用于后续目标域损失函数的构建。
步骤S2142,对取反结果进行数据标准化处理,得到标准化差值。
本申请实施例中,在对数据增强模型建模之前,都需要对输入的数据进行数据标准化处理,以消除量纲的影响。如果对未标准化的数据直接进行建模,可能会导致数据增强模型对数值大的变量学习过多,而对数值小的变量训练不够充分,使得数据增强模型的模型训练效果变差。数据标准化处理方法包括最大最小归一化、均值方差标准化、小数定标法、定量特征二值化等。
这里,通过对差值的取反结果进行数据标准化处理,得到取反结果的标准化差值,使得标准化差值与目标域样本预估概率处于同一数量级,便于后续对数据增强模型的目标域损失函数的正确构建。
步骤S2143,获取目标域样本预估概率与扩充样本数据的真实标签构成的交叉熵损失函数。
本申请实施例中,目标域样本预估概率与扩充样本数据的真实标签构成的交叉熵损失函数通过以下公式(2)得到:
(2)
其中,为扩充样本数据的数据量,表示扩充样本数据的真实标签,表示目标
域样本预估概率,为目标域样本预估概率与扩充样本数据的真实标签构成的交叉熵损
失函数。
步骤S2144,采用标准化差值对交叉熵损失函数进行加权处理,得到目标域损失函数。
本申请实施例中,加权处理是指将标准化差值作为交叉熵损失函数的权值进行乘法操作,最终通过以下公式(3)得到目标域损失函数:
(3)
其中,为标准化差值,为目标域损失函数。
步骤S215,服务器对源域样本损失函数和目标域样本损失函数分别进行损失计算,对应得到源域样本损失值和目标域样本损失值。
步骤S216,服务器根据源域样本损失值和目标域样本损失值,确定数据增强模型的总损失值。
本申请实施例中,数据增强模型的总损失值为源域样本损失值和目标域样本损失
值之和,数据增强模型的总损失函数通过以下公式(4)得到:
(4)
基于数据增强模型的总损失函数,经过损失计算,得到数据增强模型的总损失值。通过该总损失值衡量数据增强模型的源域样本预估概率和源域样本数据的真实标签、目标域样本预估概率和扩充样本数据的真实标签之间的不一致程度,即计算模型每次迭代的源域样本预估概率和源域样本数据的真实标签、目标域样本预估概率和扩充样本数据的真实标签之间的差距,从而指导下一步的模型训练向正确的方向进行。
步骤S217,服务器基于总损失值,按照预设的迭代次数对数据增强模型中的模型参数进行迭代更新,得到训练后的数据增强模型。
本申请实施例中,根据数据增强模型的总损失函数的导数,沿梯度最小方向将总损失值回传,更新数据增强模型中的模型参数,如数据增强模型中的各个权重值。预先设定一个总损失阈值,当总损失值小于预先设定的总损失阈值时,则停止迭代训练,即停止模型参数更新;也可以预先设定一个最大迭代次数阈值,当迭代次数超过最大迭代次数阈值时,则停止模型参数更新;还可以预先设定一个截至迭代时间,当迭代时间到达截至迭代时间时,则停止模型参数更新,得到训练后的数据增强模型。
在一些实施例中,目标域训练样本集是从预设的目标域样本数据库中采样得到的数据集。当检测到源域样本数据集中具有新增源域样本数据或者目标域样本数据库中具有新增目标域样本数据时,将训练后的数据增强模型的模型参数同步到基础模型中,得到当前时刻的基础模型,并获取包括新增源域样本数据的新的源域样本数据集和包括新增目标域样本数据的新的目标域样本数据库;将新的源域样本数据集确定为当前时刻的源域样本数据集;将从新的目标域样本数据库中采样得到的数据集,确定为当前时刻的目标域训练样本集;基于当前时刻的源域样本数据集和当前时刻的目标域训练样本集再次对基础模型执行数据增强模型训练方法。
在一些实施例中,当前时刻的源域样本数据集包括当前时刻的源域训练样本集和当前时刻的锚定数据集。基于当前时刻的源域训练样本集中的源域训练样本数据,对当前时刻的基础模型进行模型预训练,得到当前时刻的预训练模型的模型参数,并基于模型参数确定当前时刻的数据增强模型;针对当前时刻的目标域训练样本集中的每一目标域训练样本数据,在采用目标域训练样本数据对数据增强模型进行模型训练,得到与目标域训练样本数据对应的当前时刻的更新后的数据增强模型之后,基于当前时刻的锚定数据集中的锚定数据,分别对预训练模型和更新后的数据增强模型进行模型训练,对应得到当前时刻的基础损失函数值与当前时刻的更新损失函数值;根据当前时刻的更新损失函数值与当前时刻的基础损失函数值,从当前时刻的目标域训练样本集中筛选出至少一个当前时刻的扩充样本数据;基于当前时刻的源域样本数据集中的源域样本数据和至少一个当前时刻的扩充样本数据,对当前时刻的数据增强模型进行迭代训练,得到当前时刻的训练后的数据增强模型。
这里,通过将训练后的数据增强模型的模型参数再次同步到基础模型,在当前时刻的源域样本数据集和当前时刻的目标域训练样本集上,重复以上数据增强模型的模型训练过程,即通过持续不断的更新数据流,使得数据增强模型的模型性能不断提升,从而得到更加准确的模型处理结果。
步骤S218,服务器向终端发送训练后的数据增强模型。
步骤S219,终端输出训练后的数据增强模型。
本申请实施例中,首先将基础模型预训练后的预训练模型的模型参数同步到与预训练模型具有相同模型结构的数据增强模型,再利用目标域训练样本集中的每一目标域训练样本数据对数据增强模型进行多次模型训练,得到与目标域训练样本数据对应的更新后的数据增强模型。然后,将锚定数据集中的锚定数据分别输入到预训练模型和更新后的数据增强模型进行模型训练,通过比较两个模型的基础损失函数值与更新损失函数值的数值大小,判断目标域训练样本集中的每一目标域训练样本数据对数据增强模型的影响是否是正向影响。最后,若为正向影响,则将具有正向影响的目标域训练样本数据作为源域样本数据集的扩充样本数据来对数据增强模型进行迭代训练,得到训练后的数据增强模型,实现利用目标域训练样本数据对源域样本数据集的数据增强,从而利用增强后的数据集提升数据增强模型的模型性能。
在一些实施例中,源域样本数据集包括应用于内容订阅业务的订阅内容数据集,目标域训练样本集包括应用于内容推荐业务的推荐内容数据集;获取内容订阅业务下的待处理数据集;将待处理数据集输入到训练后的数据增强模型中,通过训练后的数据增强模型在内容订阅业务下对待处理数据集进行数据处理,得到内容订阅业务下的数据处理结果。也就是说,训练后的数据增强模型可以为订阅内容排序模型,将训练后的数据增强模型应用于订阅内容场景下的内容订阅业务,并将内容订阅业务的订阅内容数据集作为训练后的数据增强模型的模型输入,利用训练后的数据增强模型的数据处理功能,得到训练后的数据增强模型数据处理后的数据处理结果,即订阅内容的排序结果。之后,能够根据订阅内容的排序结果实现内容订阅业务下的订阅内容的精准排序,增加用户对订阅内容的点击率。
在一些实施例中,源域样本数据集包括应用于内容推荐业务的推荐内容数据集,目标域训练样本集包括应用于内容订阅业务的订阅内容数据集;获取内容推荐业务下的待处理数据集;将待处理数据集输入到训练后的数据增强模型中,通过训练后的数据增强模型在内容推荐业务下对待处理数据集进行数据处理,得到内容推荐业务下的数据处理结果。也就是说,训练后的数据增强模型可以为推荐内容排序模型,将训练后的数据增强模型应用于推荐内容场景下的内容推荐业务,并将内容推荐业务的推荐内容数据集作为训练后的数据增强模型的模型输入,利用训练后的数据增强模型的数据处理功能,得到训练后的数据增强模型数据处理后的数据处理结果,即推荐内容的排序结果。之后,能够根据推荐内容的排序结果实现内容推荐业务下的推荐内容的精准排序,增加用户对推荐内容的点击率。
在一些实施例中,在通过上述任一实施例得到训练后的数据增强模型之后,还可以提供一数据处理方法,该数据处理方法可以通过数据处理设备来执行,该数据处理设备可以与上述用于实现数据增强模型训练方法的电子设备为同一电子设备,也可以是不同的电子设备,也就是说,用于实现数据处理方法的数据处理装置和用于实现数据增强模型训练方法的数据增强模型训练装置可以位于同一电子设备中,也可以位于不同的电子设备中。可以通过训练后的数据增强模型对目标业务下的待处理数据集进行数据处理,得到目标业务下的数据处理结果。其中,目标业务包括内容订阅业务或者内容推荐业务。
当目标业务包括内容订阅业务时,在数据处理方法的实现过程中,可以获取内容订阅业务下的待处理数据集;将待处理数据集输入到训练后的数据增强模型中,通过训练后的数据增强模型在内容订阅业务下对待处理数据集进行数据处理,得到内容订阅业务下的数据处理结果。之后,能够根据数据处理结果实现内容订阅业务下的订阅内容的精准排序,增加用户对订阅内容的点击率。
当目标业务包括内容推荐业务时,在数据处理方法的实现过程中,可以获取内容推荐业务下的待处理数据集;将待处理数据集输入到训练后的数据增强模型中,通过训练后的数据增强模型在内容推荐业务下对待处理数据集进行数据处理,得到内容推荐业务下的数据处理结果。之后,能够根据数据处理结果实现内容推荐业务下的推荐内容的精准排序,增加用户对推荐内容的点击率。
下面,将说明本申请实施例在一个实际的应用场景中的示例性应用。
本申请实施例提供一种数据增强模型训练方法,该方法的应用场景可以是对平台订阅号消息盒子和推送的消息卡片(包括用户自行订阅和平台主动推荐)的排序技术。
在信息流内容产品(如平台公众号、平台社区等)中,同时存在如图6所示的订阅内容与推荐内容,但通常分属于两个展示区域。由于两个场景的数据分布通常具有较大差异,用户的特征数据也大相径庭,因此通常由不同的排序模型(如订阅内容排序模型和推荐内容排序模型)控制排序,每个排序模型分别使用各自场景的数据进行训练。
因此,本申请实施例为每个场景的排序模型扩充了来自另一个场景的数据,例如,为订阅场景的推荐内容排序模型扩充了来自订阅场景的数据,实现如图7所示的订阅场景与推荐场景之间的跨场景数据迁移,以增强推荐内容排序模型对用户的特征数据的拟合及排序的泛化能力。对于推荐内容排序模型的模型输入,推荐内容排序模型处理的每一条输入样本都表示一个用户对一条推荐给它的订阅号消息相关的基础特征。基础特征可以抽象表示为:<目标用户相关特征,待推荐消息所属订阅号相关特征,待推荐消息本身相关特征,交叉关系的统计特征>。其中,“订阅号”和“订阅号消息”的关系如图8所示,即每条消息都有所属的订阅号。
目标用户相关特征包括用户ID,用户年龄,用户性别,用户地域,用户过去1天曝光消息数,用户过去7天曝光消息数等等;待推荐消息所属订阅号相关特征包括订阅号ID,订阅号人数,订阅号过去7天创作消息数,订阅号过去7天点击阅读数等等;待推荐消息本身相关特征包括消息ID,消息发送至今的小时数,消息过去1小时曝光数,消息过去1小时点击数等等;其中,订阅号ID联合消息ID得到了每个内容ID;交叉关系的统计特征包括该用户对该订阅号过去28天曝光数,该用户对该订阅号过去28天点击数等等。
对于订阅内容排序模型,订阅内容的曝光内容、用户的特征数据及其是否发生点击等信息组成了源域数据,推荐内容的信息则属于目标域数据。反之,对于推荐内容排序模型,推荐内容的信息是源域数据,订阅内容的信息是目标域数据。
在为源域的排序模型引入目标域的数据用于训练时,存在的问题是,一个在目标域曝光的内容引发的用户是否点击的行为日志,不一定在源域会有相同的结果。例如,用户可能在浏览推荐内容场景时,带有扩充知识面的目的,偏好点击具有“新鲜感”的科技文章;而在浏览订阅内容场景时,带有了解时事的目的,偏好点击“传统”的新闻文章。于是一篇在推荐场景被点击的文章,假设其在订阅场景被推送时,也不会得到点击。更进一步地,推荐系统永远无法采集到同一个用户对同一个内容同时处于两个不同场景时的反馈信息。也就是说,来自目标域的扩充样本以及点击标签无法直接用于对源域模型进行训练。
本申请实施例提出了一种引入跨域数据样本训练排序模型的流程框架,示意图如图9所示。流程框架主要由图9中的基础模型901和增强模型902两部分模型组成,基础模型901表示使用源域数据进行训练的基础模型。基础嵌入层(Base Embedding)存储了源域的内容和用户的特征以及内容ID的嵌入向量,随基础模型的训练而更新。增强模型902表示使用源域数据和目标域数据联合训练的增强模型。附属嵌入层(Subsidiary Embedding)存储了目标域的内容和用户的特征以及内容ID的嵌入向量,随增强模型的训练而更新。两部分模型(包括嵌入层Embedding部分)使用了相同的网络层结构和参数规模。在训练过程中,两部分模型交替更新,并在一定步骤后同步彼此的参数。具体过程如下:
第一步,基础模型在源域数据上进行训练。经过N_I次前向传播和反向传播(梯度更新),训练批量大小为M,即每次使用M个样本做计算。随后,基础模型的参数,同步到中间的增强模型(即上述数据增强模型)。
第二步,从剩余的源域数据内采样一批固定数目的样本,称为锚定数据。接着,在
目标域数据中,采样一批数量为N_scr的目标域样本集(即上述目标域训练样本集)。接着,
增强模型逐个地遍历该目标域样本集中的每个目标域样本,对增强模型进行训练更新(即
Stochastic Gradient Descent,一步只利用一个目标域样本),记单个目标域样本的损失
函数值为Loss_src。在模型更新后,分别使用更新后的增强模型和基础模型对锚定数据进
行处理,计算得到损失函数值分别为Loss_tgt(即上述更新损失函数值)和Loss_base(即上
述基础损失函数值)。两个损失函数值的第个目标域样本对应的差值用以衡量将该目标域样本引入源域的训练数据集后,对
源域模型的影响是否是正向的(或者产生负向作用)。当时,表示利用该目标域样本
训练源域模型后,锚定数据对应的损失函数值降低了,则可以认为该目标域样本是对提高
源域模型的预估准确率有帮助的样本。在单个目标域样本对增强模型进行训练后,将增强
模型重新初始化为基础模型,以确保增强模型可以继续评估其它的目标域样本对基础模型
的影响效果。当遍历完N_scr个目标域样本后,各个值对应了各个目标域样本的重要
性。
第三步,利用的目标域样本(即上述扩充样本数据)和源域数据(即上述源
域样本数据)组成扩充后的增强数据。同步了基础模型参数的增强模型,进行N_U次迭代,每
次迭代从增强数据内采样M个样本。当计算来自目标域样本的损失函数(即上述目标域样本
损失函数)时,使用在第二步中得到的目标域样本重要性对交叉熵损
失函数进行加权,如公式(5)所示:
(5)
其中,是对负数取反并做最大最小值归一化后的加权权重(即上述标准化
差值),括号内的部分是标准的交叉熵损失函数。
第四步,将经过训练的增强模型的模型参数同步回到基础模型。在持续更新的数据流上,重复以上第一步到第三步的过程。
由于持续的模型参数同步,在经过数据增强的训练模式后,要应用到在线服务的推理环节只需要部署其中一个模型,如图10所示。
本申请实施例所提模型(B18)应用于平台订阅号消息的推荐内容排序时,增加了日志数据的利用率。如表1所示,对比基线模型(A1+A2),人均消息阅读次数和阅读时长得到显著提升,表明了模型具有更好的排序能力。
表1 模型性能对比结果
本申请实施例提出了利用跨场景的样本数据扩充排序模型的训练样本集,使排序模型可以在更大规模和更完整的数据上学习拟合用户的数据特征和个性化兴趣,并且提出一个适用于利用未知标签的样本来辅助模型训练的训练模式,避免在引入扩充样本数据后为基础模型带来负向效果。
可以理解的是,在本申请实施例中,涉及到用户信息的内容,例如,源域样本数据集、目标域样本数据集等信息,如果涉及与用户信息或企业信息相关的数据,当本申请实施例运用到具体产品或技术中时,需要获得用户许可或者同意,或者对这些信息进行模糊化处理,以消除这些信息与用户之间的对应关系;且相关数据收集处理在实例应用时应该严格根据相关国家法律法规的要求,获取个人信息主体的知情同意或单独同意,并在法律法规及个人信息主体的授权范围内,开展后续数据使用及处理行为。
下面继续说明本申请实施例提供的数据增强模型训练装置455的实施为软件模块的示例性结构,在一些实施例中,如图2所示,存储在存储器450的数据增强模型训练装置455中的软件模块可以包括:获取模块4551,用于获取源域样本数据集和目标域训练样本集;所述源域样本数据集包括源域训练样本集和锚定数据集;模型预训练模块4552,用于基于所述源域训练样本集中的源域训练样本数据,对预设的基础模型进行模型预训练,得到预训练模型的模型参数,并基于所述模型参数确定数据增强模型;模型训练模块4553,用于针对所述目标域训练样本集中的每一目标域训练样本数据,在采用所述目标域训练样本数据对所述数据增强模型进行模型训练,得到与所述目标域训练样本数据对应的更新后的数据增强模型之后,基于所述锚定数据集中的锚定数据,分别对所述预训练模型和所述更新后的数据增强模型进行模型训练,对应得到基础损失函数值与更新损失函数值;筛选模块4554,用于根据所述更新损失函数与所述基础损失函数,从所述目标域训练样本集中筛选出至少一个扩充样本数据;迭代训练模块4555,用于基于所述源域样本数据集中的源域样本数据和所述至少一个扩充样本数据,对所述数据增强模型进行迭代训练,得到训练后的数据增强模型。
在一些实施例中,所述模型预训练模块4552还用于:获取与所述基础模型具有相同模型结构的网络模型;将所述模型参数同步到所述网络模型中,得到所述数据增强模型。
在一些实施例中,所述目标域训练样本集中包括N个目标域训练样本数据;N为大于1的整数;所述模型训练模块4553还用于:针对所述N个目标域训练样本数据中的第i个目标域训练样本数据,采用所述第i个目标域训练样本数据对所述数据增强模型进行模型训练,得到与所述第i个目标域训练样本数据对应的第i个更新后的数据增强模型;基于所述锚定数据集中的锚定数据,分别对所述预训练模型和所述第i个更新后的数据增强模型进行模型训练,对应得到第i个基础损失函数值与第i个更新损失函数值;i为大于0且小于或等于N的任意整数。
在一些实施例中,所述筛选模块4554还用于:针对所述目标域训练样本集中的每一目标域训练样本数据,获取所述更新损失函数值与所述基础损失函数值之间的差值;将小于0的差值对应的目标域训练样本数据,确定为所述扩充样本数据。
在一些实施例中,所述迭代训练模块4555还用于:将所述源域样本数据和所述扩充样本数据输入至所述数据增强模型中;通过所述数据增强模型,对所述源域样本数据和所述扩充样本数据分别进行数据处理,对应得到源域样本预估概率和目标域样本预估概率;基于所述源域样本预估概率和所述源域样本数据的真实标签,构建源域样本损失函数;基于所述目标域样本预估概率、所述扩充样本数据的真实标签和所述扩充样本数据对应的所述差值,构建目标域样本损失函数;对所述源域样本损失函数和所述目标域样本损失函数分别进行损失计算,对应得到源域样本损失值和目标域样本损失值;根据所述源域样本损失值和所述目标域样本损失值,确定所述数据增强模型的总损失值;基于所述总损失值,按照预设的迭代次数对所述数据增强模型中的模型参数进行迭代更新,得到所述训练后的数据增强模型。
在一些实施例中,所述迭代训练模块4555还用于:通过所述数据增强模型的嵌入层,对所述源域样本数据和所述扩充样本数据分别进行特征提取,对应得到源域特征向量和目标域特征向量;通过所述数据增强模型的特征映射模块,对所述源域特征向量和所述目标域特征向量分别进行特征映射,对应得到所述源域样本预估概率和所述目标域样本预估概率。
在一些实施例中,所述迭代训练模块4555还用于:对所述扩充样本数据对应的差值进行取反操作,得到所述差值的取反结果;对所述取反结果进行数据标准化处理,得到标准化差值;获取所述目标域样本预估概率与所述扩充样本数据的真实标签构成的交叉熵损失函数;采用所述标准化差值对所述交叉熵损失函数进行加权处理,得到所述目标域损失函数。
在一些实施例中,所述目标域训练样本集是从预设的目标域样本数据库中采样得到的数据集,所述装置455还包括再次训练模块,所述再次训练模块用于:当检测到所述源域样本数据集中具有新增源域样本数据或者所述目标域样本数据库中具有新增目标域样本数据时,将所述训练后的数据增强模型的模型参数同步到所述基础模型中,得到当前时刻的基础模型,并获取包括所述新增源域样本数据的新的源域样本数据集和包括所述新增目标域样本数据的新的目标域样本数据库;将所述新的源域样本数据集确定为当前时刻的源域样本数据集;将从所述新的目标域样本数据库中采样得到的数据集,确定为当前时刻的目标域训练样本集;基于所述当前时刻的源域样本数据集和所述当前时刻的目标域训练样本集再次对所述基础模型执行所述数据增强模型训练方法。
在一些实施例中,所述再次训练模块还用于:基于所述当前时刻的源域训练样本集中的源域训练样本数据,对所述当前时刻的基础模型进行模型预训练,得到当前时刻的预训练模型的模型参数,并基于所述模型参数确定当前时刻的数据增强模型;针对所述当前时刻的目标域训练样本集中的每一目标域训练样本数据,在采用所述目标域训练样本数据对所述数据增强模型进行模型训练,得到与所述目标域训练样本数据对应的当前时刻的更新后的数据增强模型之后,基于所述当前时刻的锚定数据集中的锚定数据,分别对所述预训练模型和所述更新后的数据增强模型进行模型训练,对应得到当前时刻的基础损失函数值与当前时刻的更新损失函数值;根据所述当前时刻的更新损失函数值与所述当前时刻的基础损失函数值,从所述当前时刻的目标域训练样本集中筛选出至少一个当前时刻的扩充样本数据;基于所述当前时刻的源域样本数据集中的源域样本数据和所述至少一个当前时刻的扩充样本数据,对所述当前时刻的数据增强模型进行迭代训练,得到当前时刻的训练后的数据增强模型。
在一些实施例中,所述装置455还包括锚定数据确定模块,所述锚定数据确定模块用于:从所述源域训练样本集中提取源域训练样本集;根据预设采样数量,从提取所述源域训练样本集后剩余的源域训练样本集中进行锚定数据采样处理,得到所述锚定数据集。
在一些实施例中,所述源域样本数据集包括应用于内容订阅业务的订阅内容数据集,所述目标域训练样本集包括应用于内容推荐业务的推荐内容数据集;所述装置455还包括数据处理模块,所述数据处理模块用于:获取所述内容订阅业务下的待处理数据集;将所述待处理数据集输入到所述训练后的数据增强模型中,通过所述训练后的数据增强模型在所述内容订阅业务下对所述待处理数据集进行数据处理,得到所述内容订阅业务下的数据处理结果。
下面继续说明本申请实施例提供的数据处理装置的实施为软件模块的示例性结构,在一些实施例中,数据处理装置也可存储在存储器450中,存储在存储器450的数据处理装置中的软件模块可以包括:数据集获取模块,用于获取目标业务下的待处理数据集;所述目标业务包括内容订阅业务或者内容推荐业务;数据处理结果确定模块,用于将所述待处理数据集输入到训练后的数据增强模型中,通过所述训练后的数据增强模型在所述目标业务下对所述待处理数据集进行数据处理,得到所述目标业务下的数据处理结果;其中,所述训练后的数据增强模型采用本申请实施例所提供的数据增强模型训练方法训练得到。
需要说明的是,本申请实施例装置的描述,与上述方法实施例的描述是类似的,具有同方法实施例相似的有益效果,因此不做赘述。对于本装置实施例中未披露的技术细节,请参照本申请方法实施例的描述而理解。
本申请实施例提供一种计算机可读存储介质,其中存储有计算机可执行指令,当计算机可执行指令被处理器执行时,将引起处理器执行本申请实施例提供的数据增强模型训练方法,例如,如图3示出的数据增强模型训练方法,或者,执行本申请实施例提供的数据处理方法。
本申请实施例提供了一种计算机程序产品,该计算机程序产品包括计算机可执行指令,该计算机可执行指令存储在计算机可读存储介质中。电子设备的处理器从计算机可读存储介质读取该计算机可执行指令,处理器执行该计算机可执行指令,使得该电子设备执行本申请实施例上述的数据增强模型训练方法,或者,执行本申请实施例提供的数据处理方法。
在一些实施例中,计算机可读存储介质可以是RAM、ROM、闪存、磁表面存储器、光盘、或CD-ROM等存储器;也可以是包括上述存储器之一或任意组合的各种设备。
在一些实施例中,计算机可执行指令可以采用程序、软件、软件模块、脚本或代码的形式,按任意形式的编程语言(包括编译或解释语言,或者声明性或过程性语言)来编写,并且其可按任意形式部署,包括被部署为独立的程序或者被部署为模块、组件、子例程或者适合在计算环境中使用的其它单元。
作为示例,计算机可执行指令可以但不一定对应于文件系统中的文件,可以可被存储在保存其它程序或数据的文件的一部分,例如,存储在超文本标记语言(Hyper TextMarkup Language,HTML)文档中的一个或多个脚本中,存储在专用于所讨论的程序的单个文件中,或者,存储在多个协同文件(例如,存储一个或多个模块、子程序或代码部分的文件)中。
作为示例,计算机可执行指令可被部署为在一个电子设备上执行,或者在位于一个地点的多个电子设备上执行,又或者,在分布在多个地点且通过通信网络互连的多个电子设备上执行。
以上所述,仅为本申请的实施例而已,并非用于限定本申请的保护范围。凡在本申请的精神和范围之内所作的任何修改、等同替换和改进等,均包含在本申请的保护范围之内。
Claims (15)
1.一种数据增强模型训练方法,其特征在于,所述方法包括:
获取源域样本数据集和目标域训练样本集;所述源域样本数据集包括源域训练样本集和锚定数据集;
基于所述源域训练样本集中的源域训练样本数据,对预设的基础模型进行模型预训练,得到预训练模型的模型参数,并基于所述模型参数确定数据增强模型;
针对所述目标域训练样本集中的每一目标域训练样本数据,在采用所述目标域训练样本数据对所述数据增强模型进行模型训练,得到与所述目标域训练样本数据对应的更新后的数据增强模型之后,基于所述锚定数据集中的锚定数据,分别对所述预训练模型和所述更新后的数据增强模型进行模型训练,对应得到基础损失函数值与更新损失函数值;
根据所述更新损失函数值与所述基础损失函数值,从所述目标域训练样本集中筛选出至少一个扩充样本数据;
基于所述源域样本数据集中的源域样本数据和所述至少一个扩充样本数据,对所述数据增强模型进行迭代训练,得到训练后的数据增强模型。
2.根据权利要求1所述的方法,其特征在于,所述基于所述模型参数确定数据增强模型,包括:
获取与所述基础模型具有相同模型结构的网络模型;
将所述模型参数同步到所述网络模型中,得到所述数据增强模型。
3.根据权利要求1所述的方法,其特征在于,所述目标域训练样本集中包括N个目标域训练样本数据;N为大于1的整数;
所述针对所述目标域训练样本集中的每一目标域训练样本数据,在采用所述目标域训练样本数据对所述数据增强模型进行模型训练,得到与所述目标域训练样本数据对应的更新后的数据增强模型之后,基于所述锚定数据集中的锚定数据,分别对所述预训练模型和所述更新后的数据增强模型进行模型训练,对应得到基础损失函数值与更新损失函数值,包括:
针对所述N个目标域训练样本数据中的第i个目标域训练样本数据,采用所述第i个目标域训练样本数据对所述数据增强模型进行模型训练,得到与所述第i个目标域训练样本数据对应的第i个更新后的数据增强模型;
基于所述锚定数据集中的锚定数据,分别对所述预训练模型和所述第i个更新后的数据增强模型进行模型训练,对应得到第i个基础损失函数值与第i个更新损失函数值;i为大于0且小于或等于N的任意整数。
4.根据权利要求1所述的方法,其特征在于,所述根据所述更新损失函数值与所述基础损失函数值,从所述目标域训练样本集中筛选出至少一个扩充样本数据,包括:
针对所述目标域训练样本集中的每一目标域训练样本数据,获取所述更新损失函数值与所述基础损失函数值之间的差值;
将小于0的差值对应的目标域训练样本数据,确定为所述扩充样本数据。
5.根据权利要求4所述的方法,其特征在于,所述基于所述源域样本数据集中的源域样本数据和所述至少一个扩充样本数据,对所述数据增强模型进行迭代训练,得到训练后的数据增强模型,包括:
将所述源域样本数据和所述扩充样本数据输入至所述数据增强模型中;
通过所述数据增强模型,对所述源域样本数据和所述扩充样本数据分别进行数据处理,对应得到源域样本预估概率和目标域样本预估概率;
基于所述源域样本预估概率和所述源域样本数据的真实标签,构建源域样本损失函数;
基于所述目标域样本预估概率、所述扩充样本数据的真实标签和所述扩充样本数据对应的所述差值,构建目标域样本损失函数;
对所述源域样本损失函数和所述目标域样本损失函数分别进行损失计算,对应得到源域样本损失值和目标域样本损失值;
根据所述源域样本损失值和所述目标域样本损失值,确定所述数据增强模型的总损失值;
基于所述总损失值,按照预设的迭代次数对所述数据增强模型中的模型参数进行迭代更新,得到所述训练后的数据增强模型。
6.根据权利要求5所述的方法,其特征在于,所述通过所述数据增强模型,对所述源域样本数据和所述扩充样本数据分别进行数据处理,对应得到源域样本预估概率和目标域样本预估概率,包括:
通过所述数据增强模型的嵌入层,对所述源域样本数据和所述扩充样本数据分别进行特征提取,对应得到源域特征向量和目标域特征向量;
通过所述数据增强模型的特征映射模块,对所述源域特征向量和所述目标域特征向量分别进行特征映射,对应得到所述源域样本预估概率和所述目标域样本预估概率。
7.根据权利要求5所述的方法,其特征在于,所述基于所述目标域样本预估概率、所述扩充样本数据的真实标签和所述扩充样本数据对应的所述差值,构建目标域样本损失函数,包括:
对所述扩充样本数据对应的差值进行取反操作,得到所述差值的取反结果;
对所述取反结果进行数据标准化处理,得到标准化差值;
获取所述目标域样本预估概率与所述扩充样本数据的真实标签构成的交叉熵损失函数;
采用所述标准化差值对所述交叉熵损失函数进行加权处理,得到所述目标域损失函数。
8.根据权利要求1至7任一项所述的方法,其特征在于,所述目标域训练样本集是从预设的目标域样本数据库中采样得到的数据集;所述方法还包括:
当检测到所述源域样本数据集中具有新增源域样本数据或者所述目标域样本数据库中具有新增目标域样本数据时,将所述训练后的数据增强模型的模型参数同步到所述基础模型中,得到当前时刻的基础模型,并获取包括所述新增源域样本数据的新的源域样本数据集和包括所述新增目标域样本数据的新的目标域样本数据库;
将所述新的源域样本数据集确定为当前时刻的源域样本数据集;
将从所述新的目标域样本数据库中采样得到的数据集,确定为当前时刻的目标域训练样本集;
基于所述当前时刻的源域样本数据集和所述当前时刻的目标域训练样本集再次对所述基础模型执行所述数据增强模型训练方法。
9.根据权利要求8所述的方法,其特征在于,所述当前时刻的源域样本数据集包括当前时刻的源域训练样本集和当前时刻的锚定数据集;所述基于所述当前时刻的源域样本数据集和所述当前时刻的目标域训练样本集再次对所述基础模型执行所述数据增强模型训练方法,包括:
基于所述当前时刻的源域训练样本集中的源域训练样本数据,对所述当前时刻的基础模型进行模型预训练,得到当前时刻的预训练模型的模型参数,并基于所述模型参数确定当前时刻的数据增强模型;
针对所述当前时刻的目标域训练样本集中的每一目标域训练样本数据,在采用所述目标域训练样本数据对所述数据增强模型进行模型训练,得到与所述目标域训练样本数据对应的当前时刻的更新后的数据增强模型之后,基于所述当前时刻的锚定数据集中的锚定数据,分别对所述预训练模型和所述更新后的数据增强模型进行模型训练,对应得到当前时刻的基础损失函数值与当前时刻的更新损失函数值;
根据所述当前时刻的更新损失函数值与所述当前时刻的基础损失函数值,从所述当前时刻的目标域训练样本集中筛选出至少一个当前时刻的扩充样本数据;
基于所述当前时刻的源域样本数据集中的源域样本数据和所述至少一个当前时刻的扩充样本数据,对所述当前时刻的数据增强模型进行迭代训练,得到当前时刻的训练后的数据增强模型。
10.根据权利要求1至7任一项所述的方法,其特征在于,所述方法还包括:
从所述源域训练样本集中提取源域训练样本集;
根据预设采样数量,从提取所述源域训练样本集后剩余的源域训练样本集中进行锚定数据采样处理,得到所述锚定数据集。
11.根据权利要求1至7任一项所述的方法,其特征在于,所述源域样本数据集包括应用于内容订阅业务的订阅内容数据集,所述目标域训练样本集包括应用于内容推荐业务的推荐内容数据集;所述方法还包括:
获取所述内容订阅业务下的待处理数据集;
将所述待处理数据集输入到所述训练后的数据增强模型中,通过所述训练后的数据增强模型在所述内容订阅业务下对所述待处理数据集进行数据处理,得到所述内容订阅业务下的数据处理结果。
12.一种数据处理方法,其特征在于,所述方法包括:
获取目标业务下的待处理数据集;所述目标业务包括内容订阅业务或者内容推荐业务;
将所述待处理数据集输入到训练后的数据增强模型中,通过所述训练后的数据增强模型在所述目标业务下对所述待处理数据集进行数据处理,得到所述目标业务下的数据处理结果;其中,所述训练后的数据增强模型采用权利要求1至11任一项所提供的数据增强模型训练方法训练得到。
13.一种数据增强模型训练装置,其特征在于,所述数据增强模型训练装置包括:
获取模块,用于获取源域样本数据集和目标域训练样本集;所述源域样本数据集包括源域训练样本集和锚定数据集;
模型预训练模块,用于基于所述源域训练样本集中的源域训练样本数据,对预设的基础模型进行模型预训练,得到预训练模型的模型参数,并基于所述模型参数确定数据增强模型;
模型训练模块,用于针对所述目标域训练样本集中的每一目标域训练样本数据,在采用所述目标域训练样本数据对所述数据增强模型进行模型训练,得到与所述目标域训练样本数据对应的更新后的数据增强模型之后,基于所述锚定数据集中的锚定数据,分别对所述预训练模型和所述更新后的数据增强模型进行模型训练,对应得到基础损失函数值与更新损失函数值;
筛选模块,用于根据所述更新损失函数与所述基础损失函数,从所述目标域训练样本集中筛选出至少一个扩充样本数据;
迭代训练模块,用于基于所述源域样本数据集中的源域样本数据和所述至少一个扩充样本数据,对所述数据增强模型进行迭代训练,得到训练后的数据增强模型。
14.一种电子设备,其特征在于,包括:
存储器,用于存储计算机可执行指令;
处理器,用于执行所述存储器中存储的计算机可执行指令时,实现权利要求1至11任一项所述的数据增强模型训练方法,或者,实现权利要求12所述的数据处理方法。
15.一种计算机可读存储介质,其特征在于,存储有计算机可执行指令,所述计算机可执行指令被处理器执行时,实现权利要求1至11任一项所述的数据增强模型训练方法,或者,实现权利要求12所述的数据处理方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410078708.2A CN117609887B (zh) | 2024-01-19 | 2024-01-19 | 数据增强模型训练及数据处理方法、装置、设备、介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410078708.2A CN117609887B (zh) | 2024-01-19 | 2024-01-19 | 数据增强模型训练及数据处理方法、装置、设备、介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117609887A true CN117609887A (zh) | 2024-02-27 |
CN117609887B CN117609887B (zh) | 2024-05-10 |
Family
ID=89951959
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410078708.2A Active CN117609887B (zh) | 2024-01-19 | 2024-01-19 | 数据增强模型训练及数据处理方法、装置、设备、介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117609887B (zh) |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20210012198A1 (en) * | 2018-05-31 | 2021-01-14 | Huawei Technologies Co., Ltd. | Method for training deep neural network and apparatus |
CN112417293A (zh) * | 2020-12-03 | 2021-02-26 | 京东数字科技控股股份有限公司 | 信息推送方法和系统、模型训练方法及相关设备 |
CN115358410A (zh) * | 2022-08-08 | 2022-11-18 | 珠高智能科技(深圳)有限公司 | 预训练模型的领域增强方法、装置、设备及存储介质 |
CN117371511A (zh) * | 2023-11-01 | 2024-01-09 | 腾讯科技(深圳)有限公司 | 图像分类模型的训练方法、装置、设备及存储介质 |
-
2024
- 2024-01-19 CN CN202410078708.2A patent/CN117609887B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20210012198A1 (en) * | 2018-05-31 | 2021-01-14 | Huawei Technologies Co., Ltd. | Method for training deep neural network and apparatus |
CN112417293A (zh) * | 2020-12-03 | 2021-02-26 | 京东数字科技控股股份有限公司 | 信息推送方法和系统、模型训练方法及相关设备 |
CN115358410A (zh) * | 2022-08-08 | 2022-11-18 | 珠高智能科技(深圳)有限公司 | 预训练模型的领域增强方法、装置、设备及存储介质 |
CN117371511A (zh) * | 2023-11-01 | 2024-01-09 | 腾讯科技(深圳)有限公司 | 图像分类模型的训练方法、装置、设备及存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN117609887B (zh) | 2024-05-10 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US20230281448A1 (en) | Method and apparatus for information recommendation, electronic device, computer readable storage medium and computer program product | |
CN113302634B (zh) | 学习和预测关键短语以及生成预测的系统、介质和方法 | |
CN110838020B (zh) | 基于向量迁移的推荐方法、装置、计算机设备及存储介质 | |
US20230025317A1 (en) | Text classification model training method, text classification method, apparatus, device, storage medium and computer program product | |
CN105279146B (zh) | 针对短不相关文本的检测的上下文感知方法 | |
CN112418292B (zh) | 一种图像质量评价的方法、装置、计算机设备及存储介质 | |
CN110598070B (zh) | 应用类型识别方法及装置、服务器及存储介质 | |
CN111897934B (zh) | 问答对生成方法及装置 | |
US10909145B2 (en) | Techniques for determining whether to associate new user information with an existing user | |
US11874798B2 (en) | Smart dataset collection system | |
CN116700839B (zh) | 一种任务处理方法、装置、设备、存储介质及程序产品 | |
US20220083907A1 (en) | Data generation and annotation for machine learning | |
CN113761375A (zh) | 基于神经网络的消息推荐方法、装置、设备及存储介质 | |
CN113515625A (zh) | 测试结果分类模型训练方法、分类方法及装置 | |
CN117609887B (zh) | 数据增强模型训练及数据处理方法、装置、设备、介质 | |
CN112861474B (zh) | 一种信息标注方法、装置、设备及计算机可读存储介质 | |
CN112818658B (zh) | 文本对分类模型的训练方法、分类方法、设备及存储介质 | |
CN117648576B (zh) | 数据增强模型训练及数据处理方法、装置、设备、介质 | |
CN109885647B (zh) | 用户履历验证方法、装置、电子设备及存储介质 | |
CN115482019A (zh) | 一种活动关注度预测方法、装置、电子设备和存储介质 | |
CN116501993B (zh) | 房源数据推荐方法及装置 | |
US11868737B2 (en) | Method and server for processing text sequence for machine processing task | |
WO2023030932A1 (en) | Iterative training of computer model for machine learning | |
CN117440041A (zh) | 静默服务信息推送方法、装置、计算机设备和存储介质 | |
Yin | Personalized advertisement push method based on semantic similarity and data mining |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |