CN113673242A - Text classification method based on K-neighborhood node algorithm and comparative learning - Google Patents

Text classification method based on K-neighborhood node algorithm and comparative learning Download PDF

Info

Publication number
CN113673242A
CN113673242A CN202110960433.1A CN202110960433A CN113673242A CN 113673242 A CN113673242 A CN 113673242A CN 202110960433 A CN202110960433 A CN 202110960433A CN 113673242 A CN113673242 A CN 113673242A
Authority
CN
China
Prior art keywords
encoder
training
classification
representing
samples
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110960433.1A
Other languages
Chinese (zh)
Inventor
邱锡鹏
宋德敏
李林阳
傅家庆
杨非
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Fudan University
Zhejiang Lab
Original Assignee
Fudan University
Zhejiang Lab
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Fudan University, Zhejiang Lab filed Critical Fudan University
Priority to CN202110960433.1A priority Critical patent/CN113673242A/en
Publication of CN113673242A publication Critical patent/CN113673242A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • G06F40/279Recognition of textual entities
    • G06F40/289Phrasal analysis, e.g. finite state techniques or chunking
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/23Clustering techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • G06F40/205Parsing
    • G06F40/211Syntactic parsing, e.g. based on context-free grammar [CFG] or unification grammars
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • General Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Computational Linguistics (AREA)
  • Health & Medical Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Evolutionary Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Probability & Statistics with Applications (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

The invention discloses a text classification method based on K adjacent node algorithm and contrast learning, which uses contrast learning to pull in the inter-class distance and pull out the inter-class distance in the training stage, and combines cross entropy loss to assist the contrast learning to carry out joint training, and in the inference process, performs joint prediction by a well-united training model and the nearest node algorithm to calculate the classification of a text to be inferred; the method can obtain a result higher than a text classification mode used in the current industry on the accuracy of text classification, and greatly improves the robustness of the model.

Description

Text classification method based on K-neighborhood node algorithm and comparative learning
Technical Field
The invention relates to deep learning and natural language processing, in particular to a text classification method based on a K-neighborhood node algorithm and contrast learning.
Background
The text classification task is a basic task in natural language processing, and the current mainstream text classification method is to use a linear classifier for classification on the basis of a large-scale pre-training model (such as BERT). However, linear classifiers tend to be less robust and are easily fooled by attacks such as TextFooler or BertAttack.
Disclosure of Invention
In order to solve the defects of the prior art, improve the robustness and simultaneously improve the classification accuracy of the model, the invention adopts the following technical scheme:
a text classification method based on K adjacent node algorithm and contrast learning comprises the following steps:
s1, in the training process, positive and negative samples of k are represented by constructing sentence vectors, comparison learning is carried out, the intra-class interval is shortened, the inter-class interval is lengthened, and the loss function of comparison learning is as follows:
Figure BDA0003221873920000011
wherein M represents the number of positive samples, N represents the number of negative samples, q represents the vector representation of the sentence output by the pre-training encoder _ q, k represents the sentence vector representation output by the pre-training encoder _ k, encoder _ q is the same as encoder _ k, k is the same as encoder _ kjRepresents the j-th positive sample k+,kiRepresenting traversal of negative examples k-and kjExp (-) represents an exponential function, τ is a hyper-parameter;
and (3) combining a cross entropy loss function to carry out joint training, wherein the joint loss function is as follows:
L=λLec+(1-λ)Lsc
Figure BDA0003221873920000012
where λ represents the conditioning cross entropy loss function LecAnd a loss function L of the comparison learningscWeight parameter between, ycRepresenting the class of q, C representing the classification number of the text classification, and F (-) representing a linear classifier;
updating parameters of an encoder _ q and a linear classifier by using a back propagation loss function;
the joint loss function is a weighted sum of a cross entropy loss function and a supervised contrast learning loss function, and the loss function L through contrast learningscThe cross entropy loss function training model is assisted, and the comparison learning training model is used, so that the model can automatically cluster the imbedding representation of the sample in the training process, and a better classification effect can be achieved;
and S2, classifying the text through the trained encoder _ q and the linear classifier.
Further, in S2, obtaining a sentence vector representation q of the text to be predicted by using the trained encoder _ q, and predicting the text classification by using a joint prediction function, where the joint prediction function is as follows:
Figure BDA0003221873920000021
wherein S represents a probability value of the final classification,
Figure BDA0003221873920000022
representing a hyper-parameter, Softmax (·) representing an activation function, F (Q) representing a trained linear classifier, KNN (Q) representing that K training samples closest to Q in a sample space are selected from a queue Q, giving a probability value of a KNN model in a voting mode according to classification labels of the training samples, obtaining a classification result through the probability value, when the sample class is deduced, jointly predicting the classification of the sample to be predicted by using the KNN and the linear classifier, and displaying through a K adjacent node algorithmThe robustness of the model is improved remarkably.
Further, the K training samples closest to q are selected, and since the class labels of the K training samples are known, K is s1+s2+……+sc,siRepresenting the number of classification labels of the samples belonging to the ith class, c representing the number of classes of the training samples, and giving q belonging to the classification y through a KNN modeliHas a probability value of
Figure BDA0003221873920000023
Further, the similarity of the q and the training sample is calculated by a cos function.
Further, in S1, the momentum parameter of encoder _ k is updated by the hyper parameter m:
θk←mθk+(1-m)θq
wherein theta iskRepresents the momentum parameter, θ, of encoder _ kqRepresenting momentum parameters of the encoder _ Q, storing k obtained through encoding of the encoder _ k in a queue Q in each batch iteration process, and updating the encoder _ k in a momentum parameter updating mode in each iteration process to enable the k obtained by updating the encoder _ k to be close to Q obtained directly through the encoder _ Q.
Further, the queue Q replaces the element k therein in the sequence.
Further, M elements k identical to the class label of the sample are obtained from the queue Q as positive samples k+N elements k different from the class label of the exemplar as negative exemplars k-
The invention has the advantages and beneficial effects that:
the method not only greatly improves the robustness of the model, but also correspondingly improves the accuracy of the model. In addition, in order to predict the belonged classification of the sample by using the K-neighborhood algorithm, a distance which can be used for drawing close to the sample of the same class is added in the training process by contrast learning. Meanwhile, in the process of using contrast learning, an MOCO training mode is introduced, and the scale of positive and negative samples is greatly increased.
Drawings
FIG. 1 is a flow chart of the method of the present invention.
FIG. 2 is a line graph of the effect of the lambda value of the present invention on model accuracy over different datasets.
FIG. 3 is a drawing of the present invention
Figure BDA0003221873920000024
And (3) obtaining a line graph of influence of values on model accuracy on different data sets.
Fig. 4a is a sample spatial distribution diagram of a general linear classifier.
FIG. 4b is a KNN-BERT sample spatial distribution map of the present invention.
FIG. 5 is a graph comparing the results of the model classification accuracy test of the present invention.
FIG. 6 is a graph comparing the results of the model robustness tests of the present invention.
Detailed Description
The following detailed description of embodiments of the invention refers to the accompanying drawings. It should be understood that the detailed description and specific examples, while indicating the present invention, are given by way of illustration and explanation only, not limitation.
A text classification method based on K-neighborhood node algorithm and contrast learning is shown in FIG. 1, and comprises the following steps:
a first part: the model training process is specifically divided into the following steps:
step 1.1: the pre-training model BERT is used as the sample encoder _ q, and the same pre-training model BERT is used as the sample encoder _ k.
Step 1.2: updating the parameter of encoder _ k by using the super parameter m ═ 0.999, specifically, the formula of the momentum parameter update is as follows:
θk←mθk+(1-m)θq
wherein theta iskRepresenting the momentum parameter, θ, of the sample encoder _ kqRepresenting the momentum parameter of the sample encoder _ q. The traditional contrast learning adopts the internal selection of batchThe number of the positive samples and the number of the negative samples used in the training process are too small, MoCo adopts a momentum updating mode, samples coded by a coder are stored in a queue in each batch iteration process, and in order to enable the samples in the queue to be represented, the encoder _ k is updated in a momentum parameter updating mode in each iteration process and close to the sample representation directly obtained by the coder;
and updating the encoders encoder _ q and encoder _ k in an iterative updating mode, and updating each batch based on training data.
Step 1.3: a vector representation [ CLS ] _ q of the sentence (i.e., a sentence representation q of the sample) is obtained using the encoder _ q, and a vector representation [ CLS ] _ k of the sentence (i.e., a sentence representation k of the sample) is obtained using the encoder _ k.
For example: for the training sentence "Beijing is the capital of China", BERT will add a Token CLS to the beginning of the sentence and a Token SEP to the end of the sentence when encoding. The imbedding vector of the CLS is typically used as a representation of the entire sentence.
Step 1.4: storing [ CLS ] _ k in a queue Q with the size of 32000, and replacing elements in the Q according to the sequence;
step 1.5: obtaining M samples with the same sample labels from Q as positive samples k+N samples different from the sample label are used as negative samples k-
Step 1.6: calculating a loss function of contrast learning by using positive and negative samples, and drawing the distance of the same type of samples, wherein the loss function of contrast learning is specifically as follows:
Figure BDA0003221873920000031
where q is the sentence representation of the encoder _ q output, kjRepresents the j-th positive sample k+,kiRepresentation traversal k-And kjExp (·) denotes an exponential function, τ being a hyperparameter, in particular τ being 0.07.
Step 1.7: using a cross entropy loss function to assist model training, taking λ as 0.01, specifically, the loss function of model training is as follows:
L=λLec+(1-λ)Lsc
Figure BDA0003221873920000041
wherein λ represents the adjustment LecAnd LscThe weight parameter between, λ value as shown in FIG. 2, is the experimental result on both RTE and MRPC data sets, ycRepresents q or the class of the input sample x (q is the sentence representation of x after encoder _ q), C is the classification number of the text classification, and F (-) is the linear classifier.
Loss function L through contrast learningscTo assist the cross entropy loss function training model. The method has the beneficial effects that the comparison learning training model is used, so that the model can automatically cluster the imbedding representation of the sample in the training process. Thereby achieving better classification effect.
Step 1.8: and updating parameters of the encoder _ q and the linear classifier by using a back propagation model loss function.
A second part: the classification of the sample to be predicted is jointly predicted by using the KNN and the linear classifier, and specifically, the method comprises the following steps:
step 2.1: obtaining a sentence representation q of a sample to be predicted by using an encoder _ q;
step 2.2: predicting sample classifications using a joint prediction function
Figure BDA0003221873920000042
Specifically, the joint prediction function is as follows:
Figure BDA0003221873920000043
where S is the probability value of the final model classification,
Figure BDA0003221873920000044
and KNN (Q) is a hyperparameter, K samples which are nearest to Q in the sample space are taken from Q, and then probability values of the KNN model are given in a voting mode according to label of the samples.
Specifically, the cos function is used to calculate the similarity between two samples, and the K training samples with the maximum similarity are selected, because the classification information label of the K training samples is known, assuming that a total of c classes, s, of the training samples are assumed1+s2+……+scK, wherein siThe number of samples Label belonging to the ith classification is shown, so that the sample x to be predicted given by the KNN model belongs to the classification yiHas a probability value of
Figure BDA0003221873920000045
As shown in fig. 3
Figure BDA0003221873920000046
Values, are the effect on model accuracy on the RTE, MRPC and MNLI data sets.
As shown in fig. 4a and 4b, the red and blue points represent two different classes of data points, and it can be seen that the clustering effect of the sample distribution of KNN-BERT is better than that of the conventional linear classifier.
As shown in FIG. 5, RTE, MRPC, QNLI, MNLI, SST-2, IMDB and AG' News are all current common text classification data sets, BERT is a current general classification model, SCL-Train is a traditional classification model of contrast learning + BERT, MoCo is a classification model which expands positive and negative samples after using a momentum parameter updating method, and KNN-BERT is a classification model provided by the invention. It can be seen from the figure that the classification accuracy of the method provided by the invention is improved on each data set compared with the existing method, wherein the improvement effect is better on two data sets with less data quantity, namely RTE and MRPC.
As shown in FIG. 6, IMDB, AG's News are two commonly used text classification data sets, Origin represents the original accuracy, Textfooler and BERT-Attack represent the classification accuracy under the Attack of the two Attack modes, and BERT represents the traditional classification methodThe KNN shows that the method of the invention is adopted, and the experimental result shows that
Figure BDA0003221873920000051
The robustness of the model is best when the results of the model are given by the KNN classifier only.
The above examples are only intended to illustrate the technical solution of the present invention, but not to limit it; although the present invention has been described in detail with reference to the foregoing embodiments, it will be understood by those of ordinary skill in the art that: the technical solutions described in the foregoing embodiments may still be modified, or some or all of the technical features may be equivalently replaced; and the modifications or the substitutions do not make the essence of the corresponding technical solutions depart from the scope of the technical solutions of the embodiments of the present invention.

Claims (7)

1. A text classification method based on K adjacent node algorithm and contrast learning is characterized by comprising the following steps:
s1, in the training process, positive and negative samples of k are represented by constructed sentence vectors, and comparative learning is carried out, wherein the comparative learning loss function is as follows:
Figure FDA0003221873910000011
wherein M represents the number of positive samples, N represents the number of negative samples, q represents the vector representation of the sentence output by the pre-training encoder _ q, k represents the sentence vector representation output by the pre-training encoder _ k, encoder _ q is the same as encoder _ k, k is the same as encoder _ kjRepresents the j-th positive sample k+,kiRepresenting traversal of negative examples k-And kjExp (-) represents an exponential function, τ is a hyper-parameter;
and (3) combining a cross entropy loss function to carry out joint training, wherein the joint loss function is as follows:
L=λLec+(1-λ)Lsc
Figure FDA0003221873910000012
where λ represents the conditioning cross entropy loss function LecAnd a loss function L of the comparison learningscWeight parameter between, ycRepresenting the class of q, C representing the classification number of the text classification, and F (-) representing a linear classifier;
updating parameters of an encoder _ q and a linear classifier by using a back propagation loss function;
and S2, classifying the text through the trained encoder _ q and the linear classifier.
2. The method according to claim 1, wherein in S2, a sentence vector representation q of the text to be predicted is obtained by a trained encoder _ q, and the text classification is predicted by using a joint prediction function, wherein the joint prediction function is as follows:
Figure FDA0003221873910000013
wherein S represents a probability value of the final classification,
Figure FDA0003221873910000014
representing a hyper-parameter, Softmax (·) representing an activation function, F (Q) representing a trained linear classifier, KNN (Q) representing that K training samples closest to Q are selected from a queue Q, giving a probability value of a KNN model in a voting mode according to classification labels of the training samples, and obtaining a classification result through the probability value.
3. The method according to claim 2, wherein the K training samples closest to q are selected, and since the classification labels of the K training samples are known, K ═ s is given1+s2+……+sc,siClass labels representing samples belong to the ith classC represents the number of classes of the training sample, and q given by the KNN model belongs to the class yiHas a probability value of
Figure FDA0003221873910000015
4. The method of claim 2, wherein the similarity between q and the training sample is calculated by a cos function.
5. The method for classifying texts based on K-neighbor algorithm and contrast learning of claim 1, wherein in S1, the momentum parameter of encoder _ K is updated by the hyper-parameter m:
θk←mθk+(1-m)θq
wherein theta iskRepresents the momentum parameter, θ, of encoder _ kqRepresenting momentum parameters of the encoder _ Q, storing k obtained through encoding of the encoder _ k in a queue Q, and updating the encoder _ k in a momentum parameter updating mode in each iteration process to enable the k obtained by the updating to be close to Q directly obtained through the encoder _ Q.
6. The method of claim 5, wherein the queue Q replaces the element K in the queue in a sequential order.
7. The method of claim 5, wherein M elements K having the same classification label as the sample are obtained from the queue Q as a positive sample K+N elements k different from the class label of the exemplar as negative exemplars k-
CN202110960433.1A 2021-08-20 2021-08-20 Text classification method based on K-neighborhood node algorithm and comparative learning Pending CN113673242A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110960433.1A CN113673242A (en) 2021-08-20 2021-08-20 Text classification method based on K-neighborhood node algorithm and comparative learning

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110960433.1A CN113673242A (en) 2021-08-20 2021-08-20 Text classification method based on K-neighborhood node algorithm and comparative learning

Publications (1)

Publication Number Publication Date
CN113673242A true CN113673242A (en) 2021-11-19

Family

ID=78544489

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110960433.1A Pending CN113673242A (en) 2021-08-20 2021-08-20 Text classification method based on K-neighborhood node algorithm and comparative learning

Country Status (1)

Country Link
CN (1) CN113673242A (en)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110533104A (en) * 2019-08-30 2019-12-03 中山大学 It is a kind of based on the different classes of classification method combined apart from mean value
CN114090780A (en) * 2022-01-20 2022-02-25 宏龙科技(杭州)有限公司 Prompt learning-based rapid picture classification method
CN114299304A (en) * 2021-12-15 2022-04-08 腾讯科技(深圳)有限公司 Image processing method and related equipment
CN115346084A (en) * 2022-08-15 2022-11-15 腾讯科技(深圳)有限公司 Sample processing method, sample processing apparatus, electronic device, storage medium, and program product

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110533104A (en) * 2019-08-30 2019-12-03 中山大学 It is a kind of based on the different classes of classification method combined apart from mean value
CN114299304A (en) * 2021-12-15 2022-04-08 腾讯科技(深圳)有限公司 Image processing method and related equipment
CN114299304B (en) * 2021-12-15 2024-04-12 腾讯科技(深圳)有限公司 Image processing method and related equipment
CN114090780A (en) * 2022-01-20 2022-02-25 宏龙科技(杭州)有限公司 Prompt learning-based rapid picture classification method
CN115346084A (en) * 2022-08-15 2022-11-15 腾讯科技(深圳)有限公司 Sample processing method, sample processing apparatus, electronic device, storage medium, and program product

Similar Documents

Publication Publication Date Title
CN113673242A (en) Text classification method based on K-neighborhood node algorithm and comparative learning
CN109376242B (en) Text classification method based on cyclic neural network variant and convolutional neural network
CN108388651B (en) Text classification method based on graph kernel and convolutional neural network
Peng et al. Accelerating minibatch stochastic gradient descent using typicality sampling
CN112069310A (en) Text classification method and system based on active learning strategy
CN112560432A (en) Text emotion analysis method based on graph attention network
CN111368920A (en) Quantum twin neural network-based binary classification method and face recognition method thereof
CN116644755B (en) Multi-task learning-based few-sample named entity recognition method, device and medium
CN112491891B (en) Network attack detection method based on hybrid deep learning in Internet of things environment
WO2022241932A1 (en) Prediction method based on non-intrusive attention preprocessing process and bilstm model
CN114547299A (en) Short text sentiment classification method and device based on composite network model
CN113705238A (en) Method and model for analyzing aspect level emotion based on BERT and aspect feature positioning model
CN114692605A (en) Keyword generation method and device fusing syntactic structure information
CN114722835A (en) Text emotion recognition method based on LDA and BERT fusion improved model
CN113553245B (en) Log anomaly detection method combining bidirectional slice GRU and gate control attention mechanism
Fonseca et al. Model-agnostic approaches to handling noisy labels when training sound event classifiers
CN116050419B (en) Unsupervised identification method and system oriented to scientific literature knowledge entity
CN112861626A (en) Fine-grained expression classification method based on small sample learning
CN115809346A (en) Small sample knowledge graph completion method based on multi-view semantic enhancement
CN115270988A (en) Fine adjustment method, device and application of knowledge representation decoupling classification model
CN114707483A (en) Zero sample event extraction system and method based on contrast learning and data enhancement
Xu et al. Multi text classification model based on bret-cnn-bilstm
Li et al. A position weighted information based word embedding model for machine translation
CN113343710A (en) Unsupervised word embedding representation learning method based on Ising model
Liu et al. Hessian regularization of deep neural networks: A novel approach based on stochastic estimators of Hessian trace

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination