CN117150122A - 终端推荐模型的联邦训练方法、装置和存储介质 - Google Patents
终端推荐模型的联邦训练方法、装置和存储介质 Download PDFInfo
- Publication number
- CN117150122A CN117150122A CN202311021672.6A CN202311021672A CN117150122A CN 117150122 A CN117150122 A CN 117150122A CN 202311021672 A CN202311021672 A CN 202311021672A CN 117150122 A CN117150122 A CN 117150122A
- Authority
- CN
- China
- Prior art keywords
- model
- training
- local
- parameters
- round
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 166
- 238000000034 method Methods 0.000 title claims abstract description 66
- 238000003860 storage Methods 0.000 title claims description 20
- 238000013140 knowledge distillation Methods 0.000 claims abstract description 17
- 230000002776 aggregation Effects 0.000 claims abstract description 9
- 238000004220 aggregation Methods 0.000 claims abstract description 9
- 238000012512 characterization method Methods 0.000 claims description 10
- 238000004821 distillation Methods 0.000 claims description 9
- 230000000694 effects Effects 0.000 claims description 7
- 230000000052 comparative effect Effects 0.000 claims description 6
- 238000004364 calculation method Methods 0.000 claims description 4
- 230000006870 function Effects 0.000 description 33
- 238000004891 communication Methods 0.000 description 7
- 238000004590 computer program Methods 0.000 description 7
- 230000008569 process Effects 0.000 description 6
- 238000012545 processing Methods 0.000 description 6
- 238000010586 diagram Methods 0.000 description 5
- 230000003287 optical effect Effects 0.000 description 5
- 238000012360 testing method Methods 0.000 description 4
- 238000010801 machine learning Methods 0.000 description 3
- 238000003491 array Methods 0.000 description 2
- 230000005540 biological transmission Effects 0.000 description 2
- 238000011156 evaluation Methods 0.000 description 2
- 239000000463 material Substances 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 239000013307 optical fiber Substances 0.000 description 2
- 238000005457 optimization Methods 0.000 description 2
- 230000000644 propagated effect Effects 0.000 description 2
- 235000010627 Phaseolus vulgaris Nutrition 0.000 description 1
- 244000046052 Phaseolus vulgaris Species 0.000 description 1
- 206010044565 Tremor Diseases 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 239000000835 fiber Substances 0.000 description 1
- 238000001914 filtration Methods 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000004091 panning Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/90—Details of database functions independent of the retrieved data types
- G06F16/95—Retrieval from the web
- G06F16/953—Querying, e.g. by the use of web search engines
- G06F16/9535—Search customisation based on user profiles and personalisation
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q30/00—Commerce
- G06Q30/06—Buying, selling or leasing transactions
- G06Q30/0601—Electronic shopping [e-shopping]
- G06Q30/0631—Item recommendations
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Databases & Information Systems (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Business, Economics & Management (AREA)
- Software Systems (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- Accounting & Taxation (AREA)
- Finance (AREA)
- Development Economics (AREA)
- Marketing (AREA)
- Strategic Management (AREA)
- General Business, Economics & Management (AREA)
- Artificial Intelligence (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Economics (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本公开提供的终端推荐模型的联邦训练方法,包括:在云端构建教师模型和学生模型,利用云端的历史数据集对教师、学生模型进行预训练,并将学生模型发送至各终端;各终端分别利用本地数据集对学生模型进行首轮联邦训练,得到本地学生模型;云端对本地学生模型的参数进行聚合,得到全局模型;在云端利用教师模型和历史数据集对全局模型的参数通过自监督知识蒸馏的方式进行参数增强,利用增强后的全局模型的参数对各终端的本地学生模型的参数进行更新;不断重复上述本地训练及云端的参数聚合和参数增强步骤,直至训练轮次达到上限。本公开能在保护隐私的同时提供泛化能力更强的轻量化推荐模型,减少了联邦推荐中因数据稀疏带来准确率较低的问题。
Description
技术领域
本公开属于数据处理技术领域,特别涉及一种终端推荐模型的联邦训练方法、装置和存储介质。
背景技术
推荐系统是一种信息过滤系统,目的是为了预测用户的喜好和兴趣并且实时地为用户提供个性化的建议。推荐系统在电子商务、广告、音乐、电影等多种领域都有广泛应用。联邦推荐系统是一种将联邦学习与推荐系统结合的分布式机器学习框架,其在保护用户推荐数据隐私的同时也能够准确的为用户提供个性化推荐。随着终端小型化和智能的发展,端上推荐系统近年来收到越来越多地关注,但由于终端(如智能手机)受到计算资源的限制,终端中的多款应用如淘宝、豆瓣、抖音等对于推荐模型的大小有着极强的限制,更希望利用较小的模型为用户提供更准确的推荐。
在越来越重视数据隐私保护的当下,联邦学习作为一种新兴的分布式机器学习范式,这项技术不仅减少了数据隐私泄露的风险还降低了数据传输的开销,并打破了数据孤岛的限制。但如何实现更准确、轻量化的终端推荐模型联邦训练方法,并允许参与联邦推荐的用户在资源受限的终端上能够准确的获得推荐建议至关重要。
发明内容
本公开旨在至少解决现有技术中存在的技术问题之一。
为此,本公开第一方面提供的一种终端推荐模型的联邦训练方法,能够在保护用户隐私的同时提供泛化能力更强的轻量化推荐模型,减少了联邦推荐中因数据稀疏带来准确率较低的问题。所述联邦训练方法包括:
步骤S1、在云端构建教师模型和学生模型,利用云端的历史数据集对所述教师模型和所述学生模型进行预训练,并将预训练得到的学生模型发送至参与联邦训练的多个终端;
步骤S2、各终端分别利用本地数据集对预训练完毕的学生模型进行首轮联邦训练,得到首轮本地学生模型;
步骤S3、各终端将当前轮本地学生模型的参数发送至云端,云端对所有当前轮本地学生模型的参数进行聚合,得到当前轮全局模型及其参数;
步骤S4、在云端利用教师模型和历史数据集对当前轮全局模型的参数通过自监督知识蒸馏的方式进行参数增强,得到增强后的当前轮全局模型的参数,将增强后的当前轮全局模型的参数发送至各终端,以对各终端的本地学生模型的参数进行更新;
步骤S5、令训练轮次加1,各终端利用各自的本地数据集对当前轮本地学生模型进行联邦训练,返回步骤S3,直至训练轮次达到迭代轮次数上限。
在一些实施例中,所述教师模型和所述学生模型的输入数据为用户在历史时刻点击商品的序列,输出数据为根据所述输入数据以及待排序的商品集合得到的用户在下一时刻点击所述待排序的商品集合中各商品顺序的预测结果。
在一些实施例中,对所述教师模型和所述学生模型进行预训练,以及对所述学生模型进行首轮联邦训练时均采用交叉熵损失函数。
在一些实施例中,所述当前轮全局模型的参数按照以下公式计算得到:
其中,N表示当前轮参与联邦训练的终端的个数,Du表示由终端u所拥有的本地推荐数据构成的本地数据集,表示终端u利用本地数据集训练得到的当前轮本地学生模型的参数,/>表示云端经过参数聚合后得到的当前轮全局模型的参数,τ表示联邦训练的当前轮次。
在一些实施例中,设所述在云端利用教师模型和历史数据集对当前轮全局模型的参数通过自监督知识蒸馏的方式进行参数增强时采用的损失函数为Lτ,计算公式如下:
其中,
为当前轮τ训练采用的交叉熵损失函数,/>为历史数据集中含有的样本总数,M为历史数据集中每个样本含有的商品类别总数,Yi,m为历史数据集中第i个样本xi的第m个商品类别的真实标签,/>表示当前轮全局模型对历史数据集中第i个样本xi的第m个商品类别的预测概率;
为当前轮τ训练采用的知识蒸馏损失函数,用于训练当前轮全局模型对于不同商品类别的识别能力,表示为当前轮教师模型对第v类商品的输出概率,/>表示当前轮全局模型对第v类商品的输出概率;
为当前轮τ训练采用的自监督知识蒸馏损失函数,用于使全局模型的输出更接近教师模型,λSSKD为SSKD系数,/>分别为将历史数据集Dh中的样本xi分别输入到当前轮教师模型和当前轮全局模型后得到的表征,/>分别为将历史数据集Dh中的样本xj分别输入到当前轮教师模型和当前轮全局模型后得到的表征,j≠i,ψ(·)为引入的中间函数,E(·)为用于计算期望的函数,/>为用于计算/>和/>之间的KL散度的函数,/>为温度系数。
在一些实施例中,步骤S5中,利用上一轮本地学生模型的参数对当前轮本地学生模型的参数基于自监督对比蒸馏的方法进行指导,以提高当前轮本地学生模型的个性化效果。
在一些实施例中,步骤S5中采用的损失函数为:
其中,为在第τ轮训练时,将本地数据集Du中的样本/>输入到终端u的经过个性化训练的本地学生模型后得到的表征;/>为在第τ轮训练时,将本地数据集Du中的样本/>输入到终端u接收的当前轮增强后的全局模型的参数所对应的模型后得到的表征;/>是上一轮训练时,将本地数据集Du中的样本/>输入到终端u上一轮接收的增强后的全局模型的参数所对应的模型后得到的表征;ψ(·)为引入的中间函数,E(·)为用于计算期望的函数,为用于计算/>和/>之间的KL散度的函数,/>为温度系数;/>为利用本地数据集Du对当前轮τ本地学生模型进行训练时采用的交叉熵损失函数。
本公开第二方面提供的一种终端推荐模型的联邦训练装置,包括:
预训练模块,被配置为在云端构建教师模型和学生模型,利用云端的历史数据集对所述教师模型和所述学生模型进行预训练,并将预训练得到的学生模型发送至参与联邦训练的多个终端;
首轮联邦训练模块,被配置为在各终端分别利用本地数据集对预训练完毕的学生模型进行首轮联邦训练,得到首轮本地学生模型;
聚合模块,被配置为使各终端将当前轮本地学生模型的参数发送至云端,云端对所有当前轮本地学生模型的参数进行聚合,得到当前轮全局模型及其参数;
增强模块,被配置为在云端利用教师模型和历史数据集对当前轮全局模型的参数通过自监督蒸馏和知识蒸馏的方式进行参数增强,得到增强后的当前轮全局模型的参数,将增强后的当前轮全局模型的参数发送至各终端,以对各终端的本地学生模型的参数进行更新;
本地训练模块,被配置为从第二轮联邦训练开始,使各终端利用各自的本地数据集对当前轮本地学生模型进行联邦训练。
在一些实施例中,所述本地训练模块还被配置为利用上一轮本地学生模型的参数对当前轮本地学生模型的参数基于自监督对比蒸馏的方法进行指导,以提高当前轮本地学生模型的个性化效果。
本公开第三方面提供的一种计算机可读存储介质,所述计算机可读存储介质存储计算机指令,所述计算机指令用于使所述计算机执行本公开第一方面任一实施例所述的终端推荐模型的联邦训练方法。
本公开实施例提供的终端推荐模型的联邦训练方法,具有以下特点及有益效果:
1、在用户推荐数据较稀疏并且资源受限的前提下,在云端利用教师模型和历史数据,对联邦学习的全局模型进行指导,可以能够在保护隐私的同时为用户提供泛化能力更强的轻量化推荐模型,从而减少了联邦推荐中因数据稀疏带来准确率较低的问题。
2、本公开采用的参数增强方法,在联邦学习的每一次参数聚合后,利用云端的教师模型在历史数据集上辅助全局模型在云端的历史数据集中训练,提高了全局模型的收敛速度和能力上限,从而提高了在本地部署的轻量化推荐模型的能力上限,在很大程度上减少了终端设备在训练达到设定精度时的通讯开销。
3、本公开通过个性化算法,使用户在利用本地数据进行训练时,利用上一轮的训练结果能够帮助本轮训练快速找到梯度下降方向,并加快梯度下降的速度,从而实现更好的个性化推荐能力。
附图说明
图1是本公开第一方面实施例提供的终端推荐模型的联邦训练方法的架构示意图;
图2是本公开第一方面实施例提供的终端推荐模型的联邦训练方法的流程示意图;
图3是根据本公开第一方面的一个具体实施例得到的终端推荐模型与现有联邦训练方法得到的推荐模型的推荐准确率对比图;
图4是根据本公开第一方面的一个具体实施例得到的终端推荐模型与现有联邦训练方法得到的推荐模型在个性化方面的推荐准确率对比图;
图5是根据本公开第一方面的一个具体实施例的联邦训练方法与现有联邦训练方法在通讯开销方面的对比图;
图6是本公开第三方面实施例提供的电子设备的结构示意图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细描述。应当理解,此处所描述的具体实施例仅仅用于解释本申请,并不用于限定本申请。
相反,本申请涵盖任何由权利要求定义的在本申请精髓和范围上做的替代、修改、等效方法以及方案。进一步,为了使公众对本申请有更好的了解,在下文对本申请的细节描述中,详尽描述了一些特定的细节部分。对本领域技术人员来说没有这些细节部分的描述也可以完全理解本申请。
联邦学习作为一种新兴的分布式机器学习范式,虽然不仅减少了数据隐私泄露的风险还降低了数据传输的开销,并打破了数据孤岛的限制。但将联邦学习和推荐系统结合需要面临以下三方面的问题:
第一方面,在推荐系统领域往往会出现用户点击商品较少,但商品总数很多的场景,在这样的场景下,推荐模型对于商品给出不推荐的准确率高达90%,这种情况会使用户的使用体验较差,该问题在推荐系统领域称为推荐数据的稀疏性。
第二方面,联邦学习对数据隐私的保护,使得用户之间无法交换数据,并且每一个用户的本地数据较为少,利用本地数据训练的推荐模型无法实现为用户进行准确的商品推荐。
第三方面,参与联邦推荐的用户所使用的终端设备,其计算资源和存储资源受到限制。如何在用户资源有限的情况下,部署轻量化的推荐模型,在终端上为用户进行准确的推荐是一个迫切的问题。
因此,本申请旨在从因数据稀疏带来推荐准确率较低的技术优化角度,提出了对联邦推荐模型训练的优化方案。
参见图1、图2,本公开第一方面实施例提供的终端推荐模型的联邦训练方法,包括以下步骤:
步骤S1、在云端构建教师模型和学生模型,利用云端的历史数据集Dh对教师模型和学生模型进行预训练,并将预训练得到的学生模型发至参与联邦训练的多个终端;
步骤S2、各终端分别利用本地数据集Du对预训练完毕的学生模型进行首轮联邦训练,得到首轮本地学生模型;
步骤S3、各终端将当前轮本地学生模型的参数发送至云端,云端对所有当前轮本地学生模型的参数进行聚合,得到当前轮全局模型及其参数;
步骤S4、在云端利用教师模型和历史数据集Dh对当前轮全局模型的参数通过自监督知识蒸馏(Self-Supervised Knowledge Distillation,SSKD)的方式进行参数增强,得到增强后的当前轮全局模型的参数,将增强后的当前轮全局模型的参数发送至各终端,以对各终端的本地学生模型的参数进行更新;
步骤S5、令训练轮次加1,各终端利用各自的本地数据集对当前轮本地学生模型进行联邦训练,返回步骤S3,直至训练轮次达到迭代轮次数上限。
图1示出了根据本公开一个实施例提供的一种终端推荐模型的联邦训练方法所涉及的系统的架构示意图,该系统具体包括云端和多个终端。云端可采用云端服务器,其具备优渥的资源,在云端中存储有历史数据集Dh,并设有一个需要较多计算资源的教师模型Mt和一个轻量化适合在终端上部署的学生模型Ms,历史数据集Dh是在隐私保护条例(包括《数据安全法》、《关键信息基础设施安全保护条例》、《个人信息保护法》)发布前,云端收集到的数据组成,在本文中Dh是由前3000个用户的数据组成。每一个终端用于为单独的一个用户提供推荐结果,每一个终端中存储有其用户的本地数据集Du,且每一个终端所持有的数据与其他参与模型训练的终端不进行数据共享,具体地,本地数据集Du由在隐私保护条例发布后的用户本地的推荐数据组成。
进一步地,教师模型Mt和学生模型Ms可以采用任意适合序列推荐的模型,比如Bert模型(Bert4Rec),两者的不同之处在于:教师模型Mt和学生模型Ms中嵌入层的维度的不同,嵌入层的维度越高表示模型能够更好的对物品的嵌入层进行区分,从而实现更好的推荐效果,但嵌入层的维度越高会使模型的大小越大,不利于终端部署。在本申请的一个实施例中,教师模型Mt和学生模型Ms均采用Bert模型,教师模型Mt的嵌入层的维度为512,学生模型Ms的嵌入层的维度为32。
在本申请的一个实施例中,步骤S1中,在云端构建好教师模型Mt和学生模型Ms后,需要利用云端的历史数据集Dh对教师模型和学生模型进行预训练。在预训练阶段,教师模型和学生模型采用相同的训练方式,将预训练阶段的教师模型和学生模型统称为初始模型,设定预训练的迭代轮次数上限为1000轮。在预训练的每一轮迭代训练中,初始模型的输入数据input为用户u在0~T-1时刻点击商品的序列Su={itemid1,itemid2,....,itemidT-1},itemidT-1为用户u在T-1时刻点击的商品,初始模型的输出数据为根据序列Su以及待排序的商品集合得到的用户u在T时刻点击商品集合/>中各商品顺序的预测结果,对预测结果利用交叉熵损失函数计算损失值,随后初始模型根据该损失值进行反向传播,以更新初始模型的网络参数。不断重复上述预训练过程,直至预训练轮次数达到预训练的迭代轮次数上限,得到预训练完毕的教师模型和学生模型。云端将预训练得到的学生模型发送至参与联邦训练的用户的终端并进行部署。
在本申请的一个实施例中,步骤S2中,用户的终端u接收到云端发送的预训练学生模型后,利用用户的本地数据集Du对各自的预训练学生模型进行首轮联邦训练,得到当前轮本地学生模型。在首轮联邦训练中,参与联邦训练的终端利用本地推荐数据对各自的预训练学生模型进行训练,其中采用的损失函数和模型参数更新方法均与预训练过程一致,此处不再赘述。
在本申请的一个实施例中,步骤S3中,当用户的终端u利用本地数据集Du训练好各自的本地学生模型后,终端u将当前轮本地学生模型的参数发送至云端,云端对接收的所有当前轮本地学生模型的参数进行聚合,得到当前轮全局模型,设当前轮全局模型的参数为/>按照以下公式计算得到:
其中,N表示当前轮参与联邦学习训练的终端的个数,Du表示由终端u所拥有的本地推荐数据构成的本地数据集,表示终端u利用本地数据集训练得到的当前轮本地学生模型的参数,/>表示云端经过参数聚合后得到的当前轮全局模型的参数,τ表示联邦训练的当前轮次。
在本申请的一个实施例中,考虑到本地推荐数据具有较强的稀疏性,聚合后得到的全局模型无法准确的判断用户的喜好,会对用户推荐不满意的商品。因此,还需要使聚合后的全局模型获得更好的泛化能力。在本公开实施例的步骤S4中,在云端利用之前预训练好的教师模型Mt和历史数据集Dh对当前轮全局模型的参数通过SSKD的方式进行参数增强,得到增强后的当前轮全局模型的参数/>参数增强采用的损失函数考虑了3方面的因素,设参数增强采用的当前轮损失函数为Lτ,计算公式如下:
其中,
为云端利用历史数据集Dh对当前轮τ全局模型进行训练时采用的交叉熵损失函数,/>为云端存储的历史数据集Dh中含有的样本总数,M为历史数据集中每个样本含有的商品类别总数,Yi,m表示历史数据集中第i个样本xi的第m个商品类别的真实标签,/>表示当前轮全局模型对历史数据集中第i个样本xi的第m个商品类别的预测概率。
为利用教师模型对当前轮全局模型进行知识蒸馏所采用的损失函数,记为当前轮τ训练采用的知识蒸馏损失函数,用于训练当前轮全局模型对于不同商品类别的识别能力,具体来说,对于历史数据集Dh中的每一个样本xi,提高全局模型对于不同商品类别的识别能力。损失函数/>中,/>表示为当前轮教师模型对第v类商品的输出概率,/>表示当前轮全局模型对第v类商品的输出概率。这其中,由于全局模型对每一种商品类别没有较强的识别能力,因此利用损失函数/>借助强大的教师模型来指导全局模型,使全局模型对每一类商品的识别能力更接近于教师模型。
为通过SSKD方式对当前轮全局模型的参数进行增强训练时采用的损失函数,记为当前轮τ训练采用的自监督知识蒸馏损失函数,λSSKD为SSKD的系数,通过实验,当取为50时,效果最优。/>分别为当前轮τ训练将历史数据集Dh中的数据xi分别输入到当前轮教师模型/>和全局模型/>后得到的表征,/>分别为将历史数据集Dh中的样本xj分别输入到当前轮教师模型和当前轮全局模型后得到的表征,j≠i;ψ(·)为引入的中间函数,用于计算教师模型和全局模型对同一个输入样本xi的输出的KL散度,D(·)为用于计算期望的函数,/>为用于计算/>和/>之间的KL散度的函数,/>为温度系数,本实施例中,/>取0.5。通过SSKD的方式进行参数增强的目的是使全局模型的输出更接近教师模型,从而使全局模型能够从教师模型中学习更多的知识,通过模仿教师模型获得更好的泛化能力。
待终端接收到云端发送的增强后的当前轮全局模型的参数后,利用该参数替换终端的本地学生模型的参数/>实现对终端中本地学生模型参数的更新。
进一步地,在步骤S4之后,增强后的当前轮全局模型还远远不足以应付推荐系统对于每个用户的个性化能力,并且由于云端上的历史数据集不具有时效性(即本公开中对历史数据集不进行更新),无法实时满足用户的需求。因此为了提高每个用户的个性化能力,本公开第一方面实施例提供的联邦训练方法,在步骤S5中进行的本地训练采用的是基于自监督对比蒸馏(Self-Contrastive Distillation,SCD)的个性化训练方法,具体实现过程如下:
在终端第二次接收到增强后的全局模型的参数后,即τ≥2,对于任意终端u,利用本地数据集Du和增强后的当前轮全局模型的参数/>对本地学生模型进行个性化训练,设对终端u进行本地个性化训练采用的损失函数为/>具体公式如下:
其中,为在第τ轮训练时,将本地数据集Du中的样本/>输入到终端u的经过个性化训练的本地学生模型(即当前轮期望得到的本地学生模型)后得到的表征;/>为在第τ轮训练时,将本地数据集Du中的样本/>输入到终端u接收的当前轮增强后的全局模型的参数所对应的模型后得到的表征;/>是上一轮训练(τ-1轮)结束后,将本地数据集Du中的样本/>输入到终端u上一轮接收的增强后的全局模型的参数所对应的模型后得到的表征,即/> 的含义参见/>中/>的物理含义,此处不再赘述;为利用本地数据集Du对当前轮τ本地学生模型进行训练时采用的交叉熵损失函数。
到此,本公开使用上述个性化训练方法的目的是终端在利用本地数据进行训练时,利用上一轮的训练结果能够帮助本轮训练快速找到梯度下降方向,并加快梯度下降的速度,从而实现更好的个性化推荐能力。
本公开实施例方法的有效性验证:
为了验证本公开实施例提供的联邦训练方法在推荐准确性、个性化能力和通讯开销方面的性能,将本公开实施例提供的联邦训练方法(记为Fosses)与现有的联邦推荐系统学习方法——FedFast(KDD20)、DeepRec(WWW21)和集中式训练方法(Centralized)在推荐系统的公开数据集Yelp上进行对比,对比结果参见图3~图5。
图3表示,每次联邦训练选择50个客户端,一共进行300轮联邦训练后,在Yelp数据集测试的全局结果,本分析例选择了DeepRec、FedFast和集中式训练300轮本地学生模型进行比较,以NDCG@5作为评价指标,本公开实施例的联邦训练方法在NDCG@5上比FedFast方法提高了40%,与DeepRec相比,提高了8.7%,与集中式相比提高了0.46%。
图4表示,每次联邦训练选择1000个客户端参与,一共进行100轮联邦训练时,每10轮测试一次参与联邦训练的用户在本地测试任务上的准确率,并取平均值,以此反应各训练方法的个性化能力。如图4所示,以NDCG@10作为评价指标,本公开实施例提供的联邦训练方法明显高于DeepRec和FedFast方法。
图5表示,当模型测试精度NDCG@10达到0.53时,所需要的通讯开销(即推荐模型收敛所需要的训练轮次),本公开实施例的联邦训练方法相比DeepRec和FedFast,可以减少用户的通讯开销,提供更快的收敛速度。
本公开第二方面实施例提供的终端推荐模型的联邦训练装置,包括:
预训练模块,被配置为在云端构建教师模型和学生模型,利用云端的历史数据集对所述教师模型和所述学生模型进行预训练,并将预训练得到的学生模型发送至参与联邦训练的多个终端;
首轮联邦训练模块,被配置为在各终端分别利用本地数据集对预训练完毕的学生模型进行首轮联邦训练,得到首轮本地学生模型;
聚合模块,被配置为使各终端将当前轮本地学生模型的参数发送至云端,云端对所有当前轮本地学生模型的参数进行聚合,得到当前轮全局模型及其参数;
增强模块,被配置为在云端利用教师模型和历史数据集对当前轮全局模型的参数通过自监督知识蒸馏的方式进行参数增强,得到增强后的当前轮全局模型的参数,将增强后的当前轮全局模型的参数发送至各终端,以对各终端的本地学生模型的参数进行更新;
本地训练模块,被配置为从第二轮联邦训练开始,使各终端利用各自的本地数据集对当前轮本地学生模型进行联邦训练。
在一些实施例中,本地训练模块还被配置为利用上一轮本地学生模型的参数对当前轮本地学生模型的参数基于自监督对比蒸馏的方法进行指导,以提高当前轮本地学生模型的个性化效果。
需要说明的是,前述对一种终端推荐模型的联邦训练方法的实施例解释说明也适用于本实施例的一种终端推荐模型的联邦训练装置,在此不再赘述。
为了实现上述实施例,本公开实施例还提出一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行,用于执行上述实施例的终端推荐模型的联邦训练方法。
下面参考图6,其示出了适于用来实现本公开实施例的电子设备的结构示意图。其中,需要说明的是,本公开实施例中的电子设备可以包括但不限于诸如移动电话、笔记本电脑、数字广播接收器、PDA(个人数字助理)、PAD(平板电脑)、PMP(便携式多媒体播放器)、车载终端(例如车载导航终端)等等的移动终端以及诸如数字TV、台式计算机、服务器等等的固定终端。图6示出的电子设备仅仅是一个示例,不应对本公开实施例的功能和使用范围带来任何限制。
如图6所示,电子设备可以包括处理装置(例如中央处理器、图形处理器等)101,其可以根据存储在只读存储器(ROM)102中的程序或者从存储装置108加载到随机访问存储器(RAM)103中的程序而执行各种适当的动作和处理。在RAM 103中,还存储有电子设备操作所需的各种程序和数据。处理装置101、ROM 102以及RAM 103通过总线104彼此相连。输入/输出(I/O)接口105也连接至总线104。
通常,以下装置可以连接至I/O接口105:包括例如触摸屏、触摸板、键盘、鼠标、摄像头、麦克风等的输入装置106;包括例如液晶显示器(LCD)、扬声器、振动器等的输出装置107;包括例如磁带、硬盘等的存储装置108;以及通信装置109。通信装置109可以允许电子设备与其他设备进行无线或有线通信以交换数据。虽然图6示出了具有各种装置的电子设备,但是应理解的是,并不要求实施或具备所有示出的装置。可以替代地实施或具备更多或更少的装置。
特别地,根据本公开的实施例,上文参考流程图描述的过程可以被实现为计算机软件程序。例如,本实施例包括一种计算机程序产品,其包括承载在计算机可读介质上的计算机程序,该计算机程序包含用于执行流程图中所示方法的程序代码。在这样的实施例中,该计算机程序可以通过通信装置109从网络上被下载和安装,或者从存储装置108被安装,或者从ROM 102被安装。在该计算机程序被处理装置101执行时,执行本公开实施例的方法中限定的上述功能。
需要说明的是,本公开上述的计算机可读介质可以是计算机可读信号介质或者计算机可读存储介质或者是上述两者的任意组合。计算机可读存储介质例如可以是——但不限于——电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。计算机可读存储介质的更具体的例子可以包括但不限于:具有一个或多个导线的电连接、便携式计算机磁盘、硬盘、随机访问存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑磁盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。在本公开中,计算机可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。而在本公开中,计算机可读信号介质可以包括在基带中或者作为载波一部分传播的数据信号,其中承载了计算机可读的程序代码。这种传播的数据信号可以采用多种形式,包括但不限于电磁信号、光信号或上述的任意合适的组合。计算机可读信号介质还可以是计算机可读存储介质以外的任何计算机可读介质,该计算机可读信号介质可以发送、传播或者传输用于由指令执行系统、装置或者器件使用或者与其结合使用的程序。计算机可读介质上包含的程序代码可以用任何适当的介质传输,包括但不限于:电线、光缆、RF(射频)等等,或者上述的任意合适的组合。
上述计算机可读介质可以是上述电子设备中所包含的;也可以是单独存在,而未装配入该电子设备中。
上述计算机可读介质承载有一个或者多个程序,当上述一个或者多个程序被该电子设备执行时,使得该电子设备执行上述终端推荐模型的联邦训练方法。
可以以一种或多种程序设计语言或其组合来编写用于执行本公开的操作的计算机程序代码,上述程序设计语言包括面向对象的程序设计语言—诸如Java、Smalltalk、C++、python,还包括常规的过程式程序设计语言—诸如“C”语言或类似的程序设计语言。程序代码可以完全地在用户计算机上执行、部分地在用户计算机上执行、作为一个独立的软件包执行、部分在用户计算机上部分在远程计算机上执行、或者完全在远程计算机或服务器上执行。在涉及远程计算机的情形中,远程计算机可以通过任意种类的网络——包括局域网(LAN)或广域网(WAN)—连接到用户计算机,或者,可以连接到外部计算机(例如利用因特网服务提供商来通过因特网连接)。
在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本申请的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不必须针对的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任一个或多个实施例或示例中以合适的方式结合。此外,在不相互矛盾的情况下,本领域的技术人员可以将本说明书中描述的不同实施例或示例以及不同实施例或示例的特征进行结合和组合。
此外,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括至少一个该特征。在本申请的描述中,“多个”的含义是至少两个,例如两个,三个等,除非另有明确具体的限定。
流程图中或在此以其他方式描述的任何过程或方法描述可以被理解为,表示包括一个或更多个用于实现特定逻辑功能或过程的步骤的可执行指令的代码的模块、片段或部分,并且本申请的优选实施方式的范围包括另外的实现,其中可以不按所示出或讨论的顺序,包括根据所涉及的功能按基本同时的方式或按相反的顺序,来执行功能,这应被本申请的实施例所属技术领域的技术人员所理解。
在流程图中表示或在此以其他方式描述的逻辑和/或步骤,例如,可以被认为是用于实现逻辑功能的可执行指令的定序列表,可以具体实现在任何计算机可读介质中,以供指令执行系统、装置或设备(如基于计算机的系统、包括处理器的系统或其他可以从指令执行系统、装置或设备取指令并执行指令的系统)使用,或结合这些指令执行系统、装置或设备而使用。就本说明书而言,“计算机可读介质”可以是任何可以包含、存储、通信、传播或传输程序以供指令执行系统、装置或设备或结合这些指令执行系统、装置或设备而使用的装置。计算机可读介质的更具体的示例(非穷尽性列表)包括以下:具有一个或多个布线的电连接部(电子装置),便携式计算机盘盒(磁装置),随机存取存储器(RAM),只读存储器(ROM),可擦除可编辑只读存储器(EPROM或闪速存储器),光纤装置,以及便携式光盘只读存储器(CDROM)。另外,计算机可读介质甚至可以是可在其上打印程序的纸或其他合适的介质,因为可以例如通过对纸或其他介质进行光学扫描,接着进行编辑、解译或必要时以其他合适方式进行处理来以电子方式获得程序,然后将其存储在计算机存储器中。
应当理解,本申请的各部分可以用硬件、软件、固件或它们的组合来实现。在上述实施方式中,多个步骤或方法可以用存储在存储器中且由合适的指令执行系统执行的软件或固件来实现。例如,如果用硬件来实现,和在另一实施方式中一样,可用本领域公知的下列技术中的任一项或他们的组合来实现:具有用于对数据信号实现逻辑功能的逻辑门电路的离散逻辑电路,具有合适的组合逻辑门电路的专用集成电路,可编程门阵列(PGA),现场可编程门阵列(FPGA)等。
本技术领域的普通技术人员可以理解实现上述实施例方法携带的全部或部分步骤,可以通过程序来指令相关的硬件完成,所开发的程序可以存储于一种计算机可读存储介质中,该程序在执行时,包括方法实施例的步骤之一或其组合。
此外,在本申请各个实施例中的各功能单元可以集成在一个处理模块中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个模块中。上述集成的模块既可以采用硬件的形式实现,也可以采用软件功能模块的形式实现。集成的模块如果以软件功能模块的形式实现并作为独立的产品销售或使用时,也可以存储在一个计算机可读取存储介质中。
上述提到的存储介质可以是只读存储器,磁盘或光盘等。尽管上面已经示出和描述了本申请的实施例,可以理解的是,上述实施例是示例性的,不能理解为对本申请的限制,本领域的普通技术人员在本申请的范围内可以对上述实施例进行变化、修改、替换和变型。
Claims (10)
1.一种终端推荐模型的联邦训练方法,其特征在于,包括:
步骤S1、在云端构建教师模型和学生模型,利用云端的历史数据集对所述教师模型和所述学生模型进行预训练,并将预训练得到的学生模型发送至参与联邦训练的多个终端;
步骤S2、各终端分别利用本地数据集对预训练完毕的学生模型进行首轮联邦训练,得到首轮本地学生模型;
步骤S3、各终端将当前轮本地学生模型的参数发送至云端,云端对所有当前轮本地学生模型的参数进行聚合,得到当前轮全局模型及其参数;
步骤S4、在云端利用教师模型和历史数据集对当前轮全局模型的参数通过自监督知识蒸馏的方式进行参数增强,得到增强后的当前轮全局模型的参数,将增强后的当前轮全局模型的参数发送至各终端,以对各终端的本地学生模型的参数进行更新;
步骤S5、令训练轮次加1,各终端利用各自的本地数据集对当前轮本地学生模型进行联邦训练,返回步骤S3,直至训练轮次达到迭代轮次数上限。
2.根据权利要求1所述的联邦训练方法,其特征在于,所述教师模型和所述学生模型的输入数据为用户在历史时刻点击商品的序列,输出数据为根据所述输入数据以及待排序的商品集合得到的用户在下一时刻点击所述待排序的商品集合中各商品顺序的预测结果。
3.根据权利要求1所述的联邦训练方法,其特征在于,对所述教师模型和所述学生模型进行预训练,以及对所述学生模型进行首轮联邦训练时均采用交叉熵损失函数。
4.根据权利要求1所述的联邦训练方法,其特征在于,所述当前轮全局模型的参数按照以下公式计算得到:
其中,N表示当前轮参与联邦训练的终端的个数,Du表示由终端u所拥有的本地推荐数据构成的本地数据集,表示终端u利用本地数据集训练得到的当前轮本地学生模型的参数,表示云端经过参数聚合后得到的当前轮全局模型的参数,τ表示联邦训练的当前轮次。
5.根据权利要求1所述的联邦训练方法,其特征在于,设所述在云端利用教师模型和历史数据集对当前轮全局模型的参数通过自监督知识蒸馏的方式进行参数增强时采用的损失函数为Lτ,计算公式如下:
其中,
为当前轮τ训练采用的交叉熵损失函数,/>为历史数据集中含有的样本总数,M为历史数据集中每个样本含有的商品类别总数,Yi,m为历史数据集中第i个样本xi的第m个商品类别的真实标签,/>表示当前轮全局模型对历史数据集中第i个样本xi的第m个商品类别的预测概率;
为当前轮τ训练采用的知识蒸馏损失函数,用于训练当前轮全局模型对于不同商品类别的识别能力,/>表示为当前轮教师模型对第v类商品的输出概率,/>表示当前轮全局模型对第v类商品的输出概率;
为当前轮τ训练采用的自监督知识蒸馏损失函数,用于使全局模型的输出更接近教师模型,λSSKD为SSKD系数,/>分别为将历史数据集Dh中的样本xi分别输入到当前轮教师模型和当前轮全局模型后得到的表征,/>分别为将历史数据集Dh中的样本xj分别输入到当前轮教师模型和当前轮全局模型后得到的表征,j≠i,ψ(·)为引入的中间函数,E(·)为用于计算期望的函数,/>为用于计算/>和/>之间的KL散度的函数,/>为温度系数。
6.根据权利要求1~5中任一项所述的联邦训练方法,其特征在于,步骤S5中,利用上一轮本地学生模型的参数对当前轮本地学生模型的参数基于自监督对比蒸馏的方法进行指导,以提高当前轮本地学生模型的个性化效果。
7.根据权利要求6所述的联邦训练方法,其特征在于,步骤S5中采用的损失函数为:
其中,为在第τ轮训练时,将本地数据集Du中的样本/>输入到终端u的经过个性化训练的本地学生模型后得到的表征;/>为在第τ轮训练时,将本地数据集Du中的样本/>输入到终端u接收的当前轮增强后的全局模型的参数所对应的模型后得到的表征;/>是上一轮训练时,将本地数据集Du中的样本/>输入到终端u上一轮接收的增强后的全局模型的参数所对应的模型后得到的表征;ψ(·)为引入的中间函数,E(·)为用于计算期望的函数,为用于计算/>和/>之间的KL散度的函数,/>为温度系数;/>为利用本地数据集Du对当前轮τ本地学生模型进行训练时采用的交叉熵损失函数。
8.一种终端推荐模型的联邦训练装置,其特征在于,包括:
预训练模块,被配置为在云端构建教师模型和学生模型,利用云端的历史数据集对所述教师模型和所述学生模型进行预训练,并将预训练得到的学生模型发送至参与联邦训练的多个终端;
首轮联邦训练模块,被配置为在各终端分别利用本地数据集对预训练完毕的学生模型进行首轮联邦训练,得到首轮本地学生模型;
聚合模块,被配置为使各终端将当前轮本地学生模型的参数发送至云端,云端对所有当前轮本地学生模型的参数进行聚合,得到当前轮全局模型及其参数;
增强模块,被配置为在云端利用教师模型和历史数据集对当前轮全局模型的参数通过自监督蒸馏和知识蒸馏的方式进行参数增强,得到增强后的当前轮全局模型的参数,将增强后的当前轮全局模型的参数发送至各终端,以对各终端的本地学生模型的参数进行更新;
本地训练模块,被配置为从第二轮联邦训练开始,使各终端利用各自的本地数据集对当前轮本地学生模型进行联邦训练。
9.根据权利要求8所述的联邦训练装置,其特征在于,所述本地训练模块还被配置为利用上一轮本地学生模型的参数对当前轮本地学生模型的参数基于自监督对比蒸馏的方法进行指导,以提高当前轮本地学生模型的个性化效果。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储计算机指令,所述计算机指令用于使所述计算机执行权利要求1~7中任一项所述的终端推荐模型的联邦训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311021672.6A CN117150122A (zh) | 2023-08-15 | 2023-08-15 | 终端推荐模型的联邦训练方法、装置和存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311021672.6A CN117150122A (zh) | 2023-08-15 | 2023-08-15 | 终端推荐模型的联邦训练方法、装置和存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117150122A true CN117150122A (zh) | 2023-12-01 |
Family
ID=88899781
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311021672.6A Pending CN117150122A (zh) | 2023-08-15 | 2023-08-15 | 终端推荐模型的联邦训练方法、装置和存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117150122A (zh) |
-
2023
- 2023-08-15 CN CN202311021672.6A patent/CN117150122A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109104620B (zh) | 一种短视频推荐方法、装置和可读介质 | |
CN110321958B (zh) | 神经网络模型的训练方法、视频相似度确定方法 | |
US20240127795A1 (en) | Model training method, speech recognition method, device, medium, and apparatus | |
CN112149699B (zh) | 用于生成模型的方法、装置和用于识别图像的方法、装置 | |
CN111104599B (zh) | 用于输出信息的方法和装置 | |
CN112650841A (zh) | 信息处理方法、装置和电子设备 | |
CN112766284B (zh) | 图像识别方法和装置、存储介质和电子设备 | |
CN112836128A (zh) | 信息推荐方法、装置、设备和存储介质 | |
CN115908640A (zh) | 生成图像的方法、装置、可读介质及电子设备 | |
CN111291715B (zh) | 基于多尺度卷积神经网络的车型识别方法、电子设备及存储介质 | |
CN114417174B (zh) | 内容推荐方法、装置、设备及计算机存储介质 | |
CN116128055A (zh) | 图谱构建方法、装置、电子设备和计算机可读介质 | |
CN113033707B (zh) | 视频分类方法、装置、可读介质及电子设备 | |
CN109903075B (zh) | 基于dnn的回归分布模型及其训练方法、电子设备 | |
EP4187472A1 (en) | Method and apparatus for detecting false transaction orders | |
CN112241761B (zh) | 模型训练方法、装置和电子设备 | |
CN113140012A (zh) | 图像处理方法、装置、介质及电子设备 | |
WO2023174075A1 (zh) | 内容检测模型的训练方法、内容检测方法及装置 | |
CN116258911A (zh) | 图像分类模型的训练方法、装置、设备及存储介质 | |
CN115482415A (zh) | 模型训练方法、图像分类方法和装置 | |
CN117150122A (zh) | 终端推荐模型的联邦训练方法、装置和存储介质 | |
CN112860999B (zh) | 信息推荐方法、装置、设备和存储介质 | |
CN112417260B (zh) | 本地化推荐方法、装置及存储介质 | |
CN111581455A (zh) | 文本生成模型的生成方法、装置和电子设备 | |
CN112347278A (zh) | 用于训练表征模型的方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |