CN114120041A - 一种基于双对抗变分自编码器的小样本分类方法 - Google Patents
一种基于双对抗变分自编码器的小样本分类方法 Download PDFInfo
- Publication number
- CN114120041A CN114120041A CN202111432553.0A CN202111432553A CN114120041A CN 114120041 A CN114120041 A CN 114120041A CN 202111432553 A CN202111432553 A CN 202111432553A CN 114120041 A CN114120041 A CN 114120041A
- Authority
- CN
- China
- Prior art keywords
- network
- data
- classification
- sub
- encoder
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 31
- 238000012549 training Methods 0.000 claims abstract description 44
- 238000012360 testing method Methods 0.000 claims abstract description 14
- 238000007781 pre-processing Methods 0.000 claims abstract description 4
- 238000004140 cleaning Methods 0.000 claims abstract description 3
- 238000009826 distribution Methods 0.000 claims description 21
- 210000000601 blood cell Anatomy 0.000 claims description 17
- 238000011156 evaluation Methods 0.000 claims description 9
- 238000005457 optimization Methods 0.000 claims description 6
- 238000005070 sampling Methods 0.000 claims description 6
- 238000009499 grossing Methods 0.000 claims description 5
- 238000013528 artificial neural network Methods 0.000 claims description 4
- 230000003190 augmentative effect Effects 0.000 claims description 3
- 239000003623 enhancer Substances 0.000 claims description 3
- 238000007476 Maximum Likelihood Methods 0.000 claims description 2
- 210000001772 blood platelet Anatomy 0.000 claims description 2
- 238000012938 design process Methods 0.000 claims description 2
- 210000003743 erythrocyte Anatomy 0.000 claims description 2
- 210000000265 leukocyte Anatomy 0.000 claims description 2
- 238000013135 deep learning Methods 0.000 abstract description 6
- 239000000284 extract Substances 0.000 abstract description 3
- 210000002569 neuron Anatomy 0.000 description 20
- 230000006870 function Effects 0.000 description 15
- 238000011176 pooling Methods 0.000 description 9
- 208000009119 Giant Axonal Neuropathy Diseases 0.000 description 7
- 201000003382 giant axonal neuropathy 1 Diseases 0.000 description 7
- 238000010606 normalization Methods 0.000 description 7
- 238000010586 diagram Methods 0.000 description 6
- 230000000694 effects Effects 0.000 description 6
- 238000013145 classification model Methods 0.000 description 4
- 238000000605 extraction Methods 0.000 description 3
- 230000003213 activating effect Effects 0.000 description 2
- 238000004422 calculation algorithm Methods 0.000 description 2
- 210000004027 cell Anatomy 0.000 description 2
- 238000013527 convolutional neural network Methods 0.000 description 2
- 238000002474 experimental method Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000006467 substitution reaction Methods 0.000 description 2
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 210000004369 blood Anatomy 0.000 description 1
- 239000008280 blood Substances 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 238000012512 characterization method Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 230000010355 oscillation Effects 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 238000007637 random forest analysis Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 239000013589 supplement Substances 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/10—Pre-processing; Data cleansing
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
- G06F18/2155—Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/217—Validation; Performance evaluation; Active pattern learning techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/048—Activation functions
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- Life Sciences & Earth Sciences (AREA)
- General Physics & Mathematics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于双对抗变分自编码器的小样本分类方法,解决现有分类方法在小样本下网络难以训练、准确率较低的问题。该分类方法包括:数据预处理,对目标数据集进行清洗、填充和归一化;模型设计与优化步骤,利用VAE和GAN设计相应的网络结构,并针对过拟合和训练产生震荡等问题进行模型的优化;模型训练步骤,利用小样本数据集对模型进行训练,进而获得网络模型权重;模型预测步骤,输入测试集对模型进行预测,对比现有的一些经典深度学习分类方法,验证本发明的有效性。本发明公开的方法中构建的模型能够在小样本情况下实现数据增强,并且能够提取有效特征从而提高分类的准确率,普遍适用于不同类型的分类任务。
Description
技术领域
本发明涉及深度学习任务分类技术领域,具体涉及一种基于双对抗变分自编码器的小样本分类方法。
背景技术
深度学习已经广泛的用于各种行业当中,成为了解决问题的关键方法和技巧。其中,分类任务是该领域研究的重点,对于一些复杂任务来说,通过大量数据集的迭代训练可以取得较好的准确率。然而,模型训练所需的数据集并不总是足够的。这种数据驱动的训练方法使网络模型的性能很大程度上受数据集数量的影响。
由于不同分类任务的数据域之间存在很大差异,导致已有的一些小样本相关深度学习方法,如迁移学习,度量学习等受到了限制。特别是应用在非图像数据集时,容易出现准确率较低,泛化能力差的问题,并且随着深度神经网络层数的增加,网络容易出现过拟合,使用浅层网络又不能提取到有效的特征。在深度学习中,生成模型相比于判别模型更加注重对样本内在分布的学习。常用于小样本问题研究的生成式网络模型的代表是VAE和GAN。但VAE生成的样本质量没有GAN的高,GAN由于没有编码、解码这种结构使得训练起来比较困难,同时还容易产生模式崩塌的问题,无法捕捉到全部的样本分布。目前亟待针对这一系列问题,设计相应合理的深度学习网络框架,提高小样本分类的准确率。
发明内容
本发明的目的是为了解决现有技术中的上述缺陷,提供一种基于双对抗变分自编码器的小样本分类方法。该方法结合VAE和GAN的特性,在扩展样本空间的同时提取样本特征,实现小样本的分类任务。
本发明的目的可以通过采取如下技术方案达到:
一种基于双对抗变分自编码器的小样本分类方法,所述小样本分类方法包括以下步骤:
S1、数据预处理,由于原始待分类数据存在大量缺失和冗余值,因此对其进行清洗、填充和归一化,并划分成训练集、验证集和测试集;述待分类数据为血液细胞数据集和手写数字识别数据集,所述血液细胞数据集包括红细胞、白细胞和血小板的细胞数量与形态,所述手写数字识别数据集包括0-9的手写数字;
S2、构建基于双对抗变分自编码器的小样本分类网络模型,该网络模型包括三个子网络,分别是对数据进行特征编码的特征编码子网络、对数据进行扩充并对扩充数据以及其特征编码进行判别的数据增强子网络和对数据进行分类的分类子网络;
S3、模型训练,输入训练集,对特征编码子网络、数据增强子网络以及分类子网络设计损失函数,通过梯度下降来更新网络的参数,实现基于双对抗变分自编码器的小样本分类网络模型的收敛;
S4、模型预测,输入测试集,利用分类子网络,完成小样本的分类结果,得到小样本分类网络模型的分类准确率,其中,血液细胞数据集分类结果包括以下三种:血液细胞浓度指标正常、血液细胞浓度指标低于正常值、血液细胞浓度指标明显超过正常值一个数量级以上;手写数字识别数据集的分类结果为识别0-9的数字。
进一步地,所述步骤S2中三个子网络的处理过程如下:
所述特征编码子网络与数据增强子网络、分类子网络的神经网络共享同一组参数。其输入是原始数据x,输出是经编码后的重构数据x′;特征编码子网络采用VAE,目的是为了将数据投影到特定的潜在空间中,以便通过采样隐变量生成与原始样本不同的新数据,实现数据扩充。VAE包括一个编码器网络D和一个解码器网络E,其中,编码器网络D将原始数据x投影到特定的潜在空间中,实现输入数据的特征编码,特征编码直接决定了分类的效果;解码器网络E通过在潜在空间采样还原原始数据x,实现数据的重构;
所述数据增强子网络的输入是原始数据x,输出是扩充数据x″。为了实现小样本数据的扩充,在特征编码子网络训练完成后,将原始数据x再次送入特征编码子网络的编码器网络D中进行特征编码z*,将特征编码z*联合真实标签y送入数据增强子网络的生成子网络中得到扩充数据x″。真实标签y的加入,不仅能产生特定标签的数据,还能够提高生成数据的质量。数据增强子网络包括数据判别子网络Dx和特征判别子网络Dz,将扩充数据x″输入编码器网络D中得到特征编码z″,实现对扩充数据的特征提取。使用Dx判别数据增强子网络生成的扩充数据x″与原始数据x的差异,使得x″更加符合真实数据的分布。使用Dz判别特征编码z″与VAE中先验分布z的差异,使得z″符合真实的先验分布。上述子网络使得模型可以生成与原始数据不同的新数据,相当于扩大了训练集,有利于提高测试集的分类效果。
所述分类子网络,用于完成分类,输入是原始数据x,输出是模型分类正确的概率,所述分类子网络使用特征编码子网络中的编码器网络D作为神经网络来提取特征。
进一步地,所述步骤S2中还包括小样本分类网络模型的网络优化,过程如下:
将数据判别子网络Dx和特征判别子网络Dz的标签改成软标签方式,软标签相当于在标签中加入了随机噪声,可以防止判别子网络的判别效果太过绝对,在一定程度上解决生成网络梯度消失的问题;在分类子网络的原始标签中加入标签平滑,标签平滑使得在计算交叉熵损失时,所有标签位置都参与计算,解决原始交叉熵函数只考虑分类正确标签位置的损失,忽略错误标签位置的损失,提高模型的容错能力和泛化能力;采取误判样本重训练,为了让模型在训练过程中更加关注那些分类错误的样本,采取误判样本重训练的技巧,从而加快模型的收敛速度。
进一步地,所述步骤S3中模型训练通过优化损失函数,实现模型的收敛,其中,所述损失函数设计过程如下:
优化特征编码子网络生成的重构数据x′与原始数据x之间的差异:设置特征编码子网络损失函数,如下所示:
LVAE=-EQ(z|x)[log P(x|z)]+DKL[Q(z|x)||P(z)]
该损失函数由极大似然估计(即重构误差)和后验概率组成,其中,Q(z|x)表示近似后验概率分布,P(x|z)表示VAE的解码器,P(z)表示z的原始分布,DKL表示计算KL散度。其中,第一项越小,说明VAE隐变量映射更准确,同时也影响对抗网络生成器G的性能,因为G是从该先验分布中采样生成的数据,此过程也是预训练G的过程,在一定程度缓解GAN的不稳定。
优化条件式判别子网络的差异:设置条件式判别网络损失函数,即判别生成的扩充数据x″和原始数据x、扩充数据x″对应的特征编码z″和先验分布z之间的差异,条件式表达损失函数设计如下:
其中,m表示样本大小,xi、yi、zi分别表示第i个样本、第i个样本的标签、对第i个样本进行先验分布的采样;x″i和z″i代表通过第i个样本生成的扩充数据以及其特征编码;Dx(xi,yi)与Dx(x″i,yi)是GAN加入标签信息y后,数据判别子网络Dx对原始数据x和扩充数据x″的评判结果,Dz(zi)与Dz(z″i)分别为特征判别子网络Dz对先验分布z和特征编码z″的评判结果,该网络的目标是最大化LD的值,优化判别器参数;
优化条件式生成网络的差异:设置条件式生成网络损失函数,通过判别器对扩充数据x″i即其特征编码z″i的判别结果来更新生成网络的参数,条件式表达损失函数设计如下:
其中,Dz(z″i)是特征判别子网络Dz对特征编码z″的评判结果,Dx(x″i,yi)是数据判别子网络Dx对扩充数据x″的评判结果,Dz(zi)与Dz(z″i)分别为特征判别子网络Dz对先验分布z和特征编码z″的评判结果,该网络的目标是最小化LG的值,优化生成器的参数;
优化分类子网络中的分类结果与真实标签之间的差异:设置分类子网络损失函数如下所示:
其中,n表示标签类别数,当标签采用one-hot形式时,y(ij)表示第i个样本真实标签第j个位置的值,y′ij表示第i个样本预测标签第j个位置的值;
整个网络模型的损失函数如下:L=LVAE-LD+LG+LC
通过不断优化损失函数,实现基于双对抗变分自编码器的小样本分类网络模型的收敛。
本发明相对于现有技术具有如下的优点及效果:
1、本发明提出的基于双对抗变分自编码器的小样本分类方法中设计了一种小样本分类网络模型。利用VAE实现对原始数据的特征提取,通过VAE将输入数据映射到隐变量空间并还原这一特性,采样训练空间的隐变量来进行数据扩充,增加训练集的数量;使用两个GAN对扩充数据和对应的特征编码进行对抗训练,增加扩充数据的真实性的同时完成扩充数据的特征提取,使得网络模型可以生成一些符合原始数据分布但与原始数据不同的新数据;分类网络使用VAE的编码器来进行分类训练,增加了特征提取能力,从而在测试集上达到一个很好的分类效果。
2、本发明所使用的三个级联网络是参数共享的,从而降低了复杂网络的参数量。特征编码子网络是数据增强子网络和分类子网络的基础,数据增强子网络利用特征编码子网络实现数据增强,分类子网络加强了特征编码子网络的特征编码能力。三个子网络相辅相成、互相影响,这种预训练的效果减少了网络模型Loss值的震荡。本发明所述的分类方法普遍适用于所有的分类任务。
附图说明
此处所说明的附图用来提供对本发明的进一步理解,构成本申请的一部分,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。在附图中:
图1是本发明公开的一种基于双对抗变分自编码器的小样本分类模型的整体架构示意图;
图2是本发明公开的一种基于双对抗变分自编码器的小样本分类模型的网络结构图;
图3是本发明公开的一种基于双对抗变分自编码器的小样本分类模型的网络训练流程图;
图4是本发明公开的一种基于双对抗变分自编码器的小样本分类模型的图片分类网络结构图。
具体实施方式
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
实施例1
本实施例以来自测试组血液各细胞真实的数量和状态为具体实例,血液细胞数据集包含了3个类别(分别为血液细胞浓度指标正常、血液细胞浓度指标低于正常值、血液细胞浓度指标明显超过正常值一个数量级以上)以及3645个样本,其中,每个样本包含43个变量。
该基于双对抗变分自编码器的小样本分类方法包括以下步骤:
S1、数据预处理,首先去除缺失值较多的变量所在的列,接着采用K近邻算法(KNN)对剩下的缺失值进行填充,最后使用sklearn特征化API对输入进行标准归一化,并将标签转化成one-hot形式,经过处理后的每个样本含29个变量。其中,训练集包含了2916个样本,测试集包含了729个样本;
S2、构建基于双对抗变分自编码器的小样本分类网络模型,该网络模型包括三个子网络,分别是对数据进行特征编码的特征编码子网络、对数据进行扩充并对扩充数据以及其特征编码进行判别的数据增强子网络和对数据进行分类的分类子网络;
附图2为网络结构图。各子网络具体结构如下:
特征编码子网络:从输入层至输出层依次连接为:输入层Input_x,为29个神经元;500个神经元的编码层Encoder_h1;批量标准化层BN;500个神经元的编码层Encoder_h2,批量标准化层BN;500个神经元的编码层Encoder_h3,500个神经元的特征编码层Encoder_z1;批量标准化层BN;500个神经元的特征编码层Encoder_z2;5个神经元的均值层mean和5个神经元的方差层log_var;5个神经元的隐变量生成层z;500个神经元的解码层Decoder_h1;批量标准化层BN;500个神经元的特征编码层Decoder_h2;批量标准化层BN;500个神经元的特征编码层Decoder_h3;
数据增强子网络:生成网络为三个1000个神经元的全连接层,两个判别子网络均为1000个神经元的全连接层;
分类子网络:该子网络共享特征编码子网络的前5层,再依次接入1000个神经元的全连接层cls_h1;批量标准化层BN;1000个神经元的全连接层cls_h2;n个神经元的输出层,n为类别的个数,取值为3;
S3、模型训练,输入训练集,对特征编码子网络、数据增强子网络以及分类子网络进行训练,通过梯度下降来更新网络的参数,三个子网络交替训练,直到网络收敛,完成基于双对抗变分自编码器的小样本分类网络模型的训练阶段。图3是模型训练流程图,包括以下步骤;
S31、对训练集T={(xi,yi),...(xn,yn)},采样一个batch大小的数据 其中,T表示训练集,n表示训练集个数,置为2916,X为从T中采样一个batch的数据量,m为batch大小,置为100,(xi,yi)表示该batch的第i个样本及其标签。将X送入特征编码子网络的编码器网络D和解码器网络E中进行特征编码和解码,更新特征子网络的参数;
S32、将X放入特征编码子网络的编码器网络D得到zi为第i个样本产生的隐变量,联合真实标签yi送入数据增强子网络的生成网络中得到新产生的数据集x″i为用第i个样本产生的同类数据。将X″置于特征编码子网络的编码器网络D中用于生成特征编码 z″i为第i个x″i的特征编码。将X″和X送入数据判别子网络Dx进行判别,对判别器子网络的标签采用软标签,对于X,将标签值置为0.8~1的随机值,对于X″,将标签值置为0~0.2的随机值;同时,将特征编码Z″和先验分布Z送入特征判别子网络Dz进行判别,对于Z,将标签值置为1,对于Z″,将标签值置0,更新判别器参数;
S33、根据数据判别子网络Dx和特征判别子网络Dz的判别结果,Dx对X″判别为1,Dz对Z″判别为1,更新生成器的参数,即更新特征编码子网络的编码器网络D和解码器网络E的参数;
S34、将X输入到分类子网络中,激活函数选取Sigmoid,得到预测结果,对真实标签进行标签平滑操作,
其中,k为类别个数,值为3,yk表示one-hot形式的标签第k个位置的值,α为平滑因子,本实验取值为0.2,对真实标签作如上的改变后,将预测结果与真实标签做交叉熵,更新分类子网络参数;
S35、对误判样本重新训练,对于分类子网络分类错误的样本,再进行一次分类子网络的训练。最后判断网络是否已经收敛,如果loss值不再下降,则停止训练,否则继续执行S31-S34;
S4、模型预测,将测试集输入分类子网络中,得到分类结果,计算出最终的分类准确率。
本实施例基于TensorFlow框架和Linux平台下的Pycharm开发环境。特征编码子网络的学习率设置为0.0001、判别子网络以及生成网络的学习率设置为的学习率设置为0.00001、分类子网络的学习率设置为0.001,均采用Adam优化器,训练的batch设置为100,网络迭代500次。在729个血液细胞测试集上得到93.49%准确率,对比经典机器学习算法随机森林和贝叶斯分别高出5%和12%,验证了本发明在小样本情况下实现数据增强,并且能够提取有效特征从而提高分类的准确率。
实施例2
本实施例以来自1000张手写数字的图片为例,手写数字数据集包含了10个类别(分别识别0-9的10个数字),其中,每张图片大小为28x28。
该基于双对抗变分自编码器的小样本分类方法包括以下步骤:
S1、将手写数字数据集进行归一化,并将标签转化成one-hot形式。按8:2将数据集划分,其中,训练集包含800张图片,测试集包含了200张图片;
S2、构建三个子网络,分别是特征编码子网络、数据增强子网络和分类子网络,附图4为网络结构图;
各子网络具体结构如下:
特征编码子网络:输入28x28x3大小的图片,编码部分padding为1,卷积核大小均为2x2,经过第一个卷积和池化输出大小为28x28x16,经过第二个卷积和池化输出大小为为14x14x32,经过第三个卷积和池化输出大小为7x7x64;经过Flatten层接入两个全连接层,输出100x1,采样生成的z为100x1;经过第四个卷积和池化输出大小为7x7x64,经过第五个卷积和池化输出大小为为14x14x32,经过第六个卷积和池化输出大小为28x28x16,最终输出28x28x3;
数据增强子网络:生成子网络输入为110x1,经过第一个卷积和池化输出大小为7x7x64,经过第二个卷积和池化输出大小为为14x14x32,经过第三个卷积和池化输出大小为28x28x16,最终输出28x28x3;两个判别子网络分别为1000个神经元的全连接层,最后一层神经元个数为10;
分类子网络:该子网络共享特征编码子网络的前5层,再依次接入2个1000个神经元的全连接层;
S3、模型训练,输入训练集,通过梯度下降来更新三个子网络的参数,直到网络收敛。包括以下步骤;
S31、对训练集T={(xi,yi),...(xn,yn)},采样一个batch大小的数据 其中,n置为800,X为从T中采样一个batch的数据量,m为batch大小,置为100,(xi,yi)表示该batch的第i个样本及其标签。将X送入D和E中进行特征编码和解码,更新特征子网络的参数;
S32、将X放入D得到联合真实标签yi送入生成网络中得到x″i为用第i个样本产生的同类数据。将X″置于D中用于生成特征编码z″i为第i个x″i的特征编码。将X″和X送入Dx进行判别,对判别器子网络的标签采用软标签,对于X,将标签值置为0.8~1的随机值,对于X″,将标签值置为0~0.2的随机值;同时,将Z″和先验分布Z送入Dz进行判别,对于Z,将标签值置为1,对于Z″,将标签值置0,更新判别器参数;
S33、根据Dx和Dz的判别结果,Dx对X″判别为1,Dz对Z″判别为1,更新生成器的参数;
S34、将X输入到分类子网络中,激活函数选取Sigmoid,得到预测结果,对真实标签进行标签平滑操作,
其中,k为10,yk表示one-hot形式的标签第k个位置的值,α为平滑因子,本实验取值为0.2,对真实标签作如上的改变后,将预测结果与真实标签做交叉熵,更新分类子网络参数;
S35、对误判样本重新训练,对于分类子网络分类错误的样本,再进行一次分类子网络的训练。最后判断网络是否已经收敛,如果loss值不再下降,则停止训练,否则继续执行S31-S34;
S4、模型预测,将测试集输入分类子网络中,得到分类结果,计算出最终的分类准确率。
本实施例在200个手写数字测试集上得到88.49%准确率,对比卷积神经网络CNN高出6%,验证了本发明在小样本情况下实现数据增强,并且能够提取有效特征从而提高分类的准确率。
上述实施例为本发明较佳的实施方式,但本发明的实施方式并不受上述实施例的限制,其他的任何未背离本发明的精神实质与原理下所作的改变、修饰、替代、组合、简化,均应为等效的置换方式,都包含在本发明的保护范围之内。
Claims (4)
1.一种基于双对抗变分自编码器的小样本分类方法,其特征在于,所述小样本分类方法包括下列步骤:
S1、数据预处理,对待分类数据进行清洗、填充和归一化,并划分成训练集和测试集,所述待分类数据为血液细胞数据集和手写数字数据集,所述血液细胞数据集包括红细胞、白细胞和血小板的浓度指标,所述手写数字识别数据集包括0-9的手写数字;
S2、构建基于双对抗变分自编码器的小样本分类网络模型,该网络模型包括三个级联的子网络,分别是对数据进行特征编码的特征编码子网络、对数据进行扩充并对扩充数据以及其特征编码进行判别的数据增强子网络和对数据进行分类的分类子网络;
S3、模型训练,输入训练集,对特征编码子网络、数据增强子网络以及分类子网络设计损失函数,通过梯度下降来更新网络的参数,实现基于双对抗变分自编码器的小样本分类网络模型的收敛;
S4、模型预测,输入测试集,利用分类子网络,完成小样本的分类结果,得到小样本分类网络模型的分类准确率,其中,血液细胞数据集分类结果包括以下三种:血液细胞浓度指标正常、血液细胞浓度指标低于正常值、血液细胞浓度指标超过正常值一个数量级以上;手写数字数据集的分类结果为识别图片表示的0-9的数字。
2.根据权利要求1所述的一种基于双对抗变分自编码器的小样本分类方法,其特征在于,所述特征编码子网络的输入是原始数据x,输出是经编码后的重构数据x′;特征编码子网络包括变分自编码器VAE,变分自编码器VAE包括一个编码器网络D和一个解码器网络E,其中,编码器网络D将原始数据x投影到特定的潜在空间中,解码器网络E通过在潜在空间采样还原原始数据x,从而实现原始数据x的特征编码;
所述数据增强子网络的输入是原始数据x,输出是扩充数据x″,数据增强子网络用于对小样本数据进行扩充;将原始数据x送入特征编码子网络的编码器网络D中进行特征编码得到z*,将特征编码z*联合真实标签y送入数据增强子网络中得到扩充数据x″,将扩充数据x″输入编码器网络D中进行特征编码得到z″,使用生成对抗网络GAN对扩充数据x″及其特征编码z″进行对抗训练,所述数据增强子网络包括数据判别子网络Dx和特征判别子网络Dz,分别用于判别数据增强子网络生成的扩充数据x″与原始数据x的差异,以及扩充数据x″进行特征编码生成的特征编码z″与变分自编码器VAE中先验分布z的差异;
所述分类子网络,用于完成分类,输入是原始数据x,输出是模型分类正确的概率,所述分类子网络使用特征编码子网络中的编码器网络D作为神经网络。
3.根据权利要求2所述的一种基于双对抗变分自编码器的小样本分类方法,其特征在于,所述小样本分类方法还包括小样本分类网络模型的网络优化,过程如下:
将数据判别子网络Dx和特征判别子网络Dz的标签改成软标签方式;在分类子网络的原始标签中加入标签平滑;采取误判样本重训练,对分类错误样本进行再训练,通过不断优化损失函数,实现基于双对抗变分自编码器的小样本分类网络模型的收敛。
4.根据权利要求1所述的一种基于双对抗变分自编码器的小样本分类方法,其特征在于,所述步骤S3中模型训练通过优化损失函数,实现模型的收敛,其中,所述损失函数设计过程如下:
优化特征编码子网络生成的重构数据x′与原始数据x之间的差异:设置特征编码子网络损失函数,如下所示:
LVAE=-EQ(Z|x)[log P(x|z)]+DKL[Q(z|x)||P(z)]
该损失函数由极大似然估计和后验概率组成,其中,Q(z|x)表示近似后验概率分布,P(x|z)表示VAE的解码器,P(z)表示z的原始分布,DKL表示计算KL散度;
优化条件式判别子网络的差异:设置条件式判别网络损失函数,即判别生成的扩充数据x″和原始数据x、扩充数据x″对应的特征编码z″和先验分布z之间的差异,条件式表达损失函数设计如下:
其中,m表示样本大小,xi、yi、zi分别表示第i个样本、第i个样本的标签以及对第i个样本进行先验分布的采样;x″i和z″i代表通过第i个样本生成的扩充数据以及其特征编码;Dx(xi,yi)与Dx(x″i,yi)是GAN加入标签信息y后,数据判别子网络Dx对原始数据x和扩充数据x″的评判结果,Dz(zi)与Dz(z″i)分别为特征判别子网络Dz对先验分布z和特征编码z″的评判结果,该网络的目标是最大化LD的值,优化判别器参数;
优化条件式生成网络的差异:设置条件式生成网络损失函数,通过判别器对扩充数据x″i即其特征编码z″i的判别结果来更新生成网络的参数,条件式表达损失函数设计如下:
其中,Dz(z″i)是特征判别子网络Dz对特征编码z″的评判结果,Dx(x″i,yi)是数据判别子网络Dx对扩充数据x″的评判结果,该网络的目标是最小化LG的值,优化生成器的参数;
优化分类子网络中的分类结果与真实标签之间的差异:设置分类子网络损失函数如下所示:
其中,n表示标签类别数,当标签采用one-hot形式时,y(ij)表示第i个样本真实标签第j个位置的值,y′ij表示第i个样本预测标签第j个位置的值;
整个小样本分类网络模型的损失函数如下:L=LVAE-LD+LG+LC。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111432553.0A CN114120041B (zh) | 2021-11-29 | 2021-11-29 | 一种基于双对抗变分自编码器的小样本分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111432553.0A CN114120041B (zh) | 2021-11-29 | 2021-11-29 | 一种基于双对抗变分自编码器的小样本分类方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114120041A true CN114120041A (zh) | 2022-03-01 |
CN114120041B CN114120041B (zh) | 2024-05-17 |
Family
ID=80371456
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111432553.0A Active CN114120041B (zh) | 2021-11-29 | 2021-11-29 | 一种基于双对抗变分自编码器的小样本分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114120041B (zh) |
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115291108A (zh) * | 2022-06-27 | 2022-11-04 | 东莞新能安科技有限公司 | 数据生成方法、装置、设备及计算机程序产品 |
CN115546652A (zh) * | 2022-11-29 | 2022-12-30 | 城云科技(中国)有限公司 | 一种多时态目标检测模型及其构建方法、装置及应用 |
WO2023168903A1 (zh) * | 2022-03-10 | 2023-09-14 | 腾讯科技(深圳)有限公司 | 模型训练和身份匿名化方法、装置、设备、存储介质及程序产品 |
WO2024016303A1 (zh) * | 2022-07-22 | 2024-01-25 | 京东方科技集团股份有限公司 | 分类模型的训练方法、分类方法、装置、电子设备及介质 |
CN117893528A (zh) * | 2024-03-13 | 2024-04-16 | 云南迪安医学检验所有限公司 | 一种心脑血管疾病分类模型的构建方法及装置 |
Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109377452A (zh) * | 2018-08-31 | 2019-02-22 | 西安电子科技大学 | 基于vae和生成式对抗网络的人脸图像修复方法 |
CN110580501A (zh) * | 2019-08-20 | 2019-12-17 | 天津大学 | 一种基于变分自编码对抗网络的零样本图像分类方法 |
CN111563554A (zh) * | 2020-05-08 | 2020-08-21 | 河北工业大学 | 基于回归变分自编码器的零样本图像分类方法 |
CN111598805A (zh) * | 2020-05-13 | 2020-08-28 | 华中科技大学 | 一种基于vae-gan的对抗样本防御方法及系统 |
US20200320402A1 (en) * | 2019-04-08 | 2020-10-08 | MakinaRocks Co., Ltd. | Novelty detection using deep learning neural network |
CN112633386A (zh) * | 2020-12-26 | 2021-04-09 | 北京工业大学 | 基于sacvaegan的高光谱图像分类方法 |
CN113505477A (zh) * | 2021-06-29 | 2021-10-15 | 西北师范大学 | 一种基于svae-wgan的过程工业软测量数据补充方法 |
-
2021
- 2021-11-29 CN CN202111432553.0A patent/CN114120041B/zh active Active
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109377452A (zh) * | 2018-08-31 | 2019-02-22 | 西安电子科技大学 | 基于vae和生成式对抗网络的人脸图像修复方法 |
US20200320402A1 (en) * | 2019-04-08 | 2020-10-08 | MakinaRocks Co., Ltd. | Novelty detection using deep learning neural network |
CN110580501A (zh) * | 2019-08-20 | 2019-12-17 | 天津大学 | 一种基于变分自编码对抗网络的零样本图像分类方法 |
CN111563554A (zh) * | 2020-05-08 | 2020-08-21 | 河北工业大学 | 基于回归变分自编码器的零样本图像分类方法 |
CN111598805A (zh) * | 2020-05-13 | 2020-08-28 | 华中科技大学 | 一种基于vae-gan的对抗样本防御方法及系统 |
CN112633386A (zh) * | 2020-12-26 | 2021-04-09 | 北京工业大学 | 基于sacvaegan的高光谱图像分类方法 |
CN113505477A (zh) * | 2021-06-29 | 2021-10-15 | 西北师范大学 | 一种基于svae-wgan的过程工业软测量数据补充方法 |
Non-Patent Citations (2)
Title |
---|
REMOTE SENSING: "Self-Attention-Based Conditional Variational Auto-Encoder Generative Adversarial Networks for Hyperspectral Classification", REMOTE SENSING, vol. 13, 21 August 2021 (2021-08-21) * |
苗壮;张湧;李伟华;: "基于双重对抗自编码网络的红外目标建模方法", 光学学报, no. 11, 10 June 2020 (2020-06-10) * |
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2023168903A1 (zh) * | 2022-03-10 | 2023-09-14 | 腾讯科技(深圳)有限公司 | 模型训练和身份匿名化方法、装置、设备、存储介质及程序产品 |
CN115291108A (zh) * | 2022-06-27 | 2022-11-04 | 东莞新能安科技有限公司 | 数据生成方法、装置、设备及计算机程序产品 |
WO2024016303A1 (zh) * | 2022-07-22 | 2024-01-25 | 京东方科技集团股份有限公司 | 分类模型的训练方法、分类方法、装置、电子设备及介质 |
CN115546652A (zh) * | 2022-11-29 | 2022-12-30 | 城云科技(中国)有限公司 | 一种多时态目标检测模型及其构建方法、装置及应用 |
CN117893528A (zh) * | 2024-03-13 | 2024-04-16 | 云南迪安医学检验所有限公司 | 一种心脑血管疾病分类模型的构建方法及装置 |
CN117893528B (zh) * | 2024-03-13 | 2024-05-17 | 云南迪安医学检验所有限公司 | 一种心脑血管疾病分类模型的构建方法及装置 |
Also Published As
Publication number | Publication date |
---|---|
CN114120041B (zh) | 2024-05-17 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114120041B (zh) | 一种基于双对抗变分自编码器的小样本分类方法 | |
CN111581405B (zh) | 基于对偶学习生成对抗网络的跨模态泛化零样本检索方法 | |
CN110751698B (zh) | 一种基于混和网络模型的文本到图像的生成方法 | |
CN111860982A (zh) | 一种基于vmd-fcm-gru的风电场短期风电功率预测方法 | |
CN111428789A (zh) | 一种基于深度学习的网络流量异常检测方法 | |
CN114019370B (zh) | 基于灰度图像和轻量级cnn-svm模型的电机故障检测方法 | |
CN114842267A (zh) | 基于标签噪声域自适应的图像分类方法及系统 | |
CN113469236A (zh) | 一种自我标签学习的深度聚类图像识别系统及方法 | |
CN112464004A (zh) | 一种多视角深度生成图像聚类方法 | |
CN112784031B (zh) | 一种基于小样本学习的客服对话文本的分类方法和系统 | |
CN112784929A (zh) | 一种基于双元组扩充的小样本图像分类方法及装置 | |
CN112507778B (zh) | 一种基于线特征的改进词袋模型的回环检测方法 | |
CN112884758B (zh) | 一种基于风格迁移方法的缺陷绝缘子样本生成方法及系统 | |
CN114048468A (zh) | 入侵检测的方法、入侵检测模型训练的方法、装置及介质 | |
CN110991247B (zh) | 一种基于深度学习与nca融合的电子元器件识别方法 | |
CN115659254A (zh) | 一种双模态特征融合的配电网电能质量扰动分析方法 | |
CN112288700A (zh) | 一种铁轨缺陷检测方法 | |
CN114548199A (zh) | 一种基于深度迁移网络的多传感器数据融合方法 | |
CN115345222A (zh) | 一种基于TimeGAN模型的故障分类方法 | |
CN114821299A (zh) | 一种遥感图像变化检测方法 | |
CN111783688B (zh) | 一种基于卷积神经网络的遥感图像场景分类方法 | |
CN112699782A (zh) | 基于N2N和Bert的雷达HRRP目标识别方法 | |
CN115049852B (zh) | 一种轴承故障诊断方法、装置、存储介质及电子设备 | |
CN114387524B (zh) | 基于多层级二阶表征的小样本学习的图像识别方法和系统 | |
CN117011219A (zh) | 物品质量检测方法、装置、设备、存储介质和程序产品 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |