CN114048780A - 基于联邦学习的脑电信号分类模型训练方法及装置 - Google Patents
基于联邦学习的脑电信号分类模型训练方法及装置 Download PDFInfo
- Publication number
- CN114048780A CN114048780A CN202111347340.8A CN202111347340A CN114048780A CN 114048780 A CN114048780 A CN 114048780A CN 202111347340 A CN202111347340 A CN 202111347340A CN 114048780 A CN114048780 A CN 114048780A
- Authority
- CN
- China
- Prior art keywords
- user
- classification model
- electroencephalogram
- local
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F2218/00—Aspects of pattern recognition specially adapted for signal processing
- G06F2218/12—Classification; Matching
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F3/00—Input arrangements for transferring data to be processed into a form capable of being handled by the computer; Output arrangements for transferring data from processing unit to output unit, e.g. interface arrangements
- G06F3/01—Input arrangements or combined input and output arrangements for interaction between user and computer
- G06F3/011—Arrangements for interaction with the human body, e.g. for user immersion in virtual reality
- G06F3/015—Input arrangements based on nervous system activity detection, e.g. brain waves [EEG] detection, electromyograms [EMG] detection, electrodermal response detection
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F2218/00—Aspects of pattern recognition specially adapted for signal processing
- G06F2218/08—Feature extraction
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Physics & Mathematics (AREA)
- Health & Medical Sciences (AREA)
- Signal Processing (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Artificial Intelligence (AREA)
- Biomedical Technology (AREA)
- Dermatology (AREA)
- General Health & Medical Sciences (AREA)
- Neurology (AREA)
- Neurosurgery (AREA)
- Human Computer Interaction (AREA)
- Measurement And Recording Of Electrical Phenomena And Electrical Characteristics Of The Living Body (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本申请适用于生物信息技术领域,提供了一种基于联邦学习的脑电信号分类模型训练方法及装置,该方法包括:将服务器端的脑电信号分类模型发送给K个用户端;接收每个用户端发送的本地模型梯度;根据用户端的本地模型梯度,获取用户端的重要性评估值;根据K个用户端的重要性评估值,从K个用户端中确定多个目标用户端;根据目标用户端的本地模型梯度和重要性评估值,更新服务器端的脑电信号分类模型的网络参数;若服务器端的脑电信号分类模型未收敛,则返回将服务器端的脑电信号分类模型发送给K个用户端的步骤,直至服务器端的脑电信号分类模型收敛。本申请能在充分利用所有用户的有效信息情况下,提升脑电信号分类模型的精度及收敛速度。
Description
技术领域
本申请属于生物信息技术领域,尤其涉及一种基于联邦学习的脑电信号分类模型训练方法及装置。
背景技术
基于情感识别的脑机接口(BCI,Brain Computer Interface)通过在情感交互实验中采集用户的脑电信号,并对脑电信号进行特征提取和解码,可以识别用户真正的情感状态和意图,从而实现用户和设备间的友好通信及交互。基于脑电信号的情感分析有广泛的应用场景,例如情感障碍疾病的辅助诊断和抑郁症等心理治疗干预等。
基于深度学习的情感识别模型往往是数据驱动,要求有大量的训练数据。然而由于脑电图(EEG,Electroencephalographic)信号的采集过程繁琐及个体间差异性巨大的特点,EEG数据往往以多个小数据集的形式分散存在于各个用户。为了构建高精度的情感识别模型,现有方法致力于通过共享不同用户之间的数据,利用知识迁移和领域自适应等技术来有效利用其他用户的有用信息和提升目标用户的情感识别率。但在数据共享的过程中,如果包含了人的身份特征及思想情感等私密信息的脑电信号,一旦被滥用或者非法阅读传播,将造成个人隐私的泄露。
目前脑电信号分类模型主要有:基于EEGNet(EEGNet是为专门一般的脑电图识别任务而设计的通用紧凑的卷积神经网络)的脑电信号分类模型,和基于联邦迁移学习(FTL,Federate Transfer Learning)的脑电信号分类模型。其中EEGNet以原脑电信号作为输入,为每个用户训练出端到端的有竞争力的情感识别网络。但是由于用户的脑电信号个体差异性大,直接用上所有用户的数据训练一个统一的网络往往导致共享模型(即脑电信号分类模型)的精度低,为此基于EEGNet训练的网络只能利用每个用户本地的数据单独训练情感识别网络,忽略了其他用户的数据及所能提供的有效信息,造成了数据浪费的问题。而基于FTL的方法尽管利用了联邦学习来有效利用其他用户的数据信息,同时也满足了用户本地数据不共享的需求。但是该方法是以脑电信号的空间协方差矩阵作为输入,损失了原脑电信号的部分有效信息。另外,FTL依赖于联邦平均算法,该算法在联合训练的过程中,随机选择部分本地模型的梯度,通过无区别的简单平均聚合来更新服务器的梯度,忽略了不同用户的数据质量和重要性,这将导致每次更新服务器模型的梯度变化不稳定,不利于共享模型(即脑电信号分类模型)的精度,而且往往收敛速度慢,给模型训练造成一定难度。
发明内容
本申请实施例提供了一种基于联邦学习的脑电信号分类模型训练方法及装置,可以解决脑电信号分类模型的精度低、且收敛速度慢的问题。
第一方面,本申请实施例提供了一种基于联邦学习的脑电信号分类模型训练方法,应用于服务器端,该方法包括:
将所述服务器端的脑电信号分类模型发送给K个用户端;
接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;
根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;
根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;
根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;
若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。
其中,所述根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端的步骤,包括:
按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端。
其中,所述根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数的步骤,包括:
对每个所述目标用户端的重要性评估值进行归一化处理;
根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度;
根据更新后的全局梯度,更新所述服务器端的脑电信号分类模型的网络参数。
其中,所述根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度的步骤,包括:
其中,所述对每个所述目标用户端的重要性评估值进行归一化处理的步骤,包括:
其中,所述根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值的步骤,包括:
通过公式μk=αk×βk,计算第k个用户端的重要性评估值;
其中,μk表示第k个用户端的重要性评估值,αk=nk/n,nk表示第k个用户端的本地训练集所包含的本地样本量,n表示K个用户端的本地训练集所包含的本地样本量的总和,K表示用户端的数量, 表示第t-1轮更新时所述服务器端的全局梯度,表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
其中,所述方法还包括:
在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。
第二方面,本申请实施例提供了一种基于联邦学习的脑电信号分类模型训练装置,应用于服务器端,该装置包括:
发送模块,用于将所述服务器端的脑电信号分类模型发送给K个用户端;
接收模块,用于接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;
获取模块,用于根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;
第一确定模块,用于根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;
更新模块,用于根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;
第二确定模块,用于若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。
其中,上述第一确定模块304,具体用于按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端。
其中,上述更新模块305包括:
处理单元,用于对每个所述目标用户端的重要性评估值进行归一化处理;
第一更新单元,用于根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度;
第二更新单元,用于根据更新后的全局梯度,更新所述服务器端的脑电信号分类模型的网络参数。
其中,上述获取模块303,具体用于通过公式μk=αk×βk,计算第k个用户端的重要性评估值;
其中,μk表示第k个用户端的重要性评估值,αk=nk/n,nk表示第k个用户端的本地训练集所包含的本地样本量,n表示K个用户端的本地训练集所包含的本地样本量的总和,K表示用户端的数量, 表示第t-1轮更新时所述服务器端的全局梯度,表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
其中,上述脑电信号分类模型训练装置还包括:
下发模块,用于在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。
第三方面,本申请实施例提供了一种服务器,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述的方法。
第四方面,本申请实施例提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述的方法。
第五方面,本申请实施例提供了一种计算机程序产品,当计算机程序产品在终端设备上运行时,使得终端设备执行上述第一方面中任一项所述的方法。
本申请实施例与现有技术相比存在的有益效果是:
在本申请的实施例中,通过基于联邦学习框架,在满足数据安全、无需共享或者交换各个用户端本地数据的前提下,即可实现联合训练及其分布式训练,达到充分利用所有用户的有效信息提升脑电信号分类模型的精度的效果。同时在联合训练中,由于不是随机挑选目标用户端,而是通过各用户端的重要性评估值,从所有用户端中选择对共享模型贡献大的目标用户端,并基于目标用户端的本地模型梯度和重要性评估值,更新服务器端的脑电信号分类模型的网络参数,从而提升了脑电信号分类模型的精度及收敛速度。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请一实施例提供的基于联邦学习的脑电信号分类模型训练方法的流程图;
图2是本申请一实施例提供的步骤15的流程图;
图3是本申请一实施例提供的基于联邦学习的脑电信号分类模型训练装置的结构示意图;
图4是本申请一实施例提供的服务器的结构示意图。
具体实施方式
以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本申请实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本申请。在其它情况中,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本申请的描述。
应当理解,当在本申请说明书和所附权利要求书中使用时,术语“包括”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在本申请说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
如在本申请说明书和所附权利要求书中所使用的那样,术语“如果”可以依据上下文被解释为“当...时”或“一旦”或“响应于确定”或“响应于检测到”。类似地,短语“如果确定”或“如果检测到[所描述条件或事件]”可以依据上下文被解释为意指“一旦确定”或“响应于确定”或“一旦检测到[所描述条件或事件]”或“响应于检测到[所描述条件或事件]”。
另外,在本申请说明书和所附权利要求书的描述中,术语“第一”、“第二”、“第三”等仅用于区分描述,而不能理解为指示或暗示相对重要性。
在本申请说明书中描述的参考“一个实施例”或“一些实施例”等意味着在本申请的一个或多个实施例中包括结合该实施例描述的特定特征、结构或特点。由此,在本说明书中的不同之处出现的语句“在一个实施例中”、“在一些实施例中”、“在其他一些实施例中”、“在另外一些实施例中”等不是必然都参考相同的实施例,而是意味着“一个或多个但不是所有的实施例”,除非是以其他方式另外特别强调。术语“包括”、“包含”、“具有”及它们的变形都意味着“包括但不限于”,除非是以其他方式另外特别强调。
目前脑电信号分类模型主要有基于EEGNet的脑电信号分类模型和基于FTL的脑电信号分类模型。但基于EEGNet的脑电信号分类模型的精度低,而基于FTL的脑电信号分类模型的收敛速度慢、且精度不理想。
针对上述问题,本申请实施例基于联邦学习框架,通过在分布式训练中,将服务器端的脑电信号分类模型发送给K个用户端,使各用户端利用本地训练集对接收到的脑电信号分类模型进行训练,并将训练得到的本地模型梯度发送给服务器端,以进行联合训练,从而在满足数据安全、无需共享或者交换各个用户端本地数据的前提下,实现了联合训练及其分布式训练,达到了在充分利用所有用户的有效信息的情况下,提升脑电信号分类模型的精度的效果。
同时在联合训练中,由于不是随机挑选目标用户端,而是通过各用户端的重要性评估值,从所有用户端中选择对共享模型贡献大的目标用户端,并基于目标用户端的本地模型梯度和重要性评估值,更新服务器端的脑电信号分类模型的网络参数,从而提升了脑电信号分类模型的精度及收敛速度。
下面结合具体实施例对本申请提供的基于联邦学习的脑电信号分类模型训练方法进行示例性的说明。
如图1所示,本申请的实施例提供了一种基于联邦学习的脑电信号分类模型训练方法,应用于服务器端,该方法包括如下步骤:
步骤11,将所述服务器端的脑电信号分类模型发送给K个用户端。
在本申请的一些实施例中,上述K个用户端为与上述服务器端参与联邦学习的用户端。需要说明的是,为确保最终得到的脑电信号分类模型是基于用户端的有用户的有效信息得到的,在执行上述训练方法的步骤之前,服务器端可以初始化一个脑电信号分类模型(即上述步骤11中的脑电信号分类模型)。具体的,可以将模型权重初始化为0,也可以采用其他常见的初始化方案,例如高斯、Xavier初始化(Xavier初始化是一种神经网络初始化方法)。
其中,上述脑电信号分类模型可以为EEGNet模型,当然也可以是其他的深度学习网络,例如卷积神经网络(ConvNet)等脑电信号分类神经网络。
步骤12,接收每个所述用户端发送的本地模型梯度。
在本申请的一些实施例中,上述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的。
即,在本申请的一些实施例中,对于参与联邦学习的每个用户端,在接收到服务器端下发的脑电信号分类模型后,会利用用户端的本地训练集对接收到的脑电信号分类模型进行训练,并在该脑电信号分类模型收敛时得到本地模型梯度。
步骤13,根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值。
在本申请的一些实施例中,上述重要性评估值主要用于表征用户端的重要性程度,以便后续按照重要性从高至低的顺序,从K个用户端中选择出对共享模型(即服务器端的脑电信号分类模型)贡献大的目标用户端进行联合训练,从提升脑电信号分类模型的精度和收敛速度。
步骤14,根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端。
在本申请的一些实施例中,可按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端,从而从K个用户端中筛选出重要性程度高的目标用户端。其中,上述预设比例的具体数值可根据实际情况进行设定。
可见,在本申请的一些实施例中,上述目标用户端的重要性程度高于K个用户端中其他用户端的重要性程度,即,目标用户端对共享模型(即服务器端的脑电信号分类模型)贡献大于其他用户端对共享模型的贡献,后续利用这些目标用户端进行联合训练,能提升脑电信号分类模型的精度和收敛速度。
步骤15,根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数。
在本申请的一些实施例中,在联合训练中,通过根据目标用户端的本地模型梯度和重要性评估值,更新服务器端的脑电信号分类模型的网络参数,能提升服务器端的脑电信号分类模型的精度和收敛速度。
步骤16,若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。
在本申请的一些实施例中,上述步骤16中收敛的脑电信号分类模型是共享模型,可用于对任一用户的脑电信号进行分类。
在本申请的一些实施例中,在执行完步骤15后,若服务器端的脑电信号分类模型未收敛,则返回步骤11,以再次更新服务器端的脑电信号分类模型的网络参数,直至所述服务器端的脑电信号分类模型收敛。
需要说明的是,每次更新服务器端的脑电信号分类模型的网络参数后,都需要判断更新网络参数后的脑电信号分类模型是否收敛,若收敛,则更新后的脑电信号分类模型即为共享模型,否则,将更新网络参数后的脑电信号分类模型下发给K个用户端,使K个用户端分别利用自身的本地训练集对接收到的脑电信号分类模型进行训练,得到的本地模型梯度,以再次更新服务器端的脑电信号分类模型的网络参数。
值得一提的是,在本申请的一些实施例中,在联合训练中,不直接使用用户端的本地训练集数据,而是使用用户端的本地模型梯度来共同训练服务器端的脑电信号分类模型,从而保障了用户端本地数据的隐私和使用安全性,在满足数据安全、无需共享或者交换各个用户端本地数据的前提下,即可实现联合训练及其分布式训练,达到了充分利用所有用户的有效信息提升脑电信号分类模型的精度的效果。
同时在联合训练中,由于不是随机挑选用户端,而是通过各用户端的重要性评估值,从所有用户端中选择对共享模型贡献大的目标用户端,并基于目标用户端的本地模型梯度和重要性评估值,更新服务器端的脑电信号分类模型的网络参数,从而提升了脑电信号分类模型的精度及收敛速度。
在本申请的实施例中,在执行完上述步骤16后,上述方法还包括如下步骤:在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。
需要说明的是,用户端在接收到该脑电信号分类模型后,可利用自身的用本地训练集对该脑电信号分类模型进行训练,以对该脑电信号分类模型的模型参数进行微调,得到更适合该用户端的脑电信号分类模型,后续该用户端可利用微调后的脑电信号分类模型对该用户端的用户数据进行分类,提高分类准确性。
接下来,结合具体实施例对用户端利用本地训练集对脑电信号分类模型进行训练的过程的进行示例性的说明。
在本申请的一些实施例中,用户端的本地训练集可来源于上海交大情感脑电数据集(SEED)。在该数据集的实验中,15个筛选过的中国电影片段被选取为实验中的情感刺激源,标签包括正面、中性和负面情绪。该数据集一共采集了15名中国受试者(包括7名男生和8名女生),其中每个受试者分别进行3次实验。该数据集中的每个样本包含62个电极通道,下采样到200Hz,并且应用了0-75Hz的带通频率滤波器。为了扩展数据量,我们将每个数据按照1s的数据窗口进行不重叠切割,最终获取3394个样本。在采集的62个通道中,本申请实施例选择与情感相关的32个通道,分别对应Fp1,AF3,F3,F7,FC5,FC1,C3,T7,CP5,CP1,P3,P7,PO3,O1,Oz,Pz,Fp2,AF4,Fz,F4,F8,FC6,FC2,Cz,C4,T8,CP6,CP2,P4,P8,PO4,O2。为此,每个样本的大小为32×200。需要说明的是,在本申请的一些实施例中,可将15名受试者中任一受试者的32个通道的数据作为一用户端的本地训练集。为提升脑电信号分类模型的精度,用户端每次均可利用本地训练集中的所有数据对脑电信号分类模型进行训练。需要进一步说明的是,每个用户端对应的本地训练集均不相同。
作为一个优选的示例,根据输入的原始EEG信号的时空属性,上述脑电信号分类模型采用EEGNet模型,用于提取脑电信号的特征表示及分类。其中本申请中特征提取器和分类器模型参数如表1所示。当然可以理解的是,卷积层数目、卷积核大小、池化方法以及激活函数均可根据实际情况进行设定。
表1
其中,在用户端利用本地训练集对脑电信号分类模型进行训练时,可采用交叉熵(cross entropy)损失函数评估训练结果,其中第k个用户端的训练损失函数如下:其中,nk表示第k个用户端的本地训练集所包含的本地样本量,yi为训练样本(即本地训练集中的本地样本)的真实标签,为预测标签。需要说明的是,上述训练损失函数为常用损失函数,因此在此,不对该训练损失函数的原理进行过多赘述。
接下来,结合具体实施例对获取重要性评估值以及更新网络参数的过程的进行示例性的说明。
在本申请的一些实施例中,上述步骤13,根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值的具体实现方式可以为:通过公式μk=αk×βk,计算第k个用户端的重要性评估值。
其中,μk表示第k个用户端的重要性评估值,αk=nk/n,nk表示第k个用户端的本地训练集所包含的本地样本量,n表示K个用户端的本地训练集所包含的本地样本量的总和,K表示用户端的数量, 表示第t-1轮更新时所述服务器端的全局梯度,表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
在本申请的一些实施例中,除了通过上述公式计算用户端的重要性评估值外,还可以通过其他的相似性度量学习方法或者注意力机制算法度量用户端的重要性。
在本申请的一些实施例中,如图2所示,上述步骤15,根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数的具体实现方式包括如下步骤:
步骤21,对每个所述目标用户端的重要性评估值进行归一化处理。
步骤22,根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度。
步骤23,根据更新后的全局梯度,更新所述服务器端的脑电信号分类模型的网络参数。
在本申请的一些实施例中,可采用基于随机梯度下降(SGD,Stochastic GradientDescent)的随机梯度下降法求解网络参数。需要说明的是,服务器端在初始化脑电信号分类模型的时候,服务器端的全局梯度也会初始化为0。
综上,本申请实施例提供的基于联邦学习的脑电信号分类模型训练方法具备如下效果:
一、脑电信号分类模型采用EEGNet模型,将其应用到情感脑电信号的分类任务中,无需手工提取信号特征,能端到端进行情感脑电信号的特征提取和分类;
二、将EEGNet模型应用到情感脑电识别网络,利用深度学习自动提取情感脑电信号的可判别性特征,提升用户端单个的脑电信号分类模型的准确率;
三、无需对脑电信号做繁杂的预处理,直接利用脑电信号对脑电信号分类模型进行训练,便可有效进行脑电信号的特征提取和分类;
四、能在满足数据安全、无需共享或者交换各个用户端本地数据的前提下,实现联合训练及其分布式训练,达到充分利用所有用户的有效信息提升脑电信号分类模型的精度的效果;
五、通过各用户端的重要性,选择对共享模型贡献大的目标用户端进行联合训练,从而提升脑电信号分类模型的精度及收敛速度。
下面结合具体实施例对本申请提供的基于联邦学习的脑电信号分类模型训练装置进行示例性的说明。
如图3所示,本申请实施例提供了一种基于联邦学习的脑电信号分类模型训练装置,应用于服务器端,该脑电信号分类模型训练装置300包括:
发送模块301,用于将所述服务器端的脑电信号分类模型发送给K个用户端;
接收模块302,用于接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;
获取模块303,用于根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;
第一确定模块304,用于根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;
更新模块305,用于根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;
第二确定模块306,用于若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。
其中,上述第一确定模块304,具体用于按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端。
其中,上述更新模块305包括:
处理单元,用于对每个所述目标用户端的重要性评估值进行归一化处理;
第一更新单元,用于根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度;
第二更新单元,用于根据更新后的全局梯度,更新所述服务器端的脑电信号分类模型的网络参数。
其中,上述获取模块303,具体用于通过公式μk=αk×βk,计算第k个用户端的重要性评估值;
其中,μk表示第k个用户端的重要性评估值,αk=nk/n,nk表示第k个用户端的本地训练集所包含的本地样本量,n表示K个用户端的本地训练集所包含的本地样本量的总和,K表示用户端的数量, 表示第t-1轮更新时所述服务器端的全局梯度,表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
其中,上述脑电信号分类模型训练装置还包括:
下发模块,用于在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。
需要说明的是,上述装置/单元之间的信息交互、执行过程等内容,由于与本申请方法实施例基于同一构思,其具体功能及带来的技术效果,具体可参见方法实施例部分,此处不再赘述。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将所述装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本申请的保护范围。上述系统中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
如图4所示,本申请的实施例提供了一种服务器,如图4所示,该实施例的服务器D10包括:至少一个处理器D100(图4中仅示出一个处理器)、存储器D101以及存储在所述存储器D101中并可在所述至少一个处理器D100上运行的计算机程序D102,所述处理器D100执行所述计算机程序D102时实现上述任意各个方法实施例中的步骤。
所称处理器D100可以是中央处理单元(CPU,Central Processing Unit),该处理器D100还可以是其他通用处理器、数字信号处理器(DSP,Digital Signal Processor)、专用集成电路(ASIC,Application Specific Integrated Circuit)、现成可编程门阵列(FPGA,Field-Programmable Gate Array)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
所述存储器D101在一些实施例中可以是所述服务器D10的内部存储单元,例如服务器D10的硬盘或内存。所述存储器D101在另一些实施例中也可以是所述服务器D10的外部存储设备,例如所述服务器D10上配备的插接式硬盘,智能存储卡(SMC,Smart MediaCard),安全数字(SD,Secure Digital)卡,闪存卡(Flash Card)等。进一步地,所述存储器D101还可以既包括所述服务器D10的内部存储单元也包括外部存储设备。所述存储器D101用于存储操作系统、应用程序、引导装载程序(BootLoader)、数据以及其他程序等,例如所述计算机程序的程序代码等。所述存储器D101还可以用于暂时地存储已经输出或者将要输出的数据。
需要说明的是,上述装置/单元之间的信息交互、执行过程等内容,由于与本申请方法实施例基于同一构思,其具体功能及带来的技术效果,具体可参见方法实施例部分,此处不再赘述。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将所述装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本申请的保护范围。上述系统中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
本申请实施例还提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现可实现上述各个方法实施例中的步骤。
本申请实施例提供了一种计算机程序产品,当计算机程序产品在终端设备上运行时,使得终端设备执行时实现可实现上述各个方法实施例中的步骤。
所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请实现上述实施例方法中的全部或部分流程,可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,所述计算机程序包括计算机程序代码,所述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。所述计算机可读介质至少可以包括:能够将计算机程序代码携带到脑电信号分类模型训练装置/终端设备的任何实体或装置、记录介质、计算机存储器、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、电载波信号、电信信号以及软件分发介质。例如U盘、移动硬盘、磁碟或者光盘等。在某些司法管辖区,根据立法和专利实践,计算机可读介质不可以是电载波信号和电信信号。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
在本申请所提供的实施例中,应该理解到,所揭露的装置/网络设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/网络设备实施例仅仅是示意性的,例如,所述模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通讯连接可以是通过一些接口,装置或单元的间接耦合或通讯连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
以上所述实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围,均应包含在本申请的保护范围之内。
Claims (10)
1.一种基于联邦学习的脑电信号分类模型训练方法,其特征在于,应用于服务器端,所述方法包括:
将所述服务器端的脑电信号分类模型发送给K个用户端;
接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;
根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;
根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;
根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;
若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。
2.根据权利要求1所述的方法,其特征在于,所述根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端的步骤,包括:
按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端。
3.根据权利要求2所述的方法,其特征在于,所述根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数的步骤,包括:
对每个所述目标用户端的重要性评估值进行归一化处理;
根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度;
根据更新后的全局梯度,更新所述服务器端的脑电信号分类模型的网络参数。
7.根据权利要求1所述的方法,其特征在于,所述方法还包括:
在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。
8.一种基于联邦学习的脑电信号分类模型训练装置,其特征在于,应用于服务器端,所述装置包括:
发送模块,用于将所述服务器端的脑电信号分类模型发送给K个用户端;
接收模块,用于接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;
获取模块,用于根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;
第一确定模块,用于根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;
更新模块,用于根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;
第二确定模块,用于若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。
9.一种服务器,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至7任一项所述的方法。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述的方法。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111347340.8A CN114048780A (zh) | 2021-11-15 | 2021-11-15 | 基于联邦学习的脑电信号分类模型训练方法及装置 |
PCT/CN2021/138013 WO2023082406A1 (zh) | 2021-11-15 | 2021-12-14 | 基于联邦学习的脑电信号分类模型训练方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111347340.8A CN114048780A (zh) | 2021-11-15 | 2021-11-15 | 基于联邦学习的脑电信号分类模型训练方法及装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114048780A true CN114048780A (zh) | 2022-02-15 |
Family
ID=80208990
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111347340.8A Pending CN114048780A (zh) | 2021-11-15 | 2021-11-15 | 基于联邦学习的脑电信号分类模型训练方法及装置 |
Country Status (2)
Country | Link |
---|---|
CN (1) | CN114048780A (zh) |
WO (1) | WO2023082406A1 (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114664434A (zh) * | 2022-03-28 | 2022-06-24 | 上海韶脑传感技术有限公司 | 面向不同医疗机构的脑卒中康复训练系统及其训练方法 |
CN117708681A (zh) * | 2024-02-06 | 2024-03-15 | 南京邮电大学 | 基于结构图指导的个性化联邦脑电信号分类方法及系统 |
Family Cites Families (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20150324690A1 (en) * | 2014-05-08 | 2015-11-12 | Microsoft Corporation | Deep Learning Training System |
CN111814985B (zh) * | 2020-06-30 | 2023-08-29 | 平安科技(深圳)有限公司 | 联邦学习网络下的模型训练方法及其相关设备 |
CN112181666B (zh) * | 2020-10-26 | 2023-09-01 | 华侨大学 | 一种基于边缘智能的设备评估和联邦学习重要性聚合方法 |
CN112633146B (zh) * | 2020-12-21 | 2024-03-26 | 杭州趣链科技有限公司 | 多姿态人脸性别检测训练优化方法、装置及相关设备 |
CN113158241A (zh) * | 2021-04-06 | 2021-07-23 | 深圳市洞见智慧科技有限公司 | 基于联邦学习的岗位推荐方法及装置 |
-
2021
- 2021-11-15 CN CN202111347340.8A patent/CN114048780A/zh active Pending
- 2021-12-14 WO PCT/CN2021/138013 patent/WO2023082406A1/zh unknown
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114664434A (zh) * | 2022-03-28 | 2022-06-24 | 上海韶脑传感技术有限公司 | 面向不同医疗机构的脑卒中康复训练系统及其训练方法 |
CN117708681A (zh) * | 2024-02-06 | 2024-03-15 | 南京邮电大学 | 基于结构图指导的个性化联邦脑电信号分类方法及系统 |
CN117708681B (zh) * | 2024-02-06 | 2024-04-26 | 南京邮电大学 | 基于结构图指导的个性化联邦脑电信号分类方法及系统 |
Also Published As
Publication number | Publication date |
---|---|
WO2023082406A1 (zh) | 2023-05-19 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111190939B (zh) | 一种用户画像构建方法及装置 | |
WO2019200781A1 (zh) | 票据识别方法、装置及存储介质 | |
CN114048780A (zh) | 基于联邦学习的脑电信号分类模型训练方法及装置 | |
CN107958230B (zh) | 人脸表情识别方法及装置 | |
CN110503082B (zh) | 一种基于深度学习的模型训练方法以及相关装置 | |
CN108197592B (zh) | 信息获取方法和装置 | |
CN107092874A (zh) | 基于心电和指纹融合特征的身份识别方法、装置及系统 | |
CN108549276B (zh) | 一种智能交互控制制水设备的方法及系统 | |
CN112656431A (zh) | 基于脑电的注意力识别方法、装置、终端设备和存储介质 | |
CN108256579A (zh) | 一种基于先验知识的多模态民族认同感量化测量方法 | |
CN104951807A (zh) | 股市情绪的确定方法和装置 | |
CN113133769A (zh) | 基于运动想象脑电信号的设备控制方法、装置及终端 | |
CN111671420A (zh) | 一种从静息态脑电数据中提取特征的方法及终端设备 | |
CN105631283B (zh) | 一种基于生物特征自学习方法及移动终端 | |
CN113143295A (zh) | 基于运动想象脑电信号的设备控制方法及终端 | |
CN105844204B (zh) | 人体行为识别方法和装置 | |
Liong et al. | Automatic traditional Chinese painting classification: A benchmarking analysis | |
CN115098777A (zh) | 一种基于数据分析的用户个性化推荐方法和系统 | |
CN104679967A (zh) | 一种判断心理测试可靠性的方法 | |
Saha et al. | Common spatial pattern in frequency domain for feature extraction and classification of multichannel EEG signals | |
CN113014881A (zh) | 一种神经外科患者日常监护方法及系统 | |
Kong et al. | Task-free brainprint recognition based on degree of brain networks | |
CN116502261A (zh) | 保留数据特性的数据脱敏方法及装置 | |
CN109948718B (zh) | 一种基于多算法融合的系统及方法 | |
CN116340825A (zh) | 一种基于迁移学习的跨被试rsvp脑电信号的分类方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |