CN114298287A - Knowledge distillation-based prediction method and device, electronic equipment and storage medium - Google Patents
Knowledge distillation-based prediction method and device, electronic equipment and storage medium Download PDFInfo
- Publication number
- CN114298287A CN114298287A CN202210028926.6A CN202210028926A CN114298287A CN 114298287 A CN114298287 A CN 114298287A CN 202210028926 A CN202210028926 A CN 202210028926A CN 114298287 A CN114298287 A CN 114298287A
- Authority
- CN
- China
- Prior art keywords
- neural network
- classifier
- self
- model
- network model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 70
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 34
- 238000003062 neural network model Methods 0.000 claims abstract description 122
- 238000012549 training Methods 0.000 claims abstract description 81
- 238000004821 distillation Methods 0.000 claims abstract description 79
- 238000013145 classification model Methods 0.000 claims abstract description 37
- 230000007246 mechanism Effects 0.000 claims abstract description 25
- 239000013598 vector Substances 0.000 claims description 82
- 238000009826 distribution Methods 0.000 claims description 32
- 230000015654 memory Effects 0.000 claims description 19
- 238000013528 artificial neural network Methods 0.000 claims description 12
- 230000008014 freezing Effects 0.000 claims description 5
- 238000007710 freezing Methods 0.000 claims description 5
- 238000010801 machine learning Methods 0.000 abstract description 6
- 238000004364 calculation method Methods 0.000 description 19
- 238000012545 processing Methods 0.000 description 14
- 230000008569 process Effects 0.000 description 13
- 230000006835 compression Effects 0.000 description 12
- 238000007906 compression Methods 0.000 description 12
- 238000013473 artificial intelligence Methods 0.000 description 11
- 238000005516 engineering process Methods 0.000 description 11
- 238000013139 quantization Methods 0.000 description 10
- 238000000354 decomposition reaction Methods 0.000 description 9
- 238000003058 natural language processing Methods 0.000 description 9
- 238000013138 pruning Methods 0.000 description 9
- 238000004891 communication Methods 0.000 description 8
- 230000003044 adaptive effect Effects 0.000 description 7
- 210000002569 neuron Anatomy 0.000 description 6
- 230000000694 effects Effects 0.000 description 5
- 230000006870 function Effects 0.000 description 5
- 238000013507 mapping Methods 0.000 description 5
- 239000011159 matrix material Substances 0.000 description 5
- 238000011160 research Methods 0.000 description 5
- 238000013461 design Methods 0.000 description 4
- 238000012546 transfer Methods 0.000 description 4
- 238000005303 weighing Methods 0.000 description 4
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 238000010586 diagram Methods 0.000 description 3
- 230000001133 acceleration Effects 0.000 description 2
- 230000004913 activation Effects 0.000 description 2
- 230000006978 adaptation Effects 0.000 description 2
- 238000004458 analytical method Methods 0.000 description 2
- 238000006243 chemical reaction Methods 0.000 description 2
- 238000001914 filtration Methods 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 238000002372 labelling Methods 0.000 description 2
- 238000005259 measurement Methods 0.000 description 2
- 238000013526 transfer learning Methods 0.000 description 2
- 206010011971 Decreased interest Diseases 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 230000002457 bidirectional effect Effects 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 239000013256 coordination polymer Substances 0.000 description 1
- 238000007418 data mining Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 238000007667 floating Methods 0.000 description 1
- 230000014509 gene expression Effects 0.000 description 1
- 230000005484 gravity Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000012804 iterative process Methods 0.000 description 1
- 238000011068 loading method Methods 0.000 description 1
- 238000005065 mining Methods 0.000 description 1
- 238000010295 mobile communication Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000011176 pooling Methods 0.000 description 1
- 230000000644 propagated effect Effects 0.000 description 1
- 238000011002 quantification Methods 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
- 238000000638 solvent extraction Methods 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
- 238000000844 transformation Methods 0.000 description 1
- 238000013519 translation Methods 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
Images
Landscapes
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
The embodiment provides a prediction method and device based on knowledge distillation, electronic equipment and a storage medium, and belongs to the technical field of machine learning. The method comprises the following steps: performing self-attention distillation on the second neural network model according to the transformer layer of the first neural network model; training a backbone classifier of a backbone network according to downstream task data, and updating backbone parameters of the backbone network; performing self-distillation on the main classifier of the main network after fine adjustment to obtain a branch classifier; training a branch classifier according to the label-free task data to obtain a target classification model; based on a sample self-adaptive mechanism, the to-be-classified data input to the target classification model is subjected to self-adaptive reasoning to obtain a prediction result.
Description
Technical Field
The invention relates to the technical field of machine learning, in particular to a prediction method and device based on knowledge distillation, electronic equipment and a storage medium.
Background
Large-scale pre-training models such as the BERT model and the like have achieved good results in natural language processing tasks and are accepted by the industry. However, the large-scale pre-training models are large in parameter quantity, great challenges are brought to fine tuning and online deployment, the models are slow in speed during fine tuning and deployment due to the large quantity of parameters, calculation cost is high, and great delay and capacity limitation are caused to real-time application, so that the compression significance of the models is great.
Knowledge distillation, one of the three large methods of model compression, is widely accepted and applied in academia and industry. However, most of the knowledge distillation methods need to consider the layer-by-layer structure and knowledge of a teacher model and a student model, and the number of layers of the student model is various, so that the optimal mapping layer is very difficult to find, the uncertainty of the distillation effect is increased, and the prediction effect of the model is not ideal.
Disclosure of Invention
It is a primary object of embodiments of the present disclosure to provide a knowledge-based distillation prediction method and apparatus, an electronic device, and a storage medium, so as to reduce models and improve accuracy of model prediction data.
To achieve the above object, a first aspect of the embodiments of the present disclosure provides a prediction method based on knowledge distillation, including:
acquiring a first neural network model and a second neural network model; wherein the first neural network model comprises a transform layer;
performing self-attention distillation on the second neural network model according to the transformer layer of the first neural network model;
training a trunk classifier of a trunk network according to downstream task data, and updating trunk parameters of the trunk network to finely tune the first neural network; wherein the first neural network model is used as a backbone network;
performing self-distillation on the main classifier of the main network after fine adjustment to obtain a branch classifier; wherein the second neural network model acts as a branch network;
training the branch classifier according to the label-free task data to obtain a target classification model;
and carrying out self-adaptive reasoning on the data to be classified input into the target classification model based on a sample self-adaptive mechanism to obtain a prediction result.
In some embodiments, the pre-training step comprises:
acquiring a first self-attention module of a transform layer of the first neural network model;
generating, by the first self-attention module, a query vector, a key vector, and a value vector;
and performing dot product operation on the query vector, the key vector and the value vector, and guiding training of the second neural network model to perform self-attention distillation on the second neural network model.
In some embodiments, said performing a dot product operation on said query vector, said key vector, and said value vector directs training of said second neural network model to perform a self-attentive distillation on said second neural network model, comprising:
performing a first operation between the query vector and the key vector to obtain a first operation result;
performing second operation on the value vectors to obtain a second operation result;
measuring a first probability distribution of the first operation result according to the KL divergence and measuring a second probability distribution of the second operation result;
calculating a fitting loss from the first probability distribution and the first probability distribution;
performing a self-attentive distillation on the second neural network model based on the fitting loss.
In some embodiments, the training the trunk classifier of the first neural network model according to the downstream task data and updating trunk parameters of the trunk network to fine tune the first neural network model includes:
acquiring a trunk classifier of the first neural network model;
training the trunk classifier according to the downstream task data, and updating trunk parameters;
and adjusting the network structure of the backbone network for updating the backbone parameters.
In some embodiments, the self-distilling the refined trunk classifier of the first neural network model to obtain the branch classifier of the second neural network model includes:
freezing the stem parameters;
and carrying out self-distillation on the trunk classifier freezing the trunk parameters to obtain the branch classifier.
In some embodiments, the training the branch classifier according to the unlabeled task data to obtain a target classification model includes:
acquiring the non-tag task data;
and inputting the label-free task data into a second neural network model corresponding to the branch classifier, and training the branch classifier to obtain the target classification model.
In some embodiments, the self-distilling step comprises:
measuring the probability distribution distance of the trunk classifier and the branch classifier according to the KL divergence;
calculating target loss according to the probability distribution distance;
updating the second neural network model in accordance with the target loss.
To achieve the above object, a second aspect of the present disclosure provides a knowledge-based distillation prediction apparatus, including:
the model acquisition module is used for acquiring a first neural network model and a second neural network model; wherein the first neural network model comprises a transform layer;
a pre-training module for performing self-attention distillation on the second neural network model according to the transformer layer of the first neural network model;
the fine tuning module is used for training a trunk classifier of a trunk network according to downstream task data and updating trunk parameters of the trunk network so as to perform fine tuning on the first neural network; wherein the first neural network model is taken as the backbone network;
the self-distillation module is used for carrying out self-distillation on the main classifier of the main network after fine adjustment to obtain a branch classifier; wherein the second neural network model acts as a branch network;
the model training module is used for training the branch classifier according to the label-free task data to obtain a target classification model;
and the self-adaptive mechanism module is used for carrying out self-adaptive reasoning on the data to be classified input into the target classification model to obtain a prediction result.
To achieve the above object, a third aspect of the present disclosure provides an electronic device, including:
at least one memory;
at least one processor;
at least one program;
the program is stored in a memory and a processor executes the at least one program to implement the method of the present disclosure as described in the above first aspect.
To achieve the above object, a fourth aspect of the present disclosure proposes a storage medium that is a computer-readable storage medium storing computer-executable instructions for causing a computer to perform:
a method as described in the first aspect above.
According to the prediction method and device based on knowledge distillation, electronic equipment and storage medium, the second neural network model is subjected to self-attention distillation according to a transformer layer of the first neural network model in a self-attention stage, a trunk classifier of the first neural network model is trained according to downstream task data in a fine tuning stage, trunk parameters are updated, the trunk classifier of the first neural network model after fine tuning is subjected to self-distillation in the self-distillation stage, a branch classifier of the second neural network model is obtained, and the branch classifier is trained according to label-free task data, so that a target classification model is obtained; based on a sample self-adaptive mechanism, carrying out self-adaptive reasoning on the data to be classified input into the target classification model to obtain a prediction result, and compared with the first neural network model, the target classification model is correspondingly compressed, so that the model of the target classification model is smaller than the first neural network model, but the processing precision is close to the first neural network model; compared with the second neural network model, the physical precision of the target classification model is higher than that of the second neural network model; therefore, the model of the target classification model obtained by the embodiment of the disclosure is reduced on the premise of ensuring the processing precision; since the classifier requires much less cost than the Transformer, the cost can be saved by adding the classifier in the fine tuning stage; and by introducing a sample self-adaptive mechanism, the calculation amount of the sample can be self-adaptively adjusted when the sample is predicted, and the calculation steps are reduced.
Drawings
Fig. 1 is a flow chart of a predictive method based on knowledge distillation provided by an embodiment of the disclosure.
Fig. 2 is a flowchart of step 102 in fig. 1.
Fig. 3 is a flowchart of step 203 in fig. 2.
FIG. 4 is a diagram illustrating the application of step 102 in FIG. 1 to a specific application scenario.
Fig. 5 is a flowchart of step 103 in fig. 1.
Fig. 6 is a flowchart of step 104 in fig. 1.
Fig. 7 is a schematic diagram of a hardware structure of an electronic device provided in an embodiment of the present disclosure.
Detailed Description
In order to make the objects, technical solutions and advantages of the present invention more apparent, the present invention is described in further detail below with reference to the accompanying drawings and embodiments. It should be understood that the specific embodiments described herein are merely illustrative of the present application and are not intended to limit the present application.
It should be noted that although functional blocks are partitioned in a schematic diagram of an apparatus and a logical order is shown in a flowchart, in some cases, the steps shown or described may be performed in a different order than the partitioning of blocks in the apparatus or the order in the flowchart. The terms first, second and the like in the description and in the claims, and the drawings described above, are used for distinguishing between similar elements and not necessarily for describing a particular sequential or chronological order.
Unless defined otherwise, all technical and scientific terms used herein have the same meaning as commonly understood by one of ordinary skill in the art to which this invention belongs. The terminology used herein is for the purpose of describing embodiments of the invention only and is not intended to be limiting of the invention.
First, several terms referred to in the present application are resolved:
artificial Intelligence (AI): is a new technical science for researching and developing theories, methods, technologies and application systems for simulating, extending and expanding human intelligence; artificial intelligence is a branch of computer science that attempts to understand the essence of intelligence and produces a new intelligent machine that can react in a manner similar to human intelligence, and research in this field includes robotics, language recognition, image recognition, natural language processing, and expert systems, among others. The artificial intelligence can simulate the information process of human consciousness and thinking. Artificial intelligence is also a theory, method, technique and application system that uses a digital computer or a machine controlled by a digital computer to simulate, extend and expand human intelligence, perceive the environment, acquire knowledge and use the knowledge to obtain the best results.
Natural Language Processing (NLP): NLP uses computer to process, understand and use human language (such as chinese, english, etc.), and belongs to a branch of artificial intelligence, which is a cross discipline between computer science and linguistics, also commonly called computational linguistics. Natural language processing includes parsing, semantic analysis, discourse understanding, and the like. Natural language processing is commonly used in the technical fields of machine translation, character recognition of handwriting and print, speech recognition and text-to-speech conversion, information retrieval, information extraction and filtering, text classification and clustering, public opinion analysis and opinion mining, and relates to data mining, machine learning, knowledge acquisition, knowledge engineering, artificial intelligence research, linguistic research related to language calculation, and the like, which are related to language processing.
Model compression: model compression is more applied to complex depth models. The related technologies of model compression include low-rank decomposition (low-rank adaptation/factorization), pruning, knowledge distillation, network quantization and weight sharing. Wherein,
(1) low rank decomposition: in the model, if the weight matrix of the original network is taken as a full-rank matrix, a plurality of low-rank matrices can be used to approximate the original matrix so as to achieve the purpose of simplification (thinking of personalized recommendation based on collaborative filtering in a recommendation algorithm). The original dense full-rank matrix may be represented as a combination of several low-rank matrices, which in turn may be decomposed into products of small-scale matrices. SVD is a very good simplification for two-dimensional matrix operations. At present, a higher-level low-rank decomposition algorithm mainly comprises CP decomposition, Tucker decomposition, transport transform decomposition and Block Term decomposition, and the transport decomposition is mainly used for acceleration and compression.
(2) Pruning: the model is constructed by connecting a number of floating-point neurons, each layer passing information down according to the weight of the neuron. However, in the neurons of each layer, the weights of some nodes are very small, and the influence on the information loading of the model is very little. If the neurons with smaller weights can be deleted, the size of the model is reduced, and the influence on the precision and the like of the model is smaller. Each layer removes nodes with small values, but the granularity of pruning needs to be considered, for example, 5 neurons in each layer can be removed, or 3 neurons can be removed, or the method can be performed in a L1/L2 regular mode. In practice, Pruning is an Iterative process, generally called Iterative Pruning (Iterative Pruning): pruning-training-repeating. By introducing an AutoML mechanism, a pruning candidate set can be explored in a non-access stratum (NAS) mode (neural network search, which is somewhat similar to grid search in machine learning), and then pruning, verification and iteration are performed automatically. At present, the performance of the distillation column, whether it is compression ratio or distillation, is still to be improved. The existing problems and the research trend are different forms of knowledge, the limitation of softmax is removed, the research trend is to select the feature layer in the middle feature layer, design the selection of the loss function training student model data set, design the student model and integrate the compact network design with other compression methods. If the model compression is divided into two parts, the existing network can be compressed and a new small network can be constructed. Where pruning, quantization and low rank decomposition can be classified as the first one and distillation as the second one, the better approach is to select a small and compact network, i.e. a compact network design, at the initial stage of model construction.
(3) Network quantization: in general, parameters of the neural network model are represented by floating point type numbers with a length of 32 bits, so that the precision is not required to be kept as high as practical, and the space required by each weight can be reduced by quantization, for example, the precision represented by the original 32 bits is represented by 0-255, and the precision is sacrificed. In addition, the required precision of the SGD (storage Gradient Descent) is only 6-8 bits, so that the storage volume of the model can be reduced under the condition that the precision can be ensured by a reasonable quantification network. According to different quantization methods, the network quantization method can be roughly divided into binary quantization and three-value quantization, and the network quantization by the multi-value quantization needs to solve three basic problems. The lower-order value type means smaller data representation range and more sparse values, and the numerical precision is lost when the data are quantized.
(4) Knowledge distillation: guiding a small and simplified network model to carry out model training and learning by adopting a large and complex network model; the method comprises the steps that a large network in knowledge distillation is a Teacher network (complex but excellent reasoning performance), a small network is a Student network (simplified and low in complexity), soft-target (soft-target) related to the Teacher network is used as a part of total loss in the knowledge distillation to induce training of the Student network (Student network), knowledge transfer (knowledge transfer) is achieved, two targets are required to learn specifically, the output of the small network is similar to the true value, the output of the small network is similar to the output of the large network, a transfer learning model is adopted in the distillation model, the output of a pre-trained load model (Teacher model) is used as a supervision signal to train another simple network, and the simple network is called a Student model.
(5) Sharing weight: the concept of sharing weight refers to whether the information of the model in the building process is local is multiple times and is repeatedly used in the whole process. If the weight coefficients which can be shared can be mined in a clustering mode and other doors share some weights in a category mode, the model can be compressed.
BERT (bidirectional Encoder retrieval from transformations) model: the BERT model further increases the generalization capability of a word vector model, fully describes character-level, word-level, sentence-level and even sentence-level relational characteristics, and is constructed based on a Transformer. There are three embeddings in BERT, namely Token Embedding, Segment Embedding and Position Embedding; wherein, Token entries is a word vector, the first word is a CLS mark, and the first word can be used for the subsequent classification task; segment Embeddings are used to distinguish two sentences because pre-training does not only do LM but also do classification tasks with two sentences as input; position entries, where the Position word vector is not a trigonometric function in transform, but is learned by BERT training. But the BERT directly trains a position embedding to reserve position information, a vector is randomly initialized at each position, model training is added, and finally an embedding containing the position information is obtained, and the BERT selects direct splicing in the combination mode of the position embedding and the word embedding.
Fine adjustment: the BERT model is followed by a Classifier layer (Teacher Class), and the fine adjustment of the classification model is realized.
Attention distillation (Attention distillation): the method is used for migrating an attention feature map (attention map) learned by a large network into a small network. Specifically, on the basis of knowledge distillation, the characteristics of the network middle layer are migrated so as to ensure that attention characteristic graphs in the middle of large and small networks are similar.
Self-Attention Distillation (Self Attention Distillation): and (4) abandoning the large network, and performing transfer learning on the attention feature maps learned by different layers in the small network independently. The attention profile contains two types: one is activation-based authentication maps, and the other is gradient-based authentication maps, which are distinguished by whether an activation function is used. self-attention is a way for the Transformer to convert the "understanding" of other related words into the word we are processing. Each layer of the BERT encoder is followed by a Student Classifier layer, parameters of the fine tuning stage are frozen, the Student Classifier layer learns the distribution condition of the Teacher Classifier, and the loss is measured by using KL divergence.
Reasoning: the Student Classifier of each layer predicts the samples, the prediction result is measured by using entropy, and the larger the entropy is, the larger the uncertainty is; the classification effect is as follows: with speed, a threshold representing uncertainty is proportional to the inference speed. The larger the speed threshold, the faster the inference speed. Some samples can be predicted through few layers, and the results are predicted through all layers in the worst case, so that the calculation amount is reduced in the inference stage, and the inference speed is increased.
Transformer model: the Transformer model is similar to the Attention model and also adopts an encoder-decoder architecture, but the structure of the Transformer model is more complex compared with the Attention, and generally comprises a plurality of encoders stacked together and a plurality of decoder layers stacked together. For the encoder, the encoder comprises a self-attribute layer and a feedforward neural network layer, wherein the self-attribute can help the current node to focus only on the current word, so that the context semantics can be obtained; the decoder not only comprises two layers of a self-attention layer and a feedforward neural network, but also comprises an attention layer, wherein the attention layer is arranged between the self-attention layer and the feedforward neural network, and the attention layer can help the current node to acquire important contents needing attention currently.
The Transformer layer: the neural network comprises an embedding layer (which may be referred to as an input embedding layer) and at least one transform layer, which may be N transform layers (N is an integer greater than 0); the embedding layer comprises an input embedding layer and a position encoding (positional encoding) layer, wherein in the input embedding layer, word embedding processing can be carried out on each word in current input, so that word embedding vectors of each word are obtained; at the position encoding layer, the position of each word in the current input may be obtained, and a position vector may be generated for the position of each word. Each transform layer comprises an attention layer, an addition and normalization (add & norm) layer, a feed forward (feed forward) layer and an add & norm layer which are adjacent in sequence. Embedding the current input in an embedding layer (input embedding) to obtain a plurality of characteristic vectors; in the attention layer, acquiring P input vectors from a layer above the transducer layer, taking any first input vector in the P input vectors as a center, and obtaining intermediate vectors corresponding to the first input vectors based on the association degree between each input vector and the first input vector in a preset attention window range, so as to determine P intermediate vectors corresponding to the P input vectors; and at the pooling layer, combining the P intermediate vectors into Q output vectors, wherein a plurality of output vectors obtained by the last transformer layer in the at least one transformer layer are used as the feature representation of the current input. At the embedding layer, the current input (which may be a text input, such as a piece of text or a sentence; the text may be a text in chinese/english or other languages) is embedded to obtain a plurality of feature vectors. After the current input is obtained, the embedding layer may perform embedding processing on each word in the current input, so as to obtain a feature vector of each word.
Sample-wise adaptive mechanism (Sample-wise adaptive mechanism): the principle of the sample adaptation mechanism is as follows: the calculated amount of each sample is adjusted in a self-adaptive manner, easy samples can be predicted through two layers, and difficult samples need to be taken all the way. And (4) predicting the sample label after each layer of the Transformer, and if the confidence of a prediction result of a certain sample is high, continuing to calculate.
Large-scale pre-training models such as the BERT model and the like have achieved good results in natural language processing tasks and are accepted by the industry. However, the large-scale pre-training models are usually huge in parameter quantity (for example, the BERT-base model has 1.1 hundred million parameters, and the BERT-large model has 3.4 hundred million parameters), which brings huge challenges to fine tuning and on-line deployment, and the huge parameters cause the models to have slow speed in fine tuning and deployment, high calculation cost, and great delay and capacity limitation to real-time application, so that the model compression is significant.
Knowledge distillation, one of the three large methods of model compression, is widely accepted and applied in academia and industry. However, most of the knowledge distillation methods need to consider the layer-by-layer structure and knowledge of a teacher model and a student model, and the number of layers of the student model is various, so that the optimal mapping layer is very difficult to find, the uncertainty of the distillation effect is increased, and the prediction effect of the model is not ideal. For example, TinyBERT and MobileBERT require learning of the self-attention distribution and output of each layer of the transform, and BERT-PKD requires the use of multiple intermediate layers to the teacher model and distillation student models, making the student models relatively stationary, while it is uncertain whether the teacher layer used for distillation is the best mapping layer.
The embodiment of the application can acquire and process related data based on an artificial intelligence technology. Among them, Artificial Intelligence (AI) is a theory, method, technique and application system that simulates, extends and expands human Intelligence using a digital computer or a machine controlled by a digital computer, senses the environment, acquires knowledge and uses the knowledge to obtain the best result.
The artificial intelligence infrastructure generally includes technologies such as sensors, dedicated artificial intelligence chips, cloud computing, distributed storage, big data processing technologies, operation/interaction systems, mechatronics, and the like. The artificial intelligence software technology mainly comprises a computer vision technology, a robot technology, a biological recognition technology, a voice processing technology, a natural language processing technology, machine learning/deep learning and the like.
Based on the above, the embodiments of the present disclosure provide a prediction method and apparatus, an electronic device, and a storage medium based on knowledge distillation, which can reduce the number of models and improve the accuracy of model prediction data while ensuring the processing accuracy.
The embodiments of the present disclosure provide a prediction method and apparatus, an electronic device, and a storage medium based on knowledge distillation, and are specifically described in the following embodiments, in which the prediction method based on knowledge distillation in the embodiments of the present disclosure is first described.
The embodiment of the disclosure provides a prediction method based on knowledge distillation, and relates to the technical field of machine learning. The knowledge-based distillation prediction method provided by the embodiment of the disclosure can be applied to a terminal, a server side and software running in the terminal or the server side. In some embodiments, the terminal may be a smartphone, tablet, laptop, desktop computer, smart watch, or the like; the server side can be configured into an independent physical server, a server cluster or a distributed system formed by a plurality of physical servers, and cloud servers for providing basic cloud computing services such as cloud service, a cloud database, cloud computing, cloud functions, cloud storage, network service, cloud communication, middleware service, domain name service, security service, CDN (content delivery network) and big data and artificial intelligence platforms; the software may be an application or the like that implements a prediction method based on knowledge distillation, but is not limited to the above form.
Fig. 1 is an alternative flow chart of a prediction method based on knowledge distillation provided by an embodiment of the present disclosure, and the method in fig. 1 may include, but is not limited to, steps 101 to 106.
a pre-training step 102, carrying out self-attention distillation on a second neural network model according to a transformer layer of a first neural network model;
a fine tuning step 103, training a backbone classifier of the backbone network according to the downstream task data, and updating backbone parameters of the backbone network; wherein, the first neural network model is used as a backbone network;
a self-distillation step 104 of self-distilling the main classifier to obtain a branch classifier of a branch network; wherein the second neural network model serves as a branch network;
105, training a branch classifier according to the label-free task data to obtain a target classification model;
and a self-adaptive reasoning step 106, based on a sample self-adaptive mechanism, carrying out self-adaptive reasoning on the data to be classified input into the target classification model to obtain a prediction result.
In some embodiments, the first neural network model is a Teacher model (Teacher model) and the second neural network model is a Student model (Student model).
Comparing the target classification model obtained in the steps 101 to 105 with the first neural network model, the target classification model obtains corresponding compression, and the model is smaller than the first neural network model but has processing precision close to the first neural network model; the physical accuracy of the target classification model is higher than the second neural network model compared to the second neural network model.
In some embodiments, prior to performing step 101 of embodiments of the present disclosure, the knowledge-based distillation prediction method further comprises:
obtaining a pre-training model;
and training the pre-training model to obtain a first neural network model.
In particular, the pre-training model may be a BERT model; training the pre-training model to obtain a first neural network model, comprising: and carrying out parameter fine adjustment or knowledge distillation and other training on the pre-training model. The first neural network model may be task specific and derived by fine tuning or knowledge distillation based on a pre-trained model, such that the first neural network model has better performance on the specific task. Because the pre-training model similar to the BERT model is usually huge in parameter quantity (for example, the BERT-base model has 1.1 hundred million parameters, and the BERT-large model has 3.4 hundred million parameters), the mass parameters cause the pre-training model to be slow in fine adjustment and deployment, the calculation cost is high, and great delay and capacity limitation are caused to real-time application. Therefore, the obtained first neural network model is lightened by carrying out parameter fine adjustment or knowledge distillation training on the pre-training model.
Referring to fig. 2, the pre-training step 102 may include, but is not limited to, steps 201 to 203.
and step 203, performing dot product operation on the query vector, the key vector and the value vector, and guiding and training the second neural network model to perform self-attention distillation on the second neural network model.
In an application scenario, the first neural network model and the second neural network model each include at least two transform layers (conversion layers), wherein the last transform layer of the first neural network model includes a first self-attention module, and the last transform layer of the second neural network model includes a second self-attention module. In step 102, self-attention distillation is performed on a second self-attention module of a last transducer layer of a second neural network model according to a first self-attention module of the last transducer layer of the first neural network model.
Specifically, referring to fig. 3, the pre-training step 203 may include, but is not limited to, steps 301 to 305.
301, performing a first operation on the query vector and the key vector to obtain a first operation result;
303, weighing a first probability distribution of the first operation result according to the KL divergence, and weighing a second probability distribution of the second operation result;
Specifically, in step 202, the first self-attention module may generate a Query vector (Query, Q vector), a Key vector (Key, K vector), and a Value vector (Value, V vector). Referring to fig. 4, in a pre-training step 102 of a specific application scenario, a first operation is performed on the Q vector and the K vector to obtain a first operation result; performing second operation on the V vectors to obtain a second operation result; then, weighing a first probability distribution of the first operation result according to the KL divergence and weighing a second probability distribution of the second operation result according to the KL divergence; and calculating the fitting Loss according to the first probability distribution and the first probability distribution, so that the second neural network model can be subjected to self-attention distillation according to the fitting Loss to update the second neural network model.
In the pre-training step 102, knowledge distillation is directly performed by using self-attention modules of the last layers of the teacher model and the student model, so that the process of finding the optimal mapping layer is omitted, meanwhile, first operation (dot product operation) is performed on Q vectors and K vectors in self-attention, first operation (dot product operation) between V vectors is also added, KL divergence measurement difference is adopted for the Q vectors and the K vectors, and L divergence measurement difference is adopted for the Q vectors and the K vectorsATFirst fit loss (loss), L, as a shift from attention distributionATIs shown in formula (1), LVRSecond loss of fit (loss), L, as a measure of value relationship transferVRThe total fitting loss Lsum is a comprehensive consideration of the first fitting loss LATMake and second fitting loss LVRThe calculation formula of the total fitting loss Lsum is shown in formula (3), and weighting factors α and β are added to formula (3) to adjust the specific gravity of the two in the loss calculation:
Lsum=αLAT+βLVRformula (3)
First fitting loss L of the disclosed embodimentsATAnd a second fitting loss LVRThe loss of the product is Unsupervised Self-distillation loss (Unsupervised Self-distillation loss).
The principle of the self-attention mechanism is as follows: the self-attention mechanism may be understood as incorporating the understanding of all relevant words into the word currently being processed, and may also be understood as: when a word is encoded, the representations of all words (V-vectors) are summed weighted, and the weights are obtained by the dot product of all input words (K-vectors) and the encoded word representation (Q-vectors) and by softmax. First step of the self-attention mechanism: generating three vectors (Q vector, K vector, and V vector) from the input vector of each encoder (word vector of each word); the second step is that: calculating a score, and expressing the degree of importance of other words when encoding the current word by scoring each word of the input sentence to the word being processed; the third and fourth steps are to divide the score by the square root of the dimension of the key vector, making the gradient more stable, and then pass the result through softmax, which functions to normalize the scores of all words; the fifth step: multiplying each V vector by a softmax score; and a sixth step: the vector of weighted values is summed to obtain an output from the attention layer at that location.
The disclosed embodiments mainly involve three phases: pre-training stage, main stem fine-tuning stage, and branch self-distillation stage. In the pre-training phase, a pre-training step is performed. In the pre-training phase: the method comprises the steps that a first self-attention module is introduced into a transform layer of the last layer of a teacher model, a second self-attention module is introduced into a transform layer of the last layer of a student model, the second self-attention module is distilled through the first self-attention module, and values scaling dot products are introduced into the self-attention modules to guide training of the student model. In the fine adjustment stage of the main stem: the teacher model is used as a main network (the last layer of the main network is added with a main classifier), the teacher model is subjected to self-distillation to obtain student models, and the student models are branch networks and are trained by downstream task data to be finely adjusted. In the branched self-distillation stage: adding a branch classifier to each branch network, distilling the probability distribution predicted by the trunk classifier to the branch classifier by using label-free task data to complete branch self-distillation, and enabling the training process to be more stable and the accuracy to be higher through the self-distillation; and an adaptive mechanism is introduced for reasoning, samples are filtered according to the reasoning result of the branch classifier, the samples are easy to output a prediction result through a small number of transform layers, the network computing burden is reduced, and the reasoning process is accelerated. Moreover, by introducing a sample self-adaptive mechanism, the calculation amount of the sample can be self-adaptively adjusted when the sample is predicted, and the number of executed layers is dynamically adjusted to reduce the calculation steps.
The disclosed embodiment accelerates the accuracy and speed of reasoning of the second neural network model (i.e. student model) by combining self-attention distillation and self-distillation and crossing the pre-training stage, the trunk fine tuning stage and the branch self-distillation stage.
Since the classifier requires much less cost than the Transformer, by adding the classifier in the fine tuning stage, cost can be saved.
Referring to fig. 5, in some embodiments the fine tuning step 103 may include, but is not limited to, steps 501-502:
Further, step 103 may include, but is not limited to including 503:
In the fine tuning step 103, the first neural network model is used as a backbone network, i.e. a backbone model, which includes: embedding layer (embedding layer), at least one transform layer, a teacher classifier layer: in the self-distillation step, the second neural network model is taken as a branched network, i.e. a branched model, which comprises all Student Classifier layers.
The self-distillation step 104 in some embodiments may include, but is not limited to including:
freezing the trunk parameters;
and carrying out self-distillation on the trunk classifier of the first neural network model with the frozen trunk parameters to obtain the branch classifier of the branch network.
In the self-distillation stage: the trunk parameters are frozen and the parameter distribution of the branch classifier (student classifier) is learned.
The self-distillation step 105 in some embodiments comprises:
acquiring non-tag task data;
and inputting the label-free task data into a second neural network model corresponding to the branch classifier, and training the branch classifier to obtain a target classification model.
In a specific application scenario, because the traditional method based on the language model fine tuning cannot well utilize the unlabeled sample, the labeling cost of the sample is high and the labeling difficulty is high in reality. In the embodiment of the disclosure, by using BERT as a pre-training model, in the self-distillation stage, unlabeled data (unlabeled samples) can be used for training, so that the cost can be reduced.
Referring to fig. 6, in some embodiments, the self-distilling step 104 may include, but is not limited to, steps 601 to 603:
601, measuring the probability distribution distance of the trunk classifier and the branch classifier according to the KL divergence;
In some embodiments, the target penalties of the trunk classifier and branch classifier probability distributions are computed, via step 602.
According to the embodiment of the disclosure, a trunk fine tuning stage is entered after a pre-training stage is completed, the training is performed according to specific task data in the stage, and only a trunk classifier is added without a branch classifier at the moment; and entering a branch self-distillation stage after the main fine adjustment stage is finished. At this time, the main classifier can generate probability distribution (soft-label) with high reliability through training, the main classifier is used for distilling the branch classifier, so that the branch classifier can learn the probability distribution predicted by the main classifier, at this time, the main network needs to be frozen, and the same KL divergence is adopted to measure the difference in the distillation process, as shown in formula (4):
wherein L isSTIs the loss of interest, L is the hierarchy of the Transformer, psiIs the prediction probability, p, of the branch classifier for the ith prediction sample labeltIs the prediction probability of the stem classifier for the t-th prediction sample label. It is understood that the KL divergence is used to measure the difference from the distillation process, the principle is the same as in step 102 above; since the above step 102 is described in detail, it is not described herein again.
In some embodiments, through the adaptive inference step 106, the data to be classified input to the target classification model may be adaptively inferred based on a sample adaptive mechanism, so as to obtain a prediction result. Specifically, please refer to fig. 4, in an embodiment, the identification may be performed according to a hierarchical order from transform 0 to transform L, a predicted sample label is performed after each layer of transform, and the uncertainty is determined according to the following formula (5), if the confidence of the predicted result of a certain sample calculated according to the formula (5) is low, that is, when the uncertainty is small, it indicates that the network shallow layer can perform high accuracy determination and output on the sample; if the confidence coefficient of the prediction result of a certain sample calculated according to the formula (5) is very high, namely if the uncertainty is large, it indicates that the sample needs to be further forward propagated; therefore, the network can adaptively adjust the inference speed based on the mechanism, and is beneficial to the acceleration of network online deployment inference.
Wherein, the unrotaitaby is the entropy of the prediction result, and N is the preset number of preset classification labels, and can be specifically set according to the classification requirements of the data to be classified (prediction samples); p is a radical ofsiThe predicted probability of the branch classifier for the ith preset classification label. If the calculation is performed according to the formula (5) and the entropy of the prediction result is judged to be smaller than the preset threshold value, the root isAnd directly outputting a prediction result according to the confidence degree division of the current classifier. Further, the preset classification label with the highest corresponding confidence may be determined as the classification label of the data to be classified. The preset threshold may control the adaptive inference speed of the data to be classified, and is set for the uncertainty index of the prediction result (entropy of the prediction result). The smaller the preset threshold value is, the smaller the uncertainty (the smaller the entropy) required for the class prediction of the data (samples) to be classified is, the fewer the samples filtered by the low-level classifier is, and the slower the inference speed is. For example, if there are 4 classification labels L1, L2, L3, and L4, the data to be classified is input into the corresponding target classification model, and is identified according to the hierarchical sequence from Transformer 0 to Transformer L, each layer of Transformer will output the corresponding confidence level for the 4 classification labels L1, L2, L3, and L4, and when the entropy of the prediction result is determined to be less than the preset threshold, the prediction result can be directly output according to the classifier corresponding to the current operating level; if the classifier corresponding to the current operating level is transform 1, and the confidence levels of the transform 1 for the 4 preset classification labels are respectively: 0.05, 0.02, 0.1, 0.15, by comparing the confidence levels, the confidence level (0.15) of the classification label l4 can be determined to be the maximum, so that the classification label l4 can be determined as the prediction result of the data to be classified.
According to the embodiment of the disclosure, the self-attention module of the last layer of the transformer layer of the teacher model is used for distilling the self-attention distribution and value relation of the student model, and the best mapping layer does not need to be searched difficultly, so that the student model is more efficient and flexible. Meanwhile, a branch classifier based on self-distillation training is introduced in the fine tuning stage, and a self-adaptive mechanism is introduced in the network deployment prediction reasoning process, so that the network reasoning speed is adjustable, and high accuracy and high efficiency are ensured. Although the calculation amount of each layer of the network is increased after the branch classifier is introduced, the influence on the reasoning speed of a complex sample is extremely small, meanwhile, for a simple sample, the calculation of most of the transformers is avoided, the calculation burden is greatly reduced, and the online deployment is more stable and efficient.
According to the embodiment of the disclosure, a second neural network model is subjected to self-attention distillation in a self-attention stage according to a transformer layer of a first neural network model, a trunk classifier of the first neural network model is trained according to downstream task data in a fine tuning stage, trunk parameters are updated, the trunk classifier of the first neural network model after the fine tuning is subjected to self-distillation in the self-distillation stage, a branch classifier of the second neural network model is obtained, and the branch classifier is trained according to label-free task data, so that a target classification model is obtained; based on a sample self-adaptive mechanism, carrying out self-adaptive reasoning on the data to be classified input into the target classification model to obtain a prediction result, and compared with the first neural network model, the target classification model is correspondingly compressed, so that the model of the target classification model is smaller than the first neural network model, but the processing precision is close to the first neural network model; compared with the second neural network model, the physical precision of the target classification model is higher than that of the second neural network model; therefore, the model of the target classification model obtained by the embodiment of the disclosure is reduced on the premise of ensuring the processing precision; since the classifier requires much less cost than the Transformer, the cost can be saved by adding the classifier in the fine tuning stage; and by introducing a sample self-adaptive mechanism, the calculation amount of the sample can be self-adaptively adjusted when the sample is predicted, and the calculation steps are reduced.
The embodiment of the present disclosure further provides a prediction device based on knowledge distillation, which can implement the prediction method based on knowledge distillation, and the device includes:
the model acquisition module is used for acquiring a first neural network model and a second neural network model; wherein the first neural network model comprises a transformer layer;
the pre-training module is used for carrying out self-attention distillation on the second neural network model according to the transformer layer of the first neural network model;
the fine tuning module is used for training a trunk classifier of the trunk network according to the downstream task data and updating trunk parameters of the trunk network so as to perform fine tuning on the first neural network; wherein, the first neural network model is used as a backbone network;
the self-distillation module is used for carrying out self-distillation on the main classifier of the main network after fine adjustment to obtain a branch classifier of the branch network; wherein the second neural network model serves as a branch network;
the model training module is used for training the branch classifier according to the label-free task data to obtain a target classification model;
and the self-adaptive mechanism module is used for carrying out self-adaptive reasoning on the data to be classified input into the target classification model to obtain a prediction result.
An embodiment of the present disclosure further provides an electronic device, including:
at least one memory;
at least one processor;
at least one program;
the program is stored in the memory and the processor executes the at least one program to implement the present disclosure to implement the knowledge-based distillation prediction method described above. The electronic device can be any intelligent terminal including a mobile phone, a tablet computer, a Personal Digital Assistant (PDA for short), a vehicle-mounted computer and the like.
Referring to fig. 7, fig. 7 illustrates a hardware structure of an electronic device according to another embodiment, where the electronic device includes:
the processor 701 may be implemented by a general-purpose CPU (central processing unit), a microprocessor, an Application Specific Integrated Circuit (ASIC), or one or more integrated circuits, and is configured to execute a relevant program to implement the technical solution provided by the embodiment of the present disclosure;
the memory 702 may be implemented in a ROM (read only memory), a static memory device, a dynamic memory device, or a RAM (random access memory). The memory 702 may store an operating system and other application programs, and when the technical solution provided by the embodiments of the present disclosure is implemented by software or firmware, the relevant program codes are stored in the memory 702 and called by the processor 701 to execute the prediction method based on knowledge distillation according to the embodiments of the present disclosure;
an input/output interface 703 for realizing information input and output;
the communication interface 704 is used for realizing communication interaction between the device and other devices, and can realize communication in a wired manner (for example, USB, network cable, etc.) or in a wireless manner (for example, mobile network, WIFI, bluetooth, etc.); and
a bus 705 that transfers information between the various components of the device (e.g., the processor 701, the memory 702, the input/output interface 703, and the communication interface 704);
wherein the processor 701, the memory 702, the input/output interface 703 and the communication interface 704 are communicatively connected to each other within the device via a bus 705.
The disclosed embodiments also provide a storage medium that is a computer-readable storage medium having stored thereon computer-executable instructions for causing a computer to perform the above-described knowledge distillation based prediction method.
The knowledge distillation-based prediction method, the knowledge distillation-based prediction device, the electronic equipment and the storage medium provided by the embodiment of the disclosure are realized through three stages, which are respectively: a pre-training phase (first phase), a main stem fine-tuning phase (second phase), and a branch self-distillation phase (third phase). In the pre-training phase: the method comprises the steps that a first self-attention module is introduced into a transform layer of the last layer of a teacher model, a second self-attention module is introduced into a transform layer of the last layer of a student model, the second self-attention module is distilled through the first self-attention module, and values scaling dot products are introduced into the self-attention modules to guide training of the student model. In the fine adjustment stage of the main stem: the teacher model is used as a main network (the last layer of the main network is added with a main classifier), the teacher model is subjected to self-distillation to obtain student models, and the student models are branch networks and are trained by downstream task data to be finely adjusted. In the branched self-distillation stage: adding a branch classifier to each branch network, distilling the probability distribution predicted by the trunk classifier to the branch classifier by using label-free task data to complete branch self-distillation, and enabling the training process to be more stable and the accuracy to be higher through the self-distillation; and an adaptive mechanism is introduced for reasoning, samples are filtered according to the reasoning result of the branch classifier, the samples are easy to output a prediction result through a small number of transform layers, the network computing burden is reduced, and the reasoning process is accelerated. Moreover, by introducing a sample self-adaptive mechanism, the calculation amount of the sample can be self-adaptively adjusted when the sample is predicted, and the number of executed layers is dynamically adjusted to reduce the calculation steps. The accuracy and speed of reasoning of the second neural network model (namely the student model) are improved by combining self-attention distillation and self-distillation and crossing the pre-training stage, the main stem fine tuning stage and the branch self-distillation stage. In addition, since the classifier requires much less cost than the Transformer, by adding the classifier in the fine tuning stage, cost can be saved.
The memory, which is a non-transitory computer readable storage medium, may be used to store non-transitory software programs as well as non-transitory computer executable programs. Further, the memory may include high speed random access memory, and may also include non-transitory memory, such as at least one disk storage device, flash memory device, or other non-transitory solid state storage device. In some embodiments, the memory optionally includes memory located remotely from the processor, and these remote memories may be connected to the processor through a network. Examples of such networks include, but are not limited to, the internet, intranets, local area networks, mobile communication networks, and combinations thereof.
The embodiments described in the embodiments of the present disclosure are for more clearly illustrating the technical solutions of the embodiments of the present disclosure, and do not constitute a limitation to the technical solutions provided in the embodiments of the present disclosure, and it is obvious to those skilled in the art that the technical solutions provided in the embodiments of the present disclosure are also applicable to similar technical problems with the evolution of technology and the emergence of new application scenarios.
Those skilled in the art will appreciate that the embodiments shown in fig. 1-3 and 5-6 are not meant to limit the embodiments of the present disclosure, and may include more or less steps than those shown, or may combine some steps, or different steps.
The above-described embodiments of the apparatus are merely illustrative, wherein the units illustrated as separate components may or may not be physically separate, i.e. may be located in one place, or may also be distributed over a plurality of network elements. Some or all of the modules may be selected according to actual needs to achieve the purpose of the solution of the present embodiment.
One of ordinary skill in the art will appreciate that all or some of the steps of the methods, systems, functional modules/units in the devices disclosed above may be implemented as software, firmware, hardware, and suitable combinations thereof.
The terms "first," "second," "third," "fourth," and the like in the description of the application and the above-described figures, if any, are used for distinguishing between similar elements and not necessarily for describing a particular sequential or chronological order. It is to be understood that the data so used is interchangeable under appropriate circumstances such that the embodiments of the application described herein are capable of operation in sequences other than those illustrated or described herein. Furthermore, the terms "comprises," "comprising," and "having," and any variations thereof, are intended to cover a non-exclusive inclusion, such that a process, method, system, article, or apparatus that comprises a list of steps or elements is not necessarily limited to those steps or elements expressly listed, but may include other steps or elements not expressly listed or inherent to such process, method, article, or apparatus.
It should be understood that in the present application, "at least one" means one or more, "a plurality" means two or more. "and/or" for describing an association relationship of associated objects, indicating that there may be three relationships, e.g., "a and/or B" may indicate: only A, only B and both A and B are present, wherein A and B may be singular or plural. The character "/" generally indicates that the former and latter associated objects are in an "or" relationship. "at least one of the following" or similar expressions refer to any combination of these items, including any combination of single item(s) or plural items. For example, at least one (one) of a, b, or c, may represent: a, b, c, "a and b", "a and c", "b and c", or "a and b and c", wherein a, b, c may be single or plural.
In the several embodiments provided in the present application, it should be understood that the disclosed apparatus and method may be implemented in other ways. For example, the above-described apparatus embodiments are merely illustrative, and for example, the division of the units is only one logical division, and other divisions may be realized in practice, for example, a plurality of units or components may be combined or integrated into another system, or some features may be omitted, or not executed. In addition, the shown or discussed mutual coupling or direct coupling or communication connection may be an indirect coupling or communication connection through some interfaces, devices or units, and may be in an electrical, mechanical or other form.
The units described as separate parts may or may not be physically separate, and parts displayed as units may or may not be physical units, may be located in one place, or may be distributed on a plurality of network units. Some or all of the units can be selected according to actual needs to achieve the purpose of the solution of the embodiment.
In addition, functional units in the embodiments of the present application may be integrated into one processing unit, or each unit may exist alone physically, or two or more units are integrated into one unit. The integrated unit can be realized in a form of hardware, and can also be realized in a form of a software functional unit.
The integrated unit, if implemented in the form of a software functional unit and sold or used as a stand-alone product, may be stored in a computer readable storage medium. Based on such understanding, the technical solution of the present application may be substantially implemented or contributed to by the prior art, or all or part of the technical solution may be embodied in a software product, which is stored in a storage medium and includes multiple instructions for causing a computer device (which may be a personal computer, a server, or a network device) to perform all or part of the steps of the method according to the embodiments of the present application. And the aforementioned storage medium includes: various media capable of storing programs, such as a usb disk, a removable hard disk, a Read-Only Memory (ROM), a Random Access Memory (RAM), a magnetic disk, or an optical disk.
The preferred embodiments of the present disclosure have been described above with reference to the accompanying drawings, and therefore do not limit the scope of the claims of the embodiments of the present disclosure. Any modifications, equivalents and improvements within the scope and spirit of the embodiments of the present disclosure should be considered within the scope of the claims of the embodiments of the present disclosure by those skilled in the art.
Claims (10)
1. A method of prediction based on knowledge distillation, comprising:
acquiring a first neural network model and a second neural network model; wherein the first neural network model comprises a transform layer;
performing self-attention distillation on the second neural network model according to the transformer layer of the first neural network model;
training a trunk classifier of a trunk network according to downstream task data, and updating trunk parameters of the trunk network to finely tune the first neural network model; wherein the first neural network model is taken as the backbone network;
self-distilling the main classifier of the main network after fine adjustment to obtain a branch classifier of a branch network; wherein the second neural network model acts as a branch network;
training the branch classifier according to the label-free task data to obtain a target classification model;
and carrying out self-adaptive reasoning on the data to be classified input into the target classification model based on a sample self-adaptive mechanism to obtain a prediction result.
2. The method of claim 1, wherein the step of self-attentive distilling the second neural network model from the transformer layer of the first neural network model comprises:
acquiring a first self-attention module of a transform layer of the first neural network model;
generating, by the first self-attention module, a query vector, a key vector, and a value vector;
and performing dot product operation on the query vector, the key vector and the value vector, and guiding training of the second neural network model to perform self-attention distillation on the second neural network model.
3. The method of claim 2, wherein said performing a dot product operation on said query vector, said key vector, and said value vector directs training of said second neural network model to perform a self-attentive distillation on said second neural network model, comprising:
performing a first operation between the query vector and the key vector to obtain a first operation result;
performing second operation on the value vectors to obtain a second operation result;
measuring a first probability distribution of the first operation result according to the KL divergence and measuring a second probability distribution of the second operation result;
calculating a fitting loss from the first probability distribution and the first probability distribution;
performing a self-attentive distillation on the second neural network model based on the fitting loss.
4. The method according to any one of claims 1 to 3, wherein the step of training the trunk classifier of the first neural network model according to the downstream task data and updating trunk parameters of the trunk network to fine tune the first neural network model comprises:
acquiring the trunk classifier;
training the trunk classifier according to the downstream task data, and updating the trunk parameters;
and adjusting the network structure of the backbone network for updating the backbone parameters.
5. The method according to any one of claims 1 to 3, wherein the self-distilling of the finely tuned main classifier of the main network to obtain the branched classifier of the branched network comprises:
freezing the stem parameters;
and carrying out self-distillation on the trunk classifier freezing the trunk parameters to obtain the branch classifier.
6. The method of any one of claims 1 to 3, wherein training the branch classifier based on unlabeled task data to obtain a target classification model comprises:
acquiring the non-tag task data;
and inputting the label-free task data into a second neural network model corresponding to the branch classifier, and training the branch classifier to obtain the target classification model.
7. The method according to any one of claims 1 to 3, wherein the step of self-distilling the finely tuned main classifier of the main network to obtain branched classifiers of branched networks comprises:
measuring the probability distribution distance of the trunk classifier and the branch classifier according to the KL divergence;
calculating target loss according to the probability distribution distance;
updating the second neural network model in accordance with the target loss.
8. A predictive device based on knowledge distillation, comprising:
the model acquisition module is used for acquiring a first neural network model and a second neural network model; wherein the first neural network model comprises a transform layer;
a pre-training module for performing self-attention distillation on the second neural network model according to the transformer layer of the first neural network model;
the fine tuning module is used for training a trunk classifier of a trunk network according to downstream task data and updating trunk parameters of the trunk network so as to perform fine tuning on the first neural network; wherein the first neural network model is taken as the backbone network;
the self-distillation module is used for carrying out self-distillation on the main classifier of the main network after fine adjustment to obtain a branch classifier of a branch network; wherein the second neural network model is taken as a branch network;
the model training module is used for training the branch classifier according to the label-free task data to obtain a target classification model;
and the self-adaptive mechanism module is used for carrying out self-adaptive reasoning on the data to be classified input into the target classification model to obtain a prediction result.
9. An electronic device, comprising:
at least one memory;
at least one processor;
at least one program;
the programs are stored in the memory, and the processor executes the at least one program to implement:
the method of any one of claims 1 to 7.
10. A storage medium that is a computer-readable storage medium having stored thereon computer-executable instructions for causing a computer to perform:
the method of any one of claims 1 to 7.
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210028926.6A CN114298287A (en) | 2022-01-11 | 2022-01-11 | Knowledge distillation-based prediction method and device, electronic equipment and storage medium |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210028926.6A CN114298287A (en) | 2022-01-11 | 2022-01-11 | Knowledge distillation-based prediction method and device, electronic equipment and storage medium |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114298287A true CN114298287A (en) | 2022-04-08 |
Family
ID=80977808
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210028926.6A Pending CN114298287A (en) | 2022-01-11 | 2022-01-11 | Knowledge distillation-based prediction method and device, electronic equipment and storage medium |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114298287A (en) |
Cited By (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114913400A (en) * | 2022-05-25 | 2022-08-16 | 天津大学 | Knowledge distillation-based early warning method and device for collaborative representation of ocean big data |
CN115050355A (en) * | 2022-05-31 | 2022-09-13 | 北京小米移动软件有限公司 | Training method and device of speech recognition model, electronic equipment and storage medium |
CN116186317A (en) * | 2023-04-23 | 2023-05-30 | 中国海洋大学 | Cross-modal cross-guidance-based image-text retrieval method and system |
CN116416456A (en) * | 2023-01-13 | 2023-07-11 | 北京数美时代科技有限公司 | Self-distillation-based image classification method, system, storage medium and electronic device |
CN116778300A (en) * | 2023-06-25 | 2023-09-19 | 北京数美时代科技有限公司 | Knowledge distillation-based small target detection method, system and storage medium |
WO2023221940A1 (en) * | 2022-05-16 | 2023-11-23 | 中兴通讯股份有限公司 | Sparse attention computation model and method, electronic device, and storage medium |
CN117787922A (en) * | 2024-02-27 | 2024-03-29 | 东亚银行(中国)有限公司 | Method, system, equipment and medium for processing money-back service based on distillation learning and automatic learning |
CN117974672A (en) * | 2024-04-02 | 2024-05-03 | 华侨大学 | Method and device for detecting breast ultrasonic tumor lesion area based on global information |
-
2022
- 2022-01-11 CN CN202210028926.6A patent/CN114298287A/en active Pending
Cited By (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2023221940A1 (en) * | 2022-05-16 | 2023-11-23 | 中兴通讯股份有限公司 | Sparse attention computation model and method, electronic device, and storage medium |
CN114913400A (en) * | 2022-05-25 | 2022-08-16 | 天津大学 | Knowledge distillation-based early warning method and device for collaborative representation of ocean big data |
CN115050355A (en) * | 2022-05-31 | 2022-09-13 | 北京小米移动软件有限公司 | Training method and device of speech recognition model, electronic equipment and storage medium |
CN116416456A (en) * | 2023-01-13 | 2023-07-11 | 北京数美时代科技有限公司 | Self-distillation-based image classification method, system, storage medium and electronic device |
CN116416456B (en) * | 2023-01-13 | 2023-10-24 | 北京数美时代科技有限公司 | Self-distillation-based image classification method, system, storage medium and electronic device |
CN116186317A (en) * | 2023-04-23 | 2023-05-30 | 中国海洋大学 | Cross-modal cross-guidance-based image-text retrieval method and system |
CN116186317B (en) * | 2023-04-23 | 2023-06-30 | 中国海洋大学 | Cross-modal cross-guidance-based image-text retrieval method and system |
CN116778300A (en) * | 2023-06-25 | 2023-09-19 | 北京数美时代科技有限公司 | Knowledge distillation-based small target detection method, system and storage medium |
CN116778300B (en) * | 2023-06-25 | 2023-12-05 | 北京数美时代科技有限公司 | Knowledge distillation-based small target detection method, system and storage medium |
CN117787922A (en) * | 2024-02-27 | 2024-03-29 | 东亚银行(中国)有限公司 | Method, system, equipment and medium for processing money-back service based on distillation learning and automatic learning |
CN117787922B (en) * | 2024-02-27 | 2024-05-31 | 东亚银行(中国)有限公司 | Method, system, equipment and medium for processing money-back service based on distillation learning and automatic learning |
CN117974672A (en) * | 2024-04-02 | 2024-05-03 | 华侨大学 | Method and device for detecting breast ultrasonic tumor lesion area based on global information |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114298287A (en) | Knowledge distillation-based prediction method and device, electronic equipment and storage medium | |
CN113792818B (en) | Intention classification method and device, electronic equipment and computer readable storage medium | |
US20230100376A1 (en) | Text sentence processing method and apparatus, computer device, and storage medium | |
CN109992629B (en) | Neural network relation extraction method and system fusing entity type constraints | |
CN110309514A (en) | A kind of method for recognizing semantics and device | |
CN109214006B (en) | Natural language reasoning method for image enhanced hierarchical semantic representation | |
CN114358007A (en) | Multi-label identification method and device, electronic equipment and storage medium | |
CN113887215A (en) | Text similarity calculation method and device, electronic equipment and storage medium | |
CN114359810B (en) | Video abstract generation method and device, electronic equipment and storage medium | |
WO2023137911A1 (en) | Intention classification method and apparatus based on small-sample corpus, and computer device | |
CN113705313A (en) | Text recognition method, device, equipment and medium | |
CN114330354A (en) | Event extraction method and device based on vocabulary enhancement and storage medium | |
CN111145914B (en) | Method and device for determining text entity of lung cancer clinical disease seed bank | |
CN114358201A (en) | Text-based emotion classification method and device, computer equipment and storage medium | |
CN114627282A (en) | Target detection model establishing method, target detection model application method, target detection model establishing device, target detection model application device and target detection model establishing medium | |
CN111368531A (en) | Translation text processing method and device, computer equipment and storage medium | |
WO2023134085A1 (en) | Question answer prediction method and prediction apparatus, electronic device, and storage medium | |
CN115964459A (en) | Multi-hop inference question-answering method and system based on food safety cognitive map | |
CN114492661B (en) | Text data classification method and device, computer equipment and storage medium | |
CN115510232A (en) | Text sentence classification method and classification device, electronic equipment and storage medium | |
CN118093834A (en) | AIGC large model-based language processing question-answering system and method | |
CN117217277A (en) | Pre-training method, device, equipment, storage medium and product of language model | |
CN114936274A (en) | Model training method, dialogue generating device, dialogue training equipment and storage medium | |
CN114239599A (en) | Method, system, equipment and medium for realizing machine reading understanding | |
US11822887B2 (en) | Robust name matching with regularized embeddings |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |