CN114282741A - 任务决策方法、装置、设备及存储介质 - Google Patents
任务决策方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- CN114282741A CN114282741A CN202111107294.4A CN202111107294A CN114282741A CN 114282741 A CN114282741 A CN 114282741A CN 202111107294 A CN202111107294 A CN 202111107294A CN 114282741 A CN114282741 A CN 114282741A
- Authority
- CN
- China
- Prior art keywords
- task
- training
- training sample
- network
- vector
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 88
- 238000003860 storage Methods 0.000 title claims abstract description 30
- 238000012549 training Methods 0.000 claims abstract description 534
- 239000013598 vector Substances 0.000 claims abstract description 268
- 230000009471 action Effects 0.000 claims abstract description 118
- 238000005070 sampling Methods 0.000 claims abstract description 14
- 238000012545 processing Methods 0.000 claims description 41
- 238000006243 chemical reaction Methods 0.000 claims description 38
- 238000005259 measurement Methods 0.000 claims description 32
- 230000008569 process Effects 0.000 claims description 22
- 238000011156 evaluation Methods 0.000 claims description 21
- 230000006870 function Effects 0.000 claims description 21
- 230000006399 behavior Effects 0.000 claims description 13
- 230000002708 enhancing effect Effects 0.000 claims description 9
- 238000004590 computer program Methods 0.000 claims description 6
- 230000001186 cumulative effect Effects 0.000 claims description 4
- 238000007405 data analysis Methods 0.000 claims description 4
- 238000012935 Averaging Methods 0.000 claims description 2
- 238000013473 artificial intelligence Methods 0.000 abstract description 15
- 239000013604 expression vector Substances 0.000 abstract description 6
- 238000005516 engineering process Methods 0.000 description 24
- 230000002787 reinforcement Effects 0.000 description 18
- 238000010801 machine learning Methods 0.000 description 11
- 238000009826 distribution Methods 0.000 description 10
- 238000004422 calculation algorithm Methods 0.000 description 7
- 238000010586 diagram Methods 0.000 description 7
- 241000282414 Homo sapiens Species 0.000 description 5
- 238000004364 calculation method Methods 0.000 description 5
- 239000003795 chemical substances by application Substances 0.000 description 4
- 238000011160 research Methods 0.000 description 4
- 230000003993 interaction Effects 0.000 description 3
- 238000013528 artificial neural network Methods 0.000 description 2
- 239000002131 composite material Substances 0.000 description 2
- 238000011217 control strategy Methods 0.000 description 2
- 238000013527 convolutional neural network Methods 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000012015 optical character recognition Methods 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 239000007787 solid Substances 0.000 description 2
- 230000007704 transition Effects 0.000 description 2
- 238000004458 analytical method Methods 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000003190 augmentative effect Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 238000013434 data augmentation Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000014509 gene expression Effects 0.000 description 1
- 230000001939 inductive effect Effects 0.000 description 1
- 239000011159 matrix material Substances 0.000 description 1
- 238000013508 migration Methods 0.000 description 1
- 230000005012 migration Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 230000008447 perception Effects 0.000 description 1
- 238000003672 processing method Methods 0.000 description 1
- 238000005728 strengthening Methods 0.000 description 1
- 230000001360 synchronised effect Effects 0.000 description 1
- 230000002194 synthesizing effect Effects 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
Images
Landscapes
- Image Analysis (AREA)
Abstract
本申请实施例公开了一种任务决策方法、装置、设备及存储介质,属于人工智能技术领域。所述方法包括:从离线数据池中采样获取多个任务分别对应的训练样本集;采用数据增强的方法,得到每个任务分别对应的多个训练样本集;通过任务推断网络生成训练样本集的任务表示向量;采用对比学习的方法,确定任务推断网络的训练损失;基于各个训练样本的状态向量、动作向量和任务表示向量,确定策略网络的训练损失和评判网络的训练损失;基于上述训练损失对任务决策模型进行训练。本申请通过结合对比学习对任务决策模型的任务推断网络进行训练,实现了任务决策模型的实用性和泛化性能的提高。本申请可适用于机器人控制、自动驾驶、智慧农业等场景中。
Description
技术领域
本申请实施例涉及人工智能技术领域,特别涉及一种任务决策方法、装置、设备及存储介质。
背景技术
随着人工智能技术的发展,离线强化学习在机器人控制、自动驾驶、游戏智能、智慧农业等场景中得到了应用。其中,任务表示的学习对离线强化学习有着重要影响。
在相关技术中,通过基于不同任务对应的任务表示向量之间的距离度量,对任务决策模型中的任务推断网络进行参数调整。例如,先通过任务推断网络获取不同任务对应的任务表示向量,再采用余弦相似度计算不同任务对应的任务表示向量之间的余弦相似度,最后通过最大化不同任务对应的任务表示向量之间的余弦相似度,对任务推断网络进行参数调整,以使得不同任务对应的任务表示向量之间尽可能地分散。其中,距离度量用于表征任务表示向量之间的差异程度,任务表示向量用于表征任务。
然而,对于同一个任务,一些微小的差异(例如一些不相关特征)即可使得相关技术中的任务推断网络输出不同的任务表示向量,相关技术中的任务推断网络不够鲁棒。
发明内容
本申请实施例提供了一种任务决策方法、装置、设备及存储介质,能够使得任务决策模型学习到更鲁棒的任务表示向量,实现任务决策模型的实用性和泛化性能的提高。技术方案如下:
根据本申请实施例的一个方面,提供了一种任务决策方法,任务决策模型包括任务推断网络、策略网络和评判网络,所述方法包括:
从离线数据池中采样获取多个任务分别对应的训练样本集,其中,每一个任务对应的训练样本集中包括属于所述任务的多个训练样本;
对于每一个任务,对所述任务对应的训练样本集进行数据增强,得到所述任务对应的多个训练样本集;
通过所述任务推断网络生成所述训练样本集的任务表示向量,所述训练样本集的任务表示向量用于表征所述训练样本集所属的任务;
根据同一任务对应的多个训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个训练样本集的任务表示向量之间的距离度量,确定所述任务推断网络的训练损失;
基于各个所述训练样本的状态向量、动作向量和任务表示向量,确定所述策略网络的训练损失和所述评判网络的训练损失;
基于所述任务推断网络的训练损失、所述策略网络的训练损失和所述评判网络的训练损失,对所述任务决策模型的参数进行调整,得到完成训练的任务决策模型,所述完成训练的任务决策模型用于多任务决策。
根据本申请实施例的一个方面,提供了一种任务决策装置,任务决策模型包括任务推断网络、策略网络和评判网络,所述装置包括:
样本集获取模块,用于从离线数据池中采样获取多个任务分别对应的训练样本集,其中,每一个任务对应的训练样本集中包括属于所述任务的多个训练样本;
样本集增强模块,用于对于每一个任务,对所述任务对应的训练样本集进行数据增强,得到所述任务对应的多个训练样本集;
表示向量生成模块,用于通过所述任务推断网络生成所述训练样本集的任务表示向量,所述训练样本集的任务表示向量用于表征所述训练样本集所属的任务;
推断损失确定模块,用于根据同一任务对应的多个训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个训练样本集的任务表示向量之间的距离度量,确定所述任务推断网络的训练损失;
策略损失确定模块,用于基于各个所述训练样本的状态向量、动作向量和任务表示向量,确定所述策略网络的训练损失和所述评判网络的训练损失;
网络参数调整模块,用于基于所述任务推断网络的训练损失、所述策略网络的训练损失和所述评判网络的训练损失,对所述任务决策模型的参数进行调整,得到完成训练的任务决策模型,所述完成训练的任务决策模型用于多任务决策。
根据本申请实施例的一个方面,提供了一种计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现上述任务决策方法。
根据本申请实施例的一个方面,提供了一种计算机可读存储介质,所述可读存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现上述任务决策方法。
根据本申请实施例的一个方面,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述任务决策方法。
本申请实施例提供的技术方案可以带来如下有益效果:
通过根据同一任务对应的多个训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个训练样本集的任务表示向量之间的距离度量,确定任务推断网络的训练损失,再基于任务推断网络的训练损失对任务推断网络的参数进行调整,使得任务推断网络可以学习到每个任务对应的通用知识,减少了无关知识对任务表示学习的影响,从而使得任务推断网络可以学习到更鲁棒的任务表示向量,进而使得任务决策模型可以更鲁棒地进行多任务决策,以及更鲁棒地迁移至相似任务的决策中,拓展了任务决策模型的使用场景,提高了任务决策模型的实用性和泛化性能。同时,通过采用本申请提供的技术方案,可以有效改善数据集过拟合的问题,从而进一步提高了任务决策模型的泛化性能。
另外,本申请还结合了数据增强来辅助任务决策模型的训练,通过对训练样本集进行数据增强,为每个任务增加扰动(如无关知识),使得任务推断网络可以学习到更鲁棒的任务表示向量,从而进一步提高了任务决策模型的实用性和泛化性能。同时,通过数据增强,还可以进一步改善数据集过拟合的问题,从而进一步提高了任务决策模型的泛化性能。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请一个实施例提供的任务决策模型的模型架构图;
图2是本申请一个实施例提供的任务决策方法的流程图;
图3是本申请一个实施例提供的训练样本集针对任务推断网络的子训练损失的确定方法的流程图;
图4是本申请一个实施例提供的任务决策装置的框图;
图5是本申请一个实施例提供的计算机设备的结构框图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述。
人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
计算机视觉技术(Computer Vision,CV)是一门研究如何使机器“看”的科学,更进一步的说,就是指用摄影机和电脑代替人眼对目标进行识别、跟踪和测量等机器视觉,并进一步做图形处理,使电脑处理成为更适合人眼观察或传送给仪器检测的图像。作为一个科学学科,计算机视觉研究相关的理论和技术,试图建立能够从图像或者多维数据中获取信息的人工智能系统。计算机视觉技术通常包括图像处理、图像识别、图像语义理解、图像检索、OCR(Optical Character Recognition,光学字符识别)、视频处理、视频语义理解、视频内容/行为识别、三维物体重建、3D技术、虚拟现实、增强现实、同步定位与地图构建等技术,还包括常见的人脸识别、指纹识别等生物特征识别技术。
机器学习(Machine Learning,ML)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、示教学习等技术。
随着人工智能技术研究和进步,人工智能技术在多个领域展开研究和应用,例如常见的智能家居、智能穿戴设备、虚拟助理、智能音箱、智能营销、无人驾驶、自动驾驶、无人机、机器人、智能医疗、智能客服等,相信随着技术的发展,人工智能技术将在更多的领域得到应用,并发挥越来越重要的价值。
本申请实施例提供的方案涉及人工智能的计算机视觉技术和机器学习技术,利用计算机视觉技术和机器学习技术可以从训练样本(如图像数据、传感器数据等)中学习到任务对应的任务表示向量,基于各个任务表示向量对任务决策模型的任务推断网络进行训练,再通过利用机器学习技术基于训练样本的状态向量、动作向量和任务表示向量,对任务决策模型的策略网络和评判网络进行训练。
本申请实施例提供的方法,各步骤的执行主体可以是计算机设备,该计算机设备是指具备数据计算、处理和存储能力的电子设备。该计算机设备可以是诸如PC(PersonalComputer,个人计算机)、平板电脑、智能手机、可穿戴设备、智能机器人、车载设备、智能种植设备等终端;也可以是服务器。其中,服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云计算服务的云服务器。
本申请实施例提供的技术方案,可以被使用在任何需要任务决策(如多任务决策、时序任务决策等)的产品或系统中,诸如机器人、机器人控制系统、自动驾驶系统、游戏智能系统、智慧农业系统等。本申请实施例提供的技术方案能够使得任务决策模型可以更鲁棒地进行多任务决策,以及更鲁棒地迁移至相似任务的决策中,提高任务决策模型的实用性和泛化性能。
示例性地,在机器人控制场景下,任务决策模型用于为机器人提供操作控制策略。通过任务决策模型基于机器人所捕捉的实时图像数据,确定机器人正在执行的任务对应的任务表示向量,进而基于该任务表示向量和实时图像信息,输出针对机器人的输出动作向量,该输出动作向量包括各个可用动作的概率分布,机器人进而可以选择执行最优的动作(即概率值最大的动作),以应对当前环境。
在自动驾驶场景下,任务决策模型用于为自动驾驶车辆提供车辆控制策略。通过任务决策模型基于自动驾驶车辆所捕捉的实时图像数据,确定自动驾驶车辆正在执行的任务对应的任务表示向量,进而基于该任务表示向量和实时图像信息,输出针对自动驾驶车辆的输出动作向量,该输出动作向量包括各个可用动作的概率分布,自动驾驶车辆进而可以选择执行最优的动作(即概率值最大的动作),以应对当前环境。
在智慧农业场景下,任务决策模型用于为智慧农业设备提供种植策略。通过任务决策模型基于智慧农业设备所捕捉的针对于农作物的实时图像数据,确定智慧农业设备正在执行的任务对应的任务表示向量,进而基于该任务表示向量和实时图像信息,输出针对农作物的输出动作向量,该输出动作向量包括各个可用动作的概率分布,智慧农业设备进而可以选择执行最优的动作(即概率值最大的动作)对农作物进行处理。
可选地,对于多个不同任务,可通过同一任务决策模型进行任务决策。例如,在机器人控制场景下,可以通过同一任务决策模型对移动、抓取、操作物体等任务进行任务决策。对于相似的任务,也可以通过同一任务决策模型进行任务决策。例如,在机器人控制场景下,可以将训练用于抓取碗的任务决策模型,对抓取盘子进行任务决策。
在对本申请技术方案进行介绍说明之前,先对本申请涉及的一些名词及术语进行解释说明。
1、离线强化学习(Off-line Reinforcement Learning)
离线强化学习是一类完全从离线的数据中进行学习的强化学习方法,不与环境交互采样,通常这类方法使用动作约束(behavior regularization)来控制在线测试时数据分布与离线数据分布的差异。
2、经验回放
强化学习中离线策略算法的使用的一个技巧,保持一个经验池储存智能体与环境交互的数据,训练策略时,从经验池中采样数据来训练策略网络。经验回放的方式可以重复利用已获得的数据样本,使得离线策略算法的数据利用效率高于在线策略算法。
3、元强化学习(Meta Reinforcement Learning)
元强化学习的目标是学习到具有较强泛化性能的强化学习模型。通过在服从一定分布的多个任务上进行训练,模型能够利用较少的数据量和时间成本,适应或推广到在训练期间从未遇到过的新任务和新环境。
4、离线元强化学习
离线元强化学习作为一种新颖的范式,结合了离线强化学习和元强化学习两大前沿方法的优点,一方面可以完全不依赖与实际环境的交互,并高效、重复地利用已有数据进行训练;同时具备优秀的迁移能力,可以让智能体快速适应新的未知任务,极大地提升了强化学习算法在真实世界中的应用范围和价值。
5、对比学习(Contrastive Learning)
机器学习中用于解决表示学习问题的一种方法。通过对原始数据进行数据增强并构建正负样本集,鼓励神经网络学习同类样本之间的共同特征。对比学习也能够作为强化机器学习问题的附加目标,以引导决策算法学习更好的表示。
6、数据增强(Data Augmentation)
通过对现有数据进行小幅修改或从现有数据中合成新数据来增加数据量的技术。在训练机器学习模型时,它可以充当正则器(Regularizer),帮助减少过拟合(Over-fittinng)现象。在对比学习中,可以用于生成正负样本集合。
请参考图1,其示出了本申请一个实施例提供的任务决策模型的模型架构图。如图1所示,在本申请实施例中,任务决策模型可以包括任务推断网络10、策略网络20和评判网络30。
离线数据池40用于提供多个任务分别对应的训练样本,可以从离线数据池40中采样获取多个任务分别对应的训练样本集,每一个任务对应的训练样本集中可以包括属于该任务的多个训练样本。
任务推断网络10用于生成训练样本集的任务表示向量,该任务表示向量用于表征训练样本集所属的任务。
策略网络20用于基于训练样本的状态向量和任务表示向量,生成该训练样本的动作向量。
评判网络30用于基于训练样本的状态向量、动作向量和任务表示向量,生成相应的评分(即奖励)。
任务推断网络10、策略网络20和评判网络30的训练损失,分别以Lcontrastive、Lactor和Lcritic表示。在模型训练过程中,通过计算上述3个训练损失,据此分别对任务推断网络10、策略网络20和评判网络30的参数进行调整,例如以最小化Lcontrastive、最大化Lactor和Lcritic为目标,进行反向传播梯度下降训练,不断优化各网络的参数,以达到优化训练整个模型的目的。
请参考图2,其示出了本申请一个实施例提供的任务决策方法的流程图,该方法各步骤的执行主体可以计算机设备(如终端或服务器),该方法可以包括如下几个步骤(201~206):
步骤201,从离线数据池中采样获取多个任务分别对应的训练样本集,其中,每一个任务对应的训练样本集中包括属于该任务的多个训练样本。
离线数据池是指不进行在线数据实时更新的数据池,其可以用于存储多个任务分别对应的训练样本。例如,在机器人场景下,可以将机器人捕捉到的图像数据和该图像数据对应的动作信息进行对应存储。还可以将机器人的传感器收集到的传感器数据和传感器数据对应的动作信息进行对应存储,本申请实施例在此不做限定。本申请实施例对训练样本的形式不做限定,其可以是图像数据,也可以是传感器数据,还可以是文字数据等。上述多个任务可以是指从一定分布的多个任务。
可选地,可以从训练样本中提取样本对象的状态信息(如状态向量),该状态信息可用于表征样本对象的状态。样本对象可以是指任务的执行主体(如机器人),也可以是指任务的执行对象(如农作物)。示例性地,在机器人场景下,机器人实时捕捉到的图像数据可以反应出机器人的状态。例如,捕捉到的图像数据可以反应出机器人正在处于抓起物体的状态。其中,该抓起动作即可为该状态对应的动作信息。
在一个示例中,对于每个任务,可以从任务对应的训练样本中随机抽取出多个训练样本,以形成任务对应的训练样本集。也即在每个批次的训练样本数据集的采样过程中,无需按照训练样本的时间顺序进行采样,如此可以将同一任务对应的训练样本之间的关联性进行分割,使得训练样本之间相对独立,降低训练样本所包含的噪声,从而可以提高任务推断网络的泛化性能。同时,可以对训练样本进行重复采样,从而提高了训练样本的利用效率。
可选地,在从离线数据池中采样获取多个任务分别对应的训练样本集之前,还可以对离线数据池中的训练样本进行数据增强,以扩充离线数据池。示例性地,以训练样本为图像样本为例,其具体过程可以如下:对离线数据集中的第一图像样本进行转换处理,得到处理后的第一图像样本;将处理后的第一图像样本与第一图像样本对应的动作信息进行对应存储,得到第一图像样本对应的处理后的离线数据;其中,转换处理包括以下至少一项:裁剪、翻转、换色、缩放、移位、增加高斯噪声。可选地,考虑到维护离线数据池中的数据的时序性,对于时间上相近的训练样本,可以进行同样的转换处理。
步骤202,对于每一个任务,对任务对应的训练样本集进行数据增强,得到任务对应的多个训练样本集。
对于每一个任务,获取任务对应的训练样本集中包含的训练样本;对训练样本进行转换处理,得到处理后的训练样本;其中,训练样本和处理后的训练样本具有相同的语义特征;基于处理后的训练样本,得到任务对应的增强后训练样本集;其中,任务对应的多个训练样本集,包括任务对应的至少一个增强后训练样本集。相同的语义特征是指任务对应的通用特征(即通用知识),也即虽然训练样本和处理后的训练样本之间具有微小的差异,但还是可以从训练样本或处理后的训练样本从提取出相同的语义特征。
可选地,在只进行一次数据增强的情况下,任务对应的多个训练样本集可以包括任务对应的训练样本集(即原始训练样本集)和一个增强后训练样本集,也即,原始训练样本集和一个增强后训练样本集同为该任务的正样本集。在进行多次数据增强的情况下,任务对应的多个训练样本集可以包括任务对应的训练样本集(即原始训练样本集)和多个增强后训练样本集,也即原始训练样本集和多个增强后训练样本集同为该任务的正样本集。任务对应的多个训练样本集也可以只包括多个增强后训练样本集,也即多个增强后训练样本集为该任务的正样本集。
在一个示例中,在训练样本为图像样本的情况下,该转换处理过程可以如下:对图像样本进行转换处理,得到处理后的图像样本;其中,转换处理包括以下至少一项:裁剪、翻转、换色、缩放、移位、增加高斯噪声。示例性地,在只进行一次数据增强的情况下,对于每一个任务,采用同一转换处理方式对任务对应的训练样本集中的每个图像样本进行转换处理,得到任务对应的训练样本集中的每个图像样本对应的处理后的图像样本,将任务对应的训练样本集中的每个图像样本对应的处理后的图像样本,组合成任务对应的增强后训练样本集。
在进行多次数据增强的情况下,先从同一族转换处理方式中采样得到多个参数不同的转换处理方式,再通过多个参数不同的转换处理方式,分别获取任务对应的多个增强后训练样本集。例如,若转换处理方式为裁剪处理,则多个参数不同的转换方式可以为裁剪尺度不同的裁剪处理。
步骤203,通过任务推断网络生成训练样本集的任务表示向量,该训练样本集的任务表示向量用于表征训练样本集所属的任务。
在本申请实施例中,任务推断网络用于对训练样本对应的任务进行推断,生成训练样本对应的任务表示向量。对于不同的任务,任务推断网络采用不同的任务表示向量进行表示,以实现对不同任务的区分。
任务推断网络的网络结构可以是卷积神经网络,其可以对训练样本(如图像样本、传感器数据等)进行编码,将训练样本映射成该训练样本对应的任务表示向量。可选地,对于同一个任务,其对应的训练样本对应的任务表示向量即为该任务对应的任务表示向量。
可选地,步骤203中的训练样本集是指上述多个任务分别对应的所有训练样本集,其包括上述多个任务分别对应的所有的原训练样本集和增强后训练样本集。
步骤204,根据同一任务对应的多个训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个训练样本集的任务表示向量之间的距离度量,确定任务推断网络的训练损失。
在本申请实施例中,距离度量用于表征任务表示向量之间的差异程度。本申请实施例对距离度量的计算方式不作限定,任何适用于计算距离度量的方法都可适用,如欧式距离、余弦相似度、曼哈顿距离、切比雪夫距离等算法等。
任务推断网络的训练损失用于表征任务推断网络的表现性能,诸如任务推断网络的推断鲁棒性和准确性。示例性地,假设任务推断网络的训练损失越小,则表明任务推断网络的表现性能越优;那么,在模型训练过程中,通过不断调整任务推断网络的参数,以最小化其训练损失,以达到使得任务推断网络不断优化的目的。
在一个示例中,任务推断网络的训练损失的确定过程可以如下:
1、对于多个任务中的目标任务对应的第一训练样本集,计算第一训练样本集的任务表示向量和目标任务对应的其它训练样本集的任务表示向量之间的距离度量,得到第一距离度量。
可选地,目标任务可以是指多个任务中的任一任务,第一训练样本集可以是指目标任务对应的多个训练样本集中的任一训练样本集。目标任务对应的其它训练样本集是指目标任务对应的多个训练样本集中除去第一训练样本集之后剩余的训练样本集。目标任务对应的多个训练样本集即为目标任务对应的多个正样本集,下文中的其它任务对应的多个训练样本集集为目标任务对应的多个负样本集。
第一距离度量用于表征第一训练样本集对应的任务表示向量与目标任务对应的其它训练样本集的任务表示向量之间的综合差异程度。第一距离度量越小,第一训练样本集对应的任务表示向量与目标任务对应的其它训练样本集的任务表示向量之间的综合差异程度越小。可选地,可以基于第一训练样本集对应的任务表示向量分别与目标任务对应的其它训练样本集的任务表示向量之间的距离度量的和值(或平均值),得到第一距离度量。
2、计算第一训练样本集的任务表示向量和其它任务对应的训练样本集的任务表示向量之间的距离度量,得到第二距离度量。
可选地,分别计算第一训练样本集的任务表示向量与其它任务对应的多个训练样本集的任务表示向量之间的距离度量,再分别对该多个距离度量进行转化(如先对距离度量进行相同的线性调整,再输入相同的指数函数以进转换),得到多个转化后的距离度量,将该多个转化后的距离度量的和值确定为第二距离度量。
其中,其它任务是指上述多个任务中除去目标任务之后剩余的任务。第二距离度量用于表征第一训练样本集的任务表示向量与其它任务对应的训练样本集的任务表示向量之间的综合差异程度。第二距离度量越大,第一训练样本集的任务表示向量与其它任务对应的训练样本集的任务表示向量之间的综合差异程度越大。
3、根据第一距离度量和第二距离度量,确定第一训练样本集针对任务推断网络的子训练损失。
可选地,第一训练样本集针对任务推断网络的子训练损失可以表示如下:l=-log(第一距离度量/第二距离度量),通过最小化第一训练样本集针对任务推断网络的子训练损失,可以将目标任务对应的多个任务表示向量拉近,以使得目标任务最终只对应于一个任务表示向量,从而获取更加鲁棒的任务表示向量,提高了任务推断模型的鲁棒性。同时,还可以将目标任务对应的任务表示向量与其它任务对应的任务表示向量拉远,从而提高不同任务对应的任务表示向量之间的区别程度。
4、根据多个任务的各个训练样本集分别针对任务推断网络的子训练损失,计算得到任务推断网络的训练损失。
在一个示例中,对于目标任务,先根据目标任务对应的多个训练样本集分别针对任务推断网络的子训练损失,计算得到目标任务针对任务推断网络的训练损失。可选地,将目标任务对应的多个训练样本集分别针对任务推断网络的子训练损失的和值,确定为目标任务针对任务推断网络的训练损失。
采用同样的方法,获取多个任务分别针对任务推断网络的训练损失,再根据多个任务分别针对任务推断网络的训练损失,计算得到任务推断网络的训练损失。可选地,可以将多个任务分别针对任务推断网络的训练损失求平均,得到任务推断网络的训练损失。
在一个示例性实施例中,在任务对应的多个训练样本集包括任务对应的多个增强后训练样本集的情况下,任务推断网络的训练损失的获取过程还可以如下:根据同一任务对应的多个增强后训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个增强后训练样本集的任务表示向量之间的距离度量,确定任务推断网络的训练损失,也即任务对应的原始训练样本集不参与任务推断网络的训练损失的确定过程。
步骤205,基于各个训练样本的状态向量、动作向量和任务表示向量,确定策略网络的训练损失和评判网络的训练损失。
其中,动作向量用于表征训练样本对应的动作信息。训练样本的任务表示向量是指训练样本对应的任务的任务表示向量。
策略网络用于生成训练样本对应的输出动作向量,该输出动作向量用于指示在该训练样本对应的状态下,可用动作的概率分布。评判网络用于对策略网络的输出(即输出动作向量)进行评判,以确定对于训练样本,训练样本对应的输出动作向量的好坏。
策略网络的训练损失表征策略网络的表现性能,诸如决策的准确性、适用性等。在一个示例中,策略网络的训练损失的获取过程可以如下:
1、对于目标训练样本,基于目标训练样本的目标状态向量、目标动作向量和目标任务表示向量,计算得到目标动作向量对应的期望累计奖励,该目标动作向量对应的期望累计奖励用于表征目标动作向量对应的动作对于目标训练样本所属的任务的价值。
其中,目标训练样本可以是各个训练样本中的任一训练样本。期望累计奖励可用于表征动作向量对应的动作对于训练样本所属的任务的价值。
可选地,强化学习通常可以表示为马尔科夫决策过程(Markov DecisionProcess,MDP),MDP包含了五元组(S,A,R,P,γ),其中,S代表状态空间,A代表动作空间,R代表奖励函数,P代表状态转移概率矩阵,γ代表折扣因子。智能体每个时刻观测到状态st,根据状态执行动作at,环境接收到动作后转移到下一个状态st+1并反馈奖励rt,强化学习优化的目标是最大化累积奖励值智能体根据策略π(at|st)选择动作,动作值函数Q(st,at)代表在状态st执行动作at后的期望累积奖励,其中,E为数学期望。
本申请实施例将任务表示向量加入到策略网络和评判网络的学习过程中,则期望累计奖励可更新表示为Q(s,a,z),也即在任务z下的状态s执行动作a后的期望累计奖励。任务推断网络所学习到的策略可以更新表示为π(s,a,z),也即在任务z下的状态s对应的动作a所服从的策略。其中,z为任务表示向量。
2、通过策略网络,基于目标训练样本的目标状态向量、目标动作向量和目标任务表示向量,获取第一输出动作向量,第一输出动作向量是指策略网络对目标训练样本在执行目标动作向量对应的动作之后的状态,进行动作决策生成的动作向量。
3、获取第一输出动作向量对应的约束期望累计奖励,约束期望累计奖励是指策略网络在行为策略约束下生成的动作向量对应的期望累计奖励,行为策略用于约束策略网络从训练样本中学习策略。
在本申请实施例中,通过引入动作约束来限制行为策略与策略网络所学到的策略之间的差异,由此控制两者采样的数据分布(数据分布与策略相关)。
示例性地,约束期望累计奖励的获取过程可以如下:获取第一输出动作向量对应的期望累计奖励;计算行为策略与策略网络所学到的策略之间的距离度量,得到第三距离度量;基于第三距离度量和第一输出动作向量对应的期望累计奖励,获取第一输出动作向量对应的约束期望累计奖励。
其中,第三距离度量用于表征行为策略与策略网络所学到的策略之间的差异程度。可选地,第一输出动作向量对应的约束期望累计奖励可以表示如下:
QD(s′,a′,z)=Q(s′,a′,z)-αD(πθ(·|s′,z),πb(·|s′,z));
其中,α为可调参数,πθ(·|s′,z)为任务z下的状态s′对应的策略网络所学到的策略,πb(·|s′,z)为任务z下的状态s′对应的行为策略,D为距离度量的计算函数,s′为第一输出动作向量对应的状态向量,a′为第一输出动作向量,z为第一输出动作向量对应的任务表示向量。
4、根据目标动作向量对应的期望累计奖励和第一输出动作向量对应的约束期望累计奖励,确定目标训练样本针对策略网络的训练损失。
可选地,目标训练样本针对策略网络的训练损失可以表示如下:
其中,E为数学期望,s为目标状态向量、a为目标动作向量、r表示奖励,γ为折扣因子,B为离线数据池,Q(s,a,z)为目标动作向量对应的期望累计奖励。
5、根据各个训练样本分别针对策略网络的训练损失,计算得到策略网络的训练损失。
可选地,可以将各个训练样本分别针对策略网络的训练损失的和值(或平均值),确定为策略网络的训练损失。
评判网络的训练损失用于表征评判网络的表现性能,诸如评判的准确性、合理性等。在一个示例中,评判网络的训练损失的获取过程可以如下:
1、对于目标训练样本,通过策略网络,基于目标训练样本的目标状态向量和目标任务表示向量,获取第二输出动作向量,第二输出动作向量是指策略网络对目标训练样本进行动作决策生成的动作向量。
2、基于目标训练样本的目标状态向量和目标任务表示向量,以及第二输出动作向量,计算得到第二输出动作向量对应的期望累计奖励。
3、基于第三距离度量和第二输出动作向量对应的期望累计奖励,确定目标训练样本针对评判网络的训练损失,第三距离度量是指行为策略与策略网络所学到的策略之间的距离度量。
可选地,目标训练样本针对评判网络的训练损失可以表示如下:
其中,a″为第二输出动作向量,Q(s,a″,z)为二输出动作向量对应的期望累计奖励,πθ(·|s,z)为任务z下的状态s对应的策略网络所学到的策略,πb(·|s,z)为任务z下的状态s对应的行为策略。
4、根据各个训练样本分别针对评判网络的训练损失,计算得到评判网络的训练损失。
可选地,可以将各个训练样本分别针对评判网络的训练损失的和值(或平均值),确定为评判网络的训练损失。
步骤206,基于任务推断网络的训练损失、策略网络的训练损失和评判网络的训练损失,对任务决策模型的参数进行调整,得到完成训练的任务决策模型,完成训练的任务决策模型用于多任务决策。
可选地,以最小化任务推断网络的训练损失为目标,进行反向传播梯度下降训练,不断优化任务推断网络的参数,以达到优化训练任务推断网络的目的。以最大化策略网络的训练损失为目标,进行反向传播梯度下降训练,不断优化策略网络的参数,以达到优化训练策略网络的目的。以最大化评判网络的训练损失为目标,进行反向传播梯度下降训练,不断优化评判网络的参数,以达到优化训练评判网络的目的。通过对任务推断网络、策略网络和评判网络进行优化训练,以达到化训练任务决策模型的目的。
综上所述,本申请实施例提供的技术方案,通过根据同一任务对应的多个训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个训练样本集的任务表示向量之间的距离度量,确定任务推断网络的训练损失,再基于任务推断网络的训练损失对任务推断网络的参数进行调整,使得任务推断网络可以学习到每个任务对应的通用知识,减少了无关知识对任务表示学习的影响,从而使得任务推断网络可以学习到更鲁棒的任务表示向量,进而使得任务决策模型可以更鲁棒地进行多任务决策,以及更鲁棒地迁移至相似任务的决策中,拓展了任务决策模型的使用场景,提高了任务决策模型的实用性和泛化性能。同时,通过采用本申请提供的技术方案,可以有效改善数据集过拟合的问题,从而进一步提高了任务决策模型的泛化性能。
另外,本申请还结合了数据增强来辅助任务决策模型的训练,通过对训练样本集进行数据增强,为每个任务增加扰动(如无关知识),使得任务推断网络可以学习到更鲁棒的任务表示向量,从而进一步提高了任务决策模型的实用性和泛化性能。同时,通过数据增强,还可以进一步改善数据集过拟合的问题,从而进一步提高了任务决策模型的泛化性能。
另外,通过对离线数据池中的训练样本进行数据增强,不仅可以扩充离线数据池的数据量,还可以避免数据集过拟合的问题,从而有助于提高任务推断网络的鲁棒性。
另外,本申请通过将任务表示向量加入到策略网络和评判网络的学习中,能够将元强化学习的部分可观测马尔科夫决策过程变为完全可测的马尔科夫决策过程,从而提高了策略网络和评判网络学习的稳定性。
另外,本申请通过将任务推断网络设置为卷积神经网络,并结合对比学习对任务推断网络进行训练,使得任务推断网络可以针对不同表示形式(如图像数据、传感器数据、文字数据等)的输入数据进行任务推断,从而扩展了任务决策模型的使用场景,提高了任务决策模型的适用性和泛化性能。
在一个示例性实施例中,在一个任务对应于两个训练样本集的情况下,可以按照如下公式计算得到任务推断网络的训练损失:
其中, N为任务总数量,sim()为余弦相似度的计算公式,exp()为指数函数,z2k为第2k个训练样本集对应的任务表示向量,z2k-1为第2k-1个训练样本集对应的任务表示向量,zm为第m个训练样本集对应的任务表示向量,τ为可调参数,1为指示函数(若m≠2k-1,2k,则指示函数的输出为1,否则指示函数的输出为0)。
示例性地,该任务推断网络的训练损失的计算过程可以如下:
1、从离线数据池中采样获取N个任务分别对应的训练样本集,其中,每一个任务对应的训练样本集中包括属于该任务的多个训练样本,N为大于1的整数。
2、对于每一个任务,对任务对应的训练样本集进行数据增强,得到任务对应的两个训练样本集。
可选地,在只进行一次数据增强的情况下,该两个训练样本集包括原训练样本集和原训练样本集对应的增强后训练样本集。在进行两次数据增强的情况下,该两个训练样本集包括原训练样本集对应的两个增强后训练样本集。下文将以两个训练样本集包括原训练样本集对应的两个增强后训练样本集为例进行说明,其具体内容可以如下:
从同一族转换处理方式中采样得到第一转换方式和第二转换方式,第一转换方式和第二转换方式是指参数不同的同族转换处理方式。对于目标训练样本集,分别通过第一转换方式和第二转换方式对其进行转换处理,得到目标训练样本集对应的两个增强后训练样本集,并将该两个增强后训练样本集确定为目标训练样本集对应的正样本集。然后依次获取其它训练样本集分别对应的两个增强后训练样本集,并将所有其它训练样本集分别对应的两个增强后训练样本集确定为目标训练样本集对应的负样本集。
示例性地,参考图3,T为同一族转换处理方式的分布,从T中分别采样得到第一转换处理方式t和第二转换处理方式t′,对于目标训练样本集B,分别通过第一转换方式t和第二转换方式t′对其进行转换处理,该过程转换处理过程可以表示如下:
其中,为B在第一转换处理方式下的增强后训练样本集,为B在第二转换处理方式下的增强后训练样本集。和为目标训练样本集对应的正样本集,N-1个其它原始训练样本集对应的2N-2个增强后训练样本集为目标训练样本集对应的负样本集。
3、通过任务推断网络生成所有训练样本集(包括N个原始训练样本集和2N个增强后训练样本集)的任务表示向量,该任务表示向量用于表征训练样本集所属的任务。例如,对于其对应的任务表示向量可以表示如下:其中,φi为任务推断网络的网络参数,E为编码器。通过同样的方法得到对应的任务表示向量zj。
4、根据同一任务对应的2个正样本集的任务表示向量之间的距离度量,以及每个正样本集分别与2N-2个负样本集的任务表示向量之间的距离度量,确定任务推断网络的训练损失。
示例性地,对于目标任务对应的目标训练样本集B,先获取针对于任务推断网络的子训练损失和针对于任务推断网络的子训练损失,再基于针对于任务推断网络的子训练损失和针对于任务推断网络的子训练损失,确定目标任务针对于任务推断网络的训练损失。
可选地,将1(j,i)和1(j,i)的和值,确定为目标任务针对于任务推断网络的训练损失。
获取N个任务分别针对任务推断网络的训练损失,并将N个任务分别针对任务推断网络的训练损失的平均值,确定为任务推断网络的训练损失。该任务推断网络的训练损失可以表示为:N个任务分别针对任务推断网络的训练损失的和值/2N。
在一个可行的示例中,将1(j,i)和1(j,i)的平均值,确定为目标任务针对于任务推断网络的训练损失,则任务推断网络的训练损失可以表示为:N个任务分别针对任务推断网络的训练损失的和值/N。
综上所述,本申请实施例提供的技术方案,通过根据同一任务对应的多个训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个训练样本集的任务表示向量之间的距离度量,确定任务推断网络的训练损失,再基于任务推断网络的训练损失对任务推断网络的参数进行调整,使得任务推断网络可以学习到每个任务对应的通用知识,减少了无关知识对任务表示学习的影响,从而使得任务推断网络可以学习到更鲁棒的任务表示向量,进而使得任务决策模型可以更鲁棒地进行多任务决策,以及更鲁棒地迁移至相似任务的决策中,拓展了任务决策模型的使用场景,提高了任务决策模型的实用性和泛化性能。同时,通过采用本申请提供的技术方案,可以有效改善数据集过拟合的问题,从而进一步提高了任务决策模型的泛化性能。
另外,本申请还结合了数据增强来辅助任务决策模型的训练,通过对训练样本集进行数据增强,为每个任务增加扰动(如无关知识),使得任务推断网络可以学习到更鲁棒的任务表示向量,从而进一步提高了任务决策模型的实用性和泛化性能。同时,通过数据增强,还可以进一步改善数据集过拟合的问题,从而进一步提高了任务决策模型的泛化性能。
下述为本申请装置实施例,可以用于执行本申请方法实施例。对于本申请装置实施例中未披露的细节,请参照本申请方法实施例。
请参考图4,其示出了本申请一个实施例提供的信息获取装置的框图。该装置具有实现上述方法示例的功能,所述功能可以由硬件实现,也可以由硬件执行相应的软件实现。该装置可以是计算机设备,也可以设置在计算机设备中。该装置400可以包括:样本集获取模块401、样本集增强模块402、表示向量生成模块403、推断损失确定模块404、策略损失确定模块405和网络参数调整模块406。
样本集获取模块401,用于从离线数据池中采样获取多个任务分别对应的训练样本集,其中,每一个任务对应的训练样本集中包括属于所述任务的多个训练样本。
样本集增强模块402,用于对于每一个任务,对所述任务对应的训练样本集进行数据增强,得到所述任务对应的多个训练样本集。
表示向量生成模块403,用于通过所述任务推断网络生成所述训练样本集的任务表示向量,所述训练样本集的任务表示向量用于表征所述训练样本集所属的任务。
推断损失确定模块404,用于根据同一任务对应的多个训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个训练样本集的任务表示向量之间的距离度量,确定所述任务推断网络的训练损失。
策略损失确定模块405,用于基于各个所述训练样本的状态向量、动作向量和任务表示向量,确定所述策略网络的训练损失和所述评判网络的训练损失。
网络参数调整模块406,用于基于所述任务推断网络的训练损失、所述策略网络的训练损失和所述评判网络的训练损失,对所述任务决策模型的参数进行调整,得到完成训练的任务决策模型,所述完成训练的任务决策模型用于多任务决策。
在一个示例性实施例中,所述推断损失确定模块404,用于:
对于所述多个任务中的目标任务对应的第一训练样本集,计算所述第一训练样本集的任务表示向量和所述目标任务对应的其它训练样本集的任务表示向量之间的距离度量,得到第一距离度量;
计算所述第一训练样本集的任务表示向量和其它任务对应的训练样本集的任务表示向量之间的距离度量,得到第二距离度量;
根据所述第一距离度量和所述第二距离度量,确定所述第一训练样本集针对所述任务推断网络的子训练损失;
根据所述多个任务的各个训练样本集分别针对所述任务推断网络的子训练损失,计算得到所述任务推断网络的训练损失。
在一个示例性实施例中,所述推断损失确定模块404,还用于:
根据所述目标任务对应的多个训练样本集分别针对所述任务推断网络的子训练损失,计算得到所述目标任务针对所述任务推断网络的训练损失;
根据所述多个任务分别针对所述任务推断网络的训练损失,计算得到所述任务推断网络的训练损失。
在一个示例性实施例中,所述推断损失确定模块404还用于将所述多个任务分别针对所述任务推断网络的训练损失求平均,得到所述任务推断网络的训练损失。
在一个示例性实施例中,同一个任务对应于两个训练样本集;所述推断损失确定模块404还用于按照如下公式计算得到所述任务推断网络的训练损失:
按照如下公式计算得到所述任务推断网络的训练损失:
其中, N为任务总数量,sim()为余弦相似度的计算公式,exp()为指数函数,z2k为第2k个所述训练样本集对应的任务表示向量,z2k-1为第2k-1个所述训练样本集对应的任务表示向量,zm为第m个所述训练样本集对应的任务表示向量,τ为可调参数,1为指示函数。
在一个示例性实施例中,所述样本集增强模块402,用于:
对于每一个任务,获取所述任务对应的训练样本集中包含的训练样本;
对所述训练样本进行转换处理,得到处理后的训练样本;其中,所述训练样本和所述处理后的训练样本具有相同的语义特征;
基于所述处理后的训练样本,得到所述任务对应的增强后训练样本集;
其中,所述任务对应的多个训练样本集,包括所述任务对应的至少一个增强后训练样本集。
在一个示例性实施例中,所述任务对应的多个训练样本集,包括所述任务对应的多个增强后训练样本集;所述样本集增强模块402,还用于根据同一任务对应的多个增强后训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个增强后训练样本集的任务表示向量之间的距离度量,确定所述任务推断网络的训练损失。
在一个示例性实施例中,所述训练样本为图像样本;所述样本集增强模块402还用于对所述图像样本进行转换处理,得到处理后的图像样本;其中,所述转换处理包括以下至少一项:裁剪、翻转、换色、缩放、移位、增加高斯噪声。
在一个示例性实施例中,所述策略损失确定模块405,用于:
对于目标训练样本,基于所述目标训练样本的目标状态向量、目标动作向量和目标任务表示向量,计算得到所述目标动作向量对应的期望累计奖励,所述目标动作向量对应的期望累计奖励用于表征所述目标动作向量对应的动作对于所述目标训练样本所属的任务的价值;
通过所述策略网络,基于所述目标训练样本的目标状态向量、目标动作向量和目标任务表示向量,获取第一输出动作向量,所述第一输出动作向量是指所述策略网络对所述目标训练样本在执行所述目标动作向量对应的动作之后的状态,进行动作决策生成的动作向量;
获取所述第一输出动作向量对应的约束期望累计奖励,所述约束期望累计奖励是指所述策略网络在行为策略约束下生成的动作向量对应的期望累计奖励,所述行为策略用于约束所述策略网络从所述训练样本中学习策略;
根据所述目标动作向量对应的期望累计奖励和所述第一输出动作向量对应的约束期望累计奖励,确定所述目标训练样本针对所述策略网络的训练损失;
根据所述各个训练样本分别针对所述策略网络的训练损失,计算得到所述策略网络的训练损失。
在一个示例性实施例中,所述策略损失确定模块405,还用于:
获取所述第一输出动作向量对应的期望累计奖励;
计算所述行为策略与所述策略网络所学到的策略之间的距离度量,得到第三距离度量;
基于所述第三距离度量和所述第一输出动作向量对应的期望累计奖励,获取所述第一输出动作向量对应的约束期望累计奖励。
在一个示例性实施例中,所述策略损失确定模块405,还用于:
对于目标训练样本,通过所述策略网络,基于所述目标训练样本的目标状态向量和目标任务表示向量,获取第二输出动作向量,所述第二输出动作向量是指所述策略网络对所述目标训练样本进行动作决策生成的动作向量;
基于所述目标训练样本的目标状态向量和目标任务表示向量,以及所述第二输出动作向量,计算得到所述第二输出动作向量对应的期望累计奖励;
基于第三距离度量和所述第二输出动作向量对应的期望累计奖励,确定所述目标训练样本针对所述评判网络的训练损失,所述第三距离度量是指行为策略与所述策略网络所学到的策略之间的距离度量;
根据所述各个训练样本分别针对所述评判网络的训练损失,计算得到所述评判网络的训练损失。
综上所述,本申请实施例提供的技术方案,通过根据同一任务对应的多个训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个训练样本集的任务表示向量之间的距离度量,确定任务推断网络的训练损失,再基于任务推断网络的训练损失对任务推断网络的参数进行调整,使得任务推断网络可以学习到每个任务对应的通用知识,减少了无关知识对任务表示学习的影响,从而使得任务推断网络可以学习到更鲁棒的任务表示向量,进而使得任务决策模型可以更鲁棒地进行多任务决策,以及更鲁棒地迁移至相似任务的决策中,拓展了任务决策模型的使用场景,提高了任务决策模型的实用性和泛化性能。同时,通过采用本申请提供的技术方案,可以有效改善数据集过拟合的问题,从而进一步提高了任务决策模型的泛化性能。
另外,本申请还结合了数据增强来辅助任务决策模型的训练,通过对训练样本集进行数据增强,为每个任务增加扰动(如无关知识),使得任务推断网络可以学习到更鲁棒的任务表示向量,从而进一步提高了任务决策模型的实用性和泛化性能。同时,通过数据增强,还可以进一步改善数据集过拟合的问题,从而进一步提高了任务决策模型的泛化性能。
需要说明的是,上述实施例提供的装置,在实现其功能时,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将设备的内容结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。另外,上述实施例提供的装置与方法实施例属于同一构思,其具体实现过程详见方法实施例,这里不再赘述。
请参考图5,其示出了本申请一个实施例提供的计算机设备的结构框图。该计算机设备可以用于实施上述实施例中提供的任务决策方法。具体来讲:
该计算机设备500包括中央处理单元(如CPU(Central Processing Unit,中央处理器)、GPU(Graphics Processing Unit,图形处理器)和FPGA(Field Programmable GateArray,现场可编程逻辑门阵列)等)501、包括RAM(Random-Access Memory,随机存取存储器)502和ROM(Read-Only Memory,只读存储器)503的系统存储器504,以及连接系统存储器504和中央处理单元501的系统总线505。该计算机设备500还包括帮助服务器内的各个器件之间传输信息的基本输入/输出系统(Input Output System,I/O系统)506,和用于存储操作系统513、应用程序514和其他程序模块515的大容量存储设备507。
该基本输入/输出系统506包括有用于显示信息的显示器508和用于用户输入信息的诸如鼠标、键盘之类的输入设备509。其中,该显示器508和输入设备509都通过连接到系统总线505的输入输出控制器510连接到中央处理单元501。该基本输入/输出系统506还可以包括输入输出控制器510以用于接收和处理来自键盘、鼠标、或电子触控笔等多个其他设备的输入。类似地,输入输出控制器510还提供输出到显示屏、打印机或其他类型的输出设备。
该大容量存储设备507通过连接到系统总线505的大容量存储控制器(未示出)连接到中央处理单元501。该大容量存储设备507及其相关联的计算机可读介质为计算机设备500提供非易失性存储。也就是说,该大容量存储设备507可以包括诸如硬盘或者CD-ROM(Compact Disc Read-Only Memory,只读光盘)驱动器之类的计算机可读介质(未示出)。
不失一般性,该计算机可读介质可以包括计算机存储介质和通信介质。计算机存储介质包括以用于存储诸如计算机可读指令、数据结构、程序模块或其他数据等信息的任何方法或技术实现的易失性和非易失性、可移动和不可移动介质。计算机存储介质包括RAM、ROM、EPROM(Erasable Programmable Read-Only Memory,可擦写可编程只读存储器)、EEPROM(Electrically Erasable Programmable Read-Only Memory,电可擦写可编程只读存储器)、闪存或其他固态存储其技术,CD-ROM、DVD(Digital Video Disc,高密度数字视频光盘)或其他光学存储、磁带盒、磁带、磁盘存储或其他磁性存储设备。当然,本领域技术人员可知该计算机存储介质不局限于上述几种。上述的系统存储器504和大容量存储设备507可以统称为存储器。
根据本申请实施例,该计算机设备500还可以通过诸如因特网等网络连接到网络上的远程计算机运行。也即计算机设备500可以通过连接在该系统总线505上的网络接口单元511连接到网络512,或者说,也可以使用网络接口单元511来连接到其他类型的网络或远程计算机系统(未示出)。
所述存储器还包括至少一条指令、至少一段程序、代码集或指令集,至少一条指令、至少一段程序、代码集或指令集存储于存储器中,且经配置以由一个或者一个以上处理器执行,以实现上述任务决策方法。
在一个示例性实施例中,还提供了一种计算机可读存储介质,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集在被处理器执行时以实现上述任务决策方法。
可选地,该计算机可读存储介质可以包括:ROM(Read-Only Memory,只读存储器)、RAM(Random-Access Memory,随机存储器)、SSD(Solid State Drives,固态硬盘)或光盘等。其中,随机存取记忆体可以包括ReRAM(Resistance Random Access Memory,电阻式随机存取记忆体)和DRAM(Dynamic Random Access Memory,动态随机存取存储器)。
在一个示例性实施例中,还提供了一种计算机程序产品或计算机程序,所述计算机程序产品或计算机程序包括计算机指令,所述计算机指令存储在计算机可读存储介质中。计算机设备的处理器从所述计算机可读存储介质中读取所述计算机指令,所述处理器执行所述计算机指令,使得所述计算机设备执行上述任务决策方法。
应当理解的是,在本文中提及的“多个”是指两个或两个以上。“和/或”,描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。字符“/”一般表示前后关联对象是一种“或”的关系。另外,本文中描述的步骤编号,仅示例性示出了步骤间的一种可能的执行先后顺序,在一些其它实施例中,上述步骤也可以不按照编号顺序来执行,如两个不同编号的步骤同时执行,或者两个不同编号的步骤按照与图示相反的顺序执行,本申请实施例对此不作限定。
以上所述仅为本申请的示例性实施例,并不用以限制本申请,凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。
Claims (15)
1.一种任务决策方法,其特征在于,任务决策模型包括任务推断网络、策略网络和评判网络,所述方法包括:
从离线数据池中采样获取多个任务分别对应的训练样本集,其中,每一个任务对应的训练样本集中包括属于所述任务的多个训练样本;
对于每一个任务,对所述任务对应的训练样本集进行数据增强,得到所述任务对应的多个训练样本集;
通过所述任务推断网络生成所述训练样本集的任务表示向量,所述训练样本集的任务表示向量用于表征所述训练样本集所属的任务;
根据同一任务对应的多个训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个训练样本集的任务表示向量之间的距离度量,确定所述任务推断网络的训练损失;
基于各个所述训练样本的状态向量、动作向量和任务表示向量,确定所述策略网络的训练损失和所述评判网络的训练损失;
基于所述任务推断网络的训练损失、所述策略网络的训练损失和所述评判网络的训练损失,对所述任务决策模型的参数进行调整,得到完成训练的任务决策模型,所述完成训练的任务决策模型用于多任务决策。
2.根据权利要求1所述的方法,其特征在于,所述根据同一任务对应的多个训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个训练样本集的任务表示向量之间的距离度量,确定所述任务推断网络的训练损失,包括:
对于所述多个任务中的目标任务对应的第一训练样本集,计算所述第一训练样本集的任务表示向量和所述目标任务对应的其它训练样本集的任务表示向量之间的距离度量,得到第一距离度量;
计算所述第一训练样本集的任务表示向量和其它任务对应的训练样本集的任务表示向量之间的距离度量,得到第二距离度量;
根据所述第一距离度量和所述第二距离度量,确定所述第一训练样本集针对所述任务推断网络的子训练损失;
根据所述多个任务的各个训练样本集分别针对所述任务推断网络的子训练损失,计算得到所述任务推断网络的训练损失。
3.根据权利要求2所述的方法,其特征在于,所述根据所述多个任务的各个训练样本集分别针对所述任务推断网络的子训练损失,计算得到所述任务推断网络的训练损失,包括:
根据所述目标任务对应的多个训练样本集分别针对所述任务推断网络的子训练损失,计算得到所述目标任务针对所述任务推断网络的训练损失;
根据所述多个任务分别针对所述任务推断网络的训练损失,计算得到所述任务推断网络的训练损失。
4.根据权利要求3所述的方法,其特征在于,所述根据所述多个任务分别针对所述任务推断网络的训练损失,计算得到所述任务推断网络的训练损失,包括:
将所述多个任务分别针对所述任务推断网络的训练损失求平均,得到所述任务推断网络的训练损失。
6.根据权利要求1所述的方法,其特征在于,所述对于每一个任务,对所述任务对应的训练样本集进行数据增强,得到所述任务对应的多个训练样本集,包括:
对于每一个任务,获取所述任务对应的训练样本集中包含的训练样本;
对所述训练样本进行转换处理,得到处理后的训练样本;其中,所述训练样本和所述处理后的训练样本具有相同的语义特征;
基于所述处理后的训练样本,得到所述任务对应的增强后训练样本集;
其中,所述任务对应的多个训练样本集,包括所述任务对应的至少一个增强后训练样本集。
7.根据权利要求6所述的方法,其特征在于,所述任务对应的多个训练样本集,包括所述任务对应的多个增强后训练样本集;
所述根据同一任务对应的多个训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个训练样本集的任务表示向量之间的距离度量,确定所述任务推断网络的训练损失,包括:
根据同一任务对应的多个增强后训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个增强后训练样本集的任务表示向量之间的距离度量,确定所述任务推断网络的训练损失。
8.根据权利要求6所述的方法,其特征在于,所述训练样本为图像样本;
所述对所述训练样本进行转换处理,得到处理后的训练样本,包括:
对所述图像样本进行转换处理,得到处理后的图像样本;
其中,所述转换处理包括以下至少一项:裁剪、翻转、换色、缩放、移位、增加高斯噪声。
9.根据权利要求1所述的方法,其特征在于,所述确定所述策略网络的训练损失,包括:
对于目标训练样本,基于所述目标训练样本的目标状态向量、目标动作向量和目标任务表示向量,计算得到所述目标动作向量对应的期望累计奖励,所述目标动作向量对应的期望累计奖励用于表征所述目标动作向量对应的动作对于所述目标训练样本所属的任务的价值;
通过所述策略网络,基于所述目标训练样本的目标状态向量、目标动作向量和目标任务表示向量,获取第一输出动作向量,所述第一输出动作向量是指所述策略网络对所述目标训练样本在执行所述目标动作向量对应的动作之后的状态,进行动作决策生成的动作向量;
获取所述第一输出动作向量对应的约束期望累计奖励,所述约束期望累计奖励是指所述策略网络在行为策略约束下生成的动作向量对应的期望累计奖励,所述行为策略用于约束所述策略网络从所述训练样本中学习策略;
根据所述目标动作向量对应的期望累计奖励和所述第一输出动作向量对应的约束期望累计奖励,确定所述目标训练样本针对所述策略网络的训练损失;
根据所述各个训练样本分别针对所述策略网络的训练损失,计算得到所述策略网络的训练损失。
10.根据权利要求9所述的方法,其特征在于,所述获取所述第一输出动作向量对应的约束期望累计奖励,包括:
获取所述第一输出动作向量对应的期望累计奖励;
计算所述行为策略与所述策略网络所学到的策略之间的距离度量,得到第三距离度量;
基于所述第三距离度量和所述第一输出动作向量对应的期望累计奖励,获取所述第一输出动作向量对应的约束期望累计奖励。
11.根据权利要求1所述的方法,其特征在于,所述确定所述评判网络的训练损失,包括:
对于目标训练样本,通过所述策略网络,基于所述目标训练样本的目标状态向量和目标任务表示向量,获取第二输出动作向量,所述第二输出动作向量是指所述策略网络对所述目标训练样本进行动作决策生成的动作向量;
基于所述目标训练样本的目标状态向量和目标任务表示向量,以及所述第二输出动作向量,计算得到所述第二输出动作向量对应的期望累计奖励;
基于第三距离度量和所述第二输出动作向量对应的期望累计奖励,确定所述目标训练样本针对所述评判网络的训练损失,所述第三距离度量是指行为策略与所述策略网络所学到的策略之间的距离度量;
根据所述各个训练样本分别针对所述评判网络的训练损失,计算得到所述评判网络的训练损失。
12.一种任务决策装置,其特征在于,任务决策模型包括任务推断网络、策略网络和评判网络,所述装置包括:
样本集获取模块,用于从离线数据池中采样获取多个任务分别对应的训练样本集,其中,每一个任务对应的训练样本集中包括属于所述任务的多个训练样本;
样本集增强模块,用于对于每一个任务,对所述任务对应的训练样本集进行数据增强,得到所述任务对应的多个训练样本集;
表示向量生成模块,用于通过所述任务推断网络生成所述训练样本集的任务表示向量,所述训练样本集的任务表示向量用于表征所述训练样本集所属的任务;
推断损失确定模块,用于根据同一任务对应的多个训练样本集的任务表示向量之间的距离度量,以及不同任务对应的多个训练样本集的任务表示向量之间的距离度量,确定所述任务推断网络的训练损失;
策略损失确定模块,用于基于各个所述训练样本的状态向量、动作向量和任务表示向量,确定所述策略网络的训练损失和所述评判网络的训练损失;
网络参数调整模块,用于基于所述任务推断网络的训练损失、所述策略网络的训练损失和所述评判网络的训练损失,对所述任务决策模型的参数进行调整,得到完成训练的任务决策模型,所述完成训练的任务决策模型用于多任务决策。
13.一种计算机设备,其特征在于,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现如权利要求1至11任一项所述的任务决策方法。
14.一种计算机可读存储介质,其特征在于,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现如权利要求1至11任一项所述的任务决策方法。
15.一种计算机程序产品或计算机程序,其特征在于,所述计算机程序产品或计算机程序包括计算机指令,所述计算机指令存储在计算机可读存储介质中,处理器从所述计算机可读存储介质读取并执行所述计算机指令,以实现如权利要求1至11任一项所述的任务决策方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111107294.4A CN114282741A (zh) | 2021-09-22 | 2021-09-22 | 任务决策方法、装置、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111107294.4A CN114282741A (zh) | 2021-09-22 | 2021-09-22 | 任务决策方法、装置、设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114282741A true CN114282741A (zh) | 2022-04-05 |
Family
ID=80868548
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111107294.4A Pending CN114282741A (zh) | 2021-09-22 | 2021-09-22 | 任务决策方法、装置、设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114282741A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116983656A (zh) * | 2023-09-28 | 2023-11-03 | 腾讯科技(深圳)有限公司 | 决策模型的训练方法、装置、设备及存储介质 |
-
2021
- 2021-09-22 CN CN202111107294.4A patent/CN114282741A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116983656A (zh) * | 2023-09-28 | 2023-11-03 | 腾讯科技(深圳)有限公司 | 决策模型的训练方法、装置、设备及存储介质 |
CN116983656B (zh) * | 2023-09-28 | 2023-12-26 | 腾讯科技(深圳)有限公司 | 决策模型的训练方法、装置、设备及存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Segu et al. | Batch normalization embeddings for deep domain generalization | |
Arulkumaran et al. | Deep reinforcement learning: A brief survey | |
CN110383299B (zh) | 记忆增强的生成时间模型 | |
CN110796111B (zh) | 图像处理方法、装置、设备及存储介质 | |
CN112348113B (zh) | 离线元强化学习模型的训练方法、装置、设备及存储介质 | |
CN111507378A (zh) | 训练图像处理模型的方法和装置 | |
Ma et al. | PID controller-guided attention neural network learning for fast and effective real photographs denoising | |
CN111462010A (zh) | 图像处理模型的训练方法、图像处理方法、装置及设备 | |
CN112116090A (zh) | 神经网络结构搜索方法、装置、计算机设备及存储介质 | |
CN113538441A (zh) | 图像分割模型的处理方法、图像处理方法及装置 | |
CN112258625B (zh) | 基于注意力机制的单幅图像到三维点云模型重建方法及系统 | |
CN116958324A (zh) | 图像生成模型的训练方法、装置、设备及存储介质 | |
CN111282272A (zh) | 信息处理方法、计算机可读介质及电子设备 | |
CN114358250A (zh) | 数据处理方法、装置、计算机设备、介质及程序产品 | |
CN114282741A (zh) | 任务决策方法、装置、设备及存储介质 | |
CN111611852A (zh) | 一种表情识别模型的训练方法、装置及设备 | |
CN110533749B (zh) | 一种动态纹理视频生成方法、装置、服务器及存储介质 | |
WO2022127603A1 (zh) | 一种模型处理方法及相关装置 | |
CN113096206B (zh) | 基于注意力机制网络的人脸生成方法、装置、设备及介质 | |
CN115212549A (zh) | 一种对抗场景下的对手模型构建方法及存储介质 | |
CN113821615A (zh) | 自助对话方法、装置、设备及存储介质 | |
CN116563450A (zh) | 表情迁移方法、模型训练方法和装置 | |
CN113706650A (zh) | 一种基于注意力机制和流模型的图像生成方法 | |
CN113822293A (zh) | 用于图数据的模型处理方法、装置、设备及存储介质 | |
CN116612341B (zh) | 用于对象计数的图像处理方法、装置、设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
REG | Reference to a national code |
Ref country code: HK Ref legal event code: DE Ref document number: 40067081 Country of ref document: HK |
|
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |