CN114419363A - Target classification model training method and device based on label-free sample data - Google Patents

Target classification model training method and device based on label-free sample data Download PDF

Info

Publication number
CN114419363A
CN114419363A CN202111591966.3A CN202111591966A CN114419363A CN 114419363 A CN114419363 A CN 114419363A CN 202111591966 A CN202111591966 A CN 202111591966A CN 114419363 A CN114419363 A CN 114419363A
Authority
CN
China
Prior art keywords
sample
label
classification model
model
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111591966.3A
Other languages
Chinese (zh)
Inventor
李冠彬
黄俊凯
张津津
柴振华
魏晓林
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Sankuai Online Technology Co Ltd
Sun Yat Sen University
Original Assignee
Beijing Sankuai Online Technology Co Ltd
Sun Yat Sen University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Sankuai Online Technology Co Ltd, Sun Yat Sen University filed Critical Beijing Sankuai Online Technology Co Ltd
Priority to CN202111591966.3A priority Critical patent/CN114419363A/en
Publication of CN114419363A publication Critical patent/CN114419363A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Abstract

The invention discloses a target classification model training method and device based on label-free sample data. Wherein, the method comprises the following steps: labeling a label-free sample set in the sample data set based on a first classification model to obtain a pseudo label sample set, wherein the first classification model is obtained by training a target classification model based on the labeled sample set; screening outlier samples in the sample data set according to the multi-modal matching model and the pseudo label sample set to obtain a non-outlier sample set, wherein the multi-modal matching model is obtained based on training of a labeled sample set; and training the target classification model based on the non-outlier sample set until the model converges. The method solves the technical problem that the model training difficulty based on the open-set semi-supervised learning is large due to the shortage of labeled data in the semi-supervised learning and the poor detection accuracy of the outlier sample in the related technology.

Description

Target classification model training method and device based on label-free sample data
Technical Field
The invention relates to the technical field of neural networks, in particular to a target classification model training method and device based on label-free sample data.
Background
The deep learning method has been successful in many computer vision tasks, but it requires a huge amount of labeled data, which restricts the wide application of deep learning. When only limited labeled samples are available, the semi-supervised learning method can improve the performance of the deep neural network by utilizing a large amount of label-free data. Most existing semi-supervised learning methods assume that annotated data and unlabeled data share the same class space.
In the prior art, tedious work is required to clean the unmarked data. In recent years, researchers have begun studying a more challenging semi-supervised learning scenario, open semi-supervised learning, in which unlabeled data used includes outlier samples that do not belong to the labeled data category. However, the inclusion of outlier samples in the unlabeled data can significantly affect the performance of the semi-supervised learning method. Although various methods of outlier sample detection exist, they typically require a large number of class-tagged intra-distribution samples.
However, due to the scarcity of labeled data in semi-supervised learning, the existing outlier sample detection method cannot achieve satisfactory performance, and is therefore not suitable for being used in an open set of semi-supervised learning problems.
Disclosure of Invention
The embodiment of the invention provides a target classification model training method and device based on label-free sample data, which at least solve the technical problem of high difficulty in model training based on open-set semi-supervised learning due to the scarcity of label data in semi-supervised learning and poor detection accuracy of outlier samples in the related technology.
According to an aspect of the embodiments of the present invention, there is provided a target classification model training method based on label-free samples, including: labeling an unmarked sample set in the sample data set based on a first classification model to obtain a pseudo label sample set, wherein the first classification model is obtained by training a target classification model based on the marked sample set; screening outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set to obtain a non-outlier sample set, wherein the multi-modal matching model is obtained based on the labeled sample set through training; and training the target classification model based on the non-outlier sample set until the model converges.
Further, labeling the unlabeled sample set in the sample data set based on the first classification model includes: predicting a plurality of unmarked samples in the unmarked sample set according to the first classification model to obtain a plurality of pseudo labels corresponding to the unmarked samples; and correspondingly labeling the plurality of unlabeled samples based on the plurality of pseudo labels to obtain the pseudo label sample set.
Further, screening outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set, comprising: determining outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set; rejecting the outlier samples in the sample dataset.
Further, determining outlier samples in the sample data set according to a multi-modal matching model and a pseudo-label sample set, comprising: inputting label feature vectors corresponding to pseudo labels of the pseudo label samples in the pseudo label sample set and image feature vectors of the pseudo label samples into the multi-modal matching model; if the label characteristic vector is determined to be not matched with the image characteristic vector according to the multi-modal matching model, determining that an unmarked sample corresponding to the pseudo label sample is the outlier sample; otherwise, determining that the unmarked sample corresponding to the pseudo label sample is a non-outlier sample.
Further, the target classification model includes an image classification model, the image classification model includes a backbone network and a classifier, wherein the training the image classification model based on the non-outlier sample set until the model converges includes: training the image classification model based on the non-outlier sample set; and training an image rotation model based on a rotation sample set, wherein the rotation sample set is obtained according to the sample data set, and the image rotation model comprises the backbone network and a rotation classifier.
Further, training the image classification model based on the non-outlier sample set; and training an image rotation model based on the rotation sample set, including: determining a first loss function corresponding to the image classification model according to the non-outlier sample set; determining a second loss function corresponding to the image rotation model according to the rotation sample set; and carrying out image classification model and image rotation model based on the first loss function and the second loss function until iteration is carried out to a preset number of times.
Further, before training the image rotation model based on the rotation sample set, the method further includes: rotating the training samples in the sample data set to a preset angle to obtain a rotating sample, wherein the rotating sample comprises a rotating angle label.
According to another aspect of the embodiments of the present invention, there is also provided a target classification model training apparatus based on label-free samples, including: the labeling module is used for labeling a label-free sample set in the sample data set based on a first classification model to obtain a pseudo label sample set, wherein the first classification model is obtained by training a target classification model based on a labeled sample set; the screening module is used for screening outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set to obtain a non-outlier sample set, wherein the multi-modal matching model is obtained based on the labeled sample set through training; and the training module is used for training the target classification model based on the non-outlier sample set until the model converges.
According to another aspect of the embodiments of the present invention, there is also provided an electronic device, including a processor, a memory, and a program or instructions stored on the memory and executable on the processor, where the program or instructions, when executed by the processor, implement the steps of the target classification model training method based on unlabeled samples as described above.
According to another aspect of the embodiments of the present invention, there is also provided a readable storage medium, on which a program or instructions are stored, which when executed by a processor, implement the steps of the target classification model training method based on label-free samples as described above.
In the embodiment of the invention, a label-free sample set in a sample data set is labeled based on a first classification model to obtain a pseudo label sample set, wherein the first classification model is obtained by training a target classification model based on a labeled sample set; screening outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set to obtain a non-outlier sample set, wherein the multi-modal matching model is obtained based on the labeled sample set through training; and training the target classification model based on the non-outlier sample set until the model converges. In the embodiment, the detection of the outlier sample is completed by determining whether the embedded vectors of the features and the categories of the point images are matched, so that the purpose of rejecting the outlier sample in the label-free data is achieved, meanwhile, the self-supervision learning method is used based on the non-outlier sample comprising the label-free sample, so that the technical effects of fully utilizing all the label-free data and improving the feature extraction capability of the model by utilizing the outlier sample are achieved, and the technical problem that the difficulty in model training based on the open-set semi-supervised learning is large due to the scarcity of the label data in the semi-supervised learning in the related technology and the poor detection accuracy of the outlier sample is solved.
Drawings
The accompanying drawings, which are included to provide a further understanding of the invention and are incorporated in and constitute a part of this application, illustrate embodiment(s) of the invention and together with the description serve to explain the invention without limiting the invention. In the drawings:
FIG. 1 is a schematic flow chart of an alternative target classification model training method based on label-free samples according to an embodiment of the present invention;
FIG. 2 is a schematic structural diagram of an alternative target classification model training apparatus based on unlabeled samples according to an embodiment of the present invention;
Detailed Description
In order to make the technical solutions of the present invention better understood, the technical solutions in the embodiments of the present invention will be clearly and completely described below with reference to the drawings in the embodiments of the present invention, and it is obvious that the described embodiments are only a part of the embodiments of the present invention, and not all of the embodiments. All other embodiments, which can be derived by a person skilled in the art from the embodiments given herein without making any creative effort, shall fall within the protection scope of the present invention.
It should be noted that the terms "first," "second," and the like in the description and claims of the present invention and in the drawings described above are used for distinguishing between similar elements and not necessarily for describing a particular sequential or chronological order. It is to be understood that the data so used is interchangeable under appropriate circumstances such that the embodiments of the invention described herein are capable of operation in sequences other than those illustrated or described herein. Furthermore, the terms "comprises," "comprising," and "having," and any variations thereof, are intended to cover a non-exclusive inclusion, such that a process, method, system, article, or apparatus that comprises a list of steps or elements is not necessarily limited to those steps or elements expressly listed, but may include other steps or elements not expressly listed or inherent to such process, method, article, or apparatus.
Example 1
According to an embodiment of the present invention, a target classification model training method based on label-free samples is provided, as shown in fig. 1, the method includes:
s102, labeling a label-free sample set in the sample data set based on a first classification model to obtain a pseudo label sample set, wherein the first classification model is obtained by training a target classification model based on a labeled sample set;
in this embodiment, the training samples in the sample data set include labeled samples and unlabeled samples, where the labeled samples are data with labels, and the unlabeled samples are data without labels. The unlabeled samples include outlier samples.
It should be noted that the target classification model includes, but is not limited to, an image classification model, an image recognition model, a voice recognition model, a character recognition model, and the like. In this embodiment, the specific model structure and the use of the target classification model are not limited at all.
In one example, the object classification model is an image classification model, and is used for identifying the animal image and determining a corresponding category of the animal image. The sample data set corresponding to the image classification model comprises an animal image, a vegetable image, a fruit image and a building image. Wherein, the animal image with the label is a marked sample, and the image without the label is a non-marked sample. The vegetable image, the fruit image and the building image in the unmarked sample are outlier samples.
In this embodiment, labeled samples in the sample data set are obtained, and the target classification model is trained through the labeled samples, so that a first classification model which is trained through the labeled samples is obtained. And then, identifying the unmarked samples in the sample data set through the first classification model to obtain labels corresponding to the unmarked samples, and marking the unmarked samples through the identified labels to obtain a pseudo label sample set.
Optionally, in this embodiment, labeling, based on the first classification model, a label-free sample set in the sample data set includes, but is not limited to: predicting a plurality of unmarked samples in the unmarked sample set according to the first classification model to obtain a plurality of pseudo labels corresponding to the unmarked samples; and correspondingly labeling the plurality of unlabeled samples based on the plurality of pseudo labels to obtain a pseudo label sample set.
In this embodiment, the data structure of the sample data in the sample data set is < sample feature, label >, and the label of the unlabeled sample is null.
And training the target classification model through the labeled sample to obtain a first classification model. In the first classification model in this case, the cardinality of the training samples is small, so that an overfitting phenomenon is likely to occur, and although the label of the input sample can be predicted, the prediction result of the sample has an error.
In this embodiment, a prediction result of an unlabeled sample is obtained by performing preliminary prediction on the unlabeled sample through a first classification model. And constructing a pseudo label of the unmarked sample according to the prediction result. And then labeling the unmarked samples by the pseudo labels, wherein the data structure of the unmarked samples is < the sample characteristics, 'pseudo labels' >, and based on the same mode, predicting each unmarked sample in the unmarked sample set by the first classification model, and labeling each unmarked sample in the unmarked sample set based on the prediction result to obtain the pseudo label sample set.
Specifically, in the warm-up training stage, for each labeled sample, the backbone network g in the target classification model is firstly usedθExtracting sample characteristics, inputting into a classifier in a target classification model
Figure BDA0003429488100000061
Then, obtaining a category prediction p, and then training a classifier by adopting a cross entropy loss function, wherein the specific calculation is as follows:
Lce=-ln(p[y])
wherein y is the class label of the sample, and p [ y ] represents the prediction probability of the model to the real class of the sample. After warm-up training is completed, a pseudo class label is predicted for each unlabeled sample using the classifier of the model.
In one example, the object classification model is assumed to be an image classification model for identifying an image of an animal to predict the animal species. The sample data set corresponding to the image classification model comprises an animal image, a vegetable image, a fruit image and a building image. Wherein, the animal image with the label is a marked sample, and the image without the label is a non-marked sample. The vegetable image, the fruit image and the building image in the unmarked sample are outlier samples. The training samples in the sample dataset are < image feature vector, label 'animal species' >. Training a target classification model based on labeled samples of the animal image with the label to obtain a first classifier, preliminarily predicting the unlabeled samples through the first classifier, and labeling the unlabeled samples based on a prediction result to obtain a pseudo label sample set. The data structure of the swatches in the pseudo-tagged swatch set is < image feature vector, tag 'animal species' >.
Through the embodiment, the multiple unlabeled samples in the unlabeled sample set are predicted according to the first classification model, and the unlabeled samples are labeled based on the prediction result to obtain the pseudo label samples labeled by the pseudo labels, so that the preliminary classification of the unlabeled samples is realized.
S104, screening outlier samples in the sample data set according to the multi-modal matching model and the pseudo label sample set to obtain a non-outlier sample set, wherein the multi-modal matching model is obtained by training based on the labeled sample set;
after obtaining the pseudo label sample set labeled with the prediction result predicted by the first classification model, since the first classification model is likely to generate an overfitting phenomenon because the cardinality of the training samples is small, although the label of the input sample can be predicted, the prediction result of the sample is in error, that is, the label of the sample in the pseudo label sample set may be wrong.
Therefore, outlier samples in the sample data set are screened out through the multi-modal matching model and the pseudo label sample set, and the distribution difference between the samples in the distribution and the outlier samples is reduced. The multi-modal matching model is used for determining whether the prediction result of the sample characteristics of the samples in the pseudo label sample set is matched with the pseudo labels or not, and then extracting the outlier samples in the sample data set according to the matching result.
Optionally, in this embodiment, the outlier samples in the sample data set are filtered according to the multi-modal matching model and the pseudo label sample set, which includes but is not limited to: determining outlier samples in the sample data set according to the multi-modal matching model and the pseudo label sample set; and rejecting outlier samples in the sample data set.
The multi-modal matching model in this embodiment is used to determine whether a prediction result of a sample feature of a sample in a pseudo tag sample set matches a pseudo tag, and the input of the multi-modal matching model includes a sample feature vector corresponding to the sample in the pseudo tag sample set and a tag embedding vector corresponding to the pseudo tag. And judging whether the sample characteristics of the samples in the pseudo label sample set are matched with the labels or not according to the multi-mode matching model and the pseudo label sample set. Under the condition that the sample characteristics of the sample are matched with the label, the sample is a sample in distribution; and under the condition that the sample characteristics of the sample are matched with the label, the training sample corresponding to the pseudo label sample is an outlier sample, and the outlier sample in the sample data set is removed.
Further optionally, in this embodiment, determining outlier samples in the sample data set according to the multi-modal matching model and the pseudo tag sample set includes, but is not limited to: inputting label characteristic vectors corresponding to pseudo labels of pseudo label samples in the pseudo label sample set and image characteristic vectors of the pseudo label samples into a multi-modal matching model; if the label characteristic vector is determined to be not matched with the image characteristic vector according to the multi-modal matching model, determining that the unmarked sample corresponding to the pseudo label sample is an outlier sample; otherwise, determining that the unmarked sample corresponding to the pseudo label sample is a non-outlier sample.
In a specific example, assuming that the target classification model is an image classification model and is used for identifying an animal image to predict an animal type, if the label feature vector and the image feature vector are determined to be not matched according to the multi-modal matching model, determining that the unmarked sample corresponding to the pseudo label sample is an outlier sample; otherwise, determining that the unmarked sample corresponding to the pseudo label sample is a non-outlier sample.
Specifically, the following first introduces a multi-modal matching model:
the input of the multi-modal matching model comprises a sample characteristic vector corresponding to a sample in a pseudo label sample set and a label embedding vector corresponding to a pseudo label, and the data structure of the input data is a < sample characteristic vector, label embedding vector >. The positive samples during the multi-modal matching model training can be directly obtained from the labeled samples in the sample data set, namely, the sample characteristics and the corresponding labels form a pair of positive samples. The negative examples consist of the example and the wrong label.
In particular, the multi-mode matching model is a multi-mode matching discriminator
Figure BDA0003429488100000081
The multi-mode matching model is a multi-layer perceptron with a hidden layer, and is trained by labeled samples in a sample data set in a training process, wherein a feature vector g corresponding to a sample x is input into the multi-mode matching modelθ(x) And a class embedding vector gφ(y) are spliced, and the output is a matching fraction S (x, y) with a value between 0 and 1. First using an image classifier
Figure BDA0003429488100000091
A class prediction p is obtained by predicting the sample x,
in the training process of the multi-modal matching model, the following two strategies can be adopted in the construction process of the training sample:
1) strategies that use difficult sample mining;
specifically, label y with the highest probability in sample classification prediction but not belonging to true label y is determinedh. That is, the true label of a sample in the classification prediction is y, but the label of a sample in the prediction is yhThe highest probability.
2) Randomly selecting one label y not equal to real label y or label y from category sethLabel y ofs. Specifically, the prediction probability in the classification prediction of the randomly acquired samples is less than the standardSign yhAnd is not a genuine tag yhLabel y ofs
Then based on label y, label yhAnd label ysAnd constructing each labeled sample to obtain three samples corresponding to the corresponding labeled sample. In one example, based on sample X and the three labels, three samples, each of which is<X,y>、<X,yh>、<X,ys>。
And constructing the marked samples in the sample data set based on the mode to obtain a matched sample set. After the three samples are obtained, a binary cross entropy loss function is used for training the multi-mode matching discriminator
Figure BDA0003429488100000092
The binary cross entropy function is:
Figure BDA0003429488100000093
after the warm-up training is completed for the preset times, the preset times can be set according to actual needs.
To further improve the performance of the multi-modal matching discriminator, unlabeled samples are also added to the training because there are no real labels, and the entropy minimization method is used to let the model self-train.
For each unmarked sample, selecting the label with the highest model class prediction probability
Figure BDA0003429488100000101
And randomly selecting a label from the remaining labels
Figure BDA0003429488100000102
The entropy minimization is calculated as follows:
Figure BDA0003429488100000103
and inputting label characteristic vectors corresponding to the pseudo labels of the pseudo label samples in the pseudo label sample set and image characteristic vectors of the pseudo label samples into the multi-modal matching model to obtain the sample characteristics of the pseudo label samples and the matching characteristics of the pseudo labels.
By the above example, a multi-mode matching mechanism is adopted to screen outlier samples for a subsequent semi-supervised learning algorithm, so as to improve the accuracy of the prediction result of the target classification model.
And S106, training the target classification model based on the non-outlier sample set until the model converges.
Specifically, outlier sample values in the sample data set are removed according to a matching result of the multi-modal matching model on the pseudo label sample set, a target classification model is trained based on a non-outlier sample set, and interference of outlier samples is removed.
Optionally, in this embodiment, the target classification model includes an image classification model, and the image classification model includes a backbone network and a classifier, where the target classification model is trained based on a non-outlier sample set until the model converges, including but not limited to: training an image classification model based on a non-outlier sample set; and training an image rotation model based on a rotation sample set, wherein the rotation sample set is obtained according to the sample data set, and the image rotation model comprises a backbone network and a rotation classifier.
Specifically, while training the target classification model based on the in-distribution samples, an auxiliary task can be established for the target classification model in a self-supervision mode to improve the feature extraction capability of the model.
In some embodiments, the target classification model is an image classification model that is used to identify a class of the animal in the image. The image classification model comprises a backbone network and a classifier, and an image rotation model is constructed on the basis of the backbone network of the image classification model and used for determining the rotation angle of the image. The image rotation model comprises a backbone network and a rotation classifier, and a rotation sample is created according to a training sample in the sample data set to obtain a rotation sample set.
Training the image classification model based on the non-outlier sample set, establishing a rotation angle of an auxiliary task confirmation image, training the image rotation model based on the rotation sample set, and obtaining a converged target classification model through a training mode of self-supervision learning.
Optionally, in this embodiment, training is performed based on a non-outlier sample set image classification model; and training the image rotation model based on the rotation sample set, including but not limited to: determining a first loss function corresponding to the image classification model according to the non-outlier sample set; determining a second loss function corresponding to the image rotation model according to the rotation sample set; and training the image classification model and the image rotation model respectively based on the first loss function and the second loss function until iteration is carried out to a preset number of times.
Specifically, whether the label is matched with the sample feature is judged by using a multi-mode matching branch, and if the label is matched with the sample feature, a label-free data set of a semi-supervised image classification task is added. After the outlier sample filtering is completed, a consistency constraint semi-supervised learning method based on data augmentation is used for further training the image classification model. Specifically, a series of image transformations are performed on an input sample to obtain an augmented sample, and the model is required to predict the distribution p and p of the classes of the two samples
Figure BDA0003429488100000111
To be as consistent as possible, we use KL divergence to measure the distance of the distribution, and the calculation method is as follows:
Figure BDA0003429488100000112
k is the number of label categories in the image classification task, and p [ j ] represents the probability of the jth label category in category prediction.
Optionally, in this embodiment, before training the image rotation model based on the rotation sample set, the method further includes, but is not limited to: and rotating the training samples in the sample data set to a preset angle to obtain a rotating sample, wherein the rotating sample comprises a rotating angle label.
Specifically, for all sample data (including outlier samples) in the sample data set, the angle of rotation of the image is predicted by rotating the classification model, specifically, x is obtained by rotating the input sample x by 0 °,90 °,180 °, and 270 ° respectively1,x2,x3,x4Then, the cross entropy loss function is used to train the rotation classification model, and the calculation method is as follows:
Figure BDA0003429488100000121
wherein q is the prediction of the rotation angle of the image by the rotation classification model, 4 angles need to be predicted, and the problem can be regarded as a 4-classification problem.
According to the embodiment, a label-free sample set in the sample data set is labeled based on a first classification model to obtain a pseudo label sample set, wherein the first classification model is obtained by training a target classification model based on a labeled sample set; screening outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set to obtain a non-outlier sample set, wherein the multi-modal matching model is obtained based on the labeled sample set through training; and training the target classification model based on the non-outlier sample set until the model converges. In the embodiment, the detection of the outlier sample is completed by determining whether the embedded vectors of the features and the categories of the point images are matched, so that the purpose of rejecting the outlier sample in the label-free data is achieved, meanwhile, the self-supervision learning method is used based on the non-outlier sample comprising the label-free sample, so that the technical effects of fully utilizing all the label-free data and improving the feature extraction capability of the model by utilizing the outlier sample are achieved, and the technical problem that the difficulty in model training based on the open-set semi-supervised learning is large due to the scarcity of the label data in the semi-supervised learning in the related technology and the poor detection accuracy of the outlier sample is solved.
It should be noted that, for simplicity of description, the above-mentioned method embodiments are described as a series of acts or combination of acts, but those skilled in the art will recognize that the present invention is not limited by the order of acts, as some steps may occur in other orders or concurrently in accordance with the invention. Further, those skilled in the art should also appreciate that the embodiments described in the specification are preferred embodiments and that the acts and modules referred to are not necessarily required by the invention.
Through the above description of the embodiments, those skilled in the art can clearly understand that the method according to the above embodiments can be implemented by software plus a necessary general hardware platform, and certainly can also be implemented by hardware, but the former is a better implementation mode in many cases. Based on such understanding, the technical solutions of the present invention may be embodied in the form of a software product, which is stored in a storage medium (e.g., ROM/RAM, magnetic disk, optical disk) and includes instructions for enabling a terminal device (e.g., a mobile phone, a computer, a server, or a network device) to execute the method according to the embodiments of the present invention.
Example 2
According to an embodiment of the present invention, there is also provided an apparatus for training a target classification model based on an unlabeled sample, for implementing the above method for training a target classification model based on an unlabeled sample, as shown in fig. 2, the apparatus includes:
1) the labeling module 20 is configured to label a non-labeled sample set in the sample data set based on a first classification model to obtain a pseudo label sample set, where the first classification model is obtained by training a target classification model based on a labeled sample set;
2) a screening module 22, configured to screen outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set to obtain a non-outlier sample set, where the multi-modal matching model is obtained based on the labeled sample set through training;
3) and the training module 24 is configured to train the target classification model based on the non-outlier sample set until the model converges.
Optionally, the specific examples in this embodiment may refer to the examples described in embodiment 1 and embodiment 2, and this embodiment is not described herein again.
Through the embodiment, the detection of the outlier sample is completed by determining whether the characteristic and the category embedded vector are matched or not for the point image, so that the purpose of rejecting the outlier sample in the label-free data is achieved, meanwhile, the self-supervised learning method is used based on the non-outlier sample comprising the label-free sample, so that the technical effects of fully utilizing all the label-free data and improving the characteristic extraction capability of the model by utilizing the outlier sample are achieved, and the technical problem that the difficulty in model training based on the open set semi-supervised learning is large due to the scarcity of the label data in the semi-supervised learning in the related technology and the poor detection accuracy of the outlier sample is solved.
Example 3
There is also provided, according to an embodiment of the present invention, an electronic device, including a processor, a memory, and a program or instructions stored on the memory and executable on the processor, where the program or instructions, when executed by the processor, implement the steps of the target classification model training method based on unlabeled samples as described above.
Optionally, in this embodiment, the memory is configured to store program code for performing the following steps:
s1, labeling a label-free sample set in the sample data set based on a first classification model to obtain a pseudo label sample set, wherein the first classification model is obtained by training a target classification model based on a labeled sample set;
s2, screening outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set to obtain a non-outlier sample set, wherein the multi-modal matching model is obtained based on the labeled sample set through training;
s3, training the target classification model based on the non-outlier sample set until the model converges.
Optionally, the specific example in this embodiment may refer to the example described in embodiment 1 above, and this embodiment is not described again here.
Example 4
Embodiments of the present invention also provide a readable storage medium on which a program or instructions are stored, which when executed by a processor, implement the steps of the target classification model training method based on label-free samples as described above.
Optionally, in this embodiment, the readable storage medium is configured to store program code for performing the following steps:
s1, labeling a label-free sample set in the sample data set based on a first classification model to obtain a pseudo label sample set, wherein the first classification model is obtained by training a target classification model based on a labeled sample set;
s2, screening outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set to obtain a non-outlier sample set, wherein the multi-modal matching model is obtained based on the labeled sample set through training;
s3, training the target classification model based on the non-outlier sample set until the model converges.
Optionally, the storage medium is further configured to store program codes for executing the steps included in the method in embodiment 1, which is not described in detail in this embodiment.
Optionally, in this embodiment, the storage medium may include, but is not limited to: a U-disk, a Read-Only Memory (ROM), a Random Access Memory (RAM), a removable hard disk, a magnetic or optical disk, and other various media capable of storing program codes.
Optionally, the specific examples in this embodiment may refer to the examples described in embodiment 1 and embodiment 2, and this embodiment is not described herein again.
The above-mentioned serial numbers of the embodiments of the present invention are merely for description and do not represent the merits of the embodiments.
The integrated unit in the above embodiments, if implemented in the form of a software functional unit and sold or used as a separate product, may be stored in the above computer-readable storage medium. Based on such understanding, the technical solution of the present invention may be embodied in the form of a software product, which is stored in a storage medium and includes several instructions for causing one or more computer devices (which may be personal computers, servers, network devices, etc.) to execute all or part of the steps of the method according to the embodiments of the present invention.
In the above embodiments of the present invention, the descriptions of the respective embodiments have respective emphasis, and for parts that are not described in detail in a certain embodiment, reference may be made to related descriptions of other embodiments.
In the several embodiments provided in the present application, it should be understood that the disclosed client may be implemented in other manners. The above-described embodiments of the apparatus are merely illustrative, and for example, the division of the units is only one type of division of logical functions, and there may be other divisions when actually implemented, for example, a plurality of units or components may be combined or may be integrated into another system, or some features may be omitted, or not executed. In addition, the shown or discussed mutual coupling or direct coupling or communication connection may be an indirect coupling or communication connection through some interfaces, units or modules, and may be in an electrical or other form.
The units described as separate parts may or may not be physically separate, and parts displayed as units may or may not be physical units, may be located in one place, or may be distributed on a plurality of network units. Some or all of the units can be selected according to actual needs to achieve the purpose of the solution of the embodiment.
In addition, functional units in the embodiments of the present invention may be integrated into one processing unit, or each unit may exist alone physically, or two or more units are integrated into one unit. The integrated unit can be realized in a form of hardware, and can also be realized in a form of a software functional unit.
The foregoing is only a preferred embodiment of the present invention, and it should be noted that, for those skilled in the art, various modifications and decorations can be made without departing from the principle of the present invention, and these modifications and decorations should also be regarded as the protection scope of the present invention.

Claims (10)

1. A target classification model training method based on label-free samples is characterized by comprising the following steps:
labeling an unmarked sample set in the sample data set based on a first classification model to obtain a pseudo label sample set, wherein the first classification model is obtained by training a target classification model based on the marked sample set;
screening outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set to obtain a non-outlier sample set, wherein the multi-modal matching model is obtained based on the labeled sample set through training;
and training the target classification model based on the non-outlier sample set until the model converges.
2. The method of claim 1, wherein labeling a set of unlabeled samples of the set of sample data based on a first classification model comprises:
predicting a plurality of unmarked samples in the unmarked sample set according to the first classification model to obtain a plurality of pseudo labels corresponding to the unmarked samples;
and correspondingly labeling the plurality of unlabeled samples based on the plurality of pseudo labels to obtain the pseudo label sample set.
3. The method of claim 1, wherein screening outlier samples in the sample data set based on a multi-modal matching model and a pseudo-labeled sample set comprises:
determining outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set;
rejecting the outlier samples in the sample dataset.
4. The method of claim 3, wherein determining outlier samples in the sample data set from a multi-modal matching model and a pseudo-labeled sample set comprises:
inputting label feature vectors corresponding to pseudo labels of the pseudo label samples in the pseudo label sample set and image feature vectors of the pseudo label samples into the multi-modal matching model;
if the label characteristic vector is determined to be not matched with the image characteristic vector according to the multi-modal matching model, determining that an unmarked sample corresponding to the pseudo label sample is the outlier sample;
otherwise, determining that the unmarked sample corresponding to the pseudo label sample is a non-outlier sample.
5. The method of claim 1, wherein the target classification model comprises an image classification model comprising a backbone network and a classifier, wherein,
training the image classification model based on the non-outlier sample set until the model converges comprises:
training the image classification model based on the non-outlier sample set; and the number of the first and second groups,
an image rotation model is trained based on a set of rotation samples,
the rotating sample set is obtained according to the sample data set, and the image rotating model comprises the backbone network and a rotating classifier.
6. The method of claim 5, wherein the image classification model is trained based on the non-outlier sample set; and training an image rotation model based on the rotation sample set, including:
determining a first loss function corresponding to the image classification model according to the non-outlier sample set;
determining a second loss function corresponding to the image rotation model according to the rotation sample set;
and training the image classification model and the image rotation model respectively based on the first loss function and the second loss function until iteration reaches preset times.
7. The method of claim 1, further comprising, prior to training the image rotation model based on the rotation sample set:
rotating the training samples in the sample data set to a preset angle to obtain a rotating sample, wherein the rotating sample comprises a rotating angle label.
8. A target classification model training device based on label-free samples is characterized by comprising:
the labeling module is used for labeling a label-free sample set in the sample data set based on a first classification model to obtain a pseudo label sample set, wherein the first classification model is obtained by training a target classification model based on a labeled sample set;
the screening module is used for screening outlier samples in the sample data set according to a multi-modal matching model and a pseudo label sample set to obtain a non-outlier sample set, wherein the multi-modal matching model is obtained based on the labeled sample set through training;
and the training module is used for training the target classification model based on the non-outlier sample set until the model converges.
9. An electronic device comprising a processor, a memory, and a program or instructions stored on the memory and executable on the processor, the program or instructions when executed by the processor implementing the steps of the label-free sample based target classification model training method according to claims 1-7.
10. A readable storage medium, on which a program or instructions are stored, which when executed by a processor, implement the steps of the label-free sample based target classification model training method according to claims 1-7.
CN202111591966.3A 2021-12-23 2021-12-23 Target classification model training method and device based on label-free sample data Pending CN114419363A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111591966.3A CN114419363A (en) 2021-12-23 2021-12-23 Target classification model training method and device based on label-free sample data

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111591966.3A CN114419363A (en) 2021-12-23 2021-12-23 Target classification model training method and device based on label-free sample data

Publications (1)

Publication Number Publication Date
CN114419363A true CN114419363A (en) 2022-04-29

Family

ID=81266938

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111591966.3A Pending CN114419363A (en) 2021-12-23 2021-12-23 Target classification model training method and device based on label-free sample data

Country Status (1)

Country Link
CN (1) CN114419363A (en)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115272777A (en) * 2022-09-26 2022-11-01 山东大学 Semi-supervised image analysis method for power transmission scene
CN115346076A (en) * 2022-10-18 2022-11-15 安翰科技(武汉)股份有限公司 Pathological image recognition method, model training method and system thereof, and storage medium
CN115512696A (en) * 2022-09-20 2022-12-23 中国第一汽车股份有限公司 Simulation training method and vehicle
CN117611932A (en) * 2024-01-24 2024-02-27 山东建筑大学 Image classification method and system based on double pseudo tag refinement and sample re-weighting

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115512696A (en) * 2022-09-20 2022-12-23 中国第一汽车股份有限公司 Simulation training method and vehicle
CN115272777A (en) * 2022-09-26 2022-11-01 山东大学 Semi-supervised image analysis method for power transmission scene
CN115272777B (en) * 2022-09-26 2022-12-23 山东大学 Semi-supervised image analysis method for power transmission scene
CN115346076A (en) * 2022-10-18 2022-11-15 安翰科技(武汉)股份有限公司 Pathological image recognition method, model training method and system thereof, and storage medium
CN117611932A (en) * 2024-01-24 2024-02-27 山东建筑大学 Image classification method and system based on double pseudo tag refinement and sample re-weighting
CN117611932B (en) * 2024-01-24 2024-04-26 山东建筑大学 Image classification method and system based on double pseudo tag refinement and sample re-weighting

Similar Documents

Publication Publication Date Title
CN108491805B (en) Identity authentication method and device
CN114419363A (en) Target classification model training method and device based on label-free sample data
CN108416326B (en) Face recognition method and device
CN113139628B (en) Sample image identification method, device and equipment and readable storage medium
CN111107048A (en) Phishing website detection method and device and storage medium
CN110232373A (en) Face cluster method, apparatus, equipment and storage medium
CN110598019B (en) Repeated image identification method and device
CN107679872A (en) Art work discrimination method and device, electronic equipment based on block chain
Sahay et al. Leaf analysis for plant recognition
CN111931809A (en) Data processing method and device, storage medium and electronic equipment
WO2024060684A1 (en) Model training method, image processing method, device, and storage medium
CN112529020A (en) Animal identification method, system, equipment and storage medium based on neural network
CN113673482B (en) Cell antinuclear antibody fluorescence recognition method and system based on dynamic label distribution
CN112949767A (en) Sample image increment, image detection model training and image detection method
CN114492601A (en) Resource classification model training method and device, electronic equipment and storage medium
CN111882034A (en) Neural network processing and face recognition method, device, equipment and storage medium
CN112418256A (en) Classification, model training and information searching method, system and equipment
CN114359582A (en) Small sample feature extraction method based on neural network and related equipment
CN115114329A (en) Method and device for detecting data stream abnormity, electronic equipment and storage medium
CN112464873A (en) Model training method, face living body recognition method, system, device and medium
CN112070093A (en) Method for generating image classification model, image classification method, device and equipment
CN115713669A (en) Image classification method and device based on inter-class relation, storage medium and terminal
CN116206334A (en) Wild animal identification method and device
CN113255766B (en) Image classification method, device, equipment and storage medium
CN115567224A (en) Method for detecting abnormal transaction of block chain and related product

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination