CN114925197A - 基于主题注意力的深度学习文本分类模型训练方法 - Google Patents
基于主题注意力的深度学习文本分类模型训练方法 Download PDFInfo
- Publication number
- CN114925197A CN114925197A CN202210312063.5A CN202210312063A CN114925197A CN 114925197 A CN114925197 A CN 114925197A CN 202210312063 A CN202210312063 A CN 202210312063A CN 114925197 A CN114925197 A CN 114925197A
- Authority
- CN
- China
- Prior art keywords
- attention
- matrix
- text
- vector
- topic
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 37
- 238000000034 method Methods 0.000 title claims abstract description 35
- 238000013135 deep learning Methods 0.000 title claims abstract description 20
- 238000013145 classification model Methods 0.000 title claims abstract description 19
- 239000013598 vector Substances 0.000 claims abstract description 131
- 239000011159 matrix material Substances 0.000 claims abstract description 81
- 230000000873 masking effect Effects 0.000 claims abstract description 16
- 238000012795 verification Methods 0.000 claims abstract description 9
- 230000006870 function Effects 0.000 claims description 10
- 238000013507 mapping Methods 0.000 claims description 9
- 238000010606 normalization Methods 0.000 claims description 6
- 238000004364 calculation method Methods 0.000 abstract description 9
- 239000010410 layer Substances 0.000 description 18
- 238000013528 artificial neural network Methods 0.000 description 7
- 238000012545 processing Methods 0.000 description 4
- 238000012512 characterization method Methods 0.000 description 3
- 238000010586 diagram Methods 0.000 description 3
- 230000008569 process Effects 0.000 description 3
- 230000004913 activation Effects 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 1
- 208000037170 Delayed Emergence from Anesthesia Diseases 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 125000004122 cyclic group Chemical group 0.000 description 1
- 238000009826 distribution Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000010365 information processing Effects 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 210000002569 neuron Anatomy 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 239000002356 single layer Substances 0.000 description 1
- 239000000126 substance Substances 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/35—Clustering; Classification
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Software Systems (AREA)
- Databases & Information Systems (AREA)
- Artificial Intelligence (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本公开实施例中提供了一种基于主题注意力的深度学习文本分类模型训练方法,属于计算技术领域,具体包括:根据原始文本构建文本数据集;得到文本的数字化表示、文本的掩盖序列、文本的数字标签;得到样本,并将样本分为训练集和验证集;初始化前向网络中变量;得到表征文本的一组词向量;得到原始注意力矩阵;得到目标注意力矩阵;根据目标注意力矩阵,得到概率矩阵;计算注意力头输出;得到注意力输出;计算主题输出;计算主题概率向量;计算交叉熵损失;计算前向网络变量的梯度;更新网络变量;迭代计算交叉熵损失以及梯度;当迭代达到预设次数或模型损失趋于稳定,迭代停止。通过本公开的方案提高了模型的并行性、稳定性、可视性和准确率。
Description
技术领域
本公开实施例涉及计算技术领域,尤其涉及一种基于主题注意力的深度学习文本分类模型训练方法。
背景技术
目前,计算机以及互联网行业蓬勃发展,网络用户迅速增长,促进互联网企业以及网络用户更多的内容制作以及内容输出,并产生了大量的互联网数据。互联网数据包含大量文本数据,表现为内容繁多,形式多样。随着文本数据规模日趋庞大,相关企业处理面临的挑战也日益严峻。
文本规模的迅速增长,对文本处理工作提出了较高的要求。与传统的数据相比,网络中的文本数据具有许多新的特点,如数据量大、高度重复、高度冗余等。完全依靠人工处理这些信息的代价过大。文本分类是文本处理一项最为基础的任务,使用计算机快速高效的完成文本分类,有利于缓解信息高速增长带来的信息处理问题。
文本分类经历了从专家系统到机器学习算法再到深度学习算法的跨越。深度学习是机器学习中一种基于对数据进行表征学习的方法,其侧重于利用深度的神经网络,将模型处理得更为复杂,从而使模型对数据的理解更加深入。
深度学习文本分类模型目前主要以人工神经网络、卷积神经网络、循环神经网络为基础。这些网络搭建的模型为黑箱模型,其参数的解释性不高,不利于网络的优化以及实际的使用。同时,基于传统神经网络的文本分类模型在并发性、稳定性、训练速度、准确率等方面还有改进空间。
可见,亟需一种并发性、可解释性、稳定性、训练速度和准确率更高的基于主题注意力的深度学习文本分类模型训练方法。
发明内容
有鉴于此,本公开实施例提供一种基于主题注意力的深度学习文本分类模型训练方法,至少部分解决现有技术中存在并发性、可解释性、稳定性、训练速度和准确率较差的问题。
本公开实施例提供了一种基于主题注意力的深度学习文本分类模型训练方法,包括:
步骤1,获取原始文本,并根据所述原始文本构建文本数据集;
步骤2,根据所述文件数据集,得到文本的数字化表示、文本的掩盖序列、文本的数字标签;
步骤3,根据所述数字化表示,得到样本,并将样本分为训练集和验证集;
步骤4,初始化前向网络中变量,包括词嵌入表、主题向量以及其它全连接网络层权重;
步骤5,根据所述数字化表示,得到表征文本的一组词向量;
步骤6,根据所述词向量组和主题向量组,得到原始注意力矩阵;
步骤7,根据所述掩盖序列,掩盖原始注意力矩阵中的无效部分,得到目标注意力矩阵;
步骤8,根据所述目标注意力矩阵,得到概率矩阵;
步骤9,根据所述概率矩阵和值向量,计算注意力头输出;
步骤10,将不同头部的注意力头输出拼接并将拼接结果进行线性转化,得到注意力输出;
步骤11,根据所述注意力输出,计算主题输出;
步骤12,根据所述主题输出和主题向量,计算主题概率向量;
步骤13,将所述数字标签转化为one-hot编码形式后根据所述主题概率向量,计算交叉熵损失;
步骤14,根据所述交叉熵损失,计算前向网络变量的梯度;
步骤15,根据所述梯度,更新网络变量;
步骤16,依次从所述训练集中取出一定样本送入前向网络中,不断计算交叉熵损失以及梯度,更新网络变量;
步骤17,当迭代达到预设次数或模型损失趋于稳定,迭代停止。
根据本公开实施例的一种具体实现方式,所述前向网络包括词嵌入、主题嵌入,多头注意力模块、线性映射层、前馈网络模块、残差结构、标准化模块。
根据本公开实施例的一种具体实现方式,所述主题向量和查询向量之间、所述词向量和键向量之间、所述词向量和所述值向量之间,以及,所述注意力头输出和所述主题输出之间均设置有一个全连接层,主题输出和主题概率向量之间设置有多个全连接层。
根据本公开实施例的一种具体实现方式,所述步骤5具体包括:
将所述数字化表示中的数字序号依次取出,通过序号查询词嵌入表,取出序号对应行数的向量,将取出的向量按序拼接成矩阵,并根据所述矩阵得到所述词向量。
根据本公开实施例的一种具体实现方式,所述原始注意力矩阵Score,计算方法如下:
令Q为查询矩阵,K为键矩阵,V为值矩阵,n为类别数,l为文本最大长度,demb为词向量维度,则:
Q=(q1,q2,…,qn),K=(k1,k2,…,kl,),V=(v1,v2,…,vl,)
将Q矩阵和K的转置矩阵做矩阵乘法,并进行缩放,公式如下:
Scorei,j表示文本中第j个字符对第i个主题的贡献。
根据本公开实施例的一种具体实现方式,所述步骤7具体包括:
步骤7.1,将所述查询向量、键向量、值向量投影到低纬度上,计算每个头独立的注意力;
步骤7.2,根据所述掩盖序列和每个头独立的注意力掩盖原始注意力矩阵中的无效部分,得到目标注意力矩阵。
根据本公开实施例的一种具体实现方式,所述概率矩阵的计算公式如下:
Probi=(Probi,1,Probi,2,……,Probi,l)
Prob=Softmax(Score)=(Prob1,Prob2,……,Probn)。
根据本公开实施例的一种具体实现方式,所述注意力头输出的计算公式如下:
根据本公开实施例的一种具体实现方式,所述主题概率向量包含多个主题概率,其中,所述主题概率由主题向量和主题输出进行点积运算得出或者使用单节点全连接网络计算得出。
根据本公开实施例的一种具体实现方式,所述步骤13之前,所述方法还包括:
使用softmax函数对所述主题概率向量进行归一化。
本公开实施例中的基于主题注意力的深度学习文本分类模型训练方案,包括:步骤1,获取原始文本,并根据所述原始文本构建文本数据集;步骤2,根据所述文件数据集,得到文本的数字化表示、文本的掩盖序列、文本的数字标签;步骤3,根据所述数字化表示,得到样本,并将样本分为训练集和验证集;步骤4,初始化前向网络中变量,包括词嵌入表、主题向量以及其它全连接网络层权重;步骤5,根据所述数字化表示,得到表征文本的一组词向量;步骤6,根据所述词向量组和主题向量组,得到原始注意力矩阵;步骤7,根据所述掩盖序列,掩盖原始注意力矩阵中的无效部分,得到目标注意力矩阵;步骤8,根据所述目标注意力矩阵,得到概率矩阵;步骤9,根据所述概率矩阵和值向量,计算注意力头输出;步骤10,将不同头部的注意力头输出拼接并将拼接结果进行线性转化,得到注意力输出;步骤11,根据所述注意力输出,计算主题输出;步骤12,根据所述主题输出和主题向量,计算主题概率向量;步骤13,将所述数字标签转化为one-hot编码形式后根据所述主题概率向量,计算交叉熵损失;步骤14,根据所述交叉熵损失,计算前向网络变量的梯度;步骤15,根据所述梯度,更新网络变量;步骤16,依次从所述训练集中取出一定样本送入前向网络中,不断计算交叉熵损失以及梯度,更新网络变量;步骤17,当迭代达到预设次数或模型损失趋于稳定,迭代停止。
本公开实施例的有益效果为:通过本公开的方案,以神经网络和多头注意力为核心,通过设置主题向量,一定程度上克服了深度学习的黑箱问题,克服了注意力模型在进行长文本分类时的不适应性,提高了模型的并行性、稳定性、可视性和准确率。
附图说明
为了更清楚地说明本公开实施例的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本公开的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
图1为本公开实施例提供的一种基于主题注意力的深度学习文本分类模型训练方法的流程示意图;
图2为本公开实施例提供的一种基于主题注意力的深度学习文本分类模型训练方法的模型框架图;
图3为本公开实施例提供的一种前向网络网络结构图;
图4为本公开实施例提供的一种基于主题注意力的深度学习文本分类模型训练方法训练后模型对一篇体育新闻的关注度热力图。
具体实施方式
下面结合附图对本公开实施例进行详细描述。
以下通过特定的具体实例说明本公开的实施方式,本领域技术人员可由本说明书所揭露的内容轻易地了解本公开的其他优点与功效。显然,所描述的实施例仅仅是本公开一部分实施例,而不是全部的实施例。本公开还可以通过另外不同的具体实施方式加以实施或应用,本说明书中的各项细节也可以基于不同观点与应用,在没有背离本公开的精神下进行各种修饰或改变。需说明的是,在不冲突的情况下,以下实施例及实施例中的特征可以相互组合。基于本公开中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本公开保护的范围。
需要说明的是,下文描述在所附权利要求书的范围内的实施例的各种方面。应显而易见,本文中所描述的方面可体现于广泛多种形式中,且本文中所描述的任何特定结构及/或功能仅为说明性的。基于本公开,所属领域的技术人员应了解,本文中所描述的一个方面可与任何其它方面独立地实施,且可以各种方式组合这些方面中的两者或两者以上。举例来说,可使用本文中所阐述的任何数目个方面来实施设备及/或实践方法。另外,可使用除了本文中所阐述的方面中的一或多者之外的其它结构及/或功能性实施此设备及/或实践此方法。
还需要说明的是,以下实施例中所提供的图示仅以示意方式说明本公开的基本构想,图式中仅显示与本公开中有关的组件而非按照实际实施时的组件数目、形状及尺寸绘制,其实际实施时各组件的型态、数量及比例可为一种随意的改变,且其组件布局型态也可能更为复杂。
另外,在以下描述中,提供具体细节是为了便于透彻理解实例。然而,所属领域的技术人员将理解,可在没有这些特定细节的情况下实践所述方面。
本公开实施例提供一种基于主题注意力的深度学习文本分类模型训练方法,所述方法可以应用于新闻、文档等各类长短文本分类场景的模型有监督学习过程中。
参见图1,为本公开实施例提供的一种基于主题注意力的深度学习文本分类模型训练方法的流程示意图。如图1和图2所示,所述方法主要包括以下步骤:
步骤1,获取原始文本,并根据所述原始文本构建文本数据集;
具体实施时,当需要对某个文本或者某类进文本行分类时,可以先将其作为所述原始文本,整理文本得到文本数据集。其中文本数据集至少包括有效的可分类的文本以及文本对应的标签。
步骤2,根据所述文件数据集,得到文本的数字化表示、文本的掩盖序列、文本的数字标签;
例如,可以将文本数据集中文本的每个字符映射为数字,得到文本的数字化表示。该映射通过单词表实现。单词表记录全部或部分字符及其对应的唯一序号。单词表可根据实际任务需求新建立,也可以使用已经存在的单词表。
进一步地,实际映射中,需要保证文本长度一致。选取一个合适的值作为最大长度,对文本长度大于最大长度的文本进行截取,对文本长度小于最大长度的文本进行填充。
进一步的,向单词表添加了[“P”][“S”][“UNK”]等标记,这些标记同其它字符一样对应一个唯一数字。[“S”]为文本开始标记,该标记被加在文本的开头,[“P”]为填充标记,用以填充不足长度的部分,[“UNK”]为未知字符标记,遇到不在单词表中或者不需要的字符时,可使用该标记代替。假定单词表中,“今”的序号是200,“天”的序号是52,“是”的序号是18,“周”的序号是177,“末”不在单词表中,“!”的序号是6552,[“P”]标记的序号为0,[“S”]标记的序号为100,[“UNK”]标记的序号为1,设定最大长度为10,则“今天是周末!”这段文本的数字化表示为[100,200,52,18,177,1,6552,0,0,0]。
进一步的,为了网络层有效获取注意力,为每个文本建立一个掩盖序列。以上述“今天是周末!”为例,其掩盖序列长度为10,与数字化表示长度相同。掩盖序列只含0或1值,其中1代表序列对应位置非填充,0代表序列对应位置填充。“今天是周末!”的掩盖序列为[1,1,1,1,1,1,1,0,0,0],表示其对应的数字化表示的前7个为非填充部分,后3个为填充部分。
同样的,从0开始,为每个类别设立一个唯一序号,建立类别的one-hot表示。如,对于3个类别,正方形、三角形、圆形,设类别序号分别为0,1,2。其one-hot表示为[1,0,0],[0,1,0],[0,0,1]。
步骤3,根据所述数字化表示,得到样本,并将样本分为训练集和验证集;
具体实施时,考虑到在大型数据集上,模型损失往往具有较大的波动性,将数据集分为训练样集合和验证集,使用训练集对模型进行训练,每迭代一定次数后,使用验证集对模型进行评估,以更准确评判模型的准确率等性能指标。
步骤4,初始化前向网络中变量,包括词嵌入表、主题向量以及其它全连接网络层权重;
可选的,所述前向网络包括词嵌入、主题嵌入,多头注意力模块、线性映射层、前馈网络模块、残差结构、标准化模块。
具体实施时,如图3所示,所述前向网络包括词嵌入、主题嵌入,多头注意力模块、线性映射层、前馈网络模块、残差结构、标准化模块,所述前向网络的输入为特征工程所得到的文本数字化表示及掩盖序列,输出包括类别得分、分类结果、注意力、损失等。
可以初始化词嵌入表,词嵌入表行数与单词表长度相等,列数为词向量维度,词向量维度一般取64,128,256,512,768等,词向量维度越高,其潜在的表征能力越强,但对设备的要求越高。词嵌入表为所有词向量的集合,单词表中每一个字符都能在词嵌入表中找到一个向量与之对应,该向量为对应字符的分布式表示。
例如,所述前向网络中的参数可以使用He正态分布初始化器进行初始化。在本实施例中,本模型初始化默认为He初始化。
步骤5,根据所述数字化表示,得到表征文本的一组词向量;
进一步的,所述步骤5具体包括:
将所述数字化表示中的数字序号依次取出,通过序号查询词嵌入表,取出序号对应行数的向量,将取出的向量按序拼接成矩阵,并根据所述矩阵得到所述词向量。
具体实施时,若最大文本长度为10,查询词嵌入表,将序号对应的向量取出,一个独立的文档被表示为10个词向量堆叠的输入矩阵。如,对于[100,200,52,18,177,1,6552,0,0,0],网络将查询词嵌入表,分别取出表中第100行、200行、52行、18行、177行、1行、6552行、0行、0行、0行的向量。
特别说明,上述第0行是为了直观表示,“第0行“其实际指代的是表的第1行,同理,第1行实际指代的是表的第2行,以此类推。
步骤6,根据所述词向量组和主题向量组,得到原始注意力矩阵;
进一步的,所述原始注意力矩阵Score,计算方法如下:
令Q为查询矩阵,K为键矩阵,V为值矩阵,n为类别数,l为文本最大长度,demb为词向量维度,则:
Q=(q1,q2,…,qn),K=(k1,k2,…,kl,),V=(v1,v2,…,vl,)
将Q矩阵和K的转置矩阵做矩阵乘法,并进行缩放,公式如下:
Scorei,j表示文本中第j个字符对第i个主题的贡献。
具体实施时,可以初始化矩阵WQ、WK、WV。通过WQ将主题向量映射,得到查询向量。
进一步的,通过WK、WV将输入矩阵中的向量进行线性映射,得到键向量ki、值向量vi。保证qi,ki维度相同。设i个位置的词向量为xi,则qi、ki、vi计算公式如下:
qi=WQxi,ki=WKxi,vi=WVxi
进一步的,计算各个查询向量和键向量之间的点积作为得分。
令Q=(q1,q2,…,qn)为查询矩阵(查询向量列表),K=(k1,k2,…,kl,)为键矩阵(键向量列表),V=(v1,v2,…,vl,)为值矩阵(值向量列表),n为类别数,l为文本最大长度,demb为词向量维度,则:
Q=(q1,q2,…,qn),K=(k1,k2,…,kl,),V=(v1,v2,…,vl,)
Q矩阵和K的转置矩阵做矩阵乘法,并进行缩放,公式如下:
Scorei,j表示文本中第j个字符对第i个主题的贡献。
步骤7,根据所述掩盖序列,掩盖原始注意力矩阵中的无效部分,得到目标注意力矩阵;
在上述实施例的基础上,所述步骤7具体包括:
步骤7.1,将所述查询向量、键向量、值向量投影到低纬度上,计算每个头独立的注意力;
步骤7.2,根据所述掩盖序列和每个头独立的注意力掩盖原始注意力矩阵中的无效部分,得到目标注意力矩阵。
具体实施时,可以将查询向量、键向量、值向量投影到低纬度上,投影h次,公式如下:
每个头独立计算注意力,表达式如下。
然后对Score进行掩盖,使分类模型不关注填充部分,概率矩阵在填充部分取值趋近于0。。
将Scorei加上向量Mask完成掩盖。设Mexam为样本的掩盖序列,MaxInt为预先设置的极大值,One为值全为1且与Mexam维度相同的向量,Mask计算公式为:
Mask=(One-Mexam)*(-MaxInt)
Scorei更新公式为:
Scorei=Scorei+Mask
Score=(Score1,Score2……Scoren)。
步骤8,根据所述目标注意力矩阵,得到概率矩阵;
进一步的,所述概率矩阵的计算公式如下:
Probi=(Probi,1,Probi,2,……,Probi,l)
Prob=Softmax(Score)=(Prob1,Prob2,……,Probn)。
具体实施时,在得到所述目标注意力矩阵后,可以使用softmax函数将得分归一化,得到概率矩阵。公式如下:
Probi=(Probi,1,Probi,2,……,Probi,l)
Prob=Softmax(Score)=(Prob1,Prob2,……,Probn)。
步骤9,根据所述概率矩阵和值向量,计算注意力头输出;
进一步的,所述注意力头输出的计算公式如下:
具体实施时,在得到所述概率矩阵和所述值向量后,可以将概率矩阵与值向量列表矩阵相乘,得到注意力头输出,公式如下:
步骤10,将不同头部的注意力头输出拼接并将拼接结果进行线性转化,得到注意力输出;
具体实施时,通过将不同头部的输出结果进行拼接,然后将拼接结果线性转化得到注意力输出,公式如下:
MultiHead(Q,K,V)=Concat(h ead1,h ead2,…,h eadh)WO。
步骤11,根据所述注意力输出,计算主题输出;
具体实施时,在得到所述注意力输出后,可以将注意力输出经过一个前馈网络增强其表征能力,得到主题输出。
步骤12,根据所述主题输出和主题向量,计算主题概率向量;
可选的,所述主题向量和查询向量之间、所述词向量和键向量之间、所述词向量和所述值向量之间,以及,所述注意力头输出和所述主题输出之间均设置有一个全连接层,主题输出和主题概率向量之间设置有多个全连接层。
进一步的,所述主题概率向量包含多个主题概率,其中,所述主题概率由主题向量和主题输出进行点积运算得出或者使用单节点全连接网络计算得出。
具体实施时,前馈网络为一个两层的全连接神经网络,第一层的激活函数为Relu,第二层不使用激活函数,网络表示如下:
FFN(x)=max(0,xW1+b1)W2+b2
进一步的,前馈网络层之后接入一个残差块,其一般表示为:
x=x+f(x)
进一步的,在残差结构基础上,加入层标准化(Layer Normalization)归一化输出。
进一步的,根据主题输出和主题向量,计算主题概率向量。
优选的,上述主题概率向量的计算是将对应主题输出和主题向量做点积运算。也可以在主题输出后通过一个只有1个神经元的全连接层,利用该层计算主题概率向量。需要说明的是,所述主题向量的数量的设定与类别总数n相同。若主题向量的个数与类别数目不同,使用一个单层全连接网络将主题概率向量映射到大小为n(类别数)的向量空间中。
步骤13,将所述数字标签转化为one-hot编码形式后根据所述主题概率向量,计算交叉熵损失;
可选的,所述步骤13之前,所述方法还包括:
使用softmax函数对所述主题概率向量进行归一化。
具体实施时,在得到所述数字标签和所述主题概率向量后,可以先将所述数字标签转化为one-hot编码形式,然后使用softmax函数对所述主题概率向量进行归一化,再计算主题概率向量与其对应标签的one-hot编码之间的交叉熵。交叉熵计算公式如下:
上式中,xi为第i个主题向量,P(xi)为文本属于第i类的实际概率。当文本具有确定的类别时,P(xi)=1,表示文本属于第i类,P(xi)=0,表示文本不属于第i类。Q(xi)为第i个主题向量经过前馈网络后输出的主题概率。设某样本共有3个类别,某文本属于第一个类别,其标签one-hot编码为[1,0,0],该文本经过网络计算后,得到的主题概率向量为[0.3,0.4,0.3],交叉熵计算过程如下:
cross entropy=-(1*log 0.3+0*log 0.4+0*log 0.4)=1.2。
然后设定损失函数,通过损失函数来评估模型的性能以及学习效果,不断调整参数,以此最小化损失。当损失越低,表明模型分类能力可能越好,然后并以交叉熵损失作为模型损失。
步骤14,根据所述交叉熵损失,计算前向网络变量的梯度;
具体实施时,在得到所述交叉熵损失后,可以以交叉熵损失作为模型损失,进一步的,为最小化模型损失,然后根据链式法则计算可变参数梯度。
步骤15,根据所述梯度,更新网络变量;
具体实施时,可以设定初始学习率为0.0001,计算可变参数更新量,更新原有网络参数。
步骤16,依次从所述训练集中取出一定样本送入前向网络中,不断计算交叉熵损失以及梯度,更新网络变量;
具体实施时,可以从训练样本中选取一定数目样本,送入前向网络中进行前向传播和反向传播,不断计算交叉熵损失以及梯度,更新网络变量,从而不断迭代。需要说明的是,可变参数更新量采用小批量随机梯度算法进行计算。根据使用设备的性能自行选择批大小。例如可以通过Adam算法对梯度进行优化。
步骤17,当迭代达到预设次数或模型损失趋于稳定,迭代停止。
具体实施时,直到模型收敛或达到一定迭代要求(如在训练数据集上完成20轮训练),结束迭代,完成训练,得到训练好的分类模型。
例如,考虑到在大型数据集上,模型损失往往具有较大的波动性,将数据集分为训练样集合和验证集,使用训练集对模型进行训练,每迭代一定次数后,使用验证集对模型进行评估,以更准确评判模型的准确率等性能指标。模型输出包括多头注意力模块中的概率矩阵。使用本方法训练模型,当训练次数达到29轮次时。使用一个体育类新闻进行测试,直接提取概率矩阵。如图4所示,可以看到模型分类时关注的重点字词。
本实施例提供的基于主题注意力的深度学习文本分类模型训练方法,通过以神经网络和多头注意力为核心,通过设置主题向量,一定程度上克服了深度学习的黑箱问题,克服了注意力模型在进行长文本分类时的不适应性,提高了模型在并行性、稳定性、可视性和准确率。
应当理解,本公开的各部分可以用硬件、软件、固件或它们的组合来实现。
以上所述,仅为本公开的具体实施方式,但本公开的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本公开揭露的技术范围内,可轻易想到的变化或替换,都应涵盖在本公开的保护范围之内。因此,本公开的保护范围应以权利要求的保护范围为准。
Claims (10)
1.一种基于主题注意力的深度学习文本分类模型训练方法,其特征在于,包括:
步骤1,获取原始文本,并根据所述原始文本构建文本数据集;
步骤2,根据所述文件数据集,得到文本的数字化表示、文本的掩盖序列、文本的数字标签;
步骤3,根据所述数字化表示,得到样本,并将样本分为训练集和验证集;
步骤4,初始化前向网络中变量,包括词嵌入表、主题向量以及其它全连接网络层权重;
步骤5,根据所述数字化表示,得到表征文本的一组词向量;
步骤6,根据所述词向量组和主题向量组,得到原始注意力矩阵;
步骤7,根据所述掩盖序列,掩盖原始注意力矩阵中的无效部分,得到目标注意力矩阵;
步骤8,根据所述目标注意力矩阵,得到概率矩阵;
步骤9,根据所述概率矩阵和值向量,计算注意力头输出;
步骤10,将不同头部的注意力头输出拼接并将拼接结果进行线性转化,得到注意力输出;
步骤11,根据所述注意力输出,计算主题输出;
步骤12,根据所述主题输出和主题向量,计算主题概率向量;
步骤13,将所述数字标签转化为one-hot编码形式后根据所述主题概率向量,计算交叉熵损失;
步骤14,根据所述交叉熵损失,计算前向网络变量的梯度;
步骤15,根据所述梯度,更新网络变量;
步骤16,依次从所述训练集中取出一定样本送入前向网络中,不断计算交叉熵损失以及梯度,更新网络变量;
步骤17,当迭代达到预设次数或模型损失趋于稳定,迭代停止。
2.根据权利要求1所述的方法,其特征在于,所述前向网络包括词嵌入、主题嵌入,多头注意力模块、线性映射层、前馈网络模块、残差结构、标准化模块。
3.根据权利要求1所述的方法,其特征在于,所述主题向量和查询向量之间、所述词向量和键向量之间、所述词向量和所述值向量之间,以及,所述注意力头输出和所述主题输出之间均设置有一个全连接层,主题输出和主题概率向量之间设置有多个全连接层。
4.根据权利要求1所述的方法,其特征在于,所述步骤5具体包括:
将所述数字化表示中的数字序号依次取出,通过序号查询词嵌入表,取出序号对应行数的向量,将取出的向量按序拼接成矩阵,并根据所述矩阵得到所述词向量。
6.根据权利要求5所述的方法,其特征在于,所述步骤7具体包括:
步骤7.1,将所述查询向量、键向量、值向量投影到低纬度上,计算每个头独立的注意力;
步骤7.2,根据所述掩盖序列和每个头独立的注意力掩盖原始注意力矩阵中的无效部分,得到目标注意力矩阵。
9.根据权利要求1所述的方法,其特征在于,所述主题概率向量包含多个主题概率,其中,所述主题概率由主题向量和主题输出进行点积运算得出或者使用单节点全连接网络计算得出。
10.根据权利要求1所述的方法,其特征在于,所述步骤13之前,所述方法还包括:
使用softmax函数对所述主题概率向量进行归一化。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210312063.5A CN114925197B (zh) | 2022-03-28 | 2022-03-28 | 基于主题注意力的深度学习文本分类模型训练方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210312063.5A CN114925197B (zh) | 2022-03-28 | 2022-03-28 | 基于主题注意力的深度学习文本分类模型训练方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114925197A true CN114925197A (zh) | 2022-08-19 |
CN114925197B CN114925197B (zh) | 2024-06-11 |
Family
ID=82805083
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210312063.5A Active CN114925197B (zh) | 2022-03-28 | 2022-03-28 | 基于主题注意力的深度学习文本分类模型训练方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114925197B (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115563508A (zh) * | 2022-11-08 | 2023-01-03 | 北京百度网讯科技有限公司 | 模型训练方法、装置以及设备 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2020140633A1 (zh) * | 2019-01-04 | 2020-07-09 | 平安科技(深圳)有限公司 | 文本主题提取方法、装置、电子设备及存储介质 |
US20200249918A1 (en) * | 2019-02-02 | 2020-08-06 | Microsoft Technology Licensing, Llc. | Deep learning enhanced code completion system |
US20200356851A1 (en) * | 2019-05-10 | 2020-11-12 | Baidu Usa Llc | Systems and methods for large scale semantic indexing with deep level-wise extreme multi-label learning |
CN112231485A (zh) * | 2020-12-14 | 2021-01-15 | 平安科技(深圳)有限公司 | 文本推荐方法、装置、计算机设备及存储介质 |
WO2021179570A1 (zh) * | 2020-03-13 | 2021-09-16 | 平安科技(深圳)有限公司 | 序列标注方法、装置、计算机设备和存储介质 |
-
2022
- 2022-03-28 CN CN202210312063.5A patent/CN114925197B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2020140633A1 (zh) * | 2019-01-04 | 2020-07-09 | 平安科技(深圳)有限公司 | 文本主题提取方法、装置、电子设备及存储介质 |
US20200249918A1 (en) * | 2019-02-02 | 2020-08-06 | Microsoft Technology Licensing, Llc. | Deep learning enhanced code completion system |
US20200356851A1 (en) * | 2019-05-10 | 2020-11-12 | Baidu Usa Llc | Systems and methods for large scale semantic indexing with deep level-wise extreme multi-label learning |
WO2021179570A1 (zh) * | 2020-03-13 | 2021-09-16 | 平安科技(深圳)有限公司 | 序列标注方法、装置、计算机设备和存储介质 |
CN112231485A (zh) * | 2020-12-14 | 2021-01-15 | 平安科技(深圳)有限公司 | 文本推荐方法、装置、计算机设备及存储介质 |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115563508A (zh) * | 2022-11-08 | 2023-01-03 | 北京百度网讯科技有限公司 | 模型训练方法、装置以及设备 |
Also Published As
Publication number | Publication date |
---|---|
CN114925197B (zh) | 2024-06-11 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111694924B (zh) | 一种事件抽取方法和系统 | |
CN110347835B (zh) | 文本聚类方法、电子装置及存储介质 | |
CN110413785B (zh) | 一种基于bert和特征融合的文本自动分类方法 | |
CN106502985B (zh) | 一种用于生成标题的神经网络建模方法及装置 | |
CN115393692A (zh) | 基于生成式预训练语言模型的联想文本到图像生成方法 | |
CN111143563A (zh) | 基于bert与lstm及cnn融合的文本分类方法 | |
CN112732864B (zh) | 一种基于稠密伪查询向量表示的文档检索方法 | |
CN111626041B (zh) | 一种基于深度学习的音乐评论生成方法 | |
CN112306494A (zh) | 一种基于卷积和循环神经网络的代码分类及聚类方法 | |
CN112528634A (zh) | 文本纠错模型训练、识别方法、装置、设备及存储介质 | |
CN113946677B (zh) | 基于双向循环神经网络和注意力机制的事件识别分类方法 | |
CN112528643A (zh) | 一种基于神经网络的文本信息提取方法及装置 | |
CN110866169B (zh) | 一种基于学习的物联网实体消息解析方法 | |
CN114925197B (zh) | 基于主题注意力的深度学习文本分类模型训练方法 | |
CN117634459A (zh) | 目标内容生成及模型训练方法、装置、系统、设备及介质 | |
CN117391079A (zh) | 一种推理文本生成大模型的方法 | |
CN111507101A (zh) | 一种基于多层次语义胶囊路由的反讽检测方法 | |
CN114662659B (zh) | 一种基于多阶段迁移学习策略综合的众包文本集成方法 | |
CN115840815A (zh) | 基于指针关键信息的自动摘要生成方法 | |
CN113849641B (zh) | 一种跨领域层次关系的知识蒸馏方法和系统 | |
CN113901820A (zh) | 一种基于bert模型的中文三元组抽取方法 | |
CN111859924B (zh) | 一种基于word2vec模型构建词网的方法和装置 | |
CN112256838B (zh) | 相似域名查找方法、装置及电子设备 | |
CN117113977B (zh) | 一种识别试卷中包含ai生成文字的方法、介质及系统 | |
CN116227428B (zh) | 一种基于迁移模式感知的文本风格迁移方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |