CN114399055A - Domain generalization method based on federal learning - Google Patents

Domain generalization method based on federal learning Download PDF

Info

Publication number
CN114399055A
CN114399055A CN202111626157.1A CN202111626157A CN114399055A CN 114399055 A CN114399055 A CN 114399055A CN 202111626157 A CN202111626157 A CN 202111626157A CN 114399055 A CN114399055 A CN 114399055A
Authority
CN
China
Prior art keywords
feature
distribution
data
discriminator
domain
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111626157.1A
Other languages
Chinese (zh)
Inventor
黄宏宇
张莉玲
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Chongqing University
Original Assignee
Chongqing University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Chongqing University filed Critical Chongqing University
Priority to CN202111626157.1A priority Critical patent/CN114399055A/en
Publication of CN114399055A publication Critical patent/CN114399055A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F21/00Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
    • G06F21/60Protecting data
    • G06F21/62Protecting access to data via a platform, e.g. using keys or access control rules
    • G06F21/6218Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
    • G06F21/6245Protecting personal data, e.g. for financial or medical purposes

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Software Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Medical Informatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Bioethics (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Databases & Information Systems (AREA)
  • Computer Hardware Design (AREA)
  • Computer Security & Cryptography (AREA)
  • Image Analysis (AREA)

Abstract

The invention relates to a domain generalization method based on federal learning, belonging to the technical field of computers. The method comprises the following steps: aligning feature distributions and reference feature distributions of source domain data of the client by categories using the counterstudy network; the method comprises the steps that model parameters of feature extractors, distribution generators and classifiers of a plurality of clients are averaged in a server by using a federated learning framework, so that reference features of the clients are distributed consistently and are close to the centers of all active domain data feature distributions, the feature distributions output by the feature extractors are consistent, and the classifiers can classify the distributed features; through data interaction between a multi-round server and a client, the feature distribution of a plurality of source domains is aligned according to categories, and at the moment, the feature extractor and the classifier aggregated by the server can be well generalized from the multi-source domains to the target domain. Under the condition of protecting the privacy of the data of the source domain, the method can learn the model with good generalization capability in the target domain without the data of the target scene.

Description

Domain generalization method based on federal learning
Technical Field
The invention belongs to the technical field of computers, and relates to a domain generalization method based on federal learning.
Background
With the arrival of the big data era, the requirement of a large amount of data required by the deep learning field is greatly met. However, the deep learning task usually requires data carrying tags, which is very difficult for new fields, especially for internet environments where cold start problems occur. Therefore, there is a need to train a deep learning model on an existing tagged dataset for an untagged dataset that does not participate in the training. To increase the training data as much as possible, and to improve the generalization ability of the model, the collected data often comes from multiple data sources. The distribution difference between the source domain data and the distribution difference between the source domain and the unknown target domain data cause that the model can not have good effect on the unknown target domain. How to train a model with good performance in the target domain over a multi-source domain is called the domain generalization problem.
In particular, domain generalization refers to learning a model with good generalization capability at unknown target domains on data from tagged source domains. In order to improve domain generalization performance, i.e. the accuracy of the model in classifying the target domain data, the data of the multi-source domain is usually uploaded to a central server for training. However, a centralized domain generalization training method is not practical for reasons such as data related to privacy or user reluctance to share directly. Federal learning can provide some protection for data privacy of the source domain in the domain generalization problem. The federated learning is a decentralized learning method, data of a plurality of source domains are distributed on different clients, and a server receives model parameters from the clients and aggregates the model parameters to obtain a global model. Therefore, data of the source domain does not need to be revealed to other untrusted third parties, and privacy of the data of the source domain is protected. Therefore, based on federal learning, it is a main research scenario of the present invention how to learn a generalization model on labeled distributed multi-source domain data to apply to label-free target domains and obtain as high classification accuracy as possible. Because the federal learning only exchanges the parameters of the model with the client, the difference of data distribution among source domains is not considered, and the source domain data may be different from the unknown target domain data, the global model obtained by the server aggregation is directly applied to the target domain, and the good accuracy cannot be guaranteed. Therefore, how to improve the generalization capability of the global model in the target domain under the scene of protecting the data privacy of the source domain by federal learning is a main research problem of the invention.
Disclosure of Invention
In view of this, the present invention provides a federated learning-based domain generalization method, which improves the generalization capability of a global model in a target domain in a scenario where the federated learning protects the privacy of source domain data.
In order to achieve the purpose, the invention provides the following technical scheme:
a domain generalization method based on federal learning is characterized in that under a federal learning framework, a generalization model for an unknown target domain is learned in a cross-domain mode based on distributed multi-source domain data. The method specifically comprises the following steps:
s1: aligning the feature distribution and the reference feature distribution of the source domain data of the client according to the categories by using a counterstudy network, and learning a classifier which can be well classified on the source domain features; the method comprises the steps of generating reference characteristic distribution by using a distribution generator, and reducing characteristic deviation required by alignment of the source domain characteristic distribution and the reference characteristic distribution by using the generated reference characteristic distribution close to the characteristic distribution of source domain data according to categories through counterstudy, so as to prevent characteristic distortion;
s2: the method comprises the steps that model parameters of feature extractors, distribution generators and classifiers of a plurality of clients are averaged in a server by using a federated learning framework, so that reference features of the clients are distributed consistently and are close to the centers of feature distribution of all active domain data, the feature distributions output by the feature extractors are consistent, and the classifier can classify the features of the feature distributions;
s3: through data interaction between a multi-round server and a client, the feature distribution of a plurality of source domains is aligned according to categories, the feature distribution output by the feature extractor on the data of the source domains is close to the generated reference feature distribution, at the moment, the feature extractor learns the features with invariable domains, and the classifier has similar performance on the features with the same distribution, so the feature extractor and the classifier can be well generalized to a target domain.
Further, step S1 specifically includes the following steps:
s11: based on the local tagged source domain data, the client trains the feature extractor and the classifier so that the feature extractor extracts key data features for classification tasks,the classifier can accurately classify the features; the classification loss function of this classification training is
Figure BDA0003440091630000021
Is a standard cross entropy function, and furthermore, to prevent overfitting, a label smoothing regularization term is used to fine tune the loss function;
s12: giving source domain data characteristics extracted by a characteristic extractor and generation characteristics output by a distribution generator, wherein the two types of characteristics carry real data labels, inputting the labels into a discriminator after random mapping, and the discriminator outputs the probability of the characteristics serving as a positive sample;
s13: updating parameters of the feature extractor: using a hyperparameter lambda0And λ1To balance the impact of countermeasure and classification penalties on the feature extractor parameters;
s14: loss function of distribution generator
Figure BDA0003440091630000022
Judging whether the probability of generating the feature h' is a positive sample by a discriminator; in the course of confrontation training of the distribution generator, given a fixed parameter arbiter,
Figure BDA0003440091630000023
parameters for updating the distribution generator;
s15: based on step S12, the discriminator distinguishes as much as possible the source domain data feature regarded as a negative sample and the generation feature regarded as a positive sample; however, in steps S13 and S14, the discriminator parameters are fixed, the feature extractor is trained so that the discriminator erroneously discriminates the source domain feature as a positive sample, the training distribution generator makes the discriminator correctly discriminates the generated feature as a positive sample, and the multi-round of the countermeasure training of steps S12 to S14 is repeated so that the discriminator discriminates both of the two types of features as a positive sample, at which time, the source domain feature distribution and the generated reference feature distribution are aligned by category, and the generated reference feature distribution is close to the source domain feature distribution by category through the countermeasure training, reducing the feature shift required for the source domain feature distribution during the alignment.
Further, in step S12, in the countervailing learning process of the feature extractor, the distribution generator, and the discriminator, the feature h extracted by the feature extractor is regarded as a negative sample, and the feature h' output by the distribution generator is regarded as a positive sample; the discriminator uses the loss function of the two types of input features
Figure BDA0003440091630000031
Is defined as:
Figure BDA0003440091630000032
wherein p (h) represents the distribution of the characteristic h, p (h ') represents the distribution of the characteristic h', D represents the discriminator model, y represents the real label of the data corresponding to the characteristic h,
Figure BDA0003440091630000033
indicating a desire.
Further, step S13 specifically includes: updating the loss function of the feature extractor and classifier as
Figure BDA0003440091630000034
Wherein the content of the first and second substances,
Figure BDA0003440091630000035
for the loss function in the classification training process,
Figure BDA0003440091630000036
a loss function representing the feature extractor during the counterlearning process is defined as:
Figure BDA0003440091630000037
in the course of training the feature extractor against, the negative sample h of the feature extractor will be used to spoof the discriminator, so that the discriminator discriminates h as a positive sample.
Further, in step S14, the loss function of the distribution generator
Figure BDA0003440091630000038
For updating the parameters of the distribution generator and,
Figure BDA0003440091630000039
is defined as:
Figure BDA00034400916300000310
further, step S2 specifically includes the following steps:
s21: the server receives parameters from models uploaded by the client, wherein the models comprise: the system comprises a feature extractor, a distribution generator and a classifier, wherein the feature extractor, the distribution generator and the classifier are temporarily stored, after model parameters uploaded by all client sides are received, the uploaded parameters are averaged according to different models, and a parameter averaging operation formula is as follows:
Figure BDA00034400916300000311
wherein, wtRepresenting a model parameter in the t period, and K represents the number of clients;
s22: after the server calculates all the received model parameters in an average mode, the server distributes the obtained new model parameters to all the clients and waits for the next aggregation operation.
Further, step S3 specifically includes the following steps:
s31: at the client, the output distribution of the local distribution generator is close to the characteristic distribution of the local source domain data, and the parameters are globally averaged in step S21, so that the reference characteristic distribution output by the distribution generator is close to the centers of all the source domain data distributions; through multiple rounds of client-side and server parameter interaction until the higher accuracy rate of the classifier on the local data set is converged, uploading model parameters to a server;
s2: the server averages the parameters of the feature extractor and the discriminator, at this time, the feature extractor learns the features with invariable domains, can extract the features distributed close to the generated reference features from the target domain data, and the globally averaged classifier can accurately classify the features.
The invention has the beneficial effects that:
(1) the method takes source domain data distributed on a plurality of clients as a training set, and learns the unknown target domain based on the federated learning architecture.
(2) The method can be applied to the problems of cold start and the like of the newly-created internet, and a model with good generalization capability in a target domain can be learned without target scene data under the condition of protecting the data privacy of a source domain.
Additional advantages, objects, and features of the invention will be set forth in part in the description which follows and in part will become apparent to those having ordinary skill in the art upon examination of the following or may be learned from practice of the invention. The objectives and other advantages of the invention may be realized and attained by the means of the instrumentalities and combinations particularly pointed out hereinafter.
Drawings
For the purposes of promoting a better understanding of the objects, aspects and advantages of the invention, reference will now be made to the following detailed description taken in conjunction with the accompanying drawings in which:
FIG. 1 is a flow chart of a federated learning-based domain generalization methodology of the present invention;
FIG. 2 is a schematic diagram of a data interaction process of a multi-source domain generalization model in the present invention;
fig. 3 is a schematic diagram of a client local training process according to the present invention.
Detailed Description
The embodiments of the present invention are described below with reference to specific embodiments, and other advantages and effects of the present invention will be easily understood by those skilled in the art from the disclosure of the present specification. The invention is capable of other and different embodiments and of being practiced or of being carried out in various ways, and its several details are capable of modification in various respects, all without departing from the spirit and scope of the present invention. It should be noted that the drawings provided in the following embodiments are only for illustrating the basic idea of the present invention in a schematic way, and the features in the following embodiments and examples may be combined with each other without conflict.
Referring to fig. 1 to 3, the client network model adopted by the present invention includes:
local model 1: and a feature extractor. The function of the feature extractor is to extract high-dimensional feature vectors of the data from the data of the client source domain, and the extracted feature vectors can be used for classification tasks. Furthermore, these data features extracted from the same source domain constitute a source domain feature distribution.
Local model 2: a distribution generator. The function of the distribution generator is to input a random vector and a label of the real data, and the distribution generator outputs the generated feature vector. The generated feature vectors constitute a determined distribution, i.e. a reference feature distribution.
Local model No. 3: and a discriminator. The function of the discriminator is that given the high-dimensional characteristics and the generated characteristics of the source domain with the real data labels, the discriminator can distinguish the two different types of characteristics through random mapping. During the course of the confrontational training, the discriminator will discriminate the two different features as much as possible.
Local model 4: and (4) a classifier. The function of the classifier is that, given a feature vector, the classifier outputs label information for its prediction.
Under a federal learning framework, the invention learns a generalization model for an unknown target domain based on distributed multi-source domain data cross-domain, and the method mainly comprises the following steps:
step 1: the federated learning architecture includes a plurality of clients and a server that are distributed. The client-side has the functions of storing data of a source domain, training a local model on the data of the source domain and interacting model parameters with the server, and the server has the functions of aggregating and distributing the model parameters, and an aggregated global model can be used for data of an unknown target domain. The data of the source domain in the client carries the label, while the data of the target domain is label-free and does not participate in the training process.
Step 2: and sending the model parameters trained by the client to a server, calculating an average value by the server to obtain new model parameters, distributing the new model parameters to all the clients, performing multiple rounds of iterative training until the accuracy of the model of the client on a local data set tends to be stable, and applying the global model of the server to a target domain. The invention provides a federal confrontation learning training method. At each client, the feature distributions of all the source domain data are aligned indirectly by local countermeasure learning to align the feature distributions of the source domain data to the same reference feature distribution. At this point, the server may learn an invariant feature extractor that extracts data features on the target domain that are also distributed close to the reference features to reduce the difference between the source domain data and the target domain data, so that the global model can be well generalized from multiple source domains to the target domain. In addition, in order to better predict the label of the target domain data, the invention proposes to train a classifier locally at the client based on the extracted features and the real label, and through multiple rounds of iterative training with the server, the global classifier is used for the target domain.
For step 2, the present invention proposes a distribution generator, whose function is to participate in local countermeasure training to make the reference feature distribution of its output close to the feature distribution of the source domain data. Unlike the common fixed reference feature distribution in the domain generalization problem, the generated reference feature distribution can be close to the center of the data feature distribution of all the source domains in the federal confrontation training, so that the shift of the feature distribution of the source domains in the alignment process and the extraction loss of key feature information are reduced. At this time, the model of the client includes a feature extractor, a distribution generator, a discriminator, and a classifier, where the feature extractor, the distribution generator, and the discriminator participate in federal confrontation training to align the data feature distribution of all active domains.
And step 3: the client-side local countermeasure training aligns the feature distribution of the source domain to the reference feature distribution, uploads the model parameters of the local feature extractor to the server, and enables the data distribution extracted by the global feature extractor to be close to the same reference feature distribution through aggregation. At this time, the client local model includes: the model parameters uploaded to the server come from the feature extractor and the classifier respectively, and the function of the discriminator is to discriminate the difference between the data characteristic and the reference characteristic of the source domain, so that the model parameters are not uploaded to the server. The significance of the step is that the difference between the data distribution of the source domain and the reference characteristic distribution is reduced through local countermeasure training, further, the distance of the data distribution of all the source domains is indirectly reduced, and the generalization performance of the global model in the target domain is improved.
For step 3, the invention proposes a data distribution strategy for aligning the source domains by category. In general, a source domain typically contains multiple classes of data, with different classes of data being distributed differently. In order to align the data distribution of all the active domains better, the invention adds a real label corresponding to the data in the input characteristic of the discriminator so that the discriminator discriminates according to the category better, and adds the label of the real data in the input random vector of the distribution generator to generate the same category characteristic similar to the real data.
The invention also provides a strategy for randomly mapping the high-dimensional features before the features are input by the discriminator. And the high-dimensional feature vector obtained by the feature extractor or the distribution generator is mapped randomly from the high-dimensional feature space to the low-dimensional feature space and then enters the discriminator for discrimination. Through random mapping, the output of the discriminator is helpful for parameter training of the distributed generator, and stability of local counterstudy is ensured.
The invention provides a domain generalization method based on federal learning, which is a method for learning to have good generalization capability in an unknown target domain based on a federal learning framework by taking source domain data distributed on a plurality of clients as a training set. The method uses the generated reference feature distribution as an intermediary, aligns the feature distribution of all source domains in a federated learning framework by locally aligning the source domain feature distribution and the reference feature distribution at a client, reduces the difference between all source domain data, and learns the invariant features, wherein the feature distribution of all source domains is close to the reference feature distribution. The learned invariant features can be output by a global feature extractor aggregated at the server, that is, the global feature extractor can extract features distributed close to the reference features from unknown target domain data, and the features have similar performance on the same task as the features of the source domain, so that the learned invariant features have good performance on the target domain. The learning method utilizes a federated learning framework to carry out distributed training, and finally achieves the purpose that a global model can be well generalized to an unknown target domain.
As shown in fig. 1, the federate learning-based domain generalization method of the present invention is divided into two parts, the first part trains and aligns feature distribution and reference feature distribution of source domain data on client data and learns a classifier that can well classify the source domain features, and the second part aggregates model parameters uploaded by the client at the server to obtain new average model parameters and distributes the new average model parameters to all clients.
A first part comprising the following six steps:
s11: as shown in fig. 2, if the client trains for the first time, the client receives the initialization model from the server and initializes the arbiter model locally, otherwise, the client receives the model parameters aggregated from the server and uses the arbiter model trained in the previous cycle. In addition, during training, all clients require uploading of locally trained model parameters, including feature extractors, classifiers, and distribution generators.
S12: based on local source domain data with labels, the client trains the feature extractor and the classifier, so that the feature extractor can extract important data features for classification, and the classifier can accurately classify and predict the extracted features. The loss function of the training is
Figure BDA0003440091630000071
Is a standard cross entropy function, and to prevent overfitting, the label smoothing regularization term fine-tunes the classification loss function.
S13: and giving source domain data features extracted by the feature extractor and generation features output by the distribution generator, wherein the two types of features carry labels from real data, inputting low-dimensional feature vectors obtained after random mapping into a discriminator, and using the output features as the probability of a positive sample by the discriminator. In the countervailing learning process, the features h extracted by the feature extractor are treated as negative samples, while the features h' output by the distribution generator are treated as positive samples. The penalty function of the discriminator on these two types of input features is defined as:
Figure BDA0003440091630000072
wherein p (h) represents the distribution of the characteristic h, p (h ') represents the distribution of the characteristic h', D represents the discriminator model, and y represents the real label of the data corresponding to the characteristic h.
S14: the parameter update of the feature extractor is affected by both the counterpenalty function and the classification penalty function, therefore, the present invention uses the hyper-parameter λ0And λ1To balance the impact of countermeasures and classification penalties on the feature extractor parameters, updating the penalty functions of the feature extractor and classifier as
Figure BDA0003440091630000073
Wherein the content of the first and second substances,
Figure BDA0003440091630000074
a loss function representing the feature extractor during the counterlearning process is defined as:
Figure BDA0003440091630000075
in the process of training the feature extractor against the countermeasures, the negative sample h output by the feature extractor is used for deceiving the discriminator, so that the discriminator can discriminate the negative sample h output by the feature extractor as a positive sample as much as possible.
S15: loss function of distribution generator
Figure BDA0003440091630000076
And judging whether the generated feature h' is a positive sample by the discriminator. In the course of confrontation training of the distribution generator, given a fixed parameter arbiter,
Figure BDA0003440091630000077
for updating the parameters of the distribution generator and,
Figure BDA0003440091630000078
is defined as:
Figure BDA0003440091630000079
as can be seen from step S13, in the countercheck learning, the discriminator is trained to discriminate the feature of the source domain data as a negative sample and the generated feature as a positive sample as much as possible, and step S14 and step S15 fix the discriminator parameters, train the feature extractor to make the discriminator misjudge the source domain feature as a positive sample, train the distribution generator to make the discriminator correctly judge the generated feature as a positive sample, repeat the countercheck training of steps S13-S15 to make the discriminator judge both types of features as positive samples, and at this time, the source domain feature distribution output by the feature extractor is aligned with the reference feature distribution output by the distribution generator.
S16: and uploading the model parameters. And after the local training of the client is finished, uploading the parameters of the feature extractor, the distribution generator and the classifier to the server, and waiting for the next data interaction with the server.
A second part comprising the following two steps:
s21: the server participates in federal learning for the first time, initializes the feature extractor, the distribution generator and the classifier model and distributes the feature extractor, the distribution generator and the classifier model to all the clients.
S22: in the federal learning domain generalization training process, a server receives model parameters uploaded from a client and averages the uploaded parameters, and the parameter averaging operation is as follows:
Figure BDA0003440091630000081
wherein, wtDenotes the model parameter at the t-th cycle, and K denotes the number of clients. In federal studiesIn the method, if the given number of cycles of interaction between the server and the client is not finished, the calculated new model parameters are distributed to all the clients, and otherwise, the characteristic extractor and the classifier model of the new model parameters are applied to the target domain data.
Finally, the above embodiments are only intended to illustrate the technical solutions of the present invention and not to limit the present invention, and although the present invention has been described in detail with reference to the preferred embodiments, it will be understood by those skilled in the art that modifications or equivalent substitutions may be made on the technical solutions of the present invention without departing from the spirit and scope of the technical solutions, and all of them should be covered by the claims of the present invention.

Claims (7)

1. A domain generalization method based on federal learning is characterized in that under a federal learning framework, a generalization model for an unknown target domain is learned in a cross-domain mode based on distributed multi-source domain data, and the method specifically comprises the following steps:
s1: aligning the feature distribution and the reference feature distribution of the source domain data of the client according to the categories by using a counterstudy network, and learning a classifier which can be well classified on the source domain features; generating a reference feature distribution by using a distribution generator, wherein the generated reference feature distribution is close to the feature distribution of the source domain data according to categories through counterstudy;
s2: the method comprises the steps that model parameters of feature extractors, distribution generators and classifiers of a plurality of clients are averaged in a server by using a federated learning framework, so that reference features of the clients are distributed consistently and are close to the centers of feature distribution of all active domain data, the feature distributions output by the feature extractors are consistent, and the classifier can classify the features of the feature distributions;
s3: through multiple rounds of server and client data interaction, the feature distributions of the multiple source domains are aligned according to categories, and the feature distribution output by the feature extractor on the multiple source domain data is close to the generated reference feature distribution.
2. The federal learning-based domain generalization method of claim 1, wherein the step S1 specifically comprises the steps of:
s11: based on local source domain data with labels, a client trains a feature extractor and a classifier, so that the feature extractor extracts key data features used for classifying tasks, and the classifier can accurately classify the features;
s12: giving source domain data characteristics extracted by a characteristic extractor and generation characteristics output by a distribution generator, wherein the two types of characteristics carry real data labels, inputting the labels into a discriminator after random mapping, and the discriminator outputs the probability of the characteristics serving as a positive sample;
s13: updating parameters of the feature extractor: using a hyperparameter lambda0And λ1To balance the impact of countermeasure and classification penalties on the feature extractor parameters;
s14: loss function of distribution generator
Figure FDA0003440091620000011
Judging whether the probability of generating the feature h' is a positive sample by a discriminator; in the course of confrontation training of the distribution generator, given a fixed parameter arbiter,
Figure FDA0003440091620000012
parameters for updating the distribution generator;
s15: based on step S12, the discriminator distinguishes as much as possible the source domain data feature regarded as a negative sample and the generation feature regarded as a positive sample; in steps S13 and S14, the discriminator parameter is fixed, the feature extractor is trained to make the discriminator misjudge that the source domain feature is a positive sample, the training distribution generator makes the discriminator correctly judge that the generated feature is a positive sample, and the multi-round countertraining of steps S12 to S14 is repeated to make the discriminator judge that both types of features are positive samples, at this time, the source domain feature distribution and the generated reference feature distribution are aligned by category, and the generated reference feature distribution is close to the source domain feature distribution by category through countertraining.
3. Federal science based according to claim 2The domain generalization method of learning, wherein in step S12, in the counterstudy process of the feature extractor, the distribution generator and the discriminator, the feature h extracted by the feature extractor is regarded as a negative sample, and the feature h' output by the distribution generator is regarded as a positive sample; the discriminator uses the loss function of the two types of input features
Figure FDA0003440091620000013
Is defined as:
Figure FDA0003440091620000021
wherein p (h) represents the distribution of the characteristic h, p (h ') represents the distribution of the characteristic h', D represents the discriminator model, y represents the real label of the data corresponding to the characteristic h,
Figure FDA0003440091620000022
indicating a desire.
4. The federal learning-based domain generalization method of claim 3, wherein the step S13 specifically comprises: updating the loss function of the feature extractor and classifier as
Figure FDA0003440091620000023
Wherein the content of the first and second substances,
Figure FDA0003440091620000024
for the loss function in the classification training process,
Figure FDA0003440091620000025
a loss function representing the feature extractor during the counterlearning process is defined as:
Figure FDA0003440091620000026
in the course of training the feature extractor against, the negative sample h of the feature extractor will be used to spoof the discriminator, so that the discriminator discriminates h as a positive sample.
5. The federal learning based domain generalization method of claim 4, wherein in step S14, the loss function of the generator is distributed
Figure FDA0003440091620000027
For updating the parameters of the distribution generator and,
Figure FDA0003440091620000028
is defined as:
Figure FDA0003440091620000029
6. the federal learning-based domain generalization method of claim 1, wherein the step S2 specifically comprises the steps of:
s21: the server receives parameters from models uploaded by the client, wherein the models comprise: the system comprises a feature extractor, a distribution generator and a classifier, wherein the feature extractor, the distribution generator and the classifier are temporarily stored, after model parameters uploaded by all client sides are received, the uploaded parameters are averaged according to different models, and a parameter averaging operation formula is as follows:
Figure FDA00034400916200000210
wherein, wtRepresenting a model parameter in the t period, and K represents the number of clients;
s22: after the server calculates all the received model parameters in an average mode, the server distributes the obtained new model parameters to all the clients and waits for the next aggregation operation.
7. The federal learning-based domain generalization method of claim 1, wherein the step S3 specifically comprises the steps of:
s31: at the client, the output distribution of the local distribution generator is close to the characteristic distribution of the local source domain data, and the reference characteristic distribution output by the distribution generator is close to the center of the distribution of all the source domain data after the parameters are globally averaged; through multiple rounds of client-side and server parameter interaction until the higher accuracy rate of the classifier on the local data set is converged, uploading model parameters to a server;
s2: the server averages the parameters of the feature extractor and the discriminator.
CN202111626157.1A 2021-12-28 2021-12-28 Domain generalization method based on federal learning Pending CN114399055A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111626157.1A CN114399055A (en) 2021-12-28 2021-12-28 Domain generalization method based on federal learning

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111626157.1A CN114399055A (en) 2021-12-28 2021-12-28 Domain generalization method based on federal learning

Publications (1)

Publication Number Publication Date
CN114399055A true CN114399055A (en) 2022-04-26

Family

ID=81229345

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111626157.1A Pending CN114399055A (en) 2021-12-28 2021-12-28 Domain generalization method based on federal learning

Country Status (1)

Country Link
CN (1) CN114399055A (en)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114818996A (en) * 2022-06-28 2022-07-29 山东大学 Method and system for diagnosing mechanical fault based on federal domain generalization
CN115952442A (en) * 2023-03-09 2023-04-11 山东大学 Global robust weighting-based federal domain generalized fault diagnosis method and system

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114818996A (en) * 2022-06-28 2022-07-29 山东大学 Method and system for diagnosing mechanical fault based on federal domain generalization
CN115952442A (en) * 2023-03-09 2023-04-11 山东大学 Global robust weighting-based federal domain generalized fault diagnosis method and system

Similar Documents

Publication Publication Date Title
Kim et al. Domain adaptation without source data
Liang et al. Do we really need to access the source data? source hypothesis transfer for unsupervised domain adaptation
Mancini et al. Adagraph: Unifying predictive and continuous domain adaptation through graphs
You et al. Position-aware graph neural networks
Segu et al. Batch normalization embeddings for deep domain generalization
Choudhuri et al. Distribution alignment using complement entropy objective and adaptive consensus-based label refinement for partial domain adaptation
CN112446423B (en) Fast hybrid high-order attention domain confrontation network method based on transfer learning
CN114399055A (en) Domain generalization method based on federal learning
CN114818996B (en) Method and system for diagnosing mechanical fault based on federal domain generalization
Mazzetto et al. Adversarial multi class learning under weak supervision with performance guarantees
AU2016218947A1 (en) Learning from distributed data
CN113037783B (en) Abnormal behavior detection method and system
Zhang et al. An open set domain adaptation algorithm via exploring transferability and discriminability for remote sensing image scene classification
CN114006870A (en) Network flow identification method based on self-supervision convolution subspace clustering network
Mathur et al. FlexAdapt: Flexible cycle-consistent adversarial domain adaptation
Yan et al. Domain adversarial disentanglement network with cross-domain synthesis for generalized face anti-spoofing
Shen et al. On balancing bias and variance in unsupervised multi-source-free domain adaptation
Zhang et al. C 3-GAN: Complex-Condition-Controlled Urban Traffic Estimation through Generative Adversarial Networks
Kim et al. Semi-supervised domain adaptation via selective pseudo labeling and progressive self-training
Chandhok et al. Structured latent embeddings for recognizing unseen classes in unseen domains
Zheng et al. Gnnevaluator: Evaluating gnn performance on unseen graphs without labels
Kong et al. A neural pre-conditioning active learning algorithm to reduce label complexity
Zhang et al. STrans-GAN: Spatially-Transferable Generative Adversarial Networks for Urban Traffic Estimation
CN116306969A (en) Federal learning method and system based on self-supervision learning
Du et al. A Few-Shot Class-Incremental Learning Method for Network Intrusion Detection

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination