CN116524351A - Lightweight method and system for rotating target detection based on knowledge distillation - Google Patents
Lightweight method and system for rotating target detection based on knowledge distillation Download PDFInfo
- Publication number
- CN116524351A CN116524351A CN202310299030.6A CN202310299030A CN116524351A CN 116524351 A CN116524351 A CN 116524351A CN 202310299030 A CN202310299030 A CN 202310299030A CN 116524351 A CN116524351 A CN 116524351A
- Authority
- CN
- China
- Prior art keywords
- divergence
- detection target
- gaussian distribution
- model
- representation
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000001514 detection method Methods 0.000 title claims abstract description 168
- 238000000034 method Methods 0.000 title claims abstract description 111
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 44
- 238000009826 distribution Methods 0.000 claims abstract description 105
- 230000008569 process Effects 0.000 claims abstract description 17
- 238000010606 normalization Methods 0.000 claims abstract description 10
- 238000004364 calculation method Methods 0.000 claims description 27
- 238000012216 screening Methods 0.000 claims description 18
- 238000012545 processing Methods 0.000 claims description 10
- 238000006243 chemical reaction Methods 0.000 claims description 7
- 239000011159 matrix material Substances 0.000 claims description 6
- 238000012805 post-processing Methods 0.000 claims description 5
- 230000001174 ascending effect Effects 0.000 claims description 3
- 238000004821 distillation Methods 0.000 abstract description 40
- 238000012549 training Methods 0.000 abstract description 14
- 230000003287 optical effect Effects 0.000 abstract description 6
- 230000006870 function Effects 0.000 description 14
- WDLRUFUQRNWCPK-UHFFFAOYSA-N Tetraxetan Chemical compound OC(=O)CN1CCN(CC(O)=O)CCN(CC(O)=O)CCN(CC(O)=O)CC1 WDLRUFUQRNWCPK-UHFFFAOYSA-N 0.000 description 8
- 238000004590 computer program Methods 0.000 description 7
- 238000010586 diagram Methods 0.000 description 7
- 238000002474 experimental method Methods 0.000 description 6
- 238000013527 convolutional neural network Methods 0.000 description 3
- 238000000605 extraction Methods 0.000 description 3
- 238000012360 testing method Methods 0.000 description 3
- 230000000052 comparative effect Effects 0.000 description 2
- 230000007547 defect Effects 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 230000001629 suppression Effects 0.000 description 2
- 238000010200 validation analysis Methods 0.000 description 2
- 101001121408 Homo sapiens L-amino-acid oxidase Proteins 0.000 description 1
- 102100026388 L-amino-acid oxidase Human genes 0.000 description 1
- 101100012902 Saccharomyces cerevisiae (strain ATCC 204508 / S288c) FIG2 gene Proteins 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000013434 data augmentation Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000003384 imaging method Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000010365 information processing Effects 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 230000000737 periodic effect Effects 0.000 description 1
- 238000013139 quantization Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000009182 swimming Effects 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/10—Terrestrial scenes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/766—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using regression, e.g. by projecting features on hyperplanes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V2201/00—Indexing scheme relating to image or video recognition or understanding
- G06V2201/07—Target detection
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Multimedia (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
Description
技术领域Technical Field
本发明涉及高分辨率遥感影像信息处理技术领域,尤其是指一种基于知识蒸馏的旋转目标检测轻量化方法及系统。The present invention relates to the technical field of high-resolution remote sensing image information processing, and in particular to a lightweight method and system for rotating target detection based on knowledge distillation.
背景技术Background Art
遥感影像中的目标检测旨在地球表面定位感兴趣的物体(如车辆、飞机、舰船),并预测其类别。遥感影像成像方位多为俯视,其包含的空间场景更大更复杂,感兴趣目标分布不均衡,存在稀疏场景和密集场景分布的情况。遥感目标检测存在着目标小、分布密集、方向任意的难点。当感兴趣目标长宽比例较大、存在倾斜、紧密排列,仅仅使用水平的检测框,大小和长宽比例不能真实反映目标物体,目标和背景像素不能有效的分离,密集目标间难以分离,每个检测框会包含邻近的目标,水平检测框之间的交并比值较高,在非极大值抑制过程中很容易被抑制掉,导致最终的检测精度很低。由于目标朝向即旋转角度存在着周期性的特性,合适的损失函数设计对于模型训练优化至关重要。Object detection in remote sensing images aims to locate objects of interest (such as vehicles, aircraft, and ships) on the surface of the earth and predict their categories. Remote sensing images are mostly viewed from above, and the spatial scenes they contain are larger and more complex. The distribution of objects of interest is uneven, and there are sparse and dense distributions. Remote sensing target detection has the difficulties of small targets, dense distribution, and arbitrary directions. When the length-to-width ratio of the target of interest is large, there is an inclination, and it is closely arranged, only using a horizontal detection frame, the size and length-to-width ratio cannot truly reflect the target object, the target and background pixels cannot be effectively separated, and dense targets are difficult to separate. Each detection frame will contain adjacent targets, and the intersection-union ratio between horizontal detection frames is high, which is easily suppressed in the non-maximum suppression process, resulting in low final detection accuracy. Due to the periodic characteristics of the target orientation, that is, the rotation angle, the design of a suitable loss function is crucial for model training optimization.
传统的目标检测方法往往需要复杂的手工特征提取和精细的参数调整,模型的泛化性较差,不能适应不断变化的环境。近些年,随着硬件设备算力的提高、遥感成像技术的发展和遥感影像资源获取便捷,深度学习技术得到广泛地发展,特别是卷积神经网络模型,得益于其强大的鲁棒特征提取能力、函数拟合能力和端到端的网络模型设计,拥有较强的泛化性能。基于卷积神经网络模型,大多数最先进的目标检测方法都专注于设计先进的网络结构或者损失函数来提高检测性能,但这种方法都有着昂贵的计算成本,因此后续有一些方法致力于通过轻量化方法利用(教师)大模型中的知识进行蒸馏,来提高检测性能。例如基于中间层特征的知识蒸馏方法已被证明了中间层特征在检测网络的训练和预测过程中的重要性,但如果仅仅对中间层特征蒸馏会丢失坐标空间和类别空间的信息,对于最终的检测结果会有影响;还有基于广义分布信息的知识蒸馏,利用分类损失解决回归问题,包括旋转目标位置和旋转目标角度,解决定位不确定性问题,但其存在量化误差。将旋转坐标转换为高斯分布表示近年来在旋转检测领域有着较多的应用,但是在检测器轻量化方法中的研究较少。因此,迫切需要提出一种基于知识蒸馏的旋转目标检测轻量化方法以克服现有技术存在的上述技术缺陷。Traditional target detection methods often require complex manual feature extraction and fine parameter adjustment, and the generalization of the model is poor and cannot adapt to the ever-changing environment. In recent years, with the improvement of hardware computing power, the development of remote sensing imaging technology and the convenience of remote sensing image resources, deep learning technology has been widely developed, especially the convolutional neural network model, which has strong generalization performance due to its powerful robust feature extraction ability, function fitting ability and end-to-end network model design. Based on the convolutional neural network model, most of the most advanced target detection methods focus on designing advanced network structures or loss functions to improve detection performance, but this method has expensive computational costs. Therefore, some subsequent methods are committed to improving detection performance by using lightweight methods to distill the knowledge in the (teacher) large model. For example, the knowledge distillation method based on the intermediate layer features has been proven to be important in the training and prediction process of the detection network, but if only the intermediate layer features are distilled, the information of the coordinate space and the category space will be lost, which will affect the final detection results; there is also knowledge distillation based on generalized distribution information, which uses classification loss to solve regression problems, including rotating the target position and rotating the target angle, and solving the problem of positioning uncertainty, but it has quantization errors. Converting the rotated coordinates into Gaussian distribution representation has been widely used in the field of rotation detection in recent years, but there is little research on the lightweight method of detectors. Therefore, it is urgent to propose a lightweight method for rotating target detection based on knowledge distillation to overcome the above technical defects of the existing technology.
发明内容Summary of the invention
为此,本发明所要解决的技术问题在于克服现有技术中存在的技术缺陷,而提出一种基于知识蒸馏的旋转目标检测轻量化方法及系统,其利用教师、学生模型的坐标编码输出进行知识蒸馏,计算旋转坐标的高斯分布表示间的KL散度作为添加的蒸馏损失进行学生模型参数的训练,为轻量化网络的训练提供更准确的定位信息,提高轻量化网络在光学遥感图像上的检测性能。To this end, the technical problem to be solved by the present invention is to overcome the technical defects existing in the prior art, and propose a lightweight method and system for rotation target detection based on knowledge distillation, which uses the coordinate encoding output of the teacher and student models for knowledge distillation, calculates the KL divergence between the Gaussian distribution representations of the rotation coordinates as the added distillation loss to train the student model parameters, provides more accurate positioning information for the training of the lightweight network, and improves the detection performance of the lightweight network on optical remote sensing images.
为解决上述技术问题,本发明提供了一种基于知识蒸馏的旋转目标检测轻量化方法,包括以下步骤:To solve the above technical problems, the present invention provides a lightweight method for rotating target detection based on knowledge distillation, comprising the following steps:
S1:将所述学生模型和训练好的教师模型输出的检测目标的坐标编码分别解码为旋转坐标形式,得到检测目标的旋转坐标表示;S1: Decoding the coordinate encodings of the detection target output by the student model and the trained teacher model into rotation coordinate forms respectively to obtain the rotation coordinate representation of the detection target;
S2:将所述旋转坐标表示转换为二维高斯分布,得到检测目标的高斯分布表示;S2: converting the rotation coordinate representation into a two-dimensional Gaussian distribution to obtain a Gaussian distribution representation of the detection target;
S3:计算所述学生模型与教师模型输出的检测目标的高斯分布表示的KL散度,得到第一KL散度,以及计算学生模型输出的检测目标的高斯分布表示与标签真实框得到的检测目标的高斯分布表示之间的KL散度,得到第二KL散度;S3: Calculate the KL divergence of the Gaussian distribution representation of the detection target output by the student model and the teacher model to obtain a first KL divergence, and calculate the KL divergence between the Gaussian distribution representation of the detection target output by the student model and the Gaussian distribution representation of the detection target obtained by the label true frame to obtain a second KL divergence;
S4:对所述第一KL散度和第二KL散度进行归一化处理;S4: normalizing the first KL divergence and the second KL divergence;
S5:根据归一化处理后的第一KL散度和第二KL散度,计算得到教师模型对学生模型进行知识蒸馏过程中的总体损失函数;S5: According to the normalized first KL divergence and the second KL divergence, the overall loss function of the teacher model in the process of knowledge distillation of the student model is calculated;
S6:使用蒸馏后的学生模型进行预测。S6: Prediction using the distilled student model.
在本发明的一个实施例中,在步骤S2中,将所述学生模型和训练好的教师模型输出的检测目标的坐标编码分别解码为旋转坐标形式的方法,包括:In one embodiment of the present invention, in step S2, the method of respectively decoding the coordinate encodings of the detection target output by the student model and the trained teacher model into the form of rotation coordinates includes:
将学生模型和训练好的教师模型输出的检测目标的坐标编码Y:(dx,dy,dw,dh,dθ)解码为旋转坐标形式(x,y,w,h,θ),其中:Decode the coordinate encoding Y of the detected target output by the student model and the trained teacher model: (dx, dy, dw, dh, dθ) into a rotated coordinate form (x, y, w, h, θ), where:
x=xa+dx*wa x= xa +dx* wa
y=ya+dy*ha y= ya +dy* ha
w=wa*edw w=wa*e dw
h=ha*edh h= ha * edh
θ=θa+dθθ= θa +dθ
式中,A:(xa,ya,wa,ha,θa)为模型预设的锚框。Where A:( xa , ya , wa , ha , θa ) is the anchor box preset by the model.
在本发明的一个实施例中,在步骤S3中,将所述旋转坐标表示转换为二维高斯分布的方法,包括:In one embodiment of the present invention, in step S3, the method of converting the rotated coordinate representation into a two-dimensional Gaussian distribution includes:
将旋转坐标表示(x,y,w,h,θ)转换成二维高斯分布(μ,∑)的计算公式为:Represent the rotated coordinates (x, y, w, h, θ) is converted into a two-dimensional Gaussian distribution The calculation formula of (μ, ∑) is:
μ=(x,y)T μ=(x,y) T
式中,(x,y,w,h,θ)为检测目标的旋转坐标表示,分别表示检测目标中心点横坐标、纵坐标、检测目标的宽、高和旋转角度,(μ,∑)分别表示高斯分布的均值、协方差矩阵。Where (x, y, w, h, θ) is the rotation coordinate representation of the detection target, which respectively represents the horizontal coordinate, vertical coordinate, width, height and rotation angle of the detection target center point, and (μ, ∑) represents the mean and covariance matrix of the Gaussian distribution.
在本发明的一个实施例中,在步骤S4中,计算所述学生模型与教师模型输出的检测目标的高斯分布表示的KL散度的方法,包括:In one embodiment of the present invention, in step S4, the method of calculating the KL divergence of the Gaussian distribution representation of the detection target output by the student model and the teacher model includes:
所述学生模型与教师模型输出的检测目标的高斯分布表示的KL散度的计算公式为:The calculation formula of the KL divergence represented by the Gaussian distribution of the detection target output by the student model and the teacher model is:
式中,(xS,yS,wS,hS,θS),(μS,∑S)和(xT,yT,wT,hT,θT),(μT,∑T)分别表示学生模型和教师模型输出的检测目标的旋转坐标表示和高斯分布表示,Δx=xS-xT,Δy=yS-yT,Δθ=θS-θT分别表示学生模型输出的检测目标和标签真实框的中心点横、纵坐标之差和旋转角度之差。In the formula, (x S , y S , w S , h S , θ S ), (μ S , ∑ S ) and (x T , y T , w T , h T , θ T ), (μ T , ∑ T ) represents the rotation coordinate representation and Gaussian distribution representation of the detection target output by the student model and the teacher model, respectively. Δx=x S -x T , Δy=y S -y T , Δθ=θ S -θ T represent the difference in horizontal and vertical coordinates and the difference in rotation angle between the center point of the detection target output by the student model and the label true frame, respectively.
在本发明的一个实施例中,在步骤S4中,计算学生模型输出的检测目标的高斯分布表示与标签真实框得到的检测目标的高斯分布表示之间的KL散度的方法,包括:In one embodiment of the present invention, in step S4, a method for calculating the KL divergence between the Gaussian distribution representation of the detection target output by the student model and the Gaussian distribution representation of the detection target obtained by the label true frame includes:
学生模型输出的检测目标的高斯分布表示与标签真实框得到的检测目标的高斯分布表示之间的KL散度的计算公式为:The calculation formula for the KL divergence between the Gaussian distribution representation of the detection target output by the student model and the Gaussian distribution representation of the detection target obtained by the label real box is:
式中(xS,yS,wS,hS,θS),(μS,∑S)和(xG,yG,wG,hG,θG),(μG,∑G)分别表示学生模型输出的检测目标和标签真实框得到的检测目标的旋转坐标表示和高斯分布表示,Δx=xS-xG,Δy=yS-yG,Δθ=θS-θG分别表示学生模型输出的检测目标和标签真实框得到的检测目标的中心点横、纵坐标之差和旋转角度之差。In the formula (x S , y S , w S , h S , θ S ), (μ S , ∑ S ) and (x G , y G , w G , h G , θ G ), (μ G , ∑ G ) represent the rotation coordinate representation and Gaussian distribution representation of the detection target output by the student model and the detection target obtained by the label true frame, respectively. Δx=x S -x G , Δy=y S -y G , Δθ=θ S -θ G represent the difference in the horizontal and vertical coordinates and the difference in rotation angle of the center point of the detection target output by the student model and the detection target obtained by the label true frame, respectively.
在本发明的一个实施例中,在步骤S7中,使用蒸馏后的学生模型进行预测后,对学生模型输出的预测框进行后处理。In one embodiment of the present invention, in step S7, after prediction is performed using the distilled student model, the prediction box output by the student model is post-processed.
在本发明的一个实施例中,在步骤S7中,对学生模型输出的预测框进行后处理的方法,包括:后处理的具体步骤如下:In one embodiment of the present invention, in step S7, the method for post-processing the prediction box output by the student model includes: the specific steps of post-processing are as follows:
将每一类别中所有的预测框按得分升序或降序排列;Arrange all prediction boxes in each category in ascending or descending order by score;
选取得分最高的预测框为基准,计算其他预测框与其的交并比值;Select the prediction box with the highest score as the benchmark, and calculate the intersection and union ratio of other prediction boxes with it;
根据交并比值对每一类别中的预测框进行筛选,将不满足筛选条件的预测框删除,其中筛选条件为待筛选的预测框与基准预测框的交并比值小于等于预设的阈值。The prediction boxes in each category are screened according to the intersection-and-union ratio, and the prediction boxes that do not meet the screening conditions are deleted, where the screening condition is that the intersection-and-union ratio of the prediction box to be screened and the reference prediction box is less than or equal to a preset threshold.
在剩下符合筛选条件的预测框中,选取得分最高的预测框作为基准框,重复以上筛选步骤,直至所有符合筛选条件的预测框都作为基准参与筛选。Among the remaining prediction boxes that meet the screening conditions, select the prediction box with the highest score as the benchmark box, and repeat the above screening steps until all prediction boxes that meet the screening conditions are used as benchmarks for screening.
此外,本发明还提供一种基于知识蒸馏的旋转目标检测轻量化系统,包括以下步骤:In addition, the present invention also provides a lightweight system for rotating target detection based on knowledge distillation, comprising the following steps:
坐标解码模块,其用于将所述学生模型和训练好的教师模型输出的检测目标的坐标编码分别解码为旋转坐标形式,得到检测目标的旋转坐标表示;A coordinate decoding module, which is used to decode the coordinate encoding of the detection target output by the student model and the trained teacher model into a rotation coordinate form, so as to obtain a rotation coordinate representation of the detection target;
坐标转换模块,其用于将所述旋转坐标表示转换为二维高斯分布,得到检测目标的高斯分布表示;A coordinate conversion module, which is used to convert the rotation coordinate representation into a two-dimensional Gaussian distribution to obtain a Gaussian distribution representation of the detection target;
KL散度计算模块,其用于计算所述学生模型与教师模型输出的检测目标的高斯分布表示的KL散度,得到第一KL散度,以及计算学生模型输出的检测目标的高斯分布表示与标签真实框得到的检测目标的高斯分布表示之间的KL散度,得到第二KL散度;A KL divergence calculation module is used to calculate the KL divergence of the Gaussian distribution representation of the detection target output by the student model and the teacher model to obtain a first KL divergence, and calculate the KL divergence between the Gaussian distribution representation of the detection target output by the student model and the Gaussian distribution representation of the detection target obtained by the label true frame to obtain a second KL divergence;
归一化处理模块,其用于对所述第一KL散度和第二KL散度进行归一化处理;A normalization processing module, used for performing normalization processing on the first KL divergence and the second KL divergence;
损失函数计算模块,其用于根据归一化处理后的第一KL散度和第二KL散度,计算得到教师模型对学生模型进行知识蒸馏过程中的总体损失函数;A loss function calculation module, which is used to calculate the overall loss function in the process of knowledge distillation of the teacher model to the student model according to the first KL divergence and the second KL divergence after normalization;
预测模块,其用于使用蒸馏后的学生模型进行预测。The prediction module is used to make predictions using the distilled student model.
在本发明的一个实施例中,在坐标转换模块中,将所述旋转坐标表示转换为二维高斯分布的方法,包括:In one embodiment of the present invention, in a coordinate conversion module, a method for converting the rotation coordinate representation into a two-dimensional Gaussian distribution includes:
将旋转坐标表示(x,y,w,h,θ)转换成二维高斯分布(μ,∑)的计算公式为:Represent the rotated coordinates (x, y, w, h, θ) is converted into a two-dimensional Gaussian distribution The calculation formula of (μ, ∑) is:
μ=(x,y)T μ=(x,y) T
式中,(x,y,w,h,θ)为检测目标的旋转坐标表示,分别表示检测目标中心点横坐标、纵坐标、检测目标的宽、高和旋转角度,(μ,∑)分别表示高斯分布的均值、协方差矩阵。Where (x, y, w, h, θ) is the rotation coordinate representation of the detection target, which respectively represents the horizontal coordinate, vertical coordinate, width, height and rotation angle of the detection target center point, and (μ, ∑) represents the mean and covariance matrix of the Gaussian distribution.
在本发明的一个实施例中,在KL散度计算模块中,计算所述学生模型与教师模型输出的检测目标的高斯分布表示的KL散度的方法,包括:In one embodiment of the present invention, in the KL divergence calculation module, a method for calculating the KL divergence of the Gaussian distribution representation of the detection target output by the student model and the teacher model includes:
所述学生模型与教师模型输出的检测目标的高斯分布表示的KL散度的计算公式为:The calculation formula of the KL divergence represented by the Gaussian distribution of the detection target output by the student model and the teacher model is:
式中(xS,yS,wS,hS,θS),(μS,∑S)和(xT,yT,wT,hT,θT),(μT,∑T)分别表示学生模型和教师模型输出的检测目标的旋转坐标表示和高斯分布表示,Δx=xS-xT,Δy=yS-yT,Δθ=θS-θT分别表示学生模型输出的检测目标和标签真实框的中心点横、纵坐标之差和旋转角度之差。In the formula (x S , y S , w S , h S , θ S ), (μ S , ∑ S ) and (x T , y T , w T , h T , θ T ), (μ T , ∑ T ) represents the rotation coordinate representation and Gaussian distribution representation of the detection target output by the student model and the teacher model, respectively. Δx=x S -x T , Δy=y S -y T , Δθ=θ S -θ T represent the difference in horizontal and vertical coordinates and the difference in rotation angle between the center point of the detection target output by the student model and the label true frame, respectively.
本发明的上述技术方案相比现有技术具有以下优点:The above technical solution of the present invention has the following advantages compared with the prior art:
本发明所述的一种基于知识蒸馏的旋转目标检测轻量化方法及系统,其利用教师、学生模型的坐标编码输出进行知识蒸馏,计算旋转坐标的高斯分布表示间的KL散度作为添加的蒸馏损失进行学生模型参数的训练,为轻量化网络的训练提供更准确的定位信息,提高轻量化网络在光学遥感图像上的检测性能。The present invention discloses a lightweight method and system for rotating target detection based on knowledge distillation, which utilizes the coordinate encoding outputs of teacher and student models for knowledge distillation, calculates the KL divergence between the Gaussian distribution representations of the rotating coordinates as the added distillation loss to train the student model parameters, provides more accurate positioning information for the training of the lightweight network, and improves the detection performance of the lightweight network on optical remote sensing images.
附图说明BRIEF DESCRIPTION OF THE DRAWINGS
为了使本发明的内容更容易被清楚的理解,下面根据本发明的具体实施例并结合附图,对本发明作进一步详细的说明,其中In order to make the content of the present invention more clearly understood, the present invention is further described in detail below according to specific embodiments of the present invention in conjunction with the accompanying drawings, wherein
图1为本发明所提出的基于知识蒸馏的旋转目标检测轻量化方法的训练过程图。FIG1 is a diagram showing the training process of the lightweight method for rotating target detection based on knowledge distillation proposed in the present invention.
图2为本发明所提出的基于知识蒸馏的旋转目标检测轻量化方法的预测流程图。FIG2 is a prediction flow chart of the lightweight method for rotating target detection based on knowledge distillation proposed in the present invention.
图3为本发明所提出的蒸馏方法和其他蒸馏方法在HRSC2016数据集基于ResNet18-FPN-RetinaNet的检测精度随训练后20轮数的变化曲线,其中,w/o-d为不使用蒸馏的原始网络;FitNets为基于全部中间层特征的蒸馏方法;DeFeat为基于解耦特征的蒸馏方法;LD为基于坐标间广义分布信息的蒸馏方法;KLDD为本发明提出的基于坐标高斯分布间KL散度的蒸馏方法。Figure 3 shows the curve of the detection accuracy of the distillation method proposed in the present invention and other distillation methods based on ResNet18-FPN-RetinaNet on the HRSC2016 dataset as the number of training rounds increases after 20 rounds, where w/o-d is the original network without distillation; FitNets is a distillation method based on all intermediate layer features; DeFeat is a distillation method based on decoupled features; LD is a distillation method based on generalized distribution information between coordinates; KLDD is a distillation method based on the KL divergence between Gaussian distributions of coordinates proposed in the present invention.
图4为本发明所提出的蒸馏方法和其他蒸馏方法在HRSC2016数据集基于MobileNetv2-FPN-RetinaNet的检测精度随训练后20轮数的变化曲线,其中,w/o-d为不使用蒸馏的原始网络;FitNets为基于全部中间层特征的蒸馏方法;DeFeat为基于解耦特征的蒸馏方法;LD为基于坐标间广义分布信息的蒸馏方法;KLDD为本发明提出的基于坐标高斯分布间KL散度的蒸馏方法。Figure 4 shows the curve of the detection accuracy of the distillation method proposed in the present invention and other distillation methods based on MobileNetv2-FPN-RetinaNet on the HRSC2016 dataset as the number of training rounds increases after 20 rounds, where w/o-d is the original network without distillation; FitNets is a distillation method based on all intermediate layer features; DeFeat is a distillation method based on decoupled features; LD is a distillation method based on generalized distribution information between coordinates; KLDD is a distillation method based on the KL divergence between Gaussian distributions of coordinates proposed in the present invention.
图5为不同蒸馏方法在HRSC2016数据集上得到的检测结果对比,其中图(5a)为数据集原始标注,图(5b)为原始ResNet18-FPN-RetinaNet检测器不使用任何蒸馏方法的检测结果,图(5c)为基于全部中间层特征的FitNets蒸馏方法检测结果,图(5d)为基于目标区域解耦特征的DeFeat蒸馏方法检测结果,图(5e)为基于目标位置坐标间广义分布信息的LD蒸馏方法检测结果,图(5f)为提出的基于坐标高斯分布间KL散度的KLDD蒸馏方法检测结果。图中的蓝绿色框表示检测到的真实目标(TP),黄色框表示检测到的错误目标(FP),红色框表示遗漏检测的目标(FN)。Figure 5 shows the comparison of the detection results obtained by different distillation methods on the HRSC2016 dataset, where Figure (5a) is the original annotation of the dataset, Figure (5b) is the detection result of the original ResNet18-FPN-RetinaNet detector without using any distillation method, Figure (5c) is the detection result of the FitNets distillation method based on all intermediate layer features, Figure (5d) is the detection result of the DeFeat distillation method based on the decoupled features of the target area, Figure (5e) is the detection result of the LD distillation method based on the generalized distribution information between the target position coordinates, and Figure (5f) is the detection result of the proposed KLDD distillation method based on the KL divergence between the coordinate Gaussian distributions. The cyan box in the figure indicates the detected true target (TP), the yellow box indicates the detected false target (FP), and the red box indicates the missed detection target (FN).
具体实施方式DETAILED DESCRIPTION
下面结合附图和具体实施例对本发明作进一步说明,以使本领域的技术人员可以更好地理解本发明并能予以实施,但所举实施例不作为对本发明的限定。The present invention will be further described below in conjunction with the accompanying drawings and specific embodiments so that those skilled in the art can better understand the present invention and implement it, but the embodiments are not intended to limit the present invention.
本发明实施例提供一种基于知识蒸馏的旋转目标检测轻量化方法,包括以下步骤:The embodiment of the present invention provides a lightweight method for rotating target detection based on knowledge distillation, comprising the following steps:
S1:将所述学生模型和教师模型输出的检测目标的坐标编码分别解码为旋转坐标形式,得到检测目标的旋转坐标表示;S1: Decoding the coordinate codes of the detection target output by the student model and the teacher model into rotation coordinate forms respectively to obtain the rotation coordinate representation of the detection target;
S2:将所述旋转坐标表示转换为二维高斯分布,得到检测目标的高斯分布表示;S2: converting the rotation coordinate representation into a two-dimensional Gaussian distribution to obtain a Gaussian distribution representation of the detection target;
S3:计算所述学生模型与教师模型输出的检测目标的高斯分布表示的KL散度,得到第一KL散度,以及计算学生模型输出的检测目标的高斯分布表示与标签真实框得到的检测目标的高斯分布表示之间的KL散度,得到第二KL散度;S3: Calculate the KL divergence of the Gaussian distribution representation of the detection target output by the student model and the teacher model to obtain a first KL divergence, and calculate the KL divergence between the Gaussian distribution representation of the detection target output by the student model and the Gaussian distribution representation of the detection target obtained by the label true frame to obtain a second KL divergence;
S4:对所述第一KL散度和第二KL散度进行归一化处理;S4: normalizing the first KL divergence and the second KL divergence;
S5:根据归一化处理后的第一KL散度和第二KL散度,计算得到教师模型对学生模型进行知识蒸馏过程中的总体损失函数;S5: According to the normalized first KL divergence and the second KL divergence, the overall loss function of the teacher model in the process of knowledge distillation of the student model is calculated;
S6:使用蒸馏后的学生模型进行预测。S6: Prediction using the distilled student model.
本发明所述的一种基于知识蒸馏的旋转目标检测轻量化方法,其利用教师、学生模型的坐标编码输出进行知识蒸馏,计算旋转坐标的高斯分布表示间的KL散度作为添加的蒸馏损失进行学生模型参数的训练,为轻量化网络的训练提供更准确的定位信息,提高轻量化网络在光学遥感图像上的检测性能。The present invention discloses a lightweight method for rotating target detection based on knowledge distillation, which utilizes the coordinate encoding outputs of teacher and student models for knowledge distillation, calculates the KL divergence between the Gaussian distribution representations of the rotating coordinates as the added distillation loss to train the student model parameters, provides more accurate positioning information for the training of the lightweight network, and improves the detection performance of the lightweight network on optical remote sensing images.
请参照图1和图2所示,上述一种基于知识蒸馏的旋转目标检测轻量化方法主要包括以下三个部分:As shown in FIG. 1 and FIG. 2 , the above-mentioned lightweight method for rotating target detection based on knowledge distillation mainly includes the following three parts:
(1)训练一具有较强检测精度的深层卷积神经网络作为教师模型:(1) Train a deep convolutional neural network with strong detection accuracy as the teacher model:
该部分旨在生成为学生模型训练提供额外监督信息的教师模型,本实验选择以ResNet50-FPN为特征提取骨干网络,RetinaNet为检测网络的教师模型。This part aims to generate a teacher model that provides additional supervision information for student model training. This experiment chooses ResNet50-FPN as the feature extraction backbone network and RetinaNet as the teacher model of the detection network.
(2)基于坐标高斯分布间KL散度的KLDD蒸馏方法:(2) KLDD distillation method based on KL divergence between coordinate Gaussian distributions:
该部分旨在生成旋转坐标的二维高斯分布表示并计算高斯分布间的KL散度,具体步骤包括:坐标解码、高斯分布转换和高斯分布间的KL散度的损失计算。This part aims to generate a two-dimensional Gaussian distribution representation of the rotated coordinates and calculate the KL divergence between Gaussian distributions. The specific steps include: coordinate decoding, Gaussian distribution conversion, and loss calculation of the KL divergence between Gaussian distributions.
其中,具体的坐标解码计算公式为:Among them, the specific coordinate decoding calculation formula is:
x=xa+dx*wa x= xa +dx* wa
y=ya+dy*ha y= ya +dy* ha
w=wa*edw w= wa * edw
h=ha*edh h= ha * edh
θ=θa+dθθ= θa +dθ
得到解码后的旋转坐标表示(x,y,w,h,θ),其中Y:(dx,dy,dw,dh,dθ)为模型输出的坐标编码,A:(xa,ya,wa,ha,θa)为模型预设的锚框。Get the decoded rotation coordinate representation (x, y, w, h, θ), where Y: (dx, dy, dw, dh, dθ) is the coordinate encoding output by the model, and A: ( xa , ya , wa , ha , θa ) is the anchor box preset by the model.
其中,将旋转坐标表示(x,y,w,h,θ)转换成二维高斯分布(μ,∑)的具体计算公式如下:The rotation coordinates are expressed as (x, y, w, h, θ) is converted into a two-dimensional Gaussian distribution The specific calculation formula of (μ, ∑) is as follows:
μ=(x,y)T μ=(x,y) T
式中,(x,y,w,h,θ)为检测目标的旋转坐标表示,分别表示检测目标中心点横坐标、纵坐标、检测目标的宽、高和旋转角度,(μ,∑)分别表示高斯分布的均值、协方差矩阵。Where (x, y, w, h, θ) is the rotation coordinate representation of the detection target, which respectively represents the horizontal coordinate, vertical coordinate, width, height and rotation angle of the detection target center point, and (μ, ∑) represents the mean and covariance matrix of the Gaussian distribution.
其中,学生模型S与教师模型T输出的检测目标的高斯分布表示的KL散度的计算公式为:Among them, the calculation formula of the KL divergence represented by the Gaussian distribution of the detection target output by the student model S and the teacher model T is:
式中,(xS,yS,wS,hS,θS),(μS,∑S)和(xT,yT,wT,hT,θT),(μT,∑T)分别表示学生模型和教师模型输出的检测目标的旋转坐标表示和高斯分布表示,Δx=xS-xT,Δy=yS-yT,Δθ=θS-θT分别表示学生模型输出的检测目标和标签真实框的中心点横、纵坐标之差和旋转角度之差。In the formula, (x S , y S , w S , h S , θ S ), (μ S , ∑ S ) and (x T , y T , w T , h T , θ T ), (μ T , ∑ T ) represents the rotation coordinate representation and Gaussian distribution representation of the detection target output by the student model and the teacher model, respectively. Δx=x S -x T , Δy=y S -y T , Δθ=θ S -θ T represent the difference in horizontal and vertical coordinates and the difference in rotation angle between the center point of the detection target output by the student model and the label true frame, respectively.
其中,学生模型S输出的检测目标的高斯分布表示与标签真实框G的检测目标的高斯分布表示之间的KL散度的计算公式为:Among them, the calculation formula of the KL divergence between the Gaussian distribution representation of the detection target output by the student model S and the Gaussian distribution representation of the detection target of the label real box G is:
式中(xS,yS,wS,hS,θS),(μS,∑S)和(xG,yG,wG,hG,θG),(μG,∑G)分别表示学生模型输出的检测目标和标签真实框得到的检测目标的旋转坐标表示和高斯分布表示,Δx=xS-xG,Δy=yS-yG,Δθ=θS-θG分别表示学生模型输出的检测目标和标签真实框得到的检测目标的中心点横、纵坐标之差和旋转角度之差。In the formula (x S , y S , w S , h S , θ S ), (μ S , ∑ S ) and (x G , y G , w G , h G , θ G ), (μ G , ∑ G ) represent the rotation coordinate representation and Gaussian distribution representation of the detection target output by the student model and the detection target obtained by the label true frame, respectively. Δx=x S -x G , Δy=y S -y G , Δθ=θ S -θ G represent the difference in the horizontal and vertical coordinates and the difference in rotation angle of the center point of the detection target output by the student model and the detection target obtained by the label true frame, respectively.
对高斯分布之间的KL散度进行归一化处理F(KL):Normalize the KL divergence between Gaussian distributions F(KL):
得到最终的坐标参数高斯分布之间的KL散度损失计算公式:Get the final KL divergence loss calculation formula between the coordinate parameter Gaussian distribution:
训练过程中,模型的总体损失函数为:During training, the overall loss function of the model is for:
其中λ0,λ1,λ2为权重参数,分别设置为1,15,5.5;为正样本区域掩码,由锚框和真实目标框的交并比决定;CS,CT,CG分别为学生、教师模型输出的类别置信度和标签真实类别的独热编码;分别为学生、教师模型输出的旋转坐标和标签真实坐标;为交叉熵损失;为类别分布间的KL散度损失;和为归一化后的学生-真值和学生-教师高斯分布间的KL散度。Where λ 0 , λ 1 , and λ 2 are weight parameters, which are set to 1, 15, and 5.5 respectively; is the positive sample area mask, which is determined by the intersection-over-union ratio of the anchor box and the true target box; CS , CT , CG are the category confidence output by the student and teacher models and the unique hot encoding of the true label category respectively; The rotated coordinates and true label coordinates output by the student and teacher models respectively; is the cross entropy loss; is the KL divergence loss between category distributions; and is the normalized KL divergence between the student-truth and student-teacher Gaussian distributions.
(3)预测阶段,仅使用蒸馏后的学生模型进行预测。(3) In the prediction stage, only the distilled student model is used for prediction.
该部分旨在对网络输出的预测框进行后处理,即非极大值抑制,旨在避免对同一目标重复检测,后处理的具体步骤如下:This part aims to post-process the prediction box output by the network, that is, non-maximum suppression, to avoid repeated detection of the same target. The specific steps of post-processing are as follows:
将每一类别中所有的预测框按得分升序或降序排列;Arrange all prediction boxes in each category in ascending or descending order by score;
选取得分最高的预测框为基准,计算其他预测框与其的交并比值;Select the prediction box with the highest score as the benchmark, and calculate the intersection and union ratio of other prediction boxes with it;
根据交并比值对每一类别中的预测框进行筛选,将不满足筛选条件的预测框删除,其中筛选条件为待筛选的预测框与基准预测框的交并比值小于等于预设的阈值。The prediction boxes in each category are screened according to the intersection-and-union ratio, and the prediction boxes that do not meet the screening conditions are deleted, where the screening condition is that the intersection-and-union ratio of the prediction box to be screened and the reference prediction box is less than or equal to a preset threshold.
在剩下符合筛选条件的预测框中,选取得分最高的预测框作为基准框,重复以上筛选步骤,直至所有符合筛选条件的预测框都作为基准参与筛选。Among the remaining prediction boxes that meet the screening conditions, select the prediction box with the highest score as the benchmark box, and repeat the above screening steps until all prediction boxes that meet the screening conditions are used as benchmarks for screening.
下面以实验验证的方式来阐述本发明所提供的一种基于知识蒸馏的旋转目标检测轻量化方法的有益效果。The beneficial effects of a lightweight method for rotating target detection based on knowledge distillation provided by the present invention are explained below in an experimental verification manner.
实施例:Example:
1、实验数据1. Experimental data
为验证本发明提出的检测模型轻量化方法的有效性,本次实验采用了两个光学遥感旋转目标检测数据集:DOTA和HRSC2016。In order to verify the effectiveness of the lightweight detection model method proposed in this paper, this experiment used two optical remote sensing rotating target detection datasets: DOTA and HRSC2016.
1)DOTA数据集:1) DOTA dataset:
DOTA是航空图像中最大的旋转目标检测数据集之一,包含2806张航空图像,影像大小从800×800到4000×4000像素不等,包含了各种不同尺度、朝向和形状的超过188000物体,用任意四边形来标记,包含15个物体类别,分别为:飞机、棒球场、桥梁、地面轨道、小型车辆、大型车辆、舰船、网球场、篮球场、储罐、足球场、环形交叉路口、港口、游泳池和直升飞机。训练集、验证集和测试集按照1/2、1/6和1/3的比例划分。由于图像尺寸较大,这里将其裁剪成1024×1024像素的子图像,其中子图像间重叠部分为200像素。DOTA is one of the largest datasets for rotating object detection in aerial images. It contains 2806 aerial images with image sizes ranging from 800×800 to 4000×4000 pixels. It contains more than 188,000 objects of various scales, orientations, and shapes, marked with arbitrary quadrilaterals, and contains 15 object categories: airplanes, baseball fields, bridges, ground tracks, small vehicles, large vehicles, ships, tennis courts, basketball courts, storage tanks, football fields, roundabouts, ports, swimming pools, and helicopters. The training set, validation set, and test set are divided into 1/2, 1/6, and 1/3 ratios. Due to the large size of the image, it is cropped into 1024×1024 pixel sub-images with an overlap of 200 pixels between sub-images.
2)HRSC2016数据集:2) HRSC2016 dataset:
HRSC2016是用于船舶检测的单类别数据集,其图片来自6个知名港口,包含两种场景:海上和近海岸船只,其图像空间分辨率为0.4-2m。数据集总共有1070张图片和2976个舰船目标。训练集、验证集和测试集包括436、181和453张图像。由于图像尺寸不等,这里将其按比例缩放成800×512像素。图(5a)展示了数据集中的部分图片及目标框标注。HRSC2016 is a single-category dataset for ship detection. Its images come from 6 well-known ports and include two scenes: ships at sea and near the coast. The spatial resolution of the images is 0.4-2m. The dataset has a total of 1070 images and 2976 ship targets. The training set, validation set, and test set include 436, 181, and 453 images, respectively. Due to the different sizes of the images, they are scaled to 800×512 pixels. Figure (5a) shows some of the images in the dataset and the target box annotations.
2、实验结果2. Experimental results
本实验是基于MMRotate环境搭建,并在单块NVIDIA RTX3090 GPU上进行。本实验选择以ResNet50-FPN为骨干的RetinaNet网络作为教师模型,学生模型分别选择了以ResNet18-FPN和MobileNetV2-FPN为骨干的RetinaNet网络。在训练阶段,采取随机水平、垂直、对角翻转作为数据增强方法,选择标准动量优化器(Momentum),权重衰减和动量分别为0.0001和0.9,初始学习率设置为0.0025,对于DOTA数据集总共训练24个epoch,对于HRSC2016数据集总共训练72个epoch,批次大小都为2。在测试阶段,采用mAP50(判定真正目标的交并比阈值为0.5情况下的平均精度)作为评估检测模型精度的指标。This experiment is built based on the MMRotate environment and is conducted on a single NVIDIA RTX3090 GPU. This experiment selects the RetinaNet network with ResNet50-FPN as the backbone as the teacher model, and the student models select the RetinaNet network with ResNet18-FPN and MobileNetV2-FPN as the backbone. In the training phase, random horizontal, vertical, and diagonal flipping are used as data augmentation methods, and the standard momentum optimizer (Momentum) is selected. The weight decay and momentum are 0.0001 and 0.9 respectively. The initial learning rate is set to 0.0025. For the DOTA dataset, a total of 24 epochs are trained, and for the HRSC2016 dataset, a total of 72 epochs are trained, and the batch size is 2. In the test phase, mAP 50 (the average precision when the intersection-over-union ratio threshold for determining the true target is 0.5) is used as an indicator to evaluate the accuracy of the detection model.
1)DOTA数据集:1) DOTA dataset:
表1、2为使用不同对比蒸馏方法得到的实验结果,包括基准教师模型ResNet50-FPN、基准学生模型ResNet18-FPN和基准学生模型MobileNetv2-FPN。从表中结果可以看出,提出方法(KLDD)相比基准学生模型(ResNet18-FPN和MobileNetv2-FPN)整体精度分别从65.2%和56.8%提升至68.1%和61.0%,且优于当前主流的蒸馏方法(FitNets、DeFeat、GIImitation和LD等)的检测结果,整体精度分别提升4.45%和7.39%。Tables 1 and 2 show the experimental results obtained using different comparative distillation methods, including the benchmark teacher model ResNet50-FPN, the benchmark student model ResNet18-FPN, and the benchmark student model MobileNetv2-FPN. From the results in the table, it can be seen that the overall accuracy of the proposed method (KLDD) is improved from 65.2% and 56.8% to 68.1% and 61.0% respectively compared with the benchmark student models (ResNet18-FPN and MobileNetv2-FPN), and it is better than the detection results of the current mainstream distillation methods (FitNets, DeFeat, GIImitation, and LD, etc.), with the overall accuracy increased by 4.45% and 7.39% respectively.
表3为本实验所选用的二个轻量化模型ResNet18-FPN-RetinaNet和MobileNetv2-FPN-RetinaNet以及教师模型ResNet50-FPN-RetinaNet的精度(mAP50)、浮点运算量(FLOPS)以及模型参数量(Params)的比较,设置输入网络的图片大小为3x1024x1024。Table 3 compares the accuracy (mAP 50 ), floating-point operations (FLOPS), and model parameters (Params) of the two lightweight models ResNet18-FPN-RetinaNet and MobileNetv2-FPN-RetinaNet and the teacher model ResNet50-FPN-RetinaNet used in this experiment. The image size of the input network is set to 3x1024x1024.
表1在DOTA数据集下不同对比蒸馏方法的实验结果Table 1 Experimental results of different distillation methods on the DOTA dataset
表2在DOTA数据集下不同对比蒸馏方法的实验结果Table 2 Experimental results of different distillation methods on the DOTA dataset
表3在DOTA数据集下使用KLDD蒸馏方法前后的实验结果Table 3 Experimental results before and after using the KLDD distillation method on the DOTA dataset
2)HRSC2016数据集:2) HRSC2016 dataset:
表4、5为使用不同对比蒸馏方法得到的实验结果,包括基准教师模型ResNet50-FPN、基准学生模型ResNet18-FPN和基准学生模型MobileNetv2-FPN。从表中结果可以看出,提出方法(KLDD)相比基准学生模型(ResNet18-FPN和MobileNetv2-FPN)整体精度分别从86.0%和79.2%提升至89.2%和80.7%,且优于当前主流的蒸馏方法(FitNets、DeFeat、GIImitation和LD等)的检测结果,整体精度分别提升3.72%和1.89%。Tables 4 and 5 show the experimental results obtained using different comparative distillation methods, including the benchmark teacher model ResNet50-FPN, the benchmark student model ResNet18-FPN, and the benchmark student model MobileNetv2-FPN. From the results in the table, it can be seen that the overall accuracy of the proposed method (KLDD) is improved from 86.0% and 79.2% to 89.2% and 80.7% respectively compared with the benchmark student models (ResNet18-FPN and MobileNetv2-FPN), and it is better than the detection results of the current mainstream distillation methods (FitNets, DeFeat, GIImitation, and LD, etc.), with the overall accuracy increased by 3.72% and 1.89% respectively.
表6为本实验所选用的二个轻量化模型ResNet18-FPN-RetinaNet和MobileNetv2-FPN-RetinaNet以及教师模型ResNet50-FPN-RetinaNet的精度(mAP50)、浮点运算量(FLOPS)以及模型参数量(Params)的比较,设置输入网络的图片大小为3x1024x1024。Table 6 compares the accuracy (mAP 50 ), floating-point operations (FLOPS), and model parameters (Params) of the two lightweight models ResNet18-FPN-RetinaNet and MobileNetv2-FPN-RetinaNet and the teacher model ResNet50-FPN-RetinaNet used in this experiment. The image size of the input network is set to 3x1024x1024.
表4在HRSC2016数据集下不同对比蒸馏方法的实验结果Table 4 Experimental results of different distillation methods on the HRSC2016 dataset
表5在HRSC2016数据集下不同对比蒸馏方法的实验结果Table 5 Experimental results of different distillation methods on the HRSC2016 dataset
表6在HRSC2016数据集下使用KLDD蒸馏方法前后的实验结果Table 6 Experimental results before and after using the KLDD distillation method on the HRSC2016 dataset
从图5当中可以直观的看到所提出方法优于其他蒸馏方法,针对第一张图像,所提出方法虽然也出现了虚警现象,但还是总体优于其他蒸馏方法;针对第二张图像,所提出方法正确检测出所有的舰船目标,而基于解耦特征的DeFeat蒸馏方法和基于坐标间广义分布信息的LD蒸馏方法出现了虚警和漏检的现象;针对第三张图像,所提出方法正确检测出所有的舰船目标,而其他的蒸馏方法普遍出现了虚警和漏检的现象。It can be seen intuitively from Figure 5 that the proposed method is superior to other distillation methods. For the first image, although the proposed method also has false alarms, it is still better than other distillation methods overall; for the second image, the proposed method correctly detects all ship targets, while the DeFeat distillation method based on decoupled features and the LD distillation method based on generalized distribution information between coordinates have false alarms and missed detections; for the third image, the proposed method correctly detects all ship targets, while other distillation methods generally have false alarms and missed detections.
下面对本发明实施例公开的一种基于知识蒸馏的旋转目标检测轻量化系统进行介绍,下文描述的一种基于知识蒸馏的旋转目标检测轻量化系统与上文描述的一种基于知识蒸馏的旋转目标检测轻量化方法可相互对应参照。A lightweight system for rotating target detection based on knowledge distillation disclosed in an embodiment of the present invention is introduced below. The lightweight system for rotating target detection based on knowledge distillation described below and the lightweight method for rotating target detection based on knowledge distillation described above can refer to each other.
本发明还提供一种基于知识蒸馏的旋转目标检测轻量化系统,包括以下步骤:The present invention also provides a lightweight system for rotating target detection based on knowledge distillation, comprising the following steps:
坐标解码模块,其用于将所述学生模型和训练好的教师模型输出的检测目标的坐标编码分别解码为旋转坐标形式,得到检测目标的旋转坐标表示;A coordinate decoding module, which is used to decode the coordinate encoding of the detection target output by the student model and the trained teacher model into a rotation coordinate form, so as to obtain a rotation coordinate representation of the detection target;
坐标转换模块,其用于将所述旋转坐标表示转换为二维高斯分布,得到检测目标的高斯分布表示;A coordinate conversion module, which is used to convert the rotation coordinate representation into a two-dimensional Gaussian distribution to obtain a Gaussian distribution representation of the detection target;
KL散度计算模块,其用于计算所述学生模型与教师模型输出的检测目标的高斯分布表示的KL散度,得到第一KL散度,以及计算学生模型输出的检测目标的高斯分布表示与标签真实框得到的检测目标的高斯分布表示之间的KL散度,得到第二KL散度;A KL divergence calculation module is used to calculate the KL divergence of the Gaussian distribution representation of the detection target output by the student model and the teacher model to obtain a first KL divergence, and calculate the KL divergence between the Gaussian distribution representation of the detection target output by the student model and the Gaussian distribution representation of the detection target obtained by the label true frame to obtain a second KL divergence;
归一化处理模块,其用于对所述第一KL散度和第二KL散度进行归一化处理;A normalization processing module, used for performing normalization processing on the first KL divergence and the second KL divergence;
损失函数计算模块,其用于根据归一化处理后的第一KL散度和第二KL散度,计算得到教师模型对学生模型进行知识蒸馏过程中的总体损失函数;A loss function calculation module, which is used to calculate the overall loss function in the process of knowledge distillation of the teacher model to the student model according to the first KL divergence and the second KL divergence after normalization;
预测模块,其用于使用蒸馏后的学生模型进行预测。The prediction module is used to make predictions using the distilled student model.
在本发明的一个实施例中,在坐标转换模块中,将所述旋转坐标表示转换为二维高斯分布的方法,包括:In one embodiment of the present invention, in a coordinate conversion module, a method for converting the rotation coordinate representation into a two-dimensional Gaussian distribution includes:
将旋转坐标表示(x,y,w,h,θ)转换成二维高斯分布(μ,∑)的计算公式为:Represent the rotated coordinates (x, y, w, h, θ) is converted into a two-dimensional Gaussian distribution The calculation formula of (μ, ∑) is:
μ=(x,y)T μ=(x,y) T
式中,(x,y,w,h,θ)为检测目标的旋转坐标表示,分别表示检测目标中心点横坐标、纵坐标、检测目标的宽、高和旋转角度,(μ,∑)分别表示高斯分布的均值、协方差矩阵。Where (x, y, w, h, θ) is the rotation coordinate representation of the detection target, which respectively represents the horizontal coordinate, vertical coordinate, width, height and rotation angle of the detection target center point, and (μ, ∑) represents the mean and covariance matrix of the Gaussian distribution.
在本发明的一个实施例中,在KL散度计算模块中,计算所述学生模型与教师模型输出的检测目标的高斯分布表示的KL散度的方法,包括:In one embodiment of the present invention, in the KL divergence calculation module, a method for calculating the KL divergence of the Gaussian distribution representation of the detection target output by the student model and the teacher model includes:
所述学生模型与教师模型输出的检测目标的高斯分布表示的KL散度的计算公式为:The calculation formula of the KL divergence represented by the Gaussian distribution of the detection target output by the student model and the teacher model is:
式中(xS,yS,wS,hS,θS)和(xT,yT,wT,hT,θT)(μT,∑T)分别表示学生模型和教师模型输出的检测目标的旋转坐标表示和高斯分布表示,Δx=xS-xT,Δy=yS-yT,Δθ=θS-θT分别表示学生模型输出的检测目标和标签真实框的中心点横、纵坐标之差和旋转角度之差。In the formula (x S , y S , w S , h S , θ S ) and (x T , y T , w T , h T , θ T ) (μ T , ∑ T ) represents the rotation coordinate representation and Gaussian distribution representation of the detection target output by the student model and the teacher model, respectively. Δx=x S -x T , Δy=y S -y T , Δθ=θ S -θ T represent the difference in horizontal and vertical coordinates and the difference in rotation angle between the center point of the detection target output by the student model and the label true frame, respectively.
本实施例的基于知识蒸馏的旋转目标检测轻量化系统用于实现前述的基于知识蒸馏的旋转目标检测轻量化方法的实施例部分,所以,其具体实施方式可以参照相应的各个部分实施例的描述,在此不再展开介绍。The lightweight system for rotating target detection based on knowledge distillation of this embodiment is used to implement part of the embodiment of the aforementioned lightweight method for rotating target detection based on knowledge distillation. Therefore, its specific implementation method can refer to the description of the corresponding embodiments of each part and will not be introduced in detail here.
另外,由于本实施例的基于知识蒸馏的旋转目标检测轻量化系统用于实现前述的基于知识蒸馏的旋转目标检测轻量化方法,因此其作用与上述方法的作用相对应,这里不再赘述。In addition, since the lightweight system for rotating target detection based on knowledge distillation of this embodiment is used to implement the aforementioned lightweight method for rotating target detection based on knowledge distillation, its function corresponds to that of the above method and will not be repeated here.
本领域内的技术人员应明白,本申请的实施例可提供为方法、系统、或计算机程序产品。因此,本申请可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。Those skilled in the art will appreciate that the embodiments of the present application may be provided as methods, systems, or computer program products. Therefore, the present application may adopt the form of a complete hardware embodiment, a complete software embodiment, or an embodiment in combination with software and hardware. Moreover, the present application may adopt the form of a computer program product implemented in one or more computer-usable storage media (including but not limited to disk storage, CD-ROM, optical storage, etc.) that contain computer-usable program code.
本申请是参照根据本申请实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。The present application is described with reference to the flowchart and/or block diagram of the method, device (system) and computer program product according to the embodiment of the present application. It should be understood that each process and/or box in the flowchart and/or block diagram, and the combination of the process and/or box in the flowchart and/or block diagram can be realized by computer program instructions. These computer program instructions can be provided to a processor of a general-purpose computer, a special-purpose computer, an embedded processor or other programmable data processing device to produce a machine, so that the instructions executed by the processor of the computer or other programmable data processing device produce a device for realizing the function specified in one process or multiple processes in the flowchart and/or one box or multiple boxes in the block diagram.
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。These computer program instructions may also be stored in a computer-readable memory that can direct a computer or other programmable data processing device to work in a specific manner, so that the instructions stored in the computer-readable memory produce a manufactured product including an instruction device that implements the functions specified in one or more processes in the flowchart and/or one or more boxes in the block diagram.
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。These computer program instructions may also be loaded onto a computer or other programmable data processing device so that a series of operational steps are executed on the computer or other programmable device to produce a computer-implemented process, whereby the instructions executed on the computer or other programmable device provide steps for implementing the functions specified in one or more processes in the flowchart and/or one or more boxes in the block diagram.
显然,上述实施例仅仅是为清楚地说明所作的举例,并非对实施方式的限定。对于所属领域的普通技术人员来说,在上述说明的基础上还可以做出其它不同形式变化或变动。这里无需也无法对所有的实施方式予以穷举。而由此所引申出的显而易见的变化或变动仍处于本发明创造的保护范围之中。Obviously, the above embodiments are merely examples for clear explanation and are not intended to limit the implementation methods. For those skilled in the art, other different forms of changes or modifications can be made based on the above description. It is not necessary and impossible to list all the implementation methods here. The obvious changes or modifications derived from these are still within the protection scope of the invention.
Claims (10)
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310299030.6A CN116524351B (en) | 2023-03-24 | 2023-03-24 | Lightweight method and system for rotating target detection based on knowledge distillation |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310299030.6A CN116524351B (en) | 2023-03-24 | 2023-03-24 | Lightweight method and system for rotating target detection based on knowledge distillation |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116524351A true CN116524351A (en) | 2023-08-01 |
CN116524351B CN116524351B (en) | 2024-12-24 |
Family
ID=87389373
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310299030.6A Active CN116524351B (en) | 2023-03-24 | 2023-03-24 | Lightweight method and system for rotating target detection based on knowledge distillation |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116524351B (en) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117521848A (en) * | 2023-11-10 | 2024-02-06 | 中国科学院空天信息创新研究院 | Remote sensing basic model light-weight method and device for resource-constrained scene |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113610126A (en) * | 2021-07-23 | 2021-11-05 | 武汉工程大学 | Label-free knowledge distillation method based on multi-target detection model and storage medium |
CN115577305A (en) * | 2022-10-31 | 2023-01-06 | 中国人民解放军军事科学院系统工程研究院 | Intelligent unmanned aerial vehicle signal identification method and device |
-
2023
- 2023-03-24 CN CN202310299030.6A patent/CN116524351B/en active Active
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113610126A (en) * | 2021-07-23 | 2021-11-05 | 武汉工程大学 | Label-free knowledge distillation method based on multi-target detection model and storage medium |
CN115577305A (en) * | 2022-10-31 | 2023-01-06 | 中国人民解放军军事科学院系统工程研究院 | Intelligent unmanned aerial vehicle signal identification method and device |
Non-Patent Citations (2)
Title |
---|
YINGJIE CUI 等: ""Quantitive short-term precipitation model using multimodal data fusion based on a cross-attention mechanism"", 《MDPI》, 31 December 2022 (2022-12-31) * |
王耀: ""基于信息量化对知识蒸馏的研究"", 《中国优秀硕士学位论文全文数据库》, 15 February 2023 (2023-02-15) * |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117521848A (en) * | 2023-11-10 | 2024-02-06 | 中国科学院空天信息创新研究院 | Remote sensing basic model light-weight method and device for resource-constrained scene |
CN117521848B (en) * | 2023-11-10 | 2024-05-28 | 中国科学院空天信息创新研究院 | Remote sensing basic model light-weight method and device for resource-constrained scene |
Also Published As
Publication number | Publication date |
---|---|
CN116524351B (en) | 2024-12-24 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110276269B (en) | A target detection method for remote sensing images based on attention mechanism | |
Wang et al. | Gaussian focal loss: Learning distribution polarized angle prediction for rotated object detection in aerial images | |
CN108564097A (en) | A kind of multiscale target detection method based on depth convolutional neural networks | |
CN111833273B (en) | Semantic Boundary Enhancement Based on Long-distance Dependency | |
CN115147731A (en) | A SAR Image Target Detection Method Based on Full Spatial Coding Attention Module | |
CN112800955A (en) | Remote sensing image rotating target detection method and system based on weighted bidirectional feature pyramid | |
CN113177503A (en) | Arbitrary orientation target twelve parameter detection method based on YOLOV5 | |
CN110826485B (en) | Target detection method and system for remote sensing image | |
CN110781962A (en) | Target detection method based on lightweight convolutional neural network | |
CN115049923A (en) | SAR image ship target instance segmentation training method, system and device | |
Yang et al. | BorderPointsMask: One-stage instance segmentation with boundary points representation | |
EP4323952A1 (en) | Semantically accurate super-resolution generative adversarial networks | |
He et al. | An improved method MSS-YOLOv5 for object detection with balancing speed-accuracy | |
CN117789030A (en) | A method and system for detecting small ship targets in remote sensing images | |
CN116524351A (en) | Lightweight method and system for rotating target detection based on knowledge distillation | |
Liu | TS2Anet: Ship detection network based on transformer | |
Yuan et al. | Dynamic pyramid attention networks for multi-orientation object detection | |
CN115546171A (en) | Shadow detection method and device based on attention shadow boundary and feature correction | |
Liu et al. | SRFAD-Net: Scale-Robust Feature Aggregation and Diffusion Network for Object Detection in Remote Sensing Images | |
CN114332638B (en) | Remote sensing image target detection method and device, electronic equipment and medium | |
Li et al. | LWS-YOLOv7: A Lightweight Water-Surface Object-Detection Model | |
Zhong et al. | An Improved Mask R-CNN: Extraction of Door and Window Instances on Village Building Façade Images | |
Guo et al. | A Surface Target Recognition Algorithm Based on Coordinate Attention and Double‐Layer Cascade | |
Saini et al. | DG-YOLOT: A lightweight density guided YOLO-transformer for remote sensing object detection | |
Zhang et al. | Research on text location and recognition in natural images with deep learning |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |