WO2022105173A1 - Model distillation method and apparatus, and storage medium and device - Google Patents

Model distillation method and apparatus, and storage medium and device Download PDF

Info

Publication number
WO2022105173A1
WO2022105173A1 PCT/CN2021/096649 CN2021096649W WO2022105173A1 WO 2022105173 A1 WO2022105173 A1 WO 2022105173A1 CN 2021096649 W CN2021096649 W CN 2021096649W WO 2022105173 A1 WO2022105173 A1 WO 2022105173A1
Authority
WO
WIPO (PCT)
Prior art keywords
teacher
student
identification
preset
model
Prior art date
Application number
PCT/CN2021/096649
Other languages
French (fr)
Chinese (zh)
Inventor
吴天博
王健宗
程宁
Original Assignee
平安科技(深圳)有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 平安科技(深圳)有限公司 filed Critical 平安科技(深圳)有限公司
Publication of WO2022105173A1 publication Critical patent/WO2022105173A1/en

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • G06N5/02Knowledge representation; Symbolic representation

Definitions

  • the present application relates to the technical field of artificial intelligence, and in particular, to a model distillation method, device, storage medium and device.
  • Model distillation refers to using a teacher model with high accuracy but a complex structure to guide the training of a student model with low accuracy but a simple structure, so as to improve the accuracy of the student model.
  • the inventors realized that although the student model can learn from the teacher model, the accuracy of the student model is improved. However, there are still some differences between the teacher model and the student model in the existing distillation model architecture, resulting in poor expression effect and low accuracy of the student model.
  • the technical problem to be solved by the embodiments of the present application is to provide a model distillation method, device, storage medium and device, which can improve the accuracy and data processing capability of the student model.
  • the embodiments of the present application provide a method for model distillation, including:
  • the training sample data is identified by the preset student model and the preset teacher model, respectively, and the teacher identification result and the student identification result of the training sample data are obtained, wherein the preset student model is determined by the The above-mentioned preset teacher model guides the training;
  • model distillation device comprising:
  • a first acquisition module used for acquiring training sample data for training a preset student model
  • an identification module configured to identify the training sample data by using the preset student model and the preset teacher model, respectively, to obtain a teacher identification result and a student identification result of the training sample data, wherein the preset The student model is obtained by the guidance and training of the preset teacher model;
  • a second acquisition module configured to acquire, from the teacher recognition result, a weight parameter for adjusting the recognition result of the preset student model
  • the adjustment module is used to calculate the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, and use the calculated value as a loss value to the said logarithm.
  • the preset student models can be adjusted.
  • One aspect of the present application provides a computer device, including: a processor and a memory;
  • the above-mentioned memory is used to store the computer program
  • the above-mentioned processor is used to call the above-mentioned computer program to perform the following steps:
  • the training sample data is identified by the preset student model and the preset teacher model, respectively, and the teacher identification result and the student identification result of the training sample data are obtained, wherein the preset student model is determined by the The above-mentioned preset teacher model guides the training;
  • an embodiment of the present application provides a computer-readable storage medium.
  • the computer-readable storage medium stores a computer program, and the computer program includes program instructions. When executed by a processor, the program instructions perform the following steps:
  • the training sample data is identified by the preset student model and the preset teacher model, respectively, and the teacher identification result and the student identification result of the training sample data are obtained, wherein the preset student model is determined by the The above-mentioned preset teacher model guides the training;
  • the embodiment of the present application can reasonably allocate the identification results of the training sample data and the adjustment weights for adjusting the student model for different results in the prediction results of the training sample data according to the weight parameters, which can make the adjusted student model more accurate.
  • the student model can have the data processing capability of the teacher model, and the accuracy of the student model can be improved.
  • Fig. 1 is the schematic flow sheet of a kind of model distillation method provided by the application
  • FIG. 2 is a schematic diagram of a method for calculating the average value of multiple teacher identification results in each teacher identification group provided by an embodiment of the present application;
  • FIG. 3 is a schematic diagram of a method for obtaining a loss value of a preset student model provided by an embodiment of the present application
  • Fig. 4 is the schematic diagram of a kind of model distillation provided by the embodiment of the present application.
  • Fig. 5 is the schematic flow sheet of another kind of model distillation method provided by the application.
  • Fig. 6 is the schematic flow sheet of a kind of model distillation apparatus provided by the application.
  • FIG. 7 is a schematic structural diagram of a computer device provided by an embodiment of the present application.
  • the technical solutions of the present application relate to the technical field of artificial intelligence and/or big data.
  • the data involved in this application such as sample data, identification results and/or loss values, may be stored in a database, or may be stored in a blockchain, such as distributed storage through a blockchain, which is not limited in this application. .
  • FIG. 1 is a schematic flowchart of a model distillation method provided by an embodiment of the present application.
  • the method can be performed by a computer device, and the computer device can refer to a terminal or a server, and the terminal can include but is not limited to: a smart phone, a tablet computer, a notebook computer, a desktop computer, a smart speaker, a smart watch, etc.; the server can be an independent one
  • a physical server can also be a server cluster or distributed system composed of multiple physical servers, and can also provide cloud services, cloud databases, cloud computing, cloud functions, cloud storage, network services, cloud communications, middleware services, and domain name services. , security services, Content Delivery Network (CDN), and cloud servers for basic cloud computing services such as big data and artificial intelligence platforms.
  • the model distillation method may include steps S101-S105.
  • S101 Acquire training sample data for training a preset student model.
  • the preset student model refers to the student model in model distillation.
  • Model distillation (knowledge distillation) refers to using the teacher model to guide the training of the student model, so as to improve The accuracy of the student model.
  • the training sample data may refer to text data, image data, and the like.
  • a preset teacher model is used to identify the training sample data to obtain a teacher identification result of the training sample data
  • a preset student model is used to identify the training sample data to obtain a student identification result of the training sample data.
  • the preset student model is guided and trained by the preset teacher model.
  • the teacher model has high accuracy and high computational complexity, which is not suitable for deployment in terminal equipment, while the calculation of the student model is relatively simple and meets the requirements of terminal equipment. , but the accuracy is not enough, so model distillation (distillation) can be used to solve this problem, that is, the preset student model is trained under the guidance of the preset teacher model, so as to improve the accuracy of the preset student model.
  • the data processing type between the teacher model and the student model is the same, the network in the teacher model is deeper or wider, and the resolution of the teacher model is larger, that is, the data processing capability of the teacher model is higher than that of the student model.
  • the selection can be made according to the similarity between the teacher model and the student model.
  • the more similar the structure of the teacher model and the student model the smaller the difference in accuracy between the two after distillation. Therefore, the teacher model and the student model can be selected as the same type of model, such as resnet network series, resnet network refers to the residual network, and the width and depth of the network can be easily adjusted to obtain networks with different expressive abilities.
  • the preset student model is adjusted, so that the preset teacher model is adjusted.
  • the knowledge is transferred to the preset student model, so that the preset student model has the data processing capability and accuracy of the preset teacher model.
  • the teacher recognition result can be used to generate a weight parameter for adjusting the recognition result of the preset student model, and the weight parameter is used to determine the weight of the recognition result of the preset student model when generating the loss value, that is, the larger the weight parameter is. , the recognition result of the preset student model has a greater weight in generating the loss value.
  • the balance parameter used to balance the teacher recognition result can be obtained, and the preset teacher model's weight parameter can be obtained.
  • the obtained multiple teacher identification results are grouped according to the balance parameter to obtain multiple teacher identification groups arranged in sequence, wherein each teacher identification group contains the same number of teacher identification results. Calculate the average value of multiple teacher recognition results in each teacher recognition group respectively, and use the obtained multiple average values as weight parameters after balancing processing.
  • the teacher identification result is a plurality of teacher identification results, and the teacher identification result represents the identification probability, which refers to the identification probability obtained by identifying the training sample data by the preset teacher model.
  • the balance parameter used to balance the teacher identification results can be obtained.
  • the balance parameter can be a positive integer greater than or equal to 1, and can be used for abnormal results in the teacher identification results.
  • the abnormal results are far larger than normal results or far away. Results that are much smaller than normal results, i.e. abnormal compared to other results.
  • sort the multiple teacher recognition results to obtain the sorted multiple teacher recognition results, and group the sorted multiple teacher recognition results according to the balance parameter, and get the sequenced arrangement. of multiple teacher identification groups.
  • the number of teacher identification results included in each of the plurality of teacher identification groups is the same. Then, the average value of the recognition results of multiple teachers in each teacher recognition group is calculated separately, and the average value of the recognition results of the teachers is obtained. Then, the recognition result of the preset student model can be adjusted by using the weight parameter after the balancing process. Among them, when the average value of multiple teacher recognition results in the teacher recognition group is larger, the corresponding weight parameter after balancing processing is larger, the teacher recognition result represents the recognition probability, and the average value of the recognition probability in the teacher recognition group is larger, The resulting balanced weight parameter is larger.
  • the loss value of the preset student model calculated according to the weight parameters is more accurate and can be improved more accurately.
  • the accuracy of the preset student model The obtained multiple teacher identification results are grouped according to the balance parameter to obtain multiple teacher identification groups arranged in sequence, and then the average value of the multiple teacher identification results of each teacher identification group in the multiple teacher identification groups is calculated, that is, the balance is adopted.
  • the parameter balances each teacher identification result which can balance the abnormal results in the teacher identification result, reduce the error generated when generating the weight parameter, and improve the accuracy of the preset student model.
  • the number of the teacher identification results in the multiple teacher identification results can be obtained, and the preset threshold to which the number of the teacher identification results in the multiple teacher identification results belongs to is determined.
  • the teacher recognition result is a plurality of teacher recognition results
  • the teacher recognition result represents the recognition probability, which refers to the recognition probability obtained by the preset teacher model recognizing the training sample data. Therefore, the teacher recognition result among the multiple teacher recognition results can be obtained.
  • the number of teacher identification results is determined, and the preset threshold range to which the number of teacher identification results in the plurality of teacher identification results belongs. After determining the preset threshold range to which the number of teacher identification results among the plurality of teacher identification results belongs, the target balance parameter corresponding to the preset threshold range to which the number of teacher identification results belongs may be determined from the balance parameter database. It includes at least one balance parameter, and a corresponding relationship between each balance parameter in the at least one balance parameter and a preset threshold range.
  • a balance parameter library is preset, and the balance parameter library includes at least one balance parameter, and a corresponding relationship between each balance parameter in the at least one balance parameter and a preset threshold range, for example, the first threshold range corresponds to the first balance parameter, The second threshold range corresponds to the second balance parameter and so on.
  • the balance parameter may also be determined according to the number of student identification results among the multiple student identification results.
  • the preset threshold range to which the number of student identification results belongs
  • the target balance parameter corresponding to the preset threshold range to which the number of student identification results belongs from the balance database
  • the target balance parameter as the target balance parameter for teacher identification.
  • the balance parameters are balanced, and the balance parameter library includes at least one balance parameter, and a corresponding relationship between each balance parameter in the at least one balance parameter and a preset threshold range.
  • the balance parameter C can be determined according to the number of recognition results of the teacher or the number of recognition results of the students, or other specific circumstances. Abnormal results in teacher identification results or student identification results.
  • FIG 2 a schematic diagram of a method for calculating the average value of multiple teacher identification results in each teacher identification group provided by the embodiment of the present application, as shown in Figure 2, the number of teacher identification results is 11 teachers Taking the distribution of recognition results as an example, the 11 teacher recognition results are sorted according to the preset recognition order of the teacher model, and the sorted teacher recognition results are obtained, that is, x1, x2...x11.
  • the 11 sorted teacher identification results can be grouped according to the balance parameter 3 to obtain multiple teacher identification groups arranged in sequence, namely [x1, x2], [x1 , x2, x3], [x2, x3, x4], [x3, x4, x5];[x9, x10, x11], [x10, x11].
  • the teacher recognition group corresponding to the teacher recognition result x1 can be [x1, x2] or [x1, x2, x3].
  • the teacher recognition group corresponding to the teacher recognition result x11 can be [x10, x11] or [x9, x10, x11].
  • the distribution of teacher recognition results for training sample data is:
  • the third teacher recognition result of 0.0001 it becomes 0.0234 after balancing processing, and the abnormal teacher recognition results are balanced, that is, the abnormal teacher recognition results are eliminated.
  • S104 Calculate the logarithm between the teacher identification result and the student identification result, and perform a weighting operation on the logarithm using a weight parameter, and use the calculated value as a loss value to adjust the preset student model.
  • the preset student model is adjusted according to the loss value of the preset student model to obtain an adjusted student model, and the adjusted student model is used as the target student model.
  • the target learning model is used to identify the data to be processed, and the recognition result obtained by the target learning model identifying the data to be processed matches the recognition result obtained by identifying the data to be processed by the preset teacher model, that is, the target student model has the data of the preset teacher model. processing power.
  • the logarithm between the teacher identification result and the student identification result is calculated, and the logarithm is weighted by using a weight parameter, the following formula (1) may be used for calculation.
  • Q) refers to the preset loss value of the student model
  • P(x) refers to the teacher recognition result of the training sample data, that is, the teacher model recognizes the training sample data.
  • the teacher recognition result, Q(x) refers to the student recognition result of the training sample data, that is, the student recognition result obtained by the student model predicting the training sample data. It refers to the logarithm between the teacher recognition result and the student recognition result.
  • the weight parameter can refer to P(x), x refers to either the teacher recognition result or the student recognition result, and X refers to the teacher recognition result and the student recognition result.
  • This formula (1) can be called KL-divergence (relative entropy).
  • a schematic diagram of a method for obtaining a preset loss value of a student model provided by an embodiment of the present application, as shown in FIG. 3 , the method for obtaining a preset loss value of a student model includes step S21 -S23.
  • each student identification group in the multiple student identification groups includes the same number of student identification groups
  • each teacher identification group corresponds to each student identification group in a one-to-one identification order.
  • the student identification result is a plurality of student identification results, and the student identification result represents the identification probability, which refers to the identification probability obtained by identifying the training sample data by the preset student model.
  • the recognition sequence of the preset teacher model may be the same as the preset recognition sequence of the student model, that is, the training sample data in the same arrangement order can be recognized.
  • the obtained multiple student identification results are grouped according to the above-mentioned balance parameter to obtain multiple student identification groups arranged in sequence. Wherein, each student identification group in the multiple student identification groups contains the same number of student identification results, and each teacher identification group corresponds to each student identification group in a one-to-one identification sequence.
  • the obtained multiple student identification results are grouped according to the balance parameter to obtain multiple student identification groups arranged in sequence, and then the average value of the multiple student identification results of each student identification group in the multiple student identification groups is calculated, that is, the balance is adopted.
  • the parameters are balanced for each student identification result, and the balanced student identification result is obtained. In this way, the abnormal results in the student identification results can be balanced, and the occurrence of extreme values can be greatly reduced, thereby improving the accuracy of the target student model and enabling the student model to have the data processing capability of the teacher model.
  • the following formulas (2), (3), and (4) can be used to calculate the weight parameters after the balance processing to perform a weighting operation on the logarithm between the teacher identification result and the student identification result.
  • z in formulas (2) and (3) represents the number of teacher recognition results in the teacher recognition group or the number of student recognition results in the student recognition group
  • e represents the teacher recognition group or the student recognition group.
  • Q) C in formula (4) refers to the loss value
  • P′(x) refers to the average value of the teacher recognition group, that is, the average value obtained by balancing the teacher recognition results
  • Q′( x) is the average value of the student recognition group, that is, the average value obtained by balancing the student recognition results.
  • the weight parameter may refer to P(x), where x refers to any one of the teacher identification result and the student identification result, and X refers to the teacher identification result and the student identification result.
  • P(x), Q(x) are identical, D DKL (P
  • the present application introduces a balance parameter into the KL-divergence (relative entropy) function, and the KL-divergence loss function that introduces the balance parameter can be used as DKL-divergence, and the balance parameter is introduced into the DKL-divergence.
  • KL-divergence relative entropy
  • the KL-divergence loss function that introduces the balance parameter can be used as DKL-divergence
  • the balance parameter is introduced into the DKL-divergence.
  • the preset teacher model includes multiple teacher distillation layers
  • the preset student model includes multiple student distillation layers
  • the logarithm between the teacher identification result and the student identification result is calculated, and the weight parameter is used to pair the pair.
  • the teacher distillation layer corresponding to each student distillation layer in the multiple student distillation layers can be determined, and the value of each student distillation layer can be calculated.
  • the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer is weighted to obtain the loss value of each student distillation layer, and each student distillation layer is used separately.
  • the loss value of is adjusted to the corresponding student distillation layer in the preset student model.
  • the preset teacher model includes multiple teacher distillation layers, and each teacher distillation layer has a corresponding output teacher identification result.
  • the preset student model includes multiple student distillation layers, and each student distillation layer also has a corresponding output. According to the student identification result, the corresponding knowledge distillation can be performed between the distillation layer in the preset teacher model and the distillation layer in the preset student model.
  • the teacher distillation layer corresponding to each student distillation layer in the multiple student distillation layers can be determined, and the weight parameter is used to compare the student identification result output by each student distillation layer and the teacher identification result output by the corresponding teacher distillation layer.
  • the number is weighted to obtain the corresponding loss value of each student distillation layer.
  • the loss value of each student distillation layer is used to adjust the corresponding student distillation layer in the preset student model, so that each student distillation layer in the preset student model can be adjusted more accurately and the accuracy of the student model can be improved.
  • the distillation model generally includes three parts of distillation, namely Transformer-layer distillation, Embedding-layer distillation, and Embedding-layer distillation.
  • Layer (embedding layer) distillation and Prediction-layer (prediction layer) distillation can respectively perform knowledge distillation on the three distillation layers in the teacher model and the corresponding three distillation layers in the student model.
  • the corresponding knowledge distillation can be performed between the distillation layer in the teacher model and the distillation layer in the student model, that is, knowledge distillation is performed on the Transformer-layer conversion layer in the student model according to the Transformer-layer conversion layer in the teacher model, and the Transformer-layer conversion layer is calculated.
  • the loss value of the layer conversion layer, the Transformer-layer conversion layer is adjusted according to the loss value of the Transformer-layer conversion layer.
  • Perform knowledge distillation on the Embedding-layer in the student model according to the Embedding-layer in the teacher model calculate the loss value of the Embedding-layer, and use the loss value of the Embedding-layer for the Embedding -layer embed layer to adjust.
  • the student model (new model) has M Transformers (transformation layers) layers
  • the teacher model original model
  • N transformers transformers (transformers (transformation layers) layers
  • the m layer gets information from the nth layer of the teacher model.
  • the following functional formula (5) can be used to represent the distillation loss of knowledge transfer from teacher to student.
  • L layer in formula (5) represents the loss function of the specified layer, and the specified layer here may refer to Transformer-layer (conversion layer), Embedding-layer (embedding layer), Prediction-layer (prediction layer).
  • ⁇ m represents the loss weight of each layer Hyperparameters
  • L model represents the sum of knowledge distillation losses for all layers.
  • Transformer-layer (conversion layer) distillation can set the corresponding loss function, and can set the corresponding loss function according to the loss value obtained by the corresponding loss function. make adjustments.
  • the Transformer-layer distillation in the model distillation includes self-attention-based distillation and hidden state-based distillation.
  • the objective function of the self-attention matrix distillation of the Transformer-layer (conversion layer) in the related art is the following formula (6 ).
  • h in formula (6) is the number of attention heads, i represents the ith attention head, denote the attention matrices of the student model and the teacher model, respectively, and MSE refers to the mean squared error loss.
  • H S ⁇ R l ⁇ d′ and H T ⁇ R l ⁇ d in formula (7) refer to the hidden state matrices of the student and the teacher, respectively, R l ⁇ d′ and R l ⁇ d represent the student and the teacher, respectively.
  • the size of the hidden state matrix space, l, d represent the length of the training sample data (ie the length of the input sentence) and the size of the hidden layer, respectively.
  • W h ⁇ R d′ ⁇ d is a learnable linearly varying matrix that transforms the student’s hidden state matrix into the same result space size as the teacher.
  • the loss function of the Embedding-layer (distillation layer) in the related art is the following formula (8).
  • Lembd MSE( E S We ,E T ) (8)
  • E S and E T of (8) in the formula refer to the embedding (distillation layer) matrices of the student and teacher models, respectively, and We are a linear transformation matrix similar to W h .
  • the output layer of the Prediction-layer distillation adopts the soft cross-entropy loss as the following formula (9).
  • z S and z T in formula (9) are the teacher recognition result of the preset teacher model and the student recognition result of the preset student model, respectively, log_softmax() represents the log-likelihood, and t refers to the distillation temperature.
  • the MSE in formula (10) the full name is Mean Squared Error, also known as the mean square error, is generally used to detect the deviation between the model predicted value and the true value. Assuming that the result distribution of the true value is observed, the result distribution of the predicted value is predicted, and the sample space size is n, the difference between the two distributions can be expressed as the following formula (11).
  • MSE is used to calculate the loss value of the student model, but it can be seen from formula (11) that the MSE in the related art pays attention to all positions in the result distribution indiscriminately, and the obtained loss value does not reflect the student well. The difference between the model and the teacher model, so the preset student model cannot be adjusted accurately.
  • the above-mentioned DKL-divergence is used to calculate the loss value of the student model.
  • the loss value corresponding to the Transformer transformation layer
  • the layers corresponding to U are respectively U and V, for the knowledge distillation of each corresponding layer, the attention matrix corresponds to and T and S refer to teachers and students respectively, then for the teacher identification results and student identification results corresponding to the training sample data, after determining the balance parameter C, the attention matrix can be corresponding to and Sampling is performed to obtain the sub-distributions p T , q S corresponding to the teacher identification group and the student identification group respectively, and the distribution obtained after averaging is: and is the following formula (12) and The formula is the following formula (13)
  • z in formula (12) and formula (13) represents the number of teacher recognition results in the teacher recognition group and the number of student recognition results in the student recognition group
  • e represents the teacher recognition group and the student recognition group.
  • ⁇ in formula (14) represents the probability space where the attention matrix A ⁇ R 1 ⁇ l is located, which actually represents the result distribution of the training sample data, and l represents the length of the training sample data.
  • ⁇ in formula (14) represents the probability space where the attention matrix A ⁇ R 1 ⁇ l is located, which actually represents the result distribution of the training sample data, and l represents the length of the training sample data.
  • t in formula (15) refers to a certain sub-training sample data of length t in the training sample data of length l, and refers to a certain attention head among h attention heads.
  • the DKL-divergence in this scheme uses P'(x) (a weight parameter generated according to the teacher's recognition probability) to the logarithm of the ratio of P'(x) to Q'(x) (the average value of each student's recognition group and the corresponding The logarithm of the average value of the teacher identification group) is weighted.
  • P'(x) a weight parameter generated according to the teacher's recognition probability
  • Q'(x) the average value of each student's recognition group and the corresponding The logarithm of the average value of the teacher identification group
  • a balance constant can be used in this scheme to eliminate abnormal results (ie abnormal concepts) in the teacher identification results and student identification results corresponding to the training sample data. For example, if the value of a teacher identification result in the teacher identification results is 0.2, the corresponding student identification results If the value of the recognition result is 0.0001, the calculated logarithm between P'(x) and Q'(x) will become abnormally large, resulting in a large gradient gap. On the contrary, it may cause the problem of gradient disappearance. Therefore, the introduction of balance parameters to balance the possible abnormal probability values can greatly reduce the occurrence of extreme values.
  • the accuracy of the student model can be improved, and the data processing ability of the student model can be improved, so that the recognition result obtained by the learning model from identifying the data to be processed matches the recognition result obtained by the teacher model from identifying the data to be processed, even if the student model has the characteristics of the teacher model. data processing capability.
  • the training sample data for training the preset student model is obtained, the preset student model and the preset teacher model are used to identify the training sample data respectively, and the teacher identification result of the training sample data is obtained. and student identification results.
  • Obtain the weight parameter used to adjust the recognition result of the preset student model from the teacher recognition result calculate the logarithm between the teacher recognition result and the student recognition result, and use the weight parameter to perform a weighting operation on the logarithm, and calculate the logarithm.
  • the obtained value is used as the loss value to adjust the preset student model.
  • the recognition results of the training sample data and the adjustment weights for adjusting the student model for different results in the prediction results of the training sample data can be reasonably allocated according to the weight parameters, so that more attention is paid to the recognition results with higher probability values, which can make the obtained loss value more accurate, and can make the adjusted student model more accurate.
  • a balance parameter is introduced to eliminate the abnormal results in the teacher identification results and the student identification results of the training sample data, so as to avoid inaccurate loss values due to abnormal results in the teacher identification results and the student identification results.
  • the student model can have the data processing capability of the teacher model, and the accuracy of the student model can be improved.
  • FIG. 5 is a schematic flowchart of another model distillation method provided by an embodiment of the present application.
  • the method may be performed by computer equipment, as shown in FIG. 5
  • the other model distillation method may include steps S201-S207.
  • S201 Acquire training sample data for training a preset student model.
  • S203 Obtain, from the teacher's recognition result, a weight parameter for adjusting the recognition result of the preset student model.
  • S204 Calculate the logarithm between the teacher's identification result and the student's identification result, and perform a weighting operation on the logarithm by using a weight parameter.
  • the convergence condition means that the loss value is less than the loss threshold preset by the user, or the loss value is the minimum value of the corresponding loss function.
  • the minimum value of the loss function used to calculate the loss value can be obtained, and if the loss value is not the same as the minimum value, it is determined that the loss value does not meet the convergence condition; or, It is verified whether the loss value is less than the preset loss threshold value, and if the loss value is greater than or equal to the preset loss threshold value, it is determined that the loss value does not satisfy the convergence condition.
  • the minimum value of the loss function used to calculate the loss value is obtained. If the loss value is not the same as the minimum value, or is smaller than the minimum value, it is determined that the loss value does not meet the convergence condition. Alternatively, it is verified whether the loss value is less than the preset loss threshold value, and if the loss value is greater than or equal to the preset loss threshold value, it is determined that the loss value does not satisfy the convergence condition.
  • the preset loss threshold can be set according to the data processing type of the student model or according to other indicators.
  • the loss value does not meet the convergence condition, it means that the teacher recognition result obtained by the teacher model recognizing the training sample data is quite different from the student recognition result obtained by the student model predicting the training sample data, that is, the student model recognition needs to be
  • the recognition result obtained by processing the data does not match the recognition result obtained by the teacher model identifying the data to be processed.
  • the loss degree to which the loss value belongs is determined, and the parameters in the preset student model are adjusted according to the loss degree. If the degree of loss is larger, the adjustment of the parameters in the preset student model is larger; the smaller the degree of loss is, the smaller the adjustment of the parameters in the preset student model is.
  • adjusting the preset student model based on the loss value can realize a greater degree of adjustment when the error degree of the student model is greater, thereby improving the convergence speed of the student model and improving the training efficiency.
  • the adjustment operation of the student model is more accurate, thereby improving the training accuracy of the student model.
  • the training sample data for training the preset student model is obtained, the preset student model and the preset teacher model are used to identify the training sample data respectively, and the teacher identification result of the training sample data is obtained. and student identification results.
  • Obtain the weight parameter used to adjust the recognition result of the preset student model from the teacher recognition result calculate the logarithm between the teacher recognition result and the student recognition result, and use the weight parameter to perform a weighting operation on the logarithm, and calculate the logarithm.
  • the obtained value is used as the loss value to adjust the preset student model.
  • the recognition results of the training sample data and the adjustment weights for adjusting the student model for different results in the prediction results of the training sample data can be reasonably allocated according to the weight parameters, so that more attention is paid to the recognition results with higher probability values, which can make the obtained loss value more accurate, and can make the adjusted student model more accurate.
  • a balance parameter is introduced to eliminate the abnormal results in the teacher identification results and the student identification results of the training sample data, so as to avoid inaccurate loss values due to abnormal results in the teacher identification results and the student identification results.
  • the student model is adjusted according to the loss degree of the loss value pair, which can realize a greater degree of adjustment when the error degree of the student model is greater, thereby improving the accuracy of the student model training.
  • the student model can have the data processing capability of the teacher model, and the accuracy of the student model can be improved.
  • FIG. 6 is a schematic structural diagram of a model distillation apparatus provided by an embodiment of the present application.
  • the above-mentioned model distillation apparatus may be a computer program (including program code) running in a computer device, for example, the model distillation apparatus is an application software; the apparatus may be used to execute corresponding steps in the methods provided in the embodiments of the present application.
  • the model distillation apparatus may include: a first acquisition module 11 , an identification module 12 , a second acquisition module 13 , and an adjustment module 14 .
  • the first acquisition module 11 is used for acquiring training sample data for training a preset student model
  • the identification module 12 is configured to use the preset student model and the preset teacher model to identify the training sample data respectively, and obtain the teacher identification result and the student identification result of the training sample data, wherein the preset The set student model is obtained through the guidance and training of the preset teacher model;
  • the second obtaining module 13 is configured to obtain, from the teacher identification result, a weight parameter for adjusting the identification result of the preset student model;
  • the adjustment module 14 is used to calculate the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, and use the calculated value as a loss value for all parameters.
  • the preset student model is adjusted.
  • the teacher identification results are multiple, and the teacher identification results represent the identification probability
  • the above-mentioned second acquisition module 13 includes:
  • an acquisition unit for acquiring a balance parameter for balancing the teacher identification result
  • the first grouping unit is configured to group the obtained multiple teacher identification results according to the balance parameter according to the preset identification sequence of the teacher model to obtain multiple teacher identification groups arranged in sequence, wherein all Each teacher identification group in the multiple teacher identification groups contains the same number of teacher identification results;
  • the first calculation unit is configured to calculate the average value of multiple teacher identification results in each teacher identification group respectively, and use the obtained multiple average values as weight parameters after balancing processing.
  • the above-mentioned adjustment module 14 includes:
  • the second grouping unit is configured to group the obtained multiple student identification results according to the balance parameter in the identification sequence to obtain multiple student identification groups arranged in sequence, wherein the multiple student identification groups are In each student identification group contains the same number of student identification results, each described teacher identification group and each described student identification group are in one-to-one correspondence according to the identification sequence;
  • a second calculation unit used for calculating the average value of the recognition results of multiple students in each of the student recognition groups
  • the third calculation unit is used to calculate the logarithm of the mean value of each student identification group and the mean value of the corresponding teacher identification group respectively, and obtain a plurality of logarithms after balanced processing;
  • the first weighting operation unit is configured to perform weighting operation on the weight parameter after the balance processing and the logarithm after the balance processing.
  • the above acquisition unit is specifically used for:
  • a target balance parameter corresponding to the preset threshold range is determined from a balance parameter library, and the target balance parameter is used as a balance parameter for balancing the teacher identification result.
  • the balance parameter library includes at least one balance parameter, and all The corresponding relationship between each balance parameter of the at least one balance parameter and the preset threshold range.
  • the above-mentioned adjustment module 14 includes:
  • a verification unit configured to verify whether the loss value satisfies the convergence state condition
  • a first determining unit configured to determine the degree of loss to which the loss value belongs if the loss value does not satisfy the convergence condition
  • a first adjustment unit configured to adjust parameters in the preset student model according to the loss degree.
  • the above verification unit is specifically used for:
  • the loss value is less than a preset loss threshold, and if the loss value is greater than or equal to the preset loss threshold, it is determined that the loss value does not satisfy the convergence condition.
  • the preset teacher model includes multiple teacher distillation layers, and the preset student model includes multiple student distillation layers;
  • the above-mentioned adjustment module 14 includes:
  • a second determining unit configured to determine the teacher distillation layer corresponding to each student distillation layer in the plurality of student distillation layers
  • the fourth calculation unit is used to calculate the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer;
  • the second weighting operation unit is configured to use the weight parameter to perform a weighted operation on the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer, to obtain the each The loss value of the student distillation layer;
  • the second adjustment unit is configured to adjust the corresponding student distillation layer in the preset student model by using the loss value of each student distillation layer respectively.
  • the training sample data used for training the preset teacher model is obtained, the preset student model and the preset teacher model are used to identify the training sample data respectively, and the teacher identification result of the training sample data is obtained. and student identification results.
  • Obtain the weight parameter used to adjust the recognition result of the preset student model from the teacher recognition result calculate the logarithm between the teacher recognition result and the student recognition result, and use the weight parameter to perform a weighting operation on the logarithm, and calculate the logarithm.
  • the obtained value is used as the loss value to adjust the preset student model.
  • the recognition results of the training sample data and the adjustment weights for adjusting the student model for different results in the prediction results of the training sample data can be reasonably allocated according to the weight parameters, so that more attention is paid to the recognition results with higher probability values, which can make the obtained loss value more accurate, and can make the adjusted student model more accurate.
  • a balance parameter is introduced to eliminate the abnormal results in the teacher identification results and the student identification results of the training sample data, so as to avoid inaccurate loss values due to abnormal results in the teacher identification results and the student identification results.
  • adjusting the student model according to the loss degree of the loss value pair can realize a greater degree of adjustment when the error degree of the student model is greater, thereby improving the accuracy of the student model training.
  • the student model can have the data processing capability of the teacher model, and the accuracy of the student model can be improved.
  • step S101 shown in FIG. 1 may be performed by the first acquisition module 11 shown in FIG. 6 ;
  • step S102 shown in FIG. 1 may be performed by the identification module 12 shown in FIG. 6 ;
  • step S103 shown in FIG. 1 It can be performed by the second acquisition module 13 in FIG. 6 ;
  • step S104 shown in FIG. 1 can be performed by the adjustment module 14 in FIG. 6 .
  • FIG. 7 is a schematic structural diagram of a computer device provided by an embodiment of the present application.
  • the computer device may include a processor and memory.
  • the computer device may further include a network interface and/or a user interface.
  • the above-mentioned computer device 1000 may include: a processor 1001 , a network interface 1004 and a memory 1005 , in addition, the above-mentioned computer device 1000 may further include: a user interface 1003 , and at least one communication bus 1002 .
  • the communication bus 1002 is used to realize the connection and communication between these components.
  • the user interface 1003 may include a display screen (Display) and a keyboard (Keyboard), and the optional user interface 1003 may also include a standard wired interface and a wireless interface.
  • the network interface 1004 may include a standard wired interface and a wireless interface (eg, a WI-FI interface).
  • the memory 1005 can be a high-speed RAM memory, or a non-volatile memory, such as at least one disk memory.
  • the memory 1005 may also be at least one storage device located away from the aforementioned processor 1001 .
  • the memory 1005 as a computer-readable storage medium may include an operating system, a network communication module, a user interface module, and a device control application program.
  • the network interface 1004 can provide a network communication function;
  • the user interface 1003 is mainly used to provide an input interface for the user; and
  • the processor 1001 can be used to call the device control application stored in the memory 1005 program to achieve:
  • the training sample data is identified by the preset student model and the preset teacher model, respectively, and the teacher identification result and the student identification result of the training sample data are obtained, wherein the preset student model is determined by the The above-mentioned preset teacher model guides the training;
  • the processor 1001 can be used to call the device control application program stored in the memory 1005 to realize:
  • the obtained multiple teacher identification results are grouped according to the balance parameter, to obtain multiple teacher identification groups arranged in sequence, wherein, among the multiple teacher identification groups, Each teacher identification group contains the same number of teacher identification results;
  • the processor 1001 can be used to call the device control application program stored in the memory 1005 to realize:
  • each student identification group in the multiple student identification groups includes: The same number of student identification results, each described teacher identification group and each described student identification group are in one-to-one correspondence according to the identification sequence;
  • a weighting operation is performed on the weight parameter after the balance processing and the logarithm after the balance processing.
  • the processor 1001 can be used to call the device control application program stored in the memory 1005 to realize:
  • a target balance parameter corresponding to the preset threshold range is determined from a balance parameter library, and the target balance parameter is used as a balance parameter for balancing the teacher identification result.
  • the balance parameter library includes at least one balance parameter, and all The corresponding relationship between each balance parameter of the at least one balance parameter and the preset threshold range.
  • the processor 1001 can be used to call the device control application program stored in the memory 1005 to realize:
  • the parameters in the preset student model are adjusted according to the loss degree.
  • the processor 1001 can be used to call the device control application program stored in the memory 1005 to realize:
  • the loss value is less than a preset loss threshold, and if the loss value is greater than or equal to the preset loss threshold, it is determined that the loss value does not satisfy the convergence condition.
  • the processor 1001 can be used to call the device control application program stored in the memory 1005 to realize:
  • the logarithm between the teacher identification result and the student identification result is calculated, and the logarithm is weighted by using the weight parameter, and the calculated value is used as a loss value for the preset value.
  • Student models are adjusted to include:
  • the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer is weighted to obtain the loss value of each student distillation layer;
  • the corresponding student distillation layer in the preset student model is adjusted by using the loss value of each student distillation layer.
  • the training sample data for training the preset student model is obtained, the preset student model and the preset teacher model are used to identify the training sample data respectively, and the teacher identification result of the training sample data is obtained. and student identification results.
  • Obtain the weight parameter used to adjust the recognition result of the preset student model from the teacher recognition result calculate the logarithm between the teacher recognition result and the student recognition result, and use the weight parameter to perform a weighting operation on the logarithm, and calculate the logarithm.
  • the obtained value is used as the loss value to adjust the preset student model.
  • the recognition results of the training sample data and the adjustment weights for adjusting the student model for different results in the prediction results of the training sample data can be reasonably allocated according to the weight parameters, so that more attention is paid to the recognition results with higher probability values, which can make the obtained loss value more accurate, and can make the adjusted student model more accurate.
  • a balance parameter is introduced to eliminate the abnormal results in the teacher identification results and the student identification results of the training sample data, so as to avoid inaccurate loss values due to abnormal results in the teacher identification results and the student identification results.
  • adjusting the student model according to the loss degree of the loss value pair can realize a greater degree of adjustment when the error degree of the student model is greater, thereby improving the accuracy of the student model training.
  • the student model can have the data processing capability of the teacher model, and the accuracy of the student model can be improved.
  • the computer device 1000 described in this embodiment of the present application can execute the description of the above model distillation method in the foregoing embodiment corresponding to FIG. 1 and the foregoing FIG.
  • the description of the distillation apparatus will not be repeated here.
  • the description of the beneficial effects of using the same method will not be repeated.
  • the embodiment of the present application also provides a computer-readable storage medium, and the computer-readable storage medium described above stores the computer executed by the model distillation apparatus mentioned above.
  • the above computer program includes program instructions.
  • the above-mentioned processor executes the above-mentioned program instructions, the above-mentioned description of the above-mentioned model distillation method in the corresponding embodiment of FIG. 1 or FIG. 5 can be executed.
  • the description of the beneficial effects of using the same method will not be repeated.
  • the description of the method embodiments of the present application please refer to the description of the method embodiments of the present application.
  • the storage medium involved in this application such as a computer-readable storage medium, may be non-volatile or volatile.
  • program instructions may be deployed and executed on one computer device, or on multiple computer devices located at one site, or alternatively, distributed across multiple sites and interconnected by a communication network.
  • Executed on a blockchain multiple computer devices distributed in multiple locations and interconnected by a communication network can form a blockchain network.
  • the above-mentioned storage medium may be a magnetic disk, an optical disk, a read-only memory (Read-Only Memory, ROM) or a random access memory (Random Access Memory, RAM) and the like.

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Engineering & Computer Science (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Computational Linguistics (AREA)
  • General Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

Disclosed are a model distillation method and apparatus, and a storage medium and a device. The method comprises: acquiring training sample data which is used for training a preset student model; respectively recognizing the training sample data by using the preset student model and a preset teacher model, so as to obtain a teacher recognition result and a student recognition result of the training sample data; acquiring, from the teacher recognition result, a weight parameter which is used for adjusting a recognition result of the preset student model; and calculating a logarithm between the teacher recognition result and the student recognition result, performing a weighted operation on the logarithm by using the weight parameter, and adjusting the preset student model by taking a calculated numerical value as a loss value. By means of the present application, a student model can have the data processing capability of a teacher model, thereby improving the accuracy of the student model.

Description

模型蒸馏方法、装置、存储介质及设备Model distillation method, device, storage medium and equipment
本申请要求于2020年11月20日提交中国专利局、申请号为202011313330.8,发明名称为“模型蒸馏方法、装置、存储介质及设备”的中国专利申请的优先权,其全部内容通过引用结合在本申请中。This application claims the priority of the Chinese patent application filed on November 20, 2020 with the application number 202011313330.8 and titled "Model Distillation Method, Apparatus, Storage Medium and Equipment", the entire contents of which are incorporated by reference in in this application.
技术领域technical field
本申请涉及人工智能技术领域,尤其涉及一种模型蒸馏方法、装置、存储介质及设备。The present application relates to the technical field of artificial intelligence, and in particular, to a model distillation method, device, storage medium and device.
背景技术Background technique
模型蒸馏作为一种重要的模型压缩和加速的技术方案近年来备受关注,对自然语言处理领域起到重要的推动作用。模型蒸馏(知识蒸馏)是指用准确度较高但结构复杂的老师模型指导训练准确度较低但结构简单的学生模型,以此提升学生模型的准确度。As an important technical solution for model compression and acceleration, model distillation has attracted much attention in recent years, and has played an important role in promoting the field of natural language processing. Model distillation (knowledge distillation) refers to using a teacher model with high accuracy but a complex structure to guide the training of a student model with low accuracy but a simple structure, so as to improve the accuracy of the student model.
发明人意识到,虽然学生模型能够从老师模型中学到知识,提升学生模型的准确度。但是现有的蒸馏模型架构中老师模型与学生模型还是存在一定的差异,导致学生模型的表达效果较差以及准确度较低。The inventors realized that although the student model can learn from the teacher model, the accuracy of the student model is improved. However, there are still some differences between the teacher model and the student model in the existing distillation model architecture, resulting in poor expression effect and low accuracy of the student model.
发明内容SUMMARY OF THE INVENTION
本申请实施例所要解决的技术问题在于,提供一种模型蒸馏方法、装置、存储介质及设备,能够提高学生模型的准确度和数据处理能力。The technical problem to be solved by the embodiments of the present application is to provide a model distillation method, device, storage medium and device, which can improve the accuracy and data processing capability of the student model.
本申请实施例一方面提供一种模型蒸馏方法,包括:On the one hand, the embodiments of the present application provide a method for model distillation, including:
获取用于对预设的学生模型进行训练的训练样本数据;Obtain training sample data for training a preset student model;
采用所述预设的学生模型和预设的老师模型分别对所述训练样本数据进行识别,得到所述训练样本数据的老师识别结果和学生识别结果,其中,所述预设的学生模型由所述预设的老师模型指导训练得到;The training sample data is identified by the preset student model and the preset teacher model, respectively, and the teacher identification result and the student identification result of the training sample data are obtained, wherein the preset student model is determined by the The above-mentioned preset teacher model guides the training;
由所述老师识别结果获取用于对所述预设的学生模型的识别结果进行调整的权重参数;Obtaining a weight parameter for adjusting the recognition result of the preset student model from the teacher recognition result;
计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,以及将计算得到的数值作为损失值对所述预设的学生模型进行调整。Calculate the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, and use the calculated value as a loss value to the preset student model make adjustments.
本申请实施例一方面提供一种模型蒸馏装置,包括:One aspect of the embodiments of the present application provides a model distillation device, comprising:
第一获取模块,用于获取用于对预设的学生模型进行训练的训练样本数据;a first acquisition module, used for acquiring training sample data for training a preset student model;
识别模块,用于采用所述预设的学生模型和预设的老师模型分别对所述训练样本数据进行识别,得到所述训练样本数据的老师识别结果和学生识别结果,其中,所述预设的学生模型由所述预设的老师模型指导训练得到;an identification module, configured to identify the training sample data by using the preset student model and the preset teacher model, respectively, to obtain a teacher identification result and a student identification result of the training sample data, wherein the preset The student model is obtained by the guidance and training of the preset teacher model;
第二获取模块,用于由所述老师识别结果获取用于对所述预设的学生模型的识别结果进行调整的权重参数;a second acquisition module, configured to acquire, from the teacher recognition result, a weight parameter for adjusting the recognition result of the preset student model;
调整模块,用于计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,以及将计算得到的数值作为损失值对所述预设的学生模型进行调整。The adjustment module is used to calculate the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, and use the calculated value as a loss value to the said logarithm. The preset student models can be adjusted.
本申请一方面提供了一种计算机设备,包括:处理器及存储器;One aspect of the present application provides a computer device, including: a processor and a memory;
其中,上述存储器用于存储计算机程序,上述处理器用于调用上述计算机程序,以执行如下步骤:Wherein, the above-mentioned memory is used to store the computer program, and the above-mentioned processor is used to call the above-mentioned computer program to perform the following steps:
获取用于对预设的学生模型进行训练的训练样本数据;Obtain training sample data for training a preset student model;
采用所述预设的学生模型和预设的老师模型分别对所述训练样本数据进行识别,得到所述训练样本数据的老师识别结果和学生识别结果,其中,所述预设的学生模型由所述预设的老师模型指导训练得到;The training sample data is identified by the preset student model and the preset teacher model, respectively, and the teacher identification result and the student identification result of the training sample data are obtained, wherein the preset student model is determined by the The above-mentioned preset teacher model guides the training;
由所述老师识别结果获取用于对所述预设的学生模型的识别结果进行调整的权重参数;Obtaining a weight parameter for adjusting the recognition result of the preset student model from the teacher recognition result;
计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,以及将计算得到的数值作为损失值对所述预设的学生模型进行调整。Calculate the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, and use the calculated value as a loss value to the preset student model make adjustments.
本申请实施例一方面提供了一种计算机可读存储介质,上述计算机可读存储介质存储有计算机程序,上述计算机程序包括程序指令,上述程序指令当被处理器执行时,以执行如下步骤:On the one hand, an embodiment of the present application provides a computer-readable storage medium. The computer-readable storage medium stores a computer program, and the computer program includes program instructions. When executed by a processor, the program instructions perform the following steps:
获取用于对预设的学生模型进行训练的训练样本数据;Obtain training sample data for training a preset student model;
采用所述预设的学生模型和预设的老师模型分别对所述训练样本数据进行识别,得到所述训练样本数据的老师识别结果和学生识别结果,其中,所述预设的学生模型由所述预设的老师模型指导训练得到;The training sample data is identified by the preset student model and the preset teacher model, respectively, and the teacher identification result and the student identification result of the training sample data are obtained, wherein the preset student model is determined by the The above-mentioned preset teacher model guides the training;
由所述老师识别结果获取用于对所述预设的学生模型的识别结果进行调整的权重参数;Obtaining a weight parameter for adjusting the recognition result of the preset student model from the teacher recognition result;
计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,以及将计算得到的数值作为损失值对所述预设的学生模型进行调整。Calculate the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, and use the calculated value as a loss value to the preset student model make adjustments.
本申请实施例可以根据权重参数合理分配训练样本数据的识别结果以及训练样本数据的预测结果中不同结果对学生模型进行调整的调整权重,可以使调整后的学生模型更准确。通过本申请可以使学生模型具备老师模型的数据处理能力,提高学生模型的准确度。The embodiment of the present application can reasonably allocate the identification results of the training sample data and the adjustment weights for adjusting the student model for different results in the prediction results of the training sample data according to the weight parameters, which can make the adjusted student model more accurate. Through this application, the student model can have the data processing capability of the teacher model, and the accuracy of the student model can be improved.
附图说明Description of drawings
图1是本申请提供的一种模型蒸馏方法的流程示意图;Fig. 1 is the schematic flow sheet of a kind of model distillation method provided by the application;
图2是本申请实施例提供的一种计算每个老师识别组中多个老师识别结果的平均值方法的示意图;2 is a schematic diagram of a method for calculating the average value of multiple teacher identification results in each teacher identification group provided by an embodiment of the present application;
图3是本申请实施例提供的一种获取预设的学生模型的损失值方法的示意图;3 is a schematic diagram of a method for obtaining a loss value of a preset student model provided by an embodiment of the present application;
图4是本申请实施例提供的一种模型蒸馏的示意图;Fig. 4 is the schematic diagram of a kind of model distillation provided by the embodiment of the present application;
图5是本申请提供的另一种模型蒸馏方法的流程示意图;Fig. 5 is the schematic flow sheet of another kind of model distillation method provided by the application;
图6是本申请提供的一种模型蒸馏装置的流程示意图;Fig. 6 is the schematic flow sheet of a kind of model distillation apparatus provided by the application;
图7是本申请实施例提供的一种计算机设备的结构示意图。FIG. 7 is a schematic structural diagram of a computer device provided by an embodiment of the present application.
具体实施方式Detailed ways
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行描述。The technical solutions in the embodiments of the present application will be described below with reference to the accompanying drawings in the embodiments of the present application.
本申请的技术方案涉及人工智能和/或大数据技术领域。可选的,本申请涉及的数据如样本数据、识别结果和/或损失值等可存储于数据库中,或者可以存储于区块链中,比如通过区块链分布式存储,本申请不做限定。The technical solutions of the present application relate to the technical field of artificial intelligence and/or big data. Optionally, the data involved in this application, such as sample data, identification results and/or loss values, may be stored in a database, or may be stored in a blockchain, such as distributed storage through a blockchain, which is not limited in this application. .
请参见图1,是本申请实施例提供的一种模型蒸馏方法的流程示意图。该方法可由计算机设备来执行,该计算机设备可以是指终端或服务器,终端可包括但不限于:智能手机、平板电脑、笔记本电脑、台式计算机、智能音箱、智能手表等;服务器可以是独立的一个物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(Content Delivery Network,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。如图1所示,该模型蒸馏方法可以包括步骤S101-S105。Please refer to FIG. 1 , which is a schematic flowchart of a model distillation method provided by an embodiment of the present application. The method can be performed by a computer device, and the computer device can refer to a terminal or a server, and the terminal can include but is not limited to: a smart phone, a tablet computer, a notebook computer, a desktop computer, a smart speaker, a smart watch, etc.; the server can be an independent one A physical server can also be a server cluster or distributed system composed of multiple physical servers, and can also provide cloud services, cloud databases, cloud computing, cloud functions, cloud storage, network services, cloud communications, middleware services, and domain name services. , security services, Content Delivery Network (CDN), and cloud servers for basic cloud computing services such as big data and artificial intelligence platforms. As shown in FIG. 1, the model distillation method may include steps S101-S105.
S101,获取用于对预设的学生模型进行训练的训练样本数据。S101: Acquire training sample data for training a preset student model.
获取用于对预设的学生模型进行训练的训练样本数据,该预设的学生模型是指模型蒸馏中的学生模型,模型蒸馏(知识蒸馏)是指用老师模型指导学生模型训练,以此提升学生模型的准确度。训练样本数据可以是指文本数据、图像数据等等。Obtain training sample data for training a preset student model. The preset student model refers to the student model in model distillation. Model distillation (knowledge distillation) refers to using the teacher model to guide the training of the student model, so as to improve The accuracy of the student model. The training sample data may refer to text data, image data, and the like.
S102,采用预设的学生模型和预设的老师模型分别对训练样本数据进行识别,得到训练样本数据的老师识别结果和学生识别结果。S102, using a preset student model and a preset teacher model to identify the training sample data, respectively, to obtain a teacher identification result and a student identification result of the training sample data.
采用预设的老师模型对训练样本数据进行识别,得到训练样本数据的老师识别结果,以及采用预设的学生模型对训练样本数据进行识别,得到训练样本数据的学生识别结果。其中,预设的学生模型由预设的老师模型指导训练得到,一般来说,老师模型准确高,但计算复杂度也大,不适合在终端设备部署,而学生模型计算相对简单符合终端设备要求, 但准确度不够,所以可以采取模型蒸馏(distillation)解决这一问题,即由预设的老师模型指导训练预设的学生模型,提高预设的学生模型的准确度。其中,老师模型与学生模型之间的数据处理类型相同,老师模型中的网络更深或者更宽,以及老师模型分辨率更大,即老师模型的数据处理能力高于学生模型的数据处理能力。具体的,在获取老师模型和学生模型的时候,可以根据老师模型与学生模型的相似度进行选择,老师模型与学生模型的结构越类似,蒸馏后两者的精度差异越小。因此可以将老师模型与学生模型选择为同种类型的模型,如都是resnet网络系列,resnet网络是指残差网络,可以很容易地调整网络的宽度和深度,来得到不同表达能力的网络。再根据预设的老师模型输出的训练样本数据的老师识别结果与预设的学生模型输出的训练样本数据的学生识别结果,对预设的学生模型进行调整,以此将预设的老师模型中的知识迁移到预设的学生模型中,以使预设的学生模型具备预设的老师模型的数据处理能力以及精度。A preset teacher model is used to identify the training sample data to obtain a teacher identification result of the training sample data, and a preset student model is used to identify the training sample data to obtain a student identification result of the training sample data. Among them, the preset student model is guided and trained by the preset teacher model. Generally speaking, the teacher model has high accuracy and high computational complexity, which is not suitable for deployment in terminal equipment, while the calculation of the student model is relatively simple and meets the requirements of terminal equipment. , but the accuracy is not enough, so model distillation (distillation) can be used to solve this problem, that is, the preset student model is trained under the guidance of the preset teacher model, so as to improve the accuracy of the preset student model. Among them, the data processing type between the teacher model and the student model is the same, the network in the teacher model is deeper or wider, and the resolution of the teacher model is larger, that is, the data processing capability of the teacher model is higher than that of the student model. Specifically, when obtaining the teacher model and the student model, the selection can be made according to the similarity between the teacher model and the student model. The more similar the structure of the teacher model and the student model, the smaller the difference in accuracy between the two after distillation. Therefore, the teacher model and the student model can be selected as the same type of model, such as resnet network series, resnet network refers to the residual network, and the width and depth of the network can be easily adjusted to obtain networks with different expressive abilities. Then, according to the teacher identification result of the training sample data output by the preset teacher model and the student identification result of the training sample data output by the preset student model, the preset student model is adjusted, so that the preset teacher model is adjusted. The knowledge is transferred to the preset student model, so that the preset student model has the data processing capability and accuracy of the preset teacher model.
S103,由老师识别结果获取用于对预设的学生模型的识别结果进行调整的权重参数。S103: Obtain, from the teacher's recognition result, a weight parameter for adjusting the recognition result of the preset student model.
可以采用老师识别结果生成对预设的学生模型的识别结果进行调整的权重参数,该权重参数用于确定预设的学生模型的识别结果在生成损失值时所占的权重,即权重参数越大,预设的学生模型的识别结果在生成损失值时所占的权重越大。The teacher recognition result can be used to generate a weight parameter for adjusting the recognition result of the preset student model, and the weight parameter is used to determine the weight of the recognition result of the preset student model when generating the loss value, that is, the larger the weight parameter is. , the recognition result of the preset student model has a greater weight in generating the loss value.
可选的,在由老师识别结果获取用于对预设的学生模型的识别结果进行调整的权重参数时,可以获取用于对老师识别结果进行平衡处理的平衡参数,以预设的老师模型的识别顺序,将得到的多个老师识别结果按照平衡参数进行分组,得到依次排列的多个老师识别组,其中,每个老师识别组中包含相同数目的老师识别结果。分别计算每个老师识别组中多个老师识别结果的平均值,并将得到的多个平均值作为平衡处理后的权重参数。Optionally, when the weight parameter used to adjust the recognition result of the preset student model is obtained from the teacher recognition result, the balance parameter used to balance the teacher recognition result can be obtained, and the preset teacher model's weight parameter can be obtained. In the identification sequence, the obtained multiple teacher identification results are grouped according to the balance parameter to obtain multiple teacher identification groups arranged in sequence, wherein each teacher identification group contains the same number of teacher identification results. Calculate the average value of multiple teacher recognition results in each teacher recognition group respectively, and use the obtained multiple average values as weight parameters after balancing processing.
其中,老师识别结果为多个老师识别结果,该老师识别结果表示识别概率,是指预设的老师模型对训练样本数据进行识别得到的识别概率。可以获取用于对老师识别结果进行平衡处理的平衡参数,该平衡参数可以为大于等于1的正整数,可以用于老师识别结果中的异常结果,该异常结果是指远远大于正常结果或者远远小于正常结果的结果,即与其他结果相比异常。并以预设的老师识别模型的识别顺序,对多个老师识别结果进行排序,得到排序后的多个老师识别结果,并按照平衡参数对排序后的多个老师识别结果进行分组,得到依次排列的多个老师识别组。该多个老师识别组中每个老师识别组中包括的老师识别结果的数目相同。然后分别计算每个老师识别组中多个老师识别结果的平均值,得到多个老师识别组的多个平均值,并将多个老师识别组的多个平均值作为平衡处理后的权重参数,则可以利用该平衡处理后的权重参数对预设的学生模型的识别结果进行调整。其中,当老师识别组中多个老师识别结果的平均值越大,则对应的平衡处理后的权重参数越大,老师识别结果表示识别概率,老师识别组中的识别概率的平均值越大,生成的平衡处理后的权重参数越大。这样更加关注概率分布中高概率的位置,而优先正确匹配概念分布中高可能性的事件是有实际价值的,这样根据权重参数计算得到的预设的学生模型的损失值更准确,可以更准确的提高预设的学生模型的准确度。将得到的多个老师识别结果按照平衡参数进行分组,得到依次排列的多个老师识别组,再计算多个老师识别组中每个老师识别组的多个老师识别结果的平均值,即采用平衡参数对每个老师识别结果进行平衡处理,这样可以平衡老师识别结果中的异常结果,减少生成权重参数时所产生的误差,可以提高预设的学生模型的准确率。The teacher identification result is a plurality of teacher identification results, and the teacher identification result represents the identification probability, which refers to the identification probability obtained by identifying the training sample data by the preset teacher model. The balance parameter used to balance the teacher identification results can be obtained. The balance parameter can be a positive integer greater than or equal to 1, and can be used for abnormal results in the teacher identification results. The abnormal results are far larger than normal results or far away. Results that are much smaller than normal results, i.e. abnormal compared to other results. And according to the preset recognition sequence of the teacher recognition model, sort the multiple teacher recognition results to obtain the sorted multiple teacher recognition results, and group the sorted multiple teacher recognition results according to the balance parameter, and get the sequenced arrangement. of multiple teacher identification groups. The number of teacher identification results included in each of the plurality of teacher identification groups is the same. Then, the average value of the recognition results of multiple teachers in each teacher recognition group is calculated separately, and the average value of the recognition results of the teachers is obtained. Then, the recognition result of the preset student model can be adjusted by using the weight parameter after the balancing process. Among them, when the average value of multiple teacher recognition results in the teacher recognition group is larger, the corresponding weight parameter after balancing processing is larger, the teacher recognition result represents the recognition probability, and the average value of the recognition probability in the teacher recognition group is larger, The resulting balanced weight parameter is larger. In this way, more attention is paid to the high-probability positions in the probability distribution, and it is of practical value to give priority to correctly matching high-probability events in the concept distribution. In this way, the loss value of the preset student model calculated according to the weight parameters is more accurate and can be improved more accurately. The accuracy of the preset student model. The obtained multiple teacher identification results are grouped according to the balance parameter to obtain multiple teacher identification groups arranged in sequence, and then the average value of the multiple teacher identification results of each teacher identification group in the multiple teacher identification groups is calculated, that is, the balance is adopted. The parameter balances each teacher identification result, which can balance the abnormal results in the teacher identification result, reduce the error generated when generating the weight parameter, and improve the accuracy of the preset student model.
可选的,在获取用于对老师识别结果进行平衡的平衡参数时,可以获取多个老师识别结果中老师识别结果的数目,确定多个老师识别结果中老师识别结果的数目所属的预设阈值范围。从平衡参数库中确定与老师识别结果的数目所属的预设阈值范围对应的目标平衡参数,将目标平衡参数作为对老师识别结果进行平衡的平衡参数,平衡参数库中包括至少一个平衡参数,以及至少一个平衡参数中每个平衡参数与预设阈值范围之间的对应关系。Optionally, when obtaining the balance parameter for balancing the teacher identification results, the number of the teacher identification results in the multiple teacher identification results can be obtained, and the preset threshold to which the number of the teacher identification results in the multiple teacher identification results belongs to is determined. scope. Determine the target balance parameter corresponding to the preset threshold range to which the number of teacher identification results belongs from the balance parameter library, use the target balance parameter as the balance parameter for balancing the teacher identification results, the balance parameter library includes at least one balance parameter, and Correspondence between each of the at least one balance parameter and a preset threshold range.
由于老师识别结果为多个老师识别结果,该老师识别结果表示识别概率,是指预设的老师模型对训练样本数据进行识别得到的识别概率,因此可以获取多个老师识别结果中老师识别结果的数目,并确定多个老师识别结果中老师识别结果的数目所属的预设阈值范围。确定多个老师识别结果中老师识别结果的数目所属的预设阈值范围后,可以从平衡参数数据库中确定与老师识别结果的数目所属的预设阈值范围对应的目标平衡参数,该平衡参数库中包括至少一个平衡参数,以及至少一个平衡参数中每个平衡参数与预设阈值范围之间的对应关系。如预先设置一个平衡参数库,该平衡参数库包括至少一个平衡参数,以及至少一个平衡参数中每个平衡参数与预设阈值范围之间的对应关系,如第一阈值范围对应第一平衡参数,第二阈值范围对应第二平衡参数等等。Since the teacher recognition result is a plurality of teacher recognition results, the teacher recognition result represents the recognition probability, which refers to the recognition probability obtained by the preset teacher model recognizing the training sample data. Therefore, the teacher recognition result among the multiple teacher recognition results can be obtained. The number of teacher identification results is determined, and the preset threshold range to which the number of teacher identification results in the plurality of teacher identification results belongs. After determining the preset threshold range to which the number of teacher identification results among the plurality of teacher identification results belongs, the target balance parameter corresponding to the preset threshold range to which the number of teacher identification results belongs may be determined from the balance parameter database. It includes at least one balance parameter, and a corresponding relationship between each balance parameter in the at least one balance parameter and a preset threshold range. For example, a balance parameter library is preset, and the balance parameter library includes at least one balance parameter, and a corresponding relationship between each balance parameter in the at least one balance parameter and a preset threshold range, for example, the first threshold range corresponds to the first balance parameter, The second threshold range corresponds to the second balance parameter and so on.
可选的,由于学生识别结果为多个学生识别结果,学生识别结果与老师识别结果对应,因此也可以根据多个学生识别结果中学生识别结果的数目确定平衡参数。同样的,也可以确定学生识别结果的数目所属的预设阈值范围,再从平衡数据库中确定与学生识别结果的数目所属的预设阈值范围对应的目标平衡参数,将目标平衡参数作为对老师识别结果进行平衡的平衡参数,平衡参数库中包括至少一个平衡参数,以及至少一个平衡参数中每个平衡参数与预设阈值范围之间的对应关系。需要说明的是,平衡参数C可以根据老师识别结果的数目或者学生的识别结果的数目确定,或者其他具体情况进行确定,目的是为了平衡老师识别结果或者学生识别结果中的异常结果,即消除平衡老师识别结果或者学生识别结果中异常结果。Optionally, since the student identification results are multiple student identification results, and the student identification results correspond to the teacher identification results, the balance parameter may also be determined according to the number of student identification results among the multiple student identification results. Similarly, it is also possible to determine the preset threshold range to which the number of student identification results belongs, and then determine the target balance parameter corresponding to the preset threshold range to which the number of student identification results belongs from the balance database, and use the target balance parameter as the target balance parameter for teacher identification. As a result, the balance parameters are balanced, and the balance parameter library includes at least one balance parameter, and a corresponding relationship between each balance parameter in the at least one balance parameter and a preset threshold range. It should be noted that the balance parameter C can be determined according to the number of recognition results of the teacher or the number of recognition results of the students, or other specific circumstances. Abnormal results in teacher identification results or student identification results.
如图2所示,为本申请实施例提供的一种计算每个老师识别组中多个老师识别结果的平均值方法的示意图,如图2所示,以老师识别结果的数目为11的老师识别结果分布为例,该11个老师识别结果以预设的老师模型的识别顺序进行排序,得到排序后的老师识别结果,即x1、x2……x11。若根据老师识别结果的数目11确定平衡参数为3,则可以排序后的11个老师识别结果按照平衡参数3进行分组,得到依次排列的多个老师识别组,即[x1,x2]、[x1,x2,x3]、[x2,x3,x4]、[x3,x4,x5]…..[x9,x10,x11]、[x10,x11]。其中,由于没有排序在老师识别结果x1前面的没有老师识别结果,因此老师识别结果x1对应的老师识别组可以为[x1,x2],也可以为[x1,x2,x3]。同样的,由于没有排序在老师识别结果x11后面的老师识别结果,因此老师识别结果x11对应的老师识别组可以为[x10,x11],也可以为[x9,x10,x11]。得到依次排序的多个老师识别组后,分别计算每个老师识别组中多个老师识别结果的平均值。如得到老师识别结果x2对应的老师识别组[x1,x2,x3]后,对老师识别结果x1、老师识别结果x2以及老师识别结果x3进行求和,然后再除以该组成员数目3,得到老师识别结果x2对应的老师识别组中多个老师识别结果的平均值。As shown in Figure 2, a schematic diagram of a method for calculating the average value of multiple teacher identification results in each teacher identification group provided by the embodiment of the present application, as shown in Figure 2, the number of teacher identification results is 11 teachers Taking the distribution of recognition results as an example, the 11 teacher recognition results are sorted according to the preset recognition order of the teacher model, and the sorted teacher recognition results are obtained, that is, x1, x2...x11. If the balance parameter is determined to be 3 according to the number of teacher identification results 11, then the 11 sorted teacher identification results can be grouped according to the balance parameter 3 to obtain multiple teacher identification groups arranged in sequence, namely [x1, x2], [x1 , x2, x3], [x2, x3, x4], [x3, x4, x5]…..[x9, x10, x11], [x10, x11]. Among them, since there is no teacher recognition result before the teacher recognition result x1, the teacher recognition group corresponding to the teacher recognition result x1 can be [x1, x2] or [x1, x2, x3]. Similarly, since there is no teacher recognition result sorted after the teacher recognition result x11, the teacher recognition group corresponding to the teacher recognition result x11 can be [x10, x11] or [x9, x10, x11]. After obtaining multiple teacher identification groups in sequence, calculate the average value of the multiple teacher identification results in each teacher identification group respectively. For example, after obtaining the teacher identification group [x1, x2, x3] corresponding to the teacher identification result x2, sum the teacher identification result x1, the teacher identification result x2 and the teacher identification result x3, and then divide it by the number of members of the group 3 to get The average value of multiple teacher recognition results in the teacher recognition group corresponding to the teacher recognition result x2.
例如,训练样本数据的老师识别结果分布为:For example, the distribution of teacher recognition results for training sample data is:
[0.1,0.05,0.0001,0.02,0.15,0.28,0.23,0.06,0.023,0.05,0.0369][0.1, 0.05, 0.0001, 0.02, 0.15, 0.28, 0.23, 0.06, 0.023, 0.05, 0.0369]
根据训练样本数据的老师识别结果中老师识别结果的数目确定平衡参数为3,则利用C=3进行采样分组之后得到的老师识别组为:According to the number of teacher identification results in the teacher identification results of the training sample data, the balance parameter is determined to be 3, then the teacher identification group obtained after sampling and grouping with C=3 is:
[0.1,0.05]、[0.1,0.05,0.0001]、[0.05,0.0001,0.02]、[0.0001,0.02,0.15]、[0.02,0.15,0.28]、[0.15,0.28,0.23]、[0.28,0.23,0.06]、[0.23,0.06,0.023]、[0.06,0.023,0.05]、[0.023,0.05,0.0369]、[0.023,0.05,0.0369]。[0.1, 0.05], [0.1, 0.05, 0.0001], [0.05, 0.0001, 0.02], [0.0001, 0.02, 0.15], [0.02, 0.15, 0.28], [0.15, 0.28, 0.23], [0.28, 0.23 , 0.06], [0.23, 0.06, 0.023], [0.06, 0.023, 0.05], [0.023, 0.05, 0.0369], [0.023, 0.05, 0.0369].
分别计算每个老师识别组中多个老师识别结果的平均值,得到的多个平均值为:Calculate the average value of multiple teacher recognition results in each teacher recognition group separately, and the obtained multiple average values are:
[0.075,0.05,0.0234,0.0567,0.15,0.22,0.19,0.1043,0.0443,0.0366,0.0436][0.075, 0.05, 0.0234, 0.0567, 0.15, 0.22, 0.19, 0.1043, 0.0443, 0.0366, 0.0436]
对于计算每一个老师识别结果的平均值的详细计算过程如下:The detailed calculation process for calculating the average value of each teacher's recognition result is as follows:
老师识别结果_0=(0.1+0.05)/2=0.0750,老师识别结果_1=(0.1+0.05+0.0001)/3=0.0500,老师识别结果_2=(0.05+0.0001+0.02)/3=0.0234,老师识别结果_3=(0.0001+0.02+0.15)/3=0.0567,老师识别结果_4=(0.02+0.15+0.28)/3=0.1500,老师识别结 果5=(0.15+0.28+0.23)/3=0.2200,老师识别结果_6=(0.28+0.23+0.06)/3=0.1900,老师识别结果_7=(0.23+0.06+0.023)/3=0.1043,老师识别结果8=(0.06+0.023+0.05)/3=0.0443,老师识别结果_9=(0.023+0.05+0.0369)/3=0.0366,老师识别结果10=(0.05+0.0369)/2=0.0435Teacher recognition result_0=(0.1+0.05)/2=0.0750, teacher recognition result_1=(0.1+0.05+0.0001)/3=0.0500, teacher recognition result_2=(0.05+0.0001+0.02)/3= 0.0234, teacher recognition result_3=(0.0001+0.02+0.15)/3=0.0567, teacher recognition result_4=(0.02+0.15+0.28)/3=0.1500, teacher recognition result 5=(0.15+0.28+0.23) /3=0.2200, teacher recognition result_6=(0.28+0.23+0.06)/3=0.1900, teacher recognition result_7=(0.23+0.06+0.023)/3=0.1043, teacher recognition result 8=(0.06+0.023 +0.05)/3=0.0443, teacher recognition result_9=(0.023+0.05+0.0369)/3=0.0366, teacher recognition result 10=(0.05+0.0369)/2=0.0435
对于第三个老师识别结果的0.0001,在平衡处理之后变成了0.0234,对异常老师识别结果进行了平衡,即消除了异常的老师识别结果。For the third teacher recognition result of 0.0001, it becomes 0.0234 after balancing processing, and the abnormal teacher recognition results are balanced, that is, the abnormal teacher recognition results are eliminated.
S104,计算老师识别结果和学生识别结果之间的对数,并利用权重参数对对数进行加权运算,以及将计算得到的数值作为损失值对预设的学生模型进行调整。S104: Calculate the logarithm between the teacher identification result and the student identification result, and perform a weighting operation on the logarithm using a weight parameter, and use the calculated value as a loss value to adjust the preset student model.
计算老师识别结果和学生识别结果之间的对数,并利用上述权重参数对该老师识别结果和学生识别结果之间的对数进行加权运算,得到加权计算后的数值,将该数值作为预设的学生模型的损失值。并根据该预设的学生模型的损失值对预设的学生模型进行调整,得到调整后的学生模型,将调整后的学生模型作为目标学生模型。目标学习模型用于识别待处理数据,且目标学习模型识别待处理数据得到的识别结果与预设的老师模型识别待处理数据得到的识别结果匹配,即目标学生模型具有预设的老师模型的数据处理能力。Calculate the logarithm between the teacher identification result and the student identification result, and use the above weight parameter to perform a weighted operation on the logarithm between the teacher identification result and the student identification result, and obtain the value after the weighted calculation, and use the value as a preset The loss value of the student model. The preset student model is adjusted according to the loss value of the preset student model to obtain an adjusted student model, and the adjusted student model is used as the target student model. The target learning model is used to identify the data to be processed, and the recognition result obtained by the target learning model identifying the data to be processed matches the recognition result obtained by identifying the data to be processed by the preset teacher model, that is, the target student model has the data of the preset teacher model. processing power.
可选的,在计算老师识别结果和学生识别结果之间的对数,并利用权重参数对对数进行加权运算时,可以采用如下公式(1)进行计算。Optionally, when the logarithm between the teacher identification result and the student identification result is calculated, and the logarithm is weighted by using a weight parameter, the following formula (1) may be used for calculation.
Figure PCTCN2021096649-appb-000001
Figure PCTCN2021096649-appb-000001
其中,公式(1)D KL(P||Q)是指预设的学生模型的损失值,P(x)是指训练样本数据的老师识别结果,即老师模型对训练样本数据进行识别得到的老师识别结果,Q(x)是指训练样本数据的学生识别结果,即学生模型对训练样本数据进行预测得到的学生识别结果。
Figure PCTCN2021096649-appb-000002
是指老师识别结果和学生识别结果之间的对数,权重参数可以是指P(x),x是指老师识别结果和学生识别结果中任意一个结果,X是指老师识别结果和学生识别结果。该公式(1)可以称为KL-divergence(相对熵)。
Among them, formula (1) D KL (P||Q) refers to the preset loss value of the student model, and P(x) refers to the teacher recognition result of the training sample data, that is, the teacher model recognizes the training sample data. The teacher recognition result, Q(x) refers to the student recognition result of the training sample data, that is, the student recognition result obtained by the student model predicting the training sample data.
Figure PCTCN2021096649-appb-000002
It refers to the logarithm between the teacher recognition result and the student recognition result. The weight parameter can refer to P(x), x refers to either the teacher recognition result or the student recognition result, and X refers to the teacher recognition result and the student recognition result. . This formula (1) can be called KL-divergence (relative entropy).
如图3所示,为本申请实施例提供的一种获取预设的学生模型的损失值方法的示意图,如图3所示,该一种获取预设的学生模型的损失值方法包括步骤S21-S23。As shown in FIG. 3 , a schematic diagram of a method for obtaining a preset loss value of a student model provided by an embodiment of the present application, as shown in FIG. 3 , the method for obtaining a preset loss value of a student model includes step S21 -S23.
S21,以上述识别顺序,将得到的多个学生识别结果按照平衡参数进行分组,得到依次排列的多个学生识别组,其中,多个学生识别组中每个学生识别组包含相同数目的学生识别结果,每个老师识别组与每个学生识别组按照识别顺序一一对应。S21, in the above identification sequence, group the obtained multiple student identification results according to the balance parameter, and obtain multiple student identification groups arranged in sequence, wherein each student identification group in the multiple student identification groups includes the same number of student identification groups As a result, each teacher identification group corresponds to each student identification group in a one-to-one identification order.
S22,分别计算每个学生识别组中多个学生识别结果的平均值。S22, calculate the average value of the recognition results of multiple students in each student recognition group respectively.
学生识别结果为多个学生识别结果,该学生识别结果表示识别概率,是指预设的学生模型对训练样本数据进行识别得到的识别概率。可以按照预设的老师模型的识别顺序,该预设的老师模型的识别顺序与预设的学生模型的识别顺序相同,即对相同排列顺序的训练样本数据进行识别。将得到的多个学生识别结果按照上述平衡参数进行分组,得到依次排列的多个学生识别组。其中,多个学生识别组中每个学生识别组包含相同数目的学生识别结果,每个老师识别组与每个学生识别组按照识别顺序一一对应。将得到的多个学生识别结果按照平衡参数进行分组,得到依次排列的多个学生识别组,然后计算多个学生识别组中每个学生识别组的多个学生识别结果的平均值,即采用平衡参数对每个学生识别结果进行平衡处理,得到平衡处理后的学生识别结果。这样可以对学生识别结果中异常结果进行平衡处理,可大大降低极端值的出现,从而提高目标学生模型的准确度,使学生模型具备老师模型的数据处理能力。The student identification result is a plurality of student identification results, and the student identification result represents the identification probability, which refers to the identification probability obtained by identifying the training sample data by the preset student model. The recognition sequence of the preset teacher model may be the same as the preset recognition sequence of the student model, that is, the training sample data in the same arrangement order can be recognized. The obtained multiple student identification results are grouped according to the above-mentioned balance parameter to obtain multiple student identification groups arranged in sequence. Wherein, each student identification group in the multiple student identification groups contains the same number of student identification results, and each teacher identification group corresponds to each student identification group in a one-to-one identification sequence. The obtained multiple student identification results are grouped according to the balance parameter to obtain multiple student identification groups arranged in sequence, and then the average value of the multiple student identification results of each student identification group in the multiple student identification groups is calculated, that is, the balance is adopted. The parameters are balanced for each student identification result, and the balanced student identification result is obtained. In this way, the abnormal results in the student identification results can be balanced, and the occurrence of extreme values can be greatly reduced, thereby improving the accuracy of the target student model and enabling the student model to have the data processing capability of the teacher model.
S23,分别计算每个学生识别组的平均值与对应的老师识别组的平均值的对数,得到多个平衡处理后的对数,将平衡处理后的权重参数与平衡处理后的对数进行加权运算。S23, calculate the logarithm of the average value of each student identification group and the average value of the corresponding teacher identification group respectively, obtain a plurality of logarithms after balance processing, and perform the weight parameter after balance processing with the logarithm after balance processing. Weighted operation.
然后分别计算每个学生识别组中多个学生识别结果的平均值,以及计算每个学生识别组的平均值与对应的老师识别组的平均值的对数,得到多个平衡处理后的对数,将平衡处理后的权重参数与平衡处理后的对数进行加权运算。Then calculate the average value of multiple student recognition results in each student recognition group, and calculate the logarithm of the average value of each student recognition group and the average value of the corresponding teacher recognition group to obtain the logarithm after multiple balanced processing. , the weight parameter after balance processing and the logarithm after balance processing are weighted.
可选的,可以采用如下公式(2)、(3)、(4)计算平衡处理后的权重参数对老师识别结果和所述学生识别结果之间的对数进行加权运算。Optionally, the following formulas (2), (3), and (4) can be used to calculate the weight parameters after the balance processing to perform a weighting operation on the logarithm between the teacher identification result and the student identification result.
Figure PCTCN2021096649-appb-000003
Figure PCTCN2021096649-appb-000003
Figure PCTCN2021096649-appb-000004
Figure PCTCN2021096649-appb-000004
Figure PCTCN2021096649-appb-000005
Figure PCTCN2021096649-appb-000005
其中公式(2)和(3)中的z表示老师识别组中老师识别结果的数目或者学生识别组中学生识别结果的数目,e表示老师识别组或者学生识别组。公式(4)中的D DKL(P||Q) C是指损失值,P′(x)是指老师识别组的平均值,即对老师识别结果进行平衡处理得到的平均值,Q′(x)是学生识别组的平均值,即对学生识别结果进行平衡处理得到的平均值。
Figure PCTCN2021096649-appb-000006
是指学生识别组的平均值与对应的老师识别组的平均值的对数,
Figure PCTCN2021096649-appb-000007
权重参数可以是指P(x),x是指老师识别结果和学生识别结果中任意一个结果,X是指老师识别结果和学生识别结果。当P(x),Q(x)完全相同时,D DKL(P||Q) C等于零。
Among them, z in formulas (2) and (3) represents the number of teacher recognition results in the teacher recognition group or the number of student recognition results in the student recognition group, and e represents the teacher recognition group or the student recognition group. D DKL (P||Q) C in formula (4) refers to the loss value, P′(x) refers to the average value of the teacher recognition group, that is, the average value obtained by balancing the teacher recognition results, Q′( x) is the average value of the student recognition group, that is, the average value obtained by balancing the student recognition results.
Figure PCTCN2021096649-appb-000006
is the logarithm of the mean value of the student identification group and the mean value of the corresponding teacher identification group,
Figure PCTCN2021096649-appb-000007
The weight parameter may refer to P(x), where x refers to any one of the teacher identification result and the student identification result, and X refers to the teacher identification result and the student identification result. When P(x), Q(x) are identical, D DKL (P||Q) C equals zero.
本申请在KL-divergence(相对熵)函数中引入了平衡参数,可以将引入了平衡参数的KL-divergence损失函数作为DKL-divergence,在DKL-divergence中引入了平衡参数。使用DKL-divergence中可以对多个老师识别结果中每个老师识别结果进行平衡处理,以及对学生识别结果中每个学生识别结果进行平衡处理,消除多个老师识别结果以及多个学生识别结果中的异常结果。The present application introduces a balance parameter into the KL-divergence (relative entropy) function, and the KL-divergence loss function that introduces the balance parameter can be used as DKL-divergence, and the balance parameter is introduced into the DKL-divergence. Using DKL-divergence, each teacher recognition result in multiple teacher recognition results can be balanced, and each student recognition result in student recognition results can be balanced, eliminating multiple teacher recognition results and multiple student recognition results. abnormal results.
可选的,预设的老师模型包括多个老师蒸馏层,预设的学生模型中包括多个学生蒸馏层,在计算老师识别结果和学生识别结果之间的对数,并利用权重参数对对数进行加权运算,以及将计算得到的数值作为损失值对预设的学生模型进行调整时,可以确定多个学生蒸馏层中每个学生蒸馏层对应的老师蒸馏层,计算每个学生蒸馏层的学生识别结果与对应的老师蒸馏层的老师识别结果之间的对数。利用权重参数,对每个学生蒸馏层的学生识别结果与对应的老师蒸馏层的老师识别结果之间的对数进行加权运算,得到每个学生蒸馏层的损失值,分别采用每个学生蒸馏层的损失值对预设的学生模型中对应的学生蒸馏层进行调整。Optionally, the preset teacher model includes multiple teacher distillation layers, and the preset student model includes multiple student distillation layers, and the logarithm between the teacher identification result and the student identification result is calculated, and the weight parameter is used to pair the pair. When the calculated value is used as the loss value to adjust the preset student model, the teacher distillation layer corresponding to each student distillation layer in the multiple student distillation layers can be determined, and the value of each student distillation layer can be calculated. The logarithm between the student recognition result and the teacher recognition result of the corresponding teacher distillation layer. Using the weight parameter, the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer is weighted to obtain the loss value of each student distillation layer, and each student distillation layer is used separately. The loss value of is adjusted to the corresponding student distillation layer in the preset student model.
预设的老师模型包括多个老师蒸馏层,每个老师蒸馏层均有对应输出的老师识别结果,预设的学生模型中包括多个学生蒸馏层,每个学生蒸馏层也都有对应输出的学生识别结果,可以将预设的老师模型中的蒸馏层与预设的学生模型中的蒸馏层进行对应知识蒸馏。可以确定多个学生蒸馏层中每个学生蒸馏层对应的老师蒸馏层,并采用权重参数,对每个学生蒸馏层输出的学生识别结果与对应的老师蒸馏层输出的老师识别结果之间的对数进行加权运算,得到每个学生蒸馏层对应的损失值。分别采用每个学生蒸馏层的损失值对预设的学生模型中对应的学生蒸馏层进行调整,这样可以更准确的对预设的学生模型中每个学生蒸馏层进行调整,提高学生模型的准确度。The preset teacher model includes multiple teacher distillation layers, and each teacher distillation layer has a corresponding output teacher identification result. The preset student model includes multiple student distillation layers, and each student distillation layer also has a corresponding output. According to the student identification result, the corresponding knowledge distillation can be performed between the distillation layer in the preset teacher model and the distillation layer in the preset student model. The teacher distillation layer corresponding to each student distillation layer in the multiple student distillation layers can be determined, and the weight parameter is used to compare the student identification result output by each student distillation layer and the teacher identification result output by the corresponding teacher distillation layer. The number is weighted to obtain the corresponding loss value of each student distillation layer. The loss value of each student distillation layer is used to adjust the corresponding student distillation layer in the preset student model, so that each student distillation layer in the preset student model can be adjusted more accurately and the accuracy of the student model can be improved. Spend.
例如,如图4所示,为本申请实施例提供的一种模型蒸馏的示意图,如图4所示,蒸馏模型一般包括三个部分的蒸馏,即Transformer-layer(转换层)蒸馏、Embedding-layer(嵌入层)蒸馏、Prediction-layer(预测层)蒸馏,可以分别将老师模型中的三个蒸馏层与 学生模型中对应的三个蒸馏层进行相应的知识蒸馏。For example, as shown in FIG. 4, which is a schematic diagram of a model distillation provided by the embodiment of the application, as shown in FIG. 4, the distillation model generally includes three parts of distillation, namely Transformer-layer distillation, Embedding-layer distillation, and Embedding-layer distillation. Layer (embedding layer) distillation and Prediction-layer (prediction layer) distillation can respectively perform knowledge distillation on the three distillation layers in the teacher model and the corresponding three distillation layers in the student model.
可以将老师模型中的蒸馏层与学生模型中的蒸馏层进行对应知识蒸馏,即根据老师模型中的Transformer-layer转换层对学生模型中的Transformer-layer转换层进行知识蒸馏,并计算得到Transformer-layer转换层的损失值,根据该Transformer-layer转换层的损失值对Transformer-layer转换层进行调整。根据老师模型中的Embedding-layer嵌入层对学生模型中的Embedding-layer嵌入层进行知识蒸馏,计算得到该Embedding-layer嵌入层的损失值,并采用该Embedding-layer嵌入层的损失值对该Embedding-layer嵌入层进行调整。根据老师模型中的Prediction-layer预测层对学生模型中的Prediction-layer预测层进行知识蒸馏,计算得到该Prediction-layer预测层的损失值,采用该Prediction-layer预测层的损失值对该Prediction-layer预测层进行调整。这样,可以更准确的对学生模型中的蒸馏层进行调整,提供预设的学生模型的准确度。The corresponding knowledge distillation can be performed between the distillation layer in the teacher model and the distillation layer in the student model, that is, knowledge distillation is performed on the Transformer-layer conversion layer in the student model according to the Transformer-layer conversion layer in the teacher model, and the Transformer-layer conversion layer is calculated. The loss value of the layer conversion layer, the Transformer-layer conversion layer is adjusted according to the loss value of the Transformer-layer conversion layer. Perform knowledge distillation on the Embedding-layer in the student model according to the Embedding-layer in the teacher model, calculate the loss value of the Embedding-layer, and use the loss value of the Embedding-layer for the Embedding -layer embed layer to adjust. Perform knowledge distillation on the Prediction-layer prediction layer in the student model according to the Prediction-layer prediction layer in the teacher model, calculate the loss value of the Prediction-layer prediction layer, and use the loss value of the Prediction-layer prediction layer to the Prediction-layer prediction layer. layer prediction layer to adjust. In this way, the distillation layer in the student model can be adjusted more accurately, providing the preset accuracy of the student model.
如图4所示,学生模型(新模型)有M个Transformers(转换层)层,老师模型(原始模型)有N个transformers(转换层)层,使用n=g(m)表示学生模型的第m层从老师模型的第n层获得信息。我们设定Embedding-layer(嵌入层)蒸馏为第0层,Output layer(输出层)蒸馏为第M+1层,Transformers(转换层)层为第1层至第M层。下列函数公式(5)可以用于表示老师到学生的知识迁移的蒸馏损失。As shown in Figure 4, the student model (new model) has M Transformers (transformation layers) layers, and the teacher model (original model) has N transformers (transformers (transformation layers) layers, using n=g(m) to represent the first layer of the student model The m layer gets information from the nth layer of the teacher model. We set the Embedding-layer distillation as the 0th layer, the Output layer as the M+1th layer, and the Transformers layer as the 1st to Mth layers. The following functional formula (5) can be used to represent the distillation loss of knowledge transfer from teacher to student.
Figure PCTCN2021096649-appb-000008
Figure PCTCN2021096649-appb-000008
其中,公式(5)中L layer表示指定层的损失函数,这里的指定层可以是指Transformer-layer(转换层)、Embedding-layer(嵌入层)、Prediction-layer(预测层)。指m=0指代Embedding layer,m=M+1指代Output layer,m=1,2,…,M是指学生模型计划学习老师模型的Transformer层的数量;λ m表示每层损失权重的超参数;L model表示所有层知识蒸馏损失之和。其中,Transformer-layer(转换层)蒸馏、Embedding-layer(嵌入层)蒸馏、Prediction-layer(预测层)蒸馏均可以设置对应的损失函数,并可以根据对应的损失函数得到的损失值对对应层进行调整。 Among them, L layer in formula (5) represents the loss function of the specified layer, and the specified layer here may refer to Transformer-layer (conversion layer), Embedding-layer (embedding layer), Prediction-layer (prediction layer). Refers to m=0 refers to the Embedding layer, m=M+1 refers to the Output layer, m=1,2,...,M refers to the number of Transformer layers that the student model plans to learn the teacher model; λ m represents the loss weight of each layer Hyperparameters; L model represents the sum of knowledge distillation losses for all layers. Among them, Transformer-layer (conversion layer) distillation, Embedding-layer (embedding layer) distillation, Prediction-layer (prediction layer) distillation can set the corresponding loss function, and can set the corresponding loss function according to the loss value obtained by the corresponding loss function. make adjustments.
模型蒸馏中的Transformer-layer(转换层)蒸馏包含基于自注意力的蒸馏和基于隐状态的蒸馏,相关技术中Transformer-layer(转换层)的自注意力矩阵蒸馏的目标函数为如下公式(6)。The Transformer-layer distillation in the model distillation includes self-attention-based distillation and hidden state-based distillation. The objective function of the self-attention matrix distillation of the Transformer-layer (conversion layer) in the related art is the following formula (6 ).
Figure PCTCN2021096649-appb-000009
Figure PCTCN2021096649-appb-000009
其中,公式(6)中的h为注意力头的个数,i表示第i个注意力头,
Figure PCTCN2021096649-appb-000010
分别表示学生模型和老师模型的注意力矩阵,MSE指代均方误差损失。
Among them, h in formula (6) is the number of attention heads, i represents the ith attention head,
Figure PCTCN2021096649-appb-000010
denote the attention matrices of the student model and the teacher model, respectively, and MSE refers to the mean squared error loss.
则Transformer(转换层)中每一层输出矩阵拟合的目标函数为如下公式(7)。Then the objective function of the output matrix fitting of each layer in the Transformer (transformation layer) is the following formula (7).
L hidn=MSE(H SW h,H R)   (7) L hidn = MSE(H S W h , H R ) (7)
其中,公式(7)中的H S∈R l×d′和H T∈R l×d分别指代学生和老师的隐藏态矩阵,R l×d′,R l×d分别表示学生和老师隐藏态矩阵空间大小,l,d分别表示训练样本数据的长度(即输入句子的长度)和隐藏层的大小。W h∈R d′×d是一个可学习的线性变化矩阵,将学生的隐藏态矩阵转换为和老师一样的结果空间大小。 Among them, H S ∈ R l×d′ and H T ∈ R l×d in formula (7) refer to the hidden state matrices of the student and the teacher, respectively, R l×d′ and R l×d represent the student and the teacher, respectively The size of the hidden state matrix space, l, d represent the length of the training sample data (ie the length of the input sentence) and the size of the hidden layer, respectively. W h ∈ R d′×d is a learnable linearly varying matrix that transforms the student’s hidden state matrix into the same result space size as the teacher.
相关技术中Embedding-layer(蒸馏层)的损失函数为如下公式(8)。The loss function of the Embedding-layer (distillation layer) in the related art is the following formula (8).
L embd=MSE(E SW e,E T)   (8) Lembd = MSE( E S We ,E T ) (8)
其中,公式中(8)的E S和E T分别指代学生和老师模型的embedding(蒸馏层)的矩阵,W e是类似于W h的线性变换矩阵。 Among them, E S and E T of (8) in the formula refer to the embedding (distillation layer) matrices of the student and teacher models, respectively, and We are a linear transformation matrix similar to W h .
相关技术中Prediction-layer(预测层)蒸馏的输出层采用了软交叉熵损失为如下公式(9)。In the related art, the output layer of the Prediction-layer distillation adopts the soft cross-entropy loss as the following formula (9).
L pred=-softmax(z T)·log_softmax(z S/t)   (9) L pred = -softmax(z T )·log_softmax(z S /t) (9)
其中,公式(9)中的z S和z T分别是预设的老师模型的老师识别结果和预设的学生模型的学生识别结果,log_softmax()表示对数似然,t指代蒸馏温度。 Among them, z S and z T in formula (9) are the teacher recognition result of the preset teacher model and the student recognition result of the preset student model, respectively, log_softmax() represents the log-likelihood, and t refers to the distillation temperature.
综上可得,相关技术中模型蒸馏总的目标函数可表示为如下公式(10)。To sum up, the overall objective function of model distillation in the related art can be expressed as the following formula (10).
Figure PCTCN2021096649-appb-000011
Figure PCTCN2021096649-appb-000011
其中,公式(10)中的MSE,全称为Mean Squared Error,也叫均方误差,一般用来检测模型预测值与真实值之间的偏差。假设真实值的结果分布为observed,预测值得结果分布为predicted,样本空间大小为n,则两个分布之间的差异的可以表示为如下公式(11)。Among them, the MSE in formula (10), the full name is Mean Squared Error, also known as the mean square error, is generally used to detect the deviation between the model predicted value and the true value. Assuming that the result distribution of the true value is observed, the result distribution of the predicted value is predicted, and the sample space size is n, the difference between the two distributions can be expressed as the following formula (11).
Figure PCTCN2021096649-appb-000012
Figure PCTCN2021096649-appb-000012
相关技术中采用MSE计算得到学生模型的损失值,但由公式(11)可以看出而相关技术中的MSE无差别地关注结果分布中的所有位置,得到的损失值并不能很好的反映学生模型与老师模型之间的差异,因此并不能准确的对预设的学生模型进行调整。In the related art, MSE is used to calculate the loss value of the student model, but it can be seen from formula (11) that the MSE in the related art pays attention to all positions in the result distribution indiscriminately, and the obtained loss value does not reflect the student well. The difference between the model and the teacher model, so the preset student model cannot be adjusted accurately.
而在本申请实施例中,采用上述DKL-divergence来计算学生模型的损失值,例如在计算Transformer(转换层)对应的损失值时,假设老师模型和学生模型进行模型蒸馏的层对应分别为U和V,对于每一个对应层的知识蒸馏,注意力矩阵对应为
Figure PCTCN2021096649-appb-000013
Figure PCTCN2021096649-appb-000014
T和S分别指代老师和学生,则对于训练样本数据对应的老师识别结果和学生识别结果,确定平衡参数C之后,可以在注意力矩阵对应为
Figure PCTCN2021096649-appb-000015
Figure PCTCN2021096649-appb-000016
进行采样,得到老师识别组和学生识别组分别对应的子分布p T,q S,进行求平均值之后得到的分布为
Figure PCTCN2021096649-appb-000017
Figure PCTCN2021096649-appb-000018
为如下公式(12)以及
Figure PCTCN2021096649-appb-000019
的公式为如下公式(13)
In the embodiment of the present application, the above-mentioned DKL-divergence is used to calculate the loss value of the student model. For example, when calculating the loss value corresponding to the Transformer (transformation layer), it is assumed that the teacher model and the student model perform model distillation. The layers corresponding to U are respectively U and V, for the knowledge distillation of each corresponding layer, the attention matrix corresponds to
Figure PCTCN2021096649-appb-000013
and
Figure PCTCN2021096649-appb-000014
T and S refer to teachers and students respectively, then for the teacher identification results and student identification results corresponding to the training sample data, after determining the balance parameter C, the attention matrix can be corresponding to
Figure PCTCN2021096649-appb-000015
and
Figure PCTCN2021096649-appb-000016
Sampling is performed to obtain the sub-distributions p T , q S corresponding to the teacher identification group and the student identification group respectively, and the distribution obtained after averaging is:
Figure PCTCN2021096649-appb-000017
and
Figure PCTCN2021096649-appb-000018
is the following formula (12) and
Figure PCTCN2021096649-appb-000019
The formula is the following formula (13)
Figure PCTCN2021096649-appb-000020
Figure PCTCN2021096649-appb-000020
Figure PCTCN2021096649-appb-000021
Figure PCTCN2021096649-appb-000021
其中,公式(12)和公式(13)中的z表示老师识别组中老师识别结果的数目和学生识别组中学生识别结果的数目,e表示老师识别组和学生识别组。Among them, z in formula (12) and formula (13) represents the number of teacher recognition results in the teacher recognition group and the number of student recognition results in the student recognition group, and e represents the teacher recognition group and the student recognition group.
则学生模型的损失函数
Figure PCTCN2021096649-appb-000022
表示为如下公式(14)。
Then the loss function of the student model
Figure PCTCN2021096649-appb-000022
It is expressed as the following formula (14).
Figure PCTCN2021096649-appb-000023
Figure PCTCN2021096649-appb-000023
其中,公式(14)中的χ表示注意力矩阵A∈R 1×l所在的概率空间,实际表示训练样本数据的结果分布,l表示训练样本数据的长度。则对于拥有h个注意力头的Transformer(转换层),对长度为l的训练样本数据的总损失表示为如下公式(15)。 Among them, χ in formula (14) represents the probability space where the attention matrix A∈R 1×l is located, which actually represents the result distribution of the training sample data, and l represents the length of the training sample data. Then, for a Transformer (transformation layer) with h attention heads, the total loss for training sample data of length l is expressed as the following formula (15).
Figure PCTCN2021096649-appb-000024
Figure PCTCN2021096649-appb-000024
其中,公式(15)中的t是指长度为l的训练样本数据中某一长度为t的子训练样本数据,是指h个注意力头中的某一个注意头。Among them, t in formula (15) refers to a certain sub-training sample data of length t in the training sample data of length l, and refers to a certain attention head among h attention heads.
本方案中的DKL-divergence使用P′(x)(根据老师识别概率生成的权重参数)对P′(x)与Q′(x)比值的对数(每个学生识别组的平均值与对应的老师识别组的平均值的对数)进行加权,当P′(x)值越大,则DKL-divergence计算结果相对更大。换言之,DKL-divergence更加关注概率分布中高概率的位置,而优先正确匹配概念分布中真正高可能性的事件是有实际价值的。而相关技术中的MSE无差别地关注分布中的所有位置。因此,使用DKL-divergence相对于MSE更加适合用于计算学生模型的损失值。同时,本方案中可以采用平衡常数消除训练样本数据对应的老师识别结果和学生识别结果中的异常结果(即异常概念),如老师识别结果中某个老师识别结果的值为0.2,对应的学生识别结果的值为0.0001,则计算得到的P′(x)与Q′(x)之间的对数会变得异常大,导致梯度差距极大。反之又可能造成梯度消失的问题。所以引入平衡参数对可能存在的异常概率值进行平衡,可大大降低极端值的出现。这样,可以提高学生模型的准确度,以及提高学生模型的数据处理能力,使学习模型识别待处理数据得到的识别结果与老师模型识别待处理数据得到的识别结果匹配,即使学生模型具备老师模型的数据处理能力。The DKL-divergence in this scheme uses P'(x) (a weight parameter generated according to the teacher's recognition probability) to the logarithm of the ratio of P'(x) to Q'(x) (the average value of each student's recognition group and the corresponding The logarithm of the average value of the teacher identification group) is weighted. When the value of P'(x) is larger, the calculation result of DKL-divergence is relatively larger. In other words, DKL-divergence pays more attention to high-probability locations in the probability distribution, and there is practical value in prioritizing correctly matching the truly high-probability events in the concept distribution. Whereas the MSE in the related art focuses on all positions in the distribution indiscriminately. Therefore, using DKL-divergence is more suitable for calculating the loss value of the student model than MSE. At the same time, a balance constant can be used in this scheme to eliminate abnormal results (ie abnormal concepts) in the teacher identification results and student identification results corresponding to the training sample data. For example, if the value of a teacher identification result in the teacher identification results is 0.2, the corresponding student identification results If the value of the recognition result is 0.0001, the calculated logarithm between P'(x) and Q'(x) will become abnormally large, resulting in a large gradient gap. On the contrary, it may cause the problem of gradient disappearance. Therefore, the introduction of balance parameters to balance the possible abnormal probability values can greatly reduce the occurrence of extreme values. In this way, the accuracy of the student model can be improved, and the data processing ability of the student model can be improved, so that the recognition result obtained by the learning model from identifying the data to be processed matches the recognition result obtained by the teacher model from identifying the data to be processed, even if the student model has the characteristics of the teacher model. data processing capability.
本申请实施例中,获取用于对预设的学生模型进行训练的训练样本数据,采用预设的学生模型和预设的老师模型分别对训练样本数据进行识别,得到训练样本数据的老师识别结果和学生识别结果。由老师识别结果获取用于对预设的学生模型的识别结果进行调整的权重参数,计算老师识别结果和学生识别结果之间的对数,并利用权重参数对对数进行加权运算,以及将计算得到的数值作为损失值对预设的学生模型进行调整。这样可以根据权重参数合理分配训练样本数据的识别结果以及训练样本数据的预测结果中不同结果对学生模型进行调整的调整权重,这样更关注概率值更高的识别结果,可以使得到的损失值更加准确,以及可以使调整后的学生模型更准确。同时,引入平衡参数,消除训练样本数据的老师识别结果和学生识别结果中的异常结果,以此避免由于老师识别结果和学生识别结果中存在异常结果,而导致得到的损失值不准确的情况。通过本申请可以使学生模型具备老师模型的数据处理能力,提高学生模型的准确度。In the embodiment of the present application, the training sample data for training the preset student model is obtained, the preset student model and the preset teacher model are used to identify the training sample data respectively, and the teacher identification result of the training sample data is obtained. and student identification results. Obtain the weight parameter used to adjust the recognition result of the preset student model from the teacher recognition result, calculate the logarithm between the teacher recognition result and the student recognition result, and use the weight parameter to perform a weighting operation on the logarithm, and calculate the logarithm. The obtained value is used as the loss value to adjust the preset student model. In this way, the recognition results of the training sample data and the adjustment weights for adjusting the student model for different results in the prediction results of the training sample data can be reasonably allocated according to the weight parameters, so that more attention is paid to the recognition results with higher probability values, which can make the obtained loss value more accurate, and can make the adjusted student model more accurate. At the same time, a balance parameter is introduced to eliminate the abnormal results in the teacher identification results and the student identification results of the training sample data, so as to avoid inaccurate loss values due to abnormal results in the teacher identification results and the student identification results. Through this application, the student model can have the data processing capability of the teacher model, and the accuracy of the student model can be improved.
请参见图5,是本申请实施例提供的另一种模型蒸馏方法的流程示意图。该方法可由计算机设备来执行,如图5所示,该另一种模型蒸馏方法可以包括步骤S201-S207。Please refer to FIG. 5 , which is a schematic flowchart of another model distillation method provided by an embodiment of the present application. The method may be performed by computer equipment, as shown in FIG. 5 , the other model distillation method may include steps S201-S207.
S201,获取用于对预设的学生模型进行训练的训练样本数据。S201: Acquire training sample data for training a preset student model.
S202,采用预设的学生模型和预设的老师模型分别对训练样本数据进行识别,得到训练样本数据的老师识别结果和学生识别结果。S202, using a preset student model and a preset teacher model to identify the training sample data, respectively, to obtain a teacher identification result and a student identification result of the training sample data.
S203,由老师识别结果获取用于对预设的学生模型的识别结果进行调整的权重参数。S203: Obtain, from the teacher's recognition result, a weight parameter for adjusting the recognition result of the preset student model.
S204,计算老师识别结果和学生识别结果之间的对数,并利用权重参数对对数进行加权运算。S204: Calculate the logarithm between the teacher's identification result and the student's identification result, and perform a weighting operation on the logarithm by using a weight parameter.
本申请实施例中步骤S201-S204的具体内容可以参考图1所描述的实施例中的内容,本申请实施例在此不再累述。For the specific content of steps S201 to S204 in this embodiment of the present application, reference may be made to the content in the embodiment described in FIG. 1 , which is not repeated in this embodiment of the present application.
S205,验证损失值是否满足收敛状态条件。S205, verify whether the loss value satisfies the condition of the convergence state.
获取学生模型的损失值后,确定损失值是否满足收敛状态条件,该收敛条件是指损失值小于用户预设的损失阈值,或者该损失值为对应损失函数的最小值。After obtaining the loss value of the student model, it is determined whether the loss value satisfies the convergence state condition, where the convergence condition means that the loss value is less than the loss threshold preset by the user, or the loss value is the minimum value of the corresponding loss function.
可选的,在验证损失值是否满足收敛状态条件时,可以获取用于计算损失值的损失函数的最小取值,若损失值与最小值不相同,则确定损失值不满足收敛条件;或者,验证损失值是否小于预设损失阈值,若损失值大于或等于预设损失阈值,则确定损失值不满足所 述收敛条件。Optionally, when verifying whether the loss value satisfies the convergence state condition, the minimum value of the loss function used to calculate the loss value can be obtained, and if the loss value is not the same as the minimum value, it is determined that the loss value does not meet the convergence condition; or, It is verified whether the loss value is less than the preset loss threshold value, and if the loss value is greater than or equal to the preset loss threshold value, it is determined that the loss value does not satisfy the convergence condition.
在验证损失值是否满足收敛状态条件时,获取用于计算损失值的损失函数的最小取值,若损失值与最小值不相同,或者小于该最小值,则确定损失值不满足收敛条件。或者,验证损失值是否小于预设损失阈值,若损失值大于或等于预设损失阈值,则确定损失值不满足收敛条件。该预设损失阈值可以根据学生模型的数据处理类型或者根据其他指标进行设置。When verifying whether the loss value satisfies the convergence state condition, the minimum value of the loss function used to calculate the loss value is obtained. If the loss value is not the same as the minimum value, or is smaller than the minimum value, it is determined that the loss value does not meet the convergence condition. Alternatively, it is verified whether the loss value is less than the preset loss threshold value, and if the loss value is greater than or equal to the preset loss threshold value, it is determined that the loss value does not satisfy the convergence condition. The preset loss threshold can be set according to the data processing type of the student model or according to other indicators.
S206,若损失值不满足收敛条件,则确定损失值所属的损失程度。S206, if the loss value does not satisfy the convergence condition, determine the loss degree to which the loss value belongs.
S207,根据损失程度对预设的学生模型中的参数进行调整。S207: Adjust the parameters in the preset student model according to the degree of loss.
若损失值不满足收敛条件,则说明老师模型对训练样本数据进行识别得到的老师识别结果,与学生模型对训练样本数据进行预测得到的学生识别结果之间的差异较大,即学生模型识别待处理数据得到的识别结果与老师模型识别待处理数据得到的识别结果不匹配。则确定损失值所属的损失程度,根据该损失程度对预设的学生模型中的参数进行调整。如损失程度越大,则对预设学生模型中的参数的调整越大;损失程度越小,则对预设学生模型中的参数的调整越小。这样,基于损失值对预设的学生模型进行调整,可以实现在学生模型的错误程度越大时,进行更大程度的调整,进而提高学生模型的收敛速度,提高训练效率,同时,也使得对学生模型的调整操作更加精准,进而提高学生模型训练的精度。If the loss value does not meet the convergence condition, it means that the teacher recognition result obtained by the teacher model recognizing the training sample data is quite different from the student recognition result obtained by the student model predicting the training sample data, that is, the student model recognition needs to be The recognition result obtained by processing the data does not match the recognition result obtained by the teacher model identifying the data to be processed. Then, the loss degree to which the loss value belongs is determined, and the parameters in the preset student model are adjusted according to the loss degree. If the degree of loss is larger, the adjustment of the parameters in the preset student model is larger; the smaller the degree of loss is, the smaller the adjustment of the parameters in the preset student model is. In this way, adjusting the preset student model based on the loss value can realize a greater degree of adjustment when the error degree of the student model is greater, thereby improving the convergence speed of the student model and improving the training efficiency. The adjustment operation of the student model is more accurate, thereby improving the training accuracy of the student model.
本申请实施例的具体内容可以参看图1所描述的实施例的内容,本申请实施例在此不再累述。For the specific content of the embodiment of the present application, reference may be made to the content of the embodiment described in FIG. 1 , and the embodiment of the present application will not be repeated here.
本申请实施例中,获取用于对预设的学生模型进行训练的训练样本数据,采用预设的学生模型和预设的老师模型分别对训练样本数据进行识别,得到训练样本数据的老师识别结果和学生识别结果。由老师识别结果获取用于对预设的学生模型的识别结果进行调整的权重参数,计算老师识别结果和学生识别结果之间的对数,并利用权重参数对对数进行加权运算,以及将计算得到的数值作为损失值对预设的学生模型进行调整。这样可以根据权重参数合理分配训练样本数据的识别结果以及训练样本数据的预测结果中不同结果对学生模型进行调整的调整权重,这样更关注概率值更高的识别结果,可以使得到的损失值更加准确,以及可以使调整后的学生模型更准确。同时,引入平衡参数,消除训练样本数据的老师识别结果和学生识别结果中的异常结果,以此避免由于老师识别结果和学生识别结果中存在异常结果,而导致得到的损失值不准确的情况。并根据损失值对的损失程度对学生模型进行调整,可以实现在学生模型的错误程度越大时,进行更大程度的调整,进而提高学生模型训练的准确度。通过本申请可以使学生模型具备老师模型的数据处理能力,提高学生模型的准确度。In the embodiment of the present application, the training sample data for training the preset student model is obtained, the preset student model and the preset teacher model are used to identify the training sample data respectively, and the teacher identification result of the training sample data is obtained. and student identification results. Obtain the weight parameter used to adjust the recognition result of the preset student model from the teacher recognition result, calculate the logarithm between the teacher recognition result and the student recognition result, and use the weight parameter to perform a weighting operation on the logarithm, and calculate the logarithm. The obtained value is used as the loss value to adjust the preset student model. In this way, the recognition results of the training sample data and the adjustment weights for adjusting the student model for different results in the prediction results of the training sample data can be reasonably allocated according to the weight parameters, so that more attention is paid to the recognition results with higher probability values, which can make the obtained loss value more accurate, and can make the adjusted student model more accurate. At the same time, a balance parameter is introduced to eliminate the abnormal results in the teacher identification results and the student identification results of the training sample data, so as to avoid inaccurate loss values due to abnormal results in the teacher identification results and the student identification results. The student model is adjusted according to the loss degree of the loss value pair, which can realize a greater degree of adjustment when the error degree of the student model is greater, thereby improving the accuracy of the student model training. Through this application, the student model can have the data processing capability of the teacher model, and the accuracy of the student model can be improved.
请参见图6,是本申请实施例提供的一种模型蒸馏装置的结构示意图。上述模型蒸馏装置可以是运行于计算机设备中的一个计算机程序(包括程序代码),例如该模型蒸馏装置为一个应用软件;该装置可以用于执行本申请实施例提供的方法中的相应步骤。如图6所示,该模型蒸馏装置可以包括:第一获取模块11、识别模块12、第二获取模块13、调整模块14。Please refer to FIG. 6 , which is a schematic structural diagram of a model distillation apparatus provided by an embodiment of the present application. The above-mentioned model distillation apparatus may be a computer program (including program code) running in a computer device, for example, the model distillation apparatus is an application software; the apparatus may be used to execute corresponding steps in the methods provided in the embodiments of the present application. As shown in FIG. 6 , the model distillation apparatus may include: a first acquisition module 11 , an identification module 12 , a second acquisition module 13 , and an adjustment module 14 .
第一获取模块11,用于获取用于对预设的学生模型进行训练的训练样本数据;The first acquisition module 11 is used for acquiring training sample data for training a preset student model;
识别模块12,用于采用所述预设的学生模型和预设的老师模型分别对所述训练样本数据进行识别,得到所述训练样本数据的老师识别结果和学生识别结果,其中,所述预设的学生模型由所述预设的老师模型指导训练得到;The identification module 12 is configured to use the preset student model and the preset teacher model to identify the training sample data respectively, and obtain the teacher identification result and the student identification result of the training sample data, wherein the preset The set student model is obtained through the guidance and training of the preset teacher model;
第二获取模块13,用于由所述老师识别结果获取用于对所述预设的学生模型的识别结果进行调整的权重参数;The second obtaining module 13 is configured to obtain, from the teacher identification result, a weight parameter for adjusting the identification result of the preset student model;
调整模块14,用于计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,以及将计算得到的数值作为损失值对所述预设的学 生模型进行调整。The adjustment module 14 is used to calculate the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, and use the calculated value as a loss value for all parameters. The preset student model is adjusted.
其中,所述老师识别结果为多个,所述老师识别结果表示识别概率;Wherein, the teacher identification results are multiple, and the teacher identification results represent the identification probability;
上述第二获取模块13包括:The above-mentioned second acquisition module 13 includes:
获取单元,用于获取用于对所述老师识别结果进行平衡的平衡参数;an acquisition unit for acquiring a balance parameter for balancing the teacher identification result;
第一分组单元,用于以所述预设的老师模型的识别顺序,将得到的所述多个老师识别结果按照所述平衡参数进行分组,得到依次排列的多个老师识别组,其中,所述多个老师识别组中每个老师识别组中包含相同数目的老师识别结果;The first grouping unit is configured to group the obtained multiple teacher identification results according to the balance parameter according to the preset identification sequence of the teacher model to obtain multiple teacher identification groups arranged in sequence, wherein all Each teacher identification group in the multiple teacher identification groups contains the same number of teacher identification results;
第一计算单元,用于分别计算所述每个老师识别组中多个老师识别结果的平均值,并将得到的多个平均值作为平衡处理后的权重参数。The first calculation unit is configured to calculate the average value of multiple teacher identification results in each teacher identification group respectively, and use the obtained multiple average values as weight parameters after balancing processing.
其中,所述学生识别结果为多个;Wherein, the student identification results are multiple;
上述调整模块14包括:The above-mentioned adjustment module 14 includes:
第二分组单元,用于以所述识别顺序,将得到的所述多个学生识别结果按照所述平衡参数进行分组,得到依次排列的多个学生识别组,其中,所述多个学生识别组中每个学生识别组包含相同数目的学生识别结果,每个所述老师识别组与每个所述学生识别组按照所述识别顺序一一对应;The second grouping unit is configured to group the obtained multiple student identification results according to the balance parameter in the identification sequence to obtain multiple student identification groups arranged in sequence, wherein the multiple student identification groups are In each student identification group contains the same number of student identification results, each described teacher identification group and each described student identification group are in one-to-one correspondence according to the identification sequence;
第二计算单元,用于分别计算所述每个学生识别组中多个学生识别结果的平均值;a second calculation unit, used for calculating the average value of the recognition results of multiple students in each of the student recognition groups;
第三计算单元,用于分别计算所述每个学生识别组的平均值与对应的老师识别组的平均值的对数,得到多个平衡处理后的对数;The third calculation unit is used to calculate the logarithm of the mean value of each student identification group and the mean value of the corresponding teacher identification group respectively, and obtain a plurality of logarithms after balanced processing;
第一加权运算单元,用于将所述平衡处理后的权重参数与所述平衡处理后的对数进行加权运算。The first weighting operation unit is configured to perform weighting operation on the weight parameter after the balance processing and the logarithm after the balance processing.
其中,上述获取单元具体用于:Wherein, the above acquisition unit is specifically used for:
获取所述多个老师识别结果中老师识别结果的数目;Obtain the number of teacher identification results in the plurality of teacher identification results;
确定所述多个老师识别结果中老师识别结果的数目所属的预设阈值范围;Determine the preset threshold range to which the number of teacher identification results in the plurality of teacher identification results belongs;
从平衡参数库中确定与所述预设阈值范围对应的目标平衡参数,将所述目标平衡参数作为对老师识别结果进行平衡的平衡参数,所述平衡参数库中包括至少一个平衡参数,以及所述至少一个平衡参数中每个平衡参数与预设阈值范围之间的对应关系。A target balance parameter corresponding to the preset threshold range is determined from a balance parameter library, and the target balance parameter is used as a balance parameter for balancing the teacher identification result. The balance parameter library includes at least one balance parameter, and all The corresponding relationship between each balance parameter of the at least one balance parameter and the preset threshold range.
其中,上述调整模块14包括:Wherein, the above-mentioned adjustment module 14 includes:
验证单元,用于验证所述损失值是否满足收敛状态条件;a verification unit, configured to verify whether the loss value satisfies the convergence state condition;
第一确定单元,用于若所述损失值不满足所述收敛条件,则确定所述损失值所属的损失程度;a first determining unit, configured to determine the degree of loss to which the loss value belongs if the loss value does not satisfy the convergence condition;
第一调整单元,用于根据所述损失程度对所述预设的学生模型中的参数进行调整。A first adjustment unit, configured to adjust parameters in the preset student model according to the loss degree.
其中,上述验证单元具体用于:Among them, the above verification unit is specifically used for:
获取用于计算所述损失值的损失函数的最小取值,若所述损失值与所述最小值不相同,则确定所述损失值不满足所述收敛条件;或者,Obtain the minimum value of the loss function used to calculate the loss value, and if the loss value is different from the minimum value, it is determined that the loss value does not satisfy the convergence condition; or,
验证所述损失值是否小于预设损失阈值,若所述损失值大于或等于所述预设损失阈值,则确定所述损失值不满足所述收敛条件。It is verified whether the loss value is less than a preset loss threshold, and if the loss value is greater than or equal to the preset loss threshold, it is determined that the loss value does not satisfy the convergence condition.
其中,所述预设的老师模型包括多个老师蒸馏层,所述预设的学生模型中包括多个学生蒸馏层;Wherein, the preset teacher model includes multiple teacher distillation layers, and the preset student model includes multiple student distillation layers;
上述调整模块14包括:The above-mentioned adjustment module 14 includes:
第二确定单元,用于确定所述多个学生蒸馏层中每个学生蒸馏层对应的老师蒸馏层;a second determining unit, configured to determine the teacher distillation layer corresponding to each student distillation layer in the plurality of student distillation layers;
第四计算单元,用于计算每个学生蒸馏层的学生识别结果与对应的老师蒸馏层的老师识别结果之间的对数;The fourth calculation unit is used to calculate the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer;
第二加权运算单元,用于利用所述权重参数,对所述每个学生蒸馏层的学生识别结果与对应的老师蒸馏层的老师识别结果之间的对数进行加权运算,得到所述每个学生蒸馏层 的损失值;The second weighting operation unit is configured to use the weight parameter to perform a weighted operation on the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer, to obtain the each The loss value of the student distillation layer;
第二调整单元,用于分别采用所述每个学生蒸馏层的损失值对所述预设的学生模型中对应的学生蒸馏层进行调整。The second adjustment unit is configured to adjust the corresponding student distillation layer in the preset student model by using the loss value of each student distillation layer respectively.
本申请实施例中,获取用于对预设的老师模型进行训练的训练样本数据,采用预设的学生模型和预设的老师模型分别对训练样本数据进行识别,得到训练样本数据的老师识别结果和学生识别结果。由老师识别结果获取用于对预设的学生模型的识别结果进行调整的权重参数,计算老师识别结果和学生识别结果之间的对数,并利用权重参数对对数进行加权运算,以及将计算得到的数值作为损失值对预设的学生模型进行调整。这样可以根据权重参数合理分配训练样本数据的识别结果以及训练样本数据的预测结果中不同结果对学生模型进行调整的调整权重,这样更关注概率值更高的识别结果,可以使得到的损失值更加准确,以及可以使调整后的学生模型更准确。同时,引入平衡参数,消除训练样本数据的老师识别结果和学生识别结果中的异常结果,以此避免由于老师识别结果和学生识别结果中存在异常结果,而导致得到的损失值不准确的情况。并根据损失值对的损失程度对学生模型进行调整,可以实现在学生模型的错误程度越大时,进行更大程度的调整,进而提高学生模型训练的准确度。通过本申请可以使学生模型具备老师模型的数据处理能力,提高学生模型的准确度。In the embodiment of the present application, the training sample data used for training the preset teacher model is obtained, the preset student model and the preset teacher model are used to identify the training sample data respectively, and the teacher identification result of the training sample data is obtained. and student identification results. Obtain the weight parameter used to adjust the recognition result of the preset student model from the teacher recognition result, calculate the logarithm between the teacher recognition result and the student recognition result, and use the weight parameter to perform a weighting operation on the logarithm, and calculate the logarithm. The obtained value is used as the loss value to adjust the preset student model. In this way, the recognition results of the training sample data and the adjustment weights for adjusting the student model for different results in the prediction results of the training sample data can be reasonably allocated according to the weight parameters, so that more attention is paid to the recognition results with higher probability values, which can make the obtained loss value more accurate, and can make the adjusted student model more accurate. At the same time, a balance parameter is introduced to eliminate the abnormal results in the teacher identification results and the student identification results of the training sample data, so as to avoid inaccurate loss values due to abnormal results in the teacher identification results and the student identification results. And adjusting the student model according to the loss degree of the loss value pair can realize a greater degree of adjustment when the error degree of the student model is greater, thereby improving the accuracy of the student model training. Through this application, the student model can have the data processing capability of the teacher model, and the accuracy of the student model can be improved.
根据本申请的一个实施例,图1或者图5所示的模型蒸馏方法所涉及的步骤可由图6所示的模型蒸馏装置中的各个模块来执行。例如,图1中所示的步骤S101可由图6中的第一获取模块11来执行;图1中所示的步骤S102可由图6中的识别模块12来执行;图1中所示的步骤S103可由图6中的第二获取模块13来执行;图1中所示的步骤S104可由图6中的调整模块14来执行。According to an embodiment of the present application, the steps involved in the model distillation method shown in FIG. 1 or FIG. 5 may be performed by each module in the model distillation apparatus shown in FIG. 6 . For example, step S101 shown in FIG. 1 may be performed by the first acquisition module 11 shown in FIG. 6 ; step S102 shown in FIG. 1 may be performed by the identification module 12 shown in FIG. 6 ; step S103 shown in FIG. 1 It can be performed by the second acquisition module 13 in FIG. 6 ; step S104 shown in FIG. 1 can be performed by the adjustment module 14 in FIG. 6 .
请参见图7,是本申请实施例提供的一种计算机设备的结构示意图。该计算机设备可包括处理器及存储器。可选的,该计算机设备还可包括网络接口和/或用户接口。例如,如图7所示,上述计算机设备1000可以包括:处理器1001,网络接口1004和存储器1005,此外,上述计算机设备1000还可以包括:用户接口1003,和至少一个通信总线1002。其中,通信总线1002用于实现这些组件之间的连接通信。其中,用户接口1003可以包括显示屏(Display)、键盘(Keyboard),可选用户接口1003还可以包括标准的有线接口、无线接口。网络接口1004可选的可以包括标准的有线接口、无线接口(如WI-FI接口)。存储器1005可以是高速RAM存储器,也可以是非易失性的存储器(non-volatile memory),例如至少一个磁盘存储器。存储器1005可选的还可以是至少一个位于远离前述处理器1001的存储装置。如图7所示,作为一种计算机可读存储介质的存储器1005中可以包括操作系统、网络通信模块、用户接口模块以及设备控制应用程序。Please refer to FIG. 7 , which is a schematic structural diagram of a computer device provided by an embodiment of the present application. The computer device may include a processor and memory. Optionally, the computer device may further include a network interface and/or a user interface. For example, as shown in FIG. 7 , the above-mentioned computer device 1000 may include: a processor 1001 , a network interface 1004 and a memory 1005 , in addition, the above-mentioned computer device 1000 may further include: a user interface 1003 , and at least one communication bus 1002 . Among them, the communication bus 1002 is used to realize the connection and communication between these components. The user interface 1003 may include a display screen (Display) and a keyboard (Keyboard), and the optional user interface 1003 may also include a standard wired interface and a wireless interface. Optionally, the network interface 1004 may include a standard wired interface and a wireless interface (eg, a WI-FI interface). The memory 1005 can be a high-speed RAM memory, or a non-volatile memory, such as at least one disk memory. Optionally, the memory 1005 may also be at least one storage device located away from the aforementioned processor 1001 . As shown in FIG. 7 , the memory 1005 as a computer-readable storage medium may include an operating system, a network communication module, a user interface module, and a device control application program.
在图7所示的计算机设备1000中,网络接口1004可提供网络通讯功能;而用户接口1003主要用于为用户提供输入的接口;而处理器1001可以用于调用存储器1005中存储的设备控制应用程序,以实现:In the computer device 1000 shown in FIG. 7 , the network interface 1004 can provide a network communication function; the user interface 1003 is mainly used to provide an input interface for the user; and the processor 1001 can be used to call the device control application stored in the memory 1005 program to achieve:
获取用于对预设的学生模型进行训练的训练样本数据;Obtain training sample data for training a preset student model;
采用所述预设的学生模型和预设的老师模型分别对所述训练样本数据进行识别,得到所述训练样本数据的老师识别结果和学生识别结果,其中,所述预设的学生模型由所述预设的老师模型指导训练得到;The training sample data is identified by the preset student model and the preset teacher model, respectively, and the teacher identification result and the student identification result of the training sample data are obtained, wherein the preset student model is determined by the The above-mentioned preset teacher model guides the training;
由所述老师识别结果获取用于对所述预设的学生模型的识别结果进行调整的权重参数;Obtaining a weight parameter for adjusting the recognition result of the preset student model from the teacher recognition result;
计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,以及将计算得到的数值作为损失值对所述预设的学生模型进行调整。Calculate the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, and use the calculated value as a loss value to the preset student model make adjustments.
可选的,处理器1001可以用于调用存储器1005中存储的设备控制应用程序,以实现:Optionally, the processor 1001 can be used to call the device control application program stored in the memory 1005 to realize:
获取用于对所述老师识别结果进行平衡的平衡参数;obtaining balancing parameters for balancing the teacher identification results;
以所述预设的老师模型的识别顺序,将得到的所述多个老师识别结果按照所述平衡参数进行分组,得到依次排列的多个老师识别组,其中,所述多个老师识别组中每个老师识别组中包含相同数目的老师识别结果;In the recognition sequence of the preset teacher model, the obtained multiple teacher identification results are grouped according to the balance parameter, to obtain multiple teacher identification groups arranged in sequence, wherein, among the multiple teacher identification groups, Each teacher identification group contains the same number of teacher identification results;
分别计算所述每个老师识别组中多个老师识别结果的平均值,并将得到的多个平均值作为平衡处理后的权重参数。Calculate the average value of multiple teacher recognition results in each teacher recognition group respectively, and use the obtained multiple average values as weight parameters after balancing processing.
可选的,处理器1001可以用于调用存储器1005中存储的设备控制应用程序,以实现:Optionally, the processor 1001 can be used to call the device control application program stored in the memory 1005 to realize:
以所述识别顺序,将得到的所述多个学生识别结果按照所述平衡参数进行分组,得到依次排列的多个学生识别组,其中,所述多个学生识别组中每个学生识别组包含相同数目的学生识别结果,每个所述老师识别组与每个所述学生识别组按照所述识别顺序一一对应;In the identification sequence, the obtained multiple student identification results are grouped according to the balance parameter, to obtain multiple student identification groups arranged in sequence, wherein each student identification group in the multiple student identification groups includes: The same number of student identification results, each described teacher identification group and each described student identification group are in one-to-one correspondence according to the identification sequence;
分别计算所述每个学生识别组中多个学生识别结果的平均值;Calculate the average value of multiple student identification results in each of the student identification groups respectively;
分别计算所述每个学生识别组的平均值与对应的老师识别组的平均值的对数,得到多个平衡处理后的对数;Calculate the logarithm of the mean value of the described each student identification group and the mean value of the corresponding teacher identification group respectively, and obtain the logarithm after a plurality of balanced treatments;
将所述平衡处理后的权重参数与所述平衡处理后的对数进行加权运算。A weighting operation is performed on the weight parameter after the balance processing and the logarithm after the balance processing.
可选的,处理器1001可以用于调用存储器1005中存储的设备控制应用程序,以实现:Optionally, the processor 1001 can be used to call the device control application program stored in the memory 1005 to realize:
获取所述多个老师识别结果中老师识别结果的数目;Obtain the number of teacher identification results in the plurality of teacher identification results;
确定所述多个老师识别结果中老师识别结果的数目所属的预设阈值范围;Determine the preset threshold range to which the number of teacher identification results in the plurality of teacher identification results belongs;
从平衡参数库中确定与所述预设阈值范围对应的目标平衡参数,将所述目标平衡参数作为对老师识别结果进行平衡的平衡参数,所述平衡参数库中包括至少一个平衡参数,以及所述至少一个平衡参数中每个平衡参数与预设阈值范围之间的对应关系。A target balance parameter corresponding to the preset threshold range is determined from a balance parameter library, and the target balance parameter is used as a balance parameter for balancing the teacher identification result. The balance parameter library includes at least one balance parameter, and all The corresponding relationship between each balance parameter of the at least one balance parameter and the preset threshold range.
可选的,处理器1001可以用于调用存储器1005中存储的设备控制应用程序,以实现:Optionally, the processor 1001 can be used to call the device control application program stored in the memory 1005 to realize:
验证所述损失值是否满足收敛状态条件;verifying whether the loss value satisfies the convergence state condition;
若所述损失值不满足所述收敛条件,则确定所述损失值所属的损失程度;If the loss value does not satisfy the convergence condition, determining the degree of loss to which the loss value belongs;
根据所述损失程度对所述预设的学生模型中的参数进行调整。The parameters in the preset student model are adjusted according to the loss degree.
可选的,处理器1001可以用于调用存储器1005中存储的设备控制应用程序,以实现:Optionally, the processor 1001 can be used to call the device control application program stored in the memory 1005 to realize:
获取用于计算所述损失值的损失函数的最小取值,若所述损失值与所述最小值不相同,则确定所述损失值不满足所述收敛条件;或者,Obtain the minimum value of the loss function used to calculate the loss value, and if the loss value is different from the minimum value, it is determined that the loss value does not satisfy the convergence condition; or,
验证所述损失值是否小于预设损失阈值,若所述损失值大于或等于所述预设损失阈值,则确定所述损失值不满足所述收敛条件。It is verified whether the loss value is less than a preset loss threshold, and if the loss value is greater than or equal to the preset loss threshold, it is determined that the loss value does not satisfy the convergence condition.
可选的,处理器1001可以用于调用存储器1005中存储的设备控制应用程序,以实现:Optionally, the processor 1001 can be used to call the device control application program stored in the memory 1005 to realize:
所述计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,以及将计算得到的数值作为损失值对所述预设的学生模型进行调整,包括:The logarithm between the teacher identification result and the student identification result is calculated, and the logarithm is weighted by using the weight parameter, and the calculated value is used as a loss value for the preset value. Student models are adjusted to include:
确定所述多个学生蒸馏层中每个学生蒸馏层对应的老师蒸馏层;determining the teacher distillation layer corresponding to each student distillation layer in the plurality of student distillation layers;
计算每个学生蒸馏层的学生识别结果与对应的老师蒸馏层的老师识别结果之间的对数;Calculate the logarithm between the student recognition result of each student distillation layer and the teacher recognition result of the corresponding teacher distillation layer;
利用所述权重参数,对所述每个学生蒸馏层的学生识别结果与对应的老师蒸馏层的老师识别结果之间的对数进行加权运算,得到所述每个学生蒸馏层的损失值;Using the weight parameter, the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer is weighted to obtain the loss value of each student distillation layer;
分别采用所述每个学生蒸馏层的损失值对所述预设的学生模型中对应的学生蒸馏层进行调整。The corresponding student distillation layer in the preset student model is adjusted by using the loss value of each student distillation layer.
本申请实施例中,获取用于对预设的学生模型进行训练的训练样本数据,采用预设的学生模型和预设的老师模型分别对训练样本数据进行识别,得到训练样本数据的老师识别结果和学生识别结果。由老师识别结果获取用于对预设的学生模型的识别结果进行调整的权重参数,计算老师识别结果和学生识别结果之间的对数,并利用权重参数对对数进行加权运算,以及将计算得到的数值作为损失值对预设的学生模型进行调整。这样可以根据权重参数合理分配训练样本数据的识别结果以及训练样本数据的预测结果中不同结果对学生 模型进行调整的调整权重,这样更关注概率值更高的识别结果,可以使得到的损失值更加准确,以及可以使调整后的学生模型更准确。同时,引入平衡参数,消除训练样本数据的老师识别结果和学生识别结果中的异常结果,以此避免由于老师识别结果和学生识别结果中存在异常结果,而导致得到的损失值不准确的情况。并根据损失值对的损失程度对学生模型进行调整,可以实现在学生模型的错误程度越大时,进行更大程度的调整,进而提高学生模型训练的准确度。通过本申请可以使学生模型具备老师模型的数据处理能力,提高学生模型的准确度。In the embodiment of the present application, the training sample data for training the preset student model is obtained, the preset student model and the preset teacher model are used to identify the training sample data respectively, and the teacher identification result of the training sample data is obtained. and student identification results. Obtain the weight parameter used to adjust the recognition result of the preset student model from the teacher recognition result, calculate the logarithm between the teacher recognition result and the student recognition result, and use the weight parameter to perform a weighting operation on the logarithm, and calculate the logarithm. The obtained value is used as the loss value to adjust the preset student model. In this way, the recognition results of the training sample data and the adjustment weights for adjusting the student model for different results in the prediction results of the training sample data can be reasonably allocated according to the weight parameters, so that more attention is paid to the recognition results with higher probability values, which can make the obtained loss value more accurate, and can make the adjusted student model more accurate. At the same time, a balance parameter is introduced to eliminate the abnormal results in the teacher identification results and the student identification results of the training sample data, so as to avoid inaccurate loss values due to abnormal results in the teacher identification results and the student identification results. And adjusting the student model according to the loss degree of the loss value pair can realize a greater degree of adjustment when the error degree of the student model is greater, thereby improving the accuracy of the student model training. Through this application, the student model can have the data processing capability of the teacher model, and the accuracy of the student model can be improved.
应当理解,本申请实施例中所描述的计算机设备1000可执行前文图1以及前文图5所对应实施例中对上述模型蒸馏方法的描述,也可执行前文图6所对应实施例中对上述模型蒸馏装置的描述,在此不再赘述。另外,对采用相同方法的有益效果描述,也不再进行赘述。It should be understood that the computer device 1000 described in this embodiment of the present application can execute the description of the above model distillation method in the foregoing embodiment corresponding to FIG. 1 and the foregoing FIG. The description of the distillation apparatus will not be repeated here. In addition, the description of the beneficial effects of using the same method will not be repeated.
本申请实施例中,此外,这里需要指出的是:本申请实施例还提供了一种计算机可读存储介质,且上述计算机可读存储介质中存储有前文提及的模型蒸馏装置所执行的计算机程序,且上述计算机程序包括程序指令,当上述处理器执行上述程序指令时,能够执行前文图1或者图5对应实施例中对上述模型蒸馏方法的描述,因此,这里将不再进行赘述。另外,对采用相同方法的有益效果描述,也不再进行赘述。对于本申请所涉及的计算机可读存储介质实施例中未披露的技术细节,请参照本申请方法实施例的描述。In the embodiment of the present application, in addition, it should be pointed out here that: the embodiment of the present application also provides a computer-readable storage medium, and the computer-readable storage medium described above stores the computer executed by the model distillation apparatus mentioned above. The above computer program includes program instructions. When the above-mentioned processor executes the above-mentioned program instructions, the above-mentioned description of the above-mentioned model distillation method in the corresponding embodiment of FIG. 1 or FIG. 5 can be executed. In addition, the description of the beneficial effects of using the same method will not be repeated. For technical details not disclosed in the computer-readable storage medium embodiments involved in the present application, please refer to the description of the method embodiments of the present application.
可选的,本申请涉及的存储介质如计算机可读存储介质可以是非易失性的,也可以是易失性的。Optionally, the storage medium involved in this application, such as a computer-readable storage medium, may be non-volatile or volatile.
作为示例,上述程序指令可被部署在一个计算机设备上执行,或者被部署位于一个地点的多个计算机设备上执行,又或者,在分布在多个地点且通过通信网络互连的多个计算机设备上执行,分布在多个地点且通过通信网络互连的多个计算机设备可以组成区块链网络。By way of example, the above-described program instructions may be deployed and executed on one computer device, or on multiple computer devices located at one site, or alternatively, distributed across multiple sites and interconnected by a communication network. Executed on a blockchain, multiple computer devices distributed in multiple locations and interconnected by a communication network can form a blockchain network.
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,上述的程序可存储于一计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,上述的存储介质可为磁盘、光盘、只读存储器(Read-Only Memory,ROM)或随机存储器(Random Access Memory,RAM)等。Those of ordinary skill in the art can understand that all or part of the process in the method of the above embodiment can be implemented by instructing the relevant hardware through a computer program, and the above program can be stored in a computer-readable storage medium, and the program is in During execution, it may include the processes of the embodiments of the above-mentioned methods. Wherein, the above-mentioned storage medium may be a magnetic disk, an optical disk, a read-only memory (Read-Only Memory, ROM) or a random access memory (Random Access Memory, RAM) and the like.
以上所揭露的仅为本申请较佳实施例而已,当然不能以此来限定本申请之权利范围,因此依本申请权利要求所作的等同变化,仍属本申请所涵盖的范围。The above disclosures are only the preferred embodiments of the present application, and of course, the scope of the rights of the present application cannot be limited by this. Therefore, equivalent changes made according to the claims of the present application are still within the scope of the present application.

Claims (20)

  1. 一种模型蒸馏方法,包括:A model distillation method including:
    获取用于对预设的学生模型进行训练的训练样本数据;Obtain training sample data for training a preset student model;
    采用所述预设的学生模型和预设的老师模型分别对所述训练样本数据进行识别,得到所述训练样本数据的老师识别结果和学生识别结果,其中,所述预设的学生模型由所述预设的老师模型指导训练得到;The training sample data is identified by the preset student model and the preset teacher model, respectively, and the teacher identification result and the student identification result of the training sample data are obtained, wherein the preset student model is determined by the The above-mentioned preset teacher model guides the training;
    由所述老师识别结果获取用于对所述预设的学生模型的识别结果进行调整的权重参数;Obtaining a weight parameter for adjusting the recognition result of the preset student model from the teacher recognition result;
    计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,以及将计算得到的数值作为损失值对所述预设的学生模型进行调整。Calculate the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, and use the calculated value as a loss value to the preset student model make adjustments.
  2. 根据权利要求1所述的方法,其中,所述老师识别结果为多个,所述老师识别结果表示识别概率;The method according to claim 1, wherein the teacher identification results are multiple, and the teacher identification results represent identification probability;
    所述由所述老师识别结果获取用于对所述预设的学生模型的识别结果进行调整的权重参数,包括:The weight parameters obtained from the teacher's recognition result for adjusting the recognition result of the preset student model include:
    获取用于对所述老师识别结果进行平衡的平衡参数;obtaining balancing parameters for balancing the teacher identification results;
    以所述预设的老师模型的识别顺序,将得到的所述多个老师识别结果按照所述平衡参数进行分组,得到依次排列的多个老师识别组,其中,所述多个老师识别组中每个老师识别组中包含相同数目的老师识别结果;In the recognition sequence of the preset teacher model, the obtained multiple teacher identification results are grouped according to the balance parameter, to obtain multiple teacher identification groups arranged in sequence, wherein, among the multiple teacher identification groups, Each teacher identification group contains the same number of teacher identification results;
    分别计算所述每个老师识别组中多个老师识别结果的平均值,并将得到的多个平均值作为平衡处理后的权重参数。Calculate the average value of multiple teacher recognition results in each teacher recognition group respectively, and use the obtained multiple average values as weight parameters after balancing processing.
  3. 根据权利要求2所述的方法,其中,所述学生识别结果为多个;The method of claim 2, wherein the student identification results are multiple;
    所述计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,包括:Described calculating the logarithm between the teacher identification result and the student identification result, and using the weight parameter to perform a weighted operation on the logarithm, including:
    以所述识别顺序,将得到的所述多个学生识别结果按照所述平衡参数进行分组,得到依次排列的多个学生识别组,其中,所述多个学生识别组中每个学生识别组包含相同数目的学生识别结果,每个所述老师识别组与每个所述学生识别组按照所述识别顺序一一对应;In the identification sequence, the obtained multiple student identification results are grouped according to the balance parameter, to obtain multiple student identification groups arranged in sequence, wherein each student identification group in the multiple student identification groups includes: The same number of student identification results, each described teacher identification group and each described student identification group are in one-to-one correspondence according to the identification sequence;
    分别计算所述每个学生识别组中多个学生识别结果的平均值;Calculate the average value of multiple student identification results in each of the student identification groups respectively;
    分别计算所述每个学生识别组的平均值与对应的老师识别组的平均值的对数,得到多个平衡处理后的对数;Calculate the logarithm of the mean value of the described each student identification group and the mean value of the corresponding teacher identification group respectively, and obtain the logarithm after a plurality of balanced treatments;
    将所述平衡处理后的权重参数与所述平衡处理后的对数进行加权运算。A weighting operation is performed on the weight parameter after the balance processing and the logarithm after the balance processing.
  4. 根据权利要求2所述的方法,其中,所述获取用于对所述老师识别结果进行平衡的平衡参数,包括:The method according to claim 2, wherein the obtaining a balance parameter for balancing the teacher identification result comprises:
    获取所述多个老师识别结果中老师识别结果的数目;Obtain the number of teacher identification results in the plurality of teacher identification results;
    确定所述多个老师识别结果中老师识别结果的数目所属的预设阈值范围;Determine the preset threshold range to which the number of teacher identification results in the plurality of teacher identification results belongs;
    从平衡参数库中确定与所述预设阈值范围对应的目标平衡参数,将所述目标平衡参数作为对老师识别结果进行平衡的平衡参数,所述平衡参数库中包括至少一个平衡参数,以及所述至少一个平衡参数中每个平衡参数与预设阈值范围之间的对应关系。A target balance parameter corresponding to the preset threshold range is determined from a balance parameter library, and the target balance parameter is used as a balance parameter for balancing the teacher identification result. The balance parameter library includes at least one balance parameter, and all The corresponding relationship between each balance parameter of the at least one balance parameter and the preset threshold range.
  5. 根据权利要求1所述的方法,其中,所述将计算得到的数值作为损失值对所述预设的学生模型进行调整,包括:The method according to claim 1, wherein the adjusting the preset student model using the calculated value as a loss value comprises:
    验证所述损失值是否满足收敛状态条件;verifying whether the loss value satisfies the convergence state condition;
    若所述损失值不满足所述收敛条件,则确定所述损失值所属的损失程度;If the loss value does not satisfy the convergence condition, determining the degree of loss to which the loss value belongs;
    根据所述损失程度对所述预设的学生模型中的参数进行调整。The parameters in the preset student model are adjusted according to the loss degree.
  6. 根据权利要求5所述的方法,其中,所述验证所述损失值是否满足收敛状态条件,包括:The method of claim 5, wherein the verifying whether the loss value satisfies a convergence state condition comprises:
    获取用于计算所述损失值的损失函数的最小取值,若所述损失值与所述最小值不相同, 则确定所述损失值不满足所述收敛条件;或者,Obtain the minimum value of the loss function used to calculate the loss value, and if the loss value is different from the minimum value, determine that the loss value does not satisfy the convergence condition; or,
    验证所述损失值是否小于预设损失阈值,若所述损失值大于或等于所述预设损失阈值,则确定所述损失值不满足所述收敛条件。It is verified whether the loss value is less than a preset loss threshold, and if the loss value is greater than or equal to the preset loss threshold, it is determined that the loss value does not satisfy the convergence condition.
  7. 根据权利要求1所述的方法,其中,所述预设的老师模型包括多个老师蒸馏层,所述预设的学生模型中包括多个学生蒸馏层;The method of claim 1, wherein the preset teacher model includes a plurality of teacher distillation layers, and the preset student model includes a plurality of student distillation layers;
    所述计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,以及将计算得到的数值作为损失值对所述预设的学生模型进行调整,包括:The logarithm between the teacher identification result and the student identification result is calculated, and the logarithm is weighted by using the weight parameter, and the calculated value is used as a loss value for the preset value. Student models are adjusted to include:
    确定所述多个学生蒸馏层中每个学生蒸馏层对应的老师蒸馏层;determining the teacher distillation layer corresponding to each student distillation layer in the plurality of student distillation layers;
    计算每个学生蒸馏层的学生识别结果与对应的老师蒸馏层的老师识别结果之间的对数;Calculate the logarithm between the student recognition result of each student distillation layer and the teacher recognition result of the corresponding teacher distillation layer;
    利用所述权重参数,对所述每个学生蒸馏层的学生识别结果与对应的老师蒸馏层的老师识别结果之间的对数进行加权运算,得到所述每个学生蒸馏层的损失值;Using the weight parameter, the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer is weighted to obtain the loss value of each student distillation layer;
    分别采用所述每个学生蒸馏层的损失值对所述预设的学生模型中对应的学生蒸馏层进行调整。The corresponding student distillation layer in the preset student model is adjusted by using the loss value of each student distillation layer.
  8. 一种模型蒸馏装置,包括:A model distillation apparatus, comprising:
    第一获取模块,用于获取用于对预设的学生模型进行训练的训练样本数据;a first acquisition module, used for acquiring training sample data for training a preset student model;
    识别模块,用于采用所述预设的学生模型和预设的老师模型分别对所述训练样本数据进行识别,得到所述训练样本数据的老师识别结果和学生识别结果,其中,所述预设的学生模型由所述预设的老师模型指导训练得到;an identification module, configured to identify the training sample data by using the preset student model and the preset teacher model, respectively, to obtain a teacher identification result and a student identification result of the training sample data, wherein the preset The student model is obtained by the guidance and training of the preset teacher model;
    第二获取模块,用于由所述老师识别结果获取用于对所述预设的学生模型的识别结果进行调整的权重参数;a second acquisition module, configured to acquire, from the teacher recognition result, a weight parameter for adjusting the recognition result of the preset student model;
    调整模块,用于计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,以及将计算得到的数值作为损失值对所述预设的学生模型进行调整。The adjustment module is used to calculate the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, and use the calculated value as a loss value to the said logarithm. The preset student models can be adjusted.
  9. 一种计算机设备,包括:处理器及存储器;A computer equipment, including: a processor and a memory;
    其中,所述存储器用于存储程序代码,所述处理器用于调用所述程序代码,以执行以下方法:Wherein, the memory is used to store program code, and the processor is used to call the program code to execute the following method:
    获取用于对预设的学生模型进行训练的训练样本数据;Obtain training sample data for training a preset student model;
    采用所述预设的学生模型和预设的老师模型分别对所述训练样本数据进行识别,得到所述训练样本数据的老师识别结果和学生识别结果,其中,所述预设的学生模型由所述预设的老师模型指导训练得到;The training sample data is identified by the preset student model and the preset teacher model, respectively, and the teacher identification result and the student identification result of the training sample data are obtained, wherein the preset student model is determined by the The above-mentioned preset teacher model guides the training;
    由所述老师识别结果获取用于对所述预设的学生模型的识别结果进行调整的权重参数;Obtaining a weight parameter for adjusting the recognition result of the preset student model from the teacher recognition result;
    计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,以及将计算得到的数值作为损失值对所述预设的学生模型进行调整。Calculate the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, and use the calculated value as a loss value to the preset student model make adjustments.
  10. 根据权利要求9所述的计算机设备,其中,所述老师识别结果为多个,所述老师识别结果表示识别概率;The computer device according to claim 9, wherein the teacher recognition results are multiple, and the teacher recognition results represent recognition probability;
    执行所述由所述老师识别结果获取用于对所述预设的学生模型的识别结果进行调整的权重参数,包括:Executing the obtaining of the weight parameter used for adjusting the recognition result of the preset student model by the teacher's recognition result, including:
    获取用于对所述老师识别结果进行平衡的平衡参数;obtaining balancing parameters for balancing the teacher identification results;
    以所述预设的老师模型的识别顺序,将得到的所述多个老师识别结果按照所述平衡参数进行分组,得到依次排列的多个老师识别组,其中,所述多个老师识别组中每个老师识别组中包含相同数目的老师识别结果;In the recognition sequence of the preset teacher model, the obtained multiple teacher identification results are grouped according to the balance parameter, to obtain multiple teacher identification groups arranged in sequence, wherein, among the multiple teacher identification groups, Each teacher identification group contains the same number of teacher identification results;
    分别计算所述每个老师识别组中多个老师识别结果的平均值,并将得到的多个平均值作为平衡处理后的权重参数。Calculate the average value of multiple teacher recognition results in each teacher recognition group respectively, and use the obtained multiple average values as weight parameters after balancing processing.
  11. 根据权利要求10所述的计算机设备,其中,所述学生识别结果为多个;The computer device according to claim 10, wherein the student identification results are plural;
    执行所述计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,包括:Perform the calculation of the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, including:
    以所述识别顺序,将得到的所述多个学生识别结果按照所述平衡参数进行分组,得到依次排列的多个学生识别组,其中,所述多个学生识别组中每个学生识别组包含相同数目的学生识别结果,每个所述老师识别组与每个所述学生识别组按照所述识别顺序一一对应;In the identification sequence, the obtained multiple student identification results are grouped according to the balance parameter, to obtain multiple student identification groups arranged in sequence, wherein each student identification group in the multiple student identification groups includes: The same number of student identification results, each described teacher identification group and each described student identification group are in one-to-one correspondence according to the identification sequence;
    分别计算所述每个学生识别组中多个学生识别结果的平均值;Calculate the average value of multiple student identification results in each of the student identification groups respectively;
    分别计算所述每个学生识别组的平均值与对应的老师识别组的平均值的对数,得到多个平衡处理后的对数;Calculate the logarithm of the mean value of the described each student identification group and the mean value of the corresponding teacher identification group respectively, and obtain the logarithm after a plurality of balanced treatments;
    将所述平衡处理后的权重参数与所述平衡处理后的对数进行加权运算。A weighting operation is performed on the weight parameter after the balance processing and the logarithm after the balance processing.
  12. 根据权利要求10所述的计算机设备,其中,执行所述获取用于对所述老师识别结果进行平衡的平衡参数,包括:The computer device of claim 10, wherein performing the obtaining of a balance parameter for balancing the teacher identification results comprises:
    获取所述多个老师识别结果中老师识别结果的数目;Obtain the number of teacher identification results in the plurality of teacher identification results;
    确定所述多个老师识别结果中老师识别结果的数目所属的预设阈值范围;Determine the preset threshold range to which the number of teacher identification results in the plurality of teacher identification results belongs;
    从平衡参数库中确定与所述预设阈值范围对应的目标平衡参数,将所述目标平衡参数作为对老师识别结果进行平衡的平衡参数,所述平衡参数库中包括至少一个平衡参数,以及所述至少一个平衡参数中每个平衡参数与预设阈值范围之间的对应关系。A target balance parameter corresponding to the preset threshold range is determined from a balance parameter library, and the target balance parameter is used as a balance parameter for balancing the teacher identification result. The balance parameter library includes at least one balance parameter, and all The corresponding relationship between each balance parameter of the at least one balance parameter and the preset threshold range.
  13. 根据权利要求9所述的计算机设备,其中,执行所述将计算得到的数值作为损失值对所述预设的学生模型进行调整,包括:The computer device according to claim 9, wherein performing the adjustment of the preset student model using the calculated value as a loss value comprises:
    验证所述损失值是否满足收敛状态条件;verifying whether the loss value satisfies the convergence state condition;
    若所述损失值不满足所述收敛条件,则确定所述损失值所属的损失程度;If the loss value does not satisfy the convergence condition, determining the degree of loss to which the loss value belongs;
    根据所述损失程度对所述预设的学生模型中的参数进行调整。The parameters in the preset student model are adjusted according to the loss degree.
  14. 根据权利要求9所述的计算机设备,其中,所述预设的老师模型包括多个老师蒸馏层,所述预设的学生模型中包括多个学生蒸馏层;The computer device according to claim 9, wherein the preset teacher model includes a plurality of teacher distillation layers, and the preset student model includes a plurality of student distillation layers;
    执行所述计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,以及将计算得到的数值作为损失值对所述预设的学生模型进行调整,包括:Perform the calculation of the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, and use the calculated value as a loss value for the preset The student model is adjusted to include:
    确定所述多个学生蒸馏层中每个学生蒸馏层对应的老师蒸馏层;determining the teacher distillation layer corresponding to each student distillation layer in the plurality of student distillation layers;
    计算每个学生蒸馏层的学生识别结果与对应的老师蒸馏层的老师识别结果之间的对数;Calculate the logarithm between the student recognition result of each student distillation layer and the teacher recognition result of the corresponding teacher distillation layer;
    利用所述权重参数,对所述每个学生蒸馏层的学生识别结果与对应的老师蒸馏层的老师识别结果之间的对数进行加权运算,得到所述每个学生蒸馏层的损失值;Using the weight parameter, the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer is weighted to obtain the loss value of each student distillation layer;
    分别采用所述每个学生蒸馏层的损失值对所述预设的学生模型中对应的学生蒸馏层进行调整。The corresponding student distillation layer in the preset student model is adjusted by using the loss value of each student distillation layer.
  15. 一种计算机可读存储介质,其中,所述计算机可读存储介质存储有计算机程序,所述计算机程序包括程序指令,所述程序指令当被处理器执行时,执行以下方法:A computer-readable storage medium, wherein the computer-readable storage medium stores a computer program, the computer program includes program instructions, and the program instructions, when executed by a processor, perform the following method:
    获取用于对预设的学生模型进行训练的训练样本数据;Obtain training sample data for training a preset student model;
    采用所述预设的学生模型和预设的老师模型分别对所述训练样本数据进行识别,得到所述训练样本数据的老师识别结果和学生识别结果,其中,所述预设的学生模型由所述预设的老师模型指导训练得到;The training sample data is identified by the preset student model and the preset teacher model, respectively, and the teacher identification result and the student identification result of the training sample data are obtained, wherein the preset student model is determined by the The above-mentioned preset teacher model guides the training;
    由所述老师识别结果获取用于对所述预设的学生模型的识别结果进行调整的权重参数;Obtaining a weight parameter for adjusting the recognition result of the preset student model from the teacher recognition result;
    计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,以及将计算得到的数值作为损失值对所述预设的学生模型进行调整。Calculate the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, and use the calculated value as a loss value to the preset student model make adjustments.
  16. 根据权利要求15所述的计算机可读存储介质,其中,所述老师识别结果为多个,所述老师识别结果表示识别概率;The computer-readable storage medium according to claim 15, wherein the teacher identification results are multiple, and the teacher identification results represent identification probability;
    执行所述由所述老师识别结果获取用于对所述预设的学生模型的识别结果进行调整的权重参数,包括:Executing the obtaining of the weight parameter used for adjusting the recognition result of the preset student model by the teacher's recognition result, including:
    获取用于对所述老师识别结果进行平衡的平衡参数;obtaining balancing parameters for balancing the teacher identification results;
    以所述预设的老师模型的识别顺序,将得到的所述多个老师识别结果按照所述平衡参数进行分组,得到依次排列的多个老师识别组,其中,所述多个老师识别组中每个老师识别组中包含相同数目的老师识别结果;In the recognition sequence of the preset teacher model, the obtained multiple teacher identification results are grouped according to the balance parameter, to obtain multiple teacher identification groups arranged in sequence, wherein, among the multiple teacher identification groups, Each teacher identification group contains the same number of teacher identification results;
    分别计算所述每个老师识别组中多个老师识别结果的平均值,并将得到的多个平均值作为平衡处理后的权重参数。Calculate the average value of multiple teacher recognition results in each teacher recognition group respectively, and use the obtained multiple average values as weight parameters after balancing processing.
  17. 根据权利要求16所述的计算机可读存储介质,其中,所述学生识别结果为多个;The computer-readable storage medium of claim 16, wherein the student identification results are plural;
    执行所述计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,包括:Perform the calculation of the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, including:
    以所述识别顺序,将得到的所述多个学生识别结果按照所述平衡参数进行分组,得到依次排列的多个学生识别组,其中,所述多个学生识别组中每个学生识别组包含相同数目的学生识别结果,每个所述老师识别组与每个所述学生识别组按照所述识别顺序一一对应;In the identification sequence, the obtained multiple student identification results are grouped according to the balance parameter, to obtain multiple student identification groups arranged in sequence, wherein each student identification group in the multiple student identification groups includes: The same number of student identification results, each described teacher identification group and each described student identification group are in one-to-one correspondence according to the identification sequence;
    分别计算所述每个学生识别组中多个学生识别结果的平均值;Calculate the average value of multiple student identification results in each of the student identification groups respectively;
    分别计算所述每个学生识别组的平均值与对应的老师识别组的平均值的对数,得到多个平衡处理后的对数;Calculate the logarithm of the mean value of the described each student identification group and the mean value of the corresponding teacher identification group respectively, and obtain the logarithm after a plurality of balanced treatments;
    将所述平衡处理后的权重参数与所述平衡处理后的对数进行加权运算。A weighting operation is performed on the weight parameter after the balance processing and the logarithm after the balance processing.
  18. 根据权利要求16所述的计算机可读存储介质,其中,执行所述获取用于对所述老师识别结果进行平衡的平衡参数,包括:The computer-readable storage medium of claim 16, wherein performing the obtaining of a balance parameter for balancing the teacher identification results comprises:
    获取所述多个老师识别结果中老师识别结果的数目;Obtain the number of teacher identification results in the plurality of teacher identification results;
    确定所述多个老师识别结果中老师识别结果的数目所属的预设阈值范围;Determine the preset threshold range to which the number of teacher identification results in the plurality of teacher identification results belongs;
    从平衡参数库中确定与所述预设阈值范围对应的目标平衡参数,将所述目标平衡参数作为对老师识别结果进行平衡的平衡参数,所述平衡参数库中包括至少一个平衡参数,以及所述至少一个平衡参数中每个平衡参数与预设阈值范围之间的对应关系。A target balance parameter corresponding to the preset threshold range is determined from a balance parameter library, and the target balance parameter is used as a balance parameter for balancing the teacher identification result. The balance parameter library includes at least one balance parameter, and all The corresponding relationship between each balance parameter of the at least one balance parameter and the preset threshold range.
  19. 根据权利要求15所述的计算机可读存储介质,其中,执行所述将计算得到的数值作为损失值对所述预设的学生模型进行调整,包括:The computer-readable storage medium according to claim 15, wherein performing the adjusting of the preset student model using the calculated value as a loss value comprises:
    验证所述损失值是否满足收敛状态条件;verifying whether the loss value satisfies the convergence state condition;
    若所述损失值不满足所述收敛条件,则确定所述损失值所属的损失程度;If the loss value does not satisfy the convergence condition, determining the degree of loss to which the loss value belongs;
    根据所述损失程度对所述预设的学生模型中的参数进行调整。The parameters in the preset student model are adjusted according to the loss degree.
  20. 根据权利要求15所述的计算机可读存储介质,其中,所述预设的老师模型包括多个老师蒸馏层,所述预设的学生模型中包括多个学生蒸馏层;The computer-readable storage medium of claim 15, wherein the preset teacher model includes a plurality of teacher distillation layers, and the preset student model includes a plurality of student distillation layers;
    执行所述计算所述老师识别结果和所述学生识别结果之间的对数,并利用所述权重参数对所述对数进行加权运算,以及将计算得到的数值作为损失值对所述预设的学生模型进行调整,包括:Perform the calculation of the logarithm between the teacher identification result and the student identification result, and use the weight parameter to perform a weighted operation on the logarithm, and use the calculated value as a loss value for the preset The student model is adjusted to include:
    确定所述多个学生蒸馏层中每个学生蒸馏层对应的老师蒸馏层;determining the teacher distillation layer corresponding to each student distillation layer in the plurality of student distillation layers;
    计算每个学生蒸馏层的学生识别结果与对应的老师蒸馏层的老师识别结果之间的对数;Calculate the logarithm between the student recognition result of each student distillation layer and the teacher recognition result of the corresponding teacher distillation layer;
    利用所述权重参数,对所述每个学生蒸馏层的学生识别结果与对应的老师蒸馏层的老师识别结果之间的对数进行加权运算,得到所述每个学生蒸馏层的损失值;Using the weight parameter, the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer is weighted to obtain the loss value of each student distillation layer;
    分别采用所述每个学生蒸馏层的损失值对所述预设的学生模型中对应的学生蒸馏层进行调整。The corresponding student distillation layer in the preset student model is adjusted by using the loss value of each student distillation layer.
PCT/CN2021/096649 2020-11-20 2021-05-28 Model distillation method and apparatus, and storage medium and device WO2022105173A1 (en)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN202011313330.8 2020-11-20
CN202011313330.8A CN112465138A (en) 2020-11-20 2020-11-20 Model distillation method, device, storage medium and equipment

Publications (1)

Publication Number Publication Date
WO2022105173A1 true WO2022105173A1 (en) 2022-05-27

Family

ID=74798380

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2021/096649 WO2022105173A1 (en) 2020-11-20 2021-05-28 Model distillation method and apparatus, and storage medium and device

Country Status (2)

Country Link
CN (1) CN112465138A (en)
WO (1) WO2022105173A1 (en)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115099988A (en) * 2022-06-28 2022-09-23 腾讯科技(深圳)有限公司 Model training method, data processing method, device and computer medium
CN115170455A (en) * 2022-08-17 2022-10-11 荣耀终端有限公司 Image processing method and related device

Families Citing this family (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112465138A (en) * 2020-11-20 2021-03-09 平安科技(深圳)有限公司 Model distillation method, device, storage medium and equipment
JP7381814B2 (en) * 2020-12-15 2023-11-16 之江実験室 Automatic compression method and platform for pre-trained language models for multitasking
CN112990296B (en) * 2021-03-10 2022-10-11 中科人工智能创新技术研究院(青岛)有限公司 Image-text matching model compression and acceleration method and system based on orthogonal similarity distillation
US11200497B1 (en) * 2021-03-16 2021-12-14 Moffett Technologies Co., Limited System and method for knowledge-preserving neural network pruning
CN113239176B (en) * 2021-06-21 2022-08-23 中国平安人寿保险股份有限公司 Semantic matching model training method, device, equipment and storage medium
CN113807214B (en) * 2021-08-31 2024-01-05 中国科学院上海微系统与信息技术研究所 Small target face recognition method based on deit affiliated network knowledge distillation
CN114565759A (en) * 2022-02-22 2022-05-31 北京百度网讯科技有限公司 Image semantic segmentation model optimization method and device, electronic equipment and storage medium
CN115019060A (en) * 2022-07-12 2022-09-06 北京百度网讯科技有限公司 Target recognition method, and training method and device of target recognition model

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110852426A (en) * 2019-11-19 2020-02-28 成都晓多科技有限公司 Pre-training model integration acceleration method and device based on knowledge distillation
CN110909815A (en) * 2019-11-29 2020-03-24 深圳市商汤科技有限公司 Neural network training method, neural network training device, neural network processing device, neural network training device, image processing device and electronic equipment
CN111105008A (en) * 2018-10-29 2020-05-05 富士通株式会社 Model training method, data recognition method and data recognition device
EP3736749A1 (en) * 2019-05-09 2020-11-11 Siemens Aktiengesellschaft Method and device for controlling a device using a dataset
CN112465138A (en) * 2020-11-20 2021-03-09 平安科技(深圳)有限公司 Model distillation method, device, storage medium and equipment

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111105008A (en) * 2018-10-29 2020-05-05 富士通株式会社 Model training method, data recognition method and data recognition device
EP3736749A1 (en) * 2019-05-09 2020-11-11 Siemens Aktiengesellschaft Method and device for controlling a device using a dataset
CN110852426A (en) * 2019-11-19 2020-02-28 成都晓多科技有限公司 Pre-training model integration acceleration method and device based on knowledge distillation
CN110909815A (en) * 2019-11-29 2020-03-24 深圳市商汤科技有限公司 Neural network training method, neural network training device, neural network processing device, neural network training device, image processing device and electronic equipment
CN112465138A (en) * 2020-11-20 2021-03-09 平安科技(深圳)有限公司 Model distillation method, device, storage medium and equipment

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115099988A (en) * 2022-06-28 2022-09-23 腾讯科技(深圳)有限公司 Model training method, data processing method, device and computer medium
CN115170455A (en) * 2022-08-17 2022-10-11 荣耀终端有限公司 Image processing method and related device
CN115170455B (en) * 2022-08-17 2023-02-07 荣耀终端有限公司 Image processing method and related device

Also Published As

Publication number Publication date
CN112465138A (en) 2021-03-09

Similar Documents

Publication Publication Date Title
WO2022105173A1 (en) Model distillation method and apparatus, and storage medium and device
US20230100376A1 (en) Text sentence processing method and apparatus, computer device, and storage medium
US11081105B2 (en) Model learning device, method and recording medium for learning neural network model
EP2727103B1 (en) Speech recognition using variable-length context
US11240121B2 (en) Methods and systems for controlling data backup
WO2022116441A1 (en) Bert model fine-tuning method and apparatus based on convolutional neural network
CN110929515A (en) Reading understanding method and system based on cooperative attention and adaptive adjustment
WO2020160252A1 (en) Task-aware neural network architecture search
CN111400470A (en) Question processing method and device, computer equipment and storage medium
CN112101010B (en) Telecom industry OA office automation manuscript auditing method based on BERT
US11380301B2 (en) Learning apparatus, speech recognition rank estimating apparatus, methods thereof, and program
CN108052625A (en) A kind of entity sophisticated category method
CN114169442A (en) Remote sensing image small sample scene classification method based on double prototype network
JP2022529268A (en) Voice recognition methods and devices
CN111667069A (en) Pre-training model compression method and device and electronic equipment
WO2020154373A1 (en) Neural network training using the soft nearest neighbor loss
CN117539977A (en) Training method and device for language model
WO2020216286A1 (en) Method for training teaching style prediction model, and computer storage medium
CN116361655A (en) Model training method, standard problem prediction method, device, equipment and medium
CN113555005B (en) Model training method, model training device, confidence determining method, confidence determining device, electronic equipment and storage medium
CN110245331A (en) A kind of sentence conversion method, device, server and computer storage medium
CN114913871A (en) Target object classification method, system, electronic device and storage medium
CN116266266B (en) Multi-tone word disambiguation method, device, equipment and storage medium
Chen The Prediction of English Online Network Performance Based on the XGBoost Algorithm
CN116737888B (en) Training method of dialogue generation model and method and device for determining reply text

Legal Events

Date Code Title Description
121 Ep: the epo has been informed by wipo that ep was designated in this application

Ref document number: 21893328

Country of ref document: EP

Kind code of ref document: A1

NENP Non-entry into the national phase

Ref country code: DE

122 Ep: pct application non-entry in european phase

Ref document number: 21893328

Country of ref document: EP

Kind code of ref document: A1