CN114118272B - 用于深度学习模型的三段式训练方法 - Google Patents

用于深度学习模型的三段式训练方法 Download PDF

Info

Publication number
CN114118272B
CN114118272B CN202111425140.XA CN202111425140A CN114118272B CN 114118272 B CN114118272 B CN 114118272B CN 202111425140 A CN202111425140 A CN 202111425140A CN 114118272 B CN114118272 B CN 114118272B
Authority
CN
China
Prior art keywords
training
parameters
deep learning
learning model
parameter set
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202111425140.XA
Other languages
English (en)
Other versions
CN114118272A (zh
Inventor
黄�良
王晓峰
韩诚山
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Changchun Institute of Optics Fine Mechanics and Physics of CAS
Original Assignee
Changchun Institute of Optics Fine Mechanics and Physics of CAS
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Changchun Institute of Optics Fine Mechanics and Physics of CAS filed Critical Changchun Institute of Optics Fine Mechanics and Physics of CAS
Priority to CN202111425140.XA priority Critical patent/CN114118272B/zh
Publication of CN114118272A publication Critical patent/CN114118272A/zh
Application granted granted Critical
Publication of CN114118272B publication Critical patent/CN114118272B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02ATECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
    • Y02A40/00Adaptation technologies in agriculture, forestry, livestock or agroalimentary production
    • Y02A40/10Adaptation technologies in agriculture, forestry, livestock or agroalimentary production in agriculture

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明提出了一种用于深度学习模型的三段式训练方法,根据深度学习模型内部参数是否具有明确实际意义将其划分为两大部分,并且将整个训练过程分为三个阶段,在每个阶段都分别固定一部分参数,训练另一部分参数,对具有明确实际意义的参数给予额外的关注,用更多的人工介入来使深度学习模型获得更快的训练速度和更好的应用效果。本发明提供的用于深度学习模型的三段式训练方法在训练模型时收敛速度快,训练时间短,并且所得模型的性能与全局最优解差距小。

Description

用于深度学习模型的三段式训练方法
技术领域
本发明涉及深度学习的技术领域,具体的涉及一种用于深度学习模型的三段式训练方法。
背景技术
深度学习是机器学习领域中一个新的研究方向。深度学习是学习样本数据的内在规律和表示层次,这些学习过程中获得的信息对诸如文字,图像和声音等数据的解释有很大的帮助。它的最终目标是让机器能够像人一样具有分析学习能力,能够识别文字、图像和声音等数据。深度学习是一个复杂的机器学习算法,在语音和图像识别方面取得的效果,远远超过先前相关技术。深度学习在搜索技术,数据挖掘,机器学习,机器翻译,自然语言处理,多媒体学习,语音,推荐和个性化技术,以及其他相关领域都取得了很多成果。深度学习使机器模仿视听和思考等人类的活动,解决了很多复杂的模式识别难题,使得人工智能相关技术取得了很大进步。
深度学习技术在实际应用时,一般都需要先建立一个模型,再初始化一组参数,然后使用大量数据对其进行训练,最后才能将它用于解决相应的问题。在对深度学习模型进行训练时,现有的训练方法大多数都只在根据某种原则设定好一组超参数和初始值后,完全交由优化器自动调整内部参数去拟合数据,模型内部的每个参数都在训练过程中受到同种程度的训练,这样做虽然也能得到效果不错的模型,但往往收敛速度慢,训练时间长,且很可能最终所取得的分类精度会与全局最优解有较大差距。
发明内容
针对现有技术中深度学习模型收敛速度慢,训练时间长的不足之处,本发明提供一种用于深度学习模型的三段式训练方法,以解决现有技术缺陷。
为实现上述目的,本发明提供了一种用于深度学习模型的三段式训练方法,包括以下步骤:
S1、对深度学习模型中的参数进行分类,将深度学习模型中的参数划分为有明确实际意义的第一参数集合和无明确实际意义的第二参数集合;
S2、固定第一参数集合中的参数,对深度学习模型中的参数进行初始化,使用训练数据集对深度学习模型进行第一阶段训练,通过第一阶段训练所获得的参数来更新第二参数集合中的参数,更新后的第二参数集合中的参数将被载入至深度学习模型中的相应位置;
S3、固定第二参数集合中的参数,使用训练数据集对深度学习模型进行第二阶段训练,通过第二阶段训练所获得的参数来更新第一参数集合中的参数,更新后的第一参数集合中的参数将被载入至深度学习模型中的相应位置;
S4、固定第一参数集合中的参数,使用训练数据集对深度学习模型进行第三阶段训练,通过第三阶段训练所获得的参数来更新第二参数集合中的参数;
S5、筛选出第三阶段训练中验证精度最高值所对应的参数,并载入至深度学习模型中,获得完成三阶段训练后的深度学习模型。
进一步地,步骤S2具体为:
固定第一参数集合中的参数,使用训练数据集对深度学习模型采用早停策略进行的训练,训练量不大于X个epoch;其中,X大于1;
深度学习模型每完成一个epoch数据量的训练进行一次验证精度计算,获得训练N个epoch后深度学习模型在验证集上的总体分类精度vOAN,并通过公式(1)计算深度学习模型的验证精度变化量,直至连续M数量的深度学习模型的验证精度变化量小于第一预设阈值时,停止训练完成第一阶段训练,并获取完成第一阶段训练后的参数,使用第一阶段训练后的参数更新第二参数集合中的参数,更新后的第二参数集合中的参数将被载入至深度学习模型中的相应位置;,得到第一阶段训练后的第二参数集合;其中,X大于1,N大于1;公式(1)如下所示:
ΔvOAN=vOAN-vOAN-1 (1)
其中,ΔvOAN表示第N个epoch的验证精度变化量,vOAN表示第N个epoch的验证精度,vOAN-1表示第N-1个epoch的验证精度。
进一步地,步骤S3具体为:
固定经第一阶段训练后的第二参数集合中的参数,使用训练数据集对深度学习模型采用早停策略进行不大于X个epoch的训练;
深度学习模型每完成一个epoch数据量的训练进行一次验证精度计算,获得训练Q个epoch后深度学习模型在验证集上的总体分类精度vOAQ,并通过公式(2)计算深度学习模型的验证精度变化量,直至连续P数量的深度学习模型的验证精度变化量小于第二预设阈值时,停止训练完成第二阶段训练,并获取完成第二阶段训练后的参数,使用第二阶段训练后的参数更新第一参数集合中的参数,更新后的第一参数集合中的参数将被载入至深度学习模型中的相应位置,得到第二阶段训练后的第一参数集合;其中,Q大于1;公式(2)如下所示:
ΔvOAQ=vOAQ-vOAQ-1 (2)
其中,ΔvOAQ表示第Q个epoch的验证精度变化量,vOAQ表示第Q个epoch的验证精度,vOAQ-1表示第Q-1个epoch的验证精度。
进一步地,步骤S4具体为:
固定第二阶段训练后的第一参数集合中的参数,使用训练数据集对深度学习模型进行训练,深度学习模型每完成一次训练进行一次验证精度计算,获得验证精度vOA,直至验证精度vOA达到第三预设阈值时或epoch值达到预设定的上限时,停止训练完成第三阶段训练,得到第三阶段训练后的深度学习模型的参数。
与现有技术相比,本发明的有益效果为:
本发明提供的用于深度学习模型的三段式训练方法,根据深度学习模型内部参数是否具有明确实际意义将其分成两大部分,并且将整个训练过程分为三个阶段,在每个阶段都分别固定一部分参数,训练另一部分参数,对具有明确实际意义的参数给予额外的关注,用更多的人工介入来使深度学习模型获得更快的训练速度和更好的应用效果。本发明提供的用于深度学习模型的三段式训练方法在训练模型时收敛速度快,训练时间短,并且所得模型的性能与全局最优解差距小。
附图说明
图1是本发明实施例1中的空谱联合压缩激励残差网络模型的基本结构示意图;
图2(a)是本发明实施例1的University of Pavia数据集的假色彩合成示意图;
图2(b)是本发明实施例1的University of Pavia数据集的真实地物分布示意图;
图2(c)是本发明实施例1的University of Pavia数据集的样本分布示意图表;
图3是本发明实施例1中的用于深度学习模型的三段式训练方法的流程示意图;
图4(a)是本发明实施例1中使用最终的完成三阶段训练后的深度学习模型对高光谱图像进行处理后分类直观效果示意图;
图4(b)是本发明实施例1中使用最终的完成三阶段训练后的深度学习模型对高光谱图像进行处理后的分类精度详情示意图表;
图5(a)是本发明实施例1中传统训练方法验证精度收敛曲线示意图;
图5(b)是本发明实施例1中三段式训练方法验证精度收敛曲线示意图;
图6是本发明实施例2中时空联合特征学习块模型的基本结构示意图;
图7是本发明实施例2中UCF101数据集的基本结构示意图表;
图8(a)是本发明实施例2中传统训练方法验证精度收敛曲线示意图;
图8(b)是本发明实施例2中三段式训练方法验证精度收敛曲线示意图。
具体实施方式
下面结合附图和实施例对本发明的实施方式作进一步详细描述。需要说明的是,在不冲突的情况下,本发明中的实施例及实施例中的特征可以相互结合。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明提供的用于深度学习模型的三段式训练方法适合应用在具有空谱特征融合权重参数的深度学习模型,提出了一种能够对不同参数在训练过程中给予不同程度关注的训练方法。
实施例1:
本发明提供的实施例1应用的领域为:图像处理中对于高光谱图像分类的技术领域。以高光谱图像分类的深度学习模型为实施例进行详细说明。根据高光谱图像分类的性质,本发明提供的实施例1使用空谱联合压缩激励残差网络(Spatial-Spectral Squeeze-and-Excitation Residual Network,以下简称SSSERN)作为深度学习的模型,该模型的基本结构如图1所示。本发明实施例1使用该领域常用的University of Pavia(以下简称UP)作为数据集,UP数据集的基本结构如图2所示。本发明提供的实施例1选择交叉熵损失函数作为损失函数,选择Adam优化器作为优化器。
图3示出了本发明实施例的用于深度学习模型的三段式训练方法的流程示意图。如图3所示,本发明提供的实施例1提供一种用于深度学习模型的三段式训练方法,包括如下步骤:
S1、以明确实际意义为依据对深度学习模型中的参数进行分类,划分为有明确实际意义的第一参数集合和无明确实际意义的第二参数集合。
本发明提供的实施例1将深度学习模型中每个基本子块内进行过压缩激励(Squeeze and Excitation,SE)操作的空间特征和光谱特征在加权融合时所用的权值,即∝1和∝2划分进第一参数集合,其他参数划分进第二参数集合。
S2、固定所述第一参数集合中的参数,对所述深度学习模型中的参数进行初始化,使用训练数据集对所述深度学习模型进行第一阶段训练,通过第一阶段训练所获得的参数来更新所述第二参数集合中的参数,更新后的第二参数集合中的参数将被载入至深度学习模型中的相应位置。
对SSSERN进行第一阶段训练,建立SSSERN的第一个实例model1,在建立model1时设置第一参数集合中的参数固定不变,其他参数(即第二参数集合中的参数)可以调整。对模型中的参数进行初始化,在本实施例1中第一参数集合中参数值的初始化方式为:以对具体应用目标来说效果最均衡为依据,令每个分支的权值都相等;第二参数集合中的参数值的确定方式为:采用任意初始化方法获取。使用训练数据集对SSSERN进行第一阶段训练,通过第一阶段训练所获得的参数来更新第二参数集合中的参数,更新后的第二参数集合中的参数将被载入至深度学习模型中的相应位置。
S3、固定所述第二参数集合中的参数,使用训练数据集对所述深度学习模型进行第二阶段训练,通过第二阶段训练所获得的参数来更新所述第一参数集合中的参数,更新后的第一参数集合中的参数将被载入至深度学习模型中的相应位置。
建立SSSERN的第二个实例model2,在建立model2时设置第二参数集合中的参数固定不变,其他参数(即第一参数集合中的参数)可以调整。加载经第一阶段训练更新后的参数,使用训练数据集对SSSERN进行第二阶段训练,通过第二阶段训练所获得的参数来更新第一参数集合中的参数。
S4、固定所述第一参数集合中的参数,使用训练数据集对所述深度学习模型进行第三阶段训练,通过第三阶段训练所获得的参数来更新所述第二参数集合中的参数。
建立SSSERN的第三个实例model3,在建立model3时设置第一参数集合中的参数固定不变,其他参数(即第二参数集合中的参数)可以调整。加载经第二阶段训练更新后的参数,使用训练数据集对SSSERN进行第三阶段训练,通过第三阶段训练所获得的参数来更新第二参数集合中的参数。
S5、筛选出第三阶段训练中验证精度最高值所对应的参数,并载入至所述深度学习模型中,获得完成三阶段训练后的深度学习模型。
本发明提供的实施例1提供一种优选方案,实施例1中步骤S2具体为:
固定第一参数集合中的参数,其他参数(即第二参数集合中的参数)可以调整。初始化model1中的参数,其中第一参数集合中的参数均取为0.5,此时空间特征和光谱特征对于最终的分类任务有相同的权重,第二参数集合中的参数采用随机初始化的方式。
使用已经确定好的数据集UP、数据加载器、交叉熵损失函数、Adam优化器对已经完成参数初始化的model1进行训练。其中,在训练时需要加载的数据集为训练数据集,批大小为100,训练所使用的方式为:早停(early stopping)策略。
固定第一参数集合中的参数,使用训练数据集对深度学习模型采用早停策略进行的训练,训练量不大于100个epoch。深度学习模型每完成一个epoch数据量的训练都进行一次验证精度计算,获得训练N个epoch后所述深度学习模型在验证集上的总体分类精度vOAN,并通过公式(1)计算深度学习模型的验证精度变化量,直至连续3个数量的深度学习模型的验证精度变化量小于0.05时,即连续3个ΔvOAN均小于0.05,停止训练完成第一阶段训练,获取并保存完成第一阶段训练后的参数,更新后的第二参数集合中的参数将被载入至深度学习模型中的相应位置,得到第一阶段训练后的第二参数集合,结束第一阶段训练。其中,N等于3;公式(1)如下所示:
ΔvOAN=vOAN-vOAN-1 (1)。
其中,ΔvOAN表示第N个epoch的验证精度变化量,vOAN表示第N个epoch的验证精度,vOAN-1表示第N-1个epoch的验证精度。
本发明提供的实施例1提供一种优选方案,实施例1中步骤S3具体为:
建立SSSERN的第二个实例model2,在建立model2时设置第二参数集合中的参数固定不变,其他参数(即第一参数集合中的参数)可以调整。加载经第一阶段训练更新后的参数。
使用已经确定好的数据集UP、数据加载器、交叉熵损失函数、Adam优化器对已经完成参数加载的model2进行训练。其中,在训练时需要加载的数据集为训练数据集,批大小为100,训练所使用的方式为:早停(early stopping)策略。
固定经第一阶段训练后的第二参数集合中的参数,使用训练数据集对深度学习模型采用早停策略进行不大于100个epoch的训练。深度学习模型每完成一个epoch数据量训练,都对所计算出的验证精度的进行判断,计算获得训练后模型在验证集上的总体分类精度vOA,并通过公式(2)计算深度学习模型的验证精度变化量,直至连续3个数量的深度学习模型的验证精度变化量小于0.03时,即连续3个ΔvOAQ均小于0.03,停止训练完成第二阶段训练,获取并保存完成第二阶段训练后的参数,更新后的第一参数集合中的参数将被载入至深度学习模型中的相应位置,得到第二阶段训练后的第一参数集合,结束第二阶段训练,Q等于3。公式(2)如下所示:
ΔvOAQ=vOAQ-vOAQ-1 (2)
其中,ΔvOAQ表示第Q个epoch的验证精度变化量,vOAQ表示第Q个epoch的验证精度,vOAQ-1表示第Q-1个epoch的验证精度。
本发明提供的实施例1提供一种优选方案,实施例1中步骤S4具体为:
建立SSSERN的第三个实例model3,在建立model3时设置第一参数集合中的参数固定不变,其他参数(即第二参数集合中的参数)可以进行微调整。加载经第二阶段训练更新后的参数。
使用已经确定好的数据集UP、数据加载器、交叉熵损失函数、Adam优化器对已经完成参数加载的model3进行训练。其中,在训练时需要加载的数据集为训练数据集,批大小为100。当满足预设定的验证精度要求或epoch值达到上限值100时,停止训练完成第三阶段训练,通过第三阶段训练所获得的参数来更新第二参数集合中的参数。第三阶段训练终止后,整个第三阶段训练过程中最高的验证精度对应的那次训练所得模型中的参数就是最终参数。
将所得到的最终参数加载进入SSSERN中,使用最终的完成三阶段训练后的深度学习模型对高光谱图像进行处理,处理结果如图4所示。
图5(a)和图5(b)示出了本发明实施例1中传统训练方法和三段式训练方法对比示意图,在训练过程中所得模型的验证精度随epoch的变化关系,其中三段式训练方法对应的图5(b)中用两根实线分出了训练过程的三个阶段,可以看到通过三段式训练方法的曲线在后续训练过程中能够收敛到更高的地方,并且如果要达到相同的分类精度,三段式训练方法所需要的迭代次数更少。
本发明提供的实施例1提供一种优选方案,在步骤S1前还包括如下步骤:
S0、将深度学习模型中的UP数据集分为:训练集、验证集和测试集,三者的比例为15:5:80,设置数据加载器的批大小为100。
训练集、验证集、测试集的大小需要根据具体的数据集和对模型性能的要求确定,本发明实施例1对此不进行限定。其中,数据集越大、对泛化能力要求越低,则所选择的训练集的比例就可以越小;所选验证集的比例越大,则训练时间越长,同时也越容易获得更强的泛化能力。
实施例2:
本发明提供的实施例2应用的领域为姿态识别技术领域。根据姿态识别的性质,本发明提供的实施例2使用基于时空联合(Collaborative Spatiotemporal,CoST)特征学习块构建的网络作为深度学习的模型(以下简称CoSTNet),该特征学习块的基本结构如图6所示,输入的C1个特征图会经过三种卷积核的卷积得到C2×3个特征图,每种卷积都会分别提取出一类特征:xhw为空间特征,xtw为横向时间特征,xth为纵向时间特征,这C2×3个特征图会与C2×3个权重α值进行乘加操作,实现三类特征的加权融合。整个深度学习的模型网络的构建过程如下:以C2D网络为基础,将C2D网络中每两个残差块(参考C3D网络的结构)内的1个2维卷积操作替换为CoST块。本发明实施例2使用该领域常用的UCF101数据集作为数据集,UCF101数据集的基本结构如图7所示。本发明提供的实施例2选择交叉熵损失函数作为损失函数,选择动量梯度下降优化器(SGD with momentum)作为优化器。
图3示出了本发明实施例的用于深度学习模型的三段式训练方法的流程示意图。如图3所示,本发明提供的实施例2提供一种用于深度学习模型的三段式训练方法,包括如下步骤:
S1、以明确实际意义为依据对深度学习模型中的参数进行分类,划分为有明确实际意义的第一参数集合和无明确实际意义的第二参数集合。
本发明提供的实施例2将深度学习模型中每个CoST块内的C2×3个权重α值划分进第一参数集合,其他参数划分进第二参数集合。
S2、固定所述第一参数集合中的参数,对所述深度学习模型中的参数进行初始化,使用训练数据集对所述深度学习模型进行第一阶段训练,通过第一阶段训练所获得的参数来更新所述第二参数集合中的参数,更新后的第二参数集合中的参数将被载入至深度学习模型中的相应位置。
对CoSTNet进行第一阶段训练,建立CoSTNet的第一个实例model1,在建立model1时设置第一参数集合中的参数固定不变,其他参数(即第二参数集合中的参数)可以调整。对模型中的参数进行初始化,在本实施例2中第一参数集合中参数值的确定方式为:三类时空特征的每个通道的特征图对于最终的分类任务都有相同的权重;第二参数集合中的参数值的确定方式为:采用随机初始化方法获取,使用训练数据集对CoSTNet进行第一阶段训练,通过第一阶段训练所获得的参数来更新第二参数集合中的参数,更新后的第二参数集合中的参数将被载入至深度学习模型中的相应位置。
S3、固定所述第二参数集合中的参数,使用训练数据集对所述深度学习模型进行第二阶段训练,通过第二阶段训练所获得的参数来更新所述第一参数集合中的参数,更新后的第一参数集合中的参数将被载入至深度学习模型中的相应位置。
建立CoSTNet的第二个实例model2,在建立model2时设置第二参数集合中的参数固定不变,其他参数(即第一参数集合中的参数)可以调整。加载经第一阶段训练更新后的参数,使用训练数据集对CoSTNet进行第二阶段训练,通过第二阶段训练所获得的参数来更新第一参数集合中的参数。
S4、固定所述第一参数集合中的参数,使用训练数据集对所述深度学习模型进行第三阶段训练,通过第三阶段训练所获得的参数来更新所述第二参数集合中的参数。
建立CoSTNet的第三个实例model3,在建立model3时设置第一参数集合中的参数固定不变,其他参数(即第二参数集合中的参数)可以调整。加载经第二阶段训练更新后的第一参数集合中的参数,使用训练数据集对CoSTNet进行第三阶段训练,通过第三阶段训练所获得的参数来更新第二参数集合中的参数。
S5、筛选出第三阶段训练中验证精度最高值所对应的参数,并载入至所述深度学习模型中,获得完成三阶段训练后的深度学习模型。
本发明提供的实施例2提供一种优选方案,实施例2中步骤S2具体为:
固定第一参数集合中的参数,其他参数(即第二参数集合中的参数)可以调整。初始化model1中的参数,其中第一参数集合中的参数值αi均按公式(3)计算,公式(3)如下所示:
其中,αi表示第i个CoST块中的权值α,C2-i表示第i个CoST块中一种卷积操作所得的特征图通道数。此时第一参数集合中的三类时空特征权重参数对于最终的分类任务有相同的权重,第二参数集合中的参数采用随机初始化的方式。
使用已经确定好的数据集UCF101、数据加载器、交叉熵损失函数、动量梯度下降优化器对已经完成参数初始化的model1进行训练。其中,在训练时需要加载的数据集为训练数据集,批大小为8,训练所使用的方式为:早停(early stopping)策略。
固定第一参数集合中的参数,使用训练数据集对深度学习模型采用早停策略进行的训练,训练量不大于100个epoch。深度学习模型每完成一个epoch数据量的训练都进行一次验证精度计算,获得训练N个epoch后所述深度学习模型在验证集上的总体分类精度vOAN,并通过公式(1)计算深度学习模型的验证精度变化量,直至连续3个数量的深度学习模型的验证精度变化量小于0.05时,即连续3个ΔvOAN均小于0.05,停止训练完成第一阶段训练,获取并保存完成第一阶段训练后的参数,更新后的第二参数集合中的参数将被载入至深度学习模型中的相应位置,得到第一阶段训练后的第二参数集合,结束第一阶段训练。其中,N等于3;公式(1)如下所示:
ΔvOAN=vOAN-vOAN-1 (1)。
其中,ΔvOAN表示第N个epoch的验证精度变化量,vOAN表示第N个epoch的验证精度,vOAN-1表示第N-1个epoch的验证精度。
本发明提供的实施例2提供一种优选方案,实施例2中步骤S3具体为:
建立CoSTNet的第二个实例model2,在建立model2时设置第一阶段训练后的第二参数集合中的参数固定不变,其他参数(即第一参数集合中的参数)可以调整。加载经第一阶段训练更新后的参数。
使用已经确定好的数据集UCF101、数据加载器、交叉熵损失函数、动量梯度下降优化器已经完成参数加载的model2进行训练。其中,在训练时需要加载的数据集为训练数据集,批大小为8,训练所使用的方式为:早停(early stopping)策略。
固定经第一阶段训练后的第二参数集合中的参数,使用训练数据集对深度学习模型采用早停策略进行不大于100个epoch的训练。深度学习模型每完成一个epoch数据量训练,都对所计算出的验证精度的进行判断,计算获得训练后模型在验证集上的总体分类精度vOA,并通过公式(2)计算深度学习模型的验证精度变化量,直至连续3个数量的深度学习模型的验证精度变化量小于0.03时,即连续3个ΔvOAQ均小于0.03,停止训练完成第二阶段训练,获取并保存完成第二阶段训练后的参数,更新后的第一参数集合中的参数将被载入至深度学习模型中的相应位置,得到第二阶段训练后的第一参数集合,结束第二阶段训练,其中,Q等于3。公式(2)如下所示:
ΔvOAQ=vOAQ-vOAQ-1 (2)
其中,ΔvOAQ表示第Q个epoch的验证精度变化量,vOAQ表示第Q个epoch的验证精度,vOAQ-1表示第Q-1个epoch的验证精度。
本发明提供的实施例2提供一种优选方案,实施例2中步骤S4具体为:
建立CoSTNet的第三个实例model3,在建立model3时设置第二阶段训练后的第一参数集合中的参数固定不变,其他参数(即第二参数集合中的参数)可以进行微调整。加载经第二阶段训练更新后的参数。
使用已经确定好的数据集UCF101、数据加载器、交叉熵损失函数、动量梯度下降优化器已经完成参数初始化的model3进行训练。其中,在训练时需要加载的数据集为训练数据集,批大小为8。当满足预设定的验证精度要求或epoch值达到上限100时,停止训练完成第三阶段训练,通过第三阶段训练所获得的参数来更新第二参数集合中的参数。第三阶段训练终止后,整个第三阶段训练过程中最高的验证精度对应的那次训练所得模型中的参数就是最终参数。
将所得到的最终参数加载进入CoSTNet中,即为最终的完成三阶段训练后的深度学习模型。
图8a和图8b示出了本发明实施例2中传统训练方法和三段式训练方法对比示意图,在训练过程中所得模型的验证精度随epoch的变化关系,其中三段式训练方法对应的图8b中用两根实线分出了训练过程的三个阶段,可以看到通过三段式训练方法的曲线在后续训练过程中能够收敛到更高的地方,并且如果要达到相同的分类精度,三段式训练方法所需要的迭代次数更少。
在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本发明的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不必须针对的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任一个或多个实施例或示例中以合适的方式结合。此外,在不相互矛盾的情况下,本领域的技术人员可以将本说明书中描述的不同实施例或示例以及不同实施例或示例的特征进行结合和组合。
尽管上面已经示出和描述了本发明的实施例,可以理解的是,上述实施例是示例性的,不能理解为对本发明的限制。本领域的普通技术人员在本发明的范围内可以对上述实施例进行变化、修改、替换和变型。
以上本发明的具体实施方式,并不构成对本发明保护范围的限定。任何根据本发明的技术构思所做出的各种其他相应的改变与变形,均应包含在本发明权利要求的保护范围内。

Claims (1)

1.一种应用于高光谱图像分类领域的用于深度学习模型的三段式训练方法,其特征在于,包括如下步骤:
S1、对深度学习模型中的参数进行分类,将所述深度学习模型中的参数划分为有明确实际意义的第一参数集合和无明确实际意义的第二参数集合;
S2、固定所述第一参数集合中的参数,对所述深度学习模型中的参数进行初始化,使用训练数据集对所述深度学习模型进行第一阶段训练,通过第一阶段训练所获得的参数来更新所述第二参数集合中的参数,更新后的所述第二参数集合中的参数将被载入至所述深度学习模型中的相应位置;
S3、固定所述第二参数集合中的参数,使用所述训练数据集对所述深度学习模型进行第二阶段训练,通过第二阶段训练所获得的参数来更新所述第一参数集合中的参数,更新后的所述第一参数集合中的参数将被载入至所述深度学习模型中的相应位置;
S4、固定所述第一参数集合中的参数,使用所述训练数据集对所述深度学习模型进行第三阶段训练,通过第三阶段训练所获得的参数来更新所述第二参数集合中的参数;
S5、筛选出第三阶段训练中验证精度最高值所对应的参数,并载入至所述深度学习模型中,获得完成三阶段训练后的深度学习模型;
所述步骤S2具体为:
固定所述第一参数集合中的参数,使用训练数据集对所述深度学习模型采用早停策略进行训练,训练量不大于X个epoch;其中,X大于1;
所述深度学习模型每完成一个epoch数据量的训练进行一次验证精度计算,获得训练N个epoch后所述深度学习模型在验证集上的总体分类精度,并通过公式(1)计算所述深度学习模型的验证精度变化量,直至连续M数量的所述深度学习模型的验证精度变化量小于第一预设阈值时,停止训练完成第一阶段训练,并获取完成第一阶段训练后的参数,使用所述第一阶段训练后的参数更新所述第二参数集合中的参数,更新后的所述第二参数集合中的参数将被载入至所述深度学习模型中的相应位置,得到所述第一阶段训练后的第二参数集合;其中,X大于1,N大于1;公式(1)如下所示:
(1)
其中,表示第N个epoch的验证精度变化量,/>表示第N个epoch的验证精度,表示第N-1个epoch的验证精度;
所述步骤S3具体为:
固定经第一阶段训练后的所述第二参数集合中的参数,使用训练数据集对所述深度学习模型采用早停策略进行不大于X个epoch的训练;
所述深度学习模型每完成一个epoch数据量的训练进行一次验证精度计算,获得训练Q个epoch后所述深度学习模型在验证集上的总体分类精度,并通过公式(2)计算所述深度学习模型的验证精度变化量,直至连续P数量的所述深度学习模型的验证精度变化量小于第二预设阈值时,停止训练完成第二阶段训练,并获取完成第二阶段训练后的参数,使用所述第二阶段训练后的参数更新所述第一参数集合中的参数,更新后的所述第一参数集合中的参数将被载入至所述深度学习模型中的相应位置,得到所述第二阶段训练后的第一参数集合;其中,Q大于1;公式(2)如下所示:
(2)
其中,表示第Q个epoch的验证精度变化量,/>表示第Q个epoch的验证精度,表示第Q-1个epoch的验证精度;
所述步骤S4具体为:
固定所述第二阶段训练后的所述第一参数集合中的参数,使用训练数据集对所述深度学习模型进行训练,所述深度学习模型每完成一次训练进行一次验证精度计算,获得验证精度,直至验证精度/>达到第三预设阈值时或epoch值达到预设定的上限时,停止训练完成第三阶段训练,得到所述第三阶段训练后的深度学习模型的参数。
CN202111425140.XA 2021-11-26 2021-11-26 用于深度学习模型的三段式训练方法 Active CN114118272B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111425140.XA CN114118272B (zh) 2021-11-26 2021-11-26 用于深度学习模型的三段式训练方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111425140.XA CN114118272B (zh) 2021-11-26 2021-11-26 用于深度学习模型的三段式训练方法

Publications (2)

Publication Number Publication Date
CN114118272A CN114118272A (zh) 2022-03-01
CN114118272B true CN114118272B (zh) 2024-04-30

Family

ID=80370640

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111425140.XA Active CN114118272B (zh) 2021-11-26 2021-11-26 用于深度学习模型的三段式训练方法

Country Status (1)

Country Link
CN (1) CN114118272B (zh)

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108734193A (zh) * 2018-03-27 2018-11-02 合肥麟图信息科技有限公司 一种深度学习模型的训练方法及装置
CN110633730A (zh) * 2019-08-07 2019-12-31 中山大学 一种基于课程学习的深度学习机器阅读理解训练方法
CN110689045A (zh) * 2019-08-23 2020-01-14 苏州千视通视觉科技股份有限公司 一种深度学习模型的分布式训练方法及装置
CN111160538A (zh) * 2020-04-02 2020-05-15 北京精诊医疗科技有限公司 一种损失函数中margin参数值的更新方法和系统
WO2020249125A1 (zh) * 2019-06-14 2020-12-17 第四范式(北京)技术有限公司 用于自动训练机器学习模型的方法和系统
CN112949837A (zh) * 2021-04-13 2021-06-11 中国人民武装警察部队警官学院 一种基于可信网络的目标识别联邦深度学习方法

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20210142210A1 (en) * 2019-11-11 2021-05-13 Alibaba Group Holding Limited Multi-task segmented learning models
US20210158147A1 (en) * 2019-11-26 2021-05-27 International Business Machines Corporation Training approach determination for large deep learning models

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108734193A (zh) * 2018-03-27 2018-11-02 合肥麟图信息科技有限公司 一种深度学习模型的训练方法及装置
WO2020249125A1 (zh) * 2019-06-14 2020-12-17 第四范式(北京)技术有限公司 用于自动训练机器学习模型的方法和系统
CN110633730A (zh) * 2019-08-07 2019-12-31 中山大学 一种基于课程学习的深度学习机器阅读理解训练方法
CN110689045A (zh) * 2019-08-23 2020-01-14 苏州千视通视觉科技股份有限公司 一种深度学习模型的分布式训练方法及装置
CN111160538A (zh) * 2020-04-02 2020-05-15 北京精诊医疗科技有限公司 一种损失函数中margin参数值的更新方法和系统
CN112949837A (zh) * 2021-04-13 2021-06-11 中国人民武装警察部队警官学院 一种基于可信网络的目标识别联邦深度学习方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
一种基于两阶段深度学习的集成推荐模型;王瑞琴;吴宗大;蒋云良;楼俊钢;计算机研究与发展;20191231(第008期);全文 *

Also Published As

Publication number Publication date
CN114118272A (zh) 2022-03-01

Similar Documents

Publication Publication Date Title
Too et al. A comparative study of fine-tuning deep learning models for plant disease identification
JP6980958B1 (ja) 深層学習に基づく農村地域分けゴミ識別方法
WO2019228122A1 (zh) 模型的训练方法、存储介质及计算机设备
CN106874921B (zh) 图像分类方法和装置
Othman et al. A new deep learning application based on movidius ncs for embedded object detection and recognition
CN110334759B (zh) 一种评论驱动的深度序列推荐方法
Fu et al. DSAGAN: A generative adversarial network based on dual-stream attention mechanism for anatomical and functional image fusion
CN112561027A (zh) 神经网络架构搜索方法、图像处理方法、装置和存储介质
CN108614992A (zh) 一种高光谱遥感图像的分类方法、设备及存储设备
CN112101432B (zh) 一种基于深度学习的材料显微图像与性能双向预测方法
CN110852369B (zh) 联合3d/2d卷积网络和自适应光谱解混的高光谱图像分类方法
Doi et al. The effect of focal loss in semantic segmentation of high resolution aerial image
Kim et al. Label-preserving data augmentation for mobile sensor data
CN109284782A (zh) 用于检测特征的方法和装置
CN112580720A (zh) 一种模型训练方法及装置
CN110210278A (zh) 一种视频目标检测方法、装置及存储介质
CN114511710A (zh) 一种基于卷积神经网络的图像目标检测方法
KR20220116270A (ko) 학습 처리 장치 및 방법
CN110096976A (zh) 基于稀疏迁移网络的人体行为微多普勒分类方法
Yang et al. Ultra-lightweight CNN design based on neural architecture search and knowledge distillation: A novel method to build the automatic recognition model of space target ISAR images
CN112507114A (zh) 一种基于词注意力机制的多输入lstm_cnn文本分类方法及系统
CN112308825A (zh) 一种基于SqueezeNet的农作物叶片病害识别方法
CN106023268A (zh) 一种基于两步参数子空间优化的彩色图像灰度化方法
CN113936143B (zh) 基于注意力机制和生成对抗网络的图像识别泛化方法
Hu et al. Saliency-based YOLO for single target detection

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant