Detailed Description
In order to make the technical solutions of the present invention better understood, the technical solutions in the embodiments of the present invention will be clearly and completely described below with reference to the drawings in the embodiments of the present invention, and it is obvious that the described embodiments are only a part of the embodiments of the present invention, and not all of the embodiments. All other embodiments, which can be derived by a person skilled in the art from the embodiments given herein without making any creative effort, shall fall within the protection scope of the present invention.
It should be noted that the terms "first," "second," and the like in the description and claims of the present invention and in the drawings described above are used for distinguishing between similar elements and not necessarily for describing a particular sequential or chronological order. It is to be understood that the data so used is interchangeable under appropriate circumstances such that the embodiments of the invention described herein are capable of operation in sequences other than those illustrated or described herein. Furthermore, the terms "comprises," "comprising," and "having," and any variations thereof, are intended to cover a non-exclusive inclusion, such that a process, method, system, article, or apparatus that comprises a list of steps or elements is not necessarily limited to those steps or elements expressly listed, but may include other steps or elements not expressly listed or inherent to such process, method, article, or apparatus.
Example 1
According to an embodiment of the present invention, a target classification model training method based on label-free samples is provided, as shown in fig. 1, the method includes:
s102, labeling a label-free sample set in the sample data set based on a first classification model to obtain a pseudo label sample set, wherein the first classification model is obtained by training a target classification model based on a labeled sample set;
in this embodiment, the training samples in the sample data set include labeled samples and unlabeled samples, where the labeled samples are data with labels, and the unlabeled samples are data without labels. The unlabeled samples include outlier samples.
It should be noted that the target classification model includes, but is not limited to, an image classification model, an image recognition model, a voice recognition model, a character recognition model, and the like. In this embodiment, the specific model structure and the use of the target classification model are not limited at all.
In one example, the object classification model is an image classification model, and is used for identifying the animal image and determining a corresponding category of the animal image. The sample data set corresponding to the image classification model comprises an animal image, a vegetable image, a fruit image and a building image. Wherein, the animal image with the label is a marked sample, and the image without the label is a non-marked sample. The vegetable image, the fruit image and the building image in the unmarked sample are outlier samples.
In this embodiment, labeled samples in the sample data set are obtained, and the target classification model is trained through the labeled samples, so that a first classification model which is trained through the labeled samples is obtained. And then, identifying the unmarked samples in the sample data set through the first classification model to obtain labels corresponding to the unmarked samples, and marking the unmarked samples through the identified labels to obtain a pseudo label sample set.
Optionally, in this embodiment, labeling, based on the first classification model, a label-free sample set in the sample data set includes, but is not limited to: predicting a plurality of unmarked samples in the unmarked sample set according to the first classification model to obtain a plurality of pseudo labels corresponding to the unmarked samples; and correspondingly labeling the plurality of unlabeled samples based on the plurality of pseudo labels to obtain a pseudo label sample set.
In this embodiment, the data structure of the sample data in the sample data set is < sample feature, label >, and the label of the unlabeled sample is null.
And training the target classification model through the labeled sample to obtain a first classification model. In the first classification model in this case, the cardinality of the training samples is small, so that an overfitting phenomenon is likely to occur, and although the label of the input sample can be predicted, the prediction result of the sample has an error.
In this embodiment, a prediction result of an unlabeled sample is obtained by performing preliminary prediction on the unlabeled sample through a first classification model. And constructing a pseudo label of the unmarked sample according to the prediction result. And then labeling the unmarked samples by the pseudo labels, wherein the data structure of the unmarked samples is < the sample characteristics, 'pseudo labels' >, and based on the same mode, predicting each unmarked sample in the unmarked sample set by the first classification model, and labeling each unmarked sample in the unmarked sample set based on the prediction result to obtain the pseudo label sample set.
Specifically, in the warm-up training stage, for each labeled sample, the backbone network g in the target classification model is firstly used
θExtracting sample characteristics, inputting into a classifier in a target classification model
Then, obtaining a category prediction p, and then training a classifier by adopting a cross entropy loss function, wherein the specific calculation is as follows:
Lce=-ln(p[y])
wherein y is the class label of the sample, and p [ y ] represents the prediction probability of the model to the real class of the sample. After warm-up training is completed, a pseudo class label is predicted for each unlabeled sample using the classifier of the model.
In one example, the object classification model is assumed to be an image classification model for identifying an image of an animal to predict the animal species. The sample data set corresponding to the image classification model comprises an animal image, a vegetable image, a fruit image and a building image. Wherein, the animal image with the label is a marked sample, and the image without the label is a non-marked sample. The vegetable image, the fruit image and the building image in the unmarked sample are outlier samples. The training samples in the sample dataset are < image feature vector, label 'animal species' >. Training a target classification model based on labeled samples of the animal image with the label to obtain a first classifier, preliminarily predicting the unlabeled samples through the first classifier, and labeling the unlabeled samples based on a prediction result to obtain a pseudo label sample set. The data structure of the swatches in the pseudo-tagged swatch set is < image feature vector, tag 'animal species' >.
Through the embodiment, the multiple unlabeled samples in the unlabeled sample set are predicted according to the first classification model, and the unlabeled samples are labeled based on the prediction result to obtain the pseudo label samples labeled by the pseudo labels, so that the preliminary classification of the unlabeled samples is realized.
S104, screening outlier samples in the sample data set according to the multi-modal matching model and the pseudo label sample set to obtain a non-outlier sample set, wherein the multi-modal matching model is obtained by training based on the labeled sample set;
after obtaining the pseudo label sample set labeled with the prediction result predicted by the first classification model, since the first classification model is likely to generate an overfitting phenomenon because the cardinality of the training samples is small, although the label of the input sample can be predicted, the prediction result of the sample is in error, that is, the label of the sample in the pseudo label sample set may be wrong.
Therefore, outlier samples in the sample data set are screened out through the multi-modal matching model and the pseudo label sample set, and the distribution difference between the samples in the distribution and the outlier samples is reduced. The multi-modal matching model is used for determining whether the prediction result of the sample characteristics of the samples in the pseudo label sample set is matched with the pseudo labels or not, and then extracting the outlier samples in the sample data set according to the matching result.
Optionally, in this embodiment, the outlier samples in the sample data set are filtered according to the multi-modal matching model and the pseudo label sample set, which includes but is not limited to: determining outlier samples in the sample data set according to the multi-modal matching model and the pseudo label sample set; and rejecting outlier samples in the sample data set.
The multi-modal matching model in this embodiment is used to determine whether a prediction result of a sample feature of a sample in a pseudo tag sample set matches a pseudo tag, and the input of the multi-modal matching model includes a sample feature vector corresponding to the sample in the pseudo tag sample set and a tag embedding vector corresponding to the pseudo tag. And judging whether the sample characteristics of the samples in the pseudo label sample set are matched with the labels or not according to the multi-mode matching model and the pseudo label sample set. Under the condition that the sample characteristics of the sample are matched with the label, the sample is a sample in distribution; and under the condition that the sample characteristics of the sample are matched with the label, the training sample corresponding to the pseudo label sample is an outlier sample, and the outlier sample in the sample data set is removed.
Further optionally, in this embodiment, determining outlier samples in the sample data set according to the multi-modal matching model and the pseudo tag sample set includes, but is not limited to: inputting label characteristic vectors corresponding to pseudo labels of pseudo label samples in the pseudo label sample set and image characteristic vectors of the pseudo label samples into a multi-modal matching model; if the label characteristic vector is determined to be not matched with the image characteristic vector according to the multi-modal matching model, determining that the unmarked sample corresponding to the pseudo label sample is an outlier sample; otherwise, determining that the unmarked sample corresponding to the pseudo label sample is a non-outlier sample.
In a specific example, assuming that the target classification model is an image classification model and is used for identifying an animal image to predict an animal type, if the label feature vector and the image feature vector are determined to be not matched according to the multi-modal matching model, determining that the unmarked sample corresponding to the pseudo label sample is an outlier sample; otherwise, determining that the unmarked sample corresponding to the pseudo label sample is a non-outlier sample.
Specifically, the following first introduces a multi-modal matching model:
the input of the multi-modal matching model comprises a sample characteristic vector corresponding to a sample in a pseudo label sample set and a label embedding vector corresponding to a pseudo label, and the data structure of the input data is a < sample characteristic vector, label embedding vector >. The positive samples during the multi-modal matching model training can be directly obtained from the labeled samples in the sample data set, namely, the sample characteristics and the corresponding labels form a pair of positive samples. The negative examples consist of the example and the wrong label.
In particular, the multi-mode matching model is a multi-mode matching discriminator
The multi-mode matching model is a multi-layer perceptron with a hidden layer, and is trained by labeled samples in a sample data set in a training process, wherein a feature vector g corresponding to a sample x is input into the multi-mode matching model
θ(x) And a class embedding vector g
φ(y) are spliced, and the output is a matching fraction S (x, y) with a value between 0 and 1. First using an image classifier
A class prediction p is obtained by predicting the sample x,
in the training process of the multi-modal matching model, the following two strategies can be adopted in the construction process of the training sample:
1) strategies that use difficult sample mining;
specifically, label y with the highest probability in sample classification prediction but not belonging to true label y is determinedh. That is, the true label of a sample in the classification prediction is y, but the label of a sample in the prediction is yhThe highest probability.
2) Randomly selecting one label y not equal to real label y or label y from category sethLabel y ofs. Specifically, the prediction probability in the classification prediction of the randomly acquired samples is less than the standardSign yhAnd is not a genuine tag yhLabel y ofs。
Then based on label y, label yhAnd label ysAnd constructing each labeled sample to obtain three samples corresponding to the corresponding labeled sample. In one example, based on sample X and the three labels, three samples, each of which is<X,y>、<X,yh>、<X,ys>。
And constructing the marked samples in the sample data set based on the mode to obtain a matched sample set. After the three samples are obtained, a binary cross entropy loss function is used for training the multi-mode matching discriminator
The binary cross entropy function is:
after the warm-up training is completed for the preset times, the preset times can be set according to actual needs.
To further improve the performance of the multi-modal matching discriminator, unlabeled samples are also added to the training because there are no real labels, and the entropy minimization method is used to let the model self-train.
For each unmarked sample, selecting the label with the highest model class prediction probability
And randomly selecting a label from the remaining labels
The entropy minimization is calculated as follows:
and inputting label characteristic vectors corresponding to the pseudo labels of the pseudo label samples in the pseudo label sample set and image characteristic vectors of the pseudo label samples into the multi-modal matching model to obtain the sample characteristics of the pseudo label samples and the matching characteristics of the pseudo labels.
By the above example, a multi-mode matching mechanism is adopted to screen outlier samples for a subsequent semi-supervised learning algorithm, so as to improve the accuracy of the prediction result of the target classification model.
And S106, training the target classification model based on the non-outlier sample set until the model converges.
Specifically, outlier sample values in the sample data set are removed according to a matching result of the multi-modal matching model on the pseudo label sample set, a target classification model is trained based on a non-outlier sample set, and interference of outlier samples is removed.
Optionally, in this embodiment, the target classification model includes an image classification model, and the image classification model includes a backbone network and a classifier, where the target classification model is trained based on a non-outlier sample set until the model converges, including but not limited to: training an image classification model based on a non-outlier sample set; and training an image rotation model based on a rotation sample set, wherein the rotation sample set is obtained according to the sample data set, and the image rotation model comprises a backbone network and a rotation classifier.
Specifically, while training the target classification model based on the in-distribution samples, an auxiliary task can be established for the target classification model in a self-supervision mode to improve the feature extraction capability of the model.
In some embodiments, the target classification model is an image classification model that is used to identify a class of the animal in the image. The image classification model comprises a backbone network and a classifier, and an image rotation model is constructed on the basis of the backbone network of the image classification model and used for determining the rotation angle of the image. The image rotation model comprises a backbone network and a rotation classifier, and a rotation sample is created according to a training sample in the sample data set to obtain a rotation sample set.
Training the image classification model based on the non-outlier sample set, establishing a rotation angle of an auxiliary task confirmation image, training the image rotation model based on the rotation sample set, and obtaining a converged target classification model through a training mode of self-supervision learning.
Optionally, in this embodiment, training is performed based on a non-outlier sample set image classification model; and training the image rotation model based on the rotation sample set, including but not limited to: determining a first loss function corresponding to the image classification model according to the non-outlier sample set; determining a second loss function corresponding to the image rotation model according to the rotation sample set; and training the image classification model and the image rotation model respectively based on the first loss function and the second loss function until iteration is carried out to a preset number of times.
Specifically, whether the label is matched with the sample feature is judged by using a multi-mode matching branch, and if the label is matched with the sample feature, a label-free data set of a semi-supervised image classification task is added. After the outlier sample filtering is completed, a consistency constraint semi-supervised learning method based on data augmentation is used for further training the image classification model. Specifically, a series of image transformations are performed on an input sample to obtain an augmented sample, and the model is required to predict the distribution p and p of the classes of the two samples
To be as consistent as possible, we use KL divergence to measure the distance of the distribution, and the calculation method is as follows:
k is the number of label categories in the image classification task, and p [ j ] represents the probability of the jth label category in category prediction.
Optionally, in this embodiment, before training the image rotation model based on the rotation sample set, the method further includes, but is not limited to: and rotating the training samples in the sample data set to a preset angle to obtain a rotating sample, wherein the rotating sample comprises a rotating angle label.
Specifically, for all sample data (including outlier samples) in the sample data set, the angle of rotation of the image is predicted by rotating the classification model, specifically, x is obtained by rotating the input sample x by 0 °,90 °,180 °, and 270 ° respectively1,x2,x3,x4Then, the cross entropy loss function is used to train the rotation classification model, and the calculation method is as follows:
wherein q is the prediction of the rotation angle of the image by the rotation classification model, 4 angles need to be predicted, and the problem can be regarded as a 4-classification problem.
According to the embodiment, a label-free sample set in the sample data set is labeled based on a first classification model to obtain a pseudo label sample set, wherein the first classification model is obtained by training a target classification model based on a labeled sample set; screening outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set to obtain a non-outlier sample set, wherein the multi-modal matching model is obtained based on the labeled sample set through training; and training the target classification model based on the non-outlier sample set until the model converges. In the embodiment, the detection of the outlier sample is completed by determining whether the embedded vectors of the features and the categories of the point images are matched, so that the purpose of rejecting the outlier sample in the label-free data is achieved, meanwhile, the self-supervision learning method is used based on the non-outlier sample comprising the label-free sample, so that the technical effects of fully utilizing all the label-free data and improving the feature extraction capability of the model by utilizing the outlier sample are achieved, and the technical problem that the difficulty in model training based on the open-set semi-supervised learning is large due to the scarcity of the label data in the semi-supervised learning in the related technology and the poor detection accuracy of the outlier sample is solved.
It should be noted that, for simplicity of description, the above-mentioned method embodiments are described as a series of acts or combination of acts, but those skilled in the art will recognize that the present invention is not limited by the order of acts, as some steps may occur in other orders or concurrently in accordance with the invention. Further, those skilled in the art should also appreciate that the embodiments described in the specification are preferred embodiments and that the acts and modules referred to are not necessarily required by the invention.
Through the above description of the embodiments, those skilled in the art can clearly understand that the method according to the above embodiments can be implemented by software plus a necessary general hardware platform, and certainly can also be implemented by hardware, but the former is a better implementation mode in many cases. Based on such understanding, the technical solutions of the present invention may be embodied in the form of a software product, which is stored in a storage medium (e.g., ROM/RAM, magnetic disk, optical disk) and includes instructions for enabling a terminal device (e.g., a mobile phone, a computer, a server, or a network device) to execute the method according to the embodiments of the present invention.
Example 2
According to an embodiment of the present invention, there is also provided an apparatus for training a target classification model based on an unlabeled sample, for implementing the above method for training a target classification model based on an unlabeled sample, as shown in fig. 2, the apparatus includes:
1) the labeling module 20 is configured to label a non-labeled sample set in the sample data set based on a first classification model to obtain a pseudo label sample set, where the first classification model is obtained by training a target classification model based on a labeled sample set;
2) a screening module 22, configured to screen outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set to obtain a non-outlier sample set, where the multi-modal matching model is obtained based on the labeled sample set through training;
3) and the training module 24 is configured to train the target classification model based on the non-outlier sample set until the model converges.
Optionally, the specific examples in this embodiment may refer to the examples described in embodiment 1 and embodiment 2, and this embodiment is not described herein again.
Through the embodiment, the detection of the outlier sample is completed by determining whether the characteristic and the category embedded vector are matched or not for the point image, so that the purpose of rejecting the outlier sample in the label-free data is achieved, meanwhile, the self-supervised learning method is used based on the non-outlier sample comprising the label-free sample, so that the technical effects of fully utilizing all the label-free data and improving the characteristic extraction capability of the model by utilizing the outlier sample are achieved, and the technical problem that the difficulty in model training based on the open set semi-supervised learning is large due to the scarcity of the label data in the semi-supervised learning in the related technology and the poor detection accuracy of the outlier sample is solved.
Example 3
There is also provided, according to an embodiment of the present invention, an electronic device, including a processor, a memory, and a program or instructions stored on the memory and executable on the processor, where the program or instructions, when executed by the processor, implement the steps of the target classification model training method based on unlabeled samples as described above.
Optionally, in this embodiment, the memory is configured to store program code for performing the following steps:
s1, labeling a label-free sample set in the sample data set based on a first classification model to obtain a pseudo label sample set, wherein the first classification model is obtained by training a target classification model based on a labeled sample set;
s2, screening outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set to obtain a non-outlier sample set, wherein the multi-modal matching model is obtained based on the labeled sample set through training;
s3, training the target classification model based on the non-outlier sample set until the model converges.
Optionally, the specific example in this embodiment may refer to the example described in embodiment 1 above, and this embodiment is not described again here.
Example 4
Embodiments of the present invention also provide a readable storage medium on which a program or instructions are stored, which when executed by a processor, implement the steps of the target classification model training method based on label-free samples as described above.
Optionally, in this embodiment, the readable storage medium is configured to store program code for performing the following steps:
s1, labeling a label-free sample set in the sample data set based on a first classification model to obtain a pseudo label sample set, wherein the first classification model is obtained by training a target classification model based on a labeled sample set;
s2, screening outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set to obtain a non-outlier sample set, wherein the multi-modal matching model is obtained based on the labeled sample set through training;
s3, training the target classification model based on the non-outlier sample set until the model converges.
Optionally, the storage medium is further configured to store program codes for executing the steps included in the method in embodiment 1, which is not described in detail in this embodiment.
Optionally, in this embodiment, the storage medium may include, but is not limited to: a U-disk, a Read-Only Memory (ROM), a Random Access Memory (RAM), a removable hard disk, a magnetic or optical disk, and other various media capable of storing program codes.
Optionally, the specific examples in this embodiment may refer to the examples described in embodiment 1 and embodiment 2, and this embodiment is not described herein again.
The above-mentioned serial numbers of the embodiments of the present invention are merely for description and do not represent the merits of the embodiments.
The integrated unit in the above embodiments, if implemented in the form of a software functional unit and sold or used as a separate product, may be stored in the above computer-readable storage medium. Based on such understanding, the technical solution of the present invention may be embodied in the form of a software product, which is stored in a storage medium and includes several instructions for causing one or more computer devices (which may be personal computers, servers, network devices, etc.) to execute all or part of the steps of the method according to the embodiments of the present invention.
In the above embodiments of the present invention, the descriptions of the respective embodiments have respective emphasis, and for parts that are not described in detail in a certain embodiment, reference may be made to related descriptions of other embodiments.
In the several embodiments provided in the present application, it should be understood that the disclosed client may be implemented in other manners. The above-described embodiments of the apparatus are merely illustrative, and for example, the division of the units is only one type of division of logical functions, and there may be other divisions when actually implemented, for example, a plurality of units or components may be combined or may be integrated into another system, or some features may be omitted, or not executed. In addition, the shown or discussed mutual coupling or direct coupling or communication connection may be an indirect coupling or communication connection through some interfaces, units or modules, and may be in an electrical or other form.
The units described as separate parts may or may not be physically separate, and parts displayed as units may or may not be physical units, may be located in one place, or may be distributed on a plurality of network units. Some or all of the units can be selected according to actual needs to achieve the purpose of the solution of the embodiment.
In addition, functional units in the embodiments of the present invention may be integrated into one processing unit, or each unit may exist alone physically, or two or more units are integrated into one unit. The integrated unit can be realized in a form of hardware, and can also be realized in a form of a software functional unit.
The foregoing is only a preferred embodiment of the present invention, and it should be noted that, for those skilled in the art, various modifications and decorations can be made without departing from the principle of the present invention, and these modifications and decorations should also be regarded as the protection scope of the present invention.