WO2022105121A1 - Distillation method and apparatus applied to bert model, device, and storage medium - Google Patents

Distillation method and apparatus applied to bert model, device, and storage medium Download PDF

Info

Publication number
WO2022105121A1
WO2022105121A1 PCT/CN2021/090524 CN2021090524W WO2022105121A1 WO 2022105121 A1 WO2022105121 A1 WO 2022105121A1 CN 2021090524 W CN2021090524 W CN 2021090524W WO 2022105121 A1 WO2022105121 A1 WO 2022105121A1
Authority
WO
WIPO (PCT)
Prior art keywords
model
original
distillation
target
layer
Prior art date
Application number
PCT/CN2021/090524
Other languages
French (fr)
Chinese (zh)
Inventor
朱桂良
Original Assignee
平安科技(深圳)有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 平安科技(深圳)有限公司 filed Critical 平安科技(深圳)有限公司
Publication of WO2022105121A1 publication Critical patent/WO2022105121A1/en

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F17/00Digital computing or data processing equipment or methods, specially adapted for specific functions
    • G06F17/10Complex mathematical operations
    • G06F17/16Matrix or vector computation, e.g. matrix-matrix or matrix-vector multiplication, matrix factorization
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/044Recurrent networks, e.g. Hopfield networks

Definitions

  • the present application relates to the technical field of deep learning, and in particular, to a distillation method, apparatus, computer equipment and storage medium applied to a BERT model.
  • the purpose of the embodiments of the present application is to propose a distillation method, device, computer equipment and storage medium applied to the BERT model, so as to solve the problem that it is difficult to balance the loss parameters in the traditional deep model distillation method.
  • the embodiment of the present application provides a distillation method applied to the BERT model, which adopts the following technical solutions:
  • Model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
  • a model training operation is performed on the intermediate reduced model based on the training data to obtain a target reduced model.
  • the embodiment of the present application also provides a distillation device applied to the BERT model, which adopts the following technical solutions:
  • a request receiving module configured to receive a model distillation request sent by a user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
  • the original model acquisition module is used for reading the local database, and in the local database, the trained original BERT model corresponding to the distillation object identifier is obtained, and the loss function of the original BERT model is cross entropy;
  • the default model building module is used to construct a default reduced model to be trained that is consistent with the trained original BERT model structure, and the loss function of the default reduced model is cross entropy;
  • a distillation operation module configured to perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model
  • a training data acquisition module used for acquiring the training data of the intermediate reduced model in the local database
  • a model training module configured to perform a model training operation on the intermediate reduced model based on the training data to obtain a target reduced model.
  • the embodiment of the present application also provides a computer device, which adopts the following technical solutions:
  • Model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
  • a model training operation is performed on the intermediate reduced model based on the training data to obtain a target reduced model.
  • the embodiments of the present application also provide a computer-readable storage medium, which adopts the following technical solutions:
  • the computer-readable storage medium stores computer-readable instructions, and when the computer-readable instructions are executed by the processor, implements the steps of the distillation method applied to the BERT model as described below:
  • Model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
  • a model training operation is performed on the intermediate reduced model based on the training data to obtain a target reduced model.
  • the distillation method, device, computer equipment and storage medium applied to the BERT model provided by the embodiments of the present application mainly have the following beneficial effects:
  • the embodiment of the present application provides a distillation method applied to a BERT model, receiving a model distillation request sent by a user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient; Obtain a trained original BERT model corresponding to the distillation object identifier in the , and the loss function of the original BERT model is cross entropy; construct a default simplified model to be trained consistent with the trained original BERT model structure, The loss function of the default reduced model is cross entropy; the distillation operation is performed on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model; the training data of the intermediate reduced model is obtained in the local database; The training data is used to perform a model training operation on the intermediate reduced model to obtain a target reduced model.
  • the simplified BERT model retains the same model structure as the original BERT model, the difference is the number of layers, which makes the amount of code changes smaller, and the prediction codes of the large model and the small model are consistent, and the original code can be reused to make the model
  • the process of distillation there is no need to balance the weights of each loss parameter, thereby reducing the difficulty of the deep model distillation method.
  • the tasks in each stage of training the simplified BERT model remain consistent, which makes the convergence of the simplified BERT model more stable.
  • Fig. 1 is the realization flow chart of the distillation method applied to the BERT model provided by the first embodiment of the present application;
  • Fig. 2 is the realization flow chart of step S104 in Fig. 1;
  • Fig. 3 is the realization flow chart of step S105 in Fig. 1;
  • Fig. 4 is the realization flow chart of the parameter optimization operation provided by Embodiment 1 of the present application.
  • Fig. 5 is the realization flow chart of step S403 in Fig. 4;
  • Fig. 6 is the structural representation of the distillation apparatus applied to the BERT model provided by the second embodiment of the present application.
  • FIG. 7 is a schematic structural diagram of an embodiment of a computer device according to the present application.
  • FIG. 1 shows the implementation flow chart of the distillation method applied to the BERT model provided according to the first embodiment of the present application. For the convenience of description, only the part related to the present application is shown.
  • step S101 a model distillation request sent by a user terminal is received, where the model distillation request at least carries a distillation object identifier and a distillation coefficient.
  • a user terminal refers to a terminal device used to execute the image processing method for preventing credential abuse provided by the present application
  • the current terminal may be, for example, a mobile phone, a smart phone, a notebook computer, a digital broadcast receiver, Mobile terminals such as PDAs (Personal Digital Assistants), PADs (Tablet Computers), PMPs (Portable Multimedia Players), navigation devices, etc., as well as stationary terminals such as digital TVs, desktop computers, etc.
  • PDAs Personal Digital Assistants
  • PADs Tablett Computers
  • PMPs Portable Multimedia Players
  • navigation devices etc.
  • stationary terminals such as digital TVs, desktop computers, etc.
  • the examples are only for the convenience of understanding, and are not used to limit the present application.
  • the distillation object identifier is mainly used to uniquely identify the model object that needs to be distilled, and the distillation object identifier may be named based on the model name.
  • the identification can be named based on the abbreviation of the name, as an example, such as: sjsbmx, yysbmx, etc.; the distillation object identification can also be named by a serial number, as an example, such as: 001, 002, etc., it should be understood that the distillation object here
  • the examples of marks are only for convenience of understanding, and are not used to limit the present application.
  • the distillation coefficient is mainly used to confirm the multiple of reducing the number of layers of the original BERT model.
  • the distillation coefficient is 3, which should be It is understood that the examples of distillation coefficients here are only for convenience of understanding, and are not intended to limit the present application.
  • step S102 the local database is read, and the trained original BERT model corresponding to the distillation object identifier is obtained in the local database, and the loss function of the original BERT model is cross entropy.
  • the local database refers to a database resident on a machine running a client application.
  • the local database provides the fastest response time. Because there is no network transfer between the client (application) and the server.
  • the local database pre-stores a variety of trained original BERT models to solve problems in many fields such as computer vision and speech recognition.
  • the Bert model can be divided into a vector (embedding) layer, a transformer (transformer) layer, and a prediction (prediction) layer, each of which is a different representation of knowledge.
  • the original BERT model consists of a 12-layer transformer (a model based on an "encoder-decoder" structure), and the original BERT model uses cross-entropy as the loss function.
  • the cross entropy is mainly used to measure the difference information between two probability distributions.
  • the performance of language models is usually measured by cross-entropy and perplexity.
  • the meaning of cross-entropy is the difficulty of text recognition with the model, or from a compression point of view, how many bits are used to encode each word on average.
  • the meaning of complexity is to use the model to represent the average number of branches of this text, and its inverse can be regarded as the average probability of each word.
  • Smoothing refers to assigning a probability value to the unobserved N-gram combination to ensure that the word sequence can always obtain a probability value through the language model.
  • step S103 a default reduced model to be trained that is consistent with the trained original BERT model structure is constructed, and the loss function of the default reduced model is cross entropy.
  • the constructed default reduced model retains the same model structure as BERT, the difference lies in the number of transformer layers.
  • step S104 a distillation operation is performed on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model.
  • the distillation operation specifically includes the distillation of the transformer layer and parameter initialization.
  • the distillation transformer layer means that if the distillation coefficient is 3, the first to third layers of the trained original BERT model will be replaced with the first layer of the default reduced model; the trained original BERT model The fourth to sixth layers will be replaced to the second layer of the default reduced model; the seventh to ninth layers of the trained original BERT model will be replaced to the third layer of the default reduced model; Layers ten to twelfth will be replaced to the fourth layer of the default reduced model.
  • the probability of each layer being replaced may be determined by using the Bernoulli distribution probability.
  • parameter initialization refers to replacing the parameters of the embedding, pooler, and fully connected layers to the parameter positions corresponding to the default simplified model according to the parameters of each level in the trained original BERT model.
  • step S105 the training data of the intermediate reduced model is obtained from the local database.
  • the training data of the reduced model may be labeled data obtained by training the above-mentioned original BERT model, or may be additional unlabeled data.
  • the original training data after training of the original BERT model can be obtained; the temperature parameter of the softmax layer of the original BERT model can be increased to obtain the increased BERT model, and the original training data can be input into the increased BERT model for prediction operation , get the mean result label; perform a screening operation on the original training data based on the label information, and obtain the label of the filtered result with the label; select the reduced model training data based on the enlarged training data and the filtered training data.
  • step S106 a model training operation is performed on the intermediate reduced model based on the training data to obtain the target reduced model.
  • a distillation method applied to a BERT model which receives a model distillation request sent by a user terminal, and the model distillation request carries at least a distillation object identifier and a distillation coefficient; Obtain the trained original BERT model corresponding to the identity of the distilled object.
  • the loss function of the original BERT model is cross entropy; construct a default reduced model to be trained that is consistent with the structure of the trained original BERT model.
  • the loss function of the default reduced model is Cross entropy; perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model; obtain the training data of the intermediate reduced model in the local database; perform model training operations on the intermediate reduced model based on the training data to obtain the target reduced model. Since the simplified BERT model retains the same model structure as the original BERT model, the difference is the number of layers, which makes the amount of code changes smaller, and the prediction codes of the large model and the small model are consistent, and the original code can be reused to make the model In the process of distillation, there is no need to balance the weights of each loss parameter, thereby reducing the difficulty of the deep model distillation method. At the same time, the tasks in each stage of training the simplified BERT model remain consistent, which makes the convergence of the simplified BERT model more stable.
  • step S104 in FIG. 1 a flowchart of the implementation of step S104 in FIG. 1 is shown. For the convenience of description, only the parts related to the present application are shown.
  • step S104 specifically includes: step S201 , step S202 and step S203 .
  • step S201 a grouping operation is performed on the transformer layer of the original BERT model based on the distillation coefficient to obtain a grouped transformer layer.
  • the grouping operation refers to that the number of transformer layers is grouped according to the distillation coefficient. For example, for example, the number of transformer layers is 12 and the distillation coefficient is 3.
  • the grouping operation divides the 12 transformer layers into 4 groups.
  • step S202 extraction operations are respectively performed in the grouped transformer layers based on the Bernoulli distribution to obtain the transformer layers to be replaced.
  • step S203 the transformer layers to be replaced are respectively replaced with default reduced models to obtain intermediate reduced models.
  • the distillation method based on layer replacement retains the same model structure as BERT, the difference is the number of layers, so that the amount of code changes is small, and the prediction codes of the large model and the small model are consistent,
  • the original code can be reused, because during distillation, some layers of the small model are randomly initialized to the weight of the trained large model mapping layer based on Bernoulli sampling, which makes the model converge faster and reduces the number of training rounds.
  • step S105 in FIG. 1 a flowchart of the implementation of step S105 in FIG. 1 is shown. For the convenience of description, only the parts related to the present application are shown.
  • step S105 specifically includes: step S301 , step S302 , step S303 , step S304 and step S305 .
  • step S301 the original training data after the original BERT model training is obtained.
  • the original training data refers to the training data of inputting the training data into the untrained original BERT model before obtaining the trained original BERT model.
  • step S302 the temperature parameter of the softmax layer of the original BERT model is increased to obtain an increased BERT model.
  • step S303 the original training data is input into the BERT model for increasing the prediction operation, and the mean result label is obtained.
  • each original training data can obtain its final classification probability vector in each original BERT model, and selecting the maximum probability is the judgment result of the model for the current original training data.
  • the t probability vector can be output, and then the average of the t probability vectors can be calculated as the final probability output vector of the current original training data. After all the original training data have completed the prediction operation, the corresponding Mean result label.
  • step S304 a screening operation is performed on the original training data based on the label information to obtain a labelled screening result label.
  • step S305 the reduced model training data is selected based on the enlarged training data and the filtered training data.
  • the selected training data of the reduced model can be expressed as:
  • Target represents the label that is finally used as the training data of the intermediate reduced model
  • hard_target represents the label of the screening result
  • soft_target represents the label of the mean result
  • a and b represent the weight of the control label fusion.
  • FIG. 4 a flowchart for realizing the parameter optimization operation provided in Embodiment 1 of the present application is shown. For the convenience of description, only the part related to the present application is shown.
  • the foregoing method further includes: step S401, step S402, step S403, and step S404.
  • step S401 the optimized training data is obtained from the local database.
  • the optimized training data is mainly used to optimize the parameters of the target reduced model.
  • the optimized training data is input into the trained original BERT model and the target reduced model respectively.
  • the original The difference between the output of each transformer layer of the BERT model and the target reduction model.
  • step S402 the optimized training data is input into the trained original BERT model and the target reduced model respectively, and the original transformer layer output data and the target transformer layer output data are obtained respectively.
  • step S403 the distillation loss data of the output data of the original transformer layer and the output data of the target transformer layer are calculated based on the soil removal distance.
  • the earth removal distance is a measure of the distance between two probability distributions on a region D.
  • the attention (attention) matrix data output by the original transformer layer and the target transformer layer can be obtained respectively, and the attention EMD distance of the attention (attention) matrix data of the two can be calculated; then the original transformer layer and the target transformer layer output respectively.
  • FFN Fely Connected Feedforward Neural Network
  • step S404 a parameter optimization operation is performed on the target reduced model according to the distillation loss data to obtain an optimized reduced model.
  • the parameters in the target reduced model are optimized until the distillation loss data is less than the preset value. , or the training times meet the preset times, so as to obtain the optimized and reduced model.
  • the transformer layer of the target reduced model is selected based on the probability of Bernoulli distribution, there is a certain error in the parameters of the target reduced model, because the transformer layer in the Bert model contributes the most to the model , contains the most abundant information, and the learning ability of the simplified model in this layer is also the most important. Therefore, the loss data between the output of the transformer layer of the original BERT model and the output of the transformer layer of the target simplified model is calculated by using the "earth removal distance EMD", And based on the loss data, the parameters of the target reduced model are optimized to improve the accuracy of the target reduced model, which can ensure that the target model learns more knowledge of the original model.
  • step S403 in FIG. 4 a flowchart of the implementation of step S403 in FIG. 4 is shown. For convenience of description, only the part related to the present application is shown.
  • step S403 specifically includes: step S501 , step S502 , step 503 , step S504 and step S505 .
  • step S501 the original attention matrix output by the original transformer layer and the target attention matrix output by the target transformer layer are obtained.
  • step S502 the attention EMD distance is calculated according to the original attention matrix and the target attention matrix.
  • the attention EMD distance is expressed as:
  • L attn represents the attention EMD distance
  • a T represents the original attention matrix
  • a S represents the target attention matrix
  • f ij represents the amount of knowledge migrated from the i-th original transformer layer to the j-th target transformer layer
  • M represents the number of layers of the original transformer layer
  • N represents the target transformer layer. layers.
  • step S503 the original FFN hidden layer matrix output by the original transformer layer and the target FFN hidden layer matrix output by the target transformer layer are obtained.
  • step S504 the FFN hidden layer EMD distance is calculated according to the original FFN hidden layer matrix and the target FFN hidden layer matrix.
  • the EMD distance of the FFN hidden layer is expressed as:
  • Lffn represents the EMD distance of the FFN hidden layer
  • H T represents the original FFN hidden layer matrix of the original transformer layer
  • H S represents the target FFN hidden layer matrix of the target transformer layer
  • W h represents the transformation matrix
  • f ij represents the amount of knowledge migrated from the original transformer layer of the i-th layer to the target transformer layer of the j-th layer
  • M represents the number of layers of the original transformer layer
  • N represents the target transformer layer number of layers.
  • step S505 the distillation loss data is obtained based on the attention EMD distance and the FFN hidden layer EMD distance.
  • the transformer layer is an important part of the Bert model, and long-distance dependencies can be captured through the self-attention mechanism.
  • a standard transformer mainly includes two parts: the multi-head attention mechanism (Multi-Head Attention, MHA). ) and a fully connected feedforward neural network (FFN).
  • MHA multi-head attention mechanism
  • FNN feedforward neural network
  • EMD is a method of calculating the optimal distance between two distributions using linear programming, which can make the distillation of knowledge more reasonable.
  • the attention EMD distance is expressed as:
  • L attn represents the attention EMD distance
  • a T represents the original attention matrix
  • a S represents the target attention matrix
  • f ij represents the amount of knowledge migrated from the i-th original transformer layer to the j-th target transformer layer
  • M represents the number of layers of the original transformer layer
  • N represents the target transformer layer. layers.
  • the FFN hidden layer EMD distance is expressed as:
  • Lffn represents the EMD distance of the FFN hidden layer
  • H T represents the original FFN hidden layer matrix of the original transformer layer
  • H S represents the target FFN hidden layer matrix of the target transformer layer
  • W h represents the transformation matrix
  • f ij represents the amount of knowledge migrated from the original transformer layer of the i-th layer to the target transformer layer of the j-th layer
  • M represents the number of layers of the original transformer layer
  • N represents the target transformer layer the number of layers.
  • Embodiment 1 of the present application provides a distillation method applied to a BERT model, receiving a model distillation request sent by a user terminal, and the model distillation request at least carries a distillation object identifier and a distillation coefficient; Obtain the trained original BERT model corresponding to the identification of the distillation object, and the loss function of the original BERT model is cross entropy; construct a default reduced model to be trained that is consistent with the structure of the trained original BERT model, and the loss function of the default reduced model is cross entropy; perform distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model; obtain the training data of the intermediate reduced model in the local database; perform model training operations on the intermediate reduced model based on the training data to obtain the target reduced model.
  • the simplified BERT model retains the same model structure as the original BERT model, the difference is the number of layers, which makes the amount of code changes smaller, and the prediction codes of the large model and the small model are consistent, and the original code can be reused to make the model
  • the process of distillation there is no need to balance the weights of each loss parameter, thereby reducing the difficulty of the deep model distillation method.
  • the tasks in each stage of training the simplified BERT model remain consistent, which makes the convergence of the simplified BERT model more stable.
  • the distillation method based on layer replacement retains the same model structure as BERT.
  • the difference is the number of layers, which makes the code changes smaller, and the prediction codes of the large model and the small model are consistent, and the original code can be reused , during distillation, some layers of the small model are randomly initialized to the weight of the trained large model mapping layer based on Bernoulli sampling, which makes the model converge faster and reduces the number of training rounds.
  • the aforementioned storage medium may be a non-volatile storage medium such as a magnetic disk, an optical disk, a read-only memory (Read-Only Memory, ROM), or a random access memory (Random Access Memory, RAM) or the like.
  • the present application provides an embodiment of a distillation apparatus applied to a BERT model, and the apparatus embodiment corresponds to the method embodiment shown in FIG. 1 , Specifically, the device can be applied to various electronic devices.
  • the distillation apparatus 100 applied to the BERT model in this embodiment includes: a request receiving module 110, an original model obtaining module 120, a default model building module 130, a distillation operation module 140, a training data obtaining module 150, and a model training module module 160. in:
  • the request receiving module 110 is configured to receive a model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
  • the original model obtaining module 120 is used to read the local database, and obtain the trained original BERT model corresponding to the distillation object identifier in the local database, and the loss function of the original BERT model is cross entropy;
  • the default model building module 130 is used to construct a default reduced model to be trained that is consistent with the trained original BERT model structure, and the loss function of the default reduced model is cross entropy;
  • a distillation operation module 140 configured to perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model
  • a training data acquisition module 150 configured to acquire the training data of the intermediate reduced model in the local database
  • the model training module 160 is configured to perform a model training operation on the intermediate reduced model based on the training data to obtain the target reduced model.
  • a user terminal refers to a terminal device used to execute the image processing method for preventing credential abuse provided by the present application
  • the current terminal may be, for example, a mobile phone, a smart phone, a notebook computer, a digital broadcast receiver, Mobile terminals such as PDAs (Personal Digital Assistants), PADs (Tablet Computers), PMPs (Portable Multimedia Players), navigation devices, etc., as well as stationary terminals such as digital TVs, desktop computers, etc.
  • PDAs Personal Digital Assistants
  • PADs Tablett Computers
  • PMPs Portable Multimedia Players
  • navigation devices etc.
  • stationary terminals such as digital TVs, desktop computers, etc.
  • the examples are only for the convenience of understanding, and are not used to limit the present application.
  • the distillation object identifier is mainly used to uniquely identify the model object that needs to be distilled.
  • the distillation object identifier may be named based on the model name. For example, for example, a visual recognition model, a speech recognition model, etc.; the distillation object The identification can be named based on the abbreviation of the name, as an example, such as: sjsbmx, yysbmx, etc.; the distillation object identification can also be named by a serial number, as an example, such as: 001, 002, etc., it should be understood that the distillation object here
  • the examples of marks are only for convenience of understanding, and are not used to limit the present application.
  • the distillation coefficient is mainly used to confirm the multiple of reducing the number of layers of the original BERT model.
  • the distillation coefficient is 3, which should be It is understood that the examples of distillation coefficients here are only for convenience of understanding, and are not intended to limit the present application.
  • the local database refers to a database resident on a machine running a client application.
  • the local database provides the fastest response time. Because there is no network transfer between the client (application) and the server.
  • the local database pre-stores a variety of trained original BERT models to solve problems in many fields such as computer vision and speech recognition.
  • the Bert model can be divided into a vector (embedding) layer, a transformer (transformer) layer, and a prediction (prediction) layer, each of which is a different representation of knowledge.
  • the original BERT model consists of a 12-layer transformer (a model based on an "encoder-decoder" structure), and the original BERT model uses cross-entropy as the loss function.
  • the cross entropy is mainly used to measure the difference information between two probability distributions.
  • the performance of language models is usually measured by cross-entropy and perplexity.
  • the meaning of cross-entropy is the difficulty of text recognition with the model, or from a compression point of view, how many bits are used to encode each word on average.
  • the meaning of complexity is to use the model to represent the average number of branches of this text, and its inverse can be regarded as the average probability of each word.
  • Smoothing refers to assigning a probability value to the unobserved N-gram combination to ensure that the word sequence can always obtain a probability value through the language model.
  • the constructed default reduced model retains the same model structure as BERT, the difference lies in the number of transformer layers.
  • the distillation operation specifically includes the distillation of the transformer layer and parameter initialization.
  • the distillation transformer layer means that if the distillation coefficient is 3, the first to third layers of the trained original BERT model will be replaced with the first layer of the default reduced model; the trained original BERT model The fourth to sixth layers will be replaced to the second layer of the default reduced model; the seventh to ninth layers of the trained original BERT model will be replaced to the third layer of the default reduced model; Layers ten to twelfth will be replaced to the fourth layer of the default reduced model.
  • the probability of each layer being replaced may be determined by using the Bernoulli distribution probability.
  • parameter initialization refers to replacing the parameters of the embedding, pooler, and fully connected layers to the parameter positions corresponding to the default simplified model according to the parameters of each level in the trained original BERT model.
  • the training data of the reduced model may be labeled data obtained by training the above-mentioned original BERT model, or may be additional unlabeled data.
  • the original training data after training of the original BERT model can be obtained; the temperature parameter of the softmax layer of the original BERT model can be increased to obtain the increased BERT model, and the original training data can be input into the increased BERT model for prediction operation , get the mean result label; perform a screening operation on the original training data based on the label information, and obtain the label of the filtered result with the label; select the reduced model training data based on the enlarged training data and the filtered training data.
  • a distillation device applied to the BERT model is provided. Since the simplified BERT model retains the same model structure as the original BERT model, the difference is the number of layers, so that the amount of code changes is small and large.
  • the prediction code of the model and the small model is consistent, and the original code can be reused, so that the model does not need to balance the weight of each loss parameter during the distillation process, thereby reducing the difficulty of the deep model distillation method.
  • the tasks of the stages are all consistent, which makes the convergence of the simplified BERT model more stable.
  • the above-mentioned distillation operation module 140 specifically includes: a grouping operation sub-module, an extraction operation sub-module, and a replacement operation sub-module. in:
  • the grouping operation sub-module is used to group the transformer layer of the original BERT model based on the distillation coefficient to obtain the grouped transformer layer;
  • the extraction operation sub-module is used to perform extraction operations in the grouped transformer layers based on the Bernoulli distribution to obtain the transformer layers to be replaced;
  • the replacement operation sub-module is used to replace the transformer layer to be replaced with the default reduced model respectively to obtain the intermediate reduced model.
  • the above-mentioned training data acquisition module 150 specifically includes: an original training data acquisition sub-module, a parameter sub-adjustment model, a prediction operation sub-module, a screening operation sub-module, and a training data acquisition sub-module submodule. in:
  • the original training data acquisition sub-module is used to obtain the original training data after the original BERT model training
  • the parameter sub-adjustment model is used to increase the temperature parameters of the softmax layer of the original BERT model to obtain an increased BERT model
  • the prediction operation sub-module is used to input the original training data into the BERT model for prediction operation, and obtain the average result label;
  • the filtering operation sub-module is used to perform the filtering operation on the original training data based on the label information, and obtain the label of the filtering result with the label;
  • the training data acquisition sub-module is used to select the reduced model training data based on amplifying the training data and filtering the training data.
  • the above-mentioned distillation apparatus 100 applied to the BERT model further includes: an optimization training data acquisition module, a distillation loss data calculation module, and a parameter optimization module. in:
  • the optimized training data acquisition module is used to obtain optimized training data in the local database
  • the optimized training data input module is used to input the optimized training data into the trained original BERT model and the target reduced model, respectively, to obtain the original transformer layer output data and the target transformer layer output data;
  • the distillation loss data calculation module is used to calculate the distillation loss data of the output data of the original transformer layer and the output data of the target transformer layer based on the moving distance;
  • the parameter optimization module is used to optimize the parameters of the target reduced model according to the distillation loss data to obtain the optimized reduced model.
  • the above-mentioned distillation loss data calculation module specifically includes: a target attention matrix acquisition sub-module, an attention EMD distance calculation sub-module, a target FFN hidden layer matrix acquisition sub-module, FFN Hidden layer EMD distance calculation sub-module and distillation loss data acquisition sub-module. in:
  • the target attention matrix acquisition sub-module is used to obtain the original attention matrix output by the original transformer layer and the target attention matrix output by the target transformer layer;
  • the attention EMD distance calculation sub-module is used to calculate the attention EMD distance according to the original attention matrix and the target attention matrix
  • the target FFN hidden layer matrix acquisition sub-module is used to obtain the original FFN hidden layer matrix output by the original transformer layer and the target FFN hidden layer matrix output by the target transformer layer;
  • the FFN hidden layer EMD distance calculation sub-module is used to calculate the FFN hidden layer EMD distance according to the original FFN hidden layer matrix and the target FFN hidden layer matrix;
  • the distillation loss data acquisition sub-module is used to obtain distillation loss data based on the attention EMD distance and the FFN hidden layer EMD distance.
  • the attention EMD distance is expressed as:
  • L attn represents the attention EMD distance
  • a T represents the original attention matrix
  • a S represents the target attention matrix
  • f ij represents the amount of knowledge migrated from the i-th original transformer layer to the j-th target transformer layer
  • M represents the number of layers of the original transformer layer
  • N represents the target transformer layer. layers.
  • the FFN hidden layer EMD distance is expressed as:
  • Lffn represents the EMD distance of the FFN hidden layer
  • H T represents the original FFN hidden layer matrix of the original transformer layer
  • H S represents the target FFN hidden layer matrix of the target transformer layer
  • W h represents the transformation matrix
  • f ij represents the amount of knowledge migrated from the original transformer layer of the i-th layer to the target transformer layer of the j-th layer
  • M represents the number of layers of the original transformer layer
  • N represents the target transformer layer the number of layers.
  • the second embodiment of the present application provides a distillation device applied to the BERT model. Since the simplified BERT model retains the same model structure as the original BERT model, the difference is the number of layers, which makes the amount of code changes small, and The prediction codes of the large model and the small model are consistent, and the original code can be reused, so that the model does not need to balance the weight of each loss parameter in the process of distillation, thereby reducing the difficulty of the deep model distillation method, and at the same time, training the simplified BERT model The tasks of each stage are kept consistent, which makes the convergence of the simplified BERT model more stable. In addition, the distillation method based on layer replacement retains the same model structure as BERT.
  • the difference is the number of layers, which makes the code changes smaller, and the prediction codes of the large model and the small model are consistent, and the original code can be reused , during distillation, some layers of the small model are randomly initialized to the weight of the trained large model mapping layer based on Bernoulli sampling, which makes the model converge faster and reduces the number of training rounds.
  • FIG. 7 is a block diagram of the basic structure of a computer device according to this embodiment.
  • the computer device 200 includes a memory 210 , a processor 220 , and a network interface 230 that communicate with each other through a system bus. It should be noted that only the computer device 200 with components 210-230 is shown in the figure, but it should be understood that implementation of all of the shown components is not required, and more or less components may be implemented instead.
  • the computer device here is a device that can automatically perform numerical calculation and/or information processing according to pre-set or stored instructions, and its hardware includes but is not limited to microprocessors, special-purpose Integrated circuit (Application Specific Integrated Circuit, ASIC), programmable gate array (Field-Programmable Gate Array, FPGA), digital processor (Digital Signal Processor, DSP), embedded equipment, etc.
  • ASIC Application Specific Integrated Circuit
  • FPGA Field-Programmable Gate Array
  • DSP Digital Signal Processor
  • embedded equipment etc.
  • the computer equipment may be a desktop computer, a notebook computer, a palmtop computer, a cloud server and other computing equipment.
  • the computer device can perform human-computer interaction with the user through a keyboard, a mouse, a remote control, a touch pad or a voice control device.
  • the memory 210 includes at least one type of readable storage medium, including flash memory, hard disk, multimedia card, card-type memory (eg, SD or DX memory, etc.), random access memory (RAM), static Random Access Memory (SRAM), Read Only Memory (ROM), Electrically Erasable Programmable Read Only Memory (EEPROM), Programmable Read Only Memory (PROM), magnetic memory, magnetic disks, optical disks, etc.
  • the computer readable storage Media can be non-volatile or volatile.
  • the memory 210 may be an internal storage unit of the computer device 200 , such as a hard disk or a memory of the computer device 200 .
  • the memory 210 may also be an external storage device of the computer device 200, such as a plug-in hard disk, a smart memory card (Smart Media Card, SMC), a secure digital (Secure Digital, SD) card, flash memory card (Flash Card), etc.
  • the memory 210 may also include both the internal storage unit of the computer device 200 and its external storage device.
  • the memory 210 is generally used to store the operating system and various application software installed on the computer device 200 , such as computer-readable instructions applied to the distillation method of the BERT model.
  • the memory 210 can also be used to temporarily store various types of data that have been output or will be output.
  • the processor 220 may be a central processing unit (Central Processing Unit, CPU), a controller, a microcontroller, a microprocessor, or other data processing chips in some embodiments.
  • the processor 220 is typically used to control the overall operation of the computer device 200 .
  • the processor 220 is configured to execute the computer-readable instructions stored in the memory 210 or process data, for example, the computer-readable instructions for executing the distillation method applied to the BERT model.
  • the network interface 230 may include a wireless network interface or a wired network interface, and the network interface 230 is generally used to establish a communication connection between the computer device 200 and other electronic devices.
  • the steps of the above distillation method applied to the BERT model include:
  • Model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
  • a model training operation is performed on the intermediate reduced model based on the training data to obtain a target reduced model.
  • the distillation method applied to the BERT model provided by this application because the simplified BERT model retains the same model structure as the original BERT model, the difference is the number of layers, which makes the amount of code changes smaller, and the prediction codes of the large model and the small model are It is consistent, and the original code can be reused, so that the model does not need to balance the weight of each loss parameter in the process of distillation, thereby reducing the difficulty of the deep model distillation method. , making the convergence of the reduced BERT model more stable.
  • the present application also provides another embodiment, that is, to provide a computer-readable storage medium, where the computer-readable storage medium stores computer-readable instructions, and the computer-readable instructions can be executed by at least one processor to causing the at least one processor to perform the steps of the distillation method applied to the BERT model as follows:
  • Model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
  • a model training operation is performed on the intermediate reduced model based on the training data to obtain a target reduced model.
  • the distillation method applied to the BERT model provided by this application because the simplified BERT model retains the same model structure as the original BERT model, the difference is the number of layers, which makes the amount of code changes smaller, and the prediction codes of the large model and the small model are It is consistent, and the original code can be reused, so that the model does not need to balance the weight of each loss parameter in the process of distillation, thereby reducing the difficulty of the deep model distillation method. , making the convergence of the reduced BERT model more stable.
  • the method of the above embodiment can be implemented by means of software plus a necessary general hardware platform, and of course can also be implemented by hardware, but in many cases the former is better implementation.
  • the technical solution of the present application can be embodied in the form of a software product in essence or in a part that contributes to the prior art, and the computer software product is stored in a storage medium (such as ROM/RAM, magnetic disk, CD-ROM), including several instructions to make a terminal device (which may be a mobile phone, a computer, a server, an air conditioner, or a network device, etc.) execute the methods described in the various embodiments of this application.
  • a storage medium such as ROM/RAM, magnetic disk, CD-ROM

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • General Engineering & Computer Science (AREA)
  • Pure & Applied Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Mathematical Analysis (AREA)
  • Mathematical Optimization (AREA)
  • Computing Systems (AREA)
  • Artificial Intelligence (AREA)
  • Computational Mathematics (AREA)
  • Software Systems (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Health & Medical Sciences (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Algebra (AREA)
  • Databases & Information Systems (AREA)
  • Feedback Control In General (AREA)

Abstract

A distillation method and apparatus applied to a BERT model, a computer device, and a storage medium, which relate to the technical field of deep learning. In the method, as a refined BERT model retains the same model structure as a raw BERT model and the difference being a different number of layers, the amount of change in a code is relatively small. Moreover, prediction codes of a large model and a small model are consistent and a source code may be reused, such that that the weights of loss parameters when a model is being distilled do not need to be balanced, thereby reducing the level of difficulty of a deep model distillation method. Meanwhile, tasks of each stage of training the refined BERT model are consistent, so that convergence of the refined BERT model is more stable.

Description

一种应用于BERT模型的蒸馏方法、装置、设备及存储介质A kind of distillation method, device, equipment and storage medium applied to BERT model
本申请以2020年11月17日提交的申请号为202011288877.7,名称为“一种应用于BERT模型的蒸馏方法、装置、设备及存储介质”的中国发明专利申请为基础,并要求其优先权。This application is based on the Chinese invention patent application with the application number 202011288877.7 filed on November 17, 2020, entitled "A distillation method, device, equipment and storage medium applied to the BERT model", and claims its priority.
技术领域technical field
本申请涉及深度学习技术领域,尤其涉及一种应用于BERT模型的蒸馏方法、装置、计算机设备及存储介质。The present application relates to the technical field of deep learning, and in particular, to a distillation method, apparatus, computer equipment and storage medium applied to a BERT model.
背景技术Background technique
近年在计算机视觉、语音识别等诸多领域,在利用深度网络解决问题的时候人们常常倾向于设计更为复杂的网络收集更多的数据以期获得更好的结果。但是,随之而来的是模型的复杂度急剧提升,直观的表现是模参数越来越多、规模越来越大,需要的硬件资源(内存、GPU)越来越高。不利于模型的部署和应用向移动端的推广。In recent years, in many fields such as computer vision and speech recognition, when using deep networks to solve problems, people often tend to design more complex networks to collect more data in order to obtain better results. However, the complexity of the model has increased dramatically, and the intuitive performance is that there are more and more model parameters, the scale is getting bigger and bigger, and the hardware resources (memory, GPU) required are getting higher and higher. It is not conducive to the deployment of the model and the promotion of the application to the mobile terminal.
现有一种深度模型蒸馏方法,采用蒸馏模型的优势在进行模型蒸馏时匹配各个中间层之间的数据,已实现压缩模型的目的。There is a deep model distillation method, which uses the advantages of the distillation model to match the data between each intermediate layer during model distillation, and has achieved the purpose of compressing the model.
然而,申请人意识到传统的深度模型蒸馏方法普遍不智能,在蒸馏的过程中匹配中间层输出时,往往需要平衡较多损失(loss)参数,例如:下游任务loss、中间层输出loss、相关矩阵loss、注意力矩阵(Attention)loss、等等,从而导致传统的深度模型蒸馏方法存在平衡loss参数较为困难的问题。However, the applicant realizes that traditional deep model distillation methods are generally unintelligent. When matching the output of the intermediate layer during the distillation process, it is often necessary to balance many loss parameters, such as: downstream task loss, intermediate layer output loss, correlation Matrix loss, attention matrix (Attention) loss, etc., which lead to the difficulty of balancing loss parameters in traditional deep model distillation methods.
发明内容SUMMARY OF THE INVENTION
本申请实施例的目的在于提出一种应用于BERT模型的蒸馏方法、装置、计算机设备及存储介质,以解决传统的深度模型蒸馏方法存在平衡loss参数较为困难的问题。The purpose of the embodiments of the present application is to propose a distillation method, device, computer equipment and storage medium applied to the BERT model, so as to solve the problem that it is difficult to balance the loss parameters in the traditional deep model distillation method.
为了解决上述技术问题,本申请实施例提供一种应用于BERT模型的蒸馏方法,采用了如下所述的技术方案:In order to solve the above-mentioned technical problems, the embodiment of the present application provides a distillation method applied to the BERT model, which adopts the following technical solutions:
接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;Receive a model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;Read the local database, obtain the trained original BERT model corresponding to the distillation object identifier in the local database, and the loss function of the original BERT model is cross entropy;
构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;Build a default reduced model to be trained that is consistent with the trained original BERT model structure, and the loss function of the default reduced model is cross entropy;
基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;Perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model;
在所述本地数据库中获取所述中间精简模型的训练数据;Acquiring training data of the intermediate reduced model in the local database;
基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。A model training operation is performed on the intermediate reduced model based on the training data to obtain a target reduced model.
为了解决上述技术问题,本申请实施例还提供一种应用于BERT模型的蒸馏装置,采用了如下所述的技术方案:In order to solve the above-mentioned technical problems, the embodiment of the present application also provides a distillation device applied to the BERT model, which adopts the following technical solutions:
请求接收模块,用于接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;a request receiving module, configured to receive a model distillation request sent by a user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
原始模型获取模块,用于读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;The original model acquisition module is used for reading the local database, and in the local database, the trained original BERT model corresponding to the distillation object identifier is obtained, and the loss function of the original BERT model is cross entropy;
默认模型构建模块,用于构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;The default model building module is used to construct a default reduced model to be trained that is consistent with the trained original BERT model structure, and the loss function of the default reduced model is cross entropy;
蒸馏操作模块,用于基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;a distillation operation module, configured to perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model;
训练数据获取模块,用于在所述本地数据库中获取所述中间精简模型的训练数据;a training data acquisition module, used for acquiring the training data of the intermediate reduced model in the local database;
模型训练模块,用于基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。A model training module, configured to perform a model training operation on the intermediate reduced model based on the training data to obtain a target reduced model.
为了解决上述技术问题,本申请实施例还提供一种计算机设备,采用了如下所述的技术方案:In order to solve the above-mentioned technical problems, the embodiment of the present application also provides a computer device, which adopts the following technical solutions:
包括存储器和处理器,所述存储器中存储有计算机可读指令,所述处理器执行所述计算机可读指令时实现如下所述的应用于BERT模型的蒸馏方法的步骤;comprising a memory and a processor, wherein computer-readable instructions are stored in the memory, and when the processor executes the computer-readable instructions, the steps of the distillation method applied to the BERT model as described below are implemented;
接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;Receive a model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;Read the local database, obtain the trained original BERT model corresponding to the distillation object identifier in the local database, and the loss function of the original BERT model is cross entropy;
构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;Build a default reduced model to be trained that is consistent with the trained original BERT model structure, and the loss function of the default reduced model is cross entropy;
基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;Perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model;
在所述本地数据库中获取所述中间精简模型的训练数据;Acquiring training data of the intermediate reduced model in the local database;
基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。A model training operation is performed on the intermediate reduced model based on the training data to obtain a target reduced model.
为了解决上述技术问题,本申请实施例还提供一种计算机可读存储介质,采用了如下所述的技术方案:In order to solve the above technical problems, the embodiments of the present application also provide a computer-readable storage medium, which adopts the following technical solutions:
所述计算机可读存储介质上存储有计算机可读指令,所述计算机可读指令被处理器执行时实现如下所述的应用于BERT模型的蒸馏方法的步骤:The computer-readable storage medium stores computer-readable instructions, and when the computer-readable instructions are executed by the processor, implements the steps of the distillation method applied to the BERT model as described below:
接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;Receive a model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;Read the local database, obtain the trained original BERT model corresponding to the distillation object identifier in the local database, and the loss function of the original BERT model is cross entropy;
构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;Build a default reduced model to be trained that is consistent with the trained original BERT model structure, and the loss function of the default reduced model is cross entropy;
基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;Perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model;
在所述本地数据库中获取所述中间精简模型的训练数据;Acquiring training data of the intermediate reduced model in the local database;
基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。A model training operation is performed on the intermediate reduced model based on the training data to obtain a target reduced model.
与现有技术相比,本申请实施例提供的应用于BERT模型的蒸馏方法、装置、计算机设备及存储介质主要有以下有益效果:Compared with the prior art, the distillation method, device, computer equipment and storage medium applied to the BERT model provided by the embodiments of the present application mainly have the following beneficial effects:
本申请实施例提供了一种应用于BERT模型的蒸馏方法,接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;在所述本地数据库中获取所述中间精简模型的训练数据;基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。The embodiment of the present application provides a distillation method applied to a BERT model, receiving a model distillation request sent by a user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient; Obtain a trained original BERT model corresponding to the distillation object identifier in the , and the loss function of the original BERT model is cross entropy; construct a default simplified model to be trained consistent with the trained original BERT model structure, The loss function of the default reduced model is cross entropy; the distillation operation is performed on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model; the training data of the intermediate reduced model is obtained in the local database; The training data is used to perform a model training operation on the intermediate reduced model to obtain a target reduced model. Since the simplified BERT model retains the same model structure as the original BERT model, the difference is the number of layers, which makes the amount of code changes smaller, and the prediction codes of the large model and the small model are consistent, and the original code can be reused to make the model In the process of distillation, there is no need to balance the weights of each loss parameter, thereby reducing the difficulty of the deep model distillation method. At the same time, the tasks in each stage of training the simplified BERT model remain consistent, which makes the convergence of the simplified BERT model more stable.
附图说明Description of drawings
为了更清楚地说明本申请中的方案,下面将对本申请实施例描述中所需要使用的附图作一个简单介绍,显而易见地,下面描述中的附图是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。In order to illustrate the solutions in the present application more clearly, the following will briefly introduce the accompanying drawings used in the description of the embodiments of the present application. For those of ordinary skill, other drawings can also be obtained from these drawings without any creative effort.
图1是本申请实施例一提供的应用于BERT模型的蒸馏方法的实现流程图;Fig. 1 is the realization flow chart of the distillation method applied to the BERT model provided by the first embodiment of the present application;
图2是图1中步骤S104的实现流程图;Fig. 2 is the realization flow chart of step S104 in Fig. 1;
图3是图1中步骤S105的实现流程图;Fig. 3 is the realization flow chart of step S105 in Fig. 1;
图4是本申请实施例一提供的参数优化操作的实现流程图;Fig. 4 is the realization flow chart of the parameter optimization operation provided by Embodiment 1 of the present application;
图5是图4中步骤S403的实现流程图;Fig. 5 is the realization flow chart of step S403 in Fig. 4;
图6是本申请实施例二提供的应用于BERT模型的蒸馏装置的结构示意图;Fig. 6 is the structural representation of the distillation apparatus applied to the BERT model provided by the second embodiment of the present application;
图7是根据本申请的计算机设备的一个实施例的结构示意图。FIG. 7 is a schematic structural diagram of an embodiment of a computer device according to the present application.
具体实施方式Detailed ways
除非另有定义,本文所使用的所有的技术和科学术语与属于本申请的技术领域的技术人员通常理解的含义相同;本文中在申请的说明书中所使用的术语只是为了描述具体的实施例的目的,不是旨在于限制本申请;本申请的说明书和权利要求书及上述附图说明中的术语“包括”和“具有”以及它们的任何变形,意图在于覆盖不排他的包含。本申请的说明书和权利要求书或上述附图中的术语“第一”、“第二”等是用于区别不同对象,而不是用于描述特定顺序。Unless otherwise defined, all technical and scientific terms used herein have the same meaning as commonly understood by one of ordinary skill in the technical field of this application; the terms used herein in the specification of the application are for the purpose of describing specific embodiments only It is not intended to limit the application; the terms "comprising" and "having" and any variations thereof in the description and claims of this application and the above description of the drawings are intended to cover non-exclusive inclusion. The terms "first", "second" and the like in the description and claims of the present application or the above drawings are used to distinguish different objects, rather than to describe a specific order.
在本文中提及“实施例”意味着,结合实施例描述的特定特征、结构或特性可以包含在本申请的至少一个实施例中。在说明书中的各个位置出现该短语并不一定均是指相同的实施例,也不是与其它实施例互斥的独立的或备选的实施例。本领域技术人员显式地和隐式地理解的是,本文所描述的实施例可以与其它实施例相结合。Reference herein to an "embodiment" means that a particular feature, structure, or characteristic described in connection with the embodiment can be included in at least one embodiment of the present application. The appearances of the phrase in various places in the specification are not necessarily all referring to the same embodiment, nor a separate or alternative embodiment that is mutually exclusive of other embodiments. It is explicitly and implicitly understood by those skilled in the art that the embodiments described herein may be combined with other embodiments.
为了使本技术领域的人员更好地理解本申请方案,下面将结合附图,对本申请实施例中的技术方案进行清楚、完整地描述。In order to make those skilled in the art better understand the solutions of the present application, the technical solutions in the embodiments of the present application will be described clearly and completely below with reference to the accompanying drawings.
实施例一Example 1
如图1所示,示出了根据本申请实施例一提供的应用于BERT模型的蒸馏方法的实现流程图,为了便于说明,仅示出与本申请相关的部分。As shown in FIG. 1 , it shows the implementation flow chart of the distillation method applied to the BERT model provided according to the first embodiment of the present application. For the convenience of description, only the part related to the present application is shown.
在步骤S101中,接收用户终端发送的模型蒸馏请求,模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数。In step S101, a model distillation request sent by a user terminal is received, where the model distillation request at least carries a distillation object identifier and a distillation coefficient.
在本申请实施例中,用户终端指的是用于执行本申请提供的预防证件滥用的图像处理方法的终端设备,该当前终端可以是诸如移动电话、智能电话、笔记本电脑、数字广播接收器、PDA(个人数字助理)、PAD(平板电脑)、PMP(便携式多媒体播放器)、导航装置等等的移动终端以及诸如数字TV、台式计算机等等的固定终端,应当理解,此处对用户终端的举例仅为方便理解,不用于限定本申请。In this embodiment of the present application, a user terminal refers to a terminal device used to execute the image processing method for preventing credential abuse provided by the present application, and the current terminal may be, for example, a mobile phone, a smart phone, a notebook computer, a digital broadcast receiver, Mobile terminals such as PDAs (Personal Digital Assistants), PADs (Tablet Computers), PMPs (Portable Multimedia Players), navigation devices, etc., as well as stationary terminals such as digital TVs, desktop computers, etc. The examples are only for the convenience of understanding, and are not used to limit the present application.
在本申请实施例中,蒸馏对象标识主要用于唯一标识需要蒸馏的模型对象,该蒸馏对象标识可以是基于模型名称命名,作为示例,例如:视觉识别模型、语音识别模型等等;该蒸馏对象标识可以是基于名称简称进行命名,作为示例,例如:sjsbmx、yysbmx等等;该蒸馏对象标识还可以是序号进行命名,作为示例,例如:001、002等等,应当理解,此处对蒸馏对象标识的举例仅为方便理解,不用于限定本申请。In the embodiment of the present application, the distillation object identifier is mainly used to uniquely identify the model object that needs to be distilled, and the distillation object identifier may be named based on the model name. As an example, for example, a visual recognition model, a speech recognition model, etc.; The identification can be named based on the abbreviation of the name, as an example, such as: sjsbmx, yysbmx, etc.; the distillation object identification can also be named by a serial number, as an example, such as: 001, 002, etc., it should be understood that the distillation object here The examples of marks are only for convenience of understanding, and are not used to limit the present application.
在本申请实施例中,蒸馏系数主要用于确认将原始BERT模型的层数缩小的倍数,作为示例,例如:需要将BERT模型从12层蒸馏至4层,那么该蒸馏系数则为3,应当理解,此处对蒸馏系数的举例仅为方便理解,不用于限定本申请。In the embodiment of this application, the distillation coefficient is mainly used to confirm the multiple of reducing the number of layers of the original BERT model. As an example, for example, if the BERT model needs to be distilled from 12 layers to 4 layers, then the distillation coefficient is 3, which should be It is understood that the examples of distillation coefficients here are only for convenience of understanding, and are not intended to limit the present application.
在步骤S102中,读取本地数据库,在本地数据库中获取与蒸馏对象标识相对应的训练好的原始BERT模型,原始BERT模型的损失函数为交叉熵。In step S102, the local database is read, and the trained original BERT model corresponding to the distillation object identifier is obtained in the local database, and the loss function of the original BERT model is cross entropy.
在本申请实施例中,本地数据库是指驻留于运行客户应用程序的机器的数据库。本地数据库提供最快的响应时间。因为在客户(应用程序)和服务器之间没有网络转输。该本地数据库预先存储有各式各样的训练好的原始BERT模型,以解决在计算机视觉、语音识别等诸多领域存在的问题。In this embodiment of the present application, the local database refers to a database resident on a machine running a client application. The local database provides the fastest response time. Because there is no network transfer between the client (application) and the server. The local database pre-stores a variety of trained original BERT models to solve problems in many fields such as computer vision and speech recognition.
在本申请实施例中,Bert模型可以分为向量(embedding)层、转换器(transformer)层和预测(prediction)层,每种层是知识的不同表示形式。该原始BERT模型由12层transformer(一种基于“encoder-decoder”结构的模型)组成,该原始BERT模型选用的是交叉熵作为损失函数。该交叉熵主要用于度量两个概率分布间的差异性信息。语言模型的性能通常用交叉熵和复杂度(perplexity)来衡量。交叉熵的意义是用该模型对文本识别的难度,或者从压缩的角度来看,每个词平均要用几个位来编码。复杂度的意义是用该模型表示这一文本平均的分支数,其倒数可视为每个词的平均概率。平滑是指对没观察到的N元组合赋予一个概率值,以保证词序列总能通过语言模型得到一个概率值。In this embodiment of the present application, the Bert model can be divided into a vector (embedding) layer, a transformer (transformer) layer, and a prediction (prediction) layer, each of which is a different representation of knowledge. The original BERT model consists of a 12-layer transformer (a model based on an "encoder-decoder" structure), and the original BERT model uses cross-entropy as the loss function. The cross entropy is mainly used to measure the difference information between two probability distributions. The performance of language models is usually measured by cross-entropy and perplexity. The meaning of cross-entropy is the difficulty of text recognition with the model, or from a compression point of view, how many bits are used to encode each word on average. The meaning of complexity is to use the model to represent the average number of branches of this text, and its inverse can be regarded as the average probability of each word. Smoothing refers to assigning a probability value to the unobserved N-gram combination to ensure that the word sequence can always obtain a probability value through the language model.
在步骤S103中,构建与训练好的原始BERT模型结构一致的待训练的默认精简模型,默认精简模型的损失函数为交叉熵。In step S103, a default reduced model to be trained that is consistent with the trained original BERT model structure is constructed, and the loss function of the default reduced model is cross entropy.
在本申请实施例中,构建出来的默认精简模型保留了与BERT相同的模型结构,不同之处在于transformer层的数量。In the embodiment of the present application, the constructed default reduced model retains the same model structure as BERT, the difference lies in the number of transformer layers.
在步骤S104中,基于蒸馏系数对默认精简模型进行蒸馏操作,得到中间精简模型。In step S104, a distillation operation is performed on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model.
在本申请实施例中,蒸馏操作具体包括蒸馏transformer层以及参数初始化。In this embodiment of the present application, the distillation operation specifically includes the distillation of the transformer layer and parameter initialization.
在本申请实施例中,蒸馏transformer层指的是倘若蒸馏系数为3,那么训练好的原始BERT模型的第一至第三层将替换至默认精简模型的第一层;训练好的原始BERT模型的第四至第六层将替换至默认精简模型的第二层;训练好的原始BERT模型的第七至第九层将替换至默认精简模型的第三层;训练好的原始BERT模型的第十至第十二层将替换至默认精简模型的第四层。In the embodiment of this application, the distillation transformer layer means that if the distillation coefficient is 3, the first to third layers of the trained original BERT model will be replaced with the first layer of the default reduced model; the trained original BERT model The fourth to sixth layers will be replaced to the second layer of the default reduced model; the seventh to ninth layers of the trained original BERT model will be replaced to the third layer of the default reduced model; Layers ten to twelfth will be replaced to the fourth layer of the default reduced model.
在本申请实施例中,在进行蒸馏替换的过程中,可采用伯努利分布概率确定每一层被替换的概率。In this embodiment of the present application, in the process of distillation replacement, the probability of each layer being replaced may be determined by using the Bernoulli distribution probability.
在本申请实施例中,参数初始化指的是embedding、pooler、全连接层参数依据训练好的原始BERT模型中各层级的参数,替换至默认精简模型对应的参数位置。In the embodiment of the present application, parameter initialization refers to replacing the parameters of the embedding, pooler, and fully connected layers to the parameter positions corresponding to the default simplified model according to the parameters of each level in the trained original BERT model.
在步骤S105中,在本地数据库中获取中间精简模型的训练数据。In step S105, the training data of the intermediate reduced model is obtained from the local database.
在本申请实施例中,精简模型训练数据可以采用训练上述原始BERT模型得到的有标签数据,也可以是额外的无标签数据。In the embodiment of the present application, the training data of the reduced model may be labeled data obtained by training the above-mentioned original BERT model, or may be additional unlabeled data.
在本审请实施例中,可获取原始BERT模型训练后的原始训练数据;调高原始BERT模型softmax层的温度参数,得到调高BERT模型,将原始训练数据输入至调高BERT模型进行预测操作,得到均值结果标签;基于标签信息在原始训练数据进行筛选操作,得到带标签的筛选结果标签;基于放大训练数据以及筛选训练数据选取精简模型训练数据。In the example of this application, the original training data after training of the original BERT model can be obtained; the temperature parameter of the softmax layer of the original BERT model can be increased to obtain the increased BERT model, and the original training data can be input into the increased BERT model for prediction operation , get the mean result label; perform a screening operation on the original training data based on the label information, and obtain the label of the filtered result with the label; select the reduced model training data based on the enlarged training data and the filtered training data.
在步骤S106中,基于训练数据对中间精简模型进行模型训练操作,得到目标精简模型。In step S106, a model training operation is performed on the intermediate reduced model based on the training data to obtain the target reduced model.
在本申请实施例中,提供了一种应用于BERT模型的蒸馏方法,接收用户终端发送的模型蒸馏请求,模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;读取本地数据库,在本地数据库中获取与蒸馏对象标识相对应的训练好的原始BERT模型,原始BERT模型的损失函数为交叉熵;构建与训练好的原始BERT模型结构一致的待训练的默认精简模型,默认精简模型的损失函数为交叉熵;基于蒸馏系数对默认精简模型进行蒸馏操作,得到中间精简模型;在本地数据库中获取中间精简模型的训练数据;基于训练数据对中间精简模型进行模型训练操作,得到目标精简模型。由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代 码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。In the embodiment of the present application, a distillation method applied to a BERT model is provided, which receives a model distillation request sent by a user terminal, and the model distillation request carries at least a distillation object identifier and a distillation coefficient; Obtain the trained original BERT model corresponding to the identity of the distilled object. The loss function of the original BERT model is cross entropy; construct a default reduced model to be trained that is consistent with the structure of the trained original BERT model. The loss function of the default reduced model is Cross entropy; perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model; obtain the training data of the intermediate reduced model in the local database; perform model training operations on the intermediate reduced model based on the training data to obtain the target reduced model. Since the simplified BERT model retains the same model structure as the original BERT model, the difference is the number of layers, which makes the amount of code changes smaller, and the prediction codes of the large model and the small model are consistent, and the original code can be reused to make the model In the process of distillation, there is no need to balance the weights of each loss parameter, thereby reducing the difficulty of the deep model distillation method. At the same time, the tasks in each stage of training the simplified BERT model remain consistent, which makes the convergence of the simplified BERT model more stable.
继续参阅图2,示出了图1中步骤S104的实现流程图,为了便于说明,仅示出与本申请相关的部分。Continuing to refer to FIG. 2 , a flowchart of the implementation of step S104 in FIG. 1 is shown. For the convenience of description, only the parts related to the present application are shown.
在本申请实施例一的一些可选的实现方式中,上述步骤S104具体包括:步骤S201、步骤S202以及步骤S203。In some optional implementation manners of Embodiment 1 of the present application, the foregoing step S104 specifically includes: step S201 , step S202 and step S203 .
在步骤S201中,基于蒸馏系数对原始BERT模型的transformer层进行分组操作,得到分组transformer层。In step S201, a grouping operation is performed on the transformer layer of the original BERT model based on the distillation coefficient to obtain a grouped transformer layer.
在本申请实施例中,分组操作指的是transformer层数按照蒸馏系数进行分组,作为示例,例如:transformer层数为12,蒸馏系数为3,分组操作则将12个transformer层划分成4组。In the embodiment of this application, the grouping operation refers to that the number of transformer layers is grouped according to the distillation coefficient. For example, for example, the number of transformer layers is 12 and the distillation coefficient is 3. The grouping operation divides the 12 transformer layers into 4 groups.
在步骤S202中,基于伯努利分布分别在分组transformer层中进行提取操作,得到待替换transformer层。In step S202, extraction operations are respectively performed in the grouped transformer layers based on the Bernoulli distribution to obtain the transformer layers to be replaced.
在本申请实施例中,伯努利分布指的是对于随机变量X有,参数为p(0<p<1),如果它分别以概率p和1-p取1和0为值。EX=p,DX=p(1-p)。伯努利试验成功的次数服从伯努利分布,参数p是试验成功的概率。伯努利分布是一个离散型机率分布,是N=1时二项分布的特殊情况。In the embodiments of the present application, Bernoulli distribution refers to that for a random variable X, the parameter is p (0<p<1), if it takes probability p and 1-p to take 1 and 0 as values respectively. EX=p, DX=p(1-p). The number of successful Bernoulli trials obeys the Bernoulli distribution, and the parameter p is the probability of the success of the trial. The Bernoulli distribution is a discrete probability distribution, a special case of the binomial distribution when N=1.
在步骤S203中,将待替换transformer层分别替换至默认精简模型,得到中间精简模型。In step S203, the transformer layers to be replaced are respectively replaced with default reduced models to obtain intermediate reduced models.
在本申请实施例中,基于层替换的蒸馏方式,保留了与BERT相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,由于蒸馏时,小模型的部分层基于伯努利采样,随机初始化成训练好的大模型映射层的权重,使模型收敛更快,减少训练轮数。In the embodiment of the present application, the distillation method based on layer replacement retains the same model structure as BERT, the difference is the number of layers, so that the amount of code changes is small, and the prediction codes of the large model and the small model are consistent, The original code can be reused, because during distillation, some layers of the small model are randomly initialized to the weight of the trained large model mapping layer based on Bernoulli sampling, which makes the model converge faster and reduces the number of training rounds.
继续参阅图3,示出了图1中步骤S105的实现流程图,为了便于说明,仅示出与本申请相关的部分。Continuing to refer to FIG. 3 , a flowchart of the implementation of step S105 in FIG. 1 is shown. For the convenience of description, only the parts related to the present application are shown.
在本申请实施例一的一些可选的实现方式中,上述步骤S105具体包括:步骤S301、步骤S302、步骤S303、步骤S304以及步骤S305。In some optional implementation manners of Embodiment 1 of the present application, the foregoing step S105 specifically includes: step S301 , step S302 , step S303 , step S304 and step S305 .
在步骤S301中,获取原始BERT模型训练后的原始训练数据。In step S301, the original training data after the original BERT model training is obtained.
在本申请实施例中,原始训练数据指的是在获得训练后的原始BERT模型之前,将训练数据输入至未训练的原始BERT模型的训练数据。In the embodiment of the present application, the original training data refers to the training data of inputting the training data into the untrained original BERT model before obtaining the trained original BERT model.
在步骤S302中,调高原始BERT模型softmax层的温度参数,得到调高BERT模型。In step S302, the temperature parameter of the softmax layer of the original BERT model is increased to obtain an increased BERT model.
在本申请实施例中,可将温度参数T调高至一个较大值,作为示例,例如:T=20,应当理解,此处对调高温度参数的举例仅为方便理解,不用于限定本申请。In the embodiment of the present application, the temperature parameter T can be increased to a larger value, for example, T=20. It should be understood that the example of increasing the temperature parameter here is only for convenience and is not intended to limit the present application .
在步骤S303中,将原始训练数据输入至调高BERT模型进行预测操作,得到均值结果标签。In step S303 , the original training data is input into the BERT model for increasing the prediction operation, and the mean result label is obtained.
在本申请实施例中,每一个原始训练数据在每一个原始BERT模型可以得到其最终的分类概率向量,选取其中概率至最大即为该模型对于当前原始训练数据的判定结果。对于t个原始BERT模型就可以输出t概率向量,然后对t个概率向量求取均值作为当前原始训练数据最后的概率输出向量,当所有原始训练数据完成预测操作之后,得到该原始训练数据对应的均值结果标签。In the embodiment of the present application, each original training data can obtain its final classification probability vector in each original BERT model, and selecting the maximum probability is the judgment result of the model for the current original training data. For t original BERT models, the t probability vector can be output, and then the average of the t probability vectors can be calculated as the final probability output vector of the current original training data. After all the original training data have completed the prediction operation, the corresponding Mean result label.
在步骤S304中,基于标签信息在原始训练数据进行筛选操作,得到带标签的筛选结果标签。In step S304, a screening operation is performed on the original training data based on the label information to obtain a labelled screening result label.
在本申请实施例中,由于在训练原始BERT模型时,会对部分样本数据附上标签数据,为获得有映射关系的训练数据,需要根据是否携带标签数据为条件对原始训练数据进行筛选操作,以得到有映射关系的训练数据,作为该筛选结果标签。In the embodiment of the present application, since label data will be attached to some sample data when training the original BERT model, in order to obtain training data with a mapping relationship, it is necessary to perform a screening operation on the original training data according to whether the label data is carried as a condition, In order to obtain the training data with the mapping relationship as the label of the screening result.
在步骤S305中,基于放大训练数据以及筛选训练数据选取精简模型训练数据。In step S305, the reduced model training data is selected based on the enlarged training data and the filtered training data.
在本申请实施例中,选取到的精简模型训练数据可表示为:In the embodiment of the present application, the selected training data of the reduced model can be expressed as:
Target=a*hard_target+b*soft_target(a+b=1)Target=a*hard_target+b*soft_target(a+b=1)
其中,Target表示最终作为中间精简模型训练数据的标签;hard_target表示筛选结果标签;soft_target表示均值结果标签;a、b表示控制标签融合的权重。Among them, Target represents the label that is finally used as the training data of the intermediate reduced model; hard_target represents the label of the screening result; soft_target represents the label of the mean result; a and b represent the weight of the control label fusion.
继续参阅图4,示出了本申请实施例一提供的参数优化操作的实现流程图,为了便于说明,仅示出与本申请相关的部分。Continuing to refer to FIG. 4 , a flowchart for realizing the parameter optimization operation provided in Embodiment 1 of the present application is shown. For the convenience of description, only the part related to the present application is shown.
在本申请实施例一的一些可选的实现方式中,在上述步骤S106之后,上述方法还包括:步骤S401、步骤S402、步骤S403以及步骤S404。In some optional implementation manners of Embodiment 1 of the present application, after the foregoing step S106, the foregoing method further includes: step S401, step S402, step S403, and step S404.
在步骤S401中,在本地数据库中获取优化训练数据。In step S401, the optimized training data is obtained from the local database.
在本申请实施例中,优化训练数据主要用于优化目标精简模型的参数,该优化训练数据分别输入至训练好的原始BERT模型和目标精简模型,在保证输入数据一致的前提下,可获知原始BERT模型和目标精简模型各个transformer层输出的差异。In the embodiment of the present application, the optimized training data is mainly used to optimize the parameters of the target reduced model. The optimized training data is input into the trained original BERT model and the target reduced model respectively. On the premise of ensuring the consistency of the input data, the original The difference between the output of each transformer layer of the BERT model and the target reduction model.
在步骤S402中,将优化训练数据分别输入至训练好的原始BERT模型以及目标精简模型中,分别得到原始transformer层输出数据以及目标transformer层输出数据。In step S402, the optimized training data is input into the trained original BERT model and the target reduced model respectively, and the original transformer layer output data and the target transformer layer output data are obtained respectively.
在步骤S403中,基于搬土距离计算原始transformer层输出数据以及目标transformer层输出数据的蒸馏损失数据。In step S403, the distillation loss data of the output data of the original transformer layer and the output data of the target transformer layer are calculated based on the soil removal distance.
在本申请实施例中,搬土距离(EMD)是在一个区域D上两个概率分布之间的距离的度量。可分别获取原始transformer层和目标transformer层分别输出的attention(注意力)矩阵数据,并计算二者attention(注意力)矩阵数据的注意力EMD距离;再获取原始transformer层和目标transformer层分别输出的FFN(全连接前馈神经网络)隐层矩阵数据,并计算二者FFN隐层矩阵数据的FFN隐层EMD距离,以得到该蒸馏损失数据。In the embodiment of the present application, the earth removal distance (EMD) is a measure of the distance between two probability distributions on a region D. The attention (attention) matrix data output by the original transformer layer and the target transformer layer can be obtained respectively, and the attention EMD distance of the attention (attention) matrix data of the two can be calculated; then the original transformer layer and the target transformer layer output respectively. FFN (Fully Connected Feedforward Neural Network) hidden layer matrix data, and calculate the FFN hidden layer EMD distance of the two FFN hidden layer matrix data to obtain the distillation loss data.
在步骤S404中,根据蒸馏损失数据对目标精简模型进行参数优化操作,得到优化精简模型。In step S404, a parameter optimization operation is performed on the target reduced model according to the distillation loss data to obtain an optimized reduced model.
在本申请实施例中,在获知蒸馏损失数据(即原始transformer层输出数据以及目标transformer层输出数据的距离度量)后,对目标精简模型的中的参数进行优化,直至蒸馏损失数据小于预设值,或者训练的次数满足预设次数,从而获得该优化精简模型。In the embodiment of the present application, after learning the distillation loss data (ie, the distance metric between the original transformer layer output data and the target transformer layer output data), the parameters in the target reduced model are optimized until the distillation loss data is less than the preset value. , or the training times meet the preset times, so as to obtain the optimized and reduced model.
在本申请实施例中,由于目标精简模型的transformer层是基于伯努利分布概率进行选取的,从而导致该目标精简模型的参数存在一定的误差,由于Bert模型中的transformer层对模型的贡献最大,包含的信息最丰富,精简模型在该层的学习能力也最为重要,因此通过采用“搬土距离EMD”计算原始BERT模型transformer层的输出以及目标精简模型transformer层的输出之间的损失数据,并基于该损失数据对该目标精简模型的参数进行优化,以提高该目标精简模型的的准确率,能够保证目标模型学习到更多的原始模型的知识。In the embodiment of this application, since the transformer layer of the target reduced model is selected based on the probability of Bernoulli distribution, there is a certain error in the parameters of the target reduced model, because the transformer layer in the Bert model contributes the most to the model , contains the most abundant information, and the learning ability of the simplified model in this layer is also the most important. Therefore, the loss data between the output of the transformer layer of the original BERT model and the output of the transformer layer of the target simplified model is calculated by using the "earth removal distance EMD", And based on the loss data, the parameters of the target reduced model are optimized to improve the accuracy of the target reduced model, which can ensure that the target model learns more knowledge of the original model.
继续参阅图5,示出了图4中步骤S403的实现流程图,为了便于说明,仅示出与本申请相关的部分。Continuing to refer to FIG. 5 , a flowchart of the implementation of step S403 in FIG. 4 is shown. For convenience of description, only the part related to the present application is shown.
在本申请实施例一的一些可选的实现方式中,上述步骤S403具体包括:步骤S501、步骤S502、步骤503、步骤S504以及步骤S505。In some optional implementation manners of Embodiment 1 of the present application, the foregoing step S403 specifically includes: step S501 , step S502 , step 503 , step S504 and step S505 .
在步骤S501中,获取原始transformer层输出的原始注意力矩阵以及目标transformer层输出的目标注意力矩阵。In step S501, the original attention matrix output by the original transformer layer and the target attention matrix output by the target transformer layer are obtained.
在步骤S502中,根据原始注意力矩阵以及目标注意力矩阵计算注意力EMD距离。In step S502, the attention EMD distance is calculated according to the original attention matrix and the target attention matrix.
在本申请实施例中,注意力EMD距离表示为:In this embodiment of the present application, the attention EMD distance is expressed as:
Figure PCTCN2021090524-appb-000001
Figure PCTCN2021090524-appb-000001
其中,L attn表示注意力EMD距离;A T表示原始注意力矩阵;A S表示目标注意力矩阵;
Figure PCTCN2021090524-appb-000002
表示原始注意力矩阵与标注意力矩阵之间的均方误差,且
Figure PCTCN2021090524-appb-000003
表示第i 层原始transformer层的原始注意力矩阵;
Figure PCTCN2021090524-appb-000004
表示第j层目标transformer层的目标注意力矩阵;f ij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
Among them, L attn represents the attention EMD distance; A T represents the original attention matrix; A S represents the target attention matrix;
Figure PCTCN2021090524-appb-000002
represents the mean squared error between the original attention matrix and the standard attention matrix, and
Figure PCTCN2021090524-appb-000003
represents the original attention matrix of the original transformer layer of the i-th layer;
Figure PCTCN2021090524-appb-000004
Represents the target attention matrix of the j-th target transformer layer; f ij represents the amount of knowledge migrated from the i-th original transformer layer to the j-th target transformer layer; M represents the number of layers of the original transformer layer; N represents the target transformer layer. layers.
在步骤S503中,获取原始transformer层输出的原始FFN隐层矩阵以及目标transformer层输出的目标FFN隐层矩阵。In step S503, the original FFN hidden layer matrix output by the original transformer layer and the target FFN hidden layer matrix output by the target transformer layer are obtained.
在步骤S504中,根据原始FFN隐层矩阵以及目标FFN隐层矩阵计算FFN隐层EMD距离。In step S504, the FFN hidden layer EMD distance is calculated according to the original FFN hidden layer matrix and the target FFN hidden layer matrix.
在本申请实施例中,FFN隐层EMD距离表示为:In this embodiment of the present application, the EMD distance of the FFN hidden layer is expressed as:
Figure PCTCN2021090524-appb-000005
Figure PCTCN2021090524-appb-000005
其中,L ffn表示FFN隐层EMD距离;H T表示原始transformer层的原始FFN隐层矩阵;H S表示目标transformer层的目标FFN隐层矩阵;
Figure PCTCN2021090524-appb-000006
表示原始FFN隐层矩阵与目标FFN隐层矩阵之间的均方误差,且
Figure PCTCN2021090524-appb-000007
表示第j层目标transformer层的目标FFN隐层矩阵;W h表示转换矩阵;
Figure PCTCN2021090524-appb-000008
表示第i层原始transformer层的原始FFN隐层矩阵;f ij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
Among them, Lffn represents the EMD distance of the FFN hidden layer; H T represents the original FFN hidden layer matrix of the original transformer layer; H S represents the target FFN hidden layer matrix of the target transformer layer;
Figure PCTCN2021090524-appb-000006
represents the mean squared error between the original FFN hidden layer matrix and the target FFN hidden layer matrix, and
Figure PCTCN2021090524-appb-000007
Represents the target FFN hidden layer matrix of the j-th target transformer layer; W h represents the transformation matrix;
Figure PCTCN2021090524-appb-000008
Represents the original FFN hidden layer matrix of the original transformer layer of the i-th layer; f ij represents the amount of knowledge migrated from the original transformer layer of the i-th layer to the target transformer layer of the j-th layer; M represents the number of layers of the original transformer layer; N represents the target transformer layer number of layers.
在步骤S505中,基于注意力EMD距离以及FFN隐层EMD距离获得蒸馏损失数据。In step S505, the distillation loss data is obtained based on the attention EMD distance and the FFN hidden layer EMD distance.
在本申请实施例中,transformer层是Bert模型中的重要组成部分,通过自注意力机制可以捕获长距离依赖关系,一个标准的transformer主要包括两部分:多头注意力机制(Multi-Head Attention,MHA)和全连接前馈神经网络(FFN)。EMD是使用线性规划计算两个分布之间最优距离的方法,可以使知识的蒸馏更加合理。In the embodiment of this application, the transformer layer is an important part of the Bert model, and long-distance dependencies can be captured through the self-attention mechanism. A standard transformer mainly includes two parts: the multi-head attention mechanism (Multi-Head Attention, MHA). ) and a fully connected feedforward neural network (FFN). EMD is a method of calculating the optimal distance between two distributions using linear programming, which can make the distillation of knowledge more reasonable.
在本申请实施例一的一些可选的实现方式中,注意力EMD距离表示为:In some optional implementations of Embodiment 1 of the present application, the attention EMD distance is expressed as:
Figure PCTCN2021090524-appb-000009
Figure PCTCN2021090524-appb-000009
其中,L attn表示注意力EMD距离;A T表示原始注意力矩阵;A S表示目标注意力矩阵;
Figure PCTCN2021090524-appb-000010
表示原始注意力矩阵与标注意力矩阵之间的均方误差,且
Figure PCTCN2021090524-appb-000011
表示第i层原始transformer层的原始注意力矩阵;
Figure PCTCN2021090524-appb-000012
表示第j层目标transformer层的目标注意力矩阵;f ij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
Among them, L attn represents the attention EMD distance; A T represents the original attention matrix; A S represents the target attention matrix;
Figure PCTCN2021090524-appb-000010
represents the mean squared error between the original attention matrix and the standard attention matrix, and
Figure PCTCN2021090524-appb-000011
represents the original attention matrix of the original transformer layer of the i-th layer;
Figure PCTCN2021090524-appb-000012
Represents the target attention matrix of the j-th target transformer layer; f ij represents the amount of knowledge migrated from the i-th original transformer layer to the j-th target transformer layer; M represents the number of layers of the original transformer layer; N represents the target transformer layer. layers.
在本申请实施例一的一些可选的实现方式中,FFN隐层EMD距离表示为:In some optional implementations of Embodiment 1 of the present application, the FFN hidden layer EMD distance is expressed as:
Figure PCTCN2021090524-appb-000013
Figure PCTCN2021090524-appb-000013
其中,L ffn表示FFN隐层EMD距离;H T表示原始transformer层的原始FFN隐层矩阵;H S表示目标transformer层的目标FFN隐层矩阵;
Figure PCTCN2021090524-appb-000014
表示原始FFN隐层矩阵与目标FFN隐层矩阵之间的均方误差,且
Figure PCTCN2021090524-appb-000015
表示第j层目标transformer层的目标FFN隐层矩阵;W h表示转换矩阵;
Figure PCTCN2021090524-appb-000016
表示第i层原始transformer层的原始FFN隐层矩阵;f ij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
Among them, Lffn represents the EMD distance of the FFN hidden layer; H T represents the original FFN hidden layer matrix of the original transformer layer; H S represents the target FFN hidden layer matrix of the target transformer layer;
Figure PCTCN2021090524-appb-000014
represents the mean squared error between the original FFN hidden layer matrix and the target FFN hidden layer matrix, and
Figure PCTCN2021090524-appb-000015
Represents the target FFN hidden layer matrix of the j-th target transformer layer; W h represents the transformation matrix;
Figure PCTCN2021090524-appb-000016
Represents the original FFN hidden layer matrix of the original transformer layer of the i-th layer; f ij represents the amount of knowledge migrated from the original transformer layer of the i-th layer to the target transformer layer of the j-th layer; M represents the number of layers of the original transformer layer; N represents the target transformer layer the number of layers.
综上,本申请实施例一提供了一种应用于BERT模型的蒸馏方法,接收用户终端发送的模型蒸馏请求,模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;读取本地数据库,在本地数据库中获取与蒸馏对象标识相对应的训练好的原始BERT模型,原始BERT模型的损失函数为交叉熵;构建与训练好的原始BERT模型结构一致的待训练的默认精简模型,默认精简模型的损失函数为交叉熵;基于蒸馏系数对默认精简模型进行蒸馏操作,得到中间精简模型;在本地数据库中获取中间精简模型的训练数据;基于训练数据对中间精简模 型进行模型训练操作,得到目标精简模型。由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。另外,基于层替换的蒸馏方式,保留了与BERT相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,由于蒸馏时,小模型的部分层基于伯努利采样,随机初始化成训练好的大模型映射层的权重,使模型收敛更快,减少训练轮数。In summary, Embodiment 1 of the present application provides a distillation method applied to a BERT model, receiving a model distillation request sent by a user terminal, and the model distillation request at least carries a distillation object identifier and a distillation coefficient; Obtain the trained original BERT model corresponding to the identification of the distillation object, and the loss function of the original BERT model is cross entropy; construct a default reduced model to be trained that is consistent with the structure of the trained original BERT model, and the loss function of the default reduced model is cross entropy; perform distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model; obtain the training data of the intermediate reduced model in the local database; perform model training operations on the intermediate reduced model based on the training data to obtain the target reduced model. Since the simplified BERT model retains the same model structure as the original BERT model, the difference is the number of layers, which makes the amount of code changes smaller, and the prediction codes of the large model and the small model are consistent, and the original code can be reused to make the model In the process of distillation, there is no need to balance the weights of each loss parameter, thereby reducing the difficulty of the deep model distillation method. At the same time, the tasks in each stage of training the simplified BERT model remain consistent, which makes the convergence of the simplified BERT model more stable. In addition, the distillation method based on layer replacement retains the same model structure as BERT. The difference is the number of layers, which makes the code changes smaller, and the prediction codes of the large model and the small model are consistent, and the original code can be reused , during distillation, some layers of the small model are randomly initialized to the weight of the trained large model mapping layer based on Bernoulli sampling, which makes the model converge faster and reduces the number of training rounds.
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机可读指令来指令相关的硬件来完成,该计算机可读指令可存储于一计算机可读取存储介质中,该计算机可读指令在执行时,可包括如上述各方法的实施例的流程。其中,前述的存储介质可为磁碟、光盘、只读存储记忆体(Read-Only Memory,ROM)等非易失性存储介质,或随机存储记忆体(Random Access Memory,RAM)等。Those of ordinary skill in the art can understand that all or part of the processes in the methods of the above embodiments can be implemented by instructing relevant hardware through computer-readable instructions, and the computer-readable instructions can be stored in a computer-readable storage medium. , when the computer-readable instructions are executed, the processes of the above-mentioned method embodiments may be included. Wherein, the aforementioned storage medium may be a non-volatile storage medium such as a magnetic disk, an optical disk, a read-only memory (Read-Only Memory, ROM), or a random access memory (Random Access Memory, RAM) or the like.
应该理解的是,虽然附图的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,其可以以其他的顺序执行。而且,附图的流程图中的至少一部分步骤可以包括多个子步骤或者多个阶段,这些子步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,其执行顺序也不必然是依次进行,而是可以与其他步骤或者其他步骤的子步骤或者阶段的至少一部分轮流或者交替地执行。It should be understood that although the various steps in the flowchart of the accompanying drawings are sequentially shown in the order indicated by the arrows, these steps are not necessarily executed in sequence in the order indicated by the arrows. Unless explicitly stated herein, the execution of these steps is not strictly limited to the order and may be performed in other orders. Moreover, at least a part of the steps in the flowchart of the accompanying drawings may include multiple sub-steps or multiple stages, and these sub-steps or stages are not necessarily executed at the same time, but may be executed at different times, and the execution sequence is also It does not have to be performed sequentially, but may be performed alternately or alternately with other steps or at least a portion of sub-steps or stages of other steps.
实施例二Embodiment 2
进一步参考图6,作为对上述图1所示方法的实现,本申请提供了一种应用于BERT模型的蒸馏装置的一个实施例,该装置实施例与图1所示的方法实施例相对应,该装置具体可以应用于各种电子设备中。Further referring to FIG. 6 , as an implementation of the method shown in FIG. 1 above, the present application provides an embodiment of a distillation apparatus applied to a BERT model, and the apparatus embodiment corresponds to the method embodiment shown in FIG. 1 , Specifically, the device can be applied to various electronic devices.
如图6所示,本实施例的应用于BERT模型的蒸馏装置100包括:请求接收模块110、原始模型获取模块120、默认模型构建模块130、蒸馏操作模块140、训练数据获取模块150以及模型训练模块160。其中:As shown in FIG. 6 , the distillation apparatus 100 applied to the BERT model in this embodiment includes: a request receiving module 110, an original model obtaining module 120, a default model building module 130, a distillation operation module 140, a training data obtaining module 150, and a model training module module 160. in:
请求接收模块110,用于接收用户终端发送的模型蒸馏请求,模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;The request receiving module 110 is configured to receive a model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
原始模型获取模块120,用于读取本地数据库,在本地数据库中获取与蒸馏对象标识相对应的训练好的原始BERT模型,原始BERT模型的损失函数为交叉熵;The original model obtaining module 120 is used to read the local database, and obtain the trained original BERT model corresponding to the distillation object identifier in the local database, and the loss function of the original BERT model is cross entropy;
默认模型构建模块130,用于构建与训练好的原始BERT模型结构一致的待训练的默认精简模型,默认精简模型的损失函数为交叉熵;The default model building module 130 is used to construct a default reduced model to be trained that is consistent with the trained original BERT model structure, and the loss function of the default reduced model is cross entropy;
蒸馏操作模块140,用于基于蒸馏系数对默认精简模型进行蒸馏操作,得到中间精简模型;a distillation operation module 140, configured to perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model;
训练数据获取模块150,用于在本地数据库中获取中间精简模型的训练数据;A training data acquisition module 150, configured to acquire the training data of the intermediate reduced model in the local database;
模型训练模块160,用于基于训练数据对中间精简模型进行模型训练操作,得到目标精简模型。The model training module 160 is configured to perform a model training operation on the intermediate reduced model based on the training data to obtain the target reduced model.
在本申请实施例中,用户终端指的是用于执行本申请提供的预防证件滥用的图像处理方法的终端设备,该当前终端可以是诸如移动电话、智能电话、笔记本电脑、数字广播接收器、PDA(个人数字助理)、PAD(平板电脑)、PMP(便携式多媒体播放器)、导航装置等等的移动终端以及诸如数字TV、台式计算机等等的固定终端,应当理解,此处对用户终端的举例仅为方便理解,不用于限定本申请。In this embodiment of the present application, a user terminal refers to a terminal device used to execute the image processing method for preventing credential abuse provided by the present application, and the current terminal may be, for example, a mobile phone, a smart phone, a notebook computer, a digital broadcast receiver, Mobile terminals such as PDAs (Personal Digital Assistants), PADs (Tablet Computers), PMPs (Portable Multimedia Players), navigation devices, etc., as well as stationary terminals such as digital TVs, desktop computers, etc. The examples are only for the convenience of understanding, and are not used to limit the present application.
在本申请实施例中,蒸馏对象标识主要用于唯一标识需要蒸馏的模型对象,该蒸馏对象标识可以是基于模型名称命名,作为示例,例如:视觉识别模型、语音识别模型等等;该蒸馏对象标识可以是基于名称简称进行命名,作为示例,例如:sjsbmx、yysbmx等等;该蒸馏对象标识还可以是序号进行命名,作为示例,例如:001、002等等,应当理解,此 处对蒸馏对象标识的举例仅为方便理解,不用于限定本申请。In this embodiment of the present application, the distillation object identifier is mainly used to uniquely identify the model object that needs to be distilled. The distillation object identifier may be named based on the model name. For example, for example, a visual recognition model, a speech recognition model, etc.; the distillation object The identification can be named based on the abbreviation of the name, as an example, such as: sjsbmx, yysbmx, etc.; the distillation object identification can also be named by a serial number, as an example, such as: 001, 002, etc., it should be understood that the distillation object here The examples of marks are only for convenience of understanding, and are not used to limit the present application.
在本申请实施例中,蒸馏系数主要用于确认将原始BERT模型的层数缩小的倍数,作为示例,例如:需要将BERT模型从12层蒸馏至4层,那么该蒸馏系数则为3,应当理解,此处对蒸馏系数的举例仅为方便理解,不用于限定本申请。In the embodiment of this application, the distillation coefficient is mainly used to confirm the multiple of reducing the number of layers of the original BERT model. As an example, for example, if the BERT model needs to be distilled from 12 layers to 4 layers, then the distillation coefficient is 3, which should be It is understood that the examples of distillation coefficients here are only for convenience of understanding, and are not intended to limit the present application.
在本申请实施例中,本地数据库是指驻留于运行客户应用程序的机器的数据库。本地数据库提供最快的响应时间。因为在客户(应用程序)和服务器之间没有网络转输。该本地数据库预先存储有各式各样的训练好的原始BERT模型,以解决在计算机视觉、语音识别等诸多领域存在的问题。In this embodiment of the present application, the local database refers to a database resident on a machine running a client application. The local database provides the fastest response time. Because there is no network transfer between the client (application) and the server. The local database pre-stores a variety of trained original BERT models to solve problems in many fields such as computer vision and speech recognition.
在本申请实施例中,Bert模型可以分为向量(embedding)层、转换器(transformer)层和预测(prediction)层,每种层是知识的不同表示形式。该原始BERT模型由12层transformer(一种基于“encoder-decoder”结构的模型)组成,该原始BERT模型选用的是交叉熵作为损失函数。该交叉熵主要用于度量两个概率分布间的差异性信息。语言模型的性能通常用交叉熵和复杂度(perplexity)来衡量。交叉熵的意义是用该模型对文本识别的难度,或者从压缩的角度来看,每个词平均要用几个位来编码。复杂度的意义是用该模型表示这一文本平均的分支数,其倒数可视为每个词的平均概率。平滑是指对没观察到的N元组合赋予一个概率值,以保证词序列总能通过语言模型得到一个概率值。In this embodiment of the present application, the Bert model can be divided into a vector (embedding) layer, a transformer (transformer) layer, and a prediction (prediction) layer, each of which is a different representation of knowledge. The original BERT model consists of a 12-layer transformer (a model based on an "encoder-decoder" structure), and the original BERT model uses cross-entropy as the loss function. The cross entropy is mainly used to measure the difference information between two probability distributions. The performance of language models is usually measured by cross-entropy and perplexity. The meaning of cross-entropy is the difficulty of text recognition with the model, or from a compression point of view, how many bits are used to encode each word on average. The meaning of complexity is to use the model to represent the average number of branches of this text, and its inverse can be regarded as the average probability of each word. Smoothing refers to assigning a probability value to the unobserved N-gram combination to ensure that the word sequence can always obtain a probability value through the language model.
在本申请实施例中,构建出来的默认精简模型保留了与BERT相同的模型结构,不同之处在于transformer层的数量。In the embodiment of the present application, the constructed default reduced model retains the same model structure as BERT, the difference lies in the number of transformer layers.
在本申请实施例中,蒸馏操作具体包括蒸馏transformer层以及参数初始化。In this embodiment of the present application, the distillation operation specifically includes the distillation of the transformer layer and parameter initialization.
在本申请实施例中,蒸馏transformer层指的是倘若蒸馏系数为3,那么训练好的原始BERT模型的第一至第三层将替换至默认精简模型的第一层;训练好的原始BERT模型的第四至第六层将替换至默认精简模型的第二层;训练好的原始BERT模型的第七至第九层将替换至默认精简模型的第三层;训练好的原始BERT模型的第十至第十二层将替换至默认精简模型的第四层。In the embodiment of this application, the distillation transformer layer means that if the distillation coefficient is 3, the first to third layers of the trained original BERT model will be replaced with the first layer of the default reduced model; the trained original BERT model The fourth to sixth layers will be replaced to the second layer of the default reduced model; the seventh to ninth layers of the trained original BERT model will be replaced to the third layer of the default reduced model; Layers ten to twelfth will be replaced to the fourth layer of the default reduced model.
在本申请实施例中,在进行蒸馏替换的过程中,可采用伯努利分布概率确定每一层被替换的概率。In this embodiment of the present application, in the process of distillation replacement, the probability of each layer being replaced may be determined by using the Bernoulli distribution probability.
在本申请实施例中,参数初始化指的是embedding、pooler、全连接层参数依据训练好的原始BERT模型中各层级的参数,替换至默认精简模型对应的参数位置。In the embodiment of the present application, parameter initialization refers to replacing the parameters of the embedding, pooler, and fully connected layers to the parameter positions corresponding to the default simplified model according to the parameters of each level in the trained original BERT model.
在本申请实施例中,精简模型训练数据可以采用训练上述原始BERT模型得到的有标签数据,也可以是额外的无标签数据。In the embodiment of the present application, the training data of the reduced model may be labeled data obtained by training the above-mentioned original BERT model, or may be additional unlabeled data.
在本审请实施例中,可获取原始BERT模型训练后的原始训练数据;调高原始BERT模型softmax层的温度参数,得到调高BERT模型,将原始训练数据输入至调高BERT模型进行预测操作,得到均值结果标签;基于标签信息在原始训练数据进行筛选操作,得到带标签的筛选结果标签;基于放大训练数据以及筛选训练数据选取精简模型训练数据。In the example of this application, the original training data after training of the original BERT model can be obtained; the temperature parameter of the softmax layer of the original BERT model can be increased to obtain the increased BERT model, and the original training data can be input into the increased BERT model for prediction operation , get the mean result label; perform a screening operation on the original training data based on the label information, and obtain the label of the filtered result with the label; select the reduced model training data based on the enlarged training data and the filtered training data.
在本申请实施例中,提供了一种应用于BERT模型的蒸馏装置,由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。In the embodiment of the present application, a distillation device applied to the BERT model is provided. Since the simplified BERT model retains the same model structure as the original BERT model, the difference is the number of layers, so that the amount of code changes is small and large. The prediction code of the model and the small model is consistent, and the original code can be reused, so that the model does not need to balance the weight of each loss parameter during the distillation process, thereby reducing the difficulty of the deep model distillation method. The tasks of the stages are all consistent, which makes the convergence of the simplified BERT model more stable.
在本申请实施例二的一些可选的实现方式中,上述蒸馏操作模块140具体包括:分组操作子模块、提取操作子模块以及替换操作子模块。其中:In some optional implementations of the second embodiment of the present application, the above-mentioned distillation operation module 140 specifically includes: a grouping operation sub-module, an extraction operation sub-module, and a replacement operation sub-module. in:
分组操作子模块,用于基于蒸馏系数对原始BERT模型的transformer层进行分组操作,得到分组transformer层;The grouping operation sub-module is used to group the transformer layer of the original BERT model based on the distillation coefficient to obtain the grouped transformer layer;
提取操作子模块,用于基于伯努利分布分别在分组transformer层中进行提取操作,得到待替换transformer层;The extraction operation sub-module is used to perform extraction operations in the grouped transformer layers based on the Bernoulli distribution to obtain the transformer layers to be replaced;
替换操作子模块,用于将待替换transformer层分别替换至默认精简模型,得到中间 精简模型。The replacement operation sub-module is used to replace the transformer layer to be replaced with the default reduced model respectively to obtain the intermediate reduced model.
在本申请实施例二的一些可选的实现方式中,上述训练数据获取模块150具体包括:原始训练数据获取子模块、参数子调高模型、预测操作子模块、筛选操作子模块以及训练数据获取子模块。其中:In some optional implementations of the second embodiment of the present application, the above-mentioned training data acquisition module 150 specifically includes: an original training data acquisition sub-module, a parameter sub-adjustment model, a prediction operation sub-module, a screening operation sub-module, and a training data acquisition sub-module submodule. in:
原始训练数据获取子模块,用于获取原始BERT模型训练后的原始训练数据;The original training data acquisition sub-module is used to obtain the original training data after the original BERT model training;
参数子调高模型,用于调高原始BERT模型softmax层的温度参数,得到调高BERT模型;The parameter sub-adjustment model is used to increase the temperature parameters of the softmax layer of the original BERT model to obtain an increased BERT model;
预测操作子模块,用于将原始训练数据输入至调高BERT模型进行预测操作,得到均值结果标签;The prediction operation sub-module is used to input the original training data into the BERT model for prediction operation, and obtain the average result label;
筛选操作子模块,用于基于标签信息在原始训练数据进行筛选操作,得到带标签的筛选结果标签;The filtering operation sub-module is used to perform the filtering operation on the original training data based on the label information, and obtain the label of the filtering result with the label;
训练数据获取子模块,用于基于放大训练数据以及筛选训练数据选取精简模型训练数据。The training data acquisition sub-module is used to select the reduced model training data based on amplifying the training data and filtering the training data.
在本申请实施例二的一些可选的实现方式中,上述应用于BERT模型的蒸馏装置100还包括:优化训练数据获取模块、蒸馏损失数据计算模块以及参数优化模块。其中:In some optional implementations of the second embodiment of the present application, the above-mentioned distillation apparatus 100 applied to the BERT model further includes: an optimization training data acquisition module, a distillation loss data calculation module, and a parameter optimization module. in:
优化训练数据获取模块,用于在本地数据库中获取优化训练数据;The optimized training data acquisition module is used to obtain optimized training data in the local database;
优化训练数据输入模块,用于将优化训练数据分别输入至训练好的原始BERT模型以及目标精简模型中,分别得到原始transformer层输出数据以及目标transformer层输出数据;The optimized training data input module is used to input the optimized training data into the trained original BERT model and the target reduced model, respectively, to obtain the original transformer layer output data and the target transformer layer output data;
蒸馏损失数据计算模块,用于基于搬土距离计算原始transformer层输出数据以及目标transformer层输出数据的蒸馏损失数据;The distillation loss data calculation module is used to calculate the distillation loss data of the output data of the original transformer layer and the output data of the target transformer layer based on the moving distance;
参数优化模块,用于根据蒸馏损失数据对目标精简模型进行参数优化操作,得到优化精简模型。The parameter optimization module is used to optimize the parameters of the target reduced model according to the distillation loss data to obtain the optimized reduced model.
在本申请实施例二的一些可选的实现方式中,上述蒸馏损失数据计算模块具体包括:目标注意力矩阵获取子模块、注意力EMD距离计算子模块、目标FFN隐层矩阵获取子模块、FFN隐层EMD距离计算子模块以及蒸馏损失数据获取子模块。其中:In some optional implementations of the second embodiment of the present application, the above-mentioned distillation loss data calculation module specifically includes: a target attention matrix acquisition sub-module, an attention EMD distance calculation sub-module, a target FFN hidden layer matrix acquisition sub-module, FFN Hidden layer EMD distance calculation sub-module and distillation loss data acquisition sub-module. in:
目标注意力矩阵获取子模块,用于获取原始transformer层输出的原始注意力矩阵以及目标transformer层输出的目标注意力矩阵;The target attention matrix acquisition sub-module is used to obtain the original attention matrix output by the original transformer layer and the target attention matrix output by the target transformer layer;
注意力EMD距离计算子模块,用于根据原始注意力矩阵以及目标注意力矩阵计算注意力EMD距离;The attention EMD distance calculation sub-module is used to calculate the attention EMD distance according to the original attention matrix and the target attention matrix;
目标FFN隐层矩阵获取子模块,用于获取原始transformer层输出的原始FFN隐层矩阵以及目标transformer层输出的目标FFN隐层矩阵;The target FFN hidden layer matrix acquisition sub-module is used to obtain the original FFN hidden layer matrix output by the original transformer layer and the target FFN hidden layer matrix output by the target transformer layer;
FFN隐层EMD距离计算子模块,用于根据原始FFN隐层矩阵以及目标FFN隐层矩阵计算FFN隐层EMD距离;The FFN hidden layer EMD distance calculation sub-module is used to calculate the FFN hidden layer EMD distance according to the original FFN hidden layer matrix and the target FFN hidden layer matrix;
蒸馏损失数据获取子模块,用于基于注意力EMD距离以及FFN隐层EMD距离获得蒸馏损失数据。The distillation loss data acquisition sub-module is used to obtain distillation loss data based on the attention EMD distance and the FFN hidden layer EMD distance.
在本申请实施例二的一些可选的实现方式中,注意力EMD距离表示为:In some optional implementations of the second embodiment of the present application, the attention EMD distance is expressed as:
Figure PCTCN2021090524-appb-000017
Figure PCTCN2021090524-appb-000017
其中,L attn表示注意力EMD距离;A T表示原始注意力矩阵;A S表示目标注意力矩阵;
Figure PCTCN2021090524-appb-000018
表示原始注意力矩阵与标注意力矩阵之间的均方误差,且
Figure PCTCN2021090524-appb-000019
表示第i层原始transformer层的原始注意力矩阵;
Figure PCTCN2021090524-appb-000020
表示第j层目标transformer层的目标注意力矩阵;f ij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
Among them, L attn represents the attention EMD distance; A T represents the original attention matrix; A S represents the target attention matrix;
Figure PCTCN2021090524-appb-000018
represents the mean squared error between the original attention matrix and the standard attention matrix, and
Figure PCTCN2021090524-appb-000019
represents the original attention matrix of the original transformer layer of the i-th layer;
Figure PCTCN2021090524-appb-000020
Represents the target attention matrix of the j-th target transformer layer; f ij represents the amount of knowledge migrated from the i-th original transformer layer to the j-th target transformer layer; M represents the number of layers of the original transformer layer; N represents the target transformer layer. layers.
在本申请实施例二的一些可选的实现方式中,FFN隐层EMD距离表示为:In some optional implementation manners of the second embodiment of the present application, the FFN hidden layer EMD distance is expressed as:
Figure PCTCN2021090524-appb-000021
Figure PCTCN2021090524-appb-000021
其中,L ffn表示FFN隐层EMD距离;H T表示原始transformer层的原始FFN隐层矩阵;H S表示目标transformer层的目标FFN隐层矩阵;
Figure PCTCN2021090524-appb-000022
表示原始FFN隐层矩阵与目标FFN隐层矩阵之间的均方误差,且
Figure PCTCN2021090524-appb-000023
表示第j层目标transformer层的目标FFN隐层矩阵;W h表示转换矩阵;
Figure PCTCN2021090524-appb-000024
表示第i层原始transformer层的原始FFN隐层矩阵;f ij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
Among them, Lffn represents the EMD distance of the FFN hidden layer; H T represents the original FFN hidden layer matrix of the original transformer layer; H S represents the target FFN hidden layer matrix of the target transformer layer;
Figure PCTCN2021090524-appb-000022
represents the mean squared error between the original FFN hidden layer matrix and the target FFN hidden layer matrix, and
Figure PCTCN2021090524-appb-000023
Represents the target FFN hidden layer matrix of the j-th target transformer layer; W h represents the transformation matrix;
Figure PCTCN2021090524-appb-000024
Represents the original FFN hidden layer matrix of the original transformer layer of the i-th layer; f ij represents the amount of knowledge migrated from the original transformer layer of the i-th layer to the target transformer layer of the j-th layer; M represents the number of layers of the original transformer layer; N represents the target transformer layer the number of layers.
综上,本申请实施例二提供了一种应用于BERT模型的蒸馏装置,由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。另外,基于层替换的蒸馏方式,保留了与BERT相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,由于蒸馏时,小模型的部分层基于伯努利采样,随机初始化成训练好的大模型映射层的权重,使模型收敛更快,减少训练轮数。In summary, the second embodiment of the present application provides a distillation device applied to the BERT model. Since the simplified BERT model retains the same model structure as the original BERT model, the difference is the number of layers, which makes the amount of code changes small, and The prediction codes of the large model and the small model are consistent, and the original code can be reused, so that the model does not need to balance the weight of each loss parameter in the process of distillation, thereby reducing the difficulty of the deep model distillation method, and at the same time, training the simplified BERT model The tasks of each stage are kept consistent, which makes the convergence of the simplified BERT model more stable. In addition, the distillation method based on layer replacement retains the same model structure as BERT. The difference is the number of layers, which makes the code changes smaller, and the prediction codes of the large model and the small model are consistent, and the original code can be reused , during distillation, some layers of the small model are randomly initialized to the weight of the trained large model mapping layer based on Bernoulli sampling, which makes the model converge faster and reduces the number of training rounds.
为解决上述技术问题,本申请实施例还提供计算机设备。具体请参阅图7,图7为本实施例计算机设备基本结构框图。To solve the above technical problems, the embodiments of the present application also provide computer equipment. For details, please refer to FIG. 7 , which is a block diagram of the basic structure of a computer device according to this embodiment.
所述计算机设备200包括通过系统总线相互通信连接存储器210、处理器220、网络接口230。需要指出的是,图中仅示出了具有组件210-230的计算机设备200,但是应理解的是,并不要求实施所有示出的组件,可以替代的实施更多或者更少的组件。其中,本技术领域技术人员可以理解,这里的计算机设备是一种能够按照事先设定或存储的指令,自动进行数值计算和/或信息处理的设备,其硬件包括但不限于微处理器、专用集成电路(Application Specific Integrated Circuit,ASIC)、可编程门阵列(Field-Programmable Gate Array,FPGA)、数字处理器(Digital Signal Processor,DSP)、嵌入式设备等。The computer device 200 includes a memory 210 , a processor 220 , and a network interface 230 that communicate with each other through a system bus. It should be noted that only the computer device 200 with components 210-230 is shown in the figure, but it should be understood that implementation of all of the shown components is not required, and more or less components may be implemented instead. Among them, those skilled in the art can understand that the computer device here is a device that can automatically perform numerical calculation and/or information processing according to pre-set or stored instructions, and its hardware includes but is not limited to microprocessors, special-purpose Integrated circuit (Application Specific Integrated Circuit, ASIC), programmable gate array (Field-Programmable Gate Array, FPGA), digital processor (Digital Signal Processor, DSP), embedded equipment, etc.
所述计算机设备可以是桌上型计算机、笔记本、掌上电脑及云端服务器等计算设备。所述计算机设备可以与用户通过键盘、鼠标、遥控器、触摸板或声控设备等方式进行人机交互。The computer equipment may be a desktop computer, a notebook computer, a palmtop computer, a cloud server and other computing equipment. The computer device can perform human-computer interaction with the user through a keyboard, a mouse, a remote control, a touch pad or a voice control device.
所述存储器210至少包括一种类型的可读存储介质,所述可读存储介质包括闪存、硬盘、多媒体卡、卡型存储器(例如,SD或DX存储器等)、随机访问存储器(RAM)、静态随机访问存储器(SRAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、可编程只读存储器(PROM)、磁性存储器、磁盘、光盘等,所述计算机可读存储介质可以是非易失性,也可以是易失性。在一些实施例中,所述存储器210可以是所述计算机设备200的内部存储单元,例如该计算机设备200的硬盘或内存。在另一些实施例中,所述存储器210也可以是所述计算机设备200的外部存储设备,例如该计算机设备200上配备的插接式硬盘,智能存储卡(Smart Media Card,SMC),安全数字(Secure Digital,SD)卡,闪存卡(Flash Card)等。当然,所述存储器210还可以既包括所述计算机设备200的内部存储单元也包括其外部存储设备。本实施例中,所述存储器210通常用于存储安装于所述计算机设备200的操作系统和各类应用软件,例如应用于BERT模型的蒸馏方法的计算机可读指令等。此外,所述存储器210还可以用于暂时地存储已经输出或者将要输出的各类数据。The memory 210 includes at least one type of readable storage medium, including flash memory, hard disk, multimedia card, card-type memory (eg, SD or DX memory, etc.), random access memory (RAM), static Random Access Memory (SRAM), Read Only Memory (ROM), Electrically Erasable Programmable Read Only Memory (EEPROM), Programmable Read Only Memory (PROM), magnetic memory, magnetic disks, optical disks, etc., the computer readable storage Media can be non-volatile or volatile. In some embodiments, the memory 210 may be an internal storage unit of the computer device 200 , such as a hard disk or a memory of the computer device 200 . In other embodiments, the memory 210 may also be an external storage device of the computer device 200, such as a plug-in hard disk, a smart memory card (Smart Media Card, SMC), a secure digital (Secure Digital, SD) card, flash memory card (Flash Card), etc. Of course, the memory 210 may also include both the internal storage unit of the computer device 200 and its external storage device. In this embodiment, the memory 210 is generally used to store the operating system and various application software installed on the computer device 200 , such as computer-readable instructions applied to the distillation method of the BERT model. In addition, the memory 210 can also be used to temporarily store various types of data that have been output or will be output.
所述处理器220在一些实施例中可以是中央处理器(Central Processing Unit,CPU)、控制器、微控制器、微处理器、或其他数据处理芯片。该处理器220通常用于控制所述计算机设备200的总体操作。本实施例中,所述处理器220用于运行所述存储器210中存储的计算机可读指令或者处理数据,例如运行所述应用于BERT模型的蒸馏方法的计算机可 读指令。The processor 220 may be a central processing unit (Central Processing Unit, CPU), a controller, a microcontroller, a microprocessor, or other data processing chips in some embodiments. The processor 220 is typically used to control the overall operation of the computer device 200 . In this embodiment, the processor 220 is configured to execute the computer-readable instructions stored in the memory 210 or process data, for example, the computer-readable instructions for executing the distillation method applied to the BERT model.
所述网络接口230可包括无线网络接口或有线网络接口,该网络接口230通常用于在所述计算机设备200与其他电子设备之间建立通信连接。The network interface 230 may include a wireless network interface or a wired network interface, and the network interface 230 is generally used to establish a communication connection between the computer device 200 and other electronic devices.
上述应用于BERT模型的蒸馏方法的步骤包括:The steps of the above distillation method applied to the BERT model include:
接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;Receive a model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;Read the local database, obtain the trained original BERT model corresponding to the distillation object identifier in the local database, and the loss function of the original BERT model is cross entropy;
构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;Build a default reduced model to be trained that is consistent with the trained original BERT model structure, and the loss function of the default reduced model is cross entropy;
基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;Perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model;
在所述本地数据库中获取所述中间精简模型的训练数据;Acquiring training data of the intermediate reduced model in the local database;
基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。A model training operation is performed on the intermediate reduced model based on the training data to obtain a target reduced model.
本申请提供的应用于BERT模型的蒸馏方法,由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。The distillation method applied to the BERT model provided by this application, because the simplified BERT model retains the same model structure as the original BERT model, the difference is the number of layers, which makes the amount of code changes smaller, and the prediction codes of the large model and the small model are It is consistent, and the original code can be reused, so that the model does not need to balance the weight of each loss parameter in the process of distillation, thereby reducing the difficulty of the deep model distillation method. , making the convergence of the reduced BERT model more stable.
本申请还提供了另一种实施方式,即提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机可读指令,所述计算机可读指令可被至少一个处理器执行,以使所述至少一个处理器执行如下述的应用于BERT模型的蒸馏方法的步骤:The present application also provides another embodiment, that is, to provide a computer-readable storage medium, where the computer-readable storage medium stores computer-readable instructions, and the computer-readable instructions can be executed by at least one processor to causing the at least one processor to perform the steps of the distillation method applied to the BERT model as follows:
接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;Receive a model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;Read the local database, obtain the trained original BERT model corresponding to the distillation object identifier in the local database, and the loss function of the original BERT model is cross entropy;
构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;Build a default reduced model to be trained that is consistent with the trained original BERT model structure, and the loss function of the default reduced model is cross entropy;
基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;Perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model;
在所述本地数据库中获取所述中间精简模型的训练数据;Acquiring training data of the intermediate reduced model in the local database;
基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。A model training operation is performed on the intermediate reduced model based on the training data to obtain a target reduced model.
本申请提供的应用于BERT模型的蒸馏方法,由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。The distillation method applied to the BERT model provided by this application, because the simplified BERT model retains the same model structure as the original BERT model, the difference is the number of layers, which makes the amount of code changes smaller, and the prediction codes of the large model and the small model are It is consistent, and the original code can be reused, so that the model does not need to balance the weight of each loss parameter in the process of distillation, thereby reducing the difficulty of the deep model distillation method. , making the convergence of the reduced BERT model more stable.
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端设备(可以是手机,计算机,服务器,空调器,或者网络设备等)执行本申请各个实施例所述的方法。From the description of the above embodiments, those skilled in the art can clearly understand that the method of the above embodiment can be implemented by means of software plus a necessary general hardware platform, and of course can also be implemented by hardware, but in many cases the former is better implementation. Based on this understanding, the technical solution of the present application can be embodied in the form of a software product in essence or in a part that contributes to the prior art, and the computer software product is stored in a storage medium (such as ROM/RAM, magnetic disk, CD-ROM), including several instructions to make a terminal device (which may be a mobile phone, a computer, a server, an air conditioner, or a network device, etc.) execute the methods described in the various embodiments of this application.
显然,以上所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例,附图中给出了本申请的较佳实施例,但并不限制本申请的专利范围。本申请可以以许多不同的形式来实现,相反地,提供这些实施例的目的是使对本申请的公开内容的理解更加透彻全面。尽管参照前述实施例对本申请进行了详细的说明,对于本领域的技术人员来而言,其依然可以对前述各具体实施方式所记载的技术方案进行修改,或者对其中部分技术特征进 行等效替换。凡是利用本申请说明书及附图内容所做的等效结构,直接或间接运用在其他相关的技术领域,均同理在本申请专利保护范围之内。Obviously, the above-described embodiments are only a part of the embodiments of the present application, rather than all of the embodiments. The accompanying drawings show the preferred embodiments of the present application, but do not limit the scope of the patent of the present application. This application may be embodied in many different forms, rather these embodiments are provided so that a thorough and complete understanding of the disclosure of this application is provided. Although the present application has been described in detail with reference to the foregoing embodiments, those skilled in the art can still modify the technical solutions described in the foregoing specific embodiments, or perform equivalent replacements for some of the technical features. . Any equivalent structure made by using the contents of the description and drawings of the present application, which is directly or indirectly used in other related technical fields, is also within the scope of protection of the patent of the present application.

Claims (20)

  1. 一种应用于BERT模型的蒸馏方法,其中,包括下述步骤:A distillation method applied to a BERT model, comprising the following steps:
    接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;Receive a model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
    读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;Read the local database, obtain the trained original BERT model corresponding to the distillation object identifier in the local database, and the loss function of the original BERT model is cross entropy;
    构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;Build a default reduced model to be trained that is consistent with the trained original BERT model structure, and the loss function of the default reduced model is cross entropy;
    基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;Perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model;
    在所述本地数据库中获取所述中间精简模型的训练数据;Acquiring training data of the intermediate reduced model in the local database;
    基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。A model training operation is performed on the intermediate reduced model based on the training data to obtain a target reduced model.
  2. 根据权利要求1所述的应用于BERT模型的蒸馏方法,其中,所述基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型的步骤,具体包括:The distillation method applied to the BERT model according to claim 1, wherein the step of performing a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model specifically includes:
    基于所述蒸馏系数对所述原始BERT模型的transformer层进行分组操作,得到分组transformer层;Perform a grouping operation on the transformer layer of the original BERT model based on the distillation coefficient to obtain a grouped transformer layer;
    基于伯努利分布分别在所述分组transformer层中进行提取操作,得到待替换transformer层;Based on the Bernoulli distribution, extracting operations are performed in the grouped transformer layers to obtain the transformer layers to be replaced;
    将所述待替换transformer层分别替换至所述默认精简模型,得到所述中间精简模型。The to-be-replaced transformer layers are respectively replaced with the default reduced model to obtain the intermediate reduced model.
  3. 根据权利要求1所述的应用于BERT模型的蒸馏方法,其中,所述在所述本地数据库中获取所述中间精简模型的训练数据的步骤,具体包括:The distillation method applied to the BERT model according to claim 1, wherein the step of acquiring the training data of the intermediate reduced model in the local database specifically includes:
    获取所述原始BERT模型训练后的原始训练数据;Obtain the original training data after the original BERT model is trained;
    调高所述原始BERT模型softmax层的温度参数,得到调高BERT模型;Increase the temperature parameter of the softmax layer of the original BERT model to obtain an increased BERT model;
    将所述原始训练数据输入至所述调高BERT模型进行预测操作,得到均值结果标签;Inputting the original training data into the BERT model to perform a prediction operation to obtain a mean result label;
    基于标签信息在所述原始训练数据进行筛选操作,得到带标签的筛选结果标签;Perform a screening operation on the original training data based on the tag information to obtain a tagged screening result tag;
    基于所述放大训练数据以及所述筛选训练数据选取所述精简模型训练数据。The reduced model training data is selected based on the enlarged training data and the filtered training data.
  4. 根据权利要求1所述的应用于BERT模型的蒸馏方法,其中,在所述基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型的步骤之后还包括:The distillation method applied to the BERT model according to claim 1, wherein after the step of performing a model training operation on the intermediate reduced model based on the training data to obtain a target reduced model, the method further comprises:
    在所述本地数据库中获取优化训练数据;obtaining optimized training data in the local database;
    将所述优化训练数据分别输入至所述训练好的原始BERT模型以及所述目标精简模型中,分别得到原始transformer层输出数据以及目标transformer层输出数据;Inputting the optimized training data into the trained original BERT model and the target reduced model, respectively, to obtain the original transformer layer output data and the target transformer layer output data;
    基于搬土距离计算所述原始transformer层输出数据以及目标transformer层输出数据的蒸馏损失数据;Calculate the distillation loss data of the original transformer layer output data and the target transformer layer output data based on the soil removal distance;
    根据所述蒸馏损失数据对所述目标精简模型进行参数优化操作,得到优化精简模型。A parameter optimization operation is performed on the target reduced model according to the distillation loss data to obtain an optimized reduced model.
  5. 根据权利要求4所述的应用于BERT模型的蒸馏方法,其中,所述基于搬土距离计算所述原始transformer层输出数据以及目标transformer层输出数据的蒸馏损失数据的步骤,具体包括:The distillation method applied to the BERT model according to claim 4, wherein the step of calculating the distillation loss data of the original transformer layer output data and the target transformer layer output data based on the soil removal distance specifically includes:
    获取所述原始transformer层输出的原始注意力矩阵以及所述目标transformer层输出的目标注意力矩阵;Obtain the original attention matrix output by the original transformer layer and the target attention matrix output by the target transformer layer;
    根据所述原始注意力矩阵以及所述目标注意力矩阵计算注意力EMD距离;Calculate the attention EMD distance according to the original attention matrix and the target attention matrix;
    获取所述原始transformer层输出的原始FFN隐层矩阵以及所述目标transformer层输出的目标FFN隐层矩阵;Obtain the original FFN hidden layer matrix output by the original transformer layer and the target FFN hidden layer matrix output by the target transformer layer;
    根据所述原始FFN隐层矩阵以及所述目标FFN隐层矩阵计算FFN隐层EMD距离;Calculate the FFN hidden layer EMD distance according to the original FFN hidden layer matrix and the target FFN hidden layer matrix;
    基于所述注意力EMD距离以及所述FFN隐层EMD距离获得所述蒸馏损失数据。The distillation loss data is obtained based on the attention EMD distance and the FFN hidden layer EMD distance.
  6. 根据权利要求5所述的应用于BERT模型的蒸馏方法,其中,所述注意力EMD距离表示为:The distillation method applied to a BERT model according to claim 5, wherein the attention EMD distance is expressed as:
    Figure PCTCN2021090524-appb-100001
    Figure PCTCN2021090524-appb-100001
    其中,L attn表示注意力EMD距离;A T表示原始注意力矩阵;A S表示目标注意力矩阵;
    Figure PCTCN2021090524-appb-100002
    表示原始注意力矩阵与标注意力矩阵之间的均方误差,且
    Figure PCTCN2021090524-appb-100003
    表示第i层原始transformer层的原始注意力矩阵;
    Figure PCTCN2021090524-appb-100004
    表示第j层目标transformer层的目标注意力矩阵;f ij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
    Among them, L attn represents the attention EMD distance; A T represents the original attention matrix; A S represents the target attention matrix;
    Figure PCTCN2021090524-appb-100002
    represents the mean squared error between the original attention matrix and the standard attention matrix, and
    Figure PCTCN2021090524-appb-100003
    represents the original attention matrix of the original transformer layer of the i-th layer;
    Figure PCTCN2021090524-appb-100004
    Represents the target attention matrix of the j-th target transformer layer; f ij represents the amount of knowledge migrated from the i-th original transformer layer to the j-th target transformer layer; M represents the number of layers of the original transformer layer; N represents the target transformer layer. layers.
  7. 根据权利要求5所述的应用于BERT模型的蒸馏方法,其中,所述FFN隐层EMD距离表示为:The distillation method applied to the BERT model according to claim 5, wherein the FFN hidden layer EMD distance is expressed as:
    Figure PCTCN2021090524-appb-100005
    Figure PCTCN2021090524-appb-100005
    其中,L ffn表示FFN隐层EMD距离;H T表示原始transformer层的原始FFN隐层矩阵;H S表示目标transformer层的目标FFN隐层矩阵;
    Figure PCTCN2021090524-appb-100006
    表示原始FFN隐层矩阵与目标FFN隐层矩阵之间的均方误差,且
    Figure PCTCN2021090524-appb-100007
    表示第j层目标transformer层的目标FFN隐层矩阵;W h表示转换矩阵;
    Figure PCTCN2021090524-appb-100008
    表示第i层原始transformer层的原始FFN隐层矩阵;f ij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
    Among them, Lffn represents the EMD distance of the FFN hidden layer; H T represents the original FFN hidden layer matrix of the original transformer layer; H S represents the target FFN hidden layer matrix of the target transformer layer;
    Figure PCTCN2021090524-appb-100006
    represents the mean squared error between the original FFN hidden layer matrix and the target FFN hidden layer matrix, and
    Figure PCTCN2021090524-appb-100007
    Represents the target FFN hidden layer matrix of the j-th target transformer layer; W h represents the transformation matrix;
    Figure PCTCN2021090524-appb-100008
    Represents the original FFN hidden layer matrix of the original transformer layer of the i-th layer; f ij represents the amount of knowledge migrated from the original transformer layer of the i-th layer to the target transformer layer of the j-th layer; M represents the number of layers of the original transformer layer; N represents the target transformer layer the number of layers.
  8. 一种应用于BERT模型的蒸馏装置,其中,包括:A distillation apparatus applied to a BERT model, including:
    请求接收模块,用于接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;a request receiving module, configured to receive a model distillation request sent by a user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
    原始模型获取模块,用于读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;The original model acquisition module is used for reading the local database, and in the local database, the trained original BERT model corresponding to the distillation object identifier is obtained, and the loss function of the original BERT model is cross entropy;
    默认模型构建模块,用于构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;The default model building module is used to construct a default reduced model to be trained that is consistent with the trained original BERT model structure, and the loss function of the default reduced model is cross entropy;
    蒸馏操作模块,用于基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;a distillation operation module, configured to perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model;
    训练数据获取模块,用于在所述本地数据库中获取所述中间精简模型的训练数据;a training data acquisition module, used for acquiring the training data of the intermediate reduced model in the local database;
    模型训练模块,用于基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。A model training module, configured to perform a model training operation on the intermediate reduced model based on the training data to obtain a target reduced model.
  9. 一种计算机设备,包括存储器和处理器,所述存储器中存储有计算机可读指令,所述处理器执行所述计算机可读指令时实现如下所述的应用于BERT模型的蒸馏方法的步骤:A computer device, comprising a memory and a processor, wherein computer-readable instructions are stored in the memory, and when the processor executes the computer-readable instructions, the steps of the distillation method applied to the BERT model as described below are implemented:
    接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;Receive a model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
    读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;Read the local database, obtain the trained original BERT model corresponding to the distillation object identifier in the local database, and the loss function of the original BERT model is cross entropy;
    构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;Build a default reduced model to be trained that is consistent with the trained original BERT model structure, and the loss function of the default reduced model is cross entropy;
    基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;Perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model;
    在所述本地数据库中获取所述中间精简模型的训练数据;Acquiring training data of the intermediate reduced model in the local database;
    基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。A model training operation is performed on the intermediate reduced model based on the training data to obtain a target reduced model.
  10. 根据权利要求9所述的计算机设备,其中,所述基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型的步骤,具体包括:The computer device according to claim 9, wherein the step of performing a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model specifically includes:
    基于所述蒸馏系数对所述原始BERT模型的transformer层进行分组操作,得到分组transformer层;Perform a grouping operation on the transformer layer of the original BERT model based on the distillation coefficient to obtain a grouped transformer layer;
    基于伯努利分布分别在所述分组transformer层中进行提取操作,得到待替换 transformer层;Perform extraction operations in the grouped transformer layers based on the Bernoulli distribution to obtain the transformer layer to be replaced;
    将所述待替换transformer层分别替换至所述默认精简模型,得到所述中间精简模型。The to-be-replaced transformer layers are respectively replaced with the default reduced model to obtain the intermediate reduced model.
  11. 根据权利要求9所述的计算机设备,其中,所述在所述本地数据库中获取所述中间精简模型的训练数据的步骤,具体包括:The computer device according to claim 9, wherein the step of acquiring the training data of the intermediate reduced model in the local database specifically includes:
    获取所述原始BERT模型训练后的原始训练数据;Obtain the original training data after the original BERT model is trained;
    调高所述原始BERT模型softmax层的温度参数,得到调高BERT模型;Increase the temperature parameter of the softmax layer of the original BERT model to obtain an increased BERT model;
    将所述原始训练数据输入至所述调高BERT模型进行预测操作,得到均值结果标签;Inputting the original training data into the BERT model to perform a prediction operation to obtain a mean result label;
    基于标签信息在所述原始训练数据进行筛选操作,得到带标签的筛选结果标签;Perform a screening operation on the original training data based on the tag information to obtain a tagged screening result tag;
    基于所述放大训练数据以及所述筛选训练数据选取所述精简模型训练数据。The reduced model training data is selected based on the enlarged training data and the filtered training data.
  12. 根据权利要求9所述的计算机设备,其中,在所述基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型的步骤之后还包括:The computer device according to claim 9, wherein after the step of performing a model training operation on the intermediate reduced model based on the training data to obtain a target reduced model, it further comprises:
    在所述本地数据库中获取优化训练数据;obtaining optimized training data in the local database;
    将所述优化训练数据分别输入至所述训练好的原始BERT模型以及所述目标精简模型中,分别得到原始transformer层输出数据以及目标transformer层输出数据;Inputting the optimized training data into the trained original BERT model and the target reduced model, respectively, to obtain the original transformer layer output data and the target transformer layer output data;
    基于搬土距离计算所述原始transformer层输出数据以及目标transformer层输出数据的蒸馏损失数据;Calculate the distillation loss data of the original transformer layer output data and the target transformer layer output data based on the soil removal distance;
    根据所述蒸馏损失数据对所述目标精简模型进行参数优化操作,得到优化精简模型。A parameter optimization operation is performed on the target reduced model according to the distillation loss data to obtain an optimized reduced model.
  13. 根据权利要求12所述的计算机设备,其中,所述基于搬土距离计算所述原始transformer层输出数据以及目标transformer层输出数据的蒸馏损失数据的步骤,具体包括:The computer device according to claim 12, wherein the step of calculating the distillation loss data of the output data of the original transformer layer and the output data of the target transformer layer based on the soil removal distance specifically includes:
    获取所述原始transformer层输出的原始注意力矩阵以及所述目标transformer层输出的目标注意力矩阵;Obtain the original attention matrix output by the original transformer layer and the target attention matrix output by the target transformer layer;
    根据所述原始注意力矩阵以及所述目标注意力矩阵计算注意力EMD距离;Calculate the attention EMD distance according to the original attention matrix and the target attention matrix;
    获取所述原始transformer层输出的原始FFN隐层矩阵以及所述目标transformer层输出的目标FFN隐层矩阵;Obtain the original FFN hidden layer matrix output by the original transformer layer and the target FFN hidden layer matrix output by the target transformer layer;
    根据所述原始FFN隐层矩阵以及所述目标FFN隐层矩阵计算FFN隐层EMD距离;Calculate the FFN hidden layer EMD distance according to the original FFN hidden layer matrix and the target FFN hidden layer matrix;
    基于所述注意力EMD距离以及所述FFN隐层EMD距离获得所述蒸馏损失数据。The distillation loss data is obtained based on the attention EMD distance and the FFN hidden layer EMD distance.
  14. 根据权利要求13所述的计算机设备,其中,所述注意力EMD距离表示为:The computer device of claim 13, wherein the attention EMD distance is expressed as:
    Figure PCTCN2021090524-appb-100009
    Figure PCTCN2021090524-appb-100009
    其中,L attn表示注意力EMD距离;A T表示原始注意力矩阵;A S表示目标注意力矩阵;
    Figure PCTCN2021090524-appb-100010
    表示原始注意力矩阵与标注意力矩阵之间的均方误差,且
    Figure PCTCN2021090524-appb-100011
    表示第i层原始transformer层的原始注意力矩阵;
    Figure PCTCN2021090524-appb-100012
    表示第j层目标transformer层的目标注意力矩阵;f ij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
    Among them, L attn represents the attention EMD distance; A T represents the original attention matrix; A S represents the target attention matrix;
    Figure PCTCN2021090524-appb-100010
    represents the mean squared error between the original attention matrix and the standard attention matrix, and
    Figure PCTCN2021090524-appb-100011
    represents the original attention matrix of the original transformer layer of the i-th layer;
    Figure PCTCN2021090524-appb-100012
    Represents the target attention matrix of the j-th target transformer layer; f ij represents the amount of knowledge migrated from the i-th original transformer layer to the j-th target transformer layer; M represents the number of layers of the original transformer layer; N represents the target transformer layer. layers.
  15. 一种计算机可读存储介质,其中,所述计算机可读存储介质上存储有计算机可读指令,所述计算机可读指令被处理器执行时实现如下所述的应用于BERT模型的蒸馏方法的步骤:A computer-readable storage medium, wherein computer-readable instructions are stored on the computer-readable storage medium, and when the computer-readable instructions are executed by a processor, the steps of the distillation method applied to the BERT model as described below are implemented :
    接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;Receive a model distillation request sent by the user terminal, where the model distillation request at least carries a distillation object identifier and a distillation coefficient;
    读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;Read the local database, obtain the trained original BERT model corresponding to the distillation object identifier in the local database, and the loss function of the original BERT model is cross entropy;
    构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;Build a default reduced model to be trained that is consistent with the trained original BERT model structure, and the loss function of the default reduced model is cross entropy;
    基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;Perform a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model;
    在所述本地数据库中获取所述中间精简模型的训练数据;Acquiring training data of the intermediate reduced model in the local database;
    基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。A model training operation is performed on the intermediate reduced model based on the training data to obtain a target reduced model.
  16. 根据权利要求15所述的计算机可读存储介质,其中,所述基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型的步骤,具体包括:The computer-readable storage medium according to claim 15, wherein the step of performing a distillation operation on the default reduced model based on the distillation coefficient to obtain an intermediate reduced model specifically includes:
    基于所述蒸馏系数对所述原始BERT模型的transformer层进行分组操作,得到分组transformer层;Perform a grouping operation on the transformer layer of the original BERT model based on the distillation coefficient to obtain a grouped transformer layer;
    基于伯努利分布分别在所述分组transformer层中进行提取操作,得到待替换transformer层;Based on the Bernoulli distribution, extracting operations are performed in the grouped transformer layers to obtain the transformer layers to be replaced;
    将所述待替换transformer层分别替换至所述默认精简模型,得到所述中间精简模型。The to-be-replaced transformer layers are respectively replaced with the default reduced model to obtain the intermediate reduced model.
  17. 根据权利要求15所述的计算机可读存储介质,其中,所述在所述本地数据库中获取所述中间精简模型的训练数据的步骤,具体包括:The computer-readable storage medium according to claim 15, wherein the step of acquiring the training data of the intermediate reduced model in the local database specifically comprises:
    获取所述原始BERT模型训练后的原始训练数据;Obtain the original training data after the original BERT model is trained;
    调高所述原始BERT模型softmax层的温度参数,得到调高BERT模型;Increase the temperature parameter of the softmax layer of the original BERT model to obtain an increased BERT model;
    将所述原始训练数据输入至所述调高BERT模型进行预测操作,得到均值结果标签;Inputting the original training data into the BERT model to perform a prediction operation to obtain a mean result label;
    基于标签信息在所述原始训练数据进行筛选操作,得到带标签的筛选结果标签;Perform a screening operation on the original training data based on the tag information to obtain a tagged screening result tag;
    基于所述放大训练数据以及所述筛选训练数据选取所述精简模型训练数据。The reduced model training data is selected based on the enlarged training data and the filtered training data.
  18. 根据权利要求15所述的计算机可读存储介质,其中,在所述基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型的步骤之后还包括:The computer-readable storage medium according to claim 15, wherein after the step of performing a model training operation on the intermediate reduced model based on the training data to obtain a target reduced model, the method further comprises:
    在所述本地数据库中获取优化训练数据;Obtaining optimized training data in the local database;
    将所述优化训练数据分别输入至所述训练好的原始BERT模型以及所述目标精简模型中,分别得到原始transformer层输出数据以及目标transformer层输出数据;Inputting the optimized training data into the trained original BERT model and the target reduced model, respectively, to obtain the original transformer layer output data and the target transformer layer output data;
    基于搬土距离计算所述原始transformer层输出数据以及目标transformer层输出数据的蒸馏损失数据;Calculate the distillation loss data of the original transformer layer output data and the target transformer layer output data based on the soil removal distance;
    根据所述蒸馏损失数据对所述目标精简模型进行参数优化操作,得到优化精简模型。A parameter optimization operation is performed on the target reduced model according to the distillation loss data to obtain an optimized reduced model.
  19. 根据权利要求18所述的计算机可读存储介质,其中,所述基于搬土距离计算所述原始transformer层输出数据以及目标transformer层输出数据的蒸馏损失数据的步骤,具体包括:The computer-readable storage medium according to claim 18, wherein the step of calculating the distillation loss data of the output data of the original transformer layer and the output data of the target transformer layer based on the soil removal distance specifically comprises:
    获取所述原始transformer层输出的原始注意力矩阵以及所述目标transformer层输出的目标注意力矩阵;Obtain the original attention matrix output by the original transformer layer and the target attention matrix output by the target transformer layer;
    根据所述原始注意力矩阵以及所述目标注意力矩阵计算注意力EMD距离;Calculate the attention EMD distance according to the original attention matrix and the target attention matrix;
    获取所述原始transformer层输出的原始FFN隐层矩阵以及所述目标transformer层输出的目标FFN隐层矩阵;Obtain the original FFN hidden layer matrix output by the original transformer layer and the target FFN hidden layer matrix output by the target transformer layer;
    根据所述原始FFN隐层矩阵以及所述目标FFN隐层矩阵计算FFN隐层EMD距离;Calculate the FFN hidden layer EMD distance according to the original FFN hidden layer matrix and the target FFN hidden layer matrix;
    基于所述注意力EMD距离以及所述FFN隐层EMD距离获得所述蒸馏损失数据。The distillation loss data is obtained based on the attention EMD distance and the FFN hidden layer EMD distance.
  20. 根据权利要求19所述的计算机可读存储介质,其中,所述注意力EMD距离表示为:The computer-readable storage medium of claim 19, wherein the attention EMD distance is represented as:
    Figure PCTCN2021090524-appb-100013
    Figure PCTCN2021090524-appb-100013
    其中,L attn表示注意力EMD距离;A T表示原始注意力矩阵;A S表示目标注意力矩阵;
    Figure PCTCN2021090524-appb-100014
    表示原始注意力矩阵与标注意力矩阵之间的均方误差,且
    Figure PCTCN2021090524-appb-100015
    表示第i层原始transformer层的原始注意力矩阵;
    Figure PCTCN2021090524-appb-100016
    表示第j层目标transformer层的目标注意力矩阵;f ij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
    Among them, L attn represents the attention EMD distance; A T represents the original attention matrix; A S represents the target attention matrix;
    Figure PCTCN2021090524-appb-100014
    represents the mean squared error between the original attention matrix and the standard attention matrix, and
    Figure PCTCN2021090524-appb-100015
    represents the original attention matrix of the original transformer layer of the i-th layer;
    Figure PCTCN2021090524-appb-100016
    Represents the target attention matrix of the j-th target transformer layer; f ij represents the amount of knowledge migrated from the i-th original transformer layer to the j-th target transformer layer; M represents the number of layers of the original transformer layer; N represents the target transformer layer. layers.
PCT/CN2021/090524 2020-11-17 2021-04-28 Distillation method and apparatus applied to bert model, device, and storage medium WO2022105121A1 (en)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN202011288877.7A CN112418291A (en) 2020-11-17 2020-11-17 Distillation method, device, equipment and storage medium applied to BERT model
CN202011288877.7 2020-11-17

Publications (1)

Publication Number Publication Date
WO2022105121A1 true WO2022105121A1 (en) 2022-05-27

Family

ID=74832129

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2021/090524 WO2022105121A1 (en) 2020-11-17 2021-04-28 Distillation method and apparatus applied to bert model, device, and storage medium

Country Status (2)

Country Link
CN (1) CN112418291A (en)
WO (1) WO2022105121A1 (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116402811A (en) * 2023-06-05 2023-07-07 长沙海信智能系统研究院有限公司 Fighting behavior identification method and electronic equipment

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112418291A (en) * 2020-11-17 2021-02-26 平安科技(深圳)有限公司 Distillation method, device, equipment and storage medium applied to BERT model
GB2619569A (en) * 2020-12-15 2023-12-13 Zhejiang Lab Method and platform for automatically compressing multi-task-oriented pre-training language model

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10607598B1 (en) * 2019-04-05 2020-03-31 Capital One Services, Llc Determining input data for speech processing
CN111553479A (en) * 2020-05-13 2020-08-18 鼎富智能科技有限公司 Model distillation method, text retrieval method and text retrieval device
CN111767711A (en) * 2020-09-02 2020-10-13 之江实验室 Compression method and platform of pre-training language model based on knowledge distillation
CN112418291A (en) * 2020-11-17 2021-02-26 平安科技(深圳)有限公司 Distillation method, device, equipment and storage medium applied to BERT model

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110188360B (en) * 2019-06-06 2023-04-25 北京百度网讯科技有限公司 Model training method and device

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10607598B1 (en) * 2019-04-05 2020-03-31 Capital One Services, Llc Determining input data for speech processing
CN111553479A (en) * 2020-05-13 2020-08-18 鼎富智能科技有限公司 Model distillation method, text retrieval method and text retrieval device
CN111767711A (en) * 2020-09-02 2020-10-13 之江实验室 Compression method and platform of pre-training language model based on knowledge distillation
CN112418291A (en) * 2020-11-17 2021-02-26 平安科技(深圳)有限公司 Distillation method, device, equipment and storage medium applied to BERT model

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116402811A (en) * 2023-06-05 2023-07-07 长沙海信智能系统研究院有限公司 Fighting behavior identification method and electronic equipment
CN116402811B (en) * 2023-06-05 2023-08-18 长沙海信智能系统研究院有限公司 Fighting behavior identification method and electronic equipment

Also Published As

Publication number Publication date
CN112418291A (en) 2021-02-26

Similar Documents

Publication Publication Date Title
WO2022105121A1 (en) Distillation method and apparatus applied to bert model, device, and storage medium
US11030522B2 (en) Reducing the size of a neural network through reduction of the weight matrices
CN109190120B (en) Neural network training method and device and named entity identification method and device
WO2020232861A1 (en) Named entity recognition method, electronic device and storage medium
WO2021068329A1 (en) Chinese named-entity recognition method, device, and computer-readable storage medium
WO2021121198A1 (en) Semantic similarity-based entity relation extraction method and apparatus, device and medium
WO2020108063A1 (en) Feature word determining method, apparatus, and server
CN109697451B (en) Similar image clustering method and device, storage medium and electronic equipment
CN113792854A (en) Model training and word stock establishing method, device, equipment and storage medium
WO2020215683A1 (en) Semantic recognition method and apparatus based on convolutional neural network, and non-volatile readable storage medium and computer device
WO2023124005A1 (en) Map point of interest query method and apparatus, device, storage medium, and program product
CN113837308B (en) Knowledge distillation-based model training method and device and electronic equipment
WO2022110640A1 (en) Model optimization method and apparatus, computer device and storage medium
WO2023168909A1 (en) Pre-training method and model fine-tuning method for geographical pre-training model
CN112084752B (en) Sentence marking method, device, equipment and storage medium based on natural language
WO2023138188A1 (en) Feature fusion model training method and apparatus, sample retrieval method and apparatus, and computer device
CN115482395B (en) Model training method, image classification device, electronic equipment and medium
CN113190702B (en) Method and device for generating information
JP2022169743A (en) Information extraction method and device, electronic equipment, and storage medium
CN114780746A (en) Knowledge graph-based document retrieval method and related equipment thereof
WO2023040742A1 (en) Text data processing method, neural network training method, and related devices
CN114781611A (en) Natural language processing method, language model training method and related equipment
CN115730597A (en) Multi-level semantic intention recognition method and related equipment thereof
CN114861758A (en) Multi-modal data processing method and device, electronic equipment and readable storage medium
CN113435523B (en) Method, device, electronic equipment and storage medium for predicting content click rate

Legal Events

Date Code Title Description
121 Ep: the epo has been informed by wipo that ep was designated in this application

Ref document number: 21893277

Country of ref document: EP

Kind code of ref document: A1

NENP Non-entry into the national phase

Ref country code: DE

122 Ep: pct application non-entry in european phase

Ref document number: 21893277

Country of ref document: EP

Kind code of ref document: A1