CN118038181A - 一种基于元迁移梯度更新策略的高光谱图像分类方法 - Google Patents
一种基于元迁移梯度更新策略的高光谱图像分类方法 Download PDFInfo
- Publication number
- CN118038181A CN118038181A CN202410331617.5A CN202410331617A CN118038181A CN 118038181 A CN118038181 A CN 118038181A CN 202410331617 A CN202410331617 A CN 202410331617A CN 118038181 A CN118038181 A CN 118038181A
- Authority
- CN
- China
- Prior art keywords
- meta
- domain
- task
- source
- source domain
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000013508 migration Methods 0.000 title claims abstract description 46
- 238000000034 method Methods 0.000 title claims abstract description 44
- 238000012549 training Methods 0.000 claims description 17
- 238000005070 sampling Methods 0.000 claims description 13
- 230000006870 function Effects 0.000 claims description 11
- 230000005012 migration Effects 0.000 claims description 5
- 238000005457 optimization Methods 0.000 claims description 5
- 230000008485 antagonism Effects 0.000 claims description 3
- 238000009826 distribution Methods 0.000 abstract description 14
- 230000000694 effects Effects 0.000 abstract description 3
- 230000001939 inductive effect Effects 0.000 abstract description 3
- 239000011159 matrix material Substances 0.000 description 8
- 230000008569 process Effects 0.000 description 4
- 238000004088 simulation Methods 0.000 description 4
- 238000012360 testing method Methods 0.000 description 4
- 238000005516 engineering process Methods 0.000 description 3
- 238000013507 mapping Methods 0.000 description 3
- 102000002274 Matrix Metalloproteinases Human genes 0.000 description 2
- 108010000684 Matrix Metalloproteinases Proteins 0.000 description 2
- 238000004458 analytical method Methods 0.000 description 2
- 238000013145 classification model Methods 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 238000002474 experimental method Methods 0.000 description 2
- 238000012706 support-vector machine Methods 0.000 description 2
- 238000012935 Averaging Methods 0.000 description 1
- 230000006978 adaptation Effects 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000015572 biosynthetic process Effects 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 238000000701 chemical imaging Methods 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 230000007613 environmental effect Effects 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 229910052500 inorganic mineral Inorganic materials 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 239000011707 mineral Substances 0.000 description 1
- 238000003909 pattern recognition Methods 0.000 description 1
- 230000003595 spectral effect Effects 0.000 description 1
- 238000001228 spectrum Methods 0.000 description 1
- 238000003786 synthesis reaction Methods 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
- G06N3/0455—Auto-encoder networks; Encoder-decoder networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/094—Adversarial learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/74—Image or video pattern matching; Proximity measures in feature spaces
- G06V10/75—Organisation of the matching processes, e.g. simultaneous or sequential comparisons of image or video features; Coarse-fine approaches, e.g. multi-scale approaches; using context analysis; Selection of dictionaries
- G06V10/757—Matching configurations of points or features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/762—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using clustering, e.g. of similar faces in social networks
- G06V10/763—Non-hierarchical techniques, e.g. based on statistics of modelling distributions
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/10—Terrestrial scenes
- G06V20/194—Terrestrial scenes using hyperspectral data, i.e. more or other wavelengths than RGB
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02A—TECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
- Y02A40/00—Adaptation technologies in agriculture, forestry, livestock or agroalimentary production
- Y02A40/10—Adaptation technologies in agriculture, forestry, livestock or agroalimentary production in agriculture
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Multimedia (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Molecular Biology (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Remote Sensing (AREA)
- Spectroscopy & Molecular Physics (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于元迁移梯度更新策略的高光谱图像分类方法。通过任务分布对齐策略构建了平衡的元任务簇,以解决不同领域之间由于类别关系差异而引起的任务分布错位问题。接着,利用领域投影头捕获与域相关的特定知识,确保共享特征嵌入模块能够专注于捕获两个域之间共享的域不变知识。最后,采用元迁移梯度更新策略来更新模型,聚焦于让模型从元迁移任务集合中归纳出适用于各类型元迁移任务的无偏知识,从而提升模型的泛化性能并优化元迁移学习效果。
Description
技术领域
本发明涉及模式识别技术领域,主要涉及一种基于元迁移梯度更新策略的高光谱图像分类方法。
背景技术
高光谱成像是遥感领域的一项重要技术,用于收集从可见光到近红外波段的电磁频谱数据。高光谱图像中的每个像素都包含了数百个波段的光谱信息。这些图像通常用于分类任务,如在农业、林业和矿产勘探等领域。
传统的高光谱图像分类方法需要手工设计特征和分类器,限制了其在复杂场景下的应用。近年来,深度学习技术在高光谱图像分类方面取得了进展,它可以自动学习特征表示,减少了手工特征提取的需求。然而,深度学习方法通常需要大量标记样本,而在高光谱图像分类中,标记样本数量通常很有限。
为了应对小样本问题,跨域小样本学习成为一种解决方法。这种方法的目标是在源域学习知识,然后将这些知识迁移到目标域,以实现准确分类。跨域小样本学习需要解决领域差异和类别差异问题。领域差异问题可以通过迁移学习来解决,因为源域和目标域的数据通常具有不同的分布。迁移学习可以帮助模型适应目标域的数据分布。一些方法利用物理特征将数据投影到共享的空间中,以弥合领域之间的差距。类别差异问题可以利用元学习来解决,因为源域和目标域的样本类别通常不一致。元学习可以帮助模型从有限的标记样本中学习泛化特征表示。
为了同时应对跨域和类别差异,可以采用元迁移学习方法,它结合了迁移学习和元学习的优势,以增强模型的泛化性能和适应性。尽管元迁移学习方法在性能提升方面已有显著进展,但它们主要通过对齐源域和目标域的边缘分布来促进知识迁移,而忽略了两个领域之间的异质性可能导致跨域知识干扰的问题。这些问题包括设备、光照和环境差异,类别差异以及数据分布不完全对齐。
发明内容
发明目的:针对上述背景技术中存在的问题,本发明提供了一种基于元迁移梯度更新策略的高光谱图像分类方法,从多个视角解决了元迁移学习中的跨域知识干扰问题,为元迁移学习提供了新范式。
本发明的一种基于元迁移梯度更新策略的高光谱图像分类方法,包括如下步骤:
步骤1,从源域和目标域训练数据中随机采样一组元任务,输入元任务嵌入头以获得元任务嵌入集合T,基于所述元任务嵌入集合T构建平衡元任务簇;所述源域和目标域训练数据的数据维度相同;
步骤2,从平衡元任务簇中采样源域-目标域元任务对,将源域和目标域的支持样本和查询样本输入特征嵌入模块,获得源域的支持特征和查询特征,以及目标域的支持特征和查询特征;
然后,将两域的支持特征和查询特征输入领域投影头中,在领域投影头中,首先通过领域自适应得到两域的领域不变特征,即源域领域不变特征和目标域领域不变特征。之后,将源域领域不变特征输入领域特定投影模块,获得源域的域特定特征。
步骤3,首先,基于源域的域特定特征通过原型匹配对源域查询样本进行类别预测,并通过计算源域小样本分类损失对模型进行伪更新;
所述模型包括元任务嵌入头、特征嵌入模块和领域投影头,所述领域投影头中还包括领域特定投影模块;
然后,基于目标域的域不变特征通过原型匹配对目标域查询样本进行类别预测,计算目标域小样本分类损失。
步骤4,将源域的小样本分类损失得到的源更新梯度和目标域的小样本分类损失得到的目标更新梯度相结合,获得元任务对的元梯度,通过元迁移梯度更新策略进行模型更新。
进一步的,步骤1中从源域和目标域训练数据中随机采样一组元任务,输入元任务嵌入头以获得元任务嵌入集合T,基于所述元任务嵌入集合T构建平衡元任务簇,具体包括如下步骤:
步骤1.1,分别从两域的训练数据中随机采样一组元任务,所述一组元任务包括源域元任务和目标域元任务,将所述源域元任务和目标域元任务均输入元任务嵌入头以获得元任务嵌入集合T,表示如下:
其中,和/>分别表示源域和目标域的元任务嵌入,N为目标域和源域元任务的数量。
步骤1.2,基于元任务嵌入集合T,通过任务聚类将两域元任务划分为K个任务簇,第k个任务簇Tk表示为:
其中,mk和dk分别表示第k个任务簇中源域元任务嵌入和目标域元任务嵌入的数量。采用K-means实现任务聚类。
步骤1.3,在每一个任务簇中随机采样一个源域元任务和一个目标域元任务,构建平衡元任务簇。
所述元任务嵌入头是基于Transformer的元任务嵌入头,包括均值函数和条件编码模块;将源域元任务和目标域元任务均输入元任务嵌入头以获得元任务嵌入集合T,具体步骤如下:
首先,基于两域元任务中的各类样本集合,通过均值函数计算类锚点以捕获类相关的统计和结构信息。
之后,将类锚点输入条件编码模块,为元任务中的各个样本生成类相关的位置编码,结构化的类知识被整合到元任务嵌入的学习中。
最后,为元任务中的每个样本注入位置编码,并设置一个可学习的task token,用以捕获各样本之间的依赖关系,进而获得元任务嵌入,源域和目标域的元任务嵌入共同组成了元任务嵌入集合T。
进一步的,步骤2中,所述领域投影头包括领域判别器;在领域投影头中,首先通过领域自适应得到两域的领域不变特征,具体是通过在领域投影头中,特征嵌入模块和领域判别器之间的对抗学习最小化域不变损失而得到两域的领域不变特征,领域适应的优化目标表示为:
其中,Ds和Dt分别表示源域和目标域,zs和zt表示源域和目标域样本的特征,E表示期望,D表示领域判别器,Z表示特征嵌入模块。
进一步的,步骤2中所述将源域领域不变特征输入领域特定投影模块,获得源域的域特定特征,表示为:
ψ(x)=Γ(Z(xs)|x∈SsΓQs)
其中,Γ(·)表示域对域函数,xs表示源域HSI,Ss和Qs分别表示源域的支持集和查询集。
进一步的,采用Transformer拟合域对域函数Γ(·),具体而言:为了捕获源域元任务的域特定知识,首先建立了样本间的关联矩阵。通过建立关联矩阵,能够识别并整合这些共性特征,进一步揭示样本间共享的域特定知识。通过投影矩阵分别对源域样本xs和另一个样本进行映射,再通过softmax计算两个样本之间的关联矩阵αi。最后,基于关联矩阵可以得到源域的域特定特征/>通过这种方式,域特定特征提取器ψ(·)整合了源域元任务的域特定知识,确保两个域共享的Z能够专注于捕获两域共享的域不变知识。
进一步的,所述步骤3中所述对源域查询样本进行类别预测与对目标域查询样本进行类别预测的方法相同,所述计算源域小样本分类损失与计算目标域小样本分类损失的方法相同,均按照如下步骤进行:
源域和目标域的每个元任务都包含了C个类别的J个支持样本和F个查询样本,将支持样本和查询样本映射到共同的特征空间,得到支持特征和查询特征,并利用支持特征计算类别原型,接着,通过度量各查询样本和类别原型的在特征空间的距离,获得查询样本的类别概率。
查询样本属于第i类的概率表示为:
其中,为/>的查询样本的类别标签和特征,d(·)为欧式距离度量,Q表示查询集,类别原型mc的计算如下:
其中,表示第n个支持集样本,J表示支持集样本数;
最后,利用小样本分类损失对模型进行优化,小样本分类损失Lfsl表示为:
其中,C、F分别表示类别数和查询样本数,d(·)为欧式距离度量。
进一步的,步骤3中所述计算源域小样本分类损失对模型进行伪更新是指,在元迁移学习中,进行源域梯度更新时,将其梯度保存下来,只进行伪更新,不作用于模型上;
进一步的,步骤4中,所述获得元任务对的元梯度表示为
其中,Si为源域元任务,Ti为目标域元任务,θ为模型参数;使用元梯度对模型进行更新。
进一步的,步骤4中所述元迁移梯度更新策略是指:
首先,遍历平衡元任务簇中所有的源域-目标域元任务对按批次进行训练,一个批次中采样多个元任务对;在每个元任务对均得到对应的元梯度后,求取平均元梯度元梯度Gradall,利用Gradall对模型进行优化:
θ′=θ-ρGradall
其中,ρ为集合元梯度步长。
有益效果:本发明的一种基于元迁移梯度更新策略的高光谱图像分类方法。通过任务分布对齐策略构建了平衡的元任务簇,以解决不同领域之间由于类别关系差异而引起的任务分布错位问题。接着,利用领域投影头捕获与域相关的特定知识,确保共享特征嵌入模块能够专注于捕获两个域之间共享的域不变知识。最后,采用元迁移梯度更新策略来更新模型,聚焦于让模型从元迁移任务集合中归纳出适用于各类型元迁移任务的无偏知识,从而提升模型的泛化性能并优化元迁移学习效果。
附图说明
图1是本发明提供的基于元迁移梯度更新策略的高光谱图像分类方法原理框图。
具体实施方式
下面结合附图对本发明作更进一步的说明。显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明提供的基于元迁移梯度更新策略的高光谱图像分类方法,并基于所述方法构建了高光谱图像分类模型,模型包括元任务嵌入头、特征嵌入模块和领域投影头,所述领域投影头中还包括领域判别器和领域特定投影模块;具体原理如图1所示,本发明的基于元迁移梯度更新策略的高光谱图像分类方法,包括如下步骤:
步骤1、利用主成分分析统一源域和目标域数据的数据维度,获得具有相同维度的两域训练数据。然后,通过任务分布对齐策略获取平衡元任务簇。
具体来说,首先,分别从两域的训练集中随机采样一组元任务输入元任务嵌入头以获得元任务嵌入集合。之后,基于元任务嵌入集合通过任务聚类将元任务划分为若干个任务簇。最后,在每个任务簇中进行平衡元任务采样,获得平衡元任务簇。
步骤S1中利用任务分布对齐策略获取平衡元任务簇,具体包括:
步骤S1.1、设计了一个基于Transformer的元任务嵌入头,将元任务视作token序列,利用Transformer强大的数据关联建模能力,捕获两域元任务中各样本间的依赖关系。
首先,基于元任务中的各类样本集合,通过均值函数计算类锚点以捕获类相关的统计和结构信息。之后,将类锚点输入条件编码模块,为元任务中的各个样本生成类相关的位置编码。在基于类锚点的引导下,不同类别的样本被赋予具有区分性位置编码,通过这种方式,结构化的类知识被整合到元任务嵌入的学习中。最后,为元任务中的每个样本注入位置编码,并设置一个可学习的task token,用以捕获各样本之间的依赖关系,进而获得元任务嵌入。需要指出的是,源域和目标域的元任务被输入元任务嵌入头,得到元任务嵌入集合:
其中,和/>分别表示源域和目标域的元任务嵌入,N为目标域和源域元任务的数量。
步骤S1.2、在获得两域的元任务嵌入后,利用任务聚类将元任务嵌入集合划分为K个任务簇,第k个任务簇Tk可表示为:
其中,mk和dk分别表示第k个任务簇中源域元任务嵌入和目标域元任务嵌入的数量。在发明中,采用K-means实现任务聚类。最后,在每一个任务簇中随机采样一个源域元任务和一个目标域元任务,构建平衡元任务簇。
步骤S2、从平衡元任务簇中采样源域-目标域元任务对,将源域和目标域的支持样本和查询样本输入特征嵌入模块,获得两域分别的支持特征和查询特征。然后,将两域的支持特征和查询特征输入领域投影头中。在领域投影头中,首先通过领域自适应得到两域的领域不变特征,即源域领域不变特征和目标域领域不变特征。之后,将源域领域不变特征输入领域特定投影模块,获得域特定特征。目标域领域不变特征则不进入领域特定投影模块;具体过程如下:
首先,输入通过特征嵌入模块得到的源域和目标域的支持特征和查询特征,进行领域自适应,具体来说,通过特征嵌入模块和领域判别器之间的对抗学习最小化域不变损失,得到源域的领域不变特征和目标域的领域不变特征。这一过程旨在缓解源域和目标域之间的数据分布差异,从而解决在知识迁移过程中可能遇到的困难。领域适应的优化目标可以表示为:
其中,Ds和Dt分别表示源域和目标域,zs和zt表示源域和目标域样本的特征,E表示期望,D表示领域判别器,Z表示特征嵌入模块。
在对目标域训练时,将上述提取出来的目标域领域不变特征直接输出到后续模型中。在对源域训练时,将源域领域不变特征作为输入,通过学习一个域对域函数Γ(·),将特征嵌入模块Z转化成域特定特征提取器ψ(·)。通过这种方式,可以捕获与该域相关的特定知识,进而确保两域共享的特征嵌入模块Z能够更好地捕获两域共享域不变知识。
ψ(x)=Γ(Z(xs)|x∈SsΓQs)
其中,xs表示源域HSI,Ss和Qs分别表示源域的支持集和查询集。
本发明使用Transformer拟合域对域函数Γ(·)。具体而言,为了捕获源域元任务的域特定知识,首先建立了样本间的关联矩阵。通过建立关联矩阵,能够识别并整合这些共性特征,进一步揭示样本间共享的域特定知识。通过投影矩阵分别对源域样本xs和另一个样本进行映射,再通过softmax计算两个样本之间的关联矩阵αi。最后,基于关联矩阵可以得到源域的域特定特征/>通过这种方式,域特定特征提取器ψ(·)整合了源域元任务的域特定知识,确保两个域共享的Z能够专注于捕获两域共享的域不变知识。
步骤S3、首先,基于源域的域特定特征通过原型匹配对查询样本进行类别预测,并通过计算源域小样本分类损失对模型进行伪更新。然后,基于目标域的域不变特征通过原型匹配对查询样本进行类别预测,计算目标域小样本分类损失;
其中,类别预测与计算小样本分类损失具体步骤如下:
源域和目标域的每个元任务都包含了C个类别的J个支持样本和F个查询样本,两域的类别预测及分类损失均按照如下步骤进行:
将支持样本和查询样本映射到共同的特征空间,得到支持特征和查询特征,并利用支持特征计算类别原型,接着,通过度量各查询样本和类别原型的在特征空间的距离,获得查询样本的的类别概率。查询样本属于第i类的概率可以表示为:
其中,为/>的查询样本的类别标签和特征,d(·)为欧式距离度量,Q表示查询集,类别原型mc的计算如下:
其中,表示第n个支持集样本,J表示支持集样本数
最后,利用小样本分类损失对模型进行优化,小样本损失Lfsl可表示为:
其中,C、F分别表示类别数和查询样本数,d(·)为欧式距离度量。
步骤S4、通过元迁移梯度更新策略进行模型更新。具体地,将源域的分类损失得到的源更新梯度和目标域的分类损失得到的目标更新梯度相结合,获得该组元任务对的元梯度。训练时遍历平衡元任务簇中所有的源域-目标域元任务对按批次进行训练,一个批次中采样多个元任务对。在每个元任务对均得到对应的元梯度后,求取平均元梯度,再利用平均元梯度对模型进行更新。通过这种方式,使得模型训练时的更新不依赖于单一任务,而是试图在所有任务上都取得良好的表现;
本具体实施方式中,元迁移梯度更新策略包括如下过程:
在元迁移学习中,在进行源域梯度更新时,将其梯度保存下来,只进行伪更新,不作用于模型上,在进行目标域梯度更新时,将源域的梯度和目标域的梯度合成为元梯度,使用元梯度对模型进行更新,同时将元梯度保存下来。假设模型的参数为θ,所以元梯度可以表示为
其中,Si为源域元任务,Ti为目标域元任务,通过将两个域的更新梯度相结合,使得梯度的更新方向不仅适用于源域,而且考虑到了目标域,使两个域的更新达到一个均衡,从而缓解了一定的知识干扰的问题。
同时,构建了包含平衡元任务簇中所有源域-目标域任务对的元迁移任务集合,并聚焦于让模型从元迁移任务集合中归纳出适用于各类型元迁移任务的无偏知识。具体而言,首先,遍历元迁移任务集合,并在每个元迁移任务中利用元梯度优化对模型进行伪更新。之后,计算伪更新后的模型在整个元任务簇的整体损失,并通过对所有元梯度取平均获得集合元梯度Gradall,接着利用Gradall对模型进行优化:
θ′=θ-rGradall
其中,r为集合元梯度步长。值得注意的是,元梯度直接作用于伪更新前的模型参数,这种更新机制实际上是鼓励模型沿着具有高优化潜力的路径进行学习,这不仅需要模型能够适应当前元迁移任务,还需要其从任务中归纳出可在各种类型元迁移任务中取得均衡表现的无偏知识。
下面结合仿真试验对本发明的效果做进一步的说明:
1.仿真试验条件:
本发明的仿真实验采用的硬件测试平台是:3.80GHz Intel Core i7-10700KFCPU、32GB RAM和RTX 2080Ti GPU,并使用PyTorch作为实验环境。
选择了两个经典的高光谱图像数据集进行实验,分别是Chikusei(CH)、IndianPines(IP),将Chikusei数据集作为源域,将IP数据集作为目标域,进行了实验。
实验使用了三个评价指标,包括总体准确性(OA)、各类别的准确性(AA)、Kappa系数。OA是模型在所有样本中正确分类的比例,是一个总体性能的度量;AA衡量了模型对每个类别的分类准确性,给出了每个类别单独的性能评估;Kappa系数是一个衡量分类模型与随机分类之间一致性的指标,考虑了分类的随机性。在高光谱图像分类任务中,三个指标值越高,分类性能越好。
2.仿真实验及结果分析:
本发明与支持向量机(SVM)、深度小样本学习(DFSL)和深度跨域小样本学习(DCFSL)三种现有的技术在IP高光谱数据集上的整体分类精度OA,平均分类精度AA和Kappa系数的对比如表1所示。
表1本发明与现有技术在分类精度上的对比结果
从表1分析可得,本发明的分类结果在OA,AA和Kappa系数上的试验结果均优于三种现有技术。相比之下,本发明利用元学习的思想,通过在源域中学习任务的特征表示,能够更好地适应目标域中的任务,从而在小样本情况下获得更好的分类性能。通过引入领域投影头和任务分布对齐策略,能够更好地缓解跨域知识干扰,提高模型的稳定性和泛化性能,从而在元学习任务中具有优势。在元迁移梯度更新策略中采用跨域梯度合成和平均梯度更新,使模型能够更好地适应目标域任务的特点,同时提高对未见任务的泛化能力,从而在元迁移学习中取得更好的性能。
综上所述,本发明通过引入领域投影头,任务分布对齐策略和元迁移梯度更新策略,综合处理了两域之间的异质性、任务差异和数据分布不一致等问题,为解决小样本跨域知识干扰问题提供了一定的参考价值。
Claims (10)
1.一种基于元迁移梯度更新策略的高光谱图像分类方法,其特征在于,包括如下步骤:
步骤1,从源域和目标域训练数据中随机采样一组元任务,输入元任务嵌入头以获得元任务嵌入集合T,基于所述元任务嵌入集合T构建平衡元任务簇;
步骤2,从平衡元任务簇中采样源域-目标域元任务对,将源域和目标域的支持样本和查询样本输入特征嵌入模块,获得源域的支持特征和查询特征,以及目标域的支持特征和查询特征;
然后,将两域的支持特征和查询特征输入领域投影头中,在领域投影头中,首先通过领域自适应得到两域的领域不变特征,即源域领域不变特征和目标域领域不变特征;之后,将源域领域不变特征输入领域特定投影模块,获得源域的域特定特征;
步骤3,首先,基于源域的域特定特征通过原型匹配对源域查询样本进行类别预测,并通过计算源域小样本分类损失对模型进行伪更新;
然后,基于目标域的域不变特征通过原型匹配对目标域查询样本进行类别预测,计算目标域小样本分类损失;
步骤4,将源域的小样本分类损失得到的源更新梯度和目标域的小样本分类损失得到的目标更新梯度相结合,获得元任务对的元梯度,通过元迁移梯度更新策略进行模型更新。
2.根据权利要求1所述一种基于元迁移梯度更新策略的高光谱图像分类方法,其特征在于,步骤1中从源域和目标域训练数据中随机采样一组元任务,输入元任务嵌入头以获得元任务嵌入集合T,基于所述元任务嵌入集合T构建平衡元任务簇,具体包括如下步骤:
步骤1.1,分别从两域的训练数据中随机采样一组元任务,所述一组元任务包括源域元任务和目标域元任务,将所述源域元任务和目标域元任务均输入元任务嵌入头以获得元任务嵌入集合T,表示如下:
其中,和/>分别表示源域和目标域的元任务嵌入,N为目标域和源域元任务的数量;
步骤1.2,基于元任务嵌入集合T,通过任务聚类将两域元任务划分为K个任务簇,第k个任务簇Tk表示为:
其中,mk和dk分别表示第k个任务簇中源域元任务嵌入和目标域元任务嵌入的数量;采用K-means实现任务聚类;
步骤1.3,在每一个任务簇中随机采样一个源域元任务和一个目标域元任务,构建平衡元任务簇。
3.根据权利要求2所述一种基于元迁移梯度更新策略的高光谱图像分类方法,其特征在于,所述元任务嵌入头是基于Transformer的元任务嵌入头,包括均值函数和条件编码模块;步骤1.1中所述将源域元任务和目标域元任务均输入元任务嵌入头以获得元任务嵌入集合T,具体步骤如下具体步骤如下:
首先,基于两域元任务中的各类样本集合,通过均值函数计算类锚点以捕获类相关的统计和结构信息;
之后,将类锚点输入条件编码模块,为元任务中的各个样本生成类相关的位置编码,结构化的类知识被整合到元任务嵌入的学习中;
最后,为元任务中的每个样本注入位置编码,并设置一个可学习的task token,用以捕获各样本之间的依赖关系,进而获得元任务嵌入,源域和目标域的元任务嵌入共同组成了元任务嵌入集合T。
4.根据权利要求1所述一种基于元迁移梯度更新策略的高光谱图像分类方法,其特征在于,步骤2中,所述领域投影头包括领域判别器;
在领域投影头中,首先通过领域自适应得到两域的领域不变特征,具体是通过在领域投影头中,特征嵌入模块和领域判别器之间的对抗学习最小化域不变损失而得到两域的领域不变特征,领域适应的优化目标表示为:
其中,Ds和Dt分别表示源域和目标域,zs和zt表示源域和目标域样本的特征,E表示期望,D表示领域判别器,Z表示特征嵌入模块。
5.根据权利要求4所述一种基于元迁移梯度更新策略的高光谱图像分类方法,其特征在于,步骤2中所述将源域领域不变特征输入领域特定投影模块,获得源域的域特定特征,表示为:
ψ(x)=Γ(Z(xs)|x∈Ss∪Qs)
其中,Γ(·)表示域对域函数,xs表示源域HSI,Ss和Qs分别表示源域的支持集和查询集。
6.根据权利要求5所述一种基于元迁移梯度更新策略的高光谱图像分类方法,其特征在于,采用Transformer拟合域对域函数Γ(·)。
7.根据权利要求1所述一种基于元迁移梯度更新策略的高光谱图像分类方法,其特征在于,所述步骤3中所述对源域查询样本进行类别预测与对目标域查询样本进行类别预测的方法相同,所述计算源域小样本分类损失与计算目标域小样本分类损失的方法相同,均按照如下步骤进行:
源域和目标域的每个元任务都包含了C个类别的J个支持样本和F个查询样本,将支持样本和查询样本映射到共同的特征空间,得到支持特征和查询特征,并利用支持特征计算类别原型,接着,通过度量各查询样本和类别原型的在特征空间的距离,获得查询样本的类别概率;
查询样本属于第i类的概率表示为:
其中,为/>的查询样本的类别标签和特征,d(·)为欧式距离度量,Q表示查询集,类别原型mc的计算如下:
其中,表示第n个支持集样本,J表示支持集样本数;
最后,利用小样本分类损失对模型进行优化,小样本分类损失Lfsl表示为:
其中,C、F分别表示类别数和查询样本数,d(·)为欧式距离度量。
8.根据权利要求1所述一种基于元迁移梯度更新策略的高光谱图像分类方法,其特征在于,步骤3中所述计算源域小样本分类损失对模型进行伪更新是指,在元迁移学习中,进行源域梯度更新时,将其梯度保存下来,只进行伪更新,不作用于模型上。
9.根据权利要求1所述一种基于元迁移梯度更新策略的高光谱图像分类方法,其特征在于,步骤4中,所述获得元任务对的元梯度表示为
其中,Si为源域元任务,Ti为目标域元任务,θ为模型参数;使用元梯度对模型进行更新。
10.根据权利要求9所述一种基于元迁移梯度更新策略的高光谱图像分类方法,其特征在于,所述元迁移梯度更新策略是指:
首先,遍历平衡元任务簇中所有的源域-目标域元任务对按批次进行训练,一个批次中采样多个元任务对;在每个元任务对均得到对应的元梯度后,求取平均元梯度元梯度Gradall,利用Gradall对模型进行优化:
θ′=θ-ρGradall
其中,ρ为集合元梯度步长。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410331617.5A CN118038181B (zh) | 2024-03-22 | 2024-03-22 | 一种基于元迁移梯度更新策略的高光谱图像分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410331617.5A CN118038181B (zh) | 2024-03-22 | 2024-03-22 | 一种基于元迁移梯度更新策略的高光谱图像分类方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN118038181A true CN118038181A (zh) | 2024-05-14 |
CN118038181B CN118038181B (zh) | 2024-08-13 |
Family
ID=91002388
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410331617.5A Active CN118038181B (zh) | 2024-03-22 | 2024-03-22 | 一种基于元迁移梯度更新策略的高光谱图像分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN118038181B (zh) |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112767368A (zh) * | 2021-01-25 | 2021-05-07 | 京科互联科技(山东)有限公司 | 一种基于域自适应的跨域医疗图像分类系统及分类方法 |
WO2022113534A1 (ja) * | 2020-11-27 | 2022-06-02 | 株式会社Jvcケンウッド | 機械学習装置、機械学習方法、および学習済みモデル |
CN115375951A (zh) * | 2022-09-20 | 2022-11-22 | 中国矿业大学 | 一种基于图元迁移网络的小样本高光谱图像分类方法 |
US20230140828A1 (en) * | 2021-10-28 | 2023-05-04 | Mckinsey & Company, Inc. | Machine Learning Methods And Systems For Cataloging And Making Recommendations Based On Domain-Specific Knowledge |
WO2023087558A1 (zh) * | 2021-11-22 | 2023-05-25 | 重庆邮电大学 | 基于嵌入平滑图神经网络的小样本遥感图像场景分类方法 |
-
2024
- 2024-03-22 CN CN202410331617.5A patent/CN118038181B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2022113534A1 (ja) * | 2020-11-27 | 2022-06-02 | 株式会社Jvcケンウッド | 機械学習装置、機械学習方法、および学習済みモデル |
CN112767368A (zh) * | 2021-01-25 | 2021-05-07 | 京科互联科技(山东)有限公司 | 一种基于域自适应的跨域医疗图像分类系统及分类方法 |
US20230140828A1 (en) * | 2021-10-28 | 2023-05-04 | Mckinsey & Company, Inc. | Machine Learning Methods And Systems For Cataloging And Making Recommendations Based On Domain-Specific Knowledge |
WO2023087558A1 (zh) * | 2021-11-22 | 2023-05-25 | 重庆邮电大学 | 基于嵌入平滑图神经网络的小样本遥感图像场景分类方法 |
CN115375951A (zh) * | 2022-09-20 | 2022-11-22 | 中国矿业大学 | 一种基于图元迁移网络的小样本高光谱图像分类方法 |
Non-Patent Citations (3)
Title |
---|
HAOYU WANG ETAL.: "Graph Meta Transfer Network for Heterogeneous Few-Shot Hyperspectral Image Classification", 《IEEE TRANSACTIONS ON GEOSCIENCE AND REMOTE SENSING》, vol. 61, 31 December 2023 (2023-12-31), pages 1 - 12 * |
杜彦东 等: "元迁移学习在少样本跨域图像分类中的研究", 《中国图象图形学报》, vol. 28, no. 09, 31 December 2023 (2023-12-31), pages 2899 - 2912 * |
王浩宇 等: "关联子域对齐网络的跨域高光谱图像分类", 《中国图象图形学报》, vol. 28, no. 10, 31 December 2023 (2023-12-31), pages 3255 - 3266 * |
Also Published As
Publication number | Publication date |
---|---|
CN118038181B (zh) | 2024-08-13 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Du et al. | Spatial and spectral unmixing using the beta compositional model | |
Qin et al. | Cross-domain collaborative learning via cluster canonical correlation analysis and random walker for hyperspectral image classification | |
CN111461067B (zh) | 基于先验知识映射及修正的零样本遥感影像场景识别方法 | |
CN108446613A (zh) | 一种基于距离中心化与投影向量学习的行人重识别方法 | |
CN104573699A (zh) | 基于中等场强磁共振解剖成像的实蝇识别方法 | |
CN110163274B (zh) | 一种基于鬼成像和线性判别分析的物体分类方法 | |
Djorgovski et al. | The Palomar digital sky survey (DPOSS) | |
Moorthy et al. | Pattern similarity measures applied to mass spectra | |
Pino et al. | Semantic segmentation of radio-astronomical images | |
Han et al. | Improving sar automatic target recognition via trusted knowledge distillation from simulated data | |
CN118038181B (zh) | 一种基于元迁移梯度更新策略的高光谱图像分类方法 | |
CN116863327B (zh) | 一种基于双域分类器协同对抗的跨域小样本分类方法 | |
CN105869161A (zh) | 基于图像质量评价的高光谱图像波段选择方法 | |
Shi et al. | Proxylesskd: Direct knowledge distillation with inherited classifier for face recognition | |
CN117034017A (zh) | 一种基于深度学习的质谱图分类方法、系统、介质及设备 | |
Xin et al. | Hyperspectral Image Few-Shot Classification Network With Brownian Distance Covariance | |
CN111896609A (zh) | 一种基于人工智能分析质谱数据的方法 | |
CN115392375A (zh) | 一种多源数据融合度智能评估方法及其系统 | |
CN114863291B (zh) | 基于mcl和光谱差异度量的高光谱影像波段选择方法 | |
CN112330622B (zh) | 一种基于地物最大区分度的高光谱图像波段选择方法 | |
Webb-Robertson et al. | A Bayesian integration model of high-throughput proteomics and metabolomics data for improved early detection of microbial infections | |
CN104573746A (zh) | 基于磁共振成像的实蝇种类识别方法 | |
Zeng et al. | Data-Scarce Animal Face Alignment via Bi-Directional Cross-Species Knowledge Transfer | |
Sui et al. | An unsupervised band selection method based on overall accuracy prediction | |
CN113191259B (zh) | 用于高光谱图像分类的动态数据扩张方法及图像分类方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |