CN110659727B - A Sketch-Based Image Generation Method - Google Patents
A Sketch-Based Image Generation Method Download PDFInfo
- Publication number
- CN110659727B CN110659727B CN201910909387.5A CN201910909387A CN110659727B CN 110659727 B CN110659727 B CN 110659727B CN 201910909387 A CN201910909387 A CN 201910909387A CN 110659727 B CN110659727 B CN 110659727B
- Authority
- CN
- China
- Prior art keywords
- sketch
- module
- training sample
- map
- attention
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 41
- 238000012549 training Methods 0.000 claims abstract description 69
- 238000005070 sampling Methods 0.000 claims abstract description 13
- 238000013507 mapping Methods 0.000 claims description 28
- 230000004044 response Effects 0.000 claims description 21
- 238000012545 processing Methods 0.000 claims description 10
- 230000008569 process Effects 0.000 claims description 4
- 230000006870 function Effects 0.000 description 31
- 230000009286 beneficial effect Effects 0.000 description 4
- 238000010586 diagram Methods 0.000 description 3
- 239000011159 matrix material Substances 0.000 description 3
- 238000013519 translation Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 2
- 238000004364 calculation method Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 230000004927 fusion Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T3/00—Geometric image transformations in the plane of the image
- G06T3/04—Context-preserving transformations, e.g. by using an importance map
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- Life Sciences & Earth Sciences (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
一种基于草图的图像生成方法,包括:S1,构建对抗生成模型,对抗生成模型包括生成器和判别器,其中,生成器包括下采样掩码残差模块、残差模块、上采样掩码残差模块和条件自注意力模块,判别器包括一个以上不同深度的子判别网络;S2,将一个及以上训练样本草图输入生成器,以生成每一训练样本草图对应的生成图像;S3,将训练样本草图对应的真实图像以及生成图像输入判别器,以计算出损失函数,并根据损失函数计算训练目标函数;S4,根据训练目标函数训练生成器和判别器的参数,以使得训练目标函数降到最小;S5,利用训练后的生成器来生成目标草图对应的目标生成图像。该方法确保生成的人脸图像具有真实的局部纹理以及完整的人脸结构。
A sketch-based image generation method, comprising: S1, constructing an adversarial generative model, and the adversarial generative model includes a generator and a discriminator, wherein the generator includes a down-sampling mask residual module, a residual module, and an up-sampling mask residual module Difference module and conditional self-attention module, the discriminator includes more than one sub-discriminatory network of different depths; S2, input one or more training sample sketches into the generator to generate the generated image corresponding to each training sample sketch; S3, train the training sample sketch The real image corresponding to the sample sketch and the generated image are input to the discriminator to calculate the loss function, and the training objective function is calculated according to the loss function; S4, the parameters of the generator and the discriminator are trained according to the training objective function, so that the training objective function is reduced to Minimum; S5, use the trained generator to generate the target generated image corresponding to the target sketch. This method ensures that the generated face images have real local textures and complete face structures.
Description
技术领域technical field
本公开涉及图像处理技术领域,具体地,涉及一种基于草图的图像生成方法。The present disclosure relates to the technical field of image processing, and in particular, to a sketch-based image generation method.
背景技术Background technique
基于草图的图像生成是一种图像到图像翻译的特例,在计算机图形学中有重要作用。早期基于草图的图像生成算法采用搜索融合的方法,先从大规模的图像数据库中搜索草图相关的图像块,再对图像块进行融合。近年来随着深度学习的迅猛发展,生成对抗网络越来越多地被应用在图像到图像翻译中。Isola等人提出基于有监督训练的条件生成对抗网络的图像到图像翻译通用模型,这个模型只针对稠密图像,对于输入图像是草图这类稀疏图像的生成效果不能令人满意。Sketch-based image generation is a special case of image-to-image translation that plays an important role in computer graphics. Early sketch-based image generation algorithms used the method of search fusion, which first searched for sketch-related image patches from a large-scale image database, and then fused the image patches. With the rapid development of deep learning in recent years, generative adversarial networks have been increasingly used in image-to-image translation. Isola et al. proposed a general model for image-to-image translation based on supervised training conditional generative adversarial networks. This model is only for dense images, and the generation effect of sparse images such as sketches is not satisfactory.
由于草图具有多样性、抽象性、稀疏性等特点,基于草图的人脸图像生成面临极大的挑战性。现有方法仍然不能生成理想的人脸图像,尤其是在当输入的草图没有包含完整的人脸结构(眼睛、鼻子和嘴巴等)的时候,生成的人脸图像往往也对应地缺失部分人脸结构。这是由于现有方法的网络模型都是基于卷积层,而卷积层的感知野十分有限,需要通过多层的卷积层才能达到全局的感知野,但是多层的卷积层可能使得网络不能很容易地学习到全局的结构信息;此外,现有方法的生成对抗网络使用的真实感判别器是对图像块进行局部判别,只能确保生成图像的局部真实感,并不能直接判别人脸图像的结构完整性。Due to the diversity, abstraction, and sparseness of sketches, sketch-based face image generation faces great challenges. Existing methods still cannot generate ideal face images, especially when the input sketches do not contain complete face structures (eyes, nose, mouth, etc.), the generated face images often have correspondingly missing parts of the face structure. This is because the network models of the existing methods are all based on convolutional layers, and the perceptual field of the convolutional layer is very limited, and the global perceptual field needs to be achieved through multi-layered convolutional layers, but the multi-layered convolutional layers may make The network cannot easily learn the global structural information; in addition, the realism discriminator used by the generative adversarial network of the existing method is to locally discriminate the image blocks, which can only ensure the local realism of the generated image, and cannot directly discriminate others. Structural integrity of face images.
发明内容SUMMARY OF THE INVENTION
(一)要解决的技术问题(1) Technical problems to be solved
本公开鉴于上述问题,提供了一种基于草图的图像生成方法,通过在条件生成对抗网络中引入自注意力机制以及多尺度判别器,以解决以上技术问题。In view of the above problems, the present disclosure provides a sketch-based image generation method, which solves the above technical problems by introducing a self-attention mechanism and a multi-scale discriminator into a conditional generative adversarial network.
(二)技术方案(2) Technical solutions
本公开提供了一种基于草图的图像生成方法,包括:S1,构建对抗生成模型,所述对抗生成模型包括生成器和判别器,其中,所述生成器包括下采样掩码残差模块、残差模块、上采样掩码残差模块和条件自注意力模块,所述判别器包括一个以上不同深度的子判别网络;S2,将一个及以上训练样本草图输入所述生成器,以生成每一所述训练样本草图对应的生成图像;S3,将所述训练样本草图及其对应的真实图像和生成图像输入所述判别器,以计算出损失函数,并根据所述损失函数计算训练目标函数;S4,根据所述训练目标函数训练所述生成器和判别器的参数,以使得所述训练目标函数降到最小;S5,利用所述训练后的生成器来生成目标草图对应的目标生成图像。The present disclosure provides a sketch-based image generation method, including: S1, constructing an adversarial generative model, where the adversarial generative model includes a generator and a discriminator, wherein the generator includes a downsampling mask residual module, a residual A difference module, an upsampling mask residual module and a conditional self-attention module, the discriminator includes more than one sub-discriminatory network of different depths; S2, one or more training sample sketches are input into the generator to generate each The generated image corresponding to the training sample sketch; S3, input the training sample sketch and its corresponding real image and generated image to the discriminator to calculate a loss function, and calculate a training objective function according to the loss function; S4, the parameters of the generator and the discriminator are trained according to the training objective function, so that the training objective function is minimized; S5, the trained generator is used to generate the target generation image corresponding to the target sketch.
可选地,所述下采样掩码残差模块与所述上采样掩码残差模块的数量均为N,所述下采样掩码残差模块、残差模块、前N-1个上采样掩码残差模块、条件自注意力模块以及最后一个上采样掩码残差模块依次相连,第一个至第N-1个下采样掩码残差模块的输出还分别连接至第N个至第二个上采样掩码残差模块,所述步骤S2包括:将所述训练样本草图输入第一个下采样掩码残差模块,依次经所述下采样掩码残差模块、残差模块、前N-1个上采样掩码残差模块、条件自注意力模块以及最后一个上采样掩码残差模块处理后,最后一个上采样掩码残差模块输出每一所述训练样本草图对应的生成图像。Optionally, the number of the downsampling mask residual module and the upsampling mask residual module is N, the downsampling mask residual module, the residual module, the first N-1 upsampling modules. The mask residual module, the conditional self-attention module, and the last upsampling mask residual module are connected in sequence, and the outputs of the first to N-1th downsampling mask residual modules are also connected to the Nth to Nth ones, respectively. For the second up-sampling mask residual module, the step S2 includes: inputting the training sample sketch into the first down-sampling mask residual module, and sequentially passing through the down-sampling mask residual module and the residual module , the first N-1 upsampling mask residual modules, the conditional self-attention module and the last upsampling mask residual module are processed, and the last upsampling mask residual module outputs the corresponding training sample sketch. generated image.
可选地,所述条件自注意力模块的处理包括:将接收到的输入特征图与所述训练样本草图串联,得到条件特征;根据三个映射矩阵分别对所述条件特征进行映射,以得到三个映射特征图,所述映射矩阵由可训练参数构成;对所述三个映射特征图进行处理,以得到响应图;将所述响应图与所述输入特征图相加,以得到输出特征图。Optionally, the processing of the conditional self-attention module includes: concatenating the received input feature map with the training sample sketch to obtain conditional features; respectively mapping the conditional features according to three mapping matrices to obtain three mapping feature maps, the mapping matrix is composed of trainable parameters; processing the three mapping feature maps to obtain response maps; adding the response maps to the input feature maps to obtain output features picture.
可选地,所述映射特征图分别为:f([a,x])=Wf[a,x]、g([a,x])=Wg[a,x]、h([a,x])=Wh[a,x],其中,a为所述输入特征图,x为所述训练样本草图按特征图同等分辨率的缩放图,f([a,x])、g([a,x])、h([a,x])分别为所述三个映射特征图,Wf、Wg、Wh是所述映射矩阵,a∈RC×H×W,x∈R1×H×W,Wf∈RD×(C+1),Wg∈RD×(C+1),Wh∈RC×(C+11),D=C/8,C、H、W分别为所述输入特征图的通道数、高和宽。Optionally, the mapping feature maps are respectively: f ([a,x])=Wf[a,x], g ([a,x])=Wg[a,x], h([a , x])=W h [a, x], where a is the input feature map, x is the scaled image of the training sample sketch at the same resolution as the feature map, f([a, x]), g ([a, x]), h([a, x]) are the three mapping feature maps respectively, W f , W g , W h are the mapping matrices, a∈R C×H×W , x ∈R 1×H×W , W f ∈R D×(C+1) , W g ∈R D×(C+1) , W h ∈R C×(C+11) , D=C/8, C, H, and W are the channel number, height, and width of the input feature map, respectively.
可选地,所述对所述三个映射特征图进行处理,以得到响应图包括:对所述映射特征图f([a,x])与映射特征图g([a,x])进行处理,以得到注意力图;对所述注意力图与映射特征图h([a,x])进行处理,以得到所述响应图。Optionally, the processing of the three mapping feature maps to obtain a response map includes: performing the mapping feature map f([a, x]) and the mapping feature map g([a, x]) on process to obtain the attention map; process the attention map and the mapping feature map h([a, x]) to obtain the response map.
可选地,所述响应图为:r=(r1,r2,……,rN)∈RC×N,其中,r为所述响应图,N=H×W, si,j=f([a,x])Tg([a,x])。Optionally, the response graph is: r=(r 1 , r 2 ,...,r N )∈R C×N , where r is the response graph, N=H×W, s i,j = f([a, x]) T g([a, x]).
可选地,所述输出特征图为:oj=γri+aj,其中,γ为可训练权重参数,初始值为0,oj为所述输出特征图的第j个像素,rj为所述响应图的第j个像素,aj为所述条件自注意力模块接收到的输入特征图的第j个像素。Optionally, the output feature map is: o j =γr i +a j , where γ is a trainable weight parameter, the initial value is 0, o j is the jth pixel of the output feature map, r j is the jth pixel of the response map, and a j is the jth pixel of the input feature map received by the conditional self-attention module.
可选地,所述判别器由一层以上的卷积层组成,不同子判别网络的卷积层数不同,各个子判别网络的超参数相同。Optionally, the discriminator is composed of more than one convolutional layer, the number of convolutional layers of different sub-discriminant networks is different, and the hyperparameters of each sub-discriminant network are the same.
可选地,所述损失函数包括对抗损失函数Ladv(G;D)、重建损失函数LL1(G)和特征匹配损失函数Lfm(G),其中:Optionally, the loss function includes an adversarial loss function L adv (G; D), a reconstruction loss function L L1 (G) and a feature matching loss function L fm (G), wherein:
所述训练目标函数为其中,x为所述训练样本草图,y为所述真实图像,ND为所述子判别网络的个数,G(x)为所述训练样本草图对应的生成图像,Dk(x,y)为第k个子判别网络根据训练样本草图和真实图像得到的输出,Dk(x,G(x))为第k个子判别网络根据训练样本草图和生成图像得到的输出,为在(x,y)的数据分布上求期望,为在x的数据分布上求期望,Q为选出来的子判别网络的特征层的集合,NQ为每个子判别网络的元素个数,nq为第q层特征层的元素个数,为第k个子判别网络根据生成图像的第q层的中间输出特征图,为第k个子判别网络根据真实图像的第q层的中间输出特征图,λ为重建损失函数LL1(G)的权重,μ为特征匹配损失函数Lfm(G)的权重。The training objective function is Wherein, x is the sketch of the training sample, y is the real image, N D is the number of the sub-discriminatory networks, G(x) is the generated image corresponding to the sketch of the training sample, D k (x, y ) is the output obtained by the kth sub-discriminant network according to the training sample sketch and the real image, D k (x, G(x)) is the output obtained by the kth sub-discriminant network according to the training sample sketch and the generated image, To find the expectation on the data distribution of (x, y), In order to find the expectation on the data distribution of x, Q is the set of feature layers of the selected sub-discriminant network, N Q is the number of elements of each sub-discriminant network, n q is the number of elements of the qth layer feature layer, is the intermediate output feature map of the qth layer of the generated image for the kth sub-discriminant network, is the intermediate output feature map of the kth sub-discriminant network according to the qth layer of the real image, λ is the weight of the reconstruction loss function L L1 (G), and μ is the weight of the feature matching loss function L fm (G).
可选地,所述步骤S4包括:根据所述训练目标函数训练所述对抗生成模型中自注意力模块以外的参数;固定所述对抗生成模型中自注意力模块以外的参数,训练所述自注意力模块的参数;同时训练所述对抗生成模型中的参数,以使得所述目标训练函数降到最小。Optionally, the step S4 includes: training parameters other than the self-attention module in the confrontation generation model according to the training objective function; fixing parameters other than the self-attention module in the confrontation generation model, training the self-attention module. The parameters of the attention module; meanwhile, the parameters in the adversarial generative model are trained to minimize the target training function.
(三)有益效果(3) Beneficial effects
本公开提供的基于草图的图像生成方法,具有以下有益效果:The sketch-based image generation method provided by the present disclosure has the following beneficial effects:
(1)通过在条件生成对抗网络中引入条件自注意力模块,可以直接学习到图像的长距离依赖;(1) By introducing a conditional self-attention module into the conditional generative adversarial network, the long-distance dependencies of images can be directly learned;
(2)通过在条件生成对抗网络中引入多尺度判别器,可以在不同尺度上对生成图像的真实感进行判别,以确保生成图像的局部纹理和细节具有真实感,并当草图不完整时确保生成图像的结构完整;(2) By introducing a multi-scale discriminator into the conditional generative adversarial network, the authenticity of the generated image can be discriminated at different scales to ensure that the local texture and details of the generated image are realistic, and when the sketch is incomplete The structure of the generated image is complete;
(3)通过将各子判别网络共享前面几层的参数,可以减少网络的参数个数,并且有利于网络收敛。(3) By sharing the parameters of the previous layers in each sub-discrimination network, the number of parameters of the network can be reduced, and it is beneficial to the network convergence.
附图说明Description of drawings
图1示意性示出了本公开实施例提供的基于草图的图像生成方法的流程图;FIG. 1 schematically shows a flowchart of a sketch-based image generation method provided by an embodiment of the present disclosure;
图2A示意性示出了本公开实施例提供的基于草图的图像生成方法中构建对抗生成模型的生成器的结构示意图;2A schematically shows a schematic structural diagram of a generator for constructing an adversarial generation model in the sketch-based image generation method provided by an embodiment of the present disclosure;
图2B示意性示出了本公开实施例提供的基于草图的图像生成方法中构建对抗生成模型的判别器的结构示意图;FIG. 2B schematically shows a schematic structural diagram of a discriminator for constructing an adversarial generation model in the sketch-based image generation method provided by an embodiment of the present disclosure;
图3示意性示出了图2A所示生成器中条件自注意力模块的示意图。Figure 3 schematically shows a schematic diagram of the conditional self-attention module in the generator shown in Figure 2A.
具体实施方式Detailed ways
为使本公开的目的、技术方案和优点更加清楚明白,以下结合具体实施例,并参照附图,对本公开进一步详细说明。In order to make the objectives, technical solutions and advantages of the present disclosure clearer, the present disclosure will be further described in detail below with reference to the specific embodiments and the accompanying drawings.
本实施例提供了一种基于草图的图像生成方法,参阅图1,结合图2A、图2B和图3,对图1所示方法进行详细说明,该方法包括以下操作。This embodiment provides a sketch-based image generation method. Referring to FIG. 1 , the method shown in FIG. 1 is described in detail with reference to FIGS. 2A , 2B and 3 , and the method includes the following operations.
S1,构建对抗生成模型,对抗生成模型包括生成器和判别器,其中,生成器包括下采样掩码残差模块、残差模块、上采样掩码残差模块和条件自注意力模块,判别器包括一个以上不同深度的子判别网络。S1, build an adversarial generative model. The adversarial generative model includes a generator and a discriminator. The generator includes a downsampling mask residual module, a residual module, an upsampling mask residual module and a conditional self-attention module. The discriminator Include more than one sub-discriminative network of different depths.
本实施例中,生成器是一个“编码-解码”结构的网络。编码部分包括N个下采样掩码残差模块,解码部分包括N个上采样掩码残差模块,最后一个上采样掩码残差模块之前插入有条件自注意力模块。条件自注意力模块可以有效学习到输入特征图中的长距离依赖信息,帮助网络学习到图像的结构信息。编码部分和解码部分之间还可以加入若干个残差模块,以增强网络的容量和拟合能力。In this embodiment, the generator is a network with an "encoding-decoding" structure. The encoding part includes N downsampling mask residual modules, the decoding part includes N upsampling mask residual modules, and a conditional self-attention module is inserted before the last upsampling mask residual module. The conditional self-attention module can effectively learn the long-distance dependency information in the input feature map and help the network learn the structural information of the image. Several residual modules can also be added between the encoding part and the decoding part to enhance the capacity and fitting ability of the network.
具体地,参阅图2A,下采样掩码残差模块与上采样掩码残差模块的数量相同,均为N,下采样掩码残差模块、残差模块、前N-1个上采样掩码残差模块、条件自注意力模块以及最后一个上采样掩码残差模块依次相连,第一个至第N-1个下采样掩码残差模块的输出还分别连接至第N个至第二个上采样掩码残差模块,即解码部分和编码部分采用跳转连接,以确保低层信息(即下采样模块输出信息)可以直接传递到高层(即上采样模块)。本公开实施例中,依次相连是指前一个模块的输出连接至后一个模块的输入,即将前一个模块的输出特征图作为后一个模块的输入特征图。Specifically, referring to FIG. 2A, the number of downsampling mask residual modules and upsampling mask residual modules is the same, both are N, the downsampling mask residual module, the residual module, and the first N-1 upsampling mask residual modules. The code residual module, the conditional self-attention module and the last upsampling mask residual module are connected in turn, and the outputs of the first to N-1th downsampling mask residual modules are also connected to the Nth to Nth The two upsampling mask residual modules, namely the decoding part and the encoding part, are connected by jumps to ensure that the lower layer information (ie the output information of the downsampling module) can be directly transferred to the high layer (ie the upsampling module). In the embodiment of the present disclosure, being connected in sequence means that the output of the previous module is connected to the input of the next module, that is, the output feature map of the previous module is used as the input feature map of the latter module.
可以理解的是,本公开实施例的对抗生成模型中,也可以不包括残差模块,即残差模块的数量为0。残差模块的数量越多,生成器的拟合能力越强,但是残差模块过多时会影响对抗生成模型的训练收敛速度和计算速度,因此,优选地,将残差模块的数量设置在1-8之内。It can be understood that, the confrontation generation model in the embodiment of the present disclosure may also not include residual modules, that is, the number of residual modules is 0. The greater the number of residual modules, the stronger the fitting ability of the generator. However, when there are too many residual modules, it will affect the training convergence speed and calculation speed of the adversarial generation model. Therefore, it is preferable to set the number of residual modules to 1 within -8.
本公开实施例中,以N=6,残差模块的数量为8为例,示意性说明生成器的参数,以训练样本草图的尺寸为256×256×1为例,示意性说明生成器每一模块的输出尺寸,如表1所示。同时,本领域技术人员可以根据本实施例的描述得到其它数量及参数的下采样掩码残差模块、上采样掩码残差模块、残差模块、条件自注意力模块组成的生成器。In the embodiment of the present disclosure, taking N=6 and the number of residual modules as 8 as an example, the parameters of the generator are schematically illustrated. Taking the size of the training sample sketch as 256×256×1 as an example, it is schematically illustrated that every The output size of a module is shown in Table 1. Meanwhile, those skilled in the art can obtain generators composed of down-sampling mask residual modules, up-sampling mask residual modules, residual modules, and conditional self-attention modules with other numbers and parameters according to the description of this embodiment.
表1Table 1
判别器的输入是训练样本草图对应的生成图像和真实图像,输出是对输入图像的真实感的判别。参阅图2B,根据本公开的实施例,判别器由多个子判别网络组成,每个子判别网络有三个以上的卷积层,各个子判别网络中前三个卷积层的权值共享。各个子判别网络的超参数相同,仅卷积层数不同。超参数为人为预设的参数,例如卷积核尺寸、步长等。由于各子判别网络的深度不同,所以其输出特征图每个像素对应输入图像的感知野不同,因此不同深度的子判别网络分别判别生成图像在不同尺度下的真实感;而且由于各子判别网络的底层特征是一致的,各子判别网络共享前面几层的参数,这样可以减少网络的参数个数,并且有利于网络收敛。The input of the discriminator is the generated image and the real image corresponding to the training sample sketch, and the output is the judgment of the realism of the input image. Referring to FIG. 2B , according to an embodiment of the present disclosure, the discriminator is composed of a plurality of sub-discriminatory networks, each of which has more than three convolutional layers, and the weights of the first three convolutional layers in each sub-discriminative network are shared. The hyperparameters of each sub-discriminative network are the same, and only the number of convolutional layers is different. Hyperparameters are preset parameters, such as convolution kernel size, stride, etc. Since the depths of each sub-discriminant network are different, each pixel of the output feature map corresponds to a different perceptual field of the input image, so the sub-discriminatory networks of different depths respectively discriminate the realism of the generated image at different scales; The underlying features of the network are consistent, and each sub-discriminatory network shares the parameters of the previous layers, which can reduce the number of network parameters and is conducive to network convergence.
本公开实施例中,判别器是多尺度判别器,其各个子判别网络的结构一致,深度不同。以最大深度的子判别网络包含8个卷积模块为例,示意性说明多尺度判别器的结构,如表2所示,表2中仅示出最大深度的子判别网络。表2所示实施例中卷积模块1、卷积模块2、卷积模块3的参数被各个子判别网络共享,在此基础上添加其他卷积模块以形成子判别网络。实际应用中,可根据需求选择子判别网络的个数,例如仅设置两个子判别网络,分别为卷积模块1、2、3、4构成的子判别网络和卷积模块1、2、3、4、5、6、7、8构成的子判别网络。In the embodiment of the present disclosure, the discriminator is a multi-scale discriminator, and each sub-discrimination network has the same structure and different depths. Taking the sub-discriminant network with the maximum depth as an example, which includes 8 convolution modules, the structure of the multi-scale discriminator is schematically illustrated, as shown in Table 2, which only shows the sub-discriminatory network with the maximum depth. In the embodiment shown in Table 2, the parameters of convolution module 1, convolution module 2, and convolution module 3 are shared by each sub-discriminant network, and on this basis, other convolution modules are added to form a sub-discriminant network. In practical applications, the number of sub-discriminant networks can be selected according to requirements. For example, only two sub-discriminant networks are set up, which are the sub-discriminatory networks composed of convolution modules 1, 2, 3, and 4, and the convolution modules 1, 2, 3, and 3 respectively. The sub-discriminant network composed of 4, 5, 6, 7, and 8.
表2Table 2
S2,将一个及以上训练样本草图输入生成器,以生成每一训练样本草图对应的生成图像。S2, inputting one or more training sample sketches into the generator to generate a generated image corresponding to each training sample sketch.
具体地,将训练样本草图输入第一个下采样掩码残差模块,依次经下采样掩码残差模块、残差模块、前N-1个上采样掩码残差模块、条件自注意力模块以及最后一个上采样掩码残差模块处理后,最后一个上采样掩码残差模块输出每一训练样本草图对应的生成图像。Specifically, the training sample sketch is input into the first downsampling mask residual module, followed by the downsampling mask residual module, the residual module, the first N-1 upsampling mask residual modules, and the conditional self-attention After processing by the module and the last upsampling mask residual module, the last upsampling mask residual module outputs the generated image corresponding to each training sample sketch.
参阅图3,条件自注意力模块的处理过程为:将接收到的输入特征图与训练样本草图串联,得到条件特征;根据三个映射矩阵对条件特征进行映射,以得到三个映射特征图,该映射矩阵由大量可训练参数构成;对三个映射特征图进行处理,以得到响应图;将响应图逐元素添加至输入特征图,以得到输出特征图。Referring to Fig. 3, the processing process of the conditional self-attention module is: concatenate the received input feature map with the training sample sketch to obtain the conditional feature; map the conditional feature according to the three mapping matrices to obtain three mapping feature maps, The mapping matrix consists of a large number of trainable parameters; the three mapping feature maps are processed to obtain a response map; the response map is added element-wise to the input feature map to obtain an output feature map.
具体地,条件自注意力模块接收到的输入特征图a为第N-1个上采样模块的输出特征图以及第一个下采样模块的输出特征图,a∈RC×H×W。首先,需要修改训练样本草图的尺寸,以使得训练样本草图x的尺寸为x∈R1×H×W,并在通道方向串联,以得到条件特征[a,x]。Specifically, the input feature map a received by the conditional self-attention module is the output feature map of the N-1th upsampling module and the output feature map of the first downsampling module, a∈R C×H×W . First, the size of the training sample sketch needs to be modified so that the size of the training sample sketch x is x∈R 1×H×W , and concatenated in the channel direction to obtain the conditional features [a, x].
根据三个映射矩阵对条件特征进行映射,以将条件特征映射到三个新的特征空间上,得到三个映射特征图,分别为:The conditional features are mapped according to the three mapping matrices to map the conditional features to three new feature spaces, and three mapped feature maps are obtained, which are:
f([a,x])=Wf[a,x]f([a,x])=W f [a,x]
g([a,x])=Wg[a,x]g([a,x])=W g [a,x]
h([a,x])=Wh[a,x]h([a,x])=W h [a,x]
其中,a为输入特征图,x为训练样本草图按特征图同等分辨率的缩放图,f([a,x])、g([a,x])、h([a,x])分别为三个映射特征图,Wf、Wg、Wh是映射矩阵,a∈RC×H×W,x∈R1 ×H×W,Wf∈RD×(C+1),Wg∈RD×(C+1),Wh∈RC×(C+1),D=C/8,C、H、W分别为输入特征图的通道数、高和宽。本公开实施例中,不是显式使用矩阵来实现Wf、Wg、Wh,而是分别使用1×1的卷积层来实现Wf、Wg、Wh。Among them, a is the input feature map, x is the scaled image of the training sample sketch at the same resolution as the feature map, f([a, x]), g([a, x]), h([a, x]) respectively are three mapping feature maps, W f , W g , W h are mapping matrices, a∈R C×H×W , x∈R 1 ×H×W , W f ∈R D×(C+1) , W g ∈ R D×(C+1) , W h ∈ R C×(C+1) , D=C/8, C, H, and W are the number of channels, height and width of the input feature map, respectively. In the embodiment of the present disclosure, instead of explicitly using a matrix to implement W f , W g , and W h , a 1×1 convolution layer is used to implement W f , W g , and W h , respectively.
进一步地,对映射特征图f([a,x])与映射特征图g([a,x])进行处理后,得到注意力图,并对注意力图和映射特征图g([a,x])进行处理后,得到响应图。Further, after processing the mapped feature map f([a, x]) and the mapped feature map g([a, x]), the attention map is obtained, and the attention map and the mapped feature map g([a, x] ) to get the response graph.
具体地,将映射特征图f([a,x])转置后与映射特征图g([a,x])相乘,并进行归一化处理以得到注意力图,并将注意力图与映射特征图h([a,x])相乘,得到响应图r。Specifically, the mapped feature map f([a, x]) is transposed and multiplied by the mapped feature map g([a, x]), and normalized to obtain the attention map, and the attention map is combined with the map The feature map h([a, x]) is multiplied to get the response map r.
r=(r1,r2,……,rN)∈RC×N r=(r 1 , r 2 ,...,r N )∈R C×N
其中,N=H×W, si,j=f([a,x])Tg([a,x])。令B=RN×N为注意力图,B的每个元素记为bi,j,表示在合成响应图的第j个像素的时候,当前映射特征图h([a,x])第i个个像素的权重。Among them, N=H×W, s i,j = f([a, x]) T g([a, x]). Let B=R N×N be the attention map, and each element of B is denoted as bi ,j , indicating that when the jth pixel of the response map is synthesized, the current mapping feature map h([a, x]) The ith pixel The weight of each pixel.
进一步地,将响应图逐元素添加至输入特征图上,构成一个残差的结构,得到条件自注意力模块的输出特征图:Further, the response map is added to the input feature map element by element to form a residual structure, and the output feature map of the conditional self-attention module is obtained:
oj=γrj+dj o j =γr j +d j
其中,γ为可训练权重参数,初始值为0,oj为输出特征图的第j个像素,rj为响应图的第j个像素,aj为条件自注意力模块接收到的输入特征图的第j个像素。Among them, γ is the trainable weight parameter, the initial value is 0, o j is the j-th pixel of the output feature map, r j is the j-th pixel of the response map, and a j is the input feature received by the conditional self-attention module The jth pixel of the graph.
通过上述条件自注意力模块,生成器即可逐步学习到图像的长距离依赖。Through the above conditional self-attention module, the generator can gradually learn the long-range dependencies of images.
S3,将训练样本草图及其对应的真实图像和生成图像输入判别器,以计算出损失函数,并根据损失函数计算训练目标函数。S3, input the training sample sketch and its corresponding real image and generated image into the discriminator to calculate the loss function, and calculate the training objective function according to the loss function.
本公开实施例中,损失函数包括对抗损失函数Ladv(G;D)、重建损失函数LL1(G)和特征匹配损失函数Lfm(G),其中:In the embodiment of the present disclosure, the loss function includes an adversarial loss function L adv (G; D), a reconstruction loss function L L1 (G), and a feature matching loss function L fm (G), wherein:
x为训练样本草图,y为真实图像,ND为子判别网络的个数,G(x)为训练样本草图对应的生成图像,Dk(x,y)为第k个子判别网络根据训练样本草图和真实图像得到的输出,Dk(x,G(x))为第k个子判别网络根据训练样本草图和生成图像得到的输出,为在(x,y)的数据分布上求期望,为在x的数据分布上求期望,Q为选出来的子判别网络的特征层的集合,NQ为每个子判别网络的元素个数,nq为第q层特征层的元素个数,为第k个子判别网络根据生成图像的第q层的中间输出特征图,为第k个子判别网络根据真实图像的第q层的中间输出特征图。x is the training sample sketch, y is the real image, N D is the number of sub-discriminant networks, G(x) is the generated image corresponding to the training sample sketch, D k (x, y) is the kth sub-discriminant network according to the training sample The output obtained from the sketch and the real image, D k (x, G(x)) is the output obtained by the kth sub-discriminant network based on the training sample sketch and the generated image, To find the expectation on the data distribution of (x, y), In order to find the expectation on the data distribution of x, Q is the set of feature layers of the selected sub-discriminant network, N Q is the number of elements of each sub-discriminant network, n q is the number of elements of the qth layer feature layer, is the intermediate output feature map of the qth layer of the generated image for the kth sub-discriminant network, is the intermediate output feature map of the qth layer of the real image for the kth sub-discriminative network.
对抗损失函数Ladv(G;D)用于使得生成对抗网络(即对抗生成模型)完成对抗训练,以确保生成图像的真实感。重建损失函数LL1(G)用于使得生成图像向真实图像靠近。特征匹配损失函数Lfm(G)用于使得生成图像在特征空间上向真实图像靠近。The adversarial loss function La adv (G; D) is used to make the generative adversarial network (ie, the adversarial generative model) complete the adversarial training to ensure the realism of the generated images. The reconstruction loss function L L1 (G) is used to make the generated image closer to the real image. The feature matching loss function L fm (G) is used to make the generated image close to the real image in the feature space.
进一步地,还可以根据Ladv(G;D)、LL1(G)和Lfm(G)计算对抗生成模型的训练目标函数:Further, the training objective function of the adversarial generative model can also be calculated according to La adv (G; D), L L1 (G) and L fm (G):
其中,λ和μ分别为重建损失函数LL1(G)和特征匹配损失函数Lfm(G)的权重,实验中例如选取λ=100.0,μ=1.0。Among them, λ and μ are the weights of the reconstruction loss function L L1 (G) and the feature matching loss function L fm (G) respectively. For example, λ=100.0 and μ=1.0 are selected in the experiment.
S4,根据训练目标函数训练生成器和判别器的参数,以使得训练目标函数降到最小。S4, train the parameters of the generator and the discriminator according to the training objective function, so as to minimize the training objective function.
交替训练生成器和判别器的参数,即先固定判别器参数,训练生成器一个迭代;在固定生成器,训练判别器一个迭代,交替反复进行。The parameters of the generator and the discriminator are alternately trained, that is, the parameters of the discriminator are fixed first, and the generator is trained for one iteration; when the generator is fixed, the discriminator is trained for one iteration, which is repeated alternately.
进一步地,为了对抗生成模型更好地收敛,将操作S4分为三个阶段,分别包括以下操作:Further, in order to better converge against the generative model, operation S4 is divided into three stages, including the following operations:
根据训练目标函数训练对抗生成模型中自注意力模块以外的参数,即训练条件自注意力模块以外的模型;固定对抗生成模型中自注意力模块以外的参数,训练自注意力模块的参数,即只训练条件自注意力模块;同时训练对抗生成模型中的参数,以使得训练目标函数降到最小,即同时训练模型中的所有可训练参数。According to the training objective function, the parameters other than the self-attention module in the adversarial generative model are trained, that is, the model other than the conditional self-attention module is trained; the parameters other than the self-attention module in the adversarial generative model are fixed, and the parameters of the training self-attention module are Only the conditional self-attention module is trained; the parameters in the adversarial generative model are also trained to minimize the training objective function, that is, all trainable parameters in the model are simultaneously trained.
S5,利用训练后的生成器来生成目标草图对应的目标生成图像。S5, using the trained generator to generate the target generated image corresponding to the target sketch.
根据上述操作S1-S4完成本公开实施例中对抗生成模型的训练,即可根据该训练好的对抗生成模型中的生成器生成目标草图对应的目标生成图像。The training of the adversarial generative model in the embodiment of the present disclosure is completed according to the above-mentioned operations S1-S4, that is, the target generative image corresponding to the target sketch can be generated according to the generator in the trained adversarial generative model.
以下通过将本公开基于草图的图像生成方法与现有具有较好性能的方法pix2pix、SketchyGAN进行比较,说明本公开方法的优越性。The advantages of the method of the present disclosure will be explained below by comparing the sketch-based image generation method of the present disclosure with the existing methods pix2pix and SketchyGAN with better performance.
IS、FID、KID是三个常用的评价生成图像质量的指标。IS分数越高,生成图像的真实感和多样性越高;FID和KID数值越低,生成图像的真实感越高。表3客观地示出了本公开的方法与pix2pix、SketchyGAN的IS、FID、KID指标,可以看出本公开方法即使不使用条件自注意力模块或多尺度判别器时,其性能仍优于pix2pix、SketchyGAN,当同时使用条件自注意力模块和多尺度判别器时,其性能远远优于pix2pix、SketchyGAN。由此,也说明了本公开中条件自注意力模块和多尺度判别器均可以有效增强对抗生成模型的生成效果。IS, FID, KID are three commonly used indicators to evaluate the quality of generated images. The higher the IS score, the higher the realism and variety of the generated images; the lower the FID and KID values, the higher the realism of the generated images. Table 3 objectively shows the IS, FID, and KID indicators of the disclosed method and pix2pix and SketchyGAN. It can be seen that the disclosed method still performs better than pix2pix even if the conditional self-attention module or multi-scale discriminator is not used. , SketchyGAN, when using both conditional self-attention module and multi-scale discriminator, its performance is far better than pix2pix, SketchyGAN. Thus, it is also illustrated that both the conditional self-attention module and the multi-scale discriminator in the present disclosure can effectively enhance the generation effect of the adversarial generative model.
表3table 3
综上所述,本公开提供的基于草图的图像生成方法,在条件生成对抗网络中引入条件自注意力模块以及多尺度判别器,条件自注意力模块使得对抗生成模型可以直接学习到图像的长距离依赖信息,多尺度判别器确保了生成图像的局部纹理的真实性和全局结构的完整性,使得对抗生成模型具有较强的鲁棒性。In summary, the sketch-based image generation method provided by the present disclosure introduces a conditional self-attention module and a multi-scale discriminator into the conditional generative adversarial network. The conditional self-attention module enables the adversarial generative model to directly learn the length of the image. Distance-dependent information, the multi-scale discriminator ensures the authenticity of the local texture and the integrity of the global structure of the generated image, making the adversarial generative model more robust.
以上所述的具体实施例,对本公开的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上所述仅为本公开的具体实施例而已,并不用于限制本公开,凡在本公开的精神和原则之内,所做的任何修改、等同替换、改进等,均应包含在本公开的保护范围之内。The specific embodiments described above further describe the purpose, technical solutions and beneficial effects of the present disclosure in detail. It should be understood that the above-mentioned specific embodiments are only specific embodiments of the present disclosure, and are not intended to limit the present disclosure. Any modification, equivalent replacement, improvement, etc. made within the spirit and principle of the present disclosure should be included within the protection scope of the present disclosure.
Claims (10)
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910909387.5A CN110659727B (en) | 2019-09-24 | 2019-09-24 | A Sketch-Based Image Generation Method |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910909387.5A CN110659727B (en) | 2019-09-24 | 2019-09-24 | A Sketch-Based Image Generation Method |
Publications (2)
Publication Number | Publication Date |
---|---|
CN110659727A CN110659727A (en) | 2020-01-07 |
CN110659727B true CN110659727B (en) | 2022-05-13 |
Family
ID=69039033
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910909387.5A Active CN110659727B (en) | 2019-09-24 | 2019-09-24 | A Sketch-Based Image Generation Method |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN110659727B (en) |
Families Citing this family (21)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113313133A (en) * | 2020-02-25 | 2021-08-27 | 武汉Tcl集团工业研究院有限公司 | Training method for generating countermeasure network and animation image generation method |
CN111428761B (en) * | 2020-03-11 | 2023-03-28 | 深圳先进技术研究院 | Image feature visualization method, image feature visualization device and electronic equipment |
CN111382845B (en) * | 2020-03-12 | 2022-09-02 | 成都信息工程大学 | Template reconstruction method based on self-attention mechanism |
CN111489405B (en) * | 2020-03-21 | 2022-09-16 | 复旦大学 | Face Sketch Synthesis System Based on Conditional Augmented Generative Adversarial Networks |
CN111489287B (en) * | 2020-04-10 | 2024-02-09 | 腾讯科技(深圳)有限公司 | Image conversion method, device, computer equipment and storage medium |
CN113592724B (en) * | 2020-04-30 | 2025-06-10 | 北京金山云网络技术有限公司 | Method and device for repairing target face image |
CN111508069B (en) * | 2020-05-22 | 2023-03-21 | 南京大学 | Three-dimensional face reconstruction method based on single hand-drawn sketch |
CN112132172B (en) * | 2020-08-04 | 2025-01-17 | 绍兴埃瓦科技有限公司 | Model training method, device, equipment and medium based on image processing |
CN112070658B (en) * | 2020-08-25 | 2024-04-16 | 西安理工大学 | Deep learning-based Chinese character font style migration method |
CN112149802B (en) * | 2020-09-17 | 2022-08-09 | 广西大学 | Image content conversion method with consistent semantic structure |
CN112862110B (en) * | 2021-02-11 | 2024-01-30 | 脸萌有限公司 | Model generation method and device and electronic equipment |
CN112949553A (en) * | 2021-03-22 | 2021-06-11 | 陈懋宁 | Face image restoration method based on self-attention cascade generation countermeasure network |
CN112837215B (en) | 2021-03-31 | 2022-10-18 | 电子科技大学 | Image shape transformation method based on generation countermeasure network |
CN113205521A (en) * | 2021-04-23 | 2021-08-03 | 复旦大学 | Image segmentation method of medical image data |
CN113269256B (en) * | 2021-05-26 | 2024-08-27 | 广州密码营地信息科技有限公司 | Construction method and application of MiSrc-GAN medical image model |
CN113823296B (en) * | 2021-06-15 | 2024-11-22 | 腾讯科技(深圳)有限公司 | Voice data processing method, device, computer equipment and storage medium |
CN114140739A (en) * | 2021-09-04 | 2022-03-04 | 重庆大学 | Video deblurring method based on patch matching and synthesis |
CN113902130A (en) * | 2021-10-29 | 2022-01-07 | 中国工商银行股份有限公司 | Model training method, device and computer equipment |
CN114299218A (en) * | 2021-12-13 | 2022-04-08 | 吉林大学 | System for searching real human face based on hand-drawing sketch |
CN114399668A (en) * | 2021-12-27 | 2022-04-26 | 中山大学 | Method and device for generating natural images based on hand-drawn sketches and image sample constraints |
CN115147513A (en) * | 2022-07-14 | 2022-10-04 | 中山大学 | Sketch image translation method and system without network training |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109145992A (en) * | 2018-08-27 | 2019-01-04 | 西安电子科技大学 | Cooperation generates confrontation network and sky composes united hyperspectral image classification method |
CN109978165A (en) * | 2019-04-04 | 2019-07-05 | 重庆大学 | A kind of generation confrontation network method merged from attention mechanism |
-
2019
- 2019-09-24 CN CN201910909387.5A patent/CN110659727B/en active Active
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109145992A (en) * | 2018-08-27 | 2019-01-04 | 西安电子科技大学 | Cooperation generates confrontation network and sky composes united hyperspectral image classification method |
CN109978165A (en) * | 2019-04-04 | 2019-07-05 | 重庆大学 | A kind of generation confrontation network method merged from attention mechanism |
Non-Patent Citations (2)
Title |
---|
Self-Attention Generative Adversarial Networks;Han Zhang et al.;《arXiv》;20190114;第1-10页 * |
SketchyGAN: Towards Diverse and Realistic Sketch to Image Synthesis;Wengling Chen et al.;《arXiv》;20180412;第1-19页 * |
Also Published As
Publication number | Publication date |
---|---|
CN110659727A (en) | 2020-01-07 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110659727B (en) | A Sketch-Based Image Generation Method | |
CN113409191B (en) | Lightweight image super-resolution method and system based on attention feedback mechanism | |
CN106462724B (en) | Method and system based on normalized images verification face-image | |
CN110570522B (en) | Multi-view three-dimensional reconstruction method | |
CN110544297B (en) | Three-dimensional model reconstruction method for single image | |
WO2024060395A1 (en) | Deep learning-based high-precision point cloud completion method and apparatus | |
Xu et al. | Multi-view 3D shape recognition via correspondence-aware deep learning | |
CN113096239B (en) | Three-dimensional point cloud reconstruction method based on deep learning | |
CN114758152B (en) | A feature matching method based on attention mechanism and neighborhood consistency | |
CN115222998B (en) | Image classification method | |
CN114612902B (en) | Image semantic segmentation method, device, equipment, storage medium and program product | |
CN112347932B (en) | A 3D model recognition method based on point cloud-multi-view fusion | |
CN115512368B (en) | A cross-modal semantic image generation model and method | |
CN109447897B (en) | Real scene image synthesis method and system | |
CN112634438A (en) | Single-frame depth image three-dimensional model reconstruction method and device based on countermeasure network | |
CN113658322A (en) | Visual transform-based three-dimensional voxel reconstruction method | |
CN118136155A (en) | Drug target affinity prediction method based on multi-modal information fusion and interaction | |
CN113449612A (en) | Three-dimensional target point cloud identification method based on sub-flow sparse convolution | |
CN114693951A (en) | An RGB-D Saliency Object Detection Method Based on Global Context Information Exploration | |
Li et al. | Multi-view convolutional vision transformer for 3D object recognition | |
CN115601498A (en) | Single image three-dimensional reconstruction method based on RealPoin3D | |
CN118521482A (en) | Depth image guided super-resolution reconstruction network model | |
Liu et al. | Multilevel receptive field expansion network for small object detection | |
CN104036242A (en) | Object recognition method based on convolutional restricted Boltzmann machine combining Centering Trick | |
CN112529057A (en) | Graph similarity calculation method and device based on graph convolution network |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |