CN112698933A - 在多任务数据流中持续学习的方法及装置 - Google Patents

在多任务数据流中持续学习的方法及装置 Download PDF

Info

Publication number
CN112698933A
CN112698933A CN202110312417.1A CN202110312417A CN112698933A CN 112698933 A CN112698933 A CN 112698933A CN 202110312417 A CN202110312417 A CN 202110312417A CN 112698933 A CN112698933 A CN 112698933A
Authority
CN
China
Prior art keywords
task
learning
stage
clustering
tuple
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110312417.1A
Other languages
English (en)
Inventor
张俊格
李庆明
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Institute of Automation of Chinese Academy of Science
Original Assignee
Institute of Automation of Chinese Academy of Science
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Institute of Automation of Chinese Academy of Science filed Critical Institute of Automation of Chinese Academy of Science
Priority to CN202110312417.1A priority Critical patent/CN112698933A/zh
Publication of CN112698933A publication Critical patent/CN112698933A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F9/00Arrangements for program control, e.g. control units
    • G06F9/06Arrangements for program control, e.g. control units using stored programs, i.e. using an internal store of processing equipment to receive or retain programs
    • G06F9/46Multiprogramming arrangements
    • G06F9/48Program initiating; Program switching, e.g. by interrupt
    • G06F9/4806Task transfer initiation or dispatching
    • G06F9/4843Task transfer initiation or dispatching by program, e.g. task dispatcher, supervisor, operating system
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/23Clustering techniques
    • G06F18/232Non-hierarchical techniques
    • G06F18/2321Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions
    • G06F18/23213Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions with fixed number of clusters, e.g. K-means clustering
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/004Artificial life, i.e. computing arrangements simulating life
    • G06N3/008Artificial life, i.e. computing arrangements simulating life based on physical entities controlled by simulated intelligence so as to replicate intelligent life forms, e.g. based on robots replicating pets or humans in their appearance or behaviour
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Software Systems (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Molecular Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Probability & Statistics with Applications (AREA)
  • Robotics (AREA)
  • Image Analysis (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明提供通用的在多任务数据流中持续学习的方法和装置,包括:智能体采用DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段,如果判定任务发生了改变,则学习下一个任务,在学习下一个任务时,如果改变所述DQN网络的参数,将会启动额外的惩罚项,通过约束所述DQN网络的参数在下个任务学习中变化实现对当前任务的不遗忘。

Description

在多任务数据流中持续学习的方法及装置
技术领域
本申请涉及多智能体、在线学习领域,尤其涉及可以在多任务数据流中持续学习的方法和装置。
背景技术
一个智能体在离线训练后可以在当前任务下做出正确的决策,然而将它置于新的环境或者任务下,它可能会继续学习已完成新的任务但是它将遗忘先前的知识。在持续学习的研究中,人们把这种现象称作灾难性遗忘。对于灾难性遗忘问题的研究有许多,且已经取得了不俗的进展。现在较主流的方法有I经验回放方法,这类方法通过重播历史数据或者生成虚拟的历史数据等解决灾难性遗忘. II正则化方法,比如蒸馏旧知识或者巩固参数以防止灾难性遗忘。III 参数隔离方法,比如Expert Gate方法,通过衡量任务间关系度实现任务增量设定下的持续学习。
上述工作均停留在任务边界可知的假设上,例如在训练时明确知道当前训练的任务,而在现实中得到的数据通常是一个任务边界不可知的数据流。比如人眼的数据或者摄像机的数据,这要求agent必须能够在这样的数据流中进行学习并作出正确的决策。基于现有灾难性遗忘问题的研究,也有一些工作进行了在包含边界不可知的数据流中进行训练决策的探索。Task-Free Continual Learning一文过监测loss值得变化来检测任务边界的变化,决定何时使用MAS方法(一种正则化方法)巩固现有的知识。Continual ReinforcementLearning in 3D Non-stationary Environments关注了任务环境发生改变时如何进行持续学习。该算法通过衡量reward的差异实现任务环境边界的检测,然后在环境改变时通过EWC(一种正则化方法)巩固知识。
当我们将环境扩展为任务边界不可知的强化学习环境时,基础的灾难性遗忘问题的研究会暴露严重的弊端。强化学习通常拥有庞大的观测空间。这个观测空间由于其多变性与随机性难以保存与生成,因此经验回放方法在强化学习环境下无法发挥其优势。而正则化与参数隔离方法依赖于明确的任务边界。它在任务边界不可知的情况下难以在合适的时间巩固知识或者隔离参数。相比正则化方法,参数隔离方法通常会对网络模型进行增量的修改,因此更加难以在数据流中进行学习与决策。
现有的两种处理任务边界不可知的算法均存在一定的弊端。Task-FreeContinual Learning创作性的实现在任务边界不可知的人脸信息流中进行在线学习。但是这类方法局限于Loss平稳的假设,泛化能力较差。Continual Reinforcement Learning in3D Non-stationary Environments依赖于不同任务间reward的设定,如果两个任务的reward设定难以在任务切换时被检测出来,这种方法将不起作用。
授权公告号CN 106507398 B公开了一种基于持续学习的网络自优化方法;包含持续学习过程和网络优化过程;本发明提供的网络自优化方法可以大大减少人力物力的投入,节约成本,缩短优化流程,提高优化效率,同时解决上述发明优化时间冗长,可能不是最佳优化策略的缺陷;快速地发现网络中出现的问题,并能够缩短网络故障的持续时间,及时恢复网络正常的工作状态,达到优化网络性能的目的。
申请公布号CN 110705689 A公开了一种可区分特征的持续学习方法及装置,其中,方法包括以下步骤:确定当前分类任务,并将目标函数加入学习模型的angularloss项;在当前分类任务为新任务时,初始化学习模型的最后一层的参数,以使其相互正交,并在更新过程中暂时固定;训练预设时间后,将最后一层的参数参与更新,其中,最后一层的参数作为正交化的约束,以使不同的分类任务不会互相干扰。该方法可以在处理不同类型的分类任务时,都能够很好地对该任务的目标类别进行精准分类,有效解决了目前持续学习中不同任务中的类别在特征空间中会发生重叠、导致模型效果不佳的问题。
发明内容
有鉴于此,本发明第一方面提供一种在多任务数据流中持续学习的方法,包括:
智能体采用DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段;具体方法如下:
S1:通过输入数据状态s和智能体给出的动作a,计算Q值,并由输入数据状态队列S和智能体给出的动作队列计算Q值队列;
S2:对所述Q值队列进行排序,取数值最高的τ’个Q值对应的输入数据状态并形成第一元组(Smax,1),取数值最低的τ’’个Q值对应的输入数据状态并形成第二元组(Smin,0);将第一元组(Smax,1)和第二元组(Smin,0)合并组成原始数据集 (X,Y) ,所述Y为聚类结果的标签(Y0,Y1),Y0对应的是所述第一元组中的1,Y1对应的是所述第二元组中的0;
S3:对原始数据集中的X降维,得到降维特征Z;
S4:应用k-mean对所述Z进行聚类,得到聚类中心μ0和μ1;然后求得每个降维特征Z距离聚类中心的概率分布;
S5:应用所述概率分布求得任务学习阶段的聚类准确性;
S6:当任务学习阶段的聚类准确性大于第一阈值时,进入任务识别阶段;当任务学习阶段的聚类准确性小于一定阈值时,返回步骤S1;
S7:采集任务识别数据状态队列S’,将输入数据状态队列S替换为任务识别数据状态队列S’,重复步骤S1-S5得到任务识别阶段的聚类准确性;
S8:当任务识别阶段的聚类准确性小于第二阈值时,判定任务发生了改变,学习下一个任务。
优选地,所述方法还包括:
S9:在学习下一个任务时,如果改变所述DQN网络参数,将会启动额外的惩罚项,通过约束所述DQN网络参数在下个任务学习中变化实现对当前任务的不遗忘。
优选地,所述额外的惩罚项为,
Figure 276506DEST_PATH_IMAGE001
其中,
Figure 100105DEST_PATH_IMAGE002
:DQN网络在任务B时的损失值;
F i :累加公式中第i个Fisher信息矩阵;
Figure DEST_PATH_IMAGE003
:DQN网络在任务B时的参数;
Figure 488361DEST_PATH_IMAGE004
:DQN网络在任务A时训练完成的参数;
Figure DEST_PATH_IMAGE005
:超参数>0。
优选地,所述Q值的计算方法为:
Figure 261276DEST_PATH_IMAGE006
Figure DEST_PATH_IMAGE007
其中
Figure 443996DEST_PATH_IMAGE008
:期望计算;
s t :期望计算公式中输入数据变量,输入数据状态s为数值;
a t :期望计算公式中策略网络输出动作变量,动作a为数值;
Figure DEST_PATH_IMAGE009
:输入数据状态s进策略网络后输出一个动作a的q值分布;
Figure 754892DEST_PATH_IMAGE010
:代表折扣系数为范围为0-1;
Figure DEST_PATH_IMAGE011
:t’时刻智能体执行动作a带来的收益;
T:输入数据状态s的数据流总时长。
优选地,所述对原始数据集中的X降维,得到降维特征Z的具体方法为:
由两层全连接层将原始数据集中的X降维到所述降维特征Z。
优选地,所述每个降维特征Z距离聚类中心的概率分布的计算方法为:
Figure 681259DEST_PATH_IMAGE012
其中,
zi:所述降维特征Z中第i个样本;
μj:所述聚类中心;
α:超参数为0.5。
优选地,所述求得任务学习阶段的聚类准确性的具体方法为:
Figure DEST_PATH_IMAGE013
其中,
qi:代表第i个样本聚类求得的概率分布,具体为qi0和qi1
Yi:聚类结果的标签Y中第i个样本对应的真实标签;
m:max,取qi0和qi1中的最大值,来判断样本属于哪个类,确定第i个样本对应的真实标签Yi是否与判断一致;
τ’:所述第一元组(Smax,1)中,元素Smax的维数;
τ’’:所述第二元组(Smin,0)中,元素Smin的维数。
优选地,所述第一阈值为
Figure 308681DEST_PATH_IMAGE014
范围为0.5-1。
优选地,所述第二阈值为
Figure DEST_PATH_IMAGE015
范围为0.5-1。
本发明第二方面还提供了一种在多任务数据流中持续学习的装置,包括:DQN网络和任务识别器; DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段;
所述任务识别器包括:自动编码器与任务数据缓存;输入数据状态队列S和智能体给出的动作队列计算Q值队列存储在任务数据缓存中;当输入数据状态队列的缓存长度大于一定长度τ时,任务识别器将会进入任务学习阶段;自动编码器,对所述Q值队列进行排序,取数值最高的τ’个Q值对应的输入数据状态并形成第一元组(Smax,1),取数值最低的τ’’个Q值对应的输入数据状态并形成第二元组(Smin,0);将第一元组(Smax,1)和第二元组(Smin,0)合并组成原始数据集 (X,Y) ,所述Y为聚类结果的标签(Y0,Y1),Y0对应的是所述第一元组中的1和Y1对应的是所述第二元组中的0;对原始数据集中的X降维,得到降维特征Z;应用k-mean对所述Z进行聚类,得到聚类中心μ0和μ1;然后求得每个降维特征Z距离聚类中心的概率分布;应用所述概率分布求得任务学习阶段的聚类准确性;判断当任务学习阶段的聚类准确性大于第一阈值时,进入任务识别阶段;当任务学习阶段的聚类准确性小于一定阈值时,任务数据缓存继续收集数据;
在任务识别阶段,任务数据缓存采集任务识别数据状态队列S’,自动编码器应用任务识别数据状态队列S’计算任务识别阶段的聚类准确性,当任务识别阶段的聚类准确性小于第二阈值时,判定任务发生了改变,学习下一个任务;
在学习下一个任务时,如果改变所述DQN网络的参数,自动编码器将会启动额外的惩罚项,通过约束所述DQN网络的参数在下个任务学习中变化实现对当前任务的不遗忘。
本申请实施例提供的上述技术方案与现有技术相比具有如下优点:
本申请实施例提供的该方法,通过聚类方法进行任务边界的检测识别,解决了在任务边界不可知的数据流中进行决策的问题。
附图说明
图1为本发明实施例提供的可以在多任务数据流中持续学习的方法流程图;
图2为本发明另一实施例提供的在多任务数据流中持续学习的方法流程图。
具体实施方式
这里将详细地对示例性实施例进行说明,其示例表示在附图中。下面的描述涉及附图时,除非另有表示,不同附图中的相同数字表示相同或相似的要素。以下示例性实施例中所描述的实施方式并不代表与本发明相一致的所有实施方式。相反,它们仅是与如所附权利要求书中所详述的、本发明的一些方面相一致的装置和方法的例子。
参见图1,本申请实施例提供的一种在多任务数据流中持续学习的方法,包括:
智能体采用DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段;具体方法如下:
S1:通过输入数据状态s和智能体给出的动作a,计算Q值,并由输入数据状态队列S和智能体给出的动作队列计算Q值队列。
所述Q值的计算方法为:
Figure 927881DEST_PATH_IMAGE006
Figure 929335DEST_PATH_IMAGE007
其中
Figure 659394DEST_PATH_IMAGE008
:期望计算;
s t :期望计算公式中输入数据变量,输入数据状态s为数值;
a t :期望计算公式中策略网络输出动作变量,动作a为数值;
Figure 610163DEST_PATH_IMAGE009
:输入数据状态s进策略网络后输出一个动作a的q值分布;
Figure 603527DEST_PATH_IMAGE010
:代表折扣系数,具体范围为0-1;
Figure 154594DEST_PATH_IMAGE011
:t’时刻智能体执行动作a带来的收益;
T:输入数据状态s的数据流总时长。
S2:对所述Q值队列进行排序,取数值最高的τ’个Q值对应的输入数据状态并形成第一元组(Smax,1),取数值最低的τ’’个Q值对应的输入数据状态并形成第二元组(Smin,0);将第一元组(Smax,1)和第二元组(Smin,0)合并组成原始数据集 (X,Y) ,所述Y为聚类结果的标签(Y0,Y1),Y0对应的是所述第一元组中的1,Y1对应的是所述第二元组中的0。
S3:对原始数据集中的X降维,得到降维特征Z。
所述对原始数据集中的X降维,得到降维特征Z的具体方法为:
由两层全连接层将原始数据集中的X降维到所述降维特征Z。
S4:应用k-mean对所述Z进行聚类,得到聚类中心μ0和μ1;然后求得每个降维特征Z距离聚类中心的概率分布。
所述每个降维特征Z距离聚类中心的概率分布的计算方法为:
Figure 626027DEST_PATH_IMAGE012
其中,
zi:所述降维特征Z中第i个样本;
μj:所述聚类中心;
α:超参数为0-1,优选的为0.5。
S5:应用所述概率分布求得任务学习阶段的聚类准确性。
所述求得任务学习阶段的聚类准确性的具体方法为:
Figure 680571DEST_PATH_IMAGE016
其中,
qi:代表第i个样本聚类求得的概率分布,具体为qi0和qi1
Yi:聚类结果的标签Y中第i个样本对应的真实标签;
m:max,取qi0和qi1中的最大值,来判断样本属于哪个类,确定第i个样本对应的真实标签Yi是否与判断一致;
τ’:所述第一元组(Smax,1)中,元素Smax的维数;
τ’’:所述第二元组(Smin,0)中,元素Smin的维数。
S6:当任务学习阶段的聚类准确性大于第一阈值时,进入任务识别阶段;当任务学习阶段的聚类准确性小于一定阈值时,返回步骤S1。
所述第一阈值为
Figure 907153DEST_PATH_IMAGE014
为0.5-1,优选为0.9。
S7:采集任务识别数据状态队列S’,将输入数据状态队列S替换为任务识别数据状态队列S’,重复步骤S1-S5得到任务识别阶段的聚类准确性。
S8:当任务识别阶段的聚类准确性小于第二阈值时,判定任务发生了改变,学习下一个任务。
所述第二阈值为
Figure 617620DEST_PATH_IMAGE017
范围为0.5-1,优选的为0.7。
在一些实施例中,如图2所示,上述在多任务数据流中持续学习的方法还包括:
S9:在学习下一个任务时,如果改变所述DQN网络参数,将会启动额外的惩罚项,通过约束所述DQN网络参数在下个任务学习中变化实现对当前任务的不遗忘。
所述额外的惩罚项为,
Figure 440213DEST_PATH_IMAGE001
其中,
Figure 880422DEST_PATH_IMAGE002
:DQN网络在任务B时的损失值;
F i :累加公式中第i个Fisher信息矩阵;
Figure 950009DEST_PATH_IMAGE003
:DQN网络在任务B时的参数;
Figure 210089DEST_PATH_IMAGE004
:DQN网络在任务A时训练完成的参数;
Figure 85641DEST_PATH_IMAGE005
:超参数范围为>0,优选的为15。
基于同一发明构思,本申请实施例提供的一种在多任务数据流中持续学习的装置,包括:DQN网络和任务识别器; DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段;
所述任务识别器包括:自动编码器与任务数据缓存;输入数据状态队列S和智能体给出的动作队列计算Q值队列存储在任务数据缓存中;当输入数据状态队列的缓存长度大于一定长度τ时,任务识别器将会进入任务学习阶段;自动编码器,对所述Q值队列进行排序,取数值最高的τ’个Q值对应的输入数据状态并形成第一元组(Smax,1),取数值最低的τ’’个Q值对应的输入数据状态并形成第二元组(Smin,0);将第一元组(Smax,1)和第二元组(Smin,0)合并组成原始数据集 (X,Y) ,所述Y为聚类结果的标签(Y0,Y1),Y0对应的是所述第一元组中的1和Y1对应的是所述第二元组中的0;对原始数据集中的X降维,得到降维特征Z;应用k-mean对所述Z进行聚类,得到聚类中心μ0和μ1;然后求得每个降维特征Z距离聚类中心的概率分布;应用所述概率分布求得任务学习阶段的聚类准确性;判断当任务学习阶段的聚类准确性大于第一阈值时,进入任务识别阶段;当任务学习阶段的聚类准确性小于一定阈值时,任务数据缓存继续收集数据;
在任务识别阶段,任务数据缓存采集任务识别数据状态队列S’,自动编码器应用任务识别数据状态队列S’计算任务识别阶段的聚类准确性,当任务识别阶段的聚类准确性小于第二阈值时,判定任务发生了改变,学习下一个任务;
在学习下一个任务时,如果改变所述DQN网络的参数,自动编码器将会启动额外的惩罚项,通过约束所述DQN网络的参数在下个任务学习中变化实现对当前任务的不遗忘。
具体实施例
DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段;
所述任务识别器包括:自动编码器与任务数据缓存;输入数据状态队列S和智能体给出的动作队列计算Q值队列存储在任务数据缓存中;
所述Q值的计算方法为:
Figure 318039DEST_PATH_IMAGE006
Figure 371577DEST_PATH_IMAGE007
其中
Figure 322216DEST_PATH_IMAGE008
:期望计算;
s t :期望计算公式中输入数据变量,输入数据状态s为数值;
a t :期望计算公式中策略网络输出动作变量,动作a为数值;
Figure 735879DEST_PATH_IMAGE009
:输入数据状态s进DQN网络后输出一个动作a的q值分布;
Figure 885101DEST_PATH_IMAGE010
:代表折扣系数,具体范围为0-1,优选的为0.5;
Figure 562070DEST_PATH_IMAGE011
:t’时刻智能体执行动作a带来的收益;
T:输入数据状态s的数据流总时长。
当输入数据状态队列的缓存长度大于一定长度τ时,τ优选的为1024,任务识别器将会进入任务学习阶段;自动编码器,对所述Q值队列进行排序,取数值最高的τ’个Q值对应的输入数据状态并形成第一元组(Smax,1),τ’优选的为64,取数值最低的τ’’个Q值对应的输入数据状态并形成第二元组(Smin,0);τ’’优选的为63,将第一元组(Smax,1)和第二元组(Smin,0)合并组成原始数据集 (X,Y) ,所述Y为聚类结果的标签(Y0,Y1),Y0对应的是所述第一元组中的1和Y1对应的是所述第二元组中的0;由自动编码器的两层全连接层组成编码层对原始数据集中的X降维,得到降维特征Z;应用k-mean对所述Z进行聚类,得到聚类中心μ0和μ1;然后求得每个降维特征Z距离聚类中心的概率分布;
所述每个降维特征Z距离聚类中心的概率分布的计算方法为:
Figure 531163DEST_PATH_IMAGE012
其中,
zi:所述降维特征Z中第i个样本;
μj:所述聚类中心;
α:超参数,具体为0.5;
应用所述概率分布求得任务学习阶段的聚类准确性;
Figure 499250DEST_PATH_IMAGE016
其中,
qi:代表第i个样本聚类求得的概率分布,具体为qi0和qi1
Yi:聚类结果的标签Y中第i个样本对应的真实标签;
m:max,取qi0和qi1中的最大值,来判断样本属于哪个类,确定第i个样本对应的真实标签Yi是否与判断一致;
τ’:所述第一元组(Smax,1)中,元素Smax的维数;
τ’’:所述第二元组(Smin,0)中,元素Smin的维数。
判断当任务学习阶段的聚类准确性大于第一阈值时,所述第一阈值为
Figure 706241DEST_PATH_IMAGE014
Figure 350849DEST_PATH_IMAGE014
优选为0.9,进入任务识别阶段;当任务学习阶段的聚类准确性小于一定阈值时,任务数据缓存继续收集数据;
在任务识别阶段,任务数据缓存采集任务识别数据状态队列S’, 自动编码器应用任务识别数据状态队列S’计算任务识别阶段的聚类准确性,当任务识别阶段的聚类准确性小于第二阈值时,所述第二阈值为
Figure 10500DEST_PATH_IMAGE015
Figure 765966DEST_PATH_IMAGE015
优选为0.7,判定任务发生了改变,学习下一个任务;
在学习下一个任务时,如果改变所述DQN网络参数,自动编码器将会启动额外的惩罚项,通过约束所述DQN网络参数在下个任务学习中变化实现对当前任务的不遗忘;
所述额外的惩罚项为,
Figure 889780DEST_PATH_IMAGE001
其中,
Figure 908552DEST_PATH_IMAGE002
:DQN网络在任务B时的损失值;
F i :累加公式中第i个Fisher信息矩阵;
Figure 886127DEST_PATH_IMAGE003
:DQN网络在任务B时的参数;
Figure 179705DEST_PATH_IMAGE004
:DQN网络在任务A时训练完成的参数;
Figure 95708DEST_PATH_IMAGE005
:超参数,具体优选为15。
应当理解,尽管在本发明可能采用术语第一、第二、第三等来描述各种信息,但这些信息不应限于这些术语。这些术语仅用来将同一类型的信息彼此区分开。例如,在不脱离本发明范围的情况下,第一信息也可以被称为第二信息,类似地,第二信息也可以被称为第一信息。取决于语境,如在此所使用的词语“如果”可以被解释成为“在……时”或“当……时”或“响应于确定”。
虽然本说明书包含许多具体实施细节,但是这些不应被解释为限制任何发明的范围或所要求保护的范围,而是主要用于描述特定发明的具体实施例的特征。本说明书内在多个实施例中描述的某些特征也可以在单个实施例中被组合实施。另一方面,在单个实施例中描述的各种特征也可以在多个实施例中分开实施或以任何合适的子组合来实施。此外,虽然特征可以如上所述在某些组合中起作用并且甚至最初如此要求保护,但是来自所要求保护的组合中的一个或多个特征在一些情况下可以从该组合中去除,并且所要求保护的组合可以指向子组合或子组合的变型。
类似地,虽然在附图中以特定顺序描绘了操作,但是这不应被理解为要求这些操作以所示的特定顺序执行或顺次执行、或者要求所有例示的操作被执行,以实现期望的结果。在某些情况下,多任务和并行处理可能是有利的。此外,上述实施例中的各种系统模块和组件的分离不应被理解为在所有实施例中均需要这样的分离,并且应当理解,所描述的程序组件和系统通常可以一起集成在单个软件产品中,或者封装成多个软件产品。
由此,主题的特定实施例已被描述。其他实施例在所附权利要求书的范围以内。在某些情况下,权利要求书中记载的动作可以以不同的顺序执行并且仍实现期望的结果。此外,附图中描绘的处理并非必需所示的特定顺序或顺次顺序,以实现期望的结果。在某些实现中,多任务和并行处理可能是有利的。
以上所述仅为本发明的较佳实施例而已,并不用以限制本发明,凡在本发明的精神和原则之内,所做的任何修改、等同替换、改进等,均应包含在本发明保护的范围之内。

Claims (10)

1.在多任务数据流中持续学习的方法,其特征在于,包括:
智能体采用DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段;具体方法如下:
S1:通过输入数据状态s和智能体给出的动作a,计算Q值,并由输入数据状态队列S和智能体给出的动作队列计算Q值队列;
S2:对所述Q值队列进行排序,取数值最高的τ’个Q值对应的输入数据状态并形成第一元组(Smax,1),取数值最低的τ’’个Q值对应的输入数据状态并形成第二元组(Smin,0);将第一元组(Smax,1)和第二元组(Smin,0)合并组成原始数据集 (X,Y) ,所述Y为聚类结果的标签(Y0,Y1),Y0对应的是所述第一元组中的1,Y1对应的是所述第二元组中的0;
S3:对原始数据集中的X降维,得到降维特征Z;
S4:应用k-mean对所述Z进行聚类,得到聚类中心μ0和μ1;然后求得每个降维特征Z距离聚类中心的概率分布;
S5:应用所述概率分布求得任务学习阶段的聚类准确性;
S6:当任务学习阶段的聚类准确性大于第一阈值时,进入任务识别阶段;当任务学习阶段的聚类准确性小于一定阈值时,返回步骤S1;
S7:采集任务识别数据状态队列S’,将输入数据状态队列S替换为任务识别数据状态队列S’,重复步骤S1-S5得到任务识别阶段的聚类准确性;
S8:当任务识别阶段的聚类准确性小于第二阈值时,判定任务发生了改变,学习下一个任务。
2.根据权利要求1所述的在多任务数据流中持续学习的方法,其特征在于,所述方法还包括:
S9:在学习下一个任务时,如果改变所述DQN网络的参数,将会启动额外的惩罚项,通过约束所述DQN网络的参数在下个任务学习中变化实现对当前任务的不遗忘。
3.根据权利要求2所述的在多任务数据流中持续学习的方法,其特征在于,所述额外的惩罚项为,
Figure 332606DEST_PATH_IMAGE001
其中,
Figure 37257DEST_PATH_IMAGE002
:DQN网络在任务B时的损失值;
F i :累加公式中第i个Fisher信息矩阵;
Figure 782359DEST_PATH_IMAGE003
:DQN网络在任务B时的参数;
Figure 547053DEST_PATH_IMAGE004
:DQN网络在任务A时训练完成的参数;
Figure 377606DEST_PATH_IMAGE005
:超参数>0。
4.根据权利要求1所述的在多任务数据流中持续学习的方法,其特征在于,所述Q值的计算方法为:
Figure 636680DEST_PATH_IMAGE006
Figure 33026DEST_PATH_IMAGE007
其中
Figure 906304DEST_PATH_IMAGE008
:期望计算;
s t :期望计算公式中输入数据变量,输入数据状态s为数值;
a t :期望计算公式中策略网络输出动作变量,动作a为数值;
Figure 286470DEST_PATH_IMAGE009
:输入数据状态s进策略网络后输出一个动作a的q值分布;
Figure 67344DEST_PATH_IMAGE010
:代表折扣系数为范围为0-1;
Figure 787039DEST_PATH_IMAGE011
:t’时刻智能体执行动作a带来的收益;
T:输入数据状态s的数据流总时长。
5.根据权利要求1所述的在多任务数据流中持续学习的方法,其特征在于,所述对原始数据集中的X降维,得到降维特征Z的具体方法为:
由两层全连接层将原始数据集中的X降维到所述降维特征Z。
6.根据权利要求1所述的在多任务数据流中持续学习的方法,其特征在于,所述每个降维特征Z距离聚类中心的概率分布的计算方法为:
Figure 644267DEST_PATH_IMAGE012
其中,
zi:所述降维特征Z中第i个样本;
μj:所述聚类中心;
α:超参数为0.5。
7.根据权利要求6所述的在多任务数据流中持续学习的方法,其特征在于,所述求得任务学习阶段的聚类准确性的具体方法为:
Figure 246150DEST_PATH_IMAGE013
其中,
qi:代表第i个样本聚类求得的概率分布,具体为qi0和qi1
Yi:聚类结果的标签Y中第i个样本对应的真实标签;
m:max,取qi0和qi1中的最大值,来判断样本属于哪个类,确定第i个样本对应的真实标签Yi是否与判断一致;
τ’:所述第一元组(Smax,1)中,元素Smax的维数;
τ’’:所述第二元组(Smin,0)中,元素Smin的维数。
8.根据权利要求1所述的在多任务数据流中持续学习的方法,其特征在于,所述第一阈值为
Figure 33977DEST_PATH_IMAGE014
范围为0.5-1。
9.根据权利要求1所述的在多任务数据流中持续学习的方法,其特征在于,所述第二阈值为
Figure 404916DEST_PATH_IMAGE015
范围为0.5-1。
10.在多任务数据流中持续学习的装置,其特征在于,包括:DQN网络和任务识别器;DQN网络进行策略学习,由任务识别器检测任务边界,防止发生灾难性遗忘;任务识别器采集数据状态,采集到一定长度后进入任务学习阶段,通过任务学习阶段的聚类准确性判断任务学习是否完成,如果完成进入任务识别阶段;
所述任务识别器包括:自动编码器与任务数据缓存;输入数据状态队列S和智能体给出的动作队列计算Q值队列存储在任务数据缓存中;当输入数据状态队列的缓存长度大于一定长度τ时,任务识别器将会进入任务学习阶段;自动编码器,对所述Q值队列进行排序,取数值最高的τ’个Q值对应的输入数据状态并形成第一元组(Smax,1),取数值最低的τ’’个Q值对应的输入数据状态并形成第二元组(Smin,0);将第一元组(Smax,1)和第二元组(Smin,0)合并组成原始数据集 (X,Y) ,所述Y为聚类结果的标签(Y0,Y1),Y0对应的是所述第一元组中的1和Y1对应的是所述第二元组中的0;对原始数据集中的X降维,得到降维特征Z;应用k-mean对所述Z进行聚类,得到聚类中心μ0和μ1;然后求得每个降维特征Z距离聚类中心的概率分布;应用所述概率分布求得任务学习阶段的聚类准确性;判断当任务学习阶段的聚类准确性大于第一阈值时,进入任务识别阶段;当任务学习阶段的聚类准确性小于一定阈值时,任务数据缓存继续收集数据;
在任务识别阶段,任务数据缓存采集任务识别数据状态队列S’,自动编码器应用任务识别数据状态队列S’计算任务识别阶段的聚类准确性,当任务识别阶段的聚类准确性小于第二阈值时,判定任务发生了改变,学习下一个任务;
在学习下一个任务时,如果改变所述DQN网络的参数,自动编码器将会启动额外的惩罚项,通过约束所述DQN网络的参数在下个任务学习中变化实现对当前任务的不遗忘。
CN202110312417.1A 2021-03-24 2021-03-24 在多任务数据流中持续学习的方法及装置 Pending CN112698933A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110312417.1A CN112698933A (zh) 2021-03-24 2021-03-24 在多任务数据流中持续学习的方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110312417.1A CN112698933A (zh) 2021-03-24 2021-03-24 在多任务数据流中持续学习的方法及装置

Publications (1)

Publication Number Publication Date
CN112698933A true CN112698933A (zh) 2021-04-23

Family

ID=75515565

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110312417.1A Pending CN112698933A (zh) 2021-03-24 2021-03-24 在多任务数据流中持续学习的方法及装置

Country Status (1)

Country Link
CN (1) CN112698933A (zh)

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2018017546A1 (en) * 2016-07-18 2018-01-25 Google Llc Training machine learning models on multiple machine learning tasks
CN108647789A (zh) * 2018-05-15 2018-10-12 浙江大学 一种基于状态分布感知采样的智能体深度价值函数学习方法
CN109348707A (zh) * 2016-04-27 2019-02-15 纽拉拉股份有限公司 针对基于深度神经网络的q学习修剪经验存储器的方法和装置
CN110728694A (zh) * 2019-10-10 2020-01-24 北京工业大学 一种基于持续学习的长时视觉目标跟踪方法
CN111199241A (zh) * 2019-12-17 2020-05-26 清华大学 任务不可知连续学习场景的元空间聚类学习方法及装置
US20200193226A1 (en) * 2018-12-17 2020-06-18 King Fahd University Of Petroleum And Minerals Enhanced deep reinforcement learning deep q-network models
CN112084330A (zh) * 2020-08-12 2020-12-15 东南大学 一种基于课程规划元学习的增量关系抽取方法

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109348707A (zh) * 2016-04-27 2019-02-15 纽拉拉股份有限公司 针对基于深度神经网络的q学习修剪经验存储器的方法和装置
WO2018017546A1 (en) * 2016-07-18 2018-01-25 Google Llc Training machine learning models on multiple machine learning tasks
CN108647789A (zh) * 2018-05-15 2018-10-12 浙江大学 一种基于状态分布感知采样的智能体深度价值函数学习方法
US20200193226A1 (en) * 2018-12-17 2020-06-18 King Fahd University Of Petroleum And Minerals Enhanced deep reinforcement learning deep q-network models
CN110728694A (zh) * 2019-10-10 2020-01-24 北京工业大学 一种基于持续学习的长时视觉目标跟踪方法
CN111199241A (zh) * 2019-12-17 2020-05-26 清华大学 任务不可知连续学习场景的元空间聚类学习方法及装置
CN112084330A (zh) * 2020-08-12 2020-12-15 东南大学 一种基于课程规划元学习的增量关系抽取方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
GUANXIONG ZENG等: "Continual learning of context-dependent processing in neural networks", 《NATURE MACHINE INTELLIGENCE》 *

Similar Documents

Publication Publication Date Title
CN110287942B (zh) 年龄估计模型的训练方法、年龄估计方法以及对应的装置
CN108694502B (zh) 一种基于XGBoost算法的机器人制造单元自适应调度方法
CN108564136B (zh) 一种基于模糊推理的空域运行态势评估分类方法
CN103544499A (zh) 一种基于机器视觉的表面瑕疵检测的纹理特征降维方法
CN109886342A (zh) 基于机器学习的模型训练方法和装置
CN112465001B (zh) 一种基于逻辑回归的分类方法及装置
CN111931826A (zh) 基于多尺度卷积迁移模型的滚动轴承故障诊断方法和系统
CN112017204A (zh) 一种基于边缘标记图神经网络的刀具状态图像分类方法
CN115452376A (zh) 基于改进轻量级深度卷积神经网络的轴承故障诊断方法
Liu et al. A dual-branch balance saliency model based on discriminative feature for fabric defect detection
CN112698933A (zh) 在多任务数据流中持续学习的方法及装置
CN110320802B (zh) 基于数据可视化的复杂系统信号时序识别方法
CN112183469A (zh) 一种公共交通的拥挤度识别及自适应调整方法、系统、设备及计算机可读存储介质
CN116452950A (zh) 一种基于改进YOLOv5模型的多目标垃圾检测方法
CN115687948A (zh) 一种基于负荷曲线的电力专变用户无监督分类方法
CN114528906A (zh) 一种旋转机械的故障诊断方法、装置、设备和介质
CN115542279A (zh) 一种气象雷达杂波分类识别方法及装置
CN115169660A (zh) 基于多尺度时空特征融合神经网络的刀具磨损预测方法
CN109978038A (zh) 一种集群异常判定方法及装置
CN115640335B (zh) 基于企业画像的企业分析方法、系统及云平台
CN117067042B (zh) 一种研磨机及其控制方法
CN109784477B (zh) 一种用于对比神经网络训练的采样的方法及系统
CN115187386A (zh) 在线信用卡风险信息检测方法及装置
CN115018066A (zh) 一种边端模式下的深度神经网络本地化训练方法
CN113269311A (zh) 一种元素级重激活不重要神经元的神经网络训练方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication
RJ01 Rejection of invention patent application after publication

Application publication date: 20210423