WO2023142439A1 - 一种模型梯度更新方法及装置 - Google Patents

一种模型梯度更新方法及装置 Download PDF

Info

Publication number
WO2023142439A1
WO2023142439A1 PCT/CN2022/112615 CN2022112615W WO2023142439A1 WO 2023142439 A1 WO2023142439 A1 WO 2023142439A1 CN 2022112615 W CN2022112615 W CN 2022112615W WO 2023142439 A1 WO2023142439 A1 WO 2023142439A1
Authority
WO
WIPO (PCT)
Prior art keywords
gradient
update process
node
gradient update
probability
Prior art date
Application number
PCT/CN2022/112615
Other languages
English (en)
French (fr)
Inventor
程栋
程新
周雍恺
高鹏飞
姜铁城
Original Assignee
中国银联股份有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 中国银联股份有限公司 filed Critical 中国银联股份有限公司
Publication of WO2023142439A1 publication Critical patent/WO2023142439A1/zh

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F8/00Arrangements for software engineering
    • G06F8/60Software deployment
    • G06F8/65Updates

Definitions

  • the present application relates to the technical field of model training, in particular to a method and device for updating model gradients.
  • Horizontal federated learning also known as federated learning by sample, can be applied to scenarios where the data sets of each participant in federated learning have the same characteristics but different samples.
  • each participant calculates the gradient of the model locally, and sends the gradient (the gradient needs to be encrypted) to the central server.
  • the central server aggregates multiple gradients.
  • the central server sends the aggregated gradient (the gradient also needs to be encrypted) to each participant.
  • Each participant uses the received gradients to update their respective model parameters.
  • the central server averages and aggregates the gradients. Considering the different performances of different participants, the model trained by gradient averaging is not good.
  • the present application provides a method and device for updating model gradients to improve the accuracy of model training.
  • a model gradient update method which is applied to the central server, including:
  • the central server repeatedly executes the gradient update process until the stop condition is met; wherein, one gradient update process includes:
  • the first gradients are obtained by each node using sample data to train the model to be trained in the node one or more times; based on multiple first gradients and this gradient update
  • the probability of each node in the process obtains the second gradient, and the probability of each node in the gradient update process is determined by the Actor-Critic network based on the probability of each node in the last gradient update process;
  • the second gradients are respectively sent to the multiple nodes, so that the multiple nodes use the second gradients to update the weights of their respective models to be trained.
  • the Actor-Critic network includes an Actor network, at least one Critic network, and a reward function;
  • the reward function is used to determine a reward value based on the probabilities of the plurality of nodes determined in the last gradient update process, and transmit the reward value to the at least one critic network;
  • the at least one Critic network is used to determine a target Q value and transmit the target Q value to the Actor network;
  • the Actor network is used to determine the probability of each node in this gradient update process based on the target Q value.
  • the target Q value is the smallest Q value among the Q values determined by multiple critic networks.
  • the reward function satisfies:
  • A is the first accuracy rate
  • B is the second accuracy rate
  • g is greater than or equal to 1
  • the first accuracy rate is the model obtained by the central server and the plurality of nodes based on the federated average learning algorithm.
  • the accuracy rate; the second accuracy rate is the average value of the third accuracy rates sent by the plurality of nodes, and the third accuracy is the difference between the first gradient and the node using sample data in the node
  • the model to be trained is obtained in the same model training.
  • the Actor-Critic network includes 3 Critic networks, and for any Critic network, the Q value determined in this gradient update process is based on the Q value gradient and the value determined in the last gradient update process The Q value is determined, and the Q value gradient is determined based on a first parameter, and the first parameter satisfies the following formula:
  • the ⁇ used in this gradient update process is determined based on the ⁇ gradient and the ⁇ used in the previous gradient update process, and the ⁇ gradient satisfies the following formula:
  • J( ⁇ ) is the ⁇ gradient
  • ⁇ t-1 is the ⁇ used in the last gradient update
  • H is the ideal minimum expected entropy.
  • the k and l are determined based on variances of accuracy rates of models in the plurality of nodes before gradient updates are performed.
  • the probability of the node output by the Actor network in this gradient update process is determined based on the probability gradient and the probability of the node output in the previous gradient update process, and the probability gradient satisfies the following formula:
  • J( ⁇ ⁇ ) is the probability gradient
  • t is the number of gradient updates this time
  • ⁇ 1 , ⁇ 2 , ⁇ 3 represent three Critic networks respectively
  • ⁇ i represents ⁇ 1 , ⁇ 2 , ⁇ 3 represent three
  • st is the state in the tth gradient update process
  • a t is the probability of the plurality of nodes in the tth gradient update process
  • s t ) is the probability of making a t under s t
  • q is The index of entropy, ln q is entropy, ⁇ t is the ⁇ used in this gradient update, and ⁇ t is not 0.
  • the central server and the multiple nodes perform gradient updates based on a federated learning architecture.
  • the embodiment of the present application provides a model gradient update device, including:
  • one gradient update process includes:
  • a receiving module configured to receive first gradients sent by multiple nodes, the first gradients are obtained by each node using sample data to train the model to be trained in the node one or more times;
  • a processing module configured to obtain a second gradient based on a plurality of first gradients and the probability of each node in this gradient update process, where the probability of each node in this gradient update process is the Actor-Critic network based on the above The probability of each node in a gradient update process is determined;
  • a sending module configured to send the second gradients to the multiple nodes respectively, so that the multiple nodes use the second gradients to update the weights of their respective models to be trained.
  • the embodiment of the present application provides a model gradient update device, including a processor and a memory;
  • said memory for storing computer programs or instructions
  • the processor is configured to execute part or all of the computer programs or instructions in the memory, and when the part or all of the computer programs or instructions are executed, it is used to implement the model gradient update method described in any one of the above.
  • An embodiment of the present application provides a computer-readable storage medium for storing a computer program, where the computer program includes instructions for implementing any one of the model gradient updating methods.
  • This application considers the probability of each node, can optimize the node participation, and make the determined model better.
  • FIG. 1 is a schematic diagram of a model gradient update process provided by the present application
  • Fig. 2 is a kind of model gradient update system architecture diagram provided by the present application.
  • FIG. 3 is a structural diagram of a model gradient updating device provided by the present application.
  • FIG. 4 is a structural diagram of a model gradient updating device provided by the present application.
  • the central server repeats the gradient update process until the stop condition is met.
  • the stop condition is, for example, that the loss function converges, or the upper limit of the allowed iteration number is reached, or the allowed training time is reached.
  • Step 101 multiple nodes (for example, node 1, node 2, ... node m) respectively send the first gradient to the central server.
  • the central server receives the first gradients respectively sent by the multiple nodes.
  • the first gradient is obtained by each node using sample data to train the model to be trained in the node one or more times.
  • the gradient sent by a node to the central server is called the first gradient, and the first gradients sent by multiple nodes may be the same or different.
  • Nodes can be car terminals, and sample data can be data generated in autonomous driving. Different driving data are generated during driving, and the driving data is scattered in each node (vehicle terminal), and there are characteristics of unbalanced data quality and unbalanced node performance.
  • the model can be a model that needs to be used for automatic driving and user habits.
  • Step 102 Actor-Critic network determines the probability of each node in this gradient update process based on the probability of each node in the last gradient update process.
  • the probability of each node in the last gradient update process is the initial probability of each node.
  • the sum of the probabilities of multiple nodes can be 1.
  • the Actor-Critic network can output the probability of multiple nodes every cycle, or the probability of outputting nodes after multiple cycles.
  • the Actor-Critic network may or may not be located on the central server.
  • Step 103 The central server obtains the second gradient based on multiple first gradients and multiple probabilities.
  • the multiple first gradients and multiple probabilities can be used to determine the second gradient in a weighted average manner, for example, node 1, node 2, and node 3 respectively send
  • the first gradient is p1, p2, p3, the probabilities of the three nodes are 0.2, 0.4, 0.4, then the second gradient is the value of 0.2p1+0.4p2+0.4p3.
  • This process can also be referred to as a data fusion algorithm based on deep reinforcement learning, where multiple first gradients and multiple probabilities are used to obtain the second gradient.
  • Step 104 The central server sends the second gradients to the multiple nodes (such as node 1, node 2, ... node m) respectively, and correspondingly, the multiple nodes receive the second gradients from the central server.
  • the multiple nodes such as node 1, node 2, ... node m
  • the gradient sent by the central server to the nodes is called the second gradient, and the second gradients sent to multiple nodes are the same.
  • Step 105 multiple nodes update the weights of their respective models to be trained by using the second gradient.
  • the updated model is the trained model.
  • each node may use sample data to perform one or more trainings on the model to be trained in the node to obtain a new first gradient, and continue to repeatedly execute steps 101 - 105 .
  • the node participation can be optimized to make the determined model better.
  • the central server and the multiple nodes perform gradient updates based on a federated learning architecture.
  • the sample data of each node is private and not shared with other nodes and the central server, and when the nodes and the central server transmit the first gradient and the second gradient, the first gradient and the second gradient are encrypted.
  • a gradient update system architecture diagram is introduced, including multiple nodes, a central server, and an Actor-Critic network.
  • Actor-Critic networks may or may not be located on a central server.
  • Actor-Critic consists of two parts from the name, actor (Actor) and evaluator (Critic).
  • Actor is responsible for generating actions (Action) and interacting with the environment.
  • the Critic is responsible for evaluating the performance of the Actor and guiding the Actor's actions in the next stage.
  • the Actor-Critic network includes an Actor network, at least one Critic network, and a reward function;
  • the reward function is used to determine a reward value based on the probabilities of the plurality of nodes determined in the last gradient update process, and transmit the reward value to the at least one critic network;
  • the at least one Critic network is used to determine the target Q value and transmit the target Q value to the Actor network.
  • Each critic network determines a Q value. If there is a critic network, the Q value determined by the critic network is the target Q value. If there are multiple critic networks, a Q value can be selected from multiple Q values. as the target Q value. For example, the target Q value is the minimum Q value among the Q values determined by multiple critic networks.
  • the Actor network determines the probability of each node in this gradient update process based on the target Q value. And the probabilities of the plurality of nodes are transmitted to the reward function, and the loop is repeated until stopping.
  • This application can use the existing Actor-Critic network to determine the probability, and can also improve the existing Actor-Critic network, such as setting up multiple Critic networks, such as improving the algorithms involved in the Critic network, such as the Actor network
  • the algorithm involved is improved, for example, the reward function is improved. It can be understood that the improvement only involves specific details.
  • the operation mechanism of the improved Actor-Critic network is similar to that of the existing Actor-Critic network.
  • the parameters involved in the Actor-Critic network can be initialized first, including but not limited to: the parameters involved in the Critic network, the parameters involved in the Actor network, and the parameters involved in the reward function.
  • the central server and multiple nodes obtain the trained model based on the federated average learning algorithm, and determine the first accuracy rate A of the trained model.
  • Node 1, node 2..., node m are based on the currently saved model (also can be referred to as the model to be trained, for example, it can be a model that also performs the gradient update process of step 101-step 105, or it can be a model that has been executed once or The model of the gradient update process of step 101-step 105 multiple times) conducts one or more trainings to obtain the first gradient and the third accuracy rate B', and each node sends the first gradient and the third accuracy rate B' to the central server, and the first gradient and the third accuracy rate B' are obtained by each node.
  • the three accuracy rates and the first gradient are obtained during the same model training at the node.
  • the central server calculates the average value of multiple third accuracy rates B' to obtain the second accuracy rate B.
  • the reward function determines the reward value r based on the first accuracy rate A and the second accuracy rate B, and the first accuracy rate is the accuracy rate of the trained model obtained by the central server and the multiple nodes based on federated learning; the second accuracy rate The rate is the average value of the third accuracy rates sent by the plurality of nodes respectively, and the third accuracy rate is the same as the first gradient using sample data on the model to be trained in the node. obtained during model training.
  • the reward function is expressed as:
  • the value of A is the same, and the value of B may be the same or different.
  • g is greater than or equal to 1.
  • this application sets two reward functions, namely: when When it is greater than 1, the value of g is a constant greater than 1, which can be strongly guided to complete gradient training faster; when When less than or equal to 1, g is set to 1.
  • the Q value gradient is determined based on a first algorithm and a second algorithm, wherein the first algorithm has a characteristic of being biased towards a specific action during training, and the second algorithm has a characteristic of selecting various tasks in a balanced manner during training.
  • the first algorithm can be a deep deterministic policy gradient algorithm (Deep Deterministic Policy Gradient, DDPG) algorithm
  • the second algorithm can be a SAC algorithm
  • the SAC algorithm can be a SAC (Soft actor critic with automatic entropy adjustment , SAC-AEA) reinforcement learning algorithm.
  • DDPG is subject to the update strategy of the algorithm itself, and it will be biased towards a specific action in the later stage of training, which is not conducive to the probability of scheduling multiple nodes, and is not conducive to the realization of overall model fusion, which will cause the final trained model to be different from
  • the data of a certain node is highly correlated, and the contribution of model data of other nodes to the model results becomes lower, that is, the utilization efficiency of multi-party data is greatly reduced, and even leads to poor model training results, or problems such as overfitting.
  • the SAC-AEA reinforcement learning algorithm itself can select actions in a relatively balanced manner.
  • the data quality of each node the contribution to the model, and the local computing power (such as the computing efficiency of the local device), etc.
  • the balanced fusion of their data is obviously not conducive to the improvement of model training results, or under-fitting occurs, which cannot fully represent the complete data characteristics.
  • the Q value determined by the Critic network described in this application is updated based on the DDPG algorithm and the SAC algorithm. Combining the DDPG algorithm and the SAC algorithm can combine the advantages of the two algorithms, so that the trained model can integrate the performance of multiple nodes, and the model is better.
  • the present application can set the first weight for the DDPG algorithm, set the second weight for the SAC algorithm, determine the Q value gradient based on the DDPG algorithm and the first weight, the SAC algorithm and the second weight, and the first weight and the second weight are based on The variance of the accuracy rate of the models before the gradient update in the plurality of nodes is determined.
  • the Q value gradient is determined based on the first parameter J, for example, the Q value gradient is the product of the first parameter J and the step size, The step size is not 0.
  • the first parameter J satisfies the following formula:
  • t is the number of gradient updates.
  • t is an integer greater than or equal to 1.
  • the variance of the accuracy rate is determined; the variance of the accuracy rate can represent the performance (such as computing power) difference of the node to a certain extent. The larger the variance, the larger the performance difference of the nodes, conversely, the smaller the variance, the smaller the performance difference of multiple nodes.
  • ⁇ 1 , ⁇ 2 , and ⁇ 3 represent the three Critic networks respectively, and ⁇ i is ⁇ 1 , ⁇ 2 , and ⁇ 3 respectively represent the network corresponding to the minimum value among the newly determined Q values of the three Critic networks. It is understandable that during this gradient update process, the Q values newly determined by the three Critic networks refer to the last gradient update (if this gradient update is the tth time, then the last time is the t-1th time ) the Q value determined in the process;
  • st t is the state during the t-th gradient update process; st t can be the accuracy rate or average accuracy rate of each node, the gradient or gradient average value of each node, or the variance of each node;
  • a t is the probability of the plurality of nodes in the tth gradient update process; a t can also be called an action;
  • r(st t , a t ) is the reward value in the case of st t , a t during the tth gradient update process
  • is the attenuation factor, and ⁇ is greater than 0;
  • s t ) is the conditional probability, and ⁇ t (a t
  • q is an index of entropy (such as Tasslis entropy), q is an integer greater than or equal to 1, which can be 1 or 2 or 3, and ln q is entropy, which is a curve. When q is different, ln q is a curve family;
  • E is the number expectation, E expects the data in [*], the independent variables are st and a t , the content in [] in the above formula is the implicit expression of st and a t ;
  • ⁇ D refers to s, a in the memory bank; assuming that the memory bank D can store M s and M a, The memory bank D is cyclically covered.
  • M s and M a’s are obtained.
  • the Actor-Critic network only outputs the probability of nodes at the M+1th cycle.
  • the previous M The probability that the Actor-Critic network does not output a node in the second cycle, or the probability that the Actor-Critic network outputs a node in the previous M cycles can be ignored.
  • a ⁇ ⁇ means that ⁇ ⁇ is determined based on a, for example, multiple a forms a ⁇ ⁇ curve (or set).
  • the alpha gradient satisfies the following formula:
  • J( ⁇ ) is the ⁇ gradient
  • ⁇ t-1 is the ⁇ used in the last gradient update
  • H is the ideal minimum expected entropy
  • the entropy is, for example, Tasslis entropy.
  • the gradient weight parameters can be adjusted adaptively. According to the actual scene, only the weight parameters need to be adjusted. For specific action selection or balanced action selection, it can be scheduled according to the specific situation. Model information for each participant.
  • the probability of the node output by the Actor network this time is the sum of the probability of the node output last time (ie, the probability of the node before updating) and the probability gradient.
  • the Tasslis entropy concept and adaptive parameters are integrated into the Actor network, and the probability gradient is determined based on the following formula:
  • J( ⁇ ⁇ ) is the probability gradient
  • ⁇ t is the ⁇ obtained in this (t-th) gradient update process
  • ⁇ i is ⁇ 1 , ⁇ 2 , and ⁇ 3 represent the latest Q values determined by the three Critic networks.
  • the network corresponding to the minimum value. It can be understood that during this gradient update process, the latest Q values determined by the three critic networks refer to the Q values determined during this gradient update process (for example, the t-th gradient update).
  • the data fusion algorithms of different nodes are adjusted, and an optimized fusion strategy combining the deep reinforcement learning data fusion model with the federated average algorithm is designed.
  • the participation of different nodes and the degree of data utilization during training can be adjusted.
  • the method in the embodiment of the present application is introduced above, and the device in the embodiment of the present application will be introduced in the following.
  • the method and the device are based on the same technical concept. Since the principles of the method and the device to solve problems are similar, the implementation of the device and the method can be referred to each other, and the repetition will not be repeated.
  • the embodiment of the present application may divide the device into functional modules according to the above method example, for example, each function may be divided into each functional module, or two or more functions may be integrated into one module.
  • These modules can be implemented not only in the form of hardware, but also in the form of software function modules. It should be noted that the division of modules in the embodiment of the present application is schematic, and is only a logical function division, and there may be other division methods during specific implementation.
  • a model gradient update device including:
  • the receiving module 301 is configured to receive first gradients respectively sent by multiple nodes, and the first gradients are obtained by each node using sample data to train the model to be trained in the node one or more times;
  • the processing module 302 is used to obtain the second gradient based on a plurality of first gradients and the probability of each node in the current gradient update process, the probability of each node in the current gradient update process is the Actor-Critic network based on The probability of each node in the last gradient update process is determined;
  • the sending module 303 is configured to send the second gradients to the multiple nodes respectively, so that the multiple nodes use the second gradients to update the weights of their respective models to be trained.
  • the above process is the gradient update process once, and the gradient update process is repeatedly executed until the stop condition is satisfied.
  • a model gradient update device including a processor 401 and a memory 402, and optionally, a transceiver 403;
  • the memory 402 is used to store computer programs or instructions
  • the processor 401 is configured to execute part or all of the computer programs or instructions in the memory, and when the part or all of the computer programs or instructions are executed, it is used to implement the model gradient update method described in any one of the above .
  • the transceiver 403 performs receiving and sending actions
  • the processor 401 performs other actions except the receiving and sending actions.
  • An embodiment of the present application provides a computer-readable storage medium for storing a computer program, where the computer program includes instructions for implementing any one of the model gradient updating methods.
  • the embodiment of the present application also provides a computer program product, including: computer program code, when the computer program code is run on the computer, the computer can execute the method for updating the model gradient provided above.
  • the embodiment of the present application also provides a communication system, and the communication system includes: a node and a central server that execute the above method for updating model gradients.
  • processors mentioned in the embodiment of the present application may be a central processing unit (central processing unit, CPU), a baseband processor, and the baseband processor and the CPU may be integrated or separated, or may be a network processor (network processing unit).
  • processor NP
  • processors may further include hardware chips or other general-purpose processors.
  • the aforementioned hardware chip may be an application-specific integrated circuit (application-specific integrated circuit, ASIC), a programmable logic device (programmable logic device, PLD) or a combination thereof.
  • the above PLD can be complex programmable logic device (complex programmable logic device, CPLD), field programmable logic gate array (field-programmable gate array, FPGA), general array logic (generic array logic, GAL) and other programmable logic devices , discrete gate or transistor logic devices, discrete hardware components, etc., or any combination thereof.
  • CPLD complex programmable logic device
  • FPGA field programmable logic gate array
  • GAL general array logic
  • GAL generator array logic
  • a general-purpose processor may be a microprocessor, or the processor may be any conventional processor, or the like.
  • the memory mentioned in the embodiments of the present application may be a volatile memory or a nonvolatile memory, or may include both volatile and nonvolatile memories.
  • the non-volatile memory can be read-only memory (Read-Only Memory, ROM), programmable read-only memory (Programmable ROM, PROM), erasable programmable read-only memory (Erasable PROM, EPROM), electronically programmable Erase Programmable Read-Only Memory (Electrically EPROM, EEPROM) or Flash.
  • the volatile memory can be Random Access Memory (RAM), which acts as external cache memory.
  • RAM Static Random Access Memory
  • SRAM Static Random Access Memory
  • DRAM Dynamic Random Access Memory
  • Synchronous Dynamic Random Access Memory Synchronous Dynamic Random Access Memory
  • SDRAM double data rate synchronous dynamic random access memory
  • Double Data Rate SDRAM DDR SDRAM
  • enhanced SDRAM ESDRAM
  • Synchlink DRAM SLDRAM
  • Direct Memory Bus Random Access Memory Direct Rambus RAM, DR RAM
  • the transceiver mentioned in the embodiment of the present application may include a separate transmitter and/or a separate receiver, or the transmitter and the receiver may be integrated. Transceivers can operate under the direction of the corresponding processor.
  • the transmitter may correspond to the transmitter in the physical device
  • the receiver may correspond to the receiver in the physical device.
  • the disclosed systems, devices and methods may be implemented in other ways.
  • the device embodiments described above are only illustrative.
  • the division of the units is only a logical function division. In actual implementation, there may be other division methods.
  • multiple units or components can be combined or May be integrated into another system, or some features may be ignored, or not implemented.
  • the mutual coupling or direct coupling or communication connection shown or discussed may be indirect coupling or communication connection through some interfaces, devices or units, and may also be electrical, mechanical or other forms of connection.
  • the units described as separate components may or may not be physically separated, and the components shown as units may or may not be physical units, that is, they may be located in one place, or may be distributed to multiple network units. Part or all of the units can be selected according to actual needs to achieve the purpose of the solution of the embodiment of the present application.
  • each functional unit in each embodiment of the present application may be integrated into one processing unit, each unit may exist separately physically, or two or more units may be integrated into one unit.
  • the above-mentioned integrated units can be implemented in the form of hardware or in the form of software functional units.
  • the integrated unit is realized in the form of a software function unit and sold or used as an independent product, it can be stored in a computer-readable storage medium.
  • the technical solution of the present application is essentially or the part that contributes to the prior art, or all or part of the technical solution can be embodied in the form of software products, and the computer software products are stored in a storage medium
  • several instructions are included to make a computer device (which may be a personal computer, a server, or a network device, etc.) execute all or part of the steps of the methods described in the various embodiments of the present application.
  • the aforementioned storage medium includes: U disk, mobile hard disk, read-only memory (read-only memory, ROM), random access memory (random access memory, RAM), magnetic disk or optical disc and other media that can store program codes. .

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Artificial Intelligence (AREA)
  • Computer Security & Cryptography (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)
  • Information Transfer Between Computers (AREA)

Abstract

本申请提供一种模型梯度更新方法及装置,用以提高模型训练的准确性。中心服务器重复执行梯度更新过程,直至满足停止条件;其中,一次梯度更新过程包括:接收多个节点分别发送的第一梯度,第一梯度为每个节点采用样本数据对节点中的待训练的模型进行一次或多次训练得到;基于多个第一梯度和本次梯度更新过程中的每个节点的概率得到第二梯度,本次梯度更新过程中的每个节点的概率为Actor-Critic网络基于上一次梯度更新过程中的每个节点的概率确定的;将第二梯度分别发送给多个节点,以使多个节点采用第二梯度对各自的待训练的模型的权重进行更新。考虑到每个节点的概率,可以优化节点参与度,使确定出的模型更优。

Description

一种模型梯度更新方法及装置
相关申请的交叉引用
本申请要求在2022年01月28日提交中国专利局、申请号为202210107380.3、申请名称为“一种模型梯度更新方法及装置”的中国专利申请的优先权,其全部内容通过引用结合在本申请中。
技术领域
本申请涉及模型训练技术领域,特别涉及一种模型梯度更新方法及装置。
背景技术
横向联邦学习也称为按样本划分的联邦学习,可以应用于联邦学习的各个参与方的数据集有相同的特征和不同的样本的场景。
通常假设一个横向联邦学习系统的参与方都是诚实的,需要防范的对象是一个诚实但好奇的中心服务器。即通常假设只有中心服务器才能使得数据参与方的隐私安全受到威胁。在横向联邦学习系统中,具有同样数据特征的多个参与方在中心服务器的帮助下,协作地训练一个模型。主要包括以下步骤:各参与方在本地计算模型梯度,并梯度(梯度需要加密)发送给中心服务器。中心服务器对多个梯度进行聚合。中心服务器将聚合后的梯度(梯度也需要加密)发送给各参与方。各参与方使用接收到的梯度更新各自的模型参数。
上述步骤持续迭代进行,直到损失函数收敛或者达到允许的迭代次数的上限或允许的训练时间,这种架构独立于特定的机器学习算法(如逻辑回归和深度神经网络),并且所有参与方将会共享最终的模型参数。
目前,横向联邦学习场景中,中心服务器对梯度进行平均聚合,考虑到不同参与方的性能不同,采用梯度平均的方式训练出的模型结果不佳。
发明内容
本申请提供一种模型梯度更新的方法及装置,用以提高模型训练的准确性。
为达到上述目的,本申请实施例公开了一种模型梯度更新方法,应用于中心服务器,包括:
中心服务器重复执行梯度更新过程,直至满足停止条件;其中,一次所述梯度更新过程包括:
接收多个节点分别发送的第一梯度,所述第一梯度为每个节点采用样本数据对节点中的待训练的模型进行一次或多次训练得到;基于多个第一梯度和本次梯度更新过程中的每个节点的概率得到第二梯度,所述本次梯度更新过程中的每个节点的概率为Actor-Critic网络基于上一次梯度更新过程中的每个节点的概率确定的;将所述第二梯度分别发送给所述多个节点,以使所述多个节点采用所述第二梯度对各自的待训练的模型的权重进行更新。
一种可选的示例中,所述Actor-Critic网络包括Actor网络、至少一个Critic网络、及奖励函数;
所述奖励函数用于基于上一次梯度更新过程中确定的所述多个节点的概率,确定奖励值,并将奖励值传输至所述至少一个Critic网络;
所述至少一个Critic网络用于确定目标Q值,并将所述目标Q值传输至所述Actor网络;
所述Actor网络用于基于所述目标Q值确定本次梯度更新过程中的每个节点的概率。
一种可选的示例中,所述目标Q值为多个Critic网络确定的Q值中的最小Q值。
一种可选的示例中,奖励函数满足:
Figure PCTCN2022112615-appb-000001
其中,A为第一准确率,B为第二准确率,g大于或等于1,其中,第一 准确率为所述中心服务器与所述多个节点基于联邦平均学习算法得到的训练完成的模型的准确率;第二准确率为所述多个节点分别发送的第三准确率的平均值,所述第三准确为与所述第一梯度在所述节点采用样本数据对所述节点中的待训练的模型进行同一次模型训练中得到的。
一种可选的示例中,当
Figure PCTCN2022112615-appb-000002
大于1时,g大于1;当
Figure PCTCN2022112615-appb-000003
小于或等于1时,g为1。
一种可选的示例中,所述Actor-Critic网络包括3个Critic网络,针对任一Critic网络,在本次梯度更新过程中确定的Q值基于Q值梯度和上一次梯度更新过程中确定的Q值确定,所述Q值梯度基于第一参数确定,所述第一参数满足以下公式:
Figure PCTCN2022112615-appb-000004
其中,
Figure PCTCN2022112615-appb-000005
其中,J为所述第一参数;t为本次梯度更新的次数;k>0,l>0,k+l=1;θ 1,θ 2,θ 3分别表示3个Critic网络,θ i为θ 1,θ 2,θ 3分别表示3个Critic网络最新确定出的Q值中的最小值对应的网络;s t为第t次梯度更新过程中的状态;a t为第t次梯度更新过程中所述多个节点的概率;
Figure PCTCN2022112615-appb-000006
为第t次梯度更新过程中θ i对应的Critic网络在s t,a t情况下确定的Q值;
Figure PCTCN2022112615-appb-000007
为第t次梯度更新过程中θ 3对应的Critic网络在s t,a t情况下输出的Q值;r(s t,a t)为第t次梯度更新过程中在s t,a t情况下的奖励值;γ大于0;π t(a t|s t)为在s t下做出a t的概率;q为熵的指数,ln q为熵,α t不为0。
一种可选的示例中,在本次梯度更新过程中采用的α基于α梯度和上一次梯度更新过程中采用的α确定,所述α梯度满足以下公式:
Figure PCTCN2022112615-appb-000008
其中,J(α)为α梯度,α t-1为上一次梯度更新采用的α,H为理想的最小 期望熵。
一种可选的示例中,所述k、l基于所述多个节点中的未进行梯度更新之前的模型的准确率的方差确定。
一种可选的示例中,所述Actor网络在本次梯度更新过程中输出的节点的概率基于概率梯度和上一次梯度更新过程中输出的节点的概率确定,所述概率梯度满足以下公式:
Figure PCTCN2022112615-appb-000009
其中,J(π φ)为概率梯度,t为本次梯度更新的次数;θ 1,θ 2,θ 3分别表示3个Critic网络,θ i为θ 1,θ 2,θ 3分别表示3个Critic网络最新确定出的Q值中的最小值对应的网络;s t为第t次梯度更新过程中的状态;a t为第t次梯度更新过程中所述多个节点的概率;
Figure PCTCN2022112615-appb-000010
为第t次梯度更新过程中θ i对应的Critic网络在s t,a t情况下确定的Q值;π t(a t|s t)为在s t下做出a t的概率;q为熵的指数,ln q为熵,α t为本次梯度更新采用的α,α t不为0。
一种可选的示例中,所述中心服务器和所述多个节点基于联邦学习架构进行梯度更新。
本申请实施例提供了一种模型梯度更新装置,包括:
重复执行梯度更新过程,直至满足停止条件;其中,一次所述梯度更新过程包括:
接收模块,用于接收多个节点分别发送的第一梯度,所述第一梯度为每个节点采用样本数据对节点中的待训练的模型进行一次或多次训练得到;
处理模块,用于基于多个第一梯度和本次梯度更新过程中的每个节点的概率得到第二梯度,所述本次梯度更新过程中的每个节点的概率为Actor-Critic网络基于上一次梯度更新过程中的每个节点的概率确定的;
发送模块,用于将所述第二梯度分别发送给所述多个节点,以使所述多个节点采用所述第二梯度对各自的待训练的模型的权重进行更新。
本申请实施例提供了一种模型梯度更新装置,包括处理器和存储器;
所述存储器,用于存储计算机程序或指令;
所述处理器,用于执行所述存储器中的部分或者全部计算机程序或指令,当所述部分或者全部计算机程序或指令被执行时,用于实现上述任一项所述的模型梯度更新方法。
本申请实施例提供了一种计算机可读存储介质,用于存储计算机程序,所述计算机程序包括用于实现任一项所述的模型梯度更新方法的指令。
本申请考虑到每个节点的概率,可以优化节点参与度,使确定出的模型更优。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本申请提供的一种模型梯度更新流程示意图;
图2为本申请提供的一种模型梯度更新系统架构图;
图3为本申请提供的一种模型梯度更新装置结构图;
图4为本申请提供的一种模型梯度更新装置结构图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
中心服务器重复执行梯度更新过程,直至满足停止条件。停止条件例如为损失函数收敛、或者达到允许的迭代次数的上限、或达到允许的训练时间。
接下来如图1所示,对任一次所述梯度更新过程进行介绍:
步骤101:多个节点(例如节点1、节点2、……节点m)分别向中心服务器发送第一梯度。相应的,中心服务器接收所述多个节点分别发送的第一梯度。
所述第一梯度为每个节点采用样本数据对所述节点中的待训练的模型进行一次或多次训练得到。将节点向中心服务器发送的梯度称为第一梯度,多个节点发送的第一梯度可能是相同的,也可能是不同的。
节点可以是汽车终端,样本数据可以是在自动驾驶中产生的数据。汽车行驶中生成不同的行驶数据,行驶数据分散在各个节点(汽车终端),并存在数据质量不均衡,节点性能不均衡的特点。模型可以是需要进行自动驾驶、用户习惯相关的模型。
步骤102:Actor-Critic网络基于上一次梯度更新过程中的每个节点的概率,确定本次梯度更新过程中的每个节点的概率。
可以理解的是,如果本次梯度更新为第一次梯度更新,则上一次梯度更新过程中的每个节点的概率为初始设定的每个节点的概率。
多个节点的概率之和可以是1。
Actor-Critic网络可以每循环一次就输出多个节点的概率,也可以是循环多次才输出节点的概率。
Actor-Critic网络可以是在中心服务器上,也可以不在所述中心服务器上。
步骤103:中心服务器基于多个第一梯度和多个概率得到第二梯度。
在基于多个第一梯度和多个概率得到第二梯度时,可以将多个第一梯度和多个概率用加权平均的方式确定第二梯度,例如节点1、节点2、节点3分别发送的第一梯度为p1、p2、p3,三个节点的概率为0.2、0.4、0.4,则第二梯度为0.2p1+0.4p2+0.4p3的取值。该过程也可以称为基于深度强化学习数据融合算法,多个第一梯度和多个概率得到第二梯度。
步骤104:中心服务器将所述第二梯度分别发送给所述多个节点(例如节点1、节点2、……节点m),相应的,多个节点接收来自中心服务器的第二梯度。
将中心服务器向节点发送的梯度称为第二梯度,向多个节点发送的第二梯度是相同的。
步骤105:多个节点采用所述第二梯度对各自的待训练的模型的权重进行更新。
如果满足停止条件,则更新后的模型即为训练完成的模型。
如果未满足停止条件,则每个节点可以采用样本数据对所述节点中的待训练的模型进行一次或多次训练,得到新的第一梯度,继续重复执行步骤101-步骤105。
考虑到每个节点的概率,可以优化节点参与度,使确定出的模型更优。
可选的,所述中心服务器和所述多个节点基于联邦学习架构进行梯度更新。各个节点的样本数据私有,与其它节点和中心服务器不共享,并且节点和中心服务器在传输第一梯度和第二梯度时,第一梯度和第二梯度是加密的。
如图2所示,介绍一种梯度更新系统架构图,包括多个节点、中心服务器、和Actor-Critic网络。Actor-Critic网络可以位于中心服务器上,也可以不位于中心服务器上。Actor-Critic从名字上看包括两部分,演员(Actor)和评价者(Critic)。其中Actor负责生成动作(Action)并和环境交互。而Critic负责评估Actor的表现,并指导Actor下一阶段的动作。
所述Actor-Critic网络包括Actor网络、至少一个Critic网络,及奖励函数;
所述奖励函数用于基于上一次梯度更新过程中确定的所述多个节点的概率,确定奖励值,并将奖励值传输至所述至少一个Critic网络;
所述至少一个Critic网络用于确定目标Q值,并将目标Q值传输至所述Actor网络。每个Critic网络确定出一个Q值,如果有一个Critic网络,则该Critic网络确定出的Q值即为目标Q值,如果有多个Critic网络,则可以在多个Q值选择出一个Q值作为目标Q值。例如目标Q值为多个Critic网络确定的Q值中的最小Q值。当有多个Critic网络时,相当于设置了多个评价者,评估Actor的表现更加准确,使Actor做出的动作更加准确,进而得出的多个 节点的概率符合多个节点的性能情况。
所述Actor网络基于所述目标Q值确定本次梯度更新过程中的每个节点的概率。并将所述多个节点的概率传输至所述奖励函数,多次循环,直至停止。
本申请可以采用现有的Actor-Critic网络确定概率,也可以对现有的Actor-Critic网络进行改进,例如设置多个Critic网络,例如对Critic网络中涉及的算法进行改进,例如对Actor网络中涉及的算法进行改进,例如对奖励函数进行改进,可以理解的是,改进只是涉及到具体的细节,改进后的Actor-Critic网络与现有的Actor-Critic网络的运行机制是类似的。
结合图2介绍的系统,对本申请的梯度更新过程进行详细介绍。
可以先对Actor-Critic网络中涉及的参数进行初始化,包括但不限于:对Critic网络中涉及的参数,对Actor网络中涉及的参数,对奖励函数中涉及的参数进行初始化。
中心服务器与多个节点基于联邦平均学习算法得到的训练完成的模型,并确定所述训练完成的模型的第一准确率A。
节点1、节点2……、节点m基于当前保存的模型(也可以称为待训练的模型,例如可以是还执行步骤101-步骤105的梯度更新过程的模型,也可以是已经执行过一次或多次步骤101-步骤105的梯度更新过程的模型)进行一次或多次训练得到第一梯度及第三准确率B’,各个节点向中心服务器发送第一梯度及第三准确率B’,第三准确率与第一梯度在节点进行同一次模型训练中得到的。
中心服务器对多个第三准确率B’计算平均值,得到第二准确率B。
奖励函数基于第一准确率A和第二准确率B确定奖励值r,第一准确率为所述中心服务器与所述多个节点基于联邦学习得到的训练完成的模型的准确率;第二准确率为所述多个节点分别发送的第三准确率的平均值,所述第三准确为与所述第一梯度在所述节点采用样本数据对所述节点中的待训练的模型进行同一次模型训练中得到的。
例如奖励函数表示为:
Figure PCTCN2022112615-appb-000011
结果B/结果A越高,奖励值r越高。可以理解的是,在任一次梯度更新过程中,A的取值都是相同的,B的取值可能相同,也可能不同。g大于或等于1。可选的,本申请设置2个奖励函数,分别为:当
Figure PCTCN2022112615-appb-000012
大于1时,g取值为大于1的常数,可以进行强引导,以更快的完成梯度训练;当
Figure PCTCN2022112615-appb-000013
小于或等于1时,g设置为1。
所述Critic网络本次梯度更新过程中确定的Q值(即更新后的Q值)基于Q值梯度和上一次梯度更新过程中确定的Q值(即更新前的Q值)确定,例如,更新后的Q值=更新前的Q值+Q值梯度。可以理解的是,如果本次梯度更新为第一次梯度更新,则上一次输出的Q值为初始设定的Q值。
一种示例中,Q值梯度基于第一算法和第二算法确定,其中,第一算法具有在训练中偏向于一个特定的动作的特性,第二算法具有在训练中均衡选择各个工作的特性。例如,第一算法可以是深度确定性策略梯度算法(Deep Deterministic Policy Gradient,DDPG)算法,第二算法可以是SAC算法,SAC算法可以是具有自动熵调节功能的SAC(Soft actor critic with automatic entropy adjustment,SAC-AEA)强化学习算法。DDPG受制于算法本身的更新策略,训练后期会偏向于一个特定的动作(action),这不利于调度多个节点的概率,对实现整体性模型融合是不利的,会导致最后训练出来的模型与某个节点的数据高度相关,其他节点的模型数据对模型结果贡献度变低,也就是对于多方数据的利用效率大大降低,甚至会导致模型训练结果不佳,或出现过拟合等问题。SAC-AEA强化学习算法,本身可以较为均衡选择各个动作(action),然而对于实际的联邦学习框架,各个节点的数据质量、对模型的贡献度、本地算力(例如本地设备的计算效率)等都不同(在本申请中可以将他们表述为优势方和非优势方)的情况下,均衡地融合他们的数据显然不利于模型训练结果的提升的,或出现欠拟合,不能完全表征完整的数据特征。本申请中 所述Critic网络确定的Q值基于DDPG算法和SAC算法更新。融合DDPG算法和SAC算法,可以结合两个算法的优势,使训练出的模型融合多个节点的性能,模型较优。本申请可以对DDPG算法设置第一权重,对SAC算法设置第二权重,基于DDPG算法及第一权重、SAC算法及第二权重确定Q值梯度,所述第一权重和所述第二权重基于所述多个节点中的未进行梯度更新之前的模型的准确率的方差确定。
一种示例中,在Critic网络中,提出了基于复合自适应可调权重的Q值更新算法;Q值梯度基于第一参数J确定,例如Q值梯度为第一参数J与步长的乘积,步长不为0。以3个Critic网络为例,第一参数J满足以下公式:
Figure PCTCN2022112615-appb-000014
Figure PCTCN2022112615-appb-000015
或者,
Figure PCTCN2022112615-appb-000016
其中,
Figure PCTCN2022112615-appb-000017
该示例以Actor-Critic网络可以每循环一次就输出多个节点的概率为例进行介绍,t为梯度更新的次数,例如,第t次梯度更新过程,t为大于或等于1的整数。
k>0,l>0,k+l=1;k和l可以是固定值,可以是人为设置的数值,还可以是基于所述多个节点中的未进行梯度更新之前的模型的准确率的方差确定;准确率的方差可以在一定程度上表示节点的性能(例如算力)差异等。方差越大,节点的性能差异越大,反之,方差越小,多个节点的性能差异越小。
θ 1,θ 2,θ 3分别表示3个Critic网络,θ i为θ 1,θ 2,θ 3分别表示3个Critic网络最新确定出的Q值中的最小值对应的网络。可以理解的是,在本次梯度更新过程中,3个Critic网络最新确定出的Q值是指在上一次梯度更新(如果本次梯度更新为第t次,则上一次为第t-1次)过程中确定出的Q值;
s t为第t次梯度更新过程中的状态;s t可以是各个节点的准确率或者准确率平均值,可以是各个节点的梯度或者梯度平均值,可以是各个节点的方差等;
a t为第t次梯度更新过程中所述多个节点的概率;a t也可以称为动作;
Figure PCTCN2022112615-appb-000018
为第t次梯度更新过程中θ i对应的Critic网络在s t,a t情况下确定的Q值;
Figure PCTCN2022112615-appb-000019
为第t次梯度更新过程中θ 3对应的Critic网络在s t,a t情况下输出的Q值;
r(s t,a t)为第t次梯度更新过程中在s t,a t情况下的奖励值;
γ为衰减因子,γ大于0;
π t(a t|s t)为条件概率,π t(a t|s t)为在s t下做出a t的概率;
q为熵(例如Tasslis熵)的指数,q为大于或等于1的整数,可以是1或2或3,ln q为熵,是一个曲线。当q不同时,ln q为一个曲线族;
E为数字期望,E对[*]内的数据求期望,自变量为s t和a t,上述公式中的[]内的内容是对s t和a t的隐式表达;
D为记忆库(也可以称为经验回放池、或缓存空间),(s,a)~D是指记忆库中的s,a;假设记忆库D中可以存储M个s及M个a,记忆库D为循环覆盖,当则Actor-Critic网络可以先自循环M次,得到M个s及M个a,在第M+1次循环时Actor-Critic网络才输出节点的概率,前面的M次循环中Actor-Critic网络不输出节点的概率,或者说前面的M次循环中Actor-Critic网络输出的节点的概率可以忽略不计。
a~π φ表示π φ基于a确定,例如多个a组成π φ曲线(或集合)。
α t为第t次梯度更新过程中采用的α,α可以是个固定值,不为0即可,α也可以是变量(即自适应参数),本次梯度更新采用的(更新后的)α基于α梯度和上一次梯度更新采用的(更新前的)α确定。例如,本次梯度更新采用的(更新后的)α=α梯度和上一次梯度更新采用的(更新前的)α。可以理解的是,如果本次梯度更新为第一次梯度更新,则上一次采用的α为初始设定的α。α梯 度满足以下公式:
Figure PCTCN2022112615-appb-000020
其中,J(α)为α梯度,α t-1为上一次梯度更新采用的α,H为理想的最小期望熵,熵例如为Tasslis熵。
通过复合自适应可调权重的Q值更新方法,可以自适应调整梯度权重参数,可以根据实际场景,只需调整权重参数,对于进行特定动作action选取,或均衡动作选取,从而可以根据具体情况调度各参与方的模型信息。
所述Actor网络本次输出的节点的概率(即更新后的节点的概率)为上一次输出的节点的概率(即更新前的节点的概率)与概率梯度的和值。
在Actor网络中融入Tasslis熵概念和自适应参数,概率梯度基于以下公式确定:
Figure PCTCN2022112615-appb-000021
其中,J(π φ)为概率梯度,α t为本次(第t次)梯度更新过程得到的α;θ i为θ 1,θ 2,θ 3分别表示3个Critic网络最新确定出的Q值中的最小值对应的网络。可以理解的是,在本次梯度更新过程中,3个Critic网络最新确定出的Q值是指在本次梯度更新(例如第t次梯度更新)过程中确定出的Q值。
在联邦框架中的服务器端,调整了不同节点的数据融合算法,设计了将深度强化学习数据融合模型与联邦平均算法相结合的优化融合策略。可以调整不同节点的参与度和训练时的数据利用程度。
前文介绍了本申请实施例的方法,下文中将介绍本申请实施例中的装置。方法、装置是基于同一技术构思的,由于方法、装置解决问题的原理相似,因此装置与方法的实施可以相互参见,重复之处不再赘述。
本申请实施例可以根据上述方法示例,对装置进行功能模块的划分,例如,可以对应各个功能划分为各个功能模块,也可以将两个或两个以上的功能集成在一个模块中。这些模块既可以采用硬件的形式实现,也可以采用软件功能模块的形式实现。需要说明的是,本申请实施例中对模块的划分是示 意性的,仅仅为一种逻辑功能划分,具体实现时可以有另外的划分方式。
基于与上述方法的同一技术构思,参见图3,提供了一种模型梯度更新装置,包括:
接收模块301,用于接收多个节点分别发送的第一梯度,所述第一梯度为每个节点采用样本数据对节点中的待训练的模型进行一次或多次训练得到;
处理模块302,用于基于多个第一梯度和本次梯度更新过程中的每个节点的概率得到第二梯度,所述本次梯度更新过程中的每个节点的概率为Actor-Critic网络基于上一次梯度更新过程中的每个节点的概率确定的;
发送模块303,用于将所述第二梯度分别发送给所述多个节点,以使所述多个节点采用所述第二梯度对各自的待训练的模型的权重进行更新。
以上过程为一次所述梯度更新过程,重复执行梯度更新过程,直至满足停止条件。
基于与上述方法的同一技术构思,参见图4,提供了一种模型梯度更新装置,包括处理器401和存储器402,可选的,还包括收发器403;
所述存储器402,用于存储计算机程序或指令;
所述处理器401,用于执行所述存储器中的部分或者全部计算机程序或指令,当所述部分或者全部计算机程序或指令被执行时,用于实现上述任一项所述的模型梯度更新方法。例如收发器403执行接收和发送动作,处理器401执行处接收和发送动作外的其它动作。
本申请实施例提供了一种计算机可读存储介质,用于存储计算机程序,所述计算机程序包括用于实现任一项所述的模型梯度更新方法的指令。
本申请实施例还提供了一种计算机程序产品,包括:计算机程序代码,当所述计算机程序代码在计算机上运行时,使得计算机可以执行上述提供的模型梯度更新的方法。
本申请实施例还提供了一种通信的系统,所述通信系统包括:执行上述模型梯度更新的方法的节点和中心服务器。
另外,本申请实施例中提及的处理器可以是中央处理器(central processing unit,CPU),基带处理器,基带处理器和CPU可以集成在一起,或者分开,还可以是网络处理器(network processor,NP)或者CPU和NP的组合。处理器还可以进一步包括硬件芯片或其他通用处理器。上述硬件芯片可以是专用集成电路(application-specific integrated circuit,ASIC),可编程逻辑器件(programmable logic device,PLD)或其组合。上述PLD可以是复杂可编程逻辑器件(complex programmable logic device,CPLD),现场可编程逻辑门阵列(field-programmable gate array,FPGA),通用阵列逻辑(generic array logic,GAL)及其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等或其任意组合。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
本申请实施例中提及的存储器可以是易失性存储器或非易失性存储器,或可包括易失性和非易失性存储器两者。其中,非易失性存储器可以是只读存储器(Read-Only Memory,ROM)、可编程只读存储器(Programmable ROM,PROM)、可擦除可编程只读存储器(Erasable PROM,EPROM)、电可擦除可编程只读存储器(Electrically EPROM,EEPROM)或闪存。易失性存储器可以是随机存取存储器(Random Access Memory,RAM),其用作外部高速缓存。通过示例性但不是限制性说明,许多形式的RAM可用,例如静态随机存取存储器(Static RAM,SRAM)、动态随机存取存储器(Dynamic RAM,DRAM)、同步动态随机存取存储器(Synchronous DRAM,SDRAM)、双倍数据速率同步动态随机存取存储器(Double Data Rate SDRAM,DDR SDRAM)、增强型同步动态随机存取存储器(Enhanced SDRAM,ESDRAM)、同步连接动态随机存取存储器(Synchlink DRAM,SLDRAM)和直接内存总线随机存取存储器(Direct Rambus RAM,DR RAM)。应注意,本申请描述的存储器旨在包括但不限于这些和任意其它适合类型的存储器。
本申请实施例中提及的收发器中可以包括单独的发送器,和/或,单独的接收器,也可以是发送器和接收器集成一体。收发器可以在相应的处理器的 指示下工作。可选的,发送器可以对应物理设备中发射机,接收器可以对应物理设备中的接收机。
本领域普通技术人员可以意识到,结合本文中所公开的实施例中描述的各方法步骤和单元,能够以电子硬件、计算机软件或者二者的结合来实现,为了清楚地说明硬件和软件的可互换性,在上述说明中已经按照功能一般性地描述了各实施例的步骤及组成。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。本领域普通技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
在本申请所提供的几个实施例中,应该理解到,所揭露的系统、装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另外,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口、装置或单元的间接耦合或通信连接,也可以是电的,机械的或其它的形式连接。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本申请实施例方案的目的。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以是两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分,或者该技术方案 的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(read-only memory,ROM)、随机存取存储器(random access memory,RAM)、磁碟或者光盘等各种可以存储程序代码的介质。
本申请中的“和/或”,描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。字符“/”一般表示前后关联对象是一种“或”的关系。本申请中所涉及的多个,是指两个或两个以上。另外,需要理解的是,在本申请的描述中,“第一”、“第二”等词汇,仅用于区分描述的目的,而不能理解为指示或暗示相对重要性,也不能理解为指示或暗示顺序。
尽管已描述了本申请的优选实施例,但本领域内的技术人员一旦得知了基本创造性概念,则可对这些实施例作出另外的变更和修改。所以,所附权利要求意欲解释为包括优选实施例以及落入本申请范围的所有变更和修改。
显然,本领域的技术人员可以对本申请实施例进行各种改动和变型而不脱离本申请实施例的精神和范围。这样,倘若本申请实施例的这些修改和变型属于本申请权利要求及其等同技术的范围之内,则本申请也意图包括这些改动和变型在内。

Claims (13)

  1. 一种模型梯度更新方法,应用于中心服务器,包括:
    中心服务器重复执行梯度更新过程,直至满足停止条件;其中,一次所述梯度更新过程包括:
    接收多个节点分别发送的第一梯度,所述第一梯度为每个节点采用样本数据对节点中的待训练的模型进行一次或多次训练得到;基于多个第一梯度和本次梯度更新过程中的每个节点的概率得到第二梯度,所述本次梯度更新过程中的每个节点的概率为Actor-Critic网络基于上一次梯度更新过程中的每个节点的概率确定的;将所述第二梯度分别发送给所述多个节点,以使所述多个节点采用所述第二梯度对各自的待训练的模型的权重进行更新。
  2. 如权利要求1所述的方法,其中,所述Actor-Critic网络包括Actor网络、至少一个Critic网络、及奖励函数;
    所述奖励函数用于基于上一次梯度更新过程中确定的所述多个节点的概率,确定奖励值,并将奖励值传输至所述至少一个Critic网络;
    所述至少一个Critic网络用于确定目标Q值,并将所述目标Q值传输至所述Actor网络;
    所述Actor网络用于基于所述目标Q值确定本次梯度更新过程中的每个节点的概率。
  3. 如权利要求2所述的方法,其中,所述目标Q值为多个Critic网络确定的Q值中的最小Q值。
  4. 如权利要求2所述的方法,其中,奖励函数满足:
    Figure PCTCN2022112615-appb-100001
    其中,A为第一准确率,B为第二准确率,g大于或等于1,其中,第一准确率为所述中心服务器与所述多个节点基于联邦平均学习算法得到的训练完成的模型的准确率;第二准确率为所述多个节点分别发送的第三准确率的平均值,所述第三准确为与所述第一梯度在所述节点采用样本数据对所述节 点中的待训练的模型进行同一次模型训练中得到的。
  5. 如权利要求4所述的方法,其中,当
    Figure PCTCN2022112615-appb-100002
    大于1时,g大于1;当
    Figure PCTCN2022112615-appb-100003
    小于或等于1时,g为1。
  6. 如权利要求2所述的方法,其中,所述Actor-Critic网络包括3个Critic网络,针对任一Critic网络,在本次梯度更新过程中确定的Q值基于Q值梯度和上一次梯度更新过程中确定的Q值确定,所述Q值梯度基于第一参数确定,所述第一参数满足以下公式:
    Figure PCTCN2022112615-appb-100004
    其中,
    Figure PCTCN2022112615-appb-100005
    其中,J为所述第一参数;t为本次梯度更新的次数;k>0,l>0,k+l=1;θ 1,θ 2,θ 3分别表示3个Critic网络,θ i为θ 1,θ 2,θ 3分别表示3个Critic网络最新确定出的Q值中的最小值对应的网络;s t为第t次梯度更新过程中的状态;a t为第t次梯度更新过程中所述多个节点的概率;
    Figure PCTCN2022112615-appb-100006
    为第t次梯度更新过程中θ i对应的Critic网络在s t,a t情况下确定的Q值;
    Figure PCTCN2022112615-appb-100007
    为第t次梯度更新过程中θ 3对应的Critic网络在s t,a t情况下输出的Q值;r(s t,a t)为第t次梯度更新过程中在s t,a t情况下的奖励值;γ大于0;π t(a t|s t)为在s t下做出a t的概率;q为熵的指数,ln q为熵,α t不为0。
  7. 如权利要求6所述的方法,其中,在本次梯度更新过程中采用的α基于α梯度和上一次梯度更新过程中采用的α确定,所述α梯度满足以下公式:
    Figure PCTCN2022112615-appb-100008
    其中,J(α)为α梯度,α t-1为上一次梯度更新采用的α,H为理想的最小期望熵。
  8. 如权利要求6所述的方法,其中,所述k、l基于所述多个节点中的未进行梯度更新之前的模型的准确率的方差确定。
  9. 如权利要求2所述的方法,其中,所述Actor网络在本次梯度更新过程中输出的节点的概率基于概率梯度和上一次梯度更新过程中输出的节点的概率确定,所述概率梯度满足以下公式:
    Figure PCTCN2022112615-appb-100009
    其中,J(π φ)为概率梯度,t为本次梯度更新的次数;θ 1,θ 2,θ 3分别表示3个Critic网络,θ i为θ 1,θ 2,θ 3分别表示3个Critic网络最新确定出的Q值中的最小值对应的网络;s t为第t次梯度更新过程中的状态;a t为第t次梯度更新过程中所述多个节点的概率;
    Figure PCTCN2022112615-appb-100010
    为第t次梯度更新过程中θ i对应的Critic网络在s t,a t情况下确定的Q值;π t(a t|s t)为在s t下做出a t的概率;q为熵的指数,ln q为熵,α t为本次梯度更新采用的α,α t不为0。
  10. 如权利要求1-9任一项所述的方法,其中,所述中心服务器和所述多个节点基于联邦学习架构进行梯度更新。
  11. 一种模型梯度更新装置,包括:
    重复执行梯度更新过程,直至满足停止条件;其中,一次所述梯度更新过程包括:
    接收模块,用于接收多个节点分别发送的第一梯度,所述第一梯度为每个节点采用样本数据对节点中的待训练的模型进行一次或多次训练得到;
    处理模块,用于基于多个第一梯度和本次梯度更新过程中的每个节点的概率得到第二梯度,所述本次梯度更新过程中的每个节点的概率为Actor-Critic网络基于上一次梯度更新过程中的每个节点的概率确定的;
    发送模块,用于将所述第二梯度分别发送给所述多个节点,以使所述多个节点采用所述第二梯度对各自的待训练的模型的权重进行更新。
  12. 一种模型梯度更新装置,包括处理器和存储器;
    所述存储器,用于存储计算机程序或指令;
    所述处理器,用于执行所述存储器中的部分或者全部计算机程序或指令,当所述部分或者全部计算机程序或指令被执行时,用于实现如权利要求1-10 任一项所述的方法。
  13. 一种计算机可读存储介质,用于存储计算机程序,所述计算机程序包括用于实现权利要求1-10任一项所述的方法的指令。
PCT/CN2022/112615 2022-01-28 2022-08-15 一种模型梯度更新方法及装置 WO2023142439A1 (zh)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN202210107380.3A CN114492841A (zh) 2022-01-28 2022-01-28 一种模型梯度更新方法及装置
CN202210107380.3 2022-01-28

Publications (1)

Publication Number Publication Date
WO2023142439A1 true WO2023142439A1 (zh) 2023-08-03

Family

ID=81477080

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2022/112615 WO2023142439A1 (zh) 2022-01-28 2022-08-15 一种模型梯度更新方法及装置

Country Status (2)

Country Link
CN (1) CN114492841A (zh)
WO (1) WO2023142439A1 (zh)

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114492841A (zh) * 2022-01-28 2022-05-13 中国银联股份有限公司 一种模型梯度更新方法及装置
CN117725979B (zh) * 2023-09-27 2024-09-20 行吟信息科技(上海)有限公司 模型训练方法及装置、电子设备及计算机可读存储介质

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112087518A (zh) * 2020-09-10 2020-12-15 工银科技有限公司 用于区块链的共识方法、装置、计算机系统和介质
US20210125032A1 (en) * 2019-10-24 2021-04-29 Alibaba Group Holding Limited Method and system for distributed neural network training
CN112818394A (zh) * 2021-01-29 2021-05-18 西安交通大学 具有本地隐私保护的自适应异步联邦学习方法
CN113282933A (zh) * 2020-07-17 2021-08-20 中兴通讯股份有限公司 联邦学习方法、装置和系统、电子设备、存储介质
CN113971089A (zh) * 2021-09-27 2022-01-25 国网冀北电力有限公司信息通信分公司 联邦学习系统设备节点选择的方法及装置
CN114492841A (zh) * 2022-01-28 2022-05-13 中国银联股份有限公司 一种模型梯度更新方法及装置

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20210125032A1 (en) * 2019-10-24 2021-04-29 Alibaba Group Holding Limited Method and system for distributed neural network training
CN113282933A (zh) * 2020-07-17 2021-08-20 中兴通讯股份有限公司 联邦学习方法、装置和系统、电子设备、存储介质
CN112087518A (zh) * 2020-09-10 2020-12-15 工银科技有限公司 用于区块链的共识方法、装置、计算机系统和介质
CN112818394A (zh) * 2021-01-29 2021-05-18 西安交通大学 具有本地隐私保护的自适应异步联邦学习方法
CN113971089A (zh) * 2021-09-27 2022-01-25 国网冀北电力有限公司信息通信分公司 联邦学习系统设备节点选择的方法及装置
CN114492841A (zh) * 2022-01-28 2022-05-13 中国银联股份有限公司 一种模型梯度更新方法及装置

Also Published As

Publication number Publication date
CN114492841A (zh) 2022-05-13

Similar Documents

Publication Publication Date Title
WO2023142439A1 (zh) 一种模型梯度更新方法及装置
WO2021259090A1 (zh) 联邦学习的方法、装置和芯片
JP2023505973A (ja) 連合混合モデル
Shao et al. Almost optimal algorithms for linear stochastic bandits with heavy-tailed payoffs
WO2022193432A1 (zh) 模型参数更新方法、装置、设备、存储介质及程序产品
US20210027161A1 (en) Learning in communication systems
Farshbafan et al. Curriculum learning for goal-oriented semantic communications with a common language
US20240135191A1 (en) Method, apparatus, and system for generating neural network model, device, medium, and program product
CN112948885B (zh) 实现隐私保护的多方协同更新模型的方法、装置及系统
US20210065011A1 (en) Training and application method apparatus system and stroage medium of neural network model
CN112667400B (zh) 边缘自治中心管控的边云资源调度方法、装置及系统
US20220318412A1 (en) Privacy-aware pruning in machine learning
US11843587B2 (en) Systems and methods for tree-based model inference using multi-party computation
US20230153633A1 (en) Moderator for federated learning
CN115378788A (zh) 基于分层共识和强化学习的区块链性能自适应优化方法
CN114819196B (zh) 基于噪音蒸馏的联邦学习系统及方法
CN113298247A (zh) 智能体决策的方法和装置
US20220366266A1 (en) Agent training method, apparatus, and computer-readable storage medium
Li et al. Computation Offloading in Resource-Constrained Multi-Access Edge Computing
CN113806691B (zh) 一种分位数的获取方法、设备及存储介质
WO2024032694A1 (zh) Csi预测处理方法、装置、通信设备及可读存储介质
WO2023093229A1 (zh) 一种联合学习参数聚合方法、装置及系统
WO2024067280A1 (zh) 更新ai模型参数的方法、装置及通信设备
US11693989B2 (en) Computer-implemented methods and nodes implementing performance estimation of algorithms during evaluation of data sets using multiparty computation based random forest
WO2023225552A1 (en) Decentralized federated learning using a random walk over a communication graph

Legal Events

Date Code Title Description
121 Ep: the epo has been informed by wipo that ep was designated in this application

Ref document number: 22923243

Country of ref document: EP

Kind code of ref document: A1

WWE Wipo information: entry into national phase

Ref document number: 18690017

Country of ref document: US

NENP Non-entry into the national phase

Ref country code: DE