CN112508178A - 神经网络结构搜索方法、装置、电子设备及存储介质 - Google Patents

神经网络结构搜索方法、装置、电子设备及存储介质 Download PDF

Info

Publication number
CN112508178A
CN112508178A CN202011471982.4A CN202011471982A CN112508178A CN 112508178 A CN112508178 A CN 112508178A CN 202011471982 A CN202011471982 A CN 202011471982A CN 112508178 A CN112508178 A CN 112508178A
Authority
CN
China
Prior art keywords
model
searched
loss value
cross entropy
neural network
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202011471982.4A
Other languages
English (en)
Inventor
李健铨
刘小康
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Dingfu Intelligent Technology Co Ltd
Original Assignee
Dingfu Intelligent Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Dingfu Intelligent Technology Co Ltd filed Critical Dingfu Intelligent Technology Co Ltd
Priority to CN202011471982.4A priority Critical patent/CN112508178A/zh
Publication of CN112508178A publication Critical patent/CN112508178A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/90Details of database functions independent of the retrieved data types
    • G06F16/903Querying
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/086Learning methods using evolutionary algorithms, e.g. genetic algorithms or genetic programming

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Evolutionary Computation (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Databases & Information Systems (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Physiology (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本申请提供一种神经网络结构搜索方法、装置、电子设备及存储介质,用于改善搜索到有效神经网络结构模型的速度比较慢的问题。该方法包括:计算待搜索结构的多个隐含层输出与指导模型的多个隐含层输出之间的推土机距离损失值;根据推土机距离损失值对待搜索结构进行可微分网络结构搜索,获得待搜索模型;分别计算训练标签和待搜索模型输出的第一结果之间的第一交叉熵,以及计算待搜索模型输出的第一结果和指导模型输出的第二结果之间的第二交叉熵,并根据第一交叉熵和第二交叉熵计算交叉熵损失值;根据推土机距离损失值和交叉熵损失值对待搜索模型进行训练,获得神经网络模型。

Description

神经网络结构搜索方法、装置、电子设备及存储介质
技术领域
本申请涉及机器学习和深度学习的技术领域,具体而言,涉及一种神经网络结构搜索方法、装置、电子设备及存储介质。
背景技术
网络结构搜索(Network Architecture Search,NAS),又被称为神经网络搜索或者神经网络结构搜索,是指自动生成神经网络结构的方法过程。
目前的网络结构搜索的方法包括:基于强化学习的结构搜索算法和基于进化算法的结构搜索方法,这两种网络结构搜索方法都是将网络结构搜索的过程看成对黑箱进行优化的过程,使用强化学习或者变异遗传的思路来找到较优的神经网络结构模型,然而在实现过程中发现,使用上述方法搜索到有效神经网络结构模型的速度比较慢。
发明内容
本申请实施例的目的在于提供一种神经网络结构搜索方法、装置、电子设备及存储介质,用于改善搜索到有效神经网络结构模型的速度比较慢的问题。
本申请实施例提供了一种神经网络结构搜索方法,包括:计算待搜索结构的多个隐含层输出与指导模型的多个隐含层输出之间的推土机距离损失值;根据推土机距离损失值对待搜索结构进行可微分网络结构搜索,获得待搜索模型;分别计算训练标签和待搜索模型输出的第一结果之间的第一交叉熵,以及计算待搜索模型输出的第一结果和指导模型输出的第二结果之间的第二交叉熵,并根据第一交叉熵和第二交叉熵计算交叉熵损失值;根据推土机距离损失值和交叉熵损失值对待搜索模型进行训练,获得搜索后的神经网络模型。在上述的实现过程中,通过在对待搜索结构进行可微分网络结构搜索的过程中,使用推土机距离(EMD)来量化待搜索结构的多个隐含层与指导模型的多个隐含层在结构上的差异,并在对待搜索模型进行训练的过程中,也加入了表征推土机距离的推土机距离损失值,有效地使用推土机距离量化网络结构搜索过程和模型训练过程的进度,减少了使用强化学习或者变异遗传的思路解决黑箱优化问题的不确定性,从而加快了搜索到有效神经网络结构模型的速度。
可选地,在本申请实施例中,计算待搜索结构的多个隐含层输出与指导模型的多个隐含层输出之间的推土机距离损失值,包括:计算指导模型的每个隐含层输出和待搜索结构的每个隐含层输出之间的均方误差;计算指导模型的每个隐含层输出和待搜索结构的每个隐含层输出之间的转移矩阵;根据均方误差和转移矩阵计算推土机距离损失值。在上述的实现过程中,通过根据待搜索结构的多个隐含层输出与指导模型的多个隐含层输出之间的均方误差和转移矩阵计算推土机距离损失值;从而有效地提高了计算推土机距离损失值的准确率。
可选地,在本申请实施例中,根据推土机距离损失值对待搜索结构进行可微分网络结构搜索,获得待搜索模型,包括:若推土机距离损失值小于预设阈值,则获取待搜索结构的多个隐含层中的每个隐含层对应的结构参数,隐含层包括多个节点,每个节点包括多个神经网络基础单元,结构参数表征神经网络基础单元之间的连接权重;从每个隐含层对应的多个节点中筛选出结构参数最大的节点,并从结构参数最大的节点对应的多个神经网络基础单元中筛选出结构参数最大的神经网络基础单元,获得待搜索模型。在上述的实现过程中,若推土机距离损失值小于预设阈值,则获取待搜索结构的多个隐含层中的每个隐含层对应的结构参数;从而使用推土机距离(EMD)来量化待搜索结构的多个隐含层与指导模型的多个隐含层在结构上的差异,减少了使用强化学习或者变异遗传的思路解决黑箱优化问题的不确定性,从而加快了搜索到有效神经网络结构模型的速度。
可选地,在本申请实施例中,分别计算训练标签和待搜索模型输出的第一结果之间的第一交叉熵,以及计算待搜索模型输出的第一结果和指导模型输出的第二结果之间的第二交叉熵,包括:获得训练标签和训练标签对应的训练数据;使用待搜索模型对训练数据进行预测,获得训练数据对应的第一结果,并使用指导模型训练数据进行预测,获得训练数据对应的第二结果;分别计算训练标签与第一结果之间的第一交叉熵,以及第一结果与第二结果之间的第二交叉熵。在上述的实现过程中,通过训练标签与第一结果之间的第一交叉熵,以及第一结果与第二结果之间的第二交叉熵,来计算获得结合软目标和硬目标的交叉熵损失值;从而避免了直接只使用硬目标来获得交叉熵损失值,提高了计算交叉熵损失值的准确率。
可选地,在本申请实施例中,根据推土机距离损失值和交叉熵损失值对待搜索模型进行训练,包括:对推土机距离损失值和交叉熵损失值进行加权融合,获得总损失值;根据总损失值对待搜索模型进行训练。在上述的实现过程中,通过对推土机距离损失值和交叉熵损失值进行加权融合,获得总损失值;根据总损失值对待搜索模型进行训练;从而避免了只根据交叉熵损失值对待搜索模型进行训练,进一步提高了搜索到神经网络结构模型的速度和准确率。
可选地,在本申请实施例中,根据总损失值对待搜索模型进行训练,包括:根据总损失值更新待搜索模型的网络参数,网络参数表征待搜索模型的输入数据的权重。
本申请实施例还提供了一种神经网络结构搜索装置,包括:第一损失计算模块,用于计算待搜索结构的多个隐含层输出与指导模型的多个隐含层输出之间的推土机距离损失值;网络结构搜索模块,用于根据推土机距离损失值对待搜索结构进行可微分网络结构搜索,获得待搜索模型;第二损失计算模块,用于分别计算训练标签和待搜索模型输出的第一结果之间的第一交叉熵,以及计算待搜索模型输出的第一结果和指导模型输出的第二结果之间的第二交叉熵,并根据第一交叉熵和第二交叉熵计算交叉熵损失值;搜索模型获得模块,用于根据推土机距离损失值和交叉熵损失值对待搜索模型进行训练,获得搜索后的神经网络模型。
可选地,在本申请实施例中,第一损失计算模块,包括:均方误差计算模块,用于计算指导模型的每个隐含层输出和待搜索结构的每个隐含层输出之间的均方误差;转移矩阵计算模块,用于计算指导模型的每个隐含层输出和待搜索结构的每个隐含层输出之间的转移矩阵;距离损失计算模块,用于根据均方误差和转移矩阵计算推土机距离损失值。
可选地,在本申请实施例中,网络结构搜索模块,包括:结构参数获取模块,用于若推土机距离损失值小于预设阈值,则获取待搜索结构的多个隐含层中的每个隐含层对应的结构参数,隐含层包括多个节点,每个节点包括多个神经网络基础单元,结构参数表征神经网络基础单元之间的连接权重;结构参数筛选模块,用于从每个隐含层对应的多个节点中筛选出结构参数最大的节点,并从结构参数最大的节点对应的多个神经网络基础单元中筛选出结构参数最大的神经网络基础单元,获得待搜索模型。
可选地,在本申请实施例中,第二损失计算模块,包括:标签数据获得模块,用于获得训练标签和训练标签对应的训练数据;数据标签预测模块,用于使用待搜索模型对训练数据进行预测,获得训练数据对应的第一结果,并使用指导模型对训练数据进行预测,获得训练数据对应的第二结果;交叉熵计算子模块,用于分别计算训练标签与第一结果之间的第一交叉熵,以及第一结果与第二结果之间的第二交叉熵。
可选地,在本申请实施例中,搜索模型获得模块,包括:总损失值获得模块,用于对推土机距离损失值和交叉熵损失值进行加权融合,获得总损失值;搜索模型训练模块,用于根据总损失值对待搜索模型进行训练。
可选地,在本申请实施例中,搜索模型训练模块,包括:网络参数更新模块,用于根据总损失值更新待搜索模型的网络参数,网络参数表征待搜索模型的输入数据的权重。
本申请实施例还提供了一种电子设备,包括:处理器和存储器,存储器存储有处理器可执行的机器可读指令,机器可读指令被处理器执行时执行如上面描述的方法。
本申请实施例还提供了一种存储介质,该存储介质上存储有计算机程序,该计算机程序被处理器运行时执行如上面描述的方法。
附图说明
为了更清楚地说明本申请实施例的技术方案,下面将对本申请实施例中所需要使用的附图作简单地介绍,应当理解,以下附图仅示出了本申请的某些实施例,因此不应被看作是对范围的限定,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他相关的附图。
图1示出的本申请实施例提供的神经网络结构搜索方法的流程示意图;
图2示出的本申请实施例提供的待搜索结构与指导模型进行网络结构搜索的过程示意图;
图3示出的本申请实施例提供的根据软目标和硬目标获得交叉熵损失值的过程示意图;
图4示出的本申请实施例提供的神经网络结构搜索装置的结构示意图;
图5示出的本申请实施例提供的电子设备的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整的描述。
在介绍本申请实施例提供的神经网络结构搜索方法之前,先介绍本申请实施例中所涉及的一些概念:
可微分结构搜索(Differentiable Architecture Search,DARTS),是指将网络空间表示为一个有向无环图,其关键是将节点连接和激活函数通过一种巧妙的表示组合成了一个矩阵,其中每个元素代表了连接和激活函数的权重,在搜索时使用了归一化指数函数,这样就将搜索空间变成了连续空间,目标函数成为了可微函数。在搜索时,DARTS会遍历全部节点,使用节点上全部连接的加权进行计算,同时优化结构权重和网络权重。搜索结束后,选择权重最大的连接和激活函数,形成最终的网络。
推土机距离(Earth Mover Distance,EMD),又被称为EMD距离或者Wasserstein距离,是指度量两个概率分布之间的距离,可以用于描述两个多维分布之间相似度的度量,Π(P1,P2)是P1和P2分布组合起来的所有可能的联合分布的集合;对于每一个可能的联合分布γ,可以从中采样(x,y)~γ得到一个样本x和y,并计算出这对样本的距离||x-y||,所以可以计算该联合分布γ下,样本对距离的期望值E(x,y)~γ[||x-y||]。
知识蒸馏(Knowledge Distillation),又被称为模型蒸馏、暗知识提取、蒸馏训练或蒸馏学习,是指将知识从一个复杂的机器学习模型迁移到另一个简化的机器学习模型,模型蒸馏采用的是迁移学习,通过采用预先训练好且复杂的老师模型(Teacher model)的输出作为监督信号去训练另外一个简单的学生模型(Student model)。
自然语言处理(Natural Language Processing,NLP),是指由于理解(understanding)自然语言,需要关于外在世界的广泛知识以及运用操作这些知识的能力,而研究自然语言认知的相关问题,即自然语言认知同时也被视为一个人工智能完备(AI-complete)的相关问题,这里的自然语言处理也是机器学习中的一个重要组成部分。
需要说明的是,本申请实施例提供的神经网络结构搜索方法可以被电子设备执行,这里的电子设备是指具有执行计算机程序功能的设备终端或者服务器。
在介绍本申请实施例提供的神经网络结构搜索方法之前,先介绍该神经网络结构搜索方法适用的应用场景,这里的应用场景包括但不限于:基于深度学习的图像识别、自然语言处理和声音识别等等场景,例如:针对具体的任务使用该神经网络结构搜索方法获得搜索后的神经网络模型,并使用神经网络模型完成该任务等。
请参见图1示出的本申请实施例提供的神经网络结构搜索方法的流程示意图;该神经网络结构搜索方法的主要思路是,通过在对待搜索结构进行可微分网络结构搜索的过程中,使用推土机距离(EMD)来量化待搜索结构的多个隐含层与指导模型的多个隐含层在结构上的差异,并在对待搜索模型进行训练的过程中,也加入了表征推土机距离的推土机距离损失值,有效地使用推土机距离量化网络结构搜索过程和模型训练过程的进度,减少了使用强化学习或者变异遗传的思路解决黑箱优化问题的不确定性,从而加快了搜索到有效神经网络结构模型的速度;上述的神经网络结构搜索方法可以包括:
步骤S110:计算待搜索结构的多个隐含层输出与指导模型的多个隐含层输出之间的推土机距离损失值。
请参见图2示出的本申请实施例提供的待搜索结构与指导模型进行网络结构搜索的过程示意图;待搜索结构是指基于梯度的神经网络结构搜索方法获得的神经网络结构,即需要被网络结构搜索(NAS)出来的神经网络结构,在知识蒸馏的过程中又可以被称为学生模型(Student model),此处的上述的基于梯度的神经网络结构搜索方法可以是DARTS方法,是指将网络结构搜索转化为连续空间的优化问题,该优化问题是采用梯度下降法求解的方法来解决的。上述待搜索结构通常包括:一个输入层(Input Layer)、多个隐含层(Hidden Layer)和一个输出层(output Layer);输入层与多个隐含层连接,多个隐含层内部相互连接,且多个隐含层中远离输入层的隐含层与输出层连接;其中,隐含层可以包括多个节点,每个节点包括多个神经网络基础单元,一个神经网络基础单元即一个神经网络的相关基本操作,神经网络基础单元具体例如:卷积神经网络(Convolutional NeuralNetworks,CNN)、循环神经网络(Recurrent Neural Network,RNN)和注意力机制(Attention)等操作。
上述的多个隐含层可以通过节点连接,节点之间也可以通过神经网络基础单元连接,具体是否连接需要由每个神经网络基础单元、节点或者隐含层所对应的结构参数来确定。上述的结构参数表征神经网络基础单元、节点或者隐含层之间的连接概率,也可以理解为结构参数是表征神经网络基础单元、节点或者隐含层之间的连接权重;此处的连接概率具体计算方式例如:根据
Figure BDA0002831402360000081
计算神经网络基础单元、节点或者隐含层之间的连接概率;其中,i表示第i个神经网络基础单元、节点或者隐含层,j表示第j个神经网络基础单元、节点或者隐含层,O表示每个节点中的操作集合,即每个节点中的神经网络基础单元集合,o表示每个节点的操作集合中的具体操作,α表示对应的结构参数。
可以理解的是,上述的待搜索结构中的每个神经网络基础单元的输出是对所有中间节点中的神经网络基础单元做加和获得的,即每个节点的输入都来自上个节点的输出,那么每个中间节点使用公式可以表示为:
Figure BDA0002831402360000082
其中,x表示输入的数据,i表示第i个神经网络基础单元,j表示第j个节点,o表示每个节点的操作集合中的操作,即神经网络基础单元,oi,j表示第i个节点中的第j个神经网络基础单元。
上述网络结构搜索获得的具体效果大致可以从搜索空间、搜索策略以及结构性能的质量三个方面来评价:搜索空间指的是网络结构搜索中结构的候选集合;通常来说,搜索空间越大,搜索过程越慢;可以通过外部先验知识指导搜索空间的设计,从而减小搜索空间,简化搜索过程,但是由于人类认知的限制,这样做也可能对发现新的网络结构造成限制。搜索策略是指如何在庞大的搜索空间中进行有效快速的搜索;在搜索的过程中,需要考虑如何快速的搜索到最优的网络结构,同时需要避免在搜索过程中搜到局部最优的网络结构。
指导模型,在知识蒸馏的过程中又可以被称为老师模型(Teacher model),是指网络结构比待搜索模型更为复杂的神经网络模型,可以用于根据指导模型对待搜索模型进行蒸馏学习;此处的指导模型可以是预训练语言模型,预训练语言模型包括:自回归(AutoRegressive)语言模型或者自编码(Auto Encoding)语言模型;其中,此处的预训练语言模型又被简称为预训练模型,是指将大量的文本语料作为训练数据,使用训练数据对神经网络进行半监督机器学习,获得的神经网络模型,这里的预训练模型蕴含着语言模型中的文本结构关系,可以使用的预训练语义模型例如:GloVe、word2vec和fastText等模型。
上述步骤S110中的计算推土机距离损失值的实施方式可以包括:
步骤S111:计算指导模型的每个隐含层输出和待搜索结构的每个隐含层输出之间的均方误差。
上述步骤S111的实施方式例如:根据
Figure BDA0002831402360000091
计算指导模型的每个隐含层输出和待搜索结构的每个隐含层输出之间的均方误差;其中,
Figure BDA0002831402360000092
表示指导模型中的第j个隐含层的输出,
Figure BDA0002831402360000093
表示待搜索模型中的第i个隐含层的输出,MSE表示计算均方误差,
Figure BDA0002831402360000094
表示指导模型的第j个隐含层的输出和待搜索模型的第i个隐含层的输出之间的均方误差。
步骤S112:计算指导模型的每个隐含层输出和待搜索结构的每个隐含层输出之间的转移矩阵。
上述步骤S112的实施方式例如:使用动态规划算法计算指导模型的每个隐含层输出和待搜索结构的每个隐含层输出之间的转移矩阵;其中,此处的动态规划(DynamicProgramming,DP)是运筹学的一个分支,是求解决策过程最优化的过程。
步骤S113:根据均方误差和转移矩阵计算推土机距离损失值。
上述步骤S113的实施方式例如:根据
Figure BDA0002831402360000101
计算推土机距离损失值;其中,HT是指导模型,HS是待搜索模型,EMD(HT,HS)是推土机距离损失值,fij是计算出来的转移矩阵,dij是指导模型的第j隐含层输出和待搜索模型的第i隐含层输出之间的均方误差。
在步骤S110之后,执行步骤S120:根据推土机距离损失值对待搜索结构进行可微分网络结构搜索,获得待搜索模型。
上述步骤S120的实施方式可以包括:
步骤S121:若推土机距离损失值小于预设阈值,则获取待搜索结构的多个隐含层中的每个隐含层对应的结构参数。
上述步骤S121的实施方式例如:对待搜索结构进行可微分网络结构搜索(DARTS)处理,并实时获取待搜索结构的多个隐含层输出与指导模型的多个隐含层输出之间的推土机距离损失值;若推土机距离损失值小于预设阈值,则获取待搜索结构的多个隐含层中的每个隐含层对应的结构参数;其中,此处的预设阈值可以根据具体情况进行设置,具体例如:将预设阈值设置为2或者30等等。
步骤S122:从每个隐含层对应的多个节点中筛选出结构参数最大的节点,并从结构参数最大的节点对应的多个神经网络基础单元中筛选出结构参数最大的神经网络基础单元,获得待搜索模型。
上述步骤S122的实施方式根据具体情况选择目标对象的个数可以有很多种,该实施包括但不限于:第一种实施方式,只选择一个结构参数最大的目标对象,具体例如:从每个隐含层对应的多个节点中筛选出结构参数最大的节点,并从结构参数最大的节点对应的多个神经网络基础单元中筛选出结构参数最大的神经网络基础单元,获得待搜索模型。第二种实施方式,根据结构参数从大到小排列只选择前两个目标对象,具体例如:从每个隐含层对应的多个节点中筛选出结构参数最大的两个节点,并从结构参数最大的节点对应的多个神经网络基础单元中筛选出结构参数最大的两个神经网络基础单元,获得待搜索模型。以此类推,还可以根据结构参数从大到小排列只选择前三个、四个或者四个以上目标对象,从而推理出有更多的实施方式。
在上述的实现过程中,若推土机距离损失值小于预设阈值,则获取待搜索结构的多个隐含层中的每个隐含层对应的结构参数;从而使用推土机距离(EMD)来量化待搜索结构的多个隐含层与指导模型的多个隐含层在结构上的差异,减少了使用强化学习或者变异遗传的思路解决黑箱优化问题的不确定性,从而加快了搜索到有效神经网络结构模型的速度。
在步骤S120之后,执行步骤S130:分别计算训练标签和待搜索模型输出的第一结果之间的第一交叉熵,以及计算待搜索模型输出的第一结果和指导模型输出的第二结果之间的第二交叉熵,并根据第一交叉熵和第二交叉熵计算交叉熵损失值。
训练标签,是指对待搜索模型进行训练时所使用的训练数据集中的训练标签,该训练数据集中还包括训练数据,其中,训练数据和训练标签是对应的,具体例如:假设待搜索模型是用于对文本内容进行情感分类的神经网络模型,那么可以使用文本内容和该文本内容对应的分类标签(例如:积极的文章或消极的文章)训练该神经网络模型,此处的文本内容就是训练数据,此处的分类标签就是训练标签,为了方便存储和压缩传输,可以将很多文本内容和该文本内容对应的分类标签作为训练数据集,将训练数据集作为整体一起存储和压缩传输。
上述步骤S130的实施方式可以包括:
步骤S131:获得训练标签和训练标签对应的训练数据。
上述步骤S131的实施方式例如:上述的训练数据和训练标签可以分开获取,具体例如:人工搜集训练数据,并人工地识别训练数据对应的训练标签;当然,也可以将多个训练数据和训练数据对应的训练标签打包为数据压缩包一起获取,即训练数据和训练标签是在训练数据集中相互对应的,一个训练数据对应一个训练标签,这里以数据压缩包一起获取为例进行说明;数据压缩包的获得方式包括:第一种获得方式,接收其它终端设备发送的数据压缩包,将数据压缩包存储至文件系统、数据库或移动存储设备中;第二种获得方式,获取预先存储的数据压缩包,具体例如:从文件系统中获取数据压缩包,或者从数据库中获取数据压缩包,或者从移动存储设备中获取数据压缩包;第三种获得方式,使用浏览器等软件获取互联网上的数据压缩包,或者使用其它应用程序访问互联网获得数据压缩包。
步骤S132:使用待搜索模型对训练数据进行预测,获得训练数据对应的第一结果,并使用指导模型对训练数据进行预测,获得训练数据对应的第二结果。
上述步骤S132的实施方式例如:假设指导模型是自回归语言模型或者自编码语言模型,那么可以使用待搜索模型预测训练数据对应的第一结果,并使用自回归语言模型或者自编码语言模型预测训练数据对应的第二结果;其中,自回归语言模型具体可以是ELMo模型、GPT模型或者GPT-2模型,自编码语言模型具体可以是双向编码表示编码器(Bidirectional Encoder Representations from Transformers,BERT)等等。
步骤S133:分别计算训练标签与第一结果之间的第一交叉熵,以及第一结果与第二结果之间的第二交叉熵。
上述步骤S133的实施方式例如:分别计算训练标签与第一结果之间的第一交叉熵(cross entropy,CE),以及第一结果与第二结果之间的第二交叉熵;其中,交叉熵是指描述两个近似概率分布的差异程度;在自然语言处理研究中,交叉熵常被用来评价和对比统计语言模型,用来衡量统计语言模型是否反映了语言数据的真实分布。
步骤S134:对第一交叉熵和第二交叉熵进行加权融合,获得交叉熵损失值。
请参见图3示出的本申请实施例提供的根据软目标和硬目标获得交叉熵损失值的过程示意图;为了提高模型搜索和训练准确性,上述步骤S130还可以结合软目标和硬目标来获得待搜索模型的交叉熵损失值。当然在具体的实施过程中,可以使用可调整参数来调整软目标和硬目标的比重,从而避免了直接只使用硬目标来获得交叉熵损失值,使得获得的交叉熵损失值更加准确。
上述步骤S134的实施方式例如:根据LKD=(1-α)CE(p,y)+aCE(p,q)对第一交叉熵和第二交叉熵进行加权融合,获得交叉熵损失值;其中,LKD是交叉熵损失值,α是第一可调整参数,p是第一结果,q是第二结果,y是训练标签,CE(p,y)是训练标签与第一结果之间的第一交叉熵,由于训练标签通常是人工识别并设置的整数标签,此处的整数标签具体例如:1代表是某种动物类别,0代表不是某种动物类别;由训练标签确定的第一交叉熵也可以理解为硬目标(hard target);CE(p,q)是第一结果与第二结果之间的第二交叉熵,而第一结果和第二结果均是模型输出的小数标签,此处的小数标签具体例如:0.1、0.5和0.1等等表示是某种动物类别的概率,由第一结果和第二结果确定的第二交叉熵也可以理解为软目标(soft target)。
在上述的实现过程中,通过对训练标签与第一结果之间的第一交叉熵,以及第一结果与第二结果之间的第二交叉熵进行加权融合,从而获得结合软目标和硬目标的交叉熵损失值;从而避免了直接只使用硬目标来获得交叉熵损失值,提高了计算交叉熵损失值的准确率。
在步骤S130之后,执行步骤S140:根据推土机距离损失值和交叉熵损失值对待搜索模型进行训练,获得搜索后的神经网络模型。
上述步骤S140中的实施方式例如:对推土机距离损失值和交叉熵损失值进行加权融合,获得总损失值,根据总损失值对待搜索模型进行训练;也就是说,先根据Loss=LKD+β·EMD(HT,HS)计算出总损失值,然后再根据该总损失值对待搜索模型进行训练;其中,EMD(HT,HS)是推土机距离损失值,LKD是交叉熵损失值,β为第二可调整参数,Loss为总损失值。在上述的实现过程中,通过对推土机距离损失值和交叉熵损失值进行加权融合,获得总损失值;根据总损失值对待搜索模型进行训练;从而避免了只根据交叉熵损失值对待搜索模型进行训练,进一步提高了搜索到有效神经网络结构模型的速度和准确率。
上述根据该总损失值对待搜索模型进行训练的具体实施方式例如:根据总损失值进行反向传播运算,获得待搜索模型的每个网络参数的梯度(gradient),然后根据梯度和学习率来更新待搜索模型的网络参数;其中,网络参数表征待搜索模型的输入数据的权重。
在上述的实现过程中,通过在对待搜索结构进行可微分网络结构搜索的过程中,使用推土机距离(EMD)来量化待搜索结构的多个隐含层与指导模型的多个隐含层在结构上的差异,并在对待搜索模型进行训练的过程中,也加入了表征推土机距离的推土机距离损失值,有效地使用推土机距离量化网络结构搜索过程和模型训练过程的进度,减少了使用强化学习或者变异遗传的思路解决黑箱优化问题的不确定性,从而加快了搜索到有效神经网络结构模型的速度。
可选地,在获得搜索后的神经网络模型之后,还可以根据具体的任务类型对搜索后的神经网络模型进行微调,对搜索后的神经网络模型进行微调的具体过程可以包括:
步骤S143:获得预设任务对应的训练数据。
预设任务,是指根据具体情况设置的任务,具体可以是自然语言处理(NLP)相关的任务,常见的自然语言处理任务例如:依存句法分析、指代消解、命名实体识别和词性标注等等。
上述步骤S143的实施方式包括:第一种方式,收集训练数据,并对训练数据进行人工识别获得训练标签;第二种方式,使用浏览器等软件获取互联网上的训练数据,或者使用其它应用程序访问互联网获得训练数据;第三种方式,获取预先存储的训练数据,从文件系统中获取训练数据,或者从数据库中获取训练数据。
步骤S144:使用预设任务对应的训练数据对搜索后的神经网络模型进行微调,获得微调后的神经网络模型。
上述步骤S144的实施方式包括:使用预设任务对应的训练数据对搜索后的神经网络模型进行微调(fine-tuning),获得微调后的神经网络模型;这里的微调是指针对具体的深度学习任务,在对搜索后的神经网络模型进行微调时,可以保留之前训练的大多数参数,从而达到快速训练收敛的效果;具体例如:保留BERT模型中特征提取部分的各个卷积层,只重构卷积层后的全连接层和/或softmax网络层,具体可以将原来输出二维的全连接层替换为输出一维的全连接层,或者,将原来输出10个分类的softmax网络层替换为输出3个分类的softmax网络层。
在上述的实现过程中,通过获得预设任务对应的训练数据;使用预设任务对应的训练数据对搜索后的神经网络模型进行微调,获得微调后的神经网络模型;从而有效地提高了针对预设任务搜索到神经网络结构模型的准确率。
请参见图4示出的本申请实施例提供的神经网络结构搜索装置的结构示意图;本申请实施例提供了一种神经网络结构搜索装置200,包括:
第一损失计算模块210,用于计算待搜索结构的多个隐含层输出与指导模型的多个隐含层输出之间的推土机距离损失值。
网络结构搜索模块220,用于根据推土机距离损失值对待搜索结构进行可微分网络结构搜索,获得待搜索模型。
第二损失计算模块230,用于分别计算训练标签和待搜索模型输出的第一结果之间的第一交叉熵,以及计算待搜索模型输出的第一结果和指导模型输出的第二结果之间的第二交叉熵,并根据第一交叉熵和第二交叉熵计算交叉熵损失值。
搜索模型获得模块240,用于根据推土机距离损失值和交叉熵损失值对待搜索模型进行训练,获得搜索后的神经网络模型。
可选地,在本申请实施例中,第一损失计算模块,包括:
均方误差计算模块,用于计算指导模型的每个隐含层输出和待搜索结构的每个隐含层输出之间的均方误差。
转移矩阵计算模块,用于计算指导模型的每个隐含层输出和待搜索结构的每个隐含层输出之间的转移矩阵。
距离损失计算模块,用于根据均方误差和转移矩阵计算推土机距离损失值。
可选地,在本申请实施例中,网络结构搜索模块,包括:
结构参数获取模块,用于若推土机距离损失值小于预设阈值,则获取待搜索结构的多个隐含层中的每个隐含层对应的结构参数,隐含层包括多个节点,每个节点包括多个神经网络基础单元,结构参数表征神经网络基础单元之间的连接权重。
结构参数筛选模块,用于从每个隐含层对应的多个节点中筛选出结构参数最大的节点,并从结构参数最大的节点对应的多个神经网络基础单元中筛选出结构参数最大的神经网络基础单元,获得待搜索模型。
可选地,在本申请实施例中,第二损失计算模块,包括:
标签数据获得模块,用于获得训练标签和训练标签对应的训练数据。
数据标签预测模块,用于使用待搜索模型对训练数据进行预测,获得训练数据对应的第一结果,并使用指导模型对训练数据进行预测,获得训练数据对应的第二结果。
交叉熵计算子模块,用于分别计算训练标签与第一结果之间的第一交叉熵,以及第一结果与第二结果之间的第二交叉熵。
第一加权融合模块,用于对第一交叉熵和第二交叉熵进行加权融合,获得交叉熵损失值。
可选地,在本申请实施例中,搜索模型获得模块,包括:
总损失值获得模块,用于对推土机距离损失值和交叉熵损失值进行加权融合,获得总损失值。
搜索模型训练模块,用于根据总损失值对待搜索模型进行训练。
可选地,在本申请实施例中,搜索模型训练模块,包括:
网络参数更新模块,用于根据总损失值更新待搜索模型的网络参数,网络参数表征待搜索模型的输入数据的权重。
可选地,在本申请实施例中,神经网络结构搜索装置,还包括:
训练数据获得模块,用于获得预设任务对应的训练数据。
网络模型微调模块,用于使用预设任务对应的训练数据对搜索后的神经网络模型进行微调,获得微调后的神经网络模型。
应理解的是,该装置与上述的神经网络结构搜索方法实施例对应,能够执行上述方法实施例涉及的各个步骤,该装置具体的功能可以参见上文中的描述,为避免重复,此处适当省略详细描述。该装置包括至少一个能以软件或固件(firmware)的形式存储于存储器中或固化在装置的操作系统(operating system,OS)中的软件功能模块。
请参见图5示出的本申请实施例提供的电子设备的结构示意图。本申请实施例提供的一种电子设备300,包括:处理器310和存储器320,存储器320存储有处理器310可执行的机器可读指令,机器可读指令被处理器310执行时执行如上的方法。
本申请实施例还提供了一种存储介质330,该存储介质330上存储有计算机程序,该计算机程序被处理器310运行时执行如上的方法。
其中,存储介质330可以由任何类型的易失性或非易失性存储设备或者它们的组合实现,如静态随机存取存储器(Static Random Access Memory,简称SRAM),电可擦除可编程只读存储器(Electrically Erasable Programmable Read-Only Memory,简称EEPROM),可擦除可编程只读存储器(Erasable Programmable Read Only Memory,简称EPROM),可编程只读存储器(Programmable Red-Only Memory,简称PROM),只读存储器(Read-Only Memory,简称ROM),磁存储器,快闪存储器,磁盘或光盘。
本申请实施例提供的几个实施例中,应该理解到,所揭露的装置和方法,也可以通过其他的方式实现。以上所描述的装置实施例仅是示意性的,例如,附图中的流程图和框图显示了根据本申请实施例的多个实施例的装置、方法和计算机程序产品的可能实现的体系架构、功能和操作。
在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。
以上的描述,仅为本申请实施例的可选实施方式,但本申请实施例的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请实施例揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本申请实施例的保护范围之内。

Claims (10)

1.一种神经网络结构搜索方法,其特征在于,包括:
计算待搜索结构的多个隐含层输出与指导模型的多个隐含层输出之间的推土机距离损失值;
根据所述推土机距离损失值对所述待搜索结构进行可微分网络结构搜索,获得待搜索模型;
分别计算训练标签和所述待搜索模型输出的第一结果之间的第一交叉熵,以及计算所述待搜索模型输出的第一结果和所述指导模型输出的第二结果之间的第二交叉熵,并根据所述第一交叉熵和所述第二交叉熵计算交叉熵损失值;
根据所述推土机距离损失值和所述交叉熵损失值对所述待搜索模型进行训练,获得搜索后的神经网络模型。
2.根据权利要求1所述的方法,其特征在于,所述计算待搜索结构的多个隐含层输出与指导模型的多个隐含层输出之间的推土机距离损失值,包括:
计算所述指导模型的每个隐含层输出和所述待搜索结构每个隐含层输出之间的均方误差;
计算所述指导模型的每个隐含层输出和所述待搜索结构的每个隐含层输出之间的转移矩阵;
根据所述均方误差和所述转移矩阵计算所述推土机距离损失值。
3.根据权利要求1所述的方法,其特征在于,所述根据所述推土机距离损失值对所述待搜索结构进行可微分网络结构搜索,获得待搜索模型,包括:
若所述推土机距离损失值小于预设阈值,则获取所述待搜索结构的多个隐含层中的每个隐含层对应的结构参数,所述隐含层包括多个节点,每个所述节点包括多个神经网络基础单元,所述结构参数表征神经网络基础单元之间的连接权重;
从所述每个隐含层对应的多个节点中筛选出结构参数最大的节点,并从所述结构参数最大的节点对应的多个神经网络基础单元中筛选出结构参数最大的神经网络基础单元,获得所述待搜索模型。
4.根据权利要求1所述的方法,其特征在于,所述分别计算训练标签和所述待搜索模型输出的第一结果之间的第一交叉熵,以及计算所述待搜索模型输出的第一结果和所述指导模型输出的第二结果之间的第二交叉熵,包括:
获得所述训练标签和所述训练标签对应的训练数据;
使用所述待搜索模型对所述训练数据进行预测,获得所述训练数据对应的所述第一结果,并使用所述指导模型对所述训练数据进行预测,获得所述训练数据对应的所述第二结果;
分别计算所述训练标签与所述第一结果之间的第一交叉熵,以及所述第一结果与所述第二结果之间的第二交叉熵。
5.根据权利要求1所述的方法,其特征在于,所述根据所述推土机距离损失值和所述交叉熵损失值对所述待搜索模型进行训练,包括:
对所述推土机距离损失值和所述交叉熵损失值进行加权融合,获得总损失值;
根据所述总损失值对所述待搜索模型进行训练。
6.根据权利要求5所述的方法,其特征在于,所述根据所述总损失值对所述待搜索模型进行训练,包括:
根据所述总损失值更新所述待搜索模型的网络参数,所述网络参数表征所述待搜索模型的输入数据的权重。
7.根据权利要求1-6任一所述的方法,其特征在于,在所述获得搜索后的神经网络模型之后,还包括:
获得预设任务对应的训练数据;
使用所述预设任务对应的训练数据对所述搜索后的神经网络模型进行微调,获得微调后的神经网络模型。
8.一种神经网络结构搜索装置,其特征在于,包括:
第一损失计算模块,用于计算待搜索结构的多个隐含层输出与指导模型的多个隐含层输出之间的推土机距离损失值;
网络结构搜索模块,用于根据所述推土机距离损失值对所述待搜索结构进行可微分网络结构搜索,获得待搜索模型;
第二损失计算模块,用于分别计算训练标签和所述待搜索模型输出的第一结果之间的第一交叉熵,以及计算所述待搜索模型输出的第一结果和所述指导模型输出的第二结果之间的第二交叉熵,并根据所述第一交叉熵和所述第二交叉熵计算交叉熵损失值;
搜索模型获得模块,用于根据所述推土机距离损失值和所述交叉熵损失值对所述待搜索模型进行训练,获得搜索后的神经网络模型。
9.一种电子设备,其特征在于,包括:处理器和存储器,所述存储器存储有所述处理器可执行的机器可读指令,所述机器可读指令被所述处理器执行时执行如权利要求1至7任一所述的方法。
10.一种存储介质,其特征在于,该存储介质上存储有计算机程序,该计算机程序被处理器运行时执行如权利要求1至7任一所述的方法。
CN202011471982.4A 2020-12-11 2020-12-11 神经网络结构搜索方法、装置、电子设备及存储介质 Pending CN112508178A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011471982.4A CN112508178A (zh) 2020-12-11 2020-12-11 神经网络结构搜索方法、装置、电子设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011471982.4A CN112508178A (zh) 2020-12-11 2020-12-11 神经网络结构搜索方法、装置、电子设备及存储介质

Publications (1)

Publication Number Publication Date
CN112508178A true CN112508178A (zh) 2021-03-16

Family

ID=74973242

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011471982.4A Pending CN112508178A (zh) 2020-12-11 2020-12-11 神经网络结构搜索方法、装置、电子设备及存储介质

Country Status (1)

Country Link
CN (1) CN112508178A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112949832A (zh) * 2021-03-25 2021-06-11 鼎富智能科技有限公司 一种网络结构搜索方法、装置、电子设备及存储介质
CN115795125A (zh) * 2023-01-18 2023-03-14 北京东方瑞丰航空技术有限公司 应用于项目管理软件的搜索方法、装置、设备及介质

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112949832A (zh) * 2021-03-25 2021-06-11 鼎富智能科技有限公司 一种网络结构搜索方法、装置、电子设备及存储介质
CN112949832B (zh) * 2021-03-25 2024-04-16 鼎富智能科技有限公司 一种网络结构搜索方法、装置、电子设备及存储介质
CN115795125A (zh) * 2023-01-18 2023-03-14 北京东方瑞丰航空技术有限公司 应用于项目管理软件的搜索方法、装置、设备及介质

Similar Documents

Publication Publication Date Title
CN109408731B (zh) 一种多目标推荐方法、多目标推荐模型生成方法以及装置
CN112270379B (zh) 分类模型的训练方法、样本分类方法、装置和设备
CN111741330B (zh) 一种视频内容评估方法、装置、存储介质及计算机设备
CN109271539B (zh) 一种基于深度学习的图像自动标注方法及装置
US11922281B2 (en) Training machine learning models using teacher annealing
CN111369299B (zh) 识别的方法、装置、设备及计算机可读存储介质
CN111612134A (zh) 神经网络结构搜索方法、装置、电子设备及存储介质
US10621137B2 (en) Architecture for predicting network access probability of data files accessible over a computer network
CN114780831A (zh) 基于Transformer的序列推荐方法及系统
CN112508177A (zh) 一种网络结构搜索方法、装置、电子设备及存储介质
US20210248425A1 (en) Reinforced text representation learning
CN112508178A (zh) 神经网络结构搜索方法、装置、电子设备及存储介质
KR20220047228A (ko) 이미지 분류 모델 생성 방법 및 장치, 전자 기기, 저장 매체, 컴퓨터 프로그램, 노변 장치 및 클라우드 제어 플랫폼
CN113343092A (zh) 基于大数据挖掘的内容源推荐更新方法及云计算服务系统
CN113515589A (zh) 数据推荐方法、装置、设备以及介质
CN117851909B (zh) 一种基于跳跃连接的多循环决策意图识别系统及方法
CN111161238A (zh) 图像质量评价方法及装置、电子设备、存储介质
CN115168720A (zh) 内容交互预测方法以及相关设备
CN117036834B (zh) 基于人工智能的数据分类方法、装置及电子设备
CN113836934A (zh) 基于标签信息增强的文本分类方法和系统
KR102474436B1 (ko) 캡션 데이터를 기반으로 자연어 검색을 수행하는 영상 이미지 검색 장치 및 그 동작 방법
CN116127376A (zh) 模型训练方法、数据分类分级方法、装置、设备及介质
CN112949832B (zh) 一种网络结构搜索方法、装置、电子设备及存储介质
CN114580533A (zh) 特征提取模型的训练方法、装置、设备、介质及程序产品
CN114898184A (zh) 模型训练方法、数据处理方法、装置及电子设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication
RJ01 Rejection of invention patent application after publication

Application publication date: 20210316