CN112633359A - 一种基于梯度平衡的多类别模型训练方法、介质及设备 - Google Patents
一种基于梯度平衡的多类别模型训练方法、介质及设备 Download PDFInfo
- Publication number
- CN112633359A CN112633359A CN202011509570.5A CN202011509570A CN112633359A CN 112633359 A CN112633359 A CN 112633359A CN 202011509570 A CN202011509570 A CN 202011509570A CN 112633359 A CN112633359 A CN 112633359A
- Authority
- CN
- China
- Prior art keywords
- gradient
- training
- model
- loss function
- sample
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Software Systems (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Computing Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Mathematical Physics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Medical Informatics (AREA)
- Image Analysis (AREA)
Abstract
本发明提供了一种基于梯度平衡的多类别模型训练方法、介质及设备,其中模型训练方法包括通过将训练样本数据输入到选定的神经网络模型进行训练获得损失函数;统计训练模型时损失函数的梯度分布;根据分布结果分配样本权重;权重的平滑处理;权重的衰减处理;获取新梯度更新网络模型;本发明通过在训练的过程中根据输入训练样本数据得到的损失函数梯度的分布,调整样本分配的权重,平衡不同程度难易样本对模型的影响,缩短了模型训练时间同时提升了模型的精度。
Description
技术领域
本发明涉及人工智能图像处理技术领域,特别涉及一种基于梯度平衡的多类别模型训练方法、介质及设备。
背景技术
随着大数据时代的到来以及计算能力的不断提高,基于大数据驱动的深度学习技术得到了飞速发展。深度学习在很多应用领域如图像分类、目标检测(典型应用如人脸识别、行人识别、车辆识别等)、图像分割等都有着广泛的应用。这些领域的应用背后是近几年来涌现出许多优秀的神经网络结构。
可以说:成熟的应用模型=大量数据+优秀的网络结构+合适的训练方法。通过合适的训练方法和大量的数据,可以使网络模型从数据中自动学习到有用的特征和知识,从而完成相应的任务目标。
在近年来,大量的公开数据集被建立。对于大规模图像多分类任务,典型的数据集有ImageNet,COCO,VOC,OpenImage。大量优秀的网络结构被设计出来,并且在这些数据集上的有着优秀的表现。
目前绝大多数网络的训练方式还是基于随机梯度下降的训练方式。这种训练方式非常有效,然而其潜力还没有被彻底发挥出来。其中一个重要的原因就是没有利用不同难易程度的样本。对于一些任务来说,数量过多的简单样本会覆盖掉数量较少的困难样本对模型的影响,导致模型难以学习困难样本;同样,过多关注困难样本也会降低模型对简单样本的学习能力。因此目前还没有一种训练策略能够同时兼顾学习简单样本和困难样本,这也侧面说明当前的主流训练方式还需要进一步提升。
发明内容
为解决上述技术问题,本发明提供了一种基于梯度平衡的多类别模型训练方法、介质及设备,将数据分批次输入到模型中进行训练,通过在训练的过程中,根据输入样本对应的损失函数梯度分布的统计结果,对样本分配的权重进行平滑和衰减处理,平衡不同程度难易样本对模型的影响。该方法能够更好的挖掘数据和模型的能力,在相同的训练数据和网络模型下,能够大幅缩短模型训练的时间,同时提升训练得到的模型精度。
本发明提供了一种基于梯度平衡的多类别模型训练方法,具体技术方案如下:
S1:通过将训练样本数据输入到选定的神经网络模型进行训练获得损失函数;
S2:统计训练模型时损失函数的梯度;
通过前向传播的方式训练构建好的网络模型,计算损失函数,并对损失函数的梯度进行统计;
S3:分配样本权重;
根据获得的梯度分布统计结果,对每个样本分配相应的权重;
S4:权重的平滑;
对获得的样本对应的权重进行平滑处理,缩小过大的权重;
S5:平滑后权重的衰减处理;
在训练过程中对平滑项进行衰减,使得训练结束后平滑项衰减到0
S6:获取新梯度更新网络模型;
将经过平滑和衰减处理的权重矩阵与原损失函数梯度相乘获得新的梯度,并通过反向传播更新网络模型的参数。
进一步的,步骤S1中所述神经网络模型为基于链式法则和随机梯度下降训练方式的机器学习模型,训练样本数据划分为若干个批次输入到模型中。
进一步的,步骤S2中,对所述损失函数的梯度进行统计获取统计样本数据,所述统计方式以每次模型迭代后的梯度进行独立统计。
进一步的,所述损失函数的梯度统计时可采用区间划分统计,将梯度平均划分为若干个区间。
进一步的,步骤S3中,对得到的梯度分布统计结果进行样本的权重计算,将所述梯度的分布统计结果取倒数获得分布矩阵,根据的得到的分布矩阵通过归一化得到样本的权重矩阵。
进一步的,步骤S4中,通过对权重设定小于1的幂对得到的权重矩阵进行平滑处理。
进一步的,步骤S5中,在训练的过程中可采用指数衰减方式将平滑项衰减至0。
本发明还提供了一种计算机可读存储介质,存储有计算机程序,该程序被处理器执行时实现上述所述的基于梯度平衡的多类别模型训练方法。
本发明还提供了一种电子设备,包括:
存储器,存储有算法的计算机程序;
处理器,与所述存储器数据连接,调用所述计算机程序时执行权利要求1-7任一项所述的基于梯度平衡的多类别模型训练方法。
显示器,与所述处理器和所述存储器数据连接,所述显示所述基于梯度平衡的多类别模型训练方法相关的操作交互界面。
本发明的有益效果如下:
1、该方法将训练样本数据分批次输入,获得损失函数的梯度分布,根据梯度分布分配样本权重,通过对权重的平滑和衰减处理,减小不同困难程度样本间的权重差异,避免随着训练过程中难易样本间数量差异的增加影响模型的训练效果,该方法充分利用了训练样本数据,在相同的训练数据集和网络模型下,大大提高了训练得到的模型的泛化能力和准确性,同时大幅缩短了模型的训练时间,提高了模型的精度。
2、该方法实现了梯度的自平衡,能够适用于卷积神经网络模型、循环神经网络模型等任何基于链式法则和随机梯度下降训练方式的机器学习模型的训练。
附图说明
图1是本发明的方法流程示意图;
图2是本发明的电子设备结构示意图。
具体实施方式
在下面的描述中对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本发明的一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
实施例1
本发明的实施例提供了一种基于梯度平衡的多类别模型训练方法,该方法能够广泛应用在计算机设备中,如个人电脑、云端服务器,计算集群中,或者其它支持可执行模型训练的电子设备中。
如图1所示,所述训练方法包括如下步骤:
S1:通过将训练样本数据分为若干个批次输入到选定的神经网络模型进行训练获得损失函数;
其中,本实施例采用具有多种类别且数据量较大的OpenImage数据集作为训练样本,所述神经网络模型为基于链式法则和随机梯度下降训练方式的机器学习模型,例如VGG、Inception,ResNet等卷积神经网络模型。
S2:统计训练模型时损失函数的梯度;
通过前向传播的方式训练构建好的网络模型,计算损失函数,并对损失函数的梯度进行统计;对梯度的统计方式包括以每一次迭代后的梯度进行独立统计、对每一次迭代后的梯度进行累加统计、对每一个epoch的梯度进行统计等方式;本实施例中以每次模型迭代后的梯度进行独立统计。
对获得的梯度分布统计可采用多种方法,根据不同的问题选取不同的梯度分布统计方法,本实施例中以多标签分类问题举例,因而采用损失函数为交叉熵损失函数,将梯度按照大小分成多个区间,统计在训练过程中某个区间上的梯度的分布,对每一个类别单独做梯度的统计,这样在训练中可以更有针对性的平衡每个类别中的难易样本的影响,区间的划分方式包括平均间隔划分区间、聚类区间、分段划分等方式。
具体步骤如下:
S202:每一批次输入的训练数据样本会得到对应的导数,对每个类别单独统计导数分布,根据各个区间的导数数量,建立大小为m×n的分布矩阵K进行统计,每输入一匹训练数据样本到模型中,更新分布矩阵K。
S3:分配样本权重;
根据获得的梯度分布统计结果,对每个样本分配相应的权重,对于类别k,其分布为:
a1k,a2k,…,amk
其中aik表示类别k导数在第i个区间内的数量。
根据类别k的分布计算出类别k在权重矩阵W中的行向量为:
[w1k,w2k,…,wmk]
其中wik表示所述类别k在对应i的权重,所述wik的计算公式如下:
wik=(a1k+a2k+…+amk)/(aik×m)
S4:权重的平滑;
对获得的样本对应的权重进行平滑处理,缩小过大的权重,使得网络训练的更平滑,对权重进行平滑处理可通过给权重设定最大值,超出最大值的权重将会被强制设定为最大值;还可通过给权重设定小于1的幂,对权重进行平滑,本实施采用对权重设定小于1的幂对上述获得的权重矩阵进行平滑处理,其中幂指数设定为0.5;
S5:平滑后权重的衰减处理;
在训练过程中对平滑项进行衰减,使得训练结束后平滑项衰减到0,平滑项Si的衰减公式为:
Si=0.5-0.5×i/Z
其中Z为模型训练是设定的总训练步数。
S6:获取新梯度更新网络模型;
实施例2
基于上述模型训练方法,本发明的实施例2提供了一种计算机可读存储介质,存储有计算机程序,该程序被处理器执行时实现所述的基于梯度平衡的多类别模型训练方法。
实施例3
基于上述模型训练方法,本发明的实施例3提供了一种电子设备,如图2所示,所述设备包括存储器、处理器和显示器;
所述存储器,存储有算法的计算机程序;
所述处理器,与所述存储器数据连接,调用所述计算机程序时执行权利要求1-7任一项所述的基于梯度平衡的多类别模型训练方法。
所述显示器,与所述处理器和所述存储器数据连接,所述显示所述基于梯度平衡的多类别模型训练方法相关的操作交互界面。
本发明并不局限于前述的具体实施方式。本发明扩展到任何在本说明书中披露的新特征或任何新的组合,以及披露的任一新的方法或过程的步骤或任何新的组合。
Claims (9)
1.一种基于梯度平衡的多类别模型训练方法,其特征在于,方法包括如下步骤:
S1:通过将训练样本数据输入到选定的神经网络模型进行训练获得损失函数;
S2:统计训练模型时损失函数的梯度分布;
通过前向传播的方式训练构建好的网络模型,计算损失函数,并对损失函数的梯度进行统计;
S3:分配样本权重;
根据获得的梯度分布统计结果,对每个样本分配相应的权重;
S4:权重的平滑;
对获得的样本对应的权重进行平滑处理,缩小过大的权重;
S5:平滑后权重的衰减处理;
在训练过程中对平滑项进行衰减,使得训练结束后平滑项衰减到0;
S6:获取新梯度更新网络模型;
将经过平滑和衰减处理的权重矩阵与原损失函数梯度相乘获得新的梯度,并通过反向传播更新网络模型的参数。
2.根据权利要求1所述的基于梯度平衡的多类别模型训练方法,其特征在于,步骤S1中所述神经网络模型为基于链式法则和随机梯度下降训练方式的机器学习模型,训练样本数据划分为若干个批次输入到模型中。
3.根据权利要求1所述的基于梯度平衡的多类别模型训练方法,其特征在于,步骤S2中,对所述损失函数的梯度进行统计获取统计样本数据,所述统计方式以每次模型迭代后的梯度进行独立统计。
4.根据权利要求2所述的基于梯度平衡的多类别模型训练方法,其特征在于,所述损失函数的梯度统计时可采用的统计方法采用区间划分统计,将梯度平均划分为若干个区间。
5.根据权利要求1所述的基于梯度平衡的多类别模型训练方法,其特征在于,步骤S3中,对得到的梯度分布统计结果进行样本的权重计算,将所述梯度的分布统计结果取倒数获得分布矩阵,根据的得到的分布矩阵通过归一化得到样本的权重矩阵。
6.根据权利要求1所述的基于梯度平衡的多类别模型训练方法,其特征在于,步骤S4中,通过对权重设定最大值或对权重设定小于1的幂对得到的权重矩阵进行平滑处理。
7.根据权利要求1所述的基于梯度平衡的多类别模型训练方法,其特征在于,步骤S5中,在训练的过程中可采用均匀衰减、指数衰减或cosine衰减的方式将平滑项衰减至0。
8.一种计算机可读存储介质,存储有计算机程序,其特征在于,该程序被处理器执行时实现权利要求1-7任一项所述的基于梯度平衡的多类别模型训练方法。
9.一种电子设备,其特征在于,所述电子设备包括:
存储器,存储有算法的计算机程序;
处理器,与所述存储器数据连接,调用所述计算机程序时执行权利要求1-7任一项所述的基于梯度平衡的多类别模型训练方法。
显示器,与所述处理器和所述存储器数据连接,所述显示所述基于梯度平衡的多类别模型训练方法相关的操作交互界面。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011509570.5A CN112633359A (zh) | 2020-12-18 | 2020-12-18 | 一种基于梯度平衡的多类别模型训练方法、介质及设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011509570.5A CN112633359A (zh) | 2020-12-18 | 2020-12-18 | 一种基于梯度平衡的多类别模型训练方法、介质及设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN112633359A true CN112633359A (zh) | 2021-04-09 |
Family
ID=75317561
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011509570.5A Pending CN112633359A (zh) | 2020-12-18 | 2020-12-18 | 一种基于梯度平衡的多类别模型训练方法、介质及设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112633359A (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113704413A (zh) * | 2021-08-31 | 2021-11-26 | 平安普惠企业管理有限公司 | 基于多样本的意图分类方法、装置、设备及存储介质 |
CN114037660A (zh) * | 2021-10-26 | 2022-02-11 | 国药集团基因科技有限公司 | Oct视网膜病变图像识别方法及系统 |
CN114330573A (zh) * | 2021-12-30 | 2022-04-12 | 济南博观智能科技有限公司 | 一种目标检测方法、装置、电子设备及存储介质 |
CN114495229A (zh) * | 2022-01-26 | 2022-05-13 | 北京百度网讯科技有限公司 | 图像识别的处理方法及装置、设备、介质和产品 |
-
2020
- 2020-12-18 CN CN202011509570.5A patent/CN112633359A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113704413A (zh) * | 2021-08-31 | 2021-11-26 | 平安普惠企业管理有限公司 | 基于多样本的意图分类方法、装置、设备及存储介质 |
CN114037660A (zh) * | 2021-10-26 | 2022-02-11 | 国药集团基因科技有限公司 | Oct视网膜病变图像识别方法及系统 |
CN114330573A (zh) * | 2021-12-30 | 2022-04-12 | 济南博观智能科技有限公司 | 一种目标检测方法、装置、电子设备及存储介质 |
CN114495229A (zh) * | 2022-01-26 | 2022-05-13 | 北京百度网讯科技有限公司 | 图像识别的处理方法及装置、设备、介质和产品 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112633359A (zh) | 一种基于梯度平衡的多类别模型训练方法、介质及设备 | |
CN106228185B (zh) | 一种基于神经网络的通用图像分类识别系统及方法 | |
CN105005589B (zh) | 一种文本分类的方法和装置 | |
CN111125358B (zh) | 一种基于超图的文本分类方法 | |
CN110175628A (zh) | 一种基于自动搜索与知识蒸馏的神经网络剪枝的压缩算法 | |
Chang et al. | Automatic channel pruning via clustering and swarm intelligence optimization for CNN | |
CN110138595A (zh) | 动态加权网络的时间链路预测方法、装置、设备及介质 | |
CN114488140B (zh) | 一种基于深度迁移学习的小样本雷达一维像目标识别方法 | |
Li et al. | Hybrid optimization algorithm based on chaos, cloud and particle swarm optimization algorithm | |
CN113128671B (zh) | 一种基于多模态机器学习的服务需求动态预测方法及系统 | |
CN109472352A (zh) | 一种基于特征图统计特征的深度神经网络模型裁剪方法 | |
JP2020046784A (ja) | 計算装置、計算プログラム、記録媒体及び計算方法 | |
Wang et al. | Application of deep learning in analog circuit sizing | |
Chavan et al. | Mini batch K-Means clustering on large dataset | |
Zhou et al. | An analysis on the relationship between uncertainty and misclassification rate of classifiers | |
CN117057258B (zh) | 基于权重分配相关系数的黑启动过电压预测方法及系统 | |
CN103593504A (zh) | 一种基于改进质量放大技术的绳网动作可靠性仿真方法 | |
CN110837853A (zh) | 一种快速分类模型构建方法 | |
CN112686881B (zh) | 基于影像统计特征和lstm复合网络的颗粒物料混合均匀性检测方法 | |
CN112529637B (zh) | 基于情景感知的服务需求动态预测方法及系统 | |
Ji et al. | Fast progressive differentiable architecture search based on adaptive task granularity reorganization | |
CN112241811A (zh) | “互联网+”环境下定制产品的分层混合性能预测方法 | |
CN112560326A (zh) | 压力场的确定方法及装置 | |
CN111008692A (zh) | 基于改进生成对抗网络的多能计量特征数据生成方法及装置 | |
Li et al. | Data fine-pruning: a simple way to accelerate neural network training |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |