CN115563510A - 一种点击率预估模型的训练方法及相关装置 - Google Patents

一种点击率预估模型的训练方法及相关装置 Download PDF

Info

Publication number
CN115563510A
CN115563510A CN202211533551.5A CN202211533551A CN115563510A CN 115563510 A CN115563510 A CN 115563510A CN 202211533551 A CN202211533551 A CN 202211533551A CN 115563510 A CN115563510 A CN 115563510A
Authority
CN
China
Prior art keywords
network
model
output
bert
layer
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202211533551.5A
Other languages
English (en)
Other versions
CN115563510B (zh
Inventor
王龙滔
蔡振宇
纪承
张智慧
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Sohu New Power Information Technology Co ltd
Original Assignee
Beijing Sohu New Power Information Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Sohu New Power Information Technology Co ltd filed Critical Beijing Sohu New Power Information Technology Co ltd
Priority to CN202211533551.5A priority Critical patent/CN115563510B/zh
Publication of CN115563510A publication Critical patent/CN115563510A/zh
Application granted granted Critical
Publication of CN115563510B publication Critical patent/CN115563510B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Abstract

本申请公开了一种点击率预估模型的训练方法及相关装置。以xDeepFM模型为基础,利用Bert网络构建新的嵌入层得到新的模型架构;在新的模型架构中,输入层对训练数据集中样本特征组独热编码;将输入层输出的独热编码后的特征输入到新的嵌入层中,由CAN网络和Bert网络分别进行特征嵌入;采用新的模型架构中CIN网络和DNN网络分别基于CAN网络及Bert网络的输出学习;基于输入层的输出、CIN网络的输出和DNN网络的输出,对架构进行网络参数调整得到点击率预估模型。由于Bert网络发生的特征交互是基于元素级的,因此注意到特征向量内部特征信息的相互作用,加强特征的挖掘潜力,提升模型预估效果的准确性。

Description

一种点击率预估模型的训练方法及相关装置
技术领域
本申请涉及机器学习技术领域,特别是涉及一种点击率预估模型的训练方法及相关装置。
背景技术
机器学习在推荐系统中具有广泛的应用。推荐系统中进行内容推荐的重要数据依据之一是点击率。为了进行内容推荐,常需要通过机器学习的方式预测内容的点击率。
xDeepFM是预估点击率(Click-Through-Rate,CTR)的一种经典模型。目前大多数CTR预估模型例如xDeepFM等,其嵌入层的特征提取部分只是简单地进行特征嵌入,并不能有效深入地探索到潜在的特征联合性质,导致一定程度上有用信息被忽视。Co-Action指多个特征相互关联,共同影响最终输出,同时共同影响最初的输入。即便在推荐系统的CTR模型中引入特征交叉网络(Co-Action Net, CAN)进行特征提取和交叉融合,也仅仅只是在特征的向量级(vector-wise)上进行上述操作,忽略了特征向量内部信息之间的相互作用,导致预估的CTR准确性不足。目前,对内容数据中的特性进行有效的抽取和对特征信息的联合是整个推荐问题的重中之重,影响着预估的CTR准确性,甚至决定着推荐系统的推荐效果。
发明内容
基于上述问题,本申请提供了一种点击率预估模型的训练方法及相关装置,目的是改进现有的点击率预估模型在特征提取方面的不足,提升模型预估效果的准确性。
本申请实施例公开了如下技术方案:
本申请第一方面提供了一种点击率预估模型的训练方法,包括:
以xDeepFM模型为基础,利用Bert网络构建新的嵌入层,得到新的模型架构;所述新的嵌入层中包括所述xDeepFM模型的嵌入层中原有的CAN网络和所述Bert网络;
在所述新的模型架构中,在输入层对训练数据集中的样本特征组进行独热编码;
将所述输入层输出的独热编码后的特征输入到所述新的嵌入层中,由所述CAN网络和所述Bert网络分别进行特征嵌入;
采用所述新的模型架构中的CIN网络和DNN网络分别基于所述CAN网络的输出进行学习,并采用所述CIN网络和所述DNN网络分别基于所述Bert网络的输出进行学习;
基于所述输入层的输出、所述CIN网络的输出和所述DNN网络的输出,对所述新的模型架构进行网络参数调整,得到点击率预估模型。
可选地,所述基于所述输入层的输出、所述CIN网络的输出和所述DNN网络的输出,对所述新的模型架构进行网络参数调整,得到点击率预估模型,具体包括:
在所述新的模型架构中,在输出层利用激活函数基于所述输入层的输出、所述CIN网络的输出和所述DNN网络的输出进行运算,得到模型激活后的预测值;
利用损失函数基于所述预测值和所述样本特征组对应的标签得到模型的损失值;
根据所述损失值对网络参数进行调整,得到点击率预估模型。
可选地,点击率预估模型的训练方法还包括:
以操作者接收特性曲线下与坐标轴围成的面积AUC作为准确率评价指标,获得Bert网络具有不同层数时模型的AUC和损失值;
根据Bert网络具有不同层数时模型的AUC和损失值,确定Bert网络的目标层数。
可选地,所述获得Bert网络具有不同层数时模型的AUC和损失值,具体包括:
利用训练数据集中的样本特征组和对应的标签获取Bert网络具有不同层数时模型在所述训练数据集的第一AUC和第一平均损失值;
利用测试数据集中的样本特征组和对应的标签获取Bert网络具有不同层数时模型在所述测试数据集的第二AUC和第二平均损失值;
所述根据Bert网络具有不同层数时模型的AUC和损失值,确定Bert网络的目标层数,具体包括:
根据Bert网络具有不同层数时模型的第一AUC、第一平均损失值、第二AUC和第二平均损失值,确定Bert网络的目标层数。
可选地,所述利用损失函数基于所述预测值和所述样本特征组对应的标签得到模型的损失值,具体包括:
采用交叉熵损失函数基于所述预测值和所述样本特征组对应的标签得到模型的损失值;交叉熵损失函数计算损失值L的公式如下:
Figure 347324DEST_PATH_IMAGE001
其中,yi表示第i个所述样本特征组对应的标签,ŷi表示基于第i个样本特征组得到的预测值,N表示样本特征组的组数,i表示样本特征组的序数。
可选地,所述激活函数为sigmoid激活函数;所述预测值为通过以下公式运算得到的:
Figure 818757DEST_PATH_IMAGE002
其中,σ表示sigmoid激活函数,α表示所述输入层的独热编码后的特征;b表示全局偏置项;公式中参数上的T表示对参数的转置;
wLinear表示线性层对应的网络参数;
wCAN-DNN表示DNN网络对应于CAN网络的网络参数;
wCAN-CIN表示CIN网络对应于CAN网络的网络参数;
wBert-DNN表示DNN网络对应于Bert网络的网络参数;
wBert-CIN表示CIN网络对应于Bert网络的网络参数;
wLinear、wCAN-DNN、wCAN-CIN、wBert-DNN、wBert-CIN和b均为可学习的网络参数;
xk CAN-DNN表示经过所述CAN网络进行特征嵌入后进一步由DNN网络的输出;
p+表示经过所述CAN网络进行特征嵌入后进一步由CIN网络的输出;
xk Bert-DNN表示经过所述Bert网络进行特征嵌入后进一步由DNN网络的输出;
p++表示经过所述Bert网络进行特征嵌入后进一步由CIN网络的输出。
可选地,所述样本特征组包括:原始特征、计数特征、标签均值特征、nunique特征、视频特征、音频特征和标题特征;
其中,所述原始特征包括:用户编号,用户所在城市,视频编号,作者编号,视频所在城市,背景音乐编号,播放次数和视频持续时长;
所述计数特征包括:用户编号,播放次数,视频编号,作者编号,用户-作者编号;
所述标签均值特征包括:用户编号,播放次数,视频编号,用户-作者编号,用户播放次数和播放渠道;
所述nunique特征包括:用户-城市编号,用户-视频编号,用户-作者编号,用户-音乐编号,视频-城市编号,视频-用户编号和作者-用户编号。
可选地,点击率预估模型的训练方法还包括:
获取待预估的视频的特征组;
将所述待预估的视频的特征组输入到所述点击率预估模型中,获得所述点击率预估模型输出的点击预估结果。
本申请第二方面提供了一种点击率预估模型的训练装置,包括:
嵌入层构建模块,用于以xDeepFM模型为基础,利用Bert网络构建新的嵌入层,得到新的模型架构;所述新的嵌入层中包括所述xDeepFM模型的嵌入层中原有的CAN网络和所述Bert网络;
独热编码模块,用于在所述新的模型架构中,在输入层对训练数据集中的样本特征组进行独热编码;
特征嵌入模块,用于将所述输入层输出的独热编码后的特征输入到所述新的嵌入层中,由所述CAN网络和所述Bert网络分别进行特征嵌入;
交互信息学习模块,用于采用所述新的模型架构中的CIN网络和DNN网络分别基于所述CAN网络的输出进行学习,并采用所述CIN网络和所述DNN网络分别基于所述Bert网络的输出进行学习;
参数调整模块,用于基于所述输入层的输出、所述CIN网络的输出和所述DNN网络的输出,对所述新的模型架构进行网络参数调整,得到点击率预估模型。
可选地,所述参数调整模块,具体用于:
在所述新的模型架构中,在输出层利用激活函数基于所述输入层的输出、所述CIN网络的输出和所述DNN网络的输出进行运算,得到模型激活后的预测值;
利用损失函数基于所述预测值和所述样本特征组对应的标签得到模型的损失值;
根据所述损失值对网络参数进行调整,得到点击率预估模型。
相较于现有技术,本申请具有以下有益效果:
本申请提供的点击率预估模型的训练方法,以xDeepFM模型为基础,利用Bert网络构建新的嵌入层得到新的模型架构;所述新的嵌入层中包括所述xDeepFM模型的嵌入层中原有的CAN网络和所述Bert网络。在所述新的模型架构中,在输入层对训练数据集中的样本特征组进行独热编码;将所述输入层输出的独热编码后的特征输入到所述新的嵌入层中,由所述CAN网络和所述Bert网络分别进行特征嵌入;采用所述新的模型架构中的CIN网络和DNN网络分别基于所述CAN网络的输出进行学习,并采用所述CIN网络和所述DNN网络分别基于所述Bert网络的输出进行学习;基于所述输入层的输出、所述CIN网络的输出和所述DNN网络的输出,对所述新的模型架构进行网络参数调整,得到点击率预估模型。由于Bert网络发生的特征交互是基于元素级(bit-wise)的,因此改进的新的嵌入层能够注意到特征向量内部特征信息的相互作用,新的模型架构加强了Co-Action特征的挖掘潜力,提升模型预估效果的准确性。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为一种xDeepFM网络结构的示意图;
图2为本申请实施例提供的一种点击率预估模型的架构图;
图3为本申请实施例提供的一种点击率预估模型的训练方法的流程图;
图4为本申请实施例提供的一种点击率预估模型的数据流示意图;
图5为本申请实施例提供的一种点击率预估模型的训练装置的结构示意图。
具体实施方式
点击率预估模型的嵌入层的特征提取部分往往仅做简单的特征嵌入,不能有效深入地探索到潜在的特征联合性质,导致模型存在准确性偏低的缺陷。为了解决此问题,本申请中提出一种点击率预估模型的训练方法及相关装置,通过新的模型结构,尤其是嵌入层结构,实现对特征的向量级和元素级的抽取、联合。从而实现对特征关联的深度挖掘,提升模型的预估准确性。
为了使本技术领域的人员更好地理解本申请方案,下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
图1为一种xDeepFM网络结构的示意图。xDeepFM使用CIN(CompressedInteraction Network,压缩交叉网络)和DNN(Deep Neural Network,深度神经网络)的双路结构,同时以显式和隐式的方式学习高阶特征。其中CIN网络负责进行显示交互特征的提取,DNN网络负责进行隐式交互特征的提取。CIN与DNN两个部分同时共享嵌入(Embedding)层。
如图1所示的结构中,包括输入层、嵌入层、线性层、CIN网络、DNN网络、输出层。输入层对输入至自身的信息独热编码,形成编码后的特征,进入到嵌入层。在xDeepFM中将嵌入层输出的独热编码后的特征视为输入特征。由嵌入层处理后交给CIN网络和DNN网络。线性层也会对输入层输出的内容做线性处理后,与CIN网络和DNN网络的输出,一并提供到输出层(或称输出单元)。对于传统的xDeepFM网络结构,其中嵌入层具有CAN网络。
本申请技术方案中提出对图1所示的结构进行改变,具体地,改变嵌入层以及嵌入层与其之后的网络的连接关系。在本申请技术方案中,提供的点击率预估模型的结构示意图如图2。在该图所示的模型架构中,嵌入层中在原有的CAN网络基础上还加入了Bert网络。Bert的英文全称为:Bidirectional Encoder Representations from Transformers;中文全称为基于转换器的双向编码器表示。故Bert网络的含义是基于转换器的双向编码器表示网络。Bert网络也会做特征嵌入的工作并将自身输出分别提供给后续的CIN网络和DNN网络。即,CIN网络接收Bert网络以及CAN网络的输出,并且DNN网络也接收Bert网络以及CAN网络的输出。
图3为本申请实施例提供的一种点击率预估模型的训练方法的流程图。图4为本申请实施例提供的一种点击率预估模型的数据流示意图。如图3所示,点击率预估模型的训练方法包括:
S301,以xDeepFM模型为基础,利用Bert网络构建新的嵌入层,得到新的模型架构。
如图2所示的新的模型架构,在模型架构中的嵌入层不同于图1所示的xDeepFM模型的嵌入层。为了便于区分说明,将新的模型架构中的嵌入层称为新的嵌入层。新的嵌入层中包括xDeepFM模型的嵌入层中原有的CAN网络和Bert网络。在嵌入层中,原有的CAN网络与新加入的Bert网络均发挥作用。
S302,在新的模型架构中,在输入层对训练数据集中的样本特征组进行独热编码。
在正式训练模型之前,准备了训练数据集。训练数据集中包括多个样本特征组。每个样本特征组可以包括:原始特征、计数特征、标签均值特征、nunique特征、视频特征、音频特征和标题特征。由于模型训练完成后需要执行的任务为点击率预估,因此训练数据集中所准备的样本特征组中的特征,也是基于这一需求背景下提供的。具体涉及到视频相关的诸多特征。下面对上述部分特征进行举例说明。
原始特征包括:用户编号,用户所在城市,视频编号,作者编号,视频所在城市,背景音乐编号,播放次数和视频持续时长;
计数特征包括:用户编号,播放次数,视频编号,作者编号,用户-作者编号;
标签均值特征包括:用户编号,播放次数,视频编号,用户-作者编号,用户播放次数和播放渠道;
nunique特征包括:用户-城市编号,用户-视频编号,用户-作者编号,用户-音乐编号,视频-城市编号,视频-用户编号和作者-用户编号。
以上样本特征组提供到输入层后,由输入层对其进行独热编码,形成0/1形式的独热向量作为特征,并由输入层输出独热编码后的特征。
S303,将输入层输出的独热编码后的特征输入到新的嵌入层中,由CAN网络和Bert网络分别进行特征嵌入。
Bert网络可以视为提取器。其舍弃了Transformer结构中的解码模块,保留了编码模块。由于Bert网络的加入,能够使提取部分具有双向编码的能力以及更加强大的特征提取能力。Bert网络所执行的嵌入操作,是将词嵌入张量、语句分块张量和位置编码张量三部分直接做加和处理。
总体而言,在本申请技术方案中,嵌入层的CAN网络对经过独热编码后的特征进行基于特征向量级的特征提取,Bert网络对经过独热编码后的特征进行基于元素级的特征提取。通过两大角度的特征提取,将嵌入层提取的交叉特征再分别独立输入到后续的CIN网络和DNN网络中进行监督学习训练。
S304,采用新的模型架构中的CIN网络和DNN网络分别基于CAN网络的输出进行学习,并采用CIN网络和DNN网络分别基于Bert网络的输出进行学习。
结合图2所示的模型结构,CIN网络对CAN网络以及Bert网络的输出进行学习,与此同时DNN网络也对CAN网络以及Bert网络的输出进行学习。
S305,基于输入层的输出、CIN网络的输出和DNN网络的输出,对新的模型架构进行网络参数调整,得到点击率预估模型。
在具体实现时,本步骤可以包含如下过程;在新的模型架构中,在输出层利用激活函数基于输入层的输出、CIN网络的输出和DNN网络的输出进行运算,得到模型激活后的预测值;利用损失函数基于预测值和样本特征组对应的标签得到模型的损失值;根据损失值对网络参数进行调整,得到点击率预估模型。
在实际应用中,输入层从一个分支输出到了线性层。线性层可以对输入层的输出进行简单的线性处理,因此,上述针对三种输出进行的运算也可以视为:基于线性层的输出、CNN网络的输出和DNN网络的输出进行运算。由于输入模型输入层的样本特征组具有对应的标签,因此根据标签以及运算得到的预测值可以获得损失值。
下面介绍利用损失函数基于预测值和样本特征组对应的标签得到模型的损失值的一种示例实现方式。在本示例中,采用交叉熵损失函数进行二分类损失的计算,基于预测值和样本特征组对应的标签得到模型的损失值。交叉熵损失函数计算损失值L的公式如下:
Figure 138880DEST_PATH_IMAGE001
在以上函数公式中,yi表示第i个样本特征组对应的标签,ŷi表示基于第i个样本特征组得到的预测值,N表示样本特征组的组数,i表示样本特征组的序数。标签可以通过预先的准备直接获取,下面提供了一种预测值计算方式。
对于基于任一个样本特征组得到的预测值,其可以表示为ŷ。若激活函数为sigmoid激活函数,则预测值ŷ可以通过以下公式计算:
Figure 303145DEST_PATH_IMAGE002
其中,σ表示sigmoid激活函数,α表示输入层的独热编码后的特征;b表示全局偏置项;
wLinear表示线性层对应的网络参数;公式中参数上的T表示对参数的转置;
wCAN-DNN表示DNN网络对应于CAN网络的网络参数;
wCAN-CIN表示CIN网络对应于CAN网络的网络参数;
wBert-DNN表示DNN网络对应于Bert网络的网络参数;
wBert-CIN表示CIN网络对应于Bert网络的网络参数;
wLinear、wCAN-DNN、wCAN-CIN、wBert-DNN、wBert-CIN和b均为可学习的网络参数;
xk CAN-DNN表示经过CAN网络进行特征嵌入后进一步由DNN网络的输出;p+表示经过CAN网络进行特征嵌入后进一步由CIN网络的输出;
xk Bert-DNN表示经过Bert网络进行特征嵌入后进一步由DNN网络的输出;p++表示经过Bert网络进行特征嵌入后进一步由CIN网络的输出。
以上实施例介绍的点击率预估模型的训练方法中,对于具有新的嵌入层结构的模型结构进行训练,目的是训练得到预估准确率更高的点击率预估模型。在模型的嵌入层在已有CAN网络的基础上加入了Bert网络。由于Bert网络发生的特征交互是基于元素级的,因此新的嵌入层能够注意到特征向量内部特征信息的相互作用,加强特征的挖掘潜力。加强了模型的嵌入层特征提取与交叉能力。相比于仅有CAN网络的嵌入层,该提升模型预估效果的准确性。
并且在原本利用CAN网络做特征工程的xDeepFM模型的损失函数基础上,添加独立引入的Bert-DNN和Bert-CIN部分所产生的额外损失进行加权。使损失函数运算得到的损失值更加客观,与实际模型的新的结构相互匹配,从而调优更加准确,便捷。对于原先模型的损失,其包含线性层的部分、CAN-DNN部分、CAN-CIN部分,现在的模型中增加了Bert-DNN和Bert-CIN部分的损失。这些损失可以按照不同比例加权,再通过激活函数输出,经过多次调参搜索后得到最优的训练参数。
在实际应用中,作为可选的实现方式,还可以通过模型的AUC(Area Under Curve)和模型的损失值,确定Bert网络的优选层数。在本申请实施例中将优选的层数称为目标层数。例如,在训练之初,何种层数的Bert网络能够在训练后,使得模型的AUC较高,损失值较低,是并不确定的。通过依据AUC和损失值,可以确定出目标层数。
上述Area具体指操作者接收特性曲线(receiver operating characteristiccurve,ROC)下与坐标轴围成的面积。AUC这个面积的数值不会大于1。又由于ROC曲线一般都处于y=x这条直线的上方,所以AUC的取值范围在0.5和1之间。AUC越接近1.0,检测方法真实性越高;AUC等于0.5时,则真实性最低,无应用价值。
在本申请中,点击率预估模型的训练方法可以进一步包括:
以AUC作为准确率评价指标,获得Bert网络具有不同层数时模型的AUC和损失值;根据Bert网络具有不同层数时模型的AUC和损失值,确定Bert网络的目标层数。如此,也实现了对Bert网络的层数的调优。
通过调优Bert网络的层数,提升模型性能,并增加模型的可训练性,加快了模型的训练速度,同时提升模型的训练上限。是模型具有更好的容错和鲁棒性。通过AUC能够用来评估新的模型结构相比于现有的xDeepFM模型是否有改进和提高,以及改进的具体幅度。
作为可选实现方式,本申请中除了训练数据集,还建立了测试数据集。测试数据集中也包含样本特征组,测试数据集中的样本特征组与训练数据集中的样本特征组存在差异。在调优Bert网络的层数时,具体可以基于训练数据集以及测试数据集分别的AUC和损失值来执行:
利用训练数据集中的样本特征组和对应的标签获取Bert网络具有不同层数时模型在训练数据集的第一AUC和第一平均损失值;利用测试数据集中的样本特征组和对应的标签获取Bert网络具有不同层数时模型在测试数据集的第二AUC和第二平均损失值;根据Bert网络具有不同层数时模型的第一AUC、第一平均损失值、第二AUC和第二平均损失值,确定Bert网络的目标层数。
此外,上述AUC也可以是在每一种Bert层数下的最高AUC,也可以是每一种Bert层数下这些不同的样本特征组的平均AUC。具体视需求而定。
实际应用中,取不同的层数1~5进行上述调优测试。模型在训练5轮之后,模型的AUC值趋于稳定并且能够持续小幅度提升。利用原有的xDeepFM模型,训练数据集平均损失达到了0.0315,以AUC衡量的准确率达到了96.25%,测试数据集平均损失为0.0422,准确率达到92.07%。
在采用图2所示的模型结构的前提下,相比于原有的xDeepFM模型,训练难度增加,但是模型的收敛上限有所提高,且这种趋势随着Bert网络的层数的增加变得更加明显。测试中当Bert网络层数为两层时,达到了训练模型的最优结果,平均训练数据集的损失能够达到0.0306,准确率达到96.41%。测试数据集的损失达到了0.0423,准确率达到92.17%。准确率在训练数据集和测试数据集上分别得到了0.16%和0.1%性能的提升。在千万级的数据集规模上,这种提升是较为显著的。
结合图4所示的模型数据流,对输入的样本特征组构建one-hot编码作为独热编码后的特征,以CAN网络+Bert网络进行特征嵌入(embedding),分别在输入到后续的CIN网络和DNN网络中进行训练。将CIN网络和DNN网络的输出结果以及输入层输出的one-hot编码后的特征一并输入到输出层。输出层利用sigmoid函数进行激活并计算交叉熵损失。随后可以通过AdaGrad自适应梯度最小化参数,随后得到输出结果,可以计算AUC。计算AUC后和/或计算得到损失值后,可以对模型中嵌入层的网络参数以及CIN和DNN网络的参数进行调整,以使模型具备更优的准确率。除此之外,本申请还可以采用softmax函数作为激活函数使用。
当模型训练完成后,可以获取待预估的视频的特征组,将待预估的视频的特征组输入到点击率预估模型中,获得点击率预估模型输出的点击预估结果。如此,完成了对于所训练的点击率预估模型的应用。
基于前述实施例提供的点击率预估模型的训练方法,相应地,本申请还提供了一种点击率预估模型的训练装置。图5为一种点击率预估模型的训练装置的结构示意图。如图5所示,该点击率预估模型的训练装置包括:
嵌入层构建模块,用于以xDeepFM模型为基础,利用Bert网络构建新的嵌入层,得到新的模型架构;所述新的嵌入层中包括所述xDeepFM模型的嵌入层中原有的CAN网络和所述Bert网络;
独热编码模块,用于在所述新的模型架构中,在输入层对训练数据集中的样本特征组进行独热编码;
特征嵌入模块,用于将所述输入层输出的独热编码后的特征输入到所述新的嵌入层中,由所述CAN网络和所述Bert网络分别进行特征嵌入;
交互信息学习模块,用于采用所述新的模型架构中的CIN网络和DNN网络分别基于所述CAN网络的输出进行学习,并采用所述CIN网络和所述DNN网络分别基于所述Bert网络的输出进行学习;
参数调整模块,用于基于所述输入层的输出、所述CIN网络的输出和所述DNN网络的输出,对所述新的模型架构进行网络参数调整,得到点击率预估模型。
由于Bert网络发生的特征交互是基于元素级的,因此注意到特征向量内部特征信息的相互作用,加强特征的挖掘潜力,提升模型预估效果的准确性。
可选地,所述参数调整模块,具体用于:
在所述新的模型架构中,在输出层利用激活函数基于所述输入层的输出、所述CIN网络的输出和所述DNN网络的输出进行运算,得到模型激活后的预测值;
利用损失函数基于所述预测值和所述样本特征组对应的标签得到模型的损失值;
根据所述损失值对网络参数进行调整,得到点击率预估模型。
可选地,点击率预估模型的训练装置还包括:
目标层数确定模块,用于以操作者接收特性曲线下与坐标轴围成的面积AUC作为准确率评价指标,获得Bert网络具有不同层数时模型的AUC和损失值;根据Bert网络具有不同层数时模型的AUC和损失值,确定Bert网络的目标层数。
可选地,目标层数确定模块,具体用于:
利用训练数据集中的样本特征组和对应的标签获取Bert网络具有不同层数时模型在所述训练数据集的第一AUC和第一平均损失值;
利用测试数据集中的样本特征组和对应的标签获取Bert网络具有不同层数时模型在所述测试数据集的第二AUC和第二平均损失值;
根据Bert网络具有不同层数时模型的第一AUC、第一平均损失值、第二AUC和第二平均损失值,确定Bert网络的目标层数。
可选地,所述参数调整模块,具体用于:
采用交叉熵损失函数基于所述预测值和所述样本特征组对应的标签得到模型的损失值;交叉熵损失函数计算损失值L的公式如下:
Figure 75929DEST_PATH_IMAGE001
其中,yi表示第i个所述样本特征组对应的标签,ŷi表示基于第i个样本特征组得到的预测值,N表示样本特征组的组数,i表示样本特征组的序数。
可选地,所述激活函数为sigmoid激活函数;所述预测值为通过以下公式运算得到的:
Figure 147790DEST_PATH_IMAGE002
其中,σ表示sigmoid激活函数,α表示所述输入层的独热编码后的特征;b表示全局偏置项;公式中参数上的T表示对参数的转置;
wLinear表示线性层对应的网络参数;
wCAN-DNN表示DNN网络对应于CAN网络的网络参数;
wCAN-CIN表示CIN网络对应于CAN网络的网络参数;
wBert-DNN表示DNN网络对应于Bert网络的网络参数;
wBert-CIN表示CIN网络对应于Bert网络的网络参数;
wLinear、wCAN-DNN、wCAN-CIN、wBert-DNN、wBert-CIN和b均为可学习的网络参数;
xk CAN-DNN表示经过所述CAN网络进行特征嵌入后进一步由DNN网络的输出;
p+表示经过所述CAN网络进行特征嵌入后进一步由CIN网络的输出;
xk Bert-DNN表示经过所述Bert网络进行特征嵌入后进一步由DNN网络的输出;
p++表示经过所述Bert网络进行特征嵌入后进一步由CIN网络的输出。
可选地,点击率预估模型的训练装置还包括:
特征组获取模块,用于获取待预估的视频的特征组;
预估模块,用于将所述待预估的视频的特征组输入到所述点击率预估模型中,获得所述点击率预估模型输出的点击预估结果。
需要说明的是,本说明书中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置实施例而言,由于其基本相似于方法实施例,所以描述得比较简单,相关之处参见方法实施例的部分说明即可。以上所描述的装置实施例仅仅是示意性的,其中作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元提示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。
以上所述,仅为本申请的一种具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到的变化或替换,都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应该以权利要求的保护范围为准。

Claims (10)

1.一种点击率预估模型的训练方法,其特征在于,包括:
以xDeepFM模型为基础,利用Bert网络构建新的嵌入层,得到新的模型架构;所述新的嵌入层中包括所述xDeepFM模型的嵌入层中原有的CAN网络和所述Bert网络;
在所述新的模型架构中,在输入层对训练数据集中的样本特征组进行独热编码;
将所述输入层输出的独热编码后的特征输入到所述新的嵌入层中,由所述CAN网络和所述Bert网络分别进行特征嵌入;
采用所述新的模型架构中的CIN网络和DNN网络分别基于所述CAN网络的输出进行学习,并采用所述CIN网络和所述DNN网络分别基于所述Bert网络的输出进行学习;
基于所述输入层的输出、所述CIN网络的输出和所述DNN网络的输出,对所述新的模型架构进行网络参数调整,得到点击率预估模型。
2.根据权利要求1所述的训练方法,其特征在于,所述基于所述输入层的输出、所述CIN网络的输出和所述DNN网络的输出,对所述新的模型架构进行网络参数调整,得到点击率预估模型,具体包括:
在所述新的模型架构中,在输出层利用激活函数基于所述输入层的输出、所述CIN网络的输出和所述DNN网络的输出进行运算,得到模型激活后的预测值;
利用损失函数基于所述预测值和所述样本特征组对应的标签得到模型的损失值;
根据所述损失值对网络参数进行调整,得到点击率预估模型。
3.根据权利要求2所述的训练方法,其特征在于,所述方法还包括:
以操作者接收特性曲线下与坐标轴围成的面积AUC作为准确率评价指标,获得Bert网络具有不同层数时模型的AUC和损失值;
根据Bert网络具有不同层数时模型的AUC和损失值,确定Bert网络的目标层数。
4.根据权利要求3所述的训练方法,其特征在于,所述获得Bert网络具有不同层数时模型的AUC和损失值,具体包括:
利用训练数据集中的样本特征组和对应的标签获取Bert网络具有不同层数时模型在所述训练数据集的第一AUC和第一平均损失值;
利用测试数据集中的样本特征组和对应的标签获取Bert网络具有不同层数时模型在所述测试数据集的第二AUC和第二平均损失值;
所述根据Bert网络具有不同层数时模型的AUC和损失值,确定Bert网络的目标层数,具体包括:
根据Bert网络具有不同层数时模型的第一AUC、第一平均损失值、第二AUC和第二平均损失值,确定Bert网络的目标层数。
5.根据权利要求2所述的训练方法,其特征在于,所述利用损失函数基于所述预测值和所述样本特征组对应的标签得到模型的损失值,具体包括:
采用交叉熵损失函数基于所述预测值和所述样本特征组对应的标签得到模型的损失值;交叉熵损失函数计算损失值L的公式如下:
Figure 583396DEST_PATH_IMAGE001
其中,yi表示第i个所述样本特征组对应的标签,ŷi表示基于第i个样本特征组得到的预测值,N表示样本特征组的组数,i表示样本特征组的序数。
6.根据权利要求2所述的训练方法,其特征在于,所述激活函数为sigmoid激活函数;所述预测值为通过以下公式运算得到的:
Figure 311180DEST_PATH_IMAGE002
其中,σ表示sigmoid激活函数,α表示所述输入层的独热编码后的特征;b表示全局偏置项;公式中参数上的T表示对参数的转置;
wLinear表示线性层对应的网络参数;
wCAN-DNN表示DNN网络对应于CAN网络的网络参数;
wCAN-CIN表示CIN网络对应于CAN网络的网络参数;
wBert-DNN表示DNN网络对应于Bert网络的网络参数;
wBert-CIN表示CIN网络对应于Bert网络的网络参数;
wLinear、wCAN-DNN、wCAN-CIN、wBert-DNN、wBert-CIN和b均为可学习的网络参数;
xk CAN-DNN表示经过所述CAN网络进行特征嵌入后进一步由DNN网络的输出;
p+表示经过所述CAN网络进行特征嵌入后进一步由CIN网络的输出;
xk Bert-DNN表示经过所述Bert网络进行特征嵌入后进一步由DNN网络的输出;
p++表示经过所述Bert网络进行特征嵌入后进一步由CIN网络的输出。
7.根据权利要求1-6任一项所述的训练方法,其特征在于,所述样本特征组包括:原始特征、计数特征、标签均值特征、nunique特征、视频特征、音频特征和标题特征;
其中,所述原始特征包括:用户编号,用户所在城市,视频编号,作者编号,视频所在城市,背景音乐编号,播放次数和视频持续时长;
所述计数特征包括:用户编号,播放次数,视频编号,作者编号,用户-作者编号;
所述标签均值特征包括:用户编号,播放次数,视频编号,用户-作者编号,用户播放次数和播放渠道;
所述nunique特征包括:用户-城市编号,用户-视频编号,用户-作者编号,用户-音乐编号,视频-城市编号,视频-用户编号和作者-用户编号。
8.根据权利要求1-6任一项所述的训练方法,其特征在于,还包括:
获取待预估的视频的特征组;
将所述待预估的视频的特征组输入到所述点击率预估模型中,获得所述点击率预估模型输出的点击预估结果。
9.一种点击率预估模型的训练装置,其特征在于,包括:
嵌入层构建模块,用于以xDeepFM模型为基础,利用Bert网络构建新的嵌入层,得到新的模型架构;所述新的嵌入层中包括所述xDeepFM模型的嵌入层中原有的CAN网络和所述Bert网络;
独热编码模块,用于在所述新的模型架构中,在输入层对训练数据集中的样本特征组进行独热编码;
特征嵌入模块,用于将所述输入层输出的独热编码后的特征输入到所述新的嵌入层中,由所述CAN网络和所述Bert网络分别进行特征嵌入;
交互信息学习模块,用于采用所述新的模型架构中的CIN网络和DNN网络分别基于所述CAN网络的输出进行学习,并采用所述CIN网络和所述DNN网络分别基于所述Bert网络的输出进行学习;
参数调整模块,用于基于所述输入层的输出、所述CIN网络的输出和所述DNN网络的输出,对所述新的模型架构进行网络参数调整,得到点击率预估模型。
10.根据权利要求9所述的训练装置,其特征在于,所述参数调整模块,具体用于:
在所述新的模型架构中,在输出层利用激活函数基于所述输入层的输出、所述CIN网络的输出和所述DNN网络的输出进行运算,得到模型激活后的预测值;
利用损失函数基于所述预测值和所述样本特征组对应的标签得到模型的损失值;
根据所述损失值对网络参数进行调整,得到点击率预估模型。
CN202211533551.5A 2022-12-01 2022-12-01 一种点击率预估模型的训练方法及相关装置 Active CN115563510B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211533551.5A CN115563510B (zh) 2022-12-01 2022-12-01 一种点击率预估模型的训练方法及相关装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211533551.5A CN115563510B (zh) 2022-12-01 2022-12-01 一种点击率预估模型的训练方法及相关装置

Publications (2)

Publication Number Publication Date
CN115563510A true CN115563510A (zh) 2023-01-03
CN115563510B CN115563510B (zh) 2023-04-07

Family

ID=84770589

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211533551.5A Active CN115563510B (zh) 2022-12-01 2022-12-01 一种点击率预估模型的训练方法及相关装置

Country Status (1)

Country Link
CN (1) CN115563510B (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115994632A (zh) * 2023-03-24 2023-04-21 北京搜狐新动力信息技术有限公司 一种点击率预测方法、装置、设备及可读存储介质

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20210264272A1 (en) * 2018-07-23 2021-08-26 The Fourth Paradigm (Beijing) Tech Co Ltd Training method and system of neural network model and prediction method and system
KR20210137643A (ko) * 2020-05-11 2021-11-18 네이버 주식회사 쇼핑 검색을 위한 상품 속성 추출 방법
CN113971404A (zh) * 2021-10-29 2022-01-25 中南民族大学 一种基于解耦注意力的文物安全命名实体识别方法
CN114154565A (zh) * 2021-11-18 2022-03-08 北京科技大学 一种基于多层次特征交互的点击率预测方法及装置
CN115048855A (zh) * 2022-05-06 2022-09-13 南宁师范大学 点击率预测模型及其训练方法与应用装置
CN115222066A (zh) * 2022-07-21 2022-10-21 中国平安人寿保险股份有限公司 模型训练方法和装置、行为预测方法、设备及存储介质
CN115239429A (zh) * 2022-07-29 2022-10-25 广州华多网络科技有限公司 属性信息编码方法及其装置、设备、介质、产品

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20210264272A1 (en) * 2018-07-23 2021-08-26 The Fourth Paradigm (Beijing) Tech Co Ltd Training method and system of neural network model and prediction method and system
KR20210137643A (ko) * 2020-05-11 2021-11-18 네이버 주식회사 쇼핑 검색을 위한 상품 속성 추출 방법
CN113971404A (zh) * 2021-10-29 2022-01-25 中南民族大学 一种基于解耦注意力的文物安全命名实体识别方法
CN114154565A (zh) * 2021-11-18 2022-03-08 北京科技大学 一种基于多层次特征交互的点击率预测方法及装置
CN115048855A (zh) * 2022-05-06 2022-09-13 南宁师范大学 点击率预测模型及其训练方法与应用装置
CN115222066A (zh) * 2022-07-21 2022-10-21 中国平安人寿保险股份有限公司 模型训练方法和装置、行为预测方法、设备及存储介质
CN115239429A (zh) * 2022-07-29 2022-10-25 广州华多网络科技有限公司 属性信息编码方法及其装置、设备、介质、产品

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115994632A (zh) * 2023-03-24 2023-04-21 北京搜狐新动力信息技术有限公司 一种点击率预测方法、装置、设备及可读存储介质

Also Published As

Publication number Publication date
CN115563510B (zh) 2023-04-07

Similar Documents

Publication Publication Date Title
CN110474815B (zh) 带宽预测方法、装置、电子设备及存储介质
CN110097755A (zh) 基于深度神经网络的高速公路交通流量状态识别方法
CN111126202A (zh) 基于空洞特征金字塔网络的光学遥感图像目标检测方法
CN110399850A (zh) 一种基于深度神经网络的连续手语识别方法
CN111986180B (zh) 基于多相关帧注意力机制的人脸伪造视频检测方法
CN115563510B (zh) 一种点击率预估模型的训练方法及相关装置
CN110349597A (zh) 一种语音检测方法及装置
CN106503853A (zh) 一种基于多标度卷积神经网络的外汇交易预测模型
CN112488243A (zh) 一种图像翻译方法
CN115527150A (zh) 一种结合卷积注意力模块的双分支视频异常检测方法
CN114154016A (zh) 基于目标空间语义对齐的视频描述方法
CN115373879A (zh) 一种面向大规模云数据中心智能运维的磁盘故障预测方法
CN114926742A (zh) 一种基于二阶注意力机制的回环检测及优化方法
CN111144462A (zh) 一种雷达信号的未知个体识别方法及装置
CN112967227B (zh) 基于病灶感知建模的糖尿病视网膜病变自动评估系统
CN112417890B (zh) 一种基于多样化语义注意力模型的细粒度实体分类方法
CN102222237B (zh) 手语视频的相似度评估模型的建立方法
CN113034940A (zh) 一种基于Fisher有序聚类的单点信号交叉口优化配时方法
CN114494999B (zh) 一种双分支联合型目标密集预测方法及系统
CN115510915A (zh) 一种基于门控循环网络原理的已知雷达信号分选方法
Sun et al. MMINR: Multi-frame-to-multi-frame inference with noise resistance for precipitation nowcasting with radar
Zhong et al. Encoding broad learning system: An effective shallow model for anti-fraud
CN114529794A (zh) 一种红外与可见光图像融合方法、系统及介质
CN114528762A (zh) 一种模型训练方法、装置、设备和存储介质
Zhou et al. Weakly perceived object detection based on an improved CenterNet

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant