CN115146021A - Training method and device for text retrieval matching model, electronic equipment and medium - Google Patents

Training method and device for text retrieval matching model, electronic equipment and medium Download PDF

Info

Publication number
CN115146021A
CN115146021A CN202110343807.5A CN202110343807A CN115146021A CN 115146021 A CN115146021 A CN 115146021A CN 202110343807 A CN202110343807 A CN 202110343807A CN 115146021 A CN115146021 A CN 115146021A
Authority
CN
China
Prior art keywords
confidence
sample
target
sample set
label
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110343807.5A
Other languages
Chinese (zh)
Inventor
张辰
胡燊
刘怀军
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Sankuai Online Technology Co Ltd
Original Assignee
Beijing Sankuai Online Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Sankuai Online Technology Co Ltd filed Critical Beijing Sankuai Online Technology Co Ltd
Priority to CN202110343807.5A priority Critical patent/CN115146021A/en
Publication of CN115146021A publication Critical patent/CN115146021A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/33Querying

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Databases & Information Systems (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

The embodiment of the application discloses a training method and a device for a text retrieval matching model, electronic equipment and a storage medium, wherein the method comprises the following steps: performing fine tuning training on the pre-training language model according to the initial sample set to obtain an initial text retrieval matching model; predicting the unmarked data through an initial text retrieval matching model to obtain an output embedded vector and a prediction tag probability distribution; screening the unlabeled data set according to the probability distribution of the predicted labels to obtain a high-confidence sample set; determining the similarity relation and the vector distance between every two high-confidence samples in the high-confidence sample set, and determining the confidence weight of the high-confidence samples; adding the confidence coefficient weight, the similarity relation and the vector distance into the high confidence coefficient sample set to obtain a target sample set; and training the initial text retrieval matching model according to the target sample set to obtain a target text retrieval matching model. The embodiment of the application improves the training efficiency and accuracy of the model.

Description

Training method and device for text retrieval matching model, electronic equipment and medium
Technical Field
The embodiment of the application relates to the technical field of deep learning, in particular to a training method and device for a text retrieval matching model, electronic equipment and a storage medium.
Background
With the excellent performance of pre-training Language models such as BERT, more and more NLP (Natural Language Processing) tasks use BERT for pre-training and fine-tuning learning of downstream tasks, where text retrieval matching is one of the most common tasks. Through a model such as BERT, text input can be subjected to attention network to obtain corresponding dense vectors, and then matching and sorting are carried out. However, in the process of using these models, a large amount of labeled corpora are inevitably used for training, and when the labeled data is insufficient, the effect of these models is greatly reduced, and even an overfitting situation may occur.
In the field of retrieval and matching, the model training method for small sample learning mainly comprises the following modes: unsupervised similarity, which is a method for performing auxiliary sequencing by text matching scores such as BM25 and TFIDF, or semantic similarity scores such as Word2Vec, and the like, wherein the calculation of the text matching scores and the semantic similarity scores does not need any labeled corpora, and the corresponding scores can be obtained only by training according to corresponding calculation formulas or models; data enhancement, which is common in the field of NLP, mainly includes: translation back, random word or synonym replacement, etc.; semi-supervised learning, the most commonly used method is that semi-supervised training is performed by combining soft and hard labels, namely, labeled data is used for preliminary training of a model, then the model is used for predicting unlabelled data, a high-confidence sample in a prediction result is used as a hard label sample, a middle-confidence sample is used as a soft label, and then the data after the label extension is used for training a final model.
In a short text scene, the method for performing matching calculation by using the similarity score has a weak effect, even fails, because the search term searched by the user does not necessarily appear in the document, and the matching may also be noise information. The method based on data enhancement has the defects that the original text is replaced, so that the risk of label failure exists, particularly in the field of short texts, the meaning of the original text is often greatly influenced by replacing one word or even one word, the difficulty of model learning is greatly improved, and meanwhile, because the data enhancement method is not designed aiming at the unique scene of text retrieval, the enhancement effect cannot be played to the maximum. According to the semi-supervised learning method, a large amount of noise data interference often exists in a soft label result obtained by a model, the model fitting confidence sample capability is very depended on, common points and similar points existing between samples are ignored, the learning efficiency is greatly reduced, and the cost cannot be saved too much. Therefore, the method in the prior art still has the problem of poor text retrieval matching accuracy.
Disclosure of Invention
The embodiment of the application provides a training method and device for a text retrieval matching model, electronic equipment and a storage medium, and the method and device are favorable for improving the accuracy of the text retrieval matching model.
In order to solve the above problem, in a first aspect, an embodiment of the present application provides a method for training a text retrieval matching model, including:
acquiring an initial sample set and acquiring an unlabeled data set, wherein each initial sample in the initial sample set comprises a search word, a document and a label, and the unlabeled data in the unlabeled data set comprises a search word and a document;
performing fine tuning training on a pre-training language model according to the initial sample set to obtain an initial text retrieval matching model;
predicting the unmarked data in the unmarked data set through the initial text retrieval matching model to obtain an output embedded vector and a predictive label probability distribution corresponding to each unmarked data in the unmarked data set;
screening the unmarked data set according to the probability distribution of the predictive label corresponding to each unmarked data to obtain a high-confidence sample set;
according to the predictive label probability distribution and the output embedded vector of each high-confidence sample in the high-confidence sample set, determining the similarity relation between each high-confidence sample and other high-confidence samples in the high-confidence sample set and the vector distance of the output embedded vector, and determining the confidence weight of each high-confidence sample;
adding the confidence coefficient weight of each high-confidence sample, and the similarity relation and the vector distance between each high-confidence sample and other high-confidence samples into the high-confidence sample set to obtain a target sample set;
and training the initial text retrieval matching model according to the target sample set and a target loss function associated with the confidence coefficient weight, the similarity relation and the vector distance to obtain a target text retrieval matching model.
In a second aspect, an embodiment of the present application provides a training apparatus for text retrieval matching models, including:
the data set acquisition module is used for acquiring an initial sample set and acquiring an unlabeled data set, wherein each initial sample in the initial sample set comprises a search word, a document and a label, and the unlabeled data in the unlabeled data set comprises a search word and a document;
the first fine tuning training module is used for performing fine tuning training on a pre-training language model according to the initial sample set to obtain an initial text retrieval matching model;
the model prediction module is used for predicting the unmarked data in the unmarked data set through the initial text retrieval matching model to obtain an output embedded vector and a prediction label probability distribution corresponding to each unmarked data in the unmarked data set;
the data set screening module is used for screening the unmarked data sets according to the probability distribution of the predictive label corresponding to each unmarked data to obtain a high-confidence sample set;
the parameter determination module is used for determining the similarity relation between each high-confidence sample in the high-confidence sample set and other high-confidence samples and the vector distance of the output embedded vector according to the predictive label probability distribution and the output embedded vector of each high-confidence sample in the high-confidence sample set, and determining the confidence weight of each high-confidence sample;
the target sample construction module is used for adding the confidence weight of each high-confidence sample, the similarity relation and the vector distance between each high-confidence sample and other high-confidence samples to the high-confidence sample set to obtain a target sample set;
and the model training module is used for training the initial text retrieval matching model according to the target sample set and a target loss function associated with the confidence coefficient weight, the similarity relation and the vector distance to obtain a target text retrieval matching model.
In a third aspect, an embodiment of the present application further provides an electronic device, which includes a memory, a processor, and a computer program stored on the memory and executable on the processor, where the processor, when executing the computer program, implements the method for training the text retrieval matching model according to the embodiment of the present application.
In a fourth aspect, the present application provides a computer-readable storage medium, on which a computer program is stored, where the computer program, when executed by a processor, implements the steps of the training method for text retrieval matching model disclosed in the present application.
The training method, the training device, the electronic device and the storage medium for the text retrieval matching model are characterized in that an initial sample set and an unlabeled data set are obtained, a pre-training language model is subjected to fine-tuning training according to the initial sample set to obtain an initial text retrieval matching model, unlabeled data in the unlabeled data set are predicted through the initial text retrieval matching model to obtain an output embedded vector and a predicted tag probability distribution corresponding to each unlabeled data in the unlabeled data set, the unlabeled data set is screened according to the predicted tag probability distribution corresponding to each unlabeled data to obtain a high-confidence sample set, a similarity relation between each high-confidence sample and other high-confidence samples in the high-confidence sample set and a vector distance of the output embedded vector are determined according to the predicted tag probability distribution and the output embedded vector of each high-confidence sample in the high-confidence sample set, a confidence weight of each high-confidence sample is determined, a confidence weight of each high-confidence sample and a confidence distance between each high-confidence sample and other high-confidence sample are added to the high-confidence sample set, a confidence weight matching function is obtained, and a target text retrieval is associated with the initial sample set and the target text retrieval model, and the target distance is associated with the target distance. Because only a small amount of manually labeled initial sample sets are needed, the labor cost is effectively saved, and the interference of noise data in unlabeled samples can be effectively avoided and the training efficiency and accuracy of the model are improved by performing self-supervision learning through the target loss function associated with the confidence coefficient weight, the similarity relation and the vector distance.
Drawings
In order to more clearly illustrate the technical solutions of the embodiments of the present application, the drawings needed to be used in the description of the embodiments or the prior art will be briefly described below, and it is obvious that the drawings in the following description are only some embodiments of the present application, and it is obvious for those skilled in the art to obtain other drawings based on these drawings without inventive exercise.
FIG. 1 is a flowchart of a training method of a text retrieval matching model according to a first embodiment of the present application;
FIG. 2 is a schematic structural diagram of a training apparatus for matching a text search model according to a second embodiment of the present application;
fig. 3 is a schematic structural diagram of an electronic device according to a third embodiment of the present application.
Detailed Description
The technical solutions in the embodiments of the present application will be clearly and completely described below with reference to the drawings in the embodiments of the present application, and it is obvious that the described embodiments are some, but not all, embodiments of the present application. All other embodiments, which can be derived by a person skilled in the art from the embodiments given herein without making any creative effort, shall fall within the protection scope of the present application.
Example one
As shown in fig. 1, the method for training a text retrieval matching model according to this embodiment includes: step 110 to step 170.
Step 110, obtaining an initial sample set and obtaining an unlabeled data set, where each initial sample in the initial sample set includes a search term, a document and a tag, and the unlabeled data in the unlabeled data set includes a search term and a document.
The method comprises the steps of obtaining an initial sample set marked manually, wherein the initial sample set comprises three fields, namely a search term, a document and a label, the search term is a search query term (query) of a user, the document is a document to be searched or an entity text document (doc), in the field of e-commerce, the document can be a business name (poi _ name) or a trade name and the like, the label is a correlation identification of the search term and the document, when the classification is carried out, the label can be 0 or 1, and more labels can be set in multi-classification.
And acquiring unlabeled data sets from the search log and/or the text retrieval database, wherein each unlabeled data in the unlabeled data set at least comprises two fields of a retrieval word and a document.
And step 120, performing fine tuning training on the pre-training language model according to the initial sample set to obtain an initial text retrieval matching model.
The pre-training language model may be an open source BERT model pre-trained with open domain corpus, or a model obtained by pre-training the open source BERT model with corpus in the target domain. In the chinese domain, open source BERT models pre-trained with open domain corpus, including but not limited to google original BERT, ALBERT, roBERTa, etc., can be used directly if word-level segmentation is used.
And carrying out fine tuning training on the pre-training language model by using the initial sample set to obtain an initial text retrieval matching model. The structure of the pre-training language model can adopt a Cross network (Cross-encoders) or a twin network (Bi-encoders), and when the Cross network is adopted, a search word and a document in an initial sample are input into a BERT model as a whole; when a twin network is used, the terms and documents in an initial sample pass through a BERT model shared by two parameters. When the pre-training language model is subjected to fine-tuning training, an early-stopping mechanism can be set, such as setting iteration times or setting a threshold value of a loss function value, so as to ensure that the model is not over-fitted and has certain generalization capability.
When the pre-training language model is subjected to fine tuning training, different initialization Random Seed parameters (Random Seed) can be further set, namely different model parameters are randomly initialized and are respectively trained, and then the TOP N models with the best performance are selected as candidates of a subsequent model by using a cross validation thought, namely the model is used as an initial text retrieval matching model, wherein N can be 1 generally, and certainly can be other values, which are not limited herein.
In an embodiment of the present application, before the performing fine-tuning training on the pre-trained language model according to the initial sample set, the method further includes: and pre-training the initial pre-training language model by using the linguistic data in the target field to obtain a pre-training language model.
The target field is a field to which the trained text retrieval matching model is applied, and may be, for example, an e-commerce field or an open-domain search field.
And obtaining the linguistic data in the target field, and performing fine tuning training on the open source BERT model pre-trained by the open domain linguistic data by using the linguistic data in the target field, so that the BERT model can fully learn the knowledge of the target field. The target field belongs to the search field, and the corpus in the target field can be constructed by using search exposure and click data in the search field.
If Chinese word level segmentation is used, then the corpus in the target domain can also be used for jade training. Of course, if the search domain itself is an open domain and the segmentation level is a word level, the open-source version of the BERT-like model can be used directly, and this step can be omitted.
And step 130, predicting the unmarked data in the unmarked data set through the initial text retrieval matching model to obtain an output embedded vector and a prediction label probability distribution corresponding to each unmarked data in the unmarked data set.
After the initial text retrieval matching model is obtained, the initial text retrieval matching model can be trained by using the unlabeled data set to obtain a target text retrieval matching model.
Firstly, predicting unmarked data in an unmarked data set through an initial text retrieval matching model, namely inputting retrieval words and documents in the unmarked data into the initial text retrieval matching model to obtain an output embedded vector and a prediction tag probability distribution corresponding to the unmarked data. The output embedded vector is generally a network output vector of a layer before softmax, a first [ CLS ] vector is generally taken in a Cross network (Cross-encoders), and a twin network (Bi-encoders) is generally a vector after two-end network pooling (posing). The predicted label probability distribution is a normalized value of the probability corresponding to each label (label value) output after softmax, and the probability distribution value is also generally called as a 'soft label' in the semi-supervised learning field, so the predicted label probability distribution can be used as a soft label.
In an embodiment of the present application, predicting, by the initial text retrieval matching model, unlabeled data in the unlabeled data set to obtain an output embedded vector and a predictive tag probability distribution corresponding to each unlabeled data in the unlabeled data set, includes: predicting the unmarked data of the current batch in the unmarked data set through the initial text retrieval matching model to obtain an output embedded vector and a prediction tag probability corresponding to each unmarked data in the current batch; and respectively carrying out normalization processing on the probability of the predicted label corresponding to each unlabeled data in the current batch to obtain the probability distribution of the predicted label corresponding to each unlabeled data in the current batch.
When training a model, it is common to train multiple rounds and divide the training into multiple batches in each round. When the initial text retrieval matching model is trained, the initial text retrieval matching model is trained according to the training mode, namely, firstly, unlabeled data of the current batch are respectively input into the initial text retrieval matching model to obtain an output embedded vector corresponding to each unlabeled data in the current batch and predicted label probabilities corresponding to a plurality of labels, and the predicted label probabilities corresponding to each unlabeled data in the current batch are normalized to obtain the predicted label probability distribution corresponding to each unlabeled data in the current batch. The probability distribution of the predicted labels obtained by normalization processing can accurately reflect the probability distribution situation of the predicted labels of each unlabeled data in the current batch.
In an optional implementation manner, the normalizing the probability of the predicted label corresponding to each unlabeled data in the current batch respectively includes:
respectively carrying out normalization processing on the probability of the predicted label corresponding to each unlabeled data in the current batch according to the following normalization calculation formula:
Figure BDA0003000196440000071
wherein, the first and the second end of the pipe are connected with each other,
Figure BDA0003000196440000072
represents the square sum of the probabilities of the predicted tags corresponding to the jth tag of all the unlabeled data in the current batch,
Figure BDA0003000196440000073
representing the square of the probability of a predicted label corresponding to the jth label of the current unlabeled data, f (x; theta) representing the probability of the predicted label obtained by the current unlabeled data through the initial text detection model, x representing the current unlabeled data, theta representing the model parameter of the initial text detection model, y representing the total number of all labels,
Figure BDA0003000196440000081
and expressing the normalized probability corresponding to the jth label in the predictive label probability distribution corresponding to the current unlabeled data. j' is used to traverse all tags.
And 140, screening the unmarked data set according to the probability distribution of the predictive label corresponding to each unmarked data to obtain a high-confidence sample set.
And screening the unmarked data in the current batch in the unmarked data set according to the probability distribution of the predictive label corresponding to each unmarked data to screen out the sample with high confidence level and obtain the sample set with high confidence level.
In an embodiment of the present application, the screening the unlabeled data set according to the probability distribution of the predictive label corresponding to each unlabeled data to obtain a high-confidence sample set includes: screening unlabeled data with normalized probability of the labels being greater than or equal to a confidence threshold in the probability distribution of the predicted labels from the unlabeled dataset; and determining the unmarked data obtained by screening and the output embedded vector and the predictive label probability distribution corresponding to the unmarked data as a high confidence sample set.
And for each unlabeled data in the unlabeled data set, comparing the normalized probability corresponding to each label in the predictive label probability distribution of the unlabeled data with a confidence threshold, and screening the unlabeled data to serve as a high-confidence sample of the label if the normalized probability of one label is greater than or equal to the confidence threshold. For example, in a dichotomy, if the confidence threshold is 0.7, the normalized probability of an unlabeled data for tag 0 is 0.8, and the normalized probability of tag 1 is 0.2, then the unlabeled data is taken as the high-confidence sample of tag 0. And determining the unmarked data obtained by screening, the output embedded vector corresponding to the unmarked data and the probability distribution of the predictive label as a high-confidence sample set, and then using the high-confidence sample set to train the initial text retrieval matching model continuously.
And 150, determining the similarity relation between each high-confidence sample and other high-confidence samples in the high-confidence sample set and the vector distance of the output embedded vector according to the predictive label probability distribution and the output embedded vector of each high-confidence sample in the high-confidence sample set, and determining the confidence weight of each high-confidence sample.
One high confidence sample in the high confidence sample set includes a search term, a document, a predictive tag probability distribution, and an output embedding vector. And determining the similarity relation between the high-confidence sample and other high-confidence samples according to the probability distribution of the predictive label of each high-confidence sample in the high-confidence sample set. And determining the vector distance between each high-confidence sample and other high-confidence samples according to the output embedded vector of each high-confidence sample in the high-confidence sample set, wherein the vector distance can be Euclidean distance or cosine distance and the like. The confidence weight of the high confidence samples can be determined according to the entropy of the predictive label probability distribution of each high confidence sample in the high confidence sample set.
In an embodiment of the present application, determining a similarity relationship between each high-confidence sample in the high-confidence sample set and other high-confidence samples and a vector distance of an output embedded vector according to a predictive label probability distribution and an output embedded vector of each high-confidence sample in the high-confidence sample set, and determining a confidence weight of each high-confidence sample includes:
determining the similarity relation between each high-confidence sample in the high-confidence sample set and other high-confidence samples according to the predictive label probability distribution of each high-confidence sample in the high-confidence sample set;
determining a vector distance between each high-confidence sample in the high-confidence sample set and other high-confidence samples according to an output embedded vector of each high-confidence sample in the high-confidence sample set;
and determining the confidence weight of each high-confidence sample in the high-confidence sample set according to the predictive label probability distribution of each high-confidence sample in the high-confidence sample set.
And comparing the label probability distribution of each high confidence sample with other high confidence samples, if the label probability distributions are similar, determining that the similarity relation of the two high confidence samples is similar, otherwise, determining that the similarity relation is not similar. Calculating the distance between output embedded vectors of two high-confidence samples in the high-confidence sample set to obtain the vector distance between the two high-confidence samples, wherein the vector distance can be Euclidean distance or cosine distance, and the Euclidean distance is taken as an example and can be calculated by the following formula:
Figure BDA0003000196440000091
wherein x is i And x j Are two high confidence samples, v, of the high confidence sample set i Is a high confidence sample x i Is output embedded vector v j Is a high confidence sample x j Output embedded vector of d ij For high confidence samples x i And high confidence sample x j The vector distance between.
When determining the confidence weights of the high-confidence samples, the confidence weights of the high-confidence samples may be determined according to entropy values of predictive label probability distributions of each high-confidence sample in the high-confidence sample set.
In an alternative embodiment, determining a similarity relationship between each high-confidence sample in the high-confidence sample set and other high-confidence samples according to the predictive label probability distribution of each high-confidence sample in the high-confidence sample set includes:
according to the probability distribution of the prediction label of each high-confidence sample in the high-confidence sample set, determining the similarity relation between each high-confidence sample in the high-confidence sample set and other high-confidence samples according to the following formula:
Figure BDA0003000196440000101
wherein x is i And x j Are two high confidence samples, S, in the high confidence sample set ij Representing two high confidence samples x i And x j The similarity between them, y represents the total number of all tags,
Figure BDA0003000196440000102
representing high confidence samples x i The normalized probability of the kth label in the predicted label probability distribution of (a),
Figure BDA0003000196440000103
representing high confidence samples x j Is calculated based on the normalized probability of the kth label in the predicted label probability distribution.
When the similarity relation of the two high-confidence samples is determined, the probability distribution of the predicted labels of the two high-confidence samples is compared, if the label corresponding to the minimum normalized probability in the probability distribution of the predicted labels is the same, the value of the similarity relation of the two high-confidence samples is determined to be 1, otherwise, the value of the similarity relation of the two high-confidence samples is determined to be 0. The training of the initial text retrieval matching model is constrained by determining the similarity relationship of the two high confidence samples.
In an alternative embodiment, the determining the confidence weight of each high-confidence sample in the high-confidence sample set according to the predictive label probability distribution of each high-confidence sample in the high-confidence sample set includes:
determining an entropy value of predictive label probability distribution of each high-confidence sample in the high-confidence sample set;
determining the confidence weight of each high-confidence sample according to the formula as follows according to the entropy value:
Figure BDA0003000196440000111
where w is the confidence weight for the current high confidence sample,
Figure BDA0003000196440000112
is the entropy value of the predicted label probability distribution for the current high confidence sample, and y is the total number of labels in the predicted label probability distribution.
Determining an entropy value of a predictive label probability distribution for each high confidence sample in the set of high confidence samples by:
Figure BDA0003000196440000113
wherein the content of the first and second substances,
Figure BDA0003000196440000114
the normalized probability of the i-th label in the label probability distribution representing the current high confidence sample, y is the total number of all labels,
Figure BDA0003000196440000115
is the entropy value of the predictive label probability distribution for the current high confidence sample.
And after entropy of the predictive label probability distribution of the high-confidence sample is obtained, determining the confidence weight of the high-confidence sample according to the confidence weight calculation formula.
Because the confidence of each high-confidence sample in the high-confidence sample set is different, the weights of the sample soft labels (namely, the predicted label probability distribution) in the training are also different, and the confidence weight is determined by taking the entropy value of the predicted label probability distribution of the sample as a reference and is used during the training. The entropy value is used because the entropy value represents uncertainty, the larger the entropy value is, the smaller the uncertainty is, and the more confident the sample is, so that more accurate confidence weight can be obtained through the entropy value, the training speed of the model can be increased, and the retrieval matching accuracy of the trained model can be improved.
And 160, adding the confidence weight of each high-confidence sample, the similarity relation between each high-confidence sample and other high-confidence samples and the vector distance to the high-confidence sample set to obtain a target sample set.
The number of the target samples in the target sample set is the same as that of the high-confidence samples in the high-confidence sample set, and the target samples in the target sample set comprise search words, documents, prediction label probability distribution, output embedded vectors, confidence weights, similarity relations with other target samples and vector distances.
Step 170, training the initial text retrieval matching model according to the target sample set and the target loss function associated with the confidence weight, the similarity relation and the vector distance to obtain a target text retrieval matching model.
Wherein the target loss function is a loss function associated with confidence weights, similarity relationships, and vector distances, including: the method comprises a prediction label probability distribution KL divergence loss function based on confidence weighting, a comparison loss function based on sample similarity relation and a sample confidence regular loss function.
Determining the value of the target loss function of the current batch according to the confidence weight of each target sample in the target sample set, the similarity relation with other target samples and the vector distance, performing back propagation according to the value of the target loss function, adjusting the model parameters of the initial text retrieval matching model, performing steps 130-170 iteratively after adjustment to perform one round of training, then performing steps 130-170 iteratively using the same unlabeled data set to perform the next round of training until the training target is reached, and ending the training to obtain the trained target retrieval text matching model.
In an embodiment of the present application, the training the initial text retrieval matching model according to the target sample set and a target loss function associated with confidence weights, similarity relations, and vector distances includes: determining a loss function value of KL divergence of the prediction label probability distribution weighted by the confidence coefficient in the current batch according to the confidence coefficient weight and the prediction label probability distribution of each target sample in the current batch of the target sample set; determining a comparison loss function value of the current batch based on the sample similarity according to the similarity between each target sample and other target samples in the current batch of the target sample set and the vector distance of the input embedded vector; determining a confidence regular loss function value of a sample in the current batch according to the prediction label probability of each target sample in the current batch of the target sample set; determining a target loss function value of the current batch according to the loss function value of the probability distribution KL divergence of the current batch of the prediction labels weighted based on the confidence coefficient, the comparison loss function value of the current batch based on the sample similarity relation and the regular loss function value of the current batch of the sample confidence coefficient; and adjusting the model parameters of the initial text retrieval matching model according to the target loss function value of the current batch.
And according to the confidence weight and the probability distribution of the predicted label of each target sample in the current batch of the target sample set, carrying out weighted summation on the KL divergence of the probability distribution of the predicted label to obtain a loss function value of the KL divergence of the probability distribution of the predicted label weighted by the confidence of the current batch. The similarity relation and the vector distance of the two target samples can measure the similarity degree of the two target samples, and the comparison loss function value of the current batch based on the sample similarity relation is determined based on the similarity relation and the vector distance between each target sample and other target samples. According to the prediction label probability of each target sample in the current batch of the target sample set, the confidence coefficient regular loss function value of the samples in the current batch is determined, and a confidence coefficient regular term is introduced, so that overfitting of a model to high-confidence coefficient samples can be prevented, and the soft label prediction distribution is smoother. Adding the loss function value of the KL divergence of the prediction label probability distribution weighted by confidence coefficient of the current batch, the contrast loss function value of the sample similarity relation of the current batch and the regular loss function value of the confidence coefficient of the sample of the current batch to obtain the target loss function value of the current batch, namely the target loss function L = L 1 +L 2 +λL 3 Where λ is a regular term coefficient (which can be set generally to 0.05 to 0.1), L 1 Loss function for predicting KL divergence of a probability distribution of tags based on confidence weighting, L 2 For the comparison loss function of the current batch based on the sample similarity relationship, L 3 And (5) normalizing the loss function for the confidence of the current batch sample. And performing back propagation by using the target loss function value of the current batch, and adjusting the model parameters of the initial text retrieval matching model. By adopting the target loss function comprising the confidence coefficient regular loss function of the samples in the current batch, the contrast loss function based on the sample similarity relation and the confidence coefficient regular loss function of the samples in the current batch, the training speed can be increased, and the retrieval matching accuracy of the model can be improved.
In an optional embodiment, determining a loss function value of a prediction label probability distribution KL divergence of a current batch based on confidence weighting according to the confidence weight and the prediction label probability distribution of each target sample in the current batch of the target sample set includes: determining KL divergence of the probability distribution of the predicted label of each target sample in the current batch relative to the probability of the predicted label as first KL divergence; and according to the confidence weight of each target sample in the current batch, carrying out weighted summation on the first KL divergence of all the target samples in the current batch and averaging to obtain a loss function value of the KL divergence of the probability distribution of the prediction label weighted by the confidence in the current batch.
Determining a first KL divergence of the predicted label probability distribution relative to the predicted label probability for each target sample in the current lot according to the following formula:
Figure BDA0003000196440000131
wherein the content of the first and second substances,
Figure BDA0003000196440000132
is the first KL-divergence and,
Figure BDA0003000196440000133
representing a predictive label probability distribution of the target sample, f (x; theta) representing a predictive label probability of the target sample, x representing the target sample, theta representing model parameters of the initial text retrieval matching model,
Figure BDA0003000196440000134
representing the normalized probability, f (x; theta), corresponding to the kth label in the predicted label probability distribution of the target sample k And the predicted label probability corresponding to the kth label in the predicted label probabilities of the target sample is represented.
According to the following formula, according to the confidence weight of each target sample in the current batch, carrying out weighted summation on the first KL divergence of all target samples in the current batch and averaging to obtain a loss function value of the KL divergence of the prediction label probability distribution weighted by the confidence in the current batch:
Figure BDA0003000196440000141
wherein L is 1 A loss function representing the confidence-weighted prediction label probability distribution KL divergence, w (x) represents the confidence weight of the target sample,
Figure BDA0003000196440000142
indicating a first KL divergence.
According to the loss function of the KL divergence of the probability distribution of the predicted labels based on the confidence weighting, on the basis of the KL divergence loss function of the traditional semi-supervised learning, the confidence weighting is used for adjusting the weight, the confidence weighting of different target samples is fully considered, and the model training speed can be improved.
In an embodiment of the present application, determining a contrast loss function value of a current batch based on a sample similarity relationship according to the similarity relationship between each target sample in the current batch of the target sample set and other target samples and a vector distance of an input embedding vector, includes:
according to the similarity relation between each target sample and other target samples in the current batch of the target sample set and the vector distance of the input embedded vector, determining a comparison loss function value of the current batch based on the sample similarity relation according to the following formula:
Figure BDA0003000196440000143
wherein L is 2 A comparison loss function value, x, representing the sample similarity relationship based on the current batch i And x j Is any two target samples, S, in the current batch of the target sample set ij Is the target sample x i And x j A similar relationship of (d) ij Is a target sample x i And x j The vector distance of the embedded vector is output, γ is a hyperparameter, and C refers to all samples of the current batch. SupermallyThe parameter gamma is a threshold for controlling the distance separation of dissimilar samples.
The similarity degree of all target samples in the current batch can be controlled through the comparison loss function value of the current batch based on the sample similarity relation, and the training speed of the model and the retrieval matching accuracy of the model are improved.
In an embodiment of the present application, determining a current batch sample confidence regularization loss function value according to a predicted label probability of each target sample in a current batch of the target sample set includes:
respectively determining KL divergence of the uniform distribution with the same dimensionality as the predicted label probability relative to the predicted label probability of each target sample in the current batch as second KL divergence;
and averaging the second KL divergence corresponding to all the target samples in the current batch to obtain a confidence coefficient regular loss function value of the samples in the current batch.
Respectively determining KL divergence of the uniform distribution with the same dimension as the predicted label probability relative to the predicted label probability of each target sample in the current batch as a second KL divergence according to the following formula:
Figure BDA0003000196440000151
wherein D is KL (u | | f (x; θ)) represents a second KL divergence, u is a uniform distribution of the same dimension as the predicted tag probability, u is K Denotes the kth value in the uniform distribution, f (x; theta) k And the predicted label probability corresponding to the kth label in the predicted label probabilities of the target sample is represented.
Averaging the second KL divergence corresponding to all the target samples in the current batch according to the following formula to obtain a confidence regular loss function value of the samples in the current batch:
Figure BDA0003000196440000152
wherein L is 3 Represents the sample confidence canonical loss function, C represents all samples of the current batch.
Over-fitting of the model to high-confidence samples can be prevented by a sample confidence canonical loss function.
The training method of the text retrieval matching model provided by the embodiment of the application comprises the steps of obtaining an initial sample set and an unlabeled data set, carrying out fine-tuning training on a pre-training language model according to the initial sample set to obtain an initial text retrieval matching model, predicting unlabeled data in the unlabeled data set through the initial text retrieval matching model to obtain an output embedded vector and a predicted tag probability distribution corresponding to each unlabeled data in the unlabeled data set, screening the unlabeled data set according to the predicted tag probability distribution corresponding to each unlabeled data to obtain a high-confidence sample set, determining a similarity relation between each high-confidence sample and other high-confidence samples in the high-confidence sample set and a vector distance of the output embedded vector according to the predicted tag probability distribution and the output embedded vector of each high-confidence sample in the high-confidence sample set, determining a confidence weight of each high-confidence sample, adding the confidence weight of each high-confidence sample and the similarity relation and the vector distance between each high-confidence sample and other high-confidence samples in the high-confidence sample set to the high-confidence sample set, obtaining a target sample set, and further carrying out text retrieval matching on the basis of the similarity relation between each high-confidence sample set and the target sample and the target text retrieval model, and the target loss of the target search model, and obtaining a target search result. Because only a small amount of manually labeled initial sample sets are needed, the labor cost is effectively saved, and the interference of noise data in unlabeled samples can be effectively avoided and the training efficiency and accuracy of the model are improved by performing self-supervision learning through the target loss function associated with the confidence coefficient weight, the similarity relation and the vector distance.
On the basis of the above technical solution, before predicting the unlabeled data in the unlabeled dataset by the initial text retrieval matching model, the method further includes:
performing data enhancement processing on the unmarked data set by a sampling method to obtain a data enhancement sample set; merging the data enhancement sample set and the initial sample set to obtain a fusion sample set; and performing fine tuning training on the initial text retrieval matching model according to the fusion sample set to obtain the initial text retrieval matching model after data enhancement.
Wherein the sampling method may include at least one of random negative example sampling, BM25 sampling, and semantic similarity sampling.
During the binary classification, data enhancement processing can be carried out on an unlabeled data set through a sampling method, a negative sample is obtained by carrying out random negative sample sampling on the unlabeled data set, a positive sample is obtained by carrying out BM25 sampling and/or semantic similarity sampling on the unlabeled data set, all the obtained negative and positive samples are combined with an initial sample set to obtain a fused sample set, fine tuning training is carried out on the initial text retrieval matching model according to the fused sample set to obtain a data enhanced initial text retrieval matching model, and then self-supervision training is carried out on the data enhanced initial text retrieval matching model by using the unlabeled data set. By means of training the initial text retrieval matching model after fine tuning by using a small sample such as an initial sample set by continuously using a fusion data set after data enhancement, the semantic representation capability of the model can be further increased, the semantic generalization of the model is improved, and therefore the accuracy of retrieval matching can be improved.
On the basis of the above technical solution, the performing data enhancement processing on the unlabeled data set by using a sampling method to obtain a data enhancement sample set includes:
for a current search term in the unlabeled data set, randomly selecting a document from the unlabeled data set, and forming a negative example sample by the current search term and the selected document to obtain a negative example set;
establishing indexes for all documents in the unlabeled data set, determining a first document set with a first preset number most relevant to the current search term according to the indexes by using a BM25 algorithm, and respectively forming a positive example sample by the current search term and each document in the first document set to obtain a first positive example set;
for a current search word in the unlabeled data set, respectively predicting the current search word and each document in the unlabeled data set through the initial text search matching model, determining a second document set with a second preset number of positive case labels corresponding to the highest prediction probability, and respectively forming a positive case sample by the current search word and each document in the second document set to obtain a second positive case set;
taking at least one of the set of negative examples, the first set of positive examples, and the second set of positive examples as the set of data enhancement samples.
The method comprises the steps of selecting a document from an unlabeled data set at random by a random negative example sampling mode, forming a search term document pair by the search term and the selected document, wherein the probability of the search term document pair is irrelevant, namely label =0, so that the search term document pair is used as a negative example sample.
The method comprises the steps of establishing indexes for all documents in an unlabeled data set by using an elastic search, determining a set of TOP K similar documents obtained by a current search term according to the indexes by using a BM25 algorithm, namely K is a first preset number to obtain a first document set, respectively forming the current search term and each document in the first document set into a search term document pair, wherein the search term document pair is related, namely a label =1, the search term document pair is a positive example sample, and respectively determining a corresponding positive example sample for each search term to obtain a first positive example set. Although there is noise interference with direct use of the BM25 for retrieval/sorting, the BM25 is more confident at the TOP bit.
For any current search word in the unlabeled data set, respectively predicting the current search word and each document in the unlabeled data set through an initial text search matching model, determining a second document set with the highest prediction probability corresponding to the positive example labels in a second preset number, and respectively forming the current search word and each document in the second document set into a positive example sample, namely selecting Top K documents with the most similar semantics under each search word to form the positive example sample, so as to obtain the second positive example set. The language material is predicted by the initial text retrieval matching model, the regular example set is expanded, the problem that the BM25 lacks certain semantic generalization capability due to the fact that the face matching is emphasized can be solved, the semantic generalization capability can be improved, and the accuracy of the model is improved.
And taking at least one of a negative example set, a first positive example set and the second positive example set as the data enhancement sample set. The initial text retrieval matching model is subjected to fine tuning training through the data enhancement sample set, so that the semantic representation capability of the model can be further enhanced, and the semantic generalization of the model is improved.
On the basis of the above technical solution, taking at least one of the negative case set, the first positive case set, and the second positive case set as the data enhancement sample set, includes: determining the distribution proportion of each label in the initial sample set; and adjusting the distribution proportion of each label in at least one of the negative example set, the first positive example set and the second positive example set to the distribution proportion of each label in the initial sample set to obtain the data enhancement sample set.
In order to ensure that the sample distribution does not change too much after the data enhancement, the sample distribution of the initial sample set can be used to perform sampling and weighting on the data enhancement sample set. And counting the distribution proportion of each label in the initial sample set, and adjusting the distribution proportion of each label in at least one of the negative example set, the first positive example set and the second positive example set obtained by the sampling method into the distribution proportion of each label in the initial sample set to obtain the data enhancement sample set. For example, if the distribution ratio of positive example labels to negative example labels in the initial sample set is 1.
Example two
In the training apparatus for text retrieval matching model according to this embodiment, as shown in fig. 2, the training apparatus 200 for text retrieval matching model includes:
a data set obtaining module 210, configured to obtain an initial sample set and obtain an unlabeled data set, where each initial sample in the initial sample set includes a search term, a document, and a tag, and the unlabeled data in the unlabeled data set includes a search term and a document;
the first fine tuning training module 220 is configured to perform fine tuning training on a pre-training language model according to the initial sample set to obtain an initial text retrieval matching model;
the model prediction module 230 is configured to predict the unlabeled data in the unlabeled data set through the initial text retrieval matching model, so as to obtain an output embedded vector and a prediction tag probability distribution corresponding to each unlabeled data in the unlabeled data set;
the data set screening module 240 is configured to screen the unlabeled data sets according to the probability distribution of the predictive tag corresponding to each unlabeled data set, so as to obtain a high-confidence sample set;
a parameter determining module 250, configured to determine, according to the prediction label probability distribution and the output embedded vector of each high-confidence sample in the high-confidence sample set, a similarity relationship between each high-confidence sample and another high-confidence sample in the high-confidence sample set and a vector distance of the output embedded vector, and determine a confidence weight of each high-confidence sample;
a target sample construction module 260, configured to add the confidence weight of each high-confidence sample, and the similarity relationship and vector distance between each high-confidence sample and another high-confidence sample to the high-confidence sample set to obtain a target sample set;
and a model training module 270, configured to train the initial text retrieval matching model according to the target sample set and a target loss function associated with the confidence weight, the similarity relationship, and the vector distance, to obtain a target text retrieval matching model.
Optionally, the model prediction module includes:
the model prediction unit is used for predicting the unmarked data of the current batch in the unmarked data set through the initial text retrieval matching model to obtain an output embedded vector and a prediction tag probability corresponding to each unmarked data in the current batch;
and the normalization processing unit is used for respectively performing normalization processing on the prediction label probability corresponding to each unlabeled data in the current batch to obtain the prediction label probability distribution corresponding to each unlabeled data in the current batch.
Optionally, the normalization processing unit is specifically configured to:
respectively carrying out normalization processing on the prediction label probability corresponding to each unmarked data in the current batch according to the following normalization calculation formula:
Figure BDA0003000196440000191
wherein, the first and the second end of the pipe are connected with each other,
Figure BDA0003000196440000192
represents the square sum of the probabilities of the predicted tags corresponding to the jth tag of all the unlabeled data in the current batch,
Figure BDA0003000196440000201
representing the square of the probability of a predicted label corresponding to the jth label of the current unlabeled data, f (x; theta) representing the probability of the predicted label obtained by the current unlabeled data through the initial text detection model, x representing the current unlabeled data, theta representing the model parameter of the initial text detection model, y representing the total number of all labels,
Figure BDA0003000196440000202
and expressing the normalized probability corresponding to the jth label in the predictive label probability distribution corresponding to the current unlabeled data.
Optionally, the data set filtering module includes:
the data screening unit is used for screening the unlabeled data of which the normalized probability of the labels in the probability distribution of the predicted labels is greater than or equal to a confidence coefficient threshold value from the unlabeled data set;
and the high confidence sample construction unit is used for determining the unmarked data obtained by screening, and the output embedded vector and the predictive label probability distribution corresponding to the unmarked data as a high confidence sample set.
Optionally, the parameter determining module includes:
a similarity relation determining unit, configured to determine, according to the prediction label probability distribution of each high-confidence sample in the high-confidence sample set, a similarity relation between each high-confidence sample in the high-confidence sample set and another high-confidence sample;
a vector distance determination unit, configured to determine a vector distance between each high-confidence sample and another high-confidence sample in the high-confidence sample set according to an output embedded vector of each high-confidence sample in the high-confidence sample set;
and the confidence coefficient weight determining unit is used for determining the confidence coefficient weight of each high-confidence coefficient sample in the high-confidence coefficient sample set according to the prediction label probability distribution of each high-confidence coefficient sample in the high-confidence coefficient sample set.
Optionally, the similarity relation determining unit is specifically configured to:
according to the probability distribution of the prediction label of each high-confidence sample in the high-confidence sample set, determining the similarity relation between each high-confidence sample in the high-confidence sample set and other high-confidence samples according to the following formula:
Figure BDA0003000196440000203
wherein x is i And x j Are two high confidence samples, S, in the high confidence sample set ij Representing two high confidence samples x i And x j The similarity relationship between, y represents the total number of all tags,
Figure BDA0003000196440000204
representing high confidence samples x i The normalized probability of the kth label in the predicted label probability distribution of (1),
Figure BDA0003000196440000211
representing high confidence samples x j Is calculated based on the normalized probability of the kth label in the predicted label probability distribution.
Optionally, the confidence weighting unit is specifically configured to:
determining an entropy value of predictive label probability distribution of each high-confidence sample in the high-confidence sample set;
determining the confidence weight of each high-confidence sample according to the formula as follows according to the entropy value:
Figure BDA0003000196440000212
where w is the confidence weight for the current high confidence sample,
Figure BDA0003000196440000213
is the entropy value of the predicted label probability distribution for the current high confidence sample, and y is the total number of labels in the predicted label probability distribution.
Optionally, the objective loss function includes: the method comprises a prediction label probability distribution KL divergence loss function based on confidence weighting, a contrast loss function based on sample similarity relation and a sample confidence regular loss function.
Optionally, the model training module includes:
a KL divergence loss value determining unit, configured to determine a loss function value of KL divergence of a prediction label probability distribution weighted by a confidence coefficient for a current batch according to the confidence coefficient weight and the prediction label probability distribution of each target sample in the current batch of the target sample set;
the comparison loss value determining unit is used for determining a comparison loss function value of the current batch based on the sample similarity relation according to the similarity relation between each target sample and other target samples in the current batch of the target sample set and the vector distance of the input embedded vector;
the confidence regular loss value determining unit is used for determining a confidence regular loss function value of the samples in the current batch according to the prediction label probability of each target sample in the current batch of the target sample set;
a target loss value determining unit, configured to determine a current batch target loss function value according to the loss function value of the prediction tag probability distribution KL divergence weighted by the current batch based on the confidence, the contrast loss function value of the current batch based on the sample similarity relationship, and the current batch sample confidence regular loss function value;
and the model parameter adjusting unit is used for adjusting the model parameters of the initial text retrieval matching model according to the target loss function value of the current batch.
Optionally, the KL divergence loss value determination unit is specifically configured to:
determining KL divergence of the probability distribution of the predicted label of each target sample in the current batch relative to the probability of the predicted label as first KL divergence;
and according to the confidence weight of each target sample in the current batch, carrying out weighted summation on the first KL divergence of all the target samples in the current batch and averaging to obtain a loss function value of the KL divergence of the prediction label probability distribution weighted by the confidence in the current batch.
Optionally, the contrast loss value determining unit is specifically configured to:
according to the similarity relation between each target sample and other target samples in the current batch of the target sample set and the vector distance of the input embedded vector, determining a comparison loss function value of the current batch based on the sample similarity relation according to the following formula:
Figure BDA0003000196440000221
wherein L is 2 A value of a contrast loss function, x, representing the sample similarity relationship of the current batch i And x j Is any two target samples, S, in the current batch of the target sample set ij Is the target sample x i And x j A similar relationship of (d) ij Is the target sample x i And x j The vector distance of the embedded vector is output, γ is a hyperparameter, and C refers to all samples of the current batch.
Optionally, the confidence regularization loss value determination unit is specifically configured to:
respectively determining KL divergence of the uniform distribution with the same dimensionality as the predicted label probability relative to the predicted label probability of each target sample in the current batch as second KL divergence;
and averaging the second KL divergence corresponding to all the target samples in the current batch to obtain a confidence coefficient regular loss function value of the samples in the current batch.
Optionally, the apparatus further comprises:
and the pre-training module is used for pre-training the initial pre-training language model by using the corpora in the target field to obtain the pre-training language model.
Optionally, the apparatus further comprises:
the data enhancement module is used for carrying out data enhancement processing on the unlabeled data set by a sampling method to obtain a data enhancement sample set;
the sample fusion module is used for merging the data enhancement sample set and the initial sample set to obtain a fusion sample set;
and the second fine tuning training module is used for performing fine tuning training on the initial text retrieval matching model according to the fusion sample set to obtain the initial text retrieval matching model after data enhancement.
Optionally, the data enhancement module includes:
a random negative case sampling unit, configured to select a document from the unlabeled data set at random for a current search term in the unlabeled data set, and form a negative case sample with the current search term and the selected document to obtain a negative case set;
the BM25 sampling unit is used for establishing indexes for all the documents in the unlabeled data set, determining a first document set with a first preset number most relevant to the current search term according to the indexes by using a BM25 algorithm, and respectively forming a positive example sample by the current search term and each document in the first document set to obtain a first positive example set;
the semantic similarity sampling unit is used for predicting each document in the current search word and the unlabeled data set respectively through the initial text search matching model for the current search word in the unlabeled data set, determining a second document set with a second preset number of positive case labels corresponding to the highest prediction probability, and respectively forming a positive case sample by the current search word and each document in the second document set to obtain a second positive case set;
a data enhancement sample set determination unit configured to use at least one of the negative case set, the first positive case set, and the second positive case set as the data enhancement sample set.
Optionally, the data enhancement sample set determining unit is specifically configured to:
determining the distribution proportion of each label in the initial sample set;
and adjusting the distribution proportion of each label in at least one of the negative example set, the first positive example set and the second positive example set to the distribution proportion of each label in the initial sample set to obtain the data enhancement sample set.
The training device for the text retrieval matching model provided in the embodiment of the present application is used to implement each step of the training method for the text retrieval matching model described in the first embodiment of the present application, and specific implementation modes of each module of the device refer to the corresponding step, which is not described herein again.
The training device for the text retrieval matching model, provided by the embodiment of the application, comprises the steps of obtaining an initial sample set and an unlabeled data set, performing fine-tuning training on a pre-training language model according to the initial sample set to obtain an initial text retrieval matching model, predicting unlabeled data in the unlabeled data set through the initial text retrieval matching model to obtain an output embedded vector and a predicted tag probability distribution corresponding to each unlabeled data in the unlabeled data set, screening the unlabeled data set according to the predicted tag probability distribution corresponding to each unlabeled data to obtain a high-confidence sample set, determining a similarity relation between each high-confidence sample and other high-confidence samples in the high-confidence sample set and a vector distance of the output embedded vector according to the predicted tag probability distribution and the output embedded vector of each high-confidence sample in the high-confidence sample set, determining a confidence weight of each high-confidence sample, adding the confidence weight of each high-confidence sample and the similarity relation and the vector distance between each high-confidence sample and other high-confidence samples in the high-confidence sample set to the high-confidence sample set, obtaining a target sample set, and performing text retrieval on the target sample, and the target model based on the similarity relation and the text retrieval matching, and the target distance loss of the target model, and the target model. Because only a small amount of manually labeled initial sample sets are needed, the labor cost is effectively saved, and the interference of noise data in unlabeled samples can be effectively avoided and the training efficiency and accuracy of the model are improved by performing self-supervision learning through the target loss function associated with the confidence coefficient weight, the similarity relation and the vector distance.
EXAMPLE III
Embodiments of the present application also provide an electronic device, as shown in fig. 3, the electronic device 300 may include one or more processors 310 and one or more memories 320 connected to the processors 310. Electronic device 300 may also include input interface 330 and output interface 340 for communicating with another apparatus or system. Program code executed by processor 310 may be stored in memory 320.
The processor 310 in the electronic device 300 invokes the program code stored in the memory 320 to perform the training method of the text retrieval matching model in the above embodiment.
The embodiment of the present application further provides a computer-readable storage medium, on which a computer program is stored, where the computer program, when executed by a processor, implements the steps of the training method for text retrieval matching model according to the first embodiment of the present application.
The embodiments in the present specification are described in a progressive manner, each embodiment focuses on differences from other embodiments, and the same and similar parts among the embodiments are referred to each other. For the device embodiment, since it is basically similar to the method embodiment, the description is simple, and for the relevant points, refer to the partial description of the method embodiment.
The method, the device, the electronic device and the storage medium for training the text retrieval matching model provided by the embodiment of the application are introduced in detail, a specific example is applied in the text to explain the principle and the implementation of the application, and the description of the embodiment is only used for helping to understand the method and the core idea of the application; meanwhile, for a person skilled in the art, according to the idea of the present application, the specific implementation manner and the application scope may be changed, and in summary, the content of the present specification should not be construed as a limitation to the present application.
Through the above description of the embodiments, those skilled in the art will clearly understand that each embodiment can be implemented by software plus a necessary general hardware platform, and certainly can also be implemented by hardware. With this understanding in mind, the above-described technical solutions may be embodied in the form of a software product, which can be stored in a computer-readable storage medium such as ROM/RAM, magnetic disk, optical disk, etc., and includes instructions for causing a computer device (which may be a personal computer, a server, or a network device, etc.) to execute the methods described in the embodiments or some parts of the embodiments.

Claims (19)

1. A training method for a text retrieval matching model is characterized by comprising the following steps:
acquiring an initial sample set and acquiring an unlabeled data set, wherein each initial sample in the initial sample set comprises a search word, a document and a label, and the unlabeled data in the unlabeled data set comprises a search word and a document;
performing fine tuning training on a pre-training language model according to the initial sample set to obtain an initial text retrieval matching model;
predicting the unmarked data in the unmarked data set through the initial text retrieval matching model to obtain an output embedded vector and a predictive label probability distribution corresponding to each unmarked data in the unmarked data set;
screening the unmarked data set according to the probability distribution of the predictive label corresponding to each unmarked data to obtain a high-confidence sample set;
according to the predictive label probability distribution and the output embedded vector of each high-confidence sample in the high-confidence sample set, determining the similarity relation between each high-confidence sample and other high-confidence samples in the high-confidence sample set and the vector distance of the output embedded vector, and determining the confidence weight of each high-confidence sample;
adding the confidence weight of each high-confidence sample, and the similarity relation and the vector distance between each high-confidence sample and other high-confidence samples to the high-confidence sample set to obtain a target sample set;
and training the initial text retrieval matching model according to the target sample set and a target loss function associated with the confidence coefficient weight, the similarity relation and the vector distance to obtain a target text retrieval matching model.
2. The method of claim 1, wherein predicting unlabeled data in the unlabeled data set by the initial text retrieval matching model to obtain an output embedded vector and a predictive label probability distribution corresponding to each unlabeled data in the unlabeled data set comprises:
predicting the unmarked data of the current batch in the unmarked data set through the initial text retrieval matching model to obtain an output embedded vector and a prediction tag probability corresponding to each unmarked data in the current batch;
and respectively carrying out normalization processing on the probability of the predicted label corresponding to each unlabeled data in the current batch to obtain the probability distribution of the predicted label corresponding to each unlabeled data in the current batch.
3. The method according to claim 2, wherein the normalizing the probability of the predictive label corresponding to each unlabeled data in the current batch respectively comprises:
respectively carrying out normalization processing on the prediction label probability corresponding to each unmarked data in the current batch according to the following normalization calculation formula:
Figure FDA0003000196430000021
wherein the content of the first and second substances,
Figure FDA0003000196430000022
represents the square sum of the probabilities of the predicted labels corresponding to the jth label of all the unlabeled data in the current batch,
Figure FDA0003000196430000023
representing the square of the probability of a predicted label corresponding to the jth label of the current unlabeled data, f (x; theta) representing the probability of the predicted label obtained by the current unlabeled data through the initial text detection model, x representing the current unlabeled data, theta representing the model parameter of the initial text detection model, y representing the total number of all labels,
Figure FDA0003000196430000024
and expressing the normalized probability corresponding to the jth label in the predictive label probability distribution corresponding to the current unlabeled data.
4. The method of claim 1, wherein the screening the unlabeled data set according to the probability distribution of the predictive label corresponding to each unlabeled data to obtain a high-confidence sample set comprises:
screening unlabeled data with normalized probability of the labels being greater than or equal to a confidence threshold in the probability distribution of the predicted labels from the unlabeled dataset;
and determining the unmarked data obtained by screening and the output embedded vector and the predictive label probability distribution corresponding to the unmarked data as a high confidence sample set.
5. The method of claim 1, wherein determining a similarity relationship between each high-confidence sample in the set of high-confidence samples and other high-confidence samples and a vector distance of an output embedding vector based on the predictive label probability distribution and the output embedding vector for each high-confidence sample in the set of high-confidence samples, and determining a confidence weight for each high-confidence sample comprises:
determining the similarity relation between each high-confidence sample in the high-confidence sample set and other high-confidence samples according to the predictive label probability distribution of each high-confidence sample in the high-confidence sample set;
determining a vector distance between each high-confidence sample in the high-confidence sample set and other high-confidence samples according to an output embedded vector of each high-confidence sample in the high-confidence sample set;
and determining the confidence weight of each high-confidence sample in the high-confidence sample set according to the predictive label probability distribution of each high-confidence sample in the high-confidence sample set.
6. The method of claim 5, wherein determining similarity relationships between each high-confidence sample in the high-confidence sample set and other high-confidence samples based on the predictive label probability distribution for each high-confidence sample in the high-confidence sample set comprises:
according to the probability distribution of the prediction label of each high-confidence sample in the high-confidence sample set, determining the similarity relation between each high-confidence sample in the high-confidence sample set and other high-confidence samples according to the following formula:
Figure FDA0003000196430000031
wherein x is i And x j Are two high confidence samples, S, in the high confidence sample set ij Representing two high confidence samples x i And x j The similarity between them, y represents the total number of all tags,
Figure FDA0003000196430000032
representing high confidence samples x i The normalized probability of the kth label in the predicted label probability distribution of (a),
Figure FDA0003000196430000033
representing high confidence samples x j Is predicted, the normalized probability of the kth label in the label probability distribution.
7. The method of claim 5, wherein determining a confidence weight for each high confidence sample in the high confidence sample set based on the predictive label probability distribution for each high confidence sample in the high confidence sample set comprises:
determining an entropy value of predictive label probability distribution of each high-confidence sample in the high-confidence sample set;
according to the entropy value, determining the confidence coefficient weight of each high confidence coefficient sample according to the following formula:
Figure FDA0003000196430000034
where w is the confidence weight for the current high confidence sample,
Figure FDA0003000196430000035
the entropy value of the predictive label probability distribution of the current high-confidence sample is obtained, and y is the total number of labels in the predictive label probability distribution.
8. The method of claim 1, wherein the objective loss function comprises: the method comprises a prediction label probability distribution KL divergence loss function based on confidence weighting, a contrast loss function based on sample similarity relation and a sample confidence regular loss function.
9. The method of claim 8, wherein training the initial text retrieval matching model according to the target sample set and a target loss function associated with confidence weights, similarity relationships, and vector distances comprises:
determining a loss function value of KL divergence of the prediction label probability distribution weighted by the confidence coefficient in the current batch according to the confidence coefficient weight and the prediction label probability distribution of each target sample in the current batch of the target sample set;
determining a comparison loss function value of the current batch based on the sample similarity according to the similarity between each target sample and other target samples in the current batch of the target sample set and the vector distance of the input embedded vector;
determining a confidence regular loss function value of a sample in the current batch according to the prediction label probability of each target sample in the current batch of the target sample set;
determining a target loss function value of the current batch according to the loss function value of the probability distribution KL divergence of the current batch of the prediction labels weighted based on the confidence coefficient, the comparison loss function value of the current batch based on the sample similarity relation and the regular loss function value of the current batch of the sample confidence coefficient;
and adjusting the model parameters of the initial text retrieval matching model according to the target loss function value of the current batch.
10. The method of claim 9, wherein determining the loss function value of the current lot based on the KL divergence of the confidence weighted predicted label probability distribution according to the confidence weight and the predicted label probability distribution of each target sample in the current lot of the target sample set comprises:
determining KL divergence of the probability distribution of the predicted label of each target sample in the current batch relative to the probability of the predicted label as first KL divergence;
and according to the confidence weight of each target sample in the current batch, carrying out weighted summation on the first KL divergence of all the target samples in the current batch and averaging to obtain a loss function value of the KL divergence of the prediction label probability distribution weighted by the confidence in the current batch.
11. The method of claim 9, wherein determining a comparison loss function value of the current batch based on the sample similarity relationship according to the similarity relationship between each target sample in the current batch of the target sample set and other target samples and the vector distance of the input embedding vector comprises:
according to the similarity relation between each target sample and other target samples in the current batch of the target sample set and the vector distance of the input embedded vector, determining a comparison loss function value of the current batch based on the sample similarity relation according to the following formula:
Figure FDA0003000196430000051
wherein L is 2 Indicating the current batchComparison loss function value x based on sample similarity relation i And x j Is any two target samples, S, in the current batch of the target sample set ij Is a target sample x i And x j A similar relationship of (d) ij Is the target sample x i And x j The vector distance of the embedded vector is output, γ is a hyperparameter, and C refers to all samples of the current batch.
12. The method of claim 9, wherein determining a current batch sample confidence regularization loss function value according to the predicted label probability of each target sample in the current batch of the target sample set comprises:
respectively determining KL divergence of the uniform distribution with the same dimensionality as the predicted label probability relative to the predicted label probability of each target sample in the current batch as second KL divergence;
and averaging the second KL divergence corresponding to all the target samples in the current batch to obtain a confidence regular loss function value of the samples in the current batch.
13. The method of claim 1, further comprising, prior to the fine-tuning training of the pre-trained language model based on the initial sample set:
and pre-training the initial pre-training language model by using the linguistic data in the target field to obtain a pre-training language model.
14. The method of claim 1, further comprising, prior to predicting unlabeled data in the unlabeled dataset by the initial text retrieval matching model:
performing data enhancement processing on the unmarked data set by a sampling method to obtain a data enhancement sample set;
merging the data enhancement sample set and the initial sample set to obtain a fusion sample set;
and performing fine tuning training on the initial text retrieval matching model according to the fusion sample set to obtain the initial text retrieval matching model after data enhancement.
15. The method of claim 14, wherein the performing data enhancement processing on the unlabeled data set by a sampling method to obtain a data-enhanced sample set comprises:
for a current search word in the unlabeled data set, randomly selecting a document from the unlabeled data set, and forming a negative example sample by the current search word and the selected document to obtain a negative example set;
establishing indexes for all documents in the unlabeled data set, determining a first document set with a first preset number most relevant to the current search term according to the indexes by using a BM25 algorithm, and respectively forming a positive example sample by using the current search term and each document in the first document set to obtain a first positive example set;
for a current search word in the unlabeled data set, respectively predicting the current search word and each document in the unlabeled data set through the initial text search matching model, determining a second document set with a second preset number of positive case labels corresponding to the highest prediction probability, and respectively forming a positive case sample by the current search word and each document in the second document set to obtain a second positive case set;
taking at least one of the set of negative examples, the first set of positive examples, and the second set of positive examples as the set of data enhancement samples.
16. The method of claim 15, wherein taking at least one of the negative case set, the first positive case set, and the second positive case set as the data enhancement sample set comprises:
determining the initial sample set the distribution ratio of each label;
and adjusting the distribution proportion of each label in at least one of the negative example set, the first positive example set and the second positive example set to the distribution proportion of each label in the initial sample set to obtain the data enhancement sample set.
17. An apparatus for training a text search matching model, comprising:
the data set acquisition module is used for acquiring an initial sample set and acquiring an unlabeled data set, wherein each initial sample in the initial sample set comprises a search word, a document and a label, and the unlabeled data in the unlabeled data set comprises a search word and a document;
the first fine tuning training module is used for performing fine tuning training on a pre-training language model according to the initial sample set to obtain an initial text retrieval matching model;
the model prediction module is used for predicting the unmarked data in the unmarked data set through the initial text retrieval matching model to obtain an output embedded vector and a prediction label probability distribution corresponding to each unmarked data in the unmarked data set;
the data set screening module is used for screening the unmarked data sets according to the probability distribution of the predictive label corresponding to each unmarked data to obtain a high-confidence sample set;
the parameter determination module is used for determining the similarity relation between each high-confidence sample in the high-confidence sample set and other high-confidence samples and the vector distance of the output embedded vector according to the predictive label probability distribution and the output embedded vector of each high-confidence sample in the high-confidence sample set, and determining the confidence weight of each high-confidence sample;
the target sample construction module is used for adding the confidence weight of each high-confidence sample, the similarity relation and the vector distance between each high-confidence sample and other high-confidence samples to the high-confidence sample set to obtain a target sample set;
and the model training module is used for training the initial text retrieval matching model according to the target sample set and a target loss function associated with the confidence coefficient weight, the similarity relation and the vector distance to obtain a target text retrieval matching model.
18. An electronic device comprising a memory, a processor and a computer program stored on the memory and executable on the processor, wherein the processor implements the method for training a text retrieval matching model according to any one of claims 1 to 16 when executing the computer program.
19. A computer-readable storage medium, on which a computer program is stored, which program, when being executed by a processor, carries out the steps of the method for training a text retrieval matching model according to any one of claims 1 to 16.
CN202110343807.5A 2021-03-30 2021-03-30 Training method and device for text retrieval matching model, electronic equipment and medium Pending CN115146021A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110343807.5A CN115146021A (en) 2021-03-30 2021-03-30 Training method and device for text retrieval matching model, electronic equipment and medium

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110343807.5A CN115146021A (en) 2021-03-30 2021-03-30 Training method and device for text retrieval matching model, electronic equipment and medium

Publications (1)

Publication Number Publication Date
CN115146021A true CN115146021A (en) 2022-10-04

Family

ID=83404621

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110343807.5A Pending CN115146021A (en) 2021-03-30 2021-03-30 Training method and device for text retrieval matching model, electronic equipment and medium

Country Status (1)

Country Link
CN (1) CN115146021A (en)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117033612A (en) * 2023-08-18 2023-11-10 中航信移动科技有限公司 Text matching method, electronic equipment and storage medium
CN117786121A (en) * 2024-02-28 2024-03-29 珠海泰坦软件系统有限公司 File identification method and system based on artificial intelligence
CN117033612B (en) * 2023-08-18 2024-06-04 中航信移动科技有限公司 Text matching method, electronic equipment and storage medium

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117033612A (en) * 2023-08-18 2023-11-10 中航信移动科技有限公司 Text matching method, electronic equipment and storage medium
CN117033612B (en) * 2023-08-18 2024-06-04 中航信移动科技有限公司 Text matching method, electronic equipment and storage medium
CN117786121A (en) * 2024-02-28 2024-03-29 珠海泰坦软件系统有限公司 File identification method and system based on artificial intelligence
CN117786121B (en) * 2024-02-28 2024-05-03 珠海泰坦软件系统有限公司 File identification method and system based on artificial intelligence

Similar Documents

Publication Publication Date Title
CN113011533B (en) Text classification method, apparatus, computer device and storage medium
Zhang et al. Discovering new intents with deep aligned clustering
CN111738003B (en) Named entity recognition model training method, named entity recognition method and medium
CN113268995B (en) Chinese academy keyword extraction method, device and storage medium
CN109299228B (en) Computer-implemented text risk prediction method and device
CN112464656B (en) Keyword extraction method, keyword extraction device, electronic equipment and storage medium
CN113392209B (en) Text clustering method based on artificial intelligence, related equipment and storage medium
CN109359302B (en) Optimization method of domain word vectors and fusion ordering method based on optimization method
CN109086265B (en) Semantic training method and multi-semantic word disambiguation method in short text
CN115952292B (en) Multi-label classification method, apparatus and computer readable medium
CN116992007B (en) Limiting question-answering system based on question intention understanding
CN114298055B (en) Retrieval method and device based on multilevel semantic matching, computer equipment and storage medium
CN114428850A (en) Text retrieval matching method and system
CN116756303A (en) Automatic generation method and system for multi-topic text abstract
CN114722176A (en) Intelligent question answering method, device, medium and electronic equipment
CN114676346A (en) News event processing method and device, computer equipment and storage medium
CN117094291B (en) Automatic news generation system based on intelligent writing
CN111581365B (en) Predicate extraction method
CN116720498A (en) Training method and device for text similarity detection model and related medium thereof
CN115146021A (en) Training method and device for text retrieval matching model, electronic equipment and medium
US11822887B2 (en) Robust name matching with regularized embeddings
CN116150306A (en) Training method of question-answering robot, question-answering method and device
JP7121819B2 (en) Image processing method and apparatus, electronic device, computer-readable storage medium, and computer program
CN111767388B (en) Candidate pool generation method
CN114969324A (en) Chinese news title classification method based on subject word feature expansion

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination