CN113344089A - 模型训练方法、装置及电子设备 - Google Patents
模型训练方法、装置及电子设备 Download PDFInfo
- Publication number
- CN113344089A CN113344089A CN202110670749.7A CN202110670749A CN113344089A CN 113344089 A CN113344089 A CN 113344089A CN 202110670749 A CN202110670749 A CN 202110670749A CN 113344089 A CN113344089 A CN 113344089A
- Authority
- CN
- China
- Prior art keywords
- model
- neural network
- network
- training
- training sample
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Image Analysis (AREA)
Abstract
本申请公开了模型训练方法、装置及电子设备,涉及计算机视觉、深度学习等人工智能技术领域。具体实现方案为:获取第一神经网络模型,所述第一神经网络模型基于第二神经网络模型进行剪枝得到,所述第一神经网络模型与任务模型的特征提取网络的结构匹配,所述任务模型用于进行图像识别;对所述第一神经网络模型进行训练;基于训练好的第一神经网络模型,对所述任务模型进行训练。根据本申请的技术,解决了模型训练技术中存在的模型训练效果比较差的问题,提高了模型训练的效果。
Description
技术领域
本申请涉及人工智能技术领域,尤其涉及计算机视觉、深度学习技术领域,具体涉及一种模型训练方法、装置及电子设备。
背景技术
随着人工智能的高速发展,基于深度学习的神经网络模型得到了广泛的应用,比如,可以采用神经网络模型进行车辆检测。为了基于神经网络模型实现具体任务如车辆检测任务,需要对神经网络模型进行训练,以使神经网络模型可以学习到图像特征,并基于图像特征进行相应任务的实现。
目前,神经网络模型的训练方式通常是初始化模型参数,在训练过程中更新初始化的模型参数,直至训练完成。
发明内容
本公开提供了一种模型训练方法、装置及电子设备。
根据本公开的第一方面,提供了一种模型训练方法,包括:
获取第一神经网络模型,所述第一神经网络模型基于第二神经网络模型进行剪枝得到,所述第一神经网络模型与任务模型的特征提取网络的结构匹配,所述任务模型用于进行图像识别;
对所述第一神经网络模型进行训练;
基于训练好的第一神经网络模型,对所述任务模型进行训练。
根据本公开的第二方面,提供了一种模型训练装置,包括:
获取模块,用于获取第一神经网络模型,所述第一神经网络模型基于第二神经网络模型进行剪枝得到,所述第一神经网络模型与任务模型的特征提取网络的结构匹配,所述任务模型用于进行图像识别;
第一训练模块,用于对所述第一神经网络模型进行训练;
第二训练模块,用于基于训练好的第一神经网络模型,对所述任务模型进行训练。
根据本公开的第三方面,提供了一种电子设备,包括:
至少一个处理器;以及
与至少一个处理器通信连接的存储器;其中,
存储器存储有可被至少一个处理器执行的指令,该指令被至少一个处理器执行,以使至少一个处理器能够执行第一方面中的任一项方法。
根据本公开的第四方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,该计算机指令用于使计算机执行第一方面中的任一项方法。
根据本公开的第五方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现第一方面中的任一项方法。
根据本申请的技术解决了模型训练技术中存在的模型训练效果比较差的问题,提高了模型训练的效果。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本申请的限定。其中:
图1是根据本申请第一实施例的模型训练方法的流程示意图;
图2是根据本申请第二实施例的模型训练装置的结构示意图;
图3示出了可以用来实施本公开的实施例的示例电子设备300的示意性框图。
具体实施方式
以下结合附图对本申请的示范性实施例做出说明,其中包括本申请实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本申请的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
第一实施例
如图1所示,本申请提供一种模型训练方法,包括如下步骤:
步骤S101:获取第一神经网络模型,所述第一神经网络模型基于第二神经网络模型进行剪枝得到,所述第一神经网络模型与任务模型的特征提取网络的结构匹配,所述任务模型用于进行图像识别。
本实施例中,模型训练方法涉及人工智能技术,具体涉及计算机视觉、深度学习技术领域,其可以广泛应用于目标检测、语义分割等图像识别场景中。该方法可以由本申请实施例的模型训练装置执行。而模型训练装置可以配置在任意电子设备中,以执行本申请实施例的模型训练方法,该电子设备可以服务器,也可以为终端,这里不做具体限定。
所述第一神经网络模型和第二神经网络模型为用于进行特征提取的模型,第二神经网络模型可以为预先存储的神经网络模型,也可以为从搜索空间的众多神经网络模型中搜索得到的神经网络模型,还可以为其他电子设备发送的神经网络模型,这里不进行具体限定。其中,搜索空间可以指定神经网络模型的功能和大致结构等。
所述第二神经网络模型可以为与任务模型的特征提取网络的结构最为相似的神经网络模型,在一可选实施方式中,可以根据实际的任务模型如车辆检测模型中的特征提取网络的结构,从众多神经网络模型中搜索得到与任务模型的特征提取模型最为相似的神经网络模型作为第二神经网络模型。
所述第二神经网络模型可以为有监督的模型,即该模型的训练需要有图像标签数据的参与,也可以为自监督即无监督的模型,即该模型的训练可以从大规模的训练样本数据中挖掘自身的监督信息,通过监督信息对模型进行训练,这里不进行具体限定。
所述第二神经网络模型可以为残差ResNet系列的神经网络模型,如ResNet32或ResNet50_vd等,其网络骨架可以为backbone,所述第二神经网络模型也可以为其他结构的神经网络模型,这里不进行具体限定。
所述第二神经网络模型以残差ResNet系列的神经网络模型为例,所述第二神经网络模型可以为backbone架构的自监督学习模型,如ResNet50_vd MoCov2,该第二神经网络模型可以包括两个网络分支,通过这两个网络分支进行自我监督和学习,以实现模型的训练。
所述第一神经网络模型可以基于所述第二神经网络模型进行剪枝得到,而剪枝实质是将第二神经网络模型中冗余部分进行剔除,以对齐任务模型的特征提取网络,即使剪枝得到的神经网络模型与任务模型的特征提取网络的结构匹配。这样,所述任务模型可以复用该第一神经网络模型的模型参数,即可以将第一神经网络模型的模型参数迁移至任务模型中。
所述任务模型可以指的是实现具体任务的模型,如实现车辆检测任务、图像分割任务或人脸识别任务等,其用于进行图像识别。在进行图像识别过程中,需要基于所提取的图像特征,也就是说,该任务模型可以包括多个部分,特征提取网络即为这多个部分中其中之一,该特征提取网络所提取的特征通过图像识别可以用于实现具体任务。
其中,所述第一神经网络模型与任务模型的特征提取网络的结构匹配可以指的是第一神经网络模型的整个网络与特征提取网络的结构匹配,也可以指的是第一神经网络模型中的某个网络分支与特征提取网络的结构匹配,这里不进行具体限定。
对第二神经网络模型剪枝的具体部分可以结合任务模型的特征提取网络的结构、任务模型实现具体任务的效果和时间来综合评判,可以剪枝第二神经网络模型的一个模块,减少神经网络模型的深度,也可以剪枝第二神经网络模型的模块中的一些卷积层,还可以剪枝第二神经网络模型的某一个或某几个模块中卷积层的通道数,这里不进行具体限定。
以第二神经网络模型ResNet50_vd MoCov2的剪枝为例,ResNet50_vd MoCov2可以包括两个网络分支,而ResNet50_vd为其中的一个网络分支,其是由卷积层从浅至深堆叠而成,且按照网络深度从浅至深,这些卷积层的处理过程可以被分为5个阶段,分别为stage1至stage5,且网络深度越深,所提取的图像特征越深。
在对ResNet50_vd进行剪枝时,可以对网络深度比较深的网络模块进行剪枝,也可以对网络深度比较浅的网络模块进行剪枝,或者对网络深度比较深的网络模块和网络深度比较浅的网络模块同时剪枝。
在一可选实施方式中,可以对网络深度比较深的网络模块进行剪枝,具体可以对模块的通道数进行缩减,比如,可以减半stage4的每个瓶颈模块bottleneck中最后一个卷积层的通道数,同时减半stage5的模块中所有卷积层的输出通道数。这样,在保证任务模型实现具体任务如车辆检测任务的效果的同时,还可以提高任务处理的速度。
另外,所述第一神经网络模型的获取方式可以有多种,比如,可以获取第二神经网络模型,并对所述第二神经网络模型进行剪枝,得到第一神经网络模型,也可以接收其他电子设备发送的第一神经网络模型,所述第一神经网络模型可以为其他电子设备基于第二神经网络模型剪枝得到。
步骤S102:对所述第一神经网络模型进行训练。
该步骤中,所述第一神经网络模型可以作为任务模型的预训练模型,可以进行预先训练,并将训练好的第一神经网络模型迁移至任务模型中,以提高任务模型的训练效果,包括可以减少训练时间以及提高模型参数的训练准确性。
可以根据所述第一神经网络模型,按照相应的方式对所述第一神经网络模型进行训练,比如,在所述第一神经网络模型为有监督的模型的情况下,可以基于训练样本数据和图像标签数据对所述第一神经网络模型进行训练,比对图像特征和图像标签的差异,以基于差异信息更新第一神经网络模型的模型参数。
又比如,在所述第一神经网络模型为自监督的模型的情况下,可以基于训练样本数据对所述第一神经网络模型进行训练,通过第一神经网络模型基于自监督对比学习,从训练样本数据中挖掘自身的监督信息,基于监督信息更新第一神经网络模型的模型参数。
最终,在差异信息或监督信息达到收敛的情况下,则可以说明第一神经网络模型训练完成。
步骤S103:基于训练好的第一神经网络模型,对所述任务模型进行训练。
该步骤中,可以将训练好的第一神经网络模型作为所述任务模型的特征提取网络,将其整个网络或某个网络分支迁移至所述任务模型,也可以将训练好的第一神经网络模型中的全部或部分模型参数迁移至任务模型,即将第一神经网络模型中的模型参数作为任务模型的特征提取网络的初始参数。
迁移之后,对所述任务模型进行继续训练,最终训练一个可以执行具体任务的模型,比如,车辆检测模型,其可以对待检测图像进行车辆检测。
本实施例中,通过获取第一神经网络模型,所述第一神经网络模型基于第二神经网络模型进行剪枝得到,所述第一神经网络模型与任务模型的特征提取网络的结构匹配,所述任务模型用于进行图像识别;对所述第一神经网络模型进行训练;基于训练好的第一神经网络模型,对所述任务模型进行训练。如此,通过剪枝和预训练模型,可以大大减少模型训练的时间,以及提高模型参数的训练准确性,从而可以提高模型训练的效果。
可选的,所述步骤102具体包括:
获取第一训练样本图像;
将所述第一训练样本图像输入至所述第一神经网络模型执行第一操作,得到所述第一训练样本图像的监督信息,所述第一操作用于基于所述第一神经网络模型对所述第一训练样本图像进行自监督学习处理;
基于所述监督信息更新所述第一神经网络模型的模型参数。
本实施方式中,所述第一训练样本图像可以为无监督的图像,无监督的图像指的是该图像无对应标签,其可以为所有数据域中的图像,即第一神经网络模型的训练可以适用于所有数据域,而无需限定为任务模型所规定的数据域中图像。
其中,一个数据域可以指的是一种类型的图像数据,比如,针对车辆检测任务,车辆检测模型训练或实际运行时所规定的数据域通常是包括车辆图像信息的图像数据,又比如,针对人脸识别任务,人脸识别模型训练或实际运行时所规定的数据域通常是包括人脸图像信息的图像数据。
也就是说,在训练第一神经网络模型时,无需限定第一训练样本图像中的图像内容,其图像内容可以包括人脸、车辆或其他对象等,如此可以大大提高模型训练的灵活性。
另外,所述第一训练样本图像可以为大规模数据集如ImageNet中的图像,即第一神经网络模型可以在ImageNet数据上完成模型训练。
可以将ImageNet数据中的图像分别输入至所述第一神经网络模型执行第一操作,得到图像的监督信息,所述第一操作可以用于基于所述第一神经网络模型对图像进行自监督学习处理,通过自监督学习处理,可以得到图像自身的监督信息,该监督信息可以表征所提取的图像特征是否准确。
其中,可以采用现有的或新的自监督学习处理的方式,通过第一神经网络模型来挖掘图像的监督信息。在一可选实施方式中,第一神经网络模型可以包括至少两个网络分支,这至少两个网络分支的结构可以相同或相似,可以基于这至少两个网络分支分别对第一训练样本图像和第一训练样本图像进行数据增强的图像进行特征提取,比对图像特征,以确定这至少两个网络分支提取的图像特征是否相同或相似,最终得到第一训练样本图像的监督信息。该监督信息表征这至少两个网络分支提取的图像特征差异,在图像特征差异比较小时,可以说明所提取的图像特征比较准确。
可以基于所述监督信息更新所述第一神经网络模型的模型参数,在所述监督信息表征所提取的图像特征不准确的情况下,可以基于该监督信息来更新所述第一神经网络模型的模型参数,之后输入其他的图像继续进行训练,直至基于第一神经网络模型所得到的监督信息表征所提取的图像特征比较准确。
在一可选实施方式中,第一神经网络模型可以包括两个网络分支,分别可以称之为k分支和q分支,q分支的模型参数可以梯度回传更新,而k分支的模型参数可以根据q分支的模型参数进行更新。具体的,可以基于监督信息,通过梯度回传的方式对q分支的模型参数进行更新,之后,可以对q分支的模型参数进行动量加权,基于动量加权的结果对k分支的模型参数进行更新。
本实施方式中,通过获取第一训练样本图像;将所述第一训练样本图像输入至所述第一神经网络模型执行第一操作,得到所述第一训练样本图像的监督信息,所述第一操作用于基于所述第一神经网络模型对所述第一训练样本图像进行自监督学习处理;基于所述监督信息更新所述第一神经网络模型的模型参数。如此,可以通过自监督学习从大规模的无监督数据中挖掘自身的监督信息,基于监督信息对第一神经网络模型进行训练,从而将第一神经网络模型作为预训练模型时,可以学习到对下游任务有价值的表征,进而可以提高任务处理的效果,如下游任务为车辆检测任务时,可以提高车辆检测的准确率。
可选的,所述第一神经网络模型的网络分支包括第一网络分支和第二网络分支,所述将所述第一训练样本图像输入至所述第一神经网络模型执行第一操作,得到所述第一训练样本图像的监督信息,包括:
对所述第一训练样本图像进行数据增强,得到第一图像和第二图像;
基于所述第一网络分支对所述第一图像进行特征提取,得到第一特征;
基于所述第二网络分支对所述第二图像进行特征提取,得到第二特征;
对所述第一特征和所述第二特征进行特征比对,得到所述第一训练样本图像的监督信息。
本实施方式中,所述第一网络分支可以称之为q分支,所述第二网络分支可以称之为k分支,这两个网络分支的结构可以相同或相似,以分别对图像进行特征提取。
可以采用自监督对比学习的方式,挖掘图像自身的监督信息,具体的,可以首先对所述第一训练样本图像进行数据增强,得到第一图像和第二图像,所述第一图像和第二图像可以为图像内容相似的两张图像。比如,第一图像和第二图像中均可以包括图像内容“猫”,只是其“猫”的位置可以有所不同。
所述第一图像可以为第一训练样本图像,所述第二图像可以为基于第一训练样本图像进行数据增强得到的图像,所述第一图像和第二图像也可以为基于第一训练样本图像分别进行数据增强得到的图像。
可以采用现有的或新的数据增强的方式对所述第一训练样本图像进行数据增强,这里不进行具体阐述。
之后,可以基于所述第一网络分支对所述第一图像进行特征提取,得到第一特征,基于第二网络分支对所述第二图像进行特征提取,得到第二特征,对所述第一特征和所述第二特征进行特征比对,得到所述第一训练样本图像的监督信息。
其中,所述第一神经网络模型还可以包括比对模块,该比对模块可以为判别器,可以采用该比对模块对所述第一特征和所述第二特征进行特征比对,得到所述第一训练样本图像的监督信息。
本实施方式中,通过自监督对比学习的方式,挖掘图像自身的监督信息,如此可以非常简单地挖掘图像自身的监督信息,实现对第一神经网络模型的训练。
可选的,所述第一神经网络模型的模型参数包括所述第一网络分支的第一模型参数,所述步骤S103具体包括:
获取第二训练样本图像;
将所述第二训练样本图像输入至所述任务模型执行第二操作,得到所述第二训练样本图像的识别结果;
基于所述识别结果,更新所述任务模型的第二模型参数;
其中,所述第二操作包括:将所述第一模型参数作为所述任务模型的特征提取网络的参数,对所述第二训练样本图像进行特征提取,得到第三特征;基于所述第三特征进行图像识别,得到所述识别结果。
本实施方式中,可以抽取第一网络分支的第一模型参数,将其迁移至任务模型上,作为预训练模型参与任务模型的训练。
具体的,可以获取第二训练样本图像,所述第二训练样本图像的数据域需要与任务模型匹配,比如,当任务模型为车辆检测模型时,其训练数据通常是需要包括车辆图像信息的图像数据。
其获取方式可以有多种,比如,可以将预先存储的图像作为第二训练样本图像,可以接收其他电子设备发送的第二训练样本图像。
可以将所述第二训练样本图像输入至所述任务模型执行第二操作,得到所述第二训练样本图像的识别结果。其中,该步骤中的任务模型为基于预训练模型迁移模型参数后的任务模型,也就是说,在执行第二操作过程中,可以将所述第一模型参数作为所述任务模型的特征提取网络的参数,对所述第二训练样本图像进行特征提取,得到第三特征,并基于第三特征进行图像识别,得到识别结果。
之后,可以基于识别结果更新所述任务模型的第二模型参数,在一可选实施方式中,可以确定识别结果和图像标签的差异信息,基于差异信息更新所述任务模型的第二模型参数。
其中,所述第二模型参数可以包括第一模型参数,也就是说,任务模型除了特征提取网络的参数需要更新之外,还可能需要更新其他网络的参数。
本实施方式中,通过将所述第一模型参数作为所述任务模型的特征提取网络的参数,对第二训练样本图像进行特征提取,得到第三特征;基于所述第三特征进行图像识别,得到所述识别结果;基于所述识别结果,更新所述任务模型的第二模型参数。如此,可以通过迁移模型参数实现将预训练模型迁移至任务模型上,从而可以降低预训练模型的迁移难度。
可选的,所述第二神经网络模型的网络分支包括第三网络分支,所述步骤S101具体包括:
将所述第三网络分支中目标网络的卷积层的通道数进行缩减,得到所述第一网络分支;
其中,所述目标网络为所述第三网络分支中,网络深度大于其他网络的网络。
本实施方式中,可以剪枝第二神经网络模型的第三网络分支中某一个或某几个模块中卷积层的通道数,其中,由于深层网络所提取的特征在图像识别任务时作用相对比较小,因此,剪枝的模块可以位于第三网络分支中的深层网络。
以第二神经网络模型ResNet50_vd MoCov2的剪枝为例,ResNet50_vd MoCov2可以包括两个网络分支,而ResNet50_vd为其中的一个网络分支,可以减半stage4的每个瓶颈模块bottleneck中最后一个卷积层的通道数,同时减半stage5的模块中所有卷积层的输出通道数,最终可以得到第一网络分支,从而可以得到第一神经网络模型。其中,stage4和stage5的网络深度均大于其他阶段的网络深度。
这样,通过剪枝第三网络分支中深层网络的卷积层的通道,缩减深层网络的卷积层的通道数,可以在保证任务模型实现具体任务如车辆检测任务的效果的同时,还可以提高任务处理的速度。
第二实施例
如图2所示,本申请提供一种模型训练装置200,包括:
获取模块201,用于获取第一神经网络模型,所述第一神经网络模型基于第二神经网络模型进行剪枝得到,所述第一神经网络模型与任务模型的特征提取网络的结构匹配,所述任务模型用于进行图像识别;
第一训练模块202,用于对所述第一神经网络模型进行训练;
第二训练模块203,用于基于训练好的第一神经网络模型,对所述任务模型进行训练。
可选的,所述第一训练模块203包括:
第一获取单元,用于获取第一训练样本图像;
第一执行单元,用于将所述第一训练样本图像输入至所述第一神经网络模型执行第一操作,得到所述第一训练样本图像的监督信息,所述第一操作用于基于所述第一神经网络模型对所述第一训练样本图像进行自监督学习处理;
第一更新单元,用于基于所述监督信息更新所述第一神经网络模型的模型参数。
可选的,所述第一神经网络模型的网络分支包括第一网络分支和第二网络分支,所述第一执行单元,具体用于:
对所述第一训练样本图像进行数据增强,得到第一图像和第二图像;
基于所述第一网络分支对所述第一图像进行特征提取,得到第一特征;
基于所述第二网络分支对所述第二图像进行特征提取,得到第二特征;
对所述第一特征和所述第二特征进行特征比对,得到所述第一训练样本图像的监督信息。
可选的,所述第一神经网络模型的模型参数包括所述第一网络分支的第一模型参数,所述第二训练模块203包括:
第二获取单元,用于获取第二训练样本图像;
第二执行单元,用于将所述第二训练样本图像输入至所述任务模型执行第二操作,得到所述第二训练样本图像的识别结果;
第二更新单元,用于基于所述识别结果,更新所述任务模型的第二模型参数;
其中,所述第二操作包括:将所述第一模型参数作为所述任务模型的特征提取网络的参数,对所述第二训练样本图像进行特征提取,得到第三特征;基于所述第三特征进行图像识别,得到所述识别结果。
可选的,所述第二神经网络模型的网络分支包括第三网络分支,所述获取模块201,具体用于将所述第三网络分支中目标网络的卷积层的通道数进行缩减,得到所述第一网络分支;
其中,所述目标网络为所述第三网络分支中,网络深度大于其他网络的网络。
本申请提供的模型训练装置200能够实现上述模型训练方法实施例实现的各个过程,且能够达到相同的有益效果,为避免重复,这里不再赘述。
根据本申请的实施例,本申请还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
图3示出了可以用来实施本公开的实施例的示例电子设备300的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本申请的实现。
如图3所示,设备300包括计算单元301,其可以根据存储在只读存储器(ROM)302中的计算机程序或者从存储单元308加载到随机访问存储器(RAM)303中的计算机程序,来执行各种适当的动作和处理。在RAM303中,还可以存储设备300操作所需的各种程序和数据。计算单元301、ROM302以及RAM303通过总线304彼此相连。输入/输出(I/O)接口305也连接至总线304。
设备300中的多个部件连接至I/O接口305,包括:输入单元306,例如键盘、鼠标等;输出单元307,例如各种类型的显示器、扬声器等;存储单元308,例如磁盘、光盘等;以及通信单元309,例如网卡、调整解调器、无线通信收发机等。通信单元309允许设备300通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元301可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元301的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元301执行上文所描述的各个方法和处理,例如模型训练方法。例如,在一些实施例中,模型训练方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元308。在一些实施例中,计算机程序的部分或者全部可以经由ROM302和/或通信单元309而被载入和/或安装到设备300上。当计算机程序加载到RAM303并由计算单元301执行时,可以执行上文描述的模型训练方法的一个或多个步骤。备选地,在其他实施例中,计算单元301可以通过其他任何适当的方法(例如,借助于固件)而被配置为执行模型训练方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、负载可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编辑语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便携式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入、或者触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)、互联网和区块链网络。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,又称为云计算服务器或云主机,是云计算服务体系中的一项主机产品,以解决了传统物理主机与VPS服务("Virtual Private Server",或简称"VPS")中,存在的管理难度大,业务扩展性弱的缺陷。服务器也可以为分布式系统的服务器,或者是结合了区块链的服务器。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本发申请中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本申请公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本申请保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本申请的精神和原则之内所作的修改、等同替换和改进等,均应包含在本申请保护范围之内。
Claims (13)
1.一种模型训练方法,包括:
获取第一神经网络模型,所述第一神经网络模型基于第二神经网络模型进行剪枝得到,所述第一神经网络模型与任务模型的特征提取网络的结构匹配,所述任务模型用于进行图像识别;
对所述第一神经网络模型进行训练;
基于训练好的第一神经网络模型,对所述任务模型进行训练。
2.根据权利要求1所述的方法,其中,所述对所述第一神经网络模型进行训练,包括:
获取第一训练样本图像;
将所述第一训练样本图像输入至所述第一神经网络模型执行第一操作,得到所述第一训练样本图像的监督信息,所述第一操作用于基于所述第一神经网络模型对所述第一训练样本图像进行自监督学习处理;
基于所述监督信息更新所述第一神经网络模型的模型参数。
3.根据权利要求2所述的方法,其中,所述第一神经网络模型的网络分支包括第一网络分支和第二网络分支,所述将所述第一训练样本图像输入至所述第一神经网络模型执行第一操作,得到所述第一训练样本图像的监督信息,包括:
对所述第一训练样本图像进行数据增强,得到第一图像和第二图像;
基于所述第一网络分支对所述第一图像进行特征提取,得到第一特征;
基于所述第二网络分支对所述第二图像进行特征提取,得到第二特征;
对所述第一特征和所述第二特征进行特征比对,得到所述第一训练样本图像的监督信息。
4.根据权利要求3所述的方法,其中,所述第一神经网络模型的模型参数包括所述第一网络分支的第一模型参数,所述基于训练好的第一神经网络模型,对所述任务模型进行训练,包括:
获取第二训练样本图像;
将所述第二训练样本图像输入至所述任务模型执行第二操作,得到所述第二训练样本图像的识别结果;
基于所述识别结果,更新所述任务模型的第二模型参数;
其中,所述第二操作包括:将所述第一模型参数作为所述任务模型的特征提取网络的参数,对所述第二训练样本图像进行特征提取,得到第三特征;基于所述第三特征进行图像识别,得到所述识别结果。
5.根据权利要求3所述的方法,其中,所述第二神经网络模型的网络分支包括第三网络分支,所述获取第一神经网络模型,包括:
将所述第三网络分支中目标网络的卷积层的通道数进行缩减,得到所述第一网络分支;
其中,所述目标网络为所述第三网络分支中,网络深度大于其他网络的网络。
6.一种模型训练装置,包括:
获取模块,用于获取第一神经网络模型,所述第一神经网络模型基于第二神经网络模型进行剪枝得到,所述第一神经网络模型与任务模型的特征提取网络的结构匹配,所述任务模型用于进行图像识别;
第一训练模块,用于对所述第一神经网络模型进行训练;
第二训练模块,用于基于训练好的第一神经网络模型,对所述任务模型进行训练。
7.根据权利要求6所述的装置,其中,所述第一训练模块包括:
第一获取单元,用于获取第一训练样本图像;
第一执行单元,用于将所述第一训练样本图像输入至所述第一神经网络模型执行第一操作,得到所述第一训练样本图像的监督信息,所述第一操作用于基于所述第一神经网络模型对所述第一训练样本图像进行自监督学习处理;
第一更新单元,用于基于所述监督信息更新所述第一神经网络模型的模型参数。
8.根据权利要求7所述的装置,其中,所述第一神经网络模型的网络分支包括第一网络分支和第二网络分支,所述第一执行单元,具体用于:
对所述第一训练样本图像进行数据增强,得到第一图像和第二图像;
基于所述第一网络分支对所述第一图像进行特征提取,得到第一特征;
基于所述第二网络分支对所述第二图像进行特征提取,得到第二特征;
对所述第一特征和所述第二特征进行特征比对,得到所述第一训练样本图像的监督信息。
9.根据权利要求8所述的装置,其中,所述第一神经网络模型的模型参数包括所述第一网络分支的第一模型参数,所述第二训练模块包括:
第二获取单元,用于获取第二训练样本图像;
第二执行单元,用于将所述第二训练样本图像输入至所述任务模型执行第二操作,得到所述第二训练样本图像的识别结果;
第二更新单元,用于基于所述识别结果,更新所述任务模型的第二模型参数;
其中,所述第二操作包括:将所述第一模型参数作为所述任务模型的特征提取网络的参数,对所述第二训练样本图像进行特征提取,得到第三特征;基于所述第三特征进行图像识别,得到所述识别结果。
10.根据权利要求8所述的装置,其中,所述第二神经网络模型的网络分支包括第三网络分支,所述获取模块,具体用于将所述第三网络分支中目标网络的卷积层的通道数进行缩减,得到所述第一网络分支;
其中,所述目标网络为所述第三网络分支中,网络深度大于其他网络的网络。
11.一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-5中任一项所述的方法。
12.一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使所述计算机执行权利要求1-5中任一项所述的方法。
13.一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据权利要求1-5中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110670749.7A CN113344089B (zh) | 2021-06-17 | 2021-06-17 | 模型训练方法、装置及电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110670749.7A CN113344089B (zh) | 2021-06-17 | 2021-06-17 | 模型训练方法、装置及电子设备 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113344089A true CN113344089A (zh) | 2021-09-03 |
CN113344089B CN113344089B (zh) | 2022-07-01 |
Family
ID=77475909
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110670749.7A Active CN113344089B (zh) | 2021-06-17 | 2021-06-17 | 模型训练方法、装置及电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113344089B (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114638961A (zh) * | 2022-03-28 | 2022-06-17 | 北京国电瑞源科技发展有限公司 | 一种指针表盘识别方法、系统及计算机存储介质 |
CN114743041A (zh) * | 2022-03-09 | 2022-07-12 | 中国科学院自动化研究所 | 一种预训练模型抽选框架的构建方法及装置 |
CN114972334A (zh) * | 2022-07-19 | 2022-08-30 | 杭州因推科技有限公司 | 一种管材瑕疵检测方法、装置、介质 |
CN116994309A (zh) * | 2023-05-06 | 2023-11-03 | 浙江大学 | 一种公平性感知的人脸识别模型剪枝方法 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110598504A (zh) * | 2018-06-12 | 2019-12-20 | 北京市商汤科技开发有限公司 | 图像识别方法及装置、电子设备和存储介质 |
CN111783949A (zh) * | 2020-06-24 | 2020-10-16 | 北京百度网讯科技有限公司 | 基于迁移学习的深度神经网络的训练方法和装置 |
CN112308034A (zh) * | 2020-11-25 | 2021-02-02 | 中国科学院深圳先进技术研究院 | 一种性别分类的方法、设备、终端及计算机存储介质 |
CN112508004A (zh) * | 2020-12-18 | 2021-03-16 | 北京百度网讯科技有限公司 | 一种文字识别方法、装置、电子设备及存储介质 |
CN112560874A (zh) * | 2020-12-25 | 2021-03-26 | 北京百度网讯科技有限公司 | 图像识别模型的训练方法、装置、设备和介质 |
-
2021
- 2021-06-17 CN CN202110670749.7A patent/CN113344089B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110598504A (zh) * | 2018-06-12 | 2019-12-20 | 北京市商汤科技开发有限公司 | 图像识别方法及装置、电子设备和存储介质 |
CN111783949A (zh) * | 2020-06-24 | 2020-10-16 | 北京百度网讯科技有限公司 | 基于迁移学习的深度神经网络的训练方法和装置 |
CN112308034A (zh) * | 2020-11-25 | 2021-02-02 | 中国科学院深圳先进技术研究院 | 一种性别分类的方法、设备、终端及计算机存储介质 |
CN112508004A (zh) * | 2020-12-18 | 2021-03-16 | 北京百度网讯科技有限公司 | 一种文字识别方法、装置、电子设备及存储介质 |
CN112560874A (zh) * | 2020-12-25 | 2021-03-26 | 北京百度网讯科技有限公司 | 图像识别模型的训练方法、装置、设备和介质 |
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114743041A (zh) * | 2022-03-09 | 2022-07-12 | 中国科学院自动化研究所 | 一种预训练模型抽选框架的构建方法及装置 |
CN114638961A (zh) * | 2022-03-28 | 2022-06-17 | 北京国电瑞源科技发展有限公司 | 一种指针表盘识别方法、系统及计算机存储介质 |
CN114972334A (zh) * | 2022-07-19 | 2022-08-30 | 杭州因推科技有限公司 | 一种管材瑕疵检测方法、装置、介质 |
CN114972334B (zh) * | 2022-07-19 | 2023-09-15 | 杭州因推科技有限公司 | 一种管材瑕疵检测方法、装置、介质 |
CN116994309A (zh) * | 2023-05-06 | 2023-11-03 | 浙江大学 | 一种公平性感知的人脸识别模型剪枝方法 |
CN116994309B (zh) * | 2023-05-06 | 2024-04-09 | 浙江大学 | 一种公平性感知的人脸识别模型剪枝方法 |
Also Published As
Publication number | Publication date |
---|---|
CN113344089B (zh) | 2022-07-01 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113344089B (zh) | 模型训练方法、装置及电子设备 | |
CN112560874B (zh) | 图像识别模型的训练方法、装置、设备和介质 | |
CN112507706B (zh) | 知识预训练模型的训练方法、装置和电子设备 | |
CN113705628B (zh) | 预训练模型的确定方法、装置、电子设备以及存储介质 | |
CN113657483A (zh) | 模型训练方法、目标检测方法、装置、设备以及存储介质 | |
CN113627536A (zh) | 模型训练、视频分类方法,装置,设备以及存储介质 | |
CN112528641A (zh) | 建立信息抽取模型的方法、装置、电子设备和可读存储介质 | |
CN112580666A (zh) | 图像特征的提取方法、训练方法、装置、电子设备及介质 | |
CN115359308A (zh) | 模型训练、难例识别方法、装置、设备、存储介质及程序 | |
CN112949433B (zh) | 视频分类模型的生成方法、装置、设备和存储介质 | |
CN114581732A (zh) | 一种图像处理及模型训练方法、装置、设备和存储介质 | |
CN114186681A (zh) | 用于生成模型簇的方法、装置及计算机程序产品 | |
CN112699237B (zh) | 标签确定方法、设备和存储介质 | |
CN112528146B (zh) | 内容资源推荐方法、装置、电子设备及存储介质 | |
CN113361519B (zh) | 目标处理方法、目标处理模型的训练方法及其装置 | |
CN114612971A (zh) | 人脸检测方法、模型训练方法、电子设备及程序产品 | |
CN114330576A (zh) | 模型处理方法、装置、图像识别方法及装置 | |
CN113657248A (zh) | 人脸识别模型的训练方法、装置及计算机程序产品 | |
CN113936158A (zh) | 一种标签匹配方法及装置 | |
CN114254650A (zh) | 一种信息处理方法、装置、设备及介质 | |
CN113989899A (zh) | 人脸识别模型中特征提取层的确定方法、设备和存储介质 | |
CN113641724A (zh) | 知识标签挖掘方法、装置、电子设备及存储介质 | |
CN113204616A (zh) | 文本抽取模型的训练与文本抽取的方法、装置 | |
CN112632999A (zh) | 命名实体识别模型获取及命名实体识别方法、装置及介质 | |
CN113205119A (zh) | 数据标注方法、装置、电子设备及可读存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |