CN112698933A - 在多任务数据流中持续学习的方法及装置 - Google Patents
在多任务数据流中持续学习的方法及装置 Download PDFInfo
- Publication number
- CN112698933A CN112698933A CN202110312417.1A CN202110312417A CN112698933A CN 112698933 A CN112698933 A CN 112698933A CN 202110312417 A CN202110312417 A CN 202110312417A CN 112698933 A CN112698933 A CN 112698933A
- Authority
- CN
- China
- Prior art keywords
- task
- learning
- stage
- clustering
- tuple
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 58
- 230000008859 change Effects 0.000 claims abstract description 8
- 230000009471 action Effects 0.000 claims description 27
- 230000009467 reduction Effects 0.000 claims description 27
- 239000003795 chemical substances by application Substances 0.000 claims description 20
- 238000004364 calculation method Methods 0.000 claims description 17
- 239000000126 substance Substances 0.000 claims description 12
- 230000008901 benefit Effects 0.000 claims description 6
- 239000011159 matrix material Substances 0.000 claims description 4
- 230000002265 prevention Effects 0.000 claims description 3
- 238000005457 optimization Methods 0.000 description 7
- 230000008569 process Effects 0.000 description 6
- 238000012549 training Methods 0.000 description 4
- 238000002955 isolation Methods 0.000 description 3
- 238000012545 processing Methods 0.000 description 3
- 230000002787 reinforcement Effects 0.000 description 3
- 230000004048 modification Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 238000000926 separation method Methods 0.000 description 2
- 241000764238 Isis Species 0.000 description 1
- 238000013475 authorization Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 230000006870 function Effects 0.000 description 1
- 239000000463 material Substances 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000004044 response Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F9/00—Arrangements for program control, e.g. control units
- G06F9/06—Arrangements for program control, e.g. control units using stored programs, i.e. using an internal store of processing equipment to receive or retain programs
- G06F9/46—Multiprogramming arrangements
- G06F9/48—Program initiating; Program switching, e.g. by interrupt
- G06F9/4806—Task transfer initiation or dispatching
- G06F9/4843—Task transfer initiation or dispatching by program, e.g. task dispatcher, supervisor, operating system
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/23—Clustering techniques
- G06F18/232—Non-hierarchical techniques
- G06F18/2321—Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions
- G06F18/23213—Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions with fixed number of clusters, e.g. K-means clustering
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/004—Artificial life, i.e. computing arrangements simulating life
- G06N3/008—Artificial life, i.e. computing arrangements simulating life based on physical entities controlled by simulated intelligence so as to replicate intelligent life forms, e.g. based on robots replicating pets or humans in their appearance or behaviour
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Software Systems (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Probability & Statistics with Applications (AREA)
- Robotics (AREA)
- Image Analysis (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明提供通用的在多任务数据流中持续学习的方法和装置,包括:智能体采用DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段,如果判定任务发生了改变,则学习下一个任务,在学习下一个任务时,如果改变所述DQN网络的参数,将会启动额外的惩罚项,通过约束所述DQN网络的参数在下个任务学习中变化实现对当前任务的不遗忘。
Description
技术领域
本申请涉及多智能体、在线学习领域,尤其涉及可以在多任务数据流中持续学习的方法和装置。
背景技术
一个智能体在离线训练后可以在当前任务下做出正确的决策,然而将它置于新的环境或者任务下,它可能会继续学习已完成新的任务但是它将遗忘先前的知识。在持续学习的研究中,人们把这种现象称作灾难性遗忘。对于灾难性遗忘问题的研究有许多,且已经取得了不俗的进展。现在较主流的方法有I经验回放方法,这类方法通过重播历史数据或者生成虚拟的历史数据等解决灾难性遗忘. II正则化方法,比如蒸馏旧知识或者巩固参数以防止灾难性遗忘。III 参数隔离方法,比如Expert Gate方法,通过衡量任务间关系度实现任务增量设定下的持续学习。
上述工作均停留在任务边界可知的假设上,例如在训练时明确知道当前训练的任务,而在现实中得到的数据通常是一个任务边界不可知的数据流。比如人眼的数据或者摄像机的数据,这要求agent必须能够在这样的数据流中进行学习并作出正确的决策。基于现有灾难性遗忘问题的研究,也有一些工作进行了在包含边界不可知的数据流中进行训练决策的探索。Task-Free Continual Learning一文过监测loss值得变化来检测任务边界的变化,决定何时使用MAS方法(一种正则化方法)巩固现有的知识。Continual ReinforcementLearning in 3D Non-stationary Environments关注了任务环境发生改变时如何进行持续学习。该算法通过衡量reward的差异实现任务环境边界的检测,然后在环境改变时通过EWC(一种正则化方法)巩固知识。
当我们将环境扩展为任务边界不可知的强化学习环境时,基础的灾难性遗忘问题的研究会暴露严重的弊端。强化学习通常拥有庞大的观测空间。这个观测空间由于其多变性与随机性难以保存与生成,因此经验回放方法在强化学习环境下无法发挥其优势。而正则化与参数隔离方法依赖于明确的任务边界。它在任务边界不可知的情况下难以在合适的时间巩固知识或者隔离参数。相比正则化方法,参数隔离方法通常会对网络模型进行增量的修改,因此更加难以在数据流中进行学习与决策。
现有的两种处理任务边界不可知的算法均存在一定的弊端。Task-FreeContinual Learning创作性的实现在任务边界不可知的人脸信息流中进行在线学习。但是这类方法局限于Loss平稳的假设,泛化能力较差。Continual Reinforcement Learning in3D Non-stationary Environments依赖于不同任务间reward的设定,如果两个任务的reward设定难以在任务切换时被检测出来,这种方法将不起作用。
授权公告号CN 106507398 B公开了一种基于持续学习的网络自优化方法;包含持续学习过程和网络优化过程;本发明提供的网络自优化方法可以大大减少人力物力的投入,节约成本,缩短优化流程,提高优化效率,同时解决上述发明优化时间冗长,可能不是最佳优化策略的缺陷;快速地发现网络中出现的问题,并能够缩短网络故障的持续时间,及时恢复网络正常的工作状态,达到优化网络性能的目的。
申请公布号CN 110705689 A公开了一种可区分特征的持续学习方法及装置,其中,方法包括以下步骤:确定当前分类任务,并将目标函数加入学习模型的angularloss项;在当前分类任务为新任务时,初始化学习模型的最后一层的参数,以使其相互正交,并在更新过程中暂时固定;训练预设时间后,将最后一层的参数参与更新,其中,最后一层的参数作为正交化的约束,以使不同的分类任务不会互相干扰。该方法可以在处理不同类型的分类任务时,都能够很好地对该任务的目标类别进行精准分类,有效解决了目前持续学习中不同任务中的类别在特征空间中会发生重叠、导致模型效果不佳的问题。
发明内容
有鉴于此,本发明第一方面提供一种在多任务数据流中持续学习的方法,包括:
智能体采用DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段;具体方法如下:
S1:通过输入数据状态s和智能体给出的动作a,计算Q值,并由输入数据状态队列S和智能体给出的动作队列计算Q值队列;
S2:对所述Q值队列进行排序,取数值最高的τ’个Q值对应的输入数据状态并形成第一元组(Smax,1),取数值最低的τ’’个Q值对应的输入数据状态并形成第二元组(Smin,0);将第一元组(Smax,1)和第二元组(Smin,0)合并组成原始数据集 (X,Y) ,所述Y为聚类结果的标签(Y0,Y1),Y0对应的是所述第一元组中的1,Y1对应的是所述第二元组中的0;
S3:对原始数据集中的X降维,得到降维特征Z;
S4:应用k-mean对所述Z进行聚类,得到聚类中心μ0和μ1;然后求得每个降维特征Z距离聚类中心的概率分布;
S5:应用所述概率分布求得任务学习阶段的聚类准确性;
S6:当任务学习阶段的聚类准确性大于第一阈值时,进入任务识别阶段;当任务学习阶段的聚类准确性小于一定阈值时,返回步骤S1;
S7:采集任务识别数据状态队列S’,将输入数据状态队列S替换为任务识别数据状态队列S’,重复步骤S1-S5得到任务识别阶段的聚类准确性;
S8:当任务识别阶段的聚类准确性小于第二阈值时,判定任务发生了改变,学习下一个任务。
优选地,所述方法还包括:
S9:在学习下一个任务时,如果改变所述DQN网络参数,将会启动额外的惩罚项,通过约束所述DQN网络参数在下个任务学习中变化实现对当前任务的不遗忘。
优选地,所述额外的惩罚项为,
其中,
F i :累加公式中第i个Fisher信息矩阵;
优选地,所述Q值的计算方法为:
其中
s t :期望计算公式中输入数据变量,输入数据状态s为数值;
a t :期望计算公式中策略网络输出动作变量,动作a为数值;
T:输入数据状态s的数据流总时长。
优选地,所述对原始数据集中的X降维,得到降维特征Z的具体方法为:
由两层全连接层将原始数据集中的X降维到所述降维特征Z。
优选地,所述每个降维特征Z距离聚类中心的概率分布的计算方法为:
其中,
zi:所述降维特征Z中第i个样本;
μj:所述聚类中心;
α:超参数为0.5。
优选地,所述求得任务学习阶段的聚类准确性的具体方法为:
其中,
qi:代表第i个样本聚类求得的概率分布,具体为qi0和qi1;
Yi:聚类结果的标签Y中第i个样本对应的真实标签;
m:max,取qi0和qi1中的最大值,来判断样本属于哪个类,确定第i个样本对应的真实标签Yi是否与判断一致;
τ’:所述第一元组(Smax,1)中,元素Smax的维数;
τ’’:所述第二元组(Smin,0)中,元素Smin的维数。
本发明第二方面还提供了一种在多任务数据流中持续学习的装置,包括:DQN网络和任务识别器; DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段;
所述任务识别器包括:自动编码器与任务数据缓存;输入数据状态队列S和智能体给出的动作队列计算Q值队列存储在任务数据缓存中;当输入数据状态队列的缓存长度大于一定长度τ时,任务识别器将会进入任务学习阶段;自动编码器,对所述Q值队列进行排序,取数值最高的τ’个Q值对应的输入数据状态并形成第一元组(Smax,1),取数值最低的τ’’个Q值对应的输入数据状态并形成第二元组(Smin,0);将第一元组(Smax,1)和第二元组(Smin,0)合并组成原始数据集 (X,Y) ,所述Y为聚类结果的标签(Y0,Y1),Y0对应的是所述第一元组中的1和Y1对应的是所述第二元组中的0;对原始数据集中的X降维,得到降维特征Z;应用k-mean对所述Z进行聚类,得到聚类中心μ0和μ1;然后求得每个降维特征Z距离聚类中心的概率分布;应用所述概率分布求得任务学习阶段的聚类准确性;判断当任务学习阶段的聚类准确性大于第一阈值时,进入任务识别阶段;当任务学习阶段的聚类准确性小于一定阈值时,任务数据缓存继续收集数据;
在任务识别阶段,任务数据缓存采集任务识别数据状态队列S’,自动编码器应用任务识别数据状态队列S’计算任务识别阶段的聚类准确性,当任务识别阶段的聚类准确性小于第二阈值时,判定任务发生了改变,学习下一个任务;
在学习下一个任务时,如果改变所述DQN网络的参数,自动编码器将会启动额外的惩罚项,通过约束所述DQN网络的参数在下个任务学习中变化实现对当前任务的不遗忘。
本申请实施例提供的上述技术方案与现有技术相比具有如下优点:
本申请实施例提供的该方法,通过聚类方法进行任务边界的检测识别,解决了在任务边界不可知的数据流中进行决策的问题。
附图说明
图1为本发明实施例提供的可以在多任务数据流中持续学习的方法流程图;
图2为本发明另一实施例提供的在多任务数据流中持续学习的方法流程图。
具体实施方式
这里将详细地对示例性实施例进行说明,其示例表示在附图中。下面的描述涉及附图时,除非另有表示,不同附图中的相同数字表示相同或相似的要素。以下示例性实施例中所描述的实施方式并不代表与本发明相一致的所有实施方式。相反,它们仅是与如所附权利要求书中所详述的、本发明的一些方面相一致的装置和方法的例子。
参见图1,本申请实施例提供的一种在多任务数据流中持续学习的方法,包括:
智能体采用DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段;具体方法如下:
S1:通过输入数据状态s和智能体给出的动作a,计算Q值,并由输入数据状态队列S和智能体给出的动作队列计算Q值队列。
所述Q值的计算方法为:
其中
s t :期望计算公式中输入数据变量,输入数据状态s为数值;
a t :期望计算公式中策略网络输出动作变量,动作a为数值;
T:输入数据状态s的数据流总时长。
S2:对所述Q值队列进行排序,取数值最高的τ’个Q值对应的输入数据状态并形成第一元组(Smax,1),取数值最低的τ’’个Q值对应的输入数据状态并形成第二元组(Smin,0);将第一元组(Smax,1)和第二元组(Smin,0)合并组成原始数据集 (X,Y) ,所述Y为聚类结果的标签(Y0,Y1),Y0对应的是所述第一元组中的1,Y1对应的是所述第二元组中的0。
S3:对原始数据集中的X降维,得到降维特征Z。
所述对原始数据集中的X降维,得到降维特征Z的具体方法为:
由两层全连接层将原始数据集中的X降维到所述降维特征Z。
S4:应用k-mean对所述Z进行聚类,得到聚类中心μ0和μ1;然后求得每个降维特征Z距离聚类中心的概率分布。
所述每个降维特征Z距离聚类中心的概率分布的计算方法为:
其中,
zi:所述降维特征Z中第i个样本;
μj:所述聚类中心;
α:超参数为0-1,优选的为0.5。
S5:应用所述概率分布求得任务学习阶段的聚类准确性。
所述求得任务学习阶段的聚类准确性的具体方法为:
其中,
qi:代表第i个样本聚类求得的概率分布,具体为qi0和qi1;
Yi:聚类结果的标签Y中第i个样本对应的真实标签;
m:max,取qi0和qi1中的最大值,来判断样本属于哪个类,确定第i个样本对应的真实标签Yi是否与判断一致;
τ’:所述第一元组(Smax,1)中,元素Smax的维数;
τ’’:所述第二元组(Smin,0)中,元素Smin的维数。
S6:当任务学习阶段的聚类准确性大于第一阈值时,进入任务识别阶段;当任务学习阶段的聚类准确性小于一定阈值时,返回步骤S1。
S7:采集任务识别数据状态队列S’,将输入数据状态队列S替换为任务识别数据状态队列S’,重复步骤S1-S5得到任务识别阶段的聚类准确性。
S8:当任务识别阶段的聚类准确性小于第二阈值时,判定任务发生了改变,学习下一个任务。
在一些实施例中,如图2所示,上述在多任务数据流中持续学习的方法还包括:
S9:在学习下一个任务时,如果改变所述DQN网络参数,将会启动额外的惩罚项,通过约束所述DQN网络参数在下个任务学习中变化实现对当前任务的不遗忘。
所述额外的惩罚项为,
其中,
F i :累加公式中第i个Fisher信息矩阵;
基于同一发明构思,本申请实施例提供的一种在多任务数据流中持续学习的装置,包括:DQN网络和任务识别器; DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段;
所述任务识别器包括:自动编码器与任务数据缓存;输入数据状态队列S和智能体给出的动作队列计算Q值队列存储在任务数据缓存中;当输入数据状态队列的缓存长度大于一定长度τ时,任务识别器将会进入任务学习阶段;自动编码器,对所述Q值队列进行排序,取数值最高的τ’个Q值对应的输入数据状态并形成第一元组(Smax,1),取数值最低的τ’’个Q值对应的输入数据状态并形成第二元组(Smin,0);将第一元组(Smax,1)和第二元组(Smin,0)合并组成原始数据集 (X,Y) ,所述Y为聚类结果的标签(Y0,Y1),Y0对应的是所述第一元组中的1和Y1对应的是所述第二元组中的0;对原始数据集中的X降维,得到降维特征Z;应用k-mean对所述Z进行聚类,得到聚类中心μ0和μ1;然后求得每个降维特征Z距离聚类中心的概率分布;应用所述概率分布求得任务学习阶段的聚类准确性;判断当任务学习阶段的聚类准确性大于第一阈值时,进入任务识别阶段;当任务学习阶段的聚类准确性小于一定阈值时,任务数据缓存继续收集数据;
在任务识别阶段,任务数据缓存采集任务识别数据状态队列S’,自动编码器应用任务识别数据状态队列S’计算任务识别阶段的聚类准确性,当任务识别阶段的聚类准确性小于第二阈值时,判定任务发生了改变,学习下一个任务;
在学习下一个任务时,如果改变所述DQN网络的参数,自动编码器将会启动额外的惩罚项,通过约束所述DQN网络的参数在下个任务学习中变化实现对当前任务的不遗忘。
具体实施例
DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段;
所述任务识别器包括:自动编码器与任务数据缓存;输入数据状态队列S和智能体给出的动作队列计算Q值队列存储在任务数据缓存中;
所述Q值的计算方法为:
其中
s t :期望计算公式中输入数据变量,输入数据状态s为数值;
a t :期望计算公式中策略网络输出动作变量,动作a为数值;
T:输入数据状态s的数据流总时长。
当输入数据状态队列的缓存长度大于一定长度τ时,τ优选的为1024,任务识别器将会进入任务学习阶段;自动编码器,对所述Q值队列进行排序,取数值最高的τ’个Q值对应的输入数据状态并形成第一元组(Smax,1),τ’优选的为64,取数值最低的τ’’个Q值对应的输入数据状态并形成第二元组(Smin,0);τ’’优选的为63,将第一元组(Smax,1)和第二元组(Smin,0)合并组成原始数据集 (X,Y) ,所述Y为聚类结果的标签(Y0,Y1),Y0对应的是所述第一元组中的1和Y1对应的是所述第二元组中的0;由自动编码器的两层全连接层组成编码层对原始数据集中的X降维,得到降维特征Z;应用k-mean对所述Z进行聚类,得到聚类中心μ0和μ1;然后求得每个降维特征Z距离聚类中心的概率分布;
所述每个降维特征Z距离聚类中心的概率分布的计算方法为:
其中,
zi:所述降维特征Z中第i个样本;
μj:所述聚类中心;
α:超参数,具体为0.5;
应用所述概率分布求得任务学习阶段的聚类准确性;
其中,
qi:代表第i个样本聚类求得的概率分布,具体为qi0和qi1;
Yi:聚类结果的标签Y中第i个样本对应的真实标签;
m:max,取qi0和qi1中的最大值,来判断样本属于哪个类,确定第i个样本对应的真实标签Yi是否与判断一致;
τ’:所述第一元组(Smax,1)中,元素Smax的维数;
τ’’:所述第二元组(Smin,0)中,元素Smin的维数。
在任务识别阶段,任务数据缓存采集任务识别数据状态队列S’, 自动编码器应用任务识别数据状态队列S’计算任务识别阶段的聚类准确性,当任务识别阶段的聚类准确性小于第二阈值时,所述第二阈值为,优选为0.7,判定任务发生了改变,学习下一个任务;
在学习下一个任务时,如果改变所述DQN网络参数,自动编码器将会启动额外的惩罚项,通过约束所述DQN网络参数在下个任务学习中变化实现对当前任务的不遗忘;
所述额外的惩罚项为,
其中,
F i :累加公式中第i个Fisher信息矩阵;
应当理解,尽管在本发明可能采用术语第一、第二、第三等来描述各种信息,但这些信息不应限于这些术语。这些术语仅用来将同一类型的信息彼此区分开。例如,在不脱离本发明范围的情况下,第一信息也可以被称为第二信息,类似地,第二信息也可以被称为第一信息。取决于语境,如在此所使用的词语“如果”可以被解释成为“在……时”或“当……时”或“响应于确定”。
虽然本说明书包含许多具体实施细节,但是这些不应被解释为限制任何发明的范围或所要求保护的范围,而是主要用于描述特定发明的具体实施例的特征。本说明书内在多个实施例中描述的某些特征也可以在单个实施例中被组合实施。另一方面,在单个实施例中描述的各种特征也可以在多个实施例中分开实施或以任何合适的子组合来实施。此外,虽然特征可以如上所述在某些组合中起作用并且甚至最初如此要求保护,但是来自所要求保护的组合中的一个或多个特征在一些情况下可以从该组合中去除,并且所要求保护的组合可以指向子组合或子组合的变型。
类似地,虽然在附图中以特定顺序描绘了操作,但是这不应被理解为要求这些操作以所示的特定顺序执行或顺次执行、或者要求所有例示的操作被执行,以实现期望的结果。在某些情况下,多任务和并行处理可能是有利的。此外,上述实施例中的各种系统模块和组件的分离不应被理解为在所有实施例中均需要这样的分离,并且应当理解,所描述的程序组件和系统通常可以一起集成在单个软件产品中,或者封装成多个软件产品。
由此,主题的特定实施例已被描述。其他实施例在所附权利要求书的范围以内。在某些情况下,权利要求书中记载的动作可以以不同的顺序执行并且仍实现期望的结果。此外,附图中描绘的处理并非必需所示的特定顺序或顺次顺序,以实现期望的结果。在某些实现中,多任务和并行处理可能是有利的。
以上所述仅为本发明的较佳实施例而已,并不用以限制本发明,凡在本发明的精神和原则之内,所做的任何修改、等同替换、改进等,均应包含在本发明保护的范围之内。
Claims (10)
1.在多任务数据流中持续学习的方法,其特征在于,包括:
智能体采用DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段;具体方法如下:
S1:通过输入数据状态s和智能体给出的动作a,计算Q值,并由输入数据状态队列S和智能体给出的动作队列计算Q值队列;
S2:对所述Q值队列进行排序,取数值最高的τ’个Q值对应的输入数据状态并形成第一元组(Smax,1),取数值最低的τ’’个Q值对应的输入数据状态并形成第二元组(Smin,0);将第一元组(Smax,1)和第二元组(Smin,0)合并组成原始数据集 (X,Y) ,所述Y为聚类结果的标签(Y0,Y1),Y0对应的是所述第一元组中的1,Y1对应的是所述第二元组中的0;
S3:对原始数据集中的X降维,得到降维特征Z;
S4:应用k-mean对所述Z进行聚类,得到聚类中心μ0和μ1;然后求得每个降维特征Z距离聚类中心的概率分布;
S5:应用所述概率分布求得任务学习阶段的聚类准确性;
S6:当任务学习阶段的聚类准确性大于第一阈值时,进入任务识别阶段;当任务学习阶段的聚类准确性小于一定阈值时,返回步骤S1;
S7:采集任务识别数据状态队列S’,将输入数据状态队列S替换为任务识别数据状态队列S’,重复步骤S1-S5得到任务识别阶段的聚类准确性;
S8:当任务识别阶段的聚类准确性小于第二阈值时,判定任务发生了改变,学习下一个任务。
2.根据权利要求1所述的在多任务数据流中持续学习的方法,其特征在于,所述方法还包括:
S9:在学习下一个任务时,如果改变所述DQN网络的参数,将会启动额外的惩罚项,通过约束所述DQN网络的参数在下个任务学习中变化实现对当前任务的不遗忘。
5.根据权利要求1所述的在多任务数据流中持续学习的方法,其特征在于,所述对原始数据集中的X降维,得到降维特征Z的具体方法为:
由两层全连接层将原始数据集中的X降维到所述降维特征Z。
10.在多任务数据流中持续学习的装置,其特征在于,包括:DQN网络和任务识别器;DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段;
所述任务识别器包括:自动编码器与任务数据缓存;输入数据状态队列S和智能体给出的动作队列计算Q值队列存储在任务数据缓存中;当输入数据状态队列的缓存长度大于一定长度τ时,任务识别器将会进入任务学习阶段;自动编码器,对所述Q值队列进行排序,取数值最高的τ’个Q值对应的输入数据状态并形成第一元组(Smax,1),取数值最低的τ’’个Q值对应的输入数据状态并形成第二元组(Smin,0);将第一元组(Smax,1)和第二元组(Smin,0)合并组成原始数据集 (X,Y) ,所述Y为聚类结果的标签(Y0,Y1),Y0对应的是所述第一元组中的1和Y1对应的是所述第二元组中的0;对原始数据集中的X降维,得到降维特征Z;应用k-mean对所述Z进行聚类,得到聚类中心μ0和μ1;然后求得每个降维特征Z距离聚类中心的概率分布;应用所述概率分布求得任务学习阶段的聚类准确性;判断当任务学习阶段的聚类准确性大于第一阈值时,进入任务识别阶段;当任务学习阶段的聚类准确性小于一定阈值时,任务数据缓存继续收集数据;
在任务识别阶段,任务数据缓存采集任务识别数据状态队列S’,自动编码器应用任务识别数据状态队列S’计算任务识别阶段的聚类准确性,当任务识别阶段的聚类准确性小于第二阈值时,判定任务发生了改变,学习下一个任务;
在学习下一个任务时,如果改变所述DQN网络的参数,自动编码器将会启动额外的惩罚项,通过约束所述DQN网络的参数在下个任务学习中变化实现对当前任务的不遗忘。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110312417.1A CN112698933A (zh) | 2021-03-24 | 2021-03-24 | 在多任务数据流中持续学习的方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110312417.1A CN112698933A (zh) | 2021-03-24 | 2021-03-24 | 在多任务数据流中持续学习的方法及装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN112698933A true CN112698933A (zh) | 2021-04-23 |
Family
ID=75515565
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110312417.1A Pending CN112698933A (zh) | 2021-03-24 | 2021-03-24 | 在多任务数据流中持续学习的方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112698933A (zh) |
Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2018017546A1 (en) * | 2016-07-18 | 2018-01-25 | Google Llc | Training machine learning models on multiple machine learning tasks |
CN108647789A (zh) * | 2018-05-15 | 2018-10-12 | 浙江大学 | 一种基于状态分布感知采样的智能体深度价值函数学习方法 |
CN109348707A (zh) * | 2016-04-27 | 2019-02-15 | 纽拉拉股份有限公司 | 针对基于深度神经网络的q学习修剪经验存储器的方法和装置 |
CN110728694A (zh) * | 2019-10-10 | 2020-01-24 | 北京工业大学 | 一种基于持续学习的长时视觉目标跟踪方法 |
CN111199241A (zh) * | 2019-12-17 | 2020-05-26 | 清华大学 | 任务不可知连续学习场景的元空间聚类学习方法及装置 |
US20200193226A1 (en) * | 2018-12-17 | 2020-06-18 | King Fahd University Of Petroleum And Minerals | Enhanced deep reinforcement learning deep q-network models |
CN112084330A (zh) * | 2020-08-12 | 2020-12-15 | 东南大学 | 一种基于课程规划元学习的增量关系抽取方法 |
-
2021
- 2021-03-24 CN CN202110312417.1A patent/CN112698933A/zh active Pending
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109348707A (zh) * | 2016-04-27 | 2019-02-15 | 纽拉拉股份有限公司 | 针对基于深度神经网络的q学习修剪经验存储器的方法和装置 |
WO2018017546A1 (en) * | 2016-07-18 | 2018-01-25 | Google Llc | Training machine learning models on multiple machine learning tasks |
CN108647789A (zh) * | 2018-05-15 | 2018-10-12 | 浙江大学 | 一种基于状态分布感知采样的智能体深度价值函数学习方法 |
US20200193226A1 (en) * | 2018-12-17 | 2020-06-18 | King Fahd University Of Petroleum And Minerals | Enhanced deep reinforcement learning deep q-network models |
CN110728694A (zh) * | 2019-10-10 | 2020-01-24 | 北京工业大学 | 一种基于持续学习的长时视觉目标跟踪方法 |
CN111199241A (zh) * | 2019-12-17 | 2020-05-26 | 清华大学 | 任务不可知连续学习场景的元空间聚类学习方法及装置 |
CN112084330A (zh) * | 2020-08-12 | 2020-12-15 | 东南大学 | 一种基于课程规划元学习的增量关系抽取方法 |
Non-Patent Citations (1)
Title |
---|
GUANXIONG ZENG等: "Continual learning of context-dependent processing in neural networks", 《NATURE MACHINE INTELLIGENCE》 * |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110287942B (zh) | 年龄估计模型的训练方法、年龄估计方法以及对应的装置 | |
CN108694502B (zh) | 一种基于XGBoost算法的机器人制造单元自适应调度方法 | |
CN108564136B (zh) | 一种基于模糊推理的空域运行态势评估分类方法 | |
CN103544499A (zh) | 一种基于机器视觉的表面瑕疵检测的纹理特征降维方法 | |
CN109886342A (zh) | 基于机器学习的模型训练方法和装置 | |
CN112465001B (zh) | 一种基于逻辑回归的分类方法及装置 | |
CN111931826A (zh) | 基于多尺度卷积迁移模型的滚动轴承故障诊断方法和系统 | |
CN112017204A (zh) | 一种基于边缘标记图神经网络的刀具状态图像分类方法 | |
CN115452376A (zh) | 基于改进轻量级深度卷积神经网络的轴承故障诊断方法 | |
Liu et al. | A dual-branch balance saliency model based on discriminative feature for fabric defect detection | |
CN112698933A (zh) | 在多任务数据流中持续学习的方法及装置 | |
CN110320802B (zh) | 基于数据可视化的复杂系统信号时序识别方法 | |
CN112183469A (zh) | 一种公共交通的拥挤度识别及自适应调整方法、系统、设备及计算机可读存储介质 | |
CN116452950A (zh) | 一种基于改进YOLOv5模型的多目标垃圾检测方法 | |
CN115687948A (zh) | 一种基于负荷曲线的电力专变用户无监督分类方法 | |
CN114528906A (zh) | 一种旋转机械的故障诊断方法、装置、设备和介质 | |
CN115542279A (zh) | 一种气象雷达杂波分类识别方法及装置 | |
CN115169660A (zh) | 基于多尺度时空特征融合神经网络的刀具磨损预测方法 | |
CN109978038A (zh) | 一种集群异常判定方法及装置 | |
CN115640335B (zh) | 基于企业画像的企业分析方法、系统及云平台 | |
CN117067042B (zh) | 一种研磨机及其控制方法 | |
CN109784477B (zh) | 一种用于对比神经网络训练的采样的方法及系统 | |
CN115187386A (zh) | 在线信用卡风险信息检测方法及装置 | |
CN115018066A (zh) | 一种边端模式下的深度神经网络本地化训练方法 | |
CN113269311A (zh) | 一种元素级重激活不重要神经元的神经网络训练方法及系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
RJ01 | Rejection of invention patent application after publication | ||
RJ01 | Rejection of invention patent application after publication |
Application publication date: 20210423 |