CN116663619B - 基于gan网络的数据增强方法、设备以及介质 - Google Patents
基于gan网络的数据增强方法、设备以及介质 Download PDFInfo
- Publication number
- CN116663619B CN116663619B CN202310942682.7A CN202310942682A CN116663619B CN 116663619 B CN116663619 B CN 116663619B CN 202310942682 A CN202310942682 A CN 202310942682A CN 116663619 B CN116663619 B CN 116663619B
- Authority
- CN
- China
- Prior art keywords
- similarity
- batch
- layer
- data
- signal
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 54
- 230000002708 enhancing effect Effects 0.000 claims abstract description 4
- 238000012549 training Methods 0.000 claims description 63
- 238000012545 processing Methods 0.000 claims description 21
- 230000003044 adaptive effect Effects 0.000 claims description 20
- 230000008569 process Effects 0.000 claims description 19
- 230000006870 function Effects 0.000 claims description 16
- 239000011159 matrix material Substances 0.000 claims description 12
- 230000004913 activation Effects 0.000 claims description 11
- 238000013135 deep learning Methods 0.000 claims description 10
- 238000001514 detection method Methods 0.000 claims description 9
- 238000000605 extraction Methods 0.000 claims description 8
- 238000005070 sampling Methods 0.000 claims description 8
- 238000010606 normalization Methods 0.000 claims description 6
- 230000003213 activating effect Effects 0.000 claims description 2
- 238000003491 array Methods 0.000 claims description 2
- 238000012544 monitoring process Methods 0.000 claims description 2
- 238000007781 pre-processing Methods 0.000 claims description 2
- 238000001914 filtration Methods 0.000 claims 1
- VZCCETWTMQHEPK-QNEBEIHSSA-N gamma-linolenic acid Chemical compound CCCCC\C=C/C\C=C/C\C=C/CCCCC(O)=O VZCCETWTMQHEPK-QNEBEIHSSA-N 0.000 claims 1
- 238000013527 convolutional neural network Methods 0.000 abstract description 10
- 238000013459 approach Methods 0.000 abstract description 2
- 230000007547 defect Effects 0.000 abstract description 2
- 238000010586 diagram Methods 0.000 description 5
- 238000012360 testing method Methods 0.000 description 4
- 230000001965 increasing effect Effects 0.000 description 3
- 230000004069 differentiation Effects 0.000 description 2
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 238000013480 data collection Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000002059 diagnostic imaging Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 238000000691 measurement method Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 238000011176 pooling Methods 0.000 description 1
- 230000000306 recurrent effect Effects 0.000 description 1
- 230000004044 response Effects 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
- 238000000844 transformation Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0475—Generative networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/22—Matching criteria, e.g. proximity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/094—Adversarial learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Data Exchanges In Wide-Area Networks (AREA)
- Image Analysis (AREA)
Abstract
本发明属于数据增强技术领域,具体公开了一种基于GAN网络的数据增强方法、设备以及介质。本发明针对数据集小等缺点,提出利用改进的GAN网络对信号进行数据增强,从而扩大数据集,针对当前GAN网络收敛速度慢、生成一些固定样本的问题,本发明设计了一种自适应波形检测器和小批量判别器,从而为提高GAN网络收敛速度提供了一条可行的途径。在此基础上,本发明构建了包括卷积神经网络模块的生成器等结构,通过所提出得到数据增强模型,有效地扩大了数据集的大小。
Description
技术领域
本发明属于数据增强技术领域,具体涉及一种基于GAN网络的数据增强方法、设备以及介质。
背景技术
随着人工智能的快速发展,人们对深度学习算法来对信号进行分析越来越感兴趣,这些算法的性能取决于可用于训练数据的质量和数量。然而,在进行医学图像、语音识别和自然语言处理等任务,数据收集和标注可能需要专业知识和大量的时间,导致可用公共数据集的大小通常非常小。
为了解决数据可用性有限的技术问题,数据增强技术已被广泛用于深度学习。数据增强包括通过对现有数据进行各种转换(如旋转、缩放和翻转)来创建新的样本。然而,这些技术并不总是适用,因为它们可能改变一些信号的潜在特征。
近年来,生成对抗网络(Generative Adversarial Networks, GAN)一直受到图像和时间序列领域的关注,当涉及到图像领域时,GAN已经被广泛用作扩大图像数据集的有效技术。
例如,专利文献1公开了一种基于生成对抗网络及卷积循环神经网络的单导联心电异常信号识别方法,该方法在利用生成对抗网络对数据增强时存在如下缺陷:
1. 在GAN网络训练过程中,存在一些常见问题,其中之一是模式崩溃,模式崩溃是指生成器在训练过程中只生成一些固定的样本,而不是生成多样性的样本。
2. 目前GAN网络训练太慢,网络收敛速度慢,训练时间很久才能达到纳什均衡。
参考文献
专利文献1 中国发明专利申请 公开号:CN111990989A,公开日:2020.11.27。
发明内容
本发明的目的在于提出一种基于GAN网络的数据增强方法,该方法基于改进的GAN网络,即利用带有自适应波形检测器和小批量判别器的生成对抗网络进行数据增强,用于扩大数据集,以解决数据集小、数据不平衡等缺点。
本发明为了实现上述目的,采用如下技术方案:
基于GAN网络的数据增强方法,包括如下步骤:
步骤1. 搭建基于GAN网络的数据增强模型;
搭建的数据增强模型包括生成器、小批量判别器以及自适应波形检测器;
信号在数据增强模型中的处理流程如下:
在生成器中,符合正态分布的随机噪声输入到生成器中,经过生成器生成一批生成信号,生成信号首先进入到自适应波形检测器中;
在自适应波形检测器中设定动态阈值;通过欧几里德距离进行波形的相似度监测;
如果生成信号与真实信号的相似度小于动态阈值,则将该条生成信号丢弃,将相似度大于或等于动态阈值的生成信号输入小批量判别器进行判别;
小批量判别器分为两个分支,分别是CNN网络分支和小批量判别分支;
信号在小批量判别器中的处理流程如下:
首先,从生成信号中选择一个小批次的样本作为小批量判别器的输入;
在小批量判别分支,将输入的样本与权重矩阵W相乘,得到一个表示样本相似性的张量,计算样本之间的差异,并求取绝对差异的和,使用指数函数对差异处理,得到小批量特征,作为小批量判别分支的输出;
在CNN网络分支,对生成信号进行深度学习特征提取,并与小批量判别分支的输出进行合并,合并后在全连接层进行分类0或1,用来判别真假;
步骤2. 利用训练数据集对基于GAN网络的数据增强模型进行训练,并利用训练好的基于GAN网络的数据增强模型,对输入的信号进行数据增强。
在上述基于GAN网络的数据增强方法的基础上,本发明还提出了一种计算机设备,该计算机设备包括存储器和一个或多个处理器。
所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,用于实现上面述及的基于GAN网络的数据增强方法的步骤。
在上述基于GAN网络的数据增强方法的基础上,本发明还提出了一种计算机可读存储介质,在计算机可读存储介质上存储有程序。
该程序被处理器执行时用于实现上述基于GAN网络的数据增强方法的步骤。
本发明具有如下优点:
1. 本发明基于GAN的数据增强网络可以提升有效生成大量数据,其中,加入自适应波形检测器可以避免生成器陷入低质量信号的循环,加快生成器向更好的方向演进,通过将波形相似度作为额外的训练准则,可以引导生成器朝着更接近原始信号的方向进行学习和优化。这有助于提高生成器生成信号的逼真程度和准确性,总之,加入波形检测器有助于提升生成器的性能和生成结果的质量,加快网络的收敛速度,并提供对生成过程的控制能力,这样可以使GAN网络在生成信号任务中更加稳定和可控,产生更好的结果。
2. 本发明提出的小批量GAN网络可以生成与原始数据相似的样本,使其具备与原始信号更接近的特征并保持了差异化,让模型更快地收敛到最优解,帮助生成器更好地学习真实数据的分布,从而生成更加逼真的数据,增加了数据集的多样性,使用小批量判别器可以避免只学习到数据分布中的部分模式,而未能覆盖整个数据分布,因此可以通过计算样本之间的差异来帮助避免模式崩溃的问题,它通过比较生成样本与其他样本之间的差异来激励生成器生成更多不同的样本,有助于改进GAN的性能和生成样本的质量。
附图说明
图1为本发明实施例中基于GAN网络的数据增强方法的网络结构图。
图2为本发明实施例中数据预处理流程图。
图3为本发明实施例中生成器的结构图。
图4为本发明实施例中小批量判别器的结构图。
图5为以人体心电信号为例,原始数据与使用小批量判别器生成的数据的时域波形图对比图。
图6为图5中的A部放大图。
图7为图5中的B部放大图。
图8为以人体心电信号为例,原始数据与使用未使用小批量判别器生成的数据的时域波形图对比图。
图9为图8中的C部放大图。
图10为图8中的D部放大图。
图11为GAN网络在加入自适应波形检测器和未加入自适应波形检测器的收敛对比图。
具体实施方式
下面结合附图以及具体实施方式对本发明作进一步详细说明:
实施例1
如图1所示,基于GAN网络的数据增强方法,包括如下步骤:
步骤1. 搭建基于GAN网络的数据增强模型。本实施例中搭建的数据增强模型包括生成器、小批量判别器以及自适应波形检测器。
其中,信号在数据增强模型中的处理流程如下:
在生成器中,符合正态分布的随机噪声输入到生成器中,经过生成器生成一批生成信号,生成信号首先进入到自适应波形检测器中。
在自适应波形检测器中设定动态阈值。通过欧几里德距离进行波形的相似度监测。
如果生成信号与真实信号的相似度小于动态阈值,则将该条生成信号丢弃,将相似度大于或等于动态阈值的生成信号输入小批量判别器进行判别。
小批量判别器分为两个分支,分别是CNN网络分支和小批量判别分支。
信号在小批量判别器中的处理流程如下:
首先,从生成信号中选择一个小批次的样本作为小批量判别器的输入。
在小批量判别分支,将输入的样本与权重矩阵W相乘,得到一个表示样本相似性的张量,计算样本之间的差异,并求取绝对差异的和,使用指数函数对差异处理,得到小批量特征,作为小批量判别分支的输出。
在CNN网络分支,对生成信号进行深度学习特征提取,并与小批量判别分支的输出进行合并,合并后在全连接层进行分类0或1,用来判别真假。
本发明利用带有自适应波形检测器和小批量判别器的生成对抗网络进行数据增强,用于扩大数据集,能够很好的解决数据集小、数据不平衡等问题。
下面对基于GAN网络的数据增强模型中的各个组成部分进行详细说明:
如图3所示,生成器整体结构主要以卷积神经网络为主体,其包括重塑层、卷积模块、展平层、全连接层以及Tanh激活函数。
其中,重塑层有两个,分别为第一、第二重塑层;卷积模块有三个,且每个卷积模块均包括上采样层、一维卷积层、批归一化层以及激活函数。
生成器的输入是一个100维随机噪声向量,其在生成器中的处理流程如下:
首先经过第一重塑层将输入噪声向量重塑维度为(100,1),并输入到卷积模块。
在卷积模块中使用上采样层向上采样,通过线性插值将输入的时间序列长度加倍,可以在生成器中逐步增加时间序列的长度,并在后续的卷积层中进行处理和学习。
在上采样层后加入一维卷积层,在一维卷积层后还加入批归一化层。
批归一化层用于将每个batch的数据归一化到均值为0,方差为1的分布,这样可以加速模型训练、防止过拟合并提高模型精度。
依次经过三个卷积模块后的输出展开成一维向量,作为展平层的输入,再使用Tanh激活函数对全连接层的输出进行激活,使其输出范围在[-1,1]之间。
最后经过第二重塑层,将输出信号重塑为真实信号的维度,作为生成器的输出。
小批量判别器的结构也是基于卷积神经网络为主体,去除池化层,只保留了卷积层。在GAN训练过程中,存在一些常见问题,其中之一是模式崩溃。
模式崩溃是指生成器在训练过程中只生成一些固定的样本,而不是生成多样性的样本。模式崩溃可能发生在训练数据分布复杂或者训练过程中学习率过高或者过低的情况下。
为了解决模式崩溃问题,本发明采取了使用小批量判别器的方法,小批量判别器层的作用是通过引入关于样本差异的信息来解决这个问题,其大致思路如下:
计算每个样本与同一小批量中其他样本之间的相似性,并将这些信息与原始输入特征连接起来,因此生成器就能够学习产生更多样化和变化性的样本,因为判别器需要区分更多不同的样本,提高了判别器对生成样本的泛化能力,增加模型的多样性和鲁棒性。
这种方法可以使生成器更难以“欺骗”判别器,从而鼓励它生成更多的多样化样本。
小批量判别器的模型结构如图4所示,小批量判别器有两个分支,一个是小批量判别分支,另一个是深度卷积网络分支,即CNN网络分支。
小批量判别器层的实现逻辑是:将输入x与权重矩阵W相乘,得到一个表示样本相似性的张量。计算样本之间的差异,并求取绝对差异的和。使用指数函数对差异进行处理,得到小批量特征。将输入x和小批量特征连接在一起,作为最终的输出。
通过添加小批量判别器层,使得本发明中生成器在训练过程中能够学习到更多样化和变化性的样本生成,从而提升生成模型的性能和生成样本的质量。
小批量判别分支包括展平层以及小批量判别层。
其中,小批量判别层的网络结构如下:
输入层是一个2D张量,形状为(Batch_Size, input dim),其中,Batch_Size表示批量大小,input dim表示输入的通道数或特征维度;
权重矩阵W是一个3D张量,形状为(nb_kernels, input dim, kernel dim);
其中,nb_kernels表示判别器核的数量,input dim表示输入的通道数或特征维度,kernel dim表示计算样本相似性的空间的维度;
权重矩阵是通过层的build方法创建的,并在训练过程中进行更新;
前向传播逻辑是输入x与权重矩阵W进行矩阵乘法运算,得到表示样本相似性的张量;
对表示样本相似性的张量进行计算,包括计算样本之间的差异、绝对差异的和以及小批量特征,最后,将输入x和小批量特征连接在一起,并作为最终的输出;
输出层是一个2D张量,形状为(Batch_Size, input dim + nb_kernels),其中Batch_Size表示批量大小,input dim表示输入的通道数或特征维度,nb_kernels表示判别器核的数量。
小批量判别分支的输出将作为判别器的输出,用于生成对抗网络的训练过程。
如图4所示,CNN网络分支包括四个卷积模块以及展平层;其中,每个卷积模块均包括一维卷积层、激活函数层、以及Dropout层。
定义四个卷积模块依次为第一、第二、第三、第四卷积模块。
信号在CNN网络分支中的处理流程为:
首先通过第一卷积模块的一维卷积层,进行8个大小为8的卷积核进行卷积操作,步长为1;然后一维卷积层的输出输入到LeakyReLU激活函数。
为了防止过拟合,经过Dropout操作,其中丢弃率为0.25。
然后第一卷积模块的输出依次进入第二、第三、第四卷积模块重复上述操作,在经过第三卷积模块和第四卷积模块的一维卷积层时,步长变为2。
最后将第四卷积模块的Dropout层的输出输入到展平层中进行展平,然后和另一分支的小批量判别层进行特征合并,以增加判别器的多样性和稳定性。
将两个分支合并后的特征图输入到一个具有sigmoid激活函数的全连接层进行处理,输出一个值域为[0,1]的概率值,其公式为表示该输入信号为真实样本的概率。
使用二分类交叉熵作为损失函数,优化器为Adam,指定学习率和动量参数。
在GAN训练过程中,为了加快网络的收敛速度以及产生更高质量的信号,建立了自适应波形检测器。生成器生成信号后,先进入自适应波形检测器中,使用欧几里德距离度量方法检测生成信号和真实信号的波形相似度。
如果生成信号和真实信号的相似度大于动态阈值X,则才将生成信号输入到小批量判别器中判别,波形相似度小于动态阈值X的将会被丢弃并重新生成。
需要注意的是,动态阈值X的选择是一项复杂的任务,确定合适的阈值对于波形检测器的有效性非常重要。过高的阈值可能导致生成器很难满足要求,而过低的阈值可能导致过于严格,限制了生成器的学习能力,因此,确定阈值的大小是十分重要的。
在本发明中,建立了一个基于均值的方法来建立动态阈值X,具体如下:
先建立两个空列表similarities_batch和similarities_epoch,分别用于存放每个批量的波形相似度以及每轮训练的波形相似度;
每个批量的波形相似度则是该批量中所有生成信号得到的波形相似度的均值;
在第一轮训练时,由于similarities_epoch为空,此时的动态阈值X会由第一个批量波形相似度similarities_batch的均值代替,并将此均值追加到similarities_epoch;
从第二轮训练开始,动态阈值X则变为similarities_epoch的均值;
每个批量的所有生成信号波形都和similarities_epoch的均值作比较,并求此批量中similarities_batch的均值追加到similarities_epoch;
当similarities_epoch列表中的波形相似度个数大于10个时,最开始训练的得到的波形相似度的值已经不具备参考性,此时的动态阈值只求similarities_epoch中最后10个epoch的波形相似度均值。
举例说明:
在第一轮训练时,列表similarities_epoch为空。生成器生成了十条信号,第一条信号与真实信号进行波形相似度检测,得到一个相似度值为0.1%,并将此值添加到列表similarities_batch[0.1]。第二条信号与真实信号进行波形相似度检测,得到一个相似度值为0.5%,如果相似度大于列表similarities_batch的均值(即0.1/1=0.1),则输入到判别器中,并将此值添加到列表similarities_batch[0.1,0.5],第三条信号与真实信号进行波形相似度检测,得到一个相似度值为0.1%,如果相似度大于列表similarities_batch的均值(即(0.1+0.5)/2=0.3),则输入到判别器中,否则该条信号丢弃,并将此值添加到列表similarities_batch[0.1,0.5,0.1]。以此类推。在第十条时,列表similarities_batch[0.1,0.5,0.1,0.6,0.5,0.3,1.2,1.6,1.8,1.5]中则有10个相似度值。
在第二轮训练时,计算列表similarities_batch的均值(即(0.1+0.5+0.1+0.6+0.5+0.3+1.2+1.6+1.8+1.5)/10=0.82),并将此均值添加到列表similarities_epoch[0.82],然后清空列表similarities_batch。第二轮训练生成器又生成了十条信号,每条信号都与原始信号进行波形检测,并将波形相似度大于列表similarities_epoch均值(即0.82/1=0.82)的波形送入判别器判别,否则丢弃。第二轮训练完成后,列表similarities_batch又有了10个值,并重新计算similarities_batch均值添加到similarities_epoch[0.82,1.86],依次类推。在第十二轮训练时,列表similarities_epoch[0.82,1.86,1.88,1.99,2.63,2.84,2.91,2.98,3.12,3.96,4.52]中有11个值,此时,第一轮训练集的相似度值太低,不具备参考性,此时计算similarities_epoch的均值只计算后十个数值的均值(即(1.86+1.88+1.99+2.63+2.84+2.91+2.98+3.12+3.96+4.52)/10=2.869),并将生成器的波形都和该均值进行比较,直至训练完成。
步骤2. 利用训练数据集对基于GAN网络的数据增强模型进行训练,并利用训练好的基于GAN网络的数据增强模型,对输入的信号进行数据增强。
首先获取训练数据,其中训练数据的获取过程如下:
首先对数据进行预处理,采用通带频率为0.8Hz-45Hz的巴特沃斯带通滤波器对原始信号进行滤波。巴特沃斯带通滤波器在通频带内的频率响应曲线达到最大平坦,在阻频带能够快速的下降为零。
然后对数据进行Z-score标准化,以使数据的均值为0,标准差为1,公式如下:
Zdata= (Xdata-μ) /σ;
其中,Zdata是标准化后的数据;
Xdata是原始数据,μ是数据的均值,σ是数据的标准差;
最后将数据切割成10秒的固定窗口,并堆叠成阵列。每个窗口之间没有重叠,以避免训练和测试数据之间的数据重复。
本实施例中基于GAN网络的数据增强模型的训练过程如下:
初始化GAN类对象:创建GAN类的实例,传入一些参数,包括输入形状(inputshape)、随机噪声的维度(latent size)、训练轮数、批量大小等。
设置训练的总轮数和每批次的样本数量参数,循环遍历每个训练轮数。
先训练小批量判别器,从真实信号中随机选择一批信号样本。
随机噪声通过生成器生成一批生成信号,这一批生成信号先经过一个自适应波形检测器,并在自适应波形检测器中进行如下处理:
这一批生成信号中的每条生成信号先与真实信号进行欧几里德距离检测,计算出每条生成信号与真实信号的波形相似度;
将相似度低于动态阈值的生成信号丢弃掉,高于动态阈值的生成信号则进入小批量判别器中进行判别,从而训练小批量判别器,计算并记录小批量判别器的损失。
再训练生成器,并计算并记录生成器的损失。
在训练过程中小批量判别器和生成器交替训练,小批量判别器的目标是正确地区分真实信号和生成的伪造信号,生成器的目标是生成足够逼真的信号以欺骗判别器。
通过反复训练,小批量判别器和生成器的性能逐渐提高,生成器生成更逼真的信号。
在基于GAN网络的数据增强模型训练过程中,会打印出每个轮次的判别器损失、准确率和生成器损失,以及保存训练过程中的损失和准确率信息。
此外,本发明还给出了如下实验,以验证本发明所提出的数据增强方法的有效性。
1. GAN网络生成样本多样性测试。
为了验证GAN生成数据和原始数据的多样性,本实施例以人体心电信号为例,提供了原始数据与使用小批量判别器生成的数据的时域波形图对比图,如图5所示,以及原始数据与使用未使用小批量判别器生成的数据的时域波形图对比图,如图8所示。
图6展示了图5中A部的放大图,显示了生成信号的QRS波的多样性结果。图7展示了图5中B部的局部放大图,显示了生成信号的R-R间隔多样性结果。
图9展示了图8中C部的局部放大图,显示了数据增强前后信号的QRS波的多样性对比。图10展示了图8中D部的局部放大图,显示了数据增强前后信号的R-R间隔多样性对比。
由上组图对比结果可知,本发明使用人体信号为例,验证了小批量判别器生成的数据相比于原始数据波形趋势相似,幅值大小接近,稳定值波动范围接近,增强结果满足要求,而未使用小批量判别器只生成一些固定样本的数据,并不会增加数据集的多样性。
通过比较看出,本发明GAN网络能够生成与原始数据相似的增强数据,使其具备与原始测试信号更接近的特征并保持了差异化,让模型更快地收敛到最优解,帮助生成器更好地学习真实数据的分布,从而生成更加逼真的数据,增加了数据集的多样性。
2. GAN网络自适应波形相似度检测器测试。
在GAN训练过程中加入了自适应波形检测器,来提高GAN网络的收敛速度。为了验证自适应波形检测器的效果,本发明采用了inception score作为衡量GAN网络性能的指标,对比了使用了自适应波形检测器和未使用的结果,实验结果如图11所示。
实验结果表明,在未使用自适应波形检测器时,GAN网络在2000epoch时达到了纳什均衡,而本发明加入了自适应波形检测器的GAN网络后,仅仅在1000个epoch时就达到了纳什均衡,使得数据增强模型的收敛速度提升了近50%。
通过比较看出,本发明所提的自适应波形检测器能够很好的提高GAN网络的收敛速度。通过将波形相似度作为额外的训练准则,能够引导生成器朝着更接近原始信号的方向进行学习和优化,这有助于提高生成器生成信号的逼真程度和准确性。
实施例2
本实施例2述及了一种计算机设备,该计算机设备用于实现上述实施例1中的基于GAN网络的数据增强方法。
具体的,该计算机设备包括存储器和一个或多个处理器。在存储器中存储有可执行代码,当处理器执行可执行代码时,用于实现基于GAN网络的数据增强方法的步骤。
本实施例中计算机设备为任意具备数据数据处理能力的设备或装置,此处不再赘述。
实施例3
本实施例3述及了一种计算机可读存储介质,其上存储有程序,该程序被处理器执行时,用于实现上述基于GAN网络的数据增强方法的步骤。
该计算机可读存储介质可以是任意具备数据处理能力的设备或装置的内部存储单元,例如硬盘或内存,也可以是任意具备数据处理能力的设备的外部存储设备,例如设备上配备的插接式硬盘、智能存储卡(Smart Media Card,SMC)、SD卡、闪存卡(Flash Card)等。
当然,以上说明仅仅为本发明的较佳实施例,本发明并不限于列举上述实施例,应当说明的是,任何熟悉本领域的技术人员在本说明书的教导下,所做出的所有等同替代、明显变形形式,均落在本说明书的实质范围之内,理应受到本发明的保护。
Claims (8)
1.基于GAN网络的数据增强方法,其特征在于,包括如下步骤:
步骤1.搭建基于GAN网络的数据增强模型;
搭建的数据增强模型包括生成器、小批量判别器以及自适应波形检测器;
信号在数据增强模型中的处理流程如下:
在生成器中,符合正态分布的随机噪声输入到生成器中,经过生成器生成一批生成信号,生成信号进入到自适应波形检测器中;
在自适应波形检测器中设定动态阈值,通过欧几里德距离进行波形的相似度监测;
如果生成信号与真实信号的相似度小于动态阈值,则将该条生成信号丢弃,将相似度大于或等于动态阈值的生成信号输入小批量判别器进行判别;其中真实信号为人体心电信号;
所述动态阈值基于均值的方法建立,具体如下:
先建立两个空列表similarities_batch和similarities_epoch,分别用于存放每个批量的波形相似度以及每轮训练的波形相似度;
每个批量的波形相似度则是该批量中所有生成信号得到的波形相似度的均值;
在第一轮训练时,由于similarities_epoch为空,此时的动态阈值X会由第一个批量波形相似度similarities_batch的均值代替,并将此均值追加到similarities_epoch;
具体的,在第一轮训练时,列表similarities_epoch为空;生成器生成十条信号,第一条信号与真实信号进行波形相似度检测得到一个相似度值,并将此值添加到列表similarities_batch;第二条信号与真实信号进行波形相似度检测得到一个相似度值,如果相似度大于列表similarities_batch的均值,则输入到判别器中,否则该条信号丢弃,并将此值添加到列表similarities_batch,第三条信号与真实信号进行波形相似度检测得到一个相似度值,如果相似度大于列表similarities_batch的均值,则输入到判别器中,否则该条信号丢弃,并将此值添加到列表similarities_batch;重复上述操作,则在第十条时,列表similarities_batch中则有10个相似度值;
从第二轮训练开始,动态阈值X则变为similarities_epoch的均值;
每个批量的所有生成信号波形都和similarities_epoch的均值作比较,并求此批量中similarities_batch的均值追加到similarities_epoch;
当similarities_epoch列表中的波形相似度个数大于10个时,此时的动态阈值只求similarities_epoch中最后10个epoch的波形相似度均值;
具体的,在第二轮训练时,计算列表similarities_batch的均值,并将此均值添加到列表similarities_epoch,然后清空列表similarities_batch;第二轮训练生成器又生成十条信号,每条信号都与原始信号进行波形检测,并将波形相似度大于列表similarities_epoch均值的波形送入判别器判别,否则丢弃;第二轮训练完成后,列表similarities_batch又有10个值,并重新计算similarities_batch均值添加到similarities_epoch;重复上述操作,在第n+2轮训练时,列表similarities_epoc中有n+1个值,此时计算similarities_epoch的均值只用后十个数值的均值,并将生成器的波形都和该均值进行比较,直至训练完成,n为大于或等于10的自然数;
小批量判别器分为两个分支,分别是CNN网络分支和小批量判别分支;
信号在小批量判别器中的处理流程如下:
首先,从生成信号中选择一个小批次的样本作为小批量判别器的输入;
在小批量判别分支,将输入的样本与权重矩阵W相乘,得到一个表示样本相似性的张量,利用表示样本相似性的张量计算样本之间的差异,并求取绝对差异的和,使用指数函数对差异处理,得到小批量特征,作为小批量判别分支的输出;
在CNN网络分支,对生成信号进行深度学习特征提取,并与小批量判别分支的输出进行合并,合并后在全连接层进行分类0或1,用来判别真假;
步骤2.利用训练数据集对基于GAN网络的数据增强模型进行训练,并利用训练好的基于GAN网络的数据增强模型,对输入的信号进行数据增强。
2.根据权利要求1所述的基于GAN网络的数据增强方法,其特征在于,
所述生成器包括重塑层、卷积模块、展平层、全连接层以及Tanh激活函数;
其中,重塑层有两个,分别为第一、第二重塑层;卷积模块有三个,且每个卷积模块均包括上采样层、一维卷积层、批归一化层以及激活函数;
生成器的输入是一个100维随机噪声向量,其在生成器中的处理流程如下:
首先经过第一重塑层将输入噪声向量重塑维度为(100,1),并输入到卷积模块;
在卷积模块中使用上采样层向上采样,通过线性插值将输入的时间序列长度加倍,在上采样层后加入一维卷积层,在一维卷积层后还加入批归一化层;
批归一化层用于将每个batch的数据归一化到均值为0,方差为1的分布;
依次经过三个卷积模块后的输出展开成一维向量,作为展平层的输入,再使用Tanh激活函数对全连接层的输出进行激活,使其输出范围在[-1,1]之间;
最后经过第二重塑层,将输出信号重塑为真实信号的维度,作为生成器的输出。
3.根据权利要求1所述的基于GAN网络的数据增强方法,其特征在于,
所述小批量判别分支包括展平层以及小批量判别层;
其中,小批量判别层的网络结构如下:
输入层是一个2D张量,形状为(Batch_Size,input dim),其中,Batch_Size表示批量大小,input dim表示输入的通道数或特征维度;
权重矩阵W是一个3D张量,形状为(nb_kernels,input dim,kernel dim);
其中,nb_kernels表示判别器核的数量,input dim表示输入的通道数或特征维度,kernel dim表示计算样本相似性的空间的维度;
权重矩阵是通过层的build方法创建的,并在训练过程中进行更新;
前向传播逻辑是输入x与权重矩阵W进行矩阵乘法运算,得到表示样本相似性的张量,对表示样本相似性的张量进行计算,包括计算样本之间的差异、绝对差异的和以及小批量特征;最后,将输入x和小批量特征连接在一起,并作为最终的输出;
输出层是一个2D张量,形状为(Batch_Size,input dim+nb_kernels),其中Batch_Size表示批量大小,input dim表示输入的通道数或特征维度,nb_kernels表示判别器核的数量。
4.根据权利要求1所述的基于GAN网络的数据增强方法,其特征在于,
所述CNN网络分支包括四个卷积模块以及展平层;其中,每个卷积模块均包括一维卷积层、激活函数层、以及Dropout层;
定义四个卷积模块依次为第一、第二、第三、第四卷积模块;
信号在CNN网络分支中的处理流程为:
首先通过第一卷积模块的一维卷积层,进行8个大小为8的卷积核进行卷积操作,步长为1;然后一维卷积层的输出输入到LeakyReLU激活函数,并经过Dropout操作;
然后第一卷积模块的输出依次进入第二、第三、第四卷积模块重复上述操作,在经过第三卷积模块和第四卷积模块的一维卷积层时,步长变为2;
最后将第四卷积模块的Dropout层的输出输入到展平层中进行展平操作。
5.根据权利要求1所述的基于GAN网络的数据增强方法,其特征在于,
所述步骤2中,训练数据的获取过程如下:
首先对数据进行预处理,采用通带频率为0.8Hz-45Hz的巴特沃斯带通滤波器对原始信号进行滤波处理;
然后对滤波数据进行Z-score标准化,以使数据的均值为0,标准差为1,公式如下:
Zdata=(Xdata-μ)/σ;
其中,Zdata是标准化后的数据;
Xdata是原始数据,μ是数据的均值,σ是数据的标准差;
最后将数据切割成10秒的固定窗口,并堆叠成阵列。
6.根据权利要求1所述的基于GAN网络的数据增强方法,其特征在于,
所述步骤2中,基于GAN网络的数据增强模型的训练过程如下:
设置训练的总轮数和每批次的样本数量参数,循环遍历每个训练轮数;
先训练小批量判别器,从真实信号中随机选择一批信号样本;
随机噪声通过生成器生成一批生成信号,这一批生成信号先经过一个自适应波形检测器,并在自适应波形检测器中进行如下处理:
这一批生成信号中的每条生成信号先与真实信号进行欧几里德距离检测,计算出每条生成信号与真实信号的波形相似度;
将相似度低于动态阈值的生成信号丢弃掉,高于动态阈值的生成信号则进入小批量判别器中进行判别,从而训练小批量判别器,计算并记录小批量判别器的损失;
再训练生成器,并计算并记录生成器的损失;
在训练过程中小批量判别器和生成器交替训练,小批量判别器的目标是正确地区分真实信号和生成的伪造信号,生成器的目标是生成足够逼真的信号以欺骗判别器;
通过反复训练,小批量判别器和生成器的性能逐渐提高,生成器生成更逼真的信号。
7.一种计算机设备,包括存储器和一个或多个处理器;所述存储器中存储有可执行代码,其特征在于,所述处理器执行所述可执行代码时,用于实现上述权利要求1至6任一项所述的基于GAN网络的数据增强方法的步骤。
8.一种计算机可读存储介质,在计算机可读存储介质上存储有程序;其特征在于,该程序被处理器执行时,用于实现上述权利要求1至6任一项所述的基于GAN网络的数据增强方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310942682.7A CN116663619B (zh) | 2023-07-31 | 2023-07-31 | 基于gan网络的数据增强方法、设备以及介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310942682.7A CN116663619B (zh) | 2023-07-31 | 2023-07-31 | 基于gan网络的数据增强方法、设备以及介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116663619A CN116663619A (zh) | 2023-08-29 |
CN116663619B true CN116663619B (zh) | 2023-10-13 |
Family
ID=87721017
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310942682.7A Active CN116663619B (zh) | 2023-07-31 | 2023-07-31 | 基于gan网络的数据增强方法、设备以及介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116663619B (zh) |
Citations (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107392147A (zh) * | 2017-07-20 | 2017-11-24 | 北京工商大学 | 一种基于改进的生成式对抗网络的图像语句转换方法 |
CN109242000A (zh) * | 2018-08-09 | 2019-01-18 | 百度在线网络技术(北京)有限公司 | 图像处理方法、装置、设备及计算机可读存储介质 |
CN111986142A (zh) * | 2020-05-23 | 2020-11-24 | 冶金自动化研究设计院 | 一种热轧板卷表面缺陷图像数据无监督增强的方法 |
CN112529806A (zh) * | 2020-12-15 | 2021-03-19 | 哈尔滨工程大学 | 基于生成对抗网络信息最大化的sar图像数据增强方法 |
KR20210066730A (ko) * | 2019-11-28 | 2021-06-07 | 연세대학교 산학협력단 | 클라우드 플랫폼 서비스 기반 데이터 증강을 통한 건전성 예측 관리 모델 설계 방법 및 시스템 |
CN113052273A (zh) * | 2021-06-01 | 2021-06-29 | 之江实验室 | 基于像素组合约束和采样校正的gan图像生成方法 |
CN113962360A (zh) * | 2021-10-09 | 2022-01-21 | 西安交通大学 | 一种基于gan网络的样本数据增强方法及系统 |
CN114469120A (zh) * | 2022-01-12 | 2022-05-13 | 大连海事大学 | 一种基于相似度阈值迁移的多尺度Dtw-BiLstm-Gan心电信号生成方法 |
CN115290596A (zh) * | 2022-08-03 | 2022-11-04 | 广东工业大学 | 一种基于fcn-acgan数据增强的隐匿危险品识别方法及设备 |
CN115439323A (zh) * | 2022-08-30 | 2022-12-06 | 湖州师范学院 | 一种基于渐进式增长条件生成对抗网络的图像生成方法 |
CN115860113A (zh) * | 2023-03-03 | 2023-03-28 | 深圳精智达技术股份有限公司 | 一种自对抗神经网络模型的训练方法及相关装置 |
CN116484184A (zh) * | 2023-05-29 | 2023-07-25 | 广东电网有限责任公司广州供电局 | 一种电力设备局部放电缺陷样本增强方法及装置 |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11024009B2 (en) * | 2016-09-15 | 2021-06-01 | Twitter, Inc. | Super resolution using a generative adversarial network |
US11501438B2 (en) * | 2018-04-26 | 2022-11-15 | Elekta, Inc. | Cone-beam CT image enhancement using generative adversarial networks |
-
2023
- 2023-07-31 CN CN202310942682.7A patent/CN116663619B/zh active Active
Patent Citations (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107392147A (zh) * | 2017-07-20 | 2017-11-24 | 北京工商大学 | 一种基于改进的生成式对抗网络的图像语句转换方法 |
CN109242000A (zh) * | 2018-08-09 | 2019-01-18 | 百度在线网络技术(北京)有限公司 | 图像处理方法、装置、设备及计算机可读存储介质 |
KR20210066730A (ko) * | 2019-11-28 | 2021-06-07 | 연세대학교 산학협력단 | 클라우드 플랫폼 서비스 기반 데이터 증강을 통한 건전성 예측 관리 모델 설계 방법 및 시스템 |
CN111986142A (zh) * | 2020-05-23 | 2020-11-24 | 冶金自动化研究设计院 | 一种热轧板卷表面缺陷图像数据无监督增强的方法 |
CN112529806A (zh) * | 2020-12-15 | 2021-03-19 | 哈尔滨工程大学 | 基于生成对抗网络信息最大化的sar图像数据增强方法 |
CN113052273A (zh) * | 2021-06-01 | 2021-06-29 | 之江实验室 | 基于像素组合约束和采样校正的gan图像生成方法 |
CN113962360A (zh) * | 2021-10-09 | 2022-01-21 | 西安交通大学 | 一种基于gan网络的样本数据增强方法及系统 |
CN114469120A (zh) * | 2022-01-12 | 2022-05-13 | 大连海事大学 | 一种基于相似度阈值迁移的多尺度Dtw-BiLstm-Gan心电信号生成方法 |
CN115290596A (zh) * | 2022-08-03 | 2022-11-04 | 广东工业大学 | 一种基于fcn-acgan数据增强的隐匿危险品识别方法及设备 |
CN115439323A (zh) * | 2022-08-30 | 2022-12-06 | 湖州师范学院 | 一种基于渐进式增长条件生成对抗网络的图像生成方法 |
CN115860113A (zh) * | 2023-03-03 | 2023-03-28 | 深圳精智达技术股份有限公司 | 一种自对抗神经网络模型的训练方法及相关装置 |
CN116484184A (zh) * | 2023-05-29 | 2023-07-25 | 广东电网有限责任公司广州供电局 | 一种电力设备局部放电缺陷样本增强方法及装置 |
Non-Patent Citations (3)
Title |
---|
M. H. -M. Khan et al.Investigating on Data Augmentation and Generative Adversarial Networks (GAN s) for Diabetic Retinopathy.《2022 International Conference on Electrical, Computer, Communications and Mechatronics Engineering》.2022,1-5. * |
于贺等.基于多尺寸卷积与残差单元的快速收敛GAN胸部X射线图像数据增强.《信号处理》.2019,第35卷(第12期),2045-2054. * |
邵海东等.基于改进ACGAN的齿轮箱多模式 数据增强与故障诊断.《交通运输工程学报》.2023,第23卷(第3期),188-197. * |
Also Published As
Publication number | Publication date |
---|---|
CN116663619A (zh) | 2023-08-29 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110084173B (zh) | 人头检测方法及装置 | |
CN109948647B (zh) | 一种基于深度残差网络的心电图分类方法及系统 | |
CN108108662B (zh) | 深度神经网络识别模型及识别方法 | |
Nejad et al. | A new enhanced learning approach to automatic image classification based on Salp Swarm Algorithm | |
CN111291727B (zh) | 一种光体积变化描记图法信号质量检测方法和装置 | |
CN112783327B (zh) | 基于表面肌电信号进行手势识别的方法及系统 | |
CN113989890A (zh) | 基于多通道融合和轻量级神经网络的人脸表情识别方法 | |
Kulasingham et al. | Deep belief networks and stacked autoencoders for the p300 guilty knowledge test | |
CN113143295A (zh) | 基于运动想象脑电信号的设备控制方法及终端 | |
CN115238835A (zh) | 基于双空间自适应融合的脑电情感识别方法、介质及设备 | |
He et al. | What catches the eye? Visualizing and understanding deep saliency models | |
Lakshmi et al. | Automated detection and segmentation of brain tumor using genetic algorithm | |
CN113486752A (zh) | 基于心电信号的情感识别方法及系统 | |
CN113133769A (zh) | 基于运动想象脑电信号的设备控制方法、装置及终端 | |
Li et al. | Study on the detection of pulmonary nodules in CT images based on deep learning | |
Asghar et al. | Semi-skipping layered gated unit and efficient network: hybrid deep feature selection method for edge computing in EEG-based emotion classification | |
Seeböck | Deep learning in medical image analysis | |
Amiri et al. | Improved sparse coding under the influence of perceptual attention | |
CN116663619B (zh) | 基于gan网络的数据增强方法、设备以及介质 | |
CN113421546A (zh) | 基于跨被试多模态的语音合成方法及相关设备 | |
Moftah et al. | Brain Diagnoses Detection Using Whale Optimization Algorithm Based on Ensemble Learning Classifier. | |
Wang et al. | A modified sparse representation method for facial expression recognition | |
CN114098691A (zh) | 基于混合高斯模型的脉搏波身份认证方法、装置和介质 | |
Paidja et al. | Engagement emotion classification through facial landmark using convolutional neural network | |
Iffath et al. | A Novel Three Stage Framework for Person Identification From Audio Aesthetic |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |