CN114549901A - Multi-network joint auxiliary generation type knowledge distillation method - Google Patents

Multi-network joint auxiliary generation type knowledge distillation method Download PDF

Info

Publication number
CN114549901A
CN114549901A CN202210172188.2A CN202210172188A CN114549901A CN 114549901 A CN114549901 A CN 114549901A CN 202210172188 A CN202210172188 A CN 202210172188A CN 114549901 A CN114549901 A CN 114549901A
Authority
CN
China
Prior art keywords
training
network
sample generator
teacher
student
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202210172188.2A
Other languages
Chinese (zh)
Other versions
CN114549901B (en
Inventor
匡振中
王一琳
丁佳骏
顾晓玲
俞俊
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Hangzhou Dianzi University
Original Assignee
Hangzhou Dianzi University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Hangzhou Dianzi University filed Critical Hangzhou Dianzi University
Priority to CN202210172188.2A priority Critical patent/CN114549901B/en
Publication of CN114549901A publication Critical patent/CN114549901A/en
Application granted granted Critical
Publication of CN114549901B publication Critical patent/CN114549901B/en
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

The invention discloses a multi-network joint auxiliary generation type knowledge distillation method, which comprises the steps of firstly preprocessing an image classification data set; then, selecting a teacher network model according to the determined image classification data set and training; selecting a difficult sample generator G1 and a student network according to the determined image classification data set to form an antagonistic knowledge distillation framework; establishing an objective function for generating the antagonistic knowledge distillation; performing iterative training on the constructed confrontation knowledge distillation framework; finally, a simple sample generator G2 is introduced, and the student network is adjusted alternately by using a difficult sample generator G1 and a simple sample generator G2 to obtain a final result. The invention additionally introduces a simple sample generator, and the simple sample generator directly copies the trained difficult sample generator without increasing the calculation amount and has simple operation. In the case where the simple sample generator helps the student network to review simple samples, ultimately a better result is achieved on the target task.

Description

Multi-network joint auxiliary generation type knowledge distillation method
Technical Field
The invention belongs to the knowledge distillation field in the computer vision field, and particularly provides a multi-network joint auxiliary generation type knowledge distillation method which is used for an image classification task.
Background
Convolutional Neural Networks (CNNs) have achieved significant achievements in the fields of image classification, segmentation, detection, and the like, due to their powerful feature extraction and expression capabilities. However, the structure of the neural network with high expression capability is complex and the quantity of parameters is huge. In this case, deploying a complete CNN often requires a huge memory overhead and a high-performance computing unit, and the application of the CNN is limited in embedded devices with limited computing resources and mobile terminals with high real-time requirements. Therefore, the CNN is urgently required to be lightweight.
Knowledge distillation is currently widely used as a model compression method. Knowledge distillation considers the model to be compressed as a "teacher" and the compressed model as a "student". The teacher has strong network capacity, but the structure is complex and the deployment is inconvenient; the student network has a simple structure, but the effect obtained by direct training is not good. Knowledge distillation is a mode of assisting student network training through a teacher network, so that the performance of the student network is improved, and the effect similar to that of the teacher network is achieved.
Most of the existing deep neural network compression and acceleration methods are very effective when model compression is performed if training data can be directly accessed. However, most model compression methods fail if the training data is not accessible for privacy or legal reasons, and some researchers have proposed model compression methods that do not require training data.
For example, it is proposed that a light neural network is not required to be retrained, pruning can be directly performed on the full-link layer of the original model, similar neurons can be eliminated, and the final output result is not too different. However, this method cannot be used on convolutional layers, so that the degree of compression of the model is greatly reduced, and the method fails in the case where the internal structure of the model is unknown. There has also been proposed a method of reconstructing training data from "metadata" (such as activation values of a network layer) in the training process of the original model, but the original model does not retain the "metadata" in most cases.
It was not difficult to find that none of the above methods was practical in practice, and knowledge distillation in combination with Generation of Antagonistic Networks (GAN) was later proposed. GAN as a generative model can generate some generative data instead of training data, which can be used to perform knowledge distillation. Note that data-free is not required for any purpose, but rather refers to data that is not used to train a teacher's network.
For example, there is a method of combating Distillation (1 Fang G, Song J, Shen C, et al. data-free adaptive Distillation J. 2019), which introduces a generator in addition to a pre-trained teacher network, and then combines the student network and the teacher network as a discriminator. The purpose of the generator is to generate difficult samples that make the output of the teacher network and the output of the student network more different, while the learning goal of the student network is to reduce the difference in output from the teacher network. Through continuous learning, the difficult samples can be gradually mastered into simple samples and can be easily distinguished. In this time, the generator needs to continue to search the sample space to find a difficult sample which can enlarge the output difference between the student network and the teacher network and is not mastered by the student network. The whole training process is a process for generating the countermeasures.
The method for resisting knowledge distillation has the problems that the generator generates difficult samples to be fed into the student network training, and finally the student network can forget the simple samples, so that prediction is wrong, and the overall performance is reduced.
Disclosure of Invention
The invention aims to provide a multi-network joint auxiliary generation knowledge distillation method aiming at the defects of the method. The method adds the generator G2 which generates simple samples to assist the generator G1 which generates difficult samples, and the two generators jointly adjust the student network to prevent the student network from neglecting some simple samples while pursuing performance on the difficult samples, so that the overall performance is reduced.
In the training process, a mode of alternately training a difficult sample generator G1 and a simple sample generator G2 is adopted, the goal of G1 is to generate samples which enable the output difference between the teacher network and the student network to be large, the goal of G2 is to generate samples which enable the output difference between the teacher network and the student network to be small, and the goal of the student network is to reduce the output difference between the teacher network and the student network regardless of whether G1 or G2 is used. In the case where G2 helps the student network to review simple samples, it eventually achieves better results on the target task than the original method.
A multi-network joint assisted generation type knowledge distillation method comprises the following steps:
step 1: preprocessing an image classification data set;
step 2: selecting a teacher network model according to the determined image classification data set and training;
and step 3: selecting a difficult sample generator G1 and a student network according to the determined image classification data set to form an antagonistic knowledge distillation framework;
and 4, step 4: establishing an objective function for generating the antagonistic knowledge distillation;
and 5: performing iterative training on the constructed confrontation knowledge distillation framework;
step 6: the simple sample generator G2 is introduced, and the student network is adjusted alternately using the difficult sample generator G1 and the simple sample generator G2 to obtain the final result.
The step 1 comprises the following steps:
1-1. data preparation and preprocessing.
And selecting a public data set, wherein the public data set adopts MNIST, CIFAR10 or CIFAR100 data set. For the MNIST dataset, its resolution was first scaled up to 32x32, then normalized for the image, and finally normalized. For the CIFAR10 and CIFAR100 datasets, image normalization was performed directly, followed by normalization processing.
And 1-2, enhancing the image. MNIST data is too simple to perform image enhancement. The same image enhancement operation is done on both the CIFAR10 and CIFAR100 datasets. The image is filled with 4 pixels respectively from top to bottom and from left to right, and then is cut randomly. And finally, randomly horizontally turning the clipped image with a probability of 0.5.
Step 2, the concrete steps are as follows:
and 2-1, different image classification data sets correspond to different teacher networks. On the MNIST dataset, the teacher network uses LeNet. Both CIFAR10 and CIFAR100, the teacher network used ResNet34 and improved over the original ResNet 34. The parameters of the first layer convolutional layer are modified, kernel _ size is changed from 7 to 3, stride is changed from 2 to 1, padding is changed from 3 to 1, and then the following maximum pooling layer of the first layer convolutional layer is deleted. Finally, the average pooling layer that occurred before the fully connected layer is deleted.
And 2-2, aiming at different data sets and different teacher networks, the adopted training methods are the same, but the parameter configurations are different, so that the teacher network achieves the best effect. The training method comprises the following steps: the method comprises the steps of firstly setting the total number of training rounds, inputting a training set part of a selected data set into a teacher network in each training round, obtaining an output value of the teacher network and putting the output value and a training set label into an objective function for calculation, finally reversely transmitting errors, and optimizing the teacher network. The objective function adopts multi-class cross entropy.
The step 3 comprises the following steps:
3-1. for all datasets, the difficulty sample generator G1 uses the generation network in DCGAN (deep relational general adaptive networks).
And 3-2, different data sets correspond to different student networks.
On the MNIST data set, the student network uses LeNet-Half; the student network uses ResNet18 for both CIFAR10 and CIFAR100 datasets. And is improved on the basis of the original ResNet 18. The parameters of the first layer convolutional layer are modified, kernel _ size is changed from 7 to 3, stride is changed from 2 to 1, padding is changed from 3 to 1, and then the following maximum pooling layer of the first layer convolutional layer is deleted. Finally, the average pooling layer that occurred before the fully connected layer is deleted.
3-3, combining the difficulty sample generator G1, the student network and the trained teacher network to form the whole training frame, and combining the student network and the teacher network to be used as a discriminator. The whole training process is divided into two stages, a simulation stage and a generation stage. In the simulation phase, the fixed generator updates the student network; and a generation stage, namely updating a generator by the fixed student network.
The step 4 comprises the following steps:
4-1. selecting Mean Absolute Error (MAE) as objective function, wherein n represents sample number, y represents sample number, and n representsiThe actual value is represented by the value of,
Figure BDA0003518721950000051
indicating the predicted value.
Figure BDA0003518721950000052
4-2, in the simulation phase, marking the objective function of the student network as LSThe concrete form is as following formula 2, wherein T represents teacher network, S represents student network, n represents classification category number, ghRepresenting a difficult sample generator G1, z representing a random vector sampled from a normal distribution, | · | | computationally |, a1Representing the L1 norm, the objective function LSThe practical meaning is the MAE value of the output result of the student network and the teacher network; in the generation phase, the objective function of the difficult sample generator G1 is denoted as LG1The specific form is as following formula 3, corresponding to negative MAE values of the output results of the student network and the teacher network.
Figure BDA0003518721950000061
Figure BDA0003518721950000062
The step 5 comprises the following steps:
and 5-1, setting the number of training rounds, and setting the number of the training rounds as 500 rounds, wherein each round comprises 50 times of iterative training. Each iterative training comprises two stages: an emulation phase and a generation phase.
5-2. in the simulation phase, the difficulty sample generator G1 does not perform parameter updates and the student network does. Firstly, generating random vectors which are in compliance with standard normal distribution, then sending the random vectors into a difficult sample generator G1 to obtain generated data, then respectively sending the generated data into a student network and a teacher network to obtain corresponding outputs, and calculating an objective function LSFinally, the error is reversely transmitted, and the student network is optimized. It should be noted that the training process described above needs to be repeated 5 times to enter the generation phase.
5-3. in the generation phase, the student network does not perform parameter updating, and the difficulty sample generator G1 performs parameter updating. Firstly, generating random vectors which are in compliance with standard normal distribution, then sending the random vectors into a difficult sample generator G1 to obtain generated data, then respectively sending the generated data into a student network and a teacher network to obtain corresponding outputs, and calculating an objective function LG1Finally the error is passed back and the difficult sample generator G1 is optimized.
5-4, iterating the simulation phase and the generation phase for 50 times, and starting the next training round after the iteration is finished until all training rounds are finished.
The step 6 comprises the following steps:
6-1. after the competing training of the student network and the difficult sample generator G1 completed the set number of rounds, the difficult sample generator G1 was copied as a simple sample generator G2 in step 5.
6-2. simple sample generator G2 is also involved in the training process, the purpose of simple sample generator G2 is to generate simple samples, which are generatedThe output difference obtained by inputting the input data to the student network and the teacher network is small, so that the target function L is smallG2As shown in the following formula 4:
Figure BDA0003518721950000071
in the formula, gsRepresenting a generator G2 that generates simple samples.
6-3, after the simple sample generator G2 is introduced, the confrontational training of the student network is continued, and the training mode of the step is also an iterative training mode. A total number of training rounds was first set to 100 rounds, each round comprising 50 iterative training rounds. Each round of training first determines whether to choose to use the difficult sample generator G1 or the simple sample generator G2. The strategy of training 1 round G2 with 3 rounds of G1 per training is adopted, the first round of training using a simple sample generator G2. Regardless of which generator is used, each iterative training is divided into two stages: an emulation phase and a generation phase. If G1 is used, the operation of the two stages is the same as in step 5. If G2 is used, the training process in the simulation phase is only carried out 1 time to transit to the generation phase; the operation of the generation phase is also the same as step 5, but the objective function calculated is LG2Optimized is a simple sample generator G2.
Further, in step 2, for different data sets and different teacher networks, the training methods adopted are the same, but the configuration of the parameters is different, specifically as follows:
for the MNIST dataset, a Stochastic Gradient Descent (SGD) optimization teacher model was chosen, with momentum set to 0.9, weight attenuation set to 0.0001, learning rate 0.01, and batch size 256 for a total of 10 rounds of training. For CIFAR10 and CIFAR100 datasets, the teacher model was also optimized using the SGD optimizer, except that the weight decay was 0.0005, the batch size was 128, the initial learning rate was 0.1, and 200 rounds of training were performed in total, with a 10-fold reduction in learning rate per 80 rounds.
Further, the simulation phase of step 5:
the SGD optimizer was used to optimize the student network for all datasets, with specific parameters varying. For the MNIST dataset, the learning rate was set to 0.01, the weight attenuation was 0.0001, the momentum was 0.9, and the batch size of the generated data was 512. For the CIFAR10 dataset, momentum was 0.9, weight decay 0.0005, batch size 128, initial learning rate 0.1, 500 rounds of training total, learning rate was reduced by 10 times in turn at 100 th and 200 th rounds. CIFAR100 and CIFAR10 are substantially identical, the only difference being that the batch size is changed to 256.
A generation stage:
for all datasets, Adam optimizer was used to optimize the difficult sample generator G1, with specific parameters being different. When the teacher network is trained on the MNIST data set, the learning rate is set to 0.001 and the batch size of the generated data is 512. When the teacher's network is trained on the CIFAR10 dataset, the batch size of the generated data is set to 128, the initial learning rate is 0.001, 500 rounds of training are performed in total, and the learning rate is reduced by 10 times in turn at the 100 th round and the 200 th round. The setup of CIFAR100 is substantially the same as CIFAR10, with the only difference being that the batch size of the generated data is changed to 256.
Further, the confrontation training in step 6 is specifically as follows:
for all data sets, the student network uses the SGD optimizer to optimize, with differentiated parameter settings. For the MNIST dataset, the learning rate is set to 0.01, the momentum is 0.9, and the weight decays by 0.00001. For both data sets CIFAR10 and CIFAR100, the learning rate was set to 0.001, the momentum was 0.9, and the weight decayed by 0.00005. For all datasets, both the difficult sample generator G1 and the simple sample generator G2 were optimized using Adam optimizers, with differences in parameter settings. For the MNIST dataset, the optimizer learning rate of G1 was set to 0.001, the optimizer learning rate of G2 was set to 0.0001, and the batch size of the generated data was 512. For the CIFAR10 dataset, the learning rate was set to 0.00001 and the batch size of the generated data was 128. For the CIFAR100 dataset, the learning rate is set to 0.00001 and the batch size of the generated data is 256.
The invention has the following beneficial effects:
compared with the prior art, the method has certain improvement on experimental results. The invention additionally introduces a simple sample generator, and the simple sample generator directly copies the trained difficult sample generator without increasing the calculation amount and has simple operation. In the case where the simple sample generator helps the student network to review simple samples, ultimately a better result is achieved on the target task.
Drawings
FIG. 1 is a flow chart of the steps of the present invention;
FIG. 2 is a flow chart of the overall architecture of the present invention.
Detailed Description
The invention will be further explained with reference to the drawings.
A multi-network joint assisted generation knowledge distillation method comprises the following specific steps as shown in figure 1, and an overall architecture flow chart is shown in figure 2:
step 1: preprocessing an image classification data set;
and 2, step: selecting a teacher network model and training;
and step 3: selecting a difficulty sample generator G1 and a student network to form a whole training framework;
and 4, step 4: establishing an objective function for generating the antagonistic knowledge distillation;
and 5: performing iterative training on the constructed confrontation knowledge distillation framework;
step 6: simple sample generator G2 was introduced, and the student network was adjusted alternately using the difficult sample generator G1 and the simple sample generator G2.
Step 1, data processing, comprising the following specific steps:
1-1, loading an image classification data set, and selecting a public data set, wherein the public data set adopts MNIST, CIFAR10 or CIFAR100 data set. For the MNIST dataset, its resolution was first scaled up to 32x32, then normalized for the image, and finally normalized. For the CIFAR10 and CIFAR100 datasets, image normalization was performed directly, followed by normalization processing. Taking an RGB three-channel image such as CIFAR10 as an example, the image is normalized first, since the RGB single-channel values are [0,255], which is not favorable for model convergence. Each channel of the image is then normalized by the following equation 5 (mean becomes 0, standard deviation becomes 1), further speeding up model convergence.
Figure BDA0003518721950000111
In the above equation, μ and σ are statistical data representing the mean and standard deviation of each channel of the original image.
And 1-2, enhancing the image. In deep learning, the number of samples is generally required to be sufficient, the more the number of samples is, the better the effect of the trained model is, and the stronger the generalization ability of the model is. However, in practice, the number of samples is often insufficient or the quality of the samples is not good enough, and data enhancement is performed on the samples to improve the quality of the samples. Since the data used for training are all image data, the word image enhancement is used hereinafter instead of data enhancement. The following describes the image enhancement operations performed on the three image data sets in the present method. MNIST data is too simple to perform image enhancement. The same image enhancement operation is done on both the CIFAR10 and CIFAR100 datasets. The image is filled with 4 pixels respectively from top to bottom and from left to right, and then is cut randomly. And finally, randomly horizontally turning the clipped image with a probability of 0.5.
Step 2, selecting a teacher network model and training, and specifically comprising the following steps:
and 2-1, different image classification data sets correspond to different teacher networks. On the MNIST dataset, the teacher network uses LeNet. The LeNet network has 5 layers in total, specifically 3 layers of convolutional layers and 2 layers of full connection layers, the output units of the full connection layers are adjusted according to a training data set, and the MNIST data set comprises 10 output units. The teacher network uses ResNet34 for both CIFAR10 and CIFAR100 data sets. It is noted that the resolution of CIFAR10 and CIFAR100 data sets is 32x32, which is relatively small, so that the improvement is made on the basis of the original ResNet 34. The parameters of the first layer convolutional layer are modified, kernel _ size is changed from 7 to 3, stride is changed from 2 to 1, padding is changed from 3 to 1, and then the following maximum pooling layer of the first layer convolutional layer is deleted. Finally, the average pooling layer that occurred before the fully connected layer is deleted.
2-2. for the MNIST dataset, a Stochastic Gradient Descent (SGD) optimization teacher model was selected, with momentum set to 0.9, weight attenuation set to 0.0001, learning rate of 0.01, and batch size of 256, for a total of 10 rounds of training. For CIFAR10 and CIFAR100 datasets, the teacher model was also optimized using the SGD optimizer, except that the weight decay was 0.0005, the batch size was 128, the initial learning rate was 0.1, and 200 rounds of training were performed in total, with a 10-fold reduction in learning rate per 80 rounds.
Step 3, selecting a difficulty sample generator G1 and a student network to form a whole training frame, and specifically comprising the following steps:
3-1. for all datasets, the difficult sample generator G1 uses the generation network in DCGAN. It should be understood that all GANs are composed of two networks, called generator and discriminator, and the method does not require a special discriminator, so only the generation network in DCGAN is introduced as the difficult sample generator G1. DCGAN has several advantages over normal GAN, such as it removes all pooling layers and replaces them with convolutional layers, which makes the entire network scalable and uses BatchNormalization to speed up model convergence.
And 3-2, different data sets correspond to different student networks. On the MNIST data set, a student network uses LeNet-Half, and compared with standard LeNet, the LeNet-Half reduces the number of channels of a convolution layer and the number of neurons of a full connection layer to Half of the original number. The student network uses ResNet18 for both CIFAR10 and CIFAR100 datasets. Also, the resolution of CIFAR10 and CIFAR100 is smaller, so that the improvement is made on the basis of the original ResNet 18. The parameters of the first layer convolutional layer are modified, kernel _ size is changed from 7 to 3, stride is changed from 2 to 1, padding is changed from 3 to 1, and then the following maximum pooling layer of the first layer convolutional layer is deleted. Finally, the average pooling layer that occurred before the fully connected layer is deleted.
3-3, introducing the teacher network trained in the previous stage, combining with the student network (see the background technology, references Fang G, Song J, Shen C, et al. data-Free adaptive partitioning [ J ].2019, section 3.2) to form a discriminator, and competing with a generator, so that the whole generated competing knowledge Distillation framework is established. Since the antagonistic training is generated, it is divided into two phases, here called the simulation phase and the generation phase. Firstly, in a simulation phase, a generator is fixed, and only a student network in a discriminator is updated; in the generation phase, the student network is fixed and the generator is updated. It is noted that in order for the student network to keep up with the pace of the generator in the confrontational learning, 5 iterative trainings are performed in the simulation phase, and only 1 iterative training is performed in the generation phase.
Step 4, establishing an objective function for generating the antagonistic knowledge distillation, which comprises the following specific steps:
4-1. the purpose of knowledge distillation is to make the student network reach the effect similar to the teacher network under the guidance of the teacher network, and the optimal situation is the same as the result of the teacher network. The objective function is then typically to fit the student network to the teacher network over the output. In the case of training data available to a network of training teachers, KL-divergence (Kullback-Leibler) or MSE (Mean Squared Error) is often used as an objective function, but these objective functions are not applicable here. Because the method adopts a learning strategy for generating the countermeasure, the samples generated by the generator are always changed, an objective function such as MSE (mean square error) depends on training data, the method is suitable for the condition that the training data is real and cannot change, and the method is not suitable for the countercheck learning process, so that the MAE is selected as the objective function, the gradient of the MAE is not easy to disappear, and the gradient of the MAE is only small under the condition that two distributions are extremely similar.
MAE can be used to measure the degree of similarity of two output vectors, the more similar the two vectors are, the smaller the value is; the larger the difference, the larger the value. The purpose of the optimizer is to reduce the loss value continuously, so the objective function L for the student networkSIs directly set as a studentThe MAE value output by the network and the teacher network is specifically as follows:
Figure BDA0003518721950000141
in the above formula, T represents a teacher network, S represents a student network, ghRepresenting a difficult sample generator G1, z representing a random vector sampled from a normal distribution, | · | | computationally |, a1Representing the L1 norm. The purpose of the difficult sample generator G1 is to generate samples that are easily distinguishable from the teacher's network, but samples that are not easily distinguishable from the student's network are equivalent to making the difference between the two output vectors large, so the objective function L of G1 is hereG1The negative MAE value output by the student network and the teacher network is set as follows:
Figure BDA0003518721950000142
and 5, performing iterative training on the constructed confrontation knowledge distillation framework, wherein the method comprises the following specific steps:
5-1, setting the total number of training rounds, and setting the method to be 500 rounds, wherein each round comprises 50 times of iterative training. Each iterative training comprises two stages: an emulation phase and a generation phase.
And 5-2, in the simulation phase, fixing the generator and optimizing the student network. Firstly, generating random vectors which obey standard normal distribution, then sending the random vectors into a difficult sample generator G1 to obtain generated data, then respectively sending the generated data into a student network and a teacher network to obtain corresponding output, and calculating a target function LSFinally, the error is reversely transmitted, and the student network is optimized. The training process described above needs to be repeated 5 times. Specifically, the SGD optimizer is used to optimize the student network for all data sets, with the specific parameters being different. For the MNIST dataset, the learning rate was set to 0.01, the weight attenuation was 0.0001, the momentum was 0.9, and the batch size of the generated data was 512. For the CIFAR10 dataset, momentum was 0.9, weight decay 0.0005, batch size was 128, initialThe learning rate is 0.1, 500 rounds of training are performed in total, and the learning rate is reduced by 10 times in turn at the 100 th round and the 200 th round. CIFAR100 and CIFAR10 are substantially the same, the only difference being that the batch size is changed to 256.
5-3. in the generation phase, the student network is fixed, and the difficulty sample generator G1 is optimized. Firstly, generating random vectors which are in compliance with standard normal distribution, then sending the random vectors into a difficult sample generator G1 to obtain generated data, then respectively sending the generated data into a student network and a teacher network to obtain corresponding outputs, and calculating an objective function LG1Finally the error is passed back and the difficult sample generator G1 is optimized. Regardless of what data set is used, an Adam optimizer is used to optimize the difficult sample generator G1, with the specific parameters varying. If the teacher network is trained on the MNIST data set, the learning rate is set to 0.001 and the batch size of the generated data is 512. If the teacher's network is trained on the CIFAR10 dataset, the batch size of the generated data is set to 128, the initial learning rate is 0.001, 500 rounds of training are performed in total, and the learning rate is reduced by 10 times in turn at the 100 th and 200 th rounds. The setup of CIFAR100 is substantially the same as CIFAR10, the only difference being that the batch size of the generated data is changed to 256.
5-4, iterating for 50 times in 5-2 steps and 5-3 steps, and starting the next round of training after the iteration is finished until all rounds of training are finished.
Step 6, introducing a simple sample generator G2, and alternately adjusting the student network by using a difficult sample generator G1 and a simple sample generator G2, wherein the specific steps are as follows:
6-1. this step will introduce generator G2 that generates simple samples, which is introduced after the confrontational training of student network and difficult sample generator G1 in step 5 has completed the specified number of rounds, and adds it to the overall training process. There is a problem in how the simple sample generator G2 initializes because a bad generator initialization breaks the overall generation counter training balance. So here the hard sample generator G1 trained in step 5 is directly copied as initialization for the simple sample generator G2.
6-2 consideration now to simple sample generationThe objective function of the generator G2. The purpose of G2 is to generate simple samples whose inputs to the teacher's network and the output from the student's network should be relatively similar, so MAE is also used here as the objective function L of G2G2. The specific form is as follows:
Figure BDA0003518721950000161
in the above formula, T represents a teacher network, S represents a student network, gsA simple sample generator G2 is shown, z representing a random vector sampled from a normal distribution.
6-3. the iterative training process of this step has been described clearly in the summary of the invention, and is only supplemented with what optimizers are used and how the training parameters are set. For all data sets, the student network uses the SGD optimizer to optimize, with differentiated parameter settings. For the MNIST dataset, the learning rate is set to 0.01, the momentum is 0.9, and the weight decays by 0.00001. For both data sets CIFAR10 and CIFAR100, the learning rate was set to 0.001, the momentum was 0.9, and the weight decayed by 0.00005. For all datasets, both the difficult sample generator G1 and the simple sample generator G2 used Adam optimizers to optimize, with differences in parameter settings. For the MNIST dataset, the optimizer learning rate of G1 was set to 0.001, the optimizer learning rate of G2 was set to 0.0001, and the batch size of the generated data was 512. For the CIFAR10 dataset, the learning rate was set to 0.00001 and the batch size of the generated data was 128. For the CIFAR100 dataset, the learning rate is set to 0.00001 and the batch size of the generated data is 256.
The experimental results are as follows:
1. the classification accuracy of the method and the classification accuracy of the knowledge distillation without data opposition (DFAD) and the knowledge distillation with data (KD-ORI) are respectively tested on MNIST, CIFAR10 and CIFAR100 data sets, and meanwhile, the results of independent training of a student network and a teacher network are shown to be compared when the knowledge distillation is not used. The results are detailed in table 1.
TABLE 1 comparative experimental results of the present method with other training methods on different data sets
Figure BDA0003518721950000171

Claims (10)

1. A multi-network joint assisted generation type knowledge distillation method is characterized by comprising the following steps:
step 1: preprocessing an image classification data set;
step 2: selecting a teacher network model according to the determined image classification data set and training;
and step 3: selecting a difficult sample generator G1 and a student network according to the determined image classification data set to form an antagonistic knowledge distillation framework;
and 4, step 4: establishing an objective function for generating the antagonistic knowledge distillation;
and 5: performing iterative training on the constructed confrontation knowledge distillation framework;
step 6: the simple sample generator G2 is introduced, and the student network is adjusted alternately using the difficult sample generator G1 and the simple sample generator G2 to obtain the final result.
2. The method for distillation of knowledge with joint assistance of multiple networks and generation formula according to claim 1, wherein the specific steps of step 1 are as follows:
1-1, preparing and preprocessing data;
selecting a public data set, wherein the public data set adopts MNIST, CIFAR10 or CIFAR100 data sets; for the MNIST data set, firstly, the resolution of the MNIST data set is enlarged to 32x32, then image normalization is carried out, and finally normalization processing is carried out; directly carrying out image normalization on CIFAR10 and CIFAR100 data sets, and then carrying out standardization processing;
1-2, enhancing images; MNIST data is too simple, so image enhancement is not performed; the same image enhancement operation is carried out on the two data sets of CIFAR10 and CIFAR 100; firstly, filling 4 pixels in the upper, lower, left and right sides of an image, and then randomly cutting; and finally, randomly horizontally turning the clipped image with a probability of 0.5.
3. The method for distillation of knowledge with multiple networks joint formation assistance as claimed in claim 2, wherein the step 2 comprises the following steps:
2-1, different image classification data sets correspond to different teacher networks; on the MNIST data set, the teacher network uses LeNet; aiming at CIFAR10 and CIFAR100, the teacher network uses ResNet34 and is improved on the basis of the original ResNet 34; firstly, modifying parameters of a first layer of convolutional layer, changing kernel _ size from 7 to 3, stride from 2 to 1, padding from 3 to 1, and then deleting a rear maximum pooling layer of the first layer of convolutional layer; deleting the average pooling layer before the full connection layer;
2-2, aiming at different data sets and different teacher networks, the adopted training methods are the same, but the parameter configurations are different, so that the teacher network achieves the best effect; the training method comprises the following steps: firstly, setting the total number of training rounds, inputting a training set part of a selected data set into a teacher network in each training round, obtaining an output value of the teacher network and putting the output value and a training set label into an objective function for calculation, and finally, reversely transmitting errors and optimizing the teacher network; the objective function adopts multi-class cross entropy.
4. The method for distillation of knowledge with multiple networks combined and assisted generation as claimed in claim 3, wherein the step 3 comprises the following steps:
3-1. for all datasets, the difficult sample generator G1 uses the generation network in DCGAN;
3-2, different data sets correspond to different student networks;
on the MNIST data set, the student network uses LeNet-Half; both CIFAR10 and CIFAR100 datasets were used by student networks using ResNet 18; and is improved on the basis of the original ResNet 18; firstly, modifying parameters of a first layer of convolutional layer, changing kernel _ size from 7 to 3, stride from 2 to 1, padding from 3 to 1, and then deleting a rear maximum pooling layer of the first layer of convolutional layer; deleting the average pooling layer before the full connection layer;
3-3, combining the difficulty sample generator G1, the student network and the trained teacher network to form a whole training frame, wherein the student network and the teacher network are combined to be used as a discriminator; the whole training process is divided into two stages, namely a simulation stage and a generation stage; in the simulation phase, the fixed generator updates the student network; and a generation stage, namely updating a generator by the fixed student network.
5. The method for distillation of knowledge based on joint generation of multiple networks as claimed in claim 4, wherein the step 4 comprises the following steps:
4-1. selecting the average absolute value error MAE as the objective function, wherein the specific form is shown in the following formula 1, wherein n represents the number of samples, yiThe actual value is represented by the value of,
Figure FDA0003518721940000031
representing a predicted value;
Figure FDA0003518721940000032
4-2, in the simulation stage, marking the objective function of the student network as LSThe concrete form is as following formula 2, wherein T represents teacher network, S represents student network, n represents classification category number, ghRepresenting a difficult sample generator G1, z representing a random vector sampled from a normal distribution, | · | | computationally |, a1Representing the L1 norm, the objective function LSThe practical meaning is the MAE value of the output result of the student network and the teacher network; in the generation phase, the objective function of the difficult sample generator G1 is denoted as LG1The concrete form is as the following formula 3, corresponding to the negative MAE value of the output result of the student network and the teacher network;
Figure FDA0003518721940000033
Figure FDA0003518721940000034
6. the method for distillation of knowledge with joint assistance of multiple networks and generation formula according to claim 5, wherein the specific steps of step 5 are as follows:
5-1, setting the number of training rounds, and setting the number of training rounds as 500 rounds, wherein each round comprises 50 times of iterative training; each iterative training comprises two stages: an emulation phase and a generation phase;
5-2, in the simulation phase, the difficulty sample generator G1 does not update the parameters, and the student network updates the parameters; firstly, generating random vectors which are in compliance with standard normal distribution, then sending the random vectors into a difficult sample generator G1 to obtain generated data, then respectively sending the generated data into a student network and a teacher network to obtain corresponding outputs, and calculating an objective function LSFinally, the error is reversely transmitted, and the student network is optimized; it should be noted that the training process needs to be repeated 5 times to enter the generation stage;
5-3, in the generation stage, the student network does not update the parameters, and the difficult sample generator G1 updates the parameters; firstly, generating random vectors which are in compliance with standard normal distribution, then sending the random vectors into a difficult sample generator G1 to obtain generated data, then respectively sending the generated data into a student network and a teacher network to obtain corresponding outputs, and calculating an objective function LG1Finally, the error is transmitted back, and the difficult sample generator G1 is optimized;
5-4, iterating the simulation phase and the generation phase for 50 times, and starting the next training round after the iteration is finished until all training rounds are finished.
7. The method for distillation of knowledge with joint generation by multiple networks as claimed in claim 6, wherein the step 6 comprises the following steps:
6-1. in step 5, after the antagonistic training of the student network and the difficult sample generator G1 completes the set number of rounds, the difficult sample generator G1 is copied to be a simple sample generator G2;
6-2, the simple sample generator G2 is also involved in the training process, the purpose of the simple sample generator G2 is to generate simple samples, the generated simple samples are input to the student network and the teacher network, the obtained output difference is small, and therefore the objective function L of the simple samples is smallG2As shown in the following formula 4:
Figure FDA0003518721940000051
in the formula, gsA generator G2 representative of generating simple samples;
6-3, after the simple sample generator G2 is introduced, continuing to perform confrontation training on the student network, wherein the training mode of the step is also an iterative training mode; firstly, setting the total number of training rounds as 100 rounds, wherein each round comprises 50 times of iterative training; each round of training first determines whether to choose to use the difficult sample generator G1 or the simple sample generator G2; the strategy of training 1 round of G2 with 3 rounds of G1 per training is adopted, the first round of training uses a simple sample generator G2; regardless of which generator is used, each iterative training is divided into two stages: an emulation phase and a generation phase; if G1 is used, the operation of the two phases is the same as in step 5; if G2 is used, the training process in the simulation phase is only carried out 1 time to transit to the generation phase; the operation of the generation phase is also the same as step 5, but the objective function calculated is LG2Optimized is a simple sample generator G2.
8. The method for distilling knowledge with joint assistance of multiple networks and generation formula knowledge as claimed in claim 3, wherein the training method adopted in step 2 is the same for different data sets and different teacher networks, but the configuration of parameters is different, specifically as follows:
for the MNIST data set, a random gradient descent method is selected to optimize a teacher model, momentum is set to be 0.9, weight attenuation is set to be 0.0001, learning rate is 0.01, batch size is 256, and 10 rounds of training are performed in total; for CIFAR10 and CIFAR100 datasets, the teacher model was also optimized using the SGD optimizer, except that the weight decay was 0.0005, the batch size was 128, the initial learning rate was 0.1, and 200 rounds of training were performed in total, with a 10-fold reduction in learning rate per 80 rounds.
9. The method as claimed in claim 6, wherein the simulation stage of step 5 is:
aiming at all data sets, an SGD optimizer is used for optimizing a student network, and specific parameters are different; for the MNIST dataset, the learning rate is set to 0.01, the weight attenuation is 0.0001, the momentum is 0.9, and the batch size of the generated data is 512; for the CIFAR10 dataset, momentum was 0.9, weight decay was 0.0005, batch size was 128, initial learning rate was 0.1, 500 rounds of training were total, learning rate was reduced by 10 times in turn at 100 th and 200 th rounds; CIFAR100 and CIFAR10 are substantially the same, the only difference being that the batch size is changed to 256;
a generation stage:
for all datasets, Adam optimizer was used to optimize the difficult sample generator G1, with specific parameters varying; when the teacher network is trained on the MNIST data set, the learning rate is set to be 0.001, and the batch size of generated data is 512; when the teacher network is trained on a CIFAR10 data set, setting the batch size of generated data as 128, setting the initial learning rate as 0.001, training 500 rounds in total, and reducing the learning rate by 10 times in turn at the 100 th round and the 200 th round; the setup of CIFAR100 is substantially the same as CIFAR10, the only difference being that the batch size of the generated data is changed to 256.
10. The method as claimed in claim 7, wherein the countermeasure training in step 6 is as follows:
for all data sets, the student network uses the SGD optimizer to optimize, and the parameter setting is different; for the MNIST dataset, the learning rate is set to 0.01, the momentum is 0.9, and the weight decays by 0.00001; for both CIFAR10 and CIFAR100 datasets, the learning rate was set to 0.001, the momentum was 0.9, and the weight was decayed by 0.00005; for all datasets, both the difficult sample generator G1 and the simple sample generator G2 were optimized using Adam optimizers, with differences in parameter settings; for the MNIST dataset, the optimizer learning rate of G1 is set to 0.001, the optimizer learning rate of G2 is set to 0.0001, and the batch size of the generated data is all 512; for the CIFAR10 dataset, the learning rate was set to 0.00001 and the batch size of the generated data was 128; for the CIFAR100 dataset, the learning rate is set to 0.00001 and the batch size of the generated data is 256.
CN202210172188.2A 2022-02-24 2022-02-24 Multi-network combined auxiliary generation type knowledge distillation method Active CN114549901B (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210172188.2A CN114549901B (en) 2022-02-24 2022-02-24 Multi-network combined auxiliary generation type knowledge distillation method

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210172188.2A CN114549901B (en) 2022-02-24 2022-02-24 Multi-network combined auxiliary generation type knowledge distillation method

Publications (2)

Publication Number Publication Date
CN114549901A true CN114549901A (en) 2022-05-27
CN114549901B CN114549901B (en) 2024-05-14

Family

ID=81676859

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210172188.2A Active CN114549901B (en) 2022-02-24 2022-02-24 Multi-network combined auxiliary generation type knowledge distillation method

Country Status (1)

Country Link
CN (1) CN114549901B (en)

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111291836A (en) * 2020-03-31 2020-06-16 中国科学院计算技术研究所 Method for generating student network model
CN112465111A (en) * 2020-11-17 2021-03-09 大连理工大学 Three-dimensional voxel image segmentation method based on knowledge distillation and countertraining
CN112560631A (en) * 2020-12-09 2021-03-26 昆明理工大学 Knowledge distillation-based pedestrian re-identification method
US11200497B1 (en) * 2021-03-16 2021-12-14 Moffett Technologies Co., Limited System and method for knowledge-preserving neural network pruning

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111291836A (en) * 2020-03-31 2020-06-16 中国科学院计算技术研究所 Method for generating student network model
CN112465111A (en) * 2020-11-17 2021-03-09 大连理工大学 Three-dimensional voxel image segmentation method based on knowledge distillation and countertraining
CN112560631A (en) * 2020-12-09 2021-03-26 昆明理工大学 Knowledge distillation-based pedestrian re-identification method
US11200497B1 (en) * 2021-03-16 2021-12-14 Moffett Technologies Co., Limited System and method for knowledge-preserving neural network pruning

Also Published As

Publication number Publication date
CN114549901B (en) 2024-05-14

Similar Documents

Publication Publication Date Title
WO2020244287A1 (en) Method for generating image semantic description
CN109902546B (en) Face recognition method, face recognition device and computer readable medium
WO2020168844A1 (en) Image processing method, apparatus, equipment, and storage medium
US11886998B2 (en) Attention-based decoder-only sequence transduction neural networks
CN111260740B (en) Text-to-image generation method based on generation countermeasure network
CN113343705B (en) Text semantic based detail preservation image generation method and system
WO2021057884A1 (en) Sentence paraphrasing method, and method and apparatus for training sentence paraphrasing model
CN110188794B (en) Deep learning model training method, device, equipment and storage medium
WO2020177214A1 (en) Double-stream video generation method based on different feature spaces of text
US20230316733A1 (en) Video behavior recognition method and apparatus, and computer device and storage medium
CN114038055B (en) Image generation method based on contrast learning and generation countermeasure network
CN112348191A (en) Knowledge base completion method based on multi-mode representation learning
CN113140019A (en) Method for generating text-generated image of confrontation network based on fusion compensation
CN114676687A (en) Aspect level emotion classification method based on enhanced semantic syntactic information
CN111488979A (en) Method and apparatus for continuously learning neural network for analyzing input data on device
CN110598737A (en) Online learning method, device, equipment and medium of deep learning model
CN110633706B (en) Semantic segmentation method based on pyramid network
CN114937202A (en) Double-current Swin transform remote sensing scene classification method
CN115222998A (en) Image classification method
CN115330620A (en) Image defogging method based on cyclic generation countermeasure network
CN116188509A (en) High-efficiency three-dimensional image segmentation method
US20180061395A1 (en) Apparatus and method for training a neural network auxiliary model, speech recognition apparatus and method
CN118096922A (en) Method for generating map based on style migration and remote sensing image
CN111507276B (en) Construction site safety helmet detection method based on hidden layer enhanced features
CN115936073B (en) Language-oriented convolutional neural network and visual question-answering method

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
CB03 Change of inventor or designer information
CB03 Change of inventor or designer information

Inventor after: Wang Yilin

Inventor after: Kuang Zhenzhong

Inventor after: Ding Jiajun

Inventor after: Gu Xiaoling

Inventor after: Yu Jun

Inventor before: Kuang Zhenzhong

Inventor before: Wang Yilin

Inventor before: Ding Jiajun

Inventor before: Gu Xiaoling

Inventor before: Yu Jun

GR01 Patent grant
GR01 Patent grant