CN114417975A - Data classification method and system based on deep PU learning and class prior estimation - Google Patents
Data classification method and system based on deep PU learning and class prior estimation Download PDFInfo
- Publication number
- CN114417975A CN114417975A CN202111591020.7A CN202111591020A CN114417975A CN 114417975 A CN114417975 A CN 114417975A CN 202111591020 A CN202111591020 A CN 202111591020A CN 114417975 A CN114417975 A CN 114417975A
- Authority
- CN
- China
- Prior art keywords
- model
- student
- teacher
- data
- learning
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 49
- 238000009826 distribution Methods 0.000 claims abstract description 14
- 238000012549 training Methods 0.000 claims description 56
- 239000002245 particle Substances 0.000 claims description 20
- 238000013528 artificial neural network Methods 0.000 claims description 16
- 230000006870 function Effects 0.000 claims description 13
- 238000013145 classification model Methods 0.000 claims description 11
- 238000001514 detection method Methods 0.000 claims description 11
- 239000000203 mixture Substances 0.000 claims description 9
- 125000004122 cyclic group Chemical group 0.000 claims description 8
- 238000013527 convolutional neural network Methods 0.000 claims description 6
- 238000005057 refrigeration Methods 0.000 claims description 5
- 230000000306 recurrent effect Effects 0.000 claims description 3
- 238000005516 engineering process Methods 0.000 abstract description 5
- 230000000694 effects Effects 0.000 abstract description 4
- 239000003814 drug Substances 0.000 abstract description 2
- 238000005457 optimization Methods 0.000 description 7
- 238000013135 deep learning Methods 0.000 description 5
- 230000008569 process Effects 0.000 description 5
- 208000035977 Rare disease Diseases 0.000 description 2
- 238000013459 approach Methods 0.000 description 2
- 230000008901 benefit Effects 0.000 description 2
- 238000004364 calculation method Methods 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 230000000873 masking effect Effects 0.000 description 2
- 241000282414 Homo sapiens Species 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 239000000284 extract Substances 0.000 description 1
- 238000009499 grossing Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 239000000463 material Substances 0.000 description 1
- 238000001000 micrograph Methods 0.000 description 1
- 238000000386 microscopy Methods 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000035945 sensitivity Effects 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
- G06F18/2155—Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/35—Clustering; Classification
- G06F16/353—Clustering; Classification into predefined classes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/217—Validation; Performance evaluation; Active pattern learning techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/044—Recurrent networks, e.g. Hopfield networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/048—Activation functions
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Databases & Information Systems (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
The invention provides a data classification method and system based on deep PU learning and class prior estimation, which can estimate class prior at the same time, and can estimate and learn a depth model by utilizing the obtained prior without the real prior distribution of known data, thereby being more suitable for the application of PU learning in actual scenes. The iterative framework proposed by the invention comprises modeling the prediction score of the network as GMM, thereby estimating the positive-class prior; performing unbiased PU learning based on the estimated value of the positive-class prior; and further, the performance and stability of the algorithm are improved by combining the technologies of an average teacher of semi-supervised learning, temperature sharpening and the like. The framework can be applied to PU problems in various fields including computer vision, recommendation systems, biological medicine and the like, has excellent effect and has scientific value and practical value.
Description
Technical Field
The invention relates to the technical field of Positive-label-free (PU) learning in machine learning, in particular to deep PU learning based on cost sensitivity.
Background
In recent years, with the development of the internet and information technology, human beings have entered the big data era. Deep learning based on mass data is widely concerned and makes breakthrough progress. However, the superior performance achieved by deep learning algorithms relies on large amounts of data, especially without guidance from complete class label information. In many practical application scenarios, data labels are difficult and expensive to acquire, and labels of only a small part of data can be often acquired under limited manpower and material resources. Therefore, the learning mode which depends less on the data labeling becomes one of the popular PU learning modes. For example, in the classification problem of rare diseases, the diagnosed sample can be regarded as a positive class sample, while the other samples which are not diagnosed, namely the unlabeled samples, still have the possibility of rare diseases, namely the unlabeled samples comprise a positive class sample and a negative class sample. Similar situations also occur in tasks such as malicious URL detection, false comment detection, particle picking of a cryoelectron microscope and the like. Therefore, the method only uses the positive samples and the unlabeled data for learning, which is also called PU learning, and has important value.
The training set for PU learning consists of a labeled positive sample subset and an unlabeled sample subset, and the key and difficult point is how to utilize the unlabeled data. Therefore, PU learning is classified into different categories, namely two-stage method, biased PU learning and unbiased PU learning, according to the way it processes unlabeled data. The two-stage PU learning method is most intuitive, and the method firstly extracts reliable negative samples and positive samples from unlabeled data according to a smooth hypothesis or a clustering hypothesis, so that PU problems are converted into general semi-supervised learning problems, and then the semi-supervised or supervised learning method is used for training. Another more intuitive approach is biased PU learning, i.e. processing unlabeled data into noisy negative examples. In addition, under the framework of Cost Sensitive (Cost Sensitive) learning, unbiased PU learning estimates the loss of a positive class through only positive samples, and indirectly constructs the loss of a negative class based on the loss and class prior knowledge generated by taking unlabeled data as the negative class, so that unbiased estimation of a common classification optimization target is realized, and the current most advanced performance is obtained.
Disclosure of Invention
The inventor finds that when studying unbiased PU learning algorithms, such algorithms assume that the true class prior can be obtained in advance. In fact, for most practical application scenarios of PU learning, the class prior is often unknown. If a class prior value is randomly set in advance, especially a positive class prior preset value smaller than a true value, the performance of the algorithm is seriously damaged. The above defects greatly reduce the practical value of the algorithm. The category prior plays a key role in the method, and the related research of category prior estimation is driven. However, most of the methods can only achieve a better level on a small-scale data set and a traditional method, and still catch the forever when facing a deep learning algorithm based on mass data. Therefore, how to accurately estimate the positive class prior with the lowest possible calculation cost in the absence of the class prior knowledge so as to perform the cost-sensitive deep PU-based learning becomes the key of the problem.
Specifically, in order to overcome the above technical problems, the present invention provides a data classification method based on deep PU learning and class prior estimation, which includes:
step 2, inputting all teacher prediction scores into a Gaussian mixture model to obtain a positive prior; constructing temperature sharpening loss based on all student prediction scores; constructing consistency loss based on all student prediction scores and teacher prediction scores; obtaining a non-negative PU risk based on the positive-class prior and all student prediction scores, combining the consistency loss, the non-negative PU risk and the temperature sharpening loss to obtain a target loss, updating parameters of the student model by using gradient back propagation based on the target loss until the target converges or reaches a preset iteration number, and saving the current student model or teacher model as a data classification model, for example, saving the current teacher model and the student model with better performance as the data classification model;
and 3, inputting the data to be classified into the data classification model to obtain the classification of the data to be classified.
When the data classification method is used for malicious URL detection, the data in the training set are URLs with malicious categories marked and unlabeled URLs, and the student model and the teacher model are both cyclic neural networks; when the method is used for detecting false comments, the data in the training set are the comments marked with false categories and the comments without labels, and the student model and the teacher model are both cyclic neural networks; when the method is used for particle picking of a refrigeration electron microscope, the data in the training set are the particle area marked with the selected category and the particle area without a label, and the student model and the teacher model are both convolutional neural networks.
The data classification method based on deep PU learning and class prior estimation comprises the following steps of:
the respective prediction scores of the student model and the teacher model are respectively as follows:
S=sigmoid(f(X,Θt))
S′=sigmoid(f(X,Θ′t))
wherein,x is a training set, S is a student prediction score output by the student model, and S' is a teacher prediction score output by the teacher model, thetatIs a parameter of the student model at the moment t, theta'tThe parameters of the teacher model at the time t.
The data classification method based on deep PU learning and class prior estimation is characterized in that consistency loss of output of the student model and the teacher modelComprises the following steps:
wherein x isie.X denotes the ith sample in the training set X, ciIs its confidence, N ═ X |;is an indicator function, the function value takes 1 when the condition (·) is satisfied, and takes 0 otherwise; tau is the confidence threshold, theta is the parameter of the student model, and theta' is the parameter of the teacher model.
The data classification method based on deep PU learning and class prior estimation is characterized in that the temperature sharpening lossComprises the following steps:
wherein T is the temperature of the class distribution, and s is the student prediction score output by the student model.
The invention also provides a data classification system based on deep PU learning and class prior estimation, which comprises:
the initial module is used for acquiring a training set comprising a plurality of data samples, only part of the data samples in the training set are marked with class labels, the training set is simultaneously input into two student models and a teacher model which have the same network structure and different parameters, and student prediction scores and teacher prediction scores corresponding to the data samples output by the student models and the teacher model are respectively obtained;
the training module is used for inputting all teacher prediction scores into the Gaussian mixture model to obtain a positive prior; constructing temperature sharpening loss based on all student prediction scores; constructing consistency loss based on all student prediction scores and teacher prediction scores; obtaining a non-negative PU risk based on the positive-class prior and all student prediction scores, combining the consistency loss, the non-negative PU risk and the temperature sharpening loss to obtain a target loss, updating parameters of the student model by using gradient back propagation based on the target loss until the target converges or reaches a preset iteration number, and storing a current student model or teacher model as a data classification model;
and the classification module is used for inputting the data to be classified into the data classification model to obtain the classification of the data to be classified.
When the data classification system is used for malicious URL detection, the data in the training set are URLs with malicious categories marked and unlabeled URLs, and the student model and the teacher model are both cyclic neural networks; when the method is used for detecting false comments, the data in the training set are the comments marked with false categories and the comments without labels, and the student model and the teacher model are both cyclic neural networks; when the method is used for particle picking of a refrigeration electron microscope, the data in the training set are the particle area marked with the selected category and the particle area without a label, and the student model and the teacher model are both convolutional neural networks.
The data classification system based on deep PU learning and class prior estimation is characterized in that the training module is used for respectively predicting scores of the student model and the teacher model based on the following formulas:
S=sigmoid(f(X,Θt))
S′=sigmoid(f(X,Θ′t))
wherein,x is a training set, S is a student prediction score output by the student model, and S' is a teacher prediction score output by the teacher model, thetatIs a parameter of the student model at the moment t, theta'tThe parameters of the teacher model at the time t.
The data classification system based on deep PU learning and class prior estimation is characterized in that consistency loss of output of the student model and the teacher modelComprises the following steps:
wherein x isie.X denotes the ith sample in the training set X, ciIs its confidence, N ═ X |;is an indicator function, the function value takes 1 when the condition (·) is satisfied, and takes 0 otherwise; tau is the confidence threshold, theta is the parameter of the student model, and theta' is the parameter of the teacher model.
The data classification system based on deep PU learning and class prior estimation is characterized in that the temperature sharpening lossComprises the following steps:
wherein T is the temperature of the class distribution, and s is the student prediction score output by the student model.
The invention also proposes a storage medium for storing a program for executing the arbitrary data classification method based on deep PU learning and class prior estimation.
The invention further provides a client used for any data classification system based on deep PU learning and class prior estimation.
According to the scheme, the invention has the advantages that:
the invention provides an iterative deep PU learning and class prior estimation framework, which can simultaneously estimate class prior and utilize the obtained prior to estimate and learn a depth model without the actual prior distribution of known data, thereby being more suitable for the application of PU learning in actual scenes. The iteration framework provided by the invention comprises the following core key points: (1) modeling a prediction score of the network as a GMM to estimate a positive-class prior; (2) performing unbiased PU learning based on the estimated value of the positive-class prior; (3) and further, the performance and stability of the algorithm are improved by combining the technologies of an average teacher of semi-supervised learning, temperature sharpening and the like. The frame can be applied to PU problems in various fields including computer vision, recommendation systems, biological medicine and the like, has excellent effect and has scientific value and practical value.
Drawings
FIG. 1 is a block diagram of the process of the present invention.
Detailed Description
The aim of the invention is to solve the problem of how to perform unbiased PU learning in the absence of class prior knowledge. The default class prior of the existing unbiased PU learning method is known or easy to estimate, while the class prior in the real PU problem is often unknown and difficult to estimate. In addition, the existing class prior estimation algorithm is mainly designed for the traditional machine learning classifier, and the advantage of deep learning in a large-scale data set is not exerted. In order to overcome the problems, the invention provides an iterative deep PU learning framework based on an unsupervised hybrid model. The method utilizes the characteristic that prediction scores given by the deep neural network to different types of samples (positive samples and negative samples) have different distributions, and a Gaussian mixture model is used for approximately fitting the mixed distribution of the prediction scores. And the classification performance which is comparable to that of a PU algorithm based on real positive prior is realized by combining the common semi-supervised learning technology and the optimization target aiming at the PU problem.
The invention comprises the following key technical points:
in the key point 1, before the deep neural network generates the overfitting phenomenon, the prediction scores of the positive samples and the prediction scores of the negative samples are distributed differently, the prediction scores of the positive samples are distributed in an interval with a higher score in a centralized manner, the prediction scores of the negative samples are distributed in an interval with a lower score in a centralized manner, and the two prediction scores form two bell-shaped curves with a high middle part and low two ends respectively. Based on the above observations, it is proposed to unsupervised Model the prediction scores using Gaussian Mixture Model (GMM). The GMM has less parameter quantity, is not restricted by PU learning missing negative class labels, and has both time complexity and space complexity required for solving. Therefore, the method occupies less computing resources and can be widely applied to data sets of various scales;
a key point 2 simultaneously considers category prior estimation and PU learning, and proposes an iterative solution for the PU problem, namely training of a model and GMM estimation category prior iteration are carried out; in an ideal situation, as time goes on, if the classification performance of the deep neural network is better and better, the prediction score given by the deep neural network also tends to be reliable; if the prediction score is more reliable, the positive-class prior of the GMM estimation is more accurate, and the improvement of the classification performance of the model is further promoted. The training of the model and the GMM estimation category prior have forward feedback with each other, and the characteristic can be used by an iterative framework;
key point 3, common techniques for introducing semi-supervised learning to the framework include Mean Teacher (Mean Teacher) and Temperature Sharpening (Temperature Sharpening). When the training epoch number is not introduced, the class prior estimation is not stable enough, the prior estimation value oscillates all the time along with the increase of the training epoch number, and the variance is large. After the introduction, an average teacher ensures the stability of the predicted score through the average of historical parameters; and the temperature sharpening encourages the prediction score to approach 0 or 1, so that bell-shaped curves corresponding to different classes in the prediction score are more distinguishable, and the fitting effect of the GMM is enhanced. The two methods act together to stabilize the estimation value of the class prior and effectively improve the classification performance of the algorithm.
In order to make the aforementioned features and effects of the present invention more comprehensible, embodiments accompanied with figures are described in detail below.
The invention provides a new iterative deep PU learning and category prior estimation framework, and the flow architecture of the framework is shown in figure 1. Firstly, acquiring a training set, wherein when the training set is used for malicious URL detection, the acquired data is usually a small amount of malicious URLs and a large amount of unlabeled URLs; during false comment detection, recognized false comments often have significant features and can be regarded as positive samples, while unrecognized comments have difficulty in determining authenticity and can be regarded as unlabeled samples; when the particles of the cryoelectron microscope are picked up, the selected particles are positive samples, and due to the low signal-to-noise ratio of the cryoelectron microscope image and the diversity of negative distribution, the unselected areas may include both positive examples and negative examples, and the particles in the areas are unlabeled samples; inputting the PU training set X into two student models f (·, Θ) and a teacher model f (·, Θ ') with the same network structure and different parameters (Θ' can be obtained by performing Exponential sliding Average Exponential Moving Average operation on Θ); s ═ f (X, Θ), and S '═ f (·, Θ') are the respective prediction scores of the student model and the teacher model; solving for the prediction score S' of the teacher model using Gaussian mixture modeling GMM to estimate the positive-class priorThereby can be assisted byComputing non-negative PU riskIncorporating average teacher consistency lossAnd temperature sharpening lossAnd then calculate an optimization objectiveAnd finally, updating the parameter theta of the student model by using a gradient back propagation algorithm. The steps are iterated and repeated in such an iterative way until the optimization goal is reachedAnd (6) converging. The calculation process of each step is explained in detail next.
Among them, text-like tasks, such as malicious URL detection and false comment detection, their student model and teacher model may employ a recurrent neural network. Image-like tasks, such as particle picking for cryoelectron microscopy, may employ convolutional neural networks. Model, confidence threshold τ, temperature T, hyper-parameter λ for different applications1、λ2Can be freely arranged according to actual conditions. The prediction score can be viewed as a positive class probability given by the model. The method is obtained by carrying out deep neural network and sigmoid transformation on a sample.
(1) Calculating a prediction score S, S';
given a triplet (x, y, z), x is the input feature, y is its class label, and z ∈ {1, 0} indicates the presence or absence of its class label. In PU learning, the PU training set is typically composed of several tuples (x, z) because the true y is unknown. Furthermore, the labeled samples are all positive, i.e.There is Pr (y 1| z 1) ═ 1. Sample set X ═ Xl∪XuXl is a subset of positive samples with labels, XuIs a subset of unlabeled exemplars. the parameters of the student model and the teacher model at the time t are respectively thetat,Θ′tThe corresponding prediction functions are defined as f (·, Θ)t),f(·,Θ′t):Then, the respective prediction scores of the student model and the teacher model are respectively:
S=sigmoid(f(X,Θt)),
S′=sigmoid(f(X,Θ′t)).
modeling S' using GMM to obtain positive-class priorspIs estimated value ofThe product of this equation is an arbitrary real number. Due to the fact thatThere is no closed-form solution and therefore an iterative approximation using the EM algorithm is required (see M steps and parameter update equations). The step of modeling S' by using GMM to estimate the positive-class prior is one of the invention points of the invention, the technical progress brought by the step is suitable for large-scale data sets in deep learning scenes (the former method is not suitable), and the time and space complexity for solving GMM is low. In addition, the unique technology of the application also comprises a positive-class prior estimation smoother (average teacher and temperature sharpening), and the technical progress brought by the positive-class prior estimation smoother is that the positive-class prior estimation process is more accurate and smoother. And the positive-class prior estimation and the depth model are carried out in an iterative mode, and the technical progress brought by the iterative estimation and the depth model is that the positive-class prior estimation and the depth model are mutually promoted, so that the positive-class prior estimation is more accurate, and the classification performance of the model is better.
GMM is an unsupervised modeling method that models the prediction score S' as follows:
wherein the category label y is a hidden variable; pipIs a mixture coefficient and also represents a positive-class prior; pin=1-πp;Respectively representing the gaussian distributions to which the positive (negative) sample prediction scores follow; gaussian distributionμ, σ denote the mean and variance thereof, respectively; since the prediction fraction of positive samples is overall larger than that of negative samples, μn<μp。
GMM is typically solved using the Expectation Maximization (EM) algorithm. Is provided with The parameter value representing the GMM at time t,expressed as parameter phi(t)Alternately and iteratively performing (1) the desired (E) steps, i.e. calculating the s-meanThe generated conditional probability:
and (2) a maximize (M) step, namely:
until convergence. According to the literature,. phi(t+1)The parameter update equation of (a) can be expanded into the following form:
The estimated value of the positive-class prior is obtained in the step (2)Thereby, by means ofComputing non-negative PU risk
Defining the confidence coefficient c of the model output as the maximum value of the class probability of the sample obtained by the student model output, namely:
c=max(s,1-s).
with confidence-based masking, setting the confidence threshold to τ, then the student model and teacher model outputs a loss of consistencyComprises the following steps:
wherein x isie.X denotes the ith sample in the set of samples X, ciIs its confidence, N ═ X |;is an indicator function, which takes 1 when the condition (·) is satisfied and 0 otherwise. And the consistency is lostThe classification prior estimation process can be stabilized, and the classification performance of the depth model is finally improved.
Given a prediction score s, a sharpening function is used to reduce its entropy of information about the distribution of classes. Temperature sharpening this is achieved by adjusting the Temperature (T) of the class distribution, as follows:
in combination with confidence-based masking techniques mentioned in the average teacher, only the trusted output is temperature sharpened, and so is lostComprises the following steps:
Wherein λ is1,λ2Is a hyper-parameter.
(3) Updating model parameters theta, theta'
Updating the parameter theta of the student model by using a gradient back propagation algorithm:
wherein η represents the learning rate, and representsOptimizing an objectiveWith respect to the derivative of the model parameter Θ. Updating parameters theta' of the teacher model through EMA operation:
Θ′t+1=αΘ′t+(1-α)Θt+1.
wherein alpha is a smoothing coefficient and takes a value within [0, 1 ].
Iteratively executing the steps (1) - (4) until the optimization goal is reachedAnd (6) converging. When the algorithm convergesIt is the final positive-class prior estimate, f (·, Θ), f (·, Θ') for the resulting classifier.
The following are system examples corresponding to the above method examples, and this embodiment can be implemented in cooperation with the above embodiments. The related technical details mentioned in the above embodiments are still valid in this embodiment, and are not described herein again in order to reduce repetition. Accordingly, the related-art details mentioned in the present embodiment can also be applied to the above-described embodiments.
The invention also provides a data classification system based on deep PU learning and class prior estimation, which comprises:
the initial module is used for acquiring a training set comprising a plurality of data samples, only part of the data samples in the training set are marked with class labels, the training set is simultaneously input into two student models and a teacher model which have the same network structure and different parameters, and student prediction scores and teacher prediction scores corresponding to the data samples output by the student models and the teacher model are respectively obtained;
the training module is used for inputting all teacher prediction scores into the Gaussian mixture model to obtain a positive prior; constructing temperature sharpening loss based on all student prediction scores; constructing consistency loss based on all student prediction scores and teacher prediction scores; obtaining a non-negative PU risk based on the positive-class prior and all student prediction scores, combining the consistency loss, the non-negative PU risk and the temperature sharpening loss to obtain a target loss, updating parameters of the student model by using gradient back propagation based on the target loss until the target converges or reaches a preset iteration number, and storing a current student model or teacher model as a data classification model;
and the classification module is used for inputting the data to be classified into the data classification model to obtain the classification of the data to be classified.
When the data classification system is used for malicious URL detection, the data in the training set are URLs with malicious categories marked and unlabeled URLs, and the student model and the teacher model are both cyclic neural networks; when the method is used for detecting false comments, the data in the training set are the comments marked with false categories and the comments without labels, and the student model and the teacher model are both cyclic neural networks; when the method is used for particle picking of a refrigeration electron microscope, the data in the training set are the particle area marked with the selected category and the particle area without a label, and the student model and the teacher model are both convolutional neural networks.
The data classification system based on deep PU learning and class prior estimation is characterized in that the training module is used for respectively predicting scores of the student model and the teacher model based on the following formulas:
S=sigmoid(f(X,Θt))
S′=sigmoid(f(X,Θ′t))
wherein,x is a training set, S is a student prediction score output by the student model, and S' is a teacher prediction score output by the teacher model, thetatIs a parameter of the student model at the moment t, theta'tThe parameters of the teacher model at the time t.
The data classification system based on deep PU learning and class prior estimation is characterized in that the student modelLoss of consistency with the teacher model outputComprises the following steps:
wherein x isie.X denotes the ith sample in the training set X, ciIs its confidence, N ═ X |;is an indicator function, the function value takes 1 when the condition (·) is satisfied, and takes 0 otherwise; tau is the confidence threshold, theta is the parameter of the student model, and theta' is the parameter of the teacher model.
The data classification system based on deep PU learning and class prior estimation is characterized in that the temperature sharpening lossComprises the following steps:
wherein T is the temperature of the class distribution, and s is the student prediction score output by the student model.
The invention also proposes a storage medium for storing a program for executing the arbitrary data classification method based on deep PU learning and class prior estimation.
The invention further provides a client used for any data classification system based on deep PU learning and class prior estimation.
Claims (12)
1. A data classification method based on deep PU learning and class prior estimation is characterized by comprising the following steps:
step 1, obtaining a training set comprising a plurality of data samples, wherein only part of the data samples in the training set are marked with class labels, and inputting the training set into two student models and teacher models with the same network structure and different parameters at the same time to respectively obtain student prediction scores and teacher prediction scores corresponding to the data samples output by the student models and the teacher models;
step 2, inputting all teacher prediction scores into a Gaussian mixture model to obtain a positive prior; constructing temperature sharpening loss based on all student prediction scores; constructing consistency loss based on all student prediction scores and teacher prediction scores; obtaining a non-negative PU risk based on the positive-class prior and all student prediction scores, combining the consistency loss, the non-negative PU risk and the temperature sharpening loss to obtain a target loss, updating parameters of the student model by using gradient back propagation based on the target loss until the target converges or reaches a preset iteration number, and storing a current student model or teacher model as a data classification model;
and 3, inputting the data to be classified into the data classification model to obtain the classification of the data to be classified.
2. The data classification method based on deep PU learning and class prior estimation as claimed in claim 1, wherein when used for malicious URL detection, the data in the training set are partial URLs labeled with malicious classes and unlabeled URLs, and the student model and the teacher model are both recurrent neural networks; when the method is used for detecting false comments, the data in the training set are the comments marked with false categories and the comments without labels, and the student model and the teacher model are both cyclic neural networks; when the method is used for particle picking of a refrigeration electron microscope, the data in the training set are the particle area marked with the selected category and the particle area without a label, and the student model and the teacher model are both convolutional neural networks.
3. The method for data classification based on deep PU learning and class a priori estimation of claim 1 or 2, wherein the step 2 comprises:
the respective prediction scores of the student model and the teacher model are respectively as follows:
S=sigmoid(f(X,Θt))
S′=sigmoid(f(X,Θ′t))
4. The method of claim 1 or 2, wherein the consistency between the output of the student model and the output of the teacher model is lostComprises the following steps:
wherein x isie.X denotes the ith sample in the training set X, ciIs its confidence, N ═ X |;is an indicator function, the function value takes 1 when the condition (·) is satisfied, and takes 0 otherwise; tau is the confidence threshold, theta is the parameter of the student model, and theta' is the parameter of the teacher model.
6. A data classification system based on deep PU learning and class prior estimation, comprising:
the initial module is used for acquiring a training set comprising a plurality of data samples, only part of the data samples in the training set are marked with class labels, the training set is simultaneously input into two student models and a teacher model which have the same network structure and different parameters, and student prediction scores and teacher prediction scores corresponding to the data samples output by the student models and the teacher model are respectively obtained;
the training module is used for inputting all teacher prediction scores into the Gaussian mixture model to obtain a positive prior; constructing temperature sharpening loss based on all student prediction scores; constructing consistency loss based on all student prediction scores and teacher prediction scores; obtaining a non-negative PU risk based on the positive-class prior and all student prediction scores, combining the consistency loss, the non-negative PU risk and the temperature sharpening loss to obtain a target loss, updating parameters of the student model by using gradient back propagation based on the target loss until the target converges or reaches a preset iteration number, and storing a current student model or teacher model as a data classification model;
and the classification module is used for inputting the data to be classified into the data classification model to obtain the classification of the data to be classified.
7. The deep PU learning and class prior estimation based data classification system of claim 6, wherein when used for malicious URL detection, the data in the training set are URLs with some malicious classes labeled and unlabeled URLs, and the student model and the teacher model are both recurrent neural networks; when the method is used for detecting false comments, the data in the training set are the comments marked with false categories and the comments without labels, and the student model and the teacher model are both cyclic neural networks; when the method is used for particle picking of a refrigeration electron microscope, the data in the training set are the particle area marked with the selected category and the particle area without a label, and the student model and the teacher model are both convolutional neural networks.
8. The deep PU learning and class prior estimation based data classification system of claim 6 or 7, wherein the training module is configured to calculate the respective prediction scores of the student model and the teacher model based on the following formula:
S=sigmoid(f(X,Θt))
S′=sigmoid(f(X,Θ′t))
9. The system of claim 6 or 7, wherein the consistency of the output of the student model and the teacher model is lostComprises the following steps:
wherein x isie.X denotes the ith sample in the training set X, ciIs its confidence, N ═ X |;is an indicator function, the function value takes 1 when the condition (·) is satisfied, and takes 0 otherwise; tau is the confidence threshold, theta is the parameter of the student model, and theta' is the parameter of the teacher model.
11. A storage medium storing a program for executing the data classification method based on deep PU learning and class prior estimation according to any one of claims 1 to 5.
12. A client for use in the data classification system of any one of claims 6 to 10 based on deep PU learning and class prior estimation.
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111591020.7A CN114417975A (en) | 2021-12-23 | 2021-12-23 | Data classification method and system based on deep PU learning and class prior estimation |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111591020.7A CN114417975A (en) | 2021-12-23 | 2021-12-23 | Data classification method and system based on deep PU learning and class prior estimation |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114417975A true CN114417975A (en) | 2022-04-29 |
Family
ID=81266728
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111591020.7A Pending CN114417975A (en) | 2021-12-23 | 2021-12-23 | Data classification method and system based on deep PU learning and class prior estimation |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114417975A (en) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115859106A (en) * | 2022-12-05 | 2023-03-28 | 中国地质大学(北京) | Mineral exploration method and device based on semi-supervised learning and storage medium |
CN117574258A (en) * | 2024-01-15 | 2024-02-20 | 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) | Text classification method based on text noise labels and collaborative training strategies |
-
2021
- 2021-12-23 CN CN202111591020.7A patent/CN114417975A/en active Pending
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115859106A (en) * | 2022-12-05 | 2023-03-28 | 中国地质大学(北京) | Mineral exploration method and device based on semi-supervised learning and storage medium |
CN117574258A (en) * | 2024-01-15 | 2024-02-20 | 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) | Text classification method based on text noise labels and collaborative training strategies |
CN117574258B (en) * | 2024-01-15 | 2024-04-26 | 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) | Text classification method based on text noise labels and collaborative training strategies |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112308158B (en) | Multi-source field self-adaptive model and method based on partial feature alignment | |
CN113378632B (en) | Pseudo-label optimization-based unsupervised domain adaptive pedestrian re-identification method | |
CN109190524B (en) | Human body action recognition method based on generation of confrontation network | |
US20220188568A1 (en) | Methods and systems for mining minority-class data samples for training a neural network | |
CN109376242B (en) | Text classification method based on cyclic neural network variant and convolutional neural network | |
CN111126488B (en) | Dual-attention-based image recognition method | |
CN106778796B (en) | Human body action recognition method and system based on hybrid cooperative training | |
CN113326731B (en) | Cross-domain pedestrian re-identification method based on momentum network guidance | |
CN114492574A (en) | Pseudo label loss unsupervised countermeasure domain adaptive picture classification method based on Gaussian uniform mixing model | |
CN112085055B (en) | Black box attack method based on transfer model Jacobian array feature vector disturbance | |
CN110929848B (en) | Training and tracking method based on multi-challenge perception learning model | |
CN111564179B (en) | Species biology classification method and system based on triple neural network | |
CN110097060B (en) | Open set identification method for trunk image | |
CN107945210B (en) | Target tracking method based on deep learning and environment self-adaption | |
CN114417975A (en) | Data classification method and system based on deep PU learning and class prior estimation | |
CN109840595B (en) | Knowledge tracking method based on group learning behavior characteristics | |
CN110728694A (en) | Long-term visual target tracking method based on continuous learning | |
CN112232395B (en) | Semi-supervised image classification method for generating countermeasure network based on joint training | |
CN112784921A (en) | Task attention guided small sample image complementary learning classification algorithm | |
CN113743474A (en) | Digital picture classification method and system based on cooperative semi-supervised convolutional neural network | |
Qiao et al. | A multi-level thresholding image segmentation method using hybrid Arithmetic Optimization and Harris Hawks Optimizer algorithms | |
Demirel et al. | Meta-tuning loss functions and data augmentation for few-shot object detection | |
CN116152554A (en) | Knowledge-guided small sample image recognition system | |
Wang et al. | Prototype-based intent perception | |
Xia et al. | Detecting smiles of young children via deep transfer learning |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |