CN112465138A - Model distillation method, device, storage medium and equipment - Google Patents

Model distillation method, device, storage medium and equipment Download PDF

Info

Publication number
CN112465138A
CN112465138A CN202011313330.8A CN202011313330A CN112465138A CN 112465138 A CN112465138 A CN 112465138A CN 202011313330 A CN202011313330 A CN 202011313330A CN 112465138 A CN112465138 A CN 112465138A
Authority
CN
China
Prior art keywords
teacher
student
model
identification
preset
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202011313330.8A
Other languages
Chinese (zh)
Inventor
吴天博
王健宗
程宁
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Ping An Technology Shenzhen Co Ltd
Original Assignee
Ping An Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Ping An Technology Shenzhen Co Ltd filed Critical Ping An Technology Shenzhen Co Ltd
Priority to CN202011313330.8A priority Critical patent/CN112465138A/en
Publication of CN112465138A publication Critical patent/CN112465138A/en
Priority to PCT/CN2021/096649 priority patent/WO2022105173A1/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • G06N5/02Knowledge representation; Symbolic representation

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Engineering & Computer Science (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Computational Linguistics (AREA)
  • General Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

The embodiment of the application discloses a model distillation method, a model distillation device, a storage medium and equipment. The method comprises the following steps: the method comprises the steps of obtaining training sample data for training a preset student model, and respectively identifying the training sample data by adopting the preset student model and a preset teacher model to obtain a teacher identification result and a student identification result of the training sample data. And obtaining a weight parameter for adjusting the recognition result of the preset student model by the teacher recognition result, calculating the logarithm between the teacher recognition result and the student recognition result, performing weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using the calculated numerical value as a loss value. Through the method and the device, the student model can have the data processing capacity of the teacher model, and the accuracy of the student model is improved.

Description

Model distillation method, device, storage medium and equipment
Technical Field
The present application relates to the field of computer technologies, and in particular, to a method, an apparatus, a storage medium, and a device for model distillation.
Background
Model distillation has attracted attention in recent years as an important technical scheme for model compression and acceleration, and plays an important role in promoting the field of natural language processing. The model distillation (knowledge distillation) refers to guiding and training a student model with lower accuracy and a simple structure by using a teacher model with higher accuracy and a complex structure, so as to improve the accuracy of the student model.
Although the student model can learn knowledge from the teacher model, the accuracy of the student model is improved. However, a teacher model and a student model in the existing distillation model architecture still have a certain difference, which results in poor expression effect and low accuracy of the student model.
Disclosure of Invention
The technical problem to be solved by the embodiments of the present application is to provide a model distillation method, device, storage medium, and apparatus, which can improve the accuracy and data processing capability of a student model.
In one aspect, the present embodiments provide a model distillation method, including:
acquiring training sample data for training a preset student model;
respectively identifying the training sample data by adopting the preset student model and a preset teacher model to obtain a teacher identification result and a student identification result of the training sample data, wherein the preset student model is obtained by guiding training of the preset teacher model;
acquiring a weight parameter for adjusting the recognition result of the preset student model according to the teacher recognition result;
calculating the logarithm between the teacher identification result and the student identification result, performing weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using the calculated numerical value as a loss value.
The teacher identification result represents identification probability;
the obtaining, by the teacher recognition result, a weight parameter for adjusting the recognition result of the preset student model includes:
acquiring balance parameters for balancing the teacher identification result;
according to the recognition sequence of the preset teacher model, grouping the obtained plurality of teacher recognition results according to the balance parameters to obtain a plurality of teacher recognition groups which are sequentially arranged, wherein each teacher recognition group in the plurality of teacher recognition groups comprises the same number of teacher recognition results;
and respectively calculating the average value of the plurality of teacher identification results in each teacher identification group, and taking the obtained plurality of average values as the weight parameters after the balance processing.
Wherein the number of the student identification results is multiple;
the calculating a logarithm between the teacher identification result and the student identification result, and performing a weighting operation on the logarithm by using the weight parameter includes:
according to the identification sequence, grouping the obtained student identification results according to the balance parameters to obtain a plurality of student identification groups which are sequentially arranged, wherein each student identification group in the student identification groups contains the same number of student identification results, and each teacher identification group corresponds to each student identification group one by one according to the identification sequence;
respectively calculating the average value of the identification results of a plurality of students in each student identification group;
respectively calculating the logarithm of the average value of each student identification group and the logarithm of the average value of the corresponding teacher identification group to obtain a plurality of logarithms after balance processing;
and carrying out weighting operation on the weight parameters after the balance processing and the logarithms after the balance processing.
Wherein the obtaining of the balance parameter for balancing the teacher identification result includes:
acquiring the number of teacher identification results in the plurality of teacher identification results;
determining a preset threshold range to which the number of teacher identification results in the plurality of teacher identification results belongs;
and determining a target balance parameter corresponding to the preset threshold range from a balance parameter library, and taking the target balance parameter as a balance parameter for balancing the teacher identification result, wherein the balance parameter library comprises at least one balance parameter and a corresponding relation between each balance parameter in the at least one balance parameter and the preset threshold range.
Wherein, the adjusting the preset student model by using the calculated numerical value as a loss value comprises the following steps:
verifying whether the loss value meets a convergence state condition;
if the loss value does not meet the convergence condition, determining the loss degree to which the loss value belongs;
and adjusting parameters in the preset student model according to the loss degree.
Wherein the verifying whether the loss value satisfies a convergence status condition comprises:
obtaining a minimum value of a loss function for calculating the loss value, and if the loss value is different from the minimum value, determining that the loss value does not meet the convergence condition; alternatively, the first and second electrodes may be,
and verifying whether the loss value is smaller than a preset loss threshold value or not, and if the loss value is larger than or equal to the preset loss threshold value, determining that the loss value does not meet the convergence condition.
The preset teacher model comprises a plurality of teacher distillation layers, and the preset student model comprises a plurality of student distillation layers;
the calculating a logarithm between the teacher identification result and the student identification result, performing a weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using a calculated value as a loss value, includes:
determining a teacher distillation layer corresponding to each student distillation layer in the plurality of student distillation layers;
calculating the logarithm between the student recognition result of each student distillation layer and the teacher recognition result of the corresponding teacher distillation layer;
performing weighting operation on the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer by using the weight parameters to obtain the loss value of each student distillation layer;
and respectively adopting the loss value of each student distillation layer to adjust the corresponding student distillation layer in the preset student model.
In one aspect, the present application provides a model distillation apparatus, including:
the first acquisition module is used for acquiring training sample data for training a preset student model;
the identification module is used for respectively identifying the training sample data by adopting the preset student model and a preset teacher model to obtain a teacher identification result and a student identification result of the training sample data, wherein the preset student model is obtained by the guidance and training of the preset teacher model;
the second obtaining module is used for obtaining a weight parameter for adjusting the recognition result of the preset student model according to the teacher recognition result;
and the adjusting module is used for calculating the logarithm between the teacher identification result and the student identification result, performing weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using the calculated numerical value as a loss value.
The teacher identification result represents identification probability;
the second obtaining module includes:
an acquisition unit configured to acquire a balance parameter for balancing the teacher identification result;
the first grouping unit is used for grouping the obtained multiple teacher identification results according to the balance parameters according to the identification sequence of the preset teacher model to obtain multiple teacher identification groups which are sequentially arranged, wherein each teacher identification group in the multiple teacher identification groups comprises the same number of teacher identification results;
and the first calculating unit is used for calculating the average value of the plurality of teacher identification results in each teacher identification group respectively, and taking the obtained plurality of average values as the weight parameters after the balance processing.
Wherein the number of the student identification results is multiple;
the above-mentioned adjusting module includes:
the second grouping unit is used for grouping the obtained student identification results according to the balance parameters in the identification sequence to obtain a plurality of student identification groups which are sequentially arranged, wherein each student identification group in the student identification groups contains the same number of student identification results, and each teacher identification group and each student identification group are in one-to-one correspondence in the identification sequence;
the second calculating unit is used for calculating the average value of the identification results of the students in each student identification group;
the third calculating unit is used for calculating the logarithm of the average value of each student identification group and the logarithm of the average value of the corresponding teacher identification group respectively to obtain a plurality of logarithms after balance processing;
and the first weighting operation unit is used for carrying out weighting operation on the weighting parameters after the balance processing and the logarithms after the balance processing.
Wherein, the obtaining unit is specifically configured to:
acquiring the number of teacher identification results in the plurality of teacher identification results;
determining a preset threshold range to which the number of teacher identification results in the plurality of teacher identification results belongs;
and determining a target balance parameter corresponding to the preset threshold range from a balance parameter library, and taking the target balance parameter as a balance parameter for balancing the teacher identification result, wherein the balance parameter library comprises at least one balance parameter and a corresponding relation between each balance parameter in the at least one balance parameter and the preset threshold range.
Wherein, above-mentioned adjustment module includes:
a verification unit for verifying whether the loss value satisfies a convergence state condition;
a first determining unit, configured to determine a loss degree to which the loss value belongs if the loss value does not satisfy the convergence condition;
and the first adjusting unit is used for adjusting parameters in the preset student model according to the loss degree.
Wherein, the verification unit is specifically configured to:
obtaining a minimum value of a loss function for calculating the loss value, and if the loss value is different from the minimum value, determining that the loss value does not meet the convergence condition; alternatively, the first and second electrodes may be,
and verifying whether the loss value is smaller than a preset loss threshold value or not, and if the loss value is larger than or equal to the preset loss threshold value, determining that the loss value does not meet the convergence condition.
The preset teacher model comprises a plurality of teacher distillation layers, and the preset student model comprises a plurality of student distillation layers;
the above-mentioned adjusting module includes:
the second determining unit is used for determining a teacher distillation layer corresponding to each student distillation layer in the plurality of student distillation layers;
the fourth calculation unit is used for calculating the logarithm between the student recognition result of each student distillation layer and the teacher recognition result of the corresponding teacher distillation layer;
the second weighting operation unit is used for performing weighting operation on the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer by using the weight parameters to obtain the loss value of each student distillation layer;
and the second adjusting unit is used for adjusting the corresponding student distillation layer in the preset student model by respectively adopting the loss value of each student distillation layer.
One aspect of the present application provides a computer device, comprising: a processor and a memory;
wherein, the memory is used for storing computer programs, and the processor is used for calling the computer programs to execute the following steps:
acquiring training sample data for training a preset student model;
respectively identifying the training sample data by adopting the preset student model and a preset teacher model to obtain a teacher identification result and a student identification result of the training sample data, wherein the preset student model is obtained by guiding training of the preset teacher model;
acquiring a weight parameter for adjusting the recognition result of the preset student model according to the teacher recognition result;
calculating the logarithm between the teacher identification result and the student identification result, performing weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using the calculated numerical value as a loss value.
An aspect of the embodiments of the present application provides a computer-readable storage medium, where a computer program is stored, where the computer program includes program instructions, and the program instructions, when executed by a processor, perform the following steps:
acquiring training sample data for training a preset student model;
respectively identifying the training sample data by adopting the preset student model and a preset teacher model to obtain a teacher identification result and a student identification result of the training sample data, wherein the preset student model is obtained by guiding training of the preset teacher model;
acquiring a weight parameter for adjusting the recognition result of the preset student model according to the teacher recognition result;
calculating the logarithm between the teacher identification result and the student identification result, performing weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using the calculated numerical value as a loss value.
In the embodiment of the application, training sample data for training a preset student model is acquired, the preset student model and a preset teacher model are adopted to respectively identify the training sample data, and a teacher identification result and a student identification result of the training sample data are obtained. And obtaining a weight parameter for adjusting the recognition result of the preset student model by the teacher recognition result, calculating the logarithm between the teacher recognition result and the student recognition result, performing weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using the calculated numerical value as a loss value. Therefore, the adjustment weight for adjusting the student model according to the recognition result of the training sample data and different results in the prediction result of the training sample data can be reasonably distributed according to the weight parameter, and the adjusted student model can be more accurate. Through the method and the device, the student model can have the data processing capacity of the teacher model, and the accuracy of the student model is improved.
Drawings
In order to more clearly illustrate the embodiments of the present application or the technical solutions in the prior art, the drawings used in the description of the embodiments or the prior art will be briefly described below, it is obvious that the drawings in the following description are only some embodiments of the present application, and for those skilled in the art, other drawings can be obtained according to the drawings without creative efforts.
FIG. 1 is a schematic flow diagram of a model distillation process provided herein;
FIG. 2 is a schematic diagram of a method for calculating an average of multiple teacher identification results in each teacher identification group according to an embodiment of the present application;
FIG. 3 is a schematic diagram of a method for obtaining loss values of a preset student model according to an embodiment of the present disclosure;
FIG. 4 is a schematic illustration of a model distillation provided in an embodiment of the present application;
FIG. 5 is a schematic flow diagram of another model distillation process provided herein;
FIG. 6 is a schematic flow diagram of a model distillation apparatus provided herein;
fig. 7 is a schematic structural diagram of a computer device according to an embodiment of the present application.
Detailed Description
The technical solutions in the embodiments of the present application will be clearly and completely described below with reference to the drawings in the embodiments of the present application, and it is obvious that the described embodiments are only a part of the embodiments of the present application, and not all of the embodiments. All other embodiments, which can be derived by a person skilled in the art from the embodiments given herein without making any creative effort, shall fall within the protection scope of the present application.
Fig. 1 is a schematic flow chart of a model distillation method according to an embodiment of the present disclosure. The method may be performed by a computer device, which may refer to a terminal or a server, and the terminal may include but is not limited to: smart phones, tablet computers, notebook computers, desktop computers, smart speakers, smart watches, and the like; the server may be an independent physical server, a server cluster or a distributed system formed by a plurality of physical servers, or a cloud server providing basic cloud computing services such as cloud service, a cloud database, cloud computing, a cloud function, cloud storage, Network service, cloud communication, middleware service, domain name service, security service, Content Delivery Network (CDN), big data and an artificial intelligence platform. As shown in FIG. 1, the model distillation method may include steps S101-S105.
S101, obtaining training sample data for training a preset student model.
The method comprises the steps of obtaining training sample data for training a preset student model, wherein the preset student model is a student model in model distillation, and the model distillation (knowledge distillation) is used for guiding the student model to train by using a teacher model, so that the accuracy of the student model is improved. The training sample data may refer to text data, image data, and the like.
And S102, respectively identifying the training sample data by adopting a preset student model and a preset teacher model to obtain a teacher identification result and a student identification result of the training sample data.
And identifying the training sample data by adopting a preset student model to obtain a student identification result of the training sample data. The preset student model is obtained by the guidance and training of the preset teacher model, generally speaking, the teacher model is high in accuracy and large in calculation complexity and is not suitable for being deployed in terminal equipment, the student model calculation is relatively simple and meets the requirements of the terminal equipment, and the accuracy is not enough, so that the problem can be solved by adopting model distillation (distillation), namely the preset student model is guided and trained by the preset teacher model, and the accuracy of the preset student model is improved. The teacher model and the student model are the same in data processing type, the network in the teacher model is deeper or wider, and the resolution of the teacher model is larger, namely the data processing capacity of the teacher model is higher than that of the student model. Specifically, when a teacher model and a student model are obtained, selection can be performed according to the similarity between the teacher model and the student model, the more similar the structures of the teacher model and the student model are, and the smaller the precision difference between the teacher model and the student model is after distillation is. Therefore, the teacher model and the student model can be selected as the same type of model, such as resnet network series, the resnet network is a residual network, and the width and the depth of the network can be easily adjusted to obtain networks with different expression capacities. And adjusting the preset student model according to a teacher recognition result of the training sample data output by the preset teacher model and a student recognition result of the training sample data output by the preset student model, so that the knowledge in the preset teacher model is transferred to the preset student model, and the preset student model has the data processing capacity and precision of the preset teacher model.
And S103, acquiring a weight parameter for adjusting the recognition result of the preset student model by the teacher recognition result.
The teacher identification result can be used to generate a weight parameter for adjusting the identification result of the preset student model, where the weight parameter is used to determine the weight occupied by the identification result of the preset student model when generating the loss value, that is, the larger the weight parameter is, the larger the weight occupied by the identification result of the preset student model when generating the loss value is.
Optionally, when the teacher identification result obtains the weight parameter for adjusting the identification result of the preset student model, the balance parameter for performing balance processing on the teacher identification result can be obtained, and the obtained plurality of teacher identification results are grouped according to the balance parameter in the identification sequence of the preset teacher model to obtain a plurality of teacher identification groups arranged in sequence, wherein each teacher identification group contains the same number of teacher identification results. And respectively calculating the average value of the plurality of teacher identification results in each teacher identification group, and taking the obtained plurality of average values as the weight parameters after the balance processing.
The teacher identification result is a plurality of teacher identification results, represents an identification probability, and is an identification probability obtained by identifying training sample data by a preset teacher model. The balance parameter for performing balance processing on the teacher identification result can be obtained, the balance parameter can be a positive integer greater than or equal to 1, and the balance parameter can be used for identifying an abnormal result in the teacher identification result, where the abnormal result is a result far greater than a normal result or far less than the normal result, that is, an abnormal result compared with other results. And sequencing the plurality of teacher identification results according to a preset identification sequence of the teacher identification model to obtain a plurality of sequenced teacher identification results, and grouping the plurality of sequenced teacher identification results according to the balance parameters to obtain a plurality of sequentially arranged teacher identification groups. The number of teacher identification results included in each of the plurality of teacher identification groups is the same. And then, calculating the average value of the plurality of teacher identification results in each teacher identification group respectively to obtain a plurality of average values of the plurality of teacher identification groups, and taking the plurality of average values of the plurality of teacher identification groups as the weight parameter after the balance processing, so that the preset identification result of the student model can be adjusted by using the weight parameter after the balance processing. When the average value of the plurality of teacher identification results in the teacher identification group is larger, the corresponding weight parameter after the balance processing is larger, the teacher identification result represents the identification probability, and the larger the average value of the identification probability in the teacher identification group is, the larger the generated weight parameter after the balance processing is. Therefore, the positions with high probability in the probability distribution are more concerned, and the events with high probability in the concept distribution are matched with the concepts correctly in a preferential mode are of practical value, so that the loss value of the preset student model obtained through calculation according to the weight parameters is more accurate, and the accuracy of the preset student model can be improved more accurately. The obtained multiple teacher identification results are grouped according to balance parameters to obtain multiple teacher identification groups which are sequentially arranged, the average value of the multiple teacher identification results of each teacher identification group in the multiple teacher identification groups is calculated, namely, the balance parameters are adopted to balance each teacher identification result, so that abnormal results in the teacher identification results can be balanced, errors generated when the weight parameters are generated are reduced, and the accuracy of a preset student model can be improved.
Optionally, when the balance parameter for balancing the teacher identification result is obtained, the number of the teacher identification results in the plurality of teacher identification results may be obtained, and the preset threshold range to which the number of the teacher identification results in the plurality of teacher identification results belongs may be determined. And determining target balance parameters corresponding to the preset threshold range to which the number of the teacher identification results belongs from a balance parameter library, and taking the target balance parameters as balance parameters for balancing the teacher identification results, wherein the balance parameter library comprises at least one balance parameter and a corresponding relation between each balance parameter in the at least one balance parameter and the preset threshold range.
The teacher identification result is a plurality of teacher identification results, the teacher identification result represents identification probability, and the identification probability is obtained by identifying training sample data by a preset teacher model, so that the number of the teacher identification results in the plurality of teacher identification results can be obtained, and a preset threshold range to which the number of the teacher identification results in the plurality of teacher identification results belongs is determined. After the preset threshold range to which the number of the teacher identification results belongs in the plurality of teacher identification results is determined, the target balance parameter corresponding to the preset threshold range to which the number of the teacher identification results belongs may be determined from a balance parameter database, where the balance parameter database includes at least one balance parameter and a corresponding relationship between each balance parameter in the at least one balance parameter and the preset threshold range. For example, a balance parameter library is preset, and the balance parameter library includes at least one balance parameter, and a corresponding relationship between each balance parameter in the at least one balance parameter and a preset threshold range, for example, a first threshold range corresponds to a first balance parameter, a second threshold range corresponds to a second balance parameter, and so on.
Optionally, the student identification results are a plurality of student identification results, and the student identification results correspond to the teacher identification results, so that the balance parameters can be determined according to the number of the student identification results in the plurality of student identification results. Similarly, a preset threshold range to which the number of the student identification results belongs may be determined, then a target balance parameter corresponding to the preset threshold range to which the number of the student identification results belongs is determined from the balance database, the target balance parameter is used as a balance parameter for balancing the teacher identification result, the balance parameter library includes at least one balance parameter, and a corresponding relationship between each balance parameter in the at least one balance parameter and the preset threshold range. It should be noted that the balance parameter C may be determined according to the number of the teacher identification results or the number of the student identification results, or other specific situations, in order to balance the abnormal results in the teacher identification results or the student identification results, i.e. to eliminate the abnormal results in the teacher identification results or the student identification results.
As shown in fig. 2, a schematic diagram of a method for calculating an average value of multiple teacher recognition results in each teacher recognition group provided by the embodiment of the present application is shown in fig. 2, taking the teacher recognition result distribution with the number of teacher recognition results being 11 as an example, the 11 teacher recognition results are sorted by a preset recognition order of the teacher model, and the sorted teacher recognition results, that is, x1 and x2 … … x11, are obtained. If the balance parameter is determined to be 3 according to the number 11 of the teacher identification results, the 11 sorted teacher identification results are grouped according to the balance parameter 3, so as to obtain a plurality of teacher identification groups which are sequentially arranged, namely [ x1, x2], [ x1, x2, x3], [ x2, x3, x4], [ x3, x4, x5] … … [ x9, x10, x11], [ x10, x11 ]. Since there is no teacher recognition result ranked in front of the teacher recognition result x1, the teacher recognition group corresponding to the teacher recognition result x1 may be [ x1, x2], or [ x1, x2, x3 ]. Similarly, since there is no teacher recognition result ranked behind the teacher recognition result x11, the teacher recognition group corresponding to the teacher recognition result x11 may be [ x10, x11], or [ x9, x10, x11 ]. And after a plurality of teacher identification groups which are sequentially ordered are obtained, calculating the average value of the plurality of teacher identification results in each teacher identification group respectively. If the teacher identification group [ x1, x2, x3] corresponding to the teacher identification result x2 is obtained, the teacher identification result x1, the teacher identification result x2, and the teacher identification result x3 are summed, and then divided by the group member number 3 to obtain an average value of the plurality of teacher identification results in the teacher identification group corresponding to the teacher identification result x 2.
For example, the teacher recognition result distribution of training sample data is:
[0.1,0.05,0.0001,0.02,0.15,0.28,0.23,0.06,0.023,0.05,0.0369]
determining that the balance parameter is 3 according to the number of teacher identification results in the teacher identification results of the training sample data, and obtaining a teacher identification group after sampling grouping by using C-3 as follows:
[0.1,0.05]、[0.1,0.05,0.0001]、[0.05,0.0001,0.02]、[0.0001,0.02,0.15]、[0.02,0.15,0.28]、[0.15,0.28,0.23]、[0.28,0.23,0.06]、[0.23,0.06,0.023]、[0.06,0.023,0.05]、[0.023,0.05,0.0369]、[0.023,0.05,0.0369]。
respectively calculating the average value of a plurality of teacher identification results in each teacher identification group, wherein the obtained average values are as follows:
[0.075,0.05,0.0234,0.0567,0.15,0.22,0.19,0.1043,0.0443,0.0366,0.0436]
the detailed calculation procedure for calculating the average value of each teacher recognition result is as follows:
the teacher identification result _0 is (0.1+0.05)/2 is 0.0750, the teacher identification result _1 is (0.1+0.05+0.0001)/3 is 0.0500, the teacher identification result _2 is (0.05+0.0001+0.02)/3 is 0.0234, the teacher identification result _3 is (0.0001+0.02+0.15)/3 is 0.0567, the teacher identification result _4 is (0.02+0.15+0.28)/3 is 0.1500, the teacher identification result 5 is (0.15+0.28+0.23)/3 is 0.2200, the teacher identification result _6 is (0.28+0.23+0.06)/3 is 0.1900, the teacher identification result _7 is (0.23+0.06+ 0.05)/2 is 0.02369, the teacher identification result is (0.0230) 0.0230.0230) is 0.05)/3, the teacher identification result is (0.0230.0230) is 0.0230.05)/3), the teacher identification result is 0.0230.030.05 + 0.0230, and 3 is 0.0230.19/3, the teacher identification result is 0.8 is 0.0230.8/3, the teacher identification result is 0.030.0230.05/3
For 0.0001 of the third teacher recognition result, which becomes 0.0234 after the balancing process, the abnormal teacher recognition result is balanced, i.e., the abnormal teacher recognition result is eliminated.
And S104, calculating the logarithm between the teacher identification result and the student identification result, performing weighting operation on the logarithm by using the weight parameters, and adjusting the preset student model by using the calculated numerical value as a loss value.
And calculating the logarithm between the teacher identification result and the student identification result, performing weighting operation on the logarithm between the teacher identification result and the student identification result by using the weight parameters to obtain a value after weighting calculation, and taking the value as a loss value of a preset student model. And adjusting the preset student model according to the loss value of the preset student model to obtain an adjusted student model, and taking the adjusted student model as a target student model. The target learning model is used for recognizing the data to be processed, and a recognition result obtained by recognizing the data to be processed by the target learning model is matched with a recognition result obtained by recognizing the data to be processed by the preset teacher model, namely the target student model has the data processing capacity of the preset teacher model.
Alternatively, when calculating the logarithm between the teacher identification result and the student identification result and performing a weighting operation on the logarithm by using the weight parameter, the following formula (1) may be used for calculation.
Figure BDA0002790523960000111
Wherein, formula (1) DKL(P | | Q) refers to a loss value of a preset student model, P (x) refers to a teacher recognition result of the training sample data, namely a teacher recognition result obtained by recognizing the training sample data by the teacher model, and Q (x) refers to a student recognition result of the training sample data, namely a student recognition result obtained by predicting the training sample data by the student model.
Figure BDA0002790523960000112
The log of the teacher identification result and the student identification result is referred to, the weight parameter may be p (X), X is any one of the teacher identification result and the student identification result, and X is the teacher identification result and the student identification result. This equation (1) may be referred to as KL-divergence (relative entropy).
Referring to fig. 3, a schematic diagram of a method for obtaining loss values of preset student models according to an embodiment of the present application is shown, where as shown in fig. 3, the method for obtaining loss values of preset student models includes steps S21-S23.
And S21, grouping the obtained student identification results according to the balance parameters in the identification sequence to obtain a plurality of student identification groups which are sequentially arranged, wherein each student identification group in the student identification groups contains the same number of student identification results, and each teacher identification group corresponds to each student identification group in the identification sequence.
And S22, respectively calculating the average value of the identification results of the students in each student identification group.
The student identification results are a plurality of student identification results, the student identification results represent identification probabilities, and the identification probabilities are obtained by identifying training sample data through a preset student model. The training sample data of the same arrangement sequence can be recognized according to the recognition sequence of the preset teacher model, wherein the recognition sequence of the preset teacher model is the same as the recognition sequence of the preset student model. And grouping the obtained plurality of student identification results according to the balance parameters to obtain a plurality of student identification groups which are sequentially arranged. Each student identification group in the plurality of student identification groups contains the same number of student identification results, and each teacher identification group corresponds to each student identification group one to one according to the identification sequence. And then calculating the average value of the plurality of student identification results of each student identification group in the plurality of student identification groups, namely, carrying out balance processing on each student identification result by adopting the balance parameters to obtain the student identification result after the balance processing. Therefore, abnormal results in the student recognition results can be balanced, the extreme value can be greatly reduced, the accuracy of the target student model is improved, and the student model has the data processing capacity of the teacher model.
And S23, respectively calculating the logarithm of the average value of each student identification group and the logarithm of the average value of the corresponding teacher identification group to obtain a plurality of logarithms after balance processing, and performing weighting operation on the weight parameters after balance processing and the logarithms after balance processing.
And then respectively calculating the average value of a plurality of student identification results in each student identification group, calculating the logarithm of the average value of each student identification group and the average value of the corresponding teacher identification group to obtain a plurality of logarithms after balance processing, and performing weighting operation on the weight parameters after balance processing and the logarithms after balance processing.
Optionally, the following formulas (2), (3) and (4) may be adopted to calculate the weight parameter after the balancing process, so as to perform a weighting operation on the logarithm between the teacher identification result and the student identification result.
Figure BDA0002790523960000121
Figure BDA0002790523960000122
Figure BDA0002790523960000123
Wherein z in equations (2) and (3) represents the number of teacher recognition results in the teacher recognition group or the number of student recognition results in the student recognition group, and e represents the teacher recognition group or the student recognition group. D in formula (4)DKL(P||Q)CThe loss value is referred to as P '(x), the average value of the teacher identification group, that is, the average value obtained by performing the balance processing on the teacher identification result, and Q' (x), the average value of the student identification group, that is, the average value obtained by performing the balance processing on the student identification result.
Figure BDA0002790523960000124
Means the logarithm of the average of the student identification group and the average of the corresponding teacher identification group,
Figure BDA0002790523960000125
the weight parameter may refer to p (X), X refers to any one of the teacher identification result and the student identification result, and X refers to the teacher identification result and the student identification result. When P (x), Q (x) are identical, DDKL(P||Q)CEqual to zero.
According to the method, the balance parameters are introduced into the KL-divergence (relative entropy) function, the KL-divergence loss function introduced with the balance parameters can be used as the DKL-divergence, and the balance parameters are introduced into the DKL-divergence. The DKL-divergence is used for balancing each teacher identification result in the plurality of teacher identification results and balancing each student identification result in the student identification results, and abnormal results in the plurality of teacher identification results and the plurality of student identification results are eliminated.
Optionally, the preset teacher model comprises a plurality of teacher distillation layers, the preset student model comprises a plurality of student distillation layers, the logarithm between the teacher identification result and the student identification result is calculated, the logarithm is subjected to weighting operation by using the weight parameters, and the calculated numerical value is used as a loss value to adjust the preset student model, the teacher distillation layer corresponding to each student distillation layer in the plurality of student distillation layers can be determined, and the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer is calculated. And performing weighting operation on the logarithm between the student recognition result of each student distillation layer and the teacher recognition result of the corresponding teacher distillation layer by using the weight parameters to obtain the loss value of each student distillation layer, and adjusting the corresponding student distillation layer in the preset student model by respectively adopting the loss value of each student distillation layer.
The teacher's model of predetermineeing includes a plurality of teacher's distillation layers, and every teacher's distillation layer all has the teacher's recognition result of corresponding output, includes a plurality of student's distillation layers in the student's model of predetermineeing, and every student's distillation layer also all has the student's recognition result of corresponding output, can correspond knowledge distillation with the distillation layer in the teacher's model of predetermineeing and the student's model of predetermineeing. The teacher distillation layer corresponding to each student distillation layer in the student distillation layers can be determined, weighting operation is carried out on the logarithm between the student identification result output by each student distillation layer and the teacher identification result output by the corresponding teacher distillation layer by adopting the weight parameters, and the loss value corresponding to each student distillation layer is obtained. The loss value of each student distillation layer is adopted to adjust the corresponding student distillation layer in the preset student model, so that each student distillation layer in the preset student model can be adjusted more accurately, and the accuracy of the student model is improved.
For example, as shown in fig. 4, for a schematic diagram of model distillation provided in the embodiments of the present application, as shown in fig. 4, the distillation model generally includes three-part distillation, i.e., Transformer-layer distillation, Embedding-layer distillation, and Prediction-layer distillation, and three distillation layers in the teacher model and three corresponding distillation layers in the student model can be subjected to knowledge distillation respectively.
The distillation layer in the teacher model and the distillation layer in the student model can be subjected to corresponding knowledge distillation, namely, the Transformer-layer conversion layer in the student model is subjected to knowledge distillation according to the Transformer-layer conversion layer in the teacher model, the loss value of the Transformer-layer conversion layer is obtained through calculation, and the Transformer-layer conversion layer is adjusted according to the loss value of the Transformer-layer conversion layer. Knowledge distillation is carried out on the Embedding layers in the student model according to the Embedding layers in the teacher model, the loss value of the Embedding layers is obtained through calculation, and the Embedding layers are adjusted according to the loss value of the Embedding layers. Knowledge distillation is carried out on a Prediction-layer Prediction layer in a student model according to the Prediction-layer Prediction layer in a teacher model, the loss value of the Prediction-layer Prediction layer is obtained through calculation, and the Prediction-layer Prediction layer is adjusted through the loss value of the Prediction-layer Prediction layer. Like this, can be more accurate adjust the distillation layer in the student model, provide the degree of accuracy of preset student model.
As shown in fig. 4, the student model (new model) has M layers of transforms, the teacher model (original model) has N layers of transforms, and the mth layer representing the student model obtains information from the nth layer of the teacher model using N ═ g (M). We set the imbedding-layer to distill to layer 0, the Output layer to distill to layer M +1, and the transforms to layers 1 to M. The following functional equation (5) can be used to represent distillation loss for the teacher's knowledge migration to the student.
Figure BDA0002790523960000141
Wherein, L in the formula (5)layerThe loss function of a specific layer is expressed, and the specific layer may be a transform-layer, an Embedding-layer, or a Prediction-layer. The parameter M is 0 to denote an embedded layer, M +1 to denote an Output layer, M is 1,2, …, and M denotes the number of transform layers of the student model planning and learning teacher model; lambda [ alpha ]mA hyperparameter representing the loss weight of each layer; l ismodelRepresents the sum of all layer knowledge distillation losses. Among them, a transform-layer (transform)Layer) distillation, Embedding-layer distillation and Prediction-layer distillation can be set with corresponding loss functions, and the corresponding layers can be adjusted according to the loss values obtained by the corresponding loss functions.
The transform-layer distillation in the model distillation includes distillation based on self-attention and distillation based on a hidden state, and the objective function of the self-attention moment matrix distillation in the transform-layer distillation in the related art is the following formula (6).
Figure BDA0002790523960000142
Wherein h in the formula (6) is the number of the attention heads, i represents the ith attention head,
Figure BDA0002790523960000143
the attention matrices of the student model and the teacher model are respectively represented, and MSE refers to mean square error loss.
The objective function of the output matrix fit for each layer in the transform (translation layer) is as follows (7).
Lhidn=MSE(HSWh,HT) (7)
Wherein, H in the formula (7)S∈Rl×d′And HT∈Rl×dHidden state matrices, R, referring to students and teachers, respectivelyl×d′,Rl×dThe hidden state matrix space size of the student and the teacher is respectively represented, and l and d respectively represent the length of training sample data (namely the length of an input sentence) and the size of a hidden layer. Wh∈Rd′×dThe method is a learnable linear change matrix, and converts a hidden state matrix of a student into the same result space size as a teacher.
The loss function of the imbedding-layer in the related art is the following formula (8).
Lembd=MSE(ESWe,ET) (8)
Wherein E in the formula (8)SAnd ETAre respectively indicatedMatrix of embedding (distillation layers) for student and teacher models, WeIs similar to WhThe linear transformation matrix of (3).
The output layer of Prediction-layer distillation in the related art employs a soft cross entropy loss as the following formula (9).
Lpred=-softmax(zT)·log_softmax(zS/t) (9)
Wherein z in the formula (9)SAnd zTRespectively, a teacher recognition result of the preset teacher model and a student recognition result of the preset student model, log _ softmax () representing log likelihood, and t representing distillation temperature.
In summary, the overall objective function of the model distillation in the related art can be expressed as the following formula (10).
Figure BDA0002790523960000151
The MSE in the formula (10) is called Mean Squared Error, and is generally used to detect the deviation between the predicted value and the true value of the model. Assuming that the result distribution of true values is observed, the prediction value is predicted, and the sample space size is n, the difference between the two distributions can be expressed as the following formula (11).
Figure BDA0002790523960000152
In the related art, the loss value of the student model is obtained by adopting the MSE calculation, but the MSE in the related art can pay attention to all positions in the result distribution without distinction as can be seen from the formula (11), and the obtained loss value cannot well reflect the difference between the student model and the teacher model, so that the preset student model cannot be accurately adjusted.
In the embodiment of the present application, the DKL-subvrgence is used to calculate the loss value of the student model, for example, when calculating the loss value corresponding to the transform (conversion layer), it is assumed that the teacher model and the student modelThe layer correspondence of model distillation of the model is respectively U and V, and for knowledge distillation of each corresponding layer, the moment array correspondence is noted
Figure BDA0002790523960000153
And
Figure BDA0002790523960000154
t and S respectively refer to a teacher and a student, and after the teacher identification result and the student identification result corresponding to the training sample data are determined, the balance parameter C can be corresponding to the attention moment array
Figure BDA0002790523960000155
And
Figure BDA0002790523960000156
sampling to obtain sub-distributions p corresponding to the teacher identification group and the student identification group respectivelyT,qSThe distribution obtained after averaging is
Figure BDA0002790523960000157
And
Figure BDA0002790523960000158
Figure BDA0002790523960000159
is the following formula (12) and
Figure BDA00027905239600001510
is the following formula (13)
Figure BDA00027905239600001511
Figure BDA00027905239600001512
Wherein z in the formula (12) and the formula (13) represents the number of teacher recognition results in the teacher recognition group and the number of student recognition results in the student recognition group, and e represents the teacher recognition group and the student recognition group.
Then the loss function of the student model
Figure BDA00027905239600001513
Expressed as the following formula (14).
Figure BDA00027905239600001514
Wherein χ in formula (14) represents the attention matrix A ∈ R1×lThe probability space where the training sample data is located actually represents the result distribution of the training sample data, and l represents the length of the training sample data. The total loss for training sample data of length l for a transform (conversion layer) with h attention heads is expressed as the following equation (15).
Figure BDA0002790523960000161
Wherein, t in the formula (15) refers to a sub-training sample data with a certain length of t in the training sample data with a length of l, and a refers to a certain attention head in the h attention heads.
The DKL-divergence in the scheme uses P '(x) (a weight parameter generated according to teacher identification probability) to weight the logarithm of the ratio of P' (x) to Q '(x) (the logarithm of the average value of each student identification group and the average value of the corresponding teacher identification group), and when the value of P' (x) is larger, the calculation result of the DKL-divergence is relatively larger. In other words, DKL-divergence focuses more on high probability locations in the probability distribution, and it is of practical value to preferentially match the truly high probability events in the concept distribution correctly. Whereas the MSE in the related art focuses indiscriminately on all locations in the distribution. Therefore, the use of DKL-divergence is more suitable for calculating loss values for student models than MSE. Meanwhile, in the scheme, a balance constant can be adopted to eliminate abnormal results (namely abnormal concepts) in the teacher recognition result and the student recognition result corresponding to the training sample data, if the value of a teacher recognition result in the teacher recognition result is 0.2 and the value of a corresponding student recognition result is 0.0001, the logarithm between the calculated P '(x) and Q' (x) becomes abnormal, and the gradient difference is extremely large. Which in turn may cause problems with the disappearance of the gradient. Therefore, the balance parameters are introduced to balance the possible abnormal probability values, and the occurrence of extreme values can be greatly reduced. Therefore, the accuracy of the student model can be improved, the data processing capacity of the student model can be improved, the recognition result obtained by recognizing the data to be processed by the learning model is matched with the recognition result obtained by recognizing the data to be processed by the teacher model, and even if the student model has the data processing capacity of the teacher model.
In the embodiment of the application, training sample data for training a preset student model is acquired, the preset student model and a preset teacher model are adopted to respectively identify the training sample data, and a teacher identification result and a student identification result of the training sample data are obtained. And obtaining a weight parameter for adjusting the recognition result of the preset student model by the teacher recognition result, calculating the logarithm between the teacher recognition result and the student recognition result, performing weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using the calculated numerical value as a loss value. Therefore, the adjustment weight for adjusting the student model according to the recognition result of the training sample data and different results in the prediction result of the training sample data can be reasonably distributed according to the weight parameter, so that the recognition result with higher probability value is concerned more, the obtained loss value can be more accurate, and the adjusted student model can be more accurate. Meanwhile, balance parameters are introduced, abnormal results in the teacher identification result and the student identification result of the training sample data are eliminated, and therefore the condition that the obtained loss value is inaccurate due to the fact that the abnormal results exist in the teacher identification result and the student identification result is avoided. Through the method and the device, the student model can have the data processing capacity of the teacher model, and the accuracy of the student model is improved.
Fig. 5 is a schematic flow chart of another model distillation method provided in the embodiments of the present application. The method may be performed by a computer device, as shown in FIG. 5, and the alternative model distillation method may include steps S201-S207.
S201, obtaining training sample data for training a preset student model.
And S202, respectively identifying the training sample data by adopting a preset student model and a preset teacher model to obtain a teacher identification result and a student identification result of the training sample data.
And S203, acquiring a weight parameter for adjusting the recognition result of the preset student model by the teacher recognition result.
And S204, calculating the logarithm between the teacher identification result and the student identification result, and performing weighted operation on the logarithm by using the weight parameters.
Specific contents of steps S201 to S204 in the embodiment of the present application may refer to the contents in the embodiment described in fig. 1, and the embodiment of the present application is not described herein again.
S205, verifying whether the loss value meets the convergence state condition.
And after obtaining the loss value of the student model, determining whether the loss value meets a convergence condition, wherein the convergence condition is that the loss value is smaller than a loss threshold preset by a user or the loss value is the minimum value of a corresponding loss function.
Optionally, when verifying whether the loss value satisfies the convergence condition, a minimum value of a loss function used for calculating the loss value may be obtained, and if the loss value is different from the minimum value, it is determined that the loss value does not satisfy the convergence condition; or verifying whether the loss value is smaller than a preset loss threshold value, and if the loss value is larger than or equal to the preset loss threshold value, determining that the loss value does not meet the convergence condition.
And when verifying whether the loss value meets the convergence condition, acquiring a minimum value of a loss function for calculating the loss value, and if the loss value is not the same as the minimum value or is smaller than the minimum value, determining that the loss value does not meet the convergence condition. Or verifying whether the loss value is smaller than a preset loss threshold value, and if the loss value is larger than or equal to the preset loss threshold value, determining that the loss value does not meet the convergence condition. The preset loss threshold may be set according to the data processing type of the student model or according to other indexes.
And S206, if the loss value does not meet the convergence condition, determining the loss degree to which the loss value belongs.
And S207, adjusting parameters in the preset student model according to the loss degree.
If the loss value does not meet the convergence condition, the difference between the teacher identification result obtained by identifying the training sample data by the teacher model and the student identification result obtained by predicting the training sample data by the student model is large, namely, the identification result obtained by identifying the to-be-processed data by the student model is not matched with the identification result obtained by identifying the to-be-processed data by the teacher model. Determining the loss degree to which the loss value belongs, and adjusting the parameters in the preset student model according to the loss degree. If the loss degree is larger, adjusting parameters in the preset student model is larger; the smaller the degree of loss, the smaller the adjustment to the parameters in the preset student model. Like this, adjust predetermined student model based on the loss value, can realize carrying out bigger degree of adjustment when student model's error degree is big more, and then improve student model's convergence rate, improve training efficiency, simultaneously, also make the adjustment operation to student model more accurate, and then improve student model training's precision.
The content of the embodiment described in fig. 1 can be referred to, and the embodiment of the present application will not be described herein again.
In the embodiment of the application, training sample data for training a preset student model is acquired, the preset student model and a preset teacher model are adopted to respectively identify the training sample data, and a teacher identification result and a student identification result of the training sample data are obtained. And obtaining a weight parameter for adjusting the recognition result of the preset student model by the teacher recognition result, calculating the logarithm between the teacher recognition result and the student recognition result, performing weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using the calculated numerical value as a loss value. Therefore, the adjustment weight for adjusting the student model according to the recognition result of the training sample data and different results in the prediction result of the training sample data can be reasonably distributed according to the weight parameter, so that the recognition result with higher probability value is concerned more, the obtained loss value can be more accurate, and the adjusted student model can be more accurate. Meanwhile, balance parameters are introduced, abnormal results in the teacher identification result and the student identification result of the training sample data are eliminated, and therefore the condition that the obtained loss value is inaccurate due to the fact that the abnormal results exist in the teacher identification result and the student identification result is avoided. And the student model is adjusted according to the loss degree of the loss value pair, so that the larger the error degree of the student model is, the larger the adjustment is, and the accuracy of the student model training is further improved. Through the method and the device, the student model can have the data processing capacity of the teacher model, and the accuracy of the student model is improved.
Fig. 6 is a schematic structural diagram of a model distillation apparatus according to an embodiment of the present disclosure. The model distillation apparatus may be a computer program (including program code) running on a computer device, for example the model distillation apparatus is an application software; the apparatus may be used to perform the corresponding steps in the methods provided by the embodiments of the present application. As shown in fig. 6, the model distillation apparatus may include: the device comprises a first acquisition module 11, an identification module 12, a second acquisition module 13 and an adjustment module 14.
The first acquisition module 11 is configured to acquire training sample data for training a preset student model;
the identification module 12 is configured to respectively identify the training sample data by using the preset student model and a preset teacher model to obtain a teacher identification result and a student identification result of the training sample data, where the preset student model is obtained by the guidance and training of the preset teacher model;
a second obtaining module 13, configured to obtain, from the teacher identification result, a weight parameter for adjusting the identification result of the preset student model;
and the adjusting module 14 is configured to calculate a logarithm between the teacher identification result and the student identification result, perform a weighting operation on the logarithm by using the weight parameter, and adjust the preset student model by using a calculated numerical value as a loss value.
The teacher identification result represents identification probability;
the second obtaining module 13 includes:
an acquisition unit configured to acquire a balance parameter for balancing the teacher identification result;
the first grouping unit is used for grouping the obtained multiple teacher identification results according to the balance parameters according to the identification sequence of the preset teacher model to obtain multiple teacher identification groups which are sequentially arranged, wherein each teacher identification group in the multiple teacher identification groups comprises the same number of teacher identification results;
and the first calculating unit is used for calculating the average value of the plurality of teacher identification results in each teacher identification group respectively, and taking the obtained plurality of average values as the weight parameters after the balance processing.
Wherein the number of the student identification results is multiple;
the adjusting module 14 includes:
the second grouping unit is used for grouping the obtained student identification results according to the balance parameters in the identification sequence to obtain a plurality of student identification groups which are sequentially arranged, wherein each student identification group in the student identification groups contains the same number of student identification results, and each teacher identification group and each student identification group are in one-to-one correspondence in the identification sequence;
the second calculating unit is used for calculating the average value of the identification results of the students in each student identification group;
the third calculating unit is used for calculating the logarithm of the average value of each student identification group and the logarithm of the average value of the corresponding teacher identification group respectively to obtain a plurality of logarithms after balance processing;
and the first weighting operation unit is used for carrying out weighting operation on the weighting parameters after the balance processing and the logarithms after the balance processing.
Wherein, the obtaining unit is specifically configured to:
acquiring the number of teacher identification results in the plurality of teacher identification results;
determining a preset threshold range to which the number of teacher identification results in the plurality of teacher identification results belongs;
and determining a target balance parameter corresponding to the preset threshold range from a balance parameter library, and taking the target balance parameter as a balance parameter for balancing the teacher identification result, wherein the balance parameter library comprises at least one balance parameter and a corresponding relation between each balance parameter in the at least one balance parameter and the preset threshold range.
Wherein, the adjusting module 14 includes:
a verification unit for verifying whether the loss value satisfies a convergence state condition;
a first determining unit, configured to determine a loss degree to which the loss value belongs if the loss value does not satisfy the convergence condition;
and the first adjusting unit is used for adjusting parameters in the preset student model according to the loss degree.
Wherein, the verification unit is specifically configured to:
obtaining a minimum value of a loss function for calculating the loss value, and if the loss value is different from the minimum value, determining that the loss value does not meet the convergence condition; alternatively, the first and second electrodes may be,
and verifying whether the loss value is smaller than a preset loss threshold value or not, and if the loss value is larger than or equal to the preset loss threshold value, determining that the loss value does not meet the convergence condition.
The preset teacher model comprises a plurality of teacher distillation layers, and the preset student model comprises a plurality of student distillation layers;
the adjusting module 14 includes:
the second determining unit is used for determining a teacher distillation layer corresponding to each student distillation layer in the plurality of student distillation layers;
the fourth calculation unit is used for calculating the logarithm between the student recognition result of each student distillation layer and the teacher recognition result of the corresponding teacher distillation layer;
the second weighting operation unit is used for performing weighting operation on the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer by using the weight parameters to obtain the loss value of each student distillation layer;
and the second adjusting unit is used for adjusting the corresponding student distillation layer in the preset student model by respectively adopting the loss value of each student distillation layer.
In the embodiment of the application, training sample data for training a preset teacher model are acquired, the preset student model and the preset teacher model are adopted to respectively identify the training sample data, and a teacher identification result and a student identification result of the training sample data are obtained. And obtaining a weight parameter for adjusting the recognition result of the preset student model by the teacher recognition result, calculating the logarithm between the teacher recognition result and the student recognition result, performing weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using the calculated numerical value as a loss value. Therefore, the adjustment weight for adjusting the student model according to the recognition result of the training sample data and different results in the prediction result of the training sample data can be reasonably distributed according to the weight parameter, so that the recognition result with higher probability value is concerned more, the obtained loss value can be more accurate, and the adjusted student model can be more accurate. Meanwhile, balance parameters are introduced, abnormal results in the teacher identification result and the student identification result of the training sample data are eliminated, and therefore the condition that the obtained loss value is inaccurate due to the fact that the abnormal results exist in the teacher identification result and the student identification result is avoided. And the student model is adjusted according to the loss degree of the loss value pair, so that the larger the error degree of the student model is, the larger the adjustment is, and the accuracy of the student model training is further improved. Through the method and the device, the student model can have the data processing capacity of the teacher model, and the accuracy of the student model is improved.
According to one embodiment of the present application, the steps involved in the model distillation method shown in fig. 1 or 5 may be performed by various modules in the model distillation apparatus shown in fig. 6. For example, step S101 shown in fig. 1 may be performed by the first obtaining module 11 in fig. 6; step S102 shown in fig. 1 may be performed by the identification module 12 in fig. 6; step S103 shown in fig. 1 may be performed by the second obtaining module 13 in fig. 6; step S104 shown in fig. 1 may be performed by the adjusting module 14 in fig. 6.
Fig. 7 is a schematic structural diagram of a computer device according to an embodiment of the present application. As shown in fig. 7, the computer apparatus 1000 may include: the processor 1001, the network interface 1004, and the memory 1005, and the computer apparatus 1000 may further include: a user interface 1003, and at least one communication bus 1002. Wherein a communication bus 1002 is used to enable connective communication between these components. The user interface 1003 may include a Display screen (Display) and a Keyboard (Keyboard), and the optional user interface 1003 may also include a standard wired interface and a standard wireless interface. The network interface 1004 may optionally include a standard wired interface, a wireless interface (e.g., WI-FI interface). The memory 1005 may be a high-speed RAM memory or a non-volatile memory (e.g., at least one disk memory). The memory 1005 may optionally be at least one memory device located remotely from the processor 1001. As shown in fig. 7, a memory 1005, which is a kind of computer-readable storage medium, may include therein an operating system, a network communication module, a user interface module, and a device control application program.
In the computer device 1000 shown in fig. 7, the network interface 1004 may provide a network communication function; the user interface 1003 is an interface for providing a user with input; and the processor 1001 may be used to invoke a device control application stored in the memory 1005 to implement:
acquiring training sample data for training a preset student model;
respectively identifying the training sample data by adopting the preset student model and a preset teacher model to obtain a teacher identification result and a student identification result of the training sample data, wherein the preset student model is obtained by guiding training of the preset teacher model;
acquiring a weight parameter for adjusting the recognition result of the preset student model according to the teacher recognition result;
calculating the logarithm between the teacher identification result and the student identification result, performing weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using the calculated numerical value as a loss value.
Optionally, the processor 1001 may be configured to invoke a device control application stored in the memory 1005 to implement:
acquiring balance parameters for balancing the teacher identification result;
according to the recognition sequence of the preset teacher model, grouping the obtained plurality of teacher recognition results according to the balance parameters to obtain a plurality of teacher recognition groups which are sequentially arranged, wherein each teacher recognition group in the plurality of teacher recognition groups comprises the same number of teacher recognition results;
and respectively calculating the average value of the plurality of teacher identification results in each teacher identification group, and taking the obtained plurality of average values as the weight parameters after the balance processing.
Optionally, the processor 1001 may be configured to invoke a device control application stored in the memory 1005 to implement:
according to the identification sequence, grouping the obtained student identification results according to the balance parameters to obtain a plurality of student identification groups which are sequentially arranged, wherein each student identification group in the student identification groups contains the same number of student identification results, and each teacher identification group corresponds to each student identification group one by one according to the identification sequence;
respectively calculating the average value of the identification results of a plurality of students in each student identification group;
respectively calculating the logarithm of the average value of each student identification group and the logarithm of the average value of the corresponding teacher identification group to obtain a plurality of logarithms after balance processing;
and carrying out weighting operation on the weight parameters after the balance processing and the logarithms after the balance processing.
Optionally, the processor 1001 may be configured to invoke a device control application stored in the memory 1005 to implement:
acquiring the number of teacher identification results in the plurality of teacher identification results;
determining a preset threshold range to which the number of teacher identification results in the plurality of teacher identification results belongs;
and determining a target balance parameter corresponding to the preset threshold range from a balance parameter library, and taking the target balance parameter as a balance parameter for balancing the teacher identification result, wherein the balance parameter library comprises at least one balance parameter and a corresponding relation between each balance parameter in the at least one balance parameter and the preset threshold range.
Optionally, the processor 1001 may be configured to invoke a device control application stored in the memory 1005 to implement:
verifying whether the loss value meets a convergence state condition;
if the loss value does not meet the convergence condition, determining the loss degree to which the loss value belongs;
and adjusting parameters in the preset student model according to the loss degree.
Optionally, the processor 1001 may be configured to invoke a device control application stored in the memory 1005 to implement:
obtaining a minimum value of a loss function for calculating the loss value, and if the loss value is different from the minimum value, determining that the loss value does not meet the convergence condition; alternatively, the first and second electrodes may be,
and verifying whether the loss value is smaller than a preset loss threshold value or not, and if the loss value is larger than or equal to the preset loss threshold value, determining that the loss value does not meet the convergence condition.
Optionally, the processor 1001 may be configured to invoke a device control application stored in the memory 1005 to implement:
the calculating a logarithm between the teacher identification result and the student identification result, performing a weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using a calculated value as a loss value, includes:
determining a teacher distillation layer corresponding to each student distillation layer in the plurality of student distillation layers;
calculating the logarithm between the student recognition result of each student distillation layer and the teacher recognition result of the corresponding teacher distillation layer;
performing weighting operation on the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer by using the weight parameters to obtain the loss value of each student distillation layer;
and respectively adopting the loss value of each student distillation layer to adjust the corresponding student distillation layer in the preset student model.
In the embodiment of the application, training sample data for training a preset student model is acquired, the preset student model and a preset teacher model are adopted to respectively identify the training sample data, and a teacher identification result and a student identification result of the training sample data are obtained. And obtaining a weight parameter for adjusting the recognition result of the preset student model by the teacher recognition result, calculating the logarithm between the teacher recognition result and the student recognition result, performing weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using the calculated numerical value as a loss value. Therefore, the adjustment weight for adjusting the student model according to the recognition result of the training sample data and different results in the prediction result of the training sample data can be reasonably distributed according to the weight parameter, so that the recognition result with higher probability value is concerned more, the obtained loss value can be more accurate, and the adjusted student model can be more accurate. Meanwhile, balance parameters are introduced, abnormal results in the teacher identification result and the student identification result of the training sample data are eliminated, and therefore the condition that the obtained loss value is inaccurate due to the fact that the abnormal results exist in the teacher identification result and the student identification result is avoided. And the student model is adjusted according to the loss degree of the loss value pair, so that the larger the error degree of the student model is, the larger the adjustment is, and the accuracy of the student model training is further improved. Through the method and the device, the student model can have the data processing capacity of the teacher model, and the accuracy of the student model is improved.
It should be understood that the computer device 1000 described in the embodiment of the present application can perform the description of the above-mentioned model distillation method in the embodiment corresponding to fig. 1 and fig. 5, and can also perform the description of the above-mentioned model distillation apparatus in the embodiment corresponding to fig. 6, which is not repeated herein. In addition, the beneficial effects of the same method are not described in detail.
In the embodiments of the present application, furthermore, it is noted herein that: an embodiment of the present application further provides a computer-readable storage medium, where the computer-readable storage medium stores a computer program executed by the aforementioned model distillation apparatus, and the computer program includes program instructions, and when the processor executes the program instructions, the description of the model distillation method in the embodiment corresponding to fig. 1 or fig. 5 can be executed, so that details are not repeated here. In addition, the beneficial effects of the same method are not described in detail. For technical details not disclosed in embodiments of the computer-readable storage medium referred to in the present application, reference is made to the description of embodiments of the method of the present application.
By way of example, the program instructions described above may be executed on one computer device, or on multiple computer devices located at one site, or distributed across multiple sites and interconnected by a communication network, which may comprise a blockchain network.
It will be understood by those skilled in the art that all or part of the processes of the methods of the embodiments described above can be implemented by a computer program, which can be stored in a computer-readable storage medium, and when executed, can include the processes of the embodiments of the methods described above. The storage medium may be a magnetic disk, an optical disk, a Read-Only Memory (ROM), a Random Access Memory (RAM), or the like.
The above disclosure is only for the purpose of illustrating the preferred embodiments of the present application and is not to be construed as limiting the scope of the present application, so that the present application is not limited thereto, and all equivalent variations and modifications can be made to the present application.

Claims (10)

1. A model distillation method, comprising:
acquiring training sample data for training a preset student model;
respectively identifying the training sample data by adopting the preset student model and a preset teacher model to obtain a teacher identification result and a student identification result of the training sample data, wherein the preset student model is obtained by guiding training of the preset teacher model;
acquiring a weight parameter for adjusting the recognition result of the preset student model according to the teacher recognition result;
calculating the logarithm between the teacher identification result and the student identification result, performing weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using the calculated numerical value as a loss value.
2. The method according to claim 1, wherein the teacher identification result is plural, and the teacher identification result represents an identification probability;
the obtaining, by the teacher recognition result, a weight parameter for adjusting the recognition result of the preset student model includes:
acquiring balance parameters for balancing the teacher identification result;
according to the recognition sequence of the preset teacher model, grouping the obtained plurality of teacher recognition results according to the balance parameters to obtain a plurality of teacher recognition groups which are sequentially arranged, wherein each teacher recognition group in the plurality of teacher recognition groups comprises the same number of teacher recognition results;
and respectively calculating the average value of the plurality of teacher identification results in each teacher identification group, and taking the obtained plurality of average values as the weight parameters after the balance processing.
3. The method according to claim 2, wherein the student identification result is plural;
the calculating a logarithm between the teacher identification result and the student identification result, and performing a weighting operation on the logarithm by using the weight parameter includes:
according to the identification sequence, grouping the obtained student identification results according to the balance parameters to obtain a plurality of student identification groups which are sequentially arranged, wherein each student identification group in the student identification groups contains the same number of student identification results, and each teacher identification group corresponds to each student identification group one by one according to the identification sequence;
respectively calculating the average value of the identification results of a plurality of students in each student identification group;
respectively calculating the logarithm of the average value of each student identification group and the logarithm of the average value of the corresponding teacher identification group to obtain a plurality of logarithms after balance processing;
and carrying out weighting operation on the weight parameters after the balance processing and the logarithms after the balance processing.
4. The method of claim 2, wherein obtaining the balancing parameters for balancing the teacher identification comprises:
acquiring the number of teacher identification results in the plurality of teacher identification results;
determining a preset threshold range to which the number of teacher identification results in the plurality of teacher identification results belongs;
and determining a target balance parameter corresponding to the preset threshold range from a balance parameter library, and taking the target balance parameter as a balance parameter for balancing the teacher identification result, wherein the balance parameter library comprises at least one balance parameter and a corresponding relation between each balance parameter in the at least one balance parameter and the preset threshold range.
5. The method of claim 1, wherein the adjusting the preset student model using the calculated value as a loss value comprises:
verifying whether the loss value meets a convergence state condition;
if the loss value does not meet the convergence condition, determining the loss degree to which the loss value belongs;
and adjusting parameters in the preset student model according to the loss degree.
6. The method of claim 5, wherein verifying whether the penalty value satisfies a convergence status condition comprises:
obtaining a minimum value of a loss function for calculating the loss value, and if the loss value is different from the minimum value, determining that the loss value does not meet the convergence condition; alternatively, the first and second electrodes may be,
and verifying whether the loss value is smaller than a preset loss threshold value or not, and if the loss value is larger than or equal to the preset loss threshold value, determining that the loss value does not meet the convergence condition.
7. The method of claim 1, wherein the pre-set teacher model includes a plurality of teacher distillation layers and the pre-set student model includes a plurality of student distillation layers;
the calculating a logarithm between the teacher identification result and the student identification result, performing a weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using a calculated value as a loss value, includes:
determining a teacher distillation layer corresponding to each student distillation layer in the plurality of student distillation layers;
calculating the logarithm between the student recognition result of each student distillation layer and the teacher recognition result of the corresponding teacher distillation layer;
performing weighting operation on the logarithm between the student identification result of each student distillation layer and the teacher identification result of the corresponding teacher distillation layer by using the weight parameters to obtain the loss value of each student distillation layer;
and respectively adopting the loss value of each student distillation layer to adjust the corresponding student distillation layer in the preset student model.
8. A model distillation apparatus, comprising:
the first acquisition module is used for acquiring training sample data for training a preset student model;
the identification module is used for respectively identifying the training sample data by adopting the preset student model and a preset teacher model to obtain a teacher identification result and a student identification result of the training sample data, wherein the preset student model is obtained by the guidance and training of the preset teacher model;
the second obtaining module is used for obtaining a weight parameter for adjusting the recognition result of the preset student model according to the teacher recognition result;
and the adjusting module is used for calculating the logarithm between the teacher identification result and the student identification result, performing weighting operation on the logarithm by using the weight parameter, and adjusting the preset student model by using the calculated numerical value as a loss value.
9. A computer device, comprising: a processor and a memory;
wherein the memory is configured to store program code and the processor is configured to invoke the program code to perform the method of any of claims 1 to 7.
10. A computer-readable storage medium, characterized in that the computer-readable storage medium stores a computer program comprising program instructions which, when executed by a processor, perform the steps of the method according to any one of claims 1 to 7.
CN202011313330.8A 2020-11-20 2020-11-20 Model distillation method, device, storage medium and equipment Pending CN112465138A (en)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN202011313330.8A CN112465138A (en) 2020-11-20 2020-11-20 Model distillation method, device, storage medium and equipment
PCT/CN2021/096649 WO2022105173A1 (en) 2020-11-20 2021-05-28 Model distillation method and apparatus, and storage medium and device

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011313330.8A CN112465138A (en) 2020-11-20 2020-11-20 Model distillation method, device, storage medium and equipment

Publications (1)

Publication Number Publication Date
CN112465138A true CN112465138A (en) 2021-03-09

Family

ID=74798380

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011313330.8A Pending CN112465138A (en) 2020-11-20 2020-11-20 Model distillation method, device, storage medium and equipment

Country Status (2)

Country Link
CN (1) CN112465138A (en)
WO (1) WO2022105173A1 (en)

Cited By (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112990296A (en) * 2021-03-10 2021-06-18 中科人工智能创新技术研究院(青岛)有限公司 Image-text matching model compression and acceleration method and system based on orthogonal similarity distillation
CN113239176A (en) * 2021-06-21 2021-08-10 中国平安人寿保险股份有限公司 Semantic matching model training method, device, equipment and storage medium
CN113807214A (en) * 2021-08-31 2021-12-17 中国科学院上海微系统与信息技术研究所 Small target face recognition method based on deit attached network knowledge distillation
CN114065834A (en) * 2021-09-30 2022-02-18 中国科学院深圳先进技术研究院 Model training method, terminal device and computer storage medium
WO2022105173A1 (en) * 2020-11-20 2022-05-27 平安科技(深圳)有限公司 Model distillation method and apparatus, and storage medium and device
CN114565759A (en) * 2022-02-22 2022-05-31 北京百度网讯科技有限公司 Image semantic segmentation model optimization method and device, electronic equipment and storage medium
CN115019060A (en) * 2022-07-12 2022-09-06 北京百度网讯科技有限公司 Target recognition method, and training method and device of target recognition model
CN115049074A (en) * 2022-06-20 2022-09-13 腾讯科技(深圳)有限公司 Model training method and device, electronic equipment and storage medium
US11526774B2 (en) * 2020-12-15 2022-12-13 Zhejiang Lab Method for automatically compressing multitask-oriented pre-trained language model and platform thereof
CN116348892A (en) * 2021-03-16 2023-06-27 墨芯国际有限公司 System and method for knowledge-preserving neural network pruning
CN114065834B (en) * 2021-09-30 2024-07-02 中国科学院深圳先进技术研究院 Model training method, terminal equipment and computer storage medium

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115099988A (en) * 2022-06-28 2022-09-23 腾讯科技(深圳)有限公司 Model training method, data processing method, device and computer medium
CN115170455B (en) * 2022-08-17 2023-02-07 荣耀终端有限公司 Image processing method and related device

Family Cites Families (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111105008A (en) * 2018-10-29 2020-05-05 富士通株式会社 Model training method, data recognition method and data recognition device
EP3736749A1 (en) * 2019-05-09 2020-11-11 Siemens Aktiengesellschaft Method and device for controlling a device using a dataset
CN110852426B (en) * 2019-11-19 2023-03-24 成都晓多科技有限公司 Pre-training model integration acceleration method and device based on knowledge distillation
CN110909815B (en) * 2019-11-29 2022-08-12 深圳市商汤科技有限公司 Neural network training method, neural network training device, neural network processing device, neural network training device, image processing device and electronic equipment
CN112465138A (en) * 2020-11-20 2021-03-09 平安科技(深圳)有限公司 Model distillation method, device, storage medium and equipment

Cited By (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2022105173A1 (en) * 2020-11-20 2022-05-27 平安科技(深圳)有限公司 Model distillation method and apparatus, and storage medium and device
US11526774B2 (en) * 2020-12-15 2022-12-13 Zhejiang Lab Method for automatically compressing multitask-oriented pre-trained language model and platform thereof
CN112990296A (en) * 2021-03-10 2021-06-18 中科人工智能创新技术研究院(青岛)有限公司 Image-text matching model compression and acceleration method and system based on orthogonal similarity distillation
CN116348892A (en) * 2021-03-16 2023-06-27 墨芯国际有限公司 System and method for knowledge-preserving neural network pruning
CN113239176A (en) * 2021-06-21 2021-08-10 中国平安人寿保险股份有限公司 Semantic matching model training method, device, equipment and storage medium
CN113807214A (en) * 2021-08-31 2021-12-17 中国科学院上海微系统与信息技术研究所 Small target face recognition method based on deit attached network knowledge distillation
CN113807214B (en) * 2021-08-31 2024-01-05 中国科学院上海微系统与信息技术研究所 Small target face recognition method based on deit affiliated network knowledge distillation
CN114065834A (en) * 2021-09-30 2022-02-18 中国科学院深圳先进技术研究院 Model training method, terminal device and computer storage medium
CN114065834B (en) * 2021-09-30 2024-07-02 中国科学院深圳先进技术研究院 Model training method, terminal equipment and computer storage medium
CN114565759A (en) * 2022-02-22 2022-05-31 北京百度网讯科技有限公司 Image semantic segmentation model optimization method and device, electronic equipment and storage medium
CN115049074A (en) * 2022-06-20 2022-09-13 腾讯科技(深圳)有限公司 Model training method and device, electronic equipment and storage medium
CN115019060A (en) * 2022-07-12 2022-09-06 北京百度网讯科技有限公司 Target recognition method, and training method and device of target recognition model

Also Published As

Publication number Publication date
WO2022105173A1 (en) 2022-05-27

Similar Documents

Publication Publication Date Title
CN112465138A (en) Model distillation method, device, storage medium and equipment
KR102213478B1 (en) A system for tracking user knowledge based on artificial intelligence learning and method thereof
WO2022022152A1 (en) Video clip positioning method and apparatus, and computer device and storage medium
WO2020182122A1 (en) Text matching model generation method and device
US20210081503A1 (en) Utilizing a gated self-attention memory network model for predicting a candidate answer match to a query
WO2019157251A1 (en) Neural network compression
US10685012B2 (en) Generating feature embeddings from a co-occurrence matrix
CN113268609A (en) Dialog content recommendation method, device, equipment and medium based on knowledge graph
CN115222566A (en) Learning method and system for international finance and finance metrology teaching
CN117033802B (en) Teaching subject pushing method and system based on AI assistance
KR102500782B1 (en) Method, apparatus and computer-readable recording medium for extracting customized question for each difficulty level based on the student's learning level
CN112740132A (en) Scoring prediction for short answer questions
CN113435208A (en) Student model training method and device and electronic equipment
CN112785005A (en) Multi-target task assistant decision-making method and device, computer equipment and medium
CN112966701A (en) Method and device for classifying objects
CN110929532B (en) Data processing method, device, equipment and storage medium
KR20220098698A (en) Learning content recommendation system that predicts the user's correct answer probability using collaborative filtering based on latent factors and operation method thereof
CN116684330A (en) Traffic prediction method, device, equipment and storage medium based on artificial intelligence
WO2021027257A1 (en) Computer-executed method and device using neural network for language processing
CN112434872A (en) Hotel yield prediction method, system, equipment and storage medium
CN115392594B (en) Electrical load model training method based on neural network and feature screening
CN116361655A (en) Model training method, standard problem prediction method, device, equipment and medium
CN115062769A (en) Knowledge distillation-based model training method, device, equipment and storage medium
CN113010687B (en) Exercise label prediction method and device, storage medium and computer equipment
CN114913871A (en) Target object classification method, system, electronic device and storage medium

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication

Application publication date: 20210309

RJ01 Rejection of invention patent application after publication