CN115631631B - 一种基于双向蒸馏网络的交通流量预测方法与装置 - Google Patents

一种基于双向蒸馏网络的交通流量预测方法与装置 Download PDF

Info

Publication number
CN115631631B
CN115631631B CN202211419913.8A CN202211419913A CN115631631B CN 115631631 B CN115631631 B CN 115631631B CN 202211419913 A CN202211419913 A CN 202211419913A CN 115631631 B CN115631631 B CN 115631631B
Authority
CN
China
Prior art keywords
traffic flow
flow prediction
prediction model
network
distillation
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202211419913.8A
Other languages
English (en)
Other versions
CN115631631A (zh
Inventor
马宇晴
刘祥龙
刘卫
高雅君
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beihang University
Original Assignee
Beihang University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beihang University filed Critical Beihang University
Priority to CN202211419913.8A priority Critical patent/CN115631631B/zh
Publication of CN115631631A publication Critical patent/CN115631631A/zh
Application granted granted Critical
Publication of CN115631631B publication Critical patent/CN115631631B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G08SIGNALLING
    • G08GTRAFFIC CONTROL SYSTEMS
    • G08G1/00Traffic control systems for road vehicles
    • G08G1/01Detecting movement of traffic to be counted or controlled
    • G08G1/0104Measuring and analyzing of parameters relative to traffic conditions
    • G08G1/0125Traffic data processing
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Computational Linguistics (AREA)
  • Artificial Intelligence (AREA)
  • Chemical & Material Sciences (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Analytical Chemistry (AREA)
  • Traffic Control Systems (AREA)

Abstract

本发明公开了一种基于双向蒸馏网络的交通流量预测方法与装置。在该交通流量预测方法中,包括如下步骤:从交通流量的训练数据集中随机采样至少一个交通流量时空序列;同时建立前向网络交通流量预测模型和反向网络交通流量预测模型,并在两个交通流量预测模型之间以知识蒸馏的方式构建双向复杂时空动态;利用层级特定的元适配器对前向网络交通流量预测模型和反向网络交通流量预测模型中不同层级的短期空间交互信息进行精细调整,使双向蒸馏网络完全收敛;基于双向蒸馏网络中的前向网络交通流量预测模型,获得针对当前输入的交通流量时空序列的未来预测结果。

Description

一种基于双向蒸馏网络的交通流量预测方法与装置
技术领域
本发明涉及一种基于双向蒸馏网络的交通流量预测方法,同时涉及相应的交通流量预测装置,属于智能交通技术领域。
背景技术
交通流量预测对于交通管理和公共安全而言具有重要意义。如果能够准确预测一个地区的交通流量变化情况,就可以利用交通管制、警告、提前疏散等应急机制,减少或防止各类交通事故和危害公共安全的事件发生。此外,高效的交通管控、匝道计量和许多其他交通管理策略也是物联网(IoT)的重要组成部分。但是,交通流量的预测同时受区域间交通、事件、天气等多种复杂因素影响,具有很大的挑战性,在实践中仍然是一个长期存在的研究课题。
现有技术中,有人将长短期记忆网络(LSTM)和门控循环单元(GRU)等循环神经网络引入到交通流量预测中,这有利于对交通流量时空序列数据中的长期时间依赖性进行建模。例如,在专利号为ZL 202011119621.3的中国发明专利中,公开了一种基于内嵌注意力机制的循环神经网络的交通流量预测方法,包括如下步骤:获取各检测站点的历史交通流量数据;将数据处理成以τ为时间间隔的连续等时长的数据集;将数据集按照各检测站点的空间分布排列成交通流量数据矩阵;将交通流量数据矩阵分割为样本数据集;利用内嵌注意力机制的循环神经网络模型提取数据集之间的时空特征;采用单层全连接网络预测得到下一时刻的交通流量预测结果。
但是,此类现有技术虽然增强了交通流量预测模型的性能,但有两个潜在的缺陷使它们无法获得更好的结果。一方面,它们只沿时间序列模拟前向的交通流量变化动态,而不考虑反向信息。而从直觉上看,人类既可以前向推理,也可以后向推理,有时后向推理会带来更多的洞见。人类有了向后推理和向前推理的能力,就可以充分理解给定历史数据中的双向动态,同时考虑前向推理和回溯理性进行预测。另一方面,交通流量预测模型中不同层次的空间相互作用呈现出不同的学习复杂性,它们不适合共享相似的学习范式。与高层的抽象语义交互相比,浅层的空间相关性更容易学习,简单地对每一层采用相同的学习范式会降低交通流量预测模型的预测性能。
发明内容
本发明所要解决的首要技术问题在于提供一种基于双向蒸馏网络的交通流量预测方法。
本发明所要解决的另一技术问题在于提供一种基于双向蒸馏网络的交通流量预测装置。
为实现上述发明目的,本发明采用下述的技术方案:
根据本发明实施例的第一方面,提供一种基于双向蒸馏网络的交通流量预测方法,包括如下步骤:
S1,从交通流量的训练数据集中随机采样至少一个交通流量时空序列;
S2,同时建立前向网络交通流量预测模型和反向网络交通流量预测模型,并在两个交通流量预测模型之间以知识蒸馏的方式构建双向复杂时空动态;
S3,利用层级特定的元适配器对前向网络交通流量预测模型和反向网络交通流量预测模型中不同层级的短期空间交互信息进行精细调整,使双向蒸馏网络完全收敛;
S4,基于所述双向蒸馏网络中的前向网络交通流量预测模型,获得针对当前输入的交通流量时空序列的未来预测结果。
其中较优地,在训练过程中,为所述前向网络交通流量预测模型和所述反向网络交通流量预测模型中的各层赋予初始相同的学习率,然后迭代执行步骤S1和步骤S2。
其中较优地,根据总损失函数
Figure 770920DEST_PATH_IMAGE002
更新所述双向蒸馏网络的网络参数,使交通流量预测模型初步收敛;然后,迭代执行步骤S1和步骤S3,交替优化网络参数和元参数,直至交通流量预测模型最终完全收敛。
其中较优地,所述总损失函数
Figure 666195DEST_PATH_IMAGE002
的计算公式如下:
Figure 934365DEST_PATH_IMAGE003
其中,
Figure 626726DEST_PATH_IMAGE005
为蒸馏损失函数,
Figure 322150DEST_PATH_IMAGE005
为重建损失函数。
其中较优地,所述重建损失函数
Figure 970300DEST_PATH_IMAGE006
采用如下公式计算:
Figure 42161DEST_PATH_IMAGE007
其中,Xt是t时刻的真实交通流量数据;X’t是前向网络交通流量预测模型预测的t时刻交通流量数据;
Figure 872583DEST_PATH_IMAGE008
是反向网络交通流量预测模型预测的t时刻交通流量数据。
其中较优地,所述蒸馏损失函数
Figure 879853DEST_PATH_IMAGE009
采用如下公式计算:
Figure 874354DEST_PATH_IMAGE010
其中,
Figure 953168DEST_PATH_IMAGE011
是前向网络交通流量预测模型预测的
Figure 608403DEST_PATH_IMAGE012
时刻交通流量数据;
Figure 645629DEST_PATH_IMAGE013
是前向网络交通流量预测模型预测的
Figure 268371DEST_PATH_IMAGE012
时刻的潜在表征;
Figure 682035DEST_PATH_IMAGE014
是反向网络交通流量预测模型预测的
Figure 955890DEST_PATH_IMAGE012
时刻的潜在表征;
Figure 164018DEST_PATH_IMAGE015
是反向网络交通流量预测模型预测的
Figure 8477DEST_PATH_IMAGE012
时刻交通流量数据。
其中较优地,在步骤S2中,前向网络交通流量预测模型将交通流量时空序列
Figure 429094DEST_PATH_IMAGE016
按时间顺序依次输入,根据
Figure 167243DEST_PATH_IMAGE012
时刻的交通流量数据
Figure 437950DEST_PATH_IMAGE017
,历史记忆函数
Figure 894339DEST_PATH_IMAGE018
和历史隐藏状态函数
Figure 525171DEST_PATH_IMAGE019
,输出
Figure 648985DEST_PATH_IMAGE020
时刻的交通流量预测函数
Figure 589128DEST_PATH_IMAGE021
Figure 267234DEST_PATH_IMAGE022
其中,
Figure 701758DEST_PATH_IMAGE023
表示前向网络交通流量预测模型;
Figure 414499DEST_PATH_IMAGE024
表示卷积层,作用在于将潜在表征函数
Figure 604172DEST_PATH_IMAGE025
投影到
Figure 661252DEST_PATH_IMAGE020
时刻,指定区域内的交通流量预测函数
Figure 758521DEST_PATH_IMAGE026
其中较优地,在步骤S2中,反向网络交通流量预测模型依次输入交通流量时空序列
Figure 201134DEST_PATH_IMAGE027
,回溯前置条件
Figure 358446DEST_PATH_IMAGE028
并结合历史记忆函数
Figure 135778DEST_PATH_IMAGE029
和历史隐藏状态函数
Figure 36738DEST_PATH_IMAGE030
,输出
Figure 865017DEST_PATH_IMAGE031
时刻的潜在表征函数
Figure 193230DEST_PATH_IMAGE032
Figure 959323DEST_PATH_IMAGE033
其中,
Figure 398395DEST_PATH_IMAGE034
表示反向网络交通流量预测模型;
Figure 815601DEST_PATH_IMAGE035
表示卷积层,作用在于将潜在表征函数
Figure 517977DEST_PATH_IMAGE036
投影到
Figure 145268DEST_PATH_IMAGE031
时刻,指定区域内的交通流量预测函数
Figure 512664DEST_PATH_IMAGE037
其中较优地,在步骤S3中,所述元适配器根据所述前向网络交通流量预测模型和所述反向网络交通流量预测模型中不同层级的学习复杂度对每一层生成相应的学习率,利用每一层的学习率对不同层级的短期空间交互信息进行精细调整。
根据本发明实施例的第二方面,提供一种基于双向蒸馏网络的交通流量预测装置,包括处理器和存储器,所述处理器读取所述存储器中的计算机程序,用于执行上述的交通流量预测方法。
与现有技术相比较,本发明所提供的基于双向蒸馏网络的交通流量预测方法与装置首次从知识转移的角度对跨越时空的交通流量预测任务进行建模,以知识蒸馏的方式构建双向复杂时空动态,并通过元学习方式细化多层级的空间相关性。它有效地捕获了交通流量时空序列的长期时间相关性和短期空间相关性,在推理过程中与基线模型相比,可以在不增加额外计算量的情况下有效提高交通流量预测的准确性。
附图说明
图1为本发明提供的交通流量预测方法中,双向蒸馏网络的生成过程流程图;
图2为本发明实施例中,基于双向蒸馏网络的交通流量预测装置示意图。
具体实施方式
下面结合附图和具体实施例对本发明的技术内容进行详细具体的说明。
目前,在交通流量预测任务中,公认性能表现较好的是时空序列预测模型PredRNN-V2(关于PredRNN-V2模型的详细介绍,可以参阅链接:https://arxiv.org/abs/2103.09504)。因此,本发明实施例中也采用该时空序列预测模型PredRNN-V2作为基线模型。
在此基础上,本发明实施例首先生成一个用于交通流量预测的双向蒸馏神经网络模型(简称为双向蒸馏网络)。该双向蒸馏网络从知识转移的角度对跨越时空的交通流量预测任务进行建模,以知识蒸馏的方式构建双向复杂时空动态,并通过元学习方式细化多层级的空间相关性。在本发明的一个实施例中,该双向蒸馏网络包括前向网络交通流量预测模型和反向网络交通流量预测模型两部分。在前向网络交通流量预测模型和反向网络交通流量预测模型中,分别包含4个ST-LSTM叠加层和1个卷积层,它们之间的连接关系可以参考现有的PredRNN-V2模型,在此就不赘述了。
参见图1所示,上述双向蒸馏网络的生成过程至少包括如下步骤:S1,从交通流量的训练数据集中随机采样至少一个交通流量时空序列;S2,同时建立前向网络交通流量预测模型和反向网络交通流量预测模型,并在两个交通流量预测模型之间以知识蒸馏的方式构建双向复杂时空动态;S3,利用层级特定的元适配器对前向网络交通流量预测模型和反向网络交通流量预测模型中不同层级的短期空间交互信息进行精细调整,使双向蒸馏网络完全收敛。
下面,分别对每个步骤的具体实施过程进行说明:
首先,在步骤S1中,从交通流量的训练数据集中随机采样一个交通流量时空序列
Figure 174590DEST_PATH_IMAGE038
,分别供双向蒸馏网络中的前向网络交通流量预测模型和反向网络交通流量预测模型使用。其中,
Figure 719972DEST_PATH_IMAGE039
为双向蒸馏网络中,前向网络交通流量预测模型输入的交通流量时空序列,
Figure 568979DEST_PATH_IMAGE017
为当前时刻
Figure 507110DEST_PATH_IMAGE012
输入的特定空间区域内各个位置的交通流量数据,
Figure 23542DEST_PATH_IMAGE040
为前向网络交通流量预测模型所要预测的交通流量时空序列。
类似地,将上述的交通流量时空序列
Figure 739825DEST_PATH_IMAGE041
反转,得到
Figure DEST_PATH_IMAGE042
。其中,
Figure 731921DEST_PATH_IMAGE043
为双向蒸馏网络中,反向网络交通流量预测模型输入的交通流量时空序列,
Figure DEST_PATH_IMAGE044
为反向网络交通流量预测模型所要预测的交通流量时空序列。
接下来,在步骤S2中,同时建立前向网络交通流量预测模型和反向网络交通流量预测模型,分别对未来和过去的交通流量变化进行双向推理;然后,构建蒸馏损失函数和重建损失函数,指导两个交通流量预测模型在保证自身预测准确性的同时,相互协作进行知识迁移。在知识迁移的过程中,一个交通流量预测模型不仅能够保持其自身的交通流量时空建模能力,而且能够学习另一个交通流量预测模型的输出结果和特征表示。
在本发明的一个实施例中,上述步骤S2中的双向推理过程,具体包括如下步骤:
前向推理:将交通流量时空序列
Figure 50907DEST_PATH_IMAGE045
按时间顺序依次输入前向网络交通流量预测模型,根据t时刻的交通流量数据Xt,历史记忆函数
Figure DEST_PATH_IMAGE046
和历史隐藏状态函数
Figure 828370DEST_PATH_IMAGE047
,输出t+1时刻的交通流量预测函数
Figure DEST_PATH_IMAGE048
Figure 263025DEST_PATH_IMAGE022
其中,
Figure 227570DEST_PATH_IMAGE023
表示前向网络交通流量预测模型,
Figure 350246DEST_PATH_IMAGE024
表示一个卷积层,它的作用在于将潜在表征函数
Figure 965904DEST_PATH_IMAGE025
投影到
Figure 351886DEST_PATH_IMAGE020
时刻,指定区域内的交通流量预测函数
Figure 662782DEST_PATH_IMAGE021
反向推理:将交通流量时空序列
Figure 464516DEST_PATH_IMAGE049
以相反的顺序依次输入反向网络交通流量预测模型,回溯前置条件
Figure 810047DEST_PATH_IMAGE028
并结合历史记忆函数
Figure 55345DEST_PATH_IMAGE029
和历史隐藏状态函数
Figure 853537DEST_PATH_IMAGE030
,输出
Figure 458962DEST_PATH_IMAGE031
时刻的潜在表征函数
Figure 658999DEST_PATH_IMAGE032
Figure 308155DEST_PATH_IMAGE033
其中,
Figure 796905DEST_PATH_IMAGE034
表示反向网络交通流量预测模型;
Figure 330655DEST_PATH_IMAGE035
表示一个卷积层,它的作用在于将潜在表征函数
Figure 260565DEST_PATH_IMAGE036
投影到
Figure 955988DEST_PATH_IMAGE031
时刻,指定区域内的交通流量预测函数
Figure DEST_PATH_IMAGE050
在本发明的一个实施例中,所构建的蒸馏损失函数
Figure 874311DEST_PATH_IMAGE051
采用如下公式计算:
Figure DEST_PATH_IMAGE052
其中,
Figure 618276DEST_PATH_IMAGE011
是前向网络交通流量预测模型预测的
Figure 651960DEST_PATH_IMAGE012
时刻交通流量数据;
Figure 783864DEST_PATH_IMAGE013
是前向网络交通流量预测模型预测的
Figure 653731DEST_PATH_IMAGE012
时刻的潜在表征;
Figure 732545DEST_PATH_IMAGE014
是反向网络交通流量预测模型预测的
Figure 496102DEST_PATH_IMAGE012
时刻的潜在表征;
Figure 425006DEST_PATH_IMAGE015
是反向网络交通流量预测模型预测的
Figure 906803DEST_PATH_IMAGE012
时刻交通流量数据。
上述蒸馏损失函数
Figure 461412DEST_PATH_IMAGE051
的作用在于促使前向网络交通流量预测模型和反向网络交通流量预测模型分别输出的交通流量预测函数和潜在表征函数分别互相逼近。
相应地,所构建的重建损失函数
Figure 345054DEST_PATH_IMAGE006
采用如下公式计算:
Figure 943395DEST_PATH_IMAGE053
其中,
Figure 912488DEST_PATH_IMAGE054
Figure 5209DEST_PATH_IMAGE012
时刻的真实交通流量数据;
Figure 946620DEST_PATH_IMAGE011
是前向网络交通流量预测模型预测的
Figure 591228DEST_PATH_IMAGE012
时刻交通流量数据;
Figure 673716DEST_PATH_IMAGE015
是反向网络交通流量预测模型预测的
Figure 429182DEST_PATH_IMAGE012
时刻交通流量数据。
上述重建损失函数
Figure 428362DEST_PATH_IMAGE006
的作用在于保证交通流量预测模型输出的交通流量预测函数的真实性和准确性,对每个交通流量预测模型预测的交通流量进行约束,使其接近于真实值。
在此基础上,整个双向蒸馏网络中的总损失函数
Figure 978292DEST_PATH_IMAGE055
通过如下公式计算:
Figure 46611DEST_PATH_IMAGE003
进一步地,在步骤S3中,层级特定的元适配器根据前向网络交通流量预测模型和反向网络交通流量预测模型中不同层级的学习复杂度对每一层生成相应的学习率,利用每一层的学习率对不同层级的短期空间交互信息进行精细调整,从而使双向蒸馏网络完全收敛。
在本发明的一个实施例中,针对每一层生成相应的学习率,包括如下的具体步骤:
以前向网络交通流量预测模型为例,将前向网络交通流量预测模型第
Figure 605768DEST_PATH_IMAGE056
层(
Figure 193876DEST_PATH_IMAGE057
)所对应的学习率(元参数)表示为
Figure 180286DEST_PATH_IMAGE058
。经过一次梯度更新后的网络参数
Figure 548951DEST_PATH_IMAGE059
可以表示为:
Figure 272318DEST_PATH_IMAGE060
其中,
Figure 105145DEST_PATH_IMAGE061
为总损失函数
Figure 137823DEST_PATH_IMAGE062
针对各个网络参数的梯度。
Figure 790521DEST_PATH_IMAGE056
层的元学习器根据不同层的学习经验进一步训练学习率:
Figure 816115DEST_PATH_IMAGE063
其中
Figure 769028DEST_PATH_IMAGE064
为对学习率
Figure 972607DEST_PATH_IMAGE058
进行更新的更新步长,
Figure 112601DEST_PATH_IMAGE065
是一个损失函数,其灵感来自于一个归纳偏置,即较浅层的学习更容易,在后续的微调中应该拥有较慢的学习率。在本发明的一个实施例中,用如下公式表示这种归纳偏置过程,并对学习率(元参数)进行正则化处理:
Figure 177772DEST_PATH_IMAGE066
其中,超参数
Figure 719611DEST_PATH_IMAGE067
,以保证浅层学习率小于上层学习率。在本发明的一个实施例中,超参数
Figure 421988DEST_PATH_IMAGE068
可以优选设置为2。
反向网络交通流量预测模型也可以采用同样的方式进行学习。经过多次迭代更新,获得适合每一层的学习率。
接下来,利用元适配器生成的每一层级特定的学习率
Figure 924645DEST_PATH_IMAGE058
Figure 167407DEST_PATH_IMAGE069
,对前向网络交通流量预测模型和反向网络交通流量预测模型中不同层级的网络参数
Figure 891650DEST_PATH_IMAGE070
Figure 764928DEST_PATH_IMAGE071
进行针对性的更新:
Figure 879514DEST_PATH_IMAGE072
Figure 552066DEST_PATH_IMAGE073
通过上述步骤,可以使双向蒸馏网络中的不同层级的网络参数和元参数经过精细调整,可以更加有效地捕获交通流量时空序列中的短期空间交互信息,从而使预测的交通流量更加准确。
利用上述步骤S1~S3所获得的双向蒸馏网络在训练过程中,首先为前向网络交通流量预测模型和反向网络交通流量预测模型中的各层赋予初始相同的学习率(元参数),然后迭代执行步骤S1和步骤S2,并通过总损失函数
Figure 802919DEST_PATH_IMAGE074
更新双向蒸馏网络的网络参数,使交通流量预测模型初步收敛。然后,迭代执行步骤S1和步骤S3,交替优化网络参数和元参数,直至交通流量预测模型最终完全收敛。双向蒸馏网络的网络参数根据层级特定的学习复杂性进行更新,而元参数则朝着最优收敛方向优化。在本发明的一个实施例中,在训练过程中采用两个不同的Adam优化器分别更新网络参数和元参数,其中元参数初始化为
Figure 519202DEST_PATH_IMAGE075
,并以
Figure 121085DEST_PATH_IMAGE076
的学习率进行优化,当元参数大于
Figure 830284DEST_PATH_IMAGE077
或小于0时进行裁剪。
相应地,上述双向蒸馏网络在用于交通流量预测时,由于无法直接获得未来的交通流量数据,仅执行上述步骤S1和步骤S2中的前向推理过程(即执行步骤S4:基于双向蒸馏网络中的前向网络交通流量预测模型进行推理),获得针对当前输入的交通流量时空序列的未来预测结果,不再执行反向推理过程。因此,本发明实施例所提供的基于双向蒸馏网络的交通流量预测方法与上述基线模型如PredRNN-V2相比,并不会增加额外的计算量。
为了验证本发明实施例提供的交通流量预测方法的实际效果,发明人在具有剧烈时空流动的真实交通流量预测任务场景上进行落地使用。具体地,发明人选取了北京出租车在2013年7月1日至2013年10月30日、2014年3月1日至2014年6月30日、2015年3月1日至2015年6月30日以及2015年11月1日至2016年4月10日四个时间段的交通流量数据,数据中每一帧包含大小为32×32的两个通道,表示同一区域内各个位置的输入流量和输出流量。不同帧之间的时间间隔为30分钟,随着时间的推移呈现出剧烈和不均匀的变化。为了和其他方法进行公平的比较,本发明使用过去2小时的4帧作为输入来预测未来2小时的4帧。每一帧预测结果和真实值之间的均方误差如表1所示:
表 1
Figure 466802DEST_PATH_IMAGE078
从表1中可以看到,本发明在真实交通流量预测任务场景可以取得较为优异的表现,具有较好的实用价值。
在上述基于双向蒸馏网络的交通流量预测方法的基础上,本发明进一步提供一种基于双向蒸馏网络的交通流量预测装置。如图2所示,该交通流量预测装置包括一个或多个处理器21和存储器22。其中,存储器22与处理器21耦接,用于存储一个或多个程序,当所述一个或多个程序被所述一个或多个处理器21执行,使得所述一个或多个处理器21实现上述实施例中基于双向蒸馏网络的交通流量预测方法。
其中,处理器21用于控制该基于双向蒸馏网络的交通流量预测装置的整体操作,以完成上述基于双向蒸馏网络的交通流量预测方法的全部或部分步骤。在本发明的实施例中,该处理器21优选为GPU(图形处理单元),但也可以是FPGA(现场可编程逻辑门阵列)、ASIC(专用集成电路)、DSP(数字信号处理器)等。存储器22用于存储各种类型的数据以支持在该基于双向蒸馏网络的交通流量预测方法的操作,这些数据例如可以包括用于在该基于双向蒸馏网络的交通流量预测装置上操作的任何应用程序或方法的指令,以及应用程序相关的数据。
该存储器22可以由任何类型的易失性或非易失性存储设备或者它们的组合实现,例如静态随机存取存储器(SRAM)、电可擦除可编程只读存储器(EEPROM)、可擦除可编程只读存储器(EPROM)、可编程只读存储器(PROM)、只读存储器(ROM)、磁存储器、快闪存储器等。
在一个示例性实施例中,基于双向蒸馏网络的交通流量预测装置具体可以由计算机芯片或实体实现,或者由具有某种功能的产品来实现,用于执行上述基于双向蒸馏网络的交通流量预测方法,并达到如上述方法一致的技术效果。一种典型的实施例为计算机。具体地说,计算机例如可以为个人计算机、膝上型计算机、车载人机交互设备、公安卡口检查设备、蜂窝电话、相机电话、智能电话、个人数字助理、媒体播放器、导航设备、电子邮件设备、游戏控制台、平板计算机、可穿戴设备或者这些设备中的任何设备的组合。
在另一个示例性实施例中,本发明还提供一种包括程序指令的计算机可读存储介质,该程序指令被处理器执行时实现上述任意一个实施例中的基于双向蒸馏网络的交通流量预测方法的步骤。例如,该计算机可读存储介质可以为包括程序指令的存储器,上述程序指令可由基于双向蒸馏网络的交通流量预测装置的处理器执行以完成上述基于双向蒸馏网络的交通流量预测方法,并达到如上述方法一致的技术效果。
与现有技术相比较,本发明所提供的基于双向蒸馏网络的交通流量预测方法与装置首次从知识转移的角度对跨越时空的交通流量预测任务进行建模,以知识蒸馏的方式构建双向复杂时空动态,并通过元学习方式细化多层级的空间相关性。它有效地捕获了交通流量时空序列的长期时间相关性和短期空间相关性,在推理过程中与基线模型相比,可以在不增加额外计算量的情况下有效提高交通流量预测的准确性。
上面对本发明所提供的基于双向蒸馏网络的交通流量预测方法与装置进行了详细的说明。对本领域的一般技术人员而言,在不背离本发明实质内容的前提下对它所做的任何显而易见的改动,都将构成对本发明专利权的侵犯,将承担相应的法律责任。

Claims (8)

1.一种基于双向蒸馏网络的交通流量预测方法,其特征在于包括如下步骤:
S 1,从交通流量的训练数据集中随机采样至少一个交通流量时空序列;
S2,同时建立前向网络交通流量预测模型和反向网络交通流量预测模型,并在两个交通流量预测模型之间以知识蒸馏的方式构建双向复杂时空动态;其中,将交通流量时空序列xin={X1,…,XT}按时间顺序依次输入前向网络交通流量预测模型,根据t时刻的交通流量数据Xt,历史记忆函数Ct-1和历史隐藏状态函数Ht-1,输出t+1时刻的交通流量预测函数X t+1
Figure FDA0004061270370000011
Figure FDA0004061270370000012
同时,将交通流量时空序列
Figure FDA0004061270370000013
依次输入反向网络交通流量预测模型,回溯前置条件Xt-1并结合历史记忆函数
Figure FDA0004061270370000014
和历史隐藏状态函数
Figure FDA0004061270370000015
输出t-1时刻的潜在表征函数
Figure FDA0004061270370000016
Figure FDA0004061270370000017
Figure FDA0004061270370000018
其中,
Figure FDA0004061270370000019
表示前向网络交通流量预测模型;
Figure FDA00040612703700000110
表示卷积层,作用在于将潜在表征函数Vt +1投影到t+1时刻,指定区域内的交通流量预测函数X t+1
Figure FDA00040612703700000111
表示反向网络交通流量预测模型;
Figure FDA00040612703700000112
表示卷积层,作用在于将潜在表征函数
Figure FDA00040612703700000113
投影到t-1时刻,指定区域内的交通流量预测函数
Figure FDA00040612703700000114
S 3,利用层级特定的元适配器对前向网络交通流量预测模型和反向网络交通流量预测模型中不同层级的短期空间交互信息进行精细调整,使双向蒸馏网络完全收敛;
S4,基于所述双向蒸馏网络中的前向网络交通流量预测模型,获得针对当前输入的交通流量时空序列的未来预测结果。
2.如权利要求1所述的交通流量预测方法,其特征在于:
在训练过程中,为前向网络交通流量预测模型和反向网络交通流量预测模型中的各层赋予初始相同的学习率,然后迭代执行步骤S1和步骤S2。
3.如权利要求2所述的交通流量预测方法,其特征在于:
根据总损失函数lbid更新所述双向蒸馏网络的网络参数,使交通流量预测模型初步收敛;然后,迭代执行步骤S1和步骤S3,交替优化网络参数和元参数,直至交通流量预测模型最终完全收敛。
4.如权利要求2所述的交通流量预测方法,其特征在于所述总损失函数lbid的计算公式如下:
lbid=lrec+ldis
其中,lrec为重建损失函数,ldis为蒸馏损失函数。
5.如权利要求4所述的交通流量预测方法,其特征在于所述重建损失函数lrec采用如下公式计算:
Figure FDA0004061270370000021
其中,Xt是t时刻的真实交通流量数据;X′t是前向网络交通流量预测模型预测的t时刻交通流量数据;
Figure FDA0004061270370000022
是反向网络交通流量预测模型预测的t时刻交通流量数据。
6.如权利要求4所述的交通流量预测方法,其特征在于所述蒸馏损失函数ldis采用如下公式计算:
Figure FDA0004061270370000023
其中,X′t是前向网络交通流量预测模型预测的t时刻交通流量数据;V′t是前向网络交通流量预测模型预测的t时刻的潜在表征;
Figure FDA0004061270370000024
是反向网络交通流量预测模型预测的t时刻的潜在表征;
Figure FDA0004061270370000025
是反向网络交通流量预测模型预测的t时刻交通流量数据。
7.如权利要求1所述的交通流量预测方法,其特征在于:
在步骤S3中,所述元适配器根据前向网络交通流量预测模型和反向网络交通流量预测模型中不同层级的学习复杂度对每一层生成相应的学习率,利用每一层的学习率对不同层级的短期空间交互信息进行精细调整。
8.一种基于双向蒸馏网络的交通流量预测装置,其特征在于包括处理器和存储器,所述处理器读取所述存储器中的计算机程序,用于执行权利要求1~7中任意一项所述的交通流量预测方法。
CN202211419913.8A 2022-11-14 2022-11-14 一种基于双向蒸馏网络的交通流量预测方法与装置 Active CN115631631B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211419913.8A CN115631631B (zh) 2022-11-14 2022-11-14 一种基于双向蒸馏网络的交通流量预测方法与装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211419913.8A CN115631631B (zh) 2022-11-14 2022-11-14 一种基于双向蒸馏网络的交通流量预测方法与装置

Publications (2)

Publication Number Publication Date
CN115631631A CN115631631A (zh) 2023-01-20
CN115631631B true CN115631631B (zh) 2023-04-07

Family

ID=84910335

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211419913.8A Active CN115631631B (zh) 2022-11-14 2022-11-14 一种基于双向蒸馏网络的交通流量预测方法与装置

Country Status (1)

Country Link
CN (1) CN115631631B (zh)

Citations (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111243269A (zh) * 2019-12-10 2020-06-05 福州市联创智云信息科技有限公司 基于融合时空特征的深度网络的交通流预测方法

Family Cites Families (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110472730A (zh) * 2019-08-07 2019-11-19 交叉信息核心技术研究院(西安)有限公司 一种卷积神经网络的自蒸馏训练方法和可伸缩动态预测方法
CN111130839B (zh) * 2019-11-04 2021-07-16 清华大学 一种流量需求矩阵预测方法及其系统
CN111882031A (zh) * 2020-06-30 2020-11-03 华为技术有限公司 一种神经网络蒸馏方法及装置
AU2020102350A4 (en) * 2020-09-21 2020-10-29 Guizhou Minzu University A Spark-Based Deep Learning Method for Data-Driven Traffic Flow Forecasting
CN113053115B (zh) * 2021-03-17 2022-04-22 中国科学院地理科学与资源研究所 一种基于多尺度图卷积网络模型的交通预测方法
CN113988263A (zh) * 2021-10-29 2022-01-28 内蒙古大学 工业物联网边缘设备中基于知识蒸馏的空时预测方法

Patent Citations (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111243269A (zh) * 2019-12-10 2020-06-05 福州市联创智云信息科技有限公司 基于融合时空特征的深度网络的交通流预测方法

Also Published As

Publication number Publication date
CN115631631A (zh) 2023-01-20

Similar Documents

Publication Publication Date Title
Ziat et al. Spatio-temporal neural networks for space-time series forecasting and relations discovery
AU2020385049B2 (en) Identifying optimal weights to improve prediction accuracy in machine learning techniques
Quilodrán-Casas et al. Digital twins based on bidirectional LSTM and GAN for modelling the COVID-19 pandemic
WO2021160686A1 (en) Generative digital twin of complex systems
US20180121733A1 (en) Reducing computational overhead via predictions of subjective quality of automated image sequence processing
US20200265291A1 (en) Spatio temporal gated recurrent unit
CN113139446B (zh) 一种端到端自动驾驶行为决策方法、系统及终端设备
CN110570035B (zh) 同时建模时空依赖性和每日流量相关性的人流量预测系统
Coşkun et al. Deep reinforcement learning for traffic light optimization
CN114303177A (zh) 通过迁移学习生成具有不同疲劳程度的视频数据集的系统和方法
CN109559329A (zh) 一种基于深度去噪自动编码器的粒子滤波跟踪方法
KR102093577B1 (ko) 학습네트워크를 이용한 예측 영상 생성 방법 및 예측 영상 생성 장치
CN112633463B (zh) 用于建模序列数据中长期依赖性的双重递归神经网络架构
Hoy et al. Learning to predict pedestrian intention via variational tracking networks
Zhu et al. Multi-task credible pseudo-label learning for semi-supervised crowd counting
CN110047096A (zh) 一种基于深度条件随机场模型的多目标跟踪方法和系统
CN117237756A (zh) 一种训练目标分割模型的方法、目标分割方法及相关装置
Qiu et al. Iterative teaching by data hallucination
Zuo et al. Off-policy adversarial imitation learning for robotic tasks with low-quality demonstrations
CN115631631B (zh) 一种基于双向蒸馏网络的交通流量预测方法与装置
CN112395505B (zh) 一种基于协同注意力机制的短视频点击率预测方法
CN117636626A (zh) 强化道路周边空间特征的异质图交通预测方法及系统
CN108470212A (zh) 一种能利用事件持续时间的高效lstm设计方法
CN117056595A (zh) 一种交互式的项目推荐方法、装置及计算机可读存储介质
Yalçın Weather parameters forecasting with time series using deep hybrid neural networks

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant