CN117422960B - 一种基于元学习的图像识别持续学习方法 - Google Patents
一种基于元学习的图像识别持续学习方法 Download PDFInfo
- Publication number
- CN117422960B CN117422960B CN202311719529.4A CN202311719529A CN117422960B CN 117422960 B CN117422960 B CN 117422960B CN 202311719529 A CN202311719529 A CN 202311719529A CN 117422960 B CN117422960 B CN 117422960B
- Authority
- CN
- China
- Prior art keywords
- training
- model
- shot
- test
- image
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 71
- 238000012937 correction Methods 0.000 claims abstract description 24
- 238000013526 transfer learning Methods 0.000 claims abstract description 9
- 238000012549 training Methods 0.000 claims description 193
- 238000012360 testing method Methods 0.000 claims description 90
- 238000011156 evaluation Methods 0.000 claims description 13
- 238000011478 gradient descent method Methods 0.000 claims description 10
- 238000012935 Averaging Methods 0.000 claims description 9
- 238000004364 calculation method Methods 0.000 claims description 9
- 238000007781 pre-processing Methods 0.000 claims description 9
- 238000001514 detection method Methods 0.000 claims description 7
- 238000013145 classification model Methods 0.000 claims description 5
- 239000011159 matrix material Substances 0.000 claims description 5
- 238000010276 construction Methods 0.000 claims description 3
- 238000002372 labelling Methods 0.000 claims description 3
- 238000013135 deep learning Methods 0.000 description 4
- 238000010801 machine learning Methods 0.000 description 4
- 230000003930 cognitive ability Effects 0.000 description 3
- 101100455978 Arabidopsis thaliana MAM1 gene Proteins 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 238000013178 mathematical model Methods 0.000 description 2
- 238000012795 verification Methods 0.000 description 2
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 230000007812 deficiency Effects 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/778—Active pattern-learning, e.g. online learning of image or video features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0985—Hyperparameter optimisation; Meta-learning; Learning-to-learn
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/35—Categorising the entire scene, e.g. birthday party or wedding scene
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V40/00—Recognition of biometric, human-related or animal-related patterns in image or video data
- G06V40/10—Human or animal bodies, e.g. vehicle occupants or pedestrians; Body parts, e.g. hands
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Multimedia (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Artificial Intelligence (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Databases & Information Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Data Mining & Analysis (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Computational Linguistics (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Image Analysis (AREA)
- Human Computer Interaction (AREA)
Abstract
一种基于元学习的图像识别持续学习方法,通过迁移学习方式用实际使用场景图像的微调数据集训练预训练模型,得到微调模型,再调整微调模型架构得到few‑shot模型;然后使用few‑shot模型推理待预测图像,得到分类结果;再对分类结果进行人工矫正,将人工矫正的矫正数据和待预测图像加入微调数据集中,从而实现持续学习。本发明通过采用元学习进行迁移学习,通过大量容易采集数据得到预训练模型,迁移到少样本的实际使用场景图像中对其进行分类。而且本发明采用持续在线学习的模式,在使用过程中,对于分类结果定期人工矫正,修正后的数据自动加入微调数据集,快速迭代微调模型。
Description
技术领域
本发明涉及元学习技术领域,特别涉及一种基于元学习的图像识别持续学习方法。
背景技术
近年来深度学习在学术界、科技界应用广泛,尤其在图像领域,目前已经在图像分类领域取得较大进展,取得不错成效。深度学习能取得了巨大成功,最为关键的因素就是利用大量的数据去驱动模型训练,使其获得良好的测试效果。但是在实际应用中深度学习算法难以实施,因为深度学习训练需要大量的标记样本,但拥有大量数据样本毕竟是少数,大部分情况没有那么多标记样本;其次模型训练耗时,对于有些检测分类要求频繁变化的应用,每次更改一次检测要求,就需要重新训练大量样本,这大大增加了时间成本。
元学习(Meta Learning)是机器学习的子领域。传统的机器学习问题是基于海量数据集从头开始学习一个用于预测的数学模型,这与人类学习、积累历史经验指导新的机器学习任务的过程相差甚远。元学习则是学习不同的机器学习任务的学习训练过程,以及学习如何更快更好地训练一个数学模型。元学习可以在很少的样本上学习到如何快速适应新任务,因此非常适合少样本的情况。从少量数据中快速学习和适应的能力对于人工智能至关重要。但然而由于元学习基于少量样品得到的模型,模型存在识别精度差的缺点。
因此,针对现有技术不足,提供一种基于元学习的图像识别持续学习方法以解决现有技术不足甚为必要。
发明内容
本发明的目的在于避免现有技术的不足之处而提供一种基于元学习的图像识别持续学习方法。该基于元学习的图像识别持续学习方法根据分类结果定期人工矫正加入微调数据集,持续扩大样品量,从而提高识别精度。
本发明的上述目的通过以下技术措施实现:
提供一种基于元学习的图像识别持续学习方法,基于迁移学习方式用实际使用场景图像的微调数据集训练预训练模型,得到微调模型,再调整微调模型架构得到few-shot模型;然后使用few-shot模型推理待预测图像,得到分类结果;再对分类结果进行人工矫正,将人工矫正的矫正数据和待预测图像加入所述微调数据集中,从而实现持续学习。
本发明的基于元学习的图像识别持续学习方法,通过如下步骤进行:
S1、构造实际使用场景图像的微调数据集,然后通过所述微调数据集对预训练模型进行多次训练,得到微调模型;
S2、调整所述微调模型的模型架构,得到few-shot模型;
S3、构建few-shot推理集和待预测图像的few-shot查询集;
S4、通过所述few-shot模型对所述few-shot推理集和所述few-shot查询集分别进行推理,得到few-shot查询集的特征和few-shot推理集的特征,然后计算few-shot查询集的特征与few-shot推理集的特征之间的相似度,得到待预测图像的分类结果;
S5、对所述S4得到分类结果进行人工矫正得到矫正数据,分别将矫正数据和待预测图像均加入至所述few-shot推理集和所述微调数据集中。
优选的,上述预训练模型通过如下步骤获得:
L1、构造元学习的骨干模型和采集多个类别的图像并标注类别;
L2、对所述L1得到的图像进行预处理,得到预训练数据集;
L3、根据所述L2得到的预训练数据集对所述L1得到的骨干模型进行多次训练,每次训练对应得到一个评估指标loss'第一测试,并对所有评估指标loss'第一测试求均值,得到loss第一均值,对loss第一均值反向更新得到权重θ第一均值,然后采用梯度下降法得到最优的预训练模型。
优选的,上述S1通过如下步骤进行:
S1.1、采集实际使用场景图像并标注类别;
S1.2、对所述S1.1得到的图像进行预处理,得到微调数据集;
S1.3、根据所述S1.2得到的微调数据集对所述预训练模型进行多次训练,每次训练对应得到一个评估指标loss'第二测试,并对所有评估指标loss'第二测试求均值,得到loss第二均值,对loss第二均值反向更新得到权重θ第二均值,然后采用梯度下降法得到最优的微调模型。
优选的,上述S2具体是将所述微调模型中隐藏层直接作为输出长度E嵌入特征的图像特征采集器,最终得到所述few-shot模型。
优选的,上述L3的每次训练方法均通过如下步骤进行:
A1、从所述预训练数据集中的训练集中随机选出图像构建任务数据的task第一训练,且所述task第一训练设置有第一训练支持集和第一训练查询集,所述第一训练支持集为N-waysK-shots形式数据集,N为图像类别数量,K为每个类别的图像张数,所述第一训练查询集为Q张查询目标图像;
A2、使用所述骨干模型分别对所述A1得到的第一训练支持集和第一训练查询集进行推理,得到第一训练分类结果,然后通过交叉熵损失方法得到loss第一训练;
A3、根据所述A2得到的loss第一训练反向更新,得到权重θ第一训练;
A4、从所述预训练数据集中的测试集中构建任务数据task第一测试,且所述task第一测试设置有第一测试支持集和第一测试查询集,所述第一测试支持集为N-ways K-shots形式数据集,N为图像类别数量,K为每个类别的图像张数,所述第一训练查询集为Q张查询目标图像;
A5、使用所述A3得到的权重θ第一训练对应的骨干模型对所述A4得到的第一测试支持集和第一测试查询集进行推理,得到测试分类结果,然后通过交叉熵损失方法得到loss'第一测试。
优选的,上述S1.3的每次训练方法均通过如下步骤进行:
B1、从所述微调数据集中的训练集中构建任务数据的task第二训练,且所述task第二训练设置有第二训练支持集和第二训练查询集,所述第二训练支持集为N-ways K-shots形式数据集,N为图像类别数量,K为每个类别的图像张数,所述第二训练查询集为Q张查询目标图像;
B2、使用所述预训练模型分别对所述B1得到的第二训练支持集和第二训练查询集进行推理,得到第二训练分类结果,然后通过交叉熵损失方法得到loss第二训练;
B3、根据所述B2得到的loss第二训练反向更新,得到权重θ第二训练;
B4、从所述微调数据集中的测试集中构建任务数据task第二测试,且所述task第二测试设置有第二测试支持集和第二测试查询集,所述第二测试支持集为N-ways K-shots形式数据集,N为图像类别数量,K为每个类别的图像张数,所述第二训练查询集为Q张查询目标图像;
B5、使用所述B3得到的权重θ第二训练对应的预训练模型对所述B4得到的第二测试支持集和第二测试查询集进行推理,得到第二测试分类结果,然后通过交叉熵损失方法得到loss'第二测试。
优选的,上述微调数据集与所述few-shot推理集相同或不相同。
优选的,上述few-shot推理集为N-ways K-shots形式数据集,N为图像类别数量,K为每个类别的图像张数。
优选的,上述few-shot查询集为1张。
优选的,上述S4通过如下步骤进行:
S4.1、通过所述few-shot模型对所述few-shot推理集和所述few-shot查询集分别进行推理,即对N*K+1张图像分别进行推理,得到N*K*E的支持集特征矩阵,然后对支持集特征矩阵中每个类别的特征求均值并归一化,得到N*E的支持集平均特征矩阵;
S4.2、将N*E的支持集平均特征矩阵与查询集的图像的1*E特征做相似度计算进行相似度计算,得到查询集的图像分类结果。
优选的,上述L1的训练轮数为10000~100000,batch_size为2~5。
优选的,上述S1的训练轮数为3~50,batch_size为2~5。
优选的,上述骨干模型为图像分类模型或目标检测模型。
在所述L2和所述S1.2中的预处理方法均为将每张图像的长边统一缩放到size并裁剪为size×size的尺寸。
优选的,上述骨干模型为VIT模型。
在所述微调数据集和所述预训练数据集中,训练集的图像数量:测试集的图像数量=(1-x):x,x为0.1~0.3,且训练集的图像类别与测试集的图像类别均不相同。
本发明的一种基于元学习的图像识别持续学习方法,基于迁移学习方式用实际使用场景图像的微调数据集训练预训练模型,得到微调模型,再调整微调模型架构得到few-shot模型;然后使用few-shot模型推理待预测图像,得到分类结果;再对分类结果进行人工矫正,将人工矫正的矫正数据和待预测图像加入所述微调数据集中,从而实现持续学习。本发明通过采用元学习进行迁移学习,通过大量容易采集数据得到预训练模型,迁移到少样本的实际使用场景图像中对其进行分类。而且本发明采用持续在线学习的模式,在使用过程中,对于分类结果定期人工矫正,修正后的数据自动加入微调数据集,快速迭代模型。而且持续扩大微调数据集,使few-shot模型具有更加全面的认知能力,当持续学习时间越长时,积累的实际使用场景图像越多,其识别准确率则越高。
附图说明
利用附图对本发明作进一步的说明,但附图中的内容不构成对本发明的任何限制。
图1为一种基于元学习的图像识别持续学习方法的流程图。
图2为预训练模型的训练过程流程图。
图3为微调模型的训练过程流程图。
图4为实施例2的VIT模型的架构。
具体实施方式
结合以下实施例对本发明的技术方案作进一步说明。
实施例1
一种基于元学习的图像识别持续学习方法,如图1,基于迁移学习方式用实际使用场景图像的微调数据集训练预训练模型,得到微调模型,再调整微调模型架构得到few-shot模型;然后使用few-shot模型推理待预测图像,得到分类结果;再对分类结果进行人工矫正,将人工矫正的矫正数据和待预测图像加入所述微调数据集中,从而实现持续学习。
需要说明的是,本发明中的人工矫正会得到一些使用场景图像中的数据,这种数据对于少样本学习非常宝贵,因此将其加入至微调数据集进入学习,从而使得整个系统更加准确。
本发明基于元学习的图像识别持续学习方法,通过如下步骤进行:
S1、构造实际使用场景图像的微调数据集,然后通过微调数据集对预训练模型进行多次训练,得到微调模型;
S2、调整S1得到的微调模型的模型架构,得到few-shot模型;
S3、构建few-shot推理集和待预测图像的few-shot查询集;
S4、通过S2得到的few-shot模型对S3得到的few-shot推理集和few-shot查询集分别进行推理,得到few-shot查询集的特征和few-shot推理集的特征,然后计算few-shot查询集的特征与few-shot推理集的特征之间的相似度,得到待预测图像的分类结果;
S5、对S4得到分类结果进行人工矫正得到矫正数据,分别将矫正数据和待预测图像均加入至few-shot推理集和微调数据集中。
需要说明的是,因为本发明的实际使用场景图像的数据量是极少的,本发明将待预测图像在矫正数据后将其加入微调数据集中,可以扩充微调数据集数据,并续继训练微调模型,从而得到更准确的微调模型。而将待预测图像在矫正数据加入few-shot推理集中,同样可以扩充few-shot推理集数据,可以不断累积经验,提升S4的推理准确率。当持续学习时间越长时,实际使用场景图像越多,从而系统的识别准确率则越高。
其中,预训练模型通过如下步骤获得:
L1、构造元学习的骨干模型和采集多个类别的图像并标注类别;
L2、对L1得到的图像进行预处理,得到预训练数据集;
L3、根据L2得到的预训练数据集对L1得到的骨干模型进行多次训练,每次训练对应得到一个评估指标loss'第一测试,并对所有评估指标loss'第一测试求均值,得到loss第一均值,对loss第一均值反向更新得到权重θ第一均值,然后采用梯度下降法得到最优的预训练模型,如图2所示。
其中,L3的每次训练方法通过如下步骤进行:
A1、从预训练数据集中的训练集中随机选出图像构建任务数据的task第一训练,且task第一训练设置有第一训练支持集和第一训练查询集,第一训练支持集为N-ways K-shots形式数据集,N为图像类别数量,K为每个类别的图像张数,第一训练查询集为Q张查询目标图像;
A2、使用骨干模型分别对A1得到的第一训练支持集和第一训练查询集进行推理,得到第一训练分类结果,然后通过交叉熵损失方法得到loss第一训练;
A3、根据A2得到的loss第一训练反向更新,得到权重θ第一训练;
A4、从预训练数据集中的测试集中构建任务数据task第一测试,且task第一测试设置有第一测试支持集和第一测试查询集,第一测试支持集为N-ways K-shots形式数据集,N为图像类别数量,K为每个类别的图像张数,第一训练查询集为Q张查询目标图像;
A5、使用A3得到的权重θ第一训练对应的骨干模型对A4得到的第一测试支持集和第一测试查询集进行推理,得到测试分类结果,然后通过交叉熵损失方法得到loss'第一测试。
本发明的L1的训练轮数为10000~100000,batch_size为2~5;S1的训练轮数为3~50,batch_size为2~5。
需要说明的是,本发明的预训练数据集为常见的图像,其种类相对较多,例如一些常见类别:吃饭、打篮球、滑冰、跳舞等等。每种常见种类均可采样若干张图像。因为预训练数据集的数据量大,需要特别多的训练轮数,如60000,而微调数据集的一般情况下数量少,如10个类别,每个类别为6张时,则需要较少的训练轮数,如5~10轮。当然L1的训练轮数和S1的训练轮数可以根据实际的数据量进行调整。
其中,S1通过如下步骤进行:
S1.1、采集实际使用场景图像并标注类别;
S1.2、对S1.1得到的图像进行预处理,得到微调数据集;
S1.3、根据S1.2得到的微调数据集对预训练模型进行多次训练,每次训练对应得到一个评估指标loss'第二测试,并对所有评估指标loss'第二测试求均值,得到loss第二均值,对loss第二均值反向更新得到权重θ第二均值,然后采用梯度下降法得到最优的微调模型,如图3所示。
其中,S2具体是将S1得到的微调模型中的Class层删除,从而使隐藏层直接作为输出长度E嵌入特征的图像特征采集器,最终得到few-shot模型。
其中,S1.3的每次训练方法均通过如下步骤进行:
B1、从微调数据集中的训练集中构建任务数据的task第二训练,且task第二训练设置有第二训练支持集和第二训练查询集,第二训练支持集为N-ways K-shots形式数据集,N为图像类别数量,K为每个类别的图像张数,第二训练查询集为Q张查询目标图像;
B2、使用预训练模型分别对B1得到的第二训练支持集和第二训练查询集进行推理,得到第二训练分类结果,然后通过交叉熵损失方法得到loss第二训练;
B3、根据B2得到的loss第二训练反向更新,得到权重θ第二训练;
B4、从微调数据集中的测试集中构建任务数据task第二测试,且task第二测试中包括第二测试支持集和第二测试查询集,第二测试支持集为N-ways K-shots形式数据集,N为图像类别数量,K为每个类别的图像张数,第二训练查询集为Q张查询目标图像;
B5、使用B3得到的权重θ第二训练对应的预训练模型对B4得到的第二测试支持集和第二测试查询集进行推理,得到第二测试分类结果,然后通过交叉熵损失方法得到loss'第二测试。
对于L3和S1.3中的评估指标loss计算具体为:每次训练均构建任务数据,其中支持集为N*K张图像,然后对所有图像进行推理,得到N*K*N的矩阵A,此时对应的标签构成N*K*1的矩阵B,对矩阵A和B进行常用的交叉熵损失计算loss,这里使用pytorch中的torch.nn.CrossEntropyLoss。具体方式为:
loss=torch.nn.CrossEntropyLoss(reduction='mean')
本发明中具体采用了MAML算法,它是一种元学习算法,类似的还有其他元学习算法。在学习的过程中,MAML维护了两套模型权重和超参数,内层权重针对每个训练集任务(对N-ways K-shots图像做图像分类)独立训练求loss(即图像分类的交叉熵损失),并使用梯度下降法对权重进行调优,使得loss不断下降接近0,这个过程就是一般的训练图像分类模型的过程。外层权重将每次内层权重应用到测试集任务上,对所有测试集任务求loss,得到平均loss,采用梯度下降法得到最优权重。其中梯度下降法计算最优权重为本技术领域的常见计算方法,本领域技术人员应当知晓,在此不再一一赘述。
需要说明的是,本发明中在每次训练是均需要构建任务数据,其中训练支持集为N-ways K-shots形式数据集,也就是说每次任务是N*K张图像,总共N个类别。假设原始数据集有1000个分类,类别分布是0~999,将N*K张图像的类别重新分配为0~N-1。
以预训练数据集为例,假设预训练数据集共有1000种类的样本,并假设训练支持集为5-ways 3-shots的任务,也就是每次构造任务时,随机从1000分类中选出5个种类,每个种类抽出3张图像。 假设这1000个分类的分类定义0,1,2,......,999。第一次训练时训练支持集时,随机选出了分类为0、34、123、435、678这5个类别,因为是要训练模型根据举例来看目标与哪一类举例更接近,所以原始的分类是如何是需要让模型无视的。这个情况下,直接将原始分类重置为0、1、2、3和4,也就是:0→0、34→1、123→2、435→3和678→4,然后这样去训练分类。对于第二次训练时,同样构造训练支持集,假设随机选出分类为23、56、234、789、899,那么同样重置它的分类:23→0、56→1、234→2、789→3和899→4,然后这样去训练分类。
还需要说明的是元学习是一类学习方法,其目标是通过学习how to学习,使学习算法能够从过去解决的相关任务中获得经验,从而对新的未知任务进行快速有效的学习。它本身是一种学习方法,并不独立产生模型。本发明中的骨干模型是执行具体任务的模型,本发明中的骨干模型可以为图像分类模型、目标检测模型,其中图像分类模型作为图像分类,目标检测模型作为目标检测,它们均具有具体的模型结构。元学习是利用骨干模型的结构进行学习的过程,它通过让模型适应不同的任务,得到一组最优的适合所有任务共性的权重,这个权重也是匹配骨干模型的权重。
本发明的微调数据集可以与few-shot推理集相同,也可以与few-shot推理集不相同,具体可以根据实际情况设置。
本发明的few-shot推理集为N-ways K-shots形式数据集,N为图像类别数量,K为每个类别的图像张数;few-shot查询集为1张。
需要说明的是对于本发明中的few-shot推理集类别重新分配方式与训练支持集相同,在此不再一一赘述。
本发明的S4通过如下步骤进行:
S4.1、通过S3得到的few-shot模型对S3得到的few-shot推理集和few-shot查询集分别进行推理,即对N*K+1张图像分别进行推理,得到N*K*E的支持集特征矩阵,然后对支持集特征矩阵中每个类别的特征求均值并归一化,得到N*E的支持集平均特征矩阵;
S4.2、将N*E的支持集平均特征矩阵与查询集的图像的1*E特征做相似度计算进行相似度计算,得到查询集的图像分类结果。
需要说明的是,本发明中的相似度计算为分类结果中常见的计算方式,旨在获得与支持集最相似的特征对应的分类,即为分类结果,本领域技术人员应当知晓相似度计算具体操作过程,在此不再一一赘述。本发明S4.1在推理时将同一类别的所有图像的特征平均,更容易得到共性特征。
还需要强调的是,本发明持续将待预测图像加入few-shot推理集,当图像越多对使用场景越适应。
在微调数据集和预训练数据集中,训练集的图像数量:测试集的图像数量=(1-x):x,x为0.1~0.3,且训练集的图像类别与测试集的图像类别均不相同。
需要说明的是,本发明的训练集的图像数量与测试集的图像数量比可根据实际情况设定,本实施例具体为0.8:0.2,且训练集和测试集中的图片种类是没有交集的,以预训练数据集为例,假设预训练数据集共有1000种类的样本,那么训练集占80%,即样本种类0~799属于训练集;验证集20%,样本种类800~999属于验证集。
在所述L2和S1.2中的预处理方法为将每张图像的长边统一缩放到size并裁剪为size×size的尺寸。
该基于元学习的图像识别持续学习方法通过采用元学习进行迁移学习,通过大量容易采集数据得到预训练模型,迁移到少样本的实际使用场景图像中对其进行分类。而且本发明采用持续在线学习的模式,在使用过程中,对于分类结果定期人工矫正,修正后的数据自动加入微调数据集,快速迭代模型。而且持续扩大微调数据集,使few-shot模型具有更加全面的认知能力,当持续学习时间越长时,积累的实际使用场景图像越多,其识别准确率则越高。
实施例2
一种基于元学习的图像识别持续学习方法,其他特征与实施例1相同,不同的是本实施例的骨干模型为VIT模型,如图4所示。
在S2具体是将S1得到的微调后VIT模型中的Class层删除,从而使隐藏层直接作为输出长度E嵌入特征的图像特征采集器,最终得到few-shot模型。
需要说明的是,VIT模型为谷歌搭建元学习模型,在S2中将VIT最后输出维度改为ways,也就是N。
与实施例1相比,本实施例通过使用VIT模型作为骨干模型,从而更好地进行图像分类识别。
实施例3
一种如实施例2的基于元学习的图像识别持续学习方法,本实施例的使用imagenet的数据作为预训练数据,这些已经标注过,从而可以减少大量工作量。
将标注好的文件,划分为训练集80%,测试集20%,分别存储于train.csv,test.csv两个文件中,结构如下:
/dataset
|——images
|——train.csv
|——test.csv。
对于微调数据集和few-shot查询集均为少量真实场景图像,这些图像一般较少,如只几十张,可以手工采集或者模拟生成。然后对这些图像进行人工标注。其中微调数据集和few-shot推理集中的种类,如包括各种鸟类的图像,以及各种风筝的图像。本实施例的few-shot查询集如一张鸟类的图像,通过S1-S4操作后,识别出该张图像为鸟。然后进入S5,进行人工矫正该张图像实际分类结果,如果两次识别同样得到分类结果为鸟,矫正将矫正数据和该张图像加入微调数据集和few-shot推理集中;如果当分类结果为风筝时,人工矫正该张图像的分类结果得到矫正数据为鸟时,将该张图像的分类替代成鸟再加入微调数据集和few-shot推理集中。
图像和标注依然采用如下结构进行存放:
/dataset
|——images
|——train.csv
|——test.csv
该基于元学习的图像识别持续学习方法通过采用元学习进行迁移学习,通过大量容易采集数据得到预训练模型,迁移到少样本的实际使用场景图像中对其进行分类。而且本发明采用持续在线学习的模式,在使用过程中,对于分类结果定期人工矫正,修正后的数据自动加入微调数据集,快速迭代模型。而且持续扩大微调数据集,使few-shot模型具有更加全面的认知能力,当持续学习时间越长时,积累的实际使用场景图像越多,其识别准确率则越高。
最后应当说明的是,以上实施例仅用以说明本发明的技术方案而非对本发明保护范围的限制,尽管参照较佳实施例对本发明作了详细说明,本领域的普通技术人员应当理解,可以对本发明技术方案进行修改或者等同替换,而不脱离本发明技术方案的实质和范围。
Claims (7)
1.一种基于元学习的图像识别持续学习方法,其特征在于:基于迁移学习方式用实际使用场景图像的微调数据集训练预训练模型,得到微调模型,再调整微调模型架构得到few-shot模型;然后使用few-shot模型推理待预测图像,得到分类结果;再对分类结果进行人工矫正,将人工矫正的矫正数据和待预测图像加入所述微调数据集中,从而实现持续学习;
通过如下步骤进行:
S1、构造实际使用场景图像的微调数据集,然后通过所述微调数据集对预训练模型进行多次训练,得到微调模型;
S2、调整所述微调模型的模型架构,得到few-shot模型;
S3、构建few-shot推理集和待预测图像的few-shot查询集;
S4、通过所述few-shot模型对所述few-shot推理集和所述few-shot查询集分别进行推理,得到few-shot查询集的特征和few-shot推理集的特征,然后计算few-shot查询集的特征与few-shot推理集的特征之间的相似度,得到待预测图像的分类结果;
S5、对所述S4得到的分类结果进行人工矫正得到矫正数据,分别将矫正数据和待预测图像均加入至所述few-shot推理集和所述微调数据集中;
所述S1通过如下步骤进行:
S1.1、采集实际使用场景图像并标注类别;
S1.2、对所述S1.1得到的图像进行预处理,得到微调数据集;
S1.3、根据所述S1.2得到的微调数据集对所述预训练模型进行多次训练,每次训练对应得到一个评估指标loss'第二测试,并对所有评估指标loss'第二测试求均值,得到loss第二均值,对loss第二均值反向更新得到权重θ第二均值,然后采用梯度下降法得到最优的微调模型;
所述预训练模型通过如下步骤获得:
L1、构造元学习的骨干模型和采集多个类别的图像并标注类别;
L2、对所述L1得到的图像进行预处理,得到预训练数据集;
L3、根据所述L2得到的预训练数据集对所述L1得到的骨干模型进行多次训练,每次训练对应得到一个评估指标loss'第一测试,并对所有评估指标loss'第一测试求均值,得到loss第一均值,对loss第一均值反向更新得到权重θ第一均值,然后采用梯度下降法得到最优的预训练模型。
2.根据权利要求1所述的基于元学习的图像识别持续学习方法,其特征在于:所述S2具体是将所述微调模型中隐藏层直接作为输出长度E嵌入特征的图像特征采集器,最终得到所述few-shot模型。
3.根据权利要求2所述的基于元学习的图像识别持续学习方法,其特征在于,所述L3的每次训练方法均通过如下步骤进行:
A1、从所述预训练数据集中的训练集中随机选出图像构建任务数据的task第一训练,且所述task第一训练设置有第一训练支持集和第一训练查询集,所述第一训练支持集为N-ways K-shots形式数据集,N为图像类别数量,K为每个类别的图像张数,所述第一训练查询集为Q张查询目标图像;
A2、使用所述骨干模型分别对所述A1得到的第一训练支持集和第一训练查询集进行推理,得到第一训练分类结果,然后通过交叉熵损失方法得到loss第一训练;
A3、根据所述A2得到的loss第一训练反向更新,得到权重θ第一训练;
A4、从所述预训练数据集中的测试集中构建任务数据task第一测试,且所述task第一测试设置有第一测试支持集和第一测试查询集,所述第一测试支持集为N-ways K-shots形式数据集,N为图像类别数量,K为每个类别的图像张数,所述第一测试查询集为Q张查询目标图像;
A5、使用所述A3得到的权重θ第一训练对应的骨干模型对所述A4得到的第一测试支持集和第一测试查询集进行推理,得到测试分类结果,然后通过交叉熵损失方法得到loss'第一测试。
4.根据权利要求3所述的基于元学习的图像识别持续学习方法,其特征在于,所述S1.3的每次训练方法均通过如下步骤进行:
B1、从所述微调数据集中的训练集中构建任务数据的task第二训练,且所述task第二训练设置有第二训练支持集和第二训练查询集,所述第二训练支持集为N-ways K-shots形式数据集,N为图像类别数量,K为每个类别的图像张数,所述第二训练查询集为Q张查询目标图像;
B2、使用所述预训练模型分别对所述B1得到的第二训练支持集和第二训练查询集进行推理,得到第二训练分类结果,然后通过交叉熵损失方法得到loss第二训练;
B3、根据所述B2得到的loss第二训练反向更新,得到权重θ第二训练;
B4、从所述微调数据集中的测试集中构建任务数据task第二测试,且所述task第二测试设置有第二测试支持集和第二测试查询集,所述第二测试支持集为N-ways K-shots形式数据集,N为图像类别数量,K为每个类别的图像张数,所述第二测试查询集为Q张查询目标图像;
B5、使用所述B3得到的权重θ第二训练对应的预训练模型对所述B4得到的第二测试支持集和第二测试查询集进行推理,得到第二测试分类结果,然后通过交叉熵损失方法得到loss'第二测试。
5.根据权利要求1所述的基于元学习的图像识别持续学习方法,其特征在于:所述微调数据集与所述few-shot推理集相同或不相同;
所述few-shot推理集为N-ways K-shots形式数据集,N为图像类别数量,K为每个类别的图像张数;
所述few-shot查询集为1张;
所述骨干模型为图像分类模型或目标检测模型。
6.根据权利要求5所述的基于元学习的图像识别持续学习方法,其特征在于,所述S4通过如下步骤进行:
S4.1、通过所述few-shot模型对所述few-shot推理集和所述few-shot查询集分别进行推理,即对N*K+1张图像分别进行推理,得到N*K*E的支持集特征矩阵,然后对支持集特征矩阵中每个类别的特征求均值并归一化,得到N*E的支持集平均特征矩阵;
S4.2、将N*E的支持集平均特征矩阵与查询集的图像的1*E特征进行相似度计算,得到查询集的图像分类结果。
7.根据权利要求1所述的基于元学习的图像识别持续学习方法,其特征在于:所述L1的训练轮数为10000~100000,batch_size为2~5;
所述S1的训练轮数为3~50,batch_size为2~5;
在所述L2和所述S1.2中的预处理方法均为将每张图像的长边统一缩放到size并裁剪为size×size的尺寸;
所述骨干模型为VIT模型;
在所述微调数据集和所述预训练数据集中,训练集的图像数量:测试集的图像数量=(1-x):x,x为0.1~0.3,且训练集的图像类别与测试集的图像类别均不相同。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311719529.4A CN117422960B (zh) | 2023-12-14 | 2023-12-14 | 一种基于元学习的图像识别持续学习方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311719529.4A CN117422960B (zh) | 2023-12-14 | 2023-12-14 | 一种基于元学习的图像识别持续学习方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117422960A CN117422960A (zh) | 2024-01-19 |
CN117422960B true CN117422960B (zh) | 2024-03-26 |
Family
ID=89530449
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311719529.4A Active CN117422960B (zh) | 2023-12-14 | 2023-12-14 | 一种基于元学习的图像识别持续学习方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117422960B (zh) |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110569886A (zh) * | 2019-08-20 | 2019-12-13 | 天津大学 | 一种双向通道注意力元学习的图像分类方法 |
CN114118092A (zh) * | 2021-12-03 | 2022-03-01 | 东南大学 | 一种快速启动的交互式关系标注与抽取框架 |
WO2022088677A1 (zh) * | 2020-10-26 | 2022-05-05 | 北京百度网讯科技有限公司 | 建立区域热度预测模型、区域热度预测的方法及装置 |
CN115564987A (zh) * | 2022-09-16 | 2023-01-03 | 华中科技大学 | 一种基于元学习的图像分类模型的训练方法及应用 |
-
2023
- 2023-12-14 CN CN202311719529.4A patent/CN117422960B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110569886A (zh) * | 2019-08-20 | 2019-12-13 | 天津大学 | 一种双向通道注意力元学习的图像分类方法 |
WO2022088677A1 (zh) * | 2020-10-26 | 2022-05-05 | 北京百度网讯科技有限公司 | 建立区域热度预测模型、区域热度预测的方法及装置 |
CN114118092A (zh) * | 2021-12-03 | 2022-03-01 | 东南大学 | 一种快速启动的交互式关系标注与抽取框架 |
CN115564987A (zh) * | 2022-09-16 | 2023-01-03 | 华中科技大学 | 一种基于元学习的图像分类模型的训练方法及应用 |
Non-Patent Citations (1)
Title |
---|
Meta-Baseline: Exploring Simple Meta-Learning for Few-Shot Learning;Yinbo Chen et al.;2021 IEEE/CVF International Conference on Computer Vision (ICCV);20211231;第9042-9051页 * |
Also Published As
Publication number | Publication date |
---|---|
CN117422960A (zh) | 2024-01-19 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110516596B (zh) | 基于Octave卷积的空谱注意力高光谱图像分类方法 | |
CN110188227B (zh) | 一种基于深度学习与低秩矩阵优化的哈希图像检索方法 | |
CN106845401B (zh) | 一种基于多空间卷积神经网络的害虫图像识别方法 | |
CN115018021B (zh) | 基于图结构与异常注意力机制的机房异常检测方法及装置 | |
CN106991666B (zh) | 一种适用于多尺寸图片信息的病害图像识别方法 | |
CN111079784B (zh) | 基于卷积神经网络的烘烤过程中烤烟烘烤阶段识别方法 | |
CN113128620B (zh) | 一种基于层次关系的半监督领域自适应图片分类方法 | |
CN108021947A (zh) | 一种基于视觉的分层极限学习机目标识别方法 | |
CN106444379A (zh) | 基于物联网推荐的智能烘干远程控制方法及系统 | |
CN110807760A (zh) | 一种烟叶分级方法及系统 | |
CN112686376A (zh) | 一种基于时序图神经网络的节点表示方法及增量学习方法 | |
CN111239137B (zh) | 基于迁移学习与自适应深度卷积神经网络的谷物质量检测方法 | |
CN110990784A (zh) | 一种基于梯度提升回归树的烟支通风率预测方法 | |
CN110598848A (zh) | 一种基于通道剪枝的迁移学习加速方法 | |
CN113780242A (zh) | 一种基于模型迁移学习的跨场景水声目标分类方法 | |
CN116089883B (zh) | 用于提高已有类别增量学习新旧类别区分度的训练方法 | |
CN115115830A (zh) | 一种基于改进Transformer的家畜图像实例分割方法 | |
WO2023231204A1 (zh) | 一种基于 ics-bp 神经网络的传感器物理量回归方法 | |
CN112116002A (zh) | 一种检测模型的确定方法、验证方法和装置 | |
Gu et al. | No-reference image quality assessment with reinforcement recursive list-wise ranking | |
CN110163224B (zh) | 一种可在线学习的辅助数据标注方法 | |
CN117422960B (zh) | 一种基于元学习的图像识别持续学习方法 | |
CN116452904B (zh) | 图像美学质量确定方法 | |
CN109165587A (zh) | 智能图像信息抽取方法 | |
CN112183292B (zh) | 基于低空遥感的大田作物干旱表型提取与抗旱性评估方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |