CN114819076A - 网络蒸馏方法、装置、计算机设备、存储介质 - Google Patents
网络蒸馏方法、装置、计算机设备、存储介质 Download PDFInfo
- Publication number
- CN114819076A CN114819076A CN202210428589.XA CN202210428589A CN114819076A CN 114819076 A CN114819076 A CN 114819076A CN 202210428589 A CN202210428589 A CN 202210428589A CN 114819076 A CN114819076 A CN 114819076A
- Authority
- CN
- China
- Prior art keywords
- network
- loss
- distillation
- loss function
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Feedback Control In General (AREA)
Abstract
本公开涉及一种网络蒸馏方法、装置、计算机设备、存储介质和计算机程序产品。所述方法包括:获取网络蒸馏要求,所述网络蒸馏要求包括网络蒸馏中的大网络层数,以及与所述大网络层数对应的小网络层数;根据所述网络蒸馏要求,构建蒸馏网络的大网络;根据所述网络蒸馏要求,构建蒸馏网络的小网络;利用训练数据集训练所述大网络,在精度超过预设阈值时,得到所述大网络输出的第一输出结果;使用所述小网络、所述第一输出结果和预设损失函数进行蒸馏训练,直至符合收敛条件时,确定目标小网络,所述预设损失函数至少基于三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数得到。采用本方法能够减小蒸馏训练过程中的精度损失。
Description
技术领域
本公开涉及大数据技术领域,特别是涉及一种网络蒸馏方法、装置、计算机设备、存储介质。
背景技术
随着大数据技术的发展,出现了网络蒸馏技术。网络蒸馏是深度学习的一种压缩加速方式,通过训练好的大网络(或称为教师网络),指导小网络训练(或称为学生网络),减少精度损失。
目前,网络蒸馏技术虽然有所发展,但是网络精度仍然在一个较低的水平,网络蒸馏的精度损失有待降低。
发明内容
基于此,有必要针对上述技术问题,提供一种能够提升网络蒸馏精度的网络蒸馏方法、装置、计算机设备、计算机可读存储介质和计算机程序产品。
第一方面,本公开提供了一种网络蒸馏方法。所述方法包括:
获取网络蒸馏要求,所述网络蒸馏要求包括网络蒸馏中的大网络层数,以及与所述大网络层数对应的小网络层数;
根据所述网络蒸馏要求,构建蒸馏网络的大网络;
根据所述网络蒸馏要求,构建蒸馏网络的小网络;
利用训练数据集训练所述大网络,在精度超过预设阈值时,得到所述大网络输出的第一输出结果;
使用所述小网络、所述第一输出结果和预设损失函数进行蒸馏训练,直至符合收敛条件时,确定目标小网络,所述预设损失函数至少基于三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数得到。
在其中一个实施例中,所述预设损失函数包括:
Loss_all=Loss_base+Loss_feature+Loss_embedding
Loss_base=Loss_triplet+Loss_CE
其中,Loss_all表示预设损失函数,Loss_base表示基本损失函数,Loss_feature表示特征损失函数,Loss_embedding表示嵌入层损失函数,Loss_triplet表示三元组损失函数,Loss_CE表示交叉熵损失函数。
在其中一个实施例中,所述嵌入层损失函数包括:
Loss_embedding=Loss_batch_distance+Loss_batch_angle
其中,Loss_embedding表示嵌入层损失函数,Loss_batch_angle表示批量内角度损失函数,Loss_batch_distance表示批量内间距损失函数。
在其中一个实施例中,所述特征损失函数包括:
Loss_feature=||norm(Teacher_feature)–norm(Student_feature)||^2
其中,Loss_feature表示特征损失函数,norm表示归一化,Teacher_feature表示大网络的特征表征,Student_feature表示小网络的特征表征。
在其中一个实施例中,所述构建网络蒸馏的大网络和所述构建网络蒸馏的小网络包括:
基于Resnet50残差网络,构建大网络;
基于Resnet18残差网络,构建小网络。
在其中一个实施例中,所述蒸馏训练包括:
采用随机梯度下降+动量法学习,训练数据集采用公开分类数据集CIFAR100,迭代训练次数为80次,批量大小设定为128,学习率设定为1e-4。
在其中一个实施例中,所述训练所述大网络使用的损失函数为所述基本损失函数。
第二方面,本公开还提供了一种网络蒸馏装置。所述装置包括:
要求获取模块,用于获取网络蒸馏要求,所述网络蒸馏要求包括网络蒸馏中的大网络层数,以及与所述大网络层数对应的小网络层数;
大网络构建模块,用于根据所述网络蒸馏要求,构建蒸馏网络的大网络;
小网络构建模块,用于根据所述网络蒸馏要求,构建蒸馏网络的小网络;
大网络训练模块,用于利用训练数据集训练所述大网络,在精度超过预设阈值时,得到所述大网络输出的第一输出结果;
小网络训练模块,用于使用所述小网络、所述第一输出结果和预设损失函数进行蒸馏训练,直至符合收敛条件时,确定目标小网络,所述预设损失函数至少基于三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数得到。
在其中一个实施例中,所述蒸馏训练模块用于使用下述预设损失函数进行蒸馏训练:
Loss_all=Loss_base+Loss_feature+Loss_embedding
Loss_base=Loss_triplet+Loss_CE
其中,Loss_all表示预设损失函数,Loss_base表示基本损失函数,Loss_feature表示Resnet特征匹配损失函数,Loss_embedding表示嵌入层损失函数,Loss_triplet为三元组损失函数,Loss_CE为交叉熵损失函数。
在其中一个实施例中,所述小网络训练模块使用的嵌入层损失函数包括:
Loss_embedding=Loss_batch_distance+Loss_batch_angle
其中,Loss_embedding表示嵌入层损失函数,Loss_batch_angle表示批量内角度损失函数,Loss_batch_distance表示批量内间距损失函数。
在其中一个实施例中,所述小网络训练模块使用的特征损失函数包括:
Loss_feature=||norm(Teacher_feature)–norm(Student_feature)||^2
其中,Loss_feature表示特征损失函数,norm表示归一化,Teacher_feature表示大网络的特征表征,Student_feature表示小网络的特征表征。
在其中一个实施例中,所述大网络构建模块用于基于Resnet50残差网络,构建大网络;所述小网络构建模块用于基于Resnet18残差网络,构建目标小网络。
在其中一个实施例中,所述小网络训练模块用于:
采用随机梯度下降+动量法学习,训练数据集采用公开分类数据集CIFAR100,迭代训练次数为80次,批量大小设定为128,学习率设定为1e-4。
在其中一个实施例中,所述大网络训练模块使用的损失函数为所述基本损失函数。
第三方面,本公开还提供了一种计算机设备。所述计算机设备包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现以下步骤:
获取网络蒸馏要求,所述网络蒸馏要求包括网络蒸馏中的大网络层数,以及与所述大网络层数对应的小网络层数;
根据所述网络蒸馏要求,构建蒸馏网络的大网络;
根据所述网络蒸馏要求,构建蒸馏网络的小网络;
利用训练数据集训练所述大网络,在精度超过预设阈值时,得到所述大网络输出的第一输出结果;
使用所述小网络、所述第一输出结果和预设损失函数进行蒸馏训练,直至符合收敛条件时,确定目标小网络,所述预设损失函数至少基于三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数得到。
第四方面,本公开还提供了一种计算机可读存储介质。所述计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现以下步骤:
获取网络蒸馏要求,所述网络蒸馏要求包括网络蒸馏中的大网络层数,以及与所述大网络层数对应的小网络层数;
根据所述网络蒸馏要求,构建蒸馏网络的大网络;
根据所述网络蒸馏要求,构建蒸馏网络的小网络;
利用训练数据集训练所述大网络,在精度超过预设阈值时,得到所述大网络输出的第一输出结果;
使用所述小网络、所述第一输出结果和预设损失函数进行蒸馏训练,直至符合收敛条件时,确定目标小网络,所述预设损失函数至少基于三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数得到。
第五方面,本公开还提供了一种计算机程序产品。所述计算机程序产品,包括计算机程序,该计算机程序被处理器执行时实现以下步骤:
获取网络蒸馏要求,所述网络蒸馏要求包括网络蒸馏中的大网络层数,以及与所述大网络层数对应的小网络层数;
根据所述网络蒸馏要求,构建蒸馏网络的大网络;
根据所述网络蒸馏要求,构建蒸馏网络的小网络;
利用训练数据集训练所述大网络,在精度超过预设阈值时,得到所述大网络输出的第一输出结果;
使用所述小网络、所述第一输出结果和预设损失函数进行蒸馏训练,直至符合收敛条件时,确定目标小网络,所述预设损失函数至少基于三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数得到。
上述网络蒸馏方法、装置、计算机设备、存储介质和计算机程序产品,通过在网络蒸馏过程中,设置与大网络层数相对应的小网络层数,并采用特殊设计的预设损失函数,能够达到减小蒸馏训练过程中的精度损失的目的。
附图说明
此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本公开的实施例,并与说明书一起用于解释本公开的原理,并不构成对本公开的不当限定。
图1为一个实施例中网络蒸馏方法的应用环境图;
图2为一个实施例中网络蒸馏方法的流程示意图;
图3为另一个实施例中网络蒸馏方法的流程示意图;
图4为一个实施例中网络蒸馏装置的结构框图;
图5为一个实施例中计算机设备的内部结构图。
具体实施方式
为了使本公开的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本公开进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本公开,并不用于限定本公开。
需要说明的是,本公开的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本公开的实施例能够以除了在这里图示或描述的那些以外的顺序实施。以下示例性实施例中所描述的实施方式并不代表与本公开相一致的所有实施方式。相反,它们仅是与如所附权利要求书中所详述的、本公开的一些方面相一致的装置和方法的例子。
本公开实施例提供的网络蒸馏方法,可以应用于如图1所示的应用环境中。其中,数据存储系统可以存储服务器102需要处理的数据。数据存储系统可以集成在服务器102上,也可以放在云上或其他网络服务器上。服务器102拥有数据接收端。所述数据接收端获取网络蒸馏要求。服务器102根据所述网络蒸馏要求,构建蒸馏网络的大网络。服务器102根据所述网络蒸馏要求,构建蒸馏网络的小网络。服务器102利用训练数据集训练所述大网络,在精度超过预设阈值时,得到所述大网络输出的第一输出结果。服务器102使用所述小网络、所述第一输出结果和预设损失函数进行蒸馏训练,直至符合收敛条件时,确定目标小网络。其中,服务器102可以用独立的服务器或者是多个服务器组成的服务器集群来实现。
在一个实施例中,如图2所示,提供了一种网络蒸馏方法,以该方法应用于图1中的应用环境为例进行说明,包括以下步骤:
S202,获取网络蒸馏要求,所述网络蒸馏要求包括网络蒸馏中的大网络层数,以及与所述大网络层数对应的小网络层数。
具体地,获取网络蒸馏要求,所述网络蒸馏要求包括网络蒸馏中的大网络层数,以及与所述大网络层数对应的小网络层数。所述网络蒸馏要求还可以包括对大网络和/或小网络的类型要求、大小要求、精度要求等内容。所述网络蒸馏要求还可以包括对损失函数的要求。所述网络蒸馏要求还可以包括对训练数据集的要求。
S204,根据所述网络蒸馏要求,构建蒸馏网络的大网络。
具体地,根据所述网络蒸馏要求中对大网络的要求,构建用于网络蒸馏的大网络。所述大网络一般被称为Teacher Net,或称为教师网络。
S206,根据所述网络蒸馏要求,构建蒸馏网络的小网络。
具体地,根据所述网络蒸馏要求中对小网络的要求,构建用于网络蒸馏的小网络。所述大网络一般被称为Student Net,或称为学生网络。
S208,利用训练数据集训练所述大网络,在精度超过预设阈值时,得到所述大网络输出的第一输出结果。
其中,精度可以是指神经网络的准确度。
具体地,利用训练数据集训练所述大网络。所述训练数据集以满足所述大网络的训练需求为准,具体不做限定。训练所述大网络可以使用下述损失函数:
Loss_Teacher=Loss_triplet+Loss_CE
其中,Loss_Teacher表示训练所述大网络使用的损失函数,Loss_triplet为三元组损失函数,Loss_CE为交叉熵损失函数。在所述大网络的精度超过预设阈值时,可以停止所述大网络的训练。将完成训练的大网络的输出结果作为第一输出结果。
S210,使用所述小网络、所述第一输出结果和预设损失函数进行蒸馏训练,直至符合收敛条件时,确定目标小网络,所述预设损失函数至少基于三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数得到。
具体地,利用预设损失函数和训练数据集对所述小网络进行网络蒸馏训练。使用所述第一输出结果为所述小网络的网络蒸馏训练提供基准。训练所述小网络使用的训练数据集可以与训练相应的大网络使用的数据集相同。可以根据计算出的损失值,采用反向传播的方式,更新小网络的参数,直至所述小网络收敛。当所述小网络符合收敛条件时,可以停止训练,将训练后的小网络确定为目标小网络。所述预设损失函数至少基于三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数得到。所述特征损失函数可以是指与残差网络(残差网络的英文名称为Resnet)的特征匹配有关的损失函数。所述嵌入层(嵌入层的英文名称为Embedding Layer)损失函数可以是指与嵌入层有关的损失函数。
上述网络蒸馏方法中,通过在网络蒸馏过程中,设置与大网络层数相对应的小网络层数,并采用特殊设计的预设损失函数,能够达到减小蒸馏训练过程中的精度损失的目的。
在一个实施例中,所述预设损失函数包括:
Loss_all=Loss_base+Loss_feature+Loss_embedding
Loss_base=Loss_triplet+Loss_CE
其中,Loss_all表示预设损失函数,Loss_base表示基本损失函数,Loss_feature表示特征损失函数,Loss_embedding表示嵌入层损失函数,Loss_triplet表示三元组损失函数,Loss_CE表示交叉熵损失函数。
具体地,将三元组损失函数与交叉熵损失函数进行加和,构建基本损失函数。再将基本损失函数、特征损失函数和嵌入层损失函数进行加和,构建预设损失函数。三元组损失函数(Loss_triplet)可以表示如下:
Loss_triplet=max(0,(Dap-Dan+alpha))
其中Dap指的是基准特征向量与基准同类特征向量的欧式距离,Dan指的是基准特征向量与基准异类特征向量的欧氏距离,alpha表示边界值,可以设置alpha的具体值为0.2。
交叉熵损失函数(Loss_CE)可以表示如下:
其中,label表示神经网络中的label,中文含义为标签。output表示神经网络中的output,中文含义为输出。
本实施例中,通过对蒸馏训练中的损失函数进行特殊设计,结合三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数,得到所述预设损失函数,能够达到降低网络蒸馏的精度损失的目的。
在一个实施例中,所述嵌入层损失函数包括:
Loss_embedding=Loss_batch_distance+Loss_batch_angle
其中,Loss_embedding表示嵌入层损失函数,Loss_batch_angle表示批量内角度损失函数,Loss_batch_distance表示批量内间距损失函数。
具体地,选用批量内间距损失函数和批量内角度损失函数,对两种损失函数进行加和,构建嵌入层损失函数。所述批量内间距损失函数可以表示如下:
Loss_batch_distance=||D_teacher–D_student||
其中,Loss_batch_distance表示批量内间距损失函数,D_teacher表示大网络的特征矩阵,D_student表示小网络的特征矩阵。例如当图片批量为128张时,所述特征矩阵为一个大小为128×128的二维矩阵。批量内间距损失函数在网络蒸馏中可以使小网络的特征矩阵更加接近大网络的特征矩阵。
批量内角度损失函数可以表示如下:
Loss_batch_angle=||A_teacher–A_student||
其中,Loss_batch_angle表示批量内角度损失函数,A_teacher表示大网络的角度矩阵,A_student表示小网络的角度矩阵。
本实施例中,通过使用批量内间距损失函数和批量内角度损失函数构建嵌入层损失函数,有利于达到降低网络蒸馏的精度损失的目的。
在一个实施例中,所述特征损失函数包括:
Loss_feature=||norm(Teacher_feature)–norm(Student_feature)||^2
其中,Loss_feature表示特征损失函数,norm表示归一化,Teacher_feature表示大网络的特征表征,Student_feature表示小网络的特征表征。
具体地,所述归一化可以是指将数据整合为均值为0,方差为1的正态分布。在训练小网络的过程中,小网络中间层的特征图(特征图的英文名称为feature maps)的特征表征应当与大网络中的特征表征具有一致相关性。
本实施例中,通过在预设损失函数中使用上述自定义的特征损失函数,有利于达到降低网络蒸馏的精度损失的目的。
在一个实施例中,所述构建网络蒸馏的大网络和所述构建网络蒸馏的小网络包括:
基于Resnet50残差网络,构建大网络;
基于Resnet18残差网络,构建小网络。
具体地,Resnet50残差网络和Resnet18残差网络均属于现有的残差网络。基于Resnet50残差网络,构建大网络。在构建大网络的过程中,可以在Resnet50残差网络和神经网络的输出层之间加入嵌入层(嵌入层的英文名称为embedding layer),用以增强网络学习表征的能力。基于Resnet18残差网络,构建小网络。在构建小网络的过程中,可以在Resnet18残差网络和神经网络的输出层之间也加入嵌入层,用以增强网络学习表征的能力。
本实施例中,通过对大网络和小网络构建过程中使用的基础神经网络进行选择,有利于从神经网络构建方面达到降低网络蒸馏的精度损失的目的。
在一个实施例中,所述蒸馏训练包括:
采用随机梯度下降+动量法学习,训练数据集采用公开分类数据集CIFAR100,迭代训练次数为80次,批量大小设定为128,学习率设定为1e-4。
具体地,随机梯度下降的英文简称为SGD,英文全称为Stochastic GradientDescent。动量法的英文名称为Momentum。在所述小网络的训练中,学习方法采用随机梯度下降法和动量法相结合的方法。训练数据集采用公开分类数据集CIFAR100,迭代训练次数为80次,批量大小设定为128,学习率设定为1e-4(1e-4表示10的-4次方)。本实施例中,通过对小网络的训练过程进行优化设置,有利于达到降低网络蒸馏的精度损失的目的。在大网络的训练中,可以使用与训练小网络相同的方法和参数设置。
在一个实施例中,所述训练所述大网络使用的损失函数为所述基本损失函数。
具体地,所述训练所述大网络使用的损失函数和所述基本损失函数相同,具体如下:
Loss_teacher=Loss_triplet+Loss_CE
其中,Loss_teacher表示训练所述大网络使用的损失函数,Loss_triplet表示三元组损失函数,Loss_CE表示交叉熵损失函数。
本实施例中,通过在训练大网络过程中使用基于三元组损失函数和交叉熵损失函数构建的损失函数,有利于达到降低网络蒸馏过程中的精度损失的目的。
在一个实施例中,基于Resnet50残差网络,构建大网络;基于Resnet18残差网络,构建小网络;训练大网络和小网络使用的方法和设置相同,均为:采用随机梯度下降+动量法学习,训练数据集采用公开分类数据集CIFAR100,迭代训练次数为80次,批量大小设定为128,学习率设定为1e-4;训练所述大网络使用的损失函数如下:
Loss_teacher=Loss_triplet+Loss_CE
训练小网络使用的预设损失函数如下:
Loss_all=Loss_base+Loss_feature+Loss_embedding
其中:
Loss_base=Loss_triplet+Loss_CE
Loss_embedding=Loss_batch_distance+Loss_batch_angle
Loss_feature=||norm(Teacher_feature)–norm(Student_feature)||^2训练过程如图3所示,先训练构建的大网络,保存满足精度要求的大网络;在蒸馏训练中,将图片输入保存的大网络和构建的小网络中;训练过程中使用损失函数,不断更新小网络参数,直至小网络收敛。训练得到的大网络的准确率为76.6%,小网络的准确率为74.5%。
应该理解的是,虽然如上所述的各实施例所涉及的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,这些步骤可以以其它的顺序执行。而且,如上所述的各实施例所涉及的流程图中的至少一部分步骤可以包括多个步骤或者多个阶段,这些步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,这些步骤或者阶段的执行顺序也不必然是依次进行,而是可以与其它步骤或者其它步骤中的步骤或者阶段的至少一部分轮流或者交替地执行。
基于同样的发明构思,本公开实施例还提供了一种用于实现上述所涉及的网络蒸馏方法的网络蒸馏装置。该装置所提供的解决问题的实现方案与上述方法中所记载的实现方案相似,故下面所提供的一个或多个网络蒸馏装置实施例中的具体限定可以参见上文中对于网络蒸馏方法的限定,在此不再赘述。
基于上述所述的网络蒸馏方法实施例的描述,本公开还提供网络蒸馏装置。所述装置可以包括使用了本说明书实施例所述方法的系统(包括分布式系统)、软件(应用)、模块、组件、服务器、客户端等并结合必要的实施硬件的装置。基于同一创新构思,本公开实施例提供的一个或多个实施例中的装置如下面的实施例所述。由于装置解决问题的实现方案与方法相似,因此本说明书实施例具体的装置的实施可以参见前述方法的实施,重复之处不再赘述。以下所使用的,术语“单元”或者“模块”可以实现预定功能的软件和/或硬件的组合。尽管以下实施例所描述的装置较佳地以软件来实现,但是硬件,或者软件和硬件的组合的实现也是可能并被构想的。
在一个实施例中,如图4所示,提供了一种网络蒸馏装置,包括:要求获取模块302、大网络构建模块304、小网络构建模块306、大网络训练模块308、和小网络训练模块310,其中:
要求获取模块302,用于获取网络蒸馏要求,所述网络蒸馏要求包括网络蒸馏中的大网络层数,以及与所述大网络层数对应的小网络层数;
大网络构建模块304,用于根据所述网络蒸馏要求,构建蒸馏网络的大网络;
小网络构建模块306,用于根据所述网络蒸馏要求,构建蒸馏网络的小网络;
大网络训练模块308,用于利用训练数据集训练所述大网络,在精度超过预设阈值时,得到所述大网络输出的第一输出结果;
小网络训练模块310,用于使用所述小网络、所述第一输出结果和预设损失函数进行蒸馏训练,直至符合收敛条件时,确定目标小网络,所述预设损失函数至少基于三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数得到。
在一个实施例中,所述小网络训练模块310用于使用下述预设损失函数进行蒸馏训练:
Loss_all=Loss_base+Loss_feature+Loss_embedding
Loss_base=Loss_triplet+Loss_CE
其中,Loss_all表示预设损失函数,Loss_base表示基本损失函数,Loss_feature表示Resnet特征匹配损失函数,Loss_embedding表示嵌入层损失函数,Loss_triplet为三元组损失函数,Loss_CE为交叉熵损失函数。
在一个实施例中,所述小网络训练模块310使用的嵌入层损失函数包括:
Loss_embedding=Loss_batch_distance+Loss_batch_angle
其中,Loss_embedding表示嵌入层损失函数,Loss_batch_angle表示批量内角度损失函数,Loss_batch_distance表示批量内间距损失函数。
在一个实施例中,所述小网络训练模块310使用的特征损失函数包括:
Loss_feature=||norm(Teacher_feature)–norm(Student_feature)||^2
其中,Loss_feature表示特征损失函数,norm表示归一化,Teacher_feature表示大网络的特征表征,Student_feature表示小网络的特征表征。
在一个实施例中,所述大网络构建模块308用于基于Resnet50残差网络,构建大网络;所述小网络构建模块用于基于Resnet18残差网络,构建目标小网络。
在一个实施例中,所述小网络训练模块310用于:
采用随机梯度下降+动量法学习,训练数据集采用公开分类数据集CIFAR100,迭代训练次数为80次,批量大小设定为128,学习率设定为1e-4。
在一个实施例中,所述大网络训练模块308使用的损失函数为所述基本损失函数。
上述网络蒸馏装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
在一个实施例中,提供了一种计算机设备,该计算机设备可以是服务器,其内部结构图可以如图5所示。该计算机设备包括通过系统总线连接的处理器、存储器和网络接口。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质和内存储器。该非易失性存储介质存储有操作系统、计算机程序和数据库。该内存储器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的数据库用于存储网络蒸馏相关数据。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种网络蒸馏方法。
本领域技术人员可以理解,图5中示出的结构,仅仅是与本公开方案相关的部分结构的框图,并不构成对本公开方案所应用于其上的计算机设备的限定,具体的计算机设备可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
在一个实施例中,还提供了一种计算机设备,包括存储器和处理器,存储器中存储有计算机程序,该处理器执行计算机程序时实现上述各方法实施例中的步骤。
在一个实施例中,提供了一种计算机设备,包括存储器和处理器,存储器中存储有计算机程序,该处理器执行计算机程序时实现以下步骤:
获取网络蒸馏要求,所述网络蒸馏要求包括网络蒸馏中的大网络层数,以及与所述大网络层数对应的小网络层数;
根据所述网络蒸馏要求,构建蒸馏网络的大网络;
根据所述网络蒸馏要求,构建蒸馏网络的小网络;
利用训练数据集训练所述大网络,在精度超过预设阈值时,得到所述大网络输出的第一输出结果;
使用所述小网络、所述第一输出结果和预设损失函数进行蒸馏训练,直至符合收敛条件时,确定目标小网络,所述预设损失函数至少基于三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数得到。
在一个实施例中,处理器执行计算机程序时还实现以下步骤:
基于Resnet50残差网络,构建大网络;
基于Resnet18残差网络,构建小网络。
在一个实施例中,处理器执行计算机程序时还实现以下步骤:
采用随机梯度下降+动量法学习,训练数据集采用公开分类数据集CIFAR100,迭代训练次数为80次,批量大小设定为128,学习率设定为1e-4。
在一个实施例中,提供了一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现上述各方法实施例中的步骤。
在一个实施例中,提供了一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现以下步骤:
获取网络蒸馏要求,所述网络蒸馏要求包括网络蒸馏中的大网络层数,以及与所述大网络层数对应的小网络层数;
根据所述网络蒸馏要求,构建蒸馏网络的大网络;
根据所述网络蒸馏要求,构建蒸馏网络的小网络;
利用训练数据集训练所述大网络,在精度超过预设阈值时,得到所述大网络输出的第一输出结果;
使用所述小网络、所述第一输出结果和预设损失函数进行蒸馏训练,直至符合收敛条件时,确定目标小网络,所述预设损失函数至少基于三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数得到。
在一个实施例中,计算机程序被处理器执行时还实现以下步骤:
基于Resnet50残差网络,构建大网络;
基于Resnet18残差网络,构建小网络。
在一个实施例中,计算机程序被处理器执行时还实现以下步骤:
采用随机梯度下降+动量法学习,训练数据集采用公开分类数据集CIFAR100,迭代训练次数为80次,批量大小设定为128,学习率设定为1e-4。
在一个实施例中,提供了一种计算机程序产品,包括计算机程序,该计算机程序被处理器执行时实现上述各方法实施例中的步骤。
在一个实施例中,提供了一种计算机程序产品,包括计算机程序,该计算机程序被处理器执行时实现以下步骤:
获取网络蒸馏要求,所述网络蒸馏要求包括网络蒸馏中的大网络层数,以及与所述大网络层数对应的小网络层数;
根据所述网络蒸馏要求,构建蒸馏网络的大网络;
根据所述网络蒸馏要求,构建蒸馏网络的小网络;
利用训练数据集训练所述大网络,在精度超过预设阈值时,得到所述大网络输出的第一输出结果;
使用所述小网络、所述第一输出结果和预设损失函数进行蒸馏训练,直至符合收敛条件时,确定目标小网络,所述预设损失函数至少基于三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数得到。
在一个实施例中,计算机程序被处理器执行时还实现以下步骤:
基于Resnet50残差网络,构建大网络;
基于Resnet18残差网络,构建小网络。
在一个实施例中,计算机程序被处理器执行时还实现以下步骤:
采用随机梯度下降+动量法学习,训练数据集采用公开分类数据集CIFAR100,迭代训练次数为80次,批量大小设定为128,学习率设定为1e-4。
需要说明的是,本公开所涉及的用户信息(包括但不限于用户设备信息、用户个人信息等)和数据(包括但不限于用于分析的数据、存储的数据、展示的数据等),均为经用户授权或者经过各方充分授权的信息和数据。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本公开所提供的各实施例中所使用的对存储器、数据库或其它介质的任何引用,均可包括非易失性和易失性存储器中的至少一种。非易失性存储器可包括只读存储器(Read-OnlyMemory,ROM)、磁带、软盘、闪存、光存储器、高密度嵌入式非易失性存储器、阻变存储器(ReRAM)、磁变存储器(Magnetoresistive Random Access Memory,MRAM)、铁电存储器(Ferroelectric Random Access Memory,FRAM)、相变存储器(Phase Change Memory,PCM)、石墨烯存储器等。易失性存储器可包括随机存取存储器(Random Access Memory,RAM)或外部高速缓冲存储器等。作为说明而非局限,RAM可以是多种形式,比如静态随机存取存储器(Static Random Access Memory,SRAM)或动态随机存取存储器(Dynamic RandomAccess Memory,DRAM)等。本公开所提供的各实施例中所涉及的数据库可包括关系型数据库和非关系型数据库中至少一种。非关系型数据库可包括基于区块链的分布式数据库等,不限于此。本公开所提供的各实施例中所涉及的处理器可为通用处理器、中央处理器、图形处理器、数字信号处理器、可编程逻辑器、基于量子计算的数据处理逻辑器等,不限于此。
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
以上所述实施例仅表达了本公开的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对本公开专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本公开构思的前提下,还可以做出若干变形和改进,这些都属于本公开的保护范围。因此,本公开的保护范围应以所附权利要求为准。
Claims (10)
1.一种网络蒸馏方法,其特征在于,所述方法包括:
获取网络蒸馏要求,所述网络蒸馏要求包括网络蒸馏中的大网络层数,以及与所述大网络层数对应的小网络层数;
根据所述网络蒸馏要求,构建蒸馏网络的大网络;
根据所述网络蒸馏要求,构建蒸馏网络的小网络;
利用训练数据集训练所述大网络,在精度超过预设阈值时,得到所述大网络输出的第一输出结果;
使用所述小网络、所述第一输出结果和预设损失函数进行蒸馏训练,直至符合收敛条件时,确定目标小网络,所述预设损失函数至少基于三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数得到。
2.根据权利要求1所述的方法,其特征在于,所述预设损失函数包括:
Loss_all=Loss_base+Loss_feature+Loss_embedding
Loss_base=Loss_triplet+Loss_CE
其中,Loss_all表示预设损失函数,Loss_base表示基本损失函数,Loss_feature表示特征损失函数,Loss_embedding表示嵌入层损失函数,Loss_triplet表示三元组损失函数,Loss_CE表示交叉熵损失函数。
3.根据权利要求2所述的方法,其特征在于,所述嵌入层损失函数包括:
Loss_embedding=Loss_batch_distance+Loss_batch_angle
其中,Loss_embedding表示嵌入层损失函数,Loss_batch_angle表示批量内角度损失函数,Loss_batch_distance表示批量内间距损失函数。
4.根据权利要求2所述的方法,其特征在于,所述特征损失函数包括:
Loss_feature=||norm(Teacher_feature)–norm(Student_feature)||^2
其中,Loss_feature表示特征损失函数,norm表示归一化,Teacher_feature表示大网络的特征表征,Student_feature表示小网络的特征表征。
5.根据权利要求1所述的方法,其特征在于,所述构建网络蒸馏的大网络和所述构建网络蒸馏的小网络包括:
基于Resnet50残差网络,构建大网络;
基于Resnet18残差网络,构建小网络。
6.根据权利要求5所述的方法,其特征在于,所述蒸馏训练包括:
采用随机梯度下降+动量法学习,训练数据集采用公开分类数据集CIFAR100,迭代训练次数为80次,批量大小设定为128,学习率设定为1e-4。
7.一种网络蒸馏装置,其特征在于,所述装置包括:
要求获取模块,用于获取网络蒸馏要求,所述网络蒸馏要求包括网络蒸馏中的大网络层数,以及与所述大网络层数对应的小网络层数;
大网络构建模块,用于根据所述网络蒸馏要求,构建蒸馏网络的大网络;
小网络构建模块,用于根据所述网络蒸馏要求,构建蒸馏网络的小网络;
大网络训练模块,用于利用训练数据集训练所述大网络,在精度超过预设阈值时,得到所述大网络输出的第一输出结果;
小网络训练模块,用于使用所述小网络、所述第一输出结果和预设损失函数进行蒸馏训练,直至符合收敛条件时,确定目标小网络,所述预设损失函数至少基于三元组损失函数、交叉熵损失函数、特征损失函数、嵌入层损失函数得到。
8.一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至7中任一项所述的方法的步骤。
9.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至7中任一项所述的方法的步骤。
10.一种计算机程序产品,包括计算机程序,其特征在于,该计算机程序被处理器执行时实现权利要求1至7中任一项所述的方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210428589.XA CN114819076A (zh) | 2022-04-22 | 2022-04-22 | 网络蒸馏方法、装置、计算机设备、存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210428589.XA CN114819076A (zh) | 2022-04-22 | 2022-04-22 | 网络蒸馏方法、装置、计算机设备、存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114819076A true CN114819076A (zh) | 2022-07-29 |
Family
ID=82505914
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210428589.XA Pending CN114819076A (zh) | 2022-04-22 | 2022-04-22 | 网络蒸馏方法、装置、计算机设备、存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114819076A (zh) |
-
2022
- 2022-04-22 CN CN202210428589.XA patent/CN114819076A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
WO2018068421A1 (zh) | 一种神经网络的优化方法及装置 | |
JP2021039758A (ja) | 画像間の類似度を利用した類似領域強調方法およびシステム | |
TWI740338B (zh) | 具有動態最小批次尺寸之運算方法,以及用於執行該方法之運算系統及電腦可讀儲存媒體 | |
CN112307243B (zh) | 用于检索图像的方法和装置 | |
CN114638823B (zh) | 基于注意力机制序列模型的全切片图像分类方法及装置 | |
CN116703598A (zh) | 交易行为检测方法、装置、计算机设备和存储介质 | |
CN114819076A (zh) | 网络蒸馏方法、装置、计算机设备、存储介质 | |
Kuzman | Poletsky theory of discs in almost complex manifolds | |
CN117235584B (zh) | 图数据分类方法、装置、电子装置和存储介质 | |
CN110033098A (zh) | 在线gbdt模型学习方法及装置 | |
CN116976464A (zh) | 去偏置联邦学习训练方法、装置、计算机设备和存储介质 | |
CN117216103A (zh) | 缓存失效时间的确定方法、装置、计算机设备和存储介质 | |
CN116822512A (zh) | 命名实体识别方法、装置、计算机设备和存储介质 | |
CN117473975A (zh) | 地址资源匹配方法、装置、计算机设备和存储介质 | |
CN117436484A (zh) | 图像识别模型构建、装置、计算机设备和存储介质 | |
CN117150311A (zh) | 数据处理方法、装置、设备和存储介质 | |
CN116861097A (zh) | 信息推荐方法、装置、计算机设备和存储介质 | |
CN115860099A (zh) | 神经网络模型的压缩方法、装置、计算机设备和存储介质 | |
CN117078427A (zh) | 产品推荐方法、装置、设备、存储介质和程序产品 | |
CN116883595A (zh) | 三维场景建模方法、装置、设备、存储介质和程序产品 | |
CN116881450A (zh) | 资讯分类方法、装置、计算机设备、存储介质和程序产品 | |
CN116578886A (zh) | 用户聚类方法、装置、计算机设备和存储介质 | |
CN117437010A (zh) | 资源借调等级预测方法、装置、设备、存储介质和程序产品 | |
CN117057439A (zh) | 模型参数更新方法、装置、计算机设备和存储介质 | |
CN117437674A (zh) | 人脸图像处理方法、装置、计算机设备和存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |