CN113888524A - 缺陷检测模型训练方法、装置、设备及可读存储介质 - Google Patents

缺陷检测模型训练方法、装置、设备及可读存储介质 Download PDF

Info

Publication number
CN113888524A
CN113888524A CN202111218168.6A CN202111218168A CN113888524A CN 113888524 A CN113888524 A CN 113888524A CN 202111218168 A CN202111218168 A CN 202111218168A CN 113888524 A CN113888524 A CN 113888524A
Authority
CN
China
Prior art keywords
model
training
defect
gradient value
defect detection
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111218168.6A
Other languages
English (en)
Inventor
钱程浩
黄雪峰
熊海飞
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Xinrun Fulian Digital Technology Co Ltd
Original Assignee
Shenzhen Xinrun Fulian Digital Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Xinrun Fulian Digital Technology Co Ltd filed Critical Shenzhen Xinrun Fulian Digital Technology Co Ltd
Priority to CN202111218168.6A priority Critical patent/CN113888524A/zh
Publication of CN113888524A publication Critical patent/CN113888524A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T7/00Image analysis
    • G06T7/0002Inspection of images, e.g. flaw detection
    • G06T7/0004Industrial image inspection
    • G06T7/0006Industrial image inspection using a design-rule based approach
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/10Image acquisition modality
    • G06T2207/10004Still image; Photographic image

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Artificial Intelligence (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • General Engineering & Computer Science (AREA)
  • Quality & Reliability (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种缺陷检测模型训练方法、装置、设备及计算机可读存储介质,通过混合精度训练,将模型参数进行精度降低调节后再进行前后向传播计算,能够在保证模型进度的情况下节省计算内存带宽以及显存占用,同时也节省了数据传输时间,减少模型训练时间;通过对模型梯度值进行溢出判断,采用比例因子对原始的梯度值进行扩大,并采用动态调整策略对模型梯度值进行调整,避免了因模型参数精度转换而可能导致的梯度消失最终致使模型训练失败的情况,最终达到了保证模型有效训练的情况下大大提升模型训练迭代效率,提高模型验证速度,使得缺陷检测模型在应对多种场景检测需求时可以做到模型快速迭代和结果验证。

Description

缺陷检测模型训练方法、装置、设备及可读存储介质
技术领域
本发明涉及自动化技术领域,尤其涉及缺陷检测模型训练方法、装置、设备及计算机可读存储介质。
背景技术
随着自动化技术的快速发展,在工业生产中很多需要人工操作的环节逐渐由机器完成。而机器零件产品表面缺陷检测是工业生产中的重要环节,是产品质量把控的关键步骤,借助缺陷检测技术可以有效的提高生产质量和效率。但是由于设备及工艺等因素的影响,产品表面的缺陷类型往往五花八门。目前,基于深度学习的缺陷检测已经应用于金属固件、布匹丝织物、建筑裂纹、钢筋裂纹等多个领域。
当前的缺陷检测系统主要使用单阶段目标检测器(SSD,Single shot multiboxdetector)实现对缺陷图像的处理。这一检测器在前期训练阶段同大多数基于深度神经网络的训练方式相同,因此为了使得模型收敛且达到较好的泛化效果,往往需要大量训练数据,并且为了使得深度网络表达的特征丰富,在设计时会使得网络结构越深、越复杂。具体地,在SSD缺陷检测模型中,使用了VGG-16(VGG指的是由牛津大学的一个研究组织开发出的一种卷积神经网络)作为特征提取网络。模型的总参数量为138M,所占内存为526M(138M*4bytes)。如此大的参数量和内存占用会对显卡计算产生较大压力,如果显卡算力不够会导致模型训练耗时较长,如果显卡内存不够会直接导致模型训练失败。
发明内容
本发明的主要目的在于提出一种缺陷检测模型训练方法、装置、设备及计算机可读存储介质,旨在解决现有基于深度学习的缺陷检测模型的参数量和计算量较大,在算力较小和显卡内存较小的设备上,容易出现训练效率低下或训练失败的技术问题。
为实现上述目的,本发明提供一种缺陷检测模型训练方法,所述缺陷检测模型训练方法包括:
在预搭建的基于深度学习的缺陷检测模型中,针对输入的缺陷图像训练数据进行混合精度训练,其中,所述混合精度训练中模型参数经过数据格式精度降低转换后再参与网络传播;
获取所述缺陷检测模型输出的缺陷预测值,根据预设的比例因子和所述缺陷预测值得到模型梯度值,并确定所述模型梯度值的溢出情况;
针对所述溢出情况,按照预设的动态调整策略对所述模型梯度值进行调整,得到更新模型梯度值;
基于所述更新模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至满足预设的模型收敛条件。
可选地,所述缺陷检测模型包括由多个卷积层组成的特征提取网络,
所述在预搭建的基于深度学习的缺陷检测模型中,针对输入的缺陷图像训练数据进行混合精度训练的步骤包括:
在所述缺陷检测模型的特征提取网络中,根据所述缺陷图像训练数据得到权重和激活值,以作为所述模型参数;
将所述权重和激活值由32位浮点型转换为16位浮点型,以按照16位浮点型的权重和激活值在各所述卷积层之间进行前向传播。
可选地,所述缺陷检测模型还包括单阶段目标检测器,
所述获取所述缺陷检测模型输出的缺陷预测值,根据预设的比例因子和所述缺陷预测值得到模型梯度值,并确定所述模型梯度值的溢出情况的步骤包括:
获取所述特征提取网络基于所述训练缺陷图像生成的特征图像,并根据所述单阶段目标检测器得到所述特征图像对应的缺陷预测值;
根据所述缺陷预测值和所述缺陷图像训练数据中的缺陷真实值得到原始损失值,并结合所述比例因子扩大所述原始损失值得到损失值;
根据所述损失值得到所述模型梯度值,并判断所述模型梯度值是否溢出。
可选地,所述动态调整策略包括第一调整策略,
所述针对所述溢出情况,按照预设的动态调整策略对所述模型梯度值进行调整,得到更新模型梯度值的步骤包括:
若所述模型梯度值溢出,则按照所述第一调整策略缩小所述比例因子,并基于缩小后的比例因子和所述原始损失值得到新的模型梯度值,作为所述更新模型梯度值。
可选地,所述动态调整策略包括第二调整策略,
所述针对所述溢出情况,按照预设的动态调整策略对所述模型梯度值进行调整,得到更新模型梯度值的步骤包括:
若所述模型梯度值未溢出,则按照所述第二调整策略将所述模型梯度值按照所述比例因子进行缩小还原,将还原后的模型梯度值作为所述更新模型梯度值。
可选地,所述基于所述更新模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至满足预设的模型收敛条件的步骤包括:
基于所述更新模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,并判断模型迭代训练过程中是否连续预设次数未出现模型梯度值溢出;
若是,则增大所述比例因子,并根据增大后的比例因子得到目标模型梯度值;
基于所述目标模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至满足预设的模型收敛条件。
可选地,所述基于所述更新模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至满足预设的模型收敛条件的步骤包括:
根据所述更新模型梯度值回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至当前迭代轮次对应的模型梯度值小于预设梯度阈值,或是当前迭代轮次达到预设轮次阈值时,判定满足所述模型收敛条件。
此外,为实现上述目的,本发明还提供一种缺陷检测模型训练装置,所述缺陷检测模型训练装置包括:
混合精度训练模块,用于在预搭建的基于深度学习的缺陷检测模型中,针对输入的缺陷图像训练数据进行混合精度训练,其中,所述混合精度训练中模型参数经过数据格式精度降低转换后再参与网络传播;
溢出情况确定模块,用于获取所述缺陷检测模型输出的缺陷预测值,根据预设的比例因子和所述缺陷预测值得到模型梯度值,并确定所述模型梯度值的溢出情况;
模型梯度调整模块,用于针对所述溢出情况,按照预设的动态调整策略对所述模型梯度值进行调整,得到更新模型梯度值;
模型迭代训练模块,用于基于所述更新模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至满足预设的模型收敛条件。
此外,为实现上述目的,本发明还提供一种缺陷检测模型训练设备,所述缺陷检测模型训练设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的缺陷检测模型训练程序,所述缺陷检测模型训练程序被所述处理器执行时实现如上所述的缺陷检测模型训练方法的步骤。
此外,为实现上述目的,本发明还提供一种计算机可读存储介质,所述计算机可读存储介质上存储有缺陷检测模型训练程序,所述缺陷检测模型训练程序被处理器执行时实现如上所述的缺陷检测模型训练方法的步骤。
此外,为实现上述目的,本发明还提供一种计算机可读存储介质,包括计算机程序,所述计算机程序被处理器执行时实现如上述的缺陷检测模型训练方法的步骤。
本发明通过混合精度训练,将模型参数进行精度降低调节后再进行前后向传播计算,能够在保证模型进度的情况下节省计算内存带宽以及显存占用,同时也节省了数据传输时间,减少模型训练时间;通过对模型梯度值进行溢出判断,采用比例因子对原始的梯度值进行扩大,并采用动态调整策略对模型梯度值进行调整,避免了因模型参数精度转换而可能导致的梯度消失最终致使模型训练失败的情况,最终达到了在保证模型有效训练的情况下大大提升模型训练迭代效率,提高模型验证速度,从而解决了现有基于深度学习的缺陷检测模型的参数量和计算量较大,在算力较小和显卡内存较小的设备上,容易出现训练效率低下或训练失败的技术问题。
附图说明
图1是本发明实施例方案涉及的硬件运行环境的设备结构示意图;
图2为本发明缺陷检测模型训练方法第一实施例的流程示意图;
图3为本发明缺陷检测模型训练方法第二实施例中一具体实施例的模型架构示意图;
图4为本发明缺陷检测模型训练方法第二实施例中一具体实施例的梯度动态调整流程示意图;
图5为本发明缺陷检测模型训练装置的功能模块示意图。
本发明目的的实现、功能特点及优点将结合实施例,参照附图做进一步说明。
具体实施方式
应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
当前的缺陷检测系统主要使用单阶段目标检测器(SSD,Single shot multiboxdetector)实现对缺陷图像的处理。这一检测器在前期训练阶段同大多数基于深度神经网络的训练方式相同,因此为了使得模型收敛且达到较好的泛化效果,往往需要大量训练数据,并且为了使得深度网络表达的特征丰富,在设计时会使得网络结构越深、越复杂。具体地,在SSD缺陷检测模型中,使用了VGG-16(VGG指的是由牛津大学的一个研究组织开发出的一种卷积神经网络)作为特征提取网络。模型的总参数量为138M,所占内存为526M(138M*4bytes)。如此大的参数量和内存占用会对显卡计算产生较大压力,如果显卡算力不够会导致模型训练耗时较长,如果显卡内存不够会直接导致模型训练失败。并且为了使得模型收敛且达到较好的泛化效果,基于深度学习的模型训练需要的数据量非常大,使得在多轮训练时数据读取成为非常耗时的部分。
为解决上述问题,本发明提供一种缺陷检测模型训练方法,即通过混合精度训练,将模型参数进行精度降低调节后再进行前后向传播计算,能够在保证模型进度的情况下节省计算内存带宽以及显存占用,同时也节省了数据传输时间,减少模型训练时间;通过对模型梯度值进行溢出判断,采用比例因子对原始的梯度值进行扩大,并采用动态调整策略对模型梯度值进行调整,避免了因模型参数精度转换而可能导致的梯度消失最终致使模型训练失败的情况,最终达到了在保证模型有效训练的情况下大大提升模型训练迭代效率,提高模型验证速度,从而解决了现有基于深度学习的缺陷检测模型的参数量和计算量较大,在算力较小和显卡内存较小的设备上,容易出现训练效率低下或训练失败的技术问题。
如图1所示,图1是本发明实施例方案涉及的硬件运行环境的设备结构示意图。
如图1所示,该缺陷检测模型训练装置可以包括:处理器1001,例如CPU,用户接口1003,网络接口1004,存储器1005,通信总线1002。其中,通信总线1002用于实现这些组件之间的连接通信。用户接口1003可以包括显示屏(Display)、输入单元比如键盘(Keyboard),可选用户接口1003还可以包括标准的有线接口、无线接口。网络接口1004可选的可以包括标准的有线接口、无线接口(如WI-FI接口)。存储器1005可以是高速RAM存储器,也可以是稳定的存储器(non-volatile memory),例如磁盘存储器。存储器1005可选的还可以是独立于前述处理器1001的存储装置。
本领域技术人员可以理解,图1中示出的设备结构并不构成对设备的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。
如图1所示,作为一种计算机存储介质的存储器1005中可以包括操作系统、网络通信模块、用户接口模块以及缺陷检测模型训练程序。
在图1所示的设备中,网络接口1004主要用于连接后台服务器,与后台服务器进行数据通信;用户接口1003主要用于连接客户端(程序员端),与客户端进行数据通信;而处理器1001可以用于调用存储器1005中存储的缺陷检测模型训练程序,并执行下述缺陷检测模型训练方法中的操作。
基于上述硬件结构,提出本发明缺陷检测模型训练方法实施例。
参照图2,图2为本发明缺陷检测模型训练方法第一实施例的流程示意图。所述缺陷检测模型训练方法包括;
步骤S10,在预搭建的基于深度学习的缺陷检测模型中,针对输入的缺陷图像训练数据进行混合精度训练,其中,所述混合精度训练中模型参数经过数据格式精度降低转换后再参与网络传播;
在本实施例中,本发明应用于终端设备。缺陷检测模型指的是预先已搭建好,还未经训练的基于深度学习的神经网络模型。缺陷图像训练数据指的是用于搭建好的缺陷检测模型进行训练的训练数据,这些训练数据中包括已经标注了实际缺陷类别以及缺陷位置的缺陷产品图像。模型参数具体可以包括权重、激活值等。混合精度训练指的是在模型训练过程中,将模型参数转化为小于或等于原有数据格式精度的数值后,再进行混合训练运算。在模型前向传播计算时,将模型参数转换成小于原有数据格式精度的数值进行保存。
具体地,终端中已预搭建了上述缺陷检测模型,并定义有损失函数和梯度更新的方式。终端在接收到用于模型训练的已标注的缺陷图像(即上述缺陷图像训练数据)时,将此缺陷图像训练数据输入该缺陷检测模型。模型基于神经网络架构对缺陷图像训练数据进行混合精度训练,作为一种实施方案,在混合精度训练中,终端可以在内存中用16位浮点类型的模型参数做储存和乘法从而加速计算,用32位浮点类型的模型参数做累加避免舍入误差。
步骤S20,获取所述缺陷检测模型输出的缺陷预测值,根据预设的比例因子和所述缺陷预测值得到模型梯度值,并确定所述模型梯度值的溢出情况;
在本实施例中,缺陷预测值指的是缺陷检测模型对缺陷图像训练数据中所存在缺陷的预测值,具体可以包括缺陷类型信息、缺陷位置信息等。比例因子指的是大于1的放大倍数,用于放大缺陷检测模型在训练过程中所得到的损失值和模型梯度值。溢出情况具体包括两种,溢出和未溢出。
具体地,将缺陷图像训练数据输入缺陷检测模型后,模型会输出对应的缺陷预测值,而模型从获取缺陷图像到输出缺陷预测值的过程可为:先对特征图像进行特征提取,生成特征图,然后在特征图上的每个点生成预设框,每个框负责预测相关的类别信息和位置信息,最后将所有预设框通过非极大抑制算法,筛选出最终的结果,并输出该框的缺陷类型信息和位置信息,作为上述缺陷预测值。模型在输出缺陷预测值后,根据缺陷预测值和训练数据标注的缺陷真实值,以及比例因子得到扩大后的损失函数值和模型梯度值,并在计算得到模型梯度值之后,判断其是否溢出,其中,溢出包括上溢出和下溢出。
步骤S30,针对所述溢出情况,按照预设的动态调整策略对所述模型梯度值进行调整,得到更新模型梯度值;
步骤S40,基于所述更新模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至满足预设的模型收敛条件。
在本实施例中,动态调整策略指的是针对模型梯度值溢出或者未溢出所采取的调整策略,通常情况下,对于模型梯度值溢出和未溢出所采取的调整策略并不相同。模型梯度溢出时,可适当减少模型梯度的放大比例;模型梯度未溢出时,可复原模型梯度以去除比例因子对梯度的扩大效果。
具体地,模型根据实际的溢出情况在预设的动态调整策略中选择对应的策略来对模型梯度值进行调整,并使用调整后所得到的更新模型梯度值对模型进行迭代训练,直至终端检测到当前已满足预设的模型收敛条件,则可以停止模型迭代训练过程,判定此时模型已训练完成。
本实施例提供一种缺陷检测模型训练方法,通过混合精度训练,将模型参数进行精度降低调节后再进行前后向传播计算,能够在保证模型进度的情况下节省计算内存带宽以及显存占用,同时也节省了数据传输时间,减少模型训练时间;通过对模型梯度值进行溢出判断,采用比例因子对原始的梯度值进行扩大,并采用动态调整策略对模型梯度值进行调整,避免了因模型参数精度转换而可能导致的梯度消失最终致使模型训练失败的情况,最终达到了在保证模型有效训练的情况下大大提升模型训练迭代效率,提高模型验证速度,从而解决了现有基于深度学习的缺陷检测模型的参数量和计算量较大,在算力较小和显卡内存较小的设备上,容易出现训练效率低下或训练失败的技术问题。
进一步地,基于上述图2所示的第一实施例,提出本发明缺陷检测模型训练方法的第二实施例。在本实施例中,所述缺陷检测模型包括由多个卷积层组成的特征提取网络,步骤S10包括:
步骤S11,在所述缺陷检测模型的特征提取网络中,根据所述缺陷图像训练数据得到权重和激活值,以作为所述模型参数;
步骤S12,将所述权重和激活值由32位浮点型转换为16位浮点型,以按照16位浮点型的权重和激活值在各所述卷积层之间进行前向传播。
在本实施例中,缺陷检测模型中包括由多个卷积层组成的特征提取网络VGG,用于对缺陷图像进行特征提取,生成特征图。
具体地,输入模型的缺陷图像训练数据在特征提取网络中,将模型参数(权重和激活值)转化为16位或32位浮点类型数,再进行混合训练运算。在模型前向传播进行计算时,将模型参数转换为16位浮点类型进行计算。在前向传播至最后一卷积层,特征提取网络会完成本轮的特征提取,输出缺陷图像对应的特征图,以将此特征图传递到模型中的下一部分。
本实施例通过在神经网络进行前向传播时将权重和激活值由原先的float32型转为float16型进行存储(在反向传播是将梯度也由float32型转为float16型进行计算)。整个模型的参数量为138M,以float32型存储模型大小为526M,使用float16型存储后模型大小缩减为268M,从而极大节省了模型的内存占用。
进一步地,所述缺陷检测模型还包括单阶段目标检测器,步骤S20包括:
步骤S21,获取所述特征提取网络基于所述训练缺陷图像生成的特征图像,并根据所述单阶段目标检测器得到所述特征图像对应的缺陷预测值;
步骤S22,根据所述缺陷预测值和所述缺陷图像训练数据中的缺陷真实值得到原始损失值,并结合所述比例因子扩大所述原始损失值得到损失值;
步骤S23,根据所述损失值得到所述模型梯度值,并判断所述模型梯度值是否溢出。
在本实施例中,由于16位浮点类型数在
Figure 621908DEST_PATH_IMAGE001
~65504之间,即16位浮点类型数所表达的数值范围比32位浮点类型数范围窄。这样使得从在用16位浮点类型数表示权值、梯度和激活值时,高于65504的数值因为溢出变为无穷大,低于
Figure 79434DEST_PATH_IMAGE001
的数值因为下溢变为0。在反向传播时,梯度可能因为下溢而变为0,导致梯度消失,模型难以训练至收敛。所以在训练过程中,我们将模型的损失乘以一个足够大的比例因子(例如取值为1024),来放大梯度。计算得到最终梯度后,便可得到正确值。
缺陷检测模型除了特征提取网络VGG外,还包括单阶段目标检测器SSD。VGG在生成特征图后,特征图作为SSD的输入,在SSD中经过生成框图、缺陷信息预测、通过非极大抑制算法对预测值进行筛选等步骤,最终将筛选出的预测值作为上述缺陷预测值。
作为一具体实施例,如图3所示。图3中整个缺陷检测模型主要由两部分组成:特征提取网络VGG16和SSD检测器。VGG16主要由5个卷积层组成,SSD主要由非极大性抑制等后处理组成的解码器组成,在神经网络进行前向传播时将权重和激活值由原先的float32型转为float16型进行存储,并在VGG16输出特征图后将其输入SSD,SSD最后输出缺陷类别和缺陷位置坐标作为上述缺陷预测值。
本实施例能够在保证模型精度的情况下,减少缺陷检测模型的内存占用和计算量,加速缺陷检测模型训练速度。可以在较小算力和显存的显卡上进行缺陷检测模型的训练,在较大算力的显卡上提升缺陷检测模型的训练速度。
进一步地,所述动态调整策略包括第一调整策略,步骤S30包括:
步骤S311,若所述模型梯度值溢出,则按照所述第一调整策略缩小所述比例因子,并基于缩小后的比例因子和所述原始损失值得到新的模型梯度值,作为所述更新模型梯度值。
在本实施例中,第一调整策略指的是与模型梯度值溢出对应的调整策略。比例因子的缩小幅度可根据实际情况灵活设定。
具体地,为了避免由float16型转换过程中可能出现的数值溢出导致梯度消失,最终使得模型训练失败,因此,在终端检测到本轮训练所得到的模型梯度值溢出(上溢出或下溢出)时,需要适当缩小比例因子,例如将比例因子缩小1/5,然后再根据缩小后的比例因子重新计算损失值和模型梯度值,并将重新计算得到的模型梯度值作为上述更新模型梯度值。
进一步地,所述动态调整策略包括第二调整策略,步骤S30包括:
步骤S321,若所述模型梯度值未溢出,则按照所述第二调整策略将所述模型梯度值按照所述比例因子进行缩小还原,将还原后的模型梯度值作为所述更新模型梯度值。
在本实施例中,在终端检测到本轮训练所得到的模型梯度值未溢出时,则可将此时因比例因子而放大的模型梯度值进行复原,复原为不受比例因子方法的原始值,并将此原始值作为上述更新模型梯度值。
作为一具体实施例,如图4所示。
以比例因子取值为1024为例。输入训练数据经由缺陷检测模型后会输出一个缺陷类别分类的值和缺陷位置的坐标值。用这两个输出与真实缺陷的类别和真实缺陷的定位坐标计算则可以得到损失值,得到损失值以后乘以一个比例因子1024,此时损失值和梯度值为:
损失值=原损失值*1024;
梯度值=原梯度值*1024。
在得到梯度值后,做溢出判断,如果梯度值无溢出,则将梯度复原为原始值后,进行网络更新;若梯度值为Inf(上溢出)或者NaN(下溢出),则损失值的比例因子缩小1/5,重新计算,如此反复迭代更新模型。
本实施例通过使用混合精度训练和动态损失值扩张,在保证模型精度的情况下节省计算内存带宽,减少约一半的显存占用,使得在算力较小的边缘设备上也可进行缺陷检测模型训练,在算力较大的设备上可训练更大的模型。由于模型减小,cpu和gpu间的数据传输量也减少,节省了数据传输时间,减少训练时间,实际情况下,GPU的利用率可由45%提高到85%。原先训练速度可由单精度的3840张每秒提高到混合精度5220张。大大提升训练迭代效率,提高模型验证速度,使得缺陷检测模型在应对多种场景检测需求时可以做到模型快速迭代和结果验证。
进一步地,基于上述第一实施例,提出本发明缺陷检测模型训练方法的第三实施例。在本实施例中,步骤S40包括:
步骤S411,基于所述更新模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,并判断模型迭代训练过程中是否连续预设次数未出现模型梯度值溢出;
步骤S412,若是,则增大所述比例因子,并根据增大后的比例因子得到目标模型梯度值;
步骤S413,基于所述目标模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至满足预设的模型收敛条件。
在本实施例中,连续预设次数的取值可根据实际情况灵活设置。
具体地,在进行模型迭代训练的过程中,由于每一轮训练都会得到对应的模型梯度值,为了加快模型收敛,可以根据多轮溢出情况的观察动态调整比例因子的大小,例如可判断连续10次的迭代过程中模型梯度值是否未溢出,若检测到连续10次的迭代中所得的模型梯度值均为溢出,则将比例因子增大1/5,用增大后的比例因子计算新的模型梯度值,然后继续迭代,直至满足模型收敛条件。
进一步地,步骤S40包括:
步骤S421,根据所述更新模型梯度值回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至当前迭代轮次对应的模型梯度值小于预设梯度阈值,或是当前迭代轮次达到预设轮次阈值时,判定满足所述模型收敛条件。
在本实施例中,在得到更新模型梯度值后,在缺陷检测模型中将更新模型梯度值以float16型进行后向传播,并基于更新模型梯度值更新模型参数,在新一轮训练中将更新后的模型参数进行前向传播,如此反复迭代更新模型。具体的模型收敛条件可为,模型梯度值小于某一梯度阈值,或是当前的迭代次数达到某一次数阈值,其中,梯度阈值和次数阈值都可根据实际情况灵活设置。当终端检测到当前轮次得到的模型梯度值小于某一梯度阈值,或是当前迭代次数达到某一次数阈值时,即可判定此时满足模型收敛条件,可以停止迭代训练过程,模型训练完成。
如图5所示,本发明还提供一种缺陷检测模型训练装置,所述缺陷检测模型训练装置包括:
混合精度训练模块10,用于在预搭建的基于深度学习的缺陷检测模型中,针对输入的缺陷图像训练数据进行混合精度训练,其中,所述混合精度训练中模型参数经过数据格式精度降低转换后再参与网络传播;
溢出情况确定模块20,用于获取所述缺陷检测模型输出的缺陷预测值,根据预设的比例因子和所述缺陷预测值得到模型梯度值,并确定所述模型梯度值的溢出情况;
模型梯度调整模块30,用于针对所述溢出情况,按照预设的动态调整策略对所述模型梯度值进行调整,得到更新模型梯度值;
模型迭代训练模块40,用于基于所述更新模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至满足预设的模型收敛条件。
可选地,所述缺陷检测模型包括由多个卷积层组成的特征提取网络,
所述混合精度训练模块10包括:
模型参数获取单元,用于在所述缺陷检测模型的特征提取网络中,根据所述缺陷图像训练数据得到权重和激活值,以作为所述模型参数;
精度下降转换单元,用于将所述权重和激活值由32位浮点型转换为16位浮点型,以按照16位浮点型的权重和激活值在各所述卷积层之间进行前向传播。
可选地,所述缺陷检测模型还包括单阶段目标检测器,
所述溢出情况确定模块20包括:
缺陷预测获取单元,用于获取所述特征提取网络基于所述训练缺陷图像生成的特征图像,并根据所述单阶段目标检测器得到所述特征图像对应的缺陷预测值;
损失值获取单元,用于根据所述缺陷预测值和所述缺陷图像训练数据中的缺陷真实值得到原始损失值,并结合所述比例因子扩大所述原始损失值得到损失值;
梯度溢出判断单元,用于根据所述损失值得到所述模型梯度值,并判断所述模型梯度值是否溢出。
可选地,所述动态调整策略包括第一调整策略,
所述模型梯度调整模块30包括:
第一策略调整单元,用于若所述模型梯度值溢出,则按照所述第一调整策略缩小所述比例因子,并基于缩小后的比例因子和所述原始损失值得到新的模型梯度值,作为所述更新模型梯度值。
可选地,所述动态调整策略包括第二调整策略,
所述模型梯度调整模块30包括:
第二策略调整单元,用于若所述模型梯度值未溢出,则按照所述第二调整策略将所述模型梯度值按照所述比例因子进行缩小还原,将还原后的模型梯度值作为所述更新模型梯度值。
可选地,所述模型迭代训练模块40包括:
连续溢出判断单元,用于基于所述更新模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,并判断模型迭代训练过程中是否连续预设次数未出现模型梯度值溢出;
比例因子增大单元,用于若是,则增大所述比例因子,并根据增大后的比例因子得到目标模型梯度值;
目标返回执行单元,用于基于所述目标模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至满足预设的模型收敛条件。
可选地,所述模型迭代训练模块40包括:
收敛条件判定单元,用于根据所述更新模型梯度值回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至当前迭代轮次对应的模型梯度值小于预设梯度阈值,或是当前迭代轮次达到预设轮次阈值时,判定满足所述模型收敛条件。
本发明还提供一种缺陷检测模型训练设备。
所述缺陷检测模型训练设备包括处理器、存储器及存储在所述存储器上并可在所述处理器上运行的缺陷检测模型训练程序,其中所述缺陷检测模型训练程序被所述处理器执行时,实现如上所述的缺陷检测模型训练方法的步骤。
其中,所述缺陷检测模型训练程序被执行时所实现的方法可参照本发明缺陷检测模型训练方法的各个实施例,此处不再赘述。
本发明还提供一种计算机可读存储介质。
本发明计算机可读存储介质上存储有缺陷检测模型训练程序,所述缺陷检测模型训练程序被处理器执行时实现如上所述的缺陷检测模型训练方法的步骤。
其中,所述缺陷检测模型训练程序被执行时所实现的方法可参照本发明缺陷检测模型训练方法各个实施例,此处不再赘述。
本发明还提供一种计算机可读存储介质,包括计算机程序,所述计算机程序被处理器执行时实现如上述的缺陷检测模型训练方法的步骤。
其中,所述计算机程序被执行时所实现的方法可参照本发明缺陷检测模型训练方法各个实施例,此处不再赘述。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者系统不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者系统所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者系统中还存在另外的相同要素。
上述本发明实施例序号仅仅为了描述,不代表实施例的优劣。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件系统的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在如上所述的一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端设备(可以是手机,计算机,服务器,或者网络设备等)执行本发明各个实施例所述的方法。
以上仅为本发明的优选实施例,并非因此限制本发明的专利范围,凡是利用本发明说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本发明的专利保护范围内。

Claims (10)

1.一种缺陷检测模型训练方法,其特征在于,所述缺陷检测模型训练方法包括:
在预搭建的基于深度学习的缺陷检测模型中,针对输入的缺陷图像训练数据进行混合精度训练,其中,所述混合精度训练中模型参数经过数据格式精度降低转换后再参与网络传播;
获取所述缺陷检测模型输出的缺陷预测值,根据预设的比例因子和所述缺陷预测值得到模型梯度值,并确定所述模型梯度值的溢出情况;
针对所述溢出情况,按照预设的动态调整策略对所述模型梯度值进行调整,得到更新模型梯度值;
基于所述更新模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至满足预设的模型收敛条件。
2.如权利要求1所述的缺陷检测模型训练方法,其特征在于,所述缺陷检测模型包括由多个卷积层组成的特征提取网络,
所述在预搭建的基于深度学习的缺陷检测模型中,针对输入的缺陷图像训练数据进行混合精度训练的步骤包括:
在所述缺陷检测模型的特征提取网络中,根据所述缺陷图像训练数据得到权重和激活值,以作为所述模型参数;
将所述权重和激活值由32位浮点型转换为16位浮点型,以按照16位浮点型的权重和激活值在各所述卷积层之间进行前向传播。
3.如权利要求2所述的缺陷检测模型训练方法,其特征在于,所述缺陷检测模型还包括单阶段目标检测器,
所述获取所述缺陷检测模型输出的缺陷预测值,根据预设的比例因子和所述缺陷预测值得到模型梯度值,并确定所述模型梯度值的溢出情况的步骤包括:
获取所述特征提取网络基于所述训练缺陷图像生成的特征图像,并根据所述单阶段目标检测器得到所述特征图像对应的缺陷预测值;
根据所述缺陷预测值和所述缺陷图像训练数据中的缺陷真实值得到原始损失值,并结合所述比例因子扩大所述原始损失值得到损失值;
根据所述损失值得到所述模型梯度值,并判断所述模型梯度值是否溢出。
4.如权利要求3所述的缺陷检测模型训练方法,其特征在于,所述动态调整策略包括第一调整策略,
所述针对所述溢出情况,按照预设的动态调整策略对所述模型梯度值进行调整,得到更新模型梯度值的步骤包括:
若所述模型梯度值溢出,则按照所述第一调整策略缩小所述比例因子,并基于缩小后的比例因子和所述原始损失值得到新的模型梯度值,作为所述更新模型梯度值。
5.如权利要求3所述的缺陷检测模型训练方法,其特征在于,所述动态调整策略包括第二调整策略,
所述针对所述溢出情况,按照预设的动态调整策略对所述模型梯度值进行调整,得到更新模型梯度值的步骤包括:
若所述模型梯度值未溢出,则按照所述第二调整策略将所述模型梯度值按照所述比例因子进行缩小还原,将还原后的模型梯度值作为所述更新模型梯度值。
6.如权利要求1所述的缺陷检测模型训练方法,其特征在于,所述基于所述更新模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至满足预设的模型收敛条件的步骤包括:
基于所述更新模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,并判断模型迭代训练过程中是否连续预设次数未出现模型梯度值溢出;
若是,则增大所述比例因子,并根据增大后的比例因子得到目标模型梯度值;
基于所述目标模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至满足预设的模型收敛条件。
7.如权利要求1-6任一项所述的缺陷检测模型训练方法,其特征在于,所述基于所述更新模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至满足预设的模型收敛条件的步骤包括:
根据所述更新模型梯度值回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至当前迭代轮次对应的模型梯度值小于预设梯度阈值,或是当前迭代轮次达到预设轮次阈值时,判定满足所述模型收敛条件。
8.一种缺陷检测模型训练装置,其特征在于,所述缺陷检测模型训练装置包括:
混合精度训练模块,用于在预搭建的基于深度学习的缺陷检测模型中,针对输入的缺陷图像训练数据进行混合精度训练,其中,所述混合精度训练中模型参数经过数据格式精度降低转换后再参与网络传播;
溢出情况确定模块,用于获取所述缺陷检测模型输出的缺陷预测值,根据预设的比例因子和所述缺陷预测值得到模型梯度值,并确定所述模型梯度值的溢出情况;
模型梯度调整模块,用于针对所述溢出情况,按照预设的动态调整策略对所述模型梯度值进行调整,得到更新模型梯度值;
模型迭代训练模块,用于基于所述更新模型梯度值返回执行针对输入的缺陷图像训练数据进行混合精度训练的步骤,直至满足预设的模型收敛条件。
9.一种缺陷检测模型训练设备,其特征在于,所述缺陷检测模型训练设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的缺陷检测模型训练程序,所述缺陷检测模型训练程序被所述处理器执行时实现如权利要求1至7中任一项所述的缺陷检测模型训练方法的步骤。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质包括计算机程序,所述计算机程序被处理器执行时实现如权利要求1至7中任一项所述的缺陷检测模型训练方法的步骤。
CN202111218168.6A 2021-10-20 2021-10-20 缺陷检测模型训练方法、装置、设备及可读存储介质 Pending CN113888524A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111218168.6A CN113888524A (zh) 2021-10-20 2021-10-20 缺陷检测模型训练方法、装置、设备及可读存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111218168.6A CN113888524A (zh) 2021-10-20 2021-10-20 缺陷检测模型训练方法、装置、设备及可读存储介质

Publications (1)

Publication Number Publication Date
CN113888524A true CN113888524A (zh) 2022-01-04

Family

ID=79003657

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111218168.6A Pending CN113888524A (zh) 2021-10-20 2021-10-20 缺陷检测模型训练方法、装置、设备及可读存储介质

Country Status (1)

Country Link
CN (1) CN113888524A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2024012476A1 (zh) * 2022-07-15 2024-01-18 华为技术有限公司 一种模型训练方法及相关设备
CN117786415A (zh) * 2024-02-27 2024-03-29 常州微亿智造科技有限公司 缺陷检测方法和系统

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2024012476A1 (zh) * 2022-07-15 2024-01-18 华为技术有限公司 一种模型训练方法及相关设备
CN117786415A (zh) * 2024-02-27 2024-03-29 常州微亿智造科技有限公司 缺陷检测方法和系统

Similar Documents

Publication Publication Date Title
US11756170B2 (en) Method and apparatus for correcting distorted document image
CN113888524A (zh) 缺陷检测模型训练方法、装置、设备及可读存储介质
US20210295473A1 (en) Method for image restoration, electronic device, and storage medium
CN101763627B (zh) 一种高斯模糊的实现方法和装置
CN108074211B (zh) 一种图像处理装置及方法
CN111489322B (zh) 给静态图片加天空滤镜的方法及装置
CN112200297A (zh) 神经网络优化方法、装置及处理器
CN107909537B (zh) 一种基于卷积神经网络的图像处理方法及移动终端
CN113222813B (zh) 图像超分辨率重建方法、装置、电子设备及存储介质
US11468600B2 (en) Information processing apparatus, information processing method, non-transitory computer-readable storage medium
CN107808394B (zh) 一种基于卷积神经网络的图像处理方法及移动终端
CN109447911B (zh) 图像复原的方法、装置、存储介质和终端设备
CN107977923B (zh) 图像处理方法、装置、电子设备及计算机可读存储介质
CN116385369A (zh) 深度图像质量评价方法、装置、电子设备及存储介质
CN112561050B (zh) 一种神经网络模型训练方法及装置
JP2021144428A (ja) データ処理装置、データ処理方法
CN107871162B (zh) 一种基于卷积神经网络的图像处理方法及移动终端
CN113689341A (zh) 图像处理方法及图像处理模型的训练方法
CN114708250B (zh) 一种图像处理方法、装置及存储介质
US11048971B1 (en) Method for training image generation model and computer device
CN111340215B (zh) 一种网络模型推理加速方法、装置、存储介质和智能设备
JP7114321B2 (ja) データ処理装置及びその方法
CN115410202A (zh) 文本区域检测方法、装置、存储介质以及计算机设备
JP2024004400A (ja) データ処理装置及びその方法
CN117591244A (zh) 基于机器学习平台的模型构建方法、装置及电子设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination