CN115270002B - 一种基于知识蒸馏的物品推荐方法、电子设备及存储介质 - Google Patents

一种基于知识蒸馏的物品推荐方法、电子设备及存储介质 Download PDF

Info

Publication number
CN115270002B
CN115270002B CN202211161347.5A CN202211161347A CN115270002B CN 115270002 B CN115270002 B CN 115270002B CN 202211161347 A CN202211161347 A CN 202211161347A CN 115270002 B CN115270002 B CN 115270002B
Authority
CN
China
Prior art keywords
sequence
loss function
recommendation
training
parallel
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202211161347.5A
Other languages
English (en)
Other versions
CN115270002A (zh
Inventor
沈利东
沈利辉
赵朋朋
堵瀚文
沈逸旸
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Jiangsu Yiyou Huiyun Software Co ltd
Original Assignee
Jiangsu Yiyou Huiyun Software Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Jiangsu Yiyou Huiyun Software Co ltd filed Critical Jiangsu Yiyou Huiyun Software Co ltd
Priority to CN202211161347.5A priority Critical patent/CN115270002B/zh
Publication of CN115270002A publication Critical patent/CN115270002A/zh
Application granted granted Critical
Publication of CN115270002B publication Critical patent/CN115270002B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/90Details of database functions independent of the retrieved data types
    • G06F16/95Retrieval from the web
    • G06F16/953Querying, e.g. by the use of web search engines
    • G06F16/9535Search customisation based on user profiles and personalisation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • G06N5/02Knowledge representation; Symbolic representation

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Databases & Information Systems (AREA)
  • Computing Systems (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本申请涉及机器学习技术领域,具体涉及一种基于知识蒸馏的物品推荐方法、电子设备及存储介质,所述方法包括:基于用户的物品交互记录构建训练序列和测试序列;基于所述训练序列构建用户候选物品预测模型;将所述测试序列输入所述用户候选物品预测模型进行预测。其中,构建用户候选物品预测模型包括;使用随机数种子对训练序列做遮蔽处理生成若干个遮蔽序列;构建若干序列推荐平行网络,将训练序列和遮蔽序列作为输入分别输入到每个序列推荐平行网络进行预训练并进行知识蒸馏,确定用户候选物品预测模型的损失函数;基于所述损失函数对序列推荐平行网络进行迭代训练获得所述用户候选物品预测模型。本申请的方法,物品推荐准确率高,性能好。

Description

一种基于知识蒸馏的物品推荐方法、电子设备及存储介质
技术领域
本申请涉及机器学习技术领域,特别是涉及一种基于知识蒸馏的物品推荐方法、电子设备及存储介质。
背景技术
推荐系统是基于用户的历史交互信息为用户推荐感兴趣物品的一种系统,推荐系统通常搭建一个用户的动态兴趣模型通过模型的序列编码器捕捉用户的动态兴趣生成用户和物品隐藏表示,再基于用户和物品的隐藏表示推荐物品。现有技术中,都是通过深度神经网络改进单一的序列编码器来精准捕获用户兴趣,然而这种方式由于还是使用单一序列编码器,其使用的深度神经网络只能收敛到一个局部最优点,不能给出多个可能的最优点,其预测和推荐结果存在较大的局限性。
发明内容
为了解决现有技术存在的不足,本申请的目的在于提供一种基于知识蒸馏的物品推荐方法、电子设备及存储介质,提高推荐系统的推荐精准度。
为实现上述目的,本申请提供一种基于知识蒸馏的物品推荐方法,包括:
基于用户的物品交互记录构建训练序列和测试序列;
基于所述训练序列构建用户候选物品预测模型;
将所述测试序列输入所述用户候选物品预测模型进行预测;
其中,基于所述训练序列构建用户候选物品预测模型,包括;
使用随机数种子对训练序列做遮蔽处理生成若干个遮蔽序列,每个遮蔽序列中的一部分序列值随机使用遮蔽标志替换;
构建若干序列推荐平行网络,将所述训练序列和所述遮蔽序列作为输入分别输入到每个所述序列推荐平行网络进行预训练;
基于若干所述序列推荐平行网络的预训练进行知识蒸馏,确定用户候选物品预测模型的损失函数;
基于所述损失函数对所述序列推荐平行网络进行迭代训练获得所述用户候选物品预测模型。
进一步地,所述构建若干序列推荐平行网络,将所述训练序列和所述遮蔽序列作为输入分别输入到每个所述序列推荐平行网络进行预训练的步骤,还包括:
以不同参数初始化若干个序列编码器构建序列推荐平行网络;
基于输入的训练序列,各序列推荐平行网络生成训练序列的隐藏表示进而输出物品类概率;
基于输入的遮蔽序列,各序列推荐平行网络生成遮蔽序列的隐藏表示进而输出候选物品概率。
进一步地,所述基于若干序列推荐平行网络的训练进行知识蒸馏,确定用户候选物品预测模型的损失函数的步骤,还包括:
基于各序列推荐平行网络输出的候选物品概率,确定第一损失函数;
基于各序列推荐平行网络输出的物品类概率,确定第二损失函数;
基于各序列推荐平行网络内训练序列的隐藏表示和各遮蔽序列的隐藏表示进行网络内知识蒸馏,确定第三损失函数;
基于各序列推荐平行网络内训练序列的隐藏表示和各遮蔽序列的隐藏表示进行网络间知识蒸馏,确定第四损失函数;
基于各序列推荐平行网络输出的候选物品概率进行网络间知识蒸馏,确定第五损失函数;
基于第一损失函数、第二损失函数、第三损失函数、第四损失函数和第五损失函数,确定用户候选物品预测模型的损失函数。
进一步地,所述确定第一损失函数的具体步骤包括:
基于候选物品概率,确定各序列推荐平行网络中各遮蔽序列候选物品概率的交叉熵损失函数;
基于各序列推荐平行网络中各遮蔽序列候选物品的概率的交叉熵损失函数,确定第一损失函数。
进一步地,所述确定第二损失函数的具体步骤包括:
基于物品类概率,确定各序列推荐平行网络的物品类概率的交叉熵损失函数;
基于各序列推荐平行网络的交叉熵损失函数,确定第二损失函数。
进一步地,所述确定第三损失函数的具体步骤包括:
序列推荐平行网络以一个训练序列的隐藏表示为锚点,以该训练序列生成的遮蔽序列的隐藏表示为正样本,以训练批次内的其他训练序列的隐藏表示为负样本,进行网络内知识蒸馏,确定该序列推荐平行网络的网络内知识蒸馏的损失函数;
基于对各序列推荐平行网络的网络内知识蒸馏的损失函数,确定第三损失函数。
进一步地,所述确定第四损失函数的具体步骤包括:
序列推荐平行网络以一个序列推荐平行网络中训练序列的隐藏表示为锚点,以该序列推荐平行网络同一训练批次内的其他训练序列的隐藏表示为负样本,以另外一个序列推荐平行网络的遮蔽序列的隐藏表示为正样本进行网络间知识蒸馏,确定这个两个序列推荐平行网络的网络间知识蒸馏的损失函数;
基于序列推荐平行网络两两之间的网络间知识蒸馏的损失函数,确定第四损失函数。
进一步地,所述确定第五损失函数的具体步骤包括:
基于各序列推荐平行网络输出的候选物品概率进行网络间知识蒸馏,确定各序列推荐平行网络两两之间的KL散度;
基于各序列推荐平行网络两两之间的KL散度,确定第五损失函数。
为实现上述目的,本申请提供的电子设备,包括处理器;
存储器,包括一个或多个计算机程序模块;
其中,所述一个或多个计算机程序模块被存储在所述存储器中并被配置为由所述处理器执行,所述一个或多个计算机程序模块包括用于实现如上所述的基于知识蒸馏的物品推荐方法。
为实现上述目的,本申请提供的计算机可读存储介质,其上存储有计算机指令,当计算机指令运行时执行如上所述的基于知识蒸馏的物品推荐方法的步骤。
本申请的一种基于知识蒸馏的物品推荐方法,使用多个序列推荐平行网络进行预训练并进行知识蒸馏,以此提高了用户候选物品预测和推荐的成功率、准确率和性能。
本申请的其它特征和优点将在随后的说明书中阐述,并且,部分地从说明书中变得显而易见,或者通过实施本申请而了解。
附图说明
附图用来提供对本申请的进一步理解,并且构成说明书的一部分,并与本申请的实施例一起,用于解释本申请,并不构成对本申请的限制。在附图中:
图1为本申请的基于知识蒸馏的物品推荐方法的流程示意图;
图2为基于所述训练序列构建用户候选物品预测模型的流程示意图;
图3为构建若干序列推荐平行网络进行预训练的流程示意图;
图4为确定用户候选物品预测模型的损失函数的流程示意图;
图5为用户候选物品预测模型的参数选择的示意图;
图6为用户候选物品预测模型的性能指标的示意图;
图7为本申请的一种电子设备的示意框图;
图8为本申请的一种存储介质的示意图。
具体实施方式
下面将参照附图更详细地描述本申请的实施例。虽然附图中显示了本申请的某些实施例,然而应当理解的是,本申请可以通过各种形式来实现,而且不应该被解释为限于这里阐述的实施例,相反提供这些实施例是为了更加透彻和完整地理解本申请。应当理解的是,本申请的附图及实施例仅用于示例性作用,并非用于限制本申请的保护范围。
应当理解,本申请的方法实施方式中记载的各个步骤可以按照不同的顺序执行,和/或并行执行。此外,方法实施方式可以包括附加的步骤和/或省略执行示出的步骤。本申请的范围在此方面不受限制。
本文使用的术语“包括”及其变形是开放性包括,即“包括但不限于”。术语“基于”是“至少部分地基于”。术语“一个实施例”表示“至少一个实施例”;术语“另一实施例”表示“至少一个另外的实施例”;术语“一些实施例”表示“至少一些实施例”。其他术语的相关定义将在下文描述中给出。
需要注意,本申请中提及的“一个”、“多个”的修饰是示意性而非限制性的,本领域技术人员应当理解,除非在上下文另有明确指出,否则应该理解为“一个或多个”。“多个”应理解为两个或以上。
下面,将参考附图详细地说明本申请的实施例。
实施例1
本申请的一个实施例,提供了一种基于知识蒸馏的物品推荐方法,用于使用多个序列推荐平行网络训练,提高用户候选物品的预测和推荐的准确率。
图1为本申请的基于知识蒸馏的物品推荐方法的流程示意图,下面将参考图1对本申请的基于知识蒸馏的物品推荐方法进行详细描述。
步骤S101:基于用户的物品交互记录构建训练序列和测试序列;
在本实施例中,定义
Figure 479856DEST_PATH_IMAGE001
为用户集合,
Figure 626804DEST_PATH_IMAGE002
为物品集合,
Figure 389223DEST_PATH_IMAGE003
为属性集合。根据每一个用 户
Figure 305227DEST_PATH_IMAGE004
的物品交互记录或同步生成一个按照时间顺序排序的物品交互序列
Figure 760479DEST_PATH_IMAGE005
,其中
Figure 660302DEST_PATH_IMAGE006
代表用户
Figure 226412DEST_PATH_IMAGE007
在第
Figure 996922DEST_PATH_IMAGE008
个时间戳交互的物品,T代表当前 的时间戳以及序列长度。每一个物品
Figure 623076DEST_PATH_IMAGE009
都有其自己的属性集合
Figure 10195DEST_PATH_IMAGE010
, 该集合是所有属性集合的子集。
将每一用户的物品交互序列分割成等长的训练序列,不足的部分补0,具体的如:一个用户的物品交互序列有38个,假设训练序列的长度为10,则该用户的物品交互序列对应的共有4个长度为10的训练序列,其中有一个训练序列中为补两个0的训练序列。
步骤S102:基于所述训练序列构建用户候选物品预测模型:
用户候选物品预测模型用于预测用户
Figure 973471DEST_PATH_IMAGE007
下一个可能交互的物品,即为用户
Figure 129646DEST_PATH_IMAGE007
生成在第T+1个时间戳可能交互的物品的概率分布:
Figure 926701DEST_PATH_IMAGE011
步骤S103:将所述测试序列输入所述用户候选物品预测模型进行预测:
参阅图2,步骤S102的具体步骤包括:
步骤S201:使用随机数种子对训练序列做遮蔽处理生成若干个遮蔽序列,每个遮蔽序列中一定比例的物品随机使用遮蔽标志替换。
具体的,每一个训练序列
Figure 535537DEST_PATH_IMAGE012
,使用不同的随机数种子生成
Figure 443450DEST_PATH_IMAGE013
个不同的遮蔽序列。 在每一个遮蔽序列
Figure 454131DEST_PATH_IMAGE014
中,比例为
Figure 422087DEST_PATH_IMAGE015
的物品
Figure 518219DEST_PATH_IMAGE016
会被随机用遮蔽标志
Figure 698665DEST_PATH_IMAGE017
替换。这里
Figure 829432DEST_PATH_IMAGE018
是遮蔽物品的数量,
Figure 968289DEST_PATH_IMAGE019
代表遮蔽物品的下标。随机序列遮蔽 的过程定义如下:
Figure 286138DEST_PATH_IMAGE020
其中:
Figure 535854DEST_PATH_IMAGE021
步骤S202:构建若干序列推荐平行网络,将训练序列和遮蔽序列作为输入分别输入到每个序列推荐平行网络进行预训练;
考虑到现有技术使用单一的序列推荐平行网络其只能收敛到一个局部最优点,因此在本实施方式中,构建多个即下文的N个序列推荐平行网络来进行训练,尽管每个序列推荐平行网络的架构相同,但在使用不同参数初始化时,各序列推荐平行网络最后收敛的局部最优点也不尽相同,而这种特性会让不同的序列推荐平行网络做出不同的预测,即使,每个序列推荐平行网络的预测只能达到一定的准确率,但将各种预测结合起来就能提高整个候选物品预测网络的准确率。
参阅图3,步骤S202的具体步骤包括:
S301:以不同参数初始化若干个序列编码器构建序列推荐平行网络;
具体的,在本实施方式中,每一序列推荐平行网络都使用双向Transformer的序列编码器构建,不同序列推荐平行网络的双向Transformer的序列编码器采用不同的参数进行初始化。
其中双向Transformer包括嵌入层和Tranformer模块,所述嵌入层使用位置矩阵与输入的训练序列或遮蔽序列相结合生成嵌入表示,Tranformer模块用于基于嵌入表示生成训练序列或遮蔽序列相应的隐藏表示。
示例性的,嵌入层对一个输入序列
Figure 521127DEST_PATH_IMAGE022
转换为一个物品矩阵
Figure 830886DEST_PATH_IMAGE023
结合一个位置矩阵
Figure 262130DEST_PATH_IMAGE024
一起作为物品的嵌入表示E(su),其公式如下:
Figure 315536DEST_PATH_IMAGE025
其中,T代表模型的最大序列长度,d代表隐藏维度。
示例性的,Tranformer模块的计算公式如下:
Figure 155316DEST_PATH_IMAGE026
其中,
Figure 370397DEST_PATH_IMAGE027
为定义的序列编码器,该序列编码器将序列
Figure 662838DEST_PATH_IMAGE028
作为输入并输出该序列 的隐藏表示
Figure 519936DEST_PATH_IMAGE029
,可以理解的是,循环神经网络、卷积神经网络、单向Transformer以及双 向Transforner均能基于用户的训练序列输出相应的隐藏表示,本申请中使用双向 Transformer模块
Figure 214222DEST_PATH_IMAGE030
将嵌入表示生成隐藏表示。
S302:基于输入的训练序列,各序列推荐平行网络生成训练序列的隐藏表示进而输出物品类概率;基于输入的遮蔽序列,各序列推荐平行网络生成遮蔽序列的隐藏表示进而输出候选物品概率。
具体的,以一个训练序列
Figure 600204DEST_PATH_IMAGE031
为例,其输入到第n个序列推荐平行网络中得到的隐藏 表示为
Figure 645521DEST_PATH_IMAGE032
,为对物品进行分类,将物品属性集合 完全一样的物品视作同一种类别,示例性的,对一组物品
Figure 40730DEST_PATH_IMAGE033
,当且仅当
Figure 589523DEST_PATH_IMAGE034
时,该组物品为同一类物品。将物品集合中的物品分为
Figure 411985DEST_PATH_IMAGE035
类,然后 使用一个线性层作为属性分类器来将位置
Figure 679019DEST_PATH_IMAGE036
的隐藏表示
Figure 612339DEST_PATH_IMAGE037
转化为在属性类上的概率分布, 示例性的,概率分布的计算公式如下:
Figure 281218DEST_PATH_IMAGE038
其中,其中
Figure 399216DEST_PATH_IMAGE039
是权重矩阵,
Figure 887966DEST_PATH_IMAGE040
是偏置矩阵,
Figure 890557DEST_PATH_IMAGE041
为第n个序 列推荐平行网络预测的第t个位置物品类别为c的概率。
具体的,以一个遮蔽序列
Figure 148363DEST_PATH_IMAGE042
为例,其输入到第n个序列推荐平行网络中得到的隐 藏表示为
Figure 312628DEST_PATH_IMAGE043
,然后使用一 个线性层作为分类器来将隐藏表示转化为在候选物品上的概率分布。给定在位置
Figure 554254DEST_PATH_IMAGE044
的隐藏 表示输出
Figure 94956DEST_PATH_IMAGE045
,计算过程如下:
Figure 207269DEST_PATH_IMAGE046
其中,其中
Figure 808015DEST_PATH_IMAGE047
是权重矩阵,
Figure 271357DEST_PATH_IMAGE048
是偏置矩阵,
Figure 350171DEST_PATH_IMAGE049
为第n个序列推荐平行网络预测的第m个遮蔽序列中第t个位置物品为c的概率。
步骤203:基于若干序列推荐平行网络的预训练进行知识蒸馏,确定用户候选物品预测模型的损失函数;
使用多个序列推荐平行网络确实能够提高性能,但这种方式带来的提升还不够,在本实施方式中使用知识蒸馏进行知识迁移,以此来提高各序列推荐平行网络的性能。
参阅图4,步骤S203的具体步骤包括:
步骤S401:基于各序列推荐平行网络输出的候选物品概率,确定第一损失函数;
优选的,基于候选物品概率,确定各序列推荐平行网络中各遮蔽序列候选物品概率的交叉熵损失函数,示例性的,第n网络第m个遮蔽序列的交叉熵损失函数如下:
Figure 848149DEST_PATH_IMAGE050
其中:
Figure 354216DEST_PATH_IMAGE051
其中,
Figure 570434DEST_PATH_IMAGE052
是一个指示器函数,当且仅当预测的物品
Figure 311994DEST_PATH_IMAGE053
是真实物品
Figure 664478DEST_PATH_IMAGE054
时,其值为1,否则为0。
对各序列推荐平行网络的
Figure 341447DEST_PATH_IMAGE055
就和得到第一损失函数:
Figure 513802DEST_PATH_IMAGE056
步骤S402:基于各序列推荐平行网络输出的物品类概率,确定第二损失函数;
优选的,基于物品类概率,确定各序列推荐平行网络的物品类概率的交叉熵损失函数,示例性的,第n个序列推荐平行网络的交叉熵损失函数:
Figure 199999DEST_PATH_IMAGE057
其中:
Figure 406989DEST_PATH_IMAGE058
其中,
Figure 520439DEST_PATH_IMAGE059
是一个指示器函数,当且仅当预测的物品类
Figure 180090DEST_PATH_IMAGE060
是真实的物品类
Figure 404398DEST_PATH_IMAGE061
时,其值为1,否则为0。
对各序列推荐平行网络的交叉熵损失函数求和得到第二损失函数:
Figure 731474DEST_PATH_IMAGE062
步骤S403:基于各序列推荐平行网络内训练序列的隐藏表示和各遮蔽序列的隐藏表示进行网络内知识蒸馏,确定第二损失函数;
优选的,序列推荐平行网络以一个训练序列的隐藏表示
Figure 15825DEST_PATH_IMAGE063
为锚点,以该训练 序列生成的遮蔽序列的隐藏表示
Figure 897193DEST_PATH_IMAGE064
为正样本,以训练批次内的其他训练序列
Figure 925192DEST_PATH_IMAGE065
的隐藏表示为负样本
Figure 106775DEST_PATH_IMAGE066
,进行网络内知识蒸馏,确定该序列推荐平行网络 的网络内知识蒸馏的损失函数:
Figure 421081DEST_PATH_IMAGE067
其中
Figure 320904DEST_PATH_IMAGE068
是余弦相似度函数,
Figure 887015DEST_PATH_IMAGE069
是温度超参数,
Figure 923104DEST_PATH_IMAGE070
是从
Figure 283678DEST_PATH_IMAGE071
个不同的遮蔽 序列中随机选取的正样本的下标,
Figure 405218DEST_PATH_IMAGE066
代表来自同一批次的负样本。
需要说明的是,只将锚点
Figure 775019DEST_PATH_IMAGE072
与一个随机选取的正样本
Figure 931194DEST_PATH_IMAGE073
对比而不是与 所有正样本对比。
对各序列推荐平行网络的网络内知识蒸馏的损失函数求和得到第三损失函数:
Figure 462670DEST_PATH_IMAGE074
步骤S404:基于各序列推荐平行网络内训练序列的隐藏表示和各遮蔽序列的隐藏表示进行网络间知识蒸馏,确定第四损失函数;
优选的,以序列推荐平行网络x中一个训练序列的隐藏表示
Figure 71506DEST_PATH_IMAGE075
为锚点,以该序 列推荐平行网络同一训练批次内的其他训练序列
Figure 979419DEST_PATH_IMAGE076
的隐藏表示为负样本
Figure 724521DEST_PATH_IMAGE077
,以另外一个序列推荐平行网络y的遮蔽序列的隐藏表示为正样本
Figure 692477DEST_PATH_IMAGE078
进行网 络间知识蒸馏,确定这个两个序列推荐平行网络的网络间知识蒸馏的损失函数:
Figure 788609DEST_PATH_IMAGE079
对各序列推荐平行网络两两之间的网络间知识蒸馏的损失函数求和得到第三损失函数:
Figure 500213DEST_PATH_IMAGE080
步骤S405:基于各序列推荐平行网络输出的候选物品概率进行网络间知识蒸馏,确定第五损失函数;
在本实施方式中,各序列推荐平行网络可以同时作为从其他网络获得知识的学生网络以及向其他网络输送知识的教师网络,以此有效地分享和迁移序列推荐平行网络之间的知识。
优选的,基于各序列推荐平行网络输出的候选物品概率进行网络间知识蒸馏,确定各序列推荐平行网络两两之间的KL散度。
具体的,先计算分别作为教师网络和学生网络的两序列推荐平行网络在温度
Figure 988570DEST_PATH_IMAGE081
下的softmax概率分布:
Figure 127427DEST_PATH_IMAGE082
其中,
Figure 445276DEST_PATH_IMAGE083
为作为教师网络的序列推荐网络在温度
Figure 694992DEST_PATH_IMAGE081
下的softmax概率分布。
进而确定两序列推荐平行网络两两之间的KL散度:
Figure 680265DEST_PATH_IMAGE084
将所有序列推荐平行网络两两组合计算并求和得到第五损失函数:
Figure 990024DEST_PATH_IMAGE085
步骤S406:将第一损失函数、第二损失函数、第二损失函数、第四损失函数和第五损失函数加权求和得到用户候选物品预测模型的损失函数:
Figure 795169DEST_PATH_IMAGE086
其中,其中
Figure 582996DEST_PATH_IMAGE087
是权重超参数。
步骤S204:基于所述用户候选物品预测模型的损失函数对序列推荐平行网络进行迭代训练获得所述用户候选物品预测模型。
下面将结合试验数据对本申请的基于知识蒸馏的物品推荐方法做进一步说明:
在本实施方式中,使用广泛用于推荐的Beauty、Toys以及ML-1M三个数据集进行训练测试。
参阅图5、图5为用户候选物品预测模型的参数选择的示意图,如图5所示,其内容使用本模型对三个数据集训练迭代后的最佳参数。
参阅图6,图6为用户候选物品预测模型的性能指标的示意图,如图6所示,以HR@K和NDCG@K作为主要的性能评价指标,可以看出本申请的用户候选物品预测模型(表中的EMKD)在Beauty、Toys以及ML-1M三个数据集上预测推荐的性能比其他如:GRU4Rec、Caser、SASRec、BertRec、FDSA、S3-Rec、MMInfoRec、CL4SRec和DuoRec等模型的性能指标中的最大值都有9.59%-33.48%的提升。
实施例2
本实施例中,还提供一种电子设备,图7为本申请提供的一种电子设备的示意框图。如图7所示,电子设备130包括处理器131和存储器132。存储器132用于存储非暂时性计算机可读指令(例如一个或多个计算机程序模块)。处理器131用于运行非暂时性计算机可读指令,非暂时性计算机可读指令被处理器131运行时可以执行上文所述的基于知识蒸馏的物品推荐方法的一个或多个步骤。存储器132和处理器131可以通过总线系统和/或其它形式的连接机构(未示出)互连。
例如,处理器131可以是中央处理单元(CPU)、数字信号处理器(DSP)或者具有数据处理能力和/或程序执行能力的其它形式的处理单元,例如现场可编程门阵列(FPGA)等;例如,中央处理单元(CPU)可以为X86或ARM架构等。
例如,存储器132可以包括一个或多个计算机程序产品的任意组合,计算机程序产品可以包括各种形式的计算机可读存储介质,例如易失性存储器和/或非易失性存储器。易失性存储器例如可以包括随机存取存储器(RAM)和/或高速缓冲存储器(cache)等。非易失性存储器例如可以包括只读存储器(ROM)、硬盘、可擦除可编程只读存储器(EPROM)、便携式紧致盘只读存储器(CD-ROM)、USB存储器、闪存等。在计算机可读存储介质上可以存储一个或多个计算机程序模块,处理器131可以运行一个或多个计算机程序模块,以实现电子设备130的各种功能。在计算机可读存储介质中还可以存储各种应用程序和各种数据以及应用程序使用和/或产生的各种数据等。
需要说明的是,本申请的实施例中,电子设备130的具体功能和技术效果可以参考上文中关于基于知识蒸馏的物品推荐方法的描述,此处不再赘述。
实施例3
本实施例中,还提供一种计算机可读存储介质,图8为本申请的一种存储介质的示意图。如图8所示,存储介质150用于存储非暂时性计算机可读指令151。例如,当非暂时性计算机可读指令151由计算机执行时可以执行根据上文所述的基于知识蒸馏的物品推荐方法中的一个或多个步骤。
例如,该存储介质150可以应用于上述电子设备130中。例如,存储介质150可以为图7所示的电子设备130中的存储器132。例如,关于存储介质150的相关说明可以参考图7所示的电子设备130中的存储器132的相应描述,此处不再赘述。
需要说明的是,本申请上述的存储介质(计算机可读介质)可以是计算机可读信号介质或者非暂时性计算机可读存储介质或者是上述两者的任意组合。非暂时性计算机可读存储介质例如可以是,但不限于,电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。非暂时性计算机可读存储介质的更具体的例子可以包括但不限于:具有一个或多个导线的电连接、便携式计算机磁盘、硬盘、随机访问存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑磁盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。
在本申请中,非暂时性计算机可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。而在本申请中,计算机可读信号介质可以包括在基带中或者作为载波一部分传播的数据信号,其中承载了计算机可读的程序代码。这种传播的数据信号可以采用多种形式,包括但不限于电磁信号、光信号或上述的任意合适的组合。计算机可读信号介质还可以是非暂时性计算机可读存储介质以外的任何计算机可读介质,该计算机可读信号介质可以发送、传播或者传输用于由指令执行系统、装置或者器件使用或者与其结合使用的程序。计算机可读介质上包含的程序代码可以用任何适当的介质传输,包括但不限于:电线、光缆、RF(射频)等,或者上述的任意合适的组合。
上述计算机可读介质可以是上述电子设备中所包含的;也可以是单独存在,而未装配入该电子设备中。
可以以一种或多种程序设计语言或其组合来编写用于执行本申请的操作的计算机程序代码,上述程序设计语言包括但不限于面向对象的程序设计语言,诸如Java、Smalltalk、C++,还包括常规的过程式程序设计语言,诸如“C”语言或类似的程序设计语言。程序代码可以完全地在用户计算机上执行、部分地在用户计算机上执行、作为一个独立的软件包执行、部分在用户计算机上部分在远程计算机上执行、或者完全在远程计算机或服务器上执行。
附图中的流程图和框图,图示了按照本申请各种实施例的系统、方法和计算机程序产品的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段、或代码的一部分,该模块、程序段、或代码的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。也应当注意,在有些作为替换的实现中,方框中所标注的功能也可以以不同于附图中所标注的顺序发生。例如,两个接连地表示的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这根据所涉及的功能而定。也要注意的是,框图和/或流程图中的每个方框、以及框图和/或流程图中的方框的组合,可以用执行规定的功能或操作的专用的基于硬件的系统来实现,或者可以用专用硬件与计算机指令的组合来实现。
描述于本申请实施例中所涉及到的单元可以通过软件的方式实现,也可以通过硬件的方式来实现。其中,单元的名称在某种情况下并不构成对该单元本身的限定。
本文中以上描述的功能可以至少部分地由一个或多个硬件逻辑部件来执行。例如,非限制性地,可以使用的示范类型的硬件逻辑部件包括:现场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、片上系统(SOC)、复杂可编程逻辑设备(CPLD)等。
以上描述仅为本申请的部分实施例以及对所运用技术原理的说明。本领域技术人员应当理解,本申请中所涉及的公开范围,并不限于上述技术特征的特定组合而成的技术方案,同时也应涵盖在不脱离上述公开构思的情况下,由上述技术特征或其等同特征进行任意组合而形成的其它技术方案。例如上述特征与本申请中公开的(但不限于)具有类似功能的技术特征进行互相替换而形成的技术方案。
此外,虽然采用特定次序描绘了各操作,但是这不应当理解为要求这些操作以所示出的特定次序或以顺序次序来执行。在一定环境下,多任务和并行处理可能是有利的。同样地,虽然在上面论述中包含了若干具体实现细节,但是这些不应当被解释为对本申请的范围的限制。在单独的实施例的上下文中描述的某些特征还可以组合地实现在单个实施例中。相反地,在单个实施例的上下文中描述的各种特征也可以单独地或以任何合适的子组合的方式实现在多个实施例中。
尽管已经采用特定于结构特征和/或方法逻辑动作的语言描述了本主题,但是应当理解所附权利要求书中所限定的主题未必局限于上面描述的特定特征或动作。相反,上面所描述的特定特征和动作仅仅是实现权利要求书的示例形式。

Claims (10)

1.一种基于知识蒸馏的物品推荐方法,包括:
基于用户的物品交互记录构建训练序列和测试序列;
基于所述训练序列构建用户候选物品预测模型;
将所述测试序列输入所述用户候选物品预测模型进行预测;
其中,基于所述训练序列构建用户候选物品预测模型,包括:
使用随机数种子对训练序列做遮蔽处理生成若干个遮蔽序列,每个遮蔽序列中的一部分序列值随机使用遮蔽标志替换;
构建若干序列推荐平行网络,将所述训练序列和所述遮蔽序列作为输入分别输入到每个所述序列推荐平行网络进行预训练;
基于若干所述序列推荐平行网络的预训练进行知识蒸馏,确定用户候选物品预测模型的损失函数;
基于所述损失函数对所述序列推荐平行网络进行迭代训练获得所述用户候选物品预测模型。
2.根据权利要求1所述的基于知识蒸馏的物品推荐方法,其特征在于,所述构建若干序列推荐平行网络,将所述训练序列和所述遮蔽序列作为输入分别输入到每个所述序列推荐平行网络进行预训练的步骤,还包括:
以不同参数初始化若干个序列编码器构建序列推荐平行网络;
基于输入的训练序列,各序列推荐平行网络生成训练序列的隐藏表示进而输出物品类概率;
基于输入的遮蔽序列,各序列推荐平行网络生成遮蔽序列的隐藏表示进而输出候选物品概率。
3.根据权利要求1所述的基于知识蒸馏的物品推荐方法,其特征在于,所述基于若干所述序列推荐平行网络的训练进行知识蒸馏,确定用户候选物品预测模型的损失函数的步骤,还包括:
基于各序列推荐平行网络输出的候选物品概率,确定第一损失函数;
基于各序列推荐平行网络输出的物品类概率,确定第二损失函数;
基于各序列推荐平行网络内训练序列的隐藏表示和各遮蔽序列的隐藏表示进行网络内知识蒸馏,确定第三损失函数;
基于各序列推荐平行网络内训练序列的隐藏表示和各遮蔽序列的隐藏表示进行网络间知识蒸馏,确定第四损失函数;
基于各序列推荐平行网络输出的候选物品概率进行网络间知识蒸馏,确定第五损失函数;
基于第一损失函数、第二损失函数、第三损失函数、第四损失函数和第五损失函数,确定用户候选物品预测模型的损失函数。
4.根据权利要求3所述的基于知识蒸馏的物品推荐方法,其特征在于,所述确定第一损失函数的具体步骤包括:
基于候选物品概率,确定各序列推荐平行网络中各遮蔽序列候选物品概率的交叉熵损失函数;
基于各序列推荐平行网络中各遮蔽序列候选物品的概率的交叉熵损失函数,确定第一损失函数。
5.根据权利要求3所述的基于知识蒸馏的物品推荐方法,其特征在于,所述确定第二损失函数的具体步骤包括:
基于物品类概率,确定各序列推荐平行网络的物品类概率的交叉熵损失函数;
基于各序列推荐平行网络的物品类概率的交叉熵损失函数,确定第二损失函数。
6.根据权利要求3所述的基于知识蒸馏的物品推荐方法,其特征在于,所述确定第三损失函数的具体步骤包括:
序列推荐平行网络以一个训练序列的隐藏表示为锚点,以该训练序列生成的遮蔽序列的隐藏表示为正样本,以训练批次内的其他训练序列的隐藏表示为负样本,进行网络内知识蒸馏,确定该序列推荐平行网络的网络内知识蒸馏的损失函数;
基于对各序列推荐平行网络的网络内知识蒸馏的损失函数,确定第三损失函数。
7.根据权利要求3所述的基于知识蒸馏的物品推荐方法,其特征在于,所述确定第四损失函数的具体步骤包括:
序列推荐平行网络以一个序列推荐平行网络中训练序列的隐藏表示为锚点,以该序列推荐平行网络同一训练批次内的其他训练序列的隐藏表示为负样本,以另外一个序列推荐平行网络的遮蔽序列的隐藏表示为正样本进行网络间知识蒸馏,确定这个两个序列推荐平行网络的网络间知识蒸馏的损失函数;
基于序列推荐平行网络两两之间的网络间知识蒸馏的损失函数,确定第四损失函数。
8.根据权利要求3所述的基于知识蒸馏的物品推荐方法,其特征在于,所述确定第五损失函数的具体步骤包括:
基于各序列推荐平行网络输出的候选物品概率进行网络间知识蒸馏,确定各序列推荐平行网络两两之间的KL散度;
基于各序列推荐平行网络两两之间的KL散度,确定第五损失函数。
9.一种电子设备,其特征在于,包括:
处理器;
存储器,包括一个或多个计算机程序模块;
其中,所述一个或多个计算机程序模块被存储在所述存储器中并被配置为由所述处理器执行,所述一个或多个计算机程序模块包括用于实现权利要求1-8中任一项所述的基于知识蒸馏的物品推荐方法。
10.一种计算机可读存储介质,其特征在于,其上存储有计算机指令,当计算机指令运行时执行权利要求1-8中任一项所述的基于知识蒸馏的物品推荐方法的步骤。
CN202211161347.5A 2022-09-23 2022-09-23 一种基于知识蒸馏的物品推荐方法、电子设备及存储介质 Active CN115270002B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211161347.5A CN115270002B (zh) 2022-09-23 2022-09-23 一种基于知识蒸馏的物品推荐方法、电子设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211161347.5A CN115270002B (zh) 2022-09-23 2022-09-23 一种基于知识蒸馏的物品推荐方法、电子设备及存储介质

Publications (2)

Publication Number Publication Date
CN115270002A CN115270002A (zh) 2022-11-01
CN115270002B true CN115270002B (zh) 2022-12-09

Family

ID=83756911

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211161347.5A Active CN115270002B (zh) 2022-09-23 2022-09-23 一种基于知识蒸馏的物品推荐方法、电子设备及存储介质

Country Status (1)

Country Link
CN (1) CN115270002B (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116151353B (zh) * 2023-04-14 2023-07-18 中国科学技术大学 一种序列推荐模型的训练方法和对象推荐方法

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113590958A (zh) * 2021-08-02 2021-11-02 中国科学院深圳先进技术研究院 基于样本回放的序列推荐模型的持续学习方法
CN114817742A (zh) * 2022-05-18 2022-07-29 平安科技(深圳)有限公司 基于知识蒸馏的推荐模型配置方法、装置、设备、介质

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11727270B2 (en) * 2020-02-24 2023-08-15 Microsoft Technology Licensing, Llc Cross data set knowledge distillation for training machine learning models

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113590958A (zh) * 2021-08-02 2021-11-02 中国科学院深圳先进技术研究院 基于样本回放的序列推荐模型的持续学习方法
CN114817742A (zh) * 2022-05-18 2022-07-29 平安科技(深圳)有限公司 基于知识蒸馏的推荐模型配置方法、装置、设备、介质

Also Published As

Publication number Publication date
CN115270002A (zh) 2022-11-01

Similar Documents

Publication Publication Date Title
CN110366734B (zh) 优化神经网络架构
Lin et al. Toward compact convnets via structure-sparsity regularized filter pruning
CN113535984B (zh) 一种基于注意力机制的知识图谱关系预测方法及装置
Halvorsen et al. Opportunities for improved distribution modelling practice via a strict maximum likelihood interpretation of MaxEnt
US20150254554A1 (en) Information processing device and learning method
CN111860982A (zh) 一种基于vmd-fcm-gru的风电场短期风电功率预测方法
CN109816032B (zh) 基于生成式对抗网络的无偏映射零样本分类方法和装置
CN108319585B (zh) 数据处理方法及装置、电子设备、计算机可读介质
Wang et al. Learning efficient binarized object detectors with information compression
CN115270002B (zh) 一种基于知识蒸馏的物品推荐方法、电子设备及存储介质
Mena et al. Sinkhorn networks: Using optimal transport techniques to learn permutations
Mukunthu et al. Practical automated machine learning on Azure: using Azure machine learning to quickly build AI solutions
Li et al. The impact of feature selection techniques on effort‐aware defect prediction: An empirical study
CN115238855A (zh) 基于图神经网络的时序知识图谱的补全方法及相关设备
CN105138527B (zh) 一种数据分类回归方法及装置
CN114881343A (zh) 基于特征选择的电力系统短期负荷预测方法及装置
JP7427011B2 (ja) センサ入力信号からのコグニティブ・クエリへの応答
CN115759291B (zh) 一种基于集成学习的空间非线性回归方法及系统
Wang et al. DMFP: A dynamic multi-faceted fine-grained preference model for recommendation
CN114896138B (zh) 一种基于复杂网络和图神经网络的软件缺陷预测方法
CN115169433A (zh) 基于元学习的知识图谱分类方法及相关设备
CN115048530A (zh) 融合邻居重要度和特征学习的图卷积推荐系统
CN112818658B (zh) 文本对分类模型的训练方法、分类方法、设备及存储介质
KR20220097856A (ko) 완전 연결 레이어를 사용하는 신경망 실행 블록
Glushkovsky Ai giving back to statistics? discovery of the coordinate system of univariate distributions by beta variational autoencoder

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant