CN116432746A - 一种基于提示学习的联邦建模方法、装置、设备、介质 - Google Patents
一种基于提示学习的联邦建模方法、装置、设备、介质 Download PDFInfo
- Publication number
- CN116432746A CN116432746A CN202310500646.5A CN202310500646A CN116432746A CN 116432746 A CN116432746 A CN 116432746A CN 202310500646 A CN202310500646 A CN 202310500646A CN 116432746 A CN116432746 A CN 116432746A
- Authority
- CN
- China
- Prior art keywords
- model
- training
- prompt
- local
- gradient
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 59
- 238000012549 training Methods 0.000 claims abstract description 225
- 238000012545 processing Methods 0.000 claims abstract description 28
- 238000000605 extraction Methods 0.000 claims abstract description 24
- 238000004220 aggregation Methods 0.000 claims description 23
- 230000002776 aggregation Effects 0.000 claims description 23
- 239000013598 vector Substances 0.000 claims description 18
- 238000004590 computer program Methods 0.000 claims description 16
- 238000013140 knowledge distillation Methods 0.000 claims description 13
- 238000010276 construction Methods 0.000 claims description 12
- 238000004364 calculation method Methods 0.000 claims description 8
- 238000011478 gradient descent method Methods 0.000 claims description 7
- 238000003860 storage Methods 0.000 claims description 7
- 239000000284 extract Substances 0.000 claims description 3
- 230000001902 propagating effect Effects 0.000 claims description 2
- 230000008569 process Effects 0.000 abstract description 21
- 238000004891 communication Methods 0.000 description 6
- 238000010586 diagram Methods 0.000 description 4
- 230000009471 action Effects 0.000 description 3
- 230000005540 biological transmission Effects 0.000 description 3
- 230000006835 compression Effects 0.000 description 3
- 238000007906 compression Methods 0.000 description 3
- 238000013135 deep learning Methods 0.000 description 3
- 238000009826 distribution Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 238000013459 approach Methods 0.000 description 2
- 238000013473 artificial intelligence Methods 0.000 description 2
- 238000004422 calculation algorithm Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 238000011161 development Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 230000006870 function Effects 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 238000003062 neural network model Methods 0.000 description 2
- 238000002360 preparation method Methods 0.000 description 2
- 238000011160 research Methods 0.000 description 2
- 241000283074 Equus asinus Species 0.000 description 1
- 230000004931 aggregating effect Effects 0.000 description 1
- 238000000354 decomposition reaction Methods 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 238000007670 refining Methods 0.000 description 1
- 238000009877 rendering Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/098—Distributed learning, e.g. federated learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q40/00—Finance; Insurance; Tax strategies; Processing of corporate or income taxes
- G06Q40/02—Banking, e.g. interest calculation or account maintenance
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- Business, Economics & Management (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Mathematical Physics (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Finance (AREA)
- Accounting & Taxation (AREA)
- Health & Medical Sciences (AREA)
- Evolutionary Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Development Economics (AREA)
- Economics (AREA)
- Marketing (AREA)
- Strategic Management (AREA)
- Technology Law (AREA)
- General Business, Economics & Management (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请公开了一种基于提示学习的联邦建模方法、装置、设备、介质,涉及联邦学习技术领域,包括:利用提示信息生成策略的数据任务提示模型控制参与方的本地预训练模型的学习方向;获取全局训练共享模型对联合建模数据项的第一预测结果和第二预测结果,计算第一预测结果与本地标签的第一梯度;利用本地训练数据训练本地预训练模型,以便输出经过数据任务提示模型进行特征提取和特征处理的第三预测结果,计算第三预测结果与所述第二预测结果的第二梯度;反向传播以更新全局训练共享模型的全局模型参数,更新本地预训练模型、数据任务提示模型,直至完成联邦建模。引导本地预训练模型的学习方向,使本地预训练模型的训练过程更符合联邦建模任务要求。
Description
技术领域
本发明涉及联邦学习技术领域,特别涉及一种基于提示学习的联邦建模方法、装置、设备、介质。
背景技术
目前,在深度学习与大数据技术发展的过程中,效果好的模型往往有着较大的规模和复杂的结构,往往计算效率与资源使用方面开销很大,因此训练出来一个高效的模型需要花费大量的资源,如何有效利用模型并高效扩展成为研究的热点。随着业务场景的日益复杂,传统的单方独立建模方式已经无法满足复杂场景的业务要求。因此,联邦建模成为一种重要的建模手段。而联邦建模的过程中,往往需要大量的训练数据用来训练支撑模型,而在训练模型的过程中又会涉及到不同参与方的隐私数据,并且所有参与方均提供训练数据的情况下,又会造成大量训练数据传输,且目前是根据已有模型进行大模型再训练,不存在大模型再训练的训练方向提示和引导。例如,多家银行通过联合风控提升各自的风控能力。在此背景下,各个联合建模参与方都希望充分利用各方的现有模型资源,如:银行各自拥有的金融风控系统,并尽量避免灾难性遗忘现象,以提高联邦建模效率,同时对数据隐私保护提出更高要求。
综上,如何有效利用现有的业务模型实现更高效、更安全的联邦建模,以提升建模效率和模型准确度是本领域有待解决的技术问题。
发明内容
有鉴于此,本发明的目的在于提供一种基于提示学习的联邦建模方法、装置、设备、介质,能够有效利用现有的业务模型实现更高效、更安全的联邦建模,以提升建模效率和模型准确度。其具体方案如下:
第一方面,本申请公开了一种基于提示学习的联邦建模方法,包括:
利用基于联邦建模任务构建提示信息生成策略的数据任务提示模型控制参与方各自的本地预训练模型的学习方向;
获取全局训练共享模型对联合建模数据项的第一预测结果和第二预测结果,并计算所述第一预测结果与本地标签的第一梯度;
利用所述参与方的本地训练数据训练所述本地预训练模型,以便所述本地预训练模型输出经过所述数据任务提示模型进行特征提取和特征处理的第三预测结果,计算所述第三预测结果与所述第二预测结果的第二梯度;
分别将所述第一梯度和所述第二梯度反向传播以更新所述全局训练共享模型的全局模型参数,并根据所述第二梯度组装本地模型参数,并将所述本地模型参数反馈至所述本地预训练模型,以更新所述数据任务提示模型,直至完成联邦建模。
可选的,所述利用基于联邦建模任务构建提示信息生成策略的数据任务提示模型控制参与方各自的本地预训练模型的学习方向,包括:
基于联邦建模任务和先验知识构建提示信息生成策略的数据任务提示模型;
利用所述数据任务提示模型确定参与方各自的本地预训练模型的提示信息向量。
可选的,所述利用所述参与方的本地训练数据训练所述本地预训练模型,以便所述本地预训练模型输出经过所述数据任务提示模型进行特征提取和特征处理的第三预测结果,包括:
将所述本地训练数据输入至所述本地预训练模型,以便所述本地预训练模型的特征提取模块提取所述本地训练数据的特征向量,并将所述特征向量和所述提示信息向量输入至所述本地预训练模型的任务输出单元处理,以获取第三预测结果。
可选的,所述计算所述第三预测结果与所述第二预测结果的第二梯度,包括:
设置知识蒸馏温度参数,基于所述知识蒸馏温度参数计算所述第三预测结果与所述第二预测结果的第二梯度。
可选的,所述分别将所述第一梯度和所述第二梯度反向传播以更新所述全局训练共享模型的全局模型参数,并根据所述第二梯度组装本地模型参数,包括:
通过梯度下降法分别将所述第一梯度和所述第二梯度反向传播,以更新所述全局训练共享模型的全局模型参数;
通过各个参与方将所述全局训练共享模型的更新梯度进行压缩组装,并将所有组装后的更新梯度发送至联邦聚合节点,以便所述联邦聚合节点将所述组装后的更新梯度重新分解,并进行梯度聚合,以获取聚合后的梯度信息;
利用所述梯度信息组装本地模型参数,并将所述本地模型参数发送至对应的所述参与方。
可选的,所述将所述本地模型参数发送至对应的所述参与方之后,还包括:
根据所述本地模型参数对各个所述参与方的全局训练共享模型更新。
可选的,所述将所述本地模型参数反馈至所述本地预训练模型,以更新所述数据任务提示模型,直至完成联邦建模,包括:
利用所述本地模型参数调整所述数据任务提示模型,并生成提示模型参数,然后利用所述提示模型参数更新所述数据任务提示模型,直至完成联邦建模。
第二方面,本申请公开了一种基于提示学习的联邦建模装置,包括:
方向确定模块,用于利用基于联邦建模任务构建提示信息生成策略的数据任务提示模型控制参与方各自的本地预训练模型的学习方向;
第一梯度计算模块,用于获取全局训练共享模型对联合建模数据项的第一预测结果和第二预测结果,并计算所述第一预测结果与本地标签的第一梯度;
第二梯度计算模块,用于利用所述参与方的本地训练数据训练所述本地预训练模型,以便所述本地预训练模型输出经过所述数据任务提示模型进行特征提取和特征处理的第三预测结果,计算所述第三预测结果与所述第二预测结果的第二梯度;
模型训练模块,用于分别将所述第一梯度和所述第二梯度反向传播以更新所述全局训练共享模型的全局模型参数,并根据所述第二梯度组装本地模型参数,并将所述本地模型参数反馈至所述本地预训练模型,以更新所述数据任务提示模型,直至完成联邦建模。
第三方面,本申请公开了一种电子设备,包括:
存储器,用于保存计算机程序;
处理器,用于执行所述计算机程序,以实现前述公开的基于提示学习的联邦建模方法的步骤。
第四方面,本申请公开了一种计算机可读存储介质,用于存储计算机程序;其中,所述计算机程序被处理器执行时实现前述公开的基于提示学习的联邦建模方法的步骤。
由此可见,本申请公开了一种基于提示学习的联邦建模方法,包括:利用基于联邦建模任务构建提示信息生成策略的数据任务提示模型控制参与方各自的本地预训练模型的学习方向;获取全局训练共享模型对联合建模数据项的第一预测结果和第二预测结果,并计算所述第一预测结果与本地标签的第一梯度;利用所述参与方的本地训练数据训练所述本地预训练模型,以便所述本地预训练模型输出经过所述数据任务提示模型进行特征提取和特征处理的第三预测结果,计算所述第三预测结果与所述第二预测结果的第二梯度;分别将所述第一梯度和所述第二梯度反向传播以更新所述全局训练共享模型的全局模型参数,并根据所述第二梯度组装本地模型参数,并将所述本地模型参数反馈至所述本地预训练模型,以更新所述数据任务提示模型,直至完成联邦建模。可见,通过根据联邦建模任务构建提示信息生成策略的数据任务提示模型,然后利用该数据任务提示模型对本地预训练模型的训练过程的模型参数进行调整,且引导本地预训练模型的学习方向,使本地预训练模型的训练过程更符合联邦建模任务要求,然后训练全局训练共享模型和本地预训练模型,并且在全局训练共享模型和本地预训练模型之间不断训练调整过程中,不断更新模型参数,以实现最终联邦建模的预测模型。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据提供的附图获得其他的附图。
图1为本申请公开的一种基于提示学习的联邦建模方法流程图;
图2为本申请公开的一种具体的基于提示学习的联邦建模方法流程图;
图3为本申请公开的一种基于提示学习和知识蒸馏的联邦建模结构流程图;
图4为本申请公开的一种基于提示学习的联邦建模装置结构示意图;
图5为本申请公开的一种电子设备结构图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
目前,在深度学习与大数据技术发展的过程中,效果好的模型往往有着较大的规模和复杂的结构,往往计算效率与资源使用方面开销很大,因此训练出来一个高效的模型需要花费大量的资源,如何有效利用模型并高效扩展成为研究的热点。随着业务场景的日益复杂,传统的单方独立建模方式已经无法满足复杂场景的业务要求。因此,联邦建模成为一种重要的建模手段。而联邦建模的过程中,往往需要大量的训练数据用来训练支撑模型,而在训练模型的过程中又会涉及到不同参与方的隐私数据,并且所有参与方均提供训练数据的情况下,又会造成大量训练数据传输,且目前是根据已有模型进行大模型再训练,不存在大模型再训练的训练方向提示和引导。例如,多家银行通过联合风控提升各自的风控能力。在此背景下,各个联合建模参与方都希望充分利用各方的现有模型资源,如:银行各自拥有的金融风控系统,并尽量避免灾难性遗忘现象,以提高联邦建模效率,同时对数据隐私保护提出更高要求。
为此,本申请提供了一种基于提示学习的联邦建模方案,能够有效利用现有的业务模型实现更高效、更安全的联邦建模,以提升建模效率和模型准确度。
参照图1所示,本发明实施例公开了一种基于提示学习的联邦建模方法,包括:
步骤S11:利用基于联邦建模任务构建提示信息生成策略的数据任务提示模型控制参与方各自的本地预训练模型的学习方向。
本实施例中,基于联邦建模任务和先验知识构建提示信息生成策略的数据任务提示模型;利用所述数据任务提示模型确定参与方各自的本地预训练模型的提示信息向量。可以理解的是,根据联邦建模任务需求,构建提示信息生成策略,并将先验知识融入提示信息生成策略中,以指导各参与方在本地预训练模型的学习方向。
步骤S12:获取全局训练共享模型对联合建模数据项的第一预测结果和第二预测结果,并计算所述第一预测结果与本地标签的第一梯度。
本实施例中,利用构建的全局训练共享模型对联合建模数据项进行预测,以得到第一预测结果GM-Hard-output和第二预测结果GM-Soft-output,可以理解的是,在预测之前,构建全局训练共享模型GM,其中,所述全局训练共享模型GM是各个联邦建模参与方共同训练形成的模型,用于完成各方的联邦建模任务,其核心是神经网络模型,具体由特征提取单元、特征处理单元以及任务输出单元等模块构成。构建完成后,通过该全局训练共享模型进行预测能够获取两种方式下的预测结果,第一预测结果GM-Hard-output为直接的具体的预测类别,而第二预测结果GM-Soft-output为携带不同概率的不同预测类别,例如:当输入一张图片,第一预测结果为“马”,第二预测结果为“0.8是马、0.12是驴、0.08是狗”。当获取到第一预测结果后,计算第一预测结果和预先获取的输入数据对应的本地标签的第一梯度。
步骤S13:利用所述参与方的本地训练数据训练所述本地预训练模型,以便所述本地预训练模型输出经过所述数据任务提示模型进行特征提取和特征处理的第三预测结果,计算所述第三预测结果与所述第二预测结果的第二梯度。
本实施例中,将所述本地训练数据输入至所述本地预训练模型,以便所述本地预训练模型的特征提取模块提取所述本地训练数据的特征向量,并将所述特征向量和所述提示信息向量输入至所述本地预训练模型的任务输出单元处理,以获取第三预测结果。可以理解的是,利用参与方的本地训练数据训练本地预训练模型,这样一来能够避免隐私数据的泄露,同时,在利用本地训练数据对本地预训练模型的训练之前,构建本地预训练模型,具体的,本地预训练模型LM是联邦建模参与方自身拥有的预测模型,实现其业务场景需求,其核心是深度学习神经网络模型,由特征提取单元、特征处理单元以及任务输出单元等模块构成,用于完成业务功能任务的预测。当将本地训练数据输入至本地预训练模型后,提示模型Prompt-Model根据联邦建模任务需求设置的提示信息构建策略,由数据提示生成模块Data-Prompt-Gen和任务提示生成模块Task-Prompt-Gen构成,提示信息的构建基于领域知识等手段,主要用于设定调整所述本地预训练模型LM的参数;所述的数据提示生成模块Data-Prompt-Gen主要负责调整所述的联邦建模本地预训练模型LM的输入参数特征提取单元,使其满足联合建模输入数据的要求,并引导特征提取单元更符合联邦建模任务要求;所述的任务提示生成模块Task-Prompt-Gen主要负责调整所述的联邦建模本地预训练模型LM的任务输出单元,使其生成的预测结果满足联合建模输出的要求,并引导任务输出单元输出更符合联邦建模任务要求的预测结果。使本地预训练模型输出第三预测结果LM-Prompt-Soft-output,然后计算第三预测结果LM-Prompt-Soft-output与第二预测结果GM-Soft-output之间的第二梯度。通过提示学习在不显著改变预训练语言模型结构和参数的情况下,通过向输入增加提示信息,提升下游任务预测效果的机器学习。
步骤S14:分别将所述第一梯度和所述第二梯度反向传播以更新所述全局训练共享模型的全局模型参数,并根据所述第二梯度组装本地模型参数,并将所述本地模型参数反馈至所述本地预训练模型,以更新所述数据任务提示模型,直至完成联邦建模。
本实施例中,将上述获取的第一梯度和第二梯度反向传播,并利用反向传播的误差更新全局训练共享模型GM参数;需要注意的是,全局训练共享模型GM与本地预训练模型之间构建一个Teacher-Student知识蒸馏模式,知识蒸馏作为模型压缩和训练加速的重要手段,实现将知识从大模型向小模型传输的高效传输,使在联邦建模参与方本地训练所述全局训练共享模型GM发挥联邦建模全局共享模型参数知识语义级别的压缩和提炼作用。可以减少模型参数传输量,降低通信成本,有效提高了联邦建模的训练效率和最终模型的准确性。然后,根据所述第二梯度组装本地模型参数,并将所述本地模型参数反馈至所述本地预训练模型,以更新所述数据任务提示模型,直至完成联邦建模。
由此可见,本申请公开了一种基于提示学习的联邦建模方法,包括:利用基于联邦建模任务构建提示信息生成策略的数据任务提示模型控制参与方各自的本地预训练模型的学习方向;获取全局训练共享模型对联合建模数据项的第一预测结果和第二预测结果,并计算所述第一预测结果与本地标签的第一梯度;利用所述参与方的本地训练数据训练所述本地预训练模型,以便所述本地预训练模型输出经过所述数据任务提示模型进行特征提取和特征处理的第三预测结果,计算所述第三预测结果与所述第二预测结果的第二梯度;分别将所述第一梯度和所述第二梯度反向传播以更新所述全局训练共享模型的全局模型参数,并根据所述第二梯度组装本地模型参数,并将所述本地模型参数反馈至所述本地预训练模型,以更新所述数据任务提示模型,直至完成联邦建模。可见,通过根据联邦建模任务构建提示信息生成策略的数据任务提示模型,然后利用该数据任务提示模型对本地预训练模型的训练过程的模型参数进行调整,且引导本地预训练模型的学习方向,使本地预训练模型的训练过程更符合联邦建模任务要求,然后训练全局训练共享模型和本地预训练模型,并且在全局训练共享模型和本地预训练模型之间不断训练调整过程中,不断更新模型参数,以实现最终联邦建模的预测模型。
参照图2所示,本发明实施例公开了一种具体的基于提示学习的联邦建模方法,相对于上一实施例,本实施例对技术方案作了进一步的说明和优化。
具体的:
步骤S21:利用基于联邦建模任务构建提示信息生成策略的数据任务提示模型控制参与方各自的本地预训练模型的学习方向。
步骤S22:获取全局训练共享模型对联合建模数据项的第一预测结果和第二预测结果,并计算所述第一预测结果与本地标签的第一梯度。
步骤S23:利用所述参与方的本地训练数据训练所述本地预训练模型,以便所述本地预训练模型输出经过所述数据任务提示模型进行特征提取和特征处理的第三预测结果。
其中,步骤S21、步骤S22、步骤S23中更加详细的处理过程请参照前述公开的实施例内容,在此不再进行赘述。
步骤S24:设置知识蒸馏温度参数,基于所述知识蒸馏温度参数计算所述第三预测结果与所述第二预测结果的第二梯度。
本实施例中,设置知识蒸馏温度参数,也即实现了在本地预训练模型与全局训练共享模型之间的Teacher-Student模式交互策略的全局训练共享模型参数的知识语义级别的压缩和提炼,从而实现联邦建模全局共享模型的知识蒸馏和本地训练,进而获取计算的第三预测结果与第二预测结果的第二梯度。
步骤S25:通过梯度下降法分别将所述第一梯度和所述第二梯度反向传播,以更新所述全局训练共享模型的全局模型参数。
本实施例中,计算第一预测结果GM-Hard-Output与本地标签Real-Lable的第一梯度,采用梯度下降法,将误差反向传播更新所述的全局训练共享模型GM参数;当适当调整知识蒸馏温度参数后,计算所述的第二预测结果GM-Soft-Output与所述第三预测结果LM-Prompt-Soft-output的梯度,采用梯度下降法,将误差反向传播更新所述全局训练共享模型GM参数。
本实施例中,通过各个参与方将所述全局训练共享模型的更新梯度进行压缩组装,并将所有组装后的更新梯度发送至联邦聚合节点,以便所述联邦聚合节点将所述组装后的更新梯度重新分解,并进行梯度聚合,以获取聚合后的梯度信息;利用所述梯度信息组装本地模型参数,并将所述本地模型参数发送至对应的所述参与方。可以理解的是,各个联邦建模参与方将所述的全局训练共享模型GM的更新梯度进行压缩组装,汇聚到所述的联邦聚合节点FL-Server;所述联邦聚合节点FL-Server通过参数聚合模块Para-Aggregate将来各个联邦建模参与方的所述的全局训练共享模型GM的更新梯度进行重新分解,并实现梯度聚合;所述联邦聚合节点FL-Server通过参数更新模块Para-Upd根据各联邦建模参与方上传的梯度信息,进行参数选择组装,将更新后的参数反馈给各联邦建模方。其中,所述参数聚合模块Para-Aggregate是负责聚合联邦建模参与方的梯度参数;参数更新模块Para-Upd负责根据将更新后的参数反馈给各联邦建模方的所述的全局训练共享模型GM进行更新;提示模型分发模块Prompt-Dist将根据模型参数更新情况来调整所述的提示模型,并下发给各个联邦建模参与方。
本实施例中,所述将所述本地模型参数发送至对应的所述参与方之后,还包括:根据所述本地模型参数对各个所述参与方的全局训练共享模型更新。可以理解的是,各个联邦建模参与方接收到更新后的参数,对所述全局训练共享模型GM进行更新。
步骤S26:基于所述全局模型参数获取本地模型参数,并利用所述本地模型参数调整所述数据任务提示模型,并生成提示模型参数,然后利用所述提示模型参数更新所述数据任务提示模型,直至完成联邦建模。
本实施例中,所述提示模型分发模块Prompt-Dist根据模型参数更新情况来调整所述的提示模型,并下发给各个联邦建模参与方;各个联邦建模参与方接收到更新后的提示模型参数,更新本地的提示模型Prompt-Model;持续执行整个模型训练过程,直至满足联邦建模任务要求,形成最终联邦建模预测模型。
本实施例中,参照图2所示,在联邦建模训练前,首先进行联邦建模任务的准备过程和联邦建模过程,具体如下:
步骤101、根据联邦建模预测任务的需求,选定影响预测任务的数据特征,设定模型结构和损失函数。
步骤102、联邦聚合节点FL-Server将全局训练共享模型GM分发到各个联邦建模参与方。
步骤103、根据联邦建模预测任务的,设置提示模型Prompt-Model的初始参数。
步骤104、联邦聚合节点FL-Server将提示模型Prompt-Model分发到各个联邦建模参与方。
步骤105、各个参与方准备用于联邦建模的训练数据、联邦建模本地预训练模型LM以及本地标签数据。
当联邦建模的准备工作完成后,开始联邦建模任务训练,具体如下:
步骤201、各个联邦建模参与方将所述联合建模数据项FL-Input通过各自所述全局训练共享模型GM,生成所述第一预测结果GM-Hard-Output和第二预测结果GM-Soft-Output。
步骤202、各个联邦建模参与方将所述本地训练数据Input输入到所述的本地预训练模型LM中,所述提示模型Prompt-Model的数据提示生成模块Data-Prompt-Gen生成提示信息向量并入提示模型Prompt-Model的特征提取单元,实现针对联邦建模预测任务的特征提取。
步骤203、本地预训练模型LM将特征提取单元提取的特征进行特征处理,形成特征向量。
步骤204、将特征向量与提示模型Prompt-Model的任务提示生成模块Task-Prompt-Gen生成提示信息向量,输入到本地预训练模型LM的任务输出单元,产生第三预测结果LM-Prompt-Soft-output。
步骤205、计算第一预测结果GM-Hard-Output与所述的本地标签Real-Lable的梯度,采用梯度下降法,将误差反向传播更新所述的全局训练共享模型GM参数。
步骤206、适当调整知识蒸馏温度参数,计算第二预测结果GM-Soft-Output与第三预测结果LM-Prompt-Soft-output的梯度,采用梯度下降法,将误差反向传播更新所述的全局训练共享模型GM参数。
步骤207、各个联邦建模参与方将全局训练共享模型GM的更新梯度进行压缩组装,汇聚到所述的联邦聚合节点FL-Server。
步骤208、联邦聚合节点FL-Server通过所述的参数聚合模块Para-Aggregate将来各个联邦建模参与方的所述的全局训练共享模型GM的更新梯度进行重新分解,并实现梯度聚合。
步骤209、联邦聚合节点FL-Server通过所述的参数更新模块Para-Upd根据各联邦建模参与方上传的梯度信息,进行参数选择组装,将更新后的参数反馈给各联邦建模方。
步骤210、各个联邦建模参与方接收到更新后的参数,对全局训练共享模型GM进行更新。
步骤211、提示模型分发模块Prompt-Dist根据模型参数更新情况来调整所述的提示模型,并下发给各个联邦建模参与方。
步骤212、各个联邦建模参与方接收到更新后的提示模型参数,更新本地的提示模型Prompt-Model。
步骤213、持续执行步骤201至步骤212,直到满足联邦建模任务要求,形成最终联邦建模预测模型。
本实施例中,当联邦建模预测模型完成后,利用所述联邦建模预测模型进行联邦预测任务的推理,具体的,将所述的联邦建模形成的预测模型在数据应用端进行部署;根据联邦建模任务的需求进行数据输入,所述联邦建模预测模型将输出预测结果;持续收集预测结果反馈数据,不断优化模型,同时发掘预测任务内在联系,不断优化设计更新提示模型。
由此可见,通过采用梯度安全聚合方式的同时,针对不同参与方本地个性化模型进行有针对性的梯度参数分解压缩,进一步保障了各方的隐私数据和安全性。最后,通过持续收集反馈信息,不断提炼建模任务特征和优化提示模型,以形成最佳联合建模模型,用于联合推理任务。
参照图4所示,本发明实施例还相应公开了一种基于提示学习的联邦建模装置,包括:
方向确定模块11,用于利用基于联邦建模任务构建提示信息生成策略的数据任务提示模型控制参与方各自的本地预训练模型的学习方向;
第一梯度计算模块12,用于获取全局训练共享模型对联合建模数据项的第一预测结果和第二预测结果,并计算所述第一预测结果与本地标签的第一梯度;
第二梯度计算模块13,用于利用所述参与方的本地训练数据训练所述本地预训练模型,以便所述本地预训练模型输出经过所述数据任务提示模型进行特征提取和特征处理的第三预测结果,计算所述第三预测结果与所述第二预测结果的第二梯度;
模型训练模块14,用于分别将所述第一梯度和所述第二梯度反向传播以更新所述全局训练共享模型的全局模型参数,并根据所述第二梯度组装本地模型参数,并将所述本地模型参数反馈至所述本地预训练模型,以更新所述数据任务提示模型,直至完成联邦建模。
由此可见,本申请公开了利用基于联邦建模任务构建提示信息生成策略的数据任务提示模型控制参与方各自的本地预训练模型的学习方向;获取全局训练共享模型对联合建模数据项的第一预测结果和第二预测结果,并计算所述第一预测结果与本地标签的第一梯度;利用所述参与方的本地训练数据训练所述本地预训练模型,以便所述本地预训练模型输出经过所述数据任务提示模型进行特征提取和特征处理的第三预测结果,计算所述第三预测结果与所述第二预测结果的第二梯度;分别将所述第一梯度和所述第二梯度反向传播以更新所述全局训练共享模型的全局模型参数,并根据所述第二梯度组装本地模型参数,并将所述本地模型参数反馈至所述本地预训练模型,以更新所述数据任务提示模型,直至完成联邦建模。可见,通过根据联邦建模任务构建提示信息生成策略的数据任务提示模型,然后利用该数据任务提示模型对本地预训练模型的训练过程的模型参数进行调整,且引导本地预训练模型的学习方向,使本地预训练模型的训练过程更符合联邦建模任务要求,然后训练全局训练共享模型和本地预训练模型,并且在全局训练共享模型和本地预训练模型之间不断训练调整过程中,不断更新模型参数,以实现最终联邦建模的预测模型。
进一步的,本申请实施例还公开了一种电子设备,图5是根据一示例性实施例示出的电子设备20结构图,图中的内容不能认为是对本申请的使用范围的任何限制。
图5为本申请实施例提供的一种电子设备20的结构示意图。该电子设备20,具体可以包括:至少一个处理器21、至少一个存储器22、电源23、通信接口24、输入输出接口25和通信总线26。其中,所述存储器22用于存储计算机程序,所述计算机程序由所述处理器21加载并执行,以实现前述任一实施例公开的基于提示学习的联邦建模方法中的相关步骤。另外,本实施例中的电子设备20具体可以为电子计算机。
本实施例中,电源23用于为电子设备20上的各硬件设备提供工作电压;通信接口24能够为电子设备20创建与外界设备之间的数据传输通道,其所遵循的通信协议是能够适用于本申请技术方案的任意通信协议,在此不对其进行具体限定;输入输出接口25,用于获取外界输入数据或向外界输出数据,其具体的接口类型可以根据具体应用需要进行选取,在此不进行具体限定。
其中,处理器21可以包括一个或多个处理核心,比如4核心处理器、8核心处理器等。处理器21可以采用DSP(Digital Signal Processing,数字信号处理)、FPGA(Field-Programmable Gate Array,现场可编程门阵列)、PLA(Programmable Logic Array,可编程逻辑阵列)中的至少一种硬件形式来实现。处理器21也可以包括主处理器和协处理器,主处理器是用于对在唤醒状态下的数据进行处理的处理器,也称CPU(Central ProcessingUnit,中央处理器);协处理器是用于对在待机状态下的数据进行处理的低功耗处理器。在一些实施例中,处理器21可以在集成有GPU(Graphics Processing Unit,图像处理器),GPU用于负责显示屏所需要显示的内容的渲染和绘制。一些实施例中,处理器21还可以包括AI(Artificial Intelligence,人工智能)处理器,该AI处理器用于处理有关机器学习的计算操作。
另外,存储器22作为资源存储的载体,可以是只读存储器、随机存储器、磁盘或者光盘等,其上所存储的资源可以包括操作系统221、计算机程序222等,存储方式可以是短暂存储或者永久存储。
其中,操作系统221用于管理与控制电子设备20上的各硬件设备以及计算机程序222,以实现处理器21对存储器22中海量数据223的运算与处理,其可以是Windows Server、Netware、Unix、Linux等。计算机程序222除了包括能够用于完成前述任一实施例公开的由电子设备20执行的基于提示学习的联邦建模方法的计算机程序之外,还可以进一步包括能够用于完成其他特定工作的计算机程序。数据223除了可以包括电子设备接收到的由外部设备传输进来的数据,也可以包括由自身输入输出接口25采集到的数据等。
进一步的,本申请还公开了一种计算机可读存储介质,用于存储计算机程序;其中,所述计算机程序被处理器执行时实现前述公开的基于提示学习的联邦建模方法。关于该方法的具体步骤可以参考前述实施例中公开的相应内容,在此不再进行赘述。
本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其它实施例的不同之处,各个实施例之间相同或相似部分互相参见即可。对于实施例公开的装置而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。
专业人员还可以进一步意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、计算机软件或者二者的结合来实现,为了清楚地说明硬件和软件的可互换性,在上述说明中已经按照功能一般性地描述了各示例的组成及步骤。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。结合本文中所公开的实施例描述的方法或算法的步骤可以直接用硬件、处理器执行的软件模块,或者二者的结合来实施。软件模块可以置于随机存储器(RAM)、内存、只读存储器(ROM)、电可编程ROM、电可擦除可编程ROM、寄存器、硬盘、可移动磁盘、CD-ROM、或技术领域内所公知的任意其它形式的存储介质中。
最后,还需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
以上对本发明所提供的一种基于提示学习的联邦建模方法、装置、设备、介质进行了详细介绍,本文中应用了具体个例对本发明的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本发明的方法及其核心思想;同时,对于本领域的一般技术人员,依据本发明的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本发明的限制。
Claims (10)
1.一种基于提示学习的联邦建模方法,其特征在于,包括:
利用基于联邦建模任务构建提示信息生成策略的数据任务提示模型控制参与方各自的本地预训练模型的学习方向;
获取全局训练共享模型对联合建模数据项的第一预测结果和第二预测结果,并计算所述第一预测结果与本地标签的第一梯度;
利用所述参与方的本地训练数据训练所述本地预训练模型,以便所述本地预训练模型输出经过所述数据任务提示模型进行特征提取和特征处理的第三预测结果,计算所述第三预测结果与所述第二预测结果的第二梯度;
分别将所述第一梯度和所述第二梯度反向传播以更新所述全局训练共享模型的全局模型参数,并根据所述第二梯度组装本地模型参数,并将所述本地模型参数反馈至所述本地预训练模型,以更新所述数据任务提示模型,直至完成联邦建模。
2.根据权利要求1所述的基于提示学习的联邦建模方法,其特征在于,所述利用基于联邦建模任务构建提示信息生成策略的数据任务提示模型控制参与方各自的本地预训练模型的学习方向,包括:
基于联邦建模任务和先验知识构建提示信息生成策略的数据任务提示模型;
利用所述数据任务提示模型确定参与方各自的本地预训练模型的提示信息向量。
3.根据权利要求2所述的基于提示学习的联邦建模方法,其特征在于,所述利用所述参与方的本地训练数据训练所述本地预训练模型,以便所述本地预训练模型输出经过所述数据任务提示模型进行特征提取和特征处理的第三预测结果,包括:
将所述本地训练数据输入至所述本地预训练模型,以便所述本地预训练模型的特征提取模块提取所述本地训练数据的特征向量,并将所述特征向量和所述提示信息向量输入至所述本地预训练模型的任务输出单元处理,以获取第三预测结果。
4.根据权利要求1所述的基于提示学习的联邦建模方法,其特征在于,所述计算所述第三预测结果与所述第二预测结果的第二梯度,包括:
设置知识蒸馏温度参数,基于所述知识蒸馏温度参数计算所述第三预测结果与所述第二预测结果的第二梯度。
5.根据权利要求1所述的基于提示学习的联邦建模方法,其特征在于,所述分别将所述第一梯度和所述第二梯度反向传播以更新所述全局训练共享模型的全局模型参数,并根据所述第二梯度组装本地模型参数,包括:
通过梯度下降法分别将所述第一梯度和所述第二梯度反向传播,以更新所述全局训练共享模型的全局模型参数;
通过各个参与方将所述全局训练共享模型的更新梯度进行压缩组装,并将所有组装后的更新梯度发送至联邦聚合节点,以便所述联邦聚合节点将所述组装后的更新梯度重新分解,并进行梯度聚合,以获取聚合后的梯度信息;
利用所述梯度信息组装本地模型参数,并将所述本地模型参数发送至对应的所述参与方。
6.根据权利要求5所述的基于提示学习的联邦建模方法,其特征在于,所述将所述本地模型参数发送至对应的所述参与方之后,还包括:
根据所述本地模型参数对各个所述参与方的全局训练共享模型更新。
7.根据权利要求1所述的基于提示学习的联邦建模方法,其特征在于,所述将所述本地模型参数反馈至所述本地预训练模型,以更新所述数据任务提示模型,直至完成联邦建模,包括:
利用所述本地模型参数调整所述数据任务提示模型,并生成提示模型参数,然后利用所述提示模型参数更新所述数据任务提示模型,直至完成联邦建模。
8.一种基于提示学习的联邦建模装置,其特征在于,包括:
方向确定模块,用于利用基于联邦建模任务构建提示信息生成策略的数据任务提示模型控制参与方各自的本地预训练模型的学习方向;
第一梯度计算模块,用于获取全局训练共享模型对联合建模数据项的第一预测结果和第二预测结果,并计算所述第一预测结果与本地标签的第一梯度;
第二梯度计算模块,用于利用所述参与方的本地训练数据训练所述本地预训练模型,以便所述本地预训练模型输出经过所述数据任务提示模型进行特征提取和特征处理的第三预测结果,计算所述第三预测结果与所述第二预测结果的第二梯度;
模型训练模块,用于分别将所述第一梯度和所述第二梯度反向传播以更新所述全局训练共享模型的全局模型参数,并根据所述第二梯度组装本地模型参数,并将所述本地模型参数反馈至所述本地预训练模型,以更新所述数据任务提示模型,直至完成联邦建模。
9.一种电子设备,其特征在于,包括:
存储器,用于保存计算机程序;
处理器,用于执行所述计算机程序,以实现如权利要求1至7任一项所述的基于提示学习的联邦建模方法的步骤。
10.一种计算机可读存储介质,其特征在于,用于存储计算机程序;其中,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述的基于提示学习的联邦建模方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310500646.5A CN116432746A (zh) | 2023-04-28 | 2023-04-28 | 一种基于提示学习的联邦建模方法、装置、设备、介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310500646.5A CN116432746A (zh) | 2023-04-28 | 2023-04-28 | 一种基于提示学习的联邦建模方法、装置、设备、介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116432746A true CN116432746A (zh) | 2023-07-14 |
Family
ID=87094418
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310500646.5A Pending CN116432746A (zh) | 2023-04-28 | 2023-04-28 | 一种基于提示学习的联邦建模方法、装置、设备、介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116432746A (zh) |
-
2023
- 2023-04-28 CN CN202310500646.5A patent/CN116432746A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
KR102422729B1 (ko) | 학습 데이터 증강 정책 | |
EP3596664B1 (en) | Generating discrete latent representations of input data items | |
US20200293838A1 (en) | Scheduling computation graphs using neural networks | |
EP4383136A2 (en) | Population based training of neural networks | |
US20210117786A1 (en) | Neural networks for scalable continual learning in domains with sequentially learned tasks | |
CN114172820A (zh) | 跨域sfc动态部署方法、装置、计算机设备及存储介质 | |
US20230306258A1 (en) | Training video data generation neural networks using video frame embeddings | |
CN111292262A (zh) | 图像处理方法、装置、电子设备以及存储介质 | |
CN111651989B (zh) | 命名实体识别方法和装置、存储介质及电子装置 | |
CN113034523A (zh) | 图像处理方法、装置、存储介质及计算机设备 | |
US11514313B2 (en) | Sampling from a generator neural network using a discriminator neural network | |
Liang et al. | Generative AI-driven semantic communication networks: Architecture, technologies and applications | |
CN117475020A (zh) | 图像生成方法、装置、设备及介质 | |
CN112394982A (zh) | 生成语音识别系统的方法、装置、介质及电子设备 | |
CN116432746A (zh) | 一种基于提示学习的联邦建模方法、装置、设备、介质 | |
CN115001692A (zh) | 模型更新方法及装置、计算机可读存储介质和电子设备 | |
CN111709784B (zh) | 用于生成用户留存时间的方法、装置、设备和介质 | |
CN114818613A (zh) | 一种基于深度强化学习a3c算法的对话管理模型构建方法 | |
CN112395490B (zh) | 用于生成信息的方法和装置 | |
CN113361574A (zh) | 数据处理模型的训练方法、装置、电子设备及存储介质 | |
CN113223121A (zh) | 视频生成方法、装置、电子设备及存储介质 | |
CN117093259B (zh) | 一种模型配置方法及相关设备 | |
CN113762532B (zh) | 联邦学习模型的训练方法、装置、电子设备和存储介质 | |
CN115661238B (zh) | 可行驶区域生成方法、装置、电子设备和计算机可读介质 | |
CN117788275A (zh) | 元宇宙ugc摄影作品ai风格化实现方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |