WO2022178948A1 - Model distillation method and apparatus, device, and storage medium - Google Patents

Model distillation method and apparatus, device, and storage medium Download PDF

Info

Publication number
WO2022178948A1
WO2022178948A1 PCT/CN2021/084539 CN2021084539W WO2022178948A1 WO 2022178948 A1 WO2022178948 A1 WO 2022178948A1 CN 2021084539 W CN2021084539 W CN 2021084539W WO 2022178948 A1 WO2022178948 A1 WO 2022178948A1
Authority
WO
WIPO (PCT)
Prior art keywords
distillation
model
student model
loss value
parameters
Prior art date
Application number
PCT/CN2021/084539
Other languages
French (fr)
Chinese (zh)
Inventor
王健宗
宋青原
吴天博
程宁
Original Assignee
平安科技(深圳)有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 平安科技(深圳)有限公司 filed Critical 平安科技(深圳)有限公司
Publication of WO2022178948A1 publication Critical patent/WO2022178948A1/en

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting

Definitions

  • the present application relates to the technical field of artificial intelligence, and in particular, to a model distillation method, device, equipment and storage medium.
  • the pre-training model has strong coding ability and generalization ability, and when using the pre-training model, processing downstream tasks can greatly reduce the amount of labeled data used, so it plays a huge role in various fields. Because the pre-trained model usually has a large amount of parameters, it cannot be used online.
  • the inventor realizes that the prior art is used to reduce the amount of parameters and improve the inference speed by distilling a pre-trained model with a large amount of parameters into a model with a small amount of parameters.
  • a gap in accuracy between the small model distilled by the current distillation method and the original model and many even reach a gap of about 10 points.
  • many distillation schemes currently require a large amount of labeled data, which greatly increases the cost of distillation.
  • the distillation method of the prior art has a gap in accuracy between the distilled small model and the original model.
  • Many distillation schemes require a large amount of labeled data, which greatly increases the technical problem of the cost of distillation.
  • the main purpose of this application is to provide a model distillation method, device, equipment and storage medium, which aims to solve the gap in accuracy between the small model distilled by the distillation method in the prior art and the original model, and many distillation schemes require a large amount of , which greatly increases the technical problem of the cost of distillation.
  • the present application proposes a model distillation method, the method includes:
  • the application also proposes a model distillation device, which includes:
  • a data acquisition module for acquiring a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, where the pre-training model is a model obtained through Bert network training;
  • the first-stage distillation module is used to perform overall distillation learning on the pre-training model by using the unlabeled training samples and the student model to obtain the student model after the first distillation;
  • the second-stage distillation module is used to perform hierarchical distillation learning on the pre-training model by using the unlabeled training samples and the student model after the first distillation to obtain the student model after the second distillation;
  • the third-stage distillation module is used to perform hierarchical distillation learning on the student model after the second distillation using the labeled training samples to obtain a trained student model.
  • the present application also proposes a computer device, including a memory and a processor, the memory stores a computer program, and the processor implements the following method steps when executing the computer program:
  • the present application also proposes a computer-readable storage medium on which a computer program is stored, and when the computer program is executed by a processor, the following method steps are implemented:
  • the model distillation method, device, equipment and storage medium of the present application use unlabeled training samples and student models to perform overall distillation learning on the pre-training model to obtain the student model after the first distillation, using unlabeled training samples and The student model after the first distillation performs hierarchical distillation learning on the pre-trained model to obtain the student model after the second distillation, and uses the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation, and obtains
  • the trained student model improves the accuracy of the model obtained after distillation through three distillations; because unlabeled training samples are used in the first and second distillations, the need for labeling of training samples is reduced. Reduced distillation costs.
  • FIG. 1 is a schematic flowchart of a model distillation method according to an embodiment of the application.
  • FIG. 2 is a schematic block diagram of the structure of a model distillation apparatus according to an embodiment of the application.
  • FIG. 3 is a schematic structural block diagram of a computer device according to an embodiment of the present application.
  • This application proposes A model distillation method, which is applied in the field of artificial intelligence technology.
  • the model distillation method uses the overall distillation learning for the first time, the hierarchical distillation learning for the second time, and the hierarchical distillation learning for the third time, and improves the accuracy of the model obtained after distillation through three distillations;
  • the second distillation uses unlabeled training samples, which reduces the need for labeling training samples and reduces the cost of distillation.
  • a model distillation method is provided in the embodiment of the present application, and the method includes:
  • S1 Obtain a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, where the pre-training model is a model obtained through Bert network training;
  • S2 Use the unlabeled training samples and the student model to perform overall distillation learning on the pre-training model to obtain the student model after the first distillation;
  • the overall distillation learning is performed on the pre-training model by using the unlabeled training samples and the student model to obtain the student model after the first distillation, and the unlabeled training samples and the student model after the first distillation are used for pre-training.
  • the model is subjected to hierarchical distillation learning, and the student model after the second distillation is obtained.
  • the labeled training samples are used to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model, which is improved by three distillations.
  • the accuracy of the model obtained after distillation is improved; because unlabeled training samples are used in the first and second distillations, the need for labeling training samples is reduced and the cost of distillation is reduced.
  • the pre-training model can be obtained from the database, the pre-training model input by the user, or the pre-training model sent by a third-party application system.
  • the student model can be obtained from the database, the student model input by the user, or the student model sent by the third-party application system.
  • Multiple labeled training samples can be obtained from the database, multiple labeled training samples input by the user, or multiple labeled training samples sent by third-party application systems.
  • Multiple unlabeled training samples can be obtained from the database, multiple unlabeled training samples input by the user, or multiple unlabeled training samples sent by third-party application systems.
  • the student model includes: Embedding layer, BiLSTM layer, Dense layer.
  • the Embedding layer inputs data to the BiLSTM layer, and the BiLSTM layer outputs data to the Dense layer.
  • the Embedding layer is the embedding layer.
  • the output of the BiLSTM layer is the prediction score for each label.
  • the Dense layer is a fully connected layer that outputs predicted probabilities.
  • the labeled training samples include: sample data and sample calibration values, where the sample calibration values are the calibration results of the sample data.
  • Unlabeled training samples include: sample data.
  • the number of labeled training samples in the plurality of labeled training samples is smaller than the number of unlabeled training samples in the plurality of unlabeled training samples.
  • the labeled training sample is used to perform hierarchical distillation learning on the student model after the second distillation, that is, the parameters of the student model after the second distillation are updated hierarchically, and the trained The student model after the second distillation is used as the trained student model.
  • the phenomenon of catastrophic forgetting in the distillation method in the prior art is avoided, and the phenomenon of forgetting the content of the second distillation is avoided in the third distillation.
  • the above-mentioned steps of using the unlabeled training samples and the student model to perform overall distillation learning on the pre-training model to obtain the student model after the first distillation include:
  • S21 Input the unlabeled training samples into the pre-training model for scoring prediction, and obtain the first prediction score output by the scoring prediction layer of the pre-training model;
  • S23 Input the first predicted score and the second predicted score into a first loss function for calculation to obtain a first loss value, update all parameters of the student model according to the first loss value, and update the parameters after updating.
  • the said student model is used for the next calculation of the second predicted score;
  • This embodiment realizes that all parameters of the student model are updated according to the prediction score obtained from the unlabeled training sample prediction and the loss value is calculated, and the knowledge learned by the pre-training model is learned by overall distillation.
  • For S21 input the sample data of the unlabeled training samples into the pre-training model for prediction, and use the score output by the score prediction layer of the pre-training model as the first prediction score.
  • the first predicted score and the second predicted score are input into the first loss function to calculate the loss value, and the calculated loss value is used as the first loss value.
  • the method for updating all the parameters of the student model according to the first loss value can be selected from the prior art, which will not be repeated here.
  • the first convergence condition means that the magnitude of the first loss value calculated twice adjacently satisfies the Lipschitz condition (the Lipschitz continuity condition).
  • the number of iterations reaching the second convergence condition refers to the number of times the student model is used to calculate the second predicted score, that is, the number of iterations increases by 1 after one calculation.
  • the above-mentioned step of inputting the first predicted score and the second predicted score into a first loss function for calculation to obtain a first loss value includes:
  • the first prediction score and the second prediction score are input into the KL divergence loss function for calculation, and the first loss value is obtained.
  • KL divergence loss function also known as K-L divergence loss function.
  • x is the sample data of the unlabeled training samples
  • p(x) is the first predicted score
  • q(x) is the second predicted score
  • log() is a logarithmic function
  • the above step of using the unlabeled training samples and the student model after the first distillation to perform hierarchical distillation learning on the pre-training model to obtain the student model after the second distillation include:
  • S31 Input the unlabeled training samples into the pre-training model for probability prediction, and obtain the first prediction probability output by the probability prediction layer of the pre-training model;
  • S33 Input the first predicted probability and the second predicted probability into a second loss function for calculation to obtain a second loss value, and update the first preset parameter hierarchical update rule according to the second loss value
  • the parameters of the student model after the first distillation, the student model after the first distillation after updating the parameters is used for the next calculation of the second prediction probability;
  • This embodiment realizes that the parameters of the student model after the first distillation are updated hierarchically according to the prediction probability calculation loss value obtained from the unlabeled training samples, thereby avoiding the catastrophic forgetting phenomenon of the distillation method in the prior art, Avoid the phenomenon of forgetting the content of the first distillation in the second distillation.
  • For S31 input the sample data of the unlabeled training samples into the pre-training model for probability prediction, and use the probability output by the probability prediction layer of the pre-training model as the first prediction probability.
  • the sample data of the unlabeled training sample is input into the student model after the first distillation for probability prediction, and the probability output by the Dense layer of the student model after the first distillation is used as the second prediction probability.
  • the first predicted probability and the second predicted probability are input into the second loss function to calculate the loss value, and the calculated loss value is used as the second loss value.
  • the second loss value only the parameters of one layer (that is, the Embedding layer, the BiLSTM layer, and the Dense layer) of the student model after the first distillation are updated each time.
  • the third convergence condition means that the magnitude of the third loss value calculated twice adjacently satisfies the Lipschitz condition (the Lipschitz continuity condition).
  • the number of iterations reaching the fourth convergence condition refers to the number of times that the student model is used to calculate the fourth predicted probability, that is, the number of iterations increases by 1 after one calculation.
  • the first predicted probability and the second predicted probability are input into a second loss function for calculation to obtain a second loss value, which is stratified by first preset parameters according to the second loss value
  • the steps of updating the parameters of the student model after the first distillation of the update rule include:
  • S331 Input the first predicted probability and the second predicted probability into the MSE loss function for calculation to obtain the second loss value
  • This embodiment realizes that the parameters of the student model after the first distillation are updated hierarchically according to the prediction probability calculation loss value obtained from the unlabeled training samples, thereby avoiding the catastrophic forgetting phenomenon of the distillation method in the prior art, Avoid the phenomenon of forgetting the content of the first distillation in the second distillation.
  • the convergence conditions of the first Dense layer and the convergence conditions of the first BiLSTM layer can be set according to training requirements, which are not specifically limited here.
  • the above-mentioned steps of using the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model include:
  • S42 Input the third prediction probability and the sample calibration value of the labeled training sample into a third loss function for calculation, to obtain a third loss value, and stratify according to the second preset parameter according to the third loss value
  • the update rule updates the parameters of the student model after the second distillation, and the student model after the second distillation after the update parameter is used for the next calculation of the third prediction probability;
  • This embodiment implements hierarchical update of the parameters of the student model after the second distillation according to the prediction probability calculation loss value predicted by the labeled training samples, thereby avoiding the phenomenon of catastrophic forgetting in the distillation method in the prior art , to avoid the phenomenon of forgetting the content of the second distillation in the third distillation.
  • the third predicted probability and the sample calibration value of the labeled training sample are input into the third loss function to calculate the loss value, and the calculated loss value is used as the third loss value.
  • the third loss value only the parameters of one layer (that is, the Embedding layer, the BiLSTM layer, and the Dense layer) of the student model after the second distillation are updated each time.
  • the fifth convergence condition means that the magnitude of the third loss value calculated twice adjacently satisfies the Lipschitz condition (the Lipschitz continuity condition).
  • the number of iterations reaching the sixth convergence condition refers to the number of times that the student model after the second distillation is used to calculate the third predicted probability, that is, the number of iterations increases by 1 after one calculation.
  • the above-mentioned third prediction probability and the sample calibration value of the labeled training sample are input into a third loss function for calculation to obtain a third loss value, and according to the third loss value
  • the step of updating the parameters of the student model after the second distillation by the second preset parameter hierarchical update rule includes:
  • S421 Input the third prediction probability and the sample calibration value of the labeled training sample into a cross-entropy loss function for calculation, to obtain the third loss value;
  • the parameters of the student model after the second distillation are updated hierarchically according to the prediction probability calculation loss value predicted by the labeled training samples, so as to avoid the catastrophic forgetting phenomenon of the distillation method in the prior art, and avoid the The phenomenon of forgetting the contents of the second distillation after the third distillation.
  • y c is the sample calibration value of the labeled training sample
  • p c is the third prediction probability
  • the convergence conditions of the second Dense layer and the convergence conditions of the second BiLSTM layer can be set according to training requirements, which are not specifically limited here.
  • the present application also proposes a model distillation device, the device comprising:
  • the data acquisition module 100 is used to acquire a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, wherein the pre-training model is a model obtained based on Bert network training;
  • the first-stage distillation module 200 is used to perform overall distillation learning on the pre-training model by using the unlabeled training samples and the student model to obtain the student model after the first distillation;
  • the second-stage distillation module 300 is configured to perform hierarchical distillation learning on the pre-training model using the unlabeled training samples and the student model after the first distillation, to obtain the student model after the second distillation;
  • the third-stage distillation module 400 is configured to perform hierarchical distillation learning on the student model after the second distillation by using the labeled training samples to obtain a trained student model.
  • the overall distillation learning is performed on the pre-training model by using the unlabeled training samples and the student model to obtain the student model after the first distillation, and the unlabeled training samples and the student model after the first distillation are used for pre-training.
  • the model is subjected to hierarchical distillation learning, and the student model after the second distillation is obtained.
  • the labeled training samples are used to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model, which is improved by three distillations.
  • the accuracy of the model obtained after distillation is improved; because unlabeled training samples are used in the first and second distillations, the need for labeling training samples is reduced and the cost of distillation is reduced.
  • the first-stage distillation module 200 includes: a pre-training model score prediction sub-module, a student model score prediction sub-module, and a first-stage distillation training sub-module;
  • the pre-training model scoring prediction submodule is used to input the unlabeled training samples into the pre-training model for scoring prediction, and obtain the first prediction score output by the scoring prediction layer of the pre-training model;
  • the student model scoring prediction sub-module is used to input the unlabeled training samples into the student model for scoring prediction to obtain a second prediction score
  • the first-stage distillation training sub-module is used to input the first prediction score and the second prediction score into a first loss function for calculation to obtain a first loss value, and update the first loss value according to the first loss value.
  • the student model after updating the parameters is used for the next calculation of the second prediction score, and the above method steps are repeated until the first loss value reaches the first convergence condition or the number of iterations reaches the second.
  • Convergence condition the student model whose first loss value reaches the first convergence condition or the number of iterations reaches the second convergence condition is determined as the student model after the first distillation.
  • the first-stage distillation training sub-module includes: a first loss value calculation unit;
  • the first loss value calculation unit is configured to input the first prediction score and the second prediction score into the KL divergence loss function for calculation, and obtain the first loss value.
  • the second-stage distillation module 300 includes: a pre-training model probability prediction sub-module, a student model probability prediction sub-module after the first distillation, and a second-stage distillation training sub-module;
  • the pre-training model probability prediction sub-module is configured to input the unlabeled training samples into the pre-training model for probability prediction, and obtain the first prediction probability output by the probability prediction layer of the pre-training model;
  • the student model probability prediction submodule after the first distillation is used to input the unlabeled training samples into the student model after the first distillation for probability prediction, and obtain a second predicted probability;
  • the second-stage distillation training sub-module is used to input the first predicted probability and the second predicted probability into a second loss function for calculation to obtain a second loss value.
  • the preset parameter hierarchical update rule updates the parameters of the student model after the first distillation, and the student model after the first distillation after updating the parameters is used for the next calculation of the second prediction probability, and repeated execution
  • the above method steps are performed until the second loss value reaches the third convergence condition or the number of iterations reaches the fourth convergence condition, and the second loss value reaches the third convergence condition or the iteration number reaches the fourth convergence condition for the first time
  • the student model after distillation is determined as the student model after the second distillation.
  • the second-stage distillation training sub-module includes: a second loss value calculation unit, and a first parameter update unit;
  • the second loss value calculation unit configured to input the first predicted probability and the second predicted probability into an MSE loss function for calculation, to obtain the second loss value
  • the first parameter updating unit is configured to update the first Dense layer parameter according to the Dense layer parameter in the second loss value when the Dense layer parameter in the second loss value does not reach the convergence condition of the first Dense layer
  • the parameters of the Dense layer of the distilled student model otherwise, when the BiLSTM layer parameters in the second loss value do not reach the convergence condition of the first BiLSTM layer, update the BiLSTM layer parameters according to the second loss value.
  • the parameters of the BiLSTM layer of the student model after the first distillation, otherwise, the parameters of the Embedding layer of the student model after the first distillation are updated according to the parameters of the Embedding layer in the second loss value.
  • the third-stage distillation module 400 includes: a student model probability prediction sub-module after the second distillation, and a third-stage distillation training sub-module;
  • the student model probability prediction submodule after the second distillation is used to input the labeled training sample into the student model after the second distillation for probability prediction, and obtain a third prediction probability;
  • the third-stage distillation training sub-module is used to input the third prediction probability and the sample calibration value of the labeled training sample into a third loss function for calculation to obtain a third loss value, according to the third loss function.
  • the loss value updates the parameters of the student model after the second distillation according to the second preset parameter hierarchical update rule, and the student model after the second distillation after updating the parameters is used for the next calculation of the third
  • To predict the probability repeat the above method steps until the third loss value reaches the fifth convergence condition or the number of iterations reaches the sixth convergence condition, and the third loss value reaches the fifth convergence condition or the iteration number reaches the sixth convergence condition.
  • the student model after the second distillation is determined as the trained student model.
  • the third-stage distillation training sub-module includes: a third loss value calculation unit and a second parameter update unit;
  • the third loss value calculation unit is configured to input the third prediction probability and the sample calibration value of the labeled training sample into a cross-entropy loss function for calculation to obtain the third loss value;
  • the second parameter updating unit is configured to update the second Dense layer parameter according to the Dense layer parameter in the third loss value when the Dense layer parameter in the third loss value does not reach the convergence condition of the second Dense layer
  • the parameters of the Dense layer of the distilled student model otherwise, when the BiLSTM layer parameters in the third loss value do not reach the convergence condition of the second BiLSTM layer, update the BiLSTM layer parameters according to the third loss value.
  • the parameters of the BiLSTM layer of the student model after the second distillation, otherwise, the parameters of the Embedding layer of the student model after the second distillation are updated according to the parameters of the Embedding layer in the third loss value.
  • an embodiment of the present application further provides a computer device.
  • the computer device may be a server, and its internal structure may be as shown in FIG. 3 .
  • the computer device includes a processor, memory, a network interface, and a database connected by a system bus. Among them, the processor of the computer design is used to provide computing and control capabilities.
  • the memory of the computer device includes a non-volatile storage medium, an internal memory.
  • the nonvolatile storage medium stores an operating system, a computer program, and a database.
  • the memory provides an environment for the execution of the operating system and computer programs in the non-volatile storage medium.
  • the database of the computer equipment is used to store data such as model distillation methods.
  • the network interface of the computer device is used to communicate with an external terminal through a network connection.
  • the computer program when executed by a processor implements a model distillation method.
  • the model distillation method includes: obtaining a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, the pre-training model is a model obtained based on Bert network training;
  • the labeled training samples and the student model perform overall distillation learning on the pre-training model to obtain the student model after the first distillation; the unlabeled training samples and the student model after the first distillation are used to pair
  • the pre-trained model is subjected to layered distillation learning to obtain a student model after the second distillation; the labeled training sample is used to perform layered distillation learning on the student model after the second distillation, and a trained student model is obtained. student model.
  • the overall distillation learning is performed on the pre-training model by using the unlabeled training samples and the student model to obtain the student model after the first distillation, and the unlabeled training samples and the student model after the first distillation are used for pre-training.
  • the model is subjected to hierarchical distillation learning, and the student model after the second distillation is obtained.
  • the labeled training samples are used to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model, which is improved by three distillations.
  • the accuracy of the model obtained after distillation is improved; because unlabeled training samples are used in the first and second distillations, the need for labeling training samples is reduced and the cost of distillation is reduced.
  • An embodiment of the present application also provides a computer-readable storage medium on which a computer program is stored.
  • a method for model distillation is implemented, including the steps of: acquiring a pre-trained model, a student model, a plurality of Labeled training samples, a plurality of unlabeled training samples, the pre-training model is a model obtained based on Bert network training; using the unlabeled training samples and the student model to perform overall distillation learning on the pre-training model , obtain the student model after the first distillation; use the unlabeled training samples and the student model after the first distillation to perform hierarchical distillation learning on the pre-training model, and obtain the student model after the second distillation model; using the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation, to obtain a trained student model.
  • the model distillation method implemented above uses the unlabeled training samples and student models to perform overall distillation learning on the pre-training model to obtain the first distilled student model, using the unlabeled training samples and the first distilled student model.
  • the model performs hierarchical distillation learning on the pre-trained model to obtain the student model after the second distillation, and uses the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation, and obtains the trained student model.
  • the accuracy of the model obtained after distillation is improved by three distillations; because unlabeled training samples are used in the first and second distillations, the need for labeling of training samples is reduced and the cost of distillation is reduced.
  • the computer storage medium can be non-volatile or volatile.
  • Nonvolatile memory may include read only memory (ROM), programmable ROM (PROM), electrically programmable ROM (EPROM), electrically erasable programmable ROM (EEPROM), or flash memory.
  • Volatile memory may include random access memory (RAM) or external cache memory.
  • RAM is available in various forms such as static RAM (SRAM), dynamic RAM (DRAM), synchronous DRAM (SDRAM), double-rate SDRAM (SSRSDRAM), enhanced SDRAM (ESDRAM), synchronous Link (Synchlink) DRAM (SLDRAM), memory bus (Rambus) direct RAM (RDRAM), direct memory bus dynamic RAM (DRDRAM), and memory bus dynamic RAM (RDRAM), etc.

Landscapes

  • Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)
  • Other Investigation Or Analysis Of Materials By Electrical Means (AREA)

Abstract

A model distillation method and apparatus, a device, and a storage medium, relating to the technical field of artificial intelligence. The method comprises: acquiring a pre-trained model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, the pre-trained model being a model obtained by training on the basis of a Bert network (S1); using the unlabeled training samples and the student model to perform overall distillation learning on the pre-trained model to obtain a student model after first distillation (S2); using the unlabeled training samples and the student model after the first distillation to perform hierarchical distillation learning on the pre-trained model to obtain a student model after second distillation (S3); and using the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model (S4). Thus, the accuracy of the model obtained after distillations is improved by means of three distillations, the requirements for labeling of training samples are reduced, and distillation costs are reduced.

Description

模型蒸馏方法、装置、设备及存储介质Model distillation method, device, equipment and storage medium
本申请要求于2021年02月26日提交中国专利局、申请号为2021102205129,发明名称为“模型蒸馏方法、装置、设备及存储介质”的中国专利申请的优先权,其全部内容通过引用结合在本申请中。This application claims the priority of the Chinese patent application with the application number 2021102205129 and the invention titled "Model Distillation Method, Apparatus, Equipment and Storage Medium" filed with the China Patent Office on February 26, 2021, the entire contents of which are incorporated by reference in in this application.
技术领域technical field
本申请涉及到人工智能技术领域,特别是涉及到一种模型蒸馏方法、装置、设备及存储介质。The present application relates to the technical field of artificial intelligence, and in particular, to a model distillation method, device, equipment and storage medium.
背景技术Background technique
目前预训练模型具有较强的编码能力以及泛化能力,而且在使用预训练模型时,处理下游任务能极大的减少标注数据的使用量,因此在各个领域都发挥着巨大的作用。因为预训练模型通常参数量较大,导致无法在线使用。At present, the pre-training model has strong coding ability and generalization ability, and when using the pre-training model, processing downstream tasks can greatly reduce the amount of labeled data used, so it plays a huge role in various fields. Because the pre-trained model usually has a large amount of parameters, it cannot be used online.
发明人意识到现有技术通过把一个参数量大的预训练模型蒸馏到一个参数量小的模型,以用于实现参数量的减少以及推理速度的提升。但是,目前的蒸馏方法蒸馏后的小模型和原始模型存在准确率上的差距,很多甚至达到10个点左右的差距。同时,目前很多蒸馏方案都需要大量的标注数据,这极大的提升了蒸馏的成本。The inventor realizes that the prior art is used to reduce the amount of parameters and improve the inference speed by distilling a pre-trained model with a large amount of parameters into a model with a small amount of parameters. However, there is a gap in accuracy between the small model distilled by the current distillation method and the original model, and many even reach a gap of about 10 points. At the same time, many distillation schemes currently require a large amount of labeled data, which greatly increases the cost of distillation.
技术问题technical problem
现有技术的蒸馏方法蒸馏后的小模型和原始模型存在准确率上的差距,很多蒸馏方案都需要大量的标注数据,这极大的提升了蒸馏的成本的技术问题。The distillation method of the prior art has a gap in accuracy between the distilled small model and the original model. Many distillation schemes require a large amount of labeled data, which greatly increases the technical problem of the cost of distillation.
技术解决方案technical solutions
本申请的主要目的为提供一种模型蒸馏方法、装置、设备及存储介质,旨在解决现有技术的蒸馏方法蒸馏后的小模型和原始模型存在准确率上的差距,很多蒸馏方案都需要大量的标注数据,这极大的提升了蒸馏的成本的技术问题。The main purpose of this application is to provide a model distillation method, device, equipment and storage medium, which aims to solve the gap in accuracy between the small model distilled by the distillation method in the prior art and the original model, and many distillation schemes require a large amount of , which greatly increases the technical problem of the cost of distillation.
为了实现上述发明目的,本申请提出一种模型蒸馏方法,所述方法包括:In order to achieve the above purpose of the invention, the present application proposes a model distillation method, the method includes:
获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;Obtain a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, and the pre-training model is a model obtained based on Bert network training;
采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;Using the unlabeled training samples and the student model to perform overall distillation learning on the pre-training model to obtain the student model after the first distillation;
采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;Use the unlabeled training samples and the student model after the first distillation to perform hierarchical distillation learning on the pre-training model to obtain the student model after the second distillation;
采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。Use the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model.
本申请还提出了一种模型蒸馏装置,所述装置包括:The application also proposes a model distillation device, which includes:
数据获取模块,用于获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;a data acquisition module for acquiring a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, where the pre-training model is a model obtained through Bert network training;
第一阶段蒸馏模块,用于采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;The first-stage distillation module is used to perform overall distillation learning on the pre-training model by using the unlabeled training samples and the student model to obtain the student model after the first distillation;
第二阶段蒸馏模块,用于采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;The second-stage distillation module is used to perform hierarchical distillation learning on the pre-training model by using the unlabeled training samples and the student model after the first distillation to obtain the student model after the second distillation;
第三阶段蒸馏模块,用于采用所述带标注的训练样本对所述第二次蒸馏后的 学生模型进行分层蒸馏学习,得到训练好的学生模型。The third-stage distillation module is used to perform hierarchical distillation learning on the student model after the second distillation using the labeled training samples to obtain a trained student model.
本申请还提出了一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现如下方法步骤:The present application also proposes a computer device, including a memory and a processor, the memory stores a computer program, and the processor implements the following method steps when executing the computer program:
获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;Obtain a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, and the pre-training model is a model obtained based on Bert network training;
采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;Using the unlabeled training samples and the student model to perform overall distillation learning on the pre-training model to obtain the student model after the first distillation;
采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;Use the unlabeled training samples and the student model after the first distillation to perform hierarchical distillation learning on the pre-training model to obtain the student model after the second distillation;
采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。Use the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model.
本申请还提出了一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现如下方法步骤:The present application also proposes a computer-readable storage medium on which a computer program is stored, and when the computer program is executed by a processor, the following method steps are implemented:
获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;Obtain a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, and the pre-training model is a model obtained based on Bert network training;
采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;Using the unlabeled training samples and the student model to perform overall distillation learning on the pre-training model to obtain the student model after the first distillation;
采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;Use the unlabeled training samples and the student model after the first distillation to perform hierarchical distillation learning on the pre-training model to obtain the student model after the second distillation;
采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。Use the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model.
有益效果beneficial effect
本申请的模型蒸馏方法、装置、设备及存储介质,通过采用未标注的训练样本和学生模型对预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型,采用未标注的训练样本和第一次蒸馏后的学生模型对预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型,采用带标注的训练样本对第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型,从而通过三次蒸馏提升了蒸馏后得到的模型的准确率;因为在第一次和第二次蒸馏采用的是未标注的训练样本,从而减少了对训练样本标注的需求,降低了蒸馏的成本。The model distillation method, device, equipment and storage medium of the present application use unlabeled training samples and student models to perform overall distillation learning on the pre-training model to obtain the student model after the first distillation, using unlabeled training samples and The student model after the first distillation performs hierarchical distillation learning on the pre-trained model to obtain the student model after the second distillation, and uses the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation, and obtains The trained student model improves the accuracy of the model obtained after distillation through three distillations; because unlabeled training samples are used in the first and second distillations, the need for labeling of training samples is reduced. Reduced distillation costs.
附图说明Description of drawings
图1为本申请一实施例的模型蒸馏方法的流程示意图;1 is a schematic flowchart of a model distillation method according to an embodiment of the application;
图2为本申请一实施例的模型蒸馏装置的结构示意框图;2 is a schematic block diagram of the structure of a model distillation apparatus according to an embodiment of the application;
图3为本申请一实施例的计算机设备的结构示意框图。FIG. 3 is a schematic structural block diagram of a computer device according to an embodiment of the present application.
本申请目的实现、功能特点及优点将结合实施例,参照附图做进一步说明。The realization, functional features and advantages of the present application will be further described with reference to the accompanying drawings in conjunction with the embodiments.
本发明的实施方式Embodiments of the present invention
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本申请,并不用于限定本申请。In order to make the purpose, technical solutions and advantages of the present application more clearly understood, the present application will be described in further detail below with reference to the accompanying drawings and embodiments. It should be understood that the specific embodiments described herein are only used to explain the present application, but not to limit the present application.
为了解决现有技术的蒸馏方法蒸馏后的小模型和原始模型存在准确率上的差距,很多蒸馏方案都需要大量的标注数据,这极大的提升了蒸馏的成本的技术问题,本申请提出了一种模型蒸馏方法,所述方法应用于人工智能技术领域。所 述模型蒸馏方法通过第一次采用整体蒸馏学习、第二次采用分层蒸馏学习、第三次分层蒸馏学习,通过三次蒸馏提升了蒸馏后得到的模型的准确率;而且在第一次和第二次蒸馏采用的是未标注的训练样本,从而减少了对训练样本标注的需求,降低了蒸馏的成本。In order to solve the difference in accuracy between the small model distilled by the distillation method of the prior art and the original model, many distillation schemes require a large amount of labeled data, which greatly increases the technical problem of the cost of distillation. This application proposes A model distillation method, which is applied in the field of artificial intelligence technology. The model distillation method uses the overall distillation learning for the first time, the hierarchical distillation learning for the second time, and the hierarchical distillation learning for the third time, and improves the accuracy of the model obtained after distillation through three distillations; And the second distillation uses unlabeled training samples, which reduces the need for labeling training samples and reduces the cost of distillation.
参照图1,本申请实施例中提供一种模型蒸馏方法,所述方法包括:1 , a model distillation method is provided in the embodiment of the present application, and the method includes:
S1:获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;S1: Obtain a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, where the pre-training model is a model obtained through Bert network training;
S2:采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;S2: Use the unlabeled training samples and the student model to perform overall distillation learning on the pre-training model to obtain the student model after the first distillation;
S3:采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;S3: Use the unlabeled training samples and the student model after the first distillation to perform hierarchical distillation learning on the pre-training model to obtain the student model after the second distillation;
S4:采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。S4: Use the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation, to obtain a trained student model.
本实施例通过采用未标注的训练样本和学生模型对预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型,采用未标注的训练样本和第一次蒸馏后的学生模型对预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型,采用带标注的训练样本对第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型,从而通过三次蒸馏提升了蒸馏后得到的模型的准确率;因为在第一次和第二次蒸馏采用的是未标注的训练样本,从而减少了对训练样本标注的需求,降低了蒸馏的成本。In this embodiment, the overall distillation learning is performed on the pre-training model by using the unlabeled training samples and the student model to obtain the student model after the first distillation, and the unlabeled training samples and the student model after the first distillation are used for pre-training. The model is subjected to hierarchical distillation learning, and the student model after the second distillation is obtained. The labeled training samples are used to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model, which is improved by three distillations. The accuracy of the model obtained after distillation is improved; because unlabeled training samples are used in the first and second distillations, the need for labeling training samples is reduced and the cost of distillation is reduced.
对于S1,可以从数据库中获取预训练模型,也可以是用户输入的预训练模型,还可以是第三方应用系统发送的预训练模型。For S1, the pre-training model can be obtained from the database, the pre-training model input by the user, or the pre-training model sent by a third-party application system.
可以从数据库中获取学生模型,也可以是用户输入的学生模型,还可以是第三方应用系统发送的学生模型。The student model can be obtained from the database, the student model input by the user, or the student model sent by the third-party application system.
可以从数据库中获取多个带标注的训练样本,也可以是用户输入的多个带标注的训练样本,还可以是第三方应用系统发送的多个带标注的训练样本。Multiple labeled training samples can be obtained from the database, multiple labeled training samples input by the user, or multiple labeled training samples sent by third-party application systems.
可以从数据库中获取多个未标注的训练样本,也可以是用户输入的多个未标注的训练样本,还可以是第三方应用系统发送的多个未标注的训练样本。Multiple unlabeled training samples can be obtained from the database, multiple unlabeled training samples input by the user, or multiple unlabeled training samples sent by third-party application systems.
所述学生模型包括:Embedding层、BiLSTM层、Dense层。Embedding层输入数据到BiLSTM层,BiLSTM层输出数据到Dense层。Embedding层是嵌入层。BiLSTM层的输出为每一个标签的预测评分。Dense层是全连接层,输出预测概率。The student model includes: Embedding layer, BiLSTM layer, Dense layer. The Embedding layer inputs data to the BiLSTM layer, and the BiLSTM layer outputs data to the Dense layer. The Embedding layer is the embedding layer. The output of the BiLSTM layer is the prediction score for each label. The Dense layer is a fully connected layer that outputs predicted probabilities.
带标注的训练样本包括:样本数据、样本标定值,样本标定值是对样本数据的标定结果。The labeled training samples include: sample data and sample calibration values, where the sample calibration values are the calibration results of the sample data.
未标注的训练样本包括:样本数据。Unlabeled training samples include: sample data.
可选的,所述多个带标注的训练样本中带标注的训练样本的数量小于所述多个未标注的训练样本中未标注的训练样本的数量。Optionally, the number of labeled training samples in the plurality of labeled training samples is smaller than the number of unlabeled training samples in the plurality of unlabeled training samples.
对于S2,采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,也就是对所述预训练模型的所有参数进行更新,将训练后的学生模型作为第一次蒸馏后的学生模型。For S2, use the unlabeled training samples and the student model to perform overall distillation learning on the pre-training model, that is, update all parameters of the pre-training model, and use the trained student model as the first Student model after secondary distillation.
对于S3,采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,也就是对第一次蒸馏后的学生模型的参数进行分层更新,将训练后的第一次蒸馏后的学生模型作为第二次蒸馏后的学生模型。从而避免了现有技术的蒸馏方式的灾难性遗忘的现象,避免了在第二次蒸馏就忘记了第一次蒸馏的内容的现象。For S3, use the unlabeled training samples and the student model after the first distillation to perform hierarchical distillation learning on the pre-training model, that is, perform layering on the parameters of the student model after the first distillation. Update, take the trained student model after the first distillation as the student model after the second distillation. Thus, the phenomenon of catastrophic forgetting of the distillation method in the prior art is avoided, and the phenomenon of forgetting the content of the first distillation is avoided in the second distillation.
对于S4,采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,也就是对第二次蒸馏后的学生模型的参数进行分层更新,将训练后的第二次蒸馏后的学生模型作为训练好的学生模型。从而避免了现有技术的蒸馏方式的灾难性遗忘的现象,避免了在第三次蒸馏就忘记了第二次蒸馏的内容的现象。For S4, the labeled training sample is used to perform hierarchical distillation learning on the student model after the second distillation, that is, the parameters of the student model after the second distillation are updated hierarchically, and the trained The student model after the second distillation is used as the trained student model. Thus, the phenomenon of catastrophic forgetting in the distillation method in the prior art is avoided, and the phenomenon of forgetting the content of the second distillation is avoided in the third distillation.
在一个实施例中,上述采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型的步骤,包括:In one embodiment, the above-mentioned steps of using the unlabeled training samples and the student model to perform overall distillation learning on the pre-training model to obtain the student model after the first distillation include:
S21:将所述未标注的训练样本输入所述预训练模型进行评分预测,获取所述预训练模型的评分预测层输出的第一预测评分;S21: Input the unlabeled training samples into the pre-training model for scoring prediction, and obtain the first prediction score output by the scoring prediction layer of the pre-training model;
S22:将所述未标注的训练样本输入所述学生模型的进行评分预测,得到第二预测评分;S22: Input the unlabeled training samples into the student model for scoring prediction to obtain a second prediction score;
S23:将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值,根据所述第一损失值更新所述学生模型的所有参数,将更新参数后的所述学生模型用于下一次计算所述第二预测评分;S23: Input the first predicted score and the second predicted score into a first loss function for calculation to obtain a first loss value, update all parameters of the student model according to the first loss value, and update the parameters after updating. The said student model is used for the next calculation of the second predicted score;
S24:重复执行上述方法步骤直至所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件,将所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件的所述学生模型,确定为所述第一次蒸馏后的学生模型。S24: Repeat the above method steps until the first loss value reaches the first convergence condition or the number of iterations reaches the second convergence condition, and set the first loss value to meet the first convergence condition or the number of iterations to reach the second convergence condition The student model is determined as the student model after the first distillation.
本实施例实现了根据未标注的训练样本预测得到预测评分计算损失值对所述学生模型的所有参数进行更新,实现了整体蒸馏学习所述预训练模型学习到的知识。This embodiment realizes that all parameters of the student model are updated according to the prediction score obtained from the unlabeled training sample prediction and the loss value is calculated, and the knowledge learned by the pre-training model is learned by overall distillation.
对于S21,将所述未标注的训练样本的样本数据输入所述预训练模型进行预测,将所述预训练模型的评分预测层输出的评分作为第一预测评分。For S21, input the sample data of the unlabeled training samples into the pre-training model for prediction, and use the score output by the score prediction layer of the pre-training model as the first prediction score.
对于S22,将所述未标注的训练样本的样本数据输入所述学生模型进行预测,将所述学生模型的BiLSTM层输出的评分作为第二预测评分。For S22, input the sample data of the unlabeled training samples into the student model for prediction, and use the score output by the BiLSTM layer of the student model as the second prediction score.
对于S23,将所述第一预测评分、所述第二预测评分输入第一损失函数进行损失值计算,将计算得到的损失值作为第一损失值。For S23, the first predicted score and the second predicted score are input into the first loss function to calculate the loss value, and the calculated loss value is used as the first loss value.
根据所述第一损失值更新所述学生模型的所有参数的方法可以从现有技术中选择,在此不做赘述。The method for updating all the parameters of the student model according to the first loss value can be selected from the prior art, which will not be repeated here.
对于S24,所述第一收敛条件是指相邻两次计算的第一损失值的大小满足lipschitz条件(利普希茨连续条件)。For S24, the first convergence condition means that the magnitude of the first loss value calculated twice adjacently satisfies the Lipschitz condition (the Lipschitz continuity condition).
所述迭代次数达到第二收敛条件是指所述学生模型被用于计算所述第二预测评分的次数,也就是说,计算一次,迭代次数增加1。The number of iterations reaching the second convergence condition refers to the number of times the student model is used to calculate the second predicted score, that is, the number of iterations increases by 1 after one calculation.
可以理解的是,当所述第一损失值未达到第一收敛条件并且迭代次数未达到第二收敛条件时,从所述多个未标注的训练样本中获取新的未标注的训练样本,根据获取的未标注的训练样本执行步骤S21至步骤S24。It can be understood that, when the first loss value does not reach the first convergence condition and the number of iterations does not reach the second convergence condition, a new unlabeled training sample is obtained from the plurality of unlabeled training samples, according to Steps S21 to S24 are executed for the acquired unlabeled training samples.
在一个实施例中,上述将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值的步骤,包括:In one embodiment, the above-mentioned step of inputting the first predicted score and the second predicted score into a first loss function for calculation to obtain a first loss value includes:
将所述第一预测评分、所述第二预测评分输入KL散度损失函数进行计算,得到所述第一损失值。The first prediction score and the second prediction score are input into the KL divergence loss function for calculation, and the first loss value is obtained.
KL散度损失函数,又称为K-L散度损失函数。KL divergence loss function, also known as K-L divergence loss function.
KL散度损失函数KL(p||q)的计算公式为:The calculation formula of the KL divergence loss function KL(p||q) is:
Figure PCTCN2021084539-appb-000001
Figure PCTCN2021084539-appb-000001
其中,x是所述未标注的训练样本的样本数据,p(x)是第一预测评分,q(x)是第二预测评分,log()是对数函数。where x is the sample data of the unlabeled training samples, p(x) is the first predicted score, q(x) is the second predicted score, and log() is a logarithmic function.
在一个实施例中,上述采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型的步骤,包括:In one embodiment, the above step of using the unlabeled training samples and the student model after the first distillation to perform hierarchical distillation learning on the pre-training model to obtain the student model after the second distillation, include:
S31:将所述未标注的训练样本输入所述预训练模型进行概率预测,获取所述预训练模型的概率预测层输出的第一预测概率;S31: Input the unlabeled training samples into the pre-training model for probability prediction, and obtain the first prediction probability output by the probability prediction layer of the pre-training model;
S32:将所述未标注的训练样本输入所述第一次蒸馏后的学生模型进行概率预测,得到第二预测概率;S32: Input the unlabeled training sample into the student model after the first distillation for probability prediction, and obtain a second predicted probability;
S33:将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数,将更新参数后的所述第一次蒸馏后的学生模型用于下一次计算所述第二预测概率;S33: Input the first predicted probability and the second predicted probability into a second loss function for calculation to obtain a second loss value, and update the first preset parameter hierarchical update rule according to the second loss value The parameters of the student model after the first distillation, the student model after the first distillation after updating the parameters is used for the next calculation of the second prediction probability;
S34:重复执行上述方法步骤直至所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件,将所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件的所述第一次蒸馏后的学生模型,确定为所述第二次蒸馏后的学生模型。S34: Repeat the above method steps until the second loss value reaches the third convergence condition or the number of iterations reaches the fourth convergence condition, and set the second loss value to reach the third convergence condition or the iteration number reaches the fourth convergence condition The student model after the first distillation is determined as the student model after the second distillation.
本实施例实现了根据未标注的训练样本预测得到预测概率计算损失值对第一次蒸馏后的学生模型的参数进行分层更新,从而避免了现有技术的蒸馏方式的灾难性遗忘的现象,避免了在第二次蒸馏就忘记了第一次蒸馏的内容的现象。This embodiment realizes that the parameters of the student model after the first distillation are updated hierarchically according to the prediction probability calculation loss value obtained from the unlabeled training samples, thereby avoiding the catastrophic forgetting phenomenon of the distillation method in the prior art, Avoid the phenomenon of forgetting the content of the first distillation in the second distillation.
对于S31,将所述未标注的训练样本的样本数据输入所述预训练模型进行概率预测,将所述预训练模型的概率预测层输出的概率作为第一预测概率。For S31, input the sample data of the unlabeled training samples into the pre-training model for probability prediction, and use the probability output by the probability prediction layer of the pre-training model as the first prediction probability.
对于S32,将所述未标注的训练样本的样本数据输入所述第一次蒸馏后学生模型进行概率预测,将所述第一次蒸馏后学生模型的Dense层输出的概率作为第二预测概率。For S32, the sample data of the unlabeled training sample is input into the student model after the first distillation for probability prediction, and the probability output by the Dense layer of the student model after the first distillation is used as the second prediction probability.
对于S33,将所述第一预测概率、所述第二预测概率输入第二损失函数进行损失值计算,将计算得到的损失值作为第二损失值。For S33, the first predicted probability and the second predicted probability are input into the second loss function to calculate the loss value, and the calculated loss value is used as the second loss value.
根据所述第二损失值每次只更新所述第一次蒸馏后学生模型的一层(也就是Embedding层、BiLSTM层、Dense层)的参数。According to the second loss value, only the parameters of one layer (that is, the Embedding layer, the BiLSTM layer, and the Dense layer) of the student model after the first distillation are updated each time.
对于S34,所述第三收敛条件是指相邻两次计算的第三损失值的大小满足lipschitz条件(利普希茨连续条件)。For S34, the third convergence condition means that the magnitude of the third loss value calculated twice adjacently satisfies the Lipschitz condition (the Lipschitz continuity condition).
所述迭代次数达到第四收敛条件是指所述学生模型被用于计算所述第四预测概率的次数,也就是说,计算一次,迭代次数增加1。The number of iterations reaching the fourth convergence condition refers to the number of times that the student model is used to calculate the fourth predicted probability, that is, the number of iterations increases by 1 after one calculation.
可以理解的是,当所述第二损失值未达到第三收敛条件并且迭代次数未达到第四收敛条件时,从所述多个未标注的训练样本中获取新的未标注的训练样本,根据获取的未标注的训练样本执行步骤S31至步骤S34。It can be understood that when the second loss value does not reach the third convergence condition and the number of iterations does not reach the fourth convergence condition, a new unlabeled training sample is obtained from the plurality of unlabeled training samples, according to Steps S31 to S34 are executed for the acquired unlabeled training samples.
在一个实施例中,上述将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数的步骤,包括:In one embodiment, the first predicted probability and the second predicted probability are input into a second loss function for calculation to obtain a second loss value, which is stratified by first preset parameters according to the second loss value The steps of updating the parameters of the student model after the first distillation of the update rule include:
S331:将所述第一预测概率、所述第二预测概率输入MSE损失函数进行计 算,得到所述第二损失值;S331: Input the first predicted probability and the second predicted probability into the MSE loss function for calculation to obtain the second loss value;
S332:当所述第二损失值中的Dense层参数未达到第一Dense层收敛条件时,根据所述第二损失值中的Dense层参数更新所述第一次蒸馏后的学生模型的Dense层的参数,否则,当所述第二损失值中的BiLSTM层参数未达到第一BiLSTM层收敛条件时,根据所述第二损失值中的BiLSTM层参数更新所述第一次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第二损失值中的Embedding层参数更新所述第一次蒸馏后的学生模型的Embedding层的参数。S332: When the Dense layer parameter in the second loss value does not reach the convergence condition of the first Dense layer, update the Dense layer of the student model after the first distillation according to the Dense layer parameter in the second loss value , otherwise, when the BiLSTM layer parameters in the second loss value do not reach the convergence condition of the first BiLSTM layer, update the student model after the first distillation according to the BiLSTM layer parameters in the second loss value parameters of the BiLSTM layer, otherwise, update the parameters of the Embedding layer of the student model after the first distillation according to the parameters of the Embedding layer in the second loss value.
本实施例实现了根据未标注的训练样本预测得到预测概率计算损失值对第一次蒸馏后的学生模型的参数进行分层更新,从而避免了现有技术的蒸馏方式的灾难性遗忘的现象,避免了在第二次蒸馏就忘记了第一次蒸馏的内容的现象。This embodiment realizes that the parameters of the student model after the first distillation are updated hierarchically according to the prediction probability calculation loss value obtained from the unlabeled training samples, thereby avoiding the catastrophic forgetting phenomenon of the distillation method in the prior art, Avoid the phenomenon of forgetting the content of the first distillation in the second distillation.
对于S331,MSE损失函数公式MSE(p,q)如下:For S331, the MSE loss function formula MSE(p,q) is as follows:
Figure PCTCN2021084539-appb-000002
Figure PCTCN2021084539-appb-000002
其中,p t是第一预测概率,q t是第二预测概率。 where p t is the first predicted probability and q t is the second predicted probability.
对于S332,第一Dense层收敛条件、第一BiLSTM层收敛条件可以根据训练需求设置,在此不做具体限定。For S332, the convergence conditions of the first Dense layer and the convergence conditions of the first BiLSTM layer can be set according to training requirements, which are not specifically limited here.
在一个实施例中,上述采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型的步骤,包括:In one embodiment, the above-mentioned steps of using the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model include:
S41:将所述带标注的训练样本输入所述第二次蒸馏后的学生模型进行概率预测,得到第三预测概率;S41: Input the labeled training sample into the student model after the second distillation for probability prediction, and obtain a third predicted probability;
S42:将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行计算,得到第三损失值,根据所述第三损失值按第二预设参数分层更新规则更新所述第二次蒸馏后的学生模型的参数,将更新参数后的所述第二次蒸馏后的学生模型用于下一次计算所述第三预测概率;S42: Input the third prediction probability and the sample calibration value of the labeled training sample into a third loss function for calculation, to obtain a third loss value, and stratify according to the second preset parameter according to the third loss value The update rule updates the parameters of the student model after the second distillation, and the student model after the second distillation after the update parameter is used for the next calculation of the third prediction probability;
S43:重复执行上述方法步骤直至所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件,将所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件的所述第二次蒸馏后的学生模型,确定为所述训练好的学生模型。S43: Repeat the above method steps until the third loss value reaches the fifth convergence condition or the number of iterations reaches the sixth convergence condition, and set the third loss value to reach the fifth convergence condition or the iteration number reaches the sixth convergence condition The student model after the second distillation is determined as the trained student model.
本实施例实现了根据带标注的训练样本预测得到的预测概率计算损失值对第二次蒸馏后的学生模型的参数进行分层更新,从而避免了现有技术的蒸馏方式的灾难性遗忘的现象,避免了在第三次蒸馏就忘记了第二次蒸馏的内容的现象。This embodiment implements hierarchical update of the parameters of the student model after the second distillation according to the prediction probability calculation loss value predicted by the labeled training samples, thereby avoiding the phenomenon of catastrophic forgetting in the distillation method in the prior art , to avoid the phenomenon of forgetting the content of the second distillation in the third distillation.
对于S41,将所述带标注的训练样本的样本数据输入所述第二次蒸馏后的学生模型进行概率预测,将所述第二次蒸馏后的学生模型的Dense层输出的概率作为第三预测概率。For S41, input the sample data of the labeled training samples into the student model after the second distillation for probability prediction, and use the probability output by the Dense layer of the student model after the second distillation as the third prediction probability.
对于S42,将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行损失值计算,将计算得到的损失值作为第三损失值。For S42, the third predicted probability and the sample calibration value of the labeled training sample are input into the third loss function to calculate the loss value, and the calculated loss value is used as the third loss value.
根据所述第三损失值每次只更新所述第二次蒸馏后学生模型的一层(也就是Embedding层、BiLSTM层、Dense层)的参数。According to the third loss value, only the parameters of one layer (that is, the Embedding layer, the BiLSTM layer, and the Dense layer) of the student model after the second distillation are updated each time.
对于S43,所述第五收敛条件是指相邻两次计算的第三损失值的大小满足lipschitz条件(利普希茨连续条件)。For S43, the fifth convergence condition means that the magnitude of the third loss value calculated twice adjacently satisfies the Lipschitz condition (the Lipschitz continuity condition).
所述迭代次数达到第六收敛条件是指所述第二次蒸馏后的学生模型被用于计算所述第三预测概率的次数,也就是说,计算一次,迭代次数增加1。The number of iterations reaching the sixth convergence condition refers to the number of times that the student model after the second distillation is used to calculate the third predicted probability, that is, the number of iterations increases by 1 after one calculation.
可以理解的是,当所述第三损失值未达到第五收敛条件并且迭代次数未达到第六收敛条件时,从所述多个带标注的训练样本中获取新的带标注的训练样本, 根据获取的带标注的训练样本执行步骤S41至步骤S43。It can be understood that when the third loss value does not reach the fifth convergence condition and the number of iterations does not reach the sixth convergence condition, a new marked training sample is obtained from the multiple marked training samples, according to Steps S41 to S43 are executed for the obtained labeled training samples.
在一个实施例中,上述所述将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行计算,得到第三损失值,根据所述第三损失值按第二预设参数分层更新规则更新所述第二次蒸馏后的学生模型的参数的步骤,包括:In one embodiment, the above-mentioned third prediction probability and the sample calibration value of the labeled training sample are input into a third loss function for calculation to obtain a third loss value, and according to the third loss value The step of updating the parameters of the student model after the second distillation by the second preset parameter hierarchical update rule includes:
S421:将所述第三预测概率、所述带标注的训练样本的样本标定值输入交叉熵损失函数进行计算,得到所述第三损失值;S421: Input the third prediction probability and the sample calibration value of the labeled training sample into a cross-entropy loss function for calculation, to obtain the third loss value;
S422:当所述第三损失值中的Dense层参数未达到第二Dense层收敛条件时,根据所述第三损失值中的Dense层参数更新所述第二次蒸馏后的学生模型的Dense层的参数,否则,当所述第三损失值中的BiLSTM层参数未达到第二BiLSTM层收敛条件时,根据所述第三损失值中的BiLSTM层参数更新所述第二次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第三损失值中的Embedding层参数更新所述第二次蒸馏后的学生模型的Embedding层的参数。S422: When the Dense layer parameter in the third loss value does not reach the convergence condition of the second Dense layer, update the Dense layer of the student model after the second distillation according to the Dense layer parameter in the third loss value , otherwise, when the BiLSTM layer parameters in the third loss value do not reach the convergence condition of the second BiLSTM layer, update the student model after the second distillation according to the BiLSTM layer parameters in the third loss value parameters of the BiLSTM layer, otherwise, update the parameters of the Embedding layer of the student model after the second distillation according to the parameters of the Embedding layer in the third loss value.
本实施例根据带标注的训练样本预测得到的预测概率计算损失值对第二次蒸馏后的学生模型的参数进行分层更新,从而避免了现有技术的蒸馏方式的灾难性遗忘的现象,避免了在第三次蒸馏就忘记了第二次蒸馏的内容的现象。In this embodiment, the parameters of the student model after the second distillation are updated hierarchically according to the prediction probability calculation loss value predicted by the labeled training samples, so as to avoid the catastrophic forgetting phenomenon of the distillation method in the prior art, and avoid the The phenomenon of forgetting the contents of the second distillation after the third distillation.
对于S421,交叉熵损失函数CE的计算公式如下:For S421, the calculation formula of the cross entropy loss function CE is as follows:
Figure PCTCN2021084539-appb-000003
Figure PCTCN2021084539-appb-000003
其中,y c是所述带标注的训练样本的样本标定值,p c是第三预测概率。 Wherein, y c is the sample calibration value of the labeled training sample, and p c is the third prediction probability.
对于S422,第二Dense层收敛条件、第二BiLSTM层收敛条件可以根据训练需求设置,在此不做具体限定。For S422, the convergence conditions of the second Dense layer and the convergence conditions of the second BiLSTM layer can be set according to training requirements, which are not specifically limited here.
参照图2,本申请还提出了一种模型蒸馏装置,所述装置包括:Referring to Figure 2, the present application also proposes a model distillation device, the device comprising:
数据获取模块100,用于获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;The data acquisition module 100 is used to acquire a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, wherein the pre-training model is a model obtained based on Bert network training;
第一阶段蒸馏模块200,用于采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;The first-stage distillation module 200 is used to perform overall distillation learning on the pre-training model by using the unlabeled training samples and the student model to obtain the student model after the first distillation;
第二阶段蒸馏模块300,用于采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;The second-stage distillation module 300 is configured to perform hierarchical distillation learning on the pre-training model using the unlabeled training samples and the student model after the first distillation, to obtain the student model after the second distillation;
第三阶段蒸馏模块400,用于采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。The third-stage distillation module 400 is configured to perform hierarchical distillation learning on the student model after the second distillation by using the labeled training samples to obtain a trained student model.
本实施例通过采用未标注的训练样本和学生模型对预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型,采用未标注的训练样本和第一次蒸馏后的学生模型对预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型,采用带标注的训练样本对第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型,从而通过三次蒸馏提升了蒸馏后得到的模型的准确率;因为在第一次和第二次蒸馏采用的是未标注的训练样本,从而减少了对训练样本标注的需求,降低了蒸馏的成本。In this embodiment, the overall distillation learning is performed on the pre-training model by using the unlabeled training samples and the student model to obtain the student model after the first distillation, and the unlabeled training samples and the student model after the first distillation are used for pre-training. The model is subjected to hierarchical distillation learning, and the student model after the second distillation is obtained. The labeled training samples are used to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model, which is improved by three distillations. The accuracy of the model obtained after distillation is improved; because unlabeled training samples are used in the first and second distillations, the need for labeling training samples is reduced and the cost of distillation is reduced.
在一个实施例中,所述第一阶段蒸馏模块200,包括:预训练模型评分预测子模块、学生模型评分预测子模块、第一阶段蒸馏训练子模块;In one embodiment, the first-stage distillation module 200 includes: a pre-training model score prediction sub-module, a student model score prediction sub-module, and a first-stage distillation training sub-module;
所述预训练模型评分预测子模块,用于将所述未标注的训练样本输入所述预 训练模型进行评分预测,获取所述预训练模型的评分预测层输出的第一预测评分;The pre-training model scoring prediction submodule is used to input the unlabeled training samples into the pre-training model for scoring prediction, and obtain the first prediction score output by the scoring prediction layer of the pre-training model;
所述学生模型评分预测子模块,用于将所述未标注的训练样本输入所述学生模型的进行评分预测,得到第二预测评分;The student model scoring prediction sub-module is used to input the unlabeled training samples into the student model for scoring prediction to obtain a second prediction score;
所述第一阶段蒸馏训练子模块,用于将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值,根据所述第一损失值更新所述学生模型的所有参数,将更新参数后的所述学生模型用于下一次计算所述第二预测评分,重复执行上述方法步骤直至所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件,将所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件的所述学生模型,确定为所述第一次蒸馏后的学生模型。The first-stage distillation training sub-module is used to input the first prediction score and the second prediction score into a first loss function for calculation to obtain a first loss value, and update the first loss value according to the first loss value. For all parameters of the student model, the student model after updating the parameters is used for the next calculation of the second prediction score, and the above method steps are repeated until the first loss value reaches the first convergence condition or the number of iterations reaches the second. Convergence condition, the student model whose first loss value reaches the first convergence condition or the number of iterations reaches the second convergence condition is determined as the student model after the first distillation.
在一个实施例中,所述第一阶段蒸馏训练子模块包括:第一损失值计算单元;In one embodiment, the first-stage distillation training sub-module includes: a first loss value calculation unit;
所述第一损失值计算单元,用于将所述第一预测评分、所述第二预测评分输入KL散度损失函数进行计算,得到所述第一损失值。The first loss value calculation unit is configured to input the first prediction score and the second prediction score into the KL divergence loss function for calculation, and obtain the first loss value.
在一个实施例中,所述第二阶段蒸馏模块300包括:预训练模型概率预测子模块、第一次蒸馏后的学生模型概率预测子模块、第二阶段蒸馏训练子模块;In one embodiment, the second-stage distillation module 300 includes: a pre-training model probability prediction sub-module, a student model probability prediction sub-module after the first distillation, and a second-stage distillation training sub-module;
所述预训练模型概率预测子模块,用于将所述未标注的训练样本输入所述预训练模型进行概率预测,获取所述预训练模型的概率预测层输出的第一预测概率;The pre-training model probability prediction sub-module is configured to input the unlabeled training samples into the pre-training model for probability prediction, and obtain the first prediction probability output by the probability prediction layer of the pre-training model;
所述第一次蒸馏后的学生模型概率预测子模块,用于将所述未标注的训练样本输入所述第一次蒸馏后的学生模型进行概率预测,得到第二预测概率;The student model probability prediction submodule after the first distillation is used to input the unlabeled training samples into the student model after the first distillation for probability prediction, and obtain a second predicted probability;
所述第二阶段蒸馏训练子模块,用于将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数,将更新参数后的所述第一次蒸馏后的学生模型用于下一次计算所述第二预测概率,重复执行上述方法步骤直至所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件,将所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件的所述第一次蒸馏后的学生模型,确定为所述第二次蒸馏后的学生模型。The second-stage distillation training sub-module is used to input the first predicted probability and the second predicted probability into a second loss function for calculation to obtain a second loss value. The preset parameter hierarchical update rule updates the parameters of the student model after the first distillation, and the student model after the first distillation after updating the parameters is used for the next calculation of the second prediction probability, and repeated execution The above method steps are performed until the second loss value reaches the third convergence condition or the number of iterations reaches the fourth convergence condition, and the second loss value reaches the third convergence condition or the iteration number reaches the fourth convergence condition for the first time The student model after distillation is determined as the student model after the second distillation.
在一个实施例中,所述第二阶段蒸馏训练子模块包括:第二损失值计算单元、第一参数更新单元;In one embodiment, the second-stage distillation training sub-module includes: a second loss value calculation unit, and a first parameter update unit;
所述第二损失值计算单元,用于将所述第一预测概率、所述第二预测概率输入MSE损失函数进行计算,得到所述第二损失值;the second loss value calculation unit, configured to input the first predicted probability and the second predicted probability into an MSE loss function for calculation, to obtain the second loss value;
所述第一参数更新单元,用于当所述第二损失值中的Dense层参数未达到第一Dense层收敛条件时,根据所述第二损失值中的Dense层参数更新所述第一次蒸馏后的学生模型的Dense层的参数,否则,当所述第二损失值中的BiLSTM层参数未达到第一BiLSTM层收敛条件时,根据所述第二损失值中的BiLSTM层参数更新所述第一次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第二损失值中的Embedding层参数更新所述第一次蒸馏后的学生模型的Embedding层的参数。The first parameter updating unit is configured to update the first Dense layer parameter according to the Dense layer parameter in the second loss value when the Dense layer parameter in the second loss value does not reach the convergence condition of the first Dense layer The parameters of the Dense layer of the distilled student model, otherwise, when the BiLSTM layer parameters in the second loss value do not reach the convergence condition of the first BiLSTM layer, update the BiLSTM layer parameters according to the second loss value. The parameters of the BiLSTM layer of the student model after the first distillation, otherwise, the parameters of the Embedding layer of the student model after the first distillation are updated according to the parameters of the Embedding layer in the second loss value.
在一个实施例中,所述第三阶段蒸馏模块400包括:第二次蒸馏后的学生模型概率预测子模块、第三阶段蒸馏训练子模块;In one embodiment, the third-stage distillation module 400 includes: a student model probability prediction sub-module after the second distillation, and a third-stage distillation training sub-module;
所述第二次蒸馏后的学生模型概率预测子模块,用于将所述带标注的训练样本输入所述第二次蒸馏后的学生模型进行概率预测,得到第三预测概率;The student model probability prediction submodule after the second distillation is used to input the labeled training sample into the student model after the second distillation for probability prediction, and obtain a third prediction probability;
所述第三阶段蒸馏训练子模块,用于将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行计算,得到第三损失值,根据所述第三损失值按第二预设参数分层更新规则更新所述第二次蒸馏后的学生模型的参数,将更新参数后的所述第二次蒸馏后的学生模型用于下一次计算所述第三预测 概率,重复执行上述方法步骤直至所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件,将所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件的所述第二次蒸馏后的学生模型,确定为所述训练好的学生模型。The third-stage distillation training sub-module is used to input the third prediction probability and the sample calibration value of the labeled training sample into a third loss function for calculation to obtain a third loss value, according to the third loss function. The loss value updates the parameters of the student model after the second distillation according to the second preset parameter hierarchical update rule, and the student model after the second distillation after updating the parameters is used for the next calculation of the third To predict the probability, repeat the above method steps until the third loss value reaches the fifth convergence condition or the number of iterations reaches the sixth convergence condition, and the third loss value reaches the fifth convergence condition or the iteration number reaches the sixth convergence condition. The student model after the second distillation is determined as the trained student model.
在一个实施例中,所述第三阶段蒸馏训练子模块包括:第三损失值计算单元、第二参数更新单元;In one embodiment, the third-stage distillation training sub-module includes: a third loss value calculation unit and a second parameter update unit;
所述第三损失值计算单元,用于将所述第三预测概率、所述带标注的训练样本的样本标定值输入交叉熵损失函数进行计算,得到所述第三损失值;The third loss value calculation unit is configured to input the third prediction probability and the sample calibration value of the labeled training sample into a cross-entropy loss function for calculation to obtain the third loss value;
所述第二参数更新单元,用于当所述第三损失值中的Dense层参数未达到第二Dense层收敛条件时,根据所述第三损失值中的Dense层参数更新所述第二次蒸馏后的学生模型的Dense层的参数,否则,当所述第三损失值中的BiLSTM层参数未达到第二BiLSTM层收敛条件时,根据所述第三损失值中的BiLSTM层参数更新所述第二次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第三损失值中的Embedding层参数更新所述第二次蒸馏后的学生模型的Embedding层的参数。The second parameter updating unit is configured to update the second Dense layer parameter according to the Dense layer parameter in the third loss value when the Dense layer parameter in the third loss value does not reach the convergence condition of the second Dense layer The parameters of the Dense layer of the distilled student model, otherwise, when the BiLSTM layer parameters in the third loss value do not reach the convergence condition of the second BiLSTM layer, update the BiLSTM layer parameters according to the third loss value. The parameters of the BiLSTM layer of the student model after the second distillation, otherwise, the parameters of the Embedding layer of the student model after the second distillation are updated according to the parameters of the Embedding layer in the third loss value.
参照图3,本申请实施例中还提供一种计算机设备,该计算机设备可以是服务器,其内部结构可以如图3所示。该计算机设备包括通过系统总线连接的处理器、存储器、网络接口和数据库。其中,该计算机设计的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统、计算机程序和数据库。该内存器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的数据库用于储存模型蒸馏方法等数据。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种模型蒸馏方法。所述模型蒸馏方法,包括:获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。Referring to FIG. 3 , an embodiment of the present application further provides a computer device. The computer device may be a server, and its internal structure may be as shown in FIG. 3 . The computer device includes a processor, memory, a network interface, and a database connected by a system bus. Among them, the processor of the computer design is used to provide computing and control capabilities. The memory of the computer device includes a non-volatile storage medium, an internal memory. The nonvolatile storage medium stores an operating system, a computer program, and a database. The memory provides an environment for the execution of the operating system and computer programs in the non-volatile storage medium. The database of the computer equipment is used to store data such as model distillation methods. The network interface of the computer device is used to communicate with an external terminal through a network connection. The computer program when executed by a processor implements a model distillation method. The model distillation method includes: obtaining a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, the pre-training model is a model obtained based on Bert network training; The labeled training samples and the student model perform overall distillation learning on the pre-training model to obtain the student model after the first distillation; the unlabeled training samples and the student model after the first distillation are used to pair The pre-trained model is subjected to layered distillation learning to obtain a student model after the second distillation; the labeled training sample is used to perform layered distillation learning on the student model after the second distillation, and a trained student model is obtained. student model.
本实施例通过采用未标注的训练样本和学生模型对预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型,采用未标注的训练样本和第一次蒸馏后的学生模型对预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型,采用带标注的训练样本对第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型,从而通过三次蒸馏提升了蒸馏后得到的模型的准确率;因为在第一次和第二次蒸馏采用的是未标注的训练样本,从而减少了对训练样本标注的需求,降低了蒸馏的成本。In this embodiment, the overall distillation learning is performed on the pre-training model by using the unlabeled training samples and the student model to obtain the student model after the first distillation, and the unlabeled training samples and the student model after the first distillation are used for pre-training. The model is subjected to hierarchical distillation learning, and the student model after the second distillation is obtained. The labeled training samples are used to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model, which is improved by three distillations. The accuracy of the model obtained after distillation is improved; because unlabeled training samples are used in the first and second distillations, the need for labeling training samples is reduced and the cost of distillation is reduced.
本申请一实施例还提供一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现一种模型蒸馏方法,包括步骤:获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;采用所述带标注的训练样本对所述第二 次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。An embodiment of the present application also provides a computer-readable storage medium on which a computer program is stored. When the computer program is executed by a processor, a method for model distillation is implemented, including the steps of: acquiring a pre-trained model, a student model, a plurality of Labeled training samples, a plurality of unlabeled training samples, the pre-training model is a model obtained based on Bert network training; using the unlabeled training samples and the student model to perform overall distillation learning on the pre-training model , obtain the student model after the first distillation; use the unlabeled training samples and the student model after the first distillation to perform hierarchical distillation learning on the pre-training model, and obtain the student model after the second distillation model; using the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation, to obtain a trained student model.
上述执行的模型蒸馏方法,通过采用未标注的训练样本和学生模型对预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型,采用未标注的训练样本和第一次蒸馏后的学生模型对预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型,采用带标注的训练样本对第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型,从而通过三次蒸馏提升了蒸馏后得到的模型的准确率;因为在第一次和第二次蒸馏采用的是未标注的训练样本,从而减少了对训练样本标注的需求,降低了蒸馏的成本。The model distillation method implemented above uses the unlabeled training samples and student models to perform overall distillation learning on the pre-training model to obtain the first distilled student model, using the unlabeled training samples and the first distilled student model. The model performs hierarchical distillation learning on the pre-trained model to obtain the student model after the second distillation, and uses the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation, and obtains the trained student model. The accuracy of the model obtained after distillation is improved by three distillations; because unlabeled training samples are used in the first and second distillations, the need for labeling of training samples is reduced and the cost of distillation is reduced.
所述计算机存储介质可以是非易失性,也可以是易失性。The computer storage medium can be non-volatile or volatile.
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的和实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可以包括只读存储器(ROM)、可编程ROM(PROM)、电可编程ROM(EPROM)、电可擦除可编程ROM(EEPROM)或闪存。易失性存储器可包括随机存取存储器(RAM)或者外部高速缓冲存储器。作为说明而非局限,RAM以多种形式可得,诸如静态RAM(SRAM)、动态RAM(DRAM)、同步DRAM(SDRAM)、双速据率SDRAM(SSRSDRAM)、增强型SDRAM(ESDRAM)、同步链路(Synchlink)DRAM(SLDRAM)、存储器总线(Rambus)直接RAM(RDRAM)、直接存储器总线动态RAM(DRDRAM)、以及存储器总线动态RAM(RDRAM)等。Those of ordinary skill in the art can understand that all or part of the processes in the methods of the above embodiments can be implemented by instructing relevant hardware through a computer program, and the computer program can be stored in a non-volatile computer-readable storage In the medium, when the computer program is executed, it may include the processes of the above-mentioned method embodiments. Wherein, any reference to memory, storage, database or other medium provided in this application and used in the embodiments may include non-volatile and/or volatile memory. Nonvolatile memory may include read only memory (ROM), programmable ROM (PROM), electrically programmable ROM (EPROM), electrically erasable programmable ROM (EEPROM), or flash memory. Volatile memory may include random access memory (RAM) or external cache memory. By way of illustration and not limitation, RAM is available in various forms such as static RAM (SRAM), dynamic RAM (DRAM), synchronous DRAM (SDRAM), double-rate SDRAM (SSRSDRAM), enhanced SDRAM (ESDRAM), synchronous Link (Synchlink) DRAM (SLDRAM), memory bus (Rambus) direct RAM (RDRAM), direct memory bus dynamic RAM (DRDRAM), and memory bus dynamic RAM (RDRAM), etc.
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、装置、物品或者方法不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、装置、物品或者方法所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、装置、物品或者方法中还存在另外的相同要素。It should be noted that, herein, the terms "comprising", "comprising" or any other variation thereof are intended to encompass non-exclusive inclusion, such that a process, device, article or method comprising a series of elements includes not only those elements, It also includes other elements not expressly listed or inherent to such a process, apparatus, article or method. Without further limitation, an element qualified by the phrase "comprising a..." does not preclude the presence of additional identical elements in the process, apparatus, article, or method that includes the element.
以上所述仅为本申请的优选实施例,并非因此限制本申请的专利范围,凡是利用本申请说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本申请的专利保护范围内。The above are only the preferred embodiments of the present application, and are not intended to limit the scope of the patent of the present application. Any equivalent structure or equivalent process transformation made by using the contents of the description and drawings of the present application, or directly or indirectly applied to other related The technical field is similarly included in the scope of patent protection of this application.

Claims (20)

  1. 一种模型蒸馏方法,其中,所述方法包括:A model distillation method, wherein the method comprises:
    获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;Obtain a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, and the pre-training model is a model obtained based on Bert network training;
    采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;Using the unlabeled training samples and the student model to perform overall distillation learning on the pre-training model to obtain the student model after the first distillation;
    采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;Use the unlabeled training samples and the student model after the first distillation to perform hierarchical distillation learning on the pre-training model to obtain the student model after the second distillation;
    采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。Use the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model.
  2. 根据权利要求1所述的模型蒸馏方法,其中,所述采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型的步骤,包括:The model distillation method according to claim 1, wherein the step of using the unlabeled training samples and the student model to perform overall distillation learning on the pre-trained model to obtain the student model after the first distillation ,include:
    将所述未标注的训练样本输入所述预训练模型进行评分预测,获取所述预训练模型的评分预测层输出的第一预测评分;Input the unlabeled training samples into the pre-training model for scoring prediction, and obtain the first prediction score output by the scoring prediction layer of the pre-training model;
    将所述未标注的训练样本输入所述学生模型的进行评分预测,得到第二预测评分;Inputting the unlabeled training samples into the student model for scoring prediction to obtain a second prediction score;
    将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值,根据所述第一损失值更新所述学生模型的所有参数,将更新参数后的所述学生模型用于下一次计算所述第二预测评分;Input the first predicted score and the second predicted score into the first loss function for calculation to obtain a first loss value, update all parameters of the student model according to the first loss value, and update all parameters after updating the parameters. The student model is used for the next calculation of the second predicted score;
    重复执行上述方法步骤直至所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件,将所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件的所述学生模型,确定为所述第一次蒸馏后的学生模型。Repeat the above method steps until the first loss value reaches the first convergence condition or the number of iterations reaches the second convergence condition, and the student whose first loss value reaches the first convergence condition or the number of iterations reaches the second convergence condition model, identified as the student model after the first distillation.
  3. 根据权利要求2所述的模型蒸馏方法,其中,所述将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值的步骤,包括:The model distillation method according to claim 2, wherein the step of inputting the first prediction score and the second prediction score into a first loss function for calculation to obtain a first loss value comprises:
    将所述第一预测评分、所述第二预测评分输入KL散度损失函数进行计算,得到所述第一损失值。The first prediction score and the second prediction score are input into the KL divergence loss function for calculation, and the first loss value is obtained.
  4. 根据权利要求1所述的模型蒸馏方法,其中,所述采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型的步骤,包括:The model distillation method according to claim 1, wherein the pre-trained model is subjected to hierarchical distillation learning by using the unlabeled training samples and the student model after the first distillation to obtain the second The steps of the distilled student model include:
    将所述未标注的训练样本输入所述预训练模型进行概率预测,获取所述预训练模型的概率预测层输出的第一预测概率;Input the unlabeled training samples into the pre-training model for probability prediction, and obtain the first prediction probability output by the probability prediction layer of the pre-training model;
    将所述未标注的训练样本输入所述第一次蒸馏后的学生模型进行概率预测,得到第二预测概率;Inputting the unlabeled training samples into the student model after the first distillation for probability prediction to obtain a second predicted probability;
    将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数,将更新参数后的所述第一次蒸馏后的学生模型用于下一次计算所述第二预测概率;Inputting the first predicted probability and the second predicted probability into a second loss function for calculation to obtain a second loss value, and updating the first preset parameter hierarchical update rule according to the second loss value The parameters of the student model after the second distillation, the student model after the first distillation after updating the parameters is used for the next calculation of the second prediction probability;
    重复执行上述方法步骤直至所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件,将所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件的所述第一次蒸馏后的学生模型,确定为所述第二次蒸馏后的学生模型。Repeat the above method steps until the second loss value reaches the third convergence condition or the number of iterations reaches the fourth convergence condition, and the second loss value reaches the third convergence condition or the iteration number reaches the fourth convergence condition. The student model after the first distillation is determined as the student model after the second distillation.
  5. 根据权利要求4所述的模型蒸馏方法,其中,所述将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二 损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数的步骤,包括:The model distillation method according to claim 4, wherein the first predicted probability and the second predicted probability are input into a second loss function for calculation to obtain a second loss value, according to the second loss value The step of updating the parameters of the student model after the first distillation according to the first preset parameter hierarchical update rule includes:
    将所述第一预测概率、所述第二预测概率输入MSE损失函数进行计算,得到所述第二损失值;Inputting the first predicted probability and the second predicted probability into the MSE loss function for calculation to obtain the second loss value;
    当所述第二损失值中的Dense层参数未达到第一Dense层收敛条件时,根据所述第二损失值中的Dense层参数更新所述第一次蒸馏后的学生模型的Dense层的参数,否则,当所述第二损失值中的BiLSTM层参数未达到第一BiLSTM层收敛条件时,根据所述第二损失值中的BiLSTM层参数更新所述第一次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第二损失值中的Embedding层参数更新所述第一次蒸馏后的学生模型的Embedding层的参数。When the Dense layer parameters in the second loss value do not reach the convergence condition of the first Dense layer, update the parameters of the Dense layer of the student model after the first distillation according to the Dense layer parameters in the second loss value , otherwise, when the BiLSTM layer parameters in the second loss value do not reach the convergence condition of the first BiLSTM layer, update the BiLSTM of the student model after the first distillation according to the BiLSTM layer parameters in the second loss value The parameters of the layer, otherwise, the parameters of the Embedding layer of the student model after the first distillation are updated according to the parameters of the Embedding layer in the second loss value.
  6. 根据权利要求1所述的模型蒸馏方法,其中,所述采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型的步骤,包括:The model distillation method according to claim 1, wherein the step of using the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model includes the following steps: :
    将所述带标注的训练样本输入所述第二次蒸馏后的学生模型进行概率预测,得到第三预测概率;Inputting the labeled training sample into the student model after the second distillation for probability prediction to obtain a third prediction probability;
    将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行计算,得到第三损失值,根据所述第三损失值按第二预设参数分层更新规则更新所述第二次蒸馏后的学生模型的参数,将更新参数后的所述第二次蒸馏后的学生模型用于下一次计算所述第三预测概率;Inputting the third prediction probability and the sample calibration value of the labeled training sample into a third loss function for calculation to obtain a third loss value, and updating the rules hierarchically according to the third loss value and according to the second preset parameter updating the parameters of the student model after the second distillation, and using the student model after the second distillation after updating the parameters for the next calculation of the third prediction probability;
    重复执行上述方法步骤直至所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件,将所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件的所述第二次蒸馏后的学生模型,确定为所述训练好的学生模型。Repeat the above method steps until the third loss value reaches the fifth convergence condition or the number of iterations reaches the sixth convergence condition, and the third loss value reaches the fifth convergence condition or the iteration number reaches the sixth convergence condition. The student model after secondary distillation is determined as the trained student model.
  7. 根据权利要求6所述的模型蒸馏方法,其中,所述将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行计算,得到第三损失值,根据所述第三损失值按第二预设参数分层更新规则更新所述第二次蒸馏后的学生模型的参数的步骤,包括:The model distillation method according to claim 6, wherein the third prediction probability and the sample calibration value of the labeled training sample are input into a third loss function for calculation to obtain a third loss value, according to the The step of updating the parameters of the student model after the second distillation by the third loss value according to the second preset parameter hierarchical update rule, includes:
    将所述第三预测概率、所述带标注的训练样本的样本标定值输入交叉熵损失函数进行计算,得到所述第三损失值;Inputting the third predicted probability and the sample calibration value of the labeled training sample into a cross-entropy loss function for calculation to obtain the third loss value;
    当所述第三损失值中的Dense层参数未达到第二Dense层收敛条件时,根据所述第三损失值中的Dense层参数更新所述第二次蒸馏后的学生模型的Dense层的参数,否则,当所述第三损失值中的BiLSTM层参数未达到第二BiLSTM层收敛条件时,根据所述第三损失值中的BiLSTM层参数更新所述第二次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第三损失值中的Embedding层参数更新所述第二次蒸馏后的学生模型的Embedding层的参数。When the Dense layer parameters in the third loss value do not reach the convergence condition of the second Dense layer, update the parameters of the Dense layer of the student model after the second distillation according to the Dense layer parameters in the third loss value , otherwise, when the BiLSTM layer parameters in the third loss value do not reach the convergence condition of the second BiLSTM layer, update the BiLSTM of the student model after the second distillation according to the BiLSTM layer parameters in the third loss value The parameters of the layer, otherwise, the parameters of the Embedding layer of the student model after the second distillation are updated according to the parameters of the Embedding layer in the third loss value.
  8. 一种模型蒸馏装置,其中,所述装置包括:A model distillation apparatus, wherein the apparatus comprises:
    数据获取模块,用于获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;a data acquisition module for acquiring a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, where the pre-training model is a model obtained through Bert network training;
    第一阶段蒸馏模块,用于采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;The first-stage distillation module is used to perform overall distillation learning on the pre-training model by using the unlabeled training samples and the student model to obtain the student model after the first distillation;
    第二阶段蒸馏模块,用于采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;The second-stage distillation module is used to perform hierarchical distillation learning on the pre-training model by using the unlabeled training samples and the student model after the first distillation to obtain the student model after the second distillation;
    第三阶段蒸馏模块,用于采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。The third-stage distillation module is used to perform hierarchical distillation learning on the student model after the second distillation by using the labeled training samples to obtain a trained student model.
  9. 一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序, 其中,所述处理器执行所述计算机程序时实现如下方法步骤:A computer device, comprising a memory and a processor, wherein the memory stores a computer program, wherein the processor implements the following method steps when executing the computer program:
    获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;Obtain a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, and the pre-training model is a model obtained based on Bert network training;
    采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;Using the unlabeled training samples and the student model to perform overall distillation learning on the pre-training model to obtain the student model after the first distillation;
    采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;Use the unlabeled training samples and the student model after the first distillation to perform hierarchical distillation learning on the pre-training model to obtain the student model after the second distillation;
    采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。Use the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model.
  10. 根据权利要求9所述的计算机设备,其中,所述采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型的步骤,包括:The computer device according to claim 9, wherein the step of performing overall distillation learning on the pre-training model by using the unlabeled training samples and the student model to obtain the student model after the first distillation, include:
    将所述未标注的训练样本输入所述预训练模型进行评分预测,获取所述预训练模型的评分预测层输出的第一预测评分;Input the unlabeled training samples into the pre-training model for scoring prediction, and obtain the first prediction score output by the scoring prediction layer of the pre-training model;
    将所述未标注的训练样本输入所述学生模型的进行评分预测,得到第二预测评分;Inputting the unlabeled training samples into the student model for scoring prediction to obtain a second prediction score;
    将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值,根据所述第一损失值更新所述学生模型的所有参数,将更新参数后的所述学生模型用于下一次计算所述第二预测评分;Input the first predicted score and the second predicted score into the first loss function for calculation to obtain a first loss value, update all parameters of the student model according to the first loss value, and update all parameters after updating the parameters. The student model is used for the next calculation of the second predicted score;
    重复执行上述方法步骤直至所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件,将所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件的所述学生模型,确定为所述第一次蒸馏后的学生模型。Repeat the above method steps until the first loss value reaches the first convergence condition or the number of iterations reaches the second convergence condition, and the student whose first loss value reaches the first convergence condition or the number of iterations reaches the second convergence condition model, identified as the student model after the first distillation.
  11. 根据权利要求10所述的计算机设备,其中,所述将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值的步骤,包括:The computer device according to claim 10, wherein the step of inputting the first predicted score and the second predicted score into a first loss function for calculation to obtain a first loss value comprises:
    将所述第一预测评分、所述第二预测评分输入KL散度损失函数进行计算,得到所述第一损失值。The first prediction score and the second prediction score are input into the KL divergence loss function for calculation, and the first loss value is obtained.
  12. 根据权利要求9所述的计算机设备,其中,所述采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型的步骤,包括:The computer device according to claim 9, wherein the pre-trained model is subjected to hierarchical distillation learning using the unlabeled training samples and the student model after the first distillation to obtain the second distillation Post-student model steps include:
    将所述未标注的训练样本输入所述预训练模型进行概率预测,获取所述预训练模型的概率预测层输出的第一预测概率;Input the unlabeled training samples into the pre-training model for probability prediction, and obtain the first prediction probability output by the probability prediction layer of the pre-training model;
    将所述未标注的训练样本输入所述第一次蒸馏后的学生模型进行概率预测,得到第二预测概率;Inputting the unlabeled training samples into the student model after the first distillation for probability prediction to obtain a second predicted probability;
    将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数,将更新参数后的所述第一次蒸馏后的学生模型用于下一次计算所述第二预测概率;Inputting the first predicted probability and the second predicted probability into a second loss function for calculation to obtain a second loss value, and updating the first preset parameter hierarchical update rule according to the second loss value The parameters of the student model after the second distillation, the student model after the first distillation after updating the parameters is used for the next calculation of the second prediction probability;
    重复执行上述方法步骤直至所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件,将所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件的所述第一次蒸馏后的学生模型,确定为所述第二次蒸馏后的学生模型。Repeat the above method steps until the second loss value reaches the third convergence condition or the number of iterations reaches the fourth convergence condition, and the second loss value reaches the third convergence condition or the iteration number reaches the fourth convergence condition. The student model after the first distillation is determined as the student model after the second distillation.
  13. 根据权利要求12所述的计算机设备,其中,所述将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数的步骤,包括:The computer device according to claim 12, wherein the first predicted probability and the second predicted probability are input into a second loss function for calculation to obtain a second loss value, and according to the second loss value according to The step of updating the parameters of the student model after the first distillation by the first preset parameter hierarchical update rule includes:
    将所述第一预测概率、所述第二预测概率输入MSE损失函数进行计算,得到所述第二损失值;Inputting the first predicted probability and the second predicted probability into the MSE loss function for calculation to obtain the second loss value;
    当所述第二损失值中的Dense层参数未达到第一Dense层收敛条件时,根据所述第二损失值中的Dense层参数更新所述第一次蒸馏后的学生模型的Dense层的参数,否则,当所述第二损失值中的BiLSTM层参数未达到第一BiLSTM层收敛条件时,根据所述第二损失值中的BiLSTM层参数更新所述第一次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第二损失值中的Embedding层参数更新所述第一次蒸馏后的学生模型的Embedding层的参数。When the Dense layer parameters in the second loss value do not reach the convergence condition of the first Dense layer, update the parameters of the Dense layer of the student model after the first distillation according to the Dense layer parameters in the second loss value , otherwise, when the BiLSTM layer parameters in the second loss value do not reach the convergence condition of the first BiLSTM layer, update the BiLSTM of the student model after the first distillation according to the BiLSTM layer parameters in the second loss value The parameters of the layer, otherwise, the parameters of the Embedding layer of the student model after the first distillation are updated according to the parameters of the Embedding layer in the second loss value.
  14. 根据权利要求9所述的计算机设备,其中,所述采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型的步骤,包括:The computer equipment according to claim 9, wherein the step of using the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation, and obtaining a trained student model, comprises:
    将所述带标注的训练样本输入所述第二次蒸馏后的学生模型进行概率预测,得到第三预测概率;Inputting the labeled training sample into the student model after the second distillation for probability prediction to obtain a third prediction probability;
    将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行计算,得到第三损失值,根据所述第三损失值按第二预设参数分层更新规则更新所述第二次蒸馏后的学生模型的参数,将更新参数后的所述第二次蒸馏后的学生模型用于下一次计算所述第三预测概率;Inputting the third prediction probability and the sample calibration value of the labeled training sample into a third loss function for calculation to obtain a third loss value, and updating the rules hierarchically according to the third loss value and according to the second preset parameter updating the parameters of the student model after the second distillation, and using the student model after the second distillation after updating the parameters for the next calculation of the third prediction probability;
    重复执行上述方法步骤直至所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件,将所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件的所述第二次蒸馏后的学生模型,确定为所述训练好的学生模型。Repeat the above method steps until the third loss value reaches the fifth convergence condition or the number of iterations reaches the sixth convergence condition, and the third loss value reaches the fifth convergence condition or the iteration number reaches the sixth convergence condition. The student model after secondary distillation is determined as the trained student model.
  15. 一种计算机可读存储介质,其上存储有计算机程序,其中,所述计算机程序被处理器执行时实现如下方法步骤:A computer-readable storage medium on which a computer program is stored, wherein when the computer program is executed by a processor, the following method steps are implemented:
    获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;Obtain a pre-training model, a student model, a plurality of labeled training samples, and a plurality of unlabeled training samples, and the pre-training model is a model obtained based on Bert network training;
    采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;Using the unlabeled training samples and the student model to perform overall distillation learning on the pre-training model to obtain the student model after the first distillation;
    采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;Use the unlabeled training samples and the student model after the first distillation to perform hierarchical distillation learning on the pre-training model to obtain the student model after the second distillation;
    采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。Use the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model.
  16. 根据权利要求15所述的计算机可读存储介质,其中,所述采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型的步骤,包括:The computer-readable storage medium according to claim 15, wherein the overall distillation learning is performed on the pre-trained model by using the unlabeled training samples and the student model to obtain a student model after the first distillation steps, including:
    将所述未标注的训练样本输入所述预训练模型进行评分预测,获取所述预训练模型的评分预测层输出的第一预测评分;Input the unlabeled training samples into the pre-training model for scoring prediction, and obtain the first prediction score output by the scoring prediction layer of the pre-training model;
    将所述未标注的训练样本输入所述学生模型的进行评分预测,得到第二预测评分;Inputting the unlabeled training samples into the student model for scoring prediction to obtain a second prediction score;
    将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值,根据所述第一损失值更新所述学生模型的所有参数,将更新参数后的所述学生模型用于下一次计算所述第二预测评分;Input the first predicted score and the second predicted score into the first loss function for calculation to obtain a first loss value, update all parameters of the student model according to the first loss value, and update all parameters after updating the parameters. The student model is used for the next calculation of the second predicted score;
    重复执行上述方法步骤直至所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件,将所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件的所述学生模型,确定为所述第一次蒸馏后的学生模型。Repeat the above method steps until the first loss value reaches the first convergence condition or the number of iterations reaches the second convergence condition, and the student whose first loss value reaches the first convergence condition or the number of iterations reaches the second convergence condition model, identified as the student model after the first distillation.
  17. 根据权利要求16所述的计算机可读存储介质,其中,所述将所述第一预 测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值的步骤,包括:The computer-readable storage medium according to claim 16, wherein the step of inputting the first prediction score and the second prediction score into a first loss function for calculation to obtain a first loss value comprises:
    将所述第一预测评分、所述第二预测评分输入KL散度损失函数进行计算,得到所述第一损失值。The first prediction score and the second prediction score are input into the KL divergence loss function for calculation, and the first loss value is obtained.
  18. 根据权利要求15所述的计算机可读存储介质,其中,所述采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型的步骤,包括:The computer-readable storage medium according to claim 15, wherein the pre-trained model is subjected to hierarchical distillation learning using the unlabeled training samples and the first distilled student model, to obtain the first The steps of the student model after double distillation include:
    将所述未标注的训练样本输入所述预训练模型进行概率预测,获取所述预训练模型的概率预测层输出的第一预测概率;Input the unlabeled training samples into the pre-training model for probability prediction, and obtain the first prediction probability output by the probability prediction layer of the pre-training model;
    将所述未标注的训练样本输入所述第一次蒸馏后的学生模型进行概率预测,得到第二预测概率;Inputting the unlabeled training samples into the student model after the first distillation for probability prediction to obtain a second predicted probability;
    将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数,将更新参数后的所述第一次蒸馏后的学生模型用于下一次计算所述第二预测概率;Inputting the first predicted probability and the second predicted probability into a second loss function for calculation to obtain a second loss value, and updating the first preset parameter hierarchical update rule according to the second loss value The parameters of the student model after the second distillation, the student model after the first distillation after updating the parameters is used for the next calculation of the second prediction probability;
    重复执行上述方法步骤直至所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件,将所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件的所述第一次蒸馏后的学生模型,确定为所述第二次蒸馏后的学生模型。Repeat the above method steps until the second loss value reaches the third convergence condition or the number of iterations reaches the fourth convergence condition, and the second loss value reaches the third convergence condition or the iteration number reaches the fourth convergence condition. The student model after the first distillation is determined as the student model after the second distillation.
  19. 根据权利要求18所述的计算机可读存储介质,其中,所述将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数的步骤,包括:The computer-readable storage medium according to claim 18, wherein the first predicted probability and the second predicted probability are input into a second loss function for calculation to obtain a second loss value, according to the second loss function. The step of updating the parameters of the student model after the first distillation of the loss value according to the first preset parameter hierarchical update rule includes:
    将所述第一预测概率、所述第二预测概率输入MSE损失函数进行计算,得到所述第二损失值;Inputting the first predicted probability and the second predicted probability into the MSE loss function for calculation to obtain the second loss value;
    当所述第二损失值中的Dense层参数未达到第一Dense层收敛条件时,根据所述第二损失值中的Dense层参数更新所述第一次蒸馏后的学生模型的Dense层的参数,否则,当所述第二损失值中的BiLSTM层参数未达到第一BiLSTM层收敛条件时,根据所述第二损失值中的BiLSTM层参数更新所述第一次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第二损失值中的Embedding层参数更新所述第一次蒸馏后的学生模型的Embedding层的参数。When the Dense layer parameters in the second loss value do not reach the convergence condition of the first Dense layer, update the parameters of the Dense layer of the student model after the first distillation according to the Dense layer parameters in the second loss value , otherwise, when the BiLSTM layer parameters in the second loss value do not reach the convergence condition of the first BiLSTM layer, update the BiLSTM of the student model after the first distillation according to the BiLSTM layer parameters in the second loss value The parameters of the layer, otherwise, the parameters of the Embedding layer of the student model after the first distillation are updated according to the parameters of the Embedding layer in the second loss value.
  20. 根据权利要求15所述的计算机可读存储介质,其中,所述采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型的步骤,包括:The computer-readable storage medium according to claim 15, wherein the step of using the labeled training samples to perform hierarchical distillation learning on the student model after the second distillation to obtain a trained student model ,include:
    将所述带标注的训练样本输入所述第二次蒸馏后的学生模型进行概率预测,得到第三预测概率;Inputting the labeled training sample into the student model after the second distillation for probability prediction to obtain a third prediction probability;
    将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行计算,得到第三损失值,根据所述第三损失值按第二预设参数分层更新规则更新所述第二次蒸馏后的学生模型的参数,将更新参数后的所述第二次蒸馏后的学生模型用于下一次计算所述第三预测概率;Inputting the third prediction probability and the sample calibration value of the labeled training sample into a third loss function for calculation to obtain a third loss value, and updating the rules hierarchically according to the third loss value and according to the second preset parameter updating the parameters of the student model after the second distillation, and using the student model after the second distillation after updating the parameters for the next calculation of the third prediction probability;
    重复执行上述方法步骤直至所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件,将所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件的所述第二次蒸馏后的学生模型,确定为所述训练好的学生模型。Repeat the above method steps until the third loss value reaches the fifth convergence condition or the number of iterations reaches the sixth convergence condition, and the third loss value reaches the fifth convergence condition or the iteration number reaches the sixth convergence condition. The student model after secondary distillation is determined as the trained student model.
PCT/CN2021/084539 2021-02-26 2021-03-31 Model distillation method and apparatus, device, and storage medium WO2022178948A1 (en)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN202110220512.9A CN112836762A (en) 2021-02-26 2021-02-26 Model distillation method, device, equipment and storage medium
CN202110220512.9 2021-02-26

Publications (1)

Publication Number Publication Date
WO2022178948A1 true WO2022178948A1 (en) 2022-09-01

Family

ID=75933941

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2021/084539 WO2022178948A1 (en) 2021-02-26 2021-03-31 Model distillation method and apparatus, device, and storage medium

Country Status (2)

Country Link
CN (1) CN112836762A (en)
WO (1) WO2022178948A1 (en)

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113177616B (en) * 2021-06-29 2021-09-17 腾讯科技(深圳)有限公司 Image classification method, device, equipment and storage medium
CN113673698B (en) * 2021-08-24 2024-05-10 平安科技(深圳)有限公司 Distillation method, device, equipment and storage medium suitable for BERT model
CN115861847B (en) * 2023-02-24 2023-05-05 耕宇牧星(北京)空间科技有限公司 Intelligent auxiliary labeling method for visible light remote sensing image target

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110059740A (en) * 2019-04-12 2019-07-26 杭州电子科技大学 A kind of deep learning semantic segmentation model compression method for embedded mobile end
CN110852426A (en) * 2019-11-19 2020-02-28 成都晓多科技有限公司 Pre-training model integration acceleration method and device based on knowledge distillation
CN111242297A (en) * 2019-12-19 2020-06-05 北京迈格威科技有限公司 Knowledge distillation-based model training method, image processing method and device
US20200302230A1 (en) * 2019-03-21 2020-09-24 International Business Machines Corporation Method of incremental learning for object detection
WO2021002968A1 (en) * 2019-07-02 2021-01-07 Microsoft Technology Licensing, Llc Model generation based on model compression

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20200302230A1 (en) * 2019-03-21 2020-09-24 International Business Machines Corporation Method of incremental learning for object detection
CN110059740A (en) * 2019-04-12 2019-07-26 杭州电子科技大学 A kind of deep learning semantic segmentation model compression method for embedded mobile end
WO2021002968A1 (en) * 2019-07-02 2021-01-07 Microsoft Technology Licensing, Llc Model generation based on model compression
CN110852426A (en) * 2019-11-19 2020-02-28 成都晓多科技有限公司 Pre-training model integration acceleration method and device based on knowledge distillation
CN111242297A (en) * 2019-12-19 2020-06-05 北京迈格威科技有限公司 Knowledge distillation-based model training method, image processing method and device

Also Published As

Publication number Publication date
CN112836762A (en) 2021-05-25

Similar Documents

Publication Publication Date Title
WO2022178948A1 (en) Model distillation method and apparatus, device, and storage medium
US20220382564A1 (en) Aggregate features for machine learning
Wen et al. Optimized backstepping for tracking control of strict-feedback systems
CN113673698B (en) Distillation method, device, equipment and storage medium suitable for BERT model
CN111061847A (en) Dialogue generation and corpus expansion method and device, computer equipment and storage medium
WO2020151310A1 (en) Text generation method and device, computer apparatus, and medium
CN111598213B (en) Network training method, data identification method, device, equipment and medium
WO2022142043A1 (en) Course recommendation method and apparatus, device, and storage medium
Kamalapurkar et al. Concurrent learning-based approximate optimal regulation
CN112418482A (en) Cloud computing energy consumption prediction method based on time series clustering
KR20210106398A (en) Conversation-based recommending method, conversation-based recommending apparatus, and device
CN113642707A (en) Model training method, device, equipment and storage medium based on federal learning
CN116363423A (en) Knowledge distillation method, device and storage medium for small sample learning
Xu et al. Optimal regulation of uncertain dynamic systems using adaptive dynamic programming
CN115186062A (en) Multi-modal prediction method, device, equipment and storage medium
CN114416984A (en) Text classification method, device and equipment based on artificial intelligence and storage medium
Shen et al. A unified analysis of AdaGrad with weighted aggregation and momentum acceleration
CN113268564B (en) Method, device, equipment and storage medium for generating similar problems
Elahi et al. Finite-time stabilisation of discrete networked cascade control systems under transmission delay and packet dropout via static output feedback control
Mera et al. Finite-time attractive ellipsoid method: implicit Lyapunov function approach
WO2022178950A1 (en) Method and apparatus for predicting statement entity, and computer device
CN113777965B (en) Spray quality control method, spray quality control device, computer equipment and storage medium
Biagiola et al. Robust model predictive control of Wiener systems
Bhatnagar et al. Feature search in the Grassmanian in online reinforcement learning
Wang et al. Intermediate variable normalization for gradient descent learning for hierarchical fuzzy system

Legal Events

Date Code Title Description
121 Ep: the epo has been informed by wipo that ep was designated in this application

Ref document number: 21927382

Country of ref document: EP

Kind code of ref document: A1

NENP Non-entry into the national phase

Ref country code: DE

122 Ep: pct application non-entry in european phase

Ref document number: 21927382

Country of ref document: EP

Kind code of ref document: A1