WO2023050754A1 - Model training method and apparatus for private data set - Google Patents

Model training method and apparatus for private data set Download PDF

Info

Publication number
WO2023050754A1
WO2023050754A1 PCT/CN2022/085131 CN2022085131W WO2023050754A1 WO 2023050754 A1 WO2023050754 A1 WO 2023050754A1 CN 2022085131 W CN2022085131 W CN 2022085131W WO 2023050754 A1 WO2023050754 A1 WO 2023050754A1
Authority
WO
WIPO (PCT)
Prior art keywords
model
public data
data set
output
server
Prior art date
Application number
PCT/CN2022/085131
Other languages
French (fr)
Chinese (zh)
Inventor
刘洋
Original Assignee
清华大学
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 清华大学 filed Critical 清华大学
Publication of WO2023050754A1 publication Critical patent/WO2023050754A1/en

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F21/00Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
    • G06F21/60Protecting data
    • G06F21/62Protecting access to data via a platform, e.g. using keys or access control rules
    • G06F21/6218Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
    • G06F21/6245Protecting personal data, e.g. for financial or medical purposes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Definitions

  • This application relates to the technical field of multi-party data cooperation, and in particular to a model training method and device based on private data sets.
  • the embodiment of the present application provides a model training method and device based on a private data set to solve the problem of the lack of a model training solution based on a multi-party private data set.
  • the embodiment of the present application provides a model training method based on a private data set, including:
  • the server-side model is trained
  • the first model output sent by each client is obtained by the client inputting the public data set into the local learning model; the local learning model is obtained by the client based on the private data set and corresponding label pair Suppose the model is trained;
  • the second model output is sent to each of the clients, so that each of the clients can retrain the local learning model based on the second model output and the public data set.
  • the training of the server-side model based on the public data set and the real label corresponding to the public data set includes:
  • the server-side model is trained based on a cross-entropy loss function between the predicted result and the true label.
  • the training of the server-side model based on the public data set and the real label corresponding to the public data set further includes:
  • the first target model output is the model output corresponding to the target public data
  • the target public data is the public data set
  • the prediction result obtained after being input into the server-side model conforms to the corresponding Publicly available data on real labels
  • the target public data to be distilled is the public data set, and the prediction result obtained after inputting the server-side model does not conform to the public data corresponding to the real label;
  • the first public data to be distilled is part of the target public data to be distilled that has a corresponding first target model output;
  • the server-side model is trained based on the first public data to be distilled and a first target model output corresponding to the first public data to be distilled.
  • the acquiring the first model output sent by each client includes:
  • the second public data to be distilled is part of the target distillation public data that does not have a corresponding first target model output;
  • the request is used to request the client to return the first model output;
  • the first model output is the part corresponding to the second public data to be distilled in the model output of each local learning model model output;
  • retraining the server-side model based on the public data corresponding to each of the first model outputs and each of the first model outputs includes:
  • the second target model output is a partial model output in which the corresponding prediction result in the first model output meets the corresponding real label
  • the third public data to be distilled is part output data of the second target model corresponding to the second public data to be distilled;
  • the server-side model is retrained.
  • the retraining of the server-side model based on the third public data to be distilled and the output of each of the second target models includes:
  • the public data set and the private data set include: image data, text data or sound data related to entities.
  • the embodiment of the present application provides a model training device based on a private data set, including:
  • the first training unit is used to train the server-side model based on the public data set and the real label corresponding to the public data set;
  • the obtaining unit is used to obtain the first model output sent by each client; the first model output is obtained by the client inputting the public data set into the local learning model; the local learning model is obtained by the client based on the private data set And the corresponding label is obtained by training the preset model;
  • a second training unit configured to retrain the server-side model based on the public data corresponding to each of the first model outputs and each of the first model outputs;
  • an input unit configured to input the public data set into the retrained server-side model to obtain a second model output
  • a sending unit configured to send the second model output to each of the clients, so that each of the clients can retrain the local learning model based on the second model output and the public data set .
  • the embodiment of the present application provides an electronic device, including a memory, a processor, and a computer program stored on the memory and operable on the processor, wherein the processor implements the following when executing the program: The steps of the model training method based on the private data set provided by this application.
  • the embodiment of the present application provides a non-transitory computer-readable storage medium on which a computer program is stored, which is characterized in that, when the computer program is executed by a processor, the private data set-based privacy data collection provided by this application is realized.
  • the steps of the model training method are characterized in that, when the computer program is executed by a processor, the private data set-based privacy data collection provided by this application is realized.
  • the embodiment of the present application provides a model training method based on a private data set.
  • the server-side The independent training ability of the model, based on the output of each first model, performs knowledge distillation and knowledge fusion, and then sends the fused knowledge back to each local learning model based on the output of the second model, so that each local learning model can obtain the fused knowledge . That is: through the public data set, the output of the first model and the output of the second model as the medium of knowledge transmission, all the knowledge is stored in a powerful model (server-side model) as a general knowledge base to help federated learning.
  • the server-side model not only uses sufficient computing resources to train itself, but also uses all clients as multiple teachers to learn knowledge, helping to further improve the effect of the server-side model.
  • the accumulated knowledge on the server side will be further passed on to the client to help improve the effect of the local learning models of all clients, so that each local learning model obtained after training contains the knowledge of multiple private data sets, that is, each local learning model is based on Trained on multi-party privacy datasets.
  • the embodiment of the present application provides a feasible model training method based on private data sets, which can be specifically applied to the training of models related to private data in the medical field.
  • Fig. 1 is one of the schematic flow charts of the model training method based on the private data set provided by the embodiment of the present application;
  • Fig. 2 is the second schematic flow diagram of the model training method based on the private data set provided by the embodiment of the present application;
  • Fig. 3 is the third schematic flow diagram of the model training method based on the private data set provided by the embodiment of the present application;
  • Fig. 4 is the fourth schematic flow diagram of the model training method based on the private data set provided by the embodiment of the present application;
  • Fig. 5 is the fifth schematic flow diagram of the model training method based on the private data set provided by the embodiment of the present application.
  • FIG. 6 is a schematic structural diagram of a model training device based on a private data set provided in an embodiment of the present application
  • FIG. 7 is a schematic structural diagram of an electronic device provided by an embodiment of the present application.
  • Data processing such as data analysis, data mining, and trend prediction is widely used in more and more scenarios for a large amount of information data flooding in various industries such as economy, culture, education, medical care, and public management.
  • industries such as economy, culture, education, medical care, and public management.
  • multiple data owners can obtain better data processing results.
  • more accurate model parameters can be obtained through joint training of multi-party data.
  • the joint training system for models based on private data can be applied to the scenario where all parties cooperate to train machine learning models for use by multiple parties while ensuring the data security of all parties.
  • multiple data parties have their own data, and they want to jointly use each other's data for unified modeling (for example, linear regression model, logistic regression model, etc.), but they do not want their own data (especially private data) was leaked.
  • hospital A has a batch of patient data (such as photos of patients’ diseased parts) that are not suitable for public disclosure due to patient privacy issues
  • hospital B has a batch of patient data that is also not suitable for public disclosure due to patient privacy issues, based on the patients of hospital A and hospital B
  • the training sample set determined by the data can be trained to obtain a relatively good machine learning model. Both A and B are willing to participate in model training through each other's patient data, but Hospital A and Hospital B need to ensure that patient data will not be leaked, and they cannot or are unwilling to let the other party know their patient data.
  • the embodiments of the present application provide a model training method and device based on private datasets based on knowledge distillation and federated learning.
  • the client uploads model parameters or model gradients to the central server, and the server aggregates them in a certain form and distributes them back to the client, and further updates on the localized data.
  • Transferring parameters or gradients will bring a series of privacy, heterogeneity, and communication cost issues.
  • some work uses knowledge distillation to transfer knowledge between the terminal and the server. But since the client is actually resource-constrained, it is impossible to use large models directly on the client, so how to solve the resource-constrained problem is still a great challenge. Only by mining the computing resources on the server side as much as possible, and using the auxiliary large model to transfer and accumulate knowledge on the server side, can we achieve the same effect of knowledge fusion as centralized training with large models.
  • Fig. 1 is one of the flow diagrams of the model training method based on the private data set provided by the embodiment of the present application. As shown in Fig. 1, the method includes:
  • Step 110 based on the public data set and the real label corresponding to the public data set, the server-side model is trained;
  • the public data set and the private data set are the same type of data, except that the public data set is data that can be made public, and the private data set is data that cannot or is not suitable for publicity.
  • the public data set and the private data set may be image data, text data or sound data related to entities. For example, patient disease picture data of some hospitals, and user data of some Internet companies.
  • the server-side model is a large model, that is, the server-side model is relatively complex, and knowledge can be mined and learned as much as possible.
  • Step 120 obtain the first model output sent by each client; the first model output is obtained by the client inputting the public data set into the local learning model; the local learning model is obtained by the client based on the private data set and the corresponding The label is trained on the preset model;
  • Step 130 retraining the server-side model based on the public data corresponding to each of the first model outputs and each of the first model outputs;
  • Step 140 inputting the public data set into the retrained server-side model to obtain a second model output
  • Step 150 sending the second model output to each of the clients, so that each of the clients can retrain the local learning model based on the second model output and the public data set.
  • the output of the first model and the output of the second model as the channels and media for the information exchange between the local learning models and the server-side model is fully utilized, and knowledge distillation and Knowledge fusion, and then the knowledge obtained after fusion is sent back to each local learning model based on the output of the second model, so that each local learning model can obtain the fused knowledge. That is: through the public data set, the output of the first model and the output of the second model as the medium of knowledge transmission, all the knowledge is stored in a powerful model (server-side model) as a general knowledge base to help federated learning.
  • the server-side model not only uses sufficient computing resources to train itself, but also uses all clients as multiple teachers to learn knowledge, helping to further improve the effect of the server-side model.
  • the accumulated knowledge on the server side will be further passed on to the client to help improve the effect of the local learning models of all clients, so that each local learning model obtained after training contains the knowledge of multiple private data sets, that is, each local learning model is based on Trained on multi-party privacy datasets.
  • the server-side model is the center of knowledge aggregation, and the knowledge it learns directly affects the knowledge obtained by each local learning model based on the output of the second model; therefore, the training of the server-side model is a relatively important part .
  • step 110 train the server-side model based on the public data set and the real label corresponding to the public data set
  • step 130 based on the public data corresponding to each of the first model outputs and each of the The output of the first model described above, retraining the server-side model" is the training part of the server-side model:
  • the training of the server-side model is mainly divided into three parts: preliminary training, self-distillation and aggregation distillation (retraining). It should be noted that these three parts are not executed in strict chronological order, but integrated with each other.
  • Step 111 inputting the public data set into the server-side model to obtain prediction results
  • Step 112 Train the server-side model based on the cross-entropy loss function between the prediction result and the real label.
  • This part of the training is relatively routine, simply using the cross-entropy loss function between the predicted result and the real label to train the server-side model.
  • some existing training examples can be referred to.
  • Step 113 determine and store the first target model output;
  • the first target model output is the model output corresponding to the target public data;
  • the target public data is the prediction obtained after being input into the server-side model in the public data set The results conform to the public data corresponding to the real label;
  • step 114 the output of the first target model is stored to prepare for the distillation obtained in step 114, step 115 and step 116. Save the correctly predicted model output (that is, the first target model output) to the global model output as a memory to help correct the wrong but correct samples later.
  • the specific instructions for self-distillation are as follows: For samples with wrong model predictions in the public data set (that is, the target public data to be distilled), we first look for whether the sample exists in the memory of the global model output (ie: the output of the first target model) The corresponding model output, if it exists, means that this part of the knowledge model was once included, so reviewing the knowledge you have mastered can help the model correct its own mistakes. For this idea, we use the self-distillation method to carry out distillation training on the model. Use the current model output to approach the model output when it was right before, and combine the cross-entropy loss. Specific steps are as follows:
  • Step 114 determining the target public data to be distilled;
  • the target public data to be distilled is the public data set, and the prediction result obtained after inputting the server-side model does not conform to the public data corresponding to the real label;
  • Step 115 determining the first public data to be distilled; the first public data to be distilled is part of the target public data to be distilled that has a corresponding first target model output;
  • Step 116 Train the server-side model based on the first public data to be distilled and a first target model output corresponding to the first public data to be distilled.
  • the self-distillation of the server-side model is completed in the above way. Focused distillation is performed on other data except the first public data to be distilled in the target public data to be distilled; focused distillation is a core of the solution of the embodiment of this application, as long as it is used to obtain the knowledge of private data sets of other clients.
  • Step 121 the client trains the preset server-side model based on the private data set and the corresponding label
  • Step 122 the client inputs the public data set into the local learning model to obtain a model output
  • Step 123 determining the second public data to be distilled; the second public data to be distilled is part of the target distillation public data that does not have a corresponding first target model output;
  • Step 124 sending a request to each of the clients
  • Step 125 The client returns the first model output; wherein, the first model output is a partial model output corresponding to the second public data to be distilled among the model outputs of each local learning model;
  • Step 126 receiving the first model output sent back by each client.
  • the data sent by each client is the data used for aggregation distillation, and the knowledge contained in each local learning model trained based on the private data set is sent to the server-side model based on the output of these first models.
  • Such setting not only avoids the problem of privacy data leakage in the process of knowledge transmission, but also reduces the amount of data that needs to be transmitted.
  • the aggregation distillation is carried out, and the steps mainly include:
  • Step 131 filtering the output of the first model to obtain the output of the second target model;
  • the output of the second target model is a partial model output in which the corresponding prediction result in the output of the first model conforms to the corresponding real label;
  • This step is to eliminate the output of the first model that cannot play a good role in teaching the training of the second public data to be distilled.
  • Step 132 determining the third public data to be distilled;
  • the third public data to be distilled is part output data of the second target model corresponding to the second public data to be distilled;
  • Step 133 determining the information entropy of the model output in each of the second target model outputs
  • Step 134 based on the size of the information entropy, determine the weight of the model output in each of the second target model outputs;
  • Step 135 Fusing the outputs of each of the second target models based on the weights to obtain a third target model output
  • Step 136 retraining the server-side model based on the third public data to be distilled and the output of the third target model. That is: use the weighted model output combined with the cross-entropy loss to distill the server side.
  • the server-side model is retrained based on the third public data to be distilled and the output of each of the second target models.
  • it is weighted according to the information entropy of the output of the model. It is considered that the higher the information entropy, the lower the corresponding confidence, and the fusion of knowledge is carried out selectively.
  • step 140 and step 150 are executed to complete the retraining of the local learning model.
  • the embodiment of the present application provides a novel method, which uses selective knowledge fusion to store all knowledge in a powerful model as a general knowledge base to help federated learning.
  • the server-side model not only uses sufficient computing resources to train itself, but also uses all clients as multiple teachers to learn knowledge, helping to further improve the effect of the server-side model.
  • the accumulated knowledge on the server side will be further passed on to the client side to help all clients improve the effect of local learning models.
  • it can increase the robustness of the models at both ends, and reduce the communication cost of uploading knowledge from the client to the server.
  • the model training system based on privacy data set includes: a server end and a plurality of clients (represent a plurality of clients with hospital A and hospital B in Fig. 5)
  • Step 501 Hospital A trains the preset model based on the private data set A and corresponding labels to obtain a local learning model A;
  • Step 502 Hospital B trains the preset model based on the private data set B and corresponding labels to obtain a local learning model B;
  • Step 503 The server-side trains the server-side model based on the public data set and the real label corresponding to the public data set;
  • the private data set A, the private data set B and the public data set are the pictures of the injured part of the patient.
  • the main purpose of the embodiment of this application is to obtain a model that can identify the injured and predict the injured;
  • Step 504 Input the public dataset into the server-side model for prediction
  • Step 505 Save the correctly predicted model output to the memory of the global model output
  • Step 506 Based on the memory output by the global model, perform self-distillation on some incorrectly predicted samples;
  • Step 507 Obtain the model output A sent by hospital A;
  • Step 508 Obtain the model output B sent by the hospital B;
  • the model output A is obtained by inputting the public data set into the local learning model A
  • the model output B is obtained by inputting the public data set into the local learning model B
  • Step 509 Carry out the elimination operation and weighted fusion on the model output A and the model output B.
  • Step 510 Based on the fused model output, perform aggregation distillation on some of the wrongly predicted samples;
  • the data for aggregation distillation can be multiple pictures, and model output A and model output B have model outputs for each picture that undergoes aggregation distillation; when performing fusion and elimination, one picture should be progress of the picture. That is, first determine a picture for aggregation distillation, and then find out Hospital A and Hospital B to obtain the model output corresponding to this picture; judge whether the prediction results obtained by the two model outputs match the real label, and if they match, determine this The information entropy of the two pictures is weighted based on the level of information entropy, and the higher the information entropy, the lower the corresponding confidence.
  • Step 511 Input the public data set into the server-side model for prediction to obtain the second model output;
  • Step 512 Send the second model output to Hospital A;
  • Step 513 Train the local learning model A based on the second model output
  • Step 514 Send the second model output to Hospital B;
  • Step 515 Train the local learning model B based on the second model output.
  • This cycle is carried out, and all knowledge is stored in a powerful model by selective knowledge fusion as a general knowledge base to help federated learning. Then pass it to client hospital A and hospital B to help prompt the effect of local learning model A and local learning model B. This allows hospital A and hospital B to conduct joint training without disclosing their own private data sets to obtain local learning model A and local learning model B with better actual prediction effect.
  • FIG. 6 is a schematic structural diagram of a model training device based on a private data set provided in the embodiment of the present application. As shown in FIG. 6, the device includes:
  • the first training unit 61 is used to train the server-side model based on the public data set and the real label corresponding to the public data set;
  • the obtaining unit 62 is configured to obtain the first model output sent by each client; the first model output is obtained by the client inputting the public data set into a local learning model; the local learning model is obtained by the client based on private data The set and corresponding labels are trained on the preset model;
  • the second training unit 63 is configured to retrain the server-side model based on the public data corresponding to each of the first model outputs and each of the first model outputs;
  • An input unit 64 configured to input the public data set into the retrained server-side model to obtain a second model output
  • the sending unit 65 is configured to send the second model output to each of the clients, so that each of the clients can regenerate the local learning model based on the second model output and the public data set. train.
  • the first training unit 61 is specifically used for:
  • the server-side model is trained based on a cross-entropy loss function between the predicted result and the true label.
  • the first target model output is the model output corresponding to the target public data
  • the target public data is the public data set
  • the prediction result obtained after being input into the server-side model conforms to the corresponding Publicly available data on real labels
  • the target public data to be distilled is the public data set, and the prediction result obtained after inputting the server-side model does not conform to the public data corresponding to the real label;
  • the first public data to be distilled is part of the target public data to be distilled that has a corresponding first target model output;
  • the server-side model is trained based on the first public data to be distilled and a first target model output corresponding to the first public data to be distilled.
  • said obtaining the first model output sent by each client includes:
  • the second public data to be distilled is part of the target distillation public data that does not have a corresponding first target model output;
  • the request is used to request the client to return the first model output;
  • the first model output is the part corresponding to the second public data to be distilled in the model output of each local learning model model output;
  • the second training unit 63 is specifically used for:
  • the second target model output is a partial model output in which the corresponding prediction result in the first model output meets the corresponding real label
  • the third public data to be distilled is part output data of the second target model corresponding to the second public data to be distilled;
  • the server-side model is retrained.
  • the retraining of the server-side model based on the third public data to be distilled and the output of each of the second target models includes:
  • each of the second target model outputs is fused to obtain a third target model output
  • the public data set and the private data set include: image data, text data or sound data related to entities.
  • Figure 7 is a schematic structural diagram of an electronic device provided by the embodiment of the present application, as shown in Figure 7, the electronic device may include: a processor (processor) 710, a communication interface (Communications Interface) 720, a memory (memory) 730 and a communication bus 740 , where the processor 710 , the communication interface 720 , and the memory 730 communicate with each other through the communication bus 740 .
  • processor processor
  • Communication interface Communication Interface
  • memory memory
  • the processor 710 can call the logic command in the memory 730 to execute the following method: based on the public data set and the real label corresponding to the public data set, train the server-side model; obtain the first model output sent by each client The first model output is obtained by the client inputting the public data set into the local learning model; the local learning model is obtained by the client training the preset model based on the private data set and the corresponding label; The public data corresponding to the first model output and each of the first model outputs are used to retrain the server-side model; the public data set is input into the retrained server-side model to obtain the second model output; The second model output is sent to each of the clients, so that each of the clients can retrain the local learning model based on the second model output and the public data set.
  • the above-mentioned logic commands in the memory 730 can be implemented in the form of software function units and can be stored in a computer-readable storage medium when sold or used as an independent product.
  • the technical solution of the present application is essentially or the part that contributes to the prior art or the part of the technical solution can be embodied in the form of a software product, and the computer software product is stored in a storage medium, including Several commands are used to make a computer device (which may be a personal computer, a server, or a network device, etc.) execute all or part of the steps of the methods described in the various embodiments of the present application.
  • the aforementioned storage media include: U disk, mobile hard disk, read-only memory (ROM, Read-Only Memory), random access memory (RAM, Random Access Memory), magnetic disk or optical disc, etc., which can store program codes. .
  • the embodiment of the present application also provides a non-transitory computer-readable storage medium, on which a computer program is stored, and when the computer program is executed by a processor, it is implemented to perform the methods provided by the above-mentioned embodiments, for example, including: and the real label corresponding to the public data set, train the server-side model; obtain the first model output sent by each client; the first model output is obtained by the client inputting the public data set into the local learning model
  • the local learning model is obtained by the client from training the preset model based on the private data set and the corresponding label; based on the public data corresponding to the output of each of the first models and the output of each of the first models, the server-side retraining the model; inputting the public data set into the retrained server-side model to obtain a second model output; sending the second model output to each of the clients for each of the clients Retraining of the local learning model is performed based on the second model output and the public dataset.
  • the device embodiments described above are only illustrative, and the units described as separate components may or may not be physically separated, and the components shown as units may or may not be physical units, that is, they may be located in One place, or it can be distributed to multiple network elements. Part or all of the devices can be selected according to actual needs to achieve the purpose of the solution of this embodiment. It can be understood and implemented by those skilled in the art without any creative efforts.
  • each implementation can be implemented by means of software plus a necessary general hardware platform, and of course also by hardware.
  • the essence of the above technical solution or the part that contributes to the prior art can be embodied in the form of software products, and the computer software products can be stored in computer-readable storage media, such as ROM/RAM, magnetic Disc, CD, etc., including several commands to make a computer device (which may be a personal computer, server, or network device, etc.) execute the methods described in various embodiments or some parts of the embodiments.

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Software Systems (AREA)
  • Evolutionary Computation (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Bioethics (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Medical Informatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computer Security & Cryptography (AREA)
  • Computer Hardware Design (AREA)
  • Databases & Information Systems (AREA)
  • Evolutionary Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

A private data set-based method and apparatus for model training, which relate to the technical field of multi-party data collaboration. The method comprises: training a server-side model on the basis of a public data set and a real label corresponding to the public data set; obtaining first model outputs sent by clients, the first model outputs being obtained by inputting the public data set into local learning models, and the local learning models being obtained by training on the basis of the private data set and the corresponding label; training the server-side model on the basis of public data corresponding to the first model outputs; inputting the public data set into the server-side model to obtain second model outputs; and sending the second model outputs to the clients, for the clients to perform retraining of the local learning models on the basis of the second model outputs and the public data set. As such, while avoiding private data set leakage, model training is performed by using the private data set as part of training samples on the basis of knowledge distillation and knowledge fusion.

Description

隐私数据集的模型训练方法和装置Model training method and device for private data set
相关申请的交叉引用Cross References to Related Applications
本申请要求于2021年09月30日提交的申请号为202111165679.6,发明名称为“基于隐私数据集的模型训练方法和装置”、以及于2021年10月12日提交的申请号为202111189306.2,发明名称为“基于隐私数据集的模型训练方法和装置”的中国专利申请的优先权,其通过引用方式全部并入本文。This application requires that the application number submitted on September 30, 2021 is 202111165679.6, and the name of the invention is "model training method and device based on private data sets", and the application number submitted on October 12, 2021 is 202111189306.2, the name of the invention Priority to a Chinese patent application for "Model Training Method and Apparatus Based on Private Dataset", which is incorporated herein by reference in its entirety.
技术领域technical field
本申请涉及多方数据合作的技术领域,尤其涉及一种基于隐私数据集的模型训练方法和装置。This application relates to the technical field of multi-party data cooperation, and in particular to a model training method and device based on private data sets.
背景技术Background technique
在数据分析、数据挖掘、经济预测等领域,机器学习模型可被用来分析、发现潜在的数据价值。由于单个数据拥有方持有的数据可能是不完整的,由此难以准确地刻画目标,为了得到更好的模型预测结果,通过多个数据拥有方的数据合作,来进行模型的联合训练的方式得到了广泛的使用。但是在多方数据合作的过程中,涉及到数据安全和模型安全等问题。In data analysis, data mining, economic forecasting and other fields, machine learning models can be used to analyze and discover potential data value. Since the data held by a single data owner may be incomplete, it is difficult to accurately describe the target. In order to obtain better model prediction results, the joint training of the model is carried out through the data cooperation of multiple data owners. has been widely used. However, in the process of multi-party data cooperation, issues such as data security and model security are involved.
特别是在医疗领域,一些数据集涉及隐私无法公开,只可以在医院内部使用。若想基于各个医院的隐私数据集搭建一个学习模型十分困难。现有的方案中,存在利用隐私数据集和将隐私数据集输入学习模型后得到的模型输出(一般为学习模型的最后一层神经网络的输出)而非模型结果和对应标签作为交换的信息,通过知识蒸馏和知识融合的方式进行模型的训练。但是这种方式下,不仅仍然存在隐私泄露的问题。Especially in the medical field, some data sets involve privacy and cannot be made public, and can only be used within the hospital. It is very difficult to build a learning model based on the private data sets of various hospitals. In existing schemes, there are model outputs obtained after using private data sets and inputting private data sets into the learning model (generally the output of the last layer of neural network of the learning model) rather than model results and corresponding labels as exchanged information. The model is trained by means of knowledge distillation and knowledge fusion. But in this way, not only the problem of privacy leakage still exists.
因此,目前缺少基于多方的隐私数据集进行模型训练方案。Therefore, there is currently a lack of model training solutions based on multi-party private data sets.
发明内容Contents of the invention
本申请实施例提供一种基于隐私数据集的模型训练方法和装置,用以 解决现有缺少基于多方的隐私数据集进行模型训练方案问题。The embodiment of the present application provides a model training method and device based on a private data set to solve the problem of the lack of a model training solution based on a multi-party private data set.
第一方面,本申请实施例提供一种基于隐私数据集的模型训练方法,包括:In the first aspect, the embodiment of the present application provides a model training method based on a private data set, including:
基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;Based on the public data set and the real label corresponding to the public data set, the server-side model is trained;
获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;Obtain the first model output sent by each client; the first model output is obtained by the client inputting the public data set into the local learning model; the local learning model is obtained by the client based on the private data set and corresponding label pair Suppose the model is trained;
基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;retraining the server-side model based on the public data corresponding to each of said first model outputs and each of said first model outputs;
将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;inputting the public data set into the retrained server-side model to obtain a second model output;
将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练。The second model output is sent to each of the clients, so that each of the clients can retrain the local learning model based on the second model output and the public data set.
可选的,所述基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练,包括:Optionally, the training of the server-side model based on the public data set and the real label corresponding to the public data set includes:
将所述公开数据集输入服务器端模型得到预测结果;Inputting the public data set into the server-side model to obtain prediction results;
基于所述预测结果与所述真实标签之间的交叉熵损失函数,对所述服务器端模型进行训练。The server-side model is trained based on a cross-entropy loss function between the predicted result and the true label.
可选的,所述基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练,还包括:Optionally, the training of the server-side model based on the public data set and the real label corresponding to the public data set further includes:
确定并存储第一目标模型输出;所述第一目标模型输出为与目标公开数据对应的模型输出;所述目标公开数据为所述公开数据集中,被输入服务器端模型后得到的预测结果符合对应真实标签的公开数据;Determine and store the first target model output; the first target model output is the model output corresponding to the target public data; the target public data is the public data set, and the prediction result obtained after being input into the server-side model conforms to the corresponding Publicly available data on real labels;
确定目标待蒸馏公开数据;所述目标待蒸馏公开数据为所述公开数据集中,输入服务器端模型后得到的预测结果不符合对应真实标签的公开数据;Determine the target public data to be distilled; the target public data to be distilled is the public data set, and the prediction result obtained after inputting the server-side model does not conform to the public data corresponding to the real label;
确定第一待蒸馏公开数据;所述第一待蒸馏公开数据为所述目标待蒸 馏公开数据中,具有对应的第一目标模型输出的部分数据;Determine the first public data to be distilled; the first public data to be distilled is part of the target public data to be distilled that has a corresponding first target model output;
基于所述第一待蒸馏公开数据和与所述第一待蒸馏公开数据对应的第一目标模型输出,对所述服务器端模型进行训练。The server-side model is trained based on the first public data to be distilled and a first target model output corresponding to the first public data to be distilled.
可选的,所述获取各个客户端发送的第一模型输出,包括:Optionally, the acquiring the first model output sent by each client includes:
确定第二待蒸馏公开数据;所述第二待蒸馏公开数据为所述目标蒸馏公开数据中,不具有对应的第一目标模型输出的部分数据;Determine the second public data to be distilled; the second public data to be distilled is part of the target distillation public data that does not have a corresponding first target model output;
向各所述客户端发送请求;所述请求用于请求客户端回传第一模型输出;所述第一模型输出为各本地学习模型的模型输出中对应所述第二待蒸馏公开数据的部分模型输出;sending a request to each of the clients; the request is used to request the client to return the first model output; the first model output is the part corresponding to the second public data to be distilled in the model output of each local learning model model output;
接收各所述客户端回传的第一模型输出。receiving the first model output returned by each client.
可选的,所述基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练,包括:Optionally, retraining the server-side model based on the public data corresponding to each of the first model outputs and each of the first model outputs includes:
对所述第一模型输出进行筛选,得到第二目标模型输出;所述第二目标模型输出为所述第一模型输出中对应的预测结果符合对应的真实标签的部分模型输出;Filtering the first model output to obtain a second target model output; the second target model output is a partial model output in which the corresponding prediction result in the first model output meets the corresponding real label;
确定第三待蒸馏公开数据;所述第三待蒸馏公开数据为所述第二待蒸馏公开数据中具有对应的第二目标模型部分输出数据;Determine the third public data to be distilled; the third public data to be distilled is part output data of the second target model corresponding to the second public data to be distilled;
基于第三待蒸馏公开数据和各所述第二目标模型输出,对服务器端模型进行再训练。Based on the third public data to be distilled and the output of each of the second target models, the server-side model is retrained.
可选的,所述基于第三待蒸馏公开数据和各所述第二目标模型输出,对服务器端模型进行再训练,包括:Optionally, the retraining of the server-side model based on the third public data to be distilled and the output of each of the second target models includes:
确定各所述第二目标模型输出中的模型输出的信息熵;determining an information entropy of a model output in each of said second target model outputs;
基于所述信息熵的大小确定各所述第二目标模型输出中的模型输出的权值;determining the weights of the model outputs in each of the second target model outputs based on the size of the information entropy;
基于所述权值对各所述第二目标模型输出进行融合,得到第三目标模型输出;fusing the outputs of each of the second target models based on the weights to obtain a third target model output;
基于所述第三待蒸馏公开数据和所述第三目标模型输出,对服务器端模型进行再训练。Retraining the server-side model based on the third public data to be distilled and the output of the third target model.
可选的,所述公开数据集和所述隐私数据集包括:与实体相关的图像数据、文本数据或声音数据。Optionally, the public data set and the private data set include: image data, text data or sound data related to entities.
第二方面,本申请实施例提供一种基于隐私数据集的模型训练装置包括:In the second aspect, the embodiment of the present application provides a model training device based on a private data set, including:
第一训练单元,用于基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;The first training unit is used to train the server-side model based on the public data set and the real label corresponding to the public data set;
获取单元,用于获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;The obtaining unit is used to obtain the first model output sent by each client; the first model output is obtained by the client inputting the public data set into the local learning model; the local learning model is obtained by the client based on the private data set And the corresponding label is obtained by training the preset model;
第二训练单元,用于基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;A second training unit, configured to retrain the server-side model based on the public data corresponding to each of the first model outputs and each of the first model outputs;
输入单元,用于将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;an input unit, configured to input the public data set into the retrained server-side model to obtain a second model output;
发送单元,用于将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练。a sending unit, configured to send the second model output to each of the clients, so that each of the clients can retrain the local learning model based on the second model output and the public data set .
第三方面,本申请实施例提供一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现如本申请提供的基于隐私数据集的模型训练方法的步骤。In the third aspect, the embodiment of the present application provides an electronic device, including a memory, a processor, and a computer program stored on the memory and operable on the processor, wherein the processor implements the following when executing the program: The steps of the model training method based on the private data set provided by this application.
第四方面,本申请实施例提供一种非暂态计算机可读存储介质,其上存储有计算机程序,其特征在于,该计算机程序被处理器执行时实现如本申请提供的基于隐私数据集的模型训练方法的步骤。In the fourth aspect, the embodiment of the present application provides a non-transitory computer-readable storage medium on which a computer program is stored, which is characterized in that, when the computer program is executed by a processor, the private data set-based privacy data collection provided by this application is realized. The steps of the model training method.
本申请实施例提供一种基于隐私数据集的模型训练方法,通过公开数据集、第一模型输出和第二模型输出作为各个本地学习模型和服务器端模型信息交换的渠道和媒介,充分发挥服务器端模型的自主训练能力,基于各个第一模型输出进行知识蒸馏和知识融合,之后将融合后得到的知识,基于第二模型输出发送回各个本地学习模型,使得各个本地学习模型可以 得到融合后的知识。即:通过公开数据集、第一模型输出和第二模型输出作为知识传输的媒介,将所有的知识都存储在一个强大的模型(服务器端模型)里作为通用的知识库来帮助联邦学习。服务器端模型不仅仅利用充分的计算资源去训练自身,同时也会将所有的客户端作为多个老师来学习知识,帮助服务器端的模型的效果进一步提升。作为回报,服务器端的积累的知识也会进一步传递给客户端来帮助所有客户端的本地学习模型效果提升,使得最后训练得到的各个本地学习模型包含多方隐私数据集的知识,即各个本地学习模型是基于多方隐私数据集训练得到的。如此,本申请实施例提供了一种可行的基于隐私数据集的模型训练方法,所述基于隐私数据集的模型训练方法可以具体应用于医疗领域中一些有关隐私数据的模型的训练。The embodiment of the present application provides a model training method based on a private data set. By publicizing the data set, the output of the first model and the output of the second model as the channels and media for information exchange between each local learning model and the server-side model, the server-side The independent training ability of the model, based on the output of each first model, performs knowledge distillation and knowledge fusion, and then sends the fused knowledge back to each local learning model based on the output of the second model, so that each local learning model can obtain the fused knowledge . That is: through the public data set, the output of the first model and the output of the second model as the medium of knowledge transmission, all the knowledge is stored in a powerful model (server-side model) as a general knowledge base to help federated learning. The server-side model not only uses sufficient computing resources to train itself, but also uses all clients as multiple teachers to learn knowledge, helping to further improve the effect of the server-side model. In return, the accumulated knowledge on the server side will be further passed on to the client to help improve the effect of the local learning models of all clients, so that each local learning model obtained after training contains the knowledge of multiple private data sets, that is, each local learning model is based on Trained on multi-party privacy datasets. In this way, the embodiment of the present application provides a feasible model training method based on private data sets, which can be specifically applied to the training of models related to private data in the medical field.
附图说明Description of drawings
为了更清楚地说明本申请或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。In order to more clearly illustrate the technical solutions in the present application or the prior art, the following will briefly introduce the accompanying drawings that need to be used in the description of the embodiments or the prior art. Obviously, the accompanying drawings in the following description are the For some embodiments of the present invention, those of ordinary skill in the art can also obtain other drawings based on these drawings on the premise of not paying creative efforts.
图1为本申请实施例提供的基于隐私数据集的模型训练方法的流程示意图之一;Fig. 1 is one of the schematic flow charts of the model training method based on the private data set provided by the embodiment of the present application;
图2为本申请实施例提供的基于隐私数据集的模型训练方法的流程示意图之二;Fig. 2 is the second schematic flow diagram of the model training method based on the private data set provided by the embodiment of the present application;
图3为本申请实施例提供的基于隐私数据集的模型训练方法的流程示意图之三;Fig. 3 is the third schematic flow diagram of the model training method based on the private data set provided by the embodiment of the present application;
图4为本申请实施例提供的基于隐私数据集的模型训练方法的流程示意图之四;Fig. 4 is the fourth schematic flow diagram of the model training method based on the private data set provided by the embodiment of the present application;
图5为本申请实施例提供的基于隐私数据集的模型训练方法的流程示意图之五;Fig. 5 is the fifth schematic flow diagram of the model training method based on the private data set provided by the embodiment of the present application;
图6为本申请实施例提供的基于隐私数据集的模型训练装置的结构示意图;FIG. 6 is a schematic structural diagram of a model training device based on a private data set provided in an embodiment of the present application;
图7为本申请实施例提供的电子设备的结构示意图。FIG. 7 is a schematic structural diagram of an electronic device provided by an embodiment of the present application.
具体实施方式Detailed ways
为使本申请的目的、技术方案和优点更加清楚,下面将结合本申请中的附图,对本申请中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。In order to make the purpose, technical solutions and advantages of this application clearer, the technical solutions in this application will be clearly and completely described below in conjunction with the accompanying drawings in this application. Obviously, the described embodiments are part of the embodiments of this application , but not all examples. Based on the embodiments in this application, all other embodiments obtained by persons of ordinary skill in the art without making creative efforts belong to the scope of protection of this application.
在经济、文化、教育、医疗、公共管理等各行各业充斥的大量信息数据,对其进行例如数据分析、数据挖掘、以及趋势预测等的数据处理在越来越多场景中广泛应用。其中,通过数据合作的方式可以使多个数据拥有方获得更好的数据处理结果。例如,可以通过多方数据的联合训练来获得更为准确的模型参数。Data processing such as data analysis, data mining, and trend prediction is widely used in more and more scenarios for a large amount of information data flooding in various industries such as economy, culture, education, medical care, and public management. Among them, through data cooperation, multiple data owners can obtain better data processing results. For example, more accurate model parameters can be obtained through joint training of multi-party data.
在一些实施例中,基于隐私数据进行模型的联合训练系统可以应用于在保证各方数据安全的情况下,各方协同训练机器学习模型供多方使用的场景。在这个场景中,多个数据方拥有自己的数据,他们想共同使用彼此的数据来统一建模(例如,线性回归模型、逻辑回归模型等),但并不想各自的数据(尤其是隐私数据)被泄露。例如,医院A拥有一批患者数据(例如患者病患部位的照片)因为患者隐私问题不适合公开,医院B拥有一批患者数据同样因为患者隐私问题不适合公开,基于医院A和医院B的患者数据确定的训练样本集可以训练得到比较好的机器学习模型。A和B都愿意通过彼此的患者数据共同参与模型训练,但是医院A和医院B需要保证患者数据不会遭到泄露,不可以或者不愿意让对方知道自己的患者数据。因此需要一种基于隐私数据集的模型训练方法可以使多方的隐私数据在不受到泄露的情况下,通过多方数据的联合训练来得到共同使用的机器学习模型,达到一种共赢的合作状态。基于此,本申请实施例基于知识蒸馏和联邦学习提供一种基于隐私数据集的模型训练方法和装置。In some embodiments, the joint training system for models based on private data can be applied to the scenario where all parties cooperate to train machine learning models for use by multiple parties while ensuring the data security of all parties. In this scenario, multiple data parties have their own data, and they want to jointly use each other's data for unified modeling (for example, linear regression model, logistic regression model, etc.), but they do not want their own data (especially private data) was leaked. For example, hospital A has a batch of patient data (such as photos of patients’ diseased parts) that are not suitable for public disclosure due to patient privacy issues, and hospital B has a batch of patient data that is also not suitable for public disclosure due to patient privacy issues, based on the patients of hospital A and hospital B The training sample set determined by the data can be trained to obtain a relatively good machine learning model. Both A and B are willing to participate in model training through each other's patient data, but Hospital A and Hospital B need to ensure that patient data will not be leaked, and they cannot or are unwilling to let the other party know their patient data. Therefore, there is a need for a model training method based on private data sets, which can enable multiple parties to obtain a jointly used machine learning model through joint training of multiple data without leaking private data, and achieve a win-win cooperation state. Based on this, the embodiments of the present application provide a model training method and device based on private datasets based on knowledge distillation and federated learning.
其中,在传统的联邦学习设置中,客户端上传模型参数或者模型梯度 给中心服务器端,由服务器端按照一定的形式聚合后分发回客户端,并且在本地化数据上进一步更新。传递参数或者梯度会带来一系列隐私、异质性以及通讯成本的问题,目前有工作采用知识蒸馏的方式在终端和服务器端传递知识来解决。但是由于客户端实际上是资源受限的,直接在客户端使用大模型是不可能的,因此如何解决资源受限问题仍然是一个巨大的挑战。只有通过尽可能地去挖掘服务器端的计算资源,在服务器端利用辅助的大模型传输和累积知识,才能实现和用大模型进行中心化训练一样的知识融合的效果。Among them, in the traditional federated learning setting, the client uploads model parameters or model gradients to the central server, and the server aggregates them in a certain form and distributes them back to the client, and further updates on the localized data. Transferring parameters or gradients will bring a series of privacy, heterogeneity, and communication cost issues. At present, some work uses knowledge distillation to transfer knowledge between the terminal and the server. But since the client is actually resource-constrained, it is impossible to use large models directly on the client, so how to solve the resource-constrained problem is still a great challenge. Only by mining the computing resources on the server side as much as possible, and using the auxiliary large model to transfer and accumulate knowledge on the server side, can we achieve the same effect of knowledge fusion as centralized training with large models.
图1为本申请实施例提供的基于隐私数据集的模型训练方法的流程示意图之一,如图1所示,该方法包括:Fig. 1 is one of the flow diagrams of the model training method based on the private data set provided by the embodiment of the present application. As shown in Fig. 1, the method includes:
步骤110,基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练; Step 110, based on the public data set and the real label corresponding to the public data set, the server-side model is trained;
其中,公开数据集和隐私数据集为同一类数据,只是公开数据集为可以进行公开的数据,隐私数据集为不可以或者不适合进行公开的数据。具体的,公开数据集和隐私数据集可以为与实体相关的图像数据、文本数据或声音数据。例如,一些医院的患者疾病图片数据,一些互联网公司的用户数据。所述服务器端模型为大模型,即服务器端模型较为复杂,可以尽可能的挖掘和学习知识。Among them, the public data set and the private data set are the same type of data, except that the public data set is data that can be made public, and the private data set is data that cannot or is not suitable for publicity. Specifically, the public data set and the private data set may be image data, text data or sound data related to entities. For example, patient disease picture data of some hospitals, and user data of some Internet companies. The server-side model is a large model, that is, the server-side model is relatively complex, and knowledge can be mined and learned as much as possible.
步骤120,获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的; Step 120, obtain the first model output sent by each client; the first model output is obtained by the client inputting the public data set into the local learning model; the local learning model is obtained by the client based on the private data set and the corresponding The label is trained on the preset model;
步骤130,基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练; Step 130, retraining the server-side model based on the public data corresponding to each of the first model outputs and each of the first model outputs;
步骤140,将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出; Step 140, inputting the public data set into the retrained server-side model to obtain a second model output;
步骤150,将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练。 Step 150, sending the second model output to each of the clients, so that each of the clients can retrain the local learning model based on the second model output and the public data set.
通过公开数据集、第一模型输出和第二模型输出作为各个本地学习模型和服务器端模型信息交换的渠道和媒介,充分发挥服务器端模型的自主训练能力,基于各个第一模型输出进行知识蒸馏和知识融合,之后将融合后得到的知识,基于第二模型输出发送回各个本地学习模型,使得各个本地学习模型可以得到融合后的知识。即:通过公开数据集、第一模型输出和第二模型输出作为知识传输的媒介,将所有的知识都存储在一个强大的模型(服务器端模型)里作为通用的知识库来帮助联邦学习。服务器端模型不仅仅利用充分的计算资源去训练自身,同时也会将所有的客户端作为多个老师来学习知识,帮助服务器端的模型的效果进一步提升。作为回报,服务器端的积累的知识也会进一步传递给客户端来帮助所有客户端的本地学习模型效果提升,使得最后训练得到的各个本地学习模型包含多方隐私数据集的知识,即各个本地学习模型是基于多方隐私数据集训练得到的。Through public data sets, the output of the first model and the output of the second model as the channels and media for the information exchange between the local learning models and the server-side model, the independent training ability of the server-side model is fully utilized, and knowledge distillation and Knowledge fusion, and then the knowledge obtained after fusion is sent back to each local learning model based on the output of the second model, so that each local learning model can obtain the fused knowledge. That is: through the public data set, the output of the first model and the output of the second model as the medium of knowledge transmission, all the knowledge is stored in a powerful model (server-side model) as a general knowledge base to help federated learning. The server-side model not only uses sufficient computing resources to train itself, but also uses all clients as multiple teachers to learn knowledge, helping to further improve the effect of the server-side model. In return, the accumulated knowledge on the server side will be further passed on to the client to help improve the effect of the local learning models of all clients, so that each local learning model obtained after training contains the knowledge of multiple private data sets, that is, each local learning model is based on Trained on multi-party privacy datasets.
本申请实施例提供的方案中,服务器端模型作为知识聚合的中心,其学习的知识直接影响了最终各个本地学习模型基于第二模型输出获取的知识;因此服务器端模型的训练是比较重要的一部分。In the solution provided by the embodiment of the present application, the server-side model is the center of knowledge aggregation, and the knowledge it learns directly affects the knowledge obtained by each local learning model based on the output of the second model; therefore, the training of the server-side model is a relatively important part .
具体的,步骤110中“基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练”和步骤130“基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练”是服务器端模型的训练部分:Specifically, in step 110 "train the server-side model based on the public data set and the real label corresponding to the public data set" and step 130 "based on the public data corresponding to each of the first model outputs and each of the The output of the first model described above, retraining the server-side model" is the training part of the server-side model:
服务器端模型的训练主要分为3部分:初步训练、自蒸馏和聚集蒸馏(再训练)。需要说明的是,这3个部分并非是严格按照时间顺序执行的,而是互相融合进行的。The training of the server-side model is mainly divided into three parts: preliminary training, self-distillation and aggregation distillation (retraining). It should be noted that these three parts are not executed in strict chronological order, but integrated with each other.
参照图2,初步训练、自蒸馏的步骤具体如下:Referring to Figure 2, the steps of preliminary training and self-distillation are as follows:
步骤111,将所述公开数据集输入服务器端模型得到预测结果; Step 111, inputting the public data set into the server-side model to obtain prediction results;
步骤112,基于所述预测结果与所述真实标签之间的交叉熵损失函数,对所述服务器端模型进行训练。Step 112: Train the server-side model based on the cross-entropy loss function between the prediction result and the real label.
这一部分的训练比较常规,简单地采用预测结果和真实标签之间的交叉熵损失函数来训练服务器端模型。具体的可以参照一些现有的训练实施 例。This part of the training is relatively routine, simply using the cross-entropy loss function between the predicted result and the real label to train the server-side model. Specifically, some existing training examples can be referred to.
步骤113,确定并存储第一目标模型输出;所述第一目标模型输出为与目标公开数据对应的模型输出;所述目标公开数据为所述公开数据集中,被输入服务器端模型后得到的预测结果符合对应真实标签的公开数据; Step 113, determine and store the first target model output; the first target model output is the model output corresponding to the target public data; the target public data is the prediction obtained after being input into the server-side model in the public data set The results conform to the public data corresponding to the real label;
具体的,步骤114,进行第一目标模型输出的存储,为步骤114、步骤115和步骤116中得自蒸馏进行准备。将本次预测正确的模型输出(即:第一目标模型输出)保存到全局的模型输出中作为记忆,帮助之后纠正预测错误但曾经做对过的样本。Specifically, in step 114, the output of the first target model is stored to prepare for the distillation obtained in step 114, step 115 and step 116. Save the correctly predicted model output (that is, the first target model output) to the global model output as a memory to help correct the wrong but correct samples later.
关于自蒸馏的具体说明如下:针对公开数据集中模型预测错误的样本(即目标待蒸馏公开数据),我们首先去寻找全局的模型输出的记忆(即:第一目标模型输出)中是否存在该样本对应的模型输出,如果存在的话说明这部分的知识模型曾经是包含的,因此温习自身曾经掌握的知识可以帮助模型纠正自身的错误,针对该思路我们采用了自蒸馏的方式对模型进行蒸馏训练,用目前的模型输出去接近之前做对的时候的模型输出,同时结合交叉熵损失。具体步骤如下:The specific instructions for self-distillation are as follows: For samples with wrong model predictions in the public data set (that is, the target public data to be distilled), we first look for whether the sample exists in the memory of the global model output (ie: the output of the first target model) The corresponding model output, if it exists, means that this part of the knowledge model was once included, so reviewing the knowledge you have mastered can help the model correct its own mistakes. For this idea, we use the self-distillation method to carry out distillation training on the model. Use the current model output to approach the model output when it was right before, and combine the cross-entropy loss. Specific steps are as follows:
步骤114,确定目标待蒸馏公开数据;所述目标待蒸馏公开数据为所述公开数据集中,输入服务器端模型后得到的预测结果不符合对应真实标签的公开数据; Step 114, determining the target public data to be distilled; the target public data to be distilled is the public data set, and the prediction result obtained after inputting the server-side model does not conform to the public data corresponding to the real label;
步骤115,确定第一待蒸馏公开数据;所述第一待蒸馏公开数据为所述目标待蒸馏公开数据中,具有对应的第一目标模型输出的部分数据; Step 115, determining the first public data to be distilled; the first public data to be distilled is part of the target public data to be distilled that has a corresponding first target model output;
步骤116,基于所述第一待蒸馏公开数据和与所述第一待蒸馏公开数据对应的第一目标模型输出,对所述服务器端模型进行训练。Step 116: Train the server-side model based on the first public data to be distilled and a first target model output corresponding to the first public data to be distilled.
通过上述方式完成服务器端模型的自蒸馏。对于目标待蒸馏公开数据中除了第一待蒸馏公开数据的其他数据进行聚焦蒸馏;聚焦蒸馏是本申请实施例方案的一个核心,只要用于获取其他的客户端的隐私数据集的知识。The self-distillation of the server-side model is completed in the above way. Focused distillation is performed on other data except the first public data to be distilled in the target public data to be distilled; focused distillation is a core of the solution of the embodiment of this application, as long as it is used to obtain the knowledge of private data sets of other clients.
参照图3,进行聚集蒸馏之前需要执行步骤120中“获取各个客户端 发送的第一模型输出”,具体步骤如下:Referring to Figure 3, before performing aggregation distillation, it is necessary to perform "obtaining the first model output sent by each client" in step 120, and the specific steps are as follows:
步骤121,客户端基于隐私数据集和对应标签对预设的预设服务器端模型训练; Step 121, the client trains the preset server-side model based on the private data set and the corresponding label;
步骤122,客户端将所述公开数据集输入本地学习模型得到的模型输出; Step 122, the client inputs the public data set into the local learning model to obtain a model output;
步骤123,确定第二待蒸馏公开数据;所述第二待蒸馏公开数据为所述目标蒸馏公开数据中,不具有对应的第一目标模型输出的部分数据; Step 123, determining the second public data to be distilled; the second public data to be distilled is part of the target distillation public data that does not have a corresponding first target model output;
步骤124,向各所述客户端发送请求; Step 124, sending a request to each of the clients;
步骤125;客户端回传第一模型输出;其中,所述第一模型输出为各本地学习模型的模型输出中对应所述第二待蒸馏公开数据的部分模型输出;Step 125: The client returns the first model output; wherein, the first model output is a partial model output corresponding to the second public data to be distilled among the model outputs of each local learning model;
步骤126,接收各所述客户端回传的第一模型输出。 Step 126, receiving the first model output sent back by each client.
如此设置,各个客户端发送的数据为进行聚集蒸馏使用的数据,将各个基于隐私数据集训练的到的本地学习模型包含的知识,基于这些第一模型输出将这些知识发送至服务器端模型。如此设置,不仅仅避免了知识的传输过程中的隐私数据泄露的问题,还减少了需要传输的数据的量。In this way, the data sent by each client is the data used for aggregation distillation, and the knowledge contained in each local learning model trained based on the private data set is sent to the server-side model based on the output of these first models. Such setting not only avoids the problem of privacy data leakage in the process of knowledge transmission, but also reduces the amount of data that needs to be transmitted.
针对服务器端从始至终从未做对过的样本(即:第二待蒸馏公开数据),我们认为服务器端暂时不具备仅依靠自身预测正确的能力,因此选择聚集来自客户端的知识来帮助引导服务器端学习。首先我们从所有的客户端选择出能预测正确答案的模型,然后根据模型的输出的信息熵高低,以信息熵越高则其相应的置信度越低为原则,对其进行加权。For the samples that the server has never done correctly from the beginning to the end (ie: the second public data to be distilled), we believe that the server does not have the ability to predict correctly only by itself, so we choose to gather knowledge from the client to help guide Server-side learning. First, we select a model that can predict the correct answer from all clients, and then weight it according to the information entropy of the output of the model, based on the principle that the higher the information entropy, the lower the corresponding confidence.
具体的,参照图4,进行聚集蒸馏,步骤主要包括:Specifically, with reference to Figure 4, the aggregation distillation is carried out, and the steps mainly include:
步骤131,对所述第一模型输出进行筛选,得到第二目标模型输出;所述第二目标模型输出为所述第一模型输出中对应的预测结果符合对应的真实标签的部分模型输出; Step 131, filtering the output of the first model to obtain the output of the second target model; the output of the second target model is a partial model output in which the corresponding prediction result in the output of the first model conforms to the corresponding real label;
这一步骤的目的是剔除第一模型输出中,无法对的第二待蒸馏公开数据的训练起到好的教导作用的模型输出。The purpose of this step is to eliminate the output of the first model that cannot play a good role in teaching the training of the second public data to be distilled.
步骤132,确定第三待蒸馏公开数据;所述第三待蒸馏公开数据为所 述第二待蒸馏公开数据中具有对应的第二目标模型部分输出数据; Step 132, determining the third public data to be distilled; the third public data to be distilled is part output data of the second target model corresponding to the second public data to be distilled;
步骤133,确定各所述第二目标模型输出中的模型输出的信息熵; Step 133, determining the information entropy of the model output in each of the second target model outputs;
步骤134,基于所述信息熵的大小确定各所述第二目标模型输出中的模型输出的权值; Step 134, based on the size of the information entropy, determine the weight of the model output in each of the second target model outputs;
步骤135,基于所述权值对各所述第二目标模型输出进行融合,得到第三目标模型输出;Step 135: Fusing the outputs of each of the second target models based on the weights to obtain a third target model output;
步骤136,基于所述第三待蒸馏公开数据和所述第三目标模型输出,对服务器端模型进行再训练。即:利用加权得到的模型输出结合交叉熵损失对服务器端进行蒸馏。 Step 136, retraining the server-side model based on the third public data to be distilled and the output of the third target model. That is: use the weighted model output combined with the cross-entropy loss to distill the server side.
其中,步骤133到步骤136,基于第三待蒸馏公开数据和各所述第二目标模型输出,对服务器端模型进行再训练。在具体的融合过程中,根据模型的输出的信息熵高低对其进行加权,认为信息熵越高则其相应的置信度越低,有选择的进行知识的融合。之后执行步骤140和步骤150完成本地学习模型的再训练。Wherein, from step 133 to step 136, the server-side model is retrained based on the third public data to be distilled and the output of each of the second target models. In the specific fusion process, it is weighted according to the information entropy of the output of the model. It is considered that the higher the information entropy, the lower the corresponding confidence, and the fusion of knowledge is carried out selectively. Afterwards, step 140 and step 150 are executed to complete the retraining of the local learning model.
基于上述方案,本申请实施例提供一种新颖的方法,采用有选择地知识融合的方式将所有的知识都存储在一个强大的模型里作为通用的知识库来帮助联邦学习。服务器端模型不仅仅利用充分的计算资源去训练自身,同时也会将所有的客户端作为多个老师来学习知识,帮助服务器端的模型的效果进一步提升。作为回报,服务器端的积累的知识也会进一步传递给客户端来帮助所有客户端本地学习模型效果提升。与此同时,还能增加两端模型的鲁棒性,并且减少从客户端上传知识到服务器端的通讯成本。Based on the above solution, the embodiment of the present application provides a novel method, which uses selective knowledge fusion to store all knowledge in a powerful model as a general knowledge base to help federated learning. The server-side model not only uses sufficient computing resources to train itself, but also uses all clients as multiple teachers to learn knowledge, helping to further improve the effect of the server-side model. In return, the accumulated knowledge on the server side will be further passed on to the client side to help all clients improve the effect of local learning models. At the same time, it can increase the robustness of the models at both ends, and reduce the communication cost of uploading knowledge from the client to the server.
下面结合具体的实施例对本申请实施例提供的方案进行说明:The scheme provided by the embodiment of the present application is described below in conjunction with specific embodiments:
参照图5,基于隐私数据集的模型训练系统包括:一个服务器端和多个客户端(图5中以医院A和医院B来表示多个客户端)With reference to Fig. 5, the model training system based on privacy data set includes: a server end and a plurality of clients (represent a plurality of clients with hospital A and hospital B in Fig. 5)
步骤501:医院A基于隐私数据集A和对应标签对预设模型训练得到本地学习模型A;Step 501: Hospital A trains the preset model based on the private data set A and corresponding labels to obtain a local learning model A;
步骤502:医院B基于隐私数据集B和对应标签对预设模型训练得到本地学习模型B;Step 502: Hospital B trains the preset model based on the private data set B and corresponding labels to obtain a local learning model B;
步骤503:服务器端基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;Step 503: The server-side trains the server-side model based on the public data set and the real label corresponding to the public data set;
其中,隐私数据集A、隐私数据集B和公开数据集为患者伤患处的图片,本申请实施例的主要目的是得到一种可以识别伤患,对伤患进行预测的模型;Among them, the private data set A, the private data set B and the public data set are the pictures of the injured part of the patient. The main purpose of the embodiment of this application is to obtain a model that can identify the injured and predict the injured;
步骤504:将公开数据集输入服务器端模型进行预测;Step 504: Input the public dataset into the server-side model for prediction;
步骤505:将预测正确的模型输出保存到全局的模型输出的记忆中;Step 505: Save the correctly predicted model output to the memory of the global model output;
步骤506:基于全局的模型输出的记忆,对部分预测错误的样本进行自蒸馏;Step 506: Based on the memory output by the global model, perform self-distillation on some incorrectly predicted samples;
步骤507:获取医院A发送的模型输出A;Step 507: Obtain the model output A sent by hospital A;
步骤508:获取医院B发送的模型输出B;Step 508: Obtain the model output B sent by the hospital B;
其中,模型输出A是将公开数据集输入本地学习模型A得到的;模型输出B是将公开数据集输入本地学习模型B得到的;Among them, the model output A is obtained by inputting the public data set into the local learning model A; the model output B is obtained by inputting the public data set into the local learning model B;
步骤509:对模型输出A和对模型输出B进行剔除操作和加权融合。Step 509: Carry out the elimination operation and weighted fusion on the model output A and the model output B.
步骤510:基于融合后的模型输出,对部分预测错误的样本进行聚集蒸馏;Step 510: Based on the fused model output, perform aggregation distillation on some of the wrongly predicted samples;
需说明的是,进行聚集蒸馏的数据可以为多张图片,模型输出A和模型输出B中具有针对每一张进行聚集蒸馏的图片的模型输出;在进行融合和剔除时,应该一张图片一张图片的进行。即首先确定进行聚集蒸馏的一张图片,之后找出医院A和医院B获取对应这张图片的模型输出;判断这两个模型输出得到的预测结果是否与真实标签匹配,如果匹配,则确定这两种图片的信息熵,基于信息熵的高低对其进行加权,认为信息熵越高则其相应的置信度越低。It should be noted that the data for aggregation distillation can be multiple pictures, and model output A and model output B have model outputs for each picture that undergoes aggregation distillation; when performing fusion and elimination, one picture should be progress of the picture. That is, first determine a picture for aggregation distillation, and then find out Hospital A and Hospital B to obtain the model output corresponding to this picture; judge whether the prediction results obtained by the two model outputs match the real label, and if they match, determine this The information entropy of the two pictures is weighted based on the level of information entropy, and the higher the information entropy, the lower the corresponding confidence.
步骤511:将公开数据集输入服务器端模型进行预测得到第二模型输出;Step 511: Input the public data set into the server-side model for prediction to obtain the second model output;
步骤512:发送第二模型输出至医院A;Step 512: Send the second model output to Hospital A;
步骤513:基于第二模型输出训练本地学习模型AStep 513: Train the local learning model A based on the second model output
步骤514:发送第二模型输出至医院B;Step 514: Send the second model output to Hospital B;
步骤515:基于第二模型输出训练本地学习模型B。Step 515: Train the local learning model B based on the second model output.
如此循环进行,采用有选择地知识融合的方式将所有的知识都存储在一个强大的模型里作为通用的知识库来帮助联邦学习。之后传递给客户端医院A和医院B来帮助提示本地学习模型A本地学习模型B的效果。使得医院A和医院B在不泄露自身隐私数据集的情况下,进行联合训练分别得到实际预测效果较好的本地学习模型A本地学习模型B。This cycle is carried out, and all knowledge is stored in a powerful model by selective knowledge fusion as a general knowledge base to help federated learning. Then pass it to client hospital A and hospital B to help prompt the effect of local learning model A and local learning model B. This allows hospital A and hospital B to conduct joint training without disclosing their own private data sets to obtain local learning model A and local learning model B with better actual prediction effect.
基于上述任一实施例,图6为本申请实施例提供的基于隐私数据集的模型训练装置的结构示意图,如图6所示,该装置包括:Based on any of the above-mentioned embodiments, FIG. 6 is a schematic structural diagram of a model training device based on a private data set provided in the embodiment of the present application. As shown in FIG. 6, the device includes:
第一训练单元61,用于基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;The first training unit 61 is used to train the server-side model based on the public data set and the real label corresponding to the public data set;
获取单元62,用于获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;The obtaining unit 62 is configured to obtain the first model output sent by each client; the first model output is obtained by the client inputting the public data set into a local learning model; the local learning model is obtained by the client based on private data The set and corresponding labels are trained on the preset model;
第二训练单元63,用于基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;The second training unit 63 is configured to retrain the server-side model based on the public data corresponding to each of the first model outputs and each of the first model outputs;
输入单元64,用于将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;An input unit 64, configured to input the public data set into the retrained server-side model to obtain a second model output;
发送单元65,用于将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练。The sending unit 65 is configured to send the second model output to each of the clients, so that each of the clients can regenerate the local learning model based on the second model output and the public data set. train.
其中,第一训练单元61,具体用于:Wherein, the first training unit 61 is specifically used for:
将所述公开数据集输入服务器端模型得到预测结果;Inputting the public data set into the server-side model to obtain prediction results;
基于所述预测结果与所述真实标签之间的交叉熵损失函数,对所述服务器端模型进行训练。The server-side model is trained based on a cross-entropy loss function between the predicted result and the true label.
确定并存储第一目标模型输出;所述第一目标模型输出为与目标公开数据对应的模型输出;所述目标公开数据为所述公开数据集中,被输入服务器端模型后得到的预测结果符合对应真实标签的公开数据;Determine and store the first target model output; the first target model output is the model output corresponding to the target public data; the target public data is the public data set, and the prediction result obtained after being input into the server-side model conforms to the corresponding Publicly available data on real labels;
确定目标待蒸馏公开数据;所述目标待蒸馏公开数据为所述公开数据集中,输入服务器端模型后得到的预测结果不符合对应真实标签的公开数据;Determine the target public data to be distilled; the target public data to be distilled is the public data set, and the prediction result obtained after inputting the server-side model does not conform to the public data corresponding to the real label;
确定第一待蒸馏公开数据;所述第一待蒸馏公开数据为所述目标待蒸馏公开数据中,具有对应的第一目标模型输出的部分数据;Determine the first public data to be distilled; the first public data to be distilled is part of the target public data to be distilled that has a corresponding first target model output;
基于所述第一待蒸馏公开数据和与所述第一待蒸馏公开数据对应的第一目标模型输出,对所述服务器端模型进行训练。The server-side model is trained based on the first public data to be distilled and a first target model output corresponding to the first public data to be distilled.
其中,所述获取各个客户端发送的第一模型输出,包括:Wherein, said obtaining the first model output sent by each client includes:
确定第二待蒸馏公开数据;所述第二待蒸馏公开数据为所述目标蒸馏公开数据中,不具有对应的第一目标模型输出的部分数据;Determine the second public data to be distilled; the second public data to be distilled is part of the target distillation public data that does not have a corresponding first target model output;
向各所述客户端发送请求;所述请求用于请求客户端回传第一模型输出;所述第一模型输出为各本地学习模型的模型输出中对应所述第二待蒸馏公开数据的部分模型输出;sending a request to each of the clients; the request is used to request the client to return the first model output; the first model output is the part corresponding to the second public data to be distilled in the model output of each local learning model model output;
接收各所述客户端回传的第一模型输出。receiving the first model output returned by each client.
可选的,第二训练单元63,具体用于:Optionally, the second training unit 63 is specifically used for:
对所述第一模型输出进行筛选,得到第二目标模型输出;所述第二目标模型输出为所述第一模型输出中对应的预测结果符合对应的真实标签的部分模型输出;Filtering the first model output to obtain a second target model output; the second target model output is a partial model output in which the corresponding prediction result in the first model output meets the corresponding real label;
确定第三待蒸馏公开数据;所述第三待蒸馏公开数据为所述第二待蒸馏公开数据中具有对应的第二目标模型部分输出数据;Determine the third public data to be distilled; the third public data to be distilled is part output data of the second target model corresponding to the second public data to be distilled;
基于第三待蒸馏公开数据和各所述第二目标模型输出,对服务器端模型进行再训练。Based on the third public data to be distilled and the output of each of the second target models, the server-side model is retrained.
可选的,所述基于第三待蒸馏公开数据和各所述第二目标模型输出,对服务器端模型进行再训练,包括:Optionally, the retraining of the server-side model based on the third public data to be distilled and the output of each of the second target models includes:
确定各所述第二目标模型输出中的模型输出的信息熵;determining an information entropy of a model output in each of said second target model outputs;
基于所述信息熵的大小确定各所述第二目标模型输出中的模型输出的权值;determining the weights of the model outputs in each of the second target model outputs based on the size of the information entropy;
基于所述权值对各所述第二目标模型输出进行融合,得到第三目标模 型输出;Based on the weights, each of the second target model outputs is fused to obtain a third target model output;
基于所述第三待蒸馏公开数据和所述第三目标模型输出,对服务器端模型进行再训练。Retraining the server-side model based on the third public data to be distilled and the output of the third target model.
可选的,所述公开数据集和所述隐私数据集包括:与实体相关的图像数据、文本数据或声音数据。Optionally, the public data set and the private data set include: image data, text data or sound data related to entities.
图7为本申请实施例提供的电子设备的结构示意图,如图7所示,该电子设备可以包括:处理器(processor)710、通信接口(Communications Interface)720、存储器(memory)730和通信总线740,其中,处理器710,通信接口720,存储器730通过通信总线740完成相互间的通信。处理器710可以调用存储器730中的逻辑命令,以执行如下方法:基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练。Figure 7 is a schematic structural diagram of an electronic device provided by the embodiment of the present application, as shown in Figure 7, the electronic device may include: a processor (processor) 710, a communication interface (Communications Interface) 720, a memory (memory) 730 and a communication bus 740 , where the processor 710 , the communication interface 720 , and the memory 730 communicate with each other through the communication bus 740 . The processor 710 can call the logic command in the memory 730 to execute the following method: based on the public data set and the real label corresponding to the public data set, train the server-side model; obtain the first model output sent by each client The first model output is obtained by the client inputting the public data set into the local learning model; the local learning model is obtained by the client training the preset model based on the private data set and the corresponding label; The public data corresponding to the first model output and each of the first model outputs are used to retrain the server-side model; the public data set is input into the retrained server-side model to obtain the second model output; The second model output is sent to each of the clients, so that each of the clients can retrain the local learning model based on the second model output and the public data set.
此外,上述的存储器730中的逻辑命令可以通过软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干命令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序 代码的介质。In addition, the above-mentioned logic commands in the memory 730 can be implemented in the form of software function units and can be stored in a computer-readable storage medium when sold or used as an independent product. Based on this understanding, the technical solution of the present application is essentially or the part that contributes to the prior art or the part of the technical solution can be embodied in the form of a software product, and the computer software product is stored in a storage medium, including Several commands are used to make a computer device (which may be a personal computer, a server, or a network device, etc.) execute all or part of the steps of the methods described in the various embodiments of the present application. The aforementioned storage media include: U disk, mobile hard disk, read-only memory (ROM, Read-Only Memory), random access memory (RAM, Random Access Memory), magnetic disk or optical disc, etc., which can store program codes. .
本申请实施例还提供一种非暂态计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现以执行上述各实施例提供的方法,例如包括:基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练。The embodiment of the present application also provides a non-transitory computer-readable storage medium, on which a computer program is stored, and when the computer program is executed by a processor, it is implemented to perform the methods provided by the above-mentioned embodiments, for example, including: and the real label corresponding to the public data set, train the server-side model; obtain the first model output sent by each client; the first model output is obtained by the client inputting the public data set into the local learning model The local learning model is obtained by the client from training the preset model based on the private data set and the corresponding label; based on the public data corresponding to the output of each of the first models and the output of each of the first models, the server-side retraining the model; inputting the public data set into the retrained server-side model to obtain a second model output; sending the second model output to each of the clients for each of the clients Retraining of the local learning model is performed based on the second model output and the public dataset.
以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部装置来实现本实施例方案的目的。本领域普通技术人员在不付出创造性的劳动的情况下,即可以理解并实施。The device embodiments described above are only illustrative, and the units described as separate components may or may not be physically separated, and the components shown as units may or may not be physical units, that is, they may be located in One place, or it can be distributed to multiple network elements. Part or all of the devices can be selected according to actual needs to achieve the purpose of the solution of this embodiment. It can be understood and implemented by those skilled in the art without any creative efforts.
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到各实施方式可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件。基于这样的理解,上述技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品可以存储在计算机可读存储介质中,如ROM/RAM、磁碟、光盘等,包括若干命令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行各个实施例或者实施例的某些部分所述的方法。Through the above description of the implementations, those skilled in the art can clearly understand that each implementation can be implemented by means of software plus a necessary general hardware platform, and of course also by hardware. Based on this understanding, the essence of the above technical solution or the part that contributes to the prior art can be embodied in the form of software products, and the computer software products can be stored in computer-readable storage media, such as ROM/RAM, magnetic Disc, CD, etc., including several commands to make a computer device (which may be a personal computer, server, or network device, etc.) execute the methods described in various embodiments or some parts of the embodiments.
最后应说明的是:以上实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修 改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围。Finally, it should be noted that: the above embodiments are only used to illustrate the technical solutions of the present application, rather than limiting them; although the present application has been described in detail with reference to the foregoing embodiments, those of ordinary skill in the art should understand that: it can still Modifications are made to the technical solutions described in the foregoing embodiments, or equivalent replacements are made to some of the technical features; and these modifications or replacements do not make the essence of the corresponding technical solutions deviate from the spirit and scope of the technical solutions of the various embodiments of the present application.

Claims (10)

  1. 一种基于隐私数据集的模型训练方法,包括:A model training method based on a privacy dataset, comprising:
    基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;Based on the public data set and the real label corresponding to the public data set, the server-side model is trained;
    获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;Obtain the first model output sent by each client; the first model output is obtained by the client inputting the public data set into the local learning model; the local learning model is obtained by the client based on the private data set and corresponding label pair Suppose the model is trained;
    基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;retraining the server-side model based on the public data corresponding to each of said first model outputs and each of said first model outputs;
    将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;inputting the public data set into the retrained server-side model to obtain a second model output;
    将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练。The second model output is sent to each of the clients, so that each of the clients can retrain the local learning model based on the second model output and the public data set.
  2. 根据权利要求1所述的基于隐私数据集的模型训练方法,其中,所述基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练,包括:The method for training a model based on a private data set according to claim 1, wherein the training of the server-side model based on the public data set and the real label corresponding to the public data set includes:
    将所述公开数据集输入服务器端模型得到预测结果;Inputting the public data set into the server-side model to obtain prediction results;
    基于所述预测结果与所述真实标签之间的交叉熵损失函数,对所述服务器端模型进行训练。The server-side model is trained based on a cross-entropy loss function between the predicted result and the true label.
  3. 根据权利要求2所述的基于隐私数据集的模型训练方法,其中,所述基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练,还包括:The method for training a model based on a private data set according to claim 2, wherein the training of the server-side model based on the public data set and the real label corresponding to the public data set further includes:
    确定并存储第一目标模型输出;所述第一目标模型输出为与目标公开数据对应的模型输出;所述目标公开数据为所述公开数据集中,被输入服务器端模型后得到的预测结果符合对应真实标签的公开数据;Determine and store the first target model output; the first target model output is the model output corresponding to the target public data; the target public data is the public data set, and the prediction result obtained after being input into the server-side model conforms to the corresponding Publicly available data on real labels;
    确定目标待蒸馏公开数据;所述目标待蒸馏公开数据为所述公开数据集中,输入服务器端模型后得到的预测结果不符合对应真实标签的公开数据;Determine the target public data to be distilled; the target public data to be distilled is the public data set, and the prediction result obtained after inputting the server-side model does not conform to the public data corresponding to the real label;
    确定第一待蒸馏公开数据;所述第一待蒸馏公开数据为所述目标待蒸馏公开数据中,具有对应的第一目标模型输出的部分数据;Determine the first public data to be distilled; the first public data to be distilled is part of the target public data to be distilled that has a corresponding first target model output;
    基于所述第一待蒸馏公开数据和与所述第一待蒸馏公开数据对应的第一目标模型输出,对所述服务器端模型进行训练。The server-side model is trained based on the first public data to be distilled and a first target model output corresponding to the first public data to be distilled.
  4. 根据权利要求3所述的基于隐私数据集的模型训练方法,其中,所述获取各个客户端发送的第一模型输出,包括:The model training method based on a private data set according to claim 3, wherein said obtaining the first model output sent by each client includes:
    确定第二待蒸馏公开数据;所述第二待蒸馏公开数据为所述目标蒸馏公开数据中,不具有对应的第一目标模型输出的部分数据;Determine the second public data to be distilled; the second public data to be distilled is part of the target distillation public data that does not have a corresponding first target model output;
    向各所述客户端发送请求;所述请求用于请求客户端回传第一模型输出;所述第一模型输出为各本地学习模型的模型输出中对应所述第二待蒸馏公开数据的部分模型输出;sending a request to each of the clients; the request is used to request the client to return the first model output; the first model output is the part corresponding to the second public data to be distilled in the model output of each local learning model model output;
    接收各所述客户端回传的第一模型输出。receiving the first model output returned by each client.
  5. 根据权利要求4所述的基于隐私数据集的模型训练方法,其中,所述基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练,包括:The method for training a model based on a private data set according to claim 4, wherein the server-side model is retrained based on the public data corresponding to each of the first model outputs and each of the first model outputs, include:
    对所述第一模型输出进行筛选,得到第二目标模型输出;所述第二目标模型输出为所述第一模型输出中对应的预测结果符合对应的真实标签的部分模型输出;Filtering the first model output to obtain a second target model output; the second target model output is a partial model output in which the corresponding prediction result in the first model output meets the corresponding real label;
    确定第三待蒸馏公开数据;所述第三待蒸馏公开数据为所述第二待蒸馏公开数据中具有对应的第二目标模型部分输出数据;Determine the third public data to be distilled; the third public data to be distilled is part output data of the second target model corresponding to the second public data to be distilled;
    基于第三待蒸馏公开数据和各所述第二目标模型输出,对服务器端模型进行再训练。Based on the third public data to be distilled and the output of each of the second target models, the server-side model is retrained.
  6. 根据权利要求5所述的基于隐私数据集的模型训练方法,其中,所述基于第三待蒸馏公开数据和各所述第二目标模型输出,对服务器端模型进行再训练,包括:The model training method based on a private data set according to claim 5, wherein said retraining the server-side model based on the third public data to be distilled and the output of each of said second target models includes:
    确定各所述第二目标模型输出中的模型输出的信息熵;determining an information entropy of a model output in each of said second target model outputs;
    基于所述信息熵的大小确定各所述第二目标模型输出中的模型输出的权值;determining the weights of the model outputs in each of the second target model outputs based on the size of the information entropy;
    基于所述权值对各所述第二目标模型输出进行融合,得到第三目标模型输出;fusing the outputs of each of the second target models based on the weights to obtain a third target model output;
    基于所述第三待蒸馏公开数据和所述第三目标模型输出,对服务器端模型进行再训练。Retraining the server-side model based on the third public data to be distilled and the output of the third target model.
  7. 根据权利要求1所述的基于隐私数据集的模型训练方法,其中,所述公开数据集和所述隐私数据集包括:与实体相关的图像数据、文本数据或声音数据。The model training method based on a private data set according to claim 1, wherein the public data set and the private data set include: image data, text data or sound data related to entities.
  8. 一种基于隐私数据集的模型训练装置,包括:A model training device based on a private data set, comprising:
    第一训练单元,用于基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;The first training unit is used to train the server-side model based on the public data set and the real label corresponding to the public data set;
    获取单元,用于获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;The obtaining unit is used to obtain the first model output sent by each client; the first model output is obtained by the client inputting the public data set into the local learning model; the local learning model is obtained by the client based on the private data set And the corresponding label is obtained by training the preset model;
    第二训练单元,用于基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;A second training unit, configured to retrain the server-side model based on the public data corresponding to each of the first model outputs and each of the first model outputs;
    输入单元,用于将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;an input unit, configured to input the public data set into the retrained server-side model to obtain a second model output;
    发送单元,用于将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练。a sending unit, configured to send the second model output to each of the clients, so that each of the clients can retrain the local learning model based on the second model output and the public data set .
  9. 一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其中,所述处理器执行所述程序时实现如权利要求1至7中任一项所述的基于隐私数据集的模型训练方法的步骤。An electronic device, comprising a memory, a processor, and a computer program stored on the memory and operable on the processor, wherein, when the processor executes the program, the computer program described in any one of claims 1 to 7 is realized. The steps of the model training method based on the private data set.
  10. 一种非暂态计算机可读存储介质,其上存储有计算机程序,其中,该计算机程序被处理器执行时实现如权利要求1至7中任一项所述的基于隐私数据集的模型训练方法的步骤。A non-transitory computer-readable storage medium, on which a computer program is stored, wherein, when the computer program is executed by a processor, the method for training a model based on a private data set according to any one of claims 1 to 7 is implemented A step of.
PCT/CN2022/085131 2021-09-30 2022-04-02 Model training method and apparatus for private data set WO2023050754A1 (en)

Applications Claiming Priority (4)

Application Number Priority Date Filing Date Title
CN202111165679 2021-09-30
CN202111165679.6 2021-09-30
CN202111189306.2 2021-10-12
CN202111189306.2A CN114003949B (en) 2021-09-30 2021-10-12 Model training method and device based on private data set

Publications (1)

Publication Number Publication Date
WO2023050754A1 true WO2023050754A1 (en) 2023-04-06

Family

ID=79922769

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2022/085131 WO2023050754A1 (en) 2021-09-30 2022-04-02 Model training method and apparatus for private data set

Country Status (2)

Country Link
CN (1) CN114003949B (en)
WO (1) WO2023050754A1 (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117313869A (en) * 2023-10-30 2023-12-29 浙江大学 Large model privacy protection reasoning method based on model segmentation

Families Citing this family (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114003949B (en) * 2021-09-30 2022-08-30 清华大学 Model training method and device based on private data set
CN115238826B (en) * 2022-09-15 2022-12-27 支付宝(杭州)信息技术有限公司 Model training method and device, storage medium and electronic equipment
CN115270001B (en) * 2022-09-23 2022-12-23 宁波大学 Privacy protection recommendation method and system based on cloud collaborative learning
CN115578369B (en) * 2022-10-28 2023-09-15 佐健(上海)生物医疗科技有限公司 Online cervical cell TCT slice detection method and system based on federal learning
CN116797829A (en) * 2023-06-13 2023-09-22 北京百度网讯科技有限公司 Model generation method, image classification method, device, equipment and medium

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113052334A (en) * 2021-04-14 2021-06-29 中南大学 Method and system for realizing federated learning, terminal equipment and readable storage medium
CN113222175A (en) * 2021-04-29 2021-08-06 深圳前海微众银行股份有限公司 Information processing method and system
US20210272011A1 (en) * 2020-02-27 2021-09-02 Omron Corporation Adaptive co-distillation model
CN114003949A (en) * 2021-09-30 2022-02-01 清华大学 Model training method and device based on private data set

Family Cites Families (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111428881B (en) * 2020-03-20 2021-12-07 深圳前海微众银行股份有限公司 Recognition model training method, device, equipment and readable storage medium
CN112329052A (en) * 2020-10-26 2021-02-05 哈尔滨工业大学(深圳) Model privacy protection method and device
CN112862011A (en) * 2021-03-31 2021-05-28 中国工商银行股份有限公司 Model training method and device based on federal learning and federal learning system

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20210272011A1 (en) * 2020-02-27 2021-09-02 Omron Corporation Adaptive co-distillation model
CN113052334A (en) * 2021-04-14 2021-06-29 中南大学 Method and system for realizing federated learning, terminal equipment and readable storage medium
CN113222175A (en) * 2021-04-29 2021-08-06 深圳前海微众银行股份有限公司 Information processing method and system
CN114003949A (en) * 2021-09-30 2022-02-01 清华大学 Model training method and device based on private data set

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117313869A (en) * 2023-10-30 2023-12-29 浙江大学 Large model privacy protection reasoning method based on model segmentation
CN117313869B (en) * 2023-10-30 2024-04-05 浙江大学 Large model privacy protection reasoning method based on model segmentation

Also Published As

Publication number Publication date
CN114003949A (en) 2022-02-01
CN114003949B (en) 2022-08-30

Similar Documents

Publication Publication Date Title
WO2023050754A1 (en) Model training method and apparatus for private data set
CN110189192B (en) Information recommendation model generation method and device
US11138521B2 (en) System and method for defining and using different levels of ground truth
US20230039182A1 (en) Method, apparatus, computer device, storage medium, and program product for processing data
CN112085159B (en) User tag data prediction system, method and device and electronic equipment
WO2016090326A1 (en) Intent based digital collaboration platform architecture and design
CN111125760B (en) Model training and predicting method and system for protecting data privacy
Martínez-Villaseñor et al. Enrichment of learner profile with ubiquitous user model interoperability
Yin Research and analysis of intelligent English learning system based on improved neural network
US11847421B2 (en) Discussion support device and program for discussion support device
CN110377827B (en) Course training scene pushing method and device, medium and electronic equipment
Hodhod et al. Cybersecurity curriculum development using ai and decision support expert system
KR101429446B1 (en) System for creating contents and the method thereof
CN112241417B (en) Page data verification method and device, medium and electronic equipment
US20230351153A1 (en) Knowledge graph reasoning model, system, and reasoning method based on bayesian few-shot learning
CN116431915A (en) Cross-domain recommendation method and device based on federal learning and attention mechanism
CN109118151B (en) Work order transaction processing method and work order transaction processing system
CN114528392A (en) Block chain-based collaborative question-answering model construction method, device and equipment
Li Design and implementation of mental health consultation system for primary and secondary school students based on credibility matching model
He et al. Design of shared Internet of Things system for English translation teaching using deep learning text classification
WO2019169422A1 (en) Knowledge management system
Zhao et al. Construction of Higher Education Management Data Analysis Model Based on Association Rules
WO2023273237A1 (en) Model compression method and system, electronic device, and storage medium
Mejia A New Proposal for Virtual Academic Advisories Using ChatBots
Bushell Transition for Transformation for Sustainable Automation

Legal Events

Date Code Title Description
121 Ep: the epo has been informed by wipo that ep was designated in this application

Ref document number: 22874155

Country of ref document: EP

Kind code of ref document: A1