CN117892841A - 基于渐进式联想学习的自蒸馏方法及系统 - Google Patents
基于渐进式联想学习的自蒸馏方法及系统 Download PDFInfo
- Publication number
- CN117892841A CN117892841A CN202410288255.6A CN202410288255A CN117892841A CN 117892841 A CN117892841 A CN 117892841A CN 202410288255 A CN202410288255 A CN 202410288255A CN 117892841 A CN117892841 A CN 117892841A
- Authority
- CN
- China
- Prior art keywords
- sample
- network
- self
- distillation
- learning
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 124
- 238000004821 distillation Methods 0.000 title claims abstract description 78
- 230000013016 learning Effects 0.000 title claims abstract description 43
- 230000000750 progressive effect Effects 0.000 title claims abstract description 41
- 238000012549 training Methods 0.000 claims abstract description 60
- 238000009826 distribution Methods 0.000 claims abstract description 44
- 230000008569 process Effects 0.000 claims abstract description 43
- 238000007781 pre-processing Methods 0.000 claims abstract description 14
- 230000005012 migration Effects 0.000 claims abstract description 3
- 238000013508 migration Methods 0.000 claims abstract description 3
- 230000006870 function Effects 0.000 claims description 41
- 230000035045 associative learning Effects 0.000 claims description 22
- 238000002156 mixing Methods 0.000 claims description 19
- 238000000605 extraction Methods 0.000 claims description 12
- 238000004364 calculation method Methods 0.000 claims description 7
- 238000012935 Averaging Methods 0.000 claims description 3
- 239000000203 mixture Substances 0.000 claims description 3
- 238000012546 transfer Methods 0.000 claims description 3
- 238000010276 construction Methods 0.000 claims description 2
- 239000000284 extract Substances 0.000 abstract description 9
- 238000013473 artificial intelligence Methods 0.000 abstract description 2
- 238000013459 approach Methods 0.000 description 14
- 230000000875 corresponding effect Effects 0.000 description 13
- 238000013140 knowledge distillation Methods 0.000 description 12
- 241000282414 Homo sapiens Species 0.000 description 6
- 238000012545 processing Methods 0.000 description 6
- 238000013528 artificial neural network Methods 0.000 description 5
- 230000008901 benefit Effects 0.000 description 4
- 230000006835 compression Effects 0.000 description 4
- 238000007906 compression Methods 0.000 description 4
- 238000005516 engineering process Methods 0.000 description 4
- 230000000694 effects Effects 0.000 description 3
- 238000010606 normalization Methods 0.000 description 3
- 238000013461 design Methods 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 230000004044 response Effects 0.000 description 2
- 230000000007 visual effect Effects 0.000 description 2
- 241001465754 Metazoa Species 0.000 description 1
- 235000009499 Vanilla fragrans Nutrition 0.000 description 1
- 244000263375 Vanilla tahitensis Species 0.000 description 1
- 235000012036 Vanilla tahitensis Nutrition 0.000 description 1
- 238000007792 addition Methods 0.000 description 1
- 230000008033 biological extinction Effects 0.000 description 1
- 210000004556 brain Anatomy 0.000 description 1
- 238000013145 classification model Methods 0.000 description 1
- 230000002596 correlated effect Effects 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 239000004744 fabric Substances 0.000 description 1
- 238000003709 image segmentation Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000012804 iterative process Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 238000007670 refining Methods 0.000 description 1
- 230000001105 regulatory effect Effects 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000011144 upstream manufacturing Methods 0.000 description 1
Landscapes
- Image Analysis (AREA)
Abstract
本发明属于人工智能技术领域,公开了基于渐进式联想学习的自蒸馏方法及系统,所述方法包括步骤1、数据预处理;步骤2、构建关联样本;步骤3、学习类别特征和类间关系;将关联样本输入学生网络,利用关联样本训练学生网络,学生网络从关联样本中学习类别特征和类间关系,输出概率分布;步骤4、自蒸馏阶段;学生网络利用在步骤3学习到的类别特征和类间关系的知识,来模拟教师网络的输出概率分布,并将这些知识进行知识迁移以指导学生网络的自我学习过程。本发明通过自蒸馏的方式实现自我学习,减少对复杂教师网络的依赖;通过本发明提取更丰富的类别间关系知识。
Description
技术领域
本发明属于人工智能技术领域,特别涉及基于渐进式联想学习的自蒸馏方法及系统。
背景技术
自蒸馏(Self-Knowledge Distillation, Self-KD)是一种新兴的知识蒸馏方法,它允许深度神经网络通过自我蒸馏的方式提升其性能,而不依赖于传统知识蒸馏方法中的教师网络。这种方法在训练过程中让网络自己生成软标签,通过不断迭代进行自蒸馏,从而达到了与传统知识蒸馏相近的效果。尽管自蒸馏展示出了提高学生网络性能的潜力,但其也带来了额外的计算和参数开销。这主要体现在一下三个方面:
1)对复杂教师网络的依赖:传统的知识蒸馏方法依赖于预先训练的、复杂的教师网络来指导学生网络的训练。这种依赖限制了知识蒸馏的应用,尤其是在资源受限或无法访问预训练模型的情况下。这种方法的基本假设是教师网络能够提供更高质量的知识来指导学生网络,但这也意味着需要额外的时间和资源来训练和维护这些复杂的教师网络。在没有可用的预训练教师网络或资源受限的情况下,实现有效的知识蒸馏变得具有挑战性。
2)额外的计算和参数需求:尽管自蒸馏方法尽管通过辅助架构或数据增强来捕获额外的暗知识,可以有效提高网络性能,但这增加了模型的计算和参数需求。此方法通常涉及更复杂的网络结构或更高级的数据处理技术,从而增加了计算负担和内存需求。这限制了模型在计算资源受限的环境中的应用,尤其是在移动设备或边缘计算设备上。
3)忽略类别间关系的知识:大多数现有的自我知识蒸馏方法仅从单个输入样本中提取软标签作为监督信号,忽略了类别间的关系知识。此方法的局限性在于它只考虑了单个样本的信息,而没有利用数据集中不同样本之间的潜在关联。这导致了模型在理解和学习类别间复杂关系方面存在不足,进而限制了其性能和泛化能力。
发明内容
针对现有技术存在的不足,本发明提供基于渐进式联想学习的自蒸馏方法及系统,通过模拟人类的联想学习过程,使得学生网络能够在没有复杂教师网络的指导下自我提取和集成知识,它可以和多种分类框架进行结合,解决了传统知识蒸馏方法中对复杂教师网络过度依赖的问题,同时减少了额外的计算和参数需求。通过在蒸馏过程中考虑样本间的关系,基于渐进式联想学习的自蒸馏方法能够提取更丰富的类别间关系知识,特别是在处理复杂的分类任务时。
为了解决上述技术问题,本发明采用的技术方案是:
基于渐进式联想学习的自蒸馏方法,包括以下步骤:
步骤1、数据预处理;
输入数据是不同类型的图像,对输入的数据进行预处理,生成训练数据集样本,训练数据集样本包括原始样本及其对应的标签/>;从训练数据集中随机抽取一批样本,批量大小为B;
步骤2、通过样本混合方法构建关联样本;
将原始样本和在基于渐进式联想学习方法生成的生成样本/>通过样本混合方法结合,将两个样本以像素级别混合,构建关联样本/>;
步骤3、学习类别特征和类间关系;
首先构建包括教师网络和学生网络的网络模型,然后将关联样本输入学生网络,利用关联样本/>训练学生网络,学生网络从关联样本/>中学习类别特征和类间关系,输出概率分布,用于自蒸馏阶段的训练;
步骤4、自蒸馏阶段;
学生网络充当教师网络,利用在步骤3学习到的类别特征和类间关系的知识,来模拟教师网络的输出概率分布,并将这些知识进行知识迁移以指导学生网络的自我学习过程;具体来说,学生网络的自我学习过程通过以下方式实现:
步骤4-1、利用教师网络生成的概率分布来监督关联样本的概率分布:具体来说,教师网络使用原始样本/>来生成概率分布,这些分布随后被用来指导学生网络的训练过程,使关联样本/>产生相近的概率分布;
步骤4-2、在使用关联样本进行训练的过程中,使用原始样本/>和基于渐进式联想学习方法生成的生成样本/>的概率分布来监督关联样本/>的概率分布;
步骤4-3、优化损失函数:整个训练过程是通过优化总损失函数来实现的,总损失为关联交叉熵损失和关联蒸馏损失的加权和;
循环重复执行步骤1-步骤4,直到学生网络的参数稳定下来,达到收敛状态。
进一步的,步骤2中,样本混合方法是通过加权平均的方式进行的,其中权重由一个预先定义的参数控制,参数/>决定每个原始样本/>在关联样本/>中的贡献程度,表示如下:
。
进一步的,步骤3中,引入温度参数T来调节softmax函数的输出,计算关联样本的损失;表示如下:
;
其中是指学生网络最终输出每个类别的概率,大小为B×C,C是指类别的总个数;/>是学生网络第i个类别的logits输出;/>是学生网络是当前类别的logits输出,logits输出指的是网络中最后一层softmax函数激活前的输出,/>为参数。
进一步的,引入改进的交叉熵损失函数,同时考虑原始样本和生成样本/>的信息,具体是,最小化关联交叉熵损失函数LMCE,记为/>,公式如下:
;
这个函数是两个标准交叉熵损失函数LCE的加权和,其中和/>分别是原始样本和生成样本/>的对应标签,权重/>用于调节这两个损失函数的相对贡献,由参数/>确定的网络模型输入关联样本/>得到每个类别的概率,再与原始样本/>的对应标签/>做交叉熵损失,/>由参数/>确定的网络模型输入关联样本/>得到每个类别的概率/>,再与生成样本/>的对应标签/>做交叉熵损失。
进一步的,总损失LALSD具体公式为:
;
在这个公式中,λ是关联样本交叉熵损失的权重,而β是关联样本蒸馏损失的权重;其中,针对关联样本,设计关联蒸馏损失函数LCls,衡量类间相似性,记为,表示如下:
;
这个函数是两个KL散度项的加权和,其中、/>和分别是以参数为/>的网络模型输入数据为原始样本/>的输出概率、以参数为/>的网络模型输入数据为生成样本/>为的输出概率和以参数/>的网络模型输入数据为关联样本/>为的输出概率,大小均为B×C,/>是采用了KL散度来衡量原始样本/>和关联样本/>之间的概率分布差异,是采用了KL 散度来衡量生成样本/>和关联样本/>之间的概率分布差异,/>是由贝塔分布生成的图像混合比例系数,用于平衡这两个项的贡献,而/>是参数/>的固定副本,在训练过程中学生网络和教师网络权重共享。
本发明还提供一种基于渐进式联想学习的自蒸馏系统,用于实现如前所述的基于渐进式联想学习的自蒸馏方法,所述系统包括数据预处理模块、构建关联样本模块、由教师网络和学生网络组成的网络模型、损失计算模块,
所述数据预处理模块用于获取输入的不同类型的图像数据,并进行预处理生成训练数据集样本,包括原始样本及其对应的标签;
所述构建关联样本模块用于将原始样本和在基于渐进式联想学习方法生成的生成样本通过样本混合方法结合,将两个样本以像素级别混合,构建关联样本;
所述网络模型包括教师网络特征提取模块和学生网络特征提取模块,通过学生网络特征提取模块提取关联样本中的类别特征和类间关系,输出概率分布,用于模型自蒸馏训练;通过教师网络特征提取模块提取原始样本中的特征,输出概率分布,用来指导学生网络的自我学习过程;
所述损失计算模块用于计算关联交叉熵损失和关联蒸馏损失的加权和。
与现有技术相比,本发明优点在于:
1)减少对复杂教师网络的依赖:本发明模拟人类的联想学习过程,并通过自蒸馏的方式,允许学生网络在没有复杂教师网络的情况下自我提取知识和集成知识,从而简化了训练过程。这种方法减少了对预训练教师网络的依赖,使得知识蒸馏过程更加高效和灵活。
2)提取更丰富的类别间关系知识:与传统方法相比,基于渐进式联想学习的自蒸馏方法能够从不同小批次中的样本间关系中提取更丰富的类别间关系知识。这种方法的优势在于它不仅考虑了单个样本的信息,而且利用了数据集中不同样本之间的潜在关联,从而提高了模型的性能和泛化能力,特别是在处理复杂的分类任务时。
3)减少额外的计算和参数需求:基于渐进式联想学习的自蒸馏方法避免了通过辅助架构或数据增强来捕获额外暗知识所带来的额外计算和参数需求。这使得该方法在计算资源受限的环境中更具优势,特别是适用于移动设备或边缘计算设备。
4)本发明可以和多种分类框架进行结合,无论是传统分类还是细粒度分类,这种方法不仅优化了模型在处理图像识别和分类方面的效能,还通过模型压缩技术减少了对计算资源的需求,使得模型更适合在资源受限的环境中部署和运行,同时保持高精度和快速响应的特性。本发明适用于各种计算机视觉任务,实现深度神经网络在计算资源受限的嵌入式设备(例如无人机、手机等)上进行快速部署及应用。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本发明的系统结构图;
图2为本发明的渐进式联想学习关系图;
图3为本发明的方法流程图。
具体实施方式
下面结合附图及具体实施例对本发明作进一步的说明。
本实施例提供一种基于渐进式联想学习的自蒸馏方法,模拟人类的联想学习过程,使得学生网络能够在没有复杂教师网络的指导下自我提取和集成知识。本发明可以和多种分类框架进行结合,在本发明中,选择广泛使用的ResNet网络作为学生网络的基础架构。ResNet通过其独特的残差连接解决了深层网络中的梯度消失问题,使得网络能够有效地学习深层特征。
本发明的基于渐进式联想学习的自蒸馏方法旨在通过在训练过程中使用关联样本来构建和提取类别间的关系,从而提升自蒸馏的效果。此方法有两个主要目标。首先,它旨在弥补在构建关联样本过程中可能发生的类特定特性的丢失。由于这些关联样本是通过组合原始样本和生成样本构建的,存在原始特征丢失的风险,这可能会降低网络对原始类别的输出置信度。其次,这种方法允许网络通过自蒸馏过程更有效地吸收从关联样本中获得的类别间信息,并将这些信息应用于提升整体性能。
结合图1-图3所示,基于渐进式联想学习的自蒸馏方法包括以下步骤:
步骤1、数据预处理;
输入数据是不同类型的图像,对输入的数据进行预处理,生成训练数据集样本,训练数据集样本包括原始样本及其对应的标签/>;从训练数据集中随机抽取一批样本,批量大小为B。
本发明从所选的训练数据集中随机抽取一批样本,批量大小为B。这个过程是实施自蒸馏策略的起始点,对后续的训练至关重要。在这一步骤中,目标是确保抽取的样本代表了数据集的多样性和复杂性。
原始样本是视觉内容的集合,可能包括不同类型的图片,如动物、物体或场景等。这些图像是网络学习视觉特征的基础。相应的标签/>则为每个原始样本提供了分类信息,这些信息对于监督学习过程是必不可少的。
在样本采集过程中,随机性在训练网络时发挥着重要作用,因为它确保了训练数据的多样性,避免了模型对特定样本的过拟合。此外,这一步骤中还包括对样本数据的预处理操作,如调整大小、归一化或颜色标准化,以确保样本数据适合网络处理。
步骤2、通过样本混合方法构建关联样本;
将原始样本和在基于渐进式联想学习方法生成的生成样本/>通过样本混合方法结合,将两个样本以像素级别混合,构建关联样本/>。
其中,基于渐进式联想学习方法结合图2所示,受人类联想学习过程的启发,本发明旨在模仿人类如何逐渐从有限的知识和对各种概念间关系的初步理解,通过不断接触新信息来学习并建立更加全面的关系图。因此,生成样本是根据原始样本/>在训练数据集大小为B一批样本中,根据置信度由简单到困难,渐进式地加入到训练过程,从而使得生成样本/>由最接近原始样本/>的类别,逐渐扩展关系图,直到置信度最困难的生成样本/>加入到训练在人类学习过程中,一开始知识是有限的,但随着新信息的不断接触和处理,人们开始理解并学习不同类别之间的联系,逐渐在大脑中形成一个全面的关系网络。类似地,基于渐进式联想学习的自蒸馏方法旨在通过自蒸馏来引导神经网络学习和理解不同类别之间的复杂关系。这种方法鼓励网络不仅学习单个类别的特征,还要理解这些类别如何相互关联,从而实现更深层次的学习和更强大的推理能力。通过这种联想学习过程,本发明提出的基于渐进式联想学习的自蒸馏方法使得网络能够建立起一个更加丰富和细致的知识体系,从而提高其处理和理解新信息的能力。
本发明的样本混合方法是一种数据增强技术,用来构建关联样本。样本混合方法的核心在于将两个不同的样本以像素级别混合,从而构建一个新的样本。具体地,这种结合是通过加权平均的方式进行的,其中权重由一个预先定义的参数控制,参数/>决定每个原始样本/>在关联样本/>中的贡献程度,表示如下:
;
这种方法不仅简单地叠加了两个样本的特征,而且在一定程度上融合了它们的信息。
步骤3、学习类别特征和类间关系;
首先构建包括教师网络和学生网络的网络模型,然后本步骤中使用步骤2生成的关联样本来作为模型的输入数据,输入的数据批量大小是B。将关联样本/>输入学生网络,利用关联样本/>训练学生网络,学生网络从关联样本/>中学习类别特征和类间关系,输出与之相关的概率分布,用于自蒸馏阶段的训练。
作为一个优选的实施方式,本发明中,使用由参数确定的网络模型的logits输出/>,并引入温度参数T来调节softmax函数的输出,计算关联样本的损失;表示如下:
;
其中是指学生网络最终输出每个类别的概率,大小为B×C,C是指类别的总个数;/>是学生网络第i个类别的logits输出;/>是学生网络是当前类别的logits输出,logits输出指的是网络中最后一层softmax函数激活前的输出。
这个做法借鉴了统计力学中的玻尔兹曼分布原理。根据这一原理,可以证明当温度参数T接近0时,softmax函数的输出会趋向于一个接近硬目标。而当温度参数T增加到很高时,softmax输出会变得更加“软化”,即输出的分布变得更加平滑。因此,可以通过调高温度T的值来产生足够平滑的分布,从而在训练学生模型时提供更为细致的概率信息。这样做使得学生模型的softmax输出更接近于教师模型的输出,帮助学生模型学习到从硬目标(即直接的类别标签)中无法获取的细微和隐含的知识。
作为一个优选的实施方式,为了有效地学习类别间的关系,本发明最小化关联交叉熵损失函数LMCE,这有助于网络精确地捕捉和区分不同类别间的细微差异。
引入改进的交叉熵损失函数,同时考虑原始样本和生成样本/>的信息,具体是,最小化关联交叉熵损失函数LMCE,记为/>,公式如下:
;
这个函数是两个标准交叉熵损失函数LCE的加权和,其中和/>分别是原始样本和生成样本/>的对应标签,权重/>用于调节这两个损失函数的相对贡献,由参数/>确定的网络模型输入关联样本/>得到每个类别的概率,再与原始样本/>的对应标签/>做交叉熵损失,/>由参数/>确定的网络模型输入关联样本/>得到每个类别的概率/>,再与生成样本/>的对应标签/>做交叉熵损失。
引入这种改进的交叉熵损失函数的目的是使得网络能够更加全面地理解和利用由关联样本引入的附加信息。这在处理具有高度多样性或复杂性的数据集时尤其有价值,因为这些数据集通常需要模型捕捉更细微的模式和关系。通过这种方法,可以同时考虑到原始样本/>和生成样本/>的信息,使得网络能够在学习过程中充分利用关联样本/>带来的额外信息,能够促进网络学习到更为全面和细致的特征表达,尤其是在处理复杂或多样化的数据时。
通过最小化关联交叉熵损失,学生网络能够更有效地学习和理解不同类别之间的关系。这种损失最小化策略是学习类特征和类别关系的关键环节,它为学生网络提供了一种机制,使得网络不仅能学习到单个类别的特征,还能把握类别间的相互关系和差异。这一步骤为下面自蒸馏打下了坚实的基础,使得在后续步骤中,学生网络能够有效地利用在此步骤中学到的知识。
步骤4、自蒸馏阶段;
本发明中,没有一个独立的、预先训练好的教师网络,而是学生网络自己通过在步骤3学习的知识来充当教师网络。利用在步骤3学习到的类别特征和类间关系的知识,来模拟教师网络的输出概率分布,并将这些知识自我转移和内化,以指导学生网络的自我学习过程。
这一步骤是自我学习过程,其中学生网络利用在步骤3学到的知识来进一步提升自身的性能。此时,学生网络本身充当了教师网络的角色。网络从原始样本和生成样本/>中进行自学,利用已经获得的知识来指导学习过程。
具体来说,学生网络的自我学习过程通过以下方式实现:
步骤4-1、利用教师网络生成的概率分布来监督关联样本的概率分布:具体来说,教师网络使用原始样本/>来生成概率分布,这些分布随后被用来指导学生网络的训练过程,使关联样本/>产生相近的概率分布。
步骤4-2、在使用关联样本进行训练的过程中,使用原始样本/>和基于渐进式联想学习方法生成的生成样本/>的概率分布来监督关联样本/>的概率分布,这样做的目的是迫使学生网络学习原始样本的特征,并将这些知识自我转移和内化。
本发明中,在使用关联样本进行训练的过程中,这些样本被用作网络的输入,输入的批次大小依然是B。目的是让网络能够学习到不同类别之间的关系。然而,使用关联样本/>可能会导致原始特征的部分丢失。为了克服这个问题,采取了一种策略,即使用原始样本/>和基于渐进式联想学习方法生成的生成样本/>的概率分布来监督关联样本/>的概率分布。这样做的目的是迫使学生网络学习原始样本的特征,并将这些知识自我转移和内化。
步骤4-3、优化损失函数:
在自蒸馏过程中的关键是,学生网络通过模拟教师网络的输出分布,来提升其对数据的理解和处理能力。由于教师网络是以原始样本为基础进行训练的,其输出分布被认为是更准确和可靠的。学生网络通过自蒸馏,即利用这些分布作为参考,可以有效地提升自身的性能。由于传统的蒸馏损失函数(vanilla distillation loss)不适用于关联样本作为输入的情况。因此,本发明针对关联样本/>,设计关联蒸馏损失函数,这个损失函数能够使模型更全面地学习原始样本/>和基于渐进式联想学习方法生成的生成样本/>的特征。
具体来说,针对关联样本,设计关联蒸馏损失函数LCls,衡量类间相似性,记为,表示如下:
;
这个函数是两个KL散度项的加权和,其中、/>和分别是以参数为/>的网络模型输入数据为原始样本/>的输出概率、以参数为/>的网络模型输入数据为生成样本/>为的输出概率和以参数/>的网络模型输入数据为关联样本/>为的输出概率,大小均为B×C,/>是采用了KL散度来衡量原始样本/>和关联样本/>之间的概率分布差异,是采用了KL 散度来衡量生成样本/>和关联样本/>之间的概率分布差异,/>是由贝塔分布生成的图像混合比例系数,用于平衡这两个项的贡献,而/>是参数/>的固定副本,在训练过程中学生网络和教师网络权重共享。其中本实施例的图1中,输出的概率分布可以通过坐标图表示,坐标图的横坐标为类别(生成样本的类别),纵坐标为不同类别对应的输出概率。
通过这种方法,模型不仅能够从原始样本和生成样本/>中学习特征,还能有效地从关联样本/>中吸收更丰富的信息。
综上,整个训练过程是通过优化总损失函数来实现的,总损失为关联交叉熵损失LMCE和关联蒸馏损失LCls的加权和,总体损失函数表示为LALSD,具体公式为:
;
在这个公式中,λ是关联样本交叉熵损失的权重,而β是关联样本蒸馏损失的权重。
通过这种方式,模型的训练不仅考虑了通过关联样本学习分类任务的需求(通过LMCE),同时也考虑了从关联样本中蒸馏知识的需求(通过LCls)。这种损失函数的设计旨在充分利用关联样本中包含的信息,以提高模型的性能和泛化能力。此外,温度参数T的平方项在LCls中起到调节作用,影响蒸馏过程中知识的软化程度,从而使蒸馏过程更加有效。
在训练时,循环重复执行步骤1-步骤4,直到学生网络的参数稳定下来,达到收敛状态,完成训练。具体来说,步骤3是通过关联样本学习类别间的关系,而步骤4是学生网络利用步骤3学习到的知识进行自蒸馏,从而进一步细化其学习。利用这些学到的关系来优化模型性能。这个迭代过程是一个动态的学习过程,其中学生网络不断通过新的数据和反馈调整其内部参数。每完成一次循环,模型就会根据累积的学习成果进一步细化和改进其参数。这个过程持续进行,直到模型的性能不再有显著提升,或者参数变化趋于微小,这表明模型已经接近最优状态,即参数收敛。
训练过程的目标是优化学生网络的参数,最终输出一个经过训练的、参数化的轻量级模型。为了开始训练,首先需要对学生网络的参数进行初始化。此外,还需要设定一些超参数,这些超参数将指导训练过程,对于传统的分类任务,选择了批处理大小为128,总训练周期数为200。而对于细粒度分类任务,将批处理大小和总周期数分别设置为32和200。在本发明的自蒸馏方法中,温度参数T被设定为4,以调节软标签的平滑程度。此外,损失权重λ和β分别被设置为0.1和1,用以平衡训练过程中不同损失函数的影响,能够取得较好的技术效果。通过对这些参数的精心选择和调整,可以确保学生网络有效地学习并从输入的样本数据中提取出关键特征,并准确地预测其对应的标签。
作为本发明的一个应用,通过前面的方法训练好的网络模型,可以用于图像分类(其中图像分类方法非本发明设计要点,此处不再赘述),并作为骨干网络用于目标检测、图像分割等上游任务,在网络模型参数量实现10倍压缩率的情况下,其精度与大模型保持相当,因此训练好的模型可以很便捷地部署于资源受限的嵌入式设备。即训练好的一个图像分类模型,输入图像先经过归一化、降噪等预处理后输入模型,通过模型提取图像特征,并进行分类,输出分类结果。
如图1所示,作为本发明的另一实施例,提供基于渐进式联想学习的自蒸馏系统,用于实现如前面所述的基于渐进式联想学习的自蒸馏方法。
所述系统包括数据预处理模块、构建关联样本模块、由教师网络和学生网络组成的网络模型、损失计算模块。
所述数据预处理模块用于获取输入的不同类型的图像数据,并进行预处理生成训练数据集样本,包括原始样本及其对应的标签。
所述构建关联样本模块用于将原始样本和在基于渐进式联想学习方法生成的生成样本通过样本混合方法结合,将两个样本以像素级别混合,构建关联样本。
所述网络模型包括教师网络特征提取模块和学生网络特征提取模块,通过学生网络特征提取模块提取关联样本中的类别特征和类间关系,输出概率分布,用于模型自蒸馏训练;通过教师网络特征提取模块提取原始样本中的特征,输出概率分布,用来指导学生网络的自我学习过程。
所述损失计算模块用于计算关联交叉熵损失和关联蒸馏损失的加权和。
其中各个模块的详细功能及数据处理过程及整个基于渐进式联想学习的自蒸馏方法的具体步骤,可参见前面的记载,此处不再赘述。
综上所述,本发明的基于渐进式联想学习的自蒸馏方法通过其独特的自蒸馏策略,有效地解决了传统知识蒸馏方法中的多个问题。本发明解决了传统知识蒸馏方法中对复杂教师网络过度依赖的问题,同时减少了额外的计算和参数需求。通过在蒸馏过程中考虑样本间的关系,基于渐进式联想学习的自蒸馏方法能够提取更丰富的类别间关系知识,这在现有的自蒸馏方法中往往被忽略。这种方法的改进使得网络能够更有效地学习和泛化,特别是在处理复杂的分类任务时。
无论是传统分类还是细粒度分类,本发明不仅优化了模型在处理图像识别和分类方面的效能,还通过模型压缩技术减少了对计算资源的需求,使得模型更适合在资源受限的环境中部署和运行,同时保持高精度和快速响应的特性。
本发明提供了一种更高效、更全面的深度神经网络压缩方法,能够适用于各种计算机视觉任务,实现深度神经网络在计算资源受限的嵌入式设备(例如无人机、手机等)上进行快速部署及应用。这种方法的提出,不仅在技术上具有创新性,而且在实际应用中具有显著的优势。
当然,上述说明并非是对本发明的限制,本发明也并不限于上述举例,本技术领域的普通技术人员,在本发明的实质范围内,做出的变化、改型、添加或替换,都应属于本发明的保护范围。
Claims (6)
1.基于渐进式联想学习的自蒸馏方法,其特征在于,包括以下步骤:
步骤1、数据预处理;
输入数据是不同类型的图像,对输入的数据进行预处理,生成训练数据集样本,训练数据集样本包括原始样本及其对应的标签/>;从训练数据集中随机抽取一批样本,批量大小为B;
步骤2、通过样本混合方法构建关联样本;
将原始样本和在基于渐进式联想学习方法生成的生成样本/>通过样本混合方法结合,将两个样本以像素级别混合,构建关联样本/>;
步骤3、学习类别特征和类间关系;
首先构建包括教师网络和学生网络的网络模型,然后将关联样本输入学生网络,利用关联样本/>训练学生网络,学生网络从关联样本/>中学习类别特征和类间关系,输出概率分布,用于自蒸馏阶段的训练;
步骤4、自蒸馏阶段;
学生网络充当教师网络,利用在步骤3学习到的类别特征和类间关系的知识,来模拟教师网络的输出概率分布,并将这些知识进行知识迁移以指导学生网络的自我学习过程;具体来说,学生网络的自我学习过程通过以下方式实现:
步骤4-1、利用教师网络生成的概率分布来监督关联样本的概率分布:具体来说,教师网络使用原始样本/>来生成概率分布,这些分布随后被用来指导学生网络的训练过程,使关联样本/>产生相近的概率分布;
步骤4-2、在使用关联样本进行训练的过程中,使用原始样本/>和基于渐进式联想学习方法生成的生成样本/>的概率分布来监督关联样本/>的概率分布,这样做的目的是迫使学生网络学习原始样本的特征,并将这些知识自我转移和内化;
步骤4-3、优化损失函数:整个训练过程是通过优化总损失函数来实现的,总损失为关联交叉熵损失和关联蒸馏损失的加权和;
循环重复执行步骤1-步骤4,直到学生网络的参数稳定下来,达到收敛状态。
2.根据权利要求1所述的基于渐进式联想学习的自蒸馏方法,其特征在于,步骤2中,样本混合方法是通过加权平均的方式进行的,其中权重由一个预先定义的参数控制,参数决定每个原始样本/>在关联样本/>中的贡献程度,表示如下:
。
3.根据权利要求1所述的基于渐进式联想学习的自蒸馏方法,其特征在于,步骤3中,引入温度参数T来调节softmax函数的输出,计算关联样本的损失;表示如下:
;
其中是指学生网络最终输出每个类别的概率,大小为B×C,C是指类别的总个数;/>是学生网络第i个类别的logits输出;/>是学生网络是当前类别的logits输出,logits输出指的是网络中最后一层softmax函数激活前的输出,/>为参数。
4.根据权利要求3所述的基于渐进式联想学习的自蒸馏方法,其特征在于,引入改进的交叉熵损失函数,同时考虑原始样本和生成样本/>的信息,具体是,最小化关联交叉熵损失函数LMCE,记为/>,公式如下:
;
这个函数是两个标准交叉熵损失函数LCE的加权和,其中和/>分别是原始样本/>和生成样本/>的对应标签,权重/>用于调节这两个损失函数的相对贡献,/>由参数/>确定的网络模型输入关联样本/>得到每个类别的概率/>,再与原始样本/>的对应标签/>做交叉熵损失,/>由参数/>确定的网络模型输入关联样本得到每个类别的概率/>,再与生成样本/>的对应标签/>做交叉熵损失。
5.根据权利要求4所述的基于渐进式联想学习的自蒸馏方法,其特征在于,总损失LALSD具体公式为:
;
在这个公式中,λ是关联样本交叉熵损失的权重,而β是关联样本蒸馏损失的权重;其中,针对关联样本,设计关联蒸馏损失函数LCls,衡量类间相似性,记为,表示如下:
;
这个函数是两个KL散度项的加权和,其中、/>和分别是以参数为/>的网络模型输入数据为原始样本/>的输出概率、以参数为/>的网络模型输入数据为生成样本/>为的输出概率和以参数/>的网络模型输入数据为关联样本/>为的输出概率,大小均为B×C,/>是采用了KL散度来衡量原始样本/>和关联样本/>之间的概率分布差异,是采用了KL 散度来衡量生成样本/>和关联样本/>之间的概率分布差异,/>是由贝塔分布生成的图像混合比例系数,用于平衡这两个项的贡献,而/>是参数/>的固定副本,在训练过程中学生网络和教师网络权重共享。
6.基于渐进式联想学习的自蒸馏系统,其特征在于,用于实现如权利要求1-5任一项所述的基于渐进式联想学习的自蒸馏方法,所述系统包括数据预处理模块、构建关联样本模块、由教师网络和学生网络组成的网络模型、损失计算模块,
所述数据预处理模块用于获取输入的不同类型的图像数据,并进行预处理生成训练数据集样本,包括原始样本及其对应的标签;
所述构建关联样本模块用于将原始样本和在基于渐进式联想学习方法生成的生成样本通过样本混合方法结合,将两个样本以像素级别混合,构建关联样本;
所述网络模型包括教师网络特征提取模块和学生网络特征提取模块,通过学生网络特征提取模块提取关联样本中的类别特征和类间关系,输出概率分布,用于模型自蒸馏训练;通过教师网络特征提取模块提取原始样本中的特征,输出概率分布,用来指导学生网络的自我学习过程;
所述损失计算模块用于计算关联交叉熵损失和关联蒸馏损失的加权和。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410288255.6A CN117892841B (zh) | 2024-03-14 | 2024-03-14 | 基于渐进式联想学习的自蒸馏方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410288255.6A CN117892841B (zh) | 2024-03-14 | 2024-03-14 | 基于渐进式联想学习的自蒸馏方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117892841A true CN117892841A (zh) | 2024-04-16 |
CN117892841B CN117892841B (zh) | 2024-05-31 |
Family
ID=90645052
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410288255.6A Active CN117892841B (zh) | 2024-03-14 | 2024-03-14 | 基于渐进式联想学习的自蒸馏方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117892841B (zh) |
Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
KR102225579B1 (ko) * | 2020-05-14 | 2021-03-10 | 아주대학교산학협력단 | 학습성능이 향상된 지식 증류법 기반 의미론적 영상 분할 방법 |
CN114049513A (zh) * | 2021-09-24 | 2022-02-15 | 中国科学院信息工程研究所 | 一种基于多学生讨论的知识蒸馏方法和系统 |
CN115995018A (zh) * | 2022-12-09 | 2023-04-21 | 厦门大学 | 基于样本感知蒸馏的长尾分布视觉分类方法 |
US20230214719A1 (en) * | 2021-12-31 | 2023-07-06 | Research & Business Foundation Sungkyunkwan University | Method for performing continual learning using representation learning and apparatus thereof |
US20230222353A1 (en) * | 2020-09-09 | 2023-07-13 | Vasileios LIOUTAS | Method and system for training a neural network model using adversarial learning and knowledge distillation |
CN116913504A (zh) * | 2023-07-13 | 2023-10-20 | 重庆理工大学 | 用于单导联心律失常诊断的自监督多视图知识蒸馏方法 |
CN116994015A (zh) * | 2022-04-21 | 2023-11-03 | 北京工业大学 | 一种基于递进式知识传递的自蒸馏分类方法 |
KR20230156461A (ko) * | 2022-05-06 | 2023-11-14 | 아주대학교산학협력단 | 지식 증류를 활용한 그룹 기반의 학습을 수행하는 전자장치 및 방법 |
CN117494780A (zh) * | 2023-08-30 | 2024-02-02 | 中国科学院信息工程研究所 | 一种混合学习中知识蒸馏的学生网络训练方法 |
-
2024
- 2024-03-14 CN CN202410288255.6A patent/CN117892841B/zh active Active
Patent Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
KR102225579B1 (ko) * | 2020-05-14 | 2021-03-10 | 아주대학교산학협력단 | 학습성능이 향상된 지식 증류법 기반 의미론적 영상 분할 방법 |
US20230222353A1 (en) * | 2020-09-09 | 2023-07-13 | Vasileios LIOUTAS | Method and system for training a neural network model using adversarial learning and knowledge distillation |
CN114049513A (zh) * | 2021-09-24 | 2022-02-15 | 中国科学院信息工程研究所 | 一种基于多学生讨论的知识蒸馏方法和系统 |
US20230214719A1 (en) * | 2021-12-31 | 2023-07-06 | Research & Business Foundation Sungkyunkwan University | Method for performing continual learning using representation learning and apparatus thereof |
CN116994015A (zh) * | 2022-04-21 | 2023-11-03 | 北京工业大学 | 一种基于递进式知识传递的自蒸馏分类方法 |
KR20230156461A (ko) * | 2022-05-06 | 2023-11-14 | 아주대학교산학협력단 | 지식 증류를 활용한 그룹 기반의 학습을 수행하는 전자장치 및 방법 |
CN115995018A (zh) * | 2022-12-09 | 2023-04-21 | 厦门大学 | 基于样本感知蒸馏的长尾分布视觉分类方法 |
CN116913504A (zh) * | 2023-07-13 | 2023-10-20 | 重庆理工大学 | 用于单导联心律失常诊断的自监督多视图知识蒸馏方法 |
CN117494780A (zh) * | 2023-08-30 | 2024-02-02 | 中国科学院信息工程研究所 | 一种混合学习中知识蒸馏的学生网络训练方法 |
Non-Patent Citations (4)
Title |
---|
HAORAN ZHAO: "Knowledge Distillation via Instance-level Sequence Learning", ARXIV, 30 June 2021 (2021-06-30) * |
凌弘毅;: "基于知识蒸馏方法的行人属性识别研究", 计算机应用与软件, no. 10, 12 October 2018 (2018-10-12) * |
葛仕明;赵胜伟;刘文瑜;李晨钰;: "基于深度特征蒸馏的人脸识别", 北京交通大学学报, no. 06, 15 December 2017 (2017-12-15) * |
赵胜伟;葛仕明;叶奇挺;罗朝;李强;: "基于增强监督知识蒸馏的交通标识分类", 中国科技论文, no. 20, 23 October 2017 (2017-10-23) * |
Also Published As
Publication number | Publication date |
---|---|
CN117892841B (zh) | 2024-05-31 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109086658B (zh) | 一种基于生成对抗网络的传感器数据生成方法与系统 | |
CN110377710B (zh) | 一种基于多模态融合的视觉问答融合增强方法 | |
Chen et al. | Relation attention for temporal action localization | |
CN110263912B (zh) | 一种基于多目标关联深度推理的图像问答方法 | |
Cheng et al. | Facial expression recognition method based on improved VGG convolutional neural network | |
CN111414461B (zh) | 一种融合知识库与用户建模的智能问答方法及系统 | |
CN110609891A (zh) | 一种基于上下文感知图神经网络的视觉对话生成方法 | |
CN109977893B (zh) | 基于层次显著性通道学习的深度多任务行人再识别方法 | |
Zhao et al. | Disentangled representation learning and residual GAN for age-invariant face verification | |
CN113657115B (zh) | 一种基于讽刺识别和细粒度特征融合的多模态蒙古文情感分析方法 | |
CN113255602A (zh) | 基于多模态数据的动态手势识别方法 | |
CN112883931A (zh) | 基于长短期记忆网络的实时真假运动判断方法 | |
Gogate et al. | Real time emotion recognition and gender classification | |
CN113988079A (zh) | 一种面向低数据的动态增强多跳文本阅读识别处理方法 | |
CN115408603A (zh) | 一种基于多头自注意力机制的在线问答社区专家推荐方法 | |
CN115827954A (zh) | 动态加权的跨模态融合网络检索方法、系统、电子设备 | |
CN116561614A (zh) | 一种基于元学习的小样本数据处理系统 | |
CN114417975A (zh) | 基于深度pu学习与类别先验估计的数据分类方法及系统 | |
Jiang et al. | Cross-level reinforced attention network for person re-identification | |
Ling et al. | A facial expression recognition system for smart learning based on YOLO and vision transformer | |
CN117892841B (zh) | 基于渐进式联想学习的自蒸馏方法及系统 | |
CN113626537B (zh) | 一种面向知识图谱构建的实体关系抽取方法及系统 | |
Wu et al. | Boundaryface: A mining framework with noise label self-correction for face recognition | |
CN115439791A (zh) | 跨域视频动作识别方法、装置、设备和计算机可存储介质 | |
CN114357166A (zh) | 一种基于深度学习的文本分类方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |