CN114817307A - 一种基于半监督学习和元学习的少样本nl2sql方法 - Google Patents

一种基于半监督学习和元学习的少样本nl2sql方法 Download PDF

Info

Publication number
CN114817307A
CN114817307A CN202210147772.2A CN202210147772A CN114817307A CN 114817307 A CN114817307 A CN 114817307A CN 202210147772 A CN202210147772 A CN 202210147772A CN 114817307 A CN114817307 A CN 114817307A
Authority
CN
China
Prior art keywords
learning
model
column
training
sample
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202210147772.2A
Other languages
English (en)
Other versions
CN114817307B (zh
Inventor
郭心南
陈永锐
漆桂林
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Southeast University
Original Assignee
Southeast University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Southeast University filed Critical Southeast University
Priority to CN202210147772.2A priority Critical patent/CN114817307B/zh
Publication of CN114817307A publication Critical patent/CN114817307A/zh
Application granted granted Critical
Publication of CN114817307B publication Critical patent/CN114817307B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/20Information retrieval; Database structures therefor; File system structures therefor of structured data, e.g. relational data
    • G06F16/24Querying
    • G06F16/242Query formulation
    • G06F16/2433Query languages
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/20Information retrieval; Database structures therefor; File system structures therefor of structured data, e.g. relational data
    • G06F16/24Querying
    • G06F16/245Query processing
    • G06F16/2452Query translation
    • G06F16/24522Translation of natural language queries to structured queries
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Databases & Information Systems (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Evolutionary Computation (AREA)
  • Computing Systems (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Machine Translation (AREA)

Abstract

本专利公开了一种基于半监督学习和元学习的少样本NL2SQL方法。本方法能在仅拥有少量标注数据的场景下,通过自训练框架的辅助对模型进行迭代训练,在这过程中逐步优化模型以及伪标签。首先对基础模型利用已有的少量标注数据进行热启动训练后,将其用于大量无标注数据的伪标签以及置信度预测,并使其与标签数据结合使对模型进行半监督学习。在半监督学习的过程中,同时引入元学习算法,它会在训练过程中进行任务采样,利用其特有任务训练机制来提升模型的快速学习以及迁移学习能力。最终得到的NL2SQL模型具有接近使用大量标注数据在有监督条件下训练的模型的准确率,同时针对新数据具有强大的少样本快速学习与微调的能力。

Description

一种基于半监督学习和元学习的少样本NL2SQL方法
技术领域
本发明涉及一种基于半监督学习和元学习的文本转结构化查询语言(NL2SQL, NaturalLanguagetoStructureQueryLanguage)的方法,属于信息处理技术领域。
背景技术
随着互联网的发展,海量数据以爆炸式的速度产生与增长。数据库则成为了人们数据存储的常用工具。目前各行各业,无论是医疗、金融、化工还是电力等,都会产生很多业务数据以及知识数据,而这些数据都会被选择存储在数据库中;同时在软件与平台开发过程中,数据库也成为了首选的后端数据存储容器。数据库的中存储的数据无论是修改还是查询,都需要特定的查询语言,那就是SQL。但SQL语言在具有灵活的查询功能的同时,其语法本身也较为复杂和难懂,只能由具备一定专业知识的人来进行操作,而对于普通用户来说,他们难以直接使用SQL来查询数据库。
NL2SQL是为了解决查询问题而诞生的技术,它的核心目标是将描述查询的自然语言转化为SQL语句。这项技术可以允许普通用户使用非常口语化的自然语言来对数据库进行查询并直接得到答案。这种检索技术与返回大量相关网页或者内容的传统搜索引擎相比更具有准确性和高效性,也因此,目前它已经被用于很多问答的领域,如智能客服,智能助手等应用。
NL2SQL本身是一个复杂的任务,与传统的序列生成任务不同,SQL语句的生成需要遵循严格的语法规则,同时要根据自然语句识别出查询目标,聚合函数,限定条件等。目前基于深度学习的NL2SQL方法虽然能取得较高的准确率,但是这些方法需要大量的标注数据作为支持。由于标注本身需要自然语言,表格与SQL互相对应,导致难以从互联网的海量数据中自动获取,且由于SQL的复杂语法导致人工标注成本很大,因此缺乏带标注的监督数据来训练模型是目前阻碍NL2SQL从学术界到工业界发展的一个重要挑战。半监督学习可以使用大量无标注数据结合已有的少量标注数据来对模型进行训练,以解决缺乏监督数据的问题。同时元学习可以通过其特有的训练机制来提升模型的迁移学习能力,使其能仅通过少量样本学习新的任务。这启发了我们将半监督学习和元学习的技术用于 NL2SQL,以解决标注数据量过少的问题。
发明内容
本发明正是针对现有NL2SQL技术中存在的技术问题,提供了一种结合半监督学习和元学习的NL2SQL方法,通过半监督学习引入大量无标注数据配合少量标注数据进行训练,弥补监督数据不足的情况,同时利用元学习的任务学习机制来提升模型的迁移学习能力,使模型能够通过少量样本快速学习新任务。
为了实现上述目的,本发明的技术方案如下:
步骤1)构建NL2SQL模型,并进行参数初始化;
步骤2)利用标注数据进行热启动训练,直至准确率达到阈值;
步骤3)利用自学习框架对NL2SQL模型参数进行多轮训练和更新,直至模型参数收敛,其中每一轮自训练过程包括:
步骤3.1)使用模型为大量无标签数据预测伪标签以及置信度,作为伪标签数据集。
步骤3.2)从标签数据集和伪标签数据集的混合数据中采样任务集合,并使用基于列特异性的元学习算法对模型进行训练和参数更新。
步骤3.3)对伪标签数据按照一定比例进行随机采样,得到的采样数据与标签数据混合,采用批训练策略将混合数据分为等大的batch依次对模型进行训练和参数更新。
其中,步骤1)中的NL2SQL模型具体来说是一个细粒度输入的多任务模型。它包含编码器和多子任务解码器两部分。在编码器端,对于一个问题和其查询的表格,采用多组(列,问题)粒度的输入方式。编码器采用了预训练语言模型RoBERTa。输入格式为:(列类型,列名,多个当前列的值,问题),其中列类型包括了日期类型(Date),文本类型(Text)和数值类型(Number),多个当前列的值是从列在表格中的所有值中筛选出和问题的文本相似度打分最高的前k个值。以上述格式进行分词后输入到RoBERTa,它会输出编码后的矩阵表示。而在多子任务解码器端,SQL的生成被划分六个子任务,分别是:SELECT部分的选列预测(SC),SELECT部分的聚合函数预测(SA),WHERE部分条件个数的预测 (WN),WHERE部分每个条件的列(WC),操作符(WO)和值(WV)。其中子任务的个数和定义可以根据SQL所需覆盖的语法进行增删。其中聚合函数包含{NONE,MAX,MIN, SUM,COUNT,AVERAGE},操作符包含{>,<,=}。在此模型的基础上进行参数的随机初始化。这里采用这样的设计的原因为:(1)目前在NL2SQL的端到端方法中,列粒度输入的模型在纯模型对比中相对于传统的问句粒度输入的模型在效果上有非常明显的提升。(2) 列粒度输入模型更加方便使用本专利设计的基于列特异性的元学习算法来进行参数更新。
步骤2)用于热启动模型,由于自训练框架需要模型在预测伪标签时具备一定程度的可靠性,因此在进入自学习阶段前先使用已有的数据量较少的标注数据,采用批训练策略进行多轮训练。每一轮中将全量标注数据分为等大的batch依次对模型进行训练和参数更新。每一轮结束会使用验证集对模型进行效果评估,如果模型表现达到事先设定的阈值,则可以进入下一个步骤,否则重复进行下一轮热启动训练。这里设计热启动是为了让模型在后续的自训练中具有较好的初始参数,由于每一轮自训练需要模型对无标注数据进行伪标签预测,如果直接使用随机初始化的模型,则伪标签的可行度极低,包含的噪音极高。因此这里本专利首先利用标注数据对模型进行热启动训练,使其具备一定的标签预测能力,然后再进入自训练过程,以保证伪标签预测具有一定的可靠性。
步骤3)是框架的核心部分,即包含了半监督学习和元学习的自训练阶段。自训练的过程中模型和数据会进行迭代,其中每一轮自训练的内容如下:
步骤3.1)在标注数据以外,一般来说实际场景会存在大量没有标注的数据,它们可以被半监督学习使用以提升模型的效果。使用当前最优的模型参数对这些无标注数据进行SQL 预测,使这些数据成为伪标签数据。同时在预测过程中,综合每个样本输入到NL2SQL模型中各个子任务输出时的概率计算出每个样本的置信度。当一个伪标签样本用于后续的模型更新时,它所计算出的梯度需要乘上置信度。已有的标签数据的置信度均设置为1.0。由于伪标签中包含了一部分预测正确的标签以及预测错误的标签,而预测错误的标签对于模型来说会成为噪音。因此我们在这里根据模型预测时的概率计算得到每个伪标签的置信度,从而用于调整不同伪标签的更新权重。
步骤3.2)首先将全量的标签数据和伪标签数据混合为一个采样池。同时每个原样本按照(问题,列)的组合拆分为多个子样本。由于选列是子任务中最为基础和重要的任务,并且选列任务也特别容易受到表格变换的影响,因此这里对模型选列任务进行元学习训练。在当前任务中,输入问题和列,模型需要根据SC和WC两个任务的输出概率预测该列是否应该在问题对应的SQL的SELECT部分中,还是在WHERE部分中,还是与问题无关。
对于每个样本,这里还会计算一个列特异性的分数。由于一些列比较泛用,它可以出现在不同主题的表中,而有些列比较少见,可能仅出现在特定内容的一两张表上。因此定义泛用的列特异性较低,而少见的列特异性较高。由于模型期望在训练中能学到更多普适的列表示,因此特异性高的列得分会相对较低,反正特异性低的表得分会较高,这个分数同样也会用作样本计算损失和梯度时的权重。
与前面的batch形式不同,元学习以任务(Task)的形式来训练模型。每个任务会从采样池中按照元学习的标准进行随机采样,采样获得的大量任务组成一个任务集合。其中每个任务包含支持集(SupportSet)和验证集(Query Set)。在元学习中,模型首先在支持集上以一种学习率进行损失计算和参数更新,此后在验证集上获取当前参数的损失,将支持集和验证集上的损失按一定比例进行加权后以另一种学习率更新参数。在此过程中,每个样本计算的损失和梯度需要乘上其本身的置信度分数和列特异性分数。
当前步骤设计的动机是:(1)可以通过元学习的训练机制强化模型对于SC和WC这两个表格内容敏感的选列任务的效果,从而提升模型的快速学习和迁移学习的能力。(2)列特异性的加入可以保证模型学习到更多通用列的表示,同时可以规避特殊列的伪标签所可能包含的噪音。
步骤3.3)是对模型在生成SQL上的半监督训练。虽然元学习有助于提升模型的迁移学习能力和在少样本下的表现,但是其训练机制同时带来的问题是模型的参数会较为不稳定,并且元学习仅训练了模型的选列任务。而当前步骤则是以(问题,表格)为粒度训练模型的全部子任务进行SQL预测。由于场景中伪标签数据的数量可能是标签数据的几倍甚至几十倍,首先需要对伪标签数据针对实际的两者间的比例按照一定百分比进行随机采样。采样得到的伪标签数据与标签数据进行混合,作为半监督学习的训练集。这里同样采用批训练策略,以一定的学习率将batch依次输入模型计算损失并更新参数。这里每个样本的损失和梯度计算需要乘上样本自身的置信度。
这一步骤设计的主要目的,一方面是前面的元学习步骤仅训练了SC和WC相关联的部分参数,而当前步骤训练的是模型的整体参数;另一方面是元学习的Task训练会带来参数的不稳定性,后续进行半监督数据的批训练可以更好的引入伪标签数据以及稳定参数更新。
重复步骤3.1)到步骤3.3)之间的操作。在模型的参数更新收敛完成,在验证集上的效果不再提升时,整个自训练框架的模型训练结束。
相对于现有技术,本专利的优点如下:
1.现有的技术仅考虑基于端到端模型在标注数据上进行全监督训练,但实际场景中标注数据比较少,特别是NL2SQL任务的标注成本极高。在实际应用场景中,无标注数据非常易于获得,它可以同源的无标注业务数据,也可以是不同源的无标注外部数据。本专利引入了半监督学习来引入无标注数据,用其结合已有的标注数据共同优化模型,解决少样本的问题。
2.现有的技术仅使用传统的批训练方法,而本专利引入了元学习,针对SC和WC两个表格内容敏感的子任务以Task的形式进行训练,以提升模型的快速学习和迁移学习的能力,从而更加适配少样本环境。
3.本专利设计了一种结合表格内容的列粒度输入模型,相较于现有的端到端模型来说减少了池化操作,同时表格内容也会丰富表格的向量表示。
为了验证本专利的有效性,我们在中文电力领域的NL2SQL数据集ESQL以及英文开放域百科NL2SQL数据集WikiSQL上进行了少样本设置下的实验。
表1
方法/逻辑准确率/数据集 ESQL WikiSQL
SQLova(具有代表性的现有方面) 22.3% 23.3%
HydraNet(具有代表性的现有方面) 43.6% 64.2%
基础模型 45.3% 69.6%
基础模型+半监督学习 51.2% 75.8%
基础模型+半监督学习+元学习 55.3% 78.4%
由实验结果可以看出,对比一些现有的代表性方法,本专利的基础模型在端到端层面相较现有方法就具有一定的提升,同时本专利提出的半监督学习和元学习两组消融实验中也分别取得了较为显著的提升,由此可以验证本专利方案的有效性。
附图说明
图1是本发明的总体训练流程示意图;
图2是本发明的列粒度NL2SQL模型结构图。
具体实施方式:
以下结合实施例和说明书附图,详细说明本发明的实施过程。
实施例1:参见图1、图2,一种基于半监督学习和元学习的少样本NL2SQL方法,总体训练流程示意图如图1所示,其中主要包含三个步骤:
步骤1):随机初始化NL2SQL模型的参数θ
步骤2):将模型基于已有的少量标注数据L采用批(batch)训练策略进行热启动训练,将训练集分为等大的batch依次进行训练,每个训练步(step)利用一个batch的标注数据对NL2SQL模型进行监督。每一轮(epoch)基于L进行训练后,模型会在验证集Dv上评估当前的参数θ的准确率。如果准确率低于阈值λ,则继续进行下一轮热启动训练,如果高于λ,则保存当前参数,记为θw,并进入到步骤3)。
步骤3):对模型进行多轮的,结合半监督学习和元学习的训练,其中每一轮:
步骤3-1):基于当前模型参数θw对大量的未标注数据U进行伪SQL标签的预测,由此得到伪标签数据集UP,同时每个标签都会在预测时基于模型在多个子任务上的输出概率进行置信度的打分,具体的计算方式如下:
Figure RE-GDA0003655821630000051
其中,z属于子任务的集合,τ作为开平方根的超参数,ζ作为最低的阈值。所有标注数据的置信度设置为1.0。这个置信度会作为后续伪标签样本计算梯度时的权重。
步骤3-2):将标注数据集L和伪标签数据集UP结合,并将每个SQL预测样本按照(Q,Hi) 拆分为多个子样本,每个子样本作为一个Hi的三分类任务:属于SELECT部分,属于WHERE部分以及句子无关。由于选列是子任务中对表格变化最敏感的任务,同时也是最基础和重要的任务,因此这样的子样本设计结合元学习能大幅提升模型在这方面的能力。
同时本发明提出了列特异性的概念,在数据库中,有些列比较通用和常见,它可能出现在很多不同主题的表格中;而反之有些列比较少见和特殊,这些列可能只在特定领域特定主题的表格中才会出现。在后续的元学习中,更希望模型能学习到更通用的列的知识,减少一些少见列的学习,其原因有两点:首先,通用列会更加常用,且易于学习,反之特殊列可能带来不必要的扰动;其次,由于和半监督学习结合,伪标签中对于特殊列的预测出错的概率会大大增加,因此它们容易成为噪音。这里本发明基于数据库中每个列的出现频率来量化列特异性,它的分数计算方式如下所示:
Figure RE-GDA0003655821630000061
其中,其中Ntotal表示数据库中所有列的个数,Ndistinct是对Ntotal按照列名进行去重后获得的列的总个数,Nhi表示当前列的列名在数据库中重复出现的次数。公式主要是将数据库平均每个列名出现的频率比上当前列名出现的频率,以此来衡量特异性。
在此基础上,从混合的选列子样本集合中采样任务集合Task={task1,task2,...,taskn},对于每个任务
Figure RE-GDA0003655821630000066
表示支持集(SupportSet),
Figure RE-GDA0003655821630000062
表示验证集(QuerySet)。在NL2SQL场景中,以每个表格对应的所有样本为一类样本,对于taski,首先从0所有的样本类别中随机采样nw个类别,在每个类别中随机采样ks个样本作为Supi, 在每个类别中随机采样kq个样本作为Qryi,两者的样本不重合。按照这种方法在标注数据集L和伪标签数据集UP全量和混合后的集合中采样nt个任务组成任务集合。
在元学习的过程中,每个任务依次用于模型的训练,直到Task集合的每个任务都经过迭代。其中对于每个taski,首先用当前模型参数θw在Supi上计算损失
Figure RE-GDA0003655821630000065
Figure RE-GDA0003655821630000063
其中每个L表示一个样本的损失,y'j表示由SQL标签yj得到的选列标签,每个样本的损失都乘上各自的置信度与特异性得分的比值作为权重。将计算得到的
Figure RE-GDA0003655821630000064
以学习率γ进行梯度更新,得到参数θ'w。之后基于模型参数θ'w预测Qryi并以相同的方式计算损失LQryi,将两个损失按照一定比例进行加权获得总损失:
Li=ηLSupi+(1-η)LQryi
其中η是一个表示权重的超参数。在得到Li后将其在θ'w的基础上以学习率υ进行梯度更新,得到新的θw。以这种方式依次更新所有采样得到的任务,最终得到参数θm
步骤3-3):从伪标签数据集UP中按一定比例σ随机采样一部分数据与标注数据集L进行混合,并采用批训练策略将混合数据分为等大的batch依次进行训练,每个batch可以结合置信度计算出损失L'i
Figure RE-GDA0003655821630000071
其中yj是SQL标签,每个样本的损失会各自乘上置信度再相加作为一整个batch的损失。将L'i用于对θm进行参数更新,得到当前自训练轮的最终参数θ',使用验证集Dv上评估当前参数的准确率,如果与前几轮几乎相同则判断为收敛,停止训练,否则它将作为新的参数θw用于下一轮的自训练中。
其中,本专利所设计的基础NL2SQL模型,将自然语言转换到SQL包含以下步骤:
步骤(1):将自然语言问句Q进行分词,得到Q={x1,x2,...,xn},其中x表示每个词。同时获取到Q所对应的目标表格T={H1,H2,...,Hm},而
Figure RE-GDA0003655821630000072
其中Hi表示第i个列,hi表示其列名,Ci表示其包含的多个值。同时对于每个Hi识别它的类型ti,它分为文本类型,日期类型和数字类型(Date,Text,Number)。
步骤(2):在每个列Hi下,
Figure RE-GDA0003655821630000073
其中
Figure RE-GDA0003655821630000074
表示Ci下第j行的值。在输入到编码器前,每个
Figure RE-GDA0003655821630000075
首先会分词后与Q的单词进行滑窗式文本相似度匹配,在滑窗过程中的最高一次匹配得分作为每个
Figure RE-GDA0003655821630000076
的分数,取得分最高的前k个作为剩下的集合
Figure RE-GDA0003655821630000077
步骤(3):将分词后的自然语言问句Q和其对应表格T中的每个列Hi的信息依次输入到模型的编码器RoBERTa中,具体来说Hi的信息包含(ti,hi,C'i),在同样进行分词后与Q的分词结果一起输入到编码器中,输入格式如下:
Figure RE-GDA0003655821630000078
其中,m表示hi分词后的长度,lk表示
Figure RE-GDA0003655821630000079
分词后的长度。在经过RoBERTa进行编码后,这些单词会被向量化,如下:
Figure RE-GDA0003655821630000081
其中,每个h表示一个向量,这部分向量表示会用到后续的解码预测部分。
步骤(4):利用步骤(3)得到的向量表示来进行六个子任务的预测,这六个子任务分别是:SELECT部分的选列预测(SC),SELECT部分的聚合函数预测(SA),WHERE部分条件个数的预测(WN),WHERE部分每个条件的列(WC),操作符(WO)和值(WV)。其中子任务的个数和定义可以根据SQL所需覆盖的语法进行增删。针对每一组列粒度输入 (Q,Hi),首先对Hi作为SC和WC结果的概率进行打分,计算方式如下:
Figure RE-GDA0003655821630000082
Figure RE-GDA0003655821630000083
其中WSC和WWC均为可训练参数矩阵。P表示所得到的概率。此后会进行WN的结果进行预测,对每个Hi分别预测的条件数量概率分布以及Hi本身与Q的关联度得分,具体的计算过程如下:
Figure RE-GDA0003655821630000084
Figure RE-GDA0003655821630000085
其中nj表示某一个条件数,WWN和Ww均为可训练参数矩阵。同时对于每个Hi,还会预测当其作为SC时它的聚合函数,以及当其作为某个条件的列时所对应的操作符和值,具体的计算过程如下:
Figure RE-GDA0003655821630000086
Figure RE-GDA0003655821630000087
Figure RE-GDA0003655821630000088
Figure RE-GDA0003655821630000089
其中aj表示某一个聚合函数,oj表示某一个操作符,s和e分别表示值在问句中的开始与结束的索引,WSA,WWO
Figure RE-GDA00036558216300000810
均为可训练参数矩阵。
步骤(5):综合每一组(Q,Hi)的预测结果,基于Q和查询目标表格T来最终预测完整的SQL结果。首先对于SC,直接取PSC得分最高的列作为结果;其次综合每个列对WN的预测,并利用关联度得分作为权重来进行带权相加,最终取概率分布中得分最高的数量作为WN的结果,记作n':
Figure RE-GDA0003655821630000091
之后取前n'个PWC得分最高的列作为WC的结果;与此同时,对于SC的列,取该列所预测的聚合函数中概率最大者作为SA的结果;对于WC的列,取这些列各自所预测的操作符和值区间的概率最大者作为各自所在的条件中的操作符和值,即WO和WV的结果。由此结合预定义的SQL骨架,用子任务的结果进行填充最终形成完整的SQL:
SELECT{SA}{SC}WHERE({WC}{WO}{WV})*,
其中,{}表示一个待填充的槽位,*表示括号中的部分可能出现0次或多次,由此最终的SQL 就预测完成了。
本实施例使用了如下指标进行评估:
LF(LogicalFormAccuracy):在NL2SQL任务上,指代模型预测得到的SQL与黄金标注的 SQL在文本层面上完全一致的样本在所有样本的比例。
EX(ExecuteAccuracy):在NL2SQL任务上,指代模型预测得到的SQL与黄金标注的SQL 在实际在数据库中运行后获得的查询结果完全一致的样本在所有样本的比例。
需要说明的是上述实施例,并非用来限定本发明的保护范围,在上述技术方案的基础上所作出的等同变换或替代均落入本发明权利要求所保护的范围。

Claims (10)

1.一种基于半监督学习和元学习的少样本NL2SQL方法,其特征在于,该方法包括以下步骤:
步骤1)构建NL2SQL模型,并进行参数初始化;
步骤2)利用标注数据进行热启动训练,直至准确率达到阈值;
步骤3)利用自学习框架对NL2SQL模型参数进行多轮训练和更新,直至模型参数收敛。
2.根据权利要求1所述的基于半监督学习和元学习的少样本NL2SQL方法,其特征在于,步骤3)
其中每一轮自训练过程包括:
步骤3.1)使用模型为大量无标签数据预测伪标签以及置信度,作为伪标签数据集;
步骤3.2)从标签数据集和伪标签数据集的混合数据中采样任务集合,并使用基于列特异性的元学习算法对模型进行训练和参数更新;
步骤3.3)对伪标签数据按照一定比例进行随机采样,得到的采样数据与标签数据混合,采用批训练策略将混合数据分为等大的batch依次对模型进行训练和参数更新。
3.根据权利要求1所述的基于半监督学习和元学习的少样本NL2SQL方法,其特征在于,步骤1)构建NL2SQL模型,具体如下,
步骤(1-1):将自然语言问句Q进行分词,得到Q={x1,x2,...,xn},其中x表示每个词,同时获取到Q所对应的目标表格T={H1,H2,...,Hm},而Hi=(hi,Ci),其中Hi表示第i个列,hi表示其列名,Ci表示其包含的多个值,同时对于每个Hi识别它的类型ti,它分为文本类型,日期类型和数字类型(Date,Text,Number);
步骤(1-2):在每个列Hi下,
Figure RE-FDA0003655821620000011
其中
Figure RE-FDA0003655821620000012
表示Ci下第j行的值,在输入到编码器前,每个
Figure RE-FDA0003655821620000013
首先会分词后与Q的单词进行滑窗式文本相似度匹配,在滑窗过程中的最高一次匹配得分作为每个
Figure RE-FDA0003655821620000014
的分数,取得分最高的前k作为剩下的集合
Figure RE-FDA0003655821620000015
步骤(1-3):将分词后的自然语言问句Q和其对应表格T中的每个列Hi的信息依次输入到模型的编码器RoBERTa中,具体来说Hi的信息包含(ti,hi,C'i),在同样进行分词后与Q的分词结果一起输入到编码器中,
步骤(1-4):利用步骤(1-3)得到的向量表示来进行六个子任务的预测,这六个子任务分别是:SELECT部分的选列预测(SC),SELECT部分的聚合函数预测(SA),WHERE 部分条件个数的预测(WN),WHERE部分每个条件的列(WC),操作符(WO)和值(WV),
步骤(1-5):综合每一组(Q,Hi)的预测结果,基于Q和查询目标表格T来最终预测完整的SQL结果,首先对于SC,直接取PSC得分最高的列作为结果;其次综合每个列对WN的预测,并利用关联度得分作为权重来进行带权相加,最终取概率分布中得分最高的数量作为WN的结果,记作n':
Figure RE-FDA0003655821620000021
之后取前n'个PWC得分最高的列作为WC的结果;与此同时,对于SC的列,取该列所预测的聚合函数中概率最大者作为SA的结果;对于WC的列,取这些列各自所预测的操作符和值区间的概率最大者作为各自所在的条件中的操作符和值,即WO和WV的结果;由此结合预定义的SQL骨架,用子任务的结果进行填充最终形成完整的SQL:
SELECT{SA}{SC}WHERE({WC}{WO}{WV})*,
其中,{}表示一个待填充的槽位,*表示括号中的部分可能出现0次或多次,由此最终的SQL预测完成。
4.根据权利要求3所述的基于半监督学习和元学习的少样本NL2SQL方法,其特征在于,步骤(1-3)中在同样进行分词后与Q的分词结果一起输入到编码器中,输入格式如下:,输入格式如下:
Figure RE-FDA0003655821620000022
其中,m表示hi分词后的长度,lk表示
Figure RE-FDA0003655821620000023
分词后的长度,在经过RoBERTa进行编码后,这些单词会被向量化,如下:
Figure RE-FDA0003655821620000024
其中,每个h表示一个向量,这部分向量表示会用到后续的解码预测部分。
5.根据权利要求4所述的基于半监督学习和元学习的少样本NL2SQL方法,其特征在于,步骤(1-4)中,
其中子任务的个数和定义根据SQL所需覆盖的语法进行增删,针对每一组列粒度输入(Q,Hi),首先对Hi作为SC和WC结果的概率进行打分,计算方式如下:
Figure RE-FDA0003655821620000025
Figure RE-FDA0003655821620000026
其中WSC和WWC均为可训练参数矩阵,P表示所得到的概率,此后会进行WN的结果进行预测,对每个Hi分别预测的条件数量概率分布以及Hi本身与Q的关联度得分,具体的计算过程如下:
Figure RE-FDA0003655821620000031
Figure RE-FDA0003655821620000032
其中nj表示某一个条件数,WWN和Ww均为可训练参数矩阵,同时对于每个Hi,还会预测当其作为SC时它的聚合函数,以及当其作为某个条件的列时所对应的操作符和值,具体的计算过程如下:
Figure RE-FDA0003655821620000033
Figure RE-FDA0003655821620000034
Figure RE-FDA0003655821620000035
Figure RE-FDA0003655821620000036
其中aj表示某一个聚合函数,oj表示某一个操作符,s和e分别表示值在问句中的开始与结束的索引,WSA,WWO
Figure RE-FDA0003655821620000037
均为可训练参数矩阵。
6.根据权利要求5所述的基于半监督学习和元学习的少样本NL2SQL方法,其特征在于,步骤2)中的热启动训练,具体如下,
将模型基于已有的少量标注数据L采用批(batch)训练策略进行热启动训练,将训练集分为等大的batch依次进行训练,每个训练步(step)利用一个batch的标注数据对NL2SQL模型进行监督训练,每一轮(epoch)基于L进行训练后,模型会在验证集Dv上评估当前参数θ的准确率,如果准确率低于阈值λ,则继续进行下一轮热启动训练,如果高于λ,则保存当前参数,记为θw,并进入下一步骤。
7.根据权利要求6所述的基于半监督学习和元学习的少样本NL2SQL方法,其特征在于,步骤3.1)中的伪标签以及置信度预测,具体如下,
基于当前模型参数θw对大量的未标注数据U进行伪SQL标签的预测,由此得到伪标签数据集UP,同时每个样本都会在预测时基于模型在多个子任务上的输出概率进行置信度的打分,具体的计算方式如下:
Figure RE-FDA0003655821620000041
其中,z属于子任务的集合,τ作为开平方根的超参数,ζ作为最低的阈值,所有标注数据的置信度设置为1.0,这个置信度会作为后续伪标签样本计算梯度时的权重。
8.根据权利要求7所述的基于半监督学习和元学习的少样本NL2SQL方法,其特征在于,步骤3.2)中的任务采样以及基于列特异性的元学习算法,具体如下,
将标注数据集L和伪标签数据集UP结合,并将每个SQL预测样本按照(Q,Hi)拆分为多个子样本,每个子样本作为一个Hi的三分类任务:属于SELECT部分,属于WHERE部分以及句子无关,由于选列是子任务中对表格变化最敏感的任务,同时也是最基础和重要的任务,因此这样的子样本设计结合元学习能大幅提升模型在这方面的能力。
9.根据权利要求8所述的基于半监督学习和元学习的少样本NL2SQL方法,其特征在于,数据库中每个列的出现频率来量化列特异性,它的分数计算方式如下所示:
Figure RE-FDA0003655821620000042
其中,其中Ntotal表示数据库中所有列的个数,Ndistinct是对Ntotal按照列名进行去重后获得的列的总个数,
Figure RE-FDA0003655821620000043
表示当前列的列名在数据库中重复出现的次数,公式主要是将数据库平均每个列名出现的频率比上当前列名出现的频率,以此来衡量特异性;
在此基础上,从混合的选列子样本集合中采样任务集合Task={task1,task2,...,taskn},对于每个任务
Figure RE-FDA0003655821620000046
Supi表示支持集(SupportSet),Qryi表示验证集(QuerySet);在NL2SQL场景中,以每个表格对应的所有样本为一类样本,对于taski,首先从0所有的样本类别中随机采样nw个类别,在每个类别中随机采样ks个样本作为Supi,在每个类别中随机采样kq个样本作为Qryi,两者的样本不重合,按照这种方法在标注数据集L和伪标签数据集UP全量和混合后的集合中采样nt个任务组成任务集合;
在元学习的过程中,每个任务依次用于模型的训练,直到Task集合的每个任务都经过迭代,其中对于每个taski,首先用当前模型参数θw在Supi上计算损失
Figure RE-FDA0003655821620000044
Figure RE-FDA0003655821620000045
其中每个L表示一个样本的损失,y'j表示由SQL标签yj得到的选列标签,每个样本的损失都乘上各自的置信度与特异性得分的比值作为权重,将计算得到的
Figure RE-FDA0003655821620000051
以学习率γ进行梯度更新,得到参数θ'w,之后基于模型参数θ'w预测Qryi并以相同的方式计算损失
Figure RE-FDA0003655821620000052
将两个损失按照一定比例进行加权获得总损失:
Figure RE-FDA0003655821620000053
其中η是一个表示权重的超参数,在得到Li后将其在θ'w的基础上以学习率υ进行梯度更新,得到新的θw,以这种方式依次更新所有采样得到的任务,最终得到参数θm
10.根据权利要求9所述的基于半监督学习和元学习的少样本NL2SQL方法,其特征在于,步骤3.3)对伪标签数据按照一定比例进行随机采样,具体如下,
从伪标签数据集UP中按一定比例σ随机采样一部分数据与标注数据集L进行混合,并采用批训练策略将混合数据分为等大的batch依次进行训练,每个batch结合置信度计算出损失L'i
Figure RE-FDA0003655821620000054
其中yj是SQL标签,每个样本的损失会各自乘上置信度再相加作为一整个batch的损失,将L'i用于对θm进行参数更新,得到当前自训练轮的最终参数θ',使用验证集Dv上评估当前参数的准确率,如果与前几轮几乎相同则判断为收敛,停止训练,否则它将作为新的参数θw用于下一轮的自训练中。
CN202210147772.2A 2022-02-17 2022-02-17 一种基于半监督学习和元学习的少样本nl2sql方法 Active CN114817307B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210147772.2A CN114817307B (zh) 2022-02-17 2022-02-17 一种基于半监督学习和元学习的少样本nl2sql方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210147772.2A CN114817307B (zh) 2022-02-17 2022-02-17 一种基于半监督学习和元学习的少样本nl2sql方法

Publications (2)

Publication Number Publication Date
CN114817307A true CN114817307A (zh) 2022-07-29
CN114817307B CN114817307B (zh) 2024-08-13

Family

ID=82527844

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210147772.2A Active CN114817307B (zh) 2022-02-17 2022-02-17 一种基于半监督学习和元学习的少样本nl2sql方法

Country Status (1)

Country Link
CN (1) CN114817307B (zh)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115080748A (zh) * 2022-08-16 2022-09-20 之江实验室 一种基于带噪标签学习的弱监督文本分类方法和装置
CN115984653A (zh) * 2023-02-14 2023-04-18 中南大学 一种动态智能货柜商品识别模型的构建方法
CN117995173A (zh) * 2024-01-31 2024-05-07 三六零数字安全科技集团有限公司 一种语言模型生成方法、装置、存储介质及电子设备

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112232416A (zh) * 2020-10-16 2021-01-15 浙江大学 一种基于伪标签加权的半监督学习方法
CN113254599A (zh) * 2021-06-28 2021-08-13 浙江大学 一种基于半监督学习的多标签微博文本分类方法
WO2021243706A1 (zh) * 2020-06-05 2021-12-09 中山大学 一种跨语言生成提问的方法和装置

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2021243706A1 (zh) * 2020-06-05 2021-12-09 中山大学 一种跨语言生成提问的方法和装置
CN112232416A (zh) * 2020-10-16 2021-01-15 浙江大学 一种基于伪标签加权的半监督学习方法
CN113254599A (zh) * 2021-06-28 2021-08-13 浙江大学 一种基于半监督学习的多标签微博文本分类方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
杨灿;: "一种结合GAN和伪标签的深度半监督模型研究", 中国科技信息, no. 17, 1 September 2020 (2020-09-01), pages 83 - 87 *

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115080748A (zh) * 2022-08-16 2022-09-20 之江实验室 一种基于带噪标签学习的弱监督文本分类方法和装置
CN115080748B (zh) * 2022-08-16 2022-11-11 之江实验室 一种基于带噪标签学习的弱监督文本分类方法和装置
CN115984653A (zh) * 2023-02-14 2023-04-18 中南大学 一种动态智能货柜商品识别模型的构建方法
CN115984653B (zh) * 2023-02-14 2023-08-01 中南大学 一种动态智能货柜商品识别模型的构建方法
CN117995173A (zh) * 2024-01-31 2024-05-07 三六零数字安全科技集团有限公司 一种语言模型生成方法、装置、存储介质及电子设备

Also Published As

Publication number Publication date
CN114817307B (zh) 2024-08-13

Similar Documents

Publication Publication Date Title
CN117033608B (zh) 一种基于大语言模型的知识图谱生成式问答方法及系统
CN109614471B (zh) 一种基于生成式对抗网络的开放式问题自动生成方法
CN112417894B (zh) 一种基于多任务学习的对话意图识别方法及识别系统
CN108932342A (zh) 一种语义匹配的方法、模型的学习方法及服务器
CN114817307A (zh) 一种基于半监督学习和元学习的少样本nl2sql方法
CN109902159A (zh) 一种基于自然语言处理的智能运维语句相似度匹配方法
CN112232087B (zh) 一种基于Transformer的多粒度注意力模型的特定方面情感分析方法
CN112328800A (zh) 自动生成编程规范问题答案的系统及方法
WO2022048194A1 (zh) 事件主体识别模型优化方法、装置、设备及可读存储介质
CN111274790A (zh) 基于句法依存图的篇章级事件嵌入方法及装置
CN118093834B (zh) 一种基于aigc大模型的语言处理问答系统及方法
CN111222318A (zh) 基于双通道双向lstm-crf网络的触发词识别方法
CN107273352A (zh) 一种基于Zolu函数的词嵌入学习模型及训练方法
CN110298044A (zh) 一种实体关系识别方法
CN110807069A (zh) 一种基于强化学习算法的实体关系联合抽取模型构建方法
CN111966810A (zh) 一种用于问答系统的问答对排序方法
CN115062070A (zh) 一种基于问答的文本表格数据查询方法
CN116561251A (zh) 一种自然语言处理方法
CN115658846A (zh) 一种适用于开源软件供应链的智能搜索方法及装置
CN111666374A (zh) 一种在深度语言模型中融入额外知识信息的方法
CN112989803B (zh) 一种基于主题向量学习的实体链接预测方法
CN117828024A (zh) 一种插件检索方法、装置、存储介质及设备
CN112926323A (zh) 基于多级残差卷积与注意力机制的中文命名实体识别方法
CN112765985A (zh) 一种面向特定领域专利实施例的命名实体识别方法
CN117094325A (zh) 水稻病虫害领域命名实体识别方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant