CN115829024B - 一种模型训练方法、装置、设备及存储介质 - Google Patents
一种模型训练方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- CN115829024B CN115829024B CN202310108097.7A CN202310108097A CN115829024B CN 115829024 B CN115829024 B CN 115829024B CN 202310108097 A CN202310108097 A CN 202310108097A CN 115829024 B CN115829024 B CN 115829024B
- Authority
- CN
- China
- Prior art keywords
- initial
- training
- model
- pruning
- current
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T90/00—Enabling technologies or technologies with a potential or indirect contribution to GHG emissions mitigation
Landscapes
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请公开了一种模型训练方法、装置、设备及存储介质,涉及神经网络领域,包括:利用超球面学习算法对初始权重参数进行调整得到调整后权重参数;基于预设剪枝方法和初始稀疏率确定初始剪枝掩码;基于正则化后目标函数、预设剪枝方法和初始剪枝掩码对初始轻量化模型进行训练,以将初始稀疏率降低至目标稀疏率,并对调整后权重参数进行相应调整得到目标权重参数以及与目标权重参数对应的剪枝后轻量化模型;对剪枝后轻量化模型进行微调得到目标轻量化模型。本申请利用超球面学习算法可以降低训练难度,加快收敛速度,并且结合正则化后目标函数同时完成对模型的训练和剪枝,再对剪枝后的模型进行一次微调,在保证模型精度的同时提高模型微调效率。
Description
技术领域
本发明涉及神经网络领域,特别涉及一种模型训练方法、装置、设备及存储介质。
背景技术
深度神经网络模型包含数百万个参数,因此无法在边缘设备上部署大型神经网络模型。在这种资源受限的情况下,我们必须考虑模型大小和推理效率。目前,量化和剪枝由于可以减少模型的规模和计算开销而受到我们广泛的关注。
模型剪枝的目的是得到一个具有最大精度和压缩比的神经网络。目前大多数剪枝都会遇到两个问题:如何减少微调时间和如何从剪枝中快速恢复网络的精度。而在实际剪枝过程中,剪枝和微调的步骤会重复多次,以逐渐减小模型尺寸并保持较高的精度。微调过程非常耗时,需要通过运行整个训练数据集来调整模型的参数。因此,如何利用较少的训练数据来提高神经网络的微调效率和恢复能力是目前有待解决的问题。
发明内容
有鉴于此,本发明的目的在于提供一种模型训练方法、装置、设备及存储介质,能够利用超球面学习算法降低训练难度,加快收敛速度,并且结合正则化后目标函数同时完成对模型的训练和剪枝过程,再对剪枝后的模型进行一次微调,提高了模型的微调效率,并可以从剪枝中快速恢复网络的精度。其具体方案如下:
第一方面,本申请提供了一种模型训练方法,包括:
利用超球面学习算法对初始轻量化模型的初始权重参数进行调整,以得到调整后权重参数;
基于预设剪枝方法和初始稀疏率确定初始剪枝掩码;
基于预先构建的正则化后目标函数、所述预设剪枝方法和所述初始剪枝掩码对所述初始轻量化模型进行训练,以将所述初始稀疏率降低至目标稀疏率,并对所述调整后权重参数进行相应的调整以得到对应的目标权重参数,以及与所述目标权重参数对应的剪枝后轻量化模型;
基于预设模型微调规则对所述剪枝后轻量化模型进行微调,以得到目标轻量化模型。
可选的,所述利用超球面学习算法对初始轻量化模型的初始权重参数进行调整,以得到调整后权重参数,包括:
获取初始轻量化模型,并利用超球面学习算法对所述初始轻量化模型的初始权重参数进行调整,以得到模长为1的调整后权重参数。
可选的,所述基于预设剪枝方法和初始稀疏率确定初始剪枝掩码,包括:
基于预设剪枝方法和初始稀疏率从所述初始轻量化模型的每个网络层的每个通道对应的通道权重中确定出满足预设通道权重确定规则的若干通道权重,并为所述若干通道权重设置相应的预设掩码,以得到对应的初始剪枝掩码;所述网络层包括除第一个卷积层之外的所有线性层和卷积层。
可选的,所述基于预先构建的正则化后目标函数、所述预设剪枝方法和所述初始剪枝掩码对所述初始轻量化模型进行训练之前,还包括:
基于权重参数、剪枝掩码和单位矩阵构建矩阵的迹公式,并根据剪枝掩码对应的矩阵列数和所述矩阵的迹公式构建正则化项;
基于目标函数、正则化参数以及所述正则化项构建正则化后目标函数。
可选的,所述基于预先构建的正则化后目标函数、所述预设剪枝方法和所述初始剪枝掩码对所述初始轻量化模型进行训练,以将所述初始稀疏率降低至目标稀疏率,并对所述调整后权重参数进行相应的调整以得到对应的目标权重参数,以及与所述目标权重参数对应的剪枝后轻量化模型,包括:
基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对所述初始稀疏率进行数值降低处理,以得到当前稀疏率;所述当前训练周期为预先计算的训练周期数中的任意一个周期;
基于所述预设剪枝方法和所述当前稀疏率确定当前剪枝掩码;
基于预先构建的正则化后目标函数、所述预设剪枝方法和所述当前剪枝掩码对当前轻量化模型进行训练,并对当前权重参数进行调整得到与所述当前剪枝掩码对应的新的当前权重参数,以及与新的当前权重参数对应的新的当前轻量化模型;
判断所述当前稀疏率是否达到目标稀疏率,如果否则重新跳转至所述基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对所述初始稀疏率进行数值降低处理的步骤,如果是则结束训练。
可选的,所述基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对所述初始稀疏率进行数值降低处理之前,还包括:
获取预设训练总轮次和预设训练周期轮次;所述预设训练周期轮次为每个训练周期内需要进行的训练轮次;
基于所述预设训练总轮次与所述预设训练周期轮次计算训练周期数。
可选的,所述基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对所述初始稀疏率进行数值降低处理之前,还包括:
获取初始稀疏率和目标稀疏率,并计算所述初始稀疏率与所述目标稀疏率之间的稀疏率差值;
基于所述稀疏率差值与所述训练周期数确定出稀疏率变化值。
第二方面,本申请提供了一种模型训练装置,包括:
权重调整模块,用于利用超球面学习算法对初始轻量化模型的初始权重参数进行调整,以得到调整后权重参数;
掩码确定模块,用于基于预设剪枝方法和初始稀疏率确定初始剪枝掩码;
模型训练模块,用于基于预先构建的正则化后目标函数、所述预设剪枝方法和所述初始剪枝掩码对所述初始轻量化模型进行训练,以将所述初始稀疏率降低至目标稀疏率,并对所述调整后权重参数进行相应的调整以得到对应的目标权重参数,以及与所述目标权重参数对应的剪枝后轻量化模型;
模型微调模块,用于基于预设模型微调规则对所述剪枝后轻量化模型进行微调,以得到目标轻量化模型。
第三方面,本申请提供了一种电子设备,包括:
存储器,用于保存计算机程序;
处理器,用于执行所述计算机程序以实现前述的模型训练方法。
第四方面,本申请提供了一种计算机可读存储介质,用于保存计算机程序,所述计算机程序被处理器执行时实现前述的模型训练方法。
本申请中,利用超球面学习算法对初始轻量化模型的初始权重参数进行调整,以得到调整后权重参数;基于预设剪枝方法和初始稀疏率确定初始剪枝掩码;基于预先构建的正则化后目标函数、所述预设剪枝方法和所述初始剪枝掩码对所述初始轻量化模型进行训练,以将所述初始稀疏率降低至目标稀疏率,并对所述调整后权重参数进行相应的调整以得到对应的目标权重参数,以及与所述目标权重参数对应的剪枝后轻量化模型;基于预设模型微调规则对所述剪枝后轻量化模型进行微调,以得到目标轻量化模型。由此可见,本申请利用超球面学习算法可以降低模型训练难度,加快模型收敛速度,以及保证模型的分类精度;并且结合正则化后目标函数同时对初始轻量化模型进行训练和剪枝操作,直至将初始稀疏率不断降低至目标稀疏率,以及对调整后权重参数不断进行调整,得到与目标权重参数对应的剪枝后轻量化模型,不仅可以提高模型训练的效率,减小模型尺寸,以及减少模型网络的复杂程度,而且可以优化模型性能;最后通过对剪枝后轻量化模型进行一次微调,避免了在每次剪枝后均对剪枝后的模型进行微调操作所带来的严重耗时问题,提高了模型的微调效率,并可以从剪枝中快速恢复网络的精度。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据提供的附图获得其他的附图。
图1为本申请公开的一种模型训练方法流程图;
图2为本申请公开的一种模型训练流程图;
图3为本申请公开的一种具体的模型训练方法流程图;
图4为本申请公开的一种模型训练装置结构示意图;
图5为本申请公开的一种电子设备结构图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
在实际剪枝过程中,剪枝和微调的步骤会重复多次,以逐渐减小模型尺寸并保持较高的精度。微调过程非常耗时,需要通过运行整个训练数据集来调整模型的参数。为此,本申请提供了一种模型训练方法,能够利用超球面学习算法降低训练难度,加快收敛速度,并且结合正则化后目标函数同时完成对模型的训练和剪枝过程,再对剪枝后的模型进行一次微调,提高了模型的微调效率,并可以从剪枝中快速恢复网络的精度。
参见图1所示,本发明实施例公开了一种模型训练方法,包括:
步骤S11、利用超球面学习算法对初始轻量化模型的初始权重参数进行调整,以得到调整后权重参数。
本实施例中,如图2所示,先获取训练数据集和将要进行训练的初始轻量化模型,一般情况下轻量化模型为MobileNet系列,调整初始轻量化模型的梯度下降算法,也即利用超球面学习算法替代原来的模型算法。并且由于采用超球面学习算法,需要对初始轻量化模型的每一层的初始权重参数和输入向量进行调整,以保证调整后的权重参数和调整后的输入向量在超球体上的模长为1。也即,权重参数和输入向量需要满足如下公式:
其中,W表示每一层的权重参数,T表示矩阵的转置,X表示输入向量,y表示输出向量。调整后的权重参数需要满足以及调整后的输入向量需要满足/>。例如,本实施例中的训练数据集可以采用ImageNet数据集,初始轻量化模型可以选择MobileNetV2模型,调整MobileNetV2模型的梯度下降算法,如模型原来采用的算法为随机梯度下降(Stochastic Gradient Descent,SGD),则使用超球面学习算法来代替随机梯度下降算法,并调整初始权重参数和输入向量的模长均为1。这样一来,通过利用超球面学习算法替代模型原来的算法,可以降低模型训练难度,使模型更容易优化,收敛速度更快,也可以保证模型的分类精度。
步骤S12、基于预设剪枝方法和初始稀疏率确定初始剪枝掩码。
本实施例中,如图2所示,在模型训练之前,需要先基于预设剪枝方法和初始稀疏率从初始轻量化模型的每个网络层的每个通道对应的通道权重中确定出满足预设通道权重确定规则的若干通道权重,并为若干通道权重设置相应的预设掩码,以得到对应的初始剪枝掩码。其中,初始稀疏率可以由用户根据自身的剪枝需求进行设置,但初始需要设置较高的稀疏率;网络层包括除第一个卷积层之外的所有线性层和卷积层;预设剪枝方法包括但不限于结构化剪枝方法。例如,本实施例中的初始稀疏率设置为0.9,意味着百分之九十的权重将会置0,预设剪枝方法选择结构化剪枝方法,按照每层每个通道中的最小权重和作为通道重要性的确定规则,可以基于初始稀疏率0.9从每层每个通道对应的权重中确定出通道重要性较低的若干通道权重,并为若干通道权重设置相应的掩码0,以得到与初始稀疏率对应的初始剪枝掩码。
步骤S13、基于预先构建的正则化后目标函数、所述预设剪枝方法和所述初始剪枝掩码对所述初始轻量化模型进行训练,以将所述初始稀疏率降低至目标稀疏率,并对所述调整后权重参数进行相应的调整以得到对应的目标权重参数,以及与所述目标权重参数对应的剪枝后轻量化模型。
本实施例中,如图2所示,在进行模型训练之前,需要先构建正则化后目标函数,具体的,基于权重参数、剪枝掩码和单位矩阵构建矩阵的迹公式,并根据剪枝掩码对应的矩阵列数和矩阵的迹公式构建正则化项;基于目标函数、正则化参数以及正则化项构建正则化后目标函数。该步骤涉及到的公式如下:
其中,r表示稀疏率;M表示剪枝掩码;W表示权重参数;T表示矩阵的转置;I表示单位矩阵;m表示剪枝掩码M对应的矩阵列数;L表示L2正则化;λ表示正则化参数,用于使L(W)和Ltr(W,r)保持在相同的尺度上;trace()表示矩阵的迹,用于表示权重参数与剪枝掩码之间的余弦相似性。
本实施例中,在构建好正则化后目标函数之后,利用训练数据集并基于正则化后目标函数、预设剪枝方法和初始剪枝掩码对初始轻量化模型同时进行训练和剪枝操作,在此过程中,会不断降低初始稀疏率,并对调整后权重参数不断进行调整,使其不断向当前剪枝掩码进行靠近,直至初始稀疏率降低至目标稀疏率,同时会基于预设剪枝方法和目标稀疏率得到目标剪枝掩码,以及通过对当前权重参数不断进行调整,此时得到与目标剪枝掩码对应的目标权重参数,以及与目标权重参数对应的目标轻量化模型。其中,目标稀疏率小于初始稀疏率,并且可以由用户根据自身的剪枝需求进行设置。需要说明的是,在模型训练前,需要设置初始训练的一些参数,例如,正则化参数可以设置为2,初始稀疏率为0.9,目标稀疏率为0.7,权重衰减为0.0001,随机梯度下降动量为0.9%,初始学习率为0.01;并且在训练期间,本实施例选择使用余弦退火来调整学习速率。这样一来,本实施例通过同时完成对模型的训练和剪枝过程,可以提高模型训练的效率,减小模型尺寸,以及减少模型网络的复杂程度。
步骤S14、基于预设模型微调规则对所述剪枝后轻量化模型进行微调,以得到目标轻量化模型。
本实施例中,如图2所示,在得到剪枝后轻量化模型之后,利用训练数据集并预先设定的模型微调规则对剪枝后的轻量化模型进行一次微调,以得到最终训练好的目标轻量化模型。其中,预设模型微调规则可以由用户根据自身的微调需求进行设置。这样一来,通过对剪枝后轻量化模型进行一次微调可以减少微调时间,并从剪枝中快速恢复网络的精度。
例如,本实施例的方法可以应用于交通领域中对交通路况进行检测的边缘设备,由于考虑到边缘设备的资源受限,对传统的深度学习模型提出了体积小、速度快、精度高的要求,因此在训练和推理阶段,本实施例的方法选择使用轻量化模型。获取当前路况图像,将当前路况图像输入边缘设备的初始轻量化模型中,并将其作为初始输入向量。利用超球面学习算法替代初始轻量化模型原来的算法,并对初始权重参数和初始输入向量进行调整,以保证调整后权重参数与调整后输入向量的模长均为1。然后在模型训练的过程中,不断对初始稀疏率进行数值降低处理,以得到当前稀疏率,同时利用正则化后目标函数和预先设定的剪枝方法对当前轻量化模型进行剪枝和训练,以将当前权重参数不断向与当前稀疏率对应的当前剪枝掩码进行靠近,直至将初始稀疏率降低至目标稀疏率,同时得到与目标稀疏率对应的目标剪枝掩码,以及向目标剪枝掩码靠近的目标权重参数,以及与目标权重参数对应的剪枝后轻量化模型。最后利用预先设定的模型微调规则对剪枝后轻量化模型进行一次微调,即可得到最终训练好的目标轻量化模型。利用目标轻量化模型对模长为1的调整后输入向量进行检测,以得到当前路况检测结果。通过本实施例的方法对当前路况进行检测,不仅可以具有更快的检测速度,而且可以具有更高的检测精度。
由此可见,本申请利用超球面学习算法可以降低模型训练难度,加快模型收敛速度,以及保证模型的分类精度;并且结合正则化后目标函数同时对初始轻量化模型进行训练和剪枝操作,直至将初始稀疏率不断降低至目标稀疏率,以及对调整后权重参数不断进行调整,得到与目标权重参数对应的剪枝后轻量化模型,不仅可以提高模型训练的效率,减小模型尺寸,以及减少模型网络的复杂程度,而且可以优化模型性能;最后通过对剪枝后轻量化模型进行一次微调,避免了在每次剪枝后均对剪枝后的模型进行微调操作所带来的严重耗时问题,提高了模型的微调效率,并可以从剪枝中快速恢复网络的精度。
基于前一实施例可知,本申请描述了利用超球面学习算法替代模型原来的算法以及对模型进行训练、剪枝和微调的整体流程,接下来,本申请将对模型的训练和剪枝过程进行详细的阐述。为此,参见图3所示,本发明实施例公开了一种模型训练和剪枝的过程,包括:
步骤S21、基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对初始稀疏率进行数值降低处理,以得到当前稀疏率;所述当前训练周期为预先计算的训练周期数中的任意一个周期。
本实施例中,在模型训练之前,需要先获取训练周期数,具体的,先获取预设训练总轮次和预设训练周期轮次;其中,上述预设训练周期轮次为每个训练周期内需要进行的训练轮次;并基于预设训练总轮次与预设训练周期轮次计算训练周期数。可以理解的是,获取模型训练的总轮次以及每个训练周期内需要进行的训练轮次,利用训练总轮次除以每个训练周期内的训练轮次,即为训练周期数。例如,若预设训练总轮次为900个epoch,每个周期内想要进行的训练轮次为90个epoch,则900/90=10,也即训练周期数n为10。
本实施例中,在基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对初始稀疏率进行数值降低处理之前,还可以包括获取初始稀疏率和目标稀疏率,并计算上述初始稀疏率与上述目标稀疏率之间的稀疏率差值;基于稀疏率差值与所述训练周期数确定出稀疏率变化值。可以理解的是,先利用初始稀疏率减去目标稀疏率得到两者之间的稀疏率差值,再利用稀疏率差值除以训练周期数n得到稀疏率变化值。例如,初始稀疏率为0.9,目标稀疏率为0.7,训练周期数为10,则稀疏率变化值为(0.9-0.7)/10=0.02。
本实施例中,在当前训练周期的训练过程中,先根据当前训练周期对应的周期序号和稀疏率变化值,对初始稀疏率进行数值降低处理,以得到当前稀疏率。例如,当前训练周期对应的周期序号为3,则0.9-0.02*3=0.84,也即当前稀疏率为0.84。其中,当前训练周期为预先计算的训练周期数n中的任意一个周期。
步骤S22、基于预设剪枝方法和所述当前稀疏率确定当前剪枝掩码。
本实施例中,在得到当前稀疏率之后,需要先计算出与当前稀疏率对应的当前剪枝掩码,具体的,利用预先选择的剪枝方法和当前稀疏率从当前轻量化模型的每层每个通道对应的权重中确定出满足预设通道权重确定规则的若干通道权重,并为若干通道权重设置相应的预设掩码,以得到对应的当前剪枝掩码。其中,预设通道权重确定规则包括接近于0的权重或者重新缩放部分权重的大小将不会影响模型性能的权重。
步骤S23、基于预先构建的正则化后目标函数、所述预设剪枝方法和所述当前剪枝掩码对当前轻量化模型进行训练,并对当前权重参数进行调整得到与所述当前剪枝掩码对应的新的当前权重参数,以及与新的当前权重参数对应的新的当前轻量化模型。
本实施例中,基于正则化后目标函数、预设剪枝方法和当前剪枝掩码对当前轻量化模型同时进行训练和剪枝操作,以对当前权重参数进行调整,将其不断向当前剪枝掩码进行靠近,得到与当前剪枝掩码对应的新的当前权重参数,以及与新的当前权重参数对应的新的当前轻量化模型。这样一来,通过不断对当前权重参数进行调整,改变权重分布,可以得到性能更优的模型。
步骤S24、判断所述当前稀疏率是否达到目标稀疏率,如果否则重新跳转至所述基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对初始稀疏率进行数值降低处理的步骤,如果是则结束训练。
本实施例中,对当前稀疏率和目标稀疏率进行对比,若当前稀疏率未达到目标稀疏率,则表明当前训练周期对应的周期序号小于训练周期数,并重新跳转至步骤S21;若当前稀疏率达到目标稀疏率,则表明训练结束,此时会得到剪枝后轻量化模型。最后再利用预设模型微调规则对剪枝后轻量化模型进行一次微调,即可得到目标轻量化模型。这样一来,本实施例通过对剪枝后的模型进行一次微调操作,可以避免重复进行剪枝和微调,只需要在剪枝完成后对剪枝后的模型进行一次微调即可,不仅提高了模型的微调效率,还可以从剪枝中快速恢复网络的精度。
由此可见,本申请利用超球面学习算法可以降低模型训练难度,加快模型收敛速度,以及保证模型的分类精度;并且结合正则化后目标函数同时对初始轻量化模型进行训练和剪枝操作,直至将初始稀疏率不断降低至目标稀疏率,以及对调整后权重参数不断进行调整,得到与目标权重参数对应的剪枝后轻量化模型,不仅可以提高模型训练的效率,减小模型尺寸,以及减少模型网络的复杂程度,而且可以优化模型性能;最后通过对剪枝后轻量化模型进行一次微调,避免了在每次剪枝后均对剪枝后的模型进行微调操作所带来的严重耗时问题,提高了模型的微调效率,并可以从剪枝中快速恢复网络的精度。
参见图4所示,本发明实施例公开了一种模型训练装置,包括:
权重调整模块11,用于利用超球面学习算法对初始轻量化模型的初始权重参数进行调整,以得到调整后权重参数;
掩码确定模块12,用于基于预设剪枝方法和初始稀疏率确定初始剪枝掩码;
模型训练模块13,用于基于预先构建的正则化后目标函数、所述预设剪枝方法和所述初始剪枝掩码对所述初始轻量化模型进行训练,以将所述初始稀疏率降低至目标稀疏率,并对所述调整后权重参数进行相应的调整以得到对应的目标权重参数,以及与所述目标权重参数对应的剪枝后轻量化模型;
模型微调模块14,用于基于预设模型微调规则对所述剪枝后轻量化模型进行微调,以得到目标轻量化模型。
由此可见,本申请利用超球面学习算法可以降低模型训练难度,加快模型收敛速度,以及保证模型的分类精度;并且结合正则化后目标函数同时对初始轻量化模型进行训练和剪枝操作,直至将初始稀疏率不断降低至目标稀疏率,以及对调整后权重参数不断进行调整,得到与目标权重参数对应的剪枝后轻量化模型,不仅可以提高模型训练的效率,减小模型尺寸,以及减少模型网络的复杂程度,而且可以优化模型性能;最后通过对剪枝后轻量化模型进行一次微调,避免了在每次剪枝后均对剪枝后的模型进行微调操作所带来的严重耗时问题,提高了模型的微调效率,并可以从剪枝中快速恢复网络的精度。
在一些具体实施例中,所述权重调整模块11,具体可以包括:
模型获取单元,用于获取初始轻量化模型;
初始权重调整单元,用于并利用超球面学习算法对所述初始轻量化模型的初始权重参数进行调整,以得到模长为1的调整后权重参数。
在一些具体实施例中,所述掩码确定模块12,具体可以包括:
初始掩码确定单元,用于基于预设剪枝方法和初始稀疏率从所述初始轻量化模型的每个网络层的每个通道对应的通道权重中确定出满足预设通道权重确定规则的若干通道权重,并为所述若干通道权重设置相应的预设掩码,以得到对应的初始剪枝掩码;所述网络层包括除第一个卷积层之外的所有线性层和卷积层。
在一些具体实施例中,所述模型训练装置,还可以包括:
正则化项构建单元,用于基于权重参数、剪枝掩码和单位矩阵构建矩阵的迹公式,并根据剪枝掩码对应的矩阵列数和所述矩阵的迹公式构建正则化项;
函数构建单元,用于基于目标函数、正则化参数以及所述正则化项构建正则化后目标函数。
在一些具体实施例中,所述模型训练模块13,具体可以包括:
当前稀疏率确定单元,用于基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对所述初始稀疏率进行数值降低处理,以得到当前稀疏率;所述当前训练周期为预先计算的训练周期数中的任意一个周期;
当前掩码确定单元,用于基于所述预设剪枝方法和所述当前稀疏率确定当前剪枝掩码;
当前权重调整单元,用于基于预先构建的正则化后目标函数、所述预设剪枝方法和所述当前剪枝掩码对当前轻量化模型进行训练,并对当前权重参数进行调整得到与所述当前剪枝掩码对应的新的当前权重参数,以及与新的当前权重参数对应的新的当前轻量化模型;
稀疏率判断单元,用于判断所述当前稀疏率是否达到目标稀疏率,如果否则重新跳转至所述基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对所述初始稀疏率进行数值降低处理的步骤,如果是则结束训练。
在一些具体实施例中,所述模型训练装置,还可以包括:
轮次获取单元,用于获取预设训练总轮次和预设训练周期轮次;所述预设训练周期轮次为每个训练周期内需要进行的训练轮次;
周期数计算单元,用于基于所述预设训练总轮次与所述预设训练周期轮次计算训练周期数。
在一些具体实施例中,所述模型训练装置,还可以包括:
差值计算单元,用于获取初始稀疏率和目标稀疏率,并计算所述初始稀疏率与所述目标稀疏率之间的稀疏率差值;
变化值确定单元,用于基于所述稀疏率差值与所述训练周期数确定出稀疏率变化值。
进一步的,本申请实施例还公开了一种电子设备,图5是根据一示例性实施例示出的电子设备20结构图,图中的内容不能认为是对本申请的使用范围的任何限制。
图5为本申请实施例提供的一种电子设备20的结构示意图。该电子设备 20,具体可以包括:至少一个处理器21、至少一个存储器22、电源23、通信接口24、输入输出接口25和通信总线26。其中,所述存储器22用于存储计算机程序,所述计算机程序由所述处理器21加载并执行,以实现前述任一实施例公开的模型训练方法中的相关步骤。另外,本实施例中的电子设备20具体可以为电子计算机。
本实施例中,电源23用于为电子设备20上的各硬件设备提供工作电压;通信接口24能够为电子设备20创建与外界设备之间的数据传输通道,其所遵 循的通信协议是能够适用于本申请技术方案的任意通信协议,在此不对其进 行具体限定;输入输出接口25,用于获取外界输入数据或向外界输出数据,其具体的接口类型可以根据具体应用需要进行选取,在此不进行具体限定。
另外,存储器22作为资源存储的载体,可以是只读存储器、随机存储器、 磁盘或者光盘等,其上所存储的资源可以包括操作系统221、计算机程序222 等,存储方式可以是短暂存储或者永久存储。
其中,操作系统221用于管理与控制电子设备20上的各硬件设备以及计算 机程序222,其可以是Windows Server、Netware、Unix、Linux等。计算机程序222除了包括能够用于完成前述任一实施例公开的由电子设备20执行的模型训练方法的计算机程序之外,还可以进一步包括能够用于完成其他特定工作的计算机程序。
进一步的,本申请还公开了一种计算机可读存储介质,用于存储计算机程序;其中,所述计算机程序被处理器执行时实现前述公开的模型训练方法。关于该方法的具体步骤可以参考前述实施例中公开的相应内容,在此不再进行赘述。
本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其它实施例的不同之处,各个实施例之间相同或相似部分互相参见即可。对于实施例公开的装置而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。
专业人员还可以进一步意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、计算机软件或者二者的结合来实现,为了清楚地说明硬件和软件的可互换性,在上述说明中已经按照功能一般性地描述了各示例的组成及步骤。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
结合本文中所公开的实施例描述的方法或算法的步骤可以直接用硬件、处理器执行的软件模块,或者二者的结合来实施。软件模块可以置于随机存储器(RAM)、内存、只读存储器(ROM)、电可编程ROM、电可擦除可编程ROM、寄存器、硬盘、可移动磁盘、CD-ROM、或技术领域内所公知的任意其它形式的存储介质中。
最后,还需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
以上对本申请所提供的技术方案进行了详细介绍,本文中应用了具体个例对本申请的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本申请的方法及其核心思想;同时,对于本领域的一般技术人员,依据本申请的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本申请的限制。
Claims (8)
1.一种模型训练方法,其特征在于,应用于对交通路况进行检测的边缘设备,包括:
利用超球面学习算法对初始轻量化模型的初始权重参数和初始输入向量进行调整,以得到调整后权重参数和调整后输入向量;所述初始输入向量为输入至所述边缘设备的路况图像,并且所述调整后权重参数和所述调整后输入向量的模长均为1;
基于预设剪枝方法和初始稀疏率确定初始剪枝掩码;
基于预先构建的正则化后目标函数、所述预设剪枝方法和所述初始剪枝掩码对所述初始轻量化模型进行训练,以将所述初始稀疏率降低至目标稀疏率,并对所述调整后权重参数进行相应的调整以得到对应的目标权重参数,以及与所述目标权重参数对应的剪枝后轻量化模型;
基于预设模型微调规则对所述剪枝后轻量化模型进行微调,以得到目标轻量化模型,以便利用所述目标轻量化模型进行路况检测;
其中,所述基于预先构建的正则化后目标函数、所述预设剪枝方法和所述初始剪枝掩码对所述初始轻量化模型进行训练,以将所述初始稀疏率降低至目标稀疏率,并对所述调整后权重参数进行相应的调整以得到对应的目标权重参数,以及与所述目标权重参数对应的剪枝后轻量化模型,包括:
基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对所述初始稀疏率进行数值降低处理,以得到当前稀疏率;所述当前训练周期为预先计算的训练周期数中的任意一个周期;
基于所述预设剪枝方法和所述当前稀疏率确定当前剪枝掩码;
基于预先构建的正则化后目标函数、所述预设剪枝方法和所述当前剪枝掩码对当前轻量化模型进行训练,并对当前权重参数进行调整得到与所述当前剪枝掩码对应的新的当前权重参数,以及与新的当前权重参数对应的新的当前轻量化模型;
判断所述当前稀疏率是否达到目标稀疏率,如果否则重新跳转至所述基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对所述初始稀疏率进行数值降低处理的步骤,如果是则结束训练。
2.根据权利要求1所述的模型训练方法,其特征在于,所述基于预设剪枝方法和初始稀疏率确定初始剪枝掩码,包括:
基于预设剪枝方法和初始稀疏率从所述初始轻量化模型的每个网络层的每个通道对应的通道权重中确定出满足预设通道权重确定规则的若干通道权重,并为所述若干通道权重设置相应的预设掩码,以得到对应的初始剪枝掩码;所述网络层包括除第一个卷积层之外的所有线性层和卷积层。
3.根据权利要求1所述的模型训练方法,其特征在于,所述基于预先构建的正则化后目标函数、所述预设剪枝方法和所述初始剪枝掩码对所述初始轻量化模型进行训练之前,还包括:
基于权重参数、剪枝掩码和单位矩阵构建矩阵的迹公式,并根据剪枝掩码对应的矩阵列数和所述矩阵的迹公式构建正则化项;
基于目标函数、正则化参数以及所述正则化项构建正则化后目标函数。
4.根据权利要求1所述的模型训练方法,其特征在于,所述基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对所述初始稀疏率进行数值降低处理之前,还包括:
获取预设训练总轮次和预设训练周期轮次;所述预设训练周期轮次为每个训练周期内需要进行的训练轮次;
基于所述预设训练总轮次与所述预设训练周期轮次计算训练周期数。
5.根据权利要求4所述的模型训练方法,其特征在于,所述基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对所述初始稀疏率进行数值降低处理之前,还包括:
获取初始稀疏率和目标稀疏率,并计算所述初始稀疏率与所述目标稀疏率之间的稀疏率差值;
基于所述稀疏率差值与所述训练周期数确定出稀疏率变化值。
6.一种模型训练装置,其特征在于,应用于对交通路况进行检测的边缘设备,包括:
权重调整模块,利用超球面学习算法对初始轻量化模型的初始权重参数和初始输入向量进行调整,以得到调整后权重参数和调整后输入向量;所述初始输入向量为输入至所述边缘设备的路况图像,并且所述调整后权重参数和所述调整后输入向量的模长均为1;
掩码确定模块,用于基于预设剪枝方法和初始稀疏率确定初始剪枝掩码;
模型训练模块,用于基于预先构建的正则化后目标函数、所述预设剪枝方法和所述初始剪枝掩码对所述初始轻量化模型进行训练,以将所述初始稀疏率降低至目标稀疏率,并对所述调整后权重参数进行相应的调整以得到对应的目标权重参数,以及与所述目标权重参数对应的剪枝后轻量化模型;
模型微调模块,用于基于预设模型微调规则对所述剪枝后轻量化模型进行微调,以得到目标轻量化模型,以便利用所述目标轻量化模型进行路况检测;
其中,所述模型训练模块具体用于:基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对所述初始稀疏率进行数值降低处理,以得到当前稀疏率;所述当前训练周期为预先计算的训练周期数中的任意一个周期;基于所述预设剪枝方法和所述当前稀疏率确定当前剪枝掩码;基于预先构建的正则化后目标函数、所述预设剪枝方法和所述当前剪枝掩码对当前轻量化模型进行训练,并对当前权重参数进行调整得到与所述当前剪枝掩码对应的新的当前权重参数,以及与新的当前权重参数对应的新的当前轻量化模型;判断所述当前稀疏率是否达到目标稀疏率,如果否则重新跳转至所述基于当前训练周期对应的周期序号以及预先计算的稀疏率变化值对所述初始稀疏率进行数值降低处理的步骤,如果是则结束训练。
7.一种电子设备,其特征在于,包括:
存储器,用于保存计算机程序;
处理器,用于执行所述计算机程序以实现如权利要求1至5任一项所述的模型训练方法。
8.一种计算机可读存储介质,其特征在于,用于保存计算机程序,所述计算机程序被处理器执行时实现如权利要求1至5任一项所述的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310108097.7A CN115829024B (zh) | 2023-02-14 | 2023-02-14 | 一种模型训练方法、装置、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310108097.7A CN115829024B (zh) | 2023-02-14 | 2023-02-14 | 一种模型训练方法、装置、设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN115829024A CN115829024A (zh) | 2023-03-21 |
CN115829024B true CN115829024B (zh) | 2023-06-20 |
Family
ID=85521201
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310108097.7A Active CN115829024B (zh) | 2023-02-14 | 2023-02-14 | 一种模型训练方法、装置、设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115829024B (zh) |
Families Citing this family (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116167430B (zh) * | 2023-04-23 | 2023-07-18 | 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) | 基于均值感知稀疏的目标检测模型全局剪枝方法及设备 |
CN117058525B (zh) * | 2023-10-08 | 2024-02-06 | 之江实验室 | 一种模型的训练方法、装置、存储介质及电子设备 |
CN117474070B (zh) * | 2023-12-26 | 2024-04-23 | 苏州元脑智能科技有限公司 | 模型剪枝方法、人脸识别模型训练方法及人脸识别方法 |
Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115481728A (zh) * | 2022-09-16 | 2022-12-16 | 云南电网有限责任公司电力科学研究院 | 输电线路缺陷检测方法、模型剪枝方法、设备和介质 |
Family Cites Families (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110197257A (zh) * | 2019-05-28 | 2019-09-03 | 浙江大学 | 一种基于增量正则化的神经网络结构化稀疏方法 |
CN110598731B (zh) * | 2019-07-31 | 2021-08-20 | 浙江大学 | 一种基于结构化剪枝的高效图像分类方法 |
CN111144566B (zh) * | 2019-12-30 | 2024-03-22 | 深圳云天励飞技术有限公司 | 神经网络权重参数的训练方法、特征分类方法及对应装置 |
CN112734029A (zh) * | 2020-12-30 | 2021-04-30 | 中国科学院计算技术研究所 | 一种神经网络通道剪枝方法、存储介质及电子设备 |
CN114282666A (zh) * | 2021-12-03 | 2022-04-05 | 中科视语(北京)科技有限公司 | 基于局部稀疏约束的结构化剪枝方法和装置 |
CN114594461A (zh) * | 2022-03-14 | 2022-06-07 | 杭州电子科技大学 | 基于注意力感知与缩放因子剪枝的声呐目标检测方法 |
-
2023
- 2023-02-14 CN CN202310108097.7A patent/CN115829024B/zh active Active
Patent Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115481728A (zh) * | 2022-09-16 | 2022-12-16 | 云南电网有限责任公司电力科学研究院 | 输电线路缺陷检测方法、模型剪枝方法、设备和介质 |
Also Published As
Publication number | Publication date |
---|---|
CN115829024A (zh) | 2023-03-21 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN115829024B (zh) | 一种模型训练方法、装置、设备及存储介质 | |
CN110832509B (zh) | 使用神经网络的黑盒优化 | |
CN111260030A (zh) | 基于a-tcn电力负荷预测方法、装置、计算机设备及存储介质 | |
US11863397B2 (en) | Traffic prediction method, device, and storage medium | |
CN109918663A (zh) | 一种语义匹配方法、装置及存储介质 | |
CN113905391A (zh) | 集成学习网络流量预测方法、系统、设备、终端、介质 | |
CN112215353B (zh) | 一种基于变分结构优化网络的通道剪枝方法 | |
CN113657421B (zh) | 卷积神经网络压缩方法和装置、图像分类方法和装置 | |
CN113011570A (zh) | 一种卷积神经网络模型的自适应高精度压缩方法及系统 | |
CN111355633A (zh) | 一种基于pso-delm算法的比赛场馆内手机上网流量预测方法 | |
CN111738477A (zh) | 基于深层特征组合的电网新能源消纳能力预测方法 | |
CN110826692B (zh) | 一种自动化模型压缩方法、装置、设备及存储介质 | |
Tian et al. | A network traffic hybrid prediction model optimized by improved harmony search algorithm | |
CN112766600A (zh) | 一种城市区域人群流量预测方法及系统 | |
CN112598062A (zh) | 一种图像识别方法和装置 | |
CN114492978A (zh) | 一种基于多层注意力机制的时空序列预测方法及设备 | |
CN110874635B (zh) | 一种深度神经网络模型压缩方法及装置 | |
CN115659807A (zh) | 一种基于贝叶斯优化模型融合算法对人才表现预测的方法 | |
CN112272074B (zh) | 一种基于神经网络的信息传输速率控制方法及系统 | |
CN111260056B (zh) | 一种网络模型蒸馏方法及装置 | |
CN114385876B (zh) | 一种模型搜索空间生成方法、装置及系统 | |
CN115423162A (zh) | 一种车流量预测方法、装置、电子设备及存储介质 | |
CN113205182B (zh) | 一种基于稀疏剪枝方法的实时电力负荷预测系统 | |
CN112200275B (zh) | 人工神经网络的量化方法及装置 | |
CN113516163A (zh) | 基于网络剪枝的车辆分类模型压缩方法、装置及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |