CN115240871A - 一种基于深度嵌入聚类元学习的流行病预测方法 - Google Patents

一种基于深度嵌入聚类元学习的流行病预测方法 Download PDF

Info

Publication number
CN115240871A
CN115240871A CN202210887157.5A CN202210887157A CN115240871A CN 115240871 A CN115240871 A CN 115240871A CN 202210887157 A CN202210887157 A CN 202210887157A CN 115240871 A CN115240871 A CN 115240871A
Authority
CN
China
Prior art keywords
segment
meta
clustering
time sequence
model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210887157.5A
Other languages
English (en)
Inventor
赵学臣
张中
苗金凤
杨福强
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Nanchang Institute of Technology
Shandong Womens University
Original Assignee
Nanchang Institute of Technology
Shandong Womens University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Nanchang Institute of Technology, Shandong Womens University filed Critical Nanchang Institute of Technology
Priority to CN202210887157.5A priority Critical patent/CN115240871A/zh
Publication of CN115240871A publication Critical patent/CN115240871A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G16INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
    • G16HHEALTHCARE INFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR THE HANDLING OR PROCESSING OF MEDICAL OR HEALTHCARE DATA
    • G16H50/00ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics
    • G16H50/80ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics for detecting, monitoring or modelling epidemics or pandemics, e.g. flu
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q10/00Administration; Management
    • G06Q10/04Forecasting or optimisation specially adapted for administrative or management purposes, e.g. linear programming or "cutting stock problem"
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02ATECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
    • Y02A90/00Technologies having an indirect contribution to adaptation to climate change
    • Y02A90/10Information and communication technologies [ICT] supporting adaptation to climate change, e.g. for weather forecasting or climate simulation

Landscapes

  • Engineering & Computer Science (AREA)
  • Health & Medical Sciences (AREA)
  • Physics & Mathematics (AREA)
  • Public Health (AREA)
  • Business, Economics & Management (AREA)
  • Theoretical Computer Science (AREA)
  • Medical Informatics (AREA)
  • General Physics & Mathematics (AREA)
  • Human Resources & Organizations (AREA)
  • General Health & Medical Sciences (AREA)
  • Data Mining & Analysis (AREA)
  • Biomedical Technology (AREA)
  • Strategic Management (AREA)
  • Economics (AREA)
  • Marketing (AREA)
  • Databases & Information Systems (AREA)
  • Development Economics (AREA)
  • Entrepreneurship & Innovation (AREA)
  • Primary Health Care (AREA)
  • Operations Research (AREA)
  • Quality & Reliability (AREA)
  • Tourism & Hospitality (AREA)
  • Epidemiology (AREA)
  • General Business, Economics & Management (AREA)
  • Pathology (AREA)
  • Game Theory and Decision Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了一种基于深度嵌入聚类元学习的流行病预测方法,包括以下步骤:S1、获取历史数据,将历史数据切分为与目标地区数据长度相匹配的多个时间序列片段,每个时间序列片段均包括历史片段部分和未来片段部分;S2、针对每个时间序列片段,分别对其历史片段部分和未来片段部分进行标准化,并获取时间序列片段的特征集合;S3、基于无监督聚类模型对时间序列片段进行聚类,获得多个类,采样p个类构造元训练集,并获取元知识,基于元知识对新任务模型参数初始化,并通过元训练集对初始化后的新任务模型进行训练;S4、获取预测模型,初始化参数,通过多步梯度下降进行适应优化,进而针对元测试集中的新任务,对流行病发展进行预测。

Description

一种基于深度嵌入聚类元学习的流行病预测方法
技术领域
本发明涉及流行病预测技术领域,更具体的说是涉及一种基于深度嵌入聚类元学习的流行病预测方法。
背景技术
目前,用于预测流感或其他时间序列数据的机器/深度学习主要分为两类。首先,一些研究人员专注于寻找有效的“特征”。例如,搜索引擎查询数据用于预测Google FluTrends1中的流感。Twitter数据也用于其他研究论文。然而,这些模型通常受到来自互联网搜索等大量信息的不可靠来源的困扰。例如,谷歌的算法很容易过度拟合与流感无关的季节性术语,比如“高中篮球”。这个例子也证明了模型可解释性的重要性。其次,其他研究人员专注于寻找有效的“模型”,例如RF、Gradient Boosting、Multilayer Perceptron(MLP)、长短期记忆(LSTM)、变压器(TFR)等。基于深度学习的方法,例如Transformer因其准确性而受到更多关注,而它们中的大多数都因可解释性差而受苦。此外,统计模型和动态分析模型被认为是用于模拟流感感染模式的易于访问的工具,例如SI、SIS、SIR模型及其变体。然而,它们的参数会发生变化,并且参数的近似是困难的,例如基本再生数R0、人口流动性等。DEFSI将深度神经网络方法与因果模型相结合,以解决高分辨率ILI发病率预测。然而,这些模型中的大多数都严重依赖外部数据来提高准确性,例如经度和纬度以及气候信息
因此,提供一种基于深度嵌入聚类元学习的流行病预测方法,基于历史数据,针对疫情新爆发地区,利用少量初期数据,预测未来疫情发展情况是本领域技术人员亟需解决的问题。
发明内容
有鉴于此,本发明提供了一种基于深度嵌入聚类元学习的流行病预测方法;利用多个地区疫情传播的时间序列片段学习细粒度的传播模式,并可将学习到的传播模式用于新爆发疫情且仅存在少量历史数据地区的未来预测,仅需要很少的领域知识去构建元学习任务,并具有很好的可解释性;采用基于MAML的无监督元学习方法,将疾病传播模型从疫情传播稳定的地区迁移到疫情处于早期阶段的另一个地区。
为了实现上述目的,本发明采用如下技术方案:
一种基于深度嵌入聚类元学习的流行病预测方法,包括以下步骤:
S1、获取历史数据,将历史数据切分为与目标地区数据长度相匹配的多个时间序列片段,每个时间序列片段均包括历史片段部分和未来片段部分;
S2、针对每个时间序列片段,分别对其历史片段部分和未来片段部分进行标准化,并获取时间序列片段的特征集合;
S3、基于无监督聚类模型对时间序列片段进行聚类,获得多个类,采样p个类构造元训练集,并获取元知识,基于元知识对新任务模型参数初始化,并通过元训练集对初始化后的新任务模型进行训练;
S4、获取预测模型,初始化参数,通过多步梯度下降进行适应优化,进而针对元测试集中的新任务,对流行病发展进行预测。
优选的,所述步骤S1具体包括:
获取目标地区i长度为T的已知历史时间序列信息xi,将时间序列信息xi切分为多个长度为ω+ΔT的时间序列片段集合
Figure BDA0003766146160000031
Figure BDA0003766146160000032
其中,M为地区的数量,Ti为地区i的历史时序数据总长度,
Figure BDA0003766146160000033
为地区i在时刻t的时间序列片段,
Figure BDA0003766146160000034
为时间序列片段
Figure BDA0003766146160000035
在t时刻前的ω个数据,即历史片段部分,其与目标地区i的已知观测数据对齐,
Figure BDA0003766146160000036
为时间序列片段
Figure BDA0003766146160000037
在t时刻后的ΔT个数据,即未来片段部分,与待预测数据对齐。
优选的,所述步骤S2具体包括:
S21、分别对历史片段部分
Figure BDA0003766146160000038
和未来片段部分
Figure BDA0003766146160000039
进行标准化:
Figure BDA00037661461600000310
Figure BDA00037661461600000311
其中,
Figure BDA00037661461600000312
分别为时间序列片段
Figure BDA00037661461600000313
的历史片段部分
Figure BDA00037661461600000314
和未来片段部分
Figure BDA00037661461600000315
的均值,
Figure BDA00037661461600000316
分别为时间序列片段
Figure BDA00037661461600000317
的历史片段部分
Figure BDA00037661461600000318
和未来片段部分
Figure BDA00037661461600000319
的方差,将时间序列片段标准化到0和1之间;
S22、对于时间序列片段
Figure BDA00037661461600000320
基于CNN和RNN提取其序列局部特征和时序特征,时间序列片段
Figure BDA00037661461600000321
中的历史片段部分
Figure BDA00037661461600000322
对应已知数据的特征所在,因此时间序列的片段的嵌入表示仅从该部分特征中学习,将时间序列片段集合
Figure BDA0003766146160000041
Figure BDA0003766146160000042
投影到嵌入空间Z中,生成时间序列片段的特征集合
Figure BDA0003766146160000043
Figure BDA0003766146160000044
其中,ξ(·)为特征编码器,其由CNN和RNN两部分组成
Figure BDA0003766146160000045
为CNN特征提取操作,用于提取时间序列片段的局部特征,
Figure BDA0003766146160000046
为RNN特征提取操作,用于提取时间序列片段的时序特征,θc,θr分别为CNN模型参数和RNN模型参数。
优选的,所述步骤S3具体包括:
S31、对时间序列片段
Figure BDA0003766146160000047
进行聚类,并学习他们的嵌入,基于深度聚类模型IDEC,采用聚类损失来实现对给定输入进行聚类:
Figure BDA0003766146160000048
其中,qij表示由学生t分布测量的时间序列片段zi与聚类中心μj的相似度,pij是聚类的目标分布;
按时间序列片段特征集合
Figure BDA0003766146160000049
进行聚类,得到时间序列片段数据集合的一个划分
Figure BDA00037661461600000410
每个聚类都是多个时间序列片段特征的集合,聚类操作定义为:
Figure BDA00037661461600000411
Figure BDA00037661461600000412
其中,l为所有类别的总数,Pi为第i个聚类簇,|Pi|表示第i个聚类簇中元素的个数,z为Pi中的元素,
Figure BDA00037661461600000413
为l个类别的中心点,||·||为二范数;
S32、采样p个聚类构建元训练任务集
Figure BDA0003766146160000051
Mtrain={D1,D2,…,Dp}表示为p种传播模式,每个聚类Di分为Queryi和Supporti两部分,并对应一个预测任务
Figure BDA0003766146160000052
其中,Supporti用于任务
Figure BDA0003766146160000053
的学习适应,即用于基础学习器更新,Queryi用于更新元学习器参数;
采用最小均方误差作为预测损失:
Figure BDA0003766146160000054
其中,y为真实流行病确诊病例数,
Figure BDA0003766146160000055
为模型预测结果。
基学习器学习阶段,每个任务
Figure BDA0003766146160000056
对应一个基学习器,基于Supporti数据,基学习器计算损失
Figure BDA0003766146160000057
利用梯度下降最小化损失,找到使损失最小化的最优参数集:
Figure BDA0003766146160000058
其中,θ'i为任务i的最优参数,θ为模型初始参数,α为超参数,
Figure BDA0003766146160000059
为任务i的梯度;
元学习阶段,使用Queryi数据,基于基学习器学到的最优参数θ'i,元学习器计算相对于这些最优参数θ'i的梯度,更新随机初始化的参数θ,即元知识,使得θ调整到最佳数值,在该最佳数值状态下,应用到某地区未来疫情发展情况预测时,只需少量梯度更新,即可获得较好的预测效果:
Figure BDA00037661461600000510
其中,θ是模型初始参数,β是超参数,
Figure BDA00037661461600000511
是任务
Figure BDA00037661461600000512
在Queryi上获得的相对于参数θ'i的梯度。
优选的,所述步骤S4具体包括:
针对新的预测任务
Figure BDA0003766146160000061
将其归属到最相近时序片段聚类中,并采样获得Supporttest,基于学习到元知识θ,在Supporttest进行梯度梯度下降学习,获得适应新任务
Figure BDA0003766146160000062
的模型。
Figure BDA0003766146160000063
其中,θ'test为新任务的模型参数,θ为初始参数,即元知识,fθ为预测模型。
经由上述的技术方案可知,与现有技术相比,本发明公开提供了一种基于深度嵌入聚类元学习的流行病预测方法;利用多个地区疫情传播的时间序列片段学习细粒度的传播模式,并可将学习到的传播模式用于新爆发疫情且仅存在少量历史数据地区的未来预测,仅需要很少的领域知识去构建元学习任务,并具有很好的可解释性;采用基于MAML的无监督元学习方法,将疾病传播模型从疫情传播稳定的地区迁移到疫情处于早期阶段的另一个地区。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据提供的附图获得其他的附图。
图1附图为本发明提供的预测方法流程结构示意图。
图2附图为本发明提供的模型框架结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明实施例公开了一种基于深度嵌入聚类元学习的流行病预测方法,包括以下步骤:
S1、获取历史数据,将历史数据切分为与目标地区数据长度相匹配的多个时间序列片段,每个时间序列片段均包括历史片段部分和未来片段部分;
S2、针对每个时间序列片段,分别对其历史片段部分和未来片段部分进行标准化,并获取时间序列片段的特征集合;
S3、基于无监督聚类模型对时间序列片段进行聚类,获得多个类,采样p个类构造元训练集,并获取元知识,基于元知识对新任务模型参数初始化,并通过元训练集对初始化后的新任务模型进行训练;
S4、获取预测模型,初始化参数,通过多步梯度下降进行适应优化,进而针对元测试集中的新任务,对流行病发展进行预测。
为进一步优化上述技术方案,,步骤S1具体包括:
获取目标地区i长度为T的已知历史时间序列信息xi,将时间序列信息xi切分为多个长度为ω+ΔT的时间序列片段集合
Figure BDA0003766146160000071
Figure BDA0003766146160000072
其中,M为地区的数量,Ti为地区i的历史时序数据总长度,
Figure BDA0003766146160000081
为地区i在时刻t的时间序列片段,
Figure BDA0003766146160000082
为时间序列片段
Figure BDA0003766146160000083
在t时刻前的ω个数据,即历史片段部分,其与目标地区i的已知观测数据对齐,
Figure BDA0003766146160000084
为时间序列片段
Figure BDA0003766146160000085
在t时刻后的ΔT个数据,即未来片段部分,与待预测数据对齐。
优选的,步骤S2具体包括:
S21、分别对历史片段部分
Figure BDA0003766146160000086
和未来片段部分
Figure BDA0003766146160000087
进行标准化:
Figure BDA0003766146160000088
Figure BDA0003766146160000089
其中,
Figure BDA00037661461600000810
分别为时间序列片段
Figure BDA00037661461600000811
的历史片段部分
Figure BDA00037661461600000812
和未来片段部分
Figure BDA00037661461600000813
的均值,
Figure BDA00037661461600000814
分别为时间序列片段
Figure BDA00037661461600000815
的历史片段部分
Figure BDA00037661461600000816
和未来片段部分
Figure BDA00037661461600000817
的方差,将时间序列片段标准化到0和1之间;
S22、对于时间序列片段
Figure BDA00037661461600000818
基于CNN和RNN提取其序列局部特征和时序特征,时间序列片段
Figure BDA00037661461600000819
中的历史片段部分
Figure BDA00037661461600000820
对应已知数据的特征所在,因此时间序列的片段的嵌入表示仅从该部分特征中学习,将时间序列片段集合
Figure BDA00037661461600000821
Figure BDA00037661461600000822
投影到嵌入空间Z中,生成时间序列片段的特征集合
Figure BDA00037661461600000823
Figure BDA00037661461600000824
其中,ξ(·)为特征编码器,其由CNN和RNN两部分组成
Figure BDA00037661461600000825
为CNN特征提取操作,用于提取时间序列片段的局部特征,
Figure BDA00037661461600000826
为RNN特征提取操作,用于提取时间序列片段的时序特征,θc,θr分别为CNN模型参数和RNN模型参数。
为进一步优化上述技术方案,,步骤S3具体包括:
S31、对时间序列片段
Figure BDA0003766146160000091
进行聚类,并学习他们的嵌入,基于深度聚类模型IDEC,采用聚类损失来实现对给定输入进行聚类:
Figure BDA0003766146160000092
其中,qij表示由学生t分布测量的时间序列片段zi与聚类中心μj的相似度,pij是聚类的目标分布;
按时间序列片段特征集合
Figure BDA0003766146160000093
进行聚类,得到时间序列片段数据集合的一个划分
Figure BDA0003766146160000094
每个聚类都是多个时间序列片段特征的集合,聚类操作定义为:
Figure BDA0003766146160000095
Figure BDA0003766146160000096
其中,l为所有类别的总数,Pi为第i个聚类簇,|Pi|表示第i个聚类簇中元素的个数,z为Pi中的元素,
Figure BDA0003766146160000097
为l个类别的中心点,||·||为二范数;
S32、采样p个聚类构建元训练任务集
Figure BDA0003766146160000098
Mtrain={D1,D2,…,Dp}表示为p种传播模式,每个聚类Di分为Queryi和Supporti两部分,并对应一个预测任务
Figure BDA0003766146160000099
其中,Supporti用于任务
Figure BDA00037661461600000910
的学习适应,即用于基础学习器更新,Queryi用于更新元学习器参数;
采用最小均方误差作为预测损失:
Figure BDA00037661461600000911
其中,y为真实流行病确诊病例数,
Figure BDA00037661461600000912
为模型预测结果。
基学习器学习阶段,每个任务
Figure BDA0003766146160000101
对应一个基学习器,基于Supporti数据,基学习器计算损失
Figure BDA0003766146160000102
利用梯度下降最小化损失,找到使损失最小化的最优参数集:
Figure BDA0003766146160000103
其中,θ'i为任务i的最优参数,θ为模型初始参数,α为超参数,
Figure BDA0003766146160000104
为任务i的梯度;
元学习阶段,使用Queryi数据,基于基学习器学到的最优参数θ'i,元学习器计算相对于这些最优参数θ'i的梯度,更新随机初始化的参数θ,即元知识,使得θ调整到最佳数值,在该最佳数值状态下,应用到某地区未来疫情发展情况预测时,只需少量梯度更新,即可获得较好的预测效果:
Figure BDA0003766146160000105
其中,θ是模型初始参数,β是超参数,
Figure BDA0003766146160000106
是任务
Figure BDA0003766146160000107
在Queryi上获得的相对于参数θ'i的梯度。
为进一步优化上述技术方案,步骤S4具体包括:
针对新的预测任务
Figure BDA0003766146160000108
将其归属到最相近时序片段聚类中,并采样获得Supporttest,基于学习到元知识θ,在Supporttest进行梯度梯度下降学习,获得适应新任务
Figure BDA0003766146160000109
的模型。
Figure BDA00037661461600001010
其中,θ'test为新任务的模型参数,θ为初始参数,即元知识,fθ为预测模型。
评价指标:我们采用均方根误差
Figure BDA0003766146160000111
和皮尔逊相关系数
Figure BDA0003766146160000112
作为度量。RMSE值越低越好,而PCC值越高越好。
对比方法:
–AR:标准自回归模型
–LSTM:使用LSTM单元的循环神经网络(RNN)
–TPA-LSTM:基于注意力的LSTM模型(Shih,S.Y.,Sun,F.K.,Lee,H.y.:Temporalpattern attention for multivariate time series forecasting.Machine Learning(2019))
–ST-GCN[20]:时空图神经网络
–CNNRNN-Res:一种结合CNN、RNN和残差链接进行流行病学预测的深度学习模型(Yu,B.,Yin,H.,Zhu,Z.:Spatio-temporal graph convolutional networks:A deeplearning framework for traffic forecasting.arXiv preprint arXiv:1709.04875(2017))
–SAIFlu-Net:基于自我注意的流感预测模型(Jung,S.,Moon,J.,Park,S.,Hwang,E.:Self-attention-based deep learning network forregional influenzaforecasting.IEEE JBHI(2021))
–Cola-GNN:一种结合CNN、RNN和GCN进行流行病预测的深度学习模型(Deng,S.,Wang,S.,Rangwala,H.,Wang,L.,Ning,Y.:Cola-gnn:Cross-location attention basedgraph neural networks for long-term ili prediction.In:Proc.of CIKM(2020))
不同方法在三个数据集上的RMSE和PCC性能,horizon=3,5,10,15。粗体表示每列的最佳结果,下划线表示次优。*表示结果在相应的参考文献中报告
Figure BDA0003766146160000121
我们在短期(范围<10)和长期(范围≥10)设置中评估每个模型。流感数据集如表所示。总体趋势是预测精度随着预测范围的增加而下降,因为范围越大,问题越难。不同数据集之间RMSE的巨大差异是由于数据集的规模和方差。
我们观察到我们的方法在大多数任务上都优于其他模型。我们的方法在流感预测任务中的RMSE分别比最佳基线低5.6%。在流感预测任务中,大多数基于深度学习的模型比统计模型(HA和AR)表现更好,因为它们努力处理时间序列背后的非线性特征和复杂模式。
本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似部分互相参见即可。对于实施例公开的装置而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。
对所公开的实施例的上述说明,使本领域专业技术人员能够实现或使用本发明。对这些实施例的多种修改对本领域的专业技术人员来说将是显而易见的,本文中所定义的一般原理可以在不脱离本发明的精神或范围的情况下,在其它实施例中实现。因此,本发明将不会被限制于本文所示的这些实施例,而是要符合与本文所公开的原理和新颖特点相一致的最宽的范围。

Claims (5)

1.一种基于深度嵌入聚类元学习的流行病预测方法,其特征在于,包括以下步骤:
S1、获取历史数据,将历史数据切分为与目标地区数据长度相匹配的多个时间序列片段,每个时间序列片段均包括历史片段部分和未来片段部分;
S2、针对每个时间序列片段,分别对其历史片段部分和未来片段部分进行标准化,并获取时间序列片段的特征集合;
S3、基于无监督聚类模型对时间序列片段进行聚类,获得多个类,采样p个类构造元训练集,并获取元知识,基于元知识对新任务模型参数初始化,并通过元训练集对初始化后的新任务模型进行训练;
S4、获取预测模型,初始化参数,通过多步梯度下降进行适应优化,进而针对元测试集中的新任务,对流行病发展进行预测。
2.根据权利要求1所述的一种基于深度嵌入聚类元学习的流行病预测方法,其特征在于,所述步骤S1具体包括:
获取目标地区i长度为T的已知历史时间序列信息xi,将时间序列信息xi切分为多个长度为ω+ΔT的时间序列片段集合
Figure FDA0003766146150000011
Figure FDA0003766146150000012
其中,M为地区的数量,Ti为地区i的历史时序数据总长度,
Figure FDA0003766146150000013
为地区i在时刻t的时间序列片段,
Figure FDA0003766146150000014
为时间序列片段
Figure FDA0003766146150000015
在t时刻前的ω个数据,即历史片段部分,其与目标地区i的已知观测数据对齐,
Figure FDA0003766146150000016
为时间序列片段
Figure FDA0003766146150000017
在t时刻后的ΔT个数据,即未来片段部分,与待预测数据对齐。
3.根据权利要求1所述的一种基于深度嵌入聚类元学习的流行病预测方法,其特征在于,所述步骤S2具体包括:
S21、分别对历史片段部分
Figure FDA0003766146150000021
和未来片段部分
Figure FDA0003766146150000022
进行标准化:
Figure FDA0003766146150000023
Figure FDA0003766146150000024
其中,
Figure FDA0003766146150000025
分别为时间序列片段
Figure FDA0003766146150000026
的历史片段部分
Figure FDA0003766146150000027
和未来片段部分
Figure FDA0003766146150000028
的均值,
Figure FDA0003766146150000029
分别为时间序列片段
Figure FDA00037661461500000210
的历史片段部分
Figure FDA00037661461500000211
和未来片段部分
Figure FDA00037661461500000212
的方差,将时间序列片段标准化到0和1之间;
S22、对于时间序列片段
Figure FDA00037661461500000213
基于CNN和RNN提取其序列局部特征和时序特征,时间序列片段
Figure FDA00037661461500000214
中的历史片段部分
Figure FDA00037661461500000215
对应已知数据的特征所在,因此时间序列的片段的嵌入表示仅从该部分特征中学习,将时间序列片段集合
Figure FDA00037661461500000216
Figure FDA00037661461500000217
投影到嵌入空间Z中,生成时间序列片段的特征集合
Figure FDA00037661461500000218
Figure FDA00037661461500000219
其中,ξ(·)为特征编码器,其由CNN和RNN两部分组成
Figure FDA00037661461500000220
为CNN特征提取操作,用于提取时间序列片段的局部特征,
Figure FDA00037661461500000221
为RNN特征提取操作,用于提取时间序列片段的时序特征,θc,θr分别为CNN模型参数和RNN模型参数。
4.根据权利要求1所述的一种基于深度嵌入聚类元学习的流行病预测方法,其特征在于,所述步骤S3具体包括:
S31、对时间序列片段
Figure FDA00037661461500000222
进行聚类,并学习他们的嵌入,基于深度聚类模型IDEC,采用聚类损失来实现对给定输入进行聚类:
Figure FDA0003766146150000031
其中,qij表示由学生t分布测量的时间序列片段zi与聚类中心μj的相似度,pij是聚类的目标分布;
按时间序列片段特征集合
Figure FDA0003766146150000032
进行聚类,得到时间序列片段数据集合的一个划分
Figure FDA0003766146150000033
每个聚类都是多个时间序列片段特征的集合,聚类操作定义为:
Figure FDA0003766146150000034
Figure FDA0003766146150000035
其中,l为所有类别的总数,Pi为第i个聚类簇,|Pi|表示第i个聚类簇中元素的个数,z为Pi中的元素,
Figure FDA0003766146150000036
为l个类别的中心点,||·||为二范数;
S32、采样p个聚类构建元训练任务集
Figure FDA0003766146150000037
Mtrain={D1,D2,…,Dp}表示为p种传播模式,每个聚类Di分为Queryi和Supporti两部分,并对应一个预测任务
Figure FDA0003766146150000038
其中,Supporti用于任务
Figure FDA0003766146150000039
的学习适应,即用于基础学习器更新,Queryi用于更新元学习器参数;
采用最小均方误差作为预测损失:
Figure FDA00037661461500000310
其中,y为真实流行病确诊病例数,
Figure FDA00037661461500000311
为模型预测结果。
基学习器学习阶段,每个任务
Figure FDA00037661461500000312
对应一个基学习器,基于Supporti数据,基学习器计算损失
Figure FDA00037661461500000313
利用梯度下降最小化损失,找到使损失最小化的最优参数集:
Figure FDA0003766146150000041
其中,θ'i为任务i的最优参数,θ为模型初始参数,α为超参数,
Figure FDA0003766146150000042
为任务i的梯度;
元学习阶段,使用Queryi数据,基于基学习器学到的最优参数θ'i,元学习器计算相对于这些最优参数θ'i的梯度,更新随机初始化的参数θ,即元知识,使得θ调整到最佳数值,在该最佳数值状态下,应用到某地区未来疫情发展情况预测时,只需少量梯度更新,即可获得较好的预测效果:
Figure FDA0003766146150000043
其中,θ是模型初始参数,β是超参数,
Figure FDA0003766146150000044
是任务
Figure FDA0003766146150000045
在Queryi上获得的相对于参数θ'i的梯度。
5.根据权利要求1所述的一种基于深度嵌入聚类元学习的流行病预测方法,其特征在于,所述步骤S4具体包括:
针对新的预测任务
Figure FDA0003766146150000046
将其归属到最相近时序片段聚类中,并采样获得Supporttest,基于学习到元知识θ,在Supporttest进行梯度梯度下降学习,获得适应新任务
Figure FDA0003766146150000047
的模型。
Figure FDA0003766146150000048
其中,θ'test为新任务的模型参数,θ为初始参数,即元知识,fθ为预测模型。
CN202210887157.5A 2022-07-26 2022-07-26 一种基于深度嵌入聚类元学习的流行病预测方法 Pending CN115240871A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210887157.5A CN115240871A (zh) 2022-07-26 2022-07-26 一种基于深度嵌入聚类元学习的流行病预测方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210887157.5A CN115240871A (zh) 2022-07-26 2022-07-26 一种基于深度嵌入聚类元学习的流行病预测方法

Publications (1)

Publication Number Publication Date
CN115240871A true CN115240871A (zh) 2022-10-25

Family

ID=83675157

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210887157.5A Pending CN115240871A (zh) 2022-07-26 2022-07-26 一种基于深度嵌入聚类元学习的流行病预测方法

Country Status (1)

Country Link
CN (1) CN115240871A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116011657A (zh) * 2023-01-29 2023-04-25 上海交通大学 基于微型pmu的配电网负荷预测模型优选方法、装置及系统
CN117711636A (zh) * 2023-12-15 2024-03-15 南京理工大学 基于注意力机制的张量时空图卷积的猴痘疫情预测方法

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116011657A (zh) * 2023-01-29 2023-04-25 上海交通大学 基于微型pmu的配电网负荷预测模型优选方法、装置及系统
CN116011657B (zh) * 2023-01-29 2023-06-27 上海交通大学 基于微型pmu的配电网负荷预测模型优选方法、装置及系统
CN117711636A (zh) * 2023-12-15 2024-03-15 南京理工大学 基于注意力机制的张量时空图卷积的猴痘疫情预测方法

Similar Documents

Publication Publication Date Title
CN110570651B (zh) 一种基于深度学习的路网交通态势预测方法及系统
Wu et al. Evolving RBF neural networks for rainfall prediction using hybrid particle swarm optimization and genetic algorithm
CN108629979B (zh) 一种基于历史和周边路口数据的拥堵预测算法
CN115240871A (zh) 一种基于深度嵌入聚类元学习的流行病预测方法
CN111563706A (zh) 一种基于lstm网络的多变量物流货运量预测方法
Qin et al. Simulating and Predicting of Hydrological Time Series Based on TensorFlow Deep Learning.
CN109902801A (zh) 一种基于变分推理贝叶斯神经网络的洪水集合预报方法
CN108256590B (zh) 一种基于复合元路径的相似出行者识别方法
CN106709588B (zh) 预测模型构建方法和设备以及实时预测方法和设备
CN108564790A (zh) 一种基于交通流时空相似性的城市短时交通流预测方法
Faiq et al. Prediction of energy consumption in campus buildings using long short-term memory
CN110781595B (zh) 能源使用效率pue的预测方法、装置、终端及介质
CN111862592B (zh) 一种基于rgcn的交通流预测方法
CN110084398A (zh) 一种基于企业电力大数据的行业景气自适应检测方法
CN110267206A (zh) 用户位置预测方法及装置
CN112863182A (zh) 基于迁移学习的跨模态数据预测方法
CN114065996A (zh) 基于变分自编码学习的交通流预测方法
CN105913078A (zh) 改进自适应仿射传播聚类的多模型软测量方法
CN116108984A (zh) 基于流量-poi因果关系推理的城市流量预测方法
CN107481523A (zh) 一种交通流速度预测方法及系统
CN116311921A (zh) 一种基于多空间尺度时空Transformer的交通速度预测方法
Kim et al. A daily tourism demand prediction framework based on multi-head attention CNN: The case of the foreign entrant in South Korea
CN116227716A (zh) 一种基于Stacking的多因素能源需求预测方法及系统
CN116244647A (zh) 一种无人机集群的运行状态估计方法
CN103745602A (zh) 一种基于滑窗平均的交通流量预测方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination