CN116385844B - 一种基于多教师模型的特征图蒸馏方法、系统和存储介质 - Google Patents
一种基于多教师模型的特征图蒸馏方法、系统和存储介质 Download PDFInfo
- Publication number
- CN116385844B CN116385844B CN202211598032.7A CN202211598032A CN116385844B CN 116385844 B CN116385844 B CN 116385844B CN 202211598032 A CN202211598032 A CN 202211598032A CN 116385844 B CN116385844 B CN 116385844B
- Authority
- CN
- China
- Prior art keywords
- model
- teacher
- training sample
- student
- stage
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000004821 distillation Methods 0.000 title claims abstract description 57
- 238000000034 method Methods 0.000 title claims abstract description 40
- 238000010586 diagram Methods 0.000 claims abstract description 129
- 238000003709 image segmentation Methods 0.000 claims description 12
- 238000001514 detection method Methods 0.000 claims description 11
- 238000013145 classification model Methods 0.000 claims description 10
- 238000010276 construction Methods 0.000 claims description 5
- 230000000694 effects Effects 0.000 abstract description 9
- 230000006870 function Effects 0.000 description 53
- 238000013528 artificial neural network Methods 0.000 description 3
- 238000010606 normalization Methods 0.000 description 3
- 230000009286 beneficial effect Effects 0.000 description 2
- 230000006835 compression Effects 0.000 description 2
- 238000007906 compression Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 238000013140 knowledge distillation Methods 0.000 description 2
- 238000003062 neural network model Methods 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 238000004590 computer program Methods 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000011423 initialization method Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 238000011176 pooling Methods 0.000 description 1
- 210000001747 pupil Anatomy 0.000 description 1
- 230000011218 segmentation Effects 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Databases & Information Systems (AREA)
- Computational Linguistics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Multimedia (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Medical Informatics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
- Image Processing (AREA)
Abstract
本发明公开了一种基于多教师模型的特征图蒸馏方法、系统和存储介质,包括:利用多个教师模型对学生模型进行多阶段的特征图蒸馏,得到所述学生模型的目标损失函数;将任一训练样本分别输入每个教师模型中,得到该训练样本在每个教师模型中所对应的所有阶段的阶段特征图,直至得到每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图;基于所述目标损失函数、每个训练样本以及每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图,对所述学生模型进行迭代训练,得到训练好的学生模型。本发明通过多个教师模型对学生模型进行特征图蒸馏,在提高学生模型对于图像内容识别效果的同时,避免了因单个教师模型的特征图蒸馏所造成的偏差。
Description
技术领域
本发明涉及知识蒸馏技术领域,尤其涉及一种基于多教师模型的特征图蒸馏方法、系统和存储介质。
背景技术
深度卷积神经网络是目前图像内容识别中应用最广泛的深度学习技术,然而部署千万级别参数量的模型会耗费高额的成本。知识蒸馏是一种模型压缩技术,它通过一个大的教师模型指导小的学生模型训练,可以使小的模型达到与大模型接近的效果,在保证效果的前提下可以大大节省成本。
特征图蒸馏技术相比于Logit蒸馏,能够在早期为学生模型提供更多的指导信息,因此特征图蒸馏效果也是更好的。但使用一个教师模型对学生进行特征图蒸馏,会让学生模型学到的特征局限于教师模型的特征空间内,导致结果的偏差,既教师模型识别错误的图片学生模型通常也会识别错误。
因此,亟需提供一种技术方案解决上述技术问题。
发明内容
为解决上述技术问题,本发明提供了一种基于多教师模型的特征图蒸馏方法、系统和存储介质。
本发明的一种基于多教师模型的特征图蒸馏方法的技术方案如下:
利用多个教师模型对学生模型进行多阶段的特征图蒸馏,得到所述学生模型的目标损失函数;其中,所述目标损失函数包括:原始损失函数和每个教师模型在每一阶段的特征图蒸馏损失函数;
将任一训练样本分别输入每个教师模型中,得到该训练样本在每个教师模型中所对应的所有阶段的阶段特征图,直至得到每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图;
基于所述目标损失函数、每个训练样本以及每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图,对所述学生模型进行迭代训练,得到训练好的学生模型。
本发明的一种基于多教师模型的特征图蒸馏方法的有益效果如下:
本发明的方法通过多个教师模型对学生模型进行特征图蒸馏,在提高学生模型对于图像内容识别效果的同时,避免了因单个教师模型的特征图蒸馏所造成的偏差。
在上述方案的基础上,本发明的一种基于多教师模型的特征图蒸馏方法还可以做如下改进。
进一步,所述目标损失函数为:Loss=Lossglobal+a1Loss1+a2Loss2+…+anLossn;
其中,Loss为所述目标损失函数,Lossglobal为所述原始损失函数,Loss1为第一个教师模型在所有阶段的特征图蒸馏损失函数,Loss2为第二个教师在所有阶段的特征图蒸馏损失函数,Lossn为第n个教师模型在所有阶段的特征图蒸馏损失函数,a1为第一个教师模型的特征图蒸馏损失的系数,a2为第二个教师模型的特征图蒸馏损失的系数,an为第n个教师模型的特征图蒸馏损失的系数,
其中,j表示第j阶段,k表示总阶段数,/>表示第一个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第一个第一特征图的第i个像素点特征,m表示第j阶段特征图中的像素点的个数;/> 表示第二个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第二个第一特征图的第i个像素点特征;/>表示第n个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第n个第一特征图的第i个像素点特征。
进一步,获取所述任一训练样本在所述学生模型所对应的任一阶段的多个第一特征图的步骤,包括:
将所述任一训练样本输入至所述学生模型中,得到该训练样本在所述任一阶段的阶段特征图,并基于教师模型的数量,将该阶段特征图进行均分,得到该阶段特征图对应的多个均分特征图,并对该阶段特征图对应的每个均分特征图分别进行压缩和标准化处理,得到该训练样本在该阶段的多个第一特征图。
进一步,获取所述任一训练样本在任一教师模型所对应的任一阶段的第二特征图的步骤,包括:
对所述任一训练样本在所述任一教师模型中所对应的所述任一阶段的阶段特征图进行压缩和标准化处理,得到该训练样本在该阶段的第二特征图。
进一步,所述基于所述目标损失函数、每个训练样本以及每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图,对所述学生模型进行迭代训练,得到训练好的学生模型的步骤,包括:
将任一训练样本输入至所述学生模型中,得到该训练样本在所述学生模型中所对应的所有阶段的多个第一特征图,并基于所述目标损失函数、所述任一训练样本在所述学生模型中所对应的所有阶段的多个第一特征图和所述任一训练样本在每个教师模型中所对应的所有阶段的第二特征图,得到该训练样本的目标损失值,直至得到每个训练样本的目标损失值;
基于所有的目标损失值,对所述学生模型的参数进行优化,得到优化后的学生模型,将所述优化后的学生模型作为所述学生模型并返回执行将任一训练样本输入至所述学生模型中的步骤,直至满足预设迭代训练条件时,将所述优化后的学生模型确定为所述训练好的学生模型。
进一步,所述学生模型和每个教师模型的类型相同,所述学生模型和所有的教师模型中的任一模型的类型为:图像分割模型、图像分类模型或目标检测模型。
进一步,还包括:
当所述学生模型为图像分割模型时,将待测图像输入至所述训练好的学生模型,得到所述待测图像的图像分割结果;或,当所述学生模型为图像分类模型时,将所述待测图像输入至所述训练好的学生模型,得到所述待测图像的图像分类结果;或,当所述学生模型为目标检测模型时,将所述待测图像输入至所述训练好的学生模型,得到所述待测图像的目标检测结果。
本发明的一种基于多教师模型的特征图蒸馏系统的技术方案如下:
包括:构建模块、处理模块和运行模块;
所述构建模块用于:利用多个教师模型对学生模型进行多阶段的特征图蒸馏,得到所述学生模型的目标损失函数;其中,所述目标损失函数包括:原始损失函数和每个教师模型在每一阶段的特征图蒸馏损失函数;
所述处理模块用于:将任一训练样本分别输入每个教师模型中,得到该训练样本在每个教师模型中所对应的所有阶段的阶段特征图,直至得到每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图;
所述运行模块用于:基于所述目标损失函数、每个训练样本以及每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图,对所述学生模型进行迭代训练,得到训练好的学生模型。
本发明的一种基于多教师模型的特征图蒸馏系统的有益效果如下:
本发明的系统通过多个教师模型对学生模型进行特征图蒸馏,在提高学生模型对于图像内容识别效果的同时,避免了因单个教师模型的特征图蒸馏所造成的偏差。
在上述方案的基础上,本发明的一种基于多教师模型的特征图蒸馏系统还可以做如下改进。
进一步,所述目标损失函数为:Loss=Lossglobal+a1Loss1+a2Loss2+…+anLossn;
其中,Loss为所述目标损失函数,Lossglobal为所述原始损失函数,Loss1为第一个教师模型在所有阶段的特征图蒸馏损失函数,Loss2为第二个教师在所有阶段的特征图蒸馏损失函数,Lossn为第n个教师模型在所有阶段的特征图蒸馏损失函数,a1为第一个教师模型的特征图蒸馏损失的系数,a2为第二个教师模型的特征图蒸馏损失的系数,an为第n个教师模型的特征图蒸馏损失的系数,
其中,j表示第j阶段,k表示总阶段数,/>表示第一个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第一个第一特征图的第i个像素点特征,m表示第j阶段特征图中的像素点的个数;/> 表示第二个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第二个第一特征图的第i个像素点特征;/>表示第n个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第n个第一特征图的第i个像素点特征。
本发明的一种存储介质的技术方案如下:
存储介质中存储有指令,当计算机读取所述指令时,使所述计算机执行如本发明的一种基于多教师模型的特征图蒸馏方法的步骤。
附图说明
图1示出了本发明提供的一种基于多教师模型的特征图蒸馏方法的第一实施例的流程示意图;
图2示出了本发明提供的一种基于多教师模型的特征图蒸馏方法的第一实施例中训练过程的整体示意图;
图3示出了本发明提供的一种基于多教师模型的特征图蒸馏方法的第一实施例中步骤130的流程示意图;
图4示出了本发明提供的一种基于多教师模型的特征图蒸馏方法的第二实施例的流程示意图;
图5示出了本发明提供的一种基于多教师模型的特征图蒸馏系统的实施例的结构示意图。
具体实施方式
图1示出了本发明提供的一种基于多教师模型的特征图蒸馏方法的第一实施例的流程示意图。如图1所示,包括如下步骤:
步骤110:利用多个教师模型对学生模型进行多阶段的特征图蒸馏,得到所述学生模型的目标损失函数。
其中,①所述目标损失函数包括:原始损失函数和每个教师模型在每一阶段的特征图蒸馏损失函数。②教师模型为:训练好的神经网络模型,其类型包括但不限于:图像分类模型、图像分割模型和目标检测模型等。③学生模型为:同教师模型作用相同的模型,如教师模型和学生模型均为图像分类模型。④教师模型和学生模型中均包含多个阶段(stage),每一阶段均需要进行特征图蒸馏。
需要说明的是,①选用的教师模型需要比学生模型的效果更好,如用于图像分类的教师模型的准确率为95%,则用于图像分类的学生模型的准确率则应当低于95%。②不同教师模型之间可以是不同的神经网络结构,也可以是相同的神经网络结构。当两个或多个教师模型采用相同的神经网络结构时,这两个或多个教师模型应当是通过不同超参数(不同的学习率、不同的初始化方法、不同的优化器)训练的模型。
步骤120:将任一训练样本分别输入每个教师模型中,得到该训练样本在每个教师模型中所对应的所有阶段的阶段特征图,直至得到每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图。
其中,①训练样本为:任意选取的用于学生模型训练的图像,该图像具有标注信息。②阶段特征图包括:每个教师模型在相应阶段的特征图以及学生模型在相应阶段的特征图。假设教师模型(或学生模型)均包含5个阶段(stage),将任一图像(训练样本或待测图像)输入至教师模型或学生模型后,均可得到该图像在相应模型中所对应的5个阶段的阶段特征图。
步骤130:基于所述目标损失函数、每个训练样本以及每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图,对所述学生模型进行迭代训练,得到训练好的学生模型。Lossglobal
其中,①目标损失函数为:Loss=Lossglobal+a1Loss1+a2Loss2+…+anLossn;Loss为所述目标损失函数,Lossglobal为所述原始损失函数,Loss1为第一个教师模型在所有阶段的特征图蒸馏损失函数,Loss2为第二个教师在所有阶段的特征图蒸馏损失函数,Lossn为第n个教师模型在所有阶段的特征图蒸馏损失函数,a1为第一个教师模型的特征图蒸馏损失的系数,a2为第二个教师模型的特征图蒸馏损失的系数,an为第n个教师模型的特征图蒸馏损失的系数, j表示第j阶段,k表示总阶段数,/>表示第一个教师模型在第j阶段的第二特征图的第i个像素点特征,表示所述学生模型在第j阶段的第一个第一特征图的第i个像素点特征,m表示第j阶段特征图中的像素点的个数;/>表示第二个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第二个第一特征图的第i个像素点特征;/>表示第n个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第n个第一特征图的第i个像素点特征。②训练好的学生模型为:经过多次迭代训练所得到的神经网络模型,该模型(经过多阶段的特征图蒸馏后的模型)的效果高于直接通过训练样本训练的学生模型的效果。
较优地,获取所述任一训练样本在所述学生模型所对应的任一阶段的多个第一特征图的步骤,包括:
将所述任一训练样本输入至所述学生模型中,得到该训练样本在所述任一阶段的阶段特征图,并基于教师模型的数量,将该阶段特征图进行均分,得到该阶段特征图对应的多个均分特征图,并对该阶段特征图对应的每个均分特征图分别进行压缩和标准化处理,得到该训练样本在该阶段的多个第一特征图。
其中,①训练样本在学生模型的每一阶段的阶段特征图的数量为1个,训练样本在学生模型的每一阶段的均分特征图和第一特征图的数量为多个,该数量与教师模型的数量相同。例如,当教师模型为2个时,则每一阶段的均分特征图和第一特征图的数量均为2个。②对均分特征图进行压缩处理的过程为:对均分特征图进行通道维度池化处理,以将该均分特征图进行压缩,得到压缩特征图。③对压缩特征图进行标准化处理的过程为:对压缩特征图进行L2标准化处理,得到第一特征图。
需要说明的是,L2标准化处理是把特征图像素值缩放至0-1的归一化手段,具体过程为现有技术,在此不过多赘述。
较优地,获取所述任一训练样本在任一教师模型所对应的任一阶段的第二特征图的步骤,包括:
对所述任一训练样本在所述任一教师模型中所对应的所述任一阶段的阶段特征图进行压缩和标准化处理,得到该训练样本在该阶段的第二特征图。
其中,对根据教师模型所得到的阶段特征图进行压缩和标准化处理的过程与对根据学生模型所得到的阶段特征图进行压缩和标准化处理的过程相同,在此不过多赘述。
具体地,图2示出了本实施例中的学生模型的训练过程示意图。如图2所示,将训练样本分别输入至两个教师模型和学生模型。在两个教师模型和学生模型的第k阶段,将训练样本在第一个教师模型所对应的第k阶段的阶段特征图进行压缩和标准化处理,得到该训练样本在第一个教师模型所对应的第k阶段的第二特征图。重复上述方式,得到该训练样本在第二个教师模型所对应的第k阶段的第二特征图和该训练样本在学生模型所对应的第k阶段的两个第一特征图。这两个第一特征图是根据相应的阶段特征图进行均分并经过压缩和标准化处理后所得到的特征图。此时,将训练样本在第一教师模型所对应的第二特征图和在学生模型所对应的一个第一特征图进行比对,得到该训练样本在第一教师模型的第k阶段的特征图蒸馏损失值。将训练样本在第二教师模型所对应的第二特征图和在学生模型所对应的另一个第一特征图进行比对,得到该训练样本在第二教师模型的第k阶段的特征图蒸馏损失值。重复上述方式,直至得到该训练样本在第一教师模型的所有阶段的特征图蒸馏损失值和该样本在第二教师模型的所有阶段的特征图蒸馏损失值。
如图3所示,步骤130包括:
步骤131:将任一训练样本输入至所述学生模型中,得到该训练样本在所述学生模型中所对应的所有阶段的多个第一特征图,并基于所述目标损失函数、所述任一训练样本在所述学生模型中所对应的所有阶段的多个第一特征图和所述任一训练样本在每个教师模型中所对应的所有阶段的第二特征图,得到该训练样本的目标损失值,直至得到每个训练样本的目标损失值。
其中,目标损失值为:根据目标损失函数所得到的损失值。
具体地,将任一训练样本输入至学生模型中,得到该训练样本在所述学生模型中所对应的所有阶段的多个第一特征图,并将所述任一训练样本在所述学生模型中所对应的所有阶段的多个第一特征图和所述任一训练样本在每个教师模型中所对应的所有阶段的第二特征图代入所述目标损失函数,得到该训练样本的目标损失值。重复上述方式,直至得到每个训练样本的目标损失值。
步骤132:基于所有的目标损失值,对所述学生模型的参数进行优化,得到优化后的学生模型,将所述优化后的学生模型作为所述学生模型并返回执行步骤131,直至满足预设迭代训练条件时,将所述优化后的学生模型确定为所述训练好的学生模型。
其中,预设迭代训练条件为:最大迭代训练次数或损失函数收敛。
具体地,基于所有的目标损失值,对所述学生模型的参数进行优化,得到优化后的学生模型,判断优化后的学生模型是否满足预设迭代训练条件;若是,则将所述优化后的学生模型确定为所述训练好的学生模型。若否,则将所述优化后的学生模型作为所述学生模型并返回执行步骤131,直至满足预设迭代训练条件时,将所述优化后的学生模型确定为所述训练好的学生模型。
本实施例的技术方案通过多个教师模型对学生模型进行特征图蒸馏,在提高学生模型对于图像内容识别效果的同时,避免了因单个教师模型的特征图蒸馏所造成的偏差。
图4示出了本发明提供的一种基于多教师模型的特征图蒸馏方法的第二实施例的流程示意图。如图4所示,包括如下步骤:
步骤210:利用多个教师模型对学生模型进行多阶段的特征图蒸馏,得到所述学生模型的目标损失函数。
其中,①所述目标损失函数包括:原始损失函数和每个教师模型在每一阶段的特征图蒸馏损失函数。②所述学生模型和每个教师模型的类型相同,所述学生模型和所有的教师模型中的任一模型的类型为:图像分割模型、图像分类模型或目标检测模型。
步骤220:将任一训练样本分别输入每个教师模型中,得到该训练样本在每个教师模型中所对应的所有阶段的阶段特征图,直至得到每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图。
步骤230:基于所述目标损失函数、每个训练样本以及每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图,对所述学生模型进行迭代训练,得到训练好的学生模型。
步骤240:当所述学生模型为图像分割模型时,将待测图像输入至所述训练好的学生模型,得到所述待测图像的图像分割结果;或,当所述学生模型为图像分类模型时,将所述待测图像输入至所述训练好的学生模型,得到所述待测图像的图像分类结果;或,当所述学生模型为目标检测模型时,将所述待测图像输入至所述训练好的学生模型,得到所述待测图像的目标检测结果。
其中,①待测图像为:任意选取的图像。②当学生模型为图像分割模型时,待测图像的图像分割结果为:包含待测图像中的每类图像(物体)的分割预测值的图像。③当学生模型为图像分类模型时,待测图像的图像分类结果为:包含待测图像中的每类图像(物体)的分类预测值的图像。④当学生模型为图像分类模型时,待测图像的图像分类结果为:包含待测图像中的待测物所在位置以及该待测物为目标检测物的概率值的图像。
本实施例的技术方案在第一实施例的基础上,进一步通过训练好的学生模型对待测图像进行识别,得到更为精准的识别结果。
图5示出了本发明提供的一种基于多教师模型的特征图蒸馏系统的实施例的结构示意图。如图5所示,该系统300包括:构建模块310、处理模块320和运行模块330。
所述构建模块310用于:利用多个教师模型对学生模型进行多阶段的特征图蒸馏,得到所述学生模型的目标损失函数;其中,所述目标损失函数包括:原始损失函数和每个教师模型在每一阶段的特征图蒸馏损失函数;
所述处理模块320用于:将任一训练样本分别输入每个教师模型中,得到该训练样本在每个教师模型中所对应的所有阶段的阶段特征图,直至得到每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图;
所述运行模块330用于:基于所述目标损失函数、每个训练样本以及每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图,对所述学生模型进行迭代训练,得到训练好的学生模型。
较优地,所述目标损失函数为:Loss=Lossglobal+a1Loss1+a2Loss2+…+anLossn;
其中,Loss为所述目标损失函数,Lossglobal为所述原始损失函数,Loss1为第一个教师模型在所有阶段的特征图蒸馏损失函数,Loss2为第二个教师在所有阶段的特征图蒸馏损失函数,Lossn为第n个教师模型在所有阶段的特征图蒸馏损失函数,a1为第一个教师模型的特征图蒸馏损失的系数,a2为第二个教师模型的特征图蒸馏损失的系数,an为第n个教师模型的特征图蒸馏损失的系数,
其中,j表示第j阶段,k表示总阶段数,/>表示第一个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第一个第一特征图的第i个像素点特征,m表示第j阶段特征图中的像素点的个数;/> 表示第二个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第二个第一特征图的第i个像素点特征;/>表示第n个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第n个第一特征图的第i个像素点特征。
本实施例的技术方案通过多个教师模型对学生模型进行特征图蒸馏,在提高学生模型对于图像内容识别效果的同时,避免了因单个教师模型的特征图蒸馏所造成的偏差。
上述关于本实施例的一种基于多教师模型的特征图蒸馏系统300中的各参数和各个模块实现相应功能的步骤,可参考上文中关于一种基于多教师模型的特征图蒸馏方法的实施例中的各参数和步骤,在此不做赘述。
本发明实施例提供的一种存储介质,包括:存储介质中存储有指令,当计算机读取所述指令时,使所述计算机执行如一种基于多教师模型的特征图蒸馏方法的步骤,具体可参考上文中一种基于多教师模型的特征图蒸馏方法的实施例中的各参数和步骤,在此不做赘述。
计算机存储介质例如:优盘、移动硬盘等。
所属技术领域的技术人员知道,本发明可以实现为方法、系统和存储介质。
因此,本发明可以具体实现为以下形式,即:可以是完全的硬件、也可以是完全的软件(包括固件、驻留软件、微代码等),还可以是硬件和软件结合的形式,本文一般称为“电路”、“模块”或“系统”。此外,在一些实施例中,本发明还可以实现为在一个或多个计算机可读介质中的计算机程序产品的形式,该计算机可读介质中包含计算机可读的程序代码。可以采用一个或多个计算机可读的介质的任意组合。计算机可读介质可以是计算机可读信号介质或者计算机可读存储介质。计算机可读存储介质例如可以是但不限于——电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。计算机可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式计算机磁盘、硬盘、随机存取存储器(RAM),只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑磁盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。在本文件中,计算机可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。尽管上面已经示出和描述了本发明的实施例,可以理解的是,上述实施例是示例性的,不能理解为对本发明的限制,本领域的普通技术人员在本发明的范围内可以对上述实施例进行变化、修改、替换和变型。
Claims (6)
1.一种基于多教师模型的特征图蒸馏方法,其特征在于,包括:
利用多个教师模型对学生模型进行多阶段的特征图蒸馏,得到所述学生模型的目标损失函数;其中,所述目标损失函数包括:原始损失函数和每个教师模型在每一阶段的特征图蒸馏损失函数;
将任一训练样本分别输入每个教师模型中,得到该训练样本在每个教师模型中所对应的所有阶段的阶段特征图,直至得到每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图;
基于所述目标损失函数、每个训练样本以及每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图,对所述学生模型进行迭代训练,得到训练好的学生模型;
所述目标损失函数为:Loss=Lossglobal+a1Loss1+a2Loss2+…+anLossn;
其中,Loss为所述目标损失函数,Lossglobal为所述原始损失函数,Loss1为第一个教师模型在所有阶段的特征图蒸馏损失函数,Loss2为第二个教师在所有阶段的特征图蒸馏损失函数,Lossn为第n个教师模型在所有阶段的特征图蒸馏损失函数,a1为第一个教师模型的特征图蒸馏损失的系数,a2为第二个教师模型的特征图蒸馏损失的系数,an为第n个教师模型的特征图蒸馏损失的系数,
其中,j表示第j阶段,k表示总阶段数,/>表示第一个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第一个第一特征图的第i个像素点特征,m表示第j阶段特征图中的像素点的个数; 表示第二个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第二个第一特征图的第i个像素点特征;/> 表示第n个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第n个第一特征图的第i个像素点特征;
获取任一训练样本在所述学生模型所对应的任一阶段的多个第一特征图的步骤,包括:
将所述任一训练样本输入至所述学生模型中,得到该训练样本在所述任一阶段的阶段特征图,并基于教师模型的数量,将该阶段特征图进行均分,得到该阶段特征图对应的多个均分特征图,并对该阶段特征图对应的每个均分特征图分别进行压缩和标准化处理,得到该训练样本在该阶段的多个第一特征图;
所述基于所述目标损失函数、每个训练样本以及每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图,对所述学生模型进行迭代训练,得到训练好的学生模型的步骤,包括:
将任一训练样本输入至所述学生模型中,得到该训练样本在所述学生模型中所对应的所有阶段的多个第一特征图,并基于所述目标损失函数、所述任一训练样本在所述学生模型中所对应的所有阶段的多个第一特征图和所述任一训练样本在每个教师模型中所对应的所有阶段的第二特征图,得到该训练样本的目标损失值,直至得到每个训练样本的目标损失值;
基于所有的目标损失值,对所述学生模型的参数进行优化,得到优化后的学生模型,将所述优化后的学生模型作为所述学生模型并返回执行将任一训练样本输入至所述学生模型中的步骤,直至满足预设迭代训练条件时,将所述优化后的学生模型确定为所述训练好的学生模型。
2.根据权利要求1所述的基于多教师模型的特征图蒸馏方法,其特征在于,获取所述任一训练样本在任一教师模型所对应的任一阶段的第二特征图的步骤,包括:
对所述任一训练样本在所述任一教师模型中所对应的所述任一阶段的阶段特征图进行压缩和标准化处理,得到该训练样本在该阶段的第二特征图。
3.根据权利要求1或2所述的基于多教师模型的特征图蒸馏方法,其特征在于,所述学生模型和每个教师模型的类型相同,所述学生模型和所有的教师模型中的任一模型的类型为:图像分割模型、图像分类模型或目标检测模型。
4.根据权利要求3所述的基于多教师模型的特征图蒸馏方法,其特征在于,还包括:
当所述学生模型为图像分割模型时,将待测图像输入至所述训练好的学生模型,得到所述待测图像的图像分割结果;或,当所述学生模型为图像分类模型时,将所述待测图像输入至所述训练好的学生模型,得到所述待测图像的图像分类结果;或,当所述学生模型为目标检测模型时,将所述待测图像输入至所述训练好的学生模型,得到所述待测图像的目标检测结果。
5.一种基于多教师模型的特征图蒸馏系统,其特征在于,包括:构建模块、处理模块和运行模块;
所述构建模块用于:利用多个教师模型对学生模型进行多阶段的特征图蒸馏,得到所述学生模型的目标损失函数;其中,所述目标损失函数包括:原始损失函数和每个教师模型在每一阶段的特征图蒸馏损失函数;
所述处理模块用于:将任一训练样本分别输入每个教师模型中,得到该训练样本在每个教师模型中所对应的所有阶段的阶段特征图,直至得到每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图;
所述运行模块用于:基于所述目标损失函数、每个训练样本以及每个训练样本在每个教师模型中所对应的所有阶段的阶段特征图,对所述学生模型进行迭代训练,得到训练好的学生模型;
所述目标损失函数为:Loss=Lossglobal+a1Loss1+a2Loss2+…+anLossn;
其中,Loss为所述目标损失函数,Lossglobal为所述原始损失函数,Loss1为第一个教师模型在所有阶段的特征图蒸馏损失函数,Loss2为第二个教师在所有阶段的特征图蒸馏损失函数,Lossn为第n个教师模型在所有阶段的特征图蒸馏损失函数,a1为第一个教师模型的特征图蒸馏损失的系数,a2为第二个教师模型的特征图蒸馏损失的系数,an为第n个教师模型的特征图蒸馏损失的系数,
其中,j表示第j阶段,k表示总阶段数,/>表示第一个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第一个第一特征图的第i个像素点特征,m表示第j阶段特征图中的像素点的个数; 表示第二个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第二个第一特征图的第i个像素点特征;/> 表示第n个教师模型在第j阶段的第二特征图的第i个像素点特征,/>表示所述学生模型在第j阶段的第n个第一特征图的第i个像素点特征;
还包括:采集模块;
所述采集模块用于:将所述任一训练样本输入至所述学生模型中,得到该训练样本在所述任一阶段的阶段特征图,并基于教师模型的数量,将该阶段特征图进行均分,得到该阶段特征图对应的多个均分特征图,并对该阶段特征图对应的每个均分特征图分别进行压缩和标准化处理,得到该训练样本在该阶段的多个第一特征图;
所述运行模块具体用于:
将任一训练样本输入至所述学生模型中,得到该训练样本在所述学生模型中所对应的所有阶段的多个第一特征图,并基于所述目标损失函数、所述任一训练样本在所述学生模型中所对应的所有阶段的多个第一特征图和所述任一训练样本在每个教师模型中所对应的所有阶段的第二特征图,得到该训练样本的目标损失值,直至得到每个训练样本的目标损失值;
基于所有的目标损失值,对所述学生模型的参数进行优化,得到优化后的学生模型,将所述优化后的学生模型作为所述学生模型并返回执行将任一训练样本输入至所述学生模型中的过程,直至满足预设迭代训练条件时,将所述优化后的学生模型确定为所述训练好的学生模型。
6.一种存储介质,其特征在于,所述存储介质中存储有指令,当计算机读取所述指令时,使所述计算机执行如权利要求1至4中任一项所述的基于多教师模型的特征图蒸馏方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211598032.7A CN116385844B (zh) | 2022-12-12 | 2022-12-12 | 一种基于多教师模型的特征图蒸馏方法、系统和存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211598032.7A CN116385844B (zh) | 2022-12-12 | 2022-12-12 | 一种基于多教师模型的特征图蒸馏方法、系统和存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116385844A CN116385844A (zh) | 2023-07-04 |
CN116385844B true CN116385844B (zh) | 2023-11-10 |
Family
ID=86979294
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211598032.7A Active CN116385844B (zh) | 2022-12-12 | 2022-12-12 | 一种基于多教师模型的特征图蒸馏方法、系统和存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116385844B (zh) |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112560693A (zh) * | 2020-12-17 | 2021-03-26 | 华中科技大学 | 基于深度学习目标检测的高速公路异物识别方法和系统 |
CN112734789A (zh) * | 2021-01-28 | 2021-04-30 | 重庆兆琨智医科技有限公司 | 一种基于半监督学习和点渲染的图像分割方法及系统 |
CN112949766A (zh) * | 2021-04-07 | 2021-06-11 | 成都数之联科技有限公司 | 目标区域检测模型训练方法及系统及装置及介质 |
CN114298224A (zh) * | 2021-12-29 | 2022-04-08 | 云从科技集团股份有限公司 | 图像分类方法、装置以及计算机可读存储介质 |
KR20220096099A (ko) * | 2020-12-30 | 2022-07-07 | 성균관대학교산학협력단 | 지식 증류에서 총 cam 정보를 이용한 교사 지원 어텐션 전달의 학습 방법 및 장치 |
CN115204412A (zh) * | 2022-07-15 | 2022-10-18 | 润联软件系统(深圳)有限公司 | 基于知识蒸馏的问答模型压缩方法、装置及相关设备 |
-
2022
- 2022-12-12 CN CN202211598032.7A patent/CN116385844B/zh active Active
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112560693A (zh) * | 2020-12-17 | 2021-03-26 | 华中科技大学 | 基于深度学习目标检测的高速公路异物识别方法和系统 |
KR20220096099A (ko) * | 2020-12-30 | 2022-07-07 | 성균관대학교산학협력단 | 지식 증류에서 총 cam 정보를 이용한 교사 지원 어텐션 전달의 학습 방법 및 장치 |
CN112734789A (zh) * | 2021-01-28 | 2021-04-30 | 重庆兆琨智医科技有限公司 | 一种基于半监督学习和点渲染的图像分割方法及系统 |
CN112949766A (zh) * | 2021-04-07 | 2021-06-11 | 成都数之联科技有限公司 | 目标区域检测模型训练方法及系统及装置及介质 |
CN114298224A (zh) * | 2021-12-29 | 2022-04-08 | 云从科技集团股份有限公司 | 图像分类方法、装置以及计算机可读存储介质 |
CN115204412A (zh) * | 2022-07-15 | 2022-10-18 | 润联软件系统(深圳)有限公司 | 基于知识蒸馏的问答模型压缩方法、装置及相关设备 |
Also Published As
Publication number | Publication date |
---|---|
CN116385844A (zh) | 2023-07-04 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111860573B (zh) | 模型训练方法、图像类别检测方法、装置和电子设备 | |
CN109189767B (zh) | 数据处理方法、装置、电子设备及存储介质 | |
CN111259625A (zh) | 意图识别方法、装置、设备及计算机可读存储介质 | |
CN113673346B (zh) | 一种基于多尺度SE-Resnet的电机振动数据处理与状态识别方法 | |
CN110826558B (zh) | 图像分类方法、计算机设备和存储介质 | |
CN113065525B (zh) | 年龄识别模型训练方法、人脸年龄识别方法及相关装置 | |
CN115482418B (zh) | 基于伪负标签的半监督模型训练方法、系统及应用 | |
CN110929524A (zh) | 数据筛选方法、装置、设备及计算机可读存储介质 | |
CN111680753A (zh) | 一种数据标注方法、装置、电子设备及存储介质 | |
CN114971375A (zh) | 基于人工智能的考核数据处理方法、装置、设备及介质 | |
CN118506846A (zh) | 一种硬盘测试装置、系统及方法 | |
CN117217277A (zh) | 语言模型的预训练方法、装置、设备、存储介质及产品 | |
CN117975464A (zh) | 基于U-Net的电气二次图纸文字信息的识别方法及系统 | |
CN116778300B (zh) | 一种基于知识蒸馏的小目标检测方法、系统和存储介质 | |
CN112464966B (zh) | 鲁棒性估计方法、数据处理方法和信息处理设备 | |
CN116385844B (zh) | 一种基于多教师模型的特征图蒸馏方法、系统和存储介质 | |
CN116504230A (zh) | 数据闭环方法、装置、计算机设备及计算机可读存储介质 | |
CN115984640A (zh) | 一种基于组合蒸馏技术的目标检测方法、系统和存储介质 | |
CN113378866B (zh) | 图像分类方法、系统、存储介质及电子设备 | |
CN115049546A (zh) | 样本数据处理方法、装置、电子设备及存储介质 | |
CN113987254A (zh) | 基于计算机视觉的图书图像检索方法 | |
CN110942179A (zh) | 一种自动驾驶路线规划方法、装置及车辆 | |
CN110555338A (zh) | 对象识别方法和装置、神经网络生成方法和装置 | |
CN116416456B (zh) | 基于自蒸馏的图像分类方法、系统、存储介质和电子设备 | |
CN116431757B (zh) | 基于主动学习的文本关系抽取方法、电子设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |