CN112786030A - 一种基于元学习的对抗采样训练方法及装置 - Google Patents

一种基于元学习的对抗采样训练方法及装置 Download PDF

Info

Publication number
CN112786030A
CN112786030A CN202011642701.7A CN202011642701A CN112786030A CN 112786030 A CN112786030 A CN 112786030A CN 202011642701 A CN202011642701 A CN 202011642701A CN 112786030 A CN112786030 A CN 112786030A
Authority
CN
China
Prior art keywords
training
query
sampling
language
task
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202011642701.7A
Other languages
English (en)
Other versions
CN112786030B (zh
Inventor
肖雨蓓
郑国林
聂琳
梁小丹
王青
林倞
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Sun Yat Sen University
Original Assignee
Sun Yat Sen University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Sun Yat Sen University filed Critical Sun Yat Sen University
Priority to CN202011642701.7A priority Critical patent/CN112786030B/zh
Publication of CN112786030A publication Critical patent/CN112786030A/zh
Application granted granted Critical
Publication of CN112786030B publication Critical patent/CN112786030B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G10MUSICAL INSTRUMENTS; ACOUSTICS
    • G10LSPEECH ANALYSIS TECHNIQUES OR SPEECH SYNTHESIS; SPEECH RECOGNITION; SPEECH OR VOICE PROCESSING TECHNIQUES; SPEECH OR AUDIO CODING OR DECODING
    • G10L15/00Speech recognition
    • G10L15/08Speech classification or search
    • G10L15/16Speech classification or search using artificial neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/049Temporal neural networks, e.g. delay elements, oscillating neurons or pulsed inputs
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Computational Linguistics (AREA)
  • Health & Medical Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • General Health & Medical Sciences (AREA)
  • General Physics & Mathematics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Human Computer Interaction (AREA)
  • Acoustics & Sound (AREA)
  • Multimedia (AREA)
  • Machine Translation (AREA)

Abstract

本发明公开了一种基于元学习的对抗采样训练方法及装置,所述方法:根据策略网络从K个语种构成的大任务集T中输出K维概率向量
Figure DDA0002872045210000011
其中,
Figure DDA0002872045210000012
为第i个语种任务集对应的采样概率,根据所述采样概率选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,将所述训练任务集划分为支持集和查询集;所述支持集对语音识别模型初始化参数θ进行梯度下降得到更新参数
Figure DDA0002872045210000013
所述查询集根据查询所述更新参数
Figure DDA0002872045210000014
的效果获得查询损失向量
Figure DDA0002872045210000015
所述查询损失向量
Figure DDA0002872045210000016
用于对所述初始化参数θ寻优,获得最优的模型参数。以多语种元学习语音识别框架为基础,引入策略网络形成对抗训练,解决解决低资源语种识别不均衡的问题,提升训练的效果。

Description

一种基于元学习的对抗采样训练方法及装置
技术领域
本发明涉及语音识别技术领域,尤其涉及一种基于元学习的对抗采样训练方法及装置。
背景技术
随着深度学习理论和相关技术的蓬勃发展,语音识别领域取得了巨大的进展。然而构造一个端到端的深层语音识别模型经常需要大量的有标注的数据,而这些数据对于许多低资源语种是非常难以获取的。为了解决上述问题,有许多工作利用无监督预训练和半监督学习的方法去利用大量无标注数据来帮助低资源目标语种提升识别效果,但是这些方法依然需要大量目标语种的无标注数据,对于部分小语种来说,无标注数据也是很少量的。
因此,迁移学习被引入解决低资源语种识别问题,迁移学习通过其他语种的数据来帮助目标低资源语种来提升识别效果。同时还有多语种迁移学习方法用多个其他源语种预训练模型初始化参数,于是只需要少量低资源目标语种数据在预训练模型基础上训练就可以得到较好的目标模型。但迁移学习的方法学习到的模型参数比较容易倾向于源语种而无法很好地进行迁移。除此之外,元学习的方法也被引入到低资源语音识别问题中。元学习的方法通过一系列训练任务来元学习得到模型初始化参数,以便能够快速地适应到只有少量数据的新任务上,这种方法十分适用于低资源的场景。
然而,现有应用低资源语音识别和低资源语码转换的语音识别中,都忽略了在真实场景中任务不均衡的问题,现有技术平等地利用每个语种的元信息,从而导致了效果的损失。
发明内容
本发明目的在于,提供一种基于元学习的对抗采样训练方法及装置,以多语种元学习语音识别框架为基础,引入策略网络形成对抗训练,提升多语种低资源语音识别训练的效果。
为实现上述目的,本发明实施例提供一种基于元学习的对抗采样训练方法,包括:
根据策略网络从K个语种构成的大任务集T中输出K维概率向量
Figure BDA0002872045190000011
其中,
Figure BDA0002872045190000012
为第i个语种任务集对应的采样概率,根据所述采样概率选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,将所述训练任务集划分为支持集和查询集;
所述支持集对语音识别模型初始化参数θ进行梯度下降得到更新参数
Figure BDA0002872045190000021
所述查询集根据查询所述更新参数
Figure BDA0002872045190000022
的效果获得查询损失向量
Figure BDA0002872045190000023
所述查询损失向量
Figure BDA0002872045190000024
用于对所述初始化参数θ寻优,获得最优的模型参数。
优选地,所述根据策略网络从K个语种构成的大任务集T中输出K维概率向量
Figure BDA0002872045190000025
其中,
Figure BDA0002872045190000026
为第i个语种任务集对应的采样概率,根据所述采样概率选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,将所述训练任务集划分为支持集和查询集,包括:
所述策略网络包括前馈注意力层和LSTM层;
所述策略网络通过所述LSTM层中存储的长短期记忆信息和当前查询损失向量获取采样任务,其中,所述采样任务根据所述采样概率获得所述训练任务集。
优选地,所述查询集根据查询所述更新参数
Figure BDA0002872045190000027
的效果获得查询损失向量
Figure BDA0002872045190000028
所述查询损失向量
Figure BDA0002872045190000029
用于对所述初始化参数θ寻优,获得最优的模型参数,包括:
每一次训练获取当前训练步的查询损失向量与概率向量,将所述查询损失向量与所述概率向量输入下一次训练的策略网络,将所述查询损失向量与所述概率向量合并计算前馈注意力,所述前馈注意力通过全连接层输出cs+1
优选地,所述查询集根据查询所述更新参数
Figure BDA00028720451900000210
的效果获得查询损失向量
Figure BDA00028720451900000211
所述查询损失向量
Figure BDA00028720451900000212
用于对所述初始化参数θ寻优,获得最优的模型参数,还包括:
将所述全连接层输出的cs+1与上一次训练中LSTM层的隐藏状态hs作为输入当前LSTM层的值,生成所述当前LSTM层的输出ys+1与当前隐藏状态hs+1,基于所述当前LSTM层的输出值通过全连接层与Softmax函数预测当前的概率向量
Figure BDA00028720451900000213
根据所述概率向量,选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,获取所述查询损失向量
Figure BDA00028720451900000214
并用于对所述初始化参数θ寻优,获得最优的模型参数。
本发明实施例还提供了一种基于元学习的对抗采样训练装置,包括:
训练模块,根据策略网络从K个语种构成的大任务集T中输出K维概率向量
Figure BDA0002872045190000031
其中,
Figure BDA0002872045190000032
为第i个语种任务集对应的采样概率,根据所述采样概率选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,将所述训练任务集划分为支持集和查询集;
第一更新模块,所述支持集对语音识别模型初始化参数θ进行梯度下降得到更新参数
Figure BDA0002872045190000033
第二更新模块,所述查询集根据查询所述更新参数
Figure BDA0002872045190000034
的效果获得查询损失向量
Figure BDA0002872045190000035
所述查询损失向量
Figure BDA0002872045190000036
用于对所述初始化参数θ寻优,获得最优的模型参数。
优选地,所述训练模块,包括:
所述策略网络包括前馈注意力层和LSTM层;
所述策略网络通过所述LSTM层中存储的长短期记忆信息和当前查询损失向量获取采样任务,其中,所述采样任务根据所述采样概率获得所述训练任务集。
优选地,所述第二更新模块,包括:
每一次训练获取当前训练步的查询损失向量与概率向量,将所述查询损失向量与所述概率向量输入下一次训练的策略网络,将所述查询损失向量与所述概率向量合并计算前馈注意力,所述前馈注意力通过全连接层输出cs+1
优选地,所述第二更新模块,还包括:
将所述全连接层输出的cs+1与上一次训练中LSTM层的隐藏状态hs作为输入当前LSTM层的值,生成所述当前LSTM层的输出ys+1与当前隐藏状态hs+1,基于所述当前LSTM层的输出值通过全连接层与Softmax函数预测当前的概率向量
Figure BDA0002872045190000037
根据所述概率向量,选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,获取所述查询损失向量
Figure BDA0002872045190000038
并用于对所述初始化参数θ寻优,获得最优的模型参数。
本发明实施例还提供一种计算机终端设备,包括一个或多个处理器和存储器。存储器与所述处理器耦接,用于存储一个或多个程序;当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理器实现如上述任一实施例所述的基于元学习的对抗采样训练方法。
本发明实施例还提供一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如上述任一实施例所述的基于元学习的对抗采样训练方法。
本发明实施例在元学习的基础上引入策略网络,策略网络由注意力机制与LSTM层融合而成,通过策略网络输出概率向量与元学习生成的查询损失向量进行更新迭代,寻找最优的模型参数,每一次训练任务中,语音识别网络朝着查询损失值尽可能小的方向优化,而策略网络则是朝着尽可能大的方向去优化,形成对抗训练,促进了语音识别网络的有效训练。
附图说明
为了更清楚地说明本发明的技术方案,下面将对实施方式中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施方式,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本发明某一实施例提供的基于元学习的对抗采样训练方法的流程示意图;
图2是本发明另一实施例提供的基于元学习的对抗采样训练方法的流程示意图;
图3是本发明某一实施例提供的基于元学习的对抗采样训练装置的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
应当理解,文中所使用的步骤编号仅是为了方便描述,不对作为对步骤执行先后顺序的限定。
应当理解,在本发明说明书中所使用的术语仅仅是出于描述特定实施例的目的而并不意在限制本发明。如在本发明说明书和所附权利要求书中所使用的那样,除非上下文清楚地指明其它情况,否则单数形式的“一”、“一个”及“该”意在包括复数形式。
术语“包括”和“包含”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
请参阅图1,本发明实施例提供一种基于元学习的对抗采样训练方法,包括:
S101、根据策略网络从K个语种构成的大任务集T中输出K维概率向量
Figure BDA0002872045190000051
其中,
Figure BDA0002872045190000052
为第i个语种任务集对应的采样概率,根据所述采样概率选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,将所述训练任务集划分为支持集和查询集;
元学习的方法被引入到低资源语音识别问题中,通过一系列训练任务来元学习得到模型初始化参数,以便能够快速地适应到只有少量数据的新任务上,这类方法十分适用于低资源的场景。
基本的多语种元学习语音识别框架,采用K个源语种
Figure BDA0002872045190000053
对于第k个语种
Figure BDA0002872045190000054
通过采样部分数据
Figure BDA0002872045190000055
构成一个识别任务
Figure BDA0002872045190000056
然后进行划分
Figure BDA0002872045190000057
为支持集Dsupport和查询集Dquery,先用支持集Dsupport对语音识别模型初始化参数θ进行一次梯度下降得到特定任务更新参数
Figure BDA0002872045190000058
再用查询集Dquery来衡量该参数的效果,得到查询损失向量
Figure BDA0002872045190000059
采用
Figure BDA00028720451900000510
来指导整个模型参数θ的更新。在每个训练步中,都会采样出多个
Figure BDA00028720451900000511
并收集所有
Figure BDA00028720451900000512
来更新模型参数θ。那么θ就会更新到最优参数,使得任意任务Ti,只需在其Dsupport进行少量的更新,就可以在Dquery取得很好的效果。从而让预训练得到的参数θ拥有低资源迁移的能力。
在每一次训练中,需要从任务集
Figure BDA00028720451900000513
(其中
Figure BDA00028720451900000514
即第k个语种的任务集,nk为第k个语种任务集的任务总量)中采样M个任务Ti进行训练。而每种语言的任务数量nk不同,造成了任务数量不均衡,同时每个语种任务集中的任务识别难度不一致,也造成了任务难度的不均衡,以下统称这两种不均衡为多语种低资源语音识别中的任务不均衡问题(task-imbalance problem)。
策略网络(policy network)由前馈注意力层和LSTM层融合而成,在每一训练步进行采样时,策略网络通过所述LSTM层中存储的长短期记忆信息和当前查询损失向量获取采样任务,其中,采样任务根据所述采样概率获得所述训练任务集,采样出最合适的任务,促进网络的学习,解决任务不均衡的问题,该策略网络可以和语音识别网络一起端到端地训练而无需额外的手动调节的参数。
在每一步的训练中,策略网络从K种不同语种任务集组成的大任务集T中采样M个语种任务集
Figure BDA00028720451900000515
并从这每个任务集
Figure BDA00028720451900000516
中每个采样一个任务
Figure BDA0002872045190000061
构成训练任务集
Figure BDA0002872045190000062
用该训练任务集划分查询集和支持集,其中,支持集对语音识别模型初始化参数θ进行梯度下降得到更新参数
Figure BDA0002872045190000063
查询集根据查询更新参数
Figure BDA0002872045190000064
的效果获得查询损失向量
Figure BDA0002872045190000065
查询损失向量
Figure BDA0002872045190000066
用于对所述初始化参数θ寻优,获得最优的模型参数。在语音识别模型进行了元学习训练后,对于每一个训练任务
Figure BDA0002872045190000067
可以得到其对应的查询损失组成
Figure BDA0002872045190000068
语音识别网络朝着希望
Figure BDA0002872045190000069
尽可能小的方向去进行优化,而策略网络朝着希望
Figure BDA00028720451900000610
尽可能大的方向去进行优化。
具体的,每一次训练获取当前训练步的查询损失向量与概率向量,将所述查询损失向量与所述概率向量输入下一次训练的策略网络,将所述查询损失向量与所述概率向量合并计算前馈注意力,所述前馈注意力通过全连接层输出cs+1,将全连接层输出的cs+1与上一次训练中LSTM层的隐藏状态hs作为输入当前LSTM层的值,生成当前LSTM层的输出ys+1与当前隐藏状态hs+1,基于当前LSTM层的输出值通过全连接层与Softmax函数预测当前的概率向量
Figure BDA00028720451900000611
根据概率向量,选取前M个概率最大的语种,根据M个概率最大语种中每个语种采样一个任务构成训练任务集,获取查询损失向量
Figure BDA00028720451900000612
并用于对初始化参数θ寻优,获得最优的模型参数。
请参照图2,每一步的训练如下,包括但不限于以下迭代次数:
S-1、采用K维向量
Figure BDA00028720451900000613
定义当前步的查询损失值,策略网络输出一个K维向量
Figure BDA00028720451900000614
其中,初始的采样概率为每一维中值均等,以下步骤的采样概率为网络输出,初始LSTM层的隐藏状态hs-1为0向量。
S、将S-1步的查询损失值向量与采样概率向量输入当前的策略网络,并把两个向量合并起来计算前馈注意力,通过全连接层输出cs,之后当前的LSTM层以S-1层的隐藏状态hs-1和由全连接层输出的cs作为输入,生成当前LSTM层的输出ys与当前隐藏状态hs,基于所述当前LSTM层的输出值通过全连接层与Softmax函数预测当前的概率向量
Figure BDA00028720451900000615
根据概率向量选择前M个最大概率的语种并采样其任务来训练,然后从每个语种任务集中采样一个任务作为当前步的训练任务集,那么M个任务就会产生M个新的查询损失值,采用这M个查询损失值更新S-1步的查询损失向量
Figure BDA0002872045190000071
以获得第S步新的查询损失向量
Figure BDA0002872045190000072
S+1、把S步中获得的查询损失向量
Figure BDA0002872045190000073
与概率向量
Figure BDA0002872045190000074
输入当前的策略网络,跟S步一样进行迭代更新,获得当前的概率向量
Figure BDA0002872045190000075
与查询损失向量
Figure BDA0002872045190000076
选取前M个概率最大的语种并进行采样其任务来进行训练。
本发明考虑到语种任务集的查询损失可以很好的衡量两种不均衡的问题,若某种语言的任务由于其任务的难度或者数量而导致其未被充分采样,那么其查询损失就会很大,就应该更多地被采样训练,而本发明引入策略网络与语音识别网络的结合,形成对抗网络,促进了语音识别网络更有效的训练。
S102、所述支持集对语音识别模型初始化参数θ进行梯度下降得到更新参数
Figure BDA0002872045190000077
具体的,元学习可以将采样数据划分为支持集和查询集,其中,先采用支持集对语音识别模型初始化参数进行一次梯度下降得到特定任务更新参数,这个参数是为了查询集衡量更新效果。
S103、所述查询集根据查询所述更新参数
Figure BDA0002872045190000078
的效果获得查询损失向量
Figure BDA0002872045190000079
所述查询损失向量
Figure BDA00028720451900000710
用于对所述初始化参数θ寻优,获得最优的模型参数。
具体的,在采样M个语种的情况下,每一次训练中都会产生M个查询集,查询集通过支持集的更新参数形成查询损失向量,该查询损失向量通过对初始化参数θ寻优,获得最优的模型参数。
本发明采用joint attention-CTC语音识别模型作为语言识别基础模型,编码器包含6层的VGG特征提取和5层双向LSTM层,注意力层采用300维的Location-awareattention,解码器是320维的一层LSTM。在预测的时候,使用贪婪搜索解码来获得最佳的预测结果。采用80维的对数梅尔滤波器系数作为输入向量,谷歌的SentencePiece工具用来处理音频的文本。所有的文本用来基于BPE算法训练sub-word模型,文本通过训练好的sub-word模型处理成token序列。策略网络主要包含一个前馈注意力层和一层LSTM。采用Adam优化器,初始学习率为0.035来训练策略网络,对于采样任务数M,实验采用M的值为2,3,4,5,7,9,最终选择了3。在多个多语言数据集上进行了实验,实验结果均表明本发明取得了最优的效果。首先采用Common Voice数据集,该数据集是一个开源多语言数据集,其包含40多种语种。将其划分成三个不同的数据集Diversity11、Indo12还有Indo9。为了构建Diversity11,随机选择了11种不同地区的语种,每个语种的种类和数量都不一致,并将其划分为9种源语种和两种目标语种。这个Diversity11数据集模拟了真实场景下任务不均衡的情况。为了Indo12数据集,随机地从相同语系(印欧语系)中选择了11种语种,并且选择了一个亚非语系的语言作为目标语种。为了测试本发明在更少源语种上的效果,在Indo11中移去了三个语种(Russian,Swedish,and Welsh)构成了Indo9。除此之外,还在IARPA BABEL数据集上进行了实验,选取了六种源语种(Bengali,Tagalog,Zulu,Turkish,Lithuanian,Guarani),三种目标语种(Vietnamese,Swahili,Tamil),在Diversity11数据集上,对于所有目标语种,本发明都优于之前的低资源语音识别方法的效果。首先Monolingualtraining的方法由于没有其他语种的帮助,只用目标语种很少量的训练数据训练得到的模型效果很差。第二,由于元学习有低资源迁移学习的能力,相对于多语种迁移学习,元学习在其上能够有6%的WER下降。而再在其上使用本发明对抗训练采样方式,解决任务不均衡问题,使得语音识别模型充分有效训练,进一步将WER降低7%。除了在多语言元学习上有效果,本发明在多语言迁移学习上也有很好的效果。
请参照表1,Diversity11数据集上的WER效果
表1 Diversity11数据集上的WER效果
Figure BDA0002872045190000081
除了在Diversity11上表现很好的效果外,本发明在Indo12数据集上也表现出了最优的效果,说明了在任务不均衡情况没那么严重的情况下,本发明采样方法研究可以促进有效地训练和减缓任务难度不均衡的问题,使得在Indo12上效果最优。而在缩减了3个源语种的情况下,在Indo9数据集上,本发明的方法也还是能够帮助到目标语种的提升。除此之外,本发明还在BABEL公开数据集上进行了验证,BABEL数据集的任务不均衡问题相对来说也没有那么严重,不过本发明的方法依旧在其上取得了最好的效果。
现有的多语种低资源语音识别算法如多语言迁移学习、多语言元学习并没有考虑到多语种训练时的任务不均衡问题,普遍采用均匀采样的方式来采样任务进行训练。均匀采样方式既忽略了语种任务的数量不均衡又忽略了语种任务的难度不均衡,会导致学习的参数存在偏差,不能很快速很好地迁移到目标语种上。而按照任务数量比例进行采样的方式又忽略了语种的难度不均衡,很可能导致学习到的参数偏向于语种数量大的语种上,也不能很快速地迁移到目标语种上。
在本实施例中,采用对抗采样训练的方法以解决多语种低资源语音识别问题中的任务不均衡问题。引入了一个策略网络,该网络由注意力机制和LSTM融合构成。在每一训练步进行采样时,策略网络可以通过LSTM中存储的长短期记忆信息和当前查询损失自动地决定应该采样的任务,采样出最合适的任务,促进网络的学习,解决任务不均衡的问题。并且所采用的策略网络可以和语音识别网络一起端到端地训练而无需额外的手动调节的参数。
本发明考虑到语种任务集的查询损失可以很好地衡量两种不均衡,因为如果某种语言的任务由于其任务难度或者任务数量而导致其未被充分采样,那么其查询损失就会很大,就应该更多地被采样被训练。因此采用的策略网络是尽可能希望采样出查询损失大的任务。而语音识别网络是希望查询损失尽量小,因此形成了对抗训练,同时也更促进了语音识别网络的有效训练,采用一种对抗的训练方法来自适应地采样更好的任务进行学习,以提升多语种低资源语音识别训练的效果。
请参阅图3,本发明实施例提供一种基于元学习的对抗采样训练装置,包括:
训练模块11,根据策略网络从K个语种构成的大任务集T中输出K维概率向量
Figure BDA0002872045190000091
其中,
Figure BDA0002872045190000092
为第i个语种任务集对应的采样概率,根据所述采样概率选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,将所述训练任务集划分为支持集和查询集;
第一更新模块12,所述支持集对语音识别模型初始化参数θ进行梯度下降得到更新参数
Figure BDA0002872045190000093
第二更新模块13,所述查询集根据查询所述更新参数
Figure BDA0002872045190000094
的效果获得查询损失向量
Figure BDA0002872045190000095
所述查询损失向量
Figure BDA0002872045190000096
用于对所述初始化参数θ寻优,获得最优的模型参数。
关于基于元学习的对抗采样训练装置的具体限定可以参见上文中对于的限定,在此不再赘述。上述基于元学习的对抗采样训练装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
本发明实施例提供一种计算机终端设备,包括一个或多个处理器和存储器。存储器与所述处理器耦接,用于存储一个或多个程序,当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理器实现如上述任意一个实施例中的基于元学习的对抗采样训练方法。
处理器用于控制该计算机终端设备的整体操作,以完成上述的基于元学习的对抗采样训练方法的全部或部分步骤。存储器用于存储各种类型的数据以支持在该计算机终端设备的操作,这些数据例如可以包括用于在该计算机终端设备上操作的任何应用程序或方法的指令,以及应用程序相关的数据。该存储器可以由任何类型的易失性或非易失性存储设备或者它们的组合实现,例如静态随机存取存储器(Static Random Access Memory,简称SRAM),电可擦除可编程只读存储器(Electrically Erasable Programmable Read-OnlyMemory,简称EEPROM),可擦除可编程只读存储器(Erasable Programmable Read-OnlyMemory,简称EPROM),可编程只读存储器(Programmable Read-Only Memory,简称PROM),只读存储器(Read-Only Memory,简称ROM),磁存储器,快闪存储器,磁盘或光盘。
在一示例性实施例中,计算机终端设备可以被一个或多个应用专用集成电路(Application Specific 1ntegrated Circuit,简称AS1C)、数字信号处理器(DigitalSignal Processor,简称DSP)、数字信号处理设备(Digital Signal Processing Device,简称DSPD)、可编程逻辑器件(Programmable Logic Device,简称PLD)、现场可编程门阵列(Field Programmable Gate Array,简称FPGA)、控制器、微控制器、微处理器或其他电子元件实现,用于执行上述的基于元学习的对抗采样训练方法,并达到如上述方法一致的技术效果。
在另一示例性实施例中,还提供了一种包括程序指令的计算机可读存储介质,该程序指令被处理器执行时实现上述任意一个实施例中的基于元学习的对抗采样训练方法的步骤。例如,该计算机可读存储介质可以为上述包括程序指令的存储器,上述程序指令可由计算机终端设备的处理器执行以完成上述的基于元学习的对抗采样训练方法,并达到如上述方法一致的技术效果。
以上所述是本发明的优选实施方式,应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明原理的前提下,还可以做出若干改进和润饰,这些改进和润饰也视为本发明的保护范围。

Claims (10)

1.一种基于元学习的对抗采样训练方法,其特征在于,包括:
根据策略网络从K个语种构成的大任务集T中输出K维概率向量
Figure FDA0002872045180000011
其中,
Figure FDA0002872045180000012
为第i个语种任务集对应的采样概率,根据所述采样概率选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,将所述训练任务集划分为支持集和查询集;
所述支持集对语音识别模型初始化参数θ进行梯度下降得到更新参数
Figure FDA0002872045180000013
所述查询集根据查询所述更新参数
Figure FDA0002872045180000014
的效果获得查询损失向量
Figure FDA0002872045180000015
所述查询损失向量
Figure FDA0002872045180000016
用于对所述初始化参数θ寻优,获得最优的模型参数。
2.根据权利要求1所述的基于元学习的对抗采样训练方法,其特征在于,所述根据策略网络从K个语种构成的大任务集T中输出K维概率向量
Figure FDA0002872045180000017
其中,
Figure FDA0002872045180000018
为第i个语种任务集对应的采样概率,根据所述采样概率选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,将所述训练任务集划分为支持集和查询集,包括:
所述策略网络包括前馈注意力层和LSTM层;
所述策略网络通过所述LSTM层中存储的长短期记忆信息和当前查询损失向量获取采样任务,其中,所述采样任务根据所述采样概率获得所述训练任务集。
3.根据权利要求1所述的基于元学习的对抗采样训练方法,其特征在于,所述查询集根据查询所述更新参数
Figure FDA0002872045180000019
的效果获得查询损失向量
Figure FDA00028720451800000110
所述查询损失向量
Figure FDA00028720451800000111
用于对所述初始化参数θ寻优,获得最优的模型参数,包括:
每一次训练获取当前训练步的查询损失向量与概率向量,将所述查询损失向量与所述概率向量输入下一次训练的策略网络,将所述查询损失向量与所述概率向量合并计算前馈注意力,所述前馈注意力通过全连接层输出cs+1
4.根据权利要求3所述的基于元学习的对抗采样训练方法,其特征在于,所述查询集根据查询所述更新参数
Figure FDA0002872045180000021
的效果获得查询损失向量
Figure FDA0002872045180000022
所述查询损失向量
Figure FDA0002872045180000023
用于对所述初始化参数θ寻优,获得最优的模型参数,还包括:
将所述全连接层输出的cs+1与上一次训练中LSTM层的隐藏状态hs作为输入当前LSTM层的值,生成所述当前LSTM层的输出ys+1与当前隐藏状态hs+1,基于所述当前LSTM层的输出值通过全连接层与Softmax函数预测当前的概率向量
Figure FDA0002872045180000024
根据所述概率向量,选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,获取所述查询损失向量
Figure FDA0002872045180000025
并用于对所述初始化参数θ寻优,获得最优的模型参数。
5.一种基于元学习的对抗采样训练装置,其特征在于,包括:
训练模块,根据策略网络从K个语种构成的大任务集T中输出K维概率向量
Figure FDA0002872045180000026
其中,
Figure FDA0002872045180000027
为第i个语种任务集对应的采样概率,根据所述采样概率选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,将所述训练任务集划分为支持集和查询集;
第一更新模块,所述支持集对语音识别模型初始化参数θ进行梯度下降得到更新参数
Figure FDA0002872045180000028
第二更新模块,所述查询集根据查询所述更新参数
Figure FDA0002872045180000029
的效果获得查询损失向量
Figure FDA00028720451800000210
所述查询损失向量
Figure FDA00028720451800000211
用于对所述初始化参数θ寻优,获得最优的模型参数。
6.根据权利要求5所述的基于元学习的对抗采样训练装置,其特征在于,所述训练模块,包括:
所述策略网络包括前馈注意力层和LSTM层;
所述策略网络通过所述LSTM层中存储的长短期记忆信息和当前查询损失向量获取采样任务,其中,所述采样任务根据所述采样概率获得所述训练任务集。
7.根据权利要求5所述的基于元学习的对抗采样训练装置,其特征在于,所述第二更新模块,包括:
每一次训练获取当前训练步的查询损失向量与概率向量,将所述查询损失向量与所述概率向量输入下一次训练的策略网络,将所述查询损失向量与所述概率向量合并计算前馈注意力,所述前馈注意力通过全连接层输出cs+1
8.根据权利要求7所述的基于元学习的对抗采样训练装置,其特征在于,所述第二更新模块,还包括:
将所述全连接层输出的cs+1与上一次训练中LSTM层的隐藏状态hs作为输入当前LSTM层的值,生成所述当前LSTM层的输出ys+1与当前隐藏状态hs+1,基于所述当前LSTM层的输出值通过全连接层与Softmax函数预测当前的概率向量
Figure FDA0002872045180000031
根据所述概率向量,选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,获取所述查询损失向量
Figure FDA0002872045180000032
并用于对所述初始化参数θ寻优,获得最优的模型参数。
9.一种计算机终端设备,其特征在于,包括:
一个或多个处理器;
存储器,与所述处理器耦接,用于存储一个或多个程序;
当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理器实现如权利要求1至4任一项所述的基于元学习的对抗采样训练方法。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至4任一项所述的基于元学习的对抗采样训练方法。
CN202011642701.7A 2020-12-30 2020-12-30 一种基于元学习的对抗采样训练方法及装置 Active CN112786030B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011642701.7A CN112786030B (zh) 2020-12-30 2020-12-30 一种基于元学习的对抗采样训练方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011642701.7A CN112786030B (zh) 2020-12-30 2020-12-30 一种基于元学习的对抗采样训练方法及装置

Publications (2)

Publication Number Publication Date
CN112786030A true CN112786030A (zh) 2021-05-11
CN112786030B CN112786030B (zh) 2022-04-29

Family

ID=75755158

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011642701.7A Active CN112786030B (zh) 2020-12-30 2020-12-30 一种基于元学习的对抗采样训练方法及装置

Country Status (1)

Country Link
CN (1) CN112786030B (zh)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113627249A (zh) * 2021-07-05 2021-11-09 中山大学·深圳 基于对抗对比学习的导航系统训练方法、装置及导航系统
CN114399669A (zh) * 2022-03-25 2022-04-26 江苏智云天工科技有限公司 目标检测方法和装置
CN114743074A (zh) * 2022-06-13 2022-07-12 浙江华是科技股份有限公司 一种基于强弱对抗训练的船舶检测模型训练方法及系统
CN115730300A (zh) * 2022-12-12 2023-03-03 西南大学 基于混合式对抗元学习算法的程序安全模型构建方法

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109948648A (zh) * 2019-01-31 2019-06-28 中山大学 一种基于元对抗学习的多目标域适应迁移方法及系统
CN110852447A (zh) * 2019-11-15 2020-02-28 腾讯云计算(北京)有限责任公司 元学习方法和装置、初始化方法、计算设备和存储介质
WO2020158217A1 (ja) * 2019-02-01 2020-08-06 ソニー株式会社 情報処理装置、情報処理方法及び情報処理プログラム
CN111858991A (zh) * 2020-08-06 2020-10-30 南京大学 一种基于协方差度量的小样本学习算法
CN111881997A (zh) * 2020-08-03 2020-11-03 天津大学 一种基于显著性的多模态小样本学习方法
CN111898739A (zh) * 2020-07-30 2020-11-06 平安科技(深圳)有限公司 基于元学习的数据筛选模型构建方法、数据筛选方法、装置、计算机设备及存储介质

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109948648A (zh) * 2019-01-31 2019-06-28 中山大学 一种基于元对抗学习的多目标域适应迁移方法及系统
WO2020158217A1 (ja) * 2019-02-01 2020-08-06 ソニー株式会社 情報処理装置、情報処理方法及び情報処理プログラム
CN110852447A (zh) * 2019-11-15 2020-02-28 腾讯云计算(北京)有限责任公司 元学习方法和装置、初始化方法、计算设备和存储介质
CN111898739A (zh) * 2020-07-30 2020-11-06 平安科技(深圳)有限公司 基于元学习的数据筛选模型构建方法、数据筛选方法、装置、计算机设备及存储介质
CN111881997A (zh) * 2020-08-03 2020-11-03 天津大学 一种基于显著性的多模态小样本学习方法
CN111858991A (zh) * 2020-08-06 2020-10-30 南京大学 一种基于协方差度量的小样本学习算法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
NITHIN RAO KOLUGURI ET AL.: "META-LEARNING FOR ROBUST CHILD-ADULT CLASSIFICATION FROM SPEECH", 《ICASSP 2020》 *
王璐等: "基于元学习的语音识别探究", 《云南大学学报(自然科学版)》 *

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113627249A (zh) * 2021-07-05 2021-11-09 中山大学·深圳 基于对抗对比学习的导航系统训练方法、装置及导航系统
CN114399669A (zh) * 2022-03-25 2022-04-26 江苏智云天工科技有限公司 目标检测方法和装置
CN114743074A (zh) * 2022-06-13 2022-07-12 浙江华是科技股份有限公司 一种基于强弱对抗训练的船舶检测模型训练方法及系统
CN114743074B (zh) * 2022-06-13 2022-09-09 浙江华是科技股份有限公司 一种基于强弱对抗训练的船舶检测模型训练方法及系统
CN115730300A (zh) * 2022-12-12 2023-03-03 西南大学 基于混合式对抗元学习算法的程序安全模型构建方法

Also Published As

Publication number Publication date
CN112786030B (zh) 2022-04-29

Similar Documents

Publication Publication Date Title
CN112786030B (zh) 一种基于元学习的对抗采样训练方法及装置
CN109408731A (zh) 一种多目标推荐方法、多目标推荐模型生成方法以及装置
CN109635273A (zh) 文本关键词提取方法、装置、设备及存储介质
US20200265315A1 (en) Neural architecture search
CN111368514B (zh) 模型训练及古诗生成方法、古诗生成装置、设备和介质
JP2019537096A (ja) ニューラル機械翻訳システム
WO2018153806A1 (en) Training machine learning models
US11010664B2 (en) Augmenting neural networks with hierarchical external memory
JP2022050379A (ja) 意味検索方法、装置、電子機器、記憶媒体およびコンピュータプログラム
EP3563302A1 (en) Processing sequential data using recurrent neural networks
CN113901799A (zh) 模型训练、文本预测方法、装置、电子设备及介质
CN112925926B (zh) 多媒体推荐模型的训练方法、装置、服务器以及存储介质
CN111476038A (zh) 长文本生成方法、装置、计算机设备和存储介质
WO2020052061A1 (zh) 用于处理信息的方法和装置
CN113626610A (zh) 知识图谱嵌入方法、装置、计算机设备和存储介质
WO2022251719A1 (en) Granular neural network architecture search over low-level primitives
CN112182281A (zh) 一种音频推荐方法、装置及存储介质
JP2018084627A (ja) 言語モデル学習装置およびそのプログラム
CN112132281B (zh) 一种基于人工智能的模型训练方法、装置、服务器及介质
CN114328814A (zh) 文本摘要模型的训练方法、装置、电子设备及存储介质
CN115374252B (zh) 一种基于原生Bert架构的文本分级方法及装置
CN111507218A (zh) 语音与人脸图像的匹配方法、装置、存储介质及电子设备
CN115860009A (zh) 一种引入辅助样本进行对比学习的句子嵌入方法及系统
CN111797621B (zh) 一种术语替换方法及系统
CN113807106A (zh) 翻译模型的训练方法、装置、电子设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant