CN112116025A - 用户分类模型的训练方法、装置、电子设备及存储介质 - Google Patents

用户分类模型的训练方法、装置、电子设备及存储介质 Download PDF

Info

Publication number
CN112116025A
CN112116025A CN202011042272.XA CN202011042272A CN112116025A CN 112116025 A CN112116025 A CN 112116025A CN 202011042272 A CN202011042272 A CN 202011042272A CN 112116025 A CN112116025 A CN 112116025A
Authority
CN
China
Prior art keywords
network
sample data
loss
feature
domain sample
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202011042272.XA
Other languages
English (en)
Inventor
李振鹏
姜佳男
郭玉红
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Didi Infinity Technology and Development Co Ltd
Original Assignee
Beijing Didi Infinity Technology and Development Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Didi Infinity Technology and Development Co Ltd filed Critical Beijing Didi Infinity Technology and Development Co Ltd
Priority to CN202011042272.XA priority Critical patent/CN112116025A/zh
Publication of CN112116025A publication Critical patent/CN112116025A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques

Landscapes

  • Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明实施例公开了一种用户分类模型的训练方法、装置、电子设备及存储介质,通过获取目标域数据集、源域数据集以及随机噪声,将各源域样本数据和随机噪声输入至用户分类模型的掩码生成网络,确定源域数据集对应的掩码数据集,根据目标域数据集、源域数据集和掩码数据集训练用户分类模型中的特征生成网络、分类网络和域判别网络,响应于特征生成网络、分类网络和域判别网络对应的损失函数满足预定条件,确定用户分类模型,由此,发明实施例通过随机噪声掩盖源域数据集的部分特征,减小源域数据集与目标域数据集特征分布差异,从而提高用户分类模型的准确性。

Description

用户分类模型的训练方法、装置、电子设备及存储介质
技术领域
本发明涉及数据处理技术领域,更具体地,涉及一种用户分类模型的训练方法、装置、电子设备及存储介质。
背景技术
在数据处理领域,当有大量的类别标签时,使用有监督的深度神经网络可以有效地进行分类模型的训练和预测,但是,在实际应用中,对类别标签的标注需要大量的人力资源和时间。并且,在有新的数据集时,需要重新花费大量的人力和时间去标注新的数据集,这显然不是高效的方法。目前,存在着大量的已标注好的数据,和新的数据集具有相同的标签分布,因此,如何使用迁移学习利用已标注好的数据对新的具有较少类别标签数据集进行预测具有重要的意义。
发明内容
有鉴于此,本发明实施例提供一种用户分类模型的训练方法、装置、电子设备及存储介质,以通过随机噪声掩盖源域数据集的部分特征,减小源域数据集与目标域数据集特征分布差异,从而提高用户分类模型的准确性。
第一方面,本发明实施例提供一种用户分类模型的训练方法,所述方法包括:
获取目标域数据集、源域数据集以及随机噪声,所述源域数据集包括具有标签的多个源域样本数据,所述目标域数据集包括多个目标域样本数据,其中部分目标域样本数据具有标签;其中,所述源域样本数据包括对应用户在预定时间范围内的历史任务记录,所述目标域样本数据中不存在预定时间范围内的历史任务记录;
将各所述源域样本数据和所述随机噪声输入至所述用户分类模型的掩码生成网络,确定所述源域数据集对应的掩码数据集;
根据所述目标域数据集、所述源域数据集和所述掩码数据集训练所述用户分类模型中的特征生成网络、分类网络和域判别网络;
响应于所述特征生成网络、分类网络和域判别网络对应的损失函数满足预定条件,确定所述用户分类模型;
其中,所述特征生成网络用于获取所述目标域样本数据和源域样本数据的特征向量,所述分类网络用于根据所述目标域样本数据和源域样本数据的特征向量、所述源域样本数据的标签、部分所述目标域样本数据的标签确定对应的标签预测值,所述域判别网络用于使得所述目标域数据集与所述源域数据集进行特征对齐。
可选的,所述标签用于表征用户在未来预定时间范围内的任务状态,所述任务状态包括用户在未来预定时间范围内会执行任务、以及用户在未来预定时间范围内不会执行任务。
可选的,根据所述目标域数据集、所述源域数据集和所述掩码数据集训练所述用户分类模型中的特征生成网络、分类网络和域判别网络包括:
将所述目标域数据集、所述源域数据集和所述掩码数据集输入至所述特征生成网络进行处理,确定各所述目标域样本数据和源域样本数据的特征向量;
将各所述目标域样本数据和源域样本数据的特征向量输入至所述域判别网络中进行处理,确定对应的特征分布;
根据所述特征分布确定所述域判别网络对应的对抗损失;
将各所述目标域样本数据和源域样本数据的特征向量输入至所述分类网络,确定各所述目标域样本数据和源域样本数据的标签预测值;
根据具有标签的所述目标域样本数据的标签及对应的标签预测值、源域样本数据的标签及对应的标签预测值确定所述分类网络对应的分类损失;
根据具有标签的所述目标域样本数据的特征向量及对应的标签、源域样本数据的特征向量及对应的标签确定对应的对比损失;
根据所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数。
可选的,根据所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数包括:
保持所述特征生成网络和所述分类网络的参数,调节所述域判别网络的参数以最大化所述对抗损失;
保持所述域判别网络的参数,调节所述特征生成网络和所述分类网络的参数以最小化所述分类损失、对比损失和对抗损失。
可选的,所述特征生成网络包括第一特征生成子网络、第二特征生成子网络和第三特征生成子网络;
所述第一特征生成子网络用于根据所述源域样本数据和对应的掩码数据生成所述源域样本数据的特征向量;
所述第二特征生成子网络用于根据所述目标域样本数据生成所述目标域样本数据的特征向量;
所述第三特征生成子网络用于对所述源域样本数据的特征向量和所述目标域样本数据的特征向量进行特征处理,获取预定维度的所述源域样本数据的特征向量和所述目标域样本数据的特征向量。
可选的,所述特征生成网络包括多个第三特征生成子网络,多个所述第三特征生成子网络权值共享。
可选的,所述第一特征生成子网络具有对应的第一解码器和第一自编码损失,所述第二特征生成子网络具有对应的第二解码器和第二自编码损失;
根据所述目标域数据集、所述源域数据集和所述掩码数据集训练所述用户分类模型中的特征生成网络、分类网络和域判别网络还包括:
根据所述第一特征生成子网络的输入值和所述第一解码器的输出值计算所述第一自编码损失;
根据所述第二特征生成子网络的输入值和所述第二解码器的输出值计算所述第二自编码损失;
根据所述第一自编码损失和所述第二自编码损失确定所述特征生成网络的损失;
根据所述特征生成网络的损失、所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数。
可选的,根据所述特征生成网络的损失、所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数包括:
保持所述特征生成网络和所述分类网络的参数,调节所述域判别网络的参数以最大化所述对抗损失;
保持所述域判别网络的参数,调节所述特征生成网络和所述分类网络的参数以最小化所述特征生成网络的损失、分类损失、对比损失和对抗损失。
第二方面,本发明实施例提供一种用户分类模型的训练装置,所述装置包括
数据获取单元,被配置为获取目标域数据集、源域数据集以及随机噪声,所述源域数据集包括具有标签的多个源域样本数据,所述目标域数据集包括多个目标域样本数据,其中部分目标域样本数据具有标签;其中,所述源域样本数据包括对应用户在预定时间范围内的历史任务记录,所述目标域样本数据中不存在预定时间范围内的历史任务记录;
掩码数据确定单元,被配置为将各所述源域样本数据和所述随机噪声输入至所述用户分类模型的掩码生成网络,确定所述源域数据集对应的掩码数据集;
训练单元,被配置为根据所述目标域数据集、所述源域数据集和所述掩码数据集训练所述用户分类模型中的特征生成网络、分类网络和域判别网络;
模型确定单元,被配置为响应于所述特征生成网络、分类网络和域判别网络对应的损失函数满足预定条件,确定所述用户分类模型;
其中,所述特征生成网络用于获取所述目标域样本数据和源域样本数据的特征向量,所述分类网络用于根据所述目标域样本数据和源域样本数据的特征向量、所述源域样本数据的标签、部分所述目标域样本数据的标签确定对应的标签预测值,所述域判别网络用于使得所述目标域数据集与所述源域数据集进行特征对齐。
可选的,所述标签用于表征用户在未来预定时间范围内的任务状态,所述任务状态包括用户在未来预定时间范围内会执行任务、以及用户在未来预定时间范围内不会执行任务。
可选的,训练单元包括:
第一处理子单元,被配置为将所述目标域数据集、所述源域数据集和所述掩码数据集输入至所述特征生成网络进行处理,确定各所述目标域样本数据和源域样本数据的特征向量;
第二处理子单元,被配置为将各所述目标域样本数据和源域样本数据的特征向量输入至所述域判别网络中进行处理,确定对应的特征分布;
对抗损失确定子单元,被配置为根据所述特征分布确定所述域判别网络对应的对抗损失;
第三处理子单元,被配置为将各所述目标域样本数据和源域样本数据的特征向量输入至所述分类网络,确定各所述目标域样本数据和源域样本数据的标签预测值;
分类损失确定子单元,被配置为根据具有标签的所述目标域样本数据的标签及对应的标签预测值、源域样本数据的标签及对应的标签预测值确定所述分类网络对应的分类损失;
对比损失确定子单元,被配置为根据具有标签的所述目标域样本数据的特征向量及对应的标签、源域样本数据的特征向量及对应的标签确定对应的对比损失;
第一参数调节子单元,被配置为根据所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数。
可选的,第一参数调节子单元包括:
第一调节模块,被配置为保持所述特征生成网络和所述分类网络的参数,调节所述域判别网络的参数以最大化所述对抗损失;
第二调节模块,被配置为保持所述域判别网络的参数,调节所述特征生成网络和所述分类网络的参数以最小化所述分类损失、对比损失和对抗损失。
可选的,所述特征生成网络包括第一特征生成子网络、第二特征生成子网络和第三特征生成子网络;
所述第一特征生成子网络用于根据所述源域样本数据和对应的掩码数据生成所述源域样本数据的特征向量;
所述第二特征生成子网络用于根据所述目标域样本数据生成所述目标域样本数据的特征向量;
所述第三特征生成子网络用于对所述源域样本数据的特征向量和所述目标域样本数据的特征向量进行特征处理,获取预定维度的所述源域样本数据的特征向量和所述目标域样本数据的特征向量。
可选的,所述特征生成网络包括多个第三特征生成子网络,多个所述第三特征生成子网络权值共享。
可选的,所述第一特征生成子网络具有对应的第一解码器和第一自编码损失,所述第二特征生成子网络具有对应的第二解码器和第二自编码损失;
训练单元还包括:
第一自编码损失确定子单元,被配置为根据所述第一特征生成子网络的输入值和所述第一解码器的输出值计算所述第一自编码损失;
第二自编码损失确定子单元,被配置为根据所述第二特征生成子网络的输入值和所述第二解码器的输出值计算所述第二自编码损失;
特征损失确定子单元,被配置为根据所述第一自编码损失和所述第二自编码损失确定所述特征生成网络的损失;
第二参数调节子单元,被配置为根据所述特征生成网络的损失、所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数。
可选的,第二参数调节子单元包括:
第三参数调节模块,被配置为第一调节子保持所述特征生成网络和所述分类网络的参数,调节所述域判别网络的参数以最大化所述对抗损失;
第四参数调节模块,被配置为保持所述域判别网络的参数,调节所述特征生成网络和所述分类网络的参数以最小化所述特征生成网络的损失、分类损失、对比损失和对抗损失。
第三方面,本发明实施例提供一种电子设备,包括存储器和处理器,所述存储器用于存储一条或多条计算机程序指令,其中,所述一条或多条计算机程序指令被所述处理器执行以实现如上所述的方法。
第四方面,本发明实施例提供一种计算机可读存储介质,其上存储计算机程序指令,所述计算机程序指令在被处理器执行时以实现如上所述的方法。
本发明实施例通过获取目标域数据集、源域数据集以及随机噪声,将各源域样本数据和随机噪声输入至用户分类模型的掩码生成网络,确定源域数据集对应的掩码数据集,根据目标域数据集、源域数据集和掩码数据集训练用户分类模型中的特征生成网络、分类网络和域判别网络,响应于特征生成网络、分类网络和域判别网络对应的损失函数满足预定条件,确定用户分类模型,由此,发明实施例通过随机噪声掩盖源域数据集的部分特征,减小源域数据集与目标域数据集特征分布差异,从而提高用户分类模型的准确性。
附图说明
通过以下参照附图对本发明实施例的描述,本发明的上述以及其它目的、特征和优点将更为清楚,在附图中:
图1是本发明实施例的用户分类模型的训练方法的流程图;
图2是本发明实施例的用户分类模型的各网络训练过程的流程图;
图3是本发明实施例的一种各网络参数调节的方法流程图;
图4是本发明实施例的另一种各网络参数调节的方法流程图;
图5是本发明实施例的用户分类模型的训练过程示意图;
图6是本发明实施例的用户分类方法的流程图;
图7是本发明实施例的用户分类模型的训练装置的示意图;
图8是本发明实施例的电子设备的示意图。
具体实施方式
以下基于实施例对本发明进行描述,但是本发明并不仅仅限于这些实施例。在下文对本发明的细节描述中,详尽描述了一些特定的细节部分。对本领域技术人员来说没有这些细节部分的描述也可以完全理解本发明。为了避免混淆本发明的实质,公知的方法、过程、流程、元件和电路并没有详细叙述。
此外,本领域普通技术人员应当理解,在此提供的附图都是为了说明的目的,并且附图不一定是按比例绘制的。
除非上下文明确要求,否则在说明书的“包括”、“包含”等类似词语应当解释为包含的含义而不是排他或穷举的含义;也就是说,是“包括但不限于”的含义。
在本发明的描述中,需要理解的是,术语“第一”、“第二”等仅用于描述目的,而不能理解为指示或暗示相对重要性。此外,在本发明的描述中,除非另有说明,“多个”的含义是两个或两个以上。
图1是本发明实施例的用户分类模型的训练方法的流程图。如图1所示,本发明实施例的用户分类模型的训练方法包括以下步骤:
步骤S110,获取目标域数据集、源域数据集以及随机噪声。其中,源域数据集包括具有标签的多个源域样本数据,目标域数据集包括多个目标域样本数据,其中部分目标域样本数据具有标签。源域样本数据包括对应用户在预定时间范围内的历史任务记录,目标域样本数据中不存在预定时间范围内的历史任务记录。可选的,标签用于表征用户在未来预定时间范围内的任务状态,任务状态包括用户在未来预定时间范围内会执行任务、以及用户在未来预定时间范围内不会执行任务。
以网约车为例,在一段时间内(例如1个月内)未使用网约车服务的用户可以定义为新用户,在一段时间内使用了网约车服务的用户可以定义为活跃用户,由此,可以采用数据库中的各用户的用户信息和历史任务记录训练分类模型,以预测新用户在未来预定时间范围内是否为使用网约车服务,进而可以基于此确定对应用户的推荐信息等,应理解,本实施例并不限制应用领域,其他需要预测新用户是否在未来预定时间段内使用对应服务的应用领域均可采用本实施例的用户分类方法,例如外卖领域、快递领域等。可选的,在本实施例中,新用户并不仅代表从未使用过对应服务的用户,还代表虽然使用过,但在最近较长时间内未使用该对应服务的用户。
传统的分类模型使用新用户的完整数据和大量标签可以训练出有效的模型对新用户数据进行分类。然而,在实际情况下很难获得新用户的大量带标签的数据,同时也无法得到所有的特征数据,比如用户在注册时上传的个人信息会略过一些无关的选项,这些缺失的特征数据会影响分类模型的分类精度。而活跃用户通常在使用过程中通常会不断完善个人信息等数据,因此在实际情况下可以采集到大量的活跃用户数据。由此,本实施例采用迁移学习来训练源域数据集(也即活跃用户数据集),以对目标用户类别进行预测,从而减少标注标签带来的时间和人力的浪费,并且可以减少新用户中缺失数据对用户分类模型训练的影响。在本实施例中,采用域适应方法(Domain Adaptation)通过从有标注的源域数据集(例如活跃用户数据集)中学习知识,并迁移到有少量标签的目标域数据(例如新用户数据集)中,减少源域数据集和目标域数据集的特征分布差异。其中,本实施例通过随机噪声掩盖源域数据集的部分特征,以减小源域数据和目标域数据的特征分布差异,从而提高训练后的用户分类模型的准确性。
步骤S120,将各源域样本数据和随机噪声输入至用户分类模型的掩码生成网络,确定源域数据集对应的掩码数据集。在本实施例中,通过掩码生成网络将随机噪声映射为对应的掩码,与源域样本数据进行组合,以掩盖源域样本数据的部分特征数据,从而可以增加用户分类模型对噪声和目标样本数据缺失特征数据的宽容性,减少目标域样本数据缺失特征数据对模型训练的影响。
步骤S130,根据目标域数据集、源域数据集和掩码数据集训练用户分类模型中的特征生成网络、分类网络和域判别网络。其中,特征生成网络用于获取目标域样本数据和源域样本数据的特征向量,分类网络用于根据目标域样本数据和源域样本数据的特征向量、源域样本数据的标签、部分目标域样本数据的标签确定对应的标签预测值,域判别网络用于使得目标域数据集与所述源域数据集进行特征对齐。
图2是本发明实施例的用户分类模型的各网络训练过程的流程图。在一种可选的实现方式中,如图2所示,步骤S130包括:
步骤S131,将目标域数据集、源域数据集和掩码数据集输入至特征生成网络进行处理,确定各目标域样本数据和源域样本数据的特征向量。在一种可选的实现方式中,特征生成网络包括第一特征生成子网络、第二特征生成子网络和第三特征生成子网络。第一特征生成子网络用于根据源域样本数据和对应的掩码数据生成该源域样本数据的特征向量。可选的,将源域样本数据与对应的掩码数据相乘获得该源域样本数据的掩码数据特征,将该源域样本数据的掩码数据特征输入至第一特征生成子网络,获取该源域样本数据的特征向量。第二特征生成子网络用于根据目标域样本数据生成目标域样本数据的特征向量。可选的,第一特征生成子网络和第二特征生成子网络输出的特征向量维度为1024维,应理解,本实施例并不对此进行限制。第三特征生成子网络用于对源域样本数据的特征向量和目标域样本数据的特征向量进行特征处理,获取预定维度的源域样本数据的特征向量和目标域样本数据的特征向量。可选的,第三特征生成子网络输出的特征向量的维度为256维,应理解,本实施例并不对此进行限制。在一种可选的实现方式中,特征生成网络包括多个第三特征生成子网络,其中,多个第三特征生成子网络共享权值,由此,可以提高数据处理效率。
步骤S132,将各目标域样本数据和源域样本数据的特征向量输入至域判别网络中进行处理,确定对应的特征分布。
步骤S133,根据上述目标域样本数据和源域样本数据的特征分布确定域判别网络对应的对抗损失Ladv
Figure BDA0002707000220000101
其中,
Figure BDA0002707000220000102
为第i个源域样本数据,ε为随机噪声,Gm()为掩码生成网络的输出,fm()表征第i个源域样本数据对应的掩码数据特征,Gs()为第一特征生成子网络的输出,G()为第三特征生成子网络的输出,ns为源域样本数据的数量,
Figure BDA0002707000220000103
为第i个目标域样本数据,nt为目标域样本数据的数量,Gt()为第二特征生成子网络的输出,
Figure BDA0002707000220000104
为第i个源域样本数据的域判别网络输出,
Figure BDA0002707000220000105
为第i各目标域样本数据的域判别网络输出。
步骤S134,将各目标域样本数据和源域样本数据的特征向量输入至分类网络,确定各目标域样本数据和源域样本数据的标签预测值。可选的,假设第一标签为用户在未来预定时间范围内会执行任务,第二标签为用户在未来预定时间范围内不会执行任务,则目标域样本数据的标签预测值包括该目标域样本数据的标签为第一标签的概率、以及该目标域样本数据的标签为第二标签的概率。同理,源域样本数据的标签预测值包括该源域样本数据的标签为第一标签的概率、以及该源域样本数据的标签为第二标签的概率。可选的,标签预测值为一个向量。
步骤S135,根据具有标签的目标域样本数据的标签及对应的标签预测值、源域样本数据的标签及对应的标签预测值确定分类网络对应的分类损失Lc
Figure BDA0002707000220000111
Figure BDA0002707000220000112
Figure BDA0002707000220000113
其中,ns为源域样本数据的数量,nt-Iabeled为具有标签的目标域样本数据的数量,
Figure BDA0002707000220000114
为第i个源域样本数据的标签向量,T为向量的转置运算,
Figure BDA0002707000220000115
为第i个源域样本数据对应的标签预测值,
Figure BDA0002707000220000116
为第i个目标域样本数据的标签向量,
Figure BDA0002707000220000117
为第i个目标域域样本数据对应的标签预测值,
Figure BDA0002707000220000118
为第i个源域样本数据,ε为随机噪声,Gm()为掩码生成网络的输出,fm()表征第i个源域样本数据对应的掩码数据特征,Gs()为第一特征生成子网络的输出,G()为第三特征生成子网络的输出,F为分类网络的输出,ns为源域样本数据的数量,
Figure BDA0002707000220000119
为第i个目标域样本数据,Gt()为第二特征生成子网络的输出。
步骤S136,根据具有标签的目标域样本数据的特征向量及对应的标签、源域样本数据的特征向量及对应的标签确定对应的对比损失Lcontras。在一种可选的实现方式中,步骤S136具体可以为:采用欧式距离、余弦距离等方法计算特征向量对之间的相似度,根据特征向量对之间的相似度以及特征向量的标签,确定对应的对比损失。其中,特征向量对可以包括两个具有标签的源域样本数据的特征向量,也可以包括两个具有标签的目标域样本数据的特征向量,还可以包括一个具有标签的源域样本数据的特征向量和一个具有标签的目标域样本数据的特征向量。
在本实施例中,为了解决不同域之间的类内多样性和类间相似性问题,本实施例采用对比损失函数来约束跨域的类内距离和类间距离。对于一个特征向量对(Ga,Gb),设置其指示标签lab,当特征向量对表征的两个样本数据同类别时,lab=+1,当特征向量对表征的两个样本数据不同类别时lab=-1。其中,Ga和Gb可以为源域数据集中的两个具有标签源域样本数据的特征向量,也可以为目标域数据集中的两个具有标签目标域样本数据的特征向量,还可以一个为具有标签的源域样本数据的特征向量,另一个为具有标签的目标域数据集的特征向量。在一种可选的实现方式中,计算特征向量对之间的欧式距离来计算特征向量对之间的相似度,应理解,也可以通过计算余弦距离等相似度计算方法来计算特征向量对之间的相似度,本实施例并不对此进行限制。由此,在本实施例中,对比损失Lcontras为:
Lcontras=max(0,τ+lab(||Ga-Gb||2))
其中,Ga与Gb为具有标签的源域样本数据的特征向量或具有标签的目标域样本数据的特征向量,lab为特征向量对(Ga,Gb)的指示标签,||Ga-Gb||2为特征向量对(Ga,Gb)的欧式距离,τ为自定义阈值,用于控制不同类特征向量的距离边缘。应理解,在本实施例中,特征向量Ga和Gb可以为第一特征生成子网络输出的特征向量,或第二特征生成子网络输出的特征向量,也可以均为第三特征子网络输出的特征向量,本实施例并不对此进行限制。
步骤S137,根据对抗损失、分类损失和对比损失调节特征生成网络、分类网络和域判别网络的参数。
图3是本发明实施例的一种各网络参数调节的方法流程图。在一种可选的实现方式中,步骤S137可以包括:
步骤S137A:保持所述特征生成网络和分类网络的参数,调节域判别网络的参数以最大化所述对抗损失。也就是说,先使得特征生成网络和分类网络的参数不变,通过调节域判别网络的参数使得对抗损失Ladv最大化。
步骤S137B:保持域判别网络的参数,调节特征生成网络和分类网络的参数以最小化分类损失、对比损失和对抗损失。也就是说,在调节域判别网络的参数后,使得域判别网络的参数不变,通过调节特征生成网络和分类网络的参数,使得对抗损失Ladv、分类损失LC和对比损失Lcontras最小化,或者使得对抗损失Ladv、分类损失LC和对比损失Lcontras分别收敛到对应的预设值。
在本实施例中,保持特征生成网络和分类网络的参数不变,调节域判别网络的参数,使得对抗损失Ladv最大化,保持域判别网络的参数不变,调节特征生成网络和分类网络的参数使得对抗损失Ladv最小化,由此,迭代执行步骤S137,使得对抗损失Ladv在一定范围内上下波动、并使得分类损失LC和对比损失Lcontras最小化。
本实施例通过交替调节域判别网络的参数、以及特征生成网络和分类网络的参数,以交替优化对应的各损失,从而使得源域样本数据和目标域样本数据可以具有相同或基本相同的特征分布,由此可以更准确地将从源域数据集中学习到的知识迁移到目标域数据集,从而可以提高训练完成的用户分类模型的分类准确度。
在一种可选的实现方式中,如上所述,特征生成网络包括第一特征生成子网络、第二特征生成子网络和第三特征生成子网络。可选的,为了减少第一特征生成子网络和第二特征生成子网络对源域样本数据和目标域样本数据降维时特征内容的损失以及缺失数据对模型的影响,本实施例采用自编码模式对第一特征生成子网络和第二特征生成子网络进行约束。由此,在本实施例中,将第一特征生成子网络和第二特征生成子网络定义为第一编码器和第二编码器,并配置对应的第一解码器和第二解码器。也即,在本实施例中,第一特征生成子网络具有对应的第一解码器和第一自编码损失,第二特征生成子网络具有对应的第二解码器和第二自编码损失。
在一种可选的实现方式中,步骤S130还包括:确定特征生成网络的损失,根据特征生成网络的损失、对抗损失、分类损失和对比损失调节特征生成网络、分类网络和域判别网络的参数。
可选的,根据第一特征生成子网络的输入值和第一解码器的输出值计算第一自编码损失,根据第二特征生成子网络的输入值和第二解码器的输出值计算第二自编码损失,根据第一自编码损失和第二自编码损失确定特征生成网络的损失。可选的,特征生成网络对应的自编码损失函数LAE
Figure BDA0002707000220000141
其中,Gs()为第一特征生成子网络的输出,Gt()为第二特征生成子网络的输出,Des()为第一解码器的输出,Det()为第二解码器的输出,ns为源域样本数据的数量,nt为目标域样本数据的数量,
Figure BDA0002707000220000142
表征第i个源域样本数据对应的掩码数据特征(也即第一特征生成子网络的输入值),
Figure BDA0002707000220000143
为第i个目标域样本数据(也即第二特征生成子网络输入值),
Figure BDA0002707000220000144
表征第一自编码损失,
Figure BDA0002707000220000145
表征第二自编码损失。
图4是本发明实施例的另一种各网络参数调节的方法流程图。在一种可选的实现方式中,根据特征生成网络的损失、对抗损失、分类损失和对比损失调节特征生成网络、分类网络和域判别网络的参数包括:
步骤S137C:保持特征生成网络和分类网络的参数,调节域判别网络的参数以最大化所述对抗损失。也就是说,先使得特征生成网络和分类网络的参数不变,通过调节域判别网络的参数使得对抗损失Ladv最大化。
步骤S137D:保持域判别网络的参数,调节特征生成网络和分类网络的参数以最小化特征生成网络的损失、分类损失、对比损失和对抗损失。也就是说,在调节域判别网络的参数后,使得域判别网络的参数不变,通过调节特征生成网络和分类网络的参数,使得特征生成网络的损失LAE、对抗损失Ladv、分类损失LC和对比损失Lcontras最小化,或者使得特征生成网络的损失LAE、对抗损失Ladv、分类损失LC和对比损失Lcontras分别收敛到对应的预设值。
在本实施例中,保持特征生成网络和分类网络的参数不变,调节域判别网络的参数,使得对抗损失Ladv最大化,保持域判别网络的参数不变,调节特征生成网络和分类网络的参数使得对抗损失Ladv最小化,由此,迭代执行步骤S137,使得对抗损失Ladv在一定范围内上下波动、并使得特征生成网络的损失LAE、分类损失LC和对比损失Lcontras最小化。
本实施例通过交替调节域判别网络的参数、以及特征生成网络和分类网络的参数,以交替优化对应的各损失,从而使得源域样本数据和目标域样本数据可以具有相同或基本相同的特征分布,由此可以更准确地将从源域数据集中学习到的知识迁移到目标域数据集,从而可以提高训练完成的用户分类模型的分类准确度。
步骤S140,响应于特征生成网络、分类网络和域判别网络对应的损失函数满足预定条件,确定所述用户分类模型。也就是说,响应于特征生成网络、分类网络以及域判别网络对应的各损失达到对应的收敛值,确定用户分类模型训练完成。可选的,用户分类模型对应的各损失包括特征生成网络的损失、分类损失、对抗损失和对比损失,或者用户分类模型对应的各损失包括分类损失、对抗损失和对比损失。
本发明实施例通过获取目标域数据集、源域数据集以及随机噪声,将各源域样本数据和随机噪声输入至用户分类模型的掩码生成网络,确定源域数据集对应的掩码数据集,根据目标域数据集、源域数据集和掩码数据集训练用户分类模型中的特征生成网络、分类网络和域判别网络,响应于特征生成网络、分类网络和域判别网络对应的损失函数满足预定条件,确定用户分类模型,由此,发明实施例通过随机噪声掩盖源域数据集的部分特征,减小源域数据集与目标域数据集特征分布差异,从而提高用户分类模型的准确性。
图5是本发明实施例的用户分类模型的训练过程示意图。如图5所示,获取目标域数据集xt、源域数据集xs以及随机噪声ε。其中,源域数据集xs包括具有标签的多个源域样本数据,目标域数据集xt包括多个目标域样本数据,其中部分目标域样本数据具有标签。源域样本数据包括对应用户在预定时间范围内的历史任务记录,目标域样本数据中不存在预定时间范围内的历史任务记录。可选的,标签用于表征用户在未来预定时间范围内的任务状态,任务状态包括用户在未来预定时间范围内会执行任务、以及用户在未来预定时间范围内不会执行任务。
在本实施例中,将源域数据集xs和随机噪声ε输入至掩码生成网络Gm进行处理,获取掩码数据集Gm(xs,ε),将源域数据集xs和掩码数据集Gm(xs,ε)输入至掩码特征数据生成模块fm,以获取各源域样本数据的掩码特征数据,将各源域样本数据的掩码特征数据输入至第一特征生成子网络Gs中进行处理,获取第一维度的各源域样本数据的特征向量。在本实施例中,将目标域样本数据集xt输入至第二特征生成子网络Gt进行处理,获取第一维度的各目标域样本数据的特征向量,将第一维度的各源域样本数据的特征向量和各目标域样本数据的特征向量输入至第三特征生成子网络G中进行处理,获取第二维度(上述预定维度)的各源域样本数据和各目标域样本数据的特征向量。可选的,第一维度为1024维,第二维度为256维,应理解,本实施例并不对此进行限制。可选的,本实施例的特征生成网络包括多个第三特征生成子网络G,多个第三特征生成子网络G权值共享,由此,可以将第一维度的各源域样本数据的特征向量和各目标域样本数据的特征向量输入至不同的第三特征子网络G中进行处理,以提高数据处理效率。本实施例通过掩码生成网络Gm将随机噪声映射为对应的掩码,与源域样本数据进行组合,以掩盖源域样本数据的部分特征数据,从而可以增加用户分类模型对噪声和目标样本数据缺失特征数据的宽容性,减少目标域样本数据缺失特征数据对模型训练的影响。
在本实施例中,为了减少第一特征生成子网络和第二特征生成子网络对源域样本数据和目标域样本数据降维时特征内容的损失以及缺失数据对模型的影响,本实施例采用自编码模式对第一特征生成子网络和第二特征生成子网络进行约束。如图5所示,在本实施例中,将第一特征生成子网络和第二特征生成子网络定义为第一编码器和第二编码器,并配置对应的第一解码器Des和第二解码器Det。也即,在本实施例中,第一特征生成子网络具有对应的第一解码器Des和第一自编码损失
Figure BDA0002707000220000161
第二特征生成子网络具有对应的第二解码器Det和第二自编码损失
Figure BDA0002707000220000162
由此,本实施例可以采用特征生成网络对应的自编码损失函数LAE(包括第一自编码损失
Figure BDA0002707000220000163
和第二自编码损失
Figure BDA0002707000220000164
)对用户分类模型中的各网络进行调参以训练该用户分类模型。
在一种可选的实现方式中,为了解决不同域之间的类内多样性和类间相似性问题,本实施例采用对比损失函数来约束跨域的类内距离和类间距离。对于一个特征向量对(Ga,Gb),设置其指示标签lab,当特征向量对表征的两个样本数据同类别时,lab=+1,当特征向量对表征的两个样本数据不同类别时lab=-1。其中,Ga和Gb可以为源域数据集中的两个具有标签源域样本数据的特征向量,也可以为目标域数据集中的两个具有标签源域样本数据的特征向量,还可以一个为具有标签的源域样本数据的特征向量,另一个为具有标签的目标域数据集的特征向量。其中,Ga和Gb可以为第一维度的特征向量(也即Gs或Gt输出的特征向量),也可以均为第二维度的特征向量(也即G输出的特征向量),本实施例并不对此进行限制。在一种可选的实现方式中,计算特征向量对之间的欧式距离来计算特征向量对之间的相似度,应理解,也可以通过计算余弦距离等相似度计算方法来计算特征向量对之间的相似度,本实施例并不对此进行限制。由此,在本实施例中,对比损失Lcontras为:
Lcontras=max(0,τ+lab(||Ga-Gb||2))
其中,Ga与Gb为具有标签的源域样本数据的特征向量或具有标签的目标域样本数据的特征向量,lab为特征向量对(Ga,Gb)的指示标签,||Ga-Gb||2为特征向量对(Ga,Gb)的欧式距离,τ为自定义阈值,用于控制不同类特征向量的距离边缘。
在本实施例中,将第二维度的各目标域样本数据和源域样本数据的特征向量输入至域判别网络中进行处理,确定对应的特征分布,根据该特征分布确定域判别网络的对抗损失Ladv。其中,对抗损失Ladv的计算方法如上所述,在此不再赘述。
在本实施例中,将第二维度的各目标域样本数据和源域样本数据的特征向量输入至分类网络F中进行处理,确定各目标域样本数据和源域样本数据的标签预测值,并根据具有标签的目标域样本数据的标签及对应的标签预测值、源域样本数据的标签及对应的标签预测值确定分类网络对应的分类损失Lc。在一种可选的实现方式中,本实施例的用户分类模型包括多个分类网络F,多个分类网络F共享权值,由此,可以将第二维度的各目标域样本数据和源域样本数据的特征向量输入至不同的分类网络F中进行处理,以提高数据处理效率。其中,分类损失Lc的计算方法如上所述,在此不再赘述。可选的,假设第一标签为用户在未来预定时间范围内会执行任务,第二标签为用户在未来预定时间范围内不会执行任务,则目标域样本数据的标签预测值包括该目标域样本数据的标签为第一标签的概率、以及该目标域样本数据的标签为第二标签的概率。同理,源域样本数据的标签预测值包括该源域样本数据的标签为第一标签的概率、以及该源域样本数据的标签为第二标签的概率。可选的,标签预测值为一个向量。
在本实施例中,在训练用户分类模型过程中,首先保持特征生成网络(包括Gs、Gt和G)和分类网络F的参数,调节域判别网络D的参数以最大化对抗损失Ladv,再保持域判别网络D的参数,调节特征生成网络和分类网络F的参数以最小化特征生成网络的损失LAE、分类损失Lc、对比损失Lcontras和对抗损失Ladv。也就是说,先使得特征生成网络和分类网络F的参数不变,通过调节域判别网络D的参数使得对抗损失Ladv最大化,在调节域判别网络D的参数后,使得域判别网络D的参数不变,通过调节特征生成网络和分类网络F的参数,使得特征生成网络的损失LAE、对抗损失Ladv、分类损失LC和对比损失Lcontras最小化,或者使得特征生成网络的损失LAE、对抗损失Ladv、分类损失LC和对比损失Lcontras分别收敛到对应的预设值。
由此,本实施例通过保持特征生成网络和分类网络F的参数不变,调节域判别网络D的参数,使得对抗损失Ladv最大化,保持域判别网络的D参数不变,调节特征生成网络和分类网络F的参数使得对抗损失Ladv最小化,由此,迭代执行上述步骤,使得对抗损失Ladv在一定范围内上下波动、并使得特征生成网络的损失LAE、分类损失LC和对比损失Lcontras最小化,以确定本实施例的用户分类模型。
本实施例通过交替调节域判别网络的参数、以及特征生成网络和分类网络的参数,以交替优化对应的各损失,从而使得源域样本数据和目标域样本数据可以具有相同或基本相同的特征分布,由此可以更准确地将从源域数据集中学习到的知识迁移到目标域数据集,从而可以提高训练完成的用户分类模型的分类准确度。
本发明实施例通过获取目标域数据集、源域数据集以及随机噪声,将各源域样本数据和随机噪声输入至用户分类模型的掩码生成网络,确定源域数据集对应的掩码数据集,根据目标域数据集、源域数据集和掩码数据集训练用户分类模型中的特征生成网络、分类网络和域判别网络,响应于特征生成网络、分类网络和域判别网络对应的损失函数满足预定条件,确定用户分类模型,由此,发明实施例通过随机噪声掩盖源域数据集的部分特征,减小源域数据集与目标域数据集特征分布差异,从而提高用户分类模型的准确性。
图6是本发明实施例的用户分类方法的流程图。本实施例采用上述实施例训练获得的用户分类模型对用户进行分类,以网约车应用场景为例,采用上述训练获得的用户分类模型预测用户在未来预定时间段内是否使用网约车服务,由此可以基于此确定对应用户的推荐信息等。例如,假设用户的预测结果为在未来预定时间内会使用网约车服务,可以向用户推荐网约车相关信息等。如图6所示,本实施例的用户分类方法包括以下步骤:
步骤S210,获取目标用户的用户信息,用户信息可以包括历史行为记录和帐号信息等。
步骤S220,将目标用户的用户信息输入至预先训练的用户分类模型中进行处理,确定该目标用户的类别。
步骤S230,根据该目标用户的类别确定推荐信息。
本发明实施例通过获取目标用户的用户信息,并将目标用户的用户信息输入至预先训练的用户分类模型中进行处理,确定该目标用户的类别,根据该目标用户的类别确定推荐信息,其中,本实施例中的用户分类模型在训练过程中通过随机噪声掩盖源域数据集的部分特征,减小了源域数据集与目标域数据集特征分布差异,从而提高了用户分类模型的准确性,由此,本实施例可以更准确地确定目标用户的类别,进而更准确地向目标用户推荐对应的信息。
图7是本发明实施例的用户分类模型的训练装置的示意图。如图7所示,本发明实施例的用户分类模型的训练装置7包括数据获取单元71、掩码数据确定单元72、训练单元73和模型确定单元74。
数据获取单元71被配置为获取目标域数据集、源域数据集以及随机噪声,所述源域数据集包括具有标签的多个源域样本数据,所述目标域数据集包括多个目标域样本数据,其中部分目标域样本数据具有标签;其中,所述源域样本数据包括对应用户在预定时间范围内的历史任务记录,所述目标域样本数据中不存在预定时间范围内的历史任务记录。在一种可选的实现方式中,所述标签用于表征用户在未来预定时间范围内的任务状态,所述任务状态包括用户在未来预定时间范围内会执行任务、以及用户在未来预定时间范围内不会执行任务。
掩码数据确定单元72被配置为将各所述源域样本数据和所述随机噪声输入至所述用户分类模型的掩码生成网络,确定所述源域数据集对应的掩码数据集。训练单元73被配置为根据所述目标域数据集、所述源域数据集和所述掩码数据集训练所述用户分类模型中的特征生成网络、分类网络和域判别网络。模型确定单元74被配置为响应于所述特征生成网络、分类网络和域判别网络对应的损失函数满足预定条件,确定所述用户分类模型。其中,所述特征生成网络用于获取所述目标域样本数据和源域样本数据的特征向量,所述分类网络用于根据所述目标域样本数据和源域样本数据的特征向量、所述源域样本数据的标签、部分所述目标域样本数据的标签确定对应的标签预测值,所述域判别网络用于使得所述目标域数据集与所述源域数据集进行特征对齐。
在一种可选的实现方式中,训练单元73包括第一处理子单元731、第二处理子单元732、对抗损失确定子单元733、第三处理子单元734、分类损失确定子单元735、对比损失确定子单元736和第一参数调节子单元737。
第一处理子单元731被配置为将所述目标域数据集、所述源域数据集和所述掩码数据集输入至所述特征生成网络进行处理,确定各所述目标域样本数据和源域样本数据的特征向量。第二处理子单元732被配置为将各所述目标域样本数据和源域样本数据的特征向量输入至所述域判别网络中进行处理,确定对应的特征分布。对抗损失确定子单元733被配置为根据所述特征分布确定所述域判别网络对应的对抗损失。第三处理子单元734被配置为将各所述目标域样本数据和源域样本数据的特征向量输入至所述分类网络,确定各所述目标域样本数据和源域样本数据的标签预测值。分类损失确定子单元735被配置为根据具有标签的所述目标域样本数据的标签及对应的标签预测值、源域样本数据的标签及对应的标签预测值确定所述分类网络对应的分类损失。对比损失确定子单元736被配置为根据具有标签的所述目标域样本数据的特征向量及对应的标签、源域样本数据的特征向量及对应的标签确定对应的对比损失。第一参数调节子单元737被配置为根据所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数。
在一种可选的实现方式中,第一参数调节子单元737包括第一调节模块7371和第二调节模块7372。第一调节模块7371被配置为保持所述特征生成网络和所述分类网络的参数,调节所述域判别网络的参数以最大化所述对抗损失。第二调节模块7372被配置为保持所述域判别网络的参数,调节所述特征生成网络和所述分类网络的参数以最小化所述分类损失、对比损失和对抗损失。
在一种可选的实现方式中,所述特征生成网络包括第一特征生成子网络、第二特征生成子网络和第三特征生成子网络。所述第一特征生成子网络用于根据所述源域样本数据和对应的掩码数据生成所述源域样本数据的特征向量。所述第二特征生成子网络用于根据所述目标域样本数据生成所述目标域样本数据的特征向量。所述第三特征生成子网络用于对所述源域样本数据的特征向量和所述目标域样本数据的特征向量进行特征处理,获取预定维度的所述源域样本数据的特征向量和所述目标域样本数据的特征向量。可选的,所述特征生成网络包括多个第三特征生成子网络,多个所述第三特征生成子网络权值共享。
可选的,所述第一特征生成子网络具有对应的第一解码器和第一自编码损失,所述第二特征生成子网络具有对应的第二解码器和第二自编码损失。训练单元73还包括第一自编码损失确定子单元738、第二自编码损失确定子单元739、特征损失确定子单元73A和第二参数调节子单元73B。第一自编码损失确定子单元738被配置为根据所述第一特征生成子网络的输入值和所述第一解码器的输出值计算所述第一自编码损失。第二自编码损失确定子单元739被配置为根据所述第二特征生成子网络的输入值和所述第二解码器的输出值计算所述第二自编码损失。特征生成网络确定子单元73A被配置为根据所述第一自编码损失和所述第二自编码损失确定所述特征生成网络的损失。第二参数调节子单元73B被配置为根据所述特征生成网络的损失、所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数。
在一种可选的实现方式中,第二参数调节子单元73B包括第三参数调节模块73B1和第四参数调节模块73B2。第三参数调节模块73B1被配置为第一调节子保持所述特征生成网络和所述分类网络的参数,调节所述域判别网络的参数以最大化所述对抗损失。第四参数调节模块73B2被配置为保持所述域判别网络的参数,调节所述特征生成网络和所述分类网络的参数以最小化所述特征生成网络的损失、分类损失、对比损失和对抗损失。
本发明实施例通过获取目标域数据集、源域数据集以及随机噪声,将各源域样本数据和随机噪声输入至用户分类模型的掩码生成网络,确定源域数据集对应的掩码数据集,根据目标域数据集、源域数据集和掩码数据集训练用户分类模型中的特征生成网络、分类网络和域判别网络,响应于特征生成网络、分类网络和域判别网络对应的损失函数满足预定条件,确定用户分类模型,由此,发明实施例通过随机噪声掩盖源域数据集的部分特征,减小源域数据集与目标域数据集特征分布差异,从而提高用户分类模型的准确性。
图8是本发明实施例的电子设备的示意图。如图8所示,图8所示的电子设备为通用地址查询装置,其包括通用的计算机硬件结构,其至少包括处理器81和存储器82。处理器81和存储器82通过总线83连接。存储器82适于存储处理器81可执行的指令或程序。处理器81可以是独立的微处理器,也可以是一个或者多个微处理器集合。由此,处理器81通过执行存储器82所存储的指令,从而执行如上所述的本发明实施例的方法流程实现对于数据的处理和对于其它装置的控制。总线83将上述多个组件连接在一起,同时将上述组件连接到显示控制器84和显示装置以及输入/输出(I/O)装置85。输入/输出(I/O)装置85可以是鼠标、键盘、调制解调器、网络接口、触控输入装置、体感输入装置、打印机以及本领域公知的其他装置。典型地,输入/输出装置85通过输入/输出(I/O)控制器86与系统相连。
本领域的技术人员应明白,本申请的实施例可提供为方法、装置(设备)或计算机程序产品。因此,本申请可采用完全硬件实施例、完全软件实施例或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可读存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品。
本申请是参照根据本申请实施例的方法、装置(设备)和计算机程序产品的流程图来描述的。应理解可由计算机程序指令实现流程图中的每一流程。
这些计算机程序指令可以存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现流程图一个流程或多个流程中指定的功能。
也可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程中指定的功能的装置。
本发明的另一实施例涉及一种非易失性存储介质,用于存储计算机可读程序,所述计算机可读程序用于供计算机执行上述部分或全部的方法实施例。
即,本领域技术人员可以理解,实现上述实施例方法中的全部或部分步骤是可以通过程序来指令相关的硬件来完成,该程序存储在一个存储介质中,包括若干指令用以使得一个设备(可以是单片机,芯片等)或处理器(processor)执行本申请各实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-OnlyMemory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述仅为本发明的优选实施例,并不用于限制本发明,对于本领域技术人员而言,本发明可以有各种改动和变化。凡在本发明的精神和原理之内所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

Claims (18)

1.一种用户分类模型的训练方法,其特征在于,所述方法包括:
获取目标域数据集、源域数据集以及随机噪声,所述源域数据集包括具有标签的多个源域样本数据,所述目标域数据集包括多个目标域样本数据,其中部分目标域样本数据具有标签;其中,所述源域样本数据包括对应用户在预定时间范围内的历史任务记录,所述目标域样本数据中不存在预定时间范围内的历史任务记录;
将各所述源域样本数据和所述随机噪声输入至所述用户分类模型的掩码生成网络,确定所述源域数据集对应的掩码数据集;
根据所述目标域数据集、所述源域数据集和所述掩码数据集训练所述用户分类模型中的特征生成网络、分类网络和域判别网络;
响应于所述特征生成网络、分类网络和域判别网络对应的损失函数满足预定条件,确定所述用户分类模型;
其中,所述特征生成网络用于获取所述目标域样本数据和源域样本数据的特征向量,所述分类网络用于根据所述目标域样本数据和源域样本数据的特征向量、所述源域样本数据的标签、部分所述目标域样本数据的标签确定对应的标签预测值,所述域判别网络用于使得所述目标域数据集与所述源域数据集进行特征对齐。
2.根据权利要求1所述的方法,其特征在于,所述标签用于表征用户在未来预定时间范围内的任务状态,所述任务状态包括用户在未来预定时间范围内会执行任务、以及用户在未来预定时间范围内不会执行任务。
3.根据权利要求1所述的方法,其特征在于,根据所述目标域数据集、所述源域数据集和所述掩码数据集训练所述用户分类模型中的特征生成网络、分类网络和域判别网络包括:
将所述目标域数据集、所述源域数据集和所述掩码数据集输入至所述特征生成网络进行处理,确定各所述目标域样本数据和源域样本数据的特征向量;
将各所述目标域样本数据和源域样本数据的特征向量输入至所述域判别网络中进行处理,确定对应的特征分布;
根据所述特征分布确定所述域判别网络对应的对抗损失;
将各所述目标域样本数据和源域样本数据的特征向量输入至所述分类网络,确定各所述目标域样本数据和源域样本数据的标签预测值;
根据具有标签的所述目标域样本数据的标签及对应的标签预测值、源域样本数据的标签及对应的标签预测值确定所述分类网络对应的分类损失;
根据具有标签的所述目标域样本数据的特征向量及对应的标签、源域样本数据的特征向量及对应的标签确定对应的对比损失;
根据所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数。
4.根据权利要求3所述的方法,其特征在于,根据所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数包括:
保持所述特征生成网络和所述分类网络的参数,调节所述域判别网络的参数以最大化所述对抗损失;
保持所述域判别网络的参数,调节所述特征生成网络和所述分类网络的参数以最小化所述分类损失、对比损失和对抗损失。
5.根据权利要求3所述的方法,其特征在于,所述特征生成网络包括第一特征生成子网络、第二特征生成子网络和第三特征生成子网络;
所述第一特征生成子网络用于根据所述源域样本数据和对应的掩码数据生成所述源域样本数据的特征向量;
所述第二特征生成子网络用于根据所述目标域样本数据生成所述目标域样本数据的特征向量;
所述第三特征生成子网络用于对所述源域样本数据的特征向量和所述目标域样本数据的特征向量进行特征处理,获取预定维度的所述源域样本数据的特征向量和所述目标域样本数据的特征向量。
6.根据权利要求5所述的方法,其特征在于,所述特征生成网络包括多个第三特征生成子网络,多个所述第三特征生成子网络权值共享。
7.根据权利要求5所述的方法,其特征在于,所述第一特征生成子网络具有对应的第一解码器和第一自编码损失,所述第二特征生成子网络具有对应的第二解码器和第二自编码损失;
根据所述目标域数据集、所述源域数据集和所述掩码数据集训练所述用户分类模型中的特征生成网络、分类网络和域判别网络还包括:
根据所述第一特征生成子网络的输入值和所述第一解码器的输出值计算所述第一自编码损失;
根据所述第二特征生成子网络的输入值和所述第二解码器的输出值计算所述第二自编码损失;
根据所述第一自编码损失和所述第二自编码损失确定所述特征生成网络的损失;
根据所述特征生成网络的损失、所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数。
8.根据权利要求7所述的方法,其特征在于,根据所述特征生成网络的损失、所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数包括:
保持所述特征生成网络和所述分类网络的参数,调节所述域判别网络的参数以最大化所述对抗损失;
保持所述域判别网络的参数,调节所述特征生成网络和所述分类网络的参数以最小化所述特征生成网络的损失、分类损失、对比损失和对抗损失。
9.一种用户分类模型的训练装置,其特征在于,所述装置包括
数据获取单元,被配置为获取目标域数据集、源域数据集以及随机噪声,所述源域数据集包括具有标签的多个源域样本数据,所述目标域数据集包括多个目标域样本数据,其中部分目标域样本数据具有标签;其中,所述源域样本数据包括对应用户在预定时间范围内的历史任务记录,所述目标域样本数据中不存在预定时间范围内的历史任务记录;
掩码数据确定单元,被配置为将各所述源域样本数据和所述随机噪声输入至所述用户分类模型的掩码生成网络,确定所述源域数据集对应的掩码数据集;
训练单元,被配置为根据所述目标域数据集、所述源域数据集和所述掩码数据集训练所述用户分类模型中的特征生成网络、分类网络和域判别网络;
模型确定单元,被配置为响应于所述特征生成网络、分类网络和域判别网络对应的损失函数满足预定条件,确定所述用户分类模型;
其中,所述特征生成网络用于获取所述目标域样本数据和源域样本数据的特征向量,所述分类网络用于根据所述目标域样本数据和源域样本数据的特征向量、所述源域样本数据的标签、部分所述目标域样本数据的标签确定对应的标签预测值,所述域判别网络用于使得所述目标域数据集与所述源域数据集进行特征对齐。
10.根据权利要求9所述的装置,其特征在于,所述标签用于表征用户在未来预定时间范围内的任务状态,所述任务状态包括用户在未来预定时间范围内会执行任务、以及用户在未来预定时间范围内不会执行任务。
11.根据权利要求9所述的装置,其特征在于,训练单元包括:
第一处理子单元,被配置为将所述目标域数据集、所述源域数据集和所述掩码数据集输入至所述特征生成网络进行处理,确定各所述目标域样本数据和源域样本数据的特征向量;
第二处理子单元,被配置为将各所述目标域样本数据和源域样本数据的特征向量输入至所述域判别网络中进行处理,确定对应的特征分布;
对抗损失确定子单元,被配置为根据所述特征分布确定所述域判别网络对应的对抗损失;
第三处理子单元,被配置为将各所述目标域样本数据和源域样本数据的特征向量输入至所述分类网络,确定各所述目标域样本数据和源域样本数据的标签预测值;
分类损失确定子单元,被配置为根据具有标签的所述目标域样本数据的标签及对应的标签预测值、源域样本数据的标签及对应的标签预测值确定所述分类网络对应的分类损失;
对比损失确定子单元,被配置为根据具有标签的所述目标域样本数据的特征向量及对应的标签、源域样本数据的特征向量及对应的标签确定对应的对比损失;
第一参数调节子单元,被配置为根据所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数。
12.根据权利要求11所述的装置,其特征在于,第一参数调节子单元包括:
第一调节模块,被配置为保持所述特征生成网络和所述分类网络的参数,调节所述域判别网络的参数以最大化所述对抗损失;
第二调节模块,被配置为保持所述域判别网络的参数,调节所述特征生成网络和所述分类网络的参数以最小化所述分类损失、对比损失和对抗损失。
13.根据权利要求11所述的装置,其特征在于,所述特征生成网络包括第一特征生成子网络、第二特征生成子网络和第三特征生成子网络;
所述第一特征生成子网络用于根据所述源域样本数据和对应的掩码数据生成所述源域样本数据的特征向量;
所述第二特征生成子网络用于根据所述目标域样本数据生成所述目标域样本数据的特征向量;
所述第三特征生成子网络用于对所述源域样本数据的特征向量和所述目标域样本数据的特征向量进行特征处理,获取预定维度的所述源域样本数据的特征向量和所述目标域样本数据的特征向量。
14.根据权利要求13所述的装置,其特征在于,所述特征生成网络包括多个第三特征生成子网络,多个所述第三特征生成子网络权值共享。
15.根据权利要求13所述的装置,其特征在于,所述第一特征生成子网络具有对应的第一解码器和第一自编码损失,所述第二特征生成子网络具有对应的第二解码器和第二自编码损失;
训练单元还包括:
第一自编码损失确定子单元,被配置为根据所述第一特征生成子网络的输入值和所述第一解码器的输出值计算所述第一自编码损失;
第二自编码损失确定子单元,被配置为根据所述第二特征生成子网络的输入值和所述第二解码器的输出值计算所述第二自编码损失;
特征损失确定子单元,被配置为根据所述第一自编码损失和所述第二自编码损失确定所述特征生成网络的损失;
第二参数调节子单元,被配置为根据所述特征生成网络的损失、所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数。
16.根据权利要求15所述的装置,其特征在于,第二参数调节子单元包括:
第三参数调节模块,被配置为第一调节子保持所述特征生成网络和所述分类网络的参数,调节所述域判别网络的参数以最大化所述对抗损失;
第四参数调节模块,被配置为保持所述域判别网络的参数,调节所述特征生成网络和所述分类网络的参数以最小化所述特征生成网络的损失、分类损失、对比损失和对抗损失。
17.一种电子设备,包括存储器和处理器,其特征在于,所述存储器用于存储一条或多条计算机程序指令,其中,所述一条或多条计算机程序指令被所述处理器执行以实现如权利要求1-8中任一项所述的方法。
18.一种计算机可读存储介质,其上存储计算机程序指令,其特征在于,所述计算机程序指令在被处理器执行时以实现如权利要求1-8中任一项所述的方法。
CN202011042272.XA 2020-09-28 2020-09-28 用户分类模型的训练方法、装置、电子设备及存储介质 Pending CN112116025A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011042272.XA CN112116025A (zh) 2020-09-28 2020-09-28 用户分类模型的训练方法、装置、电子设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011042272.XA CN112116025A (zh) 2020-09-28 2020-09-28 用户分类模型的训练方法、装置、电子设备及存储介质

Publications (1)

Publication Number Publication Date
CN112116025A true CN112116025A (zh) 2020-12-22

Family

ID=73796850

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011042272.XA Pending CN112116025A (zh) 2020-09-28 2020-09-28 用户分类模型的训练方法、装置、电子设备及存储介质

Country Status (1)

Country Link
CN (1) CN112116025A (zh)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112634048A (zh) * 2020-12-30 2021-04-09 第四范式(北京)技术有限公司 一种反洗钱模型的训练方法及装置
WO2022166578A1 (zh) * 2021-02-05 2022-08-11 北京嘀嘀无限科技发展有限公司 用于域自适应学习的方法、装置、设备、介质和产品
CN116502271A (zh) * 2023-06-21 2023-07-28 杭州金智塔科技有限公司 基于生成模型的隐私保护跨域推荐方法

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107392255A (zh) * 2017-07-31 2017-11-24 深圳先进技术研究院 少数类图片样本的生成方法、装置、计算设备及存储介质
CN108304876A (zh) * 2018-01-31 2018-07-20 国信优易数据有限公司 分类模型训练方法、装置及分类方法及装置
US20180240233A1 (en) * 2017-02-22 2018-08-23 Siemens Healthcare Gmbh Deep Convolutional Encoder-Decoder for Prostate Cancer Detection and Classification
CN109242029A (zh) * 2018-09-19 2019-01-18 广东省智能制造研究所 识别分类模型训练方法和系统
CN111028861A (zh) * 2019-12-10 2020-04-17 苏州思必驰信息科技有限公司 频谱掩码模型训练方法、音频场景识别方法及系统

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180240233A1 (en) * 2017-02-22 2018-08-23 Siemens Healthcare Gmbh Deep Convolutional Encoder-Decoder for Prostate Cancer Detection and Classification
CN107392255A (zh) * 2017-07-31 2017-11-24 深圳先进技术研究院 少数类图片样本的生成方法、装置、计算设备及存储介质
CN108304876A (zh) * 2018-01-31 2018-07-20 国信优易数据有限公司 分类模型训练方法、装置及分类方法及装置
CN109242029A (zh) * 2018-09-19 2019-01-18 广东省智能制造研究所 识别分类模型训练方法和系统
CN111028861A (zh) * 2019-12-10 2020-04-17 苏州思必驰信息科技有限公司 频谱掩码模型训练方法、音频场景识别方法及系统

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
王一鸣: ""基于生成对抗网络的图像修复算法研究"", 《中国优秀硕士学位论文全文数据库 信息科技辑》 *

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112634048A (zh) * 2020-12-30 2021-04-09 第四范式(北京)技术有限公司 一种反洗钱模型的训练方法及装置
WO2022166578A1 (zh) * 2021-02-05 2022-08-11 北京嘀嘀无限科技发展有限公司 用于域自适应学习的方法、装置、设备、介质和产品
CN116502271A (zh) * 2023-06-21 2023-07-28 杭州金智塔科技有限公司 基于生成模型的隐私保护跨域推荐方法
CN116502271B (zh) * 2023-06-21 2023-09-19 杭州金智塔科技有限公司 基于生成模型的隐私保护跨域推荐方法

Similar Documents

Publication Publication Date Title
US20220327714A1 (en) Motion Engine
CN112116025A (zh) 用户分类模型的训练方法、装置、电子设备及存储介质
CN108182389B (zh) 基于大数据与深度学习的用户数据处理方法、机器人系统
KR102504077B1 (ko) 이미지 기반의 captcha 과제
US10938927B2 (en) Machine learning techniques for processing tag-based representations of sequential interaction events
US11699095B2 (en) Cross-domain recommender systems using domain separation networks and autoencoders
KR20190101327A (ko) 구독 제품 가격 산정 방법 및 가격 산정 장치
US20210035183A1 (en) Method and system for a recommendation engine utilizing progressive labeling and user content enrichment
JP6062384B2 (ja) タスク割り当てサーバ、タスク割り当て方法およびプログラム
CN114463830B (zh) 亲缘关系判定方法、装置、电子设备及存储介质
Li et al. An uncertainty-based model of the effects of fixation on choice
CN112116024B (zh) 用户分类模型的方法、装置、电子设备和存储介质
Zhang et al. Statistical inference after adaptive sampling in non-markovian environments
Hung Robust Kalman filter based on a fuzzy GARCH model to forecast volatility using particle swarm optimization
CN110866609B (zh) 解释信息获取方法、装置、服务器和存储介质
CN117114139A (zh) 一种面向噪声标签的联邦学习方法
JP7015927B2 (ja) 学習モデル適用システム、学習モデル適用方法、及びプログラム
CN104598866B (zh) 一种基于人脸的社交情商促进方法及系统
CN112270571B (zh) 一种用于冷启动广告点击率预估模型的元模型训练方法
CN115631006A (zh) 智能推荐银行产品的方法、装置、存储介质及计算机设备
US20230027309A1 (en) System and method for image de-identification to humans while remaining recognizable by machines
Alikhani et al. DynaFuse: Dynamic Fusion for Resource Efficient Multi-Modal Machine Learning Inference
Yang et al. Collaborative Filtering Recommendation Algorithm Based on AdaBoost-Naïve Bayesian Algorithm
CN117912640B (zh) 基于域增量学习的抑郁障碍检测模型训练方法及电子设备
US11727273B2 (en) System improvement for deep neural networks

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
AD01 Patent right deemed abandoned
AD01 Patent right deemed abandoned

Effective date of abandoning: 20220125