CN116415170A - 基于预训练语言模型的提示学习小样本分类方法、系统、设备及介质 - Google Patents
基于预训练语言模型的提示学习小样本分类方法、系统、设备及介质 Download PDFInfo
- Publication number
- CN116415170A CN116415170A CN202310270334.XA CN202310270334A CN116415170A CN 116415170 A CN116415170 A CN 116415170A CN 202310270334 A CN202310270334 A CN 202310270334A CN 116415170 A CN116415170 A CN 116415170A
- Authority
- CN
- China
- Prior art keywords
- model
- training
- classification
- sample
- prompt
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 120
- 238000000034 method Methods 0.000 title claims abstract description 67
- 238000013145 classification model Methods 0.000 claims abstract description 43
- 230000006870 function Effects 0.000 claims abstract description 28
- 238000007781 pre-processing Methods 0.000 claims abstract description 14
- 238000005457 optimization Methods 0.000 claims description 25
- 238000003860 storage Methods 0.000 claims description 11
- 238000012795 verification Methods 0.000 claims description 9
- 238000010276 construction Methods 0.000 claims description 8
- 238000004806 packaging method and process Methods 0.000 claims description 7
- 238000009826 distribution Methods 0.000 claims description 6
- 238000010801 machine learning Methods 0.000 claims description 6
- 239000011159 matrix material Substances 0.000 claims description 6
- 238000004590 computer program Methods 0.000 claims description 5
- 238000011478 gradient descent method Methods 0.000 claims description 3
- 230000004927 fusion Effects 0.000 claims description 2
- 238000010586 diagram Methods 0.000 description 5
- 238000003058 natural language processing Methods 0.000 description 5
- 238000013459 approach Methods 0.000 description 3
- 230000008451 emotion Effects 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 238000012545 processing Methods 0.000 description 3
- 230000009286 beneficial effect Effects 0.000 description 2
- 238000002372 labelling Methods 0.000 description 2
- 230000000873 masking effect Effects 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 230000001360 synchronised effect Effects 0.000 description 2
- 102100033814 Alanine aminotransferase 2 Human genes 0.000 description 1
- 101710096000 Alanine aminotransferase 2 Proteins 0.000 description 1
- 229940060587 alpha e Drugs 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 230000007786 learning performance Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000002360 preparation method Methods 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/25—Fusion techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Machine Translation (AREA)
Abstract
本发明公开了一种基于预训练语言模型的提示学习小样本分类方法、系统、设备及介质,方法包括下述步骤,根据预先设立的提示模板预处理数据样本;将预处理后的数据样本输入到预训练语言模型中,获取遮蔽处的标签词预测概率,利用所述标签词预测概率计算校准参数;构建分类模型并进行训练,所述分类模型是将传统的微调方法和提示学习结合起来,构建新的模型来充分学习下游任务知识并且利用预训练学习到的知识;训练时使用交叉熵损失函数计算损失值,并利用所述损失值更新分类模型;使用训练好的分类模型对新样本进行分类。本发明通过融合提示学习和传统微调的方法在小样本数据下训练分类模型,有效地学习到了下游任务特定知识并提高了分类精度。
Description
技术领域
本发明属于自然语言处理的技术领域,具体涉及一种基于预训练语言模型的提示学习小样本分类方法、系统、设备及介质。
背景技术
近年来,人工智能相关技术迅猛发展,在自然语言处理领域,涌现出一系列预训练语言模型(BERT、RoBERTa、GPT、T5等),极大地推动了自然语言处理技术的发展。由于预训练语言模型强大的能力,其已成为解决许多自然语言处理任务的主要方法。
通常的做法是在语言模型头部添加一个线性分类器,然后进行全模型微调以适应下游任务,而最近的一种方法——提示学习,使用提示来执行各种下游任务,被认为能够释放语言模型的潜力。预训练模型在预训练阶段通过填词或者续写等方式来获取通用语言知识,而提示学习,通过构建输入模板,让语言模型进行填词,将填的词映射到分类任务具体的标签上,从而将分类任务建模成完形填空任务,这种方式减少了模型在预训练阶段和下游任务阶段的差距,取得了很好的效果,尤其是在小样本训练的场景下。
然而,最近的研究显示,预训练语言模型的填词预测是有偏差的,它倾向于预测在预训练阶段词频较高的词,从而导致不公平的预测,并且提示学习性能并不稳定(比较依赖于人工构建的模板和标签词)。此外,小样本场景下,提示学习更多地是运用了预训练语言模型在预训练阶段学习到的知识,而在下游任务中能学习到的知识较少,导致小样本训练的精度仍然明显低于全样本训练。因此,如何在小样本的场景下训练出一个无偏差、稳定且高精度的模型是一个亟待解决的难题。
发明内容
本发明的主要目的在于克服现有技术的缺点与不足,提供一种基于预训练语言模型的提示学习小样本分类方法、系统、设备及介质,通过使用训练数据计算模型校准参数,缓解模型的预测偏差以及减少不同模板和标签词带来的性能差异,再通过融合提示学习和传统微调的方法在小样本数据下训练分类模型,有效地学习到了下游任务特定知识并提高了分类精度。
为了达到上述目的,本发明采用以下技术方案:
第一方面,本发明提供了一种基于预训练语言模型的提示学习小样本分类方法,包括下述步骤,
根据预先设立的提示模板预处理数据样本;所述提示模板包括输入样本和被遮蔽的词,所述预处理数据样本是指用所述提示模板包装数据样本,使得包装后的数据样本包含标签词;
将预处理后的数据样本输入到预训练语言模型中,获取遮蔽处的标签词预测概率,利用所述标签词预测概率计算校准参数;
构建分类模型并进行训练,所述分类模型是将传统的微调方法和提示学习结合起来,构建新的模型来充分学习下游任务知识并且利用预训练学习到的知识;训练时使用交叉熵损失函数计算损失值,并利用所述损失值更新分类模型;
使用训练好的分类模型对新样本进行分类。
作为优选的技术方案,所述数据样本包含N个类别,并对每个样本标注所属类别,且不同类别的样本数量为K个,构成K-way-N-shot的小样本数据集,总共包含K*N个训练样本。
作为优选的技术方案,所述利用预测词概率计算校准参数,具体包括下述步骤:
将预处理后的数据样本输入到预训练语言模型中,得到被遮蔽的词对应位置的标签词logits,标签词与标签一一对应,即获得分类标签的logits;
计算缩放系数λ,以将校准后的logits缩放回模型原本输出logits的大小规模:
其中,z是遮蔽处的各标签词的logits;N表示样本的类别,K表示不同类别样本的数量,M表示为预训练语言模型;Wv表示预测标签词的词嵌入,也被用于隐层状态上预测词;hmask为遮蔽处模型最后一层的隐层状态;diag函数是将向量拓展成对角矩阵的函数;i表示输入样本序号,j表示不同标签的索引,表示模型输出对第i个样本的第j个标签的logits,zi表示模型输出对第i个样本的logits向量。
作为优选的技术方案,所述分类模型中,使用校准参数计算校准后的提示学习的输出概率为:
pMLM=Softmax(Wz)
构建提示作为特征提取器,取预训练语言模型输出的遮蔽处隐层状态作为特征,并构建一个分类器f进行分类:
pCLS=Softmax(f(hmask))
并将两个输出加权融合:
p(y|xprompt)=α·pMLM+(1-α)·pCLS
α为平衡因子,上述Softmax函数表示为:
xc和xj为标签的索引c和j所对应的标签输出logits;由于pMLM和pCLS两种输出都使用了hmask,两种方式共享遮蔽处隐层状态,使用了隐式的多任务学习,将该分类任务拆分成了两个子分类任务,使分类模型拥有更好的泛化性能。
作为优选的技术方案,对分类模型进行训练包括下述步骤:
将每一批数据输入构建的分类模型中,获取各类预测概率分布pi;
使用交叉熵损失函数,基于每个样本对应的标签和概率分布计算损失值;
损失值用于反向传播并使用梯度下降法更新整个分类模型的参数;
其中,所述交叉熵损失函数表示为:
其中,p=[p0,…,pC-1]表示所有类别的预测概率,pi代表第i个类别的预测概率,y=[y0,···,yC-1]是样本类别的one-hot表示,当样本属于第i个类别时yi=1,否则yi=0;C是类别数量。
作为优选的技术方案,对分类模型进行训练包括下述步骤:
构建黑盒优化模型,具体为:
随机初始化投影矩阵A和本地待优化参数zl,l∈L,L为预训练模型层数,每一层需维护一组本地优化参数,并构建本地分类器f;
黑盒优化模型输出为:
p(y|xprompt)=α·pMLM+(1-α)·pCLS
对黑盒优化模型训练,由于分类器在本地构建,其梯度完全可见,可使用梯度和进化算法交替联合优化,优化过程如下:
a)训练数据以没mini-batch的形式输入模型并计算交叉熵损失,反向传播算法更新本地分类器f的参数;
b)全部训练数据输入模型并计算交叉熵损失,使用CMA-ES算法选取最优的zl保存并用于下一轮的CMA-ES算法;
c)迭代3个epoch,b)步骤迭代一轮,即模型层数每一层l使用CMA-ES算法优化一次,构成一次完整优化训练步骤;
将c)中的训练步骤迭代若干轮次.
作为优选的技术方案,分类器使用AdamW优化器对参数进行优化,并在迭代训练过程中每一次完整的优化训练在验证集上评估,选取验证集上精度最高的模型保存。
第二方面,本发明提供了一种基于预训练语言模型的提示学习小样本分类系统,应用于所述的基于同态加密和可信硬件的多方隐私保护机器学习方法,包括输入数据获取模块、模型校准和构建模块、模型训练模块以及模型分类预测模块;
所述输入数据获取模块,用于根据预先设立的提示模板预处理数据样本;所述提示模板包括输入样本和被遮蔽的词,所述预处理数据样本是指用所述提示模板包装数据样本,使得包装后的数据样本包含标签词;
所述模型校准和构建模块,用于将预处理后的数据样本输入到预训练语言模型中,获取遮蔽处的标签词预测概率,利用所述标签词预测概率计算校准参数;
所述模型训练模块,用于构建分类模型并进行训练,所述分类模型是将传统的微调方法和提示学习结合起来,构建新的模型来充分学习下游任务知识并且利用预训练学习到的知识;训练时使用交叉熵损失函数计算损失值,并利用所述损失值更新分类模型;
所述模型分类预测模块,用于使用训练好的分类模型对新样本进行分类。
第三方面,本发明提供了一种电子设备,所述电子设备包括:
至少一个处理器;以及,
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的计算机程序指令,所述计算机程序指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行所述的基于预训练语言模型的提示学习小样本分类方法。
第四方面,本发明提供了一种计算机可读存储介质,存储有程序,所述程序被处理器执行时,实现所述的基于预训练语言模型的提示学习小样本分类方法。
本发明与现有技术相比,具有如下优点和有益效果:
本发明针对提示学习应用于小样本学习场景下,初始模型具有较大预测偏差和不同人工构造模板带来较大的性能差异的问题,通过数据驱动的方法计算模型校准参数,有效地缓解了模型的预测偏差;可以减少不同模板和标签词带来的性能差异,并且通过融合了提示学习和传统微调的方法,使得模型在下游任务上学到更加丰富的任务特定知识、鲁棒性更强,从而提升了分类精度,减小了小样本和多样本学习的性能差异。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本发明实施例提示学习用于情感分析的例子;
图2是本发明实施例基于预训练语言模型的提示学习小样本分类方法的流程图;
图3是本发明实施例基于预训练语言模型的提示学习小样本分类方法的整体示意图;
图4是本发明实施例基于预训练语言模型的提示学习小样本分类方法应用于模型黑盒优化的整体示意图。
图5是本发明实施例基于预训练语言模型的提示学习小样本分类系统的结构示意图;
图6是实现本发明实施例基于预训练语言模型的提示学习小样本分类方法的电子设备的结构示意图。
具体实施方式
为了使本技术领域的人员更好地理解本申请方案,下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述。显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
在本申请中提及“实施例”意味着,结合实施例描述的特定特征、结构或特性可以包含在本申请的至少一个实施例中。在说明书中的各个位置出现该短语并不一定均是指相同的实施例,也不是与其它实施例互斥的独立的或备选的实施例。本领域技术人员显式地和隐式地理解的是,本申请所描述的实施例可以与其它实施例相结合。
大规模预训练语言模型在自然语言处理的各个领域上取得了巨大的成功,小样本学习(Few-shotLearning)是机器学习的一种范式,在极小的训练样本的情况下,对模型进行少量调优,拟合下游任务,得到精度较高可用的模型。人类可以仅通过一或几个示例就可以轻松地建立对新事物的认知,而机器学习算法通常需要大量的训练数据对模型进行训练从而获得较好的泛化能力。由于在机器学习领域,数据标注的成本较高,特别在某一些领域,标注数据十分稀缺且较难获得,人们希望机器学习模型可以在少量的样本下,即可训练出泛化能力强,精度高的模型。
GPT-2和GPT-3的出现使得在下游任务中,出现了提示微调(Prompt-basedTuning)的方法,且这一方法可以成功拓展到MLM模型(例如BERT、RoBERTa)。提示微调在小样本的场景下表现出比传统微调更加优异的性能,因为这一方法可以缩小模型在预训练阶段和下游任务阶段的差距。图1示出了情感分析的提示学习例子,如图一所示,可以构造模板“[X]It was[MASK].”,其中“[X]”是代表输入的样本,而[MASK]是被遮蔽的词,模型需要预测该位置的词来从而预测标签,例如在图1中,所选取的标签词为“great”和“terrible”,当模型预测的词为“great”时,则标签预测为积极,反之,则标签预测为消极。同理,这一方法以及本发明方法可以拓展到中文或其他语言,不再一一赘述。
但该方法的初始模型具有较大标签词预测偏差和不同人工构造模板和标签词带来较大的性能差异。基于此,如图2所示,本实施例提出一种基于预训练语言模型的提示学习小样本分类方法,包括下述步骤:
S1、根据提示模板预处理数据样本
S11、根据分类任务和分类样本的领域特点,选取模板和标签词;
可通过人工选取模板和标签词,借助人类自带的外部知识构建提示;另外LM-BFF是一项通过T5模型自动生成模板和标签词的先进技术,可以使用LM-BFF来自动生成以避免人工选取的偏差。此外,标签词可以不仅仅局限于单个词,每一类标签可以对应多个标签。
S12、用提示模板包装数据样本;
如图1所示,用提示模板包装数据样本,包装后的数据都应包含“[MASK]”用于模型的预测,并且对该样本进行分词以及将词转换为对应的词表索引,并构建attention mask为输入模型做准备。
其中,定义人工构建的提示模板和每一个标签对应的标签词(标签词也可以使用T5等模型自动生成),以情感分析为例,提示模板如下:
Hello,my dog is cute.It was[MASK].其中,“Hello,my dog is cute”为输入样本,[MASK]是被遮蔽的词,需要语言模型该位置的词来从而预测标签,标签词可为“great”和“terrible”,与分类标签“积极”和“消极”一一对应。
进一步的,小样本数据集进行预处理,通常来说,训练数据集是N-way-k-shot的,包含K*N个样本,N代表分类任务的标签数量,K表示每一类标签的样本数。步骤S1中定义的模板包装这些数据。
S2、使用数据样本和模型计算校准参数;
S21、将预处理后的少量数据集输入模型中,生成遮蔽处每一类标签的标签词logits:
z=M(x|prompt)=Wv(y)*hmask
S22、计算缩放系数λ,以将校准后的logits缩放回模型原本输出logits的大小规模:
S23、计算校准参数W:
其中,上述公式中z是遮蔽处的各标签词的logits;M表示为预训练语言模型;Wv表示预测标签词的词嵌入,也被用于隐层状态上预测词;hmask为遮蔽处模型最后一层的隐层状态;diag函数是将向量拓展成对角矩阵的函数;i表示输入样本序号,j表示不同标签的索引,表示模型输出对第i个样本的第j个标签的logits,zi表示模型输出对第i个样本的logits向量。
S3、构建模型并进行训练
S31、构建模型
提示学习可以较好地运用预训练阶段所学习到的知识,但对下游任务的特定知识学习得不够充分,如图3所示,通过将传统的微调方法和提示学习结合起来,构建新的模型来充分学习下游任务知识并且有效利用预训练学习到的知识。使用S2中校准参数计算校准后的提示学习的输出概率为:
pMLM=Softmax(Wz)
构建提示作为特征提取器,取预训练语言模型输出的遮蔽处隐层状态作为特征,并构建一个分类器f进行分类:
pCLS=Softmax(f(hmask))
并将这两个输出加权融合:
p(y|xprompt)=α·pMLM+(1-α)·pCLS
其中,α∈(0,1)是超参数用于平衡两者权重,对于比较简单的任务可以将α设为0.5,比较复杂的任务(例如句对任务)可以提高α值;
α为平衡因子,上述Softmax函数表示为:
xc和xj为标签的索引c和j所对应的标签输出logits;由于pMLM和pCLS两种输出都使用了hmask,两种方式共享遮蔽处隐层状态,使用了隐式的多任务学习,将该分类任务拆分成了两个子分类任务,可以使模型拥有更好的泛化性能。
S32、模型训练
将小样本数据集使用mini-batch的方式并使用交叉熵损失函数(Cross EntropyLoss)计算损失训练,具体为:
a)将每一批数据输入S31中的模型中,获取各类预测概率分布pi;
b)使用交叉熵函数,每个样本对应的标签和概率分布计算损失值;
c)损失值用于反向传播并使用梯度下降法更新整个模型的参数。
所述交叉熵损失函数表示为:
其中,p=[p0,…,pC-1]表示所有类别的预测概率,pi代表第i个类别的预测概率,y=[y0,···,yC-1]是样本类别的one-hot表示,当样本属于第i个类别时yi=1,否则yi=0;C是类别数量。
在本实施例的迭代训练过程中,模型使用AdamW优化器对参数进行优化,训练最大步数设为500,并在迭代训练过程中每50步在验证集上评估,选取验证集上精度最高的模型保存;批大小选为[2,4,8],学习率选为[1e-5,2e-5],α设为[0.5,0.7],以上超参数视情况增加超参数搜索空间,grid-search超参数,遍历所有超参数组合,选取验证集上表现最优的一组超参数训练的模型用于使用,训练使用线性衰减学习率逐步下降到0。
S4、将待预测数据样本按S1预处理,输入训练好的分类模型中,得到预测的分类结果。
在本发明的另一个实施例中,基于与上述实施例相同的思想,该方法可以在预训练语言模型的黑盒优化中使用;大型语言模型(Large Langage Model,LLM)通常由互联网厂商预训练,并通过API的形式需付费供消费者使用(如GPT3)。为了保护模型不被窃取,模型权重和梯度对于调用API的消费者是不可见的,而通过黑盒优化可以在模型权重和梯度不可见的情况下训练模型。其具体做法是通过进化算法来优化一组参数,这组参数通过随机投影到更高维度作为模型的前缀提示prefix-prompt,从而达到训练模型的目的。
如图4所示,本发明方法可以应用于该场景的黑盒优化,并且因模型大部分权重不可优化,本发明方法可以取得更好的校准效果,本实施例除下述构建模型并进行训练步骤(对应上述实施例中基于预训练语言模型的提示学习小样本分类方法步骤S3)外其他步骤同上述实施例,具体为:
S31、构建黑盒优化模型
随机初始化投影矩阵A和本地待优化参数zl l∈L,L为预训练模型层数,每一层需维护一组本地优化参数,并构建本地分类器f
同上述实施例,黑盒优化模型输出为:
p(y|xprompt)=α·pMLM+(1-α)·pCLS
S32、黑盒优化模型训练:
由于分类器在本地构建,其梯度完全可见,可使用梯度和进化算法交替联合优化:
a)训练数据以没mini-batch的形式输入模型并计算交叉熵损失,反向传播算法更新本地分类器f的参数
b)全部训练数据输入模型并计算交叉熵损失,使用CMA-ES算法选取最优的zl保存并用于下一轮的CMA-ES算法
c)a步骤迭代3个epoch,b步骤迭代一轮,即模型层数每一层l使用CMA-ES算法优化一次,构成一次完整优化训练步骤
d)将c中的训练步骤迭代若干轮次
在本实施例的迭代训练过程中,本地分类器使用AdamW优化器对参数进行优化,学习率选为1e-5,CMA-ES进化算法人口选为20,并在迭代训练过程中每一次完整的优化训练在验证集上评估,选取验证集上精度最高的模型保存。
基于与上述实施例中的基于预训练语言模型的提示学习小样本分类方法相同的思想,本发明还提供基于预训练语言模型的提示学习小样本分类系统,该系统可用于执行上述基于预训练语言模型的提示学习小样本分类方法。为了便于说明,基于预训练语言模型的提示学习小样本分类系统实施例的结构示意图中,仅仅示出了与本发明实施例相关的部分,本领域技术人员可以理解,图示结构并不构成对装置的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。
如图5所示,本发明实施例提供了一种基于预训练语言模型的提示学习小样本分类系统100,包括输入数据获取模块101、模型校准和构建模块102、模型训练模块103以及模型分类预测模块104;
所述输入数据获取模块101,用于根据预先设立的提示模板预处理数据样本;所述提示模板包括输入样本和被遮蔽的词,所述预处理数据样本是指用所述提示模板包装数据样本,使得包装后的数据样本包含标签词;
所述模型校准和构建模块102,用于将预处理后的数据样本输入到预训练语言模型中,获取遮蔽处的标签词预测概率,利用所述标签词预测概率计算校准参数;
所述模型训练模块103,用于构建分类模型并进行训练,所述分类模型是将传统的微调方法和提示学习结合起来,构建新的模型来充分学习下游任务知识并且利用预训练学习到的知识;训练时使用交叉熵损失函数计算损失值,并利用所述损失值更新分类模型;
所述模型分类预测模块104,用于使用训练好的分类模型对新样本进行分类。
需要说明的是,本发明的基于预训练语言模型的提示学习小样本分类系统与本发明的基于预训练语言模型的提示学习小样本分类方法一一对应,在上述基于预训练语言模型的提示学习小样本分类方法的实施例阐述的技术特征及其有益效果均适用于基于预训练语言模型的提示学习小样本分类系统的实施例中,具体内容可参见本发明方法实施例中的叙述,此处不再赘述,特此声明。
此外,上述实施例的基于预训练语言模型的提示学习小样本分类系统的实施方式中,各程序模块的逻辑划分仅是举例说明,实际应用中可以根据需要,例如出于相应硬件的配置要求或者软件的实现的便利考虑,将上述功能分配由不同的程序模块完成,即将所述基于预训练语言模型的提示学习小样本分类系统的内部结构划分成不同的程序模块,以完成以上描述的全部或者部分功能。
如图6所示,在本申请的另一个实施例中,提供了一种基于预训练语言模型的提示学习小样本分类方法的电子设备200,所述电子设备200可以包括第一处理器201、第一存储器202和总线,还可以包括存储在所述第一存储器202中并可在所述第一处理器201上运行的计算机程序,如基于预训练语言模型的提示学习小样本分类程序203。
其中,所述第一存储器202至少包括一种类型的可读存储介质,所述可读存储介质包括闪存、移动硬盘、多媒体卡、卡型存储器(例如:SD或DX存储器等)、磁性存储器、磁盘、光盘等。所述第一存储器202在一些实施例中可以是电子设备200的内部存储单元,例如该电子设备200的移动硬盘。所述第一存储器202在另一些实施例中也可以是电子设备200的外部存储设备,例如电子设备200上配备的插接式移动硬盘、智能存储卡(Smart Media Card,SMC)、安全数字(SecureDigital,SD)卡、闪存卡(Flash Card)等。进一步地,所述第一存储器202还可以既包括电子设备200的内部存储单元也包括外部存储设备。所述第一存储器202不仅可以用于存储安装于电子设备200的应用软件及各类数据,例如基于预训练语言模型的提示学习小样本分类程序203的代码等,还可以用于暂时地存储已经输出或者将要输出的数据。
所述第一处理器201在一些实施例中可以由集成电路组成,例如可以由单个封装的集成电路所组成,也可以是由多个相同功能或不同功能封装的集成电路所组成,包括一个或者多个中央处理器(Central Processing unit,CPU)、微处理器、数字处理芯片、图形处理器及各种控制芯片的组合等。所述第一处理器201是所述电子设备的控制核心(Control Unit),利用各种接口和线路连接整个电子设备的各个部件,通过运行或执行存储在所述第一存储器202内的程序或者模块,以及调用存储在所述第一存储器202内的数据,以执行电子设备200的各种功能和处理数据。
图6仅示出了具有部件的电子设备,本领域技术人员可以理解的是,图5示出的结构并不构成对所述电子设备200的限定,可以包括比图示更少或者更多的部件,或者组合某些部件,或者不同的部件布置。
所述电子设备200中的所述第一存储器202存储的基于预训练语言模型的提示学习小样本分类程序203是多个指令的组合,在所述第一处理器201中运行时,可以实现:
根据预先设立的提示模板预处理数据样本;所述提示模板包括输入样本和被遮蔽的词,所述预处理数据样本是指用所述提示模板包装数据样本,使得包装后的数据样本包含标签词;
将预处理后的数据样本输入到预训练语言模型中,获取遮蔽处的标签词预测概率,利用所述标签词预测概率计算校准参数;
构建分类模型并进行训练,所述分类模型是将传统的微调方法和提示学习结合起来,构建新的模型来充分学习下游任务知识并且利用预训练学习到的知识;训练时使用交叉熵损失函数计算损失值,并利用所述损失值更新分类模型;
使用训练好的分类模型对新样本进行分类。
进一步地,所述电子设备200集成的模块/单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个非易失性计算机可读取存储介质中。所述计算机可读介质可以包括:能够携带所述计算机程序代码的任何实体或装置、记录介质、U盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(ROM,Read-Only Memory)。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的程序可存储于一非易失性计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可包括只读存储器(ROM)、可编程ROM(PROM)、电可编程ROM(EPROM)、电可擦除可编程ROM(EEPROM)或闪存。易失性存储器可包括随机存取存储器(RAM)或者外部高速缓冲存储器。作为说明而非局限,RAM以多种形式可得,诸如静态RAM(SRAM)、动态RAM(DRAM)、同步DRAM(SDRAM)、双数据率SDRAM(DDRSDRAM)、增强型SDRAM(ESDRAM)、同步链路(Synchlink)DRAM(SLDRAM)、存储器总线(Rambus)直接RAM(RDRAM)、直接存储器总线动态RAM(DRDRAM)、以及存储器总线动态RAM(RDRAM)等。
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
上述实施例为本发明较佳的实施方式,但本发明的实施方式并不受上述实施例的限制,其他的任何未背离本发明的精神实质与原理下所作的改变、修饰、替代、组合、简化,均应为等效的置换方式,都包含在本发明的保护范围之内。
Claims (10)
1.基于预训练语言模型的提示学习小样本分类方法,其特征在于,包括下述步骤,
根据预先设立的提示模板预处理数据样本;所述提示模板包括输入样本和被遮蔽的词,所述预处理数据样本是指用所述提示模板包装数据样本,使得包装后的数据样本包含标签词;
将预处理后的数据样本输入到预训练语言模型中,获取遮蔽处的标签词预测概率,利用所述标签词预测概率计算校准参数;
构建分类模型并进行训练,所述分类模型是将传统的微调方法和提示学习结合起来,构建新的模型来充分学习下游任务知识并且利用预训练学习到的知识;训练时使用交叉熵损失函数计算损失值,并利用所述损失值更新分类模型;
使用训练好的分类模型对新样本进行分类。
2.根据权利要求1所述基于预训练语言模型的提示学习小样本分类方法,其特征在于,所述数据样本包含N个类别,并对每个样本标注所属类别,且不同类别的样本数量为K个,构成K-way-N-shot的小样本数据集,总共包含K*N个训练样本。
3.根据权利要求1所述基于预训练语言模型的提示学习小样本分类方法,其特征在于,所述利用预测词概率计算校准参数,具体包括下述步骤:
将预处理后的数据样本输入到预训练语言模型中,得到被遮蔽的词对应位置的标签词logits,标签词与标签一一对应,即获得分类标签的logits;
计算缩放系数λ,以将校准后的logits缩放回模型原本输出logits的大小规模:
4.根据权利要求1所述基于预训练语言模型的提示学习小样本分类方法,其特征在于,所述分类模型中,使用校准参数计算校准后的提示学习的输出概率为:
pMLM=Softmax(Wz)
构建提示作为特征提取器,取预训练语言模型输出的遮蔽处隐层状态作为特征,并构建一个分类器f进行分类:
pCLS=Softmax(f(jmask))
并将两个输出加权融合:
p(y||xprompt)=α·pMLM+(1α)·pCLS
α为平衡因子,上述Softmax函数表示为:
xc和xj为标签的索引c和j所对应的标签输出logits;由于pMLM和pCLS两种输出都使用了hmask,两种方式共享遮蔽处隐层状态,使用了隐式的多任务学习,将该分类任务拆分成了两个子分类任务,使分类模型拥有更好的泛化性能。
6.根据权利要求1所述基于预训练语言模型的提示学习小样本分类方法,其特征在于,对分类模型进行训练包括下述步骤:
构建黑盒优化模型,具体为:
随机初始化投影矩阵A和本地待优化参数zl,l∈L,L为预训练模型层数,每一层需维护一组本地优化参数,并构建本地分类器f;
黑盒优化模型输出为:
p(y||xprompt)=α·pMLM+(1α)·pCLS
对黑盒优化模型训练,由于分类器在本地构建,其梯度完全可见,可使用梯度和进化算法交替联合优化,优化过程如下:
a)训练数据以没mini-batch的形式输入模型并计算交叉熵损失,反向传播算法更新本地分类器f的参数;
b)全部训练数据输入模型并计算交叉熵损失,使用CMA-ES算法选取最优的zl保存并用于下一轮的CMA-ES算法;
c)迭代3个epoch,b)步骤迭代一轮,即模型层数每一层l使用CMA-ES算法优化一次,构成一次完整优化训练步骤;
将c)中的训练步骤迭代若干轮次。
7.根据权利要求6所述基于预训练语言模型的提示学习小样本分类方法,其特征在于,分类器使用AdamW优化器对参数进行优化,并在迭代训练过程中每一次完整的优化训练在验证集上评估,选取验证集上精度最高的模型保存。
8.基于预训练语言模型的提示学习小样本分类系统,其特征在于,应用于权利要求1-7中任一项所述的基于同态加密和可信硬件的多方隐私保护机器学习方法,包括输入数据获取模块、模型校准和构建模块、模型训练模块以及模型分类预测模块;
所述输入数据获取模块,用于根据预先设立的提示模板预处理数据样本;所述提示模板包括输入样本和被遮蔽的词,所述预处理数据样本是指用所述提示模板包装数据样本,使得包装后的数据样本包含标签词;
所述模型校准和构建模块,用于将预处理后的数据样本输入到预训练语言模型中,获取遮蔽处的标签词预测概率,利用所述标签词预测概率计算校准参数;
所述模型训练模块,用于构建分类模型并进行训练,所述分类模型是将传统的微调方法和提示学习结合起来,构建新的模型来充分学习下游任务知识并且利用预训练学习到的知识;训练时使用交叉熵损失函数计算损失值,并利用所述损失值更新分类模型;
所述模型分类预测模块,用于使用训练好的分类模型对新样本进行分类。
9.一种电子设备,其特征在于,所述电子设备包括:
至少一个处理器;以及,
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的计算机程序指令,所述计算机程序指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行如权利要求1-7中任意一项所述的基于预训练语言模型的提示学习小样本分类方法。
10.一种计算机可读存储介质,存储有程序,其特征在于,所述程序被处理器执行时,实现权利要求1-7任一项所述的基于预训练语言模型的提示学习小样本分类方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310270334.XA CN116415170A (zh) | 2023-03-20 | 2023-03-20 | 基于预训练语言模型的提示学习小样本分类方法、系统、设备及介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310270334.XA CN116415170A (zh) | 2023-03-20 | 2023-03-20 | 基于预训练语言模型的提示学习小样本分类方法、系统、设备及介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116415170A true CN116415170A (zh) | 2023-07-11 |
Family
ID=87055770
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310270334.XA Pending CN116415170A (zh) | 2023-03-20 | 2023-03-20 | 基于预训练语言模型的提示学习小样本分类方法、系统、设备及介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116415170A (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116610804A (zh) * | 2023-07-19 | 2023-08-18 | 深圳须弥云图空间科技有限公司 | 一种提升小样本类别识别的文本召回方法和系统 |
CN117057413A (zh) * | 2023-09-27 | 2023-11-14 | 珠高智能科技(深圳)有限公司 | 强化学习模型微调方法、装置、计算机设备及存储介质 |
CN117390497A (zh) * | 2023-12-08 | 2024-01-12 | 浙江口碑网络技术有限公司 | 基于大语言模型的类目预测方法、装置和设备 |
CN117574981A (zh) * | 2024-01-16 | 2024-02-20 | 城云科技(中国)有限公司 | 一种信息分析模型的训练方法及信息分析方法 |
-
2023
- 2023-03-20 CN CN202310270334.XA patent/CN116415170A/zh active Pending
Cited By (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116610804A (zh) * | 2023-07-19 | 2023-08-18 | 深圳须弥云图空间科技有限公司 | 一种提升小样本类别识别的文本召回方法和系统 |
CN116610804B (zh) * | 2023-07-19 | 2024-01-05 | 深圳须弥云图空间科技有限公司 | 一种提升小样本类别识别的文本召回方法和系统 |
CN117057413A (zh) * | 2023-09-27 | 2023-11-14 | 珠高智能科技(深圳)有限公司 | 强化学习模型微调方法、装置、计算机设备及存储介质 |
CN117057413B (zh) * | 2023-09-27 | 2024-03-15 | 传申弘安智能(深圳)有限公司 | 强化学习模型微调方法、装置、计算机设备及存储介质 |
CN117390497A (zh) * | 2023-12-08 | 2024-01-12 | 浙江口碑网络技术有限公司 | 基于大语言模型的类目预测方法、装置和设备 |
CN117390497B (zh) * | 2023-12-08 | 2024-03-22 | 浙江口碑网络技术有限公司 | 基于大语言模型的类目预测方法、装置和设备 |
CN117574981A (zh) * | 2024-01-16 | 2024-02-20 | 城云科技(中国)有限公司 | 一种信息分析模型的训练方法及信息分析方法 |
CN117574981B (zh) * | 2024-01-16 | 2024-04-26 | 城云科技(中国)有限公司 | 一种信息分析模型的训练方法及信息分析方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN116415170A (zh) | 基于预训练语言模型的提示学习小样本分类方法、系统、设备及介质 | |
US20210012199A1 (en) | Address information feature extraction method based on deep neural network model | |
CN109992779B (zh) | 一种基于cnn的情感分析方法、装置、设备及存储介质 | |
CN110807154A (zh) | 一种基于混合深度学习模型的推荐方法与系统 | |
CN110210032B (zh) | 文本处理方法及装置 | |
CN110516530A (zh) | 一种基于非对齐多视图特征增强的图像描述方法 | |
CN112000772B (zh) | 面向智能问答基于语义特征立方体的句子对语义匹配方法 | |
CN114860893B (zh) | 基于多模态数据融合与强化学习的智能决策方法及装置 | |
CN108665506A (zh) | 图像处理方法、装置、计算机存储介质及服务器 | |
Wu et al. | Centroid transformers: Learning to abstract with attention | |
CN111723914A (zh) | 一种基于卷积核预测的神经网络架构搜索方法 | |
CN112000770A (zh) | 面向智能问答的基于语义特征图的句子对语义匹配方法 | |
CN115222998B (zh) | 一种图像分类方法 | |
CN112348911A (zh) | 基于语义约束的堆叠文本生成细粒度图像方法及系统 | |
CN115861995B (zh) | 一种视觉问答方法、装置及电子设备和存储介质 | |
CN106021402A (zh) | 用于跨模态检索的多模态多类Boosting框架构建方法及装置 | |
WO2023160290A1 (zh) | 神经网络推理加速方法、目标检测方法、设备及存储介质 | |
Liu et al. | Convolutional neural networks-based locating relevant buggy code files for bug reports affected by data imbalance | |
CN116681810A (zh) | 虚拟对象动作生成方法、装置、计算机设备和存储介质 | |
CN112035689A (zh) | 一种基于视觉转语义网络的零样本图像哈希检索方法 | |
CN117197569A (zh) | 图像审核方法、图像审核模型训练方法、装置和设备 | |
CN113535902A (zh) | 一种融合对抗训练的生成式对话系统 | |
CN116661852B (zh) | 一种基于程序依赖图的代码搜索方法 | |
Xia | An overview of deep learning | |
CN111783688A (zh) | 一种基于卷积神经网络的遥感图像场景分类方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |