CN114429219A - Long-tail heterogeneous data-oriented federal learning method - Google Patents

Long-tail heterogeneous data-oriented federal learning method Download PDF

Info

Publication number
CN114429219A
CN114429219A CN202111502142.4A CN202111502142A CN114429219A CN 114429219 A CN114429219 A CN 114429219A CN 202111502142 A CN202111502142 A CN 202111502142A CN 114429219 A CN114429219 A CN 114429219A
Authority
CN
China
Prior art keywords
model
data
client
representing
student
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111502142.4A
Other languages
Chinese (zh)
Inventor
卢杨
尚心怡
黄刚
华炜
王菡子
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Xiamen University
Zhejiang Lab
Original Assignee
Xiamen University
Zhejiang Lab
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Xiamen University, Zhejiang Lab filed Critical Xiamen University
Priority to CN202111502142.4A priority Critical patent/CN114429219A/en
Publication of CN114429219A publication Critical patent/CN114429219A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Image Analysis (AREA)

Abstract

The invention discloses a federal learning method oriented to long-tail heterogeneous data, which comprises the following steps: step one, a server side randomly initializes a global model w and sends model parameters to each client side, each client side updates the model by using the received model parameters and uploads the updated model parameters to the server side; step two, the server side carries out aggregation on the received local model parameters to obtain a teacher model and a student model; step three, the server side calibrates the teacher model obtained in the step two, so that the teacher model can learn on unbiased knowledge, and a good student model can be taught; and step four, transmitting unbiased knowledge of the teacher model to the student models by using knowledge distillation, and then sending the student models to the clients to start the next round of federal training.

Description

Long-tail heterogeneous data-oriented federal learning method
Technical Field
The invention relates to the technical field of artificial intelligence, in particular to a federal learning method for long-tailed heterogeneous data.
Background
With the further development of big data, the importance on data privacy and security becomes a worldwide trend, and meanwhile, most industrial data show a data island phenomenon, so that the problem that how to perform cross-organization data cooperation is troubling artificial intelligence practitioners on the premise of meeting the user privacy protection, data security and regulations is a big problem. And "federal learning" will become a key technology to solve this industry problem.
Federal learning was originally proposed by google in 2016, and the objective of the federal learning is to realize common modeling and improve the effect of an artificial intelligence model, namely a machine learning paradigm, on the basis of guaranteeing data privacy safety and legal compliance. One of the core challenges of federal learning is the heterogeneity of the distribution of different data among the parties, i.e., heterogeneous data, which can greatly degrade the performance of federal learning. Meanwhile, the overall data distribution of the parties participating in the training tends to exhibit a long-tailed distribution rather than an equilibrium distribution, which may result in the model performing well on the categories with a large number of samples (head categories) and performing poorly on the categories with a small number of samples (tail categories). Therefore, the research on the long-tail distribution and heterogeneous data problem in the federal learning is of great significance.
The existing method for solving the heterogeneous data in the federal learning can be mainly divided into a client method and a server method. The first method is mainly to regularize the local model update of the client, and limit the training direction of the local model of the client by the knowledge of the global model to prevent the training direction of the local model with a larger difference from the global model from deviating from the whole system, so as to reduce the influence caused by heterogeneous data. The second type of method is a server-side method, which mainly employs a special aggregation strategy to omit and mitigate the negative impact caused by heterogeneous data.
The method solves the problem of heterogeneous data in federal learning to a certain extent, but the methods are all realized on the premise that the global data distribution is balanced. In a real scene, the global data distribution is almost impossible to be balanced, the distribution tends to approach the long tail distribution, and the performance of the global model obtained by training by the method in the tail class is still poor, so that a good effect cannot be achieved.
The existing methods for researching long-tail data mainly comprise three categories of resampling, weighting and characterization learning. The first method mainly repeats the class data with small sample size and reduces the class data with large sample size, aiming at reconstructing a balanced data set. The second method is to modify the loss function to make it more beneficial for training the tail class. The third method is to use the characteristics of deep learning to focus on the characterization learning of the input data.
In the method for solving the long tail data distribution, the data is centralized together for model training, but the method is not suitable for the real federal learning environment. To solve the data imbalance problem under federal learning, Duan, M et al propose the Astraea method, which first performs data sampling before training the model to construct a balanced training data set, thereby alleviating the global imbalance. Then using some Mediators, the client group for which each Mediator is responsible is assigned according to the KL divergence between Mediators, and the model training of the clients is rescheduled. By selectively combining clients of heterogeneous data, it may be possible to achieve a new local balance. But this results in some clients never being selected, i.e., not participating in the federal training process, and their own information cannot be utilized. Wang, L et al propose a Ratio Loss method, which realizes monitoring of data imbalance opacity in Federal learning, and a novel Loss function Ratio Loss is proposed to reduce the influence caused by imbalance. But the performance of the method is sharply reduced as the data isomerism degree is deepened.
Disclosure of Invention
In order to solve the defects of the prior art, the invention realizes the purposes of satisfying the user privacy protection and data security and simultaneously improving the model performance under the federal study so as to improve the image recognition efficiency, and adopts the following technical scheme:
a federal learning method for long-tailed heterogeneous data comprises the following steps:
s1, the server side initializes the global model w randomly and sends the model parameters to each client side, each client side updates the local model by using the received model parameters and uploads the updated local model parameters to the server side;
s2, the server side carries out aggregation on the received local model parameters to obtain a teacher model and a student model;
s3, the server side calibrates the teacher model to enable the teacher model to learn on unbiased knowledge, and therefore a good student model is taught;
and S4, transmitting unbiased knowledge of the teacher model to the student models through knowledge distillation, and then sending the student models to each client to start the next round of federal training.
Further, in step S1, the server initializes the global model parameter w, randomly selects the set S of clients participating in the current round of training, broadcasts the model parameter to each client in the set S, S participating in the current round of training, and executes a random gradient descent (SGD) using the received global model parameter w and local data to update the local model, where the local model parameter obtained by updating the client k is wkAnd after the updating, each client sends the updated model parameters back to the server.
Further, step S2 includes the following steps:
s21, the server side carries out average weighting on the local model parameters to obtain a student model, and the calculation formula is as follows:
Figure BDA0003402166730000021
φs(x)=φw(x) (formula 2)
Wherein, | Dk| represents the amount of image data owned by the kth client, | D | represents the total amount of image data owned by all clients, K represents the number of clients, x represents the input image data, φw(. represents a network of Federal averaging models, phisNetwork of (a) representation student models;
S22, the server side carries out weighting aggregation on the local model parameters to obtain a teacher model, and the calculation formula is as follows:
Figure BDA0003402166730000022
wherein phit(. a) a network representing a teacher model, ekRepresents the weight assigned to client k, represents
Figure BDA0003402166730000023
Network of kth client.
Further, in step S3, since the local models are trained on the local data with different distributions, and each local model may behave differently on the tail class, we need to assign higher weight to the local model that performs better on the tail class, however, the server does not know which image class is the tail class and which client local model performs well on the top, and therefore we do not give each client a fixed weight, and instead, we propose a client-based weight assignment strategy to calculate the weight e of each client local modelkFinally, e iskNormalized to sum to 1, i.e. the final weight, weight ekThe calculation formula of (a) is as follows:
Figure BDA0003402166730000031
wherein, ae∈RcAnd beRepresenting a network parameter that can be learned, RcRepresenting a c-dimensional vector, T being a transposed symbol, the client-based calibration works like a self-attention mechanism, computing weights for the local model according to the original output logits of the model, and multiplying the weights back to the original output logits.
Further, in step S3, if none of the client-side local models can handle the end classes well, the teacher model obtained by the weighted integration is biased toIn the head class, in order to solve the problem, a class-based original output logits calibration strategy is proposed to further improve the performance of the model in the tail class, and the calibrated model output logits is set as zclThe calculation formula is as follows:
zcl=az⊙φt(x)+bz(formula 5)
Wherein, azAnd bzIndicating a network parameter that can be learned, an indicates a hadamard product.
Further, in step S3, the premise that the above calibration policy for logits is valid is that the characterization information extracted by the local model for the input image data is good enough, and if the feature extraction of the input image data by the client local model is seriously affected by the long tail distribution, it is not enough to calibrate only the output logits, so we need to update the feature extractor to further improve the model performance, and we use an additional image with balanced labels at the server end to form a balanced labeled data set
Figure BDA0003402166730000032
Fine tuning is carried out on the global model w to obtain a fine tuning model
Figure BDA0003402166730000033
Because of
Figure BDA0003402166730000034
Is balanced, so the model is fine-tuned
Figure BDA0003402166730000035
An unbiased feature extractor can be obtained, and then we can obtain a fine-tuned model output logits of x for the input image data as
Figure BDA0003402166730000036
Wherein z isftRepresenting the output of the fine tuning model for x,
Figure BDA0003402166730000037
a network representing a fine-tuning model.
Further, the model is fine-tuned
Figure BDA0003402166730000038
Wherein, eta represents the learning rate,
Figure BDA0003402166730000039
the function of the loss is represented by,
Figure BDA00034021667300000310
the derivation is indicated.
Further, in step S3, zclAnd zftIs to calibrate the teacher model from two different levels, zclThe teacher model output logits level is calibrated, the model feature extractor is fixed, however zftIs the result of fine adjustment of a feature extractor, so as to improve the feature extraction capability of the model, and in order to fully combine the advantages of the feature extractor and the model, a calibration gating network is provided for zclAnd zftAnd (3) carrying out weighing, namely calibrating the gating network by taking the integrated characteristics as input and outputting weights through a nonlinear layer, so that each sample obtains different weights according to different characteristics of the sample, wherein the weight calculation formula is as follows:
σ=sigmoid(uTv) (equation 6)
Wherein the content of the first and second substances,
Figure BDA00034021667300000311
the integrated features are represented as such,
Figure BDA00034021667300000312
feature extractor representing the kth client, u ∈ RdRepresenting a network parameter that can be learned, RdThe d-dimensional vector is represented, so the final calibration model output logits through the calibration gating network is z', the calculation formula is as follows:
z′=σzcl+(1-σ)zft(formula 7)
Where σ ∈ (0,1) is used to trade-off zclAnd zftBoth models output logits.
Further, all parameters that can be learned are passed through the entire process of integrated calibration
Figure BDA0003402166730000041
The cross entropy penalty above is updated as follows:
Figure BDA0003402166730000042
wherein C represents the number of categories, yjA real label representing input image data, j represents a j-th dimension value in y, exp (-) represents an exponential function with a natural constant e as a base, and z'jDenotes the value of j dimension, z 'in the final calibration z'iDenotes the value of the ith dimension in the final calibration z ', which is a vector of z ' and z 'jAnd z'iRespectively, representing the value of one of the dimensions.
Further, in step S4, unbiased knowledge of the teacher model is transferred to the student models by knowledge distillation, and specifically, to better teach unbiased knowledge of the teacher model to the student models, we train the student models by combining labeled data training and unlabeled data distillation, and the loss function is as follows:
L′=(1-λ)LCE+λLKL(formula 9)
Wherein L isCECross entropy loss, L, between model output logits representing the student model and the image true label ground-truthKLRepresenting the relative entropy (KL-Leibler) divergence of model output logits between teacher and student models by balancing tagged datasets
Figure BDA0003402166730000043
Calculating LCEAnd using another unlabelled image to construct an unlabelled dataset
Figure BDA0003402166730000044
Calculating LKLSo as to further improve the distillation performance of knowledge, and the lambda belongs to [0,1 ]]Represents a hyper-parameter, pair LCEAnd LKLA trade-off is made.
The invention has the advantages and beneficial effects that:
the invention researches the joint problem of heterogeneous data and long tail distribution in federal learning, fully utilizes the diversity of a local model of a client to process the heterogeneous data, provides a new model calibration strategy and a gating network to effectively solve the long tail problem, and further improves the model performance under the federal learning.
Drawings
FIG. 1 is a flow chart of the method of the present invention.
Fig. 2 is a diagram of a client data distribution in the present invention.
Detailed Description
The following detailed description of embodiments of the invention refers to the accompanying drawings. It should be understood that the detailed description and specific examples, while indicating the preferred embodiment of the invention, are given by way of illustration and explanation only, not limitation.
As shown in fig. 1, a federal learning method oriented to long-tail heterogeneous data includes the following steps:
step one, preparing a data set, initializing a network, distributing the data set to each client side, and updating a model.
Step 1.1, the data sets used are CIFAR-10, CIFAR-100 and ImageNet _ LT.
The CIFAR-10 dataset had 60000 color images, the size of these images was 32 × 32, and were divided into 10 classes: airplanes (airplane), cars (automobile), birds (bird), cats (cat), deer (deer), dogs (dog), frogs (frog), horses (horse), boats (ship), and trucks (truck). 6000 pictures in each category, wherein 50000 pictures are used for training, 5 training batches are formed, and 10000 pictures in each batch; another 10000 was used for testing, constituting a batch individually. From the test lot data, 1000 sheets were randomly taken from each of 10 categories. The remainder is randomly arranged to form a training batch. Note that the number of images in a training batch is not necessarily the same, and there are 5000 images for each class in the training batch. In addition, 100 random pictures in each type in CIFAR-10 are selected to form an additional balanced data set
Figure BDA0003402166730000051
Knowledge distillation was performed using CIFAR-100 as unlabeled data.
CIFAR-100 has 100 classes, each containing 600 pictures. There are 500 training images and 100 test images per class. The 100 classes in CIFAR-100 are divided into 20 super classes. Each image carries a "fine" label (the class to which it belongs) and a "coarse" label (the superset to which it belongs). Randomly selecting 10 pictures in each type in CIFAR-100 to form an additional balanced data set
Figure BDA0003402166730000052
Knowledge distillation was performed using down-sampled ImageNet (image size 32 x 32) as unlabeled data.
ImageNet-LT is a large image classification dataset, a long-tailed version of ImageNet, by sampling subsets that obey Pareto distributions according to α ═ 6. It contains 115800 images in 1000 categories, the largest and smallest category containing 1280 and 5 images, respectively. We obtained datasets from balance validation data
Figure BDA0003402166730000053
Knowledge distillation was performed using oversampled CIFAR100 (image size 224 x 224) as unlabeled data.
The three data sets are distributed to different clients according to the degree of isomerism eta of 0.1 in the dirichlet distribution, and a data distribution diagram on the CIFAR-10 is shown in fig. 2 as local data of the three data sets.
And 1.2, building a federal learning environment and initializing a network.
Training was performed on CIFAR-10-LT and CIFAR-100-LT using ResNet-8 network and ImageNet-LT using ResNet-50 network. All our experiments were run by PyTorch on two NVIDIA GeForce RTX 3080 GPUs. Typically, we design 20 clients for 200 rounds of training, with 40% of the clients selected per round to participate in federal training. For client training, the batch size is set to 128, the learning rate is 0.1, and the SGD acts as the optimizer. For server-side global model training, we set the calibration epoch to 100, the distillation epoch to 100, and knowledge distillation using an Adam optimizer with a learning rate of 0.001.
And step 1.3, updating the client model.
And the server side initializes the global model parameters w, randomly selects a client side set S participating in the current round of training and broadcasts the model parameters to the client side set S participating in the current round of training. Each client in S utilizes the received global model parameter w and local data to execute random gradient descent (SGD) to update the models thereof, and the model parameter obtained by updating the client k is set as wk. After the updating, each client sends the updated model parameters back to the server.
Step two, the server side obtains a teacher model and a student model, and the specific process comprises the following substeps:
step 2.1, firstly, the server side carries out average weighting on the received model parameters to obtain a student model, and the calculation formula is as follows:
Figure BDA0003402166730000061
φs(x)=φw(x) (formula 2)
Step 2.2, then the server side carries out weighting aggregation on the model parameters obtained from the client side to obtain a teacher model, and the calculation formula is as follows:
Figure BDA0003402166730000062
wherein e iskIs the weight assigned to client k.
And step three, the server side calibrates the teacher model obtained in the step D.
Since the global data has a long tail distribution, the model obtained by each client is biased to the head class and performs poorly on the tail class, and the teacher model weighted by each client model is biased to the head class and ignores the tail class. Because the teacher model is biased to the head class, the contained knowledge is very biased, which causes the taught student model to be biased and seriously impairs the performance of the model. Therefore, the teacher model is calibrated to obtain an unbiased teacher model, and then unbiased knowledge is taught to the student models. The specific method comprises the following substeps:
step 3.1, because the local models are trained on local data with different distributions, and each local model may behave differently on the tail class, we assign higher weights to the local models that behave better on the tail class. However, the server does not know which classes are tail classes and which local models perform well on top, so we do not give each client a fixed weight, but instead we propose a client-based weight assignment strategy to compute the weight e for each client local modelkAnd e is combinedkNormalization makes the sum equal to 1, which is the final weight. The calculation formula is as follows:
Figure BDA0003402166730000063
wherein, ae∈RCAnd beAre parameters that can be learned. Client-based calibration just like the self-attention mechanism, weights are computed for the local model from the original logits, which are then multiplied back to the original logits.
And 3.2, if no local model can well process the tail class, the teacher model obtained by the weighted integration is probably still biased to the head class. To solve this problem, we propose a class-based logits calibration strategy to further improve the performance of the model in the tail class. Let the calibrated model output logits be zclThe calculation formula is as follows:
zcl=az⊙φt(x)+bz(A)Formula 5)
Wherein, azAnd bzIt is a network parameter that can be learned, which indicates a hadamard product.
Step 3.3, the premise that the above calibration strategy for logits is effective is that the characterization information extracted by the local model for the input data is good enough, and if the feature extraction of the data by the local model is seriously affected by long-tailed distribution, it is not enough to calibrate only the logits. Therefore, we need to update the feature extractor to further improve the model performance. We leverage additional balanced tagged data on the server side
Figure BDA0003402166730000071
Fine tuning is carried out on the global model w to obtain a fine tuning model
Figure BDA0003402166730000072
Because of the fact that
Figure BDA0003402166730000073
Are balanced, so the model is fine-tuned
Figure BDA0003402166730000074
An unbiased feature extractor can be obtained. Then, we can obtain the fine tuning locations for the input x as
Figure BDA0003402166730000075
Step 3.4, by the above steps, it can be seen that zclAnd zftThe teacher model is calibrated from two different levels. z is a radical ofclThe teacher model output logits level is calibrated, the model feature extractor is fixed, however zftIs the result of fine tuning the feature extractor, thereby improving the feature extraction capability of the model. To fully combine the advantages of the two, we propose a calibration gating network to align zclAnd zftA trade-off is made. The gating network takes the integrated feature as an input and outputs the weight through a nonlinear layer, so that each sample obtains different weights according to different features of the sample. Weight calculationThe formula is as follows:
σ=sigmoid(uTv) (equation 6)
Wherein the content of the first and second substances,
Figure BDA0003402166730000076
is an integrated feature, u is an element of RdAre network parameters that can be learned. Thus, the final calibration logits by calibrating the gating network is z', the calculation formula is as follows:
z′=σzcl+(1-σ)zft(equation 7)
Where σ ∈ (0,1) is used to trade off two logits.
Step 3.5, all parameters that can be learned during the whole process of the integrated calibration are passed
Figure BDA0003402166730000077
The cross entropy penalty above is updated as follows:
Figure BDA0003402166730000078
and step four, transmitting unbiased knowledge of the teacher model to the student model by using knowledge distillation.
To better teach unbiased knowledge of the teacher model (i.e., the calibration integrated model) to the student models (i.e., the global models), we trained the student models using a combination of labeled data training and unlabeled data distillation, with a loss function comprising two parts: (1) l isCEIs the cross entropy loss between logits and ground-truth of the student model; (2) l isKLIs the Kullback-Leibler (KL) divergence of logits between teacher and student models. We use
Figure BDA0003402166730000079
To calculate LCEAnd using another unlabeled data set
Figure BDA00034021667300000710
To calculate LKLTo further improve knowledgeDistillation performance. The final loss function is determined by the hyperparameter lambda epsilon [0,1 ]]The trade-offs are:
L′=(1-λ)LCE+λLKL(formula 9)
Table 1 shows the results of the precision (%) alignment of the CIFAR-10-LT and CIFAR-100-LT data sets with imbalance ratios of 100, 50 and 10 in accordance with the present invention and several other Federal learning methods. The bolded results in the table are the optimal results for each index.
It can be seen from the results in table 1 that the method of the present invention can solve the joint problem of long tail distribution and heterogeneous data in federal learning, and the method of the present invention achieves the highest test accuracy at all degrees of imbalance.
TABLE 1
Figure BDA0003402166730000081
Table 2 shows the results of comparison of the accuracy (%) of the ImageNet-LT data set of the present invention with several Federal learning methods. The bolded results in the table are the optimal results for each index.
The accuracy of several methods on three categories is compared in table 2, respectively: a head class (number of samples over 100), a middle class (number of samples between 20 and 100), and a tail class (number of samples less than 20). Compared with other methods, the method of the invention achieves the best results. Meanwhile, the accuracy of the method on the tail class reaches 15.91%, the method solves the problem of combination of long tail distribution and heterogeneous data in federal learning, and greatly improves the performance of the model on the tail class while improving the overall performance of the model.
TABLE 2
Figure BDA0003402166730000082
Figure BDA0003402166730000091
In tables 1 and 2:
FedAvg corresponds to the method proposed by McMahan, B et al (McMahan, B.; Moore, E.; Ramage, D.; Hampson, S.; and y ARCas, B.A.2017. communication-effective learning of deep networks from centralized data. in Artificial Intelligence and Statistics, 1273-;
FedAvgM corresponds to the method proposed by Hsu, T. -M.H et al (Hsu, T. -M.H.; Qi, H.; and Brown, M.2019.measuring the effects of non-essential data distribution for fed visual classification. arXiv preprint arXiv: 1909.06335.);
FedProx corresponds to a method (Li, T.; Sahu, A.K.; Zaheer, M.; Sanjabi, M.; Talwalkar, A.; and Smith, V.2020b. Federated optimization in heterologous networks. in Machine Learning and Systems, 429. 450.), proposed by Li, T et al, FedProx, and improves the stability of convergence while improving the accuracy of the model by adding a proximalterm correction term to a loss function updated by a client;
FedVova corresponds to the method proposed by Wang, J et al (Wang, J.; Liu, Q.; Liang, H.; Joshi, G.; and Poor, H.V.2020b. taggling the objective in-situ purification in heterologous knowledge optimization. in Advances in Neural Information Processing Systems, 7611-7623.);
FedDF corresponds to the method proposed by Lin, T et al (Lin, T.; Kong, L.; Stich, S.U.; and Jaggi, M.2020. end partition for robust model fusion in Federal learning, in Advances in Neural Information Processing Systems, 2351-;
FedBE corresponds to a method (Chen, H. -Y.; and Chao, W. -L.2021.FedBE: creating basic model applicable to fed Learning in International Conference on Learning retrieval.) proposed by Chen, H. -Y et al, strong aggregation is realized from the perspective of Bayesian inference by sampling a high-quality global model and combining the models through Bayesian models;
Fed-Focal local corresponds to the method proposed by Sarkar, D.C. (Sarkar, D.; Narang, A.; and Rai, S.2020.Fed-Focal local for immunological data classification in contaminated left-hand study. arXiv prediction arXiv: 2011.06283.);
the method (Wang, L.; Xu, S.; Wang, X.; and Zhu, Q.2021a. addressing class interference in affected learning. in AAAI reference on opacity Intelligence of data imbalance in federal learning, 10165-;
cRT, tau-norm and LWS correspond to the methods proposed by Kang, B et al (Kang, B.; Xie, S.; Rohrbach, M.; Yan, Z.; Gordo, A.; Feng, J.; and Kalantidis, Y.2020.Decoupling representation and classifier for long-tail registration. in International Conference on Learning Representations), indicating that data imbalance does not affect the high-quality representation of the Learning input data and that the authors can achieve strong long-tail recognition capability by adjusting only the classifier.
The above examples are only intended to illustrate the technical solution of the present invention, but not to limit it; although the present invention has been described in detail with reference to the foregoing embodiments, it will be understood by those of ordinary skill in the art that: the technical solutions described in the foregoing embodiments may still be modified, or some or all of the technical features may be equivalently replaced; and the modifications or the substitutions do not make the essence of the corresponding technical solutions depart from the scope of the technical solutions of the embodiments of the present invention.

Claims (10)

1. A federal learning method facing long-tail heterogeneous data is characterized by comprising the following steps:
s1, the server side initializes the global model w randomly and sends the model parameters to the client side, the client side updates the local model by using the received model parameters and uploads the updated local model parameters to the server side;
s2, the server side aggregates the local model parameters to obtain a teacher model and a student model;
s3, the server side calibrates the teacher model to enable the teacher model to learn on unbiased knowledge;
and S4, transmitting unbiased knowledge of the teacher model to the student model through knowledge distillation, and then sending the student model to the client to start the next round of federal training.
2. The method of claim 1, wherein in step S1, the server initializes global model parameters w, randomly selects a set S of clients participating in the current round of training, broadcasts the model parameters to the clients in the set S, S of the clients participating in the current round of training, performs a random gradient descent using the received global model parameters w and local data to update the local model, and updates the local model parameter w obtained by the client kkAnd after the updating, the client sends the updated model parameters back to the server.
3. The long-tailed heterogeneous data-oriented federal learning method as claimed in claim 2, wherein the step S2 includes the steps of:
s21, the server side carries out average weighting on the local model parameters to obtain a student model, and the calculation formula is as follows:
Figure FDA0003402166720000011
φs(x)=φw(x) (formula 2)
Wherein, | Dk| represents the amount of data owned by the kth client, | D | represents the total amount of data owned by all clients, K represents the number of clients, x represents the input data, φw(. represents a network of Federal averaging models, phis(. cndot.) represents a network of student models.
S22, the server side carries out weighting aggregation on the local model parameters to obtain a teacher model, and the calculation formula is as follows:
Figure FDA0003402166720000012
wherein phit(. a) a network representing a teacher model, ekWeight, representation of client k
Figure FDA0003402166720000013
Network of kth client.
4. The method of claim 3, wherein in step S3, a client-based weight distribution strategy is proposed to calculate the weight e of each client local modelkFinally, e iskNormalizing to make the sum equal to 1, i.e. the final weight, weight ekThe calculation formula of (a) is as follows:
Figure FDA0003402166720000014
wherein, ae∈RcAnd beRepresenting a network parameter that can be learned, RcRepresenting a c-dimensional vector, T being a transposed symbol, calculating a weight for the local model according to the original output of the model, and multiplying the weight back to the original output.
5. The method of claim 4, wherein in step S3, a class-based raw output calibration strategy is proposed, and the calibrated model output is zclThe calculation formula is as follows:
zcl=az⊙φt(x)+bz(formula 5)
Wherein, azAnd bzIndicating a network parameter that can be learned, an indicates a hadamard product.
6. The method for federal learning of long-tailed heterogeneous data according to claim 5, wherein in step S3, additional balanced tagged data sets are utilized on the server side
Figure FDA0003402166720000021
Fine tuning is carried out on the global model w to obtain a fine tuning model
Figure FDA0003402166720000022
The fine-tuning model output for input data x is
Figure FDA0003402166720000023
Wherein z isftRepresenting the output of the fine tuning model for x,
Figure FDA0003402166720000024
a network representing a fine-tuning model.
7. The method of claim 6, wherein the fine-tuning model is based on the long-tailed heterogeneous data
Figure FDA0003402166720000025
Wherein, eta represents the learning rate,
Figure FDA0003402166720000026
the function of the loss is represented by,
Figure FDA0003402166720000027
the derivation is indicated.
8. The method for federal learning of long-tailed heterogeneous data according to claim 6, wherein in step S3, z is calibrated by a calibration gating networkclAnd zftAnd (3) carrying out weighing, namely calibrating the gating network, taking the integrated features as input, and outputting weights through the nonlinear layer, wherein a weight calculation formula is as follows:
σ=sigmoid(uTv) (equation 6)
Wherein the content of the first and second substances,
Figure FDA0003402166720000028
the integrated features are represented as such,
Figure FDA0003402166720000029
feature extractor representing the kth client, u ∈ RdRepresenting a network parameter that can be learned, RdRepresenting a d-dimensional vector, and outputting a final calibration model through a calibration gating network as z', wherein the calculation formula is as follows:
z′=σzcl+(1-σ)zft(formula 7)
Where σ ∈ (0,1) is used to trade off zclAnd zftAnd outputting two models.
9. The method of claim 8, wherein the parameters capable of being learned pass through the whole process of the integrated calibration
Figure FDA00034021667200000210
The cross entropy penalty above is updated as follows:
Figure FDA00034021667200000211
wherein C represents the number of categories, yjTrue tag representing input data, j represents the value of dimension j in y, exp (-) represents an exponential function, z'jDenotes the value of j dimension, z 'in the final calibration z'iRepresenting the value of the ith dimension in the final calibration z'.
10. The federal learning method for long-tailed heterogeneous data as claimed in claim 1, wherein in step S4, unbiased knowledge of the teacher model is transferred to the student model by knowledge distillation, and specifically, the student model is trained by a combination of labeled data training and unlabeled data distillation, and the loss function is as follows:
L′=(1-λ)LCE+λLKL(formula 9)
Wherein L isCERepresenting the cross-entropy loss between the model output of the student model and the true label, LKLRepresenting the relative entropy divergence of model outputs between teacher and student models by balancing tagged datasets
Figure FDA00034021667200000212
Calculating LCEAnd using unlabeled datasets
Figure FDA00034021667200000213
Calculating LKL,λ∈[0,1]Represents a hyper-parameter, pair LCEAnd LKLA trade-off is made.
CN202111502142.4A 2021-12-09 2021-12-09 Long-tail heterogeneous data-oriented federal learning method Pending CN114429219A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111502142.4A CN114429219A (en) 2021-12-09 2021-12-09 Long-tail heterogeneous data-oriented federal learning method

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111502142.4A CN114429219A (en) 2021-12-09 2021-12-09 Long-tail heterogeneous data-oriented federal learning method

Publications (1)

Publication Number Publication Date
CN114429219A true CN114429219A (en) 2022-05-03

Family

ID=81310815

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111502142.4A Pending CN114429219A (en) 2021-12-09 2021-12-09 Long-tail heterogeneous data-oriented federal learning method

Country Status (1)

Country Link
CN (1) CN114429219A (en)

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115271033A (en) * 2022-07-05 2022-11-01 西南财经大学 Medical image processing model construction and processing method based on federal knowledge distillation
CN115511108A (en) * 2022-09-27 2022-12-23 河南大学 Data set distillation-based federal learning personalized method
CN115907001A (en) * 2022-11-11 2023-04-04 中南大学 Knowledge distillation-based federal diagram learning method and automatic driving method
CN116701939A (en) * 2023-06-09 2023-09-05 浙江大学 Classifier training method and device based on machine learning
CN117010534A (en) * 2023-09-27 2023-11-07 中国人民解放军总医院 Dynamic model training method, system and equipment based on annular knowledge distillation and meta federal learning
CN117236421A (en) * 2023-11-14 2023-12-15 湘江实验室 Large model training method based on federal knowledge distillation
WO2024027164A1 (en) * 2022-08-01 2024-02-08 浙江大学 Adaptive personalized federated learning method supporting heterogeneous model

Cited By (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115271033A (en) * 2022-07-05 2022-11-01 西南财经大学 Medical image processing model construction and processing method based on federal knowledge distillation
CN115271033B (en) * 2022-07-05 2023-11-21 西南财经大学 Medical image processing model construction and processing method based on federal knowledge distillation
WO2024027164A1 (en) * 2022-08-01 2024-02-08 浙江大学 Adaptive personalized federated learning method supporting heterogeneous model
CN115511108A (en) * 2022-09-27 2022-12-23 河南大学 Data set distillation-based federal learning personalized method
CN115907001A (en) * 2022-11-11 2023-04-04 中南大学 Knowledge distillation-based federal diagram learning method and automatic driving method
CN115907001B (en) * 2022-11-11 2023-07-04 中南大学 Knowledge distillation-based federal graph learning method and automatic driving method
CN116701939A (en) * 2023-06-09 2023-09-05 浙江大学 Classifier training method and device based on machine learning
CN116701939B (en) * 2023-06-09 2023-12-15 浙江大学 Classifier training method and device based on machine learning
CN117010534A (en) * 2023-09-27 2023-11-07 中国人民解放军总医院 Dynamic model training method, system and equipment based on annular knowledge distillation and meta federal learning
CN117010534B (en) * 2023-09-27 2024-01-30 中国人民解放军总医院 Dynamic model training method, system and equipment based on annular knowledge distillation and meta federal learning
CN117236421A (en) * 2023-11-14 2023-12-15 湘江实验室 Large model training method based on federal knowledge distillation
CN117236421B (en) * 2023-11-14 2024-03-12 湘江实验室 Large model training method based on federal knowledge distillation

Similar Documents

Publication Publication Date Title
CN114429219A (en) Long-tail heterogeneous data-oriented federal learning method
CN112949837B (en) Target recognition federal deep learning method based on trusted network
CN108876735B (en) Real image blind denoising method based on depth residual error network
CN109740627A (en) A kind of insect image identification identifying system and its method based on parallel-convolution neural network
CN108197290A (en) A kind of knowledge mapping expression learning method for merging entity and relationship description
CN104657718B (en) A kind of face identification method based on facial image feature extreme learning machine
De Fauw et al. Hierarchical autoregressive image models with auxiliary decoders
CN111931816A (en) Parallel processing method and device for retina images
Hou et al. Saliency-guided deep framework for image quality assessment
Qin et al. Data-efficient image quality assessment with attention-panel decoder
CN113743474A (en) Digital picture classification method and system based on cooperative semi-supervised convolutional neural network
CN105740884B (en) Hyperspectral Image Classification method based on singular value decomposition and neighborhood space information
CN108596044A (en) Pedestrian detection method based on depth convolutional neural networks
CN114997374A (en) Rapid and efficient federal learning method for data inclination
CN112270397B (en) Color space conversion method based on deep neural network
CN105787045B (en) A kind of precision Enhancement Method for visual media semantic indexing
CN112905894B (en) Collaborative filtering recommendation method based on enhanced graph learning
CN108428226B (en) Distortion image quality evaluation method based on ICA sparse representation and SOM
CN116561622A (en) Federal learning method for class unbalanced data distribution
CN110598737A (en) Online learning method, device, equipment and medium of deep learning model
CN116137043A (en) Infrared image colorization method based on convolution and transfomer
CN113887806B (en) Long-tail cascade popularity prediction model, training method and prediction method
Wang et al. Logit Calibration for Non-IID and Long-Tailed Data in Federated Learning
CN115759297A (en) Method, device, medium and computer equipment for federated learning
Hu et al. Tree species identification based on the fusion of multiple deep learning models transfer learning

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination