CN110738270A - 基于均值迭代的多任务学习模型训练以及预测方法 - Google Patents

基于均值迭代的多任务学习模型训练以及预测方法 Download PDF

Info

Publication number
CN110738270A
CN110738270A CN201911003056.1A CN201911003056A CN110738270A CN 110738270 A CN110738270 A CN 110738270A CN 201911003056 A CN201911003056 A CN 201911003056A CN 110738270 A CN110738270 A CN 110738270A
Authority
CN
China
Prior art keywords
task
train
label
test
iteration
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN201911003056.1A
Other languages
English (en)
Other versions
CN110738270B (zh
Inventor
周鋆
孙立健
符鹏涛
朱先强
张维明
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
National University of Defense Technology
Original Assignee
National University of Defense Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by National University of Defense Technology filed Critical National University of Defense Technology
Priority to CN201911003056.1A priority Critical patent/CN110738270B/zh
Publication of CN110738270A publication Critical patent/CN110738270A/zh
Application granted granted Critical
Publication of CN110738270B publication Critical patent/CN110738270B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting

Landscapes

  • Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Complex Calculations (AREA)

Abstract

一种基于均值迭代的多任务学习模型训练以及预测方法,先获取多个任务的样本数据集,并将各任务的样本数据集均划分为训练集和测试集。对于每一个任务,利用其对应的训练集得到各任务上各类标签的先验概率以及各任务中各特征变量在各类标签上的条件概率,计算每一个任务对应的测试集中各实例的类标签。基于均值迭代方法更新各任务上各类标签的先验概率。不断循环迭代,直至不同任务上类标签的先验概率误差的绝对值之和Δ小于设定的阈值时收敛,得到训练好的多任务学习模型。本发明可以显著提高多任务学习的效率。同时,本发明可以更加充分利用任务之间的共享信息和数据的先验知识,利用较少的计算资源便可以达到更佳的分类效果。

Description

基于均值迭代的多任务学习模型训练以及预测方法
技术领域
本发明涉及大数据处理技术领域,特别是涉及一种多任务学习模型训练以及 预测方法。
背景技术
随着信息技术的发展与大数据时代的到来,机器学习成为解决实际问题的重 要方法之一。目前大多数机器学习方法采用的都是单任务学习(Single-task Learning,STL)方法,即多个任务之间的学习过程是相互独立的,这种方法忽略了 任务之间的相关性。单任务学习方法在解决复杂问题时,往往将问题分解为简单 且相互独立的子问题,然后再通过合并结果得到复杂问题的解。
然而,这样做看似合理,实则是不正确的。因为现实世界中很多问题不能简 单地分解为一个个独立的子问题,这些子问题通常是相互关联的,即通过一些共 享因素或共享表示联系在一起。把现实世界中的问题当作相互独立的单任务处理, 忽略了问题之间的关联信息。为弥足单任务学习方法的不足,多任务学习 (Multi-task Learning,MTL)方法同时对这些任务进行学习,通过提取和利用任务之 间的共享信息,分类器中的参数更新相互影响,从而改善分类器的泛化性能。当 任务中已知标记样本数较少时,多任务学习可以有效地增加样本数,从而使分类 结果更加准确。
一些早期的多任务学习方法假设不同任务的目标函数参数是相似的或多个相 关任务共享同一特征子集,这些多任务方法均通过正则项约束使相关任务之间的 不同尽可能小。目前,多任务学习主要基于稀疏表示。2008年,Argyriou等人提 出MTL-FEAT模型,该模型通过学习多任务之间的稀疏表示来共享信息。2011 年,Kang等人对MTL-FEAT模型的约束进行松弛,提出了DG-MTL模型,该模 型通过将多任务学习问题转化为混合整数规划问题,显著提高了多任务学习模型 的性能。
在MTL-FEAT和DG-MTL模型的基础上,2012年Abhishek等人提出GO-MTL 模型,GO-MTL模型采用一种新的多任务学习分组和重叠组结构,每个任务组的 参数位于一个低维子空间中,不同分组的任务通过共享一个或多个潜在的基任务 来共享信息。2018年,Jeong等人在上述模型的基础上进行了改进,提出 VSTG-MTL模型,该模型在学习任务间重叠组结构的同时引入变量选择,模型将 系数矩阵分解成两个低秩矩阵的乘积,从而更加合理地利用多个相关任务之间的 共享信息。
2017年,一种区别于传统稀疏表示策略的自步多任务学习spMTFL模型被提 出,Murugesan等人采用一种类人学习策略,将自定步长的任务选择方法引入到 多任务学习中,模型通过迭代选择最合适的任务来学习任务参数和更新共享信息。 遗憾的是,在现有的多任务学习方法中,普遍存在模型运行效率低,数据的利用 不充分,特别是对数据的先验知识利用不充分等问题。上述模型主要缺陷如下, 在单任务学习模型中,如朴素贝叶斯学习模型,忽略了多个学习任务下不同任务 之间的相关性,导致学习得到的模型拥有较差的分类结果;在上述多任务学习模 型中,模型需要消耗大量的计算资源和计算时间,计算效率低。
发明内容
针对现有技术多任务学习方法运行效率低和数据信息利用不足的缺陷,本发 明提供一种基于均值迭代的多任务学习模型训练以及预测方法。
一种基于均值迭代的多任务学习模型训练方法,包括以下步骤:
(1)获取多个任务的样本数据集,并将各任务的样本数据集均划分为训练 集和测试集。
设有T个任务,对于各任务t(t=1,2,3…T)分别采集含D个特征变量的多个实 例,同时获取每一个任务的各实例所对应的类标签,得到各任务的样本数据集; 将各任务的样本数据集均划分为训练集和测试集。
(2)训练多任务学习模型。
(2.1)对于每一个任务,利用其对应的训练集得到各任务上各类标签的先 验概率以及各任务中各特征变量在各类标签上的条件概率。
(2.2)根据各任务上各类标签的先验概率以及各任务中各特征变量在各类 标签上的条件概率计算每一个任务对应的测试集中各实例的类标签;
(2.3)将步骤(2.2)计算得到的各任务测试集中各实例的类标签与步骤(1) 中采集得到的对应类标签进行比较,得到步骤(2.2)计算得到的分类结果的准确 率或者F1值,根据准确率或者F1值确定当前的分类结果评分;
(2.4)将步骤(2.2)计算得到的所有任务对应的测试集的分类结果集成成 总的分类结果,基于均值迭代方法更新各任务上各类标签的先验概率,返回步骤 (2.2);
(2.5)在循环迭代过程中,如果第m次迭代过程中得到的分类结果评分大 于m-1次迭代过程中得到的分类结果评分,则将第m次迭代过程中所采用的各 任务上类标签的先验概率作为当前的各任务上类标签的最优先验值;
在循环迭代过程中,计算第m次和第m-1次迭代时的不同任务上类标签的先 验概率误差的绝对值之和Δ,当Δ小于设定的超参数ε时收敛,得到训练好的多任 务学习模型,输出最终的各任务上各类标签的最优先验值。
本发明步骤(1)中,设有T个任务,对于各任务t(t=1,2,3…T)分别采集含 D个特征变量的多个实例,同时获取每一个任务的各实例所对应的类标签,各实 例所对应的类标签可用{C1,C2,…,CK}表示,K表示所有实例的类标签的总个数。 x1,x2,…,xD表示每一个实例的特征变量,特征变量的总数为D。
对于第t个任务的样本数据集,Traint表示第t个任务对应的训练集,其中 的实例数有Traint(all)个。Testt表示第t个任务对应的测试集,其中的实例数 有Testt(all)个。第t个任务的样本数据集中的所有实例的总数Nt即为 Traint(all)+Testt(all)。
T个任务总的训练集Train即为{Train1,……,Traint,……,TrainT},T个 任务总的测试集Test即为{Test1,……,Testt,……,TestT}。
本发明步骤(2.1)中,利用各任务对应的训练集得到各任务上各类标签的先 验概率以及各任务中各特征变量在各类标签上的条件概率,方法如下:
对于第t个任务对应的训练集Traint,令Traint(Ck)表示Traint的所有实例中 类标签为Ck的实例数量,则第t个任务中的第k个类标签Ck的先验概率为:
Pt(Ck)=Traint(Ck)/Traint(all)
对于第t个任务对应的训练集Traint,Traint(Ck,xd)表示Traint的所有实例 中类标签为Ck且第d个特征变量xd的取值相同的实例数量,则第t个任务中的第 d个的特征变量xd在第k个类标签Ck上的条件概率为:
Pt(xd|Ck)=Traint(Ck,xd)/Traint(Ck)
进一步地,为了避免其它特征变量所携带的信息被训练集中未出现的特征变 量所“抹去”,采用拉普拉斯修正各任务上各类标签的先验概率以及各任务中各 特征变量在各类标签上的条件概率,方法如下:
令kd表示训练集Train中的第d个特征变量xd可能的取值的总数量,修正后的 第t个任务中的第k个类标签Ck的先验概率Pt(Ck)以及第t个任务中的第d个的特征 变量xd在第k个类标签Ck上的条件概率Pt(xd|Ck),分别为:
Pt(Ck)=(Traint(Ck)+1)/(Traint(all)+K)
Pt(xd|Ck)=(Traint(Ck,xd)+1)/(Traint(Ck)+kd)
本发明步骤(2.2)中,根据各任务上各类标签的先验概率以及各任务中各特 征变量在各类标签上的条件概率计算每一个任务对应的测试集中各实例的类标 签的方法如下:
对于第t个任务对应的测试集Testt中的各实例的类标签通过下式计算得到:
Figure BDA0002241916070000051
其中:
Figure BDA0002241916070000052
()表示从K个类标签{C1,C2,…,CK}中选取概率值最大的那一个 类标签作为当前实例计算得到的类标签。
本发明步骤(2.4)中,将步骤(2.2)计算得到的所有任务对应的测试集的 分类结果集成成总的分类结果,基于均值迭代方法更新各任务上各类标签的先验 概率,方法如下:
(2.4.1)重新计算各任务上各类标签的先验概率。
按照下式重新计算第t个任务中的第k个类标签Ck的先验概率Pt'(Ck):
Figure BDA0002241916070000053
其中Testt(Ck)表示经步骤(2.2)计算得到的测试集Testt的所有实例中类标签为Ck的实例数量,Testt(all)表示第t个任务对应的测试集中的实例数目。
(2.4.2)对(2.4.1)中重新计算得到的各任务上各类标签的先验概率进行调 整,调整后的各任务上各类标签的先验概率作为第m次迭代循环过程中所采用的 各任务上各类标签的先验概率。
按照下式对重新计算得到的第t个任务中的第k个类标签Ck的先验概率 Pt'(Ck)进行调整,得到第m次迭代循环过程中第t个任务中的第k个类标签Ck的先验概率:
Pm(Ck)=α·Pt'(Ck)+(1-α)·Pm-1(Ck)
其中α设置为0.3;Pm(Ck)表示第m次迭代过程中第t个任务中的第k个类标签 Ck的先验概率
Figure BDA0002241916070000061
Pm-1(Ck)表示第m-1次迭代过程中第t个任务中的第k个 类标签Ck的先验概率
Figure BDA0002241916070000062
如果当前m=1,则即表示利用第t个任 务对应的训练集Traint计算得到的第t个任务中的第k个类标签Ck的先验概率 Pt(Ck)。
本发明步骤(2.5)中,计算第m次和第m-1次迭代时的不同任务上类标签的 先验概率误差的绝对值之和Δ,
Figure BDA0002241916070000064
超参数ε设置 为10-5。当Δ小于设定的超参数ε时收敛,得到训练好的多任务学习模型,输出最 终的各任务上各类标签的最优先验值
Figure BDA0002241916070000065
以及各任务中各特征变量在各类标 签上的条件概率Pt(xd|Ck)。
根据最终输出的各任务上各类标签的最优先验值
Figure BDA0002241916070000066
以及各任务中各特 征变量在各类标签上的条件概率Pt(xd|Ck),就能够实现对多任务中待进行预测的 实例进行类标签预测。因此,本发明还提供一种基于多任务学习模型的预测方法, 包括:
基于前述所提供的基于均值迭代的多任务学习模型训练方法,得到训练好的 多任务学习模型;
对于T个任务中的第t(t=1,2,3…T)个任务中待进行预测的实例n,获取该 实例的D个特征变量x1,n,x2,n,…,xD,n,基于训练好的多任务学习模型最终输出的 第t(t=1,2,3…T)个任务上各类标签的最优先验值
Figure BDA0002241916070000071
和第t(t=1,2,3…T)个 任务中各特征变量在各类标签上的条件概率Pt(xd|Ck),根据下式即可得到该实例 n的类标签。
Figure BDA0002241916070000072
本发明还提供一种计算机设备,包括存储器和处理器,所述存储器存储有计 算机程序,所述处理器执行所述计算机程序时实现所述基于多任务学习模型的预 测方法的步骤。
本发明还提供一种计算机可读存储介质,其上存储有计算机程序,所述计算 机程序被处理器执行时实现所述基于多任务学习模型的预测方法的步骤。
本发明可以显著提高多任务学习的效率。同时,本发明可以更加充分利用任务 之间的共享信息和数据的先验知识,利用较少的计算资源便可以达到更佳的分类 效果。
附图说明
图1为一个实施例中的流程图;
具体实施方式
为使本发明实施例的目的、技术方案和优点更加清楚明白,下面将以附图及 详细叙述清楚说明本发明所揭示内容的精神,任何所属技术领域技术人员在了解 本发明内容的实施例后,当可由本发明内容所教示的技术,加以改变及修饰,其 并不脱离本发明内容的精神与范围。本发明的示意性实施例及其说明用于解释本 发明,但并不作为对本发明的限定。
实施例1:
一种基于均值迭代的多任务学习模型训练方法,包括以下步骤:
(1)获取多个任务的样本数据集,并将各任务的样本数据集均划分为训练 集和测试集。
设有T个任务,对于各任务t(t=1,2,3…T)分别采集含D个特征变量的多个实 例,同时获取每一个任务的各实例所对应的类标签,得到各任务的样本数据集。 其中:各实例所对应的类标签可用{C1,C2,…,Ck}表示,K表示所有实例的类标签 的总个数。x1,x2,…,xD表示每一个实例的特征变量,特征变量的总数为D。
将各任务的样本数据集均划分为训练集和测试集。对于第t个任务的样本数据集,Traint表示第t个任务对应的训练集,其中的实例数有Traint(all)个。Testt表示第t)个任务对应的测试集,其中的实例数有Testt(all)个。第t个任务的样 本数据集中的所有实例的总数Nt即为Traint(all)+Testt(all)。
T个任务总的训练集Train即为{Train1,……,Traint,……,TrainT},T个 任务总的测试集Test即为{Test1,……,Testt,……,TestT}。
(2)训练多任务学习模型。
(2.1)对于每一个任务,利用其对应的训练集得到各任务上各类标签的先 验概率以及各任务中各特征变量在各类标签上的条件概率。
对于第t个任务对应的训练集Traint,令Traint(Ck)表示Traint的所有实例中 类标签为Ck的实例数量,则第t个任务中的第k个类标签Ck的先验概率为:
Pt(Ck)=Traint(Ck)/Traint(all) (1)
对于第t个任务对应的训练集Traint,Traint(Ck,xd)表示Traint的所有实例 中类标签为Ck且第d个特征变量xd的取值相同的实例数量,则第t个任务中的第 d个的特征变量xd在第k个类标签Ck上的条件概率为:
Pt(xd|Ck)=Traint(Ck,xd)/Traint(Ck) (2)
为了避免其它特征变量所携带的信息被训练集中未出现的特征变量所“抹 去”,采用拉普拉斯修正各任务上各类标签的先验概率以及各任务中各特征变量 在各类标签上的条件概率,方法如下:
令kd表示训练集Train中的第d个特征变量xd可能的取值的总数量,修正后的 第t个任务中的第k个类标签Ck的先验概率Pt(Ck)以及第t个任务中的第d个的特征 变量xd在第k个类标签Ck上的条件概率Pt(xd|Ck),分别为:
Pt(Ck)=(Traint(Ck)+1)/(Traint(all)+K) (3)
Pt(xd|Ck)=(Traint(Ck,xd)+1)/(Traint(Ck)+kd) (4)
(2.2)根据各任务上各类标签的先验概率以及各任务中各特征变量在各类标 签上的条件概率计算每一个任务对应的测试集中各实例的类标签。
朴素贝叶斯模型假设各个特征变量之间是相互独立的,即x1,x2,…,xD相互独 立。所有特征变量的条件概率为:
Figure BDA0002241916070000091
当给定类标签时,所有特征变量的联合概率Pt(x1,x2,…,xD)即可计算得到。因此,在第t个任务对应的测试集Testt中各实例上,每一个类标签的后验概率可以 表示为类标签的先验和所有特征变量的条件概率的乘积,即
Pt(Ck|x1,x2,…,xD)=Pt(x1,x2,…,xD|Ck)·Pt(Ck)/Pt(x1,x2,…,xD) (6)
具体计算过程中对于第t个任务中,为了防止连乘操作造成浮点下溢,通常采 用对数似然的方式进行计算,使乘法转换为加法。此时,计算公式为
Figure BDA0002241916070000101
根据最大后验准则,第t个任务对应的测试集Testt中各实例从各个类标签中 选取后验概率最大的类标签作为它的类标签。由于给定的第t个任务对应的测试集 Testt中各实例是确定的,对于所有的标签类别,P(x1,x2,…,xD)均为常数,故在 计算过程中通常省略计算。
最终得到,第t个任务对应的测试集Testt中的各实例的类标签的计算方式为:
Figure BDA0002241916070000102
其中:()表示从K个类标签{C1,C2,…,CK}中选取概率值最大的那一个 类标签作为当前实例计算得到的类标签。
(2.3)将步骤(2.2)计算得到的各任务测试集中各实例的类标签与步骤(1) 中采集得到的对应类标签进行比较,如果计算得到的类标签和步骤(1)中采集 得到的对应类标签一致,则表示分类结果正确,否则,分类结果不正确,以此得 到步骤(2.2)分类结果的准确率,根据准确率确定当前的分类结果评分。
(2.4)将步骤(2.2)计算得到的所有任务对应的测试集的分类结果集成成 总的分类结果,基于均值迭代方法更新各任务上各类标签的先验概率,返回步骤 (2.2)。
其中,基于均值迭代方法更新各任务上各类标签的先验概率的方法如下:
(2.4.1)重新计算各任务上各类标签的先验概率。
按照下式重新计算第t个任务中的第k个类标签Ck的先验概率Pt'(Ck):
Figure BDA0002241916070000111
其中Testt(Ck)表示经步骤(2.2)计算得到的测试集Testt的所有实例中类标签为Ck的实例数量,Testt(all)表示第t个任务对应的测试集中的实例数目。
(2.4.2)对(2.4.1)中重新计算得到的各任务上各类标签的先验概率进行调 整,调整后的各任务上各类标签的先验概率作为第m次迭代循环过程中所采用的 各任务上各类标签的先验概率。
按照下式对重新计算得到的第t个任务中的第k个类标签Ck的先验概率 Pt'(Ck)进行调整,得到第m次迭代循环过程中第t个任务中的第k个类标签Ck的先验概率:
Pm(Ck)=α·Pt'(Ck)+(1-α)·Pm-1(Ck)(10)
其中α设置为0.3;Pm(Ck)表示第m次迭代过程中第t个任务中的第k个类标签 Ck的先验概率
Figure BDA0002241916070000112
Pm-1(Ck)表示第m-1次迭代过程中第t个任务中的第k个 类标签Ck的先验概率
Figure BDA0002241916070000113
如果当前m=1,则
Figure BDA0002241916070000114
即表示利用第t个任 务对应的训练集Traint计算得到的第t个任务中的第k个类标签Ck的先验概率 Pt(Ck)。
随着各任务上各类标签的先验概率的不断更新,不同任务上先验补充信息 Pt'(Ck)的比重将不断增加。随着循环迭代次数的增多,各个任务的
Figure BDA0002241916070000115
将逐渐 趋于一致。
(2.5)在循环迭代过程中,如果第m次迭代过程中得到的分类结果评分大 于m-1次迭代过程中得到的分类结果评分,则将第m次迭代过程中所采用的各 任务上类标签的先验概率
Figure BDA0002241916070000121
作为当前的各任务上类标签的最优先验值
在循环迭代过程中,计算第m次和第m-1次迭代时的不同任务上类标签的先 验概率误差的绝对值之和Δ,
Figure BDA0002241916070000123
超参数ε设置为 10-5。当Δ小于设定的超参数ε时收敛,得到训练好的多任务学习模型,输出最终 的各任务上各类标签的最优先验值
Figure BDA0002241916070000124
以及各任务中各特征变量在各类标签 上的条件概率Pt(xd|Ck)。
实施例2:
参照图1,为一种基于多任务学习模型的预测方法的流程图,方法包括:
基于实施例1所提供的基于均值迭代的多任务学习模型训练方法,得到训练 好的多任务学习模型.
对于T个任务中的第t(t=1,2,3…T)个任务中待进行预测的实例n,获取该 实例的D个特征变量x1,n,x2,n,…,xD,n,基于训练好的多任务学习模型最终输出的 第t(t=1,2,3…T)个任务上各类标签的最优先验值
Figure BDA0002241916070000125
和第t(t=1,2,3…T)个 任务中各特征变量在各类标签上的条件概率Pt(xd|Ck),根据公式(11)即可得到 该实例n的类标签。
Figure BDA0002241916070000126
本实施例3:
一种多数据集学生成绩预测模型的训练方法,包括
(1)设有多所学校(一所学校对应一个任务),对于各学校分别采集含D个 特征变量的多个实例。实例即学生,学生对应的D个特征变量分别可以包括考试 年份、有资格获得免费校餐的学生百分比、VR第一等级(口头推理测试的最高等 级)学生百分比、学校性别(S.GN.)、学校教派、学生性别,学生民族、VR波段(可 以取值1、2或3)。同时获取每一个任务的各实例所对应的类标签,得到各任务的 样本数据集。将各任务t的样本数据集均划分为训练集和测试集;
(2)训练多任务学习模型;
(2.1)对于每一个任务,利用其对应的训练集得到各任务上各类标签的先验 概率以及各任务中各特征变量在各类标签上的条件概率;
(2.2)根据各任务上各类标签的先验概率以及各任务中各特征变量在各类 标签上的条件概率计算每一个任务对应的测试集中各实例的类标签;
(2.3)将步骤(2.2)计算得到的各任务测试集中各实例的类标签与步骤(1) 中采集得到的对应类标签进行比较,得到步骤(2.2)计算得到的分类结果的F1 值,根据F1值确定当前的分类结果评分;
(2.4)将步骤(2.2)计算得到的所有任务对应的测试集的分类结果集成成 总的分类结果,基于均值迭代方法更新各任务上各类标签的先验概率,返回步骤 (2.2);
(2.5)在循环迭代过程中,如果第m次迭代过程中得到的分类结果评分大 于m-1次迭代过程中得到的分类结果评分,则将第m次迭代过程中所采用的各 任务上类标签的先验概率作为当前的各任务上类标签的最优先验值;
在循环迭代过程中,计算第m次和第m-1次迭代时的不同任务上类标签的先 验概率误差的绝对值之和Δ,当Δ小于设定的超参数ε时收敛,得到训练好的多任 务学习模型,输出最终的各任务上各类标签的最优先验值。
将来自伦敦教育管理局的数据作为数据集,采用实施例1中提供的方法训练 得到对应的多数据集学生成绩预测分类器。数据集由伦敦139所中学的15362名学 生在1985年、1986年和1987年间的考试成绩组成。因此,伦敦139所中学对应139 个任务,对应于预测学生的表现。特征变量包括考试年份(YR)、4个学校属性和3 个学生属性。每一所学校在某一年不变的属性是:有资格获得免费校餐的学生百分 比、VR第一等级(口头推理测试的最高等级)学生百分比、学校性别(S.GN.)和学校 教派(S.DN.)。学生特有的属性有:性别(GEN)、VR波段(可以取值1、2或3)、民族 (ETH)。
在本实施例中,为每个可能的属性值使用一个二进制变量替换了分类属性 (即所有不是百分比的属性),总共得到27个属性。同时对该数据集进行划分,当 成绩大于20时为正样本,有6984条,占比45.46%;当成绩小于等于20时为负样本, 有8378条,占比54.54%,正负样本之比约为1:1.
通过5倍交叉验证法随机分割生成训练集和测试集,每个学校(task)80%的样 本属于训练集,20%的样本属于测试集。注意到每个任务(school)的实例(学生)数 量不同。平均每所学校有80名学生参加训练,每所学校有30名学生参加考试。
对于上述数据集,采用本发明提供的方法训练得到对应的多数据集学生成绩 预测分类器以及集中传统多任务学习模型训练得到的分类器,计算不同任务上最 终分类结果的F1值,得到实验结果如下表所示。
真实数据集
Figure BDA0002241916070000141
从上表可以看出,采用本发明提供的方法训练得到对应的多数据集学生成绩 预测分类器达到了最佳的分类效果,其分类效果优于其它传统模型。此外, NB-MTL(Optimal)模型虽然F1值仅高于VSTG-MTL模型0.004和spMTFL模型 0.008,但在实验过程中,本发明方法的计算时间仅为VSTG-MTL模型的1/5和 spMTFL模型的1/3,时间开销大幅下降。实验结果表明,本发明用于提高数据的 分类效果和模型的泛化性能是可行的。较传统多任务学习模型而言,本发明无论 在分类结果上,还是计算时间上,都能够取得最优的分类效果。
以上所述实施例仅表达了本申请的几种实施方式,其描述较为具体和详细, 但并不能因此而理解为对发明专利范围的限制。应当指出的是,对于本领域的普 通技术人员来说,在不脱离本申请构思的前提下,还可以做出若干变形和改进, 这些都属于本申请的保护范围。因此,本申请专利的保护范围应以所附权利要求 为准。

Claims (10)

1.一种基于均值迭代的多任务学习模型训练方法,其特征在于,包括:
(1)设有T个任务,对于各任务t(t=1,2,3…T)分别采集含D个特征变量的多个实例,同时获取每一个任务的各实例所对应的类标签,得到各任务的样本数据集;将各任务t的样本数据集均划分为训练集和测试集;
(2)训练多任务学习模型;
(2.1)对于每一个任务,利用其对应的训练集得到各任务上各类标签的先验概率以及各任务中各特征变量在各类标签上的条件概率;
(2.2)根据各任务上各类标签的先验概率以及各任务中各特征变量在各类标签上的条件概率计算每一个任务对应的测试集中各实例的类标签;
(2.3)将步骤(2.2)计算得到的各任务测试集中各实例的类标签与步骤(1)中采集得到的对应类标签进行比较,得到步骤(2.2)计算得到的分类结果的准确率或者F1值,根据准确率或者F1值确定当前的分类结果评分;
(2.4)将步骤(2.2)计算得到的所有任务对应的测试集的分类结果集成成总的分类结果,基于均值迭代方法更新各任务上各类标签的先验概率,返回步骤(2.2);
(2.5)在循环迭代过程中,如果第m次迭代过程中得到的分类结果评分大于m-1次迭代过程中得到的分类结果评分,则将第m次迭代过程中所采用的各任务上类标签的先验概率作为当前的各任务上类标签的最优先验值;
在循环迭代过程中,计算第m次和第m-1次迭代时的不同任务上类标签的先验概率误差的绝对值之和Δ,当Δ小于设定的超参数ε时收敛,得到训练好的多任务学习模型,输出最终的各任务上各类标签的最优先验值。
2.根据权利要求1所述的基于均值迭代的多任务学习模型训练方法,其特征在于,步骤(1)中,每一个任务的所有实例的类标签用{C1,C2,…,CK}表示,K表示所有实例的类标签的总个数,x1,x2,…,xD表示每一个实例的特征变量,特征变量的总数为D;
对于第t个任务的样本数据集,将其划分为训练集和测试集,Traint表示第t个任务对应的训练集,其中的实例数分别有Traint(all)个,Testt表示第t个任务对应的测试集,其中的实例数分别有Testt(all)个;第t个任务的样本数据集中的所有实例的总数Nt即为Traint(all)+Testt(all)。
T个任务总的训练集Train即为{Train1,……,Traint,……,TrainT},T个任务总的测试集Test即为{Test1,……,Testt,……,TestT}。
3.根据权利要求2所述的基于均值迭代的多任务学习模型训练方法,其特征在于,步骤(2.1)中,利用各任务对应的训练集计算各任务上各类标签的先验概率的方法如下:
对于第t个任务对应的训练集Traint,令Traint(Ck)表示Traint的所有实例中类标签为Ck的实例数量,则第t个任务中的第k个类标签Ck的先验概率为:
Pt(Ck)=Traint(Ck)/Traint(all)
步骤(2.1)中,各任务中各特征变量在各类标签上的条件概率的计算方法如下:
对于第t个任务对应的训练集Traint,Traint(Ck,xd)表示Traint的所有实例中类标签为Ck且第d个特征变量xd的取值相同的实例数量,则第t个任务中的第d个的特征变量xd在第k个类标签Ck上的条件概率为:
Pt(xd|Ck)=Traint(Ck,xd)/Traint(Ck)。
4.根据权利要求3所述的基于均值迭代的多任务学习模型训练方法,其特征在于,步骤(2.1)中,采用拉普拉斯修正各任务上各类标签的先验概率以及各任务中各特征变量在各类标签上的条件概率,方法如下:
令kd表示训练集Train中的第d个特征变量xd可能的取值的总数量,修正后的第t个任务中的第k个类标签Ck的先验概率Pt(Ck)以及第t个任务中的第d个的特征变量xd在第k个类标签Ck上的条件概率Pt(xd|Ck),分别为:
Pt(Ck)=(Traint(Ck)+1)/(Traint(all)+K)
Pt(xd|Ck)=(Traint(Ck,xd)+1)/(Traint(Ck)+kd)。
5.根据权利要求4所述的基于均值迭代的多任务学习模型训练方法,其特征在于,步骤(2.2)中,计算每一个任务对应的测试集中各实例的类标签的方法如下:
对于第t个任务对应的测试集Testt中的各实例的类标签通过下式计算得到:
Figure FDA0002241916060000031
其中:
Figure FDA0002241916060000032
表示从K个类标签{C1,C2,…,CK}中选取概率值最大的那一个类标签作为当前实例计算得到的类标签。
6.根据权利要求5所述的基于均值迭代的多任务学习模型训练方法,其特征在于,步骤(2.4)中,基于均值迭代方法更新各任务上各类标签的先验概率的方法如下:
(2.4.1)重新计算各任务上各类标签的先验概率;
按照下式重新计算第t个任务中的第k个类标签Ck的先验概率Pt'(Ck):
其中Testt(Ck)表示经步骤(2.2)计算得到的测试集Testt的所有实例中类标签为Ck的实例数量,Testt(all)表示第t个任务对应的测试集中的实例数目。
(2.4.2)对(2.4.1)中重新计算得到的各任务上各类标签的先验概率进行调整,调整后的各任务上各类标签的先验概率作为第m次迭代循环过程中所采用的各任务上各类标签的先验概率;
按照下式对重新计算得到的第t个任务中的第k个类标签Ck的先验概率Pt'(Ck)进行调整,得到第m次迭代循环过程中第t个任务中的第k个类标签Ck的先验概率:
Pm(Ck)=α·Pt'(Ck)+(1-α)·Pm-1(Ck)
其中α设置为0.3;Pm(Ck)表示第m次迭代过程中第t个任务中的第k个类标签Ck的先验概率Pm-1(Ck)表示第m-1次迭代过程中第t个任务中的第k个类标签Ck的先验概率
Figure FDA0002241916060000042
如果当前m=1,则
Figure FDA0002241916060000043
即表示利用第t个任务对应的训练集Traint计算得到的第t个任务中的第k个类标签Ck的先验概率Pt(Ck)。
7.根据权利要求6所述的基于均值迭代的多任务学习模型训练方法,其特征在于,步骤(2.5)中,
Figure FDA0002241916060000044
超参数ε设置为10-5
8.一种基于多任务学习模型的预测方法,其特征在于,包括
采用如权利要求1至7中任一权利要求所述的基于均值迭代的多任务学习模型训练方法,得到训练好的多任务学习模型;
对于T个任务中的第t(t=1,2,3…T)个任务中待进行预测的实例n,获取该实例的D个特征变量x1,n,x2,n,…,xD,n,基于训练好的多任务学习模型最终输出的第t(t=1,2,3…T)个任务上各类标签的最优先验值
Figure FDA0002241916060000045
和第t(t=1,2,3…T)个任务中各特征变量在各类标签上的条件概率Pt(xd|Ck),根据下式即可得到该实例n的类标签,
Figure FDA0002241916060000051
9.一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求8所述基于多任务学习模型的预测方法的步骤。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求8所述基于多任务学习模型的预测方法的步骤。
CN201911003056.1A 2019-10-22 2019-10-22 基于均值迭代的多任务学习模型训练以及预测方法 Active CN110738270B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201911003056.1A CN110738270B (zh) 2019-10-22 2019-10-22 基于均值迭代的多任务学习模型训练以及预测方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201911003056.1A CN110738270B (zh) 2019-10-22 2019-10-22 基于均值迭代的多任务学习模型训练以及预测方法

Publications (2)

Publication Number Publication Date
CN110738270A true CN110738270A (zh) 2020-01-31
CN110738270B CN110738270B (zh) 2022-03-11

Family

ID=69270785

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201911003056.1A Active CN110738270B (zh) 2019-10-22 2019-10-22 基于均值迭代的多任务学习模型训练以及预测方法

Country Status (1)

Country Link
CN (1) CN110738270B (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112766514A (zh) * 2021-01-22 2021-05-07 支付宝(杭州)信息技术有限公司 一种联合训练机器学习模型的方法、系统及装置
CN112801203A (zh) * 2021-02-07 2021-05-14 新疆爱华盈通信息技术有限公司 基于多任务学习的数据分流训练方法及系统

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180336484A1 (en) * 2017-05-18 2018-11-22 Sas Institute Inc. Analytic system based on multiple task learning with incomplete data
CN109063743A (zh) * 2018-07-06 2018-12-21 云南大学 基于半监督多任务学习的医疗数据分类模型的构建方法
CN109450834A (zh) * 2018-10-30 2019-03-08 北京航空航天大学 基于多特征关联和贝叶斯网络的通信信号分类识别方法
CN109815826A (zh) * 2018-12-28 2019-05-28 新大陆数字技术股份有限公司 人脸属性模型的生成方法及装置
CN110188358A (zh) * 2019-05-31 2019-08-30 北京神州泰岳软件股份有限公司 自然语言处理模型的训练方法及装置

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180336484A1 (en) * 2017-05-18 2018-11-22 Sas Institute Inc. Analytic system based on multiple task learning with incomplete data
CN109063743A (zh) * 2018-07-06 2018-12-21 云南大学 基于半监督多任务学习的医疗数据分类模型的构建方法
CN109450834A (zh) * 2018-10-30 2019-03-08 北京航空航天大学 基于多特征关联和贝叶斯网络的通信信号分类识别方法
CN109815826A (zh) * 2018-12-28 2019-05-28 新大陆数字技术股份有限公司 人脸属性模型的生成方法及装置
CN110188358A (zh) * 2019-05-31 2019-08-30 北京神州泰岳软件股份有限公司 自然语言处理模型的训练方法及装置

Non-Patent Citations (4)

* Cited by examiner, † Cited by third party
Title
MURUGESAN K ET AL: "Self-paced multitask learning with shared knowledge", 《ARXIV》 *
ZHOU Y ET AL: "An ensemble learning approach for XSS attack detection with domain knowledge and threat intelligence", 《COMPUTERS & SECURITY》 *
周国军等: "基于Hadoop的并行朴素贝叶斯分类算法", 《玉林师范学院学报》 *
赵兴刚等: "基于K-L散度和散度均值的改进矩阵CFAR检测器", 《中国科学:信息科学》 *

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112766514A (zh) * 2021-01-22 2021-05-07 支付宝(杭州)信息技术有限公司 一种联合训练机器学习模型的方法、系统及装置
CN112801203A (zh) * 2021-02-07 2021-05-14 新疆爱华盈通信息技术有限公司 基于多任务学习的数据分流训练方法及系统

Also Published As

Publication number Publication date
CN110738270B (zh) 2022-03-11

Similar Documents

Publication Publication Date Title
US20200143248A1 (en) Machine learning model training method and device, and expression image classification method and device
WO2020143130A1 (zh) 基于物理环境博弈的自主进化智能对话方法、系统、装置
US20220076136A1 (en) Method and system for training a neural network model using knowledge distillation
CN111753101B (zh) 一种融合实体描述及类型的知识图谱表示学习方法
US11093714B1 (en) Dynamic transfer learning for neural network modeling
CN107203600B (zh) 一种利用刻画因果依赖关系和时序影响机制增强答案质量排序的评判方法
CN110633786A (zh) 用于确定人工神经网络拓扑的技术
US11488309B2 (en) Robust machine learning for imperfect labeled image segmentation
CN107608953B (zh) 一种基于不定长上下文的词向量生成方法
CN109063743B (zh) 基于半监督多任务学习的医疗数据分类模型的构建方法
US20210271980A1 (en) Deterministic decoder variational autoencoder
CN110738270B (zh) 基于均值迭代的多任务学习模型训练以及预测方法
CN111008224A (zh) 一种基于深度多任务表示学习的时间序列分类和检索方法
CN113963165A (zh) 一种基于自监督学习的小样本图像分类方法及系统
CN111784595A (zh) 一种基于历史记录的动态标签平滑加权损失方法及装置
CN114943017A (zh) 一种基于相似性零样本哈希的跨模态检索方法
CN114743037A (zh) 一种基于多尺度结构学习的深度医学图像聚类方法
CN113901991A (zh) 一种基于伪标签的3d点云数据半自动标注方法及装置
CN115795065A (zh) 基于带权哈希码的多媒体数据跨模态检索方法及系统
Fine et al. Query by committee, linear separation and random walks
CN110222737A (zh) 一种基于长短时记忆网络的搜索引擎用户满意度评估方法
CN113408418A (zh) 一种书法字体与文字内容同步识别方法及系统
CN111737467B (zh) 一种基于分段卷积神经网络的对象级情感分类方法
CN110766069B (zh) 基于最优值迭代的多任务学习模型训练以及预测方法
CN115600595A (zh) 一种实体关系抽取方法、系统、设备及可读存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant