CN114648005A - 一种多任务联合学习的多片段机器阅读理解方法及装置 - Google Patents

一种多任务联合学习的多片段机器阅读理解方法及装置 Download PDF

Info

Publication number
CN114648005A
CN114648005A CN202210248775.5A CN202210248775A CN114648005A CN 114648005 A CN114648005 A CN 114648005A CN 202210248775 A CN202210248775 A CN 202210248775A CN 114648005 A CN114648005 A CN 114648005A
Authority
CN
China
Prior art keywords
segment
probability
answer
vector
fragment
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210248775.5A
Other languages
English (en)
Inventor
张虎
范越
王宇杰
李茹
梁吉业
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shanxi University
Original Assignee
Shanxi University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shanxi University filed Critical Shanxi University
Priority to CN202210248775.5A priority Critical patent/CN114648005A/zh
Publication of CN114648005A publication Critical patent/CN114648005A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/10Text processing
    • G06F40/12Use of codes for handling textual entities
    • G06F40/126Character encoding
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/35Clustering; Classification
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • G06F40/205Parsing
    • G06F40/216Parsing using statistical methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Data Mining & Analysis (AREA)
  • Computational Linguistics (AREA)
  • Databases & Information Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Probability & Statistics with Applications (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Machine Translation (AREA)

Abstract

本发明公开了一种多任务联合学习的多片段机器阅读理解方法及装置,属于自然语言处理技术领域。主要包括编码器模块、观点型和单片段抽取问题解答模块、多片段抽取问题解答模块、对抗学习模块。本发明是基于动态预测片段数量和序列标注的多任务联合学习,其中,动态预测片段数量可计算出每个问题所需的片段数量,基于此能较为准确地识别出多片段问题类型;而序列标注可以从输入文本中提取可变长度的片段,能够实现多个答案片段的有效定位。同时,在模型训练中构造了对抗训练方式,增强了模型的泛化能力。最后,本发明将多个任务进行联合优化学习,在多片段抽取问题解答以及观点型和单片段抽取的问题解答中取得了更好的效果。

Description

一种多任务联合学习的多片段机器阅读理解方法及装置
技术领域
本发明属于自然语言处理技术领域,具体涉及一种多任务联合学习的多片段机器阅读理解方法及装置。
背景技术
机器阅读理解(Machine Reading Comprehension,MRC)是使计算机理解文章语义并且回答相关问题的技术,是自然语言处理(Natural Language Processing,NLP)领域的一项重要研究任务,在搜索引擎、智能客服、智慧法律等应用领域具有重要作用。
近年来,随着机器学习特别是深度学习技术的快速发展,片段抽取式MRC任务的实验结果取得了较大提升,在SQuAD、DuReader等数据集中接近甚至超过人类表现。然而现有的片段抽取式阅读理解模型和相关数据集仍存在一定的不足,其答案通常限定为阅读材料中的某一个片段,这使得机器阅读理解在真实场景中的应用受到影响。目前,实际应用中的很多阅读理解问题的答案是由文本中多处不连续的片段组合而成,因此,多片段抽取阅读理解的研究可以扩大机器阅读理解的场景适用性。
多片段问题类型的答案是由文章中多处不连续的片段组成,需要模型在深入理解文章的基础上准确识别出多片段问题类型并且有效定位多个答案片段,这对模型提出了更高的要求,相关研究人员针对该任务开展了深入研究。MTMSN建立一种能够动态提取一个或多个片段的阅读理解模型,该模型首先预测答案数量,然后采用非极大值抑制算法(Non-maximum suppression,NMS)提取特定数量的非重叠片段。TASE提出一种简单的架构,尝试将研究任务转换为序列标记问题来回答多片段问题。该模型利用序列标注任务的特性,能够同时考虑预测答案及答案数,对于片段数较多的问题可以有效提取出相应片段。
已有大多数面向多片段抽取式阅读理解方法主要采用基于序列标注的多片段抽取方法,利用序列标注能够提取可变长度的跨度的特性,可提取一个或者多个片段。但基于该方法的多片段抽取方法在提取答案片段时,往往会导致其他类型问题(如观点型或单片段抽取型)也给出多个答案片段。
发明内容
针对目前多片段抽取式阅读理解方法给出答案不准确的问题,本发明提供了一种多任务联合学习的多片段机器阅读理解方法及装置。
为了达到上述目的,本发明采用了下列技术方案:
一种多任务联合学习的多片段机器阅读理解方法,包括以下步骤:
步骤1:通过编码器模块对问题和文章进行编码,得到问题和文章中每个字或词的向量表示;
步骤2:利用观点型和单片段抽取问题解答模块解答观点型和单片段抽取问题;
步骤3:利用多片段抽取类问题解答模块解答多片段问题,该模块首先通过预测问题的片段数量识别出多片段问题类型,再利用序列标注提取多个答案片段;
步骤4:利用对抗学习模块在模型训练中构造对抗训练方式,增强模型的鲁棒性和泛化能力。
进一步,所述步骤1通过编码器模块对问题和文章进行编码,得到问题和文章中每个字或词的向量表示的具体方法是:
采用预训练语言模型MacBERT作为编码器,对问题和文章进行编码,得到问题和文章中每个字或词的向量表示,计算公式如下所示:
input=[CLS]+question+[SEP]+context+[SEP]
Figure BDA0003546137900000031
其中,input表示模型的输入,question表示问题,context表示文章,[CLS]表示开始位置,[SEP]表示分隔符;Hi∈RS×D表示文章和问题的向量表示,S表示输入的序列长度512;D表示隐藏层维度,base版本是768,large版本为1024;L表示MacBERT的层数,base版本为12层,large为24层,R表示该向量所属的向量空间。
进一步,所述步骤2利用观点型和单片段抽取问题解答模块解答观点型和单片段抽取问题的具体方法是:
观点型问题包括“YES/NO”(是否)类问题和“Unknown”(不可回答)类问题;
其中“YES/NO”类问题:将通过步骤1编码器模块得到的文章和问题的向量其最后四层作为上下文向量u,然后对上下文向量u作自注意力计算,再通过W1∈R4D×2的全连接层进行二分类,W1表示该全连接层的可训练参数,得到问题答案是YES/NO的概率用pyes,pno来表示,具体计算过程如下:
u=Concat(HL-3,HL-2,HL-1,HL)
u'=SelfAttention(u)
{pyes,pno}=FFN(u')
其中,Concat表示向量拼接函数,HL表示BERT的最后一层向量,u∈RS×4D表示拼接BERT最后四层的向量,FFN表示全连接层,SelfAttention的计算过程如下:
α=u·W2+b,α∈RS
α'=softmax(α)
SelfAttention(u)=α'·u
其中,α是向量u经过线性运算得到的权重,W2是可学习的参数,偏置b∈RS,α'表示经过softmax归一化之后的权重,softmax表示softmax函数;
其中“Unknown”类问题:将BERT的最后一层向量HL经过最大池化得到一个向量表示,然后经过参数为W3的全连接层得到Unknown的答案概率punk,具体公式如下所示:
h=HL,h'=MaxPooling(h)
{punk}=FFN(h');
其中,h表示MacBERT的最后一层向量,h'表示经过池化后得到的向量,MaxPooling表示最大池化函数;
单片段抽取类问题:需要模型深入理解文章和问题,并在文章中标注出正确答案的开始位置和结束位置;
将上下文向量u经过参数为W4的全连接层进行二分类,得到文章中每个字符属于答案的起始位置和结束位置的概率,具体计算公式如下所示:
{pstart,pend}=FFN(u)
其中,pstart表示文章中每个字符属于正确答案的起始概率,pend表示文章中每个字符属于正确答案的结束概率,pstart,pend∈RS,S表示文本的序列长度512;
将观点类型概率pyes,pno,punk分别拼接到上述得到的开始和结束位置概率中,一起通过交叉熵计算损失,具体计算公式如下所示:
logitss=[pstart,pyes,pno,punk]
logitse=[pend,pyes,pno,punk]
Lstart=CrossEntropy(logitss,ys)
Lend=CrossEntropy(logitse,ye)
其中,logitss表示拼接观点型概率后的开始位置概率,logitse表示拼接观点型概率后的结束位置概率,logitss,logitse∈R515,CrossEntropy表示二元交叉熵损失函数,ys、ye分别是该条数据的真实开始位置和结束位置,YES类型设置其位置为512,NO类型为513,Unknown类型为514;
在预测阶段,分别遍历开始位置概率logitss和结束位置概率logitse,将符合1≤s≤e,e≤S条件的片段加入候选集合Φ,其中s表示答案的开始位置,e表示答案的结束位置;将开始位置和结束位置的概率之和作为单片段答案概率,符合条件的片段一般有多个,选取概率最大的作为span(单片段抽取)类型问题的答案;单片段抽取问题的答案是文章的某一个片段。
同时,将YES/NO、Unknown概率的2倍作为答案概率,同单片段问题一样,也加入候选集合Φ;最后从候选集合Φ中选择概率最大的作为最终答案。
进一步,所述步骤3利用多片段抽取类问题解答模块解答多片段问题的具体操作为,将预测片段数量建模为分类问题,分类数目n是超参数,(该超参数需要根据不同数据集实验选取不同的值)设置为3,对于片段数大于n的问题,随机选取n个片段作为该问题的答案,将[CLS]位置通过编码器得到的向量c经过参数为W5的全连接层得到回答该问题所需片段数量的概率pspan;序列标注层采用IO标记的方式,在数据预处理时根据正确答案将文章中的每一个字符标注标签‘I’或‘O’,如果文中某个字符属于正确答案则标注为‘I’,否则标注‘O’,即预测输入的每个位置是否属于输出的一部分,将MacBERT最后一层的向量作为上下文表示m,然后经过BiLSTM-CRF层,再经过参数为W6的全连接层为每个字符预测标签概率,具体计算公式如下所示:
pspan=FFN(c)
m=HL,m'=BiLSTM(m)
pIO=FFN(m')
其中,c∈RD,m∈RS×D,pIO∈RS×2;pIO表示经过W6的全连接层得到标签为‘I’或‘O’的概率,FFN表示全连接层;
对于片段数量预测采用交叉熵损失函数计算其损失,对于序列标注采用CRF最大化正确标签的对数概率作为损失,具体计算公式如下所示:
Lspan=CrossEntropy(pspan,yspan)
Figure BDA0003546137900000061
其中,yspan表示当前数据的实际片段数量;pi[Tj]表示第i个字符是标签Tj的概率,Tj表示IO标签;
将观点类型以及片段抽取类型的损失加权求和,一起进行梯度反向传播,联合优化,计算过程如下公式所示:
L=α·(Lstart+Lend)+β·Lspan+γ·LIO
其中α,β,γ表示三个任务的权重,设置α=1,β=1,γ=1;
在预测阶段,将片段数量概率最大值所在的索引作为该问题的片段数量,对序列标注得到的标签概率采用维特比(Viterbi)算法解码,得到IO标签序列Z,然后选择所有标签I连续的片段作为一个候选片段,根据其所在位置在文章中截取相应片段,得到候选片段集合IOspan,具体计算如下所示:
answer_num=argmax(pspan)
Z=Viterbi(pIO)
IOspan=Extract(Z)
在最终选取答案时,根据预测出的片段数量来确定是多片段答案或者其它类型的答案;当预测片段的数量大于1时,采用候选集合IOspan作为最终答案,否则采用步骤2得到的候选集合Φ中概率最大的作为最终答案。
进一步,所述步骤4利用对抗学习模块在模型训练中构造对抗训练方式,增强模型的鲁棒性和泛化能力的具体操作为:
在样本x的Embedding上增加一个扰动radv,得到对抗样本,
然后用对抗样本进行训练,输出分布和原始分布一致,训练采用交叉熵作为损失,其计算如下式所示:
loss=-log(p(y|x+radv;θ))
其中,loss表示损失,y表示样本的真实标签,x表示样本,是原始输入;θ是模型参数;
对输入序列input经过BERT后得到其编码E,如下式所示:
E={E[cls],EQ1,EQ2,...,EQn,E[SEP],EP1,EP2,...,EPm,E[SEP]}
首先对E进行前向传播并计算其损失L,然后采用FGM(Fast Gradient Method)算法构造对抗扰动,具体计算过程如下式所示:
Figure BDA0003546137900000071
Figure BDA0003546137900000072
其中,ε是超参数,||g||2是g的L2范数;
将计算出的对抗扰动radv加到原始样本E上,得到对抗样本Er,如下式所示:
Figure BDA0003546137900000081
其中,角标Q1…Qn表示问题字符,n为问题长度,P1…Pm表示文章字符,m为文章长度;
对Er进行前向传播得到损失Ladv,然后反向传播得到对抗梯度,将该梯度与原始梯度进行累加,对抗训练结束后将输入样本的Embedding恢复到原始状态E以进行下一轮训练,此时根据累加的梯度对参数进行更新。
一种多任务联合学习的多片段机器阅读理解装置,包括:
编码器模块,用于对问题和文章进行编码,得到问题和文章中每个字或词的向量表示;
观点型和单片段抽取问题解答模块,用于解决观点型和单片段抽取类问题;
多片段抽取问题解答模块,利用基于动态预测片段数量和序列标注的多任务学习,其中动态预测片段数量可计算出每个问题所需的片段数量,能较为准确地识别出多片段问题类型,序列标注可以从输入文本中提取可变长度的片段,能够实现多个答案片段的有效定位;
对抗学习模块,利用经典的对抗学习算法构造对抗训练,增强模型鲁棒性及泛化能力。
一种电子设备,包括至少一个处理器,以及至少一个与处理器通信连接的存储器,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行所述多任务联合学习的多片段机器阅读理解方法。
一种存储有计算机指令的非暂态计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行所述多任务联合学习的多片段机器阅读理解方法。
与现有技术相比本发明具有以下优点:
(1)本发明提出一种结合动态预测片段数量和序列标注的多任务联合学习方法及装置,其中,动态预测片段数量可计算出每个问题所需的片段数量,能较为准确地识别出多片段问题类型;序列标注可以从输入文本中提取可变长度的片段,能够实现多个答案片段的有效定位。
(2)本发明通过构造对抗训练,增强了模型的鲁棒性和泛化能力;
(3)本发明将多个任务进行联合优化学习,在各个问题类型上均取得了提升。
附图说明
图1为本发明实施例提供的一种多任务联合学习的多片段机器阅读理解方法整体架构图;
图2为本发明使用的数据样例;
图3为本发明的观点型和单片段抽取模块结构图;
图4为本发明的多片段抽取模块结构图;
图5为本发明的预测阶段答案解码算法;
图6为本发明实施例提供的一种多任务联合学习的多片段机器阅读理解装置的结构示意图;
图7为本发明实施例提供的电子设备整体结构示意图。
具体实施方式
下面结合实施例和附图对本发明做进一步详细描述,所举实施例只用于解释本发明,并非用于限定本发明的保护范围。
实施例1
图1为本发明实施例提供的一种多任务联合学习的多片段机器阅读理解方法整体架构图,主要包括编码器模块、观点型和单片段抽取问题解答模块、多片段抽取问题解答模块、对抗学习模块。各模块的具体内容如下:
编码器模块采用预训练语言模型对所述文章和问题进行编码,得到文章和问题中每个字或词的向量表示。
观点型和单片段抽取模块,利用自注意力机制获得观点型问题概率;对于单片段抽取问题,获得所述文章中每个字是正确答案的开始位置或结束位置的概率。
多片段抽取模块,利用动态预测片段数量和序列标注的多任务联合学习,获得所述问题所需的片段数量,基于此准确识别多片段问题类型;同时,获得回答所述多片段问题需在文章中标注出的每个字符属于标签I和O的概率。
对抗学习模块,采用快速梯度方法(Fast Gradient Method,FGM)在Embedding层构造对抗扰动,获得对抗样本以进行训练,对抗学习结束后恢复至原始状态。
实施例2
图2为2021中国法律智能技术评测(CAIL2021)阅读理解数据集的示例,其文章来自于中国裁判文书上的真实案例,问题是“在读学生在哪些地方进行治疗?”,该问题的答案为文章中三处不连续的片段组合而成。同时,CAIL2021数据集中保留观点型和单片段抽取问题。
1、首先利用编码器模块对所述文章和问题进行编码,得到问题和文章中每个字或词的向量表示。采用预训练语言模型MacBERT作为编码器,计算公式如下所示:
input=[CLS]+question+[SEP]+context+[SEP]
Figure BDA0003546137900000111
其中,input表示模型的输入,question表示问题,context表示文章,[CLS]表示开始位置,[SEP]表示分隔符。Hi∈RS×D表示文章和问题的向量表示,这里S表示输入的序列长度512;D表示隐藏层维度,base版本是768,large版本为1024;L表示BERT的层数,base版本为12层,large为24层。
2、利用所述观点型和单片段抽取问题解答模块解答观点型和单片段抽取问题,图3为该模块的结构图。
(1)所述观点型问题中的“YES/NO”类问题,将通过编码器模块得到的文章和问题向量拼接其最后四层作为上下文向量u,然后对其作自注意力计算,再通过W1∈R4D×2的全连接层进行二分类,W1表示该全连接层的可训练参数,得到问题答案是YES/NO的概率pyes,pno,具体计算过程见如下公式:
u=Concat(HL-3,HL-2,HL-1,HL)
u'=SelfAttention(u)
{pyes,pno}=FFN(u')
其中,Concat表示向量拼接函数,HL表示BERT的最后一层向量,u∈RS×4D是拼接BERT最后四层的向量,FFN表示全连接层,SelfAttention的计算如下公式所示:
α=u·W2+b,α∈RS
α'=softmax(α)
SelfAttention(u)=α'·u
这里α是向量u经过线性运算得到的权重,W2∈R4D×1是可学习的参数,偏置b∈RS,α'表示经过softmax归一化之后的权重,softmax表示softmax函数。
(2)所述观点型问题中的Unknown类问题,将BERT的最后一层向量HL经过最大池化得到一个向量表示,然后经过W3∈R4D×1的全连接层得到Unknown的答案概率punk,其中W3为该全连接层可训练的参数,具体公式如下所示:
h=HL,h'=MaxPooling(h)
{punk}=FFN(h')
其中,h表示MacBERT的最后一层向量,h'表示经过池化后得到的向量,MaxPooling表示最大池化函数。
(3)所述单片段抽取类问题,需要根据问题标出正确答案在文章中的开始位置和结束位置。将上下文向量u经过W4∈R4D×2的全连接层进行二分类,其中W4为该全连接层可训练的参数,得到每个token(字符)起始位置和结束位置的概率,具体计算公式如下所示:
{pstart,pend}=FFN(u)
其中,pstart表示文章中字符属于答案的起始位置概率,pend表示文章中字符属于答案的结束位置概率,pstart,pend∈RS
将观点类型概率pyes,pno,punk分别拼接到中,一起通过交叉熵计算损失,具体计算公式如下所示:
logitss=[pstart,pyes,pno,punk]
logitse=[pend,pyes,pno,punk]
Lstart=CrossEntropy(logitss,ys)
Lend=CrossEntropy(logitse,ye)
其中,logitss,logitse∈R515,CrossEntropy是二元交叉熵损失函数,ys、ye分别是该条数据的真实开始位置和结束位置,YES类型设置其位置为512,NO类型为513,Unknown类型为514。
在预测阶段,分别遍历开始位置概率logitss和结束位置概率logitse,将符合1≤s≤e,e≤S条件的片段加入候选集合Φ,其中s表示答案的开始位置,e表示答案的结束位置。将开始位置和结束位置的概率之和作为单片段答案概率,符合条件的片段一般有多个,选取概率最大的作为span类型问题的答案。同时,将YES/NO、Unknown概率的2倍作为答案概率,同单片段问题一样,也加入候选集合Φ。最后从候选集合Φ中选择概率最大的作为最终答案。
3、利用所述多片段抽取类问题解答模块解答多片段问题类型,需要根据问题在文章中抽取出多处不连续的片段组成最终答案。图4为多片段抽取模块结构图。
本发明将预测片段数量建模为分类问题,分类数目n是超参数,设置为3,对于片段数大于n的问题,随机选取n个片段作为该问题的答案,将[CLS]位置通过编码器得到的向量c经过W5∈RD×n的全连接层得到回答该问题所需片段数量的概率pspan,其中W5为该全连接层可训练的参数;序列标注层采用IO标记的方式,在数据预处理时根据正确答案将文章中的每一个字符标注标签‘I’或‘O’,如果文中某个字符属于正确答案则标注为‘I’,否则标注‘O’,即预测输入的每个位置是否属于输出的一部分,将MacBERT最后一层的向量作为上下文表示m,然后经过BiLSTM-CRF层,再经过W6∈RD×2的全连接层为每个token(字符)预测标签概率,其中W6为该全连接层可训练的参数。具体计算公式如下所示:
pspan=FFN(c)
m=HL,m'=BiLSTM(m)
pIO=FFN(m')
其中,c∈RD,m∈RS×D,pIO∈RS×2;pIO表示经过W6的全连接层得到标签为‘I’或‘O’的概率,FFN表示全连接层。
对于片段数量预测采用交叉熵损失函数计算其损失,对于序列标注采用CRF最大化正确标签的对数概率作为损失。具体计算公式如下所示:
Lspan=CrossEntropy(pspan,yspan)
Figure BDA0003546137900000141
其中,yspan表示当前数据的实际片段数量;pi[Tj]表示第i个token是标签Tj的概率,Tj表示IO标签。
将观点类型以及片段抽取类型的损失加权求和,一起进行梯度反向传播,联合优化,计算过程如下公式所示:
L=α·(Lstart+Lend)+β·Lspan+γ·LIO
其中α,β,γ表示三个任务的权重,设置α=1,β=1,γ=1。
在预测阶段,将片段数量概率最大值所在的索引作为该问题的片段数量,对序列标注得到的标签概率采用维特比(Viterbi)算法解码,得到IO标签序列Z,然后选择所有标签I连续的片段作为一个候选片段,根据其所在位置在文章中截取相应片段,得到候选片段集合IOspan。具体计算如下所示:
answer_num=argmax(pspan)
Z=Viterbi(pIO)
IOspan=Extract(Z)
在最终选取答案时,根据预测出的片段数量来确定是多片段答案或者其它类型的答案。当预测片段的数量大于1时,采用候选集合IOspan作为最终答案,否则采用步骤2得到的候选集合Φ中概率最大的作为最终答案。本发明在预测阶段的答案解码算法如图5所示。
4、利用所述对抗学习模块在模型训练中构造对抗训练方式,增强模型的鲁棒性和泛化能力。
在样本x的Embedding上增加一个扰动radv,得到对抗样本,然后用其进行训练,输出分布和原始分布一致。训练采用交叉熵作为损失,其计算如下式所示:
loss=-log(p(y|x+radv;θ))
其中,loss表示损失,y是样本的真实标签,x表示样本,是原始输入,θ是模型参数。
具体来说,对输入序列input经过BERT后得到其编码E,如下式所示:
E={E[cls],EQ1,EQ2,...,EQn,E[SEP],EP1,EP2,...,EPm,E[SEP]}
首先对E进行前向传播并计算其损失
Figure BDA0003546137900000151
然后采用FGM(Fast Gradient Method)算法构造对抗扰动,具体计算过程如下式所示:
Figure BDA0003546137900000152
Figure BDA0003546137900000153
其中,ε是超参数,||g||2是g的L2范数。
将计算出的对抗扰动radv加到原始样本E上,得到对抗样本Er,如下式所示:
Figure BDA0003546137900000154
对Er进行前向传播得到损失Ladv,然后反向传播得到对抗梯度,将该梯度与原始梯度进行累加,对抗训练结束后将输入样本的Embedding恢复到原始状态E以进行下一轮训练,此时根据累加的梯度对参数进行更新。
实施例3
图6为本发明实施例提供的一种多任务联合学习的多片段机器阅读理解装置的结构示意图,如图6所示,该多片段机器阅读理解装置包括:编码器模块、观点型和单片段抽取问题解答模块、多片段抽取问题解答模块和对抗学习模块,其中:
编码器模块,用于对问题和文章进行编码,得到问题和文章中每个字或词的向量表示;
观点型和单片段抽取问题解答模块,用于解决观点型和单片段抽取类问题;
多片段抽取问题解答模块,利用基于动态预测片段数量和序列标注的多任务学习,其中动态预测片段数量可计算出每个问题所需的片段数量,能较为准确地识别出多片段问题类型,序列标注可以从输入文本中提取可变长度的片段,能够实现多个答案片段的有效定位;
对抗学习模块,利用经典的对抗学习算法构造对抗训练,增强模型鲁棒性及泛化能力。
本发明实施例提供一种多任务联合学习的多片段机器阅读理解装置,具体执行上述一种多任务联合学习的多片段机器阅读理解方法实施例流程,具体请详见上述一种多任务联合学习的多片段机器阅读理解方法实施例的内容,在此不再赘述。
本实施例提供一种电子设备,图7为本发明实施例提供的电子设备整体结构示意图,该设备包括:处理器、存储器、通信总线和通信接口;其中,处理器,通信接口,存储器通过通信总线完成相互间的通信。存储器存储有可被处理器执行的程序指令,处理器调用程序指令能够执行上述各方法实施例所提供的方法,例如包括:通过预训练语言模型对文章和问题进行编码,输出文章和问题的语义表示;利用自注意力机制和全连接网络得到观点型和单片段问题的概率;基于动态预测片段数量和序列标注的多任务学习,得到多片段问题类型的概率以及相应片段;通过快速梯度方法FGM构造对抗训练,增强模型泛化能力。
此外,上述的存储器中的逻辑指令可以通过软件功能单元的形式实现,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明实施例的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备执行本发明各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)等各种可以存储程序代码的介质。
本实施例提供一种非暂态计算机可读存储介质,非暂态计算机可读存储介质存储计算机指令,计算机指令使计算机执行上述各方法实施例所提供的方法,例如包括:通过预训练语言模型对文章和问题进行编码,输出文章和问题的语义表示;利用自注意力机制和全连接网络得到观点型和单片段问题的概率;基于动态预测片段数量和序列标注的多任务学习,得到多片段问题类型的概率以及相应片段;通过快速梯度方法FGM构造对抗训练,增强模型泛化能力。
本发明说明书中未作详细描述的内容属于本领域专业技术人员公知的现有技术。尽管上面对本发明说明性的具体实施方式进行了描述,以便于本技术领的技术人员理解本发明,但应该清楚,本发明不限于具体实施方式的范围,对本技术领域的普通技术人员来讲,只要各种变化在所附的权利要求限定和确定的本发明的精神和范围内,这些变化是显而易见的,一切利用本发明构思的发明创造均在保护之列。

Claims (8)

1.一种多任务联合学习的多片段机器阅读理解方法,其特征在于:包括以下步骤:
步骤1:通过编码器模块对问题和文章进行编码,得到问题和文章中每个字或词的向量表示;
步骤2:利用观点型和单片段抽取问题解答模块解答观点型和单片段抽取问题;
步骤3:利用多片段抽取类问题解答模块解答多片段问题,该模块首先通过预测问题的片段数量识别出多片段问题类型,再利用序列标注提取多个答案片段;
步骤4:利用对抗学习模块在模型训练中构造对抗训练方式,增强模型的鲁棒性和泛化能力。
2.根据权利要求1所述的一种多任务联合学习的多片段机器阅读理解方法,其特征在于:所述步骤1通过编码器模块对问题和文章进行编码,得到问题和文章中每个字或词的向量表示的具体方法是:
采用预训练语言模型MacBERT作为编码器,对问题和文章进行编码,得到问题和文章中每个字或词的向量表示,计算公式如下所示:
input=[CLS]+question+[SEP]+context+[SEP]
Figure FDA0003546137890000011
其中,input表示模型的输入,question表示问题,context表示文章,[CLS]表示开始位置,[SEP]表示分隔符;Hi∈RS×D表示文章和问题的向量表示,S表示输入的序列长度;D表示隐藏层维度;L表示MacBERT的层数,R表示该向量所属的向量空间。
3.根据权利要求1所述的一种多任务联合学习的多片段机器阅读理解方法,其特征在于:所述步骤2利用观点型和单片段抽取问题解答模块解答观点型和单片段抽取问题的具体方法是:
观点型问题包括“YES/NO”类问题和“Unknown”类问题;
其中“YES/NO”类问题:将通过步骤1编码器模块得到的文章和问题的向量其最后四层作为上下文向量u,然后对上下文向量u作自注意力计算,再通过参数为W1的全连接层进行二分类,得到问题答案是YES/NO的概率用pyes,pno来表示,具体计算过程如下:
u=Concat(HL-3,HL-2,HL-1,HL)
u'=SelfAttention(u)
{pyes,pno}=FFN(u')
其中,Concat表示向量拼接函数,HL表示BERT的最后一层向量,u∈RS×4D表示拼接BERT最后四层的向量,FFN表示全连接层,SelfAttention的计算过程如下:
α=u·W2+b,α∈RS
α'=softmax(α)
SelfAttention(u)=α'·u
其中,α是向量u经过线性运算得到的权重,W2是可学习的参数,偏置b∈RS,α'表示经过softmax归一化之后的权重,softmax表示softmax函数;
其中“Unknown”类问题:将BERT的最后一层向量HL经过最大池化得到一个向量表示,然后经过参数为W3的全连接层得到Unknown的答案概率punk,具体公式如下所示:
h=HL,h'=MaxPooling(h)
{punk}=FFN(h');
其中,h表示MacBERT的最后一层向量,h'表示经过池化后得到的向量,MaxPooling表示最大池化函数;
单片段抽取类问题:需要根据问题标出正确答案在文章中的开始位置和结束位置;
将上下文向量u经过参数为W4的全连接层进行二分类,得到文章中每个字符属于答案的起始位置和结束位置的概率,具体计算公式如下所示:
{pstart,pend}=FFN(u)
其中,pstart表示文章中字符属于答案的起始位置概率,pend表示文章中字符属于答案的结束位置概率;
将观点类型概率pyes,pno,punk分别拼接到上述得到的开始和结束位置概率中,一起通过交叉熵计算损失,具体计算公式如下所示:
logitss=[pstart,pyes,pno,punk]
logitse=[pend,pyes,pno,punk]
Lstart=CrossEntropy(logitss,ys)
Lend=CrossEntropy(logitse,ye)
其中,logitss表示拼接观点型概率后的开始位置概率,logitse表示拼接观点型概率后的结束位置概率,CrossEntropy表示二元交叉熵损失函数,ys、ye分别是该条数据的真实开始位置和结束位置;
在预测阶段,分别遍历开始位置概率logitss和结束位置概率logitse,将符合1≤s≤e,e≤S条件的片段加入候选集合Φ,其中s表示答案的开始位置,e表示答案的结束位置;将开始位置和结束位置的概率之和作为单片段答案概率,符合条件的片段一般有多个,选取概率最大的作为span单片段抽取类型问题的答案;
同时,将YES/NO、Unknown概率的2倍作为答案概率,同单片段问题一样,也加入候选集合Φ;最后从候选集合Φ中选择概率最大的作为最终答案。
4.根据权利要求1所述的一种多任务联合学习的多片段机器阅读理解方法,其特征在于:所述步骤3利用多片段抽取类问题解答模块解答多片段问题的具体操作为,将预测片段数量建模为分类问题,分类数目n是超参数,对于片段数大于n的问题,随机选取n个片段作为该问题的答案,将[CLS]位置通过编码器得到的向量c经过参数为W5的全连接层得到回答该问题所需片段数量的概率pspan;序列标注层采用IO标记的方式,在数据预处理时根据正确答案将文章中的每一个字符标注标签‘I’或‘O’,如果文中某个字符属于正确答案则标注为‘I’,否则标注‘O’,即预测输入的每个位置是否属于输出的一部分,将MacBERT最后一层的向量作为上下文表示m,然后经过BiLSTM-CRF层,再经过参数为W6的全连接层为每个字符预测标签概率,具体计算公式如下所示:
pspan=FFN(c)
m=HL,m'=BiLSTM(m)
pIO=FFN(m')
其中,c∈RD,m∈RS×D;pIO表示经过参数为W6的全连接层得到标签为‘I’或‘O’的概率,FFN表示全连接层;
对于片段数量预测采用交叉熵损失函数计算其损失,对于序列标注采用CRF最大化正确标签的对数概率作为损失,具体计算公式如下所示:
Lspan=CrossEntropy(pspan,yspan)
Figure FDA0003546137890000051
其中,yspan表示当前数据的实际片段数量;pi[Tj]表示第i个字符是标签Tj的概率,Tj表示IO标签;
将观点类型以及片段抽取类型的损失加权求和,一起进行梯度反向传播,联合优化,计算过程如下公式所示:
L=α·(Lstart+Lend)+β·Lspan+γ·LIO
其中α,β,γ表示三个任务的权重;
在预测阶段,将片段数量概率最大值所在的索引作为该问题的片段数量,对序列标注得到的标签概率采用维特比Viterbi算法解码,得到IO标签序列Z,然后选择所有标签I连续的片段作为一个候选片段,根据其所在位置在文章中截取相应片段,得到候选片段集合IOspan,具体计算如下所示:
answer_num=argmax(pspan)
Z=Viterbi(pIO)
IOspan=Extract(Z)
在最终选取答案时,根据预测出的片段数量来确定是多片段答案或者其它类型的答案;当预测片段的数量大于1时,采用候选集合IOspan作为最终答案,否则采用步骤2得到的候选集合Φ中概率最大的作为最终答案。
5.根据权利要求1所述的一种多任务联合学习的多片段机器阅读理解方法,其特征在于:所述步骤4利用对抗学习模块在模型训练中构造对抗训练方式,增强模型的鲁棒性和泛化能力的具体操作为:
在样本x的Embedding上增加一个扰动radv,得到对抗样本,
然后用对抗样本进行训练,输出分布和原始分布一致,训练采用交叉熵作为损失,其计算如下式所示:
loss=-log(p(y|x+radv;θ))
其中,loss表示损失,y表示样本的真实标签,x表示样本,是原始输入;θ是模型参数;
对输入序列input经过BERT后得到其编码E,如下式所示:
E={E[cls],EQ1,EQ2,...,EQn,E[SEP],EP1,EP2,...,EPm,E[SEP]}
其中,角标Q1…Qn表示问题字符,n为问题长度,P1…Pm表示文章字符,m为文章长度;
首先对E进行前向传播并计算其损失L,然后采用FGM算法构造对抗扰动,具体计算过程如下式所示:
Figure FDA0003546137890000061
Figure FDA0003546137890000062
其中,ε是超参数,||g||2是g的L2范数;
将计算出的对抗扰动radv加到原始样本E上,得到对抗样本Er,如下式所示:
Figure FDA0003546137890000063
对Er进行前向传播得到损失Ladv,然后反向传播得到对抗梯度,将对抗梯度与原始梯度进行累加,对抗训练结束后将输入样本的Embedding恢复到原始状态E以进行下一轮训练,此时根据累加的梯度对参数进行更新。
6.一种多任务联合学习的多片段机器阅读理解装置,其特征在于:包括:
编码器模块,用于对问题和文章进行编码,得到问题和文章中每个字或词的向量表示;
观点型和单片段抽取问题解答模块,用于解决观点型和单片段抽取类问题;
多片段抽取问题解答模块,利用基于动态预测片段数量和序列标注的多任务学习,其中动态预测片段数量可计算出每个问题所需的片段数量,能较为准确地识别出多片段问题类型,序列标注可以从输入文本中提取可变长度的片段,能够实现多个答案片段的有效定位;
对抗学习模块,利用经典的对抗学习算法构造对抗训练,增强模型鲁棒性及泛化能力。
7.一种电子设备,其特征在于:包括至少一个处理器,以及至少一个与处理器通信连接的存储器,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1~5任一项所述多任务联合学习的多片段机器阅读理解方法。
8.一种存储有计算机指令的非暂态计算机可读存储介质,其上存储有计算机程序,其特征在于,该计算机程序被处理器执行权利要求1~5任一项所述多任务联合学习的多片段机器阅读理解方法。
CN202210248775.5A 2022-03-14 2022-03-14 一种多任务联合学习的多片段机器阅读理解方法及装置 Pending CN114648005A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210248775.5A CN114648005A (zh) 2022-03-14 2022-03-14 一种多任务联合学习的多片段机器阅读理解方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210248775.5A CN114648005A (zh) 2022-03-14 2022-03-14 一种多任务联合学习的多片段机器阅读理解方法及装置

Publications (1)

Publication Number Publication Date
CN114648005A true CN114648005A (zh) 2022-06-21

Family

ID=81992731

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210248775.5A Pending CN114648005A (zh) 2022-03-14 2022-03-14 一种多任务联合学习的多片段机器阅读理解方法及装置

Country Status (1)

Country Link
CN (1) CN114648005A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115048906A (zh) * 2022-08-17 2022-09-13 北京汉仪创新科技股份有限公司 一种文档结构化方法、装置、电子设备和存储介质

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115048906A (zh) * 2022-08-17 2022-09-13 北京汉仪创新科技股份有限公司 一种文档结构化方法、装置、电子设备和存储介质

Similar Documents

Publication Publication Date Title
CN111783462A (zh) 基于双神经网络融合的中文命名实体识别模型及方法
CN110196980B (zh) 一种基于卷积网络在中文分词任务上的领域迁移
CN110134946B (zh) 一种针对复杂数据的机器阅读理解方法
CN111159485B (zh) 尾实体链接方法、装置、服务器及存储介质
CN113392209B (zh) 一种基于人工智能的文本聚类方法、相关设备及存储介质
CN111738169B (zh) 一种基于端对端网络模型的手写公式识别方法
CN110852089B (zh) 基于智能分词与深度学习的运维项目管理方法
CN113626589B (zh) 一种基于混合注意力机制的多标签文本分类方法
CN113204633B (zh) 一种语义匹配蒸馏方法及装置
CN112580328A (zh) 事件信息的抽取方法及装置、存储介质、电子设备
CN113392651A (zh) 训练词权重模型及提取核心词的方法、装置、设备和介质
CN113743119B (zh) 中文命名实体识别模块、方法、装置及电子设备
CN113988079A (zh) 一种面向低数据的动态增强多跳文本阅读识别处理方法
CN111145914A (zh) 一种确定肺癌临床病种库文本实体的方法及装置
CN115186147A (zh) 对话内容的生成方法及装置、存储介质、终端
CN111209362A (zh) 基于深度学习的地址数据解析方法
CN114648005A (zh) 一种多任务联合学习的多片段机器阅读理解方法及装置
CN114372454A (zh) 文本信息抽取方法、模型训练方法、装置及存储介质
CN116562286A (zh) 一种基于混合图注意力的智能配置事件抽取方法
CN116910190A (zh) 多任务感知模型获取方法、装置、设备及可读存储介质
CN113704466B (zh) 基于迭代网络的文本多标签分类方法、装置及电子设备
CN112800186B (zh) 阅读理解模型的训练方法及装置、阅读理解方法及装置
CN115203388A (zh) 机器阅读理解方法、装置、计算机设备和存储介质
CN114329005A (zh) 信息处理方法、装置、计算机设备及存储介质
CN112015891A (zh) 基于深度神经网络的网络问政平台留言分类的方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination