CN113449878B - 数据分布式的增量学习方法、系统、设备及存储介质 - Google Patents
数据分布式的增量学习方法、系统、设备及存储介质 Download PDFInfo
- Publication number
- CN113449878B CN113449878B CN202110706288.4A CN202110706288A CN113449878B CN 113449878 B CN113449878 B CN 113449878B CN 202110706288 A CN202110706288 A CN 202110706288A CN 113449878 B CN113449878 B CN 113449878B
- Authority
- CN
- China
- Prior art keywords
- data distribution
- data
- model
- distribution node
- data set
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 60
- 238000003860 storage Methods 0.000 title claims abstract description 11
- 238000009826 distribution Methods 0.000 claims abstract description 181
- 238000004220 aggregation Methods 0.000 claims abstract description 12
- 230000002776 aggregation Effects 0.000 claims abstract description 12
- 238000012549 training Methods 0.000 claims description 40
- 230000006870 function Effects 0.000 claims description 24
- 238000004821 distillation Methods 0.000 claims description 18
- 238000013140 knowledge distillation Methods 0.000 claims description 14
- 238000004590 computer program Methods 0.000 claims description 10
- 238000005070 sampling Methods 0.000 claims description 9
- 238000004364 calculation method Methods 0.000 claims description 5
- 230000004931 aggregating effect Effects 0.000 claims description 3
- 238000010276 construction Methods 0.000 claims description 3
- 238000000605 extraction Methods 0.000 claims description 3
- 238000006116 polymerization reaction Methods 0.000 claims 1
- 238000012360 testing method Methods 0.000 description 2
- 238000004458 analytical method Methods 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 239000002131 composite material Substances 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 230000007774 longterm Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 206010027175 memory impairment Diseases 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000004088 simulation Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Physics & Mathematics (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Image Analysis (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了一种数据分布式的增量学习方法、系统、设备及存储介质,包括以下步骤:确定各增量学习阶段的类别,建立各数据分布节点的数据集合;得各数据分布点模型;形成共享数据集;得各数据分布节点的模型参数;将各数据分布节点的模型参数进行加权聚合,得初步的全局共享模型;将M个数据分布节点模型在共享数据集上计算得到的预测输出logit值进行集成,得集成输出logit值,将初步的全局共享模型在共享数据集上对该集成输出logit值进行学习,得全局共享模型的模型参数;将全局共享模型的模型参数下发至各数据分布节点,对各个数据分布节点上的全局共享模型进行更新,该方法、系统、设备及存储介质能够有效提高模型的学习能力。
Description
技术领域
本发明属于大数据智能分析技术领域,涉及一种数据分布式的增量学习方法、系统、设备及存储介质。
背景技术
深度模型在人工智能的广泛研究领域取得了巨大的成功。然而,事实证明它们容易出现灾难性的遗忘问题。灾难性遗忘指在对新数据进行模型学习时,深度模型对旧数据的性能严重下降的现象。增量学习旨在减轻模型对旧数据的遗忘的同时学习新数据,成为了深度学习的一个重要研究课题。
目前的增量学习框架要求深度模型以集中的方式处理连续的信息流。尽管它取得了成功,但我们认为这种集中的设置通常是不可能或不切实际的。越来越多的数据从“孤岛”中产生并存在,这些“孤岛”可能受到各种规范化或隐私方面的要求。它并不总是允许移动数据和使用数据所有者之外的数据。此外,持续的数据流入将导致位于不同存储库中的大量数据,在将它们合并到一个存储库中进行学习时,可能会造成巨大的通信和计算负担。
因此,在数据位于不同位置的场景中部署学习模型是至关重要的,学习过程可以跨时间执行,同时数据分散在多个分布节点上。然而,现有的机器学习方案都无法处理如此复杂的场景,因此对学习的实现带来了巨大的挑战。
发明内容
本发明的目的在于克服上述现有技术的缺点,提供了一种数据分布式的增量学习方法、系统、设备及存储介质,该方法、系统、设备及存储介质能够有效提高模型的学习能力。
为达到上述目的,本发明所述的数据分布式的增量学习方法包括以下步骤:
1)确定数据分布节点的数目及增量学习阶段的数目;
2)建立训练数据集;
3)确定各增量学习阶段的类别,将训练数据集划分为T个独立的数据集合,其中一个增量学习阶段对应一个数据集合,再在当前增量学习阶段,根据当增量学习阶段对应的数据集合建立各数据分布节点的数据集合;
4)向各数据分布节点输入上一增量学习阶段的全局共享模型参数及当前增量学习阶段中各数据分布节点的数据集合,再在增量学习损失函数的约束下进行增量学习训练,得各数据分布点模型;
5)各数据分布节点从其数据集合中进行随机采样,再将采样结果进行聚合形成共享数据集;
6)各数据分布节点模型对共享数据集计算得到预测输出logit值,再将各数据分布节点模型计算得到的预测输出logit值进行集成,得集成预测输出logit值,将各数据分布节点模型在共享数据集S(t)上对该集成预测输出logit值进行学习,得各数据分布节点的模型参数;
7)将各数据分布节点的模型参数进行加权聚合,得初步的全局共享模型;
8)将M个数据分布节点模型在共享数据集S(t)上计算得到的预测输出logit值进行集成,得集成输出logit值,将初步的全局共享模型在共享数据集S(t)上对该集成输出logit值进行学习,得全局共享模型的模型参数;
9)将全局共享模型的模型参数下发至各数据分布节点,对各个数据分布节点上的全局共享模型进行更新,完成数据分布式的增量学习。
步骤2)的具体操作为:
建立训练数据集D={(x,y)∣x∈X,y∈L},其中,X为训练样本集,L为对应的数据标签集合L={1,…,C},,C为类别总数。
步骤3)的具体操作过程为:
确定各增量学习阶段的类别,将训练数据集D划分为T个独立的数据集合{D(1),D(2),…,D(T)},其中,一个数据集合对应一个增量学习阶段,在第t个增量学习阶段,在各数据分布节点上的数据集合
步骤4)中增量学习训练的具体过程为:
41)定义数据分布节点m上的旧类锚点集合各锚点为由特征提取模型计算得到的旧类代表样本,则第k个样本的获得方式为:
其中,Xc为c类数据样本集合,μm,c为数据分布节点m上c类数据的特征中心向量,φ(·;θ)为θ模型参数的特征提取器;
42)建立数据分布节点m上旧类锚点的损失函数
43)建立数据分布节点m上新类知识学习的损失函数
44)建立数据分布节点m上分布式增量蒸馏学习的损失函数
45)各数据分布节点m通过分布式增量蒸馏学习的损失函数进行训练,以更新模型参数θ(t-1),得各数据分布节点模型
步骤6)的具体操作过程为:
61)各数据分布节点模型对共享数据集S(t)计算得到预测输出logit值,其中,第m个数据分布节点模型计算得到的预测输出logit值/>为:
其中,x为共享数据集S(t)中的某一样本,f(·,x)为预测输出模型;
62)对各数据分布节点模型计算得到的预测输出logit值进行集成,得集成预测输出logit值[z(t)]0:
其中,为第t个增量学习阶段时,数据分布节点m上数据集样本数目,N(t)为第t个增量学习阶段时,所有数据分布节点上的数据集样本数目;
63)利用知识蒸馏方法,通过集成预测输出logit值[z(t)]0对各数据分布节点模型计算得到的预测输出logit值进行分布式合作知识蒸馏/>
其中,DKL为KL散度距离,τ1为知识蒸馏的温度参数,n=L(t)|为当前增量学习阶段新类别的数目;
64)将各数据分布节点模型通过分布式合作蒸馏学习的损失函数进行训练更新。
步骤7)的具体操作过程为:
71)获取M个数据分布节点模型的模型参数
72)对取M个数据分布节点模型的模型参数进行加权平均,得初步的全局共享模型参数/>
其中,为第t个增量学习阶段,数据分布节点m上的数据集样本数目,N(t)为第t个增量学习阶段,所有数据分布节点上的数据集样本数目。
步骤8)的具体操作过程为:
81)各数据分布节点模型在共享数据集S(t)计算得到预测输出logit值,其中,第m个数据分布节点模型计算得到的预测输出logit值/>为:
其中,x为共享数据集S(t)中的某一样本,f(·,x)代表预测输出模型;
82)对各数据分布节点模型的预测输出logit值进行集成,得集成预测输出logit值[z(t)]1:
其中,为第t个增量学习阶段,数据分布节点m上的数据集样本数目,N(t)为第t个增量学习阶段,所有数据分布节点上数据集样本数目;
83)利用知识蒸馏方法,通过集成预测输出logit值[z(t)]1对初步的全局共享模型在共享数据集S(t)的预测输出值进行分布式聚合知识蒸馏lAD:
其中,DKL为KL散度距离,τ1为知识蒸馏的温度参数,n=|L(t)|为当前增量学习阶段新类别的数目;
84)将各数据分布节点模型通过分布式合作蒸馏学习的损失函数lAD进行训练,得全局共享模型的模型参数。
一种数据分布式的增量学习系统,包括:
确定模块,用于确定数据分布节点的数目及增量学习阶段的数目;
建立模块,用于建立训练数据集;
划分模块,用于确定各增量学习阶段的类别,将训练数据集划分为T个独立的数据集合,其中一个增量学习阶段对应一个数据集合,再在当前增量学习阶段,根据当增量学习阶段对应的数据集合建立各数据分布节点的数据集合;
模型构建模块,用于向各数据分布节点输入上一增量学习阶段的全局共享模型参数及当前增量学习阶段中各数据分布节点的数据集合,再在增量学习损失函数的约束下进行增量学习训练,得各数据分布点模型;
共享数据集形成模块,用于各数据分布节点从其数据集合中进行随机采样,再将采样结果进行聚合形成共享数据集;
参数获取模块,用于各数据分布节点模型对共享数据集计算得到预测输出logit值,再将各数据分布节点模型计算得到的预测输出logit值进行集成,得集成预测输出logit值,将各数据分布节点模型在共享数据集S(t)上对该集成预测输出logit值进行学习,得各数据分布节点的模型参数;
加权聚合模块,用于将各数据分布节点的模型参数进行加权聚合,得初步的全局共享模型;
学习模块,用于将个数据分布节点模型在共享数据集S(t)上计算得到的预测输出logit值进行集成,得集成输出logit值,将初步的全局共享模型在共享数据集S(t)上对该集成输出logit值进行学习,得全局共享模型的模型参数;
更新模块,用于将全局共享模型的模型参数下发至各数据分布节点,对各个数据分布节点上的全局共享模型进行更新,完成数据分布式的增量学习。
一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现所述数据分布式的增量学习方法的步骤。
一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现所述数据分布式的增量学习方法的步骤。
本发明具有以下有益效果:
本发明所述的数据分布式的增量学习方法、系统、设备及存储介质在具体操作时,将数据分散于不同的分布数据节点进行增量学习,使得学习过程更符合实际应用场景,实用性极强,同时集合模型参数集合及知识蒸馏的方式,以提高全局共享模型的学习能力,继而应对复杂的场景。
附图说明
图1为本发明的流程图;
图2为本发明的可视化图;
图3为本发明的的结果图。
具体实施方式
为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分的实施例,不是全部的实施例,而并非要限制本发明公开的范围。此外,在以下说明中,省略了对公知结构和技术的描述,以避免不必要的混淆本发明公开的概念。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本发明保护的范围。
在附图中示出了根据本发明公开实施例的结构示意图。这些图并非是按比例绘制的,其中为了清楚表达的目的,放大了某些细节,并且可能省略了某些细节。图中所示出的各种区域、层的形状及它们之间的相对大小、位置关系仅是示例性的,实际中可能由于制造公差或技术限制而有所偏差,并且本领域技术人员根据实际所需可以另外设计具有不同形状、大小、相对位置的区域/层。
实施例一
参考图1及图2,本发明所述的数据分布式的增量学习方法包括以下步骤:
1)确定数据分布节点的数目及增量学习阶段的数目;
2)建立训练数据集;
3)确定各增量学习阶段的类别,将训练数据集划分为T个独立的数据集合,其中一个增量学习阶段对应一个数据集合,再在当前增量学习阶段,根据当增量学习阶段对应的数据集合建立各数据分布节点的数据集合;
4)向各数据分布节点输入上一增量学习阶段的全局共享模型参数及当前增量学习阶段中各数据分布节点的数据集合,再在增量学习损失函数的约束下进行增量学习训练,得各数据分布点模型;
5)各数据分布节点从其数据集合中进行随机采样,再将采样结果进行聚合形成共享数据集;
6)各数据分布节点模型对共享数据集计算得到预测输出logit值,再将各数据分布节点模型计算得到的预测输出logit值进行集成,得集成预测输出logit值,将各数据分布节点模型在共享数据集S(t)上对该集成预测输出logit值进行学习,得各数据分布节点的模型参数;
7)将各数据分布节点的模型参数进行加权聚合,得初步的全局共享模型;
8)将M个数据分布节点模型在共享数据集S(t)上计算得到的预测输出logit值进行集成,得集成输出logit值,将初步的全局共享模型在共享数据集S(t)上对该集成输出logit值进行学习,得全局共享模型的模型参数;
9)将全局共享模型的模型参数下发至各数据分布节点,对各个数据分布节点上的全局共享模型进行更新,完成数据分布式的增量学习。
步骤2)的具体操作为:
建立训练数据集D={(x,y)∣x∈X,y∈L},其中,X为训练样本集,L为对应的数据标签集合L={1,…,C},,C为类别总数。
步骤3)的具体操作过程为:
确定各增量学习阶段的类别,将训练数据集D划分为T个独立的数据集合{D(1),D(2),…,D(T)},其中,一个数据集合对应一个增量学习阶段,在第t个增量学习阶段,在各数据分布节点上的数据集合
步骤4)中增量学习训练的具体过程为:
41)定义数据分布节点m上的旧类锚点集合各锚点为由特征提取模型计算得到的旧类代表样本,则第k个样本的获得方式为:
其中,Xc为c类数据样本集合,μm,c为数据分布节点m上c类数据的特征中心向量,φ(·;θ)为θ模型参数的特征提取器;
42)建立数据分布节点m上旧类锚点的损失函数
43)建立数据分布节点m上新类知识学习的损失函数
44)建立数据分布节点m上分布式增量蒸馏学习的损失函数
45)各数据分布节点m通过分布式增量蒸馏学习的损失函数进行训练,以更新模型参数θ(t-1),得各数据分布节点模型
步骤6)的具体操作过程为:
61)各数据分布节点模型对共享数据集S(t)计算得到预测输出logit值,其中,第m个数据分布节点模型计算得到的预测输出logit值/>为:
其中,x为共享数据集S(t)中的某一样本,f(·,x)为预测输出模型;
62)对各数据分布节点模型计算得到的预测输出logit值进行集成,得集成预测输出logit值[z(t)]0:
其中,为第t个增量学习阶段时,数据分布节点m上数据集样本数目,N(t)为第t个增量学习阶段时,所有数据分布节点上的数据集样本数目;
63)利用知识蒸馏方法,通过集成预测输出logit值[z(t)]0对各数据分布节点模型计算得到的预测输出logit值进行分布式合作知识蒸馏/>
其中,DKL为KL散度距离,τ1为知识蒸馏的温度参数,n=|L(t)|为当前增量学习阶段新类别的数目;
64)将各数据分布节点模型通过分布式合作蒸馏学习的损失函数进行训练更新。
步骤7)的具体操作过程为:
71)获取M个数据分布节点模型的模型参数
72)对取M个数据分布节点模型的模型参数进行加权平均,得初步的全局共享模型参数/>
其中,为第t个增量学习阶段,数据分布节点m上的数据集样本数目,N(t)为第t个增量学习阶段,所有数据分布节点上的数据集样本数目。
步骤8)的具体操作过程为:
81)各数据分布节点模型在共享数据集S(t)计算得到预测输出logit值,其中,第m个数据分布节点模型计算得到的预测输出logit值/>为:
其中,x为共享数据集S(t)中的某一样本,f(·,x)代表预测输出模型;
82)对各数据分布节点模型的预测输出logit值进行集成,得集成预测输出logit值[z(t)]1:
其中,为第t个增量学习阶段,数据分布节点m上的数据集样本数目,N(t)为第t个增量学习阶段,所有数据分布节点上数据集样本数目;
83)利用知识蒸馏方法,通过集成预测输出logit值[z(t)]1对初步的全局共享模型在共享数据集S(t)的预测输出值进行分布式聚合知识蒸馏lAD:
其中,DKL为KL散度距离,τ1为知识蒸馏的温度参数,n=|L(t)|为当前增量学习阶段新类别的数目;
84)将各数据分布节点模型通过分布式合作蒸馏学习的损失函数lAD进行训练,得全局共享模型的模型参数。
为使具体实施方式完整、清晰,步骤4)至步骤9)的具体过程如表1所示:
表1
实施例二一种数据分布式的增量学习系统,包括:
确定模块,用于确定数据分布节点的数目及增量学习阶段的数目;
建立模块,用于建立训练数据集;
划分模块,用于确定各增量学习阶段的类别,将训练数据集划分为T个独立的数据集合,其中一个增量学习阶段对应一个数据集合,再在当前增量学习阶段,根据当增量学习阶段对应的数据集合建立各数据分布节点的数据集合;
模型构建模块,用于向各数据分布节点输入上一增量学习阶段的全局共享模型参数及当前增量学习阶段中各数据分布节点的数据集合,再在增量学习损失函数的约束下进行增量学习训练,得各数据分布点模型;
共享数据集形成模块,用于各数据分布节点从其数据集合中进行随机采样,再将采样结果进行聚合形成共享数据集;
参数获取模块,用于各数据分布节点模型对共享数据集计算得到预测输出logit值,再将各数据分布节点模型计算得到的预测输出logit值进行集成,得集成预测输出logit值,将各数据分布节点模型在共享数据集S(t)上对该集成预测输出logit值进行学习,得各数据分布节点的模型参数;
加权聚合模块,用于将各数据分布节点的模型参数进行加权聚合,得初步的全局共享模型;
学习模块,用于将个数据分布节点模型在共享数据集S(t)上计算得到的预测输出logit值进行集成,得集成输出logit值,将初步的全局共享模型在共享数据集S(t)上对该集成输出logit值进行学习,得全局共享模型的模型参数;
更新模块,用于将全局共享模型的模型参数下发至各数据分布节点,对各个数据分布节点上的全局共享模型进行更新,完成数据分布式的增量学习。
实施例三
一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现所述数据分布式的增量学习方法的步骤。
实施例四
一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现所述数据分布式的增量学习方法的步骤。
仿真试验
考虑数据分布节点为5个,数据分布为独立同分布的,数据集为CIFAR100(Krizhevsky and Hinton 2009)和subImageNet(ImageNet随机抽取100个类别)的类增量学习。增量方法为iCARL,LUCIR和TPCIL,基类个数为50,增量阶段数为5、10及25。
CIFAR100数据集包含60000个RGB图像,每张图的大小为32×32像素,包含100类,subImageNet数据集包含13000个RGB图像,每张图的大小为224×224像素。
增量学习场景对比:本发明是在数据分布在不同节点的情况下进行增量学习的,这种设置方式符合实际应用场景,故具有更强的应用价值。本发明提出的基于复合知识蒸馏的数据分布式的增量学习框架相较于简单的将增量学习和分布式学习结合的极限方法在测试集准确度的对比上具有明显的提升,实验结果如图3所示。
图3中虚线部分代表基线方法,实线部分代表本发明,根据图3可知,在所有数据集上使用5、10和25个增量学习阶段的设置进行训练,本发明在每个增量学习阶段上的性能大大超过基线方法,特别是在subImageNet上。并且在学习完所有的增量学习阶段后,本发明的优越性更加明显,说明在数据分布下长期增量学习的有效性。
在CIFAR100数据集上,在设置5个增量学习阶段时,本发明所有增量学习阶段的均值比使用iCARL、LUCIR及TPCIL的基线方法分别高出1.96%、1.02%及0.91%,在设置10个增量学习阶段时,本发明也比使用iCARL、LUCIR及TPCIL的基线方法分别高出1.8%、1.01%及0.91%;25个增量学习阶段设置对应的提高分别为1.58%、0.72%及1.13%。
在subImageNet数据集上,在设置5个增量学习阶段时,本发明在增量学习阶段最终值比使用iCARL、LUCIR及TPCIL的基线方法分别高出5.7%、7.78%及7.08%;在设置10个增量学习阶段时,本发明也比使用iCARL、LUCIR及TPCIL的基线方法分别高出5.67%、7.14%及5.92%;25个增量学习阶段设置对应的提高分别为4.4%、6.82%及7.49%。
上述方案,仅为本申请较佳的几个实施方式的描述,但本申请的保护范围不仅限于此,任何熟悉该技术的人能在本申请描述的范围内轻易实现,而不改变权利要求涉及基本原理的变化或替换,都应涵盖在本申请的保护范围之内,即本申请保护范围应以权利要求保护范围为准。
Claims (10)
1.一种数据分布式的增量学习方法,其特征在于,用于图像处理领域,包括以下步骤:
1)确定数据分布节点的数目及增量学习阶段的数目;
2)建立训练数据集,所述训练数据集包括若干RGB图像;
3)确定各增量学习阶段的类别,将训练数据集划分为T个独立的数据集合,其中一个增量学习阶段对应一个数据集合,再在当前增量学习阶段,根据当增量学习阶段对应的数据集合建立各数据分布节点的数据集合;
4)向各数据分布节点输入上一增量学习阶段的全局共享模型参数及当前增量学习阶段中各数据分布节点的数据集合,再在增量学习损失函数的约束下进行增量学习训练,得各数据分布点模型;
5)各数据分布节点从其数据集合中进行随机采样,再将采样结果进行聚合形成共享数据集;
6)各数据分布节点模型对共享数据集计算得到预测输出logit值,再将各数据分布节点模型计算得到的预测输出logit值进行集成,得集成预测输出logit值,将各数据分布节点模型在共享数据集上对该集成预测输出logit值进行学习,得各数据分布节点的模型参数;
7)将各数据分布节点的模型参数进行加权聚合,得初步的全局共享模型;
8)将M个数据分布节点模型在共享数据集上计算得到的预测输出logit值进行集成,得集成输出logit值,将初步的全局共享模型在共享数据集上对该集成输出logit值进行学习,得全局共享模型的模型参数;
9)将全局共享模型的模型参数下发至各数据分布节点,对各个数据分布节点上的全局共享模型进行更新,完成数据分布式的增量学习。
2.根据权利要求1所述的数据分布式的增量学习方法,其特征在于,步骤2)的具体操作为:
建立训练数据集D={(x,y)∣x∈X,y∈L},其中,X为训练样本集,L为对应的数据标签集合L={1,…,C},,C为类别总数。
3.根据权利要求1所述的数据分布式的增量学习方法,其特征在于,步骤3)的具体操作过程为:
确定各增量学习阶段的类别,将训练数据集D划分为T个独立的数据集合{D(1),D(2),…,D(T)},其中,一个数据集合对应一个增量学习阶段,在第t个增量学习阶段,在各数据分布节点上的数据集合
4.根据权利要求1所述的数据分布式的增量学习方法,其特征在于,步骤4)中增量学习训练的具体过程为:
41)定义数据分布节点m上的旧类锚点集合各锚点为由特征提取模型计算得到的旧类代表样本,则第k个样本的获得方式为:
其中,Xc为c类数据样本集合,μm,E为数据分布节点m上c类数据的特征中心向量,φ(·;θ)为θ模型参数的特征提取器;
42)建立数据分布节点m上旧类锚点的损失函数
43)建立数据分布节点m上新类知识学习的损失函数
44)建立数据分布节点m上分布式增量蒸馏学习的损失函数
45)各数据分布节点m通过分布式增量蒸馏学习的损失函数进行训练,以更新模型参数θ(t-1),得各数据分布节点模型
5.根据权利要求1所述的数据分布式的增量学习方法,其特征在于,步骤6)的具体操作过程为:
61)各数据分布节点模型对共享数据集S(t)计算得到预测输出logit值,其中,第m个数据分布节点模型计算得到的预测输出logit值/>为:
其中,x为共享数据集S(t)中的某一样本,f(·,x)为预测输出模型;
62)对各数据分布节点模型计算得到的预测输出logit值进行集成,得集成预测输出logit值[z(t)]0:
其中,为第t个增量学习阶段时,数据分布节点m上数据集样本数目,N(t)为第t个增量学习阶段时,所有数据分布节点上的数据集样本数目;
63)利用知识蒸馏方法,通过集成预测输出logit值[z(t)]0对各数据分布节点模型计算得到的预测输出logit值进行分布式合作知识蒸馏/>
其中,DKL为KL散度距离,τ1为知识蒸馏的温度参数,n=|L(t)|为当前增量学习阶段新类别的数目;
64)将各数据分布节点模型通过分布式合作蒸馏学习的损失函数进行训练更新。
6.根据权利要求1所述的数据分布式的增量学习方法,其特征在于,步骤7)的具体操作过程为:
71)获取M个数据分布节点模型的模型参数
72)对取M个数据分布节点模型的模型参数进行加权平均,得初步的全局共享模型参数/>
其中,为第t个增量学习阶段,数据分布节点m上的数据集样本数目,N(t)为第t个增量学习阶段,所有数据分布节点上的数据集样本数目。
7.根据权利要求1所述的数据分布式的增量学习方法,其特征在于,步骤8)的具体操作过程为:
81)各数据分布节点模型在共享数据集S(t)计算得到预测输出logit值,其中,第m个数据分布节点模型计算得到的预测输出logit值/>为:
其中,x为共享数据集S(t)中的某一样本,f(·,x)代表预测输出模型;
82)对各数据分布节点模型的预测输出logit值进行集成,得集成预测输出logit值[z(t)]1:
其中,为第t个增量学习阶段,数据分布节点m上的数据集样本数目,N(t)为第t个增量学习阶段,所有数据分布节点上数据集样本数目;
83)利用知识蒸馏方法,通过集成预测输出logit值[z(t)]1对初步的全局共享模型在共享数据集S(t)的预测输出值进行分布式聚合知识蒸馏/>
其中,DKL为KL散度距离,τ1为知识蒸馏的温度参数,n=|L(t)|为当前增量学习阶段新类别的数目;
84)将各数据分布节点模型通过分布式合作蒸馏学习的损失函数进行训练,得全局共享模型的模型参数。
8.一种数据分布式的增量学习系统,其特征在于,用于图像处理领域,包括:
确定模块,用于确定数据分布节点的数目及增量学习阶段的数目;
建立模块,用于建立训练数据集,所述训练数据集包括若干RGB图像;
划分模块,用于确定各增量学习阶段的类别,将训练数据集划分为T个独立的数据集合,其中一个增量学习阶段对应一个数据集合,再在当前增量学习阶段,根据当增量学习阶段对应的数据集合建立各数据分布节点的数据集合;
模型构建模块,用于向各数据分布节点输入上一增量学习阶段的全局共享模型参数及当前增量学习阶段中各数据分布节点的数据集合,再在增量学习损失函数的约束下进行增量学习训练,得各数据分布点模型;
共享数据集形成模块,用于各数据分布节点从其数据集合中进行随机采样,再将采样结果进行聚合形成共享数据集;
参数获取模块,用于各数据分布节点模型对共享数据集计算得到预测输出logit值,再将各数据分布节点模型计算得到的预测输出logit值进行集成,得集成预测输出logit值,将各数据分布节点模型在共享数据集S(t)上对该集成预测输出logit值进行学习,得各数据分布节点的模型参数;
加权聚合模块,用于将各数据分布节点的模型参数进行加权聚合,得初步的全局共享模型;
学习模块,用于将M个数据分布节点模型在共享数据集S(t)上计算得到的预测输出logit值进行集成,得集成输出logit值,将初步的全局共享模型在共享数据集S(t)上对该集成输出logit值进行学习,得全局共享模型的模型参数;
更新模块,用于将全局共享模型的模型参数下发至各数据分布节点,对各个数据分布节点上的全局共享模型进行更新,完成数据分布式的增量学习。
9.一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1-7任一项所述数据分布式的增量学习方法的步骤。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1-7任一项所述数据分布式的增量学习方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110706288.4A CN113449878B (zh) | 2021-06-24 | 2021-06-24 | 数据分布式的增量学习方法、系统、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110706288.4A CN113449878B (zh) | 2021-06-24 | 2021-06-24 | 数据分布式的增量学习方法、系统、设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113449878A CN113449878A (zh) | 2021-09-28 |
CN113449878B true CN113449878B (zh) | 2024-04-02 |
Family
ID=77812554
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110706288.4A Active CN113449878B (zh) | 2021-06-24 | 2021-06-24 | 数据分布式的增量学习方法、系统、设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113449878B (zh) |
Families Citing this family (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114491168B (zh) * | 2022-01-27 | 2022-12-13 | 中国电力科学研究院有限公司 | 调控云样本数据共享方法、系统、计算机设备及存储介质 |
CN117133039B (zh) * | 2023-09-01 | 2024-03-15 | 中国科学院自动化研究所 | 图像鉴伪模型训练方法、图像鉴伪方法、装置及电子设备 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN104376120A (zh) * | 2014-12-04 | 2015-02-25 | 浙江大学 | 一种信息检索方法及系统 |
WO2018213205A1 (en) * | 2017-05-14 | 2018-11-22 | Digital Reasoning Systems, Inc. | Systems and methods for rapidly building, managing, and sharing machine learning models |
CN112990280A (zh) * | 2021-03-01 | 2021-06-18 | 华南理工大学 | 面向图像大数据的类增量分类方法、系统、装置及介质 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11295171B2 (en) * | 2019-10-18 | 2022-04-05 | Google Llc | Framework for training machine-learned models on extremely large datasets |
-
2021
- 2021-06-24 CN CN202110706288.4A patent/CN113449878B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN104376120A (zh) * | 2014-12-04 | 2015-02-25 | 浙江大学 | 一种信息检索方法及系统 |
WO2018213205A1 (en) * | 2017-05-14 | 2018-11-22 | Digital Reasoning Systems, Inc. | Systems and methods for rapidly building, managing, and sharing machine learning models |
CN112990280A (zh) * | 2021-03-01 | 2021-06-18 | 华南理工大学 | 面向图像大数据的类增量分类方法、系统、装置及介质 |
Non-Patent Citations (2)
Title |
---|
基于异质信息融合的网络图像半监督学习方法;杜友田;李谦;周亚东;吴陈鹤;;自动化学报(第12期);全文 * |
集成学习分布式异常检测方法;周绪川;钟勇;;计算机工程与应用(第18期);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN113449878A (zh) | 2021-09-28 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN107833183B (zh) | 一种基于多任务深度神经网络的卫星图像同时超分辨和着色的方法 | |
CN113449878B (zh) | 数据分布式的增量学习方法、系统、设备及存储介质 | |
CN106845530A (zh) | 字符检测方法和装置 | |
CN111626184B (zh) | 一种人群密度估计方法及系统 | |
CN110222760A (zh) | 一种基于winograd算法的快速图像处理方法 | |
CN111915555B (zh) | 一种3d网络模型预训练方法、系统、终端及存储介质 | |
CN112347970A (zh) | 一种基于图卷积神经网络的遥感影像地物识别方法 | |
CN112101364B (zh) | 基于参数重要性增量学习的语义分割方法 | |
CN112541584A (zh) | 深度神经网络模型并行模式选择方法 | |
CN113569852A (zh) | 语义分割模型的训练方法、装置、电子设备及存储介质 | |
CN113420827A (zh) | 语义分割网络训练和图像语义分割方法、装置及设备 | |
CN111282281B (zh) | 图像处理方法及装置、电子设备和计算机可读存储介质 | |
CN115018039A (zh) | 一种神经网络蒸馏方法、目标检测方法以及装置 | |
CN113554653A (zh) | 基于互信息校准点云数据长尾分布的语义分割方法 | |
CN115578624A (zh) | 农业病虫害模型构建方法、检测方法及装置 | |
CN113705402B (zh) | 视频行为预测方法、系统、电子设备及存储介质 | |
CN107122472A (zh) | 大规模非结构化数据提取方法、其系统、分布式数据管理平台 | |
EP3736749A1 (de) | Verfahren und vorrichtung zur ansteuerung eines geräts mit einem datensatz | |
CN115544307A (zh) | 基于关联矩阵的有向图数据特征提取与表达方法和系统 | |
CN114627085A (zh) | 目标图像的识别方法和装置、存储介质及电子设备 | |
CN118552136B (zh) | 基于大数据的供应链智能库存管理系统及方法 | |
CN116597419B (zh) | 一种基于参数化互近邻的车辆限高场景识别方法 | |
CN117647855B (zh) | 一种基于序列长度的短临降水预报方法、装置及设备 | |
CN118552136A (zh) | 基于大数据的供应链智能库存管理系统及方法 | |
CN114155555B (zh) | 人体行为人工智能判断系统和方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |