CN110097084A - 通过投射特征训练多任务学生网络的知识融合方法 - Google Patents
通过投射特征训练多任务学生网络的知识融合方法 Download PDFInfo
- Publication number
- CN110097084A CN110097084A CN201910264911.8A CN201910264911A CN110097084A CN 110097084 A CN110097084 A CN 110097084A CN 201910264911 A CN201910264911 A CN 201910264911A CN 110097084 A CN110097084 A CN 110097084A
- Authority
- CN
- China
- Prior art keywords
- network
- teacher
- targetnet
- block
- obtains
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/25—Fusion techniques
- G06F18/253—Fusion techniques of extracted features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
Abstract
通过投射特征训练多任务学生网络的知识融合方法,由以下步骤组成:首先初始化TargetNet(目标学生网络)的结构与教师网络相同,通过通道编码,将TargetNet中融合的特征图投射为对应任务的特征图;逐个训练TargetNet中与教师网络对应的block,得到融合的特征图;确定TargetNet中不同任务开始分支的各自位置;将教师网络中的对应block加入学生网络,作为不同任务的分支,并移除TargetNet中末端的block,得到TargetNet最终结构;最后调优学生网络。本方法能够使用无标签数据集,融合多个不同任务的教师网络,得到性能优越的轻量级学生网络。
Description
技术领域
本发明涉及学生网络的知识融合方法。
背景技术
场景语义分割是对图像进行像素级别的标签分类。目前场景语义分割的主流方法为使用卷积神经网络。现有的深度网络模型主要有PSPNet,RefineNet,FinerNet,SegNet。其中PSPNet使用金字塔池化操作获得多尺度特征;RefineNet使用多通路的网络结构,融合低层特征与高层语义特征;FinerNet通过级联一系列的网络,得到不同粒度的语义图;SegNet则使用编码器-解码器结构。其中SegNet网络鲁棒性强,性能先进,故使用该网络为本专利技术的基本网络结构。
早期深度估计方法使用手工定义的特征和图模型,如将深度问题转化为马尔可夫条件随机场问题,这些方法性能不佳。目前方法主要采用卷积神经网络,自动学习不同特征。如使用多尺度深度网络,预测粗粒度深度,随后细化。另有其他方法将深度估计问题与场景语义分割、表面法向量预测任务结合,进行多任务预测。深度估计问题与场景语义分割问题的主要区别在于,前者的输出为连续的正数,后者输出为离散的标签。本专利技术中,将深度估计问题转化为分类问题,将深度划分为N个范围,预测落在各个范围中心部分的概率,计算得到连续的深度值,得到深度估计教师网络。
表面法线预测是对图像进行逐像素的表面法向量预测。表面法线经常在计算机图形学中用于计算光照。现有的法线预测神经网络模型中使用RGB图像或RGB-D图像作为输入。
知识蒸馏技术能够学习事先训练好的深度网络教师模型,通过训练软目标得到一个精简的低复杂度学生网络。该学生网络能够达到与教师网络相近,甚至更高的性能。知识蒸馏技术能够有效利用现有的深度网络模型,一定程度上减轻深度学习领域中标签数据不足的问题。该技术应用于计算机视觉领域的分类问题时主要有两种方式:一种使用单个教师网络,或一组同类别分类的教师网络,得到低复杂度学生网络;另一种通过学习多个分类不同类别的教师网络,得到能够处理复杂分类任务的学生网络。该技术还可应用于目标检测、深度估计以及自然语言处理的序列模型等,可以达到超越教师网络的性能。目前该技术目前的局限性在于只能学习单个教师网络,或一组同任务类型的教师网络,得到的学生网络无法处理多任务。
发明内容
本发明要克服传统知识蒸馏只能学习单个任务的缺陷,以及多任务视觉应用场景中计算资源不足的的不足,在使用无标签数据集、保证学生网络规模不大的基础上,提供一种通过投射融合特征,训练得到多功能高精度学生网络的办法,能够融合多个不同任务的教师网络。
本发明是一种使用针对不同任务的多个教师网络,通过投射融合特征的训练紧凑多功能学生网络的知识融合方法。本发明的通过投射特征训练多任务学生网络的知识融合方法,包括如下步骤:
1)初始化目标网络TargetNet结构,与教师网络相同;
网络使用编码器-解码器结构,编码器中block由卷积层和池化层组成,解码器中block由卷积层和上采样层构成。为TargetNet的第n个block输出的融合特征图,使用不同的通道编码将转化为不同任务域中的特征Fus和Fud。
2)逐个训练TargetNet的block,得到融合特征图
对双任务教师网络,将无标记样本输入教师网络SegNet和DepthNet,得到不同任务的特征图和将教师网络的对应block的和分别替换为Fus和Fud,替换后通过SegNet得到预测分割DepthNet得到预测深度最后对于预测值和教师网络原有预测结果S,D,建立损失函数对于多任务教师网络(加入NormNet为例),一种方法使用不同通道编码映射并通过NormNet得到预测法向量建立损失函数另一种用已训练好的分割与深度估计学生网络TargetNet-2和NormNet作为教师网络,根据步骤1.2)为学生网络TargetNet-3引入通道编码U-Channel Coding,映射为M-ChannelCoding映射为随后将根据步骤1.2)和步骤2.1)得到将限据步骤2.1)得到建立损失函数
3)确定TargetNet中不同任务的各自分支位置;
根据步骤2)中每个block的最终loss, 对不同任务选择分支点p:p=arg minnLn
4)使用教师网络中的对应分支作为学生网络的分支;
确定不同任务的分支位置后,移除TargetNet中从靠后的分支点到网络末尾之间的所有block。不同任务的分支使用对应教师网络中block,得到最终的TargetNet结构。使用步骤2)中的损失函数,利用梯度下降调优TargetNet。
本发明具有的有益效果是:与现存的只能学习单个教师网络,或多个同任务教师网络的知识蒸馏方法相比,能够融合不同任务的教师网络知识,得到轻量级、高性能、多任务的学生网络;在需要部署多任务神经网络的应用场景中,能够大幅度减少机器计算资源、内存空间的消耗,同时能够保证每个任务的高精确度。
附图说明
图1为本发明实施例中的双教师网络知识融合的神经网络示意图。
图2为本发明实施例中的双教师网络知识融合方法融合学习学生网络特征的示意图。
图3为本发明实施例中的多教师网络知识融合方法二的通道编码示意图。
图4为本发明实施例中的学生网络结果与真值、教师网络的效果对比图。
具体实施方式
下面结合附图进一步说明本发明的技术方案。
本发明的一种使用针对不同任务的多个教师网络,通过投射特征训练多任务学生网络的知识融合方法,包括如下步骤:
1.初始化目标网络TargetNet结构,与教师网络相同;
为了保证学生网络足够小,并同时能够拥有与教师网络相近的性能,设置目标网络TargetNet结构具体包括:
1.1.初始化学生网络的结构为与教师网络相同的编码器-解码器结构。编码器中的每个block由两到三个卷积核大小为3x3的卷积层和一个2x2不重叠的最大池化层构成。解码器中的每个block由两到三个卷积核大小为3x3的卷积层和一个上采样层构成。
1.2.为TargetNet的第n个block输出的融合特征图,编码了多个任务的特征。对每个教师网络,均引入通道编码。使用不同的通道编码,转化为不同任务域中的特征。对于场景分割(SegNet)和深度估计(DepthNet)双任务教师网络,通过S-ChannelCoding映射为分割任务的特征Fus,通过D-Channel Coding映射为深度估计任务的特征Fud。
2.逐个训练TargetNet的block,得到融合特征图
2.1.对于双任务教师网络,将无标记样本输入教师网络SegNet和DepthNet,分别对第n个block得到分割任务特征图和深度估计任务特征图融合和时,一种直观思路为采用欧氏距离作为损失函数。这种方法严重浪费时间与计算力。为了减少繁琐的融合过程,使用一种与教师网络关系密切的训练方法,具体包括:首先将通过步骤1.2)的通道编码分别得到Fus和Fud;其次将教师网络的对应block的和分别替换为Fus和Fud,替换后通过SegNet得到预测分割DepthNet得到预测深度最后对于预测值和教师网络原有预测结果S,D,建立损失函数
其中λ1,λ2为定值权重,Lseg,Ldepth分别为SegNet,DepthNet的损失函数。逐block进行梯度下降。
2.2.对于多任务教师网络(以加入表面法向量估计NormNet为例),有两种方法:一种根据步骤1.2)引入NormNet的通道编码M-Channel Coding,映射并通过NormNet得到预测法向量建立损失函数
其中λ1,λ2,λ3为定值权重,Lnorm为NormNet的损失函数,逐block进行梯度下降。另一种使用已训练好的分割与深度估计学生网络TargetNet-2和NormNet作为教师网络,根据步骤1.2)为学生网络TargetNet-3引入通道编码U-Channel Coding,映射为M-Channel Coding映射为随后将根据步骤1.2)和步骤2.1)得到将根据步骤2.1)得到建立损失函数
其中λ1,λ2为定值权重,Lu2为步骤2.1)中损失函数。
3.确定TargetNet中不同任务的各自分支位置
根据步骤2.1)获取每个block的最终loss, 对不同任务选择分支点pseg,pdepth,(pnorm):
其中使所有分支点处于解码器结构内。
4.使用教师网络中的对应分支作为学生网络的分支;
根据步骤3.确定pseg,pdepth后,移除TargetNet中从靠后的分支点到网络末尾之间的所有block。pseg,pdepth之后的block使用对应教师网络中block作为分支,得到最终的TargetNet结构。使用步骤2中的损失函数,利用梯度下降调优TargetNet。
通过上述步骤,可以利用多个不同任务的教师网络得到一个性能更优,规模较小的多任务学生网络。除上述的场景分割、深度估计、表面法向量预测任务之外,还可以应用于其他计算机视觉任务。
本说明书实施例所述的内容仅仅是对发明构思的实现形式的列举,本发明的保护范围的不应当被视为仅限于实施例所陈述的具体形式。相反,本发明涵盖任何由权利要求定义的在本发明的精髓和范围上做的替代、修改、等效方法以及方案。
Claims (1)
1.一种通过投射特征训练多任务学生网络的知识融合方法,包括下列步骤:
1)初始化目标网络TargetNet结构,与教师网络相同;
网络使用编码器-解码器结构,编码器中block由卷积层和池化层组成,解码器中block由卷积层和上采样层构成;为TargetNet的第n个block输出的融合特征图,使用不同的通道编码将转化为不同任务域中的特征Fus和Fud;
2)逐个训练TargetNet的block,得到融合特征图
对双任务教师网络,将无标记样本输入教师网络SegNet和DepthNet,得到不同任务的特征图和将教师网络的对应block的和分别替换为Fus和Fud,替换后通过SegNet得到预测分割DepthNet得到预测深度最后对于预测值 和教师网络原有预测结果S,D,建立损失函数对于多任务教师网络,一种方法使用不同通道编码映射并通过NormNet得到预测法向量建立损失函数 另一种用已训练好的分割与深度估计学生网络TargetNet-2和NormNet作为教师网络,根据步骤1.2)为学生网络TargetNet-3引入通道编码U-Channel Coding,映射为M-Channel Coding映射为随后将根据步骤1.2)和步骤2.1)得到将根据步骤2.1)得到建立损失函数
3)确定TargetNet中不同任务的各自分支位置;
根据步骤2)中每个block的最终loss, 对不同任务选择分支点p:p=arg minnLn;
4)使用教师网络中的对应分支作为学生网络的分支;
确定不同任务的分支位置后,移除TargetNet中从靠后的分支点到网络末尾之间的所有block;不同任务的分支使用对应教师网络中block,得到最终的TargetNet结构;使用步骤2)中的损失函数,利用梯度下降调优TargetNet。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910264911.8A CN110097084B (zh) | 2019-04-03 | 2019-04-03 | 通过投射特征训练多任务学生网络的知识融合方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910264911.8A CN110097084B (zh) | 2019-04-03 | 2019-04-03 | 通过投射特征训练多任务学生网络的知识融合方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN110097084A true CN110097084A (zh) | 2019-08-06 |
CN110097084B CN110097084B (zh) | 2021-08-31 |
Family
ID=67444289
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910264911.8A Active CN110097084B (zh) | 2019-04-03 | 2019-04-03 | 通过投射特征训练多任务学生网络的知识融合方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN110097084B (zh) |
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110930408A (zh) * | 2019-10-15 | 2020-03-27 | 浙江大学 | 基于知识重组的语义图像压缩方法 |
CN112200062A (zh) * | 2020-09-30 | 2021-01-08 | 广州云从人工智能技术有限公司 | 一种基于神经网络的目标检测方法、装置、机器可读介质及设备 |
CN113343796A (zh) * | 2021-05-25 | 2021-09-03 | 哈尔滨工程大学 | 一种基于知识蒸馏的雷达信号调制方式识别方法 |
CN113505719A (zh) * | 2021-07-21 | 2021-10-15 | 山东科技大学 | 基于局部-整体联合知识蒸馏算法的步态识别模型压缩系统及方法 |
CN113610118A (zh) * | 2021-07-19 | 2021-11-05 | 中南大学 | 一种基于多任务课程式学习的眼底图像分类方法、装置、设备及介质 |
CN113888538A (zh) * | 2021-12-06 | 2022-01-04 | 成都考拉悠然科技有限公司 | 一种基于内存分块模型的工业异常检测方法 |
CN115578353A (zh) * | 2022-10-18 | 2023-01-06 | 中科(黑龙江)数字经济研究院有限公司 | 一种基于图流蒸馏的多模态医学影像分割方法及装置 |
Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106875373A (zh) * | 2016-12-14 | 2017-06-20 | 浙江大学 | 基于卷积神经网络剪枝算法的手机屏幕mura缺陷检测方法 |
CN107247989A (zh) * | 2017-06-15 | 2017-10-13 | 北京图森未来科技有限公司 | 一种神经网络训练方法及装置 |
WO2018126213A1 (en) * | 2016-12-30 | 2018-07-05 | Google Llc | Multi-task learning using knowledge distillation |
CN108334934A (zh) * | 2017-06-07 | 2018-07-27 | 北京深鉴智能科技有限公司 | 基于剪枝和蒸馏的卷积神经网络压缩方法 |
CN108665496A (zh) * | 2018-03-21 | 2018-10-16 | 浙江大学 | 一种基于深度学习的端到端的语义即时定位与建图方法 |
US20180307894A1 (en) * | 2017-04-21 | 2018-10-25 | General Electric Company | Neural network systems |
CN108960419A (zh) * | 2017-05-18 | 2018-12-07 | 三星电子株式会社 | 用于使用知识桥的学生-教师迁移学习网络的装置和方法 |
CN108985250A (zh) * | 2018-07-27 | 2018-12-11 | 大连理工大学 | 一种基于多任务网络的交通场景解析方法 |
CN109493407A (zh) * | 2018-11-19 | 2019-03-19 | 腾讯科技(深圳)有限公司 | 实现激光点云稠密化的方法、装置及计算机设备 |
-
2019
- 2019-04-03 CN CN201910264911.8A patent/CN110097084B/zh active Active
Patent Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106875373A (zh) * | 2016-12-14 | 2017-06-20 | 浙江大学 | 基于卷积神经网络剪枝算法的手机屏幕mura缺陷检测方法 |
WO2018126213A1 (en) * | 2016-12-30 | 2018-07-05 | Google Llc | Multi-task learning using knowledge distillation |
US20180307894A1 (en) * | 2017-04-21 | 2018-10-25 | General Electric Company | Neural network systems |
CN108960419A (zh) * | 2017-05-18 | 2018-12-07 | 三星电子株式会社 | 用于使用知识桥的学生-教师迁移学习网络的装置和方法 |
CN108334934A (zh) * | 2017-06-07 | 2018-07-27 | 北京深鉴智能科技有限公司 | 基于剪枝和蒸馏的卷积神经网络压缩方法 |
CN107247989A (zh) * | 2017-06-15 | 2017-10-13 | 北京图森未来科技有限公司 | 一种神经网络训练方法及装置 |
CN108665496A (zh) * | 2018-03-21 | 2018-10-16 | 浙江大学 | 一种基于深度学习的端到端的语义即时定位与建图方法 |
CN108985250A (zh) * | 2018-07-27 | 2018-12-11 | 大连理工大学 | 一种基于多任务网络的交通场景解析方法 |
CN109493407A (zh) * | 2018-11-19 | 2019-03-19 | 腾讯科技(深圳)有限公司 | 实现激光点云稠密化的方法、装置及计算机设备 |
Non-Patent Citations (1)
Title |
---|
廖祥文 等: "基于多任务迭代学习的论辩挖掘方法", 《计算机学报》 * |
Cited By (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110930408A (zh) * | 2019-10-15 | 2020-03-27 | 浙江大学 | 基于知识重组的语义图像压缩方法 |
CN110930408B (zh) * | 2019-10-15 | 2021-06-18 | 浙江大学 | 基于知识重组的语义图像压缩方法 |
CN112200062A (zh) * | 2020-09-30 | 2021-01-08 | 广州云从人工智能技术有限公司 | 一种基于神经网络的目标检测方法、装置、机器可读介质及设备 |
CN112200062B (zh) * | 2020-09-30 | 2021-09-28 | 广州云从人工智能技术有限公司 | 一种基于神经网络的目标检测方法、装置、机器可读介质及设备 |
CN113343796A (zh) * | 2021-05-25 | 2021-09-03 | 哈尔滨工程大学 | 一种基于知识蒸馏的雷达信号调制方式识别方法 |
CN113343796B (zh) * | 2021-05-25 | 2022-04-05 | 哈尔滨工程大学 | 一种基于知识蒸馏的雷达信号调制方式识别方法 |
CN113610118A (zh) * | 2021-07-19 | 2021-11-05 | 中南大学 | 一种基于多任务课程式学习的眼底图像分类方法、装置、设备及介质 |
CN113610118B (zh) * | 2021-07-19 | 2023-12-12 | 中南大学 | 一种基于多任务课程式学习的青光眼诊断方法、装置、设备及方法 |
CN113505719A (zh) * | 2021-07-21 | 2021-10-15 | 山东科技大学 | 基于局部-整体联合知识蒸馏算法的步态识别模型压缩系统及方法 |
CN113505719B (zh) * | 2021-07-21 | 2023-11-24 | 山东科技大学 | 基于局部-整体联合知识蒸馏算法的步态识别模型压缩系统及方法 |
CN113888538A (zh) * | 2021-12-06 | 2022-01-04 | 成都考拉悠然科技有限公司 | 一种基于内存分块模型的工业异常检测方法 |
CN115578353A (zh) * | 2022-10-18 | 2023-01-06 | 中科(黑龙江)数字经济研究院有限公司 | 一种基于图流蒸馏的多模态医学影像分割方法及装置 |
CN115578353B (zh) * | 2022-10-18 | 2024-04-05 | 中科(黑龙江)数字经济研究院有限公司 | 一种基于图流蒸馏的多模态医学影像分割方法及装置 |
Also Published As
Publication number | Publication date |
---|---|
CN110097084B (zh) | 2021-08-31 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110097084A (zh) | 通过投射特征训练多任务学生网络的知识融合方法 | |
CN109299274B (zh) | 一种基于全卷积神经网络的自然场景文本检测方法 | |
CN107577651B (zh) | 基于对抗网络的汉字字体迁移系统 | |
US11593615B2 (en) | Image stylization based on learning network | |
CA3043621C (en) | Method and system for color representation generation | |
CN110738697A (zh) | 基于深度学习的单目深度估计方法 | |
CN109919209B (zh) | 一种领域自适应深度学习方法及可读存储介质 | |
CN110246181B (zh) | 基于锚点的姿态估计模型训练方法、姿态估计方法和系统 | |
CN109934154B (zh) | 一种遥感影像变化检测方法及检测装置 | |
CN106599863A (zh) | 一种基于迁移学习技术的深度人脸识别方法 | |
CN113505792B (zh) | 面向非均衡遥感图像的多尺度语义分割方法及模型 | |
CN107169508B (zh) | 一种基于融合特征的旗袍图像情感语义识别方法 | |
CN113066025B (zh) | 一种基于增量学习与特征、注意力传递的图像去雾方法 | |
CN112508079B (zh) | 海洋锋面的精细化识别方法、系统、设备、终端及应用 | |
CN111127360A (zh) | 一种基于自动编码器的灰度图像迁移学习方法 | |
CN114048822A (zh) | 一种图像的注意力机制特征融合分割方法 | |
CN107169498B (zh) | 一种融合局部和全局稀疏的图像显著性检测方法 | |
Al-Amaren et al. | RHN: A residual holistic neural network for edge detection | |
CN115409157A (zh) | 一种基于学生反馈的无数据知识蒸馏方法 | |
CN112767277B (zh) | 一种基于参考图像的深度特征排序去模糊方法 | |
CN110209981A (zh) | 通过投射特征训练多任务学生网络的知识融合方法 | |
Wang et al. | A De-raining semantic segmentation network for real-time foreground segmentation | |
CN117011515A (zh) | 基于注意力机制的交互式图像分割模型及其分割方法 | |
CN114463614A (zh) | 使用生成式参数的层次性显著建模的显著性目标检测方法 | |
CN113807354B (zh) | 图像语义分割方法、装置、设备和存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |