CN116611477B - 数据剪枝方法和序列模型的训练方法、装置、设备和介质 - Google Patents
数据剪枝方法和序列模型的训练方法、装置、设备和介质 Download PDFInfo
- Publication number
- CN116611477B CN116611477B CN202310638785.4A CN202310638785A CN116611477B CN 116611477 B CN116611477 B CN 116611477B CN 202310638785 A CN202310638785 A CN 202310638785A CN 116611477 B CN116611477 B CN 116611477B
- Authority
- CN
- China
- Prior art keywords
- sequence
- unit
- data
- pruned
- input
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000013138 pruning Methods 0.000 title claims abstract description 106
- 238000000034 method Methods 0.000 title claims abstract description 91
- 238000012549 training Methods 0.000 title claims abstract description 57
- 238000004364 calculation method Methods 0.000 claims abstract description 173
- 239000003550 marker Substances 0.000 claims abstract description 112
- 239000011159 matrix material Substances 0.000 claims abstract description 27
- 238000012545 processing Methods 0.000 claims description 95
- 238000012805 post-processing Methods 0.000 claims description 18
- 230000008569 process Effects 0.000 claims description 17
- 238000004590 computer program Methods 0.000 claims description 12
- 230000007246 mechanism Effects 0.000 claims description 7
- 238000003058 natural language processing Methods 0.000 abstract description 6
- 238000013473 artificial intelligence Methods 0.000 abstract description 5
- 238000000605 extraction Methods 0.000 abstract description 5
- 238000013135 deep learning Methods 0.000 abstract description 2
- 238000010586 diagram Methods 0.000 description 19
- 230000006870 function Effects 0.000 description 15
- 238000004891 communication Methods 0.000 description 8
- 238000001514 detection method Methods 0.000 description 5
- 230000001133 acceleration Effects 0.000 description 3
- 238000004422 calculation algorithm Methods 0.000 description 3
- 230000006835 compression Effects 0.000 description 3
- 238000007906 compression Methods 0.000 description 3
- 238000013527 convolutional neural network Methods 0.000 description 3
- 238000013136 deep learning model Methods 0.000 description 3
- 238000012986 modification Methods 0.000 description 3
- 230000004048 modification Effects 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 230000011218 segmentation Effects 0.000 description 2
- 230000004913 activation Effects 0.000 description 1
- 238000007792 addition Methods 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 238000013475 authorization Methods 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 125000004122 cyclic group Chemical group 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000018109 developmental process Effects 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 238000011084 recovery Methods 0.000 description 1
- 230000004044 response Effects 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000001953 sensory effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
- G06N3/0455—Auto-encoder networks; Encoder-decoder networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Molecular Biology (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
- Image Analysis (AREA)
Abstract
本公开提供了一种数据剪枝方法和序列模型的训练方法、装置、设备和介质,涉及人工智能领域,具体涉及计算机视觉、自然语言处理和深度学习等技术领域,可应用于图像分类、OCR、文本抽取和问答系统等场景。数据剪枝方法包括:获取针对目标计算单元的输入标记序列;输入标记序列中的每个标记指示多媒体数据中一个数据单元的特征;根据设置于目标计算单元之前的注意力单元生成的注意力矩阵,对输入标记序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记;将剪枝后标记序列输入目标计算单元,得到目标计算单元输出的计算后标记序列;以及组合计算后标记序列和被剪枝标记,得到设置于目标计算单元之后的在后计算单元的输入数据。
Description
技术领域
本公开涉及人工智能领域,具体涉及计算机视觉、自然语言处理和深度学习等技术领域,可应用于图像分类、OCR、文本抽取和问答系统等场景。
背景技术
随着计算机技术和网络技术的发展,深度学习模型的应用越来越广泛,且深度学习模型在各个领域也都取得了突破性的进展。例如,以Transformer为代表的大模型在自然语言处理(Nature Language Processing,NLP)和计算机视觉(Computer Vision,CV)领域大放异彩。但大模型的参数量和计算量需要消耗极大的计算资源,会引起高昂的成本开销,并因此限制了大模型的推广普及。为此,大模型的压缩加速技术应运而生,数据剪枝技术为压缩加速技术中的其中之一。
发明内容
本公开旨在提供一种数据剪枝方法和序列模型的训练方法、装置、设备和介质,以在对模型进行压缩的同时,减小模型精度的损耗。
根据本公开的第一个方面,提供了一种数据剪枝方法,包括:获取输入标记序列;输入标记序列中的每个标记指示多媒体数据中一个数据单元的特征,输入标记序列是针对序列模型所包括的目标计算单元的;序列模型还包括设置于目标计算单元之前的注意力单元;根据注意力单元生成的注意力矩阵,对输入标记序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记;将剪枝后标记序列输入目标计算单元,得到目标计算单元输出的计算后标记序列;以及组合计算后标记序列和被剪枝标记,得到序列模型中设置于目标计算单元之后的在后计算单元的输入数据。
根据本公开的第二个方面,提供了一种序列模型的训练方法,包括:采用序列模型对作为样本的多媒体数据进行处理,得到预测处理结果;其中,作为样本的多媒体数据具有指示真实处理结果的标签;以及根据预测处理结果和真实处理结果,对序列模型进行训练,其中,序列模型包括目标计算单元、设置于目标计算单元之前的注意力单元和设置于目标计算单元之后的在后计算单元;采用序列模型对多媒体数据进行处理的过程包括:基于多媒体数据,获取针对目标计算单元的输入标记序列;输入标记序列中的每个标记指示多媒体数据中一个数据单元的特征;根据注意力单元生成的注意力矩阵,对输入标记序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记;将剪枝后标记序列输入目标计算单元,得到目标计算单元输出的计算后标记序列;以及组合计算后标记序列和被剪枝标记,得到在后计算单元的输入数据。
根据本公开的第三个方面,提供了一种数据剪枝装置,包括:输入序列获取模块,用于获取输入标记序列;输入标记序列中的每个标记指示多媒体数据中一个数据单元的特征;输入标记序列是针对序列模型所包括的目标计算单元的;序列模型还包括设置于目标计算单元之前的注意力单元;序列剪枝模块,用于根据注意力单元生成的注意力矩阵,对输入标记序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记;数据计算模块,用于将剪枝后标记序列输入目标计算单元,得到目标计算单元输出的计算后标记序列;以及数据组合模块,用于组合计算后标记序列和被剪枝标记,得到序列模型中设置于目标计算单元之后的在后计算单元的输入数据。
根据本公开的第四个方面,提供了一种序列模型的训练装置,包括:预测模块,用于采用序列模型对作为样本的多媒体数据进行处理,得到预测处理结果;其中,作为样本的多媒体数据具有指示真实处理结果的标签;以及训练模块,用于根据预测处理结果和真实处理结果,对序列模型进行训练,其中,序列模型包括目标计算单元、设置于目标计算单元之前的注意力单元和设置于目标计算单元之后的在后计算单元;预测模块包括:输入序列获取子模块,用于基于多媒体数据,获取针对目标计算单元的输入标记序列;输入标记序列中的每个标记指示多媒体数据中一个数据单元的特征;序列剪枝子模块,用于根据注意力单元生成的注意力矩阵,对输入标记序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记;数据计算子模块,用于将剪枝后标记序列输入目标计算单元,得到目标计算单元输出的计算后标记序列;以及数据组合子模块,用于组合计算后标记序列和被剪枝标记,得到在后计算单元的输入数据。
根据本公开的第五个方面,提供了一种电子设备,包括:至少一个处理器;以及与至少一个处理器通信连接的存储器;其中,存储器存储有可被至少一个处理器执行的指令,指令被至少一个处理器执行,以使至少一个处理器能够执行本公开提供的数据剪枝方法和/或序列模型的训练方法。
根据本公开的第六个方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,计算机指令用于使计算机执行本公开提供的数据剪枝方法和/或序列模型的训练方法。
根据本公开的第七个方面,提供了一种计算机程序产品,包括计算机程序/指令,所述计算机程序/指令存储于可读存储介质和电子设备其中至少之一上,所述计算机程序/指令在被处理器执行时实现本公开提供的数据剪枝方法和/或序列模型的训练方法。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本公开的限定。其中:
图1是根据本公开实施例的数据剪枝方法和序列模型的训练方法、装置的应用场景示意图;
图2是根据本公开实施例的数据剪枝方法的流程示意图;
图3是根据本公开实施例的对输入标记序列进行剪枝处理的原理示意图;
图4是根据本公开实施例的组合计算后标记序列和被剪枝标记的原理示意图;
图5是根据本公开实施例的序列模型的训练方法的流程示意图;
图6是根据本公开实施例的训练序列模型的原理示意图;
图7是根据本公开实施例的数据剪枝装置的结构框图;
图8是根据本公开实施例的序列模型的训练装置的结构框图;以及
图9是用来实施本公开实施例的数据剪枝方法和/或序列模型的训练方法的电子设备的框图。
具体实施方式
以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
数据剪枝技术是一种针对大模型的压缩加速技术。例如,在NLP领域和CV领域中,可以对生成的标记(token)序列进行剪枝,以减少模型需要处理的数据量,加快模型的处理效率。其中,在NLP领域,一个token代表文本中的一个单词或词语。在CV领域,一个token代表图像的一小块局部区域,例如可以代表划分图像得到的多个图像块中的一个图像块。在模型的前向计算中,通过对标记序列进行剪枝处理,可以移除冗余token,减小参与计算的token序列的长度,从而减小计算复杂度,减小模型运行的耗时和所需的内存。
但对于具有序列性质的任务或者输出密集预测(Dense Prediction)结果的任务而言,对标记序列进行剪枝处理会破坏序列的完整性,并因此会影响输出结果的精度,即降低模型处理精度。其中,输出密集预测结果的任务为密集预测任务,密集预测任务例如可以为需要对输入的文本中的每个词进行预测的任务,或者为需要对输入的图像中的每个像素进行预测的任务。
为了解决上述问题,本公开提供了一种数据剪枝方法和序列模型的训练方法、装置、设备和介质。以下先结合图1对本公开提供的方法和装置的应用场景进行描述。
图1是根据本公开实施例的数据剪枝方法和序列模型的训练方法、装置的应用场景示意图。
如图1所示,该应用场景100中包括电子设备110。电子设备110可以为具有处理功能的智能手机、平板电脑、膝上型便携计算机或台式计算机等设备。
例如,该电子设备110可以对输入的多媒体数据120进行处理,得到预测处理结果130。其中,多媒体数据120可以为文本或图像,预测处理结果130可以根据多媒体数据120的类型和处理任务来确定。例如,若多媒体数据120为文本,处理任务为文本抽取任务,则预测处理结果130可以为抽取的文本中的关键信息等。若多媒体数据120为图像,处理任务为文本识别任务,则预测处理结果130可以为识别得到的图像中的文本。若多媒体数据120为图像,处理任务为目标检测任务,则预测处理结果130可以为检测得到的目标物体的类别等。
在该场景中,电子设备110例如可以采用序列模型来对多媒体数据120进行处理。其中,序列模型例如可以为能够对序列数据进行处理的任意深度学习模型,例如可以为循环神经网络模型、基于注意力机制构建的模型等。例如,序列模型可以为基于Transformer架构构建的模型。在采用序列模型对多媒体数据进行处理时,例如可以对输入序列模型中一个或多个计算单元的标记序列进行剪枝处理,以提高序列模型得到预测处理结果130的效率。
如图1所示,该应用场景100中还可以包括服务器140。电子设备110可以通过网络与服务器140通信连接。在该实施例中,电子设备110中可以安装有图像处理类应用、文本处理类应用、即时通信类应用等客户端应用,服务器140可以为向电子设备110中安装的客户端应用的运行提供支持的后台管理服务器等。
在一实施例中,服务器140可以采用作为样本的多媒体数据,根据处理任务来对序列模型进行训练,得到训练好的序列模型150。服务器140例如可以响应于电子设备110的请求,向电子设备110发送训练好的序列模型150,以供电子设备110根据接收到的序列模型150来对多媒体数据120进行处理。
在一实施例中,电子设备110也可以将多媒体数据120经由网络发送给服务器140,由服务器140对该多媒体数据120进行处理,从而得到预测处理结果130。
可以理解的是,本公开提供的数据剪枝方法可以由电子设备110执行,也可以由服务器140执行。相应地,本公开提供的数据剪枝装置可以设置在电子设备110上,也可以设置在服务器140上。本公开提供的序列模型的训练方法可以由服务器140执行。相应地,本公开提供的序列模型的训练装置可以设置在服务器140上。
需要说明的是,图1中的多媒体数据、电子设备和服务器的数目和类型仅作为示例以利于理解本公开,本公开对此不做限定。
以下将结合图2~图4对本公开提供的数据剪枝方法进行详细描述。
图2是根据本公开实施例的数据剪枝方法的流程示意图。
如图2所示,该实施例的数据剪枝方法200可以包括操作S210~操作S240。
在操作S210,获取输入标记序列。其中,输入标记序列是针对序列模型所包括的目标计算单元的。
根据本公开的实施例,目标计算单元可以为序列模型包括的任意一个网络层中的一个计算单元或至少两个计算单元。根据实际需求,目标计算单元还可以包括序列模型包括的相邻两个网络层中,一个网络层的第一部分计算单元和另一个网络层的第二部分计算单元,且该第一部分计算单元的输出数据为第二部分计算单元的输入数据。其中,序列模型包括的网络层例如可以包括编码层和/或解码层。编码层例如可以包括卷积计算单元和激活单元,或者可以包括注意力单元和全连接单元等。解码层例如可以包括转置卷积计算单元,或者可以包括注意力单元和全连接单元等。
在一实施例中,序列模型可以包括依次连接的、基于注意力机制构建的多个计算层。每个计算层可以包括注意力单元和后处理单元。例如,序列模型可以为基于Transformer架构构建的模型,多个计算层可以包括多个Transformer Encoder,还可以包括多个Transformer Decoder,注意力单元可以为多头自注意力单元或者多头交叉注意力单元,后处理单元可以包括全连接前馈网络单元、残差连接和层归一化单元等。在该实施例中,目标计算单元可以包括多个计算层中指定计算层所包括的后处理单元。或者,目标计算单元也可以包括该指定计算层所包括的后处理单元和该指定计算层的在后计算层所包括的注意力单元,以此尽可能地减小序列模型的计算量,减小序列模型的运行对计算资源的占用量,利于在计算能力有限的计算设备上部署序列模型。其中,指定计算层可以为多个计算层中的任意一个计算层,也可以根据实际需求,从多个计算层中选择一个或至少两个计算层中的每个计算层为指定计算层。可以理解的是,目标计算单元可以从序列模型中灵活地选择,以此使得数据的剪枝更为符合实际场景的需求,具体可以根据需要部署序列模型的计算设备的计算能力,灵活地选择目标计算单元,从而利于序列模型在更多不同计算能力的计算设备上的部署。
该实施例中,可以将需要处理的多媒体数据输入序列模型,由序列模型中目标计算单元的在前计算单元输出针对目标计算单元的输入标记序列。其中,多媒体数据可以为文本数据,也可以为图像数据等。例如,若多媒体数据为文本数据,该实施例可以先对文本数据进行分词处理或分字处理,得到词序列或字序列。随后,将字序列或词序列输入序列模型,由在前计算单元输出针对目标计算单元的输入标记序列。若多媒体数据为图像数据,该实施例可以先根据预定尺寸对图像数据进行切分,得到由多个图像数据块构成的图像块序列。随后,将图像块序列输入序列模型,由在前计算单元输出针对目标计算单元的输入标记序列。可以理解的是,输入标记序列中的每个标记指示多媒体数据中一个数据单元的特征,例如,若多媒体数据为文本数据,则每个标记指示一个字或一个词的特征。若多媒体数据为图像数据,则每个标记指示一个图像块的特征。
例如,若序列模型为检测变压器(Detection Transformer,DETR)模型,多媒体数据为图像数据,目标计算单元为序列模型包括的第一个Transformer Encoder层包括的后处理单元,则可以将图像块序列输入DETR模型,图像块序列经由DETR模型的卷积神经网络(Convolutional Neural Network,CNN)和第一个Transformer Encoder层中的注意力单元处理,由第一个Transformer Encoder层中的注意力单元输出针对目标计算单元的输入标记序列。
在操作S220,根据注意力单元生成的注意力矩阵,对输入标记序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记。
在该实施例中,序列模型可以包括设置于目标计算单元之前的注意力单元。例如,序列模型包括前述的计算层,目标计算单元包括该计算层中的后处理单元,该注意力单元为该计算层中包括的注意力单元。可以理解的是,若序列模型中包括多个注意力单元,则可以采用设置于目标计算单元之前,且最为靠近目标计算单元的注意力单元作为该操作S220中所提及的注意力单元所产生的注意力矩阵来进行剪枝处理。
例如,该实施例可以先根据该注意力矩阵来确定与输入标记序列中的每个标记相对应的注意力权重。随后,根据确定的注意力权重,对输入标记序列进行剪枝处理,以剪枝掉权重较低的标记。
例如,设定输入序列模型的字序列、词序列或图像块序列的尺寸为1×N,输入标记序列的尺寸为1×N,注意力矩阵的尺寸为N×N,且该注意力矩阵是按行归一化的矩阵。则注意力矩阵中第i行第j列的元素表示对于输入序列模型的序列中的第i个元素而言,第j个元素的重要性。该实施例中,可以将注意力矩阵中第j列元素的和或者该第j列元素的平均值等,作为与输入标记序列中第j个标记对应的注意力权重。可以理解的是,上述确定注意力权重的方式仅作为示例以利于理解本公开,N为大于1的自然数,本公开对此不做限定。
在得到与每个标记相对应的注意力权重后,该实施例可以从输入标记序列中剔除注意力权重较小的预定数量个标记,从而得到剪枝后标记序列。其中,预定数量可以根据实际需求进行设定,例如,若预先设定了裁剪率为a%,则预定数量例如可以为a%*N向上取整或向下取整得到的值。其中,a为根据实际需求设定的小于1的任意值,本公开对此不做限定。在该实施例中,被剔除的标记即为被剪枝标记。
在操作S230,将剪枝后标记序列输入目标计算单元,得到目标计算单元输出的计算后标记序列。
根据本公开的实施例,可以将被剪枝标记序列作为目标计算单元的输入,被剪枝标记序列经由目标计算单元处理后,可以得到目标计算单元输出的计算后标记序列。可以理解的是,计算后标记序列与被剪枝标记序列中包括的标记个数相同。例如,若目标计算单元为序列模型中第一个Transformer Encoder层包括的后处理单元,则计算后标记序列为该第一个Transformer Encoder层输出的序列。
在操作S240,组合计算后标记序列和被剪枝标记,得到序列模型中设置于目标计算单元之后的在后计算单元的输入数据。
例如,在通过操作S220得到被剪枝标记之后,可以将该被剪枝标记存储至预定存储空间。在得到计算后标记序列之后,或者可以在得到计算后标记序列的同时,从预定存储空间中读取被剪枝标记,并在得到计算后标记序列之后,将读取的被剪枝标记与计算后标记序列组合,从而得到在后计算单元的输入数据。例如,若输入标记序列的尺寸为1×N,剪枝后标记序列的尺寸为1×M,则被剪枝标记的个数为N-M,通过组合被剪枝标记和计算后标记序列,可以得到尺寸为1×N的输入数据。可以理解的是,在多媒体数据为文本时,该尺寸为1×N的输入数据中的N个数据与前述的尺寸为1×N的字序列/词序列中的N个字/词一一对应。在多媒体数据为图像时,该尺寸为1×N的输入数据中的N个数据与前述的尺寸为1×N的图像块序列中的N个图像块一一对应。
例如,可以将被剪枝标记添加至计算后标记序列中的任意位置,例如添加至被剪枝标记序列中第一个标记之前或者添加至被剪枝标记序列中最后一个标记之后,从而得到针对在后计算单元的输入数据。可以理解的是,该输入数据实质上为序列模型中与目标计算单元相邻、且在数据处理顺序上位于目标计算单元之后的计算单元。例如,若目标计算单元为序列模型中第一个Transformer Encoder层包括的后处理单元,则在后计算单元可以为序列模型中第二个Transformer Encoder层的注意力单元。
该实施例的数据剪枝方法,在对剪枝后标记序列进行处理后,将被剪枝标记与处理剪枝后标记序列得到的计算后标记序列相组合,从而得到在后计算单元的输入数据,可以使得在后计算单元的输入数据能够完整的表达多媒体数据,从而可以在减小序列模型计算量、降低对运行序列模型的计算资源的占用量的同时,保证序列模型的精度,降低因剪枝处理而带来的精度损耗。再者,该实施例中,通过采用位于目标计算单元之前的注意力单元生成的注意力矩阵来进行剪枝处理,可以动态地选择需要被剪枝掉的token,从而可以使得剪枝处理的结果更为适合各个多媒体数据,有效地提高了剪枝精度,并因此可以进一步提高序列模型的精度。
以下将结合图3对上述对输入标记序列进行剪枝处理的步骤的实现原理进行进一步扩展和限定。
图3是根据本公开实施例的对输入标记序列进行剪枝处理的原理示意图。
如图3所示,在实施例300中,在对输入标记序列进行剪枝处理时,可以先根据设置于目标计算单元310之前的注意力单元320所生成的注意力矩阵301,来确定与输入标记序列302中的每个标记相对应的注意力权重303。随后,根据与该每个标记相对应的注意力权重303,来对输入标记序列302中的标记进行重排序,得到重排序后序列304。最后,对重排序后序列304进行剪枝处理,从而得到剪枝后标记序列305和被剪枝标记306。如此,剪枝后标记序列305即为目标计算单元310的输入数据。
其中,根据注意力矩阵得到注意力权重的原理与上文描述的原理类似,在此不再赘述。在得到与每个标记相对应的注意力权重后,可以将输入标记序列中的标记根据对应的注意力权重自大至小,或者自小至大的顺序进行重新排序,从而得到重排序后序列304。随后,该实施例300可以根据预定剪枝率或者前述的预定数量,对重排序后序列进行剪枝处理。
通过对输入标记序列中的标记进行重排序,可以提高剪枝处理的效率,降低剪枝处理所需耗费的计算时长和剪枝处理对计算资源的占用量。例如,若重排序后序列根据对应的注意力权重自大至小排序,则可以通过剔除重排序后序列中排在靠后位置的预定数量个标记,即可完成剪枝处理。若根据自小至大排序,则通过剔除重排序后序列中排在靠前位置的预定数量个标记,即可完成剪枝处理。
可以理解的是,在确定了预定剪枝率之后,例如可以采用上文描述的方式来根据预定剪枝率和输入标记序列中标记的个数,来确定需要剔除的标记的数量(即预定数量)。
以下将结合图4对组合计算后标记序列和被剪枝标记的原理进行进一步的扩展和限定。
图4是根据本公开实施例的组合计算后标记序列和被剪枝标记的原理示意图。
在一实施例中,在剪枝得到被剪枝标记之后,除了将被剪枝标记存储至预定存储空间之外,还可以将指示每个标记在输入标记序列中的位置的索引信息一并存储。随后,可以根据该索引信息对计算后标记序列和被剪枝标记进行组合,以使得组合得到的输入数据中标记所对应的数据单元的排列顺序与输入标记序列中标记所对应的数据单元的排列顺序一致,从而可以使得在后计算单元的输入数据可以更好地表达多媒体数据,利于进一步降低剪枝处理带来的模型精度的损耗,提高序列模型的精度。
具体地,如图4所示,该实施例400可以在通过上文描述的操作S220得到被剪枝标记401之后,例如可以将该被剪枝标记401及输入标记序列402中每个标记的索引信息403存储至预定存储空间410中。可以理解的是,存储的索引信息403例如可以指示每个标记在输入标记序列中的位置,以及每个标记与其在输入标记序列中的位置之间的映射关系。由于被剪枝标记401为输入标记序列中的标记,且上述操作S230得到的计算后标记序列中的标记与输入标记序列中的标记之间具有对应关系。则在组合计算后标记序列和被剪枝标记时,可以先从预定存储空间410存储的索引信息403中,确定与计算后标记序列404中各标记对应的第一索引信息405,以及与被剪枝标记401对应的第二索引信息406。随后,根据该第一索引信息405和第二索引信息406来组合计算后标记序列404和被剪枝标记401,从而得到在后计算单元的输入数据407。
例如,若被剪枝标记401在输入标记序列402中位于首位,则在后计算单元的输入数据407中,该被剪枝标记401也位于首位。若计算后标记序列中的第k个标记与输入标记序列中位于第l个位置的标记相对应,则在后计算单元的输入数据407中,该第k个标记位于第l个位置。
可以理解的是,在采用图3所示的原理对输入标记序列中的标记进行重排序,从而得到剪枝后标记序列的情况下,计算后标记序列404中标记彼此之间的排列顺序与该计算后标记序列404中的标记在输入数据407中的排列顺序可能会不一致。但输入数据407中位于第l个位置的标记与输入标记序列402中位于第l个位置的标记表征同一个数据单元的特征。
例如,在根据第一索引信息和第二索引信息组合计算后标记序列和被剪枝标记时,可以先根据计算后标记序列中的标记的第一索引信息所指示位置的先后顺序,对计算后标记序列中的标记进行重排序。随后,根据第二索引信息所指示的位置,将被剪枝标记插入重排序后的计算后标记序列中,从而得到输入数据。
例如,在根据第一索引信息和第二索引信息组合计算后标记序列和被剪枝标记时,也可以根据第一索引信息将计算后标记序列中的标记填入预定空序列,并根据第二索引信息将被剪枝标记填入该预定空序列,从而得到在后计算单元的输入数据。例如,可以根据第一索引信息指示的第一位置,将计算后标记序列中的标记填入预定空序列中的第一位置处,根据第二索引信息指示的第二位置,将被剪枝标记填入预定空序列中的第二位置处,从而得到输入数据。通过该方式,可以提高组合得到输入数据的效率。需要说明的是,预定空序列的尺寸例如可以与输入标记序列的尺寸相等。
在一实施例中,预定存储空间例如可以为运行序列模型的处理器或图形处理器等中的内存,本公开对此不做限定。例如,若运行序列模型的为AI芯片,则预定存储空间还可以为该AI芯片中所设置的内存。
在一实施例中,可以将被剪枝标记及输入标记序列每个标记的索引信息以固定的数据结构存储至预定存储空间。其中,固定的数据结构例如可以采用以下结构。其中,“Tokenpruning_meta”表示剪枝后需要存储的数据,“pruned_tokens”表示被剪枝标记。“pruned_index”表示指示被剪枝标记在输入标记序列中的位置的索引信息;“kept_index”表示指示剪枝后标记序列中各标记在输入标记序列中的位置的索引信息。可以理解的是,以下结构仅作为示例以利于理解本公开,本公开对此不做限定。
基于本公开提供的数据剪枝方法,本公开还提供了一种序列模型的训练方法,以下将结合图5对该训练方法进行详细描述。
图5是根据本公开实施例的序列模型的训练方法的流程示意图。
如图5所示,该实施例的序列模型的训练方法500可以包括操作S510~操作S520。序列模型至少包括目标计算单元、设置于目标计算单元之前的注意力单元和设置于目标计算单元之后的在后计算单元。需要说明的是,此处的之前、之后是指序列模型处理多媒体数据的过程中,数据流的前后。
在操作S510,采用序列模型对作为样本的多媒体数据进行处理,得到预测处理结果。
根据本公开的实施例,可以将作为样本的多媒体数据输入序列模型,经由序列模型处理后,由序列模型输出预测处理结果。其中,作为样本的多媒体数据具有指示真实处理结果的标签。例如,对于目标检测任务,序列模型输出的预测处理结果可以包括多媒体数据中包括的目标对象的包围框的预测位置和与该预测位置对应的目标对象属于多个预定类别中每个类别的概率值。真实处理结果可以包括目标对象的真实类别和目标对象的真实位置。对于文本抽取任务,序列模型的预测处理结果可以包括多媒体数据中包括的文本的预测关键信息,真实处理结果可以包括多媒体数据中包括的文本的真实关键信息。例如,对于文本识别任务,序列模型输出的预测处理结果可以包括预测得到的多媒体数据中的文本(即预测文本),真实处理结果可以包括多媒体数据中的文本(即真实文本)。
在一实施例中,在序列模型对多媒体数据进行处理的过程中,可以采用上文描述的数据剪枝方法来剪枝输入目标计算单元的输入标记序列,并经过组合得到在后计算单元的输入数据。即,在对序列模型进行训练的过程中引入数据剪枝方法,以此使得训练得到的序列模型的处理性能更为贴合线下实际场景的需求。
具体地,如图5所示,该实施例中,操作S510可以包括操作S511~操作S514。其中,例如可以在序列模型中设置一个或至少两个目标计算单元,设置的每个目标计算单元可以由序列模型中的一个或相邻的至少两个计算单元构成。可以理解的是,根据实际需求,可以在序列模型中灵活地设置目标计算单元,以对输入目标计算单元的标记序列进行剪枝,并组合目标计算单元输出的计算后标记序列和剪枝过程中被剪枝掉的标记,从而得到位于目标计算单元之后的在后计算单元的输入数据。
在操作S511,基于多媒体数据,获取针对目标计算单元的输入标记序列。其中,输入标记序列中的每个标记指示多媒体数据中一个数据单元的特征。该操作S511的实现原理与上文描述的操作S210的实现原理类似,在此不再赘述。
在操作S512,根据注意力单元生成的注意力矩阵,对输入标记序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记。该操作S512的实现原理与上文描述的操作S220的实现原理类似,在此不再赘述。
在操作S513,将剪枝后标记序列输入目标计算单元,得到目标计算单元输出的计算后标记序列。该操作S513的实现原理与上文描述的操作S230的实现原理类似,在此不再赘述。
在操作S514,组合计算后标记序列和被剪枝标记,得到在后计算单元的输入数据。该操作S514的实现原理与上文描述的操作S240的实现原理类似,在此不再赘述。
在操作S520,根据预测处理结果和真实处理结果,对序列模型进行训练。
根据本公开的实施例,可以先根据预测处理结果和真实处理结果,确定序列模型的预测损失值。随后,以最小化该预测损失值为目标,采用反向传播算法调整序列模型中的网络参数,从而实现对序列模型的训练。
其中,可以根据序列模型可以处理的任务来确定损失函数,并采用损失函数来确定预测损失值。例如,对于文本抽取任务而言,损失函数例如可以为基于余弦距离等构建的损失函数。对于文本识别任务而言,损失函数例如可以为多分类交叉熵损失函数等。对于目标检测任务而言,损失函数例如可以包括用于计算分类损失的交叉熵损失函数和用于计算回归损失的L1损失函数或者L2损失函数等。可以理解的是,上述与处理的任务对应的损失函数仅作为示例以利于理解本公开,本公开对此不做限定。
图6是根据本公开实施例的训练序列模型的原理示意图。
根据本公开的实施例,在训练序列模型的过程中,除了将被剪枝标记和索引信息存入预定存储空间外,例如还可以将针对序列模型的目标计算图的拓扑信息存储至预定存储空间。其中,该目标计算图为未对输入标记序列进行剪枝处理的计算图。
为了方便理解,该实施例采用简单的计算图来说明需要存储至预定存储空间的目标计算图的拓扑信息。如图6所示,在该实施例600中,以表示表达式e=(a+b)*(b+1)的计算图作为简单的计算图。该表达式有三个操作:两个加法和一个乘法。为了帮助讨论,引入两个中间变量c和d,以便每个函数的输出都有一个变量,则c=a+b;d=b+1;e=c*d。为了创建计算图,我们将这些操作连同输入变量一起放到节点中,当一个节点的值是另一个节点的输入时,箭头从一个指向另一个。可以通过将输入变量设置为特定值并通过图计算节点来求解表达式。如图6所示,针对该表达时构建的计算图中,包括图计算节点a 601、图计算节点b 602、图计算节点c=a+b 603、图计算节点d=b+1604和图计算节点e=c*d 605。其中,图计算节点a 601和图计算节点b 602指向图计算节点c=a+b 603,图计算节点b 602指向图计算节点d=b+1604,图计算节点c=a+b 603和图计算节点d=b+1604指向图计算节点e=c*d 605。
若将变量a作为本公开涉及的输入标记序列,且输入标记序列的尺寸为1×N;则该实施例中,对于未对输入标记序列进行剪枝处理的计算图而言,图计算节点a 601可以采用序列{a1,a2,…,aN}来表示。通过对该输入标记序列进行剪枝处理,得到的剪枝后标记序列的尺寸例如可以为1×M,则对输入标记序列进行剪枝处理之后的计算图,图计算节点a 601可以采用序列{a1,a2,…,aM}来表示。在该实施例中,需要存储至预定存储空间的拓扑信息为图计算节点a 601采用序列{a1,a2,…,aN}表示的计算图的拓扑信息。
其中,计算图的拓扑信息可以包括计算图中的每个图计算节点的信息,及指向每个图计算节点的其他图计算节点的信息。可以理解的是,目标计算图实质上指不涉及剪枝处理的、与序列模型相对应的计算图。
在该实施例中,在需要对序列模型进行训练时,除了采用上文描述的方式确定序列模型的预测损失值外,在采用反向传播算法调整序列模型中的网络参数时,例如可以根据确定的预测损失值和预定存储空间中存储的目标计算图的拓扑信息来执行反向传播运算,以此在反向传播过程中,可以复原被剪枝的标记,使得训练过程支持该被剪枝标记的梯度的反向传播,可以通过反向传播恢复得到该被剪枝标记,提高对序列模型的训练精度。
基于本公开提供的数据剪枝方法,本公开还提供了一种数据剪枝装置。以下将结合图7对该装置进行详细描述。
图7是根据本公开实施例的数据剪枝装置的结构框图。
如图7所示,该实施例的数据剪枝装置700可以包括输入序列获取模块710、序列剪枝模块720、数据计算模块730和数据组合模块740。
输入序列获取模块710用于获取目标计算单元的输入标记序列。其中,输入标记序列中的每个标记指示多媒体数据中一个数据单元的特征,输入标记序列是针对序列模型所包括的目标计算单元的,且序列模型还包括设置于目标计算单元之前的注意力单元。在一实施例中,输入序列获取模块710可以用于执行上文描述的操作S210,在此不再赘述。
序列剪枝模块720用于根据注意力单元生成的注意力矩阵,对输入标记序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记。在一实施例中,序列剪枝模块720可以用于执行上文描述的操作S220,在此不再赘述。
数据计算模块730用于将剪枝后标记序列输入目标计算单元,得到目标计算单元输出的计算后标记序列。在一实施例中,数据计算模块730可以用于执行上文描述的操作S230,在此不再赘述。
数据组合模块740用于组合计算后标记序列和被剪枝标记,得到序列模型中设置于目标计算单元之后的在后计算单元的输入数据。在一实施例中,数据组合模块740可以用于执行上文描述的操作S240,在此不再赘述。
根据本公开的实施例,上述数据剪枝装置700还可以包括存储模块,用于将被剪枝标记及输入标记序列中每个标记的索引信息存储至预定存储空间;索引信息指示标记在输入标记序列中的位置。上述数据组合模块740可以包括索引确定子模块和组合子模块。索引确定子模块用于确定预定存储空间存储的索引信息中与计算后标记序列中的各标记对应的第一索引信息,及与被剪枝标记对应的第二索引信息。组合子模块用于根据第一索引信息和第二索引信息组合计算后标记序列和被剪枝标记,得到在后计算单元的输入数据。
根据本公开的实施例,上述组合子模块具体可以用于分别根据第一索引信息和第二索引信息,将计算后标记序列中的标记和被剪枝标记填入预定空序列,得到在后计算单元的输入数据。
根据本公开的实施例,上述序列剪枝模块720可以包括权重确定子模块、重排序子模块和剪枝子模块。权重确定子模块用于根据注意力矩阵,确定与输入标记序列中的每个标记相对应的注意力权重。重排序子模块用于根据与每个标记相对应的注意力权重,对输入标记序列中的标记进行重排序,得到重排序后序列。剪枝子模块用于对重排序后序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记。
根据本公开的实施例,序列模型包括依次连接的、基于注意力机制构建的多个计算层;每个计算层包括注意力单元和后处理单元。目标计算单元包括:多个计算层中指定计算层所包括的后处理单元。
根据本公开的实施例,目标计算单元还包括:指定计算层的在后计算层所包括的注意力单元。
根据本公开的实施例,多媒体数据包括文本数据,数据单元为文本数据中包括的字或词;多媒体数据包括图像数据,数据单元为对图像数据进行分块处理得到的图像块。
基于本公开提供的序列模型的训练方法,本公开还提供了一种序列模型的训练装置。以下将结合图8对该装置进行详细描述。
图8是根据本公开实施例的序列模型的训练装置的结构框图。
如图8所示,该实施例的序列模型的训练装置800可以包括预测模块810和训练模块820。
预测模块810用于采用序列模型对作为样本的多媒体数据进行处理,得到预测处理结果。其中,作为样本的多媒体数据具有指示真实处理结果的标签。在一实施例中,预测模块810可以用于执行上文描述的操作S610,在此不再赘述。
示例性地,序列模型可以包括目标计算单元、设置于目标计算单元之前的注意力单元和设置于目标计算单元之后的在后计算单元。上述预测模块810可以包括输入序列获取子模块811、序列剪枝子模块812、数据计算子模块813和数据组合子模块814。输入序列获取子模块811用于基于多媒体数据,获取针对目标计算单元的输入标记序列;输入标记序列中的每个标记指示多媒体数据中一个数据单元的特征。序列剪枝子模块812用于根据注意力单元生成的注意力矩阵,对输入标记序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记。数据计算子模块813用于将剪枝后标记序列输入目标计算单元,得到目标计算单元输出的计算后标记序列。数据组合子模块814用于组合计算后标记序列和被剪枝标记,得到在后计算单元的输入数据。在一实施例中,输入序列获取子模块811、序列剪枝子模块812、数据计算子模块813和数据组合子模块814可以分别用于执行上文描述的操作S611~操作S614,在此不再赘述。
训练模块820用于根据预测处理结果和真实处理结果,对序列模型进行训练。在一实施例中,训练模块820可以用于执行上文描述的操作S620,在此不再赘述。
根据本公开的实施例,上述序列模型的训练装置800还可以包括存储模块,用于将针对序列模型的目标计算图的拓扑信息存储至预定存储空间;目标计算图为未对输入标记序列进行剪枝处理的计算图。上述训练模块820可以包括:损失确定子模块,用于根据预测处理结果和真实处理结果,确定序列模型的预测损失值;以及模型训练子模块,用于根据预测损失值和预定存储空间中存储的目标计算图的拓扑信息执行反向传播运算,以对序列模型进行训练。
需要说明的是,本公开的技术方案中,所涉及的用户个人信息的收集、存储、使用、加工、传输、提供、公开和应用等处理,均符合相关法律法规的规定,采取了必要保密措施,且不违背公序良俗。在本公开的技术方案中,在获取或采集用户个人信息之前,均获取了用户的授权或同意。
根据本公开的实施例,本公开还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
图9示出了可以用来实施本公开实施例的数据剪枝方法和/或序列模型的训练方法的示例电子设备900的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
如图9所示,电子设备900包括计算单元901,其可以根据存储在只读存储器(ROM)902中的计算机程序或者从存储单元908加载到随机访问存储器(RAM)903中的计算机程序,来执行各种适当的动作和处理。在RAM 903中,还可存储电子设备900操作所需的各种程序和数据。计算单元901、ROM 902以及RAM 903通过总线904彼此相连。输入/输出(I/O)接口905也连接至总线904。
电子设备900中的多个部件连接至I/O接口905,包括:输入单元906,例如键盘、鼠标等;输出单元907,例如各种类型的显示器、扬声器等;存储单元908,例如磁盘、光盘等;以及通信单元909,例如网卡、调制解调器、无线通信收发机等。通信单元909允许电子设备900通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元901可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元901的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元901执行上文所描述的各个方法和处理,例如数据剪枝方法和/或序列模型的训练方法。例如,在一些实施例中,数据剪枝方法和/或序列模型的训练方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元908。在一些实施例中,计算机程序的部分或者全部可以经由ROM 902和/或通信单元909而被载入和/或安装到电子设备900上。当计算机程序加载到RAM 903并由计算单元901执行时,可以执行上文描述的数据剪枝方法和/或序列模型的训练方法的一个或多个步骤。备选地,在其他实施例中,计算单元901可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行数据剪枝方法和/或序列模型的训练方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、复杂可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-二ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。其中,服务器可以是云服务器,又称为云计算服务器或云主机,是云计算服务体系中的一项主机产品,以解决了传统物理主机与VPS服务(″Virtual Private Server″,或简称″VPS″)中,存在的管理难度大,业务扩展性弱的缺陷。服务器也可以为分布式系统的服务器,或者是结合了区块链的服务器。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本发公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。
Claims (19)
1.一种数据剪枝方法,包括:
获取输入标记序列;所述输入标记序列中的每个标记指示多媒体数据中一个数据单元的特征;所述输入标记序列是针对序列模型所包括的目标计算单元的;所述序列模型还包括设置于所述目标计算单元之前的注意力单元;
根据所述注意力单元生成的注意力矩阵,对所述输入标记序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记;
将所述剪枝后标记序列输入所述目标计算单元,得到所述目标计算单元输出的计算后标记序列;以及
组合所述计算后标记序列和所述被剪枝标记,得到所述序列模型中设置于所述目标计算单元之后的在后计算单元的输入数据;
其中,所述序列模型包括依次连接的、基于注意力机制构建的多个计算层;每个计算层包括注意力单元和后处理单元;
所述目标计算单元包括:所述多个计算层中指定计算层所包括的后处理单元;设置于所述目标计算单元之前的注意力单元为所述指定计算层的注意力单元。
2.根据权利要求1所述的方法,还包括:
将所述被剪枝标记及所述输入标记序列中每个标记的索引信息存储至预定存储空间;所述索引信息指示标记在所述输入标记序列中的位置;
其中,所述组合所述计算后标记序列和所述被剪枝标记,得到所述序列模型中设置于所述目标计算单元之后的在后计算单元的输入数据包括:
确定所述预定存储空间存储的索引信息中与所述计算后标记序列中的各标记对应的第一索引信息,及与所述被剪枝标记对应的第二索引信息;以及
根据所述第一索引信息和所述第二索引信息组合所述计算后标记序列和所述被剪枝标记,得到所述在后计算单元的输入数据。
3.根据权利要求2所述的方法,其中,根据所述第一索引信息和所述第二索引信息组合所述计算后标记序列和所述被剪枝标记包括:
分别根据所述第一索引信息和所述第二索引信息,将所述计算后标记序列中的标记和所述被剪枝标记填入预定空序列,得到所述在后计算单元的输入数据。
4.根据权利要求1所述的方法,其中,所述根据所述注意力单元产生的注意力矩阵,对所述输入标记序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记包括:
根据所述注意力矩阵,确定与所述输入标记序列中的每个标记相对应的注意力权重;
根据与所述每个标记相对应的注意力权重,对所述输入标记序列中的标记进行重排序,得到重排序后序列;以及
对所述重排序后序列进行剪枝处理,得到所述剪枝后标记序列和所述被剪枝标记。
5.根据权利要求1所述的方法,其中:
所述目标计算单元还包括:所述指定计算层的在后计算层所包括的注意力单元。
6.根据权利要求1所述的方法,其中:
所述多媒体数据包括文本数据,所述数据单元为所述文本数据中包括的字或词;
所述多媒体数据包括图像数据,所述数据单元为对所述图像数据进行分块处理得到的图像块。
7.一种序列模型的训练方法,包括:
采用所述序列模型对作为样本的多媒体数据进行处理,得到预测处理结果;其中,作为样本的所述多媒体数据具有指示真实处理结果的标签;以及
根据所述预测处理结果和所述真实处理结果,对所述序列模型进行训练,
其中,所述序列模型包括目标计算单元、设置于所述目标计算单元之前的注意力单元和设置于所述目标计算单元之后的在后计算单元;所述采用所述序列模型对所述多媒体数据进行处理的过程包括:
基于所述多媒体数据,获取针对所述目标计算单元的输入标记序列;所述输入标记序列中的每个标记指示所述多媒体数据中一个数据单元的特征;
根据所述注意力单元生成的注意力矩阵,对所述输入标记序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记;
将所述剪枝后标记序列输入所述目标计算单元,得到所述目标计算单元输出的计算后标记序列;以及
组合所述计算后标记序列和所述被剪枝标记,得到所述在后计算单元的输入数据;
其中,所述序列模型包括依次连接的、基于注意力机制构建的多个计算层;每个计算层包括注意力单元和后处理单元;
所述目标计算单元包括:所述多个计算层中指定计算层所包括的后处理单元;设置于所述目标计算单元之前的注意力单元为所述指定计算层的注意力单元。
8.根据权利要求7所述的方法,还包括:
将针对所述序列模型的目标计算图的拓扑信息存储至预定存储空间;所述目标计算图为未对所述输入标记序列进行剪枝处理的计算图;
其中,所述根据所述预测处理结果和所述真实处理结果,对所述序列模型进行训练包括:
根据所述预测处理结果和所述真实处理结果,确定所述序列模型的预测损失值;以及
根据所述预测损失值和所述预定存储空间中存储的目标计算图的拓扑信息执行反向传播运算,以对所述序列模型进行训练。
9.一种数据剪枝装置,包括:
输入序列获取模块,用于获取输入标记序列;所述输入标记序列中的每个标记指示多媒体数据中一个数据单元的特征;所述输入标记序列是针对序列模型所包括的目标计算单元的;所述序列模型还包括设置于所述目标计算单元之前的注意力单元;
序列剪枝模块,用于根据所述注意力单元生成的注意力矩阵,对所述输入标记序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记;
数据计算模块,用于将所述剪枝后标记序列输入所述目标计算单元,得到所述目标计算单元输出的计算后标记序列;以及
数据组合模块,用于组合所述计算后标记序列和所述被剪枝标记,得到所述序列模型中设置于所述目标计算单元之后的在后计算单元的输入数据;
其中,所述序列模型包括依次连接的、基于注意力机制构建的多个计算层;每个计算层包括注意力单元和后处理单元;
所述目标计算单元包括:所述多个计算层中指定计算层所包括的后处理单元;设置于所述目标计算单元之前的注意力单元为所述指定计算层的注意力单元。
10.根据权利要求9所述的装置,还包括:
存储模块,用于将所述被剪枝标记及所述输入标记序列中每个标记的索引信息存储至预定存储空间;所述索引信息指示标记在所述输入标记序列中的位置;
其中,所述数据组合模块包括:
索引确定子模块,用于确定所述预定存储空间存储的索引信息中与所述计算后标记序列中的各标记对应的第一索引信息,及与所述被剪枝标记对应的第二索引信息;以及
组合子模块,用于根据所述第一索引信息和所述第二索引信息组合所述计算后标记序列和所述被剪枝标记,得到所述在后计算单元的输入数据。
11.根据权利要求10所述的装置,其中,所述组合子模块用于:
分别根据所述第一索引信息和所述第二索引信息,将所述计算后标记序列中的标记和所述被剪枝标记填入预定空序列,得到所述在后计算单元的输入数据。
12.根据权利要求9所述的装置,其中,所述序列剪枝模块包括:
权重确定子模块,用于根据所述注意力矩阵,确定与所述输入标记序列中的每个标记相对应的注意力权重;
重排序子模块,用于根据与所述每个标记相对应的注意力权重,对所述输入标记序列中的标记进行重排序,得到重排序后序列;以及
剪枝子模块,用于对所述重排序后序列进行剪枝处理,得到所述剪枝后标记序列和所述被剪枝标记。
13.根据权利要求9所述的装置,其中:
所述目标计算单元还包括:所述指定计算层的在后计算层所包括的注意力单元。
14.根据权利要求9所述的装置,其中:
所述多媒体数据包括文本数据,所述数据单元为所述文本数据中包括的字或词;
所述多媒体数据包括图像数据,所述数据单元为对所述图像数据进行分块处理得到的图像块。
15.一种序列模型的训练装置,包括:
预测模块,用于采用所述序列模型对作为样本的多媒体数据进行处理,得到预测处理结果;其中,作为样本的所述多媒体数据具有指示真实处理结果的标签;以及
训练模块,用于根据所述预测处理结果和所述真实处理结果,对所述序列模型进行训练,
其中,所述序列模型包括目标计算单元、设置于所述目标计算单元之前的注意力单元和设置于所述目标计算单元之后的在后计算单元;所述预测模块包括:
输入序列获取子模块,用于基于所述多媒体数据,获取针对所述目标计算单元的输入标记序列;所述输入标记序列中的每个标记指示所述多媒体数据中一个数据单元的特征;
序列剪枝子模块,用于根据所述注意力单元生成的注意力矩阵,对所述输入标记序列进行剪枝处理,得到剪枝后标记序列和被剪枝标记;
数据计算子模块,用于将所述剪枝后标记序列输入所述目标计算单元,得到所述目标计算单元输出的计算后标记序列;以及
数据组合子模块,用于组合所述计算后标记序列和所述被剪枝标记,得到所述在后计算单元的输入数据;
其中,所述序列模型包括依次连接的、基于注意力机制构建的多个计算层;每个计算层包括注意力单元和后处理单元;
所述目标计算单元包括:所述多个计算层中指定计算层所包括的后处理单元;设置于所述目标计算单元之前的注意力单元为所述指定计算层的注意力单元。
16.根据权利要求15所述的装置,还包括:
存储模块,用于将针对所述序列模型的目标计算图的拓扑信息存储至预定存储空间;所述目标计算图为未对所述输入标记序列进行剪枝处理的计算图;
所述训练模块包括:
损失确定子模块,用于根据所述预测处理结果和所述真实处理结果,确定所述序列模型的预测损失值;以及
模型训练子模块,用于根据所述预测损失值和所述预定存储空间中存储的目标计算图的拓扑信息执行反向传播运算,以对所述序列模型进行训练。
17.一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1~8中任一项所述的方法。
18.一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行根据权利要求1~8中任一项所述的方法。
19.一种计算机程序产品,包括计算机程序/指令,所述计算机程序/指令存储于可读存储介质和电子设备其中至少之一上,所述计算机程序/指令在被处理器执行时实现根据权利要求1~8中任一项所述方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310638785.4A CN116611477B (zh) | 2023-05-31 | 2023-05-31 | 数据剪枝方法和序列模型的训练方法、装置、设备和介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310638785.4A CN116611477B (zh) | 2023-05-31 | 2023-05-31 | 数据剪枝方法和序列模型的训练方法、装置、设备和介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116611477A CN116611477A (zh) | 2023-08-18 |
CN116611477B true CN116611477B (zh) | 2024-05-17 |
Family
ID=87678036
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310638785.4A Active CN116611477B (zh) | 2023-05-31 | 2023-05-31 | 数据剪枝方法和序列模型的训练方法、装置、设备和介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116611477B (zh) |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113901904A (zh) * | 2021-09-29 | 2022-01-07 | 北京百度网讯科技有限公司 | 图像处理方法、人脸识别模型训练方法、装置及设备 |
CN114037074A (zh) * | 2021-11-09 | 2022-02-11 | 北京百度网讯科技有限公司 | 一种模型剪枝方法、装置、电子设备及存储介质 |
CN115374777A (zh) * | 2021-05-20 | 2022-11-22 | 三星电子株式会社 | 用于自然语言处理的方法和装置 |
CN116129330A (zh) * | 2023-03-14 | 2023-05-16 | 阿里巴巴(中国)有限公司 | 基于视频的图像处理、行为识别、分割、检测方法及设备 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US10191942B2 (en) * | 2016-10-14 | 2019-01-29 | Sap Se | Reducing comparisons for token-based entity resolution |
-
2023
- 2023-05-31 CN CN202310638785.4A patent/CN116611477B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115374777A (zh) * | 2021-05-20 | 2022-11-22 | 三星电子株式会社 | 用于自然语言处理的方法和装置 |
CN113901904A (zh) * | 2021-09-29 | 2022-01-07 | 北京百度网讯科技有限公司 | 图像处理方法、人脸识别模型训练方法、装置及设备 |
CN114037074A (zh) * | 2021-11-09 | 2022-02-11 | 北京百度网讯科技有限公司 | 一种模型剪枝方法、装置、电子设备及存储介质 |
CN116129330A (zh) * | 2023-03-14 | 2023-05-16 | 阿里巴巴(中国)有限公司 | 基于视频的图像处理、行为识别、分割、检测方法及设备 |
Also Published As
Publication number | Publication date |
---|---|
CN116611477A (zh) | 2023-08-18 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113313022B (zh) | 文字识别模型的训练方法和识别图像中文字的方法 | |
CN114942984B (zh) | 视觉场景文本融合模型的预训练和图文检索方法及装置 | |
CN113392253B (zh) | 视觉问答模型训练及视觉问答方法、装置、设备及介质 | |
CN113705628B (zh) | 预训练模型的确定方法、装置、电子设备以及存储介质 | |
CN113393371B (zh) | 一种图像处理方法、装置及电子设备 | |
CN112989970A (zh) | 文档版面分析方法、装置、电子设备及可读存储介质 | |
CN112560985A (zh) | 神经网络的搜索方法、装置及电子设备 | |
CN116152833B (zh) | 基于图像的表格还原模型的训练方法及表格还原方法 | |
CN114715145B (zh) | 一种轨迹预测方法、装置、设备及自动驾驶车辆 | |
CN113887615A (zh) | 图像处理方法、装置、设备和介质 | |
CN112632227A (zh) | 简历匹配方法、装置、电子设备、存储介质和程序产品 | |
CN116363459A (zh) | 目标检测方法、模型训练方法、装置、电子设备及介质 | |
CN113723077B (zh) | 基于双向表征模型的句向量生成方法、装置及计算机设备 | |
CN114861758A (zh) | 多模态数据处理方法、装置、电子设备及可读存储介质 | |
CN117971487A (zh) | 一种高性能算子生成方法、装置、设备及存储介质 | |
CN114549904A (zh) | 视觉处理及模型训练方法、设备、存储介质及程序产品 | |
CN113657468A (zh) | 预训练模型的生成方法、装置、电子设备和存储介质 | |
CN115577106B (zh) | 基于人工智能的文本分类方法、装置、设备和介质 | |
CN116611477B (zh) | 数据剪枝方法和序列模型的训练方法、装置、设备和介质 | |
CN114792097B (zh) | 预训练模型提示向量的确定方法、装置及电子设备 | |
CN116433899A (zh) | 图像分割方法、训练图像分割模型的方法及装置 | |
CN113361621B (zh) | 用于训练模型的方法和装置 | |
CN115546844A (zh) | 跨模态行人重识别模型生成方法、识别方法、装置及设备 | |
CN113610856A (zh) | 训练图像分割模型和图像分割的方法和装置 | |
CN113869202B (zh) | 图像识别方法、装置、设备、存储介质及程序产品 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |