CN114626585B - 一种基于生成对抗网络的城市轨道交通短时客流预测方法 - Google Patents

一种基于生成对抗网络的城市轨道交通短时客流预测方法 Download PDF

Info

Publication number
CN114626585B
CN114626585B CN202210188660.1A CN202210188660A CN114626585B CN 114626585 B CN114626585 B CN 114626585B CN 202210188660 A CN202210188660 A CN 202210188660A CN 114626585 B CN114626585 B CN 114626585B
Authority
CN
China
Prior art keywords
passenger flow
time
network
model
generator
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202210188660.1A
Other languages
English (en)
Other versions
CN114626585A (zh
Inventor
张金雷
杨立兴
李华
戚建国
阴佳腾
陈瑶
高自友
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Jiaotong University
Original Assignee
Beijing Jiaotong University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Jiaotong University filed Critical Beijing Jiaotong University
Priority to CN202210188660.1A priority Critical patent/CN114626585B/zh
Publication of CN114626585A publication Critical patent/CN114626585A/zh
Application granted granted Critical
Publication of CN114626585B publication Critical patent/CN114626585B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q10/00Administration; Management
    • G06Q10/04Forecasting or optimisation specially adapted for administrative or management purposes, e.g. linear programming or "cutting stock problem"
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/044Recurrent networks, e.g. Hopfield networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/048Activation functions
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q50/00Information and communication technology [ICT] specially adapted for implementation of business processes of specific business sectors, e.g. utilities or tourism
    • G06Q50/10Services
    • G06Q50/26Government or public services
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Business, Economics & Management (AREA)
  • General Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Computational Linguistics (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Biophysics (AREA)
  • Mathematical Physics (AREA)
  • Biomedical Technology (AREA)
  • Artificial Intelligence (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Strategic Management (AREA)
  • Human Resources & Organizations (AREA)
  • Tourism & Hospitality (AREA)
  • Economics (AREA)
  • Development Economics (AREA)
  • Marketing (AREA)
  • General Business, Economics & Management (AREA)
  • Primary Health Care (AREA)
  • Educational Administration (AREA)
  • Game Theory and Decision Science (AREA)
  • Entrepreneurship & Innovation (AREA)
  • Operations Research (AREA)
  • Quality & Reliability (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了一种基于生成对抗网络的城市轨道交通短时客流预测方法。该方法包括:针对城市轨道交通网络构建图结构,并将客流量作为车站的属性特征;基于所述图结构获取多个模式下反映历史客流信息的时间序列数据,所述多个模式根据与客流预测时刻的不同时间间隔进行划分;将所述多个模式下的时间序列数据输入至图卷积神经网络获取各模式下客流的时空相关性;将所述图卷积神经网络输出的不同模式数据进行合并后输入到生成器,以生成城市轨道交通网络中目标车站在后续时刻的交通客流信息,其中所述生成器利用设定的目标函数通过训练生成对抗网络获得。本发明提高了客流预测精度,并降低了模型的复杂度。

Description

一种基于生成对抗网络的城市轨道交通短时客流预测方法
技术领域
本发明涉及交通客流预测技术领域,更具体地,涉及一种基于生成对抗网络的城市轨道交通短时客流预测方法。
背景技术
随着城市轨道交通的快速发展,轨道交通客流量越来越大,导致了城市轨道交通系统的严重拥堵。缓解交通拥堵的手段之一是准确预测城市轨道交通的短期客流,并采取相应的管理措施。因此,短期客流预测对城市轨道交通系统的管理具有重要意义。
短期客流预测是一个非常重要的研究课题,其历史悠久。最常规的模型是基于统计的预测模型,如自回归模型(AR)、移动平均模型(MA)、自回归综合移动平均模型(ARIMA)等。例如,Williams等人提出了一个季节性的ARIMA来预测城市高速公路的交通流。也有研究者将ARIMA模型扩展到空间维度,并将其进一步应用于轨迹预测。但基于统计的模型的预测精度有待提高,通常不能满足实时性要求。
近年来,基于机器学习和深度学习的预测模型在城市轨道交通系统中得到了广泛的应用。对于基于机器学习的模型,如支持向量机(SVM)、随机森林等。例如,有研究者提出了一种结合SARIMA(季节性差分自回归滑动平均模型)和SVM的混合模型,以解决短期交通预测的周期性、非线性、不确定性和复杂性等问题。又如,建立两个支持向量回归模型,专门用于预测季节性交通流量。再如,采用支持向量机模型来捕捉短期客流量的周期性和非线性特征。但是目前基于机器学习的预测方法存在两个局限性,第一个局限性是这些方法通常应用于单个站点,而不是整个城市轨道交通网络。另一个局限性是不能有效捕捉站与站之间复杂的非线性时空相关性。在现有技术方案中,有研究者提出了一种改进的LSTM来实现短期交通流预测。然而,一方面,这些深度学习算法只考虑客流的时间相关性,而忽略了地铁站之间的拓扑结构。另一方面,目前的预测短时客流预测模型越来越复杂,只考虑模型精度的提高,而忽略了模型的复杂性。
现有的短时客流预测模型总结起来存在以下几个问题:1)传统的基于数理统计的模型存在实时性较差,预测精度较低等问题。2)基于机器学习的模型虽然一定程度上提高了短时客流预测精度,但是在预测过程中没有考虑全网客流之间的时空特性对预测结果的影响,而且该类模型多数只针对一个或几个车站进行预测,无法做到使用一个模型对全网所有车站进行预测。3)基于深度学习的模型经历了长久的发展,能较好的考虑网络客流的时空特征,以及网络拓扑结构。但既有的深度学习模型复杂度高,训练时间过长,如何简化模型,降低模型复杂度而不影响模型预测精度,仍是当前短时客流预测领域重要的研究方向。总之,目前深度学习模型可以提高短期客流的预测精度。然而,大量模型为了提高精度而结合了不同的神经网络,使得模型结构极其复杂,难以应用于现实世界。因此,有必要在模型复杂性和预测性能之间进行权衡。
发明内容
本发明的目的是克服上述现有技术的缺陷,提供一种基于生成对抗网络的城市轨道交通短时客流预测方法。该方法包括以下步骤:
针对城市轨道交通网络构建图结构,标记为图G=(V,E,A),V表示地铁车站数目,车站与车站之间有E条边,代表邻接矩阵,用于标记车站之间是否相邻,并将客流量作为车站的属性特征;
基于所述图结构获取多个模式下反映历史客流信息的时间序列数据,所述多个模式根据与客流预测时刻的不同时间间隔进行划分;
将所述多个模式下的时间序列数据输入至图卷积神经网络获取各模式下客流的时空相关性;
将所述图卷积神经网络输出的不同模式数据进行合并后输入到生成器,以生成城市轨道交通网络中目标车站在后续时刻的交通客流信息,其中所述生成器利用设定的目标函数通过训练生成对抗网络获得。
与现有技术相比,本发明的优点在于,提出基于图卷积神经网络和生成对抗网络的深度学习框架,能够有机地将全网客流的时空特性以及网络的拓扑结构结合考虑,并利用简单的生成对抗网络进行城市轨道交通短时客流预测,进而提高了模型的预测精度,并降低了模型的复杂度。
通过以下参照附图对本发明的示例性实施例的详细描述,本发明的其它特征及其优点将会变得清楚。
附图说明
被结合在说明书中并构成说明书的一部分的附图示出了本发明的实施例,并且连同其说明一起用于解释本发明的原理。
图1是根据本发明一个实施例的基于图卷积神经网络和生成对抗网络的深度学习模型架构图;
图2是根据本发明一个实施例的图卷积神经网络模型的结构示意图;
图3是根据本发明一个实施例的生成对抗网络的框架图;
图4是根据本发明一个实施例的不同模型的性能对比示意图;
图5是根据本发明一个实施例的MetroBJ2016和MetroBJ2018中选取的三个站点的实际值和预测值对比示意图;
图6是根据本发明一个实施例的在MetroBJ2016和MetroBJ2018不同时间段的预测性能对比示意图;
附图中,Real-time pattern-实时模式;Daily pattern-日模式;Weeklypattern-周模式;Generator-生成器;Discriminator-判别器;Input layer-输入层;Hidden layer-隐藏层;Output layer-输出层;Dataset-数据集;Fully connectednetwork-全连接网络。
具体实施方式
现在将参照附图来详细描述本发明的各种示例性实施例。应注意到:除非另外具体说明,否则在这些实施例中阐述的部件和步骤的相对布置、数字表达式和数值不限制本发明的范围。
以下对至少一个示例性实施例的描述实际上仅仅是说明性的,决不作为对本发明及其应用或使用的任何限制。
对于相关领域普通技术人员已知的技术、方法和设备可能不作详细讨论,但在适当情况下,所述技术、方法和设备应当被视为说明书的一部分。
在这里示出和讨论的所有例子中,任何具体值应被解释为仅仅是示例性的,而不是作为限制。因此,示例性实施例的其它例子可以具有不同的值。
应注意到:相似的标号和字母在下面的附图中表示类似项,因此,一旦某一项在一个附图中被定义,则在随后的附图中不需要对其进行进一步讨论。
本发明的技术方案包括以下内容:首先,对所要解决的科学问题进行详细定义。然后,展示了所提出的深度学习框架,该深度学习框架总体上包括图卷积神经网络和生成对抗网络。进一步地,对该深度学习框架所使用的图卷积神经网络以及生成对抗网络进行详细说明。最后,使用北京2016年和2018年连续五周的地铁卡数据进行试验,并与多个现有经典的短时客流预测模型进行对比,从而验证了本发明模型的合理性和准确性。
(1)问题定义
本发明目的是同时预测城市轨道交通全网所有车站某一时段的进站客流量。将城市轨道交通网络定义为一张图G=(V,E,A),图中有V个地铁车站,车站与车站之间有E条边,代表该网络的邻接矩阵,例如采用0,1两种元素表示,0代表两个车站不相邻,1代表两个车站相邻。
将客流量作为城市轨道交通系统中地铁站的属性特征,表示为即第i个车站在第t个时间段的第k个特征,n为车站数量,m为时间步,k为特征矩阵的数量。在一个实施例中,每个车站有两个特征矩阵:进站客流矩阵和出站客流矩阵。因此,k=2。表示在时间t时所有车站的所有特征值。表示在时间T内所有车站的所有特征值。Yt+1∈Rn*1*k表示未来t+1时刻所有车站的客流量。
因此,本发明所要解决的问题是,利用过去所有车站m个时间间隔的客流量来预测t+1时刻的客流量如公式(1)所示,其中f是深度学习框架将要学习的映射函数。
Yt+1=f(Xt)(1)
(2)深度学习模型框架
本发明所提出的模型框架总体上包括图卷积神经网络模型(GCN,GraphConvolutional Network)和生成对抗网络模型(GAN,Generative Adversarial Network),或称为Graph-GAN,其结构如图1所示。GAN包含生成器G和判别器D。历史客流数据可分为三个模式:实时模式、日模式和周模式。首先应用GCN模型获取各模式客流时空相关性。然后将三个模式的输出合并作为GAN中的生成器的输入。对于生成器G,例如采用一个全连接网络模型,以合并后的数据作为输入,生成城市轨道交通网络中所有地铁站未来的交通流量。判别器D用于区分真实数据与生成数据,其输入为历史客流数据和生成数据。G和D经过不断地迭代训练,使得生成器生成的数据与真实数据非常相似,以至于判别器D无法区分它们。训练后的生成器可以作为预测模型。以下重点描述图卷积神经网络和生成对抗网络。
(3)关于图卷积神经网络
在本发明中,使用GCN模型来捕获城市轨道交通网络中站点之间的拓扑关系。传统的交通预测模型往往将交通网络视为网格矩阵,忽略了网络拓扑结构对预测精度的影响。GCN模型对时空特征和网络拓扑信息具有较强的提取能力。从谱图卷积滤波器到切比雪夫多项式滤波器,再到一阶近似滤波器,GCN的性能得到了很大的提高。因此,本发明使用GCN模型来获得地铁车站之间的内部拓扑关系。图2是GCN模型的结构。
假设在一个静态图中有N个具有M维特征的节点。拓扑结构和节点特征可以分别用邻接矩阵A和特征矩阵Z表示。在一个实施例中,所使用的GCN滤波器为2016年Kipf等人提出的,如公式(2)所示。
其中A为邻接矩阵,IN为N维单位矩阵,/>为矩阵/>的对角度矩阵,W为权重矩阵,Z为特征矩阵,f(·)是激活函数,X为最终的输出。
然而,考虑到随着GCN层堆叠数量的增加,GCN模型的性能会变得越来越差。更多的GCN层堆叠不仅会导致反向传播过程的复杂性增加,还会导致梯度消失等问题,从而降低GCN的性能。此外,深层GCN中存在严重的“过平滑”问题,即随着层数的增加顶点的几个特征收敛到相同的值。鉴于此,在一个实施例中,将一般GCN扩展为更简单的GCN,以克服公式(2)所示的GCN模型的不足,如公式(3)所示。
其中为归一化的拉普拉斯矩阵,In∈Rn*m*k为模型的输入,In′与In具有相同的维度,但是In′包含丰富的网络拓扑信息,后续作为GAN的输入。
由于进站客流和出站客流受相邻时段、日时段和周时段的客流影响,因此在一个实施例中,利用三种模式下的客流:实时模式、日模式和周模式。假设时间粒度为ti,时间步为ts,当前时间段为t,预测t+1时刻的客流量。三种模式的详细情况分别如下。
1)、实时模式
Xreal=(Xt-ts+1,Xt-ts+2,…,Xt),临近预测时间段的历史时间序列。相邻时段的客流量会影响下一时段客流量的增减。例如,在突发事件下,进出地铁站的乘客数量会发生相应的变化。
2)、日模式
预测时间段前一天同一时间的历史时间序列。在Xday中,当前时刻为t,前一天的同一时刻为/>由于早晚高峰,每天的客流都会呈现一定的趋势。因此,有必要根据前一天的客流量来预测当前的客流量。
3)、周模式
预测时间段前一周同一时间的历史时间序列。在Xweek中,当前时刻为t,前一周的同一时刻为/>即/>由于通勤人数众多,每周的客流也呈现一定的规律性。例如,本周一和上周一的交通模式有相似之处。因此,有必要根据上周的客流量来预测当前的客流量。
这三种模式共享相同的网络结构。三种模式的输出合并后输入到GAN的生成器中。应理解的是,根据应用场景不同,也可采用更多或更少的模式进行客流预测。
(4)关于生成对抗网络
在本发明中,采用生成预测结果的具有对抗性过程的GAN模型。在短期客流预测领域,许多深度学习模型只考虑提高模型精度,而忽略了模型的复杂性。在本发明中,从应用目的出发,充分考虑了模型复杂性和模型性能,不以增加模型复杂性来换取模型预测精度。因此,采用一种更简单的深度学习模型,即全连接层,再结合一种更先进的模型训练方法,即GAN,以达到平衡模型复杂性和模型性能的目的。
GAN是一种具有对抗性过程的生成模型,如图3所示,GAN由生成模型(Generator,G)和判别模型(Discriminator,D)两部分组成。G捕捉真实数据的分布,并从该分布生成新数据;D是一个二分类器,它区分输入是真实数据还是生成数据。优化过程类似于极小极大博弈过程。G和D进行迭代训练。通过反向传播算法,最终的目标可以达到纳什均衡,即生成器完全得到真实数据的分布。因此,生成器可以用来生成最终的预测结果。
例如,训练生成对抗网络的目标函数如公式(4)所示。
其中为真实数据,/>为真实数据的分布,/>为/>来自真实数据分布的概率,/>为随机噪声,/>为随机噪声/>的分布,/>为生成器生成的数据,为生成的数据来自于生成数据分布的概率。目标函数的是最大化真实数据和来自G的生成样本分配正确标签的概率。
在本发明中,目标是使用历史的进站客流和出站客流X=(X1,X2,…,XT)来预测未来的客流Yt+1。利用一个简单的全连接神经网络作为生成器G来生成预测结果,同样,利用一个简单的全连接神经网络作为判别器D来区分真实数据与生成器生成的数据。
首先,介绍利用生成器生成样本的过程。例如,利用一个全连接神经网络构成生成器G,该神经网络有两个隐藏层和一个输出层。生成器的输入为GCN模型的输出生成器的输出数据为(X′1,X′2,…,X′T)。
对于判别器D,用于区分真实数据和生成器生成的数据。同样利用一个全连接神经网络构成判别器D,该神经网络具有两个隐藏层和一个输出层。在训练过程中,真实数据X和生成的数据X′交替输入到判别器D中,然后将判别器D的误差反向传播到生成器G,使生成的数据与真实数据之间的误差最小。通过这种对抗式训练,能够提升生成器的预测精度。
在一个实施例中,使用Wasserstein GAN(WGAN)而不是初始GAN进行训练。WGAN与初始GAN的主要区别在于,WGAN引入了瓦瑟斯坦距离(Wasserstein distance)作为优化目标,而初始GAN以JS散度和KL散度作为优化目标。由于与KL散度和JS散度相比,WGAN具有更平滑的Wasserstein distance,从根本上解决了初始GAN的梯度消失问题。具体来说,WGAN具有以下两个优点:首先,WGAN更容易训练,因为训练过程更稳定,对模型结构和超参数的敏感性更低,而且不需要仔细平衡生成器和鉴别器的训练,可以通过训练一个简单的全连接网络而达到精度更高的预测结果。其次,WGAN解决了模型坍塌问题,保证了生成样本的多样性,从而加速生成器G的训练。WGAN的目标函数如公式(5)所示。
其中,ω临界参数,θ为生成器的参数,为带有参数的函数,/>都是Lipschitz连续的函数,/>表示生成器基于样本Z所生成的数据。
根据经验,总结了WGAN的训练技巧,这对实际应用来说非常重要,并为今后的研究提供了以下几点提示,这些设计有助于提高训练效率并提升模型的预测精度。
1)、在WGAN训练过程中,将判别器D最后一层的Sigmoid激活函数层去掉;
2)、生成器和判别器的损失函数不取对数;
3)、每次更新判别器的参数后,将其绝对值截断为不超过固定常数c;
4)、在选择优化器时,不使用基于动量的优化器算法,如Adam。推荐使用RMSProp和SGD优化器。
5)、在对判别器和生成器进行交替训练时,在每个epoch中建议判别器比生成器多训练几次,这样更容易地达到判别器和生成器之间的平衡状态。
为进一步验证本发明的效果,进行了实验仿真。以下将详细介绍模型使用的数据集、评价指标、基准模型、模型参数设置以及结果分析。
(1)数据集
实验使用了北京地铁的两个数据集,如表1所示。第一个数据集是MetroBJ2016,包括从2016年2月29日到2016年4月3日连续5周工作日的北京地铁刷卡数据,共计17条线路和276个车站(不包括机场线)。第二个数据集是MetroBJ2018,包括2018年10月8日至2018年11月11日连续5周工作日的北京地铁刷卡数据,共计22条线路和308个车站(不包括机场线)。为了提取进站客流与出站客流,将进站时间和出站时间转换为分钟,从0到1080,代表05:00到23:00。然后从数据中提取出15分钟时间粒度的进站客流时间序列。MetroBJ2016的客流维度为276*1800,MetroBJ2018的客流维度为308*1800。此外,为所有地铁站提供一个唯一的车站编号。
表1数据集的描述
(2)模型配置与评价指标
1)模型参数设置
对于两个数据集,GAN的参数相同:Generator G和discriminator D是有两个隐藏层和一个输出层的全连接神经网络。生成器G和判别器D中隐藏层的单元数分别为[1024,512]和[512,256]。生成器G和判别器D中隐藏层的激活函数分别为ReLU函数和LeakyReLU函数。由于使用WGAN进行训练,所以生成器输出层的激活函数为Tanh函数,判别器的输出层中没有激活函数。在实验中,选择80%的数据作为训练数据和验证数据,剩下的20%作为测试数据。Batch size为32,优化器是RMSprop,学习率为0.00005。所有的模型在一台带有CoreTMi9-10900X处理器,32GB运行内存,以及NVIDIA GeForce RTX3080 GPU的台式机进行运算。
2)数据预处理
在训练前,使用Min-Max归一化方法将数据缩放到[0,1]范围内。在训练后,将预测值重新调整到原始的尺度,用于与真实数据进行比较。
3)评价标准
所使用的评价指标为均方根误差Root Mean Square Error(RMSE)、平均绝对误差Mean Absolute Error(MAE)和加权平均绝对百分比误差weighted Mean AbsolutePercentage Error(WMAPE),如公式(6)到(8)所示。
其中,N为样本数量,yi为真实值,为预测值,/>为真实值总和。
(3)、模型对比
在实验中,所提出的Graph-GAN模型将和以下几个模型进行对比,以证明模型的有效性。
ARIMA:自回归移动平均模型。将相同的ARIMA模型应用于所有城市轨道交通站点的客流。模型中的三个参数,即滞后阶数、差异程度、移动平均阶数,经过微调后分别设为9、1、0。
LSTM:长短时记忆网络模型。LSTM模型在2015年被首次应用到交通领域,使用的LSTM模型具有两个LSTM层和三个全连接层。优化器是Adam,学习率为0.001,batch size为32。输入是10个时间步的进站客流序列,输出是下一个时间步的进站客流序列。
CNN:卷积神经网络模型。实验使用的CNN模型具有两个CNN层和三个全连接层。内核大小为3*3,优化器是Adam,学习率为0.001。batch size是32,输入和输出与LSTM模型相同。
ST-ResNet:时空残差网络模型。该模型是由张钧波等人2017年提出,实验使用其中的三个分支,不包含天气数据分支。
ConvLSTM:卷积循环神经网络模型。实验使用的ConvLSTM模型具有两个ConvLSTM层和三个全连接层。其他参数与CNN模型相同。
ResLSTM:一种融合了GCN、ResNet和注意力机制的LSTM模型的深度学习框架。
GAN:生成对抗网络模型。实验使用的GAN除了GCN模块外,GAN的参数与Graph-GAN相同。
Conv-GCN:结合图卷积网络(GCN)和三维卷积神经网络3D CNN的深度学习体系结构。
(4)结果分析
1)网络范围内的预测性能
表2和图4显示了Graph-GAN与其他基线方法在MetroBJ2016和MetroBJ2018数据集上的性能对比。图4是不同模型在全网范围内的预测效果的对比,其中图4(a)对应RMSE指标,图4(b)对应MAE指标,图4(c)对应WMAPE指标。由图可以看出,本发明所提出的Conv-GCN模型在所有情况下表现最好,展现了模型良好的鲁棒性和表现能力。
参见表2,深度学习模型显著优于基于数理统计的模型。在MetroBJ2016和MetroBJ2018数据集中,ARIMA是性能最差的模型,两个数据集的RMSE分别为81.4562和69.4250,MSE分别为42.8006和33.9540,这是因为ARIMA无法捕捉客流的综合非线性特征。
进一步将Graph-GAN模型与LSTM和CNN等深度学习方法进行了比较。在深度学习模型中,LSTM不能捕获数据之间的空间相关性,而CNN不能捕捉数据之间的时间相关性。因此,这两种模型比同时考虑空间和时间信息的Graph-GAN模型表现差。
ST-Resnet、Conv-GCN和Conv-LSTM是兼顾空间相关性和时间动态的客流预测方法。这些方法比LSTM和CNN的精度有所提高。然而,这些模型在结构上相对复杂。因此,本发明提出了一个更简单的模型,并且能够实现更好的预测性能。
上述已经介绍到GAN具有巨大的潜力,并得到了广泛的应用。利用对抗性训练过程,生成器模型可以显著提高其预测能力,获得更准确的预测结果。因此,本发明对抗性地训练由两个叠加的全连接神经网络构成的生成器和判别器,从而利用简单神经网络获得更好的预测结果。结果表明,GAN的性能优于LSTM和CNN,但由于拓扑信息没有得到充分利用,无法充分捕获空间和时间信息。
本发明提出了结合GCN和GAN的Graph-GAN模型,以更好地捕捉高维数据中的空间和时间关系。从表2可以看出,与传统和深度学习模型相比,Graph-GAN模型的预测精度最高,两个数据集的RMSE最低,分别为34.6653和32.9536,MAE最低,分别为20.3786和16.6860,WMAPE最低,分别为7.693%和8.549%。
表2不同模型性能的比较
2)模型在单个车站预测性能的比较
在实验中,选择了三个具有不同客流特征的站点来展示Graph-GAN的预测性能。第一个车站是回龙观站,这是一个有数百万人居住的大社区。第二个车站是东直门站,这是一个典型的交通枢纽,三条地铁线在此交汇。最后一个车站是北京南站,这是一个靠近大型火车站的地铁站。三个地铁站的预测结果参见图5,其展示了MetroBJ2016和MetroBJ2018中选取的三个站点的实际值和预测值比较。从结果可以得出以下结论:
回龙观站的预测结果如图5(a)所示,从图中可以看出,无论是在高峰时段还是非高峰时段,预测值始终与实际值一致,说明Graph-GAN模型具有较强的鲁棒性。此外,由于回龙观地铁站位于较大的居民区附近,平日乘客通勤极为频繁,且存在较强的早高峰和晚高峰特征,这有助于提高预测性能。
东直门站的预测结果如图5(b)所示,可以看出东直门地铁站客流呈现明显的晚峰特征。无论在高峰期还是非高峰期,预测效果都很好,说明该模型可以应用于换乘站。
北京南站的预测结果如图5(c)所示,可以看出北京南站客流波动较大,没有明显的早晚高峰。在这种情况下,所提出的模型仍然能够很好地捕捉客流变化,说明该模型在不同条件下表现良好。
图5详细展示了不同模型在不同类型车站的预测效果的对比。可以看出,所提出的Graph-GAN模型不仅在整个城市轨道交通网络有很好的预测效果,而且在每个站点都能取得良好的预测结果。
3)模型在不同时间段预测性能的比较
为了评估不同时间段的预测效果,本发明计算了MetroBJ2016和MetroBJ2018在每个时间段5:00-23:00的平均预测精度。图6显示了Graph-GAN模型与基线模型在不同时间段下的性能比较。由图可以得出了以下结论。
首先,讨论了不同时间段的预测性能与整体预测性能的关系。从图4(各模型总体预测结果)和图6(各模型在不同时间段的预测结果)可以看出,不同模型在不同时间段的性能与总体性能表现出相同的规律。例如,ARIMA在高峰期和非高峰期的表现最差。此外,ARIMA的预测误差波动最大,说明基于统计的模型不适合大规模数据的预测。在高峰时期和非高峰时期,Graph-GAN通常比基线模型表现得更好。这些结果表明了Graph-GAN模型的稳定性。
接着,分析了同一模型在不同时间段下的预测性能。以Graph-GAN模型为例,在非高峰时段,Graph-GAN的性能优于高峰时段。其他模型在不同时间段的预测性能与Graph-GAN类似,这表明,当城市轨道交通客流波动较大时,模型预测效果下降。但与其他基线模型相比,Graph-GAN模型的预测误差在峰值时段的波动最为小。
最后,讨论了各模型在不同数据集上的预测性能。从图6可以看出,模型在MetroBJ2016和MetroBJ2018上的预测结果都有相似的规律,证明了模型具有很好的泛化能力。
综上,本发明所提出的模型无论在全天还是在不同时间段都能取得良好的预测结果,具有显著的鲁棒性。
4)模型参数的比较
进一步地,将Graph-GAN模型的可训练模型参数数量与能够捕获时空相关性的基线模型的参数数量在MetroBJ2016和MetroBJ2018进行了比较,如表3所示。可以看到,本发明提出的Graph-GAN具有最少的模型参数。Graph-GAN的参数个数与GAN的个数相同。然而,Graph-GAN精度得到了提高,这也证明了GCN模块的有效性。通过对模型参数的比较,证明了本发明的思想,即不以增加模型复杂性为代价来提高模型的预测精度,而是从应用的角度充分考虑了模型复杂性和模型性能之间的权衡。仅将简单的全连接神经网络与一种更先进的模型训练方法结合使用,获得了更好的预测精度。
表3模型参数数量
上述实验结果说明,本发明所提出的模型在任何情况下均表现最好。在与其他基线模型进行对比时,Graph-GAN的效果最好,在MetroBJ2016数据集中,RMSE为34.6653,MAE为20.3786,WMAPE为7.693%;在MetroBJ2018数据集中,RMSE为32.9536,MAE为16.6860,WMAPE为8.549%。模型在不同类型车站的预测性能对比中,Graph-GAN模型不仅在整个城市轨道交通网络有很好的预测效果,而且在每个站点都能取得良好的预测结果。模型在不同时间段的预测性能对比中,所提出的模型无论在全天还是在不同时间段都能取得良好的预测结果,具有显著的鲁棒性。在模型参数数量的对比中,所提出的模型参数数量最少。
综上所述,本发明提出了一种结构简单、预测精度高的深度学习模型Graph-GAN来预测城市轨道交通网络的短期客流。该模型主要包括:简化版的图卷积网络用于提取网络拓扑信息;采用生成对抗网络来预测短期客流,生成对抗网络中的生成器和判别器由简单的全连接神经网络组成。Graph-GAN在北京地铁的两个大型真实数据集上进行了测试。最后将Graph-GAN与许多先进模型的预测性能进行比较,说明了其显著的优势和鲁棒性。本发明提出的模型在从大量数据中捕捉复杂相关性,显著提高了网络范围内的预测精度,可以从现实应用的角度为进行短期客流预测提供重要经验。
本发明可以是系统、方法和/或计算机程序产品。计算机程序产品可以包括计算机可读存储介质,其上载有用于使处理器实现本发明的各个方面的计算机可读程序指令。
计算机可读存储介质可以是可以保持和存储由指令执行设备使用的指令的有形设备。计算机可读存储介质例如可以是但不限于电存储设备、磁存储设备、光存储设备、电磁存储设备、半导体存储设备或者上述的任意合适的组合。计算机可读存储介质的更具体的例子(非穷举的列表)包括:便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、静态随机存取存储器(SRAM)、便携式压缩盘只读存储器(CD-ROM)、数字多功能盘(DVD)、记忆棒、软盘、机械编码设备、例如其上存储有指令的打孔卡或凹槽内凸起结构、以及上述的任意合适的组合。这里所使用的计算机可读存储介质不被解释为瞬时信号本身,诸如无线电波或者其他自由传播的电磁波、通过波导或其他传输媒介传播的电磁波(例如,通过光纤电缆的光脉冲)、或者通过电线传输的电信号。
这里所描述的计算机可读程序指令可以从计算机可读存储介质下载到各个计算/处理设备,或者通过网络、例如因特网、局域网、广域网和/或无线网下载到外部计算机或外部存储设备。网络可以包括铜传输电缆、光纤传输、无线传输、路由器、防火墙、交换机、网关计算机和/或边缘服务器。每个计算/处理设备中的网络适配卡或者网络接口从网络接收计算机可读程序指令,并转发该计算机可读程序指令,以供存储在各个计算/处理设备中的计算机可读存储介质中。
用于执行本发明操作的计算机程序指令可以是汇编指令、指令集架构(ISA)指令、机器指令、机器相关指令、微代码、固件指令、状态设置数据、或者以一种或多种编程语言的任意组合编写的源代码或目标代码,所述编程语言包括面向对象的编程语言—诸如Smalltalk、C++、Python等,以及常规的过程式编程语言—诸如“C”语言或类似的编程语言。计算机可读程序指令可以完全地在用户计算机上执行、部分地在用户计算机上执行、作为一个独立的软件包执行、部分在用户计算机上部分在远程计算机上执行、或者完全在远程计算机或服务器上执行。在涉及远程计算机的情形中,远程计算机可以通过任意种类的网络—包括局域网(LAN)或广域网(WAN)—连接到用户计算机,或者,可以连接到外部计算机(例如利用因特网服务提供商来通过因特网连接)。在一些实施例中,通过利用计算机可读程序指令的状态信息来个性化定制电子电路,例如可编程逻辑电路、现场可编程门阵列(FPGA)或可编程逻辑阵列(PLA),该电子电路可以执行计算机可读程序指令,从而实现本发明的各个方面。
这里参照根据本发明实施例的方法、装置(系统)和计算机程序产品的流程图和/或框图描述了本发明的各个方面。应当理解,流程图和/或框图的每个方框以及流程图和/或框图中各方框的组合,都可以由计算机可读程序指令实现。
这些计算机可读程序指令可以提供给通用计算机、专用计算机或其它可编程数据处理装置的处理器,从而生产出一种机器,使得这些指令在通过计算机或其它可编程数据处理装置的处理器执行时,产生了实现流程图和/或框图中的一个或多个方框中规定的功能/动作的装置。也可以把这些计算机可读程序指令存储在计算机可读存储介质中,这些指令使得计算机、可编程数据处理装置和/或其他设备以特定方式工作,从而,存储有指令的计算机可读介质则包括一个制造品,其包括实现流程图和/或框图中的一个或多个方框中规定的功能/动作的各个方面的指令。
也可以把计算机可读程序指令加载到计算机、其它可编程数据处理装置、或其它设备上,使得在计算机、其它可编程数据处理装置或其它设备上执行一系列操作步骤,以产生计算机实现的过程,从而使得在计算机、其它可编程数据处理装置、或其它设备上执行的指令实现流程图和/或框图中的一个或多个方框中规定的功能/动作。
附图中的流程图和框图显示了根据本发明的多个实施例的系统、方法和计算机程序产品的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段或指令的一部分,所述模块、程序段或指令的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。在有些作为替换的实现中,方框中所标注的功能也可以以不同于附图中所标注的顺序发生。例如,两个连续的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这依所涉及的功能而定。也要注意的是,框图和/或流程图中的每个方框、以及框图和/或流程图中的方框的组合,可以用执行规定的功能或动作的专用的基于硬件的系统来实现,或者可以用专用硬件与计算机指令的组合来实现。对于本领域技术人员来说公知的是,通过硬件方式实现、通过软件方式实现以及通过软件和硬件结合的方式实现都是等价的。
以上已经描述了本发明的各实施例,上述说明是示例性的,并非穷尽性的,并且也不限于所披露的各实施例。在不偏离所说明的各实施例的范围和精神的情况下,对于本技术领域的普通技术人员来说许多修改和变更都是显而易见的。本文中所用术语的选择,旨在最好地解释各实施例的原理、实际应用或对市场中的技术改进,或者使本技术领域的其它普通技术人员能理解本文披露的各实施例。本发明的范围由所附权利要求来限定。

Claims (6)

1.一种基于生成对抗网络的城市轨道交通短时客流预测方法,包括以下步骤:
针对城市轨道交通网络构建图结构,标记为图G=(V,E,A),V表示地铁车站数目,车站与车站之间有E条边,代表邻接矩阵,用于标记车站之间是否相邻,并将客流量作为车站的属性特征;
基于所述图结构获取多个模式下反映历史客流信息的时间序列数据,所述多个模式根据与客流预测时刻的不同时间间隔进行划分;
将所述多个模式下的时间序列数据输入至图卷积神经网络获取各模式下客流的时空相关性;
将所述图卷积神经网络输出的不同模式数据进行合并后输入到生成器,以生成城市轨道交通网络中目标车站在后续时刻的交通客流信息,其中所述生成器利用设定的目标函数通过训练生成对抗网络获得;
其中,将客流量作为城市轨道交通网络中车站的属性特征,标记为其表示第i个车站在第t个时间段的第k个特征,n为车站数量,m为时间步,k为特征矩阵的数量,每个车站有两个特征矩阵:进站客流矩阵和出站客流矩阵,/>表示在时间t时所有车站的所有特征值,/>表示在时间T内所有车站的所有特征值,Yt+1∈Rn*1*k表示未来t+1时刻所有车站的客流量;
其中,所述生成对抗网络包括生成器G和判别器D,生成器G采用全连接神经网络,并设置两个隐藏层和一个输出层,生成器的输入为所述图卷积神经网络的输出;判别器D采用全连接神经网络,并具有两个隐藏层和一个输出层;所述生成对抗网络的目标是使用历史的进站客流和出站客流来预测未来时刻的客流,在训练过程中,真实数据和生成的数据交替输入到判别器D中,然后将判别器D的误差反向传播到生成器G,使生成的数据与真实数据之间的误差最小;
其中,所述图卷积神经网络的滤波器表示为:
其中,A为邻接矩阵,IN为N维单位矩阵,/>为矩阵/>的对角度矩阵,W为权重矩阵,Z为特征矩阵,f(·)激活函数,X为最终的输出,设置图结构中有N个具有M维特征的节点;
其中,将所述图卷积神经网络简化为:
其中为归一化的拉普拉斯矩阵,In∈Rn*m*k为输入,In′与In具有相同的维度,In′作为生成对抗网络的输入。
2.根据权利要求1所述的方法,其特征在于,所述多个模式包括实时模式、日模式和周模式,实时模式对应预测时间段的相邻时间段的历史时间序列,表示为:
Xreal=(Xt-ts+1,Xt-ts+2,…,Xt)
日模式对应预测时间段前一天同一时间的历史时间序列,表示为:
周模式对应预测时间段前一周同一时间的历史时间序列,表示为:
其中,时间粒度为ti,时间步为ts,当前时间段为t,预测t+1时刻的客流量,在Xday中,前一天的同一时刻为在Xweek中,前一周的同一时刻为/>
3.根据权利要求1所述的方法,其特征在于,所述生成对抗网络采用Wasserstein生成对抗网络,目标函数设定为:
其中,ω是临界参数,θ为生成器的参数,为带有参数的函数,/>是Lipschitz连续的函数,/>为真实数据,/>为真实数据的分布,/>为随机噪声,为随机噪声/>的分布,/>表示生成器基于样本Z所生成的数据。
4.根据权利要求1所述的方法,其特征在于,在所述生成对抗网络训练过程中,将判别器最后一层的Sigmoid激活函数层去掉;生成器和判别器的损失函数不取对数;每次更新判别器的参数后,将其绝对值截断为不超过固定常数c;在选择优化器时,使用RMSProp或SGD优化器;在对判别器和生成器进行交替训练时,在每个epoch中判别器比生成器多训练设定数量的次数。
5.一种计算机可读存储介质,其上存储有计算机程序,其中,该程序被处理器执行时实现根据权利要求1至4中任一项所述方法的步骤。
6.一种计算机设备,包括存储器和处理器,在所述存储器上存储有能够在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现权利要求1至4中任一项所述的方法的步骤。
CN202210188660.1A 2022-02-28 2022-02-28 一种基于生成对抗网络的城市轨道交通短时客流预测方法 Active CN114626585B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210188660.1A CN114626585B (zh) 2022-02-28 2022-02-28 一种基于生成对抗网络的城市轨道交通短时客流预测方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210188660.1A CN114626585B (zh) 2022-02-28 2022-02-28 一种基于生成对抗网络的城市轨道交通短时客流预测方法

Publications (2)

Publication Number Publication Date
CN114626585A CN114626585A (zh) 2022-06-14
CN114626585B true CN114626585B (zh) 2023-09-08

Family

ID=81900114

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210188660.1A Active CN114626585B (zh) 2022-02-28 2022-02-28 一种基于生成对抗网络的城市轨道交通短时客流预测方法

Country Status (1)

Country Link
CN (1) CN114626585B (zh)

Families Citing this family (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115081717B (zh) * 2022-06-27 2023-03-24 北京建筑大学 融合注意力机制和图神经网络的轨道交通客流预测方法
CN115564151A (zh) * 2022-12-06 2023-01-03 成都智元汇信息技术股份有限公司 一种基于形态识别的突发大客流形态识别方法及系统
CN116050640B (zh) * 2023-02-01 2023-10-13 北京交通大学 基于自适应多图卷积的多模式交通系统短时客流预测方法
CN115965163A (zh) * 2023-02-07 2023-04-14 北京交通大学 基于时空图生成对抗损失的轨道交通短时客流预测方法

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109493599A (zh) * 2018-11-16 2019-03-19 南京航空航天大学 一种基于生成式对抗网络的短时交通流预测方法
CN111667092A (zh) * 2020-04-21 2020-09-15 北京交通大学 基于图卷积神经网络的轨道交通短时客流预测方法和系统
CN113450561A (zh) * 2021-05-06 2021-09-28 浙江工业大学 一种基于时空图卷积-生成对抗网络的交通速度预测方法

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20200074267A1 (en) * 2018-08-31 2020-03-05 International Business Machines Corporation Data prediction

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109493599A (zh) * 2018-11-16 2019-03-19 南京航空航天大学 一种基于生成式对抗网络的短时交通流预测方法
CN111667092A (zh) * 2020-04-21 2020-09-15 北京交通大学 基于图卷积神经网络的轨道交通短时客流预测方法和系统
CN113450561A (zh) * 2021-05-06 2021-09-28 浙江工业大学 一种基于时空图卷积-生成对抗网络的交通速度预测方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
梁强升 ; 许心越 ; 刘利强 ; .面向数据驱动的城市轨道交通短时客流预测模型.中国铁道科学.2020,(第04期),全文. *

Also Published As

Publication number Publication date
CN114626585A (zh) 2022-06-14

Similar Documents

Publication Publication Date Title
CN114626585B (zh) 一种基于生成对抗网络的城市轨道交通短时客流预测方法
Dong et al. Hourly energy consumption prediction of an office building based on ensemble learning and energy consumption pattern classification
Lv et al. Stacked autoencoder with echo-state regression for tourism demand forecasting using search query data
CN110570651A (zh) 一种基于深度学习的路网交通态势预测方法及系统
Kong et al. Big data‐driven machine learning‐enabled traffic flow prediction
Andariesta et al. Machine learning models for predicting international tourist arrivals in Indonesia during the COVID-19 pandemic: a multisource Internet data approach
Bao et al. Covid-gan: Estimating human mobility responses to covid-19 pandemic through spatio-temporal conditional generative adversarial networks
CN112001548A (zh) 一种基于深度学习的od客流预测方法
CN115204478A (zh) 一种结合城市兴趣点和时空因果关系的公共交通流量预测方法
CN116010684A (zh) 物品推荐方法、装置及存储介质
Yu et al. Real-time prediction system of train carriage load based on multi-stream fuzzy learning
Zhao et al. Historical pattern recognition with trajectory similarity for daily tourist arrivals forecasting
Dai et al. Spatio-temporal deep learning framework for traffic speed forecasting in IoT
CN115828990A (zh) 融合自适应图扩散卷积网络的时空图节点属性预测方法
Li Prediction of tourism demand in liuzhou region based on machine learning
Zahera et al. Jointly learning from social media and environmental data for typhoon intensity prediction
Xu et al. A taxi dispatch system based on prediction of demand and destination
Sandagiri et al. ANN Based Crime Detection and Prediction using Twitter Posts and Weather Data
Tran Grid Search of Convolutional Neural Network model in the case of load forecasting
Du et al. Structure tuning method on deep convolutional generative adversarial network with nondominated sorting genetic algorithm II
Wang et al. A novel GBDT-BiLSTM hybrid model on improving day-ahead photovoltaic prediction
Shaikh et al. Bayesian optimization with stacked sparse autoencoder based cryptocurrency price prediction model
Гавриленко et al. Тhe task of analyzing publications to build a forecast for changes in cryptocurrency rates
Zheng et al. Modeling stochastic service time for complex on-demand food delivery
Bansal et al. Cryptocurrency price prediction using Twitter and news articles analysis

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant