CN114283307B - 一种基于重采样策略的网络训练方法 - Google Patents

一种基于重采样策略的网络训练方法 Download PDF

Info

Publication number
CN114283307B
CN114283307B CN202111600865.8A CN202111600865A CN114283307B CN 114283307 B CN114283307 B CN 114283307B CN 202111600865 A CN202111600865 A CN 202111600865A CN 114283307 B CN114283307 B CN 114283307B
Authority
CN
China
Prior art keywords
class
training
sampling
network
phase
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202111600865.8A
Other languages
English (en)
Other versions
CN114283307A (zh
Inventor
姚鹏
徐亮
程逸
申书伟
徐晓嵘
任维
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Yousheng Biotechnology Co ltd
University of Science and Technology of China USTC
Original Assignee
Shenzhen Yousheng Biotechnology Co ltd
University of Science and Technology of China USTC
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Yousheng Biotechnology Co ltd, University of Science and Technology of China USTC filed Critical Shenzhen Yousheng Biotechnology Co ltd
Priority to CN202111600865.8A priority Critical patent/CN114283307B/zh
Publication of CN114283307A publication Critical patent/CN114283307A/zh
Application granted granted Critical
Publication of CN114283307B publication Critical patent/CN114283307B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Landscapes

  • Image Analysis (AREA)

Abstract

本发明公开了一种基于重采样策略的网络训练方法,其采用了基于阶段性渐进学习策略的类不均衡处理方案,可以减轻类不均衡数据集头部类与尾部类数量不均衡的问题,有效降低尾类数据过拟合与头类数据欠拟合的风险;同时,可以从学习表征的通用模式平滑过度到上层分类器的训练,在学习分类器的同时能够很好地保留原有学习到的深度表征;此外,具有较好的鲁棒性,通过控制阶段性超参数和渐进超参数来适应不均衡程度不同的数据集或样本数量不同的数据集,并进一步提高分类的准确率。

Description

一种基于重采样策略的网络训练方法
技术领域
本发明涉及深度学习技术领域,尤其涉及一种基于重采样策略的网络训练方法。
背景技术
随着大规模、高质量数据集(如ImageNet ILSVRC 2012,MS COCO等数据集)的发展,基于图像识别的深度学习方法在各个领域都取得了显著的效果。然而获取大量的、人工指定标签的数据是十分庞大的工作,且在现实场景中,往往获得的数据集样本类别的分布具有不均衡的特性,即少数类(又名头部类)包含大多数样本,而大多数类(又名尾部类)只包含少数样本。通用的深度学习的方法在这种不均衡的数据集上表现往往很差,因此,类不均衡视觉识别成为一项具有挑战性的任务。
针对这种类别不均衡问题,以前的方法更多的是采用基于数据分布的性能权衡方法来减轻网络在训练过程中将更多的注意力集中在样本多的头部类别而忽略了对尾部类别的拟合。例如,目前的重采样(Re-Sampling,RS)方法通过对多数类样本进行欠采样或对少数类进行过采样以调整数据的分布情况。然而,在训练过程中,重采样通常存在对尾部类别过拟合和对头部类数据欠拟合的风险。
与这些从训练的初始阶段就开始进行重采样的性能权衡方法相比,两阶段的延迟重新采样方法(Deferred Re-Sampling,DRS)得到了更加广泛的应用。在训练的第一阶段,不采用重采样策略,而是采用通用的训练方法,在原始数据分布上训练深度神经网络,以进行对模型深层特征的学习,并将模型特征参数带到一个更好的初始状态;在第二阶段,以较小的学习率,采用重采样方法来微调网络,进行上层分类器的学习,使分类器匹配训练样本的不均衡分布。由于网络训练过程中特征参数的更新是一个非凸性优化问题,在第二阶段学习率很小的情况下,模型深层特征参数不会偏离第一阶段得到的最佳值太远,从而能够使得到的分类器整体上有更好的性能。但是,这种两阶段的方法忽略了两个阶段中的数据集偏差或域偏移,在第二阶段由于训练模式或者训练样本分布的突然转变,会使得模型最终的分类性能有所下降。
目前还没有办法很好的解决重采样和两阶段的方法在处理类不均衡问题时存在的缺陷。因此,亟需设计一种更加弹性的、从学习表征的通用模式到学习分类器的专用模式之间平滑过度的深度学习方法来解决现实场景中类不均衡的问题,以提升网络模型分类性能。
发明内容
本发明的目的是提供一种基于重采样策略的网络训练方法,可以为深度学习模型的使用场景提供更好的扩展,降低尾类数据过拟合与头类数据欠拟合的风险,提升网络模型的分类性能。
本发明的目的是通过以下技术方案实现的:
一种基于重采样策略的网络训练方法,包括:
获取目标图像数据集,确定数据类别总数C以及各类别样本数目,设定循环训练的当前轮数为E,同时,设定阶段性超参数Emin和Emax以及渐进超参数γ;
根据当前轮数E与阶段性超参数Emin和Emax的关系,确定当前属于训练的前期阶段、后期阶段、或是前期阶段与后期阶段之间的过渡阶段;若为前期阶段,则采用实例采样,即按照数据的原始分布均匀采样;若为后期阶段,则采用类均衡的采样方法,即按照相同的概率对不同类别进行采样,进行分类器的学习;若为过渡阶段,则采用渐进采样的方法,即不断调采样方式,以渐进的方式的从实例采样过渡到类均衡采样;
利用各阶段采样得到的样本对卷积神经网络中进行训练,并使用反向传播不断更新网络的权重参数直至网络收敛达到预期目标。
由上述本发明提供的技术方案可以看出,基于阶段性渐进学习策略的类不均衡处理方案可以减轻类不均衡数据集头部类与尾部类数量不均衡的问题,有效降低尾类数据过拟合与头类数据欠拟合的风险;同时,可以从学习表征的通用模式平滑过度到上层分类器的训练,在学习分类器的同时能够很好地保留原有学习到的深度表征;此外,具有较好的鲁棒性,通过控制阶段性超参数和渐进超参数来适应不均衡程度不同的数据集或样本数量不同的数据集,并进一步提高分类的准确率。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域的普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他附图。
图1为本发明实施例提供的一种基于重采样策略的网络训练方法的流程图。
具体实施方式
下面结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明的保护范围。
首先对本文中可能使用的术语进行如下说明:
术语“包括”、“包含”、“含有”、“具有”或其它类似语义的描述,应被解释为非排它性的包括。例如:包括某技术特征要素(如原料、组分、成分、载体、剂型、材料、尺寸、零件、部件、机构、装置、步骤、工序、方法、反应条件、加工条件、参数、算法、信号、数据、产品或制品等),应被解释为不仅包括明确列出的某技术特征要素,还可以包括未明确列出的本领域公知的其它技术特征要素。
下面对本发明所提供的一种基于重采样策略的网络训练方法进行详细描述。本发明实施例中未作详细描述的内容属于本领域专业技术人员公知的现有技术。本发明实施例中未注明具体条件者,按照本领域常规条件或制造商建议的条件进行。本发明实施例中所用试剂或仪器未注明生产厂商者,均为可以通过市售购买获得的常规产品。
本发明实施例提供的一种基于重采样策略的网络训练方法,它是一种针对类不均衡数据集的网络训练方法,其采用基于阶段性渐进采样(Phased Progressive Sampling,PPS)的策略,主要原理可以描述为:获取目标图像数据集,确定数据类别总数C以及各类别样本数目,设定循环训练的当前轮数为E,同时,设定阶段性超参数Emin和Emax以及渐进超参数γ;根据当前轮数E与阶段性超参数Emin和Emax的关系,确定当前属于训练的前期阶段、后期阶段、或是前期阶段与后期阶段之间的过渡阶段;若为前期阶段,则采用实例采样,即按照数据的原始分布均匀采样,以获取数据集特征空间的完整表征;若为后期阶段,则以较小的学习率(即学习率小于设定的门限值)采用类均衡的采样方法,即按照相同的概率对不同类别进行采样,进行分类器的学习;若为过渡阶段,则采用渐进采样的方法,即不断调采样方式,以渐进的方式缓慢的从实例采样过渡到类均衡采样,在保证分类器学习的同时降低对已经学习到的数据集特征空间表征的损坏;利用各阶段采样得到的样本对卷积神经网络中进行训练,并使用反向传播不断更新网络的权重参数直至网络收敛达到预期目标。如图1所示,上述方案其主要包括如下步骤:
步骤1:获取目标图像数据集,确定不同数据类别总数C以及各类别样本数目ni,设定循环训练的当前轮数为E,同时确定阶段性超参数Emin和Emax以及渐进超参数γ。
本发明实施例中,所述目标图像数据集为不均衡数据集;本发明不对类别总数C以及各类别样本数目ni的具体数值进行限定。阶段性超参数Emin和Emax主要是用来界定当前轮数所处的训练阶段,这两个参数的具体数值可以由本领域技术人员根据实际情况或者经验自行设定。渐进超参数γ的具体数值可以根据目标图像数据集的数据分布来设定。
步骤2:根据当前轮数E与阶段性超参数Emin和Emax的关系,确定当前属于训练的前期阶段、后期阶段、或是前期阶段与后期阶段之间的过渡阶段;不同阶段使用不同的采样策略;可以描述为:
上式中,第j类被采样到的概率,ni、nj分别表示第i类和第j类的样本数。
通过上式可知,随着训练过程的进行,循环训练轮数逐渐增加,阶段性的渐进采样自动完成对不平衡类采样频率的改变:
1)如果E<Emin,则当前属于训练的前期阶段,采用实例采样,第j类被采样到的概率:此阶段采用实例采样,即按照数据的原始分布采样,采样频率仅与样本量占总量的比值正相关,由于数据集中的每一个样本都有相同的概率被采样,可以保证最大限度的获取数据集特征空间的完整表征。
2)如果E>Emax,则当前属于训练的后期阶段,采用类均衡采样,第j类被采样到的概率均为:1/C。此阶段采用类均衡的采样方法即按照相同的概率对不同类别进行采样,此时相当于完全忽略样本数量所造成的差异,每一类的采样频率均相等,为1/C,这样可以使分类器对头部类和尾部类都有同样的关注度,从而保证学习效果。
3)如果Emin≤E≤Emax,则当前属于前期阶段与后期阶段之间的过渡阶段,采用渐进采样,第j类被采样到的概率:此阶段,采用渐进采样的方法,即不断调整数据集的采样方式,以渐进的方式缓慢的从实例采样过渡到类均衡采样,渐进采样频率为实例采样频率与类均衡采样频率的线性叠加,而叠加部分的权值则是由E与阶段超参数Emin和Emax的函数[(E-Emin)/(Emax-Emin)]γ和1[(E-Emin)/(Emax-Emin)]γ进行控制,其中渐进超参数γ控制权重变化的趋势,按照不同的数据分布来确定,这样可以尽可能的在保证分类器学习的同时降低对已经学习到的深层特征表征的损坏。
步骤3:将采样后的样本进行数据增强并作为输入送入到卷积神经网络中进行训练模型并输出结果。
本步骤所涉及的数据增强方案可以通过常规技术实现,本发明不做赘述。
步骤4:将卷积神经网络训练输出的预测结果与样本的真实标签送入到损失函数中进行误差计算,并使用反向传播不断更新网络的权重参数直至网络收敛达到预期目标,完成最终的训练。在整个神经网络的训练过程中,学习率逐渐降低,模型的训练逐渐从对网络深层特征的学习过渡到对浅层分类器的学习。
本步骤所涉及的损失函数可以是目前图像分类学习中的任意损失函数,例如交叉熵损失函数cross-entropy(CE),所涉及网络权重参数更新流程可参照常规技术实现,本发明不做赘述,所涉及的卷积神经网络可以是目前任意结构的形式的图像分类网络。
本发明提供的上述技术方案主要获得如下有益效果:
1)减轻了不均衡数据头部类与尾部类数量不均衡的问题,有效地缓解了现有重采样的方法对尾类数据过拟合与头类数据欠拟合的风险。
2)从学习表征的通用模式平滑过度到上层分类器的训练,在学习分类器的同时能够训练初始阶段学习到的深层特征表征。
3)具有较好的鲁棒性,通过控制阶段性超参数和渐进超参数很好地适应不均衡程度不同的数据集或样本数量不同的数据集,并进一步提高检测的准确率。
为了验证本发明上述方案的有效性,以现实场景中图像的分类为例,进行了相关实验。
选取的数据集为官方数据集CIFAR10,并通过常用的不均衡数据集转化方法,将均匀的十分类的原数据集按照指数衰减的形式转化为不均衡的样本,如表1所示。
类别 飞机 汽车 鹿 青蛙 卡车
数量 5000 2997 1796 1077 645 387 232 139 83 50
表1不均衡的样本数据分布
针对表1所示的不均衡的样本数据集,采用现有重采样方法(RS)与本发明阶段性渐进采样方法(PPS)的准确率比较,比较结果如表2所示。
表2准确率比较结果
在表1所示的不均衡的样本数据集上,交叉熵函数CE的准确率为70.54%,CE+RS的准确率为73.25%,CE+DRS的准确率为74.35%,而本发明提出的阶段性渐进采样方法CE+PPS准确率可达75.22%,相较于当前已知的技术方法准确率提高了0.87%。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例可以通过软件实现,也可以借助软件加必要的通用硬件平台的方式来实现。基于这样的理解,上述实施例的技术方案可以以软件产品的形式体现出来,该软件产品可以存储在一个非易失性存储介质(可以是CD-ROM,U盘,移动硬盘等)中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例所述的方法。
以上所述,仅为本发明较佳的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明披露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应该以权利要求书的保护范围为准。

Claims (2)

1.一种基于重采样策略的网络训练方法,其特征在于,包括:
获取目标图像数据集,确定数据类别总数C以及各类别样本数目,设定循环训练的当前轮数为E,同时,设定阶段性超参数Emin和Emax以及渐进超参数γ;
根据当前轮数E与阶段性超参数Emin和Emax的关系,确定当前属于训练的前期阶段、后期阶段、或是前期阶段与后期阶段之间的过渡阶段;若为前期阶段,则采用实例采样,即按照数据的原始分布均匀采样;若为后期阶段,则采用类均衡的采样方法,即按照相同的概率对不同类别进行采样,进行分类器的学习;若为过渡阶段,则采用渐进采样的方法,即不断调采样方式,以渐进的方式的从实例采样过渡到类均衡采样;
利用各阶段采样得到的样本对卷积神经网络中进行训练,并使用反向传播不断更新网络的权重参数直至网络收敛达到预期目标,其中,卷积神经网络为图像分类网络;
所述根据当前轮数与阶段性超参数Emin和Emax的关系,确定当前属于训练的前期阶段、后期阶段、或是前期阶段与后期阶段之间的过渡阶段包括:
如果E<Emin,则当前属于训练的前期阶段;
如果E>Emax,则当前属于训练的后期阶段;
如果Emin≤E≤Emax,则当前属于前期阶段与后期阶段之间的过渡阶段;
所述实例采样方式表示为:
其中,第j类被采样到的概率,ni、nj分别表示第i类和第j类的样本数;
所述渐进采样的方法表示为:
其中,第j类被采样到的概率,ni、nj分别表示第i类和第j类的样本数;
所述类均衡的采样方法表示为:
其中,第j类被采样到的概率。
2.根据权利要求1所述的一种基于重采样策略的网络训练方法,其特征在于,所述利用各阶段采样得到的样本对卷积神经网络中进行训练,并使用反向传播不断更新网络的权重参数直至网络收敛达到预期目标包括:
将卷积神经网络训练输出的预测结果与样本的真实标签送入至损失函数中进行误差计算,并使用反向传播不断更新网络的权重参数直至网络收敛达到预期目标,完成最终的训练。
CN202111600865.8A 2021-12-24 2021-12-24 一种基于重采样策略的网络训练方法 Active CN114283307B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111600865.8A CN114283307B (zh) 2021-12-24 2021-12-24 一种基于重采样策略的网络训练方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111600865.8A CN114283307B (zh) 2021-12-24 2021-12-24 一种基于重采样策略的网络训练方法

Publications (2)

Publication Number Publication Date
CN114283307A CN114283307A (zh) 2022-04-05
CN114283307B true CN114283307B (zh) 2023-10-27

Family

ID=80875165

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111600865.8A Active CN114283307B (zh) 2021-12-24 2021-12-24 一种基于重采样策略的网络训练方法

Country Status (1)

Country Link
CN (1) CN114283307B (zh)

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114866297B (zh) * 2022-04-20 2023-11-24 中国科学院信息工程研究所 网络数据检测方法、装置、电子设备及存储介质
CN115565681A (zh) * 2022-10-21 2023-01-03 电子科技大学(深圳)高等研究院 面向不平衡数据的IgA肾病的预测分析系统
CN115953631B (zh) * 2023-01-30 2023-09-15 南开大学 基于深度迁移学习的长尾小样本声纳图像分类方法及系统

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111680724A (zh) * 2020-05-26 2020-09-18 中国人民解放军96901部队21分队 一种基于特征距离与内点随机抽样一致性的模型估计方法
CN111738301A (zh) * 2020-05-28 2020-10-02 华南理工大学 一种基于双通道学习的长尾分布图像数据识别方法
CN112101544A (zh) * 2020-08-21 2020-12-18 清华大学 适用于长尾分布数据集的神经网络的训练方法和装置
CN112633517A (zh) * 2020-12-29 2021-04-09 重庆星环人工智能科技研究院有限公司 一种机器学习模型的训练方法、计算机设备及存储介质
CN112766379A (zh) * 2021-01-21 2021-05-07 中国科学技术大学 一种基于深度学习多权重损失函数的数据均衡方法
CN113407820A (zh) * 2021-05-29 2021-09-17 华为技术有限公司 模型训练方法及相关系统、存储介质
CN113792751A (zh) * 2021-07-28 2021-12-14 中国科学院自动化研究所 一种跨域行为识别方法、装置、设备及可读存储介质

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US9965717B2 (en) * 2015-11-13 2018-05-08 Adobe Systems Incorporated Learning image representation by distilling from multi-task networks
US10529077B2 (en) * 2017-12-19 2020-01-07 Canon Kabushiki Kaisha System and method for detecting interaction

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111680724A (zh) * 2020-05-26 2020-09-18 中国人民解放军96901部队21分队 一种基于特征距离与内点随机抽样一致性的模型估计方法
CN111738301A (zh) * 2020-05-28 2020-10-02 华南理工大学 一种基于双通道学习的长尾分布图像数据识别方法
CN112101544A (zh) * 2020-08-21 2020-12-18 清华大学 适用于长尾分布数据集的神经网络的训练方法和装置
CN112633517A (zh) * 2020-12-29 2021-04-09 重庆星环人工智能科技研究院有限公司 一种机器学习模型的训练方法、计算机设备及存储介质
CN112766379A (zh) * 2021-01-21 2021-05-07 中国科学技术大学 一种基于深度学习多权重损失函数的数据均衡方法
CN113407820A (zh) * 2021-05-29 2021-09-17 华为技术有限公司 模型训练方法及相关系统、存储介质
CN113792751A (zh) * 2021-07-28 2021-12-14 中国科学院自动化研究所 一种跨域行为识别方法、装置、设备及可读存储介质

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
A Real-Life Machine Learning Experience for Predicting University Dropout at Different Stages Using Academic Data;Antonio Jesús Fernández-García 等;《IEEE Access》;第133076-133090页 *
基于组合导航技术的粒子滤波改进方法综述;杜小菁 等;《Science Discovery》;第369-374页 *

Also Published As

Publication number Publication date
CN114283307A (zh) 2022-04-05

Similar Documents

Publication Publication Date Title
CN114283307B (zh) 一种基于重采样策略的网络训练方法
CN111638488B (zh) 一种基于lstm网络的雷达干扰信号识别方法
CN110197286B (zh) 一种基于混合高斯模型和稀疏贝叶斯的主动学习分类方法
CN107564513A (zh) 语音识别方法及装置
CN111429947B (zh) 一种基于多级残差卷积神经网络的语音情感识别方法
CN110598806A (zh) 一种基于参数优化生成对抗网络的手写数字生成方法
CN108847223A (zh) 一种基于深度残差神经网络的语音识别方法
CN111126226B (zh) 一种基于小样本学习和特征增强的辐射源个体识别方法
CN114332539A (zh) 针对类别不均衡数据集的网络训练方法
CN112331181B (zh) 一种基于多说话人条件下目标说话人语音提取方法
CN113256508A (zh) 一种改进的小波变换与卷积神经网络图像去噪声的方法
CN113761805B (zh) 一种基于时域卷积网络的可控源电磁数据去噪方法、系统、终端及可读存储介质
CN114463576B (zh) 一种基于重加权策略的网络训练方法
CN112014801A (zh) 一种基于SPWVD和改进AlexNet的复合干扰识别方法
CN112884059A (zh) 一种融合先验知识的小样本雷达工作模式分类方法
CN110930996A (zh) 模型训练方法、语音识别方法、装置、存储介质及设备
CN108109612A (zh) 一种基于自适应降维的语音识别分类方法
Pijackova et al. Radio modulation classification using deep learning architectures
CN109284662A (zh) 一种面向水下声音信号分类的迁移学习方法
CN111814963A (zh) 一种基于深度神经网络模型参数调制的图像识别方法
CN116246126A (zh) 迭代无监督域自适应方法和装置
CN115392285A (zh) 一种基于多模态的深度学习信号个体识别模型防御方法
CN114897884A (zh) 基于多尺度边缘特征融合的无参考屏幕内容图像质量评估方法
CN114863088A (zh) 一种面向长尾目标检测的分类对数归一化方法
US20240005531A1 (en) Method For Change Detection

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant