CN112101544A - 适用于长尾分布数据集的神经网络的训练方法和装置 - Google Patents

适用于长尾分布数据集的神经网络的训练方法和装置 Download PDF

Info

Publication number
CN112101544A
CN112101544A CN202010851530.2A CN202010851530A CN112101544A CN 112101544 A CN112101544 A CN 112101544A CN 202010851530 A CN202010851530 A CN 202010851530A CN 112101544 A CN112101544 A CN 112101544A
Authority
CN
China
Prior art keywords
training
gradient
feature extraction
training sample
network
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202010851530.2A
Other languages
English (en)
Inventor
丁贵广
项刘宇
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tsinghua University
Original Assignee
Tsinghua University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tsinghua University filed Critical Tsinghua University
Priority to CN202010851530.2A priority Critical patent/CN112101544A/zh
Publication of CN112101544A publication Critical patent/CN112101544A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明提出一种适用于长尾分布数据集的神经网络的训练方法和装置,神经网络包括:特征提取网络,分类器,类别梯度重加权网络,训练方法包括:获取训练样本集;特征提取网络对训练样本集进行特征提取得到特征,并通过分类器对特征进行分类,根据分类结果建立损失函数;根据损失函数计算特征提取网络中各个神经元在训练样本的梯度;在神经网络训练的反向传播的过程中,类别梯度重加权网络计算训练样本的重加权梯度权重,根据重加权梯度权重调整属于不同类别的训练样本的梯度。由此,解决神经网络在长尾分布的训练数据下识别准确率下降问题,缓解特征提取网络过拟合现象,提高深度神经网络在长尾分布下的识别准确率和鲁棒性。

Description

适用于长尾分布数据集的神经网络的训练方法和装置
技术领域
本发明涉及人工智能技术领域与深度学习技术领域,尤其涉及一种适用于长尾分布数据集的神经网络的训练方法和装置。
背景技术
随着深度学习和神经网络的快速发展,深度学习技术被广泛应用于计算机视觉应用中,如目标识别、目标检测、语义分割等。而训练一个神经网络往往需要数据量充足、数据分布均衡的训练数据。而收集这样的训练数据往往需要耗费大量的人力和物力。在早期的基于神经网络的目标识别算法中,往往使用数据均衡且数据量较小的数据集,如MNIST和CIFAR。前者是手写数字识别训练数据,后者则是通用目标识别训练数据,其类别分布均是均衡的,即每个类别的样本具有相同的数量。而这样的训练数据往往与实际场景是脱节的。
实际应用上述训练数据的一大区别是,在实际应用中,自然界的语义概念等的分布往往是服从长尾分布的,即少数训练数据占据了绝大多数的出现次数(头部数据类别),而大多数训练数据则出现频率较低(尾部数据类别),这样分布会导致在收集训练数据的过程中往往会引入长尾分布效应,造成训练数据类别分布的不均衡。而在传统的分布均衡的数据集上提出的深度学习算法则往往难以处理长尾分布下的目标识别问题。由于长尾分布的普遍性,在长尾分布下的深度学习方法也相继被研究者们提出。
针对长尾分布下的神经网络训练方法主要可以分为三大类:重加权法、重采样法和知识迁移法。
其中,重加权法主要通过提出与类别数量相关的损失函数,从而降低长尾分布带来的数据不均衡的影响。具体而言,由于长尾分布中的头部数据类别占据了数据量的大部分,而尾部数据类别占据了较小的部分,因此在训练集中头部数据类别对应的分类器神经元往往占据了主导地位,而尾部数据对应的分类器神经元被相应的抑制。而重加权方法则是通过修改损失函数,削弱头部数据类别的主导地位,即为头部类别的样本对应的损失函数分配一个较小的权重,而为尾部类别样本对应的损失函数分配一个较大的权重,最终通过端到端的训练完成对长尾分布数据的识别任务。由于重加权的方法往往是在损失函数上进行修改,针对极度不均衡的长尾分布往往难以有较好的效果。
重采样法则通过设计类别均衡的采样策略对长尾分布的目标识别问题进行处理。由于神经网络的训练过程大多是基于小批量的,因此每次训练过程中都需要对整个训练数据进行小批量的采样,从而得到小批量的训练数据进行神经网络的训练。而由于对长尾分布的训练数据的采样过程中,尾部数据数量较少,因此很难被采样到,而头部数据由于数量众多,被采样到的频率则会过高。重采样技术则通过增加尾部数据被采样到的概率,或降低头部数据被采样到的概率,较为常见的做法按照相同概率对每个类别进行采样,而不是按照相同概率对每个样本进行采样,从而缓解长尾分布下的类别不均衡带来的挑战,但与此同时,也会因为重采样带来尾部数据过拟合或头部数据信息丢失的问题。
知识迁移法则是利用头部数据的丰富隐含的知识借以辅助尾部数据的训练,首先利用神经网络的头部数据进行训练,从而隐式地获取数据类别的分布信息,进而对数据量稀疏的尾部数据进行增强。具体形式包括训练从头部类别到尾部类别的分类器映射,或借助头部数据和预训练网络产生“伪”尾部数据辅助神经网络训练。这类方法往往需要数据集的类别具有较高的相似性,才能够完成较高质量的知识迁移。
上述三种方法大多都能够缓解神经网络在长尾分布下的识别困难,但也都有各自的局限性,同时针对神经网络本身的特点分析不足。
因此针对实际应用中的长尾数据分布提出有效的神经网络训练算法具有重要的意义和价值。
发明内容
本发明旨在至少在一定程度上解决相关技术中的技术问题之一。
为此,本发明的目的在于提出一种适用于长尾分布数据集的神经网络的训练方法和装置,其目的是解决神经网络在长尾分布的训练数据下识别准确率下降问题,缓解特征提取网络过拟合现象,提高深度神经网络在长尾分布下的识别准确率和鲁棒性。
根据本发明的适用于长尾分布数据集的神经网络,包括:特征提取网络,分类器,类别梯度重加权网络,其中,类别梯度重加权网络设置在特征提取网络和分类器之间。
根据本发明的适用于长尾分布数据集的神经网络的训练方法包括:获取训练样本集;特征提取网络对训练样本集进行特征提取得到特征,并通过分类器对特征进行分类,根据分类结果建立损失函数;根据损失函数计算特征提取网络中各个神经元在训练样本的梯度;在神经网络训练的反向传播的过程中,类别梯度重加权网络计算训练样本的重加权梯度权重,根据重加权梯度权重调整属于不同类别的训练样本的梯度。
另外,根据本发明的适用于长尾分布数据集的神经网络的训练方法还可以具有如下附加的技术特征:
根据本发明的一些实施例,获取训练样本集的方法可以是按照类别均衡采样方式从训练数据中获取训练样本集。
根据本发明的一些实施例,获取前向传导函数R(x),其中,前向传导函数R(x)用于指示特征提取网络中各个神经元的传播方向。
根据本发明的一些实施例,类别梯度重加权网络计算各个模块的重加权梯度权重的公式为:
Figure BDA0002644890970000031
Figure BDA0002644890970000032
其中,Nc,Nmax分别为元素x所属的类别和训练样本集中数量最多的类别包含的样本数量,β为超参数,I为单位矩阵。
根据本发明的一些实施例,根据分类结果建立损失函数,包括:获取各个分类结果对应的概率,以及训练样本集中各个元素对应的标注分类结果;根据各个分类结果对应的概率和标注分类结果建立损失函数。
为达到上述目的,本发明第二方面实施例提出了一种适用于长尾分布数据集的神经网络的训练装置,包括:获取模块、特征提取模块、梯度计算模块、梯度重加权模块,其中,获取模块用于获取训练样本集;特征提取模块用于控制特征提取网络对训练样本集进行特征提取得到特征,并通过分类器对特征进行分类,根据分类结果建立损失函数;梯度计算模块用于根据损失函数计算特征提取网络中各个神经元在训练样本的梯度;梯度重加权模块用于在神经网络训练的反向传播的过程中,类别梯度重加权网络计算训练样本的重加权梯度权重,根据重加权梯度权重调整属于不同类别的训练样本的梯度。
另外,根据本发明上述实施例的适用于长尾分布数据集的神经网络的训练装置,还可以具有如下附加的技术特征:
进一步地,在本申请实施例的一种可能的实现方式中,获取模块具体用于按照类别均衡采样法从训练数据中获取训练样本集。
进一步地,在本申请实施例的一种可能的实现方式中,特征提取模块获取前向传导函数R(x),其中,前向传导函数R(x)用于指示特征提取网络中各个神经元的传播方向。
进一步地,在本申请实施例的一种可能的实现方式中,梯度重加权模块计算各个模块的重加权梯度权重的公式为:
Figure BDA0002644890970000033
Figure BDA0002644890970000041
其中,Nc,Nmax分别为元素x所属的类别和训练样本集中数量最多的类别包含的样本数量,β为超参数,I为单位矩阵。
进一步地,在本申请实施例的一种可能的实现方式中,特征提取模块具体用于获取各个分类结果对应的概率,以及训练样本集中各个元素对应的标注分类结果;根据各个分类结果对应的概率和标注分类结果建立损失函数。
本发明实施例提供的适用于长尾分布数据集的神经网络的训练方法可以包含如下有益效果:
根据本发明的适用于长尾分布数据集的神经网络包括:特征提取网络,分类器,类别梯度重加权网络,其中,类别梯度重加权网络设置在特征提取网络和分类器之间。根据本发明的适用于长尾分布数据集的神经网络的训练方法包括:获取训练样本集;特征提取网络对训练样本集进行特征提取得到特征,并通过分类器对特征进行分类,根据分类结果建立损失函数;根据损失函数计算特征提取网络中各个神经元在训练样本的梯度;在神经网络训练的反向传播的过程中,类别梯度重加权网络计算训练样本的重加权梯度权重,根据重加权梯度权重调整属于不同类别的训练样本的梯度。由此,解决神经网络在长尾分布的训练数据下识别准确率下降问题,缓解特征提取网络过拟合现象,提高深度神经网络在长尾分布下的识别准确率和鲁棒性。
本发明附加的方面和优点将在下面的描述中部分给出,部分将从下面的描述中变得明显,或通过本发明的实践了解到。
附图说明
本发明上述的和/或附加的方面和优点从下面结合附图对实施例的描述中将变得明显和容易理解,其中:
图1为本发明实施例的一种适用于长尾分布数据集的神经网络的训练方法的流程示意图;
图2为本发明实施例的一种适用于长尾分布数据集的神经网络的示意图;
图3为本发明实施例的一种适用于长尾分布数据集的神经网络的训练装置的结构示意图。
具体实施方式
下面详细描述本发明的实施例,实施例的示例在附图中示出,其中自始至终相同或类似的标号表示相同或类似的元件或具有相同或类似功能的元件。下面通过参考附图描述的实施例是示例性的,旨在用于解释本发明,而不能理解为对本发明的限制。
长尾分布是一种广泛适用于人类日常生活中的一种数据分布形式,以互联网上歌曲和软件的下载为例,流行度前几名的热门歌曲和软件会被大量下载,随着流行度的下降,大量的歌曲和软件的下载量很快越来越低,但是即使是流行度非常低的歌曲和软件仍然保持着一定的下载量,这类数据分布形式就属于长尾分布,除此以外,人类语言的使用频次也符合长尾分布的特点,即少数词汇的使用占据日常用语的绝大部分,大多数的词汇使用频率很低,但是仍然有一定的使用频次。诸如此类的还包括:网页的点击量,书籍的购买需求等,所以研究一种适用于长尾分布数据集的神经网络的训练方法有着重大的实际意义。
本发明的目的是提供一种适用于长尾分布数据集的神经网络的训练方法和装置,从而解决神经网络在长尾分布的训练数据下识别准确率下降问题,缓解特征提取网络过拟合现象,提高深度神经网络在长尾分布下的识别准确率和鲁棒性。
下面参照附图描述根据本发明实施例提出的适用于长尾分布数据集的神经网络的训练方法。
图1为本发明实施例的一种适用于长尾分布数据集的神经网络的训练方法的流程示意图。如图1所示,该适用于长尾分布数据集的神经网络的训练方法,包括:
步骤101,获取训练样本集。
其中,训练样本集可以理解为从大量已经标记好不同类别标签的训练数据中,按照一定方法挑选出的一部分元素组成的集合,在本实施例中,挑选的方法为类别均衡采样法,即训练样本集中标记有不同类别标签的样本元素的数量是一样的。
具体的,按照指定的方法从大量训练数据中挑选出一定数量的元素作为训练样本集。
假设有训练数据集Di={(xi,yi)}i=1,2,...,N。一共有C个类别。传统随机采样方法往往是为每个元素分配一样的权重,若对于c类别而言,记属于该类别的样本数量为Nc,则对于训练过程中的某一训练样本集,采样到类别为的样本的概率为公式(1):
Figure BDA0002644890970000051
而类别均衡的数据采样则是通过为每个类别分配相同的概率被采样。对于训练过程中的某一训练样本集而言,其属于类别c的样本被采样到的概率为公式(2):
Figure BDA0002644890970000052
显然,通过类别均衡采样能够比传统随机采样更好的应对长尾分布的识别问题。
步骤102,特征提取网络对训练样本集进行特征提取得到特征,并通过分类器对特征进行分类,根据分类结果建立损失函数。
具体的,特征提取网络将训练样本集中的元素一一提出,从输入端输入,并在输出端获得特征预测概率,然后将特征预测概率输入到分类器中,分类器结合每个元素的标注分类结果和特征预测概率建立损失函数。
另外,特征提取网络还会获取前向传导函数R(x),在前向特征提取过程中不做任何修改,其中,前向传导函数R(x)用于指示特征提取网络中各个神经元的传播方向。
步骤103,根据损失函数计算特征提取网络中各个神经元在训练样本的梯度。
其中,各个神经元在训练样本的梯度可以理解为各个神经元的参数的大小或者参数的数量等信息。
具体的,对损失函数求导获得特征提取网络中各个神经元在训练样本的梯度。
步骤104,在神经网络训练的反向传播的过程中,类别梯度重加权网络计算训练样本的重加权梯度权重,根据重加权梯度权重调整属于不同类别的训练样本的梯度。
具体的,如图2所示,在本申请的适用于长尾分布数据集的神经网络的训练过程中,在反向传播阶段,类别梯度重加权网络计算属于不同类别的训练样本的重加权梯度权重的公式为公式(3)、公式(4):
Figure BDA0002644890970000061
Figure BDA0002644890970000062
其中,Nc,Nmax分别为训练样本集中任一元素x所属的类别和训练样本集中数量最多的类别包含的样本数量,β为超参数,I为单位矩阵。
由上述可知,在本发明的实施例中,本发明的适用于长尾分布数据集的神经网络的训练方法:获取训练样本集;特征提取网络对训练样本集进行特征提取得到特征,并通过分类器对特征进行分类,根据分类结果建立损失函数;根据损失函数计算特征提取网络中各个神经元在训练样本的梯度;在神经网络训练的反向传播的过程中,类别梯度重加权网络计算训练样本的重加权梯度权重,根据重加权梯度权重调整属于不同类别的训练样本的梯度。由此,解决神经网络在长尾分布的训练数据下识别准确率下降问题,缓解特征提取网络过拟合现象,提高深度神经网络在长尾分布下的识别准确率和鲁棒性。
为了实现上述实施例,本申请还提出一种适用于长尾分布数据集的神经网络的训练装置。
图3为本发明实施例提供的一种适用于长尾分布数据集的神经网络的训练装置的结构示意图。
如图3所示,该装置包括:获取模块301、特征提取模块302、梯度计算模块303、梯度重加权模块304。
获取模块301,用于获取训练样本集。
特征提取模块302,用于控制特征提取网络对训练样本集进行特征提取得到特征,并通过分类器对特征进行分类,根据分类结果建立损失函数;
梯度计算模块303,用于根据损失函数计算特征提取网络中各个神经元在训练样本的梯度;
梯度重加权模块304,用于在神经网络训练的反向传播的过程中,类别梯度重加权网络计算训练样本的重加权梯度权重,根据重加权梯度权重调整属于不同类别的训练样本的梯度。
进一步地,在本申请实施例的一种可能的实现方式中,获取模块301具体用于按照类别均衡采样法从训练数据中获取训练样本集。
进一步地,在本申请实施例的一种可能的实现方式中,特征提取模块302会获取前向传导函数R(x),在前向特征提取过程中不做任何修改,前向传导函数R(x)用于指示特征提取网络中各个神经元的传播方向。
进一步地,在本申请实施例的一种可能的实现方式中,梯度重加权模块304,具体用于:计算各个模块的重加权梯度权重的公式为公式(3)、公式(4):
Figure BDA0002644890970000071
Figure BDA0002644890970000072
其中,Nc,Nmax分别为训练样本集中任一元素x所属的类别和训练样本集中数量最多的类别包含的样本数量,β为超参数,I为单位矩阵。
进一步地,在本申请实施例的一种可能的实现方式中,特征提取模块,具体用于获取各个分类结果对应的概率,以及训练样本集中各个元素对应的标注分类结果;根据各个分类结果对应的概率和标注分类结果建立损失函数。
需要说明的是,前述对方法实施例的解释说明也适用于该实施例的装置,此处不再赘述。
本发明实施例的适用于长尾分布数据集的神经网络的训练装置,获取模块获取训练样本集;特征提取模块控制特征提取网络对训练样本集进行特征提取得到特征,并通过分类器对特征进行分类,根据分类结果建立损失函数;梯度计算模块根据损失函数计算特征提取网络中各个神经元在训练样本的梯度;梯度重加权模块在神经网络训练的反向传播的过程中,类别梯度重加权网络计算训练样本的重加权梯度权重,根据重加权梯度权重调整属于不同类别的训练样本的梯度。由此,解决神经网络在长尾分布的训练数据下识别准确率下降问题,缓解特征提取网络过拟合现象,提高深度神经网络在长尾分布下的识别准确率和鲁棒性。
在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本发明的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不必须针对的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任一个或多个实施例或示例中以合适的方式结合。此外,在不相互矛盾的情况下,本领域的技术人员可以将本说明书中描述的不同实施例或示例以及不同实施例或示例的特征进行结合和组合。
此外,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括至少一个该特征。在本发明的描述中,“多个”的含义是至少两个,例如两个,三个等,除非另有明确具体的限定。
流程图中或在此以其他方式描述的任何过程或方法描述可以被理解为,表示包括一个或更多个用于实现定制逻辑功能或过程的步骤的可执行指令的代码的模块、片段或部分,并且本发明的优选实施方式的范围包括另外的实现,其中可以不按所示出或讨论的顺序,包括根据所涉及的功能按基本同时的方式或按相反的顺序,来执行功能,这应被本发明的实施例所属技术领域的技术人员所理解。
在流程图中表示或在此以其他方式描述的逻辑和/或步骤,例如,可以被认为是用于实现逻辑功能的可执行指令的定序列表,可以具体实现在任何计算机可读介质中,以供指令执行系统、装置或设备(如基于计算机的系统、包括处理器的系统或其他可以从指令执行系统、装置或设备取指令并执行指令的系统)使用,或结合这些指令执行系统、装置或设备而使用。就本说明书而言,"计算机可读介质"可以是任何可以包含、存储、通信、传播或传输程序以供指令执行系统、装置或设备或结合这些指令执行系统、装置或设备而使用的装置。计算机可读介质的更具体的示例(非穷尽性列表)包括以下:具有一个或多个布线的电连接部(电子装置),便携式计算机盘盒(磁装置),随机存取存储器(RAM),只读存储器(ROM),可擦除可编辑只读存储器(EPROM或闪速存储器),光纤装置,以及便携式光盘只读存储器(CDROM)。另外,计算机可读介质甚至可以是可在其上打印程序的纸或其他合适的介质,因为可以例如通过对纸或其他介质进行光学扫描,接着进行编辑、解译或必要时以其他合适方式进行处理来以电子方式获得程序,然后将其存储在计算机存储器中。
应当理解,本发明的各部分可以用硬件、软件、固件或它们的组合来实现。在上述实施方式中,多个步骤或方法可以用存储在存储器中且由合适的指令执行系统执行的软件或固件来实现。如,如果用硬件来实现和在另一实施方式中一样,可用本领域公知的下列技术中的任一项或他们的组合来实现:具有用于对数据信号实现逻辑功能的逻辑门电路的离散逻辑电路,具有合适的组合逻辑门电路的专用集成电路,可编程门阵列(PGA),现场可编程门阵列(FPGA)等。
本技术领域的普通技术人员可以理解实现上述实施例方法携带的全部或部分步骤是可以通过程序来指令相关的硬件完成,程序可以存储于一种计算机可读存储介质中,该程序在执行时,包括方法实施例的步骤之一或其组合。
此外,在本发明各个实施例中的各功能单元可以集成在一个处理模块中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个模块中。上述集成的模块既可以采用硬件的形式实现,也可以采用软件功能模块的形式实现。集成的模块如果以软件功能模块的形式实现并作为独立的产品销售或使用时,也可以存储在一个计算机可读取存储介质中。
上述提到的存储介质可以是只读存储器,磁盘或光盘等。尽管上面已经示出和描述了本发明的实施例,可以理解的是,上述实施例是示例性的,不能理解为对本发明的限制,本领域的普通技术人员在本发明的范围内可以对上述实施例进行变化、修改、替换和变型。

Claims (10)

1.一种适用于长尾分布数据集的神经网络的训练方法,其特征在于,所述神经网络,包括:特征提取网络,分类器,类别梯度重加权网络,其中,所述类别梯度重加权网络设置在所述特征提取网络和所述分类器之间,包括以下步骤:
获取训练样本集;
所述特征提取网络对所述训练样本集进行特征提取得到特征,并通过所述分类器对所述特征进行分类,根据分类结果建立损失函数;
根据所述损失函数计算所述特征提取网络中各个神经元在训练样本的梯度;
在神经网络训练的反向传播的过程中,所述类别梯度重加权网络计算所述训练样本的重加权梯度权重,根据所述重加权梯度权重调整属于不同类别的训练样本的梯度。
2.如权利要求1所述的训练方法,其特征在于,所述获取训练样本集,包括:
按照类别均衡采样法从训练数据中获取所述训练样本集。
3.如权利要求1所述的训练方法,其特征在于,还包括:
获取前向传导函数R(x),其中,所述前向传导函数R(x)用于指示所述特征提取网络中各个神经元的传播方向。
4.如权利要求3所述的训练方法,其特征在于,所述类别梯度重加权网络计算所述各个模块的重加权梯度权重的公式为:
Figure FDA0002644890960000011
Figure FDA0002644890960000012
其中Nc,Nmax分别为所述训练样本集中任一元素x所属的类别和所述训练样本集中数量最多的类别包含的样本数量,β为超参数,I为单位矩阵。
5.如权利要求1所述的训练方法,其特征在于,所述根据分类结果建立损失函数,包括:
获取各个分类结果对应的概率,以及所述训练样本集中各个元素对应的标注分类结果;
根据所述各个分类结果对应的概率和所述标注分类结果建立所述损失函数。
6.一种适用于长尾分布数据集的神经网络的训练装置,其特征在于,包括:
获取模块,用于获取训练样本集;
特征提取模块,用于控制所述特征提取网络对所述训练样本集进行特征提取得到特征,并通过所述分类器对所述特征进行分类,根据分类结果建立损失函数;
梯度计算模块,用于根据所述损失函数计算所述特征提取网络中各个神经元在训练样本的梯度;
梯度重加权模块,用于在神经网络训练的反向传播的过程中,所述类别梯度重加权网络计算所述训练样本的重加权梯度权重,根据所述重加权梯度权重调整属于不同类别的训练样本的梯度。
7.如权利要求6所述的装置,其特征在于,所述获取模块,具体用于:
按照类别均衡采样法从训练数据中获取所述训练样本集。
8.如权利要求6所述的装置,其特征在于,所述特征提取模块获取前向传导函数R(x),其中,所述前向传导函数R(x)用于指示所述特征提取网络中各个神经元的传播方向。
9.如权利要求8所述的装置,其特征在于,所述梯度重加权模块,具体用于:
计算所述各个模块的重加权梯度权重的公式为:
Figure FDA0002644890960000021
Figure FDA0002644890960000022
其中,Nc,Nmax分别为所述训练样本集中任一元素x所属的类别和所述训练样本集中数量最多的类别包含的样本数量,β为超参数,I为单位矩阵。
10.如权利要求6所述的训练装置,其特征在于,所述特征提取模块,具体用于:
获取各个分类结果对应的概率,以及所述训练样本集中各个元素对应的标注分类结果;
根据所述各个分类结果对应的概率和所述标注分类结果建立所述损失函数。
CN202010851530.2A 2020-08-21 2020-08-21 适用于长尾分布数据集的神经网络的训练方法和装置 Pending CN112101544A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010851530.2A CN112101544A (zh) 2020-08-21 2020-08-21 适用于长尾分布数据集的神经网络的训练方法和装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010851530.2A CN112101544A (zh) 2020-08-21 2020-08-21 适用于长尾分布数据集的神经网络的训练方法和装置

Publications (1)

Publication Number Publication Date
CN112101544A true CN112101544A (zh) 2020-12-18

Family

ID=73754578

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010851530.2A Pending CN112101544A (zh) 2020-08-21 2020-08-21 适用于长尾分布数据集的神经网络的训练方法和装置

Country Status (1)

Country Link
CN (1) CN112101544A (zh)

Cited By (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112632319A (zh) * 2020-12-22 2021-04-09 天津大学 基于迁移学习的提升长尾分布语音总体分类准确度的方法
CN112632320A (zh) * 2020-12-22 2021-04-09 天津大学 基于长尾分布提升语音分类尾部识别准确度的方法
CN112966767A (zh) * 2021-03-19 2021-06-15 焦点科技股份有限公司 一种特征提取和分类任务分离的数据不均衡处理方法
CN113095304A (zh) * 2021-06-08 2021-07-09 成都考拉悠然科技有限公司 减弱重采样对行人重识别的影响的方法
CN113255832A (zh) * 2021-06-23 2021-08-13 成都考拉悠然科技有限公司 双分支多中心的长尾分布识别的方法
CN113688990A (zh) * 2021-09-09 2021-11-23 贵州电网有限责任公司 用于电力边缘计算分类神经网络的无数据量化训练方法
CN114283307A (zh) * 2021-12-24 2022-04-05 中国科学技术大学 一种基于重采样策略的网络训练方法
CN114330573A (zh) * 2021-12-30 2022-04-12 济南博观智能科技有限公司 一种目标检测方法、装置、电子设备及存储介质
CN114463576A (zh) * 2021-12-24 2022-05-10 中国科学技术大学 一种基于重加权策略的网络训练方法
CN114596590A (zh) * 2022-03-15 2022-06-07 北京信智文科技有限公司 一种用于具有长尾分布特性的单猴视频动作分类方法
WO2023137921A1 (zh) * 2022-01-21 2023-07-27 平安科技(深圳)有限公司 基于人工智能的实例分割模型训练方法、装置、存储介质

Cited By (16)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112632320A (zh) * 2020-12-22 2021-04-09 天津大学 基于长尾分布提升语音分类尾部识别准确度的方法
CN112632319A (zh) * 2020-12-22 2021-04-09 天津大学 基于迁移学习的提升长尾分布语音总体分类准确度的方法
CN112632319B (zh) * 2020-12-22 2023-04-11 天津大学 基于迁移学习的提升长尾分布语音总体分类准确度的方法
CN112966767B (zh) * 2021-03-19 2022-03-22 焦点科技股份有限公司 一种特征提取和分类任务分离的数据不均衡处理方法
CN112966767A (zh) * 2021-03-19 2021-06-15 焦点科技股份有限公司 一种特征提取和分类任务分离的数据不均衡处理方法
CN113095304A (zh) * 2021-06-08 2021-07-09 成都考拉悠然科技有限公司 减弱重采样对行人重识别的影响的方法
CN113255832A (zh) * 2021-06-23 2021-08-13 成都考拉悠然科技有限公司 双分支多中心的长尾分布识别的方法
CN113255832B (zh) * 2021-06-23 2021-10-01 成都考拉悠然科技有限公司 双分支多中心的长尾分布识别的方法
CN113688990A (zh) * 2021-09-09 2021-11-23 贵州电网有限责任公司 用于电力边缘计算分类神经网络的无数据量化训练方法
CN114283307A (zh) * 2021-12-24 2022-04-05 中国科学技术大学 一种基于重采样策略的网络训练方法
CN114463576A (zh) * 2021-12-24 2022-05-10 中国科学技术大学 一种基于重加权策略的网络训练方法
CN114283307B (zh) * 2021-12-24 2023-10-27 中国科学技术大学 一种基于重采样策略的网络训练方法
CN114463576B (zh) * 2021-12-24 2024-04-09 中国科学技术大学 一种基于重加权策略的网络训练方法
CN114330573A (zh) * 2021-12-30 2022-04-12 济南博观智能科技有限公司 一种目标检测方法、装置、电子设备及存储介质
WO2023137921A1 (zh) * 2022-01-21 2023-07-27 平安科技(深圳)有限公司 基于人工智能的实例分割模型训练方法、装置、存储介质
CN114596590A (zh) * 2022-03-15 2022-06-07 北京信智文科技有限公司 一种用于具有长尾分布特性的单猴视频动作分类方法

Similar Documents

Publication Publication Date Title
CN112101544A (zh) 适用于长尾分布数据集的神经网络的训练方法和装置
CN111126386B (zh) 场景文本识别中基于对抗学习的序列领域适应方法
Dvornik et al. Selecting relevant features from a multi-domain representation for few-shot classification
Bavkar et al. Multimodal sarcasm detection via hybrid classifier with optimistic logic
CN108960073B (zh) 面向生物医学文献的跨模态图像模式识别方法
CN106919951B (zh) 一种基于点击与视觉融合的弱监督双线性深度学习方法
CN109117793B (zh) 基于深度迁移学习的直推式雷达高分辨距离像识别方法
CN107944410B (zh) 一种基于卷积神经网络的跨领域面部特征解析方法
AU2020100052A4 (en) Unattended video classifying system based on transfer learning
US20230153577A1 (en) Trust-region aware neural network architecture search for knowledge distillation
CN109993236A (zh) 基于one-shot Siamese卷积神经网络的少样本满文匹配方法
CN113128620B (zh) 一种基于层次关系的半监督领域自适应图片分类方法
CN111783841A (zh) 基于迁移学习和模型融合的垃圾分类方法、系统及介质
US20220121949A1 (en) Personalized neural network pruning
CN111898685A (zh) 一种基于长尾分布数据集的目标检测方法
CN113011487B (zh) 一种基于联合学习与知识迁移的开放集图像分类方法
CN111723874A (zh) 一种基于宽度和深度神经网络的声场景分类方法
CN114882521A (zh) 基于多分支网络的无监督行人重识别方法及装置
CN113590876A (zh) 一种视频标签设置方法、装置、计算机设备及存储介质
CN111126361A (zh) 基于半监督学习和特征约束的sar目标鉴别方法
CN111832580B (zh) 结合少样本学习与目标属性特征的sar目标识别方法
CN112883931A (zh) 基于长短期记忆网络的实时真假运动判断方法
WO2023091428A1 (en) Trust-region aware neural network architecture search for knowledge distillation
CN117671246A (zh) 一种基于交叉验证识别机制的开放词表目标检测算法
WO2021072338A1 (en) Learned threshold pruning for deep neural networks

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication

Application publication date: 20201218

RJ01 Rejection of invention patent application after publication