CN115965163A - 基于时空图生成对抗损失的轨道交通短时客流预测方法 - Google Patents
基于时空图生成对抗损失的轨道交通短时客流预测方法 Download PDFInfo
- Publication number
- CN115965163A CN115965163A CN202310115130.9A CN202310115130A CN115965163A CN 115965163 A CN115965163 A CN 115965163A CN 202310115130 A CN202310115130 A CN 202310115130A CN 115965163 A CN115965163 A CN 115965163A
- Authority
- CN
- China
- Prior art keywords
- time
- network
- passenger flow
- discriminator
- historical
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 37
- 238000010586 diagram Methods 0.000 title claims abstract description 29
- 230000006870 function Effects 0.000 claims abstract description 37
- 239000011159 matrix material Substances 0.000 claims abstract description 27
- 238000012549 training Methods 0.000 claims abstract description 21
- 238000013136 deep learning model Methods 0.000 claims abstract description 15
- 238000013507 mapping Methods 0.000 claims abstract description 11
- 230000002123 temporal effect Effects 0.000 claims description 22
- 238000013527 convolutional neural network Methods 0.000 claims description 16
- 238000003860 storage Methods 0.000 claims description 11
- 238000004590 computer program Methods 0.000 claims description 8
- 238000005457 optimization Methods 0.000 claims description 8
- 235000008694 Humulus lupulus Nutrition 0.000 claims description 5
- 230000004913 activation Effects 0.000 claims description 4
- 230000002708 enhancing effect Effects 0.000 claims description 4
- 230000003442 weekly effect Effects 0.000 description 8
- 238000012545 processing Methods 0.000 description 7
- 230000003042 antagnostic effect Effects 0.000 description 5
- 238000013135 deep learning Methods 0.000 description 5
- 230000006872 improvement Effects 0.000 description 5
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 4
- 230000005540 biological transmission Effects 0.000 description 4
- 238000002474 experimental method Methods 0.000 description 4
- 238000002679 ablation Methods 0.000 description 3
- 238000011156 evaluation Methods 0.000 description 3
- 238000012360 testing method Methods 0.000 description 3
- 238000004458 analytical method Methods 0.000 description 2
- 238000013528 artificial neural network Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 238000009792 diffusion process Methods 0.000 description 2
- 239000000835 fiber Substances 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 230000008569 process Effects 0.000 description 2
- 230000001902 propagating effect Effects 0.000 description 2
- 230000003068 static effect Effects 0.000 description 2
- 238000010200 validation analysis Methods 0.000 description 2
- RYGMFSIKBFXOCR-UHFFFAOYSA-N Copper Chemical compound [Cu] RYGMFSIKBFXOCR-UHFFFAOYSA-N 0.000 description 1
- 238000013256 Gubra-Amylin NASH model Methods 0.000 description 1
- 230000002776 aggregation Effects 0.000 description 1
- 238000004220 aggregation Methods 0.000 description 1
- 230000008485 antagonism Effects 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 229910052802 copper Inorganic materials 0.000 description 1
- 239000010949 copper Substances 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 230000014509 gene expression Effects 0.000 description 1
- 230000001788 irregular Effects 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 238000006116 polymerization reaction Methods 0.000 description 1
- 238000007637 random forest analysis Methods 0.000 description 1
- 230000000306 recurrent effect Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000002441 reversible effect Effects 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 238000012706 support-vector machine Methods 0.000 description 1
Images
Classifications
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Data Exchanges In Wide-Area Networks (AREA)
Abstract
本发明公开了一种基于时空图生成对抗损失的轨道交通短时客流预测方法。该方法包括:将城市轨道交通网络定义为图;将客流量作为城市轨道交通网络中地铁站的属性特征,其中每个地铁站包含进站客流矩阵和出站客流矩阵作为特征值;通过训练深度学习模型学习映射函数,以利用在历史时刻各地铁站的特征值来预测未来时刻的客流,其中深度学习模型包括生成器网络和判别器网络,该生成器网络以多模式历史客流数据作为输入,用于捕获每个模式中客流的时间和空间相关性,并合并多模式的输出传递至所述判别器网络;所述判别器网络用于区分真实的客流量和生成器网络所预测的客流量。本发明在时空尺度的约束下提升了预测精度和效率,并降低了内存占用。
Description
技术领域
本发明涉及交通客流预测技术领域,更具体地,涉及一种基于时空图生成对抗损失的轨道交通短时客流预测方法。
背景技术
短期客流预测可以用于更好地管理城市轨道交通系统,并且对于乘客来说,准确预测城市轨道交通各站的进站量可以有效地帮助他们规划出行路线,从而节省出行时间和成本。目前,一些新兴的深度学习模型能够提高短期客流预测精度。然而,城市轨道交通系统中存在许多复杂的时空依赖关系,而现有模型只以真实值和预测值之间的绝对误差作为优化目标,没有考虑预测值的空间和时间约束。此外,现有的一些预测模型尽管引入了复杂的神经网络层来提高准确性,但忽略了它们的训练效率和内存占用,难以很好的应用于现实世界。
短期客流预测作为一个热门且非常重要的研究课题,已经经历了很长的历史。最传统的模型是基于统计的预测模型,如自回归模型(AR),移动平均模型(MA),自回归综合移动平均(ARIMA)等。但是,预测精度有待提高,一般不能满足实时性的要求。对于基于机器学习的模型,如支持向量机,随机森林等,这种预测方法的主要局限性是它们通常适用于单个站点,而不是整个城市轨道交通网络。另一个限制是它们不能有效地捕捉站点之间复杂的非线性时空相关性。基于深度学习的模型已被证明在从大量数据中捕捉复杂相关性方面优于传统方法,显著提高了网络范围内的预测准确性。然而,大多数用于客流预测的深度学习方法仅采用绝对误差函数(如L1损失和L2损失)作为目标函数来优化其可学习参数。与简单的结构化数据不同,客流数据中存在多重空间和时间约束。图1的城市轨道交通网络为例,其中左图对应商业区,右图对应居住区,可以发现车站1和车站2在高峰时段的模式相似,因为它们是同一地区的相邻车站,因此预测需要满足此类空间约束。另一方面,对于每个车站,客流的变化需要满足一定的时间约束,例如在高峰期的持续增长。此外,最近的预测模型越来越复杂,只考虑模型精度的提高而忽略了模型的训练效率和内存占用,从现实应用的角度来看是不切实际的。
经分析,目前的短时客流预测模型主要存在以下几个问题:机器学习预测方法通常适用于单个站点,而不是整个城市轨道交通网络,并且不能有效地捕捉站点之间复杂的非线性时空相关性;最近的预测模型越来越复杂,只考虑模型精度的提高而忽略了模型的训练效率和内存占用,从现实应用的角度来看是不切实际的;现有模型通常只以真实值和预测值之间的绝对误差作为优化目标,没有考虑预测值的空间和时间约束。
综上,为了更好地运营城市轨道交通系统,迫切需要提供高精度、高效率的短期客流预测方案。
发明内容
本发明的目的是克服上述现有技术的缺陷,提供一种基于时空图生成对抗损失的轨道交通短时客流预测方法。该方法包括以下步骤:
将城市轨道交通网络定义为图G=(V,E,A),其中V表示地铁站数量,地铁站与地铁站之间有E条边,A表示邻接矩阵,用于指示两个地铁站是否相邻;
将客流量作为城市轨道交通网络中地铁站的属性特征,其中每个地铁站包含进站客流矩阵和出站客流矩阵作为特征值;
通过训练深度学习模型学习映射函数f,以利用在时间t时所有地铁站的所有特征值来预测t+1时刻的客流量Yt+1,表示为:
Yt+1=f(Xt)
其中,Xt表示在时间t时所有地铁站的所有特征值,所述深度学习模型包括生成器网络和判别器网络,该生成器网络以多模式历史客流数据作为输入,用于捕获每个模式中客流的时间和空间相关性,并合并多模式的输出传递至所述判别器网络;所述判别器网络用于区分真实的客流量和生成器网络所预测的客流量;映射函数f基于经训练的生成器网络获取。
与现有技术相比,本发明的优点在于,提出了一种基于深度学习的时空图生成对抗网络(STG-GAN)模型来预测城市轨道交通网络的短期客流,该模型结合门控时间卷积网络和权重共享图卷积网络来从多个时态模式中捕获时空依赖。此外,在判别器网络中,提出了一个空间判别器和一个时间判别器来鉴别预测值和真实值的时空特征,从而可以在时空尺度的约束下更好地考虑预测的一致性,并且提升了预测精度和效率,并降低了内存占用。
通过以下参照附图对本发明的示例性实施例的详细描述,本发明的其它特征及其优点将会变得清楚。
附图说明
被结合在说明书中并构成说明书的一部分的附图示出了本发明的实施例,并且连同其说明一起用于解释本发明的原理。
图1是现有技术中预测的空间和时间约束示意图;
图2是根据本发明一个实施例的基于时空图生成对抗损失的轨道交通短时客流预测方法的流程图;
图3是根据本发明一个实施例的STG-GAN模型示意图;
图4是根据本发明一个实施例的门控时间卷积网络的结构示意图;
图5是根据本发明一个实施例的图卷积网络的结构示意图;
图6是根据本发明一个实施例的空间判别器的结构示意图
图7是根据本发明一个实施例的时间判别器的结构示意图;
图8是根据本发明一个实施例的两个数据集上的时间消耗和GPU占用率对比图;
附图中,Generator Network-生成器网络;Discriminator Network-判别器网络;Spatial Discriminator-空间判别器;Temporial Discriminator-时间判别器;FC-全连接;MLP-多层感知器;Predictive Loss-预测损失;Input layer-输入层;Hidden layer-隐藏层;Output layer-输出层。
具体实施方式
现在将参照附图来详细描述本发明的各种示例性实施例。应注意到:除非另外具体说明,否则在这些实施例中阐述的部件和步骤的相对布置、数字表达式和数值不限制本发明的范围。
以下对至少一个示例性实施例的描述实际上仅仅是说明性的,决不作为对本发明及其应用或使用的任何限制。
对于相关领域普通技术人员已知的技术、方法和设备可能不作详细讨论,但在适当情况下,所述技术、方法和设备应当被视为说明书的一部分。
在这里示出和讨论的所有例子中,任何具体值应被解释为仅仅是示例性的,而不是作为限制。因此,示例性实施例的其它例子可以具有不同的值。
应注意到:相似的标号和字母在下面的附图中表示类似项,因此,一旦某一项在一个附图中被定义,则在随后的附图中不需要对其进行进一步讨论。
参见图2所示,所提供的基于时空图生成对抗损失的轨道交通短时客流预测方法总体上包括:步骤S210,将城市轨道交通网络定义为图G=(V,E,A),其中V表示地铁站数量,地铁站与地铁站之间有E条边,A表示邻接矩阵,用于指示两个地铁站是否相邻;步骤S220,将客流量作为城市轨道交通网络中地铁站的属性特征,其中每个地铁站包含进站客流矩阵和出站客流矩阵作为特征值;步骤S230,通过训练深度学习模型学习映射函数,以利用在时间t时所有地铁站的所有特征值来预测t+1时刻的客流量。其中所设计的深度学习模型包括生成器网络和判别器网络,该生成器网络以多模式历史客流数据作为输入,用于捕获每个模式中客流的时间和空间相关性,并合并多模式的输出传递至所述判别器网络;所述判别器网络用于区分真实的客流量和生成器网络所预测的客流量;映射函数f基于经训练的生成器网络获取。
在下文,将首先对所要解决的科学问题进行详细定义。然后,介绍所提出的深度学习模型框架。进而,对该框架所使用的图卷积神经网络以及生成对抗网络进行详细说明。最后,使用北京2016年和2018年连续五周的地铁卡数据进行试验,并与其他一些经典的短时客流预测模型进行对比。
一、问题定义
本发明的目的是同时预测城市轨道交通全网所有车站某一时段的进站客流量。将城市轨道交通网络定义为一张图G=(V,E,A),图中有V个地铁车站,车站与车站之间有E条边,代表该网络的邻接矩阵,只有0,1两种元素,0代表两个车站不相邻,1代表两个车站相邻。
将客流量作为城市轨道交通系统中地铁站的属性特征,表示为即第i个车站在第t个时间段的第k个特征,n为车站数量,m为时间步,k为特征矩阵的数量。在一个实施例中,每个车站有两个特征矩阵:进站客流矩阵和出站客流矩阵。因此,k=2。
其中,Xt表示在时间t时所有车站的所有特征值。
其中,X表示在时间T内所有车站的所有特征值。
Yt+1∈Rn*1*k表示未来t+1时刻所有车站的客流量。
因此,本发明所要解决的问题是利用过去所有车站m个时间间隔的客流量来预测t+1时刻的客流量Yt+1。
如公式(4)所示,其中f为本发明提出的深度学习框架将要学习的映射函数。
Yt+1=f(Xt)(4)
二、时空图生成对抗网络(STG-GAN)概述
图3显示了STG-GAN的框架。STG-GAN主要包含两部分,即生成器网络和判别器网络。生成器网络(G)包括门控时间卷积网络(TCN)、权重共享GCN(图卷积网络)以及全连接层。判别器网络(D)包括空间判别器和时间判别器。在该框架中,历史客流数据首先被聚合成三种模式:实时模式、日模式和周模式。然后,客流数据被输入到门控TCN和权重共享GCN的堆叠模块中,这些模块用来捕获每个模式中客流的时间和空间相关性。然后,利用全连接层合并三种模式的输出,并将其作为GAN中判别器网络的输入。判别器网络试图区分哪个是真实的客流,哪个是生成器网络所生成的客流。空间判别器可由混合跳图乘法组件和全连接层组成,旨在增强预测的空间约束。时间判别器可由门控TCN和全连接层组成,旨在增强预测的时间约束。G和D迭代训练,生成的数据和真实数据非常相似,以至于D无法区分。在训练结束之后,生成器网络可以用作短时客流预测模型。
三、生成器网络
在一个实施例中,生成器网络包括门控TCN和共享GCN框架。在介绍门控TCN和共享GCN框架之前,详细描述了生成器网络的输入。
1)不同的流入模式
由于进站客流和出站客流受相邻时段、日时段和周时段的客流影响,因此在一个实施例中利用三种模式下的客流:实时模式、日模式和周模式。假设时间粒度为ti,时间步为ts,当前时间段为t,预测t+1时刻的客流量。实时模式:Xreal=(Xt-ts+1,Xt-ts+2,…,Xt),临近预测时间段的历史时间序列。日模式:预测时间段前一天同一时间的历史时间序列。周模式:预测时间段前一周同一时间的历史时间序列。这三种模式分别输入到生成器网络的门控TCN中,使得门控TCN模块可以滤除原始客流数据中的噪声。
2)门控TCN
考虑到基于RNN和基于自我注意的时间学习方法都可能带来相对较大的计算量,它们很难部署在优先考虑效率的工业场景中。因此,在一个实施例中,采用了一个更轻量级的,基于CNN的门控TCN。门控TCN主要用于过滤原始客流数据中的噪声,如图4所示。来自单一模式的客流数据由两个分支分别处理。一个是由sigmoid门控激活函数的一维CNN,另一个是由双曲正切门控激活函数(Tanh)的一维CNN。它们的输出由Hadmard乘积运算合并,Hadmard乘积运算被定义为:
3)GCN
图5显示了GCN模型的一般结构。假设在静态图中有N个具有M维特征的节点。拓扑结构和节点特征可以分别用邻接矩阵A和特征矩阵Z表示。Kipf和Welling提出的典型GCN模型以特征矩阵Z为输入,对其进行局部一阶近似的图形卷积运算。无偏差的GCN函数可以定义如下:
4)时空聚合和预测
为了在生成器网络中更好地保留多尺度时空信息,首先利用一个全连接层来聚集来自不同模式的隐藏信息,然后利用跳过连接操作来总结来自每一层的隐藏信息。在生成器网络的底部,有另一个全连接层来生成预测,其定义为:
四、判别器网络
对于判别器网络,创新性地提出了一个空间判别器和一个时间判别器来增强预测的空间和时间约束。
1)空间判别器
在城市轨道交通网络的客流中存在一些空间限制,因为彼此相邻的车站具有相对较强的空间相关性,如图1所示。因此,在空间判别器中,使用混合跳乘法将空间信息嵌入预测值和实际值Y,然后是全连接层,如图6所示。从而空间判别器可以从空间角度区分预测值和实际值。例如,用来训练空间判别器的损失函数定义如下:
2)时间判别器
时间判别器的结构如图7所示。
除了空间限制外,城市轨道交通网络的客流也有一些时间限制,特别是时间模式的连续性,如图1所示。因此,还需要增强预测的时间约束。与空间判别器不同,通过将历史客流序列与预测值和实际值连接起来,以便更好地鉴别时间特征。然后,时间序列由门控TCN处理,以过滤原始序列中的一些明显的序列,并进一步输入到全连接层中,然后输出。用来训练时间判别器的损失函数定义如下:
其中fT(·)表示门控TCN和多层感知器的映射函数,θT表示时间判别器中的可学习参数。
五、STG-GAN的优化
在训练阶段,交替优化生成器网络和判别器网络的参数。对于判别器网络,优化目标函数是等式(10)和(11)的组合,其定义如下:
对于生成器网络,优化目标函数由绝对损耗部分、空间和时间对抗性损耗部分组成,定义为:
六、实验结果
在实验验证中,利用两个真实世界的数据集测试了所提出的模型,并将其与一些基线模型进行了比较,进而从多个角度对实验结果进行了分析。
1)数据集
实验使用了北京地铁的两个数据集,参见表1。第一个数据集是MetroBJ2016,包括从2016年2月29日到2016年4月3日连续5周工作日的北京地铁刷卡数据,共计17条线路和276个车站(不包括机场线)。第二个数据集是MetroBJ2018,包括2018年10月8日至2018年11月11日连续5周工作日的北京地铁刷卡数据,共计22条线路和308个车站(不包括机场线)。为了提取进站客流与出站客流,将进站时间和出站时间转换为分钟,从0到1080,代表05:00到23:00。然后,从数据中提取出15分钟时间粒度的进站客流时间序列。MetroBJ2016的客流维度为276*1800,MetroBJ2018的客流维度为308*1800。此外,为所有地铁站提供一个唯一的车站编号。
在实验中,由于数据量有限,并且考虑到预测模型中的实时模式、每日模式和每周模式三种模式,前四周的数据用于训练和验证所提出的模型,其余数据用于测试模型。
表1数据集的描述
2)模型配置与评价指标
(1)模型参数设置
STG-GAN的相同参数被应用于两个数据集。具体地,生成器网络的层数被设置为3。生成器网络中的并行TCN和权重共享GCN的隐藏维数都设置为64。在空间判别器中,图乘法的最大跳数设置为2,MLP的隐藏单元分别为[1024,512]和[512,256]。在时间判别器中,TCN的层数设置为2,MLP的隐藏单元与空间判别器中的相同。因为使用WGAN进行对抗训练,所以在判别器的输出层中没有激活功能。选择80%的数据作为训练数据和验证数据,剩下的20%作为测试数据。批量大小为32。优化器是RMSprop,学习率为0.0001。
(2)评价标准
使用的评价指标为均方根误差Root Mean Square Error(RMSE)、平均绝对误差Mean Absolute Error(MAE)和加权平均绝对百分比误差weighted Mean AbsolutePercentage Error(WMAPE),参见公式(14)到(16)。
3)模型对比
进一步地,将所提出的Graph-GAN模型将和以下几个模型进行对比,以证明模型的有效性。
ARIMA:自回归移动平均模型。将相同的ARIMA模型应用于所有城市轨道交通站点的客流。模型中的三个参数,即滞后阶数、差异程度、移动平均阶数,经过微调后分别设为9、1、0。
LSTM:长短时记忆网络模型。使用的LSTM模型具有两个LSTM层和三个全连接层。优化器是Adam,学习率为0.001,batch size为32。输入是10个时间步的进站客流序列,输出是下一个时间步的进站客流序列。
CNN:卷积神经网络模型。使用的CNN模型具有两个CNN层和三个全连接层。内核大小为3*3,优化器是Adam,学习率为0.001。batch size是32,输入和输出与LSTM模型相同。
ST-ResNet:时空残差网络模型。使用其中的三个分支,不包含天气数据分支。
ConvLSTM:卷积循环神经网络模型。使用的ConvLSTM模型具有两个ConvLSTM层和三个全连接层。其他参数与CNN模型相同。
STGCN:这是第一个将图卷积网络(GCN)和时间卷积网络(TCN)结合起来用于道路网络交通预测的模型。GCN的跳数被设置为2,而TCN的内核大小为3。优化器是Adam,学习率为0.001。批量大小为32。
DCRNN:该模型集成了扩散GCN和门控递归单元网络(GRU)用于时空学习。GCN的扩散跳数设置为2,GRU的层数设置为1,GRU的隐藏维数设置为128。优化器是Adam,学习率为0.0001,批处理大小为32。
Graph WaveNet(GWN):该模型在STGCN的基础上深化了TCN扩张网络。该模型的层数设置为6层。扩张的比例是[1,2,1,2,1,2]。其他参数与STGCN模型相同。
MVGCN:与ST-ResNet类似,该模型也涉及三个时态分支。不同的是,对于不规则的空间范围,它用GCN代替了CNN。不使用天气数据和事故数据。
Graph WaveNet(GWN):该模型提出了一种新的动态图学习模块,以自适应地表征时空相关性。其他架构的设置类似于GWN模型。
4)实验结果分析
(1)全网预测性能
表2显示了STG-GAN与其他基线方法在MetroBJ2016和MetroBJ2018数据集上的性能对比。由表2可知,深度学习模型显著优于基于数理统计的模型。在MetroBJ2016和MetroBJ2018数据集中,ARIMA是性能最差的模型,两个数据集的RMSE分别为81.4562和69.4250,MSE分别为42.8006和33.9540,这是因为ARIMA无法捕捉客流的综合非线性特征。
进一步将本发明的模型与其他的深度学习方法进行比较,如LSTM和CNN。在深度学习模型中,LSTM不会对空间相关性进行建模,而CNN不会捕捉时间相关性。因此,这两个模型的性能比STG-GAN模型差。ST-Resnet、ConvLSTM、STGCN、DCRNN、GWN、MVGCN和DMSTGCN是考虑空间相关性和时间相关性的客流预测方法。与LSTM和CNN相比,这些方法都实现了准确性的提高。然而,这些最先进的深度学习模型的性能比本发明提出的模型差。与这些基准相比,模型STG-GAN实现了5%-8%的提升。
对于GAN而言,随着对抗性训练过程的进行,生成器模型可以显著提高其获得更准确预测结果的能力。因此,本发明设计了用于时空依赖学习的具有平行TCN和共享GCN的堆叠层的生成器。然后设计了一个空间判别器和时间判别器,分别从空间和时间尺度上增强生成器的预测能力。这是该模型能够取得最好效果的主要原因。由表2可知,与传统和深度学习模型相比,STG-GAN实现了最准确的预测,两个数据集的最低RMSE分别为34.1493和32.3884,MAE为19.5467和15.6214,WMAPE分别为7.376%和8.004%。
表2不同模型的性能对比
(2)消融研究
对MetroBJ2016和MetroBJ2018数据集进行了消融研究,以评估模型中关键部分的有效性。参加表3,将完整的STG-GAN与以下变体进行了比较:1)w/o recent,删除了生成器网络中的最近分支。2)w/o daily,删除生成器网络中的每日分支。3)w/o weekly,删除生成器网络中的每周分支。4)不使用GCN,删除所有共享的GCN层5)不使用TCN,从模型中删除所有并行时间卷积模块。6)w/o SD,其在训练阶段移除空间判别器。7)w/o TD,其在训练阶段删除时间判别器。8)w/o GAN,删除对抗性学习的整个部分。从表III中的总体实验结果,可以发现STG-GAN优于所有消融变体。
实验证明了所提出模型优于w/o recent、w/o weekly和w/o daily,这表明该模型的每个分支在时空学习中都起着重要的作用。其中,最近分支是两个数据集上最重要的部分。与w/o最近的结果相比,STG-GAN在MetroBJ2016上的MAE和WMAPE分别提高了6.02%和6.54%。同时,在MetroBJ2018上,MAE和WMAPE也分别提高了6.76%和8.25%。
此外,本发明所提出模型优于w/o GCN和w/o TCN,这验证了生成器网络中空间和时间学习模块的有效性。与w/o GCN的结果相比,STG-GAN在MetroBJ2016上的MAE和WMAPE分别提高了5.45%和4.55%。同时,在MetroBJ2018上,MAE和WMAPE也分别提高了4.87%和6.52%。与w/o TCN的结果相比,STG-GAN在MetroBJ2016上的MAE和WMAPE分别提高了6.33%和5.62%。同时,在MetroBJ2018上,MAE和WMAPE也分别提高了6.61%和9.22%。
进一步地,还选择了w/o SD、w/o TD和w/o GAN来验证模型中对抗学习的有效性。与w/o GAN的结果相比,该模型在MetroBJ2016上的MAE和WMAPE分别提高了4.77%和5.05%。同时,在MetroBJ2018上,MAE和WMAPE也分别提高了4.23%和5.37%。这一显著改进表明,对抗性学习机制能够提高城市轨道交通客流预测的准确性。
表3不同消融变体性能比较
(3)效率和占用率的比较
在行业场景中,训练效率和占用率是衡量模型效果的两个重要指标。训练阶段较低的时间消耗和GPU占用率可以使模型更容易部署。为了进一步评估模型,将所提出的模型与五个最佳基线模型进行比较。
时间消耗和GPU占用率的对比结果如图8所示,其中图8(a)对应时间消耗,8(b)对应GPU的占用率。可以发现,所提出的模型的时间消耗和GPU占用率明显低于DCRNN,但略高于其他基线模型。由于所提出模型仅略微增加了训练时间和GPU占用率,但获得了比几个最佳基线模型更好的预测性能,这进一步证明了所提出模型在工业部署中的优越性。
从表3中的总体实验结果,可以发现STG-GAN优于所有消融变体。从时间消耗和GPU占用率的对比结果分别可以发现,所提出的模型的时间消耗和GPU占用率明显低于DCRNN,但略高于其他基线模型。由于所提出模型仅略微增加了训练时间和GPU占用率,但获得了比几个最佳基线模型更好的预测性能,这进一步证明了所提出模型在工业部署中的优越性。
综上所述,本发明具有以下优势:
1)本发明提供时空图生成对抗网络(STG-GAN)模型,在模型结构方面,本发明通过门控TCN和权重共享GCN的堆叠模块来捕获每个模式中客流的时间和空间相关性,并且在生成对抗网络中,空间判别器由混合跳图乘法组件和全连接层组成,旨在增强预测的空间约束,时间判别器由门控TCN和全连接层组成,旨在增强预测的时间约束;在模型侧重点方面,本发明通过对生成对抗网络具体结构的设计,分别定义了空间判别器和时间判别器的损失函数,并且生成器的优化目标由绝对损耗部分、空间和时间对抗性损耗部分组成,模型更加侧重对损失的精细把握。
2)所设计的生成器网络可以有效地从城市轨道交通网络中提取时空和拓扑信息并生成预测,能够用来捕获结构时空依赖性并以相对较小的计算量生成预测值
3)提出了包含一个空间判别器和一个时间判别器的判别器网络来鉴别来自生成网络的预测,并引入时空对抗损失来指导生成网络的学习,这优于以前仅基于数值损失来优化网络的方法。所设计的判别器网络能够增强预测的空间和时间约束。
本发明可以是系统、方法和/或计算机程序产品。计算机程序产品可以包括计算机可读存储介质,其上载有用于使处理器实现本发明的各个方面的计算机可读程序指令。
计算机可读存储介质可以是可以保持和存储由指令执行设备使用的指令的有形设备。计算机可读存储介质例如可以是但不限于电存储设备、磁存储设备、光存储设备、电磁存储设备、半导体存储设备或者上述的任意合适的组合。计算机可读存储介质的更具体的例子(非穷举的列表)包括:便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、静态随机存取存储器(SRAM)、便携式压缩盘只读存储器(CD-ROM)、数字多功能盘(DVD)、记忆棒、软盘、机械编码设备、例如其上存储有指令的打孔卡或凹槽内凸起结构、以及上述的任意合适的组合。这里所使用的计算机可读存储介质不被解释为瞬时信号本身,诸如无线电波或者其他自由传播的电磁波、通过波导或其他传输媒介传播的电磁波(例如,通过光纤电缆的光脉冲)、或者通过电线传输的电信号。
这里所描述的计算机可读程序指令可以从计算机可读存储介质下载到各个计算/处理设备,或者通过网络、例如因特网、局域网、广域网和/或无线网下载到外部计算机或外部存储设备。网络可以包括铜传输电缆、光纤传输、无线传输、路由器、防火墙、交换机、网关计算机和/或边缘服务器。每个计算/处理设备中的网络适配卡或者网络接口从网络接收计算机可读程序指令,并转发该计算机可读程序指令,以供存储在各个计算/处理设备中的计算机可读存储介质中。
用于执行本发明操作的计算机程序指令可以是汇编指令、指令集架构(ISA)指令、机器指令、机器相关指令、微代码、固件指令、状态设置数据、或者以一种或多种编程语言的任意组合编写的源代码或目标代码,所述编程语言包括面向对象的编程语言—诸如Smalltalk、C++、Python等,以及常规的过程式编程语言—诸如“C”语言或类似的编程语言。计算机可读程序指令可以完全地在用户计算机上执行、部分地在用户计算机上执行、作为一个独立的软件包执行、部分在用户计算机上部分在远程计算机上执行、或者完全在远程计算机或服务器上执行。在涉及远程计算机的情形中,远程计算机可以通过任意种类的网络—包括局域网(LAN)或广域网(WAN)—连接到用户计算机,或者,可以连接到外部计算机(例如利用因特网服务提供商来通过因特网连接)。在一些实施例中,通过利用计算机可读程序指令的状态信息来个性化定制电子电路,例如可编程逻辑电路、现场可编程门阵列(FPGA)或可编程逻辑阵列(PLA),该电子电路可以执行计算机可读程序指令,从而实现本发明的各个方面。
这里参照根据本发明实施例的方法、装置(系统)和计算机程序产品的流程图和/或框图描述了本发明的各个方面。应当理解,流程图和/或框图的每个方框以及流程图和/或框图中各方框的组合,都可以由计算机可读程序指令实现。
这些计算机可读程序指令可以提供给通用计算机、专用计算机或其它可编程数据处理装置的处理器,从而生产出一种机器,使得这些指令在通过计算机或其它可编程数据处理装置的处理器执行时,产生了实现流程图和/或框图中的一个或多个方框中规定的功能/动作的装置。也可以把这些计算机可读程序指令存储在计算机可读存储介质中,这些指令使得计算机、可编程数据处理装置和/或其他设备以特定方式工作,从而,存储有指令的计算机可读介质则包括一个制造品,其包括实现流程图和/或框图中的一个或多个方框中规定的功能/动作的各个方面的指令。
也可以把计算机可读程序指令加载到计算机、其它可编程数据处理装置、或其它设备上,使得在计算机、其它可编程数据处理装置或其它设备上执行一系列操作步骤,以产生计算机实现的过程,从而使得在计算机、其它可编程数据处理装置、或其它设备上执行的指令实现流程图和/或框图中的一个或多个方框中规定的功能/动作。
附图中的流程图和框图显示了根据本发明的多个实施例的系统、方法和计算机程序产品的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段或指令的一部分,所述模块、程序段或指令的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。在有些作为替换的实现中,方框中所标注的功能也可以以不同于附图中所标注的顺序发生。例如,两个连续的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这依所涉及的功能而定。也要注意的是,框图和/或流程图中的每个方框、以及框图和/或流程图中的方框的组合,可以用执行规定的功能或动作的专用的基于硬件的系统来实现,或者可以用专用硬件与计算机指令的组合来实现。对于本领域技术人员来说公知的是,通过硬件方式实现、通过软件方式实现以及通过软件和硬件结合的方式实现都是等价的。
以上已经描述了本发明的各实施例,上述说明是示例性的,并非穷尽性的,并且也不限于所披露的各实施例。在不偏离所说明的各实施例的范围和精神的情况下,对于本技术领域的普通技术人员来说许多修改和变更都是显而易见的。本文中所用术语的选择,旨在最好地解释各实施例的原理、实际应用或对市场中的技术改进,或者使本技术领域的其它普通技术人员能理解本文披露的各实施例。本发明的范围由所附权利要求来限定。
Claims (10)
1.一种基于时空图生成对抗损失的轨道交通短时客流预测方法,包括以下步骤:
将城市轨道交通网络定义为图G=(V,E,A),其中V表示地铁站数量,地铁站与地铁站之间有E条边,A表示邻接矩阵,用于指示两个地铁站是否相邻;
将客流量作为城市轨道交通网络中地铁站的属性特征,其中每个地铁站包含进站客流矩阵和出站客流矩阵作为特征值;
通过训练深度学习模型学习映射函数f,以利用在时间t时所有地铁站的所有特征值来预测t+1时刻的客流量Yt+1,表示为:
Yt+1=f(Xt)
其中,Xt表示在时间t时所有地铁站的所有特征值,所述深度学习模型包括生成器网络和判别器网络,该生成器网络以多模式历史客流数据作为输入,用于捕获每个模式中客流的时间和空间相关性,并合并多模式的输出传递至所述判别器网络;所述判别器网络用于区分真实的客流量和生成器网络所预测的客流量;映射函数f基于经训练的生成器网络获取。
2.根据权利要求1是所述的方法,其特征在于,所述生成器网络包括门控时间卷积网络、权重共享图卷积神经网络以及全连接层,所述多模式历史客流数据被输入到门控时间卷积网络和权重共享图卷积神经网络的堆叠模块中,以捕获每种模式中客流的时间和空间相关性;所述全连接层用于合并多种模式的输出,并将其作为判别器网络的输入。
3.根据权利要求2所述的方法,其特征在于,所述判别器网络包括空间判别器和时间判别器,其中,所述空间判别器由混合跳图乘法组件和全连接层组成,用于增强预测的空间约束;所述时间判别器由门控时间卷积网络和全连接层组成,用于增强预测的时间约束。
4.根据权利要求2所述的方法,其特征在于,所述门控时间卷积网络包含第一分支和第二分支,其中第一分支是包含sigmoid门控激活函数的一维卷积神经网络,第而分支是包含双曲正切门控激活函数的一维卷积神经网络,第一分支和第二分支的输出由Hadmard乘积运算合并。
8.根据权利要求1所述的方法,其特征在于,所述多模式历史客流数据包括:实时模式历史客流数据,用于指示临近预测时间段的历史时间序列;日模式历史客流数据,用于指示预测时间段前一天同一时间的历史时间序列;周模式历史客流数据,用于指示预测时间段前一周同一时间的历史时间序列。
9.一种计算机可读存储介质,其上存储有计算机程序,其中,该计算机程序被处理器执行时实现根据权利要求1至8中任一项所述方法的步骤。
10.一种计算机设备,包括存储器和处理器,在所述存储器上存储有能够在处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至8中任一项所述的方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310115130.9A CN115965163A (zh) | 2023-02-07 | 2023-02-07 | 基于时空图生成对抗损失的轨道交通短时客流预测方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310115130.9A CN115965163A (zh) | 2023-02-07 | 2023-02-07 | 基于时空图生成对抗损失的轨道交通短时客流预测方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115965163A true CN115965163A (zh) | 2023-04-14 |
Family
ID=87361571
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310115130.9A Pending CN115965163A (zh) | 2023-02-07 | 2023-02-07 | 基于时空图生成对抗损失的轨道交通短时客流预测方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115965163A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117272848A (zh) * | 2023-11-22 | 2023-12-22 | 上海随申行智慧交通科技有限公司 | 基于时空影响的地铁客流预测方法及模型训练方法 |
Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114626585A (zh) * | 2022-02-28 | 2022-06-14 | 北京交通大学 | 一种基于生成对抗网络的城市轨道交通短时客流预测方法 |
-
2023
- 2023-02-07 CN CN202310115130.9A patent/CN115965163A/zh active Pending
Patent Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114626585A (zh) * | 2022-02-28 | 2022-06-14 | 北京交通大学 | 一种基于生成对抗网络的城市轨道交通短时客流预测方法 |
Non-Patent Citations (2)
Title |
---|
JINLEI ZHANG ETAL: "STG-GAN: A spatiotemporal graph generative adversarial networks for short-term passenger flow prediction in urban rail transit systems", 《HTTPS://ARXIV.ORG/ABS/2202.06727》, pages 2 - 8 * |
柳毅主编: "《机器学习与Python实践》", 西安电子科技大学出版社, pages: 202 - 203 * |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117272848A (zh) * | 2023-11-22 | 2023-12-22 | 上海随申行智慧交通科技有限公司 | 基于时空影响的地铁客流预测方法及模型训练方法 |
CN117272848B (zh) * | 2023-11-22 | 2024-02-02 | 上海随申行智慧交通科技有限公司 | 基于时空影响的地铁客流预测方法及模型训练方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Tekouabou et al. | Improving parking availability prediction in smart cities with IoT and ensemble-based model | |
Chu et al. | Deep multi-scale convolutional LSTM network for travel demand and origin-destination predictions | |
CN110223517B (zh) | 基于时空相关性的短时交通流量预测方法 | |
CN112001548B (zh) | 一种基于深度学习的od客流预测方法 | |
CN111667092A (zh) | 基于图卷积神经网络的轨道交通短时客流预测方法和系统 | |
CN112150207A (zh) | 基于时空上下文注意力网络的网约车订单需求预测方法 | |
CN113762595B (zh) | 通行时间预测模型训练方法、通行时间预测方法及设备 | |
Zhu et al. | Multistep flow prediction on car-sharing systems: A multi-graph convolutional neural network with attention mechanism | |
CN115204478A (zh) | 一种结合城市兴趣点和时空因果关系的公共交通流量预测方法 | |
CN117392846A (zh) | 一种时空自适应图学习融合动态图卷积的交通流预测方法 | |
CN115034496A (zh) | 基于GCN-Transformer的城市轨道交通节假日短时客流预测方法 | |
CN113112791A (zh) | 一种基于滑动窗口长短时记忆网络的交通流量预测方法 | |
CN116468186A (zh) | 一种航班链延误时间预测方法、电子设备及存储介质 | |
CN115146844A (zh) | 一种基于多任务学习的多模式交通短时客流协同预测方法 | |
Abdelatif et al. | Vehicular-cloud simulation framework for predicting traffic flow data | |
CN115965163A (zh) | 基于时空图生成对抗损失的轨道交通短时客流预测方法 | |
CN110443422B (zh) | 基于od吸引度的城市轨道交通od客流预测方法 | |
Li et al. | Cycle-based signal timing with traffic flow prediction for dynamic environment | |
CN115481784A (zh) | 一种基于改进组合模型的交通流量预测方法及应用 | |
Tao et al. | A delay-based deep learning approach for urban traffic volume prediction | |
CN116432808A (zh) | 一种基于深度学习的网络级多模式交通短时客流预测方法 | |
Ayman et al. | Neural architecture and feature search for predicting the ridership of public transportation routes | |
Hsieh et al. | Recommending taxi routes with an advance reservation–a multi-criteria route planner | |
CN117593877A (zh) | 一种基于集成图卷积神经网络的短时交通流预测方法 | |
CN114638395A (zh) | 一种车辆出行时间预测的方法、系统、计算机设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
RJ01 | Rejection of invention patent application after publication |
Application publication date: 20230414 |
|
RJ01 | Rejection of invention patent application after publication |