CN114792173B - 预测模型训练方法和装置 - Google Patents
预测模型训练方法和装置 Download PDFInfo
- Publication number
- CN114792173B CN114792173B CN202210694769.2A CN202210694769A CN114792173B CN 114792173 B CN114792173 B CN 114792173B CN 202210694769 A CN202210694769 A CN 202210694769A CN 114792173 B CN114792173 B CN 114792173B
- Authority
- CN
- China
- Prior art keywords
- loss
- target
- user
- sample
- probability
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 58
- 238000000034 method Methods 0.000 title claims abstract description 36
- 238000012545 processing Methods 0.000 claims abstract description 18
- 239000013598 vector Substances 0.000 claims description 66
- 230000006399 behavior Effects 0.000 claims description 46
- 238000004364 calculation method Methods 0.000 claims description 11
- 230000003993 interaction Effects 0.000 claims description 6
- 230000002452 interceptive effect Effects 0.000 claims description 5
- 230000004931 aggregating effect Effects 0.000 claims description 4
- 238000006243 chemical reaction Methods 0.000 description 16
- 230000006870 function Effects 0.000 description 14
- 230000008569 process Effects 0.000 description 9
- 238000010586 diagram Methods 0.000 description 7
- 230000000694 effects Effects 0.000 description 5
- 230000009471 action Effects 0.000 description 2
- 238000004590 computer program Methods 0.000 description 2
- 230000009466 transformation Effects 0.000 description 2
- 230000002776 aggregation Effects 0.000 description 1
- 238000004220 aggregation Methods 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 230000003542 behavioural effect Effects 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 239000000126 substance Substances 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/02—Knowledge representation; Symbolic representation
- G06N5/022—Knowledge engineering; Knowledge acquisition
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q10/00—Administration; Management
- G06Q10/04—Forecasting or optimisation specially adapted for administrative or management purposes, e.g. linear programming or "cutting stock problem"
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F3/00—Input arrangements for transferring data to be processed into a form capable of being handled by the computer; Output arrangements for transferring data from processing unit to output unit, e.g. interface arrangements
- G06F3/01—Input arrangements or combined input and output arrangements for interaction between user and computer
- G06F3/048—Interaction techniques based on graphical user interfaces [GUI]
- G06F3/0484—Interaction techniques based on graphical user interfaces [GUI] for the control of specific functions or operations, e.g. selecting or manipulating an object, an image or a displayed text element, setting a parameter value or selecting a range
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q30/00—Commerce
- G06Q30/02—Marketing; Price estimation or determination; Fundraising
- G06Q30/0201—Market modelling; Market analysis; Collecting market data
- G06Q30/0202—Market predictions or forecasting for commercial activities
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q30/00—Commerce
- G06Q30/02—Marketing; Price estimation or determination; Fundraising
- G06Q30/0241—Advertisements
- G06Q30/0251—Targeted advertisements
- G06Q30/0269—Targeted advertisements based on user profile or attribute
- G06Q30/0271—Personalized advertisement
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q30/00—Commerce
- G06Q30/06—Buying, selling or leasing transactions
- G06Q30/0601—Electronic shopping [e-shopping]
- G06Q30/0631—Item recommendations
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Business, Economics & Management (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Strategic Management (AREA)
- General Engineering & Computer Science (AREA)
- Development Economics (AREA)
- Accounting & Taxation (AREA)
- Finance (AREA)
- Data Mining & Analysis (AREA)
- Economics (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- General Business, Economics & Management (AREA)
- Marketing (AREA)
- Entrepreneurship & Innovation (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Computational Linguistics (AREA)
- Game Theory and Decision Science (AREA)
- Life Sciences & Earth Sciences (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Biophysics (AREA)
- Human Resources & Organizations (AREA)
- Operations Research (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Tourism & Hospitality (AREA)
- Quality & Reliability (AREA)
- Human Computer Interaction (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本说明书实施例提供一种训练预测模型的方法和装置,该预测模型包括第一分支和第二分支;根据该方法,首先获取目标样本,其包括样本特征,第一标签和第二标签;第一标签指示用户是否点击了目标对象;第二标签表示该用户是否实施与目标对象有关的目标行为。利用预测模型对样本特征进行处理,第一分支输出用户点击目标对象的第一概率;第二分支输出用户实施目标行为的第二概率。基于第一标签值和第一概率,确定第一损失。并且,在预设条件满足的情况下,根据第二标签值和第二概率确定第二损失,并根据第一损失和第二损失确定该目标样本的预测损失,其中预设条件包括,第一标签值指示用户点击了目标对象。于是可以根据上述预测损失,训练该预测模型。
Description
技术领域
本说明书一个或多个实施例涉及人工智能领域,尤其涉及一种预测用户行为的预测模型的训练方法及装置。
背景技术
点击后的转化率(CVR),是指用户在点击目标对象后,对其实施预定转化行为的比例,该预定转化行为可以包括购买,收藏,加购,转发等等。对CVR进行预测,是当前多种技术场景中一项重要而核心的预测任务。例如,在推荐系统中,可以将待推荐的备选物品的CVR预测值作为重要的排序因素,用来帮助平衡用户的点击行为和后续转化行为。
然而,在用机器学习的方式训练CVR预测模型,从而对CVR进行预测时,常常面临多项困难,例如包括,训练数据的稀少,样本选择的偏差,等等。这使得CVR预测模型的训练效果仍不够理想。
希望有一种新的方案,能够提升用户行为预测模型的训练效果,从而进一步提升对用户行为的预测效果。
发明内容
本说明书一个或多个实施例描述了一种预测模型的训练方法,能够更有效地训练得到预测模型,用于在全空间进行用户转化行为的预测。
根据第一方面,提供一种训练预测模型的方法,所述预测模型包括第一分支和第二分支,所述包括:
获取目标样本,其包括样本特征,第一标签和第二标签;第一标签指示目标样本对应的用户是否点击了目标对象;第二标签表示该用户是否实施了与所述目标对象相关的目标行为;
利用所述预测模型对所述样本特征进行模型处理,从而所述第一分支输出所述用户点击所述目标对象的第一概率;所述第二分支输出该用户实施所述目标行为的第二概率;
根据第一标签的第一标签值和所述第一概率,确定第一损失;
在满足预设条件的情况下,根据第二标签的第二标签值和所述第二概率确定第二损失,并根据所述第一损失和第二损失确定所述目标样本的预测损失;其中所述预设条件包括,所述第一标签值指示所述用户点击了所述目标对象;
根据所述预测损失,训练所述预测模型。
根据一种实施方式,上述方法还包括:在不满足所述预设条件的情况下,根据所述第一标签值和第二标签值的第一乘积和所述第一概率和第二概率的第二乘积,确定第三损失;并根据所述第一损失和第三损失确定所述预测损失。
在一个实施例中,所述预设条件还包括,所述第二标签值指示所述用户没有实施目标行为。
在一个实施例中,所述预测模型还包括嵌入层;所述模型处理包括,利用所述嵌入层将所述样本特征编码为嵌入向量,将所述嵌入向量分别输入所述第一分支和第二分支。
进一步的,在一个示例中,样本特征包括所述用户的用户特征,和所述目标对象的对象特征;利用所述嵌入层将所述样本特征编码为嵌入向量,具体包括:将所述用户特征编码为第一向量;以及,将所述对象特征编码为第二向量;对所述第一向量和第二向量进行聚合,得到所述嵌入向量。
在一个更进一步的例子中,所述样本特征还包括所述用户和所述目标对象的交互特征;利用所述嵌入层将所述样本特征编码为嵌入向量,还包括:将所述交互特征编码为第三向量;其中,所述嵌入向量还基于所述第三向量得到。
根据一种实施方式,上述预测模型还包括门控单元和乘积计算单元;所述门控单元在所述预设条件得到满足时,阻断目标通路,在所述预设条件不满足时,导通所述目标通路,其中所述目标通路用于将所述第一概率和第二概率传输到所述乘积计算单元,以计算所述第二乘积。
在不同实施例中,所述目标行为包括以下之一:购买,收藏,加入购物车,下载,转发。
根据第二方面,提供了一种训练预测模型的装置,所述预测模型包括第一分支和第二分支,所述装置包括:
样本获取单元,配置为获取目标样本,其包括样本特征,第一标签和第二标签;第一标签指示目标样本对应的用户是否点击了目标对象;第二标签表示该用户是否实施了与所述目标对象有关的目标行为;
概率预测单元,配置为利用所述预测模型对所述样本特征进行模型处理,从而所述第一分支输出所述用户点击所述目标对象的第一概率;所述第二分支输出该用户实施所述目标行为的第二概率;
第一损失确定单元,配置为根据第一标签的第一标签值和所述第一概率,确定第一损失;
第二损失确定单元,配置为在满足预设条件的情况下,根据第二标签的第二标签值和所述第二概率确定第二损失,并根据所述第一损失和第二损失确定所述目标样本的预测损失;其中所述预设条件包括,所述第一标签值指示所述用户点击了所述目标对象;
训练单元,配置为根据所述预测损失,训练所述预测模型。
根据第三方面,提供了一种计算机可读存储介质,其上存储有计算机程序,当该计算机程序在计算机中执行时,令计算机执行上述第一方面的方法。
根据第四方面,提供了一种计算设备,包括存储器和处理器,存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现上述第一方面的方法。
在本说明书实施例提供的训练方案中,在训练样本的标签指示出对应用户点击了目标对象的情况下,采用CVR预测任务的预测损失替代点击+转化(CTCVR)预测任务的预测损失,如此避免了该情况下CTR任务和CTCVR任务的梯度冲突,从而实现了更好的训练效果。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
图1示出全空间多任务模型的示意图;
图2示出第一标签y和第二标签z的各种可能取值下第一梯度和第二梯度的总结图表;
图3示出根据一个实施例的训练预测模型的方法流程图;
图4示出预测模型的结构和训练过程示意图;
图5示出根据一个实施例的训练装置的结构示意图。
具体实施方式
下面结合附图,对本说明书提供的方案进行描述。
在典型的互联网场景中,各种服务平台通过互联网展现多种对象。用户在浏览到感兴趣的对象时,首先会对其进行点击,以获得进一步信息,或者获得转化的入口。然后,用户可能对点击后的部分对象,实施转化行为。例如在典型的电商场景中,服务平台会呈现各种推荐的商品。用户会点击其中感兴趣的一些商品,进而有可能购买其中的一些。因此,用户行为遵循一个严格的序列模式,也就是,展现(impression)->点击->转化。
对于服务平台而言,最为关心的是用户的转化率,因此,预测点击后的转化率(post-click conversion rate,CVR)是服务平台中一项重要的任务。然而,如上所述,用户行为呈现严格的序列性,转化相对于点击的时序依赖,使得CVR预测模型的构建面临几项挑战和困难。
一项困难是训练数据的稀少和稀疏。可以理解的,对象的每一次展现都可以贡献出一个关于点击率CTR(click-through rate)预测的训练样本,有点击即为正样本,没有点击也可以形成负样本。然而,只有在展现之后用户进行了点击的那些对象,才能作为CVR预测的训练样本。实践中,发生点击的真实比例是很低的,因此,可用于CVR预测训练的样本数量远远小于可用于点击率CTR预测训练的样本。
另一项困难是样本选择偏差。可以理解,训练后的CVR预测模型需要针对所有展现对象的全空间进行预测,而不可能在用户真正点击某个对象后再去预测其转化率。然而,如前所述,CVR预测的训练样本仅仅是那些产生了点击的展现对象。取决于不同用户的个人选择,被点击的展现对象与所有展现对象的全空间往往具有非常不同的数据分布特点,而训练数据空间与预测空间的数据分布差异,会明显影响训练得到的CVR模型的预测性能。
为了解决上述难题,在一种实施方案中,提出了全空间多任务模型,该模型可以很好地利用多任务学习的方式,基于CVR预测中用户行为的序列模式特点,间接地对CVR预测进行建模和训练。
图1示出全空间多任务模型的示意图。如图1所示,该多任务模型包括共享的嵌入层和两个预测分支,即对应于点击率CTR预测的第一分支,和对应于CVR预测的第二分支。下面描述该模型的训练过程。
首先获取训练样本集,其中的每个样本用户具有是否点击和点击后是否转化的两个标签,如此,任意样本i可以表示为(xi,yi,zi),其中xi表示样本特征,例如样本用户和展现对象的特征,yi以不同的值表示样本用户是否点击了展现对象,例如yi=0表示没有点击,yi=1表示点击;zi则表示样本用户是否在点击之后实施了目标行为,该目标行为是在对应场景下被认为该用户得到转化的行为。例如,在电商场景下,目标行为可以是为购买,加入购物车,在推荐场景下,目标行为还可以是收藏,转发给朋友,等等。基于CVR预测的行为序列模式约束,要求仅在yi=1的情况下,zi才有可能等于1,如果yi=0,则zi必然为0,即:。下面在不引起混淆的情况下,在一些描述中省去下标i。
当将训练样本的样本特征x输入上述模型,嵌入层对样本特征x进行嵌入编码,然后将得到的嵌入向量分别输入到CTR分支和CVR分支。CTR分支用于预测用户对展现对象进行点击的概率,得到第一概率pCTR,该第一概率可以表示为:。CVR分支用于预测用户点击之后(即y=1情况下的用户)实施转化行为的概率,得到第二概率pCVR,该第二概率可以表示为:。
如图1所示,第一分支预测出的第一概率pCTR和第二分支预测的第二概率pCVR,传输到一个乘积计算单元或算子,得到上述总概率pCTCVR。通过上式(1),可以从总概率pCTCVR和第一概率pCTR,反推出第二概率pCVR。并且,总概率pCTCVR和第一概率pCTR都可以基于全量展现对象构成的全空间进行建模。具体的,对于CTR任务,可以基于任意样本的样本特征x和其y标签进行训练;而对于CTCVR任务,可以以任意样本的y标签和z标签的乘积y*z作为新标签进行训练,其中,只要y=0(没有发生点击),新标签必然为0;仅当y和z都为1时,新标签才为1,意味着,用户先点击又发生了转化。因此,上述全空间多任务模型引入了预测pCTR和pCTCVR这两个任务作为辅助任务,而将主要任务,预测pCVR作为辅助任务过程中的一个中间结果。由于上述两个辅助任务pCTR和pCTCVR都是在全空间建模的,因此导出的pCVR也适用于全空间,由此解决了训练数据的稀疏问题和样本偏差问题。
如上所述,上述全空间多任务模型基于上述两个辅助任务进行训练。具体的,可以将CTR任务和CTCVR任务各自的预测损失之和作为总预测损失,由此训练该模型。这可以表示为:
其中,Lesmm为该全空间多任务模型的总预测损失,Lctr为CTR预测任务的预测损失,Lctcvr为CTCVR任务的预测损失,Θ为整个模型中的参数。
更进一步的,Lctr和Lctcvr可以分别表示为:
其中,分别表示CTR分支的模型参数,CVR分支的模型参数和嵌入层的模型参数,表示模型中的预测函数,表示用于计算预测损失的损失函数,例如为交叉熵损失函数,均方差损失函数,等等。通过式(3)可见,CTR任务的预测损失Lctr基于第一概率pCTR(通过预测函数计算得出)和标签yi而确定;通过式(4)可见,CTCVR任务的预测损失Lctcvr基于总概率pCTCVR和新标签yi*zi而确定。
然而,实践操作中,上述全空间多任务模型却常常出现收敛很慢甚至无法收敛的训练困难。发明人对此进行了深入研究,发现,当按照以上(2)-(4)式进行训练时,不同任务针对模型中CTR分支的参数调整方向,即梯度,存在严重冲突,该冲突导致了难以收敛的训练困难。
具体的,针对任意单个样本(样本特征为x),假定上述模型针对CTR任务的预测分数(即第一概率pCTR的预测值)为,针对CVR任务的预测分数(即第二概率pCVR的预测值)为,则与上述式(3)和(4)对应的,CTR任务的预测损失可以表示为:
CTCVR任务的预测损失可以表示为:
典型地,上述损失函数往往采用如下的交叉熵形式:
式(7)中的y泛指标签值,h泛指预测分数。根据式(7),损失相对于预测分数h的梯度表示为:
分别将公式(5)和(6)中具体的标签值和预测分数形式代入公式(8)的梯度计算式,可以得到:
逐个分析y和z的各种可能取值下,式(9)和(10)中第一梯度和第二梯度的形式,可以得到图2所示的总结图表。如图2所示,在y=1,z=0的情况下,由于预测值均在(0,1)范围内,可以得到如下关系:
因此,不管gt的符号如何,在y=1,z=0的情况下,第一梯度和第二梯度必然一个大于0,一个小于0,也就是说,CTR任务和CTCVR任务对于模型参数的梯度是180度完全相反的方向。这一发现解释了模型训练困难的根本原因。
以上结合交叉熵形式的损失函数分析了梯度冲突的产生。针对其他形式的损失函数,例如均方差损失,铰链损失等形式,经过发明人分析,发现这些常用形式的损失函数下,两个辅助任务也会针对同一模型参数部分产生不同程度的梯度冲突。
为了提升模型的训练效果,解决梯度冲突问题,本说明书中进一步提出一种优化的训练方法。图3示出根据一个实施例的训练预测模型的方法流程图,该方法流程可以通过任何具有计算、处理能力的计算单元、平台、服务器、设备等执行。如前所述,该预测模型用于预测用户点击后实施特定行为的概率,在结构上采用与前述全空间多任务模型相似的多分支结构。图4示出预测模型的结构和训练过程示意图。如图4所示,该预测模型包括第一分支和第二分支,第一分支用于预测用户的点击率,对应于CTR预测分支;第二分支用于预测点击后实施特定转化行为的概率,对应于CVR预测分支。下面结合图3和图4描述上述优化的训练方法。
根据该优化的训练方法,如图3所示,首先在步骤S31,获取目标样本,其包括样本特征x,第一标签y和第二标签z;第一标签y指示目标样本对应的用户是否点击了目标对象;第二标签z表示该用户是否在点击之后实施了与所述目标对象相关的目标行为。
可以理解,目标对象可以是互联网场景中的各种展现对象,例如,商品,广告,文章,音乐,图片,等等;上述目标行为可以是对应场景中,被认为用户得到转化的各种行为,例如,购买行为,加入购物车的行为,下载(音乐或图片)行为,收藏行为,转发行为,等等。
然后,在步骤S32,利用预测模型对上述样本特征进行模型处理,从而,预测模型的第一分支输出该用户点击所述目标对象的第一概率;第二分支输出该用户实施目标行为的第二概率。
为进行上述模型处理,在一个实施例中,预测模型包括用于对样本特征进行嵌入处理的嵌入层,如图4所示。在图4的示例中,两个分支的嵌入层使用共享的查找表进行特征映射处理,从而该嵌入层可以认为是两个分支共享。相应的,模型处理过程可以包括,首先利用嵌入层(其中的模型参数为)将样本特征x编码为嵌入向量,将该嵌入向量分别输入第一分支和第二分支。具体的,在一个示例中,如图4所示,样本特征x包括用户的用户特征和目标对象的对象特征。相应的,在嵌入层中,可以将用户特征编码为第一向量;以及,将对象特征编码为第二向量;然后对第一向量和第二向量进行聚合,得到上述嵌入向量。聚合的方式可以包括,如图4所示的拼接,或者也可以采用求和等组合方式。
在一个例子中,样本特征x还可以包括用户和对象之间的交互特征,例如描述二者之间的交互历史信息。在这样的情况下,嵌入层还可以对交互特征进行编码,得到第三向量,然后对上述第一向量、第二向量,以及该第三向量整体进行聚合,得到该目标样本对应的嵌入向量。
在基于嵌入层得到目标样本的嵌入向量后,该嵌入向量被分别输入到第一分支和第二分支。第一分支和第二分支可以通过各种神经网络结构实现,例如,图4中示出为多层感知机MLP。
第一分支对应于CTR分支,用于预测用户的点击率。当将上述嵌入向量输入该第一分支,第一分支利用其网络模型参数对该嵌入向量进行处理,输出该用户点击目标对象的第一概率,即pCTR。第二分支对应于CVR分支,用于预测用户的点击后转化率。当将上述嵌入向量输入该第二分支,第二分支利用其网络模型参数对该嵌入向量进行处理,输出该用户点击后实施目标行为的第二概率,即pCVR。
而关于另一项预测任务的损失,则预先设定一个与标签值有关的条件,根据该条件是否满足,采用不同方式确定另一项损失。该预设条件可以对应于,常规训练过程中容易发生梯度冲突的情况。
根据一种实施方式,该预设条件设置为,第一标签y指示出用户点击了目标对象。在该实施方式中,在步骤S34,判断预设条件是否满足,即判断第一标签y的取值(下文称为第一标签值)为第一值还是第二值,其中第一值指示用户点击了目标对象,第二值指示用户没有点击目标对象。典型的,第一值设置为1,第二值设置为0。
当第一标签值等于第一值时,即y=1的情况下,预设条件满足,执行第一流程分支,其中包括步骤S35和S36。在步骤S35,根据第二标签z的取值(下文称为第二标签值)和第二概率pCVR确定第二损失。该第二损失可以认为是直接进行CVR预测的预测损失,可以表示为:
而当第一标签值等于第二值时,即y=0的情况下,预设条件不满足,则执行第二流程分支,其中包括步骤S37和S38。在步骤S37,根据第一标签值和第二标签值的第一乘积和第一概率和第二概率的第二乘积,确定第三损失。该第三损失对应于前述的pCTCVR任务的预测损失,例如可以表示为前述公式(6)的形式。
综合上述两个流程分支,针对单个目标样本,其预测损失可以总结为:
根据公式(15),预测损失可以表示为第一损失与混合损失之和,当y=1时,混合损失取第二损失,当y=0时,混合损失取第三损失。
最终,两个流程分支汇总至步骤S39,根据上述预测损失,训练预测模型。可以理解,以上描述了单个目标样本的预测损失确定方式;当基于一批样本构成的样本集进行一次模型更新时,则针对其中的各个样本,均采用上述方式确定预测损失,并根据各个样本的预测损失之和,更新预测模型的模型参数。
在另一种实施方式中,该预设条件设置为,第一标签y指示出用户点击了目标对象,并且,第二标签z指示出用户没有实施目标行为。在该实施方式中,在步骤S34,判断预设条件是否满足,即判断是否同时满足:第一标签y取1,且第二标签z取0。
如果满足上述预设条件,则执行第一流程分支,其中包括步骤S35和S36。如果不满足上述预设条件,例如,对于除(y=1,z=0)之外的其他情况,执行第二流程分支,其中包括步骤S37和S38。步骤S35-S38的具体执行过程不复赘述。
以上方案的核心要点在于,根据至少关联于第一标签y的预设条件是否满足,采用不同的方式确定其预测损失。在一个实施例中,这可以通过在训练阶段,在预测模型中设置门控单元来实现。如图4所示,该预测模型还包括门控单元和乘积计算单元(乘积算子)。门控单元在上述预设条件不满足时,例如y=0时,导通目标通路,该通路用于将第一概率pCTR和第二概率pCVR传输到乘积计算单元,以计算第二乘积,即pCTCVR。该通路的导通意味着,可以基于pCTCVR计算第三损失,此时,模型结构如(A)部分所示。
而当上述预设条件得到满足时,例如y=1时,或者y=1且z=0时,门控单元阻断上述目标通路,如此使得,第一概率pCTR和第二概率pCVR不再传输到乘积计算单元计算pCTCVR,取而代之地,直接基于第一概率pCTR/第二概率pCVR,分别计算两个预测分支各自的预测损失,据此训练模型。此时,模型结构转化为如(B)部分所示。
如前所述,全空间多任务模型的主要梯度冲突发生y=1,z=0的情况下,此时,在交叉熵损失函数形式下,CTR任务和CTCVR任务对于模型参数的梯度完全相反。而根据图3的流程和图4的架构可见,在优化的训练方案中,在容易发生梯度冲突1的情况下(对应于预设条件满足的情况),采用CVR任务的预测损失替代CTCVR任务的预测损失,如此避免CTR任务和CTCVR任务的梯度冲突。
发明人针对图3的训练过程进行了进一步的数学分析和实验论证,结果表明,采用该训练过程,不管是针对CTR分支的模型参数,还是针对CVR分支的模型参数,均可以有效避免梯度冲突,并且达到更好的训练效果。
另一方面,与上述训练过程相对应的,本说明书实施例还披露一种预测模型的训练装置,该装置可以部署在任何具有计算、处理能力的计算单元、平台、服务器、设备中。图5示出根据一个实施例的训练装置的结构示意图,该装置用于训练预测模型,该预测模型包括第一分支和第二分支。如图5所示,该装置500包括:
样本获取单元51,配置为获取目标样本,其包括样本特征,第一标签和第二标签;第一标签指示目标样本对应的用户是否点击了目标对象;第二标签表示该用户是否实施了与所述目标对象相关的目标行为;
概率预测单元52,配置为利用所述预测模型对所述样本特征进行模型处理,从而所述第一分支输出所述用户点击所述目标对象的第一概率;所述第二分支输出该用户实施所述目标行为的第二概率;
第一损失确定单元53,配置为根据第一标签的第一标签值和所述第一概率,确定第一损失;
第二损失确定单元54,配置为在满足预设条件的情况下,根据第二标签的第二标签值和所述第二概率确定第二损失,并根据所述第一损失和第二损失确定所述目标样本的预测损失;其中所述预设条件包括,所述第一标签值指示所述用户点击了所述目标对象;
训练单元55,配置为根据所述预测损失,训练所述预测模型。
根据一种实施方式,该装置500还包括:第三损失确定单元56,配置为在不满足所述预设条件的情况下,根据所述第一标签值和第二标签值的第一乘积和所述第一概率和第二概率的第二乘积,确定第三损失;并根据所述第一损失和第三损失确定所述预测损失。
根据一种实施方式,所述预设条件包括,所述第一标签值指示所述用户点击了所述目标对象,并且,所述第二标签值指示所述用户没有实施所述目标行为。
根据一种实现方式,所述预测模型还包括嵌入层;概率预测单元52中涉及的模型处理包括,利用所述嵌入层将所述样本特征编码为嵌入向量,将所述嵌入向量分别输入所述第一分支和第二分支。
在一个实施例中,所述样本特征包括所述用户的用户特征,和所述目标对象的对象特征;上述利用所述嵌入层将所述样本特征编码为嵌入向量,具体包括:将所述用户特征编码为第一向量;以及,将所述对象特征编码为第二向量;对所述第一向量和第二向量进行聚合,得到所述嵌入向量。
进一步的,在一个示例中,所述样本特征还包括所述用户和所述目标对象的交互特征;利用所述嵌入层将所述样本特征编码为嵌入向量,还包括:将所述交互特征编码为第三向量;此时,所述嵌入向量还基于所述第三向量得到。
根据一种实施方式,所述预测模型还包括门控单元和乘积计算单元;所述门控单元在所述预设条件得到满足时,阻断目标通路,在所述预设条件不满足时,导通所述目标通路,其中所述目标通路用于将所述第一概率和第二概率传输到所述乘积计算单元,以计算所述第二乘积。
在不同实施例中,所述目标行为包括以下之一:购买,收藏,加入购物车,下载,转发。
根据另一方面的实施例,还提供一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行前述优化的训练方法。
根据再一方面的实施例,还提供一种计算设备,包括存储器和处理器,该存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现前述优化的训练方法。
本领域技术人员应该可以意识到,在上述一个或多个示例中,本发明所描述的功能可以用硬件、软件、固件或它们的任意组合来实现。当使用软件实现时,可以将这些功能存储在计算机可读介质中或者作为计算机可读介质上的一个或多个指令或代码进行传输。
以上所述的具体实施方式,对本发明的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上所述仅为本发明的具体实施方式而已,并不用于限定本发明的保护范围,凡在本发明的技术方案的基础之上,所做的任何修改、等同替换、改进等,均应包括在本发明的保护范围之内。
Claims (15)
1.一种训练预测模型的方法,所述预测模型包括第一分支和第二分支,所述方法包括:
获取目标样本,其包括样本特征,第一标签和第二标签;第一标签指示目标样本对应的用户是否点击了目标对象;第二标签表示该用户是否在点击目标对象之后实施了与所述目标对象相关的目标行为;
利用所述预测模型对所述样本特征进行模型处理,从而所述第一分支输出所述用户点击所述目标对象的第一概率;所述第二分支输出该用户实施所述目标行为的第二概率;
根据第一标签的第一标签值和所述第一概率,确定第一损失;
在满足预设条件的情况下,根据第二标签的第二标签值和所述第二概率确定第二损失,并根据所述第一损失和第二损失确定所述目标样本的预测损失;其中所述预设条件包括,所述第一标签值指示所述用户点击了所述目标对象;
在不满足所述预设条件的情况下,根据所述第一标签值和第二标签值的第一乘积和所述第一概率和第二概率的第二乘积,确定第三损失;并根据所述第一损失和第三损失确定所述预测损失;
根据所述预测损失,训练所述预测模型。
2.根据权利要求1所述的方法,其中,所述预设条件还包括,所述第二标签值指示所述用户没有实施所述目标行为。
3.根据权利要求1所述的方法,其中,所述预测模型还包括嵌入层;
所述模型处理包括,利用所述嵌入层将所述样本特征编码为嵌入向量,将所述嵌入向量分别输入所述第一分支和第二分支。
4.根据权利要求3所述的方法,其中,所述样本特征包括所述用户的用户特征,和所述目标对象的对象特征;
利用所述嵌入层将所述样本特征编码为嵌入向量,包括:
将所述用户特征编码为第一向量;以及,将所述对象特征编码为第二向量;
对所述第一向量和第二向量进行聚合,得到所述嵌入向量。
5.根据权利要求4所述的方法,其中,所述样本特征还包括所述用户和所述目标对象的交互特征;
利用所述嵌入层将所述样本特征编码为嵌入向量,还包括:
将所述交互特征编码为第三向量;
所述嵌入向量还基于所述第三向量得到。
6.根据权利要求1所述的方法,其中,所述预测模型还包括门控单元和乘积计算单元;所述门控单元在所述预设条件得到满足时,阻断目标通路,在所述预设条件不满足时,导通所述目标通路,其中所述目标通路用于将所述第一概率和第二概率传输到所述乘积计算单元,以计算所述第二乘积。
7.根据权利要求1所述的方法,其中,所述目标行为包括以下之一:购买,收藏,加入购物车,下载,转发。
8.一种训练预测模型的装置,所述预测模型包括第一分支和第二分支,所述装置包括:
样本获取单元,配置为获取目标样本,其包括样本特征,第一标签和第二标签;第一标签指示目标样本对应的用户是否点击了目标对象;第二标签表示该用户是否在点击目标对象之后实施了与所述目标对象相关的目标行为;
概率预测单元,配置为利用所述预测模型对所述样本特征进行模型处理,从而所述第一分支输出所述用户点击所述目标对象的第一概率;所述第二分支输出该用户实施所述目标行为的第二概率;
第一损失确定单元,配置为根据第一标签的第一标签值和所述第一概率,确定第一损失;
第二损失确定单元,配置为在满足预设条件的情况下,根据第二标签的第二标签值和所述第二概率确定第二损失,并根据所述第一损失和第二损失确定所述目标样本的预测损失;其中所述预设条件包括,所述第一标签值指示所述用户点击了所述目标对象;
第三损失确定单元,配置为在不满足所述预设条件的情况下,根据所述第一标签值和第二标签值的第一乘积和所述第一概率和第二概率的第二乘积,确定第三损失;并根据所述第一损失和第三损失确定所述预测损失;
训练单元,配置为根据所述预测损失,训练所述预测模型。
9.根据权利要求8所述的装置,其中,所述预设条件还包括,所述第二标签值指示所述用户没有实施所述目标行为。
10.根据权利要求8所述的装置,其中,所述预测模型还包括嵌入层;
所述模型处理包括,利用所述嵌入层将所述样本特征编码为嵌入向量,将所述嵌入向量分别输入所述第一分支和第二分支。
11.根据权利要求10所述的装置,其中,所述样本特征包括所述用户的用户特征,和所述目标对象的对象特征;
利用所述嵌入层将所述样本特征编码为嵌入向量,包括:
将所述用户特征编码为第一向量;以及,将所述对象特征编码为第二向量;
对所述第一向量和第二向量进行聚合,得到所述嵌入向量。
12.根据权利要求11所述的装置,其中,所述样本特征还包括所述用户和所述目标对象的交互特征;
利用所述嵌入层将所述样本特征编码为嵌入向量,还包括:
将所述交互特征编码为第三向量;
所述嵌入向量还基于所述第三向量得到。
13.根据权利要求8所述的装置,其中,所述预测模型还包括门控单元和乘积计算单元;所述门控单元在所述预设条件得到满足时,阻断目标通路,在所述预设条件不满足时,导通所述目标通路,其中所述目标通路用于将所述第一概率和第二概率传输到所述乘积计算单元,以计算所述第二乘积。
14.根据权利要求8所述的装置,其中,所述目标行为包括以下之一:购买,收藏,加入购物车,下载,转发。
15.一种用于训练预测模型的计算设备,包括存储器和处理器,其特征在于,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现权利要求1-7中任一项所述的方法。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210694769.2A CN114792173B (zh) | 2022-06-20 | 2022-06-20 | 预测模型训练方法和装置 |
US18/337,960 US20230409929A1 (en) | 2022-06-20 | 2023-06-20 | Methods and apparatuses for training prediction model |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210694769.2A CN114792173B (zh) | 2022-06-20 | 2022-06-20 | 预测模型训练方法和装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114792173A CN114792173A (zh) | 2022-07-26 |
CN114792173B true CN114792173B (zh) | 2022-10-04 |
Family
ID=82463478
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210694769.2A Active CN114792173B (zh) | 2022-06-20 | 2022-06-20 | 预测模型训练方法和装置 |
Country Status (2)
Country | Link |
---|---|
US (1) | US20230409929A1 (zh) |
CN (1) | CN114792173B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116432039B (zh) * | 2023-06-13 | 2023-09-05 | 支付宝(杭州)信息技术有限公司 | 协同训练方法及装置、业务预测方法及装置 |
Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114462526A (zh) * | 2022-01-28 | 2022-05-10 | 腾讯科技(深圳)有限公司 | 一种分类模型训练方法、装置、计算机设备及存储介质 |
Family Cites Families (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109754105B (zh) * | 2017-11-07 | 2024-01-05 | 华为技术有限公司 | 一种预测方法及终端、服务器 |
CN109284864B (zh) * | 2018-09-04 | 2021-08-24 | 广州视源电子科技股份有限公司 | 行为序列获取方法及装置、用户转化率预测方法及装置 |
EP3862893A4 (en) * | 2019-10-31 | 2021-12-01 | Huawei Technologies Co., Ltd. | RECOMMENDATION MODEL LEARNING PROCESS, RECOMMENDATION PROCESS, DEVICE, AND COMPUTER READABLE MEDIA |
CN111310814A (zh) * | 2020-02-07 | 2020-06-19 | 支付宝(杭州)信息技术有限公司 | 利用不平衡正负样本对业务预测模型训练的方法及装置 |
CN111460150B (zh) * | 2020-03-27 | 2023-11-10 | 北京小米松果电子有限公司 | 一种分类模型的训练方法、分类方法、装置及存储介质 |
CN111767982A (zh) * | 2020-05-20 | 2020-10-13 | 北京大米科技有限公司 | 用户转换预测模型的训练方法、装置、存储介质以及电子设备 |
CN111523044B (zh) * | 2020-07-06 | 2020-10-23 | 南京梦饷网络科技有限公司 | 用于推荐目标对象的方法、计算设备和计算机存储介质 |
CN112819024B (zh) * | 2020-07-10 | 2024-02-13 | 腾讯科技(深圳)有限公司 | 模型处理方法、用户数据处理方法及装置、计算机设备 |
CN111737584B (zh) * | 2020-07-31 | 2020-12-08 | 支付宝(杭州)信息技术有限公司 | 行为预测系统的更新方法及装置 |
CN113392359A (zh) * | 2021-08-18 | 2021-09-14 | 腾讯科技(深圳)有限公司 | 多目标预测方法、装置、设备及存储介质 |
CN114330499A (zh) * | 2021-11-30 | 2022-04-12 | 腾讯科技(深圳)有限公司 | 分类模型的训练方法、装置、设备、存储介质及程序产品 |
CN114240555A (zh) * | 2021-12-17 | 2022-03-25 | 北京沃东天骏信息技术有限公司 | 训练点击率预测模型和预测点击率的方法和装置 |
-
2022
- 2022-06-20 CN CN202210694769.2A patent/CN114792173B/zh active Active
-
2023
- 2023-06-20 US US18/337,960 patent/US20230409929A1/en active Pending
Patent Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114462526A (zh) * | 2022-01-28 | 2022-05-10 | 腾讯科技(深圳)有限公司 | 一种分类模型训练方法、装置、计算机设备及存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN114792173A (zh) | 2022-07-26 |
US20230409929A1 (en) | 2023-12-21 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US20210248651A1 (en) | Recommendation model training method, recommendation method, apparatus, and computer-readable medium | |
CN111523044B (zh) | 用于推荐目标对象的方法、计算设备和计算机存储介质 | |
US10235403B2 (en) | Parallel collective matrix factorization framework for big data | |
CN113516522B (zh) | 媒体资源推荐方法、多目标融合模型的训练方法及装置 | |
CN114117216A (zh) | 推荐概率预测方法及装置、计算机存储介质和电子设备 | |
KR20210095579A (ko) | 아이템 추천방법, 시스템, 전자기기 및 기록매체 | |
CN114792173B (zh) | 预测模型训练方法和装置 | |
Liu et al. | Multi-task recommendations with reinforcement learning | |
JP2023024950A (ja) | 共有されたニューラルアイテム表現をコールドスタート推薦に用いる改良型のレコメンダシステム及び方法 | |
CN115564517A (zh) | 商品推荐方法、预测模型训练方法和相关设备 | |
CN114239675A (zh) | 融合多模态内容的知识图谱补全方法 | |
CN113592593B (zh) | 序列推荐模型的训练及应用方法、装置、设备及存储介质 | |
CN111340605B (zh) | 训练用户行为预测模型、用户行为预测的方法和装置 | |
CN117056595A (zh) | 一种交互式的项目推荐方法、装置及计算机可读存储介质 | |
CN111382846B (zh) | 基于迁移学习的训练神经网络模型的方法和装置 | |
CN110413946A (zh) | 使用交替最小二乘优化来在线训练和更新因子分解机 | |
CN111178987B (zh) | 训练用户行为预测模型的方法和装置 | |
Yin et al. | PeNet: A feature excitation learning approach to advertisement click-through rate prediction | |
KR102253365B1 (ko) | 예측 장치 및 이의 동작 방법 | |
CN110880141A (zh) | 一种深度双塔模型智能匹配算法及装置 | |
CN115423565B (zh) | 应用于云端互联网交互流程的大数据分析方法及ai系统 | |
CN113688315B (zh) | 一种基于无信息损失图编码的序列推荐方法 | |
Zhang et al. | Model-based Reinforcement Learning for Parameterized Action Spaces | |
WO2023242907A1 (ja) | 情報処理装置、情報処理方法及び情報処理プログラム | |
WO2023095680A1 (ja) | 予測装置、学習装置、予測方法、学習方法、予測プログラム及び学習プログラム |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |