CN114861885A - Knowledge distillation-based model training method, related equipment and readable storage medium - Google Patents

Knowledge distillation-based model training method, related equipment and readable storage medium Download PDF

Info

Publication number
CN114861885A
CN114861885A CN202210550915.4A CN202210550915A CN114861885A CN 114861885 A CN114861885 A CN 114861885A CN 202210550915 A CN202210550915 A CN 202210550915A CN 114861885 A CN114861885 A CN 114861885A
Authority
CN
China
Prior art keywords
model
student
student model
training
loss
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210550915.4A
Other languages
Chinese (zh)
Inventor
唐海桃
王智国
方昕
李永超
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
iFlytek Co Ltd
Original Assignee
iFlytek Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by iFlytek Co Ltd filed Critical iFlytek Co Ltd
Priority to CN202210550915.4A priority Critical patent/CN114861885A/en
Publication of CN114861885A publication Critical patent/CN114861885A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Filters That Use Time-Delay Elements (AREA)

Abstract

The application discloses a knowledge distillation-based model training method, related equipment and a readable storage medium. After acquiring a teacher model, a student model, training data and a label of the training data; training the student model to be trained by taking the training data as a training sample, taking the output distribution of the middle network layer of the student model approaching to the output distribution of the middle network layer of the teacher model, taking the final output distribution of the student model approaching to the final output distribution of the teacher model, taking the label of the final output of the student model approaching to the training data as a training target, and obtaining the trained student model. In the training process, the output and the final output of the teacher model intermediate network layer are simultaneously utilized to guide the learning of the student model, so that the output of the student model intermediate network layer is as close as possible to the output of the teacher model intermediate network layer, and the final output of the student model is ensured to be as close as possible to the final output of the teacher model.

Description

Knowledge distillation-based model training method, related equipment and readable storage medium
Technical Field
The application relates to the technical field of neural networks, in particular to a knowledge distillation-based model training method, related equipment and a readable storage medium.
Background
Knowledge distillation is based on a teacher-student model compression mode, and knowledge migration is realized by introducing a large-scale teacher model to induce training of a small-scale student model. The traditional model training method based on knowledge distillation is that a teacher model is trained firstly, and then a student model is trained by using the final output of the teacher model and the label of a training sample, so that the student model can learn how to judge the type of a correct sample from the training sample and can learn the relationship between types from the teacher model.
However, in some scenarios, both the teacher model and the student model contain an intermediate network layer, and the final output of the models is correlated with the output of the intermediate network layer. For example, in a scenario where a flow-based end-to-end speech recognition model is trained based on a traditional knowledge-based distillation method, the teacher model and the student model each include an encoder, a decoder, and a joint network layer, where the encoder and the decoder serve as an intermediate network layer, the output of the joint network layer serves as the final output of the model, and the output of the joint network layer is related to the output of the encoder and the decoder. In these scenarios, training a student model using only the final output of the teacher model and the label labels of the training samples may result in the output of the student model intermediate network layer not being similar to the output of the teacher model intermediate network layer, which may eventually result in the final output of the student model not being closer than the final output of the teacher model.
Therefore, how to optimize the conventional knowledge-based distillation model training method becomes a technical problem to be solved urgently by those skilled in the art.
Disclosure of Invention
In view of the above, the present application proposes a knowledge-based distillation model training method, related apparatus, and readable storage medium. The specific scheme is as follows:
a knowledge-distillation-based model training method, the method comprising:
acquiring a pre-trained teacher model, a student model to be trained, training data and a label of the training data; the teacher model and the student models both comprise an intermediate network layer;
and training the student model to be trained by taking the training data as a training sample, taking the output distribution of the middle network layer of the student model approaching to the output distribution of the middle network layer of the teacher model, taking the final output distribution of the student model approaching to the final output distribution of the teacher model, and taking the label of the final output of the student model approaching to the training data as a training target to obtain the trained student model.
Optionally, the training the student model to be trained by using the training data as a training sample, with the output distribution of the student model intermediate network layer approaching to the output distribution of the teacher model intermediate network layer, the final output distribution of the student model approaching to the final output distribution of the teacher model, and the final output of the student model approaching to the label of the training data as a training target, to obtain a trained student model, includes:
constructing a student model loss function for a training target by using the output distribution of the student model intermediate network layer approaching to the output distribution of the teacher model intermediate network layer, the final output distribution of the student model approaching to the final output distribution of the teacher model, and the final output of the student model approaching to the label of the training data;
and training the student model based on the constructed student model loss function to obtain the trained student model.
Optionally, the constructed student model loss function comprises a first loss term, a second loss term and a third loss term, wherein the first loss term is used for representing an error between an output distribution of an intermediate network layer of the teacher model and an output distribution of an intermediate network layer of the student model, and the second loss term is used for representing an error between a final output distribution of the teacher model and a final output distribution of the student model; the third loss term is used to characterize an error between a final output of the student model and the label of the training data.
Optionally, the constructed student model loss function includes: a first student model loss function;
the first student model loss function includes the first loss term, the second loss term, and the third loss term, wherein coefficients of the first loss term, the second loss term, and the third loss term are the same.
Optionally, the training the student model based on the constructed student model loss function to obtain a trained student model, including:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the first student model loss function until training is finished to obtain first student models;
determining the first student model as a trained student model;
alternatively, the first and second electrodes may be,
and inputting the training data into the teacher model and the first student model, and performing iterative optimization on parameters of the first student model based on the first student model loss function until the training is finished to obtain the trained student model.
Optionally, the constructed student model loss function includes: a second student model loss function and a third student model loss function;
wherein the second student model loss function includes the first loss term; the third student model loss function includes the second loss term and the third loss term, and coefficients of the first loss term, the second loss term, and the third loss term are the same.
Optionally, the training the student model based on the constructed student model loss function to obtain a trained student model, including:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the second student model loss function until the training is finished to obtain second student models;
and inputting the training data into the teacher model and the second student model, and performing iterative optimization on parameters of the second student model based on the third student model loss function until the training is finished to obtain the trained student model.
Optionally, the constructed student model loss function includes: a fourth student model loss function and a fifth student model loss function;
the fourth student model loss function comprises the first loss term, the second loss term and the third loss term, wherein the second loss term and the third loss term have the same coefficient, and the coefficient of the first loss term is much larger than the coefficients of the second loss term and the third loss term;
the fifth student model loss function includes the first loss term, the second loss term, and the third loss term, wherein coefficients of the second loss term and the third loss term are the same, and coefficients of the second loss term and the third loss term are much larger than a coefficient of the first loss term.
Optionally, the training the student model based on the constructed student model loss function to obtain a trained student model, including:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the fourth student model loss function until training is finished to obtain a third student model;
and inputting the training data into the teacher model and the third student model, and performing iterative optimization on parameters of the third student model based on the fifth student model loss function until the training is finished to obtain the trained student model.
Optionally, the error between the output distribution of the teacher model's middle network layer and the output distribution of the student model's middle network layer comprises: and errors between the output distribution of each sub-layer of the middle network layer of the teacher model and the output distribution of each sub-layer of the middle network layer of the student model.
Optionally, the teacher model and the student model are both streaming end-to-end speech recognition models, the streaming end-to-end speech recognition model includes an encoder, a decoder and a union network, the encoder and the decoder are intermediate network layers, and an output of the union network is a final output of the streaming end-to-end speech recognition model.
A knowledge-based distillation model training apparatus, the apparatus comprising:
the device comprises an acquisition unit, a training unit and a control unit, wherein the acquisition unit is used for acquiring a pre-trained teacher model, a student model to be trained, training data and a label of the training data; the teacher model and the student models both comprise an intermediate network layer;
and the training unit is used for training the student model to be trained by taking the training data as a training sample and taking the output distribution of the middle network layer of the student model as the output distribution of the middle network layer of the teacher model, the final output distribution of the student model as the final output distribution of the teacher model and the label of the final output of the student model as the training target, so as to obtain the trained student model.
Optionally, the training unit comprises:
a student model loss function constructing subunit, configured to construct a student model loss function with an output distribution of the student model intermediate network layer approaching the output distribution of the teacher model intermediate network layer, a final output distribution of the student model approaching the final output distribution of the teacher model, and a label of the student model final output approaching the training data as a training target;
and the training subunit is used for training the student model based on the constructed student model loss function to obtain a trained student model.
Optionally, the constructed student model loss function comprises a first loss term, a second loss term and a third loss term, wherein the first loss term is used for representing an error between an output distribution of an intermediate network layer of the teacher model and an output distribution of an intermediate network layer of the student model, and the second loss term is used for representing an error between a final output distribution of the teacher model and a final output distribution of the student model; the third loss term is used to characterize an error between a final output of the student model and the label of the training data.
Optionally, the constructed student model loss function includes: a first student model loss function;
the first student model loss function includes the first loss term, the second loss term, and the third loss term, wherein coefficients of the first loss term, the second loss term, and the third loss term are the same.
Optionally, the training subunit is specifically configured to:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the first student model loss function until training is finished to obtain first student models;
determining the first student model as a trained student model;
alternatively, the first and second electrodes may be,
and inputting the training data into the teacher model and the first student model, and performing iterative optimization on parameters of the first student model based on the first student model loss function until the training is finished to obtain the trained student model.
Optionally, the constructed student model loss function includes: a second student model loss function and a third student model loss function;
wherein the second student model loss function includes the first loss term; the third student model loss function includes the second loss term and the third loss term, and coefficients of the first loss term, the second loss term, and the third loss term are the same.
Optionally, the training subunit is specifically configured to:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the second student model loss function until the training is finished to obtain second student models;
and inputting the training data into the teacher model and the second student model, and performing iterative optimization on parameters of the second student model based on the third student model loss function until the training is finished to obtain the trained student model.
Optionally, the constructed student model loss function includes: a fourth student model loss function and a fifth student model loss function;
the fourth student model loss function comprises the first loss term, the second loss term and the third loss term, wherein the second loss term and the third loss term have the same coefficient, and the coefficient of the first loss term is much larger than the coefficients of the second loss term and the third loss term;
the fifth student model loss function includes the first loss term, the second loss term, and the third loss term, wherein coefficients of the second loss term and the third loss term are the same, and coefficients of the second loss term and the third loss term are much larger than a coefficient of the first loss term.
Optionally, the training subunit is specifically configured to:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the fourth student model loss function until training is finished to obtain a third student model;
and inputting the training data into the teacher model and the third student model, and performing iterative optimization on parameters of the third student model based on the fifth student model loss function until the training is finished to obtain the trained student model.
Optionally, the error between the output distribution of the teacher model's middle network layer and the output distribution of the student model's middle network layer comprises: and errors between the output distribution of each sub-layer of the middle network layer of the teacher model and the output distribution of each sub-layer of the middle network layer of the student model.
Optionally, the teacher model and the student model are both streaming end-to-end speech recognition models, the streaming end-to-end speech recognition model includes an encoder, a decoder and a union network, the encoder and the decoder are intermediate network layers, and an output of the union network is a final output of the streaming end-to-end speech recognition model.
A knowledge-distillation-based model training apparatus comprising a memory and a processor;
the memory is used for storing programs;
the processor is used for executing the program to realize the steps of the model training method based on the knowledge distillation.
A readable storage medium having stored thereon a computer program which, when executed by a processor, carries out the steps of the knowledge-based distillation model training method as described above.
By means of the technical scheme, the application discloses a knowledge distillation-based model training method, related equipment and a readable storage medium. In the scheme, after a pre-trained teacher model, a student model to be trained, training data and a label of the training data are obtained; training the student model to be trained by taking the training data as a training sample, taking the output distribution of the middle network layer of the student model approaching to the output distribution of the middle network layer of the teacher model, taking the final output distribution of the student model approaching to the final output distribution of the teacher model, taking the label of the final output of the student model approaching to the training data as a training target, and obtaining the trained student model. In the training process, the output and the final output of the teacher model intermediate network layer are simultaneously utilized to guide the learning of the student model, so that the output of the student model intermediate network layer is as close as possible to the output of the teacher model intermediate network layer, and the final output of the student model is ensured to be as close as possible to the final output of the teacher model.
Drawings
Various other advantages and benefits will become apparent to those of ordinary skill in the art upon reading the following detailed description of the preferred embodiments. The drawings are only for purposes of illustrating the preferred embodiments and are not to be construed as limiting the application. Also, like reference numerals are used to refer to like parts throughout the drawings. In the drawings:
FIG. 1 is a schematic flow chart of a knowledge-based distillation model training method disclosed in an embodiment of the present application;
fig. 2 is a schematic diagram of a teacher model and a student model based on a Transducer model according to an embodiment of the present application;
FIG. 3 is a schematic structural diagram of a knowledge-based distillation model training device disclosed in an embodiment of the present application;
fig. 4 is a block diagram of a hardware structure of a knowledge-based distillation model training device disclosed in an embodiment of the present application.
Detailed Description
The technical solutions in the embodiments of the present application will be clearly and completely described below with reference to the drawings in the embodiments of the present application, and it is obvious that the described embodiments are only a part of the embodiments of the present application, and not all of the embodiments. All other embodiments, which can be derived by a person skilled in the art from the embodiments given herein without making any creative effort, shall fall within the protection scope of the present application.
Next, the knowledge-based distillation model training method provided in the present application is described by the following examples.
Referring to fig. 1, fig. 1 is a schematic flow chart of a knowledge-based distillation model training method disclosed in an embodiment of the present application, which may include:
step S101: acquiring a pre-trained teacher model, a student model to be trained, training data and a label of the training data; the teacher model and the student model each include an intermediate network layer.
In this application, the teacher model and the student model may be models in any scene, and it should be noted that the teacher model and the student model have similar structures, except that the parameter amount of the teacher model is larger.
As an implementation, the teacher model and the student model are both streaming end-to-end speech recognition models, the streaming end-to-end speech recognition model includes an encoder, a decoder and a joint network, the encoder and the decoder are intermediate network layers, and the output of the joint network layer is the final output of the streaming end-to-end speech recognition model. In the present application, the streaming end-to-end speech recognition model may be a Transducer model.
For convenience of understanding, referring to fig. 2, fig. 2 is a schematic diagram of a teacher model and a student model based on a Transducer model according to an embodiment of the present application, as shown in fig. 2, the left side is the student model based on the Transducer model, and the right side is the teacher model based on the Transducer model. The transmucer model consists of an Encoder (i.e., the Encoder shown in fig. 2), a Decoder (i.e., the Decoder shown in fig. 2), and a joint Network (i.e., the joint Network shown in fig. 2). Wherein the encoder comprises N encoding layers (i.e. the coder Block shown in fig. 2), the decoder comprises M decoding layers (i.e. the part shown in fig. 2 and consisting of the LSTM Layer in a dashed box), and the input of the joint network is jointly determined by the encoder and the decoder. As can be seen from fig. 2, the teacher model and the student model based on the transponder model are similar in structure, but at the encoding layer of the encoder, the Self-Attention of the teacher model (i.e., the Self-Attention on the right side shown in fig. 2) is in a non-streaming structure, and the Self-Attention of the student model (i.e., the Self-Attention on the left side shown in fig. 2) is in a streaming structure, so that the teacher model is wider than the field of view of the student model, and the mapping dimension in the middle of DNN (Deep neural networks) of the teacher model is larger than that of the student model, so that the parameter amount of the teacher model is larger than that of the student model.
Step S102: and training the student model to be trained by taking the training data as a training sample, and taking the output distribution of the middle network layer of the student model approaching to the output distribution of the middle network layer of the teacher model, the final output distribution of the student model approaching to the final output distribution of the teacher model, and the label of the final output of the student model approaching to the training data as a training target to obtain a trained student model.
The embodiment discloses a knowledge distillation-based model training method. In the method, a pre-trained teacher model, a student model to be trained, training data and a label of the training data are obtained; training the student model to be trained by taking the training data as a training sample, taking the output distribution of the middle network layer of the student model approaching to the output distribution of the middle network layer of the teacher model, taking the final output distribution of the student model approaching to the final output distribution of the teacher model, taking the label of the final output of the student model approaching to the training data as a training target, and obtaining the trained student model. In the training process, the output and the final output of the teacher model intermediate network layer are simultaneously utilized to guide the learning of the student model, so that the output of the student model intermediate network layer is as close as possible to the output of the teacher model intermediate network layer, and the final output of the student model is ensured to be as close as possible to the final output of the teacher model.
In another embodiment of the present application, a specific implementation manner of step S102 is described in detail as follows:
step S201: and constructing a student model loss function by taking the output distribution of the student model intermediate network layer as the output distribution of the teacher model intermediate network layer, the final output distribution of the student model as the final output distribution of the teacher model, and the final output of the student model as the label of the training data.
It should be noted that, in the present application, with the output distribution of the student model intermediate network layer approaching to the output distribution of the teacher model intermediate network layer, the final output distribution of the student model approaches the final output distribution of the teacher model, and the final output of the student model approaches the student model loss function constructed for the training target by the label of the training data in various forms, regardless of form, however, the student model loss function includes a first loss term, a second loss term and a third loss term, wherein the first loss term is used to characterize an error between an output distribution of an intermediate network layer of the teacher model and an output distribution of an intermediate network layer of the student model, the second loss term is used for representing an error between the final output distribution of the teacher model and the final output distribution of the student model; the third loss term is used to characterize an error between a final output of the student model and the label of the training data.
It is further noted that the intermediate network layer of the student model and the intermediate network layer of the teacher model may comprise a plurality of sub-layers. As an implementable embodiment, the error between the output distribution of the teacher model's middle network layer and the output distribution of the student model's middle network layer comprises: and the error between the output distribution of the last sub-layer of the middle network layer of the teacher model and the output distribution of the last sub-layer of the middle network layer of the student model.
However, the intermediate network layer is usually deep, and the gradient disappears easily, and the bottom sub-layer cannot be updated well, so that the final output of the student model may not be close to the final output of the teacher model only by considering the error between the output distribution of the last sub-layer of the intermediate network layer of the teacher model and the output distribution of the last sub-layer of the intermediate network layer of the student model. In order to solve the problem, in the present application, as another possible embodiment, an error between an output distribution of the teacher model at the intermediate network layer and an output distribution of the student model at the intermediate network layer includes: and errors between the output distribution of each sub-layer of the middle network layer of the teacher model and the output distribution of each sub-layer of the middle network layer of the student model.
In this application, the loss form adopted by the first loss term, the second loss term and the third loss term may be determined based on the scene requirement, and this application is not limited in any way. For example, any of Cross Entropy (CE) loss, KL divergence (Kullback-Leibler divergence) loss, Mean Square Error (MSE) loss, and transmission loss may be used.
For the convenience of understanding, taking teacher model and student model based on Transducer model as an example, the first loss term can be expressed as:
Figure BDA0003655014210000111
wherein the content of the first and second substances,
Figure BDA0003655014210000112
for the output of each layer of the encoder of the student model and the teacher model based on the Transducer model,
Figure BDA0003655014210000113
the number of encoder layers is N, the number of decoder layers is M, the first loss term is used for representing the loss of Mean Square Error (MSE) between the output distribution of each encoder layer of the teacher model based on the Transducer model and the output distribution of each encoder layer of the student model based on the Transducer model, and the loss of Mean Square Error (MSE) between the output distribution of each decoder layer of the teacher model based on the Transducer model and the output distribution of each decoder layer of the student model based on the Transducer model.
The second loss term can be expressed as:
(KL(Q S ,Q T )
wherein Q is S For the final output of the student model based on the Transducer model (typically in the form of logarithmic posterior probability), Q T For the final output of the teacher model (typically log posterior probability) based on the Transducer model, this second loss term is usedA KL divergence (Kullback-Leibler divergence) loss between characterizing a final output distribution of the Transducer model-based teacher model and a final output distribution of the Transducer model-based student model.
The third loss term can be expressed as:
Transducer(Q S ,y true )
wherein Q is S For the final output of the student model based on the Transducer model (typically in the form of logarithmic posterior probability), y true Labels are labeled for the training data. This third loss term is used to characterize the Transducer loss between the final output of the Transducer model-based student model and the labeled tags of the training data.
Step S202: and training the student model based on the constructed student model loss function to obtain the trained student model.
It should be noted that, the student model is trained based on the constructed student model loss functions in different forms, the training modes may be different, and the performances of the trained student model are also different, which will be specifically described in detail through the following embodiments.
In another embodiment of the present application, the constructed student model loss function includes: a first student model loss function; the first student model loss function includes the first loss term, the second loss term, and the third loss term, wherein coefficients of the first loss term, the second loss term, and the third loss term are the same. For example, the coefficients of the first loss term, the second loss term, and the third loss term may be all 1.0.
Then as an implementable embodiment, the training the student model based on the constructed student model loss function to obtain a trained student model includes:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the first student model loss function until training is finished to obtain first student models; and determining the first student model as a trained student model.
It should be noted that the coefficients of the first loss term, the second loss term and the third loss term are the same, which indicates the first loss term of the student model, and the optimization difficulty degrees of the network structures corresponding to the second loss term and the third loss term are the same, that is, the same parameter iterative optimization strength is applied to different network structures of the whole student model.
For the convenience of understanding, taking teacher model and student model based on Transducer model as examples, the first student model thereof is a loss function L KD (W student ) 1 Can be expressed as:
Figure BDA0003655014210000121
when the parameters of the student model are subjected to iterative optimization by using a first student model loss function of the student model based on the Transducer model, the same parameter iterative optimization strength is adopted for the parameters of the whole network of the student model.
In the embodiment, through one-time training, the output and the final output of the middle network layer of the teacher model are utilized to perform iterative optimization on the parameters of the student model, and the performance of the obtained trained student model is improved compared with that of the trained student model obtained by the traditional knowledge distillation-based model training method, but the performance of the obtained trained student model is also improved.
In order to further improve the performance of the trained student model, as another possible implementation manner, the training the student model based on the constructed student model loss function to obtain the trained student model includes:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the first student model loss function until training is finished to obtain first student models;
and inputting the training data into the teacher model and the first student model, and performing iterative optimization on parameters of the first student model based on the first student model loss function until the training is finished to obtain the trained student model.
It should be noted that, when the parameters of the first student model are iteratively optimized based on the first student model loss function, the same parameter iterative optimization strength is applied to the parameters of the entire network of the first student model.
In the embodiment, through two times of training, the parameters of the whole network of the student model are subjected to two times of iterative optimization, and the performance of the obtained trained student model is improved compared with that of the student model which is obtained by adopting one time of training and one time of iterative optimization on the parameters of the whole network of the student model.
Taking a teacher model and a student model based on a Transducer model as an example, a training data set is an English message flight input method audio-text pair in 5 thousand hours, and a test set is 60-hour message flight input method data. Student model, 12M parameters. Teacher model 24M parameters. The word standard of the trained student model obtained by the traditional knowledge distillation-based model training method is 78.49%, and the word standard of the trained student model obtained by two times of training is 78.95%, so that the performance is improved.
In another embodiment of the present application, the constructed student model loss function includes: a second student model loss function and a third student model loss function; wherein the second student model loss function includes the first loss term; the third student model loss function includes the second loss term and the third loss term, and coefficients of the first loss term, the second loss term, and the third loss term are the same.
For the convenience of understanding, taking teacher model and student model based on Transducer model as examples, the second student model thereof is a loss function L KD (W student ) 2 And a third student model loss function L KD (W student ) 3 Can be expressed as:
Figure BDA0003655014210000141
L KD (W student ) 3 =1.0×(KL(Q S ,Q T )+Transducer(Q S ,y true ))
then, the training the student model based on the constructed student model loss function to obtain a trained student model, including:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the second student model loss function until the training is finished to obtain second student models; and inputting the training data into the teacher model and the second student model, and performing iterative optimization on the parameters of the models based on the loss function of the third student model until the training is finished to obtain the trained student model.
It should be noted that, when the parameters of the student model are iteratively optimized based on the second student model loss function, only the parameters of the intermediate network layer of the student model are iteratively optimized, and when the parameters of the second student model are iteratively optimized based on the third student model loss function, the parameters of the entire network of the second student model are iteratively optimized.
For the convenience of understanding, taking the teacher model and the student model based on the Transducer model as an example, when the parameters of the student model are iteratively optimized by using the second student model loss function of the student model based on the Transducer model, only the parameters of the encoder and decoder of the student model are iteratively optimized. And when the parameters of the second student model are subjected to iterative optimization by using a third student model loss function of the student model based on the Transducer model, performing iterative optimization on the parameters of the whole network of the second student model.
In this embodiment, only the parameters of the middle network layer of the student model are iteratively optimized, and then the parameters of the whole network of the student model are iteratively optimized, so that compared with twice training, the parameters of the whole network of the student model are iteratively optimized twice with the same strength, and the obtained trained student model has improved performance.
Taking a teacher model and a student model based on a Transducer model as an example, a training data set is an English message flight input method audio-text pair in 5 thousand hours, and a test set is 60-hour message flight input method data. Student model, 12M parameters. Teacher model 24M parameters. Through two times of training, all parameters of the student model are optimized twice, the word standard of the trained student model is 78.95%, only parameters of a middle network layer of the student model are subjected to iterative optimization, parameters of the whole network of the student model are subjected to iterative optimization, and the word standard of the trained student model is 80.3%, so that the performance is improved.
In another embodiment of the present application, the constructed student model loss function includes: a fourth student model loss function and a fifth student model loss function; the fourth student model loss function comprises the first loss term, the second loss term and the third loss term, wherein the second loss term and the third loss term have the same coefficient, and the coefficient of the first loss term is much larger than the coefficients of the second loss term and the third loss term; for example, the coefficient of the first loss term may be 1.0, and the coefficients of the second loss term and the third loss term may be 0.01.
The fifth student model loss function includes the first loss term, the second loss term, and the third loss term, wherein coefficients of the second loss term and the third loss term are the same, and coefficients of the second loss term and the third loss term are much larger than a coefficient of the first loss term. For example, the coefficients of the second loss term and the third loss term may be 1.0, and the coefficient of the first loss term may be 0.01.
Training the student model based on the constructed student model loss function to obtain a trained student model, comprising:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the fourth student model loss function until training is finished to obtain a third student model;
and inputting the training data into the teacher model and the third student model, and performing iterative optimization on parameters of the third student model based on the fifth student model loss function until the training is finished to obtain the trained student model.
It should be noted that, when iterative optimization is performed on the parameters of the student models based on the fourth student model loss function, the part of student model parameter update is mainly an intermediate network layer, parameters of other parts are subjected to fine tuning along with training, when iterative optimization is performed on the parameters of the third student model based on the fifth student model loss function, parameters of the intermediate network layer of the third student model are subjected to fine tuning along with training, and the updating force of parameters of other parts is large.
For the convenience of understanding, taking the teacher model and the student model based on the Transducer model as an example, the fourth student model loss function L is used for the convenience of understanding KD (W student ) 4 Can be expressed as:
Figure BDA0003655014210000161
its fifth student model loss function L KD (W student ) 5 Can be expressed as:
Figure BDA0003655014210000162
when iterative optimization is carried out on the parameters of the student models based on the fourth student model loss function, the parts of student model parameter updating are mainly an encoder and a decoder, and the joint network can be finely adjusted along with training. Through the steps, the student model can be enabled to better learn the output distribution of the teacher model in the coder and decoder part. When the parameters of the third student model are subjected to iterative optimization based on the fifth student model loss function, the third student model encoder and the third student model decoder can be finely adjusted along with training, so that the output of the student model encoder and the output of the student model decoder can be kept close to the teacher model, and the change of the student model encoder and the teacher model caused by the updating of the whole network is avoided.
In this embodiment, the middle network layer of the teacher model guides the middle network layer of the student model to learn and simultaneously fine-tunes the whole student network in the first training stage, and the final output of the teacher model guides the whole student model to learn and simultaneously fine-tunes the middle network layer in the second training stage, so that the performance of the obtained trained student network is improved.
Taking a teacher model and a student model based on a Transducer model as examples, a training data set is an English news input method audio-text pair in 5 thousand hours, and a test set is 60-hour news input method data. Student model, 12M parameters. Teacher model 24M parameters. The standard of the trained student model obtained in the above manner is 80.83%.
The knowledge-based distillation model training device disclosed in the embodiments of the present application is described below, and the knowledge-based distillation model training device described below and the knowledge-based distillation model training method described above may be referred to in correspondence with each other.
Referring to fig. 3, fig. 3 is a schematic structural diagram of a knowledge-based distillation model training device disclosed in the embodiment of the present application. As shown in fig. 3, the knowledge-based distillation model training apparatus may include:
the acquisition unit 11 is configured to acquire a pre-trained teacher model, a student model to be trained, training data, and label labels of the training data; the teacher model and the student models both comprise an intermediate network layer;
the training unit 12 is configured to train the student model to be trained by using the training data as a training sample, using the output distribution of the student model intermediate network layer as the output distribution of the teacher model intermediate network layer, using the final output distribution of the student model as the final output distribution of the teacher model, using the label that the final output of the student model as the final output distribution of the training data as the training target, and obtaining a trained student model.
As an implementable embodiment, the training unit includes:
a student model loss function constructing subunit, configured to construct a student model loss function with an output distribution of the student model intermediate network layer approaching the output distribution of the teacher model intermediate network layer, a final output distribution of the student model approaching the final output distribution of the teacher model, and a label of the student model final output approaching the training data as a training target;
and the training subunit is used for training the student model based on the constructed student model loss function to obtain a trained student model.
As an implementation, the constructed student model loss function includes a first loss term, a second loss term and a third loss term, wherein the first loss term is used for representing an error between an output distribution of the teacher model and an output distribution of the student model, and the second loss term is used for representing an error between a final output distribution of the teacher model and a final output distribution of the student model; the third loss term is used to characterize an error between a final output of the student model and the label of the training data.
As an implementable embodiment, the constructed student model loss function includes: a first student model loss function;
the first student model loss function includes the first loss term, the second loss term, and the third loss term, wherein coefficients of the first loss term, the second loss term, and the third loss term are the same.
As an implementation, the training subunit is specifically configured to:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the first student model loss function until training is finished to obtain first student models;
determining the first student model as a trained student model;
alternatively, the first and second electrodes may be,
and inputting the training data into the teacher model and the first student model, and performing iterative optimization on parameters of the first student model based on the first student model loss function until training is finished to obtain a trained student model.
As an implementable embodiment, the constructed student model loss function includes: a second student model loss function and a third student model loss function;
wherein the second student model loss function includes the first loss term; the third student model loss function includes the second loss term and the third loss term, and coefficients of the first loss term, the second loss term, and the third loss term are the same.
As an implementation, the training subunit is specifically configured to:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the second student model loss function until the training is finished to obtain second student models;
and inputting the training data into the teacher model and the second student model, and performing iterative optimization on parameters of the second student model based on the third student model loss function until the training is finished to obtain the trained student model.
As an implementable embodiment, the constructed student model loss function includes: a fourth student model loss function and a fifth student model loss function;
the fourth student model loss function comprises the first loss term, the second loss term and the third loss term, wherein the second loss term and the third loss term have the same coefficient, and the coefficient of the first loss term is much larger than the coefficients of the second loss term and the third loss term;
the fifth student model loss function includes the first loss term, the second loss term, and the third loss term, wherein coefficients of the second loss term and the third loss term are the same, and coefficients of the second loss term and the third loss term are much larger than a coefficient of the first loss term.
As an implementation, the training subunit is specifically configured to:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the fourth student model loss function until training is finished to obtain a third student model;
and inputting the training data into the teacher model and the third student model, and performing iterative optimization on parameters of the third student model based on the fifth student model loss function until training is finished to obtain a trained student model.
As an implementable embodiment, the error between the output distribution of the teacher model's middle network layer and the output distribution of the student model's middle network layer comprises: and errors between the output distribution of each sub-layer of the middle network layer of the teacher model and the output distribution of each sub-layer of the middle network layer of the student model.
As an implementation, the teacher model and the student model are both streaming end-to-end speech recognition models, and the streaming end-to-end speech recognition model includes an encoder, a decoder, and a joint network, where the encoder and decoder are intermediate network layers, and the output of the joint network is the final output of the streaming end-to-end speech recognition model.
Referring to fig. 4, fig. 4 is a block diagram of a hardware structure of a knowledge-based distillation model training device provided in an embodiment of the present application, and referring to fig. 4, the hardware structure of the knowledge-based distillation model training may include: at least one processor 1, at least one communication interface 2, at least one memory 3 and at least one communication bus 4;
in the embodiment of the application, the number of the processor 1, the communication interface 2, the memory 3 and the communication bus 4 is at least one, and the processor 1, the communication interface 2 and the memory 3 complete mutual communication through the communication bus 4;
the processor 1 may be a central processing unit CPU, or an application Specific Integrated circuit asic, or one or more Integrated circuits configured to implement embodiments of the present invention, etc.;
the memory 3 may include a high-speed RAM memory, and may further include a non-volatile memory (non-volatile memory) or the like, such as at least one disk memory;
wherein the memory stores a program and the processor can call the program stored in the memory, the program for:
acquiring a pre-trained teacher model, a student model to be trained, training data and a label of the training data; the teacher model and the student models both comprise an intermediate network layer;
and training the student model to be trained by taking the training data as a training sample, taking the output distribution of the middle network layer of the student model approaching to the output distribution of the middle network layer of the teacher model, taking the final output distribution of the student model approaching to the final output distribution of the teacher model, and taking the label of the final output of the student model approaching to the training data as a training target to obtain the trained student model.
Alternatively, the detailed function and the extended function of the program may be as described above.
Embodiments of the present application further provide a readable storage medium, where a program suitable for being executed by a processor may be stored, where the program is configured to:
acquiring a pre-trained teacher model, a student model to be trained, training data and a label of the training data; the teacher model and the student models both comprise an intermediate network layer;
and training the student model to be trained by taking the training data as a training sample, taking the output distribution of the middle network layer of the student model approaching to the output distribution of the middle network layer of the teacher model, taking the final output distribution of the student model approaching to the final output distribution of the teacher model, and taking the label of the final output of the student model approaching to the training data as a training target to obtain the trained student model.
Alternatively, the detailed function and the extended function of the program may be as described above.
Finally, it should also be noted that, herein, relational terms such as first and second, and the like may be used solely to distinguish one entity or action from another entity or action without necessarily requiring or implying any actual such relationship or order between such entities or actions. Also, the terms "comprises," "comprising," or any other variation thereof, are intended to cover a non-exclusive inclusion, such that a process, method, article, or apparatus that comprises a list of elements does not include only those elements but may include other elements not expressly listed or inherent to such process, method, article, or apparatus. Without further limitation, an element defined by the phrase "comprising an … …" does not exclude the presence of other identical elements in a process, method, article, or apparatus that comprises the element.
The embodiments in the present description are described in a progressive manner, each embodiment focuses on differences from other embodiments, and the same and similar parts among the embodiments are referred to each other.
The previous description of the disclosed embodiments is provided to enable any person skilled in the art to make or use the present application. Various modifications to these embodiments will be readily apparent to those skilled in the art, and the generic principles defined herein may be applied to other embodiments without departing from the spirit or scope of the application. Thus, the present application is not intended to be limited to the embodiments shown herein but is to be accorded the widest scope consistent with the principles and novel features disclosed herein.

Claims (13)

1. A knowledge-distillation-based model training method, the method comprising:
acquiring a pre-trained teacher model, a student model to be trained, training data and a label of the training data; the teacher model and the student models both comprise an intermediate network layer;
and training the student model to be trained by taking the training data as a training sample, taking the output distribution of the middle network layer of the student model approaching to the output distribution of the middle network layer of the teacher model, taking the final output distribution of the student model approaching to the final output distribution of the teacher model, and taking the label of the final output of the student model approaching to the training data as a training target to obtain the trained student model.
2. The method according to claim 1, wherein the training the student model to be trained to obtain the trained student model by using the training data as the training sample, using the output distribution of the student model intermediate network layer as the output distribution of the teacher model intermediate network layer, the final output distribution of the student model as the output distribution of the teacher model intermediate network layer, and the final output distribution of the student model as the output distribution of the teacher model, and using the label of the training data as the training target, comprises:
constructing a student model loss function for a training target by using the output distribution of the student model intermediate network layer approaching to the output distribution of the teacher model intermediate network layer, the final output distribution of the student model approaching to the final output distribution of the teacher model, and the final output of the student model approaching to the label of the training data;
and training the student model based on the constructed student model loss function to obtain the trained student model.
3. The method of claim 2, wherein the constructed student model loss function comprises a first loss term, a second loss term and a third loss term, wherein the first loss term is used for characterizing an error between an output distribution of the teacher model's middle network layer and an output distribution of the student model's middle network layer, and the second loss term is used for characterizing an error between a final output distribution of the teacher model and a final output distribution of the student model; the third loss term is used to characterize an error between a final output of the student model and the label of the training data.
4. The method of claim 3, wherein the constructed student model loss function comprises: a first student model loss function;
the first student model loss function includes the first loss term, the second loss term, and the third loss term, wherein coefficients of the first loss term, the second loss term, and the third loss term are the same.
5. The method of claim 4, wherein training the student model based on the constructed student model loss function to obtain a trained student model comprises:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the first student model loss function until training is finished to obtain first student models;
determining the first student model as a trained student model;
alternatively, the first and second electrodes may be,
and inputting the training data into the teacher model and the first student model, and performing iterative optimization on parameters of the first student model based on the first student model loss function until the training is finished to obtain the trained student model.
6. The method of claim 3, wherein the constructed student model loss function comprises: a second student model loss function and a third student model loss function;
wherein the second student model loss function includes the first loss term; the third student model loss function includes the second loss term and the third loss term, and coefficients of the first loss term, the second loss term, and the third loss term are the same.
7. The method of claim 6, wherein training the student model based on the constructed student model loss function to obtain a trained student model comprises:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the second student model loss function until the training is finished to obtain second student models;
and inputting the training data into the teacher model and the second student model, and performing iterative optimization on parameters of the second student model based on the third student model loss function until the training is finished to obtain the trained student model.
8. The method of claim 3, wherein the constructed student model loss function comprises: a fourth student model loss function and a fifth student model loss function;
the fourth student model loss function comprises the first loss term, the second loss term and the third loss term, wherein the second loss term and the third loss term have the same coefficient, and the coefficient of the first loss term is much larger than the coefficients of the second loss term and the third loss term;
the fifth student model loss function includes the first loss term, the second loss term, and the third loss term, wherein coefficients of the second loss term and the third loss term are the same, and coefficients of the second loss term and the third loss term are much larger than a coefficient of the first loss term.
9. The method of claim 8, wherein training the student model based on the constructed student model loss function to obtain a trained student model comprises:
inputting the training data into the teacher model and the student models, and performing iterative optimization on parameters of the student models based on the fourth student model loss function until training is finished to obtain a third student model;
and inputting the training data into the teacher model and the third student model, and performing iterative optimization on parameters of the third student model based on the fifth student model loss function until the training is finished to obtain the trained student model.
10. The method of claim 3, wherein the error between the output distribution of the teacher model's middle network layer and the output distribution of the student model's middle network layer comprises: and errors between the output distribution of each sub-layer of the middle network layer of the teacher model and the output distribution of each sub-layer of the middle network layer of the student model.
11. A knowledge-based distillation model training apparatus, the apparatus comprising:
the device comprises an acquisition unit, a training unit and a control unit, wherein the acquisition unit is used for acquiring a pre-trained teacher model, a student model to be trained, training data and a label of the training data; the teacher model and the student models both comprise an intermediate network layer;
and the training unit is used for training the student model to be trained by taking the training data as a training sample and taking the output distribution of the middle network layer of the student model as the output distribution of the middle network layer of the teacher model, the final output distribution of the student model as the final output distribution of the teacher model and the label of the final output of the student model as the training target, so as to obtain the trained student model.
12. A knowledge-based distillation model training apparatus comprising a memory and a processor;
the memory is used for storing programs;
the processor, configured to execute the program, implementing the steps of the knowledge-based distillation model training method according to any one of claims 1 to 10.
13. A readable storage medium, on which a computer program is stored which, when being executed by a processor, carries out the steps of the method for knowledge distillation based model training as defined in any one of claims 1 to 10.
CN202210550915.4A 2022-05-20 2022-05-20 Knowledge distillation-based model training method, related equipment and readable storage medium Pending CN114861885A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210550915.4A CN114861885A (en) 2022-05-20 2022-05-20 Knowledge distillation-based model training method, related equipment and readable storage medium

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210550915.4A CN114861885A (en) 2022-05-20 2022-05-20 Knowledge distillation-based model training method, related equipment and readable storage medium

Publications (1)

Publication Number Publication Date
CN114861885A true CN114861885A (en) 2022-08-05

Family

ID=82640217

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210550915.4A Pending CN114861885A (en) 2022-05-20 2022-05-20 Knowledge distillation-based model training method, related equipment and readable storage medium

Country Status (1)

Country Link
CN (1) CN114861885A (en)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116091773A (en) * 2023-02-02 2023-05-09 北京百度网讯科技有限公司 Training method of image segmentation model, image segmentation method and device
CN117372785A (en) * 2023-12-04 2024-01-09 吉林大学 Image classification method based on feature cluster center compression

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116091773A (en) * 2023-02-02 2023-05-09 北京百度网讯科技有限公司 Training method of image segmentation model, image segmentation method and device
CN116091773B (en) * 2023-02-02 2024-04-05 北京百度网讯科技有限公司 Training method of image segmentation model, image segmentation method and device
CN117372785A (en) * 2023-12-04 2024-01-09 吉林大学 Image classification method based on feature cluster center compression
CN117372785B (en) * 2023-12-04 2024-03-26 吉林大学 Image classification method based on feature cluster center compression

Similar Documents

Publication Publication Date Title
CN110188331B (en) Model training method, dialogue system evaluation method, device, equipment and storage medium
CN111046152B (en) Automatic FAQ question-answer pair construction method and device, computer equipment and storage medium
CN114861885A (en) Knowledge distillation-based model training method, related equipment and readable storage medium
CN112328742B (en) Training method and device based on artificial intelligence, computer equipment and storage medium
CN111602148A (en) Regularized neural network architecture search
CN112214604A (en) Training method of text classification model, text classification method, device and equipment
CN114511472B (en) Visual positioning method, device, equipment and medium
CN112528637B (en) Text processing model training method, device, computer equipment and storage medium
CN110704626A (en) Short text classification method and device
US20190279036A1 (en) End-to-end modelling method and system
CN113609965B (en) Training method and device of character recognition model, storage medium and electronic equipment
CN112699215B (en) Grading prediction method and system based on capsule network and interactive attention mechanism
CN113192497B (en) Speech recognition method, device, equipment and medium based on natural language processing
CN114091450B (en) Judicial domain relation extraction method and system based on graph convolution network
CN110704510A (en) User portrait combined question recommendation method and system
CN115630145A (en) Multi-granularity emotion-based conversation recommendation method and system
CN114637911A (en) Next interest point recommendation method of attention fusion perception network
CN110675879B (en) Audio evaluation method, system, equipment and storage medium based on big data
CN114625882B (en) Network construction method for improving unique diversity of image text description
CN110929532A (en) Data processing method, device, equipment and storage medium
CN113593606A (en) Audio recognition method and device, computer equipment and computer-readable storage medium
EP4322066A1 (en) Method and apparatus for generating training data
JP2021039220A (en) Speech recognition device, learning device, speech recognition method, learning method, speech recognition program, and learning program
CN112735392B (en) Voice processing method, device, equipment and storage medium
CN113312445B (en) Data processing method, model construction method, classification method and computing equipment

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination