CN113962472B - 一种基于GAT-Seq2seq模型的时空双注意力地铁客流短时预测方法 - Google Patents

一种基于GAT-Seq2seq模型的时空双注意力地铁客流短时预测方法 Download PDF

Info

Publication number
CN113962472B
CN113962472B CN202111278744.6A CN202111278744A CN113962472B CN 113962472 B CN113962472 B CN 113962472B CN 202111278744 A CN202111278744 A CN 202111278744A CN 113962472 B CN113962472 B CN 113962472B
Authority
CN
China
Prior art keywords
time
subway
passenger flow
data
gat
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202111278744.6A
Other languages
English (en)
Other versions
CN113962472A (zh
Inventor
叶智锐
邵宜昌
施晓蒙
毕辉
张宇涵
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Southeast University
Original Assignee
Southeast University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Southeast University filed Critical Southeast University
Priority to CN202111278744.6A priority Critical patent/CN113962472B/zh
Publication of CN113962472A publication Critical patent/CN113962472A/zh
Application granted granted Critical
Publication of CN113962472B publication Critical patent/CN113962472B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q10/00Administration; Management
    • G06Q10/04Forecasting or optimisation specially adapted for administrative or management purposes, e.g. linear programming or "cutting stock problem"
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/25Fusion techniques
    • G06F18/253Fusion techniques of extracted features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/044Recurrent networks, e.g. Hopfield networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q50/00Information and communication technology [ICT] specially adapted for implementation of business processes of specific business sectors, e.g. utilities or tourism
    • G06Q50/10Services
    • G06Q50/26Government or public services
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02ATECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
    • Y02A90/00Technologies having an indirect contribution to adaptation to climate change
    • Y02A90/10Information and communication technologies [ICT] supporting adaptation to climate change, e.g. for weather forecasting or climate simulation

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Business, Economics & Management (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Molecular Biology (AREA)
  • Human Resources & Organizations (AREA)
  • Software Systems (AREA)
  • Strategic Management (AREA)
  • Tourism & Hospitality (AREA)
  • Mathematical Physics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computing Systems (AREA)
  • Economics (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Marketing (AREA)
  • General Business, Economics & Management (AREA)
  • Evolutionary Biology (AREA)
  • Development Economics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Entrepreneurship & Innovation (AREA)
  • Primary Health Care (AREA)
  • Educational Administration (AREA)
  • Game Theory and Decision Science (AREA)
  • Quality & Reliability (AREA)
  • Operations Research (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了一种基于GAT‑Seq2seq模型的时空双注意力地铁客流短时预测方法,该方法包括以下步骤:获取地铁客流数据;数据预处理以及数据集划分;依据站点地理位置及运行线路生成图结构网络;构建基于多头注意力的GAT模型,输入训练集中的特征向量,结合站点图结构计算站点的空间关联性;将GAT输出的特征向量传入基于时序注意力的Seq2seq模型,提取客流的时间相关性,利用训练集计算均方误差,调整GAT中图结构的边权重矩阵和Seq2seq模型中的循环神经网络参数;使用测试集进行预测并评估模型。本发明利用时空双注意力机制解决了现有预测模型仅从短时间维度寻找特征,导致预测结果精度低的问题。

Description

一种基于GAT-Seq2seq模型的时空双注意力地铁客流短时预 测方法
技术领域
本发明涉及一种基于GAT-Seq2seq(Graph Attention Network&Sequence tosequence)模型的时空双注意力地铁客流短时预测方法,属于地铁客流量预测技术领域。
背景技术
随着社会经济的不断发展,城市化水平不断提高,地铁保障了居民对生活出行的基本需求。但由于早晚高峰以及天气的影响,地铁客流呈现周期性波动,造成某些时段内的站内拥堵情况,加大了工作人员的管理难度,而通过对地铁客流的准确预测可以协助地铁部门对全市范围内的地铁运输能力进行合理分配调度、对大规模人群密集活动做到提前预警、对未来地铁线路及站点进行规划布局等,对整体地铁网路乃至全市交通的管理都能提供巨大的帮助。
由于地铁客流的影响因素种类多且范围广,包括天气因素(如降雨降雪、温度等)、每日不同时段及节假日、站点的空间地理位置等。传统的地铁交通客流预测的历史数据获取需要依赖大量人力物力去实地长时间调查才能得到,并且最终的预测结果仅可用于观察日常的地铁流量变化,无法应对大规模人群活动时所带来的突发情况。并且该类预测方法无法将各影响因素之间的关联性强度数值化后纳入考虑范畴,从时间角度去提取日趋势、月趋势、季度趋势和年度趋势来粗略预测客流量,不具有普适性。
随着数据挖掘以及深度学习的不断发展,卷积计算、特征提取以及误差反向传播等技术在客流预测领域逐渐成熟,可以借助计算机高效的运算能力,建立神经网络模型在历史数据中挖掘出长短期时间的变化规律,使得预测精度有所提升,但仍未能全面考虑客流在各站点分布上的空间关系,或未能合理分配各相邻站点对待预测站点的影响权重值。同时基于LSTM模型的预测方法在提取时间序列的规律时采用简单的遗忘、新增、更新等操作单元遗漏了大量时序特征,导致整个预测模型精度较低。
发明内容
本发明所要解决的技术问题是:提供一种基于GAT-Seq2seq模型的时空双注意力地铁客流短时预测方法,从空间、时间两个角度考虑注意力机制解决了现有主流预测模型未能考虑时空角度都存在注意力机制,并且仅考虑短时间的时序规律提取,导致预测结果精度低的问题。同时,在Seq2seq模型部分改进了现有的时序注意力机制,使得模型能并行处理长时期长时段范围的客流预测。
本发明为解决上述技术问题采用以下技术方案:
一种基于GAT-Seq2seq模型的时空双注意力地铁客流短时预测方法,所述方法包括如下步骤:
步骤1,采集地铁客流相关数据,包括:地铁站点基本信息、地铁站点历史客流数据、历史气象数据和节假日数据;
步骤2,对采集的地铁客流相关数据进行预处理后,将所述地铁站点历史客流数据和所述历史气象数据按时间先后排序生成时序化数据,将一天划分为若干时段,统计不同时段内的地铁进站客流量、降雨量、降雪量和温度数据,对上述统计得到的数据进行归一化后得到数据集,将所述数据集按照3:1:1的比例划分为训练集、验证集和测试集;
步骤3,将所有地铁站点作为图结构的顶点,通过地铁运行线路及设置的距离阈值判断任意两个顶点之间是否存在边,生成地铁站点图结构G,顶点之间的边权重采用Xavier方法进行初始化,并服从均匀分布;
步骤4,构建图注意力卷积神经网络模型GAT,使用步骤2中的训练集构建特征矩阵将所述特征矩阵/>作为所述网络模型GAT的输入,计算步骤3所述图结构G中各顶点之间的注意力系数,再加权求和得到考虑邻域站点影响后的特征矩阵/>
步骤5,构建基于时间注意力的时序到时序模型Seq2seq,将步骤4得到的特征矩阵输入所述时序到时序模型Seq2seq,经过编码层与解码层处理后,计算地铁进站客流量预测值与实际值之间的均方根误差,根据均方根误差和训练集对所述时序到时序模型Seq2seq中循环神经网络单元的超参数进行调整,并利用验证集对超参数进一步优化,从而得到训练好的GAT-Seq2seq模型;
步骤6,将测试集输入所述训练好的GAT-Seq2seq模型,预测未来各个时段下地铁站点的客流量,并根据预测结果对所述训练好的GAT-Seq2seq模型进行评估。
作为本发明的一种优选方案,所述步骤1中,地铁站点基本信息包括:地铁站点的经纬度地理坐标和地铁运行线路数据;地铁站点历史客流数据包括:站点名称或站点序号、刷卡进站时间和进站客流量;历史气象数据包括:是否降雨、是否降雪、降雨量、降雪量和温度;节假日数据包括:是否星期几和是否节假日,其中,是否星期几采用热独编码转换成七个参数,即:是否星期一、是否星期二、是否星期三、是否星期四、是否星期五、是否星期六和是否星期日。
作为本发明的一种优选方案,所述步骤2的具体过程如下:
步骤21,采用整体剔除法删除地铁客流相关数据中地铁站点名称异常或缺失的数据,采用前向填充替换法处理地铁刷卡进站时间、进站客流量、降雨量、降雪量和温度中存在缺失或异常的数据;
步骤22,对地铁站点历史客流数据按照时间先后顺序进行排序,设置时段长度l为1小时,则一天划分为24个时段,分别统计每个时段内的地铁进站客流量、降雨量和降雪量;
步骤23,根据采集的温度数据,计算每个时段内的平均温度,每个时段内的最高温度、最低温度分别取该时段内所有温度的最大值、最小值;
步骤24,对地铁刷卡进站客流量、降雨量、降雪量、平均温度、最高温度和最低温度进行归一化处理后得到数据集;
步骤25,将数据集按照3:1:1的比例划分为训练集、验证集和测试集。
作为本发明的一种优选方案,所述步骤3的具体过程如下:
步骤31,将所有地铁站点作为图结构的顶点,已知各个地铁站点的经纬度坐标,对于任意一个顶点A,寻得以顶点A为圆心,半径为2.5Km的圆形范围,若某地铁站点同时满足以下两个条件:(1)该地铁站点在圆形范围内,(2)该地铁站点与顶点A之间存在地铁运行可达线路或该地铁站点与顶点A之间的实际步行时间小于10分钟;则认为该地铁站点与顶点A之间存在边,否则不存在;
步骤32,对所有顶点重复步骤31,得到地铁站点图结构G。
作为本发明的一种优选方案,所述步骤4的具体过程如下:
步骤41,图注意力卷积神经网络模型GAT的输入是一个四维特征矩阵 其中,D为训练集数据的时间总跨度,T为一天中的时段个数,N为地铁站点个数,F为特征向量维数,则特征矩阵/>为:
其中,表示在第d天的第t个时段下,编号为n的地铁站点的F个影响客流变化的特征向量,d∈[1,D],t∈[1,T],n∈[1,N],D=1096,T=24,N=83,F=16;
步骤42,对于图结构的顶点si,计算顶点si与邻居顶点sj之间的相关系数用于评估si与sj之间的相关性,/>计算公式如下:
其中,sj表示与si存在边的邻居顶点,表示与si存在边的邻居顶点集合,W表示边上权重值矩阵,||操作表示横向拼接操作,a(·)表示将高维特征映射到一个实数的变换函数,/>分别表示在第d天的第t个时段下,编号为i、j的地铁站点的F个影响客流变化的特征向量;
步骤43,将相关系数归一化得到注意力系数/>具体计算公式如下:
其中,LeakyReLU(·)为归一化激活函数,sk表示与si存在边的邻居顶点;
步骤44,根据注意力系数将特征向量/>加权求和得到新的特征向量/>具体计算公式如下:
其中,σ(·)为计算函数;
步骤45,考虑M组不同的边上权重值矩阵W和a(·)函数来生成不同的再将横向拼接,最后求得平均值,具体计算公式如下:
其中,表示使用第m组权重W和a(·)函数计算得到的注意力系数,Wm表示第m组边上权重值矩阵;
步骤46,对时间总跨度D内的每个时段的特征向量都进行步骤42-45得到新的特征向量/>将新的特征向量/>按步骤41中/>的结构拼接生成GAT的输出矩阵/>
作为本发明的一种优选方案,所述步骤5的具体过程如下:
步骤51,构建编码器,编码器采用循环神经网络单元,将步骤4得到的特征矩阵的每一天的数据都平铺展开,记作X,X={x1,x2,x3,…,xT},共有D个X向量,将循环神经网络单元记为p,t时刻的隐藏层状态变量ht只和t-1时刻的隐藏层状态变量ht-1与t时刻的输入xt有关,公式表达如下:
ht=p(xt,ht-1)
构建D个并行编码器,即D天中对每一天都构建包含T个时段的编码器,用于生成统一的上下文向量c,则上下文向量c由编码器所有的隐藏层状态变量变换得到:
其中,q(·)为变换函数;
步骤52,考虑注意力机制构建解码器,解码器采用循环神经网络单元,记为g;解码器在t′时刻的隐藏层状态变量st′只与t′-1时刻的隐藏层状态变量st′-1、t′时刻的上下文向量ct′、t′-1时刻的输出变量yt′-1有关,公式表达如下:
st′=g(yt′-1,ct′,st′-1)
将解码器在t′时刻的上下文向量ct′分为两部分计算,前半部分只考虑最后一天的24个时段的时序信息,后半部分考虑之前所有日期下相同预测时段的信息,具体计算公式如下:
其中,表示编码器中第D天t时刻的隐藏层状态变量,/>表示编码器中第d天t′时刻的隐藏层状态变量,/>表示解码器中t′时刻的状态与编码器中第D天t时刻状态的时间注意力系数,/>表示解码器中t′时刻的状态与编码器中第d天t′时刻状态的时间注意力系数,时间注意力系数/>计算公式如下:
其中,表示解码器t′时刻的状态与编码器第D天t时刻状态的相关系数;/>表示解码器t′时刻的状态与编码器第d天t′时刻状态的相关系数;
通过上述公式计算得到解码器中t′时刻对应的所有上下文向量ct′,用于计算预测数据yt′
yt′=σ(yt′-1,ct′,st′);
步骤53,使用均方根误差RMSE作为模型的损失函数,根据训练集使用Adam SGD优化器对GAT-Seq2seq模型中的相关参数进行调整,并通过验证集进行超参数优化,从而得到训练好的GAT-Seq2seq模型。
作为本发明的一种优选方案,所述步骤6的具体过程如下:
步骤61,将训练好的GAT-Seq2seq模型迁移给测试集,确定预测步长l为1小时,利用训练好的GAT-Seq2seq模型预测未来时段的地铁站点客流量;
步骤62,计算预测值与实际值的均方根误差RMSE,用于模型性能评估,其中均方根误差RMSE公式如下:
其中,R为预测样本总个数,yr为实际客流量,为预测客流量。
本发明采用以上技术方案与现有技术相比,具有以下技术效果:
1、本发明基于图卷积神经网络考虑了地铁站点在空间地理位置上对客流的相互影响,并结合注意力机制实现了对邻近站点的影响度权重分配。
2、本发明在采用时序到时序模型的基础上,添加了时间注意力机制,并创新性地采用多编码器融合的方式,在预测某一时段的客流时,结合考虑数据集中每一天在该时段下的历史特征向量,生成可供解码器不同隐藏状态单元使用的上下文向量。
3、本发明融合了时空特征,并在空间与时间上都采用了注意力机制,尤其改进了Seq2seq模型中的时序注意力机制,使得该模型能并行处理超长时段范围的客流预测,以上改进与应用使得预测模型的精度有明显提升,对于地铁站点日常规律性客流变化预测精准;对于大规格出行突发事件,该模型也能够结合邻近站点的客流数据做出短时准确预测。
附图说明
图1是本发明方法的整体流程逻辑图;
图2是本发明方法的模型详细运算结构图;
图3是本发明方法的数据预处理流程示意图;
图4是本发明方法的Seq2seq模型编码器的结构图;
图5是本发明模型预测结果与实际数据的对比图。
具体实施方式
下面详细描述本发明的实施方式,所述实施方式的示例在附图中示出。下面通过参考附图描述的实施方式是示例性的,仅用于解释本发明,而不能解释为对本发明的限制。
本发明的核心思想是提取数据内部特征的空间关联和时间关联,结合后用于预测未来时段的客流量。在空间关联特征提取上,使用了目前先进的图卷积神经网络使用多头注意力机制通过站点地理位置和路线来确定邻近站点之间的相互影响。在时间关联特征提取上,使用了改进后的时序到时序模型在时间层面提取影响特征,即前几个时段的客流数据会影响后时段的客流预测结果。
本发明实施例中所用数据为国外某市地铁系统2018年6月-2021年6月的各个站点的闸机客流记录数据,以及该市气象部门公开的全市各区域范围的气象、空气质量数据。地铁站点位置数据及运行线路数据来源于该市交通运输部门公开。最后自制2018-2021年度的节假日信息,包含该国法定节假日、工作日和非工作日信息。
如图1所示为本发明基于GAT-Seq2seq模型的时空双注意力地铁短时客流预测方法的总体流程图,图2为本发明各部分的详细运算结构图,下面结合图1与图2对本发明的方法作进一步说明,本发明方法包括以下步骤:
步骤S1:采集地铁客流相关数据,数据包括:地铁站点基本信息、站点历史客流数据、历史气象数据和节假日数据。其中站点基本信息包括:各地铁站点的经纬度地理坐标和地铁运行线路数据。历史客流刷卡数据包括:站点名称(或站点序号),闸机刷卡时间和该时间的进站人数。气象数据包括:是否降雨、是否降雪、降雨量、降雪量和温度。节假日数据包括:是否星期几、是否法定节假日。其中是否星期几参数采用热独编码,即转换成七个参数:是否星期一、是否星期二、是否星期三、是否星期四、是否星期五、是否星期六和是否星期日。
步骤S2:数据预处理。将地铁刷卡数据和气象数据按时间排序生成时序化数据,并统计出不同时段内的地铁客流总和、降雨量、降雪量、平均温度、最高温度和最低温度,所述时段为每1小时,数据归一化并生成输入特征矩阵H。将数据集按照3:1:1比例划分为训练集、验证集和测试集。
结合图3对数据预处理与划分的具体流程进一步解释:
步骤S21:对于历史客流数据中站点名称异常或缺失的数据,采用整体剔除法删除整行相关数据。对刷卡时间、进站人数、降雨量、降雪量、温度数据的缺失值和异常值,均采用前向填充替换法。
步骤S22:对历史客流刷卡数据按照时间顺序正序排列,设置预测时段长度l(单位:小时)为1小时,计算时段个数T,分别统计24个时间跨度内的进站人数、降雨量和降雪量。
步骤S23:对气象数据中温度数据按步骤S22中相同的时间跨度T=24计算时间段内的平均值。最高温度数据选取该时段内最大值,最低温度选取该时段内最小值。预处理后的数据如表1所示。
表1预处理后的数据结构
步骤S24:对预处理后的进站人数、降雨量、降雪量、平均温度、最高温度和最低温度数据进行归一化处理,将数据映射到[0,1]之间,归一化运算公式为:
步骤S25:将归一化后的数据集按照3:1:1比例划分为训练集、验证集和测试集。
步骤S3:生成地铁站点图结构G,各站点皆为图结构的顶点,通过地铁运行线路及设置距离阈值来决定任意两站点之间是否存在边。该图结构是学习邻近站点之间空间影响力度的基础结构,顶点之间的边上权重W采用Xavier方法初始化,并服从均匀分布。
其中生成站点图结构的具体流程为:
步骤S31:所有地铁站点均为图结构的顶点,已知各个站点的经纬度坐标,选择任意一个顶点,设置半径为2.5Km寻得一个圆形范围,其余站点若满足以下两个条件,(1)站点在圆圈范围内,(2)与顶点存在地铁运行可达线路或与顶点之间的实际步行时间小于10分钟,则认为满足条件的站点与该顶点存在边。
步骤S32:对所有顶点重复上述步骤,可得到地铁站点完整的图结构。
步骤S4:构建图注意力卷积神经网络模型,用于学习各地铁站点之间的地理空间关联性,使用步骤S3中的图结构G和步骤S2中的训练集输入特征矩阵计算步骤S31中各顶点之间的注意力系数α,再加权求和得到考虑邻域站点影响后的特征矩阵/>
其中图注意力卷积层的具体运行流程包括:
步骤S41:图注意力卷积网络的输入是一个四维特征矩阵 其中D为训练集数据的时间跨度(即总天数D=1096天),T为步骤S22中定义的每日时段个数(T=24),N为地铁站点个数(N=83),F为特征向量维数(历史进站人数、是否降雨、是否降雪、降雨量、降雪量、平均温度、最高温度、最低温度、热独编码后的星期几参数、是否法定节假日,共16个参数,即F=16)。输入特征矩阵为:
其中,表示为第d天第t个时段下,编号为n的地铁站点的F(F=16)个影响客流变化的特征向量,以/>为例,/> 表示为第1天第一个时段下(凌晨12点至早1点),编号为1的地铁站点的16个影响客流变化的特征向量。
步骤S42:计算相关系数对于顶点si,逐个计算顶点与邻居点之间的相关系数,用于评估si与sj之间的相关性,计算公式如下:
其中sj表示为与si存在边的邻居顶点,由步骤S31中图结构得到。式中存在一个共享参数W以线性映射的方式对顶点特征向量进行增维,||操作是将变更后的si和sj的特征向量进行横向拼接操作,最后通过a(·)函数操作将高维特征映射到一个实数。
步骤S43:相关系数归一化得到注意力系数具体计算公式如下:
其中归一化激活函数为LeakyReLU(·),是考虑了该函数的特性能保留节点si自身的特性,而不是仅受到邻居节点们的影响,具体完整公式如下:
步骤S44:根据计算好的注意力系数,将特征向量加权求和输出新的特征向量(融合了邻域特征信息),公式如下:
步骤S45:采用多头注意力机制增强准确性,即考虑M组不同的权重W和a(·)函数来生成不同的再将/>横向拼接,最后求得平均值,具体公式如下:
步骤S46:对数据时间总跨度(总天数D=1096)内的每个时段(总时段T=24)的特征向量经上述步骤输出新的特征向量/>按步骤S41中/>的结构拼接生成GAT层输出矩阵/>
步骤S5:构建基于时间注意力的时序到时序模型,学习时间层面的数据相关性。使用步骤S4得到的特征矩阵输入Seq2seq模型中,经过编码层与解码层处理后,计算客流量预测值与实际值之间的均方误差RMSE,用于调整Seq2seq模型中循环神经网络单元的参数。
其中时序到时序模型的具体流程包括:
步骤S51:构建编码器模块,编码器采用循环神经网络,编码器用于信息的序列编码,可以将任意长度的序列信息编码到一个上下文向量c中。下面结合图4对编码器模型具体说明,输入向量为步骤S46中的特征矩阵的每一天的数据都平铺展开,记作X,X={x1,x2,x3,…,xT},共有D=1096(总天数)个X向量。将循环神经网络单元记为p,t时刻的隐藏层变量ht只和t-1时刻的隐藏层状态变量ht-1与t时刻的输入xt有关,可用如下公式表示:
ht=p(xt,ht-1)
在目前的现有技术下,编码器输出的上下文向量c通常是由编码器所有的隐藏层变量ht变换所得(如下方公式所示)。其中q(·)即为变换函数,例如求和函数或求平均函数等。
c=q(h1,h2,…,hT)
如果将所有的输入向量展开后传入到模型中,将会生成D*T个隐藏层状态,从而导致模型计算量大幅增加,因此本发明提出一种并行运算方式以提高效率。由于共有D=1096个X输入向量,则可以构建D个并行编码器,即1096天中对每一天都构建包含24个时段的编码器,用于生成统一的上下文向量c,下方等式含义为上下文向量c可由编码器所有的隐藏状态变量变换得到:
步骤S52:考虑注意力机制构建解码器模块,注意力的目的是让解码器在不同时刻可以使用不同的上下文向量ct′,由于引入了注意力机制,则步骤S51中的上下文向量c需要按新的理论重新计算并分解成不同的ct′。解码器同样采用循环神经网络,记作g;解码器在t′时刻的隐藏层状态变量st′只与t′-1时刻的隐藏层状态变量st′-1、t′时刻的上下文向量ct′、t′-1时刻的输出变量yt′-1有关,公式表达如下:
st′=g(yt′-1,ct′,st′-1)
目前现有的注意力机制中,解码器在t′时刻的上下文向量ct′通常可以由编码器中的t个隐藏层状态变量ht按时间注意力系数αt′t加权求和表示,公式如下:
但该类注意力机制在本实例下时间跨度过长,且重复保留无关时序信息(保留了1096天24个时段的特征信息),故本发明创新性地将上下文向量ct′分成两部分计算(公式如下),前半部分是只考虑最后一天的24个时段的时序信息,后半部分是考虑之前所有日期下相同预测时段的信息。例如在预测21年6月2日8时至9时的客流时,上下文向量ct′需考虑21年6月2日24个时段的状态变量,以及18年6月1日至21年6月1日中所有8时至9时的状态变量,该方法可以在保留关键时序信息的同时,减少无用数据对上下文向量的影响。
其中时间注意力系数计算公式如下,需注意该注意力机制考虑筛选了D个编码器的所有状态ht
/>
通过上述公式可计算得到解码器中t′时刻对应的所有上下文向量ct′,用于计算预测数据yt′,其中σ(·)为计算函数:
yt′=σ(yt′-1,ct′,st′)
步骤S53:使用均方根误差RMSE得到作为模型的损失函数原理,根据训练集数据使用Adam SGD优化器对预测模型中的相关参数(步骤S3中权重矩阵W,步骤S43中的变换函数aT)进行调整,并在训练后期通过验证集数据进行超参数优化,提高整体模型的泛化能力,使得最终模型能在未知数据中获得更精准的结果。
步骤S6:将测试集数据输入训练好的模型中,预测各个时段下地铁站点的客流量,并根据预测结果进行模型评估。
测试并评估模型的具体流程包括:
步骤S61:将步骤S4与步骤S5中训练好的GAT-Seq2seq模型迁移给测试集数据,确定合适的预测步长l为1小时,利用模型及数据预测未来时段的站点客流量。
步骤S62:计算预测值与实际值的均方根误差RMSE,用于模型性能准确度评估,其中均方根误差RMSE公式如下:
其中,R为预测样本总个数,yr为实际客流量,为预测客流量,最终预测结果对比如图5所示。
以上实施例仅为说明本发明的技术思想,不能以此限定本发明的保护范围,凡是按照本发明提出的技术思想,在技术方案基础上所做的任何改动,均落入本发明保护范围之内。

Claims (5)

1.一种基于GAT-Seq2seq模型的时空双注意力地铁客流短时预测方法,其特征在于,所述方法包括如下步骤:
步骤1,采集地铁客流相关数据,包括:地铁站点基本信息、地铁站点历史客流数据、历史气象数据和节假日数据;
步骤2,对采集的地铁客流相关数据进行预处理后,将所述地铁站点历史客流数据和所述历史气象数据按时间先后排序生成时序化数据,将一天划分为若干时段,统计不同时段内的地铁进站客流量、降雨量、降雪量和温度数据,对上述统计得到的数据进行归一化后得到数据集,将所述数据集按照3:1:1的比例划分为训练集、验证集和测试集;
步骤3,将所有地铁站点作为图结构的顶点,通过地铁运行线路及设置的距离阈值判断任意两个顶点之间是否存在边,生成地铁站点图结构G,顶点之间的边权重采用Xavier方法进行初始化,并服从均匀分布;
步骤4,构建图注意力卷积神经网络模型GAT,使用步骤2中的训练集构建特征矩阵将所述特征矩阵/>作为所述网络模型GAT的输入,计算步骤3所述图结构G中各顶点之间的注意力系数,再加权求和得到考虑邻域站点影响后的特征矩阵/>具体过程如下:
步骤41,图注意力卷积神经网络模型GAT的输入是一个四维特征矩阵 其中,D为训练集数据的时间总跨度,T为一天中的时段个数,N为地铁站点个数,F为特征向量维数,则特征矩阵/>为:
其中,表示在第d天的第t个时段下,编号为n的地铁站点的F个影响客流变化的特征向量,d∈[1,D],t∈[1,T],n∈[1,N],D=1096,T=24,N=83,F=16;
步骤42,对于图结构的顶点si,计算顶点si与邻居顶点sj之间的相关系数用于评估si与sj之间的相关性,/>计算公式如下:
其中,sj表示与si存在边的邻居顶点,表示与si存在边的邻居顶点集合,W表示边上权重值矩阵,||操作表示横向拼接操作,a(·)表示将高维特征映射到一个实数的变换函数,分别表示在第d天的第t个时段下,编号为i、j的地铁站点的F个影响客流变化的特征向量;
步骤43,将相关系数归一化得到注意力系数/>具体计算公式如下:
其中,LeakyReLU(·)为归一化激活函数,sk表示与si存在边的邻居顶点;
步骤44,根据注意力系数将特征向量/>加权求和得到新的特征向量/>具体计算公式如下:
其中,σ(·)为计算函数;
步骤45,考虑M组不同的边上权重值矩阵W和a(·)函数来生成不同的再将/>横向拼接,最后求得平均值,具体计算公式如下:
其中,表示使用第m组权重W和a(·)函数计算得到的注意力系数,Wm表示第m组边上权重值矩阵;
步骤46,对时间总跨度D内的每个时段的特征向量都进行步骤42-45得到新的特征向量/>将新的特征向量/>按步骤41中/>的结构拼接生成GAT的输出矩阵/>
步骤5,构建基于时间注意力的时序到时序模型Seq2seq,将步骤4得到的特征矩阵输入所述时序到时序模型Seq2seq,经过编码层与解码层处理后,计算地铁进站客流量预测值与实际值之间的均方根误差,根据均方根误差和训练集对所述时序到时序模型Seq2seq中循环神经网络单元的超参数进行调整,并利用验证集对超参数进一步优化,从而得到训练好的GAT-Seq2seq模型;具体过程如下:
步骤51,构建编码器,编码器采用循环神经网络单元,将步骤4得到的特征矩阵的每一天的数据都平铺展开,记作X,X={x1,x2,x3,…,xT},共有D个X向量,将循环神经网络单元记为p,t时刻的隐藏层状态变量ht只和t-1时刻的隐藏层状态变量ht-1与t时刻的输入xt有关,公式表达如下:
ht=p(xt,ht-1)
构建D个并行编码器,即D天中对每一天都构建包含T个时段的编码器,用于生成统一的上下文向量c,则上下文向量c由编码器所有的隐藏层状态变量变换得到:
其中,q(·)为变换函数;
步骤52,考虑注意力机制构建解码器,解码器采用循环神经网络单元,记为g;解码器在t′时刻的隐藏层状态变量st′只与t′-1时刻的隐藏层状态变量st′-1、t′时刻的上下文向量ct′、t′-1时刻的输出变量yt′-1有关,公式表达如下:
st′=g(yt′-1,ct′,st′-1)
将解码器在t′时刻的上下文向量ct′分为两部分计算,前半部分只考虑最后一天的24个时段的时序信息,后半部分考虑之前所有日期下相同预测时段的信息,具体计算公式如下:
其中,表示编码器中第D天t时刻的隐藏层状态变量,/>表示编码器中第d天t′时刻的隐藏层状态变量,/>表示解码器中t′时刻的状态与编码器中第D天t时刻状态的时间注意力系数,/>表示解码器中t′时刻的状态与编码器中第d天t′时刻状态的时间注意力系数,时间注意力系数/>计算公式如下:
其中,表示解码器t′时刻的状态与编码器第D天t时刻状态的相关系数;/>表示解码器t′时刻的状态与编码器第d天t′时刻状态的相关系数;
通过上述公式计算得到解码器中t′时刻对应的所有上下文向量ct′,用于计算预测数据yt′
yt′=σ(yt′-1,ct′,st′);
步骤53,使用均方根误差RMSE作为模型的损失函数,根据训练集使用Adam SGD优化器对GAT-Seq2seq模型中的相关参数进行调整,并通过验证集进行超参数优化,从而得到训练好的GAT-Seq2seq模型;
步骤6,将测试集输入所述训练好的GAT-Seq2seq模型,预测未来各个时段下地铁站点的客流量,并根据预测结果对所述训练好的GAT-Seq2seq模型进行评估。
2.根据权利要求1所述的基于GAT-Seq2seq模型的时空双注意力地铁客流短时预测方法,其特征在于,所述步骤1中,地铁站点基本信息包括:地铁站点的经纬度地理坐标和地铁运行线路数据;地铁站点历史客流数据包括:站点名称或站点序号、刷卡进站时间和进站客流量;历史气象数据包括:是否降雨、是否降雪、降雨量、降雪量和温度;节假日数据包括:是否星期几和是否节假日,其中,是否星期几采用热独编码转换成七个参数,即:是否星期一、是否星期二、是否星期三、是否星期四、是否星期五、是否星期六和是否星期日。
3.根据权利要求1所述的基于GAT-Seq2seq模型的时空双注意力地铁客流短时预测方法,其特征在于,所述步骤2的具体过程如下:
步骤21,采用整体剔除法删除地铁客流相关数据中地铁站点名称异常或缺失的数据,采用前向填充替换法处理地铁刷卡进站时间、进站客流量、降雨量、降雪量和温度中存在缺失或异常的数据;
步骤22,对地铁站点历史客流数据按照时间先后顺序进行排序,设置时段长度l为1小时,则一天划分为24个时段,分别统计每个时段内的地铁进站客流量、降雨量和降雪量;
步骤23,根据采集的温度数据,计算每个时段内的平均温度,每个时段内的最高温度、最低温度分别取该时段内所有温度的最大值、最小值;
步骤24,对地铁刷卡进站客流量、降雨量、降雪量、平均温度、最高温度和最低温度进行归一化处理后得到数据集;
步骤25,将数据集按照3:1:1的比例划分为训练集、验证集和测试集。
4.根据权利要求1所述的基于GAT-Seq2seq模型的时空双注意力地铁客流短时预测方法,其特征在于,所述步骤3的具体过程如下:
步骤31,将所有地铁站点作为图结构的顶点,已知各个地铁站点的经纬度坐标,对于任意一个顶点A,寻得以顶点A为圆心,半径为2.5Km的圆形范围,若某地铁站点同时满足以下两个条件:(1)该地铁站点在圆形范围内,(2)该地铁站点与顶点A之间存在地铁运行可达线路或该地铁站点与顶点A之间的实际步行时间小于10分钟;则认为该地铁站点与顶点A之间存在边,否则不存在;
步骤32,对所有顶点重复步骤31,得到地铁站点图结构G。
5.根据权利要求1所述的基于GAT-Seq2seq模型的时空双注意力地铁客流短时预测方法,其特征在于,所述步骤6的具体过程如下:
步骤61,将训练好的GAT-Seq2seq模型迁移给测试集,确定预测步长l为1小时,利用训练好的GAT-Seq2seq模型预测未来时段的地铁站点客流量;
步骤62,计算预测值与实际值的均方根误差RMSE,用于模型性能评估,其中均方根误差RMSE公式如下:
其中,R为预测样本总个数,yr为实际客流量,为预测客流量。
CN202111278744.6A 2021-10-31 2021-10-31 一种基于GAT-Seq2seq模型的时空双注意力地铁客流短时预测方法 Active CN113962472B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111278744.6A CN113962472B (zh) 2021-10-31 2021-10-31 一种基于GAT-Seq2seq模型的时空双注意力地铁客流短时预测方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111278744.6A CN113962472B (zh) 2021-10-31 2021-10-31 一种基于GAT-Seq2seq模型的时空双注意力地铁客流短时预测方法

Publications (2)

Publication Number Publication Date
CN113962472A CN113962472A (zh) 2022-01-21
CN113962472B true CN113962472B (zh) 2024-04-19

Family

ID=79468546

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111278744.6A Active CN113962472B (zh) 2021-10-31 2021-10-31 一种基于GAT-Seq2seq模型的时空双注意力地铁客流短时预测方法

Country Status (1)

Country Link
CN (1) CN113962472B (zh)

Families Citing this family (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114004429B (zh) * 2022-01-04 2022-04-08 苏州元澄科技股份有限公司 一种用于构建数字城市的数据处理方法和系统
CN114819253A (zh) * 2022-03-02 2022-07-29 湖北大学 城市人群聚集热点区域预测方法、系统、介质及终端
CN117272848B (zh) * 2023-11-22 2024-02-02 上海随申行智慧交通科技有限公司 基于时空影响的地铁客流预测方法及模型训练方法
CN117591919B (zh) * 2024-01-17 2024-03-26 北京工业大学 客流预测方法、装置、电子设备和存储介质

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2020215798A1 (zh) * 2019-04-22 2020-10-29 中国科学院深圳先进技术研究院 一种地铁站内区域客流估计方法、系统及电子设备
CN111860785A (zh) * 2020-07-24 2020-10-30 中山大学 基于注意力机制循环神经网络的时间序列预测方法及系统
CN112801355A (zh) * 2021-01-20 2021-05-14 南京航空航天大学 基于长短期时空数据多图融合时空注意力的数据预测方法

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11158048B2 (en) * 2019-06-28 2021-10-26 Shandong University Of Science And Technology CT lymph node detection system based on spatial-temporal recurrent attention mechanism

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2020215798A1 (zh) * 2019-04-22 2020-10-29 中国科学院深圳先进技术研究院 一种地铁站内区域客流估计方法、系统及电子设备
CN111860785A (zh) * 2020-07-24 2020-10-30 中山大学 基于注意力机制循环神经网络的时间序列预测方法及系统
CN112801355A (zh) * 2021-01-20 2021-05-14 南京航空航天大学 基于长短期时空数据多图融合时空注意力的数据预测方法

Also Published As

Publication number Publication date
CN113962472A (zh) 2022-01-21

Similar Documents

Publication Publication Date Title
CN113962472B (zh) 一种基于GAT-Seq2seq模型的时空双注意力地铁客流短时预测方法
Liu et al. Contextualized spatial–temporal network for taxi origin-destination demand prediction
CN110570651B (zh) 一种基于深度学习的路网交通态势预测方法及系统
CN108564790B (zh) 一种基于交通流时空相似性的城市短时交通流预测方法
CN110956807B (zh) 基于多源数据与滑动窗口组合的高速公路流量预测方法
CN110555544B (zh) 一种基于gps导航数据的交通需求估计方法
CN112863182B (zh) 基于迁移学习的跨模态数据预测方法
CN111242395B (zh) 用于od数据的预测模型构建方法及装置
CN110163449B (zh) 一种基于主动时空图卷积的机动车排污监测节点部署方法
CN115440032A (zh) 一种长短期公共交通流量预测方法
Zhang et al. Battery maintenance of pedelec sharing system: Big data based usage prediction and replenishment scheduling
CN115412857A (zh) 一种居民出行信息预测方法
Zhao et al. Celltrademap: Delineating trade areas for urban commercial districts with cellular networks
CN116796904A (zh) 一种轨道交通新线客流预测方法、系统、电子设备及介质
CN116913088A (zh) 一种用于高速公路的智能流量预测方法
CN115204477A (zh) 一种上下文感知图递归网络的自行车流量预测方法
CN113537569B (zh) 一种基于权重堆叠决策树的短时公交客流预测方法及系统
CN116976702A (zh) 基于大场景gis轻量化引擎的城市数字孪生平台及方法
CN117494034A (zh) 基于交通拥堵指数和多源数据融合的空气质量预测方法
CN117037461A (zh) 一种基于多权重图三维卷积的路网交通拥堵预测方法
CN114330871A (zh) 一种通过公交运营数据结合gps数据预测城市路况的方法
AT&T
Niu et al. Highway Temporal‐Spatial Traffic Flow Performance Estimation by Using Gantry Toll Collection Samples: A Deep Learning Method
CN114139773A (zh) 一种基于时空图卷积网络的公共交通流量预测方法
CN114139984A (zh) 基于流量与事故协同感知的城市交通事故风险预测方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant