CN117011640A - Model distillation real-time target detection method and device based on pseudo tag filtering - Google Patents

Model distillation real-time target detection method and device based on pseudo tag filtering Download PDF

Info

Publication number
CN117011640A
CN117011640A CN202310815686.9A CN202310815686A CN117011640A CN 117011640 A CN117011640 A CN 117011640A CN 202310815686 A CN202310815686 A CN 202310815686A CN 117011640 A CN117011640 A CN 117011640A
Authority
CN
China
Prior art keywords
model
loss
real
data set
student model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310815686.9A
Other languages
Chinese (zh)
Inventor
梅少辉
马明阳
张子晔
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Northwestern Polytechnical University
Original Assignee
Northwestern Polytechnical University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Northwestern Polytechnical University filed Critical Northwestern Polytechnical University
Priority to CN202310815686.9A priority Critical patent/CN117011640A/en
Publication of CN117011640A publication Critical patent/CN117011640A/en
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0464Convolutional networks [CNN, ConvNet]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/096Transfer learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/778Active pattern-learning, e.g. online learning of image or video features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V2201/00Indexing scheme relating to image or video recognition or understanding
    • G06V2201/07Target detection
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Databases & Information Systems (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Multimedia (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Image Analysis (AREA)

Abstract

The application relates to a model distillation real-time target detection method and device based on pseudo tag filtering. The method comprises the steps of training a student model based on a teacher model, firstly obtaining an extended data set, inputting the extended data set into the trained teacher model, inputting the generated pseudo tag data set into a quality classifier to generate a new pseudo tag data set, inputting the new pseudo tag data set and a combined set of original data into the student model and the trained teacher model respectively, calculating the original loss of the student model and the knowledge distillation loss between the two models based on detection results, and finally calculating the integral loss according to the original loss and the knowledge distillation loss to reversely update student model parameters. The target detection model has high instantaneity, strong generalization capability and high detection precision.

Description

Model distillation real-time target detection method and device based on pseudo tag filtering
Technical Field
The application relates to the technical field of target detection, in particular to a model distillation real-time target detection method and device based on pseudo tag filtering.
Background
The real-time target detection aims to detect and identify objects in images or videos under the real-time requirement, and the real-time target detection is widely applied to the fields of automatic driving, security monitoring, intelligent home, medical images and the like.
In recent years, the development of deep learning technology provides strong support for the research of real-time target detection. Currently, the applications of deep learning in real-time target detection are mainly classified into the following two categories: (1) a single-stage detection method: the single-stage detection method generally adopts a Convolutional Neural Network (CNN) structure for feature extraction and classification, and detects an object through the position and the size of a regression frame. Typical single-stage detection algorithms include YOLO and SSD, etc.; (2) a two-stage detection method: the two-stage detection method generally performs generation of candidate frames through a convolutional neural network, and then classifies and locates the candidate frames. Typical two-stage detection algorithms include RCNN, fast-RCNN, and the like. In general, the real-time target detection technology has been rapidly developed and improved in recent years, and new methods and algorithms have been proposed, which make the real-time target detection have a wider application prospect in practical applications.
However, training of existing target detection methods typically requires a large amount of annotation data to learn the features and context information of the target. Due to the complexity and diversity of the target detection tasks, in order to obtain better performance, a deep network model, such as a Convolutional Neural Network (CNN) based or a Transformer based model, needs to be used. These models typically have millions to billions of parameters, requiring significant computational resources and memory space, and are poorly real-time in the model reasoning stage. And existing real-time target detection methods are generally sensitive to changes in illumination and viewing angle. Different visual angles, shielding and other factors can cause appearance changes of the targets in the image, and the changes can influence the size and shape of the targets, so that the existing target detection method is difficult to accurately detect and position the targets. When the camera zooms, the viewing angle changes, and the target is blocked, the existing target detection methods suffer from poor generalization ability.
Disclosure of Invention
In view of the above, it is necessary to provide a model distillation real-time target detection method and device based on pseudo tag filtering, which have high model real-time performance and high generalization capability.
According to the model distillation real-time target detection method based on pseudo tag filtering, a student model is trained based on a teacher model, then data to be detected is input into the trained student model to obtain a real-time target detection result, the teacher model and the student model are both target detection models, the number of layers of the teacher model is more than that of the student model, and the depth of the teacher model is greater than that of the student model;
training the student model based on the teacher model specifically comprises the following steps:
acquiring an extended data set;
training the teacher model based on the extended data set, inputting the extended data set into the trained teacher model, and generating a pseudo tag data set;
inputting the pseudo tag data set into a quality classifier to generate a new pseudo tag data set;
inputting the new pseudo tag data set and the original data set into a student model and a trained teacher model respectively to obtain a detection result and a pre-training result;
calculating the original loss of the student model based on the detection result, and calculating the knowledge distillation loss based on the detection result and the pre-training result;
calculating the overall loss according to the original loss and the knowledge distillation loss of the student model;
and adjusting parameters of the student model based on the overall loss to obtain the trained student model.
In one embodiment, acquiring the extended data set includes:
acquiring an original data set;
a random affine transformation is performed on the original data set.
In one embodiment, the teacher model is a YOLOv5-l model, and 300 epochs are trained when training the teacher model based on the extended data set;
the mass classifier is a positive and negative sample mass separator.
In one embodiment, the raw loss of the student model includes confidence, category loss, and frame regression loss;
confidence is of
L CE_obj =-αlog(β)-(1-α)log(1-β) (1);
Wherein, the element beta represents the probability that the sample belongs to the foreground or the background, namely the confidence value of the boundary box, and alpha=is a mark of whether the real label contains the target (1 represents that the target is contained, and 0 represents that the target is not contained); class loss of
Where p (x) is an actual probability distribution, each element p i Representing the probability that the sample belongs to class i, y when the sample belongs to class i i =1, the others are all 0, nc is the total number of sample categories;
the frame regression loss is
Wherein c is a minimum rectangle containing both the predicted frame and the real frame, b gt Is a true frame, b is a predicted frame, ρ (b gt B) represents the Euclidean distance between the center points of the real frame and the predicted frame, beta is a parameter for balancing the weight of the function, ν is a parameter for measuring the consistency of the length-width ratio between the two frames, ioU items and alpha represent additional power regularization items;
the original loss of the student model is
L STU =λ 1 ×L CE_cls2 ×L CE_obj3 ×L α-CIoU (4);
Wherein lambda is 1 0.3 lambda 2 0.4 lambda 3 0.3.
In one embodiment, the knowledge distillation loss is
Where m, n represents the rank of tensors of the output result, output T ,output S And outputting results of the teacher model and the student model respectively.
In one embodiment, the overall loss is
L total =α 1 ×L STU2 ×L Distill (6);
Wherein alpha is 1 Is 0.8, alpha 2 0.2.
In one embodiment, parameters of the student model are adjusted based on the overall loss, and the trained student model is obtained by back-propagating the overall loss to the student model, and adjusting the student model parameters.
In one embodiment, a depth separable convolution module is used in place of the feature extraction portion of the traditional convolutional neural network module in the student model.
In one embodiment, training the student model based on the teacher model further includes migration training the trained student model using the raw data after obtaining the trained student model.
In a second aspect, the application also provides a model distillation real-time target detection device based on pseudo tag filtering, which comprises a memory and a processor, wherein the memory stores a computer program, and the processor realizes the step of model distillation real-time target detection based on pseudo tag filtering when executing the computer program.
The application has the beneficial effects that:
(1) The application adopts the random affine transformation operation to simulate different visual angles and illumination conditions, including translation, scaling, rotation and shearing of images to simulate the position change of the target, the size change of the target under different distances, the blocking of the target and the target part under different visual angles, the conversion of the original label, the expansion of a data set, the improvement of the detection effect on the multi-visual-angle target, the blocking target and the multi-scale target, the improvement of the generalization capability of the target detection model, namely the student model, and the reduction of the overfitting of the student model.
(2) According to the application, training of two target detection models of a teacher model and a student model is simultaneously carried out, a feature extraction part of a traditional convolutional neural network module in the student model is replaced by a depth separable convolutional module, so that structural optimization and design of the deep learning neural network model are realized, model learning and fusion are carried out by combining a model distillation method, knowledge of a complex teacher model is transferred into the student model, the number of layers of the student model is small, the depth is small, the effect of light weight is achieved, and the real-time performance and generalization capability of the student model are further improved while the student model has higher detection precision.
Drawings
FIG. 1 is a schematic flow chart of a model distillation real-time target detection method based on pseudo tag filtering according to an embodiment of the present application;
FIG. 2 is a schematic flow chart of a model distillation real-time target detection method based on pseudo tag filtering according to an embodiment of the present application;
FIG. 3 is a schematic flow chart of a model distillation real-time target detection method based on pseudo tag filtering according to an embodiment of the present application;
fig. 4 is a schematic diagram of a student model after replacing a feature extraction part of a traditional convolutional neural network module with a depth separable convolutional module according to an embodiment of the present application.
Detailed Description
The present application will be described in further detail with reference to the drawings and examples, in order to make the objects, technical solutions and advantages of the present application more apparent. It should be understood that the specific embodiments described herein are for purposes of illustration only and are not intended to limit the scope of the application.
In one embodiment, as shown in fig. 1, fig. 1 is one of the flow charts of the model distillation real-time target detection method based on pseudo tag filtering according to the embodiment of the present application, and the method is applied to a computer device, and includes the following steps:
s101, training a student model based on a teacher model.
Specifically, training the student model using the teacher model transfers knowledge of the teacher model into the student model.
S102, inputting the data to be tested into the trained student model to obtain a real-time target detection result.
Specifically, the teacher model and the student model are both target detection models, the number of layers of the teacher model is more than that of the student model, and the depth of the teacher model is greater than that of the student model.
In this embodiment, as shown in fig. 2, fig. 2 is one of the flow charts of the model distillation real-time target detection method based on pseudo tag filtering provided in the embodiment of the present application, and training a student model based on a teacher model specifically includes the following steps:
s201, acquiring an extended data set.
S202, training a teacher model based on the extended data set, inputting the extended data set into the trained teacher model, and generating a pseudo tag data set.
Specifically, the output result of the teacher model is a pseudo tag, and we represent the dataset image as X U And teacher model F T :X U →Y U When training a teacher model, a temperature factor T is introduced, and the Softmax probability distribution is adjusted to generate a pseudo tag. When the T value is smaller, the negative label becomes smaller, and the attention of the student model for training to the negative label is also reduced, and the pseudo label is defined as:
soft_labels=softmax(Y U /T)
s203, inputting the pseudo tag data set into a quality classifier to generate a new pseudo tag data set.
After the random affine transformation, there are many severely deformed raw data images used to train the student model. And selecting a proper pseudo-label sample by adopting a mass classifier according to the detection result of the teacher model, wherein the filtered sample accounts for 5-20% of the total sample.
S204, inputting the new pseudo tag data set and the original data set into the student model and the trained teacher model respectively to obtain a detection result and a pre-training result.
It should be noted that, the student model generally selects a model of the same series of algorithms as the teacher model. The new pseudo tag data set is then entered into the student model along with the original data set for training so that the student model can learn knowledge in the teacher model.
Representing new pseudo tag dataset and original dataset reconstruction dataset images as X C The label is denoted as Y C And student modelThe prediction results of the student model are expressed as:
s205, calculating the original loss of the student model based on the detection result, and calculating the knowledge distillation loss based on the detection result and the pre-training result. The output of the teacher model is regarded as a pseudo tag, and the student model needs to simulate the output of the teacher model as much as possible.
S206, calculating the overall loss according to the original loss and the knowledge distillation loss of the student model.
S207, adjusting parameters of the student model based on the overall loss to obtain a trained student model.
In this embodiment, through model learning and fusion, knowledge of a complex teacher model is transferred to a student model, so that the student model has higher detection accuracy.
In one embodiment, as shown in fig. 3, fig. 3 is one of the flow charts of the model distillation real-time target detection method based on pseudo tag filtering provided in the embodiment of the present application, and the embodiment relates to how to obtain an extended data set, where step S201 includes, based on the above embodiment:
s301, acquiring an original data set.
S302, carrying out random affine transformation on the original data set.
Specifically, the random affine transformation is an image warping method based on random sampling. The following is the basic steps for implementing a random affine transformation:
randomly generating affine matrix parameters: affine matrix parameters include random generation rotation angle, random generation scaling, random generation translation distance, random generation miscut parameters, etc., by which random affine transformation can be achieved.
Constructing an affine transformation matrix: and constructing an affine transformation matrix according to the generated affine matrix parameters, and realizing affine transformation of the image.
Affine transformation of the image: and applying the constructed affine transformation matrix to the original data set to realize random affine transformation. It should be noted that the original data set is an originally acquired image set.
As shown in table 1, parameters of the 3×3 affine transformation matrix are given in detail:
table 1 parameters of affine transformation matrix
sx×cos(θ) -sy×sin(θ+hx) tx
-sx×sin(θ+hy) sy×cos(θ) ty
0 0 1
Where sx and sy denote the scaling of the image along the x-axis and the y-axis, tx and ty are the translation distances of the image along the x-axis and the y-axis, respectively, θ is the rotation angle, and hx and hy are the miscut parameters of the image along the x-axis and the y-axis, respectively. The random affine transformation can increase the diversity and the number of data sets, and simulate different distances, angles and targets under shielding through scaling, rotation and miscut so as to improve the detection precision and generalization capability of the model.
The random affine transformation operation is adopted to simulate different visual angles and illumination conditions, including image translation, scaling, rotation and shearing are adopted to simulate the position change of a target, the size change of the target under different distances and the occlusion of the target and target parts under different visual angles, and the original data set is converted to expand the data set, so that the detection effect on the multi-visual-angle target, the occlusion target and the multi-scale target can be improved, the generalization capability of the target detection model can be improved, and the overfitting of the target detection model is reduced.
In an alternative embodiment, the teacher model is a YOLOv5-l model, and 300 epochs are trained when training the teacher model based on the extended data set. The mass classifier is a positive and negative sample mass separator.
In one embodiment, the raw loss of the student model includes confidence, category loss, and frame regression loss;
confidence is of
L CE_obj =-αlog(β)-(1-α)log(1-β) (1);
Where the element β represents the probability that the sample belongs to the foreground or the background, i.e. the confidence value of the bounding box, α=a flag of whether the real label contains the target (1 represents that the target is contained, 0 represents that the target is not contained).
It should be noted that, the samples in this embodiment refer to data in a data set of a combination of a new pseudo tag data set and an original data set.
The confidence loss and the category loss were calculated using cross entropy. Entropy is a measure of uncertainty of information and is widely used in the field of communications and information. For a random variable X with probability distribution, entropy is shown as formula (9) and is f (X):
f(x)=-∫p(x)log p(x)dx (9);
measuring uncertainty of the type and the confidence of the prediction by using cross entropy, wherein the larger the value of the entropy is, the larger the uncertainty is, and the worse the predicted result is; the smaller the value of entropy, the smaller the uncertainty and the more accurate the predicted result.
For each object, the student model will output a class probability distribution representing the probability that the object belongs to each class. While for each target there is only one category of real tags. Thus, the cross entropy loss function can be used to evaluate the gap between the class of model predictions and the true tags.
Class loss of
Where p (x) is an actual probability distribution, each element p i Representing the probability that the sample belongs to class i, y when the sample belongs to class i i =1, the others are all 0, nc is the total number of sample categories.
The frame regression loss is
Wherein c is a frame comprising both a predicted frame and a real frameMinimum rectangle, b gt Is a true frame, b is a predicted frame, ρ (b gt B) represents the Euclidean distance of the center points of the real and predicted frames, β is a parameter for balancing the weights of the functions, ν is a parameter for measuring the uniformity of the aspect ratio between the two frames, ioU terms and α represent additional power regularization terms.
In the target detection model, the most common loss function for bounding box regression is the IoU series. The IoU series of expressions originally used to calculate the loss can be represented by the formulas (10), (11):
L IoU =1-IoU
(11);
wherein b gt And b is a prediction frame. This embodiment uses a new power IoU loss series with a IoU term and a parameter α representing the additional power regularization term to measure knowledge distillation loss. alpha-IoU is a family of power intersections for bounding box regression losses. By adjusting α, the weight of the loss and gradient of the high IoU object is adaptively increased to improve the bounding box regression accuracy. The above alpha-IoU loss is extended to a more general form as shown in formula (12):
wherein, alpha is generally taken as 1 =α 2 =3,The representation is based on b and b gt Any penalty term calculated. The above IoU can be substituted by any GIoU, DIoU, CIoU #>For the corresponding penalty term in its formula. The loss function calculation formula of the patent for the bounding box regression by adopting the alpha-CIoU is shown as a formula (3),the calculation formulas v and beta are shown as formula (13) and formula (14).
The original loss of the student model is
L STU =λ 1 ×L CE_cls2 ×L CE_obj3 ×L α-CIoU (4);
Wherein lambda is 1 0.3 lambda 2 0.4 lambda 3 0.3.
In one embodiment, the knowledge distillation loss is
Where m, n represents the rank of tensors of the output result, output T ,output S And outputting results of the teacher model and the student model respectively.
In one embodiment the overall loss is
L total =α 1 ×L STU2 ×L Distill (6);
Wherein alpha is 1 Is 0.8, alpha 2 0.2.
In one embodiment, parameters of the student model are adjusted based on the overall loss, and the obtaining of the trained student model is to counter-propagate the overall loss to the student model, adjust the parameters of the student model, and obtain the student model as the trained student model.
Common object detectors all employ deep convolutional neural networks. In one embodiment, a depth separable convolution module is used in place of the feature extraction portion of the traditional convolutional neural network module in the student model. As shown in fig. 4, the embodiment of the application provides a student model structure schematic diagram after replacing a feature extraction part of a traditional convolutional neural network module with a depth separable convolutional module. The application of the depth separable convolution to the back three convolution layers of the feature extraction (backfone) network of the target detection model can maintain model accuracy while reducing the number of model parameters and the amount of computation. The depth separable convolution module can decompose the convolution operation into two operations of depth convolution and point-by-point convolution. The depth separable convolution can be used to accelerate the computation of the convolutional neural network and reduce the number of parameters and computation of the model while maintaining the accuracy of the model.
In the embodiment, a depth separable convolution module is adopted to replace the traditional convolution operation, a lightweight network feature extraction structure is designed, a target detection model structure is optimized, and the instantaneity of the model is further improved.
The deep Convolution is performed by splitting the Channel Convolution (Channel-wise Convolution) and the spatial Convolution (Spatial Convolution) of the standard Convolution. Let the shape of the input feature map be [ H, W, C ], the shape of the convolution kernel be [ k, k, C, D ], where k represents the size of the convolution kernel, C represents the number of channels of the input feature map, and D represents the number of channels of the output feature map. The calculation of the depth convolution can be expressed as:
(1) For each output channel d, convolving each channel of the input feature map with a convolution kernel of size [ k, k, C ] to obtain a two-dimensional feature map of [ H, W ].
(2) And splicing all the obtained two-dimensional feature graphs along the channel dimension to obtain an output feature graph with the shape of [ H, W, D ].
The point-by-point convolution is to convolve a feature map obtained by the depth convolution with a convolution kernel of the size [1, d' ]. The point-by-point convolution is used for interacting information between each channel and fusing low-level features and high-level features in the feature map obtained by the depth convolution. Assuming that the shape of the output feature map obtained by the depth convolution is [ H, W, D ], and the convolution kernel of the point-by-point convolution is [1, D' ], the calculation process of the point-by-point convolution can be expressed as:
(1) For each position [ i, j ] in the output feature map, a convolution check with the size [1, D ] is used to weight and sum each channel of the feature map obtained by the depth convolution to obtain a vector with the length of D.
(2) And splicing all the obtained vectors along the channel dimension to obtain an output characteristic diagram with the shape of [ H, W, D' ].
The calculation process of the depth separable convolution is as follows: the method comprises the steps of performing deep convolution to obtain a feature map with the shape of [ H, W, D ], and performing point-by-point convolution to obtain an output feature map with the shape of [ H, W, D' ].
In one embodiment, training the student model based on the teacher model further includes migration training the trained student model using the raw data after obtaining the trained student model.
Preferably, the present application uses the test dataset to evaluate the overall performance of the final generated student model. The detection accuracy of the model is evaluated using indexes such as accuracy (Precision) and Recall (Recall), mAP@0.5, and the like. The frame rate (FPS) or the inference time (ms) is used for measuring the inference speed of the model, and a higher frame rate or a shorter inference time indicates that the model has faster real-time performance.
Based on the same inventive concept, the embodiment of the application also provides a model distillation real-time target detection device based on the pseudo tag filtering, which is used for realizing the model distillation real-time target detection method based on the pseudo tag filtering. The implementation of the solution provided by the device is similar to the implementation described in the above method, so specific limitations in the embodiments of the model distillation real-time target detection device based on pseudo tag filtering at one or more points provided below can be referred to above for limitations of the model distillation real-time target detection method based on pseudo tag filtering, and will not be repeated here.
In one embodiment, the model distillation real-time target detection device based on pseudo tag filtering comprises a memory and a processor, wherein the memory stores a computer program, and the processor realizes the step of model distillation real-time target detection based on pseudo tag filtering when executing the computer program.
It should be understood that, although the steps in the flowcharts related to the embodiments described above are sequentially shown as indicated by arrows, these steps are not necessarily sequentially performed in the order indicated by the arrows. The steps are not strictly limited to the order of execution unless explicitly recited herein, and the steps may be executed in other orders. Moreover, at least some of the steps in the flowcharts described in the above embodiments may include a plurality of steps or a plurality of stages, which are not necessarily performed at the same time, but may be performed at different times, and the order of the steps or stages is not necessarily performed sequentially, but may be performed alternately or alternately with at least some of the other steps or stages.
The foregoing examples illustrate only a few embodiments of the application and are described in detail herein without thereby limiting the scope of the application. It should be noted that it will be apparent to those skilled in the art that several variations and modifications can be made without departing from the spirit of the application, which are all within the scope of the application. Accordingly, the scope of the application should be assessed as that of the appended claims.

Claims (10)

1. The model distillation real-time target detection method based on pseudo tag filtering is characterized in that a student model is trained based on a teacher model, and then data to be detected is input into the trained student model to obtain a real-time target detection result, wherein the teacher model and the student model are both target detection models, the number of layers of the teacher model is more than that of the student model, and the depth of the teacher model is greater than that of the student model;
the training of the student model based on the teacher model specifically comprises the following steps:
acquiring an extended data set;
training the teacher model based on the extended data set, inputting the extended data set into the trained teacher model, and generating a pseudo tag data set;
inputting the pseudo tag data set into a quality classifier to generate a new pseudo tag data set;
inputting the new pseudo tag data set and the original data set into a student model and a trained teacher model respectively to obtain a detection result and a pre-training result;
calculating the original loss of the student model based on the detection result, and calculating the knowledge distillation loss based on the detection result and the pre-training result;
calculating the overall loss according to the original loss and the knowledge distillation loss of the student model;
and adjusting parameters of the student model based on the overall loss to obtain a trained student model.
2. The pseudo tag filter based model distillation real time target detection method according to claim 1, wherein obtaining an extended data set comprises:
acquiring an original data set;
a random affine transformation is performed on the original dataset.
3. The model distillation real-time target detection method based on pseudo tag filtering according to claim 2, wherein the teacher model is a YOLOv5-l model, and training 300 epochs is performed when training the teacher model based on an extended data set;
the mass classifier is a positive and negative sample mass separator.
4. The model distillation real-time target detection method based on pseudo tag filtering according to claim 1, wherein the raw loss of the student model comprises confidence, category loss and frame regression loss;
the confidence is
L CE _o bj =-αlog(β)-(1-α)log(1-β) (1);
Wherein, the element beta represents the probability that the sample belongs to the foreground or the background, namely the confidence value of the boundary box, and alpha=is a mark of whether the real label contains the target (1 represents that the target is contained, and 0 represents that the target is not contained);
the category loss is
Where p (x) is an actual probability distribution, each element p i Representing the probability that the sample belongs to class i, y when the sample belongs to class i i =1, the others are all 0, nc is the total number of sample categories;
the frame regression loss is
Wherein c is a minimum rectangle containing both the predicted frame and the real frame, b gt Is a true frame, b is a predicted frame, ρ (b gt B) represents the Euclidean distance between the center points of the real frame and the predicted frame, beta is a parameter for balancing the weight of the function, ν is a parameter for measuring the consistency of the length-width ratio between the two frames, ioU items and alpha represent additional power regularization items;
the original loss of the student model is
L STU =λ 1 ×L CE_cls2 ×L CE_obj3 ×L α-CIoU (4);
Wherein lambda is 1 0.3 lambda 2 0.4 lambda 3 0.3.
5. The pseudo tag filter based model distillation real time target detection method according to claim 4 wherein the knowledge distillation loss is
In the middle ofM, n represents the row and column of tensors of the output result, output T ,output S And outputting results of the teacher model and the student model respectively.
6. The pseudo tag filter based model distillation real time target detection method according to claim 5, wherein said overall loss is
L total =α 1 ×L STU2 ×L Distill (6);
Wherein alpha is 1 Is 0.8, alpha 2 0.2.
7. The method for real-time target detection by model distillation based on pseudo tag filtering according to claim 6, wherein said adjusting parameters of student model based on said overall loss, obtaining a trained student model is back-propagating overall loss to student model, adjusting student model parameters, obtaining a trained student model.
8. The model distillation real-time target detection method based on pseudo tag filtering according to claim 3, wherein a depth separable convolution module is used to replace a part of a traditional convolution neural network module for feature extraction in a student model.
9. The method for model distillation real-time target detection based on pseudo tag filtering according to any one of claims 2 to 8, wherein training the student model based on the teacher model further comprises migration training the trained student model using raw data after obtaining the trained student model.
10. Model distillation real-time object detection device based on pseudo tag filtering, comprising a memory and a processor, the memory storing a computer program, characterized in that the processor implements the steps of the method according to any one of claims 1 to 9 when executing the computer program.
CN202310815686.9A 2023-07-04 2023-07-04 Model distillation real-time target detection method and device based on pseudo tag filtering Pending CN117011640A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310815686.9A CN117011640A (en) 2023-07-04 2023-07-04 Model distillation real-time target detection method and device based on pseudo tag filtering

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310815686.9A CN117011640A (en) 2023-07-04 2023-07-04 Model distillation real-time target detection method and device based on pseudo tag filtering

Publications (1)

Publication Number Publication Date
CN117011640A true CN117011640A (en) 2023-11-07

Family

ID=88566434

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310815686.9A Pending CN117011640A (en) 2023-07-04 2023-07-04 Model distillation real-time target detection method and device based on pseudo tag filtering

Country Status (1)

Country Link
CN (1) CN117011640A (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117372819A (en) * 2023-12-07 2024-01-09 神思电子技术股份有限公司 Target detection increment learning method, device and medium for limited model space

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117372819A (en) * 2023-12-07 2024-01-09 神思电子技术股份有限公司 Target detection increment learning method, device and medium for limited model space
CN117372819B (en) * 2023-12-07 2024-02-20 神思电子技术股份有限公司 Target detection increment learning method, device and medium for limited model space

Similar Documents

Publication Publication Date Title
CN111242208B (en) Point cloud classification method, segmentation method and related equipment
Sahu et al. A survey on deep learning: convolution neural network (CNN)
Ye et al. Inverted pyramid multi-task transformer for dense scene understanding
Yang et al. A dual attention network based on efficientNet-B2 for short-term fish school feeding behavior analysis in aquaculture
CN111507378A (en) Method and apparatus for training image processing model
Li et al. Robust tensor subspace learning for anomaly detection
WO2022001805A1 (en) Neural network distillation method and device
Yin et al. End-to-end face parsing via interlinked convolutional neural networks
CN114049381A (en) Twin cross target tracking method fusing multilayer semantic information
Wang et al. Face mask extraction in video sequence
CN113592060A (en) Neural network optimization method and device
EP4318313A1 (en) Data processing method, training method for neural network model, and apparatus
CN116310850B (en) Remote sensing image target detection method based on improved RetinaNet
CN115222896B (en) Three-dimensional reconstruction method, three-dimensional reconstruction device, electronic equipment and computer readable storage medium
Wang et al. Urban building extraction from high-resolution remote sensing imagery based on multi-scale recurrent conditional generative adversarial network
CN117011640A (en) Model distillation real-time target detection method and device based on pseudo tag filtering
Zhang et al. Unsupervised remote sensing image segmentation based on a dual autoencoder
Zhang et al. Crop pest recognition based on a modified capsule network
Wu et al. Dynamic activation and enhanced image contour features for object detection
Hua et al. Real-time object detection in remote sensing images based on visual perception and memory reasoning
Cao et al. QuasiVSD: efficient dual-frame smoke detection
CN116977844A (en) Lightweight underwater target real-time detection method
Jin et al. Fusion of remote sensing images based on pyramid decomposition with Baldwinian Clonal Selection Optimization
Pei et al. FGO-Net: Feature and Gaussian Optimization Network for visual saliency prediction
Shang et al. Recognition of coal and gangue under low illumination based on SG-YOLO model

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination