CN112529183A - 一种基于知识蒸馏的模型自适应更新方法 - Google Patents

一种基于知识蒸馏的模型自适应更新方法 Download PDF

Info

Publication number
CN112529183A
CN112529183A CN202110178302.8A CN202110178302A CN112529183A CN 112529183 A CN112529183 A CN 112529183A CN 202110178302 A CN202110178302 A CN 202110178302A CN 112529183 A CN112529183 A CN 112529183A
Authority
CN
China
Prior art keywords
model
distance
samples
time instant
adaptive
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110178302.8A
Other languages
English (en)
Inventor
李劲松
朱世强
吕卫国
池胜强
田雨
周天舒
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang Lab
Original Assignee
Zhejiang Lab
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang Lab filed Critical Zhejiang Lab
Priority to CN202110178302.8A priority Critical patent/CN112529183A/zh
Publication of CN112529183A publication Critical patent/CN112529183A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • G06N5/02Knowledge representation; Symbolic representation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Computational Linguistics (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了一种基于知识蒸馏的模型自适应更新方法,本发明采用模型自适应更新方法,代替模型重训练过程,减少了计算资源和人力资源的投入;采用模型参数相似性约束,提炼旧模型中的知识,避免了模型更新中的灾难性遗忘现象,保持预测模型的稳定性;利用知识蒸馏的思想,构建实时预测的神经网络模型,使预测模型适应数据分布的变化,保证预测模型的可塑性,实现模型自适应更新中稳定性和可塑性的最佳权衡。相较于在线维护模型池,对新数据同时预测的方法,大大减少了模型实时预测需要的计算资源和内存资源。相较于直接利用新数据增量更新模型的方法,有效解决了模型更新中的灾难性遗忘现象。

Description

一种基于知识蒸馏的模型自适应更新方法
技术领域
本发明属于机器学习技术领域,具体地,涉及一种基于知识蒸馏的模型自适应更新方法。
背景技术
基于机器学习的数据自动预测方法的一个假设是:模型的训练数据和测试数据来自于同一个总体分布。然而,随着时间的推移,数据分布会发生变化。数据分布的变化可以进一步分为样本的变化和类别的变化。所谓样本的变化,是指样本在特征同构空间下的特征值的变化,以及每一类样本所占比例的可能变化。类别的变化是指新的类别的出现,即原来的分类发生了变化。所以,一段时间后,基于历史数据训练的模型可能不适用于一些新的数据。因此,有必要面向自动化预测系统的实际应用,实现模型的自适应更新,以保证不断变化的数据能够被正确预测。常用的模型自适应更新方法有模型重训练、不同时间窗口的模型集成和增量学习三种。
模型重训练需要消耗大量的计算资源和建模时间。不同时间窗口的模型集成需要维护一个模型池,对新数据同时进行打分,会消耗大量的计算资源。增量学习方法则存在灾难性遗忘现象,即随着时间的推移,模型使用最新的数据进行更新,新获得的数据往往会抹去之前学习到的模式;增量学习方法需要具备从新数据中持续学习的能力,同时保留以前学到的知识,是模型自适应更新中的稳定性-可塑性困境。
发明内容
本发明的目的在于针对现有技术的不足,提供一种基于知识蒸馏的模型自适应更新方法。
本发明的目的是通过以下技术方案来实现的:一种基于知识蒸馏的模型自适应更新方法,该方法包括以下步骤:
(1)在时刻
Figure 984372DEST_PATH_IMAGE001
,基于初始数据
Figure 582844DEST_PATH_IMAGE002
Figure 351211DEST_PATH_IMAGE003
训练一个模型
Figure 376936DEST_PATH_IMAGE004
,其中,
Figure 847231DEST_PATH_IMAGE002
Figure 233082DEST_PATH_IMAGE001
时刻数据的特征,
Figure 105223DEST_PATH_IMAGE003
Figure 52582DEST_PATH_IMAGE001
时刻数据的标签;
(2)利用模型
Figure 744594DEST_PATH_IMAGE005
Figure 216027DEST_PATH_IMAGE002
预测,得到
Figure 926363DEST_PATH_IMAGE002
的预测软标签
Figure 559470DEST_PATH_IMAGE006
(3)基于
Figure 4357DEST_PATH_IMAGE007
Figure 233476DEST_PATH_IMAGE003
Figure 814630DEST_PATH_IMAGE006
训练一个神经网络模型
Figure 133484DEST_PATH_IMAGE008
,模型
Figure 800089DEST_PATH_IMAGE008
的输入为
Figure 347745DEST_PATH_IMAGE002
,标签 为
Figure 799717DEST_PATH_IMAGE009
Figure 774627DEST_PATH_IMAGE006
,输出为
Figure 194107DEST_PATH_IMAGE010
,损失函数为:
Figure 263563DEST_PATH_IMAGE011
其中,
Figure 819309DEST_PATH_IMAGE012
为神经网络模型中的参数,
Figure 965119DEST_PATH_IMAGE013
为调整损失函数中
Figure 357049DEST_PATH_IMAGE014
Figure 512087DEST_PATH_IMAGE015
权重的系数,基于模型
Figure 922339DEST_PATH_IMAGE005
预测的软标签
Figure 488319DEST_PATH_IMAGE006
的信息熵确定;
Figure 616812DEST_PATH_IMAGE014
Figure 795115DEST_PATH_IMAGE003
Figure 591032DEST_PATH_IMAGE016
之间的对数损失函数;
Figure 813066DEST_PATH_IMAGE017
Figure 412544DEST_PATH_IMAGE006
Figure 909384DEST_PATH_IMAGE010
之间的对数损失函数;
(4)在时刻
Figure 294229DEST_PATH_IMAGE018
Figure 703476DEST_PATH_IMAGE019
执行基于知识蒸馏的模型自适应更新,步骤如下:
a.在时刻
Figure 72140DEST_PATH_IMAGE020
,基于初始数据
Figure 841513DEST_PATH_IMAGE021
Figure 330132DEST_PATH_IMAGE022
训练一个模型
Figure 159548DEST_PATH_IMAGE023
b.利用模型
Figure 749929DEST_PATH_IMAGE024
Figure 73725DEST_PATH_IMAGE021
预测,得到
Figure 167583DEST_PATH_IMAGE021
的预测软标签
Figure 417168DEST_PATH_IMAGE025
c.基于
Figure 494845DEST_PATH_IMAGE021
Figure 606021DEST_PATH_IMAGE022
Figure 85544DEST_PATH_IMAGE025
和模型
Figure 7495DEST_PATH_IMAGE026
训练神经网络模型
Figure 306889DEST_PATH_IMAGE027
,模型
Figure 471023DEST_PATH_IMAGE027
的输入为
Figure 70631DEST_PATH_IMAGE021
, 标签为
Figure 412751DEST_PATH_IMAGE028
Figure 684595DEST_PATH_IMAGE025
,输出为
Figure 137573DEST_PATH_IMAGE029
;利用模型
Figure 591688DEST_PATH_IMAGE026
的参数对模型
Figure 353976DEST_PATH_IMAGE027
的参数进行初始化,模 型
Figure 627963DEST_PATH_IMAGE026
的参数在模型
Figure 884632DEST_PATH_IMAGE030
训练过程中保持不变;损失函数为:
Figure 654969DEST_PATH_IMAGE031
其中,
Figure 338891DEST_PATH_IMAGE032
为神经网络模型
Figure 100174DEST_PATH_IMAGE027
中的参数;
Figure 409801DEST_PATH_IMAGE033
为调整损失函数中
Figure 572930DEST_PATH_IMAGE034
Figure 912906DEST_PATH_IMAGE035
权重的系数,基于模型
Figure 161485DEST_PATH_IMAGE024
预测的软标签
Figure 759957DEST_PATH_IMAGE025
的信息熵确定;
Figure 26859DEST_PATH_IMAGE036
为调整 损失函数中
Figure 52584DEST_PATH_IMAGE037
权重的系数,基于数据集
Figure 522879DEST_PATH_IMAGE038
Figure 675774DEST_PATH_IMAGE021
的相似性确定;
Figure 547915DEST_PATH_IMAGE039
Figure 10121DEST_PATH_IMAGE040
Figure 951401DEST_PATH_IMAGE041
之间的对数损失函数;
Figure 157254DEST_PATH_IMAGE035
Figure 149481DEST_PATH_IMAGE025
Figure 267741DEST_PATH_IMAGE042
之间的对数损失函数;
Figure 978208DEST_PATH_IMAGE043
为模型参数相似性约束项,以模型
Figure 722173DEST_PATH_IMAGE026
Figure 818174DEST_PATH_IMAGE030
中所有参数的距离进行 度量;
利用真实数据进行模型训练,得到模型参数
Figure 356602DEST_PATH_IMAGE032
,从而确定模型。
进一步地,模型
Figure 773939DEST_PATH_IMAGE044
选用以下机器学习方法:神经网络、逻 辑回归、支持向量机、决策树、随机森林。
进一步地,所述步骤(3)中:
Figure 321595DEST_PATH_IMAGE045
其中,
Figure 757256DEST_PATH_IMAGE046
Figure 981433DEST_PATH_IMAGE047
时刻的样本总量,
Figure 400913DEST_PATH_IMAGE048
Figure 486681DEST_PATH_IMAGE047
时刻第j个样本的预测软标签。
进一步地,所述步骤(3)中:
Figure 58738DEST_PATH_IMAGE049
Figure 938970DEST_PATH_IMAGE050
其中,
Figure 580167DEST_PATH_IMAGE051
Figure 718893DEST_PATH_IMAGE047
时刻的样本总量,
Figure 394725DEST_PATH_IMAGE052
Figure 196590DEST_PATH_IMAGE047
时刻第j个样本的标签,
Figure 590662DEST_PATH_IMAGE048
Figure 18232DEST_PATH_IMAGE047
时刻第j 个样本的预测软标签,
Figure 814150DEST_PATH_IMAGE053
Figure 754293DEST_PATH_IMAGE047
时刻第j个样本的神经网络模型预测输出。
进一步地,所述步骤(4)中:
Figure 635661DEST_PATH_IMAGE054
其中,
Figure 617655DEST_PATH_IMAGE055
Figure 268079DEST_PATH_IMAGE056
时刻的样本总量,
Figure 661014DEST_PATH_IMAGE057
Figure 13367DEST_PATH_IMAGE056
时刻第j个样本的预测软标签。
进一步地,所述步骤(4)中:
Figure 782740DEST_PATH_IMAGE058
Figure 553250DEST_PATH_IMAGE059
其中,
Figure 133398DEST_PATH_IMAGE060
为数据集
Figure 458200DEST_PATH_IMAGE061
Figure 14952DEST_PATH_IMAGE062
之间的距离,
Figure 312073DEST_PATH_IMAGE055
Figure 797543DEST_PATH_IMAGE056
时刻的样本总量,
Figure 517554DEST_PATH_IMAGE063
Figure 715186DEST_PATH_IMAGE064
时刻的样本总量,
Figure 151984DEST_PATH_IMAGE065
分别为
Figure 451378DEST_PATH_IMAGE066
中的第p,q个样本,
Figure 382556DEST_PATH_IMAGE067
分别为
Figure 451006DEST_PATH_IMAGE062
中的第 p,q个样本;
Figure 793126DEST_PATH_IMAGE068
函数用于计算两个样本间的距离。
进一步地,
Figure 829084DEST_PATH_IMAGE069
函数用于计算两个样本间的距离,距离采用:曼哈顿距离、欧氏距 离、切比雪夫距离、余弦距离。
进一步地,所述步骤(4)中:
Figure 547641DEST_PATH_IMAGE070
Figure 736177DEST_PATH_IMAGE071
其中,
Figure 999931DEST_PATH_IMAGE055
Figure 8338DEST_PATH_IMAGE056
时刻的样本总量,
Figure 779854DEST_PATH_IMAGE072
Figure 822896DEST_PATH_IMAGE056
时刻第j个样本的标签,
Figure 506818DEST_PATH_IMAGE073
Figure 18833DEST_PATH_IMAGE056
时刻第j个 样本的预测软标签,
Figure 813614DEST_PATH_IMAGE074
Figure 976742DEST_PATH_IMAGE056
时刻第j个样本的神经网络模型预测输出。
进一步地,
Figure 346412DEST_PATH_IMAGE075
用于控制模型
Figure 63833DEST_PATH_IMAGE076
参数在训练过程中的更新幅度,以 模型
Figure 927883DEST_PATH_IMAGE026
Figure 696250DEST_PATH_IMAGE077
中所有参数的距离进行度量,距离采用:曼哈顿距离、欧氏距离、切比雪 夫距离、余弦距离。
本发明的有益效果是:本发明采用模型自适应更新方法,代替模型重训练过程,减少了计算资源和人力资源的投入;采用模型参数相似性约束,提炼旧模型中的知识,避免了模型更新中的灾难性遗忘现象,保持预测模型的稳定性;利用知识蒸馏的思想,构建实时预测的神经网络模型,使预测模型适应数据分布的变化,保证预测模型的可塑性,实现模型自适应更新中稳定性和可塑性的最佳权衡。相较于在线维护模型池,对新数据同时预测的方法,大大减少了模型实时预测需要的计算资源和内存资源。相较于直接利用新数据增量更新模型的方法,有效解决了模型更新中的灾难性遗忘现象。
附图说明
图1为本发明基于知识蒸馏的模型自适应更新方法模型结构图;
图2为本发明基于知识蒸馏的模型自适应更新方法流程图;
图3为自适应更新模型结构图。
具体实施方式
为使本发明的上述目的、特征和优点能够更加明显易懂,下面结合附图对本发明的具体实施方式做详细的说明。
在下面的描述中阐述了很多具体细节以便于充分理解本发明,但是本发明还可以采用其他不同于在此描述的方式来实施,本领域技术人员可以在不违背本发明内涵的情况下做类似推广,因此本发明不受下面公开的具体实施例的限制。
本发明中所述的知识蒸馏作为一种有效的模型压缩方法,利用一个小模型来模仿大模型(或模型集合)的预测能力,从而保留大模型学习到的知识。
如图1、2所示,本发明提出的一种基于知识蒸馏的模型自适应更新方法,包括以下步骤:
(1)在时刻
Figure 987554DEST_PATH_IMAGE001
,基于初始数据
Figure 441538DEST_PATH_IMAGE002
Figure 843701DEST_PATH_IMAGE003
训练一个模型
Figure 981421DEST_PATH_IMAGE004
,其中,
Figure 928780DEST_PATH_IMAGE002
Figure 151951DEST_PATH_IMAGE001
时刻数据的特征,
Figure 92225DEST_PATH_IMAGE003
Figure 84452DEST_PATH_IMAGE001
时刻数据的标签;模型
Figure 966826DEST_PATH_IMAGE005
可以采用任意一种机器学习 方法,包括神经网络、逻辑回归、支持向量机、决策树、随机森林等。
(2)利用模型
Figure 146134DEST_PATH_IMAGE005
Figure 375253DEST_PATH_IMAGE002
预测,得到
Figure 487565DEST_PATH_IMAGE002
的预测软标签
Figure 25994DEST_PATH_IMAGE006
(3)基于
Figure 473025DEST_PATH_IMAGE007
Figure 755101DEST_PATH_IMAGE003
Figure 721920DEST_PATH_IMAGE006
训练一个神经网络模型
Figure 205420DEST_PATH_IMAGE008
Figure 359321DEST_PATH_IMAGE008
Figure 710668DEST_PATH_IMAGE001
时刻得到的最终模 型,用于
Figure 515682DEST_PATH_IMAGE078
时刻的数据预测,模型
Figure 927072DEST_PATH_IMAGE008
的输入为
Figure 568269DEST_PATH_IMAGE002
,标签为
Figure 208460DEST_PATH_IMAGE003
Figure 618712DEST_PATH_IMAGE006
,模型
Figure 935424DEST_PATH_IMAGE008
的输出为
Figure 313185DEST_PATH_IMAGE016
,损失函数为:
Figure 6334DEST_PATH_IMAGE011
其中,
Figure 271094DEST_PATH_IMAGE012
为神经网络模型中的参数,
Figure 775018DEST_PATH_IMAGE013
为调整损失函数中
Figure 125228DEST_PATH_IMAGE079
Figure 622069DEST_PATH_IMAGE015
权重的系数:
Figure 256181DEST_PATH_IMAGE080
Figure 649116DEST_PATH_IMAGE051
Figure 17781DEST_PATH_IMAGE047
时刻的样本总量,
Figure 537886DEST_PATH_IMAGE081
Figure 777238DEST_PATH_IMAGE047
时刻第j个样本的预测软标签,
Figure 855921DEST_PATH_IMAGE082
Figure 446302DEST_PATH_IMAGE014
Figure 19366DEST_PATH_IMAGE003
Figure 863956DEST_PATH_IMAGE010
之间的对数损失函数:
Figure 864273DEST_PATH_IMAGE083
Figure 941951DEST_PATH_IMAGE052
Figure 302394DEST_PATH_IMAGE047
时刻第j个样本的标签,
Figure 516338DEST_PATH_IMAGE082
Figure 438288DEST_PATH_IMAGE053
Figure 737683DEST_PATH_IMAGE047
时刻第j个样本的神 经网络模型预测输出;
Figure 652549DEST_PATH_IMAGE084
Figure 970267DEST_PATH_IMAGE006
Figure 312386DEST_PATH_IMAGE010
之间的对数损失函数:
Figure 99077DEST_PATH_IMAGE085
(4)在时刻
Figure 568367DEST_PATH_IMAGE018
Figure 491323DEST_PATH_IMAGE019
执行基于知识蒸馏的模型自适应更新,步骤如下:
a. 在时刻
Figure 253612DEST_PATH_IMAGE020
,基于初始数据
Figure 262019DEST_PATH_IMAGE021
Figure 784267DEST_PATH_IMAGE022
,训练一个模型
Figure 578042DEST_PATH_IMAGE023
, 其中,
Figure 996385DEST_PATH_IMAGE021
Figure 741356DEST_PATH_IMAGE056
时刻数据的特征,
Figure 801716DEST_PATH_IMAGE022
Figure 433685DEST_PATH_IMAGE056
时刻数据的标签;模型
Figure 39241DEST_PATH_IMAGE023
可以采用任意一 种机器学习方法,包括神经网络、逻辑回归、支持向量机、决策树、随机森林等;
b. 利用模型
Figure 553399DEST_PATH_IMAGE024
Figure 135559DEST_PATH_IMAGE021
预测,得到
Figure 153194DEST_PATH_IMAGE021
的预测软标签
Figure 444498DEST_PATH_IMAGE025
c. 基于
Figure 399947DEST_PATH_IMAGE021
Figure 67688DEST_PATH_IMAGE022
Figure 408671DEST_PATH_IMAGE025
和模型
Figure 120144DEST_PATH_IMAGE026
,训练神经网络模型
Figure 812156DEST_PATH_IMAGE076
Figure 768742DEST_PATH_IMAGE076
Figure 495390DEST_PATH_IMAGE056
时刻得到的 最终模型,用于
Figure 128497DEST_PATH_IMAGE086
时刻的数据预测,模型
Figure 557073DEST_PATH_IMAGE027
的输入为
Figure 301038DEST_PATH_IMAGE021
,标签为
Figure 632924DEST_PATH_IMAGE022
Figure 436932DEST_PATH_IMAGE025
,模型
Figure 837958DEST_PATH_IMAGE027
的 输出为
Figure 369302DEST_PATH_IMAGE042
,自适应更新模型结构如图3所示;
利用模型
Figure 70542DEST_PATH_IMAGE026
的参数对模型
Figure 530604DEST_PATH_IMAGE027
的参数进行初始化;模型
Figure 684505DEST_PATH_IMAGE026
的参数在模 型
Figure 770273DEST_PATH_IMAGE030
训练过程中保持不变;损失函数为:
Figure 840866DEST_PATH_IMAGE087
其中,
Figure 721097DEST_PATH_IMAGE032
为神经网络模型
Figure 627873DEST_PATH_IMAGE027
中的参数,
Figure 2485DEST_PATH_IMAGE088
为神经网络模型
Figure 943896DEST_PATH_IMAGE026
中的参数;
Figure 713138DEST_PATH_IMAGE033
为调整损失函数中
Figure 107210DEST_PATH_IMAGE039
Figure 534781DEST_PATH_IMAGE089
权重的系数,基于模型
Figure 81431DEST_PATH_IMAGE023
预测的软 标签
Figure 303465DEST_PATH_IMAGE025
的信息熵确定,计算公式如下:
Figure 184833DEST_PATH_IMAGE090
Figure 399782DEST_PATH_IMAGE055
Figure 519048DEST_PATH_IMAGE056
时刻的样本总量,
Figure 928295DEST_PATH_IMAGE057
Figure 296959DEST_PATH_IMAGE056
时刻第j个样本的预测软标签,
Figure 66332DEST_PATH_IMAGE091
Figure 554951DEST_PATH_IMAGE036
为调整损失函数中
Figure 384367DEST_PATH_IMAGE043
权重的系数,基于数据集
Figure 240328DEST_PATH_IMAGE038
Figure 743949DEST_PATH_IMAGE021
的相 似性确定,计算公式如下:
Figure 368965DEST_PATH_IMAGE092
Figure 369282DEST_PATH_IMAGE093
Figure 430648DEST_PATH_IMAGE094
为数据集
Figure 807403DEST_PATH_IMAGE061
Figure 286925DEST_PATH_IMAGE062
之间的距离,
Figure 943297DEST_PATH_IMAGE063
Figure 242691DEST_PATH_IMAGE064
时刻的样本总量,
Figure 141246DEST_PATH_IMAGE095
分 别为
Figure 475275DEST_PATH_IMAGE096
中的第
Figure 817395DEST_PATH_IMAGE097
Figure 620397DEST_PATH_IMAGE098
个样本,
Figure 807796DEST_PATH_IMAGE099
分别为
Figure 245599DEST_PATH_IMAGE100
中的第p,q个样本,
Figure 24200DEST_PATH_IMAGE067
分别为
Figure 32607DEST_PATH_IMAGE062
中 的第p,q个样本;
Figure 40008DEST_PATH_IMAGE101
函数用于计算两个样本间的距离,可以采用曼哈顿距离、欧氏距离、 切比雪夫距离、余弦距离等;
Figure 348630DEST_PATH_IMAGE102
Figure 766973DEST_PATH_IMAGE022
Figure 511944DEST_PATH_IMAGE103
之间的对数损失函数:
Figure 837883DEST_PATH_IMAGE104
Figure 220585DEST_PATH_IMAGE072
Figure 75409DEST_PATH_IMAGE056
时刻第j个样本的标签,
Figure 589567DEST_PATH_IMAGE105
Figure 171726DEST_PATH_IMAGE106
Figure 923782DEST_PATH_IMAGE056
时刻第j个样本的神 经网络模型预测输出;
Figure 215086DEST_PATH_IMAGE107
Figure 436114DEST_PATH_IMAGE025
Figure 307118DEST_PATH_IMAGE103
之间的对数损失函数:
Figure 428527DEST_PATH_IMAGE108
Figure 156311DEST_PATH_IMAGE073
Figure 848324DEST_PATH_IMAGE056
时刻第j个样本的预测软标签;
Figure 804909DEST_PATH_IMAGE109
为模型参数相似性约束项,控制模型
Figure 531557DEST_PATH_IMAGE110
参数在训练过程中的 更新幅度,以神经网络模型
Figure 164664DEST_PATH_IMAGE111
Figure 858819DEST_PATH_IMAGE110
中所有参数的距离进行度量,可以采用曼哈顿距 离、欧氏距离、切比雪夫距离、余弦距离等;优选地,本实施例采用欧氏距离进行度量,计算 公式如下:
Figure 337205DEST_PATH_IMAGE112
利用真实数据进行模型训练,得到模型参数
Figure 669091DEST_PATH_IMAGE113
,从而确定模型。
本发明基于分类器预测的软标签的信息熵,确定样本真实标签和预测软标签在新模型损失函数中的权重;基于数据集的相似性,确定模型参数相似性约束项在新模型损失函数中的权重,保持预测模型的稳定性。
以下给出本发明的具体应用场景,但不限于此:
基于人工智能方法的结直肠癌预后风险预测模型的预测准确率超过了临床常用的结直肠癌分期系统。但是,真实临床场景中,随着时间的推移,人口统计、疾病流行、临床实践和医疗保健系统作为一个整体可能会发生变化,这意味着基于静态截面数据的模型可能会过时,导致预测结果不再准确。其次,模型应用于临床实践会改变结直肠癌临床决策和干预措施,导致新数据的结果分布和预测因子-结果关联关系变化,从而导致模型性能快速衰退。因此,结直肠癌风险特征随时间变化的特性会降低模型临床效用,有必要实现临床风险预测模型的自适应更新,以保证不断变化的数据能够被正确预测,从而保证模型的时效性。
在金融风控领域,由于风险防控方的防御措施会抵御部分恶性攻击事件,风险施加方会不断寻找系统漏洞而采取新的攻击方式,导致新的恶性事件发生等。这些真实场景中的特征变化特性,要求风险防控方用于风险防御的模型具有自适应更新的能力,保证模型可以持续发挥作用。
在推荐系统领域,随着用户行为在系统中的记录不断增多,用户会表现出明显的倾向性;同时,用户也容易受当下热点信息的影响而改变使用行为特征。这些都要求推荐系统可以适应系统特征的变化而自适应更新。
本发明提出的基于知识蒸馏的模型自适应更新方法,可以解决医疗、金融风控、推荐系统等领域的预测系统,随着时间的推移数据分布发生变化,导致不断变化的数据不能被正确预测的问题。
以上所述仅是本发明的优选实施方式,虽然本发明已以较佳实施例披露如上,然而并非用以限定本发明。任何熟悉本领域的技术人员,在不脱离本发明技术方案范围情况下,都可利用上述揭示的方法和技术内容对本发明技术方案做出许多可能的变动和修饰,或修改为等同变化的等效实施例。因此,凡是未脱离本发明技术方案的内容,依据本发明的技术实质对以上实施例所做的任何的简单修改、等同变化及修饰,均仍属于本发明技术方案保护的范围内。

Claims (9)

1.一种基于知识蒸馏的模型自适应更新方法,其特征在于,该方法包括以下步骤:
(1)在时刻
Figure 472164DEST_PATH_IMAGE001
,基于初始数据
Figure 326988DEST_PATH_IMAGE002
Figure 824834DEST_PATH_IMAGE003
训练一个模型
Figure 892147DEST_PATH_IMAGE004
,其中,
Figure 394935DEST_PATH_IMAGE002
Figure 951818DEST_PATH_IMAGE001
时刻 数据的特征,
Figure 687693DEST_PATH_IMAGE005
Figure 807965DEST_PATH_IMAGE001
时刻数据的标签;
(2)利用模型
Figure 211264DEST_PATH_IMAGE006
Figure 876732DEST_PATH_IMAGE002
预测,得到
Figure 116215DEST_PATH_IMAGE002
的预测软标签
Figure 790910DEST_PATH_IMAGE007
(3)基于
Figure 783136DEST_PATH_IMAGE008
Figure 134352DEST_PATH_IMAGE009
Figure 110398DEST_PATH_IMAGE007
训练一个神经网络模型
Figure 808358DEST_PATH_IMAGE010
,模型
Figure 389512DEST_PATH_IMAGE010
的输入为
Figure 459099DEST_PATH_IMAGE002
,标签为
Figure 906130DEST_PATH_IMAGE011
Figure 657048DEST_PATH_IMAGE007
,输出为
Figure 623867DEST_PATH_IMAGE012
,损失函数为:
Figure 83930DEST_PATH_IMAGE013
其中,
Figure 503410DEST_PATH_IMAGE014
为神经网络模型中的参数,
Figure 120336DEST_PATH_IMAGE015
为调整损失函数中
Figure 190929DEST_PATH_IMAGE016
Figure 805581DEST_PATH_IMAGE017
权 重的系数,基于模型
Figure 931931DEST_PATH_IMAGE006
预测的软标签
Figure 86969DEST_PATH_IMAGE007
的信息熵确定;
Figure 28380DEST_PATH_IMAGE016
Figure 63201DEST_PATH_IMAGE009
Figure 457274DEST_PATH_IMAGE018
之间的对数损失函数;
Figure 416002DEST_PATH_IMAGE017
Figure 165915DEST_PATH_IMAGE007
Figure 919107DEST_PATH_IMAGE018
之间的对数损失函数;
(4)在时刻
Figure 3738DEST_PATH_IMAGE019
Figure 749846DEST_PATH_IMAGE020
执行基于知识蒸馏的模型自适应更新,步骤如下:
a.在时刻
Figure 665849DEST_PATH_IMAGE021
,基于初始数据
Figure 324364DEST_PATH_IMAGE022
Figure 912602DEST_PATH_IMAGE023
训练一个模型
Figure 150816DEST_PATH_IMAGE024
b.利用模型
Figure 639435DEST_PATH_IMAGE025
Figure 468851DEST_PATH_IMAGE022
预测,得到
Figure 590391DEST_PATH_IMAGE022
的预测软标签
Figure 429034DEST_PATH_IMAGE026
c.基于
Figure 742466DEST_PATH_IMAGE022
Figure 477204DEST_PATH_IMAGE023
Figure 804149DEST_PATH_IMAGE026
和模型
Figure 712062DEST_PATH_IMAGE027
训练神经网络模型
Figure 394847DEST_PATH_IMAGE028
,模型
Figure 582377DEST_PATH_IMAGE028
的输入为
Figure 616192DEST_PATH_IMAGE022
,标签 为
Figure 796638DEST_PATH_IMAGE029
Figure 645514DEST_PATH_IMAGE026
,输出为
Figure 253213DEST_PATH_IMAGE030
;利用模型
Figure 508745DEST_PATH_IMAGE027
的参数对模型
Figure 267051DEST_PATH_IMAGE028
的参数进行初始化,模型
Figure 190008DEST_PATH_IMAGE027
的参数在模型
Figure 968608DEST_PATH_IMAGE031
训练过程中保持不变;损失函数为:
Figure 960704DEST_PATH_IMAGE032
其中,
Figure 482952DEST_PATH_IMAGE033
为神经网络模型
Figure 57153DEST_PATH_IMAGE028
中的参数;
Figure 960649DEST_PATH_IMAGE034
为调整损失函数中
Figure 987511DEST_PATH_IMAGE035
Figure 765980DEST_PATH_IMAGE036
权重的系数,基于模型
Figure 929108DEST_PATH_IMAGE025
预测的软标签
Figure 315090DEST_PATH_IMAGE026
的信息熵确定;
Figure 829248DEST_PATH_IMAGE037
为调整损失函数中
Figure 912873DEST_PATH_IMAGE038
权重的系数,基于数据集
Figure 133769DEST_PATH_IMAGE039
Figure 939920DEST_PATH_IMAGE022
的相似性确定;
Figure 941374DEST_PATH_IMAGE040
Figure 812378DEST_PATH_IMAGE023
Figure 435252DEST_PATH_IMAGE041
之间的对数损失函数;
Figure 631878DEST_PATH_IMAGE036
Figure 855049DEST_PATH_IMAGE026
Figure 575749DEST_PATH_IMAGE042
之间的对数损失函数;
Figure 567976DEST_PATH_IMAGE038
为模型参数相似性约束项,以模型
Figure 669924DEST_PATH_IMAGE027
Figure 131124DEST_PATH_IMAGE031
中所有参数的距离进行 度量;
利用真实数据进行模型训练,得到模型参数
Figure 609509DEST_PATH_IMAGE033
,从而确定模型。
2.根据权利要求1所述的一种基于知识蒸馏的模型自适应更新方法,其特征在于,模型
Figure 721822DEST_PATH_IMAGE043
选用以下机器学习方法:神经网络、逻辑回归、支持向量机、决 策树、随机森林。
3.根据权利要求1所述的一种基于知识蒸馏的模型自适应更新方法,其特征在于,所述步骤(3)中:
Figure 978360DEST_PATH_IMAGE044
其中,
Figure 176123DEST_PATH_IMAGE045
Figure 989358DEST_PATH_IMAGE046
时刻的样本总量,
Figure 910172DEST_PATH_IMAGE047
Figure 150660DEST_PATH_IMAGE046
时刻第j个样本的预测软标签。
4.根据权利要求1所述的一种基于知识蒸馏的模型自适应更新方法,其特征在于,所述步骤(3)中:
Figure 288250DEST_PATH_IMAGE048
Figure 905176DEST_PATH_IMAGE049
其中,
Figure 726501DEST_PATH_IMAGE045
Figure 91886DEST_PATH_IMAGE046
时刻的样本总量,
Figure 998662DEST_PATH_IMAGE050
Figure 356962DEST_PATH_IMAGE046
时刻第j个样本的标签,
Figure 547641DEST_PATH_IMAGE047
Figure 129932DEST_PATH_IMAGE046
时刻第j个样 本的预测软标签,
Figure 727266DEST_PATH_IMAGE051
Figure 171148DEST_PATH_IMAGE046
时刻第j个样本的神经网络模型预测输出。
5.根据权利要求1所述的一种基于知识蒸馏的模型自适应更新方法,其特征在于,所述步骤(4)中:
Figure 435908DEST_PATH_IMAGE052
其中,
Figure 454679DEST_PATH_IMAGE053
Figure 585315DEST_PATH_IMAGE054
时刻的样本总量,
Figure 19839DEST_PATH_IMAGE055
Figure 670263DEST_PATH_IMAGE054
时刻第j个样本的预测软标签。
6.根据权利要求1所述的一种基于知识蒸馏的模型自适应更新方法,其特征在于,所述步骤(4)中:
Figure 813930DEST_PATH_IMAGE056
Figure 448174DEST_PATH_IMAGE057
其中,
Figure 483126DEST_PATH_IMAGE058
为数据集
Figure 706166DEST_PATH_IMAGE059
Figure 4423DEST_PATH_IMAGE060
之间的距离,
Figure 611116DEST_PATH_IMAGE053
Figure 653022DEST_PATH_IMAGE054
时刻的样本总量,
Figure 543617DEST_PATH_IMAGE061
Figure 58781DEST_PATH_IMAGE062
时 刻的样本总量,
Figure 339721DEST_PATH_IMAGE063
分别为
Figure 982055DEST_PATH_IMAGE064
中的第p,q个样本,
Figure 681152DEST_PATH_IMAGE065
分别为
Figure 383528DEST_PATH_IMAGE060
中的第p,q个样 本;
Figure 948502DEST_PATH_IMAGE066
函数用于计算两个样本间的距离。
7.根据权利要求6所述的一种基于知识蒸馏的模型自适应更新方法,其特征在于,
Figure 315898DEST_PATH_IMAGE066
函数用于计算两个样本间的距离,距离采用:曼哈顿距离、欧氏距离、切比雪夫距离、余弦距 离。
8.根据权利要求1所述的一种基于知识蒸馏的模型自适应更新方法,其特征在于,所述步骤(4)中:
Figure 181086DEST_PATH_IMAGE067
Figure 742780DEST_PATH_IMAGE068
其中,
Figure 60628DEST_PATH_IMAGE053
Figure 779186DEST_PATH_IMAGE054
时刻的样本总量,
Figure 685831DEST_PATH_IMAGE069
Figure 464431DEST_PATH_IMAGE054
时刻第j个样本的标签,
Figure 472838DEST_PATH_IMAGE070
Figure 11398DEST_PATH_IMAGE054
时刻第j个样本 的预测软标签,
Figure 523282DEST_PATH_IMAGE071
Figure 472783DEST_PATH_IMAGE054
时刻第j个样本的神经网络模型预测输出。
9.根据权利要求1所述的一种基于知识蒸馏的模型自适应更新方法,其特征在于,
Figure 748913DEST_PATH_IMAGE072
用于控制模型
Figure 278114DEST_PATH_IMAGE028
参数在训练过程中的更新幅度,以模型
Figure 706822DEST_PATH_IMAGE027
Figure 577957DEST_PATH_IMAGE028
中所有参数的距离进行度量,距离采用:曼哈顿距离、欧氏距离、切比雪夫距离、余弦距离。
CN202110178302.8A 2021-02-08 2021-02-08 一种基于知识蒸馏的模型自适应更新方法 Pending CN112529183A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110178302.8A CN112529183A (zh) 2021-02-08 2021-02-08 一种基于知识蒸馏的模型自适应更新方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110178302.8A CN112529183A (zh) 2021-02-08 2021-02-08 一种基于知识蒸馏的模型自适应更新方法

Publications (1)

Publication Number Publication Date
CN112529183A true CN112529183A (zh) 2021-03-19

Family

ID=74975541

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110178302.8A Pending CN112529183A (zh) 2021-02-08 2021-02-08 一种基于知识蒸馏的模型自适应更新方法

Country Status (1)

Country Link
CN (1) CN112529183A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114817742A (zh) * 2022-05-18 2022-07-29 平安科技(深圳)有限公司 基于知识蒸馏的推荐模型配置方法、装置、设备、介质

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114817742A (zh) * 2022-05-18 2022-07-29 平安科技(深圳)有限公司 基于知识蒸馏的推荐模型配置方法、装置、设备、介质
CN114817742B (zh) * 2022-05-18 2022-09-13 平安科技(深圳)有限公司 基于知识蒸馏的推荐模型配置方法、装置、设备、介质

Similar Documents

Publication Publication Date Title
Liang et al. Adversarial deep reinforcement learning in portfolio management
CN111563706A (zh) 一种基于lstm网络的多变量物流货运量预测方法
CN109902222A (zh) 一种推荐方法及装置
CN106874581A (zh) 一种基于bp神经网络模型的建筑空调能耗预测方法
CN112085254B (zh) 基于多重分形协同度量门控循环单元的预测方法及模型
CN107563542A (zh) 数据预测方法及装置和电子设备
CN111680786B (zh) 一种基于改进权重门控单元的时序预测方法
CN113393057A (zh) 一种基于深度融合机器学习模型的小麦产量集成预测方法
CN110097929A (zh) 一种高炉铁水硅含量在线预测方法
CN115983438A (zh) 数据中心末端空调系统运行策略确定方法及装置
CN116187835A (zh) 一种基于数据驱动的台区理论线损区间估算方法及系统
CN116526473A (zh) 基于粒子群优化lstm的电热负荷预测方法
Strulik Hyperbolic discounting and the time‐consistent solution of three canonical environmental problems
JPH04372046A (ja) 需要量予測方法及び装置
Chen APSO-LSTM: an improved LSTM neural network model based on APSO algorithm
CN112529183A (zh) 一种基于知识蒸馏的模型自适应更新方法
CN113821903B (zh) 温度控制方法和设备、模块化数据中心及存储介质
CN114202065A (zh) 一种基于增量式演化lstm的流数据预测方法及装置
Chen et al. Efficient approximate dynamic programming based on design and analysis of computer experiments for infinite-horizon optimization
CN108009859A (zh) 农产品价格波动预警方法及设备
CN113300884B (zh) 一种基于gwo-svr的分步网络流量预测方法
CN115221782A (zh) 一种大型公共建筑能耗混合预测方法及系统
CN114861555A (zh) 一种基于Copula理论的区域综合能源系统短期负荷预测方法
CN113627687A (zh) 一种基于arima-lstm组合模型的供水量预测方法
CN107273411A (zh) 业务操作与数据库操作数据的关联方法及设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication

Application publication date: 20210319

RJ01 Rejection of invention patent application after publication