CN113822434A  Model selection learning for knowledge distillation  Google Patents
Model selection learning for knowledge distillation Download PDFInfo
 Publication number
 CN113822434A CN113822434A CN202010561319.7A CN202010561319A CN113822434A CN 113822434 A CN113822434 A CN 113822434A CN 202010561319 A CN202010561319 A CN 202010561319A CN 113822434 A CN113822434 A CN 113822434A
 Authority
 CN
 China
 Prior art keywords
 model
 training
 target
 reference models
 prediction
 Prior art date
 Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
 Pending
Links
 238000004821 distillation Methods 0.000 title claims abstract description 25
 238000005070 sampling Methods 0.000 claims description 41
 230000000875 corresponding Effects 0.000 claims description 40
 238000010200 validation analysis Methods 0.000 claims description 13
 230000002787 reinforcement Effects 0.000 claims description 7
 239000002131 composite material Substances 0.000 claims description 3
 238000000034 method Methods 0.000 description 36
 239000000203 mixture Substances 0.000 description 4
 238000007906 compression Methods 0.000 description 2
 238000010586 diagram Methods 0.000 description 2
 230000000694 effects Effects 0.000 description 2
 238000003058 natural language processing Methods 0.000 description 2
 239000000126 substance Substances 0.000 description 2
 230000002457 bidirectional Effects 0.000 description 1
 238000002474 experimental method Methods 0.000 description 1
 230000004048 modification Effects 0.000 description 1
 238000006011 modification reaction Methods 0.000 description 1
 230000003287 optical Effects 0.000 description 1
 230000002104 routine Effects 0.000 description 1
 238000003786 synthesis reaction Methods 0.000 description 1
Images
Classifications

 G—PHYSICS
 G06—COMPUTING; CALCULATING OR COUNTING
 G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
 G06N20/00—Machine learning
 G06N20/20—Ensemble learning

 G—PHYSICS
 G06—COMPUTING; CALCULATING OR COUNTING
 G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
 G06N20/00—Machine learning

 G—PHYSICS
 G06—COMPUTING; CALCULATING OR COUNTING
 G06F—ELECTRIC DIGITAL DATA PROCESSING
 G06F18/00—Pattern recognition
 G06F18/20—Analysing
 G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
 G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting

 G—PHYSICS
 G06—COMPUTING; CALCULATING OR COUNTING
 G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
 G06N3/00—Computing arrangements based on biological models
 G06N3/004—Artificial life, i.e. computing arrangements simulating life
 G06N3/006—Artificial life, i.e. computing arrangements simulating life based on simulated virtual individual or collective life forms, e.g. social simulations or particle swarm optimisation [PSO]

 G—PHYSICS
 G06—COMPUTING; CALCULATING OR COUNTING
 G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
 G06N3/00—Computing arrangements based on biological models
 G06N3/02—Neural networks
 G06N3/04—Architecture, e.g. interconnection topology
 G06N3/045—Combinations of networks

 G—PHYSICS
 G06—COMPUTING; CALCULATING OR COUNTING
 G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
 G06N3/00—Computing arrangements based on biological models
 G06N3/02—Neural networks
 G06N3/08—Learning methods
 G06N3/088—Nonsupervised learning, e.g. competitive learning
Abstract
The present disclosure provides methods and apparatus for obtaining a target model based on knowledge distillation. A data set and a set of candidate reference models may be obtained. A selected set of reference models selected from the set of candidate reference models may be determined for each training sample in the data set. A set of target probability distributions for the set of selected reference models output for the training samples may be obtained. The target model may be trained using the set of target probability distributions.
Description
Background
With the development of deep learning techniques, various deep pretraining models are continuously developed and have excellent performance in fields such as natural language processing and computer vision. For example, in the field of natural language processing, deep Pretrained models such as Bidirectional Encoder representation from transducers (BERT) models, Generative Pretrained transducers (GPT) models, etc. have proven to work well. Such deep pretraining models tend to be complex models that rely on deep networks with a huge number of parameters, e.g., the BERT model may contain 3.4 billion parameters for 24 converter layers, and the GPT model may contain 15 billion parameters for 48 converter layers. Training and using such complex models for inference is time consuming and thus difficult to apply to a practical business scenario. Model compression methods are often employed to obtain a simple model that can be deployed with fewer parameters than a complex model.
Disclosure of Invention
This summary is provided to introduce a selection of concepts that are further described below in the detailed description. This summary is not intended to identify key features or essential features of the claimed subject matter, nor is it intended to be used to limit the scope of the claimed subject matter.
Embodiments of the present disclosure provide methods and apparatus for obtaining a target model based on knowledgebased distillation. A data set and a set of candidate reference models may be obtained. A selected set of reference models selected from the set of candidate reference models may be determined for each training sample in the data set. A set of target probability distributions for the set of selected reference models output for the training samples may be obtained. The target model may be trained using the set of target probability distributions.
It should be noted that one or more of the above aspects include features that are specifically pointed out in the following detailed description and claims. The following description and the annexed drawings set forth in detail certain illustrative features of the one or more aspects. These features are indicative of but a few of the various ways in which the principles of various aspects may be employed and the present disclosure is intended to include all such aspects and their equivalents.
Drawings
The disclosed aspects will hereinafter be described in conjunction with the appended drawings, which are provided to illustrate, but not to limit, the disclosed aspects.
FIG. 1 illustrates an exemplary process for obtaining a target model according to an embodiment of the disclosure.
FIG. 2 illustrates an exemplary process for selecting a reference model in accordance with an embodiment of the disclosure.
FIG. 3 illustrates an exemplary process for training a target model according to an embodiment of the present disclosure
FIG. 4 illustrates a specific example for training a target model according to an embodiment of the present disclosure.
FIG. 5 illustrates an exemplary process for updating policy parameters according to an embodiment of the present disclosure.
FIG. 6 illustrates an exemplary process for initializing policy parameters according to an embodiment of the disclosure.
FIG. 7 is a flow diagram of an exemplary method for obtaining a target model based on knowledge distillation in accordance with an embodiment of the present disclosure.
FIG. 8 illustrates an exemplary apparatus for obtaining a target model based on knowledge distillation in accordance with an embodiment of the disclosure.
FIG. 9 illustrates an exemplary apparatus for obtaining a target model based on knowledge distillation in accordance with an embodiment of the disclosure.
Detailed Description
The present disclosure will now be discussed with reference to several exemplary embodiments. It is to be understood that the discussion of these embodiments is merely intended to enable those skilled in the art to better understand and thereby practice the embodiments of the present disclosure, and does not teach any limitation as to the scope of the present disclosure.
One commonly used model compression method may be based on Knowledge Distillation (Knowledge Distillation). This approach typically migrates knowledge from the complex model to the simple model by the simple model learning the output distribution of the complex model. Knowledge can be thought of as parameters of the complex model and inputtooutput mapping of the complex model implementation. The method is based on a teacherstudent architecture, wherein a model providing knowledge can be considered a teacher model and a model learning knowledge can be considered a student model. Specifically, in training a student model, training data having not only true annotations, such as artificially provided annotations, but also probability distributions output by a teacher model is provided to the student model. Thus, the student model can optimize its model parameters by learning the probability distribution of the teacher model output in an attempt to achieve the effect of the teacher model. The current knowledge distillation methods can be classified into onetoone methods, manytomany methods, and manytoone methods based on the number of teacher models and student models. In this context, a onetoone approach refers to one teacher model providing knowledge to one student model, a manytomany approach refers to multiple teacher models providing knowledge to multiple student models and combining the multiple student models into a set of student models when applied, and a manytoone approach refers to multiple teacher models providing knowledge to one student model. Recent research and experiments show that the performance of the student model can be improved more effectively by training the student model by a manytoone method.
Current manytoone methods typically assign the same weight to each teacher model, or assign different weights to each teacher model, but these weights are fixed throughout the knowledge distillation. However, even if the student model is trained using a set of training data for the same task, the performance of each teacher model for different training samples in the set of training data is different. Taking two training samples for predicting semantic equivalence between two sentences as an example, for the first training sample, the teacher model a may perform better than the teacher model B; whereas for the second training sample, teacher model B may perform better than teacher model a. In addition, the performance of the student model is gradually improved in each stage of knowledge distillation, and the teacher model for training the student model should be changed correspondingly. For example, in the early stages of knowledge distillation, the performance of the student model is weak and its learning from the complex teacher model may not be good, since the complex teacher model captures finer grain patterns in the training data, which may result in the student model overfitting some parts of the training data. With the progress of the training process, the student model has stronger performance, and if the student model is trained by a teacher model with a small performance difference, the student model may not obtain obvious effect. Thus, assigning the same or fixed weights to various teacher models during knowledge distillation may limit the performance gains of student models.
Embodiments of the present disclosure propose to improve the performance of a target model through an improved training process. For example, a target model may be trained using a set of reference models by knowledge distillation. Herein, the target model refers to a structurally simple and deployable model that is desired to be trained, which may also be referred to as a student model, and the reference model refers to a model having a higher complexity than the target model, which may also be referred to as a teacher model, which can be used to assist in training the target model.
In one aspect, embodiments of the present disclosure propose selecting a reference model for training a target model from a set of candidate reference models through Reinforcement Learning (Reinforcement Learning). For example, different weights may be dynamically assigned to the respective reference models for each training sample in the data set used to train the target model. Herein, the weight assigned to a particular reference model may be implemented as a sampling probability corresponding to the reference model, which may be used to determine whether the reference model is selected for training the target model for the training sample.
In another aspect, embodiments of the present disclosure propose determining sampling probabilities of respective reference models for a current training sample by a policy function. The strategy functions for each reference model comprise, for example, strategy parameters, and information about the current training sample and the performance of the reference model for the current training sample, etc. With such a strategy function, it may be helpful to select a reference model that performs well for the current training sample.
In another aspect, embodiments of the present disclosure propose updating policy parameters in a policy function based on the performance of an objective model. For example, a data set used to train a target model may be partitioned into a plurality of data subsets. The policy parameters in the policy function may be updated based on the performance of the trained target model after the target model is trained with a subset of data, thereby affecting the sampling probability of each reference model for the next subset of data. By updating the policy parameters in the policy function based on the performance of the target model, it may be helpful to select a reference model that matches the performance of the current target model.
In another aspect, embodiments of the present disclosure propose initializing policy parameters in the policy function before determining the sampling probabilities of the respective reference models by the policy function. For example, a set of reference models may be selected from a set of candidate reference models by a policy function, and policy parameters in the policy function may be initialized according to an average performance of the selected set of reference models.
In another aspect, embodiments of the present disclosure propose pretraining the target model before training the target model using a set of reference models. For example, a data set may be scored using all reference models in the set of reference models, and the scored data set may be utilized to pretrain the target model.
FIG. 1 illustrates an exemplary process 100 for obtaining a target model according to an embodiment of the disclosure. The target model is, for example, the target model 160 in fig. 1, which may be a BERT model with a 3layer or 6layer converter.
A data set for training the target model 160 may be obtained firstData collectionCan be divided into a plurality of data subsetsWhere M represents the number of data subsets. Each data subset may include a plurality of training samples. The samples used to train the target model 160 are referred to herein as training samples. By data subsetsFor example, it may include m training samplesFor example, training sample i102 (x)_{i},y_{i}) Wherein x is_{i}Is the ith input, and y_{i}Is directed to x_{i}For example, a humanprovided annotation.
A set of candidate reference models 110 for training the target model 160 may be obtained. The candidate reference model may be a model with a higher complexity than the target model 160, e.g. a BERT model with a 12layer converter. The candidate reference model may be obtained by optimizing, e.g. finetuning, a pretrained model with training data for a specific task.
A representation model 120 may also be obtained, which may be capable of efficiently representing x_{i}Any pretrained model of the content of (1).
The training samples i102 may be provided as input to each reference model in the set of candidate reference models 110 and to the representation model 120 to obtain a set of state information. The set of state information may include at least information about the training sample i102 and the target probability distribution output by each reference model for the training sample i 102. The specific process of obtaining status information will be explained later in conjunction with fig. 2.
Subsequently, at 130, it may be determined, for each candidate reference model in the set of candidate reference models 110, by reinforcement learning, whether to select the candidate reference model for training the target model 160. For example, a policy function π may be utilized_{θ}132 to determine whether to select the candidate reference model. The selected reference models may be combined into a set of selected reference models 140. The specific process of selecting the reference model will be explained later with reference to fig. 2.
Next, a target probability distribution output by each reference model in the set of selected reference models 140 for the training sample i102 may be obtained to obtain a set of target probability distributions 150. The set of state information used to determine the set of selected reference models 140 may include information about the target probability distribution output by each reference model for the training sample i 102. A set of target probability distributions 150 corresponding to respective reference models in the set of selected reference models 140 may be extracted from the set of state information.
The target model 160 may be trained using the training sample i102 and the set of target probability distributions 150. In one embodiment, the parameters of the target model 160 may be optimized after the above process is performed using a single training sample, thereby obtaining a trained modelA training object model 170. In another embodiment, the subset of data may be utilized, e.g., the subset of dataAfter all the training samples in (b) have been processed, the parameters of the target model 160 are optimized to obtain the trained target model 170. In this case, the parameters of the target model 160 remain unchanged for all training samples within the same data subset. The specific process of training the target model 160 will be explained later with reference to fig. 3 and 4.
Subsequently, the trained target model 170 performance may be evaluated. The verification samples 180 may be utilized to evaluate the performance of the trained target model 170. The samples used to evaluate the performance of the target model are referred to herein as validation samples, which may be the same as or different from the training samples. The evaluated performance may be converted into a reward 190. The reward 190 may then be used to update the policy function π_{ } _{θ}132. The specific process of updating the policy parameters will be explained later in conjunction with fig. 5.
In the case where a trained target model 170 is obtained using a single training sample, updating the policy parameters based on the performance of the trained target model 170 may affect the sampling probability of each reference model for the next training sample. Using subsets of dataWhere all of the training samples in the set are used to obtain the trained target model 170, updating the policy parameters based on the performance of the trained target model 170 may affect the sampling probability of each reference model for the next data subset. In this case, the policy function π is used for all training samples within the same data subset_{θ}The policy parameter theta of (a) remains unchanged.
The process 100 generally includes a process of training a target model and a process of updating policy parameters, which may be performed iteratively until the performance of the target model converges. During the process of training the target modelThe parameters of the target model can be optimized and the strategy parameters can be fixed; and during the process of updating the policy parameters, the policy parameters may be optimized while fixing the parameters of the target model. To utilize subsets of dataTo obtain a trained target model and to update the policy parameters based on the performance of the trained target model. The current parameters of the object model can be set toAnd setting the current policy parameter to θ_{b}. The strategy parameter can be fixed to theta first_{b}And using the data subsetsAfter the target model is trained, the parameters of the target model are updated toNext, the parameters of the target model may be fixed toAnd based on parameters ofTo update the policy parameters to θ_{b+1}(ii) a And so on.
Fig. 2 illustrates an exemplary process 200 for selecting a reference model in accordance with an embodiment of the disclosure. Process 200 may correspond to step 130 in fig. 1. Training sample i 202 (x) may be obtained first_{i},y_{i}) Which may correspond to training sample i102 in fig. 1. A set of candidate reference models 210 and a representation model 220 may also be obtained. The set of candidate reference models 210 may include K reference models, such as reference model 2101, reference models 2102, … …, reference model 210K. The set of candidate reference models 210 may correspond to those in FIG. 1And the representation model 220 may correspond to the representation model 120 in fig. 1.
The process 200 may encode states for the training samples i 202 corresponding to the respective reference models as state information and determine whether to select the reference model based on the state information. The state information may include information about the training sample i 202 and the performance of the reference model for the training sample i 202. Using reference model 210K (1 ≦ K ≦ K) as an example, the state for reference model 210K may be represented as s_{jk}And will be directed to state s_{jk}Is represented as F(s)_{jk}). Status information F(s)_{jk}) May be implemented as a realvalued vector, which for example comprises a concatenation of three features.
The first feature may be a training sample i 202 (x)_{i}，y_{i}) X in (2)_{i}Is shown. X may be obtained, for example, by representation model 220_{i}Vector representation ofWhere d is the hidden size (hidden size).
The second feature may be the probability distribution output by the reference model 210k for the training sample i 202. Taking the example where the training sample i 202 is a sample for a classification task, the probability distribution output by the reference model 210k can be represented asWhereinIs x output by the reference model 210k_{i}Probability of belonging to class C, C being an integer between 1 and C, C being the number of classes, andare parameters of the reference model 210k.
The third feature may be a predicted loss corresponding to the probability distribution. In one embodiment, the prediction may be computed by a crossentropy functionAnd (4) loss. For example, the probability distribution for the training sample i 202 output with the reference model 210k may be calculated by the following formula Predicting loss
Wherein the content of the first and second substances,is from the true annotation y_{i}Onehot vector of (c).
The probability distribution output by the reference model 210k for the training sample i 202 and the prediction loss corresponding to the probability distribution can be considered as the performance of the reference model 210k for the training sample i 202.
Vector representation that can be output to representation model 220Probability distribution of reference model 210k outputAnd the probability distributionCorresponding prediction lossCascading is performed to obtain state information 230k F(s) for reference model 210k_{jk})。
A set of state information 230 for each reference model in the set of candidate reference models 210 may be obtained by the process described above, including, for example, state information 2301, state information 2302, … …, state information 230K.
wherein the content of the first and second substances,is state information, σ () is with trainable parameters Sigmoid function of, and P_{θ}(a_{jk}s_{jk}) Is the sampling probability, which is expressed in the state s_{jk}Lower selection action value a_{jk}Probability of (a)_{jk}E {0, 1 }. Can utilize P_{θ}(a_{jk}s_{jk}) To the action value a_{jk}Sampling is performed. When a is_{jk}Sampled to a "0" value, indicating that reference model 210k is not selected; when a is_{jk}Sampled to a "1" value, indicates that the reference model 210k is selected.
A set of sampling probabilities 250, including, for example, sampling probability 2501, sampling probability 2502, … …, sampling probability 250K, and a set of action values 260, including, for example, action value 2601, action value 2602, … …, action value 260K, for each reference model in the set of candidate reference models 210 may be obtained by the abovedescribed process.
After the set of action values 260 is determined, at 270, a set of selected reference models 280 may be determined based on the set of reference models 210 and the set of action values 260, including, for example, reference model 2801, reference models 2802, … …, reference model 280K '(0 ≦ K' ≦ K). Each reference model in the set of selected reference models 280 is, for example, a reference model whose motion value is sampled to "1".
It should be understood that while the foregoing discussion and the following discussion may refer to the selection of at least one reference model to train a target model, it is also possible that none of the reference models is selected. For example, for some training samples, all reference models do not perform well, so the sampling probabilities of all reference models are low, and further, the action values sampled according to the sampling probabilities may all be "0", which results in no reference model being selected.
After the set of selected reference models 280 is determined, a set of target probability distributions output by the set of selected reference models 280 for the training sample i 202 may be obtained. The set of state information 230 determined above includes the target probability distribution output by each reference model in the set of candidate reference models for the training sample i 202. A set of target probability distributions corresponding to each reference model in the set of selected reference models 280 may be extracted from the target probability distributions. The set of target probability distributions may be utilized to train a target model.
FIG. 3 illustrates an exemplary process 300 for training a target model according to an embodiment of the present disclosure. The process 300 may train the target model using the training sample i and a set of target probability distributions output by a selected set of reference models for the training sample i. The training sample i may include a true label.
At 310, the target model may score the training sample i to obtain a predicted probability distribution for the training sample i.
At 320, a subprediction penalty corresponding to the target probability distribution can be separately calculated based on the prediction probability distribution and each target probability distribution in the set of target probability distributions to obtain a set of subprediction penalties. In one embodiment, the subprediction loss may be calculated by a crossentropy function.
At 330, a first prediction loss corresponding to the training sample i may be calculated based on the number of the set of selected reference models and the set of subprediction losses. In one embodiment, the first prediction loss may be calculated by first summing the set of subprediction losses to obtain an intermediate prediction loss, and then dividing the intermediate prediction loss by the number of the set of selected reference models.
At 340, a second prediction loss corresponding to the training sample i may be calculated based on the prediction probability distribution and the true label in the training sample i. In one embodiment, the second prediction loss may be calculated by a crossentropy function.
At 350, a composite prediction loss corresponding to the training sample i may be calculated based on the first prediction loss and the second prediction loss. In one embodiment, the aggregate prediction loss may be calculated by a weighted sum of the first prediction loss and the second prediction loss.
At 360, the objective model may be optimized by minimizing the aggregate prediction loss.
FIG. 4 illustrates a specific example 400 for training a target model according to an embodiment of the present disclosure. In example 400, the training samples used to train the target model may be, for example, training samples 410 (x)_{i}，y_{i}) Wherein x is_{i}Is input, y_{i}Is a real label. A set of target probability distributions that may be output for the training samples 410 using a set of selected reference models 420To train the target model 430, whereinThe set of selected reference models 420 includes, for example, K 'reference models numbered reference model 4201, reference models 4202, … …, reference model 420K'.
The target model 430 may score the training samples 410 to obtain a predicted probability distribution of the training samples 410Wherein, P^{s}(y_{i}＝cx_{i}；Θ^{s}) X representing the output of the object model 430_{i}Probability of belonging to class C, C being an integer between 1 and C, C being the number of classes, and Θ^{s}Are parameters of the object model 430.
The probability distribution can then be predicted basedAnd a set of target probability distributionsEach target probability distribution in (2)To calculate the probability distribution with the target respectivelyCorresponding subprediction lossTo obtain a set of subprediction lossesIn one embodiment, the subprediction loss may be calculated by a crossentropy functionAs shown in the following equation:
next, a first prediction loss for the training sample i may be calculated based on the number of reference models K' in the set of selected reference models 420 and the set of subprediction lossesIn one embodiment, the method can be usedThe first prediction loss is calculated by first summing the set of subprediction losses to obtain an intermediate prediction loss, and then dividing the intermediate prediction loss by the number K' of the set of selected reference models, as shown in the following equation:
the predicted probability distribution may then be output based on the target model And the true labels y in the training samples i_{i}To calculate a second prediction loss corresponding to the training sample iIn one embodiment, the second prediction loss may be calculated by a crossentropy function, as shown in the following equation:
after obtaining the first prediction loss corresponding to the training sample iAnd a second prediction lossThereafter, a comprehensive prediction loss corresponding to the training sample i may be calculatedIn one embodiment, the aggregate prediction loss may be calculated by a weighted sum of the first prediction loss and the second prediction loss, as shown in the following equation:
where α is a hyperparameter for balancing the first prediction loss and the second prediction loss.
Can predict loss by synthesisMinimized to optimize the objective model.
The process described above in connection with fig. 3 and 4 trains the target model by minimizing the combined prediction loss corresponding to a single training sample. Alternatively, to improve training efficiency, data subsets, e.g., data subsets, may be mappedPerforms the above process on all training samples in (1), and obtains a data subsetCorresponding comprehensive prediction lossCan predict loss by integratingMinimized to optimize the objective model. In this case, the parameters of the target model remain unchanged for all training samples within the same data subset. And data subsetsCorresponding comprehensive prediction lossCan be obtained by the following formula:
according to embodiments of the present disclosure, after a target model is trained using a training sample or a subset of data, a policy function π may be updated based on the performance of the trained target model_{θ}And thus the sampling probability of each reference model for the next training sample or next data subset. Fig. 5 illustrates an exemplary process 500 for updating policy parameters according to an embodiment of the disclosure. The process 500 may utilize the validation samples (x ', y') to evaluate the performance of the trained target model. The validation samples may be compared to training samples (x) used to train the target model_{i}，y_{i}) The same or different. The evaluated performance may be converted into a reward. Rewards may then be used to update the policy function pi_{θ}Of (2) is determined.
At 510, the validation samples (x ', y') may be scored through the target model to obtain a predicted probability distribution of the validation samples (x ', y')Wherein Θ is^{s}Is the current parameters of the target model.
At 520, a true label y ' and a prediction probability distribution in the validation sample (x ', y ') can be basedTo calculate a prediction loss corresponding to the validation sampleIn one embodiment, the second prediction loss may be calculated by a crossentropy function, as shown in the following equation:
at 530, a loss may be predicted basedTo calculate a reward v corresponding to the validated sample_{j}. In one embodiment, the reward v may be_{j}Calculated as predicted lossThe inverse of (c), as shown in the following equation:
at 540, the reward v may be based on_{j}To update the policy parameter theta. In one embodiment, the policy parameter θ may be updated by a standard policy gradient method, such as a MonteCarlo based policy gradient method, as shown in the following equation:
where β is the learning rate and π_{θ}(s_{jk}，a_{jk}) Is a policy function for the kth reference model.
According to an embodiment of the present disclosure, policy parameters in the policy function may be initialized before determining the sampling probability of each reference model by the policy function. For example, at least one reference model from a set of candidate reference models may be selected by a policy function, and policy parameters in the policy function may be initialized according to an average performance of the selected reference models.
Fig. 6 illustrates an exemplary process 600 for initializing policy parameters according to an embodiment of the disclosure. The process 600 may initialize policy parameters using an initialization sample. In the text, the samples used to initialize the policy parameters are referred to as initialization samples, which may be the same as or different from the training samples used to train the target model. The initialization sample may include an input and a real label corresponding to the input.
At 610, a selected set of reference models selected from a set of candidate reference models may be determined by a policy function for the initialization sample. The policy function may have an original policy parameter.
At 620, the initialization samples may be individually scored through the set of selected reference models to obtain a set of probability distributions for the initialization samples.
At 630, a set of prediction losses corresponding to the initialization sample can be computed based on the true labels in the initialization sample and the set of probability distributions. For example, a subprediction loss for each probability distribution can be first calculated based on the true label and each probability distribution in the set of probability distributions to obtain a set of subprediction losses. In one embodiment, the prediction loss for each probability distribution may be calculated by a crossentropy function.
At 640, a prediction penalty corresponding to the initialized sample may be calculated based on the number of the set of candidate reference models and the set of subprediction penalties. In one embodiment, the predicted loss may be calculated by first summing the set of subpredicted losses to obtain an intermediate predicted loss, and then dividing the intermediate predicted loss by the number of the set of selected reference models.
At 650, a reward corresponding to the initialization sample may be calculated based on the predicted loss. In one embodiment, the reward may be calculated as the inverse of the predicted loss.
At 660, policy parameters may be initialized based on the reward. In one embodiment, the policy parameters may be initialized by updating the original policy parameters using a standard policy gradient method, as shown in equation (10) above.
According to embodiments of the present disclosure, the target model may be pretrained prior to training the target model using the set of candidate reference models. In one embodiment, the pretraining data set may be scored using all reference models in the set of candidate reference models, and the target model may be pretrained using the scored pretraining data set. The set of data used to pretrain the target model is referred to herein as the pretraining set of data. The process for pretraining the target model may be similar to the process for training the target model explained in connection with fig. 3 and 4, except that the set of target probability distributions involved are the probability distributions output for the pretraining samples in the pretraining data set for all reference models in the set of candidate reference models, but not for selected reference models in the set of candidate reference models.
FIG. 7 is a flow diagram of an exemplary method 700 for obtaining a target model based on knowledge distillation in accordance with an embodiment of the present disclosure.
At step 710, a data set and a set of candidate reference models may be obtained.
At step 720, a selected set of reference models selected from the set of candidate reference models may be determined for each training sample in the data set.
At step 730, a set of target probability distributions for the set of selected reference models output for the training sample may be obtained.
At step 740, the target model may be trained using the set of target probability distributions.
In one embodiment, the determining a selected set of reference models may include: determining, for each candidate reference model in the set of candidate reference models, whether to select the candidate reference model by reinforcement learning.
The determining whether to select the candidate reference model may include: determining sampling probabilities of the candidate reference models based on a policy function; sampling motion values of the candidate reference model based on the sampling probabilities; and selecting the candidate reference model based on the sampled action value.
The policy function may have policy parameters. The determining whether to select the candidate reference model may further include: updating the policy parameters based on performance of the target model.
The updating the policy parameters may include: scoring a verification sample through the target model to obtain a predicted probability distribution of the verification sample; calculating a prediction loss corresponding to the validation sample based on the true labels in the validation sample and the prediction probability distribution; calculating a reward corresponding to the validation sample based on the predicted loss; and updating the policy parameters based on the reward.
The data set may include a plurality of data subsets. The strategy parameters of the strategy function may remain unchanged for all training samples within the same data subset.
The determining of the sampling probability may be performed with respect to state information. The state information may include at least: a representation of the training sample, a probability distribution of the candidate reference model output for the training sample, and a prediction penalty corresponding to the probability distribution.
The policy function may have policy parameters. The policy parameters may be initialized by: determining, by the policy function, a selected set of reference models selected from the candidate set of reference models for an initialization sample; scoring the initialization samples through the set of selected reference models, respectively, to obtain a set of probability distributions for the initialization samples; calculating a prediction loss corresponding to the initialization sample based on the true labels in the initialization sample and the set of probability distributions; calculating a reward corresponding to the initialization sample based on the predicted loss; and initializing the policy parameters based on the reward.
In one embodiment, the training samples may include realistic labels. The training the target model may include: scoring the training samples through the target model to obtain a predicted probability distribution of the training samples; calculating a first prediction loss corresponding to the training sample based on the prediction probability distribution and the set of target probability distributions; calculating a second prediction loss corresponding to the training sample based on the prediction probability distribution and the true label; calculating a composite prediction loss corresponding to the training samples based on the first prediction loss and the second prediction loss; and optimizing the objective model by minimizing the synthetic prediction loss.
The calculating the first predicted loss may include: calculating subprediction losses corresponding to the target probability distribution based on the prediction probability distribution and each target probability distribution in the set of target probability distributions, respectively, to obtain a set of subprediction losses; and calculating the first prediction loss based on the number of reference models in the set of selected reference models and the set of subprediction losses.
The data set may include a plurality of data subsets. The parameters of the target model may remain unchanged for all training samples within the same data subset.
In one embodiment, the method 700 may further include: scoring a pretraining data set through the set of candidate reference models; and pretraining the target model using the scored pretraining data set.
In one embodiment, the set of candidate reference models may be models having a higher complexity than the target model.
It should be understood that the method 700 may also include any steps/processes for obtaining a target model based on knowledgebased distillation according to embodiments of the present disclosure described above.
Fig. 8 illustrates an exemplary apparatus 800 for obtaining a target model based on knowledge distillation in accordance with an embodiment of the disclosure. The apparatus 800 may include: an obtaining module 810 for obtaining a data set and a set of candidate reference models; a reference model determination module 820 for determining, for each training sample in the data set, a selected set of reference models selected from the candidate set of reference models; a probability distribution obtaining module 830, configured to obtain a set of target probability distributions output by the set of selected reference models for the training sample; and an object model training module 840 for training the object model using the set of object probability distributions.
In one embodiment, the reference model determination module 820 may be further configured to: determining, for each candidate reference model in the set of candidate reference models, whether to select the candidate reference model by reinforcement learning.
The determining whether to select the candidate reference model may include: determining sampling probabilities of the candidate reference models based on a policy function; sampling motion values of the candidate reference model based on the sampling probabilities; and selecting the candidate reference model based on the sampled action value.
The policy function may have policy parameters. The determining whether to select the candidate reference model may further include: updating the policy parameters based on performance of the target model.
The data set may include a plurality of data subsets. The strategy parameters of the strategy function may remain unchanged for all training samples within the same data subset.
The determining of the sampling probability may be performed with respect to state information. The state information may include at least: a representation of the training sample, a probability distribution of the candidate reference model output for the training sample, and a prediction penalty corresponding to the probability distribution.
It should be understood that the apparatus 800 may also include any other module configured for obtaining a target model based on knowledge distillation according to embodiments of the present disclosure described above.
Fig. 9 illustrates an exemplary apparatus 900 for obtaining a target model based on knowledge distillation in accordance with an embodiment of the disclosure.
The apparatus 900 may include at least one processor 910. The apparatus 900 may also include a memory 920 coupled to the processor 910. The memory 920 may store computerexecutable instructions that, when executed, cause the processor 1910 to perform any of the operations of the method for obtaining a target model based on knowledge distillation according to embodiments of the present disclosure described above.
Embodiments of the present disclosure may be embodied in nontransitory computer readable media. The nontransitory computerreadable medium may include instructions that, when executed, cause one or more processors to perform any operations of the method for obtaining a target model based on knowledge distillation according to embodiments of the present disclosure as described above.
It should be appreciated that all of the operations in the methods described above are exemplary only, and the present disclosure is not limited to any of the operations in the methods or the order of the operations, but rather should encompass all other equivalent variations under the same or similar concepts.
It should also be appreciated that all of the modules in the apparatus described above may be implemented in various ways. These modules may be implemented as hardware, software, or a combination thereof. In addition, any of these modules may be further divided functionally into submodules or combined together.
The processor has been described in connection with various apparatus and methods. These processors may be implemented using electronic hardware, computer software, or any combination thereof. Whether such processors are implemented as hardware or software depends upon the particular application and the overall design constraints imposed on the system. By way of example, a processor, any portion of a processor, or any combination of processors presented in this disclosure may be implemented with a microprocessor, a microcontroller, a Digital Signal Processor (DSP), a Field Programmable Gate Array (FPGA), a Programmable Logic Device (PLD), a state machine, gated logic units, discrete hardware circuits, and other suitable processing components configured to perform the various functions described in this disclosure. The functionality of a processor, any portion of a processor, or any combination of processors presented in this disclosure may be implemented using software executed by a microprocessor, microcontroller, DSP, or other suitable platform.
Software should be viewed broadly as meaning instructions, instruction sets, code segments, program code, programs, subroutines, software modules, applications, software packages, routines, subroutines, objects, threads of execution, procedures, functions, and the like. The software may reside in a computer readable medium. The computer readable medium may include, for example, memory, which may be, for example, a magnetic storage device (e.g., hard disk, floppy disk, magnetic strip), an optical disk, a smart card, a flash memory device, a Random Access Memory (RAM), a Read Only Memory (ROM), a programmable ROM (prom), an erasable prom (eprom), an electrically erasable prom (eeprom), a register, or a removable disk. Although the memory is shown as being separate from the processor in the aspects presented in this disclosure, the memory may also be located internal to the processor, such as a cache or registers.
The above description is provided to enable any person skilled in the art to practice the various aspects described herein. Various modifications to these aspects will be readily apparent to those skilled in the art, and the generic principles defined herein may be applied to other aspects. Thus, the claims are not intended to be limited to the aspects shown herein. All structural and functional equivalents to the elements of the various aspects described herein that are known or later come to be known to those of ordinary skill in the art are expressly incorporated herein by reference and are intended to be encompassed by the claims.
Claims (20)
1. A method for obtaining a target model based on knowledge distillation, comprising:
obtaining a data set and a set of candidate reference models;
determining, for each training sample in the data set, a selected set of reference models selected from the candidate set of reference models;
obtaining a set of target probability distributions output by the set of selected reference models for the training sample; and
training the target model using the set of target probability distributions.
2. The method of claim 1, wherein the determining a selected set of reference models comprises:
determining, for each candidate reference model in the set of candidate reference models, whether to select the candidate reference model by reinforcement learning.
3. The method of claim 2, wherein the determining whether to select the candidate reference model comprises:
determining sampling probabilities of the candidate reference models based on a policy function;
sampling motion values of the candidate reference model based on the sampling probabilities; and
selecting the candidate reference model based on the sampled action values.
4. The method of claim 3, wherein the policy function has policy parameters, and the determining whether to select the candidate reference model further comprises:
updating the policy parameters based on performance of the target model.
5. The method of claim 4, wherein the updating the policy parameters comprises:
scoring a verification sample through the target model to obtain a predicted probability distribution of the verification sample;
calculating a prediction loss corresponding to the validation sample based on the true labels in the validation sample and the prediction probability distribution;
calculating a reward corresponding to the validation sample based on the predicted loss; and
updating the policy parameters based on the reward.
6. The method of claim 3, wherein the data set comprises a plurality of data subsets and the policy parameters of the policy function remain unchanged for all training samples within the same data subset.
7. The method of claim 3, wherein the determining sampling probabilities is performed for state information comprising at least: a representation of the training sample, a probability distribution of the candidate reference model output for the training sample, and a prediction penalty corresponding to the probability distribution.
8. The method of claim 3, wherein the policy function has policy parameters, and the policy parameters are initialized by:
determining, by the policy function, a selected set of reference models selected from the candidate set of reference models for an initialization sample;
scoring the initialization samples through the set of selected reference models, respectively, to obtain a set of probability distributions for the initialization samples;
calculating a prediction loss corresponding to the initialization sample based on the true labels in the initialization sample and the set of probability distributions;
calculating a reward corresponding to the initialization sample based on the predicted loss; and
initializing the policy parameters based on the reward.
9. The method of claim 1, wherein the training samples include real labels and the training the target model comprises:
scoring the training samples through the target model to obtain a predicted probability distribution of the training samples;
calculating a first prediction loss corresponding to the training sample based on the prediction probability distribution and the set of target probability distributions;
calculating a second prediction loss corresponding to the training sample based on the prediction probability distribution and the true label;
calculating a composite prediction loss corresponding to the training samples based on the first prediction loss and the second prediction loss; and
optimizing the objective model by minimizing the synthetic prediction loss.
10. The method of claim 9, wherein the calculating a first predicted loss comprises:
calculating subprediction losses corresponding to the target probability distribution based on the prediction probability distribution and each target probability distribution in the set of target probability distributions, respectively, to obtain a set of subprediction losses; and
calculating the first prediction loss based on the number of reference models in the set of selected reference models and the set of subprediction losses.
11. The method of claim 9, wherein the data set comprises a plurality of data subsets and the parameters of the target model remain unchanged for all training samples within the same data subset.
12. The method of claim 1, further comprising:
scoring a pretraining data set through the set of candidate reference models; and
pretraining the target model using the scored pretraining data set.
13. The method of claim 1, wherein the set of candidate reference models are models having a higher complexity than the target model.
14. An apparatus for obtaining a target model based on knowledge distillation, comprising:
an obtaining module for obtaining a data set and a set of candidate reference models;
a reference model determination module for determining, for each training sample in the data set, a selected set of reference models selected from the candidate set of reference models;
a probability distribution obtaining module, configured to obtain a set of target probability distributions output by the selected reference models for the training samples; and
a target model training module to train the target model using the set of target probability distributions.
15. The apparatus of claim 14, wherein the reference model determination module is further configured to:
determining, for each candidate reference model in the set of candidate reference models, whether to select the candidate reference model by reinforcement learning.
16. The apparatus of claim 15, wherein the determination of whether to select the candidate reference model comprises:
determining sampling probabilities of the candidate reference models based on a policy function;
sampling motion values of the candidate reference model based on the sampling probabilities; and
selecting the candidate reference model based on the sampled action values.
17. The apparatus of claim 16, wherein the policy function has policy parameters, and the determining whether to select the candidate reference model further comprises:
updating the policy parameters based on performance of the target model.
18. The apparatus of claim 16, wherein the data set comprises a plurality of data subsets and the policy parameters of the policy function remain unchanged for all training samples within the same data subset.
19. The apparatus of claim 16, wherein the determining sampling probabilities is performed for state information comprising at least: a representation of the training sample, a probability distribution of the candidate reference model output for the training sample, and a prediction penalty corresponding to the probability distribution.
20. An apparatus for obtaining a target model based on knowledge distillation, comprising:
at least one processor; and
a memory storing computerexecutable instructions that, when executed, cause the at least one processor to:
a data set and a set of candidate reference models are obtained,
determining, for each training sample in the data set, a selected set of reference models selected from the candidate set of reference models,
obtaining a set of target probability distributions for the training sample output by the set of selected reference models, an
Training the target model using the set of target probability distributions.
Priority Applications (2)
Application Number  Priority Date  Filing Date  Title 

CN202010561319.7A CN113822434A (en)  20200618  20200618  Model selection learning for knowledge distillation 
PCT/US2021/026288 WO2021257160A1 (en)  20200618  20210408  Model selection learning for knowledge distillation 
Applications Claiming Priority (1)
Application Number  Priority Date  Filing Date  Title 

CN202010561319.7A CN113822434A (en)  20200618  20200618  Model selection learning for knowledge distillation 
Publications (1)
Publication Number  Publication Date 

CN113822434A true CN113822434A (en)  20211221 
Family
ID=75690703
Family Applications (1)
Application Number  Title  Priority Date  Filing Date 

CN202010561319.7A Pending CN113822434A (en)  20200618  20200618  Model selection learning for knowledge distillation 
Country Status (2)
Country  Link 

CN (1)  CN113822434A (en) 
WO (1)  WO2021257160A1 (en) 
Families Citing this family (1)
Publication number  Priority date  Publication date  Assignee  Title 

CN115082920B (en) *  20220816  20221104  北京百度网讯科技有限公司  Deep learning model training method, image processing method and device 

2020
 20200618 CN CN202010561319.7A patent/CN113822434A/en active Pending

2021
 20210408 WO PCT/US2021/026288 patent/WO2021257160A1/en unknown
Also Published As
Publication number  Publication date 

WO2021257160A1 (en)  20211223 
Similar Documents
Publication  Publication Date  Title 

Joulin et al.  Efficient softmax approximation for GPUs  
CN106650813B (en)  A kind of image understanding method based on depth residual error network and LSTM  
CN109376242B (en)  Text classification method based on cyclic neural network variant and convolutional neural network  
CN107239446A (en)  A kind of intelligence relationship extracting method based on neutral net Yu notice mechanism  
CN110796199B (en)  Image processing method and device and electronic medical equipment  
US11334791B2 (en)  Learning to search deep network architectures  
CN113282713B (en)  Event trigger detection method based on difference neural representation model  
CN110866113A (en)  Text classification method based on sparse selfattention mechanism finetuning Bert model  
CN111259147A (en)  Sentencelevel emotion prediction method and system based on adaptive attention mechanism  
CN113705769A (en)  Neural network training method and device  
CN112257449B (en)  Named entity recognition method and device, computer equipment and storage medium  
CN114528835A (en)  Semisupervised specialized term extraction method, medium and equipment based on interval discrimination  
CN113822434A (en)  Model selection learning for knowledge distillation  
Hu et al.  Saliencybased YOLO for single target detection  
CN114997287A (en)  Model training and data processing method, device, equipment and storage medium  
CN114841151A (en)  Medical text entity relation joint extraction method based on decompositionrecombination strategy  
CN114153942A (en)  Event time sequence relation extraction method based on dynamic attention mechanism  
CN113656563A (en)  Neural network searching method and related equipment  
CN113487453A (en)  Legal judgment prediction method and system based on criminal elements  
Chen et al.  Auxiliary learning with joint task and data scheduling  
CN111259673A (en)  Feedback sequence multitask learningbased law decision prediction method and system  
Nie et al.  A bidirectional LSTM model for question title and body analysis in question answering  
Li et al.  Draodm: a faster and more accurate deep recurrent attention dynamic model for object detection  
Meng et al.  Nonlinear network speech recognition structure in a deep learning algorithm  
CN114168709B (en)  Text classification method based on lightweight pretraining language model 
Legal Events
Date  Code  Title  Description 

PB01  Publication  
PB01  Publication  
SE01  Entry into force of request for substantive examination  
SE01  Entry into force of request for substantive examination 