CN111753519A - 一种模型训练和识别方法、装置、电子设备及存储介质 - Google Patents

一种模型训练和识别方法、装置、电子设备及存储介质 Download PDF

Info

Publication number
CN111753519A
CN111753519A CN202010615855.0A CN202010615855A CN111753519A CN 111753519 A CN111753519 A CN 111753519A CN 202010615855 A CN202010615855 A CN 202010615855A CN 111753519 A CN111753519 A CN 111753519A
Authority
CN
China
Prior art keywords
network model
training
countermeasure network
accuracy
loss value
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202010615855.0A
Other languages
English (en)
Other versions
CN111753519B (zh
Inventor
秦海宁
单培
张瑞飞
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Dingfu Intelligent Technology Co Ltd
Original Assignee
Dingfu Intelligent Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Dingfu Intelligent Technology Co Ltd filed Critical Dingfu Intelligent Technology Co Ltd
Priority to CN202010615855.0A priority Critical patent/CN111753519B/zh
Priority claimed from CN202010615855.0A external-priority patent/CN111753519B/zh
Publication of CN111753519A publication Critical patent/CN111753519A/zh
Application granted granted Critical
Publication of CN111753519B publication Critical patent/CN111753519B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • Computational Linguistics (AREA)
  • General Engineering & Computer Science (AREA)
  • Biomedical Technology (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Biophysics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本申请提供一种模型训练和识别方法、装置、电子设备及存储介质,用于改善对生成对抗网络进行训练时出现梯度爆炸或者梯度消失的问题。该方法包括:获得文本数据和文本数据对应的文本类别;以文本数据为训练数据,以文本类别为训练标签,对生成对抗网络进行训练,获得生成对抗网络模型,生成对抗网络模型包括:生成器和判别器;其中,对生成对抗网络进行训练,包括:获得状态推测矩阵,状态推测矩阵表征生成器的重要程度;对状态观测矩阵和状态推测矩阵进行卡尔曼滤波运算,获得生成对抗网络模型的损失值,状态观测矩阵表征判别器的重要程度;根据生成对抗网络模型的准确率调整生成对抗网络模型的损失值。

Description

一种模型训练和识别方法、装置、电子设备及存储介质
技术领域
本申请涉及人工智能和机器学习的技术领域,具体而言,涉及一种模型训练和识别方法、装置、电子设备及存储介质。
背景技术
模型训练,是指根据训练数据对目标模型进行训练,具体的训练方式根据训练数据的情况可以包括:监督式学习和无监督学习等方式。
监督式学习(Supervised learning),又被称为监督式训练,是机器学习的一种方法,可以由训练资料中学到或建立一个学习模式(learning model)或学习函数,并依此模式推测新的实例。
无监督学习(unsupervised learning),又被称为无监督式训练,是指机器学习的一种方法,没有给定事先标记过的训练示例,自动对输入的数据进行分类或分群;无监督学习的主要包括:聚类分析(cluster analysis)、关系规则(association rule)、维度缩减(dimensionality reduce)和对抗学习(AdversarialLearning)等。
目前,在对生成对抗网络进行训练时,为了提高对生成对抗网络进行训练的稳定性,通常是直接将预测标签和训练标签之间的损失值在一个恒定的区间范围截断。这种方法极大地限制了生成对抗网络模型的表现能力,生成对抗网络难以模拟出复杂的函数,在具体的实践过程中发现,使用这种方法对生成对抗网络进行训练时出现梯度爆炸或者梯度消失的问题。
发明内容
本申请实施例的目的在于提供一种模型训练和识别方法、装置、电子设备及存储介质,用于改善对生成对抗网络进行训练时出现梯度爆炸或者梯度消失的问题。
本申请实施例提供了一种模型训练方法,包括:获得文本数据和文本数据对应的文本类别;以文本数据为训练数据,以文本类别为训练标签,对生成对抗网络进行训练,获得生成对抗网络模型,生成对抗网络模型包括:生成器和判别器;其中,对生成对抗网络进行训练,包括:获得状态推测矩阵,状态推测矩阵表征生成器的重要程度;对状态观测矩阵和状态推测矩阵进行卡尔曼滤波运算,获得生成对抗网络模型的损失值,状态观测矩阵表征判别器的重要程度;根据生成对抗网络模型的准确率调整生成对抗网络模型的损失值。在上述的实现过程中,对生成对抗网络模型进行训练时,对状态观测矩阵和获得的状态推测矩阵进行卡尔曼滤波运算,获得生成对抗网络模型的损失值,并根据生成对抗网络模型的准确率调整生成对抗网络模型的损失值;也就是说,在对生成对抗网络进行训练的过程中动态地根据生成对抗网络模型的准确率调整生成对抗网络模型的损失值,从而尽快地找到损失值的阈值范围进行截断,使得生成对抗网络模型更快更稳定地收敛,有效地改善了对生成对抗网络进行训练时出现梯度爆炸或者梯度消失的问题。
可选地,在本申请实施例中,根据生成对抗网络模型的准确率调整生成对抗网络模型的损失值,包括:判断生成对抗网络模型的准确率是否逐渐收敛;若是,则将生成对抗网络模型的损失值向第一方向重置;若否,则将生成对抗网络模型的损失值向第二方向重置,第一方向与第二方向相反。
在上述的实现过程中,若生成对抗网络模型的准确率逐渐收敛,则将生成对抗网络模型的损失值向第一方向重置;若生成对抗网络模型的准确率没有逐渐收敛,则将生成对抗网络模型的损失值向与第一方向相反的第二方向重置;也就是说,根据生成对抗网络模型的准确率动态地调整生成对抗网络模型的损失值,从而尽快地找到损失值的阈值范围进行截断,使得生成对抗网络模型更快更稳定地收敛。
可选地,在本申请实施例中,获得状态推测矩阵,包括:获得生成对抗网络模型的准确率;根据生成对抗网络模型的准确率计算状态推测矩阵。在上述的实现过程中,通过获得生成对抗网络模型的准确率;根据生成对抗网络模型的准确率计算状态推测矩阵;从而有效地提高了获得状态推测矩阵的速度。
可选地,在本申请实施例中,获得生成对抗网络模型的准确率,包括:使用生成对抗网络模型对文本数据进行预测,获得预测标签;根据预测标签和训练标签计算生成对抗网络模型的准确率。在上述的实现过程中,通过使用生成对抗网络模型对文本数据进行预测,获得预测标签;根据预测标签和训练标签计算生成对抗网络模型的准确率;从而有效地提高了获得生成对抗网络模型的准确率的速度。
本申请实施例还提供了一种识别方法,包括:获得文本内容;使用生成对抗网络模型对文本内容的类别进行识别,获得文本内容对应的类别。在上述的实现过程中,通过获得文本内容;使用训练后的生成对抗网络模型对文本内容的类别进行识别,获得文本内容对应的类别;从而有效地提高了获得文本内容对应的类别的速度。
本申请实施例还提供了一种模型训练装置,包括:数据类别获得模块,用于获得文本数据和文本数据对应的文本类别;网络模型训练模块,用于以文本数据为训练数据,以文本类别为训练标签,对生成对抗网络进行训练,获得生成对抗网络模型,生成对抗网络模型包括:生成器和判别器;其中,网络模型训练模块,包括:推测矩阵获得模块,用于获得状态推测矩阵,状态推测矩阵表征生成器的重要程度;卡尔曼滤波模块,用于对状态观测矩阵和状态推测矩阵进行卡尔曼滤波运算,获得生成对抗网络模型的损失值,状态观测矩阵表征判别器的重要程度;损失值调整模块,用于根据生成对抗网络模型的准确率调整生成对抗网络模型的损失值。
可选地,在本申请实施例中,损失值调整模块,包括:逐渐收敛判断模块,用于判断生成对抗网络模型的准确率是否逐渐收敛;第一方向重置模块,用于若生成对抗网络模型的准确率逐渐收敛,则将生成对抗网络模型的损失值向第一方向重置;第二方向重置模块,用于若生成对抗网络模型的准确率不逐渐收敛,则将生成对抗网络模型的损失值向第二方向重置,第一方向与第二方向相反。
可选地,在本申请实施例中,推测矩阵获得模块,包括:准确率获得模块,用于获得生成对抗网络模型的准确率;推测矩阵计算模块,用于根据生成对抗网络模型的准确率计算状态推测矩阵。
可选地,在本申请实施例中,准确率获得模块,包括:预测标签获得模块,用于使用生成对抗网络模型对文本数据进行预测,获得预测标签;准确率计算模块,用于根据预测标签和训练标签计算生成对抗网络模型的准确率。
本申请实施例还提供了一种识别装置,包括:文本内容获得模块,用于获得文本内容;识别类别获得模块,用于使用生成对抗网络模型对文本内容的类别进行识别,获得文本内容对应的类别。
本申请实施例还提供了一种电子设备,包括:处理器和存储器,存储器存储有处理器可执行的机器可读指令,机器可读指令被处理器执行时执行如上面描述的方法。
本申请实施例还提供了一种存储介质,该存储介质上存储有计算机程序,该计算机程序被处理器运行时执行如上面描述的方法。
附图说明
为了更清楚地说明本申请实施例的技术方案,下面将对本申请实施例中所需要使用的附图作简单地介绍,应当理解,以下附图仅示出了本申请实施例的某些实施例,因此不应被看作是对范围的限定,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他相关的附图。
图1示出的本申请实施例提供的模型训练方法的示意图;
图2示出的本申请实施例提供的识别方法的示意图;
图3示出的本申请实施例提供的模型训练装置的结构示意图;
图4示出的本申请实施例提供的识别装置的结构示意图;
图5示出的本申请实施例提供的电子设备的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整的描述。
在介绍本申请实施例提供的模型训练和识别方法之前,先介绍本申请实施例所涉及的一些概念,本申请实施例所涉及的一些概念如下:
人工智能(Artificial Intelligence,AI),是指研究、开发用于模拟、延伸和扩展人的智能的理论、方法、技术及应用系统的一门新的技术科学;人工智能是计算机科学的一个分支,人工智能企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,该领域的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。
机器学习,是指人工智能领域中研究人类学习行为的一个分支。借鉴认知科学、生物学、哲学、统计学、信息论、控制论、计算复杂性等学科或理论的观点,通过归纳、一般化、特殊化、类比等基本方法探索人类的认识规律和学习过程,建立各种能通过经验自动改进的算法,使计算机系统能够具有自动学习特定知识和技能的能力;机器学习的主要方法包括:决策树、贝叶斯学习、基于实例的学习、遗传算法、规则学习、基于解释的学习等。
生成对抗网络(Generative Adversarial Network,GAN),又被称为生成对抗式网络,是机器学习中的无监督式学习的一种方法,通过让两个神经网络相互博弈的方式进行学习。
梯度消失,是指在神经网络中,当前面隐藏层的学习速率低于后面隐藏层的学习速率,即随着隐藏层数目的增加,分类准确率反而下降了;这种现象叫梯度消失。
梯度爆炸,是指在神经网络中,当前面隐藏层的学习速率低于后面隐藏层的学习速率,即随着隐藏层数目的增加,分类准确率反而下降了;这种现象叫梯度爆炸。
损失函数(loss function),又被称为成本函数,是指一种将一个事件(在一个样本空间中的一个元素)映射到一个表达与其事件相关的经济成本或机会成本的实数上的一种函数,借此直观表示的一些"成本"与事件的关联;损失函数可以决定训练过程如何来“惩罚”网络的预测结果和真实结果之间的差异,各种不同的损失函数适用于不同类型的任务,具体例如:Softmax交叉熵损失函数常常被用于在多个类别中选出一个,而Sigmoid交叉熵损失函数常常用于多个独立的二分类问题,欧几里德损失函数常常用于结果取值范围为任意实数的问题。
服务器是指通过网络提供计算服务的设备,服务器例如:x86服务器以及非x86服务器,非x86服务器包括:大型机、小型机和UNIX服务器。当然在具体的实施过程中,上述的服务器可以具体选择大型机或者小型机,这里的小型机是指采用精简指令集计算(ReducedInstruction Set Computing,RISC)、单字长定点指令平均执行速度(MillionInstructions Per Second,MIPS)等专用处理器,主要支持UNIX操作系统的封闭且专用的提供计算服务的设备;这里的大型机,又名大型主机,是指使用专用的处理器指令集、操作系统和应用软件来提供计算服务的设备。
在介绍本申请实施例提供的模型训练和识别方法之前,先分析对照实施例中在训练生成对抗网络时出现梯度爆炸或者梯度消失的原因,具体地,对照实施例在训练生成对抗网络时直接将预测标签和训练标签之间的损失值在一个恒定的区间范围截断,这里的恒定区间范围具体例如:[-0.01,0.01]的范围区间,这里的截断也可以理解为权重剪裁(weight clipping),将损失值在一个恒定的区间范围截断或剪裁,极大地限制了生成对抗网络模型的表现能力,导致生成对抗网络很难模拟出复杂的函数,在经过多层网络传递后,就很容易出现梯度消失或者梯度爆炸的问题。这里的梯度消失或者梯度爆炸的问题出现的原因在于,GAN的判别器是一个多层网络;若将剪裁阈值(clipping threshold)设得稍微小了一点,每经过一层网络,梯度就变小一点,经过多层网络之后就会出现梯度消失的现象,这里的梯度消失又被称为指数衰减;反之,如果将剪裁阈值设得稍微大了一点,每经过一层网络,梯度变大一点,经过多层网络之后就会出现梯度爆炸的现象,这里的梯度爆炸又被称为指数爆炸。也就是说,梯度消失或者梯度爆炸的问题出现的原因在于,将生成对抗网络模型的损失值截断在一个静态且恒定的区间范围。
而在本申请实施例提供的对生成对抗网络进行训练的方法中,根据生成对抗网络模型的准确率动态地调整生成对抗网络模型的损失值,从而尽快地找到损失值的阈值范围进行截断,使得生成对抗网络模型更快更稳定地收敛,有效地改善了对生成对抗网络进行训练时出现梯度爆炸或者梯度消失的问题。
需要说明的是,本申请实施例提供的模型训练和识别方法可以被电子设备执行,这里的电子设备是指具有执行计算机程序功能的设备终端或者上述的服务器,设备终端例如:智能手机、个人电脑(personal computer,PC)、平板电脑、个人数字助理(personaldigital assistant,PDA)、移动上网设备(mobile Internet device,MID)、网络交换机或网络路由器等。
在介绍本申请实施例提供的模型训练和识别方法之前,先介绍该模型训练和识别方法适用的应用场景,这里的应用场景包括但不限于:机器学习中的对文本数据进行分类,获得文本数据的类别,这里的类别具体例如:文本的情感倾向、题材分类和主体思想等。
请参见图1示出的本申请实施例提供的模型训练方法的示意图;该模型训练方法可以包括如下步骤:
步骤S100:获得文本数据和文本数据对应的文本类别。
文本数据,是指语料库中存放的是在语言的实际使用中真实出现过的语言材料;其中,语料库是以电子计算机为载体承载语言知识的基础资源;这里的文本数据具体例如:网络上的文章、教材书中的课文、专利文献或者专利文献等文本信息,这里的文本信息是一种最为常见的非结构化数据,文本信息蕴含着大量的潜在信息。
文本类别,是指上述文本数据的具体类别,按照不同的分类方式,文本数据的类别也有所不同,具体例如:可以按照文本数据的情感类别来分类,也可以按照文本数据所属的话题或者主题分类,主题分类的列表可以包括:法律、时政或社会等等。
上述步骤S100中的文本数据和文本数据对应的文本类别可以分开获取,具体例如:人工收集文本数据,并人工地识别文本数据对应的文本类别;文本数据和文本类别也可以一起获取,例如获取由文本数据和文本数据对应的文本类别打包形成的训练数据包,该训练数据包的获得方式包括:第一种方式,获取预先存储的训练数据包,从文件系统中获取训练数据包,或者从数据库中获取训练数据包;第二种方式,从其他终端设备接收获得训练数据包;第三种方式,使用浏览器等软件获取互联网上的训练数据包,或者使用其它应用程序访问互联网获得训练数据包。
步骤S200:以文本数据为训练数据,以文本类别为训练标签,对生成对抗网络进行训练,获得生成对抗网络模型。
需要说明的是,在本申请实施例中,为了便于区分,将训练后的神经网络称为神经网络模型,神经网络模型例如生成对抗网络模型,而没有经过训练的神经网络均称为某网络,例如生成对抗网络;实际上,训练前的神经网络和训练后的神经网络模型均具有相同的网络结构,即生成对抗网络和生成对抗网络模型均具有相同的网络结构;下面介绍生成对抗网络模型的网络结构:
上述的生成对抗网络模型由一个生成器(generator)与一个由多层网络构成的判别器(discriminator)组成;生成器从潜在空间(latent space)中随机取样作为输入数据,生成器的输出结果需要尽量模仿训练集中的真实样本;判别器的输入数据则为真实样本或生成器的输出数据(即生成器的输出结果),其目的是将生成器的输出数据和真实样本中尽可能分别出来;而生成器则要尽可能地欺骗判别器,即尽可能让判别器分辨不出生成器的输出数据和真实样本,生成器和判别器相互对抗从而不断调整参数,最终目的是使判别器无法判断生成器的输出结果是否真实。
其中,上述的步骤S200中的对生成对抗网络进行训练时,可以将训练数据和训练标签分为多批次进行训练,这里的训练标签包括上述的文本类别,每批次的训练数据和训练标签的数量也可以根据具体情况进行调整,使用每批次的训练数据和训练标签对生成对抗网络进行训练的实施方式可以包括如下:
步骤S210:获得状态推测矩阵。
状态推测矩阵,是指表征生成器的重要程度的矩阵,在公式中可以使用P来表示,其中,在公式中可能有不同的表示形式,具体例如:Pk表示在训练数据和训练标签分为多批次中的第k次训练时的后验状态推测矩阵,而
Figure BDA0002561129640000091
表示在训练数据和训练标签分为多批次中的第k次训练时的先验状态推测矩阵,这里的先验和后验的区别在于,先验状态推测矩阵是在不知道本批次(例如是第k次)的生成对抗网络模型的准确率的情况下计算获得的,而后验状态推测矩阵是已经知道本批次(例如是第k次)的生成对抗网络模型的准确率的情况下,根据本批次(例如是第k次)的生成对抗网络模型的准确率进行计算获得的,这里的具体计算过程将在下面详细的说明。
上述步骤S210中的获得状态推测矩阵的实施方式可以包括:
步骤S211:获得生成对抗网络模型的准确率。
生成对抗网络模型的准确率(accuracy rate),是在生成对抗网络进行训练的过程中,将训练数据输入生成对抗网络后,获得预测标签,这里预测标签为训练标签的正确概率。
上述步骤S211中的获得生成对抗网络模型的准确率的实施方式例如:使用生成对抗网络模型对文本数据进行预测,获得预测标签;根据预测标签和训练标签计算生成对抗网络模型的准确率;具体例如:在将训练数据和训练标签分为多批次进行训练的过程中,每个批次均一共有10个文本数据和10个文本标签,这里的文本标签又被称为训练标签,将这10个文本数据作为训练数据输入生成对抗网络,生成对抗网络输出10个预测标签,若10个预测标签的具体值和10个文本标签的具体值相等,那么上述生成对抗网络模型的准确率为100%;若10个预测标签的具体值只有5个和10个文本标签中的具体值相等,那么上述生成对抗网络模型的准确率为50%。在上述的实现过程中,通过使用生成对抗网络模型对文本数据进行预测,获得预测标签;根据预测标签和训练标签计算生成对抗网络模型的准确率;从而有效地提高了获得生成对抗网络模型的准确率的速度。
步骤S212:根据生成对抗网络模型的准确率计算状态推测矩阵。
上述步骤S212中的根据生成对抗网络模型的准确率计算状态推测矩阵的实施方式例如:在初始状态下,即在使用训练数据和训练标签分为多批次中的第一批次训练对抗网络模型时,可以直接根据生成对抗网络模型的准确率确定状态推测矩阵,具体例如将状态推测矩阵中的每个值均设置为生成对抗网络模型的准确率;在使用训练数据和训练标签分为多批次中的第一批次之后的批次训练对抗网络模型时,即使用多批次中的第二批次、第三批次直到所有批次的训练数据和训练标签用完,需要结合生成对抗网络模型的准确率收敛情况,使用生成对抗网络模型的准确率进行卡尔曼滤波运算,获得状态推测矩阵,这里的卡尔曼滤波运算是一个迭代运算的过程,因此,具体的卡尔曼滤波运算过程将在下面详细的描述。
步骤S220:对状态观测矩阵和状态推测矩阵进行卡尔曼滤波运算,获得生成对抗网络模型的损失值。
状态观测矩阵,是指表征判别器的重要程度的矩阵,在公式中可以使用R来表示,其中,上述的状态观测矩阵的具体获得方式例如:根据生成对抗网络模型的准确率确定状态观测矩阵,具体例如:在将训练数据和训练标签分为多批次进行训练的过程中,使用生成对抗网络模型中的判别器对每个批次进行预测,获得多个预测标签,再将多个预测标签乘以每个批次对抗网络模型的准确率,获得一维标签向量,再将一维标签向量按照状态观测矩阵的矩阵格式转换为状态观测矩阵。
当然在具体实施过程中,也可以根据生成对抗网络模型的准确率的变化值确定状态观测矩阵,具体例如:根据上述的准确率变化值可以获知准确率的变化情况;若准确率变化值为负值,即说明准确率下降,则将上一次的状态观测矩阵向下调整,即将上一次的状态观测矩阵按照准确率变化值与准确率的比例减小(具体例如:将准确率变化值除以准确率,获得变化比例;然后将上一次的状态观测矩阵乘以变化比例,获得变化矩阵;最后将上一次的状态观测矩阵减去变化矩阵,获得本批次的状态观测矩阵),获得本次的状态观测矩阵;若准确率变化值为正值,即说明准确率上升,则将上一次的状态观测矩阵向上调整,即将上一次的状态观测矩阵按照准确率变化值与准确率的比例增大(具体增大方式与减小方式类似),获得本次的状态观测矩阵。
卡尔曼滤波(Kalman filter)是一种高效率的递归滤波器(自回归滤波器),卡尔曼滤波能够从一系列的不完全及包含噪声的测量中,估计动态系统的状态。应用在本实施例中,这里的测量可以理解为上述计算对抗生成网络模型的正确率的过程,这里的动态系统可以理解为训练对抗生成网络模型的过程,这里的估计动态系统的状态则可以理解为预测对抗生成网络模型的损失值,整个训练对抗生成网络模型的过程就是要使损失值最小,但是在每一个批次的训练过程中,并不知道损失值的变化情况;在训练的过程中存在很多干扰噪声,具体例如:错误的训练标签、获得损失值的方法不合理或者设置的训练超参数不合理等等,这些干扰噪声都能够影响对抗生成网络模型的损失值;卡尔曼滤波会根据各测量在不同时间下的值,考虑各时间下的联合分布,再产生对未知变数的估计,因此会比只以单一测量为基础的估计方式要准确;也就是说,根据每一个批次训练时计算获得的正确率值,考虑各个批次下预测的正确率存在预测误差的分布情况,联合实际计算获得的正确率存在干扰误差的分布,动态地预测对抗生成网络模型的损失值,这种动态地预测损失值的方式会比仅考虑实际计算获得的正确率来预测损失值的方式更为准确。
生成对抗网络模型的损失值(Loss),是指决定在生成对抗网络模型的训练过程如何来“惩罚”网络的预测结果和真实结果之间的差异的值,也可以理解为根据生成对抗网络模型的损失函数计算预测标签和训练标签之间的差异值。
上述的计算预测标签和训练标签之间的差异值有很多表示方式包括:
第一种方式,使用KL散度(Kullback-Leibler divergence,KLD)表征预测标签和训练标签之间的差异值,这里的KL散度在信息系统中称为相对熵(relative entropy),在连续时间序列中称为随机性(randomness),在统计模型推断中称为信息增益(informationgain),也称信息散度(information divergence)。
第二种方式,使用JS散度(Jensen Shannon divergence,JSD)表征预测标签和训练标签之间的差异值,这里的JS散度是指度量两个概率分布的相似度,基于KL散度的变体,解决了KL散度非对称的问题。
第三种方式,使用Wasserstein距离表征预测标签和训练标签之间的差异值,这里的Wasserstein距离是度量两个概率分布之间的距离。
上述步骤S220中的对状态观测矩阵和状态推测矩阵进行卡尔曼滤波运算的实施方式可以包括:
根据
Figure BDA0002561129640000131
对状态观测矩阵和状态推测矩阵进行卡尔曼滤波运算;
其中,k表征训练数据和训练标签分为多批次中的第k批次,
Figure BDA0002561129640000132
Figure BDA0002561129640000133
分别表示在第k次和第k-1次训练时的先验正确率,
Figure BDA0002561129640000134
表示在第k次训练时的后验正确率,这里的先验和后验的区别在上面的描述过程中已经提到,A表示在没有噪声干扰的情况下第k-1批次的正确率与第k批次的正确率的关联程度,这里的A在每个批次训练的过程中是可以改变的,B表示控制输入参数与对抗生成网络模型的正确率之间的关联程度;
Figure BDA0002561129640000141
表示在第k次训练时的先验状态推测矩阵,Pk和Pk-1分别表示在第k次和第k-1次训练时的后验状态推测矩阵,Q表示干扰噪声协方差矩阵,Kk表示在第k次训练时的卡尔曼滤波的系数,H表示对抗生成网络模型的损失值与对抗生成网络模型的正确率之间的关联程度,R表示状态观测矩阵,zk=Hxk+vk表示对抗生成网络模型的损失值,这里的vk表示在获得对抗生成网络模型的损失值的过程中的干扰噪声,I表示单位矩阵。
可以理解的是,上述的卡尔曼滤波的系数又被称为卡尔曼系数,卡尔曼系数的作用包括:权衡状态推测矩阵P和状态观测矩阵R的大小,来决定是相信生成器多一点还是判别器多一点,具体地,在公式
Figure BDA0002561129640000142
中,若状态观测矩阵R趋近于0,那么卡尔曼系数K获得的残差权重越大,对应地,若在第k次训练时的先验状态推测矩阵
Figure BDA0002561129640000143
趋近于0,那么卡尔曼系数K获得的残差权重越小;这里的残差权重是指卡尔曼系数在公式中的重要程度,即影响GAN模型在第k次训练时的先验正确率是否接近后验正确率的权重,这里的先验正确率与后验正确率之间的差值即可理解为上面的残差;如果相信预测模型多一点,这个生成对抗网络模型的残差权重就会小一点,如果相信观察模型多一点,残差权重就会大一点。
步骤S230:根据生成对抗网络模型的准确率调整生成对抗网络模型的损失值。
上述的步骤S230的实施方式可以包括:
步骤S231:判断生成对抗网络模型的准确率是否逐渐收敛。
上述步骤S231中的判断生成对抗网络模型的准确率是否逐渐收敛的实施方式例如:根据生成对抗网络模型的准确率的历史数据来判断生成对抗网络模型的准确率是否逐渐收敛,具体可以包括如下方式:
第一种方式,可以根据准确率的历史数据的斜率来判断生成对抗网络模型的准确率是否逐渐收敛;具体例如:若准确率的历史数据的斜率小于预设阈值,则确定生成对抗网络模型的准确率逐渐收敛;若准确率的历史数据的斜率大于或等于预设阈值,则确定生成对抗网络模型的准确率没有逐渐收敛,这里的预设阈值可以根据具体情况进行设置。
第二种方式,可以根据准确率的历史数据在预设周期内的变化情况来判断生成对抗网络模型的准确率是否逐渐收敛;具体例如:若准确率的历史数据在预设周期内的变化率大于预设比例,则确定生成对抗网络模型的准确率逐渐收敛;若准确率的历史数据在预设周期内的变化率小于或等于预设比例,则确定生成对抗网络模型的准确率没有逐渐收敛,这里的预设比例可以根据具体情况进行设置。
步骤S232:若生成对抗网络模型的准确率逐渐收敛,则将生成对抗网络模型的损失值向第一方向重置。
上述步骤S232中的实施方式例如:若生成对抗网络模型的准确率逐渐稳定收敛,则将生成对抗网络模型的损失值尽可能地增大,具体例如:若生成对抗网络模型的准确率逐渐稳定收敛,则将生成对抗网络模型的损失值设置乘以1.1或者乘以1.01;当然在具体的实施过程中,还可以将生成对抗网络模型的损失值乘以其它大于1的数字。
步骤S233:若生成对抗网络模型的准确率没有逐渐收敛,则将生成对抗网络模型的损失值向第二方向重置,第一方向与第二方向相反。
上述步骤S233中的实施方式例如:若生成对抗网络模型的准确率没有逐渐收敛,则将生成对抗网络模型的损失值尽可能地减小,具体例如:若生成对抗网络模型的准确率没有逐渐收敛,则将生成对抗网络模型的损失值乘以0.99或者乘以0.999;当然在具体的实施过程中,还可以将生成对抗网络模型的损失值乘以其它小于1的数字。在上述的实现过程中,若生成对抗网络模型的准确率逐渐收敛,则将生成对抗网络模型的损失值向第一方向重置;若生成对抗网络模型的准确率没有逐渐收敛,则将生成对抗网络模型的损失值向与第一方向相反的第二方向重置;也就是说,根据生成对抗网络模型的准确率动态地调整生成对抗网络模型的损失值,从而尽快地找到损失值的阈值范围进行截断,使得生成对抗网络模型更快更稳定地收敛。
在上述的实现过程中,对生成对抗网络模型进行训练时,对状态观测矩阵和获得的状态推测矩阵进行卡尔曼滤波运算,获得生成对抗网络模型的损失值,并根据生成对抗网络模型的准确率调整生成对抗网络模型的损失值;也就是说,在对生成对抗网络进行训练的过程中动态地根据生成对抗网络模型的准确率调整生成对抗网络模型的损失值,从而尽快地找到损失值的阈值范围进行截断,使得生成对抗网络模型更快更稳定地收敛,有效地改善了对生成对抗网络进行训练时出现梯度爆炸或者梯度消失的问题。
请参见图2,本申请实施例还提供了一种识别方法,在训练好生成对抗网络模型之后,可以应用该生成对抗网络模型对文本内容的类别进行识别,即在步骤S200之后,还可以包括如下步骤:
步骤S300:获得文本内容。
文本内容,是指以文本方式存储的信息内容,这里的文本内容具体例如:网络上的文章、教材书中的课文、专利文献或者专利文献等文本信息,这里的文本信息是一种最为常见的非结构化数据,文本信息蕴含着大量的潜在信息。
上述步骤S300中的获得文本内容的实施方式包括:第一种方式,获取预先存储的文本内容,如从文件系统中获取文本内容,或者从数据库中获取文本内容;第二种方式,从其他终端设备接收获得文本内容;第三种方式,使用浏览器等软件获取互联网上的文本内容,或者使用其它应用程序访问互联网获得文本内容。
步骤S400:使用生成对抗网络模型对文本内容的类别进行识别,获得文本内容对应的类别。
上述步骤S400的实施方式例如:使用生成对抗网络模型对文本内容的类别进行识别,获得文本内容对应的类别;这里的生成对抗网络模型具体可以包括:GAN模型、WGAN(Wasserstein GAN)模型或者WGAN-GP(Wasserstein GAN-gradient penalty)模型。在上述的实现过程中,通过获得文本内容;使用训练后的生成对抗网络模型对文本内容的类别进行识别,获得文本内容对应的类别;从而有效地提高了获得文本内容对应的类别的速度。
请参见图3示出的本申请实施例提供的模型训练装置的结构示意图;本申请实施例提供了一种模型训练装置500,包括:
数据类别获得模块510,用于获得文本数据和文本数据对应的文本类别。
网络模型训练模块520,用于以文本数据为训练数据,以文本类别为训练标签,对生成对抗网络进行训练,获得生成对抗网络模型,生成对抗网络模型包括:生成器和判别器。
其中,网络模型训练模块520,包括:
推测矩阵获得模块521,用获得状态推测矩阵,状态推测矩阵表征生成器的重要程度。
卡尔曼滤波模块522,用于对状态观测矩阵和状态推测矩阵进行卡尔曼滤波运算,获得生成对抗网络模型的损失值,状态观测矩阵表征判别器的重要程度。
损失值调整模块523,用于根据生成对抗网络模型的准确率调整生成对抗网络模型的损失值。
可选地,在本申请实施例中,损失值调整模块,包括:
逐渐收敛判断模块,用于判断生成对抗网络模型的准确率是否逐渐收敛。
第一方向重置模块,用于若生成对抗网络模型的准确率逐渐收敛,则将生成对抗网络模型的损失值向第一方向重置。
第二方向重置模块,用于若生成对抗网络模型的准确率不逐渐收敛,则将生成对抗网络模型的损失值向第二方向重置,第一方向与第二方向相反。
可选地,在本申请实施例中,推测矩阵获得模块,包括:
准确率获得模块,用于获得生成对抗网络模型的准确率。
推测矩阵计算模块,用于根据生成对抗网络模型的准确率计算状态推测矩阵。
可选地,在本申请实施例中,准确率获得模块,包括:
预测标签获得模块,用于使用生成对抗网络模型对文本数据进行预测,获得预测标签。
准确率计算模块,用于根据预测标签和训练标签计算生成对抗网络模型的准确率。
请参见图4示出的本申请实施例提供的识别装置的结构示意图;本申请实施例提供了一种识别装置600,包括:
文本内容获得模块610,用于获得文本内容。
识别类别获得模块620,用于使用生成对抗网络模型对文本内容的类别进行识别,获得文本内容对应的类别。
应理解的是,该装置与上述的模型训练和识别方法实施例对应,能够执行上述方法实施例涉及的各个步骤,该装置具体的功能可以参见上文中的描述,为避免重复,此处适当省略详细描述。该装置包括至少一个能以软件或固件(firmware)的形式存储于存储器中或固化在装置的操作系统(operating system,OS)中的软件功能模块。
请参见图5示出的本申请实施例提供的电子设备的结构示意图。本申请实施例提供的一种电子设备700,包括:处理器710和存储器720,存储器720存储有处理器710可执行的机器可读指令,机器可读指令被处理器710执行时执行如上的方法。
本申请实施例还提供了一种存储介质730,该存储介质730上存储有计算机程序,该计算机程序被处理器710运行时执行如上的模型训练和识别方法。
其中,存储介质730可以由任何类型的易失性或非易失性存储设备或者它们的组合实现,如静态随机存取存储器(Static Random Access Memory,简称SRAM),电可擦除可编程只读存储器(Electrically Erasable Programmable Read-Only Memory,简称EEPROM),可擦除可编程只读存储器(Erasable Programmable Read Only Memory,简称EPROM),可编程只读存储器(Programmable Red-Only Memory,简称PROM),只读存储器(Read-Only Memory,简称ROM),磁存储器,快闪存储器,磁盘或光盘。
本申请实施例所提供的几个实施例中,应该理解到,所揭露的装置和方法,也可以通过其他的方式实现。以上所描述的装置实施例仅仅是示意性的,例如,附图中的流程图和框图显示了根据本申请实施例的多个实施例的装置、方法和计算机程序产品的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段或代码的一部分,模块、程序段或代码的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。也应当注意,在有些作为替换的实现方式中,方框中所标注的功能也可以不同于附图中所标注的顺序发生。例如,两个连续的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这依所涉及的功能而定。也要注意的是,框图和/或流程图中的每个方框、以及框图和/或流程图中的方框的组合,可以用执行规定的功能或动作的专用的基于硬件的系统来实现,或者可以用专用硬件与计算机指令的组合来实现。
另外,在本申请实施例各个实施例中的各功能模块可以集成在一起形成一个独立的部分,也可以是各个模块单独存在,也可以两个或两个以上模块集成形成一个独立的部分。
在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。
以上的描述,仅为本申请实施例的可选实施方式,但本申请实施例的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请实施例揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本申请实施例的保护范围之内。

Claims (10)

1.一种模型训练方法,其特征在于,包括:
获得文本数据和文本数据对应的文本类别;
以所述文本数据为训练数据,以所述文本类别为训练标签,对生成对抗网络进行训练,获得生成对抗网络模型,所述生成对抗网络模型包括:生成器和判别器;
其中,所述对生成对抗网络进行训练,包括:
获得状态推测矩阵,所述状态推测矩阵表征所述判别器的重要程度;
对状态观测矩阵和所述状态推测矩阵进行卡尔曼滤波运算,获得所述生成对抗网络模型的损失值,所述状态观测矩阵表征所述生成器的重要程度;
根据所述生成对抗网络模型的准确率调整所述生成对抗网络模型的损失值。
2.根据权利要求1所述的方法,其特征在于,所述根据所述生成对抗网络模型的准确率调整所述生成对抗网络模型的损失值,包括:
判断所述生成对抗网络模型的准确率是否逐渐收敛;
若是,则将所述生成对抗网络模型的损失值向第一方向重置;
若否,则将所述生成对抗网络模型的损失值向第二方向重置,所述第一方向与所述第二方向相反。
3.根据权利要求1所述的方法,其特征在于,所述获得状态推测矩阵,包括:
获得所述生成对抗网络模型的准确率;
根据所述生成对抗网络模型的准确率计算所述状态推测矩阵。
4.根据权利要求3所述的方法,其特征在于,所述获得所述生成对抗网络模型的准确率,包括:
使用所述生成对抗网络模型对所述文本数据进行预测,获得预测标签;
根据所述预测标签和所述训练标签计算所述生成对抗网络模型的准确率。
5.根据权利要求1-4任一所述的方法,其特征在于,所述生成对抗网络模型为WGAN-GP模型。
6.一种识别方法,其特征在于,包括:
获得文本内容;
使用如权利要求1-5任一所述生成对抗网络模型对所述文本内容的类别进行识别,获得所述文本内容对应的类别。
7.一种模型训练装置,其特征在于,包括:
数据类别获得模块,用于获得文本数据和文本数据对应的文本类别;
网络模型训练模块,用于以所述文本数据为训练数据,以所述文本类别为训练标签,对生成对抗网络进行训练,获得生成对抗网络模型,所述生成对抗网络模型包括:生成器和判别器;
其中,所述网络模型训练模块,包括:
推测矩阵获得模块,用于获得状态推测矩阵,所述状态推测矩阵表征所述判别器的重要程度;
卡尔曼滤波模块,用于对状态观测矩阵和所述状态推测矩阵进行卡尔曼滤波运算,获得所述生成对抗网络模型的损失值,所述状态观测矩阵表征所述生成器的重要程度;
损失值调整模块,用于根据所述生成对抗网络模型的准确率调整所述生成对抗网络模型的损失值。
8.根据权利要求7所述的装置,其特征在于,所述损失值调整模块,包括:
逐渐收敛判断模块,用于判断所述生成对抗网络模型的准确率是否逐渐收敛;
第一方向重置模块,用于若所述生成对抗网络模型的准确率逐渐收敛,则将所述生成对抗网络模型的损失值向第一方向重置;
第二方向重置模块,用于若所述生成对抗网络模型的准确率不逐渐收敛,则将所述生成对抗网络模型的损失值向第二方向重置,所述第一方向与所述第二方向相反。
9.一种电子设备,其特征在于,包括:处理器和存储器,所述存储器存储有所述处理器可执行的机器可读指令,所述机器可读指令被所述处理器执行时执行如权利要求1-6任一所述的方法。
10.一种存储介质,其特征在于,该存储介质上存储有计算机程序,该计算机程序被处理器运行时执行如权利要求1-6任一所述的方法。
CN202010615855.0A 2020-06-29 一种模型训练和识别方法、装置、电子设备及存储介质 Active CN111753519B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010615855.0A CN111753519B (zh) 2020-06-29 一种模型训练和识别方法、装置、电子设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010615855.0A CN111753519B (zh) 2020-06-29 一种模型训练和识别方法、装置、电子设备及存储介质

Publications (2)

Publication Number Publication Date
CN111753519A true CN111753519A (zh) 2020-10-09
CN111753519B CN111753519B (zh) 2024-05-28

Family

ID=

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112800828A (zh) * 2020-12-18 2021-05-14 零八一电子集团有限公司 地面栅格占有概率目标轨迹方法

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112800828A (zh) * 2020-12-18 2021-05-14 零八一电子集团有限公司 地面栅格占有概率目标轨迹方法

Similar Documents

Publication Publication Date Title
KR20180125905A (ko) 딥 뉴럴 네트워크(Deep Neural Network)를 이용하여 문장이 속하는 클래스(class)를 분류하는 방법 및 장치
CN111738436B (zh) 一种模型蒸馏方法、装置、电子设备及存储介质
KR20200022739A (ko) 데이터 증강에 기초한 인식 모델 트레이닝 방법 및 장치, 이미지 인식 방법 및 장치
CN111612134B (zh) 神经网络结构搜索方法、装置、电子设备及存储介质
US20200167593A1 (en) Dynamic reconfiguration training computer architecture
CN112131578A (zh) 攻击信息预测模型的训练方法、装置、电子设备及存储介质
EP3769270A1 (en) A method, an apparatus and a computer program product for an interpretable neural network representation
CN115511069A (zh) 神经网络的训练方法、数据处理方法、设备及存储介质
CN112488316A (zh) 事件意图推理方法、装置、设备及存储介质
CN111191722B (zh) 通过计算机训练预测模型的方法及装置
CN115062709A (zh) 模型优化方法、装置、设备、存储介质及程序产品
CN114662601A (zh) 基于正负样本的意图分类模型训练方法及装置
CN111652320B (zh) 一种样本分类方法、装置、电子设备及存储介质
CN111340150B (zh) 用于对第一分类模型进行训练的方法及装置
Catania et al. Deep convolutional neural networks for DGA detection
CN111488950A (zh) 分类模型信息输出方法及装置
CN111753519B (zh) 一种模型训练和识别方法、装置、电子设备及存储介质
CN111753519A (zh) 一种模型训练和识别方法、装置、电子设备及存储介质
CN112861601A (zh) 生成对抗样本的方法及相关设备
CN114091555A (zh) 图像识别模型的训练方法、装置、电子设备及存储介质
De Oliveira et al. Inference from aging information
Dey et al. Analysis of machine learning algorithms by developing a phishing email and website detection model
CN116912921B (zh) 表情识别方法、装置、电子设备及可读存储介质
CN113779396B (zh) 题目推荐方法和装置、电子设备、存储介质
CN117932073B (zh) 一种基于提示工程的弱监督文本分类方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant