CN113344031B - 一种文本分类方法 - Google Patents

一种文本分类方法 Download PDF

Info

Publication number
CN113344031B
CN113344031B CN202110520242.3A CN202110520242A CN113344031B CN 113344031 B CN113344031 B CN 113344031B CN 202110520242 A CN202110520242 A CN 202110520242A CN 113344031 B CN113344031 B CN 113344031B
Authority
CN
China
Prior art keywords
text
classified
category
category label
negative
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110520242.3A
Other languages
English (en)
Other versions
CN113344031A (zh
Inventor
张雷
杨竞潮
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tsinghua University
Original Assignee
Tsinghua University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tsinghua University filed Critical Tsinghua University
Priority to CN202110520242.3A priority Critical patent/CN113344031B/zh
Publication of CN113344031A publication Critical patent/CN113344031A/zh
Application granted granted Critical
Publication of CN113344031B publication Critical patent/CN113344031B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting

Landscapes

  • Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明属于深度学习与算法领域,尤其涉及一种文本分类方法。本发明在构建基于度量学习的三元组损失损失函数时,计算待分类文本与正类目标签之间的欧式距离、待分类文本与负类目标签之间的欧式距离和正类目标签与负类目标签之间的欧式距离,并得到三元组损失函数,同时加入了一个“粗筛‑精筛”的过程。本发明方法基于样本三元组,在优化待分类文本与正样本和负样本的距离之差的同时,加入一个系数同时优化正样本与负样本之间的距离,构造了两层的级联模型,相比于单层模型,有效提高了分类准确率。本发明的文本分类方法,应用简便,易于推广,除了文本分类任务也可以应用在计算机视觉等多种领域。

Description

一种文本分类方法
技术领域
本发明属于深度学习与算法领域,尤其涉及一种文本分类方法。
背景技术
在文本分类任务中,当已知文本类目的标签时,可以采用度量文本与标签距离的方法来衡量。这样和多分类模型相比,可以引入标签的文本信息,提升分类的准确率。
分类任务使用度量学习方法,计算待分类的文本与各个类目标签通过预训练模型后转化为向量的距离,将距离进行排序,距离最短的类目标签即为该文本所属的类目。
度量学习的损失函数当前主要采用三元组损失(Triplet Loss),将每一段待分类的文本作为锚点,它所属的类目标签作为正样本,在其他类目标签中随机选取一个作为负样本。让文本对应的向量尽可能靠近正样本类目标签对应的向量,并远离负样本类目标签对应的向量,通过这种方法对预训练模型进行微调。损失函数:
TripletLoss=(d(a,p)-d(a,n)+margin)+
但Triplet Loss损失函数的表达式仅仅考虑到文本与正样本距离尽可能近,与负样本尽可能远,实际上也希望同时满足于类目标签之间的距离尽可能远,并将其加入损失函数中。
根据这一要求,提出了四元组损失(Quadruplet Loss)的改进:
QuadrupletLoss=(d(a,p)-d(a,n1)+α)++(d(a,p)-d(n1,n2)+β)+
四元组损失加入了新的负样本n2,让两个负样本之间的距离尽可能远,另外弱推动项中也能使待分类文本的向量与正样本向量尽可能近。但四元组损失在一些数据集上表现不佳,分析后发现弱推动项的比重比强推动项还要大,即文本与正样本之间的距离往往比正负样本之间的距离更大,因此影响了强推动项的优化过程。
因此需要考虑一种方法,既能优化负样本之间的距离,也不影响强推动项的优化过程。
发明内容
本发明的目的提出一种文本分类方法,在优化三元组损失的过程中,同时优化正负样本之间的距离,达到提升文本分类任务的准确率的目的。
本发明提出的文本分类方法,在构建基于度量学习的三元组损失损失函数时,计算待分类文本与正类目标签之间的欧式距离d(a,p)、待分类文本与负类目标签之间的欧式距离 d(a,n)和正类目标签与负类目标签之间的欧式距离d(p,n),三元组损失损失函数的表达式为:
Figure RE-GDA0003175674530000021
其中,margin是一个超参数,要求待分类文本到负类目标签的欧式距离d(a,n)与待分类文本到正类目标签的欧式距离d(a,p)之差大于该超参数,margin的取值为1;
同时加入了一个“粗筛-精筛”的过程,先从多个类目中选出前m名,再从m个类目中选出1个作为文本分类结果。
本发明提出的文本分类方法,其特点及优点是:
1、本发明的文本分类方法,对三元组损失函数加入了一个受正负样本之间距离影响的系数
Figure RE-GDA0003175674530000022
当三元组损失为0时,系数值为1,无须优化,只有当三元组损失需要优化时,需要同时考虑将正负样本之间的距离拉远而提升分类效果,这种方法既能有效优化正负样本间的距离,当阈值选取合理的情况下,也不会影响Triplet Loss的优化效果。
2、本发明的文本分类方法,在类目数量较多时,相比于已有技术的直接找出文本所属的类目,本方法加入了一个“粗筛-精筛”的过程,先从很多类目中选出前m名,再从这m个类目中选出1个作为预测结果。这种方法构造了两层的级联模型,相比于单层模型,有效提高了分类准确率;
3、本发明的文本分类方法,应用简便,易于推广,除了文本分类任务也可以应用在计算机视觉等多种领域。
附图说明
图1为本发明方法与已有技术相比两种损失函数的区别示意图。
图2为本发明方法的流程框图。
具体实施方式
本发明提出的文本分类方法,在构建基于度量学习的三元组损失损失函数时,计算待分类文本与正类目标签之间的欧式距离d(a,p)、待分类文本与负类目标签之间的欧式距离 d(a,n)和正类目标签与负类目标签之间的欧式距离d(p,n),三元组损失损失函数的表达式为:
Figure RE-GDA0003175674530000023
其中,margin是一个超参数,要求待分类文本到负类目标签的欧式距离d(a,n)与待分类文本到正类目标签的欧式距离d(a,p)之差大于该超参数,margin的取值为1;
同时加入了一个“粗筛-精筛”的过程,先从多个类目中选出前m名,再从m个类目中选出1个作为文本分类结果。
上述的文本分类方法,其流程框图如图2所示,具体过程包括以下步骤:
(1)构建一个基本模型训练集,将由待分类文本、正样本与负样本组成的训练数据作为基本模型训练集中的一条数据,所述的正样本为待分类文本所属的类目标签,记为正类目标签,负样本为从基本模型训练集中随机抽取的多个类目标签,记为负类目标签;将待分类文本、正类目标签和多个负类目标签分别输入用于文本分类的预训练模型 (RoBERTa模型)中,预训练模型输出得到分别与待分类文本、正类目标签和多个负类目标签相对应的向量,对于同一个待分类文本,正类目标签和多个负类目标签的比例为1: (3~10),本发明的一个实施例中,为1∶5;
(2)分别计算待分类文本与正类目标签之间的欧式距离d(a,p)、待分类文本与负类目标签之间的欧式距离d(a,n)和正类目标签与负类目标签之间的欧式距离d(p,n);
(3)根据步骤(2)的欧式距离,得到基于度量学习的三元组损失函数:
Triplet Loss=(d(a,p)-d(a,n)+margin)+
其中margin是一个超参数,要求待分类文本到负样本的距离与待分类文本到正样本的距离之差大于一定的阈值。
本发明对损失函数进行改进:
Figure RE-GDA0003175674530000031
其中,margin是一个超参数,要求待分类文本到负类目标签的欧式距离d(a,n)与待分类文本到正类目标签的欧式距离d(a,p)之差大于该超参数,margin的取值为1;
三元组在向量空间的分布与两种损失函数的优化方向如图1所示。
(4)利用步骤(3)的损失函数Triangle Triplet Loss,对步骤(1)的预训练模型进行微调,使损失函数Triangle Triplet Loss最小化,得到一个微调后的预训练模型,记为基本模型;
(5)构建一个由待分类文本与相应的所有类目标签组成的基本模型测试集,将基本模型测试集中的待分类文本与相应的所有类目标签输入到步骤(4)的基本模型中,得到一个文本向量和多个相应的类目标签向量,分别计算文本向量与多个类目标签向量之间的欧式距离,对欧式距离进行从小到大的排序,排序后的欧式距离中,与文本向量的欧式距离最短的类目标签向量为相应的待分类文本的类目标签,实现文本分类。
本发明的文本分类方法,还可以包括以下步骤:
(6)从步骤(5)的排序后的欧式距离中,取出前m个类目标签,作为步骤(5)的测试集中待分类文本的类目标签候选集,记为二级模型测试集,其中m为3-10,本发明的一个实施例中,m取值为5;
(7)将步骤(1)的基本模型训练集中的待分类文本、正样本与负样本输入步骤(4)的基本模型中,得到一个文本向量和多个相应的类目标签向量,分别计算文本向量与多个类目标签向量之间的欧式距离,对欧式距离进行从小到大的排序,排序后的欧式距离中,与文本向量的欧式距离最短的类目标签向量为相应的待分类文本的类目标签;
(8)从步骤(7)的排序后的欧式距离中,取出前n个类目标签,作为步骤(7)的基本模型训练集中待分类文本的类目标签候选集,其中n为3-10,本发明的一个实施例中, n取值为5;将训练文本的负样本固定在前n名当中,记为二级模型训练集,对二级模型训练集中的类目标签进行判断,将前n名中的非正样本类目标签记为负样本类目标签;
(9)利用步骤(6)的二级模型测试集和步骤(8)的二级模型训练集,重复步骤(1) -步骤(4),对步骤(4)的基本模型进行微调,得到二级模型;重复步骤(5),实现最终文本分类。
本发明设计文本分类方法时,其关键技术在于对三元组损失函数加入了一个受正负样本之间距离影响的系数,进而优化深度度量学习的损失函数,并在模型的基础上,加入了二级级联模型的方法,将分类过程优化为一个“粗筛-精筛”的过程。
为使本发明的目的、技术方案和特点更加清楚明确,下面结合附图与实验所采用的数据集对具体实施方式进行详细说明与描述。
本发明采用kaggle的news-category-dataset公开的新闻文本分类数据集进行实验验证。该数据集包含从HuffPost获得的大约20万条从2012年至2018年的新闻文本,在该数据集上训练的模型可用于识别新闻文章的标签,共41个类目标签。服务器环境为python3.6, pytorch1.7.1,torchvision0.8.2,transformers4.1.1。验证实验过程如下:
步骤1:在训练数据中,确定每条新闻文本所属的类目标签,即为正样本,相对应的每条文本对应40个负样本。将新闻文本,正负样本的类目标签分别通过预训练模型得到它们的向量映射,对于同一个新闻文本,正负样本比例为1∶5,即构造的训练集数据约为 20万*5=100万条;
步骤2:用欧式距离表征待分类文本与正类目标签之间的距离d(a,p),待分类文本与负类目标签之间的距离d(a,n),正样本与负样本之间的距离d(p,n);
步骤3:基于度量学习损失函数三元组损失:
TripletLoss=(d(a,p)-d(a,n)+margin)+
其中margin是一个超参数,要求待分类文本到负样本的距离与待分类文本到正样本的距离之差大于阈值。
本发明对上述已有技术的对损失函数进行改进后:
Figure RE-GDA0003175674530000051
三元组在向量空间的分布与两种损失函数的优化方向如图1所示。
步骤4:基于改进后的Triangle Triplet Loss对预训练模型进行微调,得到一个基本模型。
步骤5:在测试数据中,将待分类的文本与它所对应的所有类目标签通过微调完成后的预训练模型,得到一个文本向量和若干个类目标签向量,计算文本向量与类目标签向量之间的距离,对距离进行排序,与文本向量距离最短的类目标签向量即是待分类文本所属的类目,使用Triplet Loss作为损失函数准确率约为0.674,使用Triangle TripletLoss作为损失函数准确率约为0.687;
步骤6:步骤5得到测试数据的类目标签的排序,取出41条类目标签的前5名作为测试文本新的类目候选集,但有部分文本的前5名并没有覆盖到它的标注类目标签,使用三元组损失Triplet Loss作为损失函数Top5覆盖率约为0.936,使用改进后的三元组损失Triangle Triplet Loss作为损失函数Top5覆盖率约为0.943;
步骤7:将训练数据的文本通过步骤5得到的模型,得到类目标签的前5名,将训练文本的负样本固定在前5名当中,构造新的训练集,将前5名中不属于正样本的类目标签全部作为负样本,正负样本比例约为1:5,二级模型训练数据约为20万*5=100万条。
步骤8:将步骤6与步骤7中得到的新的训练集与测试集重复步骤1-步骤5的过程,对预训练模型进行第二步微调,得到一个二级模型,并得到新的预测结果。对于二级模型,使用Triplet Loss作为损失函数准确率约为0.721,使用Triangle Triplet Loss作为损失函数准确率约为0.734,验证汇总结果如表1所示。
Figure RE-GDA0003175674530000052
表1
综上,本发明可有效提升文本分类的准确率。
以上实施例验证了本发明的正确性和实效性。以上所述仅为本发明具体应用于文本分类任务,并非用于限定本发明的保护范围。

Claims (1)

1.一种文本分类方法,其特征在于,包括以下步骤:
(1)构建一个基本模型训练集,将由待分类文本、正样本与负样本组成的训练数据作为基本模型训练集中的一条数据,所述的正样本为待分类文本所属的类目标签,记为正类目标签,负样本为从基本模型训练集中随机抽取的多个类目标签,记为负类目标签;将待分类文本、正类目标签和多个负类目标签分别输入用于文本分类的预训练模型中,预训练模型输出得到分别与待分类文本、正类目标签和多个负类目标签相对应的向量,对于同一个待分类文本,正类目标签和多个负类目标签的比例为1:(3~10);
(2)分别计算待分类文本与正类目标签之间的欧式距离d(a,p)、待分类文本与负类目标签之间的欧式距离d(a,n)和正类目标签与负类目标签之间的欧式距离d(p,n);
(3)根据步骤(2)的欧式距离,得到基于度量学习的三元组损失函数:
Figure FDA0003804282270000011
其中,margin是一个超参数,要求待分类文本到负类目标签的欧式距离d(a,n)与待分类文本到正类目标签的欧式距离d(a,p)之差大于该超参数,margin的取值为1;
(4)利用步骤(3)的损失函数Triangle Triplet Loss,对步骤(1)的预训练模型进行微调,使损失函数Triangle Triplet Loss最小化,得到一个微调后的预训练模型,记为基本模型;
(5)构建一个由待分类文本与相应的所有类目标签组成的基本模型测试集,将基本模型测试集中的待分类文本与相应的所有类目标签输入到步骤(4)的基本模型中,得到一个文本向量和多个相应的类目标签向量,分别计算文本向量与多个类目标签向量之间的欧式距离,对欧式距离进行从小到大的排序,排序后的欧式距离中,与文本向量的欧式距离最短的类目标签向量为相应的待分类文本的类目标签,实现文本分类;
(6)从步骤(5)的排序后的欧式距离中,取出前m个类目标签,作为步骤(5)的测试集中待分类文本的类目标签候选集,记为二级模型测试集,其中m为3-10;
(7)将步骤(1)的基本模型训练集中的待分类文本、正样本与负样本输入步骤(4)的基本模型中,得到一个文本向量和多个相应的类目标签向量,分别计算文本向量与多个类目标签向量之间的欧式距离,对欧式距离进行从小到大的排序,排序后的欧式距离中,与文本向量的欧式距离最短的类目标签向量为相应的待分类文本的类目标签;
(8)从步骤(7)的排序后的欧式距离中,取出前n个类目标签,作为步骤(7)的基本模型训练集中待分类文本的类目标签候选集,其中n为3-10,将训练文本的负样本固定在前n名当中,记为二级模型训练集,对二级模型训练集中的类目标签进行判断,将前n名中的非正样本类目标签记为负样本类目标签;
(9)利用步骤(6)的二级模型测试集和步骤(8)的二级模型训练集,重复步骤(1)-步骤(4),对步骤(4)的基本模型进行微调,得到二级模型;重复步骤(5),实现最终文本分类。
CN202110520242.3A 2021-05-13 2021-05-13 一种文本分类方法 Active CN113344031B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110520242.3A CN113344031B (zh) 2021-05-13 2021-05-13 一种文本分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110520242.3A CN113344031B (zh) 2021-05-13 2021-05-13 一种文本分类方法

Publications (2)

Publication Number Publication Date
CN113344031A CN113344031A (zh) 2021-09-03
CN113344031B true CN113344031B (zh) 2022-12-27

Family

ID=77468444

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110520242.3A Active CN113344031B (zh) 2021-05-13 2021-05-13 一种文本分类方法

Country Status (1)

Country Link
CN (1) CN113344031B (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113849653B (zh) * 2021-10-14 2023-04-07 鼎富智能科技有限公司 一种文本分类方法及装置

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109858552A (zh) * 2019-01-31 2019-06-07 深兰科技(上海)有限公司 一种用于细粒度分类的目标检测方法及设备
CN109948160A (zh) * 2019-03-15 2019-06-28 智者四海(北京)技术有限公司 短文本分类方法及装置
WO2019128367A1 (zh) * 2017-12-26 2019-07-04 广州广电运通金融电子股份有限公司 基于Triplet Loss的人脸认证方法、装置、计算机设备和存储介质
CN111858843A (zh) * 2019-04-30 2020-10-30 北京嘀嘀无限科技发展有限公司 一种文本分类方法及装置
CN112749268A (zh) * 2021-01-30 2021-05-04 云知声智能科技股份有限公司 基于混合策略的faq系统排序方法、装置及系统

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2019128367A1 (zh) * 2017-12-26 2019-07-04 广州广电运通金融电子股份有限公司 基于Triplet Loss的人脸认证方法、装置、计算机设备和存储介质
CN109858552A (zh) * 2019-01-31 2019-06-07 深兰科技(上海)有限公司 一种用于细粒度分类的目标检测方法及设备
CN109948160A (zh) * 2019-03-15 2019-06-28 智者四海(北京)技术有限公司 短文本分类方法及装置
CN111858843A (zh) * 2019-04-30 2020-10-30 北京嘀嘀无限科技发展有限公司 一种文本分类方法及装置
CN112749268A (zh) * 2021-01-30 2021-05-04 云知声智能科技股份有限公司 基于混合策略的faq系统排序方法、装置及系统

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
A Two-Stage Triplet Network Training Framework for Image Retrieval;Weiqing Min;《 IEEE Transactions on Multimedia》;20200220;全文 *
基于全路径相似度的大规模层次分类算法;朱建林等;《计算机工程与设计》;20190515(第05期);全文 *

Also Published As

Publication number Publication date
CN113344031A (zh) 2021-09-03

Similar Documents

Publication Publication Date Title
CN108628971B (zh) 不均衡数据集的文本分类方法、文本分类器及存储介质
CN111368920A (zh) 基于量子孪生神经网络的二分类方法及其人脸识别方法
CN106651057A (zh) 一种基于安装包序列表的移动端用户年龄预测方法
CN109993236A (zh) 基于one-shot Siamese卷积神经网络的少样本满文匹配方法
CN102262642B (zh) 一种Web图像搜索引擎及其实现方法
CN108959474B (zh) 实体关系提取方法
CN112417132B (zh) 一种利用谓宾信息筛选负样本的新意图识别方法
CN111860671A (zh) 分类模型训练方法、装置、终端设备和可读存储介质
CN113344031B (zh) 一种文本分类方法
Lumauag et al. An enhanced recommendation algorithm based on modified user-based collaborative filtering
CN110515836B (zh) 一种面向软件缺陷预测的加权朴素贝叶斯方法
CN108681532A (zh) 一种面向中文微博的情感分析方法
CN111144462A (zh) 一种雷达信号的未知个体识别方法及装置
CN104281569A (zh) 构建装置和方法、分类装置和方法以及电子设备
CN111984790B (zh) 一种实体关系抽取方法
CN112836731A (zh) 基于决策树准确率和相关性度量的信号随机森林分类方法、系统及装置
CN108268458B (zh) 一种基于knn算法的半结构化数据分类方法及装置
CN114328921B (zh) 一种基于分布校准的小样本实体关系抽取方法
CN116153299A (zh) 训练样本的处理方法、语音质检方法及装置
CN113010687B (zh) 一种习题标签预测方法、装置、存储介质以及计算机设备
CN114067165A (zh) 一种含噪声标记分布的图像筛选和学习方法与装置
CN110162629B (zh) 一种基于多基模型框架的文本分类方法
CN111783788B (zh) 一种面向标记噪声的多标记分类方法
CN114547264A (zh) 一种基于马氏距离和对比学习的新意图数据识别方法
CN113780463A (zh) 一种基于深度神经网络的多头归一化长尾分类方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant