CN112766386B - 一种基于多输入多输出融合网络的广义零样本学习方法 - Google Patents
一种基于多输入多输出融合网络的广义零样本学习方法 Download PDFInfo
- Publication number
- CN112766386B CN112766386B CN202110096703.9A CN202110096703A CN112766386B CN 112766386 B CN112766386 B CN 112766386B CN 202110096703 A CN202110096703 A CN 202110096703A CN 112766386 B CN112766386 B CN 112766386B
- Authority
- CN
- China
- Prior art keywords
- data
- network
- visual
- pair
- attri
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2411—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on the proximity to a decision surface, e.g. support vector machines
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/22—Matching criteria, e.g. proximity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Image Analysis (AREA)
Abstract
一种基于多输入多输出融合网络的广义零样本学习方法,属于计算机技术领域。步骤包括:首先,构建多输入多输出的融合网络、用于对抗训练融合网络的判别网络、用于最小化重构损失的重构网络、引入分类损失优化公共语义表示学习的分类网络,其中,融合网络以不同方式串联组合作为输入,编码成潜在语义表示后分别输出。其次,流形学习和域结构的保持。最后,模型训练优化、生成伪样本训练分类器、和识别广义零样本图像。本发明主要针对零样本图像识别中生成模型训练不够稳定问题,设计融合网络,即使仅使用类别语义嵌入也能生成与配对数据相似的公共语义,以此生成伪样本,实现广义零样本图像识别,实验验证表明,能够有效解决广义零样本图像识别问题。
Description
技术领域
本发明属于计算机技术领域,涉及一种广义零样本学习方法。
背景技术
近年来,随着互联网上爆发性的数据增长,新事物不断涌现,难以收集足够的标记数据来训练图像识别模型。为此,专家学者提出了零样本学习的概念,引起了广泛的研究兴趣。零样本学习是受人类认知新事物的过程所启发,企图从有限的知识或者其他的已知事物中,辅以相应的知识描述,认知新事物。零样本学习中可见类是训练过程中能够使用的类,而不可见类是训练过程中不可使用,测试过程中使用的类。即从可见类数据中学习知识,进而能够迁移到不可见类中,进行零样本图像识别。因此能够解决标记数据不足和缺失的难题。现有的零样本学习工作大多数从可见类数据中学习一个视觉-语义映射,然后将此映射泛化到不可见类数据上。然而现有工作大部分在测试过程中仅仅使用不可见类,即仅做了零样本图像识别,未综合判断不可见类数据和可见类数据的准确率。该测试场景是不合理的,因为测试数据既可能来自不可见类,也可能来自可见类。因此,广义零样本学习越来越引起学者们的关注,在广义零样本学习中,新数据既可能来自不可见类,也可能来自可见类,更加贴合实际应用。
许多现有工作提出将视觉特征投影到语义空间,该语义空间可以是人为定义的属性空间或是类名词向量构成的嵌入空间。此类方法存在一个不可避免的缺陷便是域便宜问题,这是由于可见类和不可见类数据的分布存在差异,而在训练中仅仅使用了可见类数据,使得预测结果会偏向可见类数据。也有许多工作是学习视觉特征和语义的一个公共子空间,然而此类方法存在枢纽点问题。近些年,一些生成模型例如生成对抗网络(GAN)方法被提出,该类方法通过合成假的不可见类图片或者视觉特征来解决零样本问题。然而基于生成对抗网络的方法在模型训练时不够稳定,依然不能取得理想的结果。基于此,同样是生成方法的变分自动编码机方法被应用于假样本生成,通过匹配输入数据到一个已知的分布,然后利用该先验分布来合成假样本或者视觉特征,然而如何建立视觉特征和语义嵌入之间的关联性,找到一个较好的公共语义表示仍然存在巨大挑战。
为解决以上问题,本发明设计了一种基于多输入多输出融合网络的广义零样本学习方法,解决在图像领域的零样本识别问题。融合网络接收四种不同类型的输入,分别生成对应的公共语义表示并输出。对于其中的单输入问题,采用补零策略。核心目的是为了使不同类型的输入能够计算输出语义相似的公共表示,采用三个判别网络来区分不同的输出数据,以此通过对抗的方式来训练融合网络,达到生成高质量公共语义表示的目的,并且考虑了数据的重构损失,提高生成伪样本的质量,此外还考虑了局部域结构提高判别性。然后,便可以使用类别语义嵌入作为输入生成不可见类和可见类的伪样本,来训练分类器,进行广义零样本的识别任务。
发明内容
本发明针对广义零样本图像识别问题,探究一种基于多输入多输出融合网络的广义零样本学习方法,利用不同的数据类型来训练融合网络,以达到生成伪样本的目的。该方法利用了生成对抗的思想,试图使具有相同语义信息的输入数据生成相似的语义表示,构建融合网络,接收不同类型的输入,采用重构网络重构出输入,采用判别网络区分不同输入输出的语义表示,对抗训练融合网络。通过视觉特征和对应类别语义嵌入的不同方式的组合,构建4种不同类型的输入。网络训练好之后,基于类别语义嵌入,同时生成不可见类和可见类的视觉特征伪样本,然后,将零样本学习转化成一个完全有监督的分类问题。一旦测试样本到来,不管是来自可见类还是不可见类,都能够用训练好的分类器进行分类识别。
为了达到上述目的,本发明采用的技术方案为:
一种基于多输入多输出融合网络的广义零样本学习方法,所述的广义零样本学习方法采用融合网络,解决零样本图像识别问题。所述的多输入多输出融合网络采用不同类型的输入数据依然能够生成相似的公共语义表示,能够同时生成可见类和不可见类伪样本。所述广义零样本学习方法包括以下步骤:
步骤1,构建多输入多输出融合网络,以图像视觉特征和对应的类别语义嵌入,以4种不同方式串联组合作为输入,编码成潜在语义表示后分别输出;
步骤2,构建判别网络,形成3个判别网络,用于对抗训练融合网络;
步骤3,构建重构网络,从融合网络生成的公共语义表示中重构出输入,最小化重构损失;
步骤4,构建分类网络,引入分类损失优化公共语义表示学习;
步骤5,流形学习,使融合网络生成的公共语义表示保持局部结构和全局结构;
步骤6,构建和优化总目标函数,迭代更新模型的参数,保存模型参数。
步骤7,采用类别语义嵌入,串联高斯分布的随机采样数据,输入到融合网络,生成类别对应的伪样本,有监督训练分类器SVM;
步骤8,广义零样本图像识别,输入图像视觉特征,SVM分类结果为识别结果。
本发明接收4种不同类型的输入数据类型,生成类似的公共语义表示。
本发明仅仅使用类别语义嵌入便能生成可见类和不可见类的伪样本。
本发明步骤7中生成可见类和不可见类样本可以采用不同的数目。
本发明的有益效果为:
多输入多输出融合网络,接收不同组合类型的输入,生成语义表示,并通过对抗训练的方式,达到仅使用单独的类别语义嵌入作为输入,也能生成和类别语义嵌入与对应图像视觉特征配对输入相似的语义,如此便能够仅仅使用类别语义嵌入就能生成可见类和不可见类的伪样本;模型采用对抗训练的方式,结构分类损失和流形学习,提高了生成公共语义表示的判别性,生成的伪样本更接近数据的真实分布。
附图说明
图1为基于多输入多输出融合网络的广义零样本学习方法(MIMOFN)框架图;
图2为算法步骤图。
具体实施方式
下面结合附图对本发明的实施方式做进一步说明。
图1为本发明的总体框架图。其中,融合网络接收不同的数据类型输入,在最后一层生成对应的公共语义表示。四种输入数据类型,包含一种配对数据和三种非配对数据。核心思想就是采用不同类型的输入数据依然能够生成相似的公共语义表示。由于输入数据中携带了图像对应类别的语义嵌入,使得模型能够学习更具语义表达能力的公共语义表示的同时,实现从可见类数据到不可见类零样本数据的迁移。此外,为了能够对模型进行对抗训练,对于融合网络的输出采用多个判别器进行判定,以区分配对数据和非配对数据。如此,融合网络在仅仅使用类别语义嵌入的情况下,依然能够生成与配对数据相似的公共语义表示,为生成不可见类数据伪样本提供了可能。不仅如此,还采用重构网络从公共语义表示中分别重构出对应的原始病虫害视觉特征和语义嵌入,通过最小化重构误差,进一步提高模型学习公共语义表示的能力。采用分类网络对标签进行预测,最小化分类误差。另外,本项目还考虑了数据的域结构信息保持,提高公共语义表示的判别性。
其中,图像视觉特征提取于预先训练好的CNN模型如VGG19等;对应的语义嵌入特征,可以是类名词向量表示,提取于预先训练好的词向量模型如Word2Vec或者为人工标注的属性数据。对于输入层,本发明设计不同的输入数据类型,由视觉特征和语义嵌入进行不同形式的拼接,形成四种输入数据类型,分别是:①由图像的视觉特征和对应的语义嵌入拼接成的配对数据、②非配对数据仅包含语义嵌入、③非配对数据仅包含视觉特征、④负样本非配对数据由语义嵌入和不匹配的视觉特征拼接。对于②、③采用零填充策略补全数据。通过零填充和数据类型②,能够使模型即仅仅利用语义嵌入的情况下,依然能够生成和配对数据①一样的语义表示。如此,在模型测试和使用时,便能直接使用类别语义嵌入作为输入,为生成伪样本,解决两样本图像识别提供了可能。
具体步骤如下:
一种基于多输入多输出融合网络的广义零样本学习方法,包括以下步骤:
步骤1、构建多输入多输出融合网络;
多输入多输出融合网络,采用前向神经网络E(·,θE)实现,融合网络接收4种不同类型的输入数据,生成对应的公共语义表示Z。其中,θE表示网络参数,对应的4种输出表示为:Zpair=E((x,a),θE),Zattri=E((0,a),θE),Zvisual=E((x,0),θE),Znegative=E((x-,a),θE),其中x表示图像视觉特征,a表示对应的语义嵌入,x-表示随机选择的与a不匹配的语义嵌入,0表示零向量填充,4种表示分别为:Zpair为成对输入①(x,a)对应的公共语义表示、Zattri为非配对数据②(0,a)对应的公共语义表示、Zvisual为非配对数据③(x,0)对应的公共语义表示、Znegative为负样本非配对数据④(x-,a)对应的公共语义表示。
所述的4种不同类型的输入数据具体为:①由图像的视觉特征和对应的语义嵌入拼接成的配对数据、②非配对数据仅包含语义嵌入、③非配对数据仅包含视觉特征、④负样本非配对数据由语义嵌入和不匹配的视觉特征拼接。
步骤2、构建判别网络;
采用判别网络,通过区分配对数据和非配对数据,以对抗方式训练融合网络,学习语义表示更准确的网络模型。设计三种判别网络D1(·,θD1)、D2(·,θD2)、D3(·,θD3),其中θD1、θD2、θD3表示网络参数。三种判别网络都将Zpair作为真样本。对于D1,将Zattri作为假样本;对于D2,将Zvisual作为假样本;此外,负样本也被引入来强化对抗训练,因此对于D3,Znegative作为假样本。对抗损失Ladv定义如公式(1)所示:
Ladv=Ladv1+Ladv2+Ladv3 (1)
其中,Ladv1表示判别网络D1的对抗损失,Ladv2表示判别网络D2的对抗损失,Ladv3表示判别网络D3的对抗损失,分别如公式(2)、(3)、(4)所示。
式中:E表示求期望,~表示满足某分布的符号,pdata()表示数据分布符号,Zpair~pdata(Zpair)表示分别求logD1(Zpair)、logD2(Zpair)、logD3(Zpair)期望时的所有数据范围,代指所有Zpair的数据,Zattri~pdata(Zattri)表示取所有数据Zattri时log(1-D1(Zattri))的期望,Zvisual~pdata(Zvisual)表示取所有数据Zvisual时log(1-D2(Zvisual))的期望,Znegative~pdata(Znegative)表示取所有数据Znegative时log(1-D3(Znegative))的期望。
步骤3、构建重构网络;
构建重构网络,使模型能够从公共语义表示中分别重构出原始图像视觉特征,最小化重构误差,优化融合网络。当整个模型训练完毕的时候,重构网络同时也能够作为生成器,用于生产不可见类和可见类的伪样本。仅对除负样本之外的其他三种输入数据所获取的公共语义表示进行重构:将Zpair、Zattri、Zvisual都填入到视觉特征重构网络G1(·,θG1)中,其中θG1表示网络参数,网络的重构损失Lrec表示如公式(5)所示,其中重构网络有三个,共享权重;
其中,Xs是可见类图像数据的视觉特征,E表示求期望;Xs~pdata(Xs)表示取所有视觉特征Xs。表示取所有数据Zpair~pdata(Zpair)、Xs~pdata(Xs)时对应的期望。表示取所有数据Zattri~pdata(Zattri)、Xs~pdata(Xs)时对应的期望。表示取所有数据Zvisual~pdata(Zvisual)、Xs~pdata(Xs)时对应的期望。
步骤4、构建分类网络;
设计一个分类误差,用于提高所学习公共语义表示的判别能力,通过预测标签,最小化分类损失实现。分类网络同样为一个前向网络P(·,θP),接收Zpair、Zattri、Zvisual作为输入,其中θP表示网络参数。分类误差Llabel定义为公式(6)所示:
其中,YZ表示Z的标签数据,Ys表示可见类即训练数据的标签集合,Yz~pdata(Ys)表示YZ属于Ys。Z~pdata(Zattri∪Zvisual∪Zpair)表示Z属于数据Zattri,Zvisual,Zpair的并集,∪是求并集符号。P(Z)表示分类网络以Z为输入的标签预测结果。表示取所有的Z和YZ时的期望。
步骤5、流形学习、域结构保持;
采用流形学习的方式在公共语义表示上对原始数据的结构进行保持。对于输入类型①、②、③,在生成的对应公共语义表示上进行结构保持,采用拉普拉斯图正则化的方式,域结构保持定义如公式(7)所示:
Lmanifold=Tr(Zpair TLpZpair)+Tr(Zattri TLaZattri)+Tr(Zvisual TLvZvisual) (7)
其中:Tr(·)表示求矩阵的迹操作。Lp表示成对数据①的拉普拉斯矩阵,根据成对数据的相似度Spair计算,Lp=Dp-Spair,其中Dp是一个对角阵,其对角元为Spair的行求和。而Spair计算公式如公式(8)所示:
其中,i、j分别表示索引,Spair(i,j)表示Spair的第i行和第j列,Zpair(i)表示Zpair的第i个样本,Zpair(j)表示Zpair的第j个样本。
公式(7)中的La表示非配对数据②对应的拉普拉斯矩阵,根据非配对数据②的相似度Sattri计算得到。Sattri根据样本之间的邻居关系计算,具体如公式(9)所示:
其中,Sattri(i,j)表示Sattri的第i行和第j列,若输入数据的第i个样本和第j个样本的距离小于等于e则取值为1,否则为0。由于此时输入数据仅包含类别语义嵌入,则dist(a(i),a(j))表示类别语义嵌入输入数据的第i个样本a(i)和第j个样本a(j)之间的距离,其中e=max(dist(a(i),a(j)))/20是阈值参数,dist(·)表示距离函数,采用欧氏距离。用类似的方法可以计算非配对数据③,即仅包含图像视觉特征的相似度矩阵Svisual,然后便能够据此计算其对应的拉普拉斯矩阵Lv。
步骤6、构建融合网络总损失函数,训练网络模型,保存模型参数。
训练过程中,通过对抗的形式进行训练,首先训练判别网络,然后是融合网络、重构网络和分类网络。根据步骤2-步骤5的公式(1)、(5)、(6)、(7)得到模型的总体损失函数,如公式(10)所示:
L=Lmanifold+Llabel+Lrec-Ladv (10)
步骤7、生成伪样本,训练分类器SVM;
与现有的一些生成不可见类原始图像不同,本发明生成图像的视觉特征,降低生成模型的干扰。模型训练好之后,设置可见类和不可见类生成样本的数量,从高斯分布中采样对应数量的噪声数据,拼接需要生成的类别语义嵌入输入到融合网络中,生成公共语义表示,通过重构网络生成的视觉特征作为该类别的样本数据。训练分类器时仅使用生成的伪样本数据,以避免数据不平衡的干扰。
步骤8、广义零样本图像识别;
分类器如SVM训练完之后,新出现的测试样本便能够直接输入到分类器得到识别的结果。如此,无论新到来的图像属于可见类数据,还是不可见类的零样本数据,分类器都能对其进行诊断识别。
验证结果:
为了验证本发明提出的方法在处理零样本图像识别上的有效性,采用3个常用的数据集SUN、AWA2和aPY进行实验验证。数据集的详细信息如表1所示。其中AWA2和aPY是粗粒度的数据集,规模较小。而SUN是细粒度的数据集。视觉特征采用ResNet110提取的2048维的特征。
表1.数据集详细信息介绍
采用每类Top1准确率平均值作为算法性能的评价指标,在广义零样本图像识别中,还采用调和均值作为指标,其计算方式融合了可见类和不可见类的每类Top1准确率平均值,因此更加准确。
验证模型的实现细节如下,采用TensorFlow实现模型,训练过程中训练数据块的大小为128,采用Adam训练器,学习率为0.0002。公共语义表示大小为50。
每一个网络都是采用全连接网络组成,融合网络E的输入维度为2048+da,其中da表示类别属性的维度,包含一个隐藏层,具有512个隐藏单元。第一层的激活函数是ReLU,丢弃率为0.7,最后一层公共语义表示层大小为50。判别网络D包含2个隐藏层大小分别为64、32,输出为2维,输入为前一阶段的公共语义表示。重构网络G的包含一个512维的隐藏层,采用ReLU激活函数。在验证过程中各个数据集生成的伪样本数量如表2所示。
表2.数据集详细信息介绍
广义零样本图像识别的结果如表3所示,其中U表示不可见类的每类Top1准确率平均值,S表示可见类的每类Top1准确率平均值,H是调和均值,三者数值越高表明准确率越高,而H越高表明最终的广义零样本图像识别准确率越高。表中的对比算法如下:语义相似度嵌入(SSE)、潜在嵌入(LATEM)、属性标签嵌入(ALE)、语义自动编码机(SAE)、深度嵌入模型(DEM)、合成样本(SE-GZSL)、双视角排序(DARK)、联合投影子空间学习(JIL)、语义保持对抗嵌入网络(SP-AEN)、语义校正生成对抗网络(SR-GAN)以及条件变分自动编码机(CVAE)。CAAE表示本发明的方法。
表3.广义零样本图像识别结果
从表3的验证结果中可以看出,本发明的方法MIMOFN远远好于相融性学习方法,比如DEVISE。本发明的方法相比于生成类方法SP-AEN、CVAE、SR-GAN,在AWA2和aPY数据集上效果比较好,特别地,本发明的方法MIMOFN相比于CVAE在AWA2数据集上,调和均值提高了9.2%。在aPY数据集上也是优于SR-GAN。在SUN数据集上,相比于CVAE和SR-GAN,本发明的方法取得了相当的调和均值,仅次于DARK。但是DARK在另外两个数据集上,表现远远不如本发明的方法。总体上看,实验验证了本发明的多输入多输出融合网络对于解决广义零样本学习是有效的。
综上所述,本发明的基于多输入多输出融合网络的广义零样本学习方法,能够在仅仅使用视觉特征或者类别语义嵌入一个条件下,也能生成和配对数据相似的公共语义表示,并且重构出输入的图像视觉特征,因此能够利用类属性条件生成可见类和不可见类的伪样本,同时训练分类器,解决广义零样本图像识别问题。
以上所述实例仅表达本发明的实施方式,但并不能因此而理解为对本发明专利的范围的限制,应当指出,对于本领域的技术人员来说,在不脱离本发明构思的前提下,还可以做出若干变形和改进,这些均属于本发明的保护范围。
Claims (3)
1.一种基于多输入多输出融合网络的广义零样本学习方法,其特征在于,所述的多输入多输出融合网络采用不同类型的输入数据依然能够生成相似的公共语义表示,能够同时生成可见类和不可见类伪样本,包括以下步骤:
步骤1、构建多输入多输出融合网络;
多输入多输出融合网络,采用前向神经网络E(·,θE)实现,融合网络接收4种不同类型的输入数据,生成对应的公共语义表示Z;其中,θE表示网络参数,对应的4种输出表示为:Zpair=E((x,a),θE),Zattri=E((0,a),θE),Zvisual=E((x,0),θE),Znegative=E((x-,a),θE),其中x表示图像视觉特征,a表示对应的语义嵌入,x-表示随机选择的与a不匹配的语义嵌入,0表示零向量填充,4种表示分别为:Zpair为成对输入①(x,a)对应的公共语义表示、Zattri为非配对数据②(0,a)对应的公共语义表示、Zvisual为非配对数据③(x,0)对应的公共语义表示、Znegative为负样本非配对数据④(x-,a)对应的公共语义表示;
所述的4种不同类型的输入数据具体为:①由图像的视觉特征和对应的语义嵌入拼接成的配对数据、②非配对数据仅包含语义嵌入、③非配对数据仅包含视觉特征、④负样本非配对数据由语义嵌入和不匹配的视觉特征拼接;
步骤2、构建判别网络;
采用判别网络,通过区分配对数据和非配对数据,以对抗方式训练融合网络,学习语义表示更准确的网络模型;设计三种判别网络D1(·,θD1)、D2(·,θD2)、D3(·,θD3),其中θD1、θD2、θD3表示网络参数;三种判别网络都将Zpair作为真样本;对于D1,将Zattri作为假样本;对于D2,将Zvisual作为假样本;此外,负样本也被引入来强化对抗训练,因此对于D3,Znegative作为假样本;对抗损失Ladv定义如公式(1)所示:
Ladv=Ladv1+Ladv2+Ladv3 (1)
其中,Ladv1表示判别网络D1的对抗损失,Ladv2表示判别网络D2的对抗损失,Ladv3表示判别网络D3的对抗损失,分别如公式(2)、(3)、(4)所示;
式中:E表示求期望,~表示满足某分布的符号,pdata()表示数据分布符号,Zpair~pdata(Zpair)表示分别求logD1(Zpair)、logD2(Zpair)、logD3(Zpair)期望时的所有数据范围,代指所有Zpair的数据,Zattri~pdata(Zattri)表示取所有数据Zattri时log(1-D1(Zattri))的期望,Zvisual~pdata(Zvisual)表示取所有数据Zvisual时log(1-D2(Zvisual))的期望,Znegative~pdata(Znegative)表示取所有数据Znegative时log(1-D3(Znegative))的期望;
步骤3、构建重构网络;
构建重构网络,使模型能够从公共语义表示中分别重构出原始图像视觉特征,最小化重构误差,优化融合网络;当整个模型训练完毕的时候,重构网络同时作为生成器,用于生产不可见类和可见类的伪样本;仅对除负样本之外的其他三种输入数据所获取的公共语义表示进行重构:将Zpair、Zattri、Zvisual都填入到视觉特征重构网络G1(·,θG1)中,其中θG1表示网络参数,网络的重构损失Lrec表示如公式(5)所示,其中重构网络有三个,共享权重;
其中,Xs是可见类图像数据的视觉特征,E表示求期望;Xs~pdata(Xs)表示取所有视觉特征Xs;表示取所有数据Zpair~pdata(Zpair)、Xs~pdata(Xs)时对应的期望;表示取所有数据Zattri~pdata(Zattri)、Xs~pdata(Xs)时对应的期望;表示取所有数据Zvisual~pdata(Zvisual)、Xs~pdata(Xs)时对应的期望;
步骤4、构建分类网络;
设计一个分类误差,用于提高所学习公共语义表示的判别能力,通过预测标签,最小化分类损失实现;分类网络同样为一个前向网络P(·,θP),接收Zpair、Zattri、Zvisual作为输入,其中θP表示网络参数;分类误差Llabel定义为公式(6)所示:
其中,YZ表示Z的标签数据,Ys表示可见类即训练数据的标签集合,Yz~pdata(Ys)表示YZ属于Ys;Z~pdata(Zattri∪Zvisual∪Zpair)表示Z属于数据Zattri,Zvisual,Zpair的并集,∪是求并集符号;P(Z)表示分类网络以Z为输入的标签预测结果;表示取所有的Z和YZ时的期望;
步骤5、流形学习、域结构保持;
采用流形学习的方式在公共语义表示上对原始数据的结构进行保持;对于输入类型①、②、③,在生成的对应公共语义表示上进行结构保持,采用拉普拉斯图正则化的方式,域结构保持定义如公式(7)所示:
Lmanifold=Tr(Zpair TLpZpair)+Tr(Zattri TLaZattri)+Tr(Zvisual TLvZvisual) (7)
公式(7)中Lp表示成对数据①的拉普拉斯矩阵,根据成对数据的相似度Spair计算,Lp=Dp-Spair,其中Dp是一个对角阵,其对角元为Spair的行求和,Spair计算公式如公式(8)所示:
公式(7)中La表示非配对数据②对应的拉普拉斯矩阵,根据非配对数据②的相似度Sattri计算得到;Sattri根据样本之间的邻居关系计算,如公式(9)所示:
其中,若输入数据的第i个样本和第j个样本的距离小于等于e则取值为1,否则为0;dist(a(i),a(j))表示类别语义嵌入输入数据的样本a(i)和样本a(j)之间的距离,dist(·)表示距离函数;e=max(dist(a(i),a(j)))/20是阈值参数;
采用类似的方法计算非配对数据③,即仅包含图像视觉特征的相似度矩阵Svisual,计算其对应的拉普拉斯矩阵Lv;
步骤6、构建融合网络总损失函数,训练网络模型,保存模型参数;
训练过程中,通过对抗的形式进行训练,首先训练判别网络,然后是融合网络、重构网络和分类网络;根据步骤2-步骤5的公式(1)、(5)、(6)、(7)得到模型的总体损失函数,如公式(10)所示:
L=Lmanifold+Llabel+Lrec-Ladv (10)
步骤7、生成伪样本,训练分类器SVM;
模型训练好之后,设置可见类和不可见类生成样本的数量,从高斯分布中采样对应数量的噪声数据,拼接需要生成的类别语义嵌入输入到融合网络中,生成公共语义表示,通过重构网络生成的视觉特征作为该类别的样本数据;训练分类器时仅使用生成的伪样本数据;
步骤8、广义零样本图像识别;
分类器如SVM训练完之后,新出现的测试样本能够直接输入到分类器得到识别的结果;如此,无论新到来的图像属于可见类数据,还是不可见类的零样本数据,分类器都能对其进行诊断识别。
2.根据权利要求1所述的一种基于多输入多输出融合网络的广义零样本学习方法,其特征在于,所述的步骤7中生成可见类和不可见类样本可以采用不同的数目。
3.根据权利要求1所述的一种基于多输入多输出融合网络的广义零样本学习方法,其特征在于,所述步骤5公式(7)中dist(·)采用欧氏距离。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110096703.9A CN112766386B (zh) | 2021-01-25 | 2021-01-25 | 一种基于多输入多输出融合网络的广义零样本学习方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110096703.9A CN112766386B (zh) | 2021-01-25 | 2021-01-25 | 一种基于多输入多输出融合网络的广义零样本学习方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112766386A CN112766386A (zh) | 2021-05-07 |
CN112766386B true CN112766386B (zh) | 2022-09-20 |
Family
ID=75707125
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110096703.9A Active CN112766386B (zh) | 2021-01-25 | 2021-01-25 | 一种基于多输入多输出融合网络的广义零样本学习方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112766386B (zh) |
Families Citing this family (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113378959B (zh) * | 2021-06-24 | 2022-03-15 | 中国矿业大学 | 一种基于语义纠错下生成对抗网络的零样本学习方法 |
CN113609569B (zh) * | 2021-07-01 | 2023-06-09 | 湖州师范学院 | 一种判别式的广义零样本学习故障诊断方法 |
CN113537322B (zh) * | 2021-07-02 | 2023-04-18 | 电子科技大学 | 一种跨模态语义增强生成对抗网络的零样本视觉分类方法 |
CN113642621B (zh) * | 2021-08-03 | 2024-06-28 | 南京邮电大学 | 基于生成对抗网络的零样本图像分类方法 |
CN113537389B (zh) * | 2021-08-05 | 2023-11-07 | 京东科技信息技术有限公司 | 基于模型嵌入的鲁棒图像分类方法和装置 |
CN113673685B (zh) * | 2021-08-31 | 2024-03-15 | 西湖大学 | 基于流形学习的数据嵌入方法 |
CN115424262A (zh) * | 2022-08-04 | 2022-12-02 | 暨南大学 | 一种用于优化零样本学习的方法 |
CN115758159B (zh) * | 2022-11-29 | 2023-07-21 | 东北林业大学 | 基于混合对比学习和生成式数据增强的零样本文本立场检测方法 |
CN117034020B (zh) * | 2023-10-09 | 2024-01-09 | 贵州大学 | 一种基于cvae-gan模型的无人机传感器零样本故障检测方法 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110232205A (zh) * | 2019-04-28 | 2019-09-13 | 大连理工大学 | 用于托卡马克中共振磁扰动控制新经典撕裂模的模拟方法 |
CN111476294A (zh) * | 2020-04-07 | 2020-07-31 | 南昌航空大学 | 一种基于生成对抗网络的零样本图像识别方法及系统 |
CN111914929A (zh) * | 2020-07-30 | 2020-11-10 | 南京邮电大学 | 零样本学习方法 |
-
2021
- 2021-01-25 CN CN202110096703.9A patent/CN112766386B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110232205A (zh) * | 2019-04-28 | 2019-09-13 | 大连理工大学 | 用于托卡马克中共振磁扰动控制新经典撕裂模的模拟方法 |
CN111476294A (zh) * | 2020-04-07 | 2020-07-31 | 南昌航空大学 | 一种基于生成对抗网络的零样本图像识别方法及系统 |
CN111914929A (zh) * | 2020-07-30 | 2020-11-10 | 南京邮电大学 | 零样本学习方法 |
Non-Patent Citations (1)
Title |
---|
基于跨域对抗学习的零样本分类;刘欢等;《计算机研究与发展》;20191215(第12期);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN112766386A (zh) | 2021-05-07 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112766386B (zh) | 一种基于多输入多输出融合网络的广义零样本学习方法 | |
Abdar et al. | A review of uncertainty quantification in deep learning: Techniques, applications and challenges | |
CN109753992B (zh) | 基于条件生成对抗网络的无监督域适应图像分类方法 | |
CN109063724B (zh) | 一种增强型生成式对抗网络以及目标样本识别方法 | |
CN107273927B (zh) | 基于类间匹配的无监督领域适应分类方法 | |
CN110046671A (zh) | 一种基于胶囊网络的文本分类方法 | |
CN111724083A (zh) | 金融风险识别模型的训练方法、装置、计算机设备及介质 | |
CN102314614B (zh) | 一种基于类共享多核学习的图像语义分类方法 | |
Kastaniotis et al. | Attention-aware generative adversarial networks (ATA-GANs) | |
CN111445548B (zh) | 一种基于非配对图像的多视角人脸图像生成方法 | |
CN111258992A (zh) | 一种基于变分自编码器的地震数据扩充方法 | |
Krichen | Generative adversarial networks | |
CN101540047A (zh) | 基于独立高斯混合模型的纹理图像分割方法 | |
CN112465226B (zh) | 一种基于特征交互和图神经网络的用户行为预测方法 | |
CN112529063B (zh) | 一种适用于帕金森语音数据集的深度域适应分类方法 | |
Ning et al. | Conditional generative adversarial networks based on the principle of homologycontinuity for face aging | |
CN112232395B (zh) | 一种基于联合训练生成对抗网络的半监督图像分类方法 | |
CN110110739A (zh) | 一种基于样本选择的域自适应降维方法 | |
CN114972904B (zh) | 一种基于对抗三元组损失的零样本知识蒸馏方法及系统 | |
CN114722892A (zh) | 基于机器学习的持续学习方法及装置 | |
CN115600137A (zh) | 面向不完备类别数据的多源域变工况机械故障诊断方法 | |
CN112241741A (zh) | 基于分类对抗网的自适应图像属性编辑模型和编辑方法 | |
CN105787045B (zh) | 一种用于可视媒体语义索引的精度增强方法 | |
CN113344146A (zh) | 基于双重注意力机制的图像分类方法、系统及电子设备 | |
CN117408845A (zh) | 一种基于知识图谱的人群定位方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |