CN115456166A - 一种无源域数据的神经网络分类模型知识蒸馏方法 - Google Patents
一种无源域数据的神经网络分类模型知识蒸馏方法 Download PDFInfo
- Publication number
- CN115456166A CN115456166A CN202211040839.9A CN202211040839A CN115456166A CN 115456166 A CN115456166 A CN 115456166A CN 202211040839 A CN202211040839 A CN 202211040839A CN 115456166 A CN115456166 A CN 115456166A
- Authority
- CN
- China
- Prior art keywords
- model
- samples
- teacher model
- teacher
- knowledge distillation
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 48
- 238000000034 method Methods 0.000 title claims abstract description 38
- 238000013528 artificial neural network Methods 0.000 title claims abstract description 24
- 238000013145 classification model Methods 0.000 title claims abstract description 21
- 239000011159 matrix material Substances 0.000 claims abstract description 27
- 238000009826 distribution Methods 0.000 claims abstract description 19
- 238000004821 distillation Methods 0.000 claims abstract description 15
- 230000009193 crawling Effects 0.000 claims abstract description 9
- 238000007781 pre-processing Methods 0.000 claims abstract description 8
- 239000013598 vector Substances 0.000 claims description 4
- 230000004913 activation Effects 0.000 claims description 3
- 238000004140 cleaning Methods 0.000 claims description 3
- 238000012549 training Methods 0.000 abstract description 39
- 238000012217 deletion Methods 0.000 abstract description 2
- 230000037430 deletion Effects 0.000 abstract description 2
- 238000012163 sequencing technique Methods 0.000 abstract description 2
- 238000004590 computer program Methods 0.000 description 7
- 238000010586 diagram Methods 0.000 description 7
- 230000006870 function Effects 0.000 description 5
- 238000012986 modification Methods 0.000 description 5
- 230000004048 modification Effects 0.000 description 5
- 238000003062 neural network model Methods 0.000 description 5
- 230000000694 effects Effects 0.000 description 4
- 238000012545 processing Methods 0.000 description 4
- 238000003860 storage Methods 0.000 description 4
- 235000011034 Rubus glaucus Nutrition 0.000 description 3
- 235000009122 Rubus idaeus Nutrition 0.000 description 3
- 230000006835 compression Effects 0.000 description 3
- 238000007906 compression Methods 0.000 description 3
- 238000013135 deep learning Methods 0.000 description 3
- 240000007651 Rubus glaucus Species 0.000 description 2
- 238000005520 cutting process Methods 0.000 description 2
- 238000013508 migration Methods 0.000 description 2
- 230000005012 migration Effects 0.000 description 2
- 238000010606 normalization Methods 0.000 description 2
- 230000008569 process Effects 0.000 description 2
- 238000012795 verification Methods 0.000 description 2
- 241000282472 Canis lupus familiaris Species 0.000 description 1
- 241000282994 Cervidae Species 0.000 description 1
- 241000196324 Embryophyta Species 0.000 description 1
- 241000282326 Felis catus Species 0.000 description 1
- 241001465754 Metazoa Species 0.000 description 1
- 244000235659 Rubus idaeus Species 0.000 description 1
- 230000003044 adaptive effect Effects 0.000 description 1
- 230000004075 alteration Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 210000004556 brain Anatomy 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 238000007418 data mining Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000003704 image resize Methods 0.000 description 1
- 230000010354 integration Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000012423 maintenance Methods 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 239000002184 metal Substances 0.000 description 1
- 230000001617 migratory effect Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000004044 response Effects 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本发明涉及一种无源域数据的神经网络分类模型知识蒸馏方法,爬取数据,对候选数据集的所有样本预处理,作为候选样本;基于教师模型的分类层权重构建类别相似度矩阵,基于教师模型计算所有候选样本的logits值并进行SoftMaxt操作,计算所有候选样本与教师模型领域分布的差异度;根据差异度将所有样本从小到大排序,选取前M个样本作为知识蒸馏数据集,对教师模型知识蒸馏得到学生模型。本发明避免不合理的高置信度领域外样本被错误选中,不完全需要源域数据,具有更高的分类准确率与更广的模型适用性,更快的删选速度,能够在资源较少的情况下对预训练模型进行知识蒸馏;基于知识蒸馏的各种下游任务都能够解决。
Description
技术领域
本发明涉及计算;推算或计数的技术领域,特别涉及一种机器学习和数据挖掘领域的无源域数据的神经网络分类模型知识蒸馏方法。
背景技术
近年来,深度神经网络在视觉分类领域取得了非常不错的应用效果,被广泛地运用在各个行业。神经网络表现出卓越性能的一个前提是测试数据与训练数据服从独立同分布,然而,在现实世界中部署训练好的模型在各种平台上时,随着业务的发展,我们通常要对该模型进行升级维护,例如需要将该模型进行压缩以便部署到边缘设备,或是需要增加模型的分类能力(增量学习)等。高性能的网络往往伴随着庞大的网络架构,这意味着要想获得高性能就必须得付出高计算、高内存占用的代价,这使得大多数庞大的高性能网络无法部署在内存资源较低或者要求响应速度快的终端设备,例如交通摄像头和自动驾驶大脑中,而对模型的升级维护也面临较大的困难。
近年来,在该技术问题上取得了巨大的进展,尤其是基于知识蒸馏的模型压缩。当我们可以直接访问源域训练数据集时,现有的很多知识蒸馏方法对于训练紧凑的深度模型非常有效。然而,知识蒸馏的效果基于源域数据可用的前提,目前对于数据隐私与安全问题的重视,使得模型训练完后其训练集应该是保密不可泄露的,这就导致了很多基于大模型的下游任务无法进行,而知识蒸馏的模型压缩也正遇到了该困难,当然,还有源域数据集过大不便运输存储等问题,使得无法访问源域数据;现实中往往只能得到一个训练好的深度神经网络模型。
现有技术中,基于最大置信度或者最小熵的方法来衡量一个样本是否属于给定的模型的学习领域分布,神经网络模型在实际运行时,输入与训练数据集分布不一致的数据时很有可能使得分类模型失效,因为它有可能对领域外的样本给出很高置信度的预测,这将导致通过删选的仿领域样本中存在较多领域外样本。
当给定一个预训练的大型神经网络模型,如果可以访问模型源领域数据,将很容易通过对该模型进行知识蒸馏达到模型知识迁移的目的,但现实中往往就仅有一个预训练模型,无法有效访问模型源领域数据。
发明内容
本发明解决了现有技术中存在的问题,提供了一种无源域数据的神经网络分类模型知识蒸馏方法,通过包括但不限于预训练模型选择符合模型训练集领域分布的图像数据进行神经网络知识蒸馏,考虑样本预测值内部类别信息以避免不合理的高置信度的领域外样本被错误选中,以删选后的接近模型领域内分布的数据对预训练模型进行知识蒸馏从而达到迁移预训练模型知识的目的;提高预训练模型在下游任务的可迁移性或预训练模型后续维护的扩展性,可以应用于基于知识蒸馏的模型压缩、增量学习、迁移学习等诸多深度学习领域。
本发明所采用的技术方案是,一种无源域数据的神经网络分类模型知识蒸馏方法,所述方法包括以下步骤:
步骤1:爬取数据,获得候选数据集;
步骤2:对候选数据集的所有样本进行预处理,作为候选样本;
步骤3:基于教师模型的分类层权重构建类别相似度矩阵C;
步骤4:基于教师模型计算所有候选样本的logits值并进行SoftMaxt操作,最后计算所有候选样本与教师模型领域分布的差异度;
步骤5:根据差异度将所有样本从小到大进行排序;
步骤6:选取前M个样本作为知识蒸馏数据集,M≥100;
步骤7:基于知识蒸馏数据集对教师模型进行知识蒸馏,得到学生模型。
优选地,所述步骤1中,若已知教师模型所分类别,则直接爬取类别关键词对应的数据,否则,判断已知教师模型是否为对某个超类进行分类,若是,则将超类名称作为关键词爬取相关数据,否则爬取任意类别的相关类型数据或者使用已公开的数据集。
优选地,所述预处理包括数据清洗、将输入的候选样本调整至统一格式。
优选地,所述步骤3包括以下步骤:
步骤3.1:通过教师模型的分类层权重计算各个类别的相关性,得到式(1),
其中,C(i,j)表示类别i与类别j之间的相似度,wi和wj分别为类别i与类别j的权重,i≠j;
步骤3.2:重复步骤3.1直至计算所有类别间的相似度,得到矩阵C′;
步骤3.3:将矩阵C′归一化处理,使得矩阵每行之和为1,得到最终的类别相似度矩阵C。
优选地,所述步骤4包括以下步骤:
步骤4.1:对于任一输入的候选样本x,以教师模型对其进行预测,得到输出的logits值T(x,θT),其中T表示教师模型,θT表示教师模型的参数;
优选地,所述步骤4.1中,对于任一候选样本x,预测为类别k的logits值为式(3),
Ti(x,θT)=[g(x),1]Twk (3)
其中,g(x)表示教师模型的倒数第二层激活的特征向量,该层与“1”连接以解释偏差,wk是教师模型第k个类的权重与偏置。
优选地,所述步骤4中,差异度为式(4),
其中,XW为候选样本数据集,mc(.)表示的是预测样本的最大置信度,Xsub为选择的知识蒸馏数据集,Ck为类别相似度矩阵的第k行,DKL为KL散度,S表示SoftMax操作。
优选地,将选中的领域数据集输入到教师模型和学生模型,计算蒸馏损失LkD,固定教师模型参数,使用LkD损失更新学生模型,蒸馏损失LkD满足式(5),
L>D=αLsoft+βLhard (5)
其中,α,β≥0且α+β=1,Lsoft是以软标签计算教师模型与学生模型的损失,Lhard是学生模型在真实标签下的交叉熵损失。
本发明涉及一种无源域数据的神经网络分类模型知识蒸馏方法,爬取数据,获得候选数据集后,对候选数据集的所有样本进行预处理,作为候选样本;基于教师模型的分类层权重构建类别相似度矩阵C,基于教师模型计算所有候选样本的logits值并进行SoftMaxt操作,计算所有候选样本与教师模型领域分布的差异度;根据差异度将所有样本从小到大进行排序,选取前M个样本作为知识蒸馏数据集,基于知识蒸馏数据集对教师模型进行知识蒸馏,得到学生模型。
本发明的有益效果在于:
(1)充分利用教师分类器模型所提供的类别之间的信息,避免不合理的高置信度领域外样本被错误选中为满足预训练模型领域分布;
(2)即使某些神经网络模型仅能够提供教师模型分类器层的权重,也可以通过本方法删选候选数据来进行知识蒸馏,并不完全需要源域数据;
(3)相比于传统采用最大置信度或者最小熵删选候选数据集的方法,本发明具有更高的分类准确率与更广的模型适用性;
(4)相比于其他无源域数据知识蒸馏技术,本发明具有更快的删选速度,且只需要更少的计算资源,能够在资源较少的情况下,对预训练模型进行知识蒸馏;
(5)有效性在cifar10与cifar100数据集上得到验证,其知识蒸馏学生模型在cifar10验证集上的准确率到达94.68%,在cifar100验证集上的准确率到达72.33%;基于知识蒸馏的各种下游任务都能够解决,包括基于知识蒸馏的增量学习、迁移学习等。
附图说明
图1为本发明的无源域数据知识蒸馏示意图;
图2为本发明的方法流程图。
具体实施方式
下面结合实施例对本发明做进一步的详细描述,但本发明的保护范围并不限于此。
本发明涉及一种无源域数据的神经网络分类模型知识蒸馏方法,通过已训练好的网络模型从互联网获取与预训练模型(教师模型)训练集领域分布差异小于阈值的数据集的方法,并通过知识蒸馏方法来迁移预训练模型的知识。
所述方法包括以下步骤:
步骤1:爬取数据,获得候选数据集;
所述步骤1中,若已知教师模型所分类别,则直接爬取类别关键词对应的数据,否则,判断已知教师模型是否为对某个超类进行分类,若是,则将超类名称作为关键词爬取相关数据,否则爬取任意类别的相关类型数据或者使用已公开的数据集。
本发明中,为了获取与教师模型源域接近的数据集,首先需要收集大量数据,此处以图像数据为例,一般来说,从互联网上爬取数据或者使用公共大规模数据集作为候选数据集。以在谷歌爬取数据为例,如果已知教师模型分类器具体类别,则直接通过谷歌爬取类别关键词的图像,例如数据集CIFAR-10中的汽车、猫、狗等;不清楚具体类别但已知教师模型是对某个超类进行分类的,则将超类名称作为关键词通过谷歌爬取相关关键词图像,例如动物、数字、建筑、花草等;当仅有教师模型、除此之外没有任何的分类器类别信息时,则直接通过谷歌爬取任意类别的图像或者直接使用ImageNet数据集作为候选数据集。
本发明中,候选数据集的大小由教师模型大小和类别数量等决定,往往候选数据集越大效果越好,但是所需计算资源也将越大,因此可根据计算资源和所需的效果决定候选数据量的大小。一般而言,候选数据集的数据量不小于教师模型训练集的20倍,若无法得知教师模型训练集大小,至少将候选数据设置为100万数据量大小。
步骤2:对候选数据集的所有样本进行预处理,作为候选样本;
所述预处理包括数据清洗、将输入的候选样本调整至统一格式。
本发明中,数据清洗为清洗掉有明显格式错误、数据缺失的数据;以样本为图像为例,将输入的候选样本调整至统一格式是指调整图像至同一尺寸,使得每张图像的长和宽相等,并将图片格式统一转化为png或者jpeg等格式。
步骤3:基于教师模型的分类层权重构建类别相似度矩阵C;
在已经训练好的预训练模型中,最后一层是具有softmax非线性的完全连接层,经过教师编码器编码得到的低维特征将由分类器的权重决定将其预测为各类的概率,分类器某个类别的权重其实就是预训练模型分类器对该类样本的低维特征的建模。
所述步骤3包括以下步骤:
步骤3.1:通过教师模型的分类层权重计算各个类别的相关性,得到式(1),
其中,C(i,j)表示类别i与类别j之间的相似度,wi和wj分别为类别i与类别j的权重,i≠j;
步骤3.2:重复步骤3.1直至计算所有类别间的相似度,得到矩阵C′;
步骤3.3:将矩阵C′归一化处理,使得矩阵每行之和为1,得到最终的类别相似度矩阵C。
本发明中,决定任一样本x被预测为第k类的关键是分类器前的特征与第k类的模板wk(分类器第k个类的权重与偏置)的对齐程度;当编码特征与权重一致时,则预测为第k类的softmax值最大,与此相反,当编码特征与权重相反时则预测为第k类的softmax值最小;基于此,很容易通过分类器权重计算各个类别的相关性。
本发明中,通过计算各个类别之间的相似度得到类别相似度矩阵,除了相似度最高的是类别本身之间之外,举例来说,卡车与汽车的相似度显然比汽车与鸟或者鹿的相似度高。
本发明中,为了能够体现预测向量的类别之间的相似度,需要对相似度矩阵进行Softmax归一化,并给定一个温度使得归一化后的概率分布更加平滑,将类别相似性矩阵按照行进行归一化,使得每一行概率的总和为1。
步骤4:基于教师模型计算所有候选样本的logits值并进行SoftMaxt操作,最后计算所有候选样本与教师模型领域分布的差异度;
所述步骤4包括以下步骤:
步骤4.1:对于任一输入的候选样本x,以教师模型对其进行预测,得到输出的logits值T(x,θT),其中T表示教师模型,θT表示教师模型的参数;
所述步骤4.1中,对于任一候选样本x,预测为类别k的logits值为式(3),
Ti(x,θT)=[g(x),1]Tw> (3)
其中,g(x)表示教师模型的倒数第二层激活的特征向量,该层与“1”连接以解释偏差,wk是教师模型第k个类的权重与偏置。
所述步骤4中,差异度为式(4),
其中,XW为候选样本数据集,mc(.)表示的是预测样本的最大置信度,Xsub为选择的知识蒸馏数据集,Ck为类别相似度矩阵的第k行,DKL为KL散度,S表示SoftMax操作。
步骤5:根据差异度将所有样本从小到大进行排序;差异度越小则说明该样本越符合预训练模型源领域分布,差异度越大则说明该样本更偏离预训练模型领域分布。
步骤6:选取前M个样本作为知识蒸馏数据集,M≥100;从排序好的样本中挑选前面M个样本作为选中的接近预训练模型领域分布的数据集,M越小则该数据越接近预训练模型领域分布,M越大则其数据分布与预训练模型领域分布偏差越大;M作为超参数,将视下游任务而决定。
步骤7:基于知识蒸馏数据集对教师模型进行知识蒸馏,得到学生模型。
将选中的领域数据集输入到教师模型和学生模型,计算蒸馏损失LkD,固定教师模型参数,使用LkD损失更新学生模型,蒸馏损失LkD满足式(5),
LkD=αLsofZ+βLhard (5)
其中,α,β≥0且α+β=1,LsofZ是以软标签计算教师模型与学生模型的损失,Lhard是学生模型在真实标签下的交叉熵损失。
本发明中,α一般设置为0.7,β一般设置为0.3,T通常设置为3。
本发明中,给出一具体实施例:
步骤1:令仅有在CIFAR-10上的预训练模型,除此之外没有任何的分类器类别信息,故直接使用ImageNet2012数据集作为候选数据集;候选数据集包含1281167张格式为jpeg的RGB图片,包含1000个类别;
步骤2:删除由于收集过程存在乱码而得到的明显错误数据、删除缺少格式错误数据,将非jpeg格式图片格式统一改为jpeg格式;图像resize后,将CIFAR-10数据集中图片一致调整为32×32像素的彩色图片,或根据需要将所有图片裁剪到相同大小的尺寸,用于预训练模型的预测;
可以通过使用pytorch、tensorflow等深度学习框架对数据集执行以上操作,pytorch中的torchvision包中已经提供对图片进行裁剪与一定数据增强即图片归一化等操作,完成后,通过提供的框架接口,将图片输入载入DataLoader中,供后续训练使用;
步骤3:如果已经提供了教师模型可以直接使用,也可以从互联网上下载预训练的教师模型;选取resnet34网络模型作为教师模型,根据交叉熵损失函数,使用pytorch提供的随机梯度下降优化器对教师模型进行训练,通过200次迭代,以最后一次迭代的模型作为教师模型;
对于教师模型,保留其分类层的权重矩阵W,维度为512×10,512为特征的长度,10表示教师模型分类器的类别个数;计算得到类别i和类别j之间的的相似度,wk是分类器第k个类的权重与偏置,得到的类别相似度矩阵的维度为10×10,将权重相关性矩阵映射到(0,1)的概率空间上,得到一个每行和为1的概率类别相关性矩阵C,维度为10×10;
步骤4:将候选数据就ImageNet2012依次输入到教师模型中进行预测,得到logits输出,维度为128×10,128为batch size的大小,10表示类别个数,进行SoftMax操作,得到概率预测值,维度为128×10;
计算各个样本与预训练模型领域分布的差异度;
步骤5:通过python内置的sort()排序方法对所有样本按照差异度从小到大排序,保存在txt文本文件中,在txt文本文件中的数据结构为字典类型,即每一行代表一个样本,第一列表示该图片的索引位置,第二列表示其计算的差异度;
步骤6:从排序好的样本中挑选前面M个样本作为选中的接近预训练模型领域分布的数据集;本实施例中CIFAR-10的预训练模型的训练集为50000张图片,设置M为100000;通过python代码读取步骤6中保存的前100000个图片索引读取选中的图片作为知识蒸馏数据集;
步骤7:使用源域数据集训练ResNet34架构的神经网络作为教师模型,使用随机参数化的ResNet18架构的神经网络作为学生网络,教师网络与学生网络都使用随机梯度Nesterov Accelerated Gradient(NAG)优化,weight decay和momentum分别设置为0.0005和0.9,共计训练200个epoch;超参数温度T设置为3,学习率初始值为0.1,随着epoch余弦衰减为0;α设置为0.7,β设置为0.3;
对于CIFAR-10,从ImageNet2012数据集中各自挑选了100000张图片作为蒸馏数据集,在知识蒸馏之前,我们将挑选的数据依次进行resize成32*32大小图片,使用4像素填充图片边缘并裁剪成32*32的图片,随机水平翻转,标准化等一系列处理;最后到训练收敛后,保存训练好的学生模型即可。
为了实现上述内容,在得到学生模型后,将得到的学生模型部署到嵌入式设备实现应用。
举例来说,树莓派是一种运行Raspbian OS(官方系统)的微型电脑,具有易于二次开发、集成度高等特点,因此将其作为部署神经网络模型演示实际图像分类任务的硬件平台。实施中,采用的树莓派型号为Raspberry Pi 4 Model B,处理器为博通BCM2711(四核Cortex-A72),主频1.5GHz,内存大小为2GB,外接32GB闪迪MicroSD存储卡;树莓派端的深度学习框架为针对arm7l指令集版本的Pytorch,版本号为1.6。此微型电脑执行中实现上述无源域数据的神经网络分类模型知识蒸馏方法,从而解决现有技术中无法有效访问模型源领域数据导致无法进行知识蒸馏的问题。
本领域内的技术人员应明白,本发明的实施例可提供为方法、系统、或计算机程序产品。因此,本发明可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本发明可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本发明是参照根据本发明实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
尽管已描述了本发明的优选实施例,但本领域内的技术人员一旦得知了基本创造性概念,则可对这些实施例作出另外的变更和修改。所以,所附权利要求意欲解释为包括优选实施例以及落入本发明范围的所有变更和修改。
显然,本领域的技术人员可以对本发明进行各种改动和变型而不脱离本发明的精神和范围。这样,倘若本发明的这些修改和变型属于本发明权利要求及其等同技术的范围之内,则本发明也意图包含这些改动和变型在内。
Claims (10)
1.一种无源域数据的神经网络分类模型知识蒸馏方法,其特征在于:所述方法包括以下步骤:
步骤1:爬取数据,获得候选数据集;
步骤2:对候选数据集的所有样本进行预处理,作为候选样本;
步骤3:基于教师模型的分类层权重构建类别相似度矩阵C;
步骤4:基于教师模型计算所有候选样本的logits值并进行SoftMaxt操作,最后计算所有候选样本与教师模型领域分布的差异度;
步骤5:根据差异度将所有样本从小到大进行排序;
步骤6:选取前M个样本作为知识蒸馏数据集,M≥100;
步骤7:基于知识蒸馏数据集对教师模型进行知识蒸馏,得到学生模型。
2.根据权利要求1所述的一种无源域数据的神经网络分类模型知识蒸馏方法,其特征在于:所述步骤1中,若已知教师模型所分类别,则直接爬取类别关键词对应的数据,否则,判断已知教师模型是否为对某个超类进行分类,若是,则将超类名称作为关键词爬取相关数据,否则爬取任意类别的相关类型数据或者使用已公开的数据集。
3.根据权利要求1所述的一种无源域数据的神经网络分类模型知识蒸馏方法,其特征在于:所述预处理包括数据清洗、将输入的候选样本调整至统一格式。
6.根据权利要求5所述的一种无源域数据的神经网络分类模型知识蒸馏方法,其特征在于:所述步骤4.1中,对于任一候选样本x,预测为类别k的logits值为式(3),
Ti(x,θT)=[g(x),1]Twk (3)
其中,g(x)表示教师模型的倒数第二层激活的特征向量,该层与“1”连接以解释偏差,wk是教师模型第k个类的权重与偏置。
8.根据权利要求1所述的一种无源域数据的神经网络分类模型知识蒸馏方法,其特征在于:将选中的领域数据集输入到教师模型和学生模型,计算蒸馏损失LkD,固定教师模型参数,使用LkD损失更新学生模型,蒸馏损失LkD满足式(5),
LkD=αLsoft+βLhard (5)
其中,α,β≥0且α+β=1,Lsoft是以软标签计算教师模型与学生模型的损失,Lhard是学生模型在真实标签下的交叉熵损失。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211040839.9A CN115456166A (zh) | 2022-08-29 | 2022-08-29 | 一种无源域数据的神经网络分类模型知识蒸馏方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211040839.9A CN115456166A (zh) | 2022-08-29 | 2022-08-29 | 一种无源域数据的神经网络分类模型知识蒸馏方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115456166A true CN115456166A (zh) | 2022-12-09 |
Family
ID=84301327
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211040839.9A Pending CN115456166A (zh) | 2022-08-29 | 2022-08-29 | 一种无源域数据的神经网络分类模型知识蒸馏方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115456166A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116543237A (zh) * | 2023-06-27 | 2023-08-04 | 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) | 无源域无监督域适应的图像分类方法、系统、设备及介质 |
CN116861302A (zh) * | 2023-09-05 | 2023-10-10 | 吉奥时空信息技术股份有限公司 | 一种案件自动分类分拨方法 |
-
2022
- 2022-08-29 CN CN202211040839.9A patent/CN115456166A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116543237A (zh) * | 2023-06-27 | 2023-08-04 | 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) | 无源域无监督域适应的图像分类方法、系统、设备及介质 |
CN116543237B (zh) * | 2023-06-27 | 2023-11-28 | 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) | 无源域无监督域适应的图像分类方法、系统、设备及介质 |
CN116861302A (zh) * | 2023-09-05 | 2023-10-10 | 吉奥时空信息技术股份有限公司 | 一种案件自动分类分拨方法 |
CN116861302B (zh) * | 2023-09-05 | 2024-01-23 | 吉奥时空信息技术股份有限公司 | 一种案件自动分类分拨方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN108334605B (zh) | 文本分类方法、装置、计算机设备及存储介质 | |
CN111738436B (zh) | 一种模型蒸馏方法、装置、电子设备及存储介质 | |
CN110362723B (zh) | 一种题目特征表示方法、装置及存储介质 | |
CN115456166A (zh) | 一种无源域数据的神经网络分类模型知识蒸馏方法 | |
CN114912612A (zh) | 鸟类识别方法、装置、计算机设备及存储介质 | |
CN110188195B (zh) | 一种基于深度学习的文本意图识别方法、装置及设备 | |
CN110929524A (zh) | 数据筛选方法、装置、设备及计算机可读存储介质 | |
CN114491039B (zh) | 基于梯度改进的元学习少样本文本分类方法 | |
CN110310012B (zh) | 数据分析方法、装置、设备及计算机可读存储介质 | |
US20200218932A1 (en) | Method and system for classification of data | |
CN115034315B (zh) | 基于人工智能的业务处理方法、装置、计算机设备及介质 | |
CN110232128A (zh) | 题目文本分类方法及装置 | |
CN114528835A (zh) | 基于区间判别的半监督专业术语抽取方法、介质及设备 | |
CN113742733A (zh) | 阅读理解漏洞事件触发词抽取和漏洞类型识别方法及装置 | |
CN115858388A (zh) | 基于变异模型映射图的测试用例优先级排序方法和装置 | |
CN114818707A (zh) | 一种基于知识图谱的自动驾驶决策方法和系统 | |
US20230063686A1 (en) | Fine-grained stochastic neural architecture search | |
CN115952438A (zh) | 社交平台用户属性预测方法、系统、移动设备及存储介质 | |
CN115730221A (zh) | 基于溯因推理的虚假新闻识别方法、装置、设备及介质 | |
CN117011577A (zh) | 图像分类方法、装置、计算机设备和存储介质 | |
CN108920492A (zh) | 一种网页分类方法、系统、终端及存储介质 | |
CN111400413B (zh) | 一种确定知识库中知识点类目的方法及系统 | |
CN113821571A (zh) | 基于bert和改进pcnn的食品安全关系抽取方法 | |
CN113297376A (zh) | 基于元学习的法律案件风险点识别方法及系统 | |
CN112698977A (zh) | 服务器故障定位方法方法、装置、设备及介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |