CN112232395A - 一种基于联合训练生成对抗网络的半监督图像分类方法 - Google Patents
一种基于联合训练生成对抗网络的半监督图像分类方法 Download PDFInfo
- Publication number
- CN112232395A CN112232395A CN202011068394.6A CN202011068394A CN112232395A CN 112232395 A CN112232395 A CN 112232395A CN 202011068394 A CN202011068394 A CN 202011068394A CN 112232395 A CN112232395 A CN 112232395A
- Authority
- CN
- China
- Prior art keywords
- discriminator
- label
- data
- training
- unlabeled
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 70
- 238000000034 method Methods 0.000 title claims abstract description 36
- 238000012360 testing method Methods 0.000 claims abstract description 13
- 238000005070 sampling Methods 0.000 claims description 6
- 238000004364 calculation method Methods 0.000 claims description 3
- 230000004913 activation Effects 0.000 description 3
- 238000013461 design Methods 0.000 description 3
- 230000006870 function Effects 0.000 description 3
- 238000012544 monitoring process Methods 0.000 description 3
- 230000008569 process Effects 0.000 description 3
- 238000012545 processing Methods 0.000 description 3
- 230000003042 antagnostic effect Effects 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 238000000605 extraction Methods 0.000 description 2
- 238000010606 normalization Methods 0.000 description 2
- 238000004422 calculation algorithm Methods 0.000 description 1
- 238000006243 chemical reaction Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000011426 transformation method Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computational Linguistics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Evolutionary Biology (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于联合训练生成对抗网络的半监督图像分类方法,包括以下步骤:步骤一、设置生成对抗网络;步骤二、划分标签数据集L和无标签数据集U;步骤三、训练生成器G;步骤四、训练判别器D1和判别器D2,迭代更新扩充了标签子样本集;步骤五、得到训练好的生成对抗网络;步骤六、利用训练好的生成对抗网络对测试集进行分类。本发明采用判别器D1和判别器D2联合训练,减小单个判别器存在的分布误差对生成对抗网络的影响;基于联合训练的生成对抗网络能够减小生成对抗网络对标签数据的依赖,利用无标签数据在训练时扩充标签数据集,加快网络收敛,提高生成对抗网络的分类准确率,从而进一步提高在小样本条件下网络图像分类的精度。
Description
技术领域
本发明属于图像处理技术领域,具体涉及一种基于联合训练生成对抗网络的半监督图像分类方法。
背景技术
作为计算机视觉领域最常见的任务之一,图像分类通过提取原始图像的特征并根据特征进行分类。传统的特征提取主要是通过对图像的颜色、纹理、局部特征等几个方面进行分析处理实现的,例如尺度不变特征变换法,方向梯度法以及局部二值法等。但是这些特征都是人为设计的特征,很大程度上靠人类对识别目标的先验知识进行设计,具有一定的局限性。随着大数据时代的到来,基于深度学习的图像分类方法具有对大量复杂数据进行处理和表征的能力,能够有效学习目标的特征信息,从而大大提高图像分类的精度。
深度学习以数据驱动方式进行训练学习,对标签数据依赖性强,而实际应用中往往难以获取大量的标签数据。当样本数量不足时,深度网络模型容易过拟合,导致分类性能较差。生成对抗网络,也称GAN网络,是由Goodfellow等在2014年提出的,由一个生成器和一个判别器构成。生成器根据输入数据分布来生成尽可能逼真的伪数据,判别器用于判断输入数据是真实数据还是生成器生成的伪数据。在训练期间,生成器不断尝试通过产生越来越好的假图片来超越判别器,与此同时判别器逐渐更好的检测并正确分类真假图片,生成器和判别器经过博弈对抗达到纳什均衡,此时生成的数据能够拟合真实的数据分布。GAN网络在训练时既能够生成样本,又能够提高特征提取能力,可以用来解决数据样本少的问题。但GAN网络还存在稳定性差和依赖标签数据的问题,不能直接应用于分类任务中。
针对GAN网络稳定性差的问题,目前已经有多种方法通过改进GAN网络结构或优化算法来解决。但是目前针对依赖标签数据的问题,并没有有效的分类方法,因此亟需一种在一定程度上减小网络对标签数据的依赖、且能提高网络分类准确率的改进的GAN网络。
发明内容
本发明所要解决的技术问题在于针对上述现有技术中的不足,提供一种基于联合训练生成对抗网络的半监督图像分类方法,其结构简单、设计合理,采用判别器D1和判别器D2联合训练,以减小单个判别器误差对生成对抗网络的影响;利用大量无标签数据和少量标签数据进行联合训练,能够学习到泛化能力较强的模型,在一定程度上减小生成对抗网络对标签数据的依赖,利用无标签数据在训练时扩充标签数据集,加快网络收敛,提高生成对抗网络的分类准确率。
为解决上述技术问题,本发明采用的技术方案是:一种基于联合训练生成对抗网络的半监督图像分类方法,其特征在于:包括以下步骤:
步骤一、设置生成对抗网络,包括生成器G、判别器D1和判别器D2,设置生成对抗网络的训练初始参数;
步骤二、获取训练集和测试集,训练集包括标签数据集L和无标签数据集U,将标签数据集L打乱随机分为标签子样本集L1和标签子样本集L2,其中,标签子样本集L1、L2包括k类标签数据;将无标签数据集U打乱随机分为无标签子样本集U1和无标签子样本集U2,其中,无标签子样本集U1包括g个无标签数据,无标签子样本集U2包括r个无标签数据;
步骤三、训练生成器G:
步骤301、将随机高斯噪声z输入生成器G生成伪数据G(z);
步骤302、将伪数据G(z)输入到判别器D1,判别器D1对伪数据G(z)进行判别得到D1(G(z));
步骤303、将伪数据G(z)输入到判别器D2,判别器D2对伪数据G(z)进行判别得到D2(G(z));
步骤304、计算生成器G的损失minLG;
步骤305、更新生成器G的训练参数;
步骤四、训练判别器D1和判别器D2:
步骤401、将标签子样本集L1输入到判别器D1,判别器D1输出k+1维分类预测概率{l11,...l1i,...l1k,l1(k+1)},其中l11至l1k表示标签子样本集L1中k类标签数据的置信度,l1(k+1)表示伪数据G(z)由判别器D1判定为“伪”的置信度;
步骤402、将无标签子样本集U1中的第n个无标签数据输入到判别器D1,判别器D1针对第n个无标签数据输出k+1维分类预测概率{h11-n,...h1i-n,...h1k-n,h1(k+1)-n},若MAX{h11-n,...h1j-n,...h1g-n}>η,则将无标签子样本集U1中第n个无标签数据加入标签子样本集L2中MAX{h11-n,...h1j-n,...h1g-n}所对应的标签类别,η表示置信度阈值,1≤n≤g;
步骤403、将标签子样本集L2输入到判别器D2,判别器D2输出k+1维分类预测概率{l21,...l2i,...l2k,l2(k+1)},其中l21至l2k表示标签子样本集L2中k类标签数据的置信度,l2(k+1)表示伪数据G(z)由判别器D2判定为“伪”的置信度;
步骤404、将无标签子样本集U2中的第m个无标签数据输入到判别器D2,判别器D2针对第m个无标签数据输出k+1维分类预测概率{h21-m,...h2i-m,...h2k-m,h2(k+1)-m},若MAX{h21-m,...h2j-m,...h2g-m}>η,则将无标签子样本集U2中的第m个无标签数据加入标签子样本集L1中MAX{h21-m,...h2j-m,...h2g-m}所对应的标签类别,η表示置信度阈值,1≤m≤r;
步骤405、计算判别器总损失maxLD;
步骤406、更新判别器D1和判别器D2的训练参数;
步骤五、迭代更新:
步骤501、若判别器损失maxLD收敛,结束迭代,得到训练好的生成对抗网络,否则进入步骤502;
步骤502、迭代执行步骤二到步骤五,每次迭代后,迭代次数加1,直到迭代次数等于最大迭代次数,迭代结束。
步骤六、利用测试集对生成对抗网络进行测试,生成对抗网络输出对测试集的分类结果,获得生成对抗网络的分类精度。
上述的一种基于联合训练生成对抗网络的半监督图像分类方法,其特征在于:步骤403中判别器总损失的计算公式为其中表示判别器监督损失,其中yi表示标签数据集L中第i维数据的标签,Du(xi)表示判别器Du判别标签数据的标签为第i维的概率,maxLunsupD表示判别器无监督损失,y′i表示判别器前一次迭代时判别无标签数据的类别为第i维。
上述的一种基于联合训练生成对抗网络的半监督图像分类方法,其特征在于:所述生成器G的网络结构依次为:输入层→全连接层→上采样层→卷积层Conv1→上采样层→卷积层Conv2→卷积层Conv3。
上述的一种基于联合训练生成对抗网络的半监督图像分类方法,其特征在于:所述判别器D1和判别器D2的网络结构相同,判别器D1的网络结构依次为:输入层→卷积层Conv1→卷积层Conv2→卷积层Conv3。
本发明与现有技术相比具有以下优点:
1、本发明的结构简单、设计合理,实现及使用操作方便。
2、本发明的基于联合训练生成对抗网络中,采用判别器D1和判别器D2进行联合训练,判别器的总损失为判别器D1损失和判别器D2损失的均值,以消除单个判别器存在的分布误差,从而以减小单个判别器误差对生成对抗网络的影响,提高判别器训练的稳定性。
3、本发明设置置信度阈值η,对每次迭代得到的无标签样本集的分类结果进行置信度判断,如果大于该置信度阈值,则将该标签数据加入到标签样本集中继续迭代训练,利用无标签样本集扩充标签样本集,从而加快生成对抗网络收敛,提高图像分类效率。
综上所述,本发明结构简单、设计合理,采用判别器D1和判别器D2联合训练,以减小单个判别器误差对生成对抗网络的影响;利用大量无标签数据和少量标签数据进行联合训练,能够学习到泛化能力较强的模型,在一定程度上减小生成对抗网络对标签数据的依赖,利用无标签数据在训练时扩充标签数据集,加快网络收敛,提高生成对抗网络的分类准确率。
下面通过附图和实施例,对本发明的技术方案做进一步的详细描述。
附图说明
图1为本发明的方法流程图。
图2为本发明生成器的结构示意图。
图3为本发明判别器的结构示意图。
具体实施方式
下面结合附图及本发明的实施例对本发明的方法作进一步详细的说明。
需要说明的是,在不冲突的情况下,本申请中的实施例及实施例中的特征可以相互组合。下面将参考附图并结合实施例来详细说明本发明。
需要注意的是,这里所使用的术语仅是为了描述具体实施方式,而非意图限制根据本申请的示例性实施方式。如在这里所使用的,除非上下文另外明确指出,否则单数形式也意图包括复数形式,此外,还应当理解的是,当在本说明书中使用术语“包含”和/或“包括”时,其指明存在特征、步骤、操作、器件、组件和/或它们的组合。
需要说明的是,本申请的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本申请的实施方式例如能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
为了便于描述,在这里可以使用空间相对术语,如“在……之上”、“在……上方”、“在……上表面”、“上面的”等,用来描述如在图中所示的一个器件或特征与其他器件或特征的空间位置关系。应当理解的是,空间相对术语旨在包含除了器件在图中所描述的方位之外的在使用或操作中的不同方位。例如,如果附图中的器件被倒置,则描述为“在其他器件或构造上方”或“在其他器件或构造之上”的器件之后将被定位为“在其他器件或构造下方”或“在其他器件或构造之下”。因而,示例性术语“在……上方”可以包括“在……上方”和“在……下方”两种方位。该器件也可以其他不同方式定位(旋转90度或处于其他方位),并且对这里所使用的空间相对描述作出相应解释。
如图1所示,本发明的一种基于联合训练生成对抗网络的半监督图像分类方法,包括以下步骤:
步骤一、设置生成对抗网络,包括生成器G、判别器D1和判别器D2,设置生成对抗网络的训练初始参数。
在本申请基于联合训练生成对抗网络中,采用了判别器D1和判别器D2进行联合训练,以减小单个判别器误差对生成对抗网络的影响。判别器D1和判别器D2共享同一个生成器G,同时判别器D1和判别器D2的网络结构和训练初始参数设为相同。
步骤二、获取训练集和测试集,训练集包括标签数据集L和无标签数据集U,将标签数据集L打乱随机分为标签子样本集L1和标签子样本集L2,其中,标签子样本集L1、L2包括k类标签数据;将无标签数据集U打乱随机分为无标签子样本集U1和无标签子样本集U2,其中,无标签子样本集U1包括g个无标签数据,无标签子样本集U2包括r个无标签数据。
需要说明的是,将标签数据集L和无标签数据集U的顺序打乱随机分为两个子集,然后分别输入到判别器D1和判别器D2中,可以保证训练过程中,判别器D1和判别器D2是动态变化的。
步骤三、训练生成器G:
步骤301、将随机高斯噪声z输入生成器G生成伪数据G(z)。基于联合训练生成对抗网络的生成器G框架如图2所示,需要说明的是,所述生成器G的网络结构依次为:输入层→全连接层→上采样层→卷积层Conv1→上采样层→卷积层Conv2→卷积层Conv3。
具体实施时,生成器G的输入为(128,100)的随机噪声,首先通过(100,8192)的全连接层得到(128,8192)的张量,经过维度转换得到维度为(128,128,8,8)的图像,经过两次上采样操作和三次步长为1的3×3卷积核的卷积操作后得到维度为(128,3,32,32)的图像,其中每次完成卷积操作后都是用归一化操作加入RELU激活函数,最后一层通过Tanh激活函数输出伪数据G(z)。
步骤302、将伪数据G(z)输入到判别器D1,判别器D1对伪数据G(z)进行判别得到D1(G(z));
步骤303、将伪数据G(z)输入到判别器D2,判别器D2对伪数据G(z)进行判别得到D2(G(z));
步骤304、计算生成器G的损失minLG:原始生成对抗网络中生成器的损失表示为为了让生成器生成的数据分布更接近真实数据的统计分布,采用特征匹配的方法对生成器的损失进行约束,定义特征匹配损失为:其中fu(·)表示判别器Du中间层的特征值,u=1、2。因此生成器G的损失minLG的计算公式为:
步骤305、更新生成器G的训练参数。
步骤四、训练判别器D1和判别器D2:
步骤401、将标签子样本集L1输入到判别器D1,判别器D1输出k+1维分类结果{l11,...l1i,...l1k,l1(k+1)},其中l11至l1k表示标签子样本集L1中k类标签数据的置信度,l1(k+1)表示伪数据G(z)由判别器D1判定为“伪”的置信度;
步骤402、将无标签子样本集U1中的第n个无标签数据输入到判别器D1,判别器D1针对第n个无标签数据输出k+1维分类预测概率{h11-n,...h1i-n,...h1k-n,h1(k+1)-n},若MAX{h11-n,...h1j-n,...h1g-n}>η,则将无标签子样本集U1中第n个无标签数据加入标签子样本集L2中MAX{h11-n,...h1j-n,...h1g-n}所对应的标签类别,η表示置信度阈值,1≤n≤g。
具体实施时,如图3所示,判别器D1和判别器D2的网络结构相同,判别器D1的网络结构依次为:输入层→卷积层Conv1→卷积层Conv2→卷积层Conv3→全连接层→softmax分类器。
判别器D1的输入为大小为32×32的3通道RGB彩色图像,其维度为(128,3,32,32),经过四次步长为2的3×3的卷积核的卷积操作,最终输出图像维度为(128,128,2,2),其中每次完成卷积操作后都加入LeakyReLU激活函数和Dropout操作以防止过拟合,而除了首次卷积不使用归一化外,其余卷积操作后都是用归一化。
设置置信度阈值η,对每次迭代得到的无标签子样本集U1的分类结果进行置信度判断,如果大于该置信度阈值η,则将该标签数据加入到标签子样本集L2中继续迭代训练,利用无标签子样本集U1扩充标签子样本集L2,从而加快生成对抗网络收敛。
步骤403、将标签子样本集L2输入到判别器D2,判别器D2输出k+1维分类预测概率{l21,...l2i,...l2k,l2(k+1)},其中l21至l2k表示标签子样本集L2中k类标签数据的置信度,l2(k+1)表示伪数据G(z)由判别器D2判定为“伪”的置信度;
步骤404、将无标签子样本集U2中的第m个无标签数据输入到判别器D2,判别器D2针对第m个无标签数据输出k+1维分类预测概率{h21-m,...h2i-m,...h2k-m,h2(k+1)-m},若MAX{h21-m,...h2j-m,...h2g-m}>η,则将无标签子样本集U2中的第m个无标签数据加入标签子样本集L1中MAX{h21-m,...h2j-m,...h2g-m}所对应的标签类别,η表示置信度阈值,1≤m≤r;
同理,设置置信度阈值η,步骤404中的置信度阈值η与步骤402中的置信度阈值η相同。对每次迭代得到的无标签子样本集U2的分类结果进行置信度判断,如果大于该置信度阈值η,则将该标签数据加入到标签子样本集L1中继续迭代训练,利用无标签子样本集U2扩充标签子样本集L1,从而加快生成对抗网络收敛。
步骤405、判别器总损失的计算公式为其中表示判别器监督损失,对于判别器的监督损失,需要加入标签信息,因此监督损失以交叉熵的形式定义为,其中yi表示标签数据集L中第i维数据的标签,Du(xi)表示判别器Du判别标签数据的标签为第i维的概率。表示判别器无监督损失,基于联合训练的生成对抗网络需要判别无标签数据的类别标签,以此判别器的无监督损失既判断真伪,也判断类别概率,所以无监督损失由两部分组成,考虑到两个判别器联合训练的情况,无监督损失定义为:y′i表示判别器前一次迭代时判别无标签数据的类别为第i维。
需要说明的是,判别器的总损失maxLD为判别器D1损失和判别器D2损失的均值,以消除单个判别器存在的分布误差。
步骤406、更新判别器D1和判别器D2的训练参数。需要说明的是,所述判别器D1和判别器D2的初始训练参数相同及网络结构相同,在训练过程中动态变化,判别器D1和判别器D2参数共享。
本申请通过判别器D1和判别器D2的联合训练,一方面可以消除单个判别器存在的分布误差,提高判别器训练的稳定性;另一方面,利用无标签数据在训练时扩充标签数据集L,能够加快网络收敛。因此,本申请基于联合训练的生成对抗网络模型能够充分利用少量标签数据的标签信息和大量无标签数据的分布信息来获取整个样本的特征分布,迭代更新扩充了标签子样本集,从而进一步提高在小样本条件下网络图像分类的精度。
步骤五、迭代更新:
步骤501、若判别器损失maxLD收敛,结束迭代,得到训练好的生成对抗网络,否则进入步骤502;
步骤502、迭代执行步骤二到步骤五,每次迭代后,迭代次数加1,直到迭代次数等于最大迭代次数,迭代结束,得到训练好的生成对抗网络。
步骤六、利用测试集对训练好的生成对抗网络进行测试,生成对抗网络输出对测试集的分类结果,获得生成对抗网络的分类精度。
以上所述,仅是本发明的实施例,并非对本发明作任何限制,凡是根据本发明技术实质对以上实施例所作的任何简单修改、变更以及等效结构变化,均仍属于本发明技术方案的保护范围内。
Claims (5)
1.一种基于联合训练生成对抗网络的半监督图像分类方法,其特征在于:包括以下步骤:
步骤一、设置生成对抗网络,包括生成器G、判别器D1和判别器D2,设置生成对抗网络的训练初始参数;
步骤二、获取训练集和测试集,训练集包括标签数据集L和无标签数据集U,将标签数据集L打乱随机分为标签子样本集L1和标签子样本集L2,其中,标签子样本集L1、L2包括k类标签数据;将无标签数据集U打乱随机分为无标签子样本集U1和无标签子样本集U2,其中,无标签子样本集U1包括g个无标签数据,无标签子样本集U2包括r个无标签数据;
步骤三、训练生成器G:
步骤301、将随机高斯噪声z输入生成器G生成伪数据G(z);
步骤302、将伪数据G(z)输入到判别器D1,判别器D1对伪数据G(z)进行判别得到D1(G(z));
步骤303、将伪数据G(z)输入到判别器D2,判别器D2对伪数据G(z)进行判别得到D2(G(z));
步骤304、计算生成器G的损失minLG;
步骤305、更新生成器G的训练参数;
步骤四、训练判别器D1和判别器D2:
步骤401、将标签子样本集L1输入到判别器D1,判别器D1输出k+1维分类预测概率{l11,...l1i,...l1k,l1(k+1)},其中l11至l1k表示标签子样本集L1中k类标签数据的置信度,l1(k+1)表示伪数据G(z)由判别器D1判定为“伪”的置信度;
步骤402、将无标签子样本集U1中的第n个无标签数据输入到判别器D1,判别器D1针对第n个无标签数据输出k+1维分类预测概率{h11-n,...h1i-n,...h1k-n,h1(k+1)-n},若MAX{h11-n,...h1j-n,...h1g-n}>η,则将无标签子样本集U1中第n个无标签数据加入标签子样本集L2中MAX{h11-n,...h1j-n,...h1g-n}所对应的标签类别,η表示置信度阈值,1≤n≤g;
步骤403、将标签子样本集L2输入到判别器D2,判别器D2输出k+1维分类预测概率{l21,...l2i,...l2k,l2(k+1)},其中l21至l2k表示标签子样本集L2中k类标签数据的置信度,l2(k+1)表示伪数据G(z)由判别器D2判定为“伪”的置信度;
步骤404、将无标签子样本集U2中的第m个无标签数据输入到判别器D2,判别器D2针对第m个无标签数据输出k+1维分类预测概率{h21-m,...h2i-m,...h2k-m,h2(k+1)-m},若MAX{h21-m,...h2j-m,...h2g-m}>η,则将无标签子样本集U2中的第m个无标签数据加入标签子样本集L1中MAX{h21-m,...h2j-m,...h2g-m}所对应的标签类别,η表示置信度阈值,1≤m≤r;
步骤405、计算判别器总损失maxLD;
步骤406、更新判别器D1和判别器D2的训练参数;
步骤五、迭代更新:
步骤501、若判别器损失maxLD收敛,结束迭代,得到训练好的生成对抗网络,否则进入步骤502;
步骤502、迭代执行步骤二到步骤五,每次迭代后,迭代次数加1,直到迭代次数等于最大迭代次数,迭代结束,得到训练好的生成对抗网络;
步骤六、利用测试集对训练好的生成对抗网络进行测试,生成对抗网络输出对测试集的分类结果,获得生成对抗网络的分类精度。
4.按照权利要求1所述的一种基于联合训练生成对抗网络的半监督图像分类方法,其特征在于:所述生成器G的网络结构依次为:输入层→全连接层→上采样层→卷积层Conv1→上采样层→卷积层Conv2→卷积层Conv3。
5.按照权利要求1所述的一种基于联合训练生成对抗网络的半监督图像分类方法,其特征在于:所述判别器D1和判别器D2的网络结构相同,判别器D1的网络结构依次为:输入层→卷积层Conv1→卷积层Conv2→卷积层Conv3→全连接层→softmax分类器。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011068394.6A CN112232395B (zh) | 2020-10-08 | 2020-10-08 | 一种基于联合训练生成对抗网络的半监督图像分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011068394.6A CN112232395B (zh) | 2020-10-08 | 2020-10-08 | 一种基于联合训练生成对抗网络的半监督图像分类方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112232395A true CN112232395A (zh) | 2021-01-15 |
CN112232395B CN112232395B (zh) | 2023-10-27 |
Family
ID=74120955
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011068394.6A Active CN112232395B (zh) | 2020-10-08 | 2020-10-08 | 一种基于联合训练生成对抗网络的半监督图像分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112232395B (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113537031A (zh) * | 2021-07-12 | 2021-10-22 | 电子科技大学 | 基于多鉴别器条件生成对抗网络的雷达图像目标识别方法 |
CN113688953A (zh) * | 2021-10-25 | 2021-11-23 | 深圳市永达电子信息股份有限公司 | 基于多层gan网络的工控信号分类方法、装置和介质 |
CN114898159A (zh) * | 2022-06-01 | 2022-08-12 | 西北工业大学 | 基于解耦表征生成对抗网络的sar图像可解释性特征提取方法 |
Citations (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
JP2016212772A (ja) * | 2015-05-13 | 2016-12-15 | 株式会社国際電気通信基礎技術研究所 | 推定システム、推定方法、推定装置 |
CN106843195A (zh) * | 2017-01-25 | 2017-06-13 | 浙江大学 | 基于自适应集成半监督费舍尔判别的故障分类方法 |
CN108460717A (zh) * | 2018-03-14 | 2018-08-28 | 儒安科技有限公司 | 一种基于双判别器的生成对抗网络的图像生成方法 |
CN108564039A (zh) * | 2018-04-16 | 2018-09-21 | 北京工业大学 | 一种基于半监督深层生成对抗网络的癫痫发作预测方法 |
CN109753992A (zh) * | 2018-12-10 | 2019-05-14 | 南京师范大学 | 基于条件生成对抗网络的无监督域适应图像分类方法 |
CN109977094A (zh) * | 2019-01-30 | 2019-07-05 | 中南大学 | 一种用于结构化数据的半监督学习的方法 |
CN110320162A (zh) * | 2019-05-20 | 2019-10-11 | 广东省智能制造研究所 | 一种基于生成对抗网络的半监督高光谱数据定量分析方法 |
CN110617966A (zh) * | 2019-09-23 | 2019-12-27 | 江南大学 | 一种基于半监督生成对抗网络的轴承故障诊断方法 |
CN110689086A (zh) * | 2019-10-08 | 2020-01-14 | 郑州轻工业学院 | 基于生成式对抗网络的半监督高分遥感图像场景分类方法 |
CN111028146A (zh) * | 2019-11-06 | 2020-04-17 | 武汉理工大学 | 基于双判别器的生成对抗网络的图像超分辨率方法 |
CN111260584A (zh) * | 2020-01-17 | 2020-06-09 | 北京工业大学 | 基于gan网络的水下退化图像增强的方法 |
CN111626317A (zh) * | 2019-08-14 | 2020-09-04 | 广东省智能制造研究所 | 基于双流条件对抗生成网络的半监督高光谱数据分析方法 |
-
2020
- 2020-10-08 CN CN202011068394.6A patent/CN112232395B/zh active Active
Patent Citations (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
JP2016212772A (ja) * | 2015-05-13 | 2016-12-15 | 株式会社国際電気通信基礎技術研究所 | 推定システム、推定方法、推定装置 |
CN106843195A (zh) * | 2017-01-25 | 2017-06-13 | 浙江大学 | 基于自适应集成半监督费舍尔判别的故障分类方法 |
CN108460717A (zh) * | 2018-03-14 | 2018-08-28 | 儒安科技有限公司 | 一种基于双判别器的生成对抗网络的图像生成方法 |
CN108564039A (zh) * | 2018-04-16 | 2018-09-21 | 北京工业大学 | 一种基于半监督深层生成对抗网络的癫痫发作预测方法 |
CN109753992A (zh) * | 2018-12-10 | 2019-05-14 | 南京师范大学 | 基于条件生成对抗网络的无监督域适应图像分类方法 |
CN109977094A (zh) * | 2019-01-30 | 2019-07-05 | 中南大学 | 一种用于结构化数据的半监督学习的方法 |
CN110320162A (zh) * | 2019-05-20 | 2019-10-11 | 广东省智能制造研究所 | 一种基于生成对抗网络的半监督高光谱数据定量分析方法 |
CN111626317A (zh) * | 2019-08-14 | 2020-09-04 | 广东省智能制造研究所 | 基于双流条件对抗生成网络的半监督高光谱数据分析方法 |
CN110617966A (zh) * | 2019-09-23 | 2019-12-27 | 江南大学 | 一种基于半监督生成对抗网络的轴承故障诊断方法 |
CN110689086A (zh) * | 2019-10-08 | 2020-01-14 | 郑州轻工业学院 | 基于生成式对抗网络的半监督高分遥感图像场景分类方法 |
CN111028146A (zh) * | 2019-11-06 | 2020-04-17 | 武汉理工大学 | 基于双判别器的生成对抗网络的图像超分辨率方法 |
CN111260584A (zh) * | 2020-01-17 | 2020-06-09 | 北京工业大学 | 基于gan网络的水下退化图像增强的方法 |
Non-Patent Citations (2)
Title |
---|
FEI GAO等: "A Deep Convolutional Generative Adversarial Networks (DCGANs)-Based Semi-Supervised Method for Object Recognition in Synthetic Aperture Radar (SAR) Images", 《REMOTE SENSING》 * |
吴飞: "基于生成对抗网络和非局部神经网络的SAR图像变化检测", 《中国优秀硕士学位论文全文数据库 信息科技辑》 * |
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113537031A (zh) * | 2021-07-12 | 2021-10-22 | 电子科技大学 | 基于多鉴别器条件生成对抗网络的雷达图像目标识别方法 |
CN113537031B (zh) * | 2021-07-12 | 2023-04-07 | 电子科技大学 | 基于多鉴别器条件生成对抗网络的雷达图像目标识别方法 |
CN113688953A (zh) * | 2021-10-25 | 2021-11-23 | 深圳市永达电子信息股份有限公司 | 基于多层gan网络的工控信号分类方法、装置和介质 |
CN113688953B (zh) * | 2021-10-25 | 2022-02-22 | 深圳市永达电子信息股份有限公司 | 基于多层gan网络的工控信号分类方法、装置和介质 |
CN114898159A (zh) * | 2022-06-01 | 2022-08-12 | 西北工业大学 | 基于解耦表征生成对抗网络的sar图像可解释性特征提取方法 |
CN114898159B (zh) * | 2022-06-01 | 2024-03-08 | 西北工业大学 | 基于解耦表征生成对抗网络的sar图像可解释性特征提取方法 |
Also Published As
Publication number | Publication date |
---|---|
CN112232395B (zh) | 2023-10-27 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109919108B (zh) | 基于深度哈希辅助网络的遥感图像快速目标检测方法 | |
CN113378632B (zh) | 一种基于伪标签优化的无监督域适应行人重识别方法 | |
CN110443143B (zh) | 多分支卷积神经网络融合的遥感图像场景分类方法 | |
CN113190699B (zh) | 一种基于类别级语义哈希的遥感图像检索方法及装置 | |
CN112232395A (zh) | 一种基于联合训练生成对抗网络的半监督图像分类方法 | |
US11908457B2 (en) | Orthogonally constrained multi-head attention for speech tasks | |
CN113326731A (zh) | 一种基于动量网络指导的跨域行人重识别算法 | |
CN110705636B (zh) | 一种基于多样本字典学习和局部约束编码的图像分类方法 | |
CN113222011A (zh) | 一种基于原型校正的小样本遥感图像分类方法 | |
US20220121949A1 (en) | Personalized neural network pruning | |
CN112784929A (zh) | 一种基于双元组扩充的小样本图像分类方法及装置 | |
CN108052959A (zh) | 一种提高深度学习图片识别算法鲁棒性的方法 | |
CN114972904B (zh) | 一种基于对抗三元组损失的零样本知识蒸馏方法及系统 | |
CN111639697B (zh) | 基于非重复采样与原型网络的高光谱图像分类方法 | |
CN114842238A (zh) | 一种嵌入式乳腺超声影像的识别方法 | |
CN111259938B (zh) | 基于流形学习和梯度提升模型的图片偏多标签分类方法 | |
CN116152554A (zh) | 基于知识引导的小样本图像识别系统 | |
CN111310820A (zh) | 基于交叉验证深度cnn特征集成的地基气象云图分类方法 | |
CN114863938A (zh) | 一种基于注意力残差和特征融合的鸟语识别方法和系统 | |
CN114898136A (zh) | 一种基于特征自适应的小样本图像分类方法 | |
CN114299326A (zh) | 一种基于转换网络与自监督的小样本分类方法 | |
CN115329821A (zh) | 一种基于配对编码网络和对比学习的舰船噪声识别方法 | |
CN116543250A (zh) | 一种基于类注意力传输的模型压缩方法 | |
CN113592045B (zh) | 从印刷体到手写体的模型自适应文本识别方法和系统 | |
Hallyal et al. | Optimized recognition of CAPTCHA through attention models |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |