CN113222100A - 神经网络模型的训练方法和装置 - Google Patents
神经网络模型的训练方法和装置 Download PDFInfo
- Publication number
- CN113222100A CN113222100A CN202010080441.2A CN202010080441A CN113222100A CN 113222100 A CN113222100 A CN 113222100A CN 202010080441 A CN202010080441 A CN 202010080441A CN 113222100 A CN113222100 A CN 113222100A
- Authority
- CN
- China
- Prior art keywords
- input image
- training
- image sample
- tth
- image samples
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Image Analysis (AREA)
Abstract
本公开涉及一种一种用于图像识别的神经网络模型的训练方法,包括:第一训练阶段,计算对应于第一组输入图像样本的第一损失函数,并利用所述第一损失函数训练所述神经网络模型,其中,所述第一损失函数的计算包括:(1)针对所述第一组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本的预测分类标签与真实分类标签之间的差异的分类项;以及(2)针对所述第一组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本与其重构图像样本之间的差异的重构误差项。
Description
技术领域
本公开涉及人工智能领域中的连续学习场景。更具体地,本公开涉及一种用于图像识别的神经网络模型的训练方法和装置。
背景技术
传统的机器学习是针对固定的任务进行的,也就是说,用于训练学习模型的数据集包含具有固定分布的训练数据。当输入新的数据集(即,包含具有与所述固定分布不同的新分布的训练数据的数据集)时,一般需要对学习模型进行重新训练。经过重新训练之后,所述学习模型只能对新的数据集给出响应,而无法对原数据集(即,包含所述固定类别的数据的数据集)给出响应。这个问题被称为机器学习中的“灾难性遗忘(CatastrophicForgetting)”。事实上,所述“灾难性遗忘”是机器学习所面临的“稳定性-可塑性困境(Stability-Plasticity Dilemma)”的结果,其中,稳定性指的是在学习新知识的同时保持原有知识的能力,而可塑性指的是学习新知识的能力。
连续学习(Continual Learning)是在一个学习模型上针对由多个不同的任务组成的连续序列进行训练。连续学习旨在解决上述“灾难性遗忘”的问题,更具体地,其在基于新的输入数据训练学习模型来适应新任务的同时,也维持所述学习模型在完成历史任务上的表现。连续学习是使一个学习模型适应学习任务的快速变化的关键,因此对于实现人工智能在现实场景中的应用十分关键。
连续学习包括任务增量学习(Task-Incremental Learning,Task-IL)、域增量学习(Domain-Incremental Learning,Domain-IL)、以及,类别增量学习(Class-IncrementalLearning,Class-IL)。其中,(i)在任务增量学习的场景中,学习模型能够得知当前的输入来自于哪一个任务类型。每一个任务类型都有独立的输出层,同时网络结构的其他部分不随任务类型不同而改变;(ii)在域增量学习的场景中,学习模型无需判断当前任务的任务类型。每个任务所使用的网络结构相同。虽然每个任务的输入分布不同,但输出分布均是相同的;(iii)在类别增量学习的场景中,学习模型需要判断当前任务的任务类型。每个任务所使用的网络结构相同。每个任务的输入分布不同,输出分布也不同。
在本发明中,我们仅就类别增量学习中的“灾难性遗忘”问题进行讨论。
针对类别增量学习中的“灾难性遗忘”问题,目前有如下两种主要的解决方法:
第一种方法是权重正则方法(Regularization-based Method)。对于已针对先前任务进行训练并达到较好训练效果的学习模型,权重正则方法将估计所述学习模型中的每个参数对于先前任务的重要性,并基于所述重要性来针对每个参数生成权重正则项,并将所述权重正则项添加到损失函数中。在针对新任务进行训练时,使用权重正则方法的学习模型将允许相对不重要的参数有较大的变化来学习新任务,并将保持相对重要的参数的变化程度尽量较小。
第二种方法是表达正则方法(Replay-based Method)。这一类方法使用一个较小的存储空间用于存储属于先前任务的输入数据样本,并在针对新任务进行学习的同时,在所述输入数据样本上进行训练。在针对新任务进行训练时,使用表达正则方法的学习模型将保持针对所述输入数据样本所提取的特征与没有针对新任务进行训练时有尽量相似的特征表达,从而促进针对新任务进行训练之后的模型在完成先前任务的表现上尽量靠近针对新任务进行训练先前的表现。
通过使用上述两种方法,类别增量学习中“灾难性遗忘”的问题可以被缓解,但仍存在局限性。具体而言,在新任务的输入数据分布与先前任务的输入数据分布有较大不同的情况下,对于针对新任务训练学习模型而言,根据先前任务的输入数据分布得到的模型参数可用性较低,因此所述学习模型需要从先前任务的输入数据中提取新任务所需要的特征。但是,对于权重正则方法而言,先前任务的输入数据已经无法获得;对于表达正则方法而言,虽然可获得少量属于先前任务的输入数据样本,但无法从少量的输入数据样本中提取足够的特征用于新任务的学习。因此,无法全面地提取特征以满足针对新任务的学习的需要是目前类别增量学习的局限性之一。
发明内容
鉴于上述问题而提出了本公开。本公开提供了一种用于图像识别的神经网络模型的训练方法和装置、以及电子设备和图像识别系统。
根据本公开的一个方面,提供了一种用于图像识别的神经网络模型的训练方法,包括:第一训练阶段,计算对应于第一组输入图像样本的第一损失函数,并利用所述第一损失函数训练所述神经网络模型,其中,所述第一损失函数的计算包括:(1)针对所述第一组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本的预测分类标签与真实分类标签之间的差异的分类项;以及(2)针对所述第一组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本与其重构图像样本之间的差异的重构误差项。
此外,根据本公开一方面的训练方法,进一步包括:在执行所述第一训练阶段之后,循环地执行第t训练阶段,其中,t>1,所述第t训练阶段计算对应于第t组输入图像样本的第t损失函数,并利用所述第t损失函数训练所述神经网络模型,其中,所述第t损失函数的计算包括:(1)针对所述第t组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本的预测分类标签与真实分类标签之间的差异的分类项;(2)针对所述第t组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本与其重构图像样本之间的差异的重构项;以及(3)针对第t训练阶段,计算正则项。
此外,根据本公开一方面的训练方法,进一步包括:在执行所述第一训练阶段之后,循环地执行第t训练阶段,其中,t>1,所述第t训练阶段计算对应于第t组输入图像样本的第t损失函数,并利用所述第t损失函数训练所述神经网络模型,其中,在每次循环地执行所述第t训练阶段之前,存储当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分,并将所存储的先前所有训练阶段的输入图像样本中的一部分与所述第t组输入图像样本一起输入到所述神经网络模型,其中,所述第t损失函数的计算基于:
(1)针对所存储的先前所有训练阶段的输入图像样本的一部分、以及第t组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本针对第t训练阶段相较于先前训练阶段新增的分类的预测分类标签与真实分类标签之间的差异的分类项;(2)针对所存储的先前所有训练阶段的输入图像样本的一部分、以及第t组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本与其重构图像样本之间的差异的重构项;以及(3)针对第t训练阶段,计算正则项。
此外,根据本公开一方面的训练方法,其中,所述重构误差项是由重构误差项计算器基于输入图像样本与重构图像样本之间的差异来计算的,其中,所述重构图像样本是由解码器基于所述输入图像样本的特征来生成的,所述输入图像样本的特征是由特征提取器基于所接收的所述输入图像样本来提取的。
此外,根据本公开一方面的训练方法,其中,第t训练阶段的正则项由正则项计算器计算,所述计算包括:针对所述神经网络模型的每一个模型参数,计算在当前第t训练阶段中所述模型参数的当前值与在先前第t-1训练阶段中所述模型参数的先前值之间的差异;针对所述神经网络模型的每一个模型参数,计算在先前第t-1训练阶段中所述模型参数的权重值;以及基于所述差异和所述权重值,计算所述模型参数在当前第t训练阶段中与在先前第t-1训练阶段中的所述差异的加权和。
此外,根据本公开一方面的训练方法,其中,在先前第t-1训练阶段中模型参数的权重值的计算包括:(1)针对在先前第t-1训练阶段中的第t-1组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本的预测分类标签与真实分类标签之间的差异的分类项;(2)针对在先前第t-1训练阶段中的第t-1组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本与其重构图像样本之间的差异的重构项。
此外,根据本公开一方面的训练方法,其中,所述模型参数包括特征提取器参数、线性分类器参数、以及解码器参数。
此外,根据本公开一方面的训练方法,其中,第t训练阶段的正则项由正则项计算器计算,所述计算包括:针对所存储的先前所有训练阶段的输入图像样本的一部分、以及第t组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本针对截至t-1训练阶段的所有训练阶段的分类,在第t训练阶段的预测分类标签与在第t-1训练阶段的预测分类标签之间的差异,针对所述所存储的先前所有训练阶段的输入图像样本中的一部分,计算所述差异的和。
根据本公开的另一个方面,提供了一种用于图像识别的神经网络模型的训练装置,其中,所述装置执行如上所述的训练方法。
根据本公开的另一个方面,提供了一种电子设备,包括:处理器;存储器,用于存储计算机程序指令;其中,当所述计算机程序指令由所述处理器加载并运行时,所述处理器执行如上所述的训练方法。
如以下将详细描述的,根据本公开的用于图像识别的神经网络模型的训练方法和装置、以及电子设备和图像识别系统,更全面地学习输入数据的特征,使得在新任务的输入数据分布与先前任务的输入数据分布有较大不同时,仍然能够在不获取先前任务输入数据、或者仅少量获取先前任务输入数据的情况下,提取出新任务的学习所需要的特征,从而使学习模型能够更好地完成新任务的学习,提高分类精度。
要理解的是,前面的一般描述和下面的详细描述两者都是示例性的,并且意图在于提供要求保护的技术的进一步说明。
附图说明
通过结合附图对本公开实施例进行更详细的描述,本公开的上述以及其它目的、特征和优势将变得更加明显。附图用来提供对本公开实施例的进一步理解,并且构成说明书的一部分,与本公开实施例一起用于解释本公开,并不构成对本公开的限制。在附图中,相同的参考标号通常代表相同部件或阶段。
图1示出了根据本公开实施例的使用EWC的学习模型在第一阶段训练的训练装置的示意图;
图2示出了根据本公开实施例的使用EWC的学习模型在第一阶段训练的训练方法的流程图;
图3示出了根据本公开实施例的使用EWC的学习模型在第t阶段阶段(其中,t>1)训练的训练装置的示意图;
图4示出了根据本公开实施例的使用EWC的学习模型在第t阶段阶段(其中,t>1)训练的训练方法的流程图;
图5示出了根据本公开实施例的使用改进的EWC的学习模型在第一阶段阶段训练的训练装置的示意图;
图6示出了根据本公开实施例的使用改进的EWC的学习模型在第一阶段阶段训练的训练方法的流程图;
图7示出了根据本公开实施例的使用改进的EWC的学习模型在第t阶段阶段(其中,t>1)训练的训练装置的示意图;
图8示出了根据本公开实施例的使用改进的EWC的学习模型在第t阶段阶段(其中,t>1)训练的训练方法的流程图;
图9示出了根据本公开实施例的使用iCaRL的学习模型在第t阶段阶段(其中,t>1)训练的训练装置的示意图;
图10示出了根据本公开实施例的使用iCaRL的学习模型在第t阶段阶段(其中,t>1)训练的训练方法的流程图;
图11示出了根据本公开实施例的使用改进的iCaRL的学习模型在第t阶段阶段(其中,t>1)训练的训练装置的示意图;
图12示出了根据本公开实施例的使用iCaRL的学习模型在第t阶段阶段(其中,t>1)训练的训练方法的流程图;
图13是图示根据本公开实施例的电子设备的硬件框图;以及
图14是图示根据本公开的实施例的计算机可读存储介质的示意图。
具体实施例
为了使得本公开的目的、技术方案和优点更为明显,下面将参照附图详细描述根据本公开的示例实施例。显然,所描述的实施例仅仅是本公开的一部分实施例,而不是本公开的全部实施例,应理解,本公开不受这里描述的示例实施例的限制。
本申请实施例提供的方案涉及人工智能领域中基于神经网络的分类技术,具体通过如下实施例进行说明。需要说明的是,如下实施例虽然是在图像识别场景中的分类任务的场景下进行描述,本发明的应用场景并不限于此,也可以应用于例如语音识别等任何适当的场景。
I.(a)权重正则方法
可塑权重巩固法(Elastic Weight Consolidation,EWC)属于权重正则方法的一种。第一实施例是关于使用EWC的学习模型。
图1示出了根据本公开实施例的使用EWC的学习模型在第一阶段训练的训练装置的示意图;图2示出了根据本公开实施例的使用EWC的学习模型在第一阶段训练的训练方法的流程图。
以下,结合图1和图2,我们将对使用EWC的学习模型第一阶段训练进行说明。
假设以x(t)表示第t阶段训练的第t组输入图像样本,以x(t,i)表示第t阶段训练的第t组输入图像样本中的第i个输入图像样本,其中,i=1,2,...,N,N是第t组输入图像样本中的输入图像样本总数。
在第一阶段训练中,第一组输入图像样本被输入到所述使用EWC的学习模型。对于第一阶段训练为例,t=1,x(1)表示第一组输入图像样本。
在步骤S101中,输入图像样本x(1)被输入到特征提取器中;
在步骤S102中,特征提取器提取输入图像样本x(1)的特征z(1),并将所述特征z(1)输出到线性分类器;
其中,x(1,i)表示第一组输入图像样本中的第i个输入图像样本,其中,i=1,2,...,N,N是第一组输入图像样本的样本总数;表示第一组输入图像样本中的第i个输入图像样本的关于第j个分类的真实分类标签,表示第一组输入图像样本中的第i个输入图像样本的关于第j个分类的预测分类标签;其中,j=1,2,...,K,K表示分类总数,真实分类标签是独热编码(One-Hot Code)的,预测分类标签可以取大于等于0且小于等于1的任意值,是针对x(1,i)的分类项,所述分类项对应于x(1,i)的预测分类标签和x(1,i)的真实分类标签y(1,i)之间的差异。
图3示出了根据本公开实施例的使用EWC的学习模型在第t阶段阶段(其中,t>1)训练的训练装置的示意图;图4示出了根据本公开实施例的使用EWC的学习模型在第t阶段阶段(其中,t>1)训练的训练方法的流程图。以下,结合图3和图4,我们将对使用EWC的学习模型第t阶段训练进行说明。
假设以x(t)表示第t阶段训练的第t组输入图像样本,以x(t,i)表示第t阶段训练的第t组输入图像样本中的第i个输入图像样本,其中,i=1,2,...,N,N是第t组输入图像样本中的输入图像样本总数。
在第t阶段训练中,第t组输入图像样本被输入到所述使用EWC的学习模型。对于第t阶段训练为例,t>1,x(t)表示第t组输入图像样本。
在步骤S301中,输入图像样本x(t)被输入到特征提取器中;
在步骤S302中,特征提取器提取输入图像样本x(t)的特征z(t),并将所述特征z(t)输出到线性分类器;
其中,x(t,i)表示第t输入图像样本中的第i个输入图像样本,其中,i=1,2,...,N,N是第t输入图像样本中的输入图像样本的总数;表示第t组输入图像样本中的第i个输入图像样本的关于第j个分类的真实分类标签,表示第t组输入图像样本中的第i个输入图像样本的关于第j个分类的预测分类标签,其中,j表示分类数量,假设总共有K种分类,则j=1,2,...,K,真实分类标签是独热编码(One-Hot Code)的,预测分类标签可以取大于等于0且小于等于1的任意值;是针对x(t,i)的分类项,所述分类项对应于x(t,i)的预测分类标签和x(t,i)的真实分类标签y(t,i)之间的差异;
其中,R(t)是所述使用EWC的学习模型在第t阶段训练中的正则项,所述正则项由正则项计算器计算。所述正则项R(t)如表达式3a所示:
其中,λ是可调整的超参数(Hyper Parameter),可以根据实际情况和经验进行手动调节;参数集θ(t)是所述使用EWC的学习模型在第t阶段训练后的所有参数的集合,在本实施中包括特征提取器中的所有参数以及线性分类器中的所有参数,所述参数集合包含的参数的总数为|θ(t)|,其中,第p个参数在当前第t阶段训练中的当前值记为第p个参数在先前第t-1阶段训练中的先前值记为 表示用于衡量第p个参数在第t-1阶段训练中的重要性的权重值。
其中,表示第p个模型参数在先前第t-1训练阶段中的权重值,表示第p个模型参数在第t-1训练阶段中的所述先前值;表示第t-1组输入图像样本中的第i个输入图像样本的关于第j个分类的真实分类标签,表示第t-1组输入图像样本中的第i个输入图像样本的关于第j个分类的预测分类标签;其中,j=1,2,...,K,K表示分类总数。
所述使用EWC的学习模型在第t阶段训练中的目标是:(1)使线性分类器输出的分类标签尽量靠近真实分类标签以提高学习模型的分类表现;(2)通过在损失函数的计算中引入正则项R(t),解决在该学习模型在第t阶段训练中的“灾难性遗忘”问题。
所述正则项R(t)贡献于“灾难性遗忘”问题的解决。具体而言,在针对新任务进行训练时,R(t)允许相对不重要的参数有较大的变化来学习新任务,并将保持相对重要的参数的变化程度尽量较小。也就是说,对于在第t-1阶段训练中较为重要的参数的值较大,所述使用EWC的学习模型在第t阶段训练中不允许过分地远离而对于在第t-1阶段训练中较为不重要参数的值较小,所述使用EWC的学习模型在第t阶段训练中允许可以极大地不同于
改进现有类别增量学习的方向之一就是如何使其更全面地学习输入数据的特征,使得在新任务的输入数据分布与先前任务的输入数据分布有较大不同时,仍然能够在不获取先前任务输入数据(对于权重正则方法而言)的情况下,提取出新任务的学习所需要的特征,从而使学习模型能够更好地完成新任务的学习。
I.(b)改进的权重正则方法
第二实施例是关于使用改进的EWC的学习模型。
图5示出了根据本公开实施例的使用改进的EWC的学习模型在第一阶段阶段训练的训练装置的示意图;图6示出了根据本公开实施例的使用改进的EWC的学习模型在第一阶段阶段训练的训练方法的流程图。
以下,结合图5和图6,我们将对使用改进的EWC的学习模型第一阶段训练进行说明。
假设以x(t)表示第t阶段训练的第t组输入图像样本,以x(t,i)表示第t阶段训练的第t组输入图像样本中的第i个输入图像样本,其中,i=1,2,...,N,N是第t组输入图像样本中的输入图像样本总数。
在第一阶段训练中,第一组输入图像样本被输入到所述改进的使用EWC的学习模型。对于第一阶段训练为例,t=1,x(1)表示第一组输入图像样本。
在步骤S501中,输入图像样本x(1)被输入到特征提取器中;
在步骤S502中,特征提取器提取输入图像样本x(1)的特征z(1),并将所述特征z(1)输出到线性分类器;
其中,x(1,i)表示第一组输入图像样本中的第i个输入图像样本,其中,i=1,2,...,N,N是第一组输入图像样本的样本总数;表示第一组输入图像样本中的第i个输入图像样本的重构图像样本;是针对x(1,i)的重构误差项,所述重构误差项对应于所述输入图像样本x(1,i)和其重构图像样本之间的差异;γ是可调整的超参数(Hyper Parameter),可以根据实际情况和经验进行手动调节;表示第一组输入图像样本中的第i个输入图像样本的关于第j个分类的真实分类标签,表示第一组输入图像样本中的第i个输入图像样本的关于第j个分类的预测分类标签;其中,j=1,2,...,K,K表示分类总数,预测分类标签可以取大于等于0且小于等于1的任意值;真实分类标签是独热编码(One-Hot Code)的; 是针对x(1,i)的分类项,所述分类项对应于x(1,i)的预测分类标签和x(1,i)的真实分类标签y(1,i)之间的差异。
所述使用改进的EWC的学习模型在第一阶段训练中的目标是:(1)使线性分类器输出的分类标签尽量靠近真实分类标签y(1)以提高学习模型的分类表现;(2)使得重构图像样本尽量靠近输入图像样本x(1),以使特征提取器对输入图像样本x(1)的特征进行更加全面的提取,有利于提升连续学习,特别是跨类别连续学习的分类表现。
相较于I.(a)中描述的使用EWC的学习模型,I.(b)中描述的使用改进的EWC的学习模型的区别在于:在损失函数中考虑了重构误差项,从而可以使得重构图像样本尽量靠近输入图像样本x(1),以使特征提取器对输入图像样本x(1)的特征进行更加全面的提取,有利于提升连续学习,特别是跨类别连续学习的分类表现,提升分类精度。
图7示出了根据本公开实施例的使用改进的EWC的学习模型在第t阶段阶段(其中,t>1)训练的训练装置的示意图;图8示出了根据本公开实施例的使用改进的EWC的学习模型在第t阶段阶段(其中,t>1)训练的训练方法的流程图。
以下,结合图7和图8,我们将对使用改进的EWC的学习模型第t阶段训练进行说明。
假设以x(t)表示第t阶段训练的第t组输入图像样本,以x(t,i)表示第t阶段训练的第t组输入图像样本中的第i个输入图像样本,其中,i=1,2,...,N,N是第t组输入图像样本中的输入图像样本总数。
在第t阶段训练中,第t组输入图像样本被输入到所述使用改进的EWC的学习模型。对于第t阶段训练为例,t>1,x(t)表示第t组输入图像样本。
在步骤S701中,输入图像样本x(t)被输入到特征提取器中;
在步骤S702中,特征提取器提取输入图像样本x(t)的特征z(t),并将所述特征z(t)输出到线性分类器;
其中,x(t,i)表示第t组输入图像样本中的第i个输入图像样本,其中,i=1,2,...,N,N是第t组输入图像样本的样本总数;表示第t组输入图像样本中的第i个输入图像样本的重构图像样本;是针对x(t,i)的重构误差项,所述重构误差项对应于所述输入图像样本x(t,i)和其重构图像样本之间的差异;γ是可调整的超参数,可以根据实际情况和经验进行手动调节;表示第t组输入图像样本中的第i个输入图像样本的关于第j个分类的真实节类标签,表示第t组输入图像样本中的第i个输入图像样本的关于第j个分类的预测分类标签;其中,j=1,2,...,K,K表示分类总数;预测分类标签可以取大于等于0且小于等于1的任意值;真实分类标签是独热编码(One-HotCode)的;针对x(1,i)的分类项,所述分类项对应于x(t,i)的预测分类标签和x(t,i)的真实分类标签y(t,i)之间的差异;R(t)表示第t训练阶段的正则项。
其中,R(t)是所述使用改进的EWC的学习模型在第t阶段训练中的正则项,所述正则项由正则项计算器计算。所述正则项R(t)如表达式3b所示:
其中,λ是可调整的超参数(Hyper Parameter),可以根据实际情况和经验进行手动调节;参数集θ(t)是所述使用EWC的学习模型在经过第t阶段训练后的所有参数的集合,在本实施中包括特征提取器中的所有参数、线性分类器中的所有参数以及解码器中的所有参数,所述参数集合包含的参数的总数为|θ(t)|,其中,表示第p个模型参数在先前第t-1训练阶段中的权重值,表示第p个模型参数在当前第t训练阶段中的所述当前值,表示第p个模型参数在第t-1训练阶段中的所述先前值。
其中,表示第p个模型参数在先前第t-1训练阶段中的权重值,表示第p个模型参数在第t-1训练阶段中的所述先前值;x(t-1,i)表示第t-1组输入图像样本中的第i个输入图像样本,其中,i=1,2,...,N,N是第t组输入图像样本的样本总数;表示第t-1组输入图像样本中的第i个输入图像样本的重构图像样本;γ是可调节的超参;表示第t-1组输入图像样本中的第i个输入图像样本的关于第j个分类的真实分类标签,表示第t-1组输入图像样本中的第i个输入图像样本的关于第j个分类的预测分类标签;其中,j=1,2,...,K,K表示分类总数。
所述使用EWC的学习模型在第t阶段训练中的目标是:(1)使线性分类器输出的分类标签尽量靠近真实分类标签以提高学习模型的分类表现;(2)通过在损失函数的计算中引入正则项R(t),解决在该学习模型在第t阶段训练中的“灾难性遗忘”问题;(3)使得重构图像样本尽量靠近输入图像样本x(t),以使特征提取器对输入图像样本x(t)的特征进行更加全面的提取,有利于提升连续学习,特别是跨类别连续学习的分类变现。
所述正则项R(t)贡献于“灾难性遗忘”问题的解决。具体而言,在针对新任务进行训练时,R(t)允许相对不重要的参数有较大的变化来学习新任务,并将保持相对重要的参数的变化程度尽量较小。也就是说,对于在第t-1阶段训练中较为重要的参数的值较大,所述使用EWC的学习模型在第t阶段训练中不允许过分地远离而对于在第t-1阶段训练中较为不重要参数的值较小,所述使用EWC的学习模型在第t阶段训练中允许可以极大地不同于
相较于I.(a)中描述的使用EWC的学习模型,I.(b)中描述的使用改进的EWC的学习模型的区别在于:在损失函数中考虑了重构误差项,从而可以使得重构图像样本尽量靠近输入图像样本x(t),以使特征提取器对输入图像样本x(t)的特征进行更加全面的提取,有利于提升连续学习,特别是跨类别连续学习的分类变现,提升分类精度。
II.(a)表达正则方法
增量分类器和表达学习法(incremental Classfier and RepresentationLearning,iCaRL)属于表达正则方法的一种。在本实施例中,将以iCaRL为例对使用表达正则方法的学习模型进行说明。
使用iCaRL的学习模型在第一阶段训练的训练装置的结构、以及训练方法的流程与如图1和图2所示出的使用EWC的学习模型在第一阶段训练的训练装置的结构、以及训练方法的流程是相同的。在此不再赘述。
图9示出了根据本公开实施例的使用iCaRL的学习模型在第t阶段阶段(其中,t>1)训练的训练装置的示意图;图10示出了根据本公开实施例的使用iCaRL的学习模型在第t阶段阶段(其中,t>1)训练的训练方法的流程图。
以下,结合图9和图10,我们将对使用iCaRL的学习模型第t阶段训练进行说明。
在执行所述第一训练阶段之后,循环地执行第t训练阶段。在每次循环地执行所述第t训练阶段之前,存储当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分,并将所存储的先前所有训练阶段的输入图像样本中的一部分与所述第t组输入图像样本一起输入到所述神经网络模型。
假设以x(t)表示第t阶段训练的第t组输入图像样本,以x(t,i)表示第t阶段训练的第t组输入图像样本中的第i个输入图像样本,其中,i=1,2,...,N,N是第t组输入图像样本中的输入图像样本总数。
在第t阶段训练中,第t组输入图像样本、以及所述所存储的先前所有训练阶段的输入图像样本中的一部分被输入到所述使用iCaRL的学习模型。
在步骤S903中,基于特征z(t)和针对在第t训练阶段相较于先前训练阶段的新增的分类,线性分类器分别生成输入图像样本x(t)的预测分类标签以及输入图像样本的预测分类标签并将所述预测分类标签和输出到损失函数计算器;
在步骤S904中,基于x(t)的针对在第t训练阶段相较于先前训练阶段的新增的分类的预测分类标签和真实分类标签y(t)、的针对在第t训练阶段相较于先前训练阶段的新增的分类的预测分类标签和真实分类标签以及正则项R(t),损失函数计算器计算第t阶段训练的损失函数所述损失函数如表达式5a所示:
其中,x(t,i)表示第t组输入图像样本中的第i个输入图像样本,其中,i=1,2,...,N,N是第t组输入图像样本的样本总数;表示第t组输入图像样本中的第i个输入图像样本在第t训练阶段的重构图像样本;表示所存储的当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分中的第i′个输入图像样本,其中,i′=1,2,...,No,No是所存储的先前所有训练阶段的输入图像样本的一部分中所包含的输入图像样本总数,表示所存储的当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分中的第i′个输入图像样本在第t训练阶段的重构样本,γ是可调节的超参;表示第t组输入图像样本中的第i个输入图像样本的关于第j个分类的真实分类标签,表示第t组输入图像样本中的第i个输入图像样本在第t训练阶段的关于第j个分类的预测分类标签;表示所存储的当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分中的第i′个输入图像样本的关于第j个分类的真实分类标签;表示所存储的当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分中的第i′个输入图像样本在第t训练阶段的关于第j个分类的预测分类标签,其中,j=Kt-1+1,Kt-1+2,...,Kt,Kt-1表示截止到第t-1阶段的分类总数,Kt表示截止到第t阶段的分类总数;真实分类标签y(t)和是独热编码(One-Hot Code)的,预测分类标签和可以取大于等于0且小于等于1的任意值;
是第t训练阶段的分类项,所述分类项对应于x(t)的针对在第t训练阶段相较于先前训练阶段的新增的分类的预测分类标签和x(t)的真实分类标签y(t)之间的差异、以及的针对在第t训练阶段相较于先前训练阶段的新增的分类的预测分类标签和x(t)的真实分类标签之间的差异。
其中,R(t)是所述使用iCaRL的学习模型在第t阶段训练中的正则项,所述正则项由正则项计算器计算。所述正则项R(t)如表达式6a所示:
其中,表示第t组输入图像样本中的第i个输入图像样本在第t-1训练阶段中关于第j′个分类的预测分类标签,其中i=1,2,...,N,N是第t组输入图像样本的样本总数,j′=1,2,...,Kt-1,Kt-1表示截止到第t-1训练阶段的分类总数;表示第t组输入图像样本中的第i个输入图像样本在第t训练阶段中关于第j′个分类的预测分类标签;表示所存储的当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分中的第i′个输入图像样本在第t-1训练阶段中关于第j′个分类的预测分类标签,表示所存储的当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分中的第i′个输入图像样本在第t训练阶段中关于第j′个分类的预测分类标签,其中i′=1,2,...,No,No是所存储的先前所有训练阶段的输入图像样本的一部分中所包含的输入图像样本总数;λ是可调节的超参;所述正则项是针对所存储的先前所有训练阶段的输入图像样本的一部分、以及第t组输入图像样本中的每一个输入图像样本,计算的对应于所述输入图像样本针对截至第t-1训练阶段的所有训练阶段的分类,第t训练阶段的预测分类标签和与第t-1训练阶段的预测分类标签和之间的差异。
所述使用iCaRL的学习模型在第t阶段训练中的目标是:(1)使针对在第t训练阶段相较于先前训练阶段的新增的分类输出的预测分类标签和分别尽量靠近真实分类标签y(t)和以提高学习模型在第t训练阶段的分类表现:(2)通过在损失函数的计算中引入正则项R(t),使针对第t-1训练阶段以及之前所有训练阶段的分类的预测分类标签和分别尽量靠近预测分类标签和以解决在该学习模型在第t阶段训练中的“灾难性遗忘”问题。
改进现有类别增量学习的方向之一就是如何使其更全面地学习输入数据的特征,使得在新任务的输入数据分布与先前任务的输入数据分布有较大不同时,仍然能够在仅少量获取先前任务输入数据(对于表达正则方法而言)的情况下,提取出新任务的学习所需要的特征,从而使学习模型能够更好地完成新任务的学习,提高分类精度。
II.(b)改进的表达正则方法
使用改进的iCaRL的学习模型在第一阶段训练的训练装置的结构、以及训练方法的流程与如图5和图6所示出的使用改进的EWC的学习模型在第一阶段训练的训练装置的结构、以及训练方法的流程是相同的。在此不再赘述。
图11示出了根据本公开实施例的使用改进的iCaRL的学习模型在第t阶段阶段(其中,t>1)训练的训练装置的示意图;图12示出了根据本公开实施例的使用iCaRL的学习模型在第t阶段阶段(其中,t>1)训练的训练方法的流程图。
以下,结合图11和图12,我们将对使用改进的iCaRL的学习模型第t阶段训练进行说明。
在执行所述第一训练阶段之后,循环地执行第t训练阶段。在每次循环地执行所述第t训练阶段之前,存储当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分,并将所存储的先前所有训练阶段的输入图像样本中的一部分与所述第t组输入图像样本一起输入到所述神经网络模型。
假设以x(t)表示第t阶段训练的第t组输入图像样本,以x(t,i)表示第t阶段训练的第t组输入图像样本中的第i个输入图像样本,其中,i=1,2,...,N,N是第t组输入图像样本中的输入图像样本总数。
在第t阶段训练中,第t组输入图像样本、以及所述所存储的先前所有训练阶段的输入图像样本中的一部分被输入到所述使用改进的iCaRL的学习模型。
在步骤S1103中,基于特征z(t)和针对在第t训练阶段相较于先前训练阶段的新增的类别,线性分类器分别生成输入图像样本x(t)的预测分类标签以及输入图像样本均预测分类标签并将所述预测分类标签和输出到损失函数计算器;
在步骤S1105中,基于所述输入图像样本x(t)和所述重构图像样本重构项计算器生成重构误差项E(t);基于所述输入图像样本和所述重构图像样本重构项计算器生成重构误差项所述重构误差项E(t)对应于所述输入图像样本x(t)和所述重构图像样本之间的差异,且所述重构误差项对应于所述输入图像样本和所述重构图像样本之间的差异;
在步骤S1106中,基于针对在第t训练阶段相较于先前训练阶段的新增的分类预测分类标签和真实分类标签y(t)和所述重构误差项E(t)和以及正则项R(t),损失函数计算器计算第t训练阶段的损失函数所述输入图像样本x(t)的损失函数如表达式5b所示:
其中,x(t,i)表示第t组输入图像样本中的第i个输入图像样本,其中,i=1,2,...,N,N是第t组输入图像样本的样本总数;表示第t组输入图像样本中的第i个输入图像样本在第t训练阶段的重构图像样本;表示所存储的当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分中的第i′个输入图像样本,其中,i′=1,2,...,No,No是所存储的先前所有训练阶段的输入图像样本的一部分中所包含的输入图像样本总数,表示所存储的当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分中的第i′个输入图像样本在第t训练阶段的重构样本,γ是可调节的超参;表示第t组输入图像样本中的第i个输入图像样本的关于第j个分类的真实分类标签,表示第t组输入图像样本中的第i个输入图像样本在第t训练阶段的关于第j个分类的预测分类标签;表示所存储的当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分中的第i′个输入图像样本的关于第j个分类的真实分类标签;表示所存储的当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分中的第i′个输入图像样本在第t训练阶段的关于第j个分类的预测分类标签,其中,j=Kt-1+1,Kt-1+2,...,Kt,Kt-1表示截止到第t-1阶段的分类总数,Kt表示截止到第t阶段的分类总数;真实分类标签y(t)和是独热编码(One-Hot Code)的.预测分类标签和可以取大于等于0且小于等于1的任意值;
是第t训练阶段的分类项,所述分类项对应于x(t)的针对在第t训练阶段相较于先前训练阶段的新增的分类的预测分类标签和x(t)的真实分类标签y(t)之间的差异、以及的针对在第t训练阶段相较于先前训练阶段的新增的分类的预测分类标签和x(t)的真实分类标签之间的差异;是第t训练阶段的重构误差项,所述重构误差项对应于所述输入图像样本x(t)和所述重构图像样本的差异、以及所述输入图像样本和所述重构图像样本的差异;
其中,R(t)是所述使用改进的iCaRL的学习模型在第t阶段训练中的正则项,所述正则项由正则项计算机计算。所述正则项R(t)如表达式6b所示:
其中,表示第t组输入图像样本中的第i个输入图像样本在第t-1阶段的模型下关于第j′个分类的预测分类标签,其中i=1,2,...,N,N是第t组输入图像样本的样本总数,j′=1,2,...,Kt-1,Kt-1表示截止到第t-1阶段的分类总数;表示所存储的当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分中的第i个输入图像样本在第t-1阶段的模型下关于第j′个分类的预测分类标签,其中i′=1,2,...,No,No是所存储的先前所有训练阶段的输入图像样本的一部分中所包含的输入图像样本总数;λ是可调节的超参;所述正则项是针对所存储的先前所有训练阶段的输入图像样本的一部分、以及第t组输入图像样本中的每一个输入图像样本,计算的对应于所述输入图像样本针对截至第t-1训练阶段的所有训练阶段的分类,第t训练阶段的预测分类标签和与第t-1训练阶段的预测分类标签和之间的差异。
所述使用改进的iCaRL的学习模型在第t阶段训练中的目标是:(1)使针对在第t训练阶段相较于先前训练阶段的新增的分类输出的预测分类标签和分别尽量靠近真实分类标签y(t)和以提高学习模型在第t训练阶段的分类表现;(2)通过在损失函数的计算中引入正则项R(t),使针对第t-1训练阶段以及之前所有训练阶段的分类的预测分类标签和分别尽量靠近预测分类标签和以解决在该学习模型在第t阶段训练中的“灾难性遗忘”问题;(3)使得重构图像样本尽量靠近输入图像样本x(t),以使特征提取器对输入图像样本x(t)的特征进行更加全面的提取,有利于提升连续学习,特别是跨类别连续学习的分类变现。
需要指出的是,本发明不仅仅可以只在EWC和iCaRL两种模型上进行改进,对于任何类别递增模型,只要有无法学到跨组别的类别之间的分类界面的问题,都可以用本发明提出的方法,即,在现有类别递增学习模型的基础上加解码器和重构误差计算器,并在误差函数中加上项。
III.改进的权重正则方法与表达正则方法对分类精度的提升
为了示出所述改进方案在EWC场景以及iCaRL场景中的分类精度优势,以下将使用公开数据集,将使用所述改进方案与不使用所述改进方案的模型进行比较。
表1示出了在公开数据集MNIST上,在EWC场景下,使用了本发明提出的解码器的模型和不使用解码器的模型的平均分类精度(各阶段结束后的分类精度的平均值):
表1
可以看到,使用解码器之后显著提高了平均分类精度。
具体而言,MNIST数据集包含由多人完成的手写数字图片,每一张图片都是黑白的,对应从0到9中的一个数字。我们把这些数字分为5个不同批次,即,0和1为第一批次,2和3为第二批次,以此类推。
首先,所述网络模型获取获取第一批次的所有数据,即,所有对应0和1的图片,将所述第一批次作为输入图像样本进行训练。在训练完成之后,对于任何一张测试图片,所述模型都可以判断其为0、还是1。
具体而言,在第一批次的训练中,所述网络模型包括第一训练步骤,用于计算对应于第一批次输入图像样本的第一损失函数,并利用所述第一损失函数训练所述网络模型,其中,所述第一训练步骤包括:
-由特征提取器接收第一批次输入图像样本、提取所述第一批次输入图像样本的特征、并将所述特征分别输出到线性分类器和解码器;
-由所述线性分类器基于特征生成所述第一批次输入图像样本的预测分类标签,并将所述预测分类标签输出到损失函数计算器;
-由所述解码器基于所述特征生成所述第一批次输入图像样本的重构图像样本,将所述重构图像样本输出到重构误差计算器;
-由所述重构误差计算器生成重构误差项,并将所述重构误差项输出到所述损失函数计算器,其中,所述重构误差项对应于所述第一批次输入图像样本和其重构图像样本的差异,
-由所述损失函数计算器计算所述第一损失函数,其中所述第一损失函数是基于所述第一批次输入图像样本的预测分类标签与真实分类标签之间的差异、以及所述第一批次输入图像样本与其重构图像样本之间的差异来计算的。
然后,网络模型获取获取第二批次的所有数据,即,所有对应2和3的图片,将所述第二批次的图片样本作为输入图像样本进行训练,目标是使所述模型能够对0、1、2和3这四类图片进行分类。
具体而言,在第二批次的训练中,所述网络模型包括第二训练步骤,用于计算对应于第二批次输入图像样本的第二损失函数,并利用所述第二损失函数训练所述网络模型,其中,所述第二训练步骤包括:
-由特征提取器接收第二批次输入图像样本、提取所述第二批次输入图像样本的特征、并将所述特征分别输出到线性分类器和解码器;
-由所述线性分类器基于特征生成所述第二批次输入图像样本的预测分类标签,并将所述预测分类标签输出到损失函数计算器;
-由所述解码器基于所述特征生成所述第二批次输入图像样本的重构图像样本,将所述重构图像样本输出到重构误差计算器;
-由所述重构误差计算器生成重构误差项,并将所述重构误差项输出到所述损失函数计算器,其中,所述重构误差项对应于所述第二批次输入图像样本和其重构图像样本的差异,
-由正则项计算机计算正则项,并将所述正则项输出到所述损失函数计算器;
-由所述损失函数计算器计算所述第二损失函数,其中所述第二损失函数是基于所述第二批次输入图像样本的预测分类标签与真实分类标签之间的差异、所述第二批次输入图像样本与其重构图像样本之间的差异、以及所述正则项来计算的。
重复这样的过程,最终能够得到一个训练完成的模型,该模型能够对对应于从0到9的图片进行分类。
表2示出了在公开数据集Fashion-MNIST上,在iCaRL情形使用不同存储空间的情况下,使用了本发明提出的解码器后的模型和不使用解码器的模型的平均分类精度:
表2
可以看到,在各种存储空间条件下,使用解码器之后均显著提高了平均分类精度。
具体而言,Fashion-MNIST的训练集包含60000张黑白手绘图片,测试集包括10000张黑白手绘图片,其包括上衣、长裤等10种不同的类别,每种类别的图片数量是相同的。同样地,我们将这10类分为5个批次(Batch),即前两类为第一批次,第三第四类为第二批次,以此类推。将所比较的几个模型依次在5个批次上进行分类任务的训练。在该测试中,我们使用存储空间大小,即,可以存储共计10、20或者40张图片的存储空间。
首先,所述网络模型获取获取第一批次的所有数据,即,所有对应前两类的图片(T恤和长裤),将所述第一批次作为输入图像样本进行训练。在训练完成之后,对于任何一张测试图片,所述模型都可以判断其为T恤还是长裤。在第一批次的训练完成后,随机选择第一批次中的一部分T恤和长裤的图片样本存储在所述存储空间中。
具体而言,在第一批次的训练中,所述网络模型包括第一训练步骤,用于计算对应于第一批次输入图像样本的第一损失函数,并利用所述第一损失函数训练所述网络模型,其中,所述第一训练步骤包括:
-由特征提取器接收第一批次输入图像样本、提取所述第一批次输入图像样本的特征、并将所述特征分别输出到线性分类器和解码器;
-由所述线性分类器基于特征生成所述第一批次输入图像样本的预测分类标签,并将所述预测分类标签输出到损失函数计算器;
-由所述解码器基于所述特征生成所述第一批次输入图像样本的重构图像样本,将所述重构图像样本输出到重构误差计算器;
-由所述重构误差计算器生成重构误差项,并将所述重构误差项输出到所述损失函数计算器,其中,所述重构误差项对应于所述第一批次输入图像样本和其重构图像样本的差异,
-由所述损失函数计算器计算所述第一损失函数,其中所述第一损失函数是基于所述第一批次输入图像样本的预测分类标签与真实分类标签之间的差异、以及所述第一批次输入图像样本与其重构图像样本之间的差异来计算的。
然后,在存储空间中存储第一批次中的一部分T恤和长裤的图片样本。
然后,所述网络模型获取获取第二批次的所有数据,即,所有对应套衫和连衣裙的图片,将所述第二批次输入图像样本、以及所述存储空间中的之前存储的第一批次中的一部分T恤和长裤的图片样本一起作为第二训练阶段输入图像样本进行训练,目标是使所述模型能够对T恤、长裤、套衫和连衣裙这四类图片进行分类。
具体而言,在第二批次的训练中,所述网络模型包括第二训练步骤,用于计算对应于第二批次输入图像样本的第二损失函数,并利用所述第二损失函数训练所述网络模型,其中,所述第二训练步骤包括:
-由特征提取器接收第二训练阶段输入图像样本、提取所述第二训练阶段输入图像样本的特征、并将所述特征分别输出到线性分类器和解码器;
-由所述线性分类器基于特征生成第二训练阶段输入图像样本的预测分类标签,并将所述预测分类标签输出到损失函数计算器;
-由所述解码器基于所述特征生成第二训练阶段输入图像样本的重构图像样本,将所述重构图像样本输出到重构误差计算器;
-由所述重构误差计算器生成重构误差项,并将所述重构误差项输出到所述损失函数计算器,其中,所述重构误差项对应于所述第二训练阶段输入图像样本和其重构图像样本的差异,
-由正则项计算机计算正则项,并将所述正则项输出到所述损失函数计算器;
-由所述损失函数计算器计算所述第二损失函数,其中所述第二损失函数是基于所述第二训练阶段输入图像样本的针对第二批次的分类(即,套衫和连衣裙)的预测分类标签与真实分类标签之间的差异、所述第二训练阶段输入图像样本与其重构图像样本之间的差异、以及所述正则项来计算的。所述正则项对应于基于所述第二训练阶段输入图像样本的针对第一批次的分类(即,T恤、长裤)在第一训练阶段的预测分类标签与第二训练阶段的预测分类标签的差异。
然后,随机清空所述存储空间中所存储的一些T恤和长裤的图片样本,旨在腾出所述存储空间的一部分来存储第二批次中的一部分套衫和连衣裙的图片样本以用于后续训练使用。
重复这样的过程,最终能够得到一个训练完成的模型,该模型能够对对应于从所有十类图片进行分类。
可见,改进的权重正则方法与表达正则方法对分类精度有提升效果。
本发明是针对现有类别增量学习的改进,即,如何使其更全面地学习输入数据的特征,使得在新任务的输入数据分布与先前任务的输入数据分布有较大不同时,仍然能够在不获取先前任务输入数据(对于权重正则方法而言)、或者仅少量获取先前任务输入数据(对于表达正则方法而言)的情况下,提取出新任务的学习所需要的特征,从而使学习模型能够更好地完成新任务的学习。
图13是图示根据本公开实施例的电子设备的硬件框图。根据本公开实施例的电子设备至少包括处理器;以及存储器,用于存储计算机程序指令。当计算机程序指令由处理器加载并运行时,所述处理器执行如上所述的神经网络模型的训练方法和图像处理方法。
图13所示的电子设备1300具体地包括:中央处理单元(CPU)1301、图形处理单元(GPU)1302和主存储器1303。这些单元通过总线1304互相连接。中央处理单元(CPU)1301和/或图形处理单元(GPU)1302可以用作上述处理器,主存储器1303可以用作上述存储计算机程序指令的存储器。此外,电子设备1300还可以包括通信单元1305、存储单元1306、输出单元1307、输入单元1308和外部设备1306,这些单元也连接到总线1304。
图14是图示根据本公开的实施例的计算机可读存储介质的示意图。如图14所示,根据本公开实施例的计算机可读存储介质1400其上存储有计算机程序指令1401。当所述计算机程序指令1401由处理器运行时,执行参照以上附图描述的根据本公开实施例的神经网络模型的训练方法和图像识别方法。所述计算机可读存储介质包括但不限于例如易失性存储器和/或非易失性存储器。所述易失性存储器例如可以包括随机存取存储器(RAM)和/或高速缓冲存储器(cache)等。所述非易失性存储器例如可以包括只读存储器(ROM)、硬盘、闪存、光盘、磁盘等。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。
以上结合具体实施例描述了本公开的基本原理,但是,需要指出的是,在本公开中提及的优点、优势、效果等仅是示例而非限制,不能认为这些优点、优势、效果等是本公开的各个实施例必须具备的。另外,上述公开的具体细节仅是为了示例的作用和便于理解的作用,而非限制,上述细节并不限制本公开为必须采用上述具体的细节来实现。
本公开中涉及的器件、装置、设备、系统的方框图仅作为例示性的例子并且不意图要求或暗示必须按照方框图示出的方式进行连接、布置、配置。如本领域技术人员将认识到的,可以按任意方式连接、布置、配置这些器件、装置、设备、系统。诸如“包括”、“包含”、“具有”等等的词语是开放性词汇,指“包括但不限于”,且可与其互换使用。这里所使用的词汇“或”和“和”指词汇“和/或”,且可与其互换使用,除非上下文明确指示不是如此。这里所使用的词汇“诸如”指词组“诸如但不限于”,且可与其互换使用。
还需要指出的是,在本公开的系统和方法中,各部件或各步骤是可以分解和/或重新组合的。这些分解和/或重新组合应视为本公开的等效方案。
可以不脱离由所附权利要求定义的教导的技术而进行对在此所述的技术的各种改变、替换和更改。此外,本公开的权利要求的范围不限于以上所述的处理、机器、制造、事件的组成、手段、方法和动作的具体方面。可以利用与在此所述的相应方面进行基本相同的功能或者实现基本相同的结果的当前存在的或者稍后要开发的处理、机器、制造、事件的组成、手段、方法或动作。因而,所附权利要求包括在其范围内的这样的处理、机器、制造、事件的组成、手段、方法或动作。
提供所公开的方面的以上描述以使本领域的任何技术人员能够做出或者使用本公开。对这些方面的各种修改对于本领域技术人员而言是非常显而易见的,并且在此定义的一般原理可以应用于其他方面而不脱离本公开的范围。因此,本公开不意图被限制到在此示出的方面,而是按照与在此公开的原理和新颖的特征一致的最宽范围。
为了例示和描述的目的已经给出了以上描述。此外,此描述不意图将本公开的实施例限制到在此公开的形式。尽管以上已经讨论了多个示例方面和实施例,但是本领域技术人员将认识到其某些变型、修改、改变、添加和子组合。
Claims (10)
1.一种用于图像识别的神经网络模型的训练方法,包括:
第一训练阶段,计算对应于第一组输入图像样本的第一损失函数,并利用所述第一损失函数训练所述神经网络模型,
其中,所述第一损失函数的计算包括:(1)针对所述第一组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本的预测分类标签与真实分类标签之间的差异的分类项;以及(2)针对所述第一组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本与其重构图像样本之间的差异的重构误差项。
2.如权利要求1所述的训练方法,进一步包括:
在执行所述第一训练阶段之后,循环地执行第t训练阶段,其中,t>1,所述第t训练阶段计算对应于第t组输入图像样本的第t损失函数,并利用所述第t损失函数训练所述神经网络模型,
其中,所述第t损失函数的计算包括:(1)针对所述第t组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本的预测分类标签与真实分类标签之间的差异的分类项;(2)针对所述第t组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本与其重构图像样本之间的差异的重构项;以及(3)针对第t训练阶段,计算正则项。
3.如权利要求1所述的训练方法,进一步包括:
在执行所述第一训练阶段之后,循环地执行第t训练阶段,其中,t>1,所述第t训练阶段计算对应于第t组输入图像样本的第t损失函数,并利用所述第t损失函数训练所述神经网络模型,
其中,在每次循环地执行所述第t训练阶段之前,存储当前第t训练阶段的先前所有训练阶段的输入图像样本中的一部分,并将所存储的先前所有训练阶段的输入图像样本中的一部分与所述第t组输入图像样本一起输入到所述神经网络模型,
其中,所述第t损失函数的计算基于:(1)针对所存储的先前所有训练阶段的输入图像样本的一部分、以及第t组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本针对第t训练阶段相较于先前训练阶段新增的分类的预测分类标签与真实分类标签之间的差异的分类项;(2)针对所存储的先前所有训练阶段的输入图像样本的一部分、以及第t组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本与其重构图像样本之间的差异的重构项;以及(3)针对第t训练阶段,计算正则项。
4.如权利要求1-3中任一项所述的训练方法,其中,
所述重构误差项是由重构误差项计算器基于输入图像样本与重构图像样本之间的差异来计算的,其中,所述重构图像样本是由解码器基于所述输入图像样本的特征来生成的,所述输入图像样本的特征是由特征提取器基于所接收的所述输入图像样本来提取的。
5.如权利要求2所述的训练方法,其中,第t训练阶段的正则项由正则项计算器计算,所述计算包括:
针对所述神经网络模型的每一个模型参数,计算在当前第t训练阶段中所述模型参数的当前值与在先前第t-1训练阶段中所述模型参数的先前值之间的差异;
针对所述神经网络模型的每一个模型参数,计算在先前第t-1训练阶段中所述模型参数的权重值;以及
基于所述差异和所述权重值,计算所述模型参数在当前第t训练阶段中与在先前第t-1训练阶段中的所述差异的加权和。
6.如权利要求5所述的训练方法,其中,在先前第t-1训练阶段中模型参数的权重值的计算包括:(1)针对在先前第t-1训练阶段中的第t-1组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本的预测分类标签与真实分类标签之间的差异的分类项;(2)针对在先前第t-1训练阶段中的第t-1组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本与其重构图像样本之间的差异的重构项。
7.如权利要求5所述的训练方法,其中,所述模型参数包括特征提取器参数、线性分类器参数、以及解码器参数。
8.如权利要求3所述的训练方法,其中,第t训练阶段的正则项由正则项计算器计算,所述计算包括:
针对所存储的先前所有训练阶段的输入图像样本的一部分、以及第t组输入图像样本中的每一个输入图像样本,计算对应于所述输入图像样本针对截至t-1训练阶段的所有训练阶段的分类,在第t训练阶段的预测分类标签与在第t-1训练阶段的预测分类标签之间的差异,
针对所述所存储的先前所有训练阶段的输入图像样本中的一部分,计算所述差异的和。
9.一种用于图像识别的神经网络模型的训练装置,其中,所述装置执行如权利要求1-8的任一项所述的训练方法。
10.一种电子设备,包括:
处理器;
存储器,用于存储计算机程序指令;
其中,当所述计算机程序指令由所述处理器加载并运行时,所述处理器执行如权利要求1-8的任一项所述的训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010080441.2A CN113222100A (zh) | 2020-02-05 | 2020-02-05 | 神经网络模型的训练方法和装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010080441.2A CN113222100A (zh) | 2020-02-05 | 2020-02-05 | 神经网络模型的训练方法和装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN113222100A true CN113222100A (zh) | 2021-08-06 |
Family
ID=77085571
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010080441.2A Pending CN113222100A (zh) | 2020-02-05 | 2020-02-05 | 神经网络模型的训练方法和装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113222100A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113627598A (zh) * | 2021-08-16 | 2021-11-09 | 重庆大学 | 一种用于加速推荐的孪生自编码器神经网络算法及系统 |
-
2020
- 2020-02-05 CN CN202010080441.2A patent/CN113222100A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113627598A (zh) * | 2021-08-16 | 2021-11-09 | 重庆大学 | 一种用于加速推荐的孪生自编码器神经网络算法及系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Song et al. | Dual Conditional GANs for Face Aging and Rejuvenation. | |
CN109033095B (zh) | 基于注意力机制的目标变换方法 | |
Schulz et al. | Deep learning: Layer-wise learning of feature hierarchies | |
CN109711426B (zh) | 一种基于gan和迁移学习的病理图片分类装置及方法 | |
US20230081346A1 (en) | Generating realistic synthetic data with adversarial nets | |
EP3963516B1 (en) | Teaching gan (generative adversarial networks) to generate per-pixel annotation | |
CN111292330A (zh) | 基于编解码器的图像语义分割方法及装置 | |
CN111767979A (zh) | 神经网络的训练方法、图像处理方法、图像处理装置 | |
CN112132739A (zh) | 3d重建以及人脸姿态归一化方法、装置、存储介质及设备 | |
US20230153965A1 (en) | Image processing method and related device | |
Ibragimovich et al. | Effective recognition of pollen grains based on parametric adaptation of the image identification model | |
Gogoi et al. | Image classification using deep autoencoders | |
CN114463605A (zh) | 基于深度学习的持续学习图像分类方法及装置 | |
CN114330736A (zh) | 具有噪声对比先验的潜在变量生成性模型 | |
Pieters et al. | Comparing generative adversarial network techniques for image creation and modification | |
CN115222998A (zh) | 一种图像分类方法 | |
Roy et al. | Tips: Text-induced pose synthesis | |
CN112801029B (zh) | 基于注意力机制的多任务学习方法 | |
CN113222100A (zh) | 神经网络模型的训练方法和装置 | |
CN112801107A (zh) | 一种图像分割方法和电子设备 | |
CN116797850A (zh) | 基于知识蒸馏和一致性正则化的类增量图像分类方法 | |
EP3588441B1 (en) | Imagification of multivariate data sequences | |
KR102105951B1 (ko) | 추론을 위한 제한된 볼츠만 머신 구축 방법 및 추론을 위한 제한된 볼츠만 머신을 탑재한 컴퓨터 장치 | |
CN113344189B (zh) | 一种神经网络的训练方法、装置、计算机设备及存储介质 | |
CN113987170A (zh) | 基于卷积神经网络的多标签文本分类方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |