CN115860113A - 一种自对抗神经网络模型的训练方法及相关装置 - Google Patents

一种自对抗神经网络模型的训练方法及相关装置 Download PDF

Info

Publication number
CN115860113A
CN115860113A CN202310196878.6A CN202310196878A CN115860113A CN 115860113 A CN115860113 A CN 115860113A CN 202310196878 A CN202310196878 A CN 202310196878A CN 115860113 A CN115860113 A CN 115860113A
Authority
CN
China
Prior art keywords
attention
generator
loss value
image
loss
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202310196878.6A
Other languages
English (en)
Other versions
CN115860113B (zh
Inventor
乐康
张耀
曹保桂
张滨
徐大鹏
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Seichitech Technology Co ltd
Original Assignee
Shenzhen Seichitech Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Seichitech Technology Co ltd filed Critical Shenzhen Seichitech Technology Co ltd
Priority to CN202310196878.6A priority Critical patent/CN115860113B/zh
Publication of CN115860113A publication Critical patent/CN115860113A/zh
Application granted granted Critical
Publication of CN115860113B publication Critical patent/CN115860113B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Image Analysis (AREA)

Abstract

本申请公开了一种自对抗神经网络模型的训练方法及相关装置,用于提高卷积神经网络模型的训练速度和准确性。本申请包括:获取卷积神经网络模型;将正态分布采样数据输入生成器,生成模拟图像;将真实图像和模拟图像输入反向器中,生成第一和第二特征空间数据;根据真实图像、模拟图像、第一、第二特征空间数据和正态分布采样数据计算数据分布损失和图像像素损失;根据数据分布损失和图像像素损失计算生成器损失值和反向器损失值;判断生成器损失值和反向器损失值是否满足预设条件;若是,则确定训练完成;若否,根据生成器损失值和反向器损失值对生成器和反向器的权重值进行拟合,生成共享更新权重值,并重复进行迭代训练。

Description

一种自对抗神经网络模型的训练方法及相关装置
技术领域
本申请实施例涉及卷积神经网络模型领域,尤其涉及一种自对抗神经网络模型的训练方法及相关装置。
背景技术
近年来,随着计算机的不断发展,使得卷积神经网络模型的应用范围快速扩大,涉及制造业、日常生活等。分析图像的类型是卷积神经网络模型的主要功能之一,可以应用在识别物品的缺陷,例如:在制造PCB板过程中识别PCB板上存在的缺陷。利用了卷积神经网络模型可以对某一图像进行学习训练的能力,提高卷积神经网络模型对该图像存在的特征的识别能力。
目前,生成模式作为深度学习技术的一大分支在图像领域有了较大发展,生成模式通过学习图像的数据统计分布特征,使用卷积神经网络训练拟合这种统计分布特征,然后通过在得到的分布特征中随机采样,重构生成与原图同分布不同采样的新数据,即生成和原图“同类但不一样”的新图像。可以大量生成各种不同的新图像,以此扩充数据集。而基于生成对抗(GAN)的深度学习方法在生成模式中独树一帜,其独特的生成器和反向器对抗方式,极大提升了生成图像的质量。
但是,上述的生成对抗网络自身存在不可忽视的缺陷。生成器和反向器为两个不同的神经网络,两者通过损失函数加以联系,但训练中这两个模块很难协同配合,会出现生成器或反向器其一训练较好,另一个训练较差。使得生成器和反向器无法协同训练,各自提升,无法同时进步,即当前应用于生成图像领域的卷积神经网络模型中生成器和反向器难以配合,使得卷积神经网络模型训练速度和准确性下降。
发明内容
本申请公开了一种自对抗神经网络模型的训练方法及相关装置,用于提高卷积神经网络模型的训练速度和准确性。
本申请第一方面提供了一种自对抗神经网络模型的训练方法,包括:
获取卷积神经网络模型,所述卷积神经网络模型中包括生成器和反向器;
将一组正态分布采样数据输入所述生成器,生成模拟图像;
将真实图像和所述模拟图像输入所述反向器中,生成所述模拟图像的第一特征空间数据和所述真实图像的第二特征空间数据;
根据所述真实图像、所述模拟图像、所述第一特征空间数据、第二特征空间数据和正态分布采样数据计算数据分布损失和图像像素损失;
根据所述数据分布损失和所述图像像素损失计算所述生成器的生成器损失值和所述反向器的反向器损失值;
判断所述生成器损失值和所述反向器损失值是否满足预设条件;
若是,则确定所述卷积神经网络模型训练完成;
若否,根据所述生成器损失值和所述反向器损失值对所述生成器和所述反向器的权重值进行拟合,生成共享更新权重值,并重复进行迭代训练。
可选的,所述生成器包括至少一个Generator单元,所述Generator单元包括注意力Dropout模块ADO、至少一个Attention/Conv_t块和区域像素注意力模块RPA,所述Attention/Conv_t块包括通道注意力模块Attention和反卷积模块;
所述将一组正态分布采样数据输入所述生成器,生成模拟图像,包括:
通过第一注意力Dropout模块ADO给所述正态分布采样数据对应的每个神经元分配注意力,并将注意力小于第一预设阈值的神经元进行置零,生成第一中间特征;
通过第一通道注意力模块Attention为所述第一中间特征生成通道向量;
通过第一通道注意力模块Attention结合所述通道向量输出一个维度与所述第一中间特征通道数相同的归一化一维向量;
通过第一通道注意力模块Attention并根据所述归一化一维向量将所述第一中间特征按通道对应相乘,生成第二中间特征;
通过反卷积模块将所述第二中间特征进行卷积处理和通道叠加处理;
通过第一区域像素注意力模块RPA对所述第二中间特征进行区域像素值权重生成处理,生成模拟图像。
可选的,所述反向器包括至少一个 Reverse单元,所述 Reverse单元包括区域像素注意力模块RPA、至少一个Attention/Conv块和、注意力Dropout模块ADO,所述Attention/Conv块包括通道注意力模块Attention和卷积模块
所述将真实图像和所述模拟图像输入所述反向器中,生成所述模拟图像的第一特征空间数据和所述真实图像的第二特征空间数据,包括:
通过第二区域像素注意力模块RPA对所述模拟图像进行区域像素值权重生成处理,生成第三中间特征;
通过卷积模块将所述第三中间特征进行卷积处理和通道叠加处理;
通过第二通道注意力模块Attention为所述第三中间特征生成通道向量;
通过第二通道注意力模块Attention结合所述通道向量输出一个维度与所述第三中间特征通道数相同的归一化一维向量;
通过第二通道注意力模块Attention并根据所述归一化一维向量将所述第三中间特征按通道对应相乘,生成第四中间特征;
通过第二注意力Dropout模块ADO给所述第四中间特征对应的每个神经元分配注意力,并将注意力小于第二预设阈值的神经元进行置零,生成第一特征空间数据;
根据上述方法为所述真实图像生成第二特征空间数据。
可选的,根据所述真实图像、所述模拟图像、所述第一特征空间数据、第二特征空间数据和正态分布采样数据计算数据分布损失和图像像素损失,包括:
使用Wasserstein距离+梯度惩罚法并根据所述第一特征空间数据和正态分布采样数据计算第一数据分布损失值;
使用均方误差法并根据所述模拟图像和所述真实图像计算图像像素损失值;
使用Wasserstein距离+梯度惩罚法并根据所述第一特征空间数据和正态分布采样数据计算第二数据分布损失值;
使用Wasserstein距离+梯度惩罚法并根据所述第二特征空间数据和所述正态分布采样数据计算第三数据分布损失值。
可选的,所述根据所述数据分布损失和所述图像像素损失计算所述生成器的生成器损失值和所述反向器的反向器损失值,包括:
根据所述第一数据分布损失值和图像像素损失值计算所述生成器的生成器损失值;
根据第二数据分布损失值和第三数据分布损失值计算所述反向器的反向器损失值。
可选的,所述根据所述生成器损失值和所述反向器损失值对所述生成器和所述反向器的权重值进行拟合,生成共享更新权重值,并重复进行迭代训练,包括:
通过小批量梯度下降法并根据所述生成器损失值和所述生成器的权重值生成新生成器权重值;
通过小批量梯度下降法并根据所述反向器损失值和所述反向器的权重值生成新反向器权重值;
将所述新生成器权重值和所述新反向器权重值进行加权平均处理,生成共享更新权重值;
将所述共享更新权重值作为所述生成器和所述反向器更新后的权重值;
对所述更新后的卷积神经网络模型重复进行迭代训练。
本申请第二方面提供了一种自对抗神经网络模型的训练装置,包括:
第一获取单元,用于获取卷积神经网络模型,所述卷积神经网络模型中包括生成器和反向器;
第一生成单元,用于将一组正态分布采样数据输入所述生成器,生成模拟图像;
第二生成单元,用于将真实图像和所述模拟图像输入所述反向器中,生成所述模拟图像的第一特征空间数据和所述真实图像的第二特征空间数据;
第一计算单元,用于根据所述真实图像、所述模拟图像、所述第一特征空间数据、第二特征空间数据和正态分布采样数据计算数据分布损失和图像像素损失;
第二计算单元,用于根据所述数据分布损失和所述图像像素损失计算所述生成器的生成器损失值和所述反向器的反向器损失值;
判断单元,用于判断所述生成器损失值和所述反向器损失值是否满足预设条件;
确定单元,用于当所述判断单元确定所述生成器损失值和所述反向器损失值满足预设条件时,则确定所述卷积神经网络模型训练完成;
迭代单元,用于当所述判断单元确定所述生成器损失值和所述反向器损失值不满足预设条件时,根据所述生成器损失值和所述反向器损失值对所述生成器和所述反向器的权重值进行拟合,生成共享更新权重值,并重复进行迭代训练。
可选的,所述生成器包括至少一个Generator单元,所述Generator单元包括注意力Dropout模块ADO、至少一个Attention/Conv_t块和区域像素注意力模块RPA,所述Attention/Conv_t块包括通道注意力模块Attention和反卷积模块;
所述第一生成单元,包括:
通过第一注意力Dropout模块ADO给所述正态分布采样数据对应的每个神经元分配注意力,并将注意力小于第一预设阈值的神经元进行置零,生成第一中间特征;
通过第一通道注意力模块Attention为所述第一中间特征生成通道向量;
通过第一通道注意力模块Attention结合所述通道向量输出一个维度与所述第一中间特征通道数相同的归一化一维向量;
通过第一通道注意力模块Attention并根据所述归一化一维向量将所述第一中间特征按通道对应相乘,生成第二中间特征;
通过反卷积模块将所述第二中间特征进行卷积处理和通道叠加处理;
通过第一区域像素注意力模块RPA对所述第二中间特征进行区域像素值权重生成处理,生成模拟图像。
可选的,所述反向器包括至少一个 Reverse单元,所述 Reverse单元包括区域像素注意力模块RPA、至少一个Attention/Conv块和、注意力Dropout模块ADO,所述Attention/Conv块包括通道注意力模块Attention和卷积模块;
所述第二生成单元,包括:
通过第二区域像素注意力模块RPA对所述模拟图像进行区域像素值权重生成处理,生成第三中间特征;
通过卷积模块将所述第三中间特征进行卷积处理和通道叠加处理;
通过第二通道注意力模块Attention为所述第三中间特征生成通道向量;
通过第二通道注意力模块Attention结合所述通道向量输出一个维度与所述第三中间特征通道数相同的归一化一维向量;
通过第二通道注意力模块Attention并根据所述归一化一维向量将所述第三中间特征按通道对应相乘,生成第四中间特征;
通过第二注意力Dropout模块ADO给所述第四中间特征对应的每个神经元分配注意力,并将注意力小于第二预设阈值的神经元进行置零,生成第一特征空间数据;
根据上述方法为所述真实图像生成第二特征空间数据。
可选的,第一计算单元,包括:
使用Wasserstein距离+梯度惩罚法并根据所述第一特征空间数据和正态分布采样数据计算第一数据分布损失值;
使用均方误差法并根据所述模拟图像和所述真实图像计算图像像素损失值;
使用Wasserstein距离+梯度惩罚法并根据所述第一特征空间数据和正态分布采样数据计算第二数据分布损失值;
使用Wasserstein距离+梯度惩罚法并根据所述第二特征空间数据和所述正态分布采样数据计算第三数据分布损失值。
可选的,所述第二计算单元,包括:
根据所述第一数据分布损失值和图像像素损失值计算所述生成器的生成器损失值;
根据第二数据分布损失值和第三数据分布损失值计算所述反向器的反向器损失值。
可选的,所述迭代单元,包括:
通过小批量梯度下降法并根据所述生成器损失值和所述生成器的权重值生成新生成器权重值;
通过小批量梯度下降法并根据所述反向器损失值和所述反向器的权重值生成新反向器权重值;
将所述新生成器权重值和所述新反向器权重值进行加权平均处理,生成共享更新权重值;
将所述共享更新权重值作为所述生成器和所述反向器更新后的权重值;
对所述更新后的卷积神经网络模型重复进行迭代训练。
从以上技术方案可以看出,本申请实施例具有以下优点:
本申请中,首先获取卷积神经网络模型,其中,卷积神经网络模型中包括生成器和反向器。通过将一组正态分布采样数据输入生成器,生成模拟图像。再将真实图像和模拟图像输入反向器中,生成模拟图像的第一特征空间数据和真实图像的第二特征空间数据,第一特征空间数据和第二特征空间数据分别表征了模拟图像和真实图像的真假判别程度。接下来,根据真实图像、模拟图像、第一特征空间数据、第二特征空间数据和正态分布采样数据计算数据分布损失和图像像素损失。根据数据分布损失和图像像素损失计算生成器的生成器损失值和反向器的反向器损失值。判断生成器损失值和反向器损失值是否满足预设条件。若是,则确定卷积神经网络模型训练完成。若否,根据生成器损失值和反向器损失值对生成器和反向器的权重值进行拟合,生成共享更新权重值,并重复进行迭代训练。可以看出,在迭代训练过程中,通过将数据分布损失和图像像素损失进行拟合计算,生成对应的损失值,并且根据对应的损失值进行权重值的计算,最后还通过各自的权重值进行拟合计算,生成共享更新权重值,并且通过共享更新权重值进行后续的更新迭代,使得生成器和反向器两个模块进行协同配合,同时进步,使得卷积神经网络模型训练速度和准确性提高。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为本申请自对抗神经网络模型的训练方法的一个实施例示意图;
图2-a为本申请自对抗神经网络模型的训练方法的第一阶段的一个实施例示意图;
图2-b为本申请自对抗神经网络模型的训练方法的第二阶段的一个实施例示意图;
图2-c为本申请自对抗神经网络模型的训练方法的第三阶段的一个实施例示意图;
图2-d为本申请自对抗神经网络模型的训练方法的第四阶段的一个实施例示意图;
图3为本申请实施例中卷积神经网络模型网络层的一个实施例流程示意图;
图4为本申请实施例中卷积神经网络模型网络层的另一个实施例结构示意图;
图5为本申请实施例中卷积神经网络模型网络层的另一个实施例结构示意图;
图6为本申请自对抗神经网络模型的训练装置的一个实施例示意图;
图7为本申请电子设备的一个实施例示意图。
具体实施方式
以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本申请实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本申请。在其它情况中,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本申请的描述。
应当理解,当在本申请说明书和所附权利要求书中使用时,术语“包括”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在本申请说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
如在本申请说明书和所附权利要求书中所使用的那样,术语“如果”可以依据上下文被解释为“当...时”或“一旦”或“响应于确定”或“响应于检测到”。类似地,短语“如果确定”或“如果检测到[所描述条件或事件]”可以依据上下文被解释为意指“一旦确定”或“响应于确定”或“一旦检测到[所描述条件或事件]”或“响应于检测到[所描述条件或事件]”。
另外,在本申请说明书和所附权利要求书的描述中,术语“第一”、“第二”、“第三”等仅用于区分描述,而不能理解为指示或暗示相对重要性。
在本申请说明书中描述的参考“一个实施例”或“一些实施例”等意味着在本申请的一个或多个实施例中包括结合该实施例描述的特定特征、结构或特点。由此,在本说明书中的不同之处出现的语句“在一个实施例中”、“在一些实施例中”、“在其他一些实施例中”、“在另外一些实施例中”等不是必然都参考相同的实施例,而是意味着“一个或多个但不是所有的实施例”,除非是以其他方式另外特别强调。术语“包括”、“包含”、“具有”及它们的变形都意味着“包括但不限于”,除非是以其他方式另外特别强调。
在现有技术中,生成模式作为深度学习技术的一大分支在图像领域有了较大发展,生成模式通过学习图像的数据统计分布特征,使用卷积神经网络训练拟合这种统计分布特征,然后通过在得到的分布特征中随机采样,重构生成与原图同分布不同采样的新数据,即生成和原图“同类但不一样”的新图像。可以大量生成各种不同的新图像,以此扩充数据集。而基于生成对抗(GAN)的深度学习方法在生成模式中独树一帜,其独特的生成器和反向器对抗方式,极大提升了生成图像的质量。
但是,上述的生成对抗网络自身存在不可忽视的缺陷。生成器和反向器为两个不同的神经网络,两者通过损失函数加以联系,但训练中这两个模块很难协同配合,会出现生成器或反向器其一训练较好,另一个训练较差。使得生成器和反向器无法协同训练,各自提升,无法同时进步,即当前应用于生成图像领域的卷积神经网络模型中生成器和反向器难以配合,使得卷积神经网络模型训练速度和准确性下降。
基于此,本申请公开了一种自对抗神经网络模型的训练方法及相关装置,用于提高卷积神经网络模型的训练速度和准确性。
下面将结合本申请实施例中的附图,对本申请中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
本申请的方法可以应用于服务器、设备、终端或者其它具备逻辑处理能力的设备,对此,本申请不作限定。为方便描述,下面以执行主体为终端为例进行描述。
请参阅图1,本申请提供了一种自对抗神经网络模型的训练方法的一个实施例,包括:
101、获取卷积神经网络模型,卷积神经网络模型中包括生成器和反向器;
本实施例使用的基础卷积神经网络模型是一种对抗生成网络(GAN),具体的,本实施例自主设计一套AI神经网络体系结构名为SAN,包括生成器和反向器,生成器用于生成多张和真实图像分布相同的模拟图像。必要时,生成器需要利用到真实图像的部分特征进行生成。
真实图像是指某一类真实存在的物品通过摄像装置进行拍摄得到的具有物品真实特征的图像,模拟图像则是按照预设的生成规则,生成的具有物品特征的图像。
102、将一组正态分布采样数据输入生成器,生成模拟图像;
在传统的生成器中,通过一组采样数据,即可按照预设的生成规则生成模拟图像。
103、将真实图像和模拟图像输入反向器中,生成模拟图像的第一特征空间数据和真实图像的第二特征空间数据;
本实施例中,终端生成模拟图像后,会将真实图像和模拟图像输入反向器中,通过反向器生成模拟图像的第一特征空间数据和真实图像的第二特征空间数据,该步骤中反向器的作用是将图片还原成数据分布。
104、根据真实图像、模拟图像、第一特征空间数据、第二特征空间数据和正态分布采样数据计算数据分布损失和图像像素损失;
当终端将真实图像和模拟图像输入反向器中,生成模拟图像的第一特征空间数据和真实图像的第二特征空间数据之后,终端根据真实图像、模拟图像、第一特征空间数据、第二特征空间数据和正态分布采样数据计算数据分布损失和图像像素损失。
105、根据数据分布损失和图像像素损失计算生成器的生成器损失值和反向器的反向器损失值;
当终端根据真实图像、模拟图像、第一特征空间数据、第二特征空间数据和正态分布采样数据计算数据分布损失和图像像素损失之后,终端根据数据分布损失和图像像素损失计算生成器的生成器损失值和反向器的反向器损失值。
106、判断生成器损失值和反向器损失值是否满足预设条件;
终端判断生成器损失值和反向器损失值是否满足预设条件,可以是先统计生成器损失值,得到生成器损失值变化集合,判断最近10000次训练生成的生成器损失值是否均满足收敛,如果满足,则生成器训练完成。再或者是生成器在满足收敛的情况下,损失值还要均小于预设值,整体训练次数也达到了100万次,如果满足,则确定生成器训练完成。
反向器也可以根据上述方式为依据确定各自是否完成训练。
但是会出现二者不同时训练完成的情况,这时可以根据需要进行下一步。可以是生成器训练完成为满足预设条件,也可以是生成器和反向器均完成训练为满足预设条件,此处不作限定。
107、若是,则确定卷积神经网络模型训练完成;
当满足上述106的条件时,可以确定卷积神经网络模型训练完成,可以将生成器、反向器分别取出进行对应领域的运用,生成器用于生成模拟图像,反向器用于识别模拟图像的真伪。
108、若否,根据生成器损失值和反向器损失值对生成器和反向器的权重值进行拟合,生成共享更新权重值,并重复进行迭代训练。
当不满足训练条件时,则需要根据生成器损失值和反向器损失值对生成器和反向器的权重值进行拟合,生成共享更新权重值,并重复进行迭代训练,通过共享更新权重值来对生成器和反向器进行协同配合,提高二者的训练效率。
本实施例中,终端首先获取卷积神经网络模型,其中,卷积神经网络模型中包括生成器和反向器。终端通过将一组正态分布采样数据输入生成器,生成模拟图像。再将真实图像和模拟图像输入反向器中,生成模拟图像的第一特征空间数据和真实图像的第二特征空间数据,第一特征空间数据和第二特征空间数据分别表征了模拟图像和真实图像的真假判别程度。接下来,终端根据真实图像、模拟图像、第一特征空间数据、第二特征空间数据和正态分布采样数据计算数据分布损失和图像像素损失。终端根据数据分布损失和图像像素损失计算生成器的生成器损失值和反向器的反向器损失值。终端判断生成器损失值和反向器损失值是否满足预设条件。若是,则终端确定卷积神经网络模型训练完成。若否,终端根据生成器损失值和反向器损失值对生成器和反向器的权重值进行拟合,生成共享更新权重值,并重复进行迭代训练。可以看出,在迭代训练过程中,通过将数据分布损失和图像像素损失进行拟合计算,生成对应的损失值,并且根据对应的损失值进行权重值的计算,最后还通过各自的权重值进行拟合计算,生成共享更新权重值,并且通过共享更新权重值进行后续的更新迭代,使得生成器和反向器两个模块进行协同配合,同时进步,使得卷积神经网络模型训练速度和准确性提高。
请参阅图2-a、图2-b、图2-c和图2-d,本申请提供了一种自对抗神经网络模型的训练方法的一个实施例,包括:
201、获取卷积神经网络模型,卷积神经网络模型中包括生成器和反向器;
本实施例中的步骤201与前述实施例中步骤101类似,此处不再赘述。
202、通过第一注意力Dropout模块ADO给正态分布采样数据对应的每个神经元分配注意力,并将注意力小于第一预设阈值的神经元进行置零,生成第一中间特征;
本实施例中,注意力Dropout模块ADO有多种组合方式,其中包括一个BatchNorm-2*2Conv-LeakyReLU加一个BatchNorm-2*2Conv-SigMiod。也可以是BatchNorm-2*2Conv-LeakyReLU加BatchNorm-2*2Conv-LeakySigMiod。
BatchNorm-2*2Conv-PixelNormReLU+BatchNorm-2*2Conv-PixelNormSigMiod也是一个选择。
基于注意力的Dropout方法,不同于一般Dropout使用的随机方式,本发明中利用注意力保留更重要的特征信息,使得卷积神经网络模型的性能和泛化性更好。
本实施例中,先将输入的正态分布采样数据Z输入到一个BatchNorm-2*2Conv-LeakyReLU中进行处理,然后再将其输出输入到一个BatchNorm-2*2Conv-SigMiod中,生成和正态分布采样数据Z形同尺寸的注意力矩阵,根据注意力矩阵的值,将注意力小于第一预设阈值的原特征矩阵对应位置神经元置零,输出第一中间特征。
203、通过第一通道注意力模块Attention为第一中间特征生成通道向量;
204、通过第一通道注意力模块Attention结合通道向量输出一个维度与第一中间特征通道数相同的归一化一维向量;
205、通过第一通道注意力模块Attention并根据归一化一维向量将第一中间特征按通道对应相乘,生成第二中间特征;
通道注意力模块Attention具体包括一个全局平均池化层、一个1*1Conv-LeakyReLU和一个1*1Conv-Sigmoid,下面详细描述通道注意力模块Attention的运行原理。
具体的,第一中间特征先经过第一通道注意力模块Attention的全局平均池化层(Global Pooling)生成通道向量,再经过1×1卷积核、LeakyReLU激活函数进行通道压缩,再经过1×1卷积核以及Sigmoid激活函数,输出一个维度等于输入特征通道数的归一化一维向量,这就是各个特征通道的注意力权重,将其输入特征各个通道相乘,生成第二中间特征。
深度学习注意力(attention)机制是对人类视觉注意力机制的仿生,本质上是一种资源分配机制。生理原理就是人类视觉注意力能够以高分辨率接收于图片上的某个区域,并且以低分辨率感知其周边区域,并且视点能够随着时间而改变。换而言之,就是人眼通过快速扫描全局图像,找到需要关注的目标区域,然后对这个区域分配更多注意,目的在于获取更多细节信息和抑制其他无用信息。提高representation的高效性。
在神经网络中,attention机制可以它认为是一种资源分配的机制,可以理解为对于原本平均分配的资源根据attention对象的重要程度重新分配资源,重要的单位就多分一点,不重要或者不好的单位就少分一点,在深度神经网络的结构设计中,attention所要分配的资源基本上就是权重了。
206、通过反卷积模块将第二中间特征进行卷积处理和通道叠加处理;
具体的首先将第二中间特征进行卷积处理,生成的数据和第二卷积特征再进行通道叠加处理。目的是使用反卷积模块将输入的第二中间特征重构,增加特征长宽。
207、通过第一区域像素注意力模块RPA对第二中间特征进行区域像素值权重生成处理,生成模拟图像;
本步骤的区域像素注意力模块RPA,包含一个Conv-PixelNorm-LReLU、一个Conv-PixelNorm、一个SigMoid函数模块和一个双线性插值模块(或者是上采样模块)。Conv-PixelNorm-LReLU、Conv-PixelNorm、SigMoid函数模块和双线性插值模块依次串联。这里的Conv-PixelNorm-LReLU层和Conv-PixelNorm层都属于卷积神经网络中常用的特征处理层,SigMoid函数为已知函数,双线性插值运算方法也是已知算法。
区域像素注意力模块RPA作为第一重注意力机制,由于给第一采样特征的每块区域像素分配一个权重,使得神经网络对于第一采样特征明显的区域更加关注。
具体的,假设输入的原始图像的张数为B,通道数量为C,分辨率为W*H,则第一采样特征记为(B,C,H,W),(B,C,H,W)需要先经过区域像素注意力模块RPA的Conv-PixelNorm-LReLU层进行通道压缩为(B,C*r,H/2,W/2),其中r<1。再经过一个Conv-PixelNorm层还原成(B,C,H/4,W/4),再通过SigMoid函数模块生成每个像素值的权重,最后使用双线性插值还原成新的(B,C,H,W),和原始图像的(B,C,H,W)一对一相乘。
208、通过第二区域像素注意力模块RPA对模拟图像进行区域像素值权重生成处理,生成第三中间特征;
本实施例中的步骤208中第二区域像素注意力模块RPA的工作原理与前述实施例中步骤2207的第一区域像素注意力模块RPA类似,此处不再赘述。
209、通过卷积模块将第三中间特征进行卷积处理和通道叠加处理;
本实施例中的步骤209中卷积模块的工作原理与前述实施例中步骤206的反卷积模块类似,此处不再赘述。
210、通过第二通道注意力模块Attention为第三中间特征生成通道向量;
211、通过第二通道注意力模块Attention结合通道向量输出一个维度与第三中间特征通道数相同的归一化一维向量;
212、通过第二通道注意力模块Attention并根据归一化一维向量将第三中间特征按通道对应相乘,生成第四中间特征;
本实施例中的步骤210至212中第二通道注意力模块Attention的工作原理与前述实施例中步骤203至205的第一通道注意力模块Attention类似,此处不再赘述。
213、通过第二注意力Dropout模块ADO给第四中间特征对应的每个神经元分配注意力,并将注意力小于第二预设阈值的神经元进行置零,生成第一特征空间数据;
本实施例中的步骤213中第二注意力Dropout模块ADO的工作原理与前述实施例中步骤202的第一注意力Dropout模块ADO类似,此处不再赘述。
214、根据上述方法为真实图像生成第二特征空间数据;
终端根据上述方法为真实图像生成第二特征空间数据。
215、使用Wasserstein距离+梯度惩罚法并根据第一特征空间数据和正态分布采样数据计算第一数据分布损失值;
216、使用均方误差法并根据模拟图像和真实图像计算图像像素损失值;
217、使用Wasserstein距离+梯度惩罚法并根据第一特征空间数据和正态分布采样数据计算第二数据分布损失值;
218、使用Wasserstein距离+梯度惩罚法并根据第二特征空间数据和正态分布采样数据计算第三数据分布损失值;
本实施例生成器输出模拟图像后,反向器输出模拟数据分布,因此使用Wasserstein distance+梯度惩罚计算数据分布损失,使用均方误差MSE计算图像像素损失。
Wasserstein距离度量两个概率分布之间的距离,定义如下:
Figure SMS_1
Figure SMS_3
是P1和P2的分布组合起来的所有可能的联合分布的集合。对于每一个可能的联合分布/>
Figure SMS_5
,可以从中采样/>
Figure SMS_7
得到一个样本x和y,并计算出这对样本的距离||x−y||,所以可以计算该联合分布/>
Figure SMS_4
下,样本对距离的期望值/>
Figure SMS_6
。在所有可能的联合分布中能够对这个期望值取到的下界/>
Figure SMS_8
就是Wasserstein距离。直观上可以把/>
Figure SMS_9
理解为在/>
Figure SMS_2
这个路径规划下把土堆P1挪到土堆P2所需要的消耗。而Wasserstein距离就是在最优路径规划下的最小消耗。所以Wesserstein距离又叫Earth-Mover距离。Wessertein距离相比KL散度和JS散度的优势在于:即使两个分布的支撑集没有重叠或者重叠非常少,仍然能反映两个分布的远近。而JS散度在此情况下是常量,KL散度可能无意义。
均方误差(MSE)是最常用的回归损失函数。MSE是目标变量与预测值之间距离平方之和,本专利使用经过反向器恢复后的特征张量和采样得到的Z对应元素进行MSE,得到误差值,再进行反向梯度计算,更新神经网络的权重值。
本实施例中的SAN卷积神经网络模型包括一个生成器和一个反向器,分别对应4个损失:
1、生成器的输出的模拟图像G(Z)在反向器上的第一特征空间数据R(G(Z))和正态分布采样数据Z,通过Wasserstein距离+梯度惩罚法计算得到的第一数据分布损失值LGD=WD(R(G(Z)),Z)。
2、生成器的输出的模拟图像G(Z)和真实图像X使用均方误差法计算得到的图像像素损失值LGI=MSE(G(Z),X)。
3、生成器的输出的模拟图像G(Z)在反向器上的第一特征空间数据R(G(Z))和正态分布采样数据Z通过Wasserstein距离+梯度惩罚法计算得到的第二数据分布损失值LRZ=WD(R(G(Z)),Z)。
4、真实图像X在反向器上的输出特征R(X)和正态分布采样数据Z通过Wasserstein距离+梯度惩罚法计算得到的第三数据分布损失值LRX=WD(R(X),Z)。
219、根据第一数据分布损失值和图像像素损失值计算生成器的生成器损失值;
220、根据第二数据分布损失值和第三数据分布损失值计算反向器的反向器损失值;
本实施例中,生成器损失值为LGD+LGI,反向器损失值为LRZ+LRX。
221、判断生成器损失值和反向器损失值是否满足预设条件;
222、若是,则确定卷积神经网络模型训练完成;
本实施例中的步骤224至225与前述实施例中步骤106和107类似,此处不再赘述。
223、若否,通过小批量梯度下降法并根据生成器损失值和生成器的权重值生成新生成器权重值;
224、通过小批量梯度下降法并根据反向器损失值和反向器的权重值生成新反向器权重值;
当终端确定生成器损失值和反向器损失值未满足预设条件时,通过小批量梯度下降法并根据生成器损失值和生成器的权重值生成新生成器权重值,通过小批量梯度下降法并根据反向器损失值和反向器的权重值生成新反向器权重值。
对卷积神经网络模型的权重更新可以是多种方式,本实施例中,以小批量随机梯度下降法更新卷积神经网络模型为例,其中批训练的梯度更新方式的公式为:
Figure SMS_10
n是批量大小(batchsize),
Figure SMS_11
是学习率(learning rate),/>
Figure SMS_12
是当前权值,/>
Figure SMS_13
为更新权值,/>
Figure SMS_14
为权值更新子函数,x为预设值。
使用反向梯度求导,请参考图3,图3为一个卷积神经网络模型网络层示意图。
左侧为第一层,也是输入层,输入层包含两个神经元a和b。中间为第二层,也是隐含层,隐含层包含两个神经元c和d。右侧为第三层,也是输出层,输出层包含e和f,每条线上标的
Figure SMS_15
是层与层之间连接的权重。
Figure SMS_16
代表第l层第j个神经元,与上一层(l-1)第k个神经元输出相对应的权重。
Figure SMS_17
代表第l层第j个神经元输出。
Figure SMS_18
代表第l层第j个神经元输入。
Figure SMS_19
代表第l层第j个神经元偏置。
W代表权重矩阵,Z代表输入矩阵,A代表输出矩阵,Y代表标准答案。
L代表卷积神经网络模型的层数。
Figure SMS_20
向前传播的方法,即将输入层的信号传输至隐藏层,以隐藏层节点c为例,站在节点c上往后看(输入层的方向),可以看到有两个箭头指向节点c,因此a,b节点的信息将传递给c,同时每个箭头有一定的权重,因此对于c节点来说,输入信号为:
Figure SMS_21
同理,节点d的输入信号为:
Figure SMS_22
由于终端善于做带有循环的任务,因此可以用矩阵相乘来表示:
Figure SMS_23
所以,隐藏层节点经过非线性变换后的输出表示如下:
Figure SMS_24
同理,输出层的输入信号表示为权重矩阵乘以上一层的输出:
Figure SMS_25
同样,输出层节点经过非线性映射后的最终输出表示为:
Figure SMS_26
/>
输入信号在权重矩阵们的帮助下,得到每一层的输出,最终到达输出层。可见,权重矩阵在前向传播信号的过程中扮演着运输兵的作用,起到承上启下的功能。
请参考图4,图4为一个卷积神经网络模型网络层示意图。向后传播的方法,既然梯度下降需要每一层都有明确的误差才能更新参数,所以接下来的重点是如何将输出层的误差反向传播给隐藏层。
其中输出层、隐藏层节点的误差如图所示,输出层误差已知,接下来对隐藏层第一个节点c作误差分析。还是站在节点c上,不同的是这次是往前看(输出层的方向),可以看到指向c节点的两个蓝色粗箭头是从节点e和节点f开始的,因此对于节点c的误差肯定是和输出层的节点e和f有关。输出层的节点e有箭头分别指向了隐藏层的节点c和d,因此对于隐藏节点e的误差不能被隐藏节点c霸为己有,而是要服从按劳分配的原则(按权重分配),同理节点f的误差也需服从这样的原则,因此对于隐藏层节点c的误差为:
Figure SMS_27
其中,
Figure SMS_28
和/>
Figure SMS_29
为输出层反向传播系数,同理,对于隐藏层节点d的误差为:
Figure SMS_30
其中,
Figure SMS_31
和/>
Figure SMS_32
为隐藏层反向传播系数,为了减少工作量,可写成矩阵相乘的形式:
Figure SMS_33
该矩阵比较繁琐,可简化到前向传播的形式,不破坏它们的比例,因此我们可以忽略掉分母部分,所以重新成矩阵形式为:
Figure SMS_34
该权重矩阵,其实是前向传播时权重矩阵w的转置,因此简写形式如下:
Figure SMS_35
输出层误差在转置权重矩阵的帮助下,传递到了隐藏层,这样我们就可以利用间接误差来更新与隐藏层相连的权重矩阵。可见,权重矩阵在反向传播的过程中同样扮演着运输兵的作用,只不过这次是搬运的输出误差,而不是输入信号。
请参考图5,图5为一个卷积神经网络模型网络层示意图。接下来需要进行链式求导,上面介绍了输入信息的前向传播与输出误差的后向传播,接下来就根据求得的误差来更新参数。
首先对隐藏层的w11进行参数更新,更新之前让我们从后往前推导,直到预见w11为止,计算方式如下:
Figure SMS_36
Figure SMS_37
Figure SMS_38
因此误差对w11求偏导如下:
Figure SMS_39
求导得如下公式(所有值已知):
Figure SMS_40
同理,误差对于w12的偏导如下:
Figure SMS_41
同样,求导得w12的求值公式:
Figure SMS_42
同理,误差对于偏置求偏导如下:
Figure SMS_43
同理,误差对于偏置求偏导如下:
Figure SMS_44
接着对输入层的w11进行参数更新,更新之前我们依然从后往前推导,直到预见第一层的w11为止:
Figure SMS_45
/>
Figure SMS_46
因此误差对输入层的w11求偏导如下:
Figure SMS_47
求导得如下公式:
Figure SMS_48
同理,输入层的其他三个参数按照同样的方法即可求出各自的偏导,此处不做赘述。在每个参数偏导数明确的情况下,带入梯度下降公式即可:
Figure SMS_49
至此,利用链式法则来对每层参数进行更新的任务已经完成。
在更新了卷积神经网络模型的权重之后,保留一份卷积神经网络模型,以使得在后续训练过程中出现泛化、过拟合等问题时,还可以使用原先保存下来的卷积神经网络模型。
当卷积神经网络模型更新完成后,可以选择原始样本重新输入卷积神经网络模型训练,也可以是从重新合成新的原始样本输入卷积神经网络模型训练。
225、将新生成器权重值和新反向器权重值进行加权平均处理,生成共享更新权重值;
终端将新生成器权重值和新反向器权重值进行加权平均处理,生成共享更新权重值。终端可以通过下述公式进行加权平均:
Figure SMS_50
其中,
Figure SMS_51
为反向器权重值,/>
Figure SMS_52
为生成器权重值,/>
Figure SMS_53
为共享更新权重值,/>
Figure SMS_54
为0至1之间的预设值。除了上述公式外,还可以使用多个公式进行计算,此处不作限定。
226、将共享更新权重值作为生成器和反向器更新后的权重值;
227、对更新后的卷积神经网络模型重复进行迭代训练。
终端将共享更新权重值作为生成器和反向器更新后的权重值之后,对更新后的卷积神经网络模型重复进行迭代训练,直到卷积神经网络模型训练完成。
本实施例中,终端首先获取卷积神经网络模型,其中,卷积神经网络模型中包括生成器和反向器。终端通过第一注意力Dropout模块ADO给正态分布采样数据对应的每个神经元分配注意力,并将注意力小于第一预设阈值的神经元进行置零,生成第一中间特征。终端通过第一通道注意力模块Attention为第一中间特征生成通道向量。通过第一通道注意力模块Attention结合通道向量输出一个维度与第一中间特征通道数相同的归一化一维向量。终端通过第一通道注意力模块Attention并根据归一化一维向量将第一中间特征按通道对应相乘,生成第二中间特征。终端通过反卷积模块将第二中间特征进行卷积处理和通道叠加处理。终端通过第一区域像素注意力模块RPA对第二中间特征进行区域像素值权重生成处理,生成模拟图像。终端通过第二区域像素注意力模块RPA对模拟图像进行区域像素值权重生成处理,生成第三中间特征。终端通过卷积模块将第三中间特征进行卷积处理和通道叠加处理。终端通过第二通道注意力模块Attention为第三中间特征生成通道向量。通过第二通道注意力模块Attention结合通道向量输出一个维度与第三中间特征通道数相同的归一化一维向量。终端通过第二通道注意力模块Attention并根据归一化一维向量将第三中间特征按通道对应相乘,生成第四中间特征。终端通过第二注意力Dropout模块ADO给第四中间特征对应的每个神经元分配注意力,并将注意力小于第二预设阈值的神经元进行置零,生成第一特征空间数据。终端根据上述方法为真实图像生成第二特征空间数据。终端使用Wasserstein距离+梯度惩罚法并根据第一特征空间数据和正态分布采样数据计算第一数据分布损失值。使用均方误差法并根据模拟图像和真实图像计算图像像素损失值。使用Wasserstein距离+梯度惩罚法并根据第一特征空间数据和正态分布采样数据计算第二数据分布损失值。终端使用Wasserstein距离+梯度惩罚法并根据第二特征空间数据和正态分布采样数据计算第三数据分布损失值。接下来,终端根据第一数据分布损失值和图像像素损失值计算生成器的生成器损失值,根据第二数据分布损失值和第三数据分布损失值计算反向器的反向器损失值,若否,通过小批量梯度下降法并根据生成器损失值和生成器的权重值生成新生成器权重值,并且通过小批量梯度下降法并根据反向器损失值和反向器的权重值生成新反向器权重值。将新生成器权重值和新反向器权重值进行加权平均处理,生成共享更新权重值。终端将共享更新权重值作为生成器和反向器更新后的权重值,终端对更新后的卷积神经网络模型重复进行迭代训练。可以看出,在迭代训练过程中,通过将数据分布损失和图像像素损失进行拟合计算,生成对应的损失值,并且根据对应的损失值进行权重值的计算,最后还通过各自的权重值进行拟合计算,生成共享更新权重值,并且通过共享更新权重值进行后续的更新迭代,使得生成器和反向器两个模块进行协同配合,同时进步,使得卷积神经网络模型训练速度和准确性提高。
其次,本实施例创新了一种新的自对抗形式的图像生成模式,从图像和映射函数两个方向同时互相逼近,提高了映射的准确性和训练速度。
其次,本实施例创新地引入共享卷积权重参数方式,使得生成器G和反向器R通过直接改变卷积权重的方式进行优化,提升互相映射的准确性,改善了传统GAN生成器和判别器难以协同训练的问题。
其次,生成器和反向器对抗训练,能更好的找到数据分布特征空间和图像之间的最优映射函数,即使某一模块陷入局部最优,另一个模块也可以通过权值共享使其跳出局部最优点。
其次,使用Wasserstent Distence、梯度惩罚以及PixelNorm,提高训练稳定性。
其次,所有的理论都认为GAN应该在纳什均衡(Nashequilibrium)上有卓越的表现,但梯度下降只有在凸函数的情况下才能保证实现纳什均衡。当博弈双方都由神经网络表示时,在没有实际达到均衡的情况下,让它们永远保持对自己策略的调整是可能的,即难以收敛以及难以确定收敛条件,而本实施例的子对抗卷积神经网络的训练方法改善了传统的生成对抗网络GAN难以训练的问题。
其次,GAN模型被定义为极小极大问题,没有损失函数,在训练过程中很难区分是否正在取得进展。GAN的学习过程可能发生崩溃问题(collapseproblem),生成器开始退化,总是生成同样的样本点,无法继续学习。当生成模型崩溃时,判别模型也会对相似的样本点指向相似的方向,训练无法继续,本实施例的子对抗卷积神经网络的训练方法改善了无法收敛等问题。
并且,本实施例为一种基于自对抗的深度学习生成模式,发明了一种新的自对抗模式,包括生成器G和反向器R。生成器G将从标准正态分布中随机采样的正态分布采样数据Z使用基于反卷积和卷积操作的卷积神经网络模型的生成器生成模拟图像G(Z),而反向器则使用生成器输出模拟图像G(Z)和真实图像X,使用基于卷积神经网络模型编码成数据分布特征空间R(G(Z))和R(X)。两者进行互逆操作,生成器G将数据分布特征空间映射到图像,而反向器R将图像映射回数据分布特征空间,两个模块共享关键的卷积权重参数,即生成器G中的对应的反卷积权重参数和反向器R的卷积权重参数相同,两个模块各自从相对的角度同向逼近,相互对抗,由于共享关键权重参数,也可以和自身进行对抗,即自对抗生成模式。一般的卷积神经网络模型的生成器的生成模式都是单向映射模式,即从数据分布特征空间映射到图像,这种方式较难拟合出比较匹配的映射函数,训练时比较容易陷入局部最优值。本实施例通过两个模块对抗训练进行双向映射,采用独创的自对抗方式共享权重参数,能更好的找到数据分布特征空间和图像之间的最优映射函数,即使某一模块陷入局部最优,另一个模块也可以通过权值共享使其跳出局部最优点。传统生成对抗网络的两个模块生成器G和判别器D是完全无关的神经网络,而本专利的生成器G和反向器R属于互逆模型,且共享卷积权重参数,能更好地进行互相对抗和互相监督。本实施例的卷积神经网络模型属于计算机视觉、图像生成领域,涉及一种基于多层卷积特征提取、注意力分配、反卷积图像重构的图像处理方法。包括采集训练图像,进行图像预处理。此外本实施例还设计了一种基于自对抗的神经网络模型称为SAN,包括一个生成器G和一个反向器R。生成器G用于通过在标准正态分布中随机采样Z,使用生成器得到期望的假图像;反向器R对真实图像X和由生成器G生成的假图像G(Z)进行反向映射成数据分布特征空间。本实施例相对于传统技术,增加了L2正则化用于防止神经网络过拟合。增加了Resnet技术增加前后特征层的数据交互,最大限度保留浅层的特征,消除梯度消失现象;加入数据并行(DP)模式用于减少显存消耗和提升训练速度。使用神经网络对训练数据集进行深度学习,将从标准正态分布里随机采样数据输入已完成训练的生成器G进行推理,生产能出新的图像。
通过上述内容可以将本实施例简要为下列内容:SAN训练和生成过程,整个神经网络训练分为正向推理和反向传播:S1.正向推理:正态分布采样数据Z,生成器G产生模拟图像G(Z);将G(Z)以及真实图像X分别输入反向器R,得到R(G(Z))和R(X);S2.计算得到损失函数LGI、LGD、LRZ和LRX;S3.反向传播:优化器算法,将损失值反向传播到神经网络的各个阈值参数上,使用LGI、LGD更新生成器权重值,使用LRZ和LRX更新反向器权重值;S4.共享关键卷积权重,将Generator单元和Reverse单元里各个对应的卷积和反卷积的权重值取加权平均作为新的权重值;S5.反复进行1-4,不停更新卷积神经网络模型中的权重值(阈值参数),使得正向推理得到的生成器和判别器的损失值达到要求,停止训练;S6.训练完成后,生成时只使用生成器G从正态分布中随机采样的数据,自动生成新图像。
并且,在本实施例中还可以进行AI模型检测及部署。在经过神经网络训练,模型达到要求的检测精度时,将训练得到的生成器模型阈值参数文件加载到生成器中去。训练完成的模型需要部署上线才能工业使用,模型部署有3种方式:1.在PC上直接安装调试AI坏境,包括AI相关底层安装库、Python文件包等,使用pycharm软件调用训练完成的模型文件进行测试,这种方式安装方便,每一批测试数据测试时需要手动方式启动检测;2.生成模型可执行文件与host进行通信,通过host调用AI的可执行文件进行检测,这种方式需要修改host与AI可执行文件进行通信,可控性好,可以在host进行任何处理,可以做成自动化;3.将AI模型通过pytorch自带的C++转换工具Libtorch转换成可被C++调用文件,编写专门的软件,嵌入转换后的模型文件,独立测试,需要编写独立的软件,由于是通过C++调用模型文件,因此检测速度快。本实施例使用后两种部署方式:通过pyinstaller软件将AI训练完成的模型转换成.exe可执行文件,在客户现场通过软件调用AI可执行文件来进行检测;使用C#编写AI软件界面,使用Libtorch框架将AI模型转换成torch script格式,封装成C++对外接口,通过C#软件调用AI模型的C++接口进行检测。
请参阅图6,本申请提供了一种自对抗神经网络模型的训练装置的一个实施例,包括:
第一获取单元601,用于获取卷积神经网络模型,卷积神经网络模型中包括生成器和反向器;
第一生成单元602,用于将一组正态分布采样数据输入生成器,生成模拟图像;
可选的,生成器包括至少一个Generator单元,Generator单元包括注意力Dropout模块ADO、至少一个Attention/Conv_t块和区域像素注意力模块RPA,Attention/Conv_t块包括通道注意力模块Attention和反卷积模块;
第一生成单元602,包括:
通过第一注意力Dropout模块ADO给正态分布采样数据对应的每个神经元分配注意力,并将注意力小于第一预设阈值的神经元进行置零,生成第一中间特征;
通过第一通道注意力模块Attention为第一中间特征生成通道向量;
通过第一通道注意力模块Attention结合通道向量输出一个维度与第一中间特征通道数相同的归一化一维向量;
通过第一通道注意力模块Attention并根据归一化一维向量将第一中间特征按通道对应相乘,生成第二中间特征;
通过反卷积模块将第二中间特征进行卷积处理和通道叠加处理;
通过第一区域像素注意力模块RPA对第二中间特征进行区域像素值权重生成处理,生成模拟图像。
第二生成单元603,用于将真实图像和模拟图像输入反向器中,生成模拟图像的第一特征空间数据和真实图像的第二特征空间数据;
可选的,反向器包括至少一个 Reverse单元, Reverse单元包括区域像素注意力模块RPA、至少一个Attention/Conv块和、注意力Dropout模块ADO,Attention/Conv块包括通道注意力模块Attention和卷积模块;
第二生成单元603,包括:
通过第二区域像素注意力模块RPA对模拟图像进行区域像素值权重生成处理,生成第三中间特征;
通过卷积模块将第三中间特征进行卷积处理和通道叠加处理;
通过第二通道注意力模块Attention为第三中间特征生成通道向量;
通过第二通道注意力模块Attention结合通道向量输出一个维度与第三中间特征通道数相同的归一化一维向量;
通过第二通道注意力模块Attention并根据归一化一维向量将第三中间特征按通道对应相乘,生成第四中间特征;
通过第二注意力Dropout模块ADO给第四中间特征对应的每个神经元分配注意力,并将注意力小于第二预设阈值的神经元进行置零,生成第一特征空间数据;
根据上述方法为真实图像生成第二特征空间数据。
第一计算单元604,用于根据真实图像、模拟图像、第一特征空间数据、第二特征空间数据和正态分布采样数据计算数据分布损失和图像像素损失;
可选的,第一计算单元604,包括:
使用Wasserstein距离+梯度惩罚法并根据第一特征空间数据和正态分布采样数据计算第一数据分布损失值;
使用均方误差法并根据模拟图像和真实图像计算图像像素损失值;
使用Wasserstein距离+梯度惩罚法并根据第一特征空间数据和正态分布采样数据计算第二数据分布损失值;
使用Wasserstein距离+梯度惩罚法并根据第二特征空间数据和正态分布采样数据计算第三数据分布损失值。
第二计算单元605,用于根据数据分布损失和图像像素损失计算生成器的生成器损失值和反向器的反向器损失值;
可选的,第二计算单元605,包括:
根据第一数据分布损失值和图像像素损失值计算生成器的生成器损失值;
根据第二数据分布损失值和第三数据分布损失值计算反向器的反向器损失值。
判断单元606,用于判断生成器损失值和反向器损失值是否满足预设条件;
确定单元607,用于当判断单元确定生成器损失值和反向器损失值满足预设条件时,则确定卷积神经网络模型训练完成;
迭代单元608,用于当判断单元确定生成器损失值和反向器损失值不满足预设条件时,根据生成器损失值和反向器损失值对生成器和反向器的权重值进行拟合,生成共享更新权重值,并重复进行迭代训练。
可选的,迭代单元609,包括:
通过小批量梯度下降法并根据生成器损失值和生成器的权重值生成新生成器权重值;
通过小批量梯度下降法并根据反向器损失值和反向器的权重值生成新反向器权重值;
将新生成器权重值和新反向器权重值进行加权平均处理,生成共享更新权重值;
将共享更新权重值作为生成器和反向器更新后的权重值;
对更新后的卷积神经网络模型重复进行迭代训练。
请参阅图7,本申请提供了一种电子设备,包括:
处理器701、存储器703、输入输出单元702以及总线704。
处理器701与存储器703、输入输出单元702以及总线704相连。
存储器703保存有程序,处理器701调用程序以执行如图1、图2-a、图2-b、图2-c、图2-d中的训练方法。
本申请提供了一种计算机可读存储介质,计算机可读存储介质上保存有程序,程序在计算机上执行时执行如图1、图2-a、图2-b、图2-c、图2-d中的训练方法。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的系统,装置和单元的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在本申请所提供的几个实施例中,应该理解到,所揭露的系统,装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,read-onlymemory)、随机存取存储器(RAM,randomaccess memory)、磁碟或者光盘等各种可以存储程序代码的介质。

Claims (10)

1.一种自对抗神经网络模型的训练方法,其特征在于,包括:
获取卷积神经网络模型,所述卷积神经网络模型中包括生成器和反向器;
将一组正态分布采样数据输入所述生成器,生成模拟图像;
将真实图像和所述模拟图像输入所述反向器中,生成所述模拟图像的第一特征空间数据和所述真实图像的第二特征空间数据;
根据所述真实图像、所述模拟图像、所述第一特征空间数据、第二特征空间数据和正态分布采样数据计算数据分布损失和图像像素损失;
根据所述数据分布损失和所述图像像素损失计算所述生成器的生成器损失值和所述反向器的反向器损失值;
判断所述生成器损失值和所述反向器损失值是否满足预设条件;
若是,则确定所述卷积神经网络模型训练完成;
若否,根据所述生成器损失值和所述反向器损失值对所述生成器和所述反向器的权重值进行拟合,生成共享更新权重值,并重复进行迭代训练。
2.根据权利要求1所述的训练方法,其特征在于,所述生成器包括至少一个Generator单元,所述Generator单元包括注意力Dropout模块ADO、至少一个Attention/Conv_t块和区域像素注意力模块RPA,所述Attention/Conv_t块包括通道注意力模块Attention和反卷积模块;
所述将一组正态分布采样数据输入所述生成器,生成模拟图像,包括:
通过第一注意力Dropout模块ADO给所述正态分布采样数据对应的每个神经元分配注意力,并将注意力小于第一预设阈值的神经元进行置零,生成第一中间特征;
通过第一通道注意力模块Attention为所述第一中间特征生成通道向量;
通过第一通道注意力模块Attention结合所述通道向量输出一个维度与所述第一中间特征通道数相同的归一化一维向量;
通过第一通道注意力模块Attention并根据所述归一化一维向量将所述第一中间特征按通道对应相乘,生成第二中间特征;
通过反卷积模块将所述第二中间特征进行卷积处理和通道叠加处理;
通过第一区域像素注意力模块RPA对所述第二中间特征进行区域像素值权重生成处理,生成模拟图像。
3.根据权利要求1所述的训练方法,其特征在于,所述反向器包括至少一个 Reverse单元,所述 Reverse单元包括区域像素注意力模块RPA、至少一个Attention/Conv块和、注意力Dropout模块ADO,所述Attention/Conv块包括通道注意力模块Attention和卷积模块
所述将真实图像和所述模拟图像输入所述反向器中,生成所述模拟图像的第一特征空间数据和所述真实图像的第二特征空间数据,包括:
通过第二区域像素注意力模块RPA对所述模拟图像进行区域像素值权重生成处理,生成第三中间特征;
通过卷积模块将所述第三中间特征进行卷积处理和通道叠加处理;
通过第二通道注意力模块Attention为所述第三中间特征生成通道向量;
通过第二通道注意力模块Attention结合所述通道向量输出一个维度与所述第三中间特征通道数相同的归一化一维向量;
通过第二通道注意力模块Attention并根据所述归一化一维向量将所述第三中间特征按通道对应相乘,生成第四中间特征;
通过第二注意力Dropout模块ADO给所述第四中间特征对应的每个神经元分配注意力,并将注意力小于第二预设阈值的神经元进行置零,生成第一特征空间数据;
根据上述方法为所述真实图像生成第二特征空间数据。
4.根据权利要求1所述的训练方法,其特征在于,根据所述真实图像、所述模拟图像、所述第一特征空间数据、第二特征空间数据和正态分布采样数据计算数据分布损失和图像像素损失,包括:
使用Wasserstein距离+梯度惩罚法并根据所述第一特征空间数据和正态分布采样数据计算第一数据分布损失值;
使用均方误差法并根据所述模拟图像和所述真实图像计算图像像素损失值;
使用Wasserstein距离+梯度惩罚法并根据所述第一特征空间数据和正态分布采样数据计算第二数据分布损失值;
使用Wasserstein距离+梯度惩罚法并根据所述第二特征空间数据和所述正态分布采样数据计算第三数据分布损失值。
5.根据权利要求4所述的训练方法,其特征在于,所述根据所述数据分布损失和所述图像像素损失计算所述生成器的生成器损失值和所述反向器的反向器损失值,包括:
根据所述第一数据分布损失值和图像像素损失值计算所述生成器的生成器损失值;
根据第二数据分布损失值和第三数据分布损失值计算所述反向器的反向器损失值。
6.根据权利要求1至5中任一项所述的训练方法,其特征在于,所述根据所述生成器损失值和所述反向器损失值对所述生成器和所述反向器的权重值进行拟合,生成共享更新权重值,并重复进行迭代训练,包括:
通过小批量梯度下降法并根据所述生成器损失值和所述生成器的权重值生成新生成器权重值;
通过小批量梯度下降法并根据所述反向器损失值和所述反向器的权重值生成新反向器权重值;
将所述新生成器权重值和所述新反向器权重值进行加权平均处理,生成共享更新权重值;
将所述共享更新权重值作为所述生成器和所述反向器更新后的权重值;
对所述更新后的卷积神经网络模型重复进行迭代训练。
7.一种自对抗神经网络模型的训练装置,其特征在于,包括:
第一获取单元,用于获取卷积神经网络模型,所述卷积神经网络模型中包括生成器和反向器;
第一生成单元,用于将一组正态分布采样数据输入所述生成器,生成模拟图像;
第二生成单元,用于将真实图像和所述模拟图像输入所述反向器中,生成所述模拟图像的第一特征空间数据和所述真实图像的第二特征空间数据;
第一计算单元,用于根据所述真实图像、所述模拟图像、所述第一特征空间数据、第二特征空间数据和正态分布采样数据计算数据分布损失和图像像素损失;
第二计算单元,用于根据所述数据分布损失和所述图像像素损失计算所述生成器的生成器损失值和所述反向器的反向器损失值;
判断单元,用于判断所述生成器损失值和所述反向器损失值是否满足预设条件;
确定单元,用于当所述判断单元确定所述生成器损失值和所述反向器损失值满足预设条件时,则确定所述卷积神经网络模型训练完成;
迭代单元,用于当所述判断单元确定所述生成器损失值和所述反向器损失值不满足预设条件时,根据所述生成器损失值和所述反向器损失值对所述生成器和所述反向器的权重值进行拟合,生成共享更新权重值,并重复进行迭代训练。
8.根据权利要求7所述的训练装置,其特征在于,所述生成器包括至少一个Generator单元,所述Generator单元包括注意力Dropout模块ADO、至少一个Attention/Conv_t块和区域像素注意力模块RPA,所述Attention/Conv_t块包括通道注意力模块Attention和反卷积模块;
所述第一生成单元,包括:
通过第一注意力Dropout模块ADO给所述正态分布采样数据对应的每个神经元分配注意力,并将注意力小于第一预设阈值的神经元进行置零,生成第一中间特征;
通过第一通道注意力模块Attention为所述第一中间特征生成通道向量;
通过第一通道注意力模块Attention结合所述通道向量输出一个维度与所述第一中间特征通道数相同的归一化一维向量;
通过第一通道注意力模块Attention并根据所述归一化一维向量将所述第一中间特征按通道对应相乘,生成第二中间特征;
通过反卷积模块将所述第二中间特征进行卷积处理和通道叠加处理;
通过第一区域像素注意力模块RPA对所述第二中间特征进行区域像素值权重生成处理,生成模拟图像。
9.一种电子设备,其特征在于,包括:
处理器、存储器、输入输出单元以及总线;
所述处理器与所述存储器、所述输入输出单元以及所述总线相连;
所述存储器保存有程序,所述处理器调用所述程序以执行如权利要求1至6任意一项所述的训练方法。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上保存有程序,所述程序在计算机上执行时执行如权利要求1至6中任一项所述的训练方法。
CN202310196878.6A 2023-03-03 2023-03-03 一种自对抗神经网络模型的训练方法及相关装置 Active CN115860113B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310196878.6A CN115860113B (zh) 2023-03-03 2023-03-03 一种自对抗神经网络模型的训练方法及相关装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310196878.6A CN115860113B (zh) 2023-03-03 2023-03-03 一种自对抗神经网络模型的训练方法及相关装置

Publications (2)

Publication Number Publication Date
CN115860113A true CN115860113A (zh) 2023-03-28
CN115860113B CN115860113B (zh) 2023-07-25

Family

ID=85659902

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310196878.6A Active CN115860113B (zh) 2023-03-03 2023-03-03 一种自对抗神经网络模型的训练方法及相关装置

Country Status (1)

Country Link
CN (1) CN115860113B (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116663619A (zh) * 2023-07-31 2023-08-29 山东科技大学 基于gan网络的数据增强方法、设备以及介质

Citations (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190295302A1 (en) * 2018-03-22 2019-09-26 Northeastern University Segmentation Guided Image Generation With Adversarial Networks
CN112164122A (zh) * 2020-10-30 2021-01-01 哈尔滨理工大学 一种基于深度残差生成对抗网络的快速cs-mri重建方法
CN112489154A (zh) * 2020-12-07 2021-03-12 重庆邮电大学 基于局部优化生成对抗网络的mri运动伪影校正方法
CN112990318A (zh) * 2021-03-18 2021-06-18 中国科学院深圳先进技术研究院 持续学习方法、装置、终端及存储介质
CN113011567A (zh) * 2021-03-31 2021-06-22 深圳精智达技术股份有限公司 一种卷积神经网络模型的训练方法及装置
CN113642621A (zh) * 2021-08-03 2021-11-12 南京邮电大学 基于生成对抗网络的零样本图像分类方法
CN114038055A (zh) * 2021-10-27 2022-02-11 电子科技大学长三角研究院(衢州) 一种基于对比学习和生成对抗网络的图像生成方法
CN114298997A (zh) * 2021-12-23 2022-04-08 北京瑞莱智慧科技有限公司 一种伪造图片检测方法、装置及存储介质
CN114444013A (zh) * 2020-10-19 2022-05-06 中国石油化工股份有限公司 一种基于对抗博弈的配电网大数据修复方法
CN114549312A (zh) * 2022-02-18 2022-05-27 南京国电南自电网自动化有限公司 一种基于srgan的提高图像传输效率的方法
CN114972130A (zh) * 2022-08-02 2022-08-30 深圳精智达技术股份有限公司 一种去噪神经网络的训练方法、装置及训练设备
CN115526891A (zh) * 2022-11-28 2022-12-27 深圳精智达技术股份有限公司 一种缺陷数据集的生成模型的训练方法及相关装置

Patent Citations (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190295302A1 (en) * 2018-03-22 2019-09-26 Northeastern University Segmentation Guided Image Generation With Adversarial Networks
CN114444013A (zh) * 2020-10-19 2022-05-06 中国石油化工股份有限公司 一种基于对抗博弈的配电网大数据修复方法
CN112164122A (zh) * 2020-10-30 2021-01-01 哈尔滨理工大学 一种基于深度残差生成对抗网络的快速cs-mri重建方法
CN112489154A (zh) * 2020-12-07 2021-03-12 重庆邮电大学 基于局部优化生成对抗网络的mri运动伪影校正方法
CN112990318A (zh) * 2021-03-18 2021-06-18 中国科学院深圳先进技术研究院 持续学习方法、装置、终端及存储介质
CN113011567A (zh) * 2021-03-31 2021-06-22 深圳精智达技术股份有限公司 一种卷积神经网络模型的训练方法及装置
CN113642621A (zh) * 2021-08-03 2021-11-12 南京邮电大学 基于生成对抗网络的零样本图像分类方法
CN114038055A (zh) * 2021-10-27 2022-02-11 电子科技大学长三角研究院(衢州) 一种基于对比学习和生成对抗网络的图像生成方法
CN114298997A (zh) * 2021-12-23 2022-04-08 北京瑞莱智慧科技有限公司 一种伪造图片检测方法、装置及存储介质
CN114549312A (zh) * 2022-02-18 2022-05-27 南京国电南自电网自动化有限公司 一种基于srgan的提高图像传输效率的方法
CN114972130A (zh) * 2022-08-02 2022-08-30 深圳精智达技术股份有限公司 一种去噪神经网络的训练方法、装置及训练设备
CN115526891A (zh) * 2022-11-28 2022-12-27 深圳精智达技术股份有限公司 一种缺陷数据集的生成模型的训练方法及相关装置

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
ISHAAN GULRAJANI 等: "Improved Training of Wasserstein GANs" *
ZACHARY C LIPTON 等: "PRECISE RECOVERY OF LATENT VECTORS FROM GENERATIVE ADVERSARIAL NETWORKS" *
熊鹰飞: "基于生成对抗网络的多源跨区域遥感图像超分辨" *

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116663619A (zh) * 2023-07-31 2023-08-29 山东科技大学 基于gan网络的数据增强方法、设备以及介质
CN116663619B (zh) * 2023-07-31 2023-10-13 山东科技大学 基于gan网络的数据增强方法、设备以及介质

Also Published As

Publication number Publication date
CN115860113B (zh) 2023-07-25

Similar Documents

Publication Publication Date Title
Goroshin et al. Learning to linearize under uncertainty
KR102039397B1 (ko) 추론 과정 설명이 가능한 시각 질의 응답 장치 및 방법
CN111027576B (zh) 基于协同显著性生成式对抗网络的协同显著性检测方法
CN111814626B (zh) 一种基于自注意力机制的动态手势识别方法和系统
CN111178545B (zh) 一种动态强化学习决策训练系统
CN111209215B (zh) 应用程序的测试方法、装置、计算机设备及存储介质
CN113095254B (zh) 一种人体部位关键点的定位方法及系统
Bontempi et al. Local learning for iterated time series prediction
CN115526891B (zh) 一种缺陷数据集的生成模型的训练方法及相关装置
CN115393231B (zh) 一种缺陷图像的生成方法、装置、电子设备和存储介质
CN117079098A (zh) 一种基于位置编码的空间小目标检测方法
CN115860113A (zh) 一种自对抗神经网络模型的训练方法及相关装置
CN116188684A (zh) 基于视频序列的三维人体重建方法及相关设备
CN116486244A (zh) 基于细节增强的水下目标检测方法
CN110675311A (zh) 一种素描序约束下的素描生成的方法、装置及存储介质
Bock et al. Gray-scale ALIAS
EP3660742B1 (en) Method and system for generating image data
KR102110316B1 (ko) 뉴럴 네트워크를 이용한 변분 추론 방법 및 장치
CN111275751A (zh) 一种无监督绝对尺度计算方法及系统
CN111833395B (zh) 一种基于神经网络模型的测向体制单目标定位方法和装置
CN115457365A (zh) 一种模型的解释方法、装置、电子设备及存储介质
CN114399628A (zh) 复杂空间环境下的绝缘子高效检测系统
EP4075343A1 (en) Device and method for realizing data synchronization in neural network inference
CN112989952B (zh) 一种基于遮罩引导的人群密度估计方法及装置
CN117635418B (zh) 生成对抗网络的训练方法、双向图像风格转换方法和装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant