CN111191791B - 基于机器学习模型的图片分类方法、装置及设备 - Google Patents
基于机器学习模型的图片分类方法、装置及设备 Download PDFInfo
- Publication number
- CN111191791B CN111191791B CN201911213128.5A CN201911213128A CN111191791B CN 111191791 B CN111191791 B CN 111191791B CN 201911213128 A CN201911213128 A CN 201911213128A CN 111191791 B CN111191791 B CN 111191791B
- Authority
- CN
- China
- Prior art keywords
- machine learning
- learning model
- task
- current task
- sample data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Physics & Mathematics (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Image Analysis (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本申请提供了一种基于机器学习模型的图片分类方法、装置及设备,涉及人工智能技术领域。所述方法包括:确定当前任务与历史任务之间的相关性;根据相关性对历史任务的样本数据进行抽样,获得元样本数据;采用元样本数据对初始的机器学习模型进行在线元学习训练,得到训练后的机器学习模型;采用当前任务的样本数据对训练后的机器学习模型的参数进行调整,得到适用于当前任务的机器学习模型。本申请一方面实现了对当前任务的训练样本的数据增扩;另一方面由于在从历史任务中抽取样本数据时,充分考虑了当前任务与历史任务之间的相关性,使得该在线元学习的过程对于当前任务更加鲁棒,从而提升模型在当前任务上的预测精度。
Description
技术领域
本申请实施例涉及人工智能技术领域,特别涉及一种基于机器学习模型的图片分类方法、装置及设备。
背景技术
随着人工智能技术的发展,机器学习也得到了广泛应用。
传统的机器学习算法建立在两个重要的前提下,即(1)训练样本与测试样本独立同分布,(2)可获得大量的标注数据。然而,在一些业务场景(如金融风控场景下的用户违约预测)中,存在如下特点:(1)任务特性随时间发生变化,新时刻的任务不能直接使用旧任务的模型;(2)每个任务只有少量的标注数据,大部分需要预测的数据无标注;(3)不同任务的场景迥异,样本分布差异较大。
因此,传统的基于大量标注数据进行训练的机器学习算法难以直接运用于上述业务场景,无法为上述业务场景提供准确可靠的机器学习模型。
发明内容
本申请实施例提供了一种基于机器学习模型的图片分类方法、装置及设备。本申请提供的技术方案如下:
一方面,本申请实施例提供了一种机器学习模型的应用方法,所述方法包括:
获取当前任务的预测样本;
调用适用于所述当前任务的机器学习模型;
通过所述机器学习模型输出所述预测样本对应的预测结果;
其中,所述机器学习模型是根据所述当前任务与历史任务之间的相关性,从所述历史任务的样本数据中抽样得到元样本数据之后,采用所述元样本数据和所述当前任务的样本数据进行训练得到的。
另一方面,本申请实施例提供了一种机器学习模型的训练方法,所述方法包括:
确定当前任务与历史任务之间的相关性;
根据所述相关性对所述历史任务的样本数据进行抽样,获得元样本数据;
采用所述元样本数据对初始的机器学习模型进行在线元学习训练,得到训练后的机器学习模型;
采用所述当前任务的样本数据对所述训练后的机器学习模型的参数进行调整,得到适用于所述当前任务的机器学习模型。
另一方面,本申请实施例提供了一种机器学习模型的应用装置,所述装置包括:
样本获取模块,用于获取当前任务的预测样本;
模型调用模块,用于调用适用于所述当前任务的机器学习模型;
结果输出模块,用于通过所述机器学习模型输出所述预测样本对应的预测结果;
其中,所述机器学习模型是根据所述当前任务与历史任务之间的相关性,从所述历史任务的样本数据中抽样得到元样本数据之后,采用所述元样本数据和所述当前任务的样本数据进行训练得到的。
另一方面,本申请实施例提供了一种机器学习模型的训练装置,所述装置包括:
相关性确定模块,用于确定当前任务与历史任务之间的相关性;
样本抽样模块,用于根据所述相关性对所述历史任务的样本数据进行抽样,获得元样本数据;
模型训练模块,用于采用所述元样本数据对初始的机器学习模型进行在线元学习训练,得到训练后的机器学习模型;
参数调整模块,用于采用所述当前任务的样本数据对所述训练后的机器学习模型的参数进行调整,得到适用于所述当前任务的机器学习模型。
再一方面,本申请实施例提供了一种计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现上述机器学习模型的应用方法,或者实现上述机器学习模型的训练方法。
再一方面,本申请实施例提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现上述机器学习模型的应用方法,或者实现上述机器学习模型的训练方法。
还一方面,本申请实施例提供了一种计算机程序产品,所述计算机程序产品被处理器执行时,用于实现上述机器学习模型的应用方法,或者实现上述机器学习模型的训练方法。
本申请实施例提供的技术方案可以包括如下有益效果:
通过确定当前任务与历史任务之间的相关性,根据该相关性从历史任务中抽样获取元样本数据,然后采用该元样本数据和当前任务的样本数据进行模型训练,最终得到适用于当前任务的机器学习模型;一方面,通过从历史任务中抽取样本数据,作为当前任务的训练样本,从而实现了对当前任务的训练样本的数据增扩;另一方面,由于在从历史任务中抽取样本数据时,充分考虑了当前任务与历史任务之间的相关性,使得该在线元学习的过程对于当前任务更加鲁棒,从而提升最终训练得到的模型在当前任务上的预测精度。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请一个实施例提供的机器学习模型的训练方法的流程图;
图2是本申请一个实施例提供的方案架构图;
图3示出了一种彩虹手写数字数据集的示意图;
图4示出了一种分类任务数据的示意图;
图5示出了一种实验所得的准确率统计结果的示意图;
图6示出了一种实验所得的注意力分量的示意图;
图7是本申请一个实施例提供的机器学习模型的应用方法的流程图;
图8是本申请一个实施例提供的机器学习模型的训练装置的框图;
图9是本申请另一个实施例提供的机器学习模型的训练装置的框图;
图10是本申请一个实施例提供的机器学习模型的应用装置的框图;
图11是本申请一个实施例提供的计算机设备的结构示意图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述。
人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
机器学习(Machine Learning,ML)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、式教学习等技术。
随着人工智能技术研究和进步,人工智能技术在多个领域展开研究和应用,例如常见的智能家居、智能穿戴设备、虚拟助理、智能音箱、智能营销、无人驾驶、自动驾驶、无人机、机器人、智能医疗、智能客服等,相信随着技术的发展,人工智能技术将在更多的领域得到应用,并发挥越来越重要的价值。
本申请实施例提供的方案涉及人工智能的机器学习等技术。在对本申请实施例进行介绍说明之前,首先对本申请实施例中涉及的一些术语进行介绍说明。
1、任务(task)
一个机器学习的任务,包含训练集与测试集,目的是寻找合适的方法,学习有限标注数据下的分布规律,从而对相同条件下的其他数据进行预测。
2、在线学习(Online Machine Learning)
一种机器学习方法,其目的是对按时间顺序输入的数据,实时地更新模型,并进行预测。
3、注意力机制(Attention Mechanism)
一种机器学习中计算相关性的计算方法,受人类视觉注意力的启发,可以使神经网络具有对输入数据的筛选能力。
注意力机制来源于人类视觉研究的启发,人类在对周围场景进行认知时,会有选择地关注一个局部的目标区域,给与该目标区域的信息更大的关注度,投入更多的视觉信息处理资源,从而更高效地获取有用信息,抑制不重要的信息。人类的注意力机制充分地利用了有限的思维资源,极大地提高了认知的准确性和效率。
受此启发,注意力机制被引入深度学习领域,成为一种极为有效的计算机制。注意力机制的引入,使得神经网络具有了对输入数据进行筛选的能力,能够适应地筛选出对最终目标更有利的输入,从而增加有利数据,抑制不利数据的输入,提升模型性能。从本质上来说,注意力机制是一种通过键值查询实现的相似性度量。
注意力机制有多种不同的形式,按对数据的筛选方式分,可分为硬注意力和软注意力,按权重的计算方式分,又分为点乘注意力与感知机注意力。目前,注意力机制凭借其简单的形式与显著的效果,已经被广泛运用于自然语言处理、图像描述等领域。
4、元学习(Meta Learning)
元学习也可以称为“learning to learn”,即解决学习如何学习的问题。传统的机器学习问题是基于海量数据集从头开始学习一个用于预测的数学模型,这与人类学习、积累历史经验(也称为元知识)指导新的机器学习任务的过程相差甚远。元学习则是学习不同的机器学习任务的学习训练过程,及学习如何更快更好地训练一个数学模型。
5、在线元学习(Online Meta Learning)
在线元学习是工业界常用的机器学习模型训练方法,其主要解决的问题是,在训练样本按时序提供的条件下,在每个时间节点更新模型并处理数据。与在线元学习对应的是常见的批次学习,即对整个训练数据集一起训练,求得一个最优化的模型。与批次学习相比,在线元学习更新时不需要遍历整个数据集,大大地缩短了计算的时间和算力成本,能够更高效地适应新任务的需求。
常见的在线元学习算法有贝叶斯在线元学习(Bayesian Online Learning)和跟随正则化领导(Follow The Regularized Leader,FTRL)等。
FTRL的思想是每次进行更新时,寻找到使得之前任务损失函数f之和最小的参数wt,即式1。然而,直接求解该参数十分困难,因而,寻找一个代理损失函数h,以代理函数的最优解作为当前时刻的近似解,即式2。为保证解的有效性,将所求得的解与真实解的损失差值定义为遗憾(Regret,式3),遗憾应满足式4。FTRL则在求解参数wt时加入了正则化项R(w),使得解稀疏化,即
式6。
wt=argminwht-1(w) 式2
FTML(Follow The Meta Leader,跟随元领导)是融合了元学习与在线学习的算法,其核心思想认为,元学习与在线学习都能利用先前任务的知识来帮助后续任务的学习,但元学习没有考虑任务输入顺序以及分布的变化问题,而在线学习提供了一套处理任务流的框架。因此该方法将元学习方法MAML(Model-agnostic Meta Learning,模型无关的元学习)引入到在线学习算法FTRL中,优化的参数w是网络的初始化参数,随后对当前时刻的任务做一个映射Ut(例如,一步梯度下降,即式7),使其更具有当前任务的特性,其遗憾(Regret)计算方法如式8。
其中,上述式1至式8中的t表示任务时刻,式7中的α为预先设定的常数。
在线元学习的算法,主要考虑和解决数据按照时序输入以及实时更新模型的问题,并没有针对新任务与以往任务的关系选择性地建立训练数据集,导致与新任务分布差异较大的旧任务也被应用于模型更新中,使得训练效果难以提升甚至导致负优化。
注意力机制主要用于自然语言处理、计算机视觉等领域的深度学习网络中,作为网络筛选输入数据的模块,尚无人应用于在线元学习中的抽样过程中。
针对上文背景技术部分提到的技术问题,本申请实施例提供一种技术方案,该技术方案可称为基于任务相关性抽样的在线元学习方法,通过注意力机制来计算当前任务与历史任务之间的相关性,根据相关性对历史任务的样本数据进行抽样组成新的元任务,再利用元任务进行在线元学习训练,最后使用当前任务对模型进行适应性更新。该方法针对上述业务场景的特点,充分地利用了现有任务的数据,减少达到最优模型所需的训练样本量,提升新任务模型的泛化性能,减少新任务训练的时间。
本申请实施例提供的技术方案,可以应用于数据按时序依次输入的机器学习领域。如在金融风控场景下,根据不同时刻的客户群体的任务输入,基于当前任务与历史任务之间的相关性,抽样组成元样本数据对模型进行训练,随后再采用当前任务的样本数据对模型进行适应性更新,即可快速有效地获得能够适应当前任务的客户群体的预测模型。本申请技术方案可适用于各类金融风控场景,如预测用户在支付、借贷、理财等金融业务环节中的欺诈风险,帮助银行、证券、互金等金融企业提升风险识别能力,降低企业损失。另外,本申请技术方案还可适用于内容推荐、计算机视觉处理等领域。
下面,通过几个实施例对本申请技术方案进行详细介绍说明。
请参考图1,其示出了本申请一个实施例提供的机器学习模型的训练方法的流程图。该方法各步骤的执行主体可以是计算机设备,所述计算机设备是指具备处理和存储能力的电子设备,如PC(Personal Computer,个人计算机)、服务器等。该方法可以包括如下几个步骤(101~104):
步骤101,确定当前任务与历史任务之间的相关性。
当前任务与历史任务之间的相关性,是指该当前任务与历史任务的任务特性之间的相关程度。一个任务的任务特性可以由该任务的特征信息来反映,一个任务的特征信息是该任务的具有区分度的特征,从两个不同任务的特征信息中,可以看出这两个任务之间的差别。
另外,历史任务的数量可以是一个,也可以是多个。当历史任务的数量为多个时,需要确定当前任务与每一个历史任务之间的相关性。当历史任务的数量为多个时,可以从该多个任务的样本数据中抽样获得用于训练模型的元样本数据,从而丰富样本数据的来源。
可选地,通过注意力机制确定当前任务与历史任务之间的相关性,本步骤可以包括如下几个子步骤:
1、对当前任务的样本数据和n个历史任务的样本数据分别进行特征提取,得到当前任务的特征信息和各个历史任务的特征信息,n为正整数。
采用特征提取器对当前任务的样本数据和n个历史任务的样本数据分别进行特征提取,得到当前任务的特征信息和各个历史任务的特征信息。特征提取器用于对任务的样本数据进行特征提取,将任务的样本数据映射到某个特征空间,获得一个具有区分度的任务的低维表示(也即任务的特征信息)。特征提取器可以是一个卷积神经网络,例如其可以包括4个卷积层和1个全连接层。当然,关于特征提取器的网络结构设计的上述介绍说明仅是示例性和解释性的,其可以结合实际情况进行设计,本申请实施例对此不作限定。
2、通过注意力网络根据当前任务的特征信息和各个历史任务的特征信息,计算注意力向量。
注意力网络可以包括第一注意力网络和第二注意力网络。通过第一注意力网络根据当前任务的特征信息,输出当前任务对应的查询向量。第一注意力网络可以是一个单层的全连接网络。通过第二注意力网络根据各个历史任务的特征信息,输出n个历史任务对应的键值矩阵。第二注意力网络也可以是一个单层的全连接网络。然后,根据查询向量和键值矩阵,计算注意力向量。例如,将查询向量和键值矩阵各个分量进行点乘运算后,经由激活函数层得到一个注意力向量。该激活函数层所采用的激活函数可以是softmax激活函数。其中,注意力向量可以是一个1×n的行向量,也可以是一个n×1的列向量。注意力向量包括n个注意力分量,第i个注意力分量用于表征当前任务与第i个历史任务之间的相关性,i为小于等于n的正整数。
示例性地,假设历史任务的数量n为5,经过上述过程计算得到的注意力向量为[0.1,0.25,0.2,0.3,0.15],表示当前任务与第1个历史任务之间的相关性为0.1,当前任务与第2个历史任务之间的相关性为0.25,当前任务与第3个历史任务之间的相关性为0.2,当前任务与第4个历史任务之间的相关性为0.3,当前任务与第5个历史任务之间的相关性为0.15。
需要说明的一点是,在确定当前任务与历史任务之间的相关性时,除了采用上文介绍的注意力机制之外,还可以采用其他相关性计算方式,如皮尔逊相关系数、最大互信息系数(Maximal Information Coefficient,MIC)、距离相关系数等算法,本申请实施例对此不作限定。
步骤102,根据相关性对历史任务的样本数据进行抽样,获得元样本数据。
在得到当前任务相对于各个历史任务的相关性之后,基于该相关性,从各个历史任务的样本数据中,分别抽取若干个样本数据,组成当前任务的元样本数据。
可选地,根据上述n个注意力分量,对n个历史任务的样本数据进行抽样,获得元样本数据;其中,从各个历史任务的样本数据中抽样获取的样本数据的数量之间的比例,与n个注意力分量之间的比例相同。例如,注意力向量为[0.1,0.25,0.2,0.3,0.15],总共需要从历史任务中抽取100个样本数据作为元样本数据,那么从第1个历史任务中抽取10个样本数据,从第2个历史任务中抽取25个样本数据,从第3个历史任务中抽取20个样本数据,从第4个历史任务中抽取30个样本数据,从第5个历史任务中抽取15个样本数据,总共抽取100个样本数据得到元样本数据集。
步骤103,采用元样本数据对初始的机器学习模型进行在线元学习训练,得到训练后的机器学习模型。
初始的机器学习模型可以是未经过任何训练的模型,例如该初始的机器学习模型的参数可以随机设定或者按照经验设定;或者,初始的机器学习模型也可以是经过训练的模型,例如该初始的机器学习模型可以是适用于最近一个历史任务的机器学习模型。
在本申请实施例中,采用在线元学习的方式对该初始的机器学习模型进行训练,得到训练后的机器学习模型。可选地,本步骤可以包括如下几个子步骤:
1、根据元样本数据生成训练样本集;
例如,可以从元样本数据中选取部分样本数据,构成训练样本集。可选地,训练样本集中包含的来自于各个历史任务的样本数据的数量之间的比例,与上述注意力向量中的n个注意力分量之间的比例相同。
仍然以上述示例为例,元样本数据包括100个样本数据,从第1个至第5个历史任务中抽取的样本数据的数量依次为10、25、20、30和15,如果需要从上述元样本数据中选取60个样本数据构成训练样本集,那么训练样本集中包含的来自于第1个至第5个历史任务的样本数据的数量可以依次为6、15、12、18和9,也即从来自于第1个历史任务的10个样本数据中选取6个样本数据添加至训练样本集,从来自于第2个历史任务的25个样本数据中选取15个样本数据添加至训练样本集,从来自于第3个历史任务的20个样本数据中选取12个样本数据添加至训练样本集,从来自于第4个历史任务的30个样本数据中选取18个样本数据添加至训练样本集,从来自于第5个历史任务的15个样本数据中选取9个样本数据添加至训练样本集。
2、采用训练样本集分批次训练机器学习模型的参数;
在构建生成训练样本集之后,便可采用该训练样本集对机器学习模型进行训练,通过不断调整该机器学习模型的参数,来优化机器学习模型的表现。
可选地,采用分批次(batch)训练的方式对该机器学习模型进行训练。分批次训练有助于提升模型的训练效率。
3、计算机器学习模型的第一损失函数值;
假设在本步骤的训练过程中,机器学习模型的损失函数为第一损失函数,基于机器学习模型对训练样本的预测结果和该训练样本的标签,可以计算该第一损失函数的值,即得到第一损失函数值。
4、当第一损失函数值满足第一条件时,根据第一损失函数值计算机器学习模型的初始参数所对应的第一梯度;
第一条件可以预先设定,例如第一条件可以是第一损失函数值达到最小。当第一损失函数值满足第一条件时,根据该第一损失函数值计算机器学习模型的初始参数所对应的第一梯度。其中,机器学习模型的初始参数,即为上述初始的机器学习模型的参数。
5、根据第一梯度更新机器学习模型的初始参数,得到训练后的机器学习模型。
例如,根据该第一梯度以及机器学习模型的初始参数,计算该机器学习模型更新后的参数,从而得到训练后的机器学习模型。
另外,还可以根据元样本数据生成测试样本集,采用该测试样本集评价上述训练后的机器学习模型的准确度。如果训练后的机器学习模型的准确度不符合条件,则可以从训练样本集中重新选取训练样本对该机器学习模型进行训练,直至训练后的机器学习模型的准确度符合条件时,进入下一步流程。
在本申请实施例中,通过从历史任务中抽取样本数据,作为当前任务的训练样本,从而实现了对当前任务的训练样本的数据增扩。另外,由于在从历史任务中抽取样本数据时,充分考虑了当前任务与历史任务之间的相关性,从而使得该在线元学习的过程对于当前任务更加鲁棒。
步骤104,采用当前任务的样本数据对训练后的机器学习模型的参数进行调整,得到适用于当前任务的机器学习模型。
在采用元样本数据对初始的机器学习模型进行训练,得到训练后的机器学习模型之后,进一步采用当前任务的样本数据对该训练后的机器学习模型的参数进行微调,从而使得最终训练得到的机器学习模型在当前任务上的表现更佳。
可选地,本步骤可以包括如下几个子步骤:
1、根据当前任务的样本数据,计算训练后的机器学习模型的第二损失函数值;
假设在本步骤的训练过程中,机器学习模型的损失函数为第二损失函数,基于机器学习模型对当前任务的样本数据的预测结果和该样本数据的标签,可以计算该第二损失函数的值,即得到第二损失函数值。
2、当第二损失函数值满足第二条件时,根据第二损失函数值计算训练后的机器学习模型的参数所对应的第二梯度;
第二条件可以预先设定,例如第二条件可以是第二损失函数值达到最小。当第二损失函数值满足第二条件时,根据该第二损失函数值计算训练后的机器学习模型的参数所对应的第二梯度。
3、根据第二梯度更新训练后的机器学习模型的参数,得到适用于当前任务的机器学习模型。
例如,根据该第二梯度以及训练后的机器学习模型的原始参数,计算该机器学习模型更新后的参数,从而得到适用于当前任务的机器学习模型。
综上所述,本申请实施例提供的技术方案,通过确定当前任务与历史任务之间的相关性,根据该相关性从历史任务中抽样获取元样本数据,然后采用该元样本数据和当前任务的样本数据进行模型训练,最终得到适用于当前任务的机器学习模型;一方面,通过从历史任务中抽取样本数据,作为当前任务的训练样本,从而实现了对当前任务的训练样本的数据增扩;另一方面,由于在从历史任务中抽取样本数据时,充分考虑了当前任务与历史任务之间的相关性,使得该在线元学习的过程对于当前任务更加鲁棒,从而提升最终训练得到的模型在当前任务上的预测精度。
结合参考图2,其示出了本申请技术方案的一个架构图,包括特征提取器21、注意力模块22、抽样器23和元分类器24。
特征提取器(Feature Extractor)21:由F表示,第t任务时刻对应的网络权重为用于对各个时刻任务Tt的样本数据Dt进行特征提取并将任务的样本数据映射到某个特征空间,获得一个具有区分度的任务的低维表示,也即任务的特征信息rt。
注意力模块(Attention Module)22:由A表示,第t任务时刻对应的网络权重为αt。用于根据不同任务的特征信息rt,计算当前任务与历史任务之间的注意力向量at。
抽样器(Sampler)23:根据注意力向量at中的各个注意力分量,对历史任务的样本数据进行抽样,组成一个新的元样本数据Dmeta。
元分类器(Meta Classifier)24:由G表示,第t任务时刻对应的网络权重为θt。采用在线元学习算法学习适用于当前任务的分类器的参数,并结合当前任务的样本数据,对该参数进行适应性更新,最终获得适用于当前任务的分类器。
结合参考图2,在获得当前任务之后,将当前任务的样本数据和各个历史任务的样本数据分别输入到特征提取器21进行特征提取,得到当前任务的特征信息和各个历史任务的特征信息。当前任务的特征信息和各个历史任务的特征信息输入到注意力模块22之后,当前任务的特征信息经过第一注意力网络(如全连接层LinearQ)得到当前任务对应的查询向量Q,各个历史任务的特征信息经过第二注意力网络(如全连接层LinearK)得到历史任务对应的键值矩阵K。查询向量Q和键值矩阵K各个分量进行点乘运算后,经由激活函数层(如softmax层)得到一个注意力向量a。采用抽样器23根据注意力向量a对历史任务的样本数据进行抽样,获得元样本数据Dmeta。之后,可以基于该元样本数据Dmeta,生成训练样本集Dtrain和测试样本集Dtest,采用训练样本集Dtrain分批次训练元分类器24的参数θ,采用测试样本集Dtest评价上述更新的参数θ,使其总损失最小。然后,对元分类器24的初始参数求梯度,根据梯度更新元分类器24的初始参数,得到训练后的元分类器24。最后,采用当前任务的特征数据对该训练后的元分类器24的参数进行微调,最终得到适用于当前任务的元分类器24。
在示例性实施例中,训练流程可以包括如下几个步骤:
步骤1,随机初始化特征提取器的参数注意力模块的参数α0,元分类器的参数θ0;
步骤2,初始化空任务池B,即B←[];
步骤3,对于每个顺序的任务输入时刻t,进行下述步骤4至步骤11的操作;
步骤4,将任务Tt加入任务池B,即B←B+[Tt];
步骤5,对于任务池中的所有任务T1,T2,…,Tt的样本数据,经一个由4层卷积层与1层全连接网络组成的特征提取器F后得到任务的特征信息r1,r2,…,rt,即
步骤6,根据任务的特征信息r1,r2,…,rt,经由一个含有两个单层全连接网络的注意力模块,rt与r1,r2,…,rt-1被分别转化为查询向量Q和键值矩阵K,查询向量Q和键值矩阵K各个分量进行点乘运算后,经由一层softmax层得到任务Tt与各个历史任务T1,T2,…,Tt-1的注意力向量at,即at=A(rt,R|αt);
步骤7,以注意力向量at各个分量作为权重,对历史任务T1,T2,…,Tt-1的样本数据进行抽取,组成元样本数据Dmeta;
步骤8,将元样本数据Dmeta划分为不相交的训练样本集Dtrain和测试样本集Dtest,在训练样本集Dtrain上分批抽取训练样本以式9分别计算各个批次(batch)上的优化参数θ′,再以式10计算元分类器G的参数θt;
其中,L(·)为分类损失函数,α和β为预先设定的常数;
步骤9,使用当前任务Tt的样本数据Dt作为输入,基于式11计算元分类器G上的分类损失Lupdate;
步骤10,计算梯度并基于式12更新元分类器G在当前任务Tt上的适应性网络权重,获得适用于当前任务Tt的分类器G(θ′t);
步骤11,更新系统状态,以当前时刻的参数作为下一时刻任务的初始参数,即
传统的在线元学习的方法不能考虑当前任务与历史任务的关系,容易陷入对当前任务的过拟合,而在一些业务场景下,各个任务具备相当程度的相关性,如果不能充分考虑任务间的关系,将限制模型的表现。针对这个问题,本申请实施例提供的技术方案具有如下优点:1)通过注意力机制建立当前任务与历史任务之间的相关性,根据最终的模型表现,优化出最适合的相关性计算参数,充分利用了历史任务的知识来帮助当前任务的训练;2)通过在线元学习的方式,充分适应了上述业务场景,能够在获得新任务时,迅速结合以往任务展开训练,使模型快速地适应新任务的分布,达到实时训练、实时预测的效果。
为了客观验证本申请技术方案的有效性,并对算法性能做定量评价,经过实验在多任务数据集上与其他在线元学习算法进行比对。
实验选用了彩虹手写数字数据集(Rainbow-Mnist),Mnist(Mixed NationalInstitute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含60000个示例的训练集以及10000个示例的测试集,如图3所示。
在此基础上,对每个样本分别进行三种变换的组合,即改变7种不同的背景颜色,改变4个不同的旋转角度,改变字符的2种不同尺寸,共组合出了56个不同的分类任务数据集,如图4所示。在图4中,不同的背景颜色以不同的灰度表示,例如该不同的背景颜色可以是红色、黄色、蓝色、紫色、绿色、橙色、黑色,等等。
在实验时,假设56个任务以随机的顺序获得。每个任务分别取900张不重复的图片,其中450张作为训练集,450张作为测试集。元分类器训练时批处理尺寸(batch size)设为25,迭代次数5次,MAML内层使用梯度下降法更新参数,学习率为0.1,外层使用Adam优化器,学习率为0.001。每个任务下模型准确率如图5中实线51所示。
为体现本申请技术方案的优点,使用FTML作为参考实验,实验超参数同上。每个任务下模型准确率如图5中虚线52所示。
实验结果显示,本申请所提出方法在训练时能取得较优效果,在每个任务上都比FTML高4个百分点。根据实验结果发现,在训练时,注意力模块能够有效地选取与当前任务相似的历史任务,从而达到更加适应性地训练的目的。如图6所示,以当前任务为第17个任务(任务编号为16)为例,经过注意力模块计算得到的当前任务相对于各个历史任务(包括第1个历史任务至第16个历史任务,即任务编号从0至15)的注意力分量值。
综上所述,本申请技术方案通过注意力机制计算当前任务与历史任务之间的相关性,以此为依据进行抽样,建立元任务,在此基础上进行适应性的在线元学习,能够有效地捕捉任务之间的相关性,充分筛选和利用有效数据,提升最终训练效果;其次,与以往方法相比,本申请技术方案能够有效地降低训练过程中的波动,提升结果的置信度。
请参考图7,其示出了本申请一个实施例提供的机器学习模型的应用方法的流程图。该方法各步骤的执行主体可以是计算机设备,所述计算机设备是指具备处理和存储能力的电子设备,如手机、平板电脑、智能机器人、PC等终端设备,也可以是服务器等。该方法可以包括如下几个步骤(701~703):
步骤701,获取当前任务的预测样本;
步骤702,调用适用于当前任务的机器学习模型;
步骤703,通过机器学习模型输出预测样本对应的预测结果;
其中,机器学习模型是根据当前任务与历史任务之间的相关性,从历史任务的样本数据中抽样得到元样本数据之后,采用元样本数据和当前任务的样本数据进行训练得到的。
有关机器学习模型的训练过程可参见上文实施例中的介绍说明,此处不再赘述。在训练得到适用于当前任务的机器学习模型之后,可以采用该机器学习模型对当前任务的预测样本进行预测,得到相应的预测结果。
在不同的业务场景中,机器学习模型的预测样本和相应的预测结果会有所不同。例如,在金融风控场景下的用户违约预测中,预测样本可以包括某个目标用户的用户信息,如年龄、性别、社交信息、征信信息等用户信息,相应的预测结果可以是该目标用户是否为潜在的违约用户。又例如,在视频推荐场景下,预测样本可以包括某个目标用户的用户信息,如年龄、性别、地区、兴趣爱好、网络历史行为信息等用户信息,相应的预测结果可以是推荐给该目标用户的视频分类。
综上所述,本申请实施例提供的技术方案,通过注意力机制确定当前任务与历史任务之间的相关性,根据该相关性从历史任务中抽样获取元样本数据,然后采用该元样本数据和当前任务的样本数据进行模型训练,最终得到适用于当前任务的机器学习模型;一方面,通过从历史任务中抽取样本数据,作为当前任务的训练样本,从而实现了对当前任务的训练样本的数据增扩;另一方面,由于在从历史任务中抽取样本数据时,充分考虑了当前任务与历史任务之间的相关性,使得该在线元学习的过程对于当前任务更加鲁棒,从而提升最终训练得到的模型在当前任务上的预测精度。
下述为本申请装置实施例,可以用于执行本申请方法实施例。对于本申请装置实施例中未披露的细节,请参照本申请方法实施例。
请参考图8,其示出了本申请一个实施例提供的机器学习模型的训练装置的框图。该装置具有实现上述训练方法示例的功能,所述功能可以由硬件实现,也可以由硬件执行相应的软件实现。该装置可以是上文介绍的计算机设备,也可以设置在计算机设备中。该装置800可以包括:相关性确定模块810、样本抽样模块820、模型训练模块830和参数调整模块840。
相关性确定模块810,用于确定当前任务与历史任务之间的相关性。
样本抽样模块820,用于根据所述相关性对所述历史任务的样本数据进行抽样,获得元样本数据。
模型训练模块830,用于采用所述元样本数据对初始的机器学习模型进行在线元学习训练,得到训练后的机器学习模型。
参数调整模块840,用于采用所述当前任务的样本数据对所述训练后的机器学习模型的参数进行调整,得到适用于所述当前任务的机器学习模型。
在示例性实施例中,如图9所示,所述相关性确定模块810,包括特征提取单元811和注意力计算单元812。
特征提取单元811,用于对所述当前任务的样本数据和n个历史任务的样本数据分别进行特征提取,得到所述当前任务的特征信息和各个所述历史任务的特征信息,所述n为正整数。
注意力计算单元812,用于通过注意力网络根据所述当前任务的特征信息和各个所述历史任务的特征信息,计算注意力向量。其中,所述注意力向量包括n个注意力分量,第i个注意力分量用于表征所述当前任务与第i个历史任务之间的相关性,所述i为小于等于所述n的正整数。
在示例性实施例中,所述注意力计算单元812,用于:
通过第一注意力网络根据所述当前任务的特征信息,输出所述当前任务对应的查询向量;
通过第二注意力网络根据各个所述历史任务的特征信息,输出所述n个历史任务对应的键值矩阵;
根据所述查询向量和所述键值矩阵,计算所述注意力向量。
在示例性实施例中,所述样本抽样模块820,用于:
根据所述n个注意力分量,对所述n个历史任务的样本数据进行抽样,获得所述元样本数据;
其中,从各个所述历史任务的样本数据中抽样获取的样本数据的数量之间的比例,与所述n个注意力分量之间的比例相同。
在示例性实施例中,所述模型训练模块830,用于:
根据所述元样本数据生成训练样本集;
采用所述训练样本集分批次训练所述机器学习模型的参数;
计算所述机器学习模型的第一损失函数值;
当所述第一损失函数值满足第一条件时,根据所述第一损失函数值计算所述机器学习模型的初始参数所对应的第一梯度;
根据所述第一梯度更新所述机器学习模型的初始参数,得到所述训练后的机器学习模型。
在示例性实施例中,所述参数调整模块840,用于:
根据所述当前任务的样本数据,计算所述训练后的机器学习模型的第二损失函数值;
当所述第二损失函数值满足第二条件时,根据所述第二损失函数值计算所述训练后的机器学习模型的参数所对应的第二梯度;
根据所述第二梯度更新所述训练后的机器学习模型的参数,得到适用于所述当前任务的机器学习模型。
综上所述,本申请实施例提供的技术方案,通过注意力机制确定当前任务与历史任务之间的相关性,根据该相关性从历史任务中抽样获取元样本数据,然后采用该元样本数据和当前任务的样本数据进行模型训练,最终得到适用于当前任务的机器学习模型;一方面,通过从历史任务中抽取样本数据,作为当前任务的训练样本,从而实现了对当前任务的训练样本的数据增扩;另一方面,由于在从历史任务中抽取样本数据时,充分考虑了当前任务与历史任务之间的相关性,使得该在线元学习的过程对于当前任务更加鲁棒,从而提升最终训练得到的模型在当前任务上的预测精度。
请参考图10,其示出了本申请一个实施例提供的机器学习模型的应用装置的框图。该装置具有实现上述应用方法示例的功能,所述功能可以由硬件实现,也可以由硬件执行相应的软件实现。该装置可以是上文介绍的计算机设备,也可以设置在计算机设备中。该装置1000可以包括:样本获取模块1010、模型调用模块1020和结果输出模块1030。
样本获取模块1010,用于获取当前任务的预测样本;
模型调用模块1020,用于调用适用于所述当前任务的机器学习模型;
结果输出模块1030,用于通过所述机器学习模型输出所述预测样本对应的预测结果。
其中,所述机器学习模型是根据所述当前任务与历史任务之间的相关性,从所述历史任务的样本数据中抽样得到元样本数据之后,采用所述元样本数据和所述当前任务的样本数据进行训练得到的。
在示例性实施例中,所述机器学习模型的训练过程如下:
确定所述当前任务与所述历史任务之间的相关性;
根据所述相关性对所述历史任务的样本数据进行抽样,获得元样本数据;
采用所述元样本数据对初始的机器学习模型进行在线元学习训练,得到训练后的机器学习模型;
采用所述当前任务的样本数据对所述训练后的机器学习模型的参数进行调整,得到适用于所述当前任务的机器学习模型。
有关模型训练过程的其它介绍说明可参见上文实施例,此处不再赘述。
综上所述,本申请实施例提供的技术方案,通过注意力机制确定当前任务与历史任务之间的相关性,根据该相关性从历史任务中抽样获取元样本数据,然后采用该元样本数据和当前任务的样本数据进行模型训练,最终得到适用于当前任务的机器学习模型;一方面,通过从历史任务中抽取样本数据,作为当前任务的训练样本,从而实现了对当前任务的训练样本的数据增扩;另一方面,由于在从历史任务中抽取样本数据时,充分考虑了当前任务与历史任务之间的相关性,使得该在线元学习的过程对于当前任务更加鲁棒,从而提升最终训练得到的模型在当前任务上的预测精度。
需要说明的是,上述实施例提供的装置,在实现其功能时,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将设备的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。另外,上述实施例提供的装置与方法实施例属于同一构思,其具体实现过程详见方法实施例,这里不再赘述。
请参考图11,其示出了本申请一个实施例提供的计算机设备的结构示意图。具体来讲:
所述计算机设备1100包括CPU(Central Processing Unit,中央处理单元)1101、包括RAM(Random Access Memory,随机存取存储器)1102和ROM(Read Only Memory,只读存储器)1103的系统存储器1104,以及连接系统存储器1104和中央处理单元1101的系统总线1105。所述计算机设备1100还包括帮助计算机内的各个器件之间传输信息的基本I/O(Input/Output输入/输出)系统1106,和用于存储操作系统1113、应用程序1114和其他程序模块1115的大容量存储设备1107。
所述基本输入/输出系统1106包括有用于显示信息的显示器1108和用于用户输入信息的诸如鼠标、键盘之类的输入设备1109。其中所述显示器1108和输入设备1109都通过连接到系统总线1105的输入输出控制器1110连接到中央处理单元1101。所述基本输入/输出系统1106还可以包括输入输出控制器1110以用于接收和处理来自键盘、鼠标、或电子触控笔等多个其他设备的输入。类似地,输入输出控制器1110还提供输出到显示屏、打印机或其他类型的输出设备。
所述大容量存储设备1107通过连接到系统总线1105的大容量存储控制器(未示出)连接到中央处理单元1101。所述大容量存储设备1107及其相关联的计算机可读介质为计算机设备1100提供非易失性存储。也就是说,所述大容量存储设备1107可以包括诸如硬盘或者CD-ROM(Compact Disc Read-Only Memory,只读光盘)驱动器之类的计算机可读介质(未示出)。
不失一般性,所述计算机可读介质可以包括计算机存储介质和通信介质。计算机存储介质包括以用于存储诸如计算机可读指令、数据结构、程序模块或其他数据等信息的任何方法或技术实现的易失性和非易失性、可移动和不可移动介质。计算机存储介质包括RAM、ROM、EPROM(Erasable Programmable Read Only Memory,可擦除可编程只读存储器)、闪存或其他固态存储其技术,CD-ROM或其他光学存储、磁带盒、磁带、磁盘存储或其他磁性存储设备。当然,本领域技术人员可知所述计算机存储介质不局限于上述几种。上述的系统存储器1104和大容量存储设备1107可以统称为存储器。
根据本申请的各种实施例,所述计算机设备1100还可以通过诸如因特网等网络连接到网络上的远程计算机运行。也即计算机设备1100可以通过连接在所述系统总线1105上的网络接口单元1111连接到网络1112,或者说,也可以使用网络接口单元1111来连接到其他类型的网络或远程计算机系统(未示出)。
所述存储器还包括至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、至少一段程序、代码集或指令集存储于存储器中,且经配置以由一个或者一个以上处理器执行,以实现上述机器学习模型的训练方法,或者实现上述机器学习模型的应用方法。
在示例性实施例中,还提供了一种计算机可读存储介质,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或所述指令集在被计算机设备的处理器执行时以实现上述机器学习模型的训练方法,或者实现上述机器学习模型的应用方法。
可选地,该计算机可读存储介质可以包括:ROM、RAM、SSD(Solid State Drives,固态硬盘)或光盘等。其中,随机存取记忆体可以包括ReRAM(Resistance Random AccessMemory,电阻式随机存取记忆体)和DRAM(Dynamic Random Access Memory,动态随机存取存储器)。
在示例性实施例中,还提供一种计算机程序产品,所述计算机程序产品被计算机设备的处理器执行时,用于实现上述机器学习模型的训练方法,或者实现上述机器学习模型的应用方法。
应当理解的是,在本文中提及的“多个”是指两个或两个以上。“和/或”,描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。字符“/”一般表示前后关联对象是一种“或”的关系。另外,本文中描述的步骤编号,仅示例性示出了步骤间的一种可能的执行先后顺序,在一些其它实施例中,上述步骤也可以不按照编号顺序来执行,如两个不同编号的步骤同时执行,或者两个不同编号的步骤按照与图示相反的顺序执行,本申请实施例对此不作限定。
以上所述仅为本申请的示例性实施例,并不用以限制本申请,凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。
Claims (10)
1.一种基于机器学习模型的图片分类方法,其特征在于,所述方法包括:
获取当前任务的预测样本,所述预测样本包括待分类的图片;
通过适用于所述当前任务的机器学习模型,根据所述预测样本输出所述预测样本对应的预测结果,所述预测结果用于指示所述待分类的图片的分类结果;
其中,适用于所述当前任务的机器学习模型的训练过程如下:
对所述当前任务的样本数据和n个历史任务的样本数据分别进行特征提取,得到所述当前任务的特征信息和各个所述历史任务的特征信息,所述n为正整数;
通过注意力网络根据所述当前任务的特征信息和各个所述历史任务的特征信息,计算注意力向量;其中,所述注意力向量包括n个注意力分量,所述n个注意力分量中的第i个注意力分量用于表征所述当前任务与所述n个历史任务中的第i个历史任务之间的相关性,所述i为小于等于所述n的正整数;
根据所述n个注意力分量,对所述n个历史任务的样本数据进行抽样,获得元样本数据;其中,从所述n个历史任务的样本数据中抽样获取的样本数据的数量之间的比例,与所述n个注意力分量之间的比例相同;
采用所述元样本数据对初始的机器学习模型进行在线元学习训练,得到训练后的机器学习模型;
采用所述当前任务的样本数据对所述训练后的机器学习模型的参数进行调整,得到适用于所述当前任务的机器学习模型。
2.根据权利要求1所述的方法,其特征在于,所述通过注意力网络根据所述当前任务的特征信息和各个所述历史任务的特征信息,计算注意力向量,包括:
通过第一注意力网络根据所述当前任务的特征信息,输出所述当前任务对应的查询向量;
通过第二注意力网络根据各个所述历史任务的特征信息,输出所述n个历史任务对应的键值矩阵;
根据所述查询向量和所述键值矩阵,计算所述注意力向量。
3.根据权利要求1或2所述的方法,其特征在于,所述采用所述元样本数据对初始的机器学习模型进行在线元学习训练,得到训练后的机器学习模型,包括:
根据所述元样本数据生成训练样本集;
采用所述训练样本集分批次训练所述机器学习模型的参数;
计算所述机器学习模型的第一损失函数值;
当所述第一损失函数值满足第一条件时,根据所述第一损失函数值计算所述机器学习模型的初始参数所对应的第一梯度;
根据所述第一梯度更新所述机器学习模型的初始参数,得到所述训练后的机器学习模型。
4.根据权利要求1或2所述的方法,其特征在于,所述采用所述当前任务的样本数据对所述训练后的机器学习模型的参数进行调整,得到适用于所述当前任务的机器学习模型,包括:
根据所述当前任务的样本数据,计算所述训练后的机器学习模型的第二损失函数值;
当所述第二损失函数值满足第二条件时,根据所述第二损失函数值计算所述训练后的机器学习模型的参数所对应的第二梯度;
根据所述第二梯度更新所述训练后的机器学习模型的参数,得到适用于所述当前任务的机器学习模型。
5.一种用于图片分类的机器学习模型的训练方法,其特征在于,所述方法包括:
对当前任务的样本数据和n个历史任务的样本数据分别进行特征提取,得到所述当前任务的特征信息和各个所述历史任务的特征信息,所述n为正整数;
通过注意力网络根据所述当前任务的特征信息和各个所述历史任务的特征信息,计算注意力向量;其中,所述注意力向量包括n个注意力分量,所述n个注意力分量中的第i个注意力分量用于表征所述当前任务与所述n个历史任务中的第i个历史任务之间的相关性,所述i为小于等于所述n的正整数;
根据所述n个注意力分量,对所述n个历史任务的样本数据进行抽样,获得元样本数据;其中,从所述n个历史任务的样本数据中抽样获取的样本数据的数量之间的比例,与所述n个注意力分量之间的比例相同;
采用所述元样本数据对初始的机器学习模型进行在线元学习训练,得到训练后的机器学习模型;
采用所述当前任务的样本数据对所述训练后的机器学习模型的参数进行调整,得到适用于所述当前任务的机器学习模型;其中,适用于所述当前任务的机器学习模型用于根据当前任务的预测样本中包含的待分类的图片,输出所述预测样本对应的预测结果,所述预测结果用于指示所述待分类的图片的分类结果。
6.根据权利要求5所述的方法,其特征在于,所述通过注意力网络根据所述当前任务的特征信息和各个所述历史任务的特征信息,计算注意力向量,包括:
通过第一注意力网络根据所述当前任务的特征信息,输出所述当前任务对应的查询向量;
通过第二注意力网络根据各个所述历史任务的特征信息,输出所述n个历史任务对应的键值矩阵;
根据所述查询向量和所述键值矩阵,计算所述注意力向量。
7.一种基于机器学习模型的图片分类装置,其特征在于,所述装置包括:
样本获取模块,用于获取当前任务的预测样本,所述预测样本包括待分类的图片;
结果输出模块,用于通过适用于所述当前任务的机器学习模型,根据所述预测样本输出所述预测样本对应的预测结果,所述预测结果用于指示所述待分类的图片的分类结果;
其中,适用于所述当前任务的机器学习模型的训练过程如下:
对所述当前任务的样本数据和n个历史任务的样本数据分别进行特征提取,得到所述当前任务的特征信息和各个所述历史任务的特征信息,所述n为正整数;
通过注意力网络根据所述当前任务的特征信息和各个所述历史任务的特征信息,计算注意力向量;其中,所述注意力向量包括n个注意力分量,所述n个注意力分量中的第i个注意力分量用于表征所述当前任务与所述n个历史任务中的第i个历史任务之间的相关性,所述i为小于等于所述n的正整数;
根据所述n个注意力分量,对所述n个历史任务的样本数据进行抽样,获得元样本数据;其中,从所述n个历史任务的样本数据中抽样获取的样本数据的数量之间的比例,与所述n个注意力分量之间的比例相同;
采用所述元样本数据对初始的机器学习模型进行在线元学习训练,得到训练后的机器学习模型;
采用所述当前任务的样本数据对所述训练后的机器学习模型的参数进行调整,得到适用于所述当前任务的机器学习模型。
8.一种用于图片分类的机器学习模型的训练装置,其特征在于,所述装置包括:
相关性确定模块,用于对当前任务的样本数据和n个历史任务的样本数据分别进行特征提取,得到所述当前任务的特征信息和各个所述历史任务的特征信息,所述n为正整数;通过注意力网络根据所述当前任务的特征信息和各个所述历史任务的特征信息,计算注意力向量;其中,所述注意力向量包括n个注意力分量,所述n个注意力分量中的第i个注意力分量用于表征所述当前任务与所述n个历史任务中的第i个历史任务之间的相关性,所述i为小于等于所述n的正整数;
样本抽样模块,用于根据所述n个注意力分量,对所述n个历史任务的样本数据进行抽样,获得元样本数据;其中,从所述n个历史任务的样本数据中抽样获取的样本数据的数量之间的比例,与所述n个注意力分量之间的比例相同;
模型训练模块,用于采用所述元样本数据对初始的机器学习模型进行在线元学习训练,得到训练后的机器学习模型;
参数调整模块,用于采用所述当前任务的样本数据对所述训练后的机器学习模型的参数进行调整,得到适用于所述当前任务的机器学习模型;其中,适用于所述当前任务的机器学习模型用于根据当前任务的预测样本中包含的待分类的图片,输出所述预测样本对应的预测结果,所述预测结果用于指示所述待分类的图片的分类结果。
9.一种计算机设备,其特征在于,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一段程序,所述至少一段程序由所述处理器加载并执行以实现如权利要求1至4任一项所述的方法,或者实现如权利要求5至6任一项所述的方法。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质中存储有至少一段程序,所述至少一段程序由处理器加载并执行以实现如权利要求1至4任一项所述的方法,或者实现如权利要求5至6任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911213128.5A CN111191791B (zh) | 2019-12-02 | 2019-12-02 | 基于机器学习模型的图片分类方法、装置及设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911213128.5A CN111191791B (zh) | 2019-12-02 | 2019-12-02 | 基于机器学习模型的图片分类方法、装置及设备 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111191791A CN111191791A (zh) | 2020-05-22 |
CN111191791B true CN111191791B (zh) | 2023-09-29 |
Family
ID=70709175
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201911213128.5A Active CN111191791B (zh) | 2019-12-02 | 2019-12-02 | 基于机器学习模型的图片分类方法、装置及设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111191791B (zh) |
Families Citing this family (24)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111813143B (zh) * | 2020-06-09 | 2022-04-19 | 天津大学 | 一种基于强化学习的水下滑翔机智能控制系统及方法 |
CN111931991A (zh) * | 2020-07-14 | 2020-11-13 | 上海眼控科技股份有限公司 | 气象临近预报方法、装置、计算机设备和存储介质 |
CN111724083B (zh) * | 2020-07-21 | 2023-10-13 | 腾讯科技(深圳)有限公司 | 金融风险识别模型的训练方法、装置、计算机设备及介质 |
CN111738355B (zh) * | 2020-07-22 | 2020-12-01 | 中国人民解放军国防科技大学 | 注意力融合互信息的图像分类方法、装置及存储介质 |
CN113095440B (zh) * | 2020-09-01 | 2022-05-17 | 电子科技大学 | 基于元学习者的训练数据生成方法及因果效应异质反应差异估计方法 |
CN112085103B (zh) * | 2020-09-10 | 2023-06-27 | 北京百度网讯科技有限公司 | 基于历史行为的数据增强方法、装置、设备以及存储介质 |
CN112069329B (zh) * | 2020-09-11 | 2024-03-15 | 腾讯科技(深圳)有限公司 | 文本语料的处理方法、装置、设备及存储介质 |
CN114248265B (zh) * | 2020-09-25 | 2023-07-07 | 广州中国科学院先进技术研究所 | 一种基于元模拟学习的多任务智能机器人学习方法及装置 |
CN112364999B (zh) * | 2020-10-19 | 2021-11-19 | 深圳市超算科技开发有限公司 | 冷水机调节模型的训练方法、装置及电子设备 |
CN112269769B (zh) * | 2020-11-18 | 2023-12-05 | 远景智能国际私人投资有限公司 | 数据压缩方法、装置、计算机设备及存储介质 |
CN112446505B (zh) * | 2020-11-25 | 2023-12-29 | 创新奇智(广州)科技有限公司 | 一种元学习建模方法及装置、电子设备、存储介质 |
CN112598133B (zh) * | 2020-12-16 | 2023-07-28 | 联合汽车电子有限公司 | 车辆数据的处理方法、装置、设备和存储介质 |
CN112766388A (zh) * | 2021-01-25 | 2021-05-07 | 深圳中兴网信科技有限公司 | 模型获取方法、电子设备和可读存储介质 |
CN113298286A (zh) * | 2021-03-31 | 2021-08-24 | 捷佳润科技集团股份有限公司 | 一种基于机器学习的火龙果上市时间预测的方法 |
CN113408374B (zh) * | 2021-06-02 | 2022-09-23 | 湖北工程学院 | 基于人工智能的产量预估方法、装置、设备及存储介质 |
CN113392118B (zh) * | 2021-06-04 | 2022-10-18 | 中电四川数据服务有限公司 | 一种基于机器学习的数据更新检测系统及其方法 |
CN113537297B (zh) * | 2021-06-22 | 2023-07-28 | 同盾科技有限公司 | 一种行为数据预测方法及装置 |
CN113657926A (zh) * | 2021-07-28 | 2021-11-16 | 上海明略人工智能(集团)有限公司 | 一种广告效果预测方法、系统、电子设备及存储介质 |
CN113792845A (zh) * | 2021-09-07 | 2021-12-14 | 未鲲(上海)科技服务有限公司 | 基于元学习的预测模型训练方法、装置、电子设备及介质 |
CN114444717A (zh) * | 2022-01-25 | 2022-05-06 | 杭州海康威视数字技术股份有限公司 | 自主学习方法、装置、电子设备及机器可读存储介质 |
CN114491039B (zh) * | 2022-01-27 | 2023-10-03 | 四川大学 | 基于梯度改进的元学习少样本文本分类方法 |
CN114822855B (zh) * | 2022-06-28 | 2022-09-20 | 北京智精灵科技有限公司 | 基于ftrl模型的认知训练任务推送方法、系统及构建方法 |
CN115470936B (zh) * | 2022-09-23 | 2023-06-06 | 广州爱浦路网络技术有限公司 | 一种基于nwdaf的机器学习模型更新方法及装置 |
CN115919273B (zh) * | 2022-12-07 | 2024-10-18 | 北京中电普华信息技术有限公司 | 一种基于深度学习的亚健康预警系统及相关设备 |
Citations (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN105279240A (zh) * | 2015-09-28 | 2016-01-27 | 暨南大学 | 客户端起源信息关联感知的元数据预取方法及系统 |
CN108665506A (zh) * | 2018-05-10 | 2018-10-16 | 腾讯科技(深圳)有限公司 | 图像处理方法、装置、计算机存储介质及服务器 |
CN109919299A (zh) * | 2019-02-19 | 2019-06-21 | 西安交通大学 | 一种基于元学习器逐步梯度校正的元学习算法 |
CN109948783A (zh) * | 2019-03-29 | 2019-06-28 | 中国石油大学(华东) | 一种基于注意力机制的网络结构优化方法 |
CN109961142A (zh) * | 2019-03-07 | 2019-07-02 | 腾讯科技(深圳)有限公司 | 一种基于元学习的神经网络优化方法及装置 |
CN109961089A (zh) * | 2019-02-26 | 2019-07-02 | 中山大学 | 基于度量学习和元学习的小样本和零样本图像分类方法 |
CN109992866A (zh) * | 2019-03-25 | 2019-07-09 | 新奥数能科技有限公司 | 负荷预测模型的训练方法、装置、可读介质及电子设备 |
CN110163380A (zh) * | 2018-04-28 | 2019-08-23 | 腾讯科技(深圳)有限公司 | 数据分析方法、模型训练方法、装置、设备及存储介质 |
CN110196946A (zh) * | 2019-05-29 | 2019-09-03 | 华南理工大学 | 一种基于深度学习的个性化推荐方法 |
CN110276446A (zh) * | 2019-06-26 | 2019-09-24 | 北京百度网讯科技有限公司 | 模型训练和选择推荐信息的方法和装置 |
EP3568811A1 (en) * | 2017-02-24 | 2019-11-20 | Deepmind Technologies Limited | Training machine learning models |
-
2019
- 2019-12-02 CN CN201911213128.5A patent/CN111191791B/zh active Active
Patent Citations (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN105279240A (zh) * | 2015-09-28 | 2016-01-27 | 暨南大学 | 客户端起源信息关联感知的元数据预取方法及系统 |
EP3568811A1 (en) * | 2017-02-24 | 2019-11-20 | Deepmind Technologies Limited | Training machine learning models |
CN110163380A (zh) * | 2018-04-28 | 2019-08-23 | 腾讯科技(深圳)有限公司 | 数据分析方法、模型训练方法、装置、设备及存储介质 |
CN108665506A (zh) * | 2018-05-10 | 2018-10-16 | 腾讯科技(深圳)有限公司 | 图像处理方法、装置、计算机存储介质及服务器 |
CN109919299A (zh) * | 2019-02-19 | 2019-06-21 | 西安交通大学 | 一种基于元学习器逐步梯度校正的元学习算法 |
CN109961089A (zh) * | 2019-02-26 | 2019-07-02 | 中山大学 | 基于度量学习和元学习的小样本和零样本图像分类方法 |
CN109961142A (zh) * | 2019-03-07 | 2019-07-02 | 腾讯科技(深圳)有限公司 | 一种基于元学习的神经网络优化方法及装置 |
CN109992866A (zh) * | 2019-03-25 | 2019-07-09 | 新奥数能科技有限公司 | 负荷预测模型的训练方法、装置、可读介质及电子设备 |
CN109948783A (zh) * | 2019-03-29 | 2019-06-28 | 中国石油大学(华东) | 一种基于注意力机制的网络结构优化方法 |
CN110196946A (zh) * | 2019-05-29 | 2019-09-03 | 华南理工大学 | 一种基于深度学习的个性化推荐方法 |
CN110276446A (zh) * | 2019-06-26 | 2019-09-24 | 北京百度网讯科技有限公司 | 模型训练和选择推荐信息的方法和装置 |
Non-Patent Citations (5)
Title |
---|
A SIMPLE NEURAL ATTENTIVE META-LEARNER;Nikhil Mishra 等;《arXiv》;1-17 * |
Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks;Chelsea Finn 等;《arXiv》;1-13 * |
On the Importance of Attention in Meta-Learning for Few-Shot Text Classification;Xiang Jiang 等;《arXiv》;1-13 * |
Online Meta-Learning;Chelsea Finn 等;《arXiv》;1-19 * |
基于 VAE 和注意力机制的小样本图像分类方法;郑欣悦 等;《计算机应用与软件》;第36卷(第10期);168-174 * |
Also Published As
Publication number | Publication date |
---|---|
CN111191791A (zh) | 2020-05-22 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111191791B (zh) | 基于机器学习模型的图片分类方法、装置及设备 | |
EP3968179A1 (en) | Place recognition method and apparatus, model training method and apparatus for place recognition, and electronic device | |
CN111507378A (zh) | 训练图像处理模型的方法和装置 | |
EP4163831A1 (en) | Neural network distillation method and device | |
CN111860573A (zh) | 模型训练方法、图像类别检测方法、装置和电子设备 | |
CN112651511A (zh) | 一种训练模型的方法、数据处理的方法以及装置 | |
Shi et al. | Recent advances in plant disease severity assessment using convolutional neural networks | |
US11621930B2 (en) | Systems and methods for generating dynamic conversational responses using trained machine learning models | |
US20220237917A1 (en) | Video comparison method and apparatus, computer device, and storage medium | |
US12008329B2 (en) | Systems and methods for generating dynamic conversational responses through aggregated outputs of machine learning models | |
WO2024001806A1 (zh) | 一种基于联邦学习的数据价值评估方法及其相关设备 | |
CN115631008B (zh) | 商品推荐方法、装置、设备及介质 | |
Schiliro et al. | A novel cognitive computing technique using convolutional networks for automating the criminal investigation process in policing | |
CN114330499A (zh) | 分类模型的训练方法、装置、设备、存储介质及程序产品 | |
CN114155388B (zh) | 一种图像识别方法、装置、计算机设备和存储介质 | |
Shehu et al. | Lateralized approach for robustness against attacks in emotion categorization from images | |
CN109508640A (zh) | 一种人群情感分析方法、装置和存储介质 | |
WO2024114659A1 (zh) | 一种摘要生成方法及其相关设备 | |
US11587323B2 (en) | Target model broker | |
CN117540336A (zh) | 时间序列预测方法、装置及电子设备 | |
CN116541507A (zh) | 一种基于动态语义图神经网络的视觉问答方法及系统 | |
CN111582404B (zh) | 内容分类方法、装置及可读存储介质 | |
CN115222112A (zh) | 一种行为预测方法、行为预测模型的生成方法及电子设备 | |
CN116992937A (zh) | 神经网络模型的修复方法和相关设备 | |
CA3196711A1 (en) | Systems and methods for generating dynamic conversational responses through aggregated outputs of machine learning models |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |