CN114780722B - 一种结合领域通用型语言模型的领域泛化方法 - Google Patents
一种结合领域通用型语言模型的领域泛化方法 Download PDFInfo
- Publication number
- CN114780722B CN114780722B CN202210342805.9A CN202210342805A CN114780722B CN 114780722 B CN114780722 B CN 114780722B CN 202210342805 A CN202210342805 A CN 202210342805A CN 114780722 B CN114780722 B CN 114780722B
- Authority
- CN
- China
- Prior art keywords
- model
- domain
- training
- data
- language model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 28
- 238000012549 training Methods 0.000 claims abstract description 59
- 230000000694 effects Effects 0.000 claims abstract description 18
- 230000006870 function Effects 0.000 claims description 13
- 238000012360 testing method Methods 0.000 claims description 8
- 238000012795 verification Methods 0.000 claims description 6
- 238000011156 evaluation Methods 0.000 claims description 5
- 238000013528 artificial neural network Methods 0.000 claims description 4
- 238000004364 calculation method Methods 0.000 claims description 4
- 230000004913 activation Effects 0.000 claims description 3
- 230000001174 ascending effect Effects 0.000 claims description 3
- 238000011478 gradient descent method Methods 0.000 claims description 2
- 230000008569 process Effects 0.000 claims description 2
- 238000003058 natural language processing Methods 0.000 abstract description 3
- 238000005516 engineering process Methods 0.000 abstract description 2
- 238000013508 migration Methods 0.000 abstract description 2
- 230000005012 migration Effects 0.000 abstract description 2
- 238000011056 performance test Methods 0.000 abstract description 2
- 238000013473 artificial intelligence Methods 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 238000009826 distribution Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 210000002569 neuron Anatomy 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 230000000717 retained effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/35—Clustering; Classification
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Biophysics (AREA)
- Evolutionary Computation (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Databases & Information Systems (AREA)
- Machine Translation (AREA)
Abstract
本发明涉及一种结合领域通用型语言模型的领域泛化方法,属于人工智能迁移学习技术领域。本方法综合了预训练语言模型和模型裁剪技术。首先对预训练语言模型微调,利用多个源域数据对预训练语言模型进行训练。基于微调后得到的模型,计算模型中参数的域不变分数,对域不变分数低的参数进行裁剪。最后对裁剪后的语言模型进行重训练,将训练得到领域通用型语言模型在不同数据上进行泛化性能测试。本方法解决了过度参数化的预训练语言模型中学习方差大的问题。基于本发明的领域通用型语言模型明显优于相应的基线模型,在自然语言处理领域泛化任务上取得了良好效果。
Description
技术领域
本发明涉及一种结合领域通用型语言模型的领域泛化方法,属于人工智能迁移学习技术领域。
背景技术
在人工智能领域,领域泛化(Domain Generalization,DG)技术,是从若干个具有不同数据分布的数据集(领域)中学习一个泛化能力强的模型,以便在未知的测试集上取得较好的效果。例如,给定一个由餐厅、购物等领域的评论文本组成的训练集,要求训练一个良好的机器学习模型能够在对图书领域的数据集上进行分类时具有最小的预测误差。
现有的领域泛化方法,主要从数据操作、表示学习和学习策略三方面进行研究。其中,数据操作主要利用数据增强和数据生成两种技术来帮助学习一般的表示;表示学习主要通过学习域不变表示学习或利用特征解耦来得到域共享表示,提高模型泛化性能;学习策略侧重于利用通用学习策略促进泛化能力,主要包括集成学习、元学习和梯度操作。
领域泛化任务中的目标域数据不可访问性,这使得领域泛化更具挑战性和实用性。例如,在自然语言处理领域,由于预训练语言模型过度参数化的特性能够带来更小的学习误差,研究人员将其用于领域泛化任务,利用多个源域的数据对BERT等预训练语言模型进行微调,再将其在不同的目标域数据集上进行测试。结果发现,预训练语言模型能够比传统的模型拥有更强的泛化能力。但是,由于预训练语言模型中存在一些对于特定领域有效的参数,使得模型内部的神经元会在某些领域数据上激活而其他领域不激活。这种参数的领域不一致性,使得预训练语言模型的领域泛化能力下降。
发明内容
本发明的目的是为了解决因预训练语言模型存在特定域有效参数导致的泛化性能下降的技术问题,创造性地提出了一种结合领域通用型语言模型的领域泛化方法。本方法综合了预训练语言模型和模型裁剪技术。
对于领域泛化任务而言,泛化模型需要在所有域上都具备较好的泛化性能。为了去除预训练语言模型中的只对特定领域有效的参数来提高模型的泛化能力,本方法定义了域不变分数来识别特定域有效参数,通过去除特定域有效参数,保留域不变参数,得到领域通用型语言模型。然后,将领域通用型语言模型用于领域泛化,显著提高了模型的领域泛化性能。
首先,对预训练语言模型微调,利用多个源域数据对预训练语言模型进行训练。基于微调后得到的模型,计算模型中参数的域不变分数,对域不变分数低的参数进行裁剪。最后,对裁剪后的语言模型进行重训练,将训练得到领域通用型语言模型在不同数据上进行泛化性能测试。
本发明采用的技术方式如下:
一种结合领域通用型语言模型的领域泛化方法,包括以下步骤:
步骤1:预训练语言模型微调。
使用预训练语言模型(例如BERT)在给定的源域数据进行训练,利用使用多层感知器(MLP)微调预训练语言模型。其中,多层感知器包含四层:全连接层、双曲正切函数(Tanh)激活函数层、随机丢弃层(dropout层)和全连接层。
利用训练好的预训练语言模型,在目标域数据上获得经过全连接层的输出表示并送至软最大化标准化层(softmax层),对目标域数据进行相应标签预测。
步骤2:计算参数域不变分数。
本发明中,仅对预训练语言中的多头注意力MHA和前馈神经网络FFN模块进行裁剪。
对于待裁剪的参数,当数据集中只有一种领域的数据时,其对应的参数重要程度分数I如下式所示:
其中,IFFN分别表示第i个MHA模块和FFN模块对应的参数重要程度。(x,y)表示数据点,x、y分别表示模型输入和对应的真实标签;/>是损失函数;ξ(i)、υ分别为MHA、FFN对应的裁剪变量。/>表示偏导数。D表示领域集合。
在基于参数重要程度分数I的基础上,将跨领域的参数重要程度分数的期望与方差纳入考虑范围,提出了参数域不变分数I′。对于待裁剪的参数,其对应的参数域不变分数I′如下式所示:
其中,(x,y)是指领域d中的数据点,D是指领域集合。V表示方差,E表示期望。参数域不变分数对将跨领域的参数重要程度分数的均值与方差进行平衡,参数λ用以权衡二者之间的关系。
对参数而言,该参数的域不变分数越大,说明该参数的在各领域上的泛化能力更优。反之,若域不变分数越小,说明该参数领域泛化性弱,仅在某些领域上有效。
步骤3:参数裁剪。
如步骤2所述,对于每个参与域不变分数计算的参数,都有对应的裁剪变量,用以表示该参数是否被裁剪。
在对参数进行域不变分数计算后,根据域不变分数对参数进行升序排列,并优先对域不变分数低的参数进行裁剪。
具体地,当ξ(i)=0,其对应的注意力头Head会被裁剪,反之该参数会被保留。当υ=0,其对应的前馈神经网络FFN会被裁剪,反之该参数会被保留。通过设置裁剪率(例如10%),本发明将域不变分数最低的参数进行裁剪,即将其对应的裁剪变量置为0。
步骤4:对裁剪后的模型重训练。
对参数进行裁剪后,将裁剪后的模型进行重训练。其中,重训练需要将裁剪后的模型置为步骤1的初始状态,再让裁剪后的模型在给定的多个源域数据进行训练。然后来对目标域数据进行相应标签预测。
通过设置不同的裁剪率,得到领域泛化效果最好的裁剪后的模型。至此,即获得了领域通用型语言模型。
步骤5:利用领域通用型语言模型,对训练领域数据以外的其他领域数据进行分类预测。
有益效果
本方法,对比现有技术,具有以下优点:
1.本发明解决了过度参数化的预训练语言模型中学习方差大的问题,基于域不变分数对预训练语言中特定域有效参数进行裁剪,保留下对领域泛化更有用的通用域有效参数。
2.本发明提出的领域通用型语言模型,在亚马逊评论数据集和多类型自然语言推理数据库上的明显优于相应的基线模型,具体表现为领域通用型语言模型的准确率得分比相应基线模型平均提高1.5个百分点。领域通用型语言模型在自然语言处理领域泛化任务上取得了最新的最好效果。
附图说明
图1是本发明的整体流程图。
图2是构建领域通用型语言模型的模型结构图。
具体实施方式
下面结合附图和实施例对本发明进一步详细描述。
实施例
如图1、图2所示,一种结合领域通用型语言模型的领域泛化方法,包括以下步骤:
步骤1:预训练语言模型微调。
具体地,包括以下步骤:
步骤1.1:加载多领域评论语料集,数据集分为训练集、验证集和测试集,并构造成批数据形式。
步骤1.2:加载预训练语言模型M,初始化后保存。其中,预训练语言模型,可以是BERT-base或BERT-large模型。
步骤1.3:模型训练。
批数据再经过BERT结构后获得句子向量表示。本方法使用多层感知器(MLP)来微调预训练语言模型。本方法中的多层感知器MLP包含四层:全连接层、ReLU(线性整流函数)激活函数层、dropout层(随机丢弃层)和全连接层。最后,将经过全连接层的输出表示送至softmax层(软最大化标准化层)以预测相应的标签。其中,模型训练的目标函数为交叉熵函数,具体表示形式如下:
其中,m是标签类别的数量,c表示m的某一类别,N为训练样本个数,yic为样本i为类别c的真实概率,为样本i为类别c的预测概率;
模型训练为达到最小交叉熵损失,采用随机梯度下降法对其进行优化。在模型训练过程中,每一次训练后,用验证集数据对模型进行效果评价,此处采用的评价指标为各领域的平均准确率。在每轮验证后,保存效果最优的模型M′。
步骤1.4:效果评价。
利用测试集数据对步骤1.3获得的模型M′进行效果评价。首先加载最优模型M′,将测试集数据作为模型的输入,预测步骤与步骤1.3相同,此处使用的评价指标与步骤1.3相同。
步骤2:计算参数域不变分数。
具体地,包括以下步骤:
步骤2.1:加载步骤1.3保存的最优模型M′和训练集数据。
步骤2.2:将训练集数据输入模型M′,对参数计算其在各个领域的重要程度分数,此处参数特指MHA(多头注意力)和FFN(前馈网络),计算公式如下所示:
其中,(x,y)是指数据点,是损失函数。ξ(i)和υ是MHA和FFN对应的裁剪变量。
步骤2.3:基于各个领域的参数重要程度分数,计算参数的域不变分数,计算公式如下所示:
其中,(x,y)是指领域d中的数据点,D是指领域集合。参数域不变分数对将跨领域的参数重要程度分数的均值与方差进行平衡,用参数λ用以权衡二者之间的关系。
步骤2.4:基于参数的域不变分数对参数进行升序排列并保存。
步骤3:参数裁剪。
步骤3.1:设置参数裁剪的比例。
步骤3.2:加载步骤1.2保存的预训练模型M。
步骤3.3:根据裁剪比例计算出待裁剪参数数量n,将步骤2.4保存的参数序列中的前n个参数在模型M中进行裁剪,具体操作为将参数对应的裁剪变量的值设为0。对裁剪后模型M″进行保存。
步骤4:裁剪后模型重训练。
具体地,包括以下步骤:
步骤4.1:加载多领域评论语料集和裁剪后模型M″。
步骤4.2:模型训练。具体方法同步骤1.3。训练后得到领域通用型模型。
步骤4.3:效果评价。对领域通用型模型进行效果评价,具体方法同步骤1.4。
步骤5:利用领域通用型语言模型,对训练领域数据以外的其他领域数据进行分类预测。
领域通用型语言模型由于去除了特定域有效的参数,在其他领域数据上预测的效果会优于原有的预训练模型。例如,基于领域为“电影”,“软件”和“自动汽车”的训练集数据,利用本方法获得的领域通用型语言模型在预测“工业”领域数据的标签时,由于其具有领域通用性,模型的预测效果会远超于原有的预训练模型。
以上所述为本发明的较佳实施例,本发明不应该局限于该实施例和附图所公开的内容。凡是不脱离本发明所公开的精神下完成的等效或修改,都落入本发明保护的范围。
Claims (2)
1.一种结合领域通用型语言模型的领域泛化方法,其特征在于,包括以下步骤:
步骤1:预训练语言模型微调;
使用预训练语言模型在给定的源域数据进行训练,利用使用多层感知器微调预训练语言模型;其中,多层感知器包含四层:全连接层、双曲正切函数激活函数层、随机丢弃层和全连接层;
利用训练好的预训练语言模型,在目标域数据上获得经过全连接层的输出表示并送至软最大化标准化层,对目标域数据进行相应标签预测;
步骤2:计算参数域不变分数;
对预训练语言中的多头注意力MHA和前馈神经网络FFN模块进行裁剪;
对于待裁剪的参数,当数据集中只有一种领域的数据时,其对应的参数重要程度分数I如下式所示:
其中,IFFN分别表示第i个MHA模块和FFN模块对应的参数重要程度;(x,y)是指数据点,其中x、y分别表示模型输入和对应的真实标签;/>是损失函数;ξ(i)和υ是MHA和FFN对应的裁剪变量;/>表示偏导数;D表示领域集合;
在基于参数重要程度分数I的基础上,提出参数域不变分数I′,对于待裁剪的参数,其对应的参数域不变分数I′如下式所示:
其中,(x,y)是指领域d中的数据点,D是指领域集合;V表示方差,E表示期望;参数域不变分数对将跨领域的参数重要程度分数的均值与方差进行平衡,参数λ用以权衡二者之间的关系;
步骤3:参数裁剪;
对于每个参与域不变分数计算的参数,都有对应的裁剪变量,用以表示该参数是否被裁剪;在对参数进行域不变分数计算后,根据域不变分数对参数进行升序排列,并优先对域不变分数低的参数进行裁剪;
当ξ(i)=0,其对应的注意力头Head会被裁剪,反之该参数会被保留;当υ=0,其对应的前馈神经网络FFN会被裁剪,反之该参数会被保留;通过设置裁剪率,将域不变分数最低的参数进行裁剪,即将其对应的裁剪变量置为0;
步骤4:对裁剪后的模型重训练;
对参数进行裁剪后,将裁剪后的模型进行重训练;其中,重训练需要将裁剪后的模型置为步骤1的初始状态,再让裁剪后的模型在给定的源域数据进行训练,然后对目标域数据进行相应标签预测;
通过设置不同的裁剪率,得到领域泛化效果最好的裁剪后的模型;
步骤5:利用领域通用型预训练模型,对目标领域数据进行分类预测。
2.如权利要求1所述的一种结合领域通用型语言模型的领域泛化方法,其特征在于,步骤1包括以下步骤:
步骤1.1:加载多领域评论语料集,数据集分为训练集、验证集和测试集,并构造成批数据形式;
步骤1.2:加载预训练语言模型M,初始化后保存;
步骤1.3:模型训练;
批数据再经过预训练语言模型结构后获得句子向量表示;使用多层感知器微调预训练语言模型;其中,模型训练的目标函数为交叉熵函数,具体表示形式如下:
其中,m是标签类别的数量,c表示m的某一类别,N为训练样本个数,yic为样本i为类别c的真实概率,为样本i为类别c的预测概率;
模型训练为达到最小交叉熵损失,采用随机梯度下降法对其进行优化;在模型训练过程中,每一次训练后,用验证集数据对模型进行效果评价,此处采用的评价指标为各领域的平均准确率;在每轮验证后,保存效果最优的模型M′;
步骤1.4:效果评价;
利用测试集数据对步骤1.3获得的模型M′进行效果评价;首先加载最优模型M′,将测试集数据作为模型的输入,预测步骤与步骤1.3相同,此处使用的评价指标与步骤1.3相同。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210342805.9A CN114780722B (zh) | 2022-03-31 | 2022-03-31 | 一种结合领域通用型语言模型的领域泛化方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210342805.9A CN114780722B (zh) | 2022-03-31 | 2022-03-31 | 一种结合领域通用型语言模型的领域泛化方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114780722A CN114780722A (zh) | 2022-07-22 |
CN114780722B true CN114780722B (zh) | 2024-05-14 |
Family
ID=82427281
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210342805.9A Active CN114780722B (zh) | 2022-03-31 | 2022-03-31 | 一种结合领域通用型语言模型的领域泛化方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114780722B (zh) |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US8756175B1 (en) * | 2012-02-22 | 2014-06-17 | Google Inc. | Robust and fast model fitting by adaptive sampling |
CN112100383A (zh) * | 2020-11-02 | 2020-12-18 | 之江实验室 | 一种面向多任务语言模型的元-知识微调方法及平台 |
AU2020103905A4 (en) * | 2020-12-04 | 2021-02-11 | Chongqing Normal University | Unsupervised cross-domain self-adaptive medical image segmentation method based on deep adversarial learning |
CN112417877A (zh) * | 2020-11-24 | 2021-02-26 | 广州平云信息科技有限公司 | 一种基于改进bert的文本蕴含关系识别方法 |
CN112613273A (zh) * | 2020-12-16 | 2021-04-06 | 上海交通大学 | 多语言bert序列标注模型的压缩方法及系统 |
-
2022
- 2022-03-31 CN CN202210342805.9A patent/CN114780722B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US8756175B1 (en) * | 2012-02-22 | 2014-06-17 | Google Inc. | Robust and fast model fitting by adaptive sampling |
CN112100383A (zh) * | 2020-11-02 | 2020-12-18 | 之江实验室 | 一种面向多任务语言模型的元-知识微调方法及平台 |
CN112417877A (zh) * | 2020-11-24 | 2021-02-26 | 广州平云信息科技有限公司 | 一种基于改进bert的文本蕴含关系识别方法 |
AU2020103905A4 (en) * | 2020-12-04 | 2021-02-11 | Chongqing Normal University | Unsupervised cross-domain self-adaptive medical image segmentation method based on deep adversarial learning |
CN112613273A (zh) * | 2020-12-16 | 2021-04-06 | 上海交通大学 | 多语言bert序列标注模型的压缩方法及系统 |
Non-Patent Citations (1)
Title |
---|
用于文本分类的多探测任务语言模型微调;傅群超;王枞;;北京邮电大学学报;20191215(第06期);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN114780722A (zh) | 2022-07-22 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11531900B2 (en) | Imitation learning for machine learning systems with synthetic data generators | |
Li et al. | Confidence-based active learning | |
Amari | A universal theorem on learning curves | |
WO2019067960A1 (en) | AGGRESSIVE DEVELOPMENT USING COOPERATIVE GENERATORS | |
Hassan et al. | A hybrid of multiobjective Evolutionary Algorithm and HMM-Fuzzy model for time series prediction | |
Bastos | Credit scoring with boosted decision trees | |
Hamoud et al. | Student’s success prediction model based on artificial neural networks (ANN) and a combination of feature selection methods | |
CN110766138A (zh) | 基于脑发育机制的自适应神经网络模型的构建方法及系统 | |
CN113705715B (zh) | 一种基于lstm和多尺度fcn的时间序列分类方法 | |
Bama et al. | Efficient classification using average weighted pattern score with attribute rank based feature selection | |
Phan et al. | Efficiency enhancement of evolutionary neural architecture search via training-free initialization | |
Urgun et al. | Composite power system reliability evaluation using importance sampling and convolutional neural networks | |
CN114780722B (zh) | 一种结合领域通用型语言模型的领域泛化方法 | |
Fouladvand et al. | Distribution estimation based negative selection algorithm (DENSA) | |
Abusnaina et al. | Enhanced MWO training algorithm to improve classification accuracy of artificial neural networks | |
Vaghela et al. | Boost a weak learner to a strong learner using ensemble system approach | |
CN114898777A (zh) | 基于深度直推式迁移网络的跨库语音情感识别方法及装置 | |
CN113408602A (zh) | 一种树突神经网络初始化方法 | |
Lunga et al. | Online forecasting of stock market movement direction using the improved incremental algorithm | |
KR20220014744A (ko) | 강화 학습을 기반으로 한 데이터 전처리 시스템 및 방법 | |
CN112465054A (zh) | 一种基于fcn的多变量时间序列数据分类方法 | |
Cernazanu-Glavan et al. | A model for determining the number of negative examples used in training a MLP | |
Mácha et al. | Deeptoppush: Simple and scalable method for accuracy at the top | |
Torgo | Predictive Analytics | |
Daqi et al. | Adaptive RBF neural networks for pattern classifications |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |