WO2023212997A1 - Knowledge distillation based neural network training method, device, and storage medium - Google Patents

Knowledge distillation based neural network training method, device, and storage medium Download PDF

Info

Publication number
WO2023212997A1
WO2023212997A1 PCT/CN2022/098769 CN2022098769W WO2023212997A1 WO 2023212997 A1 WO2023212997 A1 WO 2023212997A1 CN 2022098769 W CN2022098769 W CN 2022098769W WO 2023212997 A1 WO2023212997 A1 WO 2023212997A1
Authority
WO
WIPO (PCT)
Prior art keywords
student
teacher
network model
features
loss function
Prior art date
Application number
PCT/CN2022/098769
Other languages
French (fr)
Chinese (zh)
Inventor
崔岩
常青玲
任飞
徐世廷
Original Assignee
五邑大学
广东四维看看智能设备有限公司
中德(珠海)人工智能研究院有限公司
珠海市四维时代网络科技有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Priority claimed from CN202210646401.9A external-priority patent/CN114936605A/en
Application filed by 五邑大学, 广东四维看看智能设备有限公司, 中德(珠海)人工智能研究院有限公司, 珠海市四维时代网络科技有限公司 filed Critical 五邑大学
Publication of WO2023212997A1 publication Critical patent/WO2023212997A1/en

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Definitions

  • the invention relates to the field of artificial intelligence, and in particular to a neural network training method, equipment and storage medium based on knowledge distillation.
  • three-dimensional plane restoration and reconstruction technology is one of the current research tasks in the field of computer vision.
  • the three-dimensional plane and restoration of a single image require segmenting the plane instance area of the scene from the image dimension and estimating each instance at the same time.
  • the plane parameters of the area, non-planar areas will be represented by the depth estimated by the network model, and the three-dimensional plane recovery and reconstruction technology has broad application prospects in virtual reality, augmented reality, robots and other fields.
  • Plane restoration and reconstruction is an important research direction in three-dimensional plane restoration and reconstruction.
  • three-dimensional plane restoration and reconstruction methods focus on reconstruction accuracy and enhance the accuracy of neural network models by analyzing the edges of plane structures and their embedding with the scene.
  • the neural network model used for plane restoration and reconstruction has the problem of losing scene structure information and lacking data generalization.
  • the present invention aims to solve at least one of the technical problems existing in the prior art.
  • the present invention provides a neural network training method, equipment and storage medium based on knowledge distillation, which can improve the scene information acquisition capability and data generalization performance of the neural network model.
  • the first embodiment of the present invention provides a neural network training method based on knowledge distillation, which includes the following steps:
  • distillation loss function group includes the encoding loss function, the decoding loss function and the prediction result loss function
  • the student network model is trained according to the distillation loss function group to obtain the trained student network model.
  • the teacher network model and the student network model obtain a distilled loss function group composed of a variety of loss functions, and then train the student network model through the distilled loss function group, which can ensure the reliability of each processing link in the student network model, thereby effectively improving Accuracy of student network model prediction results.
  • a distillation loss function group is obtained, including:
  • the training samples are input into the student network model to obtain the student feature group, including:
  • the training samples are input into the teacher network model to obtain the teacher feature group, including:
  • the teacher network model includes a teacher backbone network model, a teacher encoder and a teacher decoder.
  • the teacher backbone network The model outputs teacher encoding features, the teacher encoder outputs teacher decoding features, and the teacher decoder outputs teacher prediction result features.
  • a distillation loss function group is obtained, including:
  • the coding loss function is obtained
  • the decoding loss function is obtained
  • the prediction result loss function is obtained.
  • the student network model includes a student encoder and a student decoder
  • the fused feature layer is obtained
  • the student decoding features and fused feature layers are input to the student decoder, which is first fused and then upsampled and decoded to obtain the student prediction result features.
  • a fused feature layer is obtained, including:
  • the fused feature layer at each scale is obtained.
  • the student decoding features and the fusion feature layer are input to the student decoder, and are first fused and then upsampled and decoded to obtain student prediction result features, including:
  • the student decoding features are input to the student decoder for upsampling decoding, and each fusion feature layer is fused with the upsampled intermediate feature map at the corresponding scale during the upsampling decoding process of the student decoder to obtain the student decoding features.
  • a second embodiment of the present invention provides an electronic device, including:
  • a memory a processor, and a computer program stored in the memory and executable on the processor.
  • the processor executes the computer program, the neural network training method based on knowledge distillation according to any one of the first aspects is implemented.
  • the electronic device of the embodiment of the second aspect applies any one of the knowledge distillation-based neural network training methods of the first aspect, it has all the beneficial effects of the first aspect of the present invention.
  • a computer storage medium provided according to an embodiment of the third aspect of the present invention stores computer-executable instructions, and the computer-executable instructions are used to execute any one of the knowledge distillation-based neural network training methods of the first aspect.
  • the computer storage medium of the embodiment of the third aspect can execute any one of the knowledge distillation-based neural network training methods of the first aspect, it has all the beneficial effects of the first aspect of the present invention.
  • Figure 1 is a main step diagram of the neural network training method based on knowledge distillation according to the embodiment of the present invention
  • Figure 2 is a specific step diagram of step S2000 in the neural network training method based on knowledge distillation according to the embodiment of the present invention
  • Figure 3 is a specific step diagram of step S2100 in the neural network training method based on knowledge distillation according to the embodiment of the present invention
  • Figure 4 is a specific step diagram of step S2300 in the neural network training method based on knowledge distillation according to the embodiment of the present invention
  • Figure 5 is a working principle diagram of the student network model in the neural network training method based on knowledge distillation according to the embodiment of the present invention
  • Figure 6 is a working principle diagram of the teacher network model in the neural network training method based on knowledge distillation according to the embodiment of the present invention
  • Figure 7 is a working principle diagram of the neural network training method based on knowledge distillation according to the embodiment of the present invention.
  • Three-dimensional plane restoration and reconstruction technology is one of the current research tasks in the field of computer vision.
  • the three-dimensional plane and restoration of a single image need to start from the image dimension. Segment the plane instance area of the scene and estimate the plane parameters of each instance area.
  • the non-planar area will be represented by the depth estimated by the network model.
  • the three-dimensional plane recovery and reconstruction technology has broad application in the fields of virtual reality, augmented reality, robots and other fields. application prospects.
  • the plane detection and restoration method of a single image requires simultaneous research on image depth, plane normal, plane segmentation, etc.
  • 3D reconstruction methods mainly generate point cloud data through 3D vision methods, then generate nonlinear scene surfaces by fitting relevant points, and then optimize the overall reconstruction model through global reasoning.
  • Planar restoration and reconstruction are combined with visual instance segmentation methods. Identify the plane area of the scene and use three parameters in Cartesian coordinates and a segmentation mask to represent the plane, which has better reconstruction accuracy and effect.
  • Three-dimensional reconstruction is achieved through segmented plane restoration and reconstruction. Segmented plane restoration and reconstruction is a multi-stage reconstruction method, and the accuracy of plane identification and parameter estimation will affect the results of the final model.
  • plane prediction reconstruction can be achieved through a variety of methods: the convolutional neural network architecture Planenet can infer a fixed number of plane instance masks and plane parameters from a single RGB image; it can also predict a fixed number of planes directly from the plane Learning in structure-induced deep modalities; the two-stage Mask R-CNN framework replaces object category classification with plane geometry prediction, and then refines the plane segmentation mask with a convolutional neural network; it can also predict pixel-by-pixel plane parameters using correlation
  • the embedding method trains network parameters to map each pixel to the embedding space, and then clusters the embedded pixels into planar instances; the planar thinning method constrained by the Manhattan world assumption strengthens the planarization parameters by limiting the set relationship between planar instances refinement; a divide-and-conquer method is used to segment the panorama plane from the vertical and horizontal directions.
  • this method can restore the distorted plane instance; the method PlaneTR based on the transformer module, through Adding center and edge features of plane instances can effectively improve the efficiency of plane detection.
  • Planet is a deep neural network used to segmentally reconstruct plane depth maps from a single RGB image.
  • Mask R-CNN is a neural network framework.
  • the transformer module is a model based on a multi-head attention mechanism.
  • PlaneTR is a user A model used to extract 3D planar features in the scene.
  • Plane restoration and reconstruction is an important research direction in 3D plane restoration and reconstruction.
  • Most of the current 3D reconstruction methods first generate point cloud data through 3D vision methods, then generate nonlinear scene surfaces by fitting relevant points, and then optimize the whole through global reasoning.
  • Reconstruction model in related technologies, three-dimensional plane restoration and reconstruction methods focus on reconstruction accuracy, and enhance the accuracy of the model by analyzing the edges of the plane structure and the embeddedness with the scene.
  • the neural network used for plane restoration and reconstruction has missing scenes. Structural information and lack of data generalization issues.
  • the student network model obtained by knowledge distillation training of the present invention is used for plane restoration and reconstruction, which can avoid the loss of scene structure information. , the problem of low data generalization.
  • Knowledge distillation is a training framework.
  • the student network model uses the softmax function output vector of the powerful teacher network model to learn as soft labels.
  • the student network model is generally a lightweight small network model.
  • the distillation process can effectively improve the lightweight network model.
  • the prediction accuracy of The effectiveness of the distillation process, and because the features extracted by the neural network become more abstract with depth superposition, the student network model obtained through knowledge distillation has the problem of low prediction accuracy.
  • An attention map is derived from the original feature map to express knowledge, knowledge is transferred by matching the probability distribution in the feature space, factors are introduced as a more understandable intermediate representation, and the activation boundaries of hidden neurons are used for knowledge transfer. Plasticity decreases rapidly after the first few training stages, and the effectiveness of knowledge distillation is reduced.
  • the following describes the neural network training method, equipment and storage medium based on knowledge distillation of the present invention with reference to Figures 1 to 7. It can not only improve the scene structure information acquisition ability and data generalization ability of the obtained student network model, but also improve the student network model. The prediction accuracy of the network model.
  • the neural network training method based on knowledge distillation includes the following steps:
  • Step S1000 Construct an untrained student network model and a trained teacher network model
  • Step S2000 Obtain a distillation loss function group based on the training samples, the student network model and the teacher network model, where the distillation loss function group includes the encoding loss function, the decoding loss function and the prediction result loss function;
  • Step S3000 Train the student network model according to the distillation loss function group to obtain a trained student network model, where the trained student network model can be used to implement plane restoration and reconstruction.
  • the scene information acquisition ability and data generalization ability of the student network model can be effectively improved, and the acquisition is composed of a variety of loss functions based on the teacher network model and the student network model.
  • the distillation loss function group, and then training the student network model through the distillation loss function group, can ensure the reliability of each processing link in the student network model, thereby effectively improving the accuracy of the student network model prediction results.
  • step S2000 is to obtain a distillation loss function group based on the training samples, the student network model and the teacher network model, including but not limited to the following steps:
  • Step S2100 Input the training samples into the student network model to obtain the student feature group
  • Step S2200 Input the training samples into the teacher network model to obtain the teacher feature group
  • Step S2300 Obtain a distillation loss function group based on the student feature group and the teacher feature group.
  • the student network model includes a student encoder and a student decoder
  • the student feature group includes student encoding features, student decoding features, and student prediction result features
  • step S2100 input the training samples into the student network model to obtain the student feature group , including but not limited to the following steps:
  • Step S2110 Input the training samples into the student network model to obtain a student feature group including student coding features, student decoding features and student prediction result features.
  • the structure of the student network model can be significantly simplified.
  • the student network model is highly lightweight and can be used for training. When a good student network model is used for prediction, it can effectively improve its prediction speed, and it is fast when used for plane detection and recovery; using the intermediate features of the network also has certain learning potential, through a distillation loss function group that contains three groups of distillation loss functions. Iterative training of the student network model, that is, through a step-by-step distillation process, helps alleviate the negative impact of hard correlations. The finally obtained trained student network model can meet both real-time and high-precision prediction performance.
  • step S2110 the training samples are input into the student network model to obtain a student feature group including student coding features, student decoding features and student prediction result features, including but not limited to the following steps:
  • Step S2111 Input the training samples to the student encoder for downsampling encoding to obtain student encoding features
  • Step S2112 Obtain the fusion feature layer according to the down-sampling encoding process of the student encoder
  • Step S2113 Convolve the student coding features to obtain the student decoding features
  • Step S2114 Input the student decoding features and the fused feature layer to the student decoder.
  • the student decoding features and the fused feature layer are first fused and then upsampled and decoded to obtain student prediction result features.
  • the student encoder performs down-sampling encoding processing on the input data of the training samples.
  • a fast down-sampling strategy can be used to extract and identify features with a large enough perceptual domain, which can effectively improve the recognition speed.
  • the down-sampling operation will This results in the loss of spatial information, and this lost information cannot be recovered during subsequent processing.
  • the corresponding features are extracted as a feature fusion layer during the downsampling process of the student encoder, which is used to correspond to the upsampling and decoding process of the student decoder.
  • the fusion of features can make corresponding compensation for the spatial information lost during the down-sampling process, and can effectively ensure the reliability of the student prediction result features obtained after up-sampling and decoding by the student decoder.
  • step S2112 the fusion feature layer is obtained according to the down-sampling encoding process of the student encoder, including but not limited to the following steps:
  • the fused feature layer at each scale is obtained.
  • Step S2114 input the student decoding features and the fused feature layer to the student decoder.
  • the student decoding features and the fused feature layer are first fused and then upsampled and decoded to obtain the student prediction result features, including but not limited to the following steps:
  • the final output of the student decoder is student Predicted result characteristics.
  • each A fusion feature layer is fused with deep features sampled at the same scale, which can gradually restore spatial details, thereby effectively ensuring the reliability of the student decoding features output by the student decoder.
  • step S2200 the training samples are input into the teacher network model to obtain the teacher feature group, including but not limited to the following steps:
  • Step S2210 Input the training samples into the teacher network model to obtain a teacher feature group including teacher coding features, teacher decoding features and teacher prediction result features.
  • the teacher network model includes a teacher backbone network model, a teacher encoder and a teacher decoder.
  • the teacher backbone network model outputs teacher coding features.
  • the teacher coding features are input to the teacher encoder and the teacher decoding features are output.
  • the teacher decoding features are input to the teacher decoder and the teacher prediction result features are output.
  • the teacher feature group includes teacher coding features, teacher decoding features and Teacher prediction outcome characteristics.
  • the teacher network model training samples are input into the teacher backbone network model, the teacher backbone network model outputs teacher coding features, the teacher coding features are input into the teacher encoder, the teacher encoder outputs the teacher decoding features, and the teacher decoding features are input into In the teacher decoder, the teacher decoder outputs the teacher prediction result characteristics.
  • step S2300 is to obtain a distillation loss function group based on the student feature group and the teacher feature group, including but not limited to the following steps:
  • Step S2310 Obtain a coding loss function based on the student coding features and the teacher coding features, where the coding loss function is used to correct the downsampling coding of the student coder so that the student coder outputs more accurate student coding features;
  • Step S2320 Obtain a decoding loss function based on the student decoding features and the teacher decoding features, where the decoding loss function is used to correct the convolution before the student decoder to ensure the accuracy of the student decoding features input to the student decoder;
  • Step S2330 Obtain the prediction result loss function based on the student prediction result characteristics and the teacher prediction result characteristics.
  • the prediction result loss function is used to correct the upsampling decoding of the student decoder so that the student network model outputs more accurate student predictions. Result characteristics.
  • the working principle diagram of the neural network training method based on knowledge distillation is shown in Figure 7.
  • the student network model is iteratively trained through three distillation loss functions corresponding to different network layers, that is, a direct and effective one-to-one matching is achieved between the corresponding levels of the student network model and the teacher network model, which can effectively ensure that students
  • the accuracy of data processing at the corresponding network layer in the network model can effectively improve the performance of the student network model from an architectural perspective, and can effectively ensure the generalization of the student network model and the accuracy of the prediction results.
  • Knowledge distillation architecture with multiple student networks using critical learning awareness KD (Knowledge Distillation) scheme to ensure the formation of key connections, allowing to effectively imitate the teacher's information flow, rather than just learning one student, allowing the student network model and teacher Direct and effective one-to-one matching between corresponding layers of the network model, adaptively divide the teacher network model and student network model into three parts, assign each part the adaptive parameters of the corresponding network layer, and perform knowledge distillation Learning, through semantic correction of shallow network feature associations, significantly improves the effectiveness of feature knowledge transfer, and uses the attention mechanism to achieve cross-layer rectification, which can alleviate the problem of semantic mismatch.
  • KD Knowledge Distillation
  • step S3000 is to train the student network model according to the distillation loss function group to obtain the trained student network model, including but not limited to the following steps:
  • the downsampling encoding of the student encoder is corrected according to the encoding loss function, the convolution before the student decoder is corrected according to the decoding loss function, and the upsampling decoding of the student decoder is corrected according to the prediction result loss function.
  • the student network model is trained to obtain a trained student network model.
  • the student network is auxiliary trained according to the distillation loss function group containing multiple intermediate layer loss functions.
  • the estimation performance of the student network can be enhanced through transfer learning of the intermediate feature layer.
  • the encoding dimension, decoding dimension and prediction result dimension ensure that the teacher network model provides more reliable parameter learning for the student network model when capacity underflows.
  • the teacher network based on the transformer module can achieve global area detection.
  • the teacher network model is set up based on the transformer module.
  • the HR-Net model is used as the teacher backbone network model for feature extraction to generate high-dimensional low-scale Features are embedded as blocks.
  • the HR-Net model is a high-resolution network.
  • the size of the block embedding is p.
  • the H ⁇ W pixel image is divided into a set of feature block embeddings. S 0 ⁇ R D and so on, where R D is the feature space output by the teacher backbone network model, and the number of feature blocks is Finally, it is input into the transformer module with a total of 12 layers.
  • the teacher network model includes a depth estimation branch.
  • the depth estimation branch uses the multi-scale features of the teacher backbone network model and the teacher coding features as input sources to estimate the image depth through a top-down decoding structure.
  • this structure adopts the upsampling module of bilinear interpolation.
  • the feature module after each sampling corresponds to the feature scale of the teacher backbone network model, that is, a 2 times upsampling mechanism is implemented to estimate the image depth, and the teacher backbone network model outputs the corresponding feature dimension. .
  • the final output of the teacher network model and the student network model is corrected by the L2 loss function. Since there is no maximum value function in the network, the L2 loss function is used to correct the features before the last activation layer in the corresponding network model. When training the student network model, the distillation loss function group and the L2 loss function can achieve more reliable correction effects. When applied to plane restoration and reconstruction, the prediction accuracy of the trained student network model is higher.
  • the neural network training method based on knowledge distillation of the first aspect of the present invention is described in detail below with a specific embodiment. It should be understood that the following description is only illustrative and does not specifically limit the invention.
  • the teacher network model is based on the teacher network model designed by the transformer module.
  • the teacher network model includes a teacher backbone network model, a teacher encoder and a teacher decoder. Among them, the teacher backbone network
  • the model uses the HR-Net model.
  • the student network model includes a student encoder and a student decoder. The student network model omits the settings of its backbone network model;
  • mobilenet-v3 is a lightweight network that uses mobilenet-v3 as a feature extractor.
  • the intermediate feature map is sampled to obtain the fusion feature layer.
  • the student encoding features are input to the student decoder for upsampling decoding to obtain the student decoding features;
  • the encoding loss function is obtained.
  • the decoding loss function is obtained.
  • the prediction result loss function is obtained;
  • the downsampling encoding of the student encoder is corrected according to the encoding loss function, the convolution before the student decoder is corrected according to the decoding loss function, and the upsampling decoding of the student decoder is corrected according to the prediction result loss function.
  • the student network model is trained to obtain a trained student network model, which can be used to achieve plane restoration and reconstruction.
  • the training samples are input to the student encoder for down-sampling encoding to obtain the student encoding features; according to each down-sampling intermediate feature map generated during the down-sampling encoding process of the student encoder, each scale is obtained
  • the fusion feature layer under The intermediate feature maps are upsampled for fusion, and the student decoder outputs the student prediction result features.
  • the feature fusion module fuses the general-scale shallow features, that is, the fused feature layer, with the upsampled intermediate feature map in the upsampling decoding process, with a resolution of 1/32 respectively. , 1/16, 1/8, 1/4 and 1/2, which can ensure that features of the same scale have the same feature channel after each feature fusion.
  • student encoding features, student decoding features and student prediction result features are processed separately. Transfer learning.
  • the second embodiment of the present invention also provides an electronic device.
  • the electronic device includes: a memory, a processor, and a computer program stored in the memory and executable on the processor.
  • the processor and memory may be connected via a bus or other means.
  • memory can be used to store non-transitory software programs and non-transitory computer executable programs.
  • the memory may include high-speed random access memory and may also include non-transitory memory, such as at least one magnetic disk storage device, flash memory device, or other non-transitory solid-state storage device.
  • the memory may optionally include memory located remotely from the processor, and the remote memory may be connected to the processor via a network. Examples of the above-mentioned networks include but are not limited to the Internet, intranets, local area networks, mobile communication networks and combinations thereof.
  • the non-transient software programs and instructions required to implement the neural network training method based on knowledge distillation in the first embodiment are stored in the memory.
  • the neural network based on knowledge distillation in the above embodiment is executed.
  • the training method for example, executes the above-described method steps S1000 to S3000, method steps S2100 to S2300, method step S2110, method steps S2111 to S2114, method step S2210, and method step S2310 to S2330.
  • the device embodiments described above are only illustrative, and the units described as separate components may or may not be physically separate, that is, they may be located in one place, or they may be distributed to multiple network units. Some or all of the modules can be selected according to actual needs to achieve the purpose of the solution of this embodiment.
  • an embodiment of the present invention also provides a computer-readable storage medium that stores computer-executable instructions, and the computer-executable instructions are executed by a processor or controller, for example, by the above-mentioned Execution by a processor in the device embodiment can cause the above processor to execute the neural network training method based on knowledge distillation in the above embodiment, for example, execute the above-described method steps S1000 to S3000, method steps S2100 to S2300, and method steps S2110, method steps S2111 to S2114, method step S2210, and method steps S2310 to S2330.
  • Computer storage media includes, but is not limited to, RAM, ROM, EEPROM, flash memory or other memory technology, CD-ROM, Digital Versatile Disk (DVD) or other optical disk storage, magnetic cassettes, tapes, disk storage or other magnetic storage devices, or may Any other medium used to store the desired information and that can be accessed by a computer.
  • communication media typically embodies computer readable instructions, data structures, program modules or other data in a modulated data signal such as a carrier wave or other transport mechanism, and may include any information delivery media .

Abstract

A knowledge distillation based neural network training method, a device, and a storage medium. The method comprises the following steps: constructing an untrained student network model and a trained teacher network model (S1000); according to training samples, said student network model and said teacher network model, obtaining a distillation loss function group (S2000), the loss function group comprising an encoding loss function, a decoding loss function and a prediction result loss function; and according to the distillation loss function group, training said student network model, so as to obtain a trained student network model (S3000). Knowledge distillation is performed on a student network model by means of a trained teacher network model, so that scenario information acquisition capability and data generalization capability of the student network model can be effectively improved; and a distillation loss function group consisting of multiple loss functions is acquired, and then the student network model is trained by means of the distillation loss function group, so that the accuracy of a student network model prediction result can be effectively improved.

Description

基于知识蒸馏的神经网络训练方法、设备及存储介质Neural network training methods, equipment and storage media based on knowledge distillation 技术领域Technical field
本发明涉及人工智能领域,特别涉及一种基于知识蒸馏的神经网络训练方法、设备及存储介质。The invention relates to the field of artificial intelligence, and in particular to a neural network training method, equipment and storage medium based on knowledge distillation.
背景技术Background technique
随着深度学习的发展,三维平面恢复与重建技术是目前计算机视觉领域的研究任务之一,单张图片的三维平面与恢复需要从图像维度分割出场景的平面实例区域,同时估计出每个实例区域的平面参数,非平面区域会用网络模型估计的深度进行表示,三维平面恢复与重建技术在虚拟现实、增强现实、机器人等领域具有广阔的应用前景。With the development of deep learning, three-dimensional plane restoration and reconstruction technology is one of the current research tasks in the field of computer vision. The three-dimensional plane and restoration of a single image require segmenting the plane instance area of the scene from the image dimension and estimating each instance at the same time. The plane parameters of the area, non-planar areas will be represented by the depth estimated by the network model, and the three-dimensional plane recovery and reconstruction technology has broad application prospects in virtual reality, augmented reality, robots and other fields.
平面恢复重建是三维平面恢复与重建的一个重要研究方向,相关技术中,三维平面恢复与重建方法着重于重建精度,通过分析平面结构的边缘以及与场景的嵌入性来加强神经网络模型的准确性,但用于平面恢复与重建的神经网络模型存在丢失场景结构信息、缺乏数据泛化性的问题。Plane restoration and reconstruction is an important research direction in three-dimensional plane restoration and reconstruction. Among related technologies, three-dimensional plane restoration and reconstruction methods focus on reconstruction accuracy and enhance the accuracy of neural network models by analyzing the edges of plane structures and their embedding with the scene. , but the neural network model used for plane restoration and reconstruction has the problem of losing scene structure information and lacking data generalization.
发明内容Contents of the invention
本发明旨在至少解决现有技术中存在的技术问题之一。为此,本发明提供了一种基于知识蒸馏的神经网络训练方法、设备及存储介质,可提升神经网络模型的场景信息获取能力以及数据泛化性能。The present invention aims to solve at least one of the technical problems existing in the prior art. To this end, the present invention provides a neural network training method, equipment and storage medium based on knowledge distillation, which can improve the scene information acquisition capability and data generalization performance of the neural network model.
本发明第一方面实施例提供一种基于知识蒸馏的神经网络训练方法,包括如下步骤:The first embodiment of the present invention provides a neural network training method based on knowledge distillation, which includes the following steps:
构建未训练的学生网络模型和训练好的教师网络模型;Build untrained student network models and trained teacher network models;
根据训练样本、学生网络模型和教师网络模型,得到蒸馏损失函数组,其中,蒸馏损失函数组包括编码损失函数、解码损失函数和预测结果损失函数;According to the training samples, student network model and teacher network model, a distillation loss function group is obtained, where the distillation loss function group includes the encoding loss function, the decoding loss function and the prediction result loss function;
根据蒸馏损失函数组对学生网络模型进行训练,得到训练好的学生网络模型。The student network model is trained according to the distillation loss function group to obtain the trained student network model.
根据本发明的上述实施例,至少具有如下有益效果:通过设置训练好的教师网络模型对学生网络模型进行知识蒸馏,可有效提高学生网络模型的场景信息获取能力以及数据泛化性能力,而且根据教师网络模型和学生网络模型获取由多种损失函数组成的蒸馏损失函数组,再通过蒸馏损失函数组对学生网络模型进行训练,能够确保学生网络模型中每个处理环节的可靠性,从而有效提高学生网络模型预测结果的准确性。According to the above embodiments of the present invention, at least the following beneficial effects are achieved: by setting the trained teacher network model to perform knowledge distillation on the student network model, the scene information acquisition ability and data generalization ability of the student network model can be effectively improved, and according to The teacher network model and the student network model obtain a distilled loss function group composed of a variety of loss functions, and then train the student network model through the distilled loss function group, which can ensure the reliability of each processing link in the student network model, thereby effectively improving Accuracy of student network model prediction results.
根据本发明第一方面的一些实施例,根据训练样本、学生网络模型和教师网络模型,得到蒸馏损失函数组,包括:According to some embodiments of the first aspect of the present invention, based on the training samples, the student network model and the teacher network model, a distillation loss function group is obtained, including:
将训练样本输入到学生网络模型,得到学生特征组;Input the training samples into the student network model to obtain the student feature group;
将训练样本输入到教师网络模型,得到教师特征组;Input the training samples into the teacher network model to obtain the teacher feature group;
根据学生特征组和教师特征组,得到蒸馏损失函数组。According to the student feature group and the teacher feature group, a distillation loss function group is obtained.
根据本发明第一方面的一些实施例,将训练样本输入到学生网络模型,得到学生特征组,包括:According to some embodiments of the first aspect of the present invention, the training samples are input into the student network model to obtain the student feature group, including:
将训练样本输入到学生网络模型,得到包括学生编码特征、学生解码特征和学生预测结果特征的学生特征组。Input the training samples into the student network model to obtain a student feature group including student encoding features, student decoding features and student prediction result features.
根据本发明第一方面的一些实施例,将训练样本输入到教师网络模型,得到教师特征组,包括:According to some embodiments of the first aspect of the present invention, the training samples are input into the teacher network model to obtain the teacher feature group, including:
将训练样本输入到教师网络模型,得到包括教师编码特征、教师解码特征和教师预测结果特征的教师特征组,其中,教师网络模型包括教师骨干网络模型、教师编码器和教师解码器,教师骨干网络模型输出教师编码特征,教师编码器输出教师解码特征,教师解码器输出教师预测结果特征。Input the training samples into the teacher network model to obtain a teacher feature group including teacher encoding features, teacher decoding features and teacher prediction result features. The teacher network model includes a teacher backbone network model, a teacher encoder and a teacher decoder. The teacher backbone network The model outputs teacher encoding features, the teacher encoder outputs teacher decoding features, and the teacher decoder outputs teacher prediction result features.
根据本发明第一方面的一些实施例,根据学生特征组和教师特征组,得到蒸馏损失函数组,包括:According to some embodiments of the first aspect of the present invention, according to the student feature group and the teacher feature group, a distillation loss function group is obtained, including:
根据学生编码特征和教师编码特征,得到编码损失函数;According to the student coding characteristics and the teacher coding characteristics, the coding loss function is obtained;
根据学生解码特征和教师解码特征,得到解码损失函数;According to the student decoding characteristics and the teacher decoding characteristics, the decoding loss function is obtained;
根据学生预测结果特征和教师预测结果特征,得到预测结果损失函数。According to the characteristics of student prediction results and the characteristics of teacher prediction results, the prediction result loss function is obtained.
根据本发明第一方面的一些实施例,学生网络模型包括学生编码器和学生解码器;According to some embodiments of the first aspect of the invention, the student network model includes a student encoder and a student decoder;
将训练样本输入到学生网络模型,得到包括学生编码特征、学生解码特征和学生预测结果特征的学生特征组,包括:Input the training samples into the student network model to obtain the student feature group including student encoding features, student decoding features and student prediction result features, including:
将训练样本输入到学生编码器进行下采样编码,得到学生编码特征;Input the training samples to the student encoder for downsampling encoding to obtain student encoding features;
根据学生编码器的下采样编码过程,得到融合特征层;According to the down-sampling encoding process of the student encoder, the fused feature layer is obtained;
对学生编码特征进行卷积,得到学生解码特征;Convolve the student coding features to obtain the student decoding features;
将学生解码特征和融合特征层输入到学生解码器,先进行融合再进行上采样解码,得到学生预测结果特征。The student decoding features and fused feature layers are input to the student decoder, which is first fused and then upsampled and decoded to obtain the student prediction result features.
根据本发明第一方面的一些实施例,根据学生编码器的下采样编码过程,得到融合特征层,包括:According to some embodiments of the first aspect of the present invention, according to the down-sampling encoding process of the student encoder, a fused feature layer is obtained, including:
根据学生编码器的下采样编码过程形成的每一下采样中间特征图,得到每一尺度下的融合特征层。According to each downsampling intermediate feature map formed by the downsampling encoding process of the student encoder, the fused feature layer at each scale is obtained.
根据本发明第一方面的一些实施例,将学生解码特征和融合特征层输入到学生解码器, 先进行融合再进行上采样解码,得到学生预测结果特征,包括:According to some embodiments of the first aspect of the present invention, the student decoding features and the fusion feature layer are input to the student decoder, and are first fused and then upsampled and decoded to obtain student prediction result features, including:
将学生解码特征输入到学生解码器进行上采样解码,并将每一融合特征层分别与学生解码器上采样解码过程中对应尺度下的上采样中间特征图进行融合,得到学生解码特征。The student decoding features are input to the student decoder for upsampling decoding, and each fusion feature layer is fused with the upsampled intermediate feature map at the corresponding scale during the upsampling decoding process of the student decoder to obtain the student decoding features.
本发明第二方面实施例提供一种电子设备,包括:A second embodiment of the present invention provides an electronic device, including:
存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,处理器执行计算机程序时实现第一方面任意一项的基于知识蒸馏的神经网络训练方法。A memory, a processor, and a computer program stored in the memory and executable on the processor. When the processor executes the computer program, the neural network training method based on knowledge distillation according to any one of the first aspects is implemented.
由于第二方面实施例的电子设备应用第一方面任意一项的基于知识蒸馏的神经网络训练方法,因此具有本发明第一方面的所有有益效果。Since the electronic device of the embodiment of the second aspect applies any one of the knowledge distillation-based neural network training methods of the first aspect, it has all the beneficial effects of the first aspect of the present invention.
根据本发明第三方面实施例提供的一种计算机存储介质,存储有计算机可执行指令,计算机可执行指令用于执行第一方面任意一项的基于知识蒸馏的神经网络训练方法。A computer storage medium provided according to an embodiment of the third aspect of the present invention stores computer-executable instructions, and the computer-executable instructions are used to execute any one of the knowledge distillation-based neural network training methods of the first aspect.
由于第三方面实施例的计算机存储介质可执行第一方面任意一项的基于知识蒸馏的神经网络训练方法,因此具有本发明第一方面的所有有益效果。Since the computer storage medium of the embodiment of the third aspect can execute any one of the knowledge distillation-based neural network training methods of the first aspect, it has all the beneficial effects of the first aspect of the present invention.
本发明的附加方面和优点将在下面的描述中部分给出,部分将从下面的描述中变得明显,或通过本发明的实践了解到。Additional aspects and advantages of the invention will be set forth in part in the description which follows, and in part will be obvious from the description, or may be learned by practice of the invention.
附图说明Description of the drawings
本发明的上述和/或附加的方面和优点从结合下面附图对实施例的描述中将变得明显和容易理解,其中:The above and/or additional aspects and advantages of the present invention will become apparent and readily understood from the description of the embodiments taken in conjunction with the following drawings, in which:
图1是本发明实施例的基于知识蒸馏的神经网络训练方法的主要步骤图;Figure 1 is a main step diagram of the neural network training method based on knowledge distillation according to the embodiment of the present invention;
图2是本发明实施例的基于知识蒸馏的神经网络训练方法中步骤S2000的具体步骤图;Figure 2 is a specific step diagram of step S2000 in the neural network training method based on knowledge distillation according to the embodiment of the present invention;
图3是本发明实施例的基于知识蒸馏的神经网络训练方法中步骤S2100的具体步骤图;Figure 3 is a specific step diagram of step S2100 in the neural network training method based on knowledge distillation according to the embodiment of the present invention;
图4是本发明实施例的基于知识蒸馏的神经网络训练方法中步骤S2300的具体步骤图;Figure 4 is a specific step diagram of step S2300 in the neural network training method based on knowledge distillation according to the embodiment of the present invention;
图5是本发明实施例的基于知识蒸馏的神经网络训练方法中学生网络模型的工作原理图;Figure 5 is a working principle diagram of the student network model in the neural network training method based on knowledge distillation according to the embodiment of the present invention;
图6是本发明实施例的基于知识蒸馏的神经网络训练方法中教师网络模型的工作原理图;Figure 6 is a working principle diagram of the teacher network model in the neural network training method based on knowledge distillation according to the embodiment of the present invention;
图7是本发明实施例的基于知识蒸馏的神经网络训练方法的工作原理图。Figure 7 is a working principle diagram of the neural network training method based on knowledge distillation according to the embodiment of the present invention.
具体实施方式Detailed ways
本发明的描述中,除非另有明确的限定,设置、安装、连接等词语应做广义理解,所属技术领域技术人员可以结合技术方案的具体内容合理确定上述词语在本发明中的具体含义。在本发明的描述中,若干的含义是一个或者多个,多个的含义是两个以上,大于、小于、超过等理解为不包括本数,以上、以下、以内等理解为包括本数。此外,限定有“第一”、“第二”的特征可以明示或者隐含地包括一个或者更多个该特征。在本发明的描述中,除非另有说明,“多个”的含义是两个或两个以上。In the description of the present invention, unless otherwise explicitly limited, words such as setting, installation, and connection should be understood in a broad sense. Those skilled in the art can reasonably determine the specific meaning of the above words in the present invention in combination with the specific content of the technical solution. In the description of the present invention, several means one or more, plural means two or more, greater than, less than, more than, etc. are understood to exclude the original number, and above, below, within, etc. are understood to include the original number. In addition, features defined as “first” and “second” may explicitly or implicitly include one or more of these features. In the description of the present invention, unless otherwise specified, "plurality" means two or more.
随着深度学习的发展,计算机视觉领域受到了越来有多的研究者关注,三维平面恢复与重建技术是目前计算机视觉领域的研究任务之一,单张图片的三维平面与恢复需要从图像维度分割出场景的平面实例区域,同时估计出每个实例区域的平面参数,非平面区域会用网络模型估计的深度进行表示,三维平面恢复与重建技术在虚拟现实、增强现实、机器人等领域具有广阔的应用前景。单张图片的平面检测与恢复方法需要同时对图像深度、平面法线、平面分割等展开研究,传统的基于人工提取特征的三维平面恢复重建方法仅提取了图像的浅层纹理信息,同时依赖于平面几何的先验条件,存在泛化能力较弱的缺点。而现实中室内场景十分复杂,复杂光线所产生的多重阴影以及各种折叠遮挡物都会影响平面恢复和重建的质量,导致传统方法难以应对复杂室内场景的平面恢复和重建的任务。With the development of deep learning, the field of computer vision has attracted the attention of more and more researchers. Three-dimensional plane restoration and reconstruction technology is one of the current research tasks in the field of computer vision. The three-dimensional plane and restoration of a single image need to start from the image dimension. Segment the plane instance area of the scene and estimate the plane parameters of each instance area. The non-planar area will be represented by the depth estimated by the network model. The three-dimensional plane recovery and reconstruction technology has broad application in the fields of virtual reality, augmented reality, robots and other fields. application prospects. The plane detection and restoration method of a single image requires simultaneous research on image depth, plane normal, plane segmentation, etc. The traditional three-dimensional plane restoration and reconstruction method based on artificial extraction of features only extracts the shallow texture information of the image and relies on The prior conditions of plane geometry have the disadvantage of weak generalization ability. In reality, indoor scenes are very complex. Multiple shadows produced by complex light and various folding obstructions will affect the quality of plane restoration and reconstruction, making it difficult for traditional methods to cope with the task of plane restoration and reconstruction of complex indoor scenes.
目前,三维重建的方法主要通过三维视觉方法生成点云数据,然后再通过拟合相关点生成非线性的场景表面,再通过全局推理优化整体的重建模型,而平面恢复重建是结合视觉实例分割方法识别场景的平面区域,并用笛卡尔坐标下的三个参数以及一个分割掩码表示平面,具有更好的重建精度和效果,通过分段式的平面恢复重建来实现三维重建。分段式平面恢复重建是多阶段的重建方法,平面识别和参数估计的精确度都会影响最终模型的结果。Currently, 3D reconstruction methods mainly generate point cloud data through 3D vision methods, then generate nonlinear scene surfaces by fitting relevant points, and then optimize the overall reconstruction model through global reasoning. Planar restoration and reconstruction are combined with visual instance segmentation methods. Identify the plane area of the scene and use three parameters in Cartesian coordinates and a segmentation mask to represent the plane, which has better reconstruction accuracy and effect. Three-dimensional reconstruction is achieved through segmented plane restoration and reconstruction. Segmented plane restoration and reconstruction is a multi-stage reconstruction method, and the accuracy of plane identification and parameter estimation will affect the results of the final model.
相关技术中,平面预测重建可通过多种方法实现:卷积神经网络架构Planenet能够从单张RGB图片中推断固定数量的平面实例掩码以及平面参数;还能够通过预测固定数量的平面直接从平面结构诱导的深度模态中学习;两阶段Mask R-CNN框架用平面几何预测代替对象类别分类,然后用卷积神经网络对平面分割掩码进行细化;还可以预测逐像素平面参数,采用关联嵌入方法,训练网络参数将每个像素映射到嵌入空间,再将嵌入的像素聚类成平面实例;受曼哈顿世界假设约束的平面细化方法,通过限制平面实例间的集合关系来加强平面化参数的细化;从垂直和水平方向对全景图平面分割进行了分而治之的处理方法,针对于全景图与普通图像的像素分布差异,该方法能够恢复畸变的平面实例;基于transformer模块的方法PlaneTR,通过加入平面实例中心及边缘特征,能够有效提高平面检测的效率。Planenet是一种用于从单个RGB图片进行分段重建平面深度图的深度神经网络,Mask R-CNN是一种神经网络框架,transformer模块是一个基于多头注意力机制的模型,PlaneTR是一种用于提取场景中3D平面特征的模型。In related technologies, plane prediction reconstruction can be achieved through a variety of methods: the convolutional neural network architecture Planenet can infer a fixed number of plane instance masks and plane parameters from a single RGB image; it can also predict a fixed number of planes directly from the plane Learning in structure-induced deep modalities; the two-stage Mask R-CNN framework replaces object category classification with plane geometry prediction, and then refines the plane segmentation mask with a convolutional neural network; it can also predict pixel-by-pixel plane parameters using correlation The embedding method trains network parameters to map each pixel to the embedding space, and then clusters the embedded pixels into planar instances; the planar thinning method constrained by the Manhattan world assumption strengthens the planarization parameters by limiting the set relationship between planar instances refinement; a divide-and-conquer method is used to segment the panorama plane from the vertical and horizontal directions. In view of the difference in pixel distribution between the panorama and the ordinary image, this method can restore the distorted plane instance; the method PlaneTR based on the transformer module, through Adding center and edge features of plane instances can effectively improve the efficiency of plane detection. Planet is a deep neural network used to segmentally reconstruct plane depth maps from a single RGB image. Mask R-CNN is a neural network framework. The transformer module is a model based on a multi-head attention mechanism. PlaneTR is a user A model used to extract 3D planar features in the scene.
平面恢复重建是三维平面恢复与重建的一个重要研究方向,目前三维重建方法大多是先通过三维视觉方法生成点云数据,再通过拟合相关点生成非线性的场景表面,再通过全局推理优化整体的重建模型,相关技术中,三维平面恢复与重建方法着重于重建精度,通过分析平面结构的边缘以及与场景的嵌入性来加强模型的准确性,用于平面恢复与重建的神经网络存在丢失场景结构信息、缺乏数据泛化性的问题。Plane restoration and reconstruction is an important research direction in 3D plane restoration and reconstruction. Most of the current 3D reconstruction methods first generate point cloud data through 3D vision methods, then generate nonlinear scene surfaces by fitting relevant points, and then optimize the whole through global reasoning. Reconstruction model, in related technologies, three-dimensional plane restoration and reconstruction methods focus on reconstruction accuracy, and enhance the accuracy of the model by analyzing the edges of the plane structure and the embeddedness with the scene. The neural network used for plane restoration and reconstruction has missing scenes. Structural information and lack of data generalization issues.
为解决用于平面恢复与重建的神经网络存在的丢失场景信息、缺乏数据泛化性的问题,本发明通过知识蒸馏训练得到的学生网络模型用作平面恢复与重建,能够避免出现场景结构信息丢失、数据泛化性低的问题。In order to solve the problems of lost scene information and lack of data generalization in neural networks used for plane restoration and reconstruction, the student network model obtained by knowledge distillation training of the present invention is used for plane restoration and reconstruction, which can avoid the loss of scene structure information. , the problem of low data generalization.
知识蒸馏是一种训练框架,学生网络模型利用强大的教师网络模型的softmax函数输出向量作为软标签进行学习,学生网络模型一般是轻量化的小型网络模型,通过蒸馏过程能够有效提高轻量化网络模型的预测精度,由于教师网络模型与学生网络模型之间的容量存在差异,对模型预测结果进行硬性关联会使学生网络模型在蒸馏过程中受到负正则化,即网络的过拟合,从而限制了蒸馏过程的有效性,而且由于神经网络提取特征随着深度迭加愈加抽象化,通过知识蒸馏得到的学生网络模型存在预测准确度低的问题。Knowledge distillation is a training framework. The student network model uses the softmax function output vector of the powerful teacher network model to learn as soft labels. The student network model is generally a lightweight small network model. The distillation process can effectively improve the lightweight network model. The prediction accuracy of The effectiveness of the distillation process, and because the features extracted by the neural network become more abstract with depth superposition, the student network model obtained through knowledge distillation has the problem of low prediction accuracy.
从原始特征图推导出一个注意图来表达知识,通过匹配特征空间中的概率分布传递知识,引入因素作为一个更容易理解的中间表示形式,利用隐藏神经元的激活边界进行知识转移,知识转移的可塑性在最初几个训练阶段后迅速下降,知识蒸馏的有效性会被降低。An attention map is derived from the original feature map to express knowledge, knowledge is transferred by matching the probability distribution in the feature space, factors are introduced as a more understandable intermediate representation, and the activation boundaries of hidden neurons are used for knowledge transfer. Plasticity decreases rapidly after the first few training stages, and the effectiveness of knowledge distillation is reduced.
下面参照图1至图7描述本发明的基于知识蒸馏的神经网络训练方法、设备及存储介质,不仅能够提升所获学生网络模型的场景结构信息获取能力以及数据泛化性能力,还能够提高学生网络模型的预测精度。The following describes the neural network training method, equipment and storage medium based on knowledge distillation of the present invention with reference to Figures 1 to 7. It can not only improve the scene structure information acquisition ability and data generalization ability of the obtained student network model, but also improve the student network model. The prediction accuracy of the network model.
参考图1所示,根据本发明第一方面实施例的基于知识蒸馏的神经网络训练方法,包括如下步骤:Referring to Figure 1, the neural network training method based on knowledge distillation according to the first embodiment of the present invention includes the following steps:
步骤S1000、构建未训练的学生网络模型和训练好的教师网络模型;Step S1000: Construct an untrained student network model and a trained teacher network model;
步骤S2000、根据训练样本、学生网络模型和教师网络模型,得到蒸馏损失函数组,其中,蒸馏损失函数组包括编码损失函数、解码损失函数和预测结果损失函数;Step S2000: Obtain a distillation loss function group based on the training samples, the student network model and the teacher network model, where the distillation loss function group includes the encoding loss function, the decoding loss function and the prediction result loss function;
步骤S3000、根据蒸馏损失函数组对学生网络模型进行训练,得到训练好的学生网络模型,其中,训练好的学生网络模型可用于实现平面恢复重建。Step S3000: Train the student network model according to the distillation loss function group to obtain a trained student network model, where the trained student network model can be used to implement plane restoration and reconstruction.
通过设置训练好的教师网络模型对学生网络模型进行知识蒸馏,可有效提高学生网络模型的场景信息获取能力以及数据泛化性能力,而且根据教师网络模型和学生网络模型获取由多种损失函数组成的蒸馏损失函数组,再通过蒸馏损失函数组对学生网络模型进行训练,能够确保学生网络模型中每个处理环节的可靠性,从而有效提高学生网络模型预测结果的准确性。By setting up the trained teacher network model to perform knowledge distillation on the student network model, the scene information acquisition ability and data generalization ability of the student network model can be effectively improved, and the acquisition is composed of a variety of loss functions based on the teacher network model and the student network model. The distillation loss function group, and then training the student network model through the distillation loss function group, can ensure the reliability of each processing link in the student network model, thereby effectively improving the accuracy of the student network model prediction results.
可以理解的是,参考图2所示,步骤S2000,根据训练样本、学生网络模型和教师网络模型,得到蒸馏损失函数组,包括但不限于以下步骤:It can be understood that, referring to Figure 2, step S2000 is to obtain a distillation loss function group based on the training samples, the student network model and the teacher network model, including but not limited to the following steps:
步骤S2100、将训练样本输入到学生网络模型,得到学生特征组;Step S2100: Input the training samples into the student network model to obtain the student feature group;
步骤S2200、将训练样本输入到教师网络模型,得到教师特征组;Step S2200: Input the training samples into the teacher network model to obtain the teacher feature group;
步骤S2300、根据学生特征组和教师特征组,得到蒸馏损失函数组。Step S2300: Obtain a distillation loss function group based on the student feature group and the teacher feature group.
利用同组训练样本,分别输入到学生网络模型和教师网络模型,通过知识蒸馏的方法提取两个模型的特征并构建蒸馏损失函数组,并利用具有多层感知的蒸馏损失函数组对学生网络模型进行训练,能够有效提高训练好的学生网络模型的性能,学生网络模型的预测精度高。Using the same set of training samples, input them into the student network model and the teacher network model respectively, extract the features of the two models through the knowledge distillation method and construct a distillation loss function group, and use the distillation loss function group with multi-layer perception to train the student network model Training can effectively improve the performance of the trained student network model, and the student network model has high prediction accuracy.
可以理解的是,学生网络模型包括学生编码器和学生解码器,学生特征组包括学生编码特征、学生解码特征和学生预测结果特征;步骤S2100,将训练样本输入到学生网络模型,得到学生特征组,包括但不限于以下步骤:It can be understood that the student network model includes a student encoder and a student decoder, and the student feature group includes student encoding features, student decoding features, and student prediction result features; step S2100, input the training samples into the student network model to obtain the student feature group , including but not limited to the following steps:
步骤S2110、将训练样本输入到学生网络模型,得到包括学生编码特征、学生解码特征和学生预测结果特征的学生特征组。Step S2110: Input the training samples into the student network model to obtain a student feature group including student coding features, student decoding features and student prediction result features.
通过设置学生网络模型省略学生网络模型预训练的学生骨干模型,使用学生编码器和学生解码器组成学生网络模型,能够显著简化学生网络模型的结构,学生网络模型的轻量化程度高,在使用训练好的学生网络模型进行预测时,能够有效提高其预测速度,用于平面检测恢复时的速度快;利用网络中间特征也具有一定的学习潜质,通过包含了三组蒸馏损失函数的蒸馏损失函数组对学生网络模型进行迭代训练,即通过循序渐进的蒸馏过程有助于减轻硬性关联的负面影响,最终获得的训练好的学生网络模型能够同时满足实时以及高精度的预测性能。By setting the student network model to omit the student network model pre-training, and using the student encoder and student decoder to form the student network model, the structure of the student network model can be significantly simplified. The student network model is highly lightweight and can be used for training. When a good student network model is used for prediction, it can effectively improve its prediction speed, and it is fast when used for plane detection and recovery; using the intermediate features of the network also has certain learning potential, through a distillation loss function group that contains three groups of distillation loss functions. Iterative training of the student network model, that is, through a step-by-step distillation process, helps alleviate the negative impact of hard correlations. The finally obtained trained student network model can meet both real-time and high-precision prediction performance.
可以理解的是,参考图3所示,步骤S2110,将训练样本输入到学生网络模型,得到包括学生编码特征、学生解码特征和学生预测结果特征的学生特征组,包括但不限于以下步骤:It can be understood that, referring to Figure 3, step S2110, the training samples are input into the student network model to obtain a student feature group including student coding features, student decoding features and student prediction result features, including but not limited to the following steps:
步骤S2111、将训练样本输入到学生编码器进行下采样编码,得到学生编码特征;Step S2111: Input the training samples to the student encoder for downsampling encoding to obtain student encoding features;
步骤S2112、根据学生编码器的下采样编码过程,得到融合特征层;Step S2112: Obtain the fusion feature layer according to the down-sampling encoding process of the student encoder;
步骤S2113、对学生编码特征进行卷积,得到学生解码特征;Step S2113: Convolve the student coding features to obtain the student decoding features;
步骤S2114、将学生解码特征和融合特征层输入到学生解码器,学生解码特征与融合特征层先进行融合再进行上采样解码,得到学生预测结果特征。Step S2114: Input the student decoding features and the fused feature layer to the student decoder. The student decoding features and the fused feature layer are first fused and then upsampled and decoded to obtain student prediction result features.
学生编码器对训练样本的输入数据进行下采样编码处理,下采样过程中,可以采用快速的下采样策略,通过具有足够大的感知域进行特征提取识别,能够有效提高识别速度,下采样操作会导致空间信息丢失,而这些丢失的信息在后续处理过程中是无法恢复的,通过在学生编码器下采样过程中提取相应的特征作为特征融合层,用于与学生解码器上采样解码过程中对应的特征进行融合,能够对下采样过程中丢失的空间信息作出相应的补偿,能够有效确保学生解码器上采样解码后得到的学生预测结果特征的可靠性。The student encoder performs down-sampling encoding processing on the input data of the training samples. During the down-sampling process, a fast down-sampling strategy can be used to extract and identify features with a large enough perceptual domain, which can effectively improve the recognition speed. The down-sampling operation will This results in the loss of spatial information, and this lost information cannot be recovered during subsequent processing. The corresponding features are extracted as a feature fusion layer during the downsampling process of the student encoder, which is used to correspond to the upsampling and decoding process of the student decoder. The fusion of features can make corresponding compensation for the spatial information lost during the down-sampling process, and can effectively ensure the reliability of the student prediction result features obtained after up-sampling and decoding by the student decoder.
可以理解的是,学生网络模型的工作原理参考图5所示,步骤S2112,根据学生编码器的下采样编码过程,得到融合特征层,包括但不限于以下步骤:It can be understood that the working principle of the student network model is shown in Figure 5. In step S2112, the fusion feature layer is obtained according to the down-sampling encoding process of the student encoder, including but not limited to the following steps:
根据学生编码器的下采样编码过程形成的每一下采样中间特征图,得到每一尺度下的融合特征层。According to each downsampling intermediate feature map formed by the downsampling encoding process of the student encoder, the fused feature layer at each scale is obtained.
步骤S2114,将学生解码特征和融合特征层输入到学生解码器,学生解码特征与融合特征层先进行融合再进行上采样解码,得到学生预测结果特征,包括但不限于以下步骤:Step S2114, input the student decoding features and the fused feature layer to the student decoder. The student decoding features and the fused feature layer are first fused and then upsampled and decoded to obtain the student prediction result features, including but not limited to the following steps:
将学生解码特征输入到学生解码器进行上采样解码,并将每一融合特征层分别与学生解码器上采样解码过程中对应尺度下的上采样中间特征图进行融合,学生解码器最终输出得到学生预测结果特征。Input the student decoding features to the student decoder for upsampling decoding, and fuse each fusion feature layer with the upsampled intermediate feature map at the corresponding scale during the upsampling decoding process of the student decoder. The final output of the student decoder is student Predicted result characteristics.
在学生编码器对训练样本的输入数据进行下采样编码处理时,通常需要对训练样本的输入数据进行多层下采样操作,一旦下采样的层数过多,会导致下采样操作丢失大部分的空间信息,由于这些丢失的信息在后续处理过程中是无法恢复的,向上采样过程提供的数据失真严重,会严重影响最终的预测结果,通过将下采样的浅层特征作为融合特征层,将每一融合特征层与同一尺度上采样的深层特征进行融合,能够逐步恢复空间细节,进而能够有效保证学生解码器输出的学生解码特征的可靠性。When the student encoder downsamples the input data of the training sample, it usually needs to perform a multi-layer downsampling operation on the input data of the training sample. Once the number of downsampling layers is too many, most of the downsampling operation will be lost. Spatial information, since this lost information cannot be recovered during subsequent processing, the data provided by the upsampling process is seriously distorted, which will seriously affect the final prediction result. By using the downsampled shallow features as the fusion feature layer, each A fusion feature layer is fused with deep features sampled at the same scale, which can gradually restore spatial details, thereby effectively ensuring the reliability of the student decoding features output by the student decoder.
可以理解的是,教师网络模型的工作原理参考图6所示,步骤S2200,将训练样本输入到教师网络模型,得到教师特征组,包括但不限于以下步骤:It can be understood that the working principle of the teacher network model is shown in Figure 6. In step S2200, the training samples are input into the teacher network model to obtain the teacher feature group, including but not limited to the following steps:
步骤S2210、将训练样本输入到教师网络模型,得到包括教师编码特征、教师解码特征和教师预测结果特征的教师特征组,其中,教师网络模型包括教师骨干网络模型、教师编码器和教师解码器,教师骨干网络模型输出教师编码特征,教师编码特征输入到教师编码器后输出教师解码特征,教师解码特征输入到教师解码器后输出教师预测结果特征,教师特征组包括教师编码特征、教师解码特征和教师预测结果特征。Step S2210: Input the training samples into the teacher network model to obtain a teacher feature group including teacher coding features, teacher decoding features and teacher prediction result features. The teacher network model includes a teacher backbone network model, a teacher encoder and a teacher decoder. The teacher backbone network model outputs teacher coding features. The teacher coding features are input to the teacher encoder and the teacher decoding features are output. The teacher decoding features are input to the teacher decoder and the teacher prediction result features are output. The teacher feature group includes teacher coding features, teacher decoding features and Teacher prediction outcome characteristics.
具体的,教师网络模型中,训练样本输入到教师骨干网络模型中,教师骨干网络模型输出教师编码特征,教师编码特征输入到教师编码器中,教师编码器输出教师解码特征,教师解码特征输入到教师解码器中,教师解码器输出教师预测结果特征。Specifically, in the teacher network model, training samples are input into the teacher backbone network model, the teacher backbone network model outputs teacher coding features, the teacher coding features are input into the teacher encoder, the teacher encoder outputs the teacher decoding features, and the teacher decoding features are input into In the teacher decoder, the teacher decoder outputs the teacher prediction result characteristics.
可以理解的是,参考图4所示,步骤S2300,根据学生特征组和教师特征组,得到蒸馏损失函数组,包括但不限于以下步骤:It can be understood that, referring to Figure 4, step S2300 is to obtain a distillation loss function group based on the student feature group and the teacher feature group, including but not limited to the following steps:
步骤S2310、根据学生编码特征和教师编码特征,得到编码损失函数,其中,编码损失函数用于对学生编码器的下采样编码进行校正,以使学生编码器输出更准确的学生编码特征;Step S2310: Obtain a coding loss function based on the student coding features and the teacher coding features, where the coding loss function is used to correct the downsampling coding of the student coder so that the student coder outputs more accurate student coding features;
步骤S2320、根据学生解码特征和教师解码特征,得到解码损失函数,其中,解码损失函数用于对学生解码器前的卷积进行校正,以确保输入到学生解码器的学生解码特征的准确性;Step S2320: Obtain a decoding loss function based on the student decoding features and the teacher decoding features, where the decoding loss function is used to correct the convolution before the student decoder to ensure the accuracy of the student decoding features input to the student decoder;
步骤S2330、根据学生预测结果特征和教师预测结果特征,得到预测结果损失函数,其 中,预测结果损失函数用于对学生解码器的上采样解码进行校正,以使学生网络模型输出更准确的学生预测结果特征。Step S2330: Obtain the prediction result loss function based on the student prediction result characteristics and the teacher prediction result characteristics. The prediction result loss function is used to correct the upsampling decoding of the student decoder so that the student network model outputs more accurate student predictions. Result characteristics.
基于知识蒸馏的神经网络训练方法的工作原理图参考图7所示,通过在教师网络模型中三个网络层与学生网络模型中对应网络层中,分别提取相应输出的特征,并生成相应的蒸馏损失函数,通过三种对应不同网络层的蒸馏损失函数对学生网络模型进行迭代训练,即通过学生网络模型与教师网络模型的相应层次之间实现直接且有效的一对一匹配,能够有效确保学生网络模型中对应网络层进行数据处理的准确性,从架构上有效提高学生网络模型的性能,能够有效确保学生网络模型的泛化性以及预测结果的准确性。The working principle diagram of the neural network training method based on knowledge distillation is shown in Figure 7. By extracting the corresponding output features in the three network layers in the teacher network model and the corresponding network layer in the student network model, and generating the corresponding distillation Loss function, the student network model is iteratively trained through three distillation loss functions corresponding to different network layers, that is, a direct and effective one-to-one matching is achieved between the corresponding levels of the student network model and the teacher network model, which can effectively ensure that students The accuracy of data processing at the corresponding network layer in the network model can effectively improve the performance of the student network model from an architectural perspective, and can effectively ensure the generalization of the student network model and the accuracy of the prediction results.
具有多个学生网络的知识蒸馏架构,采用批判性学习意识KD(Knowledge Distillation)方案,确保关键连接的形成,允许有效地模仿教师的信息流,而不是仅仅学习一个学生,允许学生网络模型和教师网络模型的对应层次之间进行直接和有效的一对一匹配,将教师网络模型和学生网络模型自适应分为三份,赋予每一份所包含对应网络层的自适应参数,并进行知识蒸馏学习,通过浅层网络特征关联的语义校正来显著提高特征知识迁移的有效性,利用注意机制实现跨层整流,能够缓解语义不匹配的问题。Knowledge distillation architecture with multiple student networks, using critical learning awareness KD (Knowledge Distillation) scheme to ensure the formation of key connections, allowing to effectively imitate the teacher's information flow, rather than just learning one student, allowing the student network model and teacher Direct and effective one-to-one matching between corresponding layers of the network model, adaptively divide the teacher network model and student network model into three parts, assign each part the adaptive parameters of the corresponding network layer, and perform knowledge distillation Learning, through semantic correction of shallow network feature associations, significantly improves the effectiveness of feature knowledge transfer, and uses the attention mechanism to achieve cross-layer rectification, which can alleviate the problem of semantic mismatch.
可以理解的是,步骤S3000、根据蒸馏损失函数组对学生网络模型进行训练,得到训练好的学生网络模型,包括但不限于以下步骤:It can be understood that step S3000 is to train the student network model according to the distillation loss function group to obtain the trained student network model, including but not limited to the following steps:
根据编码损失函数对学生编码器的下采样编码进行校正,根据解码损失函数对学生解码器之前的卷积进行校正,根据预测结果损失函数对学生解码器的上采样解码进行校正,通过上述方式对学生网络模型进行训练,得到训练好的学生网络模型。The downsampling encoding of the student encoder is corrected according to the encoding loss function, the convolution before the student decoder is corrected according to the decoding loss function, and the upsampling decoding of the student decoder is corrected according to the prediction result loss function. The student network model is trained to obtain a trained student network model.
学生网络在根据任务训练网络的同时,根据含有多个中间层损失函数的蒸馏损失函数组进行辅助训练,通过中间特征层的迁移学习能够加强学生网络的估计性能。具体的,在编码维度、解码维度以及预测结果的维度,保证了教师网络模型在容量下溢时为学生网络模型提供更可靠的参数学习。While training the network according to the task, the student network is auxiliary trained according to the distillation loss function group containing multiple intermediate layer loss functions. The estimation performance of the student network can be enhanced through transfer learning of the intermediate feature layer. Specifically, the encoding dimension, decoding dimension and prediction result dimension ensure that the teacher network model provides more reliable parameter learning for the student network model when capacity underflows.
可以理解的是,基于transformer模块的教师网络能够实现全局区域的检测,设置教师网络模型以transformer模块为基础进行搭建,以HR-Net模型作为特征提取的教师骨干网络模型,生成高维的低尺度特征作为块嵌入,HR-Net模型是一种高分辨率网络,块嵌入的尺寸为p,H×W的像素图片被分为特征块嵌入的集合
Figure PCTCN2022098769-appb-000001
S 0∈R D等等,其中R D是教师骨干网络模型输出的特征空间,特征块数量为
Figure PCTCN2022098769-appb-000002
最后输入到共有12层的transformer模块中,教师网络模型包括深度估计分支,深度估计分支以教师骨干网络模型的多尺度特征以及教师编码特征作为输入源,通过自上而下的解码结构估计图像深度,该结构 采用双线性插值的上采样模块,每次采样后的特征模块与教师骨干网络模型特征尺度相对应,即实行2倍上采样机制估计图像深度,教师骨干网络模型输出相应的特征维度。
It is understandable that the teacher network based on the transformer module can achieve global area detection. The teacher network model is set up based on the transformer module. The HR-Net model is used as the teacher backbone network model for feature extraction to generate high-dimensional low-scale Features are embedded as blocks. The HR-Net model is a high-resolution network. The size of the block embedding is p. The H×W pixel image is divided into a set of feature block embeddings.
Figure PCTCN2022098769-appb-000001
S 0 ∈ R D and so on, where R D is the feature space output by the teacher backbone network model, and the number of feature blocks is
Figure PCTCN2022098769-appb-000002
Finally, it is input into the transformer module with a total of 12 layers. The teacher network model includes a depth estimation branch. The depth estimation branch uses the multi-scale features of the teacher backbone network model and the teacher coding features as input sources to estimate the image depth through a top-down decoding structure. , this structure adopts the upsampling module of bilinear interpolation. The feature module after each sampling corresponds to the feature scale of the teacher backbone network model, that is, a 2 times upsampling mechanism is implemented to estimate the image depth, and the teacher backbone network model outputs the corresponding feature dimension. .
可以理解的是,教师网络模型与学生网络模型的最终输出进行L2损失函数校正,由于网络中没有最大值函数,L2损失函数用于对相应网络模型中最后一层激活层之前的特征进行校正,训练学生网络模型时,蒸馏损失函数组和L2损失函数能够实现更可靠的校正效果,应用于平面恢复与重建时,训练好的学生网络模型的预测精度更高。It can be understood that the final output of the teacher network model and the student network model is corrected by the L2 loss function. Since there is no maximum value function in the network, the L2 loss function is used to correct the features before the last activation layer in the corresponding network model. When training the student network model, the distillation loss function group and the L2 loss function can achieve more reliable correction effects. When applied to plane restoration and reconstruction, the prediction accuracy of the trained student network model is higher.
下面以一个具体的实施例来详细描述本发明第一方面的基于知识蒸馏的神经网络训练方法。值得理解的是,下述描述仅是示例性说明,而不是对发明的具体限制。The neural network training method based on knowledge distillation of the first aspect of the present invention is described in detail below with a specific embodiment. It should be understood that the following description is only illustrative and does not specifically limit the invention.
构建未训练的学生网络模型和训练好的教师网络模型,其中,教师网络模型基于transformer模块设计的教师网络模型,教师网络模型包括教师骨干网络模型、教师编码器和教师解码器,其中教师骨干网络模型采用HR-Net模型,学生网络模型包括学生编码器和学生解码器,学生网络模型省略了其骨干网络模型的设置;Construct an untrained student network model and a trained teacher network model. The teacher network model is based on the teacher network model designed by the transformer module. The teacher network model includes a teacher backbone network model, a teacher encoder and a teacher decoder. Among them, the teacher backbone network The model uses the HR-Net model. The student network model includes a student encoder and a student decoder. The student network model omits the settings of its backbone network model;
将训练样本输入到学生编码器进行下采样编码,得到学生编码特征,mobilenet-v3是一种轻量级网络,使用mobilenet-v3作为特征提取器,根据学生编码器的下采样编码过程中的下采样中间特征图,得到融合特征层,根据融合特征层,将学生编码特征输入到学生解码器进行上采样解码,得到学生解码特征;Input the training samples to the student encoder for down-sampling encoding to obtain student encoding features. mobilenet-v3 is a lightweight network that uses mobilenet-v3 as a feature extractor. According to the down-sampling encoding process of the student encoder, The intermediate feature map is sampled to obtain the fusion feature layer. According to the fusion feature layer, the student encoding features are input to the student decoder for upsampling decoding to obtain the student decoding features;
将训练样本输入到教师网络模型中,获取教师骨干网络模型的教师编码特征、教师编码器输出的教师解码特征、教师解码器输出的教师预测结果特征;获取学生编码器输出的学生编码特征、学生解码器前经过卷积处理后的学生解码特征、学生解码器输出的学生预测结果特征;Input the training samples into the teacher network model to obtain the teacher coding features of the teacher backbone network model, the teacher decoding features output by the teacher encoder, and the teacher prediction result features output by the teacher decoder; obtain the student coding features output by the student encoder, and the student The student decoding features after convolution processing before the decoder, and the student prediction result features output by the student decoder;
根据学生编码特征和教师编码特征,得到编码损失函数,根据学生解码特征和教师解码特征,得到解码损失函数,根据学生预测结果特征和教师预测结果特征,得到预测结果损失函数;According to the student coding characteristics and teacher coding characteristics, the encoding loss function is obtained. According to the student decoding characteristics and the teacher decoding characteristics, the decoding loss function is obtained. According to the student prediction result characteristics and the teacher prediction result characteristics, the prediction result loss function is obtained;
根据编码损失函数对学生编码器的下采样编码进行校正,根据解码损失函数对学生解码器之前的卷积进行校正,根据预测结果损失函数对学生解码器的上采样解码进行校正,通过上述方式对学生网络模型进行训练,得到训练好的学生网络模型,训练好的学生网络模型可用于实现平面恢复重建。The downsampling encoding of the student encoder is corrected according to the encoding loss function, the convolution before the student decoder is corrected according to the decoding loss function, and the upsampling decoding of the student decoder is corrected according to the prediction result loss function. The student network model is trained to obtain a trained student network model, which can be used to achieve plane restoration and reconstruction.
其中,学生网络模型在工作的过程中,训练样本输入到学生编码器进行下采样编码,得到学生编码特征;根据学生编码器下采样编码过程中生成的每一下采样中间特征图,得到每一尺度下的融合特征层;对学生编码特征进行卷积,得到学生解码特征;将学生解码特征输入到学生解码器进行上采样解码,并且将每一融合特征层分别与学生解码器中对应尺度下的 上采样中间特征图进行融合,学生解码器输出得到学生预测结果特征。在学生解码器中的每个解码阶段,通过特征融合模块将通尺度浅层特征即融合特征层,与上采样解码过程中的上采样中间特征图进行串联融合,其分辨率分别为1/32、1/16、1/8、1/4和1/2,能够保证每一次特征融合后相同尺度的特征具有相同的特征通道,在最后学生编码特征、学生解码特征和学生预测结果特征分别进行迁移学习。Among them, during the working process of the student network model, the training samples are input to the student encoder for down-sampling encoding to obtain the student encoding features; according to each down-sampling intermediate feature map generated during the down-sampling encoding process of the student encoder, each scale is obtained The fusion feature layer under The intermediate feature maps are upsampled for fusion, and the student decoder outputs the student prediction result features. At each decoding stage in the student decoder, the feature fusion module fuses the general-scale shallow features, that is, the fused feature layer, with the upsampled intermediate feature map in the upsampling decoding process, with a resolution of 1/32 respectively. , 1/16, 1/8, 1/4 and 1/2, which can ensure that features of the same scale have the same feature channel after each feature fusion. In the end, student encoding features, student decoding features and student prediction result features are processed separately. Transfer learning.
另外,本发明第二方面实施例还提供了一种电子设备,该电子设备包括:存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序。In addition, the second embodiment of the present invention also provides an electronic device. The electronic device includes: a memory, a processor, and a computer program stored in the memory and executable on the processor.
处理器和存储器可以通过总线或者其他方式连接。The processor and memory may be connected via a bus or other means.
存储器作为一种非暂态计算机可读存储介质,可用于存储非暂态软件程序以及非暂态性计算机可执行程序。此外,存储器可以包括高速随机存取存储器,还可以包括非暂态存储器,例如至少一个磁盘存储器件、闪存器件、或其他非暂态固态存储器件。在一些实施方式中,存储器可选包括相对于处理器远程设置的存储器,这些远程存储器可以通过网络连接至该处理器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。As a non-transitory computer-readable storage medium, memory can be used to store non-transitory software programs and non-transitory computer executable programs. In addition, the memory may include high-speed random access memory and may also include non-transitory memory, such as at least one magnetic disk storage device, flash memory device, or other non-transitory solid-state storage device. In some embodiments, the memory may optionally include memory located remotely from the processor, and the remote memory may be connected to the processor via a network. Examples of the above-mentioned networks include but are not limited to the Internet, intranets, local area networks, mobile communication networks and combinations thereof.
实现上述第一方面实施例的基于知识蒸馏的神经网络训练方法所需的非暂态软件程序以及指令存储在存储器中,当被处理器执行时,执行上述实施例中的基于知识蒸馏的神经网络训练方法,例如,执行以上描述的方法步骤S1000至S3000、方法步骤S2100至S2300、方法步骤S2110、方法步骤S2111至S2114、方法步骤S2210、方法步骤S2310至S2330。The non-transient software programs and instructions required to implement the neural network training method based on knowledge distillation in the first embodiment are stored in the memory. When executed by the processor, the neural network based on knowledge distillation in the above embodiment is executed. The training method, for example, executes the above-described method steps S1000 to S3000, method steps S2100 to S2300, method step S2110, method steps S2111 to S2114, method step S2210, and method step S2310 to S2330.
以上所描述的设备实施例仅仅是示意性的,其中作为分离部件说明的单元可以是或者也可以不是物理上分开的,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。The device embodiments described above are only illustrative, and the units described as separate components may or may not be physically separate, that is, they may be located in one place, or they may be distributed to multiple network units. Some or all of the modules can be selected according to actual needs to achieve the purpose of the solution of this embodiment.
此外,本发明的一个实施例还提供了一种计算机可读存储介质,该计算机可读存储介质存储有计算机可执行指令,该计算机可执行指令被一个处理器或控制器执行,例如,被上述设备实施例中的一个处理器执行,可使得上述处理器执行上述实施例中的基于知识蒸馏的神经网络训练方法,例如,执行以上描述的方法步骤S1000至S3000、方法步骤S2100至S2300、方法步骤S2110、方法步骤S2111至S2114、方法步骤S2210、方法步骤S2310至S2330。In addition, an embodiment of the present invention also provides a computer-readable storage medium that stores computer-executable instructions, and the computer-executable instructions are executed by a processor or controller, for example, by the above-mentioned Execution by a processor in the device embodiment can cause the above processor to execute the neural network training method based on knowledge distillation in the above embodiment, for example, execute the above-described method steps S1000 to S3000, method steps S2100 to S2300, and method steps S2110, method steps S2111 to S2114, method step S2210, and method steps S2310 to S2330.
本领域普通技术人员可以理解,上文中所公开方法中的全部或某些步骤、系统可以被实施为软件、固件、硬件及其适当的组合。某些物理组件或所有物理组件可以被实施为由处理器,如中央处理器、数字信号处理器或微处理器执行的软件,或者被实施为硬件,或者被实施为集成电路,如专用集成电路。这样的软件可以分布在计算机可读介质上,计算机可读介质可以包括计算机存储介质(或非暂时性介质)和通信介质(或暂时性介质)。如本领域普通技术人员公知的,术语计算机存储介质包括在用于存储信息(诸如计算机可读指令、数据 结构、程序模块或其他数据)的任何方法或技术中实施的易失性和非易失性、可移除和不可移除介质。计算机存储介质包括但不限于RAM、ROM、EEPROM、闪存或其他存储器技术、CD-ROM、数字多功能盘(DVD)或其他光盘存储、磁盒、磁带、磁盘存储或其他磁存储装置、或者可以用于存储期望的信息并且可以被计算机访问的任何其他的介质。此外,本领域普通技术人员公知的是,通信介质通常包含计算机可读指令、数据结构、程序模块或者诸如载波或其他传输机制之类的调制数据信号中的其他数据,并且可包括任何信息递送介质。Those of ordinary skill in the art can understand that all or some steps and systems in the methods disclosed above can be implemented as software, firmware, hardware, and appropriate combinations thereof. Some or all of the physical components may be implemented as software executed by a processor, such as a central processing unit, a digital signal processor, or a microprocessor, or as hardware, or as an integrated circuit, such as an application specific integrated circuit . Such software may be distributed on computer-readable media, which may include computer storage media (or non-transitory media) and communication media (or transitory media). As is known to those of ordinary skill in the art, the term computer storage media includes volatile and nonvolatile media implemented in any method or technology for storage of information such as computer readable instructions, data structures, program modules or other data. removable, removable and non-removable media. Computer storage media includes, but is not limited to, RAM, ROM, EEPROM, flash memory or other memory technology, CD-ROM, Digital Versatile Disk (DVD) or other optical disk storage, magnetic cassettes, tapes, disk storage or other magnetic storage devices, or may Any other medium used to store the desired information and that can be accessed by a computer. Additionally, it is known to those of ordinary skill in the art that communication media typically embodies computer readable instructions, data structures, program modules or other data in a modulated data signal such as a carrier wave or other transport mechanism, and may include any information delivery media .
在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示意性实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本发明的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不一定指的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任何的一个或多个实施例或示例中以合适的方式结合。In the description of this specification, reference to the description of the terms "one embodiment," "some embodiments," "illustrative embodiments," "examples," "specific examples," or "some examples" or the like is intended to be in conjunction with the implementation. An example or example describes a specific feature, structure, material, or characteristic that is included in at least one embodiment or example of the invention. In this specification, schematic representations of the above terms do not necessarily refer to the same embodiment or example. Furthermore, the specific features, structures, materials or characteristics described may be combined in any suitable manner in any one or more embodiments or examples.
尽管已经示出和描述了本发明的实施例,本领域的普通技术人员可以理解:在不脱离本发明的原理和宗旨的情况下可以对这些实施例进行多种变化、修改、替换和变型,本发明的范围由权利要求及其等同物限定。Although the embodiments of the present invention have been shown and described, those of ordinary skill in the art will appreciate that various changes, modifications, substitutions and variations can be made to these embodiments without departing from the principles and purposes of the invention. The scope of the invention is defined by the claims and their equivalents.

Claims (10)

  1. 一种基于知识蒸馏的神经网络训练方法,其特征在于,包括如下步骤:A neural network training method based on knowledge distillation, which is characterized by including the following steps:
    构建未训练的学生网络模型和训练好的教师网络模型;Build untrained student network models and trained teacher network models;
    根据训练样本、所述学生网络模型和所述教师网络模型,得到蒸馏损失函数组,其中,所述蒸馏损失函数组包括编码损失函数、解码损失函数和预测结果损失函数;According to the training samples, the student network model and the teacher network model, a distillation loss function group is obtained, wherein the distillation loss function group includes an encoding loss function, a decoding loss function and a prediction result loss function;
    根据所述蒸馏损失函数组对所述学生网络模型进行训练,得到训练好的学生网络模型。The student network model is trained according to the distillation loss function group to obtain a trained student network model.
  2. 根据权利要求1所述的基于知识蒸馏的神经网络训练方法,其特征在于,所述根据训练样本、所述学生网络模型和所述教师网络模型,得到蒸馏损失函数组,包括:The neural network training method based on knowledge distillation according to claim 1, characterized in that the distillation loss function group is obtained based on the training samples, the student network model and the teacher network model, including:
    将所述训练样本输入到学生网络模型,得到学生特征组;Input the training samples into the student network model to obtain the student feature group;
    将所述训练样本输入到教师网络模型,得到教师特征组;Input the training samples into the teacher network model to obtain the teacher feature group;
    根据所述学生特征组和所述教师特征组,得到所述蒸馏损失函数组。According to the student feature group and the teacher feature group, the distillation loss function group is obtained.
  3. 根据权利要求2所述的基于知识蒸馏的神经网络训练方法,其特征在于,所述将所述训练样本输入到学生网络模型,得到学生特征组,包括:The neural network training method based on knowledge distillation according to claim 2, characterized in that the said training samples are input into the student network model to obtain the student feature group, including:
    将所述训练样本输入到学生网络模型,得到包括学生编码特征、学生解码特征和学生预测结果特征的所述学生特征组。The training samples are input into the student network model to obtain the student feature group including student encoding features, student decoding features and student prediction result features.
  4. 根据权利要求3所述的基于知识蒸馏的神经网络训练方法,其特征在于,所述将所述训练样本输入到教师网络模型,得到教师特征组,包括:The neural network training method based on knowledge distillation according to claim 3, characterized in that the said training samples are input into the teacher network model to obtain the teacher feature group, including:
    将所述训练样本输入到教师网络模型,得到包括教师编码特征、教师解码特征和教师预测结果特征的所述教师特征组,其中,所述教师网络模型包括教师骨干网络模型、教师编码器和教师解码器,所述教师骨干网络模型输出所述教师编码特征,所述教师编码器输出所述教师解码特征,所述教师解码器输出所述教师预测结果特征。The training samples are input into the teacher network model to obtain the teacher feature group including teacher coding features, teacher decoding features and teacher prediction result features, where the teacher network model includes a teacher backbone network model, a teacher encoder and a teacher Decoder, the teacher backbone network model outputs the teacher encoding features, the teacher encoder outputs the teacher decoding features, and the teacher decoder outputs the teacher prediction result features.
  5. 根据权利要求4所述的基于知识蒸馏的神经网络训练方法,其特征在于,所述根据所述学生特征组和所述教师特征组,得到所述蒸馏损失函数组,包括:The neural network training method based on knowledge distillation according to claim 4, wherein the distillation loss function group is obtained according to the student feature group and the teacher feature group, including:
    根据所述学生编码特征和所述教师编码特征,得到编码损失函数;According to the student coding characteristics and the teacher coding characteristics, a coding loss function is obtained;
    根据所述学生解码特征和所述教师解码特征,得到解码损失函数;According to the student decoding characteristics and the teacher decoding characteristics, a decoding loss function is obtained;
    根据所述学生预测结果特征和所述教师预测结果特征,得到预测结果损失函数。According to the characteristics of the student's prediction results and the characteristics of the teacher's prediction results, a prediction result loss function is obtained.
  6. 根据权利要求3所述的基于知识蒸馏的神经网络训练方法,其特征在于,所述学生网络模型包括学生编码器和学生解码器;The neural network training method based on knowledge distillation according to claim 3, characterized in that the student network model includes a student encoder and a student decoder;
    将所述训练样本输入到学生网络模型,得到包括学生编码特征、学生解码特征和学生预测结果特征的所述学生特征组,包括:The training samples are input into the student network model to obtain the student feature group including student coding features, student decoding features and student prediction result features, including:
    将所述训练样本输入到所述学生编码器进行下采样编码,得到所述学生编码特征;Input the training samples to the student encoder for downsampling encoding to obtain the student encoding features;
    根据所述学生编码器的下采样编码过程,得到融合特征层;According to the down-sampling encoding process of the student encoder, a fused feature layer is obtained;
    对所述学生编码特征进行卷积,得到学生解码特征;Perform convolution on the student coding features to obtain student decoding features;
    将所述学生解码特征和所述融合特征层输入到所述学生解码器,先进行融合再进行上采样解码,得到所述学生预测结果特征。The student decoding features and the fusion feature layer are input to the student decoder, and are first fused and then upsampled and decoded to obtain the student prediction result features.
  7. 根据权利要求6所述的基于知识蒸馏的神经网络训练方法,其特征在于,所述根据所述学生编码器的下采样编码过程,得到融合特征层,包括:The neural network training method based on knowledge distillation according to claim 6, characterized in that the fusion feature layer is obtained according to the down-sampling encoding process of the student encoder, including:
    根据所述学生编码器的下采样编码过程形成的每一下采样中间特征图,得到每一尺度下的所述融合特征层。The fused feature layer at each scale is obtained according to each downsampled intermediate feature map formed by the downsampling encoding process of the student encoder.
  8. 根据权利要求7所述的基于知识蒸馏的神经网络训练方法,其特征在于,所述将所述学生解码特征和所述融合特征层输入到所述学生解码器,先进行融合再进行上采样解码,得到所述学生预测结果特征,包括:The neural network training method based on knowledge distillation according to claim 7, characterized in that the student decoding features and the fusion feature layer are input to the student decoder, and are first fused and then upsampled and decoded. , obtain the characteristics of the student's predicted results, including:
    将所述学生解码特征输入到所述学生解码器进行上采样解码,并将每一所述融合特征层分别与所述学生解码器上采样解码过程中对应尺度下的上采样中间特征图进行融合,所述学生解码器输出得到所述学生预测结果特征。Input the student decoding features to the student decoder for upsampling decoding, and fuse each fusion feature layer with the upsampled intermediate feature map at the corresponding scale during the upsampling decoding process of the student decoder. , the student decoder outputs the student prediction result characteristics.
  9. 一种电子设备,其特征在于,包括:An electronic device, characterized by including:
    存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如权利要求1至8中任意一项所述的基于知识蒸馏的神经网络训练方法。A memory, a processor and a computer program stored on the memory and executable on the processor. When the processor executes the computer program, the knowledge-based method as claimed in any one of claims 1 to 8 is implemented. Distilled neural network training method.
  10. 一种计算机存储介质,其特征在于,存储有计算机可执行指令,所述计算机可执行指令用于执行权利要求1至8中任意一项所述的基于知识蒸馏的神经网络训练方法。A computer storage medium, characterized in that computer-executable instructions are stored therein, and the computer-executable instructions are used to execute the neural network training method based on knowledge distillation according to any one of claims 1 to 8.
PCT/CN2022/098769 2022-05-05 2022-06-14 Knowledge distillation based neural network training method, device, and storage medium WO2023212997A1 (en)

Applications Claiming Priority (4)

Application Number Priority Date Filing Date Title
CN202210479268.2 2022-05-05
CN202210479268 2022-05-05
CN202210646401.9A CN114936605A (en) 2022-06-09 2022-06-09 Knowledge distillation-based neural network training method, device and storage medium
CN202210646401.9 2022-06-09

Publications (1)

Publication Number Publication Date
WO2023212997A1 true WO2023212997A1 (en) 2023-11-09

Family

ID=88646181

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2022/098769 WO2023212997A1 (en) 2022-05-05 2022-06-14 Knowledge distillation based neural network training method, device, and storage medium

Country Status (1)

Country Link
WO (1) WO2023212997A1 (en)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117421678A (en) * 2023-12-19 2024-01-19 西南石油大学 Single-lead atrial fibrillation recognition system based on knowledge distillation
CN117425013A (en) * 2023-12-19 2024-01-19 杭州靖安防务科技有限公司 Video transmission method and system based on reversible architecture
CN117557857A (en) * 2023-11-23 2024-02-13 哈尔滨工业大学 Detection network light weight method combining progressive guided distillation and structural reconstruction

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190046068A1 (en) * 2017-08-10 2019-02-14 Siemens Healthcare Gmbh Protocol independent image processing with adversarial networks
CN110852426A (en) * 2019-11-19 2020-02-28 成都晓多科技有限公司 Pre-training model integration acceleration method and device based on knowledge distillation
US20200125927A1 (en) * 2018-10-22 2020-04-23 Samsung Electronics Co., Ltd. Model training method and apparatus, and data recognition method
CN111932561A (en) * 2020-09-21 2020-11-13 深圳大学 Real-time enteroscopy image segmentation method and device based on integrated knowledge distillation
CN111950302A (en) * 2020-08-20 2020-11-17 上海携旅信息技术有限公司 Knowledge distillation-based machine translation model training method, device, equipment and medium
CN112992308A (en) * 2021-03-25 2021-06-18 腾讯科技(深圳)有限公司 Training method of medical image report generation model and image report generation method
CN114529796A (en) * 2022-01-30 2022-05-24 北京百度网讯科技有限公司 Model training method, image recognition method, device and electronic equipment

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190046068A1 (en) * 2017-08-10 2019-02-14 Siemens Healthcare Gmbh Protocol independent image processing with adversarial networks
US20200125927A1 (en) * 2018-10-22 2020-04-23 Samsung Electronics Co., Ltd. Model training method and apparatus, and data recognition method
CN110852426A (en) * 2019-11-19 2020-02-28 成都晓多科技有限公司 Pre-training model integration acceleration method and device based on knowledge distillation
CN111950302A (en) * 2020-08-20 2020-11-17 上海携旅信息技术有限公司 Knowledge distillation-based machine translation model training method, device, equipment and medium
CN111932561A (en) * 2020-09-21 2020-11-13 深圳大学 Real-time enteroscopy image segmentation method and device based on integrated knowledge distillation
CN112992308A (en) * 2021-03-25 2021-06-18 腾讯科技(深圳)有限公司 Training method of medical image report generation model and image report generation method
CN114529796A (en) * 2022-01-30 2022-05-24 北京百度网讯科技有限公司 Model training method, image recognition method, device and electronic equipment

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117557857A (en) * 2023-11-23 2024-02-13 哈尔滨工业大学 Detection network light weight method combining progressive guided distillation and structural reconstruction
CN117421678A (en) * 2023-12-19 2024-01-19 西南石油大学 Single-lead atrial fibrillation recognition system based on knowledge distillation
CN117425013A (en) * 2023-12-19 2024-01-19 杭州靖安防务科技有限公司 Video transmission method and system based on reversible architecture
CN117421678B (en) * 2023-12-19 2024-03-22 西南石油大学 Single-lead atrial fibrillation recognition system based on knowledge distillation
CN117425013B (en) * 2023-12-19 2024-04-02 杭州靖安防务科技有限公司 Video transmission method and system based on reversible architecture

Similar Documents

Publication Publication Date Title
Wang et al. SFNet-N: An improved SFNet algorithm for semantic segmentation of low-light autonomous driving road scenes
WO2023212997A1 (en) Knowledge distillation based neural network training method, device, and storage medium
Zhang et al. Deep dense multi-scale network for snow removal using semantic and depth priors
WO2020108358A1 (en) Image inpainting method and apparatus, computer device, and storage medium
CN101714262B (en) Method for reconstructing three-dimensional scene of single image
CN109087258B (en) Deep learning-based image rain removing method and device
WO2021103137A1 (en) Indoor scene illumination estimation model, method and device, and storage medium and rendering method
CN114936605A (en) Knowledge distillation-based neural network training method, device and storage medium
CN111539887B (en) Channel attention mechanism and layered learning neural network image defogging method based on mixed convolution
CN107358576A (en) Depth map super resolution ratio reconstruction method based on convolutional neural networks
CN111028235B (en) Image segmentation method for enhancing edge and detail information by utilizing feature fusion
CN110276354B (en) High-resolution streetscape picture semantic segmentation training and real-time segmentation method
CN110909666A (en) Night vehicle detection method based on improved YOLOv3 convolutional neural network
CN109509156B (en) Image defogging processing method based on generation countermeasure model
CN110599411A (en) Image restoration method and system based on condition generation countermeasure network
WO2023207778A1 (en) Data recovery method and device, computer, and storage medium
CN113392711A (en) Smoke semantic segmentation method and system based on high-level semantics and noise suppression
WO2023066173A1 (en) Image processing method and apparatus, and storage medium and electronic device
CN115908789A (en) Cross-modal feature fusion and asymptotic decoding saliency target detection method and device
CN114926734A (en) Solid waste detection device and method based on feature aggregation and attention fusion
CN113066025B (en) Image defogging method based on incremental learning and feature and attention transfer
CN113837290A (en) Unsupervised unpaired image translation method based on attention generator network
CN116342377A (en) Self-adaptive generation method and system for camouflage target image in degraded scene
Xu et al. SPNet: Superpixel pyramid network for scene parsing
CN111524090A (en) Depth prediction image-based RGB-D significance detection method

Legal Events

Date Code Title Description
121 Ep: the epo has been informed by wipo that ep was designated in this application

Ref document number: 22940691

Country of ref document: EP

Kind code of ref document: A1