CN113505210A - 一种基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统 - Google Patents
一种基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统 Download PDFInfo
- Publication number
- CN113505210A CN113505210A CN202110782860.5A CN202110782860A CN113505210A CN 113505210 A CN113505210 A CN 113505210A CN 202110782860 A CN202110782860 A CN 202110782860A CN 113505210 A CN113505210 A CN 113505210A
- Authority
- CN
- China
- Prior art keywords
- network
- actor
- critic
- generator
- sub
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 claims abstract description 53
- 238000012549 training Methods 0.000 claims abstract description 52
- 238000007476 Maximum Likelihood Methods 0.000 claims abstract description 6
- 238000003745 diagnosis Methods 0.000 claims abstract description 5
- 230000007787 long-term memory Effects 0.000 claims abstract description 5
- 230000006403 short-term memory Effects 0.000 claims abstract description 5
- 230000006870 function Effects 0.000 claims description 65
- 238000013138 pruning Methods 0.000 claims description 24
- 230000009471 action Effects 0.000 claims description 19
- 230000008569 process Effects 0.000 claims description 19
- 238000005516 engineering process Methods 0.000 claims description 15
- 210000002569 neuron Anatomy 0.000 claims description 14
- 239000011159 matrix material Substances 0.000 claims description 12
- 238000005457 optimization Methods 0.000 claims description 12
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 claims description 11
- 238000000354 decomposition reaction Methods 0.000 claims description 10
- 239000003795 chemical substances by application Substances 0.000 claims description 8
- 230000004913 activation Effects 0.000 claims description 7
- 238000013528 artificial neural network Methods 0.000 claims description 4
- 238000011217 control strategy Methods 0.000 claims description 3
- 238000013527 convolutional neural network Methods 0.000 claims description 3
- 230000000306 recurrent effect Effects 0.000 claims description 3
- 230000007704 transition Effects 0.000 claims description 3
- 238000009423 ventilation Methods 0.000 claims description 3
- 238000012545 processing Methods 0.000 abstract description 9
- 230000003042 antagnostic effect Effects 0.000 abstract description 2
- 230000002787 reinforcement Effects 0.000 description 9
- 238000004364 calculation method Methods 0.000 description 6
- 238000010586 diagram Methods 0.000 description 4
- 238000010276 construction Methods 0.000 description 3
- 238000011160 research Methods 0.000 description 3
- 238000013473 artificial intelligence Methods 0.000 description 2
- 238000012937 correction Methods 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 230000015654 memory Effects 0.000 description 2
- 238000003058 natural language processing Methods 0.000 description 2
- 230000008447 perception Effects 0.000 description 2
- 230000003044 adaptive effect Effects 0.000 description 1
- 230000001186 cumulative effect Effects 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000002068 genetic effect Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 230000002265 prevention Effects 0.000 description 1
- 238000003908 quality control method Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/33—Querying
- G06F16/332—Query formulation
- G06F16/3329—Natural language query formulation or dialogue systems
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/35—Clustering; Classification
- G06F16/353—Clustering; Classification into predefined classes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/044—Recurrent networks, e.g. Hopfield networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G16—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
- G16H—HEALTHCARE INFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR THE HANDLING OR PROCESSING OF MEDICAL OR HEALTHCARE DATA
- G16H50/00—ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics
- G16H50/20—ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics for computer-aided diagnosis, e.g. based on medical expert systems
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- Health & Medical Sciences (AREA)
- Data Mining & Analysis (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Artificial Intelligence (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- Evolutionary Computation (AREA)
- Biophysics (AREA)
- Software Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Public Health (AREA)
- Probability & Statistics with Applications (AREA)
- Human Computer Interaction (AREA)
- Pathology (AREA)
- Epidemiology (AREA)
- Primary Health Care (AREA)
- Computer And Data Communications (AREA)
Abstract
本发明公开了一种基于轻量化Actor‑Critic生成式对抗网络的医疗问答生成系统,包括轻量化Actor‑Critic结构的生成器和判别器,系统输入用户提出的医疗问题后,生成器通过编码‑解码的方式生成医疗诊断方案;已知的医疗问答文本作为数据集,输入到生成器中并采用极大似然估计方法进行预训练,再把预训练好的生成器生成的数据分布作为假样本,已知的数据作为真样本,输入到判别器网络进行预训练。预训练生成器和判别器后,复用生成器为Actor网络并构建结构为长短期记忆网络的Critic网络,采用Actor‑Critic算法对生成器网络权值参数进行更新,并与判别器进行对抗训练,同时采用基于组MCP正则项多路径多层Actor和Critic网络轻量化方法对网络进行轻量化处理。
Description
技术领域
本发明涉及强化学习领域和自然语言处理领域,具体涉及一种基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统。
背景技术
目前,由于医疗资源短缺且分布不均衡等问题给医院带来严峻的运营压力、医患关系紧张等问题。随着移动互联网技术的发展,医疗行业信息化受到越来越多企业和国家重视。医疗问答系统被广泛应用于医疗行业,其通过网络整合不同地域间的医疗资源,以获得高质量医疗服务效率,同时缓解医生的工作压力。问答生成是问答系统的一种实现方式,是近年在人工智能和自然语言处理领域中具有广阔前景的研究方向,其能准确、简洁地生成用户用自然语言提出的问题的答案,可将其应用于实现医疗问答生成系统。目前,医疗领域缺乏有效的信息质量管控机制,且医疗问答信息数据量受限,因此,整合利用有限的知识库的问答数据,开发出能生成准确、专业医疗诊断方案的医疗生成系统显得尤为重要。
深度强化学习(Deep Reinforcement Learning)将强化学习的决策能力和深度学习的感知能力两者结合,可用于解决系统复杂的决策感知问题。近年来其在机器学习、自动控制和人工智能等领域受到广泛的关注和研究,展示其优越的适应性和学习能力,Actor-Critic结构算法是深度强化学习主流方法之一。其中,Actor网络负责逼近连续的策略空间,Critic网络负责评价Actor网络选择策略的好坏。Actor-Critic结构算法不仅可以解决连续空间控制问题,且可实现单步更新,效率更高,已被广泛应用于各重要领域。
生成式对抗网络(Generative Adversarial Networks,GAN)是一种基于零和博弈的深度学习模型,由生成器(Generator)和判别器(Discriminator)组成。判别器用于判别输入样本为真实样本的概率;生成器用于接收输入变量生成逼真的样本,但传统的GAN中的生成器仅适用于生成连续型的数据。Lantao Yu提出的SeqGAN模型(Sequence GenerativeAdversarial Nets)结合了GAN和策略梯度(Policy Gradient)的强化学习方法使GAN得以应用于离散数据,为医疗问答生成系统的实现提供一种新思路。但SeqGAN模型的问答生成质量远达不到预期要求。对此,选用Actor-Critic结构深度强化学习取代已有的策略梯度方法,用于缓解问答生成系统训练不稳定问题。
但Actor-Critic结构深度强化学习通过增加网络深度和神经元个数得到优异性能的同时,不可避免引入了庞大的参数量和计算量,导致其对存储资源和计算资源提出很高要求,难以部署到现有资源受限的移动设备上。因此,实现Actor-Critic结构的深度强化学习网络轻量化,并应用于医疗问答生成系统,使其能部署到计算资源有限的设备上,在深度强化学习领域中是一个迫切需要解决的问题。
发明内容
本发明的目的是提供一种基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统,缓解医疗资源短缺且分布不均衡给医院带来严峻的运营压力、医患关系紧张等问题。
为了实现上述任务,本发明采用以下技术方案:
一种基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统,包括:
轻量化Actor-Critic结构的生成器和判别器;
输入用户提出的医疗问题后,生成器通过编码-解码的方式生成医疗诊断方案作为系统的输出;
其中,所述轻量化Actor-Critic结构的生成器和判别器的训练过程为:
首先,构建生成器和判别器网络,并进行预训练;将已知的医疗问答文本构建的数据集,输入到生成器中并采用极大似然估计方法进行预训练,再把预训练好的生成器生成的问答样本作为假样本,已知的问答样本作为真样本,输入到判别器网络进行预训练;
其次,预训练生成器和判别器后,复用生成器网络为Actor,并构建结构为长短期记忆网络的Critic网络,训练生成器网络和Critic网络,同时,生成器网络与判别器网络进行对抗训练,并在生成器网络进行多次训练后更新一次判别器网络参数;
最后,对生成器网络和Critic网络进行剪枝处理。
进一步地,判别器网络为二分类器,结构为卷积神经网络;将生成器网络输出的问答样本作为假样本,判别器网络使用生成式对抗网络的损失函数计算判别器的损失函数,更新判别器参数;生成器网络采用Seq2Seq模型的问答结构,其网络结构由递归神经网络和全连接层组成,包括编码器和解码器,通过编码-解码过程生成医疗问答;其中,编码器将输入的自然语言医疗问题映射成词向量表示,解码器将映射成的词向量作为输入,然后通过全连接层输出词库里每个单词的概率值,最后通过的激活函数为Softmax。
进一步地,首先,分别将整体的Actor和Critic网络进行多路径多层结构化处理;其次,构建轻量化Actor-Critic网络目标函数:分别使用时间差分法和策略梯度法构建Critic网络Actor网络的目标函数,并在两者的目标函数中,对子路径间的权值参数采用非凸组MCP正则项进行组间稀疏约束,组内的权值参数采用L2范数进行特征组选择;然后,针对Actor和Critic网络目标函数中组MCP正则项的非凸函数优化难点,采用DC分解技术将非凸优化问题转换为凸优化问题进行求解,并使用Adam算法更新Critic网络和Actor网络权值参数;最后,分别对多路径多层结构化的Critic网络和Actor网络中整体权重值较小的子路径进行剪枝,从而缓解Actor-Critic网络的权值参数冗余问题,实现基于非凸组MCP正则项多路径多层Actor-Critic网络轻量化。
进一步地,所述将整体的Actor和Critic网络进行多路径多层结构化处理,包括:
分别将整体的Actor和Critic网络按每层隐藏层神经元个数均分成n条并行的子路径,每条子路径定义为一组;每条子路径的隐藏层神经元个数相等,为原来完整网络隐藏层神经元个数的输入层和输出层的神经元个数与原来整体网络相同;各子路径的输入为原来整体的Actor和Critic网络的输入,子路径的输出在最后一层聚合,并通过激活函数得到最终输出;在进行多次的迭代更新后,若多路径多层结构化的Critic或Actor网络的子路径数量大于某个阈值,即对其进行剪枝,若某条子路径权值参数的期望小于某个阈值,即移除该子路径,并更新Critic或Actor网络。
进一步地,所述Critic网络目标函数表示为:
其中,V(S,W)表示Critic网络,W表示Critic网络的权值参数,近似每个状态S下到最终状态的奖励期望为V(S,W),在当前网络下下一个状态S′到最终状态的奖励期望为V(S′,W),智能体在环境中的状态为S,在状态中执行动作A,获得环境给出该动作的奖励R,折扣率为γ。智能体转移到下一状态S′,再执行下一个动作A′;为组MCP正则项,||·||2为L2范数,Wl g为第g条子路径的第l层的权值矩阵,即为组MCP正则项对各组子路径间的权值参数进行稀疏约束,L为子路径的总层数,G为子路径的总数,参数λ>0,μ>1,β>0为正则项参数,E(·)代表期望,为组MCP正则项函数表达式。
进一步地,所述Actor网络目标函数表示为:
其中:
τ={S1,A1,R1,S2,A2,R2,…,St,At,Rt,…,Sk,Ak,Rk}
上式中,Actor网络表示为π(A|S,θ),S表示Actor网络在当前环境下的状态,A表示在状态中执行的动作,θ表示Actor网络的权值参数;将Actor的控制策略视为k步的策略过程,Actor网络在该环境下的策略轨迹为τ,策略轨迹下的累积奖励表示为R(τ),St表示策略轨迹第t步时生成器的状态,At表示策略轨迹第t步时智能体选择的策略,Rt表示第t步时生成器采取策略At后获得的奖励;在某个Actor网络参数下策略轨迹出现的概率为P(τ|θ),为第g条子路径中第l层的权值矩阵,即为组MCP正则项对各组子路径间权值参数的稀疏约束,L为子路径的总层数,G为子路径的总数,β>0为正则项参数。
与现有技术相比,本发明具有以下技术特点:
1.针对基于策略梯度方法的生成式对抗网络技术存在的不稳定问题,本发明设计了基于Actor-Critic生成式对抗网络技术。另外,针对基于Actor-Critic生成式对抗网络参数冗余的问题,本发明采用基于非凸组MCP正则项的多路径多层Actor-Critic网络轻量化方法对其进行剪枝处理。
2.本发明在Actor和Critic网络的目标函数中,对组间各子路径的权值参数采用非凸组MCP正则项进行组间稀疏约束,使属于同一路径内的权值参数同时趋向零点,组内的权值参数采用L2范数进行特征组选择。相较组Lasso正则项,非凸组MCP具有较强的稀疏性和无偏性,从而可获得更好的轻量化效果。
3.本发明针对非凸组MCP正则项带来的非凸优化难点,首先,采用DC分解技术将非凸组MCP正则项分解为两个凸函数相减的形式,从而将问题转化为凸问题进行求解;其次,采用Adam算法更新网络的权值参数;最后,在训练过程中对网络进行剪枝,缓解网络参数冗余问题,使其能部署到资源有限的设备上。
附图说明
图1为本发明基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统模型训练结构图;
图2为基于非凸组MCP正则项的多路径多层化Actor/Critic网络结构图;
图3为本发明所提出的基于非凸组MCP正则项的多路径多层Actor-Critic网络轻量化方法与现有Actor-Critic算法训练曲线对比。
具体实施方式
一、基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统
参见附图,本发明提出的一种基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统,包括:
轻量化Actor-Critic结构的生成器和判别器;
输入用户提出的医疗问题后,生成器通过编码-解码的方式生成准确、专业医疗诊断方案,作为系统的输出。
首先,基于Actor-Critic生成式对抗网络构建生成器和判别器网络,并进行预训练。将已知的医疗问答文本构建的数据集,输入到生成器中并采用极大似然估计方法进行预训练,再把预训练好的生成器生成的问答样本作为假样本,已知的问答样本作为真样本,输入到判别器网络进行预训练。
其次,预训练生成器和判别器后,复用生成器网络为Actor,并构建结构为长短期记忆网络的Critic网络,采用基于非凸组MCP正则项多路径多层Actor-Critic网络更新方法训练生成器(Actor)网络和Critic网络,同时,生成器(Actor)网络与判别器网络进行对抗训练,并在生成器(Actor)网络进行多次训练后更新一次判别器网络参数。其中,定义生成器(Actor)网络在环境中的状态S为既有的句子,在状态中执行的动作A为待生成的下一个词,获得环境对该动作的奖励R为判别器网络的输出,拼接既有的句子和生成的下一个单词组成新的句子为下一状态S′,再得到下一个生成的单词为动作A′。
最后,采用基于非凸组MCP正则项多路径多层Actor-Critic网络轻量化方法对生成器(Actor)网络和Critic网络进行剪枝处理,缓解基于Actor-Critic生成式对抗网络参数冗余的问题,减少其对时间和存储资源的消耗,使其能部署到计算资源有限的设备。因此,本发明可实现基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统。
二、基于Actor-Critic生成式对抗网络
基于Actor-Critic生成式对抗网络由判别器网络和Actor-Critic结构的生成器网络组成。预训练生成器和判别器后,复用生成器网络作为Actor,构建Critic网络,采用Actor-Critic算法更新生成器(Actor)网络参数,并与判别器进行对抗训练,实现基于Actor-Critic生成式对抗网络。
判别器网络为二分类器,结构为卷积神经网络;将生成器网络输出的问答样本作为假样本,已知的问答样本为真样本,输入真、假样本到判别器网络中,分别得到判别器网络对这些真、假样本判断为真实样本的概率并视为奖励;判别器网络使用生成式对抗网络的损失函数计算判别器的损失函数,更新判别器参数。
生成器网络采用Seq2Seq模型的问答结构,其网络结构由递归神经网络和全连接层组成,通过编码-解码过程生成医疗问答;包括编码器和解码器,其中,编码器将输入的自然语言医疗问题映射成词向量表示,解码器将映射成的词向量作为输入,然后通过全连接层输出词库里每个单词的概率值,最后通过的激活函数为Softmax。
生成器网络的预训练采用极大似然估计方法,与真实问答样本计算交叉熵损失,利用损失值更新网络参数;判别器预训练时,把预训练好的生成器生成的问答样本作为假样本,已知的问答样本作为真样本,输入到判别器网络进行训练。
预训练生成器和判别器后,构造含有初始单词的词向量矩阵,复用生成器网络为Actor,构建结构为长短期记忆网络的Critic网络,采用Actor-Critic算法对生成器(Actor)网络权值参数进行更新。同时生成器(Actor)网络与判别器网络进行对抗训练,并在生成器(Actor)网络进行多次训练后更新一次判别器网络参数。
三、基于非凸组MCP正则项多路径多层Actor-Critic网络轻量化方法
Actor-Critic网络主要分成两个部分,通过训练Critic网络近似环境下的状态价值函数,Actor网络根据从Critic网络反馈的TD-error训练动作策略。Actor网络负责逼近连续的策略空间,Critic网络负责评价Actor网络选择策略的好坏。但网络受限于冗余权值参数,难以部署到计算资源有限的设备。因此,对Actor和Critic网络进行轻量化处理是一个迫切需要解决的问题。
剪枝技术因其高效简单的优点已逐渐成为深度神经网络轻量化的研究热点之一。在Actor-Critic网络的剪枝方法中,稀疏约束的选择决定着剪枝效果的优劣。现有流行的稀疏约束如L0范数,其不连续的特性导致其优化问题求解为NP-hard问题,在网络训练时使用贪婪算法求解将导致庞大的计算量,难以实现。对此,将L0范数凸松弛近似为L1范数,可减少网络训练的计算量,但其将导致弱稀疏性、过惩罚等问题,造成网络估测值偏差较大。Group Lasso为组形式的L1范数,可进行分组形式的稀疏约束,实现组间稀疏,但其也保留了L1范数弱稀疏性等问题。
为了克服上述现存技术的不足,本发明提出一种基于非凸组MCP正则项的多路径多层轻量化方法,用于对Actor-Critic网络进行剪枝,缓解网络参数冗余问题。
首先,分别将整体的Actor和Critic网络进行多路径多层结构化处理;其次,构建轻量化Actor-Critic网络目标函数:分别使用时间差分法和策略梯度法构建Critic网络Actor网络的目标函数,并在两者的目标函数中,对子路径间的权值参数采用非凸组MCP正则项进行组间稀疏约束,组内的权值参数采用L2范数进行特征组选择;然后,针对Actor和Critic网络目标函数中组MCP正则项的非凸函数优化难点,本发明采用DC分解技术将非凸优化问题转换为凸优化问题进行求解,并使用Adam算法更新Critic网络和Actor网络权值参数;最后,分别对多路径多层结构化的Critic网络和Actor网络中整体权重值较小的子路径进行剪枝,从而缓解Actor-Critic网络的权值参数冗余问题,实现基于非凸组MCP正则项多路径多层Actor-Critic网络轻量化。具体内容如下:
1.多路径多层结构化网络
分别将整体的Actor和Critic网络按每层隐藏层神经元个数均分成n条并行的子路径,每条子路径定义为一组。每条子路径的隐藏层神经元个数相等,为原来完整网络隐藏层神经元个数的输入层和输出层的神经元个数与原来整体网络相同,例如,一条结构为4-32-32-2的网络,可分为8条结构为4-4-4-2的子路径。各子路径的输入为原来完整网络的输入,子路径的输出在最后一层聚合,并通过激活函数得到最终输出。在进行一定次数的迭代更新后,若多路径多层结构化的Critic或Actor网络的子路径数量大于某个阈值,即对其进行剪枝,若某条子路径权值参数的期望小于某个阈值,即移除该子路径,并更新Critic或Actor网络。
2.轻量化Actor-Critic网络目标函数构建
2.1.Critic网络目标函数构建
定义在每一步中,智能体在环境中的状态为S,在状态中执行动作A,获得环境给出该动作的奖励R,折扣率为γ。智能体转移到下一状态S′,再执行下一个动作A′。构建Critic网络A(S,W),W表示Critic网络的权值参数,近似每个状态S下到最终状态的奖励期望为V(S,W),在当前网络下下一个状态S′到最终状态的奖励期望为V(S′,W)。因此,当前状态S的V(S,W)可换算为R+γV(S′,W),并以此作为更新的目标值,计算TD-error如下:
δ=R+γV(S′,W)-V(S,W)
其中δ代表TD-error,Critic网络通过最小化TD-error更新网络的权值参数,同时,采用非凸组MCP正则项对组间的权值参数进行稀疏约束,使属于同一路径内的权值参数同时趋向零点,子路径间实现组间稀疏,对组内的权值参数施加L2范数实现特征组选择。因此,基于非凸组MCP正则项的多路径多层Critic网络目标函数如下:
其中,为组MCP正则项,||·||2为L2范数,为第g条子路径的第l层的权值矩阵,即为组MCP正则项对各组子路径间的权值参数进行稀疏约束,L为子路径的总层数,G为子路径的总数,β>0为正则项参数,E[·]代表期望。其中,组MCP正则项函数表达式如下:
其中,参数λ>0,μ>1。
2.2.Actor网络目标函数构建
构建Actor网络π(A|S,θ),其中S表示Actor网络在当前环境下的状态,A表示在状态中执行的动作,θ表示Actor网络的权值参数,π(A|S,θ)近似在状态S下选择动作A的概率。将Actor的控制策略视为k步的策略过程,设Actor网络在该环境下的策略轨迹τ表示如下:
τ={S1,A1,R1,S2,A2,R2,…,St,At,Rt,…,Sk,Ak,Rk}
其中,St表示策略轨迹第t步时生成器的状态,At表示策略轨迹第t步时智能体选择的策略,Rt表示第t步时生成器采取策略At后获得的奖励;因此,获得该策略轨迹下的累积奖励表示如下:
设在某个Actor网络参数下该轨迹出现的概率为P(τ|θ),计算累计奖励的期望值,如下:
其中P(S1)和P(St+1|St)表示初始状态的概率和状态从St转移到St+1的概率,与网络的参数无关。Actor网络通过最大化累计奖励的期望值更新网络的权值参数。与Critic网络目标函数同理,采用非凸组MCP正则项对组间的权值参数进行稀疏约束,因此,基于非凸组MCP正则项的多路径多层Actor网络目标函数如下:
3.Actor-Critic网络更新和轻量化方案
3.1.Critic网络更新和轻量化方法
针对Critic网络中组MCP正则项的非凸函数优化难点,首先采用DC分解技术将目标函数分解为两个凸函数相减的形式,将原问题转化为凸函数进行求解;其次,使用Adam算法更新网络权值参数;最后,在训练过程中对整体权重值较小的子路径进行剪枝,具体如下:
(1)DC分解技术处理Critic网络目标函数
首先,将目标函数分解成两个凸函数g1(W)和g2(W)相减的形式,如下:
LCritic(W)=g1(W)-g2(W)
第二步,由以下目标函数计算权重W,其中<W,z>表示W与z的内积运算:
W∈arg minWg1(W)-<W,z>
代入g1(W)的表达式,得:
W∈arg minWJ(W)
(2)Adam算法更新Critic网络权值参数
为高效且稳定地更新Critic网络的权值参数,采用Adam算法对Critic网络进行训练。首先,在第t次迭代时,计算J(W)的梯度gt(W),并计算一阶矩估计mt和二阶矩估计vt,如下:
mt=β1mt-1+(1-β1)gt(W)
其中,β1和β2为一阶矩估计mt和二阶矩估计vt的衰减系数;
其中αW为学习率用于控制步长,ε表示数值计算稳定性参数,防止分母为0。
(3)Critic网络剪枝处理
最后,根据每条子路径的权值大小对多路径多层结构化的Critic网络进行剪枝。设限制最少子路径的阈值为TW和权值剪枝阈值为Tp。在进行一定次数的迭代更新后,当Critic网络的所有子路径的数量大于TW,则对网络进行剪枝。若第g组子路径的权值矩阵Wg的权值w的期望Ew[Wg]满足如下:
|Ew[Wg]|<Tp
即移除该子路径,并更新Critic网络。
3.2Actor网络更新和轻量化方法
同理,在Actor网络π(A|S,θ)的更新过程中,首先,采用DC分解技术对目标函数进行处理;其次,使用Adam算法更新网络权值参数;最后,在训练过程中对整体权重值较小的子路径进行剪枝。
(1)DC分解技术处理Actor网络目标函数
与上述Critic网络算法设计同理,采用DC分解技术处理Actor网络的目标函数LActor(θ),将其分解成两个凸函数g1(θ)和g2(θ)相减的形式,如下:
LActor(θ)=g1(θ)-g2(θ)
(2)Adam算法更新Actor网络权值参数
为实现单步更新且减小方差,把Critic网络中计算的δ即TD-error返回Actor网络,用δ代替上式中的∑τR(τ)-b,如下:
其中αθ>0表示网络更新的学习率,ε表示数值计算稳定性参数,防止分母为0。
(3)Actor网络剪枝处理
设限制最少子路径的阈值为Tθ和权重剪枝阈值为Tp,在进行一定次数的迭代更新后,当多路径多层结构化的Actor网络的所有子路径的数量大于Tθ,则对网络进行剪枝。若第g组子路径的权值矩阵θg的权值θ的期望Eθ[θg]满足如下:
|Eθ[θg]|<Tp
即移除该子路径,并更新Actor网络。
基于轻量化Actor-Critic生成式对抗网络的医疗问答系统工作具体步骤如下:
步骤1,已知的医疗问答文本作为数据集,输入到生成器网络中并采用极大似然估计方法,与真实问答样本计算交叉熵损失,利用损失值更新网络参数,从而预训练生成器网络。
步骤2,把预训练好的生成器生成的问答样本作为假样本,已知的问答样本作为真样本,输入到判别器网络并使用生成对抗网络的损失函数计算判别器的损失函数,更新判别器参数,从而预训练判别器网络。
步骤3,构造含有初始单词的词向量矩阵;
步骤4,给定最大步的范围,即完整句子所需生成单词的个数,从当前步下既有的句子利用生成器(Actor)网络生成对应的下一个单词为动作A,既有的句子为当前状态S,拼接既有的句子和生成的下一个单词为新的句子,即下一个状态S′;
步骤5,将生成器(Actor)网络预测完整的句子输入到判别器网络(生成器通过蒙特卡洛搜索将句子补全),得到当前动作A的奖励R;
步骤6,将当前状态S、下一个状态S′、奖励R输入到Critic网络,计算TD-error,同时更新Critic网络的权值参数;
步骤7,将Critic网络计算的TD-error反馈到生成器(Actor)网络,计算其梯度并更新权值参数;
步骤8,训练过程中采用非凸组MCP正则项多路径多层Actor-Critic网络轻量化方法对生成器(Actor)和Critic网络进行剪枝;
步骤9,生成器和判别器进行对抗训练,更新判别器网络参数,通常在进行多次生成器(Actor)网络训练后进行一次判别器参数更新。
附图1示出了本发明基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统模型训练结构图。参照附图1,模型包括轻量化Actor-Critic结构的生成器和判别器。判别器的训练将生成器(Actor)网络输出作为假样本,真实问答数据为真样本,输入真、假样本到判别器网络中,使用生成对抗网络的损失函数公式计算其损失函数,更新判别器参数,判别器输出真实样本的概率值并视为奖励。轻量化Actor-Critic结构的生成器训练主要分为两个部分:Critic网络更新部分和生成器(Actor)网络更新部分。其中,Critic网络更新部分使用时间差分法构建目标函数并采用Adam算法单步更新网络权值参数,并计算TD-error反馈到生成器(Actor)网络;生成器(Actor)网络更新部分采用策略梯度法构建目标函数,在更新参数时引入TD-error并采用Adam算法更新网络权重参数,训练生成器的策略π,即下一个生成的单词。同时,对两者的权值参数进行基于非凸组MCP正则项的多路径多层的稀疏化剪枝。
附图2示出了本发明基于非凸组MCP正则项的多路径多层化Actor/Critic网络结构图。在Actor和Critic网络训练的过程中,将Actor或Critic网络按每层隐藏层神经元个数均分成若干子路径,每条子路径定义为一组。每条子路径的输入为原整体网络的输入,子路径的输出在最后一层聚合,最后通过激活函数输出。同时,对各组子路径间的权值矩阵施加非凸组MCP正则项进行稀疏约束,使子路径间形成组间稀疏,对组内的权值矩阵施加L2范数实现特征组选择。
附图3示出了本发明基于非凸组MCP正则项的多路径多层Actor-Critic网络轻量化方法的实例。图中红色实线和绿色虚线分别表示本发明与原始Actor-Critic算法应用于环境CartPole-v1的训练曲线,本发明的训练性能优于原始Actor-Critic算法。此外,原始的Actor-Critic网络权重参数占用的内存为71.7KB,轻量化后的Actor-Critic网络权重参数占用内存大小为13.5KB。
Claims (7)
1.一种基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统,其特征在于,包括:
轻量化Actor-Critic结构的生成器和判别器;
输入用户提出的医疗问题后,生成器通过编码-解码的方式生成医疗诊断方案作为系统的输出;
其中,所述轻量化Actor-Critic结构的生成器和判别器的训练过程为:
首先,构建生成器和判别器网络,并进行预训练;将已知的医疗问答文本构建的数据集,输入到生成器中并采用极大似然估计方法进行预训练,再把预训练好的生成器生成的问答样本作为假样本,已知的问答样本作为真样本,输入到判别器网络进行预训练;
其次,预训练生成器和判别器后,复用生成器网络为Actor,并构建结构为长短期记忆网络的Critic网络,训练生成器网络和Critic网络,同时,生成器网络与判别器网络进行对抗训练,并在生成器网络进行多次训练后更新一次判别器网络参数;
最后,对生成器网络和Critic网络进行剪枝处理。
2.根据权利要求1所述的基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统,其特征在于,所述判别器网络为二分类器,结构为卷积神经网络;将生成器网络输出的问答样本作为假样本,判别器网络使用生成式对抗网络的损失函数计算判别器的损失函数,更新判别器参数。
3.根据权利要求1所述的基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统,其特征在于,生成器网络采用Seq2Seq模型的问答结构,其网络结构由递归神经网络和全连接层组成,包括编码器和解码器,通过编码-解码过程生成医疗问答;其中,编码器将输入的自然语言医疗问题映射成词向量表示,解码器将映射成的词向量作为输入,然后通过全连接层输出词库里每个单词的概率值,最后通过的激活函数为Softmax。
4.根据权利要求1所述的基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统,其特征在于,首先,分别将整体的Actor和Critic网络进行多路径多层结构化处理;其次,构建轻量化Actor-Critic网络目标函数:分别使用时间差分法和策略梯度法构建Critic网络Actor网络的目标函数,并在两者的目标函数中,对子路径间的权值参数采用非凸组MCP正则项进行组间稀疏约束,组内的权值参数采用L2范数进行特征组选择;然后,针对Actor和Critic网络目标函数中组MCP正则项的非凸函数优化难点,采用DC分解技术将非凸优化问题转换为凸优化问题进行求解,并使用Adam算法更新Critic网络和Actor网络权值参数;最后,分别对多路径多层结构化的Critic网络和Actor网络中整体权重值较小的子路径进行剪枝,从而缓解Actor-Critic网络的权值参数冗余问题,实现基于非凸组MCP正则项多路径多层Actor-Critic网络轻量化。
5.根据权利要求4所述的基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统,其特征在于,所述将整体的Actor和Critic网络进行多路径多层结构化处理,包括:
6.根据权利要求1所述的基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统,其特征在于,所述Critic网络目标函数表示为:
7.根据权利要求1所述的基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统,其特征在于,所述Actor网络目标函数表示为:
其中:
τ={S1,A1,R1,S2,A2,R2,...,St,At,Rt,...,Sk,Ak,Rk}
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110782860.5A CN113505210B (zh) | 2021-07-12 | 2021-07-12 | 一种基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110782860.5A CN113505210B (zh) | 2021-07-12 | 2021-07-12 | 一种基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113505210A true CN113505210A (zh) | 2021-10-15 |
CN113505210B CN113505210B (zh) | 2022-06-14 |
Family
ID=78012261
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110782860.5A Active CN113505210B (zh) | 2021-07-12 | 2021-07-12 | 一种基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113505210B (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114372438A (zh) * | 2022-01-12 | 2022-04-19 | 广东工业大学 | 基于轻量化深度强化学习的芯片宏单元布局方法及系统 |
CN117114148A (zh) * | 2023-08-18 | 2023-11-24 | 湖南工商大学 | 一种轻量级联邦学习训练方法 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110727844A (zh) * | 2019-10-21 | 2020-01-24 | 东北林业大学 | 一种基于生成对抗网络的在线评论商品特征观点提取方法 |
CN111104595A (zh) * | 2019-12-16 | 2020-05-05 | 华中科技大学 | 一种基于文本信息的深度强化学习交互式推荐方法及系统 |
CN111159454A (zh) * | 2019-12-30 | 2020-05-15 | 浙江大学 | 基于Actor-Critic生成式对抗网络的图片描述生成方法及系统 |
-
2021
- 2021-07-12 CN CN202110782860.5A patent/CN113505210B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110727844A (zh) * | 2019-10-21 | 2020-01-24 | 东北林业大学 | 一种基于生成对抗网络的在线评论商品特征观点提取方法 |
CN111104595A (zh) * | 2019-12-16 | 2020-05-05 | 华中科技大学 | 一种基于文本信息的深度强化学习交互式推荐方法及系统 |
CN111159454A (zh) * | 2019-12-30 | 2020-05-15 | 浙江大学 | 基于Actor-Critic生成式对抗网络的图片描述生成方法及系统 |
Non-Patent Citations (2)
Title |
---|
DAVID PFAU ET AL.: "Connecting generative adversarial networks and actor-critic methods", 《ARXIV》, 31 December 2016 (2016-12-31), pages 1 - 10 * |
王嘉伟: "正则化生成对抗网络研究", 《中国优秀博硕士学位论文全文数据库(硕士)信息科技辑(月刊)》, 15 June 2021 (2021-06-15), pages 140 - 36 * |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114372438A (zh) * | 2022-01-12 | 2022-04-19 | 广东工业大学 | 基于轻量化深度强化学习的芯片宏单元布局方法及系统 |
CN114372438B (zh) * | 2022-01-12 | 2023-04-07 | 广东工业大学 | 基于轻量化深度强化学习的芯片宏单元布局方法及系统 |
CN117114148A (zh) * | 2023-08-18 | 2023-11-24 | 湖南工商大学 | 一种轻量级联邦学习训练方法 |
CN117114148B (zh) * | 2023-08-18 | 2024-04-09 | 湖南工商大学 | 一种轻量级联邦学习训练方法 |
Also Published As
Publication number | Publication date |
---|---|
CN113505210B (zh) | 2022-06-14 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111581343B (zh) | 基于图卷积神经网络的强化学习知识图谱推理方法及装置 | |
JP7041281B2 (ja) | ディープニューラルネットワークモデルに基づくアドレス情報特徴抽出方法 | |
Arslan et al. | SMOTE and gaussian noise based sensor data augmentation | |
CN104077595B (zh) | 基于贝叶斯正则化的深度学习网络图像识别方法 | |
Stuhlmüller et al. | Learning stochastic inverses | |
US11151328B2 (en) | Using neural network and score weighing to incorporate contextual data in sentiment analysis | |
CN109614471B (zh) | 一种基于生成式对抗网络的开放式问题自动生成方法 | |
CN111542843A (zh) | 利用协作生成器积极开发 | |
CN107729999A (zh) | 考虑矩阵相关性的深度神经网络压缩方法 | |
CN110969251A (zh) | 基于无标签数据的神经网络模型量化方法及装置 | |
CN107679617A (zh) | 多次迭代的深度神经网络压缩方法 | |
Elhamifar et al. | Self-supervised multi-task procedure learning from instructional videos | |
CN113505210B (zh) | 一种基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统 | |
CN116523079A (zh) | 一种基于强化学习联邦学习优化方法及系统 | |
CN112990385A (zh) | 一种基于半监督变分自编码器的主动众包图像学习方法 | |
CN112000772A (zh) | 面向智能问答基于语义特征立方体的句子对语义匹配方法 | |
CN114819143A (zh) | 一种适用于通信网现场维护的模型压缩方法 | |
CN104050505A (zh) | 一种基于带学习因子蜂群算法的多层感知器训练方法 | |
CN113239211A (zh) | 一种基于课程学习的强化学习知识图谱推理方法 | |
CN111832817A (zh) | 基于mcp罚函数的小世界回声状态网络时间序列预测方法 | |
CN117435715A (zh) | 一种基于辅助监督信号改进时序知识图谱的问答方法 | |
CN109558898B (zh) | 一种基于深度神经网络的高置信度的多选择学习方法 | |
Prangle | Distilling importance sampling | |
CN116682399A (zh) | 一种音乐生成方法、系统、电子设备及介质 | |
CN111382840B (zh) | 一种面向自然语言处理的基于循环学习单元的htm设计方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |