CN115292587B - 基于知识蒸馏和因果推理的推荐方法及系统 - Google Patents

基于知识蒸馏和因果推理的推荐方法及系统 Download PDF

Info

Publication number
CN115292587B
CN115292587B CN202210837534.4A CN202210837534A CN115292587B CN 115292587 B CN115292587 B CN 115292587B CN 202210837534 A CN202210837534 A CN 202210837534A CN 115292587 B CN115292587 B CN 115292587B
Authority
CN
China
Prior art keywords
model
user
training
loss
distillation
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202210837534.4A
Other languages
English (en)
Other versions
CN115292587A (zh
Inventor
况琨
张圣宇
赵洲
吴飞
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang University ZJU
Original Assignee
Zhejiang University ZJU
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang University ZJU filed Critical Zhejiang University ZJU
Priority to CN202210837534.4A priority Critical patent/CN115292587B/zh
Publication of CN115292587A publication Critical patent/CN115292587A/zh
Application granted granted Critical
Publication of CN115292587B publication Critical patent/CN115292587B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/90Details of database functions independent of the retrieved data types
    • G06F16/95Retrieval from the web
    • G06F16/953Querying, e.g. by the use of web search engines
    • G06F16/9536Search customisation based on social or collaborative filtering
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/90Details of database functions independent of the retrieved data types
    • G06F16/95Retrieval from the web
    • G06F16/953Querying, e.g. by the use of web search engines
    • G06F16/9535Search customisation based on user profiles and personalisation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q30/00Commerce
    • G06Q30/06Buying, selling or leasing transactions
    • G06Q30/0601Electronic shopping [e-shopping]
    • G06Q30/0631Item recommendations
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02ATECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
    • Y02A90/00Technologies having an indirect contribution to adaptation to climate change
    • Y02A90/10Information and communication technologies [ICT] supporting adaptation to climate change, e.g. for weather forecasting or climate simulation

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Databases & Information Systems (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Business, Economics & Management (AREA)
  • Data Mining & Analysis (AREA)
  • Finance (AREA)
  • Accounting & Taxation (AREA)
  • Computational Linguistics (AREA)
  • Artificial Intelligence (AREA)
  • Computing Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Molecular Biology (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Development Economics (AREA)
  • Economics (AREA)
  • Marketing (AREA)
  • Strategic Management (AREA)
  • General Business, Economics & Management (AREA)
  • Health & Medical Sciences (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了一种基于知识蒸馏和因果推理的推荐方法及系统。本发明中,首先把训练数据集中所有用户按照敏感属性的高低等分成若干个用户组。再利用所有用户的行为数据训练一个基础推荐模型,继而按照用户分组,利用每一组用户数据,对基础推荐模型进行微调,为每一组用户训练一个教师模型;最后利用所有用户的数据,借助因果推断中的前门调整方法,通过每个用户分组的教师模型获取多个中间表征作为中介,继而利用Batch内采样机制和注意力机制,进行多模型多样本信息聚合,并将聚合后的信息蒸馏到学生模型。本发明将因果知识蒸馏技术应用于项目推荐中,相比于普通推荐算法,引入因果建模可以有效提升用户的推荐服务公平性,缓解马太效应。

Description

基于知识蒸馏和因果推理的推荐方法及系统
技术领域
本发明涉及推荐系统领域,尤其涉及一种旨在挖掘和利用在线电子商务中存在的海量数据价值从而缓解异质性的推荐方法及系统。
背景技术
随着互联网的发展,信息以指数爆炸的方式飞速增长。推荐系统通过为用户提供个性化服务来寻找信息来减轻网络上的信息过载。通过估计用户和项目之间的相关性,从历史用户-项目交互中学习推荐模型。目前大量模型通常表现出用户服务性能异质性,不同用户组得到的服务质量存在显着差异。
从数据和模型的角度来看,存在两个性能异质性问题的来源。一个是自然来源:训练数据分布在用户上的不平衡。例如,由于丰富的交互记录和更全面的兴趣建模,活跃用户可能会收到相对准确的推荐。在自然来源的基础上,协同过滤模型进一步放大了数据不平衡的影响,这是性能异质问题的模型根源。显然,自然来源是不可避免的,因此追求不同用户组的性能平均是不明智的。缓解性能异质性问题的关键在于解决模型对性能异质的放大问题,但如何推荐模型的性能异质性目前尚没有较好地解决方案。
发明内容
本发明的目的是克服现有推荐系统模型在异质性上的不足,利用因果图分析推荐模型的性能异质性问题,并制定基于前门调整的因果推荐模型来处理未观察到的混杂因素。
本发明采用的技术方案如下:
第一方面,本发明提供了一种基于知识蒸馏和因果推理的推荐方法,其包括如下步骤:
S1、获取用于推荐模型训练的训练数据集,其中每个训练样本包含用户的编号、用户历史行为数据和项目编号;将所述训练数据集中所有用户按照敏感属性的高低排序后等分成若干个用户组;
S2、利用所述训练数据集训练一个由编码器和预测期组成的基础推荐模型,再按照用户分组,利用每个用户组的数据分别对基础推荐模型进行微调,从而针对每一个用户组训练得到一个教师模型;在多个非独立同分布的教师模型基础上,构建一个用于通过因果后门调整进行特征蒸馏的第一蒸馏损失;
S3、借助因果推断中的前门调整方法,通过每个用户组对应的教师模型获取多个中间表征作为中介,再利用批次(Batch)内采样机制和注意力机制,进行多模型多样本信息聚合,并构建一个用于将聚合后的信息蒸馏到学生模型的第二蒸馏损失;
S4、将第一蒸馏损失、第二蒸馏损失和学生模型的推荐损失加权后作为总损失,利用所述训练数据集通过最小化总损失对学生模型进行训练,并利用训练后的学生模型对目标用户进行项目推荐。
在上述方案基础上,各步骤可以采用如下优选的具体方式实现。
作为上述第一方面的优选,所述的步骤S1具体包括以下子步骤:
S101、获取训练样本构成的训练数据集,其中每个训练样本包含用户的编号、用户历史行为数据和项目编号;对于训练数据集中用户的历史行为数据,首先基于预先确定的用户出现频率阈值Nu和项目出现频率阈值Ni,过滤掉数据中出现频率小于Nu的用户以及出现频率小于Ni的项目;
S102、以用户活跃度作为所有用户的分组依据,设置分组的组数Ng;然后按照用户在S101过滤后的历史行为数据中出现的频率从高到底排序,将所有用户分成Ng个用户组,每个用户组的数据为这组用户各自所拥有的历史行为数据组成的集合。
作为上述第一方面的优选,所述的步骤S2具体包括以下子步骤:
S201、针对预先选定的由编码器和预测期组成的基础推荐模型,通过所有用户的历史行为数据对其进行训练,得到一个预训练模型M0
S202、在用户分组的基础上,基于每个用户组的数据对预训练模型M0进行微调,从而为每一个用户组训练一个独立的教师模型;然后使用非独立同分布的教师模型来估计期望,其中对于输入的同一个用户特征xi,非独立同分布教师模型Φ={φk}k=1,...,|Φ|对非独立同分布数据样本进行训练会产生中介变量M的不同取值,将不同教师模型编码器的输出进行加权求和得到中间表征的去偏估计期望:
Figure BDA0003749271120000021
式中:mk,i代表第k个教师模型对于第i个样本的中介特征表示,由第k个教师模型的编码器对输入的用户特征xi进行编码后输出得到;
S203、将去偏估计期望
Figure BDA0003749271120000031
作为训练学生模型的因果指导之一,构建用于通过因果后门调整进行特征蒸馏的第一蒸馏损失/>
Figure BDA0003749271120000032
其公式具体如下:
Figure BDA0003749271120000033
式中:
Figure BDA0003749271120000034
是学生模型对于第i个样本的中介特征表示,由学生模型的编码器对输入的用户特征xi进行编码后输出得到;Distance表示平均均方误差MSE计算函数。
作为上述第一方面的优选,所述的步骤S3具体包括以下子步骤:
S301、使用非独立同分布的教师模型在给定用户变量X=xi的情况下对中间表征变量M的取值进行抽样,且抽样过程中对于给定X=xi时M=mk,i的条件概率建模为P(M=mk,i|X=xi),该条件概率采用注意力机制进行求解:
Figure BDA0003749271120000035
其中,αk,i代表通过注意力机制得到的条件概率P(M=mk,i|X=xi),mk,i是第k个教师模型抽样得到的第i个中介变量M的取值,mk,i代表第k个教师模型对于第i个样本的中介特征表示,W1、W2均为可学习的参数矩阵;
Figure BDA0003749271120000038
是学生模型对于第i个样本的中介特征表示;
S302、采用批次内采样策略对用户变量X进行采样,对于总样本数为Nb+1个的批次中一个给定的训练样本,采样过程中将Batch中的其他每个训练样本对应的用户xj分别作为用户变量X的采样值,且X=xj的先验概率P(X=xj)采用均匀分布;
S303、将X的采样值为xj且M的采样值为mk,i时预测结果Y为yi的概率P(Y=yi|X=xj,M=mk,i)参数化为网络
Figure BDA0003749271120000036
并使用一个sigmoid层σ进行二分类:
Figure BDA0003749271120000037
其中mk,i和yi分别是mk,i和yi的特征表示;进一步将用户xj用对应的中介特征表示
Figure BDA0003749271120000041
进行替代,从而将概率建模转换为:
Figure BDA0003749271120000042
其中
Figure BDA0003749271120000043
是以xj为输入时学生模型提取的的中介特征表示,网络/>
Figure BDA0003749271120000044
的结构与所述基础推荐模型中的预测器结构一致;
S304、在通过知识蒸馏进行前门调整过程中,对于批次中的第i个样本,经过前门调整后的预测结果
Figure BDA0003749271120000045
为:
Figure BDA0003749271120000046
式中:Nb为X在一个批次内采样的总数;
将预测结果
Figure BDA0003749271120000047
作为训练学生模型的因果指导之一,构建用于通过因果前门调整进行特征蒸馏的第二蒸馏损失/>
Figure BDA0003749271120000048
其公式具体如下:
Figure BDA0003749271120000049
式中:
Figure BDA00037492711200000410
表示学生模型对于批次中的第i个样本的预测结果;oi表示批次中的第i个样本的预测结果的真实标签,代表用户xi是否点击了物品yi;/>
Figure BDA00037492711200000411
代表/>
Figure BDA00037492711200000412
和/>
Figure BDA00037492711200000413
的蒸馏损失,/>
Figure BDA00037492711200000414
代表/>
Figure BDA00037492711200000415
和oi的一致性损失。
作为上述第一方面的优选,所述的步骤S4具体包括以下子步骤:
S41、针对学生模型的训练构建总损失函数
Figure BDA00037492711200000416
其形式为:
Figure BDA00037492711200000417
式中:
Figure BDA00037492711200000418
为第一蒸馏损失,/>
Figure BDA00037492711200000419
为第二蒸馏损失,/>
Figure BDA00037492711200000420
为学生模型自带的推荐损失,α和β分别为权重值;
S42、以最小化总损失函数为目标,利用所述训练数据集中的所有训练样本对学生模型进行训练,直至模型收敛;利用训练后的学生模型对目标用户进行项目推荐。
作为上述第一方面的优选,所述的基础推荐模型、教师模型和学生模型均采用深度兴趣网络(Deep Interest Network,DIN)。
作为上述第一方面的优选,所述S302中,采用均匀分布的先验概率P(X=xj)为1/Nb
作为上述第一方面的优选,所述的学生模型自带的推荐损失为深度兴趣网络自身计算的推荐损失。
作为上述第一方面的优选,所述的项目为商品、应用程序。
第二方面,本发明提供了一种基于知识蒸馏和因果推理的推荐系统,其包括:
数据集获取模块,该模块获取用于推荐模型训练的训练数据集,其中每个训练样本包含用户的编号、用户历史行为数据和项目编号;将所述训练数据集中所有用户按照敏感属性的高低排序后等分成若干个用户组;
第一损失模块,该模块利用所述训练数据集训练一个由编码器和预测期组成的基础推荐模型,再按照用户分组,利用每个用户组的数据分别对基础推荐模型进行微调,从而针对每一个用户组训练得到一个教师模型;在多个非独立同分布的教师模型基础上,构建一个用于通过因果后门调整进行特征蒸馏的第一蒸馏损失;
第二损失模块,该模块借助因果推断中的前门调整方法,通过每个用户组对应的教师模型获取多个中间表征作为中介,再利用批次(Batch)内采样机制和注意力机制,进行多模型多样本信息聚合,并构建一个用于将聚合后的信息蒸馏到学生模型的第二蒸馏损失;
训练和推荐模块,该模块将第一蒸馏损失、第二蒸馏损失和学生模型的推荐损失加权后作为总损失,利用所述训练数据集通过最小化总损失对学生模型进行训练,并利用训练后的学生模型对目标用户进行项目推荐。
相对于现有技术而言,本发明将因果知识蒸馏技术应用于在线商品、应用程序、视频推荐等领域中,可以挖掘和利用在线电子商务中存在的海量数据价值并缓解性能异质性。相比于普通推荐算法,本发明引入因果建模可以有效提升用户的推荐服务公平性,缓解马太效应。
附图说明
图1为基于知识蒸馏和因果推理的推荐方法的步骤流程图。
图2从因果视角对推荐中模型异质问题产生的原因分析机理以及本发明提出模型的因果建模方法原理图。
图3为基于知识蒸馏和因果推理的推荐系统的模块示意图。
具体实施方式
下面结合附图和具体实施方式对本发明做进一步阐述和说明。
本发明中为了进一步揭示推荐模型偏差的原因,深入研究了用户与项目交互的生成过程,该过程被抽象为因果图。在因果图中,X(用户)直接在Y(项目),它反映了用户偏好和项目属性之间的匹配。Z表示除用户-项目匹配之外直接影响用户和项目的一组因素。Z被确定为X和Y之间的混杂因素,当直接估计用户和项目之间的相关性时,这将导致虚假相关性(X←Z→Y)。虚假的相关性可能会导致不准确的推荐,从而有相对较高的机会伤害尾部用户。例如,用户活动(Z)将增加历史记录的大小(Z→X)并鼓励探索(Z→Y)。更多的探索(和更少的利用)可能会导致与与用户内部兴趣(X
Figure BDA0003749271120000061
Y)相关性较低的项目进行交互,并放大虚假相关性(X←Z→Y)。虽然虚假相关的用户-项目对似乎对相应的活跃用户的危害较小,因为他们乐于探索,但它会对协同过滤下的长尾用户造成伤害。因此,缓解性能异质性的关键在于阻止虚假相关性,即对因果效应X→Y进行建模。而利用知识蒸馏和因果推理领域的前后门调整技术可以有效的解决上述难点问题。
因此,如图1所示,在本发明的一个较佳实施例中,提供了一种基于知识蒸馏和因果推理的推荐方法,其包括如下步骤:
S1、获取用于推荐模型训练的训练数据集,其中每个训练样本包含用户的编号、用户历史行为数据和项目编号;将所述训练数据集中所有用户按照敏感属性的高低排序后等分成若干个用户组。
需要说明的是,本发明中的项目可以是为商品、应用程序(如App、小程序等)、在线内容(如视频、新闻、歌曲等等)。
在本发明的实施例中,上述步骤S1具体包括以下子步骤:
S101、获取训练样本构成的训练数据集,其中每个训练样本包含用户的编号、用户历史行为数据和项目编号;对于训练数据集中用户的历史行为数据,首先基于预先确定的用户出现频率阈值Nu和项目出现频率阈值Ni,过滤掉数据中出现频率小于Nu的用户以及出现频率小于Ni的项目。
S102、为了训练,需要将用户进行分组,可以采用一些与推荐系统中观察到的重要混杂因素相关的因素做为分组的基准。本实施例中以用户活跃度(即用户的历史交互行为次数)作为所有用户的分组依据,设置分组的组数Ng(编号0,1,...,Ng-1);然后按照用户在S101过滤后的历史行为数据中出现的频率从高到底排序,将所有用户分成Ng个用户组,每个用户组的数据为这组用户各自所拥有的历史行为数据组成的集合。
S2、利用所述训练数据集训练一个由编码器和预测期组成的基础推荐模型(basemodel),再按照用户分组,利用每个用户组的数据分别对基础推荐模型进行微调,从而针对每一个用户组训练得到一个教师模型;在多个非独立同分布的教师模型基础上,构建一个用于通过因果后门调整进行特征蒸馏的第一蒸馏损失。
在论述本步骤的具体实现过程之前,下面先对其基本理论部分进行叙述,以便于本领域技术人员更好地理解本发明的实现机制。
在序列推荐中,用户序列对目标项目的估计影响可能会受到由观察到的混杂因素(例如,用户的活跃度)和未观察到的混杂因素(例如,用户对项目受欢迎程度的态度)引起的虚假相关性的困扰。我们使用大写字母(例如X)表示变量,小写字母(例如x)表示变量的特定值,粗体字母(例如x)表示相应的向量表示。如图2所示,由于训练样本的数据生成过程充满了选择偏差,输入X和结果Y之间存在着混杂效应。换句话说,由于间接路径X←Z→Y存在,X和Y之间的边际关联并不能识别X对Y的纯因果效应(即从X到Y的直接路径)。Z被称为X和Y的混杂因素或共同原因。为了处理各种甚至未观察到的混杂因素,本发明从社会科学中借鉴了前门调整的想法,并研究了如何使用深度神经网络,特别是知识蒸馏框架来实现它。
前门调整要求。前门调整具有对抗未观察到的混杂因素的优势。与原始因果图相比,前门调整通过中介M估计X→Y的因果效应,即X→M→Y。有效的前门调整应满足以下条件:
(1)M截取所有从X到Y的有向路径;
(2)X到M没有未阻塞后门路径;
(3)从M到Y的所有后门路径都被X阻塞。
对于条件(1):在深度神经网络中,可以直接将X的中间特征表示视为M。使用中介M,训练现有模型的观察相关性P(Y|X)可以表述如下:
Figure BDA0003749271120000081
但是,由于M应该拦截从X到Y的所有有向路径,因此从X中提取的子部分不能被视为M。因此,给定X对M的采样仍然是一个挑战。至于第(2)和第(3)个条件,假设它们已经被满足是社会科学的惯例。然而,在DNN中,数据和标签中的偏差可能会反向传播到给定X估计M的学生模型参数。因此,本发明需要使用后门调整来处理一些重要且观察到的混杂因素,以更好地满足要求条件(2)。通过考虑观察到的混杂因素Zo,X→M的估计可以写成:
Figure BDA0003749271120000082
后门调整通过将P(Zo|X)替换为P(Zo)来切断Zo→X的效果,即
Figure BDA0003749271120000083
在实践中很难实现P(M|X,Zo=z),因为需要对每个Zo=z进行单独的估计。
在满足要求的情况下,前门调整采用两步估计:
首先,正常估计P(M=m|X)。这种估计不受路径X←Z→Y←M的影响,这要归功于碰撞效应Z→Y←M。不受控制的碰撞变量Y将阻止从Z到M的效应。同时,它也不受满足条件(2)的后门路径X→Zo←M的影响。
接着,由于未阻塞的后门路径M←X←Z→Y,对估计P(Y|M)进行do干预,即P(Y|do(M))。补救措施是以X为条件并阻止路径X→M。前面调整可以被形式化为:
Figure BDA0003749271120000084
Figure BDA0003749271120000085
本质上,前门调整将条件概率P(X=x|M=m)设置为先验概率P(X=x),从而阻塞路径Y←Z→X→M。
因此,可以假设不同教师模型的参数编码了不同Z=z的信息。首先,对X采用批量抽样策略。然后,使用非独立同分布的教师模型来估计期望。从技术上讲,给定相同的输入,非独立同分布教师模型
Figure BDA0003749271120000091
对非独立同分布数据样本进行训练会产生异构的M。由于当混杂变量设置为特定值Zo=z时,可以认为每组数据样本被选中,因此相应教师模型的参数学习了Zo=z下的知识。然后,教师模型可以视为估计量P(M|X,Zo=z)。对于第i个样本(包含用户xi,目标物品yi,该用户是否点击该目标物品oi),可以将P(M|do(X))近似为不同教师模型的估计的加权和:
Figure BDA0003749271120000092
式中:mk,i代表第k个教师模型对于第i个样本的中介特征表示,由第k个教师模型的编码器对输入的用户特征xi进行编码后输出得到;|Φ|表示Φ中的教师模型数量。
将P(z)设置为与训练
Figure BDA0003749271120000093
的用户数除以所有用户数成正比。因此,如果不同组的用户数量相同,则对P(z)采取均匀分布。/>
Figure BDA0003749271120000094
是从异质教师模型中因果提取的M的去偏估计,可以将/>
Figure BDA0003749271120000095
作为训练学生模型的因果指导之一,具体如下:
Figure BDA0003749271120000096
式中,
Figure BDA0003749271120000097
是学生模型对于第i个样本的中介特征表示,由学生模型的编码器对输入的用户特征xi进行编码后输出得到;Distance表示平均均方误差MSE计算函数。
基于上述理论描述,在本发明的实施例中,上述步骤S2具体包括以下子步骤:
S201、针对预先选定的由编码器和预测期组成的基础推荐模型,通过所有用户的历史行为数据对其进行训练,得到一个预训练模型M0
S202、在用户分组的基础上,基于每个用户组的数据对预训练模型M0进行微调,从而为每一个用户组训练一个独立的教师模型;然后使用非独立同分布的教师模型来估计期望,其中对于输入的同一个用户特征xi,非独立同分布教师模型Φ={φk}k=1,...,|Φ|对非独立同分布数据样本进行训练会产生中介变量M的不同取值,将不同教师模型编码器的输出进行加权求和得到中间表征的去偏估计期望:
Figure BDA0003749271120000101
式中:mk,i代表第k个教师模型对于第i个样本的中介特征表示,由第k个教师模型的编码器对输入的用户特征xi进行编码后输出得到。
S203、将去偏估计期望
Figure BDA0003749271120000102
作为训练学生模型的因果指导之一,构建用于通过因果后门调整进行特征蒸馏的第一蒸馏损失/>
Figure BDA0003749271120000103
其公式具体如下:
Figure BDA0003749271120000104
式中:
Figure BDA0003749271120000105
是学生模型对于第i个样本的中介特征表示,由学生模型的编码器对输入的用户特征xi进行编码后输出得到;Distance表示平均均方误差MSE计算函数。
在本发明的实施例中,上述基础推荐模型、教师模型和学生模型均采用深度兴趣网络(Deep Interest Network,DIN)。
S3、借助因果推断中的前门调整方法,通过每个用户组对应的教师模型获取多个中间表征作为中介,再利用批次(Batch)内采样机制和注意力机制,进行多模型多样本信息聚合,并构建一个用于将聚合后的信息蒸馏到学生模型的第二蒸馏损失。
在本发明的实施例中,上述步骤S3具体包括以下子步骤:
S301、M的抽样:使用非独立同分布的教师模型在给定用户变量X=xi的情况下对中间表征变量M的取值进行抽样,且抽样过程中对于给定X=xi时M=mk,i的条件概率建模为P(M=mk,i|X=xi),该条件概率采用注意力机制进行求解:
Figure BDA0003749271120000106
其中,αk,i代表通过注意力机制得到的条件概率P(M=mk,i|X=xi),mk,i是第k个教师模型抽样得到的第i个中介变量M的取值,mk,i代表第k个教师模型对于第i个样本的中介特征表示,W1、W2均为可学习的参数矩阵;
Figure BDA0003749271120000107
是学生模型对于第i个样本的中介特征表示,
S302、采用批次内采样策略对用户变量X进行采样,具体来说,对于总样本数为Nb+1个的批次中一个给定的训练样本,采样过程中将Batch中的其他每个训练样本对应的用户xj分别作为用户变量X的采样值,且X=xj的先验概率P(X=xj)采用均匀分布。当采用均匀分布时,所有采样样本的概率值是一样的,由于一个批次内的总采样次数是Nb次,因此先验概率P(X=xj)为1/Nb
S303、在推荐算法中,最终的预测一般是数据匹配。因此,将X的采样值为xj且M的采样值为mk,i时预测结果Y为yi的概率P(Y=yi|X=xj,M=mk,i)参数化为网络
Figure BDA00037492711200001112
并使用一个sigmoid层σ进行二分类:
Figure BDA0003749271120000111
其中mk,i和yi分别是mk,i和yi的特征表示;由于M截取了X→Y的所有影响,因此可进一步将用户xj用对应的中介特征表示
Figure BDA0003749271120000112
进行替代,从而将概率建模转换为:
Figure BDA0003749271120000113
其中
Figure BDA0003749271120000114
是以xj为输入时学生模型提取的的中介特征表示,网络/>
Figure BDA0003749271120000115
的结构与所述基础推荐模型中的预测器结构一致,
S304、估计P(Y|do(X)):通过上述分析和近似可知,在通过知识蒸馏进行前门调整过程中,对于批次中的第i个样本,经过前门调整后的预测结果
Figure BDA0003749271120000116
为:
mk,i=φk(xi),k=1,...,|Φ|
Figure BDA0003749271120000117
Figure BDA0003749271120000118
式中:Nb为X在一个批次内采样的总数;
由于上述估计引入了可训练参数,即W1、W2和
Figure BDA0003749271120000119
因此我们尝试将预测拉近真实标签oi。将预测结果/>
Figure BDA00037492711200001110
作为训练学生模型的因果指导之一,构建用于通过因果前门调整进行特征蒸馏的第二蒸馏损失/>
Figure BDA00037492711200001111
是一个基于前门干预蒸馏损失函数,其公式具体如下:
Figure BDA0003749271120000121
式中:
Figure BDA0003749271120000122
表示学生模型对于批次中的第i个样本的预测结果;oi表示批次中的第i个样本的预测结果的真实标签,代表用户xi是否点击了物品yi;/>
Figure BDA0003749271120000123
代表/>
Figure BDA0003749271120000124
和/>
Figure BDA0003749271120000125
的蒸馏损失,/>
Figure BDA0003749271120000126
代表/>
Figure BDA0003749271120000127
和oi的一致性损失。
S4、将第一蒸馏损失、第二蒸馏损失和学生模型的推荐损失加权后作为总损失,利用所述训练数据集通过最小化总损失对学生模型进行训练,并利用训练后的学生模型对目标用户进行项目推荐。
在本发明的实施例中,上述步骤S4具体包括以下子步骤:
S41、针对学生模型的训练构建总损失函数
Figure BDA0003749271120000128
其形式为:
Figure BDA0003749271120000129
式中:
Figure BDA00037492711200001210
为第一蒸馏损失,/>
Figure BDA00037492711200001211
为第二蒸馏损失,/>
Figure BDA00037492711200001212
为学生模型自带的推荐损失,α和β分别为权重值。在本实施例中,学生模型为DIN网络,因此学生模型自带的推荐损失为DIN网络自身计算的推荐损失。
S42、以最小化总损失函数为目标,利用所述训练数据集中的所有训练样本对学生模型进行训练,直至模型收敛;利用训练后的学生模型对目标用户进行项目推荐。
学生模型在执行推荐任务时,其输入为一个目标用户的历史行为数据,预测结果为目标用户下一步是否会与某一个项目进行交互的二分类结果以及概率,当得到目标用户与每一个项目之间的二分类结果以及概率后,即可进行项目推荐。
同样的,基于相同的发明构思,如图3所示,本发明的另一较佳实施例中还提供了与上述实施例提供的基于知识蒸馏和因果推理的推荐方法对应的一种基于知识蒸馏和因果推理的推荐系统,其包括:
数据集获取模块,该模块获取用于推荐模型训练的训练数据集,其中每个训练样本包含用户的编号、用户历史行为数据和项目编号;将所述训练数据集中所有用户按照敏感属性的高低排序后等分成若干个用户组;
第一损失模块,该模块利用所述训练数据集训练一个由编码器和预测期组成的基础推荐模型,再按照用户分组,利用每个用户组的数据分别对基础推荐模型进行微调,从而针对每一个用户组训练得到一个教师模型;在多个非独立同分布的教师模型基础上,构建一个用于通过因果后门调整进行特征蒸馏的第一蒸馏损失;
第二损失模块,该模块借助因果推断中的前门调整方法,通过每个用户组对应的教师模型获取多个中间表征作为中介,再利用批次(Batch)内采样机制和注意力机制,进行多模型多样本信息聚合,并构建一个用于将聚合后的信息蒸馏到学生模型的第二蒸馏损失;
训练和推荐模块,该模块将第一蒸馏损失、第二蒸馏损失和学生模型的推荐损失加权后作为总损失,利用所述训练数据集通过最小化总损失对学生模型进行训练,并利用训练后的学生模型对目标用户进行项目推荐。
由于上述基于知识蒸馏和因果推理的推荐方法解决问题的原理与本发明上述实施例的基于知识蒸馏和因果推理的推荐系统相似,因此该实施例中系统的各模块具体实现形式未尽之处亦可参见上述S1~S4所示方法部分的具体实现形式,重复之处不再赘述。
另外需要说明的是,上述各实施例系统中,各模块在被执行是相当于是按序执行的程序模块,因此其本质上是执行了一种数据处理的流程。且所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的系统的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。在本申请所提供的各实施例中,所述方法和系统中对于步骤或者模块的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个模块或步骤可以结合或者可以集成到一起,一个模块或者步骤亦可进行拆分。
下面本发明将通过一个具体实例,来展示上述实施例中基于知识蒸馏和因果推理的推荐方法在具体数据集上的应用效果,以便于理解本发明的实质。
实施例
本实施例在第三方支付数据上进行了测试。该数据中,手机充值服务、政务服务等小程序被视为项目。在该第三方支付平台中,本实例将一位用户同时观察和点击的项目视为正面,将用户观察但未点击的项目视为负面。
为了客观评估本算法的性能,采用推荐系统领域常用评价指标,包括AUC,Recall(R@K),NDCG(N@K),Heterogeneity(H),对方法进行了评价。
本实例中所得实验结果如表1所示,结果表明,本发明的方法(CausalD)具有较高的准确率,同时也缓解了异质性。
表1第三方支付平台数据集上的实验结果(所有模型均将DIN作为Base model)
指标 DIN KD IPS DebiasD MEAL CausalD P-value
AUC 0.7691 0.7615 0.7623 07712 0.7749 0.7777 3.08e-05
R@5 0.1669 0.1727 0.1778 0.2057 0.2343 0.2547 5.73e-08
R@10 0.3518 0.3682 0.3953 0.3675 0.3938 0.4457 2.22e-06
N@5 0.1186 0.1221 0.1237 0.1532 0.1745 0.1851 1.53e-06
N@10 0.2076 0.2165 0.2289 0.2314 0.2517 0.2779 1.86e-05
H 4.7834 4.7798 4.2870 4.5788 4.5540 3.4622 1.03e-03
本实例中对于每个模型,运行5次并取平均结果。本发明提出的CausalD在大多数情况下实现了最好的性能和最低的性能异质性。其他相关实验证明,不同推荐架构和不同数据集的改进是一致的,显示了CausalD与模型无关和领域无关的优点,而且本发明在更大规模的数据集上改进更为显着。
以上所述的实施例只是本发明的一种较佳的方案,然其并非用以限制本发明。有关技术领域的普通技术人员,在不脱离本发明的精神和范围的情况下,还可以做出各种变化和变型。因此凡采取等同替换或等效变换的方式所获得的技术方案,均落在本发明的保护范围内。

Claims (10)

1.一种基于知识蒸馏和因果推理的推荐方法,其特征在于,包括如下步骤:
S1、获取用于推荐模型训练的训练数据集,其中每个训练样本包含用户的编号、用户历史行为数据和项目编号;将所述训练数据集中所有用户按照敏感属性的高低排序后等分成若干个用户组;
S2、利用所述训练数据集训练一个由编码器和预测器组成的基础推荐模型,再按照用户分组,利用每个用户组的数据分别对基础推荐模型进行微调,从而针对每一个用户组训练得到一个教师模型;在多个非独立同分布的教师模型基础上,构建一个用于通过因果后门调整进行特征蒸馏的第一蒸馏损失;
S3、借助因果推断中的前门调整方法,通过每个用户组对应的教师模型获取多个中间表征作为中介,再利用批次Batch内采样机制和注意力机制,进行多模型多样本信息聚合,并构建一个用于将聚合后的信息蒸馏到学生模型的第二蒸馏损失;
S4、将第一蒸馏损失、第二蒸馏损失和学生模型的推荐损失加权后作为总损失,利用所述训练数据集通过最小化总损失对学生模型进行训练,并利用训练后的学生模型对目标用户进行项目推荐。
2.如权利要求1所述的基于知识蒸馏和因果推理的推荐方法,其特征在于,所述的步骤S1具体包括以下子步骤:
S101、获取训练样本构成的训练数据集,其中每个训练样本包含用户的编号、用户历史行为数据和项目编号;对于训练数据集中用户的历史行为数据,首先基于预先确定的用户出现频率阈值Nu和项目出现频率阈值Ni,过滤掉数据中出现频率小于Nu的用户以及出现频率小于Ni的项目;
S102、以用户活跃度作为所有用户的分组依据,设置分组的组数Ng;然后按照用户在S101过滤后的历史行为数据中出现的频率从高到低排序,将所有用户分成Ng个用户组,每个用户组的数据为这组用户各自所拥有的历史行为数据组成的集合。
3.如权利要求1所述的基于知识蒸馏和因果推理的推荐方法,其特征在于,所述的步骤S2具体包括以下子步骤:
S201、针对预先选定的由编码器和预测器组成的基础推荐模型,通过所有用户的历史行为数据对其进行训练,得到一个预训练模型M0
S202、在用户分组的基础上,基于每个用户组的数据对预训练模型M0进行微调,从而为每一个用户组训练一个独立的教师模型;然后使用非独立同分布的教师模型来估计期望,其中对于输入的同一个用户特征xi,非独立同分布教师模型Φ={φk}k=1,...,|Φ|对非独立同分布数据样本进行训练会产生中介变量M的不同取值,将不同教师模型编码器的输出进行加权求和得到中间表征的去偏估计期望:
Figure QLYQS_1
式中:mk,i代表第k个教师模型对于第i个样本的中介特征表示,由第k个教师模型的编码器对输入的用户特征xi进行编码后输出得到;
S203、将去偏估计期望
Figure QLYQS_2
作为训练学生模型的因果指导之一,构建用于通过因果后门调整进行特征蒸馏的第一蒸馏损失/>
Figure QLYQS_3
其公式具体如下:
Figure QLYQS_4
式中:
Figure QLYQS_5
是学生模型对于第i个样本的中介特征表示,由学生模型的编码器对输入的用户特征xi进行编码后输出得到;Distance表示平均均方误差MSE计算函数。
4.如权利要求1所述的基于知识蒸馏和因果推理的推荐方法,其特征在于,所述的步骤S3具体包括以下子步骤:
S301、使用非独立同分布的教师模型在给定用户变量X=xi的情况下对中间表征变量M的取值进行抽样,且抽样过程中对于给定X=xi时M=mk,i的条件概率建模为P(M=mk,i|X=xi),该条件概率采用注意力机制进行求解:
Figure QLYQS_6
其中,αk,i代表通过注意力机制得到的条件概率P(M=mk,i|X=xi),mk,i是第k个教师模型抽样得到的第i个中介变量M的取值,mk,i代表第k个教师模型对于第i个样本的中介特征表示,W1、W2均为可学习的参数矩阵;
Figure QLYQS_7
是学生模型对于第i个样本的中介特征表示;
S302、采用批次内采样策略对用户变量X进行采样,对于总样本数为Nb+1个的批次中一个给定的训练样本,采样过程中将Batch中的其他每个训练样本对应的用户xj分别作为用户变量X的采样值,且X=xj的先验概率P(X=xj)采用均匀分布;
S303、将X的采样值为xj且M的采样值为mk,i时预测结果Y为yi的概率P(Y=yi|X=xj,M=mk,i)参数化为网络
Figure QLYQS_8
并使用一个sigmoid层σ进行二分类:
Figure QLYQS_9
其中mk,i和yi分别是mk,i和yi的特征表示;进一步将用户xj用对应的中介特征表示
Figure QLYQS_10
进行替代,从而将概率建模转换为:
Figure QLYQS_11
其中
Figure QLYQS_12
是以xj为输入时学生模型提取的的中介特征表示,网络/>
Figure QLYQS_13
的结构与所述基础推荐模型中的预测器结构一致;
S304、在通过知识蒸馏进行前门调整过程中,对于批次中的第i个样本,经过前门调整后的预测结果
Figure QLYQS_14
为:
Figure QLYQS_15
式中:Nb为X在一个批次内采样的总数;
将预测结果
Figure QLYQS_16
作为训练学生模型的因果指导之一,构建用于通过因果前门调整进行特征蒸馏的第二蒸馏损失/>
Figure QLYQS_17
其公式具体如下:
Figure QLYQS_18
式中:
Figure QLYQS_19
表示学生模型对于批次中的第i个样本的预测结果;oi表示批次中的第i个样本的预测结果的真实标签,代表用户xi是否点击了物品yi;/>
Figure QLYQS_20
代表/>
Figure QLYQS_21
和/>
Figure QLYQS_22
的蒸馏损失,
Figure QLYQS_23
代表/>
Figure QLYQS_24
和oi的一致性损失。
5.如权利要求1所述的基于知识蒸馏和因果推理的推荐方法,其特征在于,所述的步骤S4具体包括以下子步骤:
S41、针对学生模型的训练构建总损失函数
Figure QLYQS_25
其形式为:
Figure QLYQS_26
式中:
Figure QLYQS_27
为第一蒸馏损失,/>
Figure QLYQS_28
为第二蒸馏损失,/>
Figure QLYQS_29
为学生模型自带的推荐损失,α和β分别为权重值;
S42、以最小化总损失函数为目标,利用所述训练数据集中的所有训练样本对学生模型进行训练,直至模型收敛;利用训练后的学生模型对目标用户进行项目推荐。
6.如权利要求1所述的基于知识蒸馏和因果推理的推荐方法,其特征在于,所述的基础推荐模型、教师模型和学生模型均采用深度兴趣网络(Deep Interest Network,DIN)。
7.如权利要求4所述的基于知识蒸馏和因果推理的推荐方法,其特征在于,所述S302中,采用均匀分布的先验概率P(X=xj)为1/Nb
8.如权利要求5所述的基于知识蒸馏和因果推理的推荐方法,其特征在于,所述的学生模型自带的推荐损失为深度兴趣网络自身计算的推荐损失。
9.如权利要求5所述的基于知识蒸馏和因果推理的推荐方法,其特征在于,所述的项目为商品、应用程序、在线内容。
10.一种基于知识蒸馏和因果推理的推荐系统,其特征在于,包括:
数据集获取模块,该模块获取用于推荐模型训练的训练数据集,其中每个训练样本包含用户的编号、用户历史行为数据和项目编号;将所述训练数据集中所有用户按照敏感属性的高低排序后等分成若干个用户组;
第一损失模块,该模块利用所述训练数据集训练一个由编码器和预测器组成的基础推荐模型,再按照用户分组,利用每个用户组的数据分别对基础推荐模型进行微调,从而针对每一个用户组训练得到一个教师模型;在多个非独立同分布的教师模型基础上,构建一个用于通过因果后门调整进行特征蒸馏的第一蒸馏损失;
第二损失模块,该模块借助因果推断中的前门调整方法,通过每个用户组对应的教师模型获取多个中间表征作为中介,再利用批次(Batch)内采样机制和注意力机制,进行多模型多样本信息聚合,并构建一个用于将聚合后的信息蒸馏到学生模型的第二蒸馏损失;
训练和推荐模块,该模块将第一蒸馏损失、第二蒸馏损失和学生模型的推荐损失加权后作为总损失,利用所述训练数据集通过最小化总损失对学生模型进行训练,并利用训练后的学生模型对目标用户进行项目推荐。
CN202210837534.4A 2022-07-15 2022-07-15 基于知识蒸馏和因果推理的推荐方法及系统 Active CN115292587B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210837534.4A CN115292587B (zh) 2022-07-15 2022-07-15 基于知识蒸馏和因果推理的推荐方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210837534.4A CN115292587B (zh) 2022-07-15 2022-07-15 基于知识蒸馏和因果推理的推荐方法及系统

Publications (2)

Publication Number Publication Date
CN115292587A CN115292587A (zh) 2022-11-04
CN115292587B true CN115292587B (zh) 2023-07-14

Family

ID=83823717

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210837534.4A Active CN115292587B (zh) 2022-07-15 2022-07-15 基于知识蒸馏和因果推理的推荐方法及系统

Country Status (1)

Country Link
CN (1) CN115292587B (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN118446322A (zh) * 2024-06-28 2024-08-06 北京科技大学 一种基于大语言模型先验知识的推理状态控制方法及装置

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10499857B1 (en) * 2017-09-19 2019-12-10 Deepradiology Inc. Medical protocol change in real-time imaging
CN112163081A (zh) * 2020-10-14 2021-01-01 网易(杭州)网络有限公司 标签确定方法、装置、介质及电子设备
CN114357301A (zh) * 2021-12-31 2022-04-15 腾讯科技(深圳)有限公司 数据处理方法、设备及可读存储介质

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10499857B1 (en) * 2017-09-19 2019-12-10 Deepradiology Inc. Medical protocol change in real-time imaging
CN112163081A (zh) * 2020-10-14 2021-01-01 网易(杭州)网络有限公司 标签确定方法、装置、介质及电子设备
CN114357301A (zh) * 2021-12-31 2022-04-15 腾讯科技(深圳)有限公司 数据处理方法、设备及可读存储介质

Also Published As

Publication number Publication date
CN115292587A (zh) 2022-11-04

Similar Documents

Publication Publication Date Title
Athey et al. Machine learning methods that economists should know about
Ning et al. Slim: Sparse linear methods for top-n recommender systems
CN109840833B (zh) 贝叶斯协同过滤推荐方法
CN114202061A (zh) 基于生成对抗网络模型及深度强化学习的物品推荐方法、电子设备及介质
Livne et al. Evolving context-aware recommender systems with users in mind
Navgaran et al. Evolutionary based matrix factorization method for collaborative filtering systems
Ju et al. Robust boosting for regression problems
CN115292587B (zh) 基于知识蒸馏和因果推理的推荐方法及系统
Ho et al. Algorithmic progress in language models
Zhang et al. SEDGN: Sequence enhanced denoising graph neural network for session-based recommendation
Marcelino et al. Missing data analysis in regression
CN110633417B (zh) 一种基于服务质量的web服务推荐的方法及系统
US11144938B2 (en) Method and system for predictive modeling of consumer profiles
Li et al. Ensemble of fast learning stochastic gradient boosting
Babeetha et al. An enhanced kernel weighted collaborative recommended system to alleviate sparsity
Khanna et al. Parallel matrix factorization for binary response
Yin et al. PeNet: A feature excitation learning approach to advertisement click-through rate prediction
Liang et al. ASE: Anomaly scoring based ensemble learning for highly imbalanced datasets
Vairetti et al. Propensity score oversampling and matching for uplift modeling
CN112749345B (zh) 一种基于神经网络的k近邻矩阵分解推荐方法
Zhao et al. A novel in-depth analysis approach for domain-specific problems based on multidomain data
Faletto et al. Cluster Stability Selection
Sumalatha et al. Rough set based decision rule generation to find behavioural patterns of customers
Xie et al. Econometric methods and data science techniques: A review of two strands of literature and an introduction to hybrid methods
Poulakis Unsupervised AutoML: a study on automated machine learning in the context of clustering

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant