CN116976461A - Federal learning method, apparatus, device and medium - Google Patents

Federal learning method, apparatus, device and medium Download PDF

Info

Publication number
CN116976461A
CN116976461A CN202310765118.2A CN202310765118A CN116976461A CN 116976461 A CN116976461 A CN 116976461A CN 202310765118 A CN202310765118 A CN 202310765118A CN 116976461 A CN116976461 A CN 116976461A
Authority
CN
China
Prior art keywords
model
sub
cluster
equipment
global
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310765118.2A
Other languages
Chinese (zh)
Inventor
刘吉
陈春璐
周吉文
荆博
熊昊一
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Baidu Netcom Science and Technology Co Ltd
Original Assignee
Beijing Baidu Netcom Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Baidu Netcom Science and Technology Co Ltd filed Critical Beijing Baidu Netcom Science and Technology Co Ltd
Priority to CN202310765118.2A priority Critical patent/CN116976461A/en
Publication of CN116976461A publication Critical patent/CN116976461A/en
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/23Clustering techniques
    • G06F18/232Non-hierarchical techniques
    • G06F18/2321Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions
    • G06F18/23213Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions with fixed number of clusters, e.g. K-means clustering
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/25Fusion techniques

Abstract

The present disclosure provides a federal learning method, apparatus, device, and medium. Relates to the technical field of data processing, in particular to the technical fields of artificial intelligence, deep learning and the like. The specific implementation scheme is as follows: acquiring performance parameters of a plurality of edge devices; performing cluster analysis on the plurality of edge devices based on performance parameters of the plurality of edge devices to obtain a plurality of device clusters; performing knowledge distillation on the cloud global model for a plurality of equipment clusters to distill out a plurality of first sub-models; and distributing the plurality of first sub-models to corresponding equipment cluster groups for federal learning. Aiming at the heterogeneous edge model, the embodiment of the disclosure can divide a plurality of equipment cluster groups through cluster analysis, and train the adaptive model for different equipment cluster groups by distilling through a knowledge distillation mode, so that the communication burden can be reduced.

Description

Federal learning method, apparatus, device and medium
Technical Field
The present disclosure relates to the field of data processing technologies, and in particular, to the technical field of artificial intelligence, deep learning, and the like.
Background
Artificial intelligence technology is rapidly developing, and deep learning technology is one of the most important technologies, and huge data is needed as a basis.
However, conventional machine learning collects data and then performs training in a centralized manner. This poses a great threat to some private information. Federal learning is born. Federal learning is a distributed machine learning technique, unlike previous distributed machine learning techniques, federal learning does not collect user data, but rather leaves the data local, allowing the user device itself to train the machine learning model in situ, and upload the trained model to the server. In this way, the data does not leave the local area, and the data security is ensured. However, how federal learning is performed remains worthy of study.
Disclosure of Invention
The present disclosure provides a federal learning method, apparatus, device, and storage medium.
According to an aspect of the present disclosure, there is provided a federal learning method comprising:
acquiring performance parameters of a plurality of edge devices;
performing cluster analysis on the plurality of edge devices based on performance parameters of the plurality of edge devices to obtain a plurality of device clusters;
performing knowledge distillation on the cloud global model for a plurality of equipment clusters to distill out a plurality of first sub-models;
and distributing the plurality of first sub-models to corresponding equipment cluster groups for federal learning.
According to another aspect of the present disclosure, there is provided a federal learning method comprising:
sending performance parameters of current edge equipment to a server, so that the server performs cluster analysis on the plurality of edge equipment based on the performance parameters of the plurality of edge equipment to obtain a plurality of equipment cluster groups, and performing knowledge distillation on a cloud global model to obtain a first sub-model applicable to the target equipment cluster group; the target device cluster group is a device cluster group where the current edge device is located;
federal learning is performed based on the first sub-model.
According to another aspect of the present disclosure, there is provided a federal learning apparatus comprising:
the acquisition module is used for acquiring performance parameters of a plurality of edge devices;
the clustering module is used for carrying out clustering analysis on the plurality of edge devices based on the performance parameters of the plurality of edge devices to obtain a plurality of device clusters;
the distillation module is used for carrying out knowledge distillation on the cloud global model for a plurality of equipment cluster groups so as to distill a plurality of first sub-models;
and the learning module is used for distributing the plurality of first sub-models to corresponding equipment cluster groups to perform federal learning.
According to another aspect of the present disclosure, there is provided a federal learning apparatus comprising:
The sending module is used for sending the performance parameters of the current edge equipment to the server so that the server performs cluster analysis on the plurality of edge equipment based on the performance parameters of the plurality of edge equipment to obtain a plurality of equipment cluster groups, and performs knowledge distillation on the cloud global model to obtain a first sub-model applicable to the target equipment cluster group; the target device cluster group is a device cluster group where the current edge device is located;
and the training module is used for carrying out federal learning based on the first sub-model.
According to another aspect of the present disclosure, there is provided an electronic device including:
at least one processor; and
a memory communicatively coupled to the at least one processor; wherein, the liquid crystal display device comprises a liquid crystal display device,
the memory stores instructions executable by the at least one processor to enable the at least one processor to perform the method of any one of the embodiments of the present disclosure.
According to another aspect of the present disclosure, there is provided a non-transitory computer-readable storage medium storing computer instructions for causing the computer to perform a method according to any one of the embodiments of the present disclosure.
According to another aspect of the present disclosure, there is provided a computer program product comprising a computer program which, when executed by a processor, implements a method according to any of the embodiments of the present disclosure.
Aiming at the heterogeneous edge model, the embodiment of the disclosure can divide a plurality of equipment cluster groups through cluster analysis, and train the adaptive model for distilling different equipment cluster groups through a knowledge distillation mode, so that the communication burden can be reduced.
It should be understood that the description in this section is not intended to identify key or critical features of the embodiments of the disclosure, nor is it intended to be used to limit the scope of the disclosure. Other features of the present disclosure will become apparent from the following specification.
Drawings
The drawings are for a better understanding of the present solution and are not to be construed as limiting the present disclosure. Wherein:
FIG. 1 is a flow diagram of a federal learning method according to an embodiment of the present disclosure;
FIG. 2 is a flow diagram of directing each device cluster group to generate an augmentation dataset based on a global student model and a second sub-model of each device cluster group, according to one embodiment of the present disclosure;
FIG. 3 is a flow diagram of pseudo sample updating according to an embodiment of the present disclosure;
FIG. 4 is a flow diagram of another federal learning method according to an embodiment of the present disclosure;
FIG. 5 is a schematic diagram of diffusion model generation of an amplified sample according to an embodiment of the present disclosure;
FIG. 6 is a framework diagram of federal learning according to an embodiment of the present disclosure;
FIG. 7 is a schematic diagram of the structure of a federal learning device according to an embodiment of the present disclosure;
FIG. 8 is a schematic structural view of a federal learning device according to another embodiment of the present disclosure;
fig. 9 is a block diagram of an electronic device for implementing the federal learning method of an embodiment of the present disclosure.
Detailed Description
Exemplary embodiments of the present disclosure are described below in conjunction with the accompanying drawings, which include various details of the embodiments of the present disclosure to facilitate understanding, and should be considered as merely exemplary. Accordingly, one of ordinary skill in the art will recognize that various changes and modifications of the embodiments described herein can be made without departing from the scope of the present disclosure. Also, descriptions of well-known functions and constructions are omitted in the following description for clarity and conciseness.
Furthermore, the terms "first," "second," and the like, are used for descriptive purposes only and are not to be construed as indicating or implying a relative importance or implicitly indicating the number of technical features indicated. Thus, a feature defining "a first" or "a second" may explicitly or implicitly include one or more such feature. In the description of the present disclosure, the meaning of "a plurality" is two or more, unless explicitly defined otherwise.
The continued development of science and technology has led to increasingly powerful devices such as smartphones, tablets, smartwatches, etc., which provide a great deal of convenience while also collecting a great deal of data. On the other hand, artificial intelligence techniques are also rapidly developing. Deep learning techniques are one of the most important techniques, and require very huge data as a basis. And the data on these smart devices is clearly very attractive.
However, the traditional machine learning collects data and then performs training intensively, which poses a great threat to the private information. Federal learning provides a solution. Federal learning is a distributed machine learning technique, unlike previous distributed machine learning techniques, federal learning does not collect user data, but rather leaves the data local, allowing the user device itself to train the machine learning model in situ, and upload the trained model to a server. By the mode, data does not leave the local area, data safety is guaranteed, and meanwhile, only parameters of a model are required to be transmitted, and communication pressure is greatly reduced.
However, in a federal learning system, the greater the number of devices involved in the calculation, the greater the available resources. With the increasing computing and data resources, model sizes continue to expand, and large models are created. The big model is a model which takes big data as drive and has mass parameters and wide application capacity. Large models require processing large amounts of data, and increasing the model size gradually increases the computational cost. Furthermore, the operational efficiency of large models on the device-side hardware is limited. Because of the computational complexity and high resource requirements of these models, deploying them on resource-constrained devices, such as mobile phones or internet of things devices, can be challenging. This typically results in a tradeoff between model performance and device-side feasibility.
In view of this, in order to be able to train a large model with efficient use of computing resources and data at the device side, embodiments of the present disclosure provide a federal learning approach. As shown in fig. 1, a flow chart of the method includes the following:
s101, acquiring performance parameters of a plurality of edge devices.
Wherein the edge devices are for example intelligent terminals, computers, etc.
The performance parameters of the edge device are used to describe the data processing capabilities of the edge device. For example, the performance parameters may include computing resources (e.g., CPU (Central Processing Unit, central processing unit) information, GPU (Graphics Processing Unit, image processor) information, memory information), storage resources (e.g., disk space), network resources (e.g., bandwidth), and the like.
S102, performing cluster analysis on the plurality of edge devices based on performance parameters of the plurality of edge devices to obtain a plurality of device clusters.
According to the embodiment of the disclosure, the clustering analysis is performed based on the performance parameters of the edge devices, so that the edge devices with similar performances can be clustered to the same device cluster. The embodiments of the present disclosure are not limited to a cluster analysis method, for example, k-means (k-means) clustering may be employed.
And S103, performing knowledge distillation on the cloud global model for a plurality of equipment clusters to distill out a plurality of first sub-models.
For example, the cloud global model is a large model with a huge scale. The large model has high requirements on the performance of the device, and not all edge devices can support the large model. Thus, in the case where a plurality of device cluster groups are obtained by clustering, the first sub-model applicable to each device cluster group can be distilled out in accordance with the performance parameter of that device cluster group.
And S104, distributing the plurality of first sub-models to corresponding equipment cluster groups for federal learning.
There is a difference in performance between different edge devices. For heterogeneous edge devices, the embodiment of the disclosure divides the edge devices with similar performances into the same device cluster group in a cluster analysis mode. The same cluster of devices can support running the same first sub-model. Therefore, aiming at the cloud global model which cannot be directly operated on the edge equipment, the first sub-model applicable to different equipment clusters can be distilled out. In this way, the number of distilled first submodels depends on the number of clusters of devices, without having to separately distill one submodel for each edge device, thereby reducing the amount of computation of the distillation operation. According to the embodiment of the disclosure, the cloud global model is distributed to the edge equipment after being light by using the knowledge distillation method, and the light model parameter is lower than the cloud global model parameter, so that the system communication burden can be reduced. In addition, the embodiment of the disclosure can protect the safety of private information by adopting a federal learning method. In summary, the present disclosure provides a versatile and efficient federal learning scheme for heterogeneous edge devices.
Federal learning requires training using data in edge devices. Aiming at the data required by the training model of the present disclosure, in the technical scheme of the present disclosure, the acquisition, storage, application and the like of the related user personal information all conform to the regulations of related laws and regulations, and do not violate the popular public order.
For ease of understanding, the distillation of the first sub-model is described in detail below, including both cluster analysis and distillation of the first sub-model.
1) Clustering analysis
In some embodiments, obtaining a plurality of device clusters through cluster analysis may be implemented as:
a1, determining performance similarity among edge devices based on performance parameters of a plurality of edge devices;
and step A2, performing cluster analysis based on the performance similarity among the edge devices to obtain a plurality of device clusters.
For example, devices may be clustered according to their CPU performance, memory capacity, and disk space. In implementation, the similarity between edge devices is calculated according to the numerical characteristics of the indexes, and the similar devices are divided into the same device cluster.
In implementation, cosine similarity between different performance parameters can be calculated to serve as similarity between different edge devices, and distance between different performance parameters can be calculated to serve as similarity between different edge devices.
For example, clustering all edge devices D based on predefined rules results in a cluster of device clustersWhere t represents the number of clusters of devices. In embodiments of the present disclosure, when an edge device connects to a network, the edge device may send predefined information (i.e., performance parameters) to a server. The server randomly selects t samples from the collected performance parameters of the edge equipment as the center of the cluster, wherein t is a positive integer greater than 0. The distance of each performance parameter to the cluster center (which distance represents the similarity between edge devices) is then calculated. For each performance parameter, it is categorized into the cluster of devices where the closest cluster center is located. Then, forAnd each equipment cluster group recalculates the average value of all the performance parameters in the cluster group, and takes the average value as a new cluster center. The above operations are repeated to classify each edge device into the cluster closest to the cluster until the cluster center is no longer changed or the maximum number of iterations is reached, thereby obtaining a plurality of device clusters.
According to the embodiment of the disclosure, the performance difference between different edge devices can be measured based on the performance similarity, so that the edge devices with similar performance can be accurately divided into the same class of clusters under the heterogeneous condition of the edge devices, and federal learning can be conveniently performed in the heterogeneous edge devices.
2) Distilling out the first submodel for each equipment cluster
In some embodiments, the server side may access the public data set. The public data set comprises data which are collected by the server side. Which contains a large and abundant sample. The server may distill the corresponding first sub-model for each device cluster after training the cloud global model using the public data set.
In implementation, the full data in the public data set can be adopted, and a plurality of first sub-models are distilled from the cloud global model in a knowledge distillation mode.
In another possible embodiment, to improve efficiency, for each of the plurality of device clusters, the required first submodel may also be distilled out in the following manner, including:
and step B1, sampling a training sample set of the performance parameters of the adaptive device cluster from the public data set.
The performance parameter of the device cluster group may be represented by an edge device with the lowest performance parameter in the device cluster group. The average performance of edge devices within the cluster of devices may also be used to represent the performance. The embodiments of the present disclosure are not limited in this regard.
In some embodiments, when a training sample set is sampled for a device cluster, the amount of samples supported by the device cluster for a round of training may be determined based on the amount of samples supported by the device cluster and computing resources. The training sample set required for the cluster of devices is downsampled from the public dataset based on the sample size. For example, if the training of the device cluster 1 supports 5000 samples per round, 5000 samples are downsampled for the device cluster 1 for training. 6000 samples are supported by each round of training of the device cluster group 2, and then 6000 samples are sampled for the device cluster group 2 for training.
The sampling mode can be random sampling or sampling according to the category contained in the device cluster. Wherein samples of categories not included in the device cluster groups may be preferentially sampled so that each device cluster group can learn more categories of samples. For example, taking image classification as an example, when it is known that the local data of the device cluster 1 is mostly samples of class 1 and class 2, the class 3 or class 4 samples may be sampled from the public dataset to distill out the first sub-model for the device cluster. Thus, the first sub-model for each device cluster is trained based on more classes of samples, avoiding that a certain device cluster is biased to learn of a certain class.
And B2, creating an initial model of the performance parameters of the adaptive device cluster.
Wherein the model may be scaled based on the number of samples supported by the device cluster for one round of training and the computing resources of the device cluster. The model scale supported by the higher performance of the equipment cluster is larger, so that an initial model with larger scale can be designed for the equipment cluster with high equipment performance. For example, an initial model with a layer 3 network is designed for a cluster 1 of devices with lower performance. An initial model with a 6-layer network is designed for the higher performing device cluster 2. Thus, according to the difference of the device cluster groups, a heterogeneous first sub-model can be designed.
In specific implementation, corresponding relations between different performance parameters and important model indexes can be established, and respective initial models are established for different equipment clusters based on the corresponding relations.
Of course, in other embodiments, an initial model of each device cluster may be created in response to user operations based on a priori knowledge, which is not limiting to the disclosed embodiments.
It should be noted that, in the embodiment of the present disclosure, the execution sequence of the step B1 and the step B2 is not limited.
And B3, taking the cloud global model as a teacher model, taking the initial model as a student model, and training the initial model based on a training sample set to distill out the first sub-model.
In an embodiment of the disclosure, a server is based on cluster information of devicesAnd a public dataset of the server +.>And carrying out knowledge distillation on the cloud global model w 0. In the embodiment of the disclosure, a server obtains a public data set corresponding to each equipment cluster according to random sampling of the public data set>Wherein, the liquid crystal display device comprises a liquid crystal display device,training sample set representing a first cluster of devices, < > j->A training sample set representing the t-th cluster of devices.
Embodiments of the present disclosure consider distilling out a first sub-model based on knowledge of the response. The main idea is to directly simulate the final prediction of the teacher model. The output of the last layer of the general neural network, followed by a softmax (activation function layer), can be converted into a probability distribution. A log probability is given as the output of this last layer, e.g. the fully connected layer. The response-based knowledge is called a soft target (soft target). The soft target is the probability distribution of the input samples, which can be estimated by the softmax function as:
In formula (1), z i Is a general oneLogic of the ith probability value in the rate distribution; t is a temperature factor for controlling the importance of each soft target. Thus, for each cluster of devices, the gap between the logit of the distilled first sub-model and the logit of the teacher model can be used to determine the loss. The training purpose is to minimize the gap between the first sub-model and the logic predicted by the teacher model so that the first sub-model learns the knowledge of the teacher model. The loss function thereof can be expressed as shown in the formula (2):
L R (z t ,z s )=L R (p(z t ,T),p(z s ,T))(2)
in formula (2), L R (.) represents loss between logits, z t And z s The logit of the teacher model and the student model are represented, respectively.
Taking classification as an example, the teacher model classifies the training sample set to obtain probability distribution, and the student model also classifies the training sample set to obtain probability distribution of the input sample. Under the condition that the probability distribution of the input samples is obtained, a loss value can be determined based on the formula (1) and the formula (2), so that parameters of the student model are adjusted, and the first sub-model is distilled from the cloud global model.
Finally, assume that the total set of student models corresponding to each equipment cluster group is w s ={w 1 ,w 2 ...w t W, where 1 Model parameters representing a first sub-model of the 1 st cluster of devices. Similarly, w t Model parameters representing a first sub-model of a t-th cluster of devices.
Knowledge distillation is a technique that delivers knowledge of a teacher model to a student model. Wherein the teacher model is a cloud global model. For each cluster of devices, the server trains a pupil model for it to obtain a first sub-model of the cluster of devices. Since the training sample set may be different for each cluster of devices, the training results may also be different for each first sub-model.
In summary, the embodiment of the disclosure can construct different first sub-models for different device clusters based on performance differences of different device clusters, and different training sample sets are adopted to enable training results of different first sub-models to be different. Thus, the embodiment of the disclosure can perform personalized knowledge distillation for different equipment clusters based on the difference of the different equipment clusters.
After the first sub-model is distilled out, the first sub-model is issued to a corresponding equipment cluster group for federal learning. Each edge device within the cluster of devices will receive the first sub-model. Each edge device will train the first sub-model alone with a local private data set. The cluster of devices may then aggregate the first sub-model learned by each edge device within the cluster, thereby obtaining a second sub-model. It is understood that each cluster of devices is configured to train a corresponding first sub-model to obtain a second sub-model. The training patterns, particularly within the cluster of devices, will be described in the following examples. The server acquires the second sub-model trained by each device cluster, and then trains the cloud global model based on the second sub-model, so that the cloud global model is updated. For example, the cloud global model is trained by taking the prediction result of each second sub-model on the input sample as a soft target, so that the cloud global model is updated. Therefore, the knowledge of each second sub-model is migrated to the cloud global model by taking each second sub-model as a soft target and taking the cloud global model as a student model.
In other embodiments, the cloud global model may be further trained in a data enhancement manner. The specific implementation method comprises the following steps:
and C1, carrying out knowledge distillation on the second sub-model of each equipment cluster group to obtain a global student model.
In order to learn knowledge of each second sub-model, in the embodiment of the disclosure, a knowledge distillation manner is adopted to compress a plurality of second sub-models into a global student model.
In one possible implementation, since there may be a difference in the parameter amounts of the second sub-models, the second sub-models may be sequentially used as teacher models to sequentially guide the training of the global learning model. Assume that 3 second sub-models are acquired as a second sub-model a, a second sub-model B, and a second sub-model C, respectively. The second sub-model A is taken as a teacher model, the initial global model O is taken as a student model, and the initial global model A is trained by adopting a knowledge distillation mode. And then, taking the initial global model A as a student model, taking the second sub-model B as a teacher model, and distilling knowledge of the second sub-model B to obtain the initial global model B. And by analogy, taking the initial global model B as a student model, taking the second sub-model C as a teacher model, and carrying out knowledge distillation on the second sub-model C to obtain a final global student model.
In another possible implementation, the second sub-models may be aggregated to obtain an aggregated learning objective, so that a global student model may be obtained by knowledge distillation. Specifically, knowledge distillation is performed on the second sub-model of each equipment cluster group to obtain a global student model, which can be implemented as follows:
and C11, respectively inputting the distilled samples into the second sub-models of the equipment clusters to obtain the processing results of the distilled samples by the second sub-models of the equipment clusters.
Wherein the distilled samples may be selected from a public dataset. That is, some samples may be screened from the public dataset as distilled samples at the time of implementation.
The processing result used may be a soft target as shown in formula (1). Or may be a predictive result of the model, such as a probability distribution of the final softmax output. In a classification scenario, where the target class of a distilled sample is known, the processing result may also be a predicted probability value for that target class of the distilled sample, rather than a probability distribution. The processing result may be determined according to actual circumstances, which is not limited by the present disclosure.
And C12, fusing the processing results of the distilled samples by the second submodels to obtain a soft target.
The fusion may be performed by taking an average value, that is, taking an average value of the processing results of the same distilled sample by the plurality of second sub-models. The fusion can also be performed in a weighted average manner taking into account the difference in training effect of the different second sub-models. Weight distribution according to the performance of the model may be mainly considered. For example, the weights of the second sub-models may be determined based on their accuracy on the verification set; wherein, the higher the accuracy, the higher the corresponding weight; and carrying out weighted summation on the processing results of the distilled samples by each second sub-model to obtain a soft target.
For example, in one possible weight distribution strategy, if the accuracy of the second sub-model a is 80% and the accuracy of the second sub-model B is 70%, then the weight of the second sub-model a is 0.8 and the weight of the second sub-model B is 0.7 when the weights are summed.
Of course, the importance of each device cluster group can be determined based on the predefined rule, and the weight of each second sub-model is distributed according to the importance of the device cluster group. The importance of the device cluster group may be determined with reference to the importance of the local data set within the device cluster group, the accuracy of the second sub-model learned by the device cluster group, and the like, which is not limited by the embodiments of the present disclosure.
In the embodiment of the disclosure, in the knowledge distillation process, a weighted summation mode is adopted for aggregation of a plurality of second sub-models, and particularly, the accuracy of the models is considered, so that the constructed soft target can enable the global student model to focus on the teacher model with high learning accuracy, and thus, the high-availability global learning sub-model is distilled.
Step C13, training a global student model based on the soft target and the distilled sample to minimize loss between the processing result of the distilled sample by the global student model and the soft target.
The processing results of the overall student model on the samples can be made to approach to the processing results of the plurality of second sub-models by minimizing the loss, so that the overall student model imitates the plurality of second sub-models, and knowledge distillation is completed.
The embodiment of the disclosure can fuse the processing results of the distillation samples by a plurality of second sub-models, and construct a soft target guiding distillation global student model according to the processing results. The global model can learn knowledge of a plurality of second sub-models at the same time, so that the heterogeneous plurality of second sub-models can be compressed into one global student model.
For example, the server collects all the trained second sub-models of the cluster of devices, and embodiments of the present disclosure compress knowledge of these trained second sub-models (as teacher models) into a smaller model (i.e., global student model) by way of knowledge distillation. It mainly comprises 3 parts as follows (a 1) - (a 3):
(a1) Softening output of teacher model: for each teacher model, its output probability distribution over distilled samples is calculated. These outputs are typically probability distributions transformed by a softmax function. During knowledge distillation, a higher temperature T may be used to soften the softmax function to capture the relative confidence differences between the teacher models. The probability function after softening is shown in formula (3):
in formula (3), z i Is the probability value of the i-th category in the output probability distribution, T is the temperature parameter, P (z i T) is the result of the processing of the distilled sample by the teacher model.
(a2) Output of fusion teacher model: for each distilled sample, the softened output probability distributions of all the teacher models are weighted and summed to obtain one fused target probability distribution (i.e., soft target). Weights can be assigned according to the performance of the teacher model, and the fused soft targets are shown in formula (4):
P target =∑ k (W k *P k ) (4)
in formula (4), P target For soft target, W k Is the weight of the kth teacher model, P k Is the output probability distribution of the kth teacher model (i.e., the result of processing the distilled sample).
(a3) Training of student models: and training a global student model by using the fusion target probability distribution as a soft target. The loss function of the global student model may be defined as a cross entropy loss between the fusion target probability distribution and the probability distribution output by the global student model itself. The loss function is shown in equation (5):
In formula (5), L KD Represents the cross-entropy loss value,is the probability of fusing the ith class in the target probability distribution,/for each class>Is the score of the ith category output by the global student model. Then calculate updated model parameters to get the global student model +.>Wherein t represents a global student model obtained by aggregation of the t th round.
It should be noted that the effect after softening by using the soft target is better, and whether to use the soft target can be determined according to the actual situation. In addition, the softmax function can enable probability distribution of model output to be more gentle, and probability distribution of model prediction for each category can be observed. The softmax function is used in the presently disclosed embodiments by way of example only and is not intended to limit the presently disclosed embodiments.
In the disclosed embodiment, the global student model functions to guide the cluster of devices to generate an augmentation dataset. Specifically, in step C2, based on the global student model and/or the second sub-model of each device cluster group, guiding each device cluster group to generate an amplified data set satisfying the preset data distribution characteristic; the preset data distribution characteristics are the data distribution characteristics of a local data set in the device cluster aiming at each device cluster.
For example, edge devices within a cluster of devices may generate new data as an augmentation sample from distribution characteristics of local data within the edge devices under the direction of a second sub-model and/or a global student model to construct an augmentation data set. The amplified samples are different from the actual private data, but are similar to the distribution characteristics of the actual data. In this way, the amplified samples can be transmitted in the network instead of the private data, and in step C3, the cloud global model is trained based on the amplified data set of each device cluster.
For example, the following equation (6) can be employed, based on the amplified data set after enhancementUpdating the cloud global model:
in the formula (6) of the present invention,representing the cloud global model at time step t, i.e. the cloud global model before updating, η is learning rate,/->Is about model->Is in the amplification dataset +.>A gradient over the surface. Model->The specific calculation of the loss function of (c) may be determined based on the actual traffic demand. The formula shows that in each update step the enhanced amplification dataset is used +.>The gradient of the loss function is calculated. That is, after updating the cloud global model based on the augmentation data set, the corresponding sub-model may be distilled for each device cluster, and then new augmentation data may be generatedAnd (5) integrating the iterative optimization cloud global model.
Thus, in embodiments of the present disclosure, the training cloud global model may be updated with an augmentation dataset that is similar to the device cluster local dataset. The augmented data set is different from the private data to be protected within the cluster of devices, but has data distribution characteristics similar to the private data to be protected. The amplification data set constructed by the method can be transmitted in a network, so that on one hand, the privacy data safety can be protected, and on the other hand, the cloud global model can learn the local data of the edge equipment. More importantly, in the embodiment of the disclosure, the second sub-model and the global student model are adopted to guide generation of the augmentation data set, so that priori knowledge of the second sub-model and the global student model can be adopted, and the constraint equipment cluster group is facilitated to generate the high-quality augmentation data set to optimize the cloud global model.
Of course, in some embodiments, the amplified dataset may also be generated based on a GAN (Generative Adversarial Networks, generative antagonism network) network given the device cluster local dataset distribution characteristics. For example, the GAN network includes a generator for generating an amplified data set and a discriminator for determining whether the distribution characteristics of the local data set of the device cluster are met, and if not, optimizing model parameters of the generator. Such that the generator generates a data set of the character distribution features.
In the case of using the second sub-model and/or the global learning model to guide generation of the augmentation data, a higher quality augmentation data set can be generated relative to the constraint of employing the data distribution characteristics of the local data set alone.
In practice, the guidance of the second sub-model and/or the global learning model may be that the amplified samples in the amplified dataset are required to be target results at the output of the second sub-model and/or the global learning model. It is understood that the cluster of devices generates amplified samples based on local raw data. The amplified sample is a pseudo sample but it should be similar to the original data. In this case, the processing results of the amplified sample by the global student model and/or the second sub-model should be similar to the processing results of the raw data. Then, the second sub-model and/or the global student model may screen out a plurality of pseudo-samples generated from the cluster of devices, the pseudo-samples being similar to the original data as amplified samples in the amplified data set. Therefore, the obtained amplification data set can have good performance in the second sub-model and/or the global student model, and knowledge of the second sub-model can be distilled into the cloud global model through the amplification data set.
In other possible embodiments, the cloud global model is optimized better. In an embodiment of the disclosure, each device cluster group is directed to generate an augmentation dataset based on the global student model and the second sub-model of each device cluster group. The implementation can be as shown in fig. 2:
s201, for each device cluster, obtaining a pseudo sample generated by the device cluster based on the local data distribution characteristics of the device cluster.
S202, for the dummy samples, the following operations are cyclically performed:
s2021, processing the pseudo sample based on the global student model to obtain a first comparison result of the pseudo sample; the method comprises the steps of,
s2022, processing the pseudo sample based on each second sub-model to obtain a plurality of processing results of the pseudo sample;
in implementation, the order of execution of step S2021 and step S2022 is not limited.
It should be noted that, the server may send the global student model to the edge device, and the edge device inputs the pseudo sample into the global student model to obtain the first to-be-compared result. Similarly, the server can input the pseudo sample into the second sub-model in the equipment cluster through the equipment cluster, so as to obtain a processing result of the pseudo sample in the second sub-model.
Since the second sub-model may be heterogeneous, in S2023, a fusion process may be performed on the multiple processing results to obtain a second to-be-compared result of the pseudo sample.
The fusion process may be performed by averaging, or by weighted averaging. The weighting mode can also refer to the second sub-model to generate corresponding weights in the accuracy of the verification set. That is, the higher the accuracy, the greater the weight. Weights may also be generated with reference to the importance of the cluster of devices. The embodiments of the present disclosure are not limited in this regard.
In implementation, the second comparison result and the second comparison result may be the soft targets described above, and the solving manner may be as shown in the formula (1) and the formula (2), which are not described herein. Of course, the probability distribution itself predicted for the final model may also be used.
S2024, in the case that the difference between the first comparison result and the second comparison result meets the preset condition, adding the pseudo sample into the amplification data set.
S2025, in the case that the difference between the first comparison result and the second comparison result does not meet the preset condition, requesting the device cluster to update the pseudo sample until the preset condition is met.
Therefore, under the guidance of the second sub-model and the global student model, the pseudo sample is dynamically updated, so that the pseudo sample has consistent performance in the second sub-model and the global student model, and the generation of a high-quality amplification data set is facilitated, and the information of the global student model is effectively distilled into the cloud global model.
Wherein the preset conditions include at least one of the following:
condition 1), the difference between the first to-be-compared result and the second to-be-compared result of the dummy sample is smaller than a preset threshold.
That is, the condition requires that the difference between the prediction result of the teacher model (i.e., the second sub-model) and the prediction result of the global student model be small enough, so that the generated pseudo-sample can have consistent performance in both the teacher model and the global student model, thereby facilitating distilling the information of the second sub-model or the global student model into the cloud global model by augmenting the data set.
Wherein the preset threshold may be determined based on an empirical value, which is not limited by embodiments of the present disclosure.
Condition 2), the variation trend of the difference between the first comparison result and the second comparison result of the dummy sample is from continuously decreasing to ending decreasing.
For example, as shown in fig. 3, the dummy sample a is input into the global student model to obtain a first comparison result B11. And respectively inputting the processing results into a plurality of second sub-models (i.e. teacher models) to obtain a plurality of processing results. And polymerizing the plurality of processing results to obtain a second comparison result B21. The difference between B11 and B21 was calculated to give C1. And updating the dummy sample A based on C1 to obtain a dummy sample A ', and inputting the dummy sample A' into the global student model to obtain a first comparison result B12. And respectively inputting the processing results into a plurality of second sub-models (i.e. teacher models) to obtain a plurality of processing results. And polymerizing the plurality of processing results to obtain a second comparison result B22. And calculating the difference between B12 and B22 to obtain C2, and continuously updating the pseudo sample A 'to obtain the pseudo sample A' under the condition that C2 is smaller than C1. And inputting the first comparison result B13 into the global student model. And respectively inputting the processing results into a plurality of second sub-models (i.e. teacher models) to obtain a plurality of processing results. And polymerizing the plurality of processing results to obtain a second comparison result B23. The difference between B13 and B23 was calculated to give C3. In the case where C3 is less than C2, the above operations may be repeated to continue updating the dummy sample a). In the case where C3 is not less than C2, the tendency to indicate that the difference becomes small has ended. The re-iterative updating of the pseudo-sample may not perform well between the teacher model and the student model. Thus, the dummy sample a "or a' when the difference is minimal can be added as amplification data to the amplification data set. The update method of the dummy samples will be described later.
Thus, the condition can make the difference between the prediction result of the teacher model (i.e., the second sub-model) and the prediction result of the global student model small enough, so that the generated pseudo sample can have consistent performance in both the teacher model and the global student model, and the information of the second sub-model and/or the global student model can be distilled into the cloud global model by amplifying the data set.
In some embodiments, the generative model is a type of model capable of generating synthetic data. For example, a model that is trained to generate a face may each time a face is generated that is never seen by the model or by anyone. The best known example of a generative model is the GAN network. It has a generator and a discriminator which are opposed to each other and then generate amplification data. The model itself is very difficult to train because of its resistance. This makes it difficult to achieve an optimal balance. This problem can be solved by using a diffusion model. Thus, the amplification data set in embodiments of the present disclosure may be generated by each cluster of devices based on a diffusion model. The diffusion model includes a forward process and a reverse process. In practice, the global student model may be used to construct the inverse of the diffusion model. Thus, the global student model has learned their data distribution from the various edge devices. The reverse process of constructing the diffusion model by adopting the global student model can enable the data distribution of the generated amplification sample to be closer to the data distribution of the edge equipment. Therefore, the cloud global model can learn the data distribution on the edge equipment more accurately, and the prediction performance of the cloud global model is improved.
In summary, in the embodiments of the present disclosure, the isomerism of the edge device is of concern. Focusing on the variability of the edge device resources. For example, there are significant differences in resources such as CPU, GPU, memory, storage space, and network bandwidth for each device. The challenge presented by this heterogeneity is that not all edge devices can support deployment of the cloud global model. Thus, when the model is trained or inferred, it needs to be translated by knowledge distillation methods into a suitable size to accommodate the device resource characteristics. Furthermore, a data enhancement technology is used for expanding a server data set, and a cloud global model is dynamically updated to improve accuracy.
Based on the same technical concept, the embodiment of the present disclosure further provides a federal learning method, which is applied to an edge device, as shown in fig. 4, and includes:
s401, sending performance parameters of current edge equipment to a server, so that the server performs cluster analysis on the plurality of edge equipment based on the performance parameters of the plurality of edge equipment to obtain a plurality of equipment clusters, and performs knowledge distillation on a cloud global model to obtain a first sub-model applicable to a target equipment cluster; the target device cluster group is a device cluster group where the current edge device is located.
The performance parameters of the edge device are used to describe the data processing capabilities of the edge device. For example, the performance parameters may include computing resources (e.g., CPU, GPU, memory), storage resources (e.g., disk space), network resources (e.g., bandwidth), and the like.
S402, performing federal learning based on the first sub-model.
Due to the performance differences between different edge devices. For heterogeneous edge devices, the embodiment of the disclosure divides the edge devices with similar performances into the same device cluster group in a cluster analysis mode. The same cluster of devices can support running the same first sub-model. Therefore, aiming at the cloud global large model which cannot be directly operated on the edge equipment, the first sub-model applicable to different equipment clusters can be distilled out. According to the embodiment of the disclosure, the cloud global large model is distributed to the edge equipment after being light by using the knowledge distillation method, and the light model parameter is lower than the cloud large model parameter, so that the system communication burden can be reduced. In addition, the embodiment of the disclosure can protect the safety of private information by adopting a federal learning method. In summary, the present disclosure provides a versatile and efficient federal learning scheme for heterogeneous edge devices.
In some embodiments, federal learning is performed based on the first sub-model, which may be implemented to train the first sub-model with a local data set of the current edge device to obtain a sub-model to be aggregated; and generating a second sub-model of the first sub-model based on the target device cluster and the sub-model to be aggregated.
The trained second sub-model may be used to update the cloud global model. These have been described in the foregoing and will not be described in detail here.
It should be noted that, the edge devices in the device cluster can complete training of the first sub-model based on the local data set, so that optimization of the cloud global model can be achieved through federal learning.
For example, a first submodel { w for all device clusters 1 ,w 2 ...w t Respectively issued to each device cluster. For the ith device cluster, the edge device set in the device cluster can be represented as S i ={x 1 ,x 2 ...,x N N represents the total number of edge devices within the ith device cluster. The data set of the ith device cluster group may be denoted as D i ={d 1 ,d 2 ...,d N }, wherein d is assumed to be used n Defining the number of local data sets of a single edge device to obtainD t Defining the total local data set of a single device cluster group, the total local data set of all device cluster groups +.>
Edge devices within the cluster of devices are based on a local data set training model. Assume that Representing collaborative knowledge from cluster n of devices, while +.>Intra-cluster common data set D representing that all edge devices within device cluster n are accessible i The data sample in the device cluster is a training data set obtained by sampling the device cluster by the server. The loss for each cluster of devices can be expressed as shown in equation (7):
in the formula (7), lambda>0 is a superparameter, L KL Represents Kullback-Leibler (KL, divergence function), mainly for the transfer of personalized knowledge. L (L) n Representing the local data set d of a single edge device within a cluster n And the loss can be determined according to the service requirement. c mn Is a knowledge coefficient matrix used for estimating the contribution of the training cluster m to n. w (w) m Model parameters representing a first sub-model of the device cluster group m; w (w) n Model parameters representing a first sub-model of the cluster n of devices. In the course of the implementation thereof,collaborative knowledge may be exchanged between different clusters of devices to facilitate individual clusters of devices to calculate the loss based on expression (7), respectively.
Wherein the knowledge coefficient matrix is defined as
The goal of the edge training is shown in equation (8):
in the formula (8) of the present invention,connection vectors for representing weights. For example, assume that there are two device clusters, w1 and w2, respectively, as the first submodel. The model parameters of the two first sub-models can be expressed as w1= [ a, b ]And w2= [ c, d, e ]]Then the weights of the two first models are connected by [ a, b, c, d, e ]]。
In formula (8), d n Is the model parameter w n Is a dimension of (c). Continuing with the example above, the dimension of the first sub-model w1 is 2 and the dimension of the first sub-model w2 is 3.ρ is a regularization parameter greater than 0. D represents the data set of all device clusters, D t A dataset representing a t-th cluster of devices.
In some embodiments, to enable better learning of the local data set, the first sub-model may be iteratively trained a preset number of times using the local data set of the current edge device. Through requiring the edge equipment to iterate and train based on the local data set for many times, the local data set can be fully learned, and the knowledge of the local data set can be transferred to the model to be aggregated, so that the cloud global model can be optimized better.
The equipment cluster is internally provided with a plurality of edge equipment, and each edge equipment is trained to obtain a sub-model to be aggregated. The multiple sub-models to be aggregated in the device cluster group can be aggregated in the device cluster group and then sent to the server, or the server can complete the aggregation.
In the case of aggregation within a device cluster, the device cluster may be divided into a master device and a slave device. The higher performance edge device may be chosen as the master device. Under the condition that the current edge equipment is the main equipment of the target equipment cluster, acquiring a plurality of sub-models to be aggregated obtained by training each edge equipment in the target equipment cluster; and performing polymerization treatment on the multiple sub-models to be polymerized to obtain a second sub-model of the first sub-model.
In the embodiment of the disclosure, the device cluster group completes the aggregation of the sub-models to be aggregated of the edge devices, so that network bandwidth can be saved, and the data volume transmitted between systems can be reduced.
Illustratively, in some embodiments, the aggregation processing is performed on the multiple sub-models to be aggregated to obtain the second sub-model of the first sub-model, which may be implemented as any one of the following three methods:
aggregation method 1), carrying out weighted average on model parameters of a plurality of sub-models to be aggregated to obtain a second sub-model of the first sub-model.
Wherein the weights may be determined based on the accuracy of the sub-models to be aggregated on the validation set.
The aggregation method 2), solving the average value of the model parameters of the plurality of sub-models to be aggregated, and obtaining a second sub-model of the first sub-model.
When the method is implemented, under the condition that the data distribution deviation of the local data set of each edge device in the device cluster is large, weighted average can be selected, namely, model parameters of each sub-model to be aggregated are weighted and summed and then divided by the sum of weights. This balances the difference and contribution between different edge devices.
The embodiment of the disclosure adopts a weighted average or average method to aggregate, and the aggregation method is simple and feasible and can be suitable for all equipment clusters.
The aggregation method 3), under the condition that the verification set comprises multiple categories, respectively determining a sub-model to be aggregated with highest prediction accuracy for the category as an intermediate sub-model aiming at each category so as to obtain multiple intermediate sub-models; and (5) averaging the model parameters of the plurality of intermediate sub-models to obtain a second sub-model.
For example, consider image classification as an example. Predicting the same category by adopting the sub-model to be aggregated obtained by training each edge device; voting according to the predicted result of the category; and selecting the model parameters of the sub-model to be aggregated, which are most accurate for the prediction result of the category. And selecting corresponding model parameters from each category, and solving the average value of the model parameters selected from all the categories, thereby obtaining the model parameters of the second sub model.
In the embodiment of the disclosure, each edge device votes on the processing result of the input sample according to the processing result. The aggregation result is the maximum value of the number of votes obtained.
The aggregation method 3) is suitable for the situation that the processing results of the edge devices are greatly different, and in this case, the final aggregation result is hoped to be determined through voting, so that the second sub-model obtained through aggregation can represent the knowledge learned by the device cluster.
In summary, in each device cluster, after training of each edge device, model aggregation is performed, and the model aggregation process is shown in the following formula (9):
in the formula (9) of the present invention,representing an aggregation model obtained by training the t th round in the equipment cluster i,
representing the model to be aggregated trained by the nth edge device within the device cluster, η is learning rate,/->Is the average gradient of edge device n in the t-th training round. Obtaining a second sub model corresponding to the equipment cluster after polymerization>The second sub-model is sent to a server.
In some embodiments, the second sub-model is sent to the server; so that under the guidance of the global student model and/or the second sub-model of each device cluster, an amplification sample is generated based on the distribution characteristics of the local data set of the current edge device, so that the server obtains the amplification data set of the target device cluster. The global student model is obtained by distilling knowledge of a second sub-model of each equipment cluster by the server. This embodiment enables the guidance of generating high quality amplification data sets.
In the embodiments of the present disclosure, the manner of generating the amplification data set based on the global student model and/or the second sub-model guidance is described in the foregoing, and will not be described herein. The manner in which the diffusion model is used to generate the amplification data set is described primarily in this embodiment.
In particular, since the server side discloses the data setThe samples are limited and lack data samples within the edge device, so the data set needs to be enhanced by a diffusion model. In the embodiment of the disclosure, the global student model w obtained in the last step is used student As a guide, data enhancement was performed by diffusion model. The specific process of generating an amplified dataset includes two phases, a forward process and a reverse process. The following description will take an image as an example.
Taking the image as an example, the forward process: each instant t of the forward process is only related to instant t-1 and can therefore be regarded as a markov process, the purpose of the diffusion being to take the original image x through the markov process 0 Gradually mapped to a multidimensional normal distribution x t (i.e., gaussian noise). Wherein the random process of each step is q (x t |x t-1 ) The noise adding process can be defined as shown in equation (10)The illustration is:
x t = α t x t-1tt , ∈ t ~N(0,I) (10)
in formula (10), α tt Satisfy the harmonic relation alpha t 2t 2 =1, and β t Is larger as t increases. Thus, the forward procedure can be defined as shown in equation (11):
due to the nature of gaussian distribution, expression (12) can be obtained:
in the above expressions (10) to (12), I is an identity matrix,due to alpha t Is tapered and takes a value as close to 0 as possible, so when t approaches infinity, expression (12) can be expressed as expression (13):
q(x t |0)=N(x t ;0,I) (13)
I.e. x at this time t Obeys a standard normal distribution.
The reverse process is as follows: from x t To x 0 Is a diffuse inverse process, the image is slowly transformed from gaussian noise to normal image, and the random process at each step is q (x t-1 |x t ). The stochastic process is unknown, so the diffusion model is to define a learnable p θ (x t-1 |x t ) Is the reverse of the above. Global student model w obtained by the previous step student As an optimization parameter θ, the process is made as close as possible to the true inverse process q (x t-1 |x t ) So that we can generate the desired normal image by means of a gaussian noise. By passing throughBayes formula, at x 0 When known, the expression (14) can be obtained as:
from the probability density function definition of Gaussian distribution and the above result (i.e. square of the complex), the mean value of posterior distribution can be obtainedSum of variances->As shown in expression (15): />
In expression (15), the meaning of each parameter is described in the foregoing, and will not be described here again.
Furthermore, by applying a random process q (x t |x 0 ) Using the heavy parameter method, x can be obtained t Concerning x 0 Expression (16) of (2):
bringing expressions intoExpression (17) can be obtained:
here, along withMechanical noise epsilon t Obeying standard normal distribution, and in order to ensure that each step of the reverse process can accurately predict the specific noise value of each step, a global student model w can be used student To predict this noise value, the inverse process is set as shown in expression (18):
in expression (18), E θ The global student model w is represented student Predicted noise, θ represents model parameters of the global student model, and the input of the global student model is x t And t, then predicting the noise E at that time t Through x t =α t x t-1tt Calculating to obtain x t-1 And (3) obtaining the product.
In the training process, the global student model can realize the reverse process to obtain a reverse model.
In the reverse process, the noise sampled at the corresponding time of the output fitting forward process at each time can be expressed as shown in equation (19):
in expression (19), E θ Representing the noise of the inverse model prediction,representing the noise generated by the forward process corresponding to time t.
Obtained by the forward process and the reverse process in the above mannerAs a new data sample, i.e. an amplified sample, an amplified data set generated based on the inverse model can thus be expressed as +.> Where n represents the sample size.
Briefly, the process of diffusion modeling to generate amplified samples can be as shown in FIG. 5. The edge device learns a data characteristic distribution of a local data set of the edge device in a forward process based on the local data set. And then learn how to generate new data based on the data-given feature distribution in the reverse process. For raw data x within an edge device 0 Superposition of noise e by forward process t To the original data x 0 And obtaining Gaussian noise. By progressively removing predicted noise e in the reverse process θ Thereby generating new data and realizing data enhancement.
In order to improve the quality of the generated new data, in the real-time example of the disclosure, a global student model and/or a second sub-model of each equipment cluster group is introduced to conduct data enhancement.
In some embodiments, generating the augmentation samples based on the distribution characteristics of the local data set of the current edge device under the direction of the global student model and the second sub-model of each device cluster may be implemented as: generating a pseudo sample based on the diffusion model; the diffusion model learns the data distribution characteristics of a local data set of the current edge equipment; and the reverse process of the diffusion model is constructed based on the global student model; updating the pseudo sample based on the diffusion model under the guidance of the global student model and the second sub model of each device cluster to obtain an amplified sample.
Specific guiding methods are described in the foregoing, and are not repeated here.
According to the embodiment of the disclosure, the amplification sample can be generated based on the second sub-model and the global student model, so that the amplification sample data can be well represented in the global student model and the second sub-model, and the information of the global student model is effectively distilled into the cloud global model of the server.
Wherein, the construction of the diffusion model based on the global student model can be respectively executed by each edge device and can be summarized as follows:
step D1, determining Gaussian noise of a training sample, a time step t and forward superposition noise corresponding to the time step t based on a forward process model of a diffusion model, namely
The training samples are samples in the local data set of the edge device. So that the forward process can learn the data distribution of the local data set of the edge device.
Step D2, inputting Gaussian noise and time step t of the training sample into the global student model to obtain prediction noise E of the global student model θ
And D3, determining the loss between the prediction noise and the forward superposition noise corresponding to the time step t (shown in a formula (19)), taking the minimized loss as a training target, and adjusting the parameters of the global student model to obtain a reverse process model of the diffusion model.
In summary, since the server-side published data set samples are limited and data samples of the edge devices are absent, the disclosed embodiments introduce a policy of the diffusion model. The policy can generate new sample data that is highly similar to the data samples of the edge device. In the reverse process of the diffusion model, determining the ending stage of the reverse process by judging the soft label difference of the generated pseudo sample between the global student model and the second sub model of each equipment cluster. This stage can be achieved by three steps:
Step E1, generating a pseudo sample: first, a set of dummy samples has been generated by the diffusion model. The diffusion model is usually obtained by learning the data distribution of the edge device, so that the generated pseudo sample has similar distribution characteristics with the data of the edge device.
Step E2, evaluating and generating data by using the global student model and each second sub-model: that is, the generated pseudo-samples are entered into the global student model and the second sub-model of each device cluster (collectively referred to as the teacher model), both of which produce a soft label, i.e., a predicted probability value distribution.
And E3, calculating the difference of the soft labels output by the global student model and the teacher model. This difference may measure the behavior of the generated sample data in both models. The end phase of the reverse process is determined by observing this change in soft label differences. If the variance is decreasing, then the reverse process may continue, indicating that the quality of the generated data is increasing; if the difference is increasing or remains the same, then this indicates that the quality of the data generated is no longer improving and the reverse process will end.
In summary, a training framework provided by an embodiment of the present disclosure is shown in fig. 6. And the central server performs cluster analysis on the heterogeneous edge devices and divides the edge devices with similar performances into the same device cluster group. For each cluster of devices, the central server distills out the adapted small first sub-model based on the cloud global model. And issuing each first sub-model into the equipment cluster group, so that edge equipment in the equipment cluster group trains the first sub-models by adopting the local privacy data. After training the first sub-model, each edge device is used as a model to be aggregated. The models to be aggregated in the same equipment cluster are aggregated into a second sub-model, and the second sub-model is sent to the central server. The central server performs knowledge distillation on the plurality of second sub-models and compresses the second sub-models into a global student model. Each edge device generates an augmentation dataset under the direction of the second sub-model and the global student model. The central server optimally trains a cloud global model based on the augmentation data set.
Based on the same technical concept, as shown in fig. 7, the embodiment of the disclosure further provides a federal learning device 700, including:
an obtaining module 701, configured to obtain performance parameters of a plurality of edge devices;
the clustering module 702 is configured to perform cluster analysis on a plurality of edge devices based on performance parameters of the plurality of edge devices, to obtain a plurality of device clusters;
the distillation module 703 is configured to perform knowledge distillation on the cloud global model for a plurality of device clusters to distill out a plurality of first sub-models;
and the learning module 704 is configured to distribute the plurality of first sub-models to corresponding device cluster groups for federal learning.
In some embodiments, the distillation module is specifically configured to:
for each of the plurality of device cluster groups, performing:
sampling a training sample set of performance parameters of the adaptive device cluster from the public data set; the method comprises the steps of,
creating an initial model of performance parameters of the adaptation device cluster;
and taking the cloud global model as a teacher model, taking the initial model as a student model, and training the initial model based on a training sample set to distill out the first sub-model.
In some embodiments, the clustering module is specifically configured to:
determining performance similarity between the edge devices based on performance parameters of the plurality of edge devices;
And performing cluster analysis based on the performance similarity among the edge devices to obtain a plurality of device clusters.
In some embodiments, each device cluster is configured to train a corresponding first sub-model to obtain a second sub-model;
the device also comprises a guiding module, a first data processing module and a second data processing module, wherein the guiding module is used for guiding each equipment cluster group to generate an amplification data set meeting the preset data distribution characteristic based on the second sub-model and/or the global student model of each equipment cluster group; the method comprises the steps that for each equipment cluster, the preset data distribution characteristics are the data distribution characteristics of a local data set in the equipment cluster; the global student model is obtained by distilling knowledge of a second sub-model of each equipment cluster;
the learning module is also used for training the cloud global model based on the amplified data set of each equipment cluster.
In some embodiments, the distillation module comprises:
the input submodule is used for respectively inputting the distilled samples into the second submodels of the equipment clusters to obtain the processing results of the second submodels of the equipment clusters on the distilled samples;
the fusion submodule is used for fusing the processing results of the distilled samples by the second submodels to obtain a soft target;
and the distillation sub-module is used for training the global student model based on the soft target and the distillation sample so as to minimize the loss between the processing result of the global student model on the distillation sample and the soft target.
In some embodiments, the fusion sub-module is specifically configured to:
determining the weight of each second sub-model based on the accuracy of each second sub-model on the verification set; wherein, the higher the accuracy, the higher the corresponding weight;
and carrying out weighted summation on the processing results of the distilled samples by each second sub-model to obtain a soft target.
In some embodiments, the guidance module comprises:
the acquisition submodule is used for acquiring pseudo samples generated by the equipment cluster group based on the local data distribution characteristics of the equipment cluster group aiming at each equipment cluster group;
an optimizing sub-module, configured to, for the dummy samples, perform the following operations in a loop:
processing the pseudo sample based on the global student model to obtain a first to-be-compared result of the pseudo sample; the method comprises the steps of,
processing the pseudo sample based on each second sub-model to obtain a plurality of processing results of the pseudo sample;
fusion processing is carried out on the plurality of processing results to obtain a second to-be-compared result of the pseudo sample;
under the condition that the difference between the first comparison result and the second comparison result meets the preset condition, adding a pseudo sample into the amplification data set;
and under the condition that the difference between the first comparison result and the second comparison result does not meet the preset condition, requesting the equipment cluster to update the pseudo sample until the preset condition is met.
In some embodiments, the preset conditions include at least one of:
the difference between the first comparison result and the second comparison result of the pseudo sample is smaller than a preset threshold value;
the variation trend of the difference between the first comparison result and the second comparison result of the dummy sample is from continuously decreasing to ending decreasing.
In some embodiments, the augmentation dataset is generated by each cluster of devices based on a diffusion model, and the global student model is used to build the inverse of the diffusion model described above.
Based on the same technical concept, as shown in fig. 8, the embodiment of the present disclosure further provides a federal learning device 800, including:
the sending module 801 is configured to send performance parameters of a current edge device to a server, so that the server performs cluster analysis on a plurality of edge devices based on the performance parameters of the plurality of edge devices to obtain a plurality of device clusters, and performs knowledge distillation on a cloud global model to obtain a first sub-model applicable to a target device cluster; the target device cluster group is a device cluster group where the current edge device is located;
training module 802 for federal learning based on the first sub-model.
In some embodiments, the training module comprises:
The model training sub-module is used for training a first sub-model by adopting a local data set of the current edge equipment to obtain a sub-model to be aggregated;
and the aggregation sub-module is used for generating a second sub-model of the first sub-model based on the target equipment cluster and the sub-model to be aggregated.
In some embodiments, the model training submodule is specifically configured to iteratively train the first sub-model a preset number of times using a local data set of the current edge device.
In some embodiments, the aggregation sub-module is specifically configured to:
under the condition that the current edge equipment is the main equipment of the target equipment cluster, acquiring a plurality of sub-models to be aggregated obtained by training each edge equipment in the target equipment cluster;
and performing polymerization treatment on the multiple sub-models to be polymerized to obtain a second sub-model of the first sub-model.
In some embodiments, the aggregation sub-module is specifically configured to:
carrying out weighted average on model parameters of a plurality of sub-models to be aggregated to obtain a second sub-model of the first sub-model; or alternatively, the process may be performed,
and solving the average value of the model parameters of the plurality of sub-models to be aggregated to obtain a second sub-model of the first sub-model.
In some embodiments, the aggregation sub-module is specifically configured to:
under the condition that the verification set comprises a plurality of categories, respectively determining a sub-model to be aggregated with highest prediction accuracy of the categories as an intermediate sub-model aiming at each category so as to obtain a plurality of intermediate sub-models;
And (5) averaging the model parameters of the plurality of intermediate sub-models to obtain a second sub-model.
In some embodiments, the sending module is further configured to send the second sub-model to a server;
the generation module is used for generating an amplification sample based on the distribution characteristics of the local data set of the current edge equipment under the guidance of the second sub-model and/or the global student model of each equipment cluster, so that the server obtains an amplification data set of the target equipment cluster;
the global student model is obtained by distilling knowledge of a second sub-model of each equipment cluster by the server.
In some embodiments, the generating module comprises:
a sample generation sub-module for generating a pseudo sample based on the diffusion model; the diffusion model learns the data distribution characteristics of a local data set of the current edge equipment; and the reverse process of the diffusion model is constructed based on the global student model;
and the updating sub-module is used for updating the pseudo sample based on the diffusion model under the guidance of the global student model and the second sub-model of each equipment cluster group so as to obtain an amplification sample.
For descriptions of specific functions and examples of each module and sub-module of the apparatus in the embodiments of the present disclosure, reference may be made to the related descriptions of corresponding steps in the foregoing method embodiments, which are not repeated herein.
In the technical scheme of the disclosure, the acquisition, storage, application and the like of the related user personal information all conform to the regulations of related laws and regulations, and the public sequence is not violated.
According to embodiments of the present disclosure, the present disclosure also provides an electronic device, a readable storage medium and a computer program product.
Fig. 9 shows a schematic block diagram of an example electronic device 900 that may be used to implement embodiments of the present disclosure. Electronic devices are intended to represent various forms of digital computers, such as laptops, desktops, workstations, personal digital assistants, servers, blade servers, mainframes, and other appropriate computers. The electronic device may also represent various forms of mobile apparatuses, such as personal digital assistants, cellular telephones, smartphones, wearable devices, and other similar computing apparatuses. The components shown herein, their connections and relationships, and their functions, are meant to be exemplary only, and are not meant to limit implementations of the disclosure described and/or claimed herein.
As shown in fig. 9, the apparatus 900 includes a computing unit 901 that can perform various appropriate actions and processes according to a computer program stored in a Read Only Memory (ROM) 902 or a computer program loaded from a storage unit 908 into a Random Access Memory (RAM) 903. In the RAM 903, various programs and data required for the operation of the device 900 can also be stored. The computing unit 901, the ROM 902, and the RAM 903 are connected to each other by a bus 904. An input/output (I/O) interface 905 is also connected to the bus 904.
Various components in device 900 are connected to I/O interface 905, including: an input unit 906 such as a keyboard, a mouse, or the like; an output unit 907 such as various types of displays, speakers, and the like; a storage unit 908 such as a magnetic disk, an optical disk, or the like; and a communication unit 909 such as a network card, modem, wireless communication transceiver, or the like. The communication unit 909 allows the device 900 to exchange information/data with other devices through a computer network such as the internet and/or various telecommunications networks.
The computing unit 901 may be a variety of general and/or special purpose processing components having processing and computing capabilities. Some examples of computing unit 901 include, but are not limited to, a Central Processing Unit (CPU), a Graphics Processing Unit (GPU), various specialized Artificial Intelligence (AI) computing chips, various computing units running machine learning model algorithms, a Digital Signal Processor (DSP), and any suitable processor, controller, microcontroller, etc. The computing unit 901 performs the respective methods and processes described above, such as the federal learning method. For example, in some embodiments, the federal learning method can be implemented as a computer software program tangibly embodied on a machine-readable medium, such as storage unit 908. In some embodiments, part or all of the computer program may be loaded and/or installed onto the device 900 via the ROM 902 and/or the communication unit 909. When the computer program is loaded into RAM 903 and executed by the computing unit 901, one or more steps of the federal learning method described above may be performed. Alternatively, in other embodiments, the computing unit 901 may be configured to perform the federal learning method by any other suitable means (e.g., by means of firmware).
Various implementations of the systems and techniques described here above can be implemented in digital electronic circuitry, integrated circuit systems, field Programmable Gate Arrays (FPGAs), application Specific Integrated Circuits (ASICs), application Specific Standard Products (ASSPs), systems On Chip (SOCs), complex Programmable Logic Devices (CPLDs), computer hardware, firmware, software, and/or combinations thereof. These various embodiments may include: implemented in one or more computer programs, the one or more computer programs may be executed and/or interpreted on a programmable system including at least one programmable processor, which may be a special purpose or general-purpose programmable processor, that may receive data and instructions from, and transmit data and instructions to, a storage system, at least one input device, and at least one output device.
Program code for carrying out methods of the present disclosure may be written in any combination of one or more programming languages. These program code may be provided to a processor or controller of a general purpose computer, special purpose computer, or other programmable data processing apparatus such that the program code, when executed by the processor or controller, causes the functions/operations specified in the flowchart and/or block diagram to be implemented. The program code may execute entirely on the machine, partly on the machine, as a stand-alone software package, partly on the machine and partly on a remote machine or entirely on the remote machine or server.
In the context of this disclosure, a machine-readable medium may be a tangible medium that can contain, or store a program for use by or in connection with an instruction execution system, apparatus, or device. The machine-readable medium may be a machine-readable signal medium or a machine-readable storage medium. The machine-readable medium may include, but is not limited to, an electronic, magnetic, optical, electromagnetic, infrared, or semiconductor system, apparatus, or device, or any suitable combination of the foregoing. More specific examples of a machine-readable storage medium would include an electrical connection based on one or more wires, a portable computer diskette, a hard disk, a Random Access Memory (RAM), a read-only memory (ROM), an erasable programmable read-only memory (EPROM or flash memory), an optical fiber, a portable compact disc read-only memory (CD-ROM), an optical storage device, a magnetic storage device, or any suitable combination of the foregoing.
To provide for interaction with a user, the systems and techniques described here can be implemented on a computer having: a display device (e.g., a CRT (cathode ray tube) or LCD (liquid crystal display) monitor) for displaying information to a user; and a keyboard and pointing device (e.g., a mouse or trackball) by which a user can provide input to the computer. Other kinds of devices may also be used to provide for interaction with a user; for example, feedback provided to the user may be any form of sensory feedback (e.g., visual feedback, auditory feedback, or tactile feedback); and input from the user may be received in any form, including acoustic input, speech input, or tactile input.
The systems and techniques described here can be implemented in a computing system that includes a background component (e.g., as a data server), or that includes a middleware component (e.g., an application server), or that includes a front-end component (e.g., a user computer having a graphical user interface or a web browser through which a user can interact with an implementation of the systems and techniques described here), or any combination of such background, middleware, or front-end components. The components of the system can be interconnected by any form or medium of digital data communication (e.g., a communication network). Examples of communication networks include: local Area Networks (LANs), wide Area Networks (WANs), and the internet.
The computer system may include a client and a server. The client and server are typically remote from each other and typically interact through a communication network. The relationship of client and server arises by virtue of computer programs running on the respective computers and having a client-server relationship to each other. The server may be a cloud server, a server of a distributed system, or a server incorporating a blockchain.
It should be appreciated that various forms of the flows shown above may be used to reorder, add, or delete steps. For example, the steps recited in the present disclosure may be performed in parallel, sequentially, or in a different order, provided that the desired results of the disclosed aspects are achieved, and are not limited herein.
The above detailed description should not be taken as limiting the scope of the present disclosure. It will be apparent to those skilled in the art that various modifications, combinations, sub-combinations and alternatives are possible, depending on design requirements and other factors. Any modifications, equivalent substitutions, improvements, etc. that are within the principles of the present disclosure are intended to be included within the scope of the present disclosure.

Claims (37)

1. A federal learning method, comprising:
acquiring performance parameters of a plurality of edge devices;
performing cluster analysis on the plurality of edge devices based on the performance parameters of the plurality of edge devices to obtain a plurality of device clusters;
performing knowledge distillation on the cloud global model for the plurality of equipment clusters to distill out a plurality of first sub-models;
and distributing the plurality of first sub-models to corresponding equipment cluster groups for federal learning.
2. The method of claim 1, wherein the performing knowledge distillation on the cloud global model for the plurality of device clusters comprises:
for each of the plurality of device cluster groups, performing:
sampling a training sample set adapting to the performance parameters of the equipment cluster from the public data set; the method comprises the steps of,
creating an initial model adapting performance parameters of the cluster of devices;
and training the initial model based on the training sample set by taking the cloud global model as a teacher model and the initial model as a student model so as to distill out the first sub-model.
3. The method according to claim 1 or 2, wherein the performing cluster analysis on the plurality of edge devices based on the performance parameters of the plurality of edge devices to obtain a plurality of device clusters includes:
determining performance similarity between the edge devices based on the performance parameters of the plurality of edge devices;
and performing cluster analysis based on the performance similarity among the edge devices to obtain the plurality of device clusters.
4. A method according to any of claims 1-3, each cluster of devices being adapted to train a respective first sub-model to obtain a second sub-model, further comprising:
Guiding each equipment cluster group to generate an amplification data set meeting preset data distribution characteristics based on a second sub-model and/or a global student model of each equipment cluster group; the preset data distribution characteristics are data distribution characteristics of a local data set in each equipment cluster; the global student model is obtained by distilling knowledge of a second sub-model of each equipment cluster;
and training the cloud global model based on the amplification data set of each equipment cluster.
5. The method of claim 4, wherein performing knowledge distillation on the second sub-model of each device cluster to obtain the global student model comprises:
respectively inputting the distilled sample into a second sub-model of each equipment cluster group to obtain a processing result of the second sub-model of each equipment cluster group on the distilled sample;
fusing the processing results of the distilled samples by each second sub-model to obtain a soft target;
the global student model is trained based on the soft target and the distilled sample to minimize loss between the processing results of the distilled sample by the global student model and the soft target.
6. The method of claim 5, wherein the fusing the processing results of the distilled samples by the second sub-models to obtain a soft target comprises:
Determining the weight of each second sub-model based on the accuracy of each second sub-model on the verification set; wherein, the higher the accuracy, the higher the corresponding weight;
and carrying out weighted summation on the processing results of the distilled samples by each second sub-model to obtain the soft target.
7. The method of any of claims 4-6, wherein directing each device cluster group to generate an augmentation dataset based on the second sub-model of each device cluster group and the global student model comprises:
for each equipment cluster, acquiring a pseudo sample generated by the equipment cluster based on local data distribution characteristics of the equipment cluster;
for the dummy samples, the following operations are performed in a loop:
processing the pseudo sample based on the global student model to obtain a first comparison result of the pseudo sample; the method comprises the steps of,
processing the pseudo samples based on each second sub-model to obtain a plurality of processing results of the pseudo samples;
performing fusion processing on the plurality of processing results to obtain a second to-be-compared result of the pseudo sample;
adding the pseudo sample to the amplification data set under the condition that the difference between the first to-be-compared result and the second to-be-compared result meets a preset condition;
And under the condition that the difference between the first to-be-compared result and the second to-be-compared result does not meet the preset condition, requesting the equipment cluster to update the pseudo sample until the preset condition is met.
8. The method of claim 7, the preset conditions comprising at least one of:
the difference between the first to-be-compared result and the second to-be-compared result of the pseudo sample is smaller than a preset threshold value;
and the variation trend of the difference between the first to-be-compared result and the second to-be-compared result of the pseudo sample is from continuous reduction to ending reduction.
9. The method of any one of claims 4-8, the amplification dataset being generated by clusters of devices based on a diffusion model, the global student model being used to build a reverse process of the diffusion model.
10. A federal learning method, comprising:
sending performance parameters of current edge equipment to a server, so that the server performs cluster analysis on the plurality of edge equipment based on the performance parameters of the plurality of edge equipment to obtain a plurality of equipment cluster groups, and performing knowledge distillation on a cloud global model to obtain a first sub-model applicable to a target equipment cluster group; the target device cluster group is a device cluster group in which the current edge device is located;
Federal learning is performed based on the first sub-model.
11. The method of claim 10, wherein the federally learning based on the first sub-model comprises:
training the first sub-model by adopting a local data set of the current edge equipment to obtain a sub-model to be aggregated;
generating a second sub-model of the first sub-model based on the target device cluster and the sub-model to be aggregated.
12. The method of claim 11, wherein the training the first sub-model with the local data set of the current edge device comprises:
and carrying out iterative training on the first sub-model for preset times by adopting the local data set of the current edge equipment.
13. The method of claim 11, wherein the generating a second sub-model of the first sub-model based on the target device cluster and the sub-model to be aggregated comprises:
under the condition that the current edge equipment is the main equipment of the target equipment cluster, acquiring a plurality of sub-models to be aggregated obtained by training each edge equipment in the target equipment cluster;
and carrying out aggregation treatment on the multiple sub-models to be aggregated to obtain a second sub-model of the first sub-model.
14. The method of claim 13, wherein the aggregating the plurality of sub-models to be aggregated to obtain a second sub-model of the first sub-model comprises:
carrying out weighted average on model parameters of the plurality of sub-models to be aggregated to obtain a second sub-model of the first sub-model; or alternatively, the process may be performed,
and solving an average value of the model parameters of the plurality of sub-models to be aggregated to obtain a second sub-model of the first sub-model.
15. The method of claim 13, wherein the aggregating the plurality of sub-models to be aggregated to obtain a second sub-model of the first sub-model comprises:
under the condition that the verification set comprises a plurality of categories, respectively determining a sub-model to be aggregated with highest prediction accuracy for each category as an intermediate sub-model to obtain a plurality of intermediate sub-models;
and solving the average value of the model parameters of the plurality of intermediate sub-models to obtain the second sub-model.
16. The method of any of claims 10-15, further comprising:
transmitting the second sub-model to the server;
under the guidance of a second sub-model and/or a global student model of each equipment cluster, generating an amplification sample based on the distribution characteristics of the local data set of the current edge equipment, so that the server obtains an amplification data set of the target equipment cluster;
The global student model is obtained by distilling knowledge of a second sub-model of each equipment cluster by the server.
17. The method of claim 16, wherein generating an augmentation sample based on the distribution characteristics of the local dataset of the current edge device under the direction of the second sub-model and the global student model of each device cluster comprises:
generating a pseudo sample based on the diffusion model; wherein the diffusion model learns data distribution characteristics of a local data set of the current edge device; and the inverse process of the diffusion model is constructed based on the global student model;
updating the pseudo-sample based on the diffusion model under the direction of the global student model and a second sub-model of each device cluster to obtain the amplified sample.
18. A federal learning apparatus, comprising:
the acquisition module is used for acquiring performance parameters of a plurality of edge devices;
the clustering module is used for carrying out clustering analysis on the plurality of edge devices based on the performance parameters of the plurality of edge devices to obtain a plurality of device clusters;
the distillation module is used for carrying out knowledge distillation on the cloud global model for the plurality of equipment clusters so as to distill out a plurality of first sub-models;
And the learning module is used for distributing the plurality of first sub-models to corresponding equipment cluster groups to perform federal learning.
19. The apparatus of claim 18, wherein the distillation module is specifically configured to:
for each of the plurality of device cluster groups, performing:
sampling a training sample set adapting to the performance parameters of the equipment cluster from the public data set; the method comprises the steps of,
creating an initial model adapting performance parameters of the cluster of devices;
and training the initial model based on the training sample set by taking the cloud global model as a teacher model and the initial model as a student model so as to distill out the first sub-model.
20. The apparatus of claim 18 or 19, wherein the clustering module is specifically configured to:
determining performance similarity between the edge devices based on the performance parameters of the plurality of edge devices;
and performing cluster analysis based on the performance similarity among the edge devices to obtain the plurality of device clusters.
21. The apparatus of any of claims 18-20, each cluster of devices configured to train a respective first sub-model to obtain a second sub-model;
The device also comprises a guiding module, a first data processing module and a second data processing module, wherein the guiding module is used for guiding each equipment cluster group to generate an amplification data set meeting the preset data distribution characteristics based on a second sub-model and/or a global student model of each equipment cluster group; the preset data distribution characteristics are data distribution characteristics of a local data set in each equipment cluster; the global student model is obtained by distilling knowledge of a second sub-model of each equipment cluster;
the learning module is further configured to train the cloud global model based on an augmentation data set of each device cluster.
22. The apparatus of claim 21, wherein the distillation module comprises:
the input submodule is used for respectively inputting the distilled samples into the second submodules of the equipment clusters to obtain the processing results of the second submodules of the equipment clusters on the distilled samples;
the fusion submodule is used for fusing the processing results of the distilled samples by the second submodels to obtain a soft target;
a distillation sub-module for training the global student model based on the soft target and the distillation sample to minimize loss between the processing result of the distillation sample by the global student model and the soft target.
23. The apparatus of claim 22, wherein the fusion sub-module is specifically configured to:
determining the weight of each second sub-model based on the accuracy of each second sub-model on the verification set; wherein, the higher the accuracy, the higher the corresponding weight;
and carrying out weighted summation on the processing results of the distilled samples by each second sub-model to obtain the soft target.
24. The apparatus of any of claims 21-23, wherein the guidance module comprises:
the acquisition submodule is used for acquiring pseudo samples generated by the equipment cluster group based on the local data distribution characteristics of the equipment cluster group aiming at each equipment cluster group;
an optimizing sub-module, configured to, for the dummy samples, perform the following operations in a loop:
processing the pseudo sample based on the global student model to obtain a first comparison result of the pseudo sample; the method comprises the steps of,
processing the pseudo samples based on each second sub-model to obtain a plurality of processing results of the pseudo samples;
performing fusion processing on the plurality of processing results to obtain a second to-be-compared result of the pseudo sample;
adding the pseudo sample to the amplification data set under the condition that the difference between the first to-be-compared result and the second to-be-compared result meets a preset condition;
And under the condition that the difference between the first to-be-compared result and the second to-be-compared result does not meet the preset condition, requesting the equipment cluster to update the pseudo sample until the preset condition is met.
25. The apparatus of claim 24, the preset conditions comprising at least one of:
the difference between the first to-be-compared result and the second to-be-compared result of the pseudo sample is smaller than a preset threshold value;
and the variation trend of the difference between the first to-be-compared result and the second to-be-compared result of the pseudo sample is from continuous reduction to ending reduction.
26. The apparatus of any one of claims 21-25, the augmentation dataset generated by each cluster of devices based on a diffusion model, the global student model being used to construct a reverse process of the diffusion model.
27. A federal learning apparatus, comprising:
the system comprises a sending module, a server and a cloud global model, wherein the sending module is used for sending performance parameters of current edge equipment to the server so that the server performs cluster analysis on the edge equipment based on the performance parameters of the edge equipment to obtain a plurality of equipment clusters, and performs knowledge distillation on the cloud global model to obtain a first sub-model applicable to the target equipment clusters; the target device cluster group is a device cluster group in which the current edge device is located;
And the training module is used for carrying out federal learning based on the first sub-model.
28. The apparatus of claim 27, wherein the training module comprises:
the model training sub-module is used for training the first sub-model by adopting a local data set of the current edge equipment to obtain a sub-model to be aggregated;
and the aggregation sub-module is used for generating a second sub-model of the first sub-model based on the target equipment cluster and the sub-model to be aggregated.
29. The apparatus of claim 28, wherein the model training sub-module is configured to iteratively train the first sub-model a preset number of times using a local data set of the current edge device.
30. The apparatus of claim 28, wherein the aggregation sub-module is specifically configured to:
under the condition that the current edge equipment is the main equipment of the target equipment cluster, acquiring a plurality of sub-models to be aggregated obtained by training each edge equipment in the target equipment cluster;
and carrying out aggregation treatment on the multiple sub-models to be aggregated to obtain a second sub-model of the first sub-model.
31. The apparatus of claim 30, wherein the aggregation sub-module is specifically configured to:
Carrying out weighted average on model parameters of the plurality of sub-models to be aggregated to obtain a second sub-model of the first sub-model; or alternatively, the process may be performed,
and solving an average value of the model parameters of the plurality of sub-models to be aggregated to obtain a second sub-model of the first sub-model.
32. The apparatus of claim 30, wherein the aggregation sub-module is specifically configured to:
under the condition that the verification set comprises a plurality of categories, respectively determining a sub-model to be aggregated with highest prediction accuracy for each category as an intermediate sub-model to obtain a plurality of intermediate sub-models;
and solving the average value of the model parameters of the plurality of intermediate sub-models to obtain the second sub-model.
33. The apparatus of any one of claims 27-32,
the sending module is further configured to send the second sub-model to the server;
the generation module is used for generating an amplification sample based on the distribution characteristics of the local data set of the current edge equipment under the guidance of the second sub-model and/or the global student model of each equipment cluster, so that the server obtains an amplification data set of the target equipment cluster;
the global student model is obtained by distilling knowledge of a second sub-model of each equipment cluster by the server.
34. The apparatus of claim 33, wherein the means for generating comprises:
a sample generation sub-module for generating a pseudo sample based on the diffusion model; wherein the diffusion model learns data distribution characteristics of a local data set of the current edge device; and the inverse process of the diffusion model is constructed based on the global student model;
and the updating sub-module is used for updating the pseudo sample based on the diffusion model under the guidance of the global student model and the second sub-model of each equipment cluster group so as to obtain the amplification sample.
35. An electronic device, comprising:
at least one processor; and
a memory communicatively coupled to the at least one processor; wherein, the liquid crystal display device comprises a liquid crystal display device,
the memory stores instructions executable by the at least one processor to enable the at least one processor to perform the method of any one of claims 1-17.
36. A non-transitory computer readable storage medium storing computer instructions for causing the computer to perform the method of any one of claims 1-17.
37. A computer program product comprising a computer program which, when executed by a processor, implements the method according to any of claims 1-17.
CN202310765118.2A 2023-06-26 2023-06-26 Federal learning method, apparatus, device and medium Pending CN116976461A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310765118.2A CN116976461A (en) 2023-06-26 2023-06-26 Federal learning method, apparatus, device and medium

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310765118.2A CN116976461A (en) 2023-06-26 2023-06-26 Federal learning method, apparatus, device and medium

Publications (1)

Publication Number Publication Date
CN116976461A true CN116976461A (en) 2023-10-31

Family

ID=88484054

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310765118.2A Pending CN116976461A (en) 2023-06-26 2023-06-26 Federal learning method, apparatus, device and medium

Country Status (1)

Country Link
CN (1) CN116976461A (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117196070A (en) * 2023-11-08 2023-12-08 山东省计算中心(国家超级计算济南中心) Heterogeneous data-oriented dual federal distillation learning method and device

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117196070A (en) * 2023-11-08 2023-12-08 山东省计算中心(国家超级计算济南中心) Heterogeneous data-oriented dual federal distillation learning method and device
CN117196070B (en) * 2023-11-08 2024-01-26 山东省计算中心(国家超级计算济南中心) Heterogeneous data-oriented dual federal distillation learning method and device

Similar Documents

Publication Publication Date Title
WO2021078027A1 (en) Method and apparatus for constructing network structure optimizer, and computer-readable storage medium
CN104951425B (en) A kind of cloud service performance self-adapting type of action system of selection based on deep learning
US20200265301A1 (en) Incremental training of machine learning tools
CN113905391B (en) Integrated learning network traffic prediction method, system, equipment, terminal and medium
KR20200022739A (en) Method and device to recognize image and method and device to train recognition model based on data augmentation
CN113570064A (en) Method and system for performing predictions using a composite machine learning model
CN111695415A (en) Construction method and identification method of image identification model and related equipment
WO2021042857A1 (en) Processing method and processing apparatus for image segmentation model
US11423307B2 (en) Taxonomy construction via graph-based cross-domain knowledge transfer
Azzouz et al. Steady state IBEA assisted by MLP neural networks for expensive multi-objective optimization problems
WO2021103675A1 (en) Neural network training and face detection method and apparatus, and device and storage medium
JP2022033695A (en) Method, device for generating model, electronic apparatus, storage medium and computer program product
JP2023523029A (en) Image recognition model generation method, apparatus, computer equipment and storage medium
WO2022227217A1 (en) Text classification model training method and apparatus, and device and readable storage medium
WO2023150912A1 (en) Operator scheduling operation time comparison method and device, and storage medium
Chen et al. Deep-broad learning system for traffic flow prediction toward 5G cellular wireless network
CN116976461A (en) Federal learning method, apparatus, device and medium
US20230368028A1 (en) Automated machine learning pre-trained model selector
CN113822315A (en) Attribute graph processing method and device, electronic equipment and readable storage medium
CN112785005A (en) Multi-target task assistant decision-making method and device, computer equipment and medium
US11475236B2 (en) Minimum-example/maximum-batch entropy-based clustering with neural networks
CN113011895A (en) Associated account sample screening method, device and equipment and computer storage medium
CN111667069A (en) Pre-training model compression method and device and electronic equipment
CN114650321A (en) Task scheduling method for edge computing and edge computing terminal
WO2020151017A1 (en) Scalable field human-machine dialogue system state tracking method and device

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination