CN112100387B - 用于文本分类的神经网络系统的训练方法及装置 - Google Patents
用于文本分类的神经网络系统的训练方法及装置 Download PDFInfo
- Publication number
- CN112100387B CN112100387B CN202011269071.3A CN202011269071A CN112100387B CN 112100387 B CN112100387 B CN 112100387B CN 202011269071 A CN202011269071 A CN 202011269071A CN 112100387 B CN112100387 B CN 112100387B
- Authority
- CN
- China
- Prior art keywords
- training
- loss
- classification
- feature extraction
- vectors
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 213
- 238000013528 artificial neural network Methods 0.000 title claims abstract description 119
- 238000000034 method Methods 0.000 title claims abstract description 96
- 239000013598 vector Substances 0.000 claims abstract description 321
- 238000000605 extraction Methods 0.000 claims abstract description 152
- 238000012545 processing Methods 0.000 claims abstract description 62
- 238000012512 characterization method Methods 0.000 claims abstract description 30
- 239000011159 matrix material Substances 0.000 claims description 123
- 238000011176 pooling Methods 0.000 claims description 84
- 230000000875 corresponding effect Effects 0.000 claims description 55
- 238000012795 verification Methods 0.000 claims description 41
- 230000008569 process Effects 0.000 claims description 38
- 230000002596 correlated effect Effects 0.000 claims description 16
- 230000004927 fusion Effects 0.000 claims description 14
- 238000007499 fusion processing Methods 0.000 claims description 14
- 238000004590 computer program Methods 0.000 claims description 4
- 230000006870 function Effects 0.000 description 11
- 238000010586 diagram Methods 0.000 description 9
- 238000013145 classification model Methods 0.000 description 7
- 238000013507 mapping Methods 0.000 description 5
- 230000008447 perception Effects 0.000 description 5
- 238000013527 convolutional neural network Methods 0.000 description 4
- 238000002372 labelling Methods 0.000 description 4
- 238000012360 testing method Methods 0.000 description 4
- 239000002131 composite material Substances 0.000 description 3
- 230000002708 enhancing effect Effects 0.000 description 3
- 238000010801 machine learning Methods 0.000 description 3
- 238000004364 calculation method Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 238000012935 Averaging Methods 0.000 description 1
- 208000025174 PANDAS Diseases 0.000 description 1
- 208000021155 Paediatric autoimmune neuropsychiatric disorders associated with streptococcal infection Diseases 0.000 description 1
- 240000000220 Panda oleosa Species 0.000 description 1
- 235000016496 Panda oleosa Nutrition 0.000 description 1
- 241000282376 Panthera tigris Species 0.000 description 1
- 238000007796 conventional method Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 230000000306 recurrent effect Effects 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000011218 segmentation Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/35—Clustering; Classification
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Biophysics (AREA)
- Evolutionary Computation (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Databases & Information Systems (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
- Image Analysis (AREA)
Abstract
本说明书实施例提供一种用于文本分类的神经网络系统的训练方法,该神经网络系统包括文本表征网络、特征提取层和分类网络。该训练方法包括:首先,获取训练文本集,该训练文本集对应K个类别;接着,针对该训练文本集中任一的第一训练文本,利用上述文本表征网络对其进行处理,得到第一文本向量;然后,利用上述特征提取层,将该第一文本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;再接着,基于该K个特征提取向量和上述分类网络,确定分类预测结果;再然后,基于该分类预测结果和上述第一训练文本的类别标签,训练上述神经网络系统。
Description
技术领域
本说明书一个或多个实施例自然语言处理技术领域,尤其涉及一种用于文本分类的神经网络系统的训练方法及装置,一种用于样本分类的神经网络系统的训练方法及装置。
背景技术
机器学习已经成为当下研究的热点,各行各业纷纷将机器学习技术应用于本行业的业务处理。比如说,在文本处理领域,通过构建文本分类模型,进行文本的分类处理,具体如判别新闻稿属于娱乐新闻还是社会新闻。又比如,在图像识别领域,利用目标识别模型,识别图像中包含的目标,如熊猫、老虎等。
然而,出于多种原因,导致采用机器学习模型得到的业务处理结果准确度有限,难以满足实际应用需求。比如说,在监督学习的场景下,模型的效果依赖于训练样本的质量和数量,但实际往往难以获取到足够的高质样本;又比如,目前通常全盘套用比较成熟的模型结构,使得预测性能受限。
因此,需要一种方案,可以有效提高模型预测的准确度,包括文本分类的准确度。
发明内容
本说明书一个或多个实施例描述了种用于文本分类的神经网络系统的训练方法及装置,可以有效提高文本分类的准确度。
根据第一方面,提供一种用于文本分类的神经网络系统的训练方法,所述神经网络系统包括文本表征网络、特征提取层和分类网络,所述方法包括:获取训练文本集,该训练文本集对应K个类别;针对所述训练文本集中任一的第一训练文本,利用所述文本表征网络对其进行处理,得到第一文本向量;利用所述特征提取层,将所述第一文本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;基于所述K个特征提取向量和所述分类网络,确定分类预测结果;基于所述分类预测结果和所述第一训练文本的类别标签,训练所述神经网络系统。
在一个实施例中,所述神经网络系统还包括特征池化层;其中,在确定分类预测结果之前,所述方法还包括:利用所述特征池化层对所述K个特征提取向量进行池化,得到特征池化向量;其中,确定分类预测结果,包括:基于所述K个特征提取向量、所述特征池化向量和所述分类网络,确定分类预测结果。
在一个实施例中,所述分类网络包括第一全连接层、第二全连接层和输出层;基于所述K个特征提取向量、所述特征池化向量和所述分类网络,确定分类预测结果,包括:将所述K个特征提取向量输入所述第一全连接层,得到第一处理向量;将所述特征池化向量输入所述第二全连接层,得到第二处理向量;对所述第一处理向量和第二处理向量进行融合处理,得到融合向量;将所述融合向量输入所述输出层,得到所述分类预测结果。
在一个具体的实施例中,所述融合处理包括加和处理、对位相乘处理,或拼接处理。
在一个实施例中,基于所述分类预测结果和所述第一训练文本的类别标签,训练所述神经网络系统,包括:基于所述分类预测结果和所述第一训练文本的类别标签,确定第一损失;基于所述K个特征提取向量中任意两个向量之间的相似度,确定第二损失,该第二损失与所述相似度正相关;基于所述第一损失和第二损失,训练所述神经网络系统。
在一个具体的实施例中,在确定第二损失之前,所述方法还包括:获取多个验证文本,并利用所述神经网络系统,确定该多个验证文本对应的多个分类结果;基于所述多个分类结果和各个验证文本对应的类别标签,确定混淆方阵,其中第i行第j个元素指示所述多个验证文本中,第i个类别的文本被错误分类为第j个类别的文本数量;其中,基于所述K个特征提取向量中任意两个向量之间的相似度,确定第二损失,包括:基于所述相似度和所述混淆方阵,确定所述第二损失,该第二损失还与所述混淆方阵中的非对角线元素正相关。
在一个更具体的实施例中,上述基于所述相似度和所述混淆方阵,确定所述第二损失,包括:确定相似度方阵,其中第s行第t个元素指示第s个特征提取向量和第t个特征提取向量之间的相似度;基于所述相似度方阵和所述混淆方阵,确定所述第二损失。
在一个例子中,上述基于所述相似度方阵和所述混淆方阵,确定所述第二损失,包括:将所述混淆方阵中的对角线元素置零,得到去对角化方阵;对所述去对角化方阵和所述相似度方阵进行对位相乘处理,得到对位相乘方阵;基于所述对位相乘方阵,确定所述第二损失。
在一个具体的例子中,上述基于所述对位相乘方阵,确定所述第二损失,包括:将所述对位相乘方阵中元素的平均值,确定为所述第二损失。
在一个实施例中,所述神经网络系统还包括特征池化层;其中,在基于所述K个特征提取向量和所述分类网络,确定第一分类结果之前,所述方法还包括:利用所述特征池化层对所述K个特征提取向量进行池化,得到特征池化向量;其中,基于所述分类预测结果和所述第一训练文本的类别标签,训练所述神经网络系统,包括:基于所述分类预测结果和所述第一训练文本的类别标签,确定第一损失;基于所述特征池化向量与所述K个特征提取向量中任一向量之间的相似度,确定第三损失;基于所述第一损失和第三损失,训练所述神经网络系统。
根据第二方面,提供一种用于样本分类的神经网络系统的训练方法,所述神经网络系统包括样本表征网络、特征提取层和分类网络,所述方法包括:获取训练样本集,该训练样本集对应K个类别;针对所述训练样本集中任一的第一训练样本,利用所述样本表征网络对其进行处理,得到第一样本向量;利用所述特征提取层,将所述第一样本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;基于所述K个特征提取向量和所述分类网络,确定分类预测结果;基于所述分类预测结果和所述第一训练样本的类别标签,训练所述神经网络系统。
在一个实施例中,所述第一训练样本属于文本或图片或音频,所述第一训练样本涉及的业务对象为用户、商户、商品或事件。
根据第三方面,提供一种神经网络系统,用于预测K个类别的文本,所述神经网络系统包括:输入层,用于获取目标文本;文本表征网络,用于对所述目标文本进行处理,得到目标文本向量;特征提取层,用于将所述目标文本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;分类网络,用于利用所述K个特征提取向量,确定分类预测结果。
根据第四方面,提供一种神经网络系统,用于预测K个类别的样本,所述神经网络系统包括:输入层,用于获取目标样本;样本表征网络,用于对所述目标样本进行处理,得到目标样本向量;特征提取层,用于将所述目标样本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;分类网络,用于利用所述K个特征提取向量,确定分类预测结果。
根据第五方面,提供一种用于文本分类的神经网络系统的训练装置,所述神经网络系统包括文本表征网络、特征提取层和分类网络,所述装置包括:文本获取单元,配置为获取训练文本集,该训练文本集对应K个类别;文本表征单元,配置为针对所述训练文本集中任一的第一训练文本,利用所述文本表征网络对其进行处理,得到第一文本向量;特征提取单元,配置为利用所述特征提取层,将所述第一文本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;分类预测单元,配置为基于所述K个特征提取向量和所述分类网络,确定分类预测结果;训练单元,配置为基于所述分类预测结果和所述第一训练文本的类别标签,训练所述神经网络系统。
根据第六方面,提供一种用于样本分类的神经网络系统的训练装置,所述神经网络系统包括样本表征网络、特征提取层和分类网络,所述装置包括:样本获取单元,配置为获取训练样本集,该训练样本集对应K个类别;样本表征单元,配置为针对所述训练样本集中任一的第一训练样本,利用所述样本表征网络对其进行处理,得到第一样本向量;特征提取单元,配置为利用所述特征提取层,将所述第一样本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;分类预测单元,配置为基于所述K个特征提取向量和所述分类网络,确定分类预测结果;训练单元,配置为基于所述分类预测结果和所述第一训练样本的类别标签,训练所述神经网络系统。
根据第七方面,提供了一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行上述第一方面或第二方面的方法。
根据第八方面,提供了一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现上述第一方面或第二方面的方法。
综上,采用本说明书实施例提供的上述训练方法及装置,通过引入表征类别私有特征的类别特征向量,提高文本分类模型对类别之间的差异性的感知,从而提高分类结果的准确度,同时,类别特征向量显式地建模了类别的私有特征,增强了模型的可解释性;进一步,还引入表征类别公有特征的特征池化向量,在训练过程中可以通过对私有特征和公有特征进行区分,来降低数据噪声,提高模型性能,同理,特征池化向量显式地建模了类别公有特征,增强了模型的可解释性;此外,还可以通过在模型训练的损失函数中显式地引入基于验证文本集确定的混淆矩阵,有效缓解训练集和验证集数据分布不一致的问题,提升模型在测试集上的表现。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
图1示出传统的文本分类流程;
图2示出根据一个实施例的文本分类实施架构图;
图3示出根据一个实施例的用于文本分类的神经网络系统的训练方法流程示意图;
图4示出根据一个实施例的用于样本分类的神经网络系统的训练方法流程示意图;
图5示出根据一个实施例的用于文本分类的神经网络系统的训练装置结构示意图;
图6示出根据一个实施例的用于样本分类的神经网络系统的训练装置结构示意图;
图7示出根据一个实施例的神经网络系统的结构图示;
图8示出根据另一个实施例的神经网络系统的结构图示。
具体实施方式
下面结合附图,对本说明书提供的方案进行描述。
如前所述,目前文本分类的准确度有限。在传统方式中,如图1所示,输入文本x首先经过映射函数f1映射为固定长度的向量z,然后通过映射函数f2映射成对应的类别表示y。对此,发明人发现,文本分类模型本身缺乏对类别之间差异的明显感知,导致训练后模型的准确率指标的提升,也因此而收到一定限制。
进而,发明人提出一种用于文本分类的神经网络系统的训练方法,图2示出根据一个实施例的文本分类实施架构图。如图2所示,相较于图1示出的传统方式,在得到固定长度的向量z后,不是直接通过f2映射成对应的类别表示y,而是将向量z扩张成与类别特征矩阵C维度相同的文本特征矩阵Z,其中类别特征矩阵包括各个类别(共k个)对应的特征向量(如c1等),用于提取向量z中分属各个类别的特征表示;然后,对类别特征矩阵C和文本特征矩阵Z进行对位相乘处理,得到特征提取矩阵G,再根据映射函数f′确定类别表示y。如此,可以提高文本分类模型对类别之间的差异性的感知,从而提高分类结果的准确度,同时,类别特征矩阵C显式地建模了类别间私有特征,具有良好地模型可解释性。
下面,结合具体的实施例,对上述训练方法的实施步骤进行介绍。
图3示出根据一个实施例的用于文本分类的神经网络系统的训练方法流程示意图,其中神经网络系统包括文本表征网络、特征提取层和分类网络,所述方法的执行主体可以为任意的具有计算、处理能力的平台、装置、服务器或设备集群。如图3所示,所述方法包括以下步骤:
步骤S310,获取训练文本集,该训练文本集对应K个类别;步骤S320,针对上述训练文本集中任一的第一训练文本,利用上述文本表征网络对其进行处理,得到第一文本向量;步骤S330,利用上述特征提取层,将上述第一文本向量与对应上述K个类别的K个类别特征向量分别与进行组合操作,得到K个特征提取向量;步骤S340,基于上述K个特征提取向量和上述分类网络,确定分类预测结果;步骤S350,基于上述分类预测结果和上述第一训练文本的类别标签,训练上述神经网络系统。
针对以上步骤,首先需要说明,上述“第一训练文本”、“第一文本向量”中的“第一”,以及后文“第二”等类似用语,均用于区分同类事物,不具有排序等其他限定作用。
以上步骤具体如下:
首先,在步骤S310,获取训练文本集,该训练文本集对应K个类别。其中K为大于1的正整数。
许多领域或场景涉及文本分类,比如说,在客服服务领域,需要根据用户的咨询文本,确定用户的咨询意图,相应在一个实施例中,训练文本集中的任一训练文本(以下简称第一训练文本)可以是用户会话文本,或用户与客服之间的多轮对话文本,相应地,K个类别可以是用户意图类别,例如,“如何使用绑定银行卡”、“如何撤销投诉”、“如何开通新业务”等。又比如说,在内容推荐领域,需要对待展示的内容文本进行分类,以将其展示在对应类别的界面板块或APP中,相应在一个实施例中,第一训练文本可以是板块新闻、论坛文章等,相应地,K个类别可以包括社会新闻、娱乐新闻、科技新闻、生活趣事,等等。
基于以上获取的训练文本集,在步骤S320,针对其中任一的第一训练文本,利用上述文本表征网络对其进行处理,得到第一文本向量。在一个实施例中,文本表征网络中可以包括嵌入层,嵌入层用于对文本分词进行词嵌入处理,得到对应的词向量。目前已经有许多词嵌入算法实现基于大量文本语料训练词向量,由此,嵌入层可以直接通过查阅训练好的词向量表,确定第一训练文本所包含分词对应的词向量。在一个实施例中,文本表征网络中还可以包括其他神经网络层或表征模型,用于对嵌入层输出的词向量做进一步处理,从而得到上述第一文本向量。在一个具体的实施例中,其中其他神经网络层可以是卷积神经网络CNN层,或循环神经网络RNN层,表征模型可以是Bert模型。
以上,对第一训练文本进行表征,得到第一文本向量。然后,在步骤S330,利用上述特征提取层,将上述第一文本向量与对应上述K个类别的K个类别特征向量分别与进行组合操作,得到K个特征提取向量。
需理解,K个类别特征向量中的向量元素会在训练过程中被不断调整。而对于K个类别特征向量的初始化,在一个实施例中,可以将向量元素初始化为预定数值,如每个元素均为1。在另一个实施例中,可以通过随机初始化而得到。
上述各个类别特征向量用于表征对应类别的私有特征,用于提取文本向量中包含的对应类别特征。并且,类别特征向量的设计,实现了对各个类别的私有特征的显式建模,可以增强模型的可解释性。
在一个实施例中,上述组合操作可以为对位相乘或相加。在一个实施例中,可以以矩阵形式实现组合操作,实现步骤S330。具体,K个类别特征向量具有相同的维度P,文本向量的维度同样为P,由此对文本向量进行复制,将K个第一文本向量组成K*P维的文本特征矩阵,进一步,利用其与K个类别特征向量组成的K*P维的类别特征矩阵,进行组合操作,如此得到的K*P维矩阵中的K行,对应上述K个特征提取向量。另一方面,在一个实施例中,本步骤可以包括:先对K个类别特征向量中各个向量元素进行归一化映射处理,映射为[0,1]中的数值,再用于第一文本向量的特征提取。例如,可以利用sigmoid函数实现。如此,可以降低后续计算的数量级。
以上,可以得到对应K个类别的K个特征提取向量,然后,在步骤S340,基于上述K个特征提取向量和上述分类网络,确定分类预测结果。
在一种实施方式中,分类网络中包括第一全连接层和输出层,其中第一全连接层可以是一个或多个全连接层。在一个实施例中,对K个特征提取向量进行拼接,然后将拼接得到的维数为K*P的向量输入第一全连接层,再将第一全连接层输出的K维第一处理向量输入输出层,由输出层输出分类预测结果。在另一个实施例中,将K个特征提取向量分别输入第一全连接层,对应得到的K个数值,再将该K个数值(相当于K维向量)共同输入输出层,由输出层输出分类预测结果。在一个实施例中,输出层中利用softmax函数对第一处理向量进行运算,得到分类预测结果。在一个实施例中,分类预测结果包括第一训练文本被分类至K个类别的K个概率。在另一个实施例中,分类预测结果中包括该K个概率中最大概率所对应的类别。
在另一种实施方式中,发明人还提出,在分类过程中,除了考虑类别私有特征,还可以考虑类别的公共特征,不同类别之间可能具有共性,通过提取类别的公共特征,也有助于降低数据噪声,提升分类结果的准确性。基于此,上述神经网络系统中还可以包括特征池化层,用于对K个特征提取向量进行池化,得到特征池化向量,相应地,本步骤中可以利用该特征池化向量、上述K个特征提取向量和分类网络,得到上述分类预测结果。同时,特征池化向量的设计,还实现了对各个类别的公有特征的显式建模,可以增强模型的可解释性。
在一个实施例中,其中池化的方式可以是最大池化或平均池化,其中最大池化是指,对于K个特征提取向量中相同位置的元素,选取其中的最大值作为特征池化向量中对应位置的元素;其中平均池化是指,对于K个特征提取向量中相同位置的元素,求取平均值作为特征池化向量中对应位置的元素。
在一个实施例中,得到上述分类预测结果,可以包括:对特征池化向量和上述K个特征提取向量进行拼接处理,输入分类网络。在另一个实施例中,分类网络中包括第一全连接层、第二全连接层和输出层,其中对第一全连接层的介绍可以参见前述相关内容,利用第一全连接层对K个特征提取向量进行处理,可以得到K维的第一处理向量;其中第二全连接层对特征池化向量进行处理,可以得到K维的第二处理向量。进一步,将第一处理向量和第二处理向量进行融合处理,得到融合向量,在将融合向量输入输出层,得到上述分类预测结果。在一个具体的实施例中,其中融合处理可以包括加和处理、求平均处理、对位相乘处理,或拼接处理。
以上,可以确定出针对第一训练文本的分类预测结果。然后,在步骤S350,基于上述分类预测结果和第一训练文本的类别标签,训练上述神经网络系统。需理解,上述K个类别特征向量是特征提取层中的参数,训练神经网络系统的过程中,该K个类别特征向量中的元素相应得到调整。
在一个实施例中,本步骤可以包括:基于上述分类预测结果和第一训练文本的类别标签,确定第一损失,再基于该第一损失,训练上述神经网络系统。在一个具体的实施例中,其中第一损失可以是交叉熵(cross entropy)损失或铰链损失等。在一个具体的实施例种,可以基于第一损失,利用已有的反向传播法调整神经网络系统的参数。根据一个例子,第一损失的计算可以基于以下公式实现:
在一个实施例中,为了使上述对应K个类别的K个特征提取向量各自更多反映对应类别的私有特征,同时,减少分类过程中对不同类别的混淆。可以基于K个特征提取向量中任意两个向量之间的相似度,引入第二损失,进而基于上述第一损失和该第二损失,训练上述神经网络系统。在一个具体的实施例中,向量间的相似度可以通过计算点乘、或余弦相似度、或欧式距离等实现。在一个具体的实施例中,通过计算任意两个向量之间的相似度,共可以得到个相似度,然后求取这些相似度的平均值,作为第二损失。
进一步,在一个实施例中,发明人考虑到,目前因为标注费用高昂,不同标注人员打标时存在差异,导致获取标注的文本数据的数据量较小,由其划分而来的训练文本集和验证集的分类很可能存在差异,导致训练后的模型在测试集上的效果与训练集上不一致的问题较为严重,也就影响了模型的预测准确度。因此,发明人提出,还可以引入混淆矩阵(或称困惑矩阵),来计算上述用于减少分类过程中对不同类别混淆的第二损失。具体,在确定第二损失之前,上述方法还可以包括:先获取多个验证文本,并利用上述神经网络系统,确定该多个验证文本对应的多个分类结果;然后,基于该多个分类结果和各个验证文本对应的类别标签,确定混淆方阵,其中第i行第j个元素指示该多个验证文本中,第i个类别的文本被错误分类为第j个类别的文本数量。在一个具体的实施例中,包含多个验证文本的验证文本集和上述训练文本集,可以基于标注文本总集进行划分而得到,具体的划分方式可以采用已有技术实现,不作赘述。在一个具体的实施例中,多个验证文本所对应的多个分类结果的确定,与上述第一训练文本所对应分类预测结果的确定方式一致,故不作赘述。
在以上确定混淆方阵后,可以基于该混淆方阵和上述计算出的相似度,确定第二损失,并且,该第二损失还与混淆方阵中的非对角线元素正相关。这是因为,混淆方阵中的对角线元素是指某个类别的文本被分类为该某个类别的文本的数量,而分类正确不能被认为是错误分类。在一个具体的实施例中,在将混淆方阵用于计算第二损失之前,还可以对其进行归一化处理。在一个具体的实施例中,可以求取上述个相似度和混淆方阵中个非对角线元素的平均值,作为上述第二损失。
另一方面,在一个具体的实施例中,确定第二损失可以包括:确定相似度方阵,其中第s行第t个元素指示第s个特征提取向量和第t个特征提取向量之间的相似度;然后,基于该相似度方阵和上述混淆方阵,确定第二损失。
在一个更具体的实施例中,先将该混淆方阵中的对角线元素置零,得到去对角化方阵,再对该去对角化方阵和相似度方阵进行对位相乘处理,得到对位相乘方阵,进而基于该对位相乘方阵,确定上述第二损失。在一个例子中,可以将对位相乘方阵中元素的平均值,确定为第二损失。在另一个例子中,可以将对位相乘方阵中元素和值的算术平方根,确定为第二损失。
在另一个更具体地实施例中,可以计算相似度方阵和混淆方阵之间的和矩阵,再求取该和矩阵中非对角元素的平均值,作为第二损失。
以上,可以确定第二损失,进而根据第一损失和第二损失,确定综合损失,并基于该综合损失,训练上述神经网络系统。在一个例子中,该综合损失可以表达为:
在另一个实施例中,发明人考虑到,还可以将类别间的公有特征和类别私有特征进行区分,以同时实现对公有特征和类别私有特征的降噪处理,进而提高文本分类准确度。具体,通过计算上述反映类别间公有特征的特征池化向量与上述K个特征提取向量中各向量之间的相似度,引入第三损失,进而基于上述第一损失和第三损失,训练上述神经网络系统。在一个具体的实施例中,对于计算出的K个相似度,可以求取其平均值,作为第三损失。
如此,可以确定第三损失,进而根据第一损失和第三损失,确定综合损失,并基于该综合损失,训练上述神经网络系统。在一个例子中,该综合损失可以表达为:
根据一个具体的例子,还可以根据上述第一损失、第二损失和第三损失,确定综合损失,再基于该综合损失,训练上述神经网络系统。需理解,第一损失、第二损失和第三损失,均与该综合损失正相关。在一个例子中,该综合损失可以表达为:
以上,可以实现对神经网络系统的训练。
综上,采用本说明书实施例披露的用于文本分类的神经网络系统的训练方法,通过引入表征类别私有特征的类别特征向量,提高文本分类模型对类别之间的差异性的感知,从而提高分类结果的准确度,同时,类别特征向量显式地建模了类别的私有特征,增强了模型的可解释性;进一步,还引入表征类别公有特征的特征池化向量,在训练过程中可以通过对私有特征和公有特征进行区分,来降低数据噪声,提高模型性能,同理,特征池化向量显式地建模了类别公有特征,增强了模型的可解释性;此外,还可以通过在模型训练的损失函数中显式地引入基于验证文本集确定的混淆矩阵,有效缓解训练集和验证集数据分布不一致的问题,提升模型在测试集上的表现。
以上,主要介绍一种用于文本分类的神经网络系统的训练方法。实际上,上述训练方法还可以应用于文本分类以外的其他领域,如图片目标识别,用户分类、事件(如登录事件、访问事件、交易事件等)分类。基于此,本说明书实施例还披露一种用于样本分类的神经网络系统的训练方法,该神经网络系统包括样本表征网络、特征提取层和分类网络,所述方法的执行主体可以为任何具有计算、处理能力的装置、服务器和设备集群。如图4所示,所述方法包括以下步骤:
步骤S410,获取训练样本集,该训练样本集对应K个类别;步骤S420,针对上述训练样本集中任一的第一训练样本,利用上述样本表征网络对其进行处理,得到第一样本向量;步骤S430,利用上述特征提取层,将上述第一样本向量分别与对应上述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;步骤S440,基于上述K个特征提取向量和上述分类网络,确定分类预测结果;步骤S450,基于上述分类预测结果和上述第一训练样本的类别标签,训练上述神经网络系统。
以上步骤如下:
首先,在步骤S410,获取训练样本集,该训练样本集对应K个类别。在一个实施例中,训练样本涉及的业务对象可以为用户、商品或事件。在一个具体的实施例中,训练样本可以是用户样本,对应的用户类别可以是风险类别(如高风险、低风险等)或人群类别(如低消费人群、高消费人群等)。在另一个具体的实施例中,训练样本可以是商品样本,对应的商品类别可以是兴趣等级类别(如热销商品和滞销商品等)。在又一个具体的实施例中,训练样本可以是事件样本,对应的事件类别可以是风险类别(如高风险、低风险等)。另一方面,在一个实施例中,上述训练样本可以属于样本或图片或音频。
以上,可以获取训练样本集,接着,在步骤S420,针对上述训练样本集中任一的第一训练样本,利用上述样本表征网络对其进行处理,得到第一样本向量。在一个实施例中,其中样本表征网络可以基于深度神经网络DNN或卷积神经网络CNN实现。
然后,在步骤S430,利用上述特征提取层,将第一样本向量分别与对应上述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量。在一个实施例中,其中组合操作可以是对位相乘。在另一个实施例中,其中组合操作可以是相加操作。
接着,在步骤S440,基于上述K个特征提取向量和上述分类网络,确定分类预测结果。
在一个实施例中,上述神经网络系统还包括特征池化层;其中,在确定分类预测结果之前,上述方法还包括:利用上述特征池化层对上述K个特征提取向量进行池化,得到特征池化向量;其中,确定分类预测结果,包括:基于上述K个特征提取向量、上述特征池化向量和上述分类网络,确定分类预测结果。
在一个具体的实施例中,上述分类网络包括第一全连接层、第二全连接层和输出层;基于上述K个特征提取向量、上述特征池化向量和上述分类网络,确定分类预测结果,包括:将上述K个特征提取向量输入上述第一全连接层,得到第一处理向量;将上述特征池化向量输入上述第二全连接层,得到第二处理向量;对上述第一处理向量和第二处理向量进行融合处理,得到融合向量;将上述融合向量输入上述输出层,得到上述分类预测结果。
在一个更具体的实施例中,上述融合处理包括加和处理、对位相乘处理,或拼接处理。
再然后,在步骤S450,基于上述分类预测结果和上述第一训练样本的类别标签,训练上述神经网络系统。
在一个实施例中,基于上述分类预测结果和上述第一训练样本的类别标签,训练上述神经网络系统,包括:基于上述分类预测结果和上述第一训练样本的类别标签,确定第一损失;基于上述K个特征提取向量中任意两个向量之间的相似度,确定第二损失,该第二损失与上述相似度正相关;基于上述第一损失和第二损失,训练上述神经网络系统。
在一个具体的实施例中,在确定第二损失之前,上述方法还包括:获取多个验证样本,并利用上述神经网络系统,确定该多个验证样本对应的多个分类结果;基于上述多个分类结果和各个验证样本对应的类别标签,确定混淆方阵,其中第i行第j个元素指示上述多个验证样本中,第i个类别的样本被错误分类为第j个类别的样本数量;其中,基于上述K个特征提取向量中任意两个向量之间的相似度,确定第二损失,包括:基于上述相似度和上述混淆方阵,确定上述第二损失,该第二损失还与上述混淆方阵中的非对角线元素正相关。
在一个更具体的实施例中,基于上述相似度和上述混淆方阵,确定上述第二损失,包括:确定相似度方阵,其中第s行第t个元素指示第s个特征提取向量和第t个特征提取向量之间的相似度;基于上述相似度方阵和上述混淆方阵,确定上述第二损失。进一步,在一个例子中,基于上述相似度方阵和上述混淆方阵,确定上述第二损失,包括:将上述混淆方阵中的对角线元素置零,得到去对角化方阵;对上述去对角化方阵和上述相似度方阵进行对位相乘处理,得到对位相乘方阵;基于上述对位相乘方阵,确定上述第二损失。更进一步,在一个具体的例子中,基于上述对位相乘方阵,确定上述第二损失,包括:将上述对位相乘方阵中元素的平均值,确定为上述第二损失。
在一个实施例中,上述神经网络系统还包括特征池化层;其中,在基于上述K个特征提取向量和上述分类网络,确定第一分类结果之前,上述方法还包括:利用上述特征池化层对上述K个特征提取向量进行池化,得到特征池化向量;其中,基于上述分类预测结果和上述第一训练样本的类别标签,训练上述神经网络系统,包括:基于上述分类预测结果和上述第一训练样本的类别标签,确定第一损失;基于上述特征池化向量与上述K个特征提取向量中任一向量之间的相似度,确定第三损失;基于上述第一损失和第三损失,训练上述神经网络系统。
需要说明,对于图4所示出步骤的描述,还可以参见对图3示出步骤的描述。
综上,采用本说明书实施例披露的用于样本分类的神经网络系统的训练方法,通过引入表征类别私有特征的类别特征向量,提高样本分类模型对类别之间的差异性的感知,从而提高分类结果的准确度,同时,类别特征向量显式地建模了类别的私有特征,增强了模型的可解释性;进一步,还引入表征类别公有特征的特征池化向量,在训练过程中可以通过对私有特征和公有特征进行区分,来降低数据噪声,提高模型性能,同理,特征池化向量显式地建模了类别公有特征,增强了模型的可解释性;此外,还可以通过在模型训练的损失函数中显式地引入基于验证样本集确定的混淆矩阵,有效缓解训练集和验证集数据分布不一致的问题,提升模型在测试集上的表现。
与上述训练方法相对应的,本说明书实施例还披露训练装置。具体如下:
图5示出根据一个实施例的用于文本分类的神经网络系统的训练装置结构示意图,其中神经网络系统包括文本表征网络、特征提取层和分类网络。如图5所示,所述装置500包括:
文本获取单元510,配置为获取训练文本集,该训练文本集对应K个类别;文本表征单元520,配置为针对所述训练文本集中任一的第一训练文本,利用所述文本表征网络对其进行处理,得到第一文本向量;特征提取单元530,配置为利用所述特征提取层,将所述第一文本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;分类预测单元540,配置为基于所述K个特征提取向量和所述分类网络,确定分类预测结果;训练单元550,配置为基于所述分类预测结果和所述第一训练文本的类别标签,训练所述神经网络系统。
在一个实施例中,所述神经网络系统还包括特征池化层;所述装置500还包括:特征池化单元560,配置为利用所述特征池化层对所述K个特征提取向量进行池化,得到特征池化向量;其中,所述分类预测单元540配置为:基于所述K个特征提取向量、所述特征池化向量和所述分类网络,确定分类预测结果。
在一个具体的实施例中,所述分类网络包括第一全连接层、第二全连接层和输出层;所述分类预测单元具体配置为:将所述K个特征提取向量输入所述第一全连接层,得到第一处理向量;将所述特征池化向量输入所述第二全连接层,得到第二处理向量;对所述第一处理向量和第二处理向量进行融合处理,得到融合向量;将所述融合向量输入所述输出层,得到所述分类预测结果。
在一个更具体的实施例中,所述融合处理包括加和处理、对位相乘处理,或拼接处理。
在一个实施例中,所述训练单元550包括:第一损失确定模块551,配置为基于所述分类预测结果和所述第一训练文本的类别标签,确定第一损失;第二损失确定模块552,配置为基于所述K个特征提取向量中任意两个向量之间的相似度,确定第二损失,该第二损失与所述相似度正相关;训练模块553,配置为基于所述第一损失和第二损失,训练所述神经网络系统。
在一个具体的实施例中,所述装置550还包括:验证结果确定单元570,配置为获取多个验证文本,并利用所述神经网络系统,确定该多个验证文本对应的多个分类结果;混淆方阵确定单元580,配置为基于所述多个分类结果和各个验证文本对应的类别标签,确定混淆方阵,其中第i行第j个元素指示所述多个验证文本中,第i个类别的文本被错误分类为第j个类别的文本数量;其中,所述第二损失确定模块552具体配置为:基于所述相似度和所述混淆方阵,确定所述第二损失,该第二损失还与所述混淆方阵中的非对角线元素正相关。
在一个更具体的实施例中,所述第二损失确定模块552具体配置为:确定相似度方阵,其中第s行第t个元素指示第s个特征提取向量和第t个特征提取向量之间的相似度;基于所述相似度方阵和所述混淆方阵,确定所述第二损失。进一步,在一个例子中,所述第二损失确定模块552配置为基于所述相似度方阵和所述混淆方阵,确定所述第二损失,具体包括:将所述混淆方阵中的对角线元素置零,得到去对角化方阵;对所述去对角化方阵和所述相似度方阵进行对位相乘处理,得到对位相乘方阵;基于所述对位相乘方阵,确定所述第二损失。在一个具体的例子中,所述第二损失确定模块552配置为基于所述对位相乘方阵,确定所述第二损失,具体包括:将所述对位相乘方阵中元素的平均值,确定为所述第二损失。
在一个实施例中,所述神经网络系统还包括特征池化层;所述装置500还包括:特征池化单元560,配置为利用所述特征池化层对所述K个特征提取向量进行池化,得到特征池化向量;其中,所述训练单元550具体配置为:基于所述分类预测结果和所述第一训练文本的类别标签,确定第一损失;基于所述特征池化向量与所述K个特征提取向量中任一向量之间的相似度,确定第三损失;基于所述第一损失和第三损失,训练所述神经网络系统。
图6示出根据一个实施例的用于样本分类的神经网络系统的训练装置结构示意图,其中神经网络系统包括样本表征网络、特征提取层和分类网络,所述装置600包括:
样本获取单元610,配置为获取训练样本集,该训练样本集对应K个类别;样本表征单元620,配置为针对所述训练样本集中任一的第一训练样本,利用所述样本表征网络对其进行处理,得到第一样本向量;特征提取单元630,配置为利用所述特征提取层,将所述第一样本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;分类预测单元640,配置为基于所述K个特征提取向量和所述分类网络,确定分类预测结果;训练单元650,配置为基于所述分类预测结果和所述第一训练样本的类别标签,训练所述神经网络系统。
在一个实施例中,所述第一训练样本属于文本或图片或音频,所述第一训练样本涉及的业务对象为用户、商户、商品或事件。
在一个实施例中,所述神经网络系统还包括特征池化层;所述装置500还包括:特征池化单元660,配置为利用所述特征池化层对所述K个特征提取向量进行池化,得到特征池化向量;其中,所述分类预测单元640配置为:基于所述K个特征提取向量、所述特征池化向量和所述分类网络,确定分类预测结果。
在一个具体的实施例中,所述分类网络包括第一全连接层、第二全连接层和输出层;所述分类预测单元具体配置为:将所述K个特征提取向量输入所述第一全连接层,得到第一处理向量;将所述特征池化向量输入所述第二全连接层,得到第二处理向量;对所述第一处理向量和第二处理向量进行融合处理,得到融合向量;将所述融合向量输入所述输出层,得到所述分类预测结果。
在一个更具体的实施例中,所述融合处理包括加和处理、对位相乘处理,或拼接处理。
在一个实施例中,所述训练单元650包括:第一损失确定模块651,配置为基于所述分类预测结果和所述第一训练样本的类别标签,确定第一损失;第二损失确定模块652,配置为基于所述K个特征提取向量中任意两个向量之间的相似度,确定第二损失,该第二损失与所述相似度正相关;训练模块653,配置为基于所述第一损失和第二损失,训练所述神经网络系统。
在一个具体的实施例中,所述装置550还包括:验证结果确定单元670,配置为获取多个验证样本,并利用所述神经网络系统,确定该多个验证样本对应的多个分类结果;混淆方阵确定单元680,配置为基于所述多个分类结果和各个验证样本对应的类别标签,确定混淆方阵,其中第i行第j个元素指示所述多个验证样本中,第i个类别的样本被错误分类为第j个类别的样本数量;其中,所述第二损失确定模块652具体配置为:基于所述相似度和所述混淆方阵,确定所述第二损失,该第二损失还与所述混淆方阵中的非对角线元素正相关。
在一个更具体的实施例中,所述第二损失确定模块652具体配置为:确定相似度方阵,其中第s行第t个元素指示第s个特征提取向量和第t个特征提取向量之间的相似度;基于所述相似度方阵和所述混淆方阵,确定所述第二损失。进一步,在一个例子中,所述第二损失确定模块652配置为基于所述相似度方阵和所述混淆方阵,确定所述第二损失,具体包括:将所述混淆方阵中的对角线元素置零,得到去对角化方阵;对所述去对角化方阵和所述相似度方阵进行对位相乘处理,得到对位相乘方阵;基于所述对位相乘方阵,确定所述第二损失。在一个具体的例子中,所述第二损失确定模块652配置为基于所述对位相乘方阵,确定所述第二损失,具体包括:将所述对位相乘方阵中元素的平均值,确定为所述第二损失。
在一个实施例中,所述神经网络系统还包括特征池化层;所述装置500还包括:特征池化单元660,配置为利用所述特征池化层对所述K个特征提取向量进行池化,得到特征池化向量;其中,所述训练单元650具体配置为:基于所述分类预测结果和所述第一训练样本的类别标签,确定第一损失;基于所述特征池化向量与所述K个特征提取向量中任一向量之间的相似度,确定第三损失;基于所述第一损失和第三损失,训练所述神经网络系统。
根据另一方面的实施例,图7示出根据一个实施例的神经网络系统的结构图示,其中神经网络系统用于预测K个类别的文本。如图7所示,所述神经网络系统700包括:
输入层710,用于获取目标文本;文本表征网络720,用于对所述目标文本进行处理,得到目标文本向量;特征提取层730,用于将所述目标文本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;分类网络740,用于利用所述K个特征提取向量,确定分类预测结果。
图8示出根据一个实施例的神经网络系统的结构图示,其中神经网络系统用于预测K个类别的样本。如图8所示,所述神经网络系统800包括:输入层810,用于获取目标样本;样本表征网络820,用于对所述目标样本进行处理,得到目标样本向量;特征提取层830,用于将所述目标样本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;分类网络840,用于利用所述K个特征提取向量,确定分类预测结果。
根据另一方面的实施例,还提供一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行结合图3或图4所描述的方法。
根据再一方面的实施例,还提供一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现结合图3或图4所述的方法。
本领域技术人员应该可以意识到,在上述一个或多个示例中,本发明所描述的功能可以用硬件、软件、固件或它们的任意组合来实现。当使用软件实现时,可以将这些功能存储在计算机可读介质中或者作为计算机可读介质上的一个或多个指令或代码进行传输。
以上所述的具体实施方式,对本发明的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上所述仅为本发明的具体实施方式而已,并不用于限定本发明的保护范围,凡在本发明的技术方案的基础之上,所做的任何修改、等同替换、改进等,均应包括在本发明的保护范围之内。
Claims (24)
1.一种用于文本分类的神经网络系统的训练方法,所述神经网络系统包括文本表征网络、特征提取层和分类网络,所述方法包括:
获取训练文本集,该训练文本集对应K个类别;
针对所述训练文本集中任一的第一训练文本,利用所述文本表征网络对其进行处理,得到第一文本向量;
利用所述特征提取层,将所述第一文本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;
基于所述K个特征提取向量和所述分类网络,确定分类预测结果;
基于所述分类预测结果和所述第一训练文本的类别标签,训练所述神经网络系统;
其中,基于所述分类预测结果和所述第一训练文本的类别标签,训练所述神经网络系统,包括:
基于所述分类预测结果和所述第一训练文本的类别标签,确定第一损失;
基于所述K个特征提取向量中任意两个向量之间的相似度,确定第二损失,该第二损失与所述相似度正相关;
基于所述第一损失和第二损失,训练所述神经网络系统。
2.根据权利要求1所述的方法,其中,所述神经网络系统还包括特征池化层;其中,在确定分类预测结果之前,所述方法还包括:
利用所述特征池化层对所述K个特征提取向量进行池化,得到特征池化向量;
其中,确定分类预测结果,包括:
基于所述K个特征提取向量、所述特征池化向量和所述分类网络,确定分类预测结果。
3.根据权利要求2所述的方法,其中,所述分类网络包括第一全连接层、第二全连接层和输出层;基于所述K个特征提取向量、所述特征池化向量和所述分类网络,确定分类预测结果,包括:
将所述K个特征提取向量输入所述第一全连接层,得到第一处理向量;
将所述特征池化向量输入所述第二全连接层,得到第二处理向量;
对所述第一处理向量和第二处理向量进行融合处理,得到融合向量;
将所述融合向量输入所述输出层,得到所述分类预测结果。
4.根据权利要求3所述的方法,其中,所述融合处理包括加和处理、对位相乘处理,或拼接处理。
5.根据权利要求1所述的方法,其中,在确定第二损失之前,所述方法还包括:
获取多个验证文本,并利用所述神经网络系统,确定该多个验证文本对应的多个分类结果;
基于所述多个分类结果和各个验证文本对应的类别标签,确定混淆方阵,其中第i行第j个元素指示所述多个验证文本中,第i个类别的文本被错误分类为第j个类别的文本数量;
其中,基于所述K个特征提取向量中任意两个向量之间的相似度,确定第二损失,包括:
基于所述相似度和所述混淆方阵,确定所述第二损失,该第二损失还与所述混淆方阵中的非对角线元素正相关。
6.根据权利要求5所述的方法,其中,基于所述相似度和所述混淆方阵,确定所述第二损失,包括:
确定相似度方阵,其中第s行第t个元素指示第s个特征提取向量和第t个特征提取向量之间的相似度;
基于所述相似度方阵和所述混淆方阵,确定所述第二损失。
7.根据权利要求6所述的方法,其中,基于所述相似度方阵和所述混淆方阵,确定所述第二损失,包括:
将所述混淆方阵中的对角线元素置零,得到去对角化方阵;
对所述去对角化方阵和所述相似度方阵进行对位相乘处理,得到对位相乘方阵;
基于所述对位相乘方阵,确定所述第二损失。
8.根据权利要求7所述的方法,其中,基于所述对位相乘方阵,确定所述第二损失,包括:
将所述对位相乘方阵中元素的平均值,确定为所述第二损失。
9.根据权利要求1所述的方法,其中,所述神经网络系统还包括特征池化层;其中,在基于所述K个特征提取向量和所述分类网络,确定第一分类结果之前,所述方法还包括:
利用所述特征池化层对所述K个特征提取向量进行池化,得到特征池化向量;
其中,基于所述分类预测结果和所述第一训练文本的类别标签,训练所述神经网络系统,还包括:
基于所述特征池化向量与所述K个特征提取向量中任一向量之间的相似度,确定第三损失;
基于所述第三损失,训练所述神经网络系统。
10.一种用于样本分类的神经网络系统的训练方法,所述神经网络系统包括样本表征网络、特征提取层和分类网络,所述方法包括:
获取训练样本集,该训练样本集对应K个类别;
针对所述训练样本集中任一的第一训练样本,利用所述样本表征网络对其进行处理,得到第一样本向量;
利用所述特征提取层,将所述第一样本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;
基于所述K个特征提取向量和所述分类网络,确定分类预测结果;
基于所述分类预测结果和所述第一训练样本的类别标签,训练所述神经网络系统;
其中,基于所述分类预测结果和所述第一训练样本的类别标签,训练所述神经网络系统,包括:
基于所述分类预测结果和所述第一训练样本的类别标签,确定第一损失;
基于所述K个特征提取向量中任意两个向量之间的相似度,确定第二损失,该第二损失与所述相似度正相关;
基于所述第一损失和第二损失,训练所述神经网络系统。
11.根据权利要求10所述的方法,其中,所述第一训练样本属于文本或图片或音频,所述第一训练样本涉及的业务对象为用户、商户、商品或事件。
12.一种用于文本分类的神经网络系统的训练装置,所述神经网络系统包括文本表征网络、特征提取层和分类网络,所述装置包括:
文本获取单元,配置为获取训练文本集,该训练文本集对应K个类别;
文本表征单元,配置为针对所述训练文本集中任一的第一训练文本,利用所述文本表征网络对其进行处理,得到第一文本向量;
特征提取单元,配置为利用所述特征提取层,将所述第一文本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;
分类预测单元,配置为基于所述K个特征提取向量和所述分类网络,确定分类预测结果;
训练单元,配置为基于所述分类预测结果和所述第一训练文本的类别标签,训练所述神经网络系统;
其中,所述训练单元包括:
第一损失确定模块,配置为基于所述分类预测结果和所述第一训练文本的类别标签,确定第一损失;
第二损失确定模块,配置为基于所述K个特征提取向量中任意两个向量之间的相似度,确定第二损失,该第二损失与所述相似度正相关;
训练模块,配置为基于所述第一损失和第二损失,训练所述神经网络系统。
13.根据权利要求12所述的装置,其中,所述神经网络系统还包括特征池化层;所述装置还包括:
特征池化单元,配置为利用所述特征池化层对所述K个特征提取向量进行池化,得到特征池化向量;
其中,所述分类预测单元配置为:
基于所述K个特征提取向量、所述特征池化向量和所述分类网络,确定分类预测结果。
14.根据权利要求13所述的装置,其中,所述分类网络包括第一全连接层、第二全连接层和输出层;所述分类预测单元具体配置为:
将所述K个特征提取向量输入所述第一全连接层,得到第一处理向量;
将所述特征池化向量输入所述第二全连接层,得到第二处理向量;
对所述第一处理向量和第二处理向量进行融合处理,得到融合向量;
将所述融合向量输入所述输出层,得到所述分类预测结果。
15.根据权利要求14所述的装置,其中,所述融合处理包括加和处理、对位相乘处理,或拼接处理。
16.根据权利要求12所述的装置,其中,所述装置还包括:
验证结果确定单元,配置为获取多个验证文本,并利用所述神经网络系统,确定该多个验证文本对应的多个分类结果;
混淆方阵确定单元,配置为基于所述多个分类结果和各个验证文本对应的类别标签,确定混淆方阵,其中第i行第j个元素指示所述多个验证文本中,第i个类别的文本被错误分类为第j个类别的文本数量;
其中,所述第二损失确定模块具体配置为:
基于所述相似度和所述混淆方阵,确定所述第二损失,该第二损失还与所述混淆方阵中的非对角线元素正相关。
17.根据权利要求16所述的装置,其中,所述第二损失确定模块具体配置为:
确定相似度方阵,其中第s行第t个元素指示第s个特征提取向量和第t个特征提取向量之间的相似度;
基于所述相似度方阵和所述混淆方阵,确定所述第二损失。
18.根据权利要求17所述的装置,其中,所述第二损失确定模块配置为基于所述相似度方阵和所述混淆方阵,确定所述第二损失,具体包括:
将所述混淆方阵中的对角线元素置零,得到去对角化方阵;
对所述去对角化方阵和所述相似度方阵进行对位相乘处理,得到对位相乘方阵;
基于所述对位相乘方阵,确定所述第二损失。
19.根据权利要求18所述的装置,其中,所述第二损失确定模块配置为基于所述对位相乘方阵,确定所述第二损失,具体包括:
将所述对位相乘方阵中元素的平均值,确定为所述第二损失。
20.根据权利要求12所述的装置,其中,所述神经网络系统还包括特征池化层;所述装置还包括:
特征池化单元,配置为利用所述特征池化层对所述K个特征提取向量进行池化,得到特征池化向量;
其中,所述训练单元还包括:
第三损失确定模块,配置为基于所述特征池化向量与所述K个特征提取向量中任一向量之间的相似度,确定第三损失;
所述训练模块还配置为,基于所述第三损失,训练所述神经网络系统。
21.一种用于样本分类的神经网络系统的训练装置,所述神经网络系统包括样本表征网络、特征提取层和分类网络,所述装置包括:
样本获取单元,配置为获取训练样本集,该训练样本集对应K个类别;
样本表征单元,配置为针对所述训练样本集中任一的第一训练样本,利用所述样本表征网络对其进行处理,得到第一样本向量;
特征提取单元,配置为利用所述特征提取层,将所述第一样本向量分别与对应所述K个类别的K个类别特征向量进行组合操作,得到K个特征提取向量;
分类预测单元,配置为基于所述K个特征提取向量和所述分类网络,确定分类预测结果;
训练单元,配置为基于所述分类预测结果和所述第一训练样本的类别标签,训练所述神经网络系统;
其中,所述训练单元包括:
第一损失确定模块,配置为基于所述分类预测结果和所述第一训练样本的类别标签,确定第一损失;
第二损失确定模块,配置为基于所述K个特征提取向量中任意两个向量之间的相似度,确定第二损失,该第二损失与所述相似度正相关;
训练模块,配置为基于所述第一损失和第二损失,训练所述神经网络系统。
22.根据权利要求21所述的装置,其中,所述第一训练样本属于文本或图片或音频,所述第一训练样本涉及的业务对象为用户、商户、商品或事件。
23.一种计算机可读存储介质,其上存储有计算机程序,其中,当所述计算机程序在计算机中执行时,令计算机执行权利要求1-11中任一项的所述的方法。
24.一种计算设备,包括存储器和处理器,其中,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现权利要求1-11中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011269071.3A CN112100387B (zh) | 2020-11-13 | 2020-11-13 | 用于文本分类的神经网络系统的训练方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011269071.3A CN112100387B (zh) | 2020-11-13 | 2020-11-13 | 用于文本分类的神经网络系统的训练方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112100387A CN112100387A (zh) | 2020-12-18 |
CN112100387B true CN112100387B (zh) | 2021-02-19 |
Family
ID=73784572
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011269071.3A Active CN112100387B (zh) | 2020-11-13 | 2020-11-13 | 用于文本分类的神经网络系统的训练方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112100387B (zh) |
Families Citing this family (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112766500B (zh) * | 2021-02-07 | 2022-05-17 | 支付宝(杭州)信息技术有限公司 | 图神经网络的训练方法及装置 |
CN112988186B (zh) * | 2021-02-19 | 2022-07-19 | 支付宝(杭州)信息技术有限公司 | 异常检测系统的更新方法及装置 |
CN112989045B (zh) * | 2021-03-17 | 2023-07-25 | 中国平安人寿保险股份有限公司 | 神经网络训练方法、装置、电子设备及存储介质 |
CN113033579B (zh) * | 2021-03-31 | 2023-03-21 | 北京有竹居网络技术有限公司 | 图像处理方法、装置、存储介质及电子设备 |
CN113139053B (zh) * | 2021-04-15 | 2024-03-05 | 广东工业大学 | 一种基于自监督对比学习的文本分类方法 |
CN113177482A (zh) * | 2021-04-30 | 2021-07-27 | 中国科学技术大学 | 一种基于最小类别混淆的跨个体脑电信号分类方法 |
CN113468324A (zh) * | 2021-06-03 | 2021-10-01 | 上海交通大学 | 基于bert预训练模型和卷积网络的文本分类方法和系统 |
CN113255566B (zh) * | 2021-06-11 | 2022-12-06 | 支付宝(杭州)信息技术有限公司 | 表格图像识别方法及装置 |
CN114143040B (zh) * | 2021-11-08 | 2024-03-22 | 浙江工业大学 | 一种基于多通道特征重构的对抗信号检测方法 |
CN114240495A (zh) * | 2021-12-16 | 2022-03-25 | 成都新潮传媒集团有限公司 | 商机转化概率的预测方法、装置及计算机可读存储介质 |
Family Cites Families (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109101235B (zh) * | 2018-06-05 | 2021-03-19 | 北京航空航天大学 | 一种软件程序的智能解析方法 |
CN109101932B (zh) * | 2018-08-17 | 2020-07-24 | 佛山市顺德区中山大学研究院 | 基于目标检测的多任务及临近信息融合的深度学习方法 |
CN110597983B (zh) * | 2019-07-25 | 2023-09-15 | 华北电力大学 | 一种基于类别嵌入的层次化文本分类计算方法 |
CN111737474B (zh) * | 2020-07-17 | 2021-01-12 | 支付宝(杭州)信息技术有限公司 | 业务模型的训练和确定文本分类类别的方法及装置 |
-
2020
- 2020-11-13 CN CN202011269071.3A patent/CN112100387B/zh active Active
Also Published As
Publication number | Publication date |
---|---|
CN112100387A (zh) | 2020-12-18 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112100387B (zh) | 用于文本分类的神经网络系统的训练方法及装置 | |
CN112559784B (zh) | 基于增量学习的图像分类方法及系统 | |
CN104951965B (zh) | 广告投放方法及装置 | |
CN113139628B (zh) | 样本图像的识别方法、装置、设备及可读存储介质 | |
CN109034203B (zh) | 表情推荐模型的训练、表情推荐方法、装置、设备及介质 | |
CN109189921B (zh) | 评论评估模型的训练方法和装置 | |
Yang et al. | Active matting | |
CN111461164B (zh) | 样本数据集的扩容方法及模型的训练方法 | |
CN111881722B (zh) | 一种跨年龄人脸识别方法、系统、装置及存储介质 | |
CN111242358A (zh) | 一种双层结构的企业情报流失预测方法 | |
CN113722583A (zh) | 推荐方法、推荐模型训练方法及相关产品 | |
CN115170449B (zh) | 一种多模态融合场景图生成方法、系统、设备和介质 | |
CN111105013A (zh) | 对抗网络架构的优化方法、图像描述生成方法和系统 | |
CN113592593A (zh) | 序列推荐模型的训练及应用方法、装置、设备及存储介质 | |
CN116764236A (zh) | 游戏道具推荐方法、装置、计算机设备和存储介质 | |
CN117635769A (zh) | 基于关联感知跨模态注意网络的社交网中服饰推荐方法 | |
Dong et al. | A supervised dictionary learning and discriminative weighting model for action recognition | |
CN115688742B (zh) | 基于人工智能的用户数据分析方法及ai系统 | |
CN116340635A (zh) | 物品推荐方法、模型训练方法、装置及设备 | |
CN116503127A (zh) | 模型训练方法、检索方法及相关装置 | |
CN110717037A (zh) | 对用户分类的方法和装置 | |
Kuang et al. | Multi-label image classification with multi-layered multi-perspective dynamic semantic representation | |
CN114818900A (zh) | 一种半监督特征提取方法及用户信用风险评估方法 | |
CN114239569A (zh) | 评估文本的分析方法及其装置、计算机可读存储介质 | |
CN115393914A (zh) | 多任务模型训练方法、装置、设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |