CN114937166A - Image classification model construction method, image classification method and device and electronic equipment - Google Patents

Image classification model construction method, image classification method and device and electronic equipment Download PDF

Info

Publication number
CN114937166A
CN114937166A CN202210424033.3A CN202210424033A CN114937166A CN 114937166 A CN114937166 A CN 114937166A CN 202210424033 A CN202210424033 A CN 202210424033A CN 114937166 A CN114937166 A CN 114937166A
Authority
CN
China
Prior art keywords
student
model
group leader
loss
distillation
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210424033.3A
Other languages
Chinese (zh)
Inventor
林兰芬
牛子未
袁俊坤
马旭
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang University ZJU
Original Assignee
Zhejiang University ZJU
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang University ZJU filed Critical Zhejiang University ZJU
Priority to CN202210424033.3A priority Critical patent/CN114937166A/en
Publication of CN114937166A publication Critical patent/CN114937166A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q10/00Administration; Management
    • G06Q10/04Forecasting or optimisation specially adapted for administrative or management purposes, e.g. linear programming or "cutting stock problem"

Abstract

The invention discloses an image classification model construction method, an image classification method and device and electronic equipment, wherein the image classification model construction method comprises the following steps: respectively training corresponding auxiliary student models by using a plurality of source domains; training a student group leader model using the plurality of source domain hybrids; mutually distilling the predicted outputs of the plurality of auxiliary student models to obtain a first distillation loss; updating a plurality of said secondary student models according to said first distillation loss; distilling the updated predicted output of the plurality of auxiliary student models to the student group leader model to obtain a second distillation loss; and updating the student group leader model according to the second distillation loss, wherein the student group leader model is used for classifying the images to be classified.

Description

Image classification model construction method, image classification method and device and electronic equipment
Technical Field
The present application relates to the field of computer vision, and in particular, to an image classification model construction method, an image classification method and apparatus, and an electronic device.
Background
Image classification, i.e. assigning one of a set of classification labels to an input image using a computer algorithm given the input image and the set of classification labels, is one of the fundamental tasks in the field of computer vision and has wide application in many fields including, for example, smart medicine, smart agriculture, and autopilot.
The rapid development of the deep neural network greatly improves the performance of the image classification model. Most of the existing neural network systems are based on independent and identically distributed assumptions, that is, the training data (source domain data) and the test data (target domain data) are assumed to have the same statistical distribution. However, in practical application scenarios, this assumption is often difficult to be established due to various factors, resulting in a large performance degradation of a well-trained model on test data with different distributions. The domain generalization image classification method can effectively solve the problems.
The existing domain generalization image classification method can be mainly divided into two categories. One is a data manipulation based approach and the other is an alignment based approach. The former aims to enhance the source domain data by different image transformations, or to generate new source domain data to simulate the target domain data, which has the disadvantages: on one hand, data enhancement is to simply expand, transform or randomize the input image, so that the target domain data is difficult to simulate more accurately, and the generalization capability of the trained model is limited; on the other hand, data generation models often require specific parameters when faced with specific tasks, which results in that they do not have generalization capability, and the generated pseudo source domain samples require specific constraints to ensure semantic consistency. The latter then aims to reduce the characterization differences between multiple source domains in a particular space and to learn domain-invariant potential representations for the target domain. However, this method may confuse domain-specific information and domain-invariant information, so that negative migration occurs, resulting in a decrease in classification accuracy of the model on the target domain.
Disclosure of Invention
The embodiment of the application aims to provide an image classification model construction method, an image classification method and device and electronic equipment.
According to a first aspect of embodiments of the present application, there is provided an image classification model construction method, including:
respectively training corresponding auxiliary student models by using a plurality of source domains;
training a student group leader model using the plurality of source domain hybrids;
mutually distilling the predicted outputs of the plurality of auxiliary student models to obtain a first distillation loss;
updating a plurality of said secondary student models according to said first distillation loss;
distilling the updated predicted output of the plurality of auxiliary student models to the student group leader model to obtain a second distillation loss;
and updating the student group leader model according to the second distillation loss, wherein the student group leader model is used for classifying the images to be classified.
Further, after training the corresponding assistant student models respectively by using the plurality of source domains, before mutually distilling the prediction outputs of the plurality of assistant student models, the method further comprises:
and comparing the prediction outputs of all the auxiliary student models with the real labels, reserving the correct prediction outputs, and directly rejecting the wrong prediction outputs.
Further, training the corresponding auxiliary student models respectively using the plurality of source domains comprises:
acquiring image data of a plurality of source domains according to the same category;
adjusting the sizes of the source domain images and coding to obtain a coding matrix of the source domain images;
respectively inputting the plurality of coding matrixes into a plurality of pre-trained auxiliary student models to obtain a plurality of prediction outputs;
calculating a classification loss according to the prediction output and the real label;
and updating the parameters of the corresponding auxiliary student models according to the classification loss.
Further, training a student group leader model using the plurality of source domain hybrids, comprising:
acquiring image data of a plurality of source domains randomly mixed according to categories;
adjusting the size of the mixed image and coding to obtain a coding matrix of the mixed image;
inputting the coding matrix into a pre-trained student group leader model to obtain prediction output;
calculating a classification loss according to the prediction output and the real label;
and updating parameters of the student group leader model according to the classification loss.
Further, mutually distilling the predicted outputs of a plurality of said secondary student models to obtain a first distillation loss, comprising:
sequentially selecting the prediction output of one auxiliary student model as temporary teacher distribution;
taking the average value of the prediction distribution of all the rest auxiliary student models as student distribution;
according to the teacher distribution and the student distribution, corresponding distillation losses are calculated in sequence by using an alignment loss function;
the sum of all distillation losses is taken as the first distillation loss.
Further, distilling the updated predicted outputs of the plurality of secondary student models to the student group leader model to obtain a second distillation loss, comprising:
taking the average value of the prediction outputs of all the auxiliary student models as teacher distribution;
taking the predicted output of the student group leader model as student distribution;
calculating a second distillation loss using an alignment loss function based on the teacher distribution and the student distribution.
According to a second aspect of the embodiments of the present invention, there is provided an image classification model construction method, including:
the first training module is used for respectively training the corresponding auxiliary student models by using a plurality of source domains;
a second training module for training a student group leader model using the plurality of source domain hybrids;
the first distillation module is used for mutually distilling the prediction outputs of the auxiliary student models to obtain a first distillation loss;
a first updating module for updating a plurality of said secondary student models based on said first distillation loss;
a second distillation module for distilling the updated predicted outputs of the plurality of auxiliary student models to the student group leader model to obtain a second distillation loss;
and the second updating module is used for updating the student group leader model according to the second distillation loss, and the student group leader model is used for classifying the images to be classified.
According to a third aspect of embodiments of the present invention, there is provided an image classification method including:
acquiring a target domain image to be classified;
inputting the acquired image to be classified into the updated student group leader model constructed in the first aspect to obtain a prediction output probability;
and setting the category corresponding to the maximum probability value as the category of the target domain image.
According to a fourth aspect of the embodiments of the present invention, there is provided an image classification apparatus including:
the acquisition module is used for acquiring a target domain image to be classified;
the prediction output module is used for inputting the acquired images to be classified into the updated student group leader model constructed in the first aspect to obtain prediction output probability;
and the setting module is used for setting the category corresponding to the maximum probability value as the category of the target domain image.
According to a fifth aspect of an embodiment of the present invention, there is provided an electronic apparatus including:
one or more processors;
a memory for storing one or more programs;
when executed by the one or more processors, cause the one or more processors to implement the method of the first, third, or fourth aspect.
The technical scheme provided by the embodiment of the application can have the following beneficial effects:
(1) a domain generalized image classification network based on knowledge distillation and multi-student networks. Complementary information of a multi-source domain is fully utilized through knowledge distillation among the multi-student models, and domain invariant representation is learned, so that the generalization capability of the multi-student network is improved.
(2) Two-stage knowledge distillation. Knowledge distillation in two stages is adopted, the first stage distillation improves the generalization capability of the model through knowledge exchange between the auxiliary student models, and the second stage knowledge distillation improves the stability of model prediction through knowledge exchange between the auxiliary student model set and the student group leader model.
It is to be understood that both the foregoing general description and the following detailed description are exemplary and explanatory only and are not restrictive of the application.
Drawings
The accompanying drawings, which are incorporated in and constitute a part of this specification, illustrate embodiments consistent with the present application and together with the description, serve to explain the principles of the application.
FIG. 1 is a flowchart illustrating a method of constructing an image classification model according to an exemplary embodiment.
Fig. 2 is a flowchart illustrating step S11, according to an exemplary embodiment.
Fig. 3 is a flowchart illustrating step S12 according to an exemplary embodiment.
Fig. 4 is a flowchart illustrating step S13 according to an exemplary embodiment.
Fig. 5 is a flowchart illustrating step S15 according to an exemplary embodiment.
FIG. 6 is a flow diagram illustrating another image classification model construction method according to an exemplary embodiment.
FIG. 7 is a diagram illustrating a domain-generalized image classification network architecture according to an exemplary embodiment.
Fig. 8 is a flowchart illustrating step S17 according to an exemplary embodiment.
FIG. 9 is a schematic diagram illustrating an auxiliary student model misprediction culling flow, according to an example embodiment.
Fig. 10 is a block diagram illustrating an image classification model construction apparatus according to an exemplary embodiment.
Fig. 11 is a block diagram illustrating another image classification model construction apparatus according to an exemplary embodiment.
FIG. 12 is a flow diagram illustrating a method of image classification according to an exemplary embodiment.
Fig. 13 is a block diagram illustrating an image classification apparatus according to an exemplary embodiment.
Detailed Description
Reference will now be made in detail to the exemplary embodiments, examples of which are illustrated in the accompanying drawings. When the following description refers to the accompanying drawings, like numbers in different drawings represent the same or similar elements unless otherwise indicated. The implementations described in the following exemplary examples do not represent all implementations consistent with the present application. Rather, they are merely examples of apparatus and methods consistent with certain aspects of the present application, as detailed in the appended claims.
The terminology used herein is for the purpose of describing particular embodiments only and is not intended to be limiting of the application. As used in this application and the appended claims, the singular forms "a", "an", and "the" are intended to include the plural forms as well, unless the context clearly indicates otherwise. It should also be understood that the term "and/or" as used herein refers to and encompasses any and all possible combinations of one or more of the associated listed items.
It is to be understood that although the terms first, second, third, etc. may be used herein to describe various information, such information should not be limited to these terms. These terms are only used to distinguish one type of information from another. For example, first information may also be referred to as second information, and similarly, second information may also be referred to as first information, without departing from the scope of the present application. The word "if" as used herein may be interpreted as "at … …" or "when … …" or "in response to a determination", depending on the context.
Example 1:
fig. 1 is a flowchart illustrating an image classification model construction method according to an exemplary embodiment, and the method, as shown in fig. 1, applied in a terminal may include the following steps:
step S11, using a plurality of source domains to respectively train corresponding auxiliary student models;
step S12, training a student group leader model by using the plurality of source domains;
step S13 of distilling the prediction outputs of the plurality of auxiliary student models with each other to obtain a first distillation loss;
a step S14 of updating a plurality of the secondary student models according to the first distillation loss;
step S15, distilling the updated predicted output of the plurality of auxiliary student models to the student group leader model to obtain a second distillation loss;
and step S16, updating the student group leader model according to the second distillation loss, wherein the student group leader model is used for classifying the images to be classified.
According to the technical scheme, complementary information of a multi-source domain is fully utilized through knowledge distillation among multi-student models, and domain invariant representation is learned, so that the generalization capability of a multi-student network is improved. Knowledge distillation in two stages is adopted, the first stage distillation improves the generalization capability of the model through knowledge exchange between the auxiliary student models, and the second stage knowledge distillation improves the stability of model prediction through knowledge exchange between the auxiliary student model set and the student group leader model.
In the specific implementation of step S11, a plurality of source domains are used to train corresponding auxiliary student models respectively; referring to fig. 2, this step may include the following sub-steps:
step S111, acquiring image data of a plurality of source domains according to the same category;
specifically, a plurality of source field image data x having the same label Y are acquired S ={x 1 ,x 2 ,...,x N Multi-source domain images using the same label can ensure the validity of subsequent knowledge distillation.
Step S112, adjusting the sizes of the source domain images and coding to obtain coding matrixes of the source domain images;
specifically, all the source domain images x are divided by using a bilinear difference algorithm S ={x 1 ,x 2 ,...,x N Scaling to the same size to fit the input specification of the pre-trained model;
the source domain images x with the same size S ={x 1 ,x 2 ,...,x N Dividing all pixel values in the image by 255, and carrying out normalization processing on numerical values (val) of three channels of RGB of the image by using a formula (1) to obtain an encoding matrix g of the source domain image S ={g 1 ,g 2 ,...,g N And (c) mean values (mean) of the three channels of RGB are 0.485, 0.456 and 0.406 respectively, and standard deviations (std) are 0.229, 0.224 and 0.225 respectively. And through normalization processing, the neural network model is prevented from gradient explosion in the training process.
Figure BDA0003606089210000071
Step S113, respectively inputting the plurality of coding matrixes to a plurality of pre-trained auxiliary student models to obtain a plurality of prediction outputs;
in particular, a plurality of pre-trained secondary student models S ═ { S ═ S 1 ,S 2 ,...,S N Each secondary student model contains a shared feature extractor F, and a unique classifier head C i Shared feature extractors can reduce the computational resource requirements, and the unique classifier heads are used to learn domain-specific knowledge. Wherein F is composed of a pre-training network such as Alexnet, Resnet 18, Resnet 50, etc. with output dimension changed to d; each classifier head C i Both consist of a fully connected layer (full connected layer) and a Softmax activation function.
The feature extractor F extracts all image coding matrixes
Figure BDA0003606089210000081
Wherein
Figure BDA0003606089210000082
Coding matrix g representing the image in the i-th domain i And D-dimensional real number feature vectors are extracted after the F.
Each of the sorter heads C i All the connected layers in (1) are output by the feature extractor F
Figure BDA0003606089210000083
As input, the unnormalized prediction probability l is obtained i Then using the Softmax activation function shown in equation (2) to obtain the normalized prediction distribution p i As a prediction output.
Figure BDA0003606089210000084
Wherein l ij Is 1 i The output value of the jth node; c is the number of output nodes, namely the number of image categories; t is the distillation temperature, is a hyperparameter and is set to 3. The higher distillation temperature T can smooth the predicted distribution, so that the predicted distribution contains richer inter-class relationship and semantic information.
Step S114, calculating classification loss according to the prediction output and the real label;
in particular, the rootCalculating the classification Loss of each auxiliary student model by using Cross Entropy Loss (Cross Entropy Loss) shown in formula (3) according to the prediction output and the one-hot coding of the real label
Figure BDA0003606089210000085
Used for evaluating the difference degree of the predicted value and the actual value of the auxiliary student model.
Figure BDA0003606089210000086
Wherein when u is a genuine tag, y u 1, otherwise u =0;p iu Is a real number ranging between 0 and 1, representing the probability that the image belongs to the class u.
And step S115, updating corresponding auxiliary student model parameters according to the classification loss.
Specifically, the network automatically calculates the gradient of the classification loss, and updates the parameters of the auxiliary student model by using a Stochastic Gradient Descent (SGD) method, thereby improving the classification accuracy of the auxiliary student model.
In an implementation of step S12, training a student group leader model using the plurality of source domain hybrids; referring to fig. 3, this step may include the following sub-steps:
step S121, acquiring image data of a plurality of source domains randomly mixed according to categories;
specifically, the source domain images are randomly mixed, and a mixed image x is obtained according to the category in step S111 Mix And assume x Mix Image data for the N +1 th source field, i.e. x N+1 . The multi-source domain mixed image is used for training a student group leader model, and the cross-domain classification capability of the student group leader model can be improved.
Step S122, adjusting the size of the mixed image and coding to obtain a coding matrix of the mixed image;
specifically, using the uniform size and normalization method described in step S112, the coding matrix g of the mixed image is obtained N+1
Step S123, inputting the coding matrix into a pre-trained student group leader model S N+1 Obtaining a prediction output;
in particular, the student group leader model comprises a feature extractor F shared with the secondary student models, and a unique classifier head C N+1
The feature extractor F encodes a matrix g from the mixed image N+1 Extracting a feature vector with dimension d
Figure BDA0003606089210000091
Sorter head C N+1 Full connection layer in (1) with the output of feature extractor F
Figure BDA0003606089210000092
As input, the unnormalized prediction probability l is obtained N+1 Then obtaining a normalized prediction distribution p through a Softmax activation function N+1 As the predicted output of the student group leader.
Step S124, calculating classification loss according to the prediction output and the real label;
specifically, the classification loss of the student group leader model is calculated by using cross entropy loss according to the prediction output and the one-hot coding of the real label
Figure BDA0003606089210000093
The method is used for evaluating the difference degree of the predicted value and the actual value of the student group leader model.
And step S125, updating the parameters of the student group leader model according to the classification loss.
Specifically, the network automatically calculates the gradient of the classification loss, and updates the parameters of the student group leader model by using a random gradient descent (SGD) method, thereby improving the classification accuracy of the student group leader model.
In a specific implementation of step S13, mutually distilling the predicted outputs of the plurality of auxiliary student models to obtain a first distillation loss; referring to fig. 4, this step may include the following sub-steps:
step S131, sequentially selecting the prediction output of an auxiliary student model as temporary teacher distribution;
in particular, N secondary student models S ═ { S in a multi-student network 1 ,S 2 ,...,S N The predicted outputs of the two are p ═ p, respectively 1 ,p 2 ,…,p N }。
Selecting the ith auxiliary student model S i Is predicted to output p i As a temporary teacher distribution for domain-specific knowledge delivery.
Step S132, taking the average value of the prediction distribution of all the remaining auxiliary student models as student distribution;
specifically, the mean of the predicted distribution of all the remaining auxiliary student models is calculated
Figure BDA0003606089210000101
As a student distribution, this way of collaborative learning aggregates gradients from different auxiliary student models, making better use of complementarity between different sources.
Step S133, calculating corresponding distillation loss in sequence by using an alignment loss function according to the teacher distribution and the student distribution;
specifically, the Loss of the one-stage knowledge distillation is calculated by using alignment Loss functions such as KL Divergence (Kullback-Leibler Divergence), MSE Loss (MSE Loss), and the like, as shown in formula (4). And aligning the prediction output distribution of the auxiliary student model with each other in a cooperative learning mode to realize knowledge transfer between the models, so as to learn domain invariant representation and improve the generalization capability of the auxiliary student model.
Figure BDA0003606089210000102
Wherein
Figure BDA0003606089210000103
Is an alignment loss function.
And step S134, taking the sum of all distillation losses as a first distillation loss.
Specifically, the sum of the distillation losses between all the auxiliary student models is taken as the first distillation loss, as shown in equation (5).
Figure BDA0003606089210000104
In a specific implementation of step S14, updating a plurality of said secondary student models based on said first distillation loss;
specifically, the network automatically calculates the gradient of the first distillation loss and updates the parameters of the secondary student model using a Stochastic Gradient Descent (SGD) method, thereby directing the next training to proceed in the correct direction.
In an implementation of step S15, distilling the updated predicted outputs of the plurality of secondary student models to the student group leader model to obtain a second distillation loss; referring to fig. 5, this step may include the following sub-steps:
step S151, taking the average value of the prediction outputs of all the auxiliary student models as teacher distribution;
specifically, the predicted mean value of the plurality of auxiliary student models
Figure BDA0003606089210000111
As a teacher distribution, the complementarity between different source domain data can be better utilized in the form of a mean.
Step S152, taking the prediction output of the student group leader model as student distribution;
specifically, the prediction output p of the student group leader model is output N+1 As a student distribution.
And step S153, calculating second distillation loss by using an alignment loss function according to the teacher distribution and the student distribution.
Specifically, the alignment loss function of step S133 is used as the loss function of the two-stage knowledge distillation, and the student distribution and the teacher distribution are aligned to realize knowledge transfer, as shown in formula (6). Because a certain difference exists among the N auxiliary student models, one auxiliary student model is randomly selected for predicting the difficulty in ensuring the robustness of the model, so that two-stage distillation is carried out, and the student group leader model is used as a final prediction model to ensure the stability of the classification result of the model.
Figure BDA0003606089210000112
In an implementation of step S16, the student group leader model is updated based on the second distillation loss.
Specifically, the network automatically calculates the second distillation loss gradient and updates the parameters of the student group leader model using a Stochastic Gradient Descent (SGD) method, thereby directing the next training to proceed in the correct direction.
Referring to fig. 6 and 7, a step S17 is further included, wherein a domain generalization total loss is calculated by weighting according to all the classification losses and distillation losses, and the above model parameters are iterated until the domain generalization total loss reaches a preset convergence condition.
Through the iterative updating of the step S17, the model classification is more accurate, the precision is higher, the student group leader model after the domain generalization total loss convergence is used for classifying the images to be classified, and a more accurate classification result with higher precision can be obtained.
In the specific implementation of step S17, weighted-calculating the total loss of domain generalization according to all the classification losses and distillation losses, and iterating all the above model parameters until the total loss of domain generalization reaches a preset convergence condition; referring to fig. 8, this step may include the following sub-steps:
a step S171 of taking the weighted sum of all the classification losses and distillation losses as the domain generalization total loss;
in particular, all of said classifications are lost
Figure BDA0003606089210000121
And distillation loss
Figure BDA0003606089210000122
As the domain generalization total loss, as shown in equation 7
Figure BDA0003606089210000123
Where α is a hyperparameter to balance the effect of the two types of losses on the total loss, set here to 0.2.
Step S172, iterating all the model parameters until the total loss of domain generalization reaches a preset convergence condition;
specifically, the steps S11-S17 are repeatedly performed for the remaining source domain images until the total loss of domain generalization reaches the preset convergence condition.
Using the student group leader model after the domain generalization total loss convergence to classify the images to be classified; specifically, the obtained target domain image to be classified is input into a student group leader model after the domain generalization total loss convergence to obtain a prediction output probability, and the class corresponding to the maximum probability value is set as the class of the target domain image.
In this embodiment, preferably, referring to fig. 9, after training the corresponding assistant student models respectively using the plurality of source domains, before mutually distilling the prediction outputs of the plurality of assistant student models, the method further includes:
and comparing the prediction outputs of all the auxiliary student models with the real labels, reserving the correct prediction outputs, and directly rejecting the wrong prediction outputs. In the one-stage distillation process, the prediction output of the auxiliary student models is compared with the real labels, correct prediction is reserved and transmitted, and wrong prediction is directly eliminated, so that the correctness and the high efficiency of knowledge transmission among the auxiliary student models are ensured.
Corresponding to the embodiment of the image classification model construction method, the application also provides an embodiment of an image classification model construction device.
Fig. 10 is a block diagram illustrating an image classification model construction apparatus according to an exemplary embodiment. Referring to fig. 10, the apparatus includes: a first training module 11, a second training module 12, a first distillation module 13, a first updating module 14, a second distillation module 15, a second updating module 16.
The first training module 11 is configured to use a plurality of source domains to respectively train corresponding auxiliary student models;
a second training module 12 for training a student group leader model using the plurality of source domain hybrids;
a first distillation module 13, configured to distill prediction outputs of the plurality of auxiliary student models with each other to obtain a first distillation loss;
a first updating module 14 for updating a plurality of said secondary student models according to said first distillation loss;
a second distillation module 15, configured to distill the updated predicted outputs of the plurality of auxiliary student models to the student group leader model to obtain a second distillation loss;
a second updating module 16, configured to update the student group leader model according to the second distillation loss, where the student group leader model is used to classify the image to be classified.
Referring to fig. 11, an iterative update module 17 is further included for calculating the total loss of the domain generalization in a weighted manner according to all the classification losses and distillation losses, and by iterating the above model parameters, until the total loss of the domain generalization reaches a preset convergence condition.
With regard to the apparatus in the above-described embodiment, the specific manner in which each module performs the operation has been described in detail in the embodiment related to the method, and will not be elaborated here.
For the device embodiments, since they substantially correspond to the method embodiments, reference may be made to the partial description of the method embodiments for relevant points. The above-described embodiments of the apparatus are merely illustrative, and the units described as separate parts may or may not be physically separate, and parts displayed as units may or may not be physical units, may be located in one place, or may be distributed on a plurality of network units. Some or all of the modules can be selected according to actual needs to achieve the purpose of the scheme of the application. One of ordinary skill in the art can understand and implement it without inventive effort.
Example 2:
fig. 12 is a flowchart illustrating an image classification method according to an exemplary embodiment, which is applied to a terminal, as shown in fig. 12, and may include the following steps:
step S21, acquiring a target domain image to be classified;
step S22, inputting the acquired image to be classified into the student group leader model constructed in embodiment 1 to obtain a prediction output probability;
step S23, setting the category corresponding to the maximum probability value as the category of the target domain image.
The present invention uses the common PACS dataset in domain-generalized image classification to evaluate the performance of the present invention. The images in this dataset contain the four fields of Art Painting, carton, Photo, and Sketch, all of which fall into the same 7 categories. In the experiment, each time one domain is selected as an inaccessible target domain, the remaining three domains are selected as source domains, so that 4 tasks can be formed in total. The method uses the accuracy rate commonly used in the domain generalization image classification as an evaluation index, and the index can measure the performance of model prediction on the whole.
As shown in the following table, the Art Painting in the table indicates tasks using the Art Painting domain as the target domain and the remaining three domains as the source domain, the other 3 representation methods are the same, and the Avg in the last column indicates the average accuracy of the model over 4 tasks. It can be seen that the highest accuracy is obtained on both Photo and Sketch tasks of the PACS dataset, and the average accuracy of 4 tasks is also the highest, which is 5.58% higher than that of the method without domain generalization Deepall. In addition, on the Sketch task, the method is improved by 8.26 percent compared with the other best MixStyle method, and Sketch is used as a unique colorless domain, so that the method can be better characterized by the invariant learning domain of the knowledge distillation of a multi-student model. In a word, the method effectively solves the domain generalization problem of image classification through a knowledge distillation-based multi-student network scheme, and is superior to the existing method in the accuracy rate of domain generalization image classification.
TABLE 1 results of domain generalization image Classification experiments on PACS datasets
Figure BDA0003606089210000141
Figure BDA0003606089210000151
Corresponding to the embodiment of the image classification method, the application also provides an embodiment of the image classification device.
Fig. 13 is a block diagram illustrating an image classification device according to an exemplary embodiment. Referring to fig. 13, the apparatus includes: an acquisition module 21, a prediction output module 22 and a setting module 23.
An obtaining module 21, configured to obtain a target domain image to be classified;
a prediction output module 22, configured to input the acquired image to be classified into the updated student group leader model constructed according to any one of claims 1 to 6, so as to obtain a prediction output probability;
and the setting module 23 is configured to set the category corresponding to the maximum probability value as the category of the target domain image.
With regard to the apparatus in the above-described embodiment, the specific manner in which each module performs the operation has been described in detail in the embodiment related to the method, and will not be elaborated here.
For the device embodiments, since they substantially correspond to the method embodiments, reference may be made to the partial description of the method embodiments for relevant points. The above-described embodiments of the apparatus are merely illustrative, and the units described as separate parts may or may not be physically separate, and parts displayed as units may or may not be physical units, may be located in one place, or may be distributed on a plurality of network units. Some or all of the modules can be selected according to actual needs to achieve the purpose of the scheme of the application. One of ordinary skill in the art can understand and implement it without inventive effort.
Correspondingly, the present application also provides an electronic device, comprising: one or more processors; a memory for storing one or more programs; when executed by the one or more processors, cause the one or more processors to implement an image classification model construction method, an image classification method as described above.
Accordingly, the present application further provides a computer readable storage medium, on which computer instructions are stored, wherein the instructions, when executed by a processor, implement the image classification model construction method and the image classification method as described above.
Other embodiments of the present application will be apparent to those skilled in the art from consideration of the specification and practice of the present disclosure. This application is intended to cover any variations, uses, or adaptations of the invention following, in general, the principles of the application and including such departures from the present disclosure as come within known or customary practice within the art to which the invention pertains. It is intended that the specification and examples be considered as exemplary only, with a true scope and spirit of the application being indicated by the following claims.
It will be understood that the present application is not limited to the precise arrangements described above and shown in the drawings and that various modifications and changes may be made without departing from the scope thereof. The scope of the application is limited only by the appended claims.

Claims (10)

1. An image classification model construction method is characterized by comprising the following steps:
respectively training corresponding auxiliary student models by using a plurality of source domains;
training a student group leader model using the plurality of source domain hybrids;
mutually distilling the predicted outputs of the plurality of auxiliary student models to obtain a first distillation loss;
updating a plurality of said secondary student models according to said first distillation loss;
distilling the updated predicted output of the plurality of auxiliary student models to the student group leader model to obtain a second distillation loss;
and updating the student group leader model according to the second distillation loss, wherein the student group leader model is used for classifying the images to be classified.
2. The method of claim 1, wherein after training each corresponding secondary student model using a respective one of the plurality of source domains, prior to mutually distilling the predicted outputs of the plurality of secondary student models, further comprising:
and comparing the prediction outputs of all the auxiliary student models with the real labels, reserving the correct prediction outputs, and directly rejecting the wrong prediction outputs.
3. The method of claim 1, wherein using a plurality of source domains to each train a corresponding secondary student model comprises:
acquiring image data of a plurality of source domains according to the same category;
adjusting the sizes of the source domain images and coding to obtain a coding matrix of the source domain images;
respectively inputting the plurality of coding matrixes to a plurality of pre-trained auxiliary student models to obtain a plurality of prediction outputs;
calculating a classification loss according to the prediction output and the real label;
and updating the parameters of the corresponding auxiliary student models according to the classification loss.
4. The method of claim 1, wherein training a student group leader model using the plurality of source domain hybrids comprises:
acquiring image data of a plurality of source domains randomly mixed according to categories;
adjusting the size of the mixed image and coding to obtain a coding matrix of the mixed image;
inputting the coding matrix into a pre-trained student group leader model to obtain prediction output;
calculating a classification loss according to the prediction output and the real label;
and updating parameters of the student group leader model according to the classification loss.
5. The method of claim 1, wherein distilling the predicted outputs of the plurality of secondary student models from each other to obtain a first distillation loss comprises:
sequentially selecting the prediction output of one auxiliary student model as temporary teacher distribution;
taking the average value of the prediction distribution of the rest auxiliary student models as student distribution;
according to the teacher distribution and the student distribution, corresponding distillation losses are calculated in sequence by using an alignment loss function;
the sum of all distillation losses is taken as the first distillation loss.
6. The method of claim 1, wherein distilling the updated predicted output of the plurality of secondary student models to the student group leader model results in a second distillation loss, comprising:
taking the average value of the prediction outputs of all the auxiliary student models as teacher distribution;
taking the predicted output of the student group leader model as student distribution;
calculating a second distillation loss using an alignment loss function based on the teacher distribution and the student distribution.
7. An image classification model construction method is characterized by comprising the following steps:
the first training module is used for respectively training the corresponding auxiliary student models by using a plurality of source domains;
a second training module for training a student group leader model using the plurality of source domain hybrids;
the first distillation module is used for mutually distilling the prediction outputs of the auxiliary student models to obtain a first distillation loss;
a first updating module for updating a plurality of said secondary student models based on said first distillation loss;
the second distillation module is used for distilling the updated predicted output of the plurality of auxiliary student models to the student group leader model to obtain second distillation loss;
and the second updating module is used for updating the student group leader model according to the second distillation loss, and the student group leader model is used for classifying the images to be classified.
8. An image classification method, comprising:
acquiring a target domain image to be classified;
inputting the acquired image to be classified into the updated student group leader model constructed in the claim 1 to obtain a prediction output probability;
and setting the category corresponding to the maximum probability value as the category of the target domain image.
9. An image classification apparatus, comprising:
the acquisition module is used for acquiring a target domain image to be classified;
a prediction output module, configured to input the acquired image to be classified into the updated student group leader model constructed according to any one of claims 1 to 6, so as to obtain a prediction output probability;
and the setting module is used for setting the category corresponding to the maximum probability value as the category of the target domain image.
10. An electronic device, comprising:
one or more processors;
a memory for storing one or more programs;
when executed by the one or more processors, cause the one or more processors to implement the method of any one of claims 1-6, 8.
CN202210424033.3A 2022-04-20 2022-04-20 Image classification model construction method, image classification method and device and electronic equipment Pending CN114937166A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210424033.3A CN114937166A (en) 2022-04-20 2022-04-20 Image classification model construction method, image classification method and device and electronic equipment

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210424033.3A CN114937166A (en) 2022-04-20 2022-04-20 Image classification model construction method, image classification method and device and electronic equipment

Publications (1)

Publication Number Publication Date
CN114937166A true CN114937166A (en) 2022-08-23

Family

ID=82861722

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210424033.3A Pending CN114937166A (en) 2022-04-20 2022-04-20 Image classification model construction method, image classification method and device and electronic equipment

Country Status (1)

Country Link
CN (1) CN114937166A (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115496955A (en) * 2022-11-18 2022-12-20 之江实验室 Image classification model training method, image classification method, apparatus and medium

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115496955A (en) * 2022-11-18 2022-12-20 之江实验室 Image classification model training method, image classification method, apparatus and medium
CN115496955B (en) * 2022-11-18 2023-03-24 之江实验室 Image classification model training method, image classification method, device and medium

Similar Documents

Publication Publication Date Title
CN111291183B (en) Method and device for carrying out classification prediction by using text classification model
CN112446423B (en) Fast hybrid high-order attention domain confrontation network method based on transfer learning
WO2020232874A1 (en) Modeling method and apparatus based on transfer learning, and computer device and storage medium
CN112418351B (en) Zero sample learning image classification method based on global and local context sensing
CN110619059B (en) Building marking method based on transfer learning
WO2020125404A1 (en) Method and apparatus for constructing neural network and computer-readable medium
CN115170449B (en) Multi-mode fusion scene graph generation method, system, equipment and medium
CN107871103A (en) Face authentication method and device
CN116227624A (en) Federal knowledge distillation method and system oriented to heterogeneous model
CN112131261A (en) Community query method and device based on community network and computer equipment
CN116089645A (en) Hierarchical style-based conditional text-e-commerce picture retrieval method and system
CN115984930A (en) Micro expression recognition method and device and micro expression recognition model training method
CN115270752A (en) Template sentence evaluation method based on multilevel comparison learning
CN114937166A (en) Image classification model construction method, image classification method and device and electronic equipment
CN111309823A (en) Data preprocessing method and device for knowledge graph
CN112541530B (en) Data preprocessing method and device for clustering model
CN114118370A (en) Model training method, electronic device, and computer-readable storage medium
Huang et al. Efficient optimization for linear dynamical systems with applications to clustering and sparse coding
CN116306969A (en) Federal learning method and system based on self-supervision learning
Huang Normalization Techniques in Deep Learning
CN116563602A (en) Fine granularity image classification model training method based on category-level soft target supervision
CN114677535A (en) Training method of domain-adaptive image classification network, image classification method and device
CN115578593A (en) Domain adaptation method using residual attention module
CN115063374A (en) Model training method, face image quality scoring method, electronic device and storage medium
CN115439791A (en) Cross-domain video action recognition method, device, equipment and computer-readable storage medium

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination