CN116975686A - 训练学生模型的方法、行为预测方法和装置 - Google Patents
训练学生模型的方法、行为预测方法和装置 Download PDFInfo
- Publication number
- CN116975686A CN116975686A CN202310907307.9A CN202310907307A CN116975686A CN 116975686 A CN116975686 A CN 116975686A CN 202310907307 A CN202310907307 A CN 202310907307A CN 116975686 A CN116975686 A CN 116975686A
- Authority
- CN
- China
- Prior art keywords
- bridge
- vector
- model
- node
- student
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 64
- 238000012549 training Methods 0.000 title claims abstract description 28
- 239000013598 vector Substances 0.000 claims abstract description 226
- 238000000605 extraction Methods 0.000 claims abstract description 149
- 238000004821 distillation Methods 0.000 claims abstract description 67
- 230000002452 interceptive effect Effects 0.000 claims abstract description 49
- 239000002131 composite material Substances 0.000 claims abstract description 6
- 230000006399 behavior Effects 0.000 claims description 74
- 230000003993 interaction Effects 0.000 claims description 38
- 230000002776 aggregation Effects 0.000 claims description 11
- 238000004220 aggregation Methods 0.000 claims description 11
- 230000004931 aggregating effect Effects 0.000 claims description 10
- 238000013528 artificial neural network Methods 0.000 claims description 9
- 238000004590 computer program Methods 0.000 claims description 3
- 238000010586 diagram Methods 0.000 description 14
- 238000013140 knowledge distillation Methods 0.000 description 13
- 230000005012 migration Effects 0.000 description 11
- 238000013508 migration Methods 0.000 description 11
- 230000000670 limiting effect Effects 0.000 description 5
- 230000006870 function Effects 0.000 description 4
- 238000013459 approach Methods 0.000 description 3
- 230000008569 process Effects 0.000 description 3
- 230000007547 defect Effects 0.000 description 2
- 238000011161 development Methods 0.000 description 2
- 238000012546 transfer Methods 0.000 description 2
- 230000003542 behavioural effect Effects 0.000 description 1
- 230000015572 biosynthetic process Effects 0.000 description 1
- 239000000284 extract Substances 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 239000011159 matrix material Substances 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003786 synthesis reaction Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/042—Knowledge-based neural networks; Logical representations of neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q50/00—Information and communication technology [ICT] specially adapted for implementation of business processes of specific business sectors, e.g. utilities or tourism
- G06Q50/10—Services
- G06Q50/20—Education
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Business, Economics & Management (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- Computational Linguistics (AREA)
- Mathematical Physics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Tourism & Hospitality (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Educational Administration (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- Educational Technology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Economics (AREA)
- Human Resources & Organizations (AREA)
- Marketing (AREA)
- Primary Health Care (AREA)
- Strategic Management (AREA)
- General Business, Economics & Management (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本说明书实施例提供了一种训练学生模型的方法、行为预测方法和装置,该方法基于教师模型以及桥模型来训练学生模型,该方法包括:将原始特征数据输入第一嵌入层,得到第一嵌入特征并输入训练好的教师模型、以及桥模型,得到第一和第二预测结果并根据其更新桥模型;将第一嵌入特征输入桥模型所包括的桥特征提取网络、以及学生模型包括的学生特征提取网络,得到第一桥提取向量和学生提取向量,并根据其确定交互蒸馏损失;将原始特征数据输入第二嵌入层,得到第二嵌入特征并输入桥特征提取网络,得到第二桥提取向量;根据第一和第二桥提取向量,确定嵌入蒸馏损失;至少根据交互蒸馏损失和嵌入蒸馏损失,确定综合损失并根据其更新学生模型。
Description
技术领域
本说明书一个或多个实施例涉及机器学习和行为预测领域,尤其涉及训练学生模型的方法、行为预测方法和装置。
背景技术
行为预测(Cl ick Through Rate,CTR)模型,或称点击率预测模型,通常用于对用户针对目标对象的用户行为进行预测。目前,随着CTR模型的发展,其模型结构的复杂度和在线服务的耗时均会增加,会导致限制CTR模型的应用范围。对此,一种现有解决方案通过知识蒸馏(Knowledge Distillation)技术,将具有相对复杂网络结构的教师模型中的知识(Knowledge)迁移到具有相对简单的网络结构的学生模型中,并利用学生模型对用户行为进行实际预测。但是,在一些场景中,由于学生模型容量与老师的模型容量差异较大或者模型结构的差异较大,使用现有的知识蒸馏方案进行知识迁移的效率较低,迁移后学生模型的用户行为预测能力不足。
发明内容
本说明书中的实施例旨在提供一种训练学生模型的方法、行为预测方法和装置,提高了将教师模型中的知识迁移到学生模型的效率,进而提高了迁移后学生模型的用户行为预测能力,解决现有技术中的不足。
根据第一方面,提供了一种训练学生模型的方法,包括:
将目标用户与目标对象的原始特征数据输入训练好的第一嵌入层,得到第一嵌入特征;将第一嵌入特征分别输入训练好的教师模型、以及桥模型,分别得到关于目标用户针对目标对象实施预定行为的第一预测结果和第二预测结果,根据第一预测结果和第二预测结果,更新所述桥模型;
将所述第一嵌入特征输入更新后的桥模型所包括的桥特征提取网络,得到第一桥提取向量;以及输入学生模型包括的学生特征提取网络,得到学生提取向量;根据第一桥提取向量和学生提取向量,确定交互蒸馏损失;
将所述原始特征数据输入学生模型包括的第二嵌入层,得到第二嵌入特征;将所述第二嵌入特征输入所述桥特征提取网络,得到第二桥提取向量;根据第一桥提取向量和第二桥提取向量,确定嵌入蒸馏损失;
至少根据所述交互蒸馏损失和嵌入蒸馏损失,确定综合损失,根据所述综合损失,更新所述学生模型;所述学生模型用于预测所述预定行为。
在一种可能的实施方式中,所述学生模型还包括学生预测网络;
所述方法还包括:
将学生提取向量输入学生预测网络,得到关于所述预定行为的第三预测结果,根据第三预测结果和所述目标用户针对目标对象的行为标签,得到基础分类损失;
根据第一预测结果和第三预测结果,得到师生差异损失;
至少根据所述交互蒸馏损失和嵌入蒸馏损失,确定综合损失,包括:
根据所述基础损失、师生差异损失、交互蒸馏损失和嵌入蒸馏损失,确定综合损失。
在一种可能的实施方式中,根据所述基础损失、师生差异损失、交互蒸馏损失和嵌入蒸馏损失,确定综合损失,包括:
根据所述基础损失、师生差异损失、交互蒸馏损失和嵌入蒸馏损失的加权和,确定综合损失。
在一种可能的实施方式中,所述桥特征提取网络通过图神经网络实现;所述原始特征数据包括多个字段的特征数据;
将所述第一嵌入特征输入更新后的桥模型所包括的桥特征提取网络,得到第一桥提取向量,包括:
将所述第一嵌入特征包括的对应于所述多个字段的多个子特征,作为无向图的多个节点对应的初阶向量;
通过所述图神经网络对所述多个节点进行特征交互,得到各个节点的多个阶的向量;聚合各个节点的多个阶的向量,得到第一桥提取向量。
在一种可能的实施方式中,所述特征交互包括多轮迭代,任意一轮迭代包括:
根据所述多个节点的上一阶向量,确定出本轮各个节点之间的交互消息向量;
根据本轮的所述交互消息向量、各个节点的上一阶向量、以及各个节点之间的关系权重,更新得到各个节点的本阶向量。
在一种可能的实施方式中,根据所述多个节点的上一阶向量,确定出本轮各个节点之间的交互消息向量,包括:
对于多个节点中任意的第一节点和第二节点,将第一节点的上一阶向量和第二节点的上一阶向量的向量和、与第一节点的上一阶向量和第二节点的上一阶向量的哈德玛Hadamard积进行级联,得到本轮第一节点和第二节点之间的交互消息向量。
在一种可能的实施方式中,根据本轮的所述交互消息向量、各个节点的上一阶向量、以及各个节点之间的关系权重,更新得到各个节点的本阶向量,包括:
对于多个节点中任意的第三节点,根据第三节点的上一阶向量与其它节点的上一阶向量之间的交互消息向量与关系权重的乘积之和,得到第一更新向量,对第三节点的上一阶向量和第一更新向量级联,得到第三节点向量的本阶向量。
在一种可能的实施方式中,聚合各个节点的多个阶的向量,得到第一桥提取向量,包括:
对各个节点的多个阶的向量进行级联,得到第一桥提取向量。
在一种可能的实施方式中,桥特征提取网络包括交互消息子网路、节点更新子网络和节点聚合子网络;
根据所述多个节点的上一阶向量,确定出本轮各个节点之间的交互消息向量,包括:通过交互消息子网路,根据所述多个节点的上一阶向量,确定出本轮各个节点之间的交互消息向量;
根据本轮的所述交互消息向量、各个节点的上一阶向量、以及各个节点之间的关系权重,更新得到各个节点的本阶向量,包括:通过节点更新子网络,根据本轮的所述交互消息向量、各个节点的上一阶向量、以及各个节点之间的关系权重,更新得到各个节点的本阶向量;
聚合各个节点的多个阶的向量,得到第一桥提取向量,包括:通过节点聚合子网络,聚合各个节点的多个阶的向量,得到第一桥提取向量。
在一种可能的实施方式中,所述交互消息子网路、节点更新子网络和节点聚合子网络,基于多层感知机MLP构建。
在一种可能的实施方式中,根据第一预测结果和第二预测结果,更新所述桥模型,包括:
根据第一预测结果和第二预测结果之差的平方,确定第一损失;
根据所述第一损失,更新所述桥模型。
在一种可能的实施方式中,所述教师模型的模型容量大于所述学生模型。
在一种可能的实施方式中,所述目标对象包括商品、广告、短视频文章中的一种或多种;所述预定行为包括点击、购买、播放、收藏中的一种或多种。
根据第二方面,提供了一种行为预测方法,包括:
将待测用户与待测对象的原始特征数据输入行为预测模型,所述行为预测模型为根据第一方面所述方法训练好的学生模型;
将所述行为预测模型的输出结果,作为所述待测用户针对所述待测对象施加预定行为的预测结果。
根据第三方面,提供了一种训练学生模型的装置,包括:
桥模型更新单元,配置为,将目标用户与目标对象的原始特征数据输入训练好的第一嵌入层,得到第一嵌入特征;将第一嵌入特征分别输入训练好的教师模型、以及桥模型,分别得到关于目标用户针对目标对象实施预定行为的第一预测结果和第二预测结果,根据第一预测结果和第二预测结果,更新所述桥模型;
交互损失确定单元,配置为,将所述第一嵌入特征输入更新后的桥模型所包括的桥特征提取网络,得到第一桥提取向量;以及输入学生模型包括的学生特征提取网络,得到学生提取向量;根据第一桥提取向量和学生提取向量,确定交互蒸馏损失;
嵌入损失确定单元,配置为,将所述原始特征数据输入学生模型包括的第二嵌入层,得到第二嵌入特征;将所述第二嵌入特征输入所述桥特征提取网络,得到第二桥提取向量;根据第一桥提取向量和第二桥提取向量,确定嵌入蒸馏损失;
学生模型更新单元,配置为,至少根据所述交互蒸馏损失和嵌入蒸馏损失,确定综合损失,根据所述综合损失,更新所述学生模型;所述学生模型用于预测所述预定行为。
根据第四方面,提供了一种行为预测装置,包括:
输入单元,配置为,将待测用户与待测对象的原始特征数据输入行为预测模型,所述行为预测模型为根据第一方面所述的方法训练好的学生模型;
预测单元,配置为,将所述行为预测模型的输出结果,作为所述待测用户针对所述待测对象施加预定行为的预测结果。
根据第五方面,提供了一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行第一、第二方面所述的方法。
根据第六方面,提供了一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现第一、第二方面所述的方法。
利用以上各个方面中的方法、装置、计算设备、存储介质中的一个或多个,通过教师模型和学生模型之间构建的桥模型,提高将教师模型中的知识迁移到学生模型的效率,进而提高了迁移后学生模型的用户行为预测能力,解决现有技术中的不足。
附图说明
为了更清楚说明本发明实施例的技术方案,下面将对实施例描述中所需使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1示出一种知识蒸馏方案的示意图;
图2示出根据本说明书实施例的一种训练学生模型的方法的原理示意图;
图3示出根据本说明书实施例的一种训练学生模型的方法的流程图;
图4示出根据本说明书实施例的训练桥模型的示意图;
图5示出根据本说明书实施例的确定交互蒸馏损失的示意图;
图6示出根据本说明书实施例的确定嵌入蒸馏损失的示意图;
图7示出根据本说明书实施例的确定师生差异损失的示意图;
图8示出根据本说明书实施例的确定交互蒸馏损失的示意图;
图9示出根据本说明书实施例的确定基础分类损失的示意图;
图10示出根据本说明书实施例的一种行为预测方法的流程图;
图11示出根据本说明书实施例的一种训练学生模型的装置的结构图;
图12示出根据本说明书实施例的一种行为预测装置的结构图。
具体实施方式
下面将结合附图,对本说明书提供的方案进行描述。
如前所述,随着行为预测(Cl ick Through Rate,CTR)模型,或称点击率预测模型的发展,其模型结构的复杂度和在线服务的耗时均会增加,从到导致限制CTR模型的应用范围。对此,一种现有解决方案通过知识蒸馏(Knowledge Distillation)技术,将具有相对复杂网络结构的教师模型中的知识(Knowledge)迁移到具有相对简单的网络结构的学生模型中,并利用学生模型对用户行为进行实际预测。例如,图1示出一种知识蒸馏方案的示意图。如图1所示,教师模型和学生模型均为行为预测模型,其中,教师模型可以是已经训练好的、且模型容量较大的行为预测模型。可以通过教师模型,训练结构相对简单的学生模型。具体的,可以将训练样本分别输入教师模型和学生模型,根据教师模型的输出的预测结果与学生模型输出的预测结果,确定学生模型的识别损失,并根据识别损失更新学生模型。或者说,该方案利用教师模型的预测结果作为输入学生模型的训练样本的预测标签,使得在学生模型的训练过程中,学生模型针对训练样本的输出的预测结果趋向于教师模型的预测结果。但是,在一些场景中,使用现有的知识蒸馏方案进行知识迁移仍旧存在如下问题:例如由于学生模型与教师模型的模型结构差异,导致学生模型的模型容量与教师模型的模型容量差距较大时,使用现有的知识蒸馏方案进行知识迁移仍然存在迁移效率不足的问题,或者说,进行知识迁移后存在学生模型的行为预测能力不足的问题。
为了解决上述技术问题,本说明书实施例提供了一种训练学生模型的方法,以及相应的行为预测方法及装置。图2示出根据本说明书实施例的一种训练学生模型的方法的原理示意图。如图2所示,该训练方法主要通过包括教师模型、桥模型和学生模型在内的三个预测模型实施,其中,每个预测模型均可以包括一个特征提取网络和一个预测网络,例如,教师模型包括特征提取网络1和预测网络1,桥模型包括特征提取网络2和预测网络2,学生模型包括特征提取网络3和预测网络3。教师模型和桥模型可以共享一个嵌入层(例如为嵌入层1),而学生模型可以使用单独的嵌入层(例如为嵌入层2)。在一个实施例中,例如包括多个字段的原始特征,可以分别通过例如嵌入层1或嵌入层2被转换为维度较低而数据密集度较高的嵌入向量。如图2所示,具体的训练过程可以分为两个阶段:在第一阶段,通过教师模型向桥模型进行知识迁移。具体的,例如可以以桥模型的输出结果与教师模型的输出结果趋近为目标,更新桥模型。在第二阶段,通过桥模型向学生模型进行知识迁移。具体的,例如可以根据更新后的桥模型、以及学生模型包括的特征提取网络(例如,特征提取网络2和特征提取网络3,其中,特征提取网络2例如可以基于图神经网络GNN)提取的特征向量的差异,确定出特征交互损失(或称,交互蒸馏损失)。以及,分别将通过桥模型的嵌入层(例如嵌入层1)和学生网络的嵌入层(例如嵌入层2)得到的嵌入向量输入桥模型的特征提取网络,根据输出的特征向量的差异,确定出嵌入损失(或称,嵌入蒸馏损失)。至少根据交互蒸馏损失和嵌入蒸馏损失,更新学生模型。其本质所起的作用,是可以从桥模型中解耦出与特征交互相关的知识(或简称特征交互知识)、以及与特征嵌入相关的知识(或简称嵌入知识),然后将其迁移到学生模型。此后,例如在在线预测阶段,可以仅使用更新后的学生模型进行针对用户行为的在线预测。
使用该方法的优点在于:一方面,可以利用例如基于GNN的桥模型识别出特征中显著的交互关系,提高特征的表示能力并降低特征中的冗余信息。进而以桥模型作为教师和学生模型之间知识蒸馏的中间媒介层,可以提高了教师模型和学生模型之间的知识蒸馏效率,从而解决教师模型和学生模型之间的容量差距而导致的知识蒸馏效率较为低效的问题。另一方面,在从桥模型向学生模型知识迁移的过程中,从桥模型中解耦出特征交互知识和特征嵌入知识并分别传递,以提高嵌入层和特征提取网络各自知识迁移的充分性,并减少了知识迁移后嵌入层和特征提取网络之间的相互作用,更加提高了知识蒸馏的效率。即提高了更新后学生模型的用户行为识别能力。
下面阐述该方法的详细过程。图3示出根据本说明书实施例的一种训练学生模型的方法的流程图。如图3所示,该方法至少包括如下步骤:
步骤S301,将目标用户与目标对象的原始特征数据输入第一嵌入层,得到第一嵌入特征;将第一嵌入特征分别输入训练好的教师模型得到第一预测结果,且将第一嵌入特征输入桥模型得到第二预测结果,根据第一预测结果和第二预测结果,更新所述桥模型;所述桥模型包括桥特征提取网络和桥预测网络;
步骤S303,将所述第一嵌入特征输入所述桥特征提取网络,得到第一桥提取向量;以及将所述第一嵌入特征输入学生模型包括的学生特征提取网络,得到学生提取向量;根据第一桥提取向量和学生提取向量,确定交互蒸馏损失;
步骤S305,将所述原始特征数据输入第二嵌入层,得到第二嵌入特征;将所述第二嵌入特征输入所述桥特征提取网络,得到第二桥提取向量;根据第一桥提取向量和第二桥提取向量,确定嵌入蒸馏损失;
步骤S307,至少根据所述交互蒸馏损失和嵌入蒸馏损失,确定综合损失,根据所述综合损失,更新所述学生模型;所述学生模型用于预测预定行为。
首先,在步骤S301,将目标用户与目标对象的原始特征数据输入第一嵌入层,得到第一嵌入特征;将第一嵌入特征分别输入训练好的教师模型、以及桥模型,分别得到第一预测结果和第二预测结果,根据第一预测结果和第二预测结果,更新所述桥模型。
在不同的实施例中,目标用户可以是不同的具体用户,本说明书对此不做限制。在不同的实施例中,目标对象、以及目标用户针对目标对象实施的预定行为也可以是不同的具体对象或行为。在一个实施例中,目标对象可以包括商品、广告、短视频文章中的一种或多种。在一个实施例中,预定行为可以包括点击、购买、播放、收藏中的一种或多种。原始特征数据中还可以包括多个字段的特征数据。在不同的实施例中,原始特征数据中包括的具体字段可以不同,本说明书对此不做限制。
教师模型和桥模型可以均为行为预测模型,其中,教师模型可以是预先训练好的行为预测模型。教师模型和桥模型可以具有各自的特征提取网络和预测网络。在不同的实施例中,教师模型的特征提取网络的具体网络结构可以不同,本说明书对此不做限制。为了方便后续步骤中提取特征交互向量,在一个实施例中,桥模型的特征提取网络可以基于图神经网络(Graph Neural Network,GNN)。
通过嵌入(embedding)层,可以将原始特征数据转换为后续训练中使用的特征向量。在一个具体的例子中,例如可以通过嵌入层对于原始特征数据进行降维,得到后续训练中使用的特征向量。教师模型和桥模型还可以使用同一嵌入层,例如训练好的第一嵌入层,输出的嵌入向量(例如,第一嵌入特征),作为它们各自的特征提取网络的输入向量。
在不同实施例中,根据第一预测结果和第二预测结果,更新所述桥模型的具体方式可以不同。如4图所示,在一个实施例中,可以根据第一预测结果和第二预测结果之差的平方,确定第一损失;根据所述第一损失,更新所述桥模型。
以及,可以在步骤S303,将所述第一嵌入特征输入桥模型所包括的桥特征提取网络(例如,图5中特征提取网络2),得到第一桥提取向量;以及输入学生模型包括的学生特征提取网络(例如,图5中特征提取网络3),得到学生提取向量;根据第一桥提取向量和学生提取向量,确定交互蒸馏损失,如图5所示。学生模型也可以是行为预测模型,并包括其自己的特征提取网络和预测网络。在一个实施例中,由于学生特征提取网络的输入向量的维度可以小于桥特征提取网络的输入向量的维度,为了适配学生特征提取网络的输入向量的维度,可以对第一嵌入特征进行线性转换,使其维度等于学生特征提取网络的输入向量的维度之后,输入学生特征提取网络。
模型容量,通常是指模型拟合各种函数的能力。如前所述,相对于学生模型,教师模型的结构通常更为复杂,网络参数数量更多。所以,教师模型相对于学生模型,通常可以具有更强的函数拟合能力,即具有更大的模型容量。因此,在一个实施例中,教师模型的模型容量可以大于所述学生模型。
在不同的实施例中,获取第一桥提取向量的具体方式可以不同。如前所述,在一个实施例中,桥特征提取网络可以通过图神经网络实现,所述原始特征数据可以包括多个字段的特征数据。进而,在该实施例中,可以通过以下方式得到第一桥提取向量:将所述第一嵌入特征包括的对应于所述多个字段的多个子特征,作为无向图的多个节点对应的初阶向量;通过所述图神经网络对所述多个节点进行特征交互,得到各个节点的多个阶的向量;聚合各个节点的多个阶的向量,得到第一桥提取向量。在一个例子中,聚合各个节点的多个阶的向量,得到第一桥提取向量,可以表示为:
其中,v表示节点的向量,i表示节点序数,p表示向量阶数,N表示节点数量,P表示向量最高阶数,hb表示第一桥提取向量,施加于向量的[]表示向量级联操作。在不同的实施例中,进行聚合各个节点的多个阶的向量的方式可以不同。在一个实施例中,可以对各个节点的多个阶的向量进行级联,得到第一桥提取向量。
在不同的具体实施例中,进行特征交互的方式可以不同。在一个实施例中,特征交互可以包括多轮迭代,任意一轮迭代可以包括:根据所述多个节点的上一阶向量,确定出本轮各个节点之间的交互消息向量;根据本轮的所述交互消息向量、各个节点的上一阶向量、以及各个节点之间的关系权重,更新得到各个节点的本阶向量。在一个具体的实施例中,可以通过以下方式得到各个节点的本阶向量:对于多个节点中任意的第三节点,根据第三节点的上一阶向量与其它节点的上一阶向量之间的交互消息向量与关系权重的乘积之和,得到第一更新向量,对第三节点的上一阶向量和第一更新向量级联,得到第三节点向量的本阶向量。在一个例子中,任意一轮迭代可以表示为:
其中,j表示节点序数,m表示节点之间的交互消息向量(例如表示序数为i的节点和序数为j的节点之间的p+1阶交互消息向量)。w表示节点之间的关系权重(例如wij表示序数为i的节点和序数为j的节点之间的关系权重),在一个具体的例子中,wij可以表示为:
其中τ是温度系数。在一个例子中,当τ趋近于0时,权重wij趋近于0或1,使得节点之间的交互关系更具有区分性。θij表示邻接矩阵θ中第i行j列元素。
在一个具体的实施例中,可以通过以下方式确定出上述一轮的交互消息向量:对于多个节点中任意的第一节点和第二节点,将第一节点的上一阶向量和第二节点的上一阶向量的向量和、与第一节点的上一阶向量和第二节点的上一阶向量的哈德玛Hadamard积进行级联,得到本轮第一节点和第二节点之间的交互消息向量。在一个例子中,确定上述一轮的交互消息向量可以表示为:
其中,+表示向量加,⊙表示哈德玛(Hadamard)乘。
在不同的实施例中,桥特征提取网络的具体网络结构可以不同。在一个实施例中,桥特征提取网络可以包括交互消息子网路、节点更新子网络和节点聚合子网络,如图6所示。具体的,可以通过交互消息子网路,根据所述多个节点的上一阶向量,确定出本轮各个节点之间的交互消息向量。通过节点更新子网络,根据本轮的所述交互消息向量、各个节点的上一阶向量、以及各个节点之间的关系权重,更新得到各个节点的本阶向量。通过节点聚合子网络,聚合各个节点的多个阶的向量,得到第一桥提取向量。
在不同的具体实施例中,交互消息子网路、节点更新子网络和节点聚合子网络的具体结构可以不同,在一个具体的实施例中,交互消息子网路、节点更新子网络和节点聚合子网络,可以基于多层感知机(Multilayer Perceptron,MLP)构建。
在不同的实施例中,根据第一桥提取向量和学生提取向量,确定交互蒸馏损失的具体方式可以不同。在一个实施例中,可以根据第一桥提取向量和学生提取向量之差的平方,确定交互蒸馏损失。
并且,在步骤S305,将所述原始特征数据输入第二嵌入层,得到第二嵌入特征;将所述第二嵌入特征输入所述桥特征提取网络,得到第二桥提取向量;根据第一桥提取向量和第二桥提取向量,确定嵌入蒸馏损失,如图7所示。在一个实施例中,由于学生特征提取网络的输入向量的维度可以小于桥特征提取网络的输入向量的维度,为了适配桥特征提取网络的输入向量的维度,可以对第二嵌入特征进行线性转换,使其维度等于桥特征提取网络的输入向量的维度之后,输入桥特征提取网络。
在不同的实施例中,根据第一桥提取向量和第二桥提取向量,确定嵌入蒸馏损失的具体方式可以不同。在一个实施例中,可以根据第一桥提取向量和第二桥提取向量之差的平方,确定嵌入蒸馏损失。
此后,在步骤S307,至少根据所述交互蒸馏损失和嵌入蒸馏损失,确定综合损失,根据所述综合损失,更新所述学生模型;所述学生模型用于预测所述预定行为。
在不同的实施例中,至少根据所述交互蒸馏损失和嵌入蒸馏损失,确定综合损失的具体方式可以不同。在一个实施例中,学生模型还可以包括学生预测网络。进而,可以将学生提取向量输入学生预测网络,得到关于所述预定行为的第三预测结果,根据第三预测结果和所述预定行为的行为标签,得到基础分类损失,如图8所示。以及,根据第一预测结果和第三预测结果,得到师生差异损失,如图9所示。进而,可以根据所述基础损失、师生差异损失、交互蒸馏损失和嵌入蒸馏损失,确定综合损失。在一个具体的实施例中,还可以根据所述基础损失、师生差异损失、交互蒸馏损失和嵌入蒸馏损失的加权和,确定综合损失。在不同的具体实施例中,基础损失、师生差异损失、交互蒸馏损失和嵌入蒸馏损失各自的加权权重可以是不同的具体权重值,本说明书对此不做限制。在不同的具体实施例中,根据综合损失,更新所述学生模型的具体方式也可以不同。在一个具体的实施例中,可以在确定综合损失后,基于BP(反向传播)算法,更新学生模型的网络参数。
上面描述了本说明书实施例提供的一种训练学生模型的方法。本说明书另一方面的实施例,还提供一种行为预测方法。图10示出根据本说明书实施例的一种行为预测方法的流程图,如图10所示,该方法至少包括如下步骤:
步骤S1001,将待测用户与待测对象的原始特征数据输入行为预测模型,所述行为预测模型为根据上述方法训练好的学生模型;
步骤S1002,将所述行为预测模型的输出结果,作为所述待测用户针对所述待测对象施加预定行为的预测结果。
本说明书另一方面的实施例,还提供一种训练学生模型的装置。图11示出根据本说明书实施例的一种训练学生模型的装置的结构图。如图11所示,该装置1100包括:
桥模型更新单元1101,配置为,将目标用户与目标对象的原始特征数据输入第一嵌入层,得到第一嵌入特征;将第一嵌入特征分别输入训练好的教师模型得到第一预测结果,且将第一嵌入特征输入桥模型得到第二预测结果,根据第一预测结果和第二预测结果,更新所述桥模型;所述桥模型包括桥特征提取网络和桥预测网络;
交互损失确定单元1102,配置为,将所述第一嵌入特征输入所述桥特征提取网络,得到第一桥提取向量;以及将所述第一嵌入特征输入学生模型包括的学生特征提取网络,得到学生提取向量;根据第一桥提取向量和学生提取向量,确定交互蒸馏损失;
嵌入损失确定单元1103,配置为,将所述原始特征数据输入第二嵌入层,得到第二嵌入特征;将所述第二嵌入特征输入所述桥特征提取网络,得到第二桥提取向量;根据第一桥提取向量和第二桥提取向量,确定嵌入蒸馏损失;
学生模型更新单元1104,配置为,至少根据所述交互蒸馏损失和嵌入蒸馏损失,确定综合损失,根据所述综合损失,更新所述学生模型;所述学生模型用于预测预定行为。
本说明书另一方面的实施例,还提供一种行为预测装置。图12示出根据本说明书实施例的一种行为预测装置的结构图。如图12所示,该装置1200包括:
输入单元1201,配置为,将待测用户与待测对象的原始特征数据输入行为预测模型,所述行为预测模型为根据权利要求1的方法训练好的学生模型;
预测单元1202,配置为,将所述行为预测模型的输出结果,作为所述待测用户针对所述待测对象施加预定行为的预测结果。
本说明书实施例又一方面提供一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行上述任一项方法。
本说明书实施例再一方面提供一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现上述任一项方法。
需要理解,本文中的“第一”,“第二”等描述,仅仅为了描述的简单而对相似概念进行区分,并不具有其他限定作用。
本领域技术人员应该可以意识到,在上述一个或多个示例中,本发明所描述的功能可以用硬件、软件、固件或它们的任意组合来实现。当使用软件实现时,可以将这些功能存储在计算机可读介质中或者作为计算机可读介质上的一个或多个指令或代码进行传输。
以上所述的具体实施方式,对本发明的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上所述仅为本发明的具体实施方式而已,并不用于限定本发明的保护范围,凡在本发明的技术方案的基础之上,所做的任何修改、等同替换、改进等,均应包括在本发明的保护范围之内。
Claims (18)
1.一种训练学生模型的方法,该方法基于教师模型以及桥模型来训练学生模型,包括:
将目标用户与目标对象的原始特征数据输入第一嵌入层,得到第一嵌入特征;将第一嵌入特征分别输入训练好的教师模型得到第一预测结果,且将第一嵌入特征输入桥模型得到第二预测结果,根据第一预测结果和第二预测结果,更新所述桥模型;所述桥模型包括桥特征提取网络和桥预测网络;
将所述第一嵌入特征输入所述桥特征提取网络,得到第一桥提取向量;以及将所述第一嵌入特征输入学生模型包括的学生特征提取网络,得到学生提取向量;根据第一桥提取向量和学生提取向量,确定交互蒸馏损失;
将所述原始特征数据输入第二嵌入层,得到第二嵌入特征;将所述第二嵌入特征输入所述桥特征提取网络,得到第二桥提取向量;根据第一桥提取向量和第二桥提取向量,确定嵌入蒸馏损失;
至少根据所述交互蒸馏损失和嵌入蒸馏损失,确定综合损失,根据所述综合损失,更新所述学生模型;所述学生模型用于预测预定行为。
2.根据权利要求1所述的方法,其中,所述学生模型还包括学生预测网络;
所述方法还包括:
将学生提取向量输入学生预测网络,得到关于所述预定行为的第三预测结果,根据第三预测结果和所述预定行为的行为标签,得到基础分类损失;
根据第一预测结果和第三预测结果,得到师生差异损失;
至少根据所述交互蒸馏损失和嵌入蒸馏损失,确定综合损失,包括:
根据所述基础损失、师生差异损失、交互蒸馏损失和嵌入蒸馏损失,确定综合损失。
3.根据权利要求2所述的方法,其中,根据所述基础损失、师生差异损失、交互蒸馏损失和嵌入蒸馏损失,确定综合损失,包括:
根据所述基础损失、师生差异损失、交互蒸馏损失和嵌入蒸馏损失的加权和,确定综合损失。
4.根据权利要求1所述的方法,其中,所述桥特征提取网络通过图神经网络实现;所述原始特征数据包括多个字段的特征数据;
将所述第一嵌入特征输入所述桥特征提取网络,得到第一桥提取向量,包括:
将所述第一嵌入特征包括的对应于所述多个字段的多个子特征,作为无向图的多个节点对应的初阶向量;
通过所述图神经网络对所述多个节点进行特征交互,得到各个节点的多个阶的向量;聚合各个节点的多个阶的向量,得到第一桥提取向量。
5.根据权利要求4所述的方法,其中,所述特征交互包括多轮迭代,任意一轮迭代包括:
根据所述多个节点的上一阶向量,确定出本轮各个节点之间的交互消息向量;
根据本轮的所述交互消息向量、各个节点的上一阶向量、以及各个节点之间的关系权重,更新得到各个节点的本阶向量。
6.根据权利要求5所述的方法,其中,根据所述多个节点的上一阶向量,确定出本轮各个节点之间的交互消息向量,包括:
对于多个节点中任意的第一节点和第二节点,将第一节点的上一阶向量和第二节点的上一阶向量的向量和、与第一节点的上一阶向量和第二节点的上一阶向量的哈德玛Hadamard积进行级联,得到本轮第一节点和第二节点之间的交互消息向量。
7.根据权利要求5所述的方法,其中,根据本轮的所述交互消息向量、各个节点的上一阶向量、以及各个节点之间的关系权重,更新得到各个节点的本阶向量,包括:
对于多个节点中任意的第三节点,根据第三节点的上一阶向量与其它节点的上一阶向量之间的交互消息向量与关系权重的乘积之和,得到第一更新向量,对第三节点的上一阶向量和第一更新向量级联,得到第三节点向量的本阶向量。
8.根据权利要求4所述的方法,其中,聚合各个节点的多个阶的向量,得到第一桥提取向量,包括:
对各个节点的多个阶的向量进行级联,得到第一桥提取向量。
9.根据权利要求8所述的方法,其中,桥特征提取网络包括交互消息子网路、节点更新子网络和节点聚合子网络;
根据所述多个节点的上一阶向量,确定出本轮各个节点之间的交互消息向量,包括:通过交互消息子网路,根据所述多个节点的上一阶向量,确定出本轮各个节点之间的交互消息向量;
根据本轮的所述交互消息向量、各个节点的上一阶向量、以及各个节点之间的关系权重,更新得到各个节点的本阶向量,包括:通过节点更新子网络,根据本轮的所述交互消息向量、各个节点的上一阶向量、以及各个节点之间的关系权重,更新得到各个节点的本阶向量;
聚合各个节点的多个阶的向量,得到第一桥提取向量,包括:通过节点聚合子网络,聚合各个节点的多个阶的向量,得到第一桥提取向量。
10.根据权利要求9所述的方法,其中,所述交互消息子网路、节点更新子网络和节点聚合子网络,基于多层感知机MLP构建。
11.根据权利要求1所述的方法,其中,根据第一预测结果和第二预测结果,更新所述桥模型,包括:
根据第一预测结果和第二预测结果之差的平方,确定第一损失;
根据所述第一损失,更新所述桥模型。
12.根据权利要求1所述的方法,其中,所述教师模型的模型容量大于所述学生模型。
13.根据权利要求1所述的方法,其中,所述目标对象包括商品、广告、短视频文章中的一种或多种;所述预定行为包括点击、购买、播放、收藏中的一种或多种。
14.一种行为预测方法,包括:
将待测用户与待测对象的原始特征数据输入行为预测模型,所述行为预测模型为根据权利要求1的方法训练好的学生模型;
将所述行为预测模型的输出结果,作为所述待测用户针对所述待测对象施加预定行为的预测结果。
15.一种训练学生模型的装置,包括:
桥模型更新单元,配置为,将目标用户与目标对象的原始特征数据输入第一嵌入层,得到第一嵌入特征;将第一嵌入特征分别输入训练好的教师模型得到第一预测结果,且将第一嵌入特征输入桥模型得到第二预测结果,根据第一预测结果和第二预测结果,更新所述桥模型;所述桥模型包括桥特征提取网络和桥预测网络;
交互损失确定单元,配置为,将所述第一嵌入特征输入所述桥特征提取网络,得到第一桥提取向量;以及将所述第一嵌入特征输入学生模型包括的学生特征提取网络,得到学生提取向量;根据第一桥提取向量和学生提取向量,确定交互蒸馏损失;
嵌入损失确定单元,配置为,将所述原始特征数据输入第二嵌入层,得到第二嵌入特征;将所述第二嵌入特征输入所述桥特征提取网络,得到第二桥提取向量;根据第一桥提取向量和第二桥提取向量,确定嵌入蒸馏损失;
学生模型更新单元,配置为,至少根据所述交互蒸馏损失和嵌入蒸馏损失,确定综合损失,根据所述综合损失,更新所述学生模型;所述学生模型用于预测预定行为。
16.一种行为预测装置,包括:
输入单元,配置为,将待测用户与待测对象的原始特征数据输入行为预测模型,所述行为预测模型为根据权利要求1的方法训练好的学生模型;
预测单元,配置为,将所述行为预测模型的输出结果,作为所述待测用户针对所述待测对象施加预定行为的预测结果。
17.一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行权利要求1-14中任一项的所述的方法。
18.一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现权利要求1-14中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310907307.9A CN116975686A (zh) | 2023-07-21 | 2023-07-21 | 训练学生模型的方法、行为预测方法和装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310907307.9A CN116975686A (zh) | 2023-07-21 | 2023-07-21 | 训练学生模型的方法、行为预测方法和装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116975686A true CN116975686A (zh) | 2023-10-31 |
Family
ID=88480859
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310907307.9A Pending CN116975686A (zh) | 2023-07-21 | 2023-07-21 | 训练学生模型的方法、行为预测方法和装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116975686A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117555489A (zh) * | 2024-01-11 | 2024-02-13 | 烟台大学 | 物联网数据存储交易异常检测方法、系统、设备和介质 |
-
2023
- 2023-07-21 CN CN202310907307.9A patent/CN116975686A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117555489A (zh) * | 2024-01-11 | 2024-02-13 | 烟台大学 | 物联网数据存储交易异常检测方法、系统、设备和介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110119467B (zh) | 一种基于会话的项目推荐方法、装置、设备及存储介质 | |
CN111080400B (zh) | 一种基于门控图卷积网络的商品推荐方法及系统、存储介质 | |
CN109544306B (zh) | 一种基于用户行为序列特征的跨领域推荐方法及装置 | |
CN110347932B (zh) | 一种基于深度学习的跨网络用户对齐方法 | |
CN111966914B (zh) | 基于人工智能的内容推荐方法、装置和计算机设备 | |
WO2023221928A1 (zh) | 一种推荐方法、训练方法以及装置 | |
CN113392359A (zh) | 多目标预测方法、装置、设备及存储介质 | |
CN113516522A (zh) | 媒体资源推荐方法、多目标融合模型的训练方法及装置 | |
WO2022252458A1 (zh) | 一种分类模型训练方法、装置、设备及介质 | |
CN113254792A (zh) | 训练推荐概率预测模型的方法、推荐概率预测方法及装置 | |
CN116010684A (zh) | 物品推荐方法、装置及存储介质 | |
CN109189922B (zh) | 评论评估模型的训练方法和装置 | |
CN110020877A (zh) | 点击率的预测方法、点击率的确定方法及服务器 | |
CN116975686A (zh) | 训练学生模型的方法、行为预测方法和装置 | |
CN112819024A (zh) | 模型处理方法、用户数据处理方法及装置、计算机设备 | |
CN113610610B (zh) | 基于图神经网络和评论相似度的会话推荐方法和系统 | |
CN114925270A (zh) | 一种会话推荐方法和模型 | |
CN114580794A (zh) | 数据处理方法、装置、程序产品、计算机设备和介质 | |
CN116910357A (zh) | 一种数据处理方法及相关装置 | |
Zhao et al. | CapDRL: a deep capsule reinforcement learning for movie recommendation | |
CN116541592A (zh) | 向量生成方法、信息推荐方法、装置、设备及介质 | |
CN115795153A (zh) | 一种基于特征交互和分数集成的ctr推荐方法 | |
CN115631008B (zh) | 商品推荐方法、装置、设备及介质 | |
CN109299291A (zh) | 一种基于卷积神经网络的问答社区标签推荐方法 | |
CN115080795A (zh) | 一种多充电站协同负荷预测方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |