CN114638343B - 模型训练方法、预测方法、装置、设备及存储介质 - Google Patents
模型训练方法、预测方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- CN114638343B CN114638343B CN202210310844.0A CN202210310844A CN114638343B CN 114638343 B CN114638343 B CN 114638343B CN 202210310844 A CN202210310844 A CN 202210310844A CN 114638343 B CN114638343 B CN 114638343B
- Authority
- CN
- China
- Prior art keywords
- time window
- application end
- transaction amount
- state data
- network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 133
- 238000012549 training Methods 0.000 title claims abstract description 104
- 230000006870 function Effects 0.000 claims description 33
- 238000006243 chemical reaction Methods 0.000 claims description 30
- 238000000605 extraction Methods 0.000 claims description 26
- 230000015654 memory Effects 0.000 claims description 17
- 238000004590 computer program Methods 0.000 claims description 10
- 238000012545 processing Methods 0.000 claims description 9
- 230000006403 short-term memory Effects 0.000 claims description 4
- 230000007704 transition Effects 0.000 claims description 4
- 238000010586 diagram Methods 0.000 description 16
- 230000008569 process Effects 0.000 description 13
- 230000003993 interaction Effects 0.000 description 7
- 230000009286 beneficial effect Effects 0.000 description 4
- 238000004458 analytical method Methods 0.000 description 3
- 238000013473 artificial intelligence Methods 0.000 description 3
- 235000015203 fruit juice Nutrition 0.000 description 3
- 238000011161 development Methods 0.000 description 2
- YHXISWVBGDMDLQ-UHFFFAOYSA-N moclobemide Chemical compound C1=CC(Cl)=CC=C1C(=O)NCCN1CCOCC1 YHXISWVBGDMDLQ-UHFFFAOYSA-N 0.000 description 2
- 230000005540 biological transmission Effects 0.000 description 1
- 238000007405 data analysis Methods 0.000 description 1
- 238000013480 data collection Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 239000003814 drug Substances 0.000 description 1
- 235000013399 edible fruits Nutrition 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- PCHJSUWPFVWCPO-UHFFFAOYSA-N gold Chemical compound [Au] PCHJSUWPFVWCPO-UHFFFAOYSA-N 0.000 description 1
- 239000010931 gold Substances 0.000 description 1
- 229910052737 gold Inorganic materials 0.000 description 1
- 230000010354 integration Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000001556 precipitation Methods 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 230000006641 stabilisation Effects 0.000 description 1
- 238000011105 stabilization Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012731 temporal analysis Methods 0.000 description 1
- 238000000700 time series analysis Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/20—Ensemble learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/044—Recurrent networks, e.g. Hopfield networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- Mathematical Physics (AREA)
- General Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
- Computer And Data Communications (AREA)
Abstract
本发明提供一种应用于交易量预测的模型训练方法、预测方法、装置、设备及存储介质,该模型训练方法,包括:将样本数据输入各个应用端的底层网络,得到状态数据,由服务端将各个底层网络输出的状态数据拼接;将拼接状态数据输入第一应用端的上层网络,得到预测交易量;根据预测交易量和实际交易量,生成第一梯度,并根据第一梯度更新第一应用端的上层网络的模型参数;根据更新后的上层网络以及第一梯度,生成第二梯度,并将第二梯度发送至服务端,服务端将第二梯度拆分为各个应用端对应的第三梯度,以基于第三梯度更新各个应用端的底层网络的模型参数,直至交易量预测模型稳定,基于多应用端的数据进行模型训练,提高了模型预测的准确度。
Description
技术领域
本发明涉及互联网、人工智能及大数据技术领域,尤其涉及一种应用于交易量预测的模型训练方法、预测方法、装置、设备及存储介质。
背景技术
随着人工智能技术的发展,各类网络模型被广泛地应用于各个领域。
对于时间序列预测,通常采用传统的自回归模型、ARIMA(AutoregressiveIntegrated Moving Average Mode,差分整合移动平均自回归模型)模型或LSTM(LongShort Term Memory,长短期记忆网络)等深度学习模型进行。
然而,发明人在实现本发明的过程中,发现基于上述现有技术中的方案至少存在以下技术问题:由于在模型训练时仅考虑了时间序列数据本身,仅为单平台的数据,预测精度较低,无法满足需求。
发明内容
本发明实施例提供一种应用于交易量预测的模型训练方法、预测方法、装置、设备及存储介质,用于解决模型训练的精度较低的问题。
第一方面,本发明实施例提供一种应用于交易量预测的模型训练方法,应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述方法应用于第一应用端,所述方法包括:
依次将每个时间窗口对应的交易量样本数据输入所述第一应用端的底层网络,以得到各个时间窗口对应的状态数据;将所述服务端发送的各个时间窗口对应的拼接状态数据,输入所述第一应用端的上层网络,以得到各个时间窗口对应的预测交易量,其中,每一时间窗口对应的所述拼接状态数据由所述服务端对各个应用端的底层网络输出的所述时间窗口对应的状态数据拼接而成,所述第二应用端的每一时间窗口对应的状态数据由所述第二应用端的底层网络根据所述时间窗口对应的协变量样本数据生成,所述协变量为与所述交易量相关的特征变量;针对每个时间窗口,根据所述时间窗口对应的预测交易量以及实际交易量,生成所述时间窗口对应的第一梯度,并根据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数;根据模型参数更新后的上层网络以及所述时间窗口对应的第一梯度,生成所述时间窗口对应的第二梯度,并将所述时间窗口对应的第二梯度发送至所述服务端,以经由所述服务端将所述时间窗口对应的第二梯度拆分为各个应用端所述时间窗口对应的第三梯度,以基于所述时间窗口对应的各个第三梯度更新各个应用端的底层网络的模型参数,得到更新后的交易量预测模型。
在一种具体实施方式中,所述底层网络包括输入层和由预设层数的长短期记忆网络构成的隐藏层,依次将每个时间窗口对应的交易量样本数据输入所述第一应用端的底层网络,以得到各个时间窗口对应的状态数据,包括:
针对每个时间窗口,将所述时间窗口对应的交易量样本数据输入所述第一应用端的底层网络的输入层;基于所述第一应用端的底层网络的隐藏层,对所述时间窗口对应的交易量样本数据进行特征提取,得到所述时间窗口对应的状态数据。
在一种具体实施方式中,所述输入层为编码器,将所述时间窗口对应的交易量样本数据输入所述第一应用端的底层网络的输入层,包括:
基于所述第一应用端的底层网络编码器,对各个时间窗口对应的交易量样本数据进行编码,得到特征编码数据。
相应的,基于所述第一应用端的底层网络的隐藏层,对所述时间窗口对应的交易量样本数据进行特征提取,得到所述时间窗口对应的状态数据,包括:
基于所述第一应用端的底层网络的隐藏层,对所述时间窗口对应的特征编码数据进行特征提取,得到所述时间窗口对应的状态数据。
在一种具体实施方式中,所述上层网络包括数据转换层和输出层,将所述服务端发送的各个时间窗口对应的拼接状态数据,输入所述第一应用端的上层网络,以得到各个时间窗口对应的预测交易量,包括:
依次将服务端发送的各个时间窗口对应的拼接状态数据,输入所述第一应用端的上层网络的数据转换层,以输出各个时间窗口对应的状态转换数据;基于所述上层网络的输出层,对各个时间窗口对应的状态转换数据进行特征提取,以得到各个时间窗口对应的预测交易量。
在一种具体实施方式中,根据所述时间窗口对应的预测交易量以及实际交易量,生成所述时间窗口对应的第一梯度,并基于根据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数,包括:
根据所述时间窗口对应的预设交易量和实际交易量,计算所述上层网络的损失函数所述时间窗口对应的第一梯度;通过在所述上层网络中反向传导所述时间窗口对应的第一梯度,更新所述上层网络的模型参数。
在一种具体实施方式中,在得到各个时间窗口对应的状态数据之后,所述方法还包括:
对各个时间窗口对应的状态数据进行加密,并将加密后的各个时间窗口的状态数据发送至所述服务端,以使所述服务度对接收到的来自各个应用端的各个时间窗口的状态数据进行解密和拼接,以得到各个时间窗口对应的拼接状态数据。
在一种具体实施方式中,根据所述时间窗口对应的预测交易量以及实际交易量,生成所述时间窗口对应的第一梯度,包括:
根据所述时间窗口对应的预设交易量和实际交易量,计算所述时间窗口对应的损失函数的值;判断所述时间窗口以及所述时间窗口之前预设数量的各个时间窗口对应的损失函数的值是否满足预设条件;若否,则根据所述时间窗口对应的预测交易量以及实际交易量,生成所述时间窗口对应的第一梯度。
第二方面,本发明实施例提供一种应用于交易量预测的模型训练方法,应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述方法应用于第二应用端,所述方法包括:
依次将每个时间窗口对应的协变量样本数据输入所述第二应用端的底层网络,以得到各个时间窗口对应的状态数据;将各个时间窗口对应的状态数据发送至服务端,以使所述服务端对各个时间窗口接收到的所述第二应用端以及第一应用端对应的状态数据进行拼接,以得到各个时间窗口对应的拼接状态数据,其中,所述第一应用端每一时间窗口对应的状态数据通过将每一时间窗口对应的交易量样本数据输入第一应用端的底层网络得到;针对每个时间窗口,根据所述时间窗口对应的所述第二应用端的第三梯度更新所述第二应用端的底层网络的模型参数,其中,所述时间窗口对应的第三梯度为服务端将所述时间窗口对应的第二梯度拆分后所得,所述时间窗口对应的所述第二梯度为根据所述时间窗口对应的第一梯度以及基于所述时间窗口对应的第一梯度更新模型参数后的第一应用端的上层网络生成的,所述时间窗口对应的第一梯度为第一应用端根据所述时间窗口对应的第一应用端的上层网络输出的预测交易量以及实际交易量生成的,所述时间窗口对应的预测交易量通过将服务端发送的所述时间窗口对应的拼接状态数据输入所述第一应用端的上层网络得到。
在一种具体实施方式中,所述底层网络包括输入层和由预设层数的长短期记忆网络构成的隐藏层,依次将每个时间窗口对应的协变量样本数据输入所述第二应用端的底层网络,以得到各个时间窗口对应的状态数据,包括:
针对每个时间窗口,将所述时间窗口对应的协变量样本数据输入所述第二应用端的底层网络的输入层;
基于所述第二应用端的底层网络的隐藏层,对所述时间窗口对应的交易量样本数据进行特征提取,得到所述时间窗口对应的状态数据。
在一种具体实施方式中,所述输入层为编码器,将所述时间窗口对应的协变量样本数据输入所述第二应用端的底层网络的输入层,包括:
基于所述第二应用端的底层网络编码器,对各个时间窗口对应的协变量样本数据进行编码,得到特征编码数据。
相应的,基于所述第二应用端的底层网络的隐藏层,对所述时间窗口对应的协变量样本数据进行特征提取,得到所述时间窗口对应的状态数据,包括:
基于所述第二应用端的底层网络的隐藏层,对所述时间窗口对应的特征编码数据进行特征提取,得到所述时间窗口对应的状态数据。
在一种具体实施方式中,在得到各个时间窗口对应的状态数据之后,所述方法还包括:
对各个时间窗口对应的状态数据进行加密,并将加密后的各个时间窗口的状态数据发送至所述服务端,以使所述服务度对接收到的来自各个应用端的各个时间窗口的状态数据进行解密和拼接,以得到各个时间窗口对应的拼接状态数据。
第三方面,本发明实施例提供一种应用于交易量预测的模型训练方法,应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述方法应用于服务端,所述方法包括:
针对每个时间窗口,接收各个应用端发送的所述时间窗口的状态数据,其中,所述第一应用端所述时间窗口对应的状态数据通过将所述时间窗口对应的交易量样本数据输入所述第一应用端的底层网络得到,所述第二应用端所述时间窗口对应的状态数据通过将所述时间窗口对应的协变量样本数据输入所述第二应用端的底层网络得到;
对各个应用端发送的所述时间窗口的状态数据进行拼接,以得到所述时间窗口对应的拼接状态数据;
将所述时间窗口对应的所述拼接状态数据发送至第一应用端的上层网络,以通过将所述时间窗口对应的拼接状态数据输入所述第一应用端的上层网络,得到所述时间窗口对应的预测交易量,以基于所述时间窗口对应的预测交易量生成所述时间窗口对应的第一梯度,根据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数,以及根据模型参数更新后的上层网络以及所述时间窗口对应的第一梯度,生成所述时间窗口对应的第二梯度;
接收所述时间窗口对应的第二梯度,并对所述时间窗口对应的第二梯度进行拆分,以生成各个应用端所述时间窗口对应的第三梯度,以基于所述时间窗口对应的第三梯度更新各个应用端的底层网络的模型参数,以得到更新后的交易量预测模型。
在一种具体实施方式中,每个应用端所发送的所述时间窗口的状态数据为加密后的状态数据,对各个应用端发送的所述时间窗口的状态数据进行拼接,以得到所述时间窗口对应的拼接状态数据,包括:
对每个应用端发送的所述时间窗口的状态数据进行解密和拼接,以得到所述时间窗口对应的拼接状态数据。
第四方面,本发明实施例提供一种基于模型的交易量预测方法,所述方法应用于第一应用端,所述方法包括:
获取历史交易量;基于预先训练的交易量预测模型,对所述历史交易量以及所述历史交易量对应的协变量数据进行处理,以生成所述历史交易量对应的预测交易量,其中,所述交易量预测模型为基于本发明第一方面对应的任意实施例提供的应用于交易量预测的模型训练方法生成的。
在一种具体实施方式中,基于预先训练的交易量预测模型,对所述历史交易量以及所述历史交易量对应的协变量数据进行处理,包括:
将所述历史交易量输入所述第一应用端的底层网络,以得到所述第一应用端的底层网络输出的状态数据;
将服务端发送的拼接状态数据,输入所述第一应用端的上层网络,以生成所述历史交易量对应的预测交易量,其中,所述拼接状态数据为通过所述服务端将各个应用端的状态数据拼接后生成,所述各个应用端包括所述第一应用端和至少一个第二应用端,所述第二应用端的状态数据为由所述第二应用端的底层网络根据所述历史交易量对应的协变量数据生成。
在一种具体实施方式中,所述底层网络包括输入层和由预设层数的长短期记忆网络构成的隐藏层,将所述历史交易量输入所述第一应用端的底层网络,以得到所述第一应用端的底层网络输出的状态数据,包括:
将所述历史交易量输入所述第一应用端的底层网络的输入层;
基于所述第一应用端的底层网络的隐藏层,对所述历史交易量进行特征提取,以得到所述历史交易量对应的状态数据。
在一种具体实施方式中,所述输入层为编码器,将所述历史交易量输入所述第一应用端的底层网络的输入层,包括:
基于所述第一应用端的底层网络编码器,对所述历史交易量进行编码,得到特征编码数据。
相应的,基于所述第一应用端的底层网络的隐藏层,对所述历史交易量进行特征提取,以得到所述历史交易量对应的状态数据,包括:
基于所述第一应用端的底层网络的隐藏层,对所述特征编码数据进行特征提取,以得到所述历史交易量对应的状态数据。
在一种具体实施方式中,所述上层网络包括数据转换层和输出层,将所述服务端发送的各个时间窗口对应的拼接状态数据,将服务端发送的拼接状态数据,输入所述第一应用端的上层网络,以生成所述历史交易量对应的预测交易量,包括:
将服务端发送的拼接状态数据,输入所述第一应用端的上层网络的数据转换层,以输出状态转换数据;
基于所述上层网络的输出层,对所述状态转换数据进行特征提取,以得到所述历史交易量对应的预测交易量。
在一种具体实施方式中,在得到所述第一应用端的底层网络输出的状态数据之后,所述方法还包括:
对所述状态数据进行加密,并将加密后的状态数据发送至服务端,以使所述服务端对接收到的来自各个应用端的状态数据进行解密和拼接,以得到拼接状态数据。
第五方面,本发明实施例提供一种应用于交易量预测的模型训练装置,应用于交易量预测的模型训练装置,应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述装置应用于第一应用端,所述装置包括:
第一输入模块,用于依次将每个时间窗口对应的交易量样本数据输入所述第一应用端的底层网络,以得到各个时间窗口对应的状态数据;第二输入模块,用于将所述服务端发送的各个时间窗口对应的拼接状态数据,输入所述第一应用端的上层网络,以得到各个时间窗口对应的预测交易量,其中,每一时间窗口对应的所述拼接状态数据由所述服务端对各个应用端的底层网络输出的每一时间窗口对应的状态数据拼接而成,所述第二应用端的每一时间窗口对应的状态数据由所述第二应用端的底层网络根据每一时间窗口对应的协变量样本数据生成,所述协变量为与所述交易量相关的特征变量;第一梯度生成模块,用于针对每个时间窗口,根据所述时间窗口对应的预测交易量以及实际交易量,生成所述时间窗口对应的第一梯度,并根据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数;第一参数更新模块,用于根据模型参数更新后的上层网络以及所述时间窗口对应的第一梯度,生成所述时间窗口对应的第二梯度,并将所述时间窗口对应的第二梯度发送至所述服务端,以经由所述服务端将所述时间窗口对应的第二梯度拆分为各个应用端所述时间窗口对应的第三梯度,以基于所述时间窗口对应的各个第三梯度更新各个应用端的底层网络的模型参数,得到更新后的交易量预测模型。
第六方面,本发明实施例提供一种应用于交易量预测的模型训练装置,应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述装置应用于第二应用端,所述装置包括:
第三输入模块,用于依次将每个时间窗口对应的协变量样本数据输入所述第二应用端的底层网络,以得到各个时间窗口对应的状态数据;第一数据发送模块,用于将各个时间窗口对应的状态数据发送至服务端,以使所述服务端对各个时间窗口接收到的所述第二应用端以及第一应用端对应的状态数据进行拼接,以得到各个时间窗口对应的拼接状态数据,其中,所述第一应用端每一时间窗口对应的状态数据通过将每一时间窗口对应的交易量样本数据输入第一应用端的底层网络得到;第二参数更新模块,用于针对每个时间窗口,根据所述时间窗口对应的所述第二应用端的第三梯度更新所述第二应用端的底层网络的模型参数,其中,所述时间窗口对应的第三梯度为服务端将所述时间窗口对应的第二梯度拆分后所得,所述时间窗口对应的所述第二梯度为根据所述时间窗口对应的第一梯度以及基于所述时间窗口对应的第一梯度更新模型参数后的第一应用端的上层网络生成的,所述时间窗口对应的第一梯度为第一应用端根据所述时间窗口对应的第一应用端的上层网络输出的预测交易量以及实际交易量生成的,所述时间窗口对应的预测交易量通过将服务端发送的所述时间窗口对应的拼接状态数据输入所述第一应用端的上层网络得到。
第七方面,本发明实施例提供一种应用于交易量预测的模型训练装置,应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述装置应用于服务端,所述装置包括:
数据接收模块,用于针对每个时间窗口,接收各个应用端发送的所述时间窗口的状态数据,其中,所述第一应用端所述时间窗口对应的状态数据通过将所述时间窗口对应的交易量样本数据输入所述第一应用端的底层网络得到,所述第二应用端所述时间窗口对应的状态数据通过将所述时间窗口对应的协变量样本数据输入所述第二应用端的底层网络得到;数据拼接模块,用于对各个应用端发送的所述时间窗口的状态数据进行拼接,以得到所述时间窗口对应的拼接状态数据;第二数据发送模块,用于将所述时间窗口对应的所述拼接状态数据发送至第一应用端的上层网络,以通过将所述时间窗口对应的拼接状态数据输入所述第一应用端的上层网络,得到所述时间窗口对应的预测交易量,以基于所述时间窗口对应的预测交易量生成所述时间窗口对应的第一梯度,根据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数,以及根据模型参数更新后的上层网络以及所述时间窗口对应的第一梯度,生成所述时间窗口对应的第二梯度;梯度拆分模块,用于接收所述时间窗口对应的第二梯度,并对所述时间窗口对应的第二梯度进行拆分,以生成各个应用端所述时间窗口对应的第三梯度,以基于所述时间窗口对应的第三梯度更新各个应用端的底层网络的模型参数,以得到更新后的交易量预测模型。
第八方面,本发明实施例提供一种基于模型的交易量预测装置,所述装置应用于第一应用端,所述装置包括:
历史交易量获取模块,用于获取历史交易量;交易量预测模块,用于基于预先训练的交易量预测模型,对所述历史交易量以及所述历史交易量对应的协变量数据进行处理,以生成所述历史交易量对应的预测交易量,其中,所述交易量预测模型为基于本发明第一方面对应的任意实施例提供的应用于交易量预测的模型训练方法生成的。
第九方面,本发明实施例提供一种电子设备,包括:存储器和处理器;所述存储器用于存储所述处理器可执行指令;其中,所述处理器配置为经由执行所述可执行指令来执行上述任一项实施例所述的方法。
第十方面,本发明实施例提供一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现上述任一项实施例所述的方法。
第十一方面,本发明实施例提供一种计算机程序产品,包括计算机程序,该计算机程序被处理器执行时用于实现上述任一项实施例所述的方法。
本发明实施例提供一种应用于交易量预测的模型训练方法、预测方法、装置、设备及存储介质,针对参与联邦学习的多个参与方,即第一应用端、服务端和至少一个第二应用端,提供了一种交易预测模型及其训练方法,该交易预测模型包括设置在各个应用端的底层网络,以及设置在第一应用端的上层网络,该训练方法主要包括:通过各个应用端设置的底层网络输入样本数据,包括第一应用端对应的交易量样本数据以及各个第二应用端对应的协变量样本数据,通过各个应用端的底层网络输出各个应用端的状态数据,基于服务端拼接各个应用端的状态数据,并将拼接状态数据发送至第一应用端的上层网络,从而得到预测交易量,实现了一次迭代过程,通过每次迭代对应的梯度并基于服务端实现不同应用端梯度的下发,进行模型参数的更新,直至模型训练完毕,从而输出训练好的交易量预测模型,以基于该交易量预测模型进行交易量的预测,实现了基于多方数据进行交易量预测,提高了预测准确度,且通过联邦学习的方式,通过服务端实现不同应用端的数据的交互,提高了各个应用端的数据的安全性。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作一简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为相关技术中交易量预测模型的示意图;
图2为本发明实施例提供的时序数据预测模型的示意图;
图3为本发明一个实施例提供的应用于交易量预测的模型训练方法的流程示意图;
图4为本发明图3所示实施例中交易量预测模型的结构示意图;
图5为本发明另一个实施例提供的应用于交易量预测的模型训练方法的流程示意图;
图6为本发明另一个实施例提供的应用于交易量预测的模型训练方法的流程示意图;
图7为本发明另一个实施例提供的应用于交易量预测的模型训练方法的流程示意图;
图8为本发明一个实施例提供的基于模型的交易量预测方法的流程示意图;
图9为本发明一个实施例提供的应用于交易量预测的模型训练装置的示意图;
图10为本发明另一个实施例提供的应用于交易量预测的模型训练装置的示意图;
图11为本发明另一个实施例提供的应用于交易量预测的模型训练装置的示意图;
图12为本发明一个实施例提供的基于模型的交易量预测装置的结构示意图;
图13为本发明一个实施例提供的一种电子设备的结构示意图。
通过上述附图,已示出本公开明确的实施例,后文中将有更详细的描述。这些附图和文字描述并不是为了通过任何方式限制本公开构思的范围,而是通过参考特定实施例为本领域技术人员说明本公开的概念。
具体实施方式
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”、“第三”、“第四”等(如果存在)是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本发明的实施例例如能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
时间序列预测分析是利用过去一段时间内事务的时间特征来预测未来一段时间内该事务的时间特征,如商圈的人流量预测、物品的价格预测、销量预测等。
传统的时间序列预测算法中,多采用自回归、移动平均等算法进行,如ARIMA算法。随着人工智能技术的不断发展,基于深度学习的更细粒度的数据分析算法被应用于时间序列分析中,如XGBoost(Extreme Gradient Boosting,极限梯度提升算法)、LSTM等算法。
图1为相关技术中交易量预测模型的示意图,如图1所示,交易量预测模型往往设置在应用端100,将历史时间采集的应用端100的历史交易量输入交易预测模型进行模型训练,从而得到训练好的交易预测模型,以基于该交易预测模型进行未来时间的交易量预测,即输出预测交易量。
然而,在基于如图1所示的交易量预测模型进行训练时,由于仅基于单一应用端的交易量进行模型训练,即仅采用单一维度的样本数据进行模型训练,可能导致训练的可靠性偏低,从而导致交易量预测模型的准确性偏低的问题。
为了解决上述问题,本发明的发明人在经过创造性地劳动之后,得到了本发明的发明构思:结合其他维度的样本进行模型训练,即结合多个应用端对应的样本数据,通过联邦学习的方式进行模型训练,生成交易量预测模型,以提高交易量预测模型的准确度。
下面以具体地实施例对本发明的技术方案以及本发明的技术方案如何解决上述技术问题进行详细说明。下面这几个具体的实施例可以相互结合,对于相同或相似的概念或过程可能在某些实施例中不再赘述。下面将结合附图,对本发明的实施例进行描述。
图2为本发明实施例提供的时序数据预测模型的示意图,如图2所示,该时序数据预测模型包括设置在各个应用端(Party)的底层网络以及设置在第一应用端(ActiveParty)的上层网络,该各个应用端(Party)包括第一应用端(Active Party)以及至少一个第二应用端(Passive Party),图2中以一个第二应用端为例。底层网络用于接收对应的应用端的输入数据,如时序数据以及该时序数据对应的协变量数据,并输出状态数据,每个应用端对应的状态数据通过服务端进行拼接,如通过服务端的中间网络进行拼接,得到涉及各个应用端的输入数据的拼接状态数据,进而服务端将该拼接状态数据传递至第一应用端的上层网络,从而得到交易量预测模型输出的预测交易量,进而计算损失函数和梯度,基于该梯度进行上层网络的更新,并通过服务端拆分梯度,从而得到各个应用端对应梯度,实现各个应用端的底层网络的更新,得到更新后的时序数据预测模型,如此便完成一次迭代过程,循环执行上述迭代过程,直至得到稳定的时序数据预测模型,以基于该稳定的时序数据预测模型进行预测。
通过参与联邦学习的各个应用端以及服务端,实现了基于多平台数据的时序数据预测模型的训练,提高了时序数据预测模型的准确度,且通过联邦学习的方式进行数据交互,提高了数据的安全性。
本发明实施例提供的时序数据预测模型可以应用于任意一种时序数据的分析中,如商圈人流量、价格预测、销量预测等。为了便于描述,本发明以下实施例以交易量这一时序数据为例进行说明。
图3为本发明一个实施例提供的应用于交易量预测的模型训练方法的流程示意图,本发明实施例中提及的交易量预测模型包括设置在各个应用端的底层网络,以及设置在第一应用端的上层网络,该交易量预测模型应用于参与联邦学习的多个参与方,该多个参与方中包括第一应用端、服务端和至少一个第二应用端,本实施例提供的模型训练方法应用于第一应用端,该第一应用端为交易量对应的历史数据的提供方,如图3所示,该模型训练方法包括:
S301:依次将每个时间窗口对应的交易量样本数据输入所述第一应用端的底层网络,以得到各个时间窗口对应的状态数据。
其中,时间窗口可以包括一个或多个时间节点,每个时间节点对应一个交易量样本数据。交易量样本数据为第一应用端在历史时间对应的交易量。
具体的,第一应用端的底层网络包括输入层和隐藏层,输入层用于接收第一应用端输入的交易量样本数据,对该样本数据进行处理,如数据降维、格式转换等,进而将处理后的交易量样本数据传递至隐藏层,通过隐藏层对交易量样本数据进行特征提取,生成状态数据,进而将状态数据发送至服务端。底层网络的输入层和隐藏层可以为任意结构,本发明实施例对此不进行限定。
具体的,可以先获取训练数据,包括交易量训练数据和协变量训练数据,接着,根据预设窗口长度,分别将交易量训练数据和协变量训练数据划分为各个时间窗口对应的交易量样本数据和协变量样本数据。该预设窗口长度可以自定义设置,可以包括3、4、5或者其他数量的时间节点或时间步。
可选的,所述方法还包括:获取所述第一应用端的交易量训练数据,其中,所述交易量训练数据包括每一时间节点对应的交易量;基于预设窗口长度,对所述交易量训练数据进行划分,得到每个时间窗口对应的交易量样本数据。
具体的,针对每个应用端,在将样本数据,交易量样本数据或协变量样本数据,输入底层网络之前,需要对样本数据进行预处理,如确定每个样本数据的数据标识或ID,以及对样本数据进行对齐,使得不同应用端的样本数据的维度一致,以保证在进行联邦学习时,可以提取到同一时序或时间节点对应的来自不同应用端的数据标识。
可选的,所述底层网络包括输入层和由预设层数的长短期记忆网络构成的隐藏层,依次将每个时间窗口对应的交易量样本数据输入所述第一应用端的底层网络,以得到各个时间窗口对应的状态数据,包括:针对每个时间窗口,将所述时间窗口对应的交易量样本数据输入所述第一应用端的底层网络的输入层;基于所述第一应用端的底层网络的隐藏层,对所述时间窗口对应的交易量样本数据进行特征提取,得到所述时间窗口对应的状态数据。其中,预设层数可以为2层或3层。
在一些实施例中,为了提高数据交互的安全性,在生成状态数据之后,还可以对该状态数据进行加密,从而将加密后的状态数据发送至服务端,服务端则需要对各个应用端发送的加密的状态数据进行解密和拼接,从而得到包括各个应用端的状态数据的拼接状态数据。
可选的,在得到各个时间窗口对应的状态数据之后,所述方法还包括:对各个时间窗口对应的状态数据进行加密,并将加密后的各个时间窗口的状态数据发送至所述服务端,以使所述服务度对接收到的来自各个应用端的各个时间窗口的状态数据进行解密和拼接,以得到各个时间窗口对应的拼接状态数据。
每个应用端均可以在生成各个时间窗口对应的状态数据之后,对该状态数据进行加密,从而将加密后的状态数据发送至服务端,服务端在初始化时,存储有各个应用端的状态数据的解密算法,从而基于对应的解密算法对各个加密的状态数据进行解密,得到各个应用端各个时间窗口的状态数据,进而将每一时间窗口对应的各个应用端的状态数据进行拼接,得到该时间窗口对应的拼接状态数据。
具体的,第一应用端的底层网络的隐藏层向前传递的公式可以为:
hactive,i,t=H(hactive,i,t-1,zi,t-1,θ1)
其中,θ1为第一应用端的底层网络的模型参数;hactive,i,t为第一应用端的底层网络的隐藏层输出的第i个时间窗口t时间的状态数据;zi,t-1为第i个时间窗口t-1时间的交易量样本数据;H(·)为第一应用端的底层网络的隐藏层对应的函数。
第二应用端的底层网络的隐藏层向前传递的公式可以为:
hpassive,i,t=H(hpassive,i,t-1,xi,t-1,θ2)
其中,θ2为第二应用端的底层网络的模型参数;hpassive,i,t为第二应用端的底层网络的隐藏层输出的第i个时间窗口t时间的状态数据;xi,t-1为第二应用端对应的第i个时间窗口t-1时间的协变量样本数据;第二应用端的底层网络的隐藏层对应的函数与第一应用端的底层网络的隐藏层对应的函数相同,均为H(·)。
服务端在对接收到的各个应用端的状态数据进行解密之后,拼接的具体公式为:
hi,t=concat_func(hactive,i,t,hpassive,i,t)
其中,concat_func(·)为服务端拼接对应的函数。
通过服务端将各个时间窗口的各个时间节点的状态数据进行拼接,得到包括全部应用端的状态数据的拼接状态数据hi,t,实现了每个时间窗口下各个应用端的数据的收集,同时避免了源数据的泄露,提高了数据的安全性。
S302:将所述服务端发送的各个时间窗口对应的拼接状态数据,输入所述第一应用端的上层网络,以得到各个时间窗口对应的预测交易量。
其中,每一时间窗口对应的所述拼接状态数据由所述服务端对各个应用端的底层网络输出的所述时间窗口对应的状态数据拼接而成,所述第二应用端的每一时间窗口对应的状态数据由所述第二应用端的底层网络根据所述时间窗口对应的协变量样本数据生成,所述协变量为与所述交易量相关的特征变量。
本发明实施例中提及的交易量可以为任意一种商品的交易量,如生鲜销量、黄金交易量等,协变量则为影响该交易量的一个或多个因素。以A市场的生鲜销量为例,该协变量可以包括天气数据,如降水量、温度、风速等,还可以包括位于A市场附近的各个生鲜销售方的生鲜销量。以果汁的销量为例,协变量可以包括该果汁对应的水果的产量,还可以包括该果汁对应的广告宣传的阅读量等因素。
在每个时间窗口,第一应用端的上层网络用于接收服务端发送的由各个应用端的该时间窗口的状态数据组成的拼接状态数据,并输出该时间窗口对应的预设交易量。该上层网络可以为任意结构,本发明实施例对此不进行限定。
在一些实施例中,上层网络可以通过计算对数似然函数的参数,以最大化对数似然函数的方式更新或学习网络的参数。
可选的,所述上层网络包括数据转换层和输出层,将所述服务端发送的各个时间窗口对应的拼接状态数据,输入所述第一应用端的上层网络,以得到各个时间窗口对应的预测交易量,包括:依次将服务端发送的各个时间窗口对应的拼接状态数据,输入所述第一应用端的上层网络的数据转换层,以输出各个时间窗口对应的状态转换数据;基于所述上层网络的输出层,对各个时间窗口对应的状态转换数据进行特征提取,以得到各个时间窗口对应的预测交易量。
其中,上层网络的数据转换层可以通过线性或非线性函数,对拼接状态数据hi,t进行数据转换。
在一些实施例中,上层网络的数据转换层的表达式可以为:
其中,μ(hi,t)为拼接状态数据hi,t对应的状态转换数据;和bμ分别为数据转换层的模型参数。
输出层的表达式可以为:
其中,为第i个时间窗口t时间的预测交易量。
示例性的,图4为本发明图3所示实施例中交易量预测模型的结构示意图,如图4所示,该交易量预测模型包括设置在各个应用端的底层网络,设置在第一应用端的上层网络,以及设置在服务端的中间层,图4中以一个第二应用端为例,该第二应用端数据可以为多个,如2、3、4或者其他数据。该底层网络的输入层采用嵌入层Embedding,图4中以时间窗口的长度为3为例,假设需要预测时间t的交易量yt,则将交易量样本数据yt-3、yt-2和yt-1作为当前时间窗口对应的模型输入,以及将协变量样本数据xt-2、xt-1和xt作为当前时间窗口对应的模型输入,通过Embedding分别将交易量样本数据和协变量样本数据嵌入对应的底层网络的隐藏层,该隐藏层可以由3层的LSTM网络模型构成,通过LSTM进行特征提取,生成当前时间窗口各个时间节点对应的状态数据,即ht-3、ht-2和ht-1,并对当前时间窗口的状态数据加密,生成加密的状态数据encryped(hactive)和encryped(hpassive)。加密后的状态数据传递至服务端,通过服务端的中间层对加密后的状态数据进行解密和拼接,得到拼接状态数据hi,t,进而服务端将拼接状态数据hi,t发送至第一应用端的上层网络,通过该上层网络根据该拼接状态数据,生成预测交易量
在一些实施例中,各个应用端的底层网络的具体结构可以均相同。
S303:针对每个时间窗口,根据所述时间窗口对应的预测交易量以及实际交易量,生成所述时间窗口对应的第一梯度,并根据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数。
具体的,针对每个时间窗口或者每次迭代过程,在生成该时间窗口对应的预测交易量之后,基于该时间窗口对应的预测交易量以及该时间窗口对应的实际交易量,计算用于进行模型参数更新的第一梯度,从而基于该第一梯度更新第一应用端的上层网络的模型参数。
进一步地,针对每个时间窗口,在生成该时间窗口对应的预测交易量之后,可以基于该时间窗口对应的实际交易量和预测交易,计算损失函数的值以及该损失函数的第一梯度。
具体的,可以通过反向传导该第一梯度的方式,更新上层网络的模型参数。
S304,根据模型参数更新后的上层网络以及所述时间窗口对应的第一梯度,生成所述时间窗口对应的第二梯度,并将所述时间窗口对应的第二梯度发送至所述服务端,以经由所述服务端将所述时间窗口对应的第二梯度拆分为各个应用端所述时间窗口对应的第三梯度,以基于所述时间窗口对应的各个第三梯度更新各个应用端的底层网络的模型参数,得到更新后的交易量预测模型。
具体的,在每个时间窗口,在上层网络更新完毕之后,基于链式法则,根据模型参数更新后的上层网络的模型参数以及该第一梯度,计算出用于进行各个底层网络的模型参数更新的第二梯度。
在一些实施例中,在将第二梯度发送至服务端之前,为了提高数据交互的安全性,可以对第二梯度进行加密,进而由服务端对加密后的第二梯度进行解密和拆分,从而得到各个应用端的底层网络对应的第三梯度。其中,服务端的拼接操作与拆分操作,可以为互逆的两个操作。服务度将各个第三梯度发送至对应的应用端,从而通过第三梯度在底层网络的反向传导,更新各个底层网络的模型参数,从而得到更新后的交易量预测模型,如此便完成了一次迭代过程,可以通过多次迭代,直至得到稳定的交易量预测模型,或者迭代次数达到预设的最大次数。
通过所设置的服务端进行状态数据和梯度的交互,避免了不同的应用端之间的数据交互,提高了应用端数据的安全性。
基于上述分析可知,本发明实施例提供一种应用于交易量预测的模型训练方法,该交易预测模型包括设置在各个应用端的底层网络,以及设置在第一应用端的上层网络,该训练方法主要包括:通过各个应用端设置的底层网络输入样本数据,包括第一应用端对应的交易量样本数据以及各个第二应用端对应的协变量样本数据,通过各个应用端的底层网络输出各个应用端的状态数据,基于服务端拼接各个应用端的状态数据,并将拼接状态数据发送至第一应用端的上层网络,从而得到预测交易量,实现了一次迭代过程,通过每次迭代对应的梯度并基于服务端实现不同应用端梯度的下发,进行模型参数的更新,直至模型训练完毕,从而输出训练好的交易量预测模型,以基于该交易量预测模型进行交易量的预测,实现了基于多方数据进行交易量预测,提高了预测准确度,且通过联邦学习的方式,通过服务端实现不同应用端的数据的交互,提高了各个应用端的数据的安全性。
图5为本发明另一个实施例提供的应用于交易量预测的模型训练方法的流程示意图,本发明实施例是在图3所示实施例的基础上,对步骤S301以及步骤S303的进一步细化,如图5所示,本实施例提供的模型训练方法可以包括以下步骤:
S501,依次将每个时间窗口对应的交易量样本数据输入所述第一应用端的底层网络,基于所述第一应用端的底层网络的编码器,对各个时间窗口对应的交易量样本数据进行编码,得到特征编码数据。
具体的,基于各个应用端的底层网络中的编码器,对各个时间窗口对应的样本数据进行自编码,得到各个时间窗口对应的特征编码数据,如128维的特征编码数据。
S502,基于所述第一应用端的底层网络的隐藏层,对所述时间窗口对应的特征编码数据进行特征提取,得到所述时间窗口对应的状态数据。
S503,对各个时间窗口对应的状态数据进行加密,并将加密后的各个时间窗口的状态数据发送至所述服务端。
服务端对接收到的来自各个应用端的各个时间窗口的状态数据进行解密和拼接,从而生成各个时间窗口对应的拼接状态数据。
S504,将所述服务端发送的各个时间窗口对应的拼接状态数据,输入所述第一应用端的上层网络,以得到各个时间窗口对应的预测交易量。
其中,每一时间窗口对应的所述拼接状态数据由所述服务端对各个应用端的底层网络输出的所述时间窗口对应的状态数据拼接而成,所述第二应用端的每一时间窗口对应的状态数据由所述第二应用端的底层网络根据所述时间窗口对应的协变量样本数据生成,所述协变量为与所述交易量相关的特征变量。
S505,针对每个时间窗口,根据所述时间窗口对应的预设交易量和实际交易量,计算所述时间窗口对应的损失函数的值。
其中,损失函数可以为任意一种损失函数,如均方差(Mean Squared Error,MSE)损失函数、平均绝对误差(Mean Absolute Error,MAE)损失函数等。
在每个时间窗口,当上层网络输出该时间窗口的交易量预测值之后,结合该时间窗口的交易量实际值和损失函数,确定该时间窗口对应的预测误差,即该时间窗口对应的损失函数的值。
S506,判断所述时间窗口以及所述时间窗口之前预设数量的各个时间窗口对应的损失函数的值是否满足预设条件。
其中,预设数量可以为3、4、5或者其他数值,可以进行自定义设置。
具体的,若当前时间窗口以及当前时间窗口之前预设数量的各个时间窗口对应的损失函数的值均小于或等于预设阈值,则该预设条件成立,即满足该预设条件,反之,则不满足。
若满足预设条件,则模型训练结束,得到稳定的交易量预测模型,从而可以基于该训练好的交易量预测模型进行交易量预测。
在一些实施例中,还可以设置迭代次数的上限值,如最大迭代次数,当迭代次数达到该最大迭代次数时,则模型训练强制停止。
S507,若否,则根据所述时间窗口对应的预设交易量和实际交易量,计算所述上层网络的损失函数所述时间窗口对应的第一梯度。
当不满足预设条件时,即表示交易量预测模型尚不稳定,则需要通过梯度的反向传导更新模型参数。
具体的,当不满足预设条件时,可以基于梯度下降算法,根据该时间窗口对应的预设交易量和实际交易量,计算上层网络的损失函数该时间窗口的第一梯度,以便于确定使得损失函数下降最快的梯度,从而加快模型收敛或稳定的速度,提高模型训练的效率。
S508,通过在所述上层网络中反向传导所述时间窗口对应的第一梯度,更新所述上层网络的模型参数。
S509,根据模型参数更新后的上层网络以及所述时间窗口对应的第一梯度,生成所述时间窗口对应的第二梯度,并将所述时间窗口对应的第二梯度发送至所述服务端,以经由所述服务端将所述时间窗口对应的第二梯度拆分为各个应用端所述时间窗口对应的第三梯度,以基于所述时间窗口对应的各个第三梯度更新各个应用端的底层网络的模型参数,得到更新后的交易量预测模型。
在本实施例中,针对每个应用端,通过该应用端设置的底层网络的编码器对样本数据进行编码,得到特征编码数据,进而将特征编码数据输入该应用端的底层网络的隐藏层,通过该隐藏层输出该应用终端的状态数据,各个应用终端将加密后的状态数据发送至服务端,基于服务端对加密后的各个状态数据进行解密和拼接,得到拼接状态药数据,并将拼接状态数据发送至第一应用端的上层网络,从而得到预测交易量,当基于损失函数确定模型尚不稳定时,计算损失函数的梯度,实现了一次迭代过程,通过每次迭代对应的梯度进行上层网络的更新,并基于服务端对梯度进行拆分和下发至各个应用端,实现了各个应用端的底层网络的更新,从而得到更新后的模型参数,通过多次迭代,直至交易量预测模型稳定或迭代次数达到预设次数,从而输出训练好的交易量预测模型,以基于该交易量预测模型进行交易量的预测,实现了基于多方数据进行交易量预测,提高了预测准确度,且通过联邦学习的方式,通过服务端实现不同应用端的数据的交互,提高了各个应用端的数据的安全性。
图6为本发明另一个实施例提供的应用于交易量预测的模型训练方法的流程示意图,本发明实施例中提及的交易量预测模型包括设置在各个应用端的底层网络,以及设置在第一应用端的上层网络,该交易量预测模型应用于参与联邦学习的多个参与方,该多个参与方中包括第一应用端、服务端和至少一个第二应用端,本实施例提供的模型训练方法应用于第二应用端,该第二应用端为交易量的协变量的历史数据的提供方,如图6所示,该模型训练方法包括:
S601,依次将每个时间窗口对应的协变量样本数据输入所述第二应用端的底层网络,以得到各个时间窗口对应的状态数据。
S602,将各个时间窗口对应的状态数据发送至服务端,以使所述服务端对各个时间窗口接收到的所述第二应用端以及第一应用端对应的状态数据进行拼接,以得到各个时间窗口对应的拼接状态数据。
其中,所述第一应用端每一时间窗口对应的状态数据通过将每一时间窗口对应的交易量样本数据输入第一应用端的底层网络得到。
S603,针对每个时间窗口,根据所述时间窗口对应的所述第二应用端的第三梯度更新所述第二应用端的底层网络的模型参数。
其中,所述时间窗口对应的第三梯度为服务端将所述时间窗口对应的第二梯度拆分后所得,所述时间窗口对应的所述第二梯度为根据所述时间窗口对应的第一梯度以及基于所述时间窗口对应的第一梯度更新模型参数后的第一应用端的上层网络生成的,所述时间窗口对应的第一梯度为第一应用端根据所述时间窗口对应的第一应用端的上层网络输出的预测交易量以及实际交易量生成的,所述时间窗口对应的预测交易量通过将服务端发送的所述时间窗口对应的拼接状态数据输入所述第一应用端的上层网络得到。
可选的,所述底层网络包括输入层和由预设层数的长短期记忆网络构成的隐藏层,依次将每个时间窗口对应的协变量样本数据输入所述第二应用端的底层网络,以得到各个时间窗口对应的状态数据,包括:针对每个时间窗口,将所述时间窗口对应的协变量样本数据输入所述第二应用端的底层网络的输入层;基于所述第二应用端的底层网络的隐藏层,对所述时间窗口对应的交易量样本数据进行特征提取,得到所述时间窗口对应的状态数据。
可选的,所述输入层为编码器,将所述时间窗口对应的协变量样本数据输入所述第二应用端的底层网络的输入层,包括:基于所述第二应用端的底层网络编码器,对各个时间窗口对应的协变量样本数据进行编码,得到特征编码数据。
相应的,基于所述第二应用端的底层网络的隐藏层,对所述时间窗口对应的协变量样本数据进行特征提取,得到所述时间窗口对应的状态数据,包括:基于所述第二应用端的底层网络的隐藏层,对所述时间窗口对应的特征编码数据进行特征提取,得到所述时间窗口对应的状态数据。
可选的,在得到各个时间窗口对应的状态数据之后,所述方法还包括:对各个时间窗口对应的状态数据进行加密,并将加密后的各个时间窗口的状态数据发送至所述服务端,以使所述服务度对接收到的来自各个应用端的各个时间窗口的状态数据进行解密和拼接,以得到各个时间窗口对应的拼接状态数据。
第二应用端由于仅设置有底层网络,其对应的模型训练方法可以与第一应用端对应的模型训练方法中关于底层网络训练的部分类似,仅将底层网络输入的数据由交易量样本数据替换为协变量样本数据即可,在此不再赘述。
图7为本发明另一个实施例提供的应用于交易量预测的模型训练方法的流程示意图,本发明实施例中提及的交易量预测模型包括设置在各个应用端的底层网络,以及设置在第一应用端的上层网络,该交易量预测模型应用于参与联邦学习的多个参与方,该多个参与方中包括第一应用端、服务端和至少一个第二应用端,本实施例提供的模型训练方法应用于服务端,该服务端不是数据提供方,仅用于进行数据的拼接、拆分和中转,如图7所示,该模型训练方法包括:
S701,针对每个时间窗口,接收各个应用端发送的所述时间窗口的状态数据。
其中,所述第一应用端所述时间窗口对应的状态数据通过将所述时间窗口对应的交易量样本数据输入所述第一应用端的底层网络得到,所述第二应用端所述时间窗口对应的状态数据通过将所述时间窗口对应的协变量样本数据输入所述第二应用端的底层网络得到。
S702,对各个应用端发送的所述时间窗口的状态数据进行拼接,以得到所述时间窗口对应的拼接状态数据。
S703,将所述时间窗口对应的所述拼接状态数据发送至第一应用端的上层网络。
通过将所述时间窗口对应的拼接状态数据输入所述第一应用端的上层网络,得到所述时间窗口对应的预测交易量,以基于所述时间窗口对应的预测交易量以及实际交易量生成所述时间窗口对应的第一梯度,根据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数,以及根据模型参数更新后的上层网络以及所述时间窗口对应的第一梯度,生成所述时间窗口对应的第二梯度。
S704,接收所述时间窗口对应的第二梯度,并对所述时间窗口对应的第二梯度进行拆分,以生成各个应用端所述时间窗口对应的第三梯度,以基于所述时间窗口对应的第三梯度更新各个应用端的底层网络的模型参数,以得到更新后的交易量预测模型。
可选的,每个应用端所发送的所述时间窗口的状态数据为加密后的状态数据,对各个应用端发送的所述时间窗口的状态数据进行拼接,以得到所述时间窗口对应的拼接状态数据包括:对每个应用端发送的所述时间窗口的状态数据进行解密和拼接,以得到所述时间窗口对应的拼接状态数据。
可选的,在生成各个应用端所述时间窗口对应的第三梯度之后,还包括:
对各个应用端所述时间窗口对应的第三梯度进行加密。
相应的,每个应用端在接收到所述时间窗口对应的加密的第三梯度之后,对该加密的第三梯度进行解密,并基于所述时间窗口对应的第三梯度更新所述应用端的底层网络的模型参数。
图8为本发明一个实施例提供的基于模型的交易量预测方法的流程示意图,本发明实施例中提及的交易量预测模型包括设置在各个应用端的底层网络,以及设置在第一应用端的上层网络,该交易量预测模型应用于参与联邦学习的多个参与方,该多个参与方中包括第一应用端、服务端和至少一个第二应用端,本实施例提供的模型训练方法应用于第一应用端,如图8所示,该模型训练方法包括:
S801,获取历史交易量。
其中,历史交易量可以为最新的预设时间长度的第一应用端的交易量,该预设时间长度可以为上述时间窗口对应的长度。
具体的,第一应用端可以通过监测历史时间的每一时间节点的交易量,得到上述历史交易量,进而将该历史交易量输入预先训练的交易量预测模型中,从而基于该交易量预测模型进行交易量预测。
S802,基于预先训练的交易量预测模型,对所述历史交易量以及所述历史交易量对应的协变量数据进行处理,以生成所述历史交易量对应的预测交易量。
其中,交易量预测模型为基于图3或图5对应的任意实施例提供的模型训练方法生成的。
可选的,基于预先训练的交易量预测模型,对所述历史交易量以及所述历史交易量对应的协变量数据进行处理,包括:将所述历史交易量输入所述第一应用端的底层网络,以得到所述第一应用端的底层网络输出的状态数据;将服务端发送的拼接状态数据,输入所述第一应用端的上层网络,以生成所述历史交易量对应的预测交易量,其中,所述拼接状态数据为通过所述服务端将各个应用端的状态数据拼接后生成,所述各个应用端包括所述第一应用端和至少一个第二应用端,所述第二应用端的状态数据为由所述第二应用端的底层网络根据所述历史交易量对应的协变量数据生成。
可选的,所述底层网络包括输入层和由预设层数的长短期记忆网络构成的隐藏层,将所述历史交易量输入所述第一应用端的底层网络,以得到所述第一应用端的底层网络输出的状态数据,包括:将所述历史交易量输入所述第一应用端的底层网络的输入层;基于所述第一应用端的底层网络的隐藏层,对所述历史交易量进行特征提取,以得到所述历史交易量对应的状态数据。
可选的,所述输入层为编码器,将所述历史交易量输入所述第一应用端的底层网络的输入层,包括:基于所述第一应用端的底层网络编码器,对所述历史交易量进行编码,得到特征编码数据。
相应的,基于所述第一应用端的底层网络的隐藏层,对所述历史交易量进行特征提取,以得到所述历史交易量对应的状态数据,包括:基于所述第一应用端的底层网络的隐藏层,对所述特征编码数据进行特征提取,以得到所述历史交易量对应的状态数据。
可选的,所述上层网络包括数据转换层和输出层,将所述服务端发送的各个时间窗口对应的拼接状态数据,将服务端发送的拼接状态数据,输入所述第一应用端的上层网络,以生成所述历史交易量对应的预测交易量,包括:将服务端发送的拼接状态数据,输入所述第一应用端的上层网络的数据转换层,以输出状态转换数据;基于所述上层网络的输出层,对所述状态转换数据进行特征提取,以得到所述历史交易量对应的预测交易量。
可选的,在得到所述第一应用端的底层网络输出的状态数据之后,所述方法还包括:对所述状态数据进行加密,并将加密后的状态数据发送至服务端,以使所述服务端对接收到的来自各个应用端的状态数据进行解密和拼接,以得到拼接状态数据。
图9为本发明一个实施例提供的应用于交易量预测的模型训练装置的示意图,该交易量预测模型应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述装置应用于第一应用端,如图9所示,该模型训练装置包括:第一输入模块910、第二输入模块920、第一梯度生成模块930和第一参数更新模块940。
其中,第一输入模块910,用于依次将每个时间窗口对应的交易量样本数据输入所述第一应用端的底层网络,以得到各个时间窗口对应的状态数据;第二输入模块920,用于将所述服务端发送的各个时间窗口对应的拼接状态数据,输入所述第一应用端的上层网络,以得到各个时间窗口对应的预测交易量,其中,每一时间窗口对应的所述拼接状态数据由所述服务端对各个应用端的底层网络输出的每一时间窗口对应的状态数据拼接而成,所述第二应用端的每一时间窗口对应的状态数据由所述第二应用端的底层网络根据每一时间窗口对应的协变量样本数据生成,所述协变量为与所述交易量相关的特征变量;第一梯度生成模块930,用于针对每个时间窗口,根据所述时间窗口对应的预测交易量以及实际交易量,生成所述时间窗口对应的第一梯度,并根据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数;第一参数更新模块940,用于根据模型参数更新后的上层网络以及所述时间窗口对应的第一梯度,生成所述时间窗口对应的第二梯度,并将所述时间窗口对应的第二梯度发送至所述服务端,以经由所述服务端将所述时间窗口对应的第二梯度拆分为各个应用端所述时间窗口对应的第三梯度,以基于所述时间窗口对应的各个第三梯度更新各个应用端的底层网络的模型参数,得到更新后的交易量预测模型。
可选的,所述底层网络包括输入层和由预设层数的长短期记忆网络构成的隐藏层,第一输入模块910,具体用于:针对每个时间窗口,将所述时间窗口对应的交易量样本数据输入所述第一应用端的底层网络的输入层;基于所述第一应用端的底层网络的隐藏层,对所述时间窗口对应的交易量样本数据进行特征提取,得到所述时间窗口对应的状态数据。
可选的,所述输入层为编码器,第一输入模块910,具体用于:将所述时间窗口对应的交易量样本数据输入所述第一应用端的底层网络的输入层,基于所述第一应用端的底层网络的编码器,对各个时间窗口对应的交易量样本数据进行编码,得到特征编码数据;基于所述第一应用端的底层网络的隐藏层,对所述时间窗口对应的特征编码数据进行特征提取,得到所述时间窗口对应的状态数据。
可选的,所述上层网络包括数据转换层和输出层,第二输入模块920,具体用于:依次将服务端发送的各个时间窗口对应的拼接状态数据,输入所述第一应用端的上层网络的数据转换层,以输出各个时间窗口对应的状态转换数据;基于所述上层网络的输出层,对各个时间窗口对应的状态转换数据进行特征提取,以得到各个时间窗口对应的预测交易量。
可选的,第一梯度生成模块930,具体用于:针对每个时间窗口,根据所述时间窗口对应的预设交易量和实际交易量,计算所述上层网络的损失函数所述时间窗口对应的第一梯度;通过在所述上层网络中反向传导所述时间窗口对应的第一梯度,更新所述上层网络的模型参数。
可选的,所述装置还包括:第一加密模块,用于在得到各个时间窗口对应的状态数据之后,对各个时间窗口对应的状态数据进行加密,并将加密后的各个时间窗口的状态数据发送至所述服务端,以使所述服务度对接收到的来自各个应用端的各个时间窗口的状态数据进行解密和拼接,以得到各个时间窗口对应的拼接状态数据。
可选的,第一梯度生成模块930,具体用于:针对每个时间窗口,根据所述时间窗口对应的预设交易量和实际交易量,计算所述时间窗口对应的损失函数的值;判断所述时间窗口以及所述时间窗口之前预设数量的各个时间窗口对应的损失函数的值是否满足预设条件;若否,则根据所述时间窗口对应的预测交易量以及实际交易量,生成所述时间窗口对应的第一梯度,据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数。
本实施例提供的应用于交易量预测的模型训练装置可执行本发明图3或图5任意实施例提供的应用于交易量预测的模型训练方法,具备执行方法相应的功能模块和有益效果。
图10为本发明另一个实施例提供的应用于交易量预测的模型训练装置的示意图,该交易量预测模型应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述装置应用于第二应用端,如图10所示,该模型训练装置包括:第三输入模块1010、第一数据发送模块1020和第二参数更新模块1030。
其中,第三输入模块1010,用于依次将每个时间窗口对应的协变量样本数据输入所述第二应用端的底层网络,以得到各个时间窗口对应的状态数据;第一数据发送模块1020,用于将各个时间窗口对应的状态数据发送至服务端,以使所述服务端对各个时间窗口接收到的所述第二应用端以及第一应用端对应的状态数据进行拼接,以得到各个时间窗口对应的拼接状态数据,其中,所述第一应用端每一时间窗口对应的状态数据通过将每一时间窗口对应的交易量样本数据输入第一应用端的底层网络得到;第二参数更新模块1030,用于针对每个时间窗口,根据所述时间窗口对应的所述第二应用端的第三梯度更新所述第二应用端的底层网络的模型参数,其中,所述时间窗口对应的第三梯度为服务端将所述时间窗口对应的第二梯度拆分后所得,所述时间窗口对应的所述第二梯度为根据所述时间窗口对应的第一梯度以及基于所述时间窗口对应的第一梯度更新模型参数后的第一应用端的上层网络生成的,所述时间窗口对应的第一梯度为第一应用端根据所述时间窗口对应的第一应用端的上层网络输出的预测交易量以及实际交易量生成的,所述时间窗口对应的预测交易量通过将服务端发送的所述时间窗口对应的拼接状态数据输入所述第一应用端的上层网络得到。
可选的,所述底层网络包括输入层和由预设层数的长短期记忆网络构成的隐藏层,第三输入模块1010,具体用于:针对每个时间窗口,将所述时间窗口对应的协变量样本数据输入所述第二应用端的底层网络的输入层;基于所述第二应用端的底层网络的隐藏层,对所述时间窗口对应的交易量样本数据进行特征提取,得到所述时间窗口对应的状态数据。
可选的,所述输入层为编码器,第三输入模块1010,具体用于:依次将每个时间窗口对应的协变量样本数据输入所述第二应用端的底层网络,基于所述第二应用端的底层网络编码器,对各个时间窗口对应的协变量样本数据进行编码,得到特征编码数据;基于所述第二应用端的底层网络的隐藏层,对所述时间窗口对应的特征编码数据进行特征提取,得到所述时间窗口对应的状态数据。
可选的,所述装置还包括:第二加密模块,用于在得到各个时间窗口对应的状态数据之后,对各个时间窗口对应的状态数据进行加密,并将加密后的各个时间窗口的状态数据发送至所述服务端,以使所述服务度对接收到的来自各个应用端的各个时间窗口的状态数据进行解密和拼接,以得到各个时间窗口对应的拼接状态数据。
本实施例提供的应用于交易量预测的模型训练装置可执行本发明图6对应的实施例提供的应用于交易量预测的模型训练方法,具备执行方法相应的功能模块和有益效果。
图11为本发明另一个实施例提供的应用于交易量预测的模型训练装置的示意图,该交易量预测模型应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述装置应用于服务端,如图11所示,该模型训练装置包括:数据接收模块1110、数据拼接模块1120、第二数据发送模块1130和梯度拆分模块1140。
其中,数据接收模块1110,用于针对每个时间窗口,接收各个应用端发送的所述时间窗口的状态数据,其中,所述第一应用端所述时间窗口对应的状态数据通过将所述时间窗口对应的交易量样本数据输入所述第一应用端的底层网络得到,所述第二应用端所述时间窗口对应的状态数据通过将所述时间窗口对应的协变量样本数据输入所述第二应用端的底层网络得到;数据拼接模块1120,用于对各个应用端发送的所述时间窗口的状态数据进行拼接,以得到所述时间窗口对应的拼接状态数据;第二数据发送模块1130,用于将所述时间窗口对应的所述拼接状态数据发送至第一应用端的上层网络,以通过将所述时间窗口对应的拼接状态数据输入所述第一应用端的上层网络,得到所述时间窗口对应的预测交易量,以基于所述时间窗口对应的预测交易量以及实际交易量生成所述时间窗口对应的第一梯度,根据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数,以及根据模型参数更新后的上层网络以及所述时间窗口对应的第一梯度,生成所述时间窗口对应的第二梯度;梯度拆分模块1140,用于接收所述时间窗口对应的第二梯度,并对所述时间窗口对应的第二梯度进行拆分,以生成各个应用端所述时间窗口对应的第三梯度,以基于所述时间窗口对应的第三梯度更新各个应用端的底层网络的模型参数,以得到更新后的交易量预测模型。
可选的,每个应用端所发送的所述时间窗口的状态数据为加密后的状态数据,数据拼接模块1120,具体用于:对每个应用端发送的所述时间窗口的状态数据进行解密和拼接,以得到所述时间窗口对应的拼接状态数据。
可选的,所述装置还包括:第三加密模块,用于在生成各个应用端所述时间窗口对应的第三梯度之后,对各个应用端所述时间窗口对应的第三梯度进行加密。
本实施例提供的应用于交易量预测的模型训练装置可执行本发明图7对应的实施例提供的应用于交易量预测的模型训练方法,具备执行方法相应的功能模块和有益效果。
图12为本发明一个实施例提供的基于模型的交易量预测装置的结构示意图,所述装置应用于第一应用端,如图12所示,该交易量预测装置包括:历史交易量获取模块1210和交易量预测模块1220。
其中,历史交易量获取模块1210,用于获取历史交易量;交易量预测模块1220,用于基于预先训练的交易量预测模型,对所述历史交易量以及所述历史交易量对应的协变量数据进行处理,以生成所述历史交易量对应的预测交易量,其中,所述交易量预测模型为基于图3或图5对应的任意实施例提供的模型训练方法生成的。
可选的,交易量预测模块1220,包括:状态数据生成单元,用于将所述历史交易量输入所述第一应用端的底层网络,以得到所述第一应用端的底层网络输出的状态数据;交易量预测单元,用于将服务端发送的拼接状态数据,输入所述第一应用端的上层网络,以生成所述历史交易量对应的预测交易量,其中,所述拼接状态数据为通过所述服务端将各个应用端的状态数据拼接后生成,所述各个应用端包括所述第一应用端和至少一个第二应用端,所述第二应用端的状态数据为由所述第二应用端的底层网络根据所述历史交易量对应的协变量数据生成。
可选的,所述底层网络包括输入层和由预设层数的长短期记忆网络构成的隐藏层,状态数据生成单元,具体用于:将所述历史交易量输入所述第一应用端的底层网络的输入层;基于所述第一应用端的底层网络的隐藏层,对所述历史交易量进行特征提取,以得到所述历史交易量对应的状态数据。
可选的,所述输入层为编码器,状态数据生成单元,具体用于:将所述历史交易量输入所述第一应用端的底层网络的输入层,基于所述第一应用端的底层网络编码器,对所述历史交易量进行编码,得到特征编码数据;基于所述第一应用端的底层网络的隐藏层,对所述特征编码数据进行特征提取,以得到所述历史交易量对应的状态数据。
可选的,所述上层网络包括数据转换层和输出层,交易量预测单元,具体用于:将服务端发送的拼接状态数据,输入所述第一应用端的上层网络的数据转换层,以输出状态转换数据;基于所述上层网络的输出层,对所述状态转换数据进行特征提取,以得到所述历史交易量对应的预测交易量。
可选的,所述装置还包括:数据加密模块,用于在得到所述第一应用端的底层网络输出的状态数据之后,对所述状态数据进行加密,并将加密后的状态数据发送至服务端,以使所述服务端对接收到的来自各个应用端的状态数据进行解密和拼接,以得到拼接状态数据。
本实施例提供的基于模型的交易量预测装置可执行本发明图8对应的实施例提供的基于模型的交易量预测方法,具备执行方法相应的功能模块和有益效果。
图13为本发明一个实施例提供的一种电子设备的结构示意图。如图13所示,该电子设备包括:存储器1310以及至少一个处理器1320;所述存储器1310用于存储所述处理器1320的可执行指令;其中,所述处理器1320被配置为经由执行所述可执行指令来执行本发明任意实施例提供的应用于交易量预测的模型训练方法,和/或,基于模型的交易量预测方法。
可选的,存储器1310既可以是独立的,也可以跟处理器1320集成在一起。
可选的,当所述存储器1310是独立于处理器1320之外的器件时,所述电子设备还可以包括:
总线1330,用于将上述器件,即存储器1310和处理器1320连接起来。
相关说明可以对应参见图3、图5以及图6至图8的步骤所对应的相关描述和效果进行理解,此处不做过多赘述。
本发明实施例还提供一种可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现本发明任意实施例提供的交易量预测的模型训练方法或基于模型的交易量预测方法。
本发明实施例还提供一种计算机程序产品,包括计算机程序,所述计算机程序被处理器执行时用于实现本发明任意实施例提供的交易量预测的模型训练方法或基于模型的交易量预测方法。
本领域普通技术人员可以理解:实现上述各方法实施例的全部或部分步骤可以通过程序指令相关的硬件来完成。前述的程序可以存储于一计算机可读取存储介质中。该程序在执行时,执行包括上述各方法实施例的步骤;而前述的存储介质包括:ROM、RAM、磁碟或者光盘等各种可以存储程序代码的介质。
最后应说明的是:以上各实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述各实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或对其中部分或全部技术特征进行等同替换;而这些修改或替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的范围。
Claims (18)
1.一种应用于交易量预测的模型训练方法,其特征在于,应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述方法应用于第一应用端,所述方法包括:
依次将每个时间窗口对应的交易量样本数据输入所述第一应用端的底层网络,以得到各个时间窗口对应的状态数据;
将所述服务端发送的各个时间窗口对应的拼接状态数据,输入所述第一应用端的上层网络,以得到各个时间窗口对应的预测交易量,其中,每一时间窗口对应的所述拼接状态数据由所述服务端对各个应用端的底层网络输出的所述时间窗口对应的状态数据拼接而成,所述第二应用端的每一时间窗口对应的状态数据由所述第二应用端的底层网络根据所述时间窗口对应的协变量样本数据生成,所述协变量为与所述交易量相关的特征变量;
针对每个时间窗口,根据所述时间窗口对应的预测交易量以及实际交易量,生成所述时间窗口对应的第一梯度,并根据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数;
根据模型参数更新后的上层网络以及所述时间窗口对应的第一梯度,生成所述时间窗口对应的第二梯度,并将所述时间窗口对应的第二梯度发送至所述服务端,以经由所述服务端将所述时间窗口对应的第二梯度拆分为各个应用端所述时间窗口对应的第三梯度,以基于所述时间窗口对应的各个第三梯度更新各个应用端的底层网络的模型参数,得到更新后的交易量预测模型。
2.根据权利要求1所述的方法,其特征在于,所述底层网络包括输入层和由预设层数的长短期记忆网络构成的隐藏层,依次将每个时间窗口对应的交易量样本数据输入所述第一应用端的底层网络,以得到各个时间窗口对应的状态数据,包括:
针对每个时间窗口,将所述时间窗口对应的交易量样本数据输入所述第一应用端的底层网络的输入层;
基于所述第一应用端的底层网络的隐藏层,对所述时间窗口对应的交易量样本数据进行特征提取,得到所述时间窗口对应的状态数据。
3.根据权利要求2所述的方法,其特征在于,所述输入层为编码器,将所述时间窗口对应的交易量样本数据输入所述第一应用端的底层网络的输入层,包括:
基于所述第一应用端的底层网络编码器,对各个时间窗口对应的交易量样本数据进行编码,得到特征编码数据;
相应的,基于所述第一应用端的底层网络的隐藏层,对所述时间窗口对应的交易量样本数据进行特征提取,得到所述时间窗口对应的状态数据,包括:
基于所述第一应用端的底层网络的隐藏层,对所述时间窗口对应的特征编码数据进行特征提取,得到所述时间窗口对应的状态数据。
4.根据权利要求1所述的方法,其特征在于,所述上层网络包括数据转换层和输出层,将所述服务端发送的各个时间窗口对应的拼接状态数据,输入所述第一应用端的上层网络,以得到各个时间窗口对应的预测交易量,包括:
依次将服务端发送的各个时间窗口对应的拼接状态数据,输入所述第一应用端的上层网络的数据转换层,以输出各个时间窗口对应的状态转换数据;
基于所述上层网络的输出层,对各个时间窗口对应的状态转换数据进行特征提取,以得到各个时间窗口对应的预测交易量。
5.根据权利要求1-4任一所述的方法,其特征在于,根据所述时间窗口对应的预测交易量以及实际交易量,生成所述时间窗口对应的第一梯度,并根据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数,包括:
根据所述时间窗口对应的预设交易量以及实际交易量,计算所述上层网络的损失函数所述时间窗口对应的第一梯度;
通过在所述上层网络中反向传导所述时间窗口对应的第一梯度,更新所述上层网络的模型参数。
6.根据权利要求1-4任一项所述的方法,其特征在于,在得到各个时间窗口对应的状态数据之后,所述方法还包括:
对各个时间窗口对应的状态数据进行加密,并将加密后的各个时间窗口的状态数据发送至所述服务端,以使所述服务度对接收到的来自各个应用端的各个时间窗口的状态数据进行解密和拼接,以得到各个时间窗口对应的拼接状态数据。
7.根据权利要求1-4任一项所述的方法,其特征在于,根据所述时间窗口对应的预测交易量以及实际交易量,生成所述时间窗口对应的第一梯度,包括:
根据所述时间窗口对应的预设交易量和实际交易量,计算所述时间窗口对应的损失函数的值;
判断所述时间窗口以及所述时间窗口之前预设数量的各个时间窗口对应的损失函数的值是否满足预设条件;
若否,则根据所述时间窗口对应的预测交易量以及实际交易,生成所述时间窗口对应的第一梯度。
8.一种应用于交易量预测的模型训练方法,其特征在于,应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述方法应用于第二应用端,所述方法包括:
依次将每个时间窗口对应的协变量样本数据输入所述第二应用端的底层网络,以得到各个时间窗口对应的状态数据;
将各个时间窗口对应的状态数据发送至服务端,以使所述服务端对各个时间窗口接收到的所述第二应用端以及第一应用端对应的状态数据进行拼接,以得到各个时间窗口对应的拼接状态数据,其中,所述第一应用端每一时间窗口对应的状态数据通过将每一时间窗口对应的交易量样本数据输入第一应用端的底层网络得到;
针对每个时间窗口,根据所述时间窗口对应的所述第二应用端的第三梯度更新所述第二应用端的底层网络的模型参数,其中,所述时间窗口对应的第三梯度为服务端将所述时间窗口对应的第二梯度拆分后所得,所述时间窗口对应的所述第二梯度为根据所述时间窗口对应的第一梯度以及基于所述时间窗口对应的第一梯度更新模型参数后的第一应用端的上层网络生成的,所述时间窗口对应的第一梯度为第一应用端根据所述时间窗口对应的第一应用端的上层网络输出的预测交易量以及实际交易量生成的,所述时间窗口对应的预测交易量通过将服务端发送的所述时间窗口对应的拼接状态数据输入所述第一应用端的上层网络得到。
9.一种应用于交易量预测的模型训练方法,其特征在于,应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述方法应用于服务端,所述方法包括:
针对每个时间窗口,接收各个应用端发送的所述时间窗口的状态数据,其中,所述第一应用端所述时间窗口对应的状态数据通过将所述时间窗口对应的交易量样本数据输入所述第一应用端的底层网络得到,所述第二应用端所述时间窗口对应的状态数据通过将所述时间窗口对应的协变量样本数据输入所述第二应用端的底层网络得到;
对各个应用端发送的所述时间窗口的状态数据进行拼接,以得到所述时间窗口对应的拼接状态数据;
将所述时间窗口对应的所述拼接状态数据发送至第一应用端的上层网络,以通过将所述时间窗口对应的拼接状态数据输入所述第一应用端的上层网络,得到所述时间窗口对应的预测交易量,以基于所述时间窗口对应的预测交易量以及实际交易量生成所述时间窗口对应的第一梯度,根据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数,以及根据模型参数更新后的上层网络以及所述时间窗口对应的第一梯度,生成所述时间窗口对应的第二梯度;
接收所述时间窗口对应的第二梯度,并对所述时间窗口对应的第二梯度进行拆分,以生成各个应用端所述时间窗口对应的第三梯度,以基于所述时间窗口对应的第三梯度更新各个应用端的底层网络的模型参数,以得到更新后的交易量预测模型。
10.一种基于模型的交易量预测方法,其特征在于,所述方法应用于第一应用端,所述方法包括:
获取历史交易量;
基于预先训练的交易量预测模型,对所述历史交易量以及所述历史交易量对应的协变量数据进行处理,以生成所述历史交易量对应的预测交易量,其中,所述交易量预测模型为基于权利要求1至7中任一项所述的方法生成的。
11.根据权利要求10所述的方法,其特征在于,基于预先训练的交易量预测模型,对所述历史交易量以及所述历史交易量对应的协变量数据进行处理,包括:
将所述历史交易量输入所述第一应用端的底层网络,以得到所述第一应用端的底层网络输出的状态数据;
将服务端发送的拼接状态数据,输入所述第一应用端的上层网络,以生成所述历史交易量对应的预测交易量,其中,所述拼接状态数据为通过所述服务端将各个应用端的状态数据拼接后生成,所述各个应用端包括所述第一应用端和至少一个第二应用端,所述第二应用端的状态数据为由所述第二应用端的底层网络根据所述历史交易量对应的协变量数据生成。
12.一种应用于交易量预测的模型训练装置,其特征在于,应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述装置应用于第一应用端,所述装置包括:
第一输入模块,用于依次将每个时间窗口对应的交易量样本数据输入所述第一应用端的底层网络,以得到各个时间窗口对应的状态数据;
第二输入模块,用于将所述服务端发送的各个时间窗口对应的拼接状态数据,输入所述第一应用端的上层网络,以得到各个时间窗口对应的预测交易量,其中,每一时间窗口对应的所述拼接状态数据由所述服务端对各个应用端的底层网络输出的每一时间窗口对应的状态数据拼接而成,所述第二应用端的每一时间窗口对应的状态数据由所述第二应用端的底层网络根据每一时间窗口对应的协变量样本数据生成,所述协变量为与所述交易量相关的特征变量;
第一梯度生成模块,用于针对每个时间窗口,根据所述时间窗口对应的预测交易量以及实际交易量,生成所述时间窗口对应的第一梯度,并根据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数;
第一参数更新模块,用于根据模型参数更新后的上层网络以及所述时间窗口对应的第一梯度,生成所述时间窗口对应的第二梯度,并将所述时间窗口对应的第二梯度发送至所述服务端,以经由所述服务端将所述时间窗口对应的第二梯度拆分为各个应用端所述时间窗口对应的第三梯度,以基于所述时间窗口对应的各个第三梯度更新各个应用端的底层网络的模型参数,得到更新后的交易量预测模型。
13.一种应用于交易量预测的模型训练装置,其特征在于,应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述装置应用于第二应用端,所述装置包括:
第三输入模块,用于依次将每个时间窗口对应的协变量样本数据输入所述第二应用端的底层网络,以得到各个时间窗口对应的状态数据;
第一数据发送模块,用于将各个时间窗口对应的状态数据发送至服务端,以使所述服务端对各个时间窗口接收到的所述第二应用端以及第一应用端对应的状态数据进行拼接,以得到各个时间窗口对应的拼接状态数据,其中,所述第一应用端每一时间窗口对应的状态数据通过将每一时间窗口对应的交易量样本数据输入第一应用端的底层网络得到;
第二参数更新模块,用于针对每个时间窗口,根据所述时间窗口对应的所述第二应用端的第三梯度更新所述第二应用端的底层网络的模型参数,其中,所述时间窗口对应的第三梯度为服务端将所述时间窗口对应的第二梯度拆分后所得,所述时间窗口对应的所述第二梯度为根据所述时间窗口对应的第一梯度以及基于所述时间窗口对应的第一梯度更新模型参数后的第一应用端的上层网络生成的,所述时间窗口对应的第一梯度为第一应用端根据所述时间窗口对应的第一应用端的上层网络输出的预测交易量以及实际交易量生成的,所述时间窗口对应的预测交易量通过将服务端发送的所述时间窗口对应的拼接状态数据输入所述第一应用端的上层网络得到。
14.一种应用于交易量预测的模型训练装置,其特征在于,应用于参与联邦学习的多个参与方,所述多个参与方中包括第一应用端、服务端和至少一个第二应用端,交易量预测模型包括设置在各个应用端的底层网络,以及设置在所述第一应用端的上层网络,所述装置应用于服务端,所述装置包括:
数据接收模块,用于针对每个时间窗口,接收各个应用端发送的所述时间窗口的状态数据,其中,所述第一应用端所述时间窗口对应的状态数据通过将所述时间窗口对应的交易量样本数据输入所述第一应用端的底层网络得到,所述第二应用端所述时间窗口对应的状态数据通过将所述时间窗口对应的协变量样本数据输入所述第二应用端的底层网络得到;
数据拼接模块,用于对各个应用端发送的所述时间窗口的状态数据进行拼接,以得到所述时间窗口对应的拼接状态数据;
第二数据发送模块,用于将所述时间窗口对应的所述拼接状态数据发送至第一应用端的上层网络,以通过将所述时间窗口对应的拼接状态数据输入所述第一应用端的上层网络,得到所述时间窗口对应的预测交易量,以基于所述时间窗口对应的预测交易量以及实际交易量生成所述时间窗口对应的第一梯度,根据所述时间窗口对应的第一梯度更新所述第一应用端的上层网络的模型参数,以及根据模型参数更新后的上层网络以及所述时间窗口对应的第一梯度,生成所述时间窗口对应的第二梯度;
梯度拆分模块,用于接收所述时间窗口对应的第二梯度,并对所述时间窗口对应的第二梯度进行拆分,以生成各个应用端所述时间窗口对应的第三梯度,以基于所述时间窗口对应的第三梯度更新各个应用端的底层网络的模型参数,以得到更新后的交易量预测模型。
15.一种基于模型的交易量预测装置,其特征在于,所述装置应用于第一应用端,所述装置包括:
历史交易量获取模块,用于获取历史交易量;
交易量预测模块,用于基于预先训练的交易量预测模型,对所述历史交易量以及所述历史交易量对应的协变量数据进行处理,以生成所述历史交易量对应的预测交易量,其中,所述交易量预测模型为基于权利要求1至7中任一项所述的方法生成的。
16.一种电子设备,其特征在于,包括:存储器和处理器;
所述存储器用于存储所述处理器可执行指令;
其中,所述处理器配置为经由执行所述可执行指令来执行权利要求1-9任一项所述的方法,和/或,权利要求10-11任一项所述的方法。
17.一种可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1-11任一项所述的方法。
18.一种计算机程序产品,其特征在于,包括计算机程序,所述计算机程序被处理器执行时用于实现权利要求1-11任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210310844.0A CN114638343B (zh) | 2022-03-28 | 2022-03-28 | 模型训练方法、预测方法、装置、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210310844.0A CN114638343B (zh) | 2022-03-28 | 2022-03-28 | 模型训练方法、预测方法、装置、设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114638343A CN114638343A (zh) | 2022-06-17 |
CN114638343B true CN114638343B (zh) | 2025-02-21 |
Family
ID=81950576
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210310844.0A Active CN114638343B (zh) | 2022-03-28 | 2022-03-28 | 模型训练方法、预测方法、装置、设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114638343B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2025000392A1 (zh) * | 2023-06-29 | 2025-01-02 | 华为技术有限公司 | 模型训练方法和通信装置 |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113487423A (zh) * | 2021-07-29 | 2021-10-08 | 中国银行股份有限公司 | 个人信贷风险预测模型训练方法及装置 |
CN113807921A (zh) * | 2021-09-17 | 2021-12-17 | 深圳市数聚湾区大数据研究院 | 数据商品推荐方法与装置、电子设备及计算机可读存储介质 |
Family Cites Families (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20190130476A1 (en) * | 2017-04-25 | 2019-05-02 | Yada Zhu | Management System and Predictive Modeling Method for Optimal Decision of Cargo Bidding Price |
CN113724059A (zh) * | 2020-12-29 | 2021-11-30 | 京东城市(北京)数字科技有限公司 | 联邦学习模型的训练方法、装置和电子设备 |
CN113570192A (zh) * | 2021-06-21 | 2021-10-29 | 天津大学 | 一种基于大数据的农业社交智能服务系统 |
-
2022
- 2022-03-28 CN CN202210310844.0A patent/CN114638343B/zh active Active
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113487423A (zh) * | 2021-07-29 | 2021-10-08 | 中国银行股份有限公司 | 个人信贷风险预测模型训练方法及装置 |
CN113807921A (zh) * | 2021-09-17 | 2021-12-17 | 深圳市数聚湾区大数据研究院 | 数据商品推荐方法与装置、电子设备及计算机可读存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN114638343A (zh) | 2022-06-17 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Qin et al. | Deep learning in physical layer communications | |
CN108874914B (zh) | 一种基于图卷积与神经协同过滤的信息推荐方法 | |
CN110597991B (zh) | 文本分类方法、装置、计算机设备及存储介质 | |
CN110929869A (zh) | 注意力模型的训练方法、装置、设备及存储介质 | |
CN112131888B (zh) | 分析语义情感的方法、装置、设备及存储介质 | |
CN110309275B (zh) | 一种对话生成的方法和装置 | |
CN111881350A (zh) | 一种基于混合图结构化建模的推荐方法与系统 | |
CN110490128A (zh) | 一种基于加密神经网络的手写识别方法 | |
CN113420212B (zh) | 基于深度特征学习的推荐方法、装置、设备及存储介质 | |
CN113542228A (zh) | 基于联邦学习的数据传输方法、装置以及可读存储介质 | |
CN113487423A (zh) | 个人信贷风险预测模型训练方法及装置 | |
CN116644778A (zh) | 量子同态神经网络的构建方法及加密图像分类方法 | |
CN114186256A (zh) | 神经网络模型的训练方法、装置、设备和存储介质 | |
CN113240071A (zh) | 图神经网络处理方法、装置、计算机设备及存储介质 | |
CN114638343B (zh) | 模型训练方法、预测方法、装置、设备及存储介质 | |
CN116415034A (zh) | 一种基于特征重构的文本-视频时序定位方法 | |
CN116186769A (zh) | 基于隐私计算的纵向联邦XGBoost特征衍生方法及相关设备 | |
Tariq et al. | Integrating sustainable big AI: Quantum anonymous semantic broadcast | |
CN115510186A (zh) | 基于意图识别的即时问答方法、装置、设备及存储介质 | |
CN118333105B (zh) | 数据处理方法、装置、设备及可读存储介质 | |
CN108234195B (zh) | 预测网络性能的方法和装置、设备、介质 | |
CN114330514A (zh) | 一种基于深度特征与梯度信息的数据重建方法及系统 | |
Zhang et al. | Obstacle‐transformer: A trajectory prediction network based on surrounding trajectories | |
CN113542271A (zh) | 基于生成对抗网络gan的网络背景流量生成方法 | |
CN111935259A (zh) | 目标帐号集合的确定方法和装置、存储介质及电子设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |