CN112818159B - 一种基于生成对抗网络的图像描述文本生成方法 - Google Patents

一种基于生成对抗网络的图像描述文本生成方法 Download PDF

Info

Publication number
CN112818159B
CN112818159B CN202110206288.8A CN202110206288A CN112818159B CN 112818159 B CN112818159 B CN 112818159B CN 202110206288 A CN202110206288 A CN 202110206288A CN 112818159 B CN112818159 B CN 112818159B
Authority
CN
China
Prior art keywords
text
image
word
generator
vector
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110206288.8A
Other languages
English (en)
Other versions
CN112818159A (zh
Inventor
陆佳妮
程帆
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shanghai Jiaotong University
Original Assignee
Shanghai Jiaotong University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shanghai Jiaotong University filed Critical Shanghai Jiaotong University
Priority to CN202110206288.8A priority Critical patent/CN112818159B/zh
Publication of CN112818159A publication Critical patent/CN112818159A/zh
Application granted granted Critical
Publication of CN112818159B publication Critical patent/CN112818159B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/50Information retrieval; Database structures therefor; File system structures therefor of still image data
    • G06F16/58Retrieval characterised by using metadata, e.g. metadata not derived from the content or metadata generated manually
    • G06F16/583Retrieval characterised by using metadata, e.g. metadata not derived from the content or metadata generated manually using metadata automatically derived from the content
    • G06F16/5846Retrieval characterised by using metadata, e.g. metadata not derived from the content or metadata generated manually using metadata automatically derived from the content using extracted text
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/10Text processing
    • G06F40/12Use of codes for handling textual entities
    • G06F40/126Character encoding
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/10Text processing
    • G06F40/194Calculation of difference between files
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • G06F40/237Lexical tools
    • G06F40/242Dictionaries
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/044Recurrent networks, e.g. Hopfield networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Biomedical Technology (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Evolutionary Computation (AREA)
  • Biophysics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Library & Information Science (AREA)
  • Databases & Information Systems (AREA)
  • Image Analysis (AREA)

Abstract

本发明涉及一种基于生成对抗网络的图像描述文本生成方法,包括以下步骤:1)构建用以实现对图像进行特征提取的编码器;2)对文本进行词嵌入,并构建用以生成图像描述文本的解码器;3)根据极大似然估计对由编码器和解码器共同构成的生成器进行预训练;4)构建基于卷积神经网络的判别器并进行训练;5)共同训练生成器与判别器;6)将待生成描述文本的测试图像数据输入训练好的生成器中,输出生成的描述文本。与现有技术相比,本发明具有提高生成的文本的客观评测得分、可解释性好和多样性等优点。

Description

一种基于生成对抗网络的图像描述文本生成方法
技术领域
本发明涉及人工智能方向中的计算机视觉和自然语言处理领域,尤其是涉及一种基于生成对抗网络的图像描述文本生成方法。
背景技术
随着人工智能技术的成熟,计算机视觉、自然语言处理等领域都有了飞速发展,图像描述任务要求机器可以自动为图像生成描述性的语句,因此图像描述模型需要同时具备图像理解能力和自然语言理解能力,这依赖于模型对图像表示和文本表示的获取与处理。
现有主流的图像描述方法包括以下步骤:
1)利用编码器提取出图像特征;
2)利用解码器和注意力机制,解码输入的特征,生成文本;
3)用REINFORCE这一强化学习算法进一步优化生成器。
上述图像描述生成方法框架较为简单,在步骤1)中使用的编码器通常只是一个简单的卷积神经网络,输入图像后,输出一个完整的图像特征,在这个完整特征上使用注意力机制,相当于将图像按大小相同的网格进行了划分,而一个物体被网格切分后,可能是不完整的,由此生成的描述文本是不精确的;在步骤2)中的注意力机制只用于图像特征上,没有利用好文本自身的特征,在自然语言中,句子中通常有一些连接词,它们的生成与图像本身无关;在步骤3)中只用到了强化学习算法来优化生成器,而生成对抗网络可以进一步优化生成器。图像描述生成方法单纯依靠编码器-解码器的架构和全局的注意力机制,在生成文本描述时,仍有许多不足:用词不够准确,在客观评价指标上的评分较低,提升不明显。
发明内容
本发明的目的就是为了克服上述现有技术存在的缺陷而提供一种基于生成对抗网络的图像描述文本生成方法。
本发明的目的可以通过以下技术方案来实现:
一种基于生成对抗网络的图像描述文本生成方法,包括以下步骤:
1)构建用以实现对图像进行特征提取的编码器;
2)对文本进行词嵌入,并构建用以生成图像描述文本的解码器;
3)根据极大似然估计对由编码器和解码器共同构成的生成器进行预训练;
4)构建基于卷积神经网络的判别器并进行训练;
5)共同训练生成器与判别器;
6)将待生成描述文本的测试图像数据输入训练好的生成器中,输出生成的描述文本。
所述的步骤1)中,所述的编码器为基于ResNet-101的Faster R-CNN模型,对于给定的一张图像,编码器从该图像中检测到的n个物体,将图像编码为n个目标区域的特征集合V={v1,…vi…,vn},其中,vi为目标选区i经过平均池化层后的特征向量。
所述的步骤2)中,采用GloVe模型对文本进行词嵌入,得到词嵌入后的文本向量,即词嵌入表示。
所述的步骤2)中,解码器由一个双层的长短期记忆神经网络模型和两个注意力模块组成,具体包括作为第一层的注意力生成LSTM层、作为第二层的语言生成LSTM层以及设置在注意力生成LSTM层与语言生成LSTM层之间的两个用以生成视觉哨兵向量的自适应注意力模块。
所述的注意力生成LSTM层以图像I的特征表示
Figure BDA0002950853730000021
输入词wt的词嵌入表示WeΠt、语言生成LSTM层在t-1步的隐藏层状态
Figure BDA0002950853730000022
为输入,输出为第一视觉哨兵向量
Figure BDA0002950853730000023
第二视觉哨兵向量
Figure BDA0002950853730000024
以及注意力生成LSTM层在第t步的隐藏层状态
Figure BDA0002950853730000025
所述的自适应注意力模块包括用以生成语境向量ct的第一自适应注意力模块以及用以生成目标区域集合转移信号δt的第二自适应注意力模块,所述的第一自适应注意力模块以第一视觉哨兵向量
Figure BDA0002950853730000026
和特征集合
Figure BDA0002950853730000027
为输入,所述的第二自适应注意力模块以第二视觉哨兵向量
Figure BDA0002950853730000028
和特征集合
Figure BDA0002950853730000029
为输入,所述的语言生成LSTM层以注意力生成LSTM层当前的隐藏层状态
Figure BDA00029508537300000210
和语境向量ct为输入,输出为生成词yt的概率分布,其中,图像I的特征表示
Figure BDA00029508537300000211
具体为编码器输出的特征集合V中元素的均值,We为GloVe的模型在词典Σ上的词嵌入矩阵,Πt为输入词wt的独热编码。
所述的特征集合vt具体为目标区域集合rt的特征集合,rt为集合R={r1,…,rN}中的指针在第t步指向的元素,该指针由目标区域集合转移信号δt控制,则有:
Figure BDA0002950853730000031
其中,k为解码器的步数,第0步的目标区域集合转移信号δ0默认值为0,N为集合R的大小,即包含目标区域集合ri的个数。
所述的步骤3)中,采用计划采样的方法对生成器进行预训练,在训练过程中,生成器的预训练目标为最小化损失函数,生成器的损失函数LG(θ)由预测出的生成词yt与真实的词
Figure BDA0002950853730000032
之间的交叉熵损失Lw(θ)以及预测出的目标区域集合转移信号δi与真实值
Figure BDA0002950853730000033
之间的交叉熵损失Lδ(θ)构成。
所述的步骤4)中,以生成器生成的文本、真实的文本以及真实的图像的拼接向量作为判别器的输入,并引入高速网络提升性能。
所述的步骤5)具体包括以下步骤:
51)根据训练集中给出的图像I以及预训练后的生成器Gθ生成的文本y1:T构成图像文本对{(I,y1:T)};
52)采用预训练后的判别器Dφ对生成的文本进行评分,并采用评分p∈[0,1]表示这些文本是真实文本的概率;
53)通过客观指标评价模块得到对生成的文本的评分s,所述的客观指标评价模块采用CIDEr-D作为客观评价指标;
54)结合评分p和评分s给出奖励值r=λ·p+(1-λ)·s,λ为可调的超参数;
55)采用REINFORCE强化学习算法更新生成器的参数θ,并采用奖励值r作为收益,基线算法选择用贪婪算法生成的文本序列;
56)根据训练集中给出的图像I,更新参数后的生成器Gθ重新生成文本y1:T
57)根据训练集中给出的图像I,判别器的损失函数考虑三类文本,即与图像I相关的正确真实文本
Figure BDA0002950853730000034
生成器生成的文本y1:T以及与I无关的错误真实文本
Figure BDA0002950853730000035
更新判别器的参数φ;
58)返回步骤51),继续下一次生成对抗网络的过程,直到生成器和判别器收敛。
在测试过程中采用集束搜索,且集束大小取值为5。
与现有技术相比,本发明具有以下优点:
一、本发明将强大的目标检测模型Faster-RCNN作为编码器,输出检测到的目标区域的特征作为图像的特征,使得生成器在生成描述性文本时可以更好地关注到物体本身,提升了编码器的编码效果。
二、本发明在解码器部分用到了双层的LSTM模型,并且加入了两个自适应的注意力模块,一个注意力模块用于决定模型当前应该生成“可视词”还是“文本词”,另一个注意力模块用于决定模型是否已经描述完当前的目标区域集合,是否应该描述下一个目标区域集合,使得解码器生成的文本更加流畅,并且具有良好的可解释性。
三、本发明不仅使用REINFORCE强化学习算法优化生成器,还引入了基于CNN的判别器,用到了生成对抗网络的训练过程,以CIDEr-D这一客观评价指标作为优化目标,提升了生成器生成的文本在BLEU、ROUGE-L、METEOR、CIDEr、SPICE等一系列客观评价指标上的评分,使得文本更加精准。
附图说明
图1为本发明的方法流程示意图。
图2为解码器结构示意图。
图3为与编码器共同训练时生成器的参数更新示意图。
具体实施方式
下面结合附图和具体实施例对本发明实施例中的技术方案进行清楚、完整的描述。
实施例
本方法主要采用Pytorch实现,如图1所示,本发明提供一种基于生成对抗网络的图像描述文本生成方法,包括以下步骤:
1)将目标检测模型作为编码器,提取出图像的特征。编码器是目标检测模型Faster R-CNN,图像数据经过Faster R-CNN模型得到一个区域特征集合、包围盒的集合以及每个区域的类别Softmax概率分布。
Faster R-CNN模型搭建于ResNet-101上,ResNet-101是在ImageNet数据集上进行分类训练的预训练模型,将Faster R-CNN在Visual Genome数据集上进行训练,在对目标分类时用到了1600个类别标签和1个背景标签,共计1601类,对于候选区域的非极大值抑制算法,区域面积重叠率(Intersetction Over Union,IOU)阈值设定为0.7,选择区域时的类别检测置信度阈值为0.5,给定一张图像I,Faster R-CNN从I中检测到的n个物体,将图像编码为n个目标区域的特征集合V={v1,…,vn},
Figure BDA0002950853730000051
对于每个特定的目标选区i,vi为该区域经过平均池化层后的特征向量,维度D是2048维。
2)对文本进行词嵌入,将包含注意力模块的长短期记忆神经网络作为解码器,根据图像生成文本描述。
词嵌入步骤主要包括清洗文本、建立词典Σ和词嵌入三步,限定句子的最大长度为20,去除词频低于5次的词,建立词典,并且在词典中引入四个特殊的符号:开始符号<bos>,结束符号<eos>,未知符号<unk>和填充符号<pad>。开始符号<bos>用于标记一句句子的开始;结束符号<eos>用于标记一句句子的结束;未知符号<unk>用于标记没有在词典中出现过的词;填充符号<pad>在小批次(mini-batch)训练时,将同一批次的所有句子按照该批次中最长的句子补齐成同一长度。用GloVe模型对单词进行词嵌入,词嵌入后的向量大小为300维。
解码器是一个双层的长短期记忆神经网络(Long Short-Term Memory,LSTM)模型,两层LSTM层之间包含两个自适应的注意力模块,整体结构如图2所示。第一层是注意力生成LSTM层,为两个注意力模块产生视觉哨兵向量,它的输出连接到两个注意力模块和第二层LSTM层;第二层是语言生成LSTM层,用于生成词yt的概率分布。两个LSTM层的隐藏层大小都为1000。用
Figure BDA0002950853730000052
分别表示注意力生成LSTM层、语言生成LSTM层在第t步的隐藏层状态。
注意力生成LSTM层的输入
Figure BDA0002950853730000053
有三个:图像I的特征表示
Figure BDA0002950853730000054
输入词wt的词嵌入表示WeΠt、语言生成LSTM层在前一步的隐藏层状态
Figure BDA0002950853730000055
图像I的特征表示
Figure BDA0002950853730000056
为编码器输出的特征集合V中元素的均值;输入词wt的词嵌入表示是GloVe的模型在词典Σ上的词嵌入矩阵
Figure BDA0002950853730000057
和输入词wt的独热编码Πt的乘积WeΠt。因此,
Figure BDA0002950853730000058
在第t步时,注意力生成LSTM层的隐藏层状态更新公式为:
Figure BDA0002950853730000059
两个注意力模块都是自适应注意力模块,一个模块负责产生语境向量ct,另一个模块负责产生目标区域集合转移信号δt,自适应注意力模块可以看作是一个单层的全连接神经网络,输入512维,输出1维,后面连接了一个Softmax函数。两个注意力模块的输入都有一个特征集合
Figure BDA0002950853730000061
V为Faster-RCNN检测到的所有目标区域的特征集合,vt为目标区域集合rt的特征集合,rt为集合R={r1,…,rN}中的指针在第t步指向的元素,指针由目标区域集合转移信号δt控制,则rt的表达式为:
Figure BDA0002950853730000062
负责产生语境向量ct的注意力模块的输入为特征集合vt以及注意力生成LSTM层输出的视觉哨兵向量
Figure BDA0002950853730000063
视觉哨兵向量
Figure BDA0002950853730000064
由注意力生成LSTM层根据输入向量
Figure BDA0002950853730000065
前一步的隐藏层状态
Figure BDA0002950853730000066
以及此刻内部的存储单元状态
Figure BDA0002950853730000067
计算得到:
Figure BDA0002950853730000068
Figure BDA0002950853730000069
其中,Wic和Whc为模型需要学习的参数,⊙表示元素乘积,σ(·)表示Sigmoid函数,αt为在特征集合vt上的注意力权重分布,则有:
Figure BDA00029508537300000610
Figure BDA00029508537300000611
Figure BDA00029508537300000612
这里的
Figure BDA00029508537300000613
是一个行向量,它和Wsr、Wsc、Wg都是模型需要学习的参数,
Figure BDA00029508537300000614
是元素值全为1的向量,k为rt中目标区域的数量,语境向量ct表示此时模型应该关注的区域的特征表示,作为语言生成LSTM层的输入之一。
负责产生目标区域集合转移信号δt的注意力模块的输入为特征集合vt以及注意力生成LSTM层输出的视觉哨兵向量
Figure BDA00029508537300000615
的生成与
Figure BDA00029508537300000616
的生成相类似:
Figure BDA00029508537300000617
Figure BDA00029508537300000618
其中,W和W是模型需要学习的另一组权重参数,从一个目标区域集合转移到另一个目标区域集合(δt=1)的概率可以被定义为在视觉哨兵向量
Figure BDA00029508537300000619
和目标区域集合rt上注意到
Figure BDA00029508537300000620
的概率:
Figure BDA00029508537300000621
Figure BDA0002950853730000071
Figure BDA0002950853730000072
表示是向量
Figure BDA0002950853730000073
中的第i个元素,W*是模型需要学习的权重参数。
语言生成LSTM层在第t步的隐藏层状态更新为:
Figure BDA0002950853730000074
Figure BDA0002950853730000075
最终,语言生成LSTM层输出词yt的概率分布:
Figure BDA0002950853730000076
y1:t-1表示y1,…,yt-1,Wo为模型需要学习的权重参数。
3)根据极大似然估计,对生成器进行预训练。生成器是步骤1)的编码器和步骤2)的解码器的组合。生成器的损失函数LG(θ)由两部分组成:一部分是预测出的词yt与真实的词
Figure BDA0002950853730000077
的交叉熵损失Lw(θ),另一部分是预测出的目标区域集合转移信号δi与真实值
Figure BDA0002950853730000078
的交叉熵损失Lδ(θ),权重取值为λw=1,λδ=4:
Figure BDA0002950853730000079
Figure BDA00029508537300000710
LG(θ)=λwLw(θ)+λδLδ(θ)
生成器的预训练目标是最小化损失函数。预训练时用到了计划采样的方法,在第t步预测时,解码器的输入词wt有p的概率选择前一个真实的词
Figure BDA00029508537300000711
有1-p的概率选择前一步预测的词yt。p初始时为1,进行线性衰减,每三次完整训练后,衰减0.05,最终不小于0.5。用Adam作为优化器,初始学习率为5×10-4,每三次完整训练后,学习率衰减0.8,总共预训练25次。
4)将生成器生成的文本、真实的文本、真实的图像输入判别器,对判别器进行预训练。判别器基于卷积神经网络(Convolution Neural Network,CNN),输入为图像I的特征表示
Figure BDA00029508537300000712
和完整的描述语句{w1,…,wT}的词嵌入的拼接ε:
Figure BDA00029508537300000713
Figure BDA00029508537300000714
为水平拼接操作,
Figure BDA00029508537300000715
为编码器输出的特征集合V中元素的均值,
Figure BDA00029508537300000716
是普通的词嵌入矩阵,Πi为输入词wi的独热编码。最终生成的矩阵大小为
Figure BDA00029508537300000717
d选择为2048,使用大小为d×l的卷积核
Figure BDA00029508537300000718
后得到特征向量:c=[c1,c2,…,cT-l+2],其中ci=ReLU(κ*εi:i+l-1+b),使用基于时间的最大池化层得到
Figure BDA0002950853730000081
在最终的全连接层前加入高速网络结构:
Figure BDA0002950853730000082
Figure BDA0002950853730000083
Figure BDA0002950853730000084
其中,WT、WH是高速网络的权重,bT、bH是高速网络的偏差,⊙是分段乘积操作。最终,使用一个全连接层和Sigmoid操作得到概率值p,表示给定图像I的情况下,一段话是正确文本的概率:
Figure BDA0002950853730000085
Wo和bo分别是输出层的权重和方差。
对于一张图像I,判别器的损失函数考虑三类文本:与I相关的正确真实文本
Figure BDA0002950853730000086
生成器生成的文本y1:T、与I无关的错误真实文本
Figure BDA0002950853730000087
它们与图像I构成三个样本对集合:
Figure BDA0002950853730000088
判别器的损失函数LD(φ)由三部分构成:
Figure BDA0002950853730000089
判别器的预训练目标为最小化损失函数。使用Adam作为优化器,初始学习率为1×10-3,预训练10次。
5)共同训练生成器与判别器。
生成器与判别器共同训练时,生成器的参数更新如图3所示。生成器与判别器共同训练的具体过程如下:
51)根据训练集中给出的图像I,预训练后的生成器Gθ生成文本y1:T,构成图像文本对{(I,y1:T)};
52)预训练后的判别器Dφ对生成的文本进行评分,用p∈[0,1]表示这些文本是真实文本的概率;
53)客观指标评价模块对生成的文本给出评分s;客观指标是CIDEr-D.
54)综合52)的评分和53)给出的评分,给出奖励值r=λ·p+(1-λ)·s,λ为可调的超参数;这里设置λ为0.3.
55)用REINFORCE这一强化学习算法来更新生成器的参数θ;REINFORCE将生成文本序列的过程看作一个强化学习问题:解码器(智能体)根据当前模型的参数θ(策略pθ),与图像特征、当前的文本特征、区域集合(环境)交互,生成下一个词(动作),采用步骤54)的奖励值r作为收益,记作rC(·),强化学习优化的目标为最小化负的收益函数的期望,梯度函数写为:
Figure BDA0002950853730000091
Figure BDA0002950853730000092
是抽样出的一个样本,用于近似y1:T
Figure BDA0002950853730000093
是抽样出的一个样本,用于近似δ1:T。基线收益函数b选择的是贪婪算法生成的文本序列
Figure BDA0002950853730000094
的收益
Figure BDA0002950853730000095
56)根据训练集中给出的图像I,更新参数后的生成器Gθ重新生成文本y1:T
57)根据训练集中给出的图像I,判别器的损失函数考虑三类文本:与I相关的正确真实文本
Figure BDA0002950853730000096
生成器生成的文本y1:T、与I无关的错误真实文本
Figure BDA0002950853730000097
更新判别器的参数φ,判别器参数更新时的损失函数仍然为:
Figure BDA0002950853730000098
58)返回步骤51),继续下一次生成对抗网络的过程,直到生成器和判别器收敛。
6)将测试的图像数据输入训练好的生成器中,输出生成的文本。测试过程中用到了集束搜索这个启发式的图搜索算法,集束大小取值为5。
以上,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的工作人员在本发明揭露的技术范围内,可轻易想到各种等效的修改或替换,这些修改或替换都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以权利要求的保护范围为准。

Claims (3)

1.一种基于生成对抗网络的图像描述文本生成方法,其特征在于,包括以下步骤:
1)构建用以实现对图像进行特征提取的编码器,所述的编码器为基于ResNet-101的Faster R-CNN模型,对于给定的一张图像,编码器从该图像中检测到的n个物体,将图像编码为n个目标区域的特征集合V={v1,…vi…,vn},其中,vi为目标选区i经过平均池化层后的特征向量;
2)对文本进行词嵌入,并构建用以生成图像描述文本的解码器,解码器由一个双层的长短期记忆神经网络模型和两个注意力模块组成,具体包括作为第一层的注意力生成LSTM层、作为第二层的语言生成LSTM层以及设置在注意力生成LSTM层与语言生成LSTM层之间的两个用以生成视觉哨兵向量的自适应注意力模块,
所述的注意力生成LSTM层的输入向量
Figure FDA0003710204190000011
包括图像I的特征表示
Figure FDA00037102041900000115
输入词wt的词嵌入表示WeΠt以及语言生成LSTM层在t-1步的隐藏层状态
Figure FDA0003710204190000012
图像I的特征表示
Figure FDA0003710204190000013
为编码器输出的特征集合V中元素的均值,输入词wt的词嵌入表示WeΠt为GloVe的模型在词典Σ上的词嵌入矩阵We和输入词wt的独热编码Πt的乘积WeΠt,则有
Figure FDA0003710204190000014
在第t步时,注意力生成LSTM层的隐藏层状态更新公式为
Figure FDA0003710204190000015
输出为第一视觉哨兵向量
Figure FDA0003710204190000016
第二视觉哨兵向量
Figure FDA0003710204190000017
以及注意力生成LSTM层在第t步的隐藏层状态
Figure FDA0003710204190000018
所述的自适应注意力模块包括用以生成语境向量ct的第一自适应注意力模块以及用以生成目标区域集合转移信号δt的第二自适应注意力模块,所述的第一自适应注意力模块以第一视觉哨兵向量
Figure FDA00037102041900000114
和特征集合
Figure FDA0003710204190000019
为输入,所述的第二自适应注意力模块以第二视觉哨兵向量
Figure FDA00037102041900000110
和特征集合
Figure FDA00037102041900000111
Figure FDA00037102041900000112
为输入,所述的特征集合vt为目标区域集合rt的特征集合,目标区域集合rt为集合R={r1,…,rN}中的指针在第t步指向的元素,指针由目标区域集合转移信号δt控制,则目标区域集合rt的表达式为:
Figure FDA00037102041900000113
其中,k为解码器的步数,第0步的目标区域集合转移信号δ0默认值为0,N为集合R的大小,即包含目标区域集合ri的个数;
第一视觉哨兵向量
Figure FDA0003710204190000021
由注意力生成LSTM层根据输入向量
Figure FDA0003710204190000022
前一步的隐藏层状态
Figure FDA0003710204190000023
以及此刻内部的存储单元状态
Figure FDA0003710204190000024
计算得到,则有:
Figure FDA0003710204190000025
其中,Wic和Whc为模型需要学习的参数,⊙表示元素乘积,σ(·)表示Sigmoid函数,αt为在特征集合vt上的注意力权重分布,则有:
Figure FDA0003710204190000026
Figure FDA0003710204190000027
Figure FDA0003710204190000028
其中,
Figure FDA0003710204190000029
为一个行向量,其与Wsr、Wsc、Wg均为模型需要学习的参数,
Figure FDA00037102041900000210
是元素值全为1的向量,k为目标区域集合rt中目标区域的数量,语境向量ct表示此时模型应该关注的区域的特征表示,作为语言生成LSTM层的输入之一;
第二自适应注意力模块的输入为特征集合vt以及注意力生成LSTM层输出的视觉哨兵向量
Figure FDA00037102041900000211
则有:
Figure FDA00037102041900000212
其中,W和W为模型需要学习的权重参数,从一个目标区域集合转移到另一个目标区域集合(δt=1)的概率被定义为在视觉哨兵向量
Figure FDA00037102041900000213
和目标区域集合rt上注意到
Figure FDA00037102041900000214
的概率,则有:
Figure FDA00037102041900000215
Figure FDA00037102041900000216
其中,
Figure FDA00037102041900000217
为向量
Figure FDA00037102041900000218
中的第i个元素,W*为模型需要学习的权重参数;
所述的语言生成LSTM层以注意力生成LSTM层当前的隐藏层状态
Figure FDA00037102041900000219
和语境向量ct为输入,输出为生成词yt的概率分布,语言生成LSTM层在第t步的隐藏层状态更新为:
Figure FDA00037102041900000220
Figure FDA00037102041900000221
则语言生成LSTM层输出词yt的概率分布为:
Figure FDA0003710204190000031
其中,y1:t-1表示y1,…,yt-1,Wo为模型需要学习的权重参数;
3)根据极大似然估计对由编码器和解码器共同构成的生成器进行预训练,采用计划采样的方法对生成器进行预训练,在训练过程中,生成器的预训练目标为最小化损失函数,生成器的损失函数LG(θ)由预测出的生成词yt与真实的词
Figure FDA0003710204190000032
之间的交叉熵损失Lw(θ)以及预测出的目标区域集合转移信号δi与真实值
Figure FDA0003710204190000033
之间的交叉熵损失Lδ(θ)这两部分构成,则有:
LG(θ)=λwLw(θ)+λδLδ(θ)
Figure FDA0003710204190000034
Figure FDA0003710204190000035
其中,λw、λδ为权重取值;
4)构建基于卷积神经网络的判别器并进行训练,以生成器生成的文本、真实的文本以及真实的图像的拼接向量作为判别器的输入,并引入高速网络提升性能具体为:
所述的判别器基于卷积神经网络,输入为图像I的特征表示
Figure FDA00037102041900000316
和完整的描述语句{w1,…,wT}的词嵌入的拼接ε,则有:
Figure FDA0003710204190000036
其中,
Figure FDA0003710204190000037
为水平拼接操作,
Figure FDA0003710204190000038
为编码器输出的特征集合V中元素的均值,
Figure FDA0003710204190000039
为普通的词嵌入矩阵,Πi为输入词wi的独热编码,采用大小为d×l的卷积核
Figure FDA00037102041900000310
后得到特征向量c=[c1,c2,…,cT-l+2],其中ci=ReLU(κ*εi:i+l-1+b),使用基于时间的最大池化层得到
Figure FDA00037102041900000311
在最终的全连接层前加入高速网络结构,则有:
Figure FDA00037102041900000312
Figure FDA00037102041900000313
Figure FDA00037102041900000314
其中,WT、WH为高速网络的权重,bT、bH为高速网络的偏差,⊙为分段乘积操作,最终,使用一个全连接层和Sigmoid操作得到概率值p,表示给定图像I的情况下,一段话是正确文本的概率p,则有:
Figure FDA00037102041900000315
其中,Wo和bo分别为输出层的权重和方差;
5)共同训练生成器与判别器,具体包括以下步骤:
51)根据训练集中给出的图像I以及预训练后的生成器Gθ生成的文本y1:T构成图像文本对{(I,y1:T)};
52)采用预训练后的判别器Dφ对生成的文本进行评分,并采用评分p∈[0,1]表示这些文本是真实文本的概率;
53)通过客观指标评价模块得到对生成的文本的评分s,所述的客观指标评价模块采用CIDEr-D作为客观评价指标;
54)结合评分p和评分s给出奖励值r=λ·p+(1-λ)·s,λ为可调的超参数;
55)采用REINFORCE强化学习算法更新生成器的参数θ,并采用奖励值r作为收益,基线算法选择用贪婪算法生成的文本序列,具体为:
解码器根据当前模型的参数θ,与图像特征、当前的文本特征、区域集合交互,生成下一个词,采用步骤54)的奖励值r作为收益,记作rC(·),强化学习优化的目标为最小化负的收益函数的期望,梯度函数为:
Figure FDA0003710204190000041
其中,
Figure FDA0003710204190000042
为抽样出的一个样本,用于近似y1:T
Figure FDA0003710204190000043
为抽样出的一个样本,用于近似δ1:T,基线收益函数b选择的是贪婪算法生成的文本序列
Figure FDA0003710204190000044
的收益
Figure FDA0003710204190000045
56)根据训练集中给出的图像I,更新参数后的生成器Gθ重新生成文本y1:T
57)根据训练集中给出的图像I,判别器的损失函数考虑三类文本,即与图像I相关的正确真实文本
Figure FDA0003710204190000046
生成器生成的文本y1:T以及与I无关的错误真实文本
Figure FDA0003710204190000047
更新判别器的参数φ,判别器参数更新时的损失函数为:
Figure FDA0003710204190000048
58)返回步骤51),继续下一次生成对抗网络的过程,直到生成器和判别器收敛;
6)将待生成描述文本的测试图像数据输入训练好的生成器中,输出生成的描述文本。
2.根据权利要求1所述的一种基于生成对抗网络的图像描述文本生成方法,其特征在于,所述的步骤2)中,采用GloVe模型对文本进行词嵌入,得到词嵌入后的文本向量,即词嵌入表示。
3.根据权利要求1所述的一种基于生成对抗网络的图像描述文本生成方法,其特征在于,在测试过程中采用集束搜索,且集束大小取值为5。
CN202110206288.8A 2021-02-24 2021-02-24 一种基于生成对抗网络的图像描述文本生成方法 Active CN112818159B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110206288.8A CN112818159B (zh) 2021-02-24 2021-02-24 一种基于生成对抗网络的图像描述文本生成方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110206288.8A CN112818159B (zh) 2021-02-24 2021-02-24 一种基于生成对抗网络的图像描述文本生成方法

Publications (2)

Publication Number Publication Date
CN112818159A CN112818159A (zh) 2021-05-18
CN112818159B true CN112818159B (zh) 2022-10-18

Family

ID=75865383

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110206288.8A Active CN112818159B (zh) 2021-02-24 2021-02-24 一种基于生成对抗网络的图像描述文本生成方法

Country Status (1)

Country Link
CN (1) CN112818159B (zh)

Families Citing this family (14)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113220891B (zh) * 2021-06-15 2022-10-18 北京邮电大学 基于无监督的概念到句子的生成对抗网络图像描述方法
CN113362416B (zh) * 2021-07-01 2024-05-17 中国科学技术大学 基于目标检测的文本生成图像的方法
CN113254604B (zh) * 2021-07-15 2021-10-01 山东大学 一种基于参考规范的专业文本生成方法及装置
CN113673525A (zh) * 2021-07-20 2021-11-19 广东技术师范大学 一种基于图像生成中文文本的方法、系统及装置
CN113673349B (zh) * 2021-07-20 2022-03-11 广东技术师范大学 基于反馈机制的图像生成中文文本方法、系统及装置
CN113468871A (zh) * 2021-08-16 2021-10-01 北京北大方正电子有限公司 文本纠错方法、装置及存储介质
CN113554040B (zh) * 2021-09-07 2024-02-02 西安交通大学 一种基于条件生成对抗网络的图像描述方法、装置设备
CN114022687B (zh) * 2021-09-24 2024-05-10 之江实验室 一种基于增强学习的图像描述对抗生成方法
CN113781598B (zh) * 2021-10-25 2023-06-30 北京邮电大学 图像生成模型的训练方法和设备以及图像生成方法
CN114006752A (zh) * 2021-10-29 2022-02-01 中电福富信息科技有限公司 基于gan压缩算法的dga域名威胁检测系统及其训练方法
CN115049899B (zh) * 2022-08-16 2022-11-11 粤港澳大湾区数字经济研究院(福田) 模型训练方法、指代表达式生成方法及相关设备
CN115953779B (zh) * 2023-03-03 2023-06-16 中国科学技术大学 基于文本对抗生成网络的无监督图像描述生成方法
CN116385597B (zh) * 2023-03-03 2024-02-02 阿里巴巴(中国)有限公司 文本配图方法和装置
CN117648921B (zh) * 2024-01-29 2024-05-03 山东财经大学 基于成对双层对抗对齐的跨主题作文自动测评方法及系统

Family Cites Families (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107330444A (zh) * 2017-05-27 2017-11-07 苏州科技大学 一种基于生成对抗网络的图像自动文本标注方法
CN107886169B (zh) * 2017-11-14 2021-02-12 华南理工大学 一种基于文本-图像生成对抗网络模型的多尺度卷积核方法
CN109543165B (zh) * 2018-11-21 2022-09-23 中国人民解放军战略支援部队信息工程大学 基于循环卷积注意力模型的文本生成方法及装置
CN109948776A (zh) * 2019-02-26 2019-06-28 华南农业大学 一种基于lbp的对抗网络模型图片标签生成方法
CN111159454A (zh) * 2019-12-30 2020-05-15 浙江大学 基于Actor-Critic生成式对抗网络的图片描述生成方法及系统
CN111754446A (zh) * 2020-06-22 2020-10-09 怀光智能科技(武汉)有限公司 一种基于生成对抗网络的图像融合方法、系统及存储介质
CN112052906B (zh) * 2020-09-14 2024-02-02 南京大学 一种基于指针网络的图像描述优化方法

Also Published As

Publication number Publication date
CN112818159A (zh) 2021-05-18

Similar Documents

Publication Publication Date Title
CN112818159B (zh) 一种基于生成对抗网络的图像描述文本生成方法
CN111832501B (zh) 一种面向卫星在轨应用的遥感影像文本智能描述方法
CN110188358B (zh) 自然语言处理模型的训练方法及装置
CN111858931B (zh) 一种基于深度学习的文本生成方法
CN110046252B (zh) 一种基于注意力机制神经网络与知识图谱的医疗文本分级方法
CN109214006B (zh) 图像增强的层次化语义表示的自然语言推理方法
CN113535953B (zh) 一种基于元学习的少样本分类方法
CN111046178B (zh) 一种文本序列生成方法及其系统
CN111597340A (zh) 一种文本分类方法及装置、可读存储介质
CN117475038B (zh) 一种图像生成方法、装置、设备及计算机可读存储介质
CN116110022B (zh) 基于响应知识蒸馏的轻量化交通标志检测方法及系统
CN114781375A (zh) 一种基于bert与注意力机制的军事装备关系抽取方法
CN111242059B (zh) 基于递归记忆网络的无监督图像描述模型的生成方法
CN114692732A (zh) 一种在线标签更新的方法、系统、装置及存储介质
CN112116685A (zh) 基于多粒度奖励机制的多注意力融合网络的图像字幕生成方法
CN112560438A (zh) 一种基于生成对抗网络的文本生成方法
CN111428518B (zh) 一种低频词翻译方法及装置
CN114332565A (zh) 一种基于分布估计的条件生成对抗网络文本生成图像方法
CN112926655B (zh) 一种图像内容理解与视觉问答vqa方法、存储介质和终端
Gangadhar et al. Analysis of optimization algorithms for stability and convergence for natural language processing using deep learning algorithms
CN116958548A (zh) 基于类别统计驱动的伪标签自蒸馏语义分割方法
CN114328921B (zh) 一种基于分布校准的小样本实体关系抽取方法
CN114936723A (zh) 一种基于数据增强的社交网络用户属性预测方法及系统
CN114925658A (zh) 开放性文本生成方法以及存储介质
KR20240034804A (ko) 자동 회귀 언어 모델 신경망을 사용하여 출력 시퀀스 평가

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant