CN116523079A - Reinforced learning-based federal learning optimization method and system - Google Patents

Reinforced learning-based federal learning optimization method and system Download PDF

Info

Publication number
CN116523079A
CN116523079A CN202310230326.2A CN202310230326A CN116523079A CN 116523079 A CN116523079 A CN 116523079A CN 202310230326 A CN202310230326 A CN 202310230326A CN 116523079 A CN116523079 A CN 116523079A
Authority
CN
China
Prior art keywords
model
training
learning
round
global
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310230326.2A
Other languages
Chinese (zh)
Inventor
郭恩威
汪秀敏
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
South China University of Technology SCUT
Original Assignee
South China University of Technology SCUT
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by South China University of Technology SCUT filed Critical South China University of Technology SCUT
Priority to CN202310230326.2A priority Critical patent/CN116523079A/en
Publication of CN116523079A publication Critical patent/CN116523079A/en
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • G06N20/20Ensemble learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F21/00Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
    • G06F21/60Protecting data
    • G06F21/62Protecting access to data via a platform, e.g. using keys or access control rules
    • G06F21/6218Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
    • G06F21/6245Protecting personal data, e.g. for financial or medical purposes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/092Reinforcement learning
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Medical Informatics (AREA)
  • Molecular Biology (AREA)
  • Bioethics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Probability & Statistics with Applications (AREA)
  • Databases & Information Systems (AREA)
  • Computer Hardware Design (AREA)
  • Computer Security & Cryptography (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

The invention discloses a reinforcement learning-based federal learning optimization method and system on heterogeneous resources, comprising the following steps: the method provided by the invention can estimate the model training action of the heterogeneous equipment, greatly accelerate the speed of federal learning training, reduce the interference of malicious data on training, and protect the data privacy of multiple participants.

Description

Reinforced learning-based federal learning optimization method and system
Technical Field
The invention belongs to the field of federal learning efficiency optimization in edge computing, and particularly relates to a federal learning efficiency optimization method and system based on heterogeneous resources.
Background
With the rapid development of the internet and the internet of things, the number of intelligent devices is rapidly increased in the past few years, and the data volume generated every day is also unprecedented nowadays. These data have an important role in the deep learning model. Generally, the development of artificial intelligence is not driven by big data, and some complex large-scale deep learning models often need massive data for iterative training to obtain an effective model, and the data is transmitted to a cloud center server in the past. However, as the amount of data increases, the trained deep learning model is increasingly complex, if all the data is offloaded to the cloud server, the server is loaded, all the data is impractical to be transmitted from the local device to the remote server for further processing under the constraint of limited network bandwidth, and the transmission process also has data privacy and security problems. In view of these key factors, trends in data storage and analysis are moving from cloud-based centralized to distributed and on-device. A key enabling technology for this transition is edge computing, which offloads complex computing tasks or applications by providing computing resources to support the internet of things device. Through edge calculation, the problems of insufficient calculation capacity, high unloading delay to a cloud server, data security and the like of the Internet of things equipment can be effectively solved.
Due to the rapid development of edge computing and distributed machine learning (S.Arisdakessian, Q.A.Wahab, A.Mourad, H.Otrok, and n.kara, | Intelligent Multi-criterion IoT-Fog Scheduling Approach llsing game theory, transmission on Networking, 2020.), a new distributed machine learning paradigm, federal learning (H.B.McMahan, E.Moore, D.Ramage, S.Hampson, and b.a. arcas, | "Communication-efficientlearning of deep networks from decentralized data," in Proceedings of Machine Learning Research, vol.54, pp.1273-1282, apr.2017.), to enable local and distributed machine learning training at the edge node or end device level by co-training the machine learning model without exposing the raw data. Federal learning typically employs a parameter server architecture in which the terminal trains a model synchronized by the parameter server.
Federal learning is a distributed machine learning method that can be trained on a large amount of decentralized data on local devices. In particular, there is a well-known scheme, named FedAvg, in which a model is initialized at a central server, then downloaded by an edge device, iterated for local update (using gradient descent) multiple times, then aggregated globally at the server side, and weighted-averaged according to the model obtained from the edge device, and the distributed machine learning method is particularly suitable for Internet of things. Due to the advantages, federal learning is widely applied in the fields of Internet of things and mobile edge computing. For example, training deep reinforcement learning agents in a distributed fashion using federal learning, prediction of the next word on google keyboards is also improved by federal learning.
Many existing federal learning algorithms tend to aggregate global models according to the size of the data volume when model aggregation is performed. However, the relationship between model accuracy and training data volume is nonlinear (zhan Y, li P, wang K, et al Big Data Analytics by CrowdLearning: architecture and Mechanism Design [ J ]. IEEENetwork,2020,PP (99): 1-5.), in practical situations, the data distribution over each edge device varies widely, and it is difficult to evaluate the magnitude of the effect of each device on global model training under limited information due to the nature of distributed training and the consideration of protecting user data privacy. Therefore, it is necessary to accurately estimate the action of each heterogeneous device on the current model, and further scientifically allocate the weight ratio of the model in aggregation, so as to speed up training and improve the communication efficiency of training.
Disclosure of Invention
The invention aims to overcome the defects in the prior art, and provides a federal learning training method based on reinforcement learning, wherein the reinforcement learning algorithm is used for evaluating the action of each heterogeneous device on a trained global model, and the weight of each heterogeneous device in global model aggregation is further decided according to the evaluation result. Meanwhile, malicious data can be effectively identified, and the influence of some malicious data on training is reduced. Actual experiments prove that the method provided by the invention can improve the training mode of federal learning in an actual environment and accelerate the speed of federal learning training.
The invention is realized at least by one of the following technical schemes.
A reinforcement learning-based federal learning optimization method comprises the following steps:
s1, determining a data set image task of machine learning training;
s2, initializing neural network model parameters of corresponding tasks of the central server according to the corresponding tasks, namely initializing a global model;
s3, selecting edge equipment which needs to participate in the round of training by the central server, and downloading a global model by the selected edge equipment;
s4, under the federal learning framework, the edge equipment locally performs model parameter training, and uploads model parameters after the training is completed;
s5, the central server evaluates contribution of edge equipment participating in training, and the edge equipment model is weighted and aggregated according to the contribution size;
and S6, continuously repeating the steps S3-S5 until the trained model converges.
Further, in step S1, the edge device includes a user intelligent device and an internet of things device.
Further, the model initialization includes: training a data set by adopting a LeNet-5 convolutional neural network according to the data set image recognition task in the step S1, initializing model parameters, setting machine learning model parameters as w, and defining a feature space X and a label set Y= {1, &. B }, wherein B is the total number of labels; let (X, Y) denote the marked sample, X ε X, Y ε Y; let f w X-S represents the prediction function, S= { z|Σ j∈Y z j =1,z j 0 is represented as a probability vector for each sample corresponding to a different label, z j Representing the probability that the sample is classified as a label j, i.e., trained machine learning model parameters; function f w A probability vector is generated for each sample.
4. The reinforcement learning federal learning optimization method according to claim 2, wherein a loss function of a machine learning task corresponding prediction function of federal learning training is defined as a cross entropy:
where p represents the true probability distribution,in order to predict the probability that the sample x belongs to the j-th type of labels, D is the total number of label samples, and a random gradient descent strategy is adopted in the training process to continuously reduce the value of the loss function until the model converges.
Further, when the federal learning training model is utilized, the parameter settings of the federal learning framework include: let t denote the current federal training round t, the communication content of the server and the edge device (w t τ) andwherein t=0, 1, … n, w t For global model parameters of the t training round, τ represents the number of training iterations of the edge device local model, +.>Representing model parameters on the edge device k during the t-th training, and eta represents the learning rate during the edge device training.
Further, in step S3, a method of randomly selecting a client device is adopted to select a client device that performs training in this round.
Further, in step S4, the edge device trains the model parameters locally according to the preset parameters, and after the local model training is completed, the edge device uploads the model parameters to the central server.
Further, in step S5, the central server adopts the reinforcement learning model to evaluate the contribution of each edge device to the global model according to the edge device model parameters and the characteristics of the current global model parameters, and specifically as follows:
first pre-training a reinforcement learning agent at a central server:
status: let K devices in each round to train, and the state of each round training is expressed as Model parameters representing the kth device involved in training for the t-th round, w t The method comprises the steps that the method is a global model parameter of a t-th round, the state of each round comprises the characteristics of model parameters uploaded by K edge devices and current global model parameters, each round of federal learning training is finished, each edge device transmits the model parameters from the local to a central server, then the global model is updated, the aggregation duty ratio weight of each user device participating in aggregation is required to be evaluated firstly, and then the global model is updated;
the actions are as follows: in the federal learning, when the global parameters are weighted and polymerized in each round, the state S of the current parameters is used for t Estimating the proportion of each edge equipment model in the global model aggregation, and learning the approximate value Q of the optimal action value through the value evaluation function in the DDQN * (S t A), a represents a prediction action made by a reinforcement learning agent, so that the DDQN neural network outputs the contribution of edge devices participating in training to a current global model, and the edge devices participating in global model aggregation are distributed with aggregation duty ratio according to the action of each edge device to the global model, so that the aggregation of the global model is quickened;
rewarding: the observed rewards at the end of each round t are set to
t represents the number of global iterations, Φ t Is t wheelThe test accuracy achieved by the post global model on the maintained validation set, Ω is the target accuracy, also m (m>1) Is a normal number, ensure r t Exponentially increasing with test accuracy because 0.ltoreq.Φ t Omega is not less than 1, so r t ∈(-1,0]。
Further, the contribution of the learning output client k to the global model in the t-th round is learned by using the reinforcement learning agent in the central serverThen weighted average is carried out, and the weight which the client k should occupy in the aggregation of the global model in the t-th round is calculated>
Wherein the method comprises the steps ofRepresenting the sum of the contributions of the n devices participating in the t-th round of federal learning training, ++>Representing the contribution size of the client device k participating in the t-th round of training in the present training;
performing global weighted aggregation according to the estimated contribution size:
wherein the method comprises the steps ofRepresenting model parameters on the kth device at the t-th round, w t+1 Global model parameters for the t+1 round. After the weighted aggregation of the global model is completed, the above steps are repeated until the model converges.
The system for realizing the reinforcement learning federal learning optimization method comprises the following components:
a device selection unit: selecting equipment to participate in the training of the round by using a method of randomly selecting equipment every time training is performed;
device contribution evaluation unit: after the equipment model parameters are uploaded, the central server evaluates the contribution of each local equipment through a reinforcement learning algorithm according to the characteristics of each local model parameter;
global model updating unit: when the global model is updated, the current global model is updated by carrying out weighted average on each model parameter according to the contribution of each device to the global model;
in the equipment contribution evaluation unit, feature extraction is carried out on model parameters uploaded by each equipment, the model parameters are input into the reinforcement learning agent for contribution evaluation, and the central server carries out global model aggregation according to the size of the contribution of each equipment and a set weighted average algorithm.
Compared with the prior art, the invention has the following beneficial effects:
(1) The invention provides a heterogeneous federal learning training method based on DDQN, which aims to optimize the aggregation process of a global model, uses the DDQN model to evaluate the contribution of each device participating in the round of training, and carries out self-adaptive adjustment on the weight of the model participating in the aggregation, thereby greatly accelerating the aggregation speed of the model, improving the training speed, reducing the communication rounds required by the training and saving the communication resources.
(2) The method and the system can accurately evaluate and evaluate the contribution of each device to the global model, and further provide different rewards for the devices participating in training according to the contribution.
(3) The invention can avoid the influence of malicious data on the training of the model, so that the model can normally complete the training.
Drawings
FIG. 1 is a schematic diagram of a system in the practice of the present invention;
FIG. 2 is a flow chart of a reinforcement learning-based federal learning optimization method in accordance with an embodiment of the present invention;
FIG. 3 is a block diagram of a reinforcement learning based federal learning optimization system according to an embodiment of the present invention.
Detailed Description
In order that those skilled in the art will better understand the present invention, the following description will be given in detail with reference to the accompanying drawings and detailed description. It will be apparent that the described embodiments are only some, but not all, embodiments of the invention. All other embodiments, which can be made by those skilled in the art based on the embodiments of the invention without making any inventive effort, are intended to be within the scope of the invention.
The invention is further described below by means of specific embodiments.
According to the invention, through researching the problem of federal learning optimization in edge calculation, the federal learning training method of heterogeneous resources (edge equipment with difference in edge equipment data) based on DDQN is provided, meanwhile, privacy data of users can be protected, influence of malicious data on experiments is avoided, communication efficiency is improved, and training speed is accelerated.
As shown in FIG. 1, the federal learning optimization system based on reinforcement learning comprises a central server and edge equipment, wherein the central server and the edge equipment cooperatively train a machine learning model, the edge equipment does not need to upload data to the central server, and model parameters trained by machine learning tasks are transmitted between the edge equipment and the central server.
Because federal learning belongs to a sub-domain under distributed computing, machine learning models that can be trained at a central server can also be trained by the federal learning framework, that is, it can be used to train various tasks such as speech recognition, image recognition, input method prediction models, and the like. Unlike the traditional machine learning model training in the cloud center, federal learning does not need to upload data to the cloud end, and can directly train the machine learning model locally in user equipment, and then upload the trained machine learning model to a center server for model aggregation. In this example, taking the machine learning task of training image classification as an example, training identifies data above a dataset such as MNIST, fashioMNIST, CIFAR.
As shown in fig. 2, an optimization method of a federal learning optimization system based on reinforcement learning in this embodiment includes the following steps:
s1, determining a task of machine learning training, such as speech recognition, image recognition and input method prediction model, wherein federal learning belongs to the sub-field under distributed computing, can be used for training a machine learning model by a central server, can also be used for training through federal learning, and in the federal learning framework, edge equipment can be various intelligent user equipment such as mobile phones, computers and the like, can also be internet of things equipment, and is different from the traditional machine learning model trained in a cloud center, the federal learning does not need to upload data to a cloud end, can directly train the machine learning model in the local of the user equipment, and then uploads the trained machine learning model to the central server for model aggregation.
The federal learning framework base parameters are set as follows, with t representing the current federal learning training round (t=0, 1, … n), and the parameters in the other frameworks include: communication content (w) of server and edge device t τ) andwherein w is t For global model parameters of the tth training round (i.e. central server model parameters), τ represents the number of edge device local model training iterations, +.>Representing model parameters on the edge device k during the t-th training, and eta represents the learning rate during the edge device training.
As a preferred embodiment, a LeNet-5 convolutional neural network is used to train the model, and specific machine learning training is described as follows: defined on a compact feature space X and a set of labels y= {1, and B, B is the total number of tags. Let (X, Y) denote a particular marker sample, X ε X, Y ε Y; let f w X-S represents the prediction function, S= { z|Σ j∈Y z j =1,z j 0 is more than or equal to the probability vector of a corresponding different label of each sample, and w is a parameter of a prediction function, namely a trained model parameter; function f w Generating a probability vector for each sample, whereinThe probability that the predicted sample belongs to the j-th class; thus, the loss function of the prediction function may be defined as the cross entropy:
where p represents the actual probability distribution that the training target needs to learn.
The learning problem is to solve the following optimization problems:
the federal learning system is provided with M total edge devices (the edge devices can be various intelligent user devices such as mobile phones, computers and the like or can be internet of things devices) and a central server, and each edge device k has data quantity d k The central server is responsible for initializing the model and coordinating the edge devices, selecting a plurality of edge devices in each t round, and downloading the current global model weight parameters w from the server by each device t And performs the following random gradient descent (SGD) training locally:
model parameters, w, generated for training edge device k for the t-th round t-1 Representing global model parameters, eta, which are learning rate and +.>Represents the gradient produced by the loss function, the loss function l (w t-1 ) Represented by cross entropy as-> To predict the probability that sample x belongs to the j-th class of tags, C is the total number of tags on the training-involved device k.
And uploading the trained model parameters to a central server by all the edge devices, estimating the contribution of each edge device by the central server according to the reinforcement learning agent, weighting and aggregating the model parameters uploaded by the edge devices to generate a new global model, and repeating the steps until the global model converges.
The specific implementation flow is as follows:
as stated above, the federal system on which the present invention is based has M edge devices and a central server. Each edge set k has a local data set D k Its data set size is d k Representation, deriving the size of all datasets
Initially, the server randomly selects C devices and stores the C devices to a collectionIn the server +.>Transmitting a current global model parameter w and a local iteration number tau;
the equipment i receiving the parameters carries out tau times of iterative training under the model parameters w on the local data set by using a random gradient descent SGD algorithm to obtain the local training model parameters w i The formula of each iteration training is as follows Representing the gradient produced by the loss function to the model parameters, and finally, w i The package is sent to the server, and then the global model aggregation stage is performed.
The FedAVg of the classical algorithm of federal learning is when the central server updates the global model
Similar to the above formula, w in the present formula t Global model parameters representing the current turn are displayed,is the model parameter uploaded by the edge device i in the t-th round, d i Is the data size on edge device i, and d is the data size of all edge devices participating in the present round of training. The above formula is a method for model aggregation by FedAvg, obviously, it carries out weighted aggregation according to the data quantity on user equipment, however, a large number of experiments show that the method has a very good effect in independent and equidistributed data, but has great differences in the usage habit, the environment, the calculation force and the storage of users of all edge equipment in a non-independent and equidistributed data, namely in a heterogeneous environment, so that the data quantity cannot represent the actual data quality and cannot represent the contribution to a global model, and therefore, the traditional algorithm such as FedAvg has not good effect in the heterogeneous environment.
The invention aims at the heterogeneous situation of data, improves the global model aggregation mode in the Union learning, namely, the central server intelligently evaluates the contribution of the edge equipment participating in training through the trained reinforcement learning agent, and carries out the weighted aggregation of the models according to the actual contribution of the edge equipment so as to accelerate the convergence of the models and reduce the training rounds. The specific flow is as follows:
in the invention, a reinforcement learning agent is pre-trained in a central server, and as a preferred embodiment, the reinforcement learning algorithm is specifically designed as follows:
status: assuming that there are K devices per round to train, the state of each round can be expressed as Model parameters representing the kth device involved in training for each round, w t The method is characterized in that the method is a current global model parameter, the state of each round comprises parameters of k training models and global model parameters, if the model parameters are actually too large, the model parameters are subjected to dimension reduction treatment, and then the model parameters are input into the reinforcement learning agent for effect evaluation. Because the device model parameters actually reflect the training results of the model, the relevant information of the training effects of the edge devices can be learned through the parameter characteristics. And only after the local training of each round of edge equipment is finished, each edge equipment predicts the weight of each edge equipment and updates the global model when the global model is weighted and aggregated after uploading the model parameters to a server.
The actions are as follows: in federal learning, when global parameters are weighted and aggregated in each round, the state S of the current parameters is needed t Estimating the proportion of each edge device in global parameter aggregation, and learning an optimal action value function Q through a neural network * (S t The approximate value of a) is continuously trained, and the DDQN neural network can gradually output the action which maximizes the total income, namely in the state S t Lower output makes Q * (S t The method comprises the steps of a) carrying out the maximum prediction action a, wherein the output action is the contribution value of each estimated edge device to the global model, and carrying out the weighted aggregation of the global model by accurately estimating the action size of each edge device, so that the global model can be converged more quickly, the model accuracy is higher under the same iteration times, and the communication turn is reduced.
Rewarding: the observed rewards at the end of each round t are set to
t represents the number of global iterations, Φ t Is the test accuracy achieved by the global model on the maintained validation set after the t-round, Ω is the target accuracy, i.e. the accuracy that the trained model expects to achieve, m (m>1) Is a normal number, ensure r t Exponentially increasing with test accuracy-! Because 0.ltoreq.phi t Omega is not less than 1, so r t ∈(-1,0]。
The central server performs pre-training of the agent by using the reserved sample data set at the central server, and the process is very fast because the central server has strong calculation power.
And then, the central server utilizes the evaluation result of contribution of the intelligent agent (reinforcement learning model) to each edge device to improve the federal learning global model aggregation process. Utilizing reinforcement learning agent in central server, learning output client k contribution to global model in t-th roundThen weighted average is carried out, and the weight which the client k should occupy in the aggregation of the global model in the t-th round is calculated>
Here, theRepresenting the sum of the contributions of the n devices participating in the t-th round of federal learning training, ++>The contribution size of the client device k participating in the t-th round training in the present training is represented.
Performing global weighted aggregation according to the estimated contribution size:
here, theRepresenting model parameters on the kth device at the t-th round, w t+1 Global model parameters for the t+1 round. After the weighted aggregation of the global model is completed, the above steps are repeated until the model converges.
A system for reinforcement learning based federal learning optimization method as shown in fig. 3, comprising:
the device selection unit 301: in the system, each time training is performed, some devices are selected to participate in the training of the round, and in order to ensure that each device participates in the equalization of the training opportunity and the data on each device are utilized, a method of randomly selecting the devices is used.
Device contribution evaluation unit 302: after the device model parameters are uploaded, the central server evaluates the contribution of each local device through a reinforcement learning algorithm according to the characteristics of each local model parameter.
Global model updating unit 303: when the global model is updated, the current global model is updated by carrying out weighted average on each model parameter according to the contribution of each device to the global model.
In the equipment contribution evaluation unit, feature extraction is performed on model parameters uploaded by each equipment, and then the model parameters are input into the reinforcement learning agent to evaluate the contribution. Then, the method goes to a global model updating unit, and the central server aggregates the global model according to the size of each equipment contribution and a set weighted average algorithm.
The embodiments described above are preferred embodiments of the present invention, but the embodiments of the present invention are not limited to the embodiments described above, and any other changes, modifications, substitutions, combinations, and simplifications made without departing from the spirit and principles of the present invention should be equivalent substitution manner, and are included in the scope of the present invention.

Claims (10)

1. The federal learning optimization method based on reinforcement learning is characterized by comprising the following steps of:
s1, determining a data set image task of machine learning training;
s2, initializing neural network model parameters of corresponding tasks of the central server according to the corresponding tasks, namely initializing a global model;
s3, selecting edge equipment which needs to participate in the round of training by the central server, and downloading a global model by the selected edge equipment;
s4, under the federal learning framework, the edge equipment locally performs model parameter training, and uploads model parameters after the training is completed;
s5, the central server evaluates contribution of edge equipment participating in training, and the edge equipment model is weighted and aggregated according to the contribution size;
and S6, continuously repeating the steps S3-S5 until the trained model converges.
2. The reinforcement learning federal learning optimization method according to claim 1, wherein in step S1, the edge device includes a user intelligent device and an internet of things device.
3. The reinforcement learning federal learning optimization method of claim 1, wherein the model initialization comprises: training a data set by adopting a LeNet-5 convolutional neural network according to the data set image recognition task in the step S1, initializing model parameters, setting machine learning model parameters as w, defining a feature space X and a label set Y= {1, …, N }, wherein B is the total number of labels; let (X, Y) denote the marked sample, X ε X, Y ε Y; let f w : x→s represents the prediction function, s= { z|Σ j∈Y z j =1,z j 0 is represented as a probability vector for each sample corresponding to a different label, z j Representing the probability of a sample classifying as label j, i.e. trained roboticsModel parameters are learned; function f w A probability vector is generated for each sample.
4. The reinforcement learning federal learning optimization method of claim 2, wherein the loss function of the trained machine learning task prediction function is defined as cross entropy:
where p represents the true probability distribution,in order to predict the probability that the sample x belongs to the j-th type of labels, D is the total number of label samples, and a random gradient descent strategy is adopted in the training process to continuously reduce the value of the loss function until the model converges.
5. The reinforcement learning-based federal learning optimization method according to claim 1, wherein the parameter setting of the federal learning framework when using the federal learning training model comprises: let t denote the current federal training round t, the communication content of the server and the edge device (w t τ) andwhere t=0, 1,..n, w t For global model parameters of the t training round, τ represents the number of training iterations of the edge device local model, +.>Representing model parameters on the edge device k during the t-th training, and eta represents the learning rate during the edge device training.
6. The reinforcement learning federal learning optimization method according to claim 1, wherein in step S3, a method of randomly selecting a client device is adopted to select a client device that is trained in this round.
7. The reinforcement learning federal learning optimization method according to claim 1, wherein in step S4, the edge device performs training of the model parameters locally according to preset parameters, and after the local model training is completed, the edge device uploads the model parameters to the central server.
8. The reinforcement learning federal learning optimization method according to claim 1, wherein in step S5, the central server adopts reinforcement learning agents to evaluate the contribution of each edge device to the global model according to the edge device model parameters and the characteristics of the current global model parameters, specifically as follows:
first pre-training a reinforcement learning agent at a central server:
status: let K devices in each round to train, and the state of each round training is expressed as Model parameters representing the kth device involved in training for the t-th round, w t The method comprises the steps that the method is a global model parameter of a t-th round, the state of each round comprises the characteristics of model parameters uploaded by K edge devices and current global model parameters, each round of federal learning training is finished, each edge device transmits the model parameters from the local to a central server, then the global model is updated, the aggregation duty ratio weight of each user device participating in aggregation is required to be evaluated firstly, and then the global model is updated;
the actions are as follows: in the federal learning, when the global parameters are weighted and polymerized in each round, the state S of the current parameters is used for t Global model for each edge device modelEstimating the proportion of the aggregate, and learning the approximate value Q of the optimal action value through the value evaluation function in the DDQN * (S t A), a represents a prediction action made by a reinforcement learning agent, so that the DDQN neural network outputs the contribution of edge devices participating in training to a current global model, and the edge devices participating in global model aggregation are distributed with aggregation duty ratio according to the action of each edge device to the global model, so that the aggregation of the global model is quickened;
rewarding: the observed rewards at the end of each round t are set to
Where t represents the number of global iterations, Φ t Is the test precision achieved by the global model on the maintained verification set after t rounds, wherein omega is the target precision, m > 1 is a positive constant, and r is ensured t Along with the exponential increase of the test precision, phi is more than or equal to 0 t ≤Ω≤1,r t ∈(-1,0]。
9. The reinforcement learning federal learning optimization method according to claim 1, wherein the contribution of learning output client k to the global model at the t-th round is learned by reinforcement learning agent in the central serverThen weighted average is carried out, and the weight which the client k should occupy in the aggregation of the global model in the t-th round is calculated>
Wherein the method comprises the steps ofRepresenting the sum of the contributions of the n devices participating in the t-th round of federal learning training, ++>Representing the contribution size of the client device k participating in the t-th round of training in the present training;
performing global weighted aggregation according to the estimated contribution size:
wherein the method comprises the steps ofRepresenting model parameters on the kth device at the t-th round, w t+1 Global model parameters for the t+1 round. After the weighted aggregation of the global model is completed, the above steps are repeated until the model converges.
10. A system for implementing the reinforcement learning federal learning optimization method of claim 1, comprising:
a device selection unit: selecting equipment to participate in the training of the round by using a method of randomly selecting equipment every time training is performed;
device contribution evaluation unit: after the equipment model parameters are uploaded, the central server evaluates the contribution of each local equipment through a reinforcement learning algorithm according to the characteristics of each local model parameter;
global model updating unit: when the global model is updated, the current global model is updated by carrying out weighted average on each model parameter according to the contribution of each device to the global model;
in the equipment contribution evaluation unit, feature extraction is carried out on model parameters uploaded by each equipment, the model parameters are input into the reinforcement learning agent for contribution evaluation, and the central server carries out global model aggregation according to the size of the contribution of each equipment and a set weighted average algorithm.
CN202310230326.2A 2023-03-10 2023-03-10 Reinforced learning-based federal learning optimization method and system Pending CN116523079A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310230326.2A CN116523079A (en) 2023-03-10 2023-03-10 Reinforced learning-based federal learning optimization method and system

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310230326.2A CN116523079A (en) 2023-03-10 2023-03-10 Reinforced learning-based federal learning optimization method and system

Publications (1)

Publication Number Publication Date
CN116523079A true CN116523079A (en) 2023-08-01

Family

ID=87405400

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310230326.2A Pending CN116523079A (en) 2023-03-10 2023-03-10 Reinforced learning-based federal learning optimization method and system

Country Status (1)

Country Link
CN (1) CN116523079A (en)

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117094381A (en) * 2023-08-21 2023-11-21 哈尔滨工业大学 Multi-mode federal collaboration method taking high-efficiency communication and individuation into consideration
CN117273119A (en) * 2023-08-24 2023-12-22 北京邮电大学 Dynamic fairness federal learning method and device based on reinforcement learning
CN117313835A (en) * 2023-10-08 2023-12-29 湖北大学 Federal learning method based on client contribution clearance in heterogeneous data environment
CN117392483A (en) * 2023-12-06 2024-01-12 山东大学 Album classification model training acceleration method, system and medium based on reinforcement learning
CN117575291A (en) * 2024-01-15 2024-02-20 湖南科技大学 Federal learning data collaborative management method based on edge parameter entropy
CN117592555A (en) * 2023-11-28 2024-02-23 中国医学科学院北京协和医院 Federal learning method and system for multi-source heterogeneous medical data
CN117725965A (en) * 2024-02-06 2024-03-19 湘江实验室 Federal edge data communication method based on tensor mask semantic communication

Cited By (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117094381A (en) * 2023-08-21 2023-11-21 哈尔滨工业大学 Multi-mode federal collaboration method taking high-efficiency communication and individuation into consideration
CN117094381B (en) * 2023-08-21 2024-04-12 哈尔滨工业大学 Multi-mode federal collaboration method taking high-efficiency communication and individuation into consideration
CN117273119A (en) * 2023-08-24 2023-12-22 北京邮电大学 Dynamic fairness federal learning method and device based on reinforcement learning
CN117313835A (en) * 2023-10-08 2023-12-29 湖北大学 Federal learning method based on client contribution clearance in heterogeneous data environment
CN117592555A (en) * 2023-11-28 2024-02-23 中国医学科学院北京协和医院 Federal learning method and system for multi-source heterogeneous medical data
CN117592555B (en) * 2023-11-28 2024-05-10 中国医学科学院北京协和医院 Federal learning method and system for multi-source heterogeneous medical data
CN117392483A (en) * 2023-12-06 2024-01-12 山东大学 Album classification model training acceleration method, system and medium based on reinforcement learning
CN117392483B (en) * 2023-12-06 2024-02-23 山东大学 Album classification model training acceleration method, system and medium based on reinforcement learning
CN117575291A (en) * 2024-01-15 2024-02-20 湖南科技大学 Federal learning data collaborative management method based on edge parameter entropy
CN117575291B (en) * 2024-01-15 2024-05-10 湖南科技大学 Federal learning data collaborative management method based on edge parameter entropy
CN117725965A (en) * 2024-02-06 2024-03-19 湘江实验室 Federal edge data communication method based on tensor mask semantic communication
CN117725965B (en) * 2024-02-06 2024-05-14 湘江实验室 Federal edge data communication method based on tensor mask semantic communication

Similar Documents

Publication Publication Date Title
CN116523079A (en) Reinforced learning-based federal learning optimization method and system
CN109948029B (en) Neural network self-adaptive depth Hash image searching method
CN113191484B (en) Federal learning client intelligent selection method and system based on deep reinforcement learning
WO2018161468A1 (en) Global optimization, searching and machine learning method based on lamarck acquired genetic principle
CN110851782A (en) Network flow prediction method based on lightweight spatiotemporal deep learning model
EP4350572A1 (en) Method, apparatus and system for generating neural network model, devices, medium and program product
CN111277434A (en) Network flow multi-step prediction method based on VMD and LSTM
WO2022126448A1 (en) Neural architecture search method and system based on evolutionary learning
CN111158912A (en) Task unloading decision method based on deep learning in cloud and mist collaborative computing environment
US20220318412A1 (en) Privacy-aware pruning in machine learning
CN115374853A (en) Asynchronous federal learning method and system based on T-Step polymerization algorithm
Long et al. Fedsiam: Towards adaptive federated semi-supervised learning
CN112836822A (en) Federal learning strategy optimization method and device based on width learning
CN114819143A (en) Model compression method suitable for communication network field maintenance
CN115359298A (en) Sparse neural network-based federal meta-learning image classification method
CN116645130A (en) Automobile order demand prediction method based on combination of federal learning and GRU
Kozat et al. Universal switching linear least squares prediction
CN111832817A (en) Small world echo state network time sequence prediction method based on MCP penalty function
CN113194493B (en) Wireless network data missing attribute recovery method and device based on graph neural network
Adeleke Echo-state networks for network traffic prediction
CN113128432B (en) Machine vision multitask neural network architecture searching method based on evolution calculation
Gong et al. Compressed particle-based federated bayesian learning and unlearning
CN116976405A (en) Variable component shadow quantum neural network based on immune optimization algorithm
CN114819196B (en) Noise distillation-based federal learning system and method
CN110768825A (en) Service flow prediction method based on network big data analysis

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination