CN108460464A - 深度学习训练方法及装置 - Google Patents

深度学习训练方法及装置 Download PDF

Info

Publication number
CN108460464A
CN108460464A CN201710094563.5A CN201710094563A CN108460464A CN 108460464 A CN108460464 A CN 108460464A CN 201710094563 A CN201710094563 A CN 201710094563A CN 108460464 A CN108460464 A CN 108460464A
Authority
CN
China
Prior art keywords
training data
threshold value
predetermined threshold
data example
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN201710094563.5A
Other languages
English (en)
Inventor
高燕
吕达
罗圣美
李伟华
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
ZTE Corp
Original Assignee
ZTE Corp
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by ZTE Corp filed Critical ZTE Corp
Priority to CN201710094563.5A priority Critical patent/CN108460464A/zh
Priority to PCT/CN2018/073955 priority patent/WO2018153201A1/zh
Publication of CN108460464A publication Critical patent/CN108460464A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N99/00Subject matter not provided for in other groups of this subclass

Abstract

本发明公开了一种深度学习训练方法及装置,用以解现有深度学习领域中深度学习模型收敛较慢的问题。所述方法包括:在每次迭代训练中,根据各训练数据实例的损失值,从批量训练数据实例中确定出所有困难实例;放弃学习非困难实例的特征,学习所述所有困难实例的特征。本发明中训练方法及装置通过计算训练迭代中训练数据实例的损失值,获得对该次迭代具有较大作用数据实例,并用于对模型进行训练;也就是说集中训练困难实例,加快了模型的收敛速度。

Description

深度学习训练方法及装置
技术领域
本发明涉及智能学习领域,特别是涉及一种深度学习训练方法及装置。
背景技术
随着网络信息技术的发展,信息数据的存储和传播越来越便捷,人们可以方便地获得大量的信息数据用于学习、工作和生活。目前已进入大数据时代,数以亿计的数据,加之不断提高的计算能力,使得一度进入冰河期的神经网络领域开始再度复苏,深度学习(多层神经网络)掀起新一轮的热潮。
目前,深度学习是人工智能领域中研究重点,大量的学者和研究人员投身其中,推动着其迅速发展。尽管深度学习取得了极大的成就,但其依旧面临着很多难题。相比传统方法,更多的数据和更深的网络结构是深度学习最大的特色,也是其取得成功的关键。但这也意味着深度学习往往需要更大的训练存储空间和时间;训练一个深度学习的模型往往需要数天乃至数个月的时间,因而加速训练过程,节约时间成本是当下的一个重要研究方向。
对于加速训练,现有技术中一般采用在硬件方面采用GPU加速和集群计算,在算法上采用数据并行和模型并行方案。现有方案虽然加快了深度网络的训练迭代速度,但仍然面临着模型收敛较慢的问题。
发明内容
为了克服上述现有技术的缺陷,本发明要解决的技术问题是提供一种深度学习训练方法及装置,用以解现有深度学习领域中深度学习模型收敛较慢的问题。
为解决上述技术问题,本发明中的一种深度学习训练方法,包括:
在每次迭代训练中,根据各训练数据实例的损失值,从批量训练数据实例中确定出所有困难实例;
放弃学习非困难实例的特征,学习所述所有困难实例的特征。
可选地,所述根据各训练数据实例的损失值,从批量训练数据实例中确定出所有困难实例,包括:
针对任一训练数据实例,对比该训练数据实例的损失值和预设阈值θ1的大小关系;若该损失值不小于所述预设阈值θ1,则该训练数据实例为困难实例;
遍历所述批量训练数据实例,对比出所有困难实例。
可选地,所述在每次迭代训练中,根据各训练数据实例的损失值,从批量训练数据实例中确定出所有困难实例之前,还包括:
在每次迭代训练的前向传播过程中,确定所述批量训练数据实例中每个训练数据实例的损失值。
可选地,所述根据各训练数据实例的损失值,从批量训练数据实例中确定出所有困难实例之后,还包括:
确定所述批量训练数据实例的损失平均值;
对比所述损失平均值和预设阈值θ2的大小关系;
若所述损失平均值超过所述预设阈值θ2,则放弃学习非困难实例的特征,学习所述所有困难实例的特征;
若所述损失平均值未超过所述预设阈值θ2,则放弃学习所述批量训练数据实例的特征。
具体地,所述预设阈值θ2小于所述预设阈值θ1
具体地,所述方法还包括:
针对任一训练数据实例,根据该训练数据实例的类别概率,确定该训练数据实例的预设阈值θ1
根据任一训练数据实例预设阈值θ1,确定所述预设阈值θ2
具体地,所述学习所述所有困难实例的特征,还包括:
在学习时,将各困难实例的损失值反向传播;
根据各损失值调整用于训练的网络参数。
为解决上述技术问题,本发明中的一种深度学习训练装置,包括:
实例选择模块,用于在每次迭代训练中,根据各训练数据实例的损失值,从批量训练数据实例中确定出所有困难实例;
学习模块,用于放弃学习非困难实例的特征,学习所述所有困难实例的特征。
可选地,所述实例选择模块,具体用于针对任一训练数据实例,对比该训练数据实例的损失值和预设阈值θ1的大小关系;若该损失值不小于所述预设阈值θ1,则该训练数据实例为困难实例;
遍历所述批量训练数据实例,对比出所有困难实例。
可选地,所述装置还包括:
损失确定模块,用于在每次迭代训练的前向传播过程中,确定所述批量训练数据实例中每个训练数据实例的损失值
可选地,所述装置还包括:
判断模块,用于确定所述批量训练数据实例的损失平均值;
对比所述损失平均值和预设阈值θ2的大小关系;
若所述损失平均值超过所述预设阈值θ2,则触发所述学习模块放弃学习非困难实例的特征,学习所述所有困难实例的特征;
若所述损失平均值不小于所述预设阈值θ2,则放弃学习所述批量训练数据实例的特征。
具体地,所述预设阈值θ2小于所述预设阈值θ1
具体地,所述装置还包括:
阈值设置模块,用于针对任一训练数据实例,根据该训练数据实例的类别概率,确定该训练数据实例的预设阈值θ1
根据任一训练数据实例预设阈值θ1,确定所述预设阈值θ2
具体地,所述装置还包括:
参数调整模块,用于在学习时,将各困难实例的损失值反向传播;
根据各损失值调整用于训练的网络参数。
本发明有益效果如下:
本发明中训练方法及装置通过计算训练迭代中训练数据实例的损失值,获得对该次迭代具有较大作用数据实例,并用于对模型进行训练;也就是说集中训练困难实例,加快了模型的收敛速度;同时,学习训练过程忽略了无用数据实例,有效地改善了实际问题中训练数据不平衡的问题。本发明实施例通过对模型训练数据的分析,对现有的训练学习方法进行改进,可结合现有各种优化求解方法使用,并可以融合进当前的各个深度学习框架中。
附图说明
图1是本发明实施例中一种深度学习训练方法的主流程图;
图2是本发明实施例中一种深度学习训练方法的详细流程图;
图3是本发明实施例中一种深度学习训练装置的结构示意图。
具体实施方式
对于深度学习的网络训练而言,加快网络收敛相较于单纯加速更为重要。因此基于训练数据考虑,为了解决现有深度学习领域中深度学习模型收敛较慢的问题,本发明提供了一种深度学习训练方法及装置,以下结合附图以及实施例,对本发明进行进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本发明,并不限定本发明。
本发明实施例中一种深度学习训练方法,包括:
S101,在每次迭代训练的前向传播过程中,确定批量训练数据实例中每个训练数据实例的损失值。详细说,本步骤可以包括:
步骤1011,根据任务要求,获取足量训练样本(即训练数据实例或数据实例),并对所获取的训练样本进行筛选、处理、增强、均衡、标记标签等操作,构建训练样本集。
步骤1012,选定深度网络模型结构,设定相应的训练参数,初始化深度网络模型。
步骤1013,将一定数量的训练样本组成一个batch(批量训练数据实例)送入深度网络进行计算,得到此batch中每个样本数据的分类计算值Xc。
步骤1014,对比每个样本的真实标签XT,计算每个样本的Loss(损失)值
L。其中,计算Loss值L的方法为:
L=-log[softmax(ak)]k为该实例的真实类别 (1)
其中,a为类别概率,softmax(ak)为交叉损失函数。
S102,根据各训练数据实例的损失值,从所述批量训练数据实例中确定出所有困难实例;
S103,放弃学习非困难实例的特征,学习所述所有困难实例的特征。
本发明实施例通过计算训练迭代中训练数据实例的损失(数据实例实际输出与理想输出的差距)值,获得对该次迭代具有较大作用数据实例(即困难实例),并用于对模型进行训练;也就是说集中训练困难实例,加快了模型的收敛速度;同时,学习训练过程忽略了无用数据实例(即非困难实例),有效地改善了实际问题中训练数据不平衡的问题。本发明实施例通过对模型训练数据的分析,对现有的训练学习方法进行改进,可结合现有各种优化求解方法使用,并可以融合进当前的各个深度学习框架中。
在上述实施例的基础上,进一步提出上述实施例的变型实施例,在此需要说明的是,为了使描述简要,在各变型实施例中仅描述与上述实施例的不同之处。
在本发明的一个实施例中,所述根据各训练数据实例的损失值,从所述批量训练数据实例中确定出所有困难实例,包括:
针对任一训练数据实例,对比该训练数据实例的损失值和预设阈值θ1的大小关系;若该损失值不小于所述预设阈值θ1,则该训练数据实例为困难实例;
遍历所述批量训练数据实例,对比出所有困难实例。
详细说,将batch中的每个训练样本的Loss与阈值θ1进行对比,若L超过阈值θ1,则认为此训练样本为困难实例,用于本次学习,反之则将其舍弃。
本发明实施例进一步加速了深度学习模型的收敛。
在本发明的另一个实施例中,所述根据各训练数据实例的损失值,从批量训练数据实例中确定出所有困难实例之后,还包括:
确定所述批量训练数据实例的损失平均值;
对比所述损失平均值和预设阈值的大小关系;
若所述损失平均值超过所述预设阈值θ2,则放弃学习非困难实例的特征,学习所述所有困难实例的特征;进一步说,在学习时,将各困难实例的损失值反向传播;根据各损失值调整用于训练的网络参数。
若所述损失平均值未超过所述预设阈值θ2,则放弃学习所述批量训练数据实例的特征。
其中,所述预设阈值θ2小于所述预设阈值θ1
详细说,将一个batch训练样本的Loss均值Lavg与阈值θ2进行对比,若Lavg超过阈值,则认为此batch中绝大多数训练样本为困难实例,将Loss值进行反向传播,微调网络参数,对模型进行训练;若Lavg未超过阈值,则认为此batch训练样本几乎均为为非困难实例,所得Loss值不进行反向传播,舍弃该batch,阻止模型学习此batch中训练样本特征,进一步加速。
其中,Loss均值Lavg为batch样本中所有样本的loss值之和除以每个batch中样本数量N;即,
其中,a为类别概率,softmax(ai_k)为交叉损失函数。
进一步说,所述方法还包括:
针对任一训练数据实例,根据该训练数据实例的类别概率,确定该训练数据实例的预设阈值θ1
根据任一训练数据实例预设阈值θ1,确定所述预设阈值θ2
详细说,对于阈值θ1和θ2由选取的loss计算公式和batch的大小确定。
θ1=-log(a)a∈(0.9,1) (3)
其中,a为类别概率,θ1为单个样本评价阈值,θ2为整个batch的评价阈值,N为一个batch中的样本数量。
本发明实施例设计了基于数据分析的深度学习加速收敛方法,可应用于各个深度学习开源框架。该方法主要包括数据预处理和深度学习训练。其中,在数据预处理部分,通过运用各种图像变换方法进行数据增强,从而极大地扩充了数据,并增加了数据的多样性。在深度学习训练部分,结合支持向量思想,通过对数据的损失分析,加速了收敛。
本发明实施例基于对训练过程中的数据分析,根据每次迭代中数据的损失大小,使得训练集中在困难数据实例上(损失大),从而加快了收敛的速度。相比于现有对于训练数据不加区分的学习方法,本发明实施例通过训练数据的损失对数据加以区分,使得训练更具有针对性。同时网络训练方法将所有的数据都用于学习,从而导致实际运用中训练数据的不平衡问题,则使得学习模型的训练倾向于数据量更多的数据类别,而本发明实施例则对该问题起到了遏制作用,一定程度上提升了训练效果。
举一具体应用例,详细说明本发明中方法。
实验数据采用ImageNet数据集,数据集训练图片共120万张,分为1000类,毎类1200张样本。对于ImageNet图像识别竞赛的分类任务,采用本发明中方法进行实现,同时与现有Caffe(卷积神经网络框架)开源框架训练方法进行对比。
详细说,如图2所示,本发明中方法主要分两大过程:数据预处理、深度学习训练。下面结合该实验分别说明每个步骤的具体实现。
①据预处理
数据预处理是进行数据分析、学习任务的必要过程。对于本实验而言,数据的分类、标注等任务数据集中已完成,因而所需的关键就在于数据增强。对样本进行数据增强(例如随机裁剪,镜像等增强方法)。图像分辨率调整至256×256,最终将数据保存为lmdb文件格式,供Caffe调用。
②度学习的学习训练
本发明中方法主要是针对本过程进行改进,依据训练数据实例损失大小区分数据进行迭代学习。主要涉及通过深度网络训练得到深度模型。
具体说,在通过深度网络(本文中可以简称网络)训练过程中包括如下步骤:
(1)根据任务要求,获取足量训练样本,并对所获取的训练样本进行筛选、处理、增强、均衡、标记标签等操作,构建训练样本集。
(2)选定深度网络模型结构,设定相应的训练参数,初始化深度网络模型。
(3)将一定数量的训练样本组成一个batch送入网络进行计算,得到此batch中每个样本数据的分类计算值Xc
(4)对比每个样本的真实标签XT,计算Loss(损失)值L。将batch中的每个训练样本的Loss与阈值θ1进行对比,若L超过阈值,则认为此训练样本为困难实例,用于本次学习,反之则将其舍弃。
损失计算公式有多种,本实验采用分类最常用的SoftmaxLoss进行介绍。
SoftmaxLoss是以Softmax函数作为交叉损失函数输入,计算公式如下:
Softmax的计算结果等于一个数据实例属于各个类别的概率。
进一步根据上述公式(2)可以计算出该数据实例的损失。
(5)计算整个batch中所有样本数据的Loss均值Lavg
(6)将batch训练样本中困难实例的Loss均值Lavg与阈值θ2进行对比,若Lavg超过阈值,则认为此batch中绝大多数训练样本为困难实例,将Loss值进行反向传播,微调网络参数,对模型进行训练;若Lavg未超过阈值,则认为此batch训练样本几乎均为为非困难实例,所得Loss值不进行反向传播,舍弃该batch,阻止模型学习此batch中训练样本特征,进一步加速。
由于θ1是根据单个实例的损失判定阈值,其确定方式根据上述公式(4)得来,a为类别概率。本次实验中设定a为0.99,计算得到θ1取值0.01。
θ2用于判定批量数据的平均损失,考虑防止个别实例损失值较小影响整体平均损失影响,θ2应小于θ1,且随着数据批量大小N的增大,该影响逐渐较小,θ2也不断接近θ1,因而采用上述公式(5)确定,计算得到θ2值为9.9×10-3
(7)若未达到终止条件,则返回步骤(3)继续训练。达到终止条件,接收学习过程。
综上,本发明中方法训练部分通过对单个数据实例和批量数据实例控制,实现将训练学习集中于困难实例。单个数据实例部分通过公式(2)计算出的损失值与阈值θ1比较,若大于阈值θ1,则该实例用于训练学习;反之,本次迭代中忽略该数据实例,即其反向传播梯度为0。对于批量数据实例控制部分,通过将整个批量数据的损失与阈值θ2比较,若大于阈值θ2,则执行反向传播,反之,则取消,即该批量数据不用于学习。
实验结果显示,原训练方法在4367次迭代后,loss开始下降,逐渐收敛;而使用本发明的方法后,在进行到第78次迭代后,loss开始下降,加速收敛效果明显。
本发明进一步提出一种深度学习训练装置,包括:
损失确定模块310,用于在每次迭代训练的前向传播过程中,确定批量训练数据实例中每个训练数据实例的损失值;
实例选择模块320,用于根据各训练数据实例的损失值,从所述批量训练数据实例中确定出所有困难实例;
学习模块330,用于放弃学习非困难实例的特征,学习所述所有困难实例的特征。
本发明实施例通过计算训练迭代中训练数据实例的损失值,获得对该次迭代具有较大作用数据实例,并用于对模型进行训练;也就是说集中训练困难实例,加快了模型的收敛速度;同时,学习训练过程忽略了无用数据实例,有效地改善了实际问题中训练数据不平衡的问题。本发明实施例通过对模型训练数据的分析,对现有的训练学习方法进行改进,可结合现有各种优化求解方法使用,并可以融合进当前的各个深度学习框架中。
在本发明的一个实施例中,所述实例选择模块320,具体用于针对任一训练数据实例,对比该训练数据实例的损失值和预设阈值θ1的大小关系;若该损失值不小于所述预设阈值θ1,则该训练数据实例为困难实例;
遍历所述批量训练数据实例,对比出所有困难实例。
在本发明的另一个实施例中,所述装置还包括:
判断模块,用于确定所述批量训练数据实例的损失平均值;
对比所述损失平均值和预设阈值θ2的大小关系;
若所述损失平均值超过所述预设阈值θ2,则触发所述学习模块放弃学习非困难实例的特征,学习所述所有困难实例的特征;
若所述损失平均值未超过所述预设阈值θ2,则放弃学习所述批量训练数据实例的特征。
其中,所述预设阈值θ2小于所述预设阈值θ1
进一步说,所述装置还包括:
阈值设置模块,用于针对任一训练数据实例,根据该训练数据实例的类别概率,确定该训练数据实例的预设阈值θ1
根据任一训练数据实例预设阈值θ1,确定所述预设阈值θ2。和/或
参数调整模块,用于在学习时,将各困难实例的损失值反向传播;
根据各损失值调整用于训练的网络参数。
结合本申请所公开示例描述的方法,可直接体现为硬件、由处理器执行的软件模块或者二者结合。例如,附图中所示功能框图中的一个或多个功能框图和/或功能框图的一个和/或多个组合,既可以对应于计算机程序流程的各个软件模块,亦可以对应于各个硬件模块。这些软件模块,可以分别对应于附图所示的各个步骤。这些硬件模块例如可利用现场可编程门阵列(FPGA)将这些软件模块固化而实现。
软件模块可以位于RAM存储器、闪存、ROM存储器、EPROM存储器、EEPROM存储器、寄存器、硬盘、移动硬盘、CD-ROM或者本领域已知的任何其他形式的存储介质。可以将一种存储介质藕接至处理器,从而使处理器能够从该存储介质读取信息,且可向该存储介质写入信息;或者该存储介质可以是处理器的组成部分。处理器和存储介质可以位于专用集成电路中。该软件模块可以存储在移动终端的存储器中,也可以存储在可插入移动终端的存储卡中。例如,若移动终端采用的是较大容量的MEGA-SIM卡或者大容量的闪存装置,则该软件模块可存储在该MEGA-SIM卡或者大容量的闪存装置中。
针对附图中描述的功能框图中的一个或多个和/或功能框图的一个或多个组合,可以实现为用于执行本申请所描述功能的通用处理器、数字信号处理器(DSP)、专用集成电路(ASIC)、现场可编程门阵列(FPGA)或者其他可编程逻辑器件、分立门或晶体管逻辑器件、分立硬件组件或者其任意适当组合。针对附图中描述的功能框图中的一个或多个和/或功能框图的一个或多个组合,还可以实现为计算机设备的组合,例如,DSP和微处理器的组合、多个微处理器、与DSP通信结合的一个或多个微处理器或者任何其他这种配置。
虽然本申请描述了本发明的特定示例,但本领域技术人员可以在不脱离本发明概念的基础上设计出来本发明的变型。
本领域技术人员在本发明技术构思的启发下,在不脱离本发明内容的基础上,还可以对本发明做出各种改进,这仍落在本发明的保护范围之内。

Claims (14)

1.一种深度学习训练方法,其特征在于,所述方法包括:
在每次迭代训练中,根据各训练数据实例的损失值,从批量训练数据实例中确定出所有困难实例;
放弃学习非困难实例的特征,学习所述所有困难实例的特征。
2.如权利要求1所述的方法,其特征在于,所述根据各训练数据实例的损失值,从批量训练数据实例中确定出所有困难实例,包括:
针对任一训练数据实例,对比该训练数据实例的损失值和预设阈值θ1的大小关系;若该损失值不小于所述预设阈值θ1,则该训练数据实例为困难实例;
遍历所述批量训练数据实例,对比出所有困难实例。
3.如权利要求1所述的方法,其特征在于,所述在每次迭代训练中,根据各训练数据实例的损失值,从批量训练数据实例中确定出所有困难实例之前,还包括:
在每次迭代训练的前向传播过程中,确定所述批量训练数据实例中每个训练数据实例的损失值。
4.如权利要求1-3中任意一项所述的方法,其特征在于,所述根据各训练数据实例的损失值,从批量训练数据实例中确定出所有困难实例之后,还包括:
确定所述批量训练数据实例的损失平均值;
对比所述损失平均值和预设阈值θ2的大小关系;
若所述损失平均值超过所述预设阈值θ2,则放弃学习非困难实例的特征,学习所述所有困难实例的特征;
若所述损失平均值未超过所述预设阈值θ2,则放弃学习所述批量训练数据实例的特征。
5.如权利要求4所述的方法,其特征在于,所述预设阈值θ2小于所述预设阈值θ1
6.如权利要求4所述的方法,其特征在于,所述方法还包括:
针对任一训练数据实例,根据该训练数据实例的类别概率,确定该训练数据实例的预设阈值θ1
根据任一训练数据实例预设阈值θ1,确定所述预设阈值θ2
7.如权利要求4所述的方法,其特征在于,所述学习所述所有困难实例的特征,还包括:
在学习时,将各困难实例的损失值反向传播;
根据各损失值调整用于训练的网络参数。
8.一种深度学习训练装置,其特征在于,所述装置包括:
实例选择模块,用于在每次迭代训练中,根据各训练数据实例的损失值,从批量训练数据实例中确定出所有困难实例;
学习模块,用于放弃学习非困难实例的特征,学习所述所有困难实例的特征。
9.如权利要求8所述的装置,其特征在于,所述实例选择模块,具体用于针对任一训练数据实例,对比该训练数据实例的损失值和预设阈值θ1的大小关系;若该损失值不小于所述预设阈值θ1,则该训练数据实例为困难实例;
遍历所述批量训练数据实例,对比出所有困难实例。
10.如权利要求8所述的装置,其特征在于,所述装置还包括:
损失确定模块,用于在每次迭代训练的前向传播过程中,确定所述批量训练数据实例中每个训练数据实例的损失值。
11.如权利要求8-10中任意一项所述的装置,其特征在于,所述装置还包括:
判断模块,用于确定所述批量训练数据实例的损失平均值;
对比所述损失平均值和预设阈值θ2的大小关系;
若所述损失平均值超过所述预设阈值θ2,则触发所述学习模块放弃学习非困难实例的特征,学习所述所有困难实例的特征;
若所述损失平均值未超过所述预设阈值θ2,则放弃学习所述批量训练数据实例的特征。
12.如权利要求11所述的装置,其特征在于,所述预设阈值θ2小于所述预设阈值θ1
13.如权利要求11所述的装置,其特征在于,所述装置还包括:
阈值设置模块,用于针对任一训练数据实例,根据该训练数据实例的类别概率,确定该训练数据实例的预设阈值θ1
根据任一训练数据实例预设阈值θ1,确定所述预设阈值θ2
14.如权利要求11所述的装置,其特征在于,所述装置还包括:
参数调整模块,用于在学习时,将各困难实例的损失值反向传播;
根据各损失值调整用于训练的网络参数。
CN201710094563.5A 2017-02-22 2017-02-22 深度学习训练方法及装置 Pending CN108460464A (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN201710094563.5A CN108460464A (zh) 2017-02-22 2017-02-22 深度学习训练方法及装置
PCT/CN2018/073955 WO2018153201A1 (zh) 2017-02-22 2018-01-24 深度学习训练方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201710094563.5A CN108460464A (zh) 2017-02-22 2017-02-22 深度学习训练方法及装置

Publications (1)

Publication Number Publication Date
CN108460464A true CN108460464A (zh) 2018-08-28

Family

ID=63222016

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201710094563.5A Pending CN108460464A (zh) 2017-02-22 2017-02-22 深度学习训练方法及装置

Country Status (2)

Country Link
CN (1) CN108460464A (zh)
WO (1) WO2018153201A1 (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109858411A (zh) * 2019-01-18 2019-06-07 深圳壹账通智能科技有限公司 基于人工智能的案件审判方法、装置及计算机设备
WO2021057186A1 (zh) * 2019-09-24 2021-04-01 华为技术有限公司 训练神经网络的方法、数据处理方法和相关装置

Families Citing this family (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10504027B1 (en) * 2018-10-26 2019-12-10 StradVision, Inc. CNN-based learning method, learning device for selecting useful training data and test method, test device using the same
CN110659678B (zh) * 2019-09-09 2023-11-17 腾讯科技(深圳)有限公司 一种用户行为分类方法、系统及存储介质
CN111400915A (zh) * 2020-03-17 2020-07-10 桂林理工大学 一种基于深度学习的砂土液化判别方法及装置
CN113538079A (zh) * 2020-04-17 2021-10-22 北京金山数字娱乐科技有限公司 一种推荐模型的训练方法及装置、一种推荐方法及装置
CN113420792A (zh) * 2021-06-03 2021-09-21 阿波罗智联(北京)科技有限公司 图像模型的训练方法、电子设备、路侧设备及云控平台
CN115100249B (zh) * 2022-06-24 2023-08-04 王世莉 一种基于目标跟踪算法的智慧工厂监控系统
CN116610960B (zh) * 2023-07-20 2023-10-13 北京万界数据科技有限责任公司 一种人工智能训练参数的监测管理系统

Family Cites Families (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN103593474B (zh) * 2013-11-28 2017-03-01 中国科学院自动化研究所 基于深度学习的图像检索排序方法
US9380224B2 (en) * 2014-02-28 2016-06-28 Microsoft Technology Licensing, Llc Depth sensing using an infrared camera
CN104992223B (zh) * 2015-06-12 2018-02-16 安徽大学 基于深度学习的密集人数估计方法
CN105608450B (zh) * 2016-03-01 2018-11-27 天津中科智能识别产业技术研究院有限公司 基于深度卷积神经网络的异质人脸识别方法
CN106096538B (zh) * 2016-06-08 2019-08-23 中国科学院自动化研究所 基于定序神经网络模型的人脸识别方法及装置

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109858411A (zh) * 2019-01-18 2019-06-07 深圳壹账通智能科技有限公司 基于人工智能的案件审判方法、装置及计算机设备
WO2021057186A1 (zh) * 2019-09-24 2021-04-01 华为技术有限公司 训练神经网络的方法、数据处理方法和相关装置

Also Published As

Publication number Publication date
WO2018153201A1 (zh) 2018-08-30

Similar Documents

Publication Publication Date Title
CN108460464A (zh) 深度学习训练方法及装置
CN107506799B (zh) 一种基于深度神经网络的开集类别发掘与扩展方法与装置
CN110674714B (zh) 基于迁移学习的人脸和人脸关键点联合检测方法
CN109376242B (zh) 基于循环神经网络变体和卷积神经网络的文本分类方法
CN108416440A (zh) 一种神经网络的训练方法、物体识别方法及装置
WO2021238262A1 (zh) 一种车辆识别方法、装置、设备及存储介质
Peng et al. Accelerating minibatch stochastic gradient descent using typicality sampling
CN111461226A (zh) 对抗样本生成方法、装置、终端及可读存储介质
CN106202032A (zh) 一种面向微博短文本的情感分析方法及其系统
CN108960301B (zh) 一种基于卷积神经网络的古彝文识别方法
Montalbo et al. Classification of fish species with augmented data using deep convolutional neural network
CN107292352A (zh) 基于卷积神经网络的图像分类方法和装置
CN108427665A (zh) 一种基于lstm型rnn模型的文本自动生成方法
CN110321967A (zh) 基于卷积神经网络的图像分类改进算法
CN112766399B (zh) 一种面向图像识别的自适应神经网络训练方法
CN108846120A (zh) 用于对文本集进行分类的方法、系统及存储介质
CN111653275A (zh) 基于lstm-ctc尾部卷积的语音识别模型的构建方法及装置、语音识别方法
CN110008961A (zh) 文字实时识别方法、装置、计算机设备及存储介质
CN110287985A (zh) 一种基于带变异粒子群算法的可变拓扑结构的深度神经网络图像识别方法
CN112883931A (zh) 基于长短期记忆网络的实时真假运动判断方法
CN113822434A (zh) 用于知识蒸馏的模型选择学习
CN117037006B (zh) 一种高续航能力的无人机跟踪方法
CN114049527A (zh) 基于在线协作与融合的自我知识蒸馏方法与系统
CN111783688B (zh) 一种基于卷积神经网络的遥感图像场景分类方法
CN108470212A (zh) 一种能利用事件持续时间的高效lstm设计方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
WD01 Invention patent application deemed withdrawn after publication
WD01 Invention patent application deemed withdrawn after publication

Application publication date: 20180828