CN111178543A - 一种基于元学习的概率域泛化学习方法 - Google Patents
一种基于元学习的概率域泛化学习方法 Download PDFInfo
- Publication number
- CN111178543A CN111178543A CN201911399242.1A CN201911399242A CN111178543A CN 111178543 A CN111178543 A CN 111178543A CN 201911399242 A CN201911399242 A CN 201911399242A CN 111178543 A CN111178543 A CN 111178543A
- Authority
- CN
- China
- Prior art keywords
- learning
- meta
- category
- distribution
- domain
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 43
- 238000000605 extraction Methods 0.000 claims description 13
- 238000012549 training Methods 0.000 claims description 12
- 238000012360 testing method Methods 0.000 claims description 8
- 238000013145 classification model Methods 0.000 claims description 7
- 238000012935 Averaging Methods 0.000 claims description 6
- 238000013527 convolutional neural network Methods 0.000 claims description 3
- 230000006870 function Effects 0.000 claims description 3
- 239000011159 matrix material Substances 0.000 claims description 3
- 238000011176 pooling Methods 0.000 claims description 3
- VWDWKYIASSYTQR-UHFFFAOYSA-N sodium nitrate Chemical compound [Na+].[O-][N+]([O-])=O VWDWKYIASSYTQR-UHFFFAOYSA-N 0.000 claims description 3
- 238000010998 test method Methods 0.000 claims description 2
- 238000013459 approach Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 230000006978 adaptation Effects 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 239000000284 extract Substances 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 230000001617 migratory effect Effects 0.000 description 1
- 239000004576 sand Substances 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/04—Inference or reasoning models
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- Software Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Computational Linguistics (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Biophysics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Biomedical Technology (AREA)
- Medical Informatics (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于元学习的概率域泛化学习方法,属于元学习领域,一种基于元学习的概率域泛化学习方法,可以实现首次将元学习思想结合到域泛化中,利用元学习框架解决域泛化中随着源域数目增加参数线性增加的问题;首次将变分信息瓶颈思想结合到元学习和域泛化中,可以进一步增加本专利的泛化能力;本方案可以通过元学习解决参数随着源域数目线性增加问题,并通过元学习框架,可以更加精确地获取域不变的特征表示,为了进一步增加本方案的域泛化性能,本方案将变分思想和信息瓶颈结合,将其融入到一个统一的概率框架中,形成一种全新的,并及其有效的基于元学习的概率域泛化学习方法。
Description
技术领域
本发明涉及元学习领域,更具体地说,涉及一种基于元学习的概率域泛化学习方法。
背景技术
传统的机器学习假设训练数据与测试数据服从相同的数据分布,这个条件在实际应用中很难得到满足。解决这个问题有几种经典方法,包括1)迁移学习:迁移学习的目标是将从一个环境中学到的知识用来帮助新环境中的学习任务;2)域自适应:域自适应学习的重点在于如何克服源域分布和目标域分布不同,实现目标域上的学习任务;3)域泛化:目标域不可知的情况下,使得分布或者模型对未知情况具备良好特性。这几类方法的难度是递增的,在本发明中,提出一种基于变分信息瓶颈元学习的概率域泛化学习方法,就是针对域泛化的方法。
目前针对域泛化的方法主要有1)基于特征的方法,其只要通过设计跨域不变特征实现域泛化;2)基于分类器的方法,其针对每个数据集也就是源域中的每一个子域对子分类器进行设计,然后将子分类器结合成一个融合分类器来实现;3)信息瓶颈:任何神经网络可以通过隐层与输入和输出变量之间的共享信息(mutual information)来量化,深度学习的目标就是在学习的过程中最大化地压缩输入信息,最大化地保留输出信息。信息瓶颈就是通过控制输入和输出变量之间的共享信息达到泛化的目的。
域泛化方法的关键是权衡源域到目标域之间的变化。前述方法或者是尽量抽取对域变化不敏感的特征表示或者是通过在每一个域学习得到一个模型,然后选择和目标域相近的模型进行预测。这些方法中,参数的数量会随着源域的增加而线性增加,从而在数据不充分的应用中,很容易出现过拟合现象。
发明内容
1.要解决的技术问题
针对现有技术中存在的问题,本发明的目的在于提供一种基于元学习的概率域泛化学习方法,它可以实现首次将元学习思想结合到域泛化中,利用元学习框架解决域泛化中随着源域数目增加参数线性增加的问题;首次将变分信息瓶颈思想结合到元学习和域泛化中,可以进一步增加本专利的泛化能力;本方案可以通过元学习解决参数随着源域数目线性增加问题,并通过元学习框架,可以更加精确地获取域不变的特征表示,为了进一步增加本方案的域泛化性能,本方案将变分思想和信息瓶颈结合,将其融入到一个统一的概率框架中,形成一种全新的,并及其有效的基于元学习的概率域泛化学习方法。
2.技术方案
为解决上述问题,本发明采用如下的技术方案。
一种基于元学习的概率域泛化学习方法,包括以下步骤:
输入:具有K个源域的训练数据集S,学习率λ,迭代次数Niter;
输出:参数θ,包括一个特征提取网络h的参数和两个推理网络g1和g2参数;分类模型参数ψ;
S1、从K个源域随机选取一个作为目标域,其余K-1个作为源域;
S2、从每一个源域Ds中选取包含C个类别的M个样本,表示为
S3、从目标域Dt中选取N个样本,表示为
S9、对每一个类别重复S7-S8,并按照列排列构成矩阵如下:
ψ=[ψ1,ψ2,...,ψC]
S13、将目标域的每一个类别的每一个特征送入推理网络g2中,计算关于目标域的分布;
S15、计算每个类别的损失函数如下:
S16、重复S12-S15,使其覆盖所有类别。
S18、重复S2-S17,到所有K-1结束。
进一步的,所述元学习方法在训练阶段后进行需进行元学习的测试。
进一步的,所述元学习测试的方法为:
输入:参数θ,包括一个特征提取网络h的参数和两个推理网络g1和g2参数;参数ψ,分类模型;待分类的目标域任务;
输出:分类结果;
步骤5:利用分类器参数ψ计算分类结果,即ψzj得到的向量中最大维度表示的类别,即为分类结果。
3.有益效果
相比于现有技术,本发明的优点在于:
本方案可以实现首次将元学习思想结合到域泛化中,利用元学习框架解决域泛化中随着源域数目增加参数线性增加的问题;
首次将变分信息瓶颈思想结合到元学习和域泛化中,可以进一步增加本专利的泛化能力;
本方案可以通过元学习解决参数随着源域数目线性增加问题,并通过元学习框架,可以更加精确地获取域不变的特征表示,为了进一步增加本方案的域泛化性能,本方案将变分思想和信息瓶颈结合,将其融入到一个统一的概率框架中,形成一种全新的,并及其有效的基于元学习的概率域泛化学习方法。
附图说明
图1为本发明的元学习中数据/模型关系图;
图2为本发明的在旋转的MNIST数据库上的10次测试的平均分类准确率数据表;
图3为本发明的在CLVS四个数据库上的分类准确率数据表。
具体实施方式
下面将结合本发明实施例中的附图;对本发明实施例中的技术方案进行清楚、完整地描述;显然;所描述的实施例仅仅是本发明一部分实施例;而不是全部的实施例,基于本发明中的实施例;本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例;都属于本发明保护的范围。
在本发明的描述中,需要说明的是,术语“上”、“下”、“内”、“外”、“顶/底端”等指示的方位或位置关系为基于附图所示的方位或位置关系,仅是为了便于描述本发明和简化描述,而不是指示或暗示所指的装置或元件必须具有特定的方位、以特定的方位构造和操作,因此不能理解为对本发明的限制。此外,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性。
在本发明的描述中,需要说明的是,除非另有明确的规定和限定,术语“安装”、“设置有”、“套设/接”、“连接”等,应做广义理解,例如“连接”,可以是固定连接,也可以是可拆卸连接,或一体地连接;可以是机械连接,也可以是电连接;可以是直接相连,也可以通过中间媒介间接相连,可以是两个元件内部的连通。对于本领域的普通技术人员而言,可以具体情况理解上述术语在本发明中的具体含义。
实施例1:
一种基于元学习的概率域泛化学习方法,包括以下步骤:
输入:具有K个源域的训练数据集S,学习率λ,迭代次数Niter;
输出:参数θ,包括一个特征提取网络h的参数和两个推理网络g1和g2参数;分类模型参数ψ;
S1、从K个源域随机选取一个作为目标域,其余K-1个作为源域;
S2、从每一个源域Ds中选取包含C个类别的M个样本,表示为
S3、从目标域Dt中选取N个样本,表示为
S9、对每一个类别重复S7-S8,并按照列排列构成矩阵如下:
ψ=[ψ1,ψ2,...,ψC]
S15、计算每个类别的损失函数如下:
S16、重复S12-S15,使其覆盖所有类别。
S18、重复S2-S17,到所有K-1结束。
上述方法主要为元学习的训练阶段。
元学习方法在训练阶段后进行需进行元学习的测试,元学习测试的方法为:
输入:参数θ,包括一个特征提取网络h的参数和两个推理网络g1和g2参数;参数ψ,分类模型;待分类的目标域任务;
输出:分类结果;
步骤5:利用分类器参数ψ计算分类结果,即ψzj得到的向量中最大维度表示的类别,即为分类结果。
S是元学习中源域,T是元学习中的目标域。
在元学习训练阶段,仅用到S,即源域。并在每一个任务场景中,将元学习中的源域S分成不相交的两个数据集Ds和Dt,其中的划分方法可以有很多种,其中一种划分方法即是一个任务。元学习的训练阶段就是针对每个任务场景,将模型从Ds域泛化到Dt域。
在元学习的测试阶段,就是将元学习训练得到的模型泛化到元学习目标域T中。
ψ是分类器的参数,θ是总体域泛化模型参数,其包含两个网络的参数,一个是用于特征提取的网络h,一个是用于变分推理的网络g,其中变分推理的网络包含g1和g2。
请参阅图1,图1中包括两个模型,一个模型是分类模型ψ,是在元学习训练阶段得到的,其目的是通过元学习的框架,建立起Ds到Dt域联系,使得分类器很快适应新的任务;另一个模型是整体域泛化模型,其参数为θ。
其中图2数据库是对MNIST数据集图像进行0°、15°、30°、45°、60°、75°旋转,表示在旋转的MNIST数据库上的10次测试的平均分类准确率,图3表示在CLVS四个数据库上的分类准确率。
可以实现首次将元学习思想结合到域泛化中,利用元学习框架解决域泛化中随着源域数目增加参数线性增加的问题;首次将变分信息瓶颈思想结合到元学习和域泛化中,可以进一步增加本专利的泛化能力;本方案可以通过元学习解决参数随着源域数目线性增加问题,并通过元学习框架,可以更加精确地获取域不变的特征表示,为了进一步增加本方案的域泛化性能,本方案将变分思想和信息瓶颈结合,将其融入到一个统一的概率框架中,形成一种全新的,并及其有效的基于元学习的概率域泛化学习方法。
以上所述;仅为本发明较佳的具体实施方式;但本发明的保护范围并不局限于此;任何熟悉本技术领域的技术人员在本发明揭露的技术范围内;根据本发明的技术方案及其改进构思加以等同替换或改变;都应涵盖在本发明的保护范围内。
Claims (6)
1.一种基于元学习的概率域泛化学习方法,其特征在于:包括以下步骤:
输入:具有K个源域的训练数据集S,学习率λ,迭代次数Niter;
输出:参数θ,包括一个特征提取网络h的参数和两个推理网络g1和g2参数;分类模型参数ψ;
S1、从K个源域随机选取一个作为目标域,其余K-1个作为源域;
S2、从每一个源域Ds中选取包含C个类别的M个样本,表示为
S3、从目标域Dt中选取N个样本,表示为
S9、对每一个类别重复S7-S8,并按照列排列构成矩阵如下:
ψ=[ψ1,ψ2,...,ψC]
S13、将目标域的每一个类别的每一个特征送入推理网络g2中,计算关于目标域的分布;
S15、计算每个类别的损失函数如下:
S16、重复S12-S15,使其覆盖所有类别;
S18、重复S2-S17,到所有K-1结束。
5.根据权利要求1所述的一种基于元学习的概率域泛化学习方法,其特征在于:所述元学习方法在训练阶段后进行需进行元学习的测试。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911399242.1A CN111178543B (zh) | 2019-12-30 | 2019-12-30 | 一种基于元学习的概率域泛化学习方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911399242.1A CN111178543B (zh) | 2019-12-30 | 2019-12-30 | 一种基于元学习的概率域泛化学习方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111178543A true CN111178543A (zh) | 2020-05-19 |
CN111178543B CN111178543B (zh) | 2024-01-09 |
Family
ID=70657598
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201911399242.1A Active CN111178543B (zh) | 2019-12-30 | 2019-12-30 | 一种基于元学习的概率域泛化学习方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111178543B (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111724596A (zh) * | 2020-06-23 | 2020-09-29 | 上海电科智能系统股份有限公司 | 一种智能精确自动识别预判快速路瓶颈区方法 |
CN112035649A (zh) * | 2020-09-02 | 2020-12-04 | 腾讯科技(深圳)有限公司 | 问答模型处理方法、装置、计算机设备及存储介质 |
CN112948506A (zh) * | 2021-04-01 | 2021-06-11 | 重庆邮电大学 | 一种基于卷积神经网络的改进元学习的关系预测方法 |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN102930302A (zh) * | 2012-10-18 | 2013-02-13 | 山东大学 | 基于在线序贯极限学习机的递增式人体行为识别方法 |
CN105069400A (zh) * | 2015-07-16 | 2015-11-18 | 北京工业大学 | 基于栈式稀疏自编码的人脸图像性别识别系统 |
CN105654210A (zh) * | 2016-02-26 | 2016-06-08 | 中国水产科学研究院东海水产研究所 | 一种利用海洋遥感多环境要素的集成学习渔场预报方法 |
CN105787513A (zh) * | 2016-03-01 | 2016-07-20 | 南京邮电大学 | 多示例多标记框架下基于域适应迁移学习设计方法和系统 |
CN109583342A (zh) * | 2018-11-21 | 2019-04-05 | 重庆邮电大学 | 基于迁移学习的人脸活体检测方法 |
CN110619342A (zh) * | 2018-06-20 | 2019-12-27 | 鲁东大学 | 一种基于深度迁移学习的旋转机械故障诊断方法 |
-
2019
- 2019-12-30 CN CN201911399242.1A patent/CN111178543B/zh active Active
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN102930302A (zh) * | 2012-10-18 | 2013-02-13 | 山东大学 | 基于在线序贯极限学习机的递增式人体行为识别方法 |
CN105069400A (zh) * | 2015-07-16 | 2015-11-18 | 北京工业大学 | 基于栈式稀疏自编码的人脸图像性别识别系统 |
CN105654210A (zh) * | 2016-02-26 | 2016-06-08 | 中国水产科学研究院东海水产研究所 | 一种利用海洋遥感多环境要素的集成学习渔场预报方法 |
CN105787513A (zh) * | 2016-03-01 | 2016-07-20 | 南京邮电大学 | 多示例多标记框架下基于域适应迁移学习设计方法和系统 |
CN110619342A (zh) * | 2018-06-20 | 2019-12-27 | 鲁东大学 | 一种基于深度迁移学习的旋转机械故障诊断方法 |
CN109583342A (zh) * | 2018-11-21 | 2019-04-05 | 重庆邮电大学 | 基于迁移学习的人脸活体检测方法 |
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111724596A (zh) * | 2020-06-23 | 2020-09-29 | 上海电科智能系统股份有限公司 | 一种智能精确自动识别预判快速路瓶颈区方法 |
CN111724596B (zh) * | 2020-06-23 | 2022-11-11 | 上海电科智能系统股份有限公司 | 一种智能精确自动识别预判快速路瓶颈区方法 |
CN112035649A (zh) * | 2020-09-02 | 2020-12-04 | 腾讯科技(深圳)有限公司 | 问答模型处理方法、装置、计算机设备及存储介质 |
CN112035649B (zh) * | 2020-09-02 | 2023-11-17 | 腾讯科技(深圳)有限公司 | 问答模型处理方法、装置、计算机设备及存储介质 |
CN112948506A (zh) * | 2021-04-01 | 2021-06-11 | 重庆邮电大学 | 一种基于卷积神经网络的改进元学习的关系预测方法 |
Also Published As
Publication number | Publication date |
---|---|
CN111178543B (zh) | 2024-01-09 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Luo et al. | An inherently nonnegative latent factor model for high-dimensional and sparse matrices from industrial applications | |
Lo et al. | E-graphsage: A graph neural network based intrusion detection system for iot | |
Jadhav et al. | Comparative study of K-NN, naive Bayes and decision tree classification techniques | |
CN111178543A (zh) | 一种基于元学习的概率域泛化学习方法 | |
CN109299741B (zh) | 一种基于多层检测的网络攻击类型识别方法 | |
Lei et al. | Patent analytics based on feature vector space model: A case of IoT | |
Marastoni et al. | Data augmentation and transfer learning to classify malware images in a deep learning context | |
Prabha et al. | Improved particle swarm optimization based k-means clustering | |
CN114172688B (zh) | 基于gcn-dl的加密流量网络威胁关键节点自动提取方法 | |
CN112087447A (zh) | 面向稀有攻击的网络入侵检测方法 | |
CN112529638B (zh) | 基于用户分类和深度学习的服务需求动态预测方法及系统 | |
Ojugo et al. | Computational solution of networks versus cluster groupings for social network contacts: a recommender system | |
US20230062289A1 (en) | Learning method and processing apparatus regarding machine learning model classifying input image | |
CN116506181A (zh) | 一种基于异构图注意力网络的车联网入侵检测方法 | |
Shrivastav et al. | Network traffic classification using semi-supervised approach | |
Das et al. | FERNN: A fast and evolving recurrent neural network model for streaming data classification | |
CN104468276B (zh) | 基于随机抽样多分类器的网络流量识别方法 | |
d'Andecy et al. | Indus: Incremental document understanding system focus on document classification | |
Lopes et al. | Automatic cluster labeling through artificial neural networks | |
CN114124437B (zh) | 基于原型卷积网络的加密流量识别方法 | |
Nikolaou et al. | Calibrating AdaBoost for asymmetric learning | |
Jia et al. | Trojan traffic detection based on meta-learning | |
Papakostas et al. | Evolutionary feature subset selection for pattern recognition applications | |
Meng et al. | Adaptive resonance theory (ART) for social media analytics | |
Zhu et al. | Software defect prediction model based on stacked denoising auto-encoder |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |