CN117350304B - 一种多轮对话上下文向量增强方法及系统 - Google Patents

一种多轮对话上下文向量增强方法及系统 Download PDF

Info

Publication number
CN117350304B
CN117350304B CN202311639567.9A CN202311639567A CN117350304B CN 117350304 B CN117350304 B CN 117350304B CN 202311639567 A CN202311639567 A CN 202311639567A CN 117350304 B CN117350304 B CN 117350304B
Authority
CN
China
Prior art keywords
vector
dialogue
model
sub
loss
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202311639567.9A
Other languages
English (en)
Other versions
CN117350304A (zh
Inventor
金玉赫
郑梅
徐青伟
严长春
裴非
范娥媚
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Xinghe Zhiyuan Technology Co ltd
Original Assignee
Zhiguagua Tianjin Big Data Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhiguagua Tianjin Big Data Technology Co ltd filed Critical Zhiguagua Tianjin Big Data Technology Co ltd
Priority to CN202311639567.9A priority Critical patent/CN117350304B/zh
Publication of CN117350304A publication Critical patent/CN117350304A/zh
Application granted granted Critical
Publication of CN117350304B publication Critical patent/CN117350304B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/30Semantic analysis
    • G06F40/35Discourse or dialogue representation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • G06F40/205Parsing
    • G06F40/216Parsing using statistical methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0499Feedforward networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Probability & Statistics with Applications (AREA)
  • Machine Translation (AREA)

Abstract

本申请公开了一种多轮对话上下文向量增强方法及系统,方法包括通过Ernie3模型编码,智能化地从对话内容中提取出语义信息并形成初始上下文向量;接收预处理过的文本作为输入并通过各个下游任务进行学习,从而提取出不同的任务特性并形成增强向量;利用预训练的解码器评估增强向量的效果,随后引入了PPO策略与KL散度计算,确保增强向量与初始向量在分布上的相似性。本申请为多轮对话上下文提供了一个全新、高效的向量表示和增强方法,进一步提升了对话理解的准确性,为下游任务提供更为精准的信息。

Description

一种多轮对话上下文向量增强方法及系统
技术领域
本申请涉及计算机技术领域,尤其涉及一种多轮对话上下文向量增强方法及系统。
背景技术
随着对话系统技术的快速进展,多轮对话的理解与用户意图的识别逐渐成为研究的焦点。为了更准确地解决这些问题,现有技术尝试了各种方法和策略。
在众多的研究中,一部分方法专注于使用深度学习技术,特别是神经网络结构来捕获和表示对话中的语义和上下文信息。这些技术如循环神经网络(RNN)、长短期记忆网络(LSTM)等,常常与注意力机制相结合,为意图识别和对话理解提供了强大的编码能力。例如,通过将语言内容进行数学化处理,编码对话,再结合外部存储的历史对话信息,可以捕获到更丰富的对话上下文。
此外,为了进一步提高对话理解的精度,有些方法还结合了其他高级技术,如记忆网络和异构图神经网络。记忆网络特别适用于处理和检索大量的历史对话信息,而异构图神经网络则有助于从多个数据源中综合信息,从而更准确地识别用户的真实意图。
然而现有技术主要针对单一任务进行优化,导致上下文信息的片面性,如仅关注意图识别或情感分析;并且在处理不同轮数的对话时,往往忽略了某些关键的上下文信息,导致对话理解不够准确。
发明内容
本申请提供一种多轮对话上下文向量增强方法及系统,将各种任务信息,如用户意图识别、情感分析和关键信息抽取,整合到统一的上下文表示向量中。这不仅增强了上下文向量的表达能力,也确保了信息的完整性和多维度特性。
第一方面,一种多轮对话上下文向量增强方法,所述方法包括:
获取多轮对话数据集,将所述多轮对话数据集进行预处理后根据预设大小的窗口对各个对话数据集中对话内容进行截断得到子对话数据集,并将所述子对话数据集通过Ernie3模型进行处理生成向量表示集;其中,每个子对话数据生成一个向量表示;
构建多个下游任务模型,将所述子对话数据集输入至各个下游任务模型进行学习,在训练完成后,将子对话数据集输入训练好的多任务模型,从每个任务模型分类头的前一层提取输出,并将输出进行加权平均形成增强向量表示集;其中,每个下游任务模型均由基础Ernie3模型和各自不同的分类头组成;
建立用于预测下一轮对话回答的解码器,利用所述向量表示集作为输入预训练解码器;
将增强向量表示集中的向量输入到预训练后的解码器中得到预测回答,通过预测回答与子对话数据集中的真实回答计算交叉熵损失,同时计算子对话数据集和增强向量表示集中的向量的KL散度;应用PPO策略中的策略比率clipping机制对交叉熵损失进行修正得到PPO损失,基于所述PPO损失和所述KL散度确定总损失,并基于所述总损失通过反向传播优化多任务模型得到目标多任务模型。
可选地,将所述多轮对话数据集进行预处理后根据预设大小的窗口对各个对话数据集中对话内容进行截断得到子对话数据集,包括:
根据预设窗口大小,对多轮对话数据集中的对话内容进行截断,并使用滑动窗口策略增加数据量;
其中,当存在对话轮次小于预设窗口大小时,使用填充策略将当前对话轮次补齐至预设窗口大小。
可选地,将所述子对话数据集通过Ernie3模型进行处理生成向量表示集,包括:
子对话数据集被送入Ernie3模型的Transformer层进行处理;其中,每轮对话通过多个Transformer块的处理后,都会生成一个向量表示;
并使用自注意力机制将各个向量相互关联起来,并为各个向量生成对应权重,进行计算各个向量的加权平均值作为输出;
具体的公式为:
其中,vi是第i个向量,f是前馈神经网络,N是总的向量数量,αi是第i个向量的权重;
其中,Cinit为生成的向量表示。
可选地,将所述子对话数据集输入至各个下游任务模型进行学习,包括:
将各个下游任务模型的损失函数进行加权求和得到总损失函数。
可选地,在构建多个下游任务模型过程中,不同下游任务至少包括用户意图识别任务、情感分析任务以及关键信息抽取任务。
可选地,建立用于预测下一轮对话回答的解码器,利用所述向量表示集作为输入预训练解码器,包括:
将向量表示集作为初始上下文向量输入到解码器中;其中,损失函数为交叉熵损失,对比真实的下一轮对话答案和解码器的预测来进行训练,公式包括:
其中,t指的是序列中的位置,是yt是在时间步t的真实标签的one-hot编码,pt是模型在时间步t预测的概率分布,LCE是交叉熵损失。
可选地,基于所述PPO损失和所述KL散度确定总损失,包括:
通过公式
确定总损失Lfine_tune,其中,β表示权重系数,LPPO表示PPO损失,LKL表示KL散度。
第二方面,一种多轮对话上下文向量增强系统,系统包括:
对话内容处理模块,用于获取多轮对话数据集,将所述多轮对话数据集进行预处理后根据预设大小的窗口对各个对话数据集中对话内容进行截断得到子对话数据集,并将所述子对话数据集通过Ernie3模型进行处理生成向量表示集;其中,每个子对话数据生成一个向量表示;
向量提取模块,用于构建多个下游任务模型,将所述子对话数据集输入至各个下游任务模型进行学习,在训练完成后,将子对话数据集输入训练好的多任务模型,从每个任务模型分类头的前一层提取输出,并将输出进行加权平均形成增强向量表示集;其中,每个下游任务模型均由基础Ernie3模型和各自不同的分类头组成;
预训练模块,用于建立用于预测下一轮对话回答的解码器,利用所述向量表示集作为输入预训练解码器;
向量优化模块,用于将增强向量表示集中的向量输入到预训练后的解码器中得到预测回答,通过预测回答与子对话数据集中的真实回答计算交叉熵损失,同时计算子对话数据集和增强向量表示集中的向量的KL散度;应用PPO策略中的策略比率clipping机制对交叉熵损失进行修正得到PPO损失,基于所述PPO损失和所述KL散度确定总损失,并基于所述总损失通过反向传播优化多任务模型得到目标多任务模型。
第三方面,提供了一种电子设备,包括存储器和处理器,存储器存储有计算机程序,处理器执行计算机程序时实现上述第一方面任一所述的多轮对话上下文向量增强方法。
第四方面,提供了一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现上述第一方面任一所述的多轮对话上下文向量增强方法。
相比现有技术,本申请至少具有以下有益效果:
(1)将各种任务信息,如用户意图识别、情感分析和关键信息抽取,整合到统一的上下文表示向量中。这不仅增强了上下文向量的表达能力,也确保了信息的完整性和多维度特性。
(2)通过滑动窗口策略,本发明确保了对于每一轮的对话,都能够固定长度的上下文信息;结合Ernie3模型的编码,进一步提取出这部分上下文中的关键语义信息,并确保输入维度的一致性。
(3)本申请采用了KL散度来量化增强向量与初始向量之间的差异,确保了它们在信息内容上的一致性。结合交叉熵损失和PPO策略,进一步评估和优化增强向量。不仅提供了一个关于增强向量预测准确性的标准,还引入了一种针对多任务模型的策略性微调方法,允许在保持原有语义信息的同时,进行持续的向量质量提升。
附图说明
图1为本申请实施例提供的整体流程图;
图2为本申请实施例提供的向量表示集生成过程示意图;
图3为本申请实施例提供的单一模型训练过程示意图;
图4为本申请实施例提供的预训练解码器流程示意图;
图5为本申请实施例提供的使用PPO策略评估流程示意图;
图6为本申请多轮对话上下文向量增强系统的模块架构框图;
图7为一个实施例中电子设备的内部结构图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本申请,并不用于限定本申请。
在本申请的描述中:术语“包括”、“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包括了一系列步骤或单元的过程、方法、系统、产品或设备不必限于已明确列出的那些步骤或单元,而是还可包含虽然并未明确列出的但对于这些过程、方法、产品或设备固有的其它步骤或单元,或者基于本发明构思进一步的优化方案所增加的步骤或单元。
本发明提出了一种基于Ernie3模型和多任务学习的对话上下文增强方法及系统。首先,输入多轮对话数据,进行预处理和滑动窗口策略处理后,利用Ernie3模型为每轮对话生成语义向量,并采用自注意力池化形成初始上下文向量。
接着,构建与Ernie3模型同构的下游任务模型,如用户意图识别、情感分析和关键信息抽取等。每个任务模型输出的中间层向量被进行加权平均,形成了增强上下文向量,能更全面地反映对话的语境。
为了评估和优化增强向量,首先设计一个解码模型预测下一轮对话回答,并基于初始和增强向量进行训练。评估中,计算增强向量与初始向量的KL散度,确保向量的一致性。再利用交叉熵损失评估增强向量在预测对话回答的效果,并结合PPO策略进行损失修正。结合这两方面的评估,通过反向传播进一步优化多任务学习模型,确保生成的增强向量更为准确。
本发明结合了Ernie3模型、多任务学习和细致的评估策略,为多轮对话上下文提供了一个全新、高效的向量表示和增强方法,进一步提升了对话理解的准确性,为下游任务提供更为精准的信息。
在一个实施例中,如图1所示,提供了一种多轮对话上下文向量增强方法,该方法可以应用于服务器中,方法包括:
S1对话内容处理与Ernie3模型的编码,获取多轮对话数据集,将所述多轮对话数据集进行预处理后根据预设大小的窗口对各个对话数据集中对话内容进行截断得到子对话数据集,并将所述子对话数据集通过Ernie3模型进行处理生成向量表示集。
其中,每个子对话数据生成一个向量表示。
在本步骤中,主要实现了对话内容处理与Ernie3模型的编码,如图2给出了向量表示集生成过程示意图,在S1中包括了:
S11.输入多轮对话数据集。
S12.对输入数据进行预处理,剔除无关信息与噪音,只保留有意义的文本。
S13.根据预设窗口大小W,对S12处理后的对话内容进行截断,使用滑动窗口策略增加数据量,保证输入维度一致性。
S14. Ernie3模型接收S13处理后的数据。它为每轮对话生成一个向量表示,反映了这轮对话的语义信息。对得到的若干向量使用自注意力池化策略,形成了初始上下文向量Cinit,(Cinit也即C_init)。
以下给出上述步骤的具体实现过程:
S11、S12:输入多轮对话数据集。对输入数据进行预处理,剔除无关信息与噪音,只保留有意义的文本。
数据集规模:本发明所采用的数据集是一个大规模的多轮对话数据集,该数据集包含超过100万对话会话,其中每个会话平均包含5-10轮对话交互。
数据来源:此数据集来源于多个渠道,主要包括:
在线客服记录:搜集企业内部的客服聊天记录,并进行了匿名处理。
社交媒体交互:来自于一些热门的社交媒体平台的公开对话数据。
论坛和问答网站:部分交互数据来源于各种在线论坛和问答网站。
合成数据:通过先进的对话生成模型如chatgpt,合成了一部分数据以增加数据的多样性。
数据质量:
清洗:为了确保数据的质量,本申请进行了多轮的数据清洗,剔除了那些包含敏感信息、广告或是无意义重复的内容。
注解:除了原始对话内容,本申请还对部分数据进行注解,这些注解包括意图标签、情感标签和关键信息。
多样性:数据集尽量涵盖了各种对话场景,从日常生活咨询、技术问题解答,到复杂的商务咨询等,以保证模型的广泛适用性。
S13:根据预设窗口大小W,对S12处理后的对话内容进行截断,使用滑动窗口策略增加数据量,保证输入维度一致性。
窗口大小的选择:窗口大小W代表Ernie3模型考虑的对话轮次。这主要基于模型的能力和数据集中对话的平均轮次。经过初步的统计分析,本申请发现大部分对话在5-10轮左右,因此本申请可以选择W=5作为一个合理的窗口大小。
滑动窗口策略具体包括窗口移动:给定一个长度为L的对话和窗口大小W,窗口每次移动一轮对话。因此,对于一个10轮的对话和W=5,会得到以下子对话:
对话轮次1-5、对话轮次2-6、对话轮次3-7、对话轮次4-8、对话轮次5-9、对话轮次6-10。
填充策略具体包括:填充标记:对于轮次少于W的对话,本申请会使用特定的填充对话轮次,如一个包含[PAD]的默认对话,来补齐到W轮。
头部填充:为了保持对话的原始顺序,本申请会在对话的开始处进行填充。
处理效果:通过这种方式,本申请确保了输入到Ernie3模型的每个对话都具有W轮的对话内容。并且,使用滑动窗口策略,原始数据集可以得到大量的子对话,从而增加了数据的规模。这种方法尤其对于少量的长对话而言是非常有效的,因为它允许模型从多个角度观察相同的对话内容。这为模型的批量处理设定了统一的输入格式,并帮助模型更好地捕获多轮对话中的上下文关系。
S14:Ernie3模型接收S13处理后的数据。它为每轮对话生成一个向量表示,反映了这轮对话的语义信息。对得到的若干向量使用自注意力池化策略,形成了初始上下文向量Cinit
S2多任务学习与向量优化,构建多个下游任务模型,将所述子对话数据集输入至各个下游任务模型进行学习,在训练完成后,将子对话数据集输入训练好的多任务模型,从每个任务模型分类头的前一层提取输出,并将输出进行加权平均形成增强向量表示集。
其中,每个下游任务模型均由基础Ernie3模型和各自不同的分类头组成,不同下游任务至少包括用户意图识别任务、情感分析任务以及关键信息抽取任务。
在本步骤中,主要实现了多任务学习与向量优化,如图3给出了单一模型训练过程示意图,在S2中包括了:
S21.设计下游任务模型,与Ernie3模型同构,但具有不同的分类头,如用户意图识别、情感分析、关键信息抽取等。
S22.将S13处理后的数据作为输入,送入各个下游任务模型进行学习。
S23. 从每个任务模型分类头的前一层提取输出。这些输出被进行加权平均,形成增强向量Cenhanced,确保Cenhanced与Cinit具有相同的维度。
具体地,加一个输入是每个子对话数据。每个子对话数据是一个n轮的对话,然后将n轮对话中的每一轮对话输入到一个模型中(多任务模型m个中其中一个),从分类头的前一层提取到一个输出,这个输出是一个向量。然后使用上面同样的那个自注意力机制进行多轮的一个融合,把n论对话得到的n个向量融合成一个。因为有m个模型,所以再把m个模型的m个向量融合成一个,得到增强向量。
在本申请实施例中,Ernie3的基本构建和特点:
模型结构: ERNIE (Enhanced Representation through kNowledgeIntEgration) 是百度提出的预训练模型,目标是更好地处理各种NLP任务。ERNIE与BERT不同的地方在于它对知识增强表示的方法进行了创新。它尝试结合了大量的结构化知识,如实体关系、语义关系等,来改进模型的语义表示。
多任务预训练策略: 在多任务学习环境中,Ernie3已经进行了大量的预训练,使其具有捕获深层次语义关系的能力。
对话内容的编码过程:
输入嵌入: 每轮对话首先被转化为嵌入向量,这一步使用Ernie3的嵌入层完成。
向量表示: 嵌入的对话内容被送入Ernie3的Transformer层进行处理。每轮对话通过多个Transformer块的处理后,都会生成一个向量表示。这个向量是该轮对话的语义抽象。
自注意力池化:
使用自注意力机制将5个向量相互关联起来。注意力机制会为每个向量生成一个权重,这可以确保重要的向量得到更高的权重。
使用这些权重,计算加权平均的向量作为输出。
具体的公式为:
其中,vi是第i个向量,f是前馈神经网络,N是总的向量数量,αi是第i个向量的权重;
其中,Cinit为生成的向量表示。此方法可以确保输出向量充分融合了输入向量中的所有信息,输入是一个子对话数据集,输出是一个向量。
以下给出上述过程的具体实现过程:
S21:设计下游任务模型,与Ernie3模型同构,但具有不同的分类头,如用户意图识别、情感分析、关键信息抽取等。
目的: 为了从输入的对话数据中获取更丰富和特定的信息,设计了多个下游任务模型。这些模型是为了捕获多轮对话中的各种语义属性和细节,如用户的意图、情感状态、涉及的关键信息等。
模型结构:
基础架构: 每一个下游任务模型与Ernie3模型同构,即模型主体部分结构和初始参数一致(不包括分类头和嵌入层)。
分类头: 虽然所有下游任务模型在基本架构上与Ernie3相同,但它们的输出层或“分类头”是不同的。这是因为每个任务都有其独特的目标。例如:
用户意图识别: 该模型的输出层设计为多分类层,每个类别代表一个可能的用户意图。比如“查询账户余额”、“更改密码”等。
情感分析: 这通常是一个三分类问题(正面、负面、中性),但也可以更复杂,如考虑多种情感。
关键信息抽取: 这是一个序列标注任务,输出层将为输入中的每个单词或实体提供一个标签,表示其类别或重要性。
训练数据:
来源: 数据集中已标记的任务标签。
S22:将Cinit作为输入,送入各个下游任务模型进行学习。
在此阶段,将使用从S13中得到的多轮对话窗口数据,将其作为输入传入多个下游任务模型。由于每个任务模型都针对一个特定的目标(例如意图识别、情感分析等),每个模型都会为其特定的任务生成一个输出,并伴随一个损失函数来量化模型输出与真实标签之间的差异。
用户意图识别: 此任务通常是一个分类任务。因此,可以使用交叉熵损失函数表示意图识别损失:
其中,yi是真实标签的独热编码,pi是模型的预测概率。
情感分析: 同样是一个分类任务,本申请中使用交叉熵损失函数表示情感分析损失:
其中,yi是真实标签的独热编码,pi是模型的预测概率。
关键信息抽取: 这是一个序列标注任务,使用每个token标签的交叉熵损失表示信息抽取损失:
其中,yi是真实标签的独热编码,pi是模型的预测概率。
总损失函数: 由于正在执行多任务学习,需要一个总的损失函数来联合优化所有任务。一个常见的方法是简单地将所有任务的损失加权求和:
其中,α, β 和 γ 是权重系数,它们可以根据任务的重要性或难度进行调整。
模型优化上使用标准的优化方法Adam,来最小化总损失Lmulti_mask
S23:从每个任务模型分类头的前一层提取输出。这些输出被合并,形成增强向量Cenhanced
在多任务学习中,常常需要结合各任务的知识以构建一个更强大、更具泛化能力的模型。一种常见的策略是将各任务的模型中间层的输出融合,从而获得一个综合性的表示。
从每个下游任务模型中,本申请选择分类头的前一层,也就是模型的最后一个Transformer层来提取输出。因为它通常包含了整个输入数据的高层次、综合性的表示。这层的输出应该包含了足够的信息来完成指定的任务,同时也是一个较为压缩和抽象的表示。
以下给出一个具体合并策略的示例:
以Transformer模型为例,每个Transformer层的输出都是一个batch×sequencelength×dmodel的张量,其中dmodel是模型的隐藏层大小。
加权平均:对于每个任务的输出,先对sequence length维度取平均,得到一个batch×dmodel的张量。然后,对所有任务的张量进行加权平均。
其中,ω123是超参数,代表每个任务输出的权重。
S3解码器的预训练,建立用于预测下一轮对话回答的解码器,利用所述向量表示集作为输入预训练解码器。
在本步骤中,主要实现了解码器的预训练,如图4给出了预训练解码器的处理过程示意图,在S3中包括了:
S31.设计解码模型,预测下一轮的对话回答。
S32.使用Cinit为输入,根据真实的下一轮对话答案训练解码器。
S33.经过训练,确保解码器具有较好的基于Cinit生成下一轮回答的能力。
S34.进一步使用Cenhanced进行微调解码器。
本步骤中,具体包括了:
设计解码模型,预测下一轮的对话回答。使用Cinit为输入,根据真实的下一轮对话答案训练解码器。经过训练,确保解码器具有较好的基于Cinit生成下一轮回答的能力。
设计解码模型:
解码模型的目的是从给定的向量中生成下一轮的对话回答。本发明采用预训练好的GPT2来作为基础模型。
使用Cinit为输入,训练解码器:
首先将Cinit作为初始上下文向量输入到GPT-2模型中。模型的任务是预测下一轮对话的答案。损失函数为交叉熵损失,对比真实的下一轮对话答案和GPT-2模型的预测来进行训练。公式如下:
这里的时间步t指的是序列中的位置,因为在对话生成的上下文中为每个位置预测下一个单词。yt是在时间步t的真实标签的one-hot编码。pt是模型在时间步t预测的概率分布。LCE是交叉熵损失。
微调解码器:
在初步使用Cinit进行训练之后,为了确保GPT-2模型能更好地解码Cenhanced,本申请中会进一步使用Cenhanced进行微调。同样,此时的损失函数还是交叉熵损失,对比真实的下一轮对话答案和模型的预测。
这种两步训练策略,确保了模型在预测时能够充分利用Cenhanced中的丰富信息,从而更好地预测下一轮的对话答案。
S4增强向量的评估与优化,将增强向量表示集中的向量输入到预训练后的解码器中得到预测回答,通过预测回答与子对话数据集中的真实回答计算交叉熵损失,同时计算子对话数据集和增强向量表示集中的向量的KL散度;应用PPO策略中的策略比率clipping机制对交叉熵损失进行修正得到PPO损失,基于所述PPO损失和所述KL散度确定总损失,并基于所述总损失通过反向传播优化多任务模型得到目标多任务模型。
其中,本步骤所得到的目标多任务模型属于可训练状态,目标多任务模型中基础Ernie3模型和各自不同的分类头组成的下游任务模型,以及自注意机制所对应的前馈神经网络。
在本步骤中主要实现了增强向量表示集中的向量的评估与优化,如图5给出了使用PPO策略评估和增强处理过程示意图,在S4中具体包括了:
S41.使用预训练的解码器,将Cenhanced作为输入(Cenhanced即C_enhanced),预测下一轮的对话回答。
S42.使用真实的下一轮对话与预测回答计算交叉熵损失LCE
S43.计算Cenhanced与Cinit之间的KL散度LKL
S44.应用PPO策略中的策略比率clipping机制对LCE进行修正:
其中,是新策略与旧策略的比率,新策略对应于Cenhanced生成的回答,旧策略对应于Cinit生成的回答,ϵ是一个预设的小值。
S45.总损失是PPO损失与KL散度的加权和:
确定总损失Lfine_tune(Lfine_tune即L_finetune),其中,β表示权重系数,LPPO表示PPO损失,LKL表示KL散度。
S46.基于Lfine_tune,通过反向传播,优化多任务学习模型,进一步提炼Cenhanced
以下给出上述步骤的具体实现过程:
S41:使用预训练的解码器,将Cenhanced作为输入,预测下一轮的对话回答。
在S41中,使用之前训练好的解码器,并将经过多任务学习优化的Cenhanced作为输入来生成下一轮对话的回答。
输入表示:
Cenhanced是一个固定维度的向量,它包含了多轮对话的综合信息。在将Cenhanced输入到GPT-2解码器之前,需要将其转化为适合GPT-2输入格式的表示。常见的做法是将Cenhanced与一个特殊的开始标记<s>进行连接,然后开始文本生成过程。
文本生成策略:
为了得到流畅并与上下文相关的回答,使用Top-p采样生成策略,即累积概率超过p的词来进行随机采样。
S42:使用真实的下一轮对话与预测回答计算交叉熵损失LCE
计算预测的概率分布:
GPT-2解码器为下一轮对话的每个时间步t输出一个词概率分布Ppred(wt),其中wt是该时间步的预测词。
真实的词标签:
对于下一轮对话的每个时间步t,有一个真实的词标签wtrue,t
计算交叉熵损失:
对于每个时间步t,交叉熵损失计算为:
对所有时间步求和,得到整个序列的平均交叉熵损失:
其中T是序列的长度。
S43:计算Cenhanced与Cinit之间的KL散度LKL
KL散度,即Kullback-Leibler散度,是衡量两个概率分布相似性的常用方法。在这里,将其用于测量Cenhanced和Cinit之间的差异。
计算Cenhanced和Cinit之间的KL散度的公式为:
其中,i表示向量中的元素索引。
S44:应用PPO策略中的策略比率clipping机制对LCE进行修正:
定义新策略的概率:即基于Cenhanced的回答预测的概率:Pnew
定义旧策略的概率,即基于Cinit的回答预测的概率:Pold
具体来说是,使用Cinit和Cenhanced作为输入预训练的解码器生成下一轮的对话回答。这时得到的预测概率分布标记为Pnew和Pold
定义策略比例rt(θ)为新策略与旧策略的比率。具体来说,就是Pnew除以Pold:
假设句子 "I like it" 包含三个位置,那么对于这三个位置,都会得到一个概率分布。考虑以下概率分布:
对于位置1(即第一个词 "I"):
Pold(1)=[0.8,0.05,0.05,0.05,0.05,0,0,0]
Pnew(1)=[0.85,0.03,0.03,0.03,0.06,0,0,0]
对于位置2(即第二个词 "like"):
Pold(2)=[0,0.7,0.1,0,0.1,0.05,0,0.05]
Pnew(2)=[0,0.65,0.15,0,0.12,0.03,0,0.05]
对于位置3(即第三个词 "it"):
Pold(3)=[0,0,0.9,0,0,0.05,0.03,0.02]
Pnew(3)=[0,0,0.88,0,0,0.06,0.04,0.02]
对于每个位置,计算策略比率:
和 />分别代表新策略和旧策略在该位置的概率分布。
之后,基于每个位置的策略比率 ,可以为每个位置计算 PPO 损失。
为了得到整个句子的 PPO 损失,对所有位置的损失进行求和:
N代表句子中的词的数量,ϵ是一个预设的小值。
S45、S46:计算总损失,并基于Ltotal,通过反向传播,优化多任务学习模型,进一步提炼Cenhanced
损失计算:
总损失是PPO损失与KL散度的加权和:
确定总损失Lfine_tune,其中,β表示权重系数,LPPO表示PPO损失,LKL表示KL散度。
反向传播:
基于Lfine_tune,使用小学习率及Adam优化器对多任务学习中的每个模型(不同分类头)进行微调,只调整其主干部分(不包括嵌入层和分类头)进一步提炼和优化 Cenhanced
综上可以看出,本发明公开了以下技术点:
本发明提出了一种基于多轮对话数据集,通过Ernie3模型编码,智能化地从对话内容中提取出语义信息并形成初始上下文向量Cinit。这一技术手段的突破在于滑动窗口策略和自注意力池化策略的结合,确保了输入维度的一致性和对话内容的精确表示。
本发明设计了一种基于多任务的学习模型,该模型接收预处理过的文本作为输入并通过各个下游任务进行学习,从而提取出不同的任务特性并形成增强向量Cenhanced。这种方法实现了多任务信息的有效融合,并确保增强向量与初始向量具有相同的维度。
本发明进一步提出了一个解码模型的预训练策略,用来预测下一轮的对话回答。关键之处在于,模型不仅能够理解基于Cinit的对话内容,还经过训练确保其具备基于Cinit生成下一轮回答的能力。
本发明还设计了一套增强向量的评估与优化机制。首先利用预训练的解码器评估Cenhanced的效果,随后引入了PPO策略与KL散度计算,确保增强向量表示集中的向量与初始向量在分布上的相似性。此方法的核心在于结合策略比率clipping机制与交叉熵损失,达到微调并进一步提炼Cenhanced的目的。
本发明设计了一种对多任务学习模型的持续优化方法,创新性地采用了基于PPO的损失来微调多任务模型,确保在加强Cenhanced的同时保持了模型的稳定性与效果的持续提升。
在一个实施例中,如图6所示,提供了一种多轮对话上下文向量增强系统,系统包括:
对话内容处理模块,用于获取多轮对话数据集,将多轮对话数据集进行预处理后根据预设大小的窗口对各个对话数据集中对话内容进行截断得到子对话数据集,并将子对话数据集通过Ernie3模型进行处理生成向量表示集;其中,每个子对话数据生成一个向量表示;
向量提取模块,用于构建多个下游任务模型,将子对话数据集输入至各个下游任务模型进行学习,在训练完成后,将子对话数据集输入训练好的多任务模型,从每个任务模型分类头的前一层提取输出,并将输出进行加权平均形成增强向量表示集;其中,每个下游任务模型均由基础Ernie3模型和各自不同的分类头组成,不同下游任务至少包括用户意图识别、情感分析以及关键信息抽取;
预训练模块,用于建立用于预测下一轮对话回答的解码器,利用向量表示集作为输入预训练解码器;
向量优化模块,用于将增强向量输入到预训练后的解码器中得到预测回答,通过预测回答与子对话数据集中的真实回答计算交叉熵损失,同时计算子对话数据集和增强向量的KL散度;应用PPO策略中的策略比率clipping机制对交叉熵损失进行修正得到PPO损失,基于PPO损失和KL散度确定总损失,并基于总损失通过反向传播优化多任务模型得到目标多任务模型;其中,目标多任务模型中基础Ernie3模型和各自不同的分类头组成的下游任务模型,以及自注意机制所对应的前馈神经网络。
在一个实施例中,提供了一种电子设备,该电子设备可以是服务器,其内部结构图可以如图7所示。该电子设备包括通过系统总线连接的处理器、存储器和网络接口。其中,该电子设备的处理器用于提供计算和控制能力,网络接口用于与外部的终端通过网络连接通信,该电子设备通过加载运行计算机程序以实现上述一种专利多领域知识抽取方法。
本领域技术人员可以理解,图7中示出的结构,仅仅是与本申请方案相关的部分结构的框图,并不构成对本申请方案所应用于其上的电子设备的限定,具体的电子设备可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
在一个实施例中,还提供了一种计算机可读存储介质,其上存储有计算机程序,涉及上述实施例方法中的全部或部分流程。
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。

Claims (10)

1.一种多轮对话上下文向量增强方法,其特征在于,所述方法包括:
获取多轮对话数据集,将所述多轮对话数据集进行预处理后根据预设大小的窗口对各个对话数据集中对话内容进行截断得到子对话数据集,并将所述子对话数据集通过Ernie3模型进行处理生成向量表示集;其中,每个子对话数据生成一个向量表示;
构建多个下游任务模型,将所述子对话数据集输入至各个下游任务模型进行学习,在训练完成后,将子对话数据集输入训练好的多任务模型,从每个任务模型分类头的前一层提取输出,并将输出进行加权平均形成增强向量表示集;其中,每个下游任务模型均由基础Ernie3模型和各自不同的分类头组成;
建立用于预测下一轮对话回答的解码器,利用所述向量表示集作为输入预训练解码器;
将增强向量表示集中的向量输入到预训练后的解码器中得到预测回答,通过预测回答与子对话数据集中的真实回答计算交叉熵损失,同时计算向量表示集中的向量和增强向量表示集中的向量的KL散度;应用PPO策略中的策略比率clipping机制对交叉熵损失进行修正得到PPO损失,基于所述PPO损失和所述KL散度确定总损失,并基于所述总损失通过反向传播优化多任务模型得到目标多任务模型。
2.根据权利要求1所述的方法,其特征在于,将所述多轮对话数据集进行预处理后根据预设大小的窗口对各个对话数据集中对话内容进行截断得到子对话数据集,包括:
根据预设窗口大小,对多轮对话数据集中的对话内容进行截断,并使用滑动窗口策略增加数据量;
其中,当存在对话轮次小于预设窗口大小时,使用填充策略将当前对话轮次补齐至预设窗口大小。
3.根据权利要求1所述的方法,其特征在于,将所述子对话数据集通过Ernie3模型进行处理生成向量表示集,包括:
子对话数据集被送入Ernie3模型的Transformer层进行处理;其中,每轮对话通过多个Transformer块的处理后,都会生成一个向量表示;
并使用自注意力机制将各个向量相互关联起来,并为各个向量生成对应权重,进行计算各个向量的加权平均值作为输出;
具体的公式为:
其中,vi是第i个向量,f是前馈神经网络,N是总的向量数量,αi是第i个向量的权重;
其中,Cinit为生成的向量表示。
4.根据权利要求1所述的方法,其特征在于,将所述子对话数据集输入至各个下游任务模型进行学习,包括:
将各个下游任务模型的损失函数进行加权求和得到总损失函数。
5.根据权利要求1所述的方法,其特征在于,在构建多个下游任务模型过程中,不同下游任务至少包括用户意图识别任务、情感分析任务以及关键信息抽取任务。
6.根据权利要求1所述的方法,其特征在于,建立用于预测下一轮对话回答的解码器,利用所述向量表示集作为输入预训练解码器,包括:
将向量表示集作为初始上下文向量输入到解码器中;其中,损失函数为交叉熵损失,对比真实的下一轮对话答案和解码器的预测来进行训练,公式包括:
其中,t指的是序列中的位置,是yt是在时间步t的真实标签的one-hot编码,pt是模型在时间步t预测的概率分布,LCE是交叉熵损失。
7.根据权利要求1所述的方法,其特征在于,基于所述PPO损失和所述KL散度确定总损失,包括:
通过公式
确定总损失Lfine_tune,其中,β表示权重系数,LPPO表示PPO损失,LKL表示KL散度。
8.一种多轮对话上下文向量增强系统,其特征在于,系统包括:
对话内容处理模块,用于获取多轮对话数据集,将所述多轮对话数据集进行预处理后根据预设大小的窗口对各个对话数据集中对话内容进行截断得到子对话数据集,并将所述子对话数据集通过Ernie3模型进行处理生成向量表示集;其中,每个子对话数据生成一个向量表示;
向量提取模块,用于构建多个下游任务模型,将所述子对话数据集输入至各个下游任务模型进行学习,在训练完成后,将子对话数据集输入训练好的多任务模型,从每个任务模型分类头的前一层提取输出,并将输出进行加权平均形成增强向量表示集;其中,每个下游任务模型均由基础Ernie3模型和各自不同的分类头组成;
预训练模块,用于建立用于预测下一轮对话回答的解码器,利用所述向量表示集作为输入预训练解码器;
向量优化模块,用于将增强向量表示集中的向量输入到预训练后的解码器中得到预测回答,通过预测回答与子对话数据集中的真实回答计算交叉熵损失,同时计算子对话数据集和增强向量表示集中的向量的KL散度;应用PPO策略中的策略比率clipping机制对交叉熵损失进行修正得到PPO损失,基于所述PPO损失和所述KL散度确定总损失,并基于所述总损失通过反向传播优化多任务模型得到目标多任务模型。
9.一种电子设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至7中任一项所述方法的步骤。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至7中任一项所述方法的步骤。
CN202311639567.9A 2023-12-04 2023-12-04 一种多轮对话上下文向量增强方法及系统 Active CN117350304B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202311639567.9A CN117350304B (zh) 2023-12-04 2023-12-04 一种多轮对话上下文向量增强方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202311639567.9A CN117350304B (zh) 2023-12-04 2023-12-04 一种多轮对话上下文向量增强方法及系统

Publications (2)

Publication Number Publication Date
CN117350304A CN117350304A (zh) 2024-01-05
CN117350304B true CN117350304B (zh) 2024-02-02

Family

ID=89365238

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202311639567.9A Active CN117350304B (zh) 2023-12-04 2023-12-04 一种多轮对话上下文向量增强方法及系统

Country Status (1)

Country Link
CN (1) CN117350304B (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
TWI746214B (zh) * 2020-10-19 2021-11-11 財團法人資訊工業策進會 機器閱讀理解方法、機器閱讀理解裝置及非暫態電腦可讀取媒體

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115956261A (zh) * 2021-04-06 2023-04-11 辉达公司 神经网络中分布外输入数据的识别技术
CN116415154A (zh) * 2023-06-12 2023-07-11 江西五十铃汽车有限公司 一种基于gpt的车辆故障解决方案生成方法及装置
CN116501861A (zh) * 2023-06-25 2023-07-28 知呱呱(天津)大数据技术有限公司 基于层级bert模型与标签迁移的长文本摘要生成方法
CN116992833A (zh) * 2023-06-30 2023-11-03 平安科技(深圳)有限公司 对话生成模型的训练方法、装置、设备及介质
WO2023222887A1 (en) * 2022-05-19 2023-11-23 Deepmind Technologies Limited Intra-agent speech to facilitate task learning

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115956261A (zh) * 2021-04-06 2023-04-11 辉达公司 神经网络中分布外输入数据的识别技术
WO2023222887A1 (en) * 2022-05-19 2023-11-23 Deepmind Technologies Limited Intra-agent speech to facilitate task learning
CN116415154A (zh) * 2023-06-12 2023-07-11 江西五十铃汽车有限公司 一种基于gpt的车辆故障解决方案生成方法及装置
CN116501861A (zh) * 2023-06-25 2023-07-28 知呱呱(天津)大数据技术有限公司 基于层级bert模型与标签迁移的长文本摘要生成方法
CN116992833A (zh) * 2023-06-30 2023-11-03 平安科技(深圳)有限公司 对话生成模型的训练方法、装置、设备及介质

Non-Patent Citations (4)

* Cited by examiner, † Cited by third party
Title
ERNIE 3.0 Tiny: Frustratingly Simple Method to Improve Task-Agnostic Distillation Generalization;Weixin Liu etc.;AghaarXiv: 2301.03416v1 [cs.CL];全文 *
ERNIE 3.0: LARGE-SCALE KNOWLEDGE ENHANCED PRE-TRAINING FOR LANGUAGE UNDERSTANDING AND GENERATION;Yu Sun etc.;arXiv:2107.02137v1 [cs.CL];全文 *
Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training;Hong Liu etc.;arXiv:2305.14342v3 [cs.LG];全文 *
一种基于多任务学习的多模态情感识别方法;林子杰 ,等;北京大学学报(自然科学版);第57卷(第1期);全文 *

Also Published As

Publication number Publication date
CN117350304A (zh) 2024-01-05

Similar Documents

Publication Publication Date Title
CN110046221B (zh) 一种机器对话方法、装置、计算机设备及存储介质
US20230028944A1 (en) Dialogue generation method and network training method and apparatus, storage medium, and device
CN108427771B (zh) 摘要文本生成方法、装置和计算机设备
CN106448670B (zh) 基于深度学习和强化学习的自动回复对话系统
CN110032630B (zh) 话术推荐设备、方法及模型训练设备
CN109977201B (zh) 带情感的机器聊天方法、装置、计算机设备及存储介质
CN111966800B (zh) 情感对话生成方法、装置及情感对话模型训练方法、装置
CN117350304B (zh) 一种多轮对话上下文向量增强方法及系统
CN106682387A (zh) 用于输出信息的方法和装置
CN112115246A (zh) 基于对话的内容推荐方法、装置、计算机设备及存储介质
CN112307168A (zh) 基于人工智能的问诊会话处理方法、装置和计算机设备
CN113988086A (zh) 对话处理方法及装置
CN114168707A (zh) 一种面向推荐的情绪型对话方法
Lee et al. Deep representation learning for affective speech signal analysis and processing: Preventing unwanted signal disparities
CN110955765A (zh) 智能助理的语料构建方法、装置、计算机设备和存储介质
CN115269836A (zh) 意图识别方法及装置
CN113806564A (zh) 多模态信息性推文检测方法及系统
CN109727091A (zh) 基于对话机器人的产品推荐方法、装置、介质及服务器
CN113656542A (zh) 一种基于信息检索与排序的话术推荐方法
CN117271745A (zh) 一种信息处理方法、装置及计算设备、存储介质
CN117494762A (zh) 学生模型的训练方法、素材处理方法、装置及电子设备
US11941508B2 (en) Dialog system with adaptive recurrent hopping and dual context encoding
CN113849641B (zh) 一种跨领域层次关系的知识蒸馏方法和系统
CN111078854B (zh) 问答预测模型的训练方法及装置、问答预测方法及装置
CN114547276A (zh) 基于三通道图神经网络的会话推荐方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant
TR01 Transfer of patent right

Effective date of registration: 20240226

Address after: No. 401-1, 4th floor, podium, building 3 and 4, No. 11, Changchun Bridge Road, Haidian District, Beijing 100089

Patentee after: Beijing Zhiguagua Technology Co.,Ltd.

Country or region after: China

Address before: 806A, Building 1, Sixin Building, South Side of Heiniucheng Road, Hexi District, Tianjin, 300221

Patentee before: Zhiguagua (Tianjin) Big Data Technology Co.,Ltd.

Country or region before: China

TR01 Transfer of patent right
CP03 Change of name, title or address

Address after: No. 401-1, 4th floor, podium, building 3 and 4, No. 11, Changchun Bridge Road, Haidian District, Beijing 100089

Patentee after: Beijing Xinghe Zhiyuan Technology Co.,Ltd.

Country or region after: China

Address before: No. 401-1, 4th floor, podium, building 3 and 4, No. 11, Changchun Bridge Road, Haidian District, Beijing 100089

Patentee before: Beijing Zhiguagua Technology Co.,Ltd.

Country or region before: China

CP03 Change of name, title or address