CN115965078A - Classification prediction model training method, classification prediction method, device and storage medium - Google Patents
Classification prediction model training method, classification prediction method, device and storage medium Download PDFInfo
- Publication number
- CN115965078A CN115965078A CN202211564028.9A CN202211564028A CN115965078A CN 115965078 A CN115965078 A CN 115965078A CN 202211564028 A CN202211564028 A CN 202211564028A CN 115965078 A CN115965078 A CN 115965078A
- Authority
- CN
- China
- Prior art keywords
- classification prediction
- prediction model
- local
- model
- representing
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Landscapes
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
The invention discloses a classification prediction model training method, a classification prediction method, a device and a storage medium, wherein the training method comprises the steps that each client utilizes a local training data set to train a local classification prediction model, the prototypes of various data in the local training data set are calculated, and the soft decision of a public data set is calculated; the central server side aggregates all prototypes and all soft decisions, constructs an optimization objective function by utilizing the aggregated prototypes and the aggregated soft decisions, trains a global classification prediction model and calculates the soft decisions of a public data set; each client side trains a local classification prediction model by using the received soft decision and the public data set; and when the circulation round is equal to the set round, obtaining each trained local classification prediction model and global classification prediction model. The invention can reduce the communication overhead between the server and the client, realize the individuation of the model architecture and improve the model precision.
Description
Technical Field
The invention belongs to the technical field of privacy computation, and particularly relates to a classification prediction model training method based on a federal knowledge distillation algorithm, a classification prediction method, electronic equipment and a storage medium.
Background
Due to the high development of mobile devices (such as mobile phones, watches, computers, and the like) and the advancement of sensing technologies, a large amount of data (which is private data of users, such as personal pictures) is collected by the mobile devices at the edge, artificial intelligence is rapidly developed nowadays, and such private data is usually aggregated and stored in the cloud, and various intelligent applications are realized in cooperation with machine learning or deep learning models. However, when sensitive original data is uploaded to the cloud end through the network, the problem of serious data privacy disclosure for data donors can occur when private data is processed in the cloud end in a centralized manner, and the concept of federal learning arises on the basis of the driving force for protecting the data privacy security. Different from a centralized learning mode, federated learning supports collaborative learning of global models on distributed computing nodes using local data, original data are not sent to a cloud, and only the learned global models are updated and submitted to the cloud for aggregation; then, the global model on the cloud is updated and sent back to the distributed computing nodes for the next iteration. Through the iteration mode, the global model can be learned under the condition of not damaging the privacy of the user, and besides improving the data privacy problem, the federal learning also brings many other benefits, such as improvement of safety, autonomy and efficiency.
As federal learning evolves, many new challenges also arise. The most major challenges come from two areas:
(1) Conventional federated learning algorithms share model parameters at each iteration, which means that the communication overhead can be excessive. Because the existing deep learning model may have millions of parameters, for example, mobileBRET is a deep learning model structure of a natural language processing task, and has 2500 ten thousand parameters, corresponding to the memory size of 96MB, while the mobile device at the edge end is often limited by bandwidth, and each round of communication needs to exchange 96MB of information, which is challenging for the mobile device, so that many mobile devices cannot participate in the federal learning task requiring large parameter interaction.
(2) The heterogeneity problem poses a significant challenge to the desire to deploy the federated learning system in real-world scenarios. On one hand, the problem of model heterogeneity is solved, most of mobile devices participating in the federal learning task have different computing resources and bandwidth resources, the mobile devices do not have enough bandwidth or computing capacity to train a large-scale deep learning model, and therefore different participants may need models of different architectures to train, and the federal learning architecture based on model parameter interaction does not meet the requirements of the participants on using different architecture models; on the other hand, the problem of data heterogeneity is solved, the local data distribution of each mobile device participating in the federal learning task has the characteristic of non-independent and same distribution globally, and the convergence of the model can be hindered by simply aggregating the model parameters of the mobile device client.
Based on the above constraints, the global model trained by the federal learning task may not have very high accuracy in real practice.
Disclosure of Invention
The invention aims to provide a classification prediction model training method, a classification prediction method, equipment and a storage medium based on a federal knowledge distillation algorithm, and aims to solve the problems that the communication overhead of a traditional federal learning algorithm is too high, the traditional federal learning algorithm cannot meet the requirements of participants on using different architecture models, and the model precision cannot be improved due to data isomerism.
The invention solves the technical problems through the following technical scheme: a classification prediction model training method based on a federal knowledge distillation algorithm comprises the following steps:
step 1: build by central server side and N clients C = { C 1 ,C 2 ,...,C i ,...,C N A federal learning system composed of (i) N.gtoreq.2, C i Representing the ith client;
and 2, step: each of the clients C i All locally constructing a local training data set D with labels i And local classification prediction model X i And let cycle t =1;
and step 3: each of the clients C i Locally using a local training data set D i For local classification prediction model X i Performing iterative training and using the trained local classification prediction model X i Calculate the local training data set D i Prototype of various kinds of dataWherein a prototype of k-like data +>Local classification prediction model X referring to class k data i Outputting the average value of the feature vectors;
each of the clients C i Locally utilizing the trained local classification prediction model X i Calculating public data set D without annotations P Soft decision ofWherein the soft decision->Refers to a local classification prediction model X i The predicted output of (2);
and 4, step 4: all clients C will count eachCalculated prototypeAnd a soft decision->Sending the data to the central server;
and 5: the central server side respectively aggregates all the received prototypes and all the soft decisions of each type to obtain various aggregated prototypes and aggregated soft decisions; constructing an optimization objective function by utilizing various types of prototypes after aggregation and soft decisions after aggregation, and utilizing a common data set D P The optimization objective function carries out iterative training on the constructed global classification prediction model;
calculating a public data set D without annotations by using the trained global classification prediction model P Soft decision ofWherein the soft decision->The prediction output of the global classification prediction model is referred to;
And 7: each of the clients C i Utilizing received soft decisionsAnd a common data set D P For local classification prediction model X i Performing iterative training;
and step 8: judging whether the cycle t is equal to the set cycle, if so, obtaining each trained local classification prediction model X i And a global classification prediction model; otherwise, let t = t +1, and jump to step3。
Further, the local classification prediction model and the global classification prediction model both adopt a depth residual error network model.
Further, for the client C i Prototype of class k dataThe specific calculation formula of (A) is as follows:
wherein D is k Representing a data set of class k, R w (. Represents a local classification prediction model X i Input layer and hidden layer network of (x) j ,y j )∈D k Representing a data set D k All data in (1), x j Denotes the jth input sample, y j Representing and inputting samples x j And (4) correspondingly marking.
Further, the aggregation formula for aggregating all prototypes of class k is:
wherein, N k Representing the number of clients owning a prototype of class k, P k A prototype representing class k after polymerization;
the aggregation formula for aggregating all soft decisions is:
Further, the specific expression of the optimization objective function constructed by using the various types of prototypes after aggregation and the soft decision after aggregation is as follows:
wherein (x) j ,k)∈D P Representing unlabelled public datasets D P K denotes a sample x j And by all soft decisionsAs determined by the distribution after polymerization; l is a radical of an alcohol 2 (. Cndot.) is a relative entropy loss function; c i ∈C N Representing all clients; alpha (alpha) ("alpha") j Representing a client C i Is taken as a soft decision>The weight of (c); />Soft decisions representing a global classification prediction model; l is M (. Cndot.) represents a root mean square loss function; n is a radical of k Representing the number of clients that own the prototype of class k; λ represents a hyper-parameter; />Representing a client C i Prototypes based on class k data; />Representing a sample x j The hidden layer output of the global classification prediction model; m represents the number of classifications in the classification task.
Based on the same inventive concept, the invention also provides a classification prediction method based on a classification prediction model, wherein the classification prediction model comprises a global classification prediction model and N local classification prediction models, the global classification prediction model and the local classification prediction models are obtained by training the classification prediction model training method based on the federate knowledge distillation algorithm, and the classification prediction method comprises the following steps:
acquiring data to be classified;
and carrying out classification prediction on the data to be classified by utilizing the classification prediction model to obtain the classification of the data to be classified.
Based on the same inventive concept, the present invention also provides an electronic device, the device comprising:
a memory for storing a computer program;
a processor for implementing the steps of any one of the above methods for training a classification prediction model based on a federal knowledge distillation algorithm, or implementing the steps of the above methods for classification prediction based on a classification prediction model when the computer program is executed.
Based on the same inventive concept, the present invention further provides a computer-readable storage medium, wherein a computer program is stored on the computer-readable storage medium, and when being executed by a processor, the computer program implements the steps of the method for training the classification prediction model based on the federal knowledge distillation algorithm or the method for training the classification prediction model based on the classification prediction model.
Advantageous effects
Compared with the prior art, the invention has the advantages that:
according to the classification prediction model training method, the classification prediction method, the electronic device and the storage medium, both the private data of the client and the local classification prediction model are stored in the local of the client, so that the privacy safety of the private data is ensured; the traditional federal learning based on model parameter interaction is improved to be based on model output soft decision interaction by knowledge distillation, so that the communication overhead between a server and a client is greatly reduced, the client and the server are allowed to select a model with a proper architecture according to own bandwidth resources and computing resources, and the individuation of the model architecture is realized.
Meanwhile, the invention also alleviates the problem that the model precision is difficult to improve due to the high isomerization of the client private data through the prototype network, thereby greatly improving the model precision.
Drawings
In order to more clearly illustrate the technical solution of the present invention, the drawings needed to be used in the description of the embodiments are briefly introduced below, and it is obvious that the drawings in the following description are only one embodiment of the present invention, and it is obvious for those skilled in the art to obtain other drawings based on the drawings without creative efforts.
FIG. 1 is a flow diagram of a distillation in accordance with the teachings of the present invention;
FIG. 2 is a flow chart of the federal knowledge distillation algorithm in an embodiment of the present invention;
FIG. 3 is a flow chart of a method for training a classification prediction model based on the federated knowledge distillation algorithm in an embodiment of the present invention.
Detailed Description
The technical solutions in the present invention are clearly and completely described below with reference to the drawings in the embodiments of the present invention, and it is obvious that the described embodiments are only a part of the embodiments of the present invention, and not all of the embodiments. All other embodiments, which can be derived by a person skilled in the art from the embodiments given herein without making any creative effort, shall fall within the protection scope of the present invention.
The technical solution of the present application will be described in detail below with specific examples. The following several specific embodiments may be combined with each other, and details of the same or similar concepts or processes may not be repeated in some embodiments.
In order to solve the problems that the communication overhead of a traditional federated learning algorithm based on model parameter interaction is large, the requirement that participants use different architecture models cannot be met by the traditional federated learning algorithm, and the model precision cannot be improved due to data isomerism, knowledge distillation and a prototype network are applied to federated learning.
Knowledge distillation is a method used for model compression in machine learning algorithms, and the purpose of knowledge distillation is to inject the knowledge of a teacher model into an untrained student model by using a pre-trained teacher model. Knowledge distillation differs from standard model training, which attempts to match the predicted output of the model to the true label value of each sample (e.g., [ cat, dog ] = [0,1 ]), but knowledge distillation attempts to match the predicted output of the student model to the predicted output of the teacher model, i.e., logits (soft decisions), e.g., [ cat, dog ] = [0.3,0.7], which contain more information than the true label value, and can train the student model faster than the standard model.
Consider a simple neural network comprising an input layer, a hidden layer, and an output layer, let F w (. Represents the entire network, R w Denotes an input layer and a hidden layer network, O w (. Cndot.) denotes the output layer network. In the multi-classification problem, x is applied to one sample j The output layer will finally output a classification prediction resultIn knowledge distillation, the student model updates the weight of the own neural network by minimizing the loss function and the distillation regularizer, and the optimization goal of the student model is shown as the following formula:
wherein (x) j ,y j ) E D represents the common data set, λ represents a fixed hyper-parameter, y j Represents a sample x j True tag value of, L 1 (. Cndot.) denotes a cross entropy loss function, L 2 (. Cndot.) represents a relative entropy loss function,outputs logits (soft decisions) representing teacher models, and/or>Representing the output logits (soft decisions) of the student model, the flow of which is shown in fig. 1.
Based on the knowledge distillation described above, the knowledge distillation is applied to federal learning, at this time, what the server and the client exchange in the federal learning framework is no longer model parameters but soft decisions of the model, and at this time, a public data set is also deployed globally for model knowledge transfer, but due to the requirement of federal learning on data privacy protection, the public data set is unlabeled data. The basic flow of the federal knowledge distillation algorithm is as follows:
(1) Each client locally utilizes a private data training model, and uploads logtis (soft decision) of the client model to a public data set to a server during each communication;
(2) After receiving the logits (soft decisions) uploaded by all clients participating in federal learning, the server side aggregates the logits (soft decisions) to obtain the final logits (soft decisions), wherein the server side model is used as a student model at the moment, and all client side models are used as teacher models to perform knowledge distillation training;
(3) After the knowledge distillation training is finished, the server side sends logits (soft decisions) of the server model (student model) on the public data set to each client side;
(4) And (3) after the client receives the logits (soft decisions) from the server, wherein the client model is a student model, the server model is a teacher model, the knowledge distillation training is carried out again, and the next round of federal learning training is started in the step (1) after the training is finished until the global convergence.
Since the common data set is unlabeled data, the optimization objectives of the server model when performing knowledge distillation training become:
wherein x is j ∈D P Representing unlabeled public data sets, C i ∈C N Means all parametersWith the client of the federal learning,as a client C i Model of (1), soft decisions, alpha i For client C i The weights of the model locations (soft decisions),are model logits (soft decisions) of the server, L 2 (. Cndot.) is a relative entropy loss function, the flow chart of which is shown in FIG. 2.
Knowledge distillation is applied to the federal learning, the models interact with each other and output logics of the models instead of model parameters, communication overhead is greatly reduced, and after the knowledge distillation is used, the models with the same architecture do not need to be used between a server and a client side, including the client side, the individualization capability of the federal learning framework is greatly increased, and each device (the client side and the server side) can select the models with different architectures according to self conditions.
However, when knowledge distillation is used in the federal learning framework, the distribution of private data among clients will typically exhibit non-iid (non-independent co-distribution), which results in slow convergence of model training and difficulty in improving model accuracy. The reason is that when the private data of the client has high heterogeneity, the over-confident phenomenon occurs in the output logis (soft decision) of the client model, which results in the converged logis (soft decision) of the server including wrong knowledge, and the problem of slow model training convergence and low model precision is inevitably caused by training the server model with the logis (soft decision).
The following are exemplary: two clients participate in federal learning, and the private data of the client A comprises image data of dogs, cats and airplanes; the private data of the other client B includes image data of cats, frogs, and airplanes. After the client A and the client B use the private data to carry out model training respectively, image data of a dog appears in the public data set, and the private data set of the client A contains the image data of the dog, so that the distribution of output logits (soft decisions) of the client A to the image data is correct and tends to the class of the dog; however, since the private data set of client B does not have image data of a dog but has private data of a cat, it is very likely that the client B model will incorrectly predict the image data of the dog as a cat, and the distribution of output logits (soft decisions) of the client B model will be incorrectly biased towards the cat class. The model output logits of the client a and the client B are aggregated, and the distribution of the logits obtained after aggregation may be relatively uniform in the cat and the dog, which results in slow convergence of the server model training and failure to improve the model accuracy.
Based on the above problems, the present invention proposes to use prototype networks to mitigate such phenomena. A prototype network refers to an average representation of embedded vectors of a class of data over a feature space, which is an abstract feature representation of a class of data. Consider a simple neural network comprising an input layer, a hidden layer, and an output layer, let F w (. Represents the entire network, R w (. Represents an input layer and a hidden layer network, O w (. Cndot.) represents an output-layer network, then a prototype network of class k can be represented as
Wherein D is k Data set with a representation class k, (x) j ,y j )∈D k Representing a data set D k All of the data in (a).
The prototype network can distinguish the difference of each class from the aspect of feature space and use the difference in federal learning, the server side aggregates the class prototype network based on private data of the client side, and although the prototype network of the client side is slightly different due to the heterogeneity of data distribution, the server side can obtain more abstract feature representations about the classes and relieve the model training problem caused by over-confident of logits (soft decision) by learning the abstract feature representations.
Therefore, the method for training the classification prediction model based on the federate knowledge distillation algorithm provided by the embodiment of the invention comprises the following steps:
step 1: constructing a central server side and N client sides C = { C = } 1 ,C 2 ,...,C i ,...,C N A federal learning system composed of (i) N.gtoreq.2, C i Representing the ith client.
In this embodiment, N =100, the client is a watch, a mobile phone, a computer, an IPAD, or the like.
Step 2: each of the clients C i All locally constructing a local training data set D with labels i And a local classification prediction model X i And let cycle t =1.
Local training data set D i Namely client C i A local classification prediction model X, a data set formed of private data (e.g., image data) i I.e. at client C i And constructing a classification prediction model. Each client has its own local training data set and local classification prediction model, and the local classification prediction models of the clients may be models of the same architecture or models of different architectures. These clients participate in federal learning in order to train a powerful server model to perform classification tasks, such as image classification tasks, while protecting private data.
In this embodiment, the local classification prediction model X i A deep residual network model, such as the ResNet11 model, is employed.
And 3, step 3: each of the clients C i Locally using a local training data set D i For local classification prediction model X i Performing iterative training and using the trained local classification prediction model X i Calculate the local training data set D i Prototype of various kinds of dataWherein a prototype of k-like data +>Local classification prediction model X referring to class k data i The average of the feature vectors is output. />
At each client C i Using a local training data set D i For local classification prediction model X i Performing iterative training is an existing training process. For client C i Prototype of class k dataThe specific calculation formula of (A) is as follows:
wherein D is k Representing a data set of class k, R w (. C) represents a local Classification prediction model X i Input layer and hidden layer network of (x) j ,y j )∈D k Representing a data set D k All data in (1), x j Denotes the jth input sample, y j Representing and inputting samples x j And (4) correspondingly marking. The prototype is an average value of output feature vectors of a certain type of image data, for example, if the client 1 has one hundred image data of a dog and one hundred image data of a cat, the 100 image data of the dog are respectively input into the local classification prediction model, and then 100 models R are obtained w The output feature vectors of (-) and averaging the 100 output feature vectors yields a prototype of the dog, as do the prototypes of the other classes.
Each of the clients C i Locally utilizing the trained local classification prediction model X i Calculating public data set D without annotations P Soft decision ofWherein the soft decision +>Is a local classification prediction model X i The prediction of (2).
Common data set D P Sample x in (1) j Input to each client C i Local classification prediction model X of i In (3), obtain the corresponding client C i Of the public data set D P Soft decision ofI.e. the local classification prediction model X i For common data set D P And (5) predicting and outputting the intermediate sample.
In the present embodiment, the common data set D P Adopting a color image data set of a CIFAR-10 universal object, wherein the number M =10 of target classes corresponding to the data set is respectively as follows: airplanes, cars, birds, cats, deer, dogs, frogs, horses, boats, and trucks, with a pixel size of 32 x 32 for each image, and 6000 images for each object category.
Common data set D P Taking one image as a local classification prediction model X i Input of, local classification prediction model X i The output of (2) is a local classification prediction model X i The result of the prediction for this image is an array of ten terms, each term representing the predicted distribution of the model to these ten categories. For example, inputting an image of a dog, the model would output an array: [ \8230 ], 0.1,0.2,0.5, \8230]The data in the array represents the predicted probability of the model for a class.
And 4, step 4: all clients C will each calculate a prototypeAnd a soft decision->And sending the data to the central server.
And 5: the central server side respectively aggregates all the received prototypes and all the soft decisions of each type to obtain various aggregated prototypes and aggregated soft decisions; constructing an optimization objective function by utilizing various types of prototypes after aggregation and soft decisions after aggregation, and utilizing a common data set D P And optimization purposesAnd (4) performing iterative training on the constructed global classification prediction model by using the standard function.
The aggregation includes average aggregation and weighted aggregation, the average aggregation is adopted in this embodiment, and the aggregation formula for aggregating all the prototypes of class k is:
wherein, N k Representing the number of clients that own a prototype of class k, P k Representing a prototype of class k after aggregation.
The aggregation formula for aggregating all soft decisions is:
wherein the content of the first and second substances,representing the aggregated soft decisions. />
The concrete expression of the optimization objective function constructed by utilizing various types of prototypes after aggregation and soft decisions after aggregation is as follows:
wherein (x) j ,k)∈D P Representing unlabelled public datasets D P K denotes a sample x j And all soft decisionsAs determined by the distribution after polymerization; l is 2 (. Cndot.) is a relative entropy loss function; c i ∈C N Representing all clients; alpha (alpha) ("alpha") i Representing a client C i Soft decision->The weight of (c); />Soft decisions representing a global classification prediction model; l is M (. Cndot.) represents a root mean square loss function; n is a radical of k Representing the number of clients that own the prototype of class k; λ represents a hyper-parameter; />Representing a client C i A prototype based on class k data; />Represents a sample x j The hidden layer output of the global classification prediction model; m represents the number of classifications in the classification task.
In this embodiment, the global classification prediction model uses a deep residual network model, such as a ResNet56 model.
Calculating a public data set D without annotations by using the trained global classification prediction model P Soft decision ofWherein the soft decision->Refers to the prediction output of the global classification prediction model.
And 7: each of the clients C i Utilizing received soft decisionsAnd a common data set D P For local classification prediction model X i And performing iterative training.
And 8: judging whether the cycle number t is equal to the set cycle number, if yes, obtaining the trained local classification prediction model X i And a global classification prediction model; otherwise, let t = t +1 and jump to step 3.
Based on the same inventive concept, the invention also provides a classification prediction method based on a classification prediction model, wherein the classification prediction model comprises a global classification prediction model and N local classification prediction models, the global classification prediction model and the local classification prediction models are obtained by the training method of the classification prediction model based on the federated knowledge distillation algorithm, and the classification prediction method comprises the following steps:
step 1: acquiring data to be classified;
step 2: and carrying out classification prediction on the data to be classified by using the classification prediction model to obtain the classification of the data to be classified.
In the invention, the private data of the client and the local classification prediction model are both stored in the local of the client, thereby ensuring the privacy and safety of the private data; the traditional federal learning based on model parameter interaction is improved to be based on model output soft decision interaction by knowledge distillation, so that the communication overhead between a server and a client is greatly reduced, the client and the server are allowed to select a model with a proper architecture according to own bandwidth resources and computing resources, and the individuation of the model architecture is realized.
Meanwhile, the invention also alleviates the problem that the model precision is difficult to improve due to the high isomerization of the client private data through the prototype network, greatly improves the model precision, and the Federal learning framework using the method has stability and high efficiency.
The above disclosure is only for the specific embodiments of the present invention, but the scope of the present invention is not limited thereto, and any person skilled in the art can easily conceive of changes or modifications within the technical scope of the present invention, and shall be covered by the scope of the present invention.
Claims (8)
1. A classification prediction model training method based on a federal knowledge distillation algorithm is characterized by comprising the following steps:
step 1: build by central server side and N clients C = { C 1 ,C 2 ,...,C i ,...,C N A federal learning system composed of (i) N.gtoreq.2, C i Representing the ith client;
and 2, step: each of the clients C i Constructing local training data set D with labels locally i And a local classification prediction model X i And let cycle t =1;
and step 3: each of the clients C i Locally using a local training data set D i For local classification prediction model X i Performing iterative training and using the trained local classification prediction model X i Calculate the local training data set D i Prototype of various data inWherein the prototype of the class k data->Local classification prediction model X referring to class k data i Outputting the average value of the feature vectors;
each of the clients C i Locally utilizing the trained local classification prediction model X i Calculating public data set D without annotations P Soft decision ofWherein the soft decision->Is a local classification prediction model X i The predicted output of (2);
and 4, step 4: all clients C will eachCalculated prototypeAnd a soft decision->Sending the data to the central server;
and 5: the central server side respectively aggregates all the received prototypes and all the soft decisions of each type to obtain various aggregated prototypes and aggregated soft decisions; constructing an optimization objective function by utilizing various types of prototypes after aggregation and soft decisions after aggregation, and utilizing a common data set D P The optimization objective function carries out iterative training on the constructed global classification prediction model;
calculating a public data set D without annotations by using the trained global classification prediction model P Soft decision ofWherein the soft decision +>Refers to the prediction output of the global classification prediction model;
And 7: each of the clients C i Using received soft decisionsAnd a common data set D P For local classification prediction model X i Performing iterative training;
and step 8: judging whether the cycle number t is equal to the set cycle number, if yes, obtaining the trained local classification prediction model X i And a global classification prediction model; otherwise, let t = t +1 and jump to step 3.
2. The method of claim 1, wherein the local classification prediction model and the global classification prediction model both employ a deep residual network model.
3. The method of claim 1, wherein the client C is trained on a classification predictive model based on a federated knowledge distillation algorithm i Prototype of class k dataThe specific calculation formula of (2) is:
wherein D is k Representing a data set of class k, R w (. Represents a local classification prediction model X i Input layer and hidden layer network of (x) j ,y j )∈D k Representing a data set D k All data in (1), x j Denotes the jth input sample, y j Representing and inputting samples x j And (4) correspondingly marking.
4. The method of claim 1, wherein the aggregation formula for aggregating all archetypes of class k is:
wherein, N k Representing the number of clients that own a prototype of class k, P k A prototype representing class k after polymerization;
the aggregation formula for aggregating all soft decisions is:
5. The federal knowledge distillation algorithm-based classification prediction model training method according to any one of claims 1 to 4, wherein the specific expression of the optimization objective function constructed by using the aggregated types of prototypes and the aggregated soft decision is as follows:
wherein (x) j ,k)∈D P Representing unlabelled public datasets D P K denotes a sample x j And all soft decisionsAs determined by the distribution after polymerization; l is 2 (. Cndot.) is a relative entropy loss function; c i ∈C N Representing all clients; alpha is alpha i Representing a client C i Soft decision->The weight of (c); />Representing global classification prediction modesA soft decision of type; l is a radical of an alcohol M (. Cndot.) represents a root mean square loss function; n is a radical of k Representing the number of clients that own the prototype of class k; λ represents a hyper-parameter; />Representing a client C i A prototype based on class k data; />Represents a sample x j The hidden layer output of the global classification prediction model; m represents the number of classifications in the classification task.
6. A classification prediction method based on a classification prediction model, which is characterized in that the classification prediction model comprises a global classification prediction model and N local classification prediction models, wherein the global classification prediction model and the local classification prediction models are obtained by training the classification prediction model based on the federated knowledge distillation algorithm of any one of claims 1 to 5, and the classification prediction method comprises the following steps:
acquiring data to be classified;
and carrying out classification prediction on the data to be classified by using the classification prediction model to obtain the classification of the data to be classified.
7. An electronic device, characterized in that the device comprises:
a memory for storing a computer program;
a processor for implementing the steps of the method for training a classification prediction model based on a federal knowledge distillation algorithm as claimed in any one of claims 1 to 5 or the steps of the method for classification prediction based on a classification prediction model as claimed in claim 6 when executing the computer program.
8. A computer readable storage medium having stored thereon a computer program which, when executed by a processor, implements the steps of the federal knowledge distillation algorithm based classification prediction model training method as defined in any one of claims 1 to 5, or implements the steps of the classification prediction model based classification prediction method as defined in claim 6.
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211564028.9A CN115965078A (en) | 2022-12-07 | 2022-12-07 | Classification prediction model training method, classification prediction method, device and storage medium |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211564028.9A CN115965078A (en) | 2022-12-07 | 2022-12-07 | Classification prediction model training method, classification prediction method, device and storage medium |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115965078A true CN115965078A (en) | 2023-04-14 |
Family
ID=87359297
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211564028.9A Pending CN115965078A (en) | 2022-12-07 | 2022-12-07 | Classification prediction model training method, classification prediction method, device and storage medium |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115965078A (en) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116957067A (en) * | 2023-06-28 | 2023-10-27 | 北京邮电大学 | Reinforced federal learning method and device for public safety event prediction model |
-
2022
- 2022-12-07 CN CN202211564028.9A patent/CN115965078A/en active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116957067A (en) * | 2023-06-28 | 2023-10-27 | 北京邮电大学 | Reinforced federal learning method and device for public safety event prediction model |
CN116957067B (en) * | 2023-06-28 | 2024-04-26 | 北京邮电大学 | Reinforced federal learning method and device for public safety event prediction model |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US20230196117A1 (en) | Training method for semi-supervised learning model, image processing method, and device | |
CN109891897B (en) | Method for analyzing media content | |
CN109635917B (en) | Multi-agent cooperation decision and training method | |
CN110651280A (en) | Projection neural network | |
CN106897254B (en) | Network representation learning method | |
CN111989696A (en) | Neural network for scalable continuous learning in domains with sequential learning tasks | |
CN111741330A (en) | Video content evaluation method and device, storage medium and computer equipment | |
CN112052948B (en) | Network model compression method and device, storage medium and electronic equipment | |
CN113159283A (en) | Model training method based on federal transfer learning and computing node | |
CN111898703B (en) | Multi-label video classification method, model training method, device and medium | |
CN112990295A (en) | Semi-supervised graph representation learning method and device based on migration learning and deep learning fusion | |
CN111104831B (en) | Visual tracking method, device, computer equipment and medium | |
CN114332578A (en) | Image anomaly detection model training method, image anomaly detection method and device | |
CN109710842B9 (en) | Method and device for pushing service information and readable storage medium | |
CN115344883A (en) | Personalized federal learning method and device for processing unbalanced data | |
CN115965078A (en) | Classification prediction model training method, classification prediction method, device and storage medium | |
CN112115744B (en) | Point cloud data processing method and device, computer storage medium and electronic equipment | |
CN114358250A (en) | Data processing method, data processing apparatus, computer device, medium, and program product | |
CN117726884A (en) | Training method of object class identification model, object class identification method and device | |
CN111091102B (en) | Video analysis device, server, system and method for protecting identity privacy | |
KR20210096405A (en) | Apparatus and method for generating learning model for machine | |
Jin et al. | Improving the performance of deep learning model-based classification by the analysis of local probability | |
CN115577797A (en) | Local noise perception-based federated learning optimization method and system | |
CN115587217A (en) | Multi-terminal video detection model online retraining method | |
Lin et al. | Collaborative Framework of Accelerating Reinforcement Learning Training with Supervised Learning Based on Edge Computing |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |