CN116403026A - 一种基于多类别风格提取的风格迁移方法 - Google Patents

一种基于多类别风格提取的风格迁移方法 Download PDF

Info

Publication number
CN116403026A
CN116403026A CN202310229019.2A CN202310229019A CN116403026A CN 116403026 A CN116403026 A CN 116403026A CN 202310229019 A CN202310229019 A CN 202310229019A CN 116403026 A CN116403026 A CN 116403026A
Authority
CN
China
Prior art keywords
style
adaptive
image
encoder
category
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310229019.2A
Other languages
English (en)
Inventor
胡志豪
徐蔚鸿
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Changsha University of Science and Technology
Original Assignee
Changsha University of Science and Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Changsha University of Science and Technology filed Critical Changsha University of Science and Technology
Priority to CN202310229019.2A priority Critical patent/CN116403026A/zh
Publication of CN116403026A publication Critical patent/CN116403026A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • G06N20/20Ensemble learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Software Systems (AREA)
  • Medical Informatics (AREA)
  • Artificial Intelligence (AREA)
  • Computing Systems (AREA)
  • General Physics & Mathematics (AREA)
  • Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Health & Medical Sciences (AREA)
  • Databases & Information Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Multimedia (AREA)
  • Data Mining & Analysis (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于多类别风格提取的风格迁移方法,包括,内容编码器提取源数据域的内容特征;多类别风格提取编码器提取目的数据域不同类别的风格特征,并将多类别的风格特征进行统计;解码器针对内容特征和多类别风格特征进行特征融合,以此生成与目的数据域相似的样本;并利用对比学习方法使神经网络在风格迁移过程中,不会对源数据域的内容特征进行改变。并在判别器中添加谱归一化操作,以此减少网络的训练。本发明方法能够有效的使用非配对图像输入到神经网络进行训练,以此得到较佳的风格迁移效果。

Description

一种基于多类别风格提取的风格迁移方法
技术领域
本发明涉及风格迁移领域,尤其涉及一种基于多类别风格提取的风格迁移方法。
背景技术
Image-to-Image Translation(I2IT)是一类计算机视觉任务,其旨在将输入图像转换为输出图像。这种任务可以应用于许多实际场景中,例如图像修复、图像翻译、图像风格转换等。在转换过程中,输出结果应保持源域图像的内容特征不变,同时使网络能够学习目标域的风格特征,这是I2IT的重要目标。
Pix2Pix是一个基于生成对抗网络(GAN)的有监督图像到图像转换模型,最早由Isola等人在2017年提出。Pix2Pix的生成器采用了U-Net的结构,判别器采用了PatchGAN的结构,判别器会对图像的每个小区域进行判别,这可以提高判别器的效率和精度。但Pix2Pix模型对于输入图像的质量非常敏感,如果输入图像存在噪声、模糊或失真等问题,模型可能会产生不准确的输出。
MUNIT(Multimodal Unsupervised Image-to-image Translation)是一种用于图像转换的无监督模型,它的主要思想是将图像的内容和风格分开处理,并利用对抗损失(Adversarial Loss)和循环一致性损失函数(Cycle-consistency Loss)来训练模型。MUNIT包括三个主要模块:内容编码器网络、样式编码器和解码器网络。
在训练过程中,MUNIT使用对抗损失来优化编码器和解码器网络,使其能够学习到如何从随机噪声中生成具有特定视觉风格的图像。同时,利用循环一致性损失函数来保留原始图像的内容信息,并在不同的视觉风格之间进行转换。但是该模型在训练时需要消耗较多的计算资源和内存资源。
Taesung Park等人提出CUT模型,与其他无配对图像翻译模型相比,CUT的目标是学习两个图像之间的块映射,而不是像素到像素的映射。其主要由两个主要组件组成:生成器和判别器。生成器从一个域生成图像,并通过对比学习将其转换为另一个域。具体来说,生成器包括两个子网络:编码器和解码器。编码器将输入图像编码为一组特征,解码器将这些特征转换为另一个域中的图像。对比学习使用了一个判别器网络,它被训练来区分翻译图像和真实图像。但这可能会导致生成的图像过于相似,缺乏真实样本中的多样性和复杂性。
综上所述,I2IT任务虽然得到的广泛的研究,但是现存的I2IT方法还存在迁移效果不佳等问题。因此,该发明提出一种基于多类别风格提取方法并将其应用到I2IT任务中,以此提高风格迁移效果。
发明内容
本发明目的是为了解决风格迁移过程中,迁移效果不佳问题。提出了一种基于多类别风格提取的风格迁移方法,以此提升图像的迁移效果。
为实现上述目的,本发明提供如下技术方案:提出了一种基于多类别风格提取的风格迁移方法,包括如下步骤:
1)从风格迁移官方数据集中加载源域A和目标域B的图像数据;
2)将源域图像A输入到内容编码器(Content_Encoder)中,得到关于源域图像A的内容编码Encode_A;
3)将目标域图像B输入到风格编码器(Style_Encoder)中,得到关于目标域图像的风格编码style_B;
4)将源域图像A的内容编码Encode_A和关于目标域的风格编码style_B输入到解码器(Decoder)中,得到生成的目标图像Fake_B;
5)将Fake_B图像从计算图上分离,传入带有Spectral Normalization的判别器中,输出得到判别结果,将结果结合对抗损失得到生成图像的损失值;
6)将目标域的真实图像B传入带有Spectral Normalization的判别器中,输出得到判别结果,将结果结合对抗损失得到真实图像的损失值;
7)将生成图像的损失值和真实图像的损失值相加得到判别器的损失值;
8)反向传播判别器的损失值,更新判别器的梯度;
9)将生成图像Fake_B和真实图像A传入Content_Encoder得到对比损失PatchNCEloss;
10)将判别器的判别结果和PatchNCE loss相加得到生成器损失值;
11)将判别器和生成器的损失值相加得到总损失值;
12)设置判别器的梯度为不可以更新状态,利用生成器的损失值计算得到生成器的梯度更新值,反向传播生成器的梯度,更新生成器;
13)通过不断地迭代训练得到最终的模型。
所述步骤5)中,得到的对抗损失,具体如下:
Figure BDA0004119519640000033
式(1)中G代表域A→B的生成器,D代表判别器;
式(1)中a~A,b~B分别代表A域和B域中的训练样本;
上述步骤10)中,得到的对比损失PatchNCE loss,具体计算如下:
Figure BDA0004119519640000031
式(2)中v代表生成图像Fake_B中将要进行查询的样本点,v+代表源域图像A中对应位置的样本点,v-代表源域图像A中其它位置的样本点,τ是超参数,代表温度系数,将它设置为0.07;
所述步骤11)中,得到的总损失,具体计算如下:
L(G,D,A,B)=LGAN(G,D,A,B)+λ*PatchNCE (3)
式(4)中,λ代表的是编码的相关性;
所述步骤15)中,迭代训练,具体优化目标为:
Figure BDA0004119519640000032
式(4)中,G*和D*为最优情况下生成器和判别器;
本发明的有益效果是:
一、本发明基于风格迁移问题,提出一种基于多类别风格提取的风格迁移方法,该方法能够有效提升风格迁移效果。
二、本发明采用一种注意力机制,对输入图像的重要特征进行增强。
三、本发明采用一种全新的多种风格提取方法,对迁移结果进行提升。
四、本发明采用了谱归一化操作对判别器网络进行约束,从而缓解网络的振荡。
附图说明
图1是本发明网络模型图;
图2是本发明实例多样式提取的生成器网络模型图;
图3是本发明实例注意力机制模型图;
图4是本发明实例判别器网络模型图。
具体实施方式
为了使本发明的目的及技术方案更加明显,以下结合附图及实施例,对本发明进行进一步详细说明。应当理解,此处所描述的具体实施仅仅用于解释本发明,但不用于限定本发明。此外,下面所描述的本发明各个实施方式中涉及到的技术特征只要彼此之间未构成冲突就可以相互结合。
本发明提供一种基于多类别风格提取的风格迁移方法,网络模型的生成器采用多类别风格提取的Style Encoder和带注意力机制的Content Encoder,本网络模型的判别器使用了谱归一化操作,具体步骤如下:
步骤S1,使用face2anime数据集,其中包括源域图像和目标域图像,并将图像输入网络模型中,网络模型如图1所示。
步骤S2,如图2所示,对Style_Encoder进行改进,提出了一种新的提取多种风格的编码器Mutli-Style Extraction Encoder,目标域图像在经过多次卷积层之后得到多类别的目标域风格编码,源域图像经过内容编码器得到内容编码。
步骤S2中,提取多种风格的编码器的具体步骤为:
S21,给定一个特征图作为输入张量:
Finput=(fijk)H×W×c s.t.i=1,2,,,H;j=1,2,,,W;k=1,2,,,C;
其中Finput是指输入张量,(H,W,C)指输入张量的形状,fijk代表Finput中位置为(i,j,k)的值。
S22,将Finput输入到卷积层(Conv1ijnm)7*7*C*64,得到卷积之后的结果(Conv1_outijk)H*W*64
其中,
Figure BDA0004119519640000041
卷积层Conv1需要先对特征图进行Padding为3的填充操作,卷积核有64个维度为7*7*C的算子构成,/>
Figure BDA00041195196400000515
表示卷积操作。
Conv1_outijk指第k个卷积核卷积fijk之后的结果。
S22,将(Conv1_outijk)H*W*64进行ReLu激活,得到(ReLu1_outijk)H*W*64
其中,ReLU具体操作为
Figure BDA0004119519640000051
S23,将(ReLU1_outijk)H*W*64输入到卷积层(Conv2ijnm)4*4*64*128,得到卷积之后的结果
Figure BDA0004119519640000052
其中,
Figure BDA0004119519640000053
卷积层Conv2需要先对特征图进行Padding为1的填充操作,卷积核有128个维度为4*4*64的算子构成,卷积的Stride为2,/>
Figure BDA0004119519640000054
表示卷积操作。
Conv2_outijk指第k个卷积核卷积ReLU1_outijk之后的结果。
S24,将
Figure BDA0004119519640000055
进行ReLu激活,得到/>
Figure BDA0004119519640000056
具体操作与S22一致;
S25,将
Figure BDA0004119519640000057
输入到卷积层(Conv3ijnm)4*4*128*256,得到卷积之后的结果/>
Figure BDA00041195196400000516
其中,
Figure BDA0004119519640000058
卷积层Conv3需要先对特征图进行Padding为1的填充操作,卷积核有256个维度为4*4*128的算子构成,卷积的Stride为2,/>
Figure BDA0004119519640000059
表示卷积操作。
Conv3_outijk指第k个卷积核卷积ReLU2_outijk之后的结果。
S26,将
Figure BDA00041195196400000510
进行ReLU激活,得到/>
Figure BDA00041195196400000511
具体操作与S22一致;
S27,将
Figure BDA00041195196400000512
输入到卷积层(Conv4ijnm)4*4*256*256,得到卷积之后的结果/>
Figure BDA00041195196400000513
其中,
Figure BDA00041195196400000514
卷积层Conv4需要先对特征图进行Padding为1的填充操作,卷积核有256个维度为4*4*256的算子构成,卷积的Stride为2,/>
Figure BDA0004119519640000061
表示卷积操作。
Conv4_outijk指第k个卷积核卷积ReLU3_outijk之后的结果。
S28,将
Figure BDA0004119519640000062
进行ReLU激活,得到/>
Figure BDA0004119519640000063
具体操作与S22一致;
S29,将
Figure BDA0004119519640000064
输入到卷积层(Conv5ijnm)4*4*256*256,得到卷积之后的结果/>
Figure BDA0004119519640000065
其中,
Figure BDA0004119519640000066
卷积层Conv5需要先对特征图进行Padding为1的填充操作,卷积核有256个维度为4*4*256的算子构成,卷积的Stride为2,/>
Figure BDA0004119519640000067
表示卷积操作。
Conv5_outijk指第k个卷积核卷积ReLU4_outijk之后的结果。
S210,将
Figure BDA0004119519640000068
进行ReLU激活,得到/>
Figure BDA0004119519640000069
具体操作与S22一致;
S211,将
Figure BDA00041195196400000610
进行通道维度上的自适应平均池化AdaptiveAvgPooling,得到自适应平均池化矩阵(AdaAvgPooling_outijk)1*1*256
其中,
Figure BDA00041195196400000611
Figure BDA00041195196400000612
S212,将
Figure BDA00041195196400000613
进行通道维度上的自适应最大池化AdaptiveMaxPooling,得到自适应最大池化矩阵(AdaMaxPooling_outijk)1*1*256
其中,
Figure BDA00041195196400000614
Figure BDA00041195196400000615
S213,将
Figure BDA00041195196400000616
进行通道维度上的自适应方差池化AdaptiveStdPooling,得到自适应方差池化矩阵(AdaStdPooling_outijk)1*1*256
其中,
Figure BDA00041195196400000617
Figure BDA00041195196400000618
S214,将AdaAvgPooling_out1*1*256、AdaMaxPooling_out1*1*256和AdaStdPooling_out1*1*256在通道维度上进行连接得到多种类风格矩阵(Cat_outijk)1*1*768
S215,将(Cat_outijk)1*1*768输入到卷积层(Conv6ijnm)1*1*768*8,得到卷积之后的结果(MultiStyle_featureijk)1*1*8
其中,
Figure BDA0004119519640000071
,卷积层Conv6有8个维度为1*1*768的算子构成,/>
Figure BDA0004119519640000072
表示卷积操作;
MultiStyle_featureijk指第k个卷积核卷积Cat_outijk之后的结果。
S216,将提取到的多种类风格(MultiStyle_featureijk)1*1*8注入到编码器中。
步骤S3,将风格编码注入到注意力机制的解码器中,以此实现目标域风格特征和源域内容特征的结合,注意力机制结构如图3所示。
步骤S4,如图4所示,对判别器网络进行改进,通过增加Spectral Normalization,以此约束原始判别器的卷积操作,从而稳定网络的训练过程。
步骤S4中,Spectral Normalization的具体操作为:
S41,给定卷积操作的权重矩阵为Wn*m
S42,初始化一个随机的向量v1*m
S43,计算向量
Figure BDA0004119519640000073
S44,计算向量
Figure BDA0004119519640000074
S45,计算谱范数
Figure BDA0004119519640000076
S46,用谱范数σ(W)进行归一化操作,
Figure BDA0004119519640000075
步骤S5,对S2和S3中改进的模型进行训练,得到一个全新的风格迁移模型MSE-GAN。
本领域的技术人员容易理解,以上所述仅为本发明的较佳实施例而已,并不用以限制本发明,凡在本发明的精神和原则之内做任何修改,等同替换和改进等,均应包含在本发明的保护范围之内。

Claims (4)

1.一种基于多类别风格提取的图像风格迁移方法,其特征在于,包括如下步骤:
采用公开的风格迁移数据集训练MSE-GAN,训练集中样本为待风格迁移的源域图像和有目标风格的目的域图像,将训练样本输入到神经网络中,反复迭代,以此得到训练好的风格迁移网络;将测试集的源域图像输入到训练好的风格迁移网络中,得到与目的域相似的全新图像,在转换过程中,源域图像的几何特征并不会发生改变;
所述MSE-GAN包含Multi-Style Extraction Encoder、Content Encoder、Decoder和带有Spectral Normalization的Discriminator;Multi-Style Extraction Encoder包括ConvBlock、自适应平均池化AdaptiveAvgPooling、自适应最大池化AdaptiveMaxPooling和自适应方差池化AdaptiveStdPooling,其中ConvBlock表示卷积和实例归一化操作,激活函数是Relu;Content Encoder包括Downsample、Statistical Attention和Residual Block。Decoder包括AdaIN Residual Block、Statistical Attention和Upsample;SN-Discriminator包含3个带有Spectral Normalization的ConvBlock。
2.如权利要求1所述方法,其特征在于,多类别风格提取编码器Multi-StyleExtraction Encoder的结构包括:
大小为H*W*C的输入特征图在通道维度上分别进行自适应平均池化、自适应最大池化和自适应方差池化,得到三个大小为1*1*C的特征图矩阵,拼接三个特征图矩阵输入到卷积层,卷积层中卷积核的个数为C,卷积层最终输出得到多类别的风格特征矩阵。
3.如权利要求1所述,其特征在于,新的风格迁移模型MSE-GAN的改进包括:
将Multi-Style Extraction Encoder与Content Encoder、Decoder和带有SpectralNormalization的Discriminator进行结合,得到新的风格迁移模型MSE-GAN。
4.如权利要求1所述,其特征在于,训练MSE-GAN模型:
根据MSE-GAN的结构,结合对比学习方法,以此保证图像几何特征的一致性,得到新的风格迁移模型MSE-GAN。
CN202310229019.2A 2023-03-10 2023-03-10 一种基于多类别风格提取的风格迁移方法 Pending CN116403026A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310229019.2A CN116403026A (zh) 2023-03-10 2023-03-10 一种基于多类别风格提取的风格迁移方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310229019.2A CN116403026A (zh) 2023-03-10 2023-03-10 一种基于多类别风格提取的风格迁移方法

Publications (1)

Publication Number Publication Date
CN116403026A true CN116403026A (zh) 2023-07-07

Family

ID=87009388

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310229019.2A Pending CN116403026A (zh) 2023-03-10 2023-03-10 一种基于多类别风格提取的风格迁移方法

Country Status (1)

Country Link
CN (1) CN116403026A (zh)

Similar Documents

Publication Publication Date Title
KR101880907B1 (ko) 비정상 세션 감지 방법
CN107798381B (zh) 一种基于卷积神经网络的图像识别方法
CN111292330A (zh) 基于编解码器的图像语义分割方法及装置
CN111079532A (zh) 一种基于文本自编码器的视频内容描述方法
CN110570433B (zh) 基于生成对抗网络的图像语义分割模型构建方法和装置
CN111582483A (zh) 基于空间和通道联合注意力机制的无监督学习光流估计方法
CN112418292B (zh) 一种图像质量评价的方法、装置、计算机设备及存储介质
CN112489164B (zh) 基于改进深度可分离卷积神经网络的图像着色方法
CN114511576B (zh) 尺度自适应特征增强深度神经网络的图像分割方法与系统
CN114418030A (zh) 图像分类方法、图像分类模型的训练方法及装置
CN111931779A (zh) 一种基于条件可预测参数的图像信息提取与生成方法
CN116740422A (zh) 基于多模态注意力融合技术的遥感图像分类方法及装置
CN112149526A (zh) 一种基于长距离信息融合的车道线检测方法及系统
Zhou et al. MSAR‐DefogNet: Lightweight cloud removal network for high resolution remote sensing images based on multi scale convolution
CN114494387A (zh) 一种生成数据集网络模型及雾图生成方法
Li et al. Underwater Imaging Formation Model‐Embedded Multiscale Deep Neural Network for Underwater Image Enhancement
CN107729885B (zh) 一种基于多重残差学习的人脸增强方法
CN112686830B (zh) 基于图像分解的单一深度图的超分辨率方法
CN112766099B (zh) 一种从局部到全局上下文信息提取的高光谱影像分类方法
CN113763268A (zh) 人脸图像盲修复方法及系统
CN112560719A (zh) 基于多尺度卷积-多核池化的高分辨率影像水体提取方法
CN116977694A (zh) 一种基于不变特征提取的高光谱对抗样本防御方法
CN111179171A (zh) 基于残差模块和注意力机制的图像超分辨率重建方法
CN113962332B (zh) 基于自优化融合反馈的显著目标识别方法
CN116091893A (zh) 一种基于U-net网络的地震图像反褶积方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination