CN111814813A - 神经网络训练和图像分类方法与装置 - Google Patents

神经网络训练和图像分类方法与装置 Download PDF

Info

Publication number
CN111814813A
CN111814813A CN201910284005.4A CN201910284005A CN111814813A CN 111814813 A CN111814813 A CN 111814813A CN 201910284005 A CN201910284005 A CN 201910284005A CN 111814813 A CN111814813 A CN 111814813A
Authority
CN
China
Prior art keywords
neural network
image
training
categories
determining
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN201910284005.4A
Other languages
English (en)
Inventor
王飞
钱晨
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Sensetime Technology Development Co Ltd
Original Assignee
Beijing Sensetime Technology Development Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Sensetime Technology Development Co Ltd filed Critical Beijing Sensetime Technology Development Co Ltd
Priority to CN201910284005.4A priority Critical patent/CN111814813A/zh
Publication of CN111814813A publication Critical patent/CN111814813A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • General Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Evolutionary Biology (AREA)
  • Biophysics (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)
  • Image Processing (AREA)

Abstract

本申请实施例公开一种神经网络训练和图像分类方法与装置,其中神经网络训练方法包括:将标注有类别信息的图像输入神经网络,经神经网络预测图像属于预定N个类别中每个类别的预测值;确定除图像标注的类别之外预测值大于设定阈值的K‑1个类别,其中N大于K且N和K分别是大于2的正整数;根据K‑1个类别的预测值和图像标注的类别信息,确定图像的类别预测损失;根据类别预测损失调整神经网络的网络参数。这样通过获取图像的K‑1个易混淆类别,使用这K‑1个易混淆类别来训练神经网络,以提高神经网络对易混淆的K‑1类别的区分能力,从而提高了该神经网络的分类准确性。

Description

神经网络训练和图像分类方法与装置
技术领域
本申请实施例涉及计算机图像处理技术领域,尤其涉及一种神经网络训练和图像分类方法与装置。
背景技术
在计算机视觉领域,深度学习已经被广泛应用在图像分类,定位,分割,识别等任务中。分类任务是一项基本而普遍的任务,包括人脸识别、物体分类、文字识别、疾病监测等等。
在分类模型的训练过程中,当训练数据集规模较大时,训练进行到后期,损失函数值比较小,对应的梯度也较小,梯度方向不稳定,此时大量训练是冗余的,其分类结果的准确性差。
发明内容
本申请实施例提供一种神经网络训练和图像分类方法与装置,以提高神经网络对图像的分类准确性。
第一方面,本申请实施例提供一种神经网络训练方法,包括:
将标注有类别信息的图像输入神经网络;
经所述神经网络预测所述图像属于预定N个类别中每个类别的预测值;
确定除所述图像标注的类别之外预测值大于设定阈值的K-1个类别,所述N大于所述K且所述N和所述K分别是大于2的正整数;
根据所述K-1个类别的预测值和所述图像标注的类别信息,确定所述图像的类别预测损失;
根据所述类别预测损失调整所述神经网络的网络参数。
在第一方面的一种可能的实现方式中,确定所述图像在所述N个类别中除所述K-1个类别之外的其他类别的类别预测损失为0。
在第一方面的另一种可能的实现方式中,基于分别标注有类别信息的图像集多次迭代训练神经网络,一次迭代训练完成后进行下一次迭代训练,直至满足训练停止条件,其中,每次迭代训练过程执行如上述的方法,且不同迭代次训练输入所述神经网络中的图像不完全相同。
在第一方面的另一种可能的实现方式中,
每次迭代训练过程向所述神经网络输入多张图像,分别预测所述多张图像中每个图像的类别预测损失;
根据所述类别预测损失调整所述神经网络的网络参数,包括:确定所述多张图像的平均类别预测损失,根据所述平均类别预测损失调整所述神经网络的网络参数。
可选的,所述N和所述K分别是大于1000的正整数,和/或,训练所述神经网络的图像总数量大于1000张。
在第一方面的另一种可能的实现方式中,
所述根据所述类别预测损失调整所述神经网络的网络参数,包括:
确定所述神经网络的损失函数;
确定所述损失函数关于所述K-1个类别中每个类别的预测值的第一偏导数;
确定所述损失函数关于所述图像标注的类别的预测值的第二偏导数;
根据所述第一偏导数和所述第二偏导数,确定所述神经网络的更新梯度;
根据所述更新梯度,调整所述神经网络的网络参数;
其中,所述N个类别中除所述K个类别之外的类别对应的更新梯度为0。
第二方面,本申请实施例提供一种图像分类方法,包括:
获取待分类的图像;
将所述图像输入神经网络,确定所述图像的分类结果;
其中,所述神经网络为采用上述第一方面所述的训练方法训练得到的。
第三方面,本申请实施例提供一种神经网络训练装置,包括:
输入模块,用于将标注有类别信息的图像输入神经网络;
预测模块,用于经所述神经网络预测所述图像属于预定N个类别中每个类别的预测值;
混淆类别确定模块,用于确定除所述图像标注的类别之外预测值大于设定阈值的K-1个类别,所述N大于所述K且所述N和所述K分别是大于2的正整数;
损失确定模块,用于根据所述K-1个类别的预测值和所述图像标注的类别信息,确定所述图像的类别预测损失;
调整模块,用于根据所述类别预测损失调整所述神经网络的网络参数。
在第三方面的一种可能的实现方式中,所述损失确定模块,还用于确定所述图像在所述N个类别中除所述K-1个类别之外的其他类别的类别预测损失为0。
在第三方面的另一种可能的实现方式中,所述装置还包括:
训练模块,用于基于分别标注有类别信息的图像集多次迭代训练神经网络,一次迭代训练完成后进行下一次迭代训练,直至满足训练停止条件,其中,每次迭代训练过程调用上述各模块,且不同迭代次训练输入所述神经网络中的图像不完全相同。
在第三方面的另一种可能的实现方式中,所述输入模块,具体用于每次迭代训练过程向所述神经网络输入多张图像;
所述损失确定模块,具体用于分别预测所述多张图像中每个图像的类别预测损失;
所述调整模块,具体用于确定所述多张图像的平均类别预测损失,根据所述平均类别预测损失调整所述神经网络的网络参数。
可选的,所述N和所述K分别是大于1000的正整数,和/或,训练所述神经网络的图像总数量大于1000张。
在第三方面的另一种可能的实现方式中,所述调整模块包括:确定单元和调整单元;
所述确定单元,用于确定所述神经网络的损失时候,并确定所述损失函数关于所述K-1个类别中每个类别的预测值的第一偏导数;确定所述损失函数关于所述图像标注的类别的预测值的第二偏导数;以及根据所述第一偏导数和所述第二偏导数,确定所述神经网络的更新梯度;
所述调整单元,用于根据所述更新梯度,调整所述神经网络的网络参数;
其中,所述N个类别中除所述K个类别之外的类别对应的更新梯度为0。
第四方面,本申请实施例提供一种图像分类装置,包括:
获取单元,用于获取待分类的图像;
确定模块,用于将所述图像输入神经网络,确定所述图像的分类结果;
其中,所述神经网络为采用上述第一方面所述的训练方法训练得到的。
第五方面,本申请实施例提供一种电子设备,包括:
存储器,用于存储计算机程序;
处理器,用于执行所述计算机程序,以实现第一方面任一项所述的神经网络训练方法或者实现第二方面任一项所述的图像分类方法。
第六方面,本申请实施例提供一种计算机存储介质,所述存储介质中存储计算机程序,所述计算机程序在执行时实现第一方面任一项所述的神经网络训练方法,或者实现第二方面所述的图像分类方法。
本申请实施例提供的神经网络训练和图像分类方法与装置,通过将标注有类别信息的图像输入神经网络,经神经网络预测图像属于预定N个类别中每个类别的预测值,接着,确定除图像标注的类别之外预测值大于设定阈值的K-1个类别,其中N大于K且N和K分别是大于2的正整数,然后,根据K-1个类别的预测值和图像标注的类别信息,确定图像的类别预测损失,最后根据类别预测损失调整神经网络的网络参数。即本申请通过获取图像的K-1个易混淆类别,使用这K-1个易混淆类别来训练神经网络,以提高神经网络对易混淆的K-1类别的区分能力,从而提高了该神经网络的分类准确性。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作一简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为本申请实施例提供的神经网络训练方法的流程图;
图2为本申请实施例涉及的一种神经网络的示意图;
图3为本申请实施例涉及的神经网络的又一示意图;
图4为本申请实施例提供的神经网络训练方法的流程图;
图5为本申请实施例提供的图像分类方法的流程图;
图6为本申请实施例提供的神经网络训练装置的结构示意图;
图7为本申请实施例提供的神经网络训练装置的结构示意图;
图8为本申请实施例提供的神经网络训练装置的结构示意图;
图9为本申请实施例提供的图像分类装置的结构示意图;
图10为本申请实施例提供的电子设备的结构示意图。
具体实施方式
为使本申请实施例的目的、技术方案和优点更加清楚,下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本申请实施例提供的技术方案,具有广泛的通用性,适用领域包括但不限于计算机视觉、智能视频分析、高级辅助驾驶系统和自动驾驶等领域,用于对神经网络进行训练,以使训练后的神经网络可以实现图像的准确分类。
下面以具体地实施例对本发明的技术方案进行详细说明。下面这几个例如实施例可以相互结合,对于相同或相似的概念或过程可能在某些实施例不再赘述。
图1为本申请实施例提供的神经网络训练方法的流程图。该如图1所示,本实施例的方法可以包括:
S101、将标注有类别信息的图像输入神经网络。
本实施例的执行主体是电子设备或者电子设备中的处理器,该电子设备可以是计算机、智能手机、AR(Augmented Reality Technique,增强现实技术)眼镜、车载系统等。
示例性的,本实施例以执行主体为上述电子设备中的处理器为例进行说明。
可选的,本实施例的电子设备还可以具有摄像头,可以拍摄待分类的图像,并将该待分类的图像发送给电子设备的处理器。
可选的,本实施例的电子设备可以与其他的摄像头连接,该摄像头可以拍摄待分类的图像,电子设备可以从该摄像头处获得待分类的图像。
本实施例的电子设备还包括存储介质,该存储介质中保存有待训练的神经网络,且处理器可以调用该神经网络。
上述神经网络原则上可以是任何神经网络模型,例如Quick-CNN、NIN、AlexNet等。
参照图2所示,图2为本申请实施例涉及的一种神经网络的示意图,需要说明的是,本申请实施例的神经网络包括但不限于如图2所示。如图2所示,该神经网络包括输入层和全连接层,可选的,该神经网络在输入层与全连接层之间还包括隐藏层,该隐藏层包括至少一个卷积层、至少一个池化层、至少一个非线性层等网络层。图2中的h、w和c分别为输入的图像的高、宽和通道数,BS为一次迭代训练输入神经网络的图像数量。
在对神经网络训练时,首先获取训练样本,该训练样本包括多张训练图像,针对每张图像,标注每张图像的类别信息,例如标准每张图像的类别真值。
接着,将标注有类型信息的图像输入神经网络中。
S102、经所述神经网络预测所述图像属于预定N个类别中每个类别的预测值。
假设本申请实施例涉及的分类任务为N分类任务,该N为大于等于2的正整数。
为了便于阐述,在此以一张图像x为例进行说明,将图像x输入图2所示的神经网络后,该神经网络针对N个类别中的每个类别输出一个预测值,记为oi,i=0,1,2,….N-1,i表示类别的标号。该oi可以理解为该图像x属于类别i的可能性为oi
在一种示例中,上述图像的N个类别的预测值可以为神经网络的全连接层输出的N个激活值,即全连接层包括激活函数,该激活函数针对N个类别中的每个类别输出一个激活值,激活值可以表示训练图像属于某一类别的可能性。
参照上述方法,假设上述N为10000,针对图像x,可以获得如表1所示的预测值:
表1
Figure BDA0002022654790000071
由表1可知,每个类别对应一个预测值,其中某类别的预测值越大,说明该图像x属于该类别的可能性越大,例如,类别9997的预测值为o9997=6,类别0的预测值为o0=3,说明该图像x的类别为类别9997的可能性要大于为类别1的可能性。预测值也可以根据实际需要进行规一化处理,本申请并不限定。
S103、确定除所述图像标注的类别之外预测值大于设定阈值的K-1个类别,N大于K且N和K分别是大于2的正整数。
继续参照表1,从表1所示的N个预测值中选择预测值大于设定阈值的K个分类结果。示例性的,可以将上述N个预测值从大到小排序或者从小到大排序,从中获取预测值最大的K个预测值,这K个预测值中包括图像x标注的类别的真值。
获取上述K个预测值对应的类别,得到K个类别,并从这K个类别中剔除图像x标注的类别,即图像x的类别真值l,获得K-1个类别,这K-1个类别对于图像x来说为易混淆类别。即图像x对于这K-1个类别的预测值与图像x的类别真值之间差别较小,例如,图像x的类别真值为1,但是预测出的图像x属于这K-1个类别的预测值位于1至0.5之间,这样,无法准确区分图像x到底属于哪个类别,进而使得这K-1个类别对于图像x来说为易混淆类别。将这K-1个类别记为图像x的易混淆类别子集H。
举例说明,在进行动物类别识别时,假设输入的图像为动物小狗的图像,该小狗的真实类别为哈士奇,根据上述步骤,将动物小狗的图像输入神经网络,该神经网络预测出动物小狗的图像属于预定N个类别中每个类别的预测值,这N个类别包括多种动物的类别和不同种狗的类别。从这N个类别中选择出除类别为哈士奇之外的其他预测值大于设定阈值的K-1个类别,例如这K-1个类别包括:猫、老虎、狮子、秋田犬、柴犬、斗牛犬等。这样,在对动物小狗的图像进行分类时,会将小狗识别成猫、老虎、狮子、秋田犬、柴犬、斗牛犬等,这些猫、老虎、狮子、秋田犬、柴犬、斗牛犬等K-1个类别对于小狗图像来说为易混淆类别。
再例如,识别人脸表情时,假设输入的人脸图像的人脸表情为微笑,根据上述步骤,将人脸图像输入神经网络,该神经网络预测出人脸图像属于预定N个类别中每个类别的预测值,这N个类别包括人脸的各种表情。从这N个类别中选择出除类别为微笑之外的其他预测值大于设定阈值的K-1个类别,例如这K-1个类别包括:大笑、惊讶、恐惧、愤怒等。这样,在对上述微笑的人脸图像进行分类时,会将微笑识别成大笑、惊讶、恐惧、愤怒等,这些大笑、惊讶、恐惧、愤怒等K-1个类别对于表情为微笑的人脸图像来说为易混淆类别。
需要说明的是,本申请实施例对K的具体取值范围不做限制,只要保证K为大于2的正整数即可,可选的,在训练过程中,随着迭代次数的增大,K值可以逐渐减小,以提高神经网络的训练速度。
S104、根据所述K-1个类别的预测值和所述图像标注的类别信息,确定所述图像的类别预测损失。
本申请实施例对确定图像的类别预测损失的具体形式不做限制,例如,可以是能量损失、交叉熵损失等方式来确定图像的类别预测损失。
在一种示例中,确定神经网络的损失函数,根据该损失函数,确定图像的类别预测损失。
可选的,该神经网络的损失函数如下式(1)所示,需要说明的是式(1)只是本申请实施例涉及的损失函数的一种表达形式,本申请实施例的损失函数还可以是式(1)的任一种变形,或者为其他形式的损失函数。
Figure BDA0002022654790000081
上述公式(1)中的Lh表示图像的类别预测损失,ol为图像标注的类别对应的真值,oi为类别子集H中第i个类别对应的预测值,i∈H。
S105、根据所述类别预测损失调整所述神经网络的网络参数。
本步骤对根据图像的类别预测损失调整神经网络的网络参数的方式不做限制,例如,基于牛顿算法、共轭梯度法、准牛顿法、衰减的最小平方法等方法,来根据图像的类别预测损失调整神经网络的网络参数。
可选的,为了提高神经网络的训练速度,则确定所述图像在所述N个类别中除所述K-1个类别之外的其他类别的类别预测损失为0。也就是说,其他类别反向回传的梯度为0,相当于在本次迭代训练过程中不考虑这些类别,而重点考虑K-1类别,由此提高神经网络对这K-1个易混淆类别的区分能力。
由上述各步骤可知,本申请实施例使用图像的易混淆类别子集来训练本神经网络,使得训练后的神经网络对易混淆的类别可以进行准确区分,进而在图像分类时,实现对图像的准确分类。例如,将哈士奇的图像输入神经网络,根据上述S103的方法,得到哈士奇的图像的K-1个易混淆类别,使用这K-1个易混淆类别来训练神经网络,这样在下次输入哈士奇的图像时,神经网络可以准确识别出该图像中动物的类别为哈士奇。再例如,将表情为微笑的人脸图像输入神经网络,根据上述S103的方法,得到人脸图像的K-1个易混淆类别,使用这K-1个易混淆类别来训练神经网络,这样在下次输入表情为微笑的人脸图像时,神经网络可以准确识别出该图像中人脸的表情类别为微笑。
上述S101至S105为神经网络的一次迭代训练过程,为了提高训练的准确性,需要对神经网络进行多次迭代训练,不同次迭代输入的图像对应的易混淆类别子集H可以不同。也就是说,不同的图像对应的K-1个易混淆类别可能不同,例如,图像1对应的易混淆类别为N个类别中的类别1、类别2和类别3,而图像2对应的易混淆类别为N个类别中的类别3、类别10。
本申请实施例的训练过程,是在一次训练结束后,进行下一次训练,每一次训练是在线确定本次输入的图像的K-1个易混淆类别,这样随着训练的进行可以动态进行图像的K-1个易混淆类别的挖掘,进而提高了神经网络的训练效果。
本申请实施例,由于每次训练过程选择的图像的类别可能不同,因此使得训练后的神经网络可以实现多类别的检测,例如可以对驾驶员状态检测,以识别驾驶员是在抽烟、喝水、打电话、吃东西、打哈欠、张嘴说话等。
在训练完成后,该神经网络可以实现端到端的输出,例如,将待分类的图像输入至该神经网络,该神经网络可以准确输出待分类的图像的分类结果。
本申请实施例提供的神经网络信训练方法,通过将标注有类别信息的图像输入神经网络,经神经网络预测图像属于预定N个类别中每个类别的预测值,接着,确定除图像标注的类别之外预测值大于设定阈值的K-1个类别,其中N大于K且N和K分别是大于2的正整数,然后,根据K-1个类别的预测值和图像标注的类别信息,确定图像的类别预测损失,最后根据类别预测损失调整神经网络的网络参数。即本申请通过获取图像的K-1个易混淆类别,使用这K-1个易混淆类别来训练神经网络,以提高神经网络对易混淆的K-1类别的区分能力,从而提高了该神经网络的分类准确性。
在一种示例中,在上述图1和图2的基础上,参照图3所示,图3为本申请实施例涉及的神经网络的又一示意图。如图3所示,本申请实施例的神经网络除了包括全连接层外,还包括与该全连接层连接的易混淆类别选择层、以及与易混淆类别选择层连接的softmaxloss层。
其中,易混淆类别选择层用于从全连接层输出的图像的N个类别的预测值中获取图像的K-1个易混淆类别和该图像标注的类别。
softmax loss层用于根据图像的K-1个易混淆类别的预测值和该图像标注的类别信息,确定该图像的类别预测损失。
图3所示的神经网络的训练过程可以包括S100:
S100、基于分别标注有类别信息的图像集多次迭代训练神经网络,一次迭代训练完成后进行下一次迭代训练,直至满足训练停止条件,其中,每次迭代训练过程执行上述S101至S105的步骤,且不同迭代次训练输入所述神经网络中的图像不完全相同。
神经网络的训练过程包括:前向传播和反向传播,其中S101至S104为前向传播的过程,S105为反向传播的过程。
为了提高训练的准确性,本申请实施例使用图像集来训练神经网络,该图像集中包括多张图像,其中每张图像均标注有类别信息。
继续参照图3所示,在一次迭代训练中,从图像集中任意选择一张或多张图像,将选择的图像输入神经网络的输入层,神经网络中的隐藏层对输入的每张图像进行处理,这些隐藏层在图3中未示出。隐藏层将每张图像的处理结果输入全连接层,全连接层输出每张图像的N个分类的预测值,其具体过程可以参照上述S102,在此不再赘述。接着,全连接层将这N个预测值输入易混淆类别选择层。易混淆类别选择层从全连接层输出的N个预测值中选择除图像标注的类别之外预测值大于设定阈值的K-1个预测值,并获得该K-1个预测值中每个预测值对应的类别,进而获得该图像的K-1个易混淆类别。
接着,softmax loss层根据上述获得的每张图像K-1个类别(即K-1个易混淆类别)的预测值和每张图像标注的类别信息,确定每张图像的类别预测损失。
然后,softmax loss层将每张图像的类别预测损失反向输入神经网络,以根据每张图像的类别预测损失调整神经网络的网络参数。
其中,根据每张图像的类别预测损失调整神经网络的网络参数的方式至少包括如下两种:
方式一,根据每张图像的类别预测损失逐一调整神经网络的网络参数,例如,一次训练迭代训练中输入100张图像,首先使用100张图像中的第一张图像的类别预测损失来调整神经网络的网络参数,获得第一次调整后的神经网络。接着,使用100张图像中的第二张图像的类别预测损失来调整第一次调整后的神经网络的网络参数,获得第二次整后的神经网络。这样,使用100张图像中每张图像的类别预测损失来逐一调整神经网络的网络参数,实现对神经网络的训练。
方式二,确定所述多张图像的平均类别预测损失,根据所述平均类别预测损失调整所述神经网络的网络参数。
具体的,假设一次迭代训练中输入100张图像,根据上述步骤,获得100张图像中每一张图像的类别预测损失。为了提高训练速度,可以确定100张图像的平均类别预测损失,根据该平均类别预测损失调整神经网络的网络参数,该调整过程为一次调整,可以减少一次迭代训练中对神经网络参数的调整次数,进而提高了神经网络的训练速度。
经过上述步骤,完成一次迭代训练完成后,从图像集中选择另一张或一组图像执行上述步骤,以进行下一次迭代训练,直至满足训练停止条件为止。
可选的,不同迭代次训练输入神经网络中的图像不完全相同。
可选的,上述训练停止条件可以是预设迭代次数,即神经网络的迭代循环次数达到预设的循环次数时,参数更新过程停止,将当前更新后的参数作为神经网络的新参数。
可选的,上述训练停止条件还可以是softmax loss层的输出结果的损失满足预设损失,例如,神经网络使用更新后的参数进行前向传播,softmax loss层根据图像的损失函数,确定图像的K-1易混淆类别中每个类别对应的损失值,将该损失值与预设损失值进行比较。若该损失值大于预设损失值,则说明此时的神经网络没有训练完成,继续对神经网络的参数进行更新,直到神经网络使用更新后的参数进行前向传播时,使得softmax loss层输出的损失值小于或等于预设损失值为止,此时神经网络训练完成。
这样,经过多次迭代训练的神经网络,可以准确预测出图像的类别。
由上述可知,本申请实施例中图像的K-1个易混淆类别的确定过程是在线确定的,这样随着训练的进行可以动态进行图像的K-1个易混淆类别的挖掘,进而提高了神经网络的训练效果。
可选的,针对图像集中的每一张图像,其对应的易混淆类别子集的大小可能相同,也可以不同,即每张图像对应的K-1个类别的数量可以相同,也可以不同。每张图像对应的K-1个易混淆类别可以不同,例如图像1对应的易混淆类别为N个类别中的类别1、类别2和类别3,而图像2对应的易混淆类别为N个类别中的类别3、类别10。
可选的,随着训练的进行,不同迭代所自动选择的易混淆类别子集也可能不同。
可选的,上述N和上述K分别是大于1000的正整数,针对多分类问题,可以使神经网络的训练效果更为显著。
可选的,上述训练神经网络的图像总数量可以大于1000张,可以保证神经网络训练的准确性。
本申请实施例的方法,基于分别标注有类别信息的图像集多次迭代训练神经网络,一次迭代训练完成后进行下一次迭代训练,直至满足训练停止条件,进而实现对神经网络的有效训练,使得训练后的神经网络可以准确区分图像易混淆的类别,实现对图像的准确分类。
图4为本申请实施例提供的神经网络训练方法的流程图,在上述各实施例的基础上,本申请实施例涉及的是根据所述类别预测损失调整所述神经网络的网络参数的一种可能的实现方式。参照图4所示,上述S105可以包括:
S201、确定神经网络的损失函数。
其中,softmax交叉熵损失函数,范畴交叉熵(Categorical Crossentropy),或者,二元交叉熵(Binary Crossentropy)等。
可选的,神经网络的损失函数可以为上述(1)所示的损失函数。
S202、确定所述损失函数关于所述K-1个类别中每个类别的预测值的第一偏导数。
S203、确定所述损失函数关于所述图像标注的类别的预测值的第二偏导数。
S204、根据所述第一偏导数和所述第二偏导数,确定所述神经网络的更新梯度。
S205、根据所述更新梯度,调整所述神经网络的网络参数。
示例性的,以神经网络的损失函数为上式(1)为例进行说明:
可选的,该损失函数携带在图3所示的softmax loss层中。
确定损失函数关于图像的K-1个易混淆类别中每个类别的预测值的第一偏导数,可以获得如下式(2)所示:
Figure BDA0002022654790000131
确定所述损失函数关于图像标注的类别的预测值(即图像标注的类别的真值)的第二偏导数,可以获得如下式(3)所示:
Figure BDA0002022654790000132
结合上式(1)、上述(2)和上述(3),可以获得如下式(4)和式(5),
Figure BDA0002022654790000133
Figure BDA0002022654790000134
根据上述确定的第一偏导数和第二偏导数,可以确定出神经网络的更新梯度,例如,对上述第一偏导数和第二偏导数进行数学运算,将运算结果作为神经网络的更新梯度。
在一种示例中,可以直接将第一偏导数和第二偏导数,作为神经网络的更新梯度。
根据上述步骤,可以获得神经网络的更新梯度,这样基于神经网络的更新梯度反传输入神经网络,基于链式法则,实现对神经网络的参数进行更新。
继续参照图3所示,由上述公式(2)和公式(3)可知,softmax loss层的更新梯度包括K个值,即第一偏导数和第二偏导数共同对应K个类别的梯度,该K个类别的梯度分别对应图像易混淆的K-1个类别中类别和图像标记的类别。而全连接层包括N个类别,因此,在反向传播,将softmax loss层的更新梯度赋值给全连接层的N个类别时,将N个类别中除上述K个类别之外的类别对应的更新梯度设定为0。这样,全连接层使用上述更新梯度进行反向传播,且基于链式法,可以实现对神经网络中每一层参数的更新,进而完成对神经网络的准确训练。
本申请实施例的方法,通过确定损失函数关于所述K-1个类别中每个类别的预测值的第一偏导数。确定所述损失函数关于所述图像标注的类别的预测值的第二偏导数。根据所述第一偏导数和所述第二偏导数,确定所述神经网络的更新梯度。根据所述更新梯度,调整所述神经网络的网络参数。进而实现对神经网络模型的准确训练,使得训练后的神经网络具有区分易混淆类别的能力,进而使用该训练后的神经网络进行图像分类时,可以实现对图像的准确分类。
图5为本申请实施例提供的图像分类方法的流程图,如图5所示,包括:
S301、获取待分类的图像。
S302、将所述待分类的图像输入神经网络,确定所述待分类图像的分类结果。
其中,所述神经网络为采用上述图1或图3所述的神经网络训练方法训练得到的。
本实施例的执行主体是电子设备或者电子设备中的处理器,该电子设备可以是计算机、智能手机、AR(Augmented Reality Technique,增强现实技术)眼镜、车载系统等。
可选的,本实施例的电子设备还可以具有摄像头,可以拍摄待分类的图像,并将该待分类的图像发送给电子设备的处理器。
可选的,本实施例的电子设备可以与其他的摄像头连接,该摄像头可以拍摄待分类的图像,电子设备可以从该摄像头处获得待分类的图像。
本实施例的电子设备还包括存储介质,该存储介质中保存有训练后的神经网络,且处理器可以调用该神经网络。其中,神经网络的训练过程可以参照上述实施例所述的神经网络训练方法的描述,在此不再赘述。
本申请实施例,通过获取待分类的图像,并将待分类的图像输入神经网络,可以得到图像的准确分别结果,这样因为本申请采用图像易混淆的K-1个类别来训练神经网络,进而提高神经网络对易混淆的K-1类别的区分能力,从而提高了该神经网络的分类准确性。
图6为本申请实施例提供的神经网络训练装置的结构示意图。如图6所示,本实施例的神经网络训练装置100可以包括:
输入模块110,用于将标注有类别信息的图像输入神经网络;
预测模块120,用于经所述神经网络预测所述图像属于预定N个类别中每个类别的预测值;
混淆类别确定模块130,用于确定除所述图像标注的类别之外预测值大于设定阈值的K-1个类别,所述N大于所述K且所述N和所述K分别是大于2的正整数;
损失确定模块140,用于根据所述K-1个类别的预测值和所述图像标注的类别信息,确定所述图像的类别预测损失;
调整模块150,用于根据所述类别预测损失调整所述神经网络的网络参数。
本申请实施例的神经网络训练装置,可以用于执行上述所示方法实施例的技术方案,其实现原理和技术效果类似,此处不再赘述。
在一种可能的实现方式下,所述损失确定模块140,还用于确定所述图像在所述N个类别中除所述K-1个类别之外的其他类别的类别预测损失为0。
图7为本申请实施例提供的神经网络训练装置的结构示意图,所述神经网络训练装置100还包括:训练模块160,
上述训练模块160,用于基于分别标注有类别信息的图像集多次迭代训练神经网络,一次迭代训练完成后进行下一次迭代训练,直至满足训练停止条件,其中,每次迭代训练过程调用图6的各模块,且不同迭代次训练输入所述神经网络中的图像不完全相同。
在一种可能的实现方式下,所述输入模块110,具体用于每次迭代训练过程向所述神经网络输入多张图像。
所述损失确定模块140,具体用于分别预测所述多张图像中每个图像的类别预测损失。
所述调整模块150,具体用于确定所述多张图像的平均类别预测损失,根据所述平均类别预测损失调整所述神经网络的网络参数。
可选的,所述N和所述K分别是大于1000的正整数,和/或,训练所述神经网络的图像总数量大于1000张。
本申请实施例的神经网络训练装置,可以用于执行上述所示方法实施例的技术方案,其实现原理和技术效果类似,此处不再赘述。
图8为本申请实施例提供的神经网络训练装置的结构示意图,所述调整模块150包括:确定单元151和调整单元152,
所述确定单元151,用于确定所述神经网络的损失函数,并确定所述损失函数关于所述K-1个类别中每个类别的预测值的第一偏导数;确定所述损失函数关于所述图像标注的类别的预测值的第二偏导数;以及根据所述第一偏导数和所述第二偏导数,确定所述神经网络的更新梯度;
所述调整单元152,用于根据所述更新梯度,调整所述神经网络的网络参数;
其中,所述N个类别中除所述K个类别之外的类别对应的更新梯度为0。
本申请实施例的神经网络训练装置,可以用于执行上述所示方法实施例的技术方案,其实现原理和技术效果类似,此处不再赘述。
图9为本申请实施例提供的图像分类装置的结构示意图,所述图像分类装置200包括:
获取模块210,用于获取待分类的图像。
确定模块220,用于将所述图像输入神经网络,确定所述图像的分类结果。
其中,所述神经网络为采用上述图1或图3所述的神经网络训练方法训练得到的。
本申请实施例的图像分类装置,可以用于执行上述所示图像分类方法实施例的技术方案,其实现原理和技术效果类似,此处不再赘述。
图10为本申请实施例提供的电子设备的结构示意图,如图10所示,本实施例的电子设备30包括:
存储器310,用于存储计算机程序;
处理器320,用于执行所述计算机程序,以实现上述的神经网络训练方法或者图像分类方法,其实现原理和技术效果类似,此处不再赘述。
进一步的,当本申请实施例中神经网络训练方法和/或图像分类方法的至少一部分功能通过软件实现时,本申请实施例还提供一种计算机存储介质,计算机存储介质用于储存为上述对神经网络训练和/或图像分类的计算机软件指令,当其在计算机上运行时,使得计算机可以执行上述方法实施例中各种可能的神经网络训练方法和/或图像分类方法。在计算机上加载和执行所述计算机执行指令时,可全部或部分地产生按照本申请实施例所述的流程或功能。所述计算机指令可以存储在计算机存储介质中,或者从一个计算机存储介质向另一个计算机存储介质传输,所述传输可以通过无线(例如蜂窝通信、红外、短距离无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。所述计算机存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可用介质集成的服务器、数据中心等数据存储设备。所述可用介质可以是磁性介质,(例如,软盘、硬盘、磁带)、光介质(例如,DVD)、或者半导体介质(例如SSD)等。
最后应说明的是:以上各实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述各实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分或者全部技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的范围。

Claims (10)

1.一种神经网络训练方法,其特征在于,包括:
将标注有类别信息的图像输入神经网络;
经所述神经网络预测所述图像属于预定N个类别中每个类别的预测值;
确定除所述图像标注的类别之外预测值大于设定阈值的K-1个类别,所述N大于所述K且所述N和所述K分别是大于2的正整数;
根据所述K-1个类别的预测值和所述图像标注的类别信息,确定所述图像的类别预测损失;
根据所述类别预测损失调整所述神经网络的网络参数。
2.根据权利要求1所述的方法,其特征在于,还包括:
确定所述图像在所述N个类别中除所述K-1个类别之外的其他类别的类别预测损失为0。
3.根据权利要求1所述的方法,其特征在于,基于分别标注有类别信息的图像集多次迭代训练神经网络,一次迭代训练完成后进行下一次迭代训练,直至满足训练停止条件,其中,每次迭代训练过程执行如权利要求1所述的方法,且不同迭代次训练输入所述神经网络中的图像不完全相同。
4.根据权利要求1所述的方法,其特征在于,每次迭代训练过程向所述神经网络输入多张图像,分别预测所述多张图像中每个图像的类别预测损失;
根据所述类别预测损失调整所述神经网络的网络参数,包括:确定所述多张图像的平均类别预测损失,根据所述平均类别预测损失调整所述神经网络的网络参数。
5.根据权利要求1-4任一项所述的方法,其特征在于,所述根据所述类别预测损失调整所述神经网络的网络参数,包括:
确定所述神经网络的损失函数;
确定所述损失函数关于所述K-1个类别中每个类别的预测值的第一偏导数;
确定所述损失函数关于所述图像标注的类别的预测值的第二偏导数;
根据所述第一偏导数和所述第二偏导数,确定所述神经网络的更新梯度;
根据所述更新梯度,调整所述神经网络的网络参数;
其中,所述N个类别中除所述K个类别之外的类别对应的更新梯度为0。
6.一种图像分类方法,其特征在于,包括:
获取待分类的图像;
将所述图像输入神经网络,确定所述图像的分类结果;
其中,所述神经网络为采用上述权利要求1-5任一项所述的训练方法训练得到的。
7.一种神经网络训练装置,其特征在于,包括:
输入模块,用于将标注有类别信息的图像输入神经网络;
预测模块,用于经所述神经网络预测所述图像属于预定N个类别中每个类别的预测值;
混淆类别确定模块,用于确定除所述图像标注的类别之外预测值大于设定阈值的K-1个类别,所述N大于所述K且所述N和所述K分别是大于2的正整数;
损失确定模块,用于根据所述K-1个类别的预测值和所述图像标注的类别信息,确定所述图像的类别预测损失;
调整模块,用于根据所述类别预测损失调整所述神经网络的网络参数。
8.一种图像分类装置,其特征在于,包括:
获取单元,用于获取待分类的图像;
确定模块,用于将所述图像输入神经网络,确定所述图像的分类结果;
其中,所述神经网络为采用上述权利要求1-5任一项所述的训练方法训练得到的。
9.一种电子设备,其特征在于,包括:
存储器,用于存储计算机程序;
处理器,用于执行所述计算机程序,以实现如权利要求1-5任一项所述的神经网络训练方法,或者实现如权利要求6所述的图像分类方法。
10.一种计算机存储介质,其特征在于,所述存储介质中存储计算机程序,所述计算机程序在执行时实现如权利要求1-5任一项所述的神经网络训练方法,或者实现如权利要求6所述的图像分类方法。
CN201910284005.4A 2019-04-10 2019-04-10 神经网络训练和图像分类方法与装置 Pending CN111814813A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201910284005.4A CN111814813A (zh) 2019-04-10 2019-04-10 神经网络训练和图像分类方法与装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201910284005.4A CN111814813A (zh) 2019-04-10 2019-04-10 神经网络训练和图像分类方法与装置

Publications (1)

Publication Number Publication Date
CN111814813A true CN111814813A (zh) 2020-10-23

Family

ID=72844499

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201910284005.4A Pending CN111814813A (zh) 2019-04-10 2019-04-10 神经网络训练和图像分类方法与装置

Country Status (1)

Country Link
CN (1) CN111814813A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112070777A (zh) * 2020-11-10 2020-12-11 中南大学湘雅医院 一种基于增量学习的多场景下的危及器官分割方法及设备
CN112613577A (zh) * 2020-12-31 2021-04-06 上海商汤智能科技有限公司 神经网络的训练方法、装置、计算机设备及存储介质

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2017096758A1 (zh) * 2015-12-11 2017-06-15 腾讯科技(深圳)有限公司 图像分类方法、电子设备和存储介质
CN108171275A (zh) * 2018-01-17 2018-06-15 百度在线网络技术(北京)有限公司 用于识别花卉的方法和装置
CN108875821A (zh) * 2018-06-08 2018-11-23 Oppo广东移动通信有限公司 分类模型的训练方法和装置、移动终端、可读存储介质
CN108875779A (zh) * 2018-05-07 2018-11-23 深圳市恒扬数据股份有限公司 神经网络的训练方法、装置及终端设备
CN108875934A (zh) * 2018-05-28 2018-11-23 北京旷视科技有限公司 一种神经网络的训练方法、装置、系统及存储介质
CN109299716A (zh) * 2018-08-07 2019-02-01 北京市商汤科技开发有限公司 神经网络的训练方法、图像分割方法、装置、设备及介质
CN109508655A (zh) * 2018-10-28 2019-03-22 北京化工大学 基于孪生网络的不完备训练集的sar目标识别方法

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2017096758A1 (zh) * 2015-12-11 2017-06-15 腾讯科技(深圳)有限公司 图像分类方法、电子设备和存储介质
CN108171275A (zh) * 2018-01-17 2018-06-15 百度在线网络技术(北京)有限公司 用于识别花卉的方法和装置
CN108875779A (zh) * 2018-05-07 2018-11-23 深圳市恒扬数据股份有限公司 神经网络的训练方法、装置及终端设备
CN108875934A (zh) * 2018-05-28 2018-11-23 北京旷视科技有限公司 一种神经网络的训练方法、装置、系统及存储介质
CN108875821A (zh) * 2018-06-08 2018-11-23 Oppo广东移动通信有限公司 分类模型的训练方法和装置、移动终端、可读存储介质
CN109299716A (zh) * 2018-08-07 2019-02-01 北京市商汤科技开发有限公司 神经网络的训练方法、图像分割方法、装置、设备及介质
CN109508655A (zh) * 2018-10-28 2019-03-22 北京化工大学 基于孪生网络的不完备训练集的sar目标识别方法

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
MEHWISH MALIK ET AL.: "Exploiting Class Hierarchies for Large-Scale Scene Classification Using Hybrid Discriminative Approach", 2018 IEEE 18TH INTERNATIONAL CONFERENCE ON COMMUNICATION TECHNOLOGY (ICCT), 3 January 2019 (2019-01-03), pages 1217 - 1221 *
陈慧岩: "智能车辆理论与应用", 北京理工大学出版社, pages: 68 - 69 *
黎奉薪: "基于深层卷积神经网络的物体识别研究", 中国优秀硕士学位论文全文数据库 信息科技辑, vol. 2018, no. 02, 15 February 2018 (2018-02-15), pages 138 - 1809 *

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112070777A (zh) * 2020-11-10 2020-12-11 中南大学湘雅医院 一种基于增量学习的多场景下的危及器官分割方法及设备
CN112613577A (zh) * 2020-12-31 2021-04-06 上海商汤智能科技有限公司 神经网络的训练方法、装置、计算机设备及存储介质
CN112613577B (zh) * 2020-12-31 2024-06-11 上海商汤智能科技有限公司 神经网络的训练方法、装置、计算机设备及存储介质

Similar Documents

Publication Publication Date Title
US10223614B1 (en) Learning method, learning device for detecting lane through classification of lane candidate pixels and testing method, testing device using the same
US10726313B2 (en) Active learning method for temporal action localization in untrimmed videos
US10504027B1 (en) CNN-based learning method, learning device for selecting useful training data and test method, test device using the same
US10474713B1 (en) Learning method and learning device using multiple labeled databases with different label sets and testing method and testing device using the same
KR102582194B1 (ko) 선택적 역전파
WO2018119684A1 (zh) 一种图像识别系统及图像识别方法
CN111079780B (zh) 空间图卷积网络的训练方法、电子设备及存储介质
US10262214B1 (en) Learning method, learning device for detecting lane by using CNN and testing method, testing device using the same
CN110781960B (zh) 视频分类模型的训练方法、分类方法、装置及设备
JP7448562B2 (ja) 人工知能のための希な訓練データへの対処
CN113704522B (zh) 基于人工智能的目标图像快速检索方法及系统
KR20200027887A (ko) 복수의 비디오 프레임을 이용하여 cnn의 파라미터를 최적화하기 위한 학습 방법 및 학습 장치 그리고 이를 이용한 테스트 방법 및 테스트 장치
JP2023042582A (ja) サンプル分析の方法、電子装置、記憶媒体、及びプログラム製品
CN112926461B (zh) 神经网络训练、行驶控制方法及装置
CN112149754B (zh) 一种信息的分类方法、装置、设备及存储介质
CN112613617A (zh) 基于回归模型的不确定性估计方法和装置
CN111814813A (zh) 神经网络训练和图像分类方法与装置
CN114329022A (zh) 一种色情分类模型的训练、图像检测方法及相关装置
CN113140012A (zh) 图像处理方法、装置、介质及电子设备
CN117095460A (zh) 基于长短时关系预测编码的自监督群体行为识别方法及其识别系统
CN111652320A (zh) 一种样本分类方法、装置、电子设备及存储介质
CN114241411B (zh) 基于目标检测的计数模型处理方法、装置及计算机设备
CN117523218A (zh) 标签生成、图像分类模型的训练、图像分类方法及装置
CN110705695B (zh) 搜索模型结构的方法、装置、设备和存储介质
CN112989869B (zh) 人脸质量检测模型的优化方法、装置、设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination