Disclosure of Invention
In view of the above situation, there is a need to solve the problems that the conventional counterlearning-based method only remains in the feature extraction layer migration knowledge, and the recognition performance is not ideal.
The embodiment of the invention provides a discourse relation identification method based on knowledge distillation and multitask learning, wherein the method comprises the following steps:
taking an implicit discourse relation instance labeled with the connection words and implicit discourse relation categories as a training instance;
constructing a teacher model reinforced by connecting words based on a bidirectional attention mechanism classification model, and carrying out iterative minimization processing on a cost function corresponding to the teacher model reinforced by the connecting words by taking the connecting words as additional input until convergence to obtain a trained teacher model;
constructing a multi-task learning student model based on the two-way attention mechanism classification model, introducing connection word classification as an auxiliary task to determine a cost function based on multi-task learning, calculating the characteristics and the prediction result of a training example by using the trained teacher model to determine a cost function based on knowledge distillation, and then determining a total cost function of the student model;
and iteratively minimizing the total cost function of the student model until convergence so as to output the trained student model, and further identifying the implicit discourse relation of the test case.
The invention provides a discourse relation identification method based on knowledge distillation and multi-task learning, which takes an implicit discourse relation example marked with connection words and categories as a training example and aims to fully utilize information of the connection words inserted during corpus marking; firstly, constructing a teacher model reinforced by connecting words based on a bidirectional attention mechanism classification model, and iteratively minimizing a cost function until convergence by using the connecting words as additional input to obtain a trained teacher model; and then training the constructed multi-task learning student model, constructing a total cost function based on a multi-task learning and knowledge distillation method, and carrying out minimum iteration processing on the total cost function until convergence, thereby outputting the well-trained multi-task learning student model. On one hand, the discourse relation identification method based on knowledge distillation and multi-task learning, provided by the invention, shares knowledge between a connection word classification auxiliary task and an implicit discourse relation identification main task based on a parameter sharing mode (shared characteristic extraction layer); on the other hand, the knowledge in the teacher model enhanced by the connecting words is migrated from the feature extraction layer and the classification layer to the corresponding implicit discourse relation recognition model (multitask learning student model) based on the knowledge distillation technology; the recognition performance of the student model is improved by fully utilizing the information of the connecting words inserted during the corpus labeling. The method provided by the invention obtains better identification performance on the first-level and second-level implicit discourse relations of the common PDTB data set than the similar method.
The discourse relation identification method based on knowledge distillation and multitask learning is characterized in that in the training example, the implicit discourse relation example marked with the connection words and the implicit discourse relation category is represented as
;
Wherein,
two arguments representing the implicit discourse relation training instance,
a conjunction that indicates a label is used,
and representing the implicit discourse relation category of the label.
The discourse relation identification method based on knowledge distillation and multitask learning is characterized in that in the teacher model with strengthened connecting words, the input is
The corresponding cost function is expressed as:
wherein,
are the parameters of the teacher model and are,
implicit discourse relation classification for annotations
The corresponding one-hot code is coded,
indicating the expected value of the prediction result with respect to the label category,
representing the prediction results obtained after the classification layer of the teacher model strengthened by the connecting words,
is a training example set.
The chapter relationship identification method based on knowledge distillation and multitask learning is characterized in that in the multitask learning student model, the total cost function of the student model is expressed as:
wherein,
for the total cost function of the student model,
are the parameters of the student model and are,
the weight coefficients are respectively a cost function based on multitask learning and a cost function based on knowledge distillation;
the base isThe cost function of the multi-task learning comprises two parts:
for cross-entropy cost functions identified corresponding to implicit discourse relations,
is a cross entropy cost function corresponding to the connected word classification; the cost function of knowledge-based distillation comprises two parts:
as a cost function corresponding to the distillation of the knowledge of the feature extraction layer,
as a cost function corresponding to the distillation of the knowledge of the classification layer.
The chapter relationship identification method based on knowledge distillation and multitask learning is characterized in that in the multitask learning student model, the input is
The cross-entropy cost function corresponding to implicit discourse relation identification is expressed as:
wherein,
are the parameters of the student model and are,
implicit discourse relation classification for annotations
The corresponding one-hot code is coded,
indicating the expected value of the prediction result with respect to the label category,
representing the prediction result obtained after the student model is classified into the layer 1 and corresponding to the implicit discourse relation identification,
is a training example set.
The chapter relationship identification method based on knowledge distillation and multitask learning is characterized in that a cross entropy cost function corresponding to the connection word classification in the multitask learning student model is represented as follows:
wherein,
are the parameters of the student model and are,
for marked conjunctions
The corresponding one-hot code is coded,
indicating the expected value of the prediction result with respect to the annotation link,
represents the prediction result corresponding to the connection word classification obtained after the student model classification layer 2,
is a training example set.
The chapter relationship identification method based on knowledge distillation and multitask learning is characterized in that a cost function corresponding to feature extraction layer knowledge distillation in the multitask learning student model is represented as follows:
wherein,
which represents the mean-square error of the signal,
representing the characteristics obtained after the teacher model characteristic extraction layer strengthened by the connecting words,
representing features obtained after passing through a feature extraction layer of the multi-task learning student model,
is a training example set.
The chapter relationship identification method based on knowledge distillation and multitask learning is characterized in that a cost function corresponding to classification layer knowledge distillation in the multitask learning student model is represented as follows:
wherein,
indicating the KL-distance between the two probability distributions,
representing the teacher model classification layer reinforced by the connection wordsThe result of the prediction is that,
and the prediction result obtained after the multi-task learning student model classification layer 1 is represented.
The knowledge distillation and multitask learning-based discourse relation identification method comprises the steps that the bidirectional attention mechanism classification model comprises an encoding layer, an interaction layer, an aggregation layer and a classification layer, wherein the encoding layer is used for learning the expression of words in arguments in context, and the encoding layer is expressed as follows:
wherein,
are respectively the first in argument 1
A word vector of words and its representation in context,
are respectively the first in argument 2
Word vectors of words and their representation in context,
and
the number of words in two arguments respectively,
both are bidirectional long-time memory networks.
The invention also provides a chapter relationship recognition device based on knowledge distillation and multitask learning, wherein the device comprises:
the training input module is used for taking an implicit discourse relation example marked with the connection words and the implicit discourse relation category as a training example;
the first construction module is used for constructing a teacher model reinforced by connecting words based on a bidirectional attention mechanism classification model, taking the connecting words as additional input, and performing iterative minimization processing on a cost function corresponding to the teacher model reinforced by the connecting words until convergence to obtain a trained teacher model;
the second construction module is used for constructing a multi-task learning student model based on the two-way attention mechanism classification model, introducing connection word classification as an auxiliary task to determine a cost function based on multi-task learning, calculating the characteristics and the prediction result of a training example by using the trained teacher model to determine a cost function based on knowledge distillation, and then determining a total cost function of the student model;
and the training output module is used for iteratively minimizing the total cost function of the student model until convergence so as to output the trained student model and further identify the implicit discourse relation of the test case.
Additional aspects and advantages of the invention will be set forth in part in the description which follows and, in part, will be obvious from the description, or may be learned by practice of the invention.
Detailed Description
Reference will now be made in detail to embodiments of the present invention, examples of which are illustrated in the accompanying drawings, wherein like or similar reference numerals refer to the same or similar elements or elements having the same or similar function throughout. The embodiments described below with reference to the accompanying drawings are illustrative only for the purpose of explaining the present invention, and are not to be construed as limiting the present invention.
These and other aspects of embodiments of the invention will be apparent with reference to the following description and attached drawings. In the description and drawings, particular embodiments of the invention have been disclosed in detail as being indicative of some of the ways in which the principles of the embodiments of the invention may be practiced, but it is understood that the scope of the embodiments of the invention is not limited correspondingly. On the contrary, the embodiments of the invention include all changes, modifications and equivalents coming within the spirit and terms of the claims appended hereto.
The existing method based on counterstudy is not sufficient in utilization of the information of the connecting words, only stays in the feature extraction layer to transfer knowledge, and the recognition performance is not ideal.
In order to solve the technical problem, the present invention provides a chapter relationship identification method based on knowledge distillation and multitask learning, referring to fig. 1 to 3, the method includes the following steps:
s101, taking the implicit discourse relation example marked with the connection words and the implicit discourse relation category as a training example.
Specifically, any implicit discourse relation training instance labeled with conjunctions and relation categories in the corpus can be represented as
. Wherein,
two arguments representing the implicit discourse relation training instance,
the connecting words inserted during the labeling, namely the real connecting word marks,
and representing the annotated implicit discourse relation category, namely the real category label.
And S102, constructing a teacher model reinforced by connecting words based on the two-way attention mechanism classification model, taking the connecting words as additional input, and performing iterative minimization processing on a cost function corresponding to the teacher model reinforced by the connecting words until convergence to obtain a trained teacher model.
It should be noted that the teacher model is an implicit discourse relation identification model reinforced by connecting words and uses argument
And conjunctions inserted at the time of annotation
Is an input. The characteristics of the teacher model obtained after passing through the characteristic extraction layer are expressed as
And the prediction result of the teacher model obtained after the classification layer is expressed as
。
When a teacher model is trained, a teacher model cost function (cross-entropy classification cost function) is minimized on training corpora. Wherein the teacher model cost function is represented as:
wherein,
are the parameters of the teacher model and are,
for the annotated implicit discourse relation category,
implicit discourse relation classification for annotations
Corresponding One-hot Encoding (One-hot Encoding),
a conjunction that indicates a label is used,
indicating the expected value of the prediction result with respect to the label category,
representing the prediction results obtained after the classification layer of the teacher model strengthened by the connecting words,
is a training example set.
It should be added that the teacher model reinforced by the connection words simulates the process of human labeling the implicit discourse relation. In inserting conjunctions
With the assistance of (2), the recognition performance is far higher than that of only argument
The input multitask learning student model (for example, the accuracy rate on the first-level implicit discourse relation classification task of PDTB corpus can reach more than 85 percent), which fully shows that the teacher model with strengthened connecting words can be inserted when corpus labeling is well fusedThe connection word information of (1).
S103, constructing a multi-task learning student model based on the two-way attention mechanism classification model, introducing connection word classification as an auxiliary task to determine a cost function based on multi-task learning, calculating the characteristics and the prediction result of a training example by using the trained teacher model to determine a cost function based on knowledge distillation, and then determining a total cost function of the student model.
The multi-task learning student model is a chapter relation identification model based on multi-task learning. With conjunctive word classification as an auxiliary task, i.e. giving implicit discourse relation examples
Predicting a conjunction word suitable for connecting two arguments; and taking implicit discourse relation identification as a main task. Models of two related tasks (implicit discourse relation identification task and connection word classification task) share a characteristic extraction layer, and the respective classification layers are used. Specifically, referring to fig. 3, the classification layer 1 is used for the implicit discourse relation identification task, and the classification layer 2 is used for the conjunctive word classification task. Through the shared feature extraction layer, the models of two related tasks can exchange information, so that the effect of mutual promotion is achieved. Multi-task learning student model only using argument
As input, the student model features obtained after passing through the shared feature extraction layer are expressed as
The prediction result of the multi-task learning student model obtained after the classification layer 1 corresponding to the implicit discourse relation identification is represented as
The prediction result of the multi-task learning student model obtained after the classification layer 2 corresponding to the connection word classification is expressed as
。
When training a multi-task learning student model, in order to enable the model to fit a training example as much as possible
It is desirable to minimize the cost function based on multitask learning, i.e., to simultaneously minimize the cross-entropy classification cost function corresponding to implicit discourse relation identification and the cross-entropy classification cost function corresponding to conjunctive word classification.
Specifically, the cross-entropy classification cost function corresponding to implicit discourse relation identification is represented as:
wherein,
are the parameters of the student model and are,
for the annotated implicit discourse relation category,
implicit discourse relation categories representing annotations
The corresponding one-hot code is coded,
indicating the expected value of the prediction result with respect to the label category,
shows that the prediction result about the implicit discourse relation is obtained after the student model is classified into the layer 1,
is a training example set.
The cross-entropy classification cost function corresponding to the conjunctive word classification is expressed as:
wherein,
to learn the parameters of the student model for multiple tasks,
a conjunction that indicates a label is used,
conjunctions representing annotations
The corresponding one-hot code is coded,
indicating the expected value of the prediction result with respect to the annotation link,
the prediction result about the connection words is obtained after the classification layer 2 of the student model is shown,
is a training example set.
In order to learn the classification knowledge integrated with the connecting word information from the teacher model, the invention adopts a knowledge distillation method, and the basic idea is to make the student model simulate the behavior of the teacher model as much as possible.
On the one hand, it is desirable to learn features learned by a multi-task learning student model and a connection-enhanced teacher model
And
the two models can be as close as possible, so that the knowledge transfer of the two models in the feature extraction layer is realized. As can be seen from the fact that the recognition performance of the teacher model on the PDTB data set is much higher than that of the student models, the characteristics of the teacher model
Containing specific student model features
More information useful for implicit discourse relation identification.
Specifically, a cost function corresponding to the distillation of the knowledge of the feature extraction layer in the student model is defined as:
wherein,
which represents the mean-square error of the signal,
are the parameters of the student model and are,
representing the characteristics obtained after the teacher model characteristic extraction layer strengthened by the connecting words,
representing the features obtained after the feature extraction layer of the multi-task learning student model,
is a training example set.
On the other hand, final prediction results of teacher model with reinforcement of multi-task learning student model and connection words
And
the two models can be as close as possible, so that the knowledge migration of the two models at the classification layer is realized. True class labels represented by one-hot coding
Can be regarded as a Hard Label (Hard Label), and the predicted result of the teacher model
Can be regarded as a Soft Label (Soft Label), and the Soft Label is generally considered to contain more category information. For example, similarity information between categories. Specifically, a cost function corresponding to the distillation of knowledge of the classification layer in the multi-task learning student model is defined as:
wherein,
indicating the KL (Kullback-Leibler) distance between the two probability distributions,
for the implicit discourse relation training example with the connection word information,
representing the prediction results obtained after the teacher model classification layer strengthened by the connecting words,
And representing a prediction result obtained after the multi-task learning student model is classified into the layer 1.
Finally, the multitask learning student model total cost function is defined as a linear summation of the multitask learning based cost function and the knowledge distillation based cost function.
Specifically, the total cost function of the multi-task learning student model is expressed as:
wherein,
are the parameters of the student model and are,
the weight coefficients are respectively a cost function based on multitask learning and a cost function based on knowledge distillation; the cost function based on the multi-task learning comprises two parts:
for cross-entropy cost functions identified corresponding to implicit discourse relations,
is a cross entropy cost function corresponding to the connected word classification; the cost function of knowledge-based distillation consists of two parts:
as a cost function corresponding to the distillation of the knowledge of the feature extraction layer,
as a cost function corresponding to the distillation of the knowledge of the classification layer.
And S104, iteratively minimizing the total cost function of the student model until convergence, so as to output the trained student model, and further identifying the implicit discourse relation of the test case.
Algorithm 1 describes the training process of the discourse relation identification method based on knowledge distillation and multitask learning.
Specifically, the whole training process is divided into two stages: the first stage is based on a cost function
Training a teacher model reinforced by connecting words (steps 1-5), and in the second stage, based on a cost function
Training a multitask student model (step 6-12). For simplicity, the step of judging whether the model converges or not based on the verification data set is omitted in the algorithm 1, and the finally trained multi-task learning student model is the required implicit discourse relation identification model.
Algorithm 1 training algorithm
Inputting: training example set
Maximum number of training rounds
And (3) outputting: trained multi-task learning student model
1. Constructing teacher model and initializing parameters randomly
2. Repeating the following steps:
3. from a set of training examples
Take out a batch of examples
4. Minimizing join-term-enforced teacher model cost function
Updating the parameters
5. Until: model convergence or maximum number of training rounds
6. Constructing a multi-task learning student model and randomly initializing parameters
7. Repeating the following steps:
8. from a set of training examples
Take out a batch of examples
9. Calculating corresponding characteristics based on trained teacher model reinforced by connecting words
10. Calculating corresponding prediction results based on trained connection word reinforced teacher model
11. Minimizing a multi-task learning student model cost function
Updating the parameters
12. Until: model convergence or maximum number of training rounds
Meanwhile, in the present invention, the above-mentioned two-way attention mechanism classification model is often used to model semantic relationships between two sentences, such as text implication recognition, automatic question-answering, sentence semantic matching, and the like.
Referring to fig. 4, in particular, the two-way attention mechanism classification model includes a coding layer, an interaction layer, an aggregation layer and a classification layer. The feature extraction layer is composed of a coding layer, an interaction layer and an aggregation layer. In addition, the coding layer is used for learning the expression of words in the argument in the context, and the coding layer is expressed as follows:
wherein,
are respectively the first in argument 1
A word vector of words and its representation in context,
are respectively the first in argument 2
Word vectors of words and their representation in context,
and
the number of words in two arguments respectively,
both are bidirectional long-time memory networks.
The interaction layer is represented as:
wherein,
is a fully connected multi-layer feedforward neural network,
is the number 1 of argument
The first word and argument 2
Relevance weight of the individual words;
is the number 1 of the argument
The representation of a word in argument 2 to which the word is related,
is the number 2 of the argument
The representation of a word in argument 1 to which the word is related,
is another fully-connected multi-layer feed-forward neural network,
a stitching operation of the representation vector is performed,
and
can be regarded as learned local semantic relation representation.
The aggregation layer calculates the global semantic relation based on the local semantic relation expression
. The expression is specifically as follows:
wherein,
the characteristics extracted by the characteristic extraction layer are expressed as the characteristics in the student model and the teacher model respectively
And
。
in addition, the classification layer is used to calculate the final classification result. The details are as follows:
wherein,
by a fully-connected multi-layer feedforward neural network and
layer composition;
is the final classification result.
For the teacher model with strengthened connecting words, the teacher model can be directly constructed based on the two-way attention mechanism classification model, and only needs to strengthen input of the connecting words, namely the input of the model is used as the input of the model
In particular, connecting words
Is spliced at
The beginning of argument 2 in, as new argument 2. The learned features are expressed as
The predicted result is expressed as
。
For a multi-task learning student model, the construction of the two-way attention mechanism classification model, the implicit discourse relation recognition task and the connection word classification task need to be simply expanded to share a feature extraction layer, but the classification layers are respectively used. Specifically, for the input example
Features obtained through the shared feature extraction layer are
Then, based on the classification layer 1, the prediction result corresponding to the implicit discourse relation identification is calculated as:
wherein,
by a fully-connected multi-layer feedforward neural network and
layer composition; the prediction result corresponding to the connected word classification is calculated based on the classification layer 2 as:
wherein,
by a fully-connected multi-layer feedforward neural network and
and (3) layer composition.
The invention provides a discourse relation identification method based on knowledge distillation and multi-task learning, which takes an implicit discourse relation example marked with connection words and categories as a training example and aims to fully utilize information of the connection words inserted during corpus marking; firstly, constructing a teacher model reinforced by connecting words based on a bidirectional attention mechanism classification model, and iteratively minimizing a cost function until convergence by using the connecting words as additional input to obtain a trained teacher model; and then training the constructed multi-task student model, constructing a total cost function based on a multi-task learning and knowledge distillation method, and carrying out minimum iteration processing on the total cost function until convergence, thereby outputting the well-trained multi-task student model.
According to the discourse relation identification method based on knowledge distillation and multitask learning, on one hand, knowledge is shared between a connection word classification auxiliary task and an implicit discourse relation identification main task based on a parameter sharing mode (shared characteristic extraction layer), on the other hand, knowledge in a teacher model enhanced by connection words is migrated to a corresponding implicit discourse relation identification model (multitask learning student model) from the characteristic extraction layer and the classification layer based on a knowledge distillation technology, so that the identification performance of the student model is improved by fully utilizing connection word information inserted during corpus labeling. The method provided by the invention obtains better identification performance on the first-level and second-level implicit discourse relations of the common PDTB data set than the similar method.
Referring to fig. 5, for the discourse relation identification device based on knowledge distillation and multitask learning according to the second embodiment of the present invention, the device includes a training input module 111, a first construction module 112, a second construction module 113, and a training output module 114, which are connected in sequence;
wherein the training input module 111 is specifically configured to:
taking an implicit discourse relation instance labeled with the connection words and implicit discourse relation categories as a training instance;
the first construction module 112 is specifically configured to:
constructing a teacher model reinforced by connecting words based on a bidirectional attention mechanism classification model, and carrying out iterative minimization processing on a cost function corresponding to the teacher model reinforced by the connecting words by taking the connecting words as additional input until convergence to obtain a trained teacher model;
the second construction module 113 is specifically configured to:
constructing a multi-task learning student model based on the two-way attention mechanism classification model, introducing connection word classification as an auxiliary task to determine a cost function based on multi-task learning, calculating the characteristics and the prediction result of a training example by using the trained teacher model to determine a cost function based on knowledge distillation, and then determining a total cost function of the student model;
the training output module 114 is specifically configured to:
and iteratively minimizing the total cost function of the student model until convergence so as to output the trained student model, and further identifying the implicit discourse relation of the test case.
In the description herein, references to the description of the term "one embodiment," "some embodiments," "an example," "a specific example," or "some examples," etc., mean that a particular feature, structure, material, or characteristic described in connection with the embodiment or example is included in at least one embodiment or example of the invention. In this specification, the schematic representations of the terms used above do not necessarily refer to the same embodiment or example. Furthermore, the particular features, structures, materials, or characteristics described may be combined in any suitable manner in any one or more embodiments or examples.
The above-mentioned embodiments only express several embodiments of the present invention, and the description thereof is more specific and detailed, but not construed as limiting the scope of the present invention. It should be noted that, for a person skilled in the art, several variations and modifications can be made without departing from the inventive concept, which falls within the scope of the present invention. Therefore, the protection scope of the present patent shall be subject to the appended claims.