CN114511737A - Training method of image recognition domain generalization model - Google Patents

Training method of image recognition domain generalization model Download PDF

Info

Publication number
CN114511737A
CN114511737A CN202210081010.7A CN202210081010A CN114511737A CN 114511737 A CN114511737 A CN 114511737A CN 202210081010 A CN202210081010 A CN 202210081010A CN 114511737 A CN114511737 A CN 114511737A
Authority
CN
China
Prior art keywords
domain
network
classification
loss
calculating
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202210081010.7A
Other languages
Chinese (zh)
Other versions
CN114511737B (en
Inventor
谭志
滕昭飞
李贲
袁冬海
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Hezhong Huineng Technology Co ltd
Beijing University of Civil Engineering and Architecture
Original Assignee
Beijing Hezhong Huineng Technology Co ltd
Beijing University of Civil Engineering and Architecture
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Hezhong Huineng Technology Co ltd, Beijing University of Civil Engineering and Architecture filed Critical Beijing Hezhong Huineng Technology Co ltd
Priority to CN202210081010.7A priority Critical patent/CN114511737B/en
Publication of CN114511737A publication Critical patent/CN114511737A/en
Application granted granted Critical
Publication of CN114511737B publication Critical patent/CN114511737B/en
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

The invention provides a training method of an image recognition domain generalization model. The method comprises the following steps: constructing an image identification domain generalization model, wherein the image identification domain generalization model comprises a classification network and a generation network; training a classification network and a generation network by using a countermeasure training mode to obtain a generation domain; and simulating the data distribution of the target domain by using the generated domain, calculating the loss of the generated domain by using a cycle consistency loss function, obtaining a trained image recognition domain generalization model after the loss of the generated domain meets the requirement, and performing cross-domain generalization recognition processing on the image by using the trained image recognition domain generalization model. The invention determines the validity of the generated data through a cycle consistency loss. And (4) smoothing and regularizing a cross entropy function according to the labels of the two groups of classifiers, avoiding a large difference value between correct classification and wrong classification, and avoiding overfitting. The portability of the image recognition model can be obviously improved, and the drift phenomenon in the field of the image recognition model is improved.

Description

Training method of image recognition domain generalization model
Technical Field
The invention relates to the technical field of image recognition, in particular to a training method of an image recognition domain generalization model.
Background
Traditional machine learning assumes that both the training set and test set data come from the same data distribution, whereas migration learning may assume that the training set and test set come from different data distributions. In the transfer learning, a training set and a test set are respectively called a source domain and a target domain, and the core purpose is to train source domain data and reduce the difference of quota distribution between the source domain data and the target domain data, so that the target domain is learned. Domain generalization is one kind of transfer learning, and aims to generalize a model trained on a source domain into target domains distributed differently.
With the rapid development of artificial intelligence, more and more common examples in daily life require methods of domain generalization. For example, in the special case of the medical field, some medical data acquisition is difficult and scarce due to the dangerousness and the expense of some operations. Secondly still have the problem of water gauge reading, the position of placing the water gauge can be different, along with the migration of time, the outward appearance of water gauge is difficult to avoid receiving the damage, is unfavorable for the reading of water gauge digit, and it is unrealistic to utilize a large amount of manpowers to obtain the digital data of water gauge. These real-world problems have therefore driven research into cross-domain generalization. In the current application, when a model trained through single-source domain or multi-source domain data is migrated to data with unknown data distribution, the identification accuracy of the model is obviously reduced, and the phenomenon that the identification accuracy is not high occurs in a prediction result, which is generally called as the phenomenon of 'domain drift', so that training a model with higher robustness is the most important task for researching the domain generalization problem. Data enhancement has proven to be an important and most direct means of model generalization, and there are many methods for data enhancement, and image enhancement can be achieved by physical methods such as rotation and cropping, and by changing the attributes of the image.
In order to solve the problem of performance degradation after model migration in the current domain generalization, many researchers have combined with the existing methods to solve this challenging task by using techniques such as countertraining, meta-learning, data distribution adaptation or domain expansion, and have shown good results. Early domain generalization studies mainly followed the idea of distribution alignment, learning domain invariant features through either kernel methods or domain antagonism learning.
In the prior art, a cross-domain generalization method based on meta-learning includes: the training strategy of meta-learning is simulated, and a second-order gradient is calculated on a random meta-test domain separated from a source domain at each iteration and propagated in reverse. Subsequent domain generalization approaches based on meta-learning utilize similar strategies to meta-learn regularizers, feature review networks, or how to maintain semantic relationships. Another approach to second-hand-domain generalization problem is domain enhancement, which creates samples from the virtual domain by a gradient-based image generator that opposes perturbation or opposition training. Recently, inspired by the robustness of shape-biased models to non-distribution, researchers filter texture features according to local self-information, make their models bias shape features, and subsequently further expand this work by introducing momentum metric learning schemes.
The above-mentioned prior art cross-domain generalization method based on meta-learning has the following disadvantages:
1. the data enhancement expansion type is relatively simple, the stability and the effectiveness of the expansion data cannot be ensured, and the robustness of the generalization of the model is influenced.
2. In the process of training the generated domain, only the loss of the correct label is considered and the loss of the error is ignored, so that the probability of prediction error is increased, and the classification effect of the classifier is reduced.
3. The classifier of the model structure is single and can not correlate semantic information among the features.
Disclosure of Invention
The embodiment of the invention provides a training method of an image recognition domain generalization model, which is used for improving the portability of the image recognition model.
In order to achieve the purpose, the invention adopts the following technical scheme.
A training method of an image recognition domain generalization model comprises the following steps:
constructing an image identification domain generalization model, wherein the image identification domain generalization model comprises a classification network and a generation network;
training the classification network and the generation network by using a countermeasure training mode to obtain a generation domain;
and simulating the data distribution of the target domain by using the generated domain, calculating the loss of the generated domain by using a cycle consistency loss function, obtaining a trained image recognition domain generalization model after the loss of the generated domain meets the requirement, and performing cross-domain generalization recognition processing on the image by using the trained image recognition domain generalization model.
Preferably, the constructing an image recognition domain generalization model, wherein the image recognition domain generalization model comprises a classification network and a generation network, and comprises:
the image recognition domain generalization model is constructed and comprises a classification network and a generation network, wherein the classification network comprises a feature extraction network, a full connection layer classifier and a graph neural network classifier, the generation network comprises an encoder, an AdaIN normalization layer and a decoder, and double classifiers are used in the classification network, one is the full connection layer classifier, and the other is the graph neural network classifier.
Preferably, the training the classification network and the generation network by using a countermeasure training method to obtain a generation domain includes:
step S1: preparing an image x and label y of source domain data of a training model, and an image x of target domain data for testingt
Step S2: in the pre-training stage, a plurality of images in source domain data are selected and input into a feature extraction network and a full-connection layer classifier to be circularly trained for multiple times to obtain a pre-trained recognition model, and the pre-training process is ended;
step S3: the length of the initialized and generated domain list is 0, and the source domain image x and the target domain image x are processedtCarrying out pretreatment;
step S4: inputting source domain data into the generatingA network for training the classification network and the generation network by means of countermeasure training to synthesize a generation domain SGData x ofG
Assuming S represents the source domain data set, x represents the image of the source domain data, y represents the label of the source domain data, and G represents the generation network, the domain S is generatedGData x ofGIs represented by formula (1):
Figure BDA0003485892430000031
preferably, the simulating the data distribution of the target domain by using the generated domain, calculating the loss of the generated domain by using a cyclic consistency loss function, and obtaining a trained image recognition domain generalization model after the loss of the generated domain meets requirements includes:
step S5: if the generated domain list length is equal to 0, executing the steps S6 and S7; if the value is greater than 0, executing the step S8 and the step S9;
step S6: respectively inputting the source domain data obtained in the step S3 into a feature extraction network and a full connection layer classifier to obtain the features of the source domain data, and calculating the prediction classification result of the full connection layer classifier by using a formula (5);
step S7: inputting the features obtained in the step S6 into a graph neural network classifier, taking the features as initial vertexes of the graph neural network classifier, firstly inputting the initial vertexes into an edge neural network, calculating affinity matrixes between the vertexes through a formula (8), and calculating loss between the affinity matrixes through a formula (11); updating the characteristics of the vertexes through semantic similarity, calculating to obtain a prediction classification result of the graph neural network through a formula (5), and executing the step S10;
step S8: randomly extracting generated domain data from the generated domain list, distributing a random weight with the sum of 1 to the domain and the source domain, and mixing the two domains to obtain a mixed domain;
step S9: respectively inputting source domain data, randomly extracted generated domain data and mixed domain data into a feature extraction network, a full-link layer classifier and a graph neural network classifier to obtain features of each domain data and respective prediction classification results calculated by the full-link layer classifier, calculating the loss of the full-link classifier through a formula (5), calculating the binary classification loss of an affinity matrix in the graph neural network through a formula (11) and calculating the classification loss of the graph neural network classifier by using the formula (5);
step S10: inputting the generated domain data generated in the step S4 into a feature extraction network, a full connection layer classifier and a graph neural network classifier, and calculating the classification loss of the generated domain and the cycle consistency loss obtained through the formula (2);
the calculating the classification loss of the generated domain comprises: calculating a loss of the fully-connected classifier by formula (5), calculating a two-classification loss of an affinity matrix in the graph neural network by formula (11), and calculating a classification loss of the graph neural network classifier using formula (5);
step S11: 30 rounds of circulation from step S4 to step S10;
step S12: using a gradient descent method, reversely propagating, updating the learned classification network and generating the loss of the network, and storing;
step S13: adding the data of the last round of domain generation in the step S11 into the domain list, and adding 1 to the length of the domain list;
step S14: inputting target domain data into a full-connection layer classifier of a classification network with the stored weight, calculating the identification accuracy of the target domain, and evaluating the generalization effect of image identification;
step S15: and (4) circulating the steps S4 to S14 for 20 rounds, namely generating 20 generated domains to obtain the final image recognition domain generalization model.
Preferably, the calculating the predicted classification result of the fully-connected layer classifier by using the formula (5) includes:
the cross entropy loss function after introducing label smoothing regularization correction is shown as formula (5):
Figure BDA0003485892430000041
wherein ce represents a standard cross entropy loss function as shown in formula (3), N represents the number of all classes, and i represents a corresponding correct class;
the cross entropy loss function of the standard is shown in equation (3):
Figure BDA0003485892430000042
wherein F represents a feature extraction network, C is a full-connection layer classifier, C (F ()) is a classification result predicted by the full-connection layer classifier, q is used as an index variable, when the prediction classification is consistent with the real class, q is 1, otherwise q is 0, the epsilon is assumed to be a smoothing factor, and the calculation of the index variable q is as shown in formula (4):
Figure BDA0003485892430000043
preferably, the calculating the semantic similarity between the vertexes by the formula (8) includes:
adding self-connection of nodes for non-standardized affinity matrix, inputting initial vertex into edge neural network, calculating affinity matrix between vertices by formula (8), and affinity symmetric matrix AlIs expressed by equation (8):
Figure BDA0003485892430000051
wherein D is a degree matrix, I is an identity matrix, and an affinity symmetric matrix AlEach element a ofijMeaning the semantic similarity between vertices.
Preferably, said calculating the loss between affinity matrices by equation (11) comprises:
calculating the matrix loss of the edge neural network by using the two-classification cross entropy loss function is shown as the formula (11):
Figure BDA0003485892430000052
at ijis an element of each row and column of the target affinity matrix, al ijIs the element of each row and column of affinity matrix obtained by calculation of edge neural network in graph neural network, affinity symmetric matrix AlOf (2) is used.
Preferably, the cycle consistency loss obtained by the formula (2) comprises:
the cycle consistency loss function is shown in equation (2):
Lcyc=minG,Gcyc||x-Gcyc(G(x))||2. (2)
x represents the source domain image, G represents the generation network, GcycIs a network with the same structure and weight as G, Gcyc(G (x)) is the input of the domain generated by G to GcycThe results obtained in (1).
According to the technical scheme provided by the embodiment of the invention, the effectiveness of the generated data is determined through the cycle consistency loss. The cross entropy function is smoothed and regularized according to the labels of the two groups of classifiers, so that a large difference between correct classification and wrong classification is avoided, and overfitting is avoided. The portability of the image recognition model can be obviously improved, and the drift phenomenon in the field of the image recognition model is improved.
Additional aspects and advantages of the invention will be set forth in part in the description which follows, and in part will be obvious from the description, or may be learned by practice of the invention.
Drawings
In order to more clearly illustrate the technical solutions of the embodiments of the present invention, the drawings needed to be used in the description of the embodiments are briefly introduced below, and it is obvious that the drawings in the following description are only some embodiments of the present invention, and it is obvious for those skilled in the art to obtain other drawings based on these drawings without creative efforts.
FIG. 1 is a block diagram of an image recognition domain generalization model according to an embodiment of the present invention;
FIG. 2 is a flowchart of a training process of an image recognition domain generalization model according to an embodiment of the present invention;
FIG. 3 is a schematic diagram showing comparison between the method of the embodiment of the present invention and recognition effects of several existing image recognition domain generalization methods, such as ERM, d-SNE, GUD, and UGMG.
Detailed Description
Reference will now be made in detail to embodiments of the present invention, examples of which are illustrated in the accompanying drawings, wherein like reference numerals refer to the same or similar elements or elements having the same or similar function throughout. The embodiments described below with reference to the accompanying drawings are illustrative only for the purpose of explaining the present invention, and are not to be construed as limiting the present invention.
As used herein, the singular forms "a", "an", "the" and "the" are intended to include the plural forms as well, unless the context clearly indicates otherwise. It will be further understood that the terms "comprises" and/or "comprising," when used in this specification, specify the presence of stated features, integers, steps, operations, elements, and/or components, but do not preclude the presence or addition of one or more other features, integers, steps, operations, elements, components, and/or groups thereof. It will be understood that when an element is referred to as being "connected" or "coupled" to another element, it can be directly connected or coupled to the other element or intervening elements may also be present. Further, "connected" or "coupled" as used herein may include wirelessly connected or coupled. As used herein, the term "and/or" includes any and all combinations of one or more of the associated listed items.
It will be understood by those skilled in the art that, unless otherwise defined, all terms (including technical and scientific terms) used herein have the same meaning as commonly understood by one of ordinary skill in the art to which this invention belongs. It will be further understood that terms, such as those defined in commonly used dictionaries, should be interpreted as having a meaning that is consistent with their meaning in the context of the prior art and will not be interpreted in an idealized or overly formal sense unless expressly so defined herein.
For the convenience of understanding the embodiments of the present invention, the following description will be further explained by taking several specific embodiments as examples in conjunction with the drawings, and the embodiments are not to be construed as limiting the embodiments of the present invention.
The embodiment of the invention takes the field generalization of image recognition as a research point, and designs a method for improving the domain generalization of image recognition by combining a graph neural network in order to solve the defects existing in the current research field and relieve the field drift phenomenon, adopts a strategy of implementing data enhancement by countertraining, and simulates unknown data distribution by using generated domain data.
The embodiment of the invention designs an antagonistic training domain generalization model which is jointly learned by a domain generation network and a domain classification network. In the domain classification network, double classifiers are used, one is a classifier of a full connection layer, and the other is a graph neural network classifier. An AdaIN normalization layer is added into the domain generation network, and meanwhile, a cycle consistency loss function is used, so that the diversity and the effectiveness of domain generation model extension data are improved. The label smoothing and regularization can avoid a large difference between correct and wrong classification, avoid overfitting and improve generalization performance. The graph neural network classifier can aggregate similar semantic information among image features, and the two classifiers are beneficial to invariant representation of learning classes and improve the performance of the classifier.
The structure of the image recognition domain generalization model of the embodiment of the invention is shown in fig. 1, and comprises a classification network and a generation network, wherein the classification network comprises a feature extraction network, a full connection layer classifier and a graph neural network classifier. The generation network consists of three parts, namely an encoder, an AdaIN normalization layer and a decoder.
Assuming S represents the source domain data set, x represents the image of the source domain data, y represents the label of the source domain data, and G represents the generation network, the domain S is generatedGData x ofGIs represented by formula (1):
Figure BDA0003485892430000071
(1) consistency of cycle
And obtaining a generation domain by means of a resistance training generation network and a classification network, and simulating the data distribution of the target domain. And a cycle consistency loss function is used for the generated domain obtained through the generated network, so that the effectiveness of the generated domain is ensured. Wherein G iscycAnd G have the same structure and the same weight. The cycle consistency loss function is shown in equation (2):
Lcyc=minG,Gcyc||x-Gcyc(G(x))||2. (2)
x represents the source domain image, G represents the generation network, GcycIs a network with the same structure and weight as G, Gcyc(G (x)) is the input of the domain generated by G to GcycThe results obtained in (1).
(2) Label smoothing regularization
In the process of training the migration learning sample, a one-hot label is usually adopted to calculate the cross entropy loss, but the method only considers the loss of the correct label position in the training sample, namely the position with the one-hot label being 1, and ignores the loss of the wrong label position, namely the position with the one-hot label being 0. In this way, the model can achieve a good fit to the source domain data, but not computed due to the loss of other false tag locations. In the process of classifying the prediction data, the probability of prediction error is increased, so that it is necessary to prevent the model from overfitting the target domain image. To address this problem, the present invention employs a label smoothing regularization method for preventing the classification task from overfitting. Because the label information of the target domain does not appear in the prediction process, the source domain image can be trained and model weights learned, and then the target domain image is predicted as a complete one-time learning task. The cross entropy loss function of the standard is shown in equation (3):
Figure BDA0003485892430000072
wherein F represents a feature extraction network, C is a full link layer classifier, and C (F ()) is a classification result predicted by the full link layer classifier. And q is used as an index variable, and is 1 when the prediction classification is consistent with the real classification, or is 0. Let e be a smoothing factor to control the relative weights, similar to introducing noise into the true distribution, e is taken to be 0.1. Therefore, the correct class prediction of the full-link classifier is changed from 1 to 0.9, so the calculation of the index variable is shown as formula (4):
Figure BDA0003485892430000081
where y' represents the predicted classification result of the classifier and N represents the number of all classifications.
The cross entropy loss function after introducing label smoothing regularization correction is shown as formula (5):
Figure BDA0003485892430000082
where ce represents the cross-entropy loss function of the standard as shown in equation (3), N represents the number of all classes, and i represents the corresponding correct class.
(3) Graph neural network
The graph neural network of the present invention comprises two parts, one part is the edge neural network feCalculating affinity matrix by using parameterized non-linear similarity function to calculate semantic information of similarity, and using point neural network f as another partnFor classification of the images. Constructing a undirected fully connected Graph, Graph ═ V, E, a, where V represents the vertices of the undirected Graph, i.e. the image x of the input dataiE represents the edge of an undirected graph, i.e. two vertices viAnd vjA represents the affinity matrix of the undirected graph, i.e. two vertices viAnd vjSemantic similarity between them. Therefore, the expression of the similarity between the vertexes is shown in formula (6):
Figure BDA0003485892430000083
wherein v isi l-1And vj l-1Respectively represent the vertex viAnd vjFeatures at level l-1 of the graph neural network, feature extraction networks from classification networks, fe lA first-level parameterized non-linear similarity function representing the edge neural network. The undirected graph, therefore, has a non-normalized affinity matrix A' as shown in equation (7):
Figure BDA0003485892430000084
adding node self-connection for non-standardized affinity matrix and carrying out normalization operation to obtain affinity symmetric matrix A of edge neural networklThe calculation expression of the affinity symmetric matrix is shown in formula (8):
Figure BDA0003485892430000085
where D is the degree matrix and I is the identity matrix. And finally, the obtained affinity matrix is propagated to a point neural network for updating the vertex characteristics.
The expression for updating the vertex feature is shown in equation (9):
vi l=fn l(vi l-1,∑ai,j·vj l-1). (9)
wherein f isn lThe prediction result of the graph neural network classifier is finally obtained at the last layer of the graph neural network.
Constructing a target affinity matrix A according to the label y of the source domaintLet A predictedlAnd gradually approaching. Thus target affinity matrix AtEach element at ijIs calculated as shown in equation (10):
Figure BDA0003485892430000091
calculating the matrix loss of the edge neural network by using the two-classification cross entropy loss function is shown as the formula (11):
Figure BDA0003485892430000092
at ijis an element of each row and column of the target affinity matrix, al ijIs the element of each row and column of the affinity matrix (i.e. the element of the matrix of formula 8) calculated by the edge neural network in the graph neural network.
The graph neural network classifier uses two loss functions, and calculates the loss between the edge neural network affinity matrix using equation (11), and calculates the classification loss of the whole graph neural network classifier using equation (5) for the point neural network.
The training process flow of the image recognition domain generalization model provided by the embodiment of the invention is shown in fig. 2, and comprises the following processing steps:
step S1: preparing an image x and label y of source domain data of a training model, and an image x of target domain data for testingt. The source domain data set selected by the invention is MNIST, and the target domain data set is four sets of data sets of MNIST-M, SVHN, SYN and USPS.
Step S2: a pre-training phase. And selecting the first 1W images of the source domain data, inputting the images into a feature extraction network and a full-connection layer classifier, circularly training for 30 times to improve the efficiency of achieving a fitting state in subsequent data training, obtaining a pre-trained recognition model, and finishing the pre-training process.
Step S3: the length of the initialized and generated domain list is 0, and the images x and x of the source domain and the target domain are comparedtAnd (4) carrying out pretreatment.
Step S4: entering a domain generalization stage. MNIST is selected as source domain data and input into a generator network, and a new data set is synthesized into a generation domain S by using a formula (1)G
Step S5: if the generated domain list length is equal to 0, executing the steps S6 and S7; if the value is greater than 0, the steps S8 and S9 are executed.
Step S6: and (4) respectively inputting the source domain data obtained in the step (S3) into the feature extraction network and the full-connection layer classifier to obtain the features of the source domain data, and calculating the prediction classification result of the full-connection layer classifier by using a formula (5).
Step S7: inputting the features obtained in the step S6 into the graph neural network classifier, taking the features as initial vertices of the graph neural network classifier, firstly inputting the initial vertices into the edge neural network, calculating semantic similarity between the vertices by the formula (8), and calculating loss between affinity matrices by the formula (11). And updating the characteristics of the vertex through semantic similarity. And finally, calculating by using a formula (5) to obtain a prediction classification result of the graph neural network. Step S10 is executed.
Step S8: and randomly extracting the generated domain data from the generated domain list, assigning a random weight with the sum of 1 to the domain and the source domain, and mixing the two domains to obtain a mixed domain.
Step S9: and respectively inputting the source domain data, the randomly extracted generated domain data and the mixed domain data into a feature extraction network, a full-connection layer classifier and a graph neural network classifier to obtain the features of each domain data and calculate respective prediction classification results by the full-connection layer classifier. Calculating the loss of the fully-connected classifier by equation (5), calculating the two-classification loss of the affinity matrix in the graph neural network by equation (11), and calculating the classification loss of the graph neural network classifier using equation (5).
Step S10: the generated domain data generated in step S4 is input to the feature extraction network, the full connection layer classifier, and the graph neural network classifier. And (3) calculating the classification loss of the generated domain and the cycle consistency loss obtained through the formula (2).
The calculating the classification loss of the generated domain comprises: calculating the loss of the fully-connected classifier by equation (5), calculating the two-classification loss of the affinity matrix in the graph neural network by equation (11), and calculating the classification loss of the graph neural network classifier using equation (5).
Step S11: the loop from step S4 to step S10 has 30 rounds, that is, training a generation domain for the first 15 rounds of training 30 rounds to update both the classification network and the generation network, and to update both the classifier loss and the generation network loss and the loop consistency loss. The last 15 cycles of computation only train the classification network and only compute two classifier losses.
Step S12: and (4) updating the learned classification network and the loss of the generated network and storing by using a gradient descent method and back propagation.
The first 15 rounds of updating the classification network and the generation network are the sum of the loss of the source domain data calculated by formula (5) through the two classifiers and the loss of the matrix similarity calculated by formula (11) in step S7, and the sum of the classification loss of the two classifiers, the loss of the matrix similarity of formula (11) and the loss of the cyclic consistency of formula (2) in the generation domain network calculated by formula (5) in step S10.
The last 15 update steps S7 may be performed. When the generated domain list is larger than 0, that is, the rear 19-round operation of S15 is performed, all the operations of step S7 are changed to the operations of step S9. That is, the first 15 rounds of updating the classification network and the generation network are the sum of the loss of the source domain data, the random extraction of the merged domain data, the mixed domain data, which are respectively calculated by formula (5) at step S9, through the two classifiers and the loss of the matrix similarity calculated by formula (11), and the sum of the loss of the classification of the two classifiers of the generation domain network, the loss of the matrix similarity of formula (11), and the loss of the cyclic consistency of formula (2), which is calculated by formula (5) at step S10. The rear 15 round updating step S9.
Step S13: the data of the last round of generating the domain in step S11 is added to the generating domain list, and the length of the generating domain list is increased by 1.
In the first execution, the loop skips steps S8 and S9 to perform step S10 in step S7. The next 19 loops of step S15 are executed after the determination of step S5, and step S6 and step S7 are not executed, but step S8 and step S9 are executed.
Step S13 adds the data of the last round of generation domain to the role of generating the domain list: for performing the rear 19-wheel operation use of step S15, i.e., in step S8.
Step S14: and inputting target domain data into a full-connection layer classifier of the classification network with the stored weight, calculating the identification accuracy of the target domain, and evaluating the generalization effect of image identification.
And when the identification accuracy of the target domain is evaluated, the classification result of the full-link classifier is used as the classification result of the target domain, and the average value of the identification accuracy of the target domain is used as an evaluation index.
Step S15: and (4) circulating the steps S4 to S14 for 20 rounds, namely generating 20 generated domains to obtain the final image recognition domain generalization model.
Step S16: the whole process is ended.
And subsequently, the trained image recognition domain generalization model can be used for carrying out cross-domain generalization recognition processing on the image.
In summary, in the process of constructing the model, the effectiveness of generating data can be determined through the cycle consistency loss according to the embodiment of the invention. The cross entropy function is smoothed and regularized according to the labels of the two groups of classifiers, so that a large difference between correct classification and wrong classification is avoided, and overfitting is avoided. The graph neural network classifier can aggregate similar semantic information among image features, and the two classifiers are beneficial to invariant representation of learning classes and improve the performance of the classifier. The MNIST data set is selected as a source domain training set, four groups of digital data sets including SYHN, MNIST-M, SYN and USPS are selected as a target domain test set for testing, and the average identification precision of a target domain is selected as an evaluation index. The experimental result is shown in fig. 3, the portability of the image recognition model can be obviously improved, and the drift phenomenon in the field of the image recognition model is improved.
Those of ordinary skill in the art will understand that: the figures are merely schematic representations of one embodiment, and the blocks or flow diagrams in the figures are not necessarily required to practice the present invention.
From the above description of the embodiments, it is clear to those skilled in the art that the present invention can be implemented by software plus necessary general hardware platform. Based on such understanding, the technical solutions of the present invention may be embodied in the form of a software product, which may be stored in a storage medium, such as ROM/RAM, magnetic disk, optical disk, etc., and includes instructions for causing a computer device (which may be a personal computer, a server, or a network device, etc.) to execute the method according to the embodiments or some parts of the embodiments.
The embodiments in the present specification are described in a progressive manner, and the same and similar parts among the embodiments are referred to each other, and each embodiment focuses on the differences from the other embodiments. In particular, for apparatus or system embodiments, since they are substantially similar to method embodiments, they are described in relative terms, as long as they are described in partial descriptions of method embodiments. The above-described embodiments of the apparatus and system are merely illustrative, and the units described as separate parts may or may not be physically separate, and the parts displayed as units may or may not be physical units, may be located in one place, or may be distributed on a plurality of network units. Some or all of the modules may be selected according to actual needs to achieve the purpose of the solution of the present embodiment. One of ordinary skill in the art can understand and implement it without inventive effort.
The above description is only for the preferred embodiment of the present invention, but the scope of the present invention is not limited thereto, and any changes or substitutions that can be easily conceived by those skilled in the art within the technical scope of the present invention are included in the scope of the present invention. Therefore, the protection scope of the present invention shall be subject to the protection scope of the claims.

Claims (8)

1. A training method of an image recognition domain generalization model is characterized by comprising the following steps:
constructing an image identification domain generalization model, wherein the image identification domain generalization model comprises a classification network and a generation network;
training the classification network and the generation network by using a countermeasure training mode to obtain a generation domain;
and simulating the data distribution of the target domain by using the generated domain, calculating the loss of the generated domain by using a cycle consistency loss function, obtaining a trained image recognition domain generalization model after the loss of the generated domain meets the requirement, and performing cross-domain generalization recognition processing on the image by using the trained image recognition domain generalization model.
2. The method of claim 1, wherein constructing an image recognition domain generalization model, the image recognition domain generalization model comprising a classification network and a generation network, comprises:
the image recognition domain generalization model is constructed and comprises a classification network and a generation network, wherein the classification network comprises a feature extraction network, a full connection layer classifier and a graph neural network classifier, the generation network comprises an encoder, an AdaIN normalization layer and a decoder, and double classifiers are used in the classification network, one is the full connection layer classifier, and the other is the graph neural network classifier.
3. The method of claim 2, wherein training the classification network and the generation network using a countermeasure training method to obtain a generation domain comprises:
step S1: preparing an image x and label y of source domain data of a training model, and an image x of target domain data for testingt
Step S2: in the pre-training stage, a plurality of images in source domain data are selected and input into a feature extraction network and a full-connection layer classifier to be circularly trained for multiple times to obtain a pre-trained recognition model, and the pre-training process is ended;
step S3: the length of the initialized and generated domain list is 0, and the source domain image x and the target domain image x are processedtCarrying out pretreatment;
step S4: inputting source domain data into the generation network, training the classification network and the generation network by using a countermeasure training mode, and synthesizing a generation domain SGData x ofG
Assuming S represents a source domain data set, x represents an image of the source domain data, y represents a label of the source domain data, and G represents a generation network, then generation occursFormed field SGData x ofGIs represented by formula (1):
Figure FDA0003485892420000011
4. the method according to claim 3, wherein the simulating the data distribution of the target domain by using the generated domain, calculating the loss of the generated domain by using a cyclic consistency loss function, and obtaining the trained image recognition domain generalization model after the loss of the generated domain meets the requirement, comprises:
step S5: if the generated domain list length is equal to 0, executing the steps S6 and S7; if the value is greater than 0, executing the step S8 and the step S9;
step S6: respectively inputting the source domain data obtained in the step S3 into a feature extraction network and a full connection layer classifier to obtain the features of the source domain data, and calculating the prediction classification result of the full connection layer classifier by using a formula (5);
step S7: inputting the features obtained in the step S6 into a graph neural network classifier, taking the features as initial vertexes of the graph neural network classifier, firstly inputting the initial vertexes into an edge neural network, calculating affinity matrixes between the vertexes through a formula (8), and calculating loss between the affinity matrixes through a formula (11); updating the characteristics of the vertexes through semantic similarity, calculating to obtain a prediction classification result of the graph neural network through a formula (5), and executing the step S10;
step S8: randomly extracting generated domain data from the generated domain list, distributing a random weight with the sum of 1 to the domain and the source domain, and mixing the two domains to obtain a mixed domain;
step S9: respectively inputting source domain data, randomly extracted generated domain data and mixed domain data into a feature extraction network, a full-link layer classifier and a graph neural network classifier to obtain features of each domain data and respective prediction classification results calculated by the full-link layer classifier, calculating the loss of the full-link classifier through a formula (5), calculating the binary classification loss of an affinity matrix in the graph neural network through a formula (11) and calculating the classification loss of the graph neural network classifier by using the formula (5);
step S10: inputting the generated domain data generated in the step S4 into a feature extraction network, a full connection layer classifier and a graph neural network classifier, and calculating the classification loss of the generated domain and the cycle consistency loss obtained through the formula (2);
the calculating the classification loss of the generated domain comprises: calculating a loss of the fully-connected classifier by formula (5), calculating a two-classification loss of an affinity matrix in the graph neural network by formula (11), and calculating a classification loss of the graph neural network classifier using formula (5);
step S11: 30 rounds of circulation from step S4 to step S10;
step S12: using a gradient descent method, reversely propagating, updating the learned classification network and generating the loss of the network, and storing;
step S13: adding the data of the last round of domain generation in the step S11 into the domain list, and adding 1 to the length of the domain list;
step S14: inputting target domain data into a full-connection layer classifier of a classification network with the stored weight, calculating the identification accuracy of the target domain, and evaluating the generalization effect of image identification;
step S15: and (4) circulating the steps S4 to S14 for 20 rounds, namely generating 20 generated domains to obtain the final image recognition domain generalization model.
5. The method of claim 4, wherein the calculating the predicted classification result of the fully-connected layer classifier using equation (5) comprises:
the cross entropy loss function after introducing label smoothing regularization correction is shown as formula (5):
Figure FDA0003485892420000031
wherein ce represents a standard cross entropy loss function as shown in formula (3), N represents the number of all classes, and i represents a corresponding correct class;
the cross entropy loss function of the standard is shown in equation (3):
Figure FDA0003485892420000032
wherein F represents a feature extraction network, C is a full-connection layer classifier, C (F ()) is a classification result predicted by the full-connection layer classifier, q is used as an index variable, when the prediction classification is consistent with the real class, q is 1, otherwise q is 0, the epsilon is assumed to be a smoothing factor, and the calculation of the index variable q is as shown in formula (4):
Figure FDA0003485892420000033
6. the method of claim 5, wherein the calculating semantic similarity between vertices by formula (8) comprises:
adding self-connection of nodes for non-standardized affinity matrix, inputting initial vertex into edge neural network, calculating affinity matrix between vertices by formula (8), and affinity symmetric matrix AlIs expressed by equation (8):
Figure FDA0003485892420000034
wherein D is a degree matrix, I is an identity matrix, and an affinity symmetric matrix AlEach element a ofijMeaning the semantic similarity between vertices.
7. The method of claim 5, wherein calculating the loss between affinity matrices by equation (11) comprises:
calculating the matrix loss of the edge neural network by using the two-classification cross entropy loss function is shown as the formula (11):
Figure FDA0003485892420000035
at ijis an element of each row and column of the target affinity matrix, al ijIs the element of each row and column of affinity matrix obtained by calculation of edge neural network in the graph neural network, and affinity symmetric matrix AlOf (2) is used.
8. The method of claim 5, wherein the cyclic consistency loss obtained by equation (2) comprises:
the cycle consistency loss function is shown in equation (2):
Lcyc=minG,Gcyc||x-Gcyc(G(x))||2. (2)
x represents the source domain image, G represents the generation network, GcycIs a network with the same structure and weight as G, Gcyc(G (x)) is the input of the domain generated by G to GcycThe results obtained in (1).
CN202210081010.7A 2022-01-24 2022-01-24 Training method of image recognition domain generalization model Active CN114511737B (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210081010.7A CN114511737B (en) 2022-01-24 2022-01-24 Training method of image recognition domain generalization model

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210081010.7A CN114511737B (en) 2022-01-24 2022-01-24 Training method of image recognition domain generalization model

Publications (2)

Publication Number Publication Date
CN114511737A true CN114511737A (en) 2022-05-17
CN114511737B CN114511737B (en) 2022-09-09

Family

ID=81549736

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210081010.7A Active CN114511737B (en) 2022-01-24 2022-01-24 Training method of image recognition domain generalization model

Country Status (1)

Country Link
CN (1) CN114511737B (en)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115205599A (en) * 2022-07-25 2022-10-18 浙江大学 Multi-age-range child chest radiography pneumonia classification system based on domain generalization model
CN115272681A (en) * 2022-09-22 2022-11-01 中国海洋大学 Ocean remote sensing image semantic segmentation method and system based on high-order feature class decoupling
CN115880538A (en) * 2023-02-17 2023-03-31 阿里巴巴达摩院(杭州)科技有限公司 Method and equipment for domain generalization of image processing model and image processing

Citations (15)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109753992A (en) * 2018-12-10 2019-05-14 南京师范大学 The unsupervised domain for generating confrontation network based on condition adapts to image classification method
US20200160177A1 (en) * 2018-11-16 2020-05-21 Royal Bank Of Canada System and method for a convolutional neural network for multi-label classification with partial annotations
CN111476294A (en) * 2020-04-07 2020-07-31 南昌航空大学 Zero sample image identification method and system based on generation countermeasure network
CN111738315A (en) * 2020-06-10 2020-10-02 西安电子科技大学 Image classification method based on countermeasure fusion multi-source transfer learning
CN111860588A (en) * 2020-06-12 2020-10-30 华为技术有限公司 Training method for graph neural network and related equipment
CN112131967A (en) * 2020-09-01 2020-12-25 河海大学 Remote sensing scene classification method based on multi-classifier anti-transfer learning
CN113033716A (en) * 2021-05-26 2021-06-25 南京航空航天大学 Image mark estimation method based on confrontation fusion crowdsourcing label
CN113052810A (en) * 2021-03-17 2021-06-29 浙江工业大学 Small medical image focus segmentation method suitable for mobile application
CN113221848A (en) * 2021-06-09 2021-08-06 中国人民解放军国防科技大学 Hyperspectral open set field self-adaptive method based on multi-classifier domain confrontation network
CN113344044A (en) * 2021-05-21 2021-09-03 北京工业大学 Cross-species medical image classification method based on domain self-adaptation
CN113378904A (en) * 2021-06-01 2021-09-10 电子科技大学 Image classification method based on anti-domain adaptive network
US20210349954A1 (en) * 2020-04-14 2021-11-11 Naver Corporation System and method for performing cross-modal information retrieval using a neural network using learned rank images
CN113722439A (en) * 2021-08-31 2021-11-30 福州大学 Cross-domain emotion classification method and system based on antagonism type alignment network
CN113724880A (en) * 2021-11-03 2021-11-30 深圳先进技术研究院 Abnormal brain connection prediction system, method and device and readable storage medium
CN113936143A (en) * 2021-09-10 2022-01-14 北京建筑大学 Image identification generalization method based on attention mechanism and generation countermeasure network

Patent Citations (15)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20200160177A1 (en) * 2018-11-16 2020-05-21 Royal Bank Of Canada System and method for a convolutional neural network for multi-label classification with partial annotations
CN109753992A (en) * 2018-12-10 2019-05-14 南京师范大学 The unsupervised domain for generating confrontation network based on condition adapts to image classification method
CN111476294A (en) * 2020-04-07 2020-07-31 南昌航空大学 Zero sample image identification method and system based on generation countermeasure network
US20210349954A1 (en) * 2020-04-14 2021-11-11 Naver Corporation System and method for performing cross-modal information retrieval using a neural network using learned rank images
CN111738315A (en) * 2020-06-10 2020-10-02 西安电子科技大学 Image classification method based on countermeasure fusion multi-source transfer learning
CN111860588A (en) * 2020-06-12 2020-10-30 华为技术有限公司 Training method for graph neural network and related equipment
CN112131967A (en) * 2020-09-01 2020-12-25 河海大学 Remote sensing scene classification method based on multi-classifier anti-transfer learning
CN113052810A (en) * 2021-03-17 2021-06-29 浙江工业大学 Small medical image focus segmentation method suitable for mobile application
CN113344044A (en) * 2021-05-21 2021-09-03 北京工业大学 Cross-species medical image classification method based on domain self-adaptation
CN113033716A (en) * 2021-05-26 2021-06-25 南京航空航天大学 Image mark estimation method based on confrontation fusion crowdsourcing label
CN113378904A (en) * 2021-06-01 2021-09-10 电子科技大学 Image classification method based on anti-domain adaptive network
CN113221848A (en) * 2021-06-09 2021-08-06 中国人民解放军国防科技大学 Hyperspectral open set field self-adaptive method based on multi-classifier domain confrontation network
CN113722439A (en) * 2021-08-31 2021-11-30 福州大学 Cross-domain emotion classification method and system based on antagonism type alignment network
CN113936143A (en) * 2021-09-10 2022-01-14 北京建筑大学 Image identification generalization method based on attention mechanism and generation countermeasure network
CN113724880A (en) * 2021-11-03 2021-11-30 深圳先进技术研究院 Abnormal brain connection prediction system, method and device and readable storage medium

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
刘昌通等: "基于联合一致循环生成对抗网络的人像着色", 《计算机工程与应用》 *
王坤峰等: "生成式对抗网络GAN的研究进展与展望", 《自动化学报》 *

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115205599A (en) * 2022-07-25 2022-10-18 浙江大学 Multi-age-range child chest radiography pneumonia classification system based on domain generalization model
CN115272681A (en) * 2022-09-22 2022-11-01 中国海洋大学 Ocean remote sensing image semantic segmentation method and system based on high-order feature class decoupling
CN115272681B (en) * 2022-09-22 2022-12-20 中国海洋大学 Ocean remote sensing image semantic segmentation method and system based on high-order feature class decoupling
CN115880538A (en) * 2023-02-17 2023-03-31 阿里巴巴达摩院(杭州)科技有限公司 Method and equipment for domain generalization of image processing model and image processing

Also Published As

Publication number Publication date
CN114511737B (en) 2022-09-09

Similar Documents

Publication Publication Date Title
CN114511737B (en) Training method of image recognition domain generalization model
Gu et al. Projection convolutional neural networks for 1-bit cnns via discrete back propagation
Zhang et al. Efficient evolutionary search of attention convolutional networks via sampled training and node inheritance
Liu et al. Incdet: In defense of elastic weight consolidation for incremental object detection
Fischer et al. Training restricted Boltzmann machines: An introduction
Makhzani et al. Adversarial autoencoders
Yang et al. Online learning for group lasso
CN110097178A (en) It is a kind of paid attention to based on entropy neural network model compression and accelerated method
CN109743196B (en) Network characterization method based on cross-double-layer network random walk
CN109117943B (en) Method for enhancing network representation learning by utilizing multi-attribute information
CN111667016B (en) Incremental information classification method based on prototype
Chen et al. DMGAN: Discriminative metric-based generative adversarial networks
Zhou et al. Improved cross-label suppression dictionary learning for face recognition
CN112017255A (en) Method for generating food image according to recipe
Li et al. Adaptive momentum variance for attention-guided sparse adversarial attacks
Hu et al. Image super-resolution with self-similarity prior guided network and sample-discriminating learning
Xu et al. Graphical modeling for multi-source domain adaptation
Shi et al. AutoInfo GAN: Toward a better image synthesis GAN framework for high-fidelity few-shot datasets via NAS and contrastive learning
Sood et al. Neunets: An automated synthesis engine for neural network design
CN116258504A (en) Bank customer relationship management system and method thereof
CN115578593A (en) Domain adaptation method using residual attention module
CN113344189B (en) Neural network training method and device, computer equipment and storage medium
Kekeç et al. PAWE: Polysemy aware word embeddings
Jing et al. NASABN: A neural architecture search framework for attention-based networks
Joo et al. Towards more robust interpretation via local gradient alignment

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant