CN114972904A - Zero sample knowledge distillation method and system based on triple loss resistance - Google Patents
Zero sample knowledge distillation method and system based on triple loss resistance Download PDFInfo
- Publication number
- CN114972904A CN114972904A CN202210401592.2A CN202210401592A CN114972904A CN 114972904 A CN114972904 A CN 114972904A CN 202210401592 A CN202210401592 A CN 202210401592A CN 114972904 A CN114972904 A CN 114972904A
- Authority
- CN
- China
- Prior art keywords
- model
- training
- student
- teacher
- triple
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 43
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 26
- 238000012549 training Methods 0.000 claims abstract description 81
- 238000005070 sampling Methods 0.000 claims abstract description 28
- 230000006870 function Effects 0.000 claims description 23
- 230000015572 biosynthetic process Effects 0.000 claims description 15
- 238000003786 synthesis reaction Methods 0.000 claims description 15
- 239000011159 matrix material Substances 0.000 claims description 14
- 238000013527 convolutional neural network Methods 0.000 claims description 9
- 238000011176 pooling Methods 0.000 claims description 8
- 238000002372 labelling Methods 0.000 claims description 7
- 238000004364 calculation method Methods 0.000 claims description 6
- 230000008569 process Effects 0.000 claims description 6
- 238000004821 distillation Methods 0.000 claims description 4
- 238000010606 normalization Methods 0.000 claims description 4
- 238000011478 gradient descent method Methods 0.000 claims description 3
- 238000012795 verification Methods 0.000 claims description 2
- 101100460704 Aspergillus sp. (strain MF297-2) notI gene Proteins 0.000 claims 1
- 230000008014 freezing Effects 0.000 claims 1
- 238000007710 freezing Methods 0.000 claims 1
- 208000037516 chromosome inversion disease Diseases 0.000 description 12
- 238000012546 transfer Methods 0.000 description 4
- 230000006835 compression Effects 0.000 description 3
- 238000007906 compression Methods 0.000 description 3
- 230000003042 antagnostic effect Effects 0.000 description 2
- 238000013528 artificial neural network Methods 0.000 description 2
- 238000003058 natural language processing Methods 0.000 description 2
- 230000001133 acceleration Effects 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 230000003278 mimic effect Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000008450 motivation Effects 0.000 description 1
- 230000005477 standard model Effects 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/778—Active pattern-learning, e.g. online learning of image or video features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Databases & Information Systems (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Medical Informatics (AREA)
- Software Systems (AREA)
- Artificial Intelligence (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Multimedia (AREA)
- Image Analysis (AREA)
Abstract
The invention relates to a zero sample knowledge distillation method and a system based on resistance triple loss, and belongs to the technical field of computer vision. The distance weighted triple selection sampling strategy provided by the invention performs more optimal sampling on the triple expression set from the feature space of students, and only encourages all positive samples to keep a certain distance from each other. By using order preserving regression, the relative order is focused on based on marginal loss, and the image generation efficiency and the image quality are improved. According to the method, the original training set is not required to be contacted, only the pre-trained teacher model parameters are provided by the user, and the training of the student model can be automatically completed according to the teacher model parameters and the images generated by the generator. The invention ensures the diversity and information richness of the generated images on the premise of no need of special equipment and higher training speed and precision, improves the quality of the synthetic image set, and trains a high-precision lightweight model under the condition of ensuring the privacy of data.
Description
Technical Field
The invention relates to a zero sample knowledge distillation method and a zero sample knowledge distillation system, in particular to a zero sample knowledge distillation method and a zero sample knowledge distillation system based on anti-triple loss model inversion, and belongs to the technical field of computer vision.
Background
Knowledge Distillation (KD) is a model compression method, is a training method based on the teacher-student network thought, and is widely applied to the industry due to simplicity and effectiveness. Knowledge distillation, as a representative type of model compression and acceleration, can effectively learn "small student models" from large "teacher models", and is rapidly receiving attention from the industry.
Currently, most of the expansion of knowledge distillation is concentrated on compressing deep neural networks, and the generated lightweight student networks can be easily deployed in applications such as visual recognition, speech recognition, Natural Language Processing (NLP) and the like. Furthermore, knowledge distillation, the transfer of knowledge from one model to another, can be extended to other tasks, such as combating attacks, data enhancement, data privacy and security, and the like. The idea of knowledge transfer has been further applied to compress training data, i.e., data set distillation, to transfer knowledge from large data sets to small data sets to reduce the training burden of deep models by motivation for knowledge distillation to perform model compression.
Model Inversion (MI) aims at reconstructing inputs from parameters of a pre-trained model, originally proposed for understanding deep representations of neural networks. Given the mapping of the function φ (x) to the input x, the standard model inversion problem can be formulated as finding an x 'that minimizes d (φ (x), φ (x')), where d (·,) is an error function, such as the mean square error. The paradigm is called model inversion attack and is widely applied to multiple fields such as model security, interpretability and the like. In recent years, inversion techniques have shown good results in knowledge transfer, enabling distillation without data knowledge.
Early knowledge distillation frameworks typically contained one or more large pre-trained teacher models and small student models, the teacher model typically being much larger than the student models, with the main idea being to train efficient student models under the direction of the teacher model to achieve considerable accuracy. Supervision signals from the teacher model (often referred to as "knowledge" learned by the teacher model) may help the student model mimic the behavior of the teacher model.
Disclosure of Invention
The invention aims to creatively provide a zero-sample knowledge distillation method and a zero-sample knowledge distillation system which have good generalization and high generated data quality and are used for resisting triple loss model inversion, based on the requirements of the existing deep learning model training service on data privacy, aiming at the defects that the traditional knowledge distillation method is high in cost, low in efficiency, and the like, and all data sets need to be accessed, and the problems that the existing model inversion image generation method only considers the generation loss of a single sample, and the difference among samples in different classes and the similarity among samples in the same class cannot be grasped. The generated images of the method have inter-class similarity and intra-class difference in the feature space, and the feature difference extracted by the trained student model on the inter-class model is greater than the intra-class feature.
In order to achieve the above purpose, the invention adopts the following technical scheme.
A distillation method based on zero sample knowledge against triple loss comprises a pre-training stage, a model inversion stage and a model training stage.
Step 1: and (4) pre-training.
Firstly, classifying and labeling collected image training sets, and then selecting a proper convolutional neural network model; sending all images in the training set into a convolutional neural network initialized randomly in batches, and calculating the cross entropy loss of a predicted value and a real label; and then calculating the gradient of each parameter in the convolutional neural network relative to loss, and updating the model parameters by using a random gradient descent method to obtain a trained teacher model.
Step 2: and (6) model inversion.
Firstly, the teacher model parameters obtained in the pre-training stage are frozen, so that the parameter values are not updated any more. Then, the generator and the student model are initialized by using the random parameters, wherein the generator generates synthetic data according to given conditions, and the student model distinguishes different categories of the generated data of the generator through a learning teacher model for training. The generator attempts to produce data that is as close to true as possible, and accordingly, the student model attempts to perfectly resolve the different classes of data and to match the output of the teacher model as much as possible. Inputting the trained preliminary generation algorithm model, generating and storing a corresponding preliminary generation image.
And step 3: and (5) training a model.
The student model is trained by the learning teacher model distinguishing different categories of the generator's generated data. The student model attempts to perfectly resolve the different classes of data and to try to match the output of the teacher model. Inputting the trained preliminary generation algorithm model, generating and storing a corresponding preliminary generation image. The generated images are input into the teacher model and the student models, teacher-student matching loss is calculated, and the student models are trained and improved. And finally, exporting and deploying the trained student model.
Further, in order to achieve the object of the present invention and the method mentioned above, the present invention provides a zero sample knowledge distillation system based on triple loss resistance, which comprises a data synthesis module, a training module and an identification module.
The data synthesis module is used for protecting a synthesis data set of user privacy, and calculating triple loss and synthesis loss according to trained model parameters provided by a user to generate required data.
And the training module is used for calculating the output matching loss of the teacher model and the student model by using the synthesized virtual data as input based on the teacher model provided by the user and training the student model by using the triple loss.
The recognition module is used for deploying the trained student model to the terminal equipment for recognition.
The output end of the data synthesis module is connected with the input end of the training module, and the output end of the training module is connected with the input end of the recognition module.
Advantageous effects
Compared with the prior art, the invention has the following advantages:
1. the invention is used for distilling without data knowledge to generate more confused samples, ensures that the samples are not gathered in a narrow characteristic region in the training stage of a generator, and can also help a student model generate inter-class characteristics with good resolution in the training stage of the model.
2. The invention provides a distance weighted triple selection sampling strategy to perform more optimal sampling on a triple representation set from a feature space of a student, and the strategy only encourages all positive samples to keep a certain distance from each other instead of being as close as possible. It reduces the loss and makes it stronger. Further, by using order preserving regression, relative order is focused on based on loss of margin, rather than absolute distance, thereby improving image generation efficiency and image quality.
3. The invention can improve the precision of the student model on the premise of protecting the privacy of data. Compared with the traditional knowledge distillation method, the method does not need to contact the original training set, because data privacy and safety are always one of the main problems concerned in deep learning. For example, most biometric data is not shared publicly due to personal private information. According to the method, only the pre-trained teacher model parameters are provided by the user, and the training of the student model can be automatically completed according to the teacher model parameters and the images generated by the generator.
4. The invention can ensure the diversity and the information richness of the generated images and improve the quality of the synthesized image set on the premise of no need of special equipment and higher training speed and precision, thereby training a high-precision lightweight model under the condition of ensuring the data privacy.
Drawings
FIG. 1 is a flow chart of the method of the present invention.
FIG. 2 is a schematic diagram of the model inversion and modeling performed by the method of the present invention.
Detailed Description
For better illustrating the objects and advantages of the present invention, the following detailed description of the invention is provided in conjunction with the accompanying drawings and examples.
Examples
As shown in fig. 1, a zero sample knowledge distillation method based on countering triple loss includes a pre-training stage, a model inversion stage, and a model training stage. The method comprises the following specific steps:
step 1: a pre-training phase.
Firstly, classifying and labeling collected image training sets, and then selecting a proper convolutional neural network model; sending all images in the training set into a convolutional neural network initialized randomly in batches, and calculating the cross entropy loss of a predicted value and a real label; and then calculating the gradient of each parameter in the convolutional neural network relative to loss, and updating the model parameters by using a random gradient descent method to obtain a trained teacher model.
Specifically, step 1 comprises the steps of:
step 1.1: and carrying out classification and labeling on the images in the training data set. The specific method comprises the following steps:
and respectively labeling all images in the training set. Each image is given a label in a predefined set of categories, and the image and label pair { (x, y) } N is stored for subsequent training. Where x represents an image, y represents a label, and N represents the number of samples.
Step 1.2: and extracting data in batches by using the training set for training, and primarily generating a teacher model. The specific method comprises the following steps:
firstly, randomly selecting a batch of image and label pairs { (x, y) } n from a training data set, normalizing a corresponding image data matrix, and inputting the images into a teacher model. N represents the number of randomly drawn samples, N < N.
The teacher model then outputs a predicted probability result y' for a different class. The probability number of the prediction result is consistent with the total category number of the labeled training set, that is, the output solution space is all integers on [0, (c-1) ], each integer represents a target category, and c is the total number of the target categories.
And then, comparing the class probability prediction result y' output by the model with a real class label y, calculating a cross entropy loss function, reversely transmitting the loss value, and updating parameters in the preliminarily generated algorithm model.
And continuously repeating the process until the set iteration times are met, and storing the structure and the model parameters of the network to obtain the trained preliminary generation algorithm model structure and parameters. Optionally, the function for minimizing cross entropy loss is formulated as:
wherein, theta is a model parameter of the teacher model T, and the output of T is probability prediction of each category; x is formed by R B×H×W X denotes input image data, B, H and W denote the number of channels, height and width of the image, respectively; y is i To correspond to the i-th class probability, y i Taking the position of 1 to represent that the point is the interested target of the corresponding category, and taking 0 to represent not; t is a unit of θ (x) i And (4) predicting confidence of the teacher model on the ith class of interest objects. c is the total number of target classes.
In addition, different teacher models and different loss functions can be selected according to different specific details of the models and different training data sets.
Step 2: a model inversion stage.
Firstly, the parameters of the teacher model obtained in the pre-training stage are frozen, so that the parameters are not updated any more. Then, the generator and the student model are initialized by using the random parameters, wherein the generator generates synthetic data according to given conditions, and the student model distinguishes different categories of the generated data of the generator through a learning teacher model for training. The generator attempts to produce data that is as close to true as possible, and accordingly, the student model attempts to perfectly resolve the different classes of data and to match the output of the teacher model as much as possible. Inputting the trained preliminary generation algorithm model, generating and storing a corresponding preliminary generation image.
Specifically, step 2 comprises the steps of:
step 2.1: and (4) randomly initializing noise input, and sampling a teacher model feature triple. The specific method comprises the following steps:
randomly initializing, selecting a batch of noise z and obtaining a synthetic image set through a generator G { (x) 0 ) N, where x 0 G (z), G is the generator. And normalizing the corresponding image data matrix, and inputting the image into the teacher model. The teacher model outputs prediction probability results y of different categories 0 ,y 0 =softmax(p 0 ),p 0 Are logits of teacher model output (logis softmax function in model output layerAn input vector of numbers). The probability number of the prediction result is consistent with the total category number of the label training set, namely, the output solution space is [0, (c-1)]Each integer represents an object class, and c is the total number of object classes.
Predicting result y according to class probability output by the model 0 Determining the real class label y of the generated image, and calculating the p between two samples of different class labels y 0 The distance d in space is such that,is a feature of the anchor sample that is,features of negative examples of different labels.
Calculating the characteristic distance f (d) of different negative samples for a specific anchor sample according to the characteristic distance:
where c is the total number of target classes.
Then, given anchor samples are calculated according to the characteristic distancesDifferent negative samplesProbability of being sampled
Where λ is a hyper-parameter of feature distance clipping to avoid noisy sampling.
Sampling according to the calculated sampling probabilityTriple setIt is the samples with the same prediction label that are obtained by uniformly distributed probability sampling.
Step 2.2: and calculating the triple loss of the generated image according to the triple and updating the parameters of the generator model.
The specific method comprises the following steps:
and establishing a network model according to a preliminary generation algorithm model structure and parameters stored in a training stage, normalizing the generated image data matrix, and inputting the normalized image data matrix into a student model to obtain the global characteristic e of the corresponding sample, wherein e is global Pooling (h), h is the intermediate layer characteristic of the student model, and global Pooling is global pooling operation.
According to the triple relation set obtained based on the characteristic distance samplingObtaining a triple set corresponding to global featuresThe triplet losses are calculated and the parameters of the generator G and the random noise z are updated. Wherein, the formula of the minimum triplet loss function is as follows:
wherein,for global features of the anchor sample in the student model,to correspond to the global features of the positive sample in the student model,global features in student models for corresponding negative examplesPerforming sign; τ is the margin between positive and negative pairs of all possible triples within a batch, and is used to represent the expected minimum distance gap between an intra-class feature and an inter-class feature. Theta G Representing the parameters of the generator G.
And continuously repeating the process until the set iteration times are met, storing the structure and the model parameters of the network to obtain the trained initial generation algorithm model structure and parameters, and storing the generated image with the minimum generation loss. Wherein, the formula for minimizing the generation loss function is as follows:
wherein, y t Probability distribution, y, predicted for teacher model s A probability distribution predicted for the student model; mu.s j (x),Mean and variance of the feature layer j in the teacher model for the corresponding sample x; mu.s j 、Counting the stored mean value and variance for a batch normalization layer of a characteristic layer j in the teacher model; l represents the number of model layers;a statistical function representing batch normalization; d is a function of the calculated feature distance.
In addition, the model can select different generation loss functions according to different specific details of the model and different training data sets.
And step 3: and (5) a model training stage.
In the model training phase, the student model is trained by the teacher model to distinguish different categories of generated data of the generator. The student model attempts to perfectly resolve the different classes of data and to try to match the output of the teacher model. Inputting the trained preliminary generation algorithm model, generating and storing a corresponding preliminary generation image. The generated images are input into the teacher model and the student models, teacher-student matching loss is calculated, and the student models are trained and improved. And finally, exporting and deploying the trained student model.
Specifically, step 3 includes the steps of:
step 3.1: and (4) randomly initializing a student model, and sampling a teacher model feature triple. The specific method comprises the following steps:
randomly initializing, selecting a batch of noise z and obtaining a synthetic training image set through a trained generator G { (x) 0 ) N, where x 0 (ii) g (z); and normalizing the corresponding image data matrix, and inputting the image into the teacher model. The teacher model outputs prediction probability results y of different categories 0 ,y 0 =Softmax(p 0 ),p 0 Are the locations output by the teacher model.
Predicting result y according to class probability output by the model 0 Determining the true class label y of the generated image, and calculating the p between two samples of different class labels y 0 The distance d in space is such that,whereinIs a feature of the anchor sample that is,features of negative examples for different labels; calculating the characteristic distance of different negative samples for the specific anchor sample according to the characteristic distance:
c is the total number of target classes.
Then, given anchor samples are calculated according to the characteristic distancesProbability of different negative samples being sampled
Where λ is a hyper-parameter of feature distance clipping to avoid noisy sampling.
Sampling according to the sampling probability obtained by calculation to obtain a triple setWhereinIt is the samples with the same prediction label that are obtained by uniformly distributed probability sampling.
Step 3.2: and generating the triple loss of the image according to the triple calculation, and updating the parameters of the student model. The specific method comprises the following steps:
and establishing a network model according to a preliminary generation algorithm model structure and parameters stored in a training stage, normalizing a generated image data matrix, and inputting the normalized image data matrix into a student model to obtain the global characteristic e of the corresponding sample, wherein e is global Pooling (h), and h is the intermediate layer characteristic of the student model.
According to the triple relation set obtained by sampling based on the characteristic distanceObtaining a triple set corresponding to global featuresCalculating the triple loss and updating the parameters of the student model S.
Wherein, the formula of the minimum triplet loss function is as follows:
wherein,for global features of the anchor sample in the student model,to correspond to the global features of the positive sample in the student model,global features in the student model for corresponding negative examples; theta S τ is the margin between positive and negative pairs of all possible triples within a batch, and represents the expected minimum distance gap between intra-class features and inter-class features.
This step calculates the loss in accordance with the loss calculation method in step 2, but in the opposite direction. In particular, when training a generator, for pairs of positive samples with the same pseudo-label, the antagonistic triplet loss expands their distance in space to ensure diversity of homogeneous samples. Conversely, for negative sample pairs with different pseudo-labels, the antagonistic triple penalty brings them closer to ensure that the samples are near the decision boundary. In contrast, when training a student model, positive sample pairs are put together. It enables students to generate class-distinguishing features for the above samples to improve classification performance.
And continuously repeating the process until the set iteration times are met, storing the structure and the model parameters of the network to obtain the initially generated algorithm model structure and parameters after training is finished, and storing the model with the highest precision on the verification set. Wherein, the formula for minimizing the teacher-student matching loss function is as follows:
wherein, theta S Is a student modelNumber, y t Probability distribution, y, predicted for teacher model s A probability distribution predicted for the student model; e.g. of the type T And e S The global features of the corresponding samples in the teacher model and the student model, respectively.
In addition, the model can select different matching loss functions according to different specific details of the model and different training data sets.
Further, to achieve the object of the present invention and the above method, the present invention provides a zero sample knowledge distillation system based on the countercheck triple loss, which comprises a data synthesis module, a training module and an identification module.
The data synthesis module is used for protecting a synthesis data set of user privacy, and calculating triple loss and synthesis loss according to trained model parameters provided by a user to generate required data.
And the training module is used for calculating the output matching loss of the teacher model and the student model by using the synthesized virtual data as input based on the teacher model provided by the user and training the student model by using the triple loss.
The recognition module is used for deploying the trained student model to the terminal equipment for recognition.
The output end of the data synthesis module is connected with the input end of the training module, and the output end of the training module is connected with the input end of the recognition module.
The above detailed description is intended to illustrate the objects, aspects and advantages of the present invention, and it should be understood that the above detailed description is only exemplary of the present invention and is not intended to limit the scope of the present invention, and any modifications, equivalents, improvements and the like made within the spirit and principle of the present invention should be included in the scope of the present invention.
Claims (8)
1. A method of distillation based on zero sample knowledge against loss of triples, comprising the steps of:
step 1: pre-training;
firstly, classifying and labeling collected image training sets, and then selecting a proper convolutional neural network model; sending all images in the training set into a convolutional neural network initialized randomly in batches, and calculating the cross entropy loss of a predicted value and a real label; then calculating the gradient of each parameter in the convolutional neural network relative to loss, and updating the model parameters by using a random gradient descent method to obtain a trained teacher model;
and 2, step: model inversion;
firstly, freezing teacher model parameters obtained in a pre-training stage to ensure that the parameters are not updated any more; secondly, initializing a generator and a student model by using random parameters, wherein the generator generates synthetic data according to given conditions, and the student model distinguishes different classes of the generated data of the generator through a learning teacher model for training; the generator tries to generate data as close to reality as possible, and accordingly, the student model tries to perfectly distinguish the data of different categories and accords with the output of the teacher model as much as possible; inputting a trained preliminary generation algorithm model, generating and storing a corresponding preliminary generation image;
and 3, step 3: training a model;
the student model is trained by distinguishing different types of the generated data of the generator through the learning teacher model; the student model tries to perfectly distinguish data of different categories and is consistent with the output of the teacher model as much as possible; inputting a trained preliminary generation algorithm model, generating and storing a corresponding preliminary generation image; inputting the generated images into a teacher model and a student model, calculating teacher-student matching loss, training and improving the student model; and finally, exporting and deploying the trained student model.
2. A method of zero sample knowledge distillation based on countering triple losses as set forth in claim 1, wherein step 1 includes the steps of:
step 1.1: classifying and labeling the images in the training data set;
labeling all images in the training set respectively; each image is given a label in a predefined set of classes, and the image and label pairs { (x,y)} N Storing for subsequent training; wherein x represents an image, y represents a category label, and N represents the number of samples;
step 1.2: extracting data in batches by using a training set for training, and primarily generating a teacher model;
firstly, randomly selecting a batch of image and label pairs { (x, y) } n from a training data set, normalizing a corresponding image data matrix, and inputting the images into a teacher model; n represents the number of randomly drawn samples, N < N;
then, the teacher model outputs a prediction probability result y' of different categories; wherein, the probability number of the prediction result is consistent with the total category number of the labeled training set, that is, the output solution space is all integers on [0, (c-1) ], each integer represents a target category, and c is the total number of the target categories;
then, comparing the class probability prediction result y' output by the model with a real class label y, calculating a cross entropy loss function, reversely transmitting a loss value, and updating parameters in the preliminarily generated algorithm model;
and continuously repeating the process until the set iteration times are met, and storing the structure and the model parameters of the network to obtain the trained preliminary generation algorithm model structure and parameters.
3. A method of zero-sample knowledge distillation based on countering triplet losses as claimed in claim 2 wherein in step 1.2 the minimizing cross-entropy loss function is formulated as:
wherein, theta is a model parameter of the teacher model T, and the output of T is probability prediction of each category; x is formed by R B×H×W X denotes input image data, B, H and W denote the number of channels, height and width of the image, respectively; y is i To correspond to the i-th class probability, y i The position of taking 1 represents that the point is the interested target of the corresponding category, and 0 represents that the point is notIs that; t is θ (x) i Predicting confidence of the teacher model on the ith type of interest target; c is the total number of target classes.
4. The method of claim 1, wherein step 2 comprises the steps of:
step 2.1: randomly initializing noise input, and sampling a teacher model feature triple;
randomly initializing, selecting a batch of noise z and obtaining a synthetic image set through a generator G { (x) 0 ) N, where x 0 G (z), G being a generator; normalizing the corresponding image data matrix, and inputting the image into a teacher model; the teacher model outputs different classes of predicted probability results y 0 ,y 0 =softmax(p 0 ),p 0 Are the locations output by the teacher model; the probability number of the prediction result is consistent with the total category number of the label training set, namely, the output solution space is [0, (c-1)]C is the total number of target classes;
predicting result y according to class probability output by the model 0 Determining the real class label y of the generated image, and calculating the p between two samples of different class labels y 0 The distance d in space is such that, is a characteristic of the anchor sample and,features of negative examples for different labels;
calculating the characteristic distance f (d) of different negative samples for a specific anchor sample according to the characteristic distance:
wherein c is the total number of target classes;
then, given anchor samples are calculated according to the characteristic distancesDifferent negative samplesProbability of being sampled
Wherein λ is a hyper-parameter of feature distance clipping for avoiding noisy sampling;
sampling according to the sampling probability obtained by calculation to obtain a triple set The samples with the same prediction label are obtained by sampling with uniformly distributed probability;
step 2.2: calculating and generating triple loss of the image according to the triples and updating the model parameters of the generator;
creating a network model according to a preliminary generation algorithm model structure and parameters stored in a training stage, normalizing a generated image data matrix, and inputting the normalized image data matrix into a student model to obtain a global feature e of a corresponding sample, wherein e is global Pooling (h), h is an intermediate layer feature of the student model, and global Pooling is global pooling operation;
according to the triple relation set obtained based on the characteristic distance samplingObtaining a triple set corresponding to global featuresCalculating triple losses and updating parameters of a generator G and random noise z; wherein, the formula of the minimum triplet loss function is as follows:
wherein,for global features of the anchor sample in the student model,to correspond to the global features of the positive sample in the student model,global features in the student model for corresponding negative examples; τ is the margin between positive and negative pairs of all possible triples within a batch, representing the expected minimum distance difference between intra-class features and inter-class features; theta G Parameters representing the generator G;
and continuously repeating the process until the set iteration times are met, storing the structure and the model parameters of the network to obtain the trained initial generation algorithm model structure and parameters, and storing the generated image with the minimum generation loss.
5. A method of zero-sample knowledge distillation based on countering triple losses as claimed in claim 4, characterized in that in step 2.2, the formula that minimizes the resulting loss function is:
wherein, y t Probability distribution, y, predicted for teacher model s A probability distribution predicted for the student model; mu.s j (x),Mean and variance of feature layer j in the teacher model for corresponding sample x; mu.s j 、Counting the stored mean value and variance for a batch normalization layer of a characteristic layer j in the teacher model; l represents the number of model layers;a statistical function representing batch normalization; d is a function of the calculated feature distance.
6. The method of claim 1, wherein step 3 comprises the steps of:
step 3.1: randomly initializing a student model, and sampling a teacher model feature triple;
randomly initializing, selecting a batch of noise z and obtaining a synthetic training image set through a trained generator G { (x) 0 ) N, where x 0 (ii) g (z); normalizing the corresponding image data matrix, and inputting the image into a teacher model; the teacher model outputs prediction probability results y of different categories 0 ,y 0 =Softmax(p 0 ),p 0 Are the locations of the teacher model output;
predicting result y according to class probability output by the model 0 Determining the true class label y of the generated image, and calculating the p between two samples of different class labels y 0 The distance d in space is such that,whereinIs a feature of the anchor sample that is,features of negative examples for different labels; calculating the characteristic distance of different negative samples for the specific anchor sample according to the characteristic distance:
c is the total number of target classes;
then, given anchor samples are calculated according to the characteristic distancesProbability of different negative samples being sampled
Wherein λ is a hyper-parameter of feature distance clipping for avoiding noisy sampling;
sampling according to the sampling probability obtained by calculation to obtain a triple setWhereinThe samples with the same prediction label are obtained by sampling with uniformly distributed probability;
step 3.2: generating triple loss of the image according to triple calculation, and updating parameters of the student model; the specific method comprises the following steps:
creating a network model according to a preliminary generation algorithm model structure and parameters stored in a training stage, normalizing a generated image data matrix, and inputting the normalized image data matrix into a student model to obtain a global feature e of a corresponding sample, wherein e is global Pooling (h), and h is an intermediate layer feature of the student model;
according to the triple relation set obtained based on the characteristic distance samplingObtaining a triple set corresponding to global featuresCalculating the loss of the triples and updating parameters of the student model S; wherein, the formula of the minimum triplet loss function is as follows:
wherein,for global features of the anchor sample in the student model,to correspond to the global features of the positive sample in the student model,global features in the student model for corresponding negative examples; theta S τ is the margin between positive and negative pairs of all possible triples within a batch, representing the expected minimum distance gap between intra-class features and inter-class features;
and continuously repeating the process until the set iteration times are met, storing the structure and the model parameters of the network to obtain the initially generated algorithm model structure and parameters after training is finished, and storing the model with the highest precision on the verification set.
7. A method of zero-sample knowledge distillation based on countering triple losses as claimed in claim 6, characterized in that in step 3.2, the formula for minimizing the teacher-student match loss function is:
wherein, theta S As parameters of the student model, y t Probability distribution, y, predicted for teacher model s A probability distribution predicted for the student model; e.g. of the type T And e S Global characteristics of the corresponding samples in the teacher model and the student model are respectively; l represents the number of model layers; d is a function of the calculated feature distance.
8. A zero-sample knowledge distillation system based on countering triple loss, comprising a data synthesis module, a training module and an identification module;
the data synthesis module is used for protecting a synthesis data set of user privacy, calculating triple loss and synthesis loss according to trained model parameters provided by a user and generating required data;
the training module is used for calculating the output matching loss of the teacher model and the student model and training the student model by using the synthesized virtual data as input based on the teacher model provided by the user;
the recognition module is used for deploying the trained student model to the terminal equipment for recognition;
the output end of the data synthesis module is connected with the input end of the training module, and the output end of the training module is connected with the input end of the recognition module.
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210401592.2A CN114972904B (en) | 2022-04-18 | 2022-04-18 | Zero sample knowledge distillation method and system based on fighting against triplet loss |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210401592.2A CN114972904B (en) | 2022-04-18 | 2022-04-18 | Zero sample knowledge distillation method and system based on fighting against triplet loss |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114972904A true CN114972904A (en) | 2022-08-30 |
CN114972904B CN114972904B (en) | 2024-05-31 |
Family
ID=82976732
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210401592.2A Active CN114972904B (en) | 2022-04-18 | 2022-04-18 | Zero sample knowledge distillation method and system based on fighting against triplet loss |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114972904B (en) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116362351A (en) * | 2023-05-29 | 2023-06-30 | 深圳须弥云图空间科技有限公司 | Method and device for training pre-training language model by using noise disturbance |
CN117009830A (en) * | 2023-10-07 | 2023-11-07 | 之江实验室 | Knowledge distillation method and system based on embedded feature regularization |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113610173A (en) * | 2021-08-13 | 2021-11-05 | 天津大学 | Knowledge distillation-based multi-span domain few-sample classification method |
WO2021248868A1 (en) * | 2020-09-02 | 2021-12-16 | 之江实验室 | Knowledge distillation-based compression method for pre-trained language model, and platform |
CN114170332A (en) * | 2021-11-27 | 2022-03-11 | 北京工业大学 | Image recognition model compression method based on anti-distillation technology |
WO2022051856A1 (en) * | 2020-09-09 | 2022-03-17 | Huawei Technologies Co., Ltd. | Method and system for training a neural network model using adversarial learning and knowledge distillation |
CN114241282A (en) * | 2021-11-04 | 2022-03-25 | 河南工业大学 | Knowledge distillation-based edge equipment scene identification method and device |
-
2022
- 2022-04-18 CN CN202210401592.2A patent/CN114972904B/en active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021248868A1 (en) * | 2020-09-02 | 2021-12-16 | 之江实验室 | Knowledge distillation-based compression method for pre-trained language model, and platform |
WO2022051856A1 (en) * | 2020-09-09 | 2022-03-17 | Huawei Technologies Co., Ltd. | Method and system for training a neural network model using adversarial learning and knowledge distillation |
CN113610173A (en) * | 2021-08-13 | 2021-11-05 | 天津大学 | Knowledge distillation-based multi-span domain few-sample classification method |
CN114241282A (en) * | 2021-11-04 | 2022-03-25 | 河南工业大学 | Knowledge distillation-based edge equipment scene identification method and device |
CN114170332A (en) * | 2021-11-27 | 2022-03-11 | 北京工业大学 | Image recognition model compression method based on anti-distillation technology |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116362351A (en) * | 2023-05-29 | 2023-06-30 | 深圳须弥云图空间科技有限公司 | Method and device for training pre-training language model by using noise disturbance |
CN116362351B (en) * | 2023-05-29 | 2023-09-26 | 深圳须弥云图空间科技有限公司 | Method and device for training pre-training language model by using noise disturbance |
CN117009830A (en) * | 2023-10-07 | 2023-11-07 | 之江实验室 | Knowledge distillation method and system based on embedded feature regularization |
CN117009830B (en) * | 2023-10-07 | 2024-02-13 | 之江实验室 | Knowledge distillation method and system based on embedded feature regularization |
Also Published As
Publication number | Publication date |
---|---|
CN114972904B (en) | 2024-05-31 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112308158B (en) | Multi-source field self-adaptive model and method based on partial feature alignment | |
CN107122809B (en) | Neural network feature learning method based on image self-coding | |
CN114492574A (en) | Pseudo label loss unsupervised countermeasure domain adaptive picture classification method based on Gaussian uniform mixing model | |
CN110046671A (en) | A kind of file classification method based on capsule network | |
CN111931637A (en) | Cross-modal pedestrian re-identification method and system based on double-current convolutional neural network | |
CN109255289B (en) | Cross-aging face recognition method based on unified generation model | |
CN114972904B (en) | Zero sample knowledge distillation method and system based on fighting against triplet loss | |
CN113326731A (en) | Cross-domain pedestrian re-identification algorithm based on momentum network guidance | |
CN110175248B (en) | Face image retrieval method and device based on deep learning and Hash coding | |
CN110705636B (en) | Image classification method based on multi-sample dictionary learning and local constraint coding | |
CN110941734A (en) | Depth unsupervised image retrieval method based on sparse graph structure | |
CN113591978B (en) | Confidence penalty regularization-based self-knowledge distillation image classification method, device and storage medium | |
CN112232395B (en) | Semi-supervised image classification method for generating countermeasure network based on joint training | |
CN116910571B (en) | Open-domain adaptation method and system based on prototype comparison learning | |
CN114913379B (en) | Remote sensing image small sample scene classification method based on multitasking dynamic contrast learning | |
CN112364791A (en) | Pedestrian re-identification method and system based on generation of confrontation network | |
CN114187506B (en) | Remote sensing image scene classification method of viewpoint-aware dynamic routing capsule network | |
CN111783526A (en) | Cross-domain pedestrian re-identification method using posture invariance and graph structure alignment | |
CN112651329B (en) | Low-resolution ship classification method for generating countermeasure network through double-flow feature learning | |
CN116935411A (en) | Radical-level ancient character recognition method based on character decomposition and reconstruction | |
CN116246305A (en) | Pedestrian retrieval method based on hybrid component transformation network | |
CN115329821A (en) | Ship noise identification method based on pairing coding network and comparison learning | |
Guo et al. | SAR target recognition with limited samples based on meta knowledge transferring using relation network | |
Yang et al. | NAM net: meta-network with normalization-based attention for few-shot learning | |
Lei et al. | Student action recognition based on multiple features |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |