CN113673242A - Text classification method based on K-neighborhood node algorithm and comparative learning - Google Patents
Text classification method based on K-neighborhood node algorithm and comparative learning Download PDFInfo
- Publication number
- CN113673242A CN113673242A CN202110960433.1A CN202110960433A CN113673242A CN 113673242 A CN113673242 A CN 113673242A CN 202110960433 A CN202110960433 A CN 202110960433A CN 113673242 A CN113673242 A CN 113673242A
- Authority
- CN
- China
- Prior art keywords
- encoder
- training
- classification
- representing
- samples
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F40/00—Handling natural language data
- G06F40/20—Natural language analysis
- G06F40/279—Recognition of textual entities
- G06F40/289—Phrasal analysis, e.g. finite state techniques or chunking
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/23—Clustering techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F40/00—Handling natural language data
- G06F40/20—Natural language analysis
- G06F40/205—Parsing
- G06F40/211—Syntactic parsing, e.g. based on context-free grammar [CFG] or unification grammars
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computational Linguistics (AREA)
- Health & Medical Sciences (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Audiology, Speech & Language Pathology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Probability & Statistics with Applications (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
The invention discloses a text classification method based on K adjacent node algorithm and contrast learning, which uses contrast learning to pull in the inter-class distance and pull out the inter-class distance in the training stage, and combines cross entropy loss to assist the contrast learning to carry out joint training, and in the inference process, performs joint prediction by a well-united training model and the nearest node algorithm to calculate the classification of a text to be inferred; the method can obtain a result higher than a text classification mode used in the current industry on the accuracy of text classification, and greatly improves the robustness of the model.
Description
Technical Field
The invention relates to deep learning and natural language processing, in particular to a text classification method based on a K-neighborhood node algorithm and contrast learning.
Background
The text classification task is a basic task in natural language processing, and the current mainstream text classification method is to use a linear classifier for classification on the basis of a large-scale pre-training model (such as BERT). However, linear classifiers tend to be less robust and are easily fooled by attacks such as TextFooler or BertAttack.
Disclosure of Invention
In order to solve the defects of the prior art, improve the robustness and simultaneously improve the classification accuracy of the model, the invention adopts the following technical scheme:
a text classification method based on K adjacent node algorithm and contrast learning comprises the following steps:
s1, in the training process, positive and negative samples of k are represented by constructing sentence vectors, comparison learning is carried out, the intra-class interval is shortened, the inter-class interval is lengthened, and the loss function of comparison learning is as follows:
wherein M represents the number of positive samples, N represents the number of negative samples, q represents the vector representation of the sentence output by the pre-training encoder _ q, k represents the sentence vector representation output by the pre-training encoder _ k, encoder _ q is the same as encoder _ k, k is the same as encoder _ kjRepresents the j-th positive sample k+,kiRepresenting traversal of negative examples k-and kjExp (-) represents an exponential function, τ is a hyper-parameter;
and (3) combining a cross entropy loss function to carry out joint training, wherein the joint loss function is as follows:
L=λLec+(1-λ)Lsc
where λ represents the conditioning cross entropy loss function LecAnd a loss function L of the comparison learningscWeight parameter between, ycRepresenting the class of q, C representing the classification number of the text classification, and F (-) representing a linear classifier;
updating parameters of an encoder _ q and a linear classifier by using a back propagation loss function;
the joint loss function is a weighted sum of a cross entropy loss function and a supervised contrast learning loss function, and the loss function L through contrast learningscThe cross entropy loss function training model is assisted, and the comparison learning training model is used, so that the model can automatically cluster the imbedding representation of the sample in the training process, and a better classification effect can be achieved;
and S2, classifying the text through the trained encoder _ q and the linear classifier.
Further, in S2, obtaining a sentence vector representation q of the text to be predicted by using the trained encoder _ q, and predicting the text classification by using a joint prediction function, where the joint prediction function is as follows:
wherein S represents a probability value of the final classification,representing a hyper-parameter, Softmax (·) representing an activation function, F (Q) representing a trained linear classifier, KNN (Q) representing that K training samples closest to Q in a sample space are selected from a queue Q, giving a probability value of a KNN model in a voting mode according to classification labels of the training samples, obtaining a classification result through the probability value, when the sample class is deduced, jointly predicting the classification of the sample to be predicted by using the KNN and the linear classifier, and displaying through a K adjacent node algorithmThe robustness of the model is improved remarkably.
Further, the K training samples closest to q are selected, and since the class labels of the K training samples are known, K is s1+s2+……+sc,siRepresenting the number of classification labels of the samples belonging to the ith class, c representing the number of classes of the training samples, and giving q belonging to the classification y through a KNN modeliHas a probability value of
Further, the similarity of the q and the training sample is calculated by a cos function.
Further, in S1, the momentum parameter of encoder _ k is updated by the hyper parameter m:
θk←mθk+(1-m)θq
wherein theta iskRepresents the momentum parameter, θ, of encoder _ kqRepresenting momentum parameters of the encoder _ Q, storing k obtained through encoding of the encoder _ k in a queue Q in each batch iteration process, and updating the encoder _ k in a momentum parameter updating mode in each iteration process to enable the k obtained by updating the encoder _ k to be close to Q obtained directly through the encoder _ Q.
Further, the queue Q replaces the element k therein in the sequence.
Further, M elements k identical to the class label of the sample are obtained from the queue Q as positive samples k+N elements k different from the class label of the exemplar as negative exemplars k-。
The invention has the advantages and beneficial effects that:
the method not only greatly improves the robustness of the model, but also correspondingly improves the accuracy of the model. In addition, in order to predict the belonged classification of the sample by using the K-neighborhood algorithm, a distance which can be used for drawing close to the sample of the same class is added in the training process by contrast learning. Meanwhile, in the process of using contrast learning, an MOCO training mode is introduced, and the scale of positive and negative samples is greatly increased.
Drawings
FIG. 1 is a flow chart of the method of the present invention.
FIG. 2 is a line graph of the effect of the lambda value of the present invention on model accuracy over different datasets.
FIG. 3 is a drawing of the present inventionAnd (3) obtaining a line graph of influence of values on model accuracy on different data sets.
Fig. 4a is a sample spatial distribution diagram of a general linear classifier.
FIG. 4b is a KNN-BERT sample spatial distribution map of the present invention.
FIG. 5 is a graph comparing the results of the model classification accuracy test of the present invention.
FIG. 6 is a graph comparing the results of the model robustness tests of the present invention.
Detailed Description
The following detailed description of embodiments of the invention refers to the accompanying drawings. It should be understood that the detailed description and specific examples, while indicating the present invention, are given by way of illustration and explanation only, not limitation.
A text classification method based on K-neighborhood node algorithm and contrast learning is shown in FIG. 1, and comprises the following steps:
a first part: the model training process is specifically divided into the following steps:
step 1.1: the pre-training model BERT is used as the sample encoder _ q, and the same pre-training model BERT is used as the sample encoder _ k.
Step 1.2: updating the parameter of encoder _ k by using the super parameter m ═ 0.999, specifically, the formula of the momentum parameter update is as follows:
θk←mθk+(1-m)θq
wherein theta iskRepresenting the momentum parameter, θ, of the sample encoder _ kqRepresenting the momentum parameter of the sample encoder _ q. The traditional contrast learning adopts the internal selection of batchThe number of the positive samples and the number of the negative samples used in the training process are too small, MoCo adopts a momentum updating mode, samples coded by a coder are stored in a queue in each batch iteration process, and in order to enable the samples in the queue to be represented, the encoder _ k is updated in a momentum parameter updating mode in each iteration process and close to the sample representation directly obtained by the coder;
and updating the encoders encoder _ q and encoder _ k in an iterative updating mode, and updating each batch based on training data.
Step 1.3: a vector representation [ CLS ] _ q of the sentence (i.e., a sentence representation q of the sample) is obtained using the encoder _ q, and a vector representation [ CLS ] _ k of the sentence (i.e., a sentence representation k of the sample) is obtained using the encoder _ k.
For example: for the training sentence "Beijing is the capital of China", BERT will add a Token CLS to the beginning of the sentence and a Token SEP to the end of the sentence when encoding. The imbedding vector of the CLS is typically used as a representation of the entire sentence.
Step 1.4: storing [ CLS ] _ k in a queue Q with the size of 32000, and replacing elements in the Q according to the sequence;
step 1.5: obtaining M samples with the same sample labels from Q as positive samples k+N samples different from the sample label are used as negative samples k-;
Step 1.6: calculating a loss function of contrast learning by using positive and negative samples, and drawing the distance of the same type of samples, wherein the loss function of contrast learning is specifically as follows:
where q is the sentence representation of the encoder _ q output, kjRepresents the j-th positive sample k+,kiRepresentation traversal k-And kjExp (·) denotes an exponential function, τ being a hyperparameter, in particular τ being 0.07.
Step 1.7: using a cross entropy loss function to assist model training, taking λ as 0.01, specifically, the loss function of model training is as follows:
L=λLec+(1-λ)Lsc
wherein λ represents the adjustment LecAnd LscThe weight parameter between, λ value as shown in FIG. 2, is the experimental result on both RTE and MRPC data sets, ycRepresents q or the class of the input sample x (q is the sentence representation of x after encoder _ q), C is the classification number of the text classification, and F (-) is the linear classifier.
Loss function L through contrast learningscTo assist the cross entropy loss function training model. The method has the beneficial effects that the comparison learning training model is used, so that the model can automatically cluster the imbedding representation of the sample in the training process. Thereby achieving better classification effect.
Step 1.8: and updating parameters of the encoder _ q and the linear classifier by using a back propagation model loss function.
A second part: the classification of the sample to be predicted is jointly predicted by using the KNN and the linear classifier, and specifically, the method comprises the following steps:
step 2.1: obtaining a sentence representation q of a sample to be predicted by using an encoder _ q;
step 2.2: predicting sample classifications using a joint prediction functionSpecifically, the joint prediction function is as follows:
where S is the probability value of the final model classification,and KNN (Q) is a hyperparameter, K samples which are nearest to Q in the sample space are taken from Q, and then probability values of the KNN model are given in a voting mode according to label of the samples.
Specifically, the cos function is used to calculate the similarity between two samples, and the K training samples with the maximum similarity are selected, because the classification information label of the K training samples is known, assuming that a total of c classes, s, of the training samples are assumed1+s2+……+scK, wherein siThe number of samples Label belonging to the ith classification is shown, so that the sample x to be predicted given by the KNN model belongs to the classification yiHas a probability value ofAs shown in fig. 3Values, are the effect on model accuracy on the RTE, MRPC and MNLI data sets.
As shown in fig. 4a and 4b, the red and blue points represent two different classes of data points, and it can be seen that the clustering effect of the sample distribution of KNN-BERT is better than that of the conventional linear classifier.
As shown in FIG. 5, RTE, MRPC, QNLI, MNLI, SST-2, IMDB and AG' News are all current common text classification data sets, BERT is a current general classification model, SCL-Train is a traditional classification model of contrast learning + BERT, MoCo is a classification model which expands positive and negative samples after using a momentum parameter updating method, and KNN-BERT is a classification model provided by the invention. It can be seen from the figure that the classification accuracy of the method provided by the invention is improved on each data set compared with the existing method, wherein the improvement effect is better on two data sets with less data quantity, namely RTE and MRPC.
As shown in FIG. 6, IMDB, AG's News are two commonly used text classification data sets, Origin represents the original accuracy, Textfooler and BERT-Attack represent the classification accuracy under the Attack of the two Attack modes, and BERT represents the traditional classification methodThe KNN shows that the method of the invention is adopted, and the experimental result shows thatThe robustness of the model is best when the results of the model are given by the KNN classifier only.
The above examples are only intended to illustrate the technical solution of the present invention, but not to limit it; although the present invention has been described in detail with reference to the foregoing embodiments, it will be understood by those of ordinary skill in the art that: the technical solutions described in the foregoing embodiments may still be modified, or some or all of the technical features may be equivalently replaced; and the modifications or the substitutions do not make the essence of the corresponding technical solutions depart from the scope of the technical solutions of the embodiments of the present invention.
Claims (7)
1. A text classification method based on K adjacent node algorithm and contrast learning is characterized by comprising the following steps:
s1, in the training process, positive and negative samples of k are represented by constructed sentence vectors, and comparative learning is carried out, wherein the comparative learning loss function is as follows:
wherein M represents the number of positive samples, N represents the number of negative samples, q represents the vector representation of the sentence output by the pre-training encoder _ q, k represents the sentence vector representation output by the pre-training encoder _ k, encoder _ q is the same as encoder _ k, k is the same as encoder _ kjRepresents the j-th positive sample k+,kiRepresenting traversal of negative examples k-And kjExp (-) represents an exponential function, τ is a hyper-parameter;
and (3) combining a cross entropy loss function to carry out joint training, wherein the joint loss function is as follows:
L=λLec+(1-λ)Lsc
where λ represents the conditioning cross entropy loss function LecAnd a loss function L of the comparison learningscWeight parameter between, ycRepresenting the class of q, C representing the classification number of the text classification, and F (-) representing a linear classifier;
updating parameters of an encoder _ q and a linear classifier by using a back propagation loss function;
and S2, classifying the text through the trained encoder _ q and the linear classifier.
2. The method according to claim 1, wherein in S2, a sentence vector representation q of the text to be predicted is obtained by a trained encoder _ q, and the text classification is predicted by using a joint prediction function, wherein the joint prediction function is as follows:
wherein S represents a probability value of the final classification,representing a hyper-parameter, Softmax (·) representing an activation function, F (Q) representing a trained linear classifier, KNN (Q) representing that K training samples closest to Q are selected from a queue Q, giving a probability value of a KNN model in a voting mode according to classification labels of the training samples, and obtaining a classification result through the probability value.
3. The method according to claim 2, wherein the K training samples closest to q are selected, and since the classification labels of the K training samples are known, K ═ s is given1+s2+……+sc,siClass labels representing samples belong to the ith classC represents the number of classes of the training sample, and q given by the KNN model belongs to the class yiHas a probability value of
4. The method of claim 2, wherein the similarity between q and the training sample is calculated by a cos function.
5. The method for classifying texts based on K-neighbor algorithm and contrast learning of claim 1, wherein in S1, the momentum parameter of encoder _ K is updated by the hyper-parameter m:
θk←mθk+(1-m)θq
wherein theta iskRepresents the momentum parameter, θ, of encoder _ kqRepresenting momentum parameters of the encoder _ Q, storing k obtained through encoding of the encoder _ k in a queue Q, and updating the encoder _ k in a momentum parameter updating mode in each iteration process to enable the k obtained by the updating to be close to Q directly obtained through the encoder _ Q.
6. The method of claim 5, wherein the queue Q replaces the element K in the queue in a sequential order.
7. The method of claim 5, wherein M elements K having the same classification label as the sample are obtained from the queue Q as a positive sample K+N elements k different from the class label of the exemplar as negative exemplars k-。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110960433.1A CN113673242A (en) | 2021-08-20 | 2021-08-20 | Text classification method based on K-neighborhood node algorithm and comparative learning |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110960433.1A CN113673242A (en) | 2021-08-20 | 2021-08-20 | Text classification method based on K-neighborhood node algorithm and comparative learning |
Publications (1)
Publication Number | Publication Date |
---|---|
CN113673242A true CN113673242A (en) | 2021-11-19 |
Family
ID=78544489
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110960433.1A Pending CN113673242A (en) | 2021-08-20 | 2021-08-20 | Text classification method based on K-neighborhood node algorithm and comparative learning |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113673242A (en) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110533104A (en) * | 2019-08-30 | 2019-12-03 | 中山大学 | It is a kind of based on the different classes of classification method combined apart from mean value |
CN114090780A (en) * | 2022-01-20 | 2022-02-25 | 宏龙科技(杭州)有限公司 | Prompt learning-based rapid picture classification method |
CN114299304A (en) * | 2021-12-15 | 2022-04-08 | 腾讯科技(深圳)有限公司 | Image processing method and related equipment |
CN115346084A (en) * | 2022-08-15 | 2022-11-15 | 腾讯科技(深圳)有限公司 | Sample processing method, sample processing apparatus, electronic device, storage medium, and program product |
-
2021
- 2021-08-20 CN CN202110960433.1A patent/CN113673242A/en active Pending
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110533104A (en) * | 2019-08-30 | 2019-12-03 | 中山大学 | It is a kind of based on the different classes of classification method combined apart from mean value |
CN114299304A (en) * | 2021-12-15 | 2022-04-08 | 腾讯科技(深圳)有限公司 | Image processing method and related equipment |
CN114299304B (en) * | 2021-12-15 | 2024-04-12 | 腾讯科技(深圳)有限公司 | Image processing method and related equipment |
CN114090780A (en) * | 2022-01-20 | 2022-02-25 | 宏龙科技(杭州)有限公司 | Prompt learning-based rapid picture classification method |
CN115346084A (en) * | 2022-08-15 | 2022-11-15 | 腾讯科技(深圳)有限公司 | Sample processing method, sample processing apparatus, electronic device, storage medium, and program product |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113673242A (en) | Text classification method based on K-neighborhood node algorithm and comparative learning | |
CN109376242B (en) | Text classification method based on cyclic neural network variant and convolutional neural network | |
CN108388651B (en) | Text classification method based on graph kernel and convolutional neural network | |
Peng et al. | Accelerating minibatch stochastic gradient descent using typicality sampling | |
CN112069310A (en) | Text classification method and system based on active learning strategy | |
CN112560432A (en) | Text emotion analysis method based on graph attention network | |
CN111368920A (en) | Quantum twin neural network-based binary classification method and face recognition method thereof | |
CN116644755B (en) | Multi-task learning-based few-sample named entity recognition method, device and medium | |
CN112491891B (en) | Network attack detection method based on hybrid deep learning in Internet of things environment | |
WO2022241932A1 (en) | Prediction method based on non-intrusive attention preprocessing process and bilstm model | |
CN114547299A (en) | Short text sentiment classification method and device based on composite network model | |
CN113705238A (en) | Method and model for analyzing aspect level emotion based on BERT and aspect feature positioning model | |
CN114692605A (en) | Keyword generation method and device fusing syntactic structure information | |
CN114722835A (en) | Text emotion recognition method based on LDA and BERT fusion improved model | |
CN113553245B (en) | Log anomaly detection method combining bidirectional slice GRU and gate control attention mechanism | |
Fonseca et al. | Model-agnostic approaches to handling noisy labels when training sound event classifiers | |
CN116050419B (en) | Unsupervised identification method and system oriented to scientific literature knowledge entity | |
CN112861626A (en) | Fine-grained expression classification method based on small sample learning | |
CN115809346A (en) | Small sample knowledge graph completion method based on multi-view semantic enhancement | |
CN115270988A (en) | Fine adjustment method, device and application of knowledge representation decoupling classification model | |
CN114707483A (en) | Zero sample event extraction system and method based on contrast learning and data enhancement | |
Xu et al. | Multi text classification model based on bret-cnn-bilstm | |
Li et al. | A position weighted information based word embedding model for machine translation | |
CN113343710A (en) | Unsupervised word embedding representation learning method based on Ising model | |
Liu et al. | Hessian regularization of deep neural networks: A novel approach based on stochastic estimators of Hessian trace |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |