发明内容
为了解决上述背景技术中存在的技术问题,本发明提供一种基于迭代自训练域适应的肌电手势识别方法及系统,即采用基于迭代自训练域适应的算法(Iterative Self-Training based DomainAdaptation,STDA),该方法仅仅利用目标域极少标记数据和少量未标记数据,实现源域向目标域的有效迁移。STDA主要由基于差异的域适应(Discrepancy-based DomainAdaptation,DDA)和伪标签迭代更新(Pseudo-label Iterative Update,PIU)两部分组成。DDA使用基于高斯核的距离约束将现有用户的数据和新用户的未标记数据进行对齐。PIU迭代地不断更新伪标签,以生成具有类别平衡的新用户的更准确的标记数据。
为了实现上述目的,本发明采用如下技术方案:
本发明的第一个方面提供一种基于迭代自训练域适应的肌电手势识别方法。
一种基于迭代自训练域适应的肌电手势识别方法,包括:
获取历史用户的肌电数据和新用户的肌电数据;
基于历史用户的肌电数据,提取源域时频特征,基于新用户的肌电数据,提取目标域时频特征;
采用基于差异的域适应方法,对齐源域时频特征和目标域时频特征;
基于对齐后的目标域时频特征,训练模型,判断模型是否达到迭代次数要求,若是,对新用户的肌电数据的目标域时频特征进行预测,得到手势识别结果;否则,给目标域时频特征无标签数据打上伪标签,选取伪标签数量少于一定值的标签类进行上采样以平衡所有类别,重复对齐源域时频特征和目标域时频特征至迭代次数判断的过程,直到达到设定的迭代次数。
进一步地,所述采用基于差异的域适应方法,对齐源域时频特征和目标域时频特征的过程包括:将源域时频特征和目标域时频特征映射到同一个共享空间,测量源域时频特征和目标域时频特征两个分布之间的距离;根据所述距离优化目标域时频特征。
更进一步地,所述共享空间为再生核希尔伯特空间。
更进一步地,所述测量源域时频特征和目标域时频特征两个分布之间的距离采用以下公式:
其中,φ表示一个将原始数据映射到H的函数,φ(x)=k(·,x),且k一般取高斯核函数。
进一步地,在所述训练模型的过程,还包括:训练含有少量标记数据的模型;利用训练好的模型预测未标记样本的类标签;使用阈值筛选出置信度满足条件的伪标签;采用有标记和伪标记的数据联合训练模型,直到模型收敛,得到训练好的模型。
更进一步地,所述模型收敛根据损失函数判断,所述损失函数为:
其中,Y表示实际的标签,表示预测输出,C表示分类的总数,pi,k表示第i个样本被预测为第k类的概率。
更进一步地,所述置信度满足条件的伪标签计算公式为:
其中,对于一个C分类问题中的每个样本,
本发明的第二个方面提供一种基于迭代自训练域适应的肌电手势识别系统。
一种基于迭代自训练域适应的肌电手势识别系统,包括:
数据获取模块,其被配置为:获取历史用户的肌电数据和新用户的肌电数据;
特征提取模块,其被配置为:基于历史用户的肌电数据,提取源域时频特征,基于新用户的肌电数据,提取目标域时频特征;
对齐模块,其被配置为:采用基于差异的域适应方法,对齐源域时频特征和目标域时频特征;
训练识别模块,其被配置为:基于对齐后的目标域时频特征,训练模型,判断模型是否达到迭代次数要求,若是,对新用户的肌电数据的目标域时频特征进行预测,得到手势识别结果;否则,给目标域时频特征无标签数据打上伪标签,选取伪标签数量少于一定值的标签类进行上采样以平衡所有类别,重复对齐源域时频特征和目标域时频特征至迭代次数判断的过程,直到达到设定的迭代次数。
本发明的第三个方面提供一种计算机可读存储介质。
一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现如上述第一个方面所述的基于迭代自训练域适应的肌电手势识别方法中的步骤。
本发明的第四个方面提供一种计算机设备。
一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现如上述第一个方面所述的基于迭代自训练域适应的肌电手势识别方法中的步骤。
与现有技术相比,本发明的有益效果是:
本发明提供了基于迭代自训练域适应的肌电手势识别方法及系统,采用基于迭代自训练的域适应方法STDA。STDA主要由基于差异的域适应(DDA)和伪标签迭代更新(PIU)两部分组成。DDA使用基于高斯核的距离约束将现有用户的数据和新用户的未标记数据进行对齐。PIU迭代地不断更新伪标签,以生成具有类别平衡的新用户的更准确的标签数据。由自训练生成的伪标签来监督特征解耦过程,以实现源域到目标域的有效迁移,从而提高肌电手势识别的准确性。
具体实施方式
下面结合附图与实施例对本发明作进一步说明。
应该指出,以下详细说明都是例示性的,旨在对本发明提供进一步的说明。除非另有指明,本文使用的所有技术和科学术语具有与本发明所属技术领域的普通技术人员通常理解的相同含义。
需要注意的是,这里所使用的术语仅是为了描述具体实施方式,而非意图限制根据本发明的示例性实施方式。如在这里所使用的,除非上下文另外明确指出,否则单数形式也意图包括复数形式,此外,还应当理解的是,当在本说明书中使用术语“包含”和/或“包括”时,其指明存在特征、步骤、操作、器件、组件和/或它们的组合。
需要注意的是,附图中的流程图和框图示出了根据本公开的各种实施例的方法和系统的可能实现的体系架构、功能和操作。应当注意,流程图或框图中的每个方框可以代表一个模块、程序段、或代码的一部分,所述模块、程序段、或代码的一部分可以包括一个或多个用于实现各个实施例中所规定的逻辑功能的可执行指令。也应当注意,在有些作为备选的实现中,方框中所标注的功能也可以按照不同于附图中所标注的顺序发生。例如,两个接连地表示的方框实际上可以基本并行地执行,或者它们有时也可以按照相反的顺序执行,这取决于所涉及的功能。同样应当注意的是,流程图和/或框图中的每个方框、以及流程图和/或框图中的方框的组合,可以使用执行规定的功能或操作的专用的基于硬件的系统来实现,或者可以使用专用硬件与计算机指令的组合来实现。
实施例一
本实施例提供了一种基于迭代自训练域适应的肌电手势识别方法,本实施例以该方法应用于服务器进行举例说明,可以理解的是,该方法也可以应用于终端,还可以应用于包括终端和服务器和系统,并通过终端和服务器的交互实现。服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务器、云通信、中间件服务、域名服务、安全服务CDN、以及大数据和人工智能平台等基础云计算服务的云服务器。终端可以是智能手机、平板电脑、笔记本电脑、台式计算机、智能音箱、智能手表等,但并不局限于此。终端以及服务器可以通过有线或无线通信方式进行直接或间接地连接,本申请在此不做限制。本实施例中,该方法包括以下步骤:
获取历史用户的肌电数据和新用户的肌电数据;
基于历史用户的肌电数据,提取源域时频特征,基于新用户的肌电数据,提取目标域时频特征;
采用基于差异的域适应方法,对齐源域时频特征和目标域时频特征;
基于对齐后的目标域时频特征,训练模型,判断模型是否达到迭代次数要求,若是,对新用户的肌电数据的目标域时频特征进行预测,得到手势识别结果;否则,给目标域时频特征无标签数据打上伪标签,选取伪标签数量少于一定值的标签类进行上采样以平衡所有类别,重复对齐源域时频特征和目标域时频特征至迭代次数判断的过程,直到达到设定的迭代次数。
基于迭代自训练域适应的肌电手势识别方法的流程图,如图1所示:
1)开始
2)提取源域时频特征和目标域时频特征
3)判断是否满足预训练轮数,是则4),否则5)
4)对齐源域和目标域特征
5)训练分类器
6)判断是否达到设定训练轮数,是则7),否则8)
7)对目标域测试数据进行预测,转10)
8)给目标域无标签数据打上伪标签
9)少数类进行上采样以平衡所有类别,转4)
10)结束。
1、问题定义
已有用户的数据构成源域其中Ns表示源域中样本的数量,表示源域中第i个样本,表示源域中第i个样本对应的标签。目标用户的数据构成目标域其中表示目标域中的少量标记数据(对于一个C(C∈N+)分类问题,它只有C个样本),表示目标域中相对大量的未标记数据,表示目标域的测试数据,且满足其中Nl表示目标域中标记样本的数量。其中Nu表示目标域中未标记样本的数量。其中Nte表示目标域中测试样本的数量。在跨用户场景中,存在域偏移,即源域和目标域的分布不一致以上的xs,xt是四维的数据向量,ys,yt是类别标签。跨用户肌电手势识别的目标是使用现有用户的数据在目标域上的少量的标记数据和相对大量的未标记数据获得准确的测试数据的标签,即
2、基于迭代自训练域适应的方法
基于迭代自训练域适应的方法(Iterative Self-Training based DomainAdaptation,STDA)是一种域适应方法,通过迭代自训练生成的伪标签监督特征解耦的过程,实现源域到目标域的有效迁移。STDA方法的总体框架图如图2所示,由基于差异的域适应(Discrepancy-based DomainAdaptation,DDA)和伪标签迭代更新(Pseudo-labelIterative Update,PIU)两部分组成,分别用于对齐特征和生成更准确的目标域标签。
DDA使用最大均值差异(Maximum Mean Discrepancy,MMD)作为距离度量,将源域和目标域特征在再生核希尔伯特空间(Reproducing Kernel Hilbert Space,RKHS)进行对齐。PIU主要由两部分组成,一通过迭代自训练不断生成更准确的伪标签,二通过数据层面的上采样使得生成的伪标签更平衡。STDA的基本思路是通过不断迭代生成的更准确的目标域伪标签来监督特征解耦的过程以实现源域到目标域的有效迁移。
(1)基于差异的域适应
源域表示为目标域表示为源域和目标域的输入空间和标签空间相同,但两者的概率分布不一致,即PS≠Pt。特征对齐的目的是学习一个好的映射(f),同时将源域和目标域映射到一个共享空间(H),使得两者之间的距离相对接近基于最大均值差异(Maximum Mean Discrepancy,MMD)的核学习方法有效地度量了分布间的差异。最大均值差异测量的是再生核希尔伯特空间(Reproducing Kernel Hilbert Space,RKHS)中两个分布之间的距离。两个分布Ps和Pt之间的距离可以定义为:
其中,表示在RKHS下的一个函数集,Ex~.表示在源域或目标域下的期望。当源域和目标域的分布接近时,距离D接近于0。源域和目标域之间的MMD可以计算为:
其中,φ表示一个将原始数据映射到H的函数,φ(x)=k(·,x),且k一般取高斯核函数。在这个高维空间中,源域和目标域样本之间的分布差异可以用这个距离来测量。因此,MMD距离通常被认为是一种损失,嵌入到深度学习的网络高层,然后进行优化。
(2)伪标签迭代更新
伪标签迭代更新主要分为两个部分,一是迭代自训练,二是类别再平衡。本实施例中的迭代自训练主要分为如下四步:1)训练一个含有少量标记数据的模型;2)利用训练好的模型来预测未标记样本的类标签;3)使用阈值筛选出置信度满足条件的伪标签;4)用有标记和伪标记的数据联合训练模型,并重复1)-4),直到模型收敛。对于多分类问题,使用交叉熵,计算公式如下:
其中,Y表示实际的标签,表示预测输出,C表示分类的总数,pi,k表示第i个样本被预测为第k类的概率。对于一个C分类问题中的每个样本,基于softmax置信度的伪标签计算如下:
本实施例中的类别再平衡是为了防止不平衡类别样本对模型学习的误导,采用了数据级方法过采样,使少数类别与大多数类别样本具有可比性。为了使类别达到平衡,定义了平衡损失。假设一组手势类别的数量是x1、x2、···、xn,则平衡损失的计算公式如下:
其中,所述模型如图2所示,模型主要包括两个部分,基于差异的域适应(DDA)和伪标签迭代更新(PIU)。首先,将源域已有用户的肌电数据提取时频特征,将目标域新用户也提取时频特征;其次,将源域和目标域通过卷积神经网络(CNN)和自注意(Self-attention)提取特征;然后使用源域特征和目标域特征的最大均值差异(MMD)距离作为模型优化的损失,同时用目标域模型给目标域未标记数据打上伪标签,并且采用上采样平衡(CategoryRebalance)目标域上的所有类;迭代直到模型收敛。
3、实验评估
(1)数据集
表1常用公开数据集
NinaPro:NinaPro DB-1和DB-5是由稀疏电极感知设备采集的肌电数据,是肌电手势识别最常用的公开数据集之一,目前NinaPro数据集发展到DB10,其中包含了健康受试者和残肢患者的数据。DB-1子数据集包括27名健康的受试者,总共有52个手势,其中基本手指动作12个,基本手部动作8个,基本手腕动作9个,抓握和功能动作23个。DB1数据集的表面肌电数据采用10通道(差分电极)的OttoBock MyoBock采集,采样频率为100Hz。DB-5包括10名受试者,52个手势,16个通道,采样频率为200Hz。
CapgMyo:CapgMyo是由高密度电极阵列感知设备采集的肌电数据,是肌电手势识别最常用的公开数据集之一,CapgMyo是浙江大学耿卫东教授团队采用自研的设备采集的23名健康受试者的128通道的高密度表面肌电数据,采样频率为1000Hz,包含3个子数据集,DB-a,DB-b和DB-c。DB-a包含18名受试者的8种手指手势,DB-b包含10名受试者在2个不同时间段采集的8种手势,DB-c包含10名受试者的12种基本手指手势。
CapgMyo中的8种手势、12种手势分别如图3和图4。
(2)对比方法
为了验证STDA方法的效果,本实施例选择七种方法作为对比方法,包括:
·STDA的一种变体,它只使用源域来训练模型(Only-Source);
●STDA的一种变体,它只使用目标域来训练模型(Only-Target);
●首先分解通道,其次再融合特征的基于微调的方法(Multi-Stream’);
●一种双流监督域适应框架(MDSDA);
●一种基于核空间距离的域自适应方法(SGAS);
●一种具有伪组对比机制的无监督域适应方法(Self-Tuning);
●一种循环自训练领域适应方法(CST)。
其中,Only-Source和Only-Target是两种基准方法,Multi-Stream’是一种基于深度学习的微调方法,MDSDA和SGAS是两种监督域适应方法,Self-Tuning和CST是两种无监督域适应方法。实验环境为Linux 125GB,开发环境为Python3.8.3,主要依赖的第三方库为pytorch 1.10.2+cu113。实验过程中,主要参数设置如下,置信度参数设为0.99,预训练轮数设置为400,学习率设置为0.001。
(3)对比实验结果
表2对比实验结果
对比实验结果如表2所示。根据表2所示实验结果,在公开基准肌电手势识别数据集上,本实施例提出的STDA方法明显优于其他方法。与基线方法相比,改善了25%以上;与微调方法相比,除DB-1数据集外,在其余数据集上提高了8%以上;与监督域适应方法相比,提高了5%以上;与无监督域适应方法相比,提高了24%以上。
(4)混淆矩阵分析
混淆矩阵分析结果如图5(a)-图5(e)所示。从DB-1和DB-5数据集上的混淆矩阵中可以看出,DB-1上的第28个和第51个手势比其他手势获得更高的准确性;DB-5上的第1、4、16、24、27个手势比其他手势具有更高的准确率,说明在手势识别系统的构建中,复杂的手势集设计也至关重要。此外,一些手势也很容易被混淆。例如,在DB-1数据集上,第50个手势有大约24%的概率被误判为第49个手势。同样地,第14个手势也很容易被错误地判为第13个手势。在DB-5数据集上也出现了类似的现象。例如,第8个手势很容易被误判为第10个手势,而第13个手势很容易被误判为第14个手势。假阳性率高达30%。在CapgMyo的三个子数据集DB-a、DB-b和DB-c上也得到了类似的结论。
从CapgMyo数据集的两个八分类数据集DB-a和DB-b可以看出,在DB-a中第五个手势准确率最高,为71.6%,在DB-b中第五个手势的准确率达到第二高,非常接近DB-b中81.7%的最高准确率。同时,第三个手势在DB-a和DB-b中的准确率最低,DB-a上准确率为35.2%,DB-b上准确率为57.5%。一个可能的原因是,第五个手势与其他七个手势相对不同,而第三个手势与其余七个手势中的几个非常相似。
(5)消融实验分析
表3消融实验结果
消融实验结果如表3所示。为了探究各部分的贡献,我们进行了消融实验。SDTA方法主要由两个模块组成。特征空间的对齐简称为mmd,伪标签的迭代更新简称为self-traing。总体来说,将特征对齐和自训练相结合的STDA方法具有最好的性能,证明了该方法的有效性。同时,实验结果表明,在DB-5、DB-a、DB-b和DB-c数据集上,自训练策略的贡献最大。在DB-1数据集上,主要的贡献是特征空间的对齐。
(6)参数敏感性分析
参数敏感性分析结果如图6(a)-图6(b)所示。SDTA方法主要对两个参数敏感,即预训练的轮数(简称“epoch”)和迭代自训练的置信度阈值(简称“thres”)。为了评估参数对STDA方法性能的影响,我们使用了一种单变量方法,换句话说,就是改变一个变量,同时保持另一个变量不变。参数“epoch”的范围设置为{50、100、200、400、600、800};参数“thres”的范围设置为{0.7、0.8、0.9、0.95、0.99}。图中三角形是最优值。可以看出,当“epoch”为100时,STDA方法的性能最好,当“epoch”为600时,STDA方法的性能最差。这表明,适当的预训练有利于模型学习,过度的预训练可能导致模型的过拟合。当“thres”为0.95时,STDA方法的性能最好,当“thres”为0.8时,STDA方法的性能最差。这说明,过低的置信度会导致产生大量错误标签,误导模型的学习,过高的置信度也会对模型产生不利影响。
实施例二
本实施例提供了一种基于迭代自训练域适应的肌电手势识别系统。
一种基于迭代自训练域适应的肌电手势识别系统,包括:
数据获取模块,其被配置为:获取历史用户的肌电数据和新用户的肌电数据;
特征提取模块,其被配置为:基于历史用户的肌电数据,提取源域时频特征,基于新用户的肌电数据,提取目标域时频特征;
对齐模块,其被配置为:采用基于差异的域适应方法,对齐源域时频特征和目标域时频特征;
训练识别模块,其被配置为:基于对齐后的目标域时频特征,训练模型,判断模型是否达到迭代次数要求,若是,对新用户的肌电数据的目标域时频特征进行预测,得到手势识别结果;否则,给目标域时频特征无标签数据打上伪标签,选取伪标签数量少于一定值的标签类进行上采样以平衡所有类别,重复对齐源域时频特征和目标域时频特征至迭代次数判断的过程,直到达到设定的迭代次数。
此处需要说明的是,上述数据获取模块、特征提取模块、对齐模块和训练识别模块与实施例一中的步骤所实现的示例和应用场景相同,但不限于上述实施例一所公开的内容。需要说明的是,上述模块作为系统的一部分可以在诸如一组计算机可执行指令的计算机系统中执行。
实施例三
本实施例提供了一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现如上述实施例一所述的基于迭代自训练域适应的肌电手势识别方法中的步骤。
实施例四
本实施例提供了一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现如上述实施例一所述的基于迭代自训练域适应的肌电手势识别方法中的步骤。
本领域内的技术人员应明白,本发明的实施例可提供为方法、系统、或计算机程序产品。因此,本发明可采用硬件实施例、软件实施例、或结合软件和硬件方面的实施例的形式。而且,本发明可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器和光学存储器等)上实施的计算机程序产品的形式。
本发明是参照根据本发明实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的程序可存储于一计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,所述的存储介质可为磁碟、光盘、只读存储记忆体(Read-Only Memory,ROM)或随机存储记忆体(RandomAccessMemory,RAM)等。
以上所述仅为本发明的优选实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。