CN115965078A - Classification prediction model training method, classification prediction method, device and storage medium - Google Patents

Classification prediction model training method, classification prediction method, device and storage medium Download PDF

Info

Publication number
CN115965078A
CN115965078A CN202211564028.9A CN202211564028A CN115965078A CN 115965078 A CN115965078 A CN 115965078A CN 202211564028 A CN202211564028 A CN 202211564028A CN 115965078 A CN115965078 A CN 115965078A
Authority
CN
China
Prior art keywords
classification prediction
prediction model
local
model
representing
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211564028.9A
Other languages
Chinese (zh)
Inventor
唐程
马骏
王向阳
刘小海
刘彤
吕丰
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Hunan Energy Big Data Center Co ltd
Central South University
Original Assignee
Hunan Energy Big Data Center Co ltd
Central South University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Hunan Energy Big Data Center Co ltd, Central South University filed Critical Hunan Energy Big Data Center Co ltd
Priority to CN202211564028.9A priority Critical patent/CN115965078A/en
Publication of CN115965078A publication Critical patent/CN115965078A/en
Pending legal-status Critical Current

Links

Images

Landscapes

  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

The invention discloses a classification prediction model training method, a classification prediction method, a device and a storage medium, wherein the training method comprises the steps that each client utilizes a local training data set to train a local classification prediction model, the prototypes of various data in the local training data set are calculated, and the soft decision of a public data set is calculated; the central server side aggregates all prototypes and all soft decisions, constructs an optimization objective function by utilizing the aggregated prototypes and the aggregated soft decisions, trains a global classification prediction model and calculates the soft decisions of a public data set; each client side trains a local classification prediction model by using the received soft decision and the public data set; and when the circulation round is equal to the set round, obtaining each trained local classification prediction model and global classification prediction model. The invention can reduce the communication overhead between the server and the client, realize the individuation of the model architecture and improve the model precision.

Description

Classification prediction model training method, classification prediction method, device and storage medium
Technical Field
The invention belongs to the technical field of privacy computation, and particularly relates to a classification prediction model training method based on a federal knowledge distillation algorithm, a classification prediction method, electronic equipment and a storage medium.
Background
Due to the high development of mobile devices (such as mobile phones, watches, computers, and the like) and the advancement of sensing technologies, a large amount of data (which is private data of users, such as personal pictures) is collected by the mobile devices at the edge, artificial intelligence is rapidly developed nowadays, and such private data is usually aggregated and stored in the cloud, and various intelligent applications are realized in cooperation with machine learning or deep learning models. However, when sensitive original data is uploaded to the cloud end through the network, the problem of serious data privacy disclosure for data donors can occur when private data is processed in the cloud end in a centralized manner, and the concept of federal learning arises on the basis of the driving force for protecting the data privacy security. Different from a centralized learning mode, federated learning supports collaborative learning of global models on distributed computing nodes using local data, original data are not sent to a cloud, and only the learned global models are updated and submitted to the cloud for aggregation; then, the global model on the cloud is updated and sent back to the distributed computing nodes for the next iteration. Through the iteration mode, the global model can be learned under the condition of not damaging the privacy of the user, and besides improving the data privacy problem, the federal learning also brings many other benefits, such as improvement of safety, autonomy and efficiency.
As federal learning evolves, many new challenges also arise. The most major challenges come from two areas:
(1) Conventional federated learning algorithms share model parameters at each iteration, which means that the communication overhead can be excessive. Because the existing deep learning model may have millions of parameters, for example, mobileBRET is a deep learning model structure of a natural language processing task, and has 2500 ten thousand parameters, corresponding to the memory size of 96MB, while the mobile device at the edge end is often limited by bandwidth, and each round of communication needs to exchange 96MB of information, which is challenging for the mobile device, so that many mobile devices cannot participate in the federal learning task requiring large parameter interaction.
(2) The heterogeneity problem poses a significant challenge to the desire to deploy the federated learning system in real-world scenarios. On one hand, the problem of model heterogeneity is solved, most of mobile devices participating in the federal learning task have different computing resources and bandwidth resources, the mobile devices do not have enough bandwidth or computing capacity to train a large-scale deep learning model, and therefore different participants may need models of different architectures to train, and the federal learning architecture based on model parameter interaction does not meet the requirements of the participants on using different architecture models; on the other hand, the problem of data heterogeneity is solved, the local data distribution of each mobile device participating in the federal learning task has the characteristic of non-independent and same distribution globally, and the convergence of the model can be hindered by simply aggregating the model parameters of the mobile device client.
Based on the above constraints, the global model trained by the federal learning task may not have very high accuracy in real practice.
Disclosure of Invention
The invention aims to provide a classification prediction model training method, a classification prediction method, equipment and a storage medium based on a federal knowledge distillation algorithm, and aims to solve the problems that the communication overhead of a traditional federal learning algorithm is too high, the traditional federal learning algorithm cannot meet the requirements of participants on using different architecture models, and the model precision cannot be improved due to data isomerism.
The invention solves the technical problems through the following technical scheme: a classification prediction model training method based on a federal knowledge distillation algorithm comprises the following steps:
step 1: build by central server side and N clients C = { C 1 ,C 2 ,...,C i ,...,C N A federal learning system composed of (i) N.gtoreq.2, C i Representing the ith client;
and 2, step: each of the clients C i All locally constructing a local training data set D with labels i And local classification prediction model X i And let cycle t =1;
and step 3: each of the clients C i Locally using a local training data set D i For local classification prediction model X i Performing iterative training and using the trained local classification prediction model X i Calculate the local training data set D i Prototype of various kinds of data
Figure BDA0003985889150000021
Wherein a prototype of k-like data +>
Figure BDA0003985889150000022
Local classification prediction model X referring to class k data i Outputting the average value of the feature vectors;
each of the clients C i Locally utilizing the trained local classification prediction model X i Calculating public data set D without annotations P Soft decision of
Figure BDA0003985889150000023
Wherein the soft decision->
Figure BDA0003985889150000024
Refers to a local classification prediction model X i The predicted output of (2);
and 4, step 4: all clients C will count eachCalculated prototype
Figure BDA0003985889150000025
And a soft decision->
Figure BDA0003985889150000026
Sending the data to the central server;
and 5: the central server side respectively aggregates all the received prototypes and all the soft decisions of each type to obtain various aggregated prototypes and aggregated soft decisions; constructing an optimization objective function by utilizing various types of prototypes after aggregation and soft decisions after aggregation, and utilizing a common data set D P The optimization objective function carries out iterative training on the constructed global classification prediction model;
calculating a public data set D without annotations by using the trained global classification prediction model P Soft decision of
Figure BDA0003985889150000027
Wherein the soft decision->
Figure BDA0003985889150000028
The prediction output of the global classification prediction model is referred to;
step 6: the central server end makes the soft decision
Figure BDA0003985889150000029
To each of the clients c i
And 7: each of the clients C i Utilizing received soft decisions
Figure BDA00039858891500000210
And a common data set D P For local classification prediction model X i Performing iterative training;
and step 8: judging whether the cycle t is equal to the set cycle, if so, obtaining each trained local classification prediction model X i And a global classification prediction model; otherwise, let t = t +1, and jump to step3。
Further, the local classification prediction model and the global classification prediction model both adopt a depth residual error network model.
Further, for the client C i Prototype of class k data
Figure BDA00039858891500000211
The specific calculation formula of (A) is as follows:
Figure BDA0003985889150000031
wherein D is k Representing a data set of class k, R w (. Represents a local classification prediction model X i Input layer and hidden layer network of (x) j ,y j )∈D k Representing a data set D k All data in (1), x j Denotes the jth input sample, y j Representing and inputting samples x j And (4) correspondingly marking.
Further, the aggregation formula for aggregating all prototypes of class k is:
Figure BDA0003985889150000032
wherein, N k Representing the number of clients owning a prototype of class k, P k A prototype representing class k after polymerization;
the aggregation formula for aggregating all soft decisions is:
Figure BDA0003985889150000033
wherein the content of the first and second substances,
Figure BDA00039858891500000311
representing the aggregated soft decisions.
Further, the specific expression of the optimization objective function constructed by using the various types of prototypes after aggregation and the soft decision after aggregation is as follows:
Figure BDA0003985889150000034
Figure BDA0003985889150000035
wherein (x) j ,k)∈D P Representing unlabelled public datasets D P K denotes a sample x j And by all soft decisions
Figure BDA0003985889150000036
As determined by the distribution after polymerization; l is a radical of an alcohol 2 (. Cndot.) is a relative entropy loss function; c i ∈C N Representing all clients; alpha (alpha) ("alpha") j Representing a client C i Is taken as a soft decision>
Figure BDA0003985889150000037
The weight of (c); />
Figure BDA0003985889150000038
Soft decisions representing a global classification prediction model; l is M (. Cndot.) represents a root mean square loss function; n is a radical of k Representing the number of clients that own the prototype of class k; λ represents a hyper-parameter; />
Figure BDA0003985889150000039
Representing a client C i Prototypes based on class k data; />
Figure BDA00039858891500000310
Representing a sample x j The hidden layer output of the global classification prediction model; m represents the number of classifications in the classification task.
Based on the same inventive concept, the invention also provides a classification prediction method based on a classification prediction model, wherein the classification prediction model comprises a global classification prediction model and N local classification prediction models, the global classification prediction model and the local classification prediction models are obtained by training the classification prediction model training method based on the federate knowledge distillation algorithm, and the classification prediction method comprises the following steps:
acquiring data to be classified;
and carrying out classification prediction on the data to be classified by utilizing the classification prediction model to obtain the classification of the data to be classified.
Based on the same inventive concept, the present invention also provides an electronic device, the device comprising:
a memory for storing a computer program;
a processor for implementing the steps of any one of the above methods for training a classification prediction model based on a federal knowledge distillation algorithm, or implementing the steps of the above methods for classification prediction based on a classification prediction model when the computer program is executed.
Based on the same inventive concept, the present invention further provides a computer-readable storage medium, wherein a computer program is stored on the computer-readable storage medium, and when being executed by a processor, the computer program implements the steps of the method for training the classification prediction model based on the federal knowledge distillation algorithm or the method for training the classification prediction model based on the classification prediction model.
Advantageous effects
Compared with the prior art, the invention has the advantages that:
according to the classification prediction model training method, the classification prediction method, the electronic device and the storage medium, both the private data of the client and the local classification prediction model are stored in the local of the client, so that the privacy safety of the private data is ensured; the traditional federal learning based on model parameter interaction is improved to be based on model output soft decision interaction by knowledge distillation, so that the communication overhead between a server and a client is greatly reduced, the client and the server are allowed to select a model with a proper architecture according to own bandwidth resources and computing resources, and the individuation of the model architecture is realized.
Meanwhile, the invention also alleviates the problem that the model precision is difficult to improve due to the high isomerization of the client private data through the prototype network, thereby greatly improving the model precision.
Drawings
In order to more clearly illustrate the technical solution of the present invention, the drawings needed to be used in the description of the embodiments are briefly introduced below, and it is obvious that the drawings in the following description are only one embodiment of the present invention, and it is obvious for those skilled in the art to obtain other drawings based on the drawings without creative efforts.
FIG. 1 is a flow diagram of a distillation in accordance with the teachings of the present invention;
FIG. 2 is a flow chart of the federal knowledge distillation algorithm in an embodiment of the present invention;
FIG. 3 is a flow chart of a method for training a classification prediction model based on the federated knowledge distillation algorithm in an embodiment of the present invention.
Detailed Description
The technical solutions in the present invention are clearly and completely described below with reference to the drawings in the embodiments of the present invention, and it is obvious that the described embodiments are only a part of the embodiments of the present invention, and not all of the embodiments. All other embodiments, which can be derived by a person skilled in the art from the embodiments given herein without making any creative effort, shall fall within the protection scope of the present invention.
The technical solution of the present application will be described in detail below with specific examples. The following several specific embodiments may be combined with each other, and details of the same or similar concepts or processes may not be repeated in some embodiments.
In order to solve the problems that the communication overhead of a traditional federated learning algorithm based on model parameter interaction is large, the requirement that participants use different architecture models cannot be met by the traditional federated learning algorithm, and the model precision cannot be improved due to data isomerism, knowledge distillation and a prototype network are applied to federated learning.
Knowledge distillation is a method used for model compression in machine learning algorithms, and the purpose of knowledge distillation is to inject the knowledge of a teacher model into an untrained student model by using a pre-trained teacher model. Knowledge distillation differs from standard model training, which attempts to match the predicted output of the model to the true label value of each sample (e.g., [ cat, dog ] = [0,1 ]), but knowledge distillation attempts to match the predicted output of the student model to the predicted output of the teacher model, i.e., logits (soft decisions), e.g., [ cat, dog ] = [0.3,0.7], which contain more information than the true label value, and can train the student model faster than the standard model.
Consider a simple neural network comprising an input layer, a hidden layer, and an output layer, let F w (. Represents the entire network, R w Denotes an input layer and a hidden layer network, O w (. Cndot.) denotes the output layer network. In the multi-classification problem, x is applied to one sample j The output layer will finally output a classification prediction result
Figure BDA0003985889150000051
In knowledge distillation, the student model updates the weight of the own neural network by minimizing the loss function and the distillation regularizer, and the optimization goal of the student model is shown as the following formula:
Figure BDA0003985889150000052
wherein (x) j ,y j ) E D represents the common data set, λ represents a fixed hyper-parameter, y j Represents a sample x j True tag value of, L 1 (. Cndot.) denotes a cross entropy loss function, L 2 (. Cndot.) represents a relative entropy loss function,
Figure BDA0003985889150000053
outputs logits (soft decisions) representing teacher models, and/or>
Figure BDA0003985889150000054
Representing the output logits (soft decisions) of the student model, the flow of which is shown in fig. 1.
Based on the knowledge distillation described above, the knowledge distillation is applied to federal learning, at this time, what the server and the client exchange in the federal learning framework is no longer model parameters but soft decisions of the model, and at this time, a public data set is also deployed globally for model knowledge transfer, but due to the requirement of federal learning on data privacy protection, the public data set is unlabeled data. The basic flow of the federal knowledge distillation algorithm is as follows:
(1) Each client locally utilizes a private data training model, and uploads logtis (soft decision) of the client model to a public data set to a server during each communication;
(2) After receiving the logits (soft decisions) uploaded by all clients participating in federal learning, the server side aggregates the logits (soft decisions) to obtain the final logits (soft decisions), wherein the server side model is used as a student model at the moment, and all client side models are used as teacher models to perform knowledge distillation training;
(3) After the knowledge distillation training is finished, the server side sends logits (soft decisions) of the server model (student model) on the public data set to each client side;
(4) And (3) after the client receives the logits (soft decisions) from the server, wherein the client model is a student model, the server model is a teacher model, the knowledge distillation training is carried out again, and the next round of federal learning training is started in the step (1) after the training is finished until the global convergence.
Since the common data set is unlabeled data, the optimization objectives of the server model when performing knowledge distillation training become:
Figure BDA0003985889150000061
wherein x is j ∈D P Representing unlabeled public data sets, C i ∈C N Means all parametersWith the client of the federal learning,
Figure BDA0003985889150000062
as a client C i Model of (1), soft decisions, alpha i For client C i The weights of the model locations (soft decisions),
Figure BDA0003985889150000063
are model logits (soft decisions) of the server, L 2 (. Cndot.) is a relative entropy loss function, the flow chart of which is shown in FIG. 2.
Knowledge distillation is applied to the federal learning, the models interact with each other and output logics of the models instead of model parameters, communication overhead is greatly reduced, and after the knowledge distillation is used, the models with the same architecture do not need to be used between a server and a client side, including the client side, the individualization capability of the federal learning framework is greatly increased, and each device (the client side and the server side) can select the models with different architectures according to self conditions.
However, when knowledge distillation is used in the federal learning framework, the distribution of private data among clients will typically exhibit non-iid (non-independent co-distribution), which results in slow convergence of model training and difficulty in improving model accuracy. The reason is that when the private data of the client has high heterogeneity, the over-confident phenomenon occurs in the output logis (soft decision) of the client model, which results in the converged logis (soft decision) of the server including wrong knowledge, and the problem of slow model training convergence and low model precision is inevitably caused by training the server model with the logis (soft decision).
The following are exemplary: two clients participate in federal learning, and the private data of the client A comprises image data of dogs, cats and airplanes; the private data of the other client B includes image data of cats, frogs, and airplanes. After the client A and the client B use the private data to carry out model training respectively, image data of a dog appears in the public data set, and the private data set of the client A contains the image data of the dog, so that the distribution of output logits (soft decisions) of the client A to the image data is correct and tends to the class of the dog; however, since the private data set of client B does not have image data of a dog but has private data of a cat, it is very likely that the client B model will incorrectly predict the image data of the dog as a cat, and the distribution of output logits (soft decisions) of the client B model will be incorrectly biased towards the cat class. The model output logits of the client a and the client B are aggregated, and the distribution of the logits obtained after aggregation may be relatively uniform in the cat and the dog, which results in slow convergence of the server model training and failure to improve the model accuracy.
Based on the above problems, the present invention proposes to use prototype networks to mitigate such phenomena. A prototype network refers to an average representation of embedded vectors of a class of data over a feature space, which is an abstract feature representation of a class of data. Consider a simple neural network comprising an input layer, a hidden layer, and an output layer, let F w (. Represents the entire network, R w (. Represents an input layer and a hidden layer network, O w (. Cndot.) represents an output-layer network, then a prototype network of class k can be represented as
Figure BDA0003985889150000071
Wherein D is k Data set with a representation class k, (x) j ,y j )∈D k Representing a data set D k All of the data in (a).
The prototype network can distinguish the difference of each class from the aspect of feature space and use the difference in federal learning, the server side aggregates the class prototype network based on private data of the client side, and although the prototype network of the client side is slightly different due to the heterogeneity of data distribution, the server side can obtain more abstract feature representations about the classes and relieve the model training problem caused by over-confident of logits (soft decision) by learning the abstract feature representations.
Therefore, the method for training the classification prediction model based on the federate knowledge distillation algorithm provided by the embodiment of the invention comprises the following steps:
step 1: constructing a central server side and N client sides C = { C = } 1 ,C 2 ,...,C i ,...,C N A federal learning system composed of (i) N.gtoreq.2, C i Representing the ith client.
In this embodiment, N =100, the client is a watch, a mobile phone, a computer, an IPAD, or the like.
Step 2: each of the clients C i All locally constructing a local training data set D with labels i And a local classification prediction model X i And let cycle t =1.
Local training data set D i Namely client C i A local classification prediction model X, a data set formed of private data (e.g., image data) i I.e. at client C i And constructing a classification prediction model. Each client has its own local training data set and local classification prediction model, and the local classification prediction models of the clients may be models of the same architecture or models of different architectures. These clients participate in federal learning in order to train a powerful server model to perform classification tasks, such as image classification tasks, while protecting private data.
In this embodiment, the local classification prediction model X i A deep residual network model, such as the ResNet11 model, is employed.
And 3, step 3: each of the clients C i Locally using a local training data set D i For local classification prediction model X i Performing iterative training and using the trained local classification prediction model X i Calculate the local training data set D i Prototype of various kinds of data
Figure BDA0003985889150000072
Wherein a prototype of k-like data +>
Figure BDA0003985889150000073
Local classification prediction model X referring to class k data i The average of the feature vectors is output. />
At each client C i Using a local training data set D i For local classification prediction model X i Performing iterative training is an existing training process. For client C i Prototype of class k data
Figure BDA0003985889150000074
The specific calculation formula of (A) is as follows:
Figure BDA0003985889150000075
wherein D is k Representing a data set of class k, R w (. C) represents a local Classification prediction model X i Input layer and hidden layer network of (x) j ,y j )∈D k Representing a data set D k All data in (1), x j Denotes the jth input sample, y j Representing and inputting samples x j And (4) correspondingly marking. The prototype is an average value of output feature vectors of a certain type of image data, for example, if the client 1 has one hundred image data of a dog and one hundred image data of a cat, the 100 image data of the dog are respectively input into the local classification prediction model, and then 100 models R are obtained w The output feature vectors of (-) and averaging the 100 output feature vectors yields a prototype of the dog, as do the prototypes of the other classes.
Each of the clients C i Locally utilizing the trained local classification prediction model X i Calculating public data set D without annotations P Soft decision of
Figure BDA0003985889150000081
Wherein the soft decision +>
Figure BDA0003985889150000082
Is a local classification prediction model X i The prediction of (2).
Common data set D P Sample x in (1) j Input to each client C i Local classification prediction model X of i In (3), obtain the corresponding client C i Of the public data set D P Soft decision of
Figure BDA0003985889150000083
I.e. the local classification prediction model X i For common data set D P And (5) predicting and outputting the intermediate sample.
In the present embodiment, the common data set D P Adopting a color image data set of a CIFAR-10 universal object, wherein the number M =10 of target classes corresponding to the data set is respectively as follows: airplanes, cars, birds, cats, deer, dogs, frogs, horses, boats, and trucks, with a pixel size of 32 x 32 for each image, and 6000 images for each object category.
Common data set D P Taking one image as a local classification prediction model X i Input of, local classification prediction model X i The output of (2) is a local classification prediction model X i The result of the prediction for this image is an array of ten terms, each term representing the predicted distribution of the model to these ten categories. For example, inputting an image of a dog, the model would output an array: [ \8230 ], 0.1,0.2,0.5, \8230]The data in the array represents the predicted probability of the model for a class.
And 4, step 4: all clients C will each calculate a prototype
Figure BDA0003985889150000084
And a soft decision->
Figure BDA0003985889150000085
And sending the data to the central server.
And 5: the central server side respectively aggregates all the received prototypes and all the soft decisions of each type to obtain various aggregated prototypes and aggregated soft decisions; constructing an optimization objective function by utilizing various types of prototypes after aggregation and soft decisions after aggregation, and utilizing a common data set D P And optimization purposesAnd (4) performing iterative training on the constructed global classification prediction model by using the standard function.
The aggregation includes average aggregation and weighted aggregation, the average aggregation is adopted in this embodiment, and the aggregation formula for aggregating all the prototypes of class k is:
Figure BDA0003985889150000086
wherein, N k Representing the number of clients that own a prototype of class k, P k Representing a prototype of class k after aggregation.
The aggregation formula for aggregating all soft decisions is:
Figure BDA0003985889150000087
wherein the content of the first and second substances,
Figure BDA0003985889150000088
representing the aggregated soft decisions. />
The concrete expression of the optimization objective function constructed by utilizing various types of prototypes after aggregation and soft decisions after aggregation is as follows:
Figure BDA0003985889150000091
Figure BDA0003985889150000092
wherein (x) j ,k)∈D P Representing unlabelled public datasets D P K denotes a sample x j And all soft decisions
Figure BDA0003985889150000093
As determined by the distribution after polymerization; l is 2 (. Cndot.) is a relative entropy loss function; c i ∈C N Representing all clients; alpha (alpha) ("alpha") i Representing a client C i Soft decision->
Figure BDA0003985889150000094
The weight of (c); />
Figure BDA0003985889150000095
Soft decisions representing a global classification prediction model; l is M (. Cndot.) represents a root mean square loss function; n is a radical of k Representing the number of clients that own the prototype of class k; λ represents a hyper-parameter; />
Figure BDA0003985889150000096
Representing a client C i A prototype based on class k data; />
Figure BDA0003985889150000097
Represents a sample x j The hidden layer output of the global classification prediction model; m represents the number of classifications in the classification task.
In this embodiment, the global classification prediction model uses a deep residual network model, such as a ResNet56 model.
Calculating a public data set D without annotations by using the trained global classification prediction model P Soft decision of
Figure BDA0003985889150000098
Wherein the soft decision->
Figure BDA0003985889150000099
Refers to the prediction output of the global classification prediction model.
Step 6: the central server end makes the soft decision
Figure BDA00039858891500000910
To each of the clients c i
And 7: each of the clients C i Utilizing received soft decisions
Figure BDA00039858891500000911
And a common data set D P For local classification prediction model X i And performing iterative training.
And 8: judging whether the cycle number t is equal to the set cycle number, if yes, obtaining the trained local classification prediction model X i And a global classification prediction model; otherwise, let t = t +1 and jump to step 3.
Based on the same inventive concept, the invention also provides a classification prediction method based on a classification prediction model, wherein the classification prediction model comprises a global classification prediction model and N local classification prediction models, the global classification prediction model and the local classification prediction models are obtained by the training method of the classification prediction model based on the federated knowledge distillation algorithm, and the classification prediction method comprises the following steps:
step 1: acquiring data to be classified;
step 2: and carrying out classification prediction on the data to be classified by using the classification prediction model to obtain the classification of the data to be classified.
In the invention, the private data of the client and the local classification prediction model are both stored in the local of the client, thereby ensuring the privacy and safety of the private data; the traditional federal learning based on model parameter interaction is improved to be based on model output soft decision interaction by knowledge distillation, so that the communication overhead between a server and a client is greatly reduced, the client and the server are allowed to select a model with a proper architecture according to own bandwidth resources and computing resources, and the individuation of the model architecture is realized.
Meanwhile, the invention also alleviates the problem that the model precision is difficult to improve due to the high isomerization of the client private data through the prototype network, greatly improves the model precision, and the Federal learning framework using the method has stability and high efficiency.
The above disclosure is only for the specific embodiments of the present invention, but the scope of the present invention is not limited thereto, and any person skilled in the art can easily conceive of changes or modifications within the technical scope of the present invention, and shall be covered by the scope of the present invention.

Claims (8)

1. A classification prediction model training method based on a federal knowledge distillation algorithm is characterized by comprising the following steps:
step 1: build by central server side and N clients C = { C 1 ,C 2 ,...,C i ,...,C N A federal learning system composed of (i) N.gtoreq.2, C i Representing the ith client;
and 2, step: each of the clients C i Constructing local training data set D with labels locally i And a local classification prediction model X i And let cycle t =1;
and step 3: each of the clients C i Locally using a local training data set D i For local classification prediction model X i Performing iterative training and using the trained local classification prediction model X i Calculate the local training data set D i Prototype of various data in
Figure FDA0003985889140000012
Wherein the prototype of the class k data->
Figure FDA0003985889140000013
Local classification prediction model X referring to class k data i Outputting the average value of the feature vectors;
each of the clients C i Locally utilizing the trained local classification prediction model X i Calculating public data set D without annotations P Soft decision of
Figure FDA0003985889140000014
Wherein the soft decision->
Figure FDA0003985889140000015
Is a local classification prediction model X i The predicted output of (2);
and 4, step 4: all clients C will eachCalculated prototype
Figure FDA0003985889140000016
And a soft decision->
Figure FDA0003985889140000017
Sending the data to the central server;
and 5: the central server side respectively aggregates all the received prototypes and all the soft decisions of each type to obtain various aggregated prototypes and aggregated soft decisions; constructing an optimization objective function by utilizing various types of prototypes after aggregation and soft decisions after aggregation, and utilizing a common data set D P The optimization objective function carries out iterative training on the constructed global classification prediction model;
calculating a public data set D without annotations by using the trained global classification prediction model P Soft decision of
Figure FDA0003985889140000018
Wherein the soft decision +>
Figure FDA0003985889140000019
Refers to the prediction output of the global classification prediction model;
and 6: the central server end makes the soft decision
Figure FDA00039858891400000110
To each of the clients c i
And 7: each of the clients C i Using received soft decisions
Figure FDA00039858891400000111
And a common data set D P For local classification prediction model X i Performing iterative training;
and step 8: judging whether the cycle number t is equal to the set cycle number, if yes, obtaining the trained local classification prediction model X i And a global classification prediction model; otherwise, let t = t +1 and jump to step 3.
2. The method of claim 1, wherein the local classification prediction model and the global classification prediction model both employ a deep residual network model.
3. The method of claim 1, wherein the client C is trained on a classification predictive model based on a federated knowledge distillation algorithm i Prototype of class k data
Figure FDA00039858891400000112
The specific calculation formula of (2) is:
Figure FDA0003985889140000011
wherein D is k Representing a data set of class k, R w (. Represents a local classification prediction model X i Input layer and hidden layer network of (x) j ,y j )∈D k Representing a data set D k All data in (1), x j Denotes the jth input sample, y j Representing and inputting samples x j And (4) correspondingly marking.
4. The method of claim 1, wherein the aggregation formula for aggregating all archetypes of class k is:
Figure FDA0003985889140000021
wherein, N k Representing the number of clients that own a prototype of class k, P k A prototype representing class k after polymerization;
the aggregation formula for aggregating all soft decisions is:
Figure FDA0003985889140000022
wherein, the first and the second end of the pipe are connected with each other,
Figure FDA00039858891400000210
representing the aggregated soft decisions.
5. The federal knowledge distillation algorithm-based classification prediction model training method according to any one of claims 1 to 4, wherein the specific expression of the optimization objective function constructed by using the aggregated types of prototypes and the aggregated soft decision is as follows:
Figure FDA0003985889140000023
Figure FDA0003985889140000024
wherein (x) j ,k)∈D P Representing unlabelled public datasets D P K denotes a sample x j And all soft decisions
Figure FDA0003985889140000029
As determined by the distribution after polymerization; l is 2 (. Cndot.) is a relative entropy loss function; c i ∈C N Representing all clients; alpha is alpha i Representing a client C i Soft decision->
Figure FDA0003985889140000026
The weight of (c); />
Figure FDA0003985889140000027
Representing global classification prediction modesA soft decision of type; l is a radical of an alcohol M (. Cndot.) represents a root mean square loss function; n is a radical of k Representing the number of clients that own the prototype of class k; λ represents a hyper-parameter; />
Figure FDA0003985889140000028
Representing a client C i A prototype based on class k data; />
Figure FDA0003985889140000025
Represents a sample x j The hidden layer output of the global classification prediction model; m represents the number of classifications in the classification task.
6. A classification prediction method based on a classification prediction model, which is characterized in that the classification prediction model comprises a global classification prediction model and N local classification prediction models, wherein the global classification prediction model and the local classification prediction models are obtained by training the classification prediction model based on the federated knowledge distillation algorithm of any one of claims 1 to 5, and the classification prediction method comprises the following steps:
acquiring data to be classified;
and carrying out classification prediction on the data to be classified by using the classification prediction model to obtain the classification of the data to be classified.
7. An electronic device, characterized in that the device comprises:
a memory for storing a computer program;
a processor for implementing the steps of the method for training a classification prediction model based on a federal knowledge distillation algorithm as claimed in any one of claims 1 to 5 or the steps of the method for classification prediction based on a classification prediction model as claimed in claim 6 when executing the computer program.
8. A computer readable storage medium having stored thereon a computer program which, when executed by a processor, implements the steps of the federal knowledge distillation algorithm based classification prediction model training method as defined in any one of claims 1 to 5, or implements the steps of the classification prediction model based classification prediction method as defined in claim 6.
CN202211564028.9A 2022-12-07 2022-12-07 Classification prediction model training method, classification prediction method, device and storage medium Pending CN115965078A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211564028.9A CN115965078A (en) 2022-12-07 2022-12-07 Classification prediction model training method, classification prediction method, device and storage medium

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211564028.9A CN115965078A (en) 2022-12-07 2022-12-07 Classification prediction model training method, classification prediction method, device and storage medium

Publications (1)

Publication Number Publication Date
CN115965078A true CN115965078A (en) 2023-04-14

Family

ID=87359297

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211564028.9A Pending CN115965078A (en) 2022-12-07 2022-12-07 Classification prediction model training method, classification prediction method, device and storage medium

Country Status (1)

Country Link
CN (1) CN115965078A (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116957067A (en) * 2023-06-28 2023-10-27 北京邮电大学 Reinforced federal learning method and device for public safety event prediction model

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116957067A (en) * 2023-06-28 2023-10-27 北京邮电大学 Reinforced federal learning method and device for public safety event prediction model
CN116957067B (en) * 2023-06-28 2024-04-26 北京邮电大学 Reinforced federal learning method and device for public safety event prediction model

Similar Documents

Publication Publication Date Title
US20230196117A1 (en) Training method for semi-supervised learning model, image processing method, and device
CN109891897B (en) Method for analyzing media content
CN109635917B (en) Multi-agent cooperation decision and training method
CN110651280A (en) Projection neural network
CN106897254B (en) Network representation learning method
CN111989696A (en) Neural network for scalable continuous learning in domains with sequential learning tasks
CN111741330A (en) Video content evaluation method and device, storage medium and computer equipment
CN112052948B (en) Network model compression method and device, storage medium and electronic equipment
CN113159283A (en) Model training method based on federal transfer learning and computing node
CN111898703B (en) Multi-label video classification method, model training method, device and medium
CN112990295A (en) Semi-supervised graph representation learning method and device based on migration learning and deep learning fusion
CN111104831B (en) Visual tracking method, device, computer equipment and medium
CN114332578A (en) Image anomaly detection model training method, image anomaly detection method and device
CN109710842B9 (en) Method and device for pushing service information and readable storage medium
CN115344883A (en) Personalized federal learning method and device for processing unbalanced data
CN115965078A (en) Classification prediction model training method, classification prediction method, device and storage medium
CN112115744B (en) Point cloud data processing method and device, computer storage medium and electronic equipment
CN114358250A (en) Data processing method, data processing apparatus, computer device, medium, and program product
CN117726884A (en) Training method of object class identification model, object class identification method and device
CN111091102B (en) Video analysis device, server, system and method for protecting identity privacy
KR20210096405A (en) Apparatus and method for generating learning model for machine
Jin et al. Improving the performance of deep learning model-based classification by the analysis of local probability
CN115577797A (en) Local noise perception-based federated learning optimization method and system
CN115587217A (en) Multi-terminal video detection model online retraining method
Lin et al. Collaborative Framework of Accelerating Reinforcement Learning Training with Supervised Learning Based on Edge Computing

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination