CN113257361B - 自适应蛋白质预测框架的实现方法、装置及设备 - Google Patents

自适应蛋白质预测框架的实现方法、装置及设备 Download PDF

Info

Publication number
CN113257361B
CN113257361B CN202110600871.7A CN202110600871A CN113257361B CN 113257361 B CN113257361 B CN 113257361B CN 202110600871 A CN202110600871 A CN 202110600871A CN 113257361 B CN113257361 B CN 113257361B
Authority
CN
China
Prior art keywords
model
layer
student
loss
teacher
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110600871.7A
Other languages
English (en)
Other versions
CN113257361A (zh
Inventor
陈磊
杨敏
原发杰
李成明
姜青山
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Institute of Advanced Technology of CAS
Original Assignee
Shenzhen Institute of Advanced Technology of CAS
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Institute of Advanced Technology of CAS filed Critical Shenzhen Institute of Advanced Technology of CAS
Priority to CN202110600871.7A priority Critical patent/CN113257361B/zh
Publication of CN113257361A publication Critical patent/CN113257361A/zh
Application granted granted Critical
Publication of CN113257361B publication Critical patent/CN113257361B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G16INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
    • G16BBIOINFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR GENETIC OR PROTEIN-RELATED DATA PROCESSING IN COMPUTATIONAL MOLECULAR BIOLOGY
    • G16B40/00ICT specially adapted for biostatistics; ICT specially adapted for bioinformatics-related machine learning or data mining, e.g. knowledge discovery or pattern finding
    • G16B40/30Unsupervised data analysis
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning
    • GPHYSICS
    • G16INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
    • G16BBIOINFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR GENETIC OR PROTEIN-RELATED DATA PROCESSING IN COMPUTATIONAL MOLECULAR BIOLOGY
    • G16B20/00ICT specially adapted for functional genomics or proteomics, e.g. genotype-phenotype associations

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Biophysics (AREA)
  • General Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Medical Informatics (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Spectroscopy & Molecular Physics (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Biology (AREA)
  • Biotechnology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Proteomics, Peptides & Aminoacids (AREA)
  • Genetics & Genomics (AREA)
  • Analytical Chemistry (AREA)
  • Chemical & Material Sciences (AREA)
  • Bioethics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Databases & Information Systems (AREA)
  • Epidemiology (AREA)
  • Public Health (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本申请公开了一种自适应蛋白质预测框架的实现方法、装置及设备,该方法包括:基于BERT式掩盖语言模型处理源数据,得到训练样本集;对Transformer模型进行无监督预训练,得到教师模型;固定教师模型的参数,对教师模型和学生模型进行协同训练,并仅优化学生模型的参数,以将教师模型的知识蒸馏到学生模型中,在知识蒸馏过程中,利用搬土距离算法自适应地学习预训练教师模型的中间隐藏层和学生模型的中间隐藏层之间多对多的映射关系;利用经训练的学生模型进行不同的蛋白质预测任务预测,输出预测结果。通过上述方式,能够显著中间缓解模型庞大带来的计算资源不足以及训练、推断时间过长的问题。

Description

自适应蛋白质预测框架的实现方法、装置及设备
技术领域
本申请涉及计算机技术领域,特别是涉及一种自适应蛋白质预测框架的实现方法、装置及设备。
背景技术
蛋白质预测是近几年来发展十分繁荣的领域,因其广阔的应用场景以及巨大的商业价值而备受瞩目,蛋白质是由氨基酸数据以“脱水缩合”的方式组成的多肽链经过盘曲折叠形成的具有一定空间结构的物质,其基本结构为氨基酸数据序列,通过对氨基酸数据序列进行表征学习,可以应用于一系列蛋白质预测任务,如二级结构预测、接触预测、远程同源性检测、荧光度检测以及稳定性预测等,具有十分重要的现实意义。
发明内容
本申请主要解决的技术问题是提供一种自适应蛋白质预测框架的实现方法、装置及设备,能够显著中间缓解模型庞大带来的计算资源不足以及训练、推断时间过长的问题。
为解决上述技术问题,本申请采用的一个技术方案是:提供一种自适应蛋白质预测框架的实现方法,包括以下步骤:基于BERT式掩盖语言模型处理源数据,得到训练样本集,训练样本集包括源数据和与源数据对应的目标数据;以源数据为输入,以目标数据为验证,对Transformer模型进行无监督预训练,得到教师模型;固定教师模型的参数,对教师模型和学生模型进行协同训练,并仅优化学生模型的参数,以将教师模型的知识蒸馏到学生模型中,其中,在知识蒸馏过程中,利用搬土距离算法自适应地学习预训练教师模型的中间隐藏层和学生模型的中间隐藏层之间多对多的映射关系;利用经训练的学生模型进行不同的蛋白质预测任务预测,输出预测结果。
为解决上述技术问题,本申请采用的另一个技术方案是:提供一种存储装置,其上存储有程序,其中,该程序被处理器执行时实现根据前述方法的步骤。
为解决上述技术问题,本申请采用的又一个技术方案是:提供一种电子设备,包括存储器和处理器,在存储器上存储有能够在处理器上运行的程序,处理器执行程序时实现前述方法的步骤。
与现有技术相比,本申请的优点在于,所提供的自适应蛋白质预测框架的实现方法,将一个大模型(教师模型)的知识很好地蒸馏到小模型(学生模型)当中,并且中间隐藏层的蒸馏过程中利用搬土距离算法来较好地衡量出教师模型和学生模型之间的差异,自适应地完成中间隐藏层之间多对多的映射,从而能够显著中间缓解模型庞大带来的计算资源不足以及训练、推断时间过长的问题。
附图说明
图1是本申请自适应蛋白质预测框架的实现方法一实施例的流程图;
图2是图1中步骤S20的流程图;
图3是本申请电子设备的结构示意图。
具体实施方式
蛋白质预测是近几年来发展十分繁荣的领域,因其广阔的应用场景以及巨大的商业价值而备受瞩目,蛋白质是由氨基酸数据以“脱水缩合”的方式组成的多肽链经过盘曲折叠形成的具有一定空间结构的物质,其基本结构为氨基酸数据序列,通过对氨基酸数据序列进行表征学习,可以应用于一系列蛋白质预测任务,如二级结构预测、接触预测、远程同源性检测、荧光度检测以及稳定性预测等,具有十分重要的现实意义。
本申请发明人在长期研发过程中,发现在现阶段用于蛋白质预测任务的基本框架都是基于Transformer模型构建,利用海量的未标记氨基酸序列数据,运用BERT式MaskedLanguage Model对其进行无监督预训练,得到一个表征能力强大的蛋白质预训练模型,再在下游任务中利用下游任务数据对其进行微调训练,使其适配于下游任务,表现出良好性能。但现有蛋白质预测框架预训练-微调范式仍存在以下两个主要缺点:(1)BERT式预训练模型庞大,参数量繁多,直接使用需要耗费大量计算资源,训练以及推断时间过长,难以满足现实要求;(2)统一的预训练模型通过微调应用于下游任务,使用同样的模型结构不够灵活,无法根据下游任务自身特点进行模型结构搜索,缺乏可扩展性和自适应性。
基于上述缺点,本申请通过提出一个基于知识蒸馏及可微神经架构搜索的自适应蛋白质预测框架来很好地缓解以上两个问题。
其中,知识蒸馏用于模型压缩,缓解模型庞大带来的计算资源不足以及训练、推断时间过长的问题;可微神经架构搜索用于自适应模型结构搜索,可以根据下游任务自身特点搜索出自适应模型结构,缓解模型结构固定带来的缺乏可扩展性和自适应性问题。具体地,在知识蒸馏过程中,将在海量未标记氨基酸序列数据进行预训练得到的庞大BERT式预训练模型当作教师模型,将其中有用的知识迁移到学生模型中,整个知识蒸馏过程分为输入嵌入层蒸馏、中间隐藏层蒸馏以及输出预测层蒸馏三部分,其中在中间隐藏层蒸馏过程中,利用搬土距离算法(Earth Mover’s Distance,EMD)自动地完成教师模型与学生模型中间隐藏层多对多的映射,充分迁移教师模型中的有效知识。不同于传统知识蒸馏技术,学生模型有着固定的模型结构,本申请将可微神经架构搜索应用于学生模型结构搜索中,每个学生模型的结构可以根据下游任务自身特点自适应地搜索出来,不再受限于固定的模型结构,其中从教师模型中迁移的知识为整个搜索过程提供搜索指导,同时设计一个效率感知损失来限制搜索出模型结构的大小,实现蛋白质预测效果以及效率的最佳平衡。
下面结合附图和实施方式对本申请进行详细说明。
请参阅图1,图1是本申请自适应蛋白质预测框架的实现方法一实施例的流程图。需注意的是,若有实质上相同的结果,本申请的方法并不以图1所示的流程顺序为限。如图1所示,该方法包括如下步骤:
S10:基于BERT式掩盖语言模型处理源数据,得到训练样本集,训练样本集包括源数据和与源数据对应的目标数据。
S20:以源数据为输入,以目标数据为验证,对Transformer模型进行无监督预训练,得到教师模型。
具体而言,基于Transformer模型架构,利用海量的未标记氨基酸序列数据(以下简称氨基酸数据),运用BERT式Masked Language Model对其进行无监督预训练,最终得到一个表征能力强大的蛋白质预训练模型(即教师模型),以达到良好的充当“教师”的作用。
S30:固定教师模型的参数,对教师模型和学生模型进行协同训练,并仅优化学生模型的参数,以将教师模型的知识蒸馏到学生模型中,其中,在知识蒸馏过程中,利用搬土距离算法自适应地学习预训练教师模型的中间隐藏层和学生模型的中间隐藏层之间多对多的映射关系。
S40:利用经训练的学生模型进行不同的蛋白质预测任务预测,输出预测结果。
具体而言,单独抽离出已经训练好的为不同下游任务“量身定做”的学生模型,快速准确地完成下游的蛋白质预测任务。
本申请可以将训练好的为不同下游任务“量身定做”的学生模型单独抽离出来,包括模型结构以及模型参数,用于处理不同的蛋白质预测下游任务,包括二级结构预测、接触预测、远程同源性检测、荧光度检测以及稳定性预测等。
与现有技术相比,本申请的优点在于,所提供的自适应蛋白质预测框架的实现方法,将一个大模型(教师模型)的知识很好地蒸馏到小模型(学生模型)当中,并且中间隐藏层的蒸馏过程中利用搬土距离算法来较好地衡量出教师模型和学生模型之间的差异,自适应地完成中间隐藏层之间多对多的映射,从而能够显著中间缓解模型庞大带来的计算资源不足以及训练、推断时间过长的问题。
需要说明的是,本申请的氨基酸数据为氨基酸序列数据。
在一实施例中,源数据为n个氨基酸数据组成的序列X={x1,x2,…,xn}。
在一实施例中,步骤S10包括:
S11:基于BERT式掩盖语言模型,执行掩盖策略,随机掩盖n个氨基酸数据中的k个氨基酸数据。
其中,目标数据为被掩盖的k个氨基酸数据{xΔ1,xΔ2,…,xΔk},训练样本集为X’={x1,x2,…,xn\xΔ1,xΔ2,…,xΔk},目标数据的联合概率分布为:
Figure 82543DEST_PATH_IMAGE001
在一实施例中,掩盖策略包括:被掩盖的k个氨基酸数据占n个氨基酸数据中的30~40%氨基酸数据,其中,被掩盖的k个氨基酸数据中的80%被直接掩盖,被掩盖的k个氨基酸数据中的另10%被替换为其它蛋白,被掩盖的k个氨基酸数据中的其余10%保持不变。
在一实施例中,源数据为n个氨基酸数据组成的序列X={x1,x2,…,xn},目标数据为被掩盖的k个氨基酸数据{xΔ1,xΔ2,…,xΔk},训练样本集为X’={x1,x2,…,xn/xΔ1,xΔ2,…,xΔk}。
Transformer模型包括依次连接的输入嵌入层、中间隐藏层和输出预测层,中间隐藏层由N个Transformer模块组成,每个Transformer模块均包括依次连接的多头注意力层、第一Dropout层、第一Add & Norm层、前馈层、第二Dropout层以及第二Add & Norm层。
其中,Add表示残差连接(Residual Connection),用于防止网络退化;Norm表示Layer Normalization,用于对每一层的激活值进行归一化。
具体而言,每一个Transformer模块中都包含两个子层,多头注意力层(Multi-head Attention)以及前馈层(Feed-forward),在每一个子层后都接有Dropout、残差连接(Add)以及层归一化(Norm)操作,用以学习掩盖氨基酸序列的特征表示,在经过所有Transformer模块后,Transformer模型已经充分学习到掩盖氨基酸序列的高维特征表示,最后,将学习到的特征表示输入至输出预测层,预测出掩盖位置的氨基酸。
请参阅图2,图2是图1中步骤S20的流程图。在一实施例中,步骤S20包括:
S21:将训练样本集输入Transformer模型的输入嵌入层。
S22:通过Transformer模型的输入嵌入层对训练样本集进行嵌入处理。
S23:将嵌入处理后的训练样本集输入Transformer模型的中间隐藏层。
S24:通过Transformer模型的中间隐藏层学习嵌入处理后的嵌入处理后的训练样本集的特征表示。
S25:通过Transformer模型的输出预测层输出学习到的特征表示。
其中,整个Transformer模型通过最大化对数似然来进行优化,如以下公式所示:
Figure 661424DEST_PATH_IMAGE002
;
训练直至Transformer模型收敛,整个无监督预训练过程完成,最终得到一个表征能力强大的蛋白质预训练Transformer模型,可以达到良好的充当“教师”的作用。
上述教师模型即为经无监督预训练的Transformer模型。
在一实施例中,教师模型和学生模型均包括:依次连接的输入嵌入层、中间隐藏层和输出预测层,且教师模型相对于学生模型包含较多的中间隐藏层。
在一实施例中,整个知识蒸馏过程的知识蒸馏总损失为输入嵌入层的知识蒸馏损失、中间隐藏层的知识蒸馏损失以及输出预测层的知识蒸馏损失之和。即LKD=Lemb+Lhidden+Lpred,其中,LKD为知识蒸馏总损失,Lemb为输入嵌入层的知识蒸馏损失,Lhidden为中间隐藏层的知识蒸馏损失,Lpred为输出预测层的知识蒸馏损失。
在一实施例中,知识蒸馏输入嵌入层的目的是:完成教师模型嵌入矩阵ET对于学生模型嵌入矩阵ES的知识蒸馏。
输入嵌入层的知识蒸馏过程包括以下步骤:
S101:将源数据X={x1,x2,…,xn}表示为嵌入矩阵E=[e1,e2,…,en],其中,矩阵中的每一列ei代表相应项的嵌入向量。
S102:将教师模型的输入嵌入层表示为嵌入矩阵ET,并将学生模型的输入嵌入层表示为嵌入矩阵ES
S103:在协同训练的过程中,以设定的输入嵌入层的知识蒸馏损失为目标,学习教师模型的嵌入矩阵ET和学生模型的嵌入矩阵ES之间的线性映射矩阵。
其中,通过最小化均方误差表示输入嵌入层的知识蒸馏损失:Lemb=MSE(ET,ESWe),
其中Lemb表示输入嵌入层的知识蒸馏损失,MSE(·)表示均方误差计算,We表示可学习的线性映射矩阵。
在一实施例中,知识蒸馏中间隐藏层的目的是:让学生模型更好地学习到教师模型的行为。中间隐藏层的知识蒸馏过程包括:
S201:将教师模型的中间隐藏层不同层间的输出表示为HT={HT 1,…,HT N },将学生模型的中间隐藏层不同层间的输出表示为HS={HS 1,…,HS K },其中,N和K分别表示教师模型和学生模型堆叠的空洞卷积残差块数量,HT j表示教师模型第j个中间隐藏层的输出向量,HS i表示学生模型第i个中间隐藏层的输出向量。
可选地,N可以取值为12,K可以取值为4。
S202:定义地面距离矩阵
Figure 82041DEST_PATH_IMAGE003
Figure 289031DEST_PATH_IMAGE004
表示从教师模型第j个中间隐藏层的输出向量Hs j转移到学生模型第i个中间隐藏层的输出向量Hs i的映射转移量,其中,
Figure 714065DEST_PATH_IMAGE005
,其中KL(·)表示KL散度计算,Wh表示可学习的线性映射矩阵。
S203:通过求解教师模型的中间隐藏层间与学生模型的中间隐藏层之间的整体转移损失获得最佳映射转移矩阵
Figure 639296DEST_PATH_IMAGE006
,其中,整体转移损失的计算表示为:
Figure 925920DEST_PATH_IMAGE007
,其中,
Figure 987417DEST_PATH_IMAGE008
表示从教师模型第j个中间隐藏层的输出向量HT j转移到学生模型第i个中间隐藏层的输出向量HS i的映射转移量。
定义搬土距离
Figure 84818DEST_PATH_IMAGE009
S204:通过优化教师模型的中间隐藏层输出矩阵HT与学生模型的中间隐藏层输出矩阵HS之间的搬土距离获得教师模型的中间隐藏层和学生模型的中间隐藏层之间多对多的映射关系。
具体而言,通过优化搬土距离,可以较好地衡量出教师模型和学生模型之间的差异,从而自适应地完成中间隐藏层之间多对多的映射,避免了人为指定层次映射关系带来的信息丢失以及信息误导,充分迁移教师模型中的有效知识。
其中,中间隐藏层的知识蒸馏损失表示为Lhidden,其中,Lhidden=EMD(HS,HT)。
在一实施例中,知识蒸馏输出预测层的目的是:使得学生模型的最终预测概率分布与教师模型的最终预测概率分布相接近,从而学习到教师模型的预测行为。
输出预测层的知识蒸馏过程包括:
S301:通过最小化学生模型的最终预测概率分布与教师模型的最终预测概率分布之间的KL散度优化输出预测层的蒸馏过程,其中,输出预测层的知识蒸馏损失表示为Lpred,其中,Lpred=KL(zT,zS),其中zT表示教师模型经过输出预测层后的输出向量,zS表示学生模型经过输出预测层后的输出向量。
在一实施例中,进行步骤S30的同时,该方法还包括:
S50:采用可微神经架构搜索策略对学生模型进行模型结构搜索,在模型结构搜索过程中,从搜索空间中搜索候选操作以组成学生模型的基础搜索块,并通过堆叠K个相同的基础搜索块以组成整个学生模型。
具体而言,进行知识蒸馏教师模型,完成输入嵌入层、中间隐藏层以及输出预测层三部分的蒸馏,同时采用可微神经架构搜索策略搜索学生模型结构,进行协同训练,直至模型收敛,完成整个训练过程。
由于教师模型表征能力强大,包含知识丰富,但其模型庞大,参数量繁多,直接使用需要耗费大量计算资源,训练以及推断时间过长,难以满足现实要求,因此,本申请对其进行知识蒸馏,将教师模型学到的知识很好地蒸馏到学生模型当中,减小模型规模,加速推断,并且不降低模型准确率,较好地完成蛋白质预测系列任务。同时,本申请将可微神经架构搜索应用于学生模型的结构设计中,每个学生模型的结构可以根据下游任务自身特点自适应地搜索出来,不再受限于固定的模型结构,其中从教师模型中迁移的知识为整个搜索过程提供搜索指导,同时设计一个效率感知损失来限制搜索出模型结构的大小,实现蛋白质预测效果以及效率的最佳平衡。
本申请采取可微神经架构搜索策略从搜索空间中搜索候选操作来组成学生模型的基础搜索块,然后堆叠K个相同的基础搜索块来组成整个学生模型,每个基础搜索块由输入节点、输出节点以及内部隐状态节点组成有向无环图,搜索对象为有向无环图节点之间的边,即神经网络中的操作,例如线性映射、卷积、池化、残差连接等。在本申请中,可以选取卷积神经网络操作作为候选操作集,因为卷积神经网络的高效性以及高度可并行性,在蛋白质预测任务中可以发挥出优异的性能,是神经架构搜索的最佳选择,候选操作集中有不同大小卷积核的标准卷积、不同大小卷积核的空洞卷积、最大池化、平均池化、残差连接、无连接等。
在一实施例中,在模型结构搜索过程中,采用知识蒸馏总损失为整个搜索过程提供搜索指导,采用效率感知损失限制搜索出的候选操作的大小,采用交叉熵损失为学生模型的训练过程提供指导。
其中,效率感知损失表示为LE,其中,
Figure 497344DEST_PATH_IMAGE010
其中COST(·)表示搜索空间O内搜索出的候选操作oi,j的归一化参数量以及候选操作浮点数之和。
其中,交叉熵损失表示为LCE,其中,
Figure 525343DEST_PATH_IMAGE011
其中
Figure 769243DEST_PATH_IMAGE012
表示正确蛋白质标签,yi表示模型预测蛋白质标签,C表示训练样本总数。
其中,整个自适应蛋白质预测框架的总损失表示为Lall,知识蒸馏总损失表示为LKD,效率感知损失表示为LE,交叉熵损失表示为LCE,其中,Lall=(1-γ)LCE+γLKD+βLE,其中γ和β用于平衡知识蒸馏总损失、效率感知损失以及交叉熵损失之间的权重。
在一实施例中,通过梯度下降法对自适应蛋白质预测框架的总损失进行优化,以将每个候选操作oi,j建模为离散变量Θo |o|,且该离散变量
Θo |o|符合离散变量概率分布
Figure 958916DEST_PATH_IMAGE013
,再利用Gumbel Softmax策略将候选操作的离散变量松弛为连续变量,松弛后的连续变量可表示为yo,在前向传播时使用离散变量argmax(yo),在反向传播时使用松弛后的连续变量yo,来完成学生模型的训练过程。
其中,
Figure 858738DEST_PATH_IMAGE014
其中gi表示从Gumbel(0,1)分布中随机采集的变量,τ表示温度系数,其中,温度系数用于控制输出连续变量yo的离散化程度,温度系数初始化为1且随着训练过程逐渐退化至0。
本申请还提出一种存储装置,其上存储有程序,其中,该程序被处理器执行时实现上述实施例中的各步骤。本申请还提出一种存储装置,其上存储有程序,其中,该程序被处理器执行时实现根据前述方法的步骤。
本申请还提出一种电子设备,包括存储器和处理器,在存储器上存储有能够在处理器上运行的程序,处理器执行程序时实现上述实施例中的各步骤。
图3为本申请实施例提供的一种电子设备的结构示意图,该电子设备包括处理器301等物理器件,其中,处理器301可以是一个中央处理单元(CentralProcessing Unit,CPU)、微处理器、专用集成电路、可编程逻辑电路、大规模集成电路、或者为数字处理单元等等。
该电子设备还可以包括存储器302用于存储处理器301执行的软件指令,当然还可以存储电子设备需要的一些其他数据,如电子设备的标识信息、电子设备的加密信息、用户数据等。存储器302可以是易失性存储器(Volatile Memory),例如随机存取存储器(Random-Access Memory,RAM);存储器302也可以是非易失性存储器(Non-VolatileMemory),例如只读存储器(Read-Only Memory,ROM),快闪存储器(FlashMemory),硬盘(Hard Disk Drive,HDD)或固态硬盘(Solid-State Drive,SSD)、或者存储器302是能够用于携带或存储具有指令或数据结构形式的期望的程序代码并能够由计算机存取的任何其他介质,但不限于此。存储器302可以是上述存储器的组合。
本申请实施例中不限定上述处理器301、存储器302之间的具体连接介质。本申请实施例在图3中仅以存储器302与处理器301之间通过总线303连接为例进行说明,总线在图3中以粗线表示,其它部件之间的连接方式,仅是进行示意性说明,并不引以为限。所述总线可以分为地址总线、数据总线、控制总线等。为便于表示,图3中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
处理器301可以是专用硬件或运行软件的处理器,当处理器301可以运行软件时,处理器301读取存储器302存储的软件指令,并在所述软件指令的驱动下,执行前述实施例中涉及的自适应蛋白质预测框架的实现方法。
与现有技术相比,本申请的优点在于,所提供的自适应蛋白质预测框架的实现方法,将一个大模型(教师模型)的知识很好地蒸馏到小模型(学生模型)当中,并且中间隐藏层的蒸馏过程中利用搬土距离算法来较好地衡量出教师模型和学生模型之间的差异,自适应地完成中间隐藏层之间多对多的映射,从而能够显著中间缓解模型庞大带来的计算资源不足以及训练、推断时间过长的问题。
此外,本申请通过知识蒸馏教师模型,完成输入嵌入层、中间隐藏层以及输出预测层三部分的蒸馏,将教师模型学到的知识很好地蒸馏到学生模型当中,减小模型规模,加速推断,并且不降低模型准确率,较好地完成蛋白质预测系列任务。同时,本申请将可微神经架构搜索应用于学生模型的结构设计中,每个学生模型的结构可以根据下游任务自身特点自适应地搜索出来,不再受限于固定的模型结构,其中从教师模型中迁移的知识为整个搜索过程提供搜索指导,同时设计一个效率感知损失来限制搜索出模型结构的大小,实现蛋白质预测效果以及效率的最佳平衡。
在本申请所提供的几个实施例中应该理解到,所揭露的系统,装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施方式仅仅是示意性的,例如,所述模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施方式方案的目的。
另外,在本申请各个实施方式中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台电子设备(可以是个人计算机,服务器,或者网络设备等)或处理器(processor)执行本申请各个实施方式所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述仅为本申请的实施方式,并非因此限制本申请的专利范围,凡是利用本申请说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本申请的专利保护范围内。

Claims (14)

1.一种自适应蛋白质预测框架的实现方法,其特征在于,包括以下步骤:
基于BERT式掩盖语言模型处理源数据,得到训练样本集,所述训练样本集包括所述源数据和与所述源数据对应的目标数据;
以所述源数据为输入,以所述目标数据为验证,对Transformer模型进行无监督预训练,得到教师模型;
固定所述教师模型的参数,对所述教师模型和学生模型进行协同训练,并仅优化所述学生模型的参数,以将所述教师模型的知识蒸馏到所述学生模型中,其中,在知识蒸馏过程中,利用搬土距离算法自适应地学习所述教师模型的中间隐藏层和所述学生模型的中间隐藏层之间多对多的映射关系;
利用经训练的所述学生模型进行不同的蛋白质预测任务预测,输出预测结果;
其中,所述源数据为n个氨基酸数据组成的序列X={x1,x2,…,xn},所述目标数据为被掩盖的k个氨基酸数据{xΔ1,xΔ2,…,xΔk},所述训练样本集为X’={ x1,x2,…,xn/xΔ1,xΔ2,…,xΔk };
所述Transformer模型包括依次连接的输入嵌入层、中间隐藏层和输出预测层,所述中间隐藏层由N个Transformer模块组成,每个所述Transformer模块均包括依次连接的多头注意力层、第一Dropout层、第一Add & Norm层、前馈层、第二Dropout层以及第二Add &Norm层;
所述以所述源数据为输入,以所述目标数据为验证,对Transformer模型进行无监督预训练,得到教师模型的步骤,包括:
将所述训练样本集输入所述Transformer模型的输入嵌入层;
通过所述Transformer模型的输入嵌入层对所述训练样本集进行嵌入处理;
将嵌入处理后的所述训练样本集输入所述Transformer模型的中间隐藏层;
通过所述Transformer模型的中间隐藏层学习嵌入处理后的所述训练样本集的特征表示;
通过所述Transformer模型的输出预测层输出学习到的所述特征表示;
其中,整个所述Transformer模型通过最大化对数似然来进行优化,如以下公式所示:
Figure 970421DEST_PATH_IMAGE001
其中,xΔk表示第k个被掩盖的氨基酸数据;所述教师模型为经无监督预训练的所述Transformer模型。
2.根据权利要求1所述的方法,其特征在于,
所述基于BERT式掩盖语言模型处理源数据,得到训练样本集,所述训练样本集包括源数据和与所述源数据对应的目标数据的步骤,包括:
基于BERT式掩盖语言模型,执行掩盖策略,随机掩盖所述n个氨基酸数据中的k个氨基酸数据,其中,所述目标数据为被掩盖的k个氨基酸数据{xΔ1,xΔ2,…,xΔk},所述训练样本集为X’={x1,x2,…,xn/xΔ1,xΔ2,…,xΔk},所述目标数据的联合概率分布为:
Figure 193592DEST_PATH_IMAGE002
3.根据权利要求2所述的方法,其特征在于,
所述掩盖策略包括:所述被掩盖的k个氨基酸数据占所述n个氨基酸数据中的30~40%氨基酸数据,其中,所述被掩盖的k个氨基酸数据中的80%被直接掩盖,所述被掩盖的k个氨基酸数据中的另10%被替换为其它蛋白,所述被掩盖的k个氨基酸数据中的其余10%保持不变。
4.根据权利要求1所述的方法,其特征在于,所述教师模型和所述学生模型均包括:依次连接的输入嵌入层、中间隐藏层和输出预测层,且所述教师模型相对于所述学生模型包含较多的中间隐藏层;
其中,整个知识蒸馏过程的知识蒸馏总损失为所述输入嵌入层的知识蒸馏损失、所述中间隐藏层的知识蒸馏损失以及所述输出预测层的知识蒸馏损失之和。
5.根据权利要求4所述的方法,其特征在于,
所述输入嵌入层的知识蒸馏过程包括:
将所述源数据X={x1,x2,…,xn}表示为嵌入矩阵E=[e1,e2,…,en],其中,矩阵中的每一列ed代表相应项的嵌入向量;
将所述教师模型的输入嵌入层表示为嵌入矩阵ET,并将所述学生模型的输入嵌入层表示为嵌入矩阵ES
在协同训练的过程中,以设定的所述输入嵌入层的知识蒸馏损失为目标,学习所述教师模型的嵌入矩阵ET和所述学生模型的嵌入矩阵ES之间的线性映射矩阵;
其中,通过最小化均方误差表示所述输入嵌入层的知识蒸馏损失:
Lemb=MSE(ET,ESWe),
其中Lemb表示所述输入嵌入层的知识蒸馏损失,MSE(·)表示均方误差计算,We表示可学习的线性映射矩阵。
6.根据权利要求4所述的方法,其特征在于,
所述中间隐藏层的知识蒸馏过程包括:
将所述教师模型的中间隐藏层不同层间的输出表示为HT={HT 1,…,HT N },将所述学生模型的中间隐藏层不同层间的输出表示为HS={HS 1,…,HS K },其中,N和K分别表示所述教师模型和所述学生模型堆叠的空洞卷积残差块数量,HT j表示所述教师模型第j个中间隐藏层的输出向量,
Figure 196183DEST_PATH_IMAGE003
表示所述学生模型第b个中间隐藏层的输出向量;
定义地面距离矩阵
Figure 188410DEST_PATH_IMAGE004
Figure 336363DEST_PATH_IMAGE005
表示从所述教师模型第j个中间隐藏层的输出向量HT j转移到所述学生模型第b个中间隐藏层的输出向量
Figure 781251DEST_PATH_IMAGE006
的映射转移量,其中,
Figure 525216DEST_PATH_IMAGE007
,其中KL(·)表示KL散度计算,Wh表示可学习的线性映射矩阵;
通过求解所述教师模型的中间隐藏层间与所述学生模型的中间隐藏层之间的整体转移损失获得最佳映射转移矩阵
Figure 434266DEST_PATH_IMAGE008
,其中,整体转移损失的计算表示为:
Figure 238274DEST_PATH_IMAGE009
,其中,
Figure 921191DEST_PATH_IMAGE010
表示从所述教师模型第j个中间隐藏层的输出向量HT j转移到所述学生模型第b个中间隐藏层的输出向量
Figure 468847DEST_PATH_IMAGE011
的映射转移量;
定义搬土距离为
Figure 232403DEST_PATH_IMAGE012
通过优化所述教师模型的中间隐藏层输出矩阵HT与所述学生模型的中间隐藏层输出矩阵HS之间的搬土距离获得所述教师模型的中间隐藏层和所述学生模型的中间隐藏层之间多对多的映射关系;
其中,所述中间隐藏层的知识蒸馏损失表示为Lhidden,其中,
Lhidden=EMD(HS,HT)。
7.根据权利要求4所述的方法,其特征在于,
所述输出预测层的知识蒸馏过程包括:
通过最小化所述学生模型的最终预测概率分布与所述教师模型的最终预测概率分布之间的KL散度优化所述输出预测层的蒸馏过程,其中,所述输出预测层的知识蒸馏损失表示为Lpred,其中,
Lpred=KL(zT,zS),
其中zT表示所述教师模型经过输出预测层后的输出向量,zS表示所述学生模型经过输出预测层后的输出向量。
8.根据权利要求4所述的方法,其特征在于,在进行所述固定所述教师模型的参数,对所述教师模型和学生模型进行协同训练,并仅优化所述学生模型的参数,以将所述教师模型的知识蒸馏到所述学生模型中的步骤的同时,所述方法还包括:
采用可微神经架构搜索策略对所述学生模型进行模型结构搜索,在模型结构搜索过程中,从搜索空间中搜索候选操作以组成所述学生模型的基础搜索块,并通过堆叠K个相同的所述基础搜索块以组成整个所述学生模型;
其中,在模型结构搜索过程中,采用所述知识蒸馏总损失为整个搜索过程提供搜索指导,采用效率感知损失限制搜索出的所述候选操作的大小,采用交叉熵损失为所述学生模型的训练过程提供指导。
9.根据权利要求8所述的方法,其特征在于,
所述效率感知损失表示为LE,其中,
Figure 207312DEST_PATH_IMAGE013
其中COST(·)表示所述搜索空间O内搜索出的候选操作of,h的归一化参数量以及候选操作浮点数之和。
10.根据权利要求8所述的方法,其特征在于,
所述交叉熵损失表示为LCE,其中,
Figure 626792DEST_PATH_IMAGE014
其中
Figure 774877DEST_PATH_IMAGE015
表示正确蛋白质标签,
Figure 579891DEST_PATH_IMAGE016
表示模型预测蛋白质标签,C表示训练样本总数。
11.根据权利要求8所述的方法,其特征在于,
整个所述自适应蛋白质预测框架的总损失表示为Lall,所述知识蒸馏总损失表示为LKD,所述效率感知损失表示为LE,所述交叉熵损失表示为LCE,其中,
Lall=(1-γ)LCE+γLKD+βLE
其中γ和β为用于平衡所述知识蒸馏总损失、所述效率感知损失以及所述交叉熵损失之间的权重。
12.根据权利要求11所述的方法,其特征在于,
通过梯度下降法对所述自适应蛋白质预测框架的总损失进行优化,以将每个所述候选操作of,h建模为离散变量Θo |o|,且离散变量Θo |o|符合离散变量概率分布
Figure 788018DEST_PATH_IMAGE017
,再利用Gumbel Softmax策略将所述候选操作的离散变量松弛为连续变量,松弛后的所述连续变量可表示为yo,在前向传播时使用离散变量argmax(yo),在反向传播时使用松弛后的连续变量yo
其中,
Figure 694794DEST_PATH_IMAGE018
其中
Figure 318674DEST_PATH_IMAGE019
表示从Gumbel(0,1)分布中随机采集的变量,τ表示温度系数,其中,所述温度系数用于控制输出连续变量yo的离散化程度,所述温度系数初始化为1且随着训练过程逐渐退化至0。
13.一种存储装置,其上存储有程序,其特征在于,该程序被处理器执行时实现权利要求1至12中任一项所述方法的步骤。
14.一种电子设备,包括存储器和处理器,在所述存储器上存储有能够在处理器上运行的程序,其特征在于,所述处理器执行所程序时实现权利要求1至12中任一项所述的方法的步骤。
CN202110600871.7A 2021-05-31 2021-05-31 自适应蛋白质预测框架的实现方法、装置及设备 Active CN113257361B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110600871.7A CN113257361B (zh) 2021-05-31 2021-05-31 自适应蛋白质预测框架的实现方法、装置及设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110600871.7A CN113257361B (zh) 2021-05-31 2021-05-31 自适应蛋白质预测框架的实现方法、装置及设备

Publications (2)

Publication Number Publication Date
CN113257361A CN113257361A (zh) 2021-08-13
CN113257361B true CN113257361B (zh) 2021-11-23

Family

ID=77185469

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110600871.7A Active CN113257361B (zh) 2021-05-31 2021-05-31 自适应蛋白质预测框架的实现方法、装置及设备

Country Status (1)

Country Link
CN (1) CN113257361B (zh)

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114283878A (zh) * 2021-08-27 2022-04-05 腾讯科技(深圳)有限公司 训练匹配模型、预测氨基酸序列和设计药物的方法与装置
CN113807214B (zh) * 2021-08-31 2024-01-05 中国科学院上海微系统与信息技术研究所 基于deit附属网络知识蒸馏的小目标人脸识别方法
CN115965964B (zh) * 2023-01-29 2024-01-23 中国农业大学 一种鸡蛋新鲜度识别方法、系统及设备

Family Cites Families (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP6595555B2 (ja) * 2017-10-23 2019-10-23 ファナック株式会社 仕分けシステム
EP3924971A1 (en) * 2019-02-11 2021-12-22 Flagship Pioneering Innovations VI, LLC Machine learning guided polypeptide analysis
US11922303B2 (en) * 2019-11-18 2024-03-05 Salesforce, Inc. Systems and methods for distilled BERT-based training model for text classification
CN111159416B (zh) * 2020-04-02 2020-07-17 腾讯科技(深圳)有限公司 语言任务模型训练方法、装置、电子设备及存储介质
CN112507209B (zh) * 2020-11-10 2022-07-05 中国科学院深圳先进技术研究院 一种基于陆地移动距离进行知识蒸馏的序列推荐方法
CN112614538A (zh) * 2020-12-17 2021-04-06 厦门大学 一种基于蛋白质预训练表征学习的抗菌肽预测方法和装置

Also Published As

Publication number Publication date
CN113257361A (zh) 2021-08-13

Similar Documents

Publication Publication Date Title
CN113257361B (zh) 自适应蛋白质预测框架的实现方法、装置及设备
Phan et al. Stable low-rank tensor decomposition for compression of convolutional neural network
CN110347932B (zh) 一种基于深度学习的跨网络用户对齐方法
CN109635204A (zh) 基于协同过滤和长短记忆网络的在线推荐系统
CN109902222A (zh) 一种推荐方法及装置
CN110674323B (zh) 基于虚拟标签回归的无监督跨模态哈希检索方法及系统
Dawid et al. Modern applications of machine learning in quantum sciences
CN116134454A (zh) 用于使用知识蒸馏训练神经网络模型的方法和系统
CN116415654A (zh) 一种数据处理方法及相关设备
WO2022105108A1 (zh) 一种网络数据分类方法、装置、设备及可读存储介质
CN114186084B (zh) 在线多模态哈希检索方法、系统、存储介质及设备
CN114974397A (zh) 蛋白质结构预测模型的训练方法和蛋白质结构预测方法
CN113826117A (zh) 来自神经网络的高效二元表示
CN112256971A (zh) 一种序列推荐方法及计算机可读存储介质
CN114579892A (zh) 一种基于跨城市兴趣点匹配的用户异地访问位置预测方法
WO2020195940A1 (ja) ニューラルネットワークのモデル縮約装置
US20230237337A1 (en) Large model emulation by knowledge distillation based nas
CN113609337A (zh) 图神经网络的预训练方法、训练方法、装置、设备及介质
CN116910210A (zh) 基于文档的智能问答模型训练方法、装置及其应用
Sun et al. Dynamic adjustment of hidden layer structure for convex incremental extreme learning machine
CN116805384A (zh) 自动搜索方法、自动搜索的性能预测模型训练方法及装置
CN115392594B (zh) 一种基于神经网络和特征筛选的用电负荷模型训练方法
Zhang et al. Online kernel classification with adjustable bandwidth using control-based learning approach
Yau et al. Dap-bert: Differentiable architecture pruning of bert
Liu et al. A hybrid deep model with cumulative learning for few-shot learning

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant