CN112200262A - 支持多任务和跨任务的小样本分类训练方法及装置 - Google Patents
支持多任务和跨任务的小样本分类训练方法及装置 Download PDFInfo
- Publication number
- CN112200262A CN112200262A CN202011133629.5A CN202011133629A CN112200262A CN 112200262 A CN112200262 A CN 112200262A CN 202011133629 A CN202011133629 A CN 202011133629A CN 112200262 A CN112200262 A CN 112200262A
- Authority
- CN
- China
- Prior art keywords
- task
- training
- class
- samples
- sample
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 183
- 238000000034 method Methods 0.000 title claims abstract description 57
- 238000012545 processing Methods 0.000 claims abstract description 33
- 238000005070 sampling Methods 0.000 claims description 43
- 230000006870 function Effects 0.000 claims description 42
- 238000006243 chemical reaction Methods 0.000 claims description 8
- 238000005516 engineering process Methods 0.000 abstract description 8
- 238000004088 simulation Methods 0.000 abstract 1
- 239000000523 sample Substances 0.000 description 112
- 238000012360 testing method Methods 0.000 description 10
- 238000010586 diagram Methods 0.000 description 7
- 238000013135 deep learning Methods 0.000 description 6
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 238000005457 optimization Methods 0.000 description 3
- 239000000126 substance Substances 0.000 description 3
- 230000001133 acceleration Effects 0.000 description 2
- 238000004891 communication Methods 0.000 description 2
- 238000007429 general method Methods 0.000 description 2
- 239000007787 solid Substances 0.000 description 2
- 238000006467 substitution reaction Methods 0.000 description 2
- 241000282412 Homo Species 0.000 description 1
- 206010039203 Road traffic accident Diseases 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000013145 classification model Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000018109 developmental process Effects 0.000 description 1
- 238000003709 image segmentation Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000001988 toxicity Effects 0.000 description 1
- 231100000419 toxicity Toxicity 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
Landscapes
- Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Theoretical Computer Science (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Biology (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本申请公开了一种支持多任务和跨任务的小样本分类训练方法及装置,所述方法包括:形式化类比,将小样本分类问题中的分类任务形式化成标准分类问题中的样本,并将小样本分类的目标形式化成给定大量任务样本的情况下学习一种任务解决器(能够估计任务是否完成);2)模拟标准分类问题中的批训练技术(每次迭代处理每个类别中的一些样本),提出一种多任务(multi‑episode)的小样本分类训练算法(每次迭代处理多个任务类别中的一些任务样本);3)模拟标准分类问题中的预训练技术(在大规模数据上为类似小规模数据任务预先训练一个基本模型),提出一种跨任务(cross‑way)的小样本分类训练算法(在多类别(高way)问题上为小类别(低way)问题预先训练一个基本模型)。
Description
技术领域
本申请实施例涉及深度学习、图像分类和计算机视觉处理技术,尤其涉及一种支持多任务和跨任务的小样本分类训练方法及装置。
背景技术
最近几年,得益于深度学习技术的发展,大规模监督学习取得了突破性进展,尤其是图像识别领域,例如ImageNet数据集上的精度由2012年的50%多上升到80%,机器的人脸识别正确率甚至超过了人眼。但是深度学习成功的背后,是对大数据集的依赖。而现实情况下,例如,交通事故自动识别、军事敏感目标分类以及医药分子毒性测试问题中,能够获取的样本都非常稀少。此时直接使用传统的深度学习技术进行训练非常容易产生过拟合问题。
如何在小样本下开展深度学习成为新的研究课题。小样本分类旨在模拟人类能够通过少量先验数据学习新概念的能力。人类之所以具备这种能力,主要原因就在于,人类能够从已有任务中学习知识,并将其应用到未来的模型训练。受到人类学习的启发,通用的做法是采用辅助的元学习或学会学习的模式训练小样本分类,以采用基于优化的方法学习可迁移的初始条件或采用基于记忆的方法和基于距离的方法学习可迁移的特征嵌入,然后再通过学习到的优化策略fine-tune求解目标小样本分类问题或在不更新网络权重的情况下直接前向计算求解目标小样本分类问题。
这些元学习模式已经在小样本分类中取得了重要的进展。其中最为有效的元学习模式都采用了基于episode的训练框架,每个episode中包含一个小的标记支撑集以及一个对应的查询集,以模拟测试环境下的小样本设置来增加模型的泛化能力。在这种基于episode的训练框架中,小样本分类可以被看成是通过在大量分类任务(给定一个小的标记支撑集的情况下分类一个未标注查询样本)上训练学习完成分类任务的能力。基于该视角,训练数据并不受限,而是非常巨大的,例如在Omniglot数据集上考虑5-way小样本分类问题,将会有C15200种任务,因此可以将小样本分类问题视为一种标准的分类问题。
但现有技术缺乏对该视角的形式化解释,从而无法充分利用标准分类问题中的学习技术快速从海量的任务中学习一种通用的知识。
发明内容
有鉴于此,本申请实施例提供一种支持多任务和跨任务的小样本分类训练方法及装置。
根据本申请的第一方面,提供一种支持多任务和跨任务的小样本分类训练方法,包括:
将小样本分类问题中的分类任务转换为标准分类问题中的样本,并将小样本分类的目标形式转换为给定大量任务样本的情况下学习一种任务解决器;
模拟标准分类问题中的批训练处理方式,每次迭代处理每个类别中的一些样本,采用多任务的小样本分类训练算法,每次迭代处理多个任务类别中的一些任务样本;
模拟标准分类问题中的预训练处理方式,在大规模数据上为类似小规模数据任务预先训练一个基本模型,采用跨任务的小样本分类训练算法,在多类别问题上为小类别问题预先训练基本模型,并利用预先训练基本模型对小类别问题进行微调。
在一些实施例中,所述将小样本分类的目标形式转换为给定大量任务样本的情况下学习一种任务解决器,包括:
其中,fθ是具有参数集θ的特定网络,l是给定的损失函数;
对于定义在具有M个类别、每个类别具有H个样本的训练集上的标准分类问题,其中是D维的输入向量,yi∈{1,2,…,M}是类别标号,Dj代表集合D中所有yi=j的样本(xi,yi)集合,fθ是需要学习的分类器,通用的损失函数l(fθ;xi,yi)是交叉熵,如下式(2)所示:
其中fθ(xi)j代表fθ(xi)的第j个输出;
考虑在所述训练集D上的K-way S-shot学习问题,定义任务类别为包含M个类别中的K个索引的类别子空间V∈T,将每个任务类别V中的任务样本G(V):={(τi,yi)}定义为支撑集SV和对应的查询集QV中查询样本(xi,yi)的组合;假定RANDOMSAMPLE(C,N)表示从集合C中无放回地随机均匀采样N个样本,则V=RANDOMSAMPLE({1,…,M},K),每个任务样本(τi,yi)被表示成(τi={SV,xi},yi),其中(xi,yi)=RANDOMSAMPLE(QV,1),
在一些实施例中,所述方法还包括:
l(fθ;τi,yi)为原型网络的损失函数时,则有:
l(fθ;τi,yi)=l(fθ;SV,xi,yi)=-logpθ(y=yi|SV,xi) (5)
在一些实施例中,所述采用多任务的小样本分类训练算法,每次迭代处理多个任务类别中的一些任务样本,包括:
为解决式(1)中表示的监督学习问题,批训练随机梯度下降minibatch SGD运算执行以下的更新策略:
其中α是学习速率,t是迭代步数,Bt是从整个数据集D中随机采样的一个minibatch;
对于定义在上述训练集D上的标准分类问题,在每个式(6)表示的训练步骤中,将从D中随机均匀采样一些样本作为Bt;
对于具有个任务类别,每个任务类别V∈T具有个任务样本的K-way S-shot学习问题,其数据集能够定义为其中τi={SV,xi};这里,对于式(6)表示的每个训练步骤,对应小样本分类的批训练应该是从Df中随机采样多个任务样本作为Bt;
上述采样方法的一种路径是:首先从T中采样一些任务类别,然后再从这些采样的任务类别中采样一些任务样本作为minibatch;
定义一个episode为一对支撑集SV和查询集QV,则采样一个episode可以被认为是随机从T中采样一个任务类别V,然后再从这个采样的任务类别V中采样KQ个具有相同支撑集SV的任务样本。因此,上述采样方法的一种具体实现是每次迭代采样多个episode。定义E-episode训练为使用E个episode的训练策略,则其中是随机采样的episode,Ve,e=1,…,E是从T中随机采样的E个任务类别。
在一些实施例中,所述采用跨任务的小样本分类训练算法,在多类别问题上为小类别问题预先训练基本模型,包括:
假设有另外一个数据分布Dpre,则能够将求解式(6)的初始值θ0设置为:
在标准分类问题中,通常将集合Dpre设置成具有大规模数据的集合,也就是|D|<|Dpre|;
根据本申请的第二方面,提供一种支持多任务和跨任务的小样本分类训练装置,包括:
转换单元,用于将小样本分类问题中的分类任务转换为标准分类问题中的样本,并将小样本分类的目标形式转换为给定大量任务样本的情况下学习一种任务解决器;
迭代处理单元,用于模拟标准分类问题中的批训练处理方式,每次迭代处理每个类别中的一些样本,采用多任务的小样本分类训练算法,每次迭代处理多个任务类别中的一些任务样本;
预训练单元,用于模拟标准分类问题中的预训练处理方式,在大规模数据上为类似小规模数据任务预先训练一个基本模型,采用跨任务的小样本分类训练算法,在多类别问题上为小类别问题预先训练基本模型,并利用预先训练基本模型对小类别问题进行微调。
在一些实施例中,所述转换单元,还用于:
其中,fθ是具有参数集θ的特定网络,l是给定的损失函数;
对于定义在具有M个类别、每个类别具有H个样本的训练集上的标准分类问题,其中是D维的输入向量,yi∈{1,2,…,M}是类别标号,Dj代表集合D中所有yi=j的样本(xi,yi)集合,fθ是需要学习的分类器,通用的损失函数l(fθ;xi,yi)是交叉熵,如下式(2)所示:
其中fθ(xi)j代表fθ(xi)的第j个输出;
考虑在所述训练集D上的K-way S-shot学习问题,定义任务类别为包含M个类别中的K个索引的类别子空间V∈T,将每个任务类别V中的任务样本G(V):={(τi,yi)}定义为支撑集SV和对应的查询集QV中查询样本(xi,yi)的组合;假定RANDOMSAMPLE(C,N)表示从集合C中无放回地随机均匀采样N个样本,则V=RANDOMSAMPLE({1,…,M},K),每个任务样本(τi,yi)被表示成(τi={SV,xi},yi),其中(xi,yi)=RANDOMSAMPLE(QV,1),
在一些实施例中,所述转换单元,还用于:
l(fθ;τi,yi)为原型网络的损失函数时,则有:
l(fθ;τi,yi)=l(fθ;SV,xi,yi)=-logpθ(y=yi|SV,xi) (5)
在一些实施例中,所述迭代处理单元,还用于:
为解决式(1)中表示的监督学习问题,批训练随机梯度下降minibatch SGD运算执行以下的更新策略:
其中α是学习速率,t是迭代步数,Bt是从整个数据集D中随机采样的一个minibatch;
对于定义在上述训练集D上的标准分类问题,在每个式(6)表示的训练步骤中,将从D中随机均匀采样一些样本作为Bt;
对于具有个任务类别,每个任务类别V∈T具有个任务样本的K-way S-shot学习问题,其数据集能够定义为其中τi={SV,xi};这里,对于式(6)表示的每个训练步骤,对应小样本分类的批训练应该是从Df中随机采样多个任务样本作为Bt;
上述采样方法的一种路径是:首先从T中采样一些任务类别,然后再从这些采样的任务类别中采样一些任务样本作为minibatch;
定义一个episode为一对支撑集SV和查询集QV,则采样一个episode可以被认为是随机从T中采样一个任务类别V,然后再从这个采样的任务类别V中采样KQ个具有相同支撑集SV的任务样本。因此,上述采样方法的一种具体实现是每次迭代采样多个episode。定义E-episode训练为使用E个episode的训练策略,则其中是随机采样的episode,Ve,e=1,…,E是从T中随机采样的E个任务类别。
在一些实施例中,所述预训练单元,还用于:
假设有另外一个数据分布Dpre,则能够将求解式(6)的初始值θ0设置为:
在标准分类问题中,通常将集合Dpre设置成具有大规模数据的集合,也就是|D|<|Dpre|;
本申请实施例从监督学习的视角给出了一种小样本分类与标准分类问题的形式化类比,进而模拟标准分类问题中的批训练(minibatch)和预训练策略,创新提出了多任务(multi-episode)和跨任务(cross-way)的小样本分类训练加速算法,能够在不损失精度的前提下提高小样本分类的收敛速度。本申请实施例提出的multi-episode训练(每次迭代处理多个任务类别中的一些任务样本)类比于标准分类问题中的minibatch训练(每次迭代处理每个类别中的一些样本)。由于multi-episode训练相比传统的one-episode训练(每次迭代处理一个任务类别中的一些样本)增加了minibatch的大小以及并行计算的程度,因此multi-episode训练相比one-episode训练能够加速目标小样本分类任务的收敛。另外,由于multi-episode训练能够很好地缓解one-episode训练中minibatch选择过程的不平衡的任务类别采样,因此multi-episode训练能够在不改变网络架构的情况下对目标小样本分类任务获得优于one-episode的精度性能。本申请实施例提出的cross-way训练(在多类别(高way)问题上为少类别(低way)问题预训练一个基本模型)类比于标准分类问题中的预训练(在相似的大规模数据集上训练一个基本模型),例如ImageNet预训练。由于采用高way训练时,每个episode中将有更多的数据,且采用高way预训练能够像ImageNet预训练一样生成更普适的特征表示,因此在高way上的小样本分类问题的预训练相比低way上的目标小样本分类问题的训练收敛更快,且能改善在目标小样本分类问题上的测试精度。
附图说明
图1为本申请实施例提供的3-way 1-shot分类问题的任务类别和任务样本示例示意图;
图2为本申请实施例提供的针对3-way 5-shot分类问题每次迭代的one-episode和multi-episode训练策略对比示意图;
图3为本申请实施例提供的采用5-way 5shot分类问题对3-way 5-shot分类问题进行预训练的cross-way训练策略示意图;
图4为本申请实施例提供的支持多任务和跨任务的小样本分类训练装置的组成结构示意图。
具体实施方式
为了解决将深度学习技术应用于小样本分类时出现的缺陷,一种通用的做法是采用一种辅助的元学习或学会学习的模式训练小样本分类,以学习一种可迁移的好的初始条件或特征嵌入,然后再通过学习到的优化策略fine-tune目标小样本分类问题或在不更新网络权重的情况下直接前向计算求解目标小样本分类问题。这些元学习模式已经在小样本分类中取得了重要的进展。其中最为有效的元学习模式都采用了基于episode的训练框架,每个episode中包含一个小的标记支撑集以及一个对应的查询集,以模拟测试环境下的小样本设置,从而增加模型的泛化能力。在这种基于episode的训练框架中,小样本分类可以被看成是通过在大量分类任务上训练学习完成分类任务的能力。基于该视角,训练数据并不受限,而是非常巨大的,因此可以将小样本分类问题视为一种标准的大数据集分类问题。但是现有的小样本分类方法缺乏与标准的分类问题的形式化类比,因此无法有效利用标准分类问题中的学习技术改善学习效率。针对该问题,本申请实施例首先从监督学习的视角给出了一种小样本分类与标准分类问题的形式化类比,进而提出了对应于标准分类问题中批训练和预训练策略的多任务和跨任务的小样本分类训练加速算法,包括:
1)形式化类比
其中,fθ是具有参数集θ的特定网络,l是给定的损失函数。
标准分类。对于定义在具有M个类别、每个类别具有H个样本的训练集上的标准分类问题,其中是D维的输入向量,yi∈{1,2,…,M}是类别标号,Dj代表集合D中所有yi=j的样本(xi,yi)集合,fθ是需要学习的分类器,通用的损失函数l(fθ;xi,yi)是交叉熵,如下式(2)所示:
其中fθ(xi)j代表fθ(xi)的第j个输出。
小样本分类。考虑在所述训练集D上的K-way S-shot学习问题。基于episode的训练机制的目标是:通过在数据集D上训练,为一个具有不同类别的数据集Dtest生成分类器。基于episode的训练机制背后的思想是利用D中大量的标记样本模拟测试环境下的小样本设置。具体地,模型将在K-way S-shot的episode上进行训练,每个episode的构建方法如下,首先从D中采样一个具有K个类别的类别子集V,然后生成一个包含KS个样本(V指定的K个类别中每个类别S个样本)的支撑集SV以及一个包含指定K个类别中剩余样本的查询集QV。假定RANDOMSAMPLE(C,N)表示从集合C中无放回地随机均匀采样N个样本,则V=RANDOMSAMPLE({1,…,M},K),每个任务样本(τi,yi)被表示成(τi={SV,xi},yi),其中(xi,yi)=RANDOMSAMPLE(QV,1),Q是指定K个类别中每个类别的查询样本的数目。在episode上的训练通过给模型输入支撑集SV更新参数以最小化查询集QV中样本的预测损失完成。
受到基于episode的训练机制的启发,本申请实施例定义一个任务类别为一个包含M个类别中的K个索引的类别子空间V∈T。然后将每个任务类别V中的任务样本G(V):={(τi,yi)}定义成一个支撑集SV和对应的查询集QV中一个查询样本(xi,yi)的组合。具体地,每个任务样本(τi,yi)可以被表示成(τi={SV,xi},yi),其中(xi,yi)=RANDOMSAMPLE(QV,1)。因此,任务类别的总数量|T|为每个任务类别V能够产生个任务样本。图1展示了在一个具有4个类别,每个类别有6个样本的数据集上的3-way 1-shot分类问题的任务类别和任务样本示例。根据任务类别和任务样本的定义,任务类别的总数量为每个任务类别有个任务样本。
基于上述任务类别和任务样本的形式化定义,小样本分类问题可以像标准分类问题一样被形式化成式(1)表示的监督学习问题。特别地,标准分类问题的目标是在给定大量样本{(xi,yi)}的情况下学习一种分类器fθ(fθ(x)能够估计样本x的类别),小样本分类的目标是在给定大量任务样本{(τi,yi)}的情况下学习一种任务解决器fθ(fθ(τ)能够估计任务τ是否完成)。为了更加清晰地形式化小样本分类问题,本申请实施例中,采用原型网络的损失函数对l(fθ;τi,yi)进行示例说明。
原型网络是一个虽然简单但却具有优越性能的小样本分类模型。该方法使用支撑集SV来提取每个类别的原型向量,并根据查询集QV中样本与每个类别原型的距离对样本进行分类。
l(fθ;τi,yi)=l(fθ;SV,xi,yi)=-logpθ(y=yi|SV,xi) (5)
从而原型网络的整体训练可以通过最小化每个episode的所有查询样本的均值损失,并对每个episode执行一次梯度下降更新完成。原型网络的泛化性能可以通过在测试episode上测量得到,其中测试episode中的图像都来源于Dtest而非D。对于每个测试episode,原型网络使用支撑集SV生成的预测器将每个查询样本xi分类到最可能的类别
不难看出,在给定τi={SV,xi}的情况下,式(5)表示的损失函数与式(2)表示的标准分类问题中的交叉熵损失具有很好的对应关系。
2)批训练
监督学习。本申请实施例从随机梯度下降(Stochastic gradient descent,SGD)的角度讨论批(minibatch)训练。为了解决式(1)中表示的监督学习问题,minibatch SGD执行下面的更新策略:
其中α是学习速率,t是迭代步数,Bt是从整个数据集D中随机采样的一个minibatch;而Minibatch SGD算法非常高效。
标准分类。对于定义在上述训练集D上的标准分类问题,形式化来说,在每个式(6)表示的训练步骤中,将从D中随机均匀采样一些样本作为Bt。假设|Bt|=100,M=10,则从概率的角度,Bt将从每个Dj中采样大约10个样本。
小样本分类。对于具有个任务类别,每个任务类别V∈T具有个任务样本的K-way S-shot学习问题,其数据集能够定义为其中τi={SV,xi};这里,对于式(6)表示的每个训练步骤,对应小样本分类的批训练应该是从Df中随机采样多个任务样本作为Bt。直观上,需要显示地生成Df的所有任务样本,然后再从Df中均匀采样一个minibatch。但是,根据形式化类比中的定义,Df中所有任务样本的数量是非常巨大的,因此显示地生成Df几乎不太可能,会非常地耗时,同时也对内存容量提出了极大的要求。基于上述考虑,本申请实施例提出了多任务(multi-episode)的训练策略,即首先从T中采样一些任务类别,然后再从这些采样的任务类别中采样一些任务样本作为一个minibatch。基于该视角,匹配网络(Matching Nets)中提出的基于episode的训练机制,每次迭代采样一个包含一对支撑集SV和查询集QV的episode的方法,可以被认为是随机从T中只采样一个任务类别V,然后再从这个采样的任务类别V中采样KQ个具有相同支撑集SV的任务样本作为Bt,即显然,这个方法不是一个合理的选择,因为倘若Bt是从Df中随机均匀采样得到的,则Bt中的任务样本几乎不可能全部都属于同一个任务类别。本申请实施例提出的multi-episode训练就是通过采用多个episode来构建Bt以缓解上述问题。定义E-episode训练为使用E个episode的训练策略,即其中是一个随机采样的episode,Ve,e=1,…,E是从T中随机采样的E个任务类别。图2展示了针对3-way 5-shot分类问题每次迭代的one-episode和multi-episode训练策略对比示意图。其中支撑样本、查询样本以及原型分别采用实心、空心、灰心形状表示。
3)预训练
监督学习。迭代求解式(1)表示的监督学习问题中的另一个关键点是式(6)中的初始值θ0。预训练给出了一种设定θ0的方法,也就是通过求解另一个构建在一个相似或更加复杂的数据分布之上的监督学习问题来生成θ0。具体地,假设有另外一个数据分布Dpre,则可以将求解式(6)的初始值θ0设置为:
值得说明的是,预训练中的网络并不需要与目标监督学习问题一致。对于一些图像分类问题,可以选择性地预训练网络中的一些特定层。
标准分类。在标准分类问题中,通常将集合Dpre设置成一个具有大规模数据的集合,也就是|D|<|Dpre|。最为出名的预训练方法就是应用于计算机视觉任务的ImageNet预训练。ImageNet预训练已经在多种机器学习任务上取得了成功应用,例如物体检测和图像分割任务。最近的研究表明,ImageNet预训练可以加速收敛,但并不一定能提高最终的收敛精度。
小样本分类。类比于标准分类问题中采用大规模数据进行预训练的思路,本申请实施例提出采用S-shot学习问题对K-way S-shot学习问题进行预训练,其中因为显而易见,任务类别数目任务样本总数本申请实施例将上述用于小样本分类的预训练策略命名为跨任务(cross-way)训练策略。图3展示了采用5-way 5shot分类问题对3-way 5-shot分类问题进行预训练的cross-way训练策略示意图。其中支撑样本、查询样本以及原型分别采用实心、空心、灰心形状表示。
图4为本申请实施例提供的支持多任务和跨任务的小样本分类训练装置的组成结构示意图,如图4所示,本申请实施例的支持多任务和跨任务的小样本分类训练装置包括:
转换单元40,用于将小样本分类问题中的分类任务转换为标准分类问题中的样本,并将小样本分类的目标形式转换为给定大量任务样本的情况下学习一种任务解决器;
迭代处理单元41,用于模拟标准分类问题中的批训练处理方式,每次迭代处理每个类别中的一些样本,采用多任务的小样本分类训练算法,每次迭代处理多个任务类别中的一些任务样本;
预训练单元42,用于模拟标准分类问题中的预训练处理方式,在大规模数据上为类似小规模数据任务预先训练基本模型,采用跨任务的小样本分类训练算法,在多类别问题上为小类别问题预先训练基本模型,并利用预先训练基本模型对小类别问题进行微调。
本申请实施例中,所述转换单元40,还用于:
其中,fθ是具有参数集θ的特定网络,l是给定的损失函数;
对于定义在具有M个类别、每个类别具有H个样本的训练集上的标准分类问题,其中是D维的输入向量,yi∈{1,2,…,M}是类别标号,Dj代表集合D中所有yi=j的样本(xi,yi)集合,fθ是需要学习的分类器,通用的损失函数l(fθ;xi,yi)是交叉熵,如下式(2)所示:
其中fθ(xi)j代表fθ(xi)的第j个输出;
考虑在所述训练集D上的K-way S-shot学习问题,定义任务类别为包含M个类别中的K个索引的类别子空间V∈T,将每个任务类别V中的任务样本G(V):={(τi,yi)}定义为支撑集SV和对应的查询集QV中查询样本(xi,yi)的组合;假定RANDOMSAMPLE(C,N)表示从集合C中无放回地随机均匀采样N个样本,则V=RANDOMSAMPLE({1,…,M},K),每个任务样本(τi,yi)被表示成(τi={SV,xi},yi),其中(xi,yi)=RANDOMSAMPLE(QV,1),
述转换单元40,还用于:
l(fθ;τi,yi)为原型网络的损失函数时,则有:
l(fθ;τi,yi)=l(fθ;SV,xi,yi)=-logpθ(y=yi|SV,xi) (5)
本申请实施例中,所述迭代处理单元41,还用于:
为解决式(1)中表示的监督学习问题,批训练随机梯度下降minibatch SGD运算执行以下的更新策略:
其中α是学习速率,t是迭代步数,Bt是从整个数据集D中随机采样的一个minibatch;
对于定义在上述训练集D上的标准分类问题,在每个式(6)表示的训练步骤中,将从D中随机均匀采样一些样本作为Bt;
对于具有个任务类别,每个任务类别V∈T具有个任务样本的K-way S-shot学习问题,其数据集能够定义为其中τi={SV,xi};这里,对于式(6)表示的每个训练步骤,对应小样本分类的批训练应该是从Df中随机采样多个任务样本作为Bt;
上述采样方法的一种路径是:首先从T中采样一些任务类别,然后再从这些采样的任务类别中采样一些任务样本作为minibatch;
定义一个episode为一对支撑集SV和查询集QV,则采样一个episode可以被认为是随机从T中采样一个任务类别V,然后再从这个采样的任务类别V中采样KQ个具有相同支撑集SV的任务样本。因此,上述采样方法的一种具体实现是每次迭代采样多个episode。定义E-episode训练为使用E个episode的训练策略,则其中是随机采样的episode,Ve,e=1,…,E是从T中随机采样的E个任务类别。
本申请实施例中,所述预训练单元42,还用于:
假设有另外一个数据分布Dpre,则能够将求解式(6)的初始值θ0设置为:
在标准分类问题中,通常将集合Dpre设置成具有大规模数据的集合,也就是|D|<|Dpre|;
在本公开实施例中,图4示出的支持多任务和跨任务的小样本分类训练装置中各个模块及单元执行操作的具体方式已经在有关该方法的实施例中进行了详细描述,此处将不做详细阐述说明。
应理解,说明书通篇中提到的“一个实施例”或“一实施例”意味着与实施例有关的特定特征、结构或特性包括在本发明的至少一个实施例中。因此,在整个说明书各处出现的“在一个实施例中”或“在实施例中”未必一定指相同的实施例。此外,这些特定的特征、结构或特性可以任意适合的方式结合在一个或多个实施例中。应理解,在本发明的各种实施例中,上述各过程的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。上述本发明实施例序号仅仅为了描述,不代表实施例的优劣。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者装置不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者装置所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者装置中还存在另外的相同要素。
在本申请所提供的几个实施例中,应该理解到,所揭露的设备和方法,可以通过其它的方式实现。以上所描述的设备实施例仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,如:多个单元或组件可以结合,或可以集成到另一个系统,或一些特征可以忽略,或不执行。另外,所显示或讨论的各组成部分相互之间的耦合、或直接耦合、或通信连接可以是通过一些接口,设备或单元的间接耦合或通信连接,可以是电性的、机械的或其它形式的。
上述作为分离部件说明的单元可以是、或也可以不是物理上分开的,作为单元显示的部件可以是、或也可以不是物理单元;既可以位于一个地方,也可以分布到多个网络单元上;可以根据实际的需要选择其中的部分或全部单元来实现本实施例方案的目的。
另外,在本发明各实施例中的各功能单元可以全部集成在一个处理单元中,也可以是各单元分别单独作为一个单元,也可以两个或两个以上单元集成在一个单元中;上述集成的单元既可以采用硬件的形式实现,也可以采用硬件加软件功能单元的形式实现。
以上所述,仅为本发明的实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以所述权利要求的保护范围为准。
Claims (10)
1.一种支持多任务和跨任务的小样本分类训练方法,其特征在于,所述方法包括:
将小样本分类问题中的分类任务转换为标准分类问题中的样本,并将小样本分类的目标形式转换为给定大量任务样本的情况下学习的任务解决器;
模拟标准分类问题中的批训练处理方式,每次迭代处理每个类别中的一些样本,采用多任务的小样本分类训练算法,每次迭代处理多个任务类别中的一些任务样本;
模拟标准分类问题中的预训练处理方式,在大规模数据上为类似小规模数据任务预先训练基本模型,采用跨任务的小样本分类训练算法,在多类别问题上为小类别问题预先训练基本模型,并利用预先训练基本模型对小类别问题进行微调。
2.根据权利要求1所述的方法,其特征在于,所述将小样本分类的目标形式转换为给定大量任务样本的情况下学习的任务解决器,包括:
其中,fθ是具有参数集θ的特定网络,l是给定的损失函数;
对于定义在具有M个类别、每个类别具有H个样本的训练集上的标准分类问题,其中是D维的输入向量,yi∈{1,2,…,M}是类别标号,Dj代表集合D中所有yi=j的样本(xi,yi)集合,fθ是需要学习的分类器,通用的损失函数l(fθ;xi,yi)是交叉熵,如下式(2)所示:
其中fθ(xi)j代表fθ(xi)的第j个输出;
4.根据权利要求1所述的方法,其特征在于,所述采用多任务的小样本分类训练算法,每次迭代处理多个任务类别中的一些任务样本,包括:
为解决式(1)中表示的监督学习问题,批训练随机梯度下降minibatch SGD运算执行以下的更新策略:
其中α是学习速率,t是迭代步数,Bt是从整个数据集D中随机采样的一个minibatch;
对于定义在上述训练集D上的标准分类问题,在每个式(6)表示的训练步骤中,将从D中随机均匀采样一些样本作为Bt;
对于具有个任务类别,每个任务类别V∈T具有个任务样本的K-way S-shot学习问题,其数据集能够定义为其中τi={SV,xi};这里,对于式(6)表示的每个训练步骤,对应小样本分类的批训练应该是从Df中随机采样多个任务样本作为Bt;
上述采样方法的一种路径是:首先从T中采样一些任务类别,然后再从这些采样的任务类别中采样一些任务样本作为minibatch;
6.一种支持多任务和跨任务的小样本分类训练装置,其特征在于,所述装置包括:
转换单元,用于将小样本分类问题中的分类任务转换为标准分类问题中的样本,并将小样本分类的目标形式转换为给定大量任务样本的情况下学习一种任务解决器;
迭代处理单元,用于模拟标准分类问题中的批训练处理方式,每次迭代处理每个类别中的一些样本,采用多任务的小样本分类训练算法,每次迭代处理多个任务类别中的一些任务样本;
预训练单元,用于模拟标准分类问题中的预训练处理方式,在大规模数据上为类似小规模数据任务预先训练基本模型,采用跨任务的小样本分类训练算法,在多类别问题上为小类别问题预先训练基本模型,并利用预先训练基本模型对小类别问题进行微调。
7.根据权利要求6所述的装置,其特征在于,所述转换单元,还用于:
其中,fθ是具有参数集θ的特定网络,l是给定的损失函数;
对于定义在具有M个类别、每个类别具有H个样本的训练集上的标准分类问题,其中是D维的输入向量,yi∈{1,2,…,M}是类别标号,Dj代表集合D中所有yi=j的样本(xi,yi)集合,fθ是需要学习的分类器,通用的损失函数l(fθ;xi,yi)是交叉熵,如下式(2)所示:
其中fθ(xi)j代表fθ(xi)的第j个输出;
9.根据权利要求6所述的装置,其特征在于,所述迭代处理单元,还用于:
为解决式(1)中表示的监督学习问题,批训练随机梯度下降minibatch SGD运算执行以下的更新策略:
其中α是学习速率,t是迭代步数,Bt是从整个数据集D中随机采样的一个minibatch;
对于定义在上述训练集D上的标准分类问题,在每个式(6)表示的训练步骤中,将从D中随机均匀采样一些样本作为Bt;
对于具有个任务类别,每个任务类别V∈T具有个任务样本的K-way S-shot学习问题,其数据集能够定义为其中τi={SV,xi};这里,对于式(6)表示的每个训练步骤,对应小样本分类的批训练应该是从Df中随机采样多个任务样本作为Bt;
上述采样方法的一种路径是:首先从T中采样一些任务类别,然后再从这些采样的任务类别中采样一些任务样本作为minibatch;
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011133629.5A CN112200262B (zh) | 2020-10-21 | 2020-10-21 | 支持多任务和跨任务的小样本分类训练方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011133629.5A CN112200262B (zh) | 2020-10-21 | 2020-10-21 | 支持多任务和跨任务的小样本分类训练方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112200262A true CN112200262A (zh) | 2021-01-08 |
CN112200262B CN112200262B (zh) | 2024-04-30 |
Family
ID=74010569
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011133629.5A Active CN112200262B (zh) | 2020-10-21 | 2020-10-21 | 支持多任务和跨任务的小样本分类训练方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112200262B (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113468869A (zh) * | 2021-07-12 | 2021-10-01 | 北京有竹居网络技术有限公司 | 一种语义分析模型生成方法、语义分析方法、装置及设备 |
CN113837379A (zh) * | 2021-09-14 | 2021-12-24 | 上海商汤智能科技有限公司 | 神经网络的训练方法及装置、计算机可读存储介质 |
CN113887227A (zh) * | 2021-09-15 | 2022-01-04 | 北京三快在线科技有限公司 | 一种模型训练与实体识别方法及装置 |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108764281A (zh) * | 2018-04-18 | 2018-11-06 | 华南理工大学 | 一种基于半监督自步学习跨任务深度网络的图像分类方法 |
CN109800811A (zh) * | 2019-01-24 | 2019-05-24 | 吉林大学 | 一种基于深度学习的小样本图像识别方法 |
CN110490227A (zh) * | 2019-07-09 | 2019-11-22 | 武汉理工大学 | 一种基于特征转换的少样本图像分类方法 |
US20200034694A1 (en) * | 2018-07-25 | 2020-01-30 | Element Ai Inc. | Multiple task transfer learning |
US20200143209A1 (en) * | 2018-11-07 | 2020-05-07 | Element Ai Inc. | Task dependent adaptive metric for classifying pieces of data |
CN111767949A (zh) * | 2020-06-28 | 2020-10-13 | 华南师范大学 | 一种基于特征和样本对抗共生的多任务学习方法及其系统 |
-
2020
- 2020-10-21 CN CN202011133629.5A patent/CN112200262B/zh active Active
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108764281A (zh) * | 2018-04-18 | 2018-11-06 | 华南理工大学 | 一种基于半监督自步学习跨任务深度网络的图像分类方法 |
US20200034694A1 (en) * | 2018-07-25 | 2020-01-30 | Element Ai Inc. | Multiple task transfer learning |
US20200143209A1 (en) * | 2018-11-07 | 2020-05-07 | Element Ai Inc. | Task dependent adaptive metric for classifying pieces of data |
CN109800811A (zh) * | 2019-01-24 | 2019-05-24 | 吉林大学 | 一种基于深度学习的小样本图像识别方法 |
CN110490227A (zh) * | 2019-07-09 | 2019-11-22 | 武汉理工大学 | 一种基于特征转换的少样本图像分类方法 |
CN111767949A (zh) * | 2020-06-28 | 2020-10-13 | 华南师范大学 | 一种基于特征和样本对抗共生的多任务学习方法及其系统 |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113468869A (zh) * | 2021-07-12 | 2021-10-01 | 北京有竹居网络技术有限公司 | 一种语义分析模型生成方法、语义分析方法、装置及设备 |
CN113837379A (zh) * | 2021-09-14 | 2021-12-24 | 上海商汤智能科技有限公司 | 神经网络的训练方法及装置、计算机可读存储介质 |
CN113887227A (zh) * | 2021-09-15 | 2022-01-04 | 北京三快在线科技有限公司 | 一种模型训练与实体识别方法及装置 |
CN113887227B (zh) * | 2021-09-15 | 2023-05-02 | 北京三快在线科技有限公司 | 一种模型训练与实体识别方法及装置 |
Also Published As
Publication number | Publication date |
---|---|
CN112200262B (zh) | 2024-04-30 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11361225B2 (en) | Neural network architecture for attention based efficient model adaptation | |
CN112200262A (zh) | 支持多任务和跨任务的小样本分类训练方法及装置 | |
CN110321957B (zh) | 融合三元组损失和生成对抗网络的多标签图像检索方法 | |
CN109583322A (zh) | 一种人脸识别深度网络训练方法和系统 | |
CN111639679B (zh) | 一种基于多尺度度量学习的小样本学习方法 | |
CN110209823A (zh) | 一种多标签文本分类方法及系统 | |
CN109299342A (zh) | 一种基于循环生成式对抗网络的跨模态检索方法 | |
Wang et al. | A network intrusion detection method based on deep multi-scale convolutional neural network | |
CN108984745A (zh) | 一种融合多知识图谱的神经网络文本分类方法 | |
CN109522942A (zh) | 一种图像分类方法、装置、终端设备和存储介质 | |
CN109446517A (zh) | 指代消解方法、电子装置及计算机可读存储介质 | |
CN104239858A (zh) | 一种人脸特征验证的方法和装置 | |
CN113887643B (zh) | 一种基于伪标签自训练和源域再训练的新对话意图识别方法 | |
CN114170332A (zh) | 一种基于对抗蒸馏技术的图像识别模型压缩方法 | |
CN103093247B (zh) | 一种植物图片的自动分类方法 | |
CN114444600A (zh) | 基于记忆增强原型网络的小样本图像分类方法 | |
CN110298434A (zh) | 一种基于模糊划分和模糊加权的集成深度信念网络 | |
CN105335619A (zh) | 适用于高计算代价数值计算模型参数反分析的协同优化法 | |
CN106021402A (zh) | 用于跨模态检索的多模态多类Boosting框架构建方法及装置 | |
Zhou et al. | Remote sensing image transfer classification based on weighted extreme learning machine | |
CN114898136B (zh) | 一种基于特征自适应的小样本图像分类方法 | |
CN114898158A (zh) | 基于多尺度注意力耦合机制的小样本交通异常图像采集方法及系统 | |
CN110263808B (zh) | 一种基于lstm网络和注意力机制的图像情感分类方法 | |
US10733499B2 (en) | Systems and methods for enhancing computer assisted high throughput screening processes | |
Lin et al. | Applying machine learning to fine classify construction and demolition waste based on deep residual network and knowledge transfer |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |