Disclosure of Invention
In view of the above, it is necessary to provide a model training method, apparatus, computer device and storage medium for solving the above technical problems.
An image segmentation model training method comprises the following steps:
acquiring a trained teacher model and a student model to be trained; the teacher model and the student models each comprise a feature extractor, a master classifier and a slave classifier; the secondary classifier is created according to the image characteristics of each sample image output by the characteristic extractor;
inputting the sample image into the teacher model to obtain a first main classification result output by a main classifier of the teacher model, and a first auxiliary classification result and first class information output by an auxiliary classifier of the teacher model;
inputting the sample image into the student model to be trained to obtain a second main classification result output by a main classifier of the student model to be trained, and a second auxiliary classification result and second class information output by an auxiliary classifier of the student model to be trained;
constructing a target loss function of the student model according to the first primary classification result, the first secondary classification result, the second primary classification result, the second secondary classification result, the first class information and the second class information;
performing iterative training on the student model to be trained according to the target loss function to obtain a trained student model; and the trained student model is used for carrying out semantic segmentation on the input image.
In one embodiment, the constructing a model loss function of the student model includes:
constructing a first penalty function based on the first principal classification result and the second principal classification result;
constructing a second loss function based on the first category information and the second category information;
constructing a third loss function based on the first slave classification result and the second slave classification result;
and determining the target loss function according to the first loss function, the second loss function, the third loss function and the cross entropy loss function of the second main classification result.
In one embodiment, the first loss function is obtained by:
wherein the content of the first and second substances,
L kd is the first loss function;
His the height of the image of the sample,
Wis the width of the sample image in question,
xare pixels in the sample image;
p t in order for the first primary classification result,
p s in order to be the result of the second master classification,
is composed of
KLA divergence calculation function.
In one embodiment, the second loss function is obtained by:
wherein the content of the first and second substances,
L rec is the second loss function; n is the number of classes contained in the sample image,
indicates the ith category in the first category information,
and representing the ith category in the second category information, and cos represents a cosine similarity algorithm.
In one embodiment, the third loss function is obtained by:
wherein the content of the first and second substances,
L ob is the third loss function;
His the height of the image of the sample,
Wis the width of the sample image in question,
xare pixels in the sample image;
p a,t is firstFrom the result of the classification it is possible to,
p a,s in order to be the result of the second slave classification,
is composed of
KLA divergence calculation function.
In one embodiment, the target loss function is obtained by:
wherein the content of the first and second substances,
L s is the target loss function;
L ce a cross entropy loss function for the second principal classification result;
L kd is the first loss function;
L rec is the second loss function;
L ob is the third loss function;
is a hyperparameter of the first loss function,
is a hyperparameter of the second loss function.
In one embodiment, the iteratively training the student model to be trained according to the target loss function to obtain a trained student model, includes:
performing iterative training on the student model to be trained on the target loss function;
and when the iterative training times of the student model reach the preset training times, taking the student model reaching the preset training times as a trained student model.
A model training apparatus, the apparatus comprising:
the model acquisition module is used for acquiring a trained teacher model and a student model to be trained; the teacher model and the student models each comprise a feature extractor, a master classifier and a slave classifier; the secondary classifier is respectively created according to the image characteristics of each sample image output by the characteristic extractor;
the first classification module is used for inputting the sample image into the teacher model to obtain a first main classification result output by a main classifier of the teacher model, and a first secondary classification result and first class information output by a secondary classifier of the teacher model;
the second classification module is used for inputting the sample image into the student model to be trained to obtain a second main classification result output by a main classifier of the student model to be trained, and a second secondary classification result and second class information output by a secondary classifier of the student model to be trained;
the loss function building module is used for building a target loss function of the student model according to the first master classification result, the first slave classification result, the second master classification result, the second slave classification result, the first class information and the second class information;
the model training module is used for carrying out iterative training on the student model to be trained according to the target loss function to obtain a trained student model; and the trained student model is used for carrying out semantic segmentation on the input image.
A computer device comprising a memory and a processor, the memory storing a computer program, the processor implementing the following steps when executing the computer program:
acquiring a trained teacher model and a student model to be trained; the teacher model and the student models each comprise a feature extractor, a master classifier and a slave classifier; the secondary classifier is respectively created according to the image characteristics of each sample image output by the characteristic extractor;
inputting the sample image into the teacher model to obtain a first main classification result output by a main classifier of the teacher model, and a first auxiliary classification result and first class information output by an auxiliary classifier of the teacher model;
inputting the sample image into the student model to be trained to obtain a second main classification result output by a main classifier of the student model to be trained, and a second auxiliary classification result and second class information output by an auxiliary classifier of the student model to be trained;
constructing a target loss function of the student model according to the first primary classification result, the first secondary classification result, the second primary classification result, the second secondary classification result, the first class information and the second class information;
performing iterative training on the student model to be trained according to the target loss function to obtain a trained student model; and the trained student model is used for carrying out semantic segmentation on the input image.
A computer-readable storage medium, on which a computer program is stored which, when executed by a processor, carries out the steps of:
acquiring a trained teacher model and a student model to be trained; the teacher model and the student models each comprise a feature extractor, a master classifier and a slave classifier; the secondary classifier is respectively created according to the image characteristics of each sample image output by the characteristic extractor;
inputting the sample image into the teacher model to obtain a first main classification result output by a main classifier of the teacher model, and a first auxiliary classification result and first class information output by an auxiliary classifier of the teacher model;
inputting the sample image into the student model to be trained to obtain a second main classification result output by a main classifier of the student model to be trained, and a second auxiliary classification result and second class information output by an auxiliary classifier of the student model to be trained;
constructing a target loss function of the student model to be trained according to the first primary classification result, the first secondary classification result, the second primary classification result, the second secondary classification result, the first class information and the second class information;
performing iterative training on the student model to be trained according to the target loss function to obtain a trained student model; and the trained student model is used for carrying out semantic segmentation on the input image.
The model training method, the model training device, the computer equipment and the storage medium comprise the following steps: acquiring a trained teacher model and a student model to be trained; the teacher model and the student model respectively comprise a feature extractor, a master classifier and a slave classifier; respectively creating the images from the classifier according to the image characteristics of each sample image output by the characteristic extractor; inputting the sample image into a teacher model to obtain a first main classification result output by a main classifier of the teacher model, and a first auxiliary classification result and first class information output by an auxiliary classifier of the teacher model; inputting the sample image into a student model to be trained to obtain a second main classification result output by a main classifier of the student model to be trained, and a second auxiliary classification result and second class information output by an auxiliary classifier of the student model to be trained; constructing a target loss function of the student model according to the first primary classification result, the first secondary classification result, the second primary classification result, the second secondary classification result, the first class information and the second class information; performing iterative training on the student model to be trained according to the target loss function to obtain a trained student model; the trained student model is used for carrying out semantic segmentation on the input image; according to the method and the system, the slave classifiers are arranged in the training process of the teacher model and the student model, so that the model can make up for learning of the characteristic features of each sample image while fitting the global features well, and the performance of the student model is improved.
Detailed Description
In order to make the objects, technical solutions and advantages of the present application more apparent, the present application is described in further detail below with reference to the accompanying drawings and embodiments. It should be understood that the specific embodiments described herein are merely illustrative of the present application and are not intended to limit the present application.
The image segmentation model training method provided by the application can be applied to the application environment shown in fig. 1. Wherein the terminal 11 communicates with the server 12 via a network. The server 12 acquires a sample image from the terminal 11; the server 12 acquires a trained teacher model and a student model to be trained; the teacher model and the student model respectively comprise a feature extractor, a master classifier and a slave classifier; respectively creating the images from the classifier according to the image characteristics of each sample image output by the characteristic extractor; the server 12 inputs the sample image into the teacher model, and obtains a first main classification result output by a main classifier of the teacher model, and a first auxiliary classification result and first category information output by an auxiliary classifier of the teacher model; the server 12 inputs the sample image into the student model to be trained, and obtains a second primary classification result output by a primary classifier of the student model to be trained, and a second secondary classification result and second class information output by a secondary classifier of the student model to be trained; the server 12 constructs a target loss function of the student model according to the first master classification result, the first slave classification result, the second master classification result, the second slave classification result, the first class information and the second class information; the server 12 performs iterative training on the student model to be trained according to the target loss function to obtain a trained student model; the trained student model is used for carrying out semantic segmentation on the input image.
The terminal 11 may be, but not limited to, various personal computers, notebook computers, smart phones, tablet computers, and portable wearable devices, and the server 12 may be implemented by an independent server or a server cluster formed by a plurality of servers.
In one embodiment, as shown in fig. 2, an image segmentation model training method is provided, which is described by taking the method as an example applied to the server 12 in fig. 1, and includes the following steps:
step 21, obtaining a trained teacher model and a student model to be trained; the teacher model and the student model respectively comprise a feature extractor, a master classifier and a slave classifier; the image features of each sample image output by the feature extractor are respectively created from the classifier.
The teacher model and the student model are descriptions of different branch models in a knowledge distillation scene; knowledge distillation, as an important model compression means, can migrate dark knowledge (different from explicit knowledge, i.e., knowledge that can only be mined by a machine learning process) in a complex model (teacher, i.e., teacher model) to a simple model (student model), so that the student model can obtain a training effect similar to that of the teacher model with less spatial complexity and training time, and the fitting ability of the student model can approach or even exceed that of the teacher model.
The teacher model is pre-trained, and the student model is a model which is not pre-trained or not completely trained; the teacher model and the student models can be the same type of neural network models, namely the teacher model and the student models can have the same network layer structure; further, the teacher model and the student model may be different types of neural network models, that is, the teacher model and the student model have different network layer structures. Typically, the number of layers of the student model network layer will be less than the number of layers of the teacher model network layer, i.e., the number of model sizes and model weight parameters of the student model is less than that of the teacher model. Therefore, the student models are trained through the teacher model in a knowledge distillation mode, the output speed of the prediction result can be improved, and computer resources can be saved, namely the light-weight student models are obtained through knowledge distillation.
The teacher model and the student model both comprise a feature extractor, a master classifier and a slave classifier, and the feature extractor, the master classifier and the slave classifier in the teacher model are trained in advance. The master classifier and the slave classifier can be realized by a neural network, a random forest, a support vector machine and the like; the main classifier and the slave classifier are different in that the main classifier is used for identifying, learning and classifying the features of the global sample images, and the slave classifier is created according to each sample image, so that the feature details of the training sample images can be better captured than the main classifier, and therefore the student model is helped to obtain more dark knowledge to obtain better knowledge distillation performance and student model training performance.
Specifically, the present disclosure obtains a model that has been trained in advance as a teacher model; and obtaining a model to be trained which is similar to the structure of the teacher model or is directly constructed according to the teacher model as a student model. The teacher model and the student model at least comprise a feature extractor, a master classifier and a slave classifier, and the feature extractor, the master classifier and the slave classifier are used for extracting and classifying features in the sample images.
And step 22, inputting the sample image into the teacher model to obtain a first main classification result output by a main classifier of the teacher model, and a first auxiliary classification result and first class information output by an auxiliary classifier of the teacher model.
The first main classification result refers to a prediction result obtained after the characteristics are processed by a main classifier of the teacher model; the first secondary classification result refers to the characteristic after the dimensionality reduction processing, and is a prediction result obtained after the secondary classifier of the teacher model is processed, and the dimensionality reduction processing can be passed; the first category information is the number of categories included in the sample image of the teacher model output from the classifier.
Specifically, the server inputs a sample image into a teacher model, and after a feature extractor in the teacher model extracts features from the sample image, the sample image is input into a main classifier for calculation processing to obtain a first main classification result; on the other hand, the extracted features are subjected to dimension reduction processing according to the structure of the slave classifier and then input to the slave classifier, so that the first slave classification result and the first class information are output from the slave classifier.
And step 23, inputting the sample image into the student model to be trained, and obtaining a second main classification result output by the main classifier of the student model to be trained, and a second auxiliary classification result and second class information output by the auxiliary classifier of the student model to be trained.
The second main classification result refers to a prediction result obtained after the characteristics are processed by a main classifier of the student model; the second secondary classification result is a prediction result obtained by processing the features subjected to the dimensionality reduction processing by a secondary classifier of the student model; the second class information refers to the number of classes contained in the sample image of the student model output from the classifier.
Specifically, the server may input a sample image that is the same as the teacher model into the student model, and after extracting features from the sample image, a feature extractor in the student model inputs the sample image into the main classifier for calculation processing to obtain a second main classification result; on the other hand, the extracted features are subjected to dimension reduction processing according to the structure of the slave classifier and then input to the slave classifier, so that a second slave classification result and second category information are output from the slave classifier.
And 24, constructing a target loss function of the student model according to the first primary classification result, the first secondary classification result, the second primary classification result, the second secondary classification result, the first class information and the second class information.
Specifically, the classification results and the class information of the ideal student model and the teacher model obtained by processing the sample images tend to be consistent or nearly consistent, that is, the classification prediction of the student model on the sample images achieves the accuracy rate similar to that of the teacher model. Therefore, a plurality of loss functions can be respectively constructed for the main classifier and the auxiliary classifier of the student model based on the error between the output results of the teacher model and the student model, and then a total loss function is obtained through construction of the plurality of loss functions and serves as a target loss function of the student model. The loss function can be constructed by using KL divergence (KLD), cosine similarity, cross entropy and the like.
Step 25, performing iterative training on the student model to be trained according to the target loss function to obtain a trained student model; the trained student model is used for carrying out semantic segmentation on the input image.
Specifically, the server updates the weight parameters of the current student model based on the loss information of knowledge distillation, namely a target loss function, and continuously performs loop iteration to obtain the student model with the same or similar prediction effect as the teacher model when training is completed; the student model is a compressed teacher model and can perform image segmentation processing on an input image; image segmentation is to divide the input image into objects belonging to the same category or the same individual and to separate the sub-parts.
The model training method, the model training device, the computer equipment and the storage medium comprise the following steps: acquiring a trained teacher model and a student model to be trained; the teacher model and the student model respectively comprise a feature extractor, a master classifier and a slave classifier; respectively creating the images from the classifier according to the image characteristics of each sample image output by the characteristic extractor; inputting the sample image into a teacher model to obtain a first main classification result output by a main classifier of the teacher model, and a first auxiliary classification result and first class information output by an auxiliary classifier of the teacher model; inputting the sample image into a student model to be trained to obtain a second main classification result output by a main classifier of the student model to be trained, and a second auxiliary classification result and second class information output by an auxiliary classifier of the student model to be trained; constructing a target loss function of the student model according to the first primary classification result, the first secondary classification result, the second primary classification result, the second secondary classification result, the first class information and the second class information; performing iterative training on the student model to be trained according to the target loss function to obtain a trained student model; the trained student model is used for carrying out semantic segmentation on the input image; according to the method and the system, the slave classifiers are arranged in the training process of the teacher model and the student model, so that the model compensates for learning of the characteristic features of each sample image when well combining with the global features, the efficiency of model training is improved, and finally the performance of the student model is improved.
In one embodiment, as shown in fig. 3, step 24, constructing a model loss function for the student model includes:
step 31, constructing a first loss function based on the first principal classification result and the second principal classification result;
step 32, constructing a second loss function based on the first category information and the second category information;
step 33, constructing a third loss function based on the first slave classification result and the second slave classification result;
and step 34, determining a target loss function according to the first loss function, the second loss function, the third loss function and the cross entropy loss function of the second main classification result.
Specifically, the first loss function is used for calculating the overall difference between the classification results output by the main classifier in the teacher model and the student model; therefore, the first loss function is constructed based on the first main classification result and the second main classification result, the weight parameters of the main classifiers in the student model can be mainly adjusted, and the difference between the adjusted main classifiers of the student model and the teacher model is reduced. The second loss function is the same as the third loss function. The cross entropy loss function is used to characterize the difference between the true sample label and the prediction probability.
The target loss function is determined according to the first loss function, the second loss function, the third loss function and the cross entropy loss function of the second main classification result, and the final loss function is obtained by aiming at the weight parameters of the optimized student model.
In this embodiment, the first loss function is used to reduce the difference between the master classifier in the student model and the master classifier in the teacher model, the second loss function and the third loss function are used to reduce the difference between the slave classifier in the student model and the master classifier in the teacher model, and the cross entropy loss function is used to reduce the difference between the cross entropy loss function and the real sample label; the target loss function determined by the first loss function, the second loss function, the third loss function and the cross entropy loss function can reduce the difference between the student model and the teacher model on the whole, and the overall performance of the student model is improved.
In one embodiment, step 31, the first loss function is obtained by:
wherein the content of the first and second substances,
L kd is a first loss function;
His the height of the image of the sample,
Wis the width of the image of the sample,
xare pixels in the sample image;
p t in order for the first primary classification result,
p s in order to be the result of the second master classification,
is composed of
KLA divergence calculation function.
Specifically, the server obtains the spatial dimensions, i.e., width and height, of the sample image; invokingKLThe divergence algorithm calculates the difference between the first principal classification result and the second principal classification result and reduces the difference by reducing the differencep t Andp s in betweenKLDivergence encourages student models to be generated and connected through a master classifierp t Likep s 。
This embodiment is achieved byKLThe divergence algorithm plays a role in encouraging the student model to output a classification result similar to that of the teacher model through the main classifier, and the final performance of the student model training is improved.
In one embodiment, step 32, the second loss function is obtained by:
wherein the content of the first and second substances,
L rec is as followsA second loss function; n is the number of categories contained in the sample image,
indicates the ith category in the first category information,
and representing the ith category in the second category information, and cos represents a cosine similarity algorithm.
Specifically, the server invokes a cosine similarity algorithm to calculate the difference between the first category information and the second category information, and corrects the slave classifier of the student model by encouraging a large cosine similarity.
In the embodiment, the class number contained in the sample image output from the classifier in the student model is corrected by a cosine similarity algorithm, so that the final performance of the student model is improved.
In one embodiment, step 33, the third loss function is obtained by:
wherein the content of the first and second substances,
L ob is a third loss function;
His the height of the image of the sample,
Wis the width of the image of the sample,
xare pixels in the sample image;
p a,t in order to be the result of the first slave classification,
p a,s in order to be the result of the second slave classification,
is composed of
KLA divergence calculation function.
Specifically, the server obtains the spatial dimensions, i.e., width and height, of the sample image; calling KL divergence algorithm to calculate difference between the first slave classification result and the second slave classification result, and reducingp a,t Andp a,s in betweenKLDivergence encourages student models by generating sums from classifiersp a,t Likep a,s 。
This embodiment is achieved byKLThe divergence algorithm plays a role in encouraging the student model to output a classification result similar to that of the teacher model from the classifier, and the final performance of the student model is improved.
In one embodiment, step 34, the target loss function is obtained by:
wherein the content of the first and second substances,
L s is a target loss function;
L ce a cross entropy loss function that is a second principal classification result;
L kd is a first loss function;
L rec is a second loss function;
L ob is a third loss function;
is the hyperparameter of the first loss function,
is the hyperparameter of the second loss function.
Specifically, the hyperparameter of the first loss function and the hyperparameter of the second loss function can be set to a fixed value according to experience and adjusted in the training process; for example, will
And
are set to 10 or other values.
The super-parameter of the first loss function and the super-parameter of the second loss function are set, so that the processing stability of the model to the sample image is improved, and the performance of the student model is improved.
In one embodiment, the process of deriving the output from the classifier may be described by:
wherein the content of the first and second substances,
may be set to a constant value, such as 0.1; cos represents cosine similarity algorithm, exp (x) represents
e x (ii) a N is the number of categories contained in the sample image,
indicates the ith category in the first category information,
indicating the ith category in the second category information,
a category j among the N pieces of category information representing the teacher model,
a class j among the N class information representing the student model,
representing the characteristic information in the teacher model,
representing characteristic information in the student model.
According to the embodiment, the class information, classification and results can be output through the slave classifiers of the teacher model and the student models, and the final performance of the student models is improved.
In one embodiment, in order to ensure that the teacher model can generate more accurate slave classifiers and corresponding classification results, the slave classifiers of the teacher model also need to be constrained, and the constraint is an optimized loss function of the teacher model.
Wherein the content of the first and second substances,
L t an optimization loss function representing a teacher model;
His the height of the image of the sample,
Wis the width of the sample image;
c(x)representing the x position
The above-mentioned category;
may be set to a constant value, such as 0.1;
representing feature information in the teacher model.
According to the embodiment, the output of the teacher model is constrained by optimizing the loss function, so that the teacher model can better supervise the student model.
In one embodiment, step 35, iteratively training the student model to be trained according to the target loss function to obtain a trained student model, includes: carrying out iterative training on a student model to be trained by the target loss function; and when the iterative training times of the student model reach the preset training times, taking the student model reaching the preset training times as a trained student model.
The training of this embodiment to student's model is retrained through setting up the training number of times of predetermineeing for student's model accomplishes the training promptly after reaching certain training iteration number, has improved student's model's final performance.
In one embodiment, yet another model training method is provided, which may be represented schematically in FIG. 4.
Specifically, the server respectively inputs the sample images into a teacher model and a student model; a feature extractor of the teacher model extracts features from the sample image, the features are processed by a main classifier to obtain a first main classification result, and on the one hand, the features are processed by a secondary classifier to obtain first category information and a first secondary classification result. Similarly, the feature extractor of the student model extracts features from the sample image, the features are processed by the primary classifier to obtain a second primary classification result, and on the one hand, the features are processed by the secondary classifier to obtain second category information and a second secondary classification result. Further, the server constructs a first loss function according to the first main classification result and the second main classification result; constructing a second loss function according to the first category information and the second category information; constructing a third loss function by the first slave classification result and the second slave classification result; and combining the cross entropy loss function of the second main classification result to obtain a target loss function to perform iterative training on the student model to be trained until the training times of the student model reach preset training times, thereby completing the training of the student model.
It should be understood that although the steps in the flowcharts of fig. 2 and 3 are shown in sequence as indicated by the arrows, the steps are not necessarily performed in sequence as indicated by the arrows. The steps are not performed in the exact order shown and described, and may be performed in other orders, unless explicitly stated otherwise. Moreover, at least some of the steps in fig. 2 and 3 may include multiple steps or multiple stages, which are not necessarily performed at the same time, but may be performed at different times, and the order of performing the steps or stages is not necessarily sequential, but may be performed alternately or alternately with other steps or at least some of the other steps or stages.
In one embodiment, as shown in fig. 5, there is provided a model training apparatus including: a model obtaining module 51, a first classification module 52, a second classification module 53, a loss function building module 54, and a model training module 55, wherein:
a model obtaining module 51, configured to obtain a trained teacher model and a student model to be trained; the teacher model and the student model respectively comprise a feature extractor, a master classifier and a slave classifier; respectively creating the images from the classifier according to the image characteristics of each sample image output by the characteristic extractor;
a first classification module 52, configured to input the sample image into the teacher model, and obtain a first master classification result output by a master classifier of the teacher model, and a first slave classification result and first class information output by a slave classifier of the teacher model;
a second classification module 53, configured to input the sample image into the student model to be trained, to obtain a second master classification result output by the master classifier of the student model to be trained, and a second slave classification result and second class information output by the slave classifier of the student model to be trained;
the loss function building module 54 is configured to build a target loss function of the student model according to the first master classification result, the first slave classification result, the second master classification result, the second slave classification result, the first class information, and the second class information;
the model training module 55 is used for performing iterative training on the student model to be trained according to the target loss function to obtain a trained student model; the trained student model is used for carrying out semantic segmentation on the input image.
In one embodiment, the penalty function construction module 54 is further configured to construct a first penalty function based on the first principal classification result and the second principal classification result; constructing a second loss function based on the first category information and the second category information; constructing a third loss function based on the first slave classification result and the second slave classification result; and determining a target loss function according to the first loss function, the second loss function, the third loss function and the cross entropy loss function of the second main classification result.
In one embodiment, the model training module 55 is further configured to iteratively train the objective loss function on the student model to be trained; and when the iterative training times of the student model reach the preset training times, taking the student model reaching the preset training times as a trained student model.
For specific limitations of the model training device, reference may be made to the above limitations of the model training method, which are not described herein again. The modules in the model training device can be wholly or partially realized by software, hardware and a combination thereof. The modules can be embedded in a hardware form or independent from a processor in the computer device, and can also be stored in a memory in the computer device in a software form, so that the processor can call and execute operations corresponding to the modules.
In one embodiment, a computer device is provided, which may be a server, and its internal structure diagram may be as shown in fig. 6. The computer device includes a processor, a memory, and a network interface connected by a system bus. Wherein the processor of the computer device is configured to provide computing and control capabilities. The memory of the computer device comprises a nonvolatile storage medium and an internal memory. The non-volatile storage medium stores an operating system, a computer program, and a database. The internal memory provides an environment for the operation of an operating system and computer programs in the non-volatile storage medium. The database of the computer device is used to store model training data. The network interface of the computer device is used for communicating with an external terminal through a network connection. The computer program is executed by a processor to implement a model training method.
Those skilled in the art will appreciate that the architecture shown in fig. 6 is merely a block diagram of some of the structures associated with the disclosed aspects and is not intended to limit the computing devices to which the disclosed aspects apply, as particular computing devices may include more or less components than those shown, or may combine certain components, or have a different arrangement of components.
In one embodiment, a computer device is provided, comprising a memory and a processor, the memory having a computer program stored therein, the processor implementing the following steps when executing the computer program:
acquiring a trained teacher model and a student model to be trained; the teacher model and the student model respectively comprise a feature extractor, a master classifier and a slave classifier; respectively creating the images from the classifier according to the image characteristics of each sample image output by the characteristic extractor;
inputting the sample image into a teacher model to obtain a first main classification result output by a main classifier of the teacher model, and a first auxiliary classification result and first class information output by an auxiliary classifier of the teacher model;
inputting the sample image into a student model to be trained to obtain a second main classification result output by a main classifier of the student model to be trained, and a second auxiliary classification result and second class information output by an auxiliary classifier of the student model to be trained;
constructing a target loss function of the student model according to the first primary classification result, the first secondary classification result, the second primary classification result, the second secondary classification result, the first class information and the second class information;
performing iterative training on the student model to be trained according to the target loss function to obtain a trained student model; the trained student model is used for carrying out semantic segmentation on the input image.
In one embodiment, the processor, when executing the computer program, further performs the steps of:
constructing a first loss function based on the first principal classification result and the second principal classification result;
constructing a second loss function based on the first category information and the second category information;
constructing a third loss function based on the first slave classification result and the second slave classification result;
and determining a target loss function according to the first loss function, the second loss function, the third loss function and the cross entropy loss function of the second main classification result.
In one embodiment, the processor, when executing the computer program, further performs the steps of: carrying out iterative training on a student model to be trained by the target loss function; and when the iterative training times of the student model reach the preset training times, taking the student model reaching the preset training times as a trained student model.
In one embodiment, a computer-readable storage medium is provided, having a computer program stored thereon, which when executed by a processor, performs the steps of:
acquiring a trained teacher model and a student model to be trained; the teacher model and the student model respectively comprise a feature extractor, a master classifier and a slave classifier; respectively creating the images from the classifier according to the image characteristics of each sample image output by the characteristic extractor;
inputting the sample image into a teacher model to obtain a first main classification result output by a main classifier of the teacher model, and a first auxiliary classification result and first class information output by an auxiliary classifier of the teacher model;
inputting the sample image into a student model to be trained to obtain a second main classification result output by a main classifier of the student model to be trained, and a second auxiliary classification result and second class information output by an auxiliary classifier of the student model to be trained;
constructing a target loss function of the student model according to the first primary classification result, the first secondary classification result, the second primary classification result, the second secondary classification result, the first class information and the second class information;
performing iterative training on the student model to be trained according to the target loss function to obtain a trained student model; the trained student model is used for carrying out semantic segmentation on the input image.
In one embodiment, the computer program when executed by the processor further performs the steps of:
constructing a first loss function based on the first principal classification result and the second principal classification result;
constructing a second loss function based on the first category information and the second category information;
constructing a third loss function based on the first slave classification result and the second slave classification result;
and determining a target loss function according to the first loss function, the second loss function, the third loss function and the cross entropy loss function of the second main classification result.
In one embodiment, the computer program when executed by the processor further performs the steps of: carrying out iterative training on a student model to be trained by the target loss function; and when the iterative training times of the student model reach the preset training times, taking the student model reaching the preset training times as a trained student model.
It will be understood by those skilled in the art that all or part of the processes of the methods of the embodiments described above can be implemented by hardware instructions of a computer program, which can be stored in a non-volatile computer-readable storage medium, and when executed, can include the processes of the embodiments of the methods described above. Any reference to memory, storage, database or other medium used in the embodiments provided herein can include at least one of non-volatile and volatile memory. Non-volatile Memory may include Read-Only Memory (ROM), magnetic tape, floppy disk, flash Memory, optical storage, or the like. Volatile Memory can include Random Access Memory (RAM) or external cache Memory. By way of illustration and not limitation, RAM can take many forms, such as Static Random Access Memory (SRAM) or Dynamic Random Access Memory (DRAM), among others.
The technical features of the above embodiments can be arbitrarily combined, and for the sake of brevity, all possible combinations of the technical features in the above embodiments are not described, but should be considered as the scope of the present specification as long as there is no contradiction between the combinations of the technical features.
The above embodiments only express several embodiments of the present application, and the description thereof is more specific and detailed, but not construed as limiting the scope of the invention. It should be noted that, for a person skilled in the art, several variations and modifications can be made without departing from the concept of the present application, which falls within the scope of protection of the present application. Therefore, the protection scope of the present patent shall be subject to the appended claims.