CN111737401A - 一种基于Seq2set2seq框架的关键词组预测方法 - Google Patents
一种基于Seq2set2seq框架的关键词组预测方法 Download PDFInfo
- Publication number
- CN111737401A CN111737401A CN202010576549.0A CN202010576549A CN111737401A CN 111737401 A CN111737401 A CN 111737401A CN 202010576549 A CN202010576549 A CN 202010576549A CN 111737401 A CN111737401 A CN 111737401A
- Authority
- CN
- China
- Prior art keywords
- question
- keywords
- capsule
- keyword
- vector
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F40/00—Handling natural language data
- G06F40/30—Semantic analysis
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/31—Indexing; Data structures therefor; Storage structures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/33—Querying
- G06F16/3331—Query processing
- G06F16/334—Query execution
- G06F16/3344—Query execution using natural language analysis
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/23—Clustering techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/044—Recurrent networks, e.g. Hopfield networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/049—Temporal neural networks, e.g. delay elements, oscillating neurons or pulsed inputs
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computational Linguistics (AREA)
- Evolutionary Computation (AREA)
- Software Systems (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Databases & Information Systems (AREA)
- Audiology, Speech & Language Pathology (AREA)
- Machine Translation (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明涉及自然语言处理技术领域,尤其涉及一种基于Seq2set2seq框架的关键词组预测方法。包括以下步骤:将当前问句输入预先训练的多标签分类器,输出多维向量;预先训练的多标签分类器是基于训练集中问句和问句回复所对应的关键词,采用胶囊网络进行训练得到的模型;S2、获取多维向量中前100维向量所对应的关键词,进行行列式点过程采样,得到多个指导中心词;S3、将当前问句和多个指导中心词输入预先训练的解码器,对应输出多组预测的关键词组。本发明提供的预测方法解决了现有方法中关键词组预测差异性差、生成关键词组少、性能波动的技术问题。
Description
技术领域
本发明涉及自然语言处理技术领域,尤其涉及一种基于Seq2set2seq框架的关键词组预测方法。
背景技术
开放域对话系统(也称聊天机器人)是指对用户提出的任何开放域查询(或称问题)都能给出合适响应(也称回复)的人工智能应用系统。相对于传统信息服务形式(如问答系统、搜索引擎等),开放域对话系统提供更丰富的语义内容和更有效的交互模式,这些特性使得其在可预见的未来生活中越来越普及,因此开放域对话系统相关技术的研究具有极大的经济效应和社会价值。
开放域对话生成以其广阔的应用前景成为自然语言处理领域的一个研究热点,包括聊天机器人、虚拟个人助理等。尽管已经有很多系统从多个方面提出了提高生成响应质量的方案,大多数主要建立在seq2seq架构上,但该架构存在“安全”响应问题。
目前,针对关键词一对一训练的模型,包括增强波束搜索、基于变分自编码器的模型等,这些方法中将多个响应视为独立的响应,无法对多个响应的差异进行建模。针对关键词一对多的模型,将不同组关键词使用特殊符号进行拼接,之后将其建模为序列到序列的任务,但这种模型因为顺序性问题存在生成关键词组较少,性能波动的问题。
发明内容
(一)要解决的技术问题
鉴于现有技术的上述缺点、不足,本发明提供一种基于Seq2set2seq框架的关键词组预测方法,其解决了现有关键词组预测差异性差、生成关键词组少、性能波动的技术问题。
(二)技术方案
为了达到上述目的,本发明采用的主要技术方案包括:
本发明实施例提供一种基于Seq2set2seq框架的关键词组预测方法,Seq2set2seq框架包括预测指导中心词和预测关键词组两部分,包括以下步骤:
S1、将当前问句输入预先训练的多标签分类器,输出多维向量;
多维向量的每一维均对应一个标签,一个标签对应一个关键词,多维向量的每一维的值表示对应标签作为当前问句回复中关键词的概率;
预先训练的多标签分类器是基于训练集中问句和问句回复所对应的关键词,采用胶囊网络进行训练得到的模型;
S2、获取多维向量中前100维向量所对应的关键词,进行行列式点过程采样,得到多个指导中心词;
S3、将当前问句和多个指导中心词输入预先训练的解码器,对应输出多组预测的关键词组;
预先训练的解码器是基于训练集中问句和对应指导中心词、关键词组,采用长短期记忆网络LSTM进行训练得到的模型。
本发明实施例提出的基于Seq2set2seq框架的关键词组预测方法,采用seq2set2seq的两阶段关键词组生成模型,通过考虑不同响应的差异性,利用指导中心词指导建模一对多的对话关键词组预测任务,相对于现有技术而言,其可以具有良好的可解释性、可控性、准确性以及多样性。
可选地,多标签分类器的训练过程包括:
101、向胶囊网络输入问句,提取多个特征,并将多个特征按照最后一个维度进行拼接后输入初级胶囊网络层;
102、将初级胶囊网络层中的初级胶囊转化为路由胶囊;
103、将路由胶囊转化为向量形式,输出多维向量。
可选地,步骤102包括:通过线性变换将初级胶囊转换为浓缩胶囊,通过动态路由将浓缩胶囊转化为路由胶囊。
可选地,浓缩胶囊满足以下公式:
可选地,路由胶囊中元素值满足以下公式:
式中,uj|i为路由胶囊中元素值,Wc为线性变换矩阵。
可选地,路由胶囊满足以下公式:
式中,vj为路由胶囊,cij为耦合系数,通过聚类方法迭代更新,i为胶囊内部的标量的索引。
可选地,步骤S2中,获多维向量中前100维向量所对应的关键词,进行行列式点过程采样,得到5个指导中心词。
可选地,步骤S2包括:
S21、使用步骤S1中多标签分类器中预测的前100维向量所对应的关键词作为关键词集合,对同时出现在预测的100个关键词和当前问句中的N个关键词进行抽取作为指导中心词,N≤5;
S22、使用行列式点过程从剩余的100-N个关键词中采样5-N个关键词作为指导中心词。
可选地,步骤S22具体为:选用100-N个关键词和问句的主题相似度构成一个100-N维的相似度列表,将100-N个关键词的词与词之间的相似度构成一个[100-N,100-N]的矩阵,使用行列式点过程从中采样5-N个既与当前问句相似度大于或等于0.2并且不同关键词之间相似度小于或等于0.06的的关键词作为指导中心词。
可选地,解码器的训练过程包括:
201、将预先获得的指导中心词与真实关键词组使用Familia进行主题相似度计算,构造解码器的训练集,相似度计算公式如下:
式中,Similarity(w,c)为指导中心词和真实关键词组的主题相似度,w为指导中心词,c为真实关键词组,k为主题的编号,zk为第k个主题,p(w|zk)为指导中心词属于第k个主题的概率,p(zk|c)为真实关键词组的主题分布;
202、将训练集中的指导中心词进行LSTM编码得到指导中心词经过LSTM编码后的向量,将训练集中的问句进行LSTM编码得到问句经过LSTM编码后的向量;
203、将指导中心词经过LSTM编码后的向量和问句经过LSTM编码后的向量进行拼接输入到解码器,输出预测的关键词组;
其中,将解码器的生成结果与真实关键词组进行损失计算,从而不断地更新网络的参数,完成解码器的训练;
损失计算公式如下:
式中,L为损失函数,Ly为需要预测的关键词组的长度,Yt为第t时刻需要预测的关键词组,yt-1为第t-1时刻预测出的关键词组,e为问句经过LSTM编码后的向量,r为指导中心词经过LSTM编码后的向量,θ为训练参数,-logP为负对数似然。
(三)有益效果
本发明的有益效果是:本发明提供的基于Seq2set2seq框架的关键词组预测方法,采用seq2set2seq的两阶段关键词组生成模型,通过考虑不同响应的差异性,利用指导中心词指导建模一对多的对话关键词组预测任务,相对于现有技术而言,其可以具有良好的可解释性、可控性、准确性以及多样性。
附图说明
图1为本发明提供的基于Seq2set2seq框架的关键词组预测方法的流程图;
图2为本发明中Seq2set2seq框架的关键词组生成架构图;
图3为本发明提供的基于Seq2set2seq框架的关键词组预测方法的总体思想图。
具体实施方式
为了更好的解释本发明,以便于理解,下面结合附图,通过具体实施方式,对本发明作详细描述。
一般来说,一个对话系统可以分为两类:任务型对话系统和开放域系统。本发明的第一相关实施例中将特定任务系统转化为部分可观察的马尔可夫决策过程,通过强化学习进行学习,该架构下的对话状态跟踪和响应生成等模块可以分别进行训练,然后插入到一个完全部署的系统中。本发明的第二相关实施例中提出了一种基于端到端神经网络的模块联合训练模型。自Seq2Seq模型提出以来,开放域系统得到了广泛的研究。基于此,本发明的第三相关实施例中提出添加上下文向量作为解码器的额外输入,本发明的第四相关实施例中提出了目标序列与输入序列对齐的注意机制,将这些模型引入到对话生成领域,并进一步提出了一个混合模型。本发明的第五相关实施例中考虑到消息和响应之间的复制现象,提出了一种复制机制,将消息的单词复制到响应中。本发明的第六相关实施例中使用分层的方法对话语和交互结构进行建模,建立多轮对话系统。此外,还可以使用角色信息来生成更一致的响应。
尽管基于seq2seq的模型在构建开放域对话系统方面取得了成功,但是这些模型生成的响应往往是通用的。为了解决这个问题,本发明的第七相关实施例中引入最大互信息作为目标函数。本发明的第八相关实施例中在生成过程中加入一个潜在的随机变量。本发明的第九相关实施例中将响应生成转换为一个强化学习问题,并通过手动定义奖励函数来惩罚一般的响应。本发明的第十相关实施例中提出了一个奖励函数来探索用户反馈中隐含的反馈。然而,奖励功能很难设计来涵盖理想反应的关键方面。本发明的第十一相关实施例中提出了一种对抗学习模型,该模型使用一个鉴别器来计算奖励。基于此,本发明的第十二相关实施例中使用检索到的响应候选项增强了基于对抗的神经响应生成模型的性能,将关键字合并到响应中也可以解决一般的响应问题。本发明的第十三相关实施例中提出了基于单个关键字的方法来从关键字开始生成响应,并随后生成其余的以前和将来的单词。本发明的第十四相关实施例中提出了基于多关键字的方法,该方法使用一个额外的粗编码器对从消息中提取的关键字进行编码。本发明的第十五相关实施例中提出基于多关键词的模型,用主题词扩充解码器,主题词由注意力机制线性组合。在生成过程中加入关键字是获得更多信息响应的有效方法。但是,对关键词的生成还没有得到充分的讨论。
基于上述,本发明实施例提出的基于Seq2set2seq框架的关键词组预测方法,采用seq2set2seq的两阶段关键词组生成模型,通过考虑不同响应的差异性,利用指导中心词指导建模一对多的对话关键词组预测任务,相对于现有技术而言,其可以具有良好的可解释性、可控性、准确性以及多样性。
为了更好的理解上述技术方案,下面将参照附图更详细地描述本发明的示例性实施例。虽然附图中显示了本发明的示例性实施例,然而应当理解,可以以各种形式实现本发明而不应被这里阐述的实施例所限制。相反,提供这些实施例是为了能够更清楚、透彻地理解本发明,并且能够将本发明的范围完整的传达给本领域的技术人员。
实施例1
如图1所示,为本实施例提供的基于Seq2set2seq框架的关键词组预测方法的流程图,如图2所示,为本实施例提供的基于Seq2set2seq框架的关键词组预测方法的总体架构图,包括以下步骤:
S1、将当前问句输入预先训练的多标签分类器,输出10000维向量。10000维向量的每一维均对应一个标签,一个标签对应一个关键词,10000维向量的每一维的值表示对应标签作为当前问句回复中关键词的概率。预先训练的多标签分类器是基于训练集中问句和问句回复所对应的关键词,采用胶囊网络进行训练得到的模型。
胶囊网络的训练过程包括:
101、向胶囊网络输入问句,提取多个特征,并将多个特征按照最后一个维度进行拼接后输入初级胶囊网络层。具体地:
101a、令X∈Rl×v作为问句词嵌入后的矩阵,其中,R为向量,l为问句的长度,v为词嵌入的维度。令Wa∈Rl×k作为一个滤波器,k为滤波器宽度。使用滤波器Wa对问句词嵌入后的向量Rl×v进行卷积操作得到特征mi:f为RELU激活函数。
101b、将所有的特征mi收集到特征矩阵中,为了得到不同角度的特征,将三个不同窗口大小的过滤器(2,3,4)提取多个特征,并将多个特征拼接后输入初级胶囊网络层。
102、将初级胶囊网络层中的初级胶囊转化为路由胶囊。
具体地,在初级胶囊网络层,使用组卷积操作将特征矩阵转化到初级胶囊中:使用1*1的过滤器将特征矩阵中的标量映射为胶囊,即一个d维的向量:
每个胶囊的模长代表该胶囊的概率,其值在0到1之间,其中非线性函数g的计算公式如下:
式中,g(x)为模长,x为每个胶囊的向量表示。
通过线性转换同时去除异常值将大量的初级胶囊转换为浓缩胶囊,计算公式为:
式中,ui为浓缩胶囊,bj为可训练的参数,pj为初级胶囊,j为胶囊的编号。
通过动态路由将浓缩胶囊转化为路由胶囊(更高层次的胶囊)。路由胶囊中元素值满足以下公式:
式中,uj|i为路由胶囊中元素值,Wc为线性变换矩阵。
路由胶囊满足以下公式:
式中,vj为路由胶囊,cij为耦合系数,通过聚类方法迭代更新,i为胶囊内部的标量的索引。
103、将路由胶囊转化为向量形式,输出多维向量。
S2、获取10000维向量中前100维向量所对应的关键词,进行行列式点过程采样,得到5个指导中心词。
由于对话中的回复范围是开放的,关键词的回复也因此是宽泛的,这也会造成一种数据噪声,即对话中不包含的关键词也可能是对的,从而导致模型在训练时错误地降低了可能正确的标签的概率;同时步骤S1中的标签呈现了极端的长尾现象,这两方面对关键词的预测造成了极端的影响:直接使用预测的前5个关键词作为结果,准确率较低,效果不太理想。
具体地,步骤S2包括:
S21、使用步骤S1中多标签分类器中预测前100维向量所对应的关键词作为关键词集合从而保证召回率,再对同时出现在预测的前100个关键词和当前问句中的N个词进行抽取作为指导中心词,从而保证采样得到的指导中心词尽可能多样和丰富,尽可能覆盖当前问句的回复方向。其中N≤5,可以直观理解为当前的问句的部分回复的方向有可能已经存在于当前问句当中。
S22、使用行列式点过程从剩余的100-N个关键词中采样5-N个关键词作为指导中心词。具体地,选用100-N个关键词和问句的主题相似度构成一个100-N维的相似度列表,将100-N个关键词的词与词之间的相似度构成一个[100-N,100-N]的矩阵,使用行列式点过程从中采样5-N个既与当前问句相似度大于或等于0.2并且不同关键词之间相似度小于或等于0.06的关键词作为指导中心词。
S3、将当前问句和5个指导中心词输入预先训练的解码器,输出5组预测的关键词组。其中,预先训练的解码器是基于训练集中问句和对应指导中心词、关键词组,采用长短期记忆网络(Long Short-Term Memory,简称为LSTM)进行训练得到的模型。
解码器的训练过程包括:
201、将预先获得的指导中心词与真实关键词组使用Familia进行主题相似度计算,构造解码器的训练集,相似度计算公式如下:
式中,Similarity(w,c)为指导中心词和真实关键词组的主题相似度,w为指导中心词,c为真实关键词组,k为主题的编号,zk为第k个主题,p(w|zk)为指导中心词属于第k个主题的概率,p(zk|c)为真实关键词组的主题分布。
202、将训练集中的指导中心词进行LSTM编码得到指导中心词经过LSTM编码后的向量,将训练集中的问句进行LSTM编码得到问句经过LSTM编码后的向量。
203、将指导中心词经过LSTM编码后的向量和问句经过LSTM编码后的向量进行拼接输入到解码器,输出预测的关键词组。
其中,将解码器的生成结果与真实关键词组进行损失计算,从而不断地更新网络的参数,完成解码器的训练。
损失计算公式如下:
式中,L为损失函数,Ly为需要预测的关键词组的长度,Yt为第t时刻需要预测的关键词组,yt-1为第t-1时刻预测出的关键词组,e为问句经过LSTM编码后的向量,r为指导中心词经过LSTM编码后的向量,θ为训练参数,-logP为负对数似然。
进一步地,若没有匹配到相似度大于0.2的指导中心词,将真实关键词中随机抽取一个真实关键词作为指导中心词来防止噪声的产生。通过此项匹配构成训练数据,即通过问句和指导中心词生成对应的关键词组。如表1所示,为匹配后形成的训练数据样例。
表1匹配后形成的训练数据样例
问句 | 指导中心词 | 关键词组 |
这张照片抓拍得太好了! | 照片 | 照片 |
这张照片抓拍得太好了! | 漂亮 | 拍摄画面 |
这张照片抓拍得太好了! | 手机 | 抓拍摄像机 |
这张照片抓拍得太好了! | 微博 | 曝光 |
这张照片抓拍得太好了! | 狗狗 | 喜欢视频 |
综上所述,本发明的基于Seq2set2seq框架的关键词组预测方法,采用seq2set2seq的两阶段关键词组生成模型,通过考虑不同响应的差异性,利用指导中心词指导建模一对多的对话关键词组预测任务,相对于现有技术而言,其可以具有良好的可解释性、可控性、准确性以及多样性。具体地,指导中心词为回复的生成提供可控的指导方向,相比于从分布中采样潜在变量的随机性,此方法具有更好的可控性和可解释性;另外,指导中心词作为指导方向可以解决将关键词组拼接之后使用seq2seq生成的顺序性问题,即不同组关键词组之间不存在顺序问题。
本发明提供的方法中,由于在对话中训练了一个多标签分类模型来预测一个问句所有可能的回复关键词,此种模型的特点在于,较少数据中的不确定性,即当提供一个问句和一个指导中心词后,指导中心词将能够提供给模型一个具体的回复方向,并且指导中心词的维度较高,足以影响模型进行学习。同时在测试阶段,本发明使用一个多标签分类的模型预测出前100的关键词,之后使用主题模型和行列式点过程选出5个既相关又多样的指导中心词作为指导进行关键词组的生成。实验结果表明,本发明提供的方法可以超越现有的竞争神经模型下的自动评价指标,这表明了该方法的有效性。
进一步地,如图3所示,为本实施例提供的基于Seq2set2seq框架的关键词组预测方法的总体思想图,即将问句首先输入到多标签分类器中预测概率最大的前100个关键词,之后使用规则和行列式点过程预测5个指导中心词,接着将预测的5个指导中心词和问句输入到解码器中,预测得到关键词组。
具体地:如将问句“工作太久,早点退休”输入到多标签分类器中,在多标签分类器内部,使用不同的卷积核大小的滤波器(2,3,4)对问句进行卷积,将使用不同滤波器所卷积后的结果按照最后一维进行拼接作为卷积层(Conv)得到卷积特征,经过初级胶囊层(PrimCap)将卷积特征转化为初级胶囊,将初级胶囊经过压缩(Compression)即经过一个线性变换转换为浓缩胶囊,将浓缩胶囊经过聚合(Aggregation)转化为路由胶囊,最后将路由胶囊转化为10000维的向量,即概率分布(Representation),完成多标签分类器的输出。概率分布代表10000个关键词可以作为当前问句回复中关键词的概率。在得到多标签分类器的10000维的概率分布之后,对此概率分布中概率最大的前100个词,经过拷贝(copy)和行列式点过程(Determinantal Point Process)的操作产生5个指导中心词如“工作、感觉、回家、赛跑、上班”,之后将每个指导中心词和问句“工作太久,早点退休”输入到解码器端,最终会得到5组预测的关键词组如“工作清闲、神仙日子、退休心情、赛跑、上班激情”。
实施例2
由于本实施例的任务是预测关键词组,而在此之前是没有包含关键词组的对话数据集,一般的数据只有单纯的问句和回复,因此如果想要预测关键词组的话,需要构造出有关键词组的对话数据集,所以在本实施例中采集并标注了约5万句微博对话数据的关键词组,同时使用ERNIE模型和CRF模型训练了一个关键词组的序列标注模型,以此对微博400万数据进行关键词组抽取。微博数据是一对多的单轮对话数据,这些回复中的句子质量参差不齐,为了过滤掉与问句毫不相关的回复,本实施例中对每一个问句和其所有的回复进行主题的相似度计算,去除相似度为0的回复,之后保留最多5个关键词组并且这5个关键词组中的关键词均完全不同。最终形成了10万条左右一对多的问句和多个关键词组的配对数据,随机分别各采样1000组作为验证集和测试集,其余作为训练集。
基于上述数据集,分别采用本发明提供的基于Seq2set2seq框架的关键词组预测方法和基于其他模型的方法进行关键词组的预测。
本发明关键词组提取方法中的关键词组生成模块和transformer模型以及seq2seq模型都是通过OpenNMT实现的,OpenNMT是一个用于构建基于seq2seq模型的开源框架。Cvae的潜在变量大小为500,同时所有模型编码器和解码器中的隐藏单元数均为500。使用Adam算法更新参数,并将学习率初始化为0.001。批处理大小设置为64。transformer模型的编码和解码层数均为6,在推理过程中,本实施例中使用波束搜索,将波束大小设置为10。
词汇表由最常见的50,000个单词组成。表中未包含的任何单词都将由表示未知单词的特殊标记“UNK”替换;单词的词向量均为随机初始化得到的,维度为500。
以下为使用不同方法得到的实验结果。
表2不同模型的自动评价结果
表3不使用copy和不使用行列式点过程的自动评价结果
表1和表2中均包括四类评价指标:F-p、F-r、F-f和ARI。其中:
F-p:将预测的一组关键词组与真实的每组关键词组进行比对,当前预测的关键词组相对于真实关键词组的F-p值即为所有比对F-p值中最大的,即最匹配的F-p值,之后依次完成所有预测组的关键词组的F-p值计算,此项指标可以评估模型预测的准确性。
F-r:将真实的一组关键词组与预测的每组关键词组进行比对,当前真实的关键词组相对于预测关键词组的F-r值即为所有比对F-r值中最大的,即最匹配的F-r值,之后依次完成所有真实组的关键词组的F-r值计算,此项指标可以评估模型预测的召回性或范围性。
F-f:F-f=2*F-p*F-r/(F-p+F-r),此项指标可以评估模型的整体性能。
ARI:调整兰德系数(Adjusted Rand index),此项指标可以评估聚类的性能,即评测每一类内部之间的相关性和不同簇之间的分离性,其范围为[-1,1],值越大性能越好。
将本发明提供的方法中的关键词组个数设为3,接近于关键词组个数的平均值(3.5)。在四类指标中,本发明所提出的方法在F-f和ARI上优于所有其他方法。
进一步地,one2one模式的模型在F-p和F-r的评价指标上基本相差不大,这种情况是正常的,是由于one2one模式的seq2seq和transformer使用beamsearch选取3个best的关键词组结果,生成的结果虽然语义上比较相似但内容基本上是完全不同的,因此其F-p值和F-r值较为相近;而one2one模式的cvae则是由于潜在变量的原因导致了关键词组生成的结果具有随机性,可控性较差,关键词组生成的准确度较低。
one2many模式的模型F-p与F-r的指标相差较大,且one2many模式的模型预测的关键词组组数较少,平均预测组数为2,同时生成的不同组关键词组之间存在重复情况,这也是导致one2many模式F-p值较高,F-r值较低的原因,即预测出来的较准,但预测的少,覆盖性较差,可以认为其根本原因是由于不同关键词组之间的顺序性造成的。
本发明提供的方法在F-f的综合评价指标和调整兰德系数上超过所有基线模型,证明了本发明提供的方法具有良好的预测准确性和范围覆盖性,同时调整兰德系数也说明了本发明的模型可以尽可能的将具有相关性的关键词组组合在一起,而相关性差的关键词组则分离开来。
如表4所示,为不同模型关键词组的生成样例。
表4不同模型关键词组的生成样例
由表4可知,本发明提供的方法可以生成更多的适合于问句且多样的关键词组,一对一训练的模式(One2one-seq2seq、One2one-transformer、Cvae)则容易生成一些通用的词,如“喜欢”,同时其生成的结果与问句的相关性较差,而一对多训练的模式(One2many-seq2seq、One2many-transformer)生成的关键词组数偏少,这有可能是关键词组的顺序性导致的,即不同关键词组之间在现实中并不存在顺序性,但一对多训练的模式则是将不同关键词组使用特殊符号如“;”进行拼接,将自动引入顺序性问题,从而导致其生成的结果偏少。
本发明中不使用copy的方法则显示出其结果相对与完整模型,关键词组的丰富度上表现稍差,这说明了copy当前语句中的一些内容作为指导中心词来生成关键词组,有可能降低匹配过程中的噪声,而不使用行列式点过程生成的关键词组往往语义上都比较相近,这也证明了将copy和行列式点过程结合的有效性。同时从整体角度来说,本发明提供的方法可以为关键词组的生成提供可控、可解释、内容丰富,准确性较高的生成结果。
本发明提供的方法避免了一对一模式的seq2seq采用beamsearch后生成的关键词组较为相似且容易生成高频词的问题,同时也解决了cvae通过采样变量的方法导致可控性较差的问题,且避免了一对多模式所自动引入顺序性的问题。
本领域内的技术人员应明白,本发明的实施例可提供为方法、系统或计算机程序产品。因此,本发明可采用完全硬件实施例、完全软件实施例,或结合软件和硬件方面的实施例的形式。而且,本发明可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本发明是参照根据本发明实施例的方法、设备(系统)和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。
应当注意的是,在权利要求中,不应将位于括号之间的任何附图标记理解成对权利要求的限制。词语“包含”不排除存在未列在权利要求中的部件或步骤。位于部件之前的词语“一”或“一个”不排除存在多个这样的部件。本发明可以借助于包括有若干不同部件的硬件以及借助于适当编程的计算机来实现。在列举了若干装置的权利要求中,这些装置中的若干个可以是通过同一个硬件来具体体现。词语第一、第二、第三等的使用,仅是为了表述方便,而不表示任何顺序。可将这些词语理解为部件名称的一部分。
此外,需要说明的是,在本说明书的描述中,术语“一个实施例”、“一些实施例”、“实施例”、“示例”、“具体示例”或“一些示例”等的描述,是指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本发明的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不必须针对的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任一个或多个实施例或示例中以合适的方式结合。此外,在不相互矛盾的情况下,本领域的技术人员可以将本说明书中描述的不同实施例或示例以及不同实施例或示例的特征进行结合和组合。
尽管已描述了本发明的优选实施例,但本领域的技术人员在得知了基本创造性概念后,则可对这些实施例作出另外的变更和修改。所以,权利要求应该解释为包括优选实施例以及落入本发明范围的所有变更和修改。
显然,本领域的技术人员可以对本发明进行各种修改和变型而不脱离本发明的精神和范围。这样,倘若本发明的这些修改和变型属于本发明权利要求及其等同技术的范围之内,则本发明也应该包含这些修改和变型在内。
Claims (10)
1.一种基于Seq2set2seq框架的关键词组预测方法,其特征在于,所述Seq2set2seq框架包括预测指导中心词和预测关键词组两部分,具体包括以下步骤:
S1、将当前问句输入预先训练的多标签分类器,输出多维向量;
所述多维向量的每一维均对应一个标签,一个标签对应一个关键词,多维向量的每一维的值表示对应标签作为当前问句回复中关键词的概率;
所述预先训练的多标签分类器是基于训练集中的问句和问句回复所对应的关键词,采用胶囊网络进行训练得到的模型;
S2、获取多维向量中前100维向量所对应的关键词,进行行列式点过程采样,得到多个指导中心词;
S3、将当前问句和多个指导中心词输入预先训练的解码器,对应输出多组预测的关键词组;
所述预先训练的解码器是基于训练集中问句和对应指导中心词、关键词组,采用长短期记忆网络LSTM进行训练得到的模型。
2.如权利要求1所述的基于Seq2set2seq框架的关键词组预测方法,其特征在于,所述多标签分类器的训练过程包括:
101、向胶囊网络输入问句,提取多个特征,并将多个特征按照最后一个维度进行拼接后输入初级胶囊网络层;
102、将初级胶囊网络层中的初级胶囊转化为路由胶囊;
103、将路由胶囊转化为向量形式,输出多维向量。
3.如权利要求2所述的基于Seq2set2seq框架的关键词组预测方法,其特征在于,所述步骤102包括:通过线性变换将初级胶囊转换为浓缩胶囊,通过动态路由将浓缩胶囊转化为路由胶囊。
7.如权利要求1所述的基于Seq2set2seq框架的关键词组预测方法,其特征在于,所述步骤S2中,获取多维向量中前100维向量所对应的关键词,进行行列式点过程采样,得到5个指导中心词。
8.如权利要求7所述的基于Seq2set2seq框架的关键词组预测方法,其特征在于,所述步骤S2包括:
S21、使用步骤S1中多标签分类器中预测的前100维向量所对应的关键词作为关键词集合,对同时出现在预测的100个关键词和当前问句中的N个关键词进行抽取作为指导中心词,N≤5;
S22、使用行列式点过程从剩余的100-N个关键词中采样5-N个关键词作为指导中心词。
9.如权利要求8所述的基于Seq2set2seq框架的关键词组预测方法,其特征在于,所述步骤S22具体为:选用100-N个关键词和问句的主题相似度构成一个100-N维的相似度列表,将100-N个关键词的词与词之间的相似度构成一个[100-N,100-N]的矩阵,使用行列式点过程从中采样5-N个既与当前问句相似度大于或等于0.2并且不同关键词之间相似度小于或等于0.06的关键词作为指导中心词。
10.如权利要求9所述的基于Seq2set2seq框架的关键词组预测方法,其特征在于,所述解码器的训练过程包括:
201、将预先获得的指导中心词与真实关键词组使用Familia进行主题相似度计算,构造解码器的训练集,相似度计算公式如下:
式中,Similarity(w,c)为指导中心词和真实关键词组的主题相似度,w为指导中心词,c为真实关键词组,k为主题的编号,zk为第k个主题,p(w|zk)为指导中心词属于第k个主题的概率,p(zk|c)为真实关键词组的主题分布;
202、将训练集中的指导中心词进行LSTM编码得到指导中心词经过LSTM编码后的向量,将训练集中的问句进行LSTM编码得到问句经过LSTM编码后的向量;
203、将指导中心词经过LSTM编码后的向量和问句经过LSTM编码后的向量进行拼接输入到解码器,输出预测的关键词组;
其中,将解码器的生成结果与真实关键词组进行损失计算,从而不断地更新网络的参数,完成解码器的训练;
损失计算公式如下:
式中,L为损失函数,Ly为需要预测的关键词组的长度,Yt为第t时刻需要预测的关键词组,yt-1为第t-1时刻预测出的关键词组,e为问句经过LSTM编码后的向量,r为指导中心词经过LSTM编码后的向量,θ为训练参数,-logP为负对数似然。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010576549.0A CN111737401B (zh) | 2020-06-22 | 2020-06-22 | 一种基于Seq2set2seq框架的关键词组预测方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010576549.0A CN111737401B (zh) | 2020-06-22 | 2020-06-22 | 一种基于Seq2set2seq框架的关键词组预测方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111737401A true CN111737401A (zh) | 2020-10-02 |
CN111737401B CN111737401B (zh) | 2023-03-24 |
Family
ID=72650457
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010576549.0A Active CN111737401B (zh) | 2020-06-22 | 2020-06-22 | 一种基于Seq2set2seq框架的关键词组预测方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111737401B (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113312897A (zh) * | 2021-06-21 | 2021-08-27 | 复旦大学 | 一种文本总结方法、电子设备及存储介质 |
Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106021272A (zh) * | 2016-04-04 | 2016-10-12 | 上海大学 | 基于分布式表达词向量计算的关键词自动提取方法 |
CN108376131A (zh) * | 2018-03-14 | 2018-08-07 | 中山大学 | 基于seq2seq深度神经网络模型的关键词抽取方法 |
WO2018149326A1 (zh) * | 2017-02-16 | 2018-08-23 | 阿里巴巴集团控股有限公司 | 一种自然语言问句答案的生成方法、装置及服务器 |
CN109299273A (zh) * | 2018-11-02 | 2019-02-01 | 广州语义科技有限公司 | 基于改进seq2seq模型的多源多标签文本分类方法及其系统 |
CN110119765A (zh) * | 2019-04-18 | 2019-08-13 | 浙江工业大学 | 一种基于Seq2seq框架的关键词提取方法 |
WO2019153613A1 (zh) * | 2018-02-09 | 2019-08-15 | 平安科技(深圳)有限公司 | 聊天应答方法、电子装置及存储介质 |
WO2020024951A1 (zh) * | 2018-08-01 | 2020-02-06 | 北京三快在线科技有限公司 | 多义词词义学习以及搜索结果显示 |
CN110825850A (zh) * | 2019-11-07 | 2020-02-21 | 哈尔滨工业大学(深圳) | 一种自然语言主题分类方法及装置 |
CN111209386A (zh) * | 2020-01-07 | 2020-05-29 | 重庆邮电大学 | 一种基于深度学习的个性化文本推荐方法 |
-
2020
- 2020-06-22 CN CN202010576549.0A patent/CN111737401B/zh active Active
Patent Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106021272A (zh) * | 2016-04-04 | 2016-10-12 | 上海大学 | 基于分布式表达词向量计算的关键词自动提取方法 |
WO2018149326A1 (zh) * | 2017-02-16 | 2018-08-23 | 阿里巴巴集团控股有限公司 | 一种自然语言问句答案的生成方法、装置及服务器 |
WO2019153613A1 (zh) * | 2018-02-09 | 2019-08-15 | 平安科技(深圳)有限公司 | 聊天应答方法、电子装置及存储介质 |
CN108376131A (zh) * | 2018-03-14 | 2018-08-07 | 中山大学 | 基于seq2seq深度神经网络模型的关键词抽取方法 |
WO2020024951A1 (zh) * | 2018-08-01 | 2020-02-06 | 北京三快在线科技有限公司 | 多义词词义学习以及搜索结果显示 |
CN109299273A (zh) * | 2018-11-02 | 2019-02-01 | 广州语义科技有限公司 | 基于改进seq2seq模型的多源多标签文本分类方法及其系统 |
CN110119765A (zh) * | 2019-04-18 | 2019-08-13 | 浙江工业大学 | 一种基于Seq2seq框架的关键词提取方法 |
CN110825850A (zh) * | 2019-11-07 | 2020-02-21 | 哈尔滨工业大学(深圳) | 一种自然语言主题分类方法及装置 |
CN111209386A (zh) * | 2020-01-07 | 2020-05-29 | 重庆邮电大学 | 一种基于深度学习的个性化文本推荐方法 |
Non-Patent Citations (2)
Title |
---|
SANGWOO CHO等: "Improving the Similarity Measure of Determinantal Point Processes for Extractive Multi-Document Summarization", 《ARXIV.ORG》 * |
侯丽微: "基于序列到序列模型的中文生成式自动文摘研究", 《中国优秀硕士学位论文全文数据库 信息科技辑》 * |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113312897A (zh) * | 2021-06-21 | 2021-08-27 | 复旦大学 | 一种文本总结方法、电子设备及存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN111737401B (zh) | 2023-03-24 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Shen et al. | Dialogxl: All-in-one xlnet for multi-party conversation emotion recognition | |
CN108681610B (zh) | 生成式多轮闲聊对话方法、系统及计算机可读存储介质 | |
CN108549658B (zh) | 一种基于语法分析树上注意力机制的深度学习视频问答方法及系统 | |
CN110222163A (zh) | 一种融合cnn与双向lstm的智能问答方法及系统 | |
CN109992669B (zh) | 一种基于语言模型和强化学习的关键词问答方法 | |
CN114722838A (zh) | 基于常识感知和层次化多任务学习的对话情感识别方法 | |
CN113297364A (zh) | 一种面向对话系统中的自然语言理解方法及装置 | |
CN110705298B (zh) | 一种改进的前缀树与循环神经网络结合的领域分类方法 | |
CN112417894A (zh) | 一种基于多任务学习的对话意图识别方法及识别系统 | |
CN113255366B (zh) | 一种基于异构图神经网络的方面级文本情感分析方法 | |
CN115794999A (zh) | 一种基于扩散模型的专利文档查询方法及计算机设备 | |
CN116150335A (zh) | 一种军事场景下文本语义检索方法 | |
CN110597968A (zh) | 一种回复选择方法及装置 | |
CN109033294A (zh) | 一种融入内容信息的混合推荐方法 | |
CN113254604A (zh) | 一种基于参考规范的专业文本生成方法及装置 | |
CN112818106A (zh) | 一种生成式问答的评价方法 | |
CN113361278A (zh) | 一种基于数据增强与主动学习的小样本命名实体识别方法 | |
CN111737401B (zh) | 一种基于Seq2set2seq框架的关键词组预测方法 | |
CN113177113B (zh) | 任务型对话模型预训练方法、装置、设备及存储介质 | |
CN114817307A (zh) | 一种基于半监督学习和元学习的少样本nl2sql方法 | |
CN111813907A (zh) | 一种自然语言问答技术中的问句意图识别方法 | |
CN117033423A (zh) | 一种注入最优模式项和历史交互信息的sql生成方法 | |
CN115810351A (zh) | 一种基于视听融合的管制员语音识别方法及装置 | |
CN110929006A (zh) | 一种数据型问答系统 | |
CN113239678B (zh) | 一种面向答案选择的多角度注意力特征匹配方法及系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
TA01 | Transfer of patent application right |
Effective date of registration: 20230227 Address after: 100144 Beijing City, Shijingshan District Jin Yuan Zhuang Road No. 5 Applicant after: NORTH CHINA University OF TECHNOLOGY Address before: 100048 No. 105 West Third Ring Road North, Beijing, Haidian District Applicant before: Capital Normal University |
|
TA01 | Transfer of patent application right | ||
GR01 | Patent grant | ||
GR01 | Patent grant |