CN113837308B - 基于知识蒸馏的模型训练方法、装置、电子设备 - Google Patents

基于知识蒸馏的模型训练方法、装置、电子设备 Download PDF

Info

Publication number
CN113837308B
CN113837308B CN202111155110.1A CN202111155110A CN113837308B CN 113837308 B CN113837308 B CN 113837308B CN 202111155110 A CN202111155110 A CN 202111155110A CN 113837308 B CN113837308 B CN 113837308B
Authority
CN
China
Prior art keywords
model
coding layer
distillation
feature vector
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202111155110.1A
Other languages
English (en)
Other versions
CN113837308A (zh
Inventor
李建伟
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Baidu Netcom Science and Technology Co Ltd
Original Assignee
Beijing Baidu Netcom Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Baidu Netcom Science and Technology Co Ltd filed Critical Beijing Baidu Netcom Science and Technology Co Ltd
Priority to CN202111155110.1A priority Critical patent/CN113837308B/zh
Publication of CN113837308A publication Critical patent/CN113837308A/zh
Priority to PCT/CN2022/083065 priority patent/WO2023050738A1/zh
Priority to JP2023510414A priority patent/JP2023547010A/ja
Application granted granted Critical
Publication of CN113837308B publication Critical patent/CN113837308B/zh
Priority to US18/151,639 priority patent/US20230162477A1/en
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/211Selection of the most significant subset of features
    • G06F18/2113Selection of the most significant subset of features by ranking or filtering the set of features, e.g. using a measure of variance or of feature cross-correlation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/243Classification techniques relating to the number of classes
    • G06F18/2431Multiple classes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/082Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features
    • G06V10/50Extraction of image or video features by performing operations within image blocks; by using histograms, e.g. histogram of oriented gradients [HoG]; by summing image-intensity values; Projection analysis
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/72Data preparation, e.g. statistical preprocessing of image or video features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/771Feature selection, e.g. selecting representative features from a multi-dimensional feature space
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V20/00Scenes; Scene-specific elements

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Data Mining & Analysis (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Software Systems (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Multimedia (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Databases & Information Systems (AREA)
  • Medical Informatics (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Computational Linguistics (AREA)
  • Mathematical Physics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Probability & Statistics with Applications (AREA)
  • Molecular Biology (AREA)
  • Image Analysis (AREA)

Abstract

本公开提供了一种基于知识蒸馏的模型训练方法、装置、电子设备及存储介质,涉及计算机领域,尤其涉及计算机视觉、NLP等人工智能技术领域。具体实现方案为:将基于训练样本得到的特征向量分别输入第一编码层和第二编码层,其中,该第一编码层属于第一模型,该第二编码层属于第二模型;对该第一编码层输出的结果进行汇聚处理,得到第一特征向量;根据该第二编码层的输出确定第二特征向量;对该第一特征向量和该第二特征向量做蒸馏处理,得到更新后的第一特征向量。该方案用于模型压缩蒸馏训练,可以灵活地用于模型的任一层中,压缩效果好。压缩后的模型可用于图像识别,且可以被部署到各种计算能力有限的设备上。

Description

基于知识蒸馏的模型训练方法、装置、电子设备
技术领域
本公开涉及计算机技术领域,尤其涉及计算机视觉、NLP(Natural LanguageProcessing,自然语言处理)等人工智能技术领域,具体涉及一种基于知识蒸馏的模型训练方法、装置、电子设备及存储介质。
背景技术
随着信息技术的发展,神经网络模型被广泛用于注入计算机视觉、信息检索、信息识别等机器学习任务中。但是,为了更好的学习效果,神经网络模型往往具有海量的参数,一般需要耗费巨大的算例进行推断和部署,即,在训练和推断阶段会占用大量的计算资源,因此在一些资源受限的设备上无法对此类大型神经网络模型进行相应的部署。即在保证性能优异的同时,由于模型规模大、数据量大,大型神经网络模型往往对部署环境有着较高的要求,极大地限制了该类模型的使用。
发明内容
本公开提供了一种基于知识蒸馏的模型训练方法、装置、电子设备以及存储介质。
根据本公开的一方面,提供了一种基于知识蒸馏的模型训练方法,包括:
将基于训练的图像样本得到的特征向量分别输入第一编码层和第二编码层,其中,该第一编码层属于第一模型,该第二编码层属于第二模型;对该第一编码层输出的结果进行汇聚处理,得到第一特征向量;根据该第二编码层的输出确定第二特征向量;对该第一特征向量和该第二特征向量做蒸馏处理,更新所述第一特征向量;基于更新后的该第一特征向量进行分类,完成所述第一模型的训练。
根据本公开的另一方面,提供了一种图像识别的方法,包括:将待识别图像输入训练后的识别模型,该训练后的识别模型是利用基于知识蒸馏的模型训练方法训练获得;根据该训练后的识别模型,对该待识别图像进行识别处理。
根据本公开的另一方面,提供了一种基于知识蒸馏的模型训练装置,包括:输入模块,用于将基于训练的图像样本得到的特征向量分别输入第一编码层和第二编码层,其中,该第一编码层属于第一模型,该第二编码层属于第二模型;汇聚模块,用于对该第一编码层输出的结果进行汇聚处理,得到第一特征向量;确定模块,用于根据该第二编码层的输出确定第二特征向量;蒸馏模块,用于对该第一特征向量和该第二特征向量做蒸馏处理,更新该第一特征向量;分类模块,用于基于更新后的该第一特征向量进行分类,完成该第一模型的训练。
根据本公开的另一方面,提供了一种图像识别的装置,包括:模型输入模块,用于将待识别图像输入训练后的识别模型,该训练后的识别模型根据基于知识蒸馏的模型训练装置得到的;识别模块,用于根据该训练后的识别模型,对该待识别图像进行识别处理。
根据本公开的另一方面,提供了一种电子设备,包括:至少一个处理器;以及与该至少一个处理器通信连接的存储器;其中,该存储器存储有可被该至少一个处理器执行的指令,该指令被该至少一个处理器执行,以使该至少一个处理器能够执行本公开任一实施例中的方法。
根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,该计算机指令用于使计算机执行本公开任一实施例中的方法。
根据本公开的另一方面,提供了一种计算机程序产品,包括计算机程序/指令,其特征在于,该计算机程序/指令被处理器执行时实现本公开任一实施例中的方法。
本公开的技术,可以用于模型压缩蒸馏训练,汇聚后再进行蒸馏,可以灵活用于模型的任一层中,训练好模型的计算量大幅减小,压缩效果好,从而可以将训练好的模型部署到各种计算能力有限的设备上。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本公开的限定。其中:
图1是根据本公开一实施例的基于知识蒸馏的模型训练方法的流程示意图;
图2是根据本公开另一实施例的基于知识蒸馏的模型训练方法的流程示意图;
图3是根据本公开一实施例的计算机视觉领域的Transformer模型结构示意图;
图4是根据本公开一实施例的模型蒸馏示意图;
图5是根据本公开另一实施例的模型蒸馏示意图;
图6是根据本公开一实施例的一种图像识别方法的流程示意图;
图7是根据本公开一实施例的基于知识蒸馏的模型训练装置的结构示意图;
图8是根据本公开一实施例分类模块的结构示意图;
图9是根据本公开一实施例的一种图像识别装置的结构示意图;
图10是用来实现本公开实施例的知识蒸馏的训练方法或图像识别方法的电子设备的框图。
具体实施方式
以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
现有技术中,Transformer模型是由某著名互联网公司开发的一种新型人工智能模型,近期,该模型被频频用于计算机视觉领域(CV领域),已被证实可以取得极佳的效果。但是,相比较于其余模型(如卷积神经网络模型),Transformer具有海量的参数,一般需要耗费巨大的算例进行推断和部署,即,在训练和推断阶段会占用大量的计算资源,因此在一些资源受限的设备上无法对Transformer进行相应的部署。
根据本公开的实施例,提供了一种基于知识蒸馏的模型训练方法,图1是根据本公开一实施例的基于知识蒸馏的模型训练方法的流程示意图,具体包括:
S101:将基于训练的图像样本得到的特征向量分别输入第一编码层和第二编码层,其中,该第一编码层属于第一模型,该第二编码层属于第二模型;
一示例中,第二编码层属于的第二模型是原始模型或训练好的模型,第一编码层属于的第一模型是新建模型或基于训练好的模型要生成的新模型。第一模型具体可以是学生模型,第二模型可以是教师模型。
一示例中,第一编码层和第二编码层是不同模型中相互对应的层,比如第一编码层是所属模型中的第三层,第二编码层是所属模型中与第一编码层对应的层,比如也可以是第三层。
一示例中,虽然理论上可以选取第一模型中的任意一层作为第一编码层,但是因为对模型的最后一层做蒸馏处理后实质没有降低计算量,因此不建议将最后一层作为第一编码层。一般选取模型中非最后一层的任一编码层作为第一编码层。
一示例中,该图像样本可以是图形图像。具体地,可以将多张大小相等的图片经过转换处理,生成维度相同的多个特征向量,其中,图片的张数等于生成特征向量的个数;比如,将某一要输入模型的图像切割成均等的多个小块(patch),每个小块图像的大小需要相等,图像内容可以有重叠;经过图像预处理以及特征向量转换后,生成维度相同的多个特征向量,每个小块对应一个特征向量。将由图像小块生成的多个特征向量并行输入该第一编码层和该第二编码层。如前所述,可以采用上述蒸馏方法对视觉领域Transformer模型进行压缩蒸馏。将待识别图像切分成多个小块,可以对每一小块中的图像内容进行细致地分类;将图像小块并行输入,可以通过并行处理提高整体的效率;图像小块可以存在重叠,降低了由于切分导致的遗漏部分特征的可能性。
S102:对该第一编码层输出的结果进行汇聚处理,得到第一特征向量;
一示例中,第一编码层输入的特征向量个数等于其输出的特征向量个数,汇聚处理是在该第一编码层输出的特征向量中提取特征,缩减特征向量的个数,这个过程也被称为剪枝。比如,第一编码层输出9个特征向量,经过汇聚后,获得5个特征向量。具体地,该汇聚操作可以是卷积操作。卷积可以高效滤出特征向量中的有用特征,起到高效浓缩的效果。
S103:根据该第二编码层的输出确定第二特征向量;
一示例中,可以将第二编码层输出的特征向量按照重要性重新排序后得到第二特征向量,也可以做一些特征加强处理后,再将处理后的特征向量按照重要性排序后得到第二特征向量。
S104:对该第一特征向量和该第二特征向量做蒸馏处理,更新所述第一特征向量。
一示例中,因为第一特征向量是经过至少一次汇聚处理的,所以第一特征向量中特征向量的个数小于第二特征向量中特征向量的个数,或者说,第一特征向量的大小要小于第二特征向量的大小。此时,需要从第二特征向量中抽取出与第一特征向量大小相等的特征向量用于之后的蒸馏,即可以从已经经过排序处理的第二特征向量中,抽取排序靠前的特征向量,或是抽取排序靠后的特征向量,此处不做具体限定,但是,抽取出的特征向量的大小必须和第一特征向量的大小相等,蒸馏后,得到更新后的第一特征向量,更新后的该第一特征向量学习到了第二模型中对应特征向量的一些特征。该蒸馏过程可以被称为汇聚蒸馏或剪枝蒸馏。通过抽取排序靠前的特征向量,可以让第一模型优先学习到第二模型的某些特征,而这些特征可以通过排序的规则灵活指定。比如,按照重要性排序后,抽取排序靠前的特征向量,即抽取训练好的模型中重要的特征向量用于被训模型的学习,可以大大提升模型蒸馏学习的效率。
一示例中,同一个模型中,可以选出多个不同的编码层作为第一编码层,即进行多次剪枝蒸馏。
S105:基于更新后的该第一特征向量进行分类,完成该第一模型的训练。
将更新后的第一特征向量输入下一层编码层,在从最后一层编码层得到输出后,基于该最后一层编码层的输出进行分类,完成该第一模型的训练。
采用上述实施例,可以挑选出任一层进行剪枝,并且在训练时对对应层输出的特征向量进行排序,然后对齐剪枝后的特征向量和排序后的特征向量,进行连接知识蒸馏。本公开提出一种压缩知识蒸馏方案,用于模型压缩蒸馏训练,上述剪枝-蒸馏的技术方案可以灵活用于模型的任一层中,训练好模型的计算量大幅减小,压缩效果好,从而可以将训练好的模型部署到各种计算能力有限的设备上。
根据本公开的实施例,提供了另一种基于知识蒸馏的模型训练方法,图2是根据本公开另一实施例的基于知识蒸馏的模型训练方法的流程示意图,具体包括:
S201:将更新后的该第一特征向量输入第三编码层,该第三编码层属于该第一模型;
一示例中,在经过至少一次蒸馏处理后,将更新后的第一特征向量再输入第三编码层,其中,第三编码层和第一编码层属于同一模型。
S202:将蒸馏处理后得到的更新后的第二特征向量输入第四编码层,该第四编码层属于第二模型;
一示例中,在经过至少一次蒸馏处理后,将更新后的第二特征向量再输入第四编码层,其中,第四编码层和第二编码层属于同一模型。
S203:对该第三编码层与该第四编码层的输出结果做再次蒸馏处理,得到优化结果。
一示例中,将第三编码层与第四编码层输出的结果做再次蒸馏处理,其中,因为第三编码层输入的是已经经过汇聚的特征向量,因此第三编码层输出的特征向量的个数会小于第四编码层输出的特征向量的个数,可以从第四编码层的输出中根据预设条件选取与第三编码层的输出数量相等的特征向量,再将二者进行蒸馏处理。其中,预设条件可以是选取重要性排序靠前的特征向量,也可以是其余排序方式,此处不做限定。经过再次蒸馏后,得到优化结果。该蒸馏方式被称为直接蒸馏,一示例中,可以对模型中的特征向量进行多次直接蒸馏。
S204:基于该优化结果进行分类,完成该第一模型的训练。
一示例中,基于优化结果进行分类,即是在从最后一层编码层得到输出后,基于该最后一层编码层的输出进行分类,完成该第一模型的训练。
采用上述示例,可以在剪枝蒸馏的基础上,选择没有剪枝蒸馏的编码层进行直接蒸馏。因为蒸馏过程其实是两个模型互相学习的过程,所以用上述直接蒸馏方式搭配剪枝蒸馏,可以更好地让受训模型无限接近初始模型,可以让第一模型更快、更好地接近第二模型,提升了训练过程的效率。
根据本公开的实施例,还包括:根据该第一模型中最后一个编码层得到的特征向量,获得分类结果;在该蒸馏处理中的蒸馏损失值小于固定阈值的情况下,根据该分类的结果得到分类正确率。
一示例中,经过所有的编码层后,将第一模型最终输出的优化后的特征向量输入分类器,得到分类结果。分类结果是经过多层编码层处理后训练的图像样本(以下简称训练样本)的分类结果,比如训练样本属于A类的概率是90%,属于B类的概率是10%。在训练过程中,特征向量肯定会经过至少一次蒸馏,可以基于蒸馏操作得到蒸馏的损失值(蒸馏loss),当蒸馏损失值小于某一固定阈值时,即认为训练充足,基于得到的分类结果,结合实际结果,得到分类的正确率。
一示例中,当模型训练充足以后,需要使用测试集验证模型是否有好的表现,或者是否需要继续训练。测试集是若干个测试样本组成的集合,训练时使用的是训练集,训练集是由训练样本组成的集合,比如对于图像识别任务,测试集可以有5000个测试样本(也就可以认为是5000张图片),训练集由10000个训练样本(10000张图片)组成。某个训练样本或者测试样本属于哪个类是根据这个样本的对应某个类别概率决定的。一般是选取最大概率值对应的类别作为该样本的预测类别,如果针对某一张图的预测类别和样本本身类别相同,则该样本预测正确。而分类正确率是预测正确的样本数量除以样本总数得到,比如对于测试集的分类正确率是这样:预测正确类别的有4500张,总数5000张,正确率是90%(4500/5000*100%)。
一示例中,可以用相同的样本训练多次,或用不同的样本训练多次。每次训练,都可以基于最终输出获得一个分类结果。在多次训练的后期,在蒸馏损失值均小于某一阈值,或蒸馏损失值越来越趋于稳定的情况下,则认为训练充足。此时,可以统计多次的分类结果,得到分类正确率。
上述分类正确率,可以表征训练好的模型最终的分类正确性,当分类正确性符合某一预设目标时,则预示着模型训练完成,可以投入使用。
根据本公开的实施例,在分类正确率不符合预设目标的情况下,可以继续重复训练。
一示例中,在该第一模型有多个编码层且该分类正确率不符合预设目标的情况下,在该多个编码层中选择除第一编码层以外的任一编码层的输出作为汇聚处理的输入,继续训练,具体地,在分类正确率不符合预设目标且被训模型中包括多个编码层的情况下,可以重新在该模型中选择第一编码层,但该第一编码层不能和之前的第一编码层一样。采用本示例中的方案,在仅仅通过重复训练无法得到理想的受训结果时,可以变更汇聚的相关编码层。比如之前在第二层编码层进行降维,发现剪枝率过高,导致训练分类正确率无法达到预期,可以调整剪枝的位置,使剪枝率降低。即更换第一编码层之后,再重新训练,可以提高训练效率。
应用示例:
应用本申请实施例一处理流程包括如下内容:
在训练之前,会先得到一个训练好的模型,该训练好的模型可以是用于计算机视觉领域的Transformer模型(也被称为vision transformer模型或视觉转换器模型),如图3所示。该模型包括一层图像向量转换层(Linear projection or flattened patches)和多层编码层(transformer layer)。其中,图像向量转换层主要是对输入的图像做线性变换和/或像素压平排列,将输入的图像重塑成一个向量;每个编码层都由多个编码器组成,该编码器依次由标准模块、多头注意力模块(Multi-Head Attention)、标准模块多层感知模块(MLP,也叫多层感知机,一般有两层)组成。每层中编码器的个数由输入的特征向量个数决定。每个特征向量都会被输入一个编码器,然后输出一个处理后的特征向量。编码层不会改变输入特征向量的个数。
在实际使用场景中,一般将图像分成多个小块(patch),每个小块的大小相等,每个小块对应模型的一个输入位置,经过图像向量转换层之后,生成与小块个数相等的特征向量,然后该特征向量依次进入多个编码层,每个编码层中的一个编码器处理一个特征向量。最后一个编码层输出的特征向量会被输入分类器,然后得到分类结果。该分类结果可以是一个概率值,比如识别出该输入图像是一只狗的概率是90%,是一只猫的概率是10%。
上文中描述的vision transformer模型可以同时处理多张输入图像,因此计算量较大,占用计算资源多,耗时较长。具体地,公式(1)-(4)是用于推导模型中一个编码器计算量的公式,其中,公式(1)-(3)分别估算的是编码器计算过程中三个主要步骤的计算量,公式(4)代表着整个编码器的计算量。其中,N代表着输入patch的个数或输入特征向量的个数,D是一个嵌入维度的大小(embedding size/embedding dim),是训练过程中特征向量中头(head,也被称为自注意力头、单个自注意力计算头)的个数以及每个特征向量的维度(dim,也称为特征向量的长度)的乘积,[N,D]代表维度是(N,D)的矩阵,[D,D]代表维度是(D,D)的矩阵,[N,D]、[N,N]类似,此处不再赘述。
4×([N,D]×[D,D])=>4ND2(1)
[N,D]×[D,N]+[N,N]×[N,D]=>2N2 D(2)
[N,D]×[D,4D]+[N,4D]×[4D,D]=>8ND2(3)
12ND2+2N2 D(4)
现有技术中如果要对模型进行压缩,方法主要包括两类:一类是将新模型(也叫学生模型)的层数缩小,即如果训练好的模型(也叫教师模型)有N层,即设定新模型有M层,其中M<N,这样可以减少运算量,起到压缩的效果。在知识蒸馏的过程中,只需要将新模型与训练好的模型选择一种连接方式即可,比如间隔层连接。
另一类是新模型的层数仍然保持和训练好模型的层数相同,利用上述公式,可以得知,此时需要压缩D,具体地,要么压缩head,要么压缩dim.
基于上一段的描述,压缩模型的两类方案基本上是从模型的层数和embeddingdim(也被称为feature dim)入手。本公开提出另外一类方案,由公式(1)-(4)可以看出要使最终的计算量降低,除上述两类方案,还可以从图片被分成的patch数(该patch数对应着训练过程中特征向量的个数,也可以使用sequence或者token数表示)下手。即对学生模型每一层进行剪枝,并且在训练时根据教师模型注意力层的值对教师模型的每层的特征向量在序列维度进行排序,然后对齐学生模型的前N个patch进行连接知识蒸馏。
本实施例中,教师模型和学生模型有着相同的编码层个数,相同的编码层结构,即每层中都包含相同的编码器。但是对应层编码器的初始参数与一定相同,具体地可以根据实际应用设置生成。
具体蒸馏方式如图4所示,左边是学生模型,右边是训练好的教师模型。训练样本是N小块图像,经过转换后得到N个特征向量,分别输入属于学生模型的第一编码层和属于教师模型的第二编码层,第一编码层输出N个特征向量后,将这N个特征向量输入一个汇聚层,得到压缩后的M个特征向量,其中,M<N;第二编码层输出N个特征后,进行排序,具体可以根据注意力机制(Attention Mechanism)进行排序,排序后选出前M个与学生模型中的M个特征向量进行蒸馏处理。其中,计算机视觉中所说的注意力机制是一种可以帮助模型对输入的X每个部分赋予不同的权重,抽取出更加关键及重要的信息,使模型做出更加准确的判断。注意力机制的本质就是利用相关特征图学习权重分布,再用学出来的权重施加在原特征图之上最后进行加权求和。Softmax函数(归一化函数)一般用于多分类过程中,它将多个神经元的输出,映射到(0,1)区间内,可以看成概率来理解,从而来进行多分类。上述蒸馏方式也被称作汇聚蒸馏或剪枝蒸馏。
一示例中,排序具体可以根据注意力机制中的cls token的注意力值排序。
一示例中,教师模型利用注意力机制和softmax对特征向量的重要性进行排序,具体包括以下步骤:
利用模型每层中计算各特征向量间相互注意力值权重,该权重可以使用归一化函数(softmax)计算,也可以使用其他确定注意力值函数计算,获取各特征向量间相互注意力概率,概率越大说明某个特征向量对分类越重要。通过上述概率值进行排序。
另,蒸馏损失函数可以有多种,以均方差损失(MSE loss)为例,假如学生模型经过降维后,在某层令牌特征个数为n,对于教师模型,某层令牌特征得到相互注意力概率,依据上述概率值排序后,选取前n个令牌特征,与学生模型的n个令牌特征做均方差损失计算。
参考图4,对于学生模型L(i)层模型的输入维度是[B,N,D](B是batch size(一批样本的数量),N是特征向量的个数,D是embedding dim,使用卷积(conv1d)操作(或者其他汇聚操作),得到[B,M,D](M<N)。对于教师模型L(i)层模型,输入维度是[B,N,D],通过训练获取注意力排序后的值[B,H,N,N](H是多头注意力模块中头的个数,其中D=H*d,d是单个头的大小),由于注意力的值是softmax后的结果,softmax结果是特征向量的重要性概率,根据该概率值,对教师模型的特性向量进行排序,截取最重要的前M个进行蒸馏连接,这样即实现了模型训练的裁剪蒸馏过程。
可以看出,上述实施例中介绍了一种模型蒸馏方式,即将学生模型中某一层的输出进行汇聚操作,将教师模型中对应层进行排序操作,然后将对应的特征向量进行蒸馏。将特征向量汇聚的操作也叫剪枝。因为每层中编码器的个数由输入的特征向量个数决定,因此特征向量减少后每层中的编码器会相应减少,起到了压缩学生模型的效果。
除此之外,还有另一种蒸馏方式,可以被称为直接蒸馏,如图4所示,在经过至少一次汇聚处理之后,也可以对编码层的输出直接进行蒸馏处理。此时,学生模型中第三编码层输出的仍然有M个特征向量,从教师模型第四编码层中选出M个特征向量与学生模型进行蒸馏,选取过程与上一种蒸馏方法相同,此处不再赘述。
基于最后一层编码层输出的结果,可以得到分类结果,基于该分类结果可以得到分类正确率(也叫分类指标)。分类指标指:假如测试集有1000张图片,不同类别,模型对这些图做了分类判断,假如判断对了800张,那么分类指标就是80%。在目标测试集数据上,当训练充分时,分类指标会趋于稳定,不再上升,在这时,一般蒸馏loss(蒸馏损失值)也会趋于稳定。因此,可以以分类指标趋于稳定或蒸馏损失值趋于稳定时,即可认为模型的训练完成。
需要强调的是,上述两种蒸馏,均可以用于模型中任意一层,且可以多次重复使用。在学生模型中多次使用蒸馏可具体参考图5。可以看到,教师模型和学生模型均有9层编码层,在第3层(L3)和第6层(L6)之后用了剪枝蒸馏;在第9层(L9)用了直接蒸馏。
另,要剪枝压缩的模型一般是固定的,即学生模型的层数是训练前就被定下的。如果重复训练多次之后,正确率仍然无法满足预设要求,一般会调整降维蒸馏的地方,即调整汇聚的地方,比如之前在第2层进行降维,发现剪枝率过高,导致训练正确率无法达到预期,可以调整剪枝的位置,使剪枝率降低。
如图6所示,本公开的实施例中提供一种图像识别的方法,该方法包括:
S601:将待识别图像输入训练后的识别模型,所述训练后的识别模型根据上述基于知识蒸馏的模型训练方法训练获得;
S602:根据所述训练后的识别模型,对所述待识别图像进行识别处理。
一示例中,“基于知识蒸馏的模型训练方法”是上文中公开的训练方法,此处不再赘述。将待识别图像输入该识别模型,具体地,在输入前需要根据模型的具体需求对待识别图像进行处理,比如切分成多个小块后,将多个小块并行输入该模型。该训练好的识别模型是压缩好的模型,该模型具有运算量小、占用资源空间小的优势,可以灵活部署到各种计算能力有限的设备上。
需要强调的是,该图像识别方法与上述训练方法的执行主体可以是同一主体,也可以是不同的主体。即,可以在同一设备上训练好模型,再在同一设备上利用该训练好的模型实施该识别方法,也可以在不同的设备上分别进行模型的训练和应用。
一示例中,上述图像识别的方法还可扩展用于图像物体检测,图像分割等场景中,图像物体检测即是在识别出图像中物体类型的基础上,还得到该物体的具体位置;图像分割是在得到被识别出的物体类型、位置基础上,再进一步精确识别出物体的边缘,并沿着边缘切割。总之,上述图像识别方法还可用于多种以图像识别为基础的应用场景中,此处不做限定。
如图7所示,本公开的实施例中提供一种基于知识蒸馏的模型训练装置700,该装置包括:
输入模块701,用于将基于训练样本得到的特征向量分别输入第一编码层和第二编码层,其中,该第一编码层属于第一模型,该第二编码层属于第二模型;
汇聚模块702,用于对该第一编码层输出的结果进行汇聚处理,得到第一特征向量;
确定模块703,用于根据该第二编码层的输出确定第二特征向量;
蒸馏模块704,用于对该第一特征向量和该第二特征向量做蒸馏处理,更新所述第一特征向量;
分类模块705,用于基于更新后的该第一特征向量进行分类,完成所述第一模型的训练。
如图8所示,本公开的实施例中提供的分类模块705,该模块包括:
第一输入单元801,用于将更新后的该第一特征向量输入第三编码层,该第三编码层属于该第一模型;
第二输入单元802,用于将该蒸馏处理后得到的更新后的第二特征向量输入第四编码层,该第四编码层属于该第二模型;
蒸馏单元803,用于对该第三编码层与该第四编码层的输出结果做再次蒸馏处理,得到优化结果;
分类单元804,用于基于该优化结果进行分类,完成该第一模型的训练。
一示例中,上述蒸馏模块用于:对该第一特征向量和该第二特征向量中排序靠前的特征向量做蒸馏处理,其中,该第一特征向量的大小和该第二特征向量中排序靠前的特征向量的大小相等。
一示例中,上述任一装置还包括:
分类结果获取模块,用于根据该第一模型中最后一个编码层输出的特征向量,得到分类结果;
分类正确率获取模块,用于在该蒸馏处理中的蒸馏损失值小于固定阈值的情况下,根据该分类的结果得到分类正确率。
一示例中,上述装置还包括:
重新选择模块,用于在该第一模型有多个编码层且该分类正确率不符合预设目标的情况下,在该多个编码层中选择除第一编码层以外的任一编码层的输出作为汇聚处理的输入,继续训练该第一模型。
一示例中,该汇聚模块用于:对该第一编码层输出的结果进行卷积处理。
一示例中,该输入模块用于:
将多张大小相等的图片经过转换处理,生成维度相同的多个特征向量,其中,该图片的张数等于生成特征向量的个数;
将该多个特征向量并行输入该第一编码层和该第二编码层。
如图9所示,本公开的实施例中提供一种图像识别装置900,该装置包括:
模型输入模块901,用于将待识别图像输入训练后的识别模型,所述训练后的识别模型上述任一实施例中的基于知识蒸馏的模型训练装置获得;
识别模块902,用于根据所述训练后的识别模型,对所述待识别图像进行识别处理。
本公开实施例各装置中的各模块的功能可以参见上述方法中的对应描述,在此不再赘述。
本公开的技术方案中,所涉及的用户个人信息的获取,存储和应用等,均符合相关法律法规的规定,且不违背公序良俗。
根据本公开的实施例,本公开还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
图10示出了可以用来实施本公开的实施例的示例电子设备1000的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
如图10所示,设备1000包括计算单元1001,其可以根据存储在只读存储器(ROM)1002中的计算机程序或者从存储单元1008加载到随机访问存储器(RAM)1003中的计算机程序,来执行各种适当的动作和处理。在RAM 1003中,还可存储设备1000操作所需的各种程序和数据。计算单元1001、ROM 1002以及RAM 1003通过总线1004彼此相连。输入/输出(I/O)接口1005也连接至总线1004。
设备1000中的多个部件连接至I/O接口1005,包括:输入单元1006,例如键盘、鼠标等;输出单元1007,例如各种类型的显示器、扬声器等;存储单元1008,例如磁盘、光盘等;以及通信单元1009,例如网卡、调制解调器、无线通信收发机等。通信单元1009允许设备1000通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元1001可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元1001的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元1001执行上文所描述的各个方法和处理,例如基于知识蒸馏的模型训练方法,或一种图像识别方法。例如,在一些实施例中,图像识别方法、基于知识蒸馏的模型训练方法均可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元1008。在一些实施例中,计算机程序的部分或者全部可以经由ROM 1002和/或通信单元1009而被载入和/或安装到设备1000上。当计算机程序加载到RAM 1003并由计算单元1001执行时,可以执行上文描述的基于知识蒸馏的模型训练方法的一个或多个步骤。备选地,在其他实施例中,计算单元1001可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行基于知识蒸馏的模型训练方法或图形图像识别方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、负载可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,也可以为分布式系统的服务器,或者是结合了区块链的服务器。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本发公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。

Claims (16)

1.一种基于知识蒸馏的模型训练方法,包括:
将多张大小相等的图片经过转换处理,生成维度相同的多个特征向量,所述图片的张数等于生成特征向量的个数;
将所述多个特征向量并行输入第一编码层和第二编码层,其中,所述第一编码层属于第一模型,所述第二编码层属于第二模型,所述第二模型是原始模型或训练好的模型,所述第一模型是新建模型或基于训练好的模型待生成的新模型,所述第一编码层为非最后一层的任一编码层;
对所述第一编码层输出的结果进行汇聚处理,得到第一特征向量,所述汇聚处理是在所述第一编码层输出的特征向量中提取特征,缩减特征向量的个数;
根据所述第二编码层的输出确定第二特征向量;
对所述第一特征向量和所述第二特征向量中抽取的特征向量做蒸馏处理,更新所述第一特征向量,其中,所述第一特征向量的大小和所述第二特征向量中抽取的特征向量的大小相等;
基于更新后的所述第一特征向量进行分类,完成所述第一模型的训练。
2.根据权利要求1所述的方法,其中,所述基于更新后的所述第一特征向量进行分类,完成所述第一模型的训练,包括:
将更新后的所述第一特征向量输入第三编码层,所述第三编码层属于所述第一模型;
将所述蒸馏处理后得到的更新后的第二特征向量输入第四编码层,所述第四编码层属于所述第二模型;
对所述第三编码层与所述第四编码层的输出结果做再次蒸馏处理,得到优化结果;
基于所述优化结果进行分类,完成所述第一模型的训练。
3.根据权利要求1所述的方法,其中,所述对所述第一特征向量和所述第二特征向量中抽取的特征向量做蒸馏处理,包括:
对所述第一特征向量和所述第二特征向量中排序靠前的特征向量做蒸馏处理,其中,所述第一特征向量的大小和所述第二特征向量中排序靠前的特征向量的大小相等。
4.根据权利要求1所述的方法,还包括:
在所述蒸馏处理中的蒸馏损失值小于固定阈值的情况下,根据所述分类的结果得到分类正确率。
5.根据权利要求4所述的方法,还包括:
在所述第一模型有多个编码层且所述分类正确率不符合预设目标的情况下,在所述多个编码层中选择除第一编码层以外的任一编码层的输出作为汇聚处理的输入,继续训练所述第一模型。
6.根据权利要求1所述的方法,其中,所述对所述第一编码层输出的结果进行汇聚处理,包括:
对所述第一编码层输出的结果进行卷积处理。
7.一种图像识别的方法,包括:
将待识别图像输入训练后的识别模型,所述训练后的识别模型根据权利要求1-6中任一项所述的基于知识蒸馏的模型训练方法训练获得;
根据所述训练后的识别模型,对所述待识别图像进行识别处理。
8.一种基于知识蒸馏的模型训练装置,包括:
输入模块,用于将多张大小相等的图片经过转换处理,生成维度相同的多个特征向量,所述图片的张数等于生成特征向量的个数; 将所述多个特征向量并行输入第一编码层和第二编码层,其中,所述第一编码层属于第一模型,所述第二编码层属于第二模型,所述第二模型是原始模型或训练好的模型,所述第一模型是新建模型或基于训练好的模型待生成的新模型,所述第一编码层为非最后一层的任一编码层;
汇聚模块,用于对所述第一编码层输出的结果进行汇聚处理,得到第一特征向量,所述汇聚处理是在所述第一编码层输出的特征向量中提取特征,缩减特征向量的个数;
确定模块,用于根据所述第二编码层的输出确定第二特征向量;
蒸馏模块,用于对所述第一特征向量和所述第二特征向量中抽取的特征向量做蒸馏处理,更新所述第一特征向量,其中,所述第一特征向量的大小和所述第二特征向量中抽取的特征向量的大小相等;
分类模块,用于基于更新后的所述第一特征向量进行分类,完成所述第一模型的训练。
9.根据权利要求8所述的装置,其中,所述分类模块包括:
第一输入单元,用于将更新后的所述第一特征向量输入第三编码层,所述第三编码层属于所述第一模型;
第二输入单元,用于将所述蒸馏处理后得到的更新后的第二特征向量输入第四编码层,所述第四编码层属于所述第二模型;
蒸馏单元,用于对所述第三编码层与所述第四编码层的输出结果做再次蒸馏处理,得到优化结果;
分类单元,用于基于所述优化结果进行分类,完成所述第一模型的训练。
10.根据权利要求9所述的装置,其中,所述蒸馏模块用于:
对所述第一特征向量和所述第二特征向量中排序靠前的特征向量做蒸馏处理,其中,所述第一特征向量的大小和所述第二特征向量中排序靠前的特征向量的大小相等。
11.根据权利要求8至10中任一项所述的装置,还包括:
分类结果获取模块,用于根据所述第一模型中最后一个编码层输出的特征向量,得到分类结果;
分类正确率获取模块,用于在所述蒸馏处理中的蒸馏损失值小于固定阈值的情况下,根据所述分类的结果得到分类正确率。
12.根据权利要求11所述的装置,还包括:
重新选择模块,用于在所述第一模型有多个编码层且所述分类正确率不符合预设目标的情况下,在所述多个编码层中选择除第一编码层以外的任一编码层的输出作为汇聚处理的输入,继续训练所述第一模型。
13.根据权利要求8所述的装置,其中,所述汇聚模块用于:
对所述第一编码层输出的结果进行卷积处理。
14.一种图像识别的装置,包括:
模型输入模块,用于将待识别图像输入训练后的识别模型,所述训练后的识别模型根据权利要求8-13中任一项所述的基于知识蒸馏的模型训练装置获得;
识别模块,用于根据所述训练后的识别模型,对所述待识别图像进行识别处理。
15.一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-7中任一项所述的方法。
16.一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行根据权利要求1-7中任一项所述的方法。
CN202111155110.1A 2021-09-29 2021-09-29 基于知识蒸馏的模型训练方法、装置、电子设备 Active CN113837308B (zh)

Priority Applications (4)

Application Number Priority Date Filing Date Title
CN202111155110.1A CN113837308B (zh) 2021-09-29 2021-09-29 基于知识蒸馏的模型训练方法、装置、电子设备
PCT/CN2022/083065 WO2023050738A1 (zh) 2021-09-29 2022-03-25 基于知识蒸馏的模型训练方法、装置、电子设备
JP2023510414A JP2023547010A (ja) 2021-09-29 2022-03-25 知識の蒸留に基づくモデルトレーニング方法、装置、電子機器
US18/151,639 US20230162477A1 (en) 2021-09-29 2023-01-09 Method for training model based on knowledge distillation, and electronic device

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111155110.1A CN113837308B (zh) 2021-09-29 2021-09-29 基于知识蒸馏的模型训练方法、装置、电子设备

Publications (2)

Publication Number Publication Date
CN113837308A CN113837308A (zh) 2021-12-24
CN113837308B true CN113837308B (zh) 2022-08-05

Family

ID=78967643

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111155110.1A Active CN113837308B (zh) 2021-09-29 2021-09-29 基于知识蒸馏的模型训练方法、装置、电子设备

Country Status (4)

Country Link
US (1) US20230162477A1 (zh)
JP (1) JP2023547010A (zh)
CN (1) CN113837308B (zh)
WO (1) WO2023050738A1 (zh)

Families Citing this family (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113837308B (zh) * 2021-09-29 2022-08-05 北京百度网讯科技有限公司 基于知识蒸馏的模型训练方法、装置、电子设备
CN114758360B (zh) * 2022-04-24 2023-04-18 北京医准智能科技有限公司 一种多模态图像分类模型训练方法、装置及电子设备
CN117058437B (zh) * 2023-06-16 2024-03-08 江苏大学 一种基于知识蒸馏的花卉分类方法、系统、设备及介质
CN116797611B (zh) * 2023-08-17 2024-04-30 深圳市资福医疗技术有限公司 一种息肉病灶分割方法、设备及存储介质

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110175628A (zh) * 2019-04-25 2019-08-27 北京大学 一种基于自动搜索与知识蒸馏的神经网络剪枝的压缩算法
CN113159173A (zh) * 2021-04-20 2021-07-23 北京邮电大学 一种结合剪枝与知识蒸馏的卷积神经网络模型压缩方法

Family Cites Families (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108334934B (zh) * 2017-06-07 2021-04-13 赛灵思公司 基于剪枝和蒸馏的卷积神经网络压缩方法
US11410029B2 (en) * 2018-01-02 2022-08-09 International Business Machines Corporation Soft label generation for knowledge distillation
CN108830813B (zh) * 2018-06-12 2021-11-09 福建帝视信息科技有限公司 一种基于知识蒸馏的图像超分辨率增强方法
CN110837761B (zh) * 2018-08-17 2023-04-07 北京市商汤科技开发有限公司 多模型知识蒸馏方法及装置、电子设备和存储介质
EP3748545A1 (en) * 2019-06-07 2020-12-09 Tata Consultancy Services Limited Sparsity constraints and knowledge distillation based learning of sparser and compressed neural networks
CN110852426B (zh) * 2019-11-19 2023-03-24 成都晓多科技有限公司 基于知识蒸馏的预训练模型集成加速方法及装置
CN112070207A (zh) * 2020-07-31 2020-12-11 华为技术有限公司 一种模型训练方法及装置
CN112116030B (zh) * 2020-10-13 2022-08-30 浙江大学 一种基于向量标准化和知识蒸馏的图像分类方法
CN112699958A (zh) * 2021-01-11 2021-04-23 重庆邮电大学 一种基于剪枝和知识蒸馏的目标检测模型压缩与加速方法
CN113159073B (zh) * 2021-04-23 2022-11-18 上海芯翌智能科技有限公司 知识蒸馏方法及装置、存储介质、终端
CN113837308B (zh) * 2021-09-29 2022-08-05 北京百度网讯科技有限公司 基于知识蒸馏的模型训练方法、装置、电子设备

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110175628A (zh) * 2019-04-25 2019-08-27 北京大学 一种基于自动搜索与知识蒸馏的神经网络剪枝的压缩算法
CN113159173A (zh) * 2021-04-20 2021-07-23 北京邮电大学 一种结合剪枝与知识蒸馏的卷积神经网络模型压缩方法

Also Published As

Publication number Publication date
CN113837308A (zh) 2021-12-24
US20230162477A1 (en) 2023-05-25
JP2023547010A (ja) 2023-11-09
WO2023050738A1 (zh) 2023-04-06

Similar Documents

Publication Publication Date Title
CN113837308B (zh) 基于知识蒸馏的模型训练方法、装置、电子设备
WO2023065545A1 (zh) 风险预测方法、装置、设备及存储介质
CN110084216B (zh) 人脸识别模型训练和人脸识别方法、系统、设备及介质
CN115082920B (zh) 深度学习模型的训练方法、图像处理方法和装置
CN114494784A (zh) 深度学习模型的训练方法、图像处理方法和对象识别方法
CN115482395B (zh) 模型训练方法、图像分类方法、装置、电子设备和介质
CN116152833B (zh) 基于图像的表格还原模型的训练方法及表格还原方法
CN112561060A (zh) 神经网络训练方法及装置、图像识别方法及装置和设备
CN113902010A (zh) 分类模型的训练方法和图像分类方法、装置、设备和介质
CN112418291A (zh) 一种应用于bert模型的蒸馏方法、装置、设备及存储介质
CN114821063A (zh) 语义分割模型的生成方法及装置、图像的处理方法
CN114861758A (zh) 多模态数据处理方法、装置、电子设备及可读存储介质
CN113947700A (zh) 模型确定方法、装置、电子设备和存储器
CN115294405B (zh) 农作物病害分类模型的构建方法、装置、设备及介质
CN110717577A (zh) 一种注意区域信息相似性的时间序列预测模型构建方法
CN113361522B (zh) 用于确定字符序列的方法、装置和电子设备
CN113947195A (zh) 模型确定方法、装置、电子设备和存储器
CN114882388A (zh) 多任务模型的训练及预测方法、装置、设备和介质
CN114419327A (zh) 图像检测方法和图像检测模型的训练方法、装置
CN114330576A (zh) 模型处理方法、装置、图像识别方法及装置
CN113989845A (zh) 姿态分类方法和姿态分类模型的训练方法、装置
CN116468985B (zh) 模型训练方法、质量检测方法、装置、电子设备及介质
CN116977021B (zh) 基于大数据的系统对接自动推单方法
CN111274216B (zh) 无线局域网的识别方法、识别装置、存储介质及电子设备
CN115205733A (zh) 视频识别方法、装置、设备、系统及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant