CN113610069A - Knowledge distillation-based target detection model training method - Google Patents
Knowledge distillation-based target detection model training method Download PDFInfo
- Publication number
- CN113610069A CN113610069A CN202111179182.XA CN202111179182A CN113610069A CN 113610069 A CN113610069 A CN 113610069A CN 202111179182 A CN202111179182 A CN 202111179182A CN 113610069 A CN113610069 A CN 113610069A
- Authority
- CN
- China
- Prior art keywords
- target detection
- pixel position
- detection frame
- label
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000001514 detection method Methods 0.000 title claims abstract description 386
- 238000012549 training Methods 0.000 title claims abstract description 99
- 238000000034 method Methods 0.000 title claims abstract description 40
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 34
- 239000011159 matrix material Substances 0.000 claims abstract description 61
- 238000010586 diagram Methods 0.000 claims abstract description 27
- 238000004821 distillation Methods 0.000 claims description 7
- 230000009466 transformation Effects 0.000 claims description 3
- 239000000523 sample Substances 0.000 description 44
- 238000004364 calculation method Methods 0.000 description 5
- 230000008569 process Effects 0.000 description 4
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 238000004891 communication Methods 0.000 description 2
- 238000002372 labelling Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000005012 migration Effects 0.000 description 1
- 238000013508 migration Methods 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 238000011426 transformation method Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Software Systems (AREA)
- Probability & Statistics with Applications (AREA)
- Medical Informatics (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
The invention provides a knowledge distillation-based target detection model training method, which comprises the following steps: training a target detection teacher model using a training sample image set, the training sample image having: a first label: a hard tag probability matrix of the pixel position of the central point of the target detection frame; a second label: width and height of the target detection frame; a third label: the pixel position offset of the center point of the target detection frame; the prediction output result of the target detection teacher model comprises the following steps: the pixel position probability thermodynamic diagram of the center point of the target detection frame, the width and the height of the target detection frame and the pixel position offset of the center point of the target detection frame; and after the loss function of the target detection student model is improved in a knowledge distillation mode, training to generate the target detection student model. The invention solves the problems that the target detection model obtained by training by using the existing knowledge distillation method cannot simultaneously ensure that the network structure is simple and meets the use requirement of terminal equipment, and the recognition rate of the target detection model is excellent so as to ensure the detection precision of the model.
Description
Technical Field
The invention relates to the technical field of artificial intelligence model training, in particular to a knowledge distillation-based target detection model training method.
Background
Knowledge distillation is to guide the training of the network structure of a student model by introducing the network structure of a teacher model, thereby realizing knowledge migration. The method comprises the specific steps of firstly training a teacher model, and then training a student model by using the output of the teacher model and the real label of data, so that the knowledge of the network structure of the teacher model is transferred to the network structure of the student model, the network structure of the student model is enabled to be as small as possible and the parameter quantity is less while the network structure of the student model can obtain the performance close to the network structure of the teacher model, and therefore the method is more beneficial to reducing the calculation force requirement on the deployment model and improving the reasoning efficiency of the model.
The terminal device that performs the object detection task is generally a small device such as a video camera, a camera, or a monitor probe, and the size of the network structure of the object detection model is strictly limited because the computational effort of the chip mounted thereon is limited. Although the target detection model obtained by training by using the traditional knowledge distillation method can match the computational power requirement of the terminal equipment in the network structure size; but the accuracy of the obtained target detection model when the target detection task is implemented cannot be guaranteed.
The reason is that the traditional knowledge distillation method is usually used for training a model for implementing a single classification task, and the target detection task implemented by the target detection model adopting the centret network structure simultaneously comprises a classification task and a regression task, so that the network structure of the target detection model is relatively complex, the traditional knowledge distillation method directly replaces a real label part in a loss function of a student model with the output of a teacher model, and does not perform hierarchical classification guidance and optimization on the loss function of the target detection model, so that the finally trained target detection model has the problems of poor recognition effect and low detection precision.
Therefore, how to use the knowledge distillation method to train the target detection model while considering the simple network structure to meet the use requirements of the terminal device and simultaneously ensuring the excellent recognition rate of the target detection model to ensure the detection accuracy of the model becomes a problem to be solved urgently in the prior art.
Disclosure of Invention
The invention mainly aims to provide a knowledge distillation-based target detection model training method, which aims to solve the problems that a target detection model obtained by training with a knowledge distillation method in the prior art cannot simultaneously ensure that a network structure is simple and meets the use requirement of terminal equipment, and the recognition rate of the target detection model is excellent so as to ensure the detection precision of the model.
In order to achieve the above object, the present invention provides a knowledge-based training method for a target detection model, comprising: step S1, training a generation target detection teacher model using a training sample image set, each training sample image in the training sample image set having: a first label: a hard tag probability matrix of the pixel position of the central point of the target detection frame; a second label: width and height of the target detection frame; a third label: the pixel position offset of the center point of the target detection frame; the predicted output results of the target detection teacher model corresponding to the three types of labels include: the pixel position probability thermodynamic diagram of the center point of the target detection frame, the width and the height of the target detection frame and the pixel position offset of the center point of the target detection frame; and step S2, after the loss function of the target detection student model is improved through the target detection teacher model in a knowledge distillation mode, the training sample image set and the prediction output result are used for training to generate the target detection student model.
Further, Loss function Loss of the target detection student modeltotalIs defined as:
therein, LosshmA loss function part corresponding to the target detection frame center point pixel position probability thermodynamic diagram output by the target detection student model prediction; losswhDetecting student model predictions for targetsThe loss function part corresponding to the width and the height of the output target detection frame; lossregA loss function part corresponding to the pixel position offset of the center point of the target detection frame output by the target detection student model prediction; lambda [ alpha ]whWeighting proportion coefficients of loss function parts corresponding to the width and the height of the target detection frame; lambda [ alpha ]regAnd the weight proportion coefficient is the loss function part of the pixel position offset of the central point of the target detection frame.
Further, a Loss function part Loss corresponding to the target detection frame center point pixel position probability thermodynamic diagram output by the target detection student model predictionhmIs defined as:
wherein,converting a hard label probability matrix of a central point pixel position of a target detection frame corresponding to a first label to obtain a sub-loss function guided by a soft label probability matrix of the central point pixel position of the target detection frame;a sub-loss function guided by a target detection teacher model and a soft label probability matrix of the pixel position of the center point of a target detection frame corresponding to the first label together; lambda [ alpha ]hmAnd the weight proportion coefficient of the sub-loss function is guided by the target detection teacher model and the soft label probability matrix of the target detection frame central point pixel position corresponding to the first label.
Further, the air conditioner is provided with a fan,as focalloss loss function, sub-loss functionIs defined as:
n is the number of pixel points in a target detection frame center point pixel position probability thermodynamic diagram output by target detection student model prediction;the probability value of a digital coordinate point (x, y) in the pixel position soft label probability matrix of the central point of the target detection frame is obtained after coordinate transformation is carried out on the pixel position hard label probability matrix of the central point of the target detection frame;predicting the probability value of a pixel point (x, y) in the pixel position probability thermodynamic diagram of the center point of the target detection frame output by the target detection teacher model;predicting the probability value of a pixel point (x, y) in a target detection frame center point pixel position probability thermodynamic diagram output by a target detection student model;andare all exponential constants。
Further, the hard label probability matrix of the pixel position of the center point of the target detection frame is transformed through a Gaussian kernel function coordinate to obtain a soft label probability matrix of the pixel position of the center point of the target detection frame; probability value of digital coordinate point (x, y) of target detection frame central point pixel position soft label probability matrixIs the result value G of the Gaussian kernel function; the Gaussian kernel function is:
… … … … … … … … … … (5), wherein m and n are respectively the abscissa and the ordinate of a digit coordinate point with a probability value of 1 in the hard label probability matrix of the pixel position of the center point of the target detection frame; x and y are respectively the abscissa and the ordinate of any one digital coordinate point in the soft label probability matrix of the pixel position of the central point of the target detection frame;is a scale constant corresponding to the target detection box.
Further, when a plurality of digital coordinate points with the probability value of 1 are arranged in the target detection frame central point pixel position hard label probability matrix, the probability value of each digital coordinate point (x, y) in the target detection frame central point pixel position soft label probability matrixThe largest of the multiple gaussian kernel result values G is taken.
Further, the width and height of the target detection frame output by the target detection student model prediction corresponds to the Loss function part LosswhIs as follows;
… … … … … … … … … … … … … … … … … (6), wherein,sub-loss functions for width and height guidance of a target detection frame corresponding to the second label;a sub-loss function which is used for jointly guiding the width and the height of the target detection frame corresponding to the target detection teacher model and the second label;and the weighting proportion coefficient of the sub-loss function is guided by the width and the height of the target detection frame corresponding to the target detection teacher model and the second label.
k is the number of the width and the height of the target detection frame corresponding to the second label in the training sample image; k refers to any one second label in the training sample image;the product of the width and the height of a target detection frame corresponding to a second label in the training sample image is obtained;predicting the product of the width and the height of an output target detection frame for the target detection student model;predicting the product of the width and the height of the output target detection box for the target detection teacher model;is composed ofAndthe L1 distance therebetween;is composed ofAndthe L2 distance therebetween;is composed ofAndthe L2 distance therebetween;is a first spacing constant.
Further, a Loss function part Loss corresponding to the pixel position offset of the center point of the target detection frame output by the target detection student model predictionregComprises the following steps:
wherein,a sub-loss function guided by the pixel position offset of the center point of the target detection frame corresponding to the third label;a sub-loss function which is guided by the target detection teacher model and the target detection frame central point pixel position offset corresponding to the third label together;and the weight proportion coefficient of the sub-loss function is guided by the target detection teacher model and the target detection frame center point pixel position offset corresponding to the third label.
z is the number of the pixel position offsets of the center point of the target detection frame corresponding to the third label in the training sample image; z refers to any one third label in the training sample image;multiplying the horizontal axis offset and the vertical axis offset of the pixel position offset of the center point of the target detection frame corresponding to the third label in the training sample image;predicting the product of the horizontal axis offset and the vertical axis offset of the pixel position offset of the center point of the target detection frame output by the target detection student model;the product of the horizontal axis offset and the vertical axis offset of the pixel position offset of the center point of the target detection frame output by the target detection teacher model in a prediction mode;is composed ofAndthe L1 distance therebetween;is composed ofAndthe L2 distance therebetween;is composed ofAndbetweenL2 distance;is a second spacing constant.
By applying the technical scheme of the invention, label classification is carried out on the training sample images of the training sample image set, the target detection tasks of the trained target detection teacher model can be clearly distinguished according to the classified labels, specifically, the probability thermodynamic diagram of the pixel position of the center point of the target detection frame is obtained from the prediction output result of the target detection teacher model, the classification tasks belong to, and the width and the height of the target detection frame and the offset of the pixel position of the center point of the target detection frame are regression tasks. Therefore, in the process of guiding the training of the target detection student model by using the target detection teacher model, the loss function of the target detection student model can be pertinently classified, improved and optimized according to the task type of the target detection task, so that the network structure of the obtained target detection student model is ensured to be simple enough to meet the use requirement of terminal equipment by relying on knowledge distillation, the target detection student model can be better ensured to migrate and acquire the knowledge of the target detection teacher model, the performance of the target detection teacher model is inherited, the target detection student model has excellent identification effect and detection precision, and the target detection student model has good practicability.
Drawings
The accompanying drawings, which are incorporated in and constitute a part of this application, illustrate embodiments of the invention and, together with the description, serve to explain the invention and not to limit the invention. In the drawings:
FIG. 1 illustrates a flow chart of steps of a knowledge-based distillation target detection model training method according to the present invention;
FIG. 2 is a schematic diagram of training sample images of an alternative embodiment of a set of training sample images showing a target pedestrian whose head is the detection target, selected using a target detection frame, when implementing the knowledge-based distillation target detection model training method of the present invention;
FIG. 3 illustrates a first label, i.e., a hard label probability matrix of the pixel position of the center point of the target detection frame, of the training sample image in FIG. 2;
fig. 4 shows the soft label probability matrix of the pixel position of the center point of the target detection frame after the hard label probability matrix of the pixel position of the center point of the target detection frame in fig. 3 is transformed.
Wherein the figures include the following reference numerals:
A. a target pedestrian; B. a head of the target pedestrian; C. and (5) detecting a target.
Detailed Description
It should be noted that the embodiments and features of the embodiments in the present application may be combined with each other without conflict. The present invention will be described in detail below with reference to the embodiments with reference to the attached drawings.
In order to make the technical solutions of the present invention better understood, the technical solutions in the embodiments of the present invention will be clearly and completely described below with reference to the drawings in the embodiments of the present invention, and it is obvious that the described embodiments are only a part of the embodiments of the present invention, and not all of the embodiments. All other embodiments, which can be derived by a person skilled in the art from the embodiments given herein without making any creative effort, shall fall within the protection scope of the present invention.
It should be noted that the terms "first," "second," and the like in the description and claims of the present invention and in the drawings described above are used for distinguishing between similar elements and not necessarily for describing a particular sequential or chronological order. It is to be understood that the data so used may be interchanged under appropriate circumstances in order to facilitate the description of the embodiments of the invention herein. Furthermore, the terms "comprises," "comprising," "includes," "including," "has," "having," and any variation thereof, are intended to cover a non-exclusive inclusion, such that a process, method, system, article, or apparatus that comprises a list of steps or elements is not necessarily limited to those steps or elements explicitly listed, but may include other steps or elements not expressly listed or inherent to such process, method, article, or apparatus.
The invention provides a knowledge distillation-based target detection model training method, aiming at solving the problems that a target detection model obtained by training by using a knowledge distillation method in the prior art cannot simultaneously ensure that a network structure is simple and meets the use requirement of terminal equipment, and the recognition rate of the target detection model is excellent so as to ensure the detection precision of the model.
FIG. 1 is a flow chart of the steps of a knowledge-based distillation target detection model training method according to an alternative embodiment of the invention. As shown in fig. 1, the target detection model training method includes: step S1, training a generation target detection teacher model using a training sample image set, each training sample image in the training sample image set having: a first label: a hard tag probability matrix of the pixel position of the central point of the target detection frame; a second label: width and height of the target detection frame; a third label: the pixel position offset of the center point of the target detection frame; the predicted output results of the target detection teacher model corresponding to the three types of labels include: the pixel position probability thermodynamic diagram of the center point of the target detection frame, the width and the height of the target detection frame and the pixel position offset of the center point of the target detection frame; and step S2, after the loss function of the target detection student model is improved through the target detection teacher model in a knowledge distillation mode, the training sample image set and the prediction output result are used for training to generate the target detection student model.
The label classification is carried out on the training sample images of the training sample image set, the target detection tasks of the trained target detection teacher model can be clearly distinguished according to the classified labels, specifically, the thermodynamic diagram for obtaining the pixel position probability of the center point of the target detection frame belongs to the classification tasks in the prediction output result of the target detection teacher model, and the regression tasks are obtained for obtaining the width and the height of the target detection frame and obtaining the pixel position offset of the center point of the target detection frame. Therefore, in the process of guiding the training of the target detection student model by using the target detection teacher model, the loss function of the target detection student model can be pertinently classified, improved and optimized according to the task type of the target detection task, so that the network structure of the target detection student model obtained by training is ensured to be simple enough to meet the use requirement of terminal equipment by relying on knowledge distillation, the target detection student model can be better ensured to migrate and acquire the knowledge of the target detection teacher model, the performance of the target detection teacher model is inherited, the target detection student model has excellent identification effect and detection precision, and the target detection student model has good practicability.
Optionally, the thermodynamic diagram for obtaining the pixel position probability of the center point of the target detection frame of the target detection task belongs to a binary task.
It should be explained that before training a target detection teacher model or a target detection student model by using training sample images in a training sample image set, three types of labels need to be labeled on all training sample images, and as shown in fig. 2, for example, only one target pedestrian a exists in a training sample image, and the head B of the target pedestrian is selected by using a target detection frame C in a manual labeling manner.
And then labeling the training sample image by using a preset program, wherein a labeled first label is a hard label probability matrix (shown in figure 3) of the pixel position of the center point of the target detection frame, the numerical probability values of the hard label probability matrix of the pixel position of the center point of the target detection frame correspond to the probability values of the pixel points of the training sample image as the center point of the target detection frame one by one, the numerical probability values are 0 or 1, the numerical coordinate point with the numerical probability value of 1 is the geometric center point of the frame C of the target detection frame, and the other numerical probability values are 0. Of course, when there are a plurality of target pedestrians in the training sample image, the number of the digit coordinate points with the digit probability value of 1 is also a corresponding plurality. In order to ensure that the target detection teacher model and the target detection student model can better learn the characteristic information of the first label in the training sample image so as to improve the detection precision of the models, the hard label probability matrix of the pixel position of the center point of the target detection frame needs to be converted to obtain the soft label probability matrix of the pixel position of the center point of the target detection frame; this is because, although there is only one center point of each target detection frame in the training sample image, the pixel points around the center point still represent the head characteristics of the target pedestrian and should be truly different from the pixel points outside the head, and therefore, the target detection teacher model and the target detection student model can learn more realistic characteristic information in the training sample image by using the target detection frame center point pixel position soft label probability matrix. FIG. 4 is a soft label probability matrix of the pixel position of the center point of the target detection frame obtained after the hard label probability matrix of the pixel position of the center point of the target detection frame in FIG. 3 is transformed; in the figure, the bit probability values of the bit coordinate points adjacent to the bit coordinate point having the bit probability value of 1 are closer to 1 (not shown), and the bit probability values of the bit coordinate points adjacent to the bit coordinate point having the bit probability value of 1 are closer to 0.
In this example, the transformation method of both is: the hard label probability matrix of the pixel position of the central point of the target detection frame is transformed through a Gaussian kernel function coordinate to obtain a soft label probability matrix of the pixel position of the central point of the target detection frame; probability value of digital coordinate point (x, y) of target detection frame central point pixel position soft label probability matrixIs the result value G of the Gaussian kernel function; the Gaussian kernel function is:
wherein m and n are respectively an abscissa and an ordinate of a digital coordinate point with a probability value of 1 in the hard tag probability matrix of the pixel position of the center point of the target detection frame; the mth column and the nth row of the hard tag probability matrix of the pixel position of the central point of the target detection frame; x and y are respectively the abscissa and the ordinate of any one digital coordinate point in the soft label probability matrix of the pixel position of the central point of the target detection frame; namely, the xth column and the yth row of the soft label probability matrix of the pixel position of the central point of the target detection frame;is a scale constant corresponding to the target detection box. Optionally, a scale constant of the target detection boxIs in the range of 10 pixels to 80 pixels.
Of course, when there are a plurality of digital coordinate points with a probability value of 1 in the target detection box center point pixel position hard label probability matrix, that is, when there are a plurality of target detection boxes C in fig. 2, the probability value of each digital coordinate point (x, y) in the target detection box center point pixel position soft label probability matrixThe largest of the multiple gaussian kernel result values G is taken.
The second label labeled on the training sample image is the width and height of the target detection frame (not shown), and the third label labeled on the training sample image is the pixel position offset of the center point of the target detection frame (not shown).
In this embodiment, the Loss function Loss of the object detection student modeltotalIs defined as:
Losshma loss function part corresponding to the target detection frame center point pixel position probability thermodynamic diagram output by the target detection student model prediction; losswhA loss function part corresponding to the width and height of a target detection frame output for the target detection student model prediction; lossregA loss function part corresponding to the pixel position offset of the center point of the target detection frame output by the target detection student model prediction; lambda [ alpha ]whWeighting proportion coefficients of loss function parts corresponding to the width and the height of the target detection frame; lambda [ alpha ]regAnd the weight proportion coefficient is the loss function part of the pixel position offset of the central point of the target detection frame.
Optionally, the weight scale factor λ of the loss function part corresponding to the width and height of the target detection framewhAnd the pixel position of the central point of the target detection frameWeighting scaling factor lambda of the loss function part of the offsetregThe value ranges of the target detection frames are [0.5,1 ], which shows that the probability thermodynamic diagram of the pixel position of the center point of the target detection frame output by the target detection student model in prediction occupies the largest weight and is the most key factor influencing the later detection precision of the target detection student model.
Optionally, the weight scale factor λ of the loss function part corresponding to the width and height of the target detection framewhWeight proportion coefficient lambda of loss function part larger than pixel position offset of central point of target detection framereg. This is because the target detection student model post-detection accuracy is more heavily influenced by the width and height of the target detection frame than the target detection frame center point pixel position offset amount.
In particular, Loss function Loss of the object detection student modeltotalThe first part of the grading is a Loss function part Loss corresponding to a target detection frame center point pixel position probability thermodynamic diagram output by target detection student model predictionhmThe Loss function of the classification task is optimized and improved through knowledge distillation, and the corresponding Loss function part Loss ishmIs defined as:
wherein,converting a hard label probability matrix of a central point pixel position of a target detection frame corresponding to a first label to obtain a sub-loss function guided by a soft label probability matrix of the central point pixel position of the target detection frame;a sub-loss function guided by a target detection teacher model and a soft label probability matrix of the pixel position of the center point of a target detection frame corresponding to the first label together; lambda [ alpha ]hmTarget detection frame center corresponding to teacher model and first label for target detectionAnd the soft label probability matrix of the point pixel positions guides the weight proportion coefficient of the sub-loss function together.
Optionally, a sub-loss functionIs a weight scale factor ofhmThe value range of (1) is [0.5 ], so that the value range of the target detection frame is ensured not to exceed a sub-loss function guided by a soft label probability matrix of the pixel position of the central point of the target detection frameThe weight of (c).
It should be noted that, in this embodiment, the target detection teacher model does not show the target detection frame center point pixel position probability thermodynamic diagrams of the target detection student model prediction output and the target detection teacher model prediction output, but the ideal training state of the model is to expect that the target detection frame center point pixel position probability matrixes corresponding to the target detection frame center point pixel position probability thermodynamic diagrams of both prediction outputs learn the soft label probability matrix close to the target detection frame center point pixel position in fig. 4, thereby ensuring that both the target detection teacher model and the target detection student model have good detection accuracy.
Sub-loss function guided by soft label probability matrix of pixel position of center point of target detection frameAnd the method is used for evaluating the difference between a target detection frame central point pixel position probability matrix corresponding to the target detection frame central point pixel position probability thermodynamic diagram output by the target detection student model prediction and a target detection frame central point pixel position soft label probability matrix.
In the present embodiment, the first and second electrodes are,is focalloss loss function, which is mainly used for balancing the problems of unbalance between positive and negative samples and difficult sample occurrence in the detection task, and sub-loss functionIs defined as:
is a loss function based on knowledge distillation for evaluating the distribution difference between the predicted output of the target detection student model and the predicted output of the target detection teacher model, compared with a sub-loss function guided by a soft label probability matrix of the pixel position of the center point of the target detection frameFunction of sub-lossIncrease is provided withAndoutput distribution, sub-loss function after network structure of teacher model is detected for guiding network structure of student model to be detectedIs defined as:
n is the number of pixel points in a target detection frame center point pixel position probability thermodynamic diagram output by target detection student model prediction;the probability value of a digital coordinate point (x, y) in the pixel position soft label probability matrix of the central point of the target detection frame is obtained after coordinate transformation is carried out on the pixel position hard label probability matrix of the central point of the target detection frame;predicting the probability value of a pixel point (x, y) in the pixel position probability thermodynamic diagram of the center point of the target detection frame output by the target detection teacher model;predicting the probability value of a pixel point (x, y) in a target detection frame center point pixel position probability thermodynamic diagram output by a target detection student model;andare all exponential constants.
In the above formula (3) and formula (4),andthe weight coefficients of the difficult samples are increased, and the larger the deviation of the prediction output of the target detection student model is, the larger the two weight coefficients are.Is a weighting factor used to adjust the fraction of negative samples lost, the more negative samples deviate from the target, the greater the weighting factor. Alternatively,andhas a value range of [2,4 ]]。
Loss function Loss of target detection student modeltotalThe second part of the hierarchy is a Loss function part Loss corresponding to the width and height of an object detection box output by the object detection student model predictionwhThe Loss function of the part of the regression task is optimized and improved through knowledge distillation, and the corresponding Loss function part Loss iswhThe combined L1 and L2 loss functions are defined as:
wherein,sub-loss functions for width and height guidance of a target detection frame corresponding to the second label;a sub-loss function which is used for jointly guiding the width and the height of the target detection frame corresponding to the target detection teacher model and the second label;and the weighting proportion coefficient of the sub-loss function is guided by the width and the height of the target detection frame corresponding to the target detection teacher model and the second label.
Optionally, a sub-loss functionWeight scaling factor ofIs in a value range of [0.5, 1), so as to ensure that the value does not exceed the wide and high sub-loss functions of the target detection frame corresponding to the second labelThe weight of (c).
k is the number of the width and the height of the target detection frame corresponding to the second label in the training sample image; k refers to any one second label in the training sample image;the product of the width and the height of a target detection frame corresponding to a second label in the training sample image is obtained;predicting the product of the width and the height of an output target detection frame for the target detection student model;predicting the product of the width and the height of the output target detection box for the target detection teacher model;is composed ofAndthe L1 distance therebetween;is composed ofAndthe L2 distance therebetween;is composed ofAndthe L2 distance therebetween;is a first spacing constant.
By determining the difference between the predicted output of the target detection student model and the second label of the originally input training sample imageGreater than the difference between the predicted output of the target detection student model and the predicted output of the target detection teacher modelAnd exceeds the first spacing constantThen the loss of L2 that would add a second label to the object detection student model.
Loss function Loss of target detection student modeltotalThe third part of the hierarchy is a Loss function part Loss corresponding to the pixel position offset of the center point of the target detection frame output by the prediction of the target detection student modelregThe Loss function of the part of the regression task is optimized and improved through knowledge distillation, and the corresponding Loss function part Loss isregThe combined L1 and L2 loss functions are defined as:
wherein,a sub-loss function guided by the pixel position offset of the center point of the target detection frame corresponding to the third label;a sub-loss function which is guided by the target detection teacher model and the target detection frame central point pixel position offset corresponding to the third label together;and the weight proportion coefficient of the sub-loss function is guided by the target detection teacher model and the target detection frame center point pixel position offset corresponding to the third label.
Optionally, a sub-loss functionWeight scaling factor ofIs [0.5, 1), ensures that the sub-loss function does not exceed the sub-loss function guided by the pixel position offset of the center point of the target detection frame corresponding to the third labelThe weight of (c). It should be noted that the offset of the pixel position of the center point of the target detection frame is the difference between the pixel coordinate position of the center point of the target detection frame output by the target detection student model in prediction and the actual position in the training sample image.
z is the number of the pixel position offsets of the center point of the target detection frame corresponding to the third label in the training sample image; z refers to any one third label in the training sample image;multiplying the horizontal axis offset and the vertical axis offset of the pixel position offset of the center point of the target detection frame corresponding to the third label in the training sample image;predicting the product of the horizontal axis offset and the vertical axis offset of the pixel position offset of the center point of the target detection frame output by the target detection student model;the product of the horizontal axis offset and the vertical axis offset of the pixel position offset of the center point of the target detection frame output by the target detection teacher model in a prediction mode;is composed ofAndthe L1 distance therebetween;is composed ofAndthe L2 distance therebetween;is composed ofAndthe L2 distance therebetween;is a second spacing constant.
Predicting a gap between an output and a third label of an original input training sample image by judging a target detection student modelGreater than the difference between the predicted output of the target detection student model and the predicted output of the target detection teacher modelAnd exceeds the second spacing constantThen the loss of L2 that would add a third label to the object detection student model.
It should be noted that, the network structure of the target detection teacher model and the network structure of the target detection student model both adopt hourglass network structures, and the difference is that the network depth and width of the network structure of the target detection teacher model are both greater than those of the network structure of the target detection student model, and the number of parameters of the network structure of the target detection teacher model is 5-10 times that of the network structure of the target detection student model. The recall rate and the detection precision of the target detection student model trained by the knowledge distillation-based target detection model training method provided by the invention are superior to those of the target detection student model trained by a general knowledge distillation training mode.
The above-mentioned serial numbers of the embodiments of the present invention are merely for description and do not represent the merits of the embodiments.
The integrated unit in the above embodiments, if implemented in the form of a software functional unit and sold or used as a separate product, may be stored in the above computer-readable storage medium. Based on such understanding, the technical solution of the present invention may be embodied in the form of a software product, which is stored in a storage medium and includes several instructions for causing one or more computer devices (which may be personal computers, servers, network devices, etc.) to execute all or part of the steps of the method according to the embodiments of the present invention.
In the above embodiments of the present invention, the descriptions of the respective embodiments have respective emphasis, and for parts that are not described in detail in a certain embodiment, reference may be made to related descriptions of other embodiments.
In the several embodiments provided in the present application, it should be understood that the disclosed client may be implemented in other manners. The above-described embodiments of the apparatus are merely illustrative, and for example, the division of the units is only one type of division of logical functions, and there may be other divisions when actually implemented, for example, a plurality of units or components may be combined or may be integrated into another system, or some features may be omitted, or not executed. In addition, the shown or discussed mutual coupling or direct coupling or communication connection may be an indirect coupling or communication connection through some interfaces, units or modules, and may be in an electrical or other form.
The units described as separate parts may or may not be physically separate, and parts displayed as units may or may not be physical units, may be located in one place, or may be distributed on a plurality of network units. Some or all of the units can be selected according to actual needs to achieve the purpose of the solution of the embodiment.
In addition, functional units in the embodiments of the present invention may be integrated into one processing unit, or each unit may exist alone physically, or two or more units are integrated into one unit. The integrated unit can be realized in a form of hardware, and can also be realized in a form of a software functional unit.
The above description is only a preferred embodiment of the present invention and is not intended to limit the present invention, and various modifications and changes may be made by those skilled in the art. Any modification, equivalent replacement, or improvement made within the spirit and principle of the present invention should be included in the protection scope of the present invention.
Claims (10)
1. A knowledge distillation-based target detection model training method is characterized by comprising the following steps:
step S1, training a generation target detection teacher model using a training sample image set, each training sample image in the training sample image set having: a first label: a hard tag probability matrix of the pixel position of the central point of the target detection frame; a second label: width and height of the target detection frame; a third label: the pixel position offset of the center point of the target detection frame; the first prediction output result of the target detection teacher model corresponding to the three types of labels comprises: the pixel position probability thermodynamic diagram of the center point of the target detection frame, the width and the height of the target detection frame and the pixel position offset of the center point of the target detection frame;
step S2, after the loss function of the target detection student model is improved through the target detection teacher model in a knowledge distillation mode, the training sample image set and the first prediction output result are used for training to generate the target detection student model.
2. The knowledge-distillation-based target detection model training method as claimed in claim 1, wherein the Loss function Loss of the target detection student modeltotalIs defined as:
Losshma loss function part corresponding to the target detection frame center point pixel position probability thermodynamic diagram output by the target detection student model prediction;
Losswha loss function part corresponding to the width and height of the target detection frame output by the target detection student model prediction;
Lossrega loss function part corresponding to the pixel position offset of the center point of the target detection frame output by the target detection student model prediction;
λwhweighting proportion coefficients of loss function parts corresponding to the width and the height of the target detection frame;
λregand the weight proportion coefficient is the loss function part of the pixel position offset of the central point of the target detection frame.
3. The knowledge distillation-based target detection model training method as claimed in claim 2, wherein the target detection student model predicts the Loss function part Loss corresponding to the target detection frame center point pixel position probability thermodynamic diagram outputhmIs defined as:
converting a hard label probability matrix of a central point pixel position of a target detection frame corresponding to a first label to obtain a sub-loss function guided by a soft label probability matrix of the central point pixel position of the target detection frame;
a sub-loss function guided by a target detection teacher model and a soft label probability matrix of the pixel position of the center point of a target detection frame corresponding to the first label together;
λhmand the weight proportion coefficient of the sub-loss function is guided by the target detection teacher model and the soft label probability matrix of the target detection frame central point pixel position corresponding to the first label.
4. The knowledge-based distillation target detection model training method according to claim 3,
n is the number of pixel points in a target detection frame center point pixel position probability thermodynamic diagram output by the target detection student model prediction;
the probability value of a digital coordinate point (x, y) in the target detection frame central point pixel position soft label probability matrix is obtained after coordinate transformation is carried out on the target detection frame central point pixel position hard label probability matrix;
predicting the probability value of a pixel point (x, y) in the pixel position probability thermodynamic diagram of the center point of the target detection frame output by the target detection teacher model;
predicting the probability value of a pixel point (x, y) in a target detection frame center point pixel position probability thermodynamic diagram output by a target detection student model;
5. The knowledge distillation-based target detection model training method as claimed in claim 3, wherein the hard label probability matrix of the pixel position of the center point of the target detection frame is transformed by a Gaussian kernel function coordinate to obtain a soft label probability matrix of the pixel position of the center point of the target detection frame; probability value of digital coordinate point (x, y) of target detection frame central point pixel position soft label probability matrixIs the result value G of the Gaussian kernel function; the Gaussian kernel function is:
wherein m and n are respectively an abscissa and an ordinate of a digital coordinate point with a probability value of 1 in the hard tag probability matrix of the pixel position of the center point of the target detection frame;
x and y are respectively the abscissa and the ordinate of any one digital coordinate point in the soft label probability matrix of the pixel position of the central point of the target detection frame;
6. The knowledge-based distillation target detection model training method according to claim 5,
when a plurality of digital coordinate points with the probability value of 1 in the target detection frame central point pixel position hard label probability matrix are provided, the probability value of each digital coordinate point (x, y) in the target detection frame central point pixel position soft label probability matrixThe largest of the multiple gaussian kernel result values G is taken.
7. The knowledge distillation-based target detection model training method as claimed in claim 2, wherein the target detection student model predicts the Loss function part Loss corresponding to the width and height of the output target detection boxwhIs as follows;
sub-loss functions for width and height guidance of a target detection frame corresponding to the second label;
a sub-loss function which is used for jointly guiding the width and the height of the target detection frame corresponding to the target detection teacher model and the second label;
8. The knowledge-based distillation target detection model training method according to claim 7,
k is the number of the width and the height of a target detection frame corresponding to a second label in the training sample image;
k refers to any one of the second labels in the training sample image;
the product of the width and the height of a target detection frame corresponding to a second label in the training sample image is obtained;
predicting a product of a width and a height of an output object detection box for the object detection student model;
predicting a product of a width and a height of an output target detection box for the target detection teacher model;
9. The method of claim 2The target detection model training method based on knowledge distillation is characterized in that the target detection student model predicts the Loss function part Loss corresponding to the pixel position offset of the center point of the target detection frame output by predictionregComprises the following steps:
a sub-loss function guided by the pixel position offset of the center point of the target detection frame corresponding to the third label;
a sub-loss function which is guided by the target detection teacher model and the target detection frame central point pixel position offset corresponding to the third label together;
10. The knowledge-based distillation target detection model training method of claim 9,
z is the number of the pixel position offsets of the center point of the target detection frame corresponding to the third label in the training sample image;
z refers to any one of the third labels in the training sample image;
multiplying the horizontal axis offset and the vertical axis offset of the pixel position offset of the center point of the target detection frame corresponding to the third label in the training sample image;
predicting the product of the horizontal axis offset and the vertical axis offset of the pixel position offset of the center point of the target detection frame output by the target detection student model;
the product of the horizontal axis offset and the vertical axis offset of the pixel position offset of the center point of the target detection frame output by the target detection teacher model in a prediction mode;
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111179182.XA CN113610069B (en) | 2021-10-11 | 2021-10-11 | Knowledge distillation-based target detection model training method |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111179182.XA CN113610069B (en) | 2021-10-11 | 2021-10-11 | Knowledge distillation-based target detection model training method |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113610069A true CN113610069A (en) | 2021-11-05 |
CN113610069B CN113610069B (en) | 2022-02-08 |
Family
ID=78343524
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111179182.XA Active CN113610069B (en) | 2021-10-11 | 2021-10-11 | Knowledge distillation-based target detection model training method |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113610069B (en) |
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114119959A (en) * | 2021-11-09 | 2022-03-01 | 盛视科技股份有限公司 | Vision-based garbage can overflow detection method and device |
CN115496666A (en) * | 2022-11-02 | 2022-12-20 | 清智汽车科技(苏州)有限公司 | Heatmap generation method and apparatus for target detection |
CN115512131A (en) * | 2022-10-11 | 2022-12-23 | 北京百度网讯科技有限公司 | Image detection method and training method of image detection model |
CN115984640A (en) * | 2022-11-28 | 2023-04-18 | 北京数美时代科技有限公司 | Target detection method, system and storage medium based on combined distillation technology |
CN118154992A (en) * | 2024-05-09 | 2024-06-07 | 中国科学技术大学 | Medical image classification method, device and storage medium based on knowledge distillation |
CN118521964A (en) * | 2024-07-22 | 2024-08-20 | 山东捷瑞数字科技股份有限公司 | Robot dense detection method and system based on anchor frame score optimization |
Citations (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180268292A1 (en) * | 2017-03-17 | 2018-09-20 | Nec Laboratories America, Inc. | Learning efficient object detection models with knowledge distillation |
CN110674688A (en) * | 2019-08-19 | 2020-01-10 | 深圳力维智联技术有限公司 | Face recognition model acquisition method, system and medium for video monitoring scene |
CN110991556A (en) * | 2019-12-16 | 2020-04-10 | 浙江大学 | Efficient image classification method, device, equipment and medium based on multi-student cooperative distillation |
CN112257815A (en) * | 2020-12-03 | 2021-01-22 | 北京沃东天骏信息技术有限公司 | Model generation method, target detection method, device, electronic device, and medium |
CN112367273A (en) * | 2020-10-30 | 2021-02-12 | 上海瀚讯信息技术股份有限公司 | Knowledge distillation-based flow classification method and device for deep neural network model |
CN112418268A (en) * | 2020-10-22 | 2021-02-26 | 北京迈格威科技有限公司 | Target detection method and device and electronic equipment |
CN112508169A (en) * | 2020-11-13 | 2021-03-16 | 华为技术有限公司 | Knowledge distillation method and system |
CN112990198A (en) * | 2021-03-22 | 2021-06-18 | 华南理工大学 | Detection and identification method and system for water meter reading and storage medium |
CN113011356A (en) * | 2021-03-26 | 2021-06-22 | 杭州朗和科技有限公司 | Face feature detection method, device, medium and electronic equipment |
CN113139500A (en) * | 2021-05-10 | 2021-07-20 | 重庆中科云从科技有限公司 | Smoke detection method, system, medium and device |
CN113326852A (en) * | 2021-06-11 | 2021-08-31 | 北京百度网讯科技有限公司 | Model training method, device, equipment, storage medium and program product |
CN113361384A (en) * | 2021-06-03 | 2021-09-07 | 深圳前海微众银行股份有限公司 | Face recognition model compression method, device, medium, and computer program product |
WO2021189912A1 (en) * | 2020-09-25 | 2021-09-30 | 平安科技(深圳)有限公司 | Method and apparatus for detecting target object in image, and electronic device and storage medium |
-
2021
- 2021-10-11 CN CN202111179182.XA patent/CN113610069B/en active Active
Patent Citations (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180268292A1 (en) * | 2017-03-17 | 2018-09-20 | Nec Laboratories America, Inc. | Learning efficient object detection models with knowledge distillation |
CN110674688A (en) * | 2019-08-19 | 2020-01-10 | 深圳力维智联技术有限公司 | Face recognition model acquisition method, system and medium for video monitoring scene |
CN110991556A (en) * | 2019-12-16 | 2020-04-10 | 浙江大学 | Efficient image classification method, device, equipment and medium based on multi-student cooperative distillation |
WO2021189912A1 (en) * | 2020-09-25 | 2021-09-30 | 平安科技(深圳)有限公司 | Method and apparatus for detecting target object in image, and electronic device and storage medium |
CN112418268A (en) * | 2020-10-22 | 2021-02-26 | 北京迈格威科技有限公司 | Target detection method and device and electronic equipment |
CN112367273A (en) * | 2020-10-30 | 2021-02-12 | 上海瀚讯信息技术股份有限公司 | Knowledge distillation-based flow classification method and device for deep neural network model |
CN112508169A (en) * | 2020-11-13 | 2021-03-16 | 华为技术有限公司 | Knowledge distillation method and system |
CN112257815A (en) * | 2020-12-03 | 2021-01-22 | 北京沃东天骏信息技术有限公司 | Model generation method, target detection method, device, electronic device, and medium |
CN112990198A (en) * | 2021-03-22 | 2021-06-18 | 华南理工大学 | Detection and identification method and system for water meter reading and storage medium |
CN113011356A (en) * | 2021-03-26 | 2021-06-22 | 杭州朗和科技有限公司 | Face feature detection method, device, medium and electronic equipment |
CN113139500A (en) * | 2021-05-10 | 2021-07-20 | 重庆中科云从科技有限公司 | Smoke detection method, system, medium and device |
CN113361384A (en) * | 2021-06-03 | 2021-09-07 | 深圳前海微众银行股份有限公司 | Face recognition model compression method, device, medium, and computer program product |
CN113326852A (en) * | 2021-06-11 | 2021-08-31 | 北京百度网讯科技有限公司 | Model training method, device, equipment, storage medium and program product |
Cited By (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114119959A (en) * | 2021-11-09 | 2022-03-01 | 盛视科技股份有限公司 | Vision-based garbage can overflow detection method and device |
CN115512131A (en) * | 2022-10-11 | 2022-12-23 | 北京百度网讯科技有限公司 | Image detection method and training method of image detection model |
CN115512131B (en) * | 2022-10-11 | 2024-02-13 | 北京百度网讯科技有限公司 | Image detection method and training method of image detection model |
CN115496666A (en) * | 2022-11-02 | 2022-12-20 | 清智汽车科技(苏州)有限公司 | Heatmap generation method and apparatus for target detection |
CN115984640A (en) * | 2022-11-28 | 2023-04-18 | 北京数美时代科技有限公司 | Target detection method, system and storage medium based on combined distillation technology |
CN118154992A (en) * | 2024-05-09 | 2024-06-07 | 中国科学技术大学 | Medical image classification method, device and storage medium based on knowledge distillation |
CN118154992B (en) * | 2024-05-09 | 2024-07-23 | 中国科学技术大学 | Medical image classification method, device and storage medium based on knowledge distillation |
CN118521964A (en) * | 2024-07-22 | 2024-08-20 | 山东捷瑞数字科技股份有限公司 | Robot dense detection method and system based on anchor frame score optimization |
Also Published As
Publication number | Publication date |
---|---|
CN113610069B (en) | 2022-02-08 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113610069B (en) | Knowledge distillation-based target detection model training method | |
CN114241282B (en) | Knowledge distillation-based edge equipment scene recognition method and device | |
CN109118479B (en) | Capsule network-based insulator defect identification and positioning device and method | |
CN109086811B (en) | Multi-label image classification method and device and electronic equipment | |
CN105446988B (en) | The method and apparatus for predicting classification | |
CN108133172A (en) | Method, the analysis method of vehicle flowrate and the device that Moving Objects are classified in video | |
CN114332578A (en) | Image anomaly detection model training method, image anomaly detection method and device | |
CN111368634B (en) | Human head detection method, system and storage medium based on neural network | |
CN110175657B (en) | Image multi-label marking method, device, equipment and readable storage medium | |
CN110969200A (en) | Image target detection model training method and device based on consistency negative sample | |
CN113128478A (en) | Model training method, pedestrian analysis method, device, equipment and storage medium | |
CN109583367A (en) | Image text row detection method and device, storage medium and electronic equipment | |
CN116151479B (en) | Flight delay prediction method and prediction system | |
CN114332457A (en) | Image instance segmentation model training method, image instance segmentation method and device | |
CN118172935A (en) | Intelligent high-speed management system and method based on digital twinning | |
CN115063664A (en) | Model learning method, training method and system for industrial vision detection | |
CN115017970A (en) | Migration learning-based gas consumption behavior anomaly detection method and system | |
CN114528913A (en) | Model migration method, device, equipment and medium based on trust and consistency | |
CN117710792A (en) | Knowledge distillation method, greening region detection method, electronic device, and storage medium | |
CN113065533A (en) | Feature extraction model generation method and device, electronic equipment and storage medium | |
CN113255701A (en) | Small sample learning method and system based on absolute-relative learning framework | |
CN117371511A (en) | Training method, device, equipment and storage medium for image classification model | |
CN116894593A (en) | Photovoltaic power generation power prediction method and device, electronic equipment and storage medium | |
CN111444833A (en) | Fruit measurement production method and device, computer equipment and storage medium | |
CN113010687B (en) | Exercise label prediction method and device, storage medium and computer equipment |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |