CN116882480A - Diffusion model driven unsupervised domain generalization method for privacy protection - Google Patents
Diffusion model driven unsupervised domain generalization method for privacy protection Download PDFInfo
- Publication number
- CN116882480A CN116882480A CN202311013570.XA CN202311013570A CN116882480A CN 116882480 A CN116882480 A CN 116882480A CN 202311013570 A CN202311013570 A CN 202311013570A CN 116882480 A CN116882480 A CN 116882480A
- Authority
- CN
- China
- Prior art keywords
- domain
- model
- client
- target
- virtual
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000009792 diffusion process Methods 0.000 title claims abstract description 52
- 238000000034 method Methods 0.000 title claims abstract description 44
- 238000012549 training Methods 0.000 claims description 24
- 238000009826 distribution Methods 0.000 claims description 20
- 239000011159 matrix material Substances 0.000 claims description 10
- 230000006870 function Effects 0.000 claims description 9
- 238000004364 calculation method Methods 0.000 claims description 8
- 238000012512 characterization method Methods 0.000 claims description 8
- 238000005070 sampling Methods 0.000 claims description 7
- 238000000605 extraction Methods 0.000 claims description 3
- 238000013507 mapping Methods 0.000 claims description 3
- 241001040616 Pharyngostrongylus eta Species 0.000 claims description 2
- 238000004891 communication Methods 0.000 abstract description 8
- 239000000284 extract Substances 0.000 abstract description 2
- 238000004422 calculation algorithm Methods 0.000 description 18
- 230000008569 process Effects 0.000 description 7
- 230000002776 aggregation Effects 0.000 description 6
- 238000004220 aggregation Methods 0.000 description 6
- 238000010586 diagram Methods 0.000 description 6
- 238000013145 classification model Methods 0.000 description 5
- 238000013508 migration Methods 0.000 description 5
- 230000005012 migration Effects 0.000 description 5
- 238000012545 processing Methods 0.000 description 3
- 238000002372 labelling Methods 0.000 description 2
- 231100000572 poisoning Toxicity 0.000 description 2
- 230000000607 poisoning effect Effects 0.000 description 2
- LJROKJGQSPMTKB-UHFFFAOYSA-N 4-[(4-hydroxyphenyl)-pyridin-2-ylmethyl]phenol Chemical compound C1=CC(O)=CC=C1C(C=1N=CC=CC=1)C1=CC=C(O)C=C1 LJROKJGQSPMTKB-UHFFFAOYSA-N 0.000 description 1
- 230000006978 adaptation Effects 0.000 description 1
- 238000004883 computer application Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000003709 image segmentation Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 230000000116 mitigating effect Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000011218 segmentation Effects 0.000 description 1
- 230000007480 spreading Effects 0.000 description 1
- 238000003892 spreading Methods 0.000 description 1
- 238000003860 storage Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F21/00—Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
- G06F21/60—Protecting data
- G06F21/62—Protecting access to data via a platform, e.g. using keys or access control rules
- G06F21/6218—Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
- G06F21/6245—Protecting personal data, e.g. for financial or medical purposes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/098—Distributed learning, e.g. federated learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/02—Knowledge representation; Symbolic representation
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Software Systems (AREA)
- General Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Bioinformatics & Computational Biology (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Molecular Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Bioethics (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Computer Hardware Design (AREA)
- Computer Security & Cryptography (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
The invention provides a diffusion model driven unsupervised domain generalization method for privacy protection. The method comprises the following steps: the target server sends the trained diffusion model and the initialized global model to each client; each client samples virtual target domain data from the diffusion model, extracts specific characteristics and sharing characteristics of the domain, uploads the sharing characteristics and trained client models of the domain to a target server in the target domain, and the target server performs federal confidence voting on samples of the target domain through each client model to generate a virtual prediction domain; and the target server dynamically adjusts the weight of each client model according to the contribution of each client to the virtual prediction domain, and obtains a global model for federal issuing of the next round by using the virtual prediction domain voted by the federal confidence. The invention can better protect the privacy of the data of the target domain by using the diffusion model, has enough universality and reduces the communication pressure.
Description
Technical Field
The invention relates to the technical field of computer application, in particular to a diffusion model driven unsupervised domain generalization method for privacy protection.
Background
Federal learning is a distributed learning paradigm that enables multiple participants to collaboratively train a machine learning model without sharing local data, thereby protecting their data privacy and reducing communication overhead. However, federal learning also faces some challenges. Different participants may have different data characteristics, labels, or qualities, which may lead to inconsistent performance of the model under different circumstances, even in the case of negative migration. Furthermore, due to the large amount of data in real-world scenarios, specialists may face heavy labeling workload, resulting in labeling errors, label loss or no-label situations, which may compromise the model quality of a certain participant or fail to train.
Thus, unsupervised federal domain generalization has been proposed to address these issues. It does not rely on any tag information in the target domain, but uses unsupervised learning techniques to align and migrate between different source and target domains. Unsupervised federal domain generalization can effectively address challenges of data isomerism and tag loss in a liberty learning environment, enhance the adaptability and robustness of a model across multiple domains, and enable the model to be generalized to an unknown target domain. The unsupervised federal domain generalization can be applied to various fields such as medical treatment, finance, the Internet of things and the like, and different participants can utilize the technology to improve the model performance and efficiency and protect the data privacy. For example, in the field of medical image segmentation, the quality of images or segmentation accuracy acquired by different hospital equipment may be different, and the data may be subject to distribution drift, so that simple average aggregation cannot be directly applied. Some hospitals may also lack the resources and capabilities to determine image case classifications, thus requiring learning to other hospitals while protecting their own data.
Therefore, unsupervised federal domain generalization is a challenging and important research direction, but relatively more technically difficult. First, it requires that the federal model be robust and able to accommodate unknown target domains whose data distribution and characteristics may differ from those of the source domain. Second, it needs to protect the data privacy of the source and target domains, meaning that the federal model cannot access or share any raw data during training or reasoning. Third, it must cope with source domain heterogeneity and diversity, which may include noise, bias, and even malicious data that may compromise the performance of the federal model. Fourth, it requires a generic algorithmic framework to meet a wide range of potential applications in real-world scenarios, such as cross-device image recognition, cross-language natural language processing, and cross-platform voice recognition.
Currently, most of the domain generalization methods in the prior art are based on generating challenge training for token alignment. A method for generalizing privacy protection federal domain in the prior art comprises the following steps: fedKA algorithm. This algorithm utilizes feature distribution matching in a global workspace so that the global model can learn domain-invariant client features without knowing the constraints of the client data. The FedKA algorithm adopts a federal voting mechanism to generate a pseudo tag of a target domain according to the consensus of the client so as to finely tune the global model.
The processing procedure of the FedKA algorithm comprises the following steps:
the first step: the target domain initializes the global model and distributes the global model to each client.
And a second step of: and each client uses the local data and the data of the target domain to extract domain specific features and domain commonality features, and calculates to obtain the latest local model parameters.
And a third step of: and uploading the domain commonality characteristics of the clients to a target domain by the clients, and aligning the domain commonality characteristics of the clients by the target domain by using a characteristic distribution matching method to obtain a uniform global characteristic space.
Fourth step: the target server aggregates the model parameters of all clients and performs an average calculation on these model parameters to update the latest model parameters.
Fifth step: each client uses a federal voting method to generate a virtual tag for a target domain input sample according to a local training model. The target server uses these virtual tags to fine tune the target domain model.
Sixth step: all clients perform a new round of federal domain generalization.
The above-mentioned privacy protection federal domain generalization method adopting FedKA algorithm has the following disadvantages:
when the FedKA algorithm extracts the domain specific features and the domain commonality features, although the data between the clients are not uploaded locally, the clients need to access the target domain data, which causes the data privacy of the target domain to be revealed.
In each round of federal domain generalization, the method of matching characteristic distribution of a target domain is a multi-core maximum mean difference algorithm, and the method not only needs to adjust the number of cores and the distribution of each task to adapt, but also needs to carry out a large number of times of communication between each client and the target domain, so that the communication pressure is high and privacy leakage is easy to cause.
The FedKA algorithm adopts simple average aggregation, so that some clients with poor quality or malicious poisoning cannot be processed, resulting in negative migration of the domain.
The federal voting adopts average probability voting, and when facing a client with poor quality and malicious poisoning, the voting mode can cause error of a prediction label and negative migration of a domain due to large deviation of probability.
Disclosure of Invention
The invention provides a diffusion model driven unsupervised domain generalization method for privacy protection, which is used for realizing effective privacy protection on data of a target domain.
In order to achieve the above purpose, the present invention adopts the following technical scheme.
A diffusion model driven unsupervised domain generalization method facing privacy protection comprises the following steps:
step S1: training a diffusion model on a target server in a target domain;
step S2: the target server sends the trained diffusion model and the initialized global model to each client;
step S3: after each client receives the global model, sampling virtual target domain data from the diffusion model, carrying out decoupling characterization learning on the virtual target domain data, extracting specific features and sharing features of the domain, constructing a local client model by using a classifier in the specific features, and training the client model;
step S4: uploading the sharing characteristics of the domains and the trained client models to a target server in a target domain by each client, voting the federal confidence of a sample of the target domain by the target server through the sharing characteristics of the domains, obtaining a probability weight matrix by using the voted maximum possible prediction data, and generating a virtual prediction domain;
step S5: the target server dynamically adjusts the weight of each client model according to the contribution of each client to the virtual prediction domain, and aggregates each updated client model to obtain an updated global model;
step S6: and fine-tuning the updated global model by using the virtual prediction domain voted by the federal confidence coefficient to obtain a new model for federal issuing and target domain sample label prediction of the next round.
Preferably, the step S1: training a diffusion model on a target server in a target domain, comprising:
training a diffusion model U on a target server θ From distribution ofRandom extraction data-> Gradually adding noise to the data>Until after calculation of T, obtain +.> Diffusion model U θ Described by a markov chain, i.e. t=1, …, T, where β t Is a parameter for linear interpolation from 0.0001 to 0.02, T is the number of diffusion steps;
the diffusion model is used for generating virtual data domains which are distributed in the same way as the target domain data set, and the source domain at the client side restores data which are distributed in the same way as the target domain through Gaussian noise based on the diffusion model.
Preferably, the step S3: after each client receives the global model, sampling virtual target domain data from the diffusion model, performing decoupling characterization learning on the virtual target domain data, extracting specific features and sharing features of the domain, constructing a local client model by using a classifier in the specific features, and training the client model, wherein the method comprises the following steps:
the issuing of the global model includes three parts: extracting domain sharing characteristics, a mark classifier and a domain classifier, wherein the mark classifier comprises domain specific characteristics and domain independent classifier weights, and each client model samples sample data of a virtual target domain from a diffusion modelSample data +.>Mapping to feature space H:>and obtain the sharing feature of the domain->And a specific feature, using a tag classifier f in the specific feature c :/>Forming a local client model, training the client model, and a tag classifier f c :/>According to the characteristics->To predict tag categories.
Preferably, the step S4: each client uploads the shared characteristics of the domain and the trained client model to a target server in a target domain, the target server performs federal confidence voting on samples of the target domain by using the shared characteristics of the domain through each client model, and the maximum possible prediction data obtained by voting obtains a probability weight matrix, so as to generate a virtual prediction domain, and the method comprises the following steps:
sharing features of individual client domainsAnd uploading the trained client model to a target server in a target domain, wherein the target server utilizes consensus knowledge of each source domain> Expanding a virtual domain->
Preferably, said step S5: the target server dynamically adjusts the weight of each client model according to the contribution of each client to the virtual prediction domain, and aggregates each updated client model to obtain an updated global model, comprising:
define the overall knowledge quality asWherein S' ∈S, n CK Representing the number of fields exceeding CK, the max function is the maximum value of the probability of the CK function, i.e., the predicted tag class value, calculated as CC (S k )=CQ(S)-CQ(S\{S k }) usingAnd CC (S) k ) Readjusting client model L k Is calculated as the weight ofWherein->
And the target server aggregates the updated client models to obtain an updated global model G.
Preferably, said step S6: the target server performs federal confidence voting on a sample of a target domain through the updated global model, and obtains a prediction tag of the sample of the target domain, including:
locally trained client model L on each client k Regarding the global model G aggregated by consensus focusing as a pre-training model, learning virtual prediction domains generated by federal confidence voting using the global model GWherein y is K+2 =argmax c (p×η) and p=Is the predictive probability of each source domain to the target domain,/-> Is a probability weight matrix;
from virtual prediction domainsSample x K+2 Virtual feature learning is carried out on the global model G to obtainThe target domain can learn the characteristic distribution of each source domain, and fine-tune the updated global model to obtain a new global model for federal issuing of the next round and target domain sample label prediction.
According to the technical scheme provided by the embodiment of the invention, the diffusion model is used for protecting the privacy of the data of the target domain, and the method has enough universality. The decoupling characterization learning is used instead of the distance minimization calculation of the domain, so that the communication pressure is greatly reduced.
Additional aspects and advantages of the invention will be set forth in part in the description which follows, and in part will be obvious from the description, or may be learned by practice of the invention.
Drawings
In order to more clearly illustrate the technical solutions of the embodiments of the present invention, the drawings required for the description of the embodiments will be briefly described below, and it is obvious that the drawings in the following description are only some embodiments of the present invention, and other drawings may be obtained according to these drawings without inventive effort for a person skilled in the art.
FIG. 1 is a process flow diagram of an unsupervised domain generalization method driven by a diffusion model for privacy protection;
FIG. 2 is a schematic diagram of a virtual feature generation algorithm of a target domain according to the present invention;
FIG. 3 is a schematic diagram of an algorithm for aligning source domain features with virtual features according to the present invention;
FIG. 4 is a flow chart of an algorithm process for dynamic aggregation of models provided by the invention;
fig. 5 is a flowchart of an algorithm process for virtual feature learning provided by the present invention.
Detailed Description
Embodiments of the present invention are described in detail below, examples of which are illustrated in the accompanying drawings, wherein the same or similar reference numerals refer to the same or similar elements or elements having the same or similar functions throughout. The embodiments described below by referring to the drawings are exemplary only for explaining the present invention and are not to be construed as limiting the present invention.
As used herein, the singular forms "a", "an", "the" and "the" are intended to include the plural forms as well, unless expressly stated otherwise, as understood by those skilled in the art. It will be further understood that the terms "comprises" and/or "comprising," when used in this specification, specify the presence of stated features, integers, steps, operations, elements, and/or components, but do not preclude the presence or addition of one or more other features, integers, steps, operations, elements, components, and/or groups thereof. It will be understood that when an element is referred to as being "connected" or "coupled" to another element, it can be directly connected or coupled to the other element or intervening elements may also be present. Further, "connected" or "coupled" as used herein may include wirelessly connected or coupled. The term "and/or" as used herein includes any and all combinations of one or more of the associated listed items.
It will be understood by those skilled in the art that, unless otherwise defined, all terms (including technical and scientific terms) used herein have the same meaning as commonly understood by one of ordinary skill in the art to which this invention belongs. It will be further understood that terms, such as those defined in commonly used dictionaries, should be interpreted as having a meaning that is consistent with their meaning in the context of the prior art and will not be interpreted in an idealized or overly formal sense unless expressly so defined herein.
For the purpose of facilitating an understanding of the embodiments of the invention, reference will now be made to the drawings of several specific embodiments illustrated in the drawings and in no way should be taken to limit the embodiments of the invention.
The invention provides an unsupervised federal domain generalization universal framework based on a diffusion model, which can ensure the data privacy of each domain in domain generalization, reduce the communication turn of model aggregation, identify and screen out poor quality or malicious source domains, and improve the convergence speed and accuracy of a target domain model. In the invention, the data privacy and the universality of the target domain are ensured by the diffusion model, and the data privacy of the client side is realized by federal learning. And meanwhile, virtual feature distribution alignment is carried out through decoupling characterization learning, so that the communication pressure of the multi-core maximum mean value difference calculation module is reduced. In addition, the negative migration mitigation capability of the domain is achieved by dynamic adjustment of the contribution of the client's knowledge to the virtual prediction domain that the federal confidence votes. To further accelerate the target model adaptation, we learn the virtual prediction domain using the target domain model, making fine adjustments to the target model.
The processing flow of the diffusion model driven unsupervised domain generalization method facing privacy protection provided by the invention is shown in figure 1, and comprises the following processing steps:
step S1: a diffusion model is trained on a target server in a target domain. The diffusion model trains unlabeled data of the target domain, such as (mt, mm, sv, sy) to up generalization task in the digit-5 dataset, and the diffusion model trains up this unlabeled picture dataset.
The diffusion model is used for generating a virtual data domain which is distributed in the same way as the target domain data set, and data which is distributed similarly to the target domain can be restored in the source domain of the client through Gaussian noise based on the diffusion model, so that the original data of the target domain can be protected, and characteristic distribution alignment can be carried out.
Step S2: the target server sends the trained diffusion model and the initialized global model to each client. The global model contains no data and is a classification model to train. The global model needs to be issued to the client in every round, the aggregate is also the global model, and the diffusion model only needs to be issued once.
Step S3: after each client receives the global model, sampling virtual target domain data from the diffusion model, carrying out decoupling characterization learning on the virtual target domain data, extracting specific features and sharing features of the domain, constructing a local client model by using a classifier in the specific features, and training the client model;
the client model is a local classification model of the client after the global classification model is issued to the source domain of each client. The function is to extract the features and predict the classification labels. The process may be seen in the algorithm tables of fig. 2 and 3.
Step S4: each client uploads the shared features of the domain and the trained client model to a target server in the target domain, which generates a virtual dataset by voting.
The target server performs federal confidence voting on the samples of the target domain through each client model, predicts unlabeled target domain data, and obtains a probability weight matrix according to the maximum possible predicted data obtained by voting. And comparing the probability matrix of each client with the average probability, if the probability matrix is larger than the average probability, the support number is +1, and the support number is the last predictive label, so that a virtual data set is generated.
Step S5: and the target server dynamically adjusts the weight of each client model according to the contribution of each client to the virtual data set, and aggregates each updated client model to obtain an updated global model.
Step S6: and fine-tuning the updated global model by using the virtual prediction domain voted by the federal confidence coefficient to obtain a new model for federal issuing and target domain sample label prediction of the next round.
Specifically, the step S1 includes: training a diffusion model U on a target server θ . First from distributionRandom extraction data->Gradually adding noise to the data>Until after the T-th calculation, obtainA simple distribution (e.g., gaussian distribution) is achieved. The diffusion model can be described by a markov chain, i.e. t=1, …, T, where β t Is a parameter for linear interpolation from 0.0001 to 0.02, and T is the number of diffusion steps. For different fields, different U's are used θ Training was performed, we used DDPM, diffusion-LM, diffWave for training, respectively, for pictures, text and sound.
Specifically, the aboveThe step S2 comprises the following steps: obtaining a trained diffusion model U θ Then randomly extracting noise samples from Gaussian noise distributionHandle x T As the result of T times of calculation, the original sample x is reversely recovered 0 Thus, we can get a virtual dataset +.>This data set->The feature distribution of (a) is approximately the same as the target domain. For pictures, dpm-solver++ accelerated sampling was used, while text and sound maintained the original model sampling method.
The global model is a classification model, and the virtual data set is a data set with characteristic distribution similar to that of the target domain data, so that characteristic alignment can be performed without accessing the target domain data. The target server sends the trained diffusion model and the initialized global model to each client.
Specifically, the step S3 includes: each client model samples sample data of the virtual target domain from the diffusion modelSample data +.>Mapping to feature space H:>and obtain the sharing feature of the domain->Then label classifier f c :/>According to the characteristics->To predict tag categories.
The globally issued model is divided into three parts 1, an extracted domain sharing feature 2, a label classifier 3 and a domain classifier. The label classifier includes domain specific features and domain independent classifier weights. The goal of the algorithm is to quickly generalize to unknown target domains by extracting reliable shared features of each domain to predict.
Fig. 2 is a schematic diagram of a virtual feature generation algorithm of a target domain according to the present invention. To achieve a more accurate classifier and minimize the feature differences between the client and target domains, we design the arbiter as a domain classifier f d H (x) →DP, indicateFrom client->Or virtual generation domainThe gradient of the feature function H (x) is then reversed and the label classification loss is increasedTo confuse the arbiter and get +.>And->The difference between them is small. In order to reduce the influence of the specific features of the domain, we used hyperbolic tangent function in the first few rounds of trainingTo control gradient inversionIntensity of>B is the current batch number, E is the current epoch number, E is the total epoch number, and B is the total batch number. Thus, the whole module C can be expressed as +.>Wherein the method comprises the steps of
Fig. 3 is a schematic diagram of an algorithm for aligning source domain features with virtual features according to the present invention. This process is a virtual feature alignment part. The purpose is to extract the shared features of the domains for alignment. The domain-specific features are not used for alignment because the features belonging to each domain are not generic. The domain-specific features are used only for local model training.
Specifically, the step S4 includes: each client uploads the shared characteristics of the domain to a target server in the target domain, the target server utilizing knowledge of the identity of each source domain Expanding a virtual domain->
Specifically, the step S5 includes:
define the overall knowledge quality asWherein S' ∈S, n CK Representing the number of domains exceeding CK, the max function is the maximum probability of the CK functionThe value, i.e., the predicted tag class value. From the CQ, the knowledge contribution of each client can be calculated as CC (S k )=CQ(S)-CQ(S\{S k }) then we can useAnd CC (S) k ) Readjusting client model L k Is a weight of (2). This can be calculated as +.>Wherein->
And the target server aggregates the updated client models to obtain an updated global model G. FIG. 4 is a flow chart of an algorithm process for dynamic aggregation of models provided by the invention.
Specifically, the step S6 includes: fig. 5 is a flowchart of an algorithm process for virtual feature learning provided by the present invention. By referring to the concept of meta learning, a method called virtual feature learning is designed, and a client model L is locally trained on each client in one round k Consider the global model G aggregated by consensus focusing as a pre-training model, which is regarded as a task. Then, learning the virtual prediction domain generated by federal confidence vote using global model GWherein y is K+2 =argmax c (P.eta.) and->Is the predictive probability of each source domain to the target domain,/-> Is a probability weight matrix. From virtual prediction domainsSample x K+2 Performing virtual feature learning on the global model G to obtain +.> The target domain can learn the characteristic distribution of each source domain faster, and convergence is accelerated.
And the target server learns the data of the virtual prediction domain and adjusts the global model to obtain a new model for federal issuing of the next round and target domain sample label prediction.
The spreading model is unchanged. Only one pass is required. The client model and the global model are the same classification model, and are only the global divided into targets and the independence of the clients. The global model after aggregation of each round needs to be issued to each client to replace the model of the previous round, is consistent with federal learning, and is divided into three parts, namely a feature extractor, a label classifier and a domain classifier.
In summary, the embodiment of the invention can perform better privacy protection on the data of the target domain by using the diffusion model, and has enough universality. The decoupling characterization learning is used instead of the distance minimization calculation of the domain, so that the communication pressure is greatly reduced.
The method uses federal confidence voting and knowledge focusing to alleviate the negative migration and distribution offset of the domain. The use of a virtual feature learning module may accelerate the convergence of the model.
Those of ordinary skill in the art will appreciate that: the drawing is a schematic diagram of one embodiment and the modules or flows in the drawing are not necessarily required to practice the invention.
From the above description of embodiments, it will be apparent to those skilled in the art that the present invention may be implemented in software plus a necessary general hardware platform. Based on such understanding, the technical solution of the present invention may be embodied essentially or in a part contributing to the prior art in the form of a software product, which may be stored in a storage medium, such as a ROM/RAM, a magnetic disk, an optical disk, etc., including several instructions for causing a computer device (which may be a personal computer, a server, or a network device, etc.) to execute the method described in the embodiments or some parts of the embodiments of the present invention.
In this specification, each embodiment is described in a progressive manner, and identical and similar parts of each embodiment are all referred to each other, and each embodiment mainly describes differences from other embodiments. In particular, for apparatus or system embodiments, since they are substantially similar to method embodiments, the description is relatively simple, with reference to the description of method embodiments in part. The apparatus and system embodiments described above are merely illustrative, wherein the elements illustrated as separate elements may or may not be physically separate, and the elements shown as elements may or may not be physical elements, may be located in one place, or may be distributed over a plurality of network elements. Some or all of the modules may be selected according to actual needs to achieve the purpose of the solution of this embodiment. Those of ordinary skill in the art will understand and implement the present invention without undue burden.
The present invention is not limited to the above-mentioned embodiments, and any changes or substitutions that can be easily understood by those skilled in the art within the technical scope of the present invention are intended to be included in the scope of the present invention. Therefore, the protection scope of the present invention should be subject to the protection scope of the claims.
Claims (6)
1. The diffusion model driven unsupervised domain generalization method facing privacy protection is characterized by comprising the following steps of:
step S1: training a diffusion model on a target server in a target domain;
step S2: the target server sends the trained diffusion model and the initialized global model to each client;
step S3: after each client receives the global model, sampling virtual target domain data from the diffusion model, carrying out decoupling characterization learning on the virtual target domain data, extracting specific features and sharing features of the domain, constructing a local client model by using a classifier in the specific features, and training the client model;
step S4: uploading the sharing characteristics of the domains and the trained client models to a target server in a target domain by each client, voting the federal confidence of a sample of the target domain by the target server through the sharing characteristics of the domains, obtaining a probability weight matrix by using the voted maximum possible prediction data, and generating a virtual prediction domain;
step S5: the target server dynamically adjusts the weight of each client model according to the contribution of each client to the virtual prediction domain, and aggregates each updated client model to obtain an updated global model;
step S6: and fine-tuning the updated global model by using the virtual prediction domain voted by the federal confidence coefficient to obtain a new model for federal issuing and target domain sample label prediction of the next round.
2. The method according to claim 1, wherein said step S1: training a diffusion model on a target server in a target domain, comprising:
training a diffusion model U on a target server θ From distribution ofRandom extraction data-> Gradually adding noise to the data>Until after calculation of T, obtain +.> Diffusion model U θ Described by a markov chain, i.e. t=1, …, T, where β t Is a parameter for linear interpolation from 0.0001 to 0.02, T is the number of diffusion steps;
the diffusion model is used for generating virtual data domains which are distributed in the same way as the target domain data set, and the source domain at the client side restores data which are distributed in the same way as the target domain through Gaussian noise based on the diffusion model.
3. The method according to claim 2, wherein said step S3: after each client receives the global model, sampling virtual target domain data from the diffusion model, performing decoupling characterization learning on the virtual target domain data, extracting specific features and sharing features of the domain, constructing a local client model by using a classifier in the specific features, and training the client model, wherein the method comprises the following steps:
the issuing of the global model includes three parts: extracting domain sharing characteristics, a mark classifier and a domain classifier, wherein the mark classifier comprises domain specific characteristics and domain independent classifier weights, and each client model samples sample data of a virtual target domain from a diffusion modelSample data +.>Mapping to feature space->And obtain the sharing feature of the domain->And a specific feature, using a tag classifier in the specific feature +.>Constructing a local client model, training the client model, and performing tag classifier +.>According to the characteristics->To predict tag categories.
4. A method according to claim 3, wherein said step S4: each client uploads the shared characteristics of the domain and the trained client model to a target server in a target domain, the target server performs federal confidence voting on samples of the target domain by using the shared characteristics of the domain through each client model, and the maximum possible prediction data obtained by voting obtains a probability weight matrix, so as to generate a virtual prediction domain, and the method comprises the following steps:
sharing features of individual client domainsAnd uploading the trained client model to a target server in a target domain, wherein the target server utilizes consensus knowledge of each source domain> Expanding a virtual domain->
5. The method according to claim 4, wherein said step S5: the target server dynamically adjusts the weight of each client model according to the contribution of each client to the virtual prediction domain, and aggregates each updated client model to obtain an updated global model, comprising:
define the overall knowledge quality asWherein S' ∈S, n CK Representing the number of fields exceeding CK, the max function is the maximum value of the probability of the CK function, i.e., the predicted tag class value, calculated as CC (S k )=CQ(S)-CQ(S\{S k }) with->And CC (Sk) to readjust the client model L k Weight of (2) calculated as +.>Wherein the method comprises the steps of
And the target server aggregates the updated client models to obtain an updated global model G.
6. The method according to claim 5, wherein said step S6: the target server performs federal confidence voting on a sample of a target domain through the updated global model, and obtains a prediction tag of the sample of the target domain, including:
locally trained client model L on each client k Regarding the global model G aggregated by consensus focusing as a pre-training model, learning virtual prediction domains generated by federal confidence voting using the global model GWherein y is K+2 =argmax c (P.eta.) and-> Is the predictive probability of each source domain to the target domain,/-> Is a probability weight matrix;
from virtual prediction domainsSample x K+2 Virtual feature learning is carried out on the global model G to obtainThe target domain can learn the characteristic distribution of each source domain, and fine-tune the updated global model to obtain a new global model for federal issuing of the next round and target domain sample label prediction.
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311013570.XA CN116882480A (en) | 2023-08-11 | 2023-08-11 | Diffusion model driven unsupervised domain generalization method for privacy protection |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311013570.XA CN116882480A (en) | 2023-08-11 | 2023-08-11 | Diffusion model driven unsupervised domain generalization method for privacy protection |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116882480A true CN116882480A (en) | 2023-10-13 |
Family
ID=88258785
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311013570.XA Pending CN116882480A (en) | 2023-08-11 | 2023-08-11 | Diffusion model driven unsupervised domain generalization method for privacy protection |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116882480A (en) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117910601A (en) * | 2024-03-20 | 2024-04-19 | 浙江大学滨江研究院 | Personalized federal potential diffusion model learning method and system |
-
2023
- 2023-08-11 CN CN202311013570.XA patent/CN116882480A/en active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117910601A (en) * | 2024-03-20 | 2024-04-19 | 浙江大学滨江研究院 | Personalized federal potential diffusion model learning method and system |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Zhai et al. | Multiple expert brainstorming for domain adaptive person re-identification | |
JP7470476B2 (en) | Integration of models with different target classes using distillation | |
US10885383B2 (en) | Unsupervised cross-domain distance metric adaptation with feature transfer network | |
Fang et al. | Source-free unsupervised domain adaptation: A survey | |
US20220156507A1 (en) | Unsupervised representation learning with contrastive prototypes | |
WO2022077646A1 (en) | Method and apparatus for training student model for image processing | |
TWI832679B (en) | Computer system and computer-implemented method for knowledge-preserving neural network pruning, and non-transitory computer-readable storage medium thereof | |
CN112446423A (en) | Fast hybrid high-order attention domain confrontation network method based on transfer learning | |
Kimura et al. | Anomaly detection using GANs for visual inspection in noisy training data | |
WO2022103682A1 (en) | Face recognition from unseen domains via learning of semantic features | |
CN116882480A (en) | Diffusion model driven unsupervised domain generalization method for privacy protection | |
US12002488B2 (en) | Information processing apparatus and information processing method | |
WO2020092276A1 (en) | Video recognition using multiple modalities | |
CN113821668A (en) | Data classification identification method, device, equipment and readable storage medium | |
CN116227578A (en) | Unsupervised domain adaptation method for passive domain data | |
Zhou et al. | Unsupervised domain adaptation with adversarial distribution adaptation network | |
Vilalta et al. | A general approach to domain adaptation with applications in astronomy | |
Guo et al. | Fed-fsnet: Mitigating non-iid federated learning via fuzzy synthesizing network | |
US20240119307A1 (en) | Personalized Federated Learning Via Sharable Basis Models | |
Firdaus et al. | Personalized federated learning for heterogeneous data: A distributed edge clustering approach | |
Zhou et al. | Progressive decoupled target-into-source multi-target domain adaptation | |
US20220067534A1 (en) | Systems and methods for mutual information based self-supervised learning | |
Dhillon et al. | Inference-driven metric learning for graph construction | |
Zuo et al. | FedViT: Federated continual learning of vision transformer at edge | |
CN117113274A (en) | Heterogeneous network data-free fusion method and system based on federal distillation |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |