CN117057414A - 一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法及系统 - Google Patents
一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法及系统 Download PDFInfo
- Publication number
- CN117057414A CN117057414A CN202311012488.5A CN202311012488A CN117057414A CN 117057414 A CN117057414 A CN 117057414A CN 202311012488 A CN202311012488 A CN 202311012488A CN 117057414 A CN117057414 A CN 117057414A
- Authority
- CN
- China
- Prior art keywords
- model
- text
- student
- generating
- sequence
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 51
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 35
- 238000009826 distribution Methods 0.000 claims abstract description 35
- 230000006870 function Effects 0.000 claims description 25
- 238000004364 calculation method Methods 0.000 claims description 8
- 238000009499 grossing Methods 0.000 claims description 8
- 230000008569 process Effects 0.000 claims description 8
- 238000012935 Averaging Methods 0.000 claims description 3
- 235000000332 black box Nutrition 0.000 claims description 3
- 238000004590 computer program Methods 0.000 claims description 3
- 230000005484 gravity Effects 0.000 claims description 3
- 238000005070 sampling Methods 0.000 claims description 3
- 238000012549 training Methods 0.000 abstract description 19
- 238000003058 natural language processing Methods 0.000 abstract description 2
- 230000001172 regenerating effect Effects 0.000 abstract 1
- 239000013598 vector Substances 0.000 description 8
- 238000004821 distillation Methods 0.000 description 6
- 238000010200 validation analysis Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000013528 artificial neural network Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 239000013604 expression vector Substances 0.000 description 1
- 239000011159 matrix material Substances 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 239000013589 supplement Substances 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F40/00—Handling natural language data
- G06F40/20—Natural language analysis
- G06F40/205—Parsing
- G06F40/216—Parsing using statistical methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F40/00—Handling natural language data
- G06F40/30—Semantic analysis
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0475—Generative networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/048—Activation functions
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0495—Quantised networks; Sparse networks; Compressed networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N7/00—Computing arrangements based on specific mathematical models
- G06N7/01—Probabilistic graphical models, e.g. probabilistic networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- Biomedical Technology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Probability & Statistics with Applications (AREA)
- Audiology, Speech & Language Pathology (AREA)
- Algebra (AREA)
- Computational Mathematics (AREA)
- Mathematical Analysis (AREA)
- Mathematical Optimization (AREA)
- Pure & Applied Mathematics (AREA)
- Machine Translation (AREA)
Abstract
本发明一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法及系统,涉及自然语言处理领域,为解决现有方法,无法获取模型的参数及结构、模型输出的概率分布及模型的训练数据的问题。包括如下过程对初始文本序列样本进行释义改写,再生成多个用于文本生成的prompt,与释义改写后的的序列样本相结合,得到教师模型的输出结果;生成一个用于文本生成的prompt,将初始序列样本与用于文本生成的prompt相结合输入学生模型,得到学生模型的输出结果;构建统计语言模型,分别计算教师模型和学生模型输出结果的概率分布;计算教师模型和学生模型输出结果的概率分布的差异损失,及学生模型在对应的目标文本上的损失,得到学生模型。本发明模型更具有较高的准确性。
Description
技术领域
本发明涉及自然语言处理技术领域,具体而言,涉及一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法及系统。
背景技术
知识蒸馏致力于将知识从复杂的模型转移到较小模型,黑盒蒸馏是指在教师模型为黑盒(black-box)的情况下,即当教师模型的结构、参数不可见并且只能获得模型输出的最终结果而不是软标签时,通过知识蒸馏将知识传递给学生模型的训练方法。目前,黑盒知识蒸馏的方法主要包括:构建特殊输入样例、训练样例生成器和模拟教师输出分布。
构建特殊输入样例的核心思想是通过下游任务数据、域外数据(OOD,Out-Of-Domain data)或者按照一定规则构建的特殊数据作为输入样例,将其与教师模型的对应输出组成样例标签对,用于训练学生模型。训练样例生成器的核心思想是使用深度神经网络来生成训练和测试样例,该生成器在训练的过程中会不断地与教师模型或学生模型进行交互,从而获得针对性更强、特征更丰富的样例。模拟教师输出分布是当教师模型是一个黑盒时,只能获取教师模型的输出,而不能获取位于输出前一层的概率分布,因此就出现了许多关于模拟教师输出分布的研究。而现有的黑盒知识蒸馏方法,首先无法获取模型的参数及结构;其次,无法获取模型输出的概率分布;再者,无法获取模型的训练数据。
发明内容
本发明要解决的技术问题是:
现有的黑盒知识蒸馏方法,存在无法获取模型的参数及结构、模型输出的概率分布及模型的训练数据的问题。
本发明为解决上述技术问题所采用的技术方案:
本发明提供了一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法,包括如下步骤:
S1、采用大语言模型生成多个用于释义改写的prompt,将初始文本序列样本进行释义改写,得到多个释义改写后的序列样本;
S2、采用大语言模型生成多个用于文本生成的prompt,将释义改写后的序列样本与用于文本生成的prompt相结合,输入到教师模型进行文本生成,得到教师模型输出结果;
S3、采用大语言模型生成一个用于文本生成的prompt,将初始序列样本与用于文本生成的prompt相结合输入学生模型,得到学生模型的输出结果;
S4、构建统计语言模型,采用所述统计语言模型对教师模型的多个输出结果和学生模型的一个输出结果分别建模,分别计算教师模型和学生模型输出结果的概率分布;
S5、以第一损失函数计算教师模型和学生模型输出结果的概率分布的差异损失,以第二损失函数计算学生模型在对应的目标文本上的损失,结合两个损失结果计算总损失,对学生模型的参数进行调整;
S6、重复执行S4到S5,至模型收敛或者达到预设迭代次数,得到训练后的学生模型。
进一步地,S1中采用大语言模型生成多个用于释义改写的prompt,记为其中=1,2,…;针对一个初始文本序列x=x1x2…xs,将输入序列与生成的个/>相结合,得到多个不同的序列x+pk,输入教师模型进行释义改写,得到多个释义改写后的序列样本xk;
进行全部次改写后,得到个不同的释义改写后的序列,即{x1,x2,…,xK}。
进一步地,S1中还包括使用释义判别模型对生成的个prompt两两进行相似度判断,以确保提示间语义的相似性。
进一步地,S2中采用大语言模型生成多个用于文本生成的prompt,记为其中=1,2,…K,将释义改写后的的序列样本{x1,x2,…,xK}与用于文本生成的prompt相结合,输入教师模型,得到文本生成序列样本y(k);
进行全部K次文本生成后,得到教师模型输出结果,即K个不同序列{y(1),y(2),…,y(K)}。
进一步地,S2中使用释义判别模型对个释义改写后的序列两两进行相似度判断,若基本保持语义一致,则全部送入下一轮进行文本生成;否则,对语义偏差较大的文本,重新进行释义改写,以使个释义文本间保持语义的一致性。
进一步地,S3中学生模型的文本生成过程中,模型采用贪心采样的策略,对每个位置采样时仅提取出现在当前位置的概率最大的词作为结果。
进一步地,S4中所述的统计语言模型的构建方法为:针对文本序列ω=ω1ω2…ωn,通过统计ω在整个文本语料库中出现的概率P(ω)实现机器对语言的识别,采用条件概率公式可得P(ω)为:
P(ω)=P(ω1)P(ω2|ω1)P(ω3|ω1ω2)…P(ωn|ω1ω2…ωn-1)
其中,P(ωn|ω1ω2…ωn-1)表示在已知前n-1个词的前提下,第n个词ωn的出现概率;
采用基于马尔科夫假设的二元模型Bi-gram对计算公式进行简化,具体地,假设第ωn的出现概率仅与它的前一个词ωn-1有关,则:
P(ω)=P(ω1)P(ω2|ω1)P(ω3|ω2)…P(ωn|ωn-1)
采用拉普拉斯平滑的方法对每个词的概率分布进行平滑处理,则概率分布为:
其中,C(ωn)为ωn在语料库中出现的次数,C(yn-1yn)为yn-1yn的bi-gram组合在语料库/>中出现的概率,/>为整个词汇表的大小;为常数,需根据具体的词汇表进行调整。
进一步地,S5中所述第一损失函数,首先采用KL散度计算教师模型和学生模型输出结果的概率分布间的差异,损失函数为:
其中,ypred_w为ypred序列中的第w个词,P(ypred_w|ypred)为词ypred_w在ypred映射到的词空间上的概率分布,P(ypred_w|y)为词ypred_w在y映射到的词空间上的概率分布,LMT为统计语言模型对教师模型的输出结果进行的建模,即:
LMT=Language Model(y1,y2,…,yn)
LMS为统计语言模型对学生模型的输出结果的建模,即:
LMS=Language Model(ypred_1,ypred_2,…,ypred_m)
将教师模型的次输出与学生模型的输出ypred依次计算KL散度后取平均,得到损失函数:
其中,|K|为调用教师模型对输入生成不同文本的次数,为语言模型对教师模型的第个输出文本的建模;
所述第二损失函数为计算学生模型在对应的目标文本上的负对数似然损失,损失函数为:
其中,n为目标序列的长度,为文本序列/>
结合两部分损失,得到总损失函数为:
LKD=(1-λ)LNLL+λLKL_avg
其中,λ是一个超参数,用于决定两类损失的比重。
一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏系统,该系统具有与上述技术方案任一项的步骤对应的程序模块,运行时执行上述的面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法中的步骤。
一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序配置为由处理器调用时实现上述技术方案中任一项所述的面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法的步骤。
相较于现有技术,本发明的有益效果是:
本发明一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法及系统,引入了统计语言模型对教师模型和学生模型的输出进行建模以获取两个输出的概率分布,并使用KL散度计算分布间的差异,以及学生模型生成文本与真实目标文本之间的损失,以此作为蒸馏损失,使模型更具有较高的准确性。本发明应用大规模语言模型,生成多个具有相同释义的prompt。并将同一个输入结合不同的prompt,输入到大规模语言模型中进行改写的策略,可以实现对于同一个输入获得多个不同输出,利用这种多样性特点对教师模型的输出分布进行建模,以弥补无法获取黑盒模型真实输出概率的问题。同时,多样性的教师输出,也能为学生模型提供更丰富的文本特征,以提高知识蒸馏的效果。
附图说明
图1为本发明实施例中面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法流程图一;
图2为本发明实施例中面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法流程图二。
具体实施方式
在本发明的描述中,应当说明的是,在本发明的实施例中所提到的术语“第一”、“第二”、“第三”仅用于描述目的,并不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”、“第三”的特征可以明示或者隐含地包括一个或者多个该特征。
为使本发明的上述目的、特征和优点能够更为明显易懂,下面结合附图对本发明的具体实施例做详细的说明。
具体实施方案一:如图1和图2所示,本发明提供一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法,包括如下步骤:
S1、采用大语言模型生成多个用于释义改写的prompt,将初始文本序列样本进行释义改写,得到多个释义改写后的序列样本;
S2、采用大语言模型生成多个用于文本生成的prompt,将释义改写后的序列样本与用于文本生成的prompt相结合,输入到教师模型进行文本生成,得到教师模型输出结果;
S3、采用大语言模型生成一个用于文本生成的prompt,将初始序列样本与用于文本生成的prompt相结合输入学生模型,得到学生模型的输出结果;
S4、构建统计语言模型,采用所述统计语言模型对教师模型的多个输出结果和学生模型的一个输出结果分别建模,分别计算教师模型和学生模型输出结果的概率分布;
S5、以第一损失函数计算教师模型和学生模型输出结果的概率分布的差异损失,以第二损失函数计算学生模型在对应的目标文本上的损失,结合两个损失结果计算总损失,对学生模型的参数进行调整;
S6、重复执行S4到S5,至模型收敛或者达到预设迭代次数,得到训练后的学生模型。
本实施方案中,S3中所述的学生模型对文本的编码过程为:
学生模型使用基于Transformer的预训练语言模型,在预训练阶段,给定一个文本序列ω=ω1ω…ω作为输入,模型首先会在输入层对ω中的每个词ω做词嵌入(WordEmbedding)并映射为向量:
其中,表示词ω的词嵌入(Token Embedding),/>表示词ω的位置嵌入(Position Embedding),v为第i个位置的单词ω经过学生模型输入层进行词嵌入后的输出;由于每个词在文本序列中的不同位置可能有不同的语义,并且Transformer在对词进行逐个编码时无法感知词的位置,所以此处增加位置嵌入以补充更多的位置信息。
文本序列ω=ω1ω…ω经过输入层编码为向量序列v=v1v…v,随后L个编码层对向量序列先编码后解码;编码过程中,在自注意力的机制下,每个编码层中的每个表示向量都能与之前位置中的向量相结合以获得更丰富的上下文信息;经过多层解码后,最后一个隐藏层中包含了单词层次化的组合式表示,L层Transformer的计算过程公式如下:
其中,表示第L层的表示向量序列,n为序列长度,d为模型隐藏层维度,L为模型总层数。
学生模型的文本生成过程为:
学生模型在生成文本前首先对表示向量进行解码,解码后的输出为每个位置上的条件概率,即在每个位置上每个词出现的概率,当前位置的条件概率基于第L层的隐藏状态h(L)和之前位置的预测结果;对于第一个位置,则结合句首标记<BOS>进行预测;其对应的计算公式为:
P(ωi|ω1ω2…ωi-1)=Softmax(Weh(L)+bout)
其中,为词向量矩阵,/>为词汇表大小,/>为偏置项;
模型的训练目标是最大化似然概率估计,即最小化似然概率损失,对于输入的文本序列,其对应的损失函数为:
其中,θ为模型参数;
在下游任务精调阶段采用相同的方式进行编码与解码;经过预训练的模型具备了一定的通用语言表示能力,在下游任务中根据具体的数据集和任务目标进行适配。
在下游任务精调阶段,给定一个源序列q=q1q2…qm和一个目标序列a=a1a2…an,采用提示学习的方法,在输入序列中添加提示(prompt),记为p,将提示与原输入序列q组合为形如q+p的带提示的序列作为输入。训练的目标为在给定的输入为q+p时最大化生成a的似然概率,对应的概率计算公式为:
其中,θ为模型的参数,a<t表示文本序列a1a2…at-1;在训练阶段,a<t采用的是训练目标中的序列,即a<t∈a;而在模型推理阶段,a<t均是由模型自行预测所得;在t=0的时刻,则仅依据输入序列q+p去计算输出序列第一个位置的词为a1的概率。
由于教师模型为黑盒模型,因此只调用教师模型生成结果,而不对教师模型进行训练。同样也不对大规模语言模型进行训练。通过单个输入得到多个不同的输出来捕获教师模型的输出分布规律,以弥补无法获取黑盒模型真实输出概率的问题。同时,多样性的教师输出,也能为学生模型提供更丰富的文本特征,以提高知识蒸馏的效果。学生模型的参数与结构都是可训练的,训练学生模型的目的是期望学生模型的最优输出能够在教师模型的指导下生成目标文本。
具体实施方案二:S1中采用大语言模型生成多个用于释义改写的prompt,记为其中k=1,2,…K;针对一个初始文本序列x=x1x2…xs,将输入序列与生成的个/>相结合,得到多个不同的序列x+pk,输入教师模型进行释义改写,得到多个释义改写后的序列样本xk;
进行全部次改写后,得到个不同的释义改写后的序列,即{x1,x2,…,xK}。本实施方案其它与具体实施方案一相同。
具体实施方案三:S1中还包括使用释义判别模型对生成的个prompt两两进行相似度判断,以确保提示间语义的相似性。本实施方案其它与具体实施方案二相同。
本实施方案中释义判别模型采用预训练语言模型BERT。
具体实施方案四:S2中采用大语言模型生成多个用于文本生成的prompt,记为其中k=1,2,…K,将释义改写后的的序列样本{x1,x2,…,xK}与用于文本生成的prompt相结合,输入教师模型,得到文本生成序列样本y(k);
进行全部次文本生成后,得到教师模型输出结果,即个不同序列{y(1),y(2),…,y(K)}。本实施方案其它与具体实施方案一相同。
具体实施方案五:S2中使用释义判别模型对个释义改写后的序列两两进行相似度判断,若基本保持语义一致,则全部送入下一轮进行文本生成;否则,对语义偏差较大的文本,重新进行释义改写,以使个释义文本间保持语义的一致性。本实施方案其它与具体实施方案四相同。
具体实施方案六:S3中学生模型的文本生成过程中,模型采用贪心采样的策略,对每个位置采样时仅提取出现在当前位置的概率最大的词作为结果。本实施方案其它与具体实施方案一相同。
具体实施方案七:S4中所述的统计语言模型的构建方法为:针对文本序列ω=ω1ω2…ωn,通过统计ω在整个文本语料库中出现的概率P(ω)实现机器对语言的识别,采用条件概率公式可得P(ω)为:
P(ω)=P(ω1)P(ω2|ω1)P(ω3|ω1ω2)…P(ωn|ω1ω2…ωn-1)
其中,P(ωn|ω1ω2…ωn-1)表示在已知前n-1个词的前提下,第n个词ωn的出现概率;
采用基于马尔科夫假设的二元模型Bi-gram对计算公式进行简化,具体地,假设第ωn的出现概率仅与它的前一个词ωn-1有关,则:
P(ω)=P(ω1)P(ω2|ω1)P(ω3|ω2)…P(ωn|ωn-1)
采用拉普拉斯平滑的方法对每个词的概率分布进行平滑处理,则概率分布为:
其中,C(ωn)为ωn在语料库中出现的次数,C(yn-1yn)为yn-1yn的bi-gram组合在语料库/>中出现的概率,/>为整个词汇表的大小;为常数,需根据具体的词汇表进行调整。本实施方案其它与具体实施方案一相同。
本实施方案中采用基于马尔科夫假设的二元模型Bi-gram对计算公式进行简化,以避免数据稀疏问题带来计算量巨大的问题。
本实施方案中统计语言模型是基于每个词n在整个语料库出现的条件概率所构造的,一旦在文本序列中出现未登录词OOV(Out Of Vocabulary),则会直接让个文本序列的概率归零或者造成数据稀疏,因此,本实施方案采用拉普拉斯平滑(LaplaceSmoothing)的方法对每个词的概率分布进行平滑处理,从而避免OOV造成的零概率以及数据稀疏的问题。
具体实施方案八:S5中所述第一损失函数,首先采用KL散度计算教师模型和学生模型输出结果的概率分布间的差异,损失函数为:
其中,ypred_w为ypred序列中的第个词,P(ypred_w|ypred)为词ypred_w在ypred映射到的词空间上的概率分布,P(ypred_w|)为词ypred_w在y映射到的词空间上的概率分布,lmT为统计语言模型对教师模型的输出结果进行的建模,即:
LMT=Language Model(y1,y2,…,yn)
LMs为统计语言模型对学生模型的输出结果的建模,即:
LMs=Language Model(ypred_1,ypred_2,…,ypred_m)
将教师模型的次输出与学生模型的输出ypred依次计算KL散度后取平均,得到损失函数:
其中,|K|为调用教师模型对输入x生成不同文本的次数,为语言模型对教师模型的第个输出文本的建模;
所述第二损失函数为计算学生模型在对应的目标文本上的负对数似然损失,损失函数为:
其中,n为目标序列的长度,为文本序列/>结合两部分损失,得到总损失函数为:
LKD=(1-λ)LNLL+λLKL_avg
其中,λ是一个超参数,用于决定两类损失的比重。本实施方案其它与具体实施方案一相同。
本实施方案中面向文本生成的基于语言模型的黑盒知识蒸馏模型的训练方法为:
具体实施方案九:一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏系统,该系统具有与上述实施方案一至八任一项的步骤对应的程序模块,运行时执行上述的面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法中的步骤。
具体实施方案十:一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序配置为由处理器调用时实现实施方案一至八中任一项所述的面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法的步骤。
通过下述实施例验证本发明方法的有效性。
实施例1
数据集介绍
使用开源的Stanford Question Answering Dataset,SQuAD问答数据集构建demo,该数据集是一个阅读理解数据集,由众包工作者在一组维基百科文章上提出的问题组成。SQuAD的训练集包含87,599条数据,验证集包含10,570条数据。Demo在训练集上训练,并取验证集上的结果,在EM(exact match)和F1(F1-score)两个指标上进行比较。
模型介绍
以mT0-base为教师模型,mT0-small为学生模型验证方法的有效性。mT0模型为预训练语言模型mT5系列在多任务上精调后的变体,而mT5为T5模型的多语言变体。其中mT0-small模型包含300M参数,mT0-base包含580M参数。
实验结果
如表1所示为教师模型、学生模型、蒸馏后学生模型在SQuAD验证集上的实验结果如下:
表1
其中,mT0-small-KD为使用黑盒蒸馏算法精调的学生模型。
通过分析在demo上的实验结果可以看到,尽管模型的规模相比于大规模语言模型小很多,本发明所提出的蒸馏方法仍能够实现在教师模型为黑盒的限制条件下,仍然能够通过知识蒸馏的方法,把性能较强的教师模型的知识迁移到性能较弱的学生模型上,从而提升学生模型的性能,且在性能上超过了仅依靠学生模型单独训练时的性能,因此,证明本发明蒸馏方法有效性。
虽然本发明公开披露如上,但本发明公开的保护范围并非仅限于此。本发明领域技术人员在不脱离本发明公开的精神和范围的前提下,可进行各种变更与修改,这些变更与修改均将落入本发明的保护范围。
Claims (10)
1.一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法,其特征在于,包括如下步骤:
S1、采用大语言模型生成多个用于释义改写的prompt,将初始文本序列样本进行释义改写,得到多个释义改写后的序列样本;
S2、采用大语言模型生成多个用于文本生成的prompt,将释义改写后的序列样本与用于文本生成的prompt相结合,输入到教师模型进行文本生成,得到教师模型输出结果;
S3、采用大语言模型生成一个用于文本生成的prompt,将初始序列样本与用于文本生成的prompt相结合输入学生模型,得到学生模型的输出结果;
S4、构建统计语言模型,采用所述统计语言模型对教师模型的多个输出结果和学生模型的一个输出结果分别建模,分别计算教师模型和学生模型输出结果的概率分布;
S5、以第一损失函数计算教师模型和学生模型输出结果的概率分布的差异损失,以第二损失函数计算学生模型在对应的目标文本上的损失,结合两个损失结果计算总损失,对学生模型的参数进行调整;
S6、重复执行S4到S5,至模型收敛或者达到预设迭代次数,得到训练后的学生模型。
2.根据权利要求1所述的面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法,其特征在于,S1中采用大语言模型生成多个用于释义改写的prompt,记为其中k=1,2,...K;针对一个初始文本序列x=x1x2...xs,将输入序列x与生成的K个/>相结合,得到多个不同的序列x+pk,输入教师模型进行释义改写,得到多个释义改写后的序列样本xk;
进行全部K次改写后,得到K个不同的释义改写后的序列,即{x1,x2,...,xK}。
3.根据权利要求2所述的面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法,其特征在于,S1中还包括使用释义判别模型对生成的K个prompt两两进行相似度判断,以确保提示间语义的相似性。
4.根据权利要求1所述的面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法,其特征在于,S2中采用大语言模型生成多个用于文本生成的prompt,记为其中k=1,2,...K,将释义改写后的的序列样本{x1,x2,...,xK}与用于文本生成的prompt相结合,输入教师模型,得到文本生成序列样本y(k);
进行全部K次文本生成后,得到教师模型输出结果,即K个不同序列{y(1),y(2),...,y(K)}。
5.根据权利要求4所述的面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法,其特征在于,S2中使用释义判别模型对K个释义改写后的序列两两进行相似度判断,若基本保持语义一致,则全部送入下一轮进行文本生成;否则,对语义偏差较大的文本,重新进行释义改写,以使K个释义文本间保持语义的一致性。
6.根据权利要求1所述的面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法,其特征在于,S3中学生模型的文本生成过程中,模型采用贪心采样的策略,对每个位置采样时仅提取出现在当前位置的概率最大的词作为结果。
7.根据权利要求1所述的面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法,其特征在于,S4中所述的统计语言模型的构建方法为:针对文本序列ω=ω1ω2...ωn,通过统计ω在整个文本语料库中出现的概率P(ω)实现机器对语言的识别,采用条件概率公式可得P(ω)为:
P(ω)=P(ω1)P(ω2|ω1)P(ω3|ω1ω2)...P(ωn|ω1ω2...ωn-1)
其中,P(ωn|ω1ω2...ωn-1)表示在已知前n-1个词的前提下,第n个词ωn的出现概率;
采用基于马尔科夫假设的二元模型Bi-gram对计算公式进行简化,具体地,假设第ωn的出现概率仅与它的前一个词ωn-1有关,则:
P(ω)=P(ω1)P(ω2|ω1)P(ω3|ω2)...P(ωn|ωn-1)
采用拉普拉斯平滑的方法对每个词的概率分布进行平滑处理,则概率分布为:
其中,C(ωn)为ωn在语料库中出现的次数,C(yn-1yn)为yn-1yn的bi-gram组合在语料库/>中出现的概率,/>为整个词汇表的大小;k为常数,需根据具体的词汇表进行调整。
8.根据权利要求1所述的面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法,其特征在于,S5中所述第一损失函数,首先采用KL散度计算教师模型和学生模型输出结果的概率分布间的差异,损失函数为:
其中,ypred_w为ypred序列中的第w个词,P(ypred_w|ypred)为词ypred_w在ypred映射到的词空间上的概率分布,P(ypred_w|y)为词ypred_w在y映射到的词空间上的概率分布,LMT为统计语言模型对教师模型的输出结果进行的建模,即:
LMT=Language Model(y1,y2,...,yn)
LMS为统计语言模型对学生模型的输出结果的建模,即:
LMS=Language Model(ypred_1,ypred_2,...,ypred_m)
将教师模型的K次输出与学生模型的输出ypred依次计算KL散度后取平均,得到损失函数:
其中,|K|为调用教师模型对输入x生成不同文本的次数,为语言模型对教师模型的第k个输出文本的建模;
所述第二损失函数为计算学生模型在对应的目标文本上的负对数似然损失,损失函数为:
其中,n为目标序列的长度,为文本序列/>
结合两部分损失,得到总损失函数为:
LKD=(1-λ)LNLL+λLKL_avg
其中,λ是一个超参数,用于决定两类损失的比重。
9.一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏系统,其特征在于,该系统具有与上述权利要求1~8任一项权利要求的步骤对应的程序模块,运行时执行上述的面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法中的步骤。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机程序,所述计算机程序配置为由处理器调用时实现权利要求1~8中任一项所述的面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311012488.5A CN117057414B (zh) | 2023-08-11 | 2023-08-11 | 一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311012488.5A CN117057414B (zh) | 2023-08-11 | 2023-08-11 | 一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117057414A true CN117057414A (zh) | 2023-11-14 |
CN117057414B CN117057414B (zh) | 2024-06-07 |
Family
ID=88667115
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311012488.5A Active CN117057414B (zh) | 2023-08-11 | 2023-08-11 | 一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117057414B (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117521799A (zh) * | 2024-01-08 | 2024-02-06 | 徐州医科大学 | 一种基于提示学习的个性化知识图谱动态生成方法 |
Citations (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20190080688A1 (en) * | 2015-10-09 | 2019-03-14 | Mitsubishi Electric Corporation | Language model generating device, language model generating method, and recording medium |
WO2021243473A1 (en) * | 2020-06-05 | 2021-12-09 | Huawei Technologies Co., Ltd. | Improved knowledge distillation by utilizing backward pass knowledge in neural networks |
CN114254100A (zh) * | 2021-12-15 | 2022-03-29 | 科大讯飞股份有限公司 | 输入推荐方法、装置、电子设备和存储介质 |
CN114611670A (zh) * | 2022-03-15 | 2022-06-10 | 重庆理工大学 | 一种基于师生协同的知识蒸馏方法 |
CN114627331A (zh) * | 2022-03-07 | 2022-06-14 | 北京沃东天骏信息技术有限公司 | 模型训练方法和装置 |
CN114818891A (zh) * | 2022-04-14 | 2022-07-29 | 人民网股份有限公司 | 小样本多标签文本分类模型训练方法及文本分类方法 |
CN114925699A (zh) * | 2022-04-28 | 2022-08-19 | 电子科技大学 | 一种基于风格变换的高迁移性对抗文本生成方法 |
CN115114974A (zh) * | 2022-05-18 | 2022-09-27 | 腾讯科技(深圳)有限公司 | 一种模型蒸馏方法、装置、计算机设备和存储介质 |
CN115526332A (zh) * | 2022-08-17 | 2022-12-27 | 阿里巴巴(中国)有限公司 | 基于预训练语言模型的学生模型训练方法和文本分类系统 |
US20230031512A1 (en) * | 2020-10-14 | 2023-02-02 | Feedzai - Consultadoria E Inovação Tecnológica, S.A. | Surrogate hierarchical machine-learning model to provide concept explanations for a machine-learning classifier |
CN115964999A (zh) * | 2023-01-10 | 2023-04-14 | 阿里巴巴(中国)有限公司 | 模型训练和文本生成方法、装置、电子设备和存储介质 |
CN116186200A (zh) * | 2023-01-19 | 2023-05-30 | 北京百度网讯科技有限公司 | 模型训练方法、装置、电子设备和存储介质 |
CN116306868A (zh) * | 2023-03-01 | 2023-06-23 | 支付宝(杭州)信息技术有限公司 | 一种模型的处理方法、装置及设备 |
-
2023
- 2023-08-11 CN CN202311012488.5A patent/CN117057414B/zh active Active
Patent Citations (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20190080688A1 (en) * | 2015-10-09 | 2019-03-14 | Mitsubishi Electric Corporation | Language model generating device, language model generating method, and recording medium |
WO2021243473A1 (en) * | 2020-06-05 | 2021-12-09 | Huawei Technologies Co., Ltd. | Improved knowledge distillation by utilizing backward pass knowledge in neural networks |
US20230031512A1 (en) * | 2020-10-14 | 2023-02-02 | Feedzai - Consultadoria E Inovação Tecnológica, S.A. | Surrogate hierarchical machine-learning model to provide concept explanations for a machine-learning classifier |
CN114254100A (zh) * | 2021-12-15 | 2022-03-29 | 科大讯飞股份有限公司 | 输入推荐方法、装置、电子设备和存储介质 |
CN114627331A (zh) * | 2022-03-07 | 2022-06-14 | 北京沃东天骏信息技术有限公司 | 模型训练方法和装置 |
CN114611670A (zh) * | 2022-03-15 | 2022-06-10 | 重庆理工大学 | 一种基于师生协同的知识蒸馏方法 |
CN114818891A (zh) * | 2022-04-14 | 2022-07-29 | 人民网股份有限公司 | 小样本多标签文本分类模型训练方法及文本分类方法 |
CN114925699A (zh) * | 2022-04-28 | 2022-08-19 | 电子科技大学 | 一种基于风格变换的高迁移性对抗文本生成方法 |
CN115114974A (zh) * | 2022-05-18 | 2022-09-27 | 腾讯科技(深圳)有限公司 | 一种模型蒸馏方法、装置、计算机设备和存储介质 |
CN115526332A (zh) * | 2022-08-17 | 2022-12-27 | 阿里巴巴(中国)有限公司 | 基于预训练语言模型的学生模型训练方法和文本分类系统 |
CN115964999A (zh) * | 2023-01-10 | 2023-04-14 | 阿里巴巴(中国)有限公司 | 模型训练和文本生成方法、装置、电子设备和存储介质 |
CN116186200A (zh) * | 2023-01-19 | 2023-05-30 | 北京百度网讯科技有限公司 | 模型训练方法、装置、电子设备和存储介质 |
CN116306868A (zh) * | 2023-03-01 | 2023-06-23 | 支付宝(杭州)信息技术有限公司 | 一种模型的处理方法、装置及设备 |
Non-Patent Citations (3)
Title |
---|
CHUHAN WU等: "One Teacher is Enough? Pre-trained Language Model Distillation from Multiple Teachers", ARXIV:2106.01023, 2 June 2021 (2021-06-02) * |
DANG NGUYEN等: "Black-box Few-shot Knowledge Distillation", ARXIV:2207.12106, 25 July 2022 (2022-07-25) * |
张一珂;张鹏远;颜永红;: "基于对抗训练策略的语言模型数据增强技术", 自动化学报, no. 05, 18 April 2018 (2018-04-18) * |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117521799A (zh) * | 2024-01-08 | 2024-02-06 | 徐州医科大学 | 一种基于提示学习的个性化知识图谱动态生成方法 |
CN117521799B (zh) * | 2024-01-08 | 2024-03-08 | 徐州医科大学 | 一种基于提示学习的个性化知识图谱动态生成方法 |
Also Published As
Publication number | Publication date |
---|---|
CN117057414B (zh) | 2024-06-07 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Prakash et al. | Neural paraphrase generation with stacked residual LSTM networks | |
CN111078836B (zh) | 基于外部知识增强的机器阅读理解方法、系统、装置 | |
CN110210032B (zh) | 文本处理方法及装置 | |
CN112800203B (zh) | 一种融合文本和知识表征的问答匹配方法及系统 | |
CN111191002B (zh) | 一种基于分层嵌入的神经代码搜索方法及装置 | |
Chen et al. | Delving deeper into the decoder for video captioning | |
CN117057414B (zh) | 一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法及系统 | |
CN110807335A (zh) | 基于机器学习的翻译方法、装置、设备及存储介质 | |
CN118093834B (zh) | 一种基于aigc大模型的语言处理问答系统及方法 | |
CN113704393A (zh) | 关键词提取方法、装置、设备及介质 | |
CN117648950A (zh) | 神经网络模型的训练方法、装置、电子设备及存储介质 | |
CN111444328A (zh) | 一种带有解释生成的自然语言自动预测推断方法 | |
Zhou et al. | Scalable prompt generation for semi-supervised learning with language models | |
CN112732879B (zh) | 一种问答任务的下游任务处理方法及模型 | |
CN112364659B (zh) | 一种无监督的语义表示自动识别方法及装置 | |
Han et al. | Generative adversarial networks for open information extraction | |
CN116595189A (zh) | 基于两阶段的零样本关系三元组抽取方法及系统 | |
Rao | Are you asking the right questions? Teaching Machines to Ask Clarification Questions | |
CN115357712A (zh) | 方面级情感分析方法、装置、电子设备及存储介质 | |
CN116450783A (zh) | 面向篇章级的事件抽取方法、系统、存储介质和电子设备 | |
CN114896973A (zh) | 一种文本处理方法、装置及电子设备 | |
Dasgupta et al. | A Review of Generative AI from Historical Perspectives | |
CN114239555A (zh) | 一种关键词提取模型的训练方法及相关装置 | |
CN114610852B (zh) | 一种基于课程学习的细粒度中文句法分析方法及装置 | |
US20240256964A1 (en) | Pretraining Already-Pretrained Models for Diverse Downstream Tasks |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
CP03 | Change of name, title or address |
Address after: No.18, Jiangwan 1st Road, Foshan, Guangdong 528011 Patentee after: Foshan University Country or region after: China Address before: No.18, Jiangwan 1st Road, Foshan, Guangdong 528011 Patentee before: FOSHAN University Country or region before: China |
|
CP03 | Change of name, title or address |