CN116628147A - 训练文本预测模型的方法、文本预测方法及装置 - Google Patents

训练文本预测模型的方法、文本预测方法及装置 Download PDF

Info

Publication number
CN116628147A
CN116628147A CN202310459343.3A CN202310459343A CN116628147A CN 116628147 A CN116628147 A CN 116628147A CN 202310459343 A CN202310459343 A CN 202310459343A CN 116628147 A CN116628147 A CN 116628147A
Authority
CN
China
Prior art keywords
matrix
network
text
value matrix
key matrix
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310459343.3A
Other languages
English (en)
Inventor
惠彬原
杨家玺
黎槟华
黄非
李永彬
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Alibaba Damo Institute Hangzhou Technology Co Ltd
Original Assignee
Alibaba Damo Institute Hangzhou Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Alibaba Damo Institute Hangzhou Technology Co Ltd filed Critical Alibaba Damo Institute Hangzhou Technology Co Ltd
Priority to CN202310459343.3A priority Critical patent/CN116628147A/zh
Publication of CN116628147A publication Critical patent/CN116628147A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/33Querying
    • G06F16/332Query formulation
    • G06F16/3329Natural language query formulation or dialogue systems
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/31Indexing; Data structures therefor; Storage structures
    • G06F16/313Selection or weighting of terms for indexing
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/35Clustering; Classification
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/30Semantic analysis
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • G06N3/0455Auto-encoder networks; Encoder-decoder networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y04INFORMATION OR COMMUNICATION TECHNOLOGIES HAVING AN IMPACT ON OTHER TECHNOLOGY AREAS
    • Y04SSYSTEMS INTEGRATING TECHNOLOGIES RELATED TO POWER NETWORK OPERATION, COMMUNICATION OR INFORMATION TECHNOLOGIES FOR IMPROVING THE ELECTRICAL POWER GENERATION, TRANSMISSION, DISTRIBUTION, MANAGEMENT OR USAGE, i.e. SMART GRIDS
    • Y04S10/00Systems supporting electrical power generation, transmission or distribution
    • Y04S10/50Systems or methods supporting the power network operation or management, involving a certain degree of interaction with the load-side end user applications

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Computational Linguistics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • General Health & Medical Sciences (AREA)
  • Databases & Information Systems (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biophysics (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Human Computer Interaction (AREA)
  • Machine Translation (AREA)

Abstract

本申请实施例公开了一种训练文本预测模型的方法、文本预测方法及装置。本申请利用训练样本集对大型语言模型进行训练来得到文本预测模型,这种方式实质上利用了已标注的样本训练大型语言模型(LLM),在第二键矩阵和第二值矩阵的更新过程中,利用了上一轮迭代得到的第二键矩阵和第二值矩阵与当前输入特征矩阵产生的第一键矩阵和第一值矩阵,既保留了历史信息又保持了当前输入文本的信息,使得大型语言模型能够充分对已标注的样本进行理解和学习,从而提高情景学习场景下基于大型语言模型的文本预测效果。并且这种前向优化模型的方式,大大缩减了需要更新的模型参数,降低了模型训练的成本,提高了效率。

Description

训练文本预测模型的方法、文本预测方法及装置
技术领域
本申请涉及自然语言处理技术领域,特别是涉及一种训练文本预测模型的方法、文本预测方法及装置。
背景技术
NLP(Natural Language Processing,自然语言处理)研究的目标是让机器能够理解人类语言。其中LLM(Large Language Model,大型语言模型)是自然语言处理领域中的一个核心工具,指的是具有大规模参数(通常数以亿计或更多)的深度学习模型。
LLM因其具有极高的学习能力而被广泛地应用于文本预测领域。其中,情景学习(In Context Learning)是目前LLM采用的其中一种文本预测方式。所谓情景学习指的是,给定标注数据后,LLM进行观察和归纳,对无标签数据进行预测。由于情景学习通常没有训练过程,应用于文本预测时,预测效果不佳。因此亟需一种方式能够提高情景学习场景下基于LLM的文本预测效果。
发明内容
有鉴于此,本申请提供了一种训练文本预测模型的方法、文本预测方法及装置,以便于提高情景学习场景下基于LLM的文本预测效果。
本申请提供了如下方案:
第一方面,提供了一种训练文本预测模型的方法,所述方法包括:
获取训练数据集,所述训练数据集包括输入文本样本以及该输入文本样本对应的输出标签;
将包含输入文本样本和该输入文本样本对应的输出标签的文本序列作为文本预测模型的输入,训练所述文本预测模型;其中,所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换Transformer网络;
在所述训练中各Transformer网络分别作为当前层Transformer网络执行:在当前轮迭代中,利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与当前层Transformer网络在上一轮迭代得到的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与当前层Transformer网络在上一轮迭代得到的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示;利用所述第一键矩阵对所述第二键矩阵进行更新,将更新后的第二键矩阵作为所述当前层Transformer网络在当前轮迭代得到的第二键矩阵;利用所述第一值矩阵对所述第二值矩阵进行更新,将更新后的第二值矩阵作为所述当前层Transformer网络在当前轮迭代得到的第二值矩阵。
根据本申请实施例中一可实现的方式,所述大型语言模型还包括嵌入网络,所述文本预测模型还包括预测网络;
所述嵌入网络用以对所述文本序列进行嵌入处理;
若当前层Transformer网络为第一层Transformer网络,则所述上一层网络输出的特征表示为所述嵌入网络输出的特征表示;否则,所述上一层网络输出的特征表示为上一层Transformer网络输出的特征表示;
所述预测网络用以利用最后一层Transformer网络输出的特征表示预测输入文本样本对应的输出标签。
根据本申请实施例中一可实现的方式,利用所述第一键矩阵对所述第二键矩阵进行更新包括:利用所述第一键矩阵对所述当前层Transformer网络在上一轮迭代得到的第二键矩阵采用动量梯度下降的方式进行更新;
利用所述第一值矩阵对所述第二值矩阵进行更新包括:利用所述第一值矩阵对所述当前层Transformer网络在上一轮迭代得到的第二值矩阵采用动量梯度下降的方式进行更新。
根据本申请实施例中一可实现的方式,利用所述第一键矩阵对所述当前层Transformer网络在上一轮迭代得到的第二键矩阵采用动量梯度下降的方式进行更新包括:利用所述第一键矩阵和所述当前层Transformer网络在上一轮迭代得到的第二键矩阵进行逐元素求差,得到键矩阵梯度;利用所述当前层Transformer网络在上一轮迭代得到的键矩阵动量和所述键矩阵梯度进行加权求和,得到所述当前层Transformer网络在当前轮迭代得到的键矩阵动量;利用所述当前层Transformer网络在当前轮迭代得到的键矩阵动量和当前层Transformer网络在上一轮迭代得到的第二键矩阵,得到所述更新后的第二键矩阵;
利用所述第一值矩阵对所述当前层Transformer网络在上一轮迭代得到的第二值矩阵采用动量梯度下降的方式进行更新包括:利用所述第一值矩阵和所述当前层Transformer网络在上一轮迭代得到的第二值矩阵进行逐元素求差,得到值矩阵梯度;利用所述当前层Transformer网络在上一轮迭代得到的值矩阵动量和所述值矩阵梯度进行加权求和,得到所述当前层Transformer网络在当前轮迭代得到的值矩阵动量;利用所述当前层Transformer网络在当前轮迭代得到的值矩阵动量和当前层Transformer网络在上一轮迭代得到的第二值矩阵,得到所述更新后的第二值矩阵。
根据本申请实施例中一可实现的方式,每一轮迭代完成后,若确定达到预设的训练结束条件,则将各层Transformer网络当前迭代得到的第二键矩阵和第二值矩阵分别作为训练得到的各层Transformer网络的第二键矩阵和第二值矩阵进行存储。
第二方面,提供了一种文本预测方法,所述方法包括:
获取输入文本;
将包含所述输入文本的文本序列输入文本预测模型,获取所述文本预测模型预测得到的所述输入文本对应的输出标签;所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换Transformer网络;
各Transformer网络分别作为当前层Transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与预先训练得到的当前层Transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与预先训练得到的当前层Transformer网络的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示。
根据本申请实施例中一可实现的方式,所述大型语言模型还包括嵌入网络,所述文本预测模型还包括预测网络;
所述嵌入网络用以对所述文本序列进行嵌入处理;
若当前层Transformer网络为第一层Transformer网络,则所述上一层网络输出的特征表示为所述嵌入网络输出的特征表示;否则,所述上一层网络输出的特征表示为上一层Transformer网络输出的特征表示;
所述预测网络用以利用最后一层Transformer网络输出的特征表示预测输入文本对应的输出标签。
第三方面,提供了一种训练文本预测模型的方法,所述方法包括:
获取训练数据集,所述训练数据集包括输入文本样本以及该输入文本样本对应的情感类别标签;
将包含输入文本样本和该输入文本样本对应的情感类别标签的文本序列作为文本预测模型的输入,训练所述文本预测模型;其中,所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的Transformer网络;
在所述训练中各Transformer网络分别作为当前层Transformer网络执行:在当前轮迭代中,利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与当前层Transformer网络在上一轮迭代得到的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与当前层Transformer网络在上一轮迭代得到的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示;利用所述第一键矩阵对所述第二键矩阵进行更新,将更新后的第二键矩阵作为所述当前层Transformer网络在当前轮迭代得到的第二键矩阵;利用所述第一值矩阵对所述第二值矩阵进行更新,将更新后的第二值矩阵作为所述当前层Transformer网络在当前轮迭代得到的第二值矩阵。
第四方面,提供了一种情感分析方法,所述方法包括:
获取输入文本;
将包含所述输入文本的文本序列输入文本预测模型,获取所述文本预测模型预测得到的所述输入文本对应的情感类别标签;所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换Transformer网络;
各Transformer网络分别作为当前层Transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与预先训练得到的当前层Transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与预先训练得到的当前层Transformer网络的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示。
第五方面,提供了一种文本预测方法,由云端服务器执行,所述方法包括:
获取来自用户终端的输入文本;
将包含所述输入文本的文本序列输入文本预测模型,获取所述文本预测模型预测得到的所述输入文本对应的输出标签;
基于所述输出标签确定对应的服务内容,将所述服务内容发送至所述用户终端;
其中,所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换Transformer网络;
各Transformer网络分别作为当前层Transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与预先训练得到的当前层Transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与预先训练得到的当前层Transformer网络的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示。
第六方面,提供了一种训练文本预测模型的装置,所述装置包括:
样本获取单元,被配置为获取训练数据集,所述训练数据集包括输入文本样本以及该输入文本样本对应的输出标签样本;
模型训练单元,被配置为将包含输入文本样本和该输入文本样本对应的输出标签的文本序列作为文本预测模型的输入,训练所述文本预测模型;其中,所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换Transformer网络;在所述训练中各Transformer网络分别作为当前层Transformer网络执行:在当前轮迭代中,利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与当前层Transformer网络在上一轮迭代得到的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与当前层Transformer网络在上一轮迭代得到的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示;利用所述第一键矩阵对所述第二键矩阵进行更新,将更新后的第二键矩阵作为所述当前层Transformer网络在当前轮迭代得到的第二键矩阵;利用所述第一值矩阵对所述第二值矩阵进行更新,将更新后的第二值矩阵作为所述当前层Transformer网络在当前轮迭代得到的第二值矩阵。
第七方面,提供了一种文本预测装置,所述装置包括:
文本获取单元,被配置为获取输入文本;
文本预测单元,被配置为将包含所述输入文本的文本序列输入文本预测模型,获取所述文本预测模型预测得到的所述输入文本对应的输出标签;所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换Transformer网络;其中,各Transformer网络分别作为当前层Transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与预先训练得到的当前层Transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与预先训练得到的当前层Transformer网络的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示。
根据第八方面,提供了一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现上述第一方面中任一项所述的方法的步骤。
根据第九方面,提供了一种电子设备,包括:
一个或多个处理器;以及
与所述一个或多个处理器关联的存储器,所述存储器用于存储程序指令,所述程序指令在被所述一个或多个处理器读取执行时,执行上述第一方面中任一项所述的方法的步骤。
根据本申请提供的具体实施例,本申请公开了以下技术效果:
1)本申请利用训练样本集对LLM进行训练来得到文本预测模型,这种方式实质上利用了已标注的样本训练LLM,在第二键矩阵和第二值矩阵的更新过程中,利用了上一轮迭代得到的第二键矩阵和第二值矩阵与当前输入特征矩阵产生的第一键矩阵和第一值矩阵,既保留了历史信息又保持了当前输入文本的信息,使得LLM能够充分对已标注的样本进行理解和学习,从而提高LLM的文本预测效果。
2)本申请在LLM的训练过程中,每一轮迭代过程中使用的第二键矩阵和第二值矩阵均由上一轮迭代得到,且进行更新后用于下一轮迭代使用。这种前向优化模型的方式,仅需要优化各Transformer网络的第二键矩阵和第二值矩阵即可,大大缩减了需要更新的模型参数,降低了模型训练的成本,提高了效率。
3)本申请实施例中采用动量梯度下降的方式更新各Transformer网络的第二键矩阵和第二值矩阵,能够加快梯度下降的速度,使得迭代效率更高,并且避免陷入局部最小值。
4)本申请在预测过程中使用了文本预测模型训练后得到的各Transformer网络的第二键矩阵和第二值矩阵,与普通的情景学习相比,各Transformer网络的第二键矩阵和第二值矩阵包含了文本预测模型对已标注的样本数据更好的观察和理解,能够显著提高预测准确性。
当然,实施本申请的任一产品并不一定需要同时达到以上所述的所有优点。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为传统基于LLM进行情景学习的原理示意图;
图2为本申请实施例所适用的系统架构图;
图3为本申请实施例提供的训练文本预测模型的方法的流程图;
图4为本申请实施例提供的文本预测模型的结构性示意图;
图5为本申请实施例提供的一个Transformer网络的原理性示意图;
图6为本申请实施例提供的文本预测方法流程图;
图7为本申请实施例提供的基于LLM进行情景学习的原理示意图;
图8为本申请实施例提供的模型训练装置的示意性框图;
图9为本申请实施例提供的文本预测装置的示意性框图;
图10为本申请实施例提供的电子设备的示意性框图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员所获得的所有其他实施例,都属于本申请保护的范围。
在本发明实施例中使用的术语是仅仅出于描述特定实施例的目的,而非旨在限制本发明。在本发明实施例和所附权利要求书中所使用的单数形式的“一种”、“所述”和“该”也旨在包括多数形式,除非上下文清楚地表示其他含义。
应当理解,本文中使用的术语“和/或”仅仅是一种描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。另外,本文中字符“/”,一般表示前后关联对象是一种“或”的关系。
取决于语境,如在此所使用的词语“如果”可以被解释成为“在……时”或“当……时”或“响应于确定”或“响应于检测”。类似地,取决于语境,短语“如果确定”或“如果检测(陈述的条件或事件)”可以被解释成为“当确定时”或“响应于确定”或“当检测(陈述的条件或事件)时”或“响应于检测(陈述的条件或事件)”。
传统基于LLM进行的情景学习过程如图1中所示,将已标注的样本数据和待预测的输入文本一起输入LLM,由LLM对已标注的样本数据进行观察和归纳,从而对输入文本对应的输出标签进行预测。该情景学习过程没有训练过程,文本预测效果不佳。
若采用传统反向传播的训练方式预先基于LLM训练文本预测模型,即利用已标注的样本数据训练LLM,通过损失函数产生梯度进行模型参数的更新。由于LLM的模型参数量巨大,传统采用反向传播的训练方式会造成成本极高且效率低下的问题。
有鉴于此,本申请提出了一种全新的思路,采用前向训练的方式预先基于LLM训练文本预测模型。为了方便对本申请的理解,首先对本申请所基于的系统架构进行简单描述。图2示出了可以应用本申请实施例的示例性系统架构。如图2中所示,该系统架构包括采用离线方式训练文本预测模型的模型训练装置,以及在线对输入文本进行预测的文本预测装置。
其中,模型训练装置在获取训练数据集后,可以采用本申请实施例提供的方法进行模型训练,得到文本预测模型。
文本预测装置利用已经训练得到的文本预测模型,对输入文本进行预测,得到的预测结果为输入文本对应的输出标签。例如对输入文本进行预测,得到诸如情感、意图等类别。
模型训练装置和文本预测装置可以分别设置为独立的服务器,也可以设置于同一个服务器或服务器群组,还可以设置于独立的或者同一云服务器。云服务器又称为云计算服务器或云主机,是云计算服务体系中的一项主机产品,以解决传统物理主机与虚拟专用服务器(VPS,Virtual Private Server)服务中存在的管理难度大,服务扩展性弱的缺陷。模型训练装置和文本预测装置还可以设置于具有较强计算能力的计算机终端。
另外,文本预测装置除了在线进行文本预测之外,也可以离线进行文本预测。例如对数据库中的文本分别作为输入文本进行批量的文本预测,并将预测结果存储于数据库中供后续查询或调用。
应该理解,图2中的模型训练装置、文本预测装置以及文本预测模型的数目仅仅是示意性的。根据实现需要,可以具有任意数目的模型训练装置、文本预测装置以及文本预测模型。
图3为本申请实施例提供的训练文本预测模型的方法的流程图,该方法可以由图2所示系统中的模型训练装置执行。如图3中所示,该方法可以包括以下步骤:
步骤302:获取训练数据集,训练数据集包括输入文本样本以及该输入文本样本对应的输出标签样本。
步骤304:将包含输入文本样本和该输入文本样本对应的输出标签的文本序列作为文本预测模型的输入,训练文本预测模型;其中,文本预测模型采用LLM,LLM包括多层串连的Transformer(转换)网络。在训练中各Transformer网络分别作为当前层Transformer网络执行:在当前轮迭代中,利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将第一键矩阵与当前层Transformer网络在上一轮迭代得到的第二键矩阵进行拼接得到第三键矩阵,以及将第一值矩阵与当前层Transformer网络在上一轮迭代得到的第二值矩阵进行拼接得到第三值矩阵;利用第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示;利用第一键矩阵对上述第二键矩阵进行更新,将更新后的第二键矩阵作为当前层Transformer网络在当前轮迭代得到的第二键矩阵;利用第一值矩阵对上述第二值矩阵进行更新,将更新后的第二值矩阵作为当前层Transformer网络在当前轮迭代得到的第二值矩阵。
由上述流程可以看出,本申请利用训练样本集对LLM进行训练来得到文本预测模型,这种方式实质上利用了已标注的样本训练LLM,在第二键矩阵和第二值矩阵的更新过程中,利用了上一轮迭代得到的第二键矩阵和第二值矩阵与当前输入特征矩阵产生的第一键矩阵和第一值矩阵,既保留了历史信息又保持了当前输入文本的信息,使得LLM能够充分对已标注的样本进行理解和学习,从而提高LLM的文本预测效果。
并且,在LLM的训练过程中,每一轮迭代过程中使用的第二键矩阵和第二值矩阵均由上一轮迭代得到,且进行更新后用于下一轮迭代使用。这种前向优化模型的方式,仅需要优化各Transformer网络的第二键矩阵和第二值矩阵即可,大大缩减了需要更新的模型参数,降低了模型训练的成本,提高了效率。
需要说明的是,本公开中涉及的“第一”、“第二”等限定并不具备大小、顺序和数量等方面的限制,仅仅用以在名称上加以区分。例如“第一键矩阵”、“第二键矩阵”和“第三键矩阵”用以在名称上区分三个键矩阵。再例如“第一值矩阵”、“第二值矩阵”和“第三值矩阵”用以在名称上区分三个值矩阵。
下面分别对上述流程中的各步骤进行详细描述。首先结合实施例对上述步骤302即“获取训练数据集”进行详细描述。
本步骤中获取的训练数据集包括多个训练样本。训练样本包括输入文本样本以及该输入文本样本对应的输出标签样本,每个训练样本可以表示为(xi,yi),其中xi为第i个输入文本样本,yi为xi对应的输出标签样本,即为xi标注的输出标签。
上述输入文本样本可以是句子、段落、文章、短语等等。输出标签在不同应用场景下可以对应不同的内容,可以是类别标签,也可以是文本标签。例如,在情感分析场景下,输出标签可以是情感类型标签,即对输入文本样本标注的情感类型。再例如,在意图识别场景下,输出标签可以是意图类型标签,即对输入文本样本标注的意图类型。再例如,在自动问答场景下,输出标签可以是答案文本标签,即将输入文本样本作为问题,针对问题标注的答案文本。
上述输出标签样本可以采用人工的方式进行标注,也可以采用已有的挖掘方式自动进行标注,本申请对此不做限制。
下面结合实施例重点对上述步骤304即“将包含输入文本样本和该输入文本样本对应的输出标签的文本序列作为文本预测模型的输入,训练文本预测模型”进行详细描述。
本申请实施例中文本预测模型基于LLM实现,LLM为预训练的大型语言模型,本申请实施例在LLM的基础上进一步进行模型优化,以使得LLM能够快速理解已标注的样本集训练数据集。为了充分利用LLM已有的语言理解能力,在进行训练的过程中,利用各输入文本样本和该输入文本样本对应的输出标签分别构建文本序列,例如将输入文本样本xi与对应的输出标签yi进行拼接得到文本序列Ti输入文本预测模型即LLM。在拼接xi和yi时,可以增加对xi和yi的指示信息。例如,输入文本样本“食物太好吃了”,输出标签“积极”,可以构成文本序列“评价:食物太好吃了。情感:积极”。其中“评价”和“情感”是对输入文本样本和输出标签的指示信息。
本申请实施例所采用的文本预测模型的整体框架可以如图4中所示,主要包括特征提取网络和预测网络。
其中特征提取网络用以从输入的文本序列中提取特征表示。预测网络用以利用特征提取网络提取的特征表示预测输入文本样本的输出标签。除了该框架之外,也可以采用其他诸如编码器-解码器框架等,在此不做一一列举,但框架均基于LLM实现,例如特征提取网络基于LLM实现,编码器基于LLM实现等等,本申请实施例中仅以图4所示框架为例。
鉴于本申请的文本预测模型基于LLM实现,因此特征提取网络可以具体包括嵌入网络和多层串连的Transformer(转换)网络。
嵌入网络用以对输入的文本序列进行Embedding(嵌入)处理。更具体地,可以对输入的文本序列的各Token(元素)进行Embedding处理。文本序列的各Token指的是构成文本序列的元素。对于文本序列而言,将文本序列切分为字符或者词语序列,则文本序列中的字符或者词语、以及起始符、分隔符均为Token。
上述基于Token的Embedding处理至少包括:词Embedding和位置Embedding。词Embedding,即将各Token进行词向量编码,得到词向量表示。位置Embedding,即将各Token在待预测文本序列中的位置进行编码,得到位置的表示。
Transformer网络是一个采用自注意力机制对输入的各Token进行编码以转换为特征表示的模型。本申请实施例提供的前向学习的方式,仅需要对Transformer中的自注意力部分进行修改。
图5为本申请实施例提供的一个Transformer网络的原理性示意图,如图5中所示,每一个Transformer网络都包含“横向”对特征表示的处理以及“纵向”对模型参数的更新处理。需要说明的是,这里的“横向”和“纵向”是图中所示出的方向,仅为了方便描述而使用的名称,并不具有真实的方向含义。
如图5中的“横向”处理过程中,当前Transformer网络首先利用上一层网络输出的特征表示确定第一键矩阵KX、第一值矩阵VX和第一查询矩阵Q。其中,若当前Transformer网络为第一层Transformer网络,则上述/>为嵌入网络输出的特征表示。若当前Transformer网络并非第一层Transformer网络,则上述/>为上一层Transformer网络输出的特征表示。图中以当前层Transformer网络为第l层Transformer网络为例。
图中的下标t表示文本预测模型训练过程中的第t轮迭代。每一轮迭代可以从训练数据集中采集一个batch(批)的训练样本用来输入文本预测模型进行训练,该部分与传统模型训练类似,不做详述。
上述第一键矩阵KX、第一值矩阵VX和第一查询矩阵QX可以采用如下公式确定:
其中,WQ、Wk和Wv为权重矩阵,是LLM在预训练过程中已经学习到的参数。
然后将第一键矩阵KX与当前层Transformer网络在上一轮迭代得到的第二键矩阵进行拼接得到第三键矩阵K,以及将第一值矩阵VX与当前层Transformer网络在上一轮迭代得到的第二值矩阵/>进行拼接得到第三值矩阵V。
其中,K和V可以表示为:
其中,||表示拼接的处理。
各Transformer网络的第二键矩阵和第二值矩阵是前向训练过程中需要学习的模型参数,在每一轮迭代中不断更新,每一个Transformer网络均对应有第二键矩阵和第二值矩阵。
再利用第三键矩阵K、第三值矩阵V和第一查询矩阵Q进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示其中自注意力机制的处理是注意力模块和FFN模块的处理,在此不做详述,可以表示为:
其中,Transformer()为注意力模块和FFN模块的处理函数。
上述“横向”处理过程中,将第一键矩阵KX与当前层Transformer网络在上一轮迭代得到的第二键矩阵进行拼接,以及将第一值矩阵VX与当前层Transformer网络在上一轮迭代得到的第二值矩阵/>进行拼接,可以看做是混合了“过去与现在”的信息从而得到第三键矩阵K、第三值矩阵V,符合情景学习的本质。
对于“纵向”处理过程,首先利用第一键矩阵KX对第二键矩阵进行更新,将更新后的第二键矩阵/>作为当前层Transformer网络在当前轮迭代得到的第二键矩阵,以供下一轮迭代使用。利用第一值矩阵VX对第二值矩阵/>进行更新,将更新后的第二值矩阵/>作为当前层Transformer网络在当前轮迭代得到的第二值矩阵,以供下一轮迭代使用。可以看出,该参数的更新过程是一个前向更新过程。
作为其中一种可实现的方式,上述第二键矩阵和第二值矩阵的更新可以采用动量梯度下降的方式。所谓动量梯度下降法就是在梯度下降中,每一次迭代中参数更新的方向与本次迭代采用的动量的方向相反。也可以认为每一次迭代中参数更新的方向是当前梯度的反方向与上一次参数更新方向加权组合而成。也就是说,参数更新的方向不只取决于当前梯度的方向,还取决于过去的参数更新方向。
以更新第二键矩阵为例,可以首先利用第一键矩阵KX和当前层Transformer网络在上一轮迭代得到的第二键矩阵进行逐元素求差,得到键矩阵梯度/>再利用当前层Transformer网络在上一轮迭代得到的键矩阵动量/>和键矩阵梯度/>进行加权求和,得到当前层Transformer网络在当前轮迭代得到的键矩阵动量/>然后利用当前层Transformer网络在当前轮迭代得到的键矩阵动量/>和当前层Transformer网络在上一轮迭代得到的第二键矩阵/>得到更新后的第二键矩阵/>整个过程可以如下公式所示:
其中,β和η分别为动量权重和更新步长,均为预设的超参数,可以采用经验值或者实验值,例如β取0.9,η取0.01。
第二值矩阵的更新过程为:首先利用第一值矩阵VX和当前层Transformer网络在上一轮迭代得到的第二值矩阵进行逐元素求差,得到值矩阵梯度/>再利用当前层Transformer网络在上一轮迭代得到的值矩阵动量/>和值矩阵梯度/>进行加权求和,得到当前层Transformer网络在当前轮迭代得到的值矩阵动量/>然后利用当前层Transformer网络在当前轮迭代得到的值矩阵动量/>和当前层Transformer网络在上一轮迭代得到的第二值矩阵/>得到更新后的第二值矩阵/>整个过程可以如下公式所示:
本申请实施例中在进行模型训练时,输入为输入文本样本和输出标签,Transformer网络通过自注意力处理来学习输入文本样本和输出标签之间的关系,上述过程中的各键矩阵和值矩阵是表达这些关系的关键载体,上述的键矩阵梯度和值矩阵梯度可以看做是两轮迭代学习到的输入文本样本和输出标签之间的关系差异,通过不断前向更新,使得得到的关系趋于一致。
上述更新过程实际上是将上一轮迭代得到的第二键矩阵和第二值矩阵与当前输入特征矩阵产生的第一键矩阵和第一值矩阵进行信息“混合”而产生更新后的第二键矩阵和第二值矩阵,即保留了历史信息又保持了当前输入文本的信息,实现了参数的平滑更新,提高了模型的文本预测能力。
从上述实施例中可以看出,整个模型的更新过程中仅仅涉及到第二键矩阵和第二值矩阵这两个主要参数的更新,而对于模型基本网络的参数,例如嵌入网络、自注意力模块、FFN模块和预测网络等涉及的参数均无需进行更新,大大缩减了需要更新的参数量,降低了计算成本,提高了效率。
每一轮迭代完成后,可以判断是否满足预设的训练结束条件,如果是,则停止迭代,训练结束;否则进行下一轮迭代。训练结束后,将各层Transformer网络当前迭代得到的第二键矩阵和第二值矩阵分别作为训练得到的各层Transformer网络的第二键矩阵和第二值矩阵/>进行存储。/>和/>用于进行文本预测过程中使用。
其中,训练结束条件可以包括但不限于:在验证集上的准确率达到预设准确率阈值,迭代次数达到预设的迭代次数阈值,等等。
本申请实施例中采用动量梯度下降的方式更新各Transformer网络的第二键矩阵和第二值矩阵,能够加快梯度下降的速度,使得迭代效率更高,并且避免陷入局部最小值。另外,除了上述动量梯度下降的方式之外,也可以采用其他的梯度下降方式,在此不做一一列举。
图6为本申请实施例提供的文本预测方法流程图,该方法可以由图2所示系统架构中的文本预测装置执行。如图6中所示,该方法可以包括以下步骤:
步骤602:获取输入文本。
本步骤中涉及的输入文本是没有标注对应输出标签的文本,可以是句子、段落、文章、短语等等。
步骤604:将包含输入文本的文本序列输入文本预测模型,获取文本预测模型预测得到的输入文本对应的输出标签;文本预测模型采用大型语言模型,大型语言模型包括多层串连的转换Transformer网络。其中,各Transformer网络分别作为当前层Transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将第一键矩阵与预先训练得到的当前层Transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将第一值矩阵与预先训练得到的当前层Transformer网络的第二值矩阵进行拼接得到第三值矩阵;利用第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示。
可以仅将输入文本作为文本序列。也可以将输入文本以及输入文本的指示信息构成文本序列。还可以将输入文本与预设的提示文本进行拼接,构成文本序列。其中提示文本可以根据实际场景的需求预先设置,例如在进行情感识别时,可以采用诸如“该句子所表达的情感为【Mask】”。其中【Mask】为掩码部分,文本预测模型通过对文本序列的编码和预测,实际上是实现对【mask】部分内容的预测。
文本预测模型的结构可以参见之前实施例中的相关记载,若采用图4所示框架,则将文本序列输入嵌入网络,嵌入网络用以对输入的文本序列进行Embedding处理。然后经过串连的多个Transformer网络进行编码,得到文本序列的特征表示。最后由预测网络利用文本序列的特征表示预测输入文本对应的输出标签。
其中,每一个Transformer网络的处理过程可以如图7中所示,若上一层网络输出的特征表示为则首先利用/>确定第一键矩阵KX、第一值矩阵VX和第一查询矩阵Q。其中,若当前Transformer网络为第一层Transformer网络,则上述/>为嵌入网络输出的特征表示。若当前Transformer网络并非第一层Transformer网络,则上述/>为上一层Transformer网络输出的特征表示。图中以当前层Transformer网络为第l层Transformer网络为例。
上述第一键矩阵KX、第一值矩阵VX和第一查询矩阵QX可以采用如下公式确定:
然后将第一键矩阵KX与预先训练得到的当前层Transformer网络的第二键矩阵进行拼接得到第三键矩阵K,以及将第一值矩阵VX与预先训练得到的当前层Transformer网络的第二值矩阵/>进行拼接得到第三值矩阵V。
其中,K和V可以表示为:
再利用第三键矩阵K、第三值矩阵V和第一查询矩阵Q进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示其中自注意力机制的处理是注意力模块和FFN模块的处理,在此不做详述,可以表示为:
从上述预测过程可以看出,该预测过程使用了在模型训练后得到的和/>与普通的情景学习相比,/>和/>包含了文本预测模型对已标注的样本数据更好的观察和理解,能够显著提高预测准确性。
从上述训练和预测过程可以看出,本申请实施例采用的情景学习的原理如图7中所示,先利用训练样本集中的样本数据训练LLM,最终得到所有Transformer网络的第二键矩阵KT和第二值矩阵VT。在预测过程中,利用已经训练得到的LLM对输入文本进行预测,预测过程中各Transformer需要使用预先训练得到的KT和VT进行自注意力处理,具体过程参见之前实施例中的相关记载。与图1所示现有的情景学习存在显著区别。
本申请实施例提供的上述方法可以应用于多种应用场景,包括但不限于:情感分析、意图识别、问答处理等。下面以情感分析为例,对上述实施例提供的方法进行举例描述。
首先获取已标注的样本数据构建训练数据集,训练数据集包括多个训练样本,各训练样本包括输入文本样本以及该输入文本样本对应的情感类别标签。例如表1中所示:
表1
输入文本样本 情感类别标签
食物太好吃了 积极
食物太难吃了 消极
糟糕的菜 消极
这道菜汤汁浓郁 积极
香气扑鼻,鲜美无比 积极
…… ……
上述表1中提供的情感类别是以“积极”、“中性”和“消极”等为例,也可以划分为其他更粗粒度,或者更细粒度的情感类别。例如划分为“兴奋”、“高兴”、“平和”、“伤心”、“生气”、“愤怒”等等。
将包含输入文本样本和情感类别标签的文本序列作为文本预测模型的输入,训练文本预测模型。例如,构建文本序列“评价:食物太好吃了。情感类别:积极”。
该文本预测模型是在预训练得到的LLM的基础上进行模型优化得到的。文本预测模型的结构和训练过程可以参见上面方法实施例中的相关记载。其中的核心内容是,LLM包括多层串连的转换Transformer网络。在训练中各Transformer网络分别作为当前层Transformer网络执行:在当前轮迭代中,利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将第一键矩阵与当前层Transformer网络在上一轮迭代得到的第二键矩阵进行拼接得到第三键矩阵,以及将第一值矩阵与当前层Transformer网络在上一轮迭代得到的第二值矩阵进行拼接得到第三值矩阵;利用第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示;利用第一键矩阵对上述第二键矩阵进行更新,将更新后的第二键矩阵作为当前层Transformer网络在当前轮迭代得到的第二键矩阵;利用第一值矩阵对上述第二值矩阵进行更新,将更新后的第二值矩阵作为当前层Transformer网络在当前轮迭代得到的第二值矩阵。其他细节在此不做赘述。
在进行模型训练时,Transformer网络通过自注意力处理来学习输入文本样本和情感类别标签之间的关系,上述过程中的各键矩阵和值矩阵是表达这些关系的关键载体,上述的键矩阵梯度和值矩阵梯度可以看做是两轮迭代学习到的输入文本样本和情感类别标签之间的关系差异,通过不断前向更新,使得得到的关系趋于一致。
模型训练过程中,文本预测模型在每一轮迭代中均对历史信息进行了保留,从而对已标注的样本数据进行了充分的理解和学习。模型训练结束后,将训练得到的各层Transformer网络的第二键矩阵和第二值矩阵/>进行存储,用于进行情感分析过程中使用。
在利用已经训练得到的文本预测模型进行情感分析时,对于未标注的输入文本,例如“这道菜真让人爱不释口”,作为其中一种可实现的方式,可以将“评价:这道菜真让人爱不释口”作为文本序列输入文本预测模型,获取文本预测模型预测得到的输入文本对应的情感类别。
作为另一种可实现的方式,可以将输入文本与提示文本进行拼接得到的文本序列输入文本预测模型,获取文本预测模型预测得到的输入文本对应的情感类别。其中,提示文本可以采用诸如“该句子所表达的情感为【Mask】”。例如,将输入文本“这道菜汤汁浓郁”与提示文本进行拼接后,可以得到文本序列“这道菜汤汁浓郁,该句子所表达的情感为【Mask】”。其中【Mask】为掩码部分,文本预测模型通过对文本序列的编码和预测,实际上是实现对【mask】部分内容的预测,预测结果实际上是映射到情感类别空间上的具体情感类别上。
同样,文本预测模型采用大型语言模型,其结构和原理参见之前实施例中的相关记载,在此不做赘述。其核心内容是,大型语言模型包括多层串连的转换Transformer网络。其中,各Transformer网络分别作为当前层Transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将第一键矩阵与预先训练得到的当前层Transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将第一值矩阵与预先训练得到的当前层Transformer网络的第二值矩阵/>进行拼接得到第三值矩阵;利用第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示。
作为其中一种可实现的方式,上述训练文本预测模型的方法或文本预测方法可以由云端服务器执行。云服务器又称为云计算服务器或云主机,是云计算服务体系中的一项主机产品,以解决传统物理主机与虚拟专用服务器(VPS,Virtual Private Server)服务中存在的管理难度大,服务扩展性弱的缺陷。
云端服务器获取到来自用户终端的输入文本后;将包含输入文本的文本序列输入文本预测模型,获取文本预测模型预测得到的输入文本对应的输出标签;基于输出标签确定对应的服务内容,将服务内容发送至所述用户终端。例如,获取用户在智能客服系统中输入的文本,对该文本进行情感分析。然后依据情感分析得到的情感类别向用户提供对应的服务内容。例如若识别出用户很生气,可以转接人工客服进行处理。若识别出用户很高兴,可以向用户发送推广信息,等等。
其中,所述文本预测模型的原理和结构可以参见之前方法实施例中的相关记载,在此不做赘述。
其中上述终端设备可以包括但不限于诸如:智能移动终端、智能家居设备、可穿戴式设备、PC(Personal Computer,个人计算机)等。其中智能移动设备可以包括诸如手机、平板电脑、笔记本电脑、PDA(Personal Digital Assistant,个人数字助理)、互联网汽车等。智能家居设备可以包括智能电视、智能音箱、智能冰箱等等。可穿戴式设备可以包括诸如智能手表、智能眼镜、虚拟现实设备、增强现实设备、混合现实设备(即可以支持虚拟现实和增强现实的设备)等等。
上述对本说明书特定实施例进行了描述。其它实施例在所附权利要求书的范围内。在一些情况下,在权利要求书中记载的动作或步骤可以按照不同于实施例中的顺序来执行并且仍然可以实现期望的结果。另外,在附图中描绘的过程不一定要求示出的特定顺序或者连续顺序才能实现期望的结果。在某些实施方式中,多任务处理和并行处理也是可以的或者可能是有利的。
根据另一方面的实施例,提供了一种训练文本预测模型的装置,该装置对应于图2所示系统中的模型训练装置。图8示出根据一个实施例的该模型训练装置的示意性框图,如图8所示,该装置800可以包括:样本获取单元801和模型训练单元802。其中各单元的主要功能如下:
样本获取单元801,被配置为获取训练数据集,训练数据集包括输入文本样本以及该输入文本样本对应的输出标签样本。
模型训练单元802,被配置为将包含输入文本样本和该输入文本样本对应的输出标签的文本序列作为文本预测模型的输入,训练文本预测模型。其中,文本预测模型采用LLM,LLM包括多层串连的转换Transformer网络;在训练中各Transformer网络分别作为当前层Transformer网络执行:在当前轮迭代中,利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将第一键矩阵与当前层Transformer网络在上一轮迭代得到的第二键矩阵进行拼接得到第三键矩阵,以及将第一值矩阵与当前层Transformer网络在上一轮迭代得到的第二值矩阵进行拼接得到第三值矩阵;利用第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示;利用第一键矩阵对第二键矩阵进行更新,将更新后的第二键矩阵作为当前层Transformer网络在当前轮迭代得到的第二键矩阵;利用第一值矩阵对第二值矩阵进行更新,将更新后的第二值矩阵作为当前层Transformer网络在当前轮迭代得到的第二值矩阵。
更进一步地,上述LLM还可以包括嵌入网络,上述文本预测模型还包括预测网络。
嵌入网络用以对文本序列进行嵌入处理。
若当前层Transformer网络为第一层Transformer网络,则上述的上一层网络输出的特征表示为嵌入网络输出的特征表示;否则,上述的上一层网络输出的特征表示为上一层Transformer网络输出的特征表示。
预测网络用以利用最后一层Transformer网络输出的特征表示预测输入文本样本对应的输出标签。
作为其中一种可实现的方式,在当前层Transformer网络中,可以利用第一键矩阵对当前层Transformer网络在上一轮迭代得到的第二键矩阵采用动量梯度下降的方式进行更新;利用第一值矩阵对当前层Transformer网络在上一轮迭代得到的第二值矩阵采用动量梯度下降的方式进行更新。
作为其中一种可实现的方式,在当前层Transformer网络中,可以利用第一键矩阵和当前层Transformer网络在上一轮迭代得到的第二键矩阵进行逐元素求差,得到键矩阵梯度;利用当前层Transformer网络在上一轮迭代得到的键矩阵动量和键矩阵梯度进行加权求和,得到当前层Transformer网络在当前轮迭代得到的键矩阵动量;利用当前层Transformer网络在当前轮迭代得到的键矩阵动量和当前层Transformer网络在上一轮迭代得到的第二键矩阵,得到更新后的第二键矩阵。
可以利用第一值矩阵和当前层Transformer网络在上一轮迭代得到的第二值矩阵进行逐元素求差,得到值矩阵梯度;利用当前层Transformer网络在上一轮迭代得到的值矩阵动量和值矩阵梯度进行加权求和,得到当前层Transformer网络在当前轮迭代得到的值矩阵动量;利用当前层Transformer网络在当前轮迭代得到的值矩阵动量和当前层Transformer网络在上一轮迭代得到的第二值矩阵,得到更新后的第二值矩阵。
更进一步地,每一轮迭代完成后,若确定达到预设的训练结束条件,则模型训练单元802将各层Transformer网络当前迭代得到的第二键矩阵和第二值矩阵分别作为训练得到的各层Transformer网络的第二键矩阵和第二值矩阵进行存储。
根据再一方面的实施例,提供了一种文本预测装置。图9示出根据一个实施例的文本预测装置的示意性框图,如图9所示,该装置900可以包括:文本获取单元901和文本预测单元902。其中各单元的主要功能如下:
文本获取单元901,被配置为获取输入文本。
文本预测单元902,被配置为将包含输入文本的文本序列输入文本预测模型,获取文本预测模型预测得到的输入文本对应的输出标签;文本预测模型采用LLM,LLM包括多层串连的转换Transformer网络;其中,各Transformer网络分别作为当前层Transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将第一键矩阵与预先训练得到的当前层Transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将第一值矩阵与预先训练得到的当前层Transformer网络的第二值矩阵进行拼接得到第三值矩阵;利用第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示。
更进一步地,LLM还包括嵌入网络,上述文本预测模型还包括预测网络。
嵌入网络用以对文本序列进行嵌入处理。
若当前层Transformer网络为第一层Transformer网络,则上述的上一层网络输出的特征表示为嵌入网络输出的特征表示;否则,上述的上一层网络输出的特征表示为上一层Transformer网络输出的特征表示。
预测网络用以利用最后一层Transformer网络输出的特征表示预测输入文本对应的输出标签。
本说明书中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置实施例而言,由于其基本相似于方法实施例,所以描述得比较简单,相关之处参见方法实施例的部分说明即可。以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。
需要说明的是,本申请所涉及的用户信息(包括但不限于用户设备信息、用户个人信息等)和数据(包括但不限于用于分析的数据、存储的数据、展示的数据等),均为经用户授权或者经过各方充分授权的信息和数据,并且相关数据的收集、使用和处理需要遵守相关国家和地区的相关法律法规和标准,并提供有相应的操作入口,供用户选择授权或者拒绝。
另外,本申请实施例还提供了一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现前述方法实施例中任一项所述的方法的步骤。
以及一种电子设备,包括:
一个或多个处理器;以及
与所述一个或多个处理器关联的存储器,所述存储器用于存储程序指令,所述程序指令在被所述一个或多个处理器读取执行时,执行前述方法实施例中任一项所述的方法的步骤。
本申请还提供了一种计算机程序产品,包括计算机程序,该计算机程序在被处理器执行时实现前述方法实施例中任一项所述的方法的步骤。
其中,图10示例性的展示出了电子设备的架构,具体可以包括处理器1010,视频显示适配器1011,磁盘驱动器1012,输入/输出接口1013,网络接口1014,以及存储器1020。上述处理器1010、视频显示适配器1011、磁盘驱动器1012、输入/输出接口1013、网络接口1014,与存储器1020之间可以通过通信总线1030进行通信连接。
其中,处理器1010可以采用通用的CPU、微处理器、应用专用集成电路(Application Specific Integrated Circuit,ASIC)、或者一个或多个集成电路等方式实现,用于执行相关程序,以实现本申请所提供的技术方案。
存储器1020可以采用ROM(Read Only Memory,只读存储器)、RAM(RandomAccessMemory,随机存取存储器)、静态存储设备,动态存储设备等形式实现。存储器1020可以存储用于控制电子设备1000运行的操作系统1021,用于控制电子设备1000的低级别操作的基本输入输出系统(BIOS)1022。另外,还可以存储网页浏览器1023,数据存储管理系统1024,以及模型训练装置/文本预测装置1025等等。上述模型训练装置/文本预测装置1025就可以是本申请实施例中具体实现前述各步骤操作的应用程序。总之,在通过软件或者固件来实现本申请所提供的技术方案时,相关的程序代码保存在存储器1020中,并由处理器1010来调用执行。
输入/输出接口1013用于连接输入/输出模块,以实现信息输入及输出。输入输出/模块可以作为组件配置在设备中(图中未示出),也可以外接于设备以提供相应功能。其中输入设备可以包括键盘、鼠标、触摸屏、麦克风、各类传感器等,输出设备可以包括显示器、扬声器、振动器、指示灯等。
网络接口1014用于连接通信模块(图中未示出),以实现本设备与其他设备的通信交互。其中通信模块可以通过有线方式(例如USB、网线等)实现通信,也可以通过无线方式(例如移动网络、WIFI、蓝牙等)实现通信。
总线1030包括一通路,在设备的各个组件(例如处理器1010、视频显示适配器1011、磁盘驱动器1012、输入/输出接口1013、网络接口1014,与存储器1020)之间传输信息。
需要说明的是,尽管上述设备仅示出了处理器1010、视频显示适配器1011、磁盘驱动器1012、输入/输出接口1013、网络接口1014,存储器1020,总线1030等,但是在具体实施过程中,该设备还可以包括实现正常运行所必需的其他组件。此外,本领域的技术人员可以理解的是,上述设备中也可以仅包含实现本申请方案所必需的组件,而不必包含图中所示的全部组件。
通过以上的实施方式的描述可知,本领域的技术人员可以清楚地了解到本申请可借助软件加必需的通用硬件平台的方式来实现。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分可以以计算机程序产品的形式体现出来,该计算机程序产品可以存储在存储介质中,如ROM/RAM、磁碟、光盘等,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例或者实施例的某些部分所述的方法。
以上对本申请所提供的技术方案进行了详细介绍,本文中应用了具体个例对本申请的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本申请的方法及其核心思想;同时,对于本领域的一般技术人员,依据本申请的思想,在具体实施方式及应用范围上均会有改变之处。综上所述,本说明书内容不应理解为对本申请的限制。

Claims (14)

1.一种训练文本预测模型的方法,其特征在于,所述方法包括:
获取训练数据集,所述训练数据集包括输入文本样本以及该输入文本样本对应的输出标签;
将包含输入文本样本和该输入文本样本对应的输出标签的文本序列作为文本预测模型的输入,训练所述文本预测模型;其中,所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换Transformer网络;
在所述训练中各Transformer网络分别作为当前层Transformer网络执行:在当前轮迭代中,利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与当前层Transformer网络在上一轮迭代得到的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与当前层Transformer网络在上一轮迭代得到的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示;利用所述第一键矩阵对所述第二键矩阵进行更新,将更新后的第二键矩阵作为所述当前层Transformer网络在当前轮迭代得到的第二键矩阵;利用所述第一值矩阵对所述第二值矩阵进行更新,将更新后的第二值矩阵作为所述当前层Transformer网络在当前轮迭代得到的第二值矩阵。
2.根据权利要求1所述的方法,其特征在于,所述大型语言模型还包括嵌入网络,所述文本预测模型还包括预测网络;
所述嵌入网络用以对所述文本序列进行嵌入处理;
若当前层Transformer网络为第一层Transformer网络,则所述上一层网络输出的特征表示为所述嵌入网络输出的特征表示;否则,所述上一层网络输出的特征表示为上一层Transformer网络输出的特征表示;
所述预测网络用以利用最后一层Transformer网络输出的特征表示预测输入文本样本对应的输出标签。
3.根据权利要求1所述的方法,其特征在于,利用所述第一键矩阵对所述第二键矩阵进行更新包括:利用所述第一键矩阵对所述当前层Transformer网络在上一轮迭代得到的第二键矩阵采用动量梯度下降的方式进行更新;
利用所述第一值矩阵对所述第二值矩阵进行更新包括:利用所述第一值矩阵对所述当前层Transformer网络在上一轮迭代得到的第二值矩阵采用动量梯度下降的方式进行更新。
4.根据权利要求3所述的方法,其特征在于,利用所述第一键矩阵对所述当前层Transformer网络在上一轮迭代得到的第二键矩阵采用动量梯度下降的方式进行更新包括:利用所述第一键矩阵和所述当前层Transformer网络在上一轮迭代得到的第二键矩阵进行逐元素求差,得到键矩阵梯度;利用所述当前层Transformer网络在上一轮迭代得到的键矩阵动量和所述键矩阵梯度进行加权求和,得到所述当前层Transformer网络在当前轮迭代得到的键矩阵动量;利用所述当前层Transformer网络在当前轮迭代得到的键矩阵动量和当前层Transformer网络在上一轮迭代得到的第二键矩阵,得到所述更新后的第二键矩阵;
利用所述第一值矩阵对所述当前层Transformer网络在上一轮迭代得到的第二值矩阵采用动量梯度下降的方式进行更新包括:利用所述第一值矩阵和所述当前层Transformer网络在上一轮迭代得到的第二值矩阵进行逐元素求差,得到值矩阵梯度;利用所述当前层Transformer网络在上一轮迭代得到的值矩阵动量和所述值矩阵梯度进行加权求和,得到所述当前层Transformer网络在当前轮迭代得到的值矩阵动量;利用所述当前层Transformer网络在当前轮迭代得到的值矩阵动量和当前层Transformer网络在上一轮迭代得到的第二值矩阵,得到所述更新后的第二值矩阵。
5.根据权利要求1至4中任一项所述的方法,其特征在于,每一轮迭代完成后,若确定达到预设的训练结束条件,则将各层Transformer网络当前迭代得到的第二键矩阵和第二值矩阵分别作为训练得到的各层Transformer网络的第二键矩阵和第二值矩阵进行存储。
6.一种文本预测方法,其特征在于,所述方法包括:
获取输入文本;
将包含所述输入文本的文本序列输入文本预测模型,获取所述文本预测模型预测得到的所述输入文本对应的输出标签;所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换Transformer网络;
各Transformer网络分别作为当前层Transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与预先训练得到的当前层Transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与预先训练得到的当前层Transformer网络的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示。
7.根据权利要求6所述的方法,其特征在于,所述大型语言模型还包括嵌入网络,所述文本预测模型还包括预测网络;
所述嵌入网络用以对所述文本序列进行嵌入处理;
若当前层Transformer网络为第一层Transformer网络,则所述上一层网络输出的特征表示为所述嵌入网络输出的特征表示;否则,所述上一层网络输出的特征表示为上一层Transformer网络输出的特征表示;
所述预测网络用以利用最后一层Transformer网络输出的特征表示预测输入文本对应的输出标签。
8.一种训练文本预测模型的方法,其特征在于,所述方法包括:
获取训练数据集,所述训练数据集包括输入文本样本以及该输入文本样本对应的情感类别标签;
将包含输入文本样本和该输入文本样本对应的情感类别标签的文本序列作为文本预测模型的输入,训练所述文本预测模型;其中,所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的Transformer网络;
在所述训练中各Transformer网络分别作为当前层Transformer网络执行:在当前轮迭代中,利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与当前层Transformer网络在上一轮迭代得到的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与当前层Transformer网络在上一轮迭代得到的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示;利用所述第一键矩阵对所述第二键矩阵进行更新,将更新后的第二键矩阵作为所述当前层Transformer网络在当前轮迭代得到的第二键矩阵;利用所述第一值矩阵对所述第二值矩阵进行更新,将更新后的第二值矩阵作为所述当前层Transformer网络在当前轮迭代得到的第二值矩阵。
9.一种情感分析方法,其特征在于,所述方法包括:
获取输入文本;
将包含所述输入文本的文本序列输入文本预测模型,获取所述文本预测模型预测得到的所述输入文本对应的情感类别标签;所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换Transformer网络;
各Transformer网络分别作为当前层Transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与预先训练得到的当前层Transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与预先训练得到的当前层Transformer网络的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示。
10.一种文本预测方法,由云端服务器执行,其特征在于,所述方法包括:
获取来自用户终端的输入文本;
将包含所述输入文本的文本序列输入文本预测模型,获取所述文本预测模型预测得到的所述输入文本对应的输出标签;
基于所述输出标签确定对应的服务内容,将所述服务内容发送至所述用户终端;
其中,所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换Transformer网络;
各Transformer网络分别作为当前层Transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与预先训练得到的当前层Transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与预先训练得到的当前层Transformer网络的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示。
11.一种训练文本预测模型的装置,其特征在于,所述装置包括:
样本获取单元,被配置为获取训练数据集,所述训练数据集包括输入文本样本以及该输入文本样本对应的输出标签样本;
模型训练单元,被配置为将包含输入文本样本和该输入文本样本对应的输出标签的文本序列作为文本预测模型的输入,训练所述文本预测模型;其中,所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换Transformer网络;在所述训练中各Transformer网络分别作为当前层Transformer网络执行:在当前轮迭代中,利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与当前层Transformer网络在上一轮迭代得到的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与当前层Transformer网络在上一轮迭代得到的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示;利用所述第一键矩阵对所述第二键矩阵进行更新,将更新后的第二键矩阵作为所述当前层Transformer网络在当前轮迭代得到的第二键矩阵;利用所述第一值矩阵对所述第二值矩阵进行更新,将更新后的第二值矩阵作为所述当前层Transformer网络在当前轮迭代得到的第二值矩阵。
12.一种文本预测装置,其特征在于,所述装置包括:
文本获取单元,被配置为获取输入文本;
文本预测单元,被配置为将包含所述输入文本的文本序列输入文本预测模型,获取所述文本预测模型预测得到的所述输入文本对应的输出标签;所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换Transformer网络;其中,各Transformer网络分别作为当前层Transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与预先训练得到的当前层Transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与预先训练得到的当前层Transformer网络的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层Transformer网络输出的特征表示。
13.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该程序被处理器执行时实现权利要求1至10中任一项所述的方法的步骤。
14.一种电子设备,其特征在于,包括:
一个或多个处理器;以及
与所述一个或多个处理器关联的存储器,所述存储器用于存储程序指令,所述程序指令在被所述一个或多个处理器读取执行时,执行权利要求1至10中任一项所述的方法的步骤。
CN202310459343.3A 2023-04-23 2023-04-23 训练文本预测模型的方法、文本预测方法及装置 Pending CN116628147A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310459343.3A CN116628147A (zh) 2023-04-23 2023-04-23 训练文本预测模型的方法、文本预测方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310459343.3A CN116628147A (zh) 2023-04-23 2023-04-23 训练文本预测模型的方法、文本预测方法及装置

Publications (1)

Publication Number Publication Date
CN116628147A true CN116628147A (zh) 2023-08-22

Family

ID=87590995

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310459343.3A Pending CN116628147A (zh) 2023-04-23 2023-04-23 训练文本预测模型的方法、文本预测方法及装置

Country Status (1)

Country Link
CN (1) CN116628147A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN118036754A (zh) * 2024-04-12 2024-05-14 清华大学 基于键值矩阵缓存的模型推理方法及装置、介质

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN118036754A (zh) * 2024-04-12 2024-05-14 清华大学 基于键值矩阵缓存的模型推理方法及装置、介质

Similar Documents

Publication Publication Date Title
CN110781663B (zh) 文本分析模型的训练方法及装置、文本分析方法及装置
CN111461301B (zh) 序列化数据处理方法和装置、文本处理方法和装置
CN111177325B (zh) 一种自动生成答案的方法和系统
EP4113357A1 (en) Method and apparatus for recognizing entity, electronic device and storage medium
CN113987147A (zh) 样本处理方法及装置
CN110347802A (zh) 一种文本分析方法及装置
CN112699686A (zh) 基于任务型对话系统的语义理解方法、装置、设备及介质
CN115309877A (zh) 对话生成方法、对话模型训练方法及装置
CN111291187A (zh) 一种情感分析方法、装置、电子设备及存储介质
CN115114407B (zh) 意图识别方法、装置、计算机设备及存储介质
CN111859967A (zh) 实体识别方法、装置,电子设备
CN116245097A (zh) 训练实体识别模型的方法、实体识别方法及对应装置
CN116628147A (zh) 训练文本预测模型的方法、文本预测方法及装置
CN116050425A (zh) 建立预训练语言模型的方法、文本预测方法及装置
CN117114063A (zh) 用于训练生成式大语言模型和用于处理图像任务的方法
CN114817478A (zh) 基于文本的问答方法、装置、计算机设备及存储介质
CN114648032A (zh) 语义理解模型的训练方法、装置和计算机设备
CN113722436A (zh) 文本信息提取方法、装置、计算机设备及存储介质
CN114077655A (zh) 一种答案抽取模型的训练方法及装置
CN117556005A (zh) 质量评估模型的训练方法、多轮对话质量评估方法和装置
CN112906368A (zh) 行业文本增量方法、相关装置及计算机程序产品
CN111931503A (zh) 信息抽取方法及装置、设备、计算机可读存储介质
CN109002498B (zh) 人机对话方法、装置、设备及存储介质
CN116662496A (zh) 信息抽取方法、训练问答处理模型的方法及装置
CN114970666B (zh) 一种口语处理方法、装置、电子设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination