CN113344031A - 一种文本分类方法 - Google Patents
一种文本分类方法 Download PDFInfo
- Publication number
- CN113344031A CN113344031A CN202110520242.3A CN202110520242A CN113344031A CN 113344031 A CN113344031 A CN 113344031A CN 202110520242 A CN202110520242 A CN 202110520242A CN 113344031 A CN113344031 A CN 113344031A
- Authority
- CN
- China
- Prior art keywords
- text
- classified
- category label
- category
- negative
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
Landscapes
- Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Theoretical Computer Science (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Biology (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明属于深度学习与算法领域,尤其涉及一种文本分类方法。本发明在构建基于度量学习的三元组损失损失函数时,计算待分类文本与正类目标签之间的欧式距离、待分类文本与负类目标签之间的欧式距离和正类目标签与负类目标签之间的欧式距离,并得到三元组损失函数,同时加入了一个“粗筛‑精筛”的过程。本发明方法基于样本三元组,在优化待分类文本与正样本和负样本的距离之差的同时,加入一个系数同时优化正样本与负样本之间的距离,构造了两层的级联模型,相比于单层模型,有效提高了分类准确率。本发明的文本分类方法,应用简便,易于推广,除了文本分类任务也可以应用在计算机视觉等多种领域。
Description
技术领域
本发明属于深度学习与算法领域,尤其涉及一种文本分类方法。
背景技术
在文本分类任务中,当已知文本类目的标签时,可以采用度量文本与标签距离的方法来衡量。这样和多分类模型相比,可以引入标签的文本信息,提升分类的准确率。
分类任务使用度量学习方法,计算待分类的文本与各个类目标签通过预训练模型后转化为向量的距离,将距离进行排序,距离最短的类目标签即为该文本所属的类目。
度量学习的损失函数当前主要采用三元组损失(Triplet Loss),将每一段待分类的文本作为锚点,它所属的类目标签作为正样本,在其他类目标签中随机选取一个作为负样本。让文本对应的向量尽可能靠近正样本类目标签对应的向量,并远离负样本类目标签对应的向量,通过这种方法对预训练模型进行微调。损失函数:
TripletLoss=(d(a,p)-d(a,n)+margin)+
但Triplet Loss损失函数的表达式仅仅考虑到文本与正样本距离尽可能近,与负样本尽可能远,实际上也希望同时满足于类目标签之间的距离尽可能远,并将其加入损失函数中。
根据这一要求,提出了四元组损失(Quadruplet Loss)的改进:
QuadrupletLoss=(d(a,p)-d(a,n1)+α)++(d(a,p)-d(n1,n2)+β)+
四元组损失加入了新的负样本n2,让两个负样本之间的距离尽可能远,另外弱推动项中也能使待分类文本的向量与正样本向量尽可能近。但四元组损失在一些数据集上表现不佳,分析后发现弱推动项的比重比强推动项还要大,即文本与正样本之间的距离往往比正负样本之间的距离更大,因此影响了强推动项的优化过程。
因此需要考虑一种方法,既能优化负样本之间的距离,也不影响强推动项的优化过程。
发明内容
本发明的目的提出一种文本分类方法,在优化三元组损失的过程中,同时优化正负样本之间的距离,达到提升文本分类任务的准确率的目的。
本发明提出的文本分类方法,在构建基于度量学习的三元组损失损失函数时,计算待分类文本与正类目标签之间的欧式距离d(a,p)、待分类文本与负类目标签之间的欧式距离 d(a,n)和正类目标签与负类目标签之间的欧式距离d(p,n),三元组损失损失函数的表达式为:
其中,margin是一个超参数,要求待分类文本到负类目标签的欧式距离d(a,n)与待分类文本到正类目标签的欧式距离d(a,p)之差大于该超参数,margin的取值为1;
同时加入了一个“粗筛-精筛”的过程,先从多个类目中选出前m名,再从m个类目中选出1个作为文本分类结果。
本发明提出的文本分类方法,其特点及优点是:
1、本发明的文本分类方法,对三元组损失函数加入了一个受正负样本之间距离影响的系数当三元组损失为0时,系数值为1,无须优化,只有当三元组损失需要优化时,需要同时考虑将正负样本之间的距离拉远而提升分类效果,这种方法既能有效优化正负样本间的距离,当阈值选取合理的情况下,也不会影响Triplet Loss的优化效果。
2、本发明的文本分类方法,在类目数量较多时,相比于已有技术的直接找出文本所属的类目,本方法加入了一个“粗筛-精筛”的过程,先从很多类目中选出前m名,再从这m个类目中选出1个作为预测结果。这种方法构造了两层的级联模型,相比于单层模型,有效提高了分类准确率;
3、本发明的文本分类方法,应用简便,易于推广,除了文本分类任务也可以应用在计算机视觉等多种领域。
附图说明
图1为本发明方法与已有技术相比两种损失函数的区别示意图。
图2为本发明方法的流程框图。
具体实施方式
本发明提出的文本分类方法,在构建基于度量学习的三元组损失损失函数时,计算待分类文本与正类目标签之间的欧式距离d(a,p)、待分类文本与负类目标签之间的欧式距离 d(a,n)和正类目标签与负类目标签之间的欧式距离d(p,n),三元组损失损失函数的表达式为:
其中,margin是一个超参数,要求待分类文本到负类目标签的欧式距离d(a,n)与待分类文本到正类目标签的欧式距离d(a,p)之差大于该超参数,margin的取值为1;
同时加入了一个“粗筛-精筛”的过程,先从多个类目中选出前m名,再从m个类目中选出1个作为文本分类结果。
上述的文本分类方法,其流程框图如图2所示,具体过程包括以下步骤:
(1)构建一个基本模型训练集,将由待分类文本、正样本与负样本组成的训练数据作为基本模型训练集中的一条数据,所述的正样本为待分类文本所属的类目标签,记为正类目标签,负样本为从基本模型训练集中随机抽取的多个类目标签,记为负类目标签;将待分类文本、正类目标签和多个负类目标签分别输入用于文本分类的预训练模型 (RoBERTa模型)中,预训练模型输出得到分别与待分类文本、正类目标签和多个负类目标签相对应的向量,对于同一个待分类文本,正类目标签和多个负类目标签的比例为1: (3~10),本发明的一个实施例中,为1∶5;
(2)分别计算待分类文本与正类目标签之间的欧式距离d(a,p)、待分类文本与负类目标签之间的欧式距离d(a,n)和正类目标签与负类目标签之间的欧式距离d(p,n);
(3)根据步骤(2)的欧式距离,得到基于度量学习的三元组损失函数:
Triplet Loss=(d(a,p)-d(a,n)+margin)+
其中margin是一个超参数,要求待分类文本到负样本的距离与待分类文本到正样本的距离之差大于一定的阈值。
本发明对损失函数进行改进:
其中,margin是一个超参数,要求待分类文本到负类目标签的欧式距离d(a,n)与待分类文本到正类目标签的欧式距离d(a,p)之差大于该超参数,margin的取值为1;
三元组在向量空间的分布与两种损失函数的优化方向如图1所示。
(4)利用步骤(3)的损失函数Triangle Triplet Loss,对步骤(1)的预训练模型进行微调,使损失函数Triangle Triplet Loss最小化,得到一个微调后的预训练模型,记为基本模型;
(5)构建一个由待分类文本与相应的所有类目标签组成的基本模型测试集,将基本模型测试集中的待分类文本与相应的所有类目标签输入到步骤(4)的基本模型中,得到一个文本向量和多个相应的类目标签向量,分别计算文本向量与多个类目标签向量之间的欧式距离,对欧式距离进行从小到大的排序,排序后的欧式距离中,与文本向量的欧式距离最短的类目标签向量为相应的待分类文本的类目标签,实现文本分类。
本发明的文本分类方法,还可以包括以下步骤:
(6)从步骤(5)的排序后的欧式距离中,取出前m个类目标签,作为步骤(5)的测试集中待分类文本的类目标签候选集,记为二级模型测试集,其中m为3-10,本发明的一个实施例中,m取值为5;
(7)将步骤(1)的基本模型训练集中的待分类文本、正样本与负样本输入步骤(4)的基本模型中,得到一个文本向量和多个相应的类目标签向量,分别计算文本向量与多个类目标签向量之间的欧式距离,对欧式距离进行从小到大的排序,排序后的欧式距离中,与文本向量的欧式距离最短的类目标签向量为相应的待分类文本的类目标签;
(8)从步骤(7)的排序后的欧式距离中,取出前n个类目标签,作为步骤(7)的基本模型训练集中待分类文本的类目标签候选集,其中n为3-10,本发明的一个实施例中, n取值为5;将训练文本的负样本固定在前n名当中,记为二级模型训练集,对二级模型训练集中的类目标签进行判断,将前n名中的非正样本类目标签记为负样本类目标签;
(9)利用步骤(6)的二级模型测试集和步骤(8)的二级模型训练集,重复步骤(1) -步骤(4),对步骤(4)的基本模型进行微调,得到二级模型;重复步骤(5),实现最终文本分类。
本发明设计文本分类方法时,其关键技术在于对三元组损失函数加入了一个受正负样本之间距离影响的系数,进而优化深度度量学习的损失函数,并在模型的基础上,加入了二级级联模型的方法,将分类过程优化为一个“粗筛-精筛”的过程。
为使本发明的目的、技术方案和特点更加清楚明确,下面结合附图与实验所采用的数据集对具体实施方式进行详细说明与描述。
本发明采用kaggle的news-category-dataset公开的新闻文本分类数据集进行实验验证。该数据集包含从HuffPost获得的大约20万条从2012年至2018年的新闻文本,在该数据集上训练的模型可用于识别新闻文章的标签,共41个类目标签。服务器环境为python3.6, pytorch1.7.1,torchvision0.8.2,transformers4.1.1。验证实验过程如下:
步骤1:在训练数据中,确定每条新闻文本所属的类目标签,即为正样本,相对应的每条文本对应40个负样本。将新闻文本,正负样本的类目标签分别通过预训练模型得到它们的向量映射,对于同一个新闻文本,正负样本比例为1∶5,即构造的训练集数据约为 20万*5=100万条;
步骤2:用欧式距离表征待分类文本与正类目标签之间的距离d(a,p),待分类文本与负类目标签之间的距离d(a,n),正样本与负样本之间的距离d(p,n);
步骤3:基于度量学习损失函数三元组损失:
TripletLoss=(d(a,p)-d(a,n)+margin)+
其中margin是一个超参数,要求待分类文本到负样本的距离与待分类文本到正样本的距离之差大于阈值。
本发明对上述已有技术的对损失函数进行改进后:
三元组在向量空间的分布与两种损失函数的优化方向如图1所示。
步骤4:基于改进后的Triangle Triplet Loss对预训练模型进行微调,得到一个基本模型。
步骤5:在测试数据中,将待分类的文本与它所对应的所有类目标签通过微调完成后的预训练模型,得到一个文本向量和若干个类目标签向量,计算文本向量与类目标签向量之间的距离,对距离进行排序,与文本向量距离最短的类目标签向量即是待分类文本所属的类目,使用Triplet Loss作为损失函数准确率约为0.674,使用Triangle TripletLoss作为损失函数准确率约为0.687;
步骤6:步骤5得到测试数据的类目标签的排序,取出41条类目标签的前5名作为测试文本新的类目候选集,但有部分文本的前5名并没有覆盖到它的标注类目标签,使用三元组损失Triplet Loss作为损失函数Top5覆盖率约为0.936,使用改进后的三元组损失Triangle Triplet Loss作为损失函数Top5覆盖率约为0.943;
步骤7:将训练数据的文本通过步骤5得到的模型,得到类目标签的前5名,将训练文本的负样本固定在前5名当中,构造新的训练集,将前5名中不属于正样本的类目标签全部作为负样本,正负样本比例约为1:5,二级模型训练数据约为20万*5=100万条。
步骤8:将步骤6与步骤7中得到的新的训练集与测试集重复步骤1-步骤5的过程,对预训练模型进行第二步微调,得到一个二级模型,并得到新的预测结果。对于二级模型,使用Triplet Loss作为损失函数准确率约为0.721,使用Triangle Triplet Loss作为损失函数准确率约为0.734,验证汇总结果如表1所示。
表1
综上,本发明可有效提升文本分类的准确率。
以上实施例验证了本发明的正确性和实效性。以上所述仅为本发明具体应用于文本分类任务,并非用于限定本发明的保护范围。
Claims (3)
2.一种如权利要求1所述的文本分类方法,其特征在于具体过程包括以下步骤:
(1)构建一个基本模型训练集,将由待分类文本、正样本与负样本组成的训练数据作为基本模型训练集中的一条数据,所述的正样本为待分类文本所属的类目标签,记为正类目标签,负样本为从基本模型训练集中随机抽取的多个类目标签,记为负类目标签;将待分类文本、正类目标签和多个负类目标签分别输入用于文本分类的预训练模型(RoBERTa模型)中,预训练模型输出得到分别与待分类文本、正类目标签和多个负类目标签相对应的向量,对于同一个待分类文本,正类目标签和多个负类目标签的比例为1:(3~10);
(2)分别计算待分类文本与正类目标签之间的欧式距离d(a,p)、待分类文本与负类目标签之间的欧式距离d(a,n)和正类目标签与负类目标签之间的欧式距离d(p,n);
(3)根据步骤(2)的欧式距离,得到基于度量学习的三元组损失函数:
其中,margin是一个超参数,要求待分类文本到负类目标签的欧式距离d(a,n)与待分类文本到正类目标签的欧式距离d(a,p)之差大于该超参数,margin的取值为1;
(4)利用步骤(3)的损失函数Triangle Triplet Loss,对步骤(1)的预训练模型进行微调,使损失函数Triangle Triplet Loss最小化,得到一个微调后的预训练模型,记为基本模型;
(5)构建一个由待分类文本与相应的所有类目标签组成的基本模型测试集,将基本模型测试集中的待分类文本与相应的所有类目标签输入到步骤(4)的基本模型中,得到一个文本向量和多个相应的类目标签向量,分别计算文本向量与多个类目标签向量之间的欧式距离,对欧式距离进行从小到大的排序,排序后的欧式距离中,与文本向量的欧式距离最短的类目标签向量为相应的待分类文本的类目标签,实现文本分类。
3.如权利要求1所述的文本分类方法,其特征在于还包括以下步骤:
(6)从步骤(5)的排序后的欧式距离中,取出前m个类目标签,作为步骤(5)的测试集中待分类文本的类目标签候选集,记为二级模型测试集,其中m为3-10;
(7)将步骤(1)的基本模型训练集中的待分类文本、正样本与负样本输入步骤(4)的基本模型中,得到一个文本向量和多个相应的类目标签向量,分别计算文本向量与多个类目标签向量之间的欧式距离,对欧式距离进行从小到大的排序,排序后的欧式距离中,与文本向量的欧式距离最短的类目标签向量为相应的待分类文本的类目标签;
(8)从步骤(7)的排序后的欧式距离中,取出前n个类目标签,作为步骤(7)的基本模型训练集中待分类文本的类目标签候选集,其中n为3-10,将训练文本的负样本固定在前n名当中,记为二级模型训练集,对二级模型训练集中的类目标签进行判断,将前n名中的非正样本类目标签记为负样本类目标签;
(9)利用步骤(6)的二级模型测试集和步骤(8)的二级模型训练集,重复步骤(1)-步骤(4),对步骤(4)的基本模型进行微调,得到二级模型;重复步骤(5),实现最终文本分类。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110520242.3A CN113344031B (zh) | 2021-05-13 | 2021-05-13 | 一种文本分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110520242.3A CN113344031B (zh) | 2021-05-13 | 2021-05-13 | 一种文本分类方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113344031A true CN113344031A (zh) | 2021-09-03 |
CN113344031B CN113344031B (zh) | 2022-12-27 |
Family
ID=77468444
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110520242.3A Active CN113344031B (zh) | 2021-05-13 | 2021-05-13 | 一种文本分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113344031B (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113849653A (zh) * | 2021-10-14 | 2021-12-28 | 鼎富智能科技有限公司 | 一种文本分类方法及装置 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109858552A (zh) * | 2019-01-31 | 2019-06-07 | 深兰科技(上海)有限公司 | 一种用于细粒度分类的目标检测方法及设备 |
CN109948160A (zh) * | 2019-03-15 | 2019-06-28 | 智者四海(北京)技术有限公司 | 短文本分类方法及装置 |
WO2019128367A1 (zh) * | 2017-12-26 | 2019-07-04 | 广州广电运通金融电子股份有限公司 | 基于Triplet Loss的人脸认证方法、装置、计算机设备和存储介质 |
CN111858843A (zh) * | 2019-04-30 | 2020-10-30 | 北京嘀嘀无限科技发展有限公司 | 一种文本分类方法及装置 |
CN112749268A (zh) * | 2021-01-30 | 2021-05-04 | 云知声智能科技股份有限公司 | 基于混合策略的faq系统排序方法、装置及系统 |
-
2021
- 2021-05-13 CN CN202110520242.3A patent/CN113344031B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2019128367A1 (zh) * | 2017-12-26 | 2019-07-04 | 广州广电运通金融电子股份有限公司 | 基于Triplet Loss的人脸认证方法、装置、计算机设备和存储介质 |
CN109858552A (zh) * | 2019-01-31 | 2019-06-07 | 深兰科技(上海)有限公司 | 一种用于细粒度分类的目标检测方法及设备 |
CN109948160A (zh) * | 2019-03-15 | 2019-06-28 | 智者四海(北京)技术有限公司 | 短文本分类方法及装置 |
CN111858843A (zh) * | 2019-04-30 | 2020-10-30 | 北京嘀嘀无限科技发展有限公司 | 一种文本分类方法及装置 |
CN112749268A (zh) * | 2021-01-30 | 2021-05-04 | 云知声智能科技股份有限公司 | 基于混合策略的faq系统排序方法、装置及系统 |
Non-Patent Citations (2)
Title |
---|
WEIQING MIN: "A Two-Stage Triplet Network Training Framework for Image Retrieval", 《 IEEE TRANSACTIONS ON MULTIMEDIA》 * |
朱建林等: "基于全路径相似度的大规模层次分类算法", 《计算机工程与设计》 * |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113849653A (zh) * | 2021-10-14 | 2021-12-28 | 鼎富智能科技有限公司 | 一种文本分类方法及装置 |
Also Published As
Publication number | Publication date |
---|---|
CN113344031B (zh) | 2022-12-27 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN105260356B (zh) | 基于多任务学习的中文交互文本情感与话题识别方法 | |
CN111368920B (zh) | 基于量子孪生神经网络的二分类方法及其人脸识别方法 | |
CN103605990B (zh) | 基于图聚类标签传播的集成多分类器融合分类方法和系统 | |
CN104750844A (zh) | 基于tf-igm的文本特征向量生成方法和装置及文本分类方法和装置 | |
CN106651057A (zh) | 一种基于安装包序列表的移动端用户年龄预测方法 | |
CN108959474B (zh) | 实体关系提取方法 | |
CN114998602B (zh) | 基于低置信度样本对比损失的域适应学习方法及系统 | |
CN112417132B (zh) | 一种利用谓宾信息筛选负样本的新意图识别方法 | |
CN113344031B (zh) | 一种文本分类方法 | |
CN108681532A (zh) | 一种面向中文微博的情感分析方法 | |
CN111144462A (zh) | 一种雷达信号的未知个体识别方法及装置 | |
CN108229565B (zh) | 一种基于认知的图像理解方法 | |
CN117516937A (zh) | 基于多模态特征融合增强的滚动轴承未知故障检测方法 | |
CN112836731A (zh) | 基于决策树准确率和相关性度量的信号随机森林分类方法、系统及装置 | |
CN111242131B (zh) | 一种智能阅卷中图像识别的方法、存储介质及装置 | |
CN106202045B (zh) | 基于车联网的专项语音识别方法 | |
CN111984790A (zh) | 一种实体关系抽取方法 | |
CN116153299A (zh) | 训练样本的处理方法、语音质检方法及装置 | |
CN106057196A (zh) | 车载语音数据解析识别方法 | |
CN103207893B (zh) | 基于向量组映射的两类文本的分类方法 | |
CN111783788B (zh) | 一种面向标记噪声的多标记分类方法 | |
CN114067165A (zh) | 一种含噪声标记分布的图像筛选和学习方法与装置 | |
CN110162629B (zh) | 一种基于多基模型框架的文本分类方法 | |
CN114359568A (zh) | 一种基于多粒度特征的多标签场景图生成方法 | |
CN114817537A (zh) | 一种基于政策文件数据的分类方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |