CN117609788A - 端云接力的点击率预测模型训练方法、装置及存储介质 - Google Patents
端云接力的点击率预测模型训练方法、装置及存储介质 Download PDFInfo
- Publication number
- CN117609788A CN117609788A CN202311618807.7A CN202311618807A CN117609788A CN 117609788 A CN117609788 A CN 117609788A CN 202311618807 A CN202311618807 A CN 202311618807A CN 117609788 A CN117609788 A CN 117609788A
- Authority
- CN
- China
- Prior art keywords
- model
- training
- user
- equipment end
- data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 157
- 238000000034 method Methods 0.000 title claims abstract description 104
- 230000008569 process Effects 0.000 claims description 22
- 239000013598 vector Substances 0.000 claims description 21
- 230000003993 interaction Effects 0.000 claims description 19
- 230000006870 function Effects 0.000 claims description 17
- 230000002452 interceptive effect Effects 0.000 claims description 9
- 238000012935 Averaging Methods 0.000 claims description 3
- 238000010586 diagram Methods 0.000 description 8
- 230000007246 mechanism Effects 0.000 description 8
- 238000004590 computer program Methods 0.000 description 7
- 230000003287 optical effect Effects 0.000 description 5
- 238000012545 processing Methods 0.000 description 5
- 238000007477 logistic regression Methods 0.000 description 4
- 238000004891 communication Methods 0.000 description 3
- 230000006872 improvement Effects 0.000 description 3
- 238000003491 array Methods 0.000 description 2
- 238000013528 artificial neural network Methods 0.000 description 2
- 230000006399 behavior Effects 0.000 description 2
- 230000008901 benefit Effects 0.000 description 2
- 238000010276 construction Methods 0.000 description 2
- 239000000463 material Substances 0.000 description 2
- 239000013307 optical fiber Substances 0.000 description 2
- 230000000644 propagated effect Effects 0.000 description 2
- 238000012795 verification Methods 0.000 description 2
- 230000009471 action Effects 0.000 description 1
- 230000004931 aggregating effect Effects 0.000 description 1
- 230000002776 aggregation Effects 0.000 description 1
- 238000004220 aggregation Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 235000020303 café frappé Nutrition 0.000 description 1
- 238000000354 decomposition reaction Methods 0.000 description 1
- 230000003247 decreasing effect Effects 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 239000000835 fiber Substances 0.000 description 1
- 238000003064 k means clustering Methods 0.000 description 1
- 230000000670 limiting effect Effects 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 230000002829 reductive effect Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000002441 reversible effect Effects 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/23—Clustering techniques
- G06F18/232—Non-hierarchical techniques
- G06F18/2321—Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions
- G06F18/23213—Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions with fixed number of clusters, e.g. K-means clustering
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F21/00—Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
- G06F21/60—Protecting data
- G06F21/62—Protecting access to data via a platform, e.g. using keys or access control rules
- G06F21/6218—Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
- G06F21/6245—Protecting personal data, e.g. for financial or medical purposes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Software Systems (AREA)
- General Health & Medical Sciences (AREA)
- Bioethics (AREA)
- Medical Informatics (AREA)
- Health & Medical Sciences (AREA)
- Databases & Information Systems (AREA)
- Computer Hardware Design (AREA)
- Computer Security & Cryptography (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Probability & Statistics with Applications (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本公开提供的一种端云接力的点击率预测模型训练方法、装置及存储介质,包括:云服务器利用数据监管前收集的旧数据执行分组元学习训练得到多个初始化设备端模型;云服务器训练一个模型选择器以根据用户偏好来从多个初始化设备端模型中为用户自动化选择初始的设备端模型;设备端下载所述初始的设备端模型并利用数据监管后的本地数据接力训练设备端模型,设备间通过设备端模型参数的共享实现协同训练得到个性化的点击率预测模型。本公开即保护了用户的数据隐私,又实现了模型的更新,能够为用户提供更准确的点击率预测结果。
Description
技术领域
本公开属于端智能推荐技术领域,特别涉及一种端云接力的点击率预测模型训练方法、装置及存储介质。
背景技术
点击率(CTR)预测旨在预测用户点击推荐物品的概率,在推荐系统、在线广告和Web搜索等领域都有着至关重要的作用。为了提高模型的预测准确性,一方面,模型从简单的逻辑回归(LR)、因子分解机(FM)等模型不断演进为深度神经网络(DNN)、Wide&Deep、深度兴趣网络(DIN)等深度模型;另一方面,大量的用户交互数据被收集到运营商,用于训练先进的点击率预测模型。尽管点击率预测取得了显著的进展,但模型在训练和推断阶段通常存在显著的性能差距。在实际应用中,由于用户需求或兴趣的变化,用户的点击行为通常会随时间变化。因此,部署到端设备上的模型很容易变得陈旧,导致预测性能下降。因此,为了保证预测性能,模型需要定期用新数据更新。然而,点击率预测涉及用户资料和行为信息等隐私敏感数据,例如用户ID、性别、年龄和点击的物品等。随着用户隐私意识的不断提高,用户可能拒绝共享他们的私人数据。此外,国内外许多数据保护法律法规也对用户数据的收集和使用做出了严格的监管和限制。在数据监管新形式下,如何充分利用数据监管前收集的旧数据和数据监管后散布在端设备上的实时数据更新端上点击率预测模型具有重要研究意义。
发明内容
本公开旨在至少在一定程度上解决相关技术中的技术问题之一。
为此,本公开提出一种端云接力的点击率预测模型训练方法,通过先在云端利用数据监管前的数据和模型训练多个轻量的、学习能力强的、个性化的端模型,再在每个端设备上利用数据监管后的实时数据协同训练端模型的方式,即保护了用户的数据隐私,又实现了模型的更新,能够为用户提供更准确的点击率预测结果。
本公开的另一个目的在于提出一种端云接力的点击率预测模型训练装置。
本公开的又一个目的在于提出一种计算机存储介质。
为了实现上述目的,本公开第一方面提出了一种端云接力的点击率预测模型训练方法,包括:
云服务器利用数据监管前收集的旧数据执行分组元学习训练得到多个初始化设备端模型;
云服务器训练一个模型选择器以根据用户偏好来从多个初始化设备端模型中为用户自动化选择初始的设备端模型;
设备端下载所述初始的设备端模型并利用数据监管后的本地数据接力训练设备端模型,设备间通过设备端模型参数的共享实现协同训练得到个性化的点击率预测模型。
在一些实施例中,所述分组元学习训练,先在云服务器利用数据监管之前云端收集的数据训练一个点击率预测模型,然后再基于用于表征用户特征的用户嵌入向量对所述云端收集的数据进行分组,分组训练所述点击率预测模型得到多个设备端模型,作为所述多个初始化设备端模型。
在一些实施例中,所述云服务器利用数据监管前收集的旧数据执行分组元学习训练得到多个初始化设备端模型,包括:
S11,云服务器利用数据集训练一个初始化参数为ws的点击率预测模型fs(ws),训练过程中,参数ws更新及损失函数/>分别设置为:
其中,代表由数据监管前云端收集的数据构建的数据集,/>代表由数据监管前从用户/>收集的数据构成的数据集,/>代表数据监管前的用户集,yi代表用户是否与物品i交互的标签,/>代表云服务器端的点击率预测模型fs(ws)预测用户与物品i是否有交互的概率,η1为第一学习率,/>代表某一数据集;
S12,云服务器利用训练得到的点击率预测模型fs(ws)得到每个用户的用户嵌入矩阵/>根据用户嵌入矩阵/>对用户进行聚类,得到用户分组结果其中,/>代表第k组用户,/>代表分组数;
S13,在每组k内,采用元学习的方式训练设备端模型,得到每组对应的一个初始化设备端模型参数具体包括:
S131,随机初始化设备端模型参数
S132,重复执行以下步骤直到设备端模型收敛:1)从第k组用户中随机采样一批次的用户/>2)对于B中的每一个用户/>先执行本地更新,将随机初始化设备端模型参数/>赋值给参数θu,然后计算损失函数/>并更新/> η2代表第二学习率,/>代表任务τu对应的支持集;3)执行全局更新/>η3代表第三学习率,/>代表任务τu对应的查询集;
S133,利用数据集训练参数/>
在一些实施例中,所述云服务器训练一个模型选择器以根据用户偏好来从多个初始化设备端模型中为用户自动化选择初始的设备端模型,包括:
S21,通过云服务器利用数据监管之前云端收集的数据训练的所述点击率预测模型,得到每一个物品的嵌入向量ei,依据所有物品的嵌入向量对各物品进行聚类,将物品分为/>组/>对应/>个类别,对于每个物品类别/>对属于该类别的物品的嵌入向量/>取平均得到一个聚类中心嵌入向量/>表示每个物品类别c的特征;
S22,对于用户的每个交互物品/> 代表由数据监管前从用户/>收集的数据构成的数据集,通过计算交互物品/>的嵌入向量/>与每个聚类中心嵌入向量ec之间的相似性来确定用户/>的每个交互物品的类别;
S23,基于用户交互物品的类别,进行用户/>的偏好/>建模,公式如下:
式中,λc代表用户交互过的物品属于类别c的个数,/>代表该用户交互过的总物品个数;
S24,通过用户偏好建模,按照下式构建所述模型选择器的训练数据集
式中,是通过步骤S21得到的用户/>的聚类结果,对应云服务器一个初始化设备端模型ID;
S25,利用所述训练数据集对所述模型选择器进行训练,将所述模型选择器的训练任务视为一个分类任务,以用户/>的设备端新生成的交互数据/>构建模型选择器的输入,以初始化设备端模型ID作为模型选择器的输出,/>代表数据监管前的用户集。
在一些实施例中,采用多层感知机网络作为所述模型选择器s,并在所述训练数据集上使用交叉熵损失函数对所述模型选择器s进行训练。
在一些实施例中,所述设备间通过设备端模型参数的共享实现协同训练得到个性化的点击率预测模型,包括:
S31,在每一轮次训练中,从数据监管后的用户集中随机选择一部分用户/>参与协同学习,每个参与用户/>利用其在数据监管后生成的本地数据对初始设备端模型进行若干轮次的本地训练更新,得到更新后的设备端模型参数;
S32,每个参与用户上传更新后的设备端模型参数到云服务器;
S33,云服务器对于每组k,将下载了该组初始化设备端模型的所有参与用户的模型参数聚合更新参数wk,更新公式如下:
式中,ku代表云服务器执行分组元学习训练时对用户u的聚类结果,/>代表数据监管后用户u的设备端生成的交互数据;
S34,利用每组参数wk更新所有用户的设备端模型;
S35,重复执行步骤S31~步骤S34若干轮次,直到模型的端上点击率预测性能达到目标水平。
在一些实施例中,S31中,使用FedProx算法对初始设备端模型进行若干轮次的本地训练更新,FedProx算法采用的损失函数及设备端模型参数的更新方式为:
式中,为用户/>的设备端在数据监管后生成的交互数据,/>为本地训练时用户/>的设备端模型参数,/>为用户/>对应的/>组的设备端模型参数,η4为第四学习率,μ为超参数。
在一些实施例中,S31中,使用FedAvg算法对初始设备端模型进行若干轮次的本地训练更新,FedAvg算法采用的损失函数及设备端模型参数的更新方式为:
式中,为用户/>的设备端在数据监管后生成的交互数据,/>为本地训练时用户/>的设备端模型参数,/>为用户/>对应的/>组的设备端模型参数,η5为第五学习率。
本公开第二方面提供的一种用于实现根据本公开第一方面任一实施例所述点击率预测模型训练方法的点击率预测模型训练装置,包括:
初始化设备端模型训练模块,用于通过云服务器利用数据监管前收集的旧数据执行分组元学习训练得到多个初始化设备端模型;
模型选择器训练模块,用于通过云服务器训练一个模型选择器以根据用户偏好来从多个初始化设备端模型中为用户自动化选择初始的设备端模型;
设备端模型协同训练模块,用于使设备端下载所述初始的设备端模型并利用数据监管后的本地数据接力训练设备端模型,设备间通过设备端模型参数的共享实现协同训练得到个性化的点击率预测模型。
本公开第三方面提供的一种计算机可读存储介质,所述计算机可读存储介质存储计算机指令,所述计算机指令用于使所述计算机执行根据本公开第一方面任一实施例所述的点击率预测模型训练方法。
本公开提供的一种端云接力的点击率预测模型训练方法、装置及存储介质,具有以下特点及有益效果:
1.通过先在云端利用数据监管前的数据和模型训练多个轻量的、学习能力强的、个性化的端模型,再在每个端设备上利用数据监管后的实时数据协同训练端模型的方式,即保护了用户的数据隐私,又实现了模型的更新,能够为用户提供更准确的点击率预测结果。
2.包括一种分组元学习机制,能够充分利用数据监管前云端收集的数据为端设备训练具有良好初始化的模型。
3.包括一种自动化的模型选择机制,能够根据用户偏好自动为端设备选择初始化模型。
4.包括一种个性化的协同学习机制,在数据监管后能够在不共享用户数据的情况下完成用户设备端模型的实时更新。
附图说明
本公开上述的和/或附加的方面和优点从下面结合附图对实施例的描述中将变得明显和容易理解,其中:
图1是本公开实施例提供的端云接力的点击率预测模型训练方法面向的应用场景图;
图2是本公开实施例提供的端云接力的点击率预测模型训练方法的流程图;
图3是本公开实施例提供的端云接力的点击率预测模型训练方法的逻辑图;
图4是本公开实施例提供的端云接力的点击率预测模型训练方法的性能验证图;
图5是本公开实施例提供的端云接力的点击率预测模型训练装置的结构示意图;
图6是本公开实施例提供的用于实现计算机可读存储介质的电子设备的结构示意图。
具体实施方式
需要说明的是,在不冲突的情况下,本公开中的实施例及实施例中的特征可以相互组合。下面将参考附图并结合实施例来详细说明本公开。
为了使本技术领域的人员更好地理解本公开方案,下面将结合本公开实施例中的附图,对本公开实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本公开一部分的实施例,而不是全部的实施例。基于本公开中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本公开保护的范围。
下面参照附图描述根据本公开实施例提出的端云接力的点击率预测模型训练方法、装置及存储介质。
图1是本公开实施例提供的端云接力的点击率预测模型训练方法面向的应用场景图。如图1所示,在数据受监管前,云端服务器收集了大量用户交互数据,并利用这些数据训练了一个点击率预测模型;数据受到监管后,端设备产生的用户交互数据便无法再被收集到云端,而是分布式散布在端设备上。本公开实施例提供的端云接力的点击率预测模型训练方法提出要充分利用云端在数据受监管之前收集的用户交互数据和利用收集的用户交互数据训练得到的点击率预测模型,以及数据受到监管后散布在端设备上的分布式用户交互数据,训练更新端设备上的点击率预测模型,提升用户模型的预测性能。
图2是本公开实施例提供的端云接力的点击率预测模型训练方法的流程图。如图2所示,该方法包括但不限于以下步骤:
S1,云服务器利用数据监管前收集的旧数据执行分组元学习训练得到多个初始化设备端模型;
S2,云服务器训练一个模型选择器以根据用户偏好来从多个初始化设备端模型中为用户自动化选择初始的设备端模型;
S3,设备端下载初始的设备端模型并利用数据监管后的本地数据接力训练设备端模型,设备间通过设备端模型参数的共享实现协同训练得到个性化的点击率预测模型。
作为一个实施例,图3为本公开实施例提供的端云接力的点击率预测模型训练方法的逻辑图,如图3所示:
具体地,为了充分利用数据监管前云服务器收集的用户交互数据,本公开实施例的端云接力的点击率预测模型训练方法提供一种分组元学习训练方法,先在云服务器利用数据监管之前云端收集的数据训练一个服务器模型,然后再基于用于表征用户特征的用户嵌入向量对云端数据进行分组,分组训练服务器模型得到多个设备端模型,作为设备端的初始化模型。设代表由数据监管前云端收集的数据构建的数据集,其中/>代表由数据监管前从用户/>收集的数据构成的数据集,/>代表数据监管前的用户集。本公开实施例的分组元学习训练方法具体包括以下步骤:
S11,云服务器先利用数据集训练一个较先进的初始化参数为ws的点击率预测模型fs(ws),训练过程中,参数ws更新及损失函数/>分别设置为:
其中,yi代表用户是否与物品i交互的标签,代表云服务器端的点击率预测模型fs(ws)预测用户与物品i是否有交互的概率,η1为第一学习率,/>代表泛指的某一数据集。
S12,然后云服务器利用训练得到的点击率预测模型fs(ws)得到每个用户的用户嵌入矩阵/>再根据用户嵌入矩阵/>对用户进行k-means聚类,得到用户分组结果/>其中,/>代表第k组用户,/>代表分组数。
S13,然后在每组k内,采用元学习的方式训练设备端模型f,得到每组对应的一个初始化设备端模型参数采用元学习,对于每个用户/>进行的点击率预测可以被视为一个任务/>每个任务/>被分为一个支持集/>和一个查询集/>它们分别用于计算每个任务的训练损失和测试损失。具体而言,支持集在本地更新阶段使用,算法在该阶段基于每个支持集调整初始化设备端模型参数/>促进学习过程。相反,查询集在全局更新阶段使用,算法利用本地更新阶段调整后的模型参数计算查询集的损失函数值,并通过最小化损失函数的方式训练更新初始化设备端模型参数/>实现学习如何学习的过程。通过上述元学习过程,云服务器得到的各初始化设备端模型可以高效地适应新的任务。具体元学习方法为:
S131,先随机初始化设备端模型参数
S132,然后重复执行以下步骤直到设备端模型收敛:1)从中随机采样一批次的用户/>2)对于B中的每一个用户u∈B,先执行本地更新,将随机初始化设备端模型参数/>赋值给参数θu,然后计算损失函数/>并更新/> η2代表第二学习率,/>代表任务τu对应的支持集;3)执行全局更新/>η3代表第三学习率,/>代表任务τu对应的查询集;
S133,最后再利用数据监管前云端收集的数据集训练/>
本公开实施例的端云接力的点击率预测模型训练方法提供的分组元学习方法最终得到个初始化设备端模型参数/>以及用户的分组结果/>
通过上述分组元学习训练方法,云服务器训练得到多个初始化设备端模型,然后设备可以下载其中一个初始设备端模型并在本地训练以适应用户产生的新数据。为了能为用户选择合适的初始化设备端模型,本公开实施例的端云接力的点击率预测模型训练方法提供一种自动化模型选择机制,云服务器依据分组结果和物品嵌入向量训练一个模型选择器,用于为用户自动选择初始化设备端模型。具体而言,包括以下步骤:
S21,通过利用步骤S11得到的预先训练的点击率预测模型fs(ws),可以得到每一个物品的嵌入向量ei(该向量为点击率预测模型fs(ws)的中间层输出的特征)。然后依据所有物品的嵌入向量对各物品进行聚类,将物品分为/>组/>对应/>个类别。对于每个物品类别/>对属于该类别的物品的嵌入向量/>取平均得到一个聚类中心嵌入向量/>表示每个物品类别c的特征。
S22,然后,对于用户的每个交互物品/>通过计算其嵌入向量/>与每个聚类中心嵌入向量ec之间的相似性来确定用户/>的每个交互物品的类别。
S23,基于用户交互物品的类别,进行用户/>的偏好/>建模,公式如下:
式中,λc代表该用户交互过的物品属于类别c的个数,/>代表该用户交互过的总物品个数。
S24,通过用户偏好建模,按照下式构建模型选择器的训练数据集
式中,是通过步骤S21得到的用户/>的聚类结果,对应云服务器一个初始化设备端模型ID。
S25,利用训练数据集对模型选择器进行训练,模型选择器的训练任务可以视为一个分类任务。具体地,采用多层感知机(MLP)网络作为模型选择器s,并在训练数据集上使用交叉熵损失函数对模型选择器s进行训练。模型选择器在云服务器上训练,并被每个设备端下载进行本地推断。在用户/>的设备端上,首先基于用户/>的设备端新生成的交互数据/>构建模型选择器的输入,即用户/>的偏好/>这与模型选择器的训练数据集的构建过程相一致。然后,模型选择器输出一个ID,设备将从云服务器下载相应的初始设备端模型。
当设备获取各自的初始设备端模型后,它们将在本地使用数据监管后新生成的数据对其进行接力训练,使其适应用户的即时偏好。为了促进用户之间的高效协作训练,本公开实施例的端云接力的点击率预测模型训练方法提供一种个性化协同训练机制,根据用户初始设备端模型的选择来执行个性化的设备端模型聚合。每个用户先从云服务器下载训练好的模型选择器s,然后利用本地新生成的交互数据构建模型选择器网络的输入pu,这与模型选择器的训练数据构建过程相一致。然后,模型选择器输出ku,对应一个分组,设备将从云服务器下载对应分组的初始化设备模型。接下来,设备之间便开始迭代的协同学习过程,具体过程包括:
S31,在每一轮次中,从数据监管后的用户集中随机选择一部分用户/>参与协同学习,每个参与用户/>利用其本地数据使用FedProx算法对初始设备端模型进行τ轮次的本地训练更新:
式中,为FedProx算法的损失函数,/>为用户/>的设备端新生成的交互数据,/>为本地训练时用户/>的设备端模型参数,/>为用户/>对应的/>组的设备端模型参数,η4为第四学习率,μ为超参数。
或者,采用使用FedAvg算法对初始设备端模型进行若干轮次的本地训练更新,FedAvg算法采用的损失函数及设备端模型参数的更新方式为:
式中,η5为第五学习率。
S32,之后,每个参与用户上传更新后的设备端模型参数/>到云服务器。
S33,云服务器对于每组k,将下载了该组初始化设备端模型的所有参与用户的模型参数聚合更新参数wk,更新公式如下:
式中, 代表选中参与的用户中现在了第k组初始化模型的用户集,ku代表云服务器执行分组元学习训练时对用户u的聚类结果,/>代表数据监管后用户u的设备端生成的交互数据。
S34,利用每组参数wk更新所有用户的设备端模型。
S35,重复执行步骤S31~步骤S34若干轮次,直到模型的端上预测性能达到目标水平。
本公开实施例的有效性验证:
采用MovieLens、Frappe和KKBox三个公开数据集和一个工业数据集Meituan来验证本公开实施例提供的端云接力的点击率预测模型训练方法的有效性。其中,对于三个公开数据集,训练数据被随机划分为两部分,5%的数据作为数据监管前云服务器收集的用户交互数据,其余作为散布在分布式设备上的隐私保护数据。对于Meituan数据集,将从2022年12月17日到2022年12月20日收集的数据标记为在数据隐私法规实施之前收集汇总的云端数据,将从2023年2月17日到2023年2月20日生成的数据,视为数据隐私法规实施后产生的数据,这些数据分散存储在用户设备上。设备端的点击率预测模型f采用两个广泛使用的逻辑回归(LR)模型或深度神经网络(DNN)模型,服务器端较先进的点击率预测模型fs采用xDeepFM模型。本公开实施例提供的端云接力的点击率预测模型训练方法取名为RelayRec,其性能与8种方法做对比,包括2种初始化模型训练方法与4种协同训练方法的组合。2种初始化模型训练方法包括直接用数据监管前云端收集的旧数据训练得到初始化设备端模型与利用元学习的方法训练得到初始化元模型。4种协同训练方法包括端上微调(只利用端上数据训练)、联邦平均(FedAvg)、FedProx和分组协同。包括两个性能评估指标,一是AUC(曲线下面积)量化了被选择的随机正样本排名高于随机选择的负样本的概率,更高的AUC值表示更好的点击率预测性能;另一个是LogLoss(对数损失)也称为二进制交叉熵损失,用作衡量二元分类场景中损失的指标,较低的Logloss值表示更好的点击率预测性能。
图4是RelayRec与其他方法的性能对比,可以看到,在大多数情况下,RelayRec的表现都优于对比方案。特别是,在Frappe数据集上表现出最显著的性能提升,平均AUC提高了7.2%,LogLoss降低了14.9%。此外,RelayRec在美团工业数据集上的表现也显著,AUC提高了约5.9%,LogLoss降低了约5.4%。总体而言,RelayRec平均提高了5%的AUC,降低了8.1%的LogLoss。这些性能改进表明了RelayRec端云接力的训练方式在点击率预测模型训练性能方面的优越性。
可以理解的是,根据本公开实施例提供的端云接力的点击率预测模型训练方法,通过先在云端利用数据监管前的数据和模型训练多个轻量的、学习能力强的、个性化的端模型,再在每个端设备上利用数据监管后的实时数据协同训练端模型的方式,即保护了用户的数据隐私,又实现了模型的更新,能够为用户提供更准确的点击率预测结果。具体地,通过提出的一种分组元学习机制,能够充分利用数据监管前云端收集的数据为端设备训练具有良好初始化的模型;通过提出的一种自动化的模型选择机制,能够根据用户偏好自动为端设备选择初始化模型;通过提出的一种个性化的协同学习机制,在数据监管后能够在不共享用户数据的情况下完成用户设备端模型的实时更新。
为了实现上述实施例,如图5所示,本实施例中还提供了端云接力的点击率预测模型训练装置100,该装置100包括,初始化设备端模型训练模块1、模型选择器训练模块2、和设备端模型协同训练模块3。
初始化设备端模型训练模块1,用于通过云服务器利用数据监管前收集的旧数据执行分组元学习训练得到多个初始化设备端模型;
模型选择器训练模块2,用于通过云服务器训练一个模型选择器以根据用户偏好来从多个初始化设备端模型中为用户自动化选择初始的设备端模型;
设备端模型协同训练模块3,用于使设备端下载所述初始的设备端模型并利用数据监管后的本地数据接力训练设备端模型,设备间通过设备端模型参数的共享实现协同训练得到个性化的点击率预测模型。
需要说明的是,前述对一种端云接力的点击率预测模型训练方法的实施例解释说明也适用于本实施例的一种端云接力的点击率预测模型训练装置,在此不再赘述。
为了实现上述实施例,本公开实施例还提出一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行,用于执行上述实施例的端云接力的点击率预测模型训练方法。
下面参考图6,其示出了适于用来实现本公开实施例的电子设备的结构示意图。其中,需要说明的是,本公开实施例中的电子设备可以包括但不限于诸如移动电话、笔记本电脑、数字广播接收器、PDA(个人数字助理)、PAD(平板电脑)、PMP(便携式多媒体播放器)、车载终端(例如车载导航终端)等等的移动终端以及诸如数字TV、台式计算机、服务器等等的固定终端。图6示出的电子设备仅仅是一个示例,不应对本公开实施例的功能和使用范围带来任何限制。
如图6所示,电子设备可以包括处理装置(例如中央处理器、图形处理器等)101,其可以根据存储在只读存储器(ROM)102中的程序或者从存储装置108加载到随机访问存储器(RAM)103中的程序而执行各种适当的动作和处理。在RAM 103中,还存储有电子设备操作所需的各种程序和数据。处理装置101、ROM 102以及RAM 103通过总线104彼此相连。输入/输出(I/O)接口105也连接至总线104。
通常,以下装置可以连接至I/O接口105:包括例如触摸屏、触摸板、键盘、鼠标、摄像头、麦克风等的输入装置106;包括例如液晶显示器(LCD)、扬声器、振动器等的输出装置107;包括例如磁带、硬盘等的存储装置108;以及通信装置109。通信装置109可以允许电子设备与其他设备进行无线或有线通信以交换数据。虽然图6示出了具有各种装置的电子设备,但是应理解的是,并不要求实施或具备所有示出的装置。可以替代地实施或具备更多或更少的装置。
特别地,根据本公开的实施例,上文参考流程图描述的过程可以被实现为计算机软件程序。例如,本实施例包括一种计算机程序产品,其包括承载在计算机可读介质上的计算机程序,该计算机程序包含用于执行流程图中所示方法的程序代码。在这样的实施例中,该计算机程序可以通过通信装置109从网络上被下载和安装,或者从存储装置108被安装,或者从ROM 102被安装。在该计算机程序被处理装置101执行时,执行本公开实施例的方法中限定的上述功能。
需要说明的是,本公开上述的计算机可读介质可以是计算机可读信号介质或者计算机可读存储介质或者是上述两者的任意组合。计算机可读存储介质例如可以是——但不限于——电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。计算机可读存储介质的更具体的例子可以包括但不限于:具有一个或多个导线的电连接、便携式计算机磁盘、硬盘、随机访问存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑磁盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。在本公开中,计算机可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。而在本公开中,计算机可读信号介质可以包括在基带中或者作为载波一部分传播的数据信号,其中承载了计算机可读的程序代码。这种传播的数据信号可以采用多种形式,包括但不限于电磁信号、光信号或上述的任意合适的组合。计算机可读信号介质还可以是计算机可读存储介质以外的任何计算机可读介质,该计算机可读信号介质可以发送、传播或者传输用于由指令执行系统、装置或者器件使用或者与其结合使用的程序。计算机可读介质上包含的程序代码可以用任何适当的介质传输,包括但不限于:电线、光缆、RF(射频)等等,或者上述的任意合适的组合。
上述计算机可读介质可以是上述电子设备中所包含的;也可以是单独存在,而未装配入该电子设备中。
上述计算机可读介质承载有一个或者多个程序,当上述一个或者多个程序被该电子设备执行时,使得该电子设备执行上述端云接力的点击率预测模型训练方法。
可以以一种或多种程序设计语言或其组合来编写用于执行本公开的操作的计算机程序代码,上述程序设计语言包括面向对象的程序设计语言一诸如Java、Smalltalk、C++、python,还包括常规的过程式程序设计语言-诸如“C”语言或类似的程序设计语言。程序代码可以完全地在用户计算机上执行、部分地在用户计算机上执行、作为一个独立的软件包执行、部分在用户计算机上部分在远程计算机上执行、或者完全在远程计算机或服务器上执行。在涉及远程计算机的情形中,远程计算机可以通过任意种类的网络——包括局域网(LAN)或广域网(WAN)-连接到用户计算机,或者,可以连接到外部计算机(例如利用因特网服务提供商来通过因特网连接)。
在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本申请的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不必须针对的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任一个或多个实施例或示例中以合适的方式结合。此外,在不相互矛盾的情况下,本领域的技术人员可以将本说明书中描述的不同实施例或示例以及不同实施例或示例的特征进行结合和组合。
此外,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括至少一个该特征。在本申请的描述中,“多个”的含义是至少两个,例如两个,三个等,除非另有明确具体的限定。
流程图中或在此以其他方式描述的任何过程或方法描述可以被理解为,表示包括一个或更多个用于实现特定逻辑功能或过程的步骤的可执行指令的代码的模块、片段或部分,并且本申请的优选实施方式的范围包括另外的实现,其中可以不按所示出或讨论的顺序,包括根据所涉及的功能按基本同时的方式或按相反的顺序,来执行功能,这应被本申请的实施例所属技术领域的技术人员所理解。
在流程图中表示或在此以其他方式描述的逻辑和/或步骤,例如,可以被认为是用于实现逻辑功能的可执行指令的定序列表,可以具体实现在任何计算机可读介质中,以供指令执行系统、装置或设备(如基于计算机的系统、包括处理器的系统或其他可以从指令执行系统、装置或设备取指令并执行指令的系统)使用,或结合这些指令执行系统、装置或设备而使用。就本说明书而言,“计算机可读介质”可以是任何可以包含、存储、通信、传播或传输程序以供指令执行系统、装置或设备或结合这些指令执行系统、装置或设备而使用的装置。计算机可读介质的更具体的示例(非穷尽性列表)包括以下:具有一个或多个布线的电连接部(电子装置),便携式计算机盘盒(磁装置),随机存取存储器(RAM),只读存储器(ROM),可擦除可编辑只读存储器(EPROM或闪速存储器),光纤装置,以及便携式光盘只读存储器(CDROM)。另外,计算机可读介质甚至可以是可在其上打印程序的纸或其他合适的介质,因为可以例如通过对纸或其他介质进行光学扫描,接着进行编辑、解译或必要时以其他合适方式进行处理来以电子方式获得程序,然后将其存储在计算机存储器中。
应当理解,本申请的各部分可以用硬件、软件、固件或它们的组合来实现。在上述实施方式中,多个步骤或方法可以用存储在存储器中且由合适的指令执行系统执行的软件或固件来实现。例如,如果用硬件来实现,和在另一实施方式中一样,可用本领域公知的下列技术中的任一项或他们的组合来实现:具有用于对数据信号实现逻辑功能的逻辑门电路的离散逻辑电路,具有合适的组合逻辑门电路的专用集成电路,可编程门阵列(PGA),现场可编程门阵列(FPGA)等。
本技术领域的普通技术人员可以理解实现上述实施例方法携带的全部或部分步骤,可以通过程序来指令相关的硬件完成,所开发的程序可以存储于一种计算机可读存储介质中,该程序在执行时,包括方法实施例的步骤之一或其组合。
此外,在本申请各个实施例中的各功能单元可以集成在一个处理模块中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个模块中。上述集成的模块既可以采用硬件的形式实现,也可以采用软件功能模块的形式实现。集成的模块如果以软件功能模块的形式实现并作为独立的产品销售或使用时,也可以存储在一个计算机可读取存储介质中。
上述提到的存储介质可以是只读存储器,磁盘或光盘等。尽管上面已经示出和描述了本申请的实施例,可以理解的是,上述实施例是示例性的,不能理解为对本申请的限制,本领域的普通技术人员在本申请的范围内可以对上述实施例进行变化、修改、替换和变型。
Claims (10)
1.一种端云接力的点击率预测模型训练方法,其特征在于,包括:
云服务器利用数据监管前收集的旧数据执行分组元学习训练得到多个初始化设备端模型;
云服务器训练一个模型选择器以根据用户偏好来从多个初始化设备端模型中为用户自动化选择初始的设备端模型;
设备端下载所述初始的设备端模型并利用数据监管后的本地数据接力训练设备端模型,设备间通过设备端模型参数的共享实现协同训练得到个性化的点击率预测模型。
2.根据权利要求1所述的点击率预测模型训练方法,其特征在于,所述分组元学习训练,先在云服务器利用数据监管之前云端收集的数据训练一个点击率预测模型,然后再基于用于表征用户特征的用户嵌入向量对所述云端收集的数据进行分组,分组训练所述点击率预测模型得到多个设备端模型,作为所述多个初始化设备端模型。
3.根据权利要求1所述的点击率预测模型训练方法,其特征在于,所述云服务器利用数据监管前收集的旧数据执行分组元学习训练得到多个初始化设备端模型,包括:
S11,云服务器利用数据集训练一个初始化参数为ws的点击率预测模型fs(ws),训练过程中,参数ws更新及损失函数/>分别设置为:
其中,代表由数据监管前云端收集的数据构建的数据集,/>代表由数据监管前从用户/>收集的数据构成的数据集,/>代表数据监管前的用户集,yi代表用户是否与物品i交互的标签,/>代表云服务器端的点击率预测模型fs(ws)预测用户与物品i是否有交互的概率,η1为第一学习率,/>代表某一数据集;
S12,云服务器利用训练得到的点击率预测模型fs(ws)得到每个用户的用户嵌入矩阵/> 根据用户嵌入矩阵/>对用户进行聚类,得到用户分组结果其中,/>代表第k组用户,/>代表分组数;
S13,在每组k内,采用元学习的方式训练设备端模型,得到每组对应的一个初始化设备端模型参数具体包括:
S131,随机初始化设备端模型参数
S132,重复执行以下步骤直到设备端模型收敛:1)从第k组用户中随机采样一批次的用户/>2)对于B中的每一个用户u∈B,先执行本地更新,将随机初始化设备端模型参数/>赋值给参数θu,然后计算损失函数/>并更新/> η2代表第二学习率,/>代表任务τu对应的支持集;3)执行全局更新η3代表第三学习率,/>代表任务τu对应的查询集;
S133,利用数据集训练参数/>
4.根据权利要求2所述的点击率预测模型训练方法,其特征在于,所述云服务器训练一个模型选择器以根据用户偏好来从多个初始化设备端模型中为用户自动化选择初始的设备端模型,包括:
S21,通过云服务器利用数据监管之前云端收集的数据训练的所述点击率预测模型,得到每一个物品的嵌入向量ei,依据所有物品的嵌入向量对各物品进行聚类,将物品分为/>组/>对应/>个类别,对于每个物品类别/>对属于该类别的物品的嵌入向量/>取平均得到一个聚类中心嵌入向量/>表示每个物品类别c的特征;
S22,对于用户的每个交互物品/>代表由数据监管前从用户/>收集的数据构成的数据集,通过计算交互物品/>的嵌入向量/>与每个聚类中心嵌入向量ec之间的相似性来确定用户/>的每个交互物品的类别;
S23,基于用户交互物品的类别,进行用户/>的偏好/>建模,公式如下:
式中,λc代表用户交互过的物品属于类别c的个数,/>代表该用户交互过的总物品个数;
S24,通过用户偏好建模,按照下式构建所述模型选择器的训练数据集
式中,是通过步骤S21得到的用户/>的聚类结果,对应云服务器一个初始化设备端模型ID;
S25,利用所述训练数据集对所述模型选择器进行训练,将所述模型选择器的训练任务视为一个分类任务,以用户/>的设备端新生成的交互数据/>构建模型选择器的输入,以初始化设备端模型ID作为模型选择器的输出,/>代表数据监管前的用户集。
5.根据权利要求4所述的点击率预测模型训练方法,其特征在于,采用多层感知机网络作为所述模型选择器s,并在所述训练数据集上使用交叉熵损失函数对所述模型选择器s进行训练。
6.根据权利要求1所述的点击率预测模型训练方法,其特征在于,所述设备间通过设备端模型参数的共享实现协同训练得到个性化的点击率预测模型,包括:
S31,在每一轮次训练中,从数据监管后的用户集中随机选择一部分用户/>参与协同学习,每个参与用户/>利用其在数据监管后生成的本地数据对初始设备端模型进行若干轮次的本地训练更新,得到更新后的设备端模型参数;
S32,每个参与用户上传更新后的设备端模型参数到云服务器;
S33,云服务器对于每组k,将下载了该组初始化设备端模型的所有参与用户的模型参数聚合更新参数wk,更新公式如下:
式中,ku代表云服务器执行分组元学习训练时对用户u的聚类结果,/>代表数据监管后用户u的设备端生成的交互数据;
S34,利用每组参数wk更新所有用户的设备端模型;
S35,重复执行步骤S31~步骤S34若干轮次,直到模型的端上点击率预测性能达到目标水平。
7.根据权利要求6所述的点击率预测模型训练方法,其特征在于,S31中,使用FedProx算法对初始设备端模型进行若干轮次的本地训练更新,FedProx算法采用的损失函数及设备端模型参数的更新方式为:
式中,为用户/>的设备端在数据监管后生成的交互数据,/>为本地训练时用户的设备端模型参数,/>为用户/>对应的/>组的设备端模型参数,η4为第四学习率,μ为超参数。
8.根据权利要求6所述的点击率预测模型训练方法,其特征在于,S31中,使用FedAvg算法对初始设备端模型进行若干轮次的本地训练更新,FedAvg算法采用的损失函数及设备端模型参数的更新方式为:
式中,为用户/>的设备端在数据监管后生成的交互数据,/>为本地训练时用户的设备端模型参数,/>为用户/>对应的/>组的设备端模型参数,η5为第五学习率。
9.一种用于实现根据权利要求1~8中任一项所述点击率预测模型训练方法的点击率预测模型训练装置,其特征在于,包括:
初始化设备端模型训练模块,用于通过云服务器利用数据监管前收集的旧数据执行分组元学习训练得到多个初始化设备端模型;
模型选择器训练模块,用于通过云服务器训练一个模型选择器以根据用户偏好来从多个初始化设备端模型中为用户自动化选择初始的设备端模型;
设备端模型协同训练模块,用于使设备端下载所述初始的设备端模型并利用数据监管后的本地数据接力训练设备端模型,设备间通过设备端模型参数的共享实现协同训练得到个性化的点击率预测模型。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储计算机指令,所述计算机指令用于使所述计算机执行权利要求1~8中任一项所述的点击率预测模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311618807.7A CN117609788A (zh) | 2023-11-30 | 2023-11-30 | 端云接力的点击率预测模型训练方法、装置及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311618807.7A CN117609788A (zh) | 2023-11-30 | 2023-11-30 | 端云接力的点击率预测模型训练方法、装置及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117609788A true CN117609788A (zh) | 2024-02-27 |
Family
ID=89943950
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311618807.7A Pending CN117609788A (zh) | 2023-11-30 | 2023-11-30 | 端云接力的点击率预测模型训练方法、装置及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117609788A (zh) |
-
2023
- 2023-11-30 CN CN202311618807.7A patent/CN117609788A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11315030B2 (en) | Continuously learning, stable and robust online machine learning system | |
CN107463701B (zh) | 基于人工智能推送信息流的方法和装置 | |
CN113626719A (zh) | 信息推荐方法、装置、设备、存储介质及计算机程序产品 | |
US20170308535A1 (en) | Computational query modeling and action selection | |
WO2022016556A1 (zh) | 一种神经网络蒸馏方法以及装置 | |
CN109471978B (zh) | 一种电子资源推荐方法及装置 | |
US20210406761A1 (en) | Differentiable user-item co-clustering | |
CN111104599B (zh) | 用于输出信息的方法和装置 | |
CN112528164B (zh) | 一种用户协同过滤召回方法及装置 | |
CN112380449B (zh) | 信息推荐方法、模型训练方法及相关装置 | |
CN111783810A (zh) | 用于确定用户的属性信息的方法和装置 | |
CN114417174B (zh) | 内容推荐方法、装置、设备及计算机存储介质 | |
CN114764471A (zh) | 一种推荐方法、装置及存储介质 | |
CN116452263A (zh) | 一种信息推荐方法、装置、设备及存储介质、程序产品 | |
CN115841366A (zh) | 物品推荐模型训练方法、装置、电子设备及存储介质 | |
CN116957678A (zh) | 一种数据处理方法和相关装置 | |
CN114647789A (zh) | 一种推荐模型的确定方法和相关装置 | |
CN117609788A (zh) | 端云接力的点击率预测模型训练方法、装置及存储介质 | |
CN116484085A (zh) | 一种信息投放方法、装置、设备及存储介质、程序产品 | |
CN115700548A (zh) | 用户行为预测的方法、设备和计算机程序产品 | |
Chen et al. | Temporal-aware influence maximization solution in artificial intelligent edge application | |
CN116662672B (zh) | 价值对象信息发送方法、装置、设备和计算机可读介质 | |
CN112734462B (zh) | 一种信息推荐方法、装置、设备及介质 | |
CN116401398A (zh) | 一种数据处理方法及相关装置 | |
CN116629087A (zh) | 模型的预测置信度评估方法及装置、设备、存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |