CN112232156B - 一种基于多头注意力生成对抗网络的遥感场景分类方法 - Google Patents

一种基于多头注意力生成对抗网络的遥感场景分类方法 Download PDF

Info

Publication number
CN112232156B
CN112232156B CN202011059707.1A CN202011059707A CN112232156B CN 112232156 B CN112232156 B CN 112232156B CN 202011059707 A CN202011059707 A CN 202011059707A CN 112232156 B CN112232156 B CN 112232156B
Authority
CN
China
Prior art keywords
remote sensing
sensing image
head
generator
generated
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202011059707.1A
Other languages
English (en)
Other versions
CN112232156A (zh
Inventor
王鑫
李可
宁晨
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Hohai University HHU
Original Assignee
Hohai University HHU
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Hohai University HHU filed Critical Hohai University HHU
Priority to CN202011059707.1A priority Critical patent/CN112232156B/zh
Publication of CN112232156A publication Critical patent/CN112232156A/zh
Application granted granted Critical
Publication of CN112232156B publication Critical patent/CN112232156B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V20/00Scenes; Scene-specific elements
    • G06V20/10Terrestrial scenes
    • G06V20/182Network patterns, e.g. roads or rivers
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F17/00Digital computing or data processing equipment or methods, specially adapted for specific functions
    • G06F17/10Complex mathematical operations
    • G06F17/16Matrix or vector computation, e.g. matrix-matrix or matrix-vector multiplication, matrix factorization
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/25Fusion techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02ATECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
    • Y02A20/00Water conservation; Efficient water supply; Efficient water use
    • Y02A20/152Water filtration

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Software Systems (AREA)
  • Computational Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Pure & Applied Mathematics (AREA)
  • Mathematical Optimization (AREA)
  • Mathematical Analysis (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Multimedia (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Health & Medical Sciences (AREA)
  • Algebra (AREA)
  • Databases & Information Systems (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于多头注意力生成对抗网络的遥感场景分类方法,首先,将多头注意力机制引入生成对抗网络,该网络包含一个嵌入了多头注意力机制的生成器和一个嵌入了多头注意力机制的判别器;其次,将遥感图像样本输入该网络中,生成器通过输入随机噪声产生合成图像,判别器分别通过输入真实图像和生成模型生成的合成图像来判断输入图像的真假,交替训练判别器和生成器,直到当生成器和判别器达到纳什均衡时,模型达到最优;然后,提取生成器生成的遥感图像,和原始遥感图像数据集进行合并,形成新的遥感图像数据集;最后,将新的遥感图像数据集输入深度卷积神经网络分类器中,实现遥感图像的分类。

Description

一种基于多头注意力生成对抗网络的遥感场景分类方法
技术领域
本发明属于图像处理领域,尤其涉及一种基于多头注意力生成对抗网络的遥感场景分类方法。
背景技术
当前我国水质环境面临的形势十分严峻,加强对水质环境管理和保护已成为重中之重。然而传统的水质环境监测方法费时费力,很难在较短时间内全面获取信息,将获取信息快、覆盖范围大、受限条件少的遥感技术应用于水质环境监测中显得尤为重要。水利遥感图像分类是水域解译的重要基础,为水利遥感图像的应用(如水资源调查监测、水文情报预测和区域水文研究)提供了不可或缺的分析数据,因此对其深入研究具有很高的理论意义和实用价值。
在实际应用中,水利遥感图像中水体类别复杂,一级水域和水利设施用地可以粗略地分为十种类别,如:河流、湖泊、水库、坑塘和沟渠等,而其中的湖泊又可以分为堰塞湖、构造湖、火口湖和河成湖等等。复杂的水体类别使得获取大量水利遥感图像并对其进行标注显然是不现实的,而实际情况中,需要预测的水利遥感图像数量往往远超于现有的带标注的数据集数量,导致训练样本和测试样本极不均衡。传统的基于监督分类、非监督分类、和浅层学习的方法虽然能在一定程度上提升水利遥感图像的分类精度,但是容易受到光照、噪声等问题的干扰,且面对复杂分类问题时,泛化能力容易受到制约。基于深度学习的水利遥感图像分类方法强大的特征表达能力使得其能有效应用于水利遥感图像分类问题。但是无论是传统的还是基于深度学习的水利遥感图像分类方法,分类性能的瓶颈均受限于带标签水利遥感图像数据集的数目。现有的深度学习方法应用于水利遥感图像分类任务面临巨大的挑战。如何充分利用有限的标记数据集预测大量的未知类别的水利遥感图像,提升其分类性能是目前亟待解决的一个问题。
针对标记数据集匮乏的水利遥感图像而言,若要大幅度提升其分类效果,关键在于要有大量带标签的水利遥感图像训练样本,然而,对水利遥感图像进行训练样本的手工选取和标注是一个巨大的工程,而且随着具体分类任务的不同,大量训练样本的获取非常困难。目前常用的数据增强手段是通过裁剪、翻转图像、旋转图像和缩放比例的方式来获取足够多的数据,但是这种方法只是在原有的数据集上进行一些变换,并不是生成新的遥感图像信息,并且很难稳定地改善分类结果。尤其是面向细节较为丰富、信息量较大的水利遥感图像来说,通过裁剪图像等数据增强手段会损失掉一些信息。而生成对抗网络作为近年来最具潜力的深度学习方法,表示图像等高维数据的突出能力使其在遥感图像处理领域有很强的适用性,尤其具备图像合成的能力。因此,利用在图像生成方面成果显著的生成对抗网络增强标记水利遥感图像数据集,进一步提升水利遥感图像分类准确率成为了一个自然而然的选择。
近年来,以生成对抗网络为代表的生成模型备受关注,其不依赖于任何先验分布假设,能够用简单的方式从潜在空间中生成真实样本,能很好地表示图像、视频、音乐等高维数据,这些特点使得生成对抗网络具备处理复杂图像的优越能力,已经有学者将其应用于遥感图像处理领域。
Xiangyu Liu等人2018年在25th IEEE International Conference on ImageProcessing上发表的论文“Psgan:a generative adversaria network for remotesensing image pan-sharpening”将生成对抗网络应用于遥感图像融合,提出了一种用于遥感图像融合的生成对抗网络PSGAN,首先提出了一个用于生成高分辨率多光谱图像的双流融合结构,然后将一个全卷积网络作为判别器判别真实图像或者融合图像的真实性。结果表明,提出的PSGAN能够有效地融合多光谱图像和单波段图像,显著提高了图像的融合效果。
Lin Zhu等人2018年在IEEE Transactions on Geoscience and Remote Sensing上发表的论文“Generative adversarial networks for hyperspectral imageclassification”将生成对抗网络用于高光谱图像分类,提出了两种方案:将提出的1D-GAN作为光谱分类器;鲁棒性更强的3D-GAN作为光谱空间分类器,将生成的对抗性样本与真实训练样本相结合,对判别器进行微调,提高了最终的分类性能。
Qian Shi等人2017年在IEEE access上发表的论文“Road detection fromremote sensing images by generative adversarial networks”将生成对抗网络用于道路检测,提出了一种端到端的生成对抗网络,特别是构建了一个基于对抗训练的卷积神经网络,能够区分真实图像或由分割模型生成的分割图像,该方法通过修正分割模型输出结果与真实值之间的差异,提高了分割性能。
现有的基于生成对抗网络的遥感图像处理方法,存在的诸多局限性表现在:
(1)生成对抗网络模型生成的遥感图像过于自由不可控。生成对抗网络最大的优点是毋须预先建模,但是也带来了生成对抗网络训练过程和生成的水利遥感图像通常不太可控的问题。例如在生成河流的过程中,有时会看到蓝色的河流,而真实水利遥感图像的河流颜色通常呈绿色或者黄色。也就是说虽然生成对抗网络使用噪声z作为先验知识,但是生成器如何利用噪声z生成图像是无法控制的,使得生成对抗网络生成的图像太过自由,很不稳定。
(2)生成对抗网络生成的遥感图像过于关注局部点的信息,无法确保图像中长距离像素点之间的相关性。生成对抗网络中的生成器和判别器通常使用卷积神经网络作为特征提取器,卷积核受局部感受野的限制只能提取到近距离像素点之间的相关性,忽略了图像中较远距离像素点之间的相关性。具体表现为,在生成某一类别的图像时,只关注最重要的部分,而忽视了周围其他的像素点。
发明内容
发明目的:针对现有技术中存在的问题,本发明提供一种基于多头注意力生成对抗网络的遥感场景分类方法。该方法解决了目前水利遥感图像分类过程中,因标记数据集缺乏和获取困难,却常需要通过小样本数据预测大量数据而导致的整体分类性能不高的问题,通过提出的Multihead-attention GAN增强水利遥感图像数据集,同时考虑到了二维平面和通道上所有的像素点,并计算它们之间长距离的相关性,在生成图像时将这些相关性作为权重,生成的图像粒度和细节会更加精细,能有效提升分类准确率。
技术方案:为实现本发明的目的,本发明所采用的技术方案是:一种基于多头注意力生成对抗网络的遥感场景分类方法,该方法的具体步骤如下:
(1)将多头注意力机制引入生成对抗网络,构建Multihead-attention GAN网络,该网络包含一个嵌入多头注意力机制的生成器和一个嵌入多头注意力机制的判别器;
(2)获取原始遥感图像数据集,划分训练集、验证集和测试集,使用训练集交替训练判别器和生成器,直到生成器和判别器达到纳什均衡,得到最优的网络模型;
(3)提取步骤(2)得到的最优网络模型中生成器生成的遥感图像,与原始遥感图像训练集合并,得到新的遥感图像训练集;
(4)使用新的遥感图像训练集训练深度卷积神经网络分类器,使用验证集进行验证,在测试集上实现遥感图像的分类。
进一步的,所述步骤(1)中,构建Multihead-attention GAN结构如下:
Multihead-attention GAN由嵌入了Multihead-attention机制的生成器和判别器构成;生成器G包含七层:五层反卷积层和两层Multihead-attention层;判别器D包含七层:五层卷积层和两层Multihead-attention层;
所述生成器的输入是128维的随机噪声z,经过第一层反卷积层TConv1后,生成的特征图为4×4×512;经过第二层反卷积层Tconv2后,生成的特征图为8×8×256;经过第三层反卷积层Tconv3后,生成的特征图为16×16×128;经过第四层Multihead-attention层GAtten1之后,生成的特征图为16×16×128;经过第五层反卷积层Tconv4后,生成的特征图为32×32×64;经过第六层Multihead-attention层GAtten2之后,生成的特征图为32×32×64;经过第七层反卷积层Tconv5后,生成的特征图为64×64×3,即为生成的遥感图像;
所述判别器的输入是真实的遥感图像和生成器生成的遥感图像,大小均为64×64×3;经过第一层卷积层Conv1之后,生成的特征图为32×32×64;经过第二层卷积层Conv2之后,生成的特征图为16×16×128;经过第三层卷积层Conv3之后,生成的特征图为8×8×256;经过第四层Multihead-attention层DAtten1之后,生成的特征图为8×8×256;经过第五层卷积层Conv4之后,生成的特征图为4×4×512;经过第六层Multihead-attention层DAtten2之后,生成的特征图为4×4×512;经过第七层卷积层Conv5之后,最后输出的特征图是1×1×1,表示输入分类器中待分类的图像所属的类别。
进一步的,所述步骤(1)中,Multihead-attention机制是在Self-attention机制上进行改进,将Self-attention的单头改为多头,针对每一个head执行attention操作;步骤如下:
将输入多头注意力机制层的特征图,从通道上划分为L个head,并且在每个head内计算attention矩阵;将每个head的attention矩阵作用于该head对应的输入特征图中的像素点,得到每个head的输出结果;将所有head的输出结果进行融合,融合后的结果作为该多头注意力机制层的输出。
进一步的,计算每个head的attention矩阵方法如下:
首先,Multihead-attention对从判别器中第三层卷积层Conv3、第五层卷积层Conv4或生成器中第三层反卷积层Tconv3、第五层反卷积层Tconv4中提取出的特征图
Figure BDA0002711926660000041
赋予不同的权重矩阵WQ、WK和WV,生成特征图WQxh、WKxh和WVxh;其中C是特征图的通道数,N是特征图单个通道上像素点的个数;
其次,经过非线性变换f(x)=Wfx和g(x)=Wgx将特征图WQxh、WKxh和WVxh转换到特征空间f和g中;其中x表示输入的特征图,Wf,Wg分别表示将输入特征图转换到特征空间f和g的权重矩阵;
然后,计算第l个head的attention矩阵:
Figure BDA0002711926660000042
其中,
Figure BDA0002711926660000043
表示归一化之后第l个head的attention矩阵,
Figure BDA0002711926660000044
表示在Multihead-attention机制中,生成第j个区域时第i个位置对其的关注程度;N表示特征图单个通道上像素点的个数,xhi表示特征空间f中的像素点,xhj表示特征空间g中的像素点;
Figure BDA0002711926660000045
Figure BDA0002711926660000046
均为权重矩阵,
Figure BDA0002711926660000047
C是原始输入特征图的通道数,L为多头注意力机制中head的个数,
Figure BDA0002711926660000048
是每一个head划分到的通道数。
进一步的,将每个head的attention矩阵作用于输入特征图对应区域内的像素点,得到每个head的输出结果,将所有head的输出结果进行融合,方法如下:
第l个head中的像素点组成的区域内,第j个像素xhj经过
Figure BDA0002711926660000049
作用之后的结果
Figure BDA00027119266600000410
为:
Figure BDA00027119266600000411
其中,h(x)=Whx;Wh是将特征图转换到特征空间h时的权重矩阵;
第j个像素xhj经过第l个head的attention机制计算之后的输出为:
Figure BDA0002711926660000051
其中,γ(l)表示
Figure BDA0002711926660000052
的权重,
Figure BDA0002711926660000053
表示第j个像素经过第l个head中的attention矩阵计算之后的输出结果;
将所有head的输出进行融合,得到:
Output=v(concat(O(1),O(2),…,O(l),…,O(L))WO)
Figure BDA0002711926660000054
其中,O(l)是经过第l个head attention作用之后的所有像素点的集合,v(x)=WVx,
Figure BDA0002711926660000055
是权重矩阵,
Figure BDA0002711926660000056
WO是权重矩阵。
进一步的,准备原始遥感图像数据集方法如下:
设原始遥感图像数据集共包含K类,每个类别包含M张图像,原始遥感图像数据集表示为T={(x11,y1),…,(x1N,y1),…,(xij,yj),…,(xK1,yK),…,(xKM,yK)},其中xij表示为第i类遥感图像中的第j张图像,yj表示图像xij所属类别的标签;
将原始遥感图像数据集按照1:1:8的比例划分为训练集、验证集和测试集;
训练集Tr表示如下:
Tr={(x11,y1),…,(x1(M/10),y1),…,
(xi1,yi),…,(xi(M/10),yi),…,
(xK1,yK),…,(xK(M/10),yK)}
其中每一类遥感图像数目为M/10;
验证集Tv表示如下:
Tv={(x1(M/10+1),y1),…,(x1(2M/10),y1),…,
(xi(M/10+1),yi),…,(xi(2M/10),yi),…,
(xK(M/10+1),yK),…,(xK(2M/10),yK)}
其中每一类遥感图像数也为M/10;
测试集Te表示如下:
Te={(x1(2M/10+1),y1),…,(x1M,y1),…,
(xi(2M/10+1),yi),…,(xiM,yi),…,
(xK(2M/10+1),yK),…,(xKM,yK)}
其中每一类遥感图像数目为8M/10。
进一步的,使用准备好的原始遥感图像数据集训练生成对抗网络,方法如下:
将准备好的原始遥感图像训练集Tr输入构建的Multihead-attention GAN中,交替训练判别器D和生成器G,直到生成器和判别器达到纳什均衡;
生成器和判别器的损失函数采用对抗性损失的合页损失函数,形式为:
Figure BDA0002711926660000061
Figure BDA0002711926660000062
LM-GAN=LD+LG
其中,LD,LG分别表示判别器和生成器的损失函数,LM-GAN表示Multihead-attention GAN的损失函数;x表示输入的真实遥感图像,E(·)表示期望函数,z表示输入的随机噪声;pdata(x)表示在数据空间中真实数据的概率分布,pz(z)表示在潜在空间中的噪声z的先验概率分布;训练Multihead-attention GAN的目标即为最小化整体损失函数LM-GAN,当LM-GAN达到最小时,模型达到最优。
进一步的,所述步骤(3)中,形成新的遥感图像训练集,过程如下:
设合成的每类遥感图像数目为F,则合成的遥感图像训练集表示为:
Tf={(x11,y1),…,(x1F,y1),…,(xij,yi),…,(xK1,yK),…,(xKF,yK)};
将原始遥感图像训练集Tr和合成的遥感图像训练集Tf进行合并,合并方法为:将两个数据集中对应类别的遥感图像都添加进新的遥感图像数据集的对应类别中,即新的遥感图像数据集类别和训练集Tr、Tf完全相同,同一类别下,新的遥感图像训练集的数目是Tr、Tf中对应遥感图像数目的总和,新的遥感图像训练集表示为:
Tn={(x11,y1),…,(x1(M+F),y1),…,(xij,yj),…,(xK1,yK),…,(xK(M+F),yK)}。
进一步的,所述步骤(4)中,选用ResNet作为深度卷积神经网络分类器,将新的遥感图像训练集Tn输入ResNet中,进行网络的训练;在网络达到最优之后,将验证集Tv和测试集Te输入网络中进行验证和测试;针对测试集中的每一张图像xij,ResNet都输出一个预测的类别标签yp,将测试集中每一张图像的yp和真实标签yj进行对比,即得到分类准确率。
有益效果:与现有技术相比,本发明的技术方案具有以下有益技术效果:
(1)本方法将生成对抗网络用于水利遥感图像分类问题中,通过生成对抗网络生成与原始水利遥感图像极为相似的合成水利遥感图像,增强水利遥感图像数据集,提升小样本情况下预测大规模水利遥感图像分类问题的准确率,并能稳定改善分类结果;
(2)本方法在生成对抗网络的生成器和判别器中引入了Multihead-attention机制,针对每个head都执行Attention操作,多个head能够在不同的latent space中学习到侧重点不同的attention,提升模型的表示能力和容量;同时赋予不同的head经过attention输出之后的特征图不同的权重,即赋予来自不同latent space的特征不同的权重,最后将这些特征进行融合。融合后的特征既包含来自不同latent space中侧重点不同的多个特征,包含更多的信息,又针对每个特征自动学习有利于分类的权重,对分类的增益更大,生成的图像粒度和细节会更加精细,更接近于真实的水利遥感图像,有效提升分类准确率。
附图说明
图1为本发明实施的框架图;
图2为Multihead-attention的结构图;
图3为基于Multihead-attention生成对抗网络的结构图;
图4为基于ResNet的水利遥感图像分类的结构图;
图5为实验采用的19类水利遥感图像数据集;
图6为三种生成对抗网络生成的19类水利遥感图像。
具体实施方式
下面结合附图和实施例对本发明的技术方案作进一步的说明。
如图1所示,本发明提出的一种基于多头注意力生成对抗网络的遥感场景分类方法,该方法的具体步骤如下:
(1)将多头注意力机制引入生成对抗网络,构建Multihead-attention GAN网络,该网络包含一个嵌入多头注意力机制的生成器和一个嵌入多头注意力机制的判别器。
Multihead-attention是多个head的Attention,多个head实质上是在通道维度上对像素点进行划分。针对每个head都执行Attention操作,多个head能够在不同的latentspace中学习到侧重点不同的attention,提升模型的表示能力和容量;同时赋予不同的head经过attention输出之后的特征图不同的权重,即赋予来自不同latent space的特征不同的权重,最后将这些特征进行融合。融合后的特征既包含来自不同latent space中侧重点不同的多个特征,包含更多的信息,又针对每个特征自动学习有利于分类的权重,对分类的增益更大,生成的图像粒度和细节会更加精细,更接近于真实的水利遥感图像。
如图3所示,构建Multihead-attention GAN结构如下:
Multihead-attention GAN由嵌入了Multihead-attention机制的生成器和判别器构成;生成器G包含七层:五层反卷积层和两层Multihead-attention层;判别器D包含七层:五层卷积层和两层Multihead-attention层;其中每一层Multihead-attention都是4-head;
所述生成器的输入是128维的随机噪声z,经过第一层反卷积层TConv1后,生成的特征图为4×4×512;经过第二层反卷积层Tconv2后,生成的特征图为8×8×256;经过第三层反卷积层Tconv3后,生成的特征图为16×16×128;经过第四层Multihead-attention层GAtten1之后,生成的特征图为16×16×128;经过第五层反卷积层Tconv4后,生成的特征图为32×32×64;经过第六层Multihead-attention层GAtten2之后,生成的特征图为32×32×64;经过第七层反卷积层Tconv5后,生成的特征图为64×64×3,即为生成的遥感图像;
所述判别器的输入是真实的遥感图像和生成器生成的遥感图像,大小均为64×64×3;经过第一层卷积层Conv1之后,生成的特征图为32×32×64;经过第二层卷积层Conv2之后,生成的特征图为16×16×128;经过第三层卷积层Conv3之后,生成的特征图为8×8×256;经过第四层Multihead-attention层DAtten1之后,生成的特征图为8×8×256;经过第五层卷积层Conv4之后,生成的特征图为4×4×512;经过第六层Multihead-attention层DAtten2之后,生成的特征图为4×4×512;经过第七层卷积层Conv5之后,最后输出的特征图是1×1×1,表示输入分类器中待分类的图像所属的类别。
加入Multihead-attention机制之后,1)生成器在生成某一个像素时会综合其他像素点来自不同表示子空间的信息;2)判别器在判定图像类别时会观测较远点的细节信息是否一致。也就是说,生成器和判别器既能在卷积层的作用下重点关注局部感受野之内的像素点,又能同时兼顾在二维平面和通道维度上其他较远像素点的细节特征和不同特征之间的相关性。最终得到的特征既包含来自不同latent space中侧重点不同的多个特征,包含更多的信息,又针对每个特征自动学习有利于分类的权重,对分类的增益更大。生成的图像粒度和细节会更加精细,信息更加完整,相应的质量会有所提升,对遥感图像分类有积极影响。
Multihead-attention机制是在Self-attention机制上进行改进,将Self-attention的单头改为多头,针对每一个head执行attention操作;如图2所示,步骤如下:
将输入多头注意力机制层的特征图,从通道上划分为L个head,并且在每个head内计算attention矩阵;将每个head的attention矩阵作用于该head对应的输入特征图中的像素点,得到每个head的输出结果;将所有head的输出结果进行融合,融合后的结果作为该多头注意力机制层的输出。
计算每个head的attention矩阵方法如下:
首先,Multihead-attention对从判别器中第三层卷积层Conv3、第五层卷积层Conv4或生成器中第三层反卷积层Tconv3、第五层反卷积层Tconv4中提取出的特征图
Figure BDA0002711926660000081
赋予不同的权重矩阵WQ、WK和WV,生成特征图WQxh、WKxh和WVxh;其中C是特征图的通道数,N是特征图单个通道上像素点的个数;
其次,经过非线性变换f(x)=Wfx和g(x)=Wgx将特征图WQxh、WKxh和WVxh转换到特征空间f和g中;其中x表示输入的特征图,Wf,Wg分别表示将输入特征图转换到特征空间f和g的权重矩阵;
然后,计算第l个head的attention矩阵:
Figure BDA0002711926660000082
其中,
Figure BDA0002711926660000083
表示归一化之后第l个head的attention矩阵,
Figure BDA0002711926660000084
表示在Multihead-attention机制中,生成第j个区域时第i个位置对其的关注程度;N表示特征图单个通道上像素点的个数,xhi表示特征空间f中的像素点,xhj表示特征空间g中的像素点;
Figure BDA0002711926660000091
Figure BDA0002711926660000092
均为权重矩阵,
Figure BDA0002711926660000093
C是原始输入特征图的通道数,L为多头注意力机制中head的个数,
Figure BDA0002711926660000094
是每一个head划分到的通道数。
将每个head的attention矩阵作用于输入特征图对应区域内的像素点,得到每个head的输出结果,将所有head的输出结果进行融合,方法如下:
第l个head中的像素点组成的区域内,第j个像素xhj经过
Figure BDA0002711926660000095
作用之后的结果
Figure BDA0002711926660000096
为:
Figure BDA0002711926660000097
其中,h(x)=Whx;Wh是将特征图转换到特征空间h时的权重矩阵;
第j个像素xhj经过第l个head的attention机制计算之后的输出为:
Figure BDA0002711926660000098
其中,γ(l)表示
Figure BDA0002711926660000099
的权重,
Figure BDA00027119266600000910
表示第j个像素经过第l个head中的attention矩阵计算之后的输出结果;
将所有head的输出进行融合,得到:
Output=v(concat(O(1),O(2),…,O(l),…,O(L))WO)
Figure BDA00027119266600000911
其中,O(l)是经过第l个head attention作用之后的所有像素点的集合,v(x)=WVx,
Figure BDA00027119266600000912
是权重矩阵,
Figure BDA00027119266600000913
WO是权重矩阵。
(2)获取原始遥感图像数据集,划分训练集、验证集和测试集,使用训练集交替训练判别器和生成器,直到生成器和判别器达到纳什均衡,得到最优的网络模型。
准备原始遥感图像数据集方法如下:
设原始遥感图像数据集共包含K类,每个类别包含M张图像,原始遥感图像数据集表示为T={(x11,y1),…,(x1N,y1),…,(xij,yj),…,(xK1,yK),…,(xKM,yK)},其中xij表示为第i类遥感图像中的第j张图像,yj表示图像xij所属类别的标签;
将原始遥感图像数据集按照1:1:8的比例划分为训练集、验证集和测试集;
训练集Tr表示如下:
Tr={(x11,y1),…,(x1(M/10),y1),…,
(xi1,yi),…,(xi(M/10),yi),…,
(xK1,yK),…,(xK(M/10),yK)}
其中每一类遥感图像数目为M/10;
验证集Tv表示如下:
Tv={(x1(M/10+1),y1),…,(x1(2M/10),y1),…,
(xi(M/10+1),yi),…,(xi(2M/10),yi),…,
(xK(M/10+1),yK),…,(xK(2M/10),yK)}
其中每一类遥感图像数也为M/10;
测试集Te表示如下:
Te={(x1(2M/10+1),y1),…,(x1M,y1),…,
(xi(2M/10+1),yi),…,(xiM,yi),…,
(xK(2M/10+1),yK),…,(xKM,yK)}
其中每一类遥感图像数目为8M/10。
使用准备好的原始遥感图像数据集训练生成对抗网络,方法如下:
将准备好的原始遥感图像训练集Tr输入构建的Multihead-attention GAN中,交替训练判别器D和生成器G,直到生成器和判别器达到纳什均衡;
生成器和判别器的损失函数采用对抗性损失的合页损失函数,形式为:
Figure BDA0002711926660000101
Figure BDA0002711926660000102
LM-GAN=LD+LG
其中,LD,LG分别表示判别器和生成器的损失函数,LM-GAN表示Multihead-attention GAN的损失函数;x表示输入的真实遥感图像,E(·)表示期望函数,z表示输入的随机噪声;pdata(x)表示在数据空间中真实数据的概率分布,pz(z)表示在潜在空间中的噪声z的先验概率分布;训练Multihead-attention GAN的目标即为最小化整体损失函数LM-GAN,当LM-GAN达到最小时,模型达到最优。
(3)提取步骤(2)得到的最优网络模型中生成器生成的遥感图像,与原始遥感图像训练集合并,得到新的遥感图像训练集。
设合成的每类遥感图像数目为F,则合成的遥感图像训练集表示为:
Tf={(x11,y1),…,(x1F,y1),…,(xij,yi),…,(xK1,yK),…,(xKF,yK)};
将原始遥感图像训练集Tr和合成的遥感图像训练集Tf进行合并,合并方法为:将两个数据集中对应类别的遥感图像都添加进新的遥感图像数据集的对应类别中,即新的遥感图像数据集类别和训练集Tr、Tf完全相同,同一类别下,新的遥感图像训练集的数目是Tr、Tf中对应遥感图像数目的总和,新的遥感图像训练集表示为:
Tn={(x11,y1),…,(x1(M+F),y1),…,(xij,yj),…,(xK1,yK),…,(xK(M+F),yK)}。
(4)使用新的遥感图像训练集训练深度卷积神经网络分类器,使用验证集进行验证,在测试集上实现遥感图像的分类。
如图4所示,本实施例选用ResNet18网络作为深度卷积神经网络分类器,图4中ResNet的实线表示直接连接,虚线表示内核为1且步幅为2的卷积进行尺寸更改以匹配输入和输出特征的数量。将新的遥感图像训练集Tn输入ResNet18网络中进行网络的训练。假设Tn中的图像xij输入ResNet18之后,经过第l层残差单元(Residual block)之后的输入和输出分别为xij_l和xij_l+1,则有:
xij_l+1=R(h(xij_l)+F(xij_l,Wl))
其中,F(·)表示残差函数,Wl是计算残差时卷积操作的参数矩阵,h(·)表示恒等映射,R(·)表示ReLU激活函数。那么从浅层l到深层L学习到的特征为:
Figure BDA0002711926660000111
最后,根据链式规则,求得反向传播过程中的梯度为:
Figure BDA0002711926660000112
其中,loss表示ResNet的损失函数,
Figure BDA0002711926660000113
表示损失函数到达深层L的梯度。
在网络达到最优之后,将验证集Tv和测试集Te输入网络中进行验证和测试。针对测试集中的每一张图像xij,ResNet18都会输出一个预测的类别标签yp,将测试集中每一张图像的yp和真实标签yj进行对比,即可得到分类准确率。
本实施例选用两个不同的生成对抗网络与本发明提出的Multihead-attentionGAN进行对比,选用的两个生成对抗网络分别是A.Radford等人于2015年发表的“Unsupervised representation learning with deep convolutional generativeadversarial networks”中提出的Deep convolutional generative adversarialnetworks,简称DCGAN,以及H.Zhang等人与2018年发表的“Self-Attention generativeadversarial networks”中提出的Self-attention generative adversarial networks,简称SAGAN,将本发明提出的Multihead-attention GAN简称为MHAGAN。
实验采用的数据集为:从遥感图像数据集AID数据集、RSI-CB数据集、WHU-RS19数据集、PatternNet数据集、UC Merced Land Use数据集和NWPU-RESISC45中提取的与水利有关的高分遥感图像数据集,共19类,分别为:beach、bridge、dam、Farmland、ferry_terminal、Forest、harbor、lake、marina、pond、port、river、sea、sea_ice、stream、swimming_pool、terrace、wastewater_treament_plant、wetland。图5展示了使用的19类水利遥感图像数据集。图6展示了三种方法生成的19类水利遥感图像。
表1为三种数据增强方法进行数据增强之后的水利遥感图像分类准确率对比,表2分别为三种数据增强方法进行数据增强之后的Kappa系数对比,结果表明,本发明提出的基于多头注意力生成对抗网络的遥感场景分类方法性能最好。
表1
Figure BDA0002711926660000121
表2
Figure BDA0002711926660000122
以上所述是本发明的优选实施方式,应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明技术原理的前提下,还可以做出若干改进和变形,这些改进和变形也应视为本发明的保护范围。

Claims (5)

1.一种基于多头注意力生成对抗网络的遥感场景分类方法,其特征在于,该方法的具体步骤如下:
(1)将多头注意力机制引入生成对抗网络,构建Multihead-attention GAN网络,该网络包含一个嵌入多头注意力机制的生成器和一个嵌入多头注意力机制的判别器;
(2)获取原始遥感图像数据集,划分训练集、验证集和测试集,使用训练集交替训练判别器和生成器,直到生成器和判别器达到纳什均衡,得到最优的网络模型;
(3)提取步骤(2)得到的最优网络模型中生成器生成的遥感图像,与原始遥感图像训练集合并,得到新的遥感图像训练集;
(4)使用新的遥感图像训练集训练深度卷积神经网络分类器,使用验证集进行验证,在测试集上实现遥感图像的分类;
所述步骤(1)中,构建Multihead-attention GAN结构如下:
Multihead-attention GAN由嵌入了Multihead-attention机制的生成器和判别器构成;生成器G包含七层:五层反卷积层和两层Multihead-attention层;判别器D包含七层:五层卷积层和两层Multihead-attention层;
所述生成器的输入是128维的随机噪声z,经过第一层反卷积层TConv1后,生成的特征图为4×4×512;经过第二层反卷积层Tconv2后,生成的特征图为8×8×256;经过第三层反卷积层Tconv3后,生成的特征图为16×16×128;经过第四层Multihead-attention层GAtten1之后,生成的特征图为16×16×128;经过第五层反卷积层Tconv4后,生成的特征图为32×32×64;经过第六层Multihead-attention层GAtten2之后,生成的特征图为32×32×64;经过第七层反卷积层Tconv5后,生成的特征图为64×64×3,即为生成的遥感图像;
所述判别器的输入是真实的遥感图像和生成器生成的遥感图像,大小均为64×64×3;经过第一层卷积层Conv1之后,生成的特征图为32×32×64;经过第二层卷积层Conv2之后,生成的特征图为16×16×128;经过第三层卷积层Conv3之后,生成的特征图为8×8×256;经过第四层Multihead-attention层DAtten1之后,生成的特征图为8×8×256;经过第五层卷积层Conv4之后,生成的特征图为4×4×512;经过第六层Multihead-attention层DAtten2之后,生成的特征图为4×4×512;经过第七层卷积层Conv5之后,最后输出的特征图是1×1×1,表示输入分类器中待分类的图像所属的类别;
其中,Multihead-attention机制是在Self-attention机制上进行改进,将Self-attention的单头改为多头,针对每一个head执行attention操作;步骤如下:
1)将输入多头注意力机制层的特征图,从通道上划分为L个head,并且在每个head内计算attention矩阵;计算每个head的attention矩阵方法如下:
首先,Multihead-attention对从判别器中第三层卷积层Conv3、第五层卷积层Conv4或生成器中第三层反卷积层Tconv3、第五层反卷积层Tconv4中提取出的特征图
Figure FDA0003709715680000021
赋予不同的权重矩阵WQ、WK和WV,生成特征图WQxh、WKxh和WVxh;其中C是特征图的通道数,N是特征图单个通道上像素点的个数;
其次,经过非线性变换f(x)=Wfx和g(x)=Wgx将特征图WQxh、WKxh和WVxh转换到特征空间f和g中;其中x表示输入的特征图,Wf,Wg分别表示将输入特征图转换到特征空间f和g的权重矩阵;
然后,计算第l个head的attention矩阵:
Figure FDA0003709715680000022
其中,
Figure FDA0003709715680000023
表示归一化之后第l个head的attention矩阵,
Figure FDA0003709715680000024
表示在Multihead-attention机制中,生成第j个区域时第i个位置对其的关注程度;N表示特征图单个通道上像素点的个数,xhi表示特征空间f中的像素点,xhj表示特征空间g中的像素点;
Figure FDA0003709715680000025
Figure FDA0003709715680000026
均为权重矩阵,
Figure FDA0003709715680000027
C是原始输入特征图的通道数,L为多头注意力机制中head的个数,
Figure FDA0003709715680000028
是每一个head划分到的通道数;
2)将每个head的attention矩阵作用于该head对应的输入特征图中的像素点,得到每个head的输出结果;将所有head的输出结果进行融合,融合后的结果作为该多头注意力机制层的输出;方法如下:
第l个head中的像素点组成的区域内,第j个像素xhj经过
Figure FDA0003709715680000029
作用之后的结果
Figure FDA00037097156800000210
为:
Figure FDA00037097156800000211
其中,h(x)=Whx;Wh是将特征图转换到特征空间h时的权重矩阵;
第j个像素xhj经过第l个head的attention机制计算之后的输出为:
Figure FDA00037097156800000212
其中,γ(l)表示
Figure FDA00037097156800000213
的权重,
Figure FDA00037097156800000214
表示第j个像素经过第l个head中的attention矩阵计算之后的输出结果;
将所有head的输出进行融合,得到:
Output=v(concat(O(1),O(2),…,O(l),…,O(L))WO)
Figure FDA00037097156800000215
其中,O(l)是经过第l个head attention作用之后的所有像素点的集合,v(x)=WVx,
Figure FDA00037097156800000216
是权重矩阵,
Figure FDA00037097156800000217
WO是权重矩阵。
2.根据权利要求1所述的一种基于多头注意力生成对抗网络的遥感场景分类方法,其特征在于,准备原始遥感图像数据集方法如下:
设原始遥感图像数据集共包含K类,每个类别包含M张图像,原始遥感图像数据集表示为T={(x11,y1),…,(x1N,y1),…,(xij,yj),…,(xK1,yK),…,(xKM,yK)},其中xij表示为第i类遥感图像中的第j张图像,yj表示图像xij所属类别的标签;
将原始遥感图像数据集按照1:1:8的比例划分为训练集、验证集和测试集;
训练集Tr表示如下:
Tr={(x11,y1),…,(x1(M/10),y1),…,(xi1,yi),…,(xi(M/10),yi),…,(xK1,yK),…,(xK(M/10),yK)}
其中每一类遥感图像数目为M/10;
验证集Tv表示如下:
Tv={(x1(M/10+1),y1),…,(x1(2M/10),y1),…,(xi(M/10+1),yi),…,(xi(2M/10),yi),…,(xK(M/10+1),yK),…,(xK(2M/10),yK)}
其中每一类遥感图像数也为M/10;
测试集Te表示如下:
Te={(x1(2M/10+1),y1),…,(x1M,y1),…,(xi(2M/10+1),yi),…,(xiM,yi),…,(xK(2M/10+1),yK),…,(xKM,yK)}
其中每一类遥感图像数目为8M/10。
3.根据权利要求2所述的一种基于多头注意力生成对抗网络的遥感场景分类方法,其特征在于,使用准备好的原始遥感图像数据集训练生成对抗网络,方法如下:
将准备好的原始遥感图像训练集Tr输入构建的Multihead-attention GAN中,交替训练判别器D和生成器G,直到生成器和判别器达到纳什均衡;
生成器和判别器的损失函数采用对抗性损失的合页损失函数,形式为:
Figure FDA0003709715680000031
Figure FDA0003709715680000032
LM-GAN=LD+LG
其中,LD,LG分别表示判别器和生成器的损失函数,LM-GAN表示Multihead-attention GAN的损失函数;x表示输入的真实遥感图像,E(×)表示期望函数,z表示输入的随机噪声;pdata(x)表示在数据空间中真实数据的概率分布,pz(z)表示在潜在空间中的噪声z的先验概率分布;训练Multihead-attention GAN的目标即为最小化整体损失函数LM-GAN,当LM-GAN达到最小时,模型达到最优。
4.根据权利要求2或3所述的一种基于多头注意力生成对抗网络的遥感场景分类方法,其特征在于,所述步骤(3)中,形成新的遥感图像训练集,过程如下:
设合成的每类遥感图像数目为F,则合成的遥感图像训练集表示为:
Tf={(x11,y1),…,(x1F,y1),…,(xij,yi),…,(xK1,yK),…,(xKF,yK)};
将原始遥感图像训练集Tr和合成的遥感图像训练集Tf进行合并,合并方法为:将两个数据集中对应类别的遥感图像都添加进新的遥感图像数据集的对应类别中,即新的遥感图像数据集类别和训练集Tr、Tf完全相同,同一类别下,新的遥感图像训练集的数目是Tr、Tf中对应遥感图像数目的总和,新的遥感图像训练集表示为:
Tn={(x11,y1),…,(x1(M+F),y1),…,(xij,yj),…,(xK1,yK),…,(xK(M+F),yK)}。
5.根据权利要求4所述的一种基于多头注意力生成对抗网络的遥感场景分类方法,其特征在于,所述步骤(4)中,选用ResNet作为深度卷积神经网络分类器,将新的遥感图像训练集Tn输入ResNet中,进行网络的训练;在网络达到最优之后,将验证集Tv和测试集Te输入网络中进行验证和测试;针对测试集中的每一张图像xij,ResNet都输出一个预测的类别标签yp,将测试集中每一张图像的yp和真实标签yj进行对比,即得到分类准确率。
CN202011059707.1A 2020-09-30 2020-09-30 一种基于多头注意力生成对抗网络的遥感场景分类方法 Active CN112232156B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011059707.1A CN112232156B (zh) 2020-09-30 2020-09-30 一种基于多头注意力生成对抗网络的遥感场景分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011059707.1A CN112232156B (zh) 2020-09-30 2020-09-30 一种基于多头注意力生成对抗网络的遥感场景分类方法

Publications (2)

Publication Number Publication Date
CN112232156A CN112232156A (zh) 2021-01-15
CN112232156B true CN112232156B (zh) 2022-08-16

Family

ID=74119853

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011059707.1A Active CN112232156B (zh) 2020-09-30 2020-09-30 一种基于多头注意力生成对抗网络的遥感场景分类方法

Country Status (1)

Country Link
CN (1) CN112232156B (zh)

Families Citing this family (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112949384B (zh) * 2021-01-23 2024-03-08 西北工业大学 一种基于对抗性特征提取的遥感图像场景分类方法
CN113095437B (zh) * 2021-04-29 2022-03-01 中国电子科技集团公司第五十四研究所 一种Himawari-8遥感数据的火点检测方法
CN113177599B (zh) * 2021-05-10 2023-11-21 南京信息工程大学 一种基于gan的强化样本生成方法
CN113239844B (zh) * 2021-05-26 2022-11-01 哈尔滨理工大学 一种基于多头注意力目标检测的智能化妆镜系统
CN113344070A (zh) * 2021-06-01 2021-09-03 南京林业大学 一种基于多头自注意力模块的遥感图像分类系统及方法
CN113450313B (zh) * 2021-06-04 2022-03-15 电子科技大学 一种基于区域对比学习的图像显著性可视化方法
CN113538615B (zh) * 2021-06-29 2024-01-09 中国海洋大学 基于双流生成器深度卷积对抗生成网络的遥感图像上色方法
CN113822895B (zh) * 2021-08-29 2024-08-02 陕西师范大学 一种基于自注意力机制和CycleGAN的ScanSAR图像扇贝效应抑制方法
CN113887136B (zh) * 2021-10-08 2024-05-14 东北大学 一种基于改进GAN和ResNet的电动汽车电机轴承故障诊断方法
CN115661002B (zh) * 2022-12-14 2023-04-21 北京数慧时空信息技术有限公司 基于gan的多时相遥感数据修复方法

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110287800B (zh) * 2019-05-29 2022-08-16 河海大学 一种基于sgse-gan的遥感图像场景分类方法
CN110414377B (zh) * 2019-07-09 2020-11-13 武汉科技大学 一种基于尺度注意力网络的遥感图像场景分类方法

Also Published As

Publication number Publication date
CN112232156A (zh) 2021-01-15

Similar Documents

Publication Publication Date Title
CN112232156B (zh) 一种基于多头注意力生成对抗网络的遥感场景分类方法
Guo et al. Scene-driven multitask parallel attention network for building extraction in high-resolution remote sensing images
Wang et al. Adaptive DropBlock-enhanced generative adversarial networks for hyperspectral image classification
Zhang et al. Asymmetric cross-attention hierarchical network based on CNN and transformer for bitemporal remote sensing images change detection
Xue et al. Local transformer with spatial partition restore for hyperspectral image classification
Dong et al. Abundance matrix correlation analysis network based on hierarchical multihead self-cross-hybrid attention for hyperspectral change detection
Li et al. Few-shot hyperspectral image classification with self-supervised learning
CN111242061B (zh) 一种基于注意力机制的合成孔径雷达舰船目标检测方法
Zhang et al. Efficiently utilizing complex-valued PolSAR image data via a multi-task deep learning framework
CN114612476B (zh) 一种基于全分辨率混合注意力机制的图像篡改检测方法
Zhou et al. PVT-SAR: An arbitrarily oriented SAR ship detector with pyramid vision transformer
Zhao et al. Center attention network for hyperspectral image classification
Feng et al. Embranchment cnn based local climate zone classification using sar and multispectral remote sensing data
CN109977968A (zh) 一种深度学习分类后比较的sar变化检测方法
Zhao et al. High-resolution remote sensing bitemporal image change detection based on feature interaction and multitask learning
CN111091059A (zh) 一种生活垃圾塑料瓶分类中的数据均衡方法
Jeong et al. Enriching SAR ship detection via multistage domain alignment
Yamada et al. Geoclr: Georeference contrastive learning for efficient seafloor image interpretation
Huyan et al. AUD-Net: A unified deep detector for multiple hyperspectral image anomaly detection via relation and few-shot learning
Anwer et al. Accident vehicle types classification: a comparative study between different deep learning models
Chen et al. Class-aware domain adaptation for coastal land cover mapping using optical remote sensing imagery
Zhao et al. NAS-kernel: Learning suitable Gaussian kernel for remote sensing object counting
CN109902746A (zh) 非对称的细粒度红外图像生成系统及方法
Zhu et al. Adversarial fine-grained adaptation network for cross-scene classification
CN113920481A (zh) 基于航迹特征和深度神经网络MobileNet迁移训练的船舶分类识别方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant