CN112734049A - 一种基于域自适应的多初始值元学习框架及方法 - Google Patents
一种基于域自适应的多初始值元学习框架及方法 Download PDFInfo
- Publication number
- CN112734049A CN112734049A CN202110210507.XA CN202110210507A CN112734049A CN 112734049 A CN112734049 A CN 112734049A CN 202110210507 A CN202110210507 A CN 202110210507A CN 112734049 A CN112734049 A CN 112734049A
- Authority
- CN
- China
- Prior art keywords
- domain
- meta
- network
- modulation
- cross
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Digital Transmission Methods That Use Modulated Carrier Waves (AREA)
Abstract
本发明提供一种基于域自适应的多初始值元学习框架及方法,框架包括跨域编码器,将输入数据通过共有编码器编码为共有特征向量,通过私有编码器编码为私有特征向量;跨域调制网络,将共有特征向量编码为域公用调制向量,私有特征向量编码为域专用调制向量;元分离网络,用于在源域和目标域中更新元学习器,其中元学习器的参数分为由域公用调制向量调制的参数公共部分和由域专用调制向量调制的参数私有部分,该学习框架及方法可在一定程度上提高算法在少样本问题中的准确率,并广泛适用于跨域数据的元学习中。
Description
技术领域
本发明涉及元学习技术领域,具体涉及一种基于域自适应的多初始值元学习框架及方法。
背景技术
人工智能在各种技术领域中都有着广泛的应用,其存在的基本问题是其无法像人类一样高效地学习,需要不断地用训练样本对其进行训练学习,训练样本越完善越多,则训练得到的人工智能模型的结果就越好。然而在实际过程中经常会出现训练样本数量不足的问题,因此如何进行有效的少样本学习,已成为人工智能学习领域的一个备受关注的问题。
元学习是解决少样本学习的一种有效方法,元学习也可被理解为“学习如何学习”,现有的元学习方法包括基于度量学习的方法、基于元优化的方法、基于循环模型的方法,但是这些元学习方法的损失函数仅与特定任务有关,而没有域无关或者域自适应的约束,因此,这些方法在单域任务上表现良好,而在跨域数据上都存在着泛化性能不足的缺陷。
具体而言,元测试阶段和元训练阶段中不同类别产生的分布不同,导致了元学习方法存在领域转换的问题,尽管多初始值技术在识别任务模式方面取得了成功,但依旧无法解决由不同类别分布产生的领域转移导致的其在跨域领域存在泛化不足的缺陷。也就是说,现有的领域适应方法只能使元学习方法适应单模态元测试领域,而不能适应多模态元测试领域,到目前为止,如何缓解多模式设置中元训练和元测试阶段之间的领域转换仍然是一个挑战。
总计而言,目前的元学习方法无法很好地适用于跨域数据的学习,也就限制了其在集合跨域数据的应用场景的应用。
发明内容
本发明的目的在于提供一种基于域自适应的多初始值元学习框架及方法,可广泛适用于跨域数据的元学习,且在一定程度上提高算法在少样本问题中的准确率。
为实现上述目的,本技术方案提供一种基于域自适应的多初始值元学习框架及方法,该基于域自适应的多初始化元学习框架包括:
跨域编码器,将输入数据通过共有编码器编码为共有特征向量,通过私有编码器编码为私有特征向量;
跨域调制网络,将共有特征向量编码为域公用调制向量,私有特征向量编码为域专用调制向量;
元分离网络,用于在源域和目标域中更新元学习器,其中元学习器的参数分为由域公用调制向量调制的参数公共部分和由域专用调制向量调制的参数私有部分。
其中跨域编码器的损失函数的计算公式为:
Le=Lr+Ld+Ls,
其中Lr为重构误差损失函数,通过缩小重构误差,使得编码器提取的信息尽可能保留编码前的信息,Ld为跨域差异损失函数,起到让编码器输出的共有特征向量和私有特征向量的差异增大的作用,Ls为跨域相似损失函数,起到让编码器输出的跨域共有特征向量更接近的作用,本方案训练的目标是最小化跨域编码器参数的总损失。
重构误差损失函数Lr的具体公式如下:
跨域差异损失函数Ld的具体公式如下:
跨域相似损失函数Ls的具体公式如下:
其中跨域调制网络将共有特征向量编码为域公用调制向量的公式为:
其中跨域调制网络将私有特征向量编码为域专用调制向量的公式为:
其中元分离网络的更新参数公共部分的公式如下:
元分离网络的更新参数私有部分的公式如下:
具体的,由于神经网络的低层和高层显示出不同类型的信息。低层网络往往具有良好的迁移性,其不针对特定任务,而具有针对不同任务的通用性,而网络从低到高的过程中,其特征也从一般性过渡到特定性。本发明基于这一重要现象,设计跨域编码器、跨域调制网络、元分离网络三个部分。其中元分离网络的参数公共部分为低层网络参数,不同任务的低层网络跨域共享以进行元学习和联合训练,并由域公用调制向量调制;元分离网络的参数私有部分为高层网络参数,特定于单个任务,因此是在不同域分开训练,并由域专用调制向量调制。通过有效利用低层网络的通用性,本发明可以实现提高元学习对跨域数据的泛化性能;通过有效利用高层网络的特定性,且本发明也可以实现提高元学习对跨域数据的预测速度和效率,由于元学习广泛用于解决少样本问题,因此本发明可以应用于各应用领域的少样本问题。
该基于域自适应的多初始化元学习模型的结构如图1所示,不同域的数据 Xs和Xt输入跨域编码器的共有编码器中得到共有特征向量和共有特征向量输入到跨域调制网络里面转换为域公用调制向量和不同域的数据Xs和 Xt输入跨域编码器的私有编码器中分别得到私有特征向量和私有特征向量输入到跨域调制网络里面转换为域专用调制向量和元分离网络分别在不同域利用域公用调制向量调制的参数公共部分和由域专用调制向量调制的参数私有部分更新元学习器。其中跨域调制网络用于识别多模式任务的模式,其中用跨域编码器提取隐藏的表示,然后利用网络寻找域共用和域私有调制向量。由于调制向量隐含了任务的模式信息,元分离网络将这些调制向量作为输入,学习一个好的多初始化元学习器,以实现对新的多模态任务的快速适应。可见,元分离网络通过分别使用源域和目标域的调制向量来学习跨域元学习器,其中在较低层使用公共调制向量来学习域不变知识,而在高层使用特定调制向量来学习特定领域的知识,最后,在泛化方面,利用共享编码器获取Dtest的特征到它们的调制向量,然后利用目标域中学习到的元学习器来实现快速自适应。
第二方面,本方案提供一种基于域自适应的多初始值元学习方法,利用上述基于域自适应的多初始值元学习模型进行学习,包括以下步骤:
初始化:随机初始化元分离网络参数和跨域编码器参数、跨域调制网络参数;
数据采样:从源域和目标域数据中分别采样支持集和查询集数据;
自适应于支持集的元分离网络参数获取:将支持集数据输入跨域编码器,输出共有特征向量和私有特征向量;将共有特征向量和私有特征向量输入跨域调制网络,输出域公用调制向量和域专用调制向量;将域公用调制向量调制得到元分离网络参数公共部分,将域专用调制向量调制成元分离网络参数私有部分,得到调制后的元分离网络;将支持集数据输入调制后的元分离网络,计算第一网络梯度;将第一网络梯度用于更新元分离网络参数,得到自适应于支持集的元分离网络参数,遍历源域和目标域的支持集数据。
更新网络参数:根据查询集在自适应于支持集的元分离网络上的误差,计算第二梯度,并更新元分离网络参数公共部分,元分离网络参数私有部分,跨域编码器,跨域调制网络后回归进行自适应于支持集的元分离网络参数获取步骤,直到网络收敛,输出所有网络参数。
其中计算第一网络梯度的公式如下:
其中更新元分离网络参数公共部分:
更新元分离网络参数私有部分:
更新跨域编码器
更新跨域调制网络:
相较现有技术,本技术方案的有益效果和特点如下:首先,提出了一种新的单分散网络结构的基于域自适应的多初始值元学习方法来提高元学习在多模态任务上的性能,针对多初始化域移位问题,提出了一种基于薄膜网络的跨域调制网络,将任务编码为域公共和域私有调制向量。基于生成的调制向量,一种新的元分离网络(MSN)提出了在源和目标域更新元学习器,元学习器的参数分为由域公用调制向量调制的参数公共部分和由域专用调制向量调制的参数私有部分。此外,元学习中共享的公共参数是由前几层学习的,而私有参数是由后几层学习的。关键的原因是较低的层次可能生成通用的特征,而较高的层次可能学习特定的特征。此外,将不等式测度纳入元学习的更新过程中,以进一步提高元学习的泛化能力。
附图说明
图1是根据本发明的一种基于域自适应的多初始值元学习模型框架示意图。
图2是基于域自适应的多初始值元学习方法的伪代码图。
图3是强化学习实验用的实验图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员所获得的所有其他实施例,都属于本发明保护的范围。
实施例:
本申请人采用了大量的实验,包括回归、图像分类和强化学习(RL),以评估本方案提出的方法在各种多模态少样本学习任务中的应用,为了进行比较,本申请人考虑了以下元学习方法作为参考:
1.MAML:MAML是传统的模型无关元学习算法的代表,它们已经应用于广泛的研究领域;
2.Multi-MAML:由多个MAML模型组成,每个模型都是根据从单个模态中采样的任务进行专门训练的。值得注意的是,Multi-MAML是在单模态上进行评估,这意味着该方法无须像其他方法一样进行模态的判别,而这在实际应用中是无法实现的:无法提前知道即将到来的数据是什么模态。因此其性能是MAML算法在准确辨识模态情况下的性能上界,而且其性能在实际中无法得到。
3.MMAML:作为最近的一项研究成果,MMAML应用特征线性调制(FiLM)来识别任务的模式,然后调整元学习器参数以产生多个初始化。
本方案的方法用MIML-DA表示。
回归实验:
实验条件准备:进行multimodal few-shot regression实验,本申请人从一维函数中抽取五对输入输出数据{xk,yk}k=1....K,构建多模式任务分配,本方案考虑了五种不同的函数:正弦函数、线性函数、二次函数、转换范数和双曲正切函数,它们都被视为离散任务模式,本方案在两种模式(线性和正弦函数)、三种模式(二次、线性和正弦函数)和五种模式(所有五种函数)的混合数据集上训练模型。对于每个任务,采样5对数据,然后在输出值上添加高斯噪声,进一步增加识别生成数据的函数的难度。应用μ=0且σ=0.3的高斯噪声。
实验方法:
首先考虑了三种基线方法MAML,Multi-MAML和MMAML,这三种方法都有元网络,MMAML进一步用一个调制网络来增强元网络。首先,将按x值排序的数据点输入调制网络,生成特定任务的调制向量,用于调制元分离网络,然后,进一步调整调制后的元分离网络。
实验结果:如表一所示,表一说明了本方案的方法和其他基线方法在平均均方误差(MSE)方面的性能,其中每种情况下的最低值用黑体突出显示。结果表明,所提出的MIML-DA在所有情况下都达到了最佳性能。更具体地说,传统的 MAML在所有情况下都有最大的误差,并且加入任务身份的Multi-MAML的性能明显优于MAML,这表明在多模态任务分布下MAML会退化。由于调制网络产生的向量暗示了输入数据的模式,调制后的元学习器可以得到更好的初始化。因此,基于梯度的优化方法在这种情况下可以获得更好的性能,因此MIML-DA和MMAML的性能明显优于LSTM学习者。最后,MIML-DA的性能优于MMAML,因为跨域调制网络和元分离网络减少了训练和测试阶段之间的域偏移,从而提高了泛化能力。
表一
图像分类实验:
进行multi-modal few-shot image classification实验,这种分类任务考虑了将图像分类为N类,其中标记可用样本数为K的N类,称为N-way-shot分类;创建类似于Triantafliou等人的多模式任务,将多个广泛使用的数据集组合在一起,形成由OmniglotLake等人组成的元数据集。本申请人在元数据集上训练模型,包括两种模式(Omniglot和Mini Imagenet)、三种模式(Omniglot、 Mini Imagenet和FC100)和五种模式(所有五个数据集)。
总体结果见表二。观察到本方案提出的MIML-DA方法在几乎所有情况下都达到了最佳性能,只有一个值除外。总的来说,分类方法之间的性能比较类似于回归方法。随着模式数目的增加,MIML-DA与基线之间的性能差距越来越大,表明我们的方法能够更好地处理多模式任务分布。值得注意的是,Multi-MAML获得了很好的性能,因为每个Multi-MAML很可能会过度适应一个具有较少类的单个数据集。相反,MMAML和MIML-DA从所有数据集中学习模型。结果表明,由于调制网络的特性,MMML-DA的性能略好于MMAML和MMAML,由于跨域调制网络和MSN 的特性,MIML-DA的性能要好于MMAML和MMAML。
表二
强化学习实验:
在MuJoCo物理模拟器上验证MIML-DA在多模态元强化学习中的能力,以适应基于有限经验的新任务。考虑到图3中的三个环境,在每个时间点上对agent 进行奖励,以最小化从多模态分布中采样的到未知目标的距离。
用ProMP代替MAML作为我们的基准,此外,基线Multi-ProMP使用Vuorio 等人(2019年)提出的ProMP方法为每种模式训练一个策略。由于任务的对称分布和随机初始值,agent只接受一种模式的训练移动。同样利用了ProMP对 MMAML和MIML-DA的策略和调制网络进行了优化。
结果如表三、表四和表五所示。如所观察到的,在所有三种环境中,MIML-DA 在各种模式下的表现始终优于ProMP和MMAML。值得注意的是,由于每个多ProMP 只考虑单一模式,所以Multi-ProMP表现出良好的性能。
表三
表四
表五
上述具体实施方式,并不构成对本发明保护范围的限制。本领域技术人员应该明白的是,取决于设计要求和其他因素,可以发生各种各样的修改、组合、子组合和替代。任何在本发明的精神和原则之内所作的修改、等同替换和改进等,均应包含在本发明保护范围之内。
Claims (8)
1.一种基于域自适应的多初始值元学习框架,其特征在于,包括:
跨域编码器,将输入数据通过共有编码器编码为共有特征向量,通过私有编码器编码为私有特征向量;
跨域调制网络,将共有特征向量编码为域公用调制向量,私有特征向量编码为域专用调制向量;
元分离网络,用于在源域和目标域中更新元学习器,其中元学习器的参数分为由域公用调制向量调制的参数公共部分和由域专用调制向量调制的参数私有部分。
5.根据权利要求1所述的一种基于域自适应的多初始值元学习框架,其特征在于,元分离网络的参数公共部分为低层网络参数,不同任务的低层网络跨域共享以进行元学习和联合训练,并由域公用调制向量调制;元分离网络的参数私有部分为高层网络参数,特定于单个任务,由域专用调制向量调制。
6.一种基于域自适应的多初始值元学习方法,其特征在于,包括以下步骤:
初始化:随机初始化元分离网络参数、跨域编码器参数以及跨域调制网络参数;
数据采样:从源域和目标域数据中分别采样支持集和查询集数据;
自适应于支持集的元分离网络参数获取:将支持集数据输入跨域编码器,输出共有特征向量和私有特征向量;将共有特征向量和私有特征向量输入跨域调制网络,输出域公用调制向量和域专用调制向量;将域公用调制向量调制得到元分离网络参数公共部分,将域专用调制向量调制成元分离网络参数私有部分,得到调制后的元分离网络;将支持集数据输入调制后的元分离网络,计算第一网络梯度;将第一网络梯度用于更新元分离网络参数,得到自适应于支持集的元分离网络参数,遍历源域和目标域的支持集数据;
更新网络参数:根据查询集在自适应于支持集的元分离网络上的误差,计算第二梯度,并更新元分离网络参数公共部分,元分离网络参数私有部分,跨域编码器,跨域调制网络后回归进行自适应于支持集的元分离网络参数获取步骤,直到网络收敛,输出所有网络参数。
Applications Claiming Priority (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN2020113211400 | 2020-11-23 | ||
CN202011321140 | 2020-11-23 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN112734049A true CN112734049A (zh) | 2021-04-30 |
Family
ID=75597013
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110210507.XA Pending CN112734049A (zh) | 2020-11-23 | 2021-02-25 | 一种基于域自适应的多初始值元学习框架及方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112734049A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113139536A (zh) * | 2021-05-12 | 2021-07-20 | 哈尔滨工业大学(威海) | 一种基于跨域元学习的文本验证码识别方法、设备及存储介质 |
CN113377990A (zh) * | 2021-06-09 | 2021-09-10 | 电子科技大学 | 基于元自步学习的视频/图片-文本跨模态匹配训练方法 |
CN114202028A (zh) * | 2021-12-13 | 2022-03-18 | 四川大学 | 基于mamtl的滚动轴承寿命阶段识别方法 |
-
2021
- 2021-02-25 CN CN202110210507.XA patent/CN112734049A/zh active Pending
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113139536A (zh) * | 2021-05-12 | 2021-07-20 | 哈尔滨工业大学(威海) | 一种基于跨域元学习的文本验证码识别方法、设备及存储介质 |
CN113377990A (zh) * | 2021-06-09 | 2021-09-10 | 电子科技大学 | 基于元自步学习的视频/图片-文本跨模态匹配训练方法 |
CN114202028A (zh) * | 2021-12-13 | 2022-03-18 | 四川大学 | 基于mamtl的滚动轴承寿命阶段识别方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112734049A (zh) | 一种基于域自适应的多初始值元学习框架及方法 | |
CN111291212B (zh) | 基于图卷积神经网络的零样本草图图像检索方法和系统 | |
US20190205334A1 (en) | Method for learning cross-domain relations based on generative adversarial networks | |
Kim et al. | Diffusionclip: Text-guided image manipulation using diffusion models | |
Park et al. | Learning symmetric embeddings for equivariant world models | |
CN112733965B (zh) | 一种基于小样本学习的无标签图像分类方法 | |
Boney et al. | Semi-supervised few-shot learning with prototypical networks | |
CN112307883B (zh) | 训练方法、装置、电子设备以及计算机可读存储介质 | |
CN109348229B (zh) | 基于异构特征子空间迁移的jpeg图像失配隐写分析方法 | |
CN111931814A (zh) | 一种基于类内结构紧致约束的无监督对抗域适应方法 | |
Vayer et al. | Fused Gromov-Wasserstein distance for structured objects: theoretical foundations and mathematical properties | |
Liu et al. | Mitigating barren plateaus with transfer-learning-inspired parameter initializations | |
CN114419323A (zh) | 基于跨模态学习与领域自适应rgbd图像语义分割方法 | |
Hughes et al. | A semi-supervised approach to SAR-optical image matching | |
CN108629374A (zh) | 一种基于卷积神经网络的无监督多模态子空间聚类方法 | |
Habib et al. | Knowledge distillation in vision transformers: A critical review | |
CN114372505A (zh) | 一种无监督网络对齐方法和系统 | |
Stergiopoulou et al. | Fluctuation-based deconvolution in fluorescence microscopy using plug-and-play denoisers | |
CN116383470B (zh) | 一种具有隐私保护的图像搜索方法 | |
CN116310545A (zh) | 一种基于深度层次化最优传输的跨域舌头图像分类方法 | |
Li et al. | Automatic Dictionary Learning Sparse Representation for Image Denoising. | |
Fabian et al. | Learning deep representation by increasing ConvNets Depth for few shot learning | |
Zeng et al. | Incomplete texture repair of iris based on generative adversarial networks | |
Zhou et al. | BPJDet: Extended Object Representation for Generic Body-Part Joint Detection | |
Lu et al. | Image Dehazing Based on CycleGAN with an Enhanced Generator and a Multiscale Discriminator |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |