CN116882480A - Diffusion model driven unsupervised domain generalization method for privacy protection - Google Patents

Diffusion model driven unsupervised domain generalization method for privacy protection Download PDF

Info

Publication number
CN116882480A
CN116882480A CN202311013570.XA CN202311013570A CN116882480A CN 116882480 A CN116882480 A CN 116882480A CN 202311013570 A CN202311013570 A CN 202311013570A CN 116882480 A CN116882480 A CN 116882480A
Authority
CN
China
Prior art keywords
domain
model
client
target
virtual
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202311013570.XA
Other languages
Chinese (zh)
Inventor
王伟
孔文康
吕晓婷
刘鹏睿
陈国荣
陈政
刘敬楷
祝咏升
胡福强
段莉
李超
刘吉强
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Jiaotong University
Original Assignee
Beijing Jiaotong University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Jiaotong University filed Critical Beijing Jiaotong University
Priority to CN202311013570.XA priority Critical patent/CN116882480A/en
Publication of CN116882480A publication Critical patent/CN116882480A/en
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F21/00Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
    • G06F21/60Protecting data
    • G06F21/62Protecting access to data via a platform, e.g. using keys or access control rules
    • G06F21/6218Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
    • G06F21/6245Protecting personal data, e.g. for financial or medical purposes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/098Distributed learning, e.g. federated learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • G06N5/02Knowledge representation; Symbolic representation

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Software Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Molecular Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Bioethics (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Computer Hardware Design (AREA)
  • Computer Security & Cryptography (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

The invention provides a diffusion model driven unsupervised domain generalization method for privacy protection. The method comprises the following steps: the target server sends the trained diffusion model and the initialized global model to each client; each client samples virtual target domain data from the diffusion model, extracts specific characteristics and sharing characteristics of the domain, uploads the sharing characteristics and trained client models of the domain to a target server in the target domain, and the target server performs federal confidence voting on samples of the target domain through each client model to generate a virtual prediction domain; and the target server dynamically adjusts the weight of each client model according to the contribution of each client to the virtual prediction domain, and obtains a global model for federal issuing of the next round by using the virtual prediction domain voted by the federal confidence. The invention can better protect the privacy of the data of the target domain by using the diffusion model, has enough universality and reduces the communication pressure.

Description

Diffusion model driven unsupervised domain generalization method for privacy protection
Technical Field
The invention relates to the technical field of computer application, in particular to a diffusion model driven unsupervised domain generalization method for privacy protection.
Background
Federal learning is a distributed learning paradigm that enables multiple participants to collaboratively train a machine learning model without sharing local data, thereby protecting their data privacy and reducing communication overhead. However, federal learning also faces some challenges. Different participants may have different data characteristics, labels, or qualities, which may lead to inconsistent performance of the model under different circumstances, even in the case of negative migration. Furthermore, due to the large amount of data in real-world scenarios, specialists may face heavy labeling workload, resulting in labeling errors, label loss or no-label situations, which may compromise the model quality of a certain participant or fail to train.
Thus, unsupervised federal domain generalization has been proposed to address these issues. It does not rely on any tag information in the target domain, but uses unsupervised learning techniques to align and migrate between different source and target domains. Unsupervised federal domain generalization can effectively address challenges of data isomerism and tag loss in a liberty learning environment, enhance the adaptability and robustness of a model across multiple domains, and enable the model to be generalized to an unknown target domain. The unsupervised federal domain generalization can be applied to various fields such as medical treatment, finance, the Internet of things and the like, and different participants can utilize the technology to improve the model performance and efficiency and protect the data privacy. For example, in the field of medical image segmentation, the quality of images or segmentation accuracy acquired by different hospital equipment may be different, and the data may be subject to distribution drift, so that simple average aggregation cannot be directly applied. Some hospitals may also lack the resources and capabilities to determine image case classifications, thus requiring learning to other hospitals while protecting their own data.
Therefore, unsupervised federal domain generalization is a challenging and important research direction, but relatively more technically difficult. First, it requires that the federal model be robust and able to accommodate unknown target domains whose data distribution and characteristics may differ from those of the source domain. Second, it needs to protect the data privacy of the source and target domains, meaning that the federal model cannot access or share any raw data during training or reasoning. Third, it must cope with source domain heterogeneity and diversity, which may include noise, bias, and even malicious data that may compromise the performance of the federal model. Fourth, it requires a generic algorithmic framework to meet a wide range of potential applications in real-world scenarios, such as cross-device image recognition, cross-language natural language processing, and cross-platform voice recognition.
Currently, most of the domain generalization methods in the prior art are based on generating challenge training for token alignment. A method for generalizing privacy protection federal domain in the prior art comprises the following steps: fedKA algorithm. This algorithm utilizes feature distribution matching in a global workspace so that the global model can learn domain-invariant client features without knowing the constraints of the client data. The FedKA algorithm adopts a federal voting mechanism to generate a pseudo tag of a target domain according to the consensus of the client so as to finely tune the global model.
The processing procedure of the FedKA algorithm comprises the following steps:
the first step: the target domain initializes the global model and distributes the global model to each client.
And a second step of: and each client uses the local data and the data of the target domain to extract domain specific features and domain commonality features, and calculates to obtain the latest local model parameters.
And a third step of: and uploading the domain commonality characteristics of the clients to a target domain by the clients, and aligning the domain commonality characteristics of the clients by the target domain by using a characteristic distribution matching method to obtain a uniform global characteristic space.
Fourth step: the target server aggregates the model parameters of all clients and performs an average calculation on these model parameters to update the latest model parameters.
Fifth step: each client uses a federal voting method to generate a virtual tag for a target domain input sample according to a local training model. The target server uses these virtual tags to fine tune the target domain model.
Sixth step: all clients perform a new round of federal domain generalization.
The above-mentioned privacy protection federal domain generalization method adopting FedKA algorithm has the following disadvantages:
when the FedKA algorithm extracts the domain specific features and the domain commonality features, although the data between the clients are not uploaded locally, the clients need to access the target domain data, which causes the data privacy of the target domain to be revealed.
In each round of federal domain generalization, the method of matching characteristic distribution of a target domain is a multi-core maximum mean difference algorithm, and the method not only needs to adjust the number of cores and the distribution of each task to adapt, but also needs to carry out a large number of times of communication between each client and the target domain, so that the communication pressure is high and privacy leakage is easy to cause.
The FedKA algorithm adopts simple average aggregation, so that some clients with poor quality or malicious poisoning cannot be processed, resulting in negative migration of the domain.
The federal voting adopts average probability voting, and when facing a client with poor quality and malicious poisoning, the voting mode can cause error of a prediction label and negative migration of a domain due to large deviation of probability.
Disclosure of Invention
The invention provides a diffusion model driven unsupervised domain generalization method for privacy protection, which is used for realizing effective privacy protection on data of a target domain.
In order to achieve the above purpose, the present invention adopts the following technical scheme.
A diffusion model driven unsupervised domain generalization method facing privacy protection comprises the following steps:
step S1: training a diffusion model on a target server in a target domain;
step S2: the target server sends the trained diffusion model and the initialized global model to each client;
step S3: after each client receives the global model, sampling virtual target domain data from the diffusion model, carrying out decoupling characterization learning on the virtual target domain data, extracting specific features and sharing features of the domain, constructing a local client model by using a classifier in the specific features, and training the client model;
step S4: uploading the sharing characteristics of the domains and the trained client models to a target server in a target domain by each client, voting the federal confidence of a sample of the target domain by the target server through the sharing characteristics of the domains, obtaining a probability weight matrix by using the voted maximum possible prediction data, and generating a virtual prediction domain;
step S5: the target server dynamically adjusts the weight of each client model according to the contribution of each client to the virtual prediction domain, and aggregates each updated client model to obtain an updated global model;
step S6: and fine-tuning the updated global model by using the virtual prediction domain voted by the federal confidence coefficient to obtain a new model for federal issuing and target domain sample label prediction of the next round.
Preferably, the step S1: training a diffusion model on a target server in a target domain, comprising:
training a diffusion model U on a target server θ From distribution ofRandom extraction data-> Gradually adding noise to the data>Until after calculation of T, obtain +.> Diffusion model U θ Described by a markov chain, i.e. t=1, …, T, where β t Is a parameter for linear interpolation from 0.0001 to 0.02, T is the number of diffusion steps;
the diffusion model is used for generating virtual data domains which are distributed in the same way as the target domain data set, and the source domain at the client side restores data which are distributed in the same way as the target domain through Gaussian noise based on the diffusion model.
Preferably, the step S3: after each client receives the global model, sampling virtual target domain data from the diffusion model, performing decoupling characterization learning on the virtual target domain data, extracting specific features and sharing features of the domain, constructing a local client model by using a classifier in the specific features, and training the client model, wherein the method comprises the following steps:
the issuing of the global model includes three parts: extracting domain sharing characteristics, a mark classifier and a domain classifier, wherein the mark classifier comprises domain specific characteristics and domain independent classifier weights, and each client model samples sample data of a virtual target domain from a diffusion modelSample data +.>Mapping to feature space H:>and obtain the sharing feature of the domain->And a specific feature, using a tag classifier f in the specific feature c :/>Forming a local client model, training the client model, and a tag classifier f c :/>According to the characteristics->To predict tag categories.
Preferably, the step S4: each client uploads the shared characteristics of the domain and the trained client model to a target server in a target domain, the target server performs federal confidence voting on samples of the target domain by using the shared characteristics of the domain through each client model, and the maximum possible prediction data obtained by voting obtains a probability weight matrix, so as to generate a virtual prediction domain, and the method comprises the following steps:
sharing features of individual client domainsAnd uploading the trained client model to a target server in a target domain, wherein the target server utilizes consensus knowledge of each source domain> Expanding a virtual domain->
Preferably, said step S5: the target server dynamically adjusts the weight of each client model according to the contribution of each client to the virtual prediction domain, and aggregates each updated client model to obtain an updated global model, comprising:
define the overall knowledge quality asWherein S' ∈S, n CK Representing the number of fields exceeding CK, the max function is the maximum value of the probability of the CK function, i.e., the predicted tag class value, calculated as CC (S k )=CQ(S)-CQ(S\{S k }) usingAnd CC (S) k ) Readjusting client model L k Is calculated as the weight ofWherein->
And the target server aggregates the updated client models to obtain an updated global model G.
Preferably, said step S6: the target server performs federal confidence voting on a sample of a target domain through the updated global model, and obtains a prediction tag of the sample of the target domain, including:
locally trained client model L on each client k Regarding the global model G aggregated by consensus focusing as a pre-training model, learning virtual prediction domains generated by federal confidence voting using the global model GWherein y is K+2 =argmax c (p×η) and p=Is the predictive probability of each source domain to the target domain,/-> Is a probability weight matrix;
from virtual prediction domainsSample x K+2 Virtual feature learning is carried out on the global model G to obtainThe target domain can learn the characteristic distribution of each source domain, and fine-tune the updated global model to obtain a new global model for federal issuing of the next round and target domain sample label prediction.
According to the technical scheme provided by the embodiment of the invention, the diffusion model is used for protecting the privacy of the data of the target domain, and the method has enough universality. The decoupling characterization learning is used instead of the distance minimization calculation of the domain, so that the communication pressure is greatly reduced.
Additional aspects and advantages of the invention will be set forth in part in the description which follows, and in part will be obvious from the description, or may be learned by practice of the invention.
Drawings
In order to more clearly illustrate the technical solutions of the embodiments of the present invention, the drawings required for the description of the embodiments will be briefly described below, and it is obvious that the drawings in the following description are only some embodiments of the present invention, and other drawings may be obtained according to these drawings without inventive effort for a person skilled in the art.
FIG. 1 is a process flow diagram of an unsupervised domain generalization method driven by a diffusion model for privacy protection;
FIG. 2 is a schematic diagram of a virtual feature generation algorithm of a target domain according to the present invention;
FIG. 3 is a schematic diagram of an algorithm for aligning source domain features with virtual features according to the present invention;
FIG. 4 is a flow chart of an algorithm process for dynamic aggregation of models provided by the invention;
fig. 5 is a flowchart of an algorithm process for virtual feature learning provided by the present invention.
Detailed Description
Embodiments of the present invention are described in detail below, examples of which are illustrated in the accompanying drawings, wherein the same or similar reference numerals refer to the same or similar elements or elements having the same or similar functions throughout. The embodiments described below by referring to the drawings are exemplary only for explaining the present invention and are not to be construed as limiting the present invention.
As used herein, the singular forms "a", "an", "the" and "the" are intended to include the plural forms as well, unless expressly stated otherwise, as understood by those skilled in the art. It will be further understood that the terms "comprises" and/or "comprising," when used in this specification, specify the presence of stated features, integers, steps, operations, elements, and/or components, but do not preclude the presence or addition of one or more other features, integers, steps, operations, elements, components, and/or groups thereof. It will be understood that when an element is referred to as being "connected" or "coupled" to another element, it can be directly connected or coupled to the other element or intervening elements may also be present. Further, "connected" or "coupled" as used herein may include wirelessly connected or coupled. The term "and/or" as used herein includes any and all combinations of one or more of the associated listed items.
It will be understood by those skilled in the art that, unless otherwise defined, all terms (including technical and scientific terms) used herein have the same meaning as commonly understood by one of ordinary skill in the art to which this invention belongs. It will be further understood that terms, such as those defined in commonly used dictionaries, should be interpreted as having a meaning that is consistent with their meaning in the context of the prior art and will not be interpreted in an idealized or overly formal sense unless expressly so defined herein.
For the purpose of facilitating an understanding of the embodiments of the invention, reference will now be made to the drawings of several specific embodiments illustrated in the drawings and in no way should be taken to limit the embodiments of the invention.
The invention provides an unsupervised federal domain generalization universal framework based on a diffusion model, which can ensure the data privacy of each domain in domain generalization, reduce the communication turn of model aggregation, identify and screen out poor quality or malicious source domains, and improve the convergence speed and accuracy of a target domain model. In the invention, the data privacy and the universality of the target domain are ensured by the diffusion model, and the data privacy of the client side is realized by federal learning. And meanwhile, virtual feature distribution alignment is carried out through decoupling characterization learning, so that the communication pressure of the multi-core maximum mean value difference calculation module is reduced. In addition, the negative migration mitigation capability of the domain is achieved by dynamic adjustment of the contribution of the client's knowledge to the virtual prediction domain that the federal confidence votes. To further accelerate the target model adaptation, we learn the virtual prediction domain using the target domain model, making fine adjustments to the target model.
The processing flow of the diffusion model driven unsupervised domain generalization method facing privacy protection provided by the invention is shown in figure 1, and comprises the following processing steps:
step S1: a diffusion model is trained on a target server in a target domain. The diffusion model trains unlabeled data of the target domain, such as (mt, mm, sv, sy) to up generalization task in the digit-5 dataset, and the diffusion model trains up this unlabeled picture dataset.
The diffusion model is used for generating a virtual data domain which is distributed in the same way as the target domain data set, and data which is distributed similarly to the target domain can be restored in the source domain of the client through Gaussian noise based on the diffusion model, so that the original data of the target domain can be protected, and characteristic distribution alignment can be carried out.
Step S2: the target server sends the trained diffusion model and the initialized global model to each client. The global model contains no data and is a classification model to train. The global model needs to be issued to the client in every round, the aggregate is also the global model, and the diffusion model only needs to be issued once.
Step S3: after each client receives the global model, sampling virtual target domain data from the diffusion model, carrying out decoupling characterization learning on the virtual target domain data, extracting specific features and sharing features of the domain, constructing a local client model by using a classifier in the specific features, and training the client model;
the client model is a local classification model of the client after the global classification model is issued to the source domain of each client. The function is to extract the features and predict the classification labels. The process may be seen in the algorithm tables of fig. 2 and 3.
Step S4: each client uploads the shared features of the domain and the trained client model to a target server in the target domain, which generates a virtual dataset by voting.
The target server performs federal confidence voting on the samples of the target domain through each client model, predicts unlabeled target domain data, and obtains a probability weight matrix according to the maximum possible predicted data obtained by voting. And comparing the probability matrix of each client with the average probability, if the probability matrix is larger than the average probability, the support number is +1, and the support number is the last predictive label, so that a virtual data set is generated.
Step S5: and the target server dynamically adjusts the weight of each client model according to the contribution of each client to the virtual data set, and aggregates each updated client model to obtain an updated global model.
Step S6: and fine-tuning the updated global model by using the virtual prediction domain voted by the federal confidence coefficient to obtain a new model for federal issuing and target domain sample label prediction of the next round.
Specifically, the step S1 includes: training a diffusion model U on a target server θ . First from distributionRandom extraction data->Gradually adding noise to the data>Until after the T-th calculation, obtainA simple distribution (e.g., gaussian distribution) is achieved. The diffusion model can be described by a markov chain, i.e. t=1, …, T, where β t Is a parameter for linear interpolation from 0.0001 to 0.02, and T is the number of diffusion steps. For different fields, different U's are used θ Training was performed, we used DDPM, diffusion-LM, diffWave for training, respectively, for pictures, text and sound.
Specifically, the aboveThe step S2 comprises the following steps: obtaining a trained diffusion model U θ Then randomly extracting noise samples from Gaussian noise distributionHandle x T As the result of T times of calculation, the original sample x is reversely recovered 0 Thus, we can get a virtual dataset +.>This data set->The feature distribution of (a) is approximately the same as the target domain. For pictures, dpm-solver++ accelerated sampling was used, while text and sound maintained the original model sampling method.
The global model is a classification model, and the virtual data set is a data set with characteristic distribution similar to that of the target domain data, so that characteristic alignment can be performed without accessing the target domain data. The target server sends the trained diffusion model and the initialized global model to each client.
Specifically, the step S3 includes: each client model samples sample data of the virtual target domain from the diffusion modelSample data +.>Mapping to feature space H:>and obtain the sharing feature of the domain->Then label classifier f c :/>According to the characteristics->To predict tag categories.
The globally issued model is divided into three parts 1, an extracted domain sharing feature 2, a label classifier 3 and a domain classifier. The label classifier includes domain specific features and domain independent classifier weights. The goal of the algorithm is to quickly generalize to unknown target domains by extracting reliable shared features of each domain to predict.
Fig. 2 is a schematic diagram of a virtual feature generation algorithm of a target domain according to the present invention. To achieve a more accurate classifier and minimize the feature differences between the client and target domains, we design the arbiter as a domain classifier f d H (x) →DP, indicateFrom client->Or virtual generation domainThe gradient of the feature function H (x) is then reversed and the label classification loss is increasedTo confuse the arbiter and get +.>And->The difference between them is small. In order to reduce the influence of the specific features of the domain, we used hyperbolic tangent function in the first few rounds of trainingTo control gradient inversionIntensity of>B is the current batch number, E is the current epoch number, E is the total epoch number, and B is the total batch number. Thus, the whole module C can be expressed as +.>Wherein the method comprises the steps of
Fig. 3 is a schematic diagram of an algorithm for aligning source domain features with virtual features according to the present invention. This process is a virtual feature alignment part. The purpose is to extract the shared features of the domains for alignment. The domain-specific features are not used for alignment because the features belonging to each domain are not generic. The domain-specific features are used only for local model training.
Specifically, the step S4 includes: each client uploads the shared characteristics of the domain to a target server in the target domain, the target server utilizing knowledge of the identity of each source domain Expanding a virtual domain->
Specifically, the step S5 includes:
define the overall knowledge quality asWherein S' ∈S, n CK Representing the number of domains exceeding CK, the max function is the maximum probability of the CK functionThe value, i.e., the predicted tag class value. From the CQ, the knowledge contribution of each client can be calculated as CC (S k )=CQ(S)-CQ(S\{S k }) then we can useAnd CC (S) k ) Readjusting client model L k Is a weight of (2). This can be calculated as +.>Wherein->
And the target server aggregates the updated client models to obtain an updated global model G. FIG. 4 is a flow chart of an algorithm process for dynamic aggregation of models provided by the invention.
Specifically, the step S6 includes: fig. 5 is a flowchart of an algorithm process for virtual feature learning provided by the present invention. By referring to the concept of meta learning, a method called virtual feature learning is designed, and a client model L is locally trained on each client in one round k Consider the global model G aggregated by consensus focusing as a pre-training model, which is regarded as a task. Then, learning the virtual prediction domain generated by federal confidence vote using global model GWherein y is K+2 =argmax c (P.eta.) and->Is the predictive probability of each source domain to the target domain,/-> Is a probability weight matrix. From virtual prediction domainsSample x K+2 Performing virtual feature learning on the global model G to obtain +.> The target domain can learn the characteristic distribution of each source domain faster, and convergence is accelerated.
And the target server learns the data of the virtual prediction domain and adjusts the global model to obtain a new model for federal issuing of the next round and target domain sample label prediction.
The spreading model is unchanged. Only one pass is required. The client model and the global model are the same classification model, and are only the global divided into targets and the independence of the clients. The global model after aggregation of each round needs to be issued to each client to replace the model of the previous round, is consistent with federal learning, and is divided into three parts, namely a feature extractor, a label classifier and a domain classifier.
In summary, the embodiment of the invention can perform better privacy protection on the data of the target domain by using the diffusion model, and has enough universality. The decoupling characterization learning is used instead of the distance minimization calculation of the domain, so that the communication pressure is greatly reduced.
The method uses federal confidence voting and knowledge focusing to alleviate the negative migration and distribution offset of the domain. The use of a virtual feature learning module may accelerate the convergence of the model.
Those of ordinary skill in the art will appreciate that: the drawing is a schematic diagram of one embodiment and the modules or flows in the drawing are not necessarily required to practice the invention.
From the above description of embodiments, it will be apparent to those skilled in the art that the present invention may be implemented in software plus a necessary general hardware platform. Based on such understanding, the technical solution of the present invention may be embodied essentially or in a part contributing to the prior art in the form of a software product, which may be stored in a storage medium, such as a ROM/RAM, a magnetic disk, an optical disk, etc., including several instructions for causing a computer device (which may be a personal computer, a server, or a network device, etc.) to execute the method described in the embodiments or some parts of the embodiments of the present invention.
In this specification, each embodiment is described in a progressive manner, and identical and similar parts of each embodiment are all referred to each other, and each embodiment mainly describes differences from other embodiments. In particular, for apparatus or system embodiments, since they are substantially similar to method embodiments, the description is relatively simple, with reference to the description of method embodiments in part. The apparatus and system embodiments described above are merely illustrative, wherein the elements illustrated as separate elements may or may not be physically separate, and the elements shown as elements may or may not be physical elements, may be located in one place, or may be distributed over a plurality of network elements. Some or all of the modules may be selected according to actual needs to achieve the purpose of the solution of this embodiment. Those of ordinary skill in the art will understand and implement the present invention without undue burden.
The present invention is not limited to the above-mentioned embodiments, and any changes or substitutions that can be easily understood by those skilled in the art within the technical scope of the present invention are intended to be included in the scope of the present invention. Therefore, the protection scope of the present invention should be subject to the protection scope of the claims.

Claims (6)

1. The diffusion model driven unsupervised domain generalization method facing privacy protection is characterized by comprising the following steps of:
step S1: training a diffusion model on a target server in a target domain;
step S2: the target server sends the trained diffusion model and the initialized global model to each client;
step S3: after each client receives the global model, sampling virtual target domain data from the diffusion model, carrying out decoupling characterization learning on the virtual target domain data, extracting specific features and sharing features of the domain, constructing a local client model by using a classifier in the specific features, and training the client model;
step S4: uploading the sharing characteristics of the domains and the trained client models to a target server in a target domain by each client, voting the federal confidence of a sample of the target domain by the target server through the sharing characteristics of the domains, obtaining a probability weight matrix by using the voted maximum possible prediction data, and generating a virtual prediction domain;
step S5: the target server dynamically adjusts the weight of each client model according to the contribution of each client to the virtual prediction domain, and aggregates each updated client model to obtain an updated global model;
step S6: and fine-tuning the updated global model by using the virtual prediction domain voted by the federal confidence coefficient to obtain a new model for federal issuing and target domain sample label prediction of the next round.
2. The method according to claim 1, wherein said step S1: training a diffusion model on a target server in a target domain, comprising:
training a diffusion model U on a target server θ From distribution ofRandom extraction data-> Gradually adding noise to the data>Until after calculation of T, obtain +.> Diffusion model U θ Described by a markov chain, i.e. t=1, …, T, where β t Is a parameter for linear interpolation from 0.0001 to 0.02, T is the number of diffusion steps;
the diffusion model is used for generating virtual data domains which are distributed in the same way as the target domain data set, and the source domain at the client side restores data which are distributed in the same way as the target domain through Gaussian noise based on the diffusion model.
3. The method according to claim 2, wherein said step S3: after each client receives the global model, sampling virtual target domain data from the diffusion model, performing decoupling characterization learning on the virtual target domain data, extracting specific features and sharing features of the domain, constructing a local client model by using a classifier in the specific features, and training the client model, wherein the method comprises the following steps:
the issuing of the global model includes three parts: extracting domain sharing characteristics, a mark classifier and a domain classifier, wherein the mark classifier comprises domain specific characteristics and domain independent classifier weights, and each client model samples sample data of a virtual target domain from a diffusion modelSample data +.>Mapping to feature space->And obtain the sharing feature of the domain->And a specific feature, using a tag classifier in the specific feature +.>Constructing a local client model, training the client model, and performing tag classifier +.>According to the characteristics->To predict tag categories.
4. A method according to claim 3, wherein said step S4: each client uploads the shared characteristics of the domain and the trained client model to a target server in a target domain, the target server performs federal confidence voting on samples of the target domain by using the shared characteristics of the domain through each client model, and the maximum possible prediction data obtained by voting obtains a probability weight matrix, so as to generate a virtual prediction domain, and the method comprises the following steps:
sharing features of individual client domainsAnd uploading the trained client model to a target server in a target domain, wherein the target server utilizes consensus knowledge of each source domain> Expanding a virtual domain->
5. The method according to claim 4, wherein said step S5: the target server dynamically adjusts the weight of each client model according to the contribution of each client to the virtual prediction domain, and aggregates each updated client model to obtain an updated global model, comprising:
define the overall knowledge quality asWherein S' ∈S, n CK Representing the number of fields exceeding CK, the max function is the maximum value of the probability of the CK function, i.e., the predicted tag class value, calculated as CC (S k )=CQ(S)-CQ(S\{S k }) with->And CC (Sk) to readjust the client model L k Weight of (2) calculated as +.>Wherein the method comprises the steps of
And the target server aggregates the updated client models to obtain an updated global model G.
6. The method according to claim 5, wherein said step S6: the target server performs federal confidence voting on a sample of a target domain through the updated global model, and obtains a prediction tag of the sample of the target domain, including:
locally trained client model L on each client k Regarding the global model G aggregated by consensus focusing as a pre-training model, learning virtual prediction domains generated by federal confidence voting using the global model GWherein y is K+2 =argmax c (P.eta.) and-> Is the predictive probability of each source domain to the target domain,/-> Is a probability weight matrix;
from virtual prediction domainsSample x K+2 Virtual feature learning is carried out on the global model G to obtainThe target domain can learn the characteristic distribution of each source domain, and fine-tune the updated global model to obtain a new global model for federal issuing of the next round and target domain sample label prediction.
CN202311013570.XA 2023-08-11 2023-08-11 Diffusion model driven unsupervised domain generalization method for privacy protection Pending CN116882480A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202311013570.XA CN116882480A (en) 2023-08-11 2023-08-11 Diffusion model driven unsupervised domain generalization method for privacy protection

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202311013570.XA CN116882480A (en) 2023-08-11 2023-08-11 Diffusion model driven unsupervised domain generalization method for privacy protection

Publications (1)

Publication Number Publication Date
CN116882480A true CN116882480A (en) 2023-10-13

Family

ID=88258785

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202311013570.XA Pending CN116882480A (en) 2023-08-11 2023-08-11 Diffusion model driven unsupervised domain generalization method for privacy protection

Country Status (1)

Country Link
CN (1) CN116882480A (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117910601A (en) * 2024-03-20 2024-04-19 浙江大学滨江研究院 Personalized federal potential diffusion model learning method and system

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117910601A (en) * 2024-03-20 2024-04-19 浙江大学滨江研究院 Personalized federal potential diffusion model learning method and system

Similar Documents

Publication Publication Date Title
Zhai et al. Multiple expert brainstorming for domain adaptive person re-identification
JP7470476B2 (en) Integration of models with different target classes using distillation
US10885383B2 (en) Unsupervised cross-domain distance metric adaptation with feature transfer network
Fang et al. Source-free unsupervised domain adaptation: A survey
US20220156507A1 (en) Unsupervised representation learning with contrastive prototypes
WO2022077646A1 (en) Method and apparatus for training student model for image processing
TWI832679B (en) Computer system and computer-implemented method for knowledge-preserving neural network pruning, and non-transitory computer-readable storage medium thereof
CN112446423A (en) Fast hybrid high-order attention domain confrontation network method based on transfer learning
Kimura et al. Anomaly detection using GANs for visual inspection in noisy training data
WO2022103682A1 (en) Face recognition from unseen domains via learning of semantic features
CN116882480A (en) Diffusion model driven unsupervised domain generalization method for privacy protection
US12002488B2 (en) Information processing apparatus and information processing method
WO2020092276A1 (en) Video recognition using multiple modalities
CN113821668A (en) Data classification identification method, device, equipment and readable storage medium
CN116227578A (en) Unsupervised domain adaptation method for passive domain data
Zhou et al. Unsupervised domain adaptation with adversarial distribution adaptation network
Vilalta et al. A general approach to domain adaptation with applications in astronomy
Guo et al. Fed-fsnet: Mitigating non-iid federated learning via fuzzy synthesizing network
US20240119307A1 (en) Personalized Federated Learning Via Sharable Basis Models
Firdaus et al. Personalized federated learning for heterogeneous data: A distributed edge clustering approach
Zhou et al. Progressive decoupled target-into-source multi-target domain adaptation
US20220067534A1 (en) Systems and methods for mutual information based self-supervised learning
Dhillon et al. Inference-driven metric learning for graph construction
Zuo et al. FedViT: Federated continual learning of vision transformer at edge
CN117113274A (en) Heterogeneous network data-free fusion method and system based on federal distillation

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination