CN114091667A - Federal mutual learning model training method oriented to non-independent same distribution data - Google Patents

Federal mutual learning model training method oriented to non-independent same distribution data Download PDF

Info

Publication number
CN114091667A
CN114091667A CN202111386087.7A CN202111386087A CN114091667A CN 114091667 A CN114091667 A CN 114091667A CN 202111386087 A CN202111386087 A CN 202111386087A CN 114091667 A CN114091667 A CN 114091667A
Authority
CN
China
Prior art keywords
client
model
edge
server
clients
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111386087.7A
Other languages
Chinese (zh)
Inventor
李侃
李洋
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Institute of Technology BIT
Original Assignee
Beijing Institute of Technology BIT
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Institute of Technology BIT filed Critical Beijing Institute of Technology BIT
Priority to CN202111386087.7A priority Critical patent/CN114091667A/en
Publication of CN114091667A publication Critical patent/CN114091667A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks

Abstract

The invention provides a federal mutual learning model training method facing non-independent same distribution data, which comprises the following steps: s1, sending the initial global model parameters to an intermediate client, generating intermediate client model parameters by the intermediate client, and generating edge client model parameters by the edge client by using a local data set S2; s3, updating parameters by the aid of a mutual learning method through the middle client and the edge client; s4, uploading the probability predicted value output by the intermediate client model to a server, and updating the global model and the intermediate client model by the server by using a distillation technology; and S5, executing the steps S3-S4 until the model meets the convergence condition to obtain a final intermediate client model, an edge client model and a global model, and then broadcasting the final global model to all edge clients by the server. In the invention, the problems of the federal learning communication bandwidth limitation and the model generation of the non-independent same distribution data are solved by a grouping mutual learning and knowledge distillation method.

Description

Federal mutual learning model training method for non-independent same distribution data
Technical Field
The invention relates to the technical field of federal learning, in particular to a method for training a federal mutual learning model facing non-independent same-distribution data.
Background
Along with the development of artificial intelligence and big data, information resources can be transmitted at high speed in a distributed network, and the overall management of physical information is realized. Cloud computing and edge computing promote the development of deep learning, a deep learning model usually contains millions of parameters, the scale of a neural network is usually enlarged, the model precision can be effectively improved, but in a distributed network, a resource-limited edge device cannot deploy a large neural network model, communication delay and the condition of rejecting a low-priority user terminal to access can occur, and at the moment, strong computing power and storage space of a central server are needed. In actual life, data collection, exchange, attribution and the like among enterprises are strictly supervised by laws and regulations, a large amount of high-quality training data is very difficult to obtain, and meanwhile, the privacy disclosure problem is increasingly highlighted.
The federated learning is a distributed framework for decoupling data and models, can solve the problems of data islanding and privacy protection, does not need to centralize data of participants to a central storage point in the whole process, and realizes the joint modeling of each participant under the condition that the data does not leave the local. In the federal learning framework, a central calculation party generally exists, which is used for collecting model parameter information transmitted by other parties, updating the model parameter information through a corresponding algorithm, returning tasks to the parties, repeating iteration until convergence, and finally constructing an effective global model. The client and the server have no right to acquire and control data of other clients in the whole process, normal use of client equipment is not affected during establishment of the federal learning model, the trained federal learning model can be shared and deployed among all data participants, and the method has wide application prospects in the fields of intelligent medical treatment, financial insurance, intelligent internet of things and the like.
In a real-world scenario, the data distribution across the device side may be statistically heterogeneous, including situations of label assignment skew, feature distribution skew, and quantity skew. The method mainly researches how to train on the non-independent same-distribution data in federal learning, and the traditional federal average algorithm trains on the non-independent same-distribution data, so that the traditional federal average algorithm is easily influenced by discrete values, and the global model effect obtained by introducing large variance is not ideal. For example: data heterogeneity can cause slow convergence of the global model, and increase the number of communication times between the server and the client and the number of training rounds local to the client. Federal learning can be orders of magnitude slower than traditional distributed systems in terms of network bandwidth, bandwidth limitations can cause client connections to fail or break, and network access-limited or connection-failed clients can be dropped during communications. The number of nodes may be several orders of magnitude more than that of a distributed system, a large number of users are deployed, and strict requirements on delay and computing resources are provided, and as the number of communication rounds and the number of users increase, the problem of communication overhead becomes more prominent. For the above problems, related studies have been conducted from various angles:
aiming at the problem of communication overhead, the cloud server can access more data, but the communication overhead is too large and the communication delay is longer, compared with the cloud-based joint learning, the client-edge-cloud layered federated learning method has the advantages that the direct communication between the client and the cloud server can be reduced, the model training time is shortened, and the energy consumption of terminal equipment is reduced by introducing the middle edge server, but the number of clients and the data volume accessed by the edge server are limited, and the more obvious communication problems can be brought by the heterogeneity and the data heterogeneity among the equipment; researchers provide a multi-center aggregated federal learning model training method, and try to cluster clients with similar tasks into a plurality of centers for training, so that the convergence speed of each center is increased, and the accuracy of models in each center is improved; by increasing the calculated amount of the client, the communication exchange times between the server and the client are reduced, and the method is more favorable for quickly generating the model; the method of compressing the model parameter scale, pruning the network and the like is utilized to reduce the exchange of a large number of model parameters in the communication process; the knowledge distillation technology is that a large-scale complex teacher network is used for helping a small-scale simple student network training model, the student network can learn information of real labels and relationship information of different labels, and the small model is easy to deploy and used for solving the problem of bandwidth limitation of federal learning.
Aiming at the problem of data heterogeneity, the traditional integration method and the federal average algorithm are easily affected by outliers, model deviation is caused in the training process under the scene of cross-equipment heterogeneous data, and the final global model cannot play a role in improving the training models of all participants. Researchers put forward a shared part of global data subsets, reduce the difference degree of heterogeneous data of cross-equipment, but have hidden danger of privacy disclosure, in order to reduce the risk of privacy disclosure, a large amount of redundant data is added into a public data set, and a countermeasure network is utilized to generate a pseudo sample; by combining knowledge distillation and generation countermeasure network technologies, the missing samples of all equipment are complemented in other equipment, and the data distribution of all clients is converted into the situation of independent and same-distribution data, so that an effective global model is constructed, but the communication cost is increased; it is also believed by the scholars that clients with similar tasks can be trained together to generate a plurality of different task models.
The above mentioned methods explore federal learning of non-independent co-distributed data, and the following problems remain unsolved:
1. compared with the client-edge-cloud layered federal learning system based on the cloud, the client-edge-cloud layered federal learning system introduces the middle edge server, so that the direct passing of a client and a cloud server can be reduced, the model training time is reduced, but a large amount of parameters need to be transmitted, and a scheme capable of reducing the communication cost is still lacked.
2. In recent years, knowledge distillation technology is combined with federal learning, but when processing non-independent co-distributed data, a teacher model needs to be generated on a proper proxy data set in advance, but it is difficult to find a proper data set to generate the teacher model, and currently, a federal learning method for solving the problems of high communication cost and training performance loss encountered when processing non-independent co-distributed data is lacked.
Disclosure of Invention
Aiming at the defects of the prior art and aiming at solving the problems of communication bandwidth limitation and model generation in the process of training non-independent and distributed data in the federal learning, the invention constructs a model training method for solving the problem of the federal mutual learning of heterogeneous data, and adopts a method combining the grouped mutual learning and distillation technologies, so that the grouped mutual learning method can reduce the requirement for generating a teacher model in advance, and the student network exchange model output mutually, thereby avoiding transmitting a large number of parameters, reducing the direct interaction times of a server and a client, improving the effect of the training model of the federal learning on the heterogeneous data and solving the problem of asymmetric upstream and downstream bandwidth speeds, generating a global model and a local personalized model, and being used for users to select in practical application.
In order to achieve the purpose, the invention is realized by the following technical scheme:
a federal mutual learning model training method facing non-independent same-distribution data comprises the following steps:
s1, the server sends the initialized global model parameters to the middle clients of each group, the middle clients generate middle client model parameters by using the middle client data set, and the generated middle client model parameters are sent to the edge clients in the group;
s2, the edge client receives the intermediate client model parameters, and the edge client generates the edge client model parameters by using the local data set;
s3, performing multiple rounds of training by using the inter-learning method for the middle client and the edge client in the group, and updating the model parameters of the middle client and the model parameters of the edge client;
s4, uploading the label category probability predicted value of the intermediate client model to a server by all groups of intermediate clients, and updating global model parameters;
and S5, repeatedly executing the steps S3-S4 until a convergence condition is met, obtaining an intermediate client model, an edge client model and a global model, and broadcasting the final global model generated by training to all edge clients by the server.
Further, the global model, the intermediate client model and the edge client model are neural network models.
Further, in step S2, the intermediate client and the edge client respectively update the model parameters on the local data set by using a random gradient descent algorithm.
Further, in step S3, the mutual learning method includes:
s31, taking all edge clients in the group as the student network, wherein the output of all edge client models in t rounds is
Figure BDA0003367079170000041
C is the number of the edge clients, and the probability prediction value of the label category is calculated
Figure BDA0003367079170000042
And transmitted to the intermediate client; i represents the ith intermediate client, and j represents the jth edge client connected with the ith intermediate client;
s32, intermediate client computation of group i
Figure BDA0003367079170000043
c represents the edge client, represents the average value of probability predicted values of the computed label categories, and the KL divergence D of the student network of the middle client model of the ith group and the edge client model of the groupKLThe formula is as follows:
Figure BDA0003367079170000044
Figure BDA0003367079170000045
m represents the number of intermediate clients,
Figure BDA0003367079170000046
representing the probability prediction value of the tag class of the t round of the intermediate client of the ith group, wherein the loss function of the intermediate client model of the ith group is
Figure BDA0003367079170000047
Figure BDA0003367079170000048
S33, the intermediate client of the ith group updates the intermediate client model parameters on the intermediate client data set by using a stochastic gradient descent algorithm
Figure BDA0003367079170000051
S34, calculating KL divergence D by jth edge client of ith groupKLAnd loss function
Figure BDA0003367079170000052
The formula is as follows:
Figure BDA0003367079170000053
s35, the edge client updates the edge client model parameters on the edge client data set by using the stochastic gradient descent algorithm
Figure BDA0003367079170000054
S36, after N rounds of S31-S35 are executed, all the intermediate clients calculate the probability predicted value of the label category
Figure BDA0003367079170000055
And transmitted to the server.
Further, a temperature parameter T is added to a Softmax function of the edge client model and used for adjusting output probability distribution, and the edge client calculates to obtain a probability predicted value corresponding to the label category
Figure BDA0003367079170000056
Figure BDA0003367079170000057
Further, the server receives the probability prediction values of the corresponding label categories of the intermediate client model
Figure BDA0003367079170000058
Output of intermediate client
Figure BDA0003367079170000059
i represents the ith intermediate client, m represents the intermediate client type, adoptUpdating the global model with a distillation learning method, the distillation learning method comprising:
s41, the server calculates the loss function of the global model
Figure BDA00033670791700000510
z represents the number of interaction rounds of the server and the intermediate client, and the formula is as follows:
Figure BDA00033670791700000511
and S42, updating global model parameters on the local data set of the server by using a gradient descent algorithm:
Figure BDA00033670791700000512
z is the number of rounds of mutual learning of the middle client and the edge client;
s43, each group of intermediate clients respectively calculate KL divergence and loss functions, and the formula is as follows:
Figure BDA00033670791700000513
Figure BDA00033670791700000514
in the invention, the problem of underutilization of bandwidth is solved, and the time of a model for training non-independent and uniformly distributed data by federal learning is shortened.
Drawings
In order to more clearly illustrate the embodiments of the present invention or the technical solutions in the prior art, the drawings used in the description of the embodiments or the prior art will be briefly introduced below, and it is obvious that the drawings in the following description are some embodiments of the present invention, and for those skilled in the art, other drawings can be obtained according to these drawings without creative efforts.
FIG. 1 is a schematic diagram of an intermediate client and edge client grouping scenario;
FIG. 2 is a flow diagram of a federated mutual learning model training method in accordance with one embodiment of the present invention;
FIG. 3 is a flow diagram of a cross-learning method according to one embodiment of the invention;
FIG. 4 is a schematic flow diagram of a distillation process according to one embodiment of the present invention.
Detailed Description
In order to make the objects, technical solutions and advantages of the embodiments of the present invention clearer, the technical solutions in the embodiments of the present invention will be clearly and completely described below with reference to the drawings in the embodiments of the present invention, and it is obvious that the described embodiments are some, but not all, embodiments of the present invention. All other embodiments, which can be obtained by a person skilled in the art without any inventive step based on the embodiments of the present invention, are within the scope of the present invention.
The method provided by the invention can be executed in a mobile terminal, a computer terminal or a similar computing device, and each participant has a plurality of training data which are respectively expected to be used for training the model. As shown in fig. 1, the intermediate clients are grouped, each group includes one intermediate client and a plurality of edge clients, an edge client may be a plurality of devices, and an intermediate client may access a certain number of edge clients. The invention provides a federal mutual learning model training method facing non-independent same-distribution data, as shown in figure 2, comprising the following steps:
s1, the server establishes communication with the intermediate client, the server sends the initialized global model parameters to the intermediate client, the intermediate client model parameters are generated on the local data set of the intermediate client, and the intermediate client model parameters are sent to the edge client;
s2, after grouping, the middle client of each group establishes communication with the edge client in the group, the edge client receives the model parameters of the middle client, and the edge client generates the model parameters of the edge client on the local data set;
s3, the intermediate client and the edge client update the intermediate client model parameters and the edge client model parameters by using a mutual learning algorithm;
s4, the intermediate client uploads the probability predicted value of the label category of the intermediate client model to the server, and the server updates the global model by using a distillation technology;
s5, repeating the steps S3-S4 until the convergence condition is met, obtaining an intermediate client model, an edge client model and a global model, and then broadcasting the final global model to all edge clients by the server.
In fig. 1, the server is represented by S, the middle client is represented by M, the edge client is represented by C, and the global model, the middle client model, and the edge client model are all neural network models.
In step S1, the server seeks participation of clients according to the states of the nodes (intermediate client and edge client), and sends a modeling task to the intermediate client and the edge client, where 0 represents offline, 1 represents online, the intermediate client M belongs to {1, …, M }, the edge client C belongs to {1, …, C }, and when Class M is 1 and Class C is 1, the intermediate client M and the edge client C belong to the same group. Global model parameters
Figure BDA0003367079170000071
When z is equal to 0, the ratio of z,
Figure BDA0003367079170000072
as an initial global model, intermediate client model parameters
Figure BDA0003367079170000073
Figure BDA0003367079170000074
Model parameters representing the ith update of the ith intermediate client, edge client model parameters
Figure BDA0003367079170000075
And representing the model parameters updated at the t time by the jth edge client of the ith group.
The middle client and the edge client feed back the joint modeling response according to the requirements of the middle client and the edge client, and the edge clients in each groupThe data set of the intermediate client side meets the non-independent same distribution, such as: the edge clients within each group contain only one or more disjoint classes of data tags. After the edge client responds, the server will
Figure BDA0003367079170000076
And sending the data to the intermediate client.
Taking the ith group as an example, the intermediate client M e {1, …, M } receives the global model parameters
Figure BDA0003367079170000077
Then, will
Figure BDA0003367079170000078
Initializing as an intermediate client parameter, using
Figure BDA0003367079170000079
Representing, at a local data set
Figure BDA00033670791700000710
(wherein, xiAs a characteristic of the data sample, yiIs a true label of the sample, n is the number of samples) updates the intermediate client model parameters by a stochastic gradient descent algorithm
Figure BDA00033670791700000711
And (4) passing the intermediate client model parameters to the edge client C e {1, …, C } in the group.
In step S2, the edge client receives the global model parameters
Figure BDA0003367079170000081
Then, initializing edge client model parameters
Figure BDA0003367079170000082
Edge client local data set
Figure BDA0003367079170000083
(wherein, x)iAs a characteristic of the data sample, yiIs a real label of a sample, n is the number of the samples, the sample types among all edge clients are not mutually intersected and meet the condition of non-independent same distribution) and edge client model parameters are generated by utilizing a random gradient descent algorithm
Figure BDA0003367079170000084
Figure BDA0003367079170000085
Wherein eta is the learning rate,
Figure BDA0003367079170000086
representing the update gradient of the ith group of jth edge clients for the first time.
In step S3, the middle client and edge client model parameters are updated using a mutual learning algorithm. As shown in fig. 3, the mutual learning algorithm is specifically described as follows:
s31, taking all edge clients c in the group (the middle client and the corresponding edge client are a group) as a student network (the student network is a noun in mutual learning), here labeled as ijRepresenting the ith group of j edge clients, the output of the edge client model
Figure BDA0003367079170000087
Computing probabilistic predictive values for label categories
Figure BDA0003367079170000088
Adding a soft label output by temperature parameter T transformation in a Softmax function (an activation function of an output layer of an edge client model), and transmitting the soft label to an intermediate client;
s32, calculating at the intermediate client
Figure BDA0003367079170000089
KL divergence D of intermediate client model and edge client student networkKLFor measuring the degree of difference between data distributions, the formula is:
Figure BDA00033670791700000810
the loss function of the intermediate client model is
Figure BDA00033670791700000811
S33, the intermediate client updates the intermediate client model parameters on the local data set by using a stochastic gradient descent algorithm
Figure BDA00033670791700000812
S34, calculating divergence D by edge clientKLCalculating a loss function of the edge client model at the edge client, wherein the formula is as follows:
Figure BDA00033670791700000813
s35, the edge client updates the edge client model parameters on the edge client data set by using the stochastic gradient descent algorithm
Figure BDA0003367079170000091
S36 and S31-S35 are mutual learning update processes of the intermediate client and the edge client, and after N updates (N is manual setting) are executed, the output of the intermediate client is output
Figure BDA0003367079170000092
And transmitting to the server.
In step S4, the server receives a probabilistic predictive value for a tag class of the intermediate client model
Figure BDA0003367079170000093
Updating the global model by using a distillation learning algorithm, as shown in fig. 4, the distillation learning method includes:
s41, the server calculates the global model loss function
Figure BDA0003367079170000094
The formula is as follows:
Figure BDA0003367079170000095
s42, updating global model parameters on the local data set of the server by using a gradient descent algorithm
Figure BDA0003367079170000096
S43, the intermediate client calculates the intermediate client model loss function on the local data set as follows:
Figure BDA0003367079170000097
in step S5, steps S3 to S4 are repeatedly executed until the global model, the intermediate client model, and the edge client model satisfy the convergence condition, and finally the intermediate client model parameters are generated
Figure BDA0003367079170000098
Edge client model parameters
Figure BDA0003367079170000099
And global model parameters
Figure BDA00033670791700000910
Global model parameters
Figure BDA00033670791700000911
Distributed to all edge clients.
Example 1
According to the method steps provided by the embodiment, a model training method for solving the non-independent same-distribution federal learning is introduced, taking an MNIST image classification task as an example. Each client is a data owner, at least one terminal participates in training, and an intermediate client participates in training. The sample data held by each data owner may be the same data set or different data sets. The server and the client database store MNIST data sets, and after manual segmentation, the edge clients in each group only contain one or more non-intersected data labels, so that the condition of non-independent and same distribution is met. The server, the intermediate clients and the edge clients use a LeNet-5 convolutional neural network (input layer dimension 28 × 28, convolutional layer includes 6 convolution kernels of 5 × 5, maximum pooling layer includes 1 kernel of 2 × 2, convolutional layer includes 16 convolution kernels of 5 × 5, maximum pooling layer includes 1 kernel of 2 × 2, convolutional layer includes 120 convolution kernels of 5 × 5, fully connected layer includes 84 neurons, output layer dimension 1 × 10), the input vector is an image with 28 × 28 pixels, and is converted into an output of a 10-dimensional vector through LeNet-5, the predicted maximum probability value of the output layer is a corresponding label value, for example, [0.1028, -0.0138, -0.0664, -0.0575, -0.0387,0.0269,0.1911,0.0022, -0.0441, -0.0797], and the predicted result of soft label is 6. The intermediate client, the edge client and the server all use the neural network, and the specific steps are as follows:
in step S1, the server is denoted by S, the intermediate client miI 1,2, edge client ciI 1,2,3,4, grouping the intermediate clients and the edge clients, the first group comprising m1,c1,c2The second group includes m2,c3,c4. The server seeks client participation according to the node state, and sends a modeling task to the edge client, and the edge client c1The data set of (1) is an image containing 1000 labels of 1 and 2, and the edge client c2Is an image of 500 tags 3,4, edge client c3Is 1000 images labeled 5, 6, edge client c4Is 2000 images labeled 7, 8, the intermediate client m1Is 1000 images labeled 1, 3, 5, the intermediate client m2The data set of (1) is 1000 images labeled 2, 4, 6, and the data set of the server S is 2000 images labeled 1, 9. Server generation of initial model
Figure BDA0003367079170000101
The server sends the initial model parameters
Figure BDA0003367079170000102
Transmitted to an intermediate client, the intermediate client generating the model on the local data set, for example in the first group
Figure BDA0003367079170000103
In step S2, m is used1,c1,c2For example, intermediate client model parameters
Figure BDA0003367079170000104
Transmission to edge client c1,c2Edge clients generate local models in local datasets
Figure BDA0003367079170000105
In step S3, the edge client c1,c2Is output as
Figure BDA0003367079170000106
Computing probabilistic predictive values for label categories
Figure BDA0003367079170000107
To the intermediate client m1Intermediate client computing
Figure BDA0003367079170000108
Figure BDA0003367079170000109
And loss function
Figure BDA00033670791700001010
Updating intermediate client model parameters
Figure BDA00033670791700001011
Calculating a loss function of each edge client model as
Figure BDA00033670791700001012
Where i isj1,2, which is one time of the intermediate client and the edge client in the groupMutual learning updating process, after N times of updating are executed, the probability prediction values of all the label categories of the intermediate client sides are obtained
Figure BDA0003367079170000111
And transmitting to the server.
In step S4, the server receives the probability prediction values of the label categories
Figure BDA0003367079170000112
Using a distillation learning algorithm, the server computes a server model loss function as
Figure BDA0003367079170000113
Figure BDA0003367079170000114
The intermediate client computes a model loss function as:
Figure BDA0003367079170000115
in step S5, the convergence condition is satisfied after the repeated execution of S3-S4, and the final global model is generated at the server
Figure BDA0003367079170000116
And broadcasting the final global model generated by training to all edge clients.
The above steps are the whole process of the embodiment, the edge client model, the intermediate client model and the global client model can be generated in the whole process, the global model can be selected when the accuracy of the edge client generation is low, and the local model is selected when the accuracy of the edge client local model is high.
The invention provides a federal mutual learning model training method facing to non-independent same-distribution data, which reduces the direct communication between an edge client and a server in the training process, although the number of the edge clients accessed by a middle client is limited, the communication efficiency is higher, compared with the traditional method, the more complex network generates more parameters, the more parameters refer to the weighted value on the neural network node at each moment, the parameter sum is the number of the nodes multiplied by the number of the parameters on the nodes, a large amount of bandwidth is occupied by transmitting a large amount of parameters in the training process, compared with the mutual learning method, only the output of the model is transmitted in the training process without relation with the complexity of the model structure, the transmission of a large amount of model parameters is reduced, the uplink and downlink bandwidths between the clients are fully utilized, the influence of bandwidth limitation on the federal learning training is weakened, the edge clients utilize a distillation algorithm to cooperatively train a global model, meanwhile, parameter changes of all groups of intermediate clients caused by the non-independent same-distribution data can be dynamically adjusted, so that the influence of the non-independent same-distribution data on the Federal learning model training is solved.
The above examples are only for illustrating the technical solutions of the present invention, and not for limiting the same; although the present invention has been described in detail with reference to the foregoing embodiments, it should be understood by those of ordinary skill in the art that: the technical solutions described in the foregoing embodiments may still be modified, or some technical features may be equivalently replaced; and such modifications or substitutions do not depart from the spirit and scope of the corresponding technical solutions.

Claims (8)

1. A federal mutual learning model training method oriented to non-independent same-distribution data is characterized by comprising the following steps:
s1, the server sends the initialized global model parameters to the middle clients of each group, the middle clients generate middle client model parameters by using the middle client data set, and the generated middle client model parameters are sent to the edge clients in the group;
s2, the edge client receives the intermediate client model parameters, and the edge client generates the edge client model parameters by using the local data set;
s3, performing multiple rounds of training by using the inter-learning method for the middle client and the edge client in the group, and updating the model parameters of the middle client and the model parameters of the edge client;
s4, uploading the label category probability predicted value of the intermediate client model to a server by all groups of intermediate clients, and updating global model parameters;
and S5, repeatedly executing the steps S3-S4 until a convergence condition is met, obtaining an intermediate client model, an edge client model and a global model, and broadcasting the generated final global model to all edge clients by the server.
2. The model training method of claim 1, wherein the global model, the intermediate client model, and the edge client model are neural network models.
3. The model training method of claim 1, wherein in step S2, the intermediate client and the edge client respectively update model parameters on the local data set using a stochastic gradient descent algorithm.
4. The model training method according to claim 1, wherein in step S3, the mutual learning method comprises:
s31, recording the output of all edge client models in the group in t rounds
Figure FDA0003367079160000011
C is the number of the edge clients, and the probability prediction value of the label category is calculated
Figure FDA0003367079160000012
And transmitted to the intermediate client; i represents the ith intermediate client, and j represents the jth edge client connected with the ith intermediate client;
s32, intermediate client computation of group i
Figure FDA0003367079160000013
c represents the edge client, and the KL divergence D of the intermediate client model of the ith group and the KL divergence D of the edge client models in the group is calculatedKLAnd a loss function of the intermediate client model of the ith group
Figure FDA0003367079160000021
S33, the intermediate client of the ith group updates the intermediate client model parameters on the intermediate client data set by using a stochastic gradient descent algorithm
Figure FDA0003367079160000022
S34, calculating KL divergence D by jth edge client of ith groupKLAnd loss function
Figure FDA0003367079160000023
S35, the edge client updates the edge client model parameters on the edge client data set by using the stochastic gradient descent algorithm
Figure FDA0003367079160000024
S36, after N rounds of S31-S35 are executed, all the intermediate clients calculate the probability predicted value of the label category
Figure FDA0003367079160000025
And transmitted to the server.
5. The model training method according to claim 4, wherein in step S31, the divergence formula is:
Figure FDA0003367079160000026
m represents the number of intermediate clients,
Figure FDA0003367079160000027
representing the probability prediction value of the tag class of the t round of the intermediate client of the ith group, wherein the loss function of the intermediate client model of the ith group is
Figure FDA0003367079160000028
6. The model training method according to claim 4, wherein in step S34, the formula for calculating the loss function is:
Figure FDA0003367079160000029
7. the model training method of claim 4, wherein the temperature parameter T is added to the Softmax function of the edge client model for adjusting the output probability distribution, and the edge client calculates the predicted probability value of the corresponding label category
Figure FDA00033670791600000210
8. The model training method of claim 4, wherein the server receives the probability predictors for the corresponding label classes of the intermediate client model
Figure FDA00033670791600000211
Output of intermediate client
Figure FDA00033670791600000212
i represents the ith intermediate client, m represents the type of the intermediate client, and the distillation learning method is adopted to update the global model and comprises the following steps:
s41, the server calculates the loss function of the global model
Figure FDA00033670791600000213
z represents the number of interaction rounds of the server and the intermediate client, and the formula is as follows:
Figure FDA0003367079160000031
s42, using gradient descent algorithm to the local data set of the serverNew global model parameters:
Figure FDA0003367079160000032
z is the number of rounds of mutual learning of the middle client and the edge client;
s43, each group of intermediate clients respectively calculate KL divergence and loss functions, and the formula is as follows:
Figure FDA0003367079160000033
Figure FDA0003367079160000034
CN202111386087.7A 2021-11-22 2021-11-22 Federal mutual learning model training method oriented to non-independent same distribution data Pending CN114091667A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111386087.7A CN114091667A (en) 2021-11-22 2021-11-22 Federal mutual learning model training method oriented to non-independent same distribution data

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111386087.7A CN114091667A (en) 2021-11-22 2021-11-22 Federal mutual learning model training method oriented to non-independent same distribution data

Publications (1)

Publication Number Publication Date
CN114091667A true CN114091667A (en) 2022-02-25

Family

ID=80302767

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111386087.7A Pending CN114091667A (en) 2021-11-22 2021-11-22 Federal mutual learning model training method oriented to non-independent same distribution data

Country Status (1)

Country Link
CN (1) CN114091667A (en)

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114866545A (en) * 2022-04-19 2022-08-05 郑州大学 Semi-asynchronous layered federal learning method and system based on air calculation
CN114863169A (en) * 2022-04-27 2022-08-05 电子科技大学 Image classification method combining parallel ensemble learning and federal learning
CN115511108A (en) * 2022-09-27 2022-12-23 河南大学 Data set distillation-based federal learning personalized method
CN115879467A (en) * 2022-12-16 2023-03-31 浙江邦盛科技股份有限公司 Federal learning-based Chinese address word segmentation method and device
CN116189874A (en) * 2023-03-03 2023-05-30 海南大学 Telemedicine system data sharing method based on federal learning and federation chain
CN115879467B (en) * 2022-12-16 2024-04-30 浙江邦盛科技股份有限公司 Chinese address word segmentation method and device based on federal learning

Cited By (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114866545A (en) * 2022-04-19 2022-08-05 郑州大学 Semi-asynchronous layered federal learning method and system based on air calculation
CN114866545B (en) * 2022-04-19 2023-04-25 郑州大学 Semi-asynchronous hierarchical federal learning method and system based on air calculation
CN114863169A (en) * 2022-04-27 2022-08-05 电子科技大学 Image classification method combining parallel ensemble learning and federal learning
CN114863169B (en) * 2022-04-27 2023-05-02 电子科技大学 Image classification method combining parallel integrated learning and federal learning
CN115511108A (en) * 2022-09-27 2022-12-23 河南大学 Data set distillation-based federal learning personalized method
CN115879467A (en) * 2022-12-16 2023-03-31 浙江邦盛科技股份有限公司 Federal learning-based Chinese address word segmentation method and device
CN115879467B (en) * 2022-12-16 2024-04-30 浙江邦盛科技股份有限公司 Chinese address word segmentation method and device based on federal learning
CN116189874A (en) * 2023-03-03 2023-05-30 海南大学 Telemedicine system data sharing method based on federal learning and federation chain
CN116189874B (en) * 2023-03-03 2023-11-28 海南大学 Telemedicine system data sharing method based on federal learning and federation chain

Similar Documents

Publication Publication Date Title
CN114091667A (en) Federal mutual learning model training method oriented to non-independent same distribution data
Pu et al. Asymptotic network independence in distributed stochastic optimization for machine learning: Examining distributed and centralized stochastic gradient descent
Le et al. Federated continuous learning with broad network architecture
Liu et al. Resource-constrained federated edge learning with heterogeneous data: Formulation and analysis
CN113065974A (en) Link prediction method based on dynamic network representation learning
Zhang et al. Prediction for network traffic of radial basis function neural network model based on improved particle swarm optimization algorithm
CN115587633A (en) Personalized federal learning method based on parameter layering
CN115879542A (en) Federal learning method oriented to non-independent same-distribution heterogeneous data
Zhou Deep embedded clustering with adversarial distribution adaptation
CN113313266B (en) Federal learning model training method based on two-stage clustering and storage device
CN108009635A (en) A kind of depth convolutional calculation model for supporting incremental update
CN117236421B (en) Large model training method based on federal knowledge distillation
CN114626550A (en) Distributed model collaborative training method and system
CN114065033A (en) Training method of graph neural network model for recommending Web service combination
Tanghatari et al. Federated learning by employing knowledge distillation on edge devices with limited hardware resources
CN113240086A (en) Complex network link prediction method and system
Wang Multimodal emotion recognition algorithm based on edge network emotion element compensation and data fusion
Chen et al. Resource-aware knowledge distillation for federated learning
CN116187469A (en) Client member reasoning attack method based on federal distillation learning framework
CN116582442A (en) Multi-agent cooperation method based on hierarchical communication mechanism
CN114265954B (en) Graph representation learning method based on position and structure information
CN116227632A (en) Federation learning method and device for heterogeneous scenes of client and heterogeneous scenes of data
Gowgi et al. Temporal self-organization: A reaction–diffusion framework for spatiotemporal memories
Hettiarachchi et al. Time series regression and artificial neural network approaches for forecasting unit price of tea at Colombo auction
Tao et al. Communication Efficient Federated Learning via Channel-wise Dynamic Pruning

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination