CN116821764A - 一种基于知识蒸馏的多源域适应的eeg情绪状态分类方法 - Google Patents

一种基于知识蒸馏的多源域适应的eeg情绪状态分类方法 Download PDF

Info

Publication number
CN116821764A
CN116821764A CN202310802378.2A CN202310802378A CN116821764A CN 116821764 A CN116821764 A CN 116821764A CN 202310802378 A CN202310802378 A CN 202310802378A CN 116821764 A CN116821764 A CN 116821764A
Authority
CN
China
Prior art keywords
domain
model
teacher
source
sample
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310802378.2A
Other languages
English (en)
Inventor
郑浩浩
曾虹
乐淑萍
潘登
曾涛
徐非凡
欧阳瑜
贾哲
钱东官
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Hangzhou Dianzi University
Original Assignee
Hangzhou Dianzi University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Hangzhou Dianzi University filed Critical Hangzhou Dianzi University
Priority to CN202310802378.2A priority Critical patent/CN116821764A/zh
Publication of CN116821764A publication Critical patent/CN116821764A/zh
Pending legal-status Critical Current

Links

Landscapes

  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明公开了一种基于知识蒸馏的多源域适应的EEG情绪状态分类方法。首先获取数据进行带通滤波,并使用独立成分分析技术去除伪迹。其次通过差分熵方法进行脑电特征提取,将三维脑电时间序列转换为二维样本矩阵。然后在两种任务场景下各自划定训练集和测试集,确保它们不重合。本发明采用基于边际采样的伪标签三元组损失结合最大均值差异。本发明从不同的源领域中学习知识,以最大程度地利用多个单源模型,并以更少的时间消耗实现更强大的模型。最后利用分类准确率对模型进行两种任务场景下的性能评估。本发明将三元组损失和最大均值差异结合,不仅可以在域级别上实现每对源领域和目标领域之间的无偏置对齐,还考虑了数据对级别的相关性。

Description

一种基于知识蒸馏的多源域适应的EEG情绪状态分类方法
技术领域
本发明属于生物特征识别领域中的脑电信号(EEG)情绪状态识别领域,具体涉及一种基于知识蒸馏的多源域适应的EEG情绪状态分类方法。
背景技术
情感识别在人机交互中发挥着重要作用。近年来,随着计算能力的提高,基于深度学习的情感识别方法越来越受到关注。这些方法通过深入挖掘用户潜在的客观情感特征,做出反映人类情感的决策。
情感型脑机接口(aBCIs)是情感识别的一个重要应用。通过测量周围和中枢神经系统的信号,它们提取与用户情感状态相关的特征,并利用这些特征来调整人机交互(HCI)。它们在康复和交流方面显示出潜力。
一般来说,情绪识别可以分为两类:基于非生理信号的方法,如面部表情图像、身体手势和语音信号;以及基于生理信号的方法,如脑电图(EEG)、肌电图(EMG)和心电图(ECG)。然而,与非生理信号相比,生理信号可以直接接触到个体的内部情绪状态,使其不容易受到有意识或无意识的操纵。在各种基于生理信号的情绪识别方法中,脑电图是最常用的方法之一,因为它是直接从大脑皮层采集的,对反映人的心理状态很有价值。随着脑电图采集技术和处理方法的快速发展,基于脑电图的情绪识别近年来受到越来越多的关注。
然而,由于低信噪比(SNR)以及不同时间和不同受试者之间的显著个体差异,构建高效、稳健的基于EEG的情绪识别深度学习模型仍是一个巨大的挑战。此外,在基于EEG的BCI中使用现有的标记数据来分析新的未标记的数据是至关重要的。为此,域适应在研究工作中被广泛使用,通过从源数据分布中学习,训练一个在相关但不同的目标数据分布上表现良好的模型。然而,在实践中,通常有多个源域,这使得多源域适应成为域适应的强大扩展。尽管如此,在多源域适应中用于域对齐的技术通常是最大均值差异(MMD),它只考虑域级别的适应,而缺乏数据对级别的适应。这种限制可能会导致缺乏判别能力。此外,在大多数多源域适应框架中,只是用多个单源域模型的平均预测结果作为最终的结果,并未充分的利用好这些单源域模型。
发明内容
为了解决上述现有技术存在的缺陷以及更好的利用多个单源模型的优势,本发明提出了一种基于知识蒸馏的多源域适应的EEG情绪状态分类方法(MS-KTF)。
本发明采用的技术方案是:
本发明以微分熵(DE)特征作为使用的EEG信号的频域特征,略微修改EEGNet模型作为特征提取器,以单层的线性层作为分类器,对EEG信号进行分析,实现跨被试和跨时段两种情形下的情绪状态识别任务。
本发明将训练过程分为三个步骤:训练过程分为三个步骤:(1)基于每个有标签的源域对每个教师模型进行预训练;(2)基于相应的有标签源域和无标签的目标域,利用源域分类损失(SCL)、目标域分类损失(TCL)、最大均值差异(MMD)和伪标签三元组损失,对每个教师模型进行域适应;(3)将多个单源域教师的知识转移给学生模型。此外,在步骤(2)中,为了提高伪标签三元组损失的有效性,我们采用了基于边际的抽样策略来过滤原始特征,只选择那些边际分数高于预设阈值的特征作为计算伪标签三元组损失的嵌入特征。
本发明实施方式包括如下步骤:
步骤S1.数据处理:
以情绪数据集SEED为例进行分析,对脑电采集设备采集到的原始EEG数据的处理步骤如下:
S1-1:数据去噪
本发明用来验证模型性能的数据集来源于SEED。首先对数据集中采集到的EEG原始信号降采样到200HZ,然后进行0-75HZ的带通滤波和ICA技术去除信号中的眼电伪迹,最后使用传统的移动平均和线性动态系统(LDS)方法进一步平滑特征。
S1-2:DE特征提取
对伪迹去除后的EEG数据进行DE特征提取,针对每个被试个体用1s的非重叠滑动窗口进行数据分割,得到3394个数据样本。对于每个数据样本xi,脑电数据采集通道数为62;且提取了δ(1-3HZ)、θ(4-7HZ)、α(8-13HZ)、β(14-30HZ)和γ(31-50HZ)五个频段的频域特征。
步骤S2.数据定义和数据集划分:
情绪状态分类存在两种测试情境:跨被试和跨时段,这两种情况下的模型测试有着各自不同的数据定义和数据集划分,下面各自具体说明。
假定存在N个被试,每个被试进行了D次不同(会话)时段实验。整个样本集合表示为其中i表示被试序号,j表示会话(时段)序号,Xi表示被试i的样本集合,对应的标签集合为Yi
对于跨时段的情感状态分类任务,同样采用留一法对数据集进行交叉验证。具体来说,在每一个被试i中,将最新一次会话中所有被试的15次情感试验的数据作为测试集;将其余D-1个会话以会话为单位,每个会话作为训练集中的一个源域,最终获得D-1个源域作为训练集。总共进行N次实验并计算平均准确率。
对于跨被试的情感状态分类任务,采用留一法对数据集进行交叉验证。具体来说,在一个会话(时段)中,迭代的取出一个被试的所有15次情感试验的数据并假设其情绪状态标签未知,作为测试集;从其余N-1个被试中,随机不重复的将R个被试组成一组作为训练集中的一个源域,最终获得(向下取整)个源域作为训练集。总共进行D×N次实验并计算平均准确率。
步骤S3.MS-KTF模型的构建与训练:
神经网络MS-KTF模型中的主要参数包括:
①特征嵌入空间的维度大小df,即特征被特征提取器提取后进入嵌入空间的维度,与嵌入特征的表征能力密切相关。
②进行边缘采样的阈值Threshold,即决定该特征是否应该被采样的关键参数,对获得分数高于阈值的特征进行采样。
③进行蒸馏学习的温度Temperature,即进行softmax操作中的温度系数,主要是用来调整预测分布情况的平滑程度。
S3-1:初始化
MS-KFT由两部分组成:基于单源域的教师模型和作用于目标域的学生模型。不论是教师模型还是学生模型,都由特定领域特征提取器Nf和标签分类器Ny这两个模块组成。将基于多源域的多个单源域教师模型和一个目标域学生的参数进行初始化操作;
S3-2:预训练多个单源域教师模型
基于多个源域样本集合,分别使用单个有标签的源域样本来预训练一个领域特定教师模型的特征提取器Nf和标签分类器Ny,使之在各自的源域下有一定的模式识别能力。
S3-3:对多个单源域教师模型的特征提取器进行域适应
每个有标签源域样本和无标签目标域样本形成一个分支,在每个分支中,使用对应领域特定教师模型的特征提取器Nf对各自源域样本和目标域样本进行特征提取,将特征从原始的特征空间提取到嵌入空间中。
然后在特征空间中基于最大均值差异对嵌入特征进行领域层面的对齐;基于边缘采样的伪标签三元组损失对嵌入特征进行数据组层面的对齐。
通过最小化最大均值差异和伪标签三元组损失,训练多个单源域教师模型的特征提取器Nf,使其能够在源域和目标域进行域不变特征提取。
S3-4:训练多个单源域教师模型的标签分类器Ny
在各个单源域教师模型中,将提取出来的源域特征信息通过标签分类器Ny获得预测情绪并计算预测情绪/>与实际样本中对应标签YS的交叉熵;同样,计算目标域特征信息的预测情绪/>与其所生成的伪标签/>的交叉熵。
通过最小化两个获得的交叉熵,训练多个单元域教师模型的标签分类器Ny,使其在各自的源域和目标域拥有良好的情绪分类能力。
S3-5:合并多个单源域教师模型的知识;
针对教师模型性能之间平衡性,使用两种不同的合并策略:
①基于投票的形式来筛选合并的教师模型
这种方式更适用于教师性能平衡性较差的情况。使用无标签的目标域样本,首先通过教师模型的特征提取器和标签分类器获得的情绪预测结果生成对应的独热编码结果/>基于各个教师生成的独热编码结果/>进行投票,将投票结果作为决策变量/>如果该教师模型的情绪预测结果/>与决策变量/>相同,则该教师被选中进行知识合并。
计算所有被选中的教师模型的情绪预测结果的平均值作为所有教师模型的合并知识。
②使用平均的形式来合并教师模型
这种方式更适用于教师性能平衡性较强的情况。这种情况下所有的教师模型均拥有同样的权重,计算所有被选中的教师模型的预测结果的平均值作为所有教师模型的合并知识。
S3-6:将教师模型的合并知识教导给学生模型
使用无标签的目标域样本数据,通过学生模型的特征提取器和标签分类器,得到学生模型的预测结果
基于预设定的蒸馏温度Temperature,将教师模型的合并知识和学生模型的预测结果/>进行平滑化处理。使用KL散度(Kullback-Leibler Divergence)评估两个预测结果之间的差异。
通过最小化教师模型与学生模型的KL散度使学生模型学习教师模型的知识,获得在目标域相较于教师更具泛化性的特征提取和标签分类能力。
步骤S4.跨时段和跨被试两种情境下的模型性能评估:
本发明在SEED数据集上具体验证模型性能。
将目标域样本集在训练收敛的学生模型预测出的情绪状态和真实状态YT进行比对,获得准确率结果,并评估模型性能。准确率为模型测试时正确分类的样本个数占总测试样本的个数,模型准确度的计算公式如下:
其中,TP为被模型预测为正类的正类样本,TN为被模型预测为负类的负类样本,FP为被模型预测为正类的负类样本,FN为被模型预测为负类的正类样本。
一种基于知识蒸馏的多源域适应的EEG情绪状态分类系统,包括预训练模块、教师模型域适应模块、学生模型训练模块;其中预训练模块基于每个有标签的源域对每个教师模型进行预训练;教师模型域适应模块基于相应的有标签源域和无标签的目标域,利用源域分类损失(SCL)、目标域分类损失(TCL)、最大均值差异(MMD)和伪标签三元组损失,对每个教师模型进行域适应;学生模型训练模块将多个单源域教师的知识转移给学生模型。
此外,在教师模型域适应模块中,为了提高伪标签三元组损失的有效性,我们采用了基于边际的抽样策略来过滤原始特征,只选择那些边际分数高于预设阈值的特征作为计算伪标签三元组损失的嵌入特征。
本发明有益效果如下:
本发明通过利用伪标签三联体损失,解决了多源域适应中最大平均差(MMD)技术的盲目估计问题。此外,还采用了基于不确定性测量的基于边际的抽样策略来提高其有效性;同时,引入了知识提炼技术,通过教授多个老师模型的知识训练一个更稳健的学生模型,以最大限度地利用多源域知识。通过在公开情绪数据集SEED上的实验验证,相较之前的方法该发明获得了显著的提升。
附图说明
图1为本发明流程示意图;
图2为本发现模型特征提取器的具体结构图;
图3为MS-KTF模型架构图;
图4为MS-KTF模型数据划分与构建图;
图5为教师模型结构图;
图6为学生模型结构图;
具体实施方式
下面结合附图对本发明的较佳实施例进行详细阐述,以使本发明的优点和特征能更易于被本领域技术人员理解,从而对本发明的保护范围做出更为清楚明确的界定。
多源领域适应(MSDA)旨在将多个源领域的知识转移到一个未标记的目标领域,这对于跨会话和跨受试者的脑电情感识别非常适用。然而,现有的MSDA模型仅仅考虑了源领域和目标领域之间每对特征关系的域级别,而很少考虑两个领域之间数据对级别的相关性,导致鲁棒性较差。
本发明公开了一种用于脑电情感识别的多源领域知识转移框架(MS-KTF)。首先,获取数据进行带通滤波,并使用独立成分分析(ICA)技术去除伪迹。其次,通过差分熵(DE)方法进行脑电特征提取,将三维脑电时间序列转换为二维样本矩阵。然后,在两种任务场景下各自划定训练集和测试集,确保它们不重合。对于这些样本,MS-KTF采用基于边际采样的伪标签三元组损失结合最大均值差异(MMD)。这种方法不仅可以在域级别上实现每对源领域和目标领域之间的无偏置对齐,还考虑了数据对级别的相关性。具体而言,该框架从不同的源领域中学习知识,以最大程度地利用多个单源模型,并以更少的时间消耗实现更强大的模型。最后利用分类准确率对模型进行两种任务场景下的性能评估。本发明将三元组损失和最大均值差异结合,一定程度上解决了脑电信号分布差异对齐不充分的问题,训练出了高精度的跨时段和跨被试的情绪状态分类器,具有时间复杂度小、计算效率高、泛化能力强等优势,以期在实际的脑机交互中有着广泛的应用前景。
本发明具体实现请参阅图1、图2、图3、图4、图5、图6,本发明实施方式包括如下步骤:
步骤S1.数据处理:
以情绪数据集为例进行分析,对脑电采集设备采集到的原始EEG数据的处理步骤如下:
S1-1:数据去噪
本发明用来验证模型性能的数据集来源于SEED,具体可参考论文《InvestigatingCritical Frequency Bands and Channels for EEG-Based Emotion Recognition withDeep Neural Networks》。首先对数据集中采集到的EEG原始信号降采样到200HZ,然后进行0.3-50HZ的带通滤波,最后采用ICA技术去除信号中的眼电伪迹。
S1-2:DE特征提取
对伪迹去除后的EEG数据进行DE特征提取,每个被试会观看15个能引起被试明显情绪变换的视频,把同一个视频播放时长内采集到EEG数据认为是一个情感试验,每个被试有15个情感试验。针对每个被试个体用1s的非重叠滑动窗口进行数据分割,得到3394个数据样本。对于每个样本xi,其中脑电数据采集通道数为62;提取了δ(1-3HZ)、θ(4-7HZ)、α(8-13HZ)、β(14-30HZ)和γ(31-50HZ)五个频段的频域特征。
步骤S2.数据定义和数据集划分:
情绪状态分类存在两种测试情境:跨被试和跨时段,这两种情况下的模型测试有着各自不同的数据定义和数据集划分,下面各自具体说明。
假定存在N个被试,每个被试进行了D次不同(会话)时段实验。整个样本集合表示为其中i表示被试序号,j表示会话(时段)序号,Xi表示被试i的样本集合,对应的标签集合为Yi
对于跨时段的情感状态分类任务,同样采用留一法对数据集进行交叉验证。具体来说,在每一个被试中,将最新一次会话中所有被试的15次情感试验的数据作为测试集;将其余D-1个会话以会话为单位,每个会话作为训练集中的一个源域,最终获得D-1个源域作为训练集。总共进行N次实验并计算平均准确率。
对于跨被试的情感状态分类任务,采用留一法对数据集进行交叉验证。具体来说,在一个会话(时段)中,迭代的取出一个被试的所有15次情感试验的数据并假设其情绪状态标签未知,作为测试集;从其余N-1个被试中,随机不重复的将R个被试一组作为训练集中的一个源域,最终获得(向下取整)个源域作为训练集。最后分别在N个被试的测试集上验证模型性能,总共进行D×N次实验并计算平均准确率。
步骤S3:MS-KTF模型的构建与训练
神经网络MS-KTF模型中的主要参数包括:
1)特征嵌入空间的维度大小df,即特征被特征提取器提取后进入嵌入空间的维度,与嵌入特征的表征能力密切相关。
2)进行边缘采样的阈值Threshold,即决定该特征是否应该被采样的关键参数,对获得分数高于阈值的特征进行采样。
3)进行蒸馏学习的温度Temperature,即进行softmax操作中的温度系数,主要是用来调整预测分布情况的平滑程度。
S3-1:模型的具体数据划分与输入
S3-1-1:模型的数据划分
模型的数据划分与构建如图4所示,具体的划分情况描述如下:
对于跨被试的情境:模型的目标域样本集合为UT={Xi},其中Xi表示第i个被试的特征数据集合;xj表示Xi中的第j个样本,n表示Xi中样本的总数;模型的多源域样本集合为/> 其中[N]\i表示去掉第i个被试数据后的所有的被试序号集合,Pj表示第j个源域中包含的被试序号集合(*所有的跨被试情境下的数据都来源于相同的会话)。
对于跨时段的情境:模型的目标域样本集合为其中Xi表示第i个被试的特征数据集合;j表示第j个会话(时段)。模型的多源域样本集合为其中[D]/j表示去掉第j个会话后的所有会话序号集合。
S3-1-2:模型的数据输入
如图3左半部分,即图4所示,每个有标签的源域样本集合US和无标签的目标域样本UT集合组成一个分支,用于后续的单源域教师模型的训练。而对于学生模型,则只使用无标签的目标域样本集合UT
S3-2:模型的初始化
如图3右半部分所示,MS-KFT由两部分组成:基于单源域的教师模型(图3右上角部分)和只作用于目标域的学生模型(图3右下角部分)。不论是教师模型还是学生模型,都由特定领域特征提取器Nf和标签分类器Ny这两个模块组成,其中特征提取器Nf的具体结构详见图2,标签分类器Ny由单层的线性层和softmax函数组成。
S3-3:单源域教师模型的预训练
如图5所示,为多个单源域教师模型的结构图,每个特定领域的特征提取器Nf和特定领域标签分类器Ny组成一个单源域教师模型。
基于多个源域样本集合US,分别使用单个有标签的源域样本集合来预训练一个领域特定教师模型的特征提取器Nf和标签分类器Ny,使之在各自的源域下有一定的模式识别能力(优化目标同下述公式(5)中的SCL,此处不予赘述)。
S3-4:单源域教师模型特征提取器的训练
在经过领域特定教师模型的特征提取器Nf后,对应的源领域数据US和目标领域数据UT将被提取出它们各自的低维特征FS,FT。为了确保提取出的特征的无偏域自适应性,本专利使用了领域级别分布对齐和数据对级别分布对齐两种方法。
S3-4-1:领域级别的分布对齐
对应图5中无偏移的分布对齐,本专利基于伪标签的三元组损失和最大均值差异两种技术,进行教师模型的域自适应。
最大均值差异(MMD)是一种概率度量空间中的距离度量,广泛被应用于机器学习和非参数测试中。该距离度量基于将概率嵌入到再生核希尔伯特空间(RKHS)的思想,它旨在减少源域和目标域之间的分布差异,同时保留它们的特定判别信息。在训练过程中,通过最小化MMD损失来减小特征空间中源域和目标域之间的距离,从而实现领域级别对齐。具体的公式如下所示:
其中和/>分别表示源领域的第i个样本和目标领域的第j个样本被提取出的低维特征;NS和NT表示源域样本和目标域样本的数量。
S3-4-2:数据对级别的分布对齐
由于MMD盲目地估计参数以考虑统计信息及其关系,因此特征可辨别性可能会降低,类内距离和类间距离之间的关系可能会受到影响,因为其中一个距离值下降而另一个距离值上升。三元组损失能够缩小类内距离,增大类间距离,是解决这个问题的一种方式,然而,在域适应中,目标领域通常是未标记的。因此,本专利使用基于边缘采样的三元组损失进行数据对级别的分布对齐。
本专利使用每个样本预测结果的边缘成绩,作为决定该样本是否被采样的依据,该方法可以用如下公式表示:
Xselected={xj|margin(xj)≥Threshold,xj∈X} (3)
其中x是输入样本,gθ是标签分类器的抽象函数,i*是预测结果中拥有最高预测概率的类别,k是所有类别的数量,[k]\i*表示除i*外所有的类别集合,Threshold是预先设定的边缘采样的阈值。
使用三元组损失需要以三元组的形式进行采样,其中/>(锚点样本)和/>(正例样本)是来自第i个三元组的相同类别不同样本,/>是第i个三元组中与锚点样本不同类别的任意样本。三元组损失的目的是保证正样本对/>嵌入特征之间的距离加上固定的边际值(margin)小于负样本对/>嵌入特征之间的距离。形式上,对于一个小批量样本集,三元组损失定义为:
其中N是Xselected中包含的样本数量,α是预先设定用于引导可分性的边际值,d(·)是计算正则化嵌入特征对之间欧拉距离的函数,fθ(·)是特征提取的抽象函数。
S3-5:单源域教师模型标签分类器的训练
本专利使用交叉熵(CE)损失作为标签分类器在源域和目标域的分类结果的评估指标,在源域中具体使用SCL作为分类损失,在目标域中具体使用TCL作为分类损失。
在源域中,具有真实标签,所以SCL使用真实标签和标签分类器的分类结果作为比对对象,具体的公式如下所示:
其中xi是第i个源域输入样本,是第i个源域输入样本的真实标签,/>是标签分类器对第i个源域输入样本的预测结果,fθ(·)是特征提取的抽象函数,gθ(·)是标签分类器的抽象函数。
在目标域中,样本缺乏真实标签,所以相对应的TCL使用生成的伪标签和标签分类器的分类结果作为比对对象,具体的公式如下所示:
其中xi是第i个目标域输入样本,是第i个目标域输入样本生成的伪标签,/>是标签分类器对第i个目标域输入样本的预测结果,fθ(·)是特征提取的抽象函数,gθ(·)是标签分类器的抽象函数。
S3-6:单源域教师模型的优化目标和训练
总结S3-4和3-5,在教师模型的域适应阶段,最终的优化目标如下公式所示:
其中β,γ,σ是用于平衡损失函数的权重因子。
使用随机梯度(SGD)优化器并结合mini-batch的训练方式,通过最小化公式(7)中的MMD损失和三元组损失可以在领域级别和数据对级别为每对源领域和目标领域获得域不变特征。在源域和目标域中最小化分类损失/>将获得更优越的分类器,在不牺牲对目标域样本判别能力的情况下准确地预测源域样本。
S3-7:学生模型的训练
具体学生模型的结构可见图6,在教师模型领域自适应后,获得了多个单源域模型,这些模型能够有效地提取对于分类任务是有区分度的但在不同领域之间是可转移的深层EEG模式表示。为了最大化地利用这些单源模型,知识蒸馏被用以传递从多源领域学习的知识并训练一个更加强大的学生模型。
为了更好的合并教师模型的知识,本专利使用基于投票的方式来选择待合并的教师模型知识,可以表示为如下公式:
其中xi是第i个输入样本,Nt是教师模型的数量,mode(·)是用于寻找众数/多重众数的函数,*是点乘函数,是第j个教师模型对第i个输入样本的预测结果,/>是用于生成第i个输入样本教师模型掩码的决策标签集合。
在获得多个单源域教师模型的合并知识后,使用Kullback-Leibler(KL)散度,评估教师模型预测结果和学生模型预测结果的差异,公式如下所示:
其中X是输入样本集合,merge是经过合并的教师知识集合,T是预设置的用于控制softmax函数平滑度的温度系数,KLD[p,q]是用于度量分布p和分布q之间KL散度的评估函数。
使用Adam优化器并结合mini-batch的训练方式,最小化公式(9)中的KL损失,使学生模型充分学习教师模型的合并知识,获得在目标域上更优越的表现。
S4:跨时段和跨被试两种情境下的模型性能评估:
本发明在SEED数据集和SEED-IV数据集上具体验证模型性能。
对收敛完成的学生模型在目标域上获得的预测结果ypred和目标域中的真实标签yT使用混淆矩阵进行对比并获得对比结果,评估模型性能。准确率为模型测试时正确分类的样本个数占总测试样本的个数,模型准确度的计算公式如下:
其中,TP为被模型预测为正类的正类样本,TN为被模型预测为负类的负类样本,FP为被模型预测为正类的负类样本,FN为被模型预测为负类的正类样本。SEED数据中包括15个被试,每个被试做三场试验,共进行45次试验。则15个被试前2次试验的平均准确率如下所示:
结果的均方差公式如下所示:
跨时段和跨被试两种情境下的数据集划分具体见S3-1-1。对于跨被试的情境,本发明提出的模型在15个被试共1次试验的EEG数据上进行测试;对于跨时段的情境,本发明提出的模型在1个测试共15个被试的EEG数据上进行测试。最终的测试结果与现有技术(SVM、DGCNN和RGNN)的对比情况如下表所示:
表格1在SEED数据集上的分类器性能比较情况表
分类器 DDC DAN MS-MDA 本发明
跨时段准确率 81.53/6.83 79.93/7.06 88.56/7.80 97.58/1.46
跨被试准确率 68.99/3.23 65.84/2.25 89.63/6.79 91.73/10.48
表格2在SEED-IV数据集上的分类器性能比较情况表
分类器 DDC DAN MS-MDA 本发明
跨时段准确率 57.63/11.28 55.14/12.79 61.43/15.71 77.70/13.49
跨被试准确率 37.71/6.36 32.44/9.02 59.34/5.48 74.19/15.84
从上表的结果可以看出,本发明提出的方法比DDC、DAN和MS-MDA在跨时段和跨被试的情况下都得到了更高的准确率。本发明不仅适用于情绪状态识别的研究,还适用于任何基于EEG的跨时段跨被试分类预测任务,一定程度上解决了脑电个体差异性问题。

Claims (10)

1.一种基于知识蒸馏的多源域适应的EEG情绪状态分类方法,其特征在于,该方法以微分熵特征作为使用的EEG信号的频域特征,改进的EEGNet模型作为特征提取器,以单层的线性层作为分类器,对EEG信号进行分析,实现跨被试和跨时段两种情形下的情绪状态识别任务,具体过程包括:
I.基于每个有标签的源域对每个教师模型进行预训练;
II.基于相应的有标签源域和无标签的目标域,利用源域分类损失、目标域分类损失、最大均值差异和伪标签三元组损失,对每个教师模型进行域适应;其中,为提高伪标签三元组损失的有效性,采用基于边际的抽样策略来过滤原始特征,只选择边际分数高于预设阈值的特征作为计算伪标签三元组损失的嵌入特征;
III.将多个单源域教师的知识转移给学生模型。
2.根据权利要求1所述的一种基于知识蒸馏的多源域适应的EEG情绪状态分类方法,其特征在于,该方法中情绪状态分类存在两种测试情境:跨被试和跨时段,这两种情况下的模型测试有着各自不同的数据定义和数据集划分,具体数据定义和数据集划分如下:
假定存在N个被试,每个被试进行了D次不同会话;整个样本集合表示为其中i表示被试序号,j表示会话序号,Xi表示被试i的样本集合,对应的标签集合为Yi
对于跨时段的情感状态分类任务:采用留一法对数据集进行交叉验证;具体来说,在每一个被试i中,将最新一次会话中所有被试的15次情感试验的数据作为测试集;将其余D-1个会话以会话为单位,每个会话作为训练集中的一个源域,最终获得D-1个源域作为训练集;总共进行N次实验并计算平均准确率;
对于跨被试的情感状态分类任务:采用留一法对数据集进行交叉验证;具体来说,在一个会话中,迭代的取出一个被试的所有15次情感试验的数据并假设其情绪状态标签未知,作为测试集;从其余N-1个被试中,随机不重复的将R个被试组成一组作为训练集中的一个源域,最终获得个源域作为训练集,总共进行D×N次实验并计算平均准确率。
3.根据权利要求1或2所述的一种基于知识蒸馏的多源域适应的EEG情绪状态分类方法,其特征在于,该方法实现过程中MS-KTF模型的构建与训练如下:
S3-1:初始化;
MS-KFT模型由两部分组成:基于单源域的教师模型和作用于目标域的学生模型,教师模型和学生模型都由特定领域特征提取器Nf和标签分类器Ny这两个模块组成;将基于多源域的多个单源域教师模型和一个目标域学生的参数进行初始化操作;
S3-2:预训练多个单源域教师模型;
基于多个源域样本集合,分别使用单个有标签的源域样本来预训练一个领域特定教师模型的特征提取器Nf和标签分类器Ny,使之在各自的源域下有一定的模式识别能力;
S3-3:对多个单源域教师模型的特征提取器进行域适应;
每个有标签源域样本和无标签目标域样本形成一个分支,在每个分支中,使用对应领域特定教师模型的特征提取器Nf对各自源域样本和目标域样本进行特征提取,将特征从原始的特征空间提取到嵌入空间中;
然后在特征空间中基于最大均值差异对嵌入特征进行领域层面的对齐;基于边缘采样的伪标签三元组损失对嵌入特征进行数据组层面的对齐;
通过最小化最大均值差异和伪标签三元组损失,训练多个单源域教师模型的特征提取器Nf,使其能够在源域和目标域进行域不变特征提取;
S3-4:训练多个单源域教师模型的标签分类器Ny
在各个单源域教师模型中,将提取出来的源域特征信息通过标签分类器Ny获得预测情绪并计算预测情绪/>与实际样本中对应标签YS的交叉熵;同样,计算目标域特征信息的预测情绪/>与其所生成的伪标签/>的交叉熵;
通过最小化两个获得的交叉熵,训练多个单元域教师模型的标签分类器Ny,使其在各自的源域和目标域拥有良好的情绪分类能力;
S3-5:合并多个单源域教师模型的知识;
S3-6:将教师模型的合并知识教导给学生模型。
4.根据权利要求3所述的一种基于知识蒸馏的多源域适应的EEG情绪状态分类方法,其特征在于针对教师模型性能之间平衡性,使用两种不同的合并策略:
①基于投票的形式来筛选合并的教师模型;
这种方式更适用于教师性能平衡性较差的情况;使用无标签的目标域样本,首先通过教师模型的特征提取器和标签分类器获得的情绪预测结果生成对应的独热编码结果/>基于各个教师生成的独热编码结果/>进行投票,将投票结果作为决策变量如果该教师模型的情绪预测结果/>与决策变量/>相同,则该教师被选中进行知识合并;
计算所有被选中的教师模型的情绪预测结果的平均值作为所有教师模型的合并知识;
②使用平均的形式来合并教师模型;
这种方式更适用于教师性能平衡性较强的情况;这种情况下所有的教师模型均拥有同样的权重,计算所有被选中的教师模型的预测结果的平均值作为所有教师模型的合并知识。
5.根据权利要求3所述的一种基于知识蒸馏的多源域适应的EEG情绪状态分类方法,其特征在于所述的将教师模型的合并知识教导给学生模型,具体如下:
使用无标签的目标域样本数据,通过学生模型的特征提取器和标签分类器,得到学生模型的预测结果
基于预设定的蒸馏温度Temperature,将教师模型的合并知识和学生模型的预测结果/>进行平滑化处理;使用KL散度评估两个预测结果之间的差异;
通过最小化教师模型与学生模型的KL散度使学生模型学习教师模型的知识,获得在目标域相较于教师更具泛化性的特征提取和标签分类能力。
6.根据权利要求3或4或5所述的一种基于知识蒸馏的多源域适应的EEG情绪状态分类方法,其特征在于所述初始化具体实现如下:
S3-1-1:模型的数据划分,具体的划分情况描述如下:
对于跨被试的情境:模型的目标域样本集合为UT={Xi},其中Xi表示第i个被试的特征数据集合;xj表示Xi中的第j个样本,n表示Xi中样本的总数;模型的多源域样本集合为/> 其中[N]\i表示去掉第i个被试数据后的所有的被试序号集合,Pj表示第j个源域中包含的被试序号集合;
对于跨时段的情境:模型的目标域样本集合为其中Xi表示第i个被试的特征数据集合;j表示第j个会话(时段);模型的多源域样本集合为其中p]/j表示去掉第j个会话后的所有会话序号集合;
S3-1-2:模型的数据输入
每个有标签的源域样本集合US和无标签的目标域样本UT集合组成一个分支,用于后续的单源域教师模型的训练;而对于学生模型,则只使用无标签的目标域样本集合UT
7.根据权利要求3或4或5所述的一种基于知识蒸馏的多源域适应的EEG情绪状态分类方法,其特征在于对多个单源域教师模型的特征提取器进行域适应,具体如下:
在经过领域特定教师模型的特征提取器Nf后,对应的源领域数据US和目标领域数据UT将被提取出它们各自的低维特征FS、FT;为了确保提取出的特征的无偏域自适应性,使用了领域级别分布对齐和数据对级别分布对齐两种方法;
S3-4-1:领域级别的分布对齐
基于伪标签的三元组损失和最大均值差异两种技术,进行教师模型的域自适应;在训练过程中,通过最小化MMD损失来减小特征空间中源域和目标域之间的距离,从而实现领域级别对齐,具体的公式如下所示:
其中和/>分别表示源领域的第i个样本和目标领域的第j个样本被提取出的低维特征;NS和NT表示源域样本和目标域样本的数量;
S3-4-2:数据对级别的分布对齐
使用基于边缘采样的三元组损失进行数据对级别的分布对齐,使用每个样本预测结果的边缘成绩,作为决定该样本是否被采样的依据,该方法用如下公式表示:
Xselected={xj|margin(xj)≥Threshold,xj∈X} (3)
其中x是输入样本,gθ是标签分类器的抽象函数,i*是预测结果中拥有最高预测概率的类别,k是所有类别的数量,[k]\i*表示除i*外所有的类别集合,Threshold是预先设定的边缘采样的阈值;
使用三元组损失需要以三元组的形式进行采样,其中锚点样本/>和正例样本/>是来自第i个三元组的相同类别不同样本,/>是第i个三元组中与锚点样本/>不同类别的任意样本;三元组损失的目的是保证正样本对/>嵌入特征之间的距离加上固定的边际值margin小于负样本对/>嵌入特征之间的距离;对于一个小批量样本集,三元组损失定义为:
其中N是Xselected中包含的样本数量,α是预先设定用于引导可分性的边际值,d(·)是计算正则化嵌入特征对之间欧拉距离的函数,fθ(·)是特征提取的抽象函数。
8.根据权利要求3或4或5所述的一种基于知识蒸馏的多源域适应的EEG情绪状态分类方法,其特征在于使用交叉熵损失作为标签分类器在源域和目标域的分类结果的评估指标,在源域中具体使用SCL作为分类损失,在目标域中具体使用TCL作为分类损失;
在源域中,具有真实标签,所以SCL使用真实标签和标签分类器的分类结果作为比对对象,具体的公式如下所示:
其中xi是第i个源域输入样本,是第i个源域输入样本的真实标签,/>是标签分类器对第i个源域输入样本的预测结果,fθ(·)是特征提取的抽象函数,gθ(·)是标签分类器的抽象函数;
在目标域中,样本缺乏真实标签,所以相对应的TCL使用生成的伪标签和标签分类器的分类结果作为比对对象,具体的公式如下所示:
其中xi是第i个目标域输入样本,是第i个目标域输入样本生成的伪标签,/>是标签分类器对第i个目标域输入样本的预测结果,fθ(·)是特征提取的抽象函数,gθ(·)是标签分类器的抽象函数。
9.根据权利要求3或4或5所述的一种基于知识蒸馏的多源域适应的EEG情绪状态分类方法,其特征在于单源域教师模型的优化目标和训练,具体如下:
在教师模型的域适应阶段,最终的优化目标如下公式所示:
其中β,γ,σ是用于平衡损失函数的权重因子;
使用随机梯度优化器并结合mini-batch的训练方式,通过最小化公式(7)中的MMD损失和三元组损失在领域级别和数据对级别为每对源领域和目标领域获得域不变特征;在源域和目标域中最小化分类损失/>将获得更优越的分类器,在不牺牲对目标域样本判别能力的情况下准确地预测源域样本。
10.根据权利要求3或4或5所述的一种基于知识蒸馏的多源域适应的EEG情绪状态分类方法,其特征在于学生模型的训练具体如下:
在教师模型领域自适应后,获得了多个单源域模型,为了更好的合并教师模型的知识,使用基于投票的方式来选择待合并的教师模型知识,表示为如下公式:
其中xi是第i个输入样本,Nt是教师模型的数量,mode(·)是用于寻找众数/多重众数的函数,*是点乘函数,是第j个教师模型对第i个输入样本的预测结果,/>是用于生成第i个输入样本教师模型掩码的决策标签集合;
在获得多个单源域教师模型的合并知识后,使用KL散度,评估教师模型预测结果和学生模型预测结果的差异,公式如下所示:
其中X是输入样本集合,merge是经过合并的教师知识集合,T是预设置的用于控制softmax函数平滑度的温度系数,KLD[p,q]是用于度量分布p和分布q之间KL散度的评估函数;
使用Adam优化器并结合mini-batch的训练方式,最小化公式(9)中的KL损失,使学生模型充分学习教师模型的合并知识,获得在目标域上更优越的表现。
CN202310802378.2A 2023-06-30 2023-06-30 一种基于知识蒸馏的多源域适应的eeg情绪状态分类方法 Pending CN116821764A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310802378.2A CN116821764A (zh) 2023-06-30 2023-06-30 一种基于知识蒸馏的多源域适应的eeg情绪状态分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310802378.2A CN116821764A (zh) 2023-06-30 2023-06-30 一种基于知识蒸馏的多源域适应的eeg情绪状态分类方法

Publications (1)

Publication Number Publication Date
CN116821764A true CN116821764A (zh) 2023-09-29

Family

ID=88123857

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310802378.2A Pending CN116821764A (zh) 2023-06-30 2023-06-30 一种基于知识蒸馏的多源域适应的eeg情绪状态分类方法

Country Status (1)

Country Link
CN (1) CN116821764A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117034124A (zh) * 2023-10-07 2023-11-10 中孚信息股份有限公司 基于小样本学习的恶意流量分类方法、系统、设备及介质

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117034124A (zh) * 2023-10-07 2023-11-10 中孚信息股份有限公司 基于小样本学习的恶意流量分类方法、系统、设备及介质
CN117034124B (zh) * 2023-10-07 2024-02-23 中孚信息股份有限公司 基于小样本学习的恶意流量分类方法、系统、设备及介质

Similar Documents

Publication Publication Date Title
Li et al. A novel neural network model based on cerebral hemispheric asymmetry for EEG emotion recognition.
CN106886792B (zh) 一种基于分层机制构建多分类器融合模型的脑电情感识别方法
Li et al. A novel transferability attention neural network model for EEG emotion recognition
Jin et al. EEG-based emotion recognition using domain adaptation network
Zhang et al. Cross-subject EEG-based emotion recognition with deep domain confusion
Mensch et al. Learning neural representations of human cognition across many fMRI studies
CN112800998B (zh) 融合注意力机制和dmcca的多模态情感识别方法及系统
CN113392733B (zh) 基于标签对齐的多源域自适应跨被试eeg认知状态评估方法
Ning et al. Cross-subject EEG emotion recognition using domain adaptive few-shot learning networks
Ma et al. Depersonalized cross-subject vigilance estimation with adversarial domain generalization
Singhal et al. Detection of alcoholism using EEG signals and a CNN-LSTM-ATTN network
Zhou et al. PR-PL: A novel transfer learning framework with prototypical representation based pairwise learning for EEG-based emotion recognition
CN116821764A (zh) 一种基于知识蒸馏的多源域适应的eeg情绪状态分类方法
Quan et al. EEG-based cross-subject emotion recognition using multi-source domain transfer learning
CN114239652A (zh) 基于聚类的对抗部分域适应跨被试eeg情绪识别方法
Li et al. Continuous dynamic gesture recognition using surface EMG signals based on blockchain-enabled internet of medical things
Jiang et al. Analytical comparison of two emotion classification models based on convolutional neural networks
CN113951883B (zh) 基于脑电信号情绪识别的性别差异性检测方法
CN114330559A (zh) 一种基于度量迁移学习的脑电信号识别方法
Fan et al. DC-tCNN: a deep model for EEG-based detection of dim targets
CN112084935B (zh) 一种基于扩充高质量脑电样本的情绪识别方法
CN117493955A (zh) 一种癫痫患者的脑电信号分类模型的训练方法
Liu et al. Automated Machine Learning for Epileptic Seizure Detection Based on EEG Signals.
CN115758118A (zh) 一种基于脑电互信息的多源流形嵌入特征选择方法
CN115969392A (zh) 基于张量化频空注意力域适应网络的跨时段脑纹识别方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination