CN116645130A - 基于联邦学习与gru结合的汽车订单需求量预测方法 - Google Patents
基于联邦学习与gru结合的汽车订单需求量预测方法 Download PDFInfo
- Publication number
- CN116645130A CN116645130A CN202310421463.4A CN202310421463A CN116645130A CN 116645130 A CN116645130 A CN 116645130A CN 202310421463 A CN202310421463 A CN 202310421463A CN 116645130 A CN116645130 A CN 116645130A
- Authority
- CN
- China
- Prior art keywords
- model
- client
- training
- gru
- local
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 76
- 238000012549 training Methods 0.000 claims abstract description 106
- 238000004891 communication Methods 0.000 claims abstract description 28
- 238000012545 processing Methods 0.000 claims abstract description 22
- 238000013140 knowledge distillation Methods 0.000 claims abstract description 20
- 230000005012 migration Effects 0.000 claims abstract description 9
- 238000013508 migration Methods 0.000 claims abstract description 9
- 238000012795 verification Methods 0.000 claims abstract description 6
- 230000006870 function Effects 0.000 claims description 51
- 239000013598 vector Substances 0.000 claims description 30
- 238000004821 distillation Methods 0.000 claims description 18
- 230000008569 process Effects 0.000 claims description 13
- 238000009826 distribution Methods 0.000 claims description 12
- 238000010606 normalization Methods 0.000 claims description 9
- 230000004913 activation Effects 0.000 claims description 6
- 230000002776 aggregation Effects 0.000 claims description 6
- 238000004220 aggregation Methods 0.000 claims description 6
- 238000011478 gradient descent method Methods 0.000 claims description 6
- 239000011159 matrix material Substances 0.000 claims description 6
- 238000007781 pre-processing Methods 0.000 claims description 5
- 230000007246 mechanism Effects 0.000 claims description 4
- 239000000203 mixture Substances 0.000 claims description 3
- 238000005457 optimization Methods 0.000 claims description 3
- 238000005096 rolling process Methods 0.000 claims description 3
- 238000012360 testing method Methods 0.000 claims description 3
- 230000001502 supplementing effect Effects 0.000 claims description 2
- 230000008901 benefit Effects 0.000 abstract description 4
- 238000010801 machine learning Methods 0.000 description 9
- 230000005540 biological transmission Effects 0.000 description 2
- 238000013136 deep learning model Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 239000006185 dispersion Substances 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000005530 etching Methods 0.000 description 1
- 238000004880 explosion Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q30/00—Commerce
- G06Q30/02—Marketing; Price estimation or determination; Fundraising
- G06Q30/0201—Market modelling; Market analysis; Collecting market data
- G06Q30/0202—Market predictions or forecasting for commercial activities
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/044—Recurrent networks, e.g. Hopfield networks
- G06N3/0442—Recurrent networks, e.g. Hopfield networks characterised by memory or gating, e.g. long short-term memory [LSTM] or gated recurrent units [GRU]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/098—Distributed learning, e.g. federated learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Business, Economics & Management (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- Strategic Management (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- Development Economics (AREA)
- Accounting & Taxation (AREA)
- Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Evolutionary Computation (AREA)
- Finance (AREA)
- Entrepreneurship & Innovation (AREA)
- General Business, Economics & Management (AREA)
- Marketing (AREA)
- Economics (AREA)
- Game Theory and Decision Science (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了一种基于联邦学习与GRU结合的汽车订单需求量预测方法,包括获取历史汽车订单数据并处理得到训练数据集和验证数据集;采用构建的训练数据集训练GRU模型构建GRU预测模型;针对GRU预测模型进行处理构建最终的GRU预测模型;采用基于联邦学习的方法针对最终的GRU预测模型进行训练处理,构建若干客户端本地模型;采用知识蒸馏的方法针对构建的若干个客户端本地模型进行迁移处理,构建联邦学习全局预测模型;采用构建的联邦学习全局预测模型,完成对汽车订单需求量的预测处理;本发明充分考虑预测信息,有效保障数据的隐私安全;而且本发明的预测精度高、通信成本低、模型稳定性强。
Description
技术领域
本发明属于联邦机器学习与需求预测技术领域,具体涉及一种基于联邦学习与GRU结合的汽车订单需求量预测方法。
背景技术
在目前的机器学习中,针对汽车订单需求量进行预测的方法基本要使用大量的订单数据,通过集中式机器学习来训练模型,但集中式机器学习也会受限于单服务器的计算资源有限,大规模的数据会导致训练速度不佳等问题,严重阻碍机器学习的发展。虽然提出了分布式机器学习,但中心服务器需要收集使用多个下游零售商的订单数据进行模型训练,数据在传输过程中完全裸露,这无疑会导致零售商的真实数据泄露,致使数据隐私与数据安全无法得到保障,同时也面临着零售商可能不愿意提供数据的难题。
为了保护终端设备用户或零售商的敏感数据,同时又能满足于集中式机器学习的模型精度要求,降低中心服务器的资源负载,联邦学习应运而生。联邦学习作为一种分布式机器学习技术,各客户端无需上传其敏感数据,而只需共享其本地模型参数更新,中心服务器通过与终端设备不断通信传输信息,并且传输过程是经过加密的,中心服务器会聚合来自各客户端的本地模型,聚合后的模型成为下一轮迭代的全局模型,重复这个过程,直到中心服务器全局模型收敛,因此联邦学习能够使多个参与方在保护数据隐私、满足合法合规的要求下进行机器学习,并得到最终的可用模型。
但在联邦学习中,终端设备与中心服务器需要多轮通信交互才能够获得目标精度的全局模型对于复杂的模型训练,如:深度学习模型的训练,每次模型更新可能包含数百万个参数,每一次的更新模型将耗费大量的通信成本,甚至成为模型训练瓶颈;此外,由于终端设备的异构性,每个设备网络状态的不可靠性以及互联网连接速度的不对称性,如:下载速度大于上传速度,也将导致终端设备上传更新参数延迟,致使模型训练瓶颈进一步恶化。
综上所述,在当前的预测方法中,大多数方法的模型建立存在一定的问题;且用户数据的保密性不够,预测的准确度仍有待提高。
发明内容
本发明的目的在于提供一种预测精度高、通信成本低、数据安全性强的基于联邦学习与GRU结合的汽车订单需求量m预测方法。
本发明提供的这种基于联邦学习与GRU结合的汽车订单需求量预测方法,包括如下步骤:
S1.获取历史汽车订单数据,针对得到的数据进行预处理,并构建训练数据集和验证数据集;
S2.采用步骤S1构建的训练数据集,针对GRU模型进行训练处理,构建GRU预测模型;
S3.采用动态规划的方法针对步骤S2构建的GRU预测模型进行处理,构建最终的DP-GRU预测模型;
S4.采用基于联邦学习的方法针对步骤S3构建的DP-GRU预测模型进行训练处理,构建若干个客户端本地模型;
S5.采用知识蒸馏的方法针对步骤S4构建的若干个客户端本地模型进行迁移处理,构建联邦学习全局预测模型;
S6.采用步骤S5构建的联邦学习全局预测模型,完成对汽车订单需求量的预测处理。
步骤S1所述的获取历史汽车订单数据,针对得到的数据进行预处理,并构建训练数据集和验证数据集,具体包括:
(1)获取历史汽车订单数据,针对其中缺失的数据进行补充;
(2)将步骤(1)得到的历史数据转换为时间序列数据,X=[x1,x2,…,xt-1,xt],同时设定时间步timestep,根据前timestep的数据预测后面的数据;
(3)将步骤(2)得到的时间序列数据按照设定尺寸的时间窗格式,转化成一个二维的矩阵,针对二维数据矩阵利用softmax函数进行归一化处理,并设定映射后数据值域的上界值和下界值。
步骤S2所述的采用步骤S1构建的训练数据集,针对GRU模型进行训练处理,构建GRU预测模型,具体包括:
采用下述公式描述GRU预测模型:
其中,xt是时刻t的输入变量,rt为重置门在时刻t的输出变量或权重,zt为更新门在时刻t的输出变量或权重,ht为当前时刻t的状态记忆变量,ht-1为上一时刻t-1的状态记忆变量,为候选集,/>Wr、Wz分别为候选集、更新门和重置门的权重系数,σ为sigmoid激活函数,[·]为两个向量间的连接,*为矩阵间的乘积;
采用Dropout方法针对GRU模型训练过程中存在的过拟合问题进行处理;
在模型的输出端采用全连接层,并选择ReLU作为激活函数,得到模型的预测值,采用下述公式进行表示:
其中,为GRU模型在t时刻的一个输出值,Wo为输出层的权重,bo为偏置项。
步骤S3所述的采用动态规划的方法针对步骤S2构建的GRU预测模型进行处理,构建最终的DP-GRU预测模型,具体包括:
将步骤S1中设定的timestep的序列任务转换为若干个长度为p的预测任务,p为输入值xi的个数;在构建预测t+1时刻的模型输入时,将t-1时刻的输入向量Xt-1和t时刻的预测值Yt同时引入,使得模型在预测t+1时刻的信息时能够考虑之前的预测信息;
设置模型输入序列X,长度为p;输出序列Y,长度为q;l为滚动步长,且0≤l≤p,0≤p-l≤q,得到模型输入向量序列的更新公式,表示如下所示:
Xt=Xt-1[x′1,x′2,…,x′l]+Yt[y1,y2,…,yp-l]
其中,Xt由Xt-1和Yt组成,Xt-1为DP-GRU模型预测t时刻的预测值Yt的输入序列,x′1,x′2,…,x′l为上一时刻逆序选择的l个输入值,y1,y2,…,yp-l为顺序选择的p-l个预测值,每个时刻都可以滚动选择l个输入值与p-l个预测值作为下一时刻的输入序列;采用下述公式表示最终的t+1时刻的预测结果:
Yt+1=fGRU(Xt)
其中,fGRU(·)为设置的GRU模型参数更新公式,Yt+1为t+1时刻的预测结果。
步骤S4所述的采用基于联邦学习的方法针对步骤S3构建的最终的GRU预测模型进行训练处理,构建若干个客户端本地模型,具体包括:
1)通过云中心服务器初始化全局模型WG,并初始化全局变量,包括:定义所有边缘客户端的数目为N,给定的客户端数量为M个,nk为客户端k拥有的数据量,n=n1+…nM为M个客户端的总数据量,向量表示客户端被选择参与训练的离散概率分布,表示客户端k在每一轮通信中被选择的概率,注意力向量分数为边缘客户端选择概率分布为P=[p1,p2,…,pM],每个客户端的注意力向量分数对应于客户端被选中参与训练的概率,并在第一轮通信中初始化
2)采取动态规划客户端方案,在联邦学习训练初始时,选择设定数量的边缘客户端参与到本地训练中,随着通信轮次不断增加,不断增加参与到本地训练中的边缘客户端数量,直到最终模型收敛或所有客户端都加入到训练中为止;
3)在第t轮通信时,云中心服务器根据客户端选择概率分布P=[p1,p2,…,pk]从客户端子集St中随机选择K个客户端加入到本地训练中,Wi (t)表示在第t轮通信时第i个客户端返回的本地模型训练结果;
4)各个边缘客户端首先获取当前通信轮次的全局模型然后采取随机梯度下降方法,使用本地训练集数据对GRU预测模型进行本地训练;
当本地训练结束后,将训练得到的网络输出值调用scaler.inverse_transform()函数进行反归一化处理,需要注意的是输出的预测值的shape要和归一化前的数据shape一致,再计算各个客户端的真实值与预测值之间的误差,判断误差是否满足设定的预测精度要求,若满足要求则对测试集进行预测;若不满足要求则各个本地客户端会将自己的本地模型权重参数Wi (t)及其他参数信息经加密后返回给云中心服务器;
5)云中心服务器收到各个客户端发送的本地模型参数信息后,基于注意力机制,云中心服务器会首先使用欧氏距离来度量每个参与训练的边缘客户端本地模型与全局模型之间的差异,采用下述公式进行表示:
其中,为第t轮训练后客户端i的欧氏距离,也就是客户端i的本地模型与全局模型之间的差异,/>为第t+1轮时的全局模型,Wi (t)为第t轮时客户端i上传给中心服务器的本地模型,||·||表示计算欧氏距离;
获取差异后,对于参与训练的每个客户端,更新其注意力向量分数,采用下述公式描述更新公式:
其中,α为注意力分数衰减率,且α∈[0,1],为客户端i在第t轮时的注意力向量分数,/>为客户端i在第t+1轮时的注意力向量分数,/>为当前轮次参与训练的客户端集合St中客户端k在第t轮时的注意力向量分数,/>为客户端k在第t轮时的欧氏距离,/>为客户端i在第t轮时的欧氏距离;
对于每一个没有被选择参与到训练中的边缘客户端j,令同时更新客户端选择概率分布/>针对欧氏距离不满足设定数值的客户端,或本地模型性能不满足设定要求的客户端而言,在下一轮通信将提高被选中参与到本地训练中的概率,从而减少全局模型在所有边缘客户端中的性能差异性;
6)云中心服务器对所有得到的客户端模型进行加权聚合,采用下述公式表示模型聚合:
得到加权平均处理后的模型参数W(t+1)。
步骤S5所述的采用知识蒸馏的方法针对步骤S4构建的若干个客户端本地模型进行迁移处理,构建联邦学习全局预测模型,具体包括:
假设客户端为N={1,2,…,n},每个客户端i只能访问本地的隐私数据集Di,采用下述公式描述损失函数:
式中Li(WG,Di)为客户端i的本地损失函数,k表示客户端i的本地数据Di的样本数量,ωi为模型训练参数,xj为输入数据,yj为实际输出数据,f(xj)为DP-GRU模型的预测输出数据;利用SGD算法优化损失函数,通过梯度下降法改变参数ωi从而最小化本地损失函数,加速收敛;
其中L(WG)为全局模型的损失函数,N是客户端集合,| |为获取总的客户端数目,为获取一个全局模型WG,使得损失函数最小;
在损失函数公式的基础上,重新定义客户端i采取知识蒸馏后的本地损失函数Lper,i(Wi),采用下述公式进行描述:
其中,Li(WG,Di)为未进行知识蒸馏之前的客户端i的本地损失函数;s为student,表示联邦学习全局模型;t为teachers,表示集成后的本地模型;Dp表示公共数据集,每个参与训练的客户端均能够访问;σ(·)为softmax函数,LKL(·)表示Kullback-Leiblerdivergence函数,λ∈(0,1)为加权系数,用于控制student学习teachers的程度,T为蒸馏温度;
知识蒸馏方法通过基于梯度下降的优化方式,训练联邦学习全局模型与集成后的本地模型,使得联邦学习全局模型与集成后的本地模型具有相似的泛化能力,进行J轮蒸馏,在蒸馏过程中,各个本地模型通过蒸馏样本数据集n得到各自模型的logit输出f(Wi (t),n),并用于训练云中心服务器上的联邦学习全局模型,知识蒸馏过程中的模型参数更新采用下述公式表示:
其中,W(t,j)表示第t轮训练中第j次蒸馏的全局模型,j表示第j次蒸馏,η表示学习率,L表示客户端i的本地蒸馏损失函数,f(·)表示本地模型的logit输出的求解函数,logit输出也就是该模型的最后一个全连接层的输出,将各个本地模型的logit输出的平均值作为整体迁移的知识;
经过J轮知识蒸馏,令:
为新的全局模型,模型训练进入下一轮的迭代;重复上述步骤,直至全局模型收敛,得到最终的联邦学习全局预测模型/>同时联邦学习训练结束。
本发明提供的这种基于联邦学习与GRU结合的汽车订单需求量预测方法,提出基于动态规划思想改进的GRU网络预测模型,使得模型充分考虑预测信息;同时将GRU模型与联邦学习相结合,有效保障数据的隐私安全,针对各个汽车零售商的汽车订单需求量进行预测,解决“梯度弥散”和“梯度爆炸”问题;引入基于欧氏距离的注意力机制,提高所有边缘客户端设备对于全局模型的整体收益;而且本发明的预测精度高、通信成本低、模型稳定性强。
附图说明
图1为本发明方法的方法流程示意图。
具体实施方式
如图1所示为本发明方法的方法流程示意图:本发明提供的这种基于联邦学习与GRU结合的汽车订单需求量预测方法,包括如下步骤:
S1.获取历史汽车订单数据,针对得到的数据进行预处理,并构建训练数据集和验证数据集;具体包括:
(1)获取历史汽车订单数据,针对其中缺失的数据,例如某款汽车缺少某个日期的订单量,使用加权平均的方法对其进行补充;
(2)将步骤(1)得到的历史数据转换为时间序列数据,X=[x1,x2,…,xt-1,xt],同时设定时间步timestep,根据前timestep的数据预测后面的数据;
(3)将步骤(2)得到的时间序列数据按照设定尺寸的时间窗格式,转化成一个二维的矩阵,针对二维数据矩阵利用softmax函数进行归一化处理,并设定映射后数据值域的上界值和下界值;本发明中将上界值设定为1,下界值设定为-1;
S2.采用步骤S1构建的训练数据集,针对GRU模型进行训练处理,构建GRU预测模型;具体包括:
采用下述公式描述GRU预测模型:
其中,xt是时刻t的输入变量,rt为重置门在时刻t的输出变量或权重,zt为更新门在时刻t的输出变量或权重,ht为当前时刻t的状态记忆变量,ht-1为上一时刻t-1的状态记忆变量,为候选集,/>Wr、Wz分别为候选集、更新门和重置门的权重系数,σ为sigmoid激活函数,[·]为两个向量间的连接,*为矩阵间的乘积;
采用Dropout方法针对GRU模型训练过程中存在的过拟合问题进行处理;本发明中将Dropout的抛弃阈值设置为0.2;
在模型的输出端采用全连接层,并选择ReLU作为激活函数,得到模型的预测值,采用下述公式进行表示:
其中,为GRU模型在t时刻的一个输出值,Wo为输出层的权重,bo为偏置项;
S3.采用动态规划的方法针对步骤S2构建的GRU预测模型进行处理,构建最终的DP-GRU预测模型;具体包括:
将步骤S1中设定的timestep的序列任务转换为若干个长度为p的预测任务,p为输入值xi的个数;在构建预测t+1时刻的模型输入时,将t-1时刻的输入向量Xt-1和t时刻的预测值Yt同时引入,使得模型在预测t+1时刻的信息时能够考虑之前的预测信息;
设置模型输入序列X,长度为p;输出序列Y,长度为q;l为滚动步长,且0≤l≤p,0≤p-l≤q,得到模型输入向量序列的更新公式,表示如下所示:
Xt=Xt-1[x′1,x′2,…,x′l]+Yt[y1,y2,…,yp-l]
其中,Xt由Xt-1和Yt组成,Xt-1为DP-GRU模型预测t时刻的预测值Yt的输入序列,x′1,x′2,…,x′l为上一时刻逆序选择的l个输入值,y1,y2,…,yp-l为顺序选择的p-l个预测值,每个时刻都可以滚动选择l个输入值与p-l个预测值作为下一时刻的输入序列;采用下述公式表示最终的t+1时刻的预测结果:
Yt+1=fGRU(Xt)
其中,fGRU(·)为设置的GRU模型参数更新公式,Yt+1为t+1时刻的预测结果;
S4.采用基于联邦学习的方法针对步骤S3构建的最终的GRU预测模型进行训练处理,构建若干个客户端本地模型;具体包括:
1)通过云中心服务器初始化全局模型WG,并初始化全局变量,包括:定义所有边缘客户端的数目为N,给定的客户端数量为M个,nk为客户端k拥有的数据量,n=n1+…nM为M个客户端的总数据量,向量表示客户端被选择参与训练的离散概率分布,表示客户端k在每一轮通信中被选择的概率,注意力向量分数为边缘客户端选择概率分布为P=[p1,p2,…,pM],每个客户端的注意力向量分数分别对应于客户端被选中参与训练的概率,并在第一轮通信中初始化
2)采取动态规划客户端方案,在联邦学习训练初始时,选择设定数量的边缘客户端参与到本地训练中,随着通信轮次不断增加,逐步增加参与到本地训练中的边缘客户端数量,直到最终模型收敛或所有客户端都加入到训练中为止;在本发明的训练最开始时,只随机选择10%的边缘客户端参与到训练中,每经过固定ΔT通信轮次后,多选择10%的边缘客户端加入到训练中,直到最终模型收敛或所有客户端都加入到训练中为止;
3)在第t轮通信时,云中心服务器根据客户端选择概率分布P=[p1,p2,…,pk]从客户端子集St中随机选择K个客户端加入到本地训练中,Wi (t)表示在第t轮通信时第i个客户端返回的本地模型训练结果;
4)各个边缘客户端首先获取当前通信轮次的全局模型然后采取随机梯度下降方法,使用本地训练集数据对GRU预测模型进行本地训练;
当本地训练结束后,将训练得到的网络输出值调用scaler.inverse_transform()函数进行反归一化处理,需要注意的是输出的预测值的shape要和归一化前的数据shape一致,再计算各个客户端的真实值与预测值之间的误差,判断误差是否满足设定的预测精度要求,若满足要求则对测试集进行预测;若不满足要求则各个本地客户端会将自己的本地模型权重参数Wi (t)及其他参数信息经加密后返回给云中心服务器;
5)云中心服务器收到各个客户端发送的本地模型参数信息后,基于注意力机制,云中心服务器会首先使用欧氏距离来度量每个参与训练的边缘客户端本地模型与全局模型之间的差异,采用下述公式进行表示:
其中,为第t轮训练后客户端i的欧氏距离,也就是客户端i的本地模型与全局模型之间的差异,/>为第t+1轮时的全局模型,Wi (t)为第t轮时客户端i上传给中心服务器的本地模型,||·||表示计算欧氏距离;
获取差异后,对于参与训练的每个客户端,更新其注意力向量分数,采用下述公式描述更新公式:
其中,α为注意力分数衰减率,且α∈[0,1],为客户端i在第t轮时的注意力向量分数,/>为客户端i在第t+1轮时的注意力向量分数,/>为当前轮次参与训练的客户端集合St中客户端k在第t轮时的注意力向量分数,/>为客户端k在第t轮时的欧氏距离,/>为客户端i在第t轮时的欧氏距离;
对于每一个没有被选择参与到训练中的边缘客户端j,令同时更新客户端选择概率分布/>针对欧氏距离不满足设定数值的客户端,或本地模型性能不满足设定要求的客户端而言,在下一轮通信将提高被选中参与到本地训练中的概率,从而减少全局模型在所有边缘客户端中的性能差异性;
6)云中心服务器对所有得到的客户端模型进行加权聚合,采用下述公式表示模型聚合:
得到加权平均处理后的模型参数W(t+1);
S5.采用知识蒸馏的方法针对步骤S4构建的若干个客户端本地模型进行迁移处理,构建联邦学习全局预测模型;具体包括:
假设客户端为N={1,2,…,n},每个客户端i只能访问本地的隐私数据集Di,采用下述公式描述损失函数:
式中Li(WG,Di)为客户端i的本地损失函数,k表示客户端i的本地数据Di的样本数量,ωi为模型训练参数,xj为输入数据,yj为实际输出数据,f(xj)为DP-GRU模型的预测输出数据;利用SGD算法优化损失函数,通过梯度下降法改变参数ωi从而最小化本地损失函数,加速收敛;
其中L(WG)为全局模型的损失函数,N是客户端集合,| |为获取总的客户端数目,为获取一个全局模型WG,使得损失函数最小;
在损失函数公式的基础上,重新定义客户端i采取知识蒸馏后的本地损失函数Lper,i(Wi),采用下述公式进行描述:
其中,Li(WG,Di)为未进行知识蒸馏之前的客户端i的本地损失函数;s为student,表示联邦学习全局模型;t为teachers,表示集成后的本地模型;Dp表示公共数据集,每个参与训练的客户端均能够访问;σ(·)为softmax函数,LKL(·)表示Kullback-Leiblerdivergence函数,λ∈(0,1)为加权系数,用于控制student学习teachers的程度,T为蒸馏温度;
知识蒸馏方法通过基于梯度下降的优化方式,训练联邦学习全局模型与集成后的本地模型,使得联邦学习全局模型与集成后的本地模型具有相似的泛化能力,进行J轮蒸馏,在蒸馏过程中,各个本地模型通过蒸馏样本数据集n得到各自模型的logit输出f(Wi (t),n),并用于训练云中心服务器上的联邦学习全局模型,知识蒸馏过程中的模型参数更新采用下述公式表示:
其中,W(t,j)表示第t轮训练中第j次蒸馏的全局模型,j表示第j次蒸馏,η表示学习率,L表示客户端i的本地蒸馏损失函数,f(·)表示本地模型的logit输出的求解函数,logit输出也就是该模型的最后一个全连接层的输出,将各个本地模型的logit输出的平均值作为整体迁移的知识;
经过J轮知识蒸馏,令:
为新的全局模型,模型训练进入下一轮的迭代;重复上述步骤,直至全局模型收敛,得到最终的联邦学习全局预测模型/>同时联邦学习训练结束;
S6.采用步骤S5构建的联邦学习全局预测模型,完成对汽车订单需求量的预测处理。
Claims (6)
1.一种基于联邦学习与GRU结合的汽车订单需求量预测方法,包括如下步骤:
S1.获取历史汽车订单数据,针对得到的数据进行预处理,并构建训练数据集和验证数据集;
S2.采用步骤S1构建的训练数据集,针对GRU模型进行训练处理,构建GRU预测模型;
S3.采用动态规划的方法针对步骤S2构建的GRU预测模型进行处理,构建最终的DP-GRU预测模型;
S4.采用基于联邦学习的方法针对步骤S3构建的DP-GRU预测模型进行训练处理,构建若干个客户端本地模型;
S5.采用知识蒸馏的方法针对步骤S4构建的若干个客户端本地模型进行迁移处理,构建联邦学习全局预测模型;
S6.采用步骤S5构建的联邦学习全局预测模型,完成对汽车订单需求量的预测处理。
2.根据权利要求1所述的基于联邦学习与GRU结合的汽车订单需求量预测方法,其特征在于步骤S1所述的获取历史汽车订单数据,针对得到的数据进行预处理,并构建训练数据集和验证数据集,具体包括:
(1)获取历史汽车订单数据,针对其中缺失的数据进行补充;
(2)将步骤(1)得到的历史数据转换为时间序列数据,X=[x1,x2,…,xt-1,xt],同时设定时间步timestep,根据前timestep的数据预测后面的数据;
(3)将步骤(2)得到的时间序列数据按照设定尺寸的时间窗格式,转化成一个二维的矩阵,针对二维数据矩阵利用softmax函数进行归一化处理,并设定映射后数据值域的上界值和下界值。
3.根据权利要求2所述的基于联邦学习与GRU结合的汽车订单需求量预测方法,其特征在于步骤S2所述的采用步骤S1构建的训练数据集,针对GRU模型进行训练处理,构建GRU预测模型,具体包括:
采用下述公式描述GRU预测模型:
其中,xt是时刻t的输入变量,rt为重置门在时刻t的输出变量或权重,zt为更新门在时刻t的输出变量或权重,ht为当前时刻t的状态记忆变量,ht-1为上一时刻t-1的状态记忆变量,为候选集,/>Wr、Wz分别为候选集、更新门和重置门的权重系数,σ为sigmoid激活函数,[·]为两个向量间的连接,*为矩阵间的乘积;
采用Dropout方法针对GRU模型训练过程中存在的过拟合问题进行处理;
在模型的输出端采用全连接层,并选择ReLU作为激活函数,得到模型的预测值,采用下述公式进行表示:
其中,为GRU模型在t时刻的一个输出值,Wo为输出层的权重,bo为偏置项。
4.根据权利要求3所述的基于联邦学习与GRU结合的汽车订单需求量预测方法,其特征在于步骤S3所述的采用动态规划的方法针对步骤S2构建的GRU预测模型进行处理,构建最终的DP-GRU预测模型,具体包括:
将步骤S1中设定的timestep的序列任务转换为若干个长度为p的预测任务,p为输入值xi的个数;在构建预测t+1时刻的模型输入时,将t-1时刻的输入向量Xt-1和t时刻的预测值Yt同时引入,使得模型在预测t+1时刻的信息时能够考虑之前的预测信息;
设置模型输入序列X,长度为p;输出序列Y,长度为q;l为滚动步长,且0≤l≤p,0≤p-l≤q,得到模型输入向量序列的更新公式,表示如下所示:
Xt=Xt-1[x′1,x′2,…,x′l]+Yt[y1,y2,…,yp-l]
其中,Xt由Xt-1和Yt组成,Xt-1为DP-GRU模型预测t时刻的预测值Yt的输入序列,x′1,x′2,…,x′l为上一时刻逆序选择的l个输入值,y1,y2,…,yp-l为顺序选择的p-l个预测值,每个时刻都可以滚动选择l个输入值与p-l个预测值作为下一时刻的输入序列;采用下述公式表示最终的t+1时刻的预测结果:
Yt+1=fGRU(Xt)
其中,fGRU(·)为设置的GRU模型参数更新公式,Yt+1为t+1时刻的预测结果。
5.根据权利要求4所述的基于联邦学习与GRU结合的汽车订单需求量预测方法,其特征在于步骤S4所述的采用基于联邦学习的方法针对步骤S3构建的最终的GRU预测模型进行训练处理,构建若干个客户端本地模型,具体包括:
1)通过云中心服务器初始化全局模型WG,并初始化全局变量,包括:定义所有边缘客户端的数目为N,给定的客户端数量为M个,nk为客户端k拥有的数据量,n=n1+…nM为M个客户端的总数据量,向量表示客户端被选择参与训练的离散概率分布,表示客户端k在每一轮通信中被选择的概率,注意力向量分数为边缘客户端选择概率分布为P=[p1,p2,…,pM],每个客户端的注意力向量分数对应于客户端被选中参与训练的概率,并在第一轮通信中初始化
2)采取动态规划客户端方案,在联邦学习训练初始时,选择设定数量的边缘客户端参与到本地训练中,随着通信轮次不断增加,不断增加参与到本地训练中的边缘客户端数量,直到最终模型收敛或所有客户端都加入到训练中为止;
3)在第t轮通信时,云中心服务器根据客户端选择概率分布P=[p1,p2,…,pk]从客户端子集St中随机选择K个客户端加入到本地训练中,Wi (t)表示在第t轮通信时第i个客户端返回的本地模型训练结果;
4)各个边缘客户端首先获取当前通信轮次的全局模型然后采取随机梯度下降方法,使用本地训练集数据对GRU预测模型进行本地训练;
当本地训练结束后,将训练得到的网络输出值调用scaler.inverse_transform()函数进行反归一化处理,需要注意的是输出的预测值的shape要和归一化前的数据shape一致,再计算各个客户端的真实值与预测值之间的误差,判断误差是否满足设定的预测精度要求,若满足要求则对测试集进行预测;若不满足要求则各个本地客户端将自己的本地模型权重参数Wi (t)及其他参数信息经加密后返回给云中心服务器;
5)云中心服务器收到各个客户端发送的本地模型参数信息后,基于注意力机制,云中心服务器会首先使用欧氏距离来度量每个参与训练的边缘客户端本地模型与全局模型之间的差异,采用下述公式进行表示:
其中,为第t轮训练后客户端i的欧氏距离,/>为第t+1轮时的全局模型,Wi (t)为第t轮时客户端i上传给中心服务器的本地模型,||·||表示计算欧氏距离;
获取差异后,对于参与训练的每个客户端,更新其注意力向量分数,采用下述公式描述更新公式:
其中,α为注意力分数衰减率,且α∈[0,1],为客户端i在第t轮时的注意力向量分数,/>为客户端i在第t+1轮时的注意力向量分数,/>为当前轮次参与训练的客户端集合St中客户端k在第t轮时的注意力向量分数,/>为客户端k在第t轮时的欧氏距离,/>为客户端i在第t轮时的欧氏距离;
对于每一个没有被选择参与到训练中的边缘客户端j,令同时更新客户端选择概率分布/>针对欧氏距离不满足设定数值的客户端,或本地模型性能不满足设定要求的客户端而言,在下一轮通信将提高被选中参与到本地训练中的概率,从而减少全局模型在所有边缘客户端中的性能差异性;
6)云中心服务器对所有得到的客户端模型进行加权聚合,采用下述公式表示模型聚合:
得到加权平均处理后的模型参数W(t+1)。
6.根据权利要求5所述的基于联邦学习与GRU结合的汽车订单需求量预测方法,其特征在于步骤S5所述的采用知识蒸馏的方法针对步骤S4构建的若干个客户端本地模型进行迁移处理,构建联邦学习全局预测模型,具体包括:
假设客户端为N={1,2,…,n},每个客户端i只能访问本地的隐私数据集Di,采用下述公式描述损失函数:
式中Li(WG,Di)为客户端i的本地损失函数,k表示客户端i的本地数据Di的样本数量,ωi为模型训练参数,xj为输入数据,yj为实际输出数据,f(xj)为DP-GRU模型的预测输出数据;利用SGD算法优化损失函数,通过梯度下降法改变参数ωi从而最小化本地损失函数,加速收敛;
其中L(WG)为全局模型的损失函数,N是客户端集合,| |为获取总的客户端数目,为获取一个全局模型WG,使得损失函数最小;
在损失函数公式的基础上,重新定义客户端i采取知识蒸馏后的本地损失函数Lper,i(Wi),采用下述公式进行描述:
其中,Li(WG,Di)为未进行知识蒸馏之前的客户端i的本地损失函数;s为student,表示联邦学习全局模型;t为teachers,表示集成后的本地模型;Dp表示公共数据集,每个参与训练的客户端均能够访问;σ(·)为softmax函数,LKL(·)表示Kullback-Leibler divergence函数,λ∈(0,1)为加权系数,用于控制student学习teachers的程度,T为蒸馏温度;
知识蒸馏方法通过基于梯度下降的优化方式,训练联邦学习全局模型与集成后的本地模型,使得联邦学习全局模型与集成后的本地模型具有相似的泛化能力,进行J轮蒸馏,在蒸馏过程中,各个本地模型通过蒸馏样本数据集n得到各自模型的logit输出f(Wi (t),n),并用于训练云中心服务器上的联邦学习全局模型,知识蒸馏过程中的模型参数更新采用下述公式表示:
其中,W(t,j)表示第t轮训练中第j次蒸馏的全局模型,j表示第j次蒸馏,η表示学习率,L表示客户端i的本地蒸馏损失函数,f(·)表示本地模型的logit输出的求解函数,logit输出也就是该模型的最后一个全连接层的输出,将各个本地模型的logit输出的平均值作为整体迁移的知识;
经过J轮知识蒸馏,令:
为新的全局模型,模型训练进入下一轮的迭代;重复上述步骤,直至全局模型收敛,得到最终的联邦学习全局预测模型/>同时联邦学习训练结束。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310421463.4A CN116645130A (zh) | 2023-04-19 | 2023-04-19 | 基于联邦学习与gru结合的汽车订单需求量预测方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310421463.4A CN116645130A (zh) | 2023-04-19 | 2023-04-19 | 基于联邦学习与gru结合的汽车订单需求量预测方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116645130A true CN116645130A (zh) | 2023-08-25 |
Family
ID=87617704
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310421463.4A Pending CN116645130A (zh) | 2023-04-19 | 2023-04-19 | 基于联邦学习与gru结合的汽车订单需求量预测方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116645130A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117835329A (zh) * | 2024-03-04 | 2024-04-05 | 东北大学 | 车载边缘计算中基于移动性预测的服务迁移方法 |
CN117875535B (zh) * | 2024-03-13 | 2024-06-04 | 中南大学 | 基于历史信息嵌入的取送货路径规划方法及系统 |
-
2023
- 2023-04-19 CN CN202310421463.4A patent/CN116645130A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117835329A (zh) * | 2024-03-04 | 2024-04-05 | 东北大学 | 车载边缘计算中基于移动性预测的服务迁移方法 |
CN117875535B (zh) * | 2024-03-13 | 2024-06-04 | 中南大学 | 基于历史信息嵌入的取送货路径规划方法及系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113705610B (zh) | 一种基于联邦学习的异构模型聚合方法和系统 | |
US20160012330A1 (en) | Neural network and method of neural network training | |
Le et al. | Federated continuous learning with broad network architecture | |
CN110659678B (zh) | 一种用户行为分类方法、系统及存储介质 | |
CN107103359A (zh) | 基于卷积神经网络的大服务系统在线可靠性预测方法 | |
CN114492833A (zh) | 基于梯度记忆的车联网联邦学习分层知识安全迁移方法 | |
CN116523079A (zh) | 一种基于强化学习联邦学习优化方法及系统 | |
CN113128432B (zh) | 基于演化计算的机器视觉多任务神经网络架构搜索方法 | |
CN116645130A (zh) | 基于联邦学习与gru结合的汽车订单需求量预测方法 | |
US20240027976A1 (en) | Industrial Process Soft Sensor Method Based on Federated Stochastic Configuration Network | |
CN116471286A (zh) | 基于区块链及联邦学习的物联网数据共享方法 | |
Long et al. | Fedsiam: Towards adaptive federated semi-supervised learning | |
CN117574429A (zh) | 一种边缘计算网络中隐私强化的联邦深度学习方法 | |
CN111832817A (zh) | 基于mcp罚函数的小世界回声状态网络时间序列预测方法 | |
CN117236421B (zh) | 一种基于联邦知识蒸馏的大模型训练方法 | |
CN114580747A (zh) | 基于数据相关性和模糊系统的异常数据预测方法及系统 | |
Tembine | Mean field stochastic games: Convergence, Q/H-learning and optimality | |
CN117523291A (zh) | 基于联邦知识蒸馏和集成学习的图像分类方法 | |
CN110321951B (zh) | 一种vr模拟飞行器训练评价方法 | |
CN115640852B (zh) | 联邦学习参与节点选择优化方法、联邦学习方法及系统 | |
CN115936110A (zh) | 一种缓解异构性问题的联邦学习方法 | |
Xue et al. | An improved extreme learning machine based on variable-length particle swarm optimization | |
CN117033997A (zh) | 数据切分方法、装置、电子设备和介质 | |
Miyajima et al. | Fast and secure back-propagation learning using vertically partitioned data with IoT | |
CN111563767A (zh) | 股票价格预测方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |