CN116578674B - 联邦变分自编码主题模型训练方法、主题预测方法及装置 - Google Patents

联邦变分自编码主题模型训练方法、主题预测方法及装置 Download PDF

Info

Publication number
CN116578674B
CN116578674B CN202310826329.2A CN202310826329A CN116578674B CN 116578674 B CN116578674 B CN 116578674B CN 202310826329 A CN202310826329 A CN 202310826329A CN 116578674 B CN116578674 B CN 116578674B
Authority
CN
China
Prior art keywords
pruning
coding
model
training
self
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202310826329.2A
Other languages
English (en)
Other versions
CN116578674A (zh
Inventor
李雅文
马成洁
梁美玉
薛哲
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing University of Posts and Telecommunications
Original Assignee
Beijing University of Posts and Telecommunications
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing University of Posts and Telecommunications filed Critical Beijing University of Posts and Telecommunications
Priority to CN202310826329.2A priority Critical patent/CN116578674B/zh
Publication of CN116578674A publication Critical patent/CN116578674A/zh
Application granted granted Critical
Publication of CN116578674B publication Critical patent/CN116578674B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/33Querying
    • G06F16/3331Query processing
    • G06F16/3332Query translation
    • G06F16/3335Syntactic pre-processing, e.g. stopword elimination, stemming
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/33Querying
    • G06F16/3331Query processing
    • G06F16/334Query execution
    • G06F16/3344Query execution using natural language analysis
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/35Clustering; Classification
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/23Clustering techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • G06N3/0455Auto-encoder networks; Encoder-decoder networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Computational Linguistics (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Databases & Information Systems (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Machine Translation (AREA)

Abstract

本申请提供联邦变分自编码主题模型训练方法、主题预测方法及装置,方法包括:在当前的剪枝训练轮次中,接收联邦学习系统中的各个节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数和神经元累计梯度,并对各个局部变分自编码主题模型的模型参数进行聚类以生成目标变分自编码主题模型;基于各个局部变分自编码主题模型的神经元累计梯度对目标变分自编码主题模型进行神经元剪枝处理,得到当前的全局变分自编码主题模型。本申请能够在有效保护本地数据隐私的基础上,能够有效降低模型训练过程的通信和计算开销,能够有效提高采用训练得到的主题模型预测文本数据所属主题类型的预测精度及可靠性。

Description

联邦变分自编码主题模型训练方法、主题预测方法及装置
技术领域
本申请涉及文本主题预测技术领域,尤其涉及联邦变分自编码主题模型训练方法、主题预测方法及装置。
背景技术
主题模型是广泛适用于社会事件数据的建模。传统隐含狄利克雷分布(latentDirichlet Allocation,LDA)在概率隐语义分析(probabilistic latent semanticanalysis, pLSA)模型的基础上加入了贝叶斯概率思想,学习文档的特征表示,为每个文档建模多个主题,有效地解决了数据表示的维度和隐含语义挖掘的问题。近二十年来以LDA为首的贝叶斯主题模型一直是主题分析的主线。但是随着深度学习的发展,目前的新算法更多的转向了使用神经网络的神经主题模型(neuraltopic models,NTMs),旨在通过神经网络学习潜在的文档与主题之间的关系,在理想情况下获得更高质量的主题。
变分自动编码主题模型(autoencoding variational inference for topicmodels,AVITM)由一个编码器-解码器架构和一个推理网络组成,推理网络将词袋( bag ofword,BoW )文档表示映射为连续的潜在表示,解码器网络重构该词袋。它的生成过程类似于LDA,但狄利克雷先验是通过高斯分布来近似的,而加权的专家乘积代替了单个单词上的多项式分布,目的是为了更方便的神经网络训练和使主题更符合人类的判断。
然而,现有的变分自动编码主题模型的训练方式虽然考虑到数据集中场景下的如何去提取文档与主题之间的关系,但在现实情况中在构建一个共享的主题模型以进行多个文档集合之间的比较时,需要满足隐私约束条件。这种限制在多种分析领域中均会遇到,因为数据源机构可能因为机密性或数据保护条例等规定将文本数据作为个人隐私,不愿或不允许将其文本数据进行共享。因此,如何在满足隐私约束的同时保证变分自动编码主题模型的预测精度是亟需解决的问题。
发明内容
鉴于此,本申请实施例提供了联邦变分自编码主题模型训练方法、主题预测方法及装置,以消除或改善现有技术中存在的一个或更多个缺陷。
本申请的一个方面提供了一种联邦变分自编码主题模型训练方法,包括:
在当前的剪枝训练轮次中,接收联邦学习系统中的各个节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数和神经元累计梯度,并对各个所述局部变分自编码主题模型的模型参数进行聚类以生成当前的目标变分自编码主题模型;
基于各个所述局部变分自编码主题模型的神经元累计梯度对所述目标变分自编码主题模型进行神经元剪枝处理,得到当前的全局变分自编码主题模型;
若所述全局变分自编码主题模型当前已收敛或当前的剪枝训练轮次为预设训练次数中的最后一次,则将该全局变分自编码主题模型作为用于根据输入的文本数据对应输出该文本数据所属主题类型的联邦变分自编码主题模型。
进一步地,所述基于各个所述局部变分自编码主题模型的神经元累计梯度对所述目标变分自编码主题模型进行神经元剪枝处理,得到当前的全局变分自编码主题模型,包括:
根据当前的剪枝训练轮次对应的单次剪枝率,其中,所述单次剪枝率小于或等于预设的针对联邦变分自编码主题模型的目标剪枝率;
以当前的剪枝训练轮次对应的单次剪枝率对所述目标变分自编码主题模型进行神经元剪枝处理以得到对应的剪枝后的目标变分自编码主题模型;
在被剪枝的神经元中查找是否包含有神经元累计梯度大于梯度阈值的神经元,若是,则在所述目标变分自编码主题模型中恢复该神经元累计梯度大于梯度阈值的神经元,以生成对应的全局变分自编码主题模型。
进一步地,在所述基于各个所述局部变分自编码主题模型的神经元累计梯度对所述目标变分自编码主题模型进行神经元剪枝处理之前,还包括:
接收针对联邦变分自编码主题模型的目标剪枝率以及预设的渐进式剪枝策略;
根据所述目标剪枝率以及所述渐进式剪枝策略分别设置预设训练次数中的各个剪枝训练轮次各自对应的单次剪枝率。
进一步地,所述渐进式剪枝策略包括:平均剪枝策略;
相对应的,所述根据所述目标剪枝率以及所述渐进式剪枝策略分别设置预设训练次数中的各个剪枝训练轮次各自对应的单次剪枝率,包括:
基于所述平均剪枝策略,以相同的差值将所述目标剪枝率划分为百分比依次递增的各个单次剪枝率,且依次递增的各个所述单次剪枝率与依次执行的各个剪枝训练轮次之间一一对应。
进一步地,所述渐进式剪枝策略包括:快速剪枝策略;
相对应的,所述根据所述目标剪枝率以及所述渐进式剪枝策略分别设置预设训练次数中的各个剪枝训练轮次各自对应的单次剪枝率,包括:
基于所述快速剪枝策略,以依次递减的各个差值将所述目标剪枝率划分为百分比依次递增的各个单次剪枝率,且依次递增的各个所述单次剪枝率与依次执行的各个剪枝训练轮次之间一一对应。
进一步地,在所述接收联邦学习系统中的各个节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数和神经元累计梯度之前,还包括:
根据预设的剪枝轮次间隔,将预设训练次数中的各个训练轮次分别划分为剪枝训练轮次和非剪枝训练轮次,并将对应的划分结果分别发送至联邦学习系统中的各个节点进行存储,以使各个所述节点在非剪枝训练轮次中仅发生各自训练得到的局部变分自编码主题模型的模型参数;
相对应的,所述联邦变分自编码主题模型训练方法还包括:
在当前的非剪枝训练轮次中,接收各个所述节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数,并对各个所述局部变分自编码主题模型的模型参数进行聚类以得到当前的全局变分自编码主题模型。
进一步地,在所述接收联邦学习系统中的各个节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数和神经元累计梯度之前,还包括:
接收联邦学习系统中的各个节点分别发送的词汇集合,其中,每个所述节点预先对本地的语料库进行预处理以得到各自对应的词汇集合;
对各个词汇集合进行聚合处理以形成对应的全局词汇库;
将所述全局词汇库和全局变分自编码主题模型的初始权重分别发送至联邦学习系统中的各个节点,以使各个所述节点根据所述全局词汇库和全局变分自编码主题模型的初始权重对本地的局部变分自编码主题模型进行初始化处理,而后基于在本地词汇集合中获取的文本训练数据对已初始化的局部变分自编码主题模型进行训练,得到局部变分自编码主题模型的模型参数和神经元累计梯度,若经判定当前的训练轮次为剪枝训练轮次,则发出本地的局部变分自编码主题模型的模型参数和神经元累计梯度;
相对应的,在所述得到当前的全局变分自编码主题模型之后,还包括:
若所述全局变分自编码主题模型当前未收敛或当前的剪枝训练轮次不为预设训练次数中的最后一次,则将该全局变分自编码主题模型的模型参数分别发送至各个所述节点,以使各个所述节点基于接收到的模型参数针对各自对应的局部变分自编码主题模型执行下一个所述训练轮次的模型训练。
本申请的第二个方面提供了一种文本主题预测方法,包括:
接收文本数据;
将所述文本数据输入预设的联邦变分自编码主题模型,以使该联邦变分自编码主题模型输出所述文本数据对应的主题类型,其中,所述联邦变分自编码主题模型预先基于所述的联邦变分自编码主题模型训练方法训练得到。
本申请的第三个方面提供了一种联邦变分自编码主题模型训练装置,包括:
联邦学习模块,用于在当前的剪枝训练轮次中,接收联邦学习系统中的各个节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数和神经元累计梯度,并对各个所述局部变分自编码主题模型的模型参数进行聚类以生成当前的目标变分自编码主题模型;
模型剪枝模块,用于基于各个所述局部变分自编码主题模型的神经元累计梯度对所述目标变分自编码主题模型进行神经元剪枝处理,得到当前的全局变分自编码主题模型;
模型生成模块,用于若所述全局变分自编码主题模型当前已收敛或当前的剪枝训练轮次为预设训练次数中的最后一次,则将该全局变分自编码主题模型作为用于根据输入的文本数据对应输出该文本数据所属主题类型的联邦变分自编码主题模型。
本申请的第四个方面提供了一种文本主题预测装置,包括:
数据接收模块,用于接收文本数据;
模型预测模块,用于将所述文本数据输入预设的联邦变分自编码主题模型,以使该联邦变分自编码主题模型输出所述文本数据对应的主题类型,其中,所述联邦变分自编码主题模型预先基于所述的联邦变分自编码主题模型训练方法训练得到。
本申请的第五个方面提供了一种联邦学习系统,包括:服务器和分别与所述服务器之间通信连接的各个客户端设备;
所述服务器用于执行本申请的第一个方面提供的联邦变分自编码主题模型训练方法,各个所述客户端设备分别用于作为各个所述节点;
所述服务器和所述客户端设备还可以执行本申请的第二个方面提供的文本主题预测方法。
本申请的第六个方面提供了一种电子设备,包括存储器、处理器及存储在存储器上并在处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现第一方面所述的联邦变分自编码主题模型训练方法,或者,实现第二方面所述的文本主题预测方法。
本申请的第七个方面提供了一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现第一方面所述的联邦变分自编码主题模型训练方法,或者,实现第二方面所述的文本主题预测方法。
本申请提供的联邦变分自编码主题模型训练方法,在当前的剪枝训练轮次中,接收联邦学习系统中的各个节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数和神经元累计梯度,并对各个所述局部变分自编码主题模型的模型参数进行聚类以生成当前的目标变分自编码主题模型,通过采用联邦学习系统分别对变分自编码主题模型进行训练,能够采用多方协作的方式,在保护本地数据隐私的前提下共同训练变分自编码主题模型,使变分自编码主题模型能够获得更为全面的数据信息,能够在满足隐私约束的同时训练得到高质量的主题模型,进而能够提高主题模型预测文本数据所属的主题类型的预测精度及可靠性。基于各个所述局部变分自编码主题模型的神经元累计梯度对所述目标变分自编码主题模型进行神经元剪枝处理,得到当前的全局变分自编码主题模型,利用模型剪枝技术,能够有效克服联邦学习的通信瓶颈和计算瓶颈,能够有效减少联邦学习过程中的网络上的通信开销以及客户端本地训练所占用的计算资源,进而能够有效提高联邦变分自编码主题模型的训练效率。若所述全局变分自编码主题模型当前已收敛或当前的剪枝训练轮次为预设训练次数中的最后一次,则将该全局变分自编码主题模型作为用于根据输入的文本数据对应输出该文本数据所属主题类型的联邦变分自编码主题模型,能够进一步提高联邦学习过程的有效性及可靠性。也就是说,本申请能够在有效保护本地数据隐私的基础上,能够有效降低模型训练过程的通信和计算开销,能够有效提高采用训练得到的主题模型预测文本数据所属主题类型的预测精度及可靠性。
本申请的附加优点、目的,以及特征将在下面的描述中将部分地加以阐述,且将对于本领域普通技术人员在研究下文后部分地变得明显,或者可以根据本申请的实践而获知。本申请的目的和其它优点可以通过在说明书以及附图中具体指出的结构实现到并获得。
本领域技术人员将会理解的是,能够用本申请实现的目的和优点不限于以上具体所述,并且根据以下详细说明将更清楚地理解本申请能够实现的上述和其他目的。
附图说明
此处所说明的附图用来提供对本申请的进一步理解,构成本申请的一部分,并不构成对本申请的限定。附图中的部件不是成比例绘制的,而只是为了示出本申请的原理。为了便于示出和描述本申请的一些部分,附图中对应部分可能被放大,即,相对于依据本申请实际制造的示例性装置中的其它部件可能变得更大。在附图中:
图1为本申请一实施例中的联邦变分自编码主题模型训练方法的第一种流程示意图。
图2为本申请一实施例中的联邦变分自编码主题模型训练方法的第二种流程示意图。
图3为本申请一实施例中的联邦变分自编码主题模型训练方法的第三种流程示意图。
图4为本申请另一实施例中的文本主题预测方法的流程示意图。
图5为本申请一实施例中的联邦变分自编码主题模型训练装置的结构示意图。
图6为本申请另一实施例中的文本主题预测装置的结构示意图。
图7为本申请应用实例提供的联邦变分自编码主题模型的训练过程举例示意图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚明白,下面结合实施方式和附图,对本申请做进一步详细说明。在此,本申请的示意性实施方式及其说明用于解释本申请,但并不作为对本申请的限定。
在此,还需要说明的是,为了避免因不必要的细节而模糊了本申请,在附图中仅仅示出了与根据本申请的方案密切相关的结构和/或处理步骤,而省略了与本申请关系不大的其他细节。
应该强调,术语“包括/包含”在本文使用时指特征、要素、步骤或组件的存在,但并不排除一个或更多个其它特征、要素、步骤或组件的存在或附加。
在此,还需要说明的是,如果没有特殊说明,术语“连接”在本文不仅可以指直接连接,也可以表示存在中间物的间接连接。
在下文中,将参考附图描述本申请的实施例。在附图中,相同的附图标记代表相同或类似的部件,或者相同或类似的步骤。
主题模型是一种广泛应用于许多领域的数据分析方法,其中包括科技创新(science,technology,and innovation,STI)文档分析。例如,上下文主题模型(contextualized topic models,CTMs )是在AVITM的基础上构建的,通过上下文嵌入来整合先验知识。STI文档分析领域中,主题模型已被广泛应用于比较不同机构资助项目的主题,揭示特定区域或组织的研究优势等问题。然而,在构建一个共享的主题模型以进行多个文档集合之间的比较时,面临着一些挑战,因为这需要满足隐私约束条件。这种限制在STI分析领域中经常遇到,因为资助机构可能因为机密性或《通用数据保护条例》等规定的个人隐私而不愿或不允许共享其文档集合。
基于此,需要在满足隐私约束的同时获得高质量的主题模型,以有效提高在对文本数据进行主题预测的精度和可靠性。本申请的设计人员首先想到采用联邦学习来解决训练用文本数据的来源隐私约束的问题,联邦学习(federated learning, FL)是一种分布式框架,根据该框架有一个或多个中心服务器协调,充当设置协议、隐私保证和节点更新聚合的中介,然后在一组设备上充当客户端训练模型,在训练全局模型的同时保证数据在本地的隐私性。由于FL具有去 究专注于设计类似LDA或基于非负矩阵分解(nonnegativematrixfactorization, NMF)的联邦框架,而另一些研究则选择提出联邦通用主题模型。也就是说,联邦学习是一种分布式框架,根据该框架在一组设备上训练模型,同时保持数据本地化。因此可以考虑采用主题模型(federated topic modeling, FTM),一个基于梅特罗波利斯-黑斯廷斯算法的框架,用于LDA主题模型的集体训练;还有学者提出了另一种联邦主题模型私人和一致的主题发现(private andconsistent topic discovery, PC - TD),但采用了基于嵌入空间的联邦推理框架。最近,还有学者提出了基于NMF的主题建模框架FedNMF。尽管之前的算法通过将其传统的推理过程与安全聚合协议相结合来定制比特姆主题模型方法,但在使用NTM作为支持技术的联邦环境中构建主题模型的工作还很少。
然而,在联邦算法的实现中,在每一个全局模型训练轮次中,每一个参与方都需要给服务器发送完整的模型参数更新。由于现代的深度神经网络(deep neural networks,DNN)模型通常有数百万个参数,向服务器发送如此多的数值将会导致巨大的通信开销,并且这样的通信开销会随着客户端数量和迭代轮次的增加而增加。当存在大量客户端时,从客户端上传模型参数至服务器将成为强化联邦学习的瓶颈。并且联邦模型的应用场景下的节点往往是各种手机等终端设备,存在计算量有限的问题,难以部署较为复杂的模型。换句话说,在联邦算法的实现中,在每一个全局模型训练轮次中,每一个参与方都需要给服务器发送完整的模型参数更新,这会造成巨大的通信开销,导致模型的训练时间急剧增加。并且联邦模型的应用场景下的节点往往是各种手机等终端设备,存在计算量有限的问题,难以部署较为复杂的模型。
因此,在采用联邦学习来解决如何在满足隐私约束的同时保证变分自动编码主题模型的预测精度是亟需解决的问题的基础上,还需要解决联邦学习带来了的巨大的通信开销和计算资源开销的问题。
因此,本申请实施例分别提供一种联邦变分自编码主题模型训练方法、文本主题预测方法、用于实现联邦变分自编码主题模型训练方法的联邦变分自编码主题模型训练装置及实体设备(如服务器等)、用于实现文本主题预测方法的文本主题预测装置及实体设备(如客户端设备等)以及联邦学习系统等,目的是训练一个联邦变分自编码主题模型进行主题和文档的交叉比较,以预测文档的主题类型。在联邦学习过程中,不需要客户端彼此共享或与服务器共享本地语料,这样的全局模型增加了每个客户端从各自对应的语料库中学习到的非协作主题模型的知识增益,同时通过使用模型剪枝技术,保证联邦变分自编码主题模型可以更快的收敛,保证联邦变分自编码主题模型的质效均衡。
在本申请的一个或多个实施例中,联邦变分自编码主题模型是一种联邦主题模型,在联邦学习场景下,多个客户端在保证数据本地化和数据隐私的前提下,共同训练一个主题模型,达到与在数据集中情况下训练主题模型类似的效果。
在本申请的一个或多个实施例中,变分自编码主题模型(autoencodingvariational inference for topic models,AVITM)是基于变分自编码机的主题模型,也是一个神经主题模型,它用变分自编码机代替LDA中的狄利克雷分布对文档和主题分布进行建模,用于捕捉更复杂的文档与主题之间的分布关系。变分自编码机是一个使用自动编码机和变分推理技术的生成模型,它能学习复杂高维数据的潜在结构。
在本申请的一个或多个实施例中,模型剪枝是一种通过移除神经网络中不必要或者冗余的神经元结构来减小神经元规模的深度学习技术。它的目标是在保持模型精度的同时加速模型训练,降低模型大小和推断时间。
具体通过下述实施例进行详细说明。
本申请实施例提供一种可由联邦变分自编码主题模型训练装置实现的联邦变分自编码主题模型训练方法,参见图1,所述联邦变分自编码主题模型训练方法具体包含有如下内容:
步骤100:在当前的剪枝训练轮次中,接收联邦学习系统中的各个节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数和神经元累计梯度,并对各个所述局部变分自编码主题模型的模型参数进行聚类以生成当前的目标变分自编码主题模型。
在步骤100中,剪枝训练轮次是整个联邦变分自编码主题模型中的预设训练次数中的各个训练轮次的类型之一,可以理解的是,各个训练轮次还可以划分有非剪枝训练轮次,也即是说,为了进一步提高联邦变分自编码主题模型训练效率,并不需要在每一轮训练过程中都对模型的神经元进行剪枝处理。
在本申请的一个或多个实施例中,所述模型参数中可以至少包含有模型的权重,所述神经元累计梯度是指该局部变分自编码主题模型中的各个神经元各自依次在每一训练轮次中的梯度的累计加和值。
可以理解的是,局部变分自编码主题模型是指由联邦学习系统中的各个节点分别用各自本地文本训练数据训练得到的变分自编码主题模型,目标变分自编码主题模型是指在剪枝训练轮次中待进行神经元剪枝的聚合后的变分自编码主题模型,全局变分自编码主题模型是指在剪枝训练轮次中剪枝后的变分自编码主题模型,全局变分自编码主题模型还指在非剪枝训练轮次中由各个局部变分自编码主题模型聚合后得到的变分自编码主题模型。其中,节点可以采用客户端设备实现。
在本申请的一个或多个实施例中,全局模型即为全局变分自编码主题模型的简称,而局部模型即为局部变分自编码主题模型的简称。
步骤200:基于各个所述局部变分自编码主题模型的神经元累计梯度对所述目标变分自编码主题模型进行神经元剪枝处理,得到当前的全局变分自编码主题模型。
在步骤200中,在服务器接收到客户端(即在两轮迭代之间的边界处)的模型参数更新后进行对目标变分自编码主题模型进行剪枝,此时剪枝间隔可以始终为每轮迭代次数的整数倍。
步骤300:若所述全局变分自编码主题模型当前已收敛或当前的剪枝训练轮次为预设训练次数中的最后一次,则将该全局变分自编码主题模型作为用于根据输入的文本数据对应输出该文本数据所属主题类型的联邦变分自编码主题模型。
在步骤300中,可以直到网络中权重的相对变化不再进行,或者直到达到预定义的迭代次数后,停止迭代。
从上述描述可知,本申请实施例提供的联邦变分自编码主题模型训练方法,通过采用联邦学习系统分别对变分自编码主题模型进行训练,能够采用多方协作的方式,在保护本地数据隐私的前提下共同训练变分自编码主题模型,使变分自编码主题模型能够获得更为全面的数据信息,能够在满足隐私约束的同时训练得到高质量的主题模型,进而能够提高主题模型预测文本数据所属的主题类型的预测精度及可靠性。利用模型剪枝技术,能够有效克服联邦学习的通信瓶颈和计算瓶颈,能够有效减少联邦学习过程中的网络上的通信开销以及客户端本地训练所占用的计算资源,进而能够有效提高联邦变分自编码主题模型的训练效率。
为了在降低通信开销的基础上,有效提高剪枝的可靠性及智能化程度,在本申请实施例提供的一种联邦变分自编码主题模型训练方法中,参见图2,所述联邦变分自编码主题模型训练方法中的步骤200具体包含有如下内容:
步骤210:根据当前的剪枝训练轮次对应的单次剪枝率,其中,所述单次剪枝率小于或等于预设的针对联邦变分自编码主题模型的目标剪枝率。
具体来说,在剪枝中较常用的剪枝方法是震级剪枝,即通过神经元的权重绝对值大小来对神经元进行剪枝,神经元的权重越小代表这个神经元在组成模型的时候越没有显著的贡献。但考虑到一些神经元虽然初始权重很小,但在训练过程中可能会起到重要作用,且考虑到模训练后期模型较为稳定,因此需要适应性降低剪枝率。
步骤220:以当前的剪枝训练轮次对应的单次剪枝率对所述目标变分自编码主题模型进行神经元剪枝处理以得到对应的剪枝后的目标变分自编码主题模型。
步骤230:在被剪枝的神经元中查找是否包含有神经元累计梯度大于梯度阈值的神经元,若是,则在所述目标变分自编码主题模型中恢复该神经元累计梯度大于梯度阈值的神经元,以生成对应的全局变分自编码主题模型。
可以理解的是,所述梯度阈值可以根据实际应用情形进行设置,在步骤230中,若未在被剪枝的神经元中查找到包含有神经元累计梯度大于梯度阈值的神经元,则直接将剪枝后的目标变分自编码主题模型的作为当前的全局变分自编码主题模型。
为了提高剪枝的有效性及可靠性,进而进一步降低通信开销,在本申请实施例提供的一种联邦变分自编码主题模型训练方法中,参见图2,所述联邦变分自编码主题模型训练方法中的步骤200之前还具体包含有如下内容:
步骤010:接收针对联邦变分自编码主题模型的目标剪枝率以及预设的渐进式剪枝策略;
步骤020:根据所述目标剪枝率以及所述渐进式剪枝策略分别设置预设训练次数中的各个剪枝训练轮次各自对应的单次剪枝率。
具体来说,联邦变分自编码主题模型训练装置(如服务器)预先根据用户录入的针对联邦变分自编码主题模型的目标剪枝率以及预设的渐进式剪枝策略来设置预设训练次数中的各个剪枝训练轮次各自对应的单次剪枝率,能够有效提高后续对目标变分自编码主题模型进行神经元剪枝处理的效率及便捷性。
为了在降低通信开销的基础上提高模型训练及预测精度,在本申请实施例提供的一种联邦变分自编码主题模型训练方法中,所述渐进式剪枝策略包括:平均剪枝策略,参见图3,所述联邦变分自编码主题模型训练方法中的步骤020具体包含有如下内容:
步骤021:基于所述平均剪枝策略,以相同的差值将所述目标剪枝率划分为百分比依次递增的各个单次剪枝率,且依次递增的各个所述单次剪枝率与依次执行的各个剪枝训练轮次之间一一对应。
具体来说,为了在剪枝的同时保留训练过程中尽可能多的信息,将目标剪枝率平均分配到整个训练过程。如设定目标剪枝率为50%,则在模型训练过程达到一半时,达到25%的目标剪枝率。在模型训练完成的时候达到最终的50%的目标剪枝率。这种剪枝方式对于模型训练过程的加速有限,但是可以保证剪枝后的模型达到更高的精度。在模型预测过程中可以大幅减少模型预测时间。
为了提高训练效率并进一步降低通信开销,在本申请实施例提供的一种联邦变分自编码主题模型训练方法中,所述渐进式剪枝策略包括:快速剪枝策略,参见图3,所述联邦变分自编码主题模型训练方法中的步骤020还可以具体包含有如下内容:
步骤022:基于所述快速剪枝策略,以依次递减的各个差值将所述目标剪枝率划分为百分比依次递增的各个单次剪枝率,且依次递增的各个所述单次剪枝率与依次执行的各个剪枝训练轮次之间一一对应。
具体来说,为了加快模型训练速度,在模型训练初期就快速达到目标剪枝率,然后在达到目标剪枝率之后以较小的模型规模继续训练模型。这种方式可能丢失更多的有用信息,但是可以更快的完成模型的训练。这种剪枝方式可以大幅减少模型的训练时间,但是在剪枝过程中可能丢失过多的有用信息,模型最终的精度可能会受到影响。
为了进一步降低通信开销,并进一步提高联邦学习过程的有效性及可靠性,在本申请实施例提供的一种联邦变分自编码主题模型训练方法中,参见图3,所述联邦变分自编码主题模型训练方法中的步骤100之前还可以具体包含有如下内容:
步骤030:根据预设的剪枝轮次间隔,将预设训练次数中的各个训练轮次分别划分为剪枝训练轮次和非剪枝训练轮次,并将对应的划分结果分别发送至各个所述节点进行存储,以使各个所述节点在非剪枝训练轮次中仅发生各自训练得到的局部变分自编码主题模型的模型参数。
具体来说,联邦变分自编码主题模型训练装置(如服务器)预先将剪枝训练轮次和非剪枝训练轮次分别发送至各个所述节点进行存储,进而能够有效提高各个节点确定每次发送何种数据的便捷性。
相对应的,在所述联邦变分自编码主题模型训练方法中的步骤300之前还具体包含有如下内容:
步骤110:在当前的非剪枝训练轮次中,接收联邦学习系统中的各个节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数,并对各个所述局部变分自编码主题模型的模型参数进行聚类以得到当前的全局变分自编码主题模型。
为了进一步提高联邦学习过的有效性及可靠性,在本申请实施例提供的一种联邦变分自编码主题模型训练方法中,参见图3,所述联邦变分自编码主题模型训练方法中的步骤100之前还可以具体包含有如下内容:
步骤040:接收联邦学习系统中的各个节点分别发送的词汇集合,其中,每个所述节点预先对本地的语料库进行预处理以得到各自对应的词汇集合;
步骤050:对各个词汇集合进行聚合处理以形成对应的全局词汇库;
步骤060:将所述全局词汇库和全局变分自编码主题模型的初始权重分别发送至各个所述节点,以使各个所述节点根据所述全局词汇库和全局变分自编码主题模型的初始权重对本地的局部变分自编码主题模型进行初始化处理,而后基于在本地词汇集合中获取的文本训练数据对已初始化的局部变分自编码主题模型进行训练,得到局部变分自编码主题模型的模型参数和神经元累计梯度,若经判定当前的训练轮次为剪枝训练轮次,则发出本地的局部变分自编码主题模型的模型参数和神经元累计梯度。
相对应的,在所述联邦变分自编码主题模型训练方法中的步骤110或步骤200之后还可以具体包含有如下内容:
步骤310:若所述全局变分自编码主题模型当前未收敛或当前的剪枝训练轮次不为预设训练次数中的最后一次,则将该全局变分自编码主题模型的模型参数分别发送至各个所述节点,以使各个所述节点基于接收到的模型参数针对各自对应的局部变分自编码主题模型执行下一个所述训练轮次的模型训练。
基于前述的联邦变分自编码主题模型训练方法的实施例,本申请还提供一种可由文本主题预测装置执行的文本主题预测方法,参见图4,所述文本主题预测方法具体包含有如下内容:
步骤400:接收文本数据;
步骤500:将所述文本数据输入预设的联邦变分自编码主题模型,以使该联邦变分自编码主题模型输出所述文本数据对应的主题类型,其中,所述联邦变分自编码主题模型预先基于所述的联邦变分自编码主题模型训练方法训练得到。
从上述描述可知,本申请实施例提供的文本主题预测方法,能够在满足隐私约束的同时训练得到高质量的主题模型,进而能够提高主题模型预测文本数据所属的主题类型的预测精度及可靠性。
从软件层面来说,本申请还提供一种用于执行所述联邦变分自编码主题模型训练方法中全部或部分内的联邦变分自编码主题模型训练装置,参见图5,所述联邦变分自编码主题模型训练装置具体包含有如下内容:
联邦学习模块10,用于在当前的剪枝训练轮次中,接收联邦学习系统中的各个节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数和神经元累计梯度,并对各个所述局部变分自编码主题模型的模型参数进行聚类以生成当前的目标变分自编码主题模型。
模型剪枝模块20,用于基于各个所述局部变分自编码主题模型的神经元累计梯度对所述目标变分自编码主题模型进行神经元剪枝处理,得到当前的全局变分自编码主题模型。
模型生成模块30,用于若所述全局变分自编码主题模型当前已收敛或当前的剪枝训练轮次为预设训练次数中的最后一次,则将该全局变分自编码主题模型作为用于根据输入的文本数据对应输出该文本数据所属主题类型的联邦变分自编码主题模型。
本申请提供的联邦变分自编码主题模型训练装置的实施例具体可以用于执行上述实施例中的联邦变分自编码主题模型训练方法的实施例的处理流程,其功能在此不再赘述,可以参照上述联邦变分自编码主题模型训练方法实施例的详细描述。
所述联邦变分自编码主题模型训练装置进行联邦变分自编码主题模型训练的部分可以在服务器中执行,也可以在客户端设备中完成。具体可以根据所述客户端设备的处理能力,以及用户使用场景的限制等进行选择。本申请对此不作限定。若所有的操作都在所述客户端设备中完成,所述客户端设备还可以包括处理器,用于联邦变分自编码主题模型训练的具体处理。
上述的客户端设备可以具有通信模块(即通信单元),可以与远程的服务器进行通信连接,实现与所述服务器的数据传输。所述服务器可以包括任务调度中心一侧的服务器,其他的实施场景中也可以包括中间平台的服务器,例如与任务调度中心服务器有通信链接的第三方服务器平台的服务器。所述的服务器可以包括单台计算机设备,也可以包括多个服务器组成的服务器集群,或者分布式装置的服务器结构。
上述服务器与所述客户端设备端之间可以使用任何合适的网络协议进行通信,包括在本申请提交日尚未开发出的网络协议。所述网络协议例如可以包括TCP/IP协议、UDP/IP协议、HTTP协议、HTTPS协议等。当然,所述网络协议例如还可以包括在上述协议之上使用的RPC协议(Remote Procedure Call Protocol,远程过程调用协议)、REST协议(Representational State Transfer,表述性状态转移协议)等。
从上述描述可知,本申请实施例提供的联邦变分自编码主题模型训练装置,通过采用联邦学习系统分别对变分自编码主题模型进行训练,能够采用多方协作的方式,在保护本地数据隐私的前提下共同训练变分自编码主题模型,使变分自编码主题模型能够获得更为全面的数据信息,能够在满足隐私约束的同时训练得到高质量的主题模型,进而能够提高主题模型预测文本数据所属的主题类型的预测精度及可靠性。利用模型剪枝技术,能够有效克服联邦学习的通信瓶颈和计算瓶颈,能够有效减少联邦学习过程中的网络上的通信开销以及客户端本地训练所占用的计算资源,进而能够有效提高联邦变分自编码主题模型的训练效率。
从软件层面来说,本申请还提供一种用于执行所述文本主题预测方法中全部或部分内的文本主题预测装置,参见图6,所述文本主题预测装置具体包含有如下内容:
数据接收模块40,用于接收文本数据;
模型预测模块50,用于将所述文本数据输入预设的联邦变分自编码主题模型,以使该联邦变分自编码主题模型输出所述文本数据对应的主题类型,其中,所述联邦变分自编码主题模型预先基于所述的联邦变分自编码主题模型训练方法训练得到。
本申请提供的文本主题预测装置的实施例具体可以用于执行上述实施例中的文本主题预测方法的实施例的处理流程,其功能在此不再赘述,可以参照上述文本主题预测方法实施例的详细描述。
所述文本主题预测装置进行文本主题预测的部分可以在客户端设备中完成。
基于前述联邦变分自编码主题模型训练装置和文本主题预测装置的实施例,本申请还提供一种联邦学习系统的实施例,所述联邦学习系统具体包含有如下内容:
服务器和分别与所述服务器之间通信连接的各个客户端设备;
所述服务器用于执行所述联邦变分自编码主题模型训练方法,各个所述客户端设备分别用于作为各个所述节点。
所述服务器和所述客户端设备还可以执行本申请的第二个方面提供的文本主题预测方法。
为了进一步说明本方案,本申请还提供一种采用联邦学习系统执行联邦变分自编码主题模型训练方法的具体应用实例,主题建模已成为处理大量文档集合的有效技术,用于发现其中的潜在主题和模式。然而,涉及多方数据的交叉分析时,保护数据隐私成为一个重要问题。为此,联邦主题建模应运而生,它允许多方在不泄露私有数据的情况下共同训练主题模型。在每一个全局模型训练轮次中,每一个参与方都需要给服务器发送完整的模型参数更新,这会造成巨大的通信开销,导致模型的训练时间急剧增加。并且联邦模型的应用场景下的节点往往是各种手机等终端设备,存在计算量有限的问题,难以部署较为复杂的模型。基于此,本申请应用实例主要分为三个步骤:数据预处理、联邦变分自编码主题模型的联邦学习、联邦变分自编码主题模型的渐进式剪枝。
首先对文档数据进行预处理,主要保留名词这一些对于主题分辨能力较强的词语,并将数据集转化为词袋文档用于模型训练。然后由联邦学习系统中的各个客户端发送本地节点的词汇,联邦学习系统中的服务器等待接收到所有节点的词汇,然后将它们合并成一个公共的全局词汇库,用于初始化带有权重W(0)的全局模型。等所有客户端从服务器接收回公共的全局词汇库,用于初始化带有权重W(0)的全局模型后,进行联邦变分自编码主题模型的联邦训练过程。在联邦变分自编码主题模型的训练过程中,使用渐进式剪枝算法,隔一定轮数客户端会将神经网络结点的权重和累计梯度发送到服务器端,然后服务器会据此对联邦变分自编码主题模型进行剪枝操作。经过剪枝操作可以大大减少网络上的通信开销和客户端本地训练的运算开销,且设置模型剪枝率在联邦变分自编码主题模型训练后期足够小,保证联邦变分自编码主题模型可以更快的收敛,保证联邦变分自编码主题模型的质效均衡。
所述联邦变分自编码主题模型训练方法的应用实例具体包含有如下内容:
一、预处理
预处理工作共分为三步:如果训练使用的数据集是中文的,需要对其进行必要的文本预处理。即,各个客户端设备均需要对本地的语料库C(例如,客户端N2的语料库C2)进行清洗和分词之后才能继续进行后续的处理。
S1:数据清洗
在预处理过程中,可以采取以下步骤对训练用的语料库进行清洗:
(a)去掉停顿字;
(b)剔除频次<20的词;
(c)过滤掉网络地址URL(Uniform Resource Locator)、表情符号、井字标签(hashtag)和非中文字符;
(d)删除长度小于10的段落。
注意,缩略语、首字母缩写和俚语仍然用于后面的主题建模。
S2:分词处理
通过Jieba分词工具可以自定义的创建停用词字典,根据训练数据集的特性添加停用词词典,使Jieba分词工具在进行分词工作的时候帮助其识别在训练数据集场景所出现的词语。并使用词性标注功能,将形容词、副词等去除,主要保留名词这一些对于主题分辨能力较强的词语。Jieba分词工具是Python中文分词组件,是针对中文的自然语言处理的分词工具,其原理是利用一个中文词库(如前述的停用词词典),确定汉字之间的关联概率,并将汉字间概率大的组成词组以形成分词结果。
S3:数据转化
最后,各个客户端设备均将各自得到的数据集进行词袋转化,以得到各自对应的词汇集合V l ,其中,V l 表示第l个客户端发送的词汇集合,l=1、2…L;L表示客户端的总数。
二、联邦变分自编码主题模型训练
图7展示了本申请应用实例提出的联邦变分自编码主题模型的训练过程。
1、词汇共识阶段
词汇共识阶段可以具体包含有下述步骤S4-S6:
S4:服务器等待接收所有客户端N l 即图7中的客户端N1、客户端N2至客户端NL的词汇集合V l
S5:服务器将各个所述词汇集合V l 聚合成一个公共集合,即全局词汇库V;
S6:服务器将全局词汇库V和全局联邦变分自编码主题模型的初始权重W(0)分发给各个所述客户端设备,以使各个所述客户端设备后续根据该全局词汇库V初始化全局联邦变分自编码主题模型。
2、联邦平均阶段
联邦平均阶段可以具体包含有下述步骤S7-S12:
S7:所有客户端设备均从服务器接收全局词汇库V和初始权重W(0),并采用全局词汇库V分别初始化各自本地的带有初始权重W(0)的全局联邦变分自编码主题模型。
S8:在每个客户端设备上,客户端设备使用本地的小批量语料库数据(即本地的语料库C中的部分或全部数据)对所述全局联邦变分自编码主题模型进行一定次数的训练,分别得到各个所述客户端设备各自训练得到的局部联邦变分自编码主题模型及本轮训练对应的模型参数;例如第一轮训练对应的模型参数W l (1),包括:客户端设备N1的第一轮训练对应的模型参数W1 (1)、客户端设备N2的第一轮训练对应的模型参数W2 (1)和客户端设备NL的第一轮训练对应的模型参数WL (1);第二轮训练对应模型参数W l (2),包括:客户端设备N1的第二轮训练对应的模型参数W1 (2)、客户端设备N2的第二轮训练对应的模型参数W2 (2)和客户端设备NL的第二轮训练对应的模型参数WL (2)等。
S9:各个所述客户端设备分别将训练得到的局部联邦变分自编码主题模型及第一轮训练对应的模型参数W l (1)上传到服务器。
S10:服务器等待所有客户端发送其本地训练的局部联邦变分自编码主题模型,并对各个所述局部联邦变分自编码主题模型进行聚合得到第一轮训练对应的全局联邦变分自编码主题模型的新的全局模型参数,例如第一轮训练对应的全局模型参数W(1)、第二轮训练对应的全局模型参数W(2)等。
S11:服务器将该全局主题模型更新后的本轮训练对应的全局模型参数发送给所有客户端。
S12:所有客户端设备均从服务器接收本轮训练对应的模型参数,而返回重复执行步骤 S8至步骤S10,直到网络中权重的相对变化不再进行,或者直到达到预定义的迭代次数后,停止迭代。
上述训练过程举例可以参见表1所示的算法1。算法1显示了服务器和客户端节点的运行情况。在算法描述中,Agg(⋅)表示聚集函数,它有多个聚集选项。其中,最常见的是联邦平均,即将每个客户端的模型参数进行求和平均得到全局主题模型。
表1
三、联邦变分自编码主题模型的渐进式剪枝
在联邦变分自编码主题模型算法的实现中,在每一个全局模型训练轮次中,每一个参与方都需要给服务器发送完整的模型参数更新。由于现代的神经网络模型通常有数百万个参数,给协调方发送如此多的数值将会导致巨大的通信开销,并且这样的通信开销会随着参与方数量和迭代轮次的增加而增加。通信开销成为了联邦学习模型训练速度的主要瓶颈。除了通信瓶颈之外,在联邦学习的应用场景中,客户端设备往往是一些边缘计算设备(如手机),它们的计算和通信资源更为有限,难以用庞大模型进行推断。
基于此,本申请应用实例使用一种新的渐进式剪枝技术。在联邦主题模型训练过程中,每隔一定轮数客户端会将神经网络结点的权重和累计梯度发送到服务器端,然后服务器会据此对神经主题模型进行剪枝操作。
本申请应用实例通过模型剪枝操作可以有效压缩联邦变分自编码主题模型的参数数量,有效的减少通信和计算负担,加快模型的训练速度。
渐进式剪枝的具体说明如下:
定义:令k表示迭代总次数,g n (w(k))表示全局模型在目标模型参数w(k)处的随机梯度,在客户n上的全参数空间上计算。此外,令mw(k)表示一个掩码向量,如果w(k)未被剪枝,则该向量为1, 符号“*”表示元素之间的积。客户n可以指代前述的客户端N l l=1、2…L。
在上述步骤S7至步骤S12的联邦训练期间,本申请应用实例采用的渐进式剪枝过程可以与标准的联邦学习(FedAvg)过程一起执行以进一步实现自适应剪枝,在服务器接收到客户端(即在两轮迭代之间的边界处)的参数更新后进行对模型进行剪枝,此时剪枝间隔始终为每轮迭代次数的整数倍。
在每个剪枝过程中,渐进式剪枝找到一组最优的剩余模型参数。然后,对参数进行相应的剪枝或添加回来,使用得到的模型和掩码进行训练,直到下一个剪枝过程。
在剪枝中较常用的剪枝方法是震级剪枝,即通过神经元的权重绝对值大小来对神经元进行剪枝,神经元的权重越小代表这个神经元在组成模型的时候越没有显著的贡献。但考虑到一些神经元虽然初始权重很小,但在训练过程中可能会起到重要作用,且考虑到模训练后期模型较为稳定,因此需要适应性降低剪枝率。
基于此,本申请应用实例的剪枝策略是在客户端本地训练时累计记录神经元的累计梯度,且n=1、2…L,(例如客户端N2的梯度表示为Z2),累积梯度值较大表示该神经元在未来更可能起作用。
具体剪枝流程如图7所示,具体包含有如下内容:
S13:在剪枝轮客户端将包含有最新轮训练对应的模型参数和神经元的累计梯度Zn一起发送到服务器,例如,客户端N1发送包含有第二轮训练对应的模型参数W1 (2)和神经元的累计梯度Z1,客户端N2发送包含有第二轮训练对应的模型参数W2 (2)和神经元的累计梯度Z2,客户端NL发送包含有第二轮训练对应的模型参数WL (2)和神经元的累计梯度Z3
S14:服务器使用联邦平均算法先将接收自各个客户端的模型参数和梯度进行平均,得到未剪枝的全局模型和各神经元的平均梯度。
S15:服务器根据全局模型各神经元的权重对全局模型的神经元进行剪枝,使用mw(k)将相应位置赋值为0,然后再根据被剪枝的神经元的累积梯度,将累积梯度较大的神经元再恢复到模型当中,即将mw(k)相应位置赋值为1。
待模型剪枝过程完成之后,将全局模型进行w(k)* mw(k)的运算,例如对真正的实现剪枝操作,随后将模型转化为稀疏矩阵储存,再将转变后的稀疏矩阵作为新的全局模型的模型参数发送到各客户端进行新一轮的联邦学习。例如,全局模型的第二轮训练对应的全局模型参数W(2)经剪枝后,得到将转变后的稀疏矩阵作为新的全局模型的模型参数
四、目标剪枝率
在上述内容的基础上,由于在渐进式剪枝算法中是多次剪枝到达最终的目标剪枝率。那么为了达到目标剪枝率,本申请应用实例还可以采用两种方式来设定每一次剪枝的目标剪枝率来达到最终的目标剪枝率。
1、目标剪枝率的第一种设定
为了在剪枝的同时保留训练过程中尽可能多的信息,将目标剪枝率平均分配到整个训练过程。如设定目标剪枝率为50%,则在模型训练过程达到一半时,达到25%的目标剪枝率。在模型训练完成的时候达到最终的50%的目标剪枝率,本申请应用实例将使用这一种剪枝策略的联邦变分自编码机主题模型可以叫做Prune-FedAVITM。这种剪枝方式对于模型训练过程的加速有限,但是可以保证剪枝后的模型达到更高的精度。在模型推理过程中可以大幅减少模型推理时间。
2、目标剪枝率的第二种设定
为了加快模型训练速度,在模型训练初期就快速达到目标剪枝率,然后在达到目标剪枝率之后以较小的模型规模继续训练模型。这种方式可能丢失更多的有用信息,但是可以更快的完成模型的训练。
本申请应用实例将这一种剪枝策略的联邦变分自编码机主题模型叫做FastPrune-FedAVITM。这种剪枝方式可以大幅减少模型的训练时间,但是在剪枝过程中可能丢失过多的有用信息,模型最终的精度可能会受到影响。
两种方式只是每次剪枝的目标剪枝率不同,具体的剪枝过程举例均可以参见表2所示的算法2。
表2
综上所述,本申请应用实例提供的联邦变分自编码主题模型训练方法,具有如下有益效果:
1)通过提供一种应用模型剪枝的联邦学习主题模型方式,采用多方协作的方式,在保护本地数据隐私的前提下共同训练主题模型,使模型可以获得更全面的数据信息。并且利用模型剪枝技术,有效克服了联邦学习通信瓶颈和计算瓶颈。
2)本申请应用实例使用了一种新的渐进式剪枝技术。在联邦主题模型训练过程中,每隔一定轮数客户端会将神经网络结点的权重(即模型参数)和累计梯度发送到服务器端,然后服务器会据此对神经主题模型进行剪枝操作。经过剪枝操作可以大大减少网络上的通信开销、客户端本地训练的运算开销。在模型训练完成之后,经过剪枝的联邦主题模型可以大幅加快模型推理技术;
3)为了应对不同的需求提出两种不同的确定模型剪枝率的方法,第一种方法是在整个模型训练过程中缓慢进行剪枝,这种方式对于模型训练过程的加速有限,但是可以保证剪枝后的模型达到更高的精度。在模型推理过程中可以大幅减少模型推理时间。第二种策略是为了加快模型训练速度,在模型训练初期就快速达到目标剪枝率,然后在达到目标剪枝率之后以较小的模型规模继续训练模型。这种方式可能丢失更多的有用信息,但是可以更快的完成模型的训练。
本申请实施例还提供了一种电子设备,该电子设备可以包括处理器、存储器、接收器及发送器,处理器用于执行上述实施例提及的联邦变分自编码主题模型训练和/或文本主题预测方法,其中处理器和存储器可以通过总线或者其他方式连接,以通过总线连接为例。该接收器可通过有线或无线方式与处理器、存储器连接。
处理器可以为中央处理器(Central Processing Unit,CPU)。处理器还可以为其他通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field-ProgrammableGate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等芯片,或者上述各类芯片的组合。
存储器作为一种非暂态计算机可读存储介质,可用于存储非暂态软件程序、非暂态计算机可执行程序以及模块,如本申请实施例中的联邦变分自编码主题模型训练和/或文本主题预测方法对应的程序指令/模块。处理器通过运行存储在存储器中的非暂态软件程序、指令以及模块,从而执行处理器的各种功能应用以及数据处理,即实现上述方法实施例中的联邦变分自编码主题模型训练和/或文本主题预测方法。
存储器可以包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需要的应用程序;存储数据区可存储处理器所创建的数据等。此外,存储器可以包括高速随机存取存储器,还可以包括非暂态存储器,例如至少一个磁盘存储器件、闪存器件、或其他非暂态固态存储器件。在一些实施例中,存储器可选包括相对于处理器远程设置的存储器,这些远程存储器可以通过网络连接至处理器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
所述一个或者多个模块存储在所述存储器中,当被所述处理器执行时,执行实施例中的联邦变分自编码主题模型训练和/或文本主题预测方法。
在本申请的一些实施例中,用户设备可以包括处理器、存储器和收发单元,该收发单元可包括接收器和发送器,处理器、存储器、接收器和发送器可通过总线系统连接,存储器用于存储计算机指令,处理器用于执行存储器中存储的计算机指令,以控制收发单元收发信号。
作为一种实现方式,本申请中接收器和发送器的功能可以考虑通过收发电路或者收发的专用芯片来实现,处理器可以考虑通过专用处理芯片、处理电路或通用芯片实现。
作为另一种实现方式,可以考虑使用通用计算机的方式来实现本申请实施例提供的服务器。即将实现处理器,接收器和发送器功能的程序代码存储在存储器中,通用处理器通过执行存储器中的代码来实现处理器,接收器和发送器的功能。
本申请实施例还提供一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时以实现前述联邦变分自编码主题模型训练和/或文本主题预测方法的步骤。该计算机可读存储介质可以是有形存储介质,诸如随机存储器(RAM)、内存、只读存储器(ROM)、电可编程ROM、电可擦除可编程ROM、寄存器、软盘、硬盘、可移动存储盘、CD-ROM、或技术领域内所公知的任意其它形式的存储介质。
本领域普通技术人员应该可以明白,结合本文中所公开的实施方式描述的各示例性的组成部分、系统和方法,能够以硬件、软件或者二者的结合来实现。具体究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。当以硬件方式实现时,其可以例如是电子电路、专用集成电路(ASIC)、适当的固件、插件、功能卡等等。当以软件方式实现时,本申请的元素是被用于执行所需任务的程序或者代码段。程序或者代码段可以存储在机器可读介质中,或者通过载波中携带的数据信号在传输介质或者通信链路上传送。
需要明确的是,本申请并不局限于上文所描述并在图中示出的特定配置和处理。为了简明起见,这里省略了对已知方法的详细描述。在上述实施例中,描述和示出了若干具体的步骤作为示例。但是,本申请的方法过程并不限于所描述和示出的具体步骤,本领域的技术人员可以在领会本申请的精神后,作出各种改变、修改和添加,或者改变步骤之间的顺序。
本申请中,针对一个实施方式描述和/或例示的特征,可以在一个或更多个其它实施方式中以相同方式或以类似方式使用,和/或与其他实施方式的特征相结合或代替其他实施方式的特征
以上所述仅为本申请的优选实施例,并不用于限制本申请,对于本领域的技术人员来说,本申请实施例可以有各种更改和变化。凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。

Claims (6)

1.一种联邦变分自编码主题模型训练方法,其特征在于,包括:
在当前的剪枝训练轮次中,接收联邦学习系统中的各个节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数和神经元累计梯度,并对各个所述局部变分自编码主题模型的模型参数进行聚类以生成当前的目标变分自编码主题模型;
基于各个所述局部变分自编码主题模型的神经元累计梯度对所述目标变分自编码主题模型进行神经元剪枝处理,得到当前的全局变分自编码主题模型;
若所述全局变分自编码主题模型当前已收敛或当前的剪枝训练轮次为预设训练次数中的最后一次,则将该全局变分自编码主题模型作为用于根据输入的文本数据对应输出该文本数据所属主题类型的联邦变分自编码主题模型;
所述基于各个所述局部变分自编码主题模型的神经元累计梯度对所述目标变分自编码主题模型进行神经元剪枝处理,得到当前的全局变分自编码主题模型,包括:
根据当前的剪枝训练轮次对应的单次剪枝率,其中,所述单次剪枝率小于或等于预设的针对联邦变分自编码主题模型的目标剪枝率;以当前的剪枝训练轮次对应的单次剪枝率对所述目标变分自编码主题模型进行神经元剪枝处理以得到对应的剪枝后的目标变分自编码主题模型;
在被剪枝的神经元中查找是否包含有神经元累计梯度大于梯度阈值的神经元,若是,则在所述目标变分自编码主题模型中恢复该神经元累计梯度大于梯度阈值的神经元,以生成对应的全局变分自编码主题模型;
在所述基于各个所述局部变分自编码主题模型的神经元累计梯度对所述目标变分自编码主题模型进行神经元剪枝处理之前,还包括:
接收针对联邦变分自编码主题模型的目标剪枝率以及预设的渐进式剪枝策略;
根据所述目标剪枝率以及所述渐进式剪枝策略分别设置预设训练次数中的各个剪枝训练轮次各自对应的单次剪枝率;
若所述渐进式剪枝策略为平均剪枝策略;则相对应的,所述根据所述目标剪枝率以及所述渐进式剪枝策略分别设置预设训练次数中的各个剪枝训练轮次各自对应的单次剪枝率,包括:
基于所述平均剪枝策略,以相同的差值将所述目标剪枝率划分为百分比依次递增的各个单次剪枝率,且依次递增的各个所述单次剪枝率与依次执行的各个剪枝训练轮次之间一一对应;
若所述渐进式剪枝策略为快速剪枝策略;则相对应的,所述根据所述目标剪枝率以及所述渐进式剪枝策略分别设置预设训练次数中的各个剪枝训练轮次各自对应的单次剪枝率,包括:
基于所述快速剪枝策略,以依次递减的各个差值将所述目标剪枝率划分为百分比依次递增的各个单次剪枝率,且依次递增的各个所述单次剪枝率与依次执行的各个剪枝训练轮次之间一一对应。
2.根据权利要求1所述的联邦变分自编码主题模型训练方法,其特征在于,在所述接收联邦学习系统中的各个节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数和神经元累计梯度之前,还包括:
根据预设的剪枝轮次间隔,将预设训练次数中的各个训练轮次分别划分为剪枝训练轮次和非剪枝训练轮次,并将对应的划分结果分别发送至联邦学习系统中的各个节点进行存储,以使各个所述节点在非剪枝训练轮次中仅发生各自训练得到的局部变分自编码主题模型的模型参数;
相对应的,所述联邦变分自编码主题模型训练方法还包括:
在当前的非剪枝训练轮次中,接收各个所述节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数,并对各个所述局部变分自编码主题模型的模型参数进行聚类以得到当前的全局变分自编码主题模型。
3.根据权利要求2所述的联邦变分自编码主题模型训练方法,其特征在于,在所述接收联邦学习系统中的各个节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数和神经元累计梯度之前,还包括:
接收联邦学习系统中的各个节点分别发送的词汇集合,其中,每个所述节点预先对本地的语料库进行预处理以得到各自对应的词汇集合;
对各个词汇集合进行聚合处理以形成对应的全局词汇库;
将所述全局词汇库和全局变分自编码主题模型的初始权重分别发送至联邦学习系统中的各个节点,以使各个所述节点根据所述全局词汇库和全局变分自编码主题模型的初始权重对本地的局部变分自编码主题模型进行初始化处理,而后基于在本地词汇集合中获取的文本训练数据对已初始化的局部变分自编码主题模型进行训练,得到局部变分自编码主题模型的模型参数和神经元累计梯度,若经判定当前的训练轮次为剪枝训练轮次,则发出本地的局部变分自编码主题模型的模型参数和神经元累计梯度;
相对应的,在所述得到当前的全局变分自编码主题模型之后,还包括:
若所述全局变分自编码主题模型当前未收敛或当前的剪枝训练轮次不为预设训练次数中的最后一次,则将该全局变分自编码主题模型的模型参数分别发送至各个所述节点,以使各个所述节点基于接收到的模型参数针对各自对应的局部变分自编码主题模型执行下一个所述训练轮次的模型训练。
4.一种文本主题预测方法,其特征在于,包括:
接收文本数据;
将所述文本数据输入预设的联邦变分自编码主题模型,以使该联邦变分自编码主题模型输出所述文本数据对应的主题类型,其中,所述联邦变分自编码主题模型预先基于权利要求1至3任一项所述的联邦变分自编码主题模型训练方法训练得到。
5.一种联邦变分自编码主题模型训练装置,其特征在于,包括:
联邦学习模块,用于在当前的剪枝训练轮次中,接收联邦学习系统中的各个节点各自采用本地的文本训练数据训练得到的局部变分自编码主题模型的模型参数和神经元累计梯度,并对各个所述局部变分自编码主题模型的模型参数进行聚类以生成当前的目标变分自编码主题模型;
模型剪枝模块,用于基于各个所述局部变分自编码主题模型的神经元累计梯度对所述目标变分自编码主题模型进行神经元剪枝处理,得到当前的全局变分自编码主题模型;
模型生成模块,用于若所述全局变分自编码主题模型当前已收敛或当前的剪枝训练轮次为预设训练次数中的最后一次,则将该全局变分自编码主题模型作为用于根据输入的文本数据对应输出该文本数据所属主题类型的联邦变分自编码主题模型;
其中,所述基于各个所述局部变分自编码主题模型的神经元累计梯度对所述目标变分自编码主题模型进行神经元剪枝处理,得到当前的全局变分自编码主题模型,包括:
根据当前的剪枝训练轮次对应的单次剪枝率,其中,所述单次剪枝率小于或等于预设的针对联邦变分自编码主题模型的目标剪枝率;
以当前的剪枝训练轮次对应的单次剪枝率对所述目标变分自编码主题模型进行神经元剪枝处理以得到对应的剪枝后的目标变分自编码主题模型;
在被剪枝的神经元中查找是否包含有神经元累计梯度大于梯度阈值的神经元,若是,则在所述目标变分自编码主题模型中恢复该神经元累计梯度大于梯度阈值的神经元,以生成对应的全局变分自编码主题模型;
在所述基于各个所述局部变分自编码主题模型的神经元累计梯度对所述目标变分自编码主题模型进行神经元剪枝处理之前,还包括:
接收针对联邦变分自编码主题模型的目标剪枝率以及预设的渐进式剪枝策略;
根据所述目标剪枝率以及所述渐进式剪枝策略分别设置预设训练次数中的各个剪枝训练轮次各自对应的单次剪枝率;
若所述渐进式剪枝策略包括:平均剪枝策略;则相对应的,所述根据所述目标剪枝率以及所述渐进式剪枝策略分别设置预设训练次数中的各个剪枝训练轮次各自对应的单次剪枝率,包括:
基于所述平均剪枝策略,以相同的差值将所述目标剪枝率划分为百分比依次递增的各个单次剪枝率,且依次递增的各个所述单次剪枝率与依次执行的各个剪枝训练轮次之间一一对应;
若所述渐进式剪枝策略包括:快速剪枝策略;则相对应的,所述根据所述目标剪枝率以及所述渐进式剪枝策略分别设置预设训练次数中的各个剪枝训练轮次各自对应的单次剪枝率,包括:
基于所述快速剪枝策略,以依次递减的各个差值将所述目标剪枝率划分为百分比依次递增的各个单次剪枝率,且依次递增的各个所述单次剪枝率与依次执行的各个剪枝训练轮次之间一一对应。
6.一种电子设备,包括存储器、处理器及存储在存储器上并在处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至3任一项所述的联邦变分自编码主题模型训练方法,或者,实现如权利要求4所述的文本主题预测方法。
CN202310826329.2A 2023-07-07 2023-07-07 联邦变分自编码主题模型训练方法、主题预测方法及装置 Active CN116578674B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310826329.2A CN116578674B (zh) 2023-07-07 2023-07-07 联邦变分自编码主题模型训练方法、主题预测方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310826329.2A CN116578674B (zh) 2023-07-07 2023-07-07 联邦变分自编码主题模型训练方法、主题预测方法及装置

Publications (2)

Publication Number Publication Date
CN116578674A CN116578674A (zh) 2023-08-11
CN116578674B true CN116578674B (zh) 2023-10-31

Family

ID=87536107

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310826329.2A Active CN116578674B (zh) 2023-07-07 2023-07-07 联邦变分自编码主题模型训练方法、主题预测方法及装置

Country Status (1)

Country Link
CN (1) CN116578674B (zh)

Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2021204040A1 (zh) * 2020-10-29 2021-10-14 平安科技(深圳)有限公司 联邦学习数据处理方法、装置、设备及存储介质
CN114492831A (zh) * 2021-12-23 2022-05-13 北京百度网讯科技有限公司 联邦学习模型的生成方法及其装置
WO2022105714A1 (zh) * 2020-11-23 2022-05-27 华为技术有限公司 数据处理方法、机器学习的训练方法及相关装置、设备
WO2022110720A1 (zh) * 2020-11-24 2022-06-02 平安科技(深圳)有限公司 基于选择性梯度更新的联邦建模方法及相关设备
CN114969312A (zh) * 2022-05-30 2022-08-30 特赞(上海)信息科技有限公司 基于变分自编码器的营销案例主题提取方法及系统
CN115238908A (zh) * 2022-03-15 2022-10-25 华东师范大学 基于变分自编码器、无监督聚类算法和联邦学习的数据生成方法
CN115391522A (zh) * 2022-08-02 2022-11-25 中国科学院计算技术研究所 一种基于社交平台元数据的文本主题建模方法及系统
CN115564062A (zh) * 2022-09-26 2023-01-03 南京理工大学 一种基于模型剪枝和传输压缩优化的联邦学习系统及方法
CN115829027A (zh) * 2022-10-31 2023-03-21 广东工业大学 一种基于对比学习的联邦学习稀疏训练方法及系统

Patent Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2021204040A1 (zh) * 2020-10-29 2021-10-14 平安科技(深圳)有限公司 联邦学习数据处理方法、装置、设备及存储介质
WO2022105714A1 (zh) * 2020-11-23 2022-05-27 华为技术有限公司 数据处理方法、机器学习的训练方法及相关装置、设备
WO2022110720A1 (zh) * 2020-11-24 2022-06-02 平安科技(深圳)有限公司 基于选择性梯度更新的联邦建模方法及相关设备
CN114492831A (zh) * 2021-12-23 2022-05-13 北京百度网讯科技有限公司 联邦学习模型的生成方法及其装置
CN115238908A (zh) * 2022-03-15 2022-10-25 华东师范大学 基于变分自编码器、无监督聚类算法和联邦学习的数据生成方法
CN114969312A (zh) * 2022-05-30 2022-08-30 特赞(上海)信息科技有限公司 基于变分自编码器的营销案例主题提取方法及系统
CN115391522A (zh) * 2022-08-02 2022-11-25 中国科学院计算技术研究所 一种基于社交平台元数据的文本主题建模方法及系统
CN115564062A (zh) * 2022-09-26 2023-01-03 南京理工大学 一种基于模型剪枝和传输压缩优化的联邦学习系统及方法
CN115829027A (zh) * 2022-10-31 2023-03-21 广东工业大学 一种基于对比学习的联邦学习稀疏训练方法及系统

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
Federation Learning for Intrusion Detection Methods by Parse Convolutional Neural Network;Jiechen Luo等;《2022 Second International Conference on Advances in Electrical, Computing, Communication and Sustainable Technologies (ICAECT)》;全文 *
基于联邦学习和卷积神经网络的入侵检测方法;王蓉等;信息网络安全(04);全文 *
机器学习隐私保护研究综述;谭作文等;软件学报(07);全文 *

Also Published As

Publication number Publication date
CN116578674A (zh) 2023-08-11

Similar Documents

Publication Publication Date Title
CN110334201B (zh) 一种意图识别方法、装置及系统
JP7383803B2 (ja) 不均一モデルタイプおよびアーキテクチャを使用した連合学習
US20240135191A1 (en) Method, apparatus, and system for generating neural network model, device, medium, and program product
US11423307B2 (en) Taxonomy construction via graph-based cross-domain knowledge transfer
Elbir et al. A hybrid architecture for federated and centralized learning
CN114091667A (zh) 一种面向非独立同分布数据的联邦互学习模型训练方法
CN114282678A (zh) 一种机器学习模型的训练的方法以及相关设备
CN113673260A (zh) 模型处理方法、装置、存储介质和处理器
Long et al. Fedsiam: Towards adaptive federated semi-supervised learning
Hsieh et al. Fl-hdc: Hyperdimensional computing design for the application of federated learning
Deng et al. Adaptive federated learning with negative inner product aggregation
Ju et al. Efficient convolutional neural networks on Raspberry Pi for image classification
KR20210096405A (ko) 사물 학습모델 생성 장치 및 방법
CN114595815A (zh) 一种面向传输友好的云-端协作训练神经网络模型方法
Saputra et al. Federated learning framework with straggling mitigation and privacy-awareness for AI-based mobile application services
CN114626550A (zh) 分布式模型协同训练方法和系统
CN116578674B (zh) 联邦变分自编码主题模型训练方法、主题预测方法及装置
CN106339072A (zh) 一种基于左右脑模型的分布式大数据实时处理系统及方法
CN116797850A (zh) 基于知识蒸馏和一致性正则化的类增量图像分类方法
CN117034008A (zh) 高效联邦大模型调节方法、系统及相关设备
CN116976461A (zh) 联邦学习方法、装置、设备及介质
Cao et al. Lstm network based traffic flow prediction for cellular networks
EP3767548A1 (en) Delivery of compressed neural networks
CN113033653A (zh) 一种边-云协同的深度神经网络模型训练方法
EP3683733A1 (en) A method, an apparatus and a computer program product for neural networks

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant