CN110705691A - 神经网络训练方法、装置及计算机可读存储介质 - Google Patents
神经网络训练方法、装置及计算机可读存储介质 Download PDFInfo
- Publication number
- CN110705691A CN110705691A CN201910907549.1A CN201910907549A CN110705691A CN 110705691 A CN110705691 A CN 110705691A CN 201910907549 A CN201910907549 A CN 201910907549A CN 110705691 A CN110705691 A CN 110705691A
- Authority
- CN
- China
- Prior art keywords
- output
- characteristic
- iteration
- neural network
- loss
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000013528 artificial neural network Methods 0.000 title claims abstract description 119
- 238000012549 training Methods 0.000 title claims abstract description 80
- 238000000034 method Methods 0.000 title claims abstract description 57
- 238000000605 extraction Methods 0.000 claims abstract description 22
- 238000003062 neural network model Methods 0.000 claims abstract description 7
- 238000012544 monitoring process Methods 0.000 abstract description 9
- 238000004821 distillation Methods 0.000 abstract description 4
- 238000010586 diagram Methods 0.000 description 11
- 230000006870 function Effects 0.000 description 8
- 230000008569 process Effects 0.000 description 7
- 230000000694 effects Effects 0.000 description 5
- 238000012545 processing Methods 0.000 description 5
- 238000012360 testing method Methods 0.000 description 4
- 238000004590 computer program Methods 0.000 description 3
- 238000013527 convolutional neural network Methods 0.000 description 2
- 230000014509 gene expression Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000013473 artificial intelligence Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000002708 enhancing effect Effects 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 230000003313 weakening effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Image Analysis (AREA)
Abstract
本公开提供了一种神经网络训练方法及装置,其中,方法包括:获取待训练的神经网络的多个中间层输出的多个特征图;通过特征提取网络对多个特征图进行特征提取,分别得到每个中间层的第一特征输出;根据多个第一特征输出与待训练的神经网络输出的第二特征输出,计算得到第一损失;基于第一损失,调整多个中间层的参数。通过自蒸馏的方式,将神经网络模型自身的各中间层、以及各次迭代中提取的结果特征输出,作为自身的监督信号充分利用,能够更快的收敛结果,完成训练,节约了时间和资源。
Description
技术领域
本公开一般地涉及人工智能领域,具体涉及一种神经网络训练方法及装置、电子设备及计算机可读存储介质。
背景技术
随着近几年深度学习的崛起,人们在图像分类、语音识别、自然语言处理、策略AI、自动驾驶等诸多领域取得了优异的成绩。然而,依靠着复杂的神经网络以及超大的数据集取得良好的成绩的基础是强大的计算能力。随着神经网络层数的加深以及数据集的不断扩展,训练神经网络对计算力要求以及调整参数试错成本也越来越高,这对于神经网络训练来说是极大的时间成本。
目前,通过蒸馏方法对神经网络进行训练的方式中,教师模型所需资源多,训练时间长。并且,不能统一优化,首先要有教师模型,再进行蒸馏,时间串行,所需时间长。
发明内容
为了解决现有技术中存在的上述问题,本公开的第一方面提供一种神经网络训练方法,其中,方法包括:获取待训练的神经网络的多个中间层输出的多个特征图;通过特征提取网络对多个特征图进行特征提取,分别得到每个中间层的第一特征输出;根据多个第一特征输出与待训练的神经网络输出的第二特征输出,计算得到第一损失;基于第一损失,调整多个中间层的参数。
在一例中,根据多个第一特征输出与神经网络输出的第二特征输出,计算得到第一损失,包括:每个第一特征输出分别与第二特征输出进行比对,得到每个中间层对应的中间层损失;基于第一损失,调整多个中间层的参数,包括:基于中间层损失,调整对应的中间层的参数以及对应的中间层前序全部中间层的参数。
在一例中,根据多个第一特征输出与神经网络输出的第二特征输出,计算得到第一损失,包括:根据当前轮的第二特征输出与前N轮迭代的第二特征输出,计算得到第一迭代损失,N为正整数;基于第一损失,调整多个中间层的参数,包括:基于第一迭代损失,调整多个中间层的参数。
在一例中,根据当前轮的第二特征输出与前N轮迭代的第二特征输出,计算得到第一迭代损失,包括:根据前N轮迭代的第二特征输出分别对应的第一权重系数,对多个前N轮迭代的第二特征输出进行加权拼接,得到前N轮迭代的加权特征输出;根据当前轮的第二特征输出与前N轮迭代的加权特征输出,计算得到第一迭代损失。
在一例中,第i轮迭代的第二特征输出对应的第一权重系数大于第j轮迭代的第二特征输出对应的第一权重系数,i、j均为正整数,且i>j。
在一例中,方法还包括:存储各轮迭代的第二特征输出。
在一例中,根据多个第一特征输出与待训练的神经网络输出的第二特征输出,计算得到第一损失,包括:对第一特征输出和第二特征输出进行拼接,得到第三特征输出;根据当前轮的第二特征输出与前M轮迭代的第三特征输出或当前轮的第三特征输出与前M轮迭代的第三特征输出,计算得到第二迭代损失,M为正整数;基于第一损失,调整中间层的参数,包括:基于第二迭代损失,调整多个中间层的参数。
在一例中,根据当前轮的第二特征输出与前M轮迭代的第三特征输出或当前轮的第三特征输出与前M轮迭代的第三特征输出,计算得到第二迭代损失,包括:根据前M轮迭代的第三特征输出分别对应的第二权重系数,对多个前M轮迭代的第三特征输出进行加权拼接,得到前M轮迭代的加权特征输出;根据当前轮的第二特征输出或当前轮的第三特征输出与前M轮迭代的加权特征输出,计算得到第二迭代损失。
在一例中,第g轮迭代的第三特征输出对应的第二权重系数大于第h轮迭代的第三特征输出对应的第二权重系数,g、h均为正整数,且g>h。
在一例中,方法还包括:存储各轮迭代的第三特征输出。
在一例中,特征提取网络包括卷积层和全连接层。
在一例中,方法还包括:根据待训练的神经网络输出的训练样本的预测值与训练样本的真实值,计算得到第二损失;基于第二损失,调整待训练的神经网络的参数。
在一例中,方法还包括:当第二损失小于训练阈值时,待训练的神经网络完成训练,得到神经网络模型。
本公开的第二方面提供一种神经网络训练装置,其中,装置包括:获取模块,用于获取待训练的神经网络的多个中间层输出的多个特征图;特征提取模块,用于通过特征提取网络对所述多个特征图进行特征提取,分别得到每个所述中间层的第一特征输出;损失确定模块,用于根据多个所述第一特征输出与所述待训练的神经网络输出的第二特征输出,计算得到第一损失;反馈模块,用于基于所述第一损失,调整所述中间层的参数。
本公开的第三方面提供一种电子设备,包括:存储器,用于存储指令;以及处理器,用于调用存储器存储的指令执行如第一方面的神经网络训练方法。
本公开的第四方面提供一种计算机可读存储介质,其中存储有指令,指令被处理器执行时,执行如第一方面的神经网络训练方法。
本公开提供的模型训练方法及装置,通过自蒸馏的方式,将待训练的神经网络的各中间层、以及各轮迭代中提取特征输出,作为自身的监督信号充分利用,用来对待训练的神经网络进行训练,能够提升训练速度,并能提升训练精度,使得训练得到的神经网络模型的准确度更高。
附图说明
通过参考附图阅读下文的详细描述,本公开实施方式的上述以及其他目的、特征和优点将变得易于理解。在附图中,以示例性而非限制性的方式示出了本公开的若干实施方式,其中:
图1示出了根据本公开一实施例神经网络训练方法的流程示意图;
图2示出了根据本公开一实施例神经网络训练的架构的示意框图;
图3示出了根据本公开一实施例通过迭代轮次和权重系数进行监督的示意框图;
图4示出了根据本公开一实施例的神经网络训练装置示意图。
图5是本公开实施例提供的一种电子设备示意图。
在附图中,相同或对应的标号表示相同或对应的部分。
具体实施方式
下面将参考若干示例性实施方式来描述本公开的原理和精神。应当理解,给出这些实施方式仅仅是为了使本领域技术人员能够更好地理解进而实现本公开,而并非以任何方式限制本公开的范围。
需要注意,虽然本文中使用“第一”、“第二”等表述来描述本公开的实施方式的不同模块、步骤和数据等,但是“第一”、“第二”等表述仅是为了在不同的模块、步骤和数据等之间进行区分,而并不表示特定的顺序或者重要程度。实际上,“第一”、“第二”等表述完全可以互换使用。
目前通过教师模型训练学生模型的方式,由于教师模型所需资源多,训练时间长,导致训练成本很高。本公开提供的实施例充分利用了神经网络的特征输出,对神经网络的参数进行调整,从而降低了训练的时间成本,提高数据利用率,
图1示出了本公开实施例提供的一种神经网络训练方法10,包括:步骤S11-S14,下文分别对上述步骤进行详细说明:
步骤S11,获取待训练的神经网络的多个中间层输出的多个特征图。
其中,待训练的神经网络包括多个依次连接的中间层,输入待训练的神经网络的训练样本经过多个中间层依次处理,每一中间层的输出的特征图即为其后一层中间层的输入。其中待训练的神经网络可以是残差神经网络(Residual Neural Network,ResNet),上述多个中间层可以是残差神经网络中的多个块(block);待训练的神经网络也可以是AlexNet,其中的多个卷积层即为上述多个中间层,但本公开实施例不以此为限,待训练的神经网络还可以具有其他结构,例如还可以为其他现有的神经网络或自主设计的神经网络。同时,根据实际情况,本公开实施例中的多个中间层可以是待训练的神经网络中的部分中间层,如在待训练的神经网络结构中靠近输出层的多个中间层,本公开实施例对此不作限定。
步骤S12,通过特征提取网络对多个特征图进行特征提取,分别得到对应的中间层的第一特征输出。
可以通过与中间层一一对应设置的特征提取网络,对相应中间层输出的特征图进行特征提取,特征提取网络可以是卷积神经网络(CNN),也可以是注意力网络(AttentionNetworks)等,本公开实施例对此不作限定。在本公开的一个实施例中,特征提取网络包括卷积层和全连接层。其中卷积层可以通过逐点卷积的方式对特征图进行特征提取,并通过全连接层得到第一特征输出。
步骤S13,根据多个第一特征输出与待训练的神经网络输出的第二特征输出,计算得到第一损失。
其中,待训练的神经网络可以包括一个或多个全连接层,在输入的训练样本经过全部中间层的输出后,通过神经网络的全连接层得到第二特征输出。在对待训练的神经网络进行训练的过程中,将待训练的神经网络的中间层或输出层输出的一些特征输出作为监督特征对另一些特征输出进行监督,对比监督特征和相应的特征输出,并通过损失函数计算损失,然后根据损失对被监督的特征输出相应的待训练的神经网络的中间层的参数或待训练的神经网络的参数进行更新。其中,监督特征为相对概念,某一特征输出可以在一些情况下可以作为监督特征,而在另一些情况下可以作为被监督的特征输出。损失函数可以包括但不限于交叉熵函数、指数损失函数、铰链损失函数等。
步骤S14,基于第一损失,调整多个中间层的参数。
上述实施例中,根据第一损失针对的对象不同,调整的中间层也相应不同。在一些实施例中,通过一个第一损失调整相应的中间层的参数、以及该中间层的全部前序中间层的参数;而在另一些实施例中,通过一个第一损失可调节全部中间层的参数。后文会以具体的实施例进行说明。
利用神经网络训练过程中的一些特征输出监督另一些特征输出,即被监督的特征输出,通过对比监督特征与被监督的特征输出之间的差值或比值,计算得到第一损失,并根据该第一损失调整多个中间层的参数,由于其中作为监督的特征输出为训练过程中神经网络自身的特征输出,无需凭借其他教师模型,从而提高模型达到收敛的速度,提高训练效率。
为便于理解,图2示出了本公开涉及的一种神经网络训练的架构,其中神经网络至少包括多个依次连接的中间层、全连接层,另外,独立于神经网络设置有包括卷积层和全连接层的特征提取网络。
在一实施例中,步骤S13根据多个第一特征输出与神经网络输出的第二特征输出,计算得到第一损失,包括:每个第一特征输出分别与第二特征输出进行比对,得到每个中间层对应的中间层损失;步骤S14基于第一损失,调整多个中间层的参数,包括:基于中间层损失,调整对应的中间层的参数以及对应的中间层前序全部中间层的参数。
在通过特征提取网络对进行特征提取并得到相应的第一特征输出后,将神经网络的第二特征输出作为监督特征,通过将相应多个中间层的第一特征输出分别与神经网络的第二特征输出进行比对,得到对应于该相应多个中间层的中间层损失,再根据中间层损失,调整中间层的参数。由于中间层是依次连接的,输出的特征图即为后一中间层的输入,因此,一个中间层的特征图包含了该中间层及其前序中间层的信息,从而可以根据一个中间层的中间层损失,调整该中间层的参数及其前序中间层的参数。示例性的,以图2中所示的神经网络训练的架构为例,通过第二特征输出与中间层3被提取对应的第一特征输出F3计算得到的第一损失L3,基于该第一损失L3,调整的是相应的中间层3的参数、以及中间层3前序的中间层2、中间层1的参数。通过充分利用神经网络训练过程中产生的特征输出等数据,并利用第二特征输出作为这些特征输出等数据的监督信息,训练神经网络,从而提高了效率。
在一实施例中,步骤S13根据多个第一特征输出与神经网络输出的第二特征输出,计算得到第一损失,包括:根据当前轮的第二特征输出与前N轮迭代的第二特征输出,计算得到第一迭代损失,N为正整数;步骤S14基于第一损失,调整多个中间层的参数,包括:基于第一迭代损失,调整多个中间层的参数。
该实施例中充分利用不同轮次迭代的结果,由于神经网络的训练总体是将结果进行收敛,但对于局部迭代结果来说具有不确定性,因此,可以根据前一轮或前几轮迭代中的特征输出,对本轮迭代的特征输出进行监督。具体而言,在对待训练的神经网络进行训练的过程中,往往需要进行多轮的迭代,才能使待训练的神经网络达到训练要求,及最终的结果足够收敛,得到神经网络模型。本实施例中,在训练过程中,将前N轮迭代的第二特征输出汇总,作为监督特征,与本轮的第二特征输出进行比对得到一个前N轮迭代损失,示例性的,如图2所示的神经网络训练的架构为例,可见第二特征输出是经过全部中间层以及全连接层(FC)得到的输出,因此,在通过前N轮迭代的第二特征输出对本轮的第二特征输出进行监督时,基于得到的第一迭代损失,可以对待训练的神经网络的全部中间层的参数进行调整。能够使得训练效率进一步提高。
在一实施例中,根据当前轮的第二特征输出与前N轮迭代的第二特征输出,计算得到第一迭代损失,包括:根据前N轮迭代的第二特征输出分别对应的第一权重系数,对多个前N轮迭代的第二特征输出进行加权拼接(concate),得到前N轮迭代的加权特征输出;根据当前轮的第二特征输出与前N轮迭代的加权特征输出,计算得到第一迭代损失。
图3示意性的示出了通过迭代轮次和权重系数进行监督的示意图,以图3所示为例,在基于前N轮的第二特征输出和当前轮的第二特征输出计算第一迭代损失时,可以将前N轮迭代的第二特征输出,即第k1轮迭代的第二特征输出、第k2轮迭代的第二特征输出……第ks轮迭代的第二特征输出,进行加权拼接,即对第k1-ks轮迭代的第二特征输出设置相应的权重系数:第一权重系数P(k1)、第一权重系数P(k2)、……、第一权重系数P(ks),在加权后进行拼接,得到一个加权特征输出。因此,加权特征输出包含了前N轮迭代的信息,将其与当前轮的第二特征输出进行比较,可以计算得到第一迭代损失。示例地,可以通过比较加权特征输出与当前轮的第二特征输出之间的差值,计算得到第一迭代损失。在一实施例中,可以通过计算加权特征输出与当前轮的第二特征输出之间的欧式距离,确定加权特征输出与当前轮的第二特征输出之间的差值。
在一实施例中,第i轮迭代的第二特征输出对应的第一权重系数大于第j轮迭代的第二特征输出对应的第一权重系数,i、j均为正整数,且i>j。在一实施例中,该前N轮迭代包括该第i轮迭代、第j轮迭代,也即第i轮迭代、第j轮迭代是当前轮迭代之前的迭代,且0<i-j<N。在一实施例中,该前N轮迭代中并非所有在前轮迭代的第二特征输出所对应第一权重系数均大于在后轮迭代的第二特征输出所对应第一权重系数,例如,相邻多轮迭代的第二特征输出所对应第一权重系数为相同值,本申请实施例对此不作限定。
图3示意性的示出了通过迭代轮次和权重系数进行监督的示意图,以图3为例,在通过前面迭代的特征输出进行监督时,考虑到神经网络的训练总体是将结果进行收敛的特性,越靠前轮次的迭代特征输出对当前轮次的特征输出的监督意义更小,因此,在对多个前N轮迭代的第二特征输出进行加权拼接时,迭代轮次越靠前,则对应的第一权重系数越小,从而提高距离当前轮次越近的迭代轮次特征输出监督作用,相对弱化距离当前轮次较远的迭代轮次特征输出的监督作用,从而保证了监督效果。
在一实施例中,神经网络训练方法S10还包括:存储各轮迭代的第二特征输出。示例地,可以将各轮迭代的第二特征输出,即当前轮次之前每个轮次迭代以及当前轮次的第二特征输出均保存在独立于神经网络的存储模块中,并记录迭代轮次,便于监督后续轮次特征输出时取用。在另一些示例中,也可以根据实际需要,仅将用于监督当前轮次第二特征输出的前N轮迭代的第二特征输出保存于存储模块,即保存与当前轮最近的N个在前轮次的第二特征输出,相应的,当前轮次的第二特征输出在进入下一轮迭代后进行存储,同时删除已存储的最靠前轮次的一个第二特征输出,存储的第二特征输维持为N个。采用该种方式能够节约存储资源,降低成本。
在一例中,步骤S13根据多个第一特征输出与待训练的神经网络输出的第二特征输出,计算得到第一损失,包括:对第一特征输出和第二特征输出进行拼接,得到第三特征输出;根据当前轮的第二特征输出与前M轮迭代的第三特征输出或当前轮的第三特征输出与前M轮迭代的第三特征输出,计算得到第二迭代损失,M为正整数;步骤S14基于第一损失,调整中间层的参数,包括:基于第二迭代损失,调整多个中间层的参数。
与在前考虑前序迭代的特征输出相似,本实施例中,对单次迭代过程中的神经网络各中间层的第一特征输出和神经网络的第二特征输出进行拼接(concate),得到第三特征输出,将第三特征输出用于对之后迭代轮次的第二特征输出或第三特征输出进行监督。示例性的,如图2所示的神经网络训练的架构为例,第二特征输出是经过全部中间层以及全连接层(FC)得到的输出,而第三特征输出是通过第一特征输出与第二特征输出进行拼接得到,因此,第二特征输出或第三特征输出均包含了全部中间层的信息,在通过前N轮迭代的第三特征输出对本轮的第二特征输出或第三特征输出进行监督时,基于得到的第二迭代损失,可以对待训练的神经网络的全部中间层的参数进行调整。本实施例充分利用了训练过程中,不同迭代次数产生的结果,提高了模型的训练效率。
在一实施例中,根据当前轮的第二特征输出与前M轮迭代的第三特征输出或当前轮的第三特征输出与前M轮迭代的第三特征输出,计算得到第二迭代损失,包括:根据前M轮迭代的第三特征输出分别对应的第二权重系数,对多个前M轮迭代的第三特征输出进行加权拼接,得到前M轮迭代的加权特征输出,根据当前轮的第二特征输出或当前轮的第三特征输出与前M轮迭代的加权特征输出,计算得到第二迭代损失。
与在前实施例中对多个前N轮迭代的第二特征输出进行加权拼接的原理相同,设置第二权重系数,对前M轮迭代的第三特征输出进行加权拼接得到前M轮迭代的加权特征输出以计算第二迭代损失。
在一实施例中,第g轮迭代的第三特征输出对应的第二权重系数大于第h轮迭代的第三特征输出对应的第二权重系数,g、h均为正整数,且g>h。
同样的,在通过前M轮迭代的第三特征输出得到前M轮迭代的加权特征输出时,也需要考虑越靠前轮次的迭代特征输出对当前轮次的特征输出的监督意义更小,因此第二权重系数也设置为轮次越靠前的第三特征输出对应的第二权重系数越小,从而加强了与本轮迭代距离更近的迭代轮次特征输出的监督作用,以提高在通过在前迭代轮次进行监督时的准确性。
在一实施例中,神经网络训练方法S10还包括:存储各轮迭代的第三特征输出。与前述实施例中存储第二特征输出原理相同,可以将当前轮次之前每个轮次迭代以及当前轮次的第三特征输出存储在独立于神经网络的存储模块中,便于监督后续轮次特征输出时使用。同样,在另一些示例中,也可以根据实际需要,仅将用于监督当前轮次的前M轮迭代的第三特征输出保存于存储模块,即保存与当前轮最近的M个在前轮次的第三特征输出,相应的,当前轮次的第M特征输出在进入下一轮迭代后进行存储,同时删除已存储的最靠前轮次的一个第三特征输出,存储的第三特征输维持为M个。采用该种方式能够节约存储资源,降低成本。
在一实施例中,神经网络训练方法S10还包括:根据待训练的神经网络输出的训练样本的预测值与训练样本的真实值计算得到的第二损失;基于第二损失,调整神经网络的参数。神经网络在全连接层后,可以通过设置softmax层得到分类结果,即本轮迭代神经网络针对训练样本的预测值,将预测值与训练样本的真实值比对,并通过损失函数计算两者的第二损失,可以不仅调整神经网络中全部中间层的参数,还可以调整包括全连接层的参数等神经网络中的其他参数,使得结果更加收敛。
在一实施例中,神经网络训练方法S10还包括:当第二损失小于训练阈值时,神经网络完成训练。可以根据神经网络对训练样本的输出的预测值进行判断是否待训练的神经网络是否完成训练,如果预测值足够准确,即与真实值之间的第二损失小于训练阈值,则说明模型训练完成。在另一示例中,还可以通过设置测试样本组,对神经网络进行测试,通过输入测试样本,通过神经网络得到测试样本的预测值,通过与测试样本真实值进行比对,如果正确率超过一预设阈值,则说明神经网络训练效果达到目标,可完成训练。
基于同一发明构思,本公开实施例还提供一种神经网络训练装置100。如图4所示,其中,神经网络训练装置100包括获取模块110、特征提取模块120、损失确定模块130和反馈模块140。
获取模块110,用于获取待训练的神经网络的多个中间层输出的多个特征图。
特征提取模块120,用于通过特征提取网络对多个特征图进行特征提取,分别得到每个中间层的第一特征输出。
损失确定模块130,用于根据多个第一特征输出与待训练的神经网络输出的第二特征输出,计算得到第一损失。
反馈模块140,用于基于第一损失,调整中间层的参数。
在一实施例中,损失确定模块130还用于:每个第一特征输出分别与第二特征输出进行比对,得到每个中间层对应的中间层损失;反馈模块140还用于:基于中间层损失,调整对应的中间层的参数以及对应的中间层前序全部中间层的参数。
在一实施例中,损失确定模块130还用于:根据当前轮的第二特征输出与前N轮迭代的第二特征输出,计算得到第一迭代损失,N为正整数;反馈模块140还用于:基于第一迭代损失,调整多个中间层的参数。
在一实施例中,损失确定模块130还用于:根据前N轮迭代的第二特征输出分别对应的第一权重系数,对多个前N轮迭代的第二特征输出进行加权拼接,得到前N轮迭代的加权特征输出;根据当前轮的第二特征输出与前N轮迭代的加权特征输出,计算得到第一迭代损失。
在一实施例中,第i轮迭代的第二特征输出对应的第一权重系数大于第j轮迭代的第二特征输出对应的第一权重系数,i、j均为正整数,且i>j。
在一实施例中,神经网络训练装置100还包括:存储模块,用于存储各轮迭代的第二特征输出。
在一实施例中,损失确定模块130还用于:对第一特征输出和第二特征输出进行拼接,得到第三特征输出;根据当前轮的第二特征输出与前M轮迭代的第三特征输出或当前轮的第三特征输出与前M轮迭代的第三特征输出,计算得到第二迭代损失,M为正整数;反馈模块140还用于:基于第二迭代损失,调整多个中间层的参数。
在一实施例中,损失确定模块130还用于:根据前M轮迭代的第三特征输出分别对应的第二权重系数,对多个前M轮迭代的第三特征输出进行加权拼接,得到前M轮迭代的加权特征输出;根据当前轮的第二特征输出或当前轮的第三特征输出与前M轮迭代的加权特征输出,计算得到第二迭代损失。
在一实施例中,第g轮迭代的第三特征输出对应的第二权重系数大于第h轮迭代的第三特征输出对应的第二权重系数,g、h均为正整数,且g>h。
在一实施例中,神经网络训练装置100还包括:存储模块,用于存储各轮迭代的第三特征输出。
在一实施例中,特征提取网络包括卷积层和全连接层。
在一实施例中,损失确定模块130还用于:根据待训练的神经网络输出的训练样本的预测值与训练样本的真实值,计算得到第二损失;反馈模块140还用于:基于第二损失,调整待训练的神经网络的参数。
在一实施例中,神经网络训练装置100还包括:判断模块,用于当第二损失小于训练阈值时,待训练的神经网络完成训练,得到神经网络模型。
关于上述实施例中的装置,其中各个模块执行操作的具体方式已经在有关该方法的实施例中进行了详细描述,此处将不做详细阐述说明。
如图5所示,本公开的一个实施方式提供了一种电子设备300。其中,该电子设备300包括存储器301、处理器302、输入/输出(Input/Output,I/O)接口303。其中,存储器301,用于存储指令。处理器302,用于调用存储器301存储的指令执行本公开实施例的神经网络训练方法。其中,处理器302分别与存储器301、I/O接口303连接,例如可通过总线系统和/或其他形式的连接机构(未示出)进行连接。存储器301可用于存储程序和数据,包括本公开实施例中涉及的神经网络训练方法的程序,处理器302通过运行存储在存储器301的程序从而执行电子设备300的各种功能应用以及数据处理。
本公开实施例中处理器302可以采用数字信号处理器(Digital SignalProcessing,DSP)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)、可编程逻辑阵列(Programmable Logic Array,PLA)中的至少一种硬件形式来实现,所述处理器302可以是中央处理单元(Central Processing Unit,CPU)或者具有数据处理能力和/或指令执行能力的其他形式的处理单元中的一种或几种的组合。
本公开实施例中的存储器301可以包括一个或多个计算机程序产品,所述计算机程序产品可以包括各种形式的计算机可读存储介质,例如易失性存储器和/或非易失性存储器。所述易失性存储器例如可以包括随机存取存储器(Random Access Memory,RAM)和/或高速缓冲存储器(cache)等。所述非易失性存储器例如可以包括只读存储器(Read-OnlyMemory,ROM)、快闪存储器(Flash Memory)、硬盘(Hard Disk Drive,HDD)或固态硬盘(Solid-State Drive,SSD)等。
本公开实施例中,I/O接口303可用于接收输入的指令(例如数字或字符信息,以及产生与电子设备300的用户设置以及功能控制有关的键信号输入等),也可向外部输出各种信息(例如,图像或声音等)。本公开实施例中I/O接口303可包括物理键盘、功能按键(比如音量控制按键、开关按键等)、鼠标、操作杆、轨迹球、麦克风、扬声器、和触控面板等中的一个或多个。
可以理解的是,本公开实施例中尽管在附图中以特定的顺序描述操作,但是不应将其理解为要求按照所示的特定顺序或是串行顺序来执行这些操作,或是要求执行全部所示的操作以得到期望的结果。在特定环境中,多任务和并行处理可能是有利的。
本公开实施例涉及的方法和装置能够利用标准编程技术来完成,利用基于规则的逻辑或者其他逻辑来实现各种方法步骤。还应当注意的是,此处以及权利要求书中使用的词语“装置”和“模块”意在包括使用一行或者多行软件代码的实现和/或硬件实现和/或用于接收输入的设备。
此处描述的任何步骤、操作或程序可以使用单独的或与其他设备组合的一个或多个硬件或软件模块来执行或实现。在一个实施方式中,软件模块使用包括包含计算机程序代码的计算机可读介质的计算机程序产品实现,其能够由计算机处理器执行用于执行任何或全部的所描述的步骤、操作或程序。
出于示例和描述的目的,已经给出了本公开实施的前述说明。前述说明并非是穷举性的也并非要将本公开限制到所公开的确切形式,根据上述教导还可能存在各种变形和修改,或者是可能从本公开的实践中得到各种变形和修改。选择和描述这些实施例是为了说明本公开的原理及其实际应用,以使得本领域的技术人员能够以适合于构思的特定用途来以各种实施方式和各种修改而利用本公开。
Claims (16)
1.一种神经网络训练方法,其中,所述方法包括:
获取待训练的神经网络的多个中间层输出的多个特征图;
通过特征提取网络对所述多个特征图进行特征提取,分别得到每个所述中间层的第一特征输出;
根据多个所述第一特征输出与所述待训练的神经网络输出的第二特征输出,计算得到第一损失;
基于所述第一损失,调整所述多个中间层的参数。
2.根据权利要求1所述的方法,其中,所述根据多个所述第一特征输出与所述神经网络输出的第二特征输出,计算得到第一损失,包括:
每个所述第一特征输出分别与所述第二特征输出进行比对,得到每个所述中间层对应的中间层损失;
所述基于所述第一损失,调整所述多个中间层的参数,包括:基于所述中间层损失,调整对应的中间层的参数以及对应的中间层前序全部中间层的参数。
3.根据权利要求1所述的方法,其中,所述根据多个所述第一特征输出与所述神经网络输出的第二特征输出,计算得到第一损失,包括:
根据当前轮的第二特征输出与前N轮迭代的第二特征输出,计算得到第一迭代损失,N为正整数;
所述基于所述第一损失,调整所述多个中间层的参数,包括:基于所述第一迭代损失,调整所述多个中间层的参数。
4.根据权利要求3所述的方法,其中,所述根据当前轮的第二特征输出与前N轮迭代的第二特征输出,计算得到第一迭代损失,包括:
根据所述前N轮迭代的第二特征输出分别对应的第一权重系数,对多个所述前N轮迭代的第二特征输出进行加权拼接,得到前N轮迭代的加权特征输出;
根据当前轮的第二特征输出与所述前N轮迭代的加权特征输出,计算得到所述第一迭代损失。
5.根据权利要求4所述的方法,其中,第i轮迭代的第二特征输出对应的第一权重系数大于第j轮迭代的第二特征输出对应的第一权重系数,i、j均为正整数,且i>j。
6.根据权利要求3所述的方法,其中,所述方法还包括:存储各轮迭代的第二特征输出。
7.根据权利要求1所述的方法,其中,所述根据多个所述第一特征输出与所述待训练的神经网络输出的第二特征输出,计算得到第一损失,包括:
对所述第一特征输出和所述第二特征输出进行拼接,得到第三特征输出;
根据当前轮的第二特征输出与前M轮迭代的第三特征输出或当前轮的第三特征输出与前M轮迭代的第三特征输出,计算得到第二迭代损失,M为正整数;
所述基于所述第一损失,调整所述中间层的参数,包括:基于所述第二迭代损失,调整所述多个中间层的参数。
8.根据权利要求7所述的方法,其中,所述根据当前轮的第二特征输出与前M轮迭代的第三特征输出或当前轮的第三特征输出与前M轮迭代的第三特征输出,计算得到第二迭代损失,包括:
根据所述前M轮迭代的第三特征输出分别对应的第二权重系数,对多个所述前M轮迭代的第三特征输出进行加权拼接,得到前M轮迭代的加权特征输出;
根据所述当前轮的第二特征输出或所述当前轮的第三特征输出与所述前M轮迭代的加权特征输出,计算得到所述第二迭代损失。
9.根据权利要求8所述的方法,其中,第g轮迭代的第三特征输出对应的第二权重系数大于第h轮迭代的第三特征输出对应的第二权重系数,g、h均为正整数,且g>h。
10.根据权利要求7所述的方法,其中,所述方法还包括:存储各轮迭代的第三特征输出。
11.根据权利要求1-10任一项所述的方法,其中,所述特征提取网络包括卷积层和全连接层。
12.根据权利要求1-10任一项所述的方法,其中,所述方法还包括:
根据所述待训练的神经网络输出的训练样本的预测值与所述训练样本的真实值,计算得到第二损失;
基于所述第二损失,调整所述待训练的神经网络的参数。
13.根据权利要求12所述的方法,其中,所述方法还包括:当所述第二损失小于训练阈值时,所述待训练的神经网络完成训练,得到神经网络模型。
14.一种神经网络训练装置,其中,所述装置包括:
获取模块,用于获取待训练的神经网络的多个中间层输出的多个特征图;
特征提取模块,用于通过特征提取网络对所述多个特征图进行特征提取,分别得到每个所述中间层的第一特征输出;
损失确定模块,用于根据多个所述第一特征输出与所述待训练的神经网络输出的第二特征输出,计算得到第一损失;
反馈模块,用于基于所述第一损失,调整所述中间层的参数。
15.一种电子设备,其中,所述电子设备包括:
存储器,用于存储指令;以及
处理器,用于调用所述存储器存储的指令执行如权利要求1-13中任一项所述的神经网络训练方法。
16.一种计算机可读存储介质,其中存储有指令,所述指令被处理器执行时,执行如权利要求1-13中任一项所述的神经网络训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910907549.1A CN110705691A (zh) | 2019-09-24 | 2019-09-24 | 神经网络训练方法、装置及计算机可读存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910907549.1A CN110705691A (zh) | 2019-09-24 | 2019-09-24 | 神经网络训练方法、装置及计算机可读存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN110705691A true CN110705691A (zh) | 2020-01-17 |
Family
ID=69196010
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910907549.1A Pending CN110705691A (zh) | 2019-09-24 | 2019-09-24 | 神经网络训练方法、装置及计算机可读存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN110705691A (zh) |
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111767989A (zh) * | 2020-06-29 | 2020-10-13 | 北京百度网讯科技有限公司 | 神经网络的训练方法和装置 |
CN111898707A (zh) * | 2020-08-24 | 2020-11-06 | 鼎富智能科技有限公司 | 模型训练方法、文本分类方法、电子设备及存储介质 |
CN112183336A (zh) * | 2020-09-28 | 2021-01-05 | 平安科技(深圳)有限公司 | 表情识别模型训练方法、装置、终端设备及存储介质 |
CN113409769A (zh) * | 2020-11-24 | 2021-09-17 | 腾讯科技(深圳)有限公司 | 基于神经网络模型的数据识别方法、装置、设备及介质 |
CN116596916A (zh) * | 2023-06-09 | 2023-08-15 | 北京百度网讯科技有限公司 | 缺陷检测模型的训练和缺陷检测方法及其装置 |
-
2019
- 2019-09-24 CN CN201910907549.1A patent/CN110705691A/zh active Pending
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111767989A (zh) * | 2020-06-29 | 2020-10-13 | 北京百度网讯科技有限公司 | 神经网络的训练方法和装置 |
CN111898707A (zh) * | 2020-08-24 | 2020-11-06 | 鼎富智能科技有限公司 | 模型训练方法、文本分类方法、电子设备及存储介质 |
CN112183336A (zh) * | 2020-09-28 | 2021-01-05 | 平安科技(深圳)有限公司 | 表情识别模型训练方法、装置、终端设备及存储介质 |
WO2022062403A1 (zh) * | 2020-09-28 | 2022-03-31 | 平安科技(深圳)有限公司 | 表情识别模型训练方法、装置、终端设备及存储介质 |
CN113409769A (zh) * | 2020-11-24 | 2021-09-17 | 腾讯科技(深圳)有限公司 | 基于神经网络模型的数据识别方法、装置、设备及介质 |
CN113409769B (zh) * | 2020-11-24 | 2024-02-09 | 腾讯科技(深圳)有限公司 | 基于神经网络模型的数据识别方法、装置、设备及介质 |
CN116596916A (zh) * | 2023-06-09 | 2023-08-15 | 北京百度网讯科技有限公司 | 缺陷检测模型的训练和缺陷检测方法及其装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110705691A (zh) | 神经网络训练方法、装置及计算机可读存储介质 | |
KR102110486B1 (ko) | 인공 뉴럴 네트워크 클래스-기반 프루닝 | |
KR102492318B1 (ko) | 모델 학습 방법 및 장치, 및 데이터 인식 방법 | |
CN111612134B (zh) | 神经网络结构搜索方法、装置、电子设备及存储介质 | |
KR20190050141A (ko) | 고정 소수점 타입의 뉴럴 네트워크를 생성하는 방법 및 장치 | |
Kan et al. | Simple reservoir computing capitalizing on the nonlinear response of materials: theory and physical implementations | |
CN112990444B (zh) | 一种混合式神经网络训练方法、系统、设备及存储介质 | |
CN111260032A (zh) | 神经网络训练方法、图像处理方法及装置 | |
WO2021208455A1 (zh) | 一种面向家居口语环境的神经网络语音识别方法及系统 | |
CN109214502B (zh) | 神经网络权重离散化方法和系统 | |
CN110059804B (zh) | 数据处理方法及装置 | |
CN110321430B (zh) | 域名识别和域名识别模型生成方法、装置及存储介质 | |
KR20220098991A (ko) | 음성 신호에 기반한 감정 인식 장치 및 방법 | |
KR20190134965A (ko) | 뉴럴 네트워크 학습 방법 및 그 시스템 | |
KR20190136578A (ko) | 음성 인식 방법 및 장치 | |
CN114997287A (zh) | 模型训练和数据处理方法、装置、设备及存储介质 | |
CN109033413B (zh) | 一种基于神经网络的需求文档和服务文档匹配方法 | |
CN115345303A (zh) | 卷积神经网络权重调优方法、装置、存储介质和电子设备 | |
KR102292921B1 (ko) | 언어 모델 학습 방법 및 장치, 음성 인식 방법 및 장치 | |
Giannakopoulos et al. | Improving post-processing of audio event detectors using reinforcement learning | |
JP7438544B2 (ja) | ニューラルネットワーク処理装置、コンピュータプログラム、ニューラルネットワーク製造方法、ニューラルネットワークデータの製造方法、ニューラルネットワーク利用装置、及びニューラルネットワーク小規模化方法 | |
CN113297579B (zh) | 基于时序神经通路的语音识别模型中毒检测方法及装置 | |
Vassiljeva et al. | Neural networks based minimal or reduced model representation for control of nonlinear MIMO systems | |
Salaken et al. | Switch point finding using polynomial regression for fuzzy type reduction algorithms | |
CN117873904B (zh) | 基于t-分布鲸鱼优化算法生成浮点数测试激励方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
RJ01 | Rejection of invention patent application after publication |
Application publication date: 20200117 |
|
RJ01 | Rejection of invention patent application after publication |