CN114549901B - Multi-network combined auxiliary generation type knowledge distillation method - Google Patents

Multi-network combined auxiliary generation type knowledge distillation method Download PDF

Info

Publication number
CN114549901B
CN114549901B CN202210172188.2A CN202210172188A CN114549901B CN 114549901 B CN114549901 B CN 114549901B CN 202210172188 A CN202210172188 A CN 202210172188A CN 114549901 B CN114549901 B CN 114549901B
Authority
CN
China
Prior art keywords
training
network
sample generator
teacher
cifar
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202210172188.2A
Other languages
Chinese (zh)
Other versions
CN114549901A (en
Inventor
王一琳
匡振中
丁佳骏
顾晓玲
俞俊
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Hangzhou Dianzi University
Original Assignee
Hangzhou Dianzi University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Hangzhou Dianzi University filed Critical Hangzhou Dianzi University
Priority to CN202210172188.2A priority Critical patent/CN114549901B/en
Publication of CN114549901A publication Critical patent/CN114549901A/en
Application granted granted Critical
Publication of CN114549901B publication Critical patent/CN114549901B/en
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

The invention discloses a multi-network combined auxiliary generation type knowledge distillation method, which comprises the steps of firstly preprocessing an image classification data set; then selecting a teacher network model according to the determined image classification data set and training; then selecting a difficult sample generator G1 and a student network according to the determined image classification data set to form an countermeasure knowledge distillation frame; establishing an objective function for generating countermeasure knowledge distillation; performing iterative training on the built countermeasure knowledge distillation frame; and finally, introducing a simple sample generator G2, and alternately adjusting the student network by using a difficult sample generator G1 and the simple sample generator G2 to obtain a final result. The invention additionally introduces a simple sample generator, and the simple sample generator directly replicates the trained difficult sample generator, does not increase the calculation amount, and is simple to operate. In the case where the simple sample generator helps the student network review the simple sample, a better effect is ultimately achieved on the target task.

Description

Multi-network combined auxiliary generation type knowledge distillation method
Technical Field
The invention belongs to the field of knowledge distillation in the field of computer vision, and particularly provides a multi-network combined auxiliary generation type knowledge distillation method which is used for image classification tasks.
Background
Convolutional neural networks (Convolutional Neural Network, CNN) have achieved remarkable achievements in the fields of image classification, segmentation, detection, etc., by virtue of their strong feature extraction and expression capabilities. However, the structure of the neural network with high expression capability is often complex and the parameter amount is huge. In this case, deployment of the complete CNN often requires a huge memory overhead and a high-performance computing unit, and there is a limitation in application of the CNN on an embedded device with limited computing resources and a mobile terminal with high real-time requirements. Therefore, the CNN is strongly required to be lightweight.
Knowledge distillation is currently widely used as a model compression method. Knowledge distillation treats the model to be compressed as a "teacher" and the compressed model as a "student". The teacher has strong network capability, but has complex structure and inconvenient deployment; the student network has simple structure, but the effect obtained by direct training is not good. Knowledge distillation is a mode of assisting student network training through a teacher network, improves the performance of the student network, and achieves the effect close to the teacher network.
Most existing deep neural network compression and acceleration methods are very effective if training data can be directly accessed when model compression is performed. However, if the training data is not accessible for privacy or legal reasons, most model compression methods fail, so that a learner proposes some model compression methods that do not require training data.
If one proposes that a light neural network can be omitted, pruning can be directly carried out on the full-connection layer of the original model, similar neurons are eliminated, and the final output results are not far different. However, this method cannot be used on top of the convolution layer, so that the compression degree of the model is greatly reduced, and the method also fails in the case that the internal structure of the model is unknown. Methods have also been proposed whereby training data can be de-reconstructed by "metadata" (such as the activation values of the network layer) in the original model training process, but in most cases the original model does not retain these "metadata".
It is not difficult to find that none of the above methods are practical in practice, and it has been proposed later to distill knowledge in connection with generating a countermeasure Network (GAN). GAN is used as a generative model to generate some of the data that replaces training data that can be used for knowledge distillation. Note that the thematically no data does not require any data, but rather no data for training the teacher network.
For example, there is now a method of combating distillation ([ 1]Fang G,Song J,Shen C,et al.Data-FREE ADVERSARIAL Distillation [ J ].2019 ]) in which a generator is introduced in addition to a pre-trained teacher network, and then the student network and the teacher network are combined as a discriminator. The purpose of the generator is to generate a difficult sample that makes the teacher network and the student network output differ significantly, while the learning goal of the student network is to continually reduce the differences with the teacher network in output. The student network can be mastered gradually into a simple sample by continuously learning, and the difficult sample can be distinguished easily. At this time, the generator needs to continue searching the sample space to find out difficult samples which can enlarge the output difference between the student network and the teacher network and are not mastered by the student network. The whole training process is a process of generating the countermeasure.
The above method for countermeasures against knowledge distillation has a problem that the generator is used for generating difficult samples for training in the student network, and finally the student network can forget simple samples, so that prediction errors are caused, and the overall performance is reduced.
Disclosure of Invention
The invention aims at overcoming the defects of the method, and provides a multi-network combined auxiliary generation type knowledge distillation method. The method adds a generator G2 for generating simple samples to assist a generator G1 for generating difficult samples, and the two generators together regulate a student network, so that the student network is prevented from ignoring some simple samples because of pursuing the performance on the difficult samples, and the overall performance is reduced.
In the training process, a mode of alternately training a difficult sample generator G1 and a simple sample generator G2 is adopted, the goal of G1 is to generate samples with large output differences between a teacher network and a student network, the goal of G2 is to generate samples with small output differences between the teacher network and the student network, and the goal of the student network is to reduce the output differences between the teacher network and the student network regardless of whether G1 or G2 is used. In the case of G2 helping the student network review simple samples, better results than the original method are ultimately achieved on the target task.
A multi-network combined auxiliary generation type knowledge distillation method comprises the following steps:
Step1: preprocessing an image classification data set;
step 2: selecting a teacher network model according to the determined image classification data set and training;
Step3: selecting a difficult sample generator G1 and a student network according to the determined image classification data set to form an countermeasure knowledge distillation framework;
step 4: establishing an objective function for generating countermeasure knowledge distillation;
step 5: performing iterative training on the built countermeasure knowledge distillation frame;
step 6: a simple sample generator G2 is introduced, and a student network is alternately adjusted by using a difficult sample generator G1 and the simple sample generator G2 to obtain a final result.
The specific steps of the step 1 are as follows:
1-1. Data preparation and preprocessing.
And selecting a public data set, wherein the public data set adopts MNIST, CIFAR10 or CIFAR data sets. For MNIST dataset, its resolution is first scaled up to 32x32, then image normalization is performed, and finally normalization is performed. For CIFAR and CIFAR100 datasets, image normalization was performed directly, followed by normalization.
1-2, Image enhancement. MNIST data is too simple and therefore does not undergo image enhancement. The same image enhancement operations are performed on both datasets CIFAR and CIFAR. I.e. 4 pixels are filled in the upper, lower, left and right of the image respectively, and then cut randomly. Finally, the clipped image is flipped at a probability of 0.5 at random.
Step 2, the specific steps are as follows:
2-1. Teacher networks corresponding to different image classification datasets are different. On the MNIST dataset, the teacher network uses LeNet. For CIFAR and CIFAR100, the teacher network uses ResNet and is modified from ResNet. The parameters of the first layer of convolution layer are first modified, kernel_size is changed from 7 to 3, stride is changed from 2 to 1, padding is changed from 3 to 1, and then the subsequent max pooling layer of the first layer of convolution layer is deleted. And finally deleting the average pooling layer which appears before the full connection layer.
2-2, Aiming at different data sets and different teacher networks, the adopted training methods are the same, but the configuration of parameters is different, so that the teacher networks achieve the best effect. The training method comprises the following steps: firstly, setting total training rounds, inputting a training set part of a selected data set into a teacher network in each round of training, obtaining an output value of the teacher network, putting the output value and a training set label into an objective function together for calculation, and finally, transmitting errors reversely to optimize the teacher network. The objective function employs multi-class cross entropy.
The specific steps of the step3 are as follows:
3-1. For all data sets, the difficult sample generator G1 uses the generation network in DCGAN (Deep Convolutional GENERATIVE ADVERSARIAL Networks).
3-2, Different student networks corresponding to different data sets are different.
On MNIST dataset, student networks use LeNet-Half; for CIFAR and CIFAR data sets, the student network uses ResNet. And is improved on the basis of the original ResNet. The parameters of the first layer of convolution layer are first modified, kernel_size is changed from 7 to 3, stride is changed from 2 to 1, padding is changed from 3 to 1, and then the subsequent max pooling layer of the first layer of convolution layer is deleted. And finally deleting the average pooling layer which appears before the full connection layer.
And 3-3, forming a whole training framework by the difficult sample generator G1, the student network and the trained teacher network, wherein the student network and the teacher network are combined to be used as a discriminator. The whole training process is divided into two stages, an imitation stage and a generation stage. A simulation stage, wherein the fixed generator updates the student network; and in the generation stage, a student network update generator is fixed.
The specific steps of the step4 are as follows:
4-1 mean absolute value error (Mean Absolute Error, MAE) is chosen as the objective function, in the specific form below, formula 1, where n represents the number of samples, y i represents the true value, Representing the predicted value.
4-2, In the simulation stage, the objective function of the student network is recorded as L S, the specific form is shown as the following formula 2, wherein T represents the teacher network, S represents the student network, n represents the classification category number, G h represents the difficult sample generator G1, z represents the random vector sampled from normal distribution, I 1 represents the L1 norm, and the objective function L S is the MAE value of the output result of the student network and the teacher network in practical sense; in the generation stage, the objective function of the difficult sample generator G1 is denoted as L G1, and the specific form is as shown in the following formula 3, which corresponds to the negative MAE value of the output result of the student network and the teacher network.
The specific steps of the step 5 are as follows:
5-1, setting the number of training wheels, and setting the number of training wheels to 500 wheels, wherein each wheel comprises 50 iterative training. Each iterative training comprises two stages in sequence: an emulation phase and a generation phase.
5-2. In the simulation phase, the difficult sample generator G1 does not update parameters, and the student network updates parameters. Firstly, generating random vectors conforming to standard normal distribution, then sending the random vectors into a difficult sample generator G1 to obtain generated data, then sending the generated data into a student network and a teacher network respectively to obtain corresponding output, calculating the value of an objective function L S, and finally, reversely transmitting errors to optimize the student network. It should be noted that the training process described above needs to be repeated 5 times to enter the generation phase.
5-3, In the generation stage, the student network does not update parameters, and the difficult sample generator G1 updates parameters. Firstly, generating random vectors which are subjected to standard normal distribution, then sending the random vectors into a difficult sample generator G1 to obtain generated data, then sending the generated data into a student network and a teacher network respectively to obtain corresponding output, calculating the value of an objective function L G1, and finally, reversely transmitting errors to optimize the difficult sample generator G1.
And 5-4, iterating the simulation stage and the generation stage for 50 times, and starting the next training after the iteration is completed until all training rounds are completed.
The specific steps of the step 6 are as follows:
6-1. In step 5, after the student network and the challenge training of the difficult sample generator G1 have completed the set number of rounds, the difficult sample generator G1 is duplicated as a simple sample generator G2.
6-2. The simple sample generator G2 is also involved in the training process, the purpose of the simple sample generator G2 is to generate a simple sample, and the generated simple sample is input to the student network and the teacher network to obtain a smaller output difference, so that the objective function L G2 is shown in the following formula 4:
where G s represents the generator G2 that generates simple samples.
6-3, After the simple sample generator G2 is introduced, the student network is continuously subjected to countermeasure training, and the training mode of the step is also an iterative training mode. A total number of training rounds was first set to 100 rounds, each round containing 50 iterative training. Each round of training first determines whether to choose to use the difficult sample generator G1 or the simple sample generator G2. Taking a strategy of 3 rounds of training G1, 1 round of training G2, the first round of training uses a simple sample generator G2. Whichever generator is used, each iterative training is divided into two phases in succession: an emulation phase and a generation phase. If G1 is used, the two-stage operation is the same as in step 5. If G2 is used, the training process in the simulation phase only goes through 1 time to transition to the generation phase; the operation of the generation stage is also the same as that of step 5, but the calculated objective function is L G2, and the simple sample generator G2 is optimized.
Further, the training method adopted by the step 2 for different data sets and different teacher networks is the same, but the configuration of the parameters is different, specifically as follows:
For the MNIST dataset, a random gradient GRADIENT DESCENT (SGD) optimized teacher model was selected, the momentum was set to 0.9, the weight decay was set to 0.0001, the learning rate was 0.01, the batch size was 256, and a total of 10 rounds of training. For CIFAR and CIFAR data sets, the teacher model was also optimized using the SGD optimizer except that the weight decay was 0.0005,batch size to 128, the initial learning rate was 0.1, and the training was 200 rounds total, reducing the learning rate by a factor of 10 per 80 rounds.
Further, the simulation phase described in step 5:
The student network is optimized for all datasets using an SGD optimizer, with specific parameters varying. For the MNIST dataset, the learning rate was set to 0.01, the weight decay was 0.0001, the momentum was 0.9, and the batch size of the generated data was 512. For CIFAR data sets, the momentum was 0.9, the weight decay 0.0005,batch size was 128, the initial learning rate was 0.1, and the training was 500 rounds total, with 10-fold reductions in learning rate at rounds 100 and 200. CIFAR100 are substantially identical to CIFAR, the only difference being that the batch size is changed to 256.
The generation stage:
The specific parameters vary for all data sets using Adam optimizer to optimize the difficult sample generator G1. When the teacher network is trained on MNIST data sets, the learning rate is set to 0.001 and the batch size of the generated data is 512. When the teacher network is trained on CIFAR data sets, the batch size of the generated data is set to 128, the initial learning rate is 0.001, and the training is performed for 500 rounds in total, and the learning rate is sequentially reduced by 10 times at 100 th round and 200 th round. CIFAR100 is set up substantially the same as CIFAR, the only difference being that the batch size of the generated data is changed to 256.
Further, the countermeasure training in step 6 is specifically as follows:
For all data sets, the student network was optimized using an SGD optimizer, with a distinction in parameter settings. For MNIST dataset, the learning rate is set to 0.01, the momentum is 0.9, and the weight decays by 0.00001. For both datasets CIFAR and CIFAR100, the learning rate was set to 0.001, the momentum to 0.9, and the weight decay to 0.00005. For all data sets, the difficult sample generator G1 and the simple sample generator G2 were optimized using Adam optimizer, with a difference in parameter settings. For the MNIST dataset, the optimizer learning rate of G1 was set to 0.001, the optimizer learning rate of G2 was set to 0.0001, and the batch size of the generated data was 512. For CIFAR data sets, the learning rate was set to 0.00001 and the batch size of the generated data was 128. For CIFAR data sets, the learning rate was set to 0.00001 and the batch size of the generated data was 256.
The invention has the following beneficial effects:
Compared with the prior art, the invention has a certain improvement on the experimental result. The invention additionally introduces a simple sample generator, and the simple sample generator directly replicates the trained difficult sample generator, does not increase the calculation amount, and is simple to operate. In the case where the simple sample generator helps the student network review the simple sample, a better effect is ultimately achieved on the target task.
Drawings
FIG. 1 is a flow chart of the steps of the present invention;
Fig. 2 is a flow chart of the overall architecture of the present invention.
Detailed Description
The invention is further described below with reference to the accompanying drawings.
A multi-network combined auxiliary generation type knowledge distillation method is shown in the whole architecture flow chart of FIG. 2, and specific steps are shown in FIG. 1:
Step1: preprocessing an image classification data set;
step 2: selecting a teacher network model and training;
step 3: selecting a difficult sample generator G1 and a student network to form a whole training framework;
step 4: establishing an objective function for generating countermeasure knowledge distillation;
step 5: performing iterative training on the constructed countermeasure knowledge distillation frame;
Step 6: a simple sample generator G2 is introduced and a student network is alternately adjusted using a difficult sample generator G1 and a simple sample generator G2.
Step 1, data processing, which comprises the following specific steps:
1-1, loading an image classification data set, and selecting a public data set, wherein the public data set adopts MNIST, CIFAR10 or CIFAR data sets. For MNIST dataset, its resolution is first scaled up to 32x32, then image normalization is performed, and finally normalization is performed. For CIFAR and CIFAR100 datasets, image normalization was performed directly, followed by normalization. Taking an RGB three-channel image such as CIFAR a, for example, the image is normalized first because the RGB single channel values are 0,255, which is detrimental to model convergence. Next, normalization is performed on each channel of the image according to the following expression 5 (the mean value becomes 0, and the standard deviation becomes 1), further accelerating model convergence.
In the above equation, μ and σ are both statistical data representing the mean and standard deviation of each channel of the original image, respectively.
1-2, Image enhancement. In deep learning, the number of samples is generally required to be sufficient, and the more the number of samples is, the better the training model effect is, and the stronger the generalization capability of the model is. In practice, however, there are often cases where the number of samples is insufficient or the quality of the samples is not good enough, and this requires data enhancement on the samples to improve the quality of the samples. Since the data used for training is image data, the vocabulary of image enhancement is used in place of data enhancement hereinafter. Next, the image enhancement operation performed by the three image data sets in the present method will be described. MNIST data is too simple and therefore does not undergo image enhancement. The same image enhancement operations are performed on both datasets CIFAR and CIFAR. I.e. 4 pixels are filled in the upper, lower, left and right of the image respectively, and then cut randomly. Finally, the clipped image is flipped at a probability of 0.5 at random.
Step 2, selecting a teacher network model and training, wherein the specific steps are as follows:
2-1. Teacher networks corresponding to different image classification datasets are different. On the MNIST dataset, the teacher network uses LeNet. The LeNet network has 5 layers in total, specifically a 3-layer convolution layer and a 2-layer full connection layer, and output units of the full connection layer are adjusted according to the training data set, and are 10 output units on the MNIST data set. For CIFAR and CIFAR data sets, the teacher network uses ResNet. Notably, the CIFAR and CIFAR100 datasets have a resolution of 32x32, which is relatively small, and thus improved over the original ResNet. The parameters of the first layer of convolution layer are first modified, kernel_size is changed from 7 to 3, stride is changed from 2 to 1, padding is changed from 3 to 1, and then the subsequent max pooling layer of the first layer of convolution layer is deleted. And finally deleting the average pooling layer which appears before the full connection layer.
2-2 For MNIST dataset, a random gradient GRADIENT DESCENT (SGD) optimized teacher model was selected, the momentum set was 0.9, the weight decay set was 0.0001, the learning rate was 0.01, the batch size was 256, and a total of 10 training rounds. For CIFAR and CIFAR data sets, the teacher model was also optimized using the SGD optimizer except that the weight decay was 0.0005,batch size to 128, the initial learning rate was 0.1, and the training was 200 rounds total, reducing the learning rate by a factor of 10 per 80 rounds.
Step 3, selecting a difficult sample generator G1 and a student network to form a whole training framework, wherein the specific steps are as follows:
3-1. For all data sets, the difficult sample generator G1 uses the generation network in DCGAN. It should be understood at first that all GANs consist of two networks, called generator and arbiter, respectively, and that the method does not require a special arbiter, so only the generation network in DCGAN is introduced as the difficult sample generator G1.DCGAN has some advantages over conventional GAN, such as it removes all pooling layers, replaces them with convolution layers, which makes the whole network tiny, and uses Batch Normalization to accelerate model convergence.
3-2, Different student networks corresponding to different data sets are different. On MNIST data set, the student network uses LeNet-Half, and compared with standard LeNet, the LeNet-Half reduces the channel number of the convolution layer and the neuron number of the full connection layer to Half of the original ones. For CIFAR and CIFAR data sets, the student network uses ResNet. Also for the reason that CIFAR and CIFAR are smaller in resolution, improvements were made on the basis of original ResNet. The parameters of the first layer of convolution layer are first modified, kernel_size is changed from 7 to 3, stride is changed from 2 to 1, padding is changed from 3 to 1, and then the subsequent max pooling layer of the first layer of convolution layer is deleted. And finally deleting the average pooling layer which appears before the full connection layer.
3-3. The teacher network trained in the previous stage is now introduced to form a discriminator with the student network in combination (see references Fang G, song J, shen C, et al data-FREE ADVERSARIAL Distillation [ J ].2019. Section 3.2) and the generator is used for countermeasures, so that the whole countermeasures knowledge distillation frame is established. Since the challenge training is generated, it is divided into two phases, called an emulation phase and a generation phase. Firstly, in the imitation stage, a generator is fixed, and only the student network in the discriminator is updated; in the generation phase, the student network is fixed and the generator is updated. Notably, in order for the student network to keep pace with the generator during the challenge learning, 5 iterative exercises are performed during the simulation phase, whereas only 1 iterative exercise is performed during the generation phase.
And 4, establishing an objective function for generating the countermeasure knowledge distillation, wherein the specific steps are as follows:
4-1. The purpose of knowledge distillation is to enable the student network to achieve the effect similar to the teacher network under the guidance of the teacher network, and the optimal situation is the same as the result of the teacher network. The objective function is then typically to let the student network fit the teacher network over the output results. Where training data is available to train the teacher network, KL divergence (Kullback-Leibler divergence) or MSE (Mean Squared Error, MSE) is often used as the objective function, but these objective functions are not applicable here. Because the method adopts a learning strategy for generating the countermeasure, the sample generated by the generator always changes, the objective function such as MSE is relatively dependent on training data, and the method is suitable for the condition that the training data is real and cannot change and is not suitable for the countermeasure learning process, so MAE is selected as the objective function, and the MAE is not easy to cause gradient disappearance, and only in the condition that two distributions are very similar, the gradient is small.
MAE can be used to measure the degree of similarity of two output vectors, the more similar the two vectors, the smaller the value; the larger the difference, the larger the value. The purpose of the optimizer is to continuously reduce the loss value, so the objective function L S for the student network is directly set as the MAE value output by the student network and the teacher network, and the specific formula is as follows:
In the above formula, T represents a teacher network, S represents a student network, G h represents a difficult sample generator G1, z represents a random vector sampled from a normal distribution, |·| 1 represents an L1 norm. The purpose of the difficult sample generator G1 is to continuously generate samples that are easy to distinguish by the teacher network, but are not easy to distinguish by the student network, which is equivalent to that the difference between the two output vectors is to be increased, so the objective function L G1 of the G1 is set as a negative MAE value output by the student network and the teacher network, and the specific formula is as follows:
And 5, performing iterative training on the built countermeasure knowledge distillation frame, wherein the method comprises the following specific steps of:
5-1. Set the number of rounds of total training, the method was set to 500 rounds, each round containing 50 iterative training. Each iterative training comprises two stages in sequence: an emulation phase and a generation phase.
5-2, Fixing a generator and optimizing the student network in the simulation stage. Firstly, generating random vectors conforming to standard normal distribution, then sending the random vectors into a difficult sample generator G1 to obtain generated data, then sending the generated data into a student network and a teacher network respectively to obtain corresponding output, calculating the value of an objective function L S, and finally, reversely transmitting errors to optimize the student network. The training process described above needs to be repeated 5 times. In particular, SGD optimizers are used to optimize student networks for all datasets, with specific parameters varying. For the MNIST dataset, the learning rate was set to 0.01, the weight decay was 0.0001, the momentum was 0.9, and the batch size of the generated data was 512. For CIFAR data sets, the momentum was 0.9, the weight decay 0.0005,batch size was 128, the initial learning rate was 0.1, and the training was 500 rounds total, with 10-fold reductions in learning rate at rounds 100 and 200. CIFAR100 are substantially identical to CIFAR, the only difference being that the batch size is changed to 256.
5-3. In the generation phase, the student network is fixed, and the difficult sample generator G1 is optimized. Firstly, generating random vectors which are subjected to standard normal distribution, then sending the random vectors into a difficult sample generator G1 to obtain generated data, then sending the generated data into a student network and a teacher network respectively to obtain corresponding output, calculating the value of an objective function L G1, and finally, reversely transmitting errors to optimize the difficult sample generator G1. Regardless of the data set used, the Adam optimizer is used to optimize the difficult sample generator G1 with different specific parameters. If the teacher network is trained on MNIST data sets, the learning rate is set to 0.001 and the batch size of the generated data is 512. If the teacher network is trained on CIFAR data sets, the batch size of the generated data is set to 128, the initial learning rate is 0.001, and training is performed for a total of 500 rounds, and the learning rate is sequentially reduced by 10 times at 100 th and 200 th rounds. CIFAR100 is set up substantially the same as CIFAR, the only difference being that the batch size of the generated data is changed to 256.
And 5-4, iterating for 50 times in steps 5-2 and 5-3, and starting the next training round after the iteration is completed until all training rounds are completed.
Step 6, introducing a simple sample generator G2, and using a difficult sample generator G1 and the simple sample generator G2 to alternately adjust the student network, wherein the specific steps are as follows:
This step will introduce a generator G2 that generates simple samples, which will also be added to the overall training process, after the student's network and the challenge training of the difficult sample generator G1 have completed a specified number of rounds in step 5. There is a problem in how the simple sample generator G2 is initialized because the generator is not well initialized and breaks the overall generation counter training balance. The difficult sample generator G1 trained in step 5 is directly duplicated here as an initialization of the simple sample generator G2.
6-2. Consider now the objective function of the simple sample generator G2. The purpose of G2 is to generate a simple sample, which should be input to the teacher network and output from the student network that are relatively similar, so MAE is also used here as the objective function L G2 for G2. The specific form is as follows:
in the above formula, T represents a teacher network, S represents a student network, G s represents a simple sample generator G2, and z represents a random vector sampled from a normal distribution.
The iterative training process of this step has been described in the summary of the invention, only to supplement what optimizers are used and how the training parameters are set. For all data sets, the student network was optimized using an SGD optimizer, with a distinction in parameter settings. For MNIST dataset, the learning rate is set to 0.01, the momentum is 0.9, and the weight decays by 0.00001. For both datasets CIFAR and CIFAR100, the learning rate was set to 0.001, the momentum to 0.9, and the weight decay to 0.00005. For all data sets, the difficult sample generator G1 and the simple sample generator G2 were optimized using Adam optimizer, with a distinction in parameter settings. For the MNIST dataset, the optimizer learning rate of G1 was set to 0.001, the optimizer learning rate of G2 was set to 0.0001, and the batch size of the generated data was 512. For CIFAR data sets, the learning rate was set to 0.00001 and the batch size of the generated data was 128. For CIFAR data sets, the learning rate was set to 0.00001 and the batch size of the generated data was 256.
Experimental results:
1. The classification accuracy of the method was tested on MNIST, CIFAR10 and CIFAR data sets, respectively, against data-less knowledge-resistant distillation (DFAD), data-less knowledge distillation (KD-ORI), while demonstrating the results of training alone for student and teacher networks without knowledge distillation. Details of the specific data results are shown in Table 1.
Table 1 results of comparative experiments of the method with other training methods on different data sets

Claims (9)

1. A multi-network combined auxiliary generation type knowledge distillation method is characterized by comprising the following steps of:
Step1: preprocessing an image classification data set;
step 2: selecting a teacher network model according to the determined image classification data set and training;
Step3: selecting a difficult sample generator G1 and a student network according to the determined image classification data set to form an countermeasure knowledge distillation framework;
step 4: establishing an objective function for generating countermeasure knowledge distillation;
step 5: performing iterative training on the built countermeasure knowledge distillation frame;
Step 6: introducing a simple sample generator G2, and alternately adjusting a student network by using a difficult sample generator G1 and the simple sample generator G2 to obtain a final result;
the specific steps of the step 6 are as follows:
6-1. In step5, after the student network and the challenge training of the difficult sample generator G1 have completed the set number of rounds, copying the difficult sample generator G1 as a simple sample generator G2;
6-2. The simple sample generator G2 is also involved in the training process, the purpose of the simple sample generator G2 is to generate a simple sample, and the generated simple sample is input to the student network and the teacher network to obtain a smaller output difference, so that the objective function L G2 is shown in the following formula 4:
where G s represents a generator G2 that generates a simple sample;
6-3, after the simple sample generator G2 is introduced, the student network is continuously subjected to countermeasure training, and the training mode of the step is also an iterative training mode; firstly, setting the total training round number as 100 rounds, wherein each round comprises 50 iterative training; each round of training first determines whether to choose to use the difficult sample generator G1 or the simple sample generator G2; taking a strategy of training 3 rounds of G1 and 1 round of G2 per training, wherein the first round of training uses a simple sample generator G2; whichever generator is used, each iterative training is divided into two phases in succession: an emulation phase and a generation phase; if G1 is used, the two phases of operation are the same as in step 5; if G2 is used, the training process in the simulation phase only goes through 1 time to transition to the generation phase; the operation of the generation stage is also the same as that of step 5, but the calculated objective function is L G2, and the simple sample generator G2 is optimized.
2. The multi-network joint auxiliary generation type knowledge distillation method according to claim 1, wherein the specific steps of step 1 are as follows:
1-1, data preparation and pretreatment;
Selecting a public data set, wherein the public data set adopts MNIST, CIFAR10 or CIFAR data sets; for MNIST data set, firstly amplifying the resolution to 32x32, then carrying out image normalization, and finally carrying out standardization treatment; for CIFAR and CIFAR data sets, directly carrying out image normalization and then carrying out normalization treatment;
1-2, enhancing an image; MNIST data is too simple, so image enhancement is not performed; the same image enhancement operation is performed on both datasets CIFAR and CIFAR; filling 4 pixels in the upper, lower, left and right of the image, and then cutting randomly; finally, the clipped image is flipped at a probability of 0.5 at random.
3. The multi-network joint auxiliary generation type knowledge distillation method according to claim 2, wherein the specific steps of step 2 are as follows:
2-1, teacher networks corresponding to different image classification data sets are different; on MNIST data sets, teacher networks use LeNet; for CIFAR and CIFAR100, the teacher network uses ResNet and improves on the original ResNet; firstly, modifying parameters of a first layer of convolution layer, changing kernel_size from 7 to 3, stride from 2 to 1, padding from 3 to 1, and deleting a subsequent maximum pooling layer of the first layer of convolution layer; finally deleting the average value pooling layer before the full connection layer;
2-2, aiming at different data sets and different teacher networks, the adopted training methods are the same, but the configuration of parameters is different, so that the teacher networks achieve the best effect; the training method comprises the following steps: firstly, setting total training rounds, inputting a training set part of a selected data set into a teacher network in each round of training, obtaining an output value of the teacher network, putting the output value and a training set label into an objective function together for calculation, and finally, transmitting errors reversely to optimize the teacher network; the objective function employs multi-class cross entropy.
4. The multi-network combined auxiliary generation type knowledge distillation method as claimed in claim 3, wherein the specific steps of the step 3 are as follows:
3-1 for all data sets, the difficult sample generator G1 uses the generation network in DCGAN;
3-2, different student networks corresponding to different data sets;
on MNIST dataset, student networks use LeNet-Half; for CIFAR and CIFAR100 datasets, the student network uses ResNet18; and improves on the basis of the original ResNet; firstly, modifying parameters of a first layer of convolution layer, changing kernel_size from 7 to 3, stride from 2 to 1, padding from 3 to 1, and deleting a subsequent maximum pooling layer of the first layer of convolution layer; finally deleting the average value pooling layer before the full connection layer;
3-3, forming a whole training framework by the difficult sample generator G1, the student network and the trained teacher network together, wherein the student network and the teacher network are combined to be used as a discriminator; the whole training process is divided into two stages, namely an imitation stage and a generation stage; a simulation stage, wherein the fixed generator updates the student network; and in the generation stage, a student network update generator is fixed.
5. The multi-network joint auxiliary generation type knowledge distillation method according to claim 4, wherein the specific steps of step4 are as follows:
4-1, selecting an average absolute value error MAE as an objective function, in the specific form shown in the following formula 1, wherein n represents the number of samples, y i represents a true value, Representing the predicted value;
4-2, in the simulation stage, the objective function of the student network is recorded as L S, the specific form is shown as the following formula 2, wherein T represents the teacher network, S represents the student network, n represents the classification category number, G h represents the difficult sample generator G1, z represents the random vector sampled from normal distribution, I 1 represents the L1 norm, and the objective function L S is the MAE value of the output result of the student network and the teacher network in practical sense; in the generation stage, the objective function of the difficult sample generator G1 is recorded as L G1, and the specific form is as shown in the following formula 3, and the objective function corresponds to a negative MAE value of output results of a student network and a teacher network;
6. The multi-network joint auxiliary generation type knowledge distillation method according to claim 5, wherein the specific steps of step 5 are as follows:
5-1, setting the number of training wheels, and setting the number of training wheels to 500 wheels, wherein each wheel comprises 50 iterative training; each iterative training comprises two stages in sequence: an emulation phase and a generation phase;
5-2, in the simulation stage, the difficult sample generator G1 does not update parameters, and the student network updates parameters; firstly, generating random vectors conforming to standard normal distribution, then sending the random vectors into a difficult sample generator G1 to obtain generated data, then sending the generated data into a student network and a teacher network respectively to obtain corresponding output, calculating a value of an objective function L S, and finally, reversely transmitting errors to optimize the student network; it should be noted that the training process needs to be repeated 5 times to enter the generation stage;
5-3, in the generation stage, the student network does not update parameters, and the difficult sample generator G1 updates parameters; firstly, generating random vectors conforming to standard normal distribution, then sending the random vectors into a difficult sample generator G1 to obtain generated data, then sending the generated data into a student network and a teacher network respectively to obtain corresponding output, calculating a value of an objective function L G1, and finally, reversely transmitting errors to optimize the difficult sample generator G1;
And 5-4, iterating the simulation stage and the generation stage for 50 times, and starting the next training after the iteration is completed until all training rounds are completed.
7. A multi-network joint auxiliary generation type knowledge distillation method according to claim 3, wherein the training method adopted by the step2 is the same for different data sets and different teacher networks, but the configuration of parameters is different, specifically as follows:
For MNIST data set, selecting a random gradient descent method to optimize a teacher model, setting momentum to 0.9, setting weight attenuation to 0.0001, learning rate to 0.01, batch size to 256, and training for 10 rounds in total; for CIFAR and CIFAR data sets, the teacher model was also optimized using the SGD optimizer except that the weight decay was 0.0005,batch size to 128, the initial learning rate was 0.1, and the training was 200 rounds total, reducing the learning rate by a factor of 10 per 80 rounds.
8. The multi-network joint assisted generation type knowledge distillation method according to claim 6, wherein said simulation phase of step 5:
for all data sets, an SGD optimizer is used for optimizing a student network, and specific parameters are different; for MNIST data set, the learning rate is set to 0.01, the weight attenuation is 0.0001, the momentum is 0.9, and the batch size of the generated data is 512; for CIFAR data sets, the momentum is 0.9, the weight decay 0.0005,batch size is 128, the initial learning rate is 0.1, the total training is 500 rounds, and the learning rate is sequentially reduced by 10 times at 100 th round and 200 th round; CIFAR100 are substantially identical to CIFAR, the only difference being that the batch size is changed to 256;
The generation stage:
for all data sets, using Adam optimizer to optimize the difficult sample generator G1, the specific parameters are different; when the teacher network is trained on the MNIST data set, the learning rate is set to be 0.001, and the batch size of the generated data is 512; when the teacher network is trained on CIFAR data sets, the batch size of the generated data is set to 128, the initial learning rate is 0.001, the total training is 500 rounds, and the learning rate is sequentially reduced by 10 times when the training is performed on 100 th round and 200 th round; CIFAR100 is set up substantially the same as CIFAR, the only difference being that the batch size of the generated data is changed to 256.
9. The multi-network joint auxiliary generation type knowledge distillation method according to claim 6, wherein the countermeasure training in step 6 is specifically as follows:
For all data sets, the student network is optimized by using an SGD optimizer, and parameter settings are different; for MNIST data sets, the learning rate is set to 0.01, the momentum is 0.9, and the weight decay is 0.00001; for both datasets CIFAR and CIFAR100, the learning rate was set to 0.001, the momentum was 0.9, and the weight decay was 0.00005; for all data sets, the difficult sample generator G1 and the simple sample generator G2 are optimized by using an Adam optimizer, and parameter settings are different; for the MNIST dataset, the optimizer learning rate of G1 was set to 0.001, the optimizer learning rate of G2 was set to 0.0001, and the batch size of the generated data was 512; for CIFAR data sets, the learning rate was set to 0.00001, and the batch size of the generated data was 128; for CIFAR data sets, the learning rate was set to 0.00001 and the batch size of the generated data was 256.
CN202210172188.2A 2022-02-24 2022-02-24 Multi-network combined auxiliary generation type knowledge distillation method Active CN114549901B (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210172188.2A CN114549901B (en) 2022-02-24 2022-02-24 Multi-network combined auxiliary generation type knowledge distillation method

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210172188.2A CN114549901B (en) 2022-02-24 2022-02-24 Multi-network combined auxiliary generation type knowledge distillation method

Publications (2)

Publication Number Publication Date
CN114549901A CN114549901A (en) 2022-05-27
CN114549901B true CN114549901B (en) 2024-05-14

Family

ID=81676859

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210172188.2A Active CN114549901B (en) 2022-02-24 2022-02-24 Multi-network combined auxiliary generation type knowledge distillation method

Country Status (1)

Country Link
CN (1) CN114549901B (en)

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111291836A (en) * 2020-03-31 2020-06-16 中国科学院计算技术研究所 Method for generating student network model
CN112465111A (en) * 2020-11-17 2021-03-09 大连理工大学 Three-dimensional voxel image segmentation method based on knowledge distillation and countertraining
CN112560631A (en) * 2020-12-09 2021-03-26 昆明理工大学 Knowledge distillation-based pedestrian re-identification method
US11200497B1 (en) * 2021-03-16 2021-12-14 Moffett Technologies Co., Limited System and method for knowledge-preserving neural network pruning

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111291836A (en) * 2020-03-31 2020-06-16 中国科学院计算技术研究所 Method for generating student network model
CN112465111A (en) * 2020-11-17 2021-03-09 大连理工大学 Three-dimensional voxel image segmentation method based on knowledge distillation and countertraining
CN112560631A (en) * 2020-12-09 2021-03-26 昆明理工大学 Knowledge distillation-based pedestrian re-identification method
US11200497B1 (en) * 2021-03-16 2021-12-14 Moffett Technologies Co., Limited System and method for knowledge-preserving neural network pruning

Also Published As

Publication number Publication date
CN114549901A (en) 2022-05-27

Similar Documents

Publication Publication Date Title
CN110020682B (en) Attention mechanism relation comparison network model method based on small sample learning
Sercu et al. Very deep multilingual convolutional neural networks for LVCSR
CN110546656B (en) Feedforward generation type neural network
US9400955B2 (en) Reducing dynamic range of low-rank decomposition matrices
US20190034784A1 (en) Fixed-point training method for deep neural networks based on dynamic fixed-point conversion scheme
CN107690663B (en) Whitening neural network layer
WO2018051841A1 (en) Model learning device, method therefor, and program
CN109767759A (en) End-to-end speech recognition methods based on modified CLDNN structure
WO2021057884A1 (en) Sentence paraphrasing method, and method and apparatus for training sentence paraphrasing model
WO2022217849A1 (en) Methods and systems for training neural network model for mixed domain and multi-domain tasks
CN113343705B (en) Text semantic based detail preservation image generation method and system
CN114038055B (en) Image generation method based on contrast learning and generation countermeasure network
CN111126602A (en) Cyclic neural network model compression method based on convolution kernel similarity pruning
CN109740695A (en) Image-recognizing method based on adaptive full convolution attention network
WO2022083165A1 (en) Transformer-based automatic speech recognition system incorporating time-reduction layer
CN111414928A (en) Method, device and equipment for generating face image data
CN113033822A (en) Antagonistic attack and defense method and system based on prediction correction and random step length optimization
US20180061395A1 (en) Apparatus and method for training a neural network auxiliary model, speech recognition apparatus and method
CN114549901B (en) Multi-network combined auxiliary generation type knowledge distillation method
CN112199481B (en) Single-user personalized dialogue method and system adopting PCC dialogue model
Jiang et al. Neuralizing regular expressions for slot filling
CN113076391A (en) Remote supervision relation extraction method based on multi-layer attention mechanism
CN115374251A (en) Dense retrieval method based on syntax comparison learning
CN115249061A (en) Pruning method and system for ViT network model
CN110175231B (en) Visual question answering method, device and equipment

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
CB03 Change of inventor or designer information
CB03 Change of inventor or designer information

Inventor after: Wang Yilin

Inventor after: Kuang Zhenzhong

Inventor after: Ding Jiajun

Inventor after: Gu Xiaoling

Inventor after: Yu Jun

Inventor before: Kuang Zhenzhong

Inventor before: Wang Yilin

Inventor before: Ding Jiajun

Inventor before: Gu Xiaoling

Inventor before: Yu Jun

GR01 Patent grant
GR01 Patent grant