CN113449840A - 神经网络训练方法及装置、图像分类的方法及装置 - Google Patents

神经网络训练方法及装置、图像分类的方法及装置 Download PDF

Info

Publication number
CN113449840A
CN113449840A CN202010231122.7A CN202010231122A CN113449840A CN 113449840 A CN113449840 A CN 113449840A CN 202010231122 A CN202010231122 A CN 202010231122A CN 113449840 A CN113449840 A CN 113449840A
Authority
CN
China
Prior art keywords
neural network
loss value
network
attention
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202010231122.7A
Other languages
English (en)
Inventor
李鹏
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Nanjing Artificial Intelligence Advanced Research Institute Co ltd
Original Assignee
Nanjing Artificial Intelligence Advanced Research Institute Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Nanjing Artificial Intelligence Advanced Research Institute Co ltd filed Critical Nanjing Artificial Intelligence Advanced Research Institute Co ltd
Priority to CN202010231122.7A priority Critical patent/CN113449840A/zh
Publication of CN113449840A publication Critical patent/CN113449840A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • Computational Linguistics (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)

Abstract

本公开公开了一种神经网络训练方法及装置、基于神经网络进行图像分类的方法及装置、计算机可读存储介质及电子设备。神经网络训练方法包括:将训练样本输入被训练神经网络;通过训练样本以及被训练神经网络,确定被训练神经网络的第一损失值;通过训练样本以及至少一个注意力网络,确定至少一个注意力网络的第二损失值;根据第一损失值和第二损失值,更新被训练神经网络中的参数。本方案利用注意力网络辅助训练神经网络,提升了神经网络训练效果,同时,训练完成的神经网络可以与注意力网络剥离,使得神经网络的参数量不会增加。

Description

神经网络训练方法及装置、图像分类的方法及装置
技术领域
本发明涉及深度学习领域,具体涉及一种神经网络训练方法及装置、基于神经网络进行图像分类的方法及装置、计算机可读存储介质及电子设备。
背景技术
神经网络大幅度提升了机器学习的性能,在图像分类、目标检测、模式识别、语义分割和自然语言处理等领域取得了极大的成功,成为目前机器学习理论研究和工业应用的一个主流分支。
但是,目前神经网络训练存在训练效果较差的问题,如何提高深度神经网络的训练效率,提升训练的效果,是目前影响深度神经网络发展和应用的关键问题之一。
发明内容
为了解决上述技术问题,提出了本公开。本公开的实施例提供了一种神经网络训练方法及装置、基于神经网络进行图像分类的方法及装置、计算机可读存储介质及电子设备。
根据本公开实施例的一个方面,提供了一种神经网络训练方法,包括:将训练样本输入被训练神经网络;通过训练样本以及被训练神经网络,确定被训练神经网络的第一损失值;通过训练样本以及至少一个注意力网络,确定至少一个注意力网络的第二损失值;根据第一损失值和第二损失值,更新被训练神经网络中的参数。
根据本公开实施例的第二方面,提供了一种基于神经网络进行图像分类的方法,包括:将待分类图像输入神经网络,其中,神经网络通过上述任一所述的神经网络训练方法训练得到;采用神经网络对待分类图像进行分类。
根据本公开实施例的第三方面,提供了一种神经网络训练装置,包括:输入模块,用于将训练样本输入被训练神经网络;第一确定模块,用于通过训练样本以及被训练神经网络,确定被训练神经网络的第一损失值;第二确定模块,用于通过训练样本以及至少一个注意力网络,确定至少一个注意力网络的第二损失值;更新模块,用于根据第一损失值和第二损失值,更新被训练神经网络中的参数。
根据本公开实施例的第四方面,提供了一种基于神经网络进行图像分类的装置,包括:输入模块,用于将待分类图像输入神经网络,其中,神经网络通过上述任一所述的神经网络训练方法训练得到;分类模块,用于采用神经网络对待分类图像进行分类。
根据本公开实施例的第五方面,提供了一种计算机可读存储介质,所述存储介质存储有计算机程序,所述计算机程序用于执行上述任一所述的方法。
根据本公开实施例的第六方面,提供了一种电子设备,所述电子设备包括:处理器;用于存储所述处理器可执行指令的存储器;所述处理器,用于执行上述任一所述的方法。
本公开实施例提供的技术方案至少可以带来如下有益效果:
通过分别确定被训练神经网络和注意力网络的损失值,以及根据被训练神经网络和注意力网络的损失值,调整被训练神经网络的参数,由于注意力网络可以为被训练神经网络的参数提供更合适的梯度,因此可以充分利用注意力网络辅助训练神经网络,提升神经网络训练效果,同时,训练完成的神经网络可以与注意力网络剥离,使得神经网络的参数量不会增加。
通过利用上述神经网络训练方法训练得到的神经网络对待分类图像进行分类,能够提高神经网络的分类准确率。
附图说明
通过结合附图对本公开实施例进行更详细的描述,本公开的上述以及其他目的、特征和优势将变得更加明显。附图用来提供对本公开实施例的进一步理解,并且构成说明书的一部分,与本公开实施例一起用于解释本公开,并不构成对本公开的限制。在附图中,相同的参考标号通常代表相同部件或步骤。
图1是本公开实施例所提供的一种实施环境的示意图。
图2是本公开一示例性实施例提供的神经网络训练方法的流程示意图。
图3是本公开另一示例性实施例提供的神经网络训练方法的流程示意图。
图4是本公开另一示例性实施例提供的神经网络训练方法的流程示意图。
图5是本公开一示例性实施例提供的被训练神经网络和注意力网络的结构示意图。
图6是本公开一示例性实施例提供的第一注意力网络的结构示意图。
图7是本公开另一示例性实施例提供的神经网络训练方法的流程示意图。
图8是本公开一示例性实施例提供的基于神经网络进行图像分类的方法的流程示意图。
图9是本公开一示例性实施例提供的神经网络训练装置的框图。
图10是本申请一示例性实施例提供的神经网络训练装置的第二确定模块的框图。
图11是本申请一示例性实施例提供的神经网络训练装置的确定单元的框图。
图12是本申请一示例性实施例提供的神经网络训练装置的更新模块的框图。
图13是本申请一示例性实施例提供的基于神经网络进行图像分类的装置的框图。
图14是本公开一示例性实施例提供的电子设备的结构图。
具体实施方式
下面,将参考附图详细地描述根据本公开的示例实施例。显然,所描述的实施例仅仅是本公开的一部分实施例,而不是本公开的全部实施例,应理解,本公开不受这里描述的示例实施例的限制。
申请概述
神经网络是一种运算模型,由大量的节点(或称神经元)之间相互连接构成,每个节点代表一种特定的输出函数,称为激励函数。每两个节点间的连接都代表一个对于通过该连接信号的加权值,称之为权重。神经网络一般包括多个神经网络层,上下网络层之间相互级联,第i个神经网络层的输出与第i+1个神经网络层的输入相连,第i+1个神经网络层的输出与第i+2个神经网络层的输入相连,以此类推。训练样本输入级联的神经网络层后,通过每个神经网络层输出一个输出结果,该输出结果作为下一个神经网络层的输入,由此,通过多个神经网络层计算获得输出。根据目标结果和输出层输出的预测结果,利用损失函数计算得到出损失值,再根据损失值来反向调整每一层的权重矩阵和激励函数,神经网络利用训练样本不断地经过上述调整过程,使得神经网络的权重等参数得到调整,直到神经网络的预测结果与目标结果相符,该过程就被称为神经网络的训练过程。神经网络经过训练后,可得到神经网络模型。
另外,注意力(Attention)模型最近两年被广泛使用在自然语言处理、图像识别及语音识别等各种不同类型的深度学习任务中,是深度学习技术中最值得关注与深入了解的核心技术之一。注意力模型借鉴了人类的视觉注意力机制,人类视觉通过快速扫描全局图像,获得需要重点关注的目标区域,也就是一般所说的注意力焦点,而后对这一区域投入更多注意力资源,以获取更多所需要关注目标的细节信息,而抑制其他无用信息。人类视觉注意力机制极大地提高了视觉信息处理的效率与准确性。深度学习中的注意力机制从本质上讲和人类的视觉注意力机制类似,核心目标也是从众多信息中选择出对当前任务目标更关键的信息。
利用注意力网络可以辅助训练神经网络,以提高神经网络的性能。但是,利用注意力网络的损失值调整注意力网络和神经网络的参数,完成训练后,注意力网络会变成神经网络的一部分,注意力网络无法从神经网络中剥离,而且会导致训练完成的神经网络模型的参数量增多,提高了神经网络模型的复杂度。
针对上述问题,本公开实施例提供了一种神经网络训练方法,该方法通过分别确定被训练神经网络和注意力网络的损失值,以及根据被训练神经网络和注意力网络的损失值,调整被训练神经网络的参数,由于注意力网络可以为被训练神经网络的参数提供更合适的梯度,因此可以充分利用注意力网络辅助训练神经网络,提升神经网络训练效果,同时,训练完成的神经网络可以与注意力网络剥离,使得神经网络的参数量不会增加。
另外,本公开实施例提供了一种基于神经网络进行图像分类的方法,该方法通过利用上述神经网络训练方法训练得到的神经网络对待分类图像进行分类,能够提高神经网络的分类准确率。
示例性系统
图1是本公开实施例所提供的一种实施环境的示意图。该实施环境包括:服务器120和多个终端设备110。
终端110可以是手机、游戏主机、平板电脑、照相机、摄像机、车载电脑等移动终端设备,或者,终端110也可以是个人计算机(Personal Computer,PC),比如膝上型便携计算机和台式计算机等等。本领域技术人员可以知晓,上述终端110类型可以相同或者不同,其数量可以为一个或大于一个。本公开实施例对终端的设备类型和数量不加以限定。
服务器120是一台服务器,或者由若干台服务器组成,或者是一个虚拟化平台,或者是一个云计算服务中心。终端110与服务器120之间通过通信网络相连。可选的,通信网络是有线网络或无线网络。
在一些可选的实施例中,服务器120接收终端110采集到的训练样本,并通过训练样本对神经网络进行训练,以更新神经网络中的参数。但本公开实施例对此不加以限定,在另一些可选的实施例中,终端110采集训练样本,并且通过训练样本对神经网络进行训练,以更新神经网络中的参数。
更新后的神经网络可以应用于图像分类、语义分割、目标检测等任务,本公开实施例对此不作限定。
示例性方法
图2是本公开一示例性实施例提供的神经网络训练方法的流程示意图。本实施例可应用在电子设备上,由终端设备或服务器等执行,本公开实施例对此不作限定。如图2所示,该方法可以包括如下步骤210、步骤220和步骤230。
步骤210,将训练样本输入被训练神经网络。
被训练神经网络可以为任意类型的神经网络。可选地,被训练神经网络可以为卷积神经网络(Convolutional Neural Network,CNN)、深度神经网络(Deep NeuralNetwork,DNN)或循环神经网络(Recurrent Neural Network,RNN)等,本公开实施例对被训练神经网络的具体类型不作限定。被训练神经网络可以包括输入层、卷积层、池化层、连接层等神经网络层,本公开实施例对此不作限定。另外,本公开实施例对每一种神经网络层的个数也不作限定。
步骤220,通过训练样本以及被训练神经网络,确定被训练神经网络的第一损失值。
在一实施例中,被训练神经网络采用第一损失函数,应当理解,第一损失函数可以为任意类型的损失函数。可选地,第一损失函数可以为交叉熵损失函数,用户可以根据应用场景不同选择不同的损失函数,本公开实施例对第一损失函数的类型不作限定。
将训练样本输入被训练神经网络后,通过卷积、池化等特征提取操作可以得到被训练神经网络输出的预测结果。根据预测结果、目标结果,利用第一损失函数计算可确定被训练神经网络的第一损失值。第一损失值越小,代表预测结果越接近目标结果,预测正确的准确率越高。相反,第一损失值越大,代表预测正确的准确率越低。
步骤230,通过训练样本以及至少一个注意力网络,确定至少一个注意力网络的第二损失值。
注意力网络为基于注意力机制的神经网络。至少一个注意力网络可以为一个或多个,本公开实施例对此不作限定。
在一实施例中,注意力网络采用第二损失函数,该第二损失函数可以为与第一损失函数相同类型或不同类型的损失函数,例如,第一损失函数和第二损失函数可以均为交叉熵损失函数,或者第一损失函数为交叉熵损失函数,第二损失函数为均方差损失函数等。第二损失函数的具体类型可以根据不同应用场景进行选择,本公开对此不作限定。
具体地,通过卷积、池化等特征提取操作可以得到注意力网络输出的预测结果。根据预测结果、目标结果,利用第二损失函数计算可确定注意力网络的第二损失值。同样,第二损失值越小,代表预测结果越接近目标结果,预测正确的准确率越高。相反,第二损失值越大,代表预测正确的准确率越低。
步骤240,根据第一损失值和第二损失值,更新被训练神经网络中的参数。
在一实施例中,结合被训练神经网络的第一损失值与注意力网络的第二损失值,对被训练神经网络中的参数进行调整。例如,可以根据第一损失值和第二损失值,得到总损失值;然后将总损失值反向传播,以更新第一神经网络中的权重等参数。神经网络中的参数包括权重、偏置等,本公开实施例对参数种类不作限定。
通过本公开实施例提供的神经网络训练方法,通过分别确定被训练神经网络和注意力网络的损失值,以及根据被训练神经网络和注意力网络的损失值,调整被训练神经网络的参数,由于注意力网络可以为被训练神经网络的参数提供更合适的梯度,因此可以充分利用注意力网络辅助训练神经网络,提升神经网络训练效果,同时,训练完成的神经网络可以与注意力网络剥离,使得神经网络的参数量不会增加。
图3是本公开另一示例性实施例提供的神经网络训练方法的流程示意图。在本公开图2所示实施例的基础上延伸出本公开图3所示实施例,下面着重叙述图3所示实施例与图2所示实施例的不同之处,相同之处不再赘述。
如图3所示,在本公开实施例提供的神经网络训练方法中,上述步骤230可以包括步骤2310和步骤2320。
步骤2310,获取被训练神经网络的至少一个中间层输出的至少一个特征图以及被训练神经网络的第一特征向量,中间层与其输出的特征图相对应。
在一实施例中,例如,被训练神经网络为用作图像分类的卷积神经网络,包括卷积层、池化层、全连接层和分类层。其中,中间层为全连接层和分类层之前的卷积层和池化层,从这些中间层中输出特征图,输出的特征图可以是多维,本公开实施例对特征图的通道维数不作限定。
至少一个中间层可以是一个或多个,本公开实施例对中间层的个数不作限定。例如,可以获取每个池化层输出的特征图。应当理解,上述描述仅为示例性描述,本公开实施例对此不作限定。
在一实施例中,第一特征向量可以为全连接层和分类层的前一层输出的特征向量,该特征向量的通道维数的维度数可以为1,本公开实施例对此不作限定。
步骤2320,根据至少一个特征图和第一特征向量,确定至少一个注意力网络的第二损失值。
具体地,注意力网络的特征提取依赖于被训练神经网络,可以将被训练神经网络的中间层输出的至少一个特征图和第一特征向量作为至少一个注意力网络的输入,得到注意力网络的第二损失值。
根据本公开实施例提供的神经网络训练方法,通过将被训练神经网络的至少一个中间层输出的至少一个特征图以及被训练神经网络的第一特征向量作为至少一个注意力网络的输入,以获得至少一个注意力网络的第二损失值,由于注意力网络可以为被训练神经网络的参数提供更合适的梯度,因此可以充分利用注意力网络辅助训练神经网络,提高神经网络的训练效果,并且注意力网络可以与训练完成的神经网络剥离,使得神经网络的参数量不会增加。
图4是本公开另一示例性实施例提供的神经网络训练方法的流程示意图。在本公开图3所示实施例的基础上延伸出本公开图4所示实施例,下面着重叙述图4所示实施例与图3所示实施例的不同之处,相同之处不再赘述。
如图4所示,在本公开实施例提供的神经网络训练方法中,上述步骤2320可以包括步骤2321、步骤2322和步骤2323。
步骤2321,将被训练神经网络的第一特征向量和至少一个特征图输入至少一个注意力网络中,得到至少一个注意力网络各自的第二特征向量,其中,特征图分别与注意力网络一一对应。
例如,如图5所示,至少一个注意力网络520包括三个注意力网络,分别为第一注意力网络521、第二注意力网络522和第三注意力网络523。第一特征向量g为被训练神经网络510的全连接层和softmax层516的前一层(即第五中间层515)输出的特征向量。至少一个特征图选择被训练神经网络510中第一中间层511、第二中间层512和第三中间层513输出的多维特征图,分别为L1、L2和L3,其中,第一中间层511、第二中间层512和第三中间层513可以为卷积层或池化层等,本发明对此不作具体限定。
特征图分别与注意力网络一一对应。具体地,将被训练神经网络510的第一中间层511输出的第一特征图L1和第一特征向量g作为第一注意力网络521的输入;将第二中间层512输出的第二特征图L2和第一特征向量g作为第二注意力网络522的输入,将第三中间层513输出的第三特征图L3和第一特征向量g作为第三注意力网络523的输入。具体地,以第一注意力网络521为例,在第一注意力网络521中,如图6所示,首先将第一特征向量g经过1×1卷积,获得与第一特征图L1通道对齐的特征向量g’,然后将特征向量g’上采样为与第一特征图L1通道数相同的上采样特征图,使得上采样特征图与第一特征图的通道维数相同;将上采样特征图与第一特征图L1进行点加操作,获得加和特征图L1’;在第一注意力网络的卷积层,该加和特征图L1’经过与卷积核的卷积,使其通道维数降为1维;通过Softmax激活函数进行归一化,计算出注意力分数M;将注意力分数M与第一特征图L1点乘后,进行平均池化操作,从而得到第一注意力网络的特征向量,即第二特征向量X。应当理解,上述描述仅为示例性描述,本公开实施例对得到注意力网络的特征向量的方式不作具体限定。
步骤2322,对至少一个注意力网络各自的第二特征向量进行拼接,得到拼接后的第三特征向量。
当至少一个注意力网络为多个注意力网络时,将多个注意力网络的多个特征向量进行拼接。例如,如图5所示,注意力网络的数量为三个,第一注意力网络输出的第二特征向量为X=[x0,x1,...,xN-2,xN-1]T,第二注意力网络输出的第二特征向量为Y=[y0,y1,...,yN-2,yN-1]T,第三注意力网络输出的第二特征向量为Z=[z0,z1,...,zN-2,zN-1]T,则三个注意力网络的拼接后的第三特征向量为N=[x0,x1,...,xN-2,xN-1,y0,y1,...,yN-2,yN-1,z0,z1,...,zN-2,zN-1]T。应当理解,上述描述仅为示例性描述,本公开对此不作限定。
步骤2323,根据第三特征向量,确定至少一个注意力网络的第二损失值。
例如,如图5所示,拼接后的第三特征向量N经过全连接层和softmax层524可以获得注意力网络输出的预测结果。根据预测结果、目标结果,利用第二损失函数计算可以得出三个注意力网络的第二损失值。另外,第一特征向量g经过全连接层和softmax层516,可以确定被训练神经网络的第一损失值。
根据本公开实施例提供的神经网络训练方法,通过将被训练神经网络的第一特征向量和至少一个特征图输入至少一个注意力网络中,得到至少一个注意力网络各自的第二特征向量;对至少一个注意力网络各自的第二特征向量进行拼接,得到拼接后的第三特征向量;根据第三特征向量,确定至少一个注意力网络的第二损失值,由于注意力网络可以为被训练神经网络的参数提供更合适的梯度,因此可以充分利用注意力网络辅助训练神经网络,提高神经网络的训练效果,并且使得注意力网络可以与训练完成的神经网络剥离,使得神经网络的参数量不会增加。
图7是本公开另一示例性实施例提供的神经网络训练方法的流程示意图。在本公开图2所示实施例的基础上延伸出本公开图7所示实施例,下面着重叙述图7所示实施例与图2所示实施例的不同之处,相同之处不再赘述。
如图7所示,在本公开实施例提供的神经网络训练方法中,上述步骤240可以包括步骤2410和步骤2420。
步骤2410,根据第一损失值和第二损失值,得到总损失值。
在一实施例中,可以根据预设权重系数,计算第一损失值和第二损失值的加权和,得到总损失值。
示例性地,被训练神经网络和注意力神经网络均将训练样本中的猫正确识别为猫,其中,被训练神经网络的第一损失值为0.3,第一预设权重系数为0.4,注意力网络的第二损失值为0.1,第二预设权重系数为0.6,则总损失值为0.3*0.4+0.1*0.6=0.18。
在另一实施例中,第一损失值和第二损失值可以直接加和得到总损失值。
示例性地,被训练神经网络和注意力神经网络均将训练样本中的猫正确识别为猫,其中,被训练神经网络的第一损失值为0.3,注意力网络的第二损失值为0.1,则总损失值为0.3+0.1=0.4。
应当理解,本公开实施例对得到总损失值的具体方式不作限定。
步骤2420,根据总损失值,更新被训练神经网络中的参数。
根据本公开实施例提供的技术方案,通过根据第一损失值和第二损失值的总损失值,更新被训练神经网络中的参数,由于注意力网络可以为被训练神经网络的参数提供更合适的梯度,因此可以充分利用注意力网络辅助训练神经网络,提升神经网络训练效果,同时,训练完成的神经网络可以与注意力网络剥离,使得神经网络的参数量不会增加。
在本公开一些实施例中,步骤2420可以包括:根据总损失值,更新各注意力网络的参数和被训练神经网络的参数,得到训练后的神经网络。
具体地,可以根据总损失值反向调整各个注意力网络的参数和第二神经网络的权重、偏置等参数,得到训练后的神经网络以及各注意力网络。利用注意力网络辅助训练后的神经网络性能大大提高,同时训练后的注意力网络可以与神经网络剥离开,从而不增加神经网络模型中的参数量。
通过本公开实施例提供的方法训练后的神经网络可以应用于图像分类、语义分割、目标检测等,本公开实施例对此不作限定。
图8是本公开一示例性实施例提供的基于神经网络进行图像分类的方法的流程示意图。本实施例可应用在电子设备上,由终端设备或服务器等执行,本公开实施例对此不作限定。如图8所示,该方法可以包括如下步骤810和步骤820。
步骤810,将待分类图像输入神经网络,其中,神经网络通过上述任一所述的神经网络训练方法训练得到。
步骤820,采用神经网络对待分类图像进行分类。
根据本公开实施例提供的技术方案,通过利用上述神经网络训练方法训练得到的神经网络对待分类图像进行分类,提高了神经网络的分类准确率。
示例性装置
本公开装置实施例,可以用于执行本公开方法实施例。对于本公开装置实施例中未披露的细节,请参照本公开方法实施例。
请参考图9,其示出本公开一示例性实施例提供的神经网络训练装置的框图。该装置具有实现上述方式实施例图2中的功能,所述功能可以由硬件实现,也可以由硬件执行相应的软件实现。该装置900可以包括:输入模块910、第一确定模块920、第二确定模块930和更新模块940。
输入模块910,用于将训练样本输入被训练神经网络;
第一确定模块920,用于通过训练样本以及被训练神经网络,确定被训练神经网络的第一损失值;
第二确定模块930,用于通过训练样本以及至少一个注意力网络,确定至少一个注意力网络的第二损失值;
更新模块940,用于根据第一损失值和第二损失值,更新被训练神经网络中的参数。
本公开实施例提供的神经网络训练装置,通过分别确定被训练神经网络和注意力网络的损失值,以及根据被训练神经网络和注意力网络的损失值,调整被训练神经网络的参数,由于注意力网络可以为被训练神经网络的参数提供更合适的梯度,因此可以充分利用注意力网络辅助训练神经网络,提升神经网络训练效果,同时,训练完成的神经网络可以与注意力网络剥离,使得神经网络的参数量不会增加。
图10是本公开一示例性实施例提供的神经网络训练装置的第二确定模块的框图。在本公开图9所示实施例的基础上延伸出本公开图10所示实施例,下面着重叙述图10所示实施例与图9所示实施例的不同之处,相同之处不再赘述。
如图10所示,在本公开实施例提供的神经网络训练装置中,第二确定模块930可以包括:获取单元9310和确定单元9320。
获取单元9310,用于获取被训练神经网络的至少一个中间层输出的至少一个特征图以及被训练神经网络的第一特征向量,中间层与其输出的特征图相对应。
确定单元9320,用于根据至少一个特征图和第一特征向量,确定至少一个注意力网络的第二损失值。
图11是本公开一示例性实施例提供的神经网络训练装置的确定单元的框图。在本公开图10所示实施例的基础上延伸出本公开图11所示实施例,下面着重叙述图11所示实施例与图10所示实施例的不同之处,相同之处不再赘述。
如图11所示,在本公开实施例提供的神经网络训练装置中,确定单元9320包括输入子单元9321、拼接子单元9322和确定子单元9323。
输入子单元9321,用于将被训练神经网络的第一特征向量和至少一个特征图输入至少一个注意力网络中,得到至少一个注意力网络各自的第二特征向量,其中,特征图分别与注意力网络一一对应。
拼接子单元9322,用于对至少一个注意力网络各自的第二特征向量进行拼接,得到拼接后的第三特征向量。
确定子单元9323,用于根据第三特征向量,确定至少一个注意力网络的第二损失值。
图12是本公开一示例性实施例提供的神经网络训练装置的更新模块的框图。在本公开图9所示实施例的基础上延伸出本公开图12所示实施例,下面着重叙述图12所示实施例与图9所示实施例的不同之处,相同之处不再赘述。
如图12所示,在本公开实施例提供的神经网络训练装置中,更新模块940包括总损失单元9410和更新单元9420。
总损失单元9410,用于根据第一损失值和第二损失值,得到总损失值;
更新单元9420,用于根据总损失值,更新被训练神经网络中的参数。
在基于图12所示实施例提供的一些实施例中,总损失单元9410还用于根据预设权重系数,计算第一损失值和第二损失值的加权和,得到总损失值。
在基于图12所示实施例提供的一些实施例中,更新单元9420还用于根据总损失值,更新各注意力网络的参数和被训练神经网络的参数,得到训练后的神经网络。
需要说明的是,以上第一确定模块920和第二确定模块930实际上可以为同一个软件或硬件模块,也可以为不同的软件或硬件模块,本公开实施例对此不作限定。
请参考图13,其示出本公开一示例性实施例提供的基于神经网络进行图像分类的装置的框图。该装置具有实现上述方式实施例图8中的功能,所述功能可以有硬件实现,也可以由硬件执行相应的软件实现。该装置1300可以包括:输入模块1310和分类模块1320。
输入模块1310,用于将待分类图像输入神经网络,其中,神经网络通过上述任一所述的神经网络训练方法训练得到。
分类模块1320,用于采用神经网络对待分类图像进行分类。
本公开实施例提供的神经网络训练装置,通过利用上述神经网络训练方法训练得到的神经网络对待分类图像进行分类,提高了神经网络的分类准确率。
示例性电子设备
下面,参考图14来描述根据本公开实施例的电子设备。图14图示了根据本公开实施例的电子设备的框图。
如图14所示,电子设备1400包括一个或多个处理器1410和存储器1420。
处理器1410可以是中央处理单元(CPU)或者具有数据处理能力和/或指令执行能力的其他形式的处理单元,并且可以控制电子设备1400中的其他组件以执行期望的功能。
存储器1420可以包括一个或多个计算机程序产品,所述计算机程序产品可以包括各种形式的计算机可读存储介质,例如易失性存储器和/或非易失性存储器。所述易失性存储器例如可以包括随机存取存储器(RAM)和/或高速缓冲存储器(cache)等。所述非易失性存储器例如可以包括只读存储器(ROM)、硬盘、闪存等。在所述计算机可读存储介质上可以存储一个或多个计算机程序指令,处理器1410可以运行所述程序指令,以实现上文所述的本公开的各个实施例的神经网络训练方法、基于神经网络进行图像分类的方法以及/或者其他期望的功能。在所述计算机可读存储介质中还可以存储诸如输入信号、信号分量、噪声分量等各种内容。
在一个示例中,电子设备1400还可以包括:输入装置1430和输出装置1440,这些组件通过总线系统和/或其他形式的连接机构(未示出)互连。
例如,该输入装置1430可以是麦克风或麦克风阵列,摄像头等。在该电子设备是单机设备时,该输入装置1430可以是通信网络连接器。
此外,该输入设备1430还可以包括例如键盘、鼠标等等。
该输出装置1440可以向外部输出各种信息,包括确定出的距离信息、方向信息等。该输出设备1440可以包括例如显示器、扬声器、打印机、以及通信网络及其所连接的远程输出设备等等。
当然,为了简化,图14中仅示出了该电子设备1400中与本公开有关的组件中的一些,省略了诸如总线、输入/输出接口等等的组件。除此之外,根据具体应用情况,电子设备1400还可以包括任何其他适当的组件。
示例性计算机程序产品和计算机可读存储介质
除了上述方法和设备以外,本公开的实施例还可以是计算机程序产品,其包括计算机程序指令,所述计算机程序指令在被处理器运行时使得所述处理器执行本说明书上述“示例性方法”部分中描述的根据本公开各种实施例的神经网络训练方法、基于神经网络进行图像分类的方法中的步骤。
所述计算机程序产品可以以一种或多种程序设计语言的任意组合来编写用于执行本公开实施例操作的程序代码,所述程序设计语言包括面向对象的程序设计语言,诸如Java、C++等,还包括常规的过程式程序设计语言,诸如“C”语言或类似的程序设计语言。程序代码可以完全地在用户计算设备上执行、部分地在用户设备上执行、作为一个独立的软件包执行、部分在用户计算设备上部分在远程计算设备上执行、或者完全在远程计算设备或服务器上执行。
此外,本公开的实施例还可以是计算机可读存储介质,其上存储有计算机程序指令,所述计算机程序指令在被处理器运行时使得所述处理器执行本说明书上述“示例性方法”部分中描述的根据本公开各种实施例的神经网络训练方法、基于神经网络进行图像分类的方法中的步骤。
所述计算机可读存储介质可以采用一个或多个可读介质的任意组合。可读介质可以是可读信号介质或者可读存储介质。可读存储介质例如可以包括但不限于电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。
以上结合具体实施例描述了本公开的基本原理,但是,需要指出的是,在本公开中提及的优点、优势、效果等仅是示例而非限制,不能认为这些优点、优势、效果等是本公开的各个实施例必须具备的。另外,上述公开的具体细节仅是为了示例的作用和便于理解的作用,而非限制,上述细节并不限制本公开为必须采用上述具体的细节来实现。
本公开中涉及的器件、装置、设备、系统的方框图仅作为例示性的例子并且不意图要求或暗示必须按照方框图示出的方式进行连接、布置、配置。如本领域技术人员将认识到的,可以按任意方式连接、布置、配置这些器件、装置、设备、系统。诸如“包括”、“包含”、“具有”等等的词语是开放性词汇,指“包括但不限于”,且可与其互换使用。这里所使用的词汇“或”和“和”指词汇“和/或”,且可与其互换使用,除非上下文明确指示不是如此。这里所使用的词汇“诸如”指词组“诸如但不限于”,且可与其互换使用。
还需要指出的是,在本公开的装置、设备和方法中,各部件或各步骤是可以分解和/或重新组合的。这些分解和/或重新组合应视为本公开的等效方案。
提供所公开的方面的以上描述以使本领域的任何技术人员能够做出或者使用本公开。对这些方面的各种修改对于本领域技术人员而言是非常显而易见的,并且在此定义的一般原理可以应用于其他方面而不脱离本公开的范围。因此,本公开不意图被限制到在此示出的方面,而是按照与在此公开的原理和新颖的特征一致的最宽范围。
为了例示和描述的目的已经给出了以上描述。此外,此描述不意图将本公开的实施例限制到在此公开的形式。尽管以上已经讨论了多个示例方面和实施例,但是本领域技术人员将认识到其某些变型、修改、改变、添加和子组合。

Claims (11)

1.一种神经网络训练方法,包括:
将训练样本输入被训练神经网络;
通过所述训练样本以及所述被训练神经网络,确定所述被训练神经网络的第一损失值;
通过所述训练样本以及至少一个注意力网络,确定所述至少一个注意力网络的第二损失值;
根据所述第一损失值和所述第二损失值,更新所述被训练神经网络中的参数。
2.根据权利要求1所述的方法,其中,所述确定所述至少一个注意力网络的第二损失值,包括:
获取所述被训练神经网络的至少一个中间层输出的至少一个特征图以及所述被训练神经网络的第一特征向量,中间层与其输出的特征图相对应;
根据所述至少一个特征图和所述第一特征向量,确定所述至少一个注意力网络的第二损失值。
3.根据权利要求2所述的方法,其中,所述根据所述至少一个特征图和所述第一特征向量,确定所述至少一个注意力网络的第二损失值,包括:
将所述被训练神经网络的所述第一特征向量和所述至少一个特征图输入所述至少一个注意力网络中,得到所述至少一个注意力网络各自的第二特征向量,其中,特征图分别与注意力网络一一对应;
对所述至少一个注意力网络各自的第二特征向量进行拼接,得到拼接后的第三特征向量;
根据所述第三特征向量,确定所述至少一个注意力网络的第二损失值。
4.根据权利要求1-3任一项所述的方法,其中,所述根据所述第一损失值和所述第二损失值,更新所述被训练神经网络中的参数,包括:
根据所述第一损失值和所述第二损失值,得到总损失值;
根据所述总损失值,更新所述被训练神经网络中的参数。
5.根据权利要求4所述的方法,其中,所述根据所述第一损失值和所述第二损失值,得到总损失值,包括:
根据预设权重系数,计算所述第一损失值和所述第二损失值的加权和,得到所述总损失值。
6.根据权利要求4所述的方法,其中,所述根据所述总损失值,更新所述被训练神经网络中的参数,包括:
根据所述总损失值,更新各注意力网络的参数和所述被训练神经网络的参数,得到训练后的神经网络。
7.一种基于神经网络进行图像分类的方法,包括:
将待分类图像输入神经网络,其中,所述神经网络通过上述权利要求1-6任一所述的神经网络训练方法训练得到;
采用所述神经网络对所述待分类图像进行分类。
8.一种神经网络训练装置,包括:
输入模块,用于将训练样本输入被训练神经网络;
第一确定模块,用于通过所述训练样本以及所述被训练神经网络,确定所述被训练神经网络的第一损失值;
第二确定模块,用于通过所述训练样本以及至少一个注意力网络,确定所述至少一个注意力网络的第二损失值;
更新模块,用于根据所述第一损失值和所述第二损失值,更新所述被训练神经网络中的参数。
9.一种基于神经网络进行图像分类的装置,包括:
输入模块,用于将待分类图像输入神经网络,其中,所述神经网络通过上述权利要求1-6任一所述的神经网络训练方法训练得到;
分类模块,用于采用所述神经网络对所述待分类图像进行分类。
10.一种计算机可读存储介质,所述存储介质存储有计算机程序,所述计算机程序用于执行上述权利要求1-7任一所述的方法。
11.一种电子设备,所述电子设备包括:
处理器;
用于存储所述处理器可执行指令的存储器;
所述处理器,用于执行上述权利要求1-7任一所述的方法。
CN202010231122.7A 2020-03-27 2020-03-27 神经网络训练方法及装置、图像分类的方法及装置 Pending CN113449840A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010231122.7A CN113449840A (zh) 2020-03-27 2020-03-27 神经网络训练方法及装置、图像分类的方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010231122.7A CN113449840A (zh) 2020-03-27 2020-03-27 神经网络训练方法及装置、图像分类的方法及装置

Publications (1)

Publication Number Publication Date
CN113449840A true CN113449840A (zh) 2021-09-28

Family

ID=77808001

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010231122.7A Pending CN113449840A (zh) 2020-03-27 2020-03-27 神经网络训练方法及装置、图像分类的方法及装置

Country Status (1)

Country Link
CN (1) CN113449840A (zh)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113780478A (zh) * 2021-10-26 2021-12-10 平安科技(深圳)有限公司 活动性分类模型训练方法、分类方法、装置、设备、介质
CN115761448A (zh) * 2022-12-02 2023-03-07 美的集团(上海)有限公司 神经网络的训练方法、训练装置和可读存储介质
CN116416456A (zh) * 2023-01-13 2023-07-11 北京数美时代科技有限公司 基于自蒸馏的图像分类方法、系统、存储介质和电子设备
CN113780478B (zh) * 2021-10-26 2024-05-28 平安科技(深圳)有限公司 活动性分类模型训练方法、分类方法、装置、设备、介质

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113780478A (zh) * 2021-10-26 2021-12-10 平安科技(深圳)有限公司 活动性分类模型训练方法、分类方法、装置、设备、介质
CN113780478B (zh) * 2021-10-26 2024-05-28 平安科技(深圳)有限公司 活动性分类模型训练方法、分类方法、装置、设备、介质
CN115761448A (zh) * 2022-12-02 2023-03-07 美的集团(上海)有限公司 神经网络的训练方法、训练装置和可读存储介质
CN115761448B (zh) * 2022-12-02 2024-03-01 美的集团(上海)有限公司 神经网络的训练方法、训练装置和可读存储介质
CN116416456A (zh) * 2023-01-13 2023-07-11 北京数美时代科技有限公司 基于自蒸馏的图像分类方法、系统、存储介质和电子设备
CN116416456B (zh) * 2023-01-13 2023-10-24 北京数美时代科技有限公司 基于自蒸馏的图像分类方法、系统、存储介质和电子设备

Similar Documents

Publication Publication Date Title
CN111797893B (zh) 一种神经网络的训练方法、图像分类系统及相关设备
CN109948149B (zh) 一种文本分类方法及装置
WO2022068623A1 (zh) 一种模型训练方法及相关设备
GB2546360A (en) Image captioning with weak supervision
US20230153615A1 (en) Neural network distillation method and apparatus
CN113570029A (zh) 获取神经网络模型的方法、图像处理方法及装置
CN111666416B (zh) 用于生成语义匹配模型的方法和装置
US20230117973A1 (en) Data processing method and apparatus
EP4318313A1 (en) Data processing method, training method for neural network model, and apparatus
CN111831826A (zh) 跨领域的文本分类模型的训练方法、分类方法以及装置
CN111428805A (zh) 显著性物体的检测方法、装置、存储介质及电子设备
CN112749737A (zh) 图像分类方法及装置、电子设备、存储介质
CN113449840A (zh) 神经网络训练方法及装置、图像分类的方法及装置
CN112420125A (zh) 分子属性预测方法、装置、智能设备和终端
CN115238909A (zh) 一种基于联邦学习的数据价值评估方法及其相关设备
CN113435531B (zh) 零样本图像分类方法、系统、电子设备及存储介质
CN112989843B (zh) 意图识别方法、装置、计算设备及存储介质
CN113806501B (zh) 意图识别模型的训练方法、意图识别方法和设备
CN113870863A (zh) 声纹识别方法及装置、存储介质及电子设备
CN111882048A (zh) 一种神经网络结构搜索方法及相关设备
CN113010687B (zh) 一种习题标签预测方法、装置、存储介质以及计算机设备
CN113961765B (zh) 基于神经网络模型的搜索方法、装置、设备和介质
CN113569860B (zh) 实例分割方法和实例分割网络的训练方法及其装置
CN114707070A (zh) 一种用户行为预测方法及其相关设备
CN111767710B (zh) 印尼语的情感分类方法、装置、设备及介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination