CN116720498A - 一种文本相似度检测模型的训练方法、装置及其相关介质 - Google Patents
一种文本相似度检测模型的训练方法、装置及其相关介质 Download PDFInfo
- Publication number
- CN116720498A CN116720498A CN202310804728.9A CN202310804728A CN116720498A CN 116720498 A CN116720498 A CN 116720498A CN 202310804728 A CN202310804728 A CN 202310804728A CN 116720498 A CN116720498 A CN 116720498A
- Authority
- CN
- China
- Prior art keywords
- text
- domain
- loss
- feature vector
- detection model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000001514 detection method Methods 0.000 title claims abstract description 81
- 238000012549 training Methods 0.000 title claims abstract description 62
- 238000000034 method Methods 0.000 title claims abstract description 50
- 239000013598 vector Substances 0.000 claims abstract description 122
- 238000004364 calculation method Methods 0.000 claims abstract description 17
- 230000006870 function Effects 0.000 claims description 34
- 238000005070 sampling Methods 0.000 claims description 24
- 238000002372 labelling Methods 0.000 claims description 20
- 238000004590 computer program Methods 0.000 claims description 10
- 238000012216 screening Methods 0.000 claims description 5
- 230000014509 gene expression Effects 0.000 claims description 4
- 230000000052 comparative effect Effects 0.000 claims description 2
- 230000000694 effects Effects 0.000 abstract description 7
- 230000008569 process Effects 0.000 description 12
- 238000012423 maintenance Methods 0.000 description 7
- 230000006978 adaptation Effects 0.000 description 5
- 238000013528 artificial neural network Methods 0.000 description 4
- 238000013519 translation Methods 0.000 description 4
- 230000009471 action Effects 0.000 description 3
- 238000004422 calculation algorithm Methods 0.000 description 3
- 238000009826 distribution Methods 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 239000000463 material Substances 0.000 description 3
- 238000005457 optimization Methods 0.000 description 3
- 238000010586 diagram Methods 0.000 description 2
- 238000011156 evaluation Methods 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 239000004973 liquid crystal related substance Substances 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000003058 natural language processing Methods 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000006243 chemical reaction Methods 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 239000012634 fragment Substances 0.000 description 1
- 238000011478 gradient descent method Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000011524 similarity measure Methods 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F40/00—Handling natural language data
- G06F40/10—Text processing
- G06F40/194—Calculation of difference between files
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/35—Clustering; Classification
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Abstract
本发明公开了一种文本相似度检测模型的训练方法、装置及其相关介质,该方法包括获取公开文本数据集和私有文本数据集,并进行文本相似度标注和混合,得到混合数据集;利用预训练的文本相似度检测模型对所述混合数据集进行文本推理,得到文本特征向量和域标签;根据所述文本特征向量和域标签计算分别得到对比损失结果、域差异损失结果和域分类损失结果;并相加后进行反向传播所述预训练的文本相似度检测模型的参数,得到最终的文本相似度检测模型。本发明通过对计算得到的对比损失结果、域差异损失结果和域分类损失结果进行相加和反向传播,得到最终的文本相似度检测模型,如此,提高了文本之间的区分度,优化了文本相似度检测效果。
Description
技术领域
本发明涉及自然语言处理技术领域,特别涉及一种文本相似度检测模型的训练方法、装置及其相关介质。
背景技术
目前,关于文本相似度检测技术大多数都使用预训练语言模型,将预训练语言模型的推理输出作为句向量,或者针对预训练语言模型,在有标注的语料上进行微调;但是,在单一领域上的语料微调得到的模型在另外一领域上的表现通常较差。针对上述存在的问题,现有技术中已经有解决方法,例如,现有技术(CN113672718A)中实现域自适应的方法是使用LDA变换,能够将原始句向量投影到另一个空间,使得在这个空间内域间差异最大,域内差异最小,实现了域自适应转换;但是上述现有技术依旧存在问题,如直接使用模型提取出的句向量中,高频词会占更大的权重,这将导致文本之间区分度较差;微调训练的语料和实际使用的语料之间存在较大的领域差异,这将导致跨领域的文本相似度检测效果变差。
发明内容
本发明实施例提供了一种文本相似度检测模型的训练方法、装置及其相关介质,旨在解决现有技术中相似文本之间区分度较差,导致检测效果较差的问题。
第一方面,本发明实施例提供了一种文本相似度检测模型的训练方法,包括:
获取预训练的文本相似度检测模型;
获取公开文本数据集,并对所述公开文本数据集进行文本相似度标注,得到第一文本数据;
获取私有文本数据集,并对所述私有文本数据集进行文本相似度标注,得到第二文本数据;
将所述第一文本数据和第二文本数据进行语料混合,得到混合数据集;
利用所述预训练的文本相似度检测模型对所述混合数据集进行文本推理,分别得到第一文本特征向量和第二文本特征向量;其中,所述第一文本特征向量和第二文本特征向量均对应有域标签;
根据所述第一文本特征向量和第二文本特征向量计算对比损失,得到对比损失结果;
根据所述域标签计算域差异损失,得到域差异损失结果;
根据所述第一文本特征向量计算域分类损失,得到域分类损失结果;
将所述对比损失结果、域差异损失结果和域分类损失结果进行相加,并进行反向传播所述预训练的文本相似度检测模型的参数,得到最终的文本相似度检测模型。
第二方面,本发明实施例提供了一种文本相似度检测模型的训练装置,包括:
第一获取单元,用于获取预训练的文本相似度检测模型;
第二获取单元,用于获取公开文本数据集,并对所述公开文本数据集进行文本相似度标注,得到第一文本数据;
第三获取单元,用于获取私有文本数据集,并对所述私有文本数据集进行文本相似度标注,得到第二文本数据;
数据混合单元,用于将所述第一文本数据和第二文本数据进行语料混合,得到混合数据集;
模型推理单元,用于利用所述预训练的文本相似度检测模型对所述混合数据集进行文本推理,分别得到第一文本特征向量和第二文本特征向量;其中,所述第一文本特征向量和第二文本特征向量均对应有域标签;
第一损失单元,用于根据所述第一文本特征向量和第二文本特征向量计算对比损失,得到对比损失结果;
第二损失单元,用于根据所述域标签计算域差异损失,得到域差异损失结果;
第三损失单元,用于根据所述第一文本特征向量计算域分类损失,得到域分类损失结果;
模型输出单元,用于将所述对比损失结果、域差异损失结果和域分类损失结果进行相加,并进行反向传播所述预训练的文本相似度检测模型的参数,得到最终的文本相似度检测模型。
第三方面,本发明实施例提供了一种计算机设备,包括存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现所述第一方面的文本相似度检测模型的训练方法。
第四方面,本发明实施例提供了一种计算机可读存储介质,其中,所述计算机可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现所述第一方面的文本相似度检测模型的训练方法。
本发明实施例提供一种文本相似度检测模型的训练方法,包括获取公开文本数据集和私有文本数据集,并进行文本相似度标注和混合,得到混合数据集;利用预训练的文本相似度检测模型对所述混合数据集进行文本推理,得到文本特征向量和域标签;根据所述文本特征向量和域标签计算分别得到对比损失结果、域差异损失结果和域分类损失结果;并相加后进行反向传播所述预训练的文本相似度检测模型的参数,得到最终的文本相似度检测模型。本发明通过对计算得到的对比损失结果、域差异损失结果和域分类损失结果进行相加和反向传播,得到最终的文本相似度检测模型,如此,提高了文本之间的区分度,优化了文本相似度检测效果。
本发明实施例还提供一种文本相似度检测模型的训练装置、计算机设备和存储介质,同样具有上述有益效果。
附图说明
为了更清楚地说明本发明实施例技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本发明实施例提供的一种文本相似度检测模型的训练方法的流程示意图;
图2为本发明实施例提供的一种文本相似度检测模型的训练方法的另一流程示意图;
图3为本发明实施例提供的一种文本相似度检测模型的训练装置的示意性框图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
应当理解,当在本说明书和所附权利要求书中使用时,术语“包括”和“包含”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在此本发明说明书中所使用的术语仅仅是出于描述特定实施例的目的而并不意在限制本发明。如在本发明说明书和所附权利要求书中所使用的那样,除非上下文清楚地指明其它情况,否则单数形式的“一”、“一个”及“该”意在包括复数形式。
还应当进一步理解,在本发明说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
下面请参见图1,图1为本发明实施例提供的一种文本相似度检测模型的训练方法的流程示意图,具体包括:步骤S101~S109。
S101、获取预训练的文本相似度检测模型;
S102、获取公开文本数据集,并对所述公开文本数据集进行文本相似度标注,得到第一文本数据;
S103、获取私有文本数据集,并对所述私有文本数据集进行文本相似度标注,得到第二文本数据;
S104、将所述第一文本数据和第二文本数据进行语料混合,得到混合数据集;
S105、利用所述预训练的文本相似度检测模型对所述混合数据集进行文本推理,分别得到第一文本特征向量和第二文本特征向量;其中,所述第一文本特征向量和第二文本特征向量均对应有域标签;
S106、根据所述第一文本特征向量和第二文本特征向量计算对比损失,得到对比损失结果;
S107、根据所述域标签计算域差异损失,得到域差异损失结果;
S108、根据所述第一文本特征向量计算域分类损失,得到域分类损失结果;
S109、将所述对比损失结果、域差异损失结果和域分类损失结果进行相加,并进行反向传播所述预训练的文本相似度检测模型的参数,得到最终的文本相似度检测模型。
结合图2所示,在步骤S101中,获取预训练的文本相似度检测模型,例如Bert、roBERTa等,这些文本相似度检测模型是在通用的文本数据上进行训练的,具有很强的语义理解能力。以Bert为例,Bert是一种广泛应用的预训练语言模型,训练过程主要包括两个阶段:首先,收集大量的公开语料库,例如维基百科、书籍文本等,以作为训练数据,这些训练数据具有广泛的领域覆盖,可以帮助语言模型学习通用的语言知识;接下来,在基于Transformer结构的模型中,Bert采用了一种称为"Masked Language Model"(MLM)的训练目标,在这个训练目标中,Bert会随机屏蔽输入文本的一部分字句,然后通过语言模型预测这些被屏蔽的字句,这个过程可以帮助Bert学习上下文信息和词汇语义。另外,Bert还采用了"Next Sentence Prediction"(NSP)任务进行训练,在NSP任务中,Bert会接收两个句子作为输入,并判断它们是否是连续的,NSP任务有助于Bert学习句子之间的关联和语义连贯性;通过预训练过程,Bert可以学习到丰富的语义信息,并生成对应输入文本的句向量表示,句向量可以用于文本相似度检测任务,通过计算向量之间的相似度来判断文本的相似程度。在实际应用中,开发者可以使用预训练好的Bert模型,或者进行进一步的微调和训练,以适应具体的文本相似度检测任务和领域需求,如此,预训练的语言模型为跨领域的文本相似度检测提供了一种有效的方法,并且在自然语言处理领域取得了广泛的应用。
在步骤S102中,需要获取公开的文本数据集并对其进行文本相似度标注,以获得第一文本数据,在实践中,很多公开的文本数据集可供使用,其中一些专注于文本相似度任务的公开数据集如下示例:
STS(Semantic Textual Similarity)数据集,这是一个广泛使用的文本相似度数据集系列,包括STS-B(基于句子对比较的二元分类任务)和STS-Multi(基于句子对比较的回归任务)等,这些数据集包含了从不同领域和来源收集的句子对,并给出了相似度评分;
Quora Question Pairs数据集,包含了从Quora问答社区收集的问题对,其中每个问题对都被标注为相似或不相似,主要用于问题匹配和文本相似度任务;
SICK(Sentences Involving Compositional Knowledge)数据集,包含了从多个来源收集的句子对,用于评估文本相似度,每个句子对都被标注为相似或不相似;
Microsoft Research Paraphrase Corpus数据集,包含了从互联网收集的句子对,其中每个句子对都被标注为是否是意义上的同义句,主要用于文本相似度和句子重述任务。
在选择适合需求的数据集后,可以进行文本相似度标注,标注的过程需要人工判断两个文本是否相似,并为其分配相应的标签,例如相似或不相似;通常情况下,除了标注为相似的文本对,还需要选择一些互为负样本的文本对,以提供对比和训练的数据;完成文本相似度标注后,可以得到第一文本数据,即包含句子对和相应标签的数据集,可以用于训练文本相似度检测模型,提高模型在跨领域上的效果,并用于后续的模型评估和改进。
在步骤S103中,需要获取私有的物业设备设施维修领域的无标注文本数据,并对其进行文本相似度标注,以获得第二文本数据;获取私有文本数据集可以涉及到物业设备设施维修领域的相关文档、报告、工单等,可以是企业内部的数据或特定机构的数据,具有特定领域的知识和信息;接下来,使用翻译程序将中文语料(私有文本数据集)翻译成外文再回译成中文,这种方法通常称为"回译增强"(Back Translation)技术,可以扩充数据集并增加样本的多样性,具体步骤为:将私有的中文语料进行机器翻译,将其翻译成外文(例如英文);将翻译后的外文语料再次进行机器翻译,将其回译成中文。
通过这个过程,可以得到经过翻译和回译的中文语料,这些中文语料可以被认为是与原始中文语料相似但具有一定变化的句子;在进行文本相似度标注时,将翻译和回译后的中文语料与原始的中文语料进行配对,将配对的句对标记为正样本,除了翻译和回译的句对,还可以选择一些互为负样本的句对,这些句对可以是随机选择的其他不相似的句子对,如此,可以提供对比和训练的数据。完成文本相似度标注后,可以得到第二文本数据,即包含经过翻译和回译处理的句子对和相应标签的数据集,可以用于训练针对物业设备设施维修领域的文本相似度检测模型,提高模型在该特定领域上的效果,并用于后续的模型评估和改进。
在步骤S104中,将第一文本数据和第二文本数据进行语料混合,获得混合数据集,以提高语言模型的泛化能力。混合数据集的构建可以按照以下步骤进行:将第一文本数据和第二文本数据合并为一个混合数据集,对合并后的数据集进行随机洗牌,以打乱数据的顺序;根据需求对混合数据集进行划分,例如划分为训练集、验证集和测试集;混合数据集即包含了第一文本数据和第二文本数据的样本,每个样本都由句子对和相应的标签组成。混合数据集的目的是提供更广泛的数据样本,增加模型的训练样本多样性,有助于语言模型更好的捕捉不同领域、语境和句子表达方式的文本相似性。
在步骤S105中,可以利用预训练的文本相似度检测模型对所述混合数据集进行文本推理,以获取第一文本的特征向量和第二文本的特征向量,并且每个特征向量都对应的域标签;使用预训练的文本相似度检测模型(如Bert)来对混合数据集中的句子对进行推理,可以将输入的句子对转化为相应的特征向量;如对于句子对(句子i和句子j),可以使用该预训练模型计算其特征向量,这些特征向量通常为模型的最后一层输出结果或其他层的输出结果,具体取决于实际应用需求;经过模型推理后,可以获得第一文本特征向量(记为xi)和第二文本特征向量(记为xj),特征向量可以用来表示相应句子在模型中的语义表征。
此外,每个特征向量都对应一个域标签,域标签表示特征向量所属的域或来源,对于混合数据集中的句子对,可以根据句子所属的原始文本数据集来标记域标签;例如,对于第一文本数据和第二文本数据,可以分别将其特征向量标记为对应的域标签(如Di和Dj);通过将域标签与特征向量关联起来,可以在模型的训练和判别过程中进行利用,即域标签可以作为附加特征,帮助模型更好的理解不同来源的句子对,并更好的捕捉不同领域和语境的文本相似性。
在一实施例中,所述步骤S105,包括:
判断所述第一文本特征向量和第二文本特征向量是否语义相似,并标记句对标签;若是,则将所述句对标签标记为1;若否,则将所述句对标签标记为0。
在本实施例中,需要判断第一文本特征向量和第二文本特征向量是否语义相似,并为句子对标签进行标记,如果语义相似,将句子对标签标记为1;如果不相似,将句子对标签标记为0;判断特征向量的语义相似性具体步骤如下:计算第一文本特征向量和第二文本特征向量之间的相似性分数,可以使用余弦相似度作为相似性度量,计算第一文本特征向量和第二文本特征向量之间的相似度分值;根据相似性分数判断语义相似性,设置一个阈值来判断相似性分数的界限,如果相似性分数高于阈值,则认为两个特征向量语义相似;如果相似性分数低于阈值,则认为两个特征向量语义不相似;根据判断结果,为句子对标签进行标记,如果特征向量语义相似,则将句子对标签标记为1;如果特征向量语义不相似,则将句子对标签标记为0。
在步骤S106中,对比损失是一种常用的损失函数,用于训练文本相似度模型,对比损失衡量了两个特征向量之间的距离或差异,并根据相似性标签(0或1)对模型进行优化。首先,定义对比损失函数,常见的对比损失函数包括余弦对比损失(Cosine ContrastiveLoss)和三元组对比损失(Triplet Contrastive Loss)等,这些损失函数根据特征向量之间的相似性标签来计算损失值;然后使用定义好的对比损失函数,将第一文本特征向量和第二文本特征向量作为输入,计算对比损失的结果,具体的计算方式取决于所选择的对比损失函数;最后根据计算得到的对比损失结果,进行模型的优化,优化的目标是通过调整模型的参数,使得对比损失最小化或最优化。
在一实施例中,所述步骤S106,包括:
根据所述第一文本特征向量和第二文本特征向量计算得到句对之间的余弦相似度;利用所述余弦相似度分别计算得到句对之间的相似度差异值和标签差异值;根据所述标签差异值筛选得到负样本,并基于所述负样本和所述相似度差异值利用指数函数计算负样本采样概率;根据所述负样本采样概率对所述相似度差异值进行采样,得到采样结果;根据所述采样结果进行对比损失计算,得到所述对比损失结果。
在本实施例中,使用第一文本特征向量和第二文本特征向量,通过计算它们的余弦相似度来衡量句对之间的相似度,余弦相似度的取值范围在[-1,1]之间,值越接近1表示句对越相似,值越接近-1表示句对越不相似,具体的余弦相似度计算公式如下:
其中,si,j表示句对i,j的余弦相似度;xi表示所述第一文本特征向量;xj表示所述第二文本特征向量。
利用余弦相似度来计算句对之间的相似度差异值,相似度差异值可通过取差的绝对值来表示,同时,计算句对标签之间的差异值,用于表示预测的相似度差异;通常情况下,相似度差异值为预测相似度减去人工标注的相似度(公式表达为d(i,j,k,l)=si,j-sk,l),标签差异值为预测标签减去人工标注的标签(公式达为y(i,j,k,l)=Li,j-Lk,l)。根据标签差异值进行筛选,选择标签差异值为-1的样本作为负样本,即预测的相似度小于人工标注的相似度的样本;基于负样本和相似度差异值,利用指数函数计算负样本的采样概率,指数函数可以将相似度差异值映射到[0,1]之间的概率值,较大的差异值对应较小的概率(公式表达为psample=exp(γ·(si,j-sk,l)),其中,γ表示人为设置的超参数);根据负样本采样概率,对相似度差异值进行采样,得到采样结果,采样的数量可以根据需求和超参数进行设定。利用采样结果和相似度差异值,进行对比损失的计算,具体的对比损失函数根据任务和模型的选择而有所不同,可以采用余弦对比损失、三元组对比损失等。对比损失的计算将根据采样结果对正样本和负样本进行比较,并优化模型的参数,使得正样本更接近、负样本更远离。
在一实施例中,所述根据所述采样结果进行对比损失计算,得到所述对比损失结果,包括:
按如下公式计算得到所述对比损失结果:
losscontrast=log(∑exp(si,j-sk,l)+1)
其中,losscontrast表示所述对比损失结果;exp()表示指数函数;si,j表示句对i,j;sk,l表示句对k,l。
在本实施例中,训练时最小化对比损失losscontrast等价于最小化(si,j-sk,l),又等价于最小化si,j的同时最大化sk,l,将目标训练放在一个损失函数中,可以通过一个统一的优化过程来更新模型参数;当训练收敛后,si,j将近似为0而sk,l近似为1,即完成了训练目标(句对i,j不相似而句对k,l相似)。
在步骤S107中,根据第一文本特征向量和第二文本特征向量的域标签,确定文本特征向量来自的域,域标签可以是表示数据来源的标识,比如"公开数据集"和"私有数据集";使用域标签计算域差异损失。域差异损失可以通过比较两个文本的域标签的差异来衡量域之间的差异性。可以使用交叉熵损失函数或其他适当的损失函数来计算域差异损失;根据计算可以得到域差异损失结果,域差异损失结果将反映出文本之间在域上的相似性或差异性。
在一实施例中,所述步骤S107,包括:
根据所述域标签计算得到句对之间的最大均值差异;利用所述最大均值差异进行域差异损失计算,得到所述域差异损失结果。
在本实施例中,针对每个句对提取其对应的域标签,并计算句对之间的MMD距离(即最大均值差异,下同),MMD距离衡量了两个分布之间的差异,可以用于比较句对在域上的差异,MMD距离计算公式如下:
其中,k表示核函数;σ表示人为设置的超参数;m,n表示选用域的样本数量;
需要注意的是,m,n中的域指的是语料来源的邻域,例如金融领域、生物领域或者设备维修领域,不同的领域之间存在着差异,这里的域差异损失两两进行计算,两个域可以指称金融领域和设备维修领域,由于训练是抽取全部语料的一个批次(Batch),因此在这个Batch中可能有一部分语料来自金融领域,有一部分语料来自设备维修邻域;m,n可以指称来自这两个领域的样本数量,由于这里的域损失差异没有方向性,即“金融领域和设备维修领域”的差异和“设备维修领域和金融领域”的差异是相等的,因此m,n可以互相表示来自这两个域的样本数量,这里不做区分。
利用最大均值差异作为域差异损失的一部分,域差异损失可以通过最大均值差异来衡量句对之间的域差异,可以使用适当的损失函数(如均方误差损失或交叉熵损失)将最大均值差异作为损失项加入到模型的训练中;根据域差异损失的计算,得到域差异损失的结果,域差异损失结果将反映出句对之间在域上的差异度。
在一实施例中,所述利用所述最大均值差异进行域差异损失计算,得到所述域差异损失结果,包括:
按如下公式计算得到所述域差异损失结果:
其中,lossshift表示所述域差异损失结果;c表示句对比较次数;xi表示所述第一文本特征向量;xj表示所述第二文本特征向量。
在本实施例中,域差异损失函数是一种用于度量不同域之间差异的损失函数,在训练模型时被广泛应用于领域自适应和域间特征对齐任务中,域差异损失函数的目标是通过最小化不同域之间的差异来增强模型的泛化能力,使得模型能够在目标域上获得更好的性能。本实施例中采用的域差异损失函数为最大均值差异(Maximum Mean Discrepancy,MMD),用于度量两个域之间的分布差异,MMD损失的计算基于核方法,通过在特征空间中比较不同域的样本分布来测量距离。具体的,给定源域和目标域的样本特征表示,MMD损失的计算分为以下几个步骤:计算源域和目标域的样本特征的均值向量;计算源域和目标域样本特征的协方差矩阵,使用核函数计算源域和目标域样本特征之间的距离,通常使用高斯核函数;将距离转化为MMD损失,通过最小化该损失来减小源域和目标域之间的差异。域差异损失函数不仅可以通过MMD进行度量,还可以使用其他方法,如领域对抗神经网络(Domain Adversarial Neural Network,DANN)中的领域分类器,DANN通过训练一个领域分类器来最小化源域和目标域之间的分类误差,从而减小两个域之间的差异。
在步骤S108中,将第一文本特征向量输入到域分类神经网络层中,在域分类神经网络层中,经过线性分类层的计算,得到属于该域的概率值;根据域标签,即第一文本特征向量所属的域类别标签,计算域分类损失。域分类损失的计算可以使用交叉熵损失函数,交叉熵损失函数常用于多分类任务,可以将域标签表示为one-hot编码,然后将one-hot编码与对应域的概率值输入到交叉熵损失函数中,计算损失值。域分类损失的目标是使模型能够准确预测输入文本特征向量所属的域类别,从而实现域分类的目标,通过最小化域分类损失,模型能够更好学习不同域之间的特征差异,从而提高模型在跨域场景下的泛化能力。需要注意的是,域分类损失只是模型训练过程中的一部分,通常与其他损失函数(如对比损失、域差异损失等)结合使用,共同训练模型以实现多个目标。
在一实施例中,所述步骤S108,包括:
按如下公式计算得到所述域分类损失结果:
其中,losscls表示所述域差异损失结果;N表示样本数量;pi表示域分类概率;yD表示域类别标签。
在本实施例中,域分类损失函数用于域分类任务,旨在训练模型使其能够准确预测输入样本所属的域类别,域分类损失函数的目标是最小化预测与真实标签之间的差异,从而使模型能够更准确预测样本的域类别,还可以通过反向传播算法和梯度下降优化算法,可以更新模型的参数,使其逐渐减小域分类损失,提高模型的域分类性能。
在步骤S109中,将对比损失结果、域差异损失结果和域分类损失结果,进行损失值相加得到总损失(公式可以表示为loss=losscontrast+lossshift+losscls);使用总损失进行反向传播,根据反向传播算法更新模型的参数,可以利用梯度下降的方法,沿着损失函数的梯度方向对参数进行调整,以减小总损失;通过多次迭代训练模型,直到达到收敛状态(即模型的性能不再显著改善或损失函数收敛到最小值),最终可以得到训练完成的(即最终的文本相似度检测模型)。在实际应用中即可使用最终的文本相似度检测模型,应用在物业设备设施维修数据库查询模块中,实现根据查询语句,返回与之最相似的文档片段。
综上所述,本发明通过一种交叉余弦对比损失,来优化微调后的句向量的区分度,提高文本相似度模型的检测效果,并结合对抗学习的思想,提出一种新的域自适应训练流程,以克服不同域语料之间的差异,在完成源域的监督学习的同时,也能完成目标域的半监督学习。
结合图3所示,图3为本发明实施例提供的一种文本相似度检测模型的训练装置的示意性框图,文本相似度检测模型的训练装置300包括:
第一获取单元301,用于获取预训练的文本相似度检测模型;
第二获取单元302,用于获取公开文本数据集,并对所述公开文本数据集进行文本相似度标注,得到第一文本数据;
第三获取单元303,用于获取私有文本数据集,并对所述私有文本数据集进行文本相似度标注,得到第二文本数据;
数据混合单元304,用于将所述第一文本数据和第二文本数据进行语料混合,得到混合数据集;
模型推理单元305,用于利用所述预训练的文本相似度检测模型对所述混合数据集进行文本推理,分别得到第一文本特征向量和第二文本特征向量;其中,所述第一文本特征向量和第二文本特征向量均对应有域标签;
第一损失单元306,用于根据所述第一文本特征向量和第二文本特征向量计算对比损失,得到对比损失结果;
第二损失单元307,用于根据所述域标签计算域差异损失,得到域差异损失结果;
第三损失单元308,用于根据所述第一文本特征向量计算域分类损失,得到域分类损失结果;
模型输出单元309,用于将所述对比损失结果、域差异损失结果和域分类损失结果进行相加,并进行反向传播所述预训练的文本相似度检测模型的参数,得到最终的文本相似度检测模型。
在本实施例中,第一获取单元301用于获取预训练的文本相似度检测模型;第二获取单元302用于获取公开文本数据集,并对所述公开文本数据集进行文本相似度标注,得到第一文本数据;第三获取单元303用于获取私有文本数据集,并对所述私有文本数据集进行文本相似度标注,得到第二文本数据;数据混合单元304用于将所述第一文本数据和第二文本数据进行语料混合,得到混合数据集;模型推理单元305用于利用所述预训练的文本相似度检测模型对所述混合数据集进行文本推理,分别得到第一文本特征向量和第二文本特征向量;其中,所述第一文本特征向量和第二文本特征向量均对应有域标签;第一损失单元306用于根据所述第一文本特征向量和第二文本特征向量计算对比损失,得到对比损失结果;第二损失单元307用于根据所述域标签计算域差异损失,得到域差异损失结果;第三损失单元308用于根据所述第一文本特征向量计算域分类损失,得到域分类损失结果;模型输出单元309用于将所述对比损失结果、域差异损失结果和域分类损失结果进行相加,并进行反向传播所述预训练的文本相似度检测模型的参数,得到最终的文本相似度检测模型。
在一实施例中,所述模型推理单元305,包括:
判断单元,用于判断所述第一文本特征向量和第二文本特征向量是否语义相似,并标记句对标签;若是,则将所述句对标签标记为1;若否,则将所述句对标签标记为0。
在一实施例中,所述第一损失单元306,包括:
余弦单元,用于根据所述第一文本特征向量和第二文本特征向量计算得到句对之间的余弦相似度;
相似单元,用于利用所述余弦相似度分别计算得到句对之间的相似度差异值和标签差异值;
筛选单元,用于根据所述标签差异值筛选得到负样本,并基于所述负样本和所述相似度差异值利用指数函数计算负样本采样概率;
采样单元,用于根据所述负样本采样概率对所述相似度差异值进行采样,得到采样结果;
对比单元,用于根据所述采样结果进行对比损失计算,得到所述对比损失结果。
在一实施例中,所述对比单元,包括:
指数单元,用于按如下公式计算得到所述对比损失结果:
losscontrast=log(∑exp(si,j-sk,l)+1)
其中,losscontrast表示所述对比损失结果;exp()表示指数函数;si,j表示句对i,j;sk,l表示句对k,l。
在一实施例中,所述第二损失单元307,包括:
差异单元,用于根据所述域标签计算得到句对之间的最大均值差异;
均值单元,用于利用所述最大均值差异进行域差异损失计算,得到所述域差异损失结果。
在一实施例中,所述均值单元,包括:
域差单元,用于按如下公式计算得到所述域差异损失结果:
其中,lossshift表示所述域差异损失结果;c表示句对比较次数;xi表示所述第一文本特征向量;xj表示所述第二文本特征向量。
在一实施例中,所述第三损失单元308,包括:
域类单元,用于按如下公式计算得到所述域分类损失结果:
其中,losscls表示所述域差异损失结果;N表示样本数量;pi表示域分类概率;yD表示域类别标签。
由于装置部分的实施例与方法部分的实施例相互对应,因此装置部分的实施例请参见方法部分的实施例的描述,这里暂不赘述。
本发明实施例还提供了一种计算机可读存储介质,其上存有计算机程序,该计算机程序被执行时可以实现上述实施例所提供的步骤。该存储介质可以包括:U盘、移动硬盘、只读存储器(Read-Only Memory,ROM)、随机存取存储器(Random Access Memory,RAM)、磁碟或者光盘等各种可以存储程序代码的介质。
本发明实施例还提供了一种计算机设备,可以包括存储器和处理器,存储器中存有计算机程序,处理器调用存储器中的计算机程序时,可以实现上述实施例所提供的步骤。当然计算机设备还可以包括各种网络接口,电源等组件。
说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似部分互相参见即可。对于实施例公开的系统而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。应当指出,对于本技术领域的普通技术人员来说,在不脱离本申请原理的前提下,还可以对本申请进行若干改进和修饰,这些改进和修饰也落入本申请权利要求的保护范围内。
还需要说明的是,在本说明书中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的状况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
Claims (10)
1.一种文本相似度检测模型的训练方法,其特征在于,包括:
获取预训练的文本相似度检测模型;
获取公开文本数据集,并对所述公开文本数据集进行文本相似度标注,得到第一文本数据;
获取私有文本数据集,并对所述私有文本数据集进行文本相似度标注,得到第二文本数据;
将所述第一文本数据和第二文本数据进行语料混合,得到混合数据集;
利用所述预训练的文本相似度检测模型对所述混合数据集进行文本推理,分别得到第一文本特征向量和第二文本特征向量;其中,所述第一文本特征向量和第二文本特征向量均对应有域标签;
根据所述第一文本特征向量和第二文本特征向量计算对比损失,得到对比损失结果;
根据所述域标签计算域差异损失,得到域差异损失结果;
根据所述第一文本特征向量计算域分类损失,得到域分类损失结果;
将所述对比损失结果、域差异损失结果和域分类损失结果进行相加,并进行反向传播所述预训练的文本相似度检测模型的参数,得到最终的文本相似度检测模型。
2.根据权利要求1所述的文本相似度检测模型的训练方法,其特征在于,所述利用所述预训练的文本相似度检测模型对所述混合数据集进行文本推理,分别得到第一文本特征向量和第二文本特征向量,包括:
判断所述第一文本特征向量和第二文本特征向量是否语义相似,并标记句对标签;若是,则将所述句对标签标记为1;若否,则将所述句对标签标记为0。
3.根据权利要求1所述的文本相似度检测模型的训练方法,其特征在于,所述根据所述第一文本特征向量和第二文本特征向量计算对比损失,得到对比损失结果,包括:
根据所述第一文本特征向量和第二文本特征向量计算得到句对之间的余弦相似度;
利用所述余弦相似度分别计算得到句对之间的相似度差异值和标签差异值;
根据所述标签差异值筛选得到负样本,并基于所述负样本和所述相似度差异值利用指数函数计算负样本采样概率;
根据所述负样本采样概率对所述相似度差异值进行采样,得到采样结果;
根据所述采样结果进行对比损失计算,得到所述对比损失结果。
4.根据权利要求3所述的文本相似度检测模型的训练方法,其特征在于,所述根据所述采样结果进行对比损失计算,得到所述对比损失结果,包括:
按如下公式计算得到所述对比损失结果:
losscontrast=log(∑exp(si,j-sk,l)+1)
其中,losscontrast表示所述对比损失结果;exp()表示指数函数;si,j表示句对i,j;sk,l表示句对k,l。
5.根据权利要求1所述的文本相似度检测模型的训练方法,其特征在于,所述根据所述域标签计算域差异损失,得到域差异损失结果,包括:
根据所述域标签计算得到句对之间的最大均值差异;
利用所述最大均值差异进行域差异损失计算,得到所述域差异损失结果。
6.根据权利要求5所述的文本相似度检测模型的训练方法,其特征在于,所述利用所述最大均值差异进行域差异损失计算,得到所述域差异损失结果,包括:
按如下公式计算得到所述域差异损失结果:
其中,lossshift表示所述域差异损失结果;c表示句对比较次数;xi表示所述第一文本特征向量;xj表示所述第二文本特征向量。
7.根据权利要求1所述的文本相似度检测模型的训练方法,其特征在于,所述根据所述第一文本特征向量计算域分类损失,得到域分类损失结果,包括:
按如下公式计算得到所述域分类损失结果:
其中,losscls表示所述域差异损失结果;N表示样本数量;pi表示域分类概率;yD表示域类别标签。
8.一种文本相似度检测模型的训练装置,其特征在于,包括:
第一获取单元,用于获取预训练的文本相似度检测模型;
第二获取单元,用于获取公开文本数据集,并对所述公开文本数据集进行文本相似度标注,得到第一文本数据;
第三获取单元,用于获取私有文本数据集,并对所述私有文本数据集进行文本相似度标注,得到第二文本数据;
数据混合单元,用于将所述第一文本数据和第二文本数据进行语料混合,得到混合数据集;
模型推理单元,用于利用所述预训练的文本相似度检测模型对所述混合数据集进行文本推理,分别得到第一文本特征向量和第二文本特征向量;其中,所述第一文本特征向量和第二文本特征向量均对应有域标签;
第一损失单元,用于根据所述第一文本特征向量和第二文本特征向量计算对比损失,得到对比损失结果;
第二损失单元,用于根据所述域标签计算域差异损失,得到域差异损失结果;
第三损失单元,用于根据所述第一文本特征向量计算域分类损失,得到域分类损失结果;
模型输出单元,用于将所述对比损失结果、域差异损失结果和域分类损失结果进行相加,并进行反向传播所述预训练的文本相似度检测模型的参数,得到最终的文本相似度检测模型。
9.一种计算机设备,其特征在于,包括存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如权利要求1至7任一项所述的文本相似度检测模型的训练方法。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述的文本相似度检测模型的训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310804728.9A CN116720498A (zh) | 2023-07-03 | 2023-07-03 | 一种文本相似度检测模型的训练方法、装置及其相关介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310804728.9A CN116720498A (zh) | 2023-07-03 | 2023-07-03 | 一种文本相似度检测模型的训练方法、装置及其相关介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116720498A true CN116720498A (zh) | 2023-09-08 |
Family
ID=87869726
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310804728.9A Pending CN116720498A (zh) | 2023-07-03 | 2023-07-03 | 一种文本相似度检测模型的训练方法、装置及其相关介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116720498A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117573815A (zh) * | 2024-01-17 | 2024-02-20 | 之江实验室 | 一种基于向量相似度匹配优化的检索增强生成方法 |
-
2023
- 2023-07-03 CN CN202310804728.9A patent/CN116720498A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117573815A (zh) * | 2024-01-17 | 2024-02-20 | 之江实验室 | 一种基于向量相似度匹配优化的检索增强生成方法 |
CN117573815B (zh) * | 2024-01-17 | 2024-04-30 | 之江实验室 | 一种基于向量相似度匹配优化的检索增强生成方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109992782B (zh) | 法律文书命名实体识别方法、装置及计算机设备 | |
CN110188358B (zh) | 自然语言处理模型的训练方法及装置 | |
CN111738003B (zh) | 命名实体识别模型训练方法、命名实体识别方法和介质 | |
CN114492363B (zh) | 一种小样本微调方法、系统及相关装置 | |
CN113392209B (zh) | 一种基于人工智能的文本聚类方法、相关设备及存储介质 | |
CN112883193A (zh) | 一种文本分类模型的训练方法、装置、设备以及可读介质 | |
CN115309910B (zh) | 语篇要素和要素关系联合抽取方法、知识图谱构建方法 | |
CN116992007B (zh) | 基于问题意图理解的限定问答系统 | |
CN115952292B (zh) | 多标签分类方法、装置及计算机可读介质 | |
CN110852089A (zh) | 基于智能分词与深度学习的运维项目管理方法 | |
Jiang et al. | Impact of OCR quality on BERT embeddings in the domain classification of book excerpts | |
CN114781651A (zh) | 基于对比学习的小样本学习鲁棒性提升方法 | |
CN116720498A (zh) | 一种文本相似度检测模型的训练方法、装置及其相关介质 | |
CN111581365B (zh) | 一种谓词抽取方法 | |
CN117033961A (zh) | 一种上下文语境感知的多模态图文分类方法 | |
CN114841148A (zh) | 文本识别模型训练方法、模型训练装置、电子设备 | |
CN115146021A (zh) | 文本检索匹配模型的训练方法、装置、电子设备及介质 | |
CN111767388B (zh) | 一种候选池生成方法 | |
CN114548117A (zh) | 一种基于bert语义增强的因果关系抽取方法 | |
CN113535928A (zh) | 基于注意力机制下长短期记忆网络的服务发现方法及系统 | |
Kreyssig | Deep learning for user simulation in a dialogue system | |
Yang et al. | Hierarchical dialog state tracking with unknown slot values | |
Üveges | Comprehensibility and Automation: Plain Language in the Era of Digitalization | |
CN111104478A (zh) | 一种领域概念语义漂移探究方法 | |
CN113051886B (zh) | 一种试题查重方法、装置、存储介质及设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |