CN112085041A - 神经网络的训练方法、训练装置和电子设备 - Google Patents
神经网络的训练方法、训练装置和电子设备 Download PDFInfo
- Publication number
- CN112085041A CN112085041A CN201910507780.1A CN201910507780A CN112085041A CN 112085041 A CN112085041 A CN 112085041A CN 201910507780 A CN201910507780 A CN 201910507780A CN 112085041 A CN112085041 A CN 112085041A
- Authority
- CN
- China
- Prior art keywords
- neural network
- samples
- matrix
- optimal transmission
- distance matrix
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000013528 artificial neural network Methods 0.000 title claims abstract description 128
- 238000012549 training Methods 0.000 title claims abstract description 86
- 238000000034 method Methods 0.000 title claims abstract description 50
- 239000011159 matrix material Substances 0.000 claims abstract description 181
- 230000005540 biological transmission Effects 0.000 claims abstract description 116
- 230000006870 function Effects 0.000 claims description 74
- 238000004590 computer program Methods 0.000 claims description 11
- 238000005065 mining Methods 0.000 abstract 1
- 238000010586 diagram Methods 0.000 description 14
- 230000008569 process Effects 0.000 description 7
- 238000013527 convolutional neural network Methods 0.000 description 4
- 238000013135 deep learning Methods 0.000 description 3
- 238000012545 processing Methods 0.000 description 3
- 238000005259 measurement Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 238000005457 optimization Methods 0.000 description 2
- 239000013598 vector Substances 0.000 description 2
- 230000016776 visual perception Effects 0.000 description 2
- 230000004913 activation Effects 0.000 description 1
- 238000007792 addition Methods 0.000 description 1
- 230000004075 alteration Effects 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 238000000354 decomposition reaction Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000011478 gradient descent method Methods 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 230000006798 recombination Effects 0.000 description 1
- 238000005215 recombination Methods 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 230000003252 repetitive effect Effects 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Image Analysis (AREA)
Abstract
公开了一种神经网络的训练方法、神经网络的训练装置和电子设备。该神经网络的训练方法包括:通过神经网络从一批样本获得样本距离矩阵,所述样本距离矩阵包括所述一批样本中的同类样本距离和异类样本距离;计算出与所述样本距离矩阵对应的最优传输规划矩阵;基于所述样本距离矩阵与所述最优传输规划矩阵的乘积的加权之和,确定最优传输损失函数值;以及,基于所述最优传输损失函数值更新所述神经网络和所述最优传输规划矩阵的参数。这样,通过对一批样本当中的难样本进行挖掘,提高了网络的训练的收敛速率和性能。
Description
技术领域
本申请涉及深度学习领域,且更为具体地,涉及一种神经网络的训练方法,神经网络的训练装置和电子设备。
背景技术
在深度学习领域中,通过学习数据的语义嵌入度量,缩小数据类内差异(或距离),使相似的同类样本聚集在一起,以及扩大数据类间差异(或距离),使不相似的异类样本分开是物体识别任务的重要基础。
随着深度学习技术的迅速发展,深度度量学习近年来越来越受到重视。在深度度量学习中,通过端到端地训练深度神经网络,可以学习到复杂的高度非线性的数据深度特征表示(从输入空间到低维语义嵌入度量空间)。
深度度量学习到的深度特征表示和语义嵌入度量在视觉识别中有广泛的应用场景和优异的识别性能,例如,二维(2D)自然图像检索/分类、人脸识别、三维(3D)物体检索/分类、多源异构视觉感知数据跨模态检索(例如2D图像/视频、3D物体、文本数据之间的检索匹配)等。
因此,期望提供改进的神经网络的训练方案。
发明内容
为了解决上述技术问题,提出了本申请。本申请的实施例提供了一种神经网络的训练方法、神经网络的训练装置和电子设备,其使用区分同类样本和异类样本的样本距离矩阵及其对应的最优传输规划矩阵来构造最优传输损失函数值并以此训练神经网络,从而学习样本重要性驱动的距离度量,提高了网络训练的收敛速率。
根据本申请的一方面,提供了一种神经网络的训练方法,包括:通过神经网络从一批样本获得样本距离矩阵,所述样本距离矩阵包括所述一批样本中的同类样本距离和异类样本距离;计算出与所述样本距离矩阵对应的最优传输规划矩阵;基于所述样本距离矩阵与所述最优传输规划矩阵的乘积的加权之和,确定最优传输损失函数值;以及,基于所述最优传输损失函数值更新所述神经网络和所述最优传输规划矩阵的参数。
根据本申请的另一方面,提供了一种神经网络的训练装置,包括:距离矩阵获得单元,用于通过神经网络从一批样本获得样本距离矩阵,所述样本距离矩阵包括所述一批样本中的同类样本距离和异类样本距离;传输矩阵获得单元,用于计算出与所述距离矩阵获得单元所获得的所述样本距离矩阵对应的最优传输规划矩阵;损失函数确定单元,用于基于所述距离矩阵获得单元所获得的所述样本距离矩阵与所述传输矩阵获得单元所获得的所述最优传输规划矩阵的乘积的加权之和,确定最优传输损失函数值;以及,参数更新单元,用于基于所述损失函数确定单元所确定的所述最优传输损失函数值更新所述神经网络和所述最优传输规划矩阵的参数。
根据本申请的再一方面,提供了一种电子设备,包括:处理器;以及,存储器,在所述存储器中存储有计算机程序指令,所述计算机程序指令在被所述处理器运行时使得所述处理器执行如上所述的神经网络的训练方法。
根据本申请的又一方面,提供了一种计算机可读介质,其上存储有计算机程序指令,所述计算机程序指令在被处理器运行时使得所述处理器执行如上所述的神经网络的训练方法。
本申请提供的神经网络的训练方法、神经网络的训练装置和电子设备使用包括一批样本中的同类样本距离和异类样本距离的样本距离矩阵来通过最优传输方法计算出最优传输规划矩阵,使得能够按照样本的重要性赋予难样本更高的权重,这样,通过基于最优传输规划矩阵构造最优传输损失函数并以此训练神经网络,可以使得神经网络能够学习到样本重要性驱动的距离度量,从而提高网络训练的收敛速率。
附图说明
通过结合附图对本申请实施例进行更详细的描述,本申请的上述以及其他目的、特征和优势将变得更加明显。附图用来提供对本申请实施例的进一步理解,并且构成说明书的一部分,与本申请实施例一起用于解释本申请,并不构成对本申请的限制。在附图中,相同的参考标号通常代表相同部件或步骤。
图1A图示了根据本申请实施例的通过最优传输损失扩展样本间的语义信息的示意图。
图1B图示了现有的成对样本情况下的距离度量学习的示意图。
图1C图示了根据本申请实施例的批量样本情况下的距离度量学习的示意图。
图2图示了根据本申请实施例的神经网络的训练方法的流程图。
图3图示了根据本申请实施例的神经网络的训练方法中获得样本距离矩阵的示例的流程图。
图4图示了根据本申请实施例的神经网络的训练装置的框图。
图5图示了根据本申请实施例的神经网络的训练装置的距离矩阵获得单元的示例的框图。
图6图示了根据本申请实施例的电子设备的框图。
具体实施方式
下面,将参考附图详细地描述根据本申请的示例实施例。显然,所描述的实施例仅仅是本申请的一部分实施例,而不是本申请的全部实施例,应理解,本申请不受这里描述的示例实施例的限制。
申请概述
如上所述,目前在深度度量学习中,广泛使用基于对比损失和三元组损失的方案。
对比损失(Contrastive loss)用于训练孪生网络(Siamese Network),其输入为两个样本(样本对),每一对样本都有标签,表示两个样本属于同类(正样本对)或者异类(负样本对)。当输入为正样本对的时候,对比损失逐渐减小,相同类标签的样本会持续在特征空间形成聚类。反之,当网络输入负样本对时,对比损失会逐渐变大,直到超过设定的阈值。通过最小化对比损失函数,可以使正样本对之间距离逐渐变小,负样本对之间距离逐渐变大,从而满足识别任务的需要。
三元组损失(Triplet loss)是另一种广泛应用的度量学习损失函数。三元组损失同时输入三个样本。不同于对比损失,一个输入的三元组(Triplet)包括锚样本(Anchor)、正样本和负样本图片。通过优化三元组损失,可以使得网络不仅能够在特征空间把正负样本对推开,也能把正样本对之间的距离拉近。
但是,以上广泛使用的基于样本对或三元组的目标损失函数都不能充分利用训练样本中的语义信息。因为目前深度神经网络常用一批(batch)样本作为输入,而对比损失或三元组损失每次更新时仅考虑单个样本对或三元组内的语义信息,忽略与一批样本内其余样本之间的信息。这会使所学的嵌入度量和特征表示产生偏差。此外,这些损失函数也不能在深度网络优化过程中对难样本(在度量学习中,同类相似的样本距离较远,异类不相似的样本距离较近,称为难样本)给予足够的重视。因此,往往存在收敛速度慢和性能较差的问题。
基于上述技术问题,本申请的基本构思是从一批样本获得区分同类样本和异类样本的样本距离矩阵,并通过最优传输方法获得其对应的最优传输规划矩阵以构造最优传输损失函数,并以此来训练神经网络。
本申请提供的神经网络的训练方法、神经网络的训练装置和电子设备首先通过神经网络从一批样本获得样本距离矩阵,所述样本距离矩阵包括所述一批样本中的同类样本距离和异类样本距离,然后计算出与所述样本距离矩阵对应的最优传输规划矩阵,再基于所述样本距离矩阵与所述最优传输规划矩阵的乘积的加权之和,确定最优传输损失函数值,最后基于所述最优传输损失函数值更新所述神经网络和所述最优传输规划矩阵的参数。
这样,由于在根据本申请的神经网络的训练方法、神经网络的训练装置和电子设备中,所述样本距离矩阵包括所述一批样本中的同类样本距离和异类样本距离,因此通过最优传输方法计算出的最优传输规划矩阵能够按照样本的重要性赋予难样本更高的权重。因此,通过进一步基于最优传输规划矩阵构造最优传输损失函数并以此训练神经网络,可以使得神经网络能够学习到样本重要性驱动的距离度量,从而提高网络训练的收敛速率。
图1A图示了根据本申请实施例的通过最优传输损失扩展样本间的语义信息的示意图。
如图1A所示,通过经由最优传输损失函数将成对样本之间的语义信息扩展到批次的所有样本之间的语义信息,可以充分利用批次内的各个样本之间的语义信息,从而使得所学的嵌入度量和特征表示更加准确。
图1B图示了现有的成对样本情况下的距离度量学习的示意图。
如图1B所示,在成对的样本距离矩阵学习时,在每次更新时仅考虑一对样本之间的语义信息,也就是,通过学习,仅能够缩小样本x1和x2之间的距离。
图1C图示了根据本申请实施例的批量样本情况下的距离度量学习的示意图。如图1C所示,在基于批次的神经网络训练中,使用训练的批次内的所有可用信息来优化重要性驱动的样本距离矩阵,以使得自动强调具有大的距离的相似正样本和具有小的距离的不相似负样本,从而提高网络训练的收敛速率。也就是,在根据本申请实施例的批量样本情况下的距离度量学习中,不仅能够缩小作为相似正样本的样本x1和x2以及样本x1和x4之间的距离,还可能扩大作为不相似负样本的样本x1和x3之间的距离。
值得注意的是,根据本申请的神经网络的训练方法、神经网络的训练装置和电子设备中基于最优传输规划矩阵构造的最优传输损失函数不仅可以用于训练卷积神经网络等深层神经网络,也可以用于训练例如树模型、核模型等浅层模型。
在介绍了本申请的基本原理之后,下面将参考附图来具体介绍本申请的各种非限制性实施例。
示例性方法
图2图示了根据本申请实施例的神经网络的训练方法的流程图。
如图2所示,根据本申请实施例的神经网络的训练方法包括以下步骤。
步骤S110,通过神经网络从一批样本获得样本距离矩阵,所述样本距离矩阵包括所述一批样本中的同类样本距离和异类样本距离。如上所述,对于输入神经网络的一批样本,例如多幅图像,存在同类样本和异类样本,例如属于同一对象的图像和属于不同对象的图像。相应地,所述样本距离矩阵描述了所述一批样本中的每两个样本之间的距离,因此,为了缩小样本的类内差异(或距离)和扩大样本的类间差异(或距离),在本申请实施例中,所述样本距离矩阵针对同类样本和异类样本区分同类样本距离和异类样本距离。
值得注意的是,在本申请实施例中,所述同类样本和所述异类样本是按照样本之间的相似性来区分,因此,所述同类样本指的是相似度高的样本,而所述异类样本指的是较为不相似的样本,而并不是指所述同类样本必须属于相同类别,例如相同对象的图像。
步骤S120,计算出与所述样本距离矩阵对应的最优传输规划矩阵。这里,所述最优传输规划矩阵是通过最优传输方法从所述样本距离矩阵计算出的矩阵,其可以如下式定义:
其中,给定两个批次,每个批次包括n个样本,r和c是两个批次的n维概率向量,并且,Tij是最优传输规划矩阵,Mij是样本距离矩阵,h(Tij)是最优传输规划矩阵Tij的熵,且λ越大,越接近于初始的DM(r,c)。因此,最优传输规划矩阵实际上是用于在损失优化期间强调难的同类样本和异类样本的权重的概率分布。
这里,难的同类样本指的是样本之间距离远的同类样本,而难的异类样本指的是样本之间距离近的异类样本。所述最优传输规划矩阵是目的在于找到在样本之间传输的最小成本量的概率分布,且这种成本对应于传输样本的距离,因此,所获得的最优传输规划矩阵能够加大难样本的权重。
步骤S130,基于所述样本距离矩阵与所述最优传输规划矩阵的乘积的加权之和,确定最优传输损失函数值。也就是,通过计算所述样本距离矩阵与所述最优传输规划矩阵的乘积,可以获得用于同类样本和异类样本的重要性驱动的距离度量。因此,在本申请实施例中,所述最优传输损失函数值能够通过所述最优运输规划矩阵从成批的样本中自动学习到一种样本重要性驱动的距离度量。
也就是,所述最优传输损失函数值能够在一批样本中使得相似的同类样本聚集在一起,以及使得不相似的异类样本分开,从而更加准确和快捷地区分同类样本和异类样本。
步骤S140,基于所述最优传输损失函数值更新所述神经网络和所述最优传输规划矩阵的参数。如上所述,通过以所述最优传输损失函数值更新所述神经网络和所述最优传输规划矩阵,可以实现深度度量学习网络架构,从而可以在训练过程中自动地发掘难样本并对其加大权重,显著地改善网络训练的收敛速率。
另外,由于所述神经网络在训练的过程中强调了难样本的学习,即赋予难样本更大权重,也可以改进对于难样本的学习准确性,从而提高神经网络的性能。例如,对于用于对象识别的神经网络,可以提高其识别性能,而对于用于分类的神经网络,可以提高其分类准确性。
图3图示了根据本申请实施例的神经网络的训练方法中获得样本距离矩阵的示例的流程图。
如图3所示,在如图2所示的实施例的基础上,所述步骤S110包括以下步骤。
步骤S1101,通过所述神经网络从所述一批样本获得同类样本距离矩阵。如上所述,因为样本的距离矩阵包括所述样本中的每两个样本之间的距离,在本申请实施例中,首先针对所述一批样本中的同类样本,通过所述神经网络获得所述同类样本距离矩阵。
例如,假设某批样本包括三个样本,则该批样本的距离矩阵应该为3×3矩阵,并且假设该批样本中的第一和第二样本为同类样本,且与所述第三样本为异类样本,则所获得的同类样本距离矩阵包括所述第一样本和所述第二样本之间的同类样本距离,且所述第一样本与所述第三样本以及所述第二样本与所述第三样本之间的异类样本距离为0。
步骤S1102,通过所述神经网络从所述一批样本获得异类样本距离矩阵。同样,按照上述示例,所获得的同类样本距离矩阵包括所述第一样本与所述第三样本以及所述第二样本与所述第三样本之间的异类样本距离,且所述第一样本与所述第二样本之间的同类样本距离为0。
步骤S1103,逐元素合并所述同类样本距离矩阵和所述异类样本距离矩阵以获得样本距离矩阵。也就是,对于如上所述的包括三个样本的一批样本,其同类样本距离矩阵和异类样本距离矩阵均为3×3矩阵,通过逐元素对这两个矩阵进行合并,就可以获得包含所述第一样本和所述第二样本之间的同类样本距离,以及所述第一样本与所述第三样本以及所述第二样本与所述第三样本之间的异类样本距离的距离矩阵。
因此,通过如上所述的获得样本距离矩阵的示例,可以在样本距离矩阵的计算中区分同类样本和异类样本,从而使得样本距离矩阵中包含同类样本距离和异类样本距离,以提高样本距离矩阵对于同类样本和异类样本的区分度,从而提高诸如识别任务等的性能。
在一个示例中,所述同类样本距离可以定义为第一幂函数,且所述第一幂函数的底数是自然常数,且所述第一幂函数的指数是缩放参数与样本间的欧氏距离的乘积,如下式所示:
通过如上所述的同类样本矩阵G+,训练得到的最优传输规划矩阵将对于那些在彼此之间具有较大的欧氏距离的相似样本(也就是,难的同类样本)给予较高的重要性值,同时相应地对于其它相似样本给予较低的重要性值。因此,这将加速相似样本彼此靠近的过程。
另外,在本申请实施例中,所述异类样本距离可以定义为第二幂函数,所述第二幂函数的底数是自然常数,所述第二幂函数的指数是所述缩放参数与样本间的欧氏距离的铰链损失值的乘积,如下式所示:
因此,与上述同类样本距离相对地,通过异类样本距离,训练得到的最优传输规划矩阵将对于那些具有小的欧氏距离的不相似样本(即,难的异类样本)给予较高的重要性值,同时对于其它异类样本给予较低的重要性值。因此,这将加速不相似样本彼此远离的过程。
这样,通过加速相似样本彼此靠近的过程和不相似样本彼此远离的过程,可以显著地改善网络训练的收敛速率。
在一个示例中,所述最优传输损失函数值可以看做对比损失或者三元组损失的n对扩展版本,针对如上所述的同类样本距离G+和异类样本距离G-,所述最优传输损失函数值可以相应地定义为:
其中Tij是最优传输规划矩阵,Mij是样本距离矩阵,Iij是单位矩阵,而Yij是分配给一对训练样本的二元标签。如果样本xi和xj被认为相似则Yij=1,否则Yij=0。
也就是,在本申请实施例中,首先将用于表示一对样本间的相似性的二元标签作为用于计算所述加权和的权重,再基于所述权重来对所述样本距离矩阵与所述最优传输规划矩阵的乘积加权,以获得所述最优传输损失函数值。
因此,如上所述,所述最优传输损失函数值实际上可以看做表示同类样本之间的重要性驱动的距离度量和异类样本之间的重要性驱动的距离度量的两项,且通过这两项的加权和,可以使得所述最优传输损失函数值对于同类样本强调第一项,而对于异类样本强调第二项,从而反映样本整体的重要性驱动的距离度量。这样,改善了网络训练的收敛速率并提高了网络性能。
在一个示例中,在基于所述最优传输损失函数值更新所述神经网络和所述最优传输规划矩阵的过程中,基于所述最优传输损失函数值通过梯度下降的方式更新所述神经网络。
因此,通过基于最优传输损失函数值以梯度下降的方式更新所述神经网络,可以方便地训练诸如指示对象识别任务或者分类任务的深度卷积神经网络的神经网络,从而改善网络训练的便利度。
在一个示例中,在基于所述最优传输损失函数值通过梯度下降的方式更新所述神经网络时,首先计算所述最优传输规划矩阵与其对应的特征差值以及所述权重的差值的乘积,然后将所述乘积对于所述多个样本求和以计算出所述最优传输损失函数值的梯度,接下来基于所述梯度通过梯度下降的方式更新所述神经网络。
在一个示例中,根据本申请实施例的神经网络的训练方法可以用于深度度量学习网络架构的训练,这种深度度量学习网络架构可以是通用的、可对多源异构视觉感知数据,如2D自然图像/视频/手绘草图、2.5D深度图像、3D物体形状等进行跨模态数据识别的深度度量学习网络架构。
例如,所述深度度量学习网络架构可以包括用于从一批样本提取特征以获得特征图的深度神经网络,例如Resnet-50等卷积神经网络,并且还包括用于从所述特征图获得样本距离矩阵的深度矩阵学习网络。所述深度矩阵学习网络例如可以包括四个全连接层以用于执行特征图的降维,并可以另外在这些全连接层当中添加三个sigmoid激活函数来生成归一化和致密的特征向量。
如上所述,在本申请实施例中,所述一批样本可以包括二维图像、二维手绘草图和三维物体形状中的至少一个。相应地,针对每种样本,例如二维图像和二维手绘草图,可以采用不同的深度神经网络来获得其特征图。例如,针对二维图像,可以使用LeNet-5深度神经网络获得其特征图,而对于二维手绘草图,可以使用Resnet-50卷积神经网络获得其特征图。
因此,根据本申请实施例的神经网络的训练方法可以端到端地训练用于跨模态数据识别的神经网络架构,从而以高收敛速率训练得到识别性能高的神经网络。
示例性装置
图4图示了根据本申请实施例的神经网络的训练装置的框图。
如图4所示,根据本申请实施例的神经网络的训练装置200包括:距离矩阵获得单元210,用于通过神经网络从一批样本获得样本距离矩阵,所述样本距离矩阵包括所述一批样本中的同类样本距离和异类样本距离;传输矩阵获得单元220,用于计算出与所述距离矩阵获得单元210所获得的所述样本距离矩阵对应的最优传输规划矩阵;损失函数确定单元230,用于基于所述距离矩阵获得单元210所获得的所述样本距离矩阵与所述传输矩阵获得单元220所获得的所述最优传输规划矩阵的乘积的加权之和,确定最优传输损失函数值;以及,参数更新单元240,用于基于所述损失函数确定单元230所确定的所述最优传输损失函数值更新所述神经网络和所述最优传输规划矩阵的参数。
图5图示了根据本申请实施例的神经网络的训练装置的距离矩阵获得单元的示例的框图。
如图5所示,在如图4所示的实施例的基础上,所述距离矩阵获得单元210包括:同类距离矩阵获得子单元2101,用于通过所述神经网络从所述一批样本获得同类样本距离矩阵;异类距离矩阵获得子单元2102,用于通过所述神经网络从所述一批样本获得异类样本距离矩阵;以及,矩阵合并子单元2103,用于逐元素合并所述同类距离矩阵获得子单元2101所获得的所述同类样本距离矩阵和所述异类距离矩阵获得子单元2102所获得的所述异类样本距离矩阵以获得所述样本距离矩阵。
在一个示例中,在根据本申请实施例的神经网络的训练装置200中,所述同类样本距离是第一幂函数,所述第一幂函数的底数是自然常数,且所述第一幂函数的指数是缩放参数与样本间的欧氏距离的乘积;以及,所述异类样本距离是第二幂函数,所述第二幂函数的底数是自然常数,所述第二幂函数的指数是所述缩放参数与样本间的欧氏距离的铰链损失值的乘积;
在一个示例中,在根据本申请实施例的神经网络的训练装置200中,所述损失函数确定单元230用于:将用于表示一对样本间的相似性的二元标签作为用于计算所述加权和的权重;以及,基于所述权重来对所述样本距离矩阵与所述最优传输规划矩阵的乘积加权以获得所述最优传输损失函数值。
在一个示例中,在根据本申请实施例的神经网络的训练装置200中,所述参数更新单元240用于基于所述最优传输损失函数值通过梯度下降的方式更新所述神经网络。
在一个示例中,在根据本申请实施例的神经网络的训练装置200中,所述参数更新单元240用于:计算所述最优传输规划矩阵与其对应的特征差值以及所述权重的差值的乘积;将所述乘积对于所述多个样本求和以计算出所述最优传输损失函数值的梯度;以及,基于所述梯度通过梯度下降的方式更新所述神经网络。
在一个示例中,在根据本申请实施例的神经网络的训练装置200中,所述距离矩阵获得单元210用于:通过深度神经网络从所述一批样本获得特征图;以及,通过深度矩阵学习网络从所述特征图获得所述样本距离矩阵。
在一个示例中,在根据本申请实施例的神经网络的训练装置200中,所述一批样本包括二维图像、二维手绘草图和三维物体形状中的至少一个。
这里,本领域技术人员可以理解,上述神经网络的训练装置200中的各个单元和模块的具体功能和操作已经在上面参考图2和图3的神经网络的训练方法的描述中得到了详细介绍,并因此,将省略其重复描述。
如上所述,根据本申请实施例的神经网络的训练装置200可以实现在各种终端设备中,例如用于进行对象识别任务的服务器等。在一个示例中,根据本申请实施例的神经网络的训练装置200可以作为一个软件模块和/或硬件模块而集成到终端设备中。例如,该神经网络的训练装置200可以是该终端设备的操作系统中的一个软件模块,或者可以是针对于该终端设备所开发的一个应用程序;当然,该神经网络的训练装置200同样可以是该终端设备的众多硬件模块之一。
替换地,在另一示例中,该神经网络的训练装置200与该终端设备也可以是分立的设备,并且该神经网络的训练装置200可以通过有线和/或无线网络连接到该终端设备,并且按照约定的数据格式来传输交互信息。
示例性电子设备
下面,参考图6来描述根据本申请实施例的电子设备。
图6图示了根据本申请实施例的电子设备的框图。
如图6所示,电子设备10包括一个或多个处理器11和存储器12。
处理器13可以是中央处理单元(CPU)或者具有数据处理能力和/或指令执行能力的其他形式的处理单元,并且可以控制电子设备10中的其他组件以执行期望的功能。
存储器12可以包括一个或多个计算机程序产品,所述计算机程序产品可以包括各种形式的计算机可读存储介质,例如易失性存储器和/或非易失性存储器。所述易失性存储器例如可以包括随机存取存储器(RAM)和/或高速缓冲存储器(cache)等。所述非易失性存储器例如可以包括只读存储器(ROM)、硬盘、闪存等。在所述计算机可读存储介质上可以存储一个或多个计算机程序指令,处理器11可以运行所述程序指令,以实现上文所述的本申请的各个实施例的神经网络的训练方法以及/或者其他期望的功能。在所述计算机可读存储介质中还可以存储诸如样本距离矩阵,最优传输规划矩阵等各种内容。
在一个示例中,电子设备10还可以包括:输入装置13和输出装置14,这些组件通过总线系统和/或其他形式的连接机构(未示出)互连。
该输入装置13可以包括例如键盘、鼠标等等。
该输出装置14可以向外部输出各种信息,包括训练好的神经网络等。该输出装置14可以包括例如显示器、扬声器、打印机、以及通信网络及其所连接的远程输出设备等等。
当然,为了简化,图6中仅示出了该电子设备10中与本申请有关的组件中的一些,省略了诸如总线、输入/输出接口等等的组件。除此之外,根据具体应用情况,电子设备10还可以包括任何其他适当的组件。
示例性计算机程序产品和计算机可读存储介质
除了上述方法和设备以外,本申请的实施例还可以是计算机程序产品,其包括计算机程序指令,所述计算机程序指令在被处理器运行时使得所述处理器执行本说明书上述“示例性方法”部分中描述的根据本申请各种实施例的神经网络的训练方法中的步骤。
所述计算机程序产品可以以一种或多种程序设计语言的任意组合来编写用于执行本申请实施例操作的程序代码,所述程序设计语言包括面向对象的程序设计语言,诸如Java、C++等,还包括常规的过程式程序设计语言,诸如“C”语言或类似的程序设计语言。程序代码可以完全地在用户计算设备上执行、部分地在用户设备上执行、作为一个独立的软件包执行、部分在用户计算设备上部分在远程计算设备上执行、或者完全在远程计算设备或服务器上执行。
此外,本申请的实施例还可以是计算机可读存储介质,其上存储有计算机程序指令,所述计算机程序指令在被处理器运行时使得所述处理器执行本说明书上述“示例性方法”部分中描述的根据本申请各种实施例的神经网络的训练方法中的步骤。
所述计算机可读存储介质可以采用一个或多个可读介质的任意组合。可读介质可以是可读信号介质或者可读存储介质。可读存储介质例如可以包括但不限于电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。
以上结合具体实施例描述了本申请的基本原理,但是,需要指出的是,在本申请中提及的优点、优势、效果等仅是示例而非限制,不能认为这些优点、优势、效果等是本申请的各个实施例必须具备的。另外,上述公开的具体细节仅是为了示例的作用和便于理解的作用,而非限制,上述细节并不限制本申请为必须采用上述具体的细节来实现。
本申请中涉及的器件、装置、设备、系统的方框图仅作为例示性的例子并且不意图要求或暗示必须按照方框图示出的方式进行连接、布置、配置。如本领域技术人员将认识到的,可以按任意方式连接、布置、配置这些器件、装置、设备、系统。诸如“包括”、“包含”、“具有”等等的词语是开放性词汇,指“包括但不限于”,且可与其互换使用。这里所使用的词汇“或”和“和”指词汇“和/或”,且可与其互换使用,除非上下文明确指示不是如此。这里所使用的词汇“诸如”指词组“诸如但不限于”,且可与其互换使用。
还需要指出的是,在本申请的装置、设备和方法中,各部件或各步骤是可以分解和/或重新组合的。这些分解和/或重新组合应视为本申请的等效方案。
提供所公开的方面的以上描述以使本领域的任何技术人员能够做出或者使用本申请。对这些方面的各种修改对于本领域技术人员而言是非常显而易见的,并且在此定义的一般原理可以应用于其他方面而不脱离本申请的范围。因此,本申请不意图被限制到在此示出的方面,而是按照与在此公开的原理和新颖的特征一致的最宽范围。
为了例示和描述的目的已经给出了以上描述。此外,此描述不意图将本申请的实施例限制到在此公开的形式。尽管以上已经讨论了多个示例方面和实施例,但是本领域技术人员将认识到其某些变型、修改、改变、添加和子组合。
Claims (12)
1.一种神经网络的训练方法,包括:
通过神经网络从一批样本获得样本距离矩阵,所述样本距离矩阵包括所述一批样本中的同类样本距离和异类样本距离;
计算出与所述样本距离矩阵对应的最优传输规划矩阵;
基于所述样本距离矩阵与所述最优传输规划矩阵的乘积的加权之和,确定最优传输损失函数值;以及
基于所述最优传输损失函数值更新所述神经网络和所述最优传输规划矩阵的参数。
2.如权利要求1所述的神经网络的训练方法,其中,通过神经网络从一批样本获得样本距离矩阵包括:
通过所述神经网络从所述一批样本获得同类样本距离矩阵;
通过所述神经网络从所述一批样本获得异类样本距离矩阵;以及
逐元素合并所述同类样本距离矩阵和所述异类样本距离矩阵以获得样本距离矩阵。
3.如权利要求2所述的神经网络的训练方法,其中,
所述同类样本距离是第一幂函数,所述第一幂函数的底数是自然常数,且所述第一幂函数的指数是缩放参数与样本间的欧氏距离的乘积;以及
所述异类样本距离是第二幂函数,所述第二幂函数的底数是自然常数,所述第二幂函数的指数是所述缩放参数与样本间的欧氏距离的铰链损失值的乘积。
4.如权利要求1所述的神经网络的训练方法,其中,基于所述样本距离矩阵与所述最优传输规划矩阵的乘积的加权和计算最优传输损失函数值包括:
将用于表示一对样本间的相似性的二元标签作为用于计算所述加权和的权重;以及
基于所述权重来对所述样本距离矩阵与所述最优传输规划矩阵的乘积加权以获得所述最优传输损失函数值。
5.如权利要求4所述的神经网络的训练方法,其中,基于所述最优传输损失函数值更新所述神经网络和所述最优传输规划矩阵的参数包括:
基于所述最优传输损失函数值通过梯度下降的方式更新所述神经网络。
6.如权利要求5所述的神经网络的训练方法,其中,基于所述最优传输损失函数值通过梯度下降的方式更新所述神经网络包括:
计算所述最优传输规划矩阵与其对应的特征差值以及所述权重的差值的乘积;
将所述乘积对于所述多个样本求和以计算出所述最优传输损失函数值的梯度;以及
基于所述梯度通过梯度下降的方式更新所述神经网络。
7.如权利要求1所述的神经网络的训练方法,其中,通过神经网络从一批样本获得样本距离矩阵包括:
通过深度神经网络从所述一批样本获得特征图;以及
通过深度矩阵学习网络从所述特征图获得所述样本距离矩阵。
8.如权利要求1所述的神经网络的训练方法,其中,所述一批样本包括二维图像、二维手绘草图和三维物体形状中的至少一个。
9.一种神经网络的训练装置,包括:
距离矩阵获得单元,用于通过神经网络从一批样本获得样本距离矩阵,所述样本距离矩阵包括所述一批样本中的同类样本距离和异类样本距离;
传输矩阵获得单元,用于计算出与所述距离矩阵获得单元所获得的所述样本距离矩阵对应的最优传输规划矩阵;
损失函数确定单元,用于基于所述距离矩阵获得单元所获得的所述样本距离矩阵与所述传输矩阵获得单元所获得的所述最优传输规划矩阵的乘积的加权之和,确定最优传输损失函数值;以及
参数更新单元,用于基于所述损失函数确定单元所确定的所述最优传输损失函数值更新所述神经网络和所述最优传输规划矩阵的参数。
10.如权利要求9所述的神经网络的训练装置,其中,所述距离矩阵获得单元包括:
同类距离矩阵获得子单元,用于通过所述神经网络从所述一批样本获得同类样本距离矩阵;
异类距离矩阵获得子单元,用于通过所述神经网络从所述一批样本获得异类样本距离矩阵;以及
矩阵合并子单元,用于逐元素合并所述同类距离矩阵获得子单元所获得的所述同类样本距离矩阵和所述异类距离矩阵获得子单元所获得的所述异类样本距离矩阵以获得样本距离矩阵。
11.一种电子设备,包括:
处理器;以及
存储器,在所述存储器中存储有计算机程序指令,所述计算机程序指令在被所述处理器运行时使得所述处理器执行如权利要求1-8中任一项所述的神经网络的训练方法。
12.一种计算机可读介质,其上存储有计算机程序指令,所述计算机程序指令在被处理器运行时使得所述处理器执行如权利要求1-8中任一项所述的神经网络的训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910507780.1A CN112085041B (zh) | 2019-06-12 | 2019-06-12 | 神经网络的训练方法、训练装置和电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910507780.1A CN112085041B (zh) | 2019-06-12 | 2019-06-12 | 神经网络的训练方法、训练装置和电子设备 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112085041A true CN112085041A (zh) | 2020-12-15 |
CN112085041B CN112085041B (zh) | 2024-07-12 |
Family
ID=73733574
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910507780.1A Active CN112085041B (zh) | 2019-06-12 | 2019-06-12 | 神经网络的训练方法、训练装置和电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112085041B (zh) |
Cited By (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112598091A (zh) * | 2021-03-08 | 2021-04-02 | 北京三快在线科技有限公司 | 一种训练模型和小样本分类的方法及装置 |
CN112699811A (zh) * | 2020-12-31 | 2021-04-23 | 中国联合网络通信集团有限公司 | 活体检测方法、装置、设备、储存介质及程序产品 |
CN112884204A (zh) * | 2021-01-22 | 2021-06-01 | 中国科学院信息工程研究所 | 网络安全风险事件预测方法及装置 |
CN113065636A (zh) * | 2021-02-27 | 2021-07-02 | 华为技术有限公司 | 一种卷积神经网络的剪枝处理方法、数据处理方法及设备 |
CN113516227A (zh) * | 2021-06-08 | 2021-10-19 | 华为技术有限公司 | 一种基于联邦学习的神经网络训练方法及设备 |
CN116628507A (zh) * | 2023-07-20 | 2023-08-22 | 腾讯科技(深圳)有限公司 | 数据处理方法、装置、设备及可读存储介质 |
CN116913259A (zh) * | 2023-09-08 | 2023-10-20 | 中国电子科技集团公司第十五研究所 | 结合梯度引导的语音识别对抗防御方法及装置 |
WO2023232031A1 (zh) * | 2022-05-31 | 2023-12-07 | 中国第一汽车股份有限公司 | 一种神经网络模型的训练方法、装置、电子设备及介质 |
Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2017215240A1 (zh) * | 2016-06-14 | 2017-12-21 | 广州视源电子科技股份有限公司 | 基于神经网络的人脸特征提取建模、人脸识别方法及装置 |
CN108108754A (zh) * | 2017-12-15 | 2018-06-01 | 北京迈格威科技有限公司 | 重识别网络的训练、重识别方法、装置和系统 |
CN108399428A (zh) * | 2018-02-09 | 2018-08-14 | 哈尔滨工业大学深圳研究生院 | 一种基于迹比准则的三元组损失函数设计方法 |
CN109086871A (zh) * | 2018-07-27 | 2018-12-25 | 北京迈格威科技有限公司 | 神经网络的训练方法、装置、电子设备和计算机可读介质 |
CN109426858A (zh) * | 2017-08-29 | 2019-03-05 | 京东方科技集团股份有限公司 | 神经网络、训练方法、图像处理方法及图像处理装置 |
CN109558821A (zh) * | 2018-11-21 | 2019-04-02 | 哈尔滨工业大学(深圳) | 一种视频中特定人物的服装件数计算方法 |
CN109816092A (zh) * | 2018-12-13 | 2019-05-28 | 北京三快在线科技有限公司 | 深度神经网络训练方法、装置、电子设备及存储介质 |
-
2019
- 2019-06-12 CN CN201910507780.1A patent/CN112085041B/zh active Active
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2017215240A1 (zh) * | 2016-06-14 | 2017-12-21 | 广州视源电子科技股份有限公司 | 基于神经网络的人脸特征提取建模、人脸识别方法及装置 |
CN109426858A (zh) * | 2017-08-29 | 2019-03-05 | 京东方科技集团股份有限公司 | 神经网络、训练方法、图像处理方法及图像处理装置 |
CN108108754A (zh) * | 2017-12-15 | 2018-06-01 | 北京迈格威科技有限公司 | 重识别网络的训练、重识别方法、装置和系统 |
CN108399428A (zh) * | 2018-02-09 | 2018-08-14 | 哈尔滨工业大学深圳研究生院 | 一种基于迹比准则的三元组损失函数设计方法 |
CN109086871A (zh) * | 2018-07-27 | 2018-12-25 | 北京迈格威科技有限公司 | 神经网络的训练方法、装置、电子设备和计算机可读介质 |
CN109558821A (zh) * | 2018-11-21 | 2019-04-02 | 哈尔滨工业大学(深圳) | 一种视频中特定人物的服装件数计算方法 |
CN109816092A (zh) * | 2018-12-13 | 2019-05-28 | 北京三快在线科技有限公司 | 深度神经网络训练方法、装置、电子设备及存储介质 |
Cited By (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112699811B (zh) * | 2020-12-31 | 2023-11-03 | 中国联合网络通信集团有限公司 | 活体检测方法、装置、设备、储存介质及程序产品 |
CN112699811A (zh) * | 2020-12-31 | 2021-04-23 | 中国联合网络通信集团有限公司 | 活体检测方法、装置、设备、储存介质及程序产品 |
CN112884204A (zh) * | 2021-01-22 | 2021-06-01 | 中国科学院信息工程研究所 | 网络安全风险事件预测方法及装置 |
CN112884204B (zh) * | 2021-01-22 | 2024-04-12 | 中国科学院信息工程研究所 | 网络安全风险事件预测方法及装置 |
CN113065636A (zh) * | 2021-02-27 | 2021-07-02 | 华为技术有限公司 | 一种卷积神经网络的剪枝处理方法、数据处理方法及设备 |
CN113065636B (zh) * | 2021-02-27 | 2024-06-07 | 华为技术有限公司 | 一种卷积神经网络的剪枝处理方法、数据处理方法及设备 |
CN112598091A (zh) * | 2021-03-08 | 2021-04-02 | 北京三快在线科技有限公司 | 一种训练模型和小样本分类的方法及装置 |
CN113516227A (zh) * | 2021-06-08 | 2021-10-19 | 华为技术有限公司 | 一种基于联邦学习的神经网络训练方法及设备 |
WO2023232031A1 (zh) * | 2022-05-31 | 2023-12-07 | 中国第一汽车股份有限公司 | 一种神经网络模型的训练方法、装置、电子设备及介质 |
CN116628507B (zh) * | 2023-07-20 | 2023-10-27 | 腾讯科技(深圳)有限公司 | 数据处理方法、装置、设备及可读存储介质 |
CN116628507A (zh) * | 2023-07-20 | 2023-08-22 | 腾讯科技(深圳)有限公司 | 数据处理方法、装置、设备及可读存储介质 |
CN116913259B (zh) * | 2023-09-08 | 2023-12-15 | 中国电子科技集团公司第十五研究所 | 结合梯度引导的语音识别对抗防御方法及装置 |
CN116913259A (zh) * | 2023-09-08 | 2023-10-20 | 中国电子科技集团公司第十五研究所 | 结合梯度引导的语音识别对抗防御方法及装置 |
Also Published As
Publication number | Publication date |
---|---|
CN112085041B (zh) | 2024-07-12 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112085041B (zh) | 神经网络的训练方法、训练装置和电子设备 | |
CN111797893B (zh) | 一种神经网络的训练方法、图像分类系统及相关设备 | |
US20230016365A1 (en) | Method and apparatus for training text classification model | |
EP3940638A1 (en) | Image region positioning method, model training method, and related apparatus | |
US20220180882A1 (en) | Training method and device for audio separation network, audio separation method and device, and medium | |
EP4131030A1 (en) | Method and apparatus for searching for target | |
US12100192B2 (en) | Method, apparatus, and electronic device for training place recognition model | |
US9519868B2 (en) | Semi-supervised random decision forests for machine learning using mahalanobis distance to identify geodesic paths | |
US20210406266A1 (en) | Computerized information extraction from tables | |
CN111414987A (zh) | 神经网络的训练方法、训练装置和电子设备 | |
CN111898374B (zh) | 文本识别方法、装置、存储介质和电子设备 | |
CN111930894B (zh) | 长文本匹配方法及装置、存储介质、电子设备 | |
CN114298122B (zh) | 数据分类方法、装置、设备、存储介质及计算机程序产品 | |
CN111898636B (zh) | 一种数据处理方法及装置 | |
CN109918506A (zh) | 一种文本分类方法及装置 | |
CN113254716B (zh) | 视频片段检索方法、装置、电子设备和可读存储介质 | |
WO2023231753A1 (zh) | 一种神经网络的训练方法、数据的处理方法以及设备 | |
CN112308131B (zh) | 样本拒识方法、装置、设备及存储介质 | |
CN114611672A (zh) | 模型训练方法、人脸识别方法及装置 | |
CN117992805A (zh) | 基于张量积图融合扩散的零样本跨模态检索方法、系统 | |
CN112861474B (zh) | 一种信息标注方法、装置、设备及计算机可读存储介质 | |
CN113822143A (zh) | 文本图像的处理方法、装置、设备以及存储介质 | |
CN113705293A (zh) | 图像场景的识别方法、装置、设备及可读存储介质 | |
US20240028952A1 (en) | Apparatus for attribute path generation | |
CN113822293A (zh) | 用于图数据的模型处理方法、装置、设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |