CN111222628B - 循环神经网络训练优化方法、设备、系统及可读存储介质 - Google Patents
循环神经网络训练优化方法、设备、系统及可读存储介质 Download PDFInfo
- Publication number
- CN111222628B CN111222628B CN201911141081.6A CN201911141081A CN111222628B CN 111222628 B CN111222628 B CN 111222628B CN 201911141081 A CN201911141081 A CN 201911141081A CN 111222628 B CN111222628 B CN 111222628B
- Authority
- CN
- China
- Prior art keywords
- rnn
- neural network
- participation
- equipment
- output result
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/044—Recurrent networks, e.g. Hopfield networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了一种循环神经网络训练优化方法、设备、系统及可读存储介质,所述方法包括:接收参与设备发送的RNN输出结果,其中,RNN输出结果为参与设备将上游传递信息,以及参与设备所处理时间步对应的本地训练数据输入对应RNN中得到的;根据RNN输出结果计算得到梯度信息;将梯度信息反向传播给参与设备,供各参与设备根据梯度信息更新RNN的模型参数;将从各参与设备接收到的更新后模型参数进行融合得到全局模型参数,并返回给各参与设备,迭代训练得到训练完成的RNN。本发明通过协调设备协调多个参与设备分别处理不同时间步,分担训练RNN的计算和电量开销,使得在设备计算资源和电量资源有限的场景下也可以进行RNN训练。
Description
技术领域
本发明涉及人工智能领域,尤其涉及一种循环神经网络训练优化方法、设备、系统及可读存储介质。
背景技术
循环神经网络(Recurrent Neural Network,RNN)是一类具有短期记忆能力的神经网络,适合用于处理视频、语音、文本等与时序相关的问题。在RNN中,神经元不但可以接收其他神经元的信息,还可以接收自身的信息,形成具有环路的网络结构。目前,RNN以及深度RNN(例如,stacked LSTM)已经在实践中证明了其强大的功能,特别是在自然语言处理领域被广泛应用。
然而,训练RNN的计算复杂度是非常高的,这是因为RNN有很多时间步(time-step),每个时间步都对应一个神经网络或者深度神经网络。例如,当总时间步的个数为1024,每个时间步对应的隐藏节点为1024,一共有8层RNN(例如,4个双向的LSTM(bidirectional LSTM)堆积起来),批量大小为64,那么这个计算量是非常庞大的,在训练的时候,一个LSTM(Long Short-Term Memory,长短期记忆网络)输入层维度即为1x64x1024x1024,还需要重复做8次这样的输入以及计算。如果处理的序列数据是遥感卫星拍摄的四通道图像数据,那么一个LSTM的输入维度就变成了4x64x1024x1024,即计算量又增加了4倍。如果考虑卫星在轨计算或者物联网(Internet of Things,IoT)等应用场景,训练RNN的计算复杂度可能会超过设备的计算能力和电量资源。
现有的解决办法是使用强大的CPU或者TPU硬件资源来训练复杂的RNN模型,但是CPU和TPU的成本和电量消耗都超过了卫星或者IoT设备的承受范围,不能用于卫星在轨计算等应用场景。
发明内容
本发明的主要目的在于提供一种循环神经网络训练优化方法、设备、系统及可读存储介质,旨在解决训练RNN的计算复杂度高,可能会超过设备的计算能力和电量资源,从而无法应用于设备计算能力和电量资源受限的应用场景的问题。
为实现上述目的,本发明提供一种循环神经网络训练优化方法,所述循环神经网络训练优化方法应用于基于联邦学习训练循环神经网络RNN的协调设备,所述协调设备与各参与设备通信连接,各所述参与设备按照各自处理的时间步的先后顺序通信连接,所述循环神经网络训练优化方法包括以下步骤:
接收所述参与设备发送的RNN输出结果,其中,所述RNN输出结果为所述参与设备将上游参与设备传递的上游传递信息,以及所述参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到的,所述上游传递信息为所述上游参与设备将上上游参与设备传递的上上游传递信息,以及所述上游参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到的;
根据所述RNN输出结果计算得到预设损失函数对所述RNN输出结果的梯度信息;
将所述梯度信息反向传播给所述参与设备,以供各所述参与设备根据所述梯度信息计算所述损失函数对RNN模型参数的梯度信息,并根据所述损失函数对RNN模型参数的梯度信息来更新所述RNN的模型参数;
将从各所述参与设备接收到的更新后的模型参数进行融合得到全局模型参数,并将所述全局模型参数发送给各所述参与设备,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN。
可选地,所述RNN是神经网络模型的一部分,所述根据所述RNN输出结果计算得到预设损失函数对所述RNN输出结果的梯度信息的步骤包括:
将所述RNN输出结果输入所述神经网络模型的其他网络部分,得到所述神经网络模型的输出结果;
根据所述神经网络模型的输出结果和预设的样本标签计算预设损失函数对所述RNN输出结果的梯度信息。
可选地,所述将从各所述参与设备接收到的更新后的模型参数进行融合得到全局模型参数,并将所述全局模型参数发送给各所述参与设备,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN的步骤之后,还包括:
接收各所述参与设备发送的RNN预测输出结果,其中,各所述参与设备将各自的上游传递信息和所处理时间步对应的本地预测数据输入训练完成的RNN中,得到所述RNN预测输出结果;
将所述RNN预测输出结果输入所述神经网络模型的其他网络部分,得到所述神经网络模型的预测结果。
可选地,所述接收所述参与设备发送的RNN输出结果的步骤之前,还包括:
根据所述RNN训练样本的时间步数和所述参与设备的数量确定各所述参与设备需处理的时间步;或,
根据各所述参与设备拥有的训练数据所对应的时间步确定各所述参与设备需处理的时间步。
为实现上述目的,本发明还提供一种循环神经网络训练优化方法,所述循环神经网络训练优化方法应用于基于联邦学习训练循环神经网络RNN的参与设备,所述参与设备与协调设备通信连接,各所述参与设备按照各自处理的时间步的先后顺序通信连接,所述循环神经网络训练优化方法包括以下步骤:
将上游参与设备传递的上游传递信息,以及所述参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到RNN输出结果,并将所述RNN输出结果发送给所述协调设备,其中,所述上游传递信息为所述上游参与设备将上上游参与设备传递的上上游传递信息,以及所述上游参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到的;
接收所述协调设备反向传播的梯度信息,其中,所述协调设备根据从各所述参与设备接收到的RNN输出结果计算得到所述梯度信息;
根据所述梯度信息计算所述损失函数对RNN模型参数的梯度信息,并根据所述损失函数对RNN模型参数的梯度信息来更新所述RNN的模型参数,并将更新后的模型参数发送给所述协调设备;
接收所述协调设备发送的全局模型参数,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN,其中,所述协调设备对各所述参与设备发送的模型参数进行融合得到所述全局模型参数。
可选地,所述上游传递信息包括所述上游参与设备对应的RNN的输出结果和/或记忆状态。
为实现上述目的,本发明还提供一种循环神经网络训练优化设备,所述循环神经网络训练优化设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的循环神经网络训练优化程序,所述循环神经网络训练优化程序被所述处理器执行时实现如上所述的循环神经网络训练优化方法的步骤。
为实现上述目的,本发明还提供一种循环神经网络训练优化系统,所述循环神经网络训练优化包括至少一个如上所述的协调设备和至少两个如上所述的参与设备。
此外,为实现上述目的,本发明还提出一种计算机可读存储介质,所述计算机可读存储介质上存储有循环神经网络训练优化程序,所述循环神经网络训练优化程序被处理器执行时实现如上所述的循环神经网络训练优化方法的步骤。
本发明中,通过协调设备接收参与设备发送的RNN输出结果,其中,RNN输出结果为参与设备将上游参与设备传递的上游传递信息,以及参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到的,上游传递信息为上游参与设备将上上游参与设备传递的上上游传递信息,以及上游参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到的;协调设备根据RNN输出结果计算得到预设损失函数对RNN输出结果的梯度信息;将梯度信息反向传播给参与设备,以供各参与设备根据梯度信息更新RNN的模型参数;将从各参与设备接收到的更新后的模型参数进行融合得到全局模型参数,并将全局模型参数发送给各参与设备,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN。在本实施例中,通过在训练RNN时,由协调设备协调多个参与设备分别处理不同时间步,分担训练RNN的计算负担和电量开销,降低了训练RNN对单个设备的计算和电量消耗,使得在设备计算资源和电量资源有限的场景下也可以进行RNN训练,适用于卫星在轨计算或者IoT场景。
附图说明
图1是本发明实施例方案涉及的硬件运行环境的结构示意图;
图2为本发明循环神经网络训练优化方法第一实施例的流程示意图;
图3为本发明实施例涉及的一种一般RNN的结构示意图;
图4为为图3中RNN按照时间步展开的示意图;
图5为本发明实施例涉及的一种使用RNN的深度神经网络示意图;
图6为本发明实施例涉及的一种参与设备与协调设备联合训练RNN的流程示意图;
图7为本发明实施例涉及的一种参与设备之间进行信息传递,以及参与设备向协调设备发送RNN输出结果的示意图;
图8为本发明实施例涉及的一种协调设备向参与设备发送梯度信息以及全局模型参数更新的示意图;
图9为一种参与设备向协调设备发送更新后的模型参数的示意图。
本发明目的的实现、功能特点及优点将结合实施例,参照附图做进一步说明。
具体实施方式
应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
如图1所示,图1是本发明实施例方案涉及的硬件运行环境的设备结构示意图。
需要说明的是,本发明实施例循环神经网络训练优化设备可以是智能手机、个人计算机和服务器等设备,在此不做具体限制。
如图1所示,该循环神经网络训练优化设备可以包括:处理器1001,例如CPU,网络接口1004,用户接口1003,存储器1005,通信总线1002。其中,通信总线1002用于实现这些组件之间的连接通信。用户接口1003可以包括显示屏(Display)、输入单元比如键盘(Keyboard),可选用户接口1003还可以包括标准的有线接口、无线接口。网络接口1004可选的可以包括标准的有线接口、无线接口(如WI-FI接口)。存储器1005可以是高速RAM存储器,也可以是稳定的存储器(non-volatile memory),例如磁盘存储器。存储器1005可选的还可以是独立于前述处理器1001的存储装置。
本领域技术人员可以理解,图1中示出的设备结构并不构成对循环神经网络训练优化设备的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。
如图1所示,作为一种计算机存储介质的存储器1005中可以包括操作系统、网络通信模块、用户接口模块以及循环神经网络训练优化程序。其中,操作系统是管理和控制设备硬件和软件资源的程序,支持循环神经网络训练优化程序以及其它软件或程序的运行。
当循环神经网络训练优化设备是基于联邦学习训练RNN的协调设备时,在图1所示的设备中,用户接口1003主要用于与客户端进行数据通信;网络接口1004主要用于与基于联邦学习训练RNN的参与设备建立通信连接,各个参与设备按照各自处理的时间步先后顺序通信连接;处理器1001可以用于调用存储器1005中存储的循环神经网络训练优化程序,并执行以下操作:
接收所述参与设备发送的RNN输出结果,其中,所述RNN输出结果为所述参与设备将上游参与设备传递的上游传递信息,以及所述参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到的,所述上游传递信息为所述上游参与设备将上上游参与设备传递的上上游传递信息,以及所述上游参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到的;
根据所述RNN输出结果计算得到预设损失函数对所述RNN输出结果的梯度信息;
将所述梯度信息反向传播给所述参与设备,以供各所述参与设备根据所述梯度信息计算所述损失函数对RNN模型参数的梯度信息,并根据所述损失函数对RNN模型参数的梯度信息来更新所述RNN的模型参数;
将从各所述参与设备接收到的更新后的模型参数进行融合得到全局模型参数,并将所述全局模型参数发送给各所述参与设备,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN。
进一步地,所述RNN是神经网络模型的一部分,所述根据所述RNN输出结果计算得到预设损失函数对所述RNN输出结果的梯度信息的步骤包括:
将所述RNN输出结果输入所述神经网路模型的其他网络部分,得到所述神经网络模型的输出结果;
根据所述神经网络模型的输出结果和预设的样本标签计算预设损失函数对所述RNN输出结果的梯度信息。
进一步地,所述将从各所述参与设备接收到的更新后的模型参数进行融合得到全局模型参数,并将所述全局模型参数发送给各所述参与设备,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN的步骤之后,处理器1001还可以用于调用存储器1005中存储的循环神经网络训练优化程序,执行以下操作:
接收各所述参与设备发送的RNN预测输出结果,其中,各所述参与设备将各自的上游传递信息和所处理时间步对应的本地预测数据输入训练完成的RNN中,得到所述RNN预测输出结果;
将所述RNN预测输出结果输入所述神经网络模型的其他网络部分,得到所述神经网络模型的预测结果。
进一步地,所述接收所述参与设备发送的RNN输出结果的步骤之前,处理器1001还可以用于调用存储器1005中存储的循环神经网络训练优化程序,执行以下操作:
根据所述RNN训练样本的时间步数和所述参与设备的数量确定各所述参与设备需处理的时间步;或,
根据各所述参与设备拥有的训练数据所对应的时间步确定各所述参与设备需处理的时间步。
当循环神经网络训练优化设备是基于联邦学习训练RNN的参与设备时,在图1所示的设备中,用户接口1003主要用于与客户端进行数据通信;网络接口1004主要用于与基于联邦学习训练RNN的协调设备建立通信连接,各个参与设备按照各自处理的时间步先后顺序通信连接;处理器1001可以用于调用存储器1005中存储的循环神经网络训练优化程序,并执行以下操作:
将上游参与设备传递的上游传递信息,以及所述参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到RNN输出结果,并将所述RNN输出结果发送给所述协调设备,其中,所述上游传递信息为所述上游参与设备将上上游参与设备传递的上上游传递信息,以及所述上游参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到的;
接收所述协调设备反向传播的梯度信息,其中,所述协调设备根据从各所述参与设备接收到的RNN输出结果计算得到所述梯度信息;
根据所述梯度信息计算所述损失函数对RNN模型参数的梯度信息,并根据所述损失函数对RNN模型参数的梯度信息来更新所述RNN的模型参数,并将更新后的模型参数发送给所述协调设备;
接收所述协调设备发送的全局模型参数,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN,其中,所述协调设备对各所述参与设备发送的模型参数进行融合得到所述全局模型参数。
进一步地,所述上游传递信息包括所述上游参与设备对应的RNN的输出结果和/或记忆状态。
此外,本发明实施例还提出一种联邦学习模型训练系统,所述联邦学习模型训练系统包括至少一个如上所述的协调设备、至少两个如上所述的参与设备。
基于上述的结构,提出循环神经网络训练优化方法的各个实施例。
参照图2,图2为本发明循环神经网络训练优化方法第一实施例的流程示意图。需要说明的是,虽然在流程图中示出了逻辑顺序,但是在某些情况下,可以以不同于此处的顺序执行所示出或描述的步骤。
在本实施例中,循环神经网络训练优化方法应用于基于联邦学习训练RNN的协调设备,协调设备与各参与设备通信连接,各参与设备按照各自处理的时间步的先后顺序通信连接,协调设备和参与设备可以是智能手机、个人计算机和服务器等设备。在本实施例中,循环神经网络训练优化方法包括:
步骤S10,接收所述参与设备发送的RNN输出结果,其中,所述RNN输出结果为所述参与设备将上游参与设备传递的上游传递信息,以及所述参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到的,所述上游传递信息为所述上游参与设备将上上游参与设备传递的上上游传递信息,以及所述上游参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到的;
为解决因训练RNN的计算复杂度高,对训练设备的计算资源和电量消耗过大,采用常规的训练方法可能超出设备的计算能力和电量资源,导致无法应用于设备计算能力和电量资源受限的应用场景的问题,在本实施例中,提出一种基于联邦学习的RNN训练优化方法,旨在降低训练设备的计算资源和电量资源消耗,从而使得在设备计算能力和电量资源受限的应用场景也能够进行RNN训练。
具体地,以一般RNN的结构为例说明如何基于联邦学习优化RNN的训练过程,但应当理解的是,本实施例中RNN训练优化方法不限于一般RNN,还可以适用于训练LSTM(LongShort-Term Memory,长短期记忆网络)、GRU(Gated Recurrent Unit,门控循环单元)和IndRNN(Independently Recurrent Neural Network,独立循环神经网络),还可以适用于深度RNN,例如,stacked LSTM。
如图3所示,为一般RNN结构,X是输入序列数据,S是记忆状态(记忆单元),O是输出。RNN会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻(即上一个时间步)的记忆状态。需要注意的是,X,S和O都是二维的向量(可以用矩阵来表示),例如,X的维度可以是NxT,其中,N是RNN处理单元的输入维度,T是时间步的个数。U是对应输入的权重矩阵,V是对应输出的权重矩阵,W是对应记忆状态的权重矩阵,U、V和W即RNN的模型参数。
如图4所示,为图3中RNN按照时间步展开的示意图。RNN的输入、输出和状态都是由时间步索引的(即时间步t);RNN的不同的时间步对应的神经网络节点是共享权重的,即所有的时间步共享权重矩阵U、V和W。需要说明的是,第t个时间步对应的输入xt、状态st以及输出ot都是向量,例如,xt的维度是Nx1,其中,N是RNN处理单元的输入维度。简单地说,RNN的每个时间步都对应一个深度神经网络,所有时间步对应的深度神经网络是共享权重矩阵的(U、V和W)。对一个时间步t,RNN的输入可能包括三个部分:外部输入xt、前一个时间步的输出ot-1,和前一个时间步的记忆状态st-1。
基于上述RNN的结构原理,在本实施例中,提出使用联邦学习技术,联合多个参与设备(例如,卫星或者IoT设备)一起训练RNN模型,不同的参与设备处理不同的时间步,将计算和电量开销分担到多个设备上。
具体地,可预先在各个参与设备中构建同样的RNN,也即各个参与设备的RNN结构相同,模型参数相同。可以预先确定各个参与设备需处理的时间步,一个参与设备可处理一个或多个时间步,但是避免一个参与设备处理所有的时间步。每个参与设备根据具体情况,可以处理不同数量的时间步,如可根据参与设备的计算资源和电量资源来进行适量分配。可以由协调设备收集各个参与设备的设备信息,根据设备信息进行时间步的分配。
在确定各个参与设备所处理的时间步后,各个参与设备按照各自处理的时间步的先后顺序通信连接,也即处理前一时间步的参与设备,与处理后一时间步的参与设备通信连接。需要说明的是,根据各个时间步之前的先后顺序关系不同,各参与设备通信连接的状态可以是交叉连接、串行连接或全连接等不同的连接状态。各个参与设备预先存储各自处理的时间步对应的训练数据,例如一条样本数据包括多个时间步的数据,则处理不同时间步的参与设备预先存储该条样本数据对应时间步的数据;训练数据包括多条样本数据,则各个参与设备本地存储了每条样本数据的部分时间步的数据。
各个参与设备处理各自负责的时间步,并按照时间步的先后顺序,传递各个时间步之间的传递信息,如前一时间步对应的记忆状态。具体地,参与设备接收其上游参与设备发送的上游传递信息,将上游传递信息和所处理时间步对应的本地训练数据输入本地的RNN中,得到RNN输出结果,以及要传递给下游参与设备的下游传递信息。参与设备将得到的下游传递信息发送给下游参与设备,对于下游参与设备来说该传递信息为上游传递信息,下游参与设备同样采样其上游传递信息对所负责的时间步进行处理。同样地,参与设备接收到的上游传递信息也是上游参与设备将上上游参与设备传递的上上游传递信息,以及该上游参与设备所处理时间步对应的本地训练数据输入上游参与设备本地的RNN中得到的。需要说明的是,处理第一个时间步的参与设备的上游传递信息根据具体情况,可以是预先设置的初始值;处理最后一个时间步的参与设备后面没有连接的下游参与设备,可以不向后传递下游传递信息。
需要说明的是,根据具体RNN的结构不同,下游传递信息可能不同,如下游传递信息可以是参与设备本地的RNN的输出结果,或者是本地的RNN的记忆状态,还可以是既包括输出结果又包括记忆状态。如当RNN是LSTM时,下游传递信息既包括输出结果又包括记忆状态。上游传递信息与下游传递信息等同。
参与设备将得到的RNN输出结果发送给协调设备。需要说明的是,根据根据具体应用场景或RNN结构不同,可以是各个参与设备均将各自处理的所有时间步对应的RNN输出结果都发送给协调设备,也可以是各个参与设备仅将各自处理的最后一个时间步对应的RNN输出结果发送给协调设备,还可以是仅处理最后一个时间步的参与设备将最后一个时间步对应的RNN输出结果发送给协调设备。如图3和4所示,每一个时间步都对应有一个RNN输出结果O,则参与设备可将每个时间步对应的O都发送给协调设备。如有的RNN结构是仅最后一个时间步对应一个RNN输出结果,则可以是仅处理最后一个时间步的参与设备将最后一个时间步对应的RNN输出结果发送给协调设备。
协调设备接收参与设备发送的RNN输出结果。
步骤S20,根据所述RNN输出结果计算得到预设损失函数对所述RNN输出结果的梯度信息;
RNN的模型参数学习可以通过随时间反向传播算法来学习,即按照时间的逆序把误差一步步往前传递。协调设备中预先设置与RNN结构对应的损失函数。协调设备根据RNN输出结果计算预设损失函数对RNN输出结果的梯度信息。需要说明的是,若协调设备接收各个参与设备发送的各时间步对应的RNN输出结果,则协调设备分别计算损失函数对各个RNN输出结果的梯度信息;若协调设备接收最后一个时间步对应的RNN输出结果,则协调设备计算损失函数对最后一个时间步对应的RNN输出结果的梯度信息。
步骤S30,将所述梯度信息反向传播给所述参与设备,以供各所述参与设备根据所述梯度信息计算所述损失函数对RNN模型参数的梯度信息,并根据所述损失函数对RNN模型参数的梯度信息来更新所述RNN的模型参数;
协调设备计算得到梯度信息后反向传播给参与设备。需要说明的是,协调设备接收参与设备发送的RNN输出结果,计算该RNN输出结果的梯度信息,将梯度信息返回给该参与设备,即协调设备将RNN输出结果的梯度信息对应返回给发送该RNN输出结果的参与设备。参与设备接收到梯度信息后,根据梯度信息更新本地的RNN的模型参数,具体地,参与设备根据RNN输出结果的梯度信息,按照链式法则反向推导各个模型参数的梯度信息,根据各个模型参数的梯度信息对应更新各模型参数。参与设备将更新后的RNN的模型参数发送给协调设备。
步骤S40,将从各所述参与设备接收到的更新后的模型参数进行融合得到全局模型参数,并将所述全局模型参数发送给各所述参与设备,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN。
协调设备接收各个参与设备发送的更新后的模型参数,并将各个更新后的模型参数进行融合得到全局模型参数。具体地融合方式可以是计算各个更新后模型参数的加权平均。协调设备将全局模型参数发送给各个参与设备,以确保各个参与设备获得相同的模型参数。各个参与设备根据全局模型参数更新本地RNN的模型参数,并采用更新后的RNN进行时间步的处理,迭代训练,直到协调设备检测到满足预设停止条件时,确定最终的模型参数,即得到了训练完成的RNN。其中,预设停止条件可以根据需要进行预先设置,如损失函数收敛,或迭代训练的次数达到最大次数,或迭代训练的时间达到最大训练时间。
在本实施例中,通过协调设备接收参与设备发送的RNN输出结果,其中,RNN输出结果为参与设备将上游参与设备传递的上游传递信息,以及参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到的,上游传递信息为上游参与设备将上上游参与设备传递的上上游传递信息,以及上游参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到的;协调设备根据RNN输出结果计算得到预设损失函数对RNN输出结果的梯度信息;将梯度信息反向传播给参与设备,以供各参与设备根据梯度信息计算所述损失函数对RNN模型参数的梯度信息,并根据所述损失函数对RNN模型参数的梯度信息来更新RNN的模型参数;将从各参与设备接收到的更新后的模型参数进行融合得到全局模型参数,并将全局模型参数发送给各参与设备,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN。在本实施例中,通过在训练RNN时,由协调设备协调多个参与设备分别处理不同时间步,分担训练RNN的计算负担和电量开销,降低了训练RNN对单个设备的计算和电量消耗,使得在设备计算资源和电量资源有限的场景下也可以进行RNN训练,适用于卫星在轨计算或者IoT场景。
此外,在本实施例中,由于不同的参与设备处理不同的时间步数据,自然地适用于不同的参与设备拥有不同的时间步数据的场景。例如,不同的遥感卫星可能有不同的时间维度的照片。同样地,序列数据也可以是空间维度的序列数据。本发明实施例方案就很适合联合这些遥感卫星来训练RNN和使用RNN进行时间序列数据预测或者空间序列数据预测。
在一实施例中,各个参与设备可以是拥有不同时间维度图像数据的遥感卫星,各个遥感卫星联合利用各自的图像数据训练RNN完成预测任务。具体地,各个遥感卫星根据所拥有的图像数据之间的时间维度关系,确定各自所处理的时间步。协调设备可以是其中一个遥感卫星,也可以是位于地面的基站。遥感卫星采用上游的遥感卫星传递的上游传递信息,以及该遥感卫星所处理时间步对应的图像数据输入对应的RNN中,得到RNN输出结果,上游传递信息是上上游遥感卫星传递的上上游传递信息,以及上游遥感卫星所处理时间步对应的图像数据输入对应的RNN中得到的;协调设备各遥感卫星发送的RNN输出结果计算得到预设损失函数对RNN输出结果的梯度信息,并将梯度信息反向传播给遥感卫星,以供各遥感卫星根据梯度信息计算损失函数对RNN模型参数的梯度信息,并根据损失函数对RNN模型参数的梯度信息来更新RNN的模型参数;协调设备将从各遥感卫星接收到的更新后的模型参数进行融合得到全局模型参数,并将全局模型参数发送给各遥感卫星,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN。在得到训练完成的RNN后,协调设备可将训练完成的RNN发送给各遥感卫星,各个遥感卫星可采用训练完成的RNN,输入图像数据完成预测任务。通过在训练RNN时,由协调设备协调多个遥感卫星分别处理不同时间步,分担训练RNN的计算负担和电量开销,降低了训练RNN对单个遥感卫星的计算和电量消耗。
此外,在本实施例中,由于各个参与设备之间,以及参与设备和协调设备之间不需传输训练数据,而只需要传输时间步间的少量传递信息以及RNN的输出结果即可,参与设备之间,以及参与设备与协调设备之间传输的信息量很小,不会显著增加额外的通信开销。
进一步地,基于上述第一实施例,提出本发明循环神经网络训练优化方法第二实施例,在本实施例中,所述RNN是神经网络模型的一部分,所述步骤S20包括:
步骤S201,将所述RNN输出结果输入所述神经网络模型的其他网络部分,得到所述神经网络模型的输出结果;
在本实施例中,RNN可作为神经网络模型的一部分,如图5所示,为一种使用RNN的深度神经网络示意图,整个结构包括输入层、CNN(Convolutional Neural Networks,卷积神经网络)层、LSTM层、DNN全连接层和输出层,其中,LSTM是RNN的一种。CNN层的输出作为LSTM层的输入,LSTM层的输出作为DNN全连接层的输入。协调设备中预先设置神经网络模型的其他部分,如参与设备中可预置输入层、CNN层和LSTM层,协调设备则预置DNN全连接层和输出层。协调设备与参与设备可通过分裂学习(split learning)的方式来合作训练神经网络模型。
具体地,协调设备接收到参与设备发送的RNN输出结果后,将RNN输出结果输入神经网络模型的其他网络部分,如输入LSTM层后面的全连接层。经过全连接层和输出层的处理,得到神经网络模型的输出结果。
步骤S202,根据所述神经网络模型的输出结果和预设的样本标签计算预设损失函数对所述RNN输出结果的梯度信息。
协调设备中可以预设个条训练样本数据的样本标签。协调设备根据神经网络模型的输出结果和预设的样本标签计算损失函数对RNN输出结果的梯度信息。具体地,协调设备可计算损失函数对神经网络模型的输出结果的损失函数,并按照链式法则反向推导损失函数对其他网络部分模型参数的梯度信息,以及损失函数对RNN输出结果的梯度信息;根据其他网络部分模型参数的梯度信息更新其他网络部分。
在本实施例中,通过将RNN作为神经网络模型的一部分,并通过协调设备协调不同参与设备处理RNN不同的时间步,降低了整个神经网络模型训练对单个设备的计算资源和电量资源的消耗,从而使得在设备计算资源和电量资源受限的应用场景,也可以进行复杂神经网络模型的训练,适用于卫星在轨计算或者IoT场景。
进一步地,在本实施例中,参与设备之间、以及参与设备与协调设备之间可以根据具体应用场景的需要,如为保护各个参与设备内部隐私数据,可可以对需要发送的数据进行加密,如采用同态加密技术对数据进行加密。
进一步地,步骤S40之后,还包括:
步骤S50,接收各所述参与设备发送的RNN预测输出结果,其中,各所述参与设备将各自的上游传递信息和所处理时间步对应的本地预测数据输入训练完成的RNN中,得到所述RNN预测输出结果;
由于当时间步较多,RNN结构复杂的情况下,采用训练完成的RNN完成预测或分类任务时,计算复杂度依然会很高,可能超出设备的计算资源。因此,在本实施例中,在得到训练完成的RNN后,协调设备可以联合各个参与设备一起使用训练完成的RNN完成预测或分类任务。具体地,与训练过程类似,预先可确定各个参与设备需处理的时间步,各个参与设备中预先存储预测数据中不同时间步的数据。各个参与设备将上游参与设备发送的上游传递信息和各自所处理时间步对应的本地预测数据输入本地训练完成的RNN中,得到RNN预测输出结果,同样地,也将得到的下游传递信息传递给下游参与设备。各参与设备将得到的RNN预测输出结果发送给协调设备。
步骤S60,将所述RNN预测输出结果输入所述神经网络模型的其他网络部分,得到所述神经网络模型的预测结果。
协调设备可将RNN预测输出结果输入神经网络模型的其他网络部分,得到神经网络模型的预测结果。例如,包含RNN的神经网络模型可以用于预测海洋污染的变化,则预测数据可以来自时间序列的遥感卫星图片,通过将时间序列的遥感卫星图片输入神经网络模型中的CNN,通过CNN提取图片的特征,作为RNN的各个时间步的输入数据;不同的参与设备处理不同的时间步得到RNN预测输出结果,并将RNN预测输出结果发送给协调设备;由协调设备将RNN预测输出结果输入神经网络模型后面的网络部分,最终得到海洋污染变化的预测结果。
进一步地,步骤S10之前,还包括:
步骤S70,根据所述RNN训练样本的时间步数和所述参与设备的数量确定各所述参与设备需处理的时间步;
协调设备确定各参与设备需处理的时间步的方式有多种。在本实施例中,协调设备可以根据RNN训练样本的时间步数,以及参与设备的数量来确定各个参与设备需处理的时间步。例如时间步数是T,一共有K个联邦学习的参与设备(例如K个进行在轨计算的卫星),那么可确定第j个参与设备负责处理T/K个时间步的数据,若T/K不是整数,则可以对T/K向上取整,其中有一个参与设备处理的时间步个数是T-(K-1)*ceil(T/K)。其中,ceil(T/K)表示对T/K进行向上取整。也即,若有时间步数是10,参与设备数量是3,则可确定参与设备1处理第1至4个时间步,参与设备2处理第5至8个时间步,参与设备3处理第9个和第10个时间步。
步骤S80,根据各所述参与设备拥有的训练数据所对应的时间步确定各所述参与设备需处理的时间步。
或者,当不同的参与设备中存储有不同时间步对应的训练数据时,如不同的遥感卫星可能有不同的时间维度的照片,协调设备可根据各参与设备拥有的训练数据对应的时间步,确定各个参与设备需处理的时间步,如参与设备1存储有第一时间步和第二时间步的数据,则确定由参与设备处理第一时间步和第二时间步。也即,时间步对应的数据在哪个参与设备,就确定哪个参与设备处理该时间步,以适应不同的参与设备拥有不同的时间步数据的场景。
进一步地,基于上述第一和二实施例,提出本发明循环神经网络训练优化方法第三实施例。在本实施例中,循环神经网络训练优化方法应用于基于联邦学习训练RNN的参与设备,参与设备与协调设备通信连接,各参与设备按照各自处理的时间步的先后顺序串行连接,协调设备和参与设备可以是智能手机、个人计算机和服务器等设备。在本实施例中,循环神经网络训练优化方法包括:
步骤A10,将上游参与设备传递的上游传递信息,以及所述参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到RNN输出结果,并将所述RNN输出结果发送给所述协调设备,其中,所述上游传递信息为所述上游参与设备将上上游参与设备传递的上上游传递信息,以及所述上游参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到的;
在本实施例中,提出使用联邦学习技术,联合多个参与设备(例如,卫星或者IoT设备)一起训练RNN模型,不同的参与设备处理不同的时间步,将计算和电量开销分担到多个设备上。
具体地,可预先在各个参与设备中构建同样的RNN,也即各个参与设备的RNN结构相同,模型参数相同。可以预先确定各个参与设备需处理的时间步,一个参与设备可处理一个或多个时间步,但是避免一个参与设备处理所有的时间步。每个参与设备根据具体情况,可以处理不同数量的时间步,如可根据参与设备的计算资源和电量资源来进行适量分配。可以由协调设备收集各个参与设备的设备信息,根据设备信息进行时间步的分配。
在确定各个参与设备所处理的时间步后,各个参与设备按照各自处理的时间步的先后顺序串行连接,也即处理前一时间步的参与设备,与处理后一时间步的参与设备连接。各个参与设备预先存储各自处理的时间步对应的训练数据,例如一条样本数据包括多个时间步的数据,则处理不同时间步的参与设备预先存储该条样本数据对应时间步的数据;训练数据包括多条样本数据,则各个参与设备本地存储了每条样本数据的部分时间步的数据。
各个参与设备处理各自负责的时间步,并按照时间步的先后顺序,传递各个时间步之间的传递信息,如前一时间步对应的记忆状态。具体地,参与设备接收其上游参与设备发送的上游传递信息,将上游传递信息和所处理时间步对应的本地训练数据输入本地的RNN中,得到RNN输出结果,以及要传递给下游参与设备的下游传递信息。参与设备将得到的下游传递信息发送给下游参与设备,对于下游参与设备来说该传递信息为上游传递信息,下游参与设备同样采样其上游传递信息对所负责的时间步进行处理。同样地,参与设备接收到的上游传递信息也是上游参与设备将上上游参与设备传递的上上游传递信息,以及该上游参与设备所处理时间步对应的本地训练数据输入上游参与设备本地的RNN中得到的。需要说明的是,处理第一个时间步的参与设备的上游传递信息根据具体情况,可以是预先设置的初始值;处理最后一个时间步的参与设备后面没有连接的下游参与设备,可以不向后传递下游传递信息。
需要说明的是,根据具体RNN的结构不同,下游传递信息可能不同,如下游传递信息可以是参与设备本地的RNN的输出结果,或者是本地的RNN的记忆状态,还可以是既包括输出结果又包括记忆状态。如当RNN是LSTM时,下游传递信息既包括输出结果又包括记忆状态。上游传递信息与下游传递信息等同。
参与设备将得到的RNN输出结果发送给协调设备。需要说明的是,根据根据具体应用场景或RNN结构不同,可以是各个参与设备均将各自处理的所有时间步对应的RNN输出结果都发送给协调设备,也可以是各个参与设备仅将各自处理的最后一个时间步对应的RNN输出结果发送给协调设备,还可以是仅处理最后一个时间步的参与设备将最后一个时间步对应的RNN输出结果发送给协调设备。如图3和4所示,每一个时间步都对应有一个RNN输出结果O,则参与设备可将每个时间步对应的O都发送给协调设备。如有的RNN结构是仅最后一个时间步对应一个RNN输出结果,则可以是仅处理最后一个时间步的参与设备将最后一个时间步对应的RNN输出结果发送给协调设备。
步骤A20,接收所述协调设备反向传播的梯度信息,其中,所述协调设备根据从各所述参与设备接收到的RNN输出结果计算得到所述梯度信息;
协调设备接收参与设备发送的RNN输出结果。RNN的模型参数学习可以通过随时间反向传播算法来学习,即按照时间的逆序把误差一步步往前传递。协调设备中预先设置与RNN结构对应的损失函数。协调设备根据RNN输出结果计算预设损失函数对RNN输出结果的梯度信息。需要说明的是,若协调设备接收各个参与设备发送的各时间步对应的RNN输出结果,则协调设备分别计算损失函数对各个RNN输出结果的梯度信息;若协调设备接收最后一个时间步对应的RNN输出结果,则协调设备计算损失函数对最后一个时间步对应的RNN输出结果的梯度信息。
协调设备计算得到梯度信息后反向传播给参与设备。需要说明的是,协调设备接收参与设备发送的RNN输出结果,计算该RNN输出结果的梯度信息,将梯度信息返回给该参与设备,即协调设备将RNN输出结果的梯度信息对应返回给发送该RNN输出结果的参与设备。
参与设备接收协调设备返回的梯度信息。
步骤A30,根据所述梯度信息计算所述损失函数对RNN模型参数的梯度信息,并根据所述损失函数对RNN模型参数的梯度信息来更新所述RNN的模型参数,并将更新后的模型参数发送给所述协调设备;
参与设备接收到梯度信息后,根据梯度信息更新本地的RNN的模型参数,具体地,参与设备根据RNN输出结果的梯度信息,按照链式法则反向推导各个模型参数的梯度信息,根据各个模型参数的梯度信息对应更新各模型参数。参与设备将更新后的RNN的模型参数发送给协调设备。
步骤A40,接收所述协调设备发送的全局模型参数,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN,其中,所述协调设备对各所述参与设备发送的模型参数进行融合得到所述全局模型参数。
协调设备接收各个参与设备发送的更新后的模型参数,并将各个更新后的模型参数进行融合得到全局模型参数。具体地融合方式可以是计算各个更新后模型参数的加权平均。协调设备将全局模型参数发送给各个参与设备,以确保各个参与设备获得相同的模型参数。
各个参与设备接收协调设备发送的全局模型参数,根据全局模型参数更新本地RNN的模型参数,并采用更新后的RNN进行时间步的处理,迭代训练,直到参与设备检测到满足预设停止条件时,确定最终的模型参数,即得到了训练完成的RNN。其中,预设停止条件可以根据需要进行预先设置,如损失函数收敛,或迭代训练的次数达到最大次数,或迭代训练的时间达到最大训练时间。
在本实施例中,通过参与设备将上游参与设备传递的上游传递信息,以及参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到RNN输出结果,并将RNN输出结果发送给协调设备,其中,上游传递信息为上游参与设备将上上游参与设备传递的上上游传递信息,以及上游参与设备所处理时间步对应的本地训练数据输入对应的RNN中得到的;接收协调设备反向传播的梯度信息,其中,协调设备根据从各参与设备接收到的RNN输出结果计算得到梯度信息;根据梯度信息更新RNN的模型参数,并将更新后的模型参数发送给协调设备;接收协调设备发送的全局模型参数,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN,其中,协调设备对各参与设备发送的模型参数进行融合得到全局模型参数。在本实施例中,通过在训练RNN时,由多个参与设备分别处理不同时间步,分担训练RNN的计算负担和电量开销,降低了训练RNN对单个设备的计算和电量消耗,使得在设备计算资源和电量资源有限的场景下也可以进行RNN训练,适用于卫星在轨计算或者IoT场景。
此外,在本实施例中,由于不同的参与设备处理不同的时间步数据,自然地适用于不同的参与设备拥有不同的时间步数据的场景。例如,不同的遥感卫星可能有不同的时间维度的照片。同样地,序列数据也可以是空间维度的序列数据。本发明实施例方案就很适合联合这些遥感卫星来训练RNN和使用RNN进行时间序列数据预测或者空间序列数据预测。
此外,在本实施例中,由于各个参与设备之间,以及参与设备和协调设备之间不需传输训练数据,而只需要传输时间步间的少量传递信息以及RNN的输出结果即可,参与设备之间,以及参与设备与协调设备之间传输的信息量很小,不会显著增加额外的通信开销。
如图6所示,为本实施例涉及的一种参与设备与协调设备联合训练RNN的流程示意图。如图7所示,为参与设备之间进行信息传递,以及参与设备向协调设备发送RNN输出结果的示意图。如图8所示,为协调设备向参与设备发送梯度信息以及全局模型参数更新的示意图。如图9所示,为参与设备向协调设备发送更新后的模型参数的示意图。
此外,本发明实施例还提出一种计算机可读存储介质,所述存储介质上存储有循环神经网络训练优化程序,所述循环神经网络训练优化程序被处理器执行时实现如下所述的循环神经网络训练优化方法的步骤。
本发明循环神经网络训练优化设备和计算机可读存储介质的各实施例,均可参照本发明循环神经网络训练优化方法各个实施例,此处不再赘述。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者装置不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者装置所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者装置中还存在另外的相同要素。
上述本发明实施例序号仅仅为了描述,不代表实施例的优劣。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端设备(可以是手机,计算机,服务器,空调器,或者网络设备等)执行本发明各个实施例所述的方法。
以上仅为本发明的优选实施例,并非因此限制本发明的专利范围,凡是利用本发明说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本发明的专利保护范围内。
Claims (9)
1.一种循环神经网络训练优化方法,其特征在于,应用于基于联邦学习训练循环神经网络RNN的协调设备,所述协调设备与各参与设备通信连接,各所述参与设备按照各自处理的时间步的先后顺序通信连接,所述协调设备为遥感卫星或位于地面的基站,所述参与设备为拥有不同时间维度图像数据的遥感卫星,所述循环神经网络训练优化方法包括以下步骤:
接收所述参与设备发送的RNN输出结果,其中,所述RNN输出结果为所述参与设备将上游参与设备传递的上游传递信息,以及所述参与设备所处理时间步对应的图像数据输入对应的RNN中得到的,所述上游传递信息为所述上游参与设备将上上游参与设备传递的上上游传递信息,以及所述上游参与设备所处理时间步对应的图像数据输入对应的RNN中得到的;
根据所述RNN输出结果计算得到预设损失函数对所述RNN输出结果的梯度信息;
将所述梯度信息反向传播给所述参与设备,以供各所述参与设备根据所述梯度信息计算所述损失函数对RNN模型参数的梯度信息,并根据所述损失函数对RNN模型参数的梯度信息来更新所述RNN的模型参数;
将从各所述参与设备接收到的更新后的模型参数进行融合得到全局模型参数,并将所述全局模型参数发送给各所述参与设备,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN;
其中,所述接收所述参与设备发送的RNN输出结果的步骤之前,还包括:
根据RNN训练样本的时间步数和所述参与设备的数量确定各所述参与设备需处理的时间步;或,
根据各所述参与设备拥有的训练数据所对应的时间步确定各所述参与设备需处理的时间步。
2.如权利要求1所述的循环神经网络训练优化方法,其特征在于,所述RNN是神经网络模型的一部分,所述根据所述RNN输出结果计算得到预设损失函数对所述RNN输出结果的梯度信息的步骤包括:
将所述RNN输出结果输入所述神经网络模型的其他网络部分,得到所述神经网络模型的输出结果;
根据所述神经网络模型的输出结果和预设的样本标签计算预设损失函数对所述RNN输出结果的梯度信息。
3.如权利要求2所述的循环神经网络训练优化方法,其特征在于,所述将从各所述参与设备接收到的更新后的模型参数进行融合得到全局模型参数,并将所述全局模型参数发送给各所述参与设备,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN的步骤之后,还包括:
接收各所述参与设备发送的RNN预测输出结果,其中,各所述参与设备将各自的上游传递信息和所处理时间步对应的本地预测数据输入训练完成的RNN中,得到所述RNN预测输出结果;
将所述RNN预测输出结果输入所述神经网络模型的其他网络部分,得到所述神经网络模型的预测结果。
4.一种循环神经网络训练优化方法,其特征在于,应用于基于联邦学习训练循环神经网络RNN的参与设备,所述参与设备与协调设备通信连接,各所述参与设备按照各自处理的时间步的先后顺序通信连接,所述协调设备为遥感卫星或位于地面的基站,所述参与设备为拥有不同时间维度图像数据的遥感卫星,所述循环神经网络训练优化方法包括以下步骤:
将上游参与设备传递的上游传递信息,以及所述参与设备所处理时间步对应的图像数据输入对应的RNN中得到RNN输出结果,并将所述RNN输出结果发送给所述协调设备,其中,所述上游传递信息为所述上游参与设备将上上游参与设备传递的上上游传递信息,以及所述上游参与设备所处理时间步对应的图像数据输入对应的RNN中得到的;
接收所述协调设备反向传播的梯度信息,其中,所述协调设备根据从各所述参与设备接收到的RNN输出结果计算得到所述梯度信息;
根据所述梯度信息计算损失函数对RNN模型参数的梯度信息,并根据所述损失函数对RNN模型参数的梯度信息来更新所述RNN的模型参数,并将更新后的模型参数发送给所述协调设备;
接收所述协调设备发送的全局模型参数,迭代训练直到检测到满足预设停止条件时得到训练完成的RNN,其中,所述协调设备对各所述参与设备发送的模型参数进行融合得到所述全局模型参数;
其中,所述将上游参与设备传递的上游传递信息,以及所述参与设备所处理时间步对应的图像数据输入对应的RNN中得到RNN输出结果的步骤之前,还包括:
根据所拥有的所述图像数据之间的时间维度关系,确定各所述参与设备所处理的时间步。
5.如权利要求4所述的循环神经网络训练优化方法,其特征在于,所述上游传递信息包括所述上游参与设备对应的RNN的输出结果和/或记忆状态。
6.一种循环神经网络训练优化设备,其特征在于,所述循环神经网络训练优化设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的循环神经网络训练优化程序,所述循环神经网络训练优化程序被所述处理器执行时实现如权利要求1至3中任一项所述的循环神经网络训练优化方法的步骤。
7.一种循环神经网络训练优化设备,其特征在于,所述循环神经网络训练优化设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的循环神经网络训练优化程序,所述循环神经网络训练优化程序被所述处理器执行时实现如权利要求4至5中任一项所述的循环神经网络训练优化方法的步骤。
8.一种循环神经网络训练优化系统,其特征在于,所述循环神经网络训练优化系统包括基于联邦学习训练循环神经网络RNN的至少一个协调设备和至少两个参与设备,其中,所述协调设备为权利要求6所述的循环神经网络训练优化设备,所述参与设备为权利要求7所述的循环神经网络训练优化设备。
9.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有循环神经网络训练优化程序,所述循环神经网络训练优化程序被处理器执行时实现如权利要求1至5中任一项所述的循环神经网络训练优化方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911141081.6A CN111222628B (zh) | 2019-11-20 | 2019-11-20 | 循环神经网络训练优化方法、设备、系统及可读存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911141081.6A CN111222628B (zh) | 2019-11-20 | 2019-11-20 | 循环神经网络训练优化方法、设备、系统及可读存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111222628A CN111222628A (zh) | 2020-06-02 |
CN111222628B true CN111222628B (zh) | 2023-09-26 |
Family
ID=70827645
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201911141081.6A Active CN111222628B (zh) | 2019-11-20 | 2019-11-20 | 循环神经网络训练优化方法、设备、系统及可读存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111222628B (zh) |
Families Citing this family (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111737921B (zh) * | 2020-06-24 | 2024-04-26 | 深圳前海微众银行股份有限公司 | 基于循环神经网络的数据处理方法、设备及介质 |
CN111737920B (zh) * | 2020-06-24 | 2024-04-26 | 深圳前海微众银行股份有限公司 | 基于循环神经网络的数据处理方法、设备及介质 |
CN111737922A (zh) * | 2020-06-24 | 2020-10-02 | 深圳前海微众银行股份有限公司 | 基于循环神经网络的数据处理方法、装置、设备及介质 |
CN114363921B (zh) * | 2020-10-13 | 2024-05-10 | 维沃移动通信有限公司 | Ai网络参数的配置方法和设备 |
CN112564974B (zh) * | 2020-12-08 | 2022-06-14 | 武汉大学 | 一种基于深度学习的物联网设备指纹识别方法 |
CN112865116B (zh) * | 2021-01-11 | 2022-04-12 | 广西大学 | 一种平行联邦图神经网络的十三区图无功优化方法 |
CN112836816B (zh) * | 2021-02-04 | 2024-02-09 | 南京大学 | 一种适用于光电存算一体处理单元串扰的训练方法 |
CN112733967B (zh) * | 2021-03-30 | 2021-06-29 | 腾讯科技(深圳)有限公司 | 联邦学习的模型训练方法、装置、设备及存储介质 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108197701A (zh) * | 2018-02-05 | 2018-06-22 | 哈工大机器人(合肥)国际创新研究院 | 一种基于rnn的多任务学习方法 |
CN109325584A (zh) * | 2018-08-10 | 2019-02-12 | 深圳前海微众银行股份有限公司 | 基于神经网络的联邦建模方法、设备及可读存储介质 |
CN109447244A (zh) * | 2018-10-11 | 2019-03-08 | 中山大学 | 一种结合门控循环单元神经网络的广告推荐方法 |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11256990B2 (en) * | 2016-05-20 | 2022-02-22 | Deepmind Technologies Limited | Memory-efficient backpropagation through time |
EP3446259A1 (en) * | 2016-05-20 | 2019-02-27 | Deepmind Technologies Limited | Training machine learning models |
-
2019
- 2019-11-20 CN CN201911141081.6A patent/CN111222628B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108197701A (zh) * | 2018-02-05 | 2018-06-22 | 哈工大机器人(合肥)国际创新研究院 | 一种基于rnn的多任务学习方法 |
CN109325584A (zh) * | 2018-08-10 | 2019-02-12 | 深圳前海微众银行股份有限公司 | 基于神经网络的联邦建模方法、设备及可读存储介质 |
CN109447244A (zh) * | 2018-10-11 | 2019-03-08 | 中山大学 | 一种结合门控循环单元神经网络的广告推荐方法 |
Non-Patent Citations (1)
Title |
---|
循环神经网络在语音识别模型中的训练加速方法;冯诗影;韩文廷;金旭;迟孟贤;安虹;;小型微型计算机系统(12);第3-7页 * |
Also Published As
Publication number | Publication date |
---|---|
CN111222628A (zh) | 2020-06-02 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111222628B (zh) | 循环神经网络训练优化方法、设备、系统及可读存储介质 | |
US11544573B2 (en) | Projection neural networks | |
CN111784002B (zh) | 分布式数据处理方法、装置、计算机设备及存储介质 | |
EP3711000B1 (en) | Regularized neural network architecture search | |
US11244243B2 (en) | Coordinated learning using distributed average consensus | |
CN111625361A (zh) | 一种基于云端服务器和IoT设备协同的联合学习框架 | |
JP7383803B2 (ja) | 不均一モデルタイプおよびアーキテクチャを使用した連合学習 | |
CN113505882B (zh) | 基于联邦神经网络模型的数据处理方法、相关设备及介质 | |
WO2022257730A1 (zh) | 实现隐私保护的多方协同更新模型的方法、装置及系统 | |
Han et al. | A deep reinforcement learning based solution for flexible job shop scheduling problem | |
CN113159283A (zh) | 一种基于联邦迁移学习的模型训练方法及计算节点 | |
US20210065011A1 (en) | Training and application method apparatus system and stroage medium of neural network model | |
CN111222046B (zh) | 服务配置方法、用于服务配置的客户端、设备及电子设备 | |
CN113011603A (zh) | 模型参数更新方法、装置、设备、存储介质及程序产品 | |
Kondratenko et al. | Multi-criteria decision making and soft computing for the selection of specialized IoT platform | |
CN112948885B (zh) | 实现隐私保护的多方协同更新模型的方法、装置及系统 | |
CN111931901A (zh) | 一种神经网络构建方法以及装置 | |
WO2023213157A1 (zh) | 数据处理方法、装置、程序产品、计算机设备和介质 | |
CN114240506A (zh) | 多任务模型的建模方法、推广内容处理方法及相关装置 | |
US20230229963A1 (en) | Machine learning model training | |
CN112532251A (zh) | 一种数据处理的方法及设备 | |
CN114445692B (zh) | 图像识别模型构建方法、装置、计算机设备及存储介质 | |
CN115965078A (zh) | 分类预测模型训练方法、分类预测方法、设备及存储介质 | |
CN115660116A (zh) | 基于稀疏适配器的联邦学习方法及系统 | |
CN111967612A (zh) | 横向联邦建模优化方法、装置、设备及可读存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |