CN114358206A - 二值神经网络模型训练方法及系统、图像处理方法及系统 - Google Patents
二值神经网络模型训练方法及系统、图像处理方法及系统 Download PDFInfo
- Publication number
- CN114358206A CN114358206A CN202210033086.2A CN202210033086A CN114358206A CN 114358206 A CN114358206 A CN 114358206A CN 202210033086 A CN202210033086 A CN 202210033086A CN 114358206 A CN114358206 A CN 114358206A
- Authority
- CN
- China
- Prior art keywords
- neural network
- network model
- theta
- binary
- initial
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000003062 neural network model Methods 0.000 title claims abstract description 164
- 238000012549 training Methods 0.000 title claims abstract description 107
- 238000000034 method Methods 0.000 title claims abstract description 47
- 238000003672 processing method Methods 0.000 title claims abstract description 12
- 238000013528 artificial neural network Methods 0.000 claims abstract description 205
- 238000013140 knowledge distillation Methods 0.000 claims abstract description 28
- 238000004821 distillation Methods 0.000 claims abstract description 14
- 238000012545 processing Methods 0.000 claims description 36
- 230000004913 activation Effects 0.000 claims description 24
- 238000004088 simulation Methods 0.000 claims description 10
- 230000008569 process Effects 0.000 claims description 8
- 238000010276 construction Methods 0.000 claims description 7
- 238000004458 analytical method Methods 0.000 claims description 3
- 239000000865 liniment Substances 0.000 claims description 2
- 239000012528 membrane Substances 0.000 claims description 2
- 238000013473 artificial intelligence Methods 0.000 abstract description 2
- 230000006870 function Effects 0.000 description 35
- 238000001994 activation Methods 0.000 description 18
- 238000004364 calculation method Methods 0.000 description 4
- 230000006872 improvement Effects 0.000 description 4
- 238000010586 diagram Methods 0.000 description 3
- 238000013139 quantization Methods 0.000 description 3
- 230000009467 reduction Effects 0.000 description 3
- 238000002474 experimental method Methods 0.000 description 2
- 238000007667 floating Methods 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 230000000644 propagated effect Effects 0.000 description 2
- 230000001360 synchronised effect Effects 0.000 description 2
- 238000013459 approach Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 239000000470 constituent Substances 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000009499 grossing Methods 0.000 description 1
- 238000013508 migration Methods 0.000 description 1
- 230000005012 migration Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 230000008092 positive effect Effects 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 238000011002 quantification Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/048—Activation functions
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了二值神经网络模型训练方法及系统、图像处理方法及系统,属于人工智能技术领域,其中训练方法具体包括:构建在线知识蒸馏增强的二值神经网络训练框架,其中在线知识蒸馏增强的二值神经网络训练框架中教师网络为初始实值神经网络模型以及初始辅助神经网络模型,学生网络为初始二值神经网络模型;对这三个网络模型使用在线蒸馏方法进行训练,从而提升二值神经网络的性能,同时,利用本发明二值神经网络模型对待处理图像进行图像分类处理,从而提高了图像分类的准确性。
Description
技术领域
本发明涉及人工智能技术领域,更具体的说是涉及二值神经网络模型训练方法及系统、图像处理方法及系统。
背景技术
深度神经网络在计算机视觉任务,如图像分类和目标检测上,取得了巨大的成功,然而深度神经网络模型通常有数百万个参数,需要消耗大量的内存和运算资源来解决复杂的计算问题。在实际中,因为计算资源的限制,将深度神经网络部署在嵌入式平台和移动设备上会遇到许多挑战。为了解决这个限制,许多方法通过压缩网络结构来减少内存使用和计算开销。
在现有技术中,二值神经网络通过将浮点输入和网络权重转化为二值形式来压缩深度神经网络。为了减少二值神经网络和实值神经网络之间的性能差距,一些经典的网络结构被提出,如:XNOR-Net网络,其利用对应的二值化参数和比例因子来重建全精度的权重和激活值,以此提高二值神经网络的性能;ABC-Net,运用多个二进制基的线性组合来近似全精度的权重和激活值。
但是,上述提到的二值神经网络仍有以下几个限制:
(1)由于极端的二值化按位操作很可能会导致实值神经网络和二值神经网络的信息流间产生巨大的差异,因此在正向传播与反向传播时产生的量化误差和梯度错配通常会导致实值神经网络和二值神经网络的性能差距巨大,造成二值神经网络模型在具体计算机视觉任务上,如图像分类任务的类别预测准确度相较于实值神经网络大幅度降低,从而限制图像分类等计算机视觉任务在资源受限的平台(如嵌入式设备等)上的部署。
(2)根据第(1)点,巨大的性能差距会导致实值神经网络的准确率损失,这会影响到实值神经网络对二值神经网络的训练。而现有技术中并未有减小网络间的性能差距的问题。
(3)对于知识蒸馏,学生网络常通过离线的方式,由预训练好的教师网络进行训练,这使得教师网络无法获得学生网络的反馈。换句话说,知识是单向从教师网络传递给学生网络。这会给二值神经网络的知识蒸馏带来更多的障碍。
综上,如何提供一种二值神经网络模型训练方法及系统、图像处理方法及系统是本领域技术人员亟需解决的问题。
发明内容
有鉴于此,本发明提供了一种二值神经网络模型训练方法及系统、图像处理方法及系统,使用在线蒸馏技术来联合训练二值神经网络和实值神经网络,提升了网络间知识的相互交流,同时使得实值神经网络能够依据二值神经网络的反馈更好的指导二值神经网络的训练,并且,通过本发明提出的辅助神经网络来桥接实值神经网络和二值神经网络之间的知识迁移,进一步提升性能,并将基于在线知识蒸馏的二值神经网络训练框架扩展成三个网络集成的结构,进一步的缩小了教师网络和学生网络之间的性能差异,提高二值神经网络的性能,从而提高了图像分类的准确性。
为了实现上述目的,本发明提供如下技术方案:
一方面,本发明提供一种二值神经网络模型的训练方法,所述训练方法包括:
S100:构建在线知识蒸馏增强的二值神经网络训练框架,其中,所述在线知识蒸馏增强的二值神经网络训练框架中教师网络为初始实值神经网络模型ΘR以及初始辅助神经网络模型ΘA,学生网络为初始二值神经网络模型ΘB;
S200:利用所述在线蒸馏方法,对所述初始实值神经网络模型ΘR、所述初始辅助神经网络模型ΘA以及所述初始二值神经网络模型ΘB进行j次训练,得到实值神经网络模型ΘR j、辅助神经网络模型ΘA j以及二值神经网络模型ΘB j;
S300:获取待训练图像,将所述待训练图像输入至所述实值神经网络模型ΘR j、所述辅助神经网络模型ΘA j以及所述二值神经网络模型ΘB j中,得到图像的类别预测值以及图像类别标签;
S400:基于图像的类别预测值以及图像类别标签,计算得到目标损失函数值,并根据所述目标损失函数值进行参数更新,得到更新后的实值神经网络ΘR j+1、辅助神经网络ΘA j+1以及二值神经网络ΘB j+1;
S500:当满足预设训练条件时,将所述二值神经网络ΘB j+1作为目标二值神经网络模型。
优选的,所述S100包括初始二值神经网络模型ΘB的构建:
其中,sign(.)是符号函数,Ab为激活值,Wb为实值权重;
优选的,所述S100还包括对初始辅助神经网络模型ΘA的构建:
优选的,所述S400包括:
S410:基于图像的类别预测值以及值图像类别标签,计算得到目标损失函数值:
LΘB=Lce(y,PB)+Lm(ΘB);
LΘA=Lce(y,PA)+Lm(ΘA);
LΘR=Lce(y,PR)+Lm(ΘR);
其中,y是图像类别标签,PB是初始二值神经网络模型ΘB对输入图片的类别预测值,PA是初始辅助神经网络模型ΘA对输入图片的类别预测值,PR是初始实值神经网络模型ΘR对输入图片的类别预测值;是初始二值神经网络模型ΘB的整体损失函数,是初始辅助神经网络模型ΘA的整体损失函数,是初始实值神经网络模型ΘR的整体损失函数;
S410:根据目标损失函数值进行j+1次训练,并进行参数更新,得到更新后的实值神经网络模型ΘR j+1、辅助神经网络模型ΘA j+1以及二值神经网络模型ΘB j+1。
优选的,所述目标损失函数值包括模拟损失项Lm(·),所述模拟损失项Lm(·)由两个模拟损失子项Lm(.,.)组成,其计算公式为:
Lm(ΘB)=αRBLm(PR,PB)+βABLm(PA,PB);
Lm(ΘA)=αRALm(PR,PA)+βBALm(PB,PA);
Lm(ΘR)=αARLm(PA,PR)+βBRLm(PB,PR);
其中,PA是初始辅助神经网络模型ΘA对于输入图片的类别预测值,PR是初始实值神经网络模型ΘR对于输入图片的类别预测值,PB是初始二值神经网络模型ΘB对于输入图片的类别预测值,αRB、αRA、αAB、βAB、βBA、βBR分别为模拟因子;
模拟损失子项Lm(.,.)的计算公式为:
优选的,所述目标损失函数值还包括交叉熵损失项Lce(·,·),其计算公式为:
其中,y是图像类别标签,pi是输入到网络的训练样本中的第i个样本的类别预测值,N为训练样本的大小。
优选的,所述S500包括:共对所述实值神经网络模型、所述辅助神经网络模型以及所述初始二值神经网络模型进行K次训练,对于第j+1次训练为1=<j+1<=K,其中,j为正整数;当j+1=K时,将二值神经网络ΘB j+1作为目标二值神经网络,否则令j=j+1,并返回步骤S200进行重复训练。
另一方面,本发明提供一种二值神经网络模型的训练系统,包括:
构建模块,构建在线知识蒸馏增强的二值神经网络训练框架,其中,所述在线知识蒸馏增强的二值神经网络训练框架中教师网络为初始实值神经网络模型ΘR以及初始辅助神经网络模型ΘA,学生网络为初始二值神经网络模型ΘB;
训练模块:与所述构建模块连接,利用所述在线蒸馏方法,对所述初始实值神经网络模型ΘR、所述初始辅助神经网络模型ΘA以及所述初始二值神经网络模型ΘB进行j次训练,得到实值神经网络模型ΘR j、辅助神经网络模型ΘA j以及二值神经网络模型ΘB j;
处理模块,与所述训练模块连接,获取待训练数据集,将所述待训练数据集输入至所述实值神经网络模型ΘR j、所述辅助神经网络模型ΘA j以及所述二值神经网络模型ΘB j中,得到数据集中图片的类别预测值以及数据集类别标签;
更新模块,与所述处理模块连接,基于数据集的类别预测以及数据集类别标签,计算得到目标损失函数值,并根据所述目标损失函数值进行参数更新,得到更新后的实值神经网络ΘR j+1、辅助神经网络ΘA j+1以及二值神经网络ΘB j+1;
判断模块,与所述更新模块连接,用于当满足训练预设条件时,将所述二值神经网络ΘB j+1作为目标二值神经网络模型。
另一方面,本发明提供一种图像处理方法,应用是上述得到的目标二值神经网络模型,所述图像处理方法包括:
S10:获取待处理图像;
S20:利用所述目标二值神经网络模型对所述待处理图像进行图像分类处理;
S30:得到分类处理结果并输出。
再一方面,本发明还提供了一种图像处理系统,包括:
获取模块:用于获取待处理图像;
分类处理模块,与所述获取模块连接:用于利用所述目标二值神经网络模型对所述待处理图像进行图像分类处理;
输出模块,与所述分析处理模块连接,用于获取待处理图像,利用所述目标二值神经网络模型对所述待处理图像进行图像分类处理,得到分类处理结果并输出。
经由上述的技术方案可知,与现有技术相比,本发明公开提供了二值神经网络模型训练方法及系统、图像处理方法及系统,所构建的在线知识蒸馏增强的二值神经网络训练框架,实现教师网络和学生网络之间知识的交互,通过辅助神经网络,帮助建立实值神经网络和二值神经网络之间的联系,并将基于在线知识蒸馏的二值神经网络训练框架扩展成三个网络的集成结构。减小了教师网络和学生网络之间的性能差异,进一步提升网络的性能,从而提高了图像分类的准确度。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据提供的附图获得其他的附图。
图1为本发明提供的二值神经网络模型的训练方法流程示意图;
图2为本实施例1提供的在线知识蒸馏增强的二值神经网络训练框架结构示意图;
图3为本实施例1提供的二值神经网络模型的训练系统结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
实施例1
一方面,参见附图1所示,本发明实施例1公开了一种二值神经网络模型的训练方法,包括:
S100:构建在线知识蒸馏增强的二值神经网络训练框架,其中,在线知识蒸馏增强的二值神经网络训练框架中教师网络为初始实值神经网络模型ΘR以及初始辅助神经网络模型ΘA,学生网络为初始二值神经网络模型ΘB;
S200:利用在线蒸馏方法,对实值神经网络模型ΘR、初始辅助神经网络模型ΘA以及初始二值神经网络模型ΘB进行j次训练,得到实值神经网络模型ΘR j、辅助神经网络模型ΘA j以及二值神经网络模型ΘB j;
S300:获取待训练数据集,将待训练数据集输入至训练后的实值神经网络模型ΘR j、辅助神经网络模型ΘA j以及二值神经网络模型ΘB j中,得到数据集中图片的类别预测值以及数据集类别标签;
S400:基于数据集中图片的类别预测值以及数据集类别标签,计算得到目标损失函数值,并根据目标损失函数值进行参数更新,得到更新后的实值神经网络ΘR j+1、辅助神经网络ΘA j+1以及二值神经网络ΘB j+1;
S500:当满足训练预设训练条件时,将二值神经网络ΘB j+1作为目标二值神经网络模型;
具体的,当目标二值神经网络模型应用于图像处理时,待训练数据集为待训练图像数据集。
在一个具体实施例中,二值神经网络是一种有效的神经网络压缩方法,其通过将浮点输入和全精度网络权重进行二值化来压缩网络结构。使用二值化操作对实值神经网络进行压缩后,网络中的权重和激活都可以用1位数值(如+1或-1)表示,而不会占用太多内存。
在公式(1)中,sign(.)是符号函数,函数输入若是正值,输出为1,负值则为-1,其导数为脉冲函数。同时,通过使用Straight-forward(直接前向传递)方法在反向传播过程中估计sign函数的梯度,权重平均值用来估计激活函数的梯度。
经过上述技术方案,得到了初始实值神经网络模型ΘR所对应的初始二值神经网络模型ΘB。
然而,直接对实值神经网络的激活值和权重进行二值化,会在参数正向传播与梯度反向传播时产生量化误差和梯度错配,导致二值神经网络相较于全精度实值神经网络,性能急剧下降。
在一个具体实施例中,为了解决二值神经网络性能急剧下降的问题,本发明基于在线知识蒸馏,提出了一种在线知识蒸馏增强的二值神经网络,即Online Distilling-Enhanced Binary Neural Networks,缩写为ODE-BNN。通过ODE-BNN对压缩后的二值神经网络参数进行训练。通过在线知识蒸馏,使用性能更好的全精度实值神经网络对二值神经网络的训练进行指导,可以使二值神经网络的性能获得了极大的提高。然而,由于正向和反向传播中产生的量化误差和梯度错配,这种提高被实值神经网络和二值神经网络间的性能差距限制了。因此,仅使用实值神经网络对二值神经网络进行在线知识蒸馏,并不能给二值神经网络提供足够好的指导。进一步的,本发明还提出了构建软化的辅助神经网络来解决上述问题,辅助神经网络就像一座桥一样联系实值神经网络和二值神经网络。软化方法可以平滑量化步骤、避免梯度错配。一方面,辅助神经网络的精度介于实值神经网络和二值神经网络之间,其有助于实现实值神经网络和二值神经网络的信息交换,帮助提升二值神经网络性能。另一方面,辅助神经网络可以和实值神经网络一起提供对二值神经网络训练的指导。
参见附图2所示,本发明实施例提供了在线知识蒸馏增强的二值神经网络训练框架结构示意图,在一个具体实施例中,将初始实值神经网络ΘR、初始二值神经网络ΘB和初始辅助神经网络ΘA集成为在线蒸馏增强的二值神经网络训练框架。通过在线蒸馏的方式,使用实值神经网络和辅助神经网络对二值神经网络的参数优化过程进行指导。其中在线蒸馏框架中的教师网络为初始实值神经网络ΘR和初始辅助神经网络ΘA,学生网络为初始二值神经网络ΘB。
对于图像分类任务,基于上述在线蒸馏框架对二值神经网络进行K次训练,对于第j+1次训练(1=<j+1<=K),将训练图像输入到在线蒸馏框架下的每个神经网络中,即实值神经网络ΘR j、二值神经网络ΘB j和辅助神经网络ΘA j中,其中ΘR j、ΘB j和ΘA j是基于第j次训练得到的。每个神经网络分别对图片进行处理,得到网络对于该次训练输入图片的类别预测值。
之后,基于上述图像的类别预测值和图像类别标签,通过下面的目标函数公式(6)计算得到该次训练过程的损失函数值,并基于该目标损失函数值来更新每个神经网络模型的参数。该损失函数由模拟损失项Lm(·)和交叉熵损失项Lce(·,·)构成。其中模拟损失项用于描述框架中任意一个神经网络(如二值神经网络ΘB)与框架中另外两个神经网络(如实值神经网络ΘR和辅助神经网络ΘA)对于第j+1次训练输入图像的类别预测值之间的差异。交叉熵损失项用于描述框架中任意网络对于第j+1次训练输入图像的输出类别预测值和图像的真实类别标签之间的差异。
其中y是图像类别标签,PB是二值神经网络ΘB的类别预测值,PA是辅助神经网络ΘA的类别预测值,PR是实值神经网络ΘR的类别预测值;是二值神经网络ΘB的整体损失函数,是辅助神经网络ΘA的整体损失函数,是实值神经网络ΘR的整体损失函数。
经过上述的第j+1次训练,我们同步训练框架内的三个神经网络并进行参数更新,得到了实值神经网络ΘR j+1、二值神经网络ΘB j+1和辅助神经网络ΘA j+1。此时若满足预设条件后(如j+1=K,即当前训练次数为预设训练次数),我们即可将上述框架中训练得到二值神经网络ΘB j+1作为目标二值神经网络,否则令j=j+1,继续上述训练。
在一个具体实施例中,模拟损失项Lm(·)和交叉熵损失项Lce(·,·)的具体计算过程为:
(1)模拟损失项Lm(·)由两个模拟损失子项Lm(.,.)组成,每个模拟损失子项描述在线蒸馏框架中任意两个网络的输出类别预测值之间的差异,通过最小化Lm(.,.)使一个网络能够尽可能地学习另一个网络的输出。如二值神经网络的模拟损失项Lm(ΘB),其由二值神经网络和实值神经网络间的模拟损失子项Lm(PR,PB)和二值神经网络与辅助神经网络间的模拟损失子项Lm(PA,PB)构成。二值神经网络通过模拟损失项向教师网络(即实值神经网络和辅助神经网络)学习,可以使得训练得到的目标二值神经网络在图片类别的预测结果上更接近教师网络,进而提升二值神经网络的预测准确度。下式为框架中各网络所对应的模拟损失项Lm(·):
其中PA是辅助神经网络ΘA对于输入图片的类别预测值,PR是实值神经网络ΘR对于输入图片的类别预测值,PB是二值神经网络ΘB对于输入图片的类别预测值,α**和β**为模拟因子,用于平衡两个模拟损失的大小。在实现中,αRB设置为0.5,βAB设置为0.5,αRA设置为0.7,βBA,αAR和βBR设置为1。同时,模拟损失子项Lm(.,.)的具体计算公式如下:
从上述模拟损失项可以看出,二值神经网络通过模拟损失项来学习实值神经网络输出类别预测值的分布,实值神经网络同时也通过模拟损失接收二值神经网络的反馈并给整个训练过程提供更好的指导。同时,二值神经网络也通过模拟损失项来学习辅助神经网络输出类别预测值的分布,由于辅助神经网络的性能介于实值神经网络与二值神经网络之间,其可以弥补实值神经网络和二值神经网络间的巨大差异,有助于实现实值神经网络和二值神经网络的信息交换,帮助提升二值神经网络性能。
(2)交叉熵损失Lce(·)可由下式得到,该损失项通过比较框架中的神经网络类别预测值和图像标签间差异,让网络能够学习到数据的正确分布,从而提升模型的预测准确度
其中y是图像类别标签,pi是输入到网络的训练样本中的第i个样本的类别预测值,N是该批样本的大小。
通过上述技术方案,本发明使用在线知识蒸馏网络框架,通过联合训练实值神经网络和二值神经网络大幅提升二值神经网络性能。同时,该框架也构建了软化辅助神经网络,在训练过程中平滑量化步骤、减少梯度错配,弥合实值神经网络和二值神经网络间的巨大差异,进一步提升二值神经网络性能。在多个公共数据集上的大量实验也验证了该方法的有效性。
另一方面,参见附图3所示,本发明实施例1还提供了一种二值神经网络模型的训练系统,包括:
构建模块,构建在线知识蒸馏增强的二值神经网络训练框架,其中,在线知识蒸馏增强的二值神经网络训练框架中教师网络为初始实值神经网络模型ΘR以及初始辅助神经网络模型ΘA,学生网络为初始二值神经网络模型ΘB;
训练模块:与构建模块连接,利用在线蒸馏方法,对实值神经网络模型ΘR、初始辅助神经网络模型ΘA以及初始二值神经网络模型ΘB进行j次训练,得到实值神经网络模型ΘR j、辅助神经网络模型ΘA j以及二值神经网络模型ΘB j;
处理模块,与训练模块连接,获取待训练数据集,将待训练数据集输入至训练后的实值神经网络模型ΘR j、辅助神经网络模型ΘA j以及二值神经网络模型ΘB j中,得到数据集中图片的类别预测值以及数据集类别标签;
更新模块,与处理模块连接,基于数据集的类别预测以及数据集类别标签,计算得到目标损失函数值,并根据目标损失函数值进行参数更新,得到更新后的实值神经网络ΘR j+1、辅助神经网络ΘA j+1以及二值神经网络ΘB j+1;
判断模块,与更新模块连接,用于当满足训练预设条件时,将二值神经网络ΘB j+1作为目标二值神经网络模型。
另一方面,本实施例1还提供一种图像处理方法,应用是上述得到的目标二值神经网络模型,图像处理方法包括:
S10:获取待处理图像;
S20:利用所述目标二值神经网络模型对所述待处理图像进行图像分类处理;
S30:得到分类处理结果并输出。
再一方面,本实施例1还提供了一种图像处理系统,包括
获取模块:用于获取待处理图像;
分类处理模块,与获取模块连接:用于利用目标二值神经网络模型对待处理图像进行图像分类处理;
输出模块,与分析处理模块连接,用于获取待处理图像,利用目标二值神经网络模型对待处理图像进行图像分类处理,得到分类处理结果并输出。
经由上述的技术方案可知,与现有技术相比,本发明公开提供了二值神经网络模型训练方法及系统、图像处理方法及系统,所构建的在线知识蒸馏增强的二值神经网络训练框架,实现教师网络和学生网络之间的知识交互,通过辅助神经网络,帮助建立实值神经网络和二值神经网络之间的联系,并将基于在线知识蒸馏的二值神经网络训练框架扩展成集成了三个网络的结构。减小了教师网络和学生网络之间的性能差异,并进一步提升网络的性能,从而提高了图像分类的准确度。
实施例2
为了验证上述方法的有效性,在三个公共基准数据集上进行了大量实验,实验结果证明了本发明对二值神经网络性能具有明显的提升效果,其在CIFAR10和CIFAR100数据集上分别获得最高3.15%和6.67%的准确度提升。同时也验证了辅助神经网络对缩小教师网络和学生网络间差距的积极作用,辅助神经网络可帮助ODE-BNN在CIFAR10和CIFAR100数据集上分别获得最高0.87%及3.48%的准确度提升。
本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似部分互相参见即可。对于实施例公开的装置而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。
对所公开的实施例的上述说明,使本领域专业技术人员能够实现或使用本发明。对这些实施例的多种修改对本领域的专业技术人员来说将是显而易见的,本文中所定义的一般原理可以在不脱离本发明的精神或范围的情况下,在其它实施例中实现。因此,本发明将不会被限制于本文所示的这些实施例,而是要符合与本文所公开的原理和新颖特点相一致的最宽的范围。
Claims (10)
1.一种二值神经网络模型的训练方法,其特征在于,所述训练方法包括:
S100:构建在线知识蒸馏增强的二值神经网络训练框架,其中,所述在线知识蒸馏增强的二值神经网络训练框架中教师网络为初始实值神经网络模型ΘR以及初始辅助神经网络模型ΘA,学生网络为初始二值神经网络模型ΘB;
S200:利用所述在线蒸馏方法,对所述初始实值神经网络模型ΘR、所述初始辅助神经网络模型ΘA以及所述初始二值神经网络模型ΘB进行j次训练,得到实值神经网络模型ΘR j、辅助神经网络模型ΘA j以及二值神经网络模型ΘB j;
S300:获取待训练数据集,将所述待训练数据集输入至所述实值神经网络模型ΘR j、所述辅助神经网络模型ΘA j以及所述二值神经网络模型ΘB j中,得到数据集中图片的类别预测值以及数据集类别标签;
S400:基于数据集中图片的类别预测值以及数据集类别标签,计算得到目标损失函数值,并根据所述目标损失函数值进行参数更新,得到更新后的实值神经网络ΘR j+1、辅助神经网络ΘA j+1以及二值神经网络ΘB j+1;
S500:当满足预设训练条件时,将所述二值神经网络ΘB j+1作为目标二值神经网络模型。
4.根据权利要求1所述的一种二值神经网络模型的训练方法,其特征在于,所述S400包括:
S410:基于图像的类别预测值以及图像类别标签,计算得到目标损失函数值:
LΘB=Lce(y,PB)+Lm(ΘB);
LΘA=Lce(y,PA)+Lm(ΘA);
LΘR=Lce(y,PR)+Lm(ΘR);
其中,y是图像类别标签,PB是初始二值神经网络模型ΘB对输入图片的类别预测值,PA是初始辅助神经网络模型ΘA对输入图片的类别预测值,PR是初始实值神经网络模型ΘR对输入图片的类别预测值;是初始二值神经网络模型ΘB的整体损失函数,是初始辅助神经网络模型ΘA的整体损失函数,是初始实值神经网络模型ΘR的整体损失函数;
S420:根据目标损失函数值进行j+1次训练,并进行参数更新,得到更新后的实值神经网络模型ΘR j+1、辅助神经网络模型ΘA j+1以及二值神经网络模型ΘB j+1。
5.根据权利要求4所述的一种二值神经网络模型的训练方法,其特征在于,所述目标损失函数值包括模拟损失项Lm(·),所述模拟损失项Lm(·)由两个模拟损失子项Lm(.,.)组成,其计算公式为:
Lm(ΘB)=αRBLm(PR,PB)+βABLm(PA,PB);
Lm(ΘA)=αRALm(PR,PA)+βBALm(PB,PA);
Lm(ΘR)=αARLm(PA,PR)+βBRLm(PB,PR);
其中,PA是初始辅助神经网络模型ΘA对于输入图片的类别预测值,PR是初始实值神经网络模型ΘR对于输入图片的类别预测值,PB是初始二值神经网络模型ΘB对于输入图片的类别预测值,αRB、αRA、αAB、βAB、βBA、βBR分别为模拟因子;
模拟损失子项Lm(.,.)的计算公式为:
7.根据权利要求1所述的一种二值神经网络模型的训练方法,其特征在于,所述S500包括:共对所述实值神经网络模型、所述辅助神经网络模型以及所述初始二值神经网络模型进行K次训练,对于第j+1次训练为1=<j+1<=K,其中,j为正整数;当j+1=K时,将二值神经网络ΘB j+1作为目标二值神经网络,否则令j=j+1,并返回步骤S200进行重复训练。
8.一种二值神经网络模型的训练系统,其特征在于,包括:
构建模块,构建在线知识蒸馏增强的二值神经网络训练框架,其中,所述在线知识蒸馏增强的二值神经网络训练框架中教师网络为初始实值神经网络模型ΘR以及初始辅助神经网络模型ΘA,学生网络为初始二值神经网络模型ΘB;
训练模块:与所述构建模块连接,利用所述在线蒸馏方法,对所述初始实值神经网络模型ΘR、所述初始辅助神经网络模型ΘA以及所述初始二值神经网络模型ΘB进行j次训练,得到实值神经网络模型ΘR j、辅助神经网络模型ΘA j以及二值神经网络模型ΘB j;
处理模块,与所述训练模块连接,获取待训练数据集,将所述待训练数据集输入至所述实值神经网络模型ΘR j、所述辅助神经网络模型ΘA j以及所述二值神经网络模型ΘB j中,得到数据集中图片的类别预测值以及数据集类别标签;
更新模块,与所述处理模块连接,基于数据集中图片的类别预测值以及数据集类别标签,计算得到目标损失函数值,并根据所述目标损失函数值进行参数更新,得到更新后的实值神经网络ΘR j+1、辅助神经网络ΘA j+1以及二值神经网络ΘB j+1;
判断模块,与所述更新模块连接,用于当满足训练预设条件时,将所述二值神经网络ΘB j+1作为目标二值神经网络模型。
9.一种图像处理方法,应用权利要求1-7任一项所述得到的目标二值神经网络模型,其特征在于,所述图像处理方法包括:
S10:获取待处理图像;
S20:利用所述目标二值神经网络模型对所述待处理图像进行图像分类处理;
S30:得到分类处理结果并输出。
10.一种图像处理系统,其特征在于,包括:
获取模块:用于获取待处理图像;
分类处理模块,与所述获取模块连接:用于利用所述目标二值神经网络模型对所述待处理图像进行图像分类处理;
输出模块,与所述分析处理模块连接,用于获取待处理图像,利用所述目标二值神经网络模型对所述待处理图像进行图像分类处理,得到分类处理结果并输出。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210033086.2A CN114358206B (zh) | 2022-01-12 | 2022-01-12 | 二值神经网络模型训练方法及系统、图像处理方法及系统 |
US18/080,777 US20230222325A1 (en) | 2022-01-12 | 2022-12-14 | Binary neural network model training method and system, and image processing method and system |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210033086.2A CN114358206B (zh) | 2022-01-12 | 2022-01-12 | 二值神经网络模型训练方法及系统、图像处理方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114358206A true CN114358206A (zh) | 2022-04-15 |
CN114358206B CN114358206B (zh) | 2022-11-01 |
Family
ID=81109566
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210033086.2A Active CN114358206B (zh) | 2022-01-12 | 2022-01-12 | 二值神经网络模型训练方法及系统、图像处理方法及系统 |
Country Status (2)
Country | Link |
---|---|
US (1) | US20230222325A1 (zh) |
CN (1) | CN114358206B (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114822510A (zh) * | 2022-06-28 | 2022-07-29 | 中科南京智能技术研究院 | 一种基于二值卷积神经网络的语音唤醒方法及系统 |
CN115660046A (zh) * | 2022-10-24 | 2023-01-31 | 中电金信软件有限公司 | 二值神经网络的梯度重构方法、装置、设备及存储介质 |
CN116664958A (zh) * | 2023-07-27 | 2023-08-29 | 鹏城实验室 | 基于二值神经网络模型的图像分类方法以及相关设备 |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN118233222B (zh) * | 2024-05-24 | 2024-09-10 | 浙江大学 | 一种基于知识蒸馏的工控网络入侵检测方法及装置 |
Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110379506A (zh) * | 2019-06-14 | 2019-10-25 | 杭州电子科技大学 | 针对心电图数据使用二值化神经网络的心律不齐检测方法 |
CN110880036A (zh) * | 2019-11-20 | 2020-03-13 | 腾讯科技(深圳)有限公司 | 神经网络压缩方法、装置、计算机设备及存储介质 |
CN111985523A (zh) * | 2020-06-28 | 2020-11-24 | 合肥工业大学 | 基于知识蒸馏训练的2指数幂深度神经网络量化方法 |
CN112116030A (zh) * | 2020-10-13 | 2020-12-22 | 浙江大学 | 一种基于向量标准化和知识蒸馏的图像分类方法 |
WO2021042857A1 (zh) * | 2019-09-02 | 2021-03-11 | 华为技术有限公司 | 图像分割模型的处理方法和处理装置 |
CN112508169A (zh) * | 2020-11-13 | 2021-03-16 | 华为技术有限公司 | 知识蒸馏方法和系统 |
CN113191489A (zh) * | 2021-04-30 | 2021-07-30 | 华为技术有限公司 | 二值神经网络模型的训练方法、图像处理方法和装置 |
CN113569882A (zh) * | 2020-04-28 | 2021-10-29 | 上海舜瞳科技有限公司 | 一种基于知识蒸馏的快速行人检测方法 |
CN113591978A (zh) * | 2021-07-30 | 2021-11-02 | 山东大学 | 一种基于置信惩罚正则化的自我知识蒸馏的图像分类方法、设备及存储介质 |
-
2022
- 2022-01-12 CN CN202210033086.2A patent/CN114358206B/zh active Active
- 2022-12-14 US US18/080,777 patent/US20230222325A1/en active Pending
Patent Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110379506A (zh) * | 2019-06-14 | 2019-10-25 | 杭州电子科技大学 | 针对心电图数据使用二值化神经网络的心律不齐检测方法 |
WO2021042857A1 (zh) * | 2019-09-02 | 2021-03-11 | 华为技术有限公司 | 图像分割模型的处理方法和处理装置 |
CN110880036A (zh) * | 2019-11-20 | 2020-03-13 | 腾讯科技(深圳)有限公司 | 神经网络压缩方法、装置、计算机设备及存储介质 |
CN113569882A (zh) * | 2020-04-28 | 2021-10-29 | 上海舜瞳科技有限公司 | 一种基于知识蒸馏的快速行人检测方法 |
CN111985523A (zh) * | 2020-06-28 | 2020-11-24 | 合肥工业大学 | 基于知识蒸馏训练的2指数幂深度神经网络量化方法 |
CN112116030A (zh) * | 2020-10-13 | 2020-12-22 | 浙江大学 | 一种基于向量标准化和知识蒸馏的图像分类方法 |
CN112508169A (zh) * | 2020-11-13 | 2021-03-16 | 华为技术有限公司 | 知识蒸馏方法和系统 |
CN113191489A (zh) * | 2021-04-30 | 2021-07-30 | 华为技术有限公司 | 二值神经网络模型的训练方法、图像处理方法和装置 |
CN113591978A (zh) * | 2021-07-30 | 2021-11-02 | 山东大学 | 一种基于置信惩罚正则化的自我知识蒸馏的图像分类方法、设备及存储介质 |
Non-Patent Citations (5)
Title |
---|
GEOFFREY HINTON ET AL: "distilling the knowledge in a neural network", 《ARXIV》 * |
YUAN L,TAY F E H,LI G,ET AL: "Revisiting Knowledge Distillation via Label Smoothing Regularization", 《2020 IEEE/CVF CONFERENCE ON COMPUTER VISION AND PATTERN RECOGNITION (CVPR)》 * |
刘峡壁等: "《人工智能 机器学习与神经网络》", 31 August 2020 * |
耿丽丽等: "深度神经网络模型压缩综述", 《计算机科学与探索》 * |
赖叶静等: "深度神经网络模型压缩方法与进展", 《华东师范大学学报(自然科学版)》 * |
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114822510A (zh) * | 2022-06-28 | 2022-07-29 | 中科南京智能技术研究院 | 一种基于二值卷积神经网络的语音唤醒方法及系统 |
CN114822510B (zh) * | 2022-06-28 | 2022-10-04 | 中科南京智能技术研究院 | 一种基于二值卷积神经网络的语音唤醒方法及系统 |
CN115660046A (zh) * | 2022-10-24 | 2023-01-31 | 中电金信软件有限公司 | 二值神经网络的梯度重构方法、装置、设备及存储介质 |
CN116664958A (zh) * | 2023-07-27 | 2023-08-29 | 鹏城实验室 | 基于二值神经网络模型的图像分类方法以及相关设备 |
CN116664958B (zh) * | 2023-07-27 | 2023-11-14 | 鹏城实验室 | 基于二值神经网络模型的图像分类方法以及相关设备 |
Also Published As
Publication number | Publication date |
---|---|
CN114358206B (zh) | 2022-11-01 |
US20230222325A1 (en) | 2023-07-13 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114358206B (zh) | 二值神经网络模型训练方法及系统、图像处理方法及系统 | |
Wang et al. | A mesh-free method for interface problems using the deep learning approach | |
WO2021037113A1 (zh) | 一种图像描述的方法及装置、计算设备和存储介质 | |
CN109214001A (zh) | 一种中文语义匹配系统及方法 | |
CN107636691A (zh) | 用于识别图像中的文本的方法和设备 | |
CN115239638A (zh) | 一种工业缺陷检测方法、装置、设备及可读存储介质 | |
CN107292382A (zh) | 一种神经网络声学模型激活函数定点量化方法 | |
CN110275928B (zh) | 迭代式实体关系抽取方法 | |
CN107092594B (zh) | 基于图的双语递归自编码器 | |
CN111832637B (zh) | 基于交替方向乘子法admm的分布式深度学习分类方法 | |
CN113516133A (zh) | 一种多模态图像分类方法及系统 | |
CN113505206A (zh) | 基于自然语言推理的信息处理方法、装置和电子设备 | |
CN110738314B (zh) | 一种基于深度迁移网络的点击率预测方法及装置 | |
CN117455011A (zh) | 一种多模态交通大模型设计方法及多模态交通大模型 | |
CN112200255B (zh) | 一种针对样本集的信息去冗余方法 | |
CN116975686A (zh) | 训练学生模型的方法、行为预测方法和装置 | |
CN113538485B (zh) | 学习生物视觉通路的轮廓检测方法 | |
JP2021039220A (ja) | 音声認識装置、学習装置、音声認識方法、学習方法、音声認識プログラムおよび学習プログラム | |
CN110263352A (zh) | 用于训练深层神经机器翻译模型的方法及装置 | |
CN110365583A (zh) | 一种基于桥接域迁移学习的符号预测方法及系统 | |
CN115936801A (zh) | 基于神经网络的产品推荐方法、装置、设备和存储介质 | |
CN114429121A (zh) | 一种面向试题语料情感与原因句子对的抽取方法 | |
CN114548382A (zh) | 迁移训练方法、装置、设备、存储介质及程序产品 | |
CN111709275A (zh) | 一种用于Affordance推理的深度网络构建方法 | |
Rezk et al. | MOHAQ: Multi-Objective Hardware-Aware Quantization of recurrent neural networks |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |