CN110110860B - 一种用于加速机器学习训练的自适应数据采样方法 - Google Patents
一种用于加速机器学习训练的自适应数据采样方法 Download PDFInfo
- Publication number
- CN110110860B CN110110860B CN201910371632.1A CN201910371632A CN110110860B CN 110110860 B CN110110860 B CN 110110860B CN 201910371632 A CN201910371632 A CN 201910371632A CN 110110860 B CN110110860 B CN 110110860B
- Authority
- CN
- China
- Prior art keywords
- training
- image
- iteration
- sample
- samples
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Physics & Mathematics (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Feedback Control In General (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种用于加速机器学习训练的自适应数据采样方法,根据每个样本数据上损失函数的利普希茨常数从样本集合中选取一个子集。接下来的若干轮迭代中,将使用这个样本集合的子集代替完整的训练样本集合进行训练,直至下一次的样本数据选择。本发明的方法能在使用部分样本进行训练的情况下不损失最终结果的准确性,所以达到了加速机器学习训练过程的效果。
Description
技术领域
本发明涉及一种用于加速机器学习训练的自适应数据采样方法,属于数据挖掘技术领域。
背景技术
大部分机器学习模型可以被形式化为以下优化问题:
其中w代表了模型的参数,n代表了训练样本的总数,fi(·)则表示第i个样本所对应的损失函数。为了求解上述优化问题,随机梯度下降法(SGD)以及它的变体是目前应用最为广泛的方法。
在一个典型的随机梯度下降法中,每一次迭代可以描述为以下过程:首先从样本集合中随机选取一个样本(假设其样本编号为i),随后计算出该样本所对应的损失函数的梯度最后沿着该样本数据的负梯度方向去更新当前的模型参数。一般来说,将进行n次这样更新的过程称为模型训练的一轮迭代(Epoch),意味着遍历了一次样本集合。一次机器学习问题的训练过程,往往要经历多轮迭代才能接近全局最优解或是局部最优解。
随着训练样本数据量的增大,很多机器学习问题的训练过程需要花费大量的时间。本发明考虑在每一轮迭代中,选取一个训练样本集合的子集来代替完整的集合进行训练,以此来加速机器学习的训练过程。
发明内容
发明目的:目前的随机梯度下降法在模型训练的一轮迭中需要遍历完整的样本集合,对于样本数目较多的机器学习在训练过程中需要花费大量的时间。针对上述问题,本发明考虑到在模型训练的一轮迭中,不同的样本数据对参数接近最优解所做出的贡献也是不同的。本发明的方法认为如果一个样本上损失函数的利普希茨常数(Lipschitz constant)较小,可以认为这个样本在当前的训练阶段是不重要的。因此,在本发明提供一种用于加速机器学习训练的自适应数据采样方法,每经过若干个轮迭代的训练,就会进行一次自适应采样,即根据每个样本上损失函数的利普希茨常数从样本集合中选取一个子集,并使用这个样本集合的子集代替完整的训练样本集合进行训练。本发明的方法能在使用部分样本进行训练的情况下不损失最终结果的准确性。
技术方案:一种用于加速机器学习训练的自适应数据采样方法,具体步骤为:
步骤100,输入样本数量为n训练数据,机器学习模型w以及学习率η、采样间隔p、总共的迭代轮数T、阈值c;
步骤101,随机初始化模型参数w=w0,并将参与训练的样本集合S初始化为完整的训练样本集合S0=[n];
步骤102,使用当前的模型参数w0计算并存储所有样本对应的损失fi(w0),其中fi(·)则表示第i个样本所对应的损失函数;
步骤103,在当前样本集合S下使用随机梯度下降法进行模型训练的一轮迭代;
步骤104,判断当前已完成的迭代轮数t是否是采样间隔p的整数倍,如果是则进入自适应的样本选择阶段,获得新的样本集合S用于下一轮迭代的训练;否则保持样本集合S不变;
步骤105,判断当前已完成的迭代轮数t是否达到总共的迭代轮数T,如果是则输出并保存模型w;否则返回步骤103继续进行训练。
所述进行模型训练的一轮迭代的具体流程为:首先输入当前模型参数wt,当前训练样本St以及学习率η等超参数;随后从样本集合St中随机选取一个样本(假设其样本编号为i),并计算出该样本所对应的损失函数的梯度最后使用随机梯度下降法更新模型参数wt;重复以上的步骤|St|次即是完成了模型训练的一轮迭代。
所述自适应的样本选择的具体步骤为;
步骤200,输入当前模型参数wt,所有的训练样本以及p轮迭代之前所有样本对应的损失fi(wt-p);
步骤201,初始化St为空集;
步骤202,计算当前模型参数wt下所有样本对应的损失fi(wt);
步骤203,计算经过p轮迭代的训练每个样本损失的变化量|fi(wt)-fi(wt-p)|;
步骤204,由于每个样本上损失函数的局部利普希茨常数可以通过估算,所以可以根据变化量|fi(wt)-fi(wt-p)|的从大到小的顺序依次将样本添加到St中,直至满足
即要求集合St中样本的局部利普希茨常数之和在所有样本的局部利普希茨常数之和中所占比例超过阈值c;
步骤205,输出集合St以用于随后的迭代。
本发明的方法在进行模型更新时既可以使用随机梯度下降法,也可以使用随机梯度下降法的变体,例如带动量的随机梯度下降法(momentum SGD)。
有益效果:与现有技术相比,本发明提供的用于加速机器学习训练的自适应数据采样方法,利用了样本损失函数的局部利普希茨常数,在训练过程中使用样本集合的子集代替完整的训练样本集合进行训练,从而加速了训练过程。
附图说明
图1为本发明实施的用于加速机器学习训练的自适应数据采样方法流程图;
图2为本发明实施的模型训练的一轮迭代的工作流程图;
图3为本发明实施的自适应的样本选择的工作流程图。
具体实施方式
下面结合具体实施例,进一步阐明本发明,应理解这些实施例仅用于说明本发明而不用于限制本发明的范围,在阅读了本发明之后,本领域技术人员对本发明的各种等价形式的修改均落于本申请所附权利要求所限定的范围。
本发明提供的用于加速机器学习训练的自适应采样方法,可应用于图像分类、文本分类等领域,适合于训练数据量大、训练过程耗时长的场景。以图像分类应用为例,本发明的方法的工作流程如下所述:
用于加速机器学习训练的自适应数据采样方法,方法流程如图1所示。首先输入n张训练图像数据,机器学习模型w以及学习率η、采样间隔p、总共的迭代轮数T、阈值c(步骤10),接着随机初始化模型参数w=w0(步骤11),计算并存储所有图像样本对应的损失fi(w0)(步骤12)。然后初始化迭代轮数计数器t=0以及参与训练的图像样本集合S0=[n](步骤13),接下来进入到模型训练的迭代阶段:在图像样本集合St下使用随机梯度下降法进行模型训练的一轮迭代(步骤14);随后将迭代轮数计数器t增加1,并根据迭代轮数计数器t是否是采样间隔p的整数倍来判断是否进入自适应的样本选择阶段(步骤16),若是进行自适应的样本选择以得到新的St(步骤17a),否则继续保持St不变(步骤17b)。每轮迭代结束时进行判断是否达到停止条件t=T(步骤18),若未达到停止条件则继续迭代,否则输出训练结果并保存模型(步骤19)。
模型训练的一轮迭代的工作流程如图2所示。首先读取当前的模型wt、图像样本集合St以及学习率η等超参数(步骤140),随后初始化更新次数计数器j=0并进入训练过程(步骤141):先从图像样本集合St中随机选取一个图像样本(步骤142),再计算出该图像样本所对应的损失函数的梯度(步骤143),使用随机梯度下降法或者其他优化方法更新模型参数(步骤144),最后将更新次数计数器j增加1(步骤145)。重复上述步骤,直到满足停止条件j=|St|(步骤146),输出参数模型(步骤147)。
自适应的样本选择阶段的工作流程图如图3所示。首先读取当前的模型wt和所有的训练图像样本(步骤170),然后初始化图像样本集合St为空集(步骤171)。接下来用当前模型计算并存储所有图像样本对应的损失fi(wt)(步骤172),随后计算经过p轮迭代的训练每个图像样本损失的变化量|fi(wt)-fi(wt-p)|(步骤173)。将图像样本按照损失变化量|fi(wt)-fi(wt-p)|从大到小进行排序以后(步骤174),依次将图像样本添加到St中(步骤175),直至满足以下终止条件(步骤177):
最后输出图像样本集合St(步骤178)。
本发明的方法在一个包含了60000张彩色图像的数据集Cifar10上进行了实验,实验中使用的深度学习模型为20层的ResNet模型。实验结果表明,本发明提出的方法最终只需要使用33.2%的训练样本,就达到了与使用全部样本训练的随机梯度下降法同样的精度,训练所需的总时间减少了35.2%。
Claims (2)
1.一种用于加速机器学习训练的自适应数据采样方法,其特征在于:应用于图像分类,具体步骤为:
步骤100,输入样本数量为n张训练图像数据,机器学习模型w以及学习率η、采样间隔p、总共的迭代轮数T、阈值c;
步骤101,随机初始化模型参数w=w0,并将参与训练的图像样本集合S初始化为完整的训练图像样本集合S0=[n];
步骤102,使用当前的模型参数w0计算并存储所有图像样本对应的损失fi(w0),其中fi(·)则表示第i个图像样本所对应的损失函数;
步骤103,在当前图像样本集合S下使用随机梯度下降法进行模型训练的一轮迭代;
步骤104,判断当前已完成的迭代轮数t是否是采样间隔p的整数倍,如果是则进入自适应的图像样本选择阶段,获得新的图像样本集合S用于下一轮迭代的训练;否则保持样本集合S不变;
步骤105,判断当前已完成的迭代轮数t是否达到总共的迭代轮数T,如果是则输出并保存模型w;否则返回步骤103继续进行训练;
所述自适应的图像样本选择的具体步骤为:
步骤200,输入当前模型参数wt,所有的训练图像样本以及p轮迭代之前所有图像样本对应的损失fi(wt-p);
步骤201,初始化St为空集;
步骤202,计算当前模型参数wt下所有图像样本对应的损失fi(wt);
步骤203,计算经过p轮迭代的训练每个图像样本损失的变化量|fi(wt)-fi(wt-p)|;
步骤204,由于每个图像样本上损失函数的局部利普希茨常数通过估算,所以根据变化量|fi(wt)-fi(wt-p)|的从大到小的顺序依次将图像样本添加到St中,直至满足
即要求图像集合St中图像样本的局部利普希茨常数之和在所有图像样本的局部利普希茨常数之和中所占比例超过阈值c;
步骤205,输出集合St以用于随后的迭代。
2.如权利要求l所述的用于加速机器学习训练的自适应数据采样方法,其特征在于:进行模型训练的一轮迭代的具体流程为:首先输入当前模型参数wt,当前训练图像样本St以及学习率η等超参数;随后从图像样本集合St中随机选取一个样本,并计算出该图像样本所对应的损失函数的梯度最后使用随机梯度下降法更新模型参数wt;重复以上的步骤|St|次即是完成了模型训练的一轮迭代。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910371632.1A CN110110860B (zh) | 2019-05-06 | 2019-05-06 | 一种用于加速机器学习训练的自适应数据采样方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910371632.1A CN110110860B (zh) | 2019-05-06 | 2019-05-06 | 一种用于加速机器学习训练的自适应数据采样方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN110110860A CN110110860A (zh) | 2019-08-09 |
CN110110860B true CN110110860B (zh) | 2023-07-25 |
Family
ID=67488332
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910371632.1A Active CN110110860B (zh) | 2019-05-06 | 2019-05-06 | 一种用于加速机器学习训练的自适应数据采样方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN110110860B (zh) |
Families Citing this family (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112749565A (zh) * | 2019-10-31 | 2021-05-04 | 华为终端有限公司 | 基于人工智能的语义识别方法、装置和语义识别设备 |
CN110852420B (zh) * | 2019-11-11 | 2021-04-13 | 北京智能工场科技有限公司 | 一种基于人工智能的垃圾分类方法 |
CN111310901B (zh) * | 2020-02-24 | 2023-10-10 | 北京百度网讯科技有限公司 | 用于获取样本的方法及装置 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US8756175B1 (en) * | 2012-02-22 | 2014-06-17 | Google Inc. | Robust and fast model fitting by adaptive sampling |
CN107784312A (zh) * | 2016-08-24 | 2018-03-09 | 腾讯征信有限公司 | 机器学习模型训练方法及装置 |
CN108875933A (zh) * | 2018-05-08 | 2018-11-23 | 中国地质大学(武汉) | 一种无监督稀疏参数学习的超限学习机分类方法及系统 |
-
2019
- 2019-05-06 CN CN201910371632.1A patent/CN110110860B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US8756175B1 (en) * | 2012-02-22 | 2014-06-17 | Google Inc. | Robust and fast model fitting by adaptive sampling |
CN107784312A (zh) * | 2016-08-24 | 2018-03-09 | 腾讯征信有限公司 | 机器学习模型训练方法及装置 |
CN108875933A (zh) * | 2018-05-08 | 2018-11-23 | 中国地质大学(武汉) | 一种无监督稀疏参数学习的超限学习机分类方法及系统 |
Also Published As
Publication number | Publication date |
---|---|
CN110110860A (zh) | 2019-08-09 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110110860B (zh) | 一种用于加速机器学习训练的自适应数据采样方法 | |
CN112069310B (zh) | 基于主动学习策略的文本分类方法及系统 | |
CN107729999A (zh) | 考虑矩阵相关性的深度神经网络压缩方法 | |
CN109948029A (zh) | 基于神经网络自适应的深度哈希图像搜索方法 | |
CN109934826A (zh) | 一种基于图卷积网络的图像特征分割方法 | |
CN108960301B (zh) | 一种基于卷积神经网络的古彝文识别方法 | |
CN116503676B (zh) | 一种基于知识蒸馏小样本增量学习的图片分类方法及系统 | |
CN108596204B (zh) | 一种基于改进型scdae的半监督调制方式分类模型的方法 | |
CN111784595A (zh) | 一种基于历史记录的动态标签平滑加权损失方法及装置 | |
CN116089883B (zh) | 用于提高已有类别增量学习新旧类别区分度的训练方法 | |
CN111814963B (zh) | 一种基于深度神经网络模型参数调制的图像识别方法 | |
CN111931813A (zh) | 一种基于cnn的宽度学习分类方法 | |
CN107240100B (zh) | 一种基于遗传算法的图像分割方法和系统 | |
CN110942141A (zh) | 基于全局稀疏动量sgd的深度神经网络剪枝方法 | |
CN113393051A (zh) | 基于深度迁移学习的配电网投资决策方法 | |
CN111582442A (zh) | 一种基于优化深度神经网络模型的图像识别方法 | |
CN115828100A (zh) | 基于深度神经网络的手机辐射源频谱图类别增量学习方法 | |
CN116432780A (zh) | 一种模型增量学习方法、装置、设备及存储介质 | |
CN113590748B (zh) | 基于迭代网络组合的情感分类持续学习方法及存储介质 | |
CN115035304A (zh) | 一种基于课程学习的图像描述生成方法及系统 | |
CN113420834B (zh) | 一种基于关系约束自注意力的图像描述自动生成方法 | |
CN113920124A (zh) | 基于分割和误差引导的脑神经元迭代分割方法 | |
CN113377884A (zh) | 基于多智能体增强学习的事件语料库提纯方法 | |
CN115205577A (zh) | 用于图像分类的卷积神经网络的自适应优化训练方法 | |
CN109858127B (zh) | 基于递归时序深度置信网络的蓝藻水华预测方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |