CN116523079A - Reinforced learning-based federal learning optimization method and system - Google Patents
Reinforced learning-based federal learning optimization method and system Download PDFInfo
- Publication number
- CN116523079A CN116523079A CN202310230326.2A CN202310230326A CN116523079A CN 116523079 A CN116523079 A CN 116523079A CN 202310230326 A CN202310230326 A CN 202310230326A CN 116523079 A CN116523079 A CN 116523079A
- Authority
- CN
- China
- Prior art keywords
- model
- training
- learning
- round
- global
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 52
- 238000005457 optimization Methods 0.000 title claims abstract description 26
- 238000012549 training Methods 0.000 claims abstract description 109
- 230000002787 reinforcement Effects 0.000 claims abstract description 43
- 230000009471 action Effects 0.000 claims abstract description 17
- 230000002776 aggregation Effects 0.000 claims description 38
- 238000004220 aggregation Methods 0.000 claims description 38
- 238000010801 machine learning Methods 0.000 claims description 27
- 230000006870 function Effects 0.000 claims description 21
- 239000003795 chemical substances by application Substances 0.000 claims description 18
- 238000011156 evaluation Methods 0.000 claims description 13
- 238000004422 calculation algorithm Methods 0.000 claims description 12
- 238000004891 communication Methods 0.000 claims description 8
- 238000012360 testing method Methods 0.000 claims description 6
- 230000008569 process Effects 0.000 claims description 5
- 238000013528 artificial neural network Methods 0.000 claims description 4
- 238000009826 distribution Methods 0.000 claims description 4
- 238000013527 convolutional neural network Methods 0.000 claims description 3
- 238000000605 extraction Methods 0.000 claims description 3
- 238000003062 neural network model Methods 0.000 claims description 2
- 238000012795 verification Methods 0.000 claims 1
- 238000004364 calculation method Methods 0.000 description 5
- 230000000694 effects Effects 0.000 description 5
- 238000013136 deep learning model Methods 0.000 description 3
- 238000011161 development Methods 0.000 description 3
- 238000002474 experimental method Methods 0.000 description 3
- 230000005540 biological transmission Effects 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 238000006467 substitution reaction Methods 0.000 description 2
- 238000010200 validation analysis Methods 0.000 description 2
- 241001421757 Arcas Species 0.000 description 1
- 241000102542 Kara Species 0.000 description 1
- 230000004931 aggregating effect Effects 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 201000003639 autosomal recessive cerebellar ataxia Diseases 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 238000007405 data analysis Methods 0.000 description 1
- 238000012517 data analytics Methods 0.000 description 1
- 238000013500 data storage Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 230000003203 everyday effect Effects 0.000 description 1
- 239000003550 marker Substances 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000006855 networking Effects 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000003860 storage Methods 0.000 description 1
- 230000001360 synchronised effect Effects 0.000 description 1
- 230000007704 transition Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/20—Ensemble learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F21/00—Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
- G06F21/60—Protecting data
- G06F21/62—Protecting access to data via a platform, e.g. using keys or access control rules
- G06F21/6218—Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
- G06F21/6245—Protecting personal data, e.g. for financial or medical purposes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/092—Reinforcement learning
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Medical Informatics (AREA)
- Molecular Biology (AREA)
- Bioethics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Probability & Statistics with Applications (AREA)
- Databases & Information Systems (AREA)
- Computer Hardware Design (AREA)
- Computer Security & Cryptography (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
The invention discloses a reinforcement learning-based federal learning optimization method and system on heterogeneous resources, comprising the following steps: the method provided by the invention can estimate the model training action of the heterogeneous equipment, greatly accelerate the speed of federal learning training, reduce the interference of malicious data on training, and protect the data privacy of multiple participants.
Description
Technical Field
The invention belongs to the field of federal learning efficiency optimization in edge computing, and particularly relates to a federal learning efficiency optimization method and system based on heterogeneous resources.
Background
With the rapid development of the internet and the internet of things, the number of intelligent devices is rapidly increased in the past few years, and the data volume generated every day is also unprecedented nowadays. These data have an important role in the deep learning model. Generally, the development of artificial intelligence is not driven by big data, and some complex large-scale deep learning models often need massive data for iterative training to obtain an effective model, and the data is transmitted to a cloud center server in the past. However, as the amount of data increases, the trained deep learning model is increasingly complex, if all the data is offloaded to the cloud server, the server is loaded, all the data is impractical to be transmitted from the local device to the remote server for further processing under the constraint of limited network bandwidth, and the transmission process also has data privacy and security problems. In view of these key factors, trends in data storage and analysis are moving from cloud-based centralized to distributed and on-device. A key enabling technology for this transition is edge computing, which offloads complex computing tasks or applications by providing computing resources to support the internet of things device. Through edge calculation, the problems of insufficient calculation capacity, high unloading delay to a cloud server, data security and the like of the Internet of things equipment can be effectively solved.
Due to the rapid development of edge computing and distributed machine learning (S.Arisdakessian, Q.A.Wahab, A.Mourad, H.Otrok, and n.kara, | Intelligent Multi-criterion IoT-Fog Scheduling Approach llsing game theory, transmission on Networking, 2020.), a new distributed machine learning paradigm, federal learning (H.B.McMahan, E.Moore, D.Ramage, S.Hampson, and b.a. arcas, | "Communication-efficientlearning of deep networks from decentralized data," in Proceedings of Machine Learning Research, vol.54, pp.1273-1282, apr.2017.), to enable local and distributed machine learning training at the edge node or end device level by co-training the machine learning model without exposing the raw data. Federal learning typically employs a parameter server architecture in which the terminal trains a model synchronized by the parameter server.
Federal learning is a distributed machine learning method that can be trained on a large amount of decentralized data on local devices. In particular, there is a well-known scheme, named FedAvg, in which a model is initialized at a central server, then downloaded by an edge device, iterated for local update (using gradient descent) multiple times, then aggregated globally at the server side, and weighted-averaged according to the model obtained from the edge device, and the distributed machine learning method is particularly suitable for Internet of things. Due to the advantages, federal learning is widely applied in the fields of Internet of things and mobile edge computing. For example, training deep reinforcement learning agents in a distributed fashion using federal learning, prediction of the next word on google keyboards is also improved by federal learning.
Many existing federal learning algorithms tend to aggregate global models according to the size of the data volume when model aggregation is performed. However, the relationship between model accuracy and training data volume is nonlinear (zhan Y, li P, wang K, et al Big Data Analytics by CrowdLearning: architecture and Mechanism Design [ J ]. IEEENetwork,2020,PP (99): 1-5.), in practical situations, the data distribution over each edge device varies widely, and it is difficult to evaluate the magnitude of the effect of each device on global model training under limited information due to the nature of distributed training and the consideration of protecting user data privacy. Therefore, it is necessary to accurately estimate the action of each heterogeneous device on the current model, and further scientifically allocate the weight ratio of the model in aggregation, so as to speed up training and improve the communication efficiency of training.
Disclosure of Invention
The invention aims to overcome the defects in the prior art, and provides a federal learning training method based on reinforcement learning, wherein the reinforcement learning algorithm is used for evaluating the action of each heterogeneous device on a trained global model, and the weight of each heterogeneous device in global model aggregation is further decided according to the evaluation result. Meanwhile, malicious data can be effectively identified, and the influence of some malicious data on training is reduced. Actual experiments prove that the method provided by the invention can improve the training mode of federal learning in an actual environment and accelerate the speed of federal learning training.
The invention is realized at least by one of the following technical schemes.
A reinforcement learning-based federal learning optimization method comprises the following steps:
s1, determining a data set image task of machine learning training;
s2, initializing neural network model parameters of corresponding tasks of the central server according to the corresponding tasks, namely initializing a global model;
s3, selecting edge equipment which needs to participate in the round of training by the central server, and downloading a global model by the selected edge equipment;
s4, under the federal learning framework, the edge equipment locally performs model parameter training, and uploads model parameters after the training is completed;
s5, the central server evaluates contribution of edge equipment participating in training, and the edge equipment model is weighted and aggregated according to the contribution size;
and S6, continuously repeating the steps S3-S5 until the trained model converges.
Further, in step S1, the edge device includes a user intelligent device and an internet of things device.
Further, the model initialization includes: training a data set by adopting a LeNet-5 convolutional neural network according to the data set image recognition task in the step S1, initializing model parameters, setting machine learning model parameters as w, and defining a feature space X and a label set Y= {1, &. B }, wherein B is the total number of labels; let (X, Y) denote the marked sample, X ε X, Y ε Y; let f w X-S represents the prediction function, S= { z|Σ j∈Y z j =1,z j 0 is represented as a probability vector for each sample corresponding to a different label, z j Representing the probability that the sample is classified as a label j, i.e., trained machine learning model parameters; function f w A probability vector is generated for each sample.
4. The reinforcement learning federal learning optimization method according to claim 2, wherein a loss function of a machine learning task corresponding prediction function of federal learning training is defined as a cross entropy:
where p represents the true probability distribution,in order to predict the probability that the sample x belongs to the j-th type of labels, D is the total number of label samples, and a random gradient descent strategy is adopted in the training process to continuously reduce the value of the loss function until the model converges.
Further, when the federal learning training model is utilized, the parameter settings of the federal learning framework include: let t denote the current federal training round t, the communication content of the server and the edge device (w t τ) andwherein t=0, 1, … n, w t For global model parameters of the t training round, τ represents the number of training iterations of the edge device local model, +.>Representing model parameters on the edge device k during the t-th training, and eta represents the learning rate during the edge device training.
Further, in step S3, a method of randomly selecting a client device is adopted to select a client device that performs training in this round.
Further, in step S4, the edge device trains the model parameters locally according to the preset parameters, and after the local model training is completed, the edge device uploads the model parameters to the central server.
Further, in step S5, the central server adopts the reinforcement learning model to evaluate the contribution of each edge device to the global model according to the edge device model parameters and the characteristics of the current global model parameters, and specifically as follows:
first pre-training a reinforcement learning agent at a central server:
status: let K devices in each round to train, and the state of each round training is expressed as Model parameters representing the kth device involved in training for the t-th round, w t The method comprises the steps that the method is a global model parameter of a t-th round, the state of each round comprises the characteristics of model parameters uploaded by K edge devices and current global model parameters, each round of federal learning training is finished, each edge device transmits the model parameters from the local to a central server, then the global model is updated, the aggregation duty ratio weight of each user device participating in aggregation is required to be evaluated firstly, and then the global model is updated;
the actions are as follows: in the federal learning, when the global parameters are weighted and polymerized in each round, the state S of the current parameters is used for t Estimating the proportion of each edge equipment model in the global model aggregation, and learning the approximate value Q of the optimal action value through the value evaluation function in the DDQN * (S t A), a represents a prediction action made by a reinforcement learning agent, so that the DDQN neural network outputs the contribution of edge devices participating in training to a current global model, and the edge devices participating in global model aggregation are distributed with aggregation duty ratio according to the action of each edge device to the global model, so that the aggregation of the global model is quickened;
rewarding: the observed rewards at the end of each round t are set to
t represents the number of global iterations, Φ t Is t wheelThe test accuracy achieved by the post global model on the maintained validation set, Ω is the target accuracy, also m (m>1) Is a normal number, ensure r t Exponentially increasing with test accuracy because 0.ltoreq.Φ t Omega is not less than 1, so r t ∈(-1,0]。
Further, the contribution of the learning output client k to the global model in the t-th round is learned by using the reinforcement learning agent in the central serverThen weighted average is carried out, and the weight which the client k should occupy in the aggregation of the global model in the t-th round is calculated>
Wherein the method comprises the steps ofRepresenting the sum of the contributions of the n devices participating in the t-th round of federal learning training, ++>Representing the contribution size of the client device k participating in the t-th round of training in the present training;
performing global weighted aggregation according to the estimated contribution size:
wherein the method comprises the steps ofRepresenting model parameters on the kth device at the t-th round, w t+1 Global model parameters for the t+1 round. After the weighted aggregation of the global model is completed, the above steps are repeated until the model converges.
The system for realizing the reinforcement learning federal learning optimization method comprises the following components:
a device selection unit: selecting equipment to participate in the training of the round by using a method of randomly selecting equipment every time training is performed;
device contribution evaluation unit: after the equipment model parameters are uploaded, the central server evaluates the contribution of each local equipment through a reinforcement learning algorithm according to the characteristics of each local model parameter;
global model updating unit: when the global model is updated, the current global model is updated by carrying out weighted average on each model parameter according to the contribution of each device to the global model;
in the equipment contribution evaluation unit, feature extraction is carried out on model parameters uploaded by each equipment, the model parameters are input into the reinforcement learning agent for contribution evaluation, and the central server carries out global model aggregation according to the size of the contribution of each equipment and a set weighted average algorithm.
Compared with the prior art, the invention has the following beneficial effects:
(1) The invention provides a heterogeneous federal learning training method based on DDQN, which aims to optimize the aggregation process of a global model, uses the DDQN model to evaluate the contribution of each device participating in the round of training, and carries out self-adaptive adjustment on the weight of the model participating in the aggregation, thereby greatly accelerating the aggregation speed of the model, improving the training speed, reducing the communication rounds required by the training and saving the communication resources.
(2) The method and the system can accurately evaluate and evaluate the contribution of each device to the global model, and further provide different rewards for the devices participating in training according to the contribution.
(3) The invention can avoid the influence of malicious data on the training of the model, so that the model can normally complete the training.
Drawings
FIG. 1 is a schematic diagram of a system in the practice of the present invention;
FIG. 2 is a flow chart of a reinforcement learning-based federal learning optimization method in accordance with an embodiment of the present invention;
FIG. 3 is a block diagram of a reinforcement learning based federal learning optimization system according to an embodiment of the present invention.
Detailed Description
In order that those skilled in the art will better understand the present invention, the following description will be given in detail with reference to the accompanying drawings and detailed description. It will be apparent that the described embodiments are only some, but not all, embodiments of the invention. All other embodiments, which can be made by those skilled in the art based on the embodiments of the invention without making any inventive effort, are intended to be within the scope of the invention.
The invention is further described below by means of specific embodiments.
According to the invention, through researching the problem of federal learning optimization in edge calculation, the federal learning training method of heterogeneous resources (edge equipment with difference in edge equipment data) based on DDQN is provided, meanwhile, privacy data of users can be protected, influence of malicious data on experiments is avoided, communication efficiency is improved, and training speed is accelerated.
As shown in FIG. 1, the federal learning optimization system based on reinforcement learning comprises a central server and edge equipment, wherein the central server and the edge equipment cooperatively train a machine learning model, the edge equipment does not need to upload data to the central server, and model parameters trained by machine learning tasks are transmitted between the edge equipment and the central server.
Because federal learning belongs to a sub-domain under distributed computing, machine learning models that can be trained at a central server can also be trained by the federal learning framework, that is, it can be used to train various tasks such as speech recognition, image recognition, input method prediction models, and the like. Unlike the traditional machine learning model training in the cloud center, federal learning does not need to upload data to the cloud end, and can directly train the machine learning model locally in user equipment, and then upload the trained machine learning model to a center server for model aggregation. In this example, taking the machine learning task of training image classification as an example, training identifies data above a dataset such as MNIST, fashioMNIST, CIFAR.
As shown in fig. 2, an optimization method of a federal learning optimization system based on reinforcement learning in this embodiment includes the following steps:
s1, determining a task of machine learning training, such as speech recognition, image recognition and input method prediction model, wherein federal learning belongs to the sub-field under distributed computing, can be used for training a machine learning model by a central server, can also be used for training through federal learning, and in the federal learning framework, edge equipment can be various intelligent user equipment such as mobile phones, computers and the like, can also be internet of things equipment, and is different from the traditional machine learning model trained in a cloud center, the federal learning does not need to upload data to a cloud end, can directly train the machine learning model in the local of the user equipment, and then uploads the trained machine learning model to the central server for model aggregation.
The federal learning framework base parameters are set as follows, with t representing the current federal learning training round (t=0, 1, … n), and the parameters in the other frameworks include: communication content (w) of server and edge device t τ) andwherein w is t For global model parameters of the tth training round (i.e. central server model parameters), τ represents the number of edge device local model training iterations, +.>Representing model parameters on the edge device k during the t-th training, and eta represents the learning rate during the edge device training.
As a preferred embodiment, a LeNet-5 convolutional neural network is used to train the model, and specific machine learning training is described as follows: defined on a compact feature space X and a set of labels y= {1, and B, B is the total number of tags. Let (X, Y) denote a particular marker sample, X ε X, Y ε Y; let f w X-S represents the prediction function, S= { z|Σ j∈Y z j =1,z j 0 is more than or equal to the probability vector of a corresponding different label of each sample, and w is a parameter of a prediction function, namely a trained model parameter; function f w Generating a probability vector for each sample, whereinThe probability that the predicted sample belongs to the j-th class; thus, the loss function of the prediction function may be defined as the cross entropy:
where p represents the actual probability distribution that the training target needs to learn.
The learning problem is to solve the following optimization problems:
the federal learning system is provided with M total edge devices (the edge devices can be various intelligent user devices such as mobile phones, computers and the like or can be internet of things devices) and a central server, and each edge device k has data quantity d k The central server is responsible for initializing the model and coordinating the edge devices, selecting a plurality of edge devices in each t round, and downloading the current global model weight parameters w from the server by each device t And performs the following random gradient descent (SGD) training locally:
model parameters, w, generated for training edge device k for the t-th round t-1 Representing global model parameters, eta, which are learning rate and +.>Represents the gradient produced by the loss function, the loss function l (w t-1 ) Represented by cross entropy as-> To predict the probability that sample x belongs to the j-th class of tags, C is the total number of tags on the training-involved device k.
And uploading the trained model parameters to a central server by all the edge devices, estimating the contribution of each edge device by the central server according to the reinforcement learning agent, weighting and aggregating the model parameters uploaded by the edge devices to generate a new global model, and repeating the steps until the global model converges.
The specific implementation flow is as follows:
as stated above, the federal system on which the present invention is based has M edge devices and a central server. Each edge set k has a local data set D k Its data set size is d k Representation, deriving the size of all datasets
Initially, the server randomly selects C devices and stores the C devices to a collectionIn the server +.>Transmitting a current global model parameter w and a local iteration number tau;
the equipment i receiving the parameters carries out tau times of iterative training under the model parameters w on the local data set by using a random gradient descent SGD algorithm to obtain the local training model parameters w i The formula of each iteration training is as follows Representing the gradient produced by the loss function to the model parameters, and finally, w i The package is sent to the server, and then the global model aggregation stage is performed.
The FedAVg of the classical algorithm of federal learning is when the central server updates the global model
Similar to the above formula, w in the present formula t Global model parameters representing the current turn are displayed,is the model parameter uploaded by the edge device i in the t-th round, d i Is the data size on edge device i, and d is the data size of all edge devices participating in the present round of training. The above formula is a method for model aggregation by FedAvg, obviously, it carries out weighted aggregation according to the data quantity on user equipment, however, a large number of experiments show that the method has a very good effect in independent and equidistributed data, but has great differences in the usage habit, the environment, the calculation force and the storage of users of all edge equipment in a non-independent and equidistributed data, namely in a heterogeneous environment, so that the data quantity cannot represent the actual data quality and cannot represent the contribution to a global model, and therefore, the traditional algorithm such as FedAvg has not good effect in the heterogeneous environment.
The invention aims at the heterogeneous situation of data, improves the global model aggregation mode in the Union learning, namely, the central server intelligently evaluates the contribution of the edge equipment participating in training through the trained reinforcement learning agent, and carries out the weighted aggregation of the models according to the actual contribution of the edge equipment so as to accelerate the convergence of the models and reduce the training rounds. The specific flow is as follows:
in the invention, a reinforcement learning agent is pre-trained in a central server, and as a preferred embodiment, the reinforcement learning algorithm is specifically designed as follows:
status: assuming that there are K devices per round to train, the state of each round can be expressed as Model parameters representing the kth device involved in training for each round, w t The method is characterized in that the method is a current global model parameter, the state of each round comprises parameters of k training models and global model parameters, if the model parameters are actually too large, the model parameters are subjected to dimension reduction treatment, and then the model parameters are input into the reinforcement learning agent for effect evaluation. Because the device model parameters actually reflect the training results of the model, the relevant information of the training effects of the edge devices can be learned through the parameter characteristics. And only after the local training of each round of edge equipment is finished, each edge equipment predicts the weight of each edge equipment and updates the global model when the global model is weighted and aggregated after uploading the model parameters to a server.
The actions are as follows: in federal learning, when global parameters are weighted and aggregated in each round, the state S of the current parameters is needed t Estimating the proportion of each edge device in global parameter aggregation, and learning an optimal action value function Q through a neural network * (S t The approximate value of a) is continuously trained, and the DDQN neural network can gradually output the action which maximizes the total income, namely in the state S t Lower output makes Q * (S t The method comprises the steps of a) carrying out the maximum prediction action a, wherein the output action is the contribution value of each estimated edge device to the global model, and carrying out the weighted aggregation of the global model by accurately estimating the action size of each edge device, so that the global model can be converged more quickly, the model accuracy is higher under the same iteration times, and the communication turn is reduced.
Rewarding: the observed rewards at the end of each round t are set to
t represents the number of global iterations, Φ t Is the test accuracy achieved by the global model on the maintained validation set after the t-round, Ω is the target accuracy, i.e. the accuracy that the trained model expects to achieve, m (m>1) Is a normal number, ensure r t Exponentially increasing with test accuracy-! Because 0.ltoreq.phi t Omega is not less than 1, so r t ∈(-1,0]。
The central server performs pre-training of the agent by using the reserved sample data set at the central server, and the process is very fast because the central server has strong calculation power.
And then, the central server utilizes the evaluation result of contribution of the intelligent agent (reinforcement learning model) to each edge device to improve the federal learning global model aggregation process. Utilizing reinforcement learning agent in central server, learning output client k contribution to global model in t-th roundThen weighted average is carried out, and the weight which the client k should occupy in the aggregation of the global model in the t-th round is calculated>
Here, theRepresenting the sum of the contributions of the n devices participating in the t-th round of federal learning training, ++>The contribution size of the client device k participating in the t-th round training in the present training is represented.
Performing global weighted aggregation according to the estimated contribution size:
here, theRepresenting model parameters on the kth device at the t-th round, w t+1 Global model parameters for the t+1 round. After the weighted aggregation of the global model is completed, the above steps are repeated until the model converges.
A system for reinforcement learning based federal learning optimization method as shown in fig. 3, comprising:
the device selection unit 301: in the system, each time training is performed, some devices are selected to participate in the training of the round, and in order to ensure that each device participates in the equalization of the training opportunity and the data on each device are utilized, a method of randomly selecting the devices is used.
Device contribution evaluation unit 302: after the device model parameters are uploaded, the central server evaluates the contribution of each local device through a reinforcement learning algorithm according to the characteristics of each local model parameter.
Global model updating unit 303: when the global model is updated, the current global model is updated by carrying out weighted average on each model parameter according to the contribution of each device to the global model.
In the equipment contribution evaluation unit, feature extraction is performed on model parameters uploaded by each equipment, and then the model parameters are input into the reinforcement learning agent to evaluate the contribution. Then, the method goes to a global model updating unit, and the central server aggregates the global model according to the size of each equipment contribution and a set weighted average algorithm.
The embodiments described above are preferred embodiments of the present invention, but the embodiments of the present invention are not limited to the embodiments described above, and any other changes, modifications, substitutions, combinations, and simplifications made without departing from the spirit and principles of the present invention should be equivalent substitution manner, and are included in the scope of the present invention.
Claims (10)
1. The federal learning optimization method based on reinforcement learning is characterized by comprising the following steps of:
s1, determining a data set image task of machine learning training;
s2, initializing neural network model parameters of corresponding tasks of the central server according to the corresponding tasks, namely initializing a global model;
s3, selecting edge equipment which needs to participate in the round of training by the central server, and downloading a global model by the selected edge equipment;
s4, under the federal learning framework, the edge equipment locally performs model parameter training, and uploads model parameters after the training is completed;
s5, the central server evaluates contribution of edge equipment participating in training, and the edge equipment model is weighted and aggregated according to the contribution size;
and S6, continuously repeating the steps S3-S5 until the trained model converges.
2. The reinforcement learning federal learning optimization method according to claim 1, wherein in step S1, the edge device includes a user intelligent device and an internet of things device.
3. The reinforcement learning federal learning optimization method of claim 1, wherein the model initialization comprises: training a data set by adopting a LeNet-5 convolutional neural network according to the data set image recognition task in the step S1, initializing model parameters, setting machine learning model parameters as w, defining a feature space X and a label set Y= {1, …, N }, wherein B is the total number of labels; let (X, Y) denote the marked sample, X ε X, Y ε Y; let f w : x→s represents the prediction function, s= { z|Σ j∈Y z j =1,z j 0 is represented as a probability vector for each sample corresponding to a different label, z j Representing the probability of a sample classifying as label j, i.e. trained roboticsModel parameters are learned; function f w A probability vector is generated for each sample.
4. The reinforcement learning federal learning optimization method of claim 2, wherein the loss function of the trained machine learning task prediction function is defined as cross entropy:
where p represents the true probability distribution,in order to predict the probability that the sample x belongs to the j-th type of labels, D is the total number of label samples, and a random gradient descent strategy is adopted in the training process to continuously reduce the value of the loss function until the model converges.
5. The reinforcement learning-based federal learning optimization method according to claim 1, wherein the parameter setting of the federal learning framework when using the federal learning training model comprises: let t denote the current federal training round t, the communication content of the server and the edge device (w t τ) andwhere t=0, 1,..n, w t For global model parameters of the t training round, τ represents the number of training iterations of the edge device local model, +.>Representing model parameters on the edge device k during the t-th training, and eta represents the learning rate during the edge device training.
6. The reinforcement learning federal learning optimization method according to claim 1, wherein in step S3, a method of randomly selecting a client device is adopted to select a client device that is trained in this round.
7. The reinforcement learning federal learning optimization method according to claim 1, wherein in step S4, the edge device performs training of the model parameters locally according to preset parameters, and after the local model training is completed, the edge device uploads the model parameters to the central server.
8. The reinforcement learning federal learning optimization method according to claim 1, wherein in step S5, the central server adopts reinforcement learning agents to evaluate the contribution of each edge device to the global model according to the edge device model parameters and the characteristics of the current global model parameters, specifically as follows:
first pre-training a reinforcement learning agent at a central server:
status: let K devices in each round to train, and the state of each round training is expressed as Model parameters representing the kth device involved in training for the t-th round, w t The method comprises the steps that the method is a global model parameter of a t-th round, the state of each round comprises the characteristics of model parameters uploaded by K edge devices and current global model parameters, each round of federal learning training is finished, each edge device transmits the model parameters from the local to a central server, then the global model is updated, the aggregation duty ratio weight of each user device participating in aggregation is required to be evaluated firstly, and then the global model is updated;
the actions are as follows: in the federal learning, when the global parameters are weighted and polymerized in each round, the state S of the current parameters is used for t Global model for each edge device modelEstimating the proportion of the aggregate, and learning the approximate value Q of the optimal action value through the value evaluation function in the DDQN * (S t A), a represents a prediction action made by a reinforcement learning agent, so that the DDQN neural network outputs the contribution of edge devices participating in training to a current global model, and the edge devices participating in global model aggregation are distributed with aggregation duty ratio according to the action of each edge device to the global model, so that the aggregation of the global model is quickened;
rewarding: the observed rewards at the end of each round t are set to
Where t represents the number of global iterations, Φ t Is the test precision achieved by the global model on the maintained verification set after t rounds, wherein omega is the target precision, m > 1 is a positive constant, and r is ensured t Along with the exponential increase of the test precision, phi is more than or equal to 0 t ≤Ω≤1,r t ∈(-1,0]。
9. The reinforcement learning federal learning optimization method according to claim 1, wherein the contribution of learning output client k to the global model at the t-th round is learned by reinforcement learning agent in the central serverThen weighted average is carried out, and the weight which the client k should occupy in the aggregation of the global model in the t-th round is calculated>
Wherein the method comprises the steps ofRepresenting the sum of the contributions of the n devices participating in the t-th round of federal learning training, ++>Representing the contribution size of the client device k participating in the t-th round of training in the present training;
performing global weighted aggregation according to the estimated contribution size:
wherein the method comprises the steps ofRepresenting model parameters on the kth device at the t-th round, w t+1 Global model parameters for the t+1 round. After the weighted aggregation of the global model is completed, the above steps are repeated until the model converges.
10. A system for implementing the reinforcement learning federal learning optimization method of claim 1, comprising:
a device selection unit: selecting equipment to participate in the training of the round by using a method of randomly selecting equipment every time training is performed;
device contribution evaluation unit: after the equipment model parameters are uploaded, the central server evaluates the contribution of each local equipment through a reinforcement learning algorithm according to the characteristics of each local model parameter;
global model updating unit: when the global model is updated, the current global model is updated by carrying out weighted average on each model parameter according to the contribution of each device to the global model;
in the equipment contribution evaluation unit, feature extraction is carried out on model parameters uploaded by each equipment, the model parameters are input into the reinforcement learning agent for contribution evaluation, and the central server carries out global model aggregation according to the size of the contribution of each equipment and a set weighted average algorithm.
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310230326.2A CN116523079A (en) | 2023-03-10 | 2023-03-10 | Reinforced learning-based federal learning optimization method and system |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310230326.2A CN116523079A (en) | 2023-03-10 | 2023-03-10 | Reinforced learning-based federal learning optimization method and system |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116523079A true CN116523079A (en) | 2023-08-01 |
Family
ID=87405400
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310230326.2A Pending CN116523079A (en) | 2023-03-10 | 2023-03-10 | Reinforced learning-based federal learning optimization method and system |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116523079A (en) |
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117094381A (en) * | 2023-08-21 | 2023-11-21 | 哈尔滨工业大学 | Multi-mode federal collaboration method taking high-efficiency communication and individuation into consideration |
CN117273119A (en) * | 2023-08-24 | 2023-12-22 | 北京邮电大学 | Dynamic fairness federal learning method and device based on reinforcement learning |
CN117313835A (en) * | 2023-10-08 | 2023-12-29 | 湖北大学 | Federal learning method based on client contribution clearance in heterogeneous data environment |
CN117392483A (en) * | 2023-12-06 | 2024-01-12 | 山东大学 | Album classification model training acceleration method, system and medium based on reinforcement learning |
CN117575291A (en) * | 2024-01-15 | 2024-02-20 | 湖南科技大学 | Federal learning data collaborative management method based on edge parameter entropy |
CN117592555A (en) * | 2023-11-28 | 2024-02-23 | 中国医学科学院北京协和医院 | Federal learning method and system for multi-source heterogeneous medical data |
CN117725965A (en) * | 2024-02-06 | 2024-03-19 | 湘江实验室 | Federal edge data communication method based on tensor mask semantic communication |
-
2023
- 2023-03-10 CN CN202310230326.2A patent/CN116523079A/en active Pending
Cited By (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117094381A (en) * | 2023-08-21 | 2023-11-21 | 哈尔滨工业大学 | Multi-mode federal collaboration method taking high-efficiency communication and individuation into consideration |
CN117094381B (en) * | 2023-08-21 | 2024-04-12 | 哈尔滨工业大学 | Multi-mode federal collaboration method taking high-efficiency communication and individuation into consideration |
CN117273119A (en) * | 2023-08-24 | 2023-12-22 | 北京邮电大学 | Dynamic fairness federal learning method and device based on reinforcement learning |
CN117313835A (en) * | 2023-10-08 | 2023-12-29 | 湖北大学 | Federal learning method based on client contribution clearance in heterogeneous data environment |
CN117592555A (en) * | 2023-11-28 | 2024-02-23 | 中国医学科学院北京协和医院 | Federal learning method and system for multi-source heterogeneous medical data |
CN117592555B (en) * | 2023-11-28 | 2024-05-10 | 中国医学科学院北京协和医院 | Federal learning method and system for multi-source heterogeneous medical data |
CN117392483A (en) * | 2023-12-06 | 2024-01-12 | 山东大学 | Album classification model training acceleration method, system and medium based on reinforcement learning |
CN117392483B (en) * | 2023-12-06 | 2024-02-23 | 山东大学 | Album classification model training acceleration method, system and medium based on reinforcement learning |
CN117575291A (en) * | 2024-01-15 | 2024-02-20 | 湖南科技大学 | Federal learning data collaborative management method based on edge parameter entropy |
CN117575291B (en) * | 2024-01-15 | 2024-05-10 | 湖南科技大学 | Federal learning data collaborative management method based on edge parameter entropy |
CN117725965A (en) * | 2024-02-06 | 2024-03-19 | 湘江实验室 | Federal edge data communication method based on tensor mask semantic communication |
CN117725965B (en) * | 2024-02-06 | 2024-05-14 | 湘江实验室 | Federal edge data communication method based on tensor mask semantic communication |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN116523079A (en) | Reinforced learning-based federal learning optimization method and system | |
CN109948029B (en) | Neural network self-adaptive depth Hash image searching method | |
CN113191484B (en) | Federal learning client intelligent selection method and system based on deep reinforcement learning | |
WO2018161468A1 (en) | Global optimization, searching and machine learning method based on lamarck acquired genetic principle | |
CN110851782A (en) | Network flow prediction method based on lightweight spatiotemporal deep learning model | |
EP4350572A1 (en) | Method, apparatus and system for generating neural network model, devices, medium and program product | |
CN111277434A (en) | Network flow multi-step prediction method based on VMD and LSTM | |
WO2022126448A1 (en) | Neural architecture search method and system based on evolutionary learning | |
CN111158912A (en) | Task unloading decision method based on deep learning in cloud and mist collaborative computing environment | |
US20220318412A1 (en) | Privacy-aware pruning in machine learning | |
CN115374853A (en) | Asynchronous federal learning method and system based on T-Step polymerization algorithm | |
Long et al. | Fedsiam: Towards adaptive federated semi-supervised learning | |
CN112836822A (en) | Federal learning strategy optimization method and device based on width learning | |
CN114819143A (en) | Model compression method suitable for communication network field maintenance | |
CN115359298A (en) | Sparse neural network-based federal meta-learning image classification method | |
CN116645130A (en) | Automobile order demand prediction method based on combination of federal learning and GRU | |
Kozat et al. | Universal switching linear least squares prediction | |
CN111832817A (en) | Small world echo state network time sequence prediction method based on MCP penalty function | |
CN113194493B (en) | Wireless network data missing attribute recovery method and device based on graph neural network | |
Adeleke | Echo-state networks for network traffic prediction | |
CN113128432B (en) | Machine vision multitask neural network architecture searching method based on evolution calculation | |
Gong et al. | Compressed particle-based federated bayesian learning and unlearning | |
CN116976405A (en) | Variable component shadow quantum neural network based on immune optimization algorithm | |
CN114819196B (en) | Noise distillation-based federal learning system and method | |
CN110768825A (en) | Service flow prediction method based on network big data analysis |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |