CN113435519B - 基于对抗插值的样本数据增强方法、装置、设备及介质 - Google Patents
基于对抗插值的样本数据增强方法、装置、设备及介质 Download PDFInfo
- Publication number
- CN113435519B CN113435519B CN202110730469.0A CN202110730469A CN113435519B CN 113435519 B CN113435519 B CN 113435519B CN 202110730469 A CN202110730469 A CN 202110730469A CN 113435519 B CN113435519 B CN 113435519B
- Authority
- CN
- China
- Prior art keywords
- interpolation
- sample data
- proportion
- representing
- data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 54
- 238000011478 gradient descent method Methods 0.000 claims abstract description 13
- 230000006870 function Effects 0.000 claims description 29
- 230000004927 fusion Effects 0.000 claims description 15
- 230000002708 enhancing effect Effects 0.000 claims description 14
- 238000004364 calculation method Methods 0.000 claims 2
- 238000013145 classification model Methods 0.000 abstract description 21
- 238000010586 diagram Methods 0.000 description 18
- 238000012549 training Methods 0.000 description 9
- 239000000203 mixture Substances 0.000 description 7
- 238000004590 computer program Methods 0.000 description 5
- 230000000694 effects Effects 0.000 description 5
- 238000002372 labelling Methods 0.000 description 3
- 238000010801 machine learning Methods 0.000 description 3
- 238000012545 processing Methods 0.000 description 2
- 102100033814 Alanine aminotransferase 2 Human genes 0.000 description 1
- 101710096000 Alanine aminotransferase 2 Proteins 0.000 description 1
- 230000008485 antagonism Effects 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000005540 biological transmission Effects 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000002156 mixing Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computational Linguistics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Evolutionary Biology (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Complex Calculations (AREA)
Abstract
本发明公开了一种基于对抗插值的样本数据增强方法、装置、设备及介质,所述方法包括:获取已标注的第一样本数据,根据mixup算法对所述第一样本数据进行随机插值,得到第二样本数据;通过梯度下降方法调整插值比例,得到更新后的插值比例;根据所述更新后的插值比例,重新进行插值运算,得到增强的第三样本数据。根据本公开实施例提供的基于对抗插值的样本数据增强方法,利用对抗学习的方法来搜索插值比例,增强程度可以控制,可以生成更“难”的增强样本,从而提高分类模型在低资源情况下的准确度。
Description
技术领域
本发明涉及数据处理技术领域,特别涉及一种基于对抗插值的样本数据增强方法、装置、设备及介质。
背景技术
现实场景使用图像分类模型、语音分类模型或者文本分类模型会遇到标注数据少(低资源)的问题。在低资源的情况下,比如每个类别只有少量样本,模型可能会过拟合导致其性能不达预期。这种过拟合情况在数据稀缺的情况下更加明显,例如每个类别只有5个样本的极端情况。
面对一个标注数据稀缺的低资源应用场景,数据增强是一种有效的技术方法,可以利用非常少量的标注语料得到一个有一定性能的基础模型,帮助破解低资源困局、减少对标注的需求,快速进入模型优化的迭代开发。
但是,现有技术中的数据增强方法都是单样本增强。例如:在文本分类的场景中,通常用GPT-2模型生成某个类别的合成样本,然后将合成样本放入训练集合训练模型,提升模型的泛化能力。这样的单样本增强的程度难以控制使得增强效果不能得到保证。而基于插值的数据增强利用两个不同类别的真实样本进行插值生成一个插值样本,会因插值比例不同而生成出不同“难易”程度的样本,从而影响到分类模型的效果。
发明内容
本公开实施例提供了一种基于对抗插值的样本数据增强方法、装置、设备及介质。解决了现有技术中标注数据少,影响模型训练效果的问题。为了对披露的实施例的一些方面有一个基本的理解,下面给出了简单的概括。该概括部分不是泛泛评述,也不是要确定关键/重要组成元素或描绘这些实施例的保护范围。其唯一目的是用简单的形式呈现一些概念,以此作为后面的详细说明的序言。
第一方面,本公开实施例提供了一种基于对抗插值的样本数据增强方法,包括:
获取已标注的第一样本数据,根据mixup算法对第一样本数据进行随机插值,得到第二样本数据;
通过梯度下降方法调整插值比例,得到更新后的插值比例;
根据更新后的插值比例,重新进行插值运算,得到增强的第三样本数据。
在一个可选地实施例中,根据mixup算法对第一样本数据进行随机插值,得到第二样本数据,包括:
从第一样本数据中随机抽取两个样本;
从Beta分布中随机抽取一个插值比例,得到随机插值比例;
根据抽取的样本数据、随机插值比例以及mixup算法进行随机插值,得到第二样本数据。
在一个可选地实施例中,根据如下公式进行随机插值,得到第二样本数据:
λ~Beta(α,α)
其中,{xi,yi}和{xj,yj}表示抽取的样本数据,λ表示插值的比例,Beta(α,α)表示beta分布,gk(xi)和gk(xj)表示xi和xj经过网络编码后的数据,表示根据插值比例λ将位置K的词的表示gk(xi)和gk(xj)进行插值融合后的增强数据,/>表示对于xi和xj对应的标签yi和yj进行插值融合后的增强数据。
在一个可选地实施例中,通过梯度下降方法调整插值比例,得到更新后的插值比例,包括:
根据预设的损失函数计算每个位置的插值损失;
对随机插值比例求偏导,根据随机插值比例的偏导值以及损失值计算当前的梯度;
根据得到的梯度更新随机插值比例,得到对抗方向上的最新插值比例。
在一个可选地实施例中,预设的损失函数如下公式所示:
其中,θ是模型的参数,i和j是真实标记数据的编号,frand表示随机插值操作,λ表示插值的比例,η表示对抗噪音,λ~Beta(α,α),Beta(α,α)表示beta分布,lmix表示插值的损失函数。
在一个可选地实施例中,根据如下公式计算当前的梯度,
其中,η表示梯度,Δλ表示随机插值比例的偏导,表示损失值。
在一个可选地实施例中,根据如下公式计算最新插值比例,
λ′=λ+εη
其中,λ′表示最新插值比例,ε表示步长,η表示梯度。
第二方面,本公开实施例提供了一种基于对抗插值的样本数据增强装置,包括:
第一插值模块,用于获取已标注的第一样本数据,根据mixup算法对第一样本数据进行随机插值,得到第二样本数据;
插值比例更新模块,用于通过梯度下降方法调整插值比例,得到更新后的插值比例;
第二插值模块,用于根据更新后的插值比例,重新进行插值运算,得到增强的第三样本数据。
第三方面,本公开实施例提供了一种计算机设备,包括存储器和处理器,存储器中存储有计算机可读指令,计算机可读指令被处理器执行时,使得处理器执行上述实施例提供的基于对抗插值的样本数据增强方法的步骤。
第四方面,本公开实施例提供了一种存储有计算机可读指令的存储介质,计算机可读指令被一个或多个处理器执行时,使得一个或多个处理器执行上述实施例提供的基于对抗插值的样本数据增强方法的步骤。
本公开实施例提供的技术方案可以包括以下有益效果:
根据本公开实施例提供的基于对抗插值的样本数据增强方法,利用对抗学习的方法来调整插值比例,增强程度可以控制,可以生成使机器学习算法产生误判的增强样本,从而提高分类模型在低资源情况下的准确度。而且根据该方法得到的增强数据样本,可以用于文本,图像和音频等多种数据分类,适用领域很广泛。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本发明。
附图说明
此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本发明的实施例,并与说明书一起用于解释本发明的原理。
图1是根据一示例性实施例示出的一种基于对抗插值的样本数据增强方法的实施环境图;
图2是根据一示例性实施例示出的一种计算机设备的内部结构图;
图3是根据一示例性实施例示出的一种基于对抗插值的样本数据增强方法的流程示意图;
图4是根据一示例性实施例示出的一种随机插值方法的示意图;
图5是根据一示例性实施例示出的一种更新插值比例方法的示意图;
图6是根据一示例性实施例示出的一种随机插值的示意图;
图7是根据一示例性实施例示出的一种调整插值比例的示意图;
图8是根据一示例性实施例示出的一种基于对抗插值的样本数据增强的整体示意图;
图9是根据一示例性实施例示出的一种基于对抗插值的样本数据增强装置的结构示意图。
具体实施方式
为了使本发明的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本发明进行进一步详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
可以理解,本申请所使用的术语“第一”、“第二”等可在本文中用于描述各种元件,但这些元件不受这些术语限制。这些术语仅用于将第一个元件与另一个元件区分。举例来说,在不脱离本申请的范围的情况下,可以将第一字段及算法确定模块成为第二字段及算法确定模块,且类似地,可将第二字段及算法确定模块成为第一字段及算法确定模块。
图1是根据一示例性实施例示出的一种基于对抗插值的样本数据增强方法的实施环境图,如图1所示,在该实施环境中,包括服务器110以及终端120。
服务器110为基于对抗插值的样本数据增强设备,例如为技术人员使用的电脑等计算机设备,服务器110上安装有数据增强工具。终端120上安装有需要进行数据增强的应用,当需要提供数据增强服务时,技术人员可以在计算机设备110发出提供数据增强的请求,该请求中携带有请求标识,计算机设备110接收该请求,获取计算机设备110中存储的基于对抗插值的样本数据增强方法。然后利用该方法实现数据处理。
需要说明的是,终端120以及计算机设备110可为智能手机、平板电脑、笔记本电脑、台式计算机等,但并不局限于此。计算机设备110以及终端120可以通过蓝牙、USB(Universal Serial Bus,通用串行总线)或者其他通讯连接方式进行连接,本发明在此不做限制。
图2是根据一示例性实施例示出的一种计算机设备的内部结构图。如图2所示,该计算机设备包括通过系统总线连接的处理器、非易失性存储介质、存储器和网络接口。其中,该计算机设备的非易失性存储介质存储有操作系统、数据库和计算机可读指令,数据库中可存储有控件信息序列,该计算机可读指令被处理器执行时,可使得处理器实现一种基于对抗插值的样本数据增强方法。该计算机设备的处理器用于提供计算和控制能力,支撑整个计算机设备的运行。该计算机设备的存储器中可存储有计算机可读指令,该计算机可读指令被处理器执行时,可使得处理器执行一种基于对抗插值的样本数据增强方法。该计算机设备的网络接口用于与终端连接通信。本领域技术人员可以理解,图2中示出的结构,仅仅是与本申请方案相关的部分结构的框图,并不构成对本申请方案所应用于其上的计算机设备的限定,具体的计算机设备可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
下面将结合附图3-附图8,对本申请实施例提供的基于对抗插值的样本数据增强方法进行详细介绍。该方法可依赖于计算机程序实现,可运行于基于冯诺依曼体系的数据传输装置上。该计算机程序可集成在应用中,也可作为独立的工具类应用运行。
请参见图3,为本申请实施例提供了一种基于对抗插值的样本数据增强方法的流程示意图,如图3所示,本申请实施例的方法可以包括以下步骤:
S301获取已标注的第一样本数据,根据mixup算法对第一样本数据进行随机插值,得到第二样本数据。
在一种可能的实现方式中,首先获取目标分类模型已标注的第一样本数据,其中,第一样本数据包含已标注的标签。
在一种可能的实现方式中,可以从标注样本数据库中获取第一样本数据,或者本领域技术人员自行标注第一样本数据。
本实施例中,目标分类模型通过已标注的样本训练得到,目标分类模型可为文本分类模型、图像分类模型或语音分类模型,针对不同的模型功能,目标分类模型可以采用不同的神经网络来实现。
进一步地,根据mixup算法对第一样本数据进行随机插值,得到第二样本数据。图4是根据一示例性实施例示出的一种随机插值方法的示意图,如图4所示,该方法包括:S401从第一样本数据中随机抽取两个样本;S402从Beta分布中随机抽取一个插值比例,得到随机插值比例;S403根据抽取的样本数据、随机插值比例以及mixup算法进行随机插值。
mixup是一种运用在计算机视觉中的对图像进行混类增强的算法,它可以将不同类之间的图像进行混合,从而扩充训练数据集,图6是根据一示例性实施例示出的一种随机插值的示意图;mixup算法的步骤如图6所示,首先从得到的第一样本中随机选择两个样本{xi,yi}和{xj,yj},将两个输入xi和xj经过网络编码后得到gk(xi)和gk(xj);
再从Beta分布中随机抽样得到一个随机插值比例λ,该值属于[0,1],如下公式所示:
λ~Beta(α,α)
根据插值比例λ将位置K的词的表示gk(xi)和gk(xj)进行插值,得到融合后的
同时对于xi和xj对应的标签yi和yj进行插值,得到
和/>相当于新的增强数据。
S302通过梯度下降方法调整插值比例,得到更新后的插值比例。
图7是根据一示例性实施例示出的一种调整插值比例的示意图;如图7所示,为了提高模型的训练效果,本公开实施例引入对抗操作,通过梯度下降法在对抗方向上调整插值比例λ,生成难度更高的样本,对抗样本指的是使机器学习算法产生误判的样本。模型应该多用难度较高的样本进行训练。
图5是根据一示例性实施例示出的一种更新插值比例方法的示意图,如图5所示,该方法包括:
S501根据预设的损失函数计算每个位置的插值损失;
在一个可选地实施例中,损失函数如下所示:
其中,θ是模型的参数,i和j是真实标记数据的编号,frand表示随机插值操作,λ表示插值的比例,η表示对抗噪音,Beta(α,α)表示beta分布,lmix表示插值的损失函数。
S502对随机插值比例求偏导,根据随机插值比例的偏导值以及损失值计算当前的梯度;
在一个可选地实施例中,根据如下公式计算当前的梯度,包括:
其中,η表示梯度,Δλ表示随机插值比例的偏导,表示损失值。
S503根据得到的梯度更新随机插值比例,得到对抗方向上的最新插值比例。在一个可选地实施例中,根据如下公式计算最新插值比例:
λ′=λ+εη
其中,λ′表示最新插值比例,ε表示步长,η表示梯度。
根据该步骤,可以通过梯度下降方法更新插值比例,得到对抗方向上的最新插值比例。
S303根据更新后的插值比例,重新进行插值运算,得到增强的第三样本数据。
在一种可能的实现方式中,根据最新得到的插值比例,按照如下公式重新进行插值运算,得到增强的第三样本数据。
其中,λ′表示最新插值比例,{xi,yi}和{xj,yj}表示抽取的样本数据,Beta(α,α)表示beta分布,gk(xi)和gk(xj)表示xi和xj经过网络编码后的数据,表示根据插值比例λ′将位置K的词的表示gk(xi)和gk(xj)进行插值融合后的增强数据,/>表示对于xi和xj对应的标签yi和yj进行插值融合后的增强数据。
重新得到的增强数据是基于对抗方向上的插值比例生成的,因此,最新得到的增强数据难度更高,更适合训练模型,可以更好的正则化模型,提升分类模型在低资源下的效果。
进一步地,基于对抗插值方法得到增强后的第三样本数据,将所述第三样本数据作为分类模型的训练集,基于所述训练集训练所述目标分类模型,并最小化损失函数,得到训练好的分类模型。
在一种可能的实现方式中,可以对文字、语音、图像等数据进行增强,然后基于增强的数据训练文字分类模型、语音分类模型、图像分类模型,应用范围广泛。
为了便于理解本公开实施例提供的基于对抗插值的样本数据增强方法,下面结合附图8进一步说明。图8是根据一示例性实施例示出的一种基于对抗插值的样本数据增强的整体示意图,如图8所示:
本公开实施例提供的基于对抗插值的数据增强方法,主要包含三个步骤,首先是随机插值阶段,基于mixup方法生成随机插值比例,进行随机插值,得到第二样本数据。然后根据梯度下降方法调整插值比例,最大化插值损失函数,得到更新的插值比例。最后,根据更新后的插值比例,重新进行插值,最小化插值损失函数,得到更难的第三样本数据。
用数学公式来描述如下所示:
其中,θ是分类模型的参数,i和j是真实标记数据的编号,比如,i表示(xi,yi),frand表示随机插值操作,λ表示插值的比例,η表示对抗噪音,lmix表示插值的损失函数,ε表示步长,表示插值数据的集合。
根据本公开实施例提供的基于对抗插值的数据增强方法,增强程度可以控制,可以生成使机器学习算法产生误判的难度更高的增强样本,从而对模型在低资源情况下提升更明显,提高模型对高难度样本的识别准确率,解决标注数据少影响模型准确度的问题。而且可以用于文本、图像和音频等多种数据分类,适用领域很广泛。
下述为本申请装置实施例,可以用于执行本发明方法实施例。对于本发明装置实施例中未披露的细节,请参照本发明方法实施例。
请参见图9,其示出了本发明一个示例性实施例提供的基于对抗插值的样本数据增强装置的结构示意图。如图9所示,该基于对抗插值的样本数据增强装置可以集成于上述的计算机设备110中,具体可以包括第一插值模块901、插值比例更新模块902以及第二插值模块903。
第一插值模块901,用于获取已标注的第一样本数据,根据mixup算法对第一样本数据进行随机插值,得到第二样本数据;
插值比例更新模块902,用于通过梯度下降算法调整插值比例,得到更新后的插值比例;
第二插值模块903,用于根据更新后的插值比例,重新进行插值运算,得到增强的第三样本数据。
在一个可选地实施例中,第一插值模块901具体用于从第一样本数据中随机抽取两个样本;从Beta分布中随机抽取一个插值比例,得到随机插值比例;根据抽取的样本数据、随机插值比例以及mixup算法进行随机插值,得到第二样本数据。
在一个可选地实施例中,根据如下公式进行随机插值,得到第二样本数据:
λ~Beta(α,α)
其中,{xi,yi}和{xj,yj}表示抽取的样本数据,λ表示插值的比例,Beta(α,α)表示beta分布,gk(xi)和gk(xj)表示xi和xj经过网络编码后的数据,表示根据插值比例λ将位置K的词的表示gk(xi)和gk(xj)进行插值融合后的增强数据,/>表示对于xi和xj对应的标签yi和yj进行插值融合后的增强数据。
在一个可选地实施例中,插值比例更新模块902具体用于根据预设的损失函数计算每个位置的插值损失;对随机插值比例求偏导,根据随机插值比例的偏导值以及损失值计算当前的梯度;根据得到的梯度更新随机插值比例,得到对抗方向上的最新插值比例。
在一个可选地实施例中,预设的损失函数如下公式所示:
其中,θ是模型的参数,i和j是真实标记数据的编号,frand表示随机插值操作,λ表示插值的比例,η表示对抗噪音,λ~Beta(α,α),Beta(α,α)表示beta分布,lmix表示插值的损失函数。
在一个可选地实施例中,根据如下公式计算当前的梯度,
其中,η表示梯度,Δλ表示随机插值比例的偏导,表示损失值。
在一个可选地实施例中,根据如下公式计算最新插值比例,
λ′=λ+εη
其中,λ′表示最新插值比例,ε表示步长,η表示梯度。
根据本公开实施例提供的基于对抗插值的数据增强装置,增强程度可以控制,可以生成更“难”的增强样本,从而对模型在低资源情况下提升更明显,提高模型对高难度样本的识别准确率,解决标注数据少影响模型准确度的问题。而且可以用于文本、图像和音频等多种数据分类,适用领域很广泛。
需要说明的是,上述实施例提供的基于对抗插值的数据增强装置在执行基于对抗插值的数据增强方法时,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将设备的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。另外,上述实施例提供的基于对抗插值的数据增强装置与基于对抗插值的数据增强方法实施例属于同一构思,其体现实现过程详见方法实施例,这里不再赘述。
在一个实施例中,提出了一种计算机设备,计算机设备包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,处理器执行计算机程序时实现以下步骤:获取已标注的第一样本数据,根据mixup算法对第一样本数据进行随机插值,得到第二样本数据;通过梯度下降算法调整插值比例,得到更新后的插值比例;根据更新后的插值比例,重新进行插值运算,得到增强的第三样本数据。
在一个可选地实施例中,根据mixup算法对第一样本数据进行随机插值,得到第二样本数据,包括:
从第一样本数据中随机抽取两个样本;
从Beta分布中随机抽取一个插值比例,得到随机插值比例;
根据抽取的样本数据、随机插值比例以及mixup算法进行随机插值,得到第二样本数据。
在一个可选地实施例中,根据如下公式进行随机插值,得到第二样本数据:
λ~Beta(α,α)
其中,{xi,yi}和{xj,yj}表示抽取的样本数据,λ表示插值的比例,Beta(α,α)表示beta分布,gk(xi)和gk(xj)表示xi和xj经过网络编码后的数据,表示根据插值比例λ将位置K的词的表示gk(xi)和gk(xj)进行插值融合后的增强数据,/>表示对于xi和xj对应的标签yi和yj进行插值融合后的增强数据。
在一个可选地实施例中,通过梯度下降方法调整插值比例,得到更新后的插值比例,包括:
根据预设的损失函数计算每个位置的插值损失;
对随机插值比例求偏导,根据随机插值比例的偏导值以及损失值计算当前的梯度;
根据得到的梯度更新随机插值比例,得到对抗方向上的最新插值比例。
在一个可选地实施例中,预设的损失函数如下公式所示:
其中,θ是模型的参数,i和j是真实标记数据的编号,frand表示随机插值操作,λ表示插值的比例,η表示对抗噪音,λ~Beta(α,α),Beta(α,α)表示beta分布,lmix表示插值的损失函数。
在一个可选地实施例中,根据如下公式计算当前的梯度,
其中,η表示梯度,Δλ表示随机插值比例的偏导,表示损失值。
在一个可选地实施例中,根据如下公式计算最新插值比例,
λ′=λ+εη
其中,λ′表示最新插值比例,ε表示步长,η表示梯度。
在一个实施例中,提出了一种存储有计算机可读指令的存储介质,该计算机可读指令被一个或多个处理器执行时,使得一个或多个处理器执行以下步骤:获取已标注的第一样本数据,根据mixup算法对第一样本数据进行随机插值,得到第二样本数据;通过梯度下降方法调整插值比例,得到更新后的插值比例;根据更新后的插值比例,重新进行插值运算,得到增强的第三样本数据。
在一个可选地实施例中,根据mixup算法对第一样本数据进行随机插值,得到第二样本数据,包括:
从第一样本数据中随机抽取两个样本;
从Beta分布中随机抽取一个插值比例,得到随机插值比例;
根据抽取的样本数据、随机插值比例以及mixup算法进行随机插值,得到第二样本数据。
在一个可选地实施例中,根据如下公式进行随机插值,得到第二样本数据:
λ~Beta(α,α)
其中,{xi,yi}和{xj,yj}表示抽取的样本数据,λ表示插值的比例,Beta(α,α)表示beta分布,gk(xi)和gk(xj)表示xi和xj经过网络编码后的数据,表示根据插值比例λ将位置K的词的表示gk(xi)和gk(xj)进行插值融合后的增强数据,/>表示对于xi和xj对应的标签yi和yj进行插值融合后的增强数据。
在一个可选地实施例中,通过梯度下降方法调整插值比例,得到更新后的插值比例,包括:
根据预设的损失函数计算每个位置的插值损失;
对随机插值比例求偏导,根据随机插值比例的偏导值以及损失值计算当前的梯度;
根据得到的梯度更新随机插值比例,得到对抗方向上的最新插值比例。
在一个可选地实施例中,预设的损失函数如下公式所示:
其中,θ是模型的参数,i和j是真实标记数据的编号,frand表示随机插值操作,λ表示插值的比例,η表示对抗噪音,λ~Beta(α,α),Beta(α,α)表示beta分布,lmix表示插值的损失函数。
在一个可选地实施例中,根据如下公式计算当前的梯度,
其中,η表示梯度,Δλ表示随机插值比例的偏导,表示损失值。
在一个可选地实施例中,根据如下公式计算最新插值比例,
λ′=λ+εη
其中,λ′表示最新插值比例,ε表示步长,η表示梯度。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,该计算机程序可存储于一计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,前述的存储介质可为磁碟、光盘、只读存储记忆体(Read-Only Memory,ROM)等非易失性存储介质,或随机存储记忆体(Random Access Memory,RAM)等。
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
以上实施例仅表达了本发明的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对本发明专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本发明构思的前提下,还可以做出若干变形和改进,这些都属于本发明的保护范围。因此,本发明专利的保护范围应以所附权利要求为准。
Claims (7)
1.一种基于对抗插值的样本数据增强方法,其特征在于,包括:
获取已标注的第一样本数据,根据mixup算法对所述第一样本数据进行随机插值,得到第二样本数据;所述第一样本数据为文本、图像或音频中的一种数据;
通过梯度下降方法调整插值比例,得到更新后的插值比例,包括:
根据预设的损失函数计算每个位置的插值损失;其中,所述预设的损失函数如下公式所示:,其中,是模型的参数,i和j是真实标记数据的编号,/>表示随机插值操作,/>表示插值的比例,/>,/>表示beta分布,/>表示插值的损失函数;/>表示交叉熵损失函数;/>和/>为真实标记数据的标签;/>表示根据插值比例λ将位置K的词的表示进行插值融合后的增强数据;/>为所述增强数据的标签;
对随机插值比例求偏导,根据随机插值比例的偏导值以及损失值计算当前的梯度;所述当前的梯度的计算公式为:,其中,/>表示梯度,代表对抗噪音,/>表示随机插值比例的偏导,/>表示损失值;
根据得到的梯度更新所述随机插值比例,得到对抗方向上的最新插值比例;
根据所述更新后的插值比例,重新进行插值运算,得到增强的第三样本数据。
2.根据权利要求1所述的方法,其特征在于,根据mixup算法对所述第一样本数据进行随机插值,得到第二样本数据,包括:
从所述第一样本数据中随机抽取两个样本;
从Beta分布中随机抽取一个插值比例,得到随机插值比例;
根据抽取的样本数据、随机插值比例以及mixup算法进行随机插值,得到第二样本数据。
3.根据权利要求2所述的方法,其特征在于,根据如下公式进行随机插值,得到所述第二样本数据:
其中,{,/>}和{/>,/>}表示抽取的样本数据,λ表示插值的比例,/>表示beta分布,/>和/>表示/>和/>经过网络编码后的数据,/>表示根据插值比例λ将位置K的词的表示/>和/>进行插值融合后的增强数据,/>表示对于/>和/>对应的标签/>和/>进行插值融合后的增强数据。
4.根据权利要求1所述的方法,其特征在于,根据如下公式计算最新插值比例,
其中,表示最新插值比例,/>表示步长,/>表示梯度。
5.一种基于对抗插值的样本数据增强装置,其特征在于,包括:
第一插值模块,用于获取已标注的第一样本数据,根据mixup算法对所述第一样本数据进行随机插值,得到第二样本数据;所述第一样本数据为文本、图像或音频中的一种数据;
插值比例更新模块,用于通过梯度下降方法调整插值比例,得到更新后的插值比例;所述插值比例更新模块,具体用于根据预设的损失函数计算每个位置的插值损失;对随机插值比例求偏导,根据随机插值比例的偏导值以及损失值计算当前的梯度;根据得到的梯度更新所述随机插值比例,得到对抗方向上的最新插值比例;其中,所述预设的损失函数如下公式所示:,其中,/>是模型的参数,i和j是真实标记数据的编号,/>表示随机插值操作,/>表示插值的比例,/>表示对抗噪音,/>,/>表示beta分布,/>表示插值的损失函数;/>表示交叉熵损失函数;/>和/>为真实标记数据的标签;/>表示根据插值比例λ将位置K的词的表示进行插值融合后的增强数据;/>为所述增强数据的标签;所述当前的梯度的计算公式为:/>,其中,/>表示梯度,代表对抗噪音,/>表示随机插值比例的偏导,/>表示损失值;/>表示根据插值比例λ将位置K的词的表示进行插值融合后的增强数据;/>为所述增强数据的标签;
第二插值模块,用于根据所述更新后的插值比例,重新进行插值运算,得到增强的第三样本数据。
6.一种计算机设备,包括存储器和处理器,所述存储器中存储有计算机可读指令,所述计算机可读指令被所述处理器执行时,使得所述处理器执行如权利要求1至4中任一项权利要求所述的基于对抗插值的样本数据增强方法的步骤。
7.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质包括计算机指令,所述计算机指令被一个或多个处理器执行时,实现如权利要求1至4中任一项权利要求所述的基于对抗插值的样本数据增强方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110730469.0A CN113435519B (zh) | 2021-06-29 | 2021-06-29 | 基于对抗插值的样本数据增强方法、装置、设备及介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110730469.0A CN113435519B (zh) | 2021-06-29 | 2021-06-29 | 基于对抗插值的样本数据增强方法、装置、设备及介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113435519A CN113435519A (zh) | 2021-09-24 |
CN113435519B true CN113435519B (zh) | 2024-03-01 |
Family
ID=77758026
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110730469.0A Active CN113435519B (zh) | 2021-06-29 | 2021-06-29 | 基于对抗插值的样本数据增强方法、装置、设备及介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113435519B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115455177B (zh) * | 2022-08-02 | 2023-07-21 | 淮阴工学院 | 基于混合样本空间的不平衡化工文本数据增强方法及装置 |
Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112836820A (zh) * | 2021-01-31 | 2021-05-25 | 云知声智能科技股份有限公司 | 用于图像分类任务的深度卷积网络训方法、装置及系统 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11461537B2 (en) * | 2019-11-13 | 2022-10-04 | Salesforce, Inc. | Systems and methods of data augmentation for pre-trained embeddings |
-
2021
- 2021-06-29 CN CN202110730469.0A patent/CN113435519B/zh active Active
Patent Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112836820A (zh) * | 2021-01-31 | 2021-05-25 | 云知声智能科技股份有限公司 | 用于图像分类任务的深度卷积网络训方法、装置及系统 |
Also Published As
Publication number | Publication date |
---|---|
CN113435519A (zh) | 2021-09-24 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11670071B2 (en) | Fine-grained image recognition | |
CN108564129B (zh) | 一种基于生成对抗网络的轨迹数据分类方法 | |
CN111275107A (zh) | 一种基于迁移学习的多标签场景图像分类方法及装置 | |
CN113593611B (zh) | 语音分类网络训练方法、装置、计算设备及存储介质 | |
WO2023029356A1 (zh) | 基于句向量模型的句向量生成方法、装置及计算机设备 | |
CN111984792A (zh) | 网站分类方法、装置、计算机设备及存储介质 | |
CN111259137A (zh) | 知识图谱摘要的生成方法及系统 | |
CN111444346B (zh) | 一种用于文本分类的词向量对抗样本生成方法及装置 | |
CN113435519B (zh) | 基于对抗插值的样本数据增强方法、装置、设备及介质 | |
CN112328735A (zh) | 热点话题确定方法、装置及终端设备 | |
CN113297355A (zh) | 基于对抗插值序列标注数据增强方法、装置、设备及介质 | |
Jin et al. | Dual low-rank multimodal fusion | |
CN114611672A (zh) | 模型训练方法、人脸识别方法及装置 | |
CN113723077A (zh) | 基于双向表征模型的句向量生成方法、装置及计算机设备 | |
CN116363374B (zh) | 图像语义分割网络持续学习方法、系统、设备及存储介质 | |
CN111597336A (zh) | 训练文本的处理方法、装置、电子设备及可读存储介质 | |
CN115130437B (zh) | 一种文档智能填写方法、装置及存储介质 | |
CN114758130B (zh) | 图像处理及模型训练方法、装置、设备和存储介质 | |
CN114281950B (zh) | 基于多图加权融合的数据检索方法与系统 | |
CN114241411B (zh) | 基于目标检测的计数模型处理方法、装置及计算机设备 | |
WO2022126917A1 (zh) | 基于深度学习的人脸图像评估方法、装置、设备及介质 | |
CN111091198A (zh) | 一种数据处理方法及装置 | |
CN111275201A (zh) | 一种基于子图划分的图半监督学习的分布式实现方法 | |
CN116563642B (zh) | 图像分类模型可信训练及图像分类方法、装置、设备 | |
CN116456289B (zh) | 一种富媒体信息处理方法及系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |