CN111310823A - 目标分类方法、装置和电子系统 - Google Patents

目标分类方法、装置和电子系统 Download PDF

Info

Publication number
CN111310823A
CN111310823A CN202010089737.0A CN202010089737A CN111310823A CN 111310823 A CN111310823 A CN 111310823A CN 202010089737 A CN202010089737 A CN 202010089737A CN 111310823 A CN111310823 A CN 111310823A
Authority
CN
China
Prior art keywords
preset
network model
auxiliary
loss value
model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202010089737.0A
Other languages
English (en)
Other versions
CN111310823B (zh
Inventor
黄昊明
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Megvii Technology Co Ltd
Original Assignee
Beijing Megvii Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Megvii Technology Co Ltd filed Critical Beijing Megvii Technology Co Ltd
Priority to CN202010089737.0A priority Critical patent/CN111310823B/zh
Publication of CN111310823A publication Critical patent/CN111310823A/zh
Application granted granted Critical
Publication of CN111310823B publication Critical patent/CN111310823B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明提供了一种目标分类方法、装置和电子系统,首先获取待处理数据;将该待处理数据输入至预先训练完成的网络模型,得到该待处理数据中待分类目标的分类结果;该网络模型通过预设的辅助模型和损失函数训练得到,在训练网络模型和辅助模型的过程中,损失函数可根据辅助模型和网络模型分别基于预设样本输出的预设类别的分类结果,确定网络模型的第一损失值。由于该方式在训练网络模型的过程中同时训练辅助模型,因而辅助模型在训练过程中也在不断提升性能,且在辅助模型性能不断提升的同时,通过损失函数可以将辅助模型的知识传递至网络模型,使得网络模型的性能也不断提升,从而网络模型的性能不再受到辅助模型在初始状态下的性能的限制。

Description

目标分类方法、装置和电子系统
技术领域
本发明涉及神经网络技术领域,尤其是涉及一种目标分类方法、装置和电子系统。
背景技术
神经网络的应用越来越广泛,为了使神经网络可以完成复杂的信息处理任务,神经网络的深度或宽度不断增加,使得神经网络的参数量越来越庞大,尽管更深或者更宽的神经网络的性能更好,但是由于参数量庞大,导致其计算量较大,难以部署在资源有限的设备(例如,手机、平板、车载设备等)上。
相关技术中,通常采用模型蒸馏算法将训练好的参数量庞大的大网络的知识传递给小网络,以使小网络在结构简单参数量小的同时,具有大网络的性能,但是,该方式采用的是将大网络的知识单向传递至小网络的方式,小网络的性能提升空间受到大网络性能的限制,进而影响了小网络的总体性能。
发明内容
有鉴于此,本发明的目的在于提供一种目标分类方法、装置和电子系统,以提高网络性能提升的灵活性。
第一方面,本发明实施例提供了一种目标分类方法,该方法包括:获取待处理数据;该待处理数据中包括待分类目标;将该待处理数据输入至预先训练完成的网络模型中,得到待分类目标的分类结果;其中,网络模型通过预设的辅助模型和预设的损失函数训练得到;该损失函数用于:在训练网络模型和辅助模型的过程中,根据辅助模型基于预设样本输出的预设类别的分类结果,以及网络模型基于预设样本输出的预设类别的分类结果,确定网络模型的第一损失值。
在本发明较佳的实施例中,上述损失函数具体通过下述方式确定第一损失值:计算网络模型基于预设样本输出的预设类别的分类结果,与辅助模型基于预设样本输出的预设类别的分类结果的相对熵,根据该相对熵确定第一损失值。
在本发明较佳的实施例中,上述计算网络模型基于预设样本输出的预设类别的分类结果,与辅助模型基于预设样本输出的预设类别的分类结果的相对熵,根据相对熵确定第一损失值的步骤,包括:计算网络模型基于预设样本输出的预设类别的分类结果对应的第一概率分布;计算辅助模型基于预设样本输出的预设类别的分类结果对应的第二概率分布;根据第一概率分布和第二概率分布,计算第二概率分布相对第一概率分布的第一相对熵;将该第一相对熵确定为第一损失值。
在本发明较佳的实施例中,上述计算网络模型基于预设样本输出的预设类别的分类结果对应的第一概率分布的步骤,包括:计算网络模型基于第i个预设样本xi输出的第m个预设类别的分类结果对应的概率
Figure BDA0002383225390000021
组合每个概率
Figure BDA0002383225390000022
得到第一概率分布p1;上述计算辅助模型基于预设样本输出的预设类别的分类结果对应的第二概率分布的步骤,包括:计算辅助模型基于第i个预设样本xi输出的第m个预设类别的分类结果对应的概率
Figure BDA0002383225390000023
组合每个概率
Figure BDA0002383225390000024
得到第二概率分布p2
在本发明较佳的实施例中,上述辅助模型包括多个;上述计算网络模型基于预设样本输出的预设类别的分类结果,与辅助模型基于预设样本输出的预设类别的分类结果的相对熵,根据相对熵确定第一损失值的步骤,包括:计算网络模型基于预设样本输出的预设类别的分类结果对应的第三概率分布;针对每个辅助模型,执行下述操作:计算当前辅助模型基于预设样本输出的预设类别的分类结果对应的第四概率分布;计算多个辅助模型中每个辅助模型对应的第四概率分布相对第三概率分布的第二相对熵,根据第二相对熵确定第一损失值。
在本发明较佳的实施例中,上述计算多个辅助模型中每个辅助模型对应的第四概率分布相对第三概率分布的第二相对熵的步骤,包括:计算每个辅助模型对应的第四概率分布相对第三概率分布的第二相对熵;上述根据第二相对熵确定第一损失值的步骤,包括:计算每个辅助模型对应的第二相对熵的平均值;将该平均值确定为第一损失值。
在本发明较佳的实施例中,上述计算多个辅助模型中每个辅助模型对应的第四概率分布相对第三概率分布的第二相对熵,根据第二相对熵确定第一损失值的步骤,包括:计算每个辅助模型基于预设样本输出的预设类别的分类结果对应的第四概率分布的均值概率分布;计算均值概率分布相对第三概率分布的第二相对熵,将第二相对熵确定为第一损失值。
在本发明较佳的实施例中,上述计算网络模型基于预设样本输出的预设类别的分类结果对应的第三概率分布的步骤,包括:计算网络模型基于第i个预设样本xi输出的第m个预设类别的分类结果对应的概率
Figure BDA0002383225390000031
组合每个概率
Figure BDA0002383225390000032
得到第三概率分布pk;上述计算当前辅助模型基于预设样本输出的预设类别的分类结果对应的第四概率分布的步骤,包括:计算多个辅助模型中第l个辅助模型基于第i个预设样本xi输出的第m个预设类别的分类结果对应的概率
Figure BDA0002383225390000033
组合每个概率
Figure BDA0002383225390000034
得到第l个辅助模型对应的第四概率分布pl
在本发明较佳的实施例中,上述损失函数还用于:根据网络模型基于预设样本输出的预设类别的分类结果,以及预设样本携带的类别标签,确定第二损失值。
在本发明较佳的实施例中,上述根据网络模型基于预设样本输出的预设类别的分类结果,以及预设样本携带的类别标签,确定第二损失值的步骤,包括:计算网络模型基于预设样本输出的预设类别的分类结果,与预设样本携带的类别标签的交叉熵;将该交叉熵确定为第二损失值。
在本发明较佳的实施例中,上述网络模型,具体通过下述方式训练得到:确定样本集合;该样本集合中的每个样本携带有类别标签;将该样本集合分别输入至网络模型和辅助模型中,得到网络模型输出的预设类别的分类结果,和辅助模型输出的预设类别的分类结果;通过损失函数,确定第一损失值和第二损失值;根据第一损失值和第二损失值,训练网络模型和辅助模型;继续执行确定样本集合的步骤,直到第一损失值和第二损失值收敛,得到训练后的网络模型。
在本发明较佳的实施例中,上述根据第一损失值和第二损失值,训练网络模型和辅助模型的步骤,包括:根据第一损失值和第二损失值,调整预设第一网络的参数;其中,该第一网络为网络模型或辅助模型;将样本集合中的样本输入至网络模型和辅助模型中,得到网络模型输出的预设类别的分类结果,和每个辅助模型输出的预设类别的分类结果;通过损失函数,确定第三损失值和第四损失值;根据第三损失值和第四损失值,调整预设第二网络的参数;其中,当第一网络为网络模型时,该第二网络为辅助模型;当第一网络为辅助模型时,该第二网络为网络模型。
第二方面,本发明实施例还提供一种目标分类装置,该装置包括:数据获取模块,用于获取待处理数据;该待处理数据中包括待分类目标;数据处理模块,用于将待处理数据输入至预先训练完成的网络模型中,得到待分类目标的分类结果;其中,该网络模型通过预设的辅助模型和预设的损失函数训练得到:该损失函数用于:在训练网络模型和辅助模型的过程中,根据辅助模型基于预设样本输出的预设类别的分类结果,以及网络模型基于预设样本输出的预设类别的分类结果,确定网络模型的第一损失值。
第三方面,本发明实施例还提供一种电子系统,该电子系统包括:处理设备和存储装置;存储装置上存储有计算机程序,计算机程序在被处理设备运行时执行上述目标分类方法。
第四方面,本发明实施例还提供一种计算机可读存储介质,该计算机可读存储介质上存储有计算机程序,该计算机程序被处理设备运行时执行上述目标分类方法的步骤。
本发明实施例带来了以下有益效果:
本发明提供了一种目标分类方法、装置和电子系统,用于分类目标的网络模型,通过预设的辅助模型和预设的损失函数训练得到;在训练网络模型的过程中,同时训练辅助模型;在训练网络模型和辅助模型的过程中,损失函数根据辅助模型基于预设样本输出的预设类别的分类结果,以及网络模型基于预设样本输出的预设类别的分类结果,确定网络模型的第一损失值。由于该方式在训练网络模型的过程中同时训练辅助模型,因而辅助模型在训练过程中也在不断提升性能,损失函数基于辅助模型输出的分类结果,确定网络模型的损失值,因此,在辅助模型性能不断提升的同时,通过该损失函数可以将辅助模型的知识传递至网络模型,使得网络模型的性能也不断提升,从而网络模型的性能不再受到辅助模型在初始状态下的性能的限制,进一步提高了网络模型的性能。
本发明的其他特征和优点将在随后的说明书中阐述,或者,部分特征和优点可以从说明书推知或毫无疑义地确定,或者通过实施本发明的上述技术即可得知。
为使本发明的上述目的、特征和优点能更明显易懂,下文特举较佳实施方式,并配合所附附图,作详细说明如下。
附图说明
为了更清楚地说明本发明具体实施方式或现有技术中的技术方案,下面将对具体实施方式或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施方式,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本发明实施例提供的一种电子系统的结构示意图;
图2为本发明实施例提供的一种目标分类方法的流程图;
图3为本发明实施例提供的另一种目标分类方法的流程图;
图4为本发明实施例提供的另一种目标分类方法的流程图;
图5为本发明实施例提供的另一种目标分类方法的流程图;
图6为本发明实施例提供的一种目标分类装置的结构示意图。
具体实施方式
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合附图对本发明的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
相关技术中,为了解决集成模型计算量大与性能提升导致性价比下降的问题,研究人员通常设计更精巧的网络结构(例如,MobileNet和ShuffleNet),或者采用网络压缩、剪枝、二值化、模型蒸馏等方法降低网络参数。目前比较常用的模型蒸馏算法中,通常将训练好的参数量庞大的大网络的知识传递给小网络,以使小网络在结构简单参数量小的同时,具有大网络的性能,但是,该方式采用的是将大网络的知识单向传递至小网络的方式,小网络的性能提升空间受到大网络性能的限制,进而影响了小网络的总体性能。
基于此,本发明实施例提供了一种目标分类方法、装置和电子系统,该技术可以应用于信息处理场景中,尤其是图像数据、文本数据等的分类场景中,同时该技术可采用相应的软件和硬件实现,以下对本发明实施例进行详细介绍。
实施例一:
首先,参照图1来描述用于实现本发明实施例的目标分类方法、装置和电子系统的示例电子系统100。
如图1所示的一种电子系统的结构示意图,电子系统100包括一个或多个处理设备102、一个或多个存储装置104、输入装置106以及输出装置108,这些组件通过总线系统110和/或其它形式的连接机构(未示出)互连。应当注意,图1所示的电子系统100的组件和结构只是示例性的,而非限制性的,根据需要,所述电子系统也可以具有其他组件和结构。
所述处理设备102可以是网关,也可以为智能终端,或者是包含中央处理单元(CPU)或者具有数据处理能力和/或指令执行能力的其它形式的处理单元的设备,可以对所述电子系统100中的其它组件的数据进行处理,还可以控制所述电子系统100中的其它组件以执行期望的功能。
所述存储装置104可以包括一个或多个计算机程序产品,所述计算机程序产品可以包括各种形式的计算机可读存储介质,例如易失性存储器和/或非易失性存储器。所述易失性存储器例如可以包括随机存取存储器(RAM)和/或高速缓冲存储器(cache)等。所述非易失性存储器例如可以包括只读存储器(ROM)、硬盘、闪存等。在所述计算机可读存储介质上可以存储一个或多个计算机程序指令,处理设备102可以运行所述程序指令,以实现下文所述的本发明实施例中(由处理设备实现)的客户端功能以及/或者其它期望的功能。在所述计算机可读存储介质中还可以存储各种应用程序和各种数据,例如所述应用程序使用和/或产生的各种数据等。
所述输入装置106可以是用户用来输入指令的装置,并且可以包括键盘、鼠标、麦克风和触摸屏等中的一个或多个。
所述输出装置108可以向外部(例如,用户)输出各种信息(例如,图像、文本或声音),并且可以包括显示器、扬声器等中的一个或多个。
示例性地,用于实现根据本发明实施例的网络结构的确定方法、装置和电子系统的示例电子系统中的各器件可以集成设置,也可以分散设置,诸如将处理设备102、存储装置104、输入装置106和输出装置108集成设置于一体。当上述电子系统中的各器件集成设置时,该电子系统可以被实现为诸如智能手机、平板电脑、计算机等智能终端。
实施例二:
本实施例提供了一种目标分类方法,该方法由上述电子系统中的处理设备执行;该处理设备可以是具有数据处理能力的任何设备或芯片。该处理设备可以独立对接收到的信息进行处理,也可以与服务器相连,共同对信息进行分析处理,并将处理结果上传至云端;如图2所示,该方法包括如下步骤:所述方法包括:
步骤S202,获取待处理数据;该待处理数据中包括待分类目标。
上述待处理数据可以是图像数据、文本数据或者其他类型的数据,该待处理数据可以是用户通过终端设备输入的数据,也可以是电子设备从监控场景中获取到的数据。该待处理数据中包含有待分类目标,该待分类目标可以是文本中的词语、句子、段落等,也可以是图像中的人物、动物、建筑物等。
步骤S204,将上述待处理数据输入至预先训练完成的网络模型中,得到待分类目标的分类结果;其中,该网络模型通过预设的辅助模型和预设的损失函数训练得到;该损失函数用于:在训练网络模型和辅助模型的过程中,根据辅助模型基于预设样本输出的预设类别的分类结果,以及网络模型基于预设样本输出的预设类别的分类结果,确定网络模型的第一损失值。
上述网络模型模型可以是深度学习模型、神经网络模型等,上述辅助模型也可以是深度学习模型、神经网络模型等;其中,网络模型与辅助模型的网络结构可以相同也可以不同,也可以理解为网络模型与辅助模型可实现的功能可以相同也可以不同。在具体实现时,上述网络模型和辅助模型可以是参数量少的网络结构,也可是参数量大的网络结构;在网络模型为参数量少的网络结构时,辅助模型可以为参数量少的网路结构,也可以是参数量大的网络结构。
在对网络模型训练的过程中,需要从预设的样本集合中选择预设样本,再将该预设样本分别输入至网络模型和预设的辅助模型中,得到网络模型输出的预设样本对应的预设类别的分类结果,以及辅助模型输出的预设样本对应的预设类别的分类结果;进而基于预设的损失函数计算网络模型的损失值,继续执行从预设的样本集合中选择预设样本的步骤,直到该损失值收敛。
上述预设的样本集合中通常包含有大量的样本,每个样本均含有待分类目标,以及该待分类目标对应的预设类别的类别标签,该预设类别对应的类别,以及类别总数是预先设置好的,且该类别标签可以用数字、字母或者矩形框等标注。例如,当待处理数据为图像时,可以是把图像中的每个像元或区域划归为若干个预设类别中的某一种,其中,待分类目标可以是上述每个像元或区域。
在具体实现时,在训练网络模型的过程中,也可以训练辅助模型。在训练辅助模型的过程中可以将辅助模型当成网络模型,将网络模型作为辅助模型的辅助模型,从而利用上述训练网络模型的方式训练辅助模型。随着模型的不断训练,网络模型和辅助模型的模型参数不断调整,也可以理解为在继续执行从预设的样本集合中选取预设样本的步骤之后,将该预设样本输入至参数调整后的网络模型和参数调整后的辅助模型中,以使网络模型和辅助模型相辅相成,共同训练,直到损失值收敛。当网络模型与辅助模型的网络结构不同时,输出的分类结果也不同,也即网络模型和辅助模型关注的点不同,因此网络之间互相学习,共同训练可以提升每个模型的性能。
上述预设的损失函数可以根据网络模型输出的分类结果和辅助模型输出的分类结果,计算第一损失值,实现模型之间互相学习的目的;也可以仅根据网络模型模型输出的分类结果,计算第二损失值,实现模型自学习的目的。在网络模型的训练过程中,可以仅在第一损失值收敛时,得到训练后的网络模型,也可以在第一损失值和第二损失值均收敛时,得到训练后的网络模型。
本发明提供了一种目标分类方法,首先获取待处理数据;将该待处理数据输入至预先训练完成的网络模型中,得到该待处理数据中待分类目标的分类结果;该网络模型通过预设的辅助模型和损失函数训练得到,在训练网络模型和辅助模型的过程中,损失函数可根据辅助模型基于预设样本输出的预设类别的分类结果,以及网络模型基于预设样本输出的预设类别的分类结果,确定网络模型的第一损失值。由于该方式在训练网络模型的过程中同时训练辅助模型,因而辅助模型在训练过程中也在不断提升性能,损失函数基于辅助模型输出的分类结果,确定网络模型的损失值,因此,在辅助模型性能不断提升的同时,通过该损失函数可以将辅助模型的知识传递至网络模型,使得网络模型的性能也不断提升,从而网络模型的性能不再受到辅助模型在初始状态下的性能的限制,进一步提高了网络模型的性能。
实施例三:
本发明实施例还提供另一种目标分类方法,该方法在上述实施例所述方法的基础上实现;该方法重点描述获取待处理数据之前,训练网络模型的具体过程(通过下述步骤S302-S310实现);如图3所示,该方法包括如下具体步骤:
步骤S302,确定预设样本;该预设样本携带有类别标签。
上述预设样本通常是从预设的样本集合中随机选择的。上述预设样本携带的类别标签可以用具有特殊含义的数字表示,例如,用1表示预设类别1、用2表示预设类别2等。
步骤S304,将上述预设样本输入至网络模型和预设的辅助模型中。
步骤S306,通过预设的损失函数,计算上述网络模型基于预设样本输出的预设类别的分类结果,与辅助模型基于预设样本输出的预设类别的分类结果的相对熵,根据该相对熵确定第一损失值。
相对熵又可称为Kullback-Leibler散度(Kullback-Leibler divergence)或信息散度,通常是两个随机分布间差异的非对称性度量,它可以衡量两个随机分布之间的距离,当两个随机分布相同时,它们的相对熵为零,当两个随机分布的差别增大时,它们的相对熵也会增大。本方案中通过相对熵可以衡量网络模型对应的分类结果与辅助模型对应的分类结果的相似度,并将该相似度确定为第一损失值。
在具体实现时,上述步骤S306可以通过下述步骤10-13实现:
步骤10,计算网络模型基于预设样本输出的预设类别的分类结果对应的第一概率分布。
将网络模型基于预设样本输出的预设类别的分类结果输入至softmax激活函数,可得到网络模型对应的第一概率分布;该softmax激活函数可以用于多分类过程,它可将模型输出的结果映射到(0,1)区间内,可以看成概率来理解,从而来进行多分类。
在具体实现时,在确定预设样本时,可以一次确定多个预设样本,将多个预设样本均输入至网络模型和辅助模型中,得到网络模型基于多个预设样本输出的预设类别的分类结果和辅助模型基于多个预设样本输出预设类别的分类结果,从而通过下述方式计算第一概率分布:
首先计算网络模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率
Figure BDA0002383225390000121
然后组合每个概率
Figure BDA0002383225390000122
得到第一概率分布p1;其中,xi为多个预设样本中的第i个预设样本,M为预设类别的类别总数,
Figure BDA0002383225390000123
分别为网络模型基于第i个预设样本输出的第m个、第j个预设类别的分类结果。
上述第一概率分布的计算方式也可理解为通过softmax激活函数可得到每个预设样本在每个预设类别下的分类结果对应的概率,将其进行排列组合,可以得到第一概率分布,该概率分布通常为矩阵形式,其中,矩阵的每一行可以代表一个预设样本的M个预设类别对应的概率,也可以每一行代表多个预设样本在某一预设类别下对应的概率。
步骤11,计算辅助模型基于预设样本输出的预设类别的分类结果对应的第二概率分布。
计算辅助模型对应的第二概率分布的方式与第一概率分布的计算方式相同,可以通过下述方式计算第二概率分布:
首先计算辅助模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率
Figure BDA0002383225390000124
组合每个概率
Figure BDA0002383225390000125
得到第二概率分布p2;其中,xi为多个预设样本中的第i个预设样本,M为预设类别的类别总数,
Figure BDA0002383225390000126
分别为辅助模型基于第i个预设样本输出的第m个、第j个预设类别的分类结果。
步骤12,根据上述第一概率分布和第二概率分布,计算第二概率分布相对第一概率分布的第一相对熵
Figure BDA0002383225390000131
其中,p1为第一概率分布,p2为第二概率分布;N为预设样本的样本总数,xi为第i个预设样本,M为预设类别的类别总数;
Figure BDA0002383225390000132
为第一概率分布中,网络模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率,
Figure BDA0002383225390000133
为第二概率分布中,辅助模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率,log为以2为底的对数运算。
步骤13,将上述第一相对熵确定为第一损失值。
步骤S308,根据网络模型基于预设样本输出的预设类别的分类结果,以及预设样本携带的类别标签,确定第二损失值。
上述预设样本携带的类别标签可以是人工标注的预设样本中的待分类目标对应的准确的预设类别,该类别标签可以用数字表示,且该数字可以与预设类别的种类号相同,例如,类别标签中的数字1可以表示预设类别1,2可以表示预设类别2等;通常类别标签对应的数字的种类与预设类别的种类数一致,也即是预设类别的种类数为1~M,那么类别标签为1~M。
根据网络模型基于预设样本输出的预设类别的分类结果与预设样本携带的类别标签的差距,可以得到第二损失值,通常差距越大,第二损失值越大,差距越小第二损失值越小。在具体实现时,上述步骤S308,可以通过下述步骤20-21实现:
步骤20,计算网络模型基于预设样本输出的预设类别的分类结果,与预设样本携带的类别标签的交叉熵
Figure BDA0002383225390000134
其中,
Figure BDA0002383225390000135
N为预设样本的样本总数,xi为第i个预设样本,yi为第i个预设样本的预设类别标签,M为预设类别的类别总数,
Figure BDA0002383225390000141
Figure BDA0002383225390000142
为网络模型基于第i个预设样本输出的第m个预设类别的分类结果对应的分类概率,
Figure BDA0002383225390000143
分别为网络模型第i个预设样本输出的第m个、第j个预设类别的分类结果,log为以2为底的对数运算。上述分类概率通常是将网络模型基于预设样本输出的预设类别的分类结果输入至softmax激活函数得到的。
上述交叉熵通常用于度量两个概率分布间的差异性信息,通过该交叉熵可以衡量网络模型基于预设样本输出的预设类别的分类结果与预设样本携带的类别标签的相似度;在本方案中,可以将分类结果与类别标签的相似度确定为第一损失值,通常相似度越高,第一损失值越小。
步骤21,将上述交叉熵确定为第二损失值。
步骤S310,根据上述第一损失值和第二损失值,训练网络模型和辅助模型;继续执行确定预设样本的步骤,直到第一损失值和第二损失值收敛,得到训练后的网络模型。
在具体实现时,在训练网络模型的过程中,也可以训练辅助模型。在训练辅助模型的过程中可以将辅助模型当成网络模型,将网络模型作为辅助模型的辅助模型,以训练辅助模型。在具体实现时,可以将第一损失值和第二损失值的和确定为模型损失值L=L1+L2,基于该模型损失值,下面主要介绍通过步骤30-33训练网络模型的具体方式:
步骤30,计算模型损失值对网络模型中待更新参数的导数
Figure BDA0002383225390000144
其中,L为模型损失值;W为待更新参数;该待更新参数可以为网络模型中的所有参数,也可以为随机从网络模型中确定的部分参数;其中,该待更新参数也可以称为网络模型中各层网络的权值。通常可以根据反向传播算法求解待更新参数的导数;如果模型损失值较大,则说明当前的网络模型的输出与期望输出结果不符。
步骤31,更新待更新参数,得到更新后的待更新参数
Figure BDA0002383225390000151
其中,α为预设系数,也即是学习率。该过程也可以称为梯度下降算法;各个待更新参数的导数也可以理解为相对于当前参数,模型损失值下降最快的方向,通过该方向调整参数,可以使模型损失值快速降低,使该待更新参数收敛速度也加快。
步骤32,判断更新后的网络模型的参数是否均收敛,如果均收敛,执行确定预设样本的步骤;否则执行步骤S33。
如果更新后的网络模型的参数不是均收敛,则继续执行确定预设样本的步骤,直到更新后的网络模型的参数均收敛。
步骤33,将参数更新后的网络模型确定为训练后的网络模型。
另外,当网络模型经一次训练后,得到模型损失值,此时可以从网络模型的各个参数中随机选择一个或多个参数进行上述的更新过程,该方式的模型训练时间较短,算法较快;当然也可以对网络模型中所有参数进行上述的更新过程,该方式的模型训练更加准确。
步骤S312,如果获取到待处理数据,将该待处理数据输入至训练后的网络模型中,得到待处理数据中待分类目标的分类结果。
上述目标分类方法,首先确定预设样本,再将该预设样本输入至网络模型和预设的辅助模型中;进而通过预设的损失函数,计算上述网络模型基于预设样本输出的预设类别的分类结果,与辅助模型基于预设样本输出的预设类别的分类结果的相对熵,将相对熵确定为第一损失值;然后根据网络模型基于预设样本输出的预设类别的分类结果以及预设样本携带的类别标签确定第二损失值;再根据第一损失值和第二损失值,训练网络模型和辅助模型,直到第一损失值和第二损失值收敛,得到训练后的网络模型;如果获取到待处理数据,将该待处理数据输入至训练后的网络模型中,得到待处理数据中待分类目标的分类结果。该方式可以互相学习的方式训练网络模型和辅助模型,提高了模型的泛化性能,而且该方式不仅可以用于训练高效的小网络对应的网络模型,也可以进一步提升大网络对应的网络模型的性能,且容易扩展到多网络学习及半监督学习场景中,同时该方式训练得到的网络模型可以准确快速地分类出待处理数据中待分类目标的类别,具有非常重要的实用价值。
实施例四:
本发明实施例还提供另一种目标分类方法,该方法在上述实施例所述方法的基础上实现;该方法重点描述当辅助模型为多个时,训练网络模型的具体过程(通过下述步骤S402-S410实现);如图4所示,该方法包括如下具体步骤:
步骤S402,确定预设样本;该预设样本携带有类别标签。
步骤S404,将上述预设样本输入至网络模型和多个辅助模型中。
步骤S406,计算网络模型基于预设样本输出的预设类别的分类结果对应的第三概率分布。
将网络模型基于预设样本输出的预设类别的分类结果输入至softmax激活函数,可得到网络模型对应的第三率分布。在具体实现时,在确定预设样本时,可以一次确定多个预设样本,例如预设样本的样本总数为N,基于此,上述步骤S406,可以通过下述步骤40-41实现:
步骤40,计算网络模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率
Figure BDA0002383225390000161
步骤41,组合每个上述概率
Figure BDA0002383225390000171
得到第三概率分布pk;其中,xi为第i个所述预设样本,M为预设类别的类别总数,
Figure BDA0002383225390000172
分别为网络模型基于第i个预设样本输出的第m个、第j个预设类别的分类结果。
上述第三概率分布的计算方式也可理解为通过softmax激活函数可得到每个预设样本在每个预设类别下的分类结果对应的概率,将其进行排列组合,可以得到第三概率分布,该概率分布通常为矩阵形式,其中,矩阵的每一行可以代表一个预设样本的M个预设类别对应的概率,也可以每一行代表多个预设样本在某一预设类别下对应的概率。
步骤S408,针对每个辅助模型,执行下述操作:计算当前辅助模型基于预设样本输出的预设类别的分类结果对应的第四概率分布。
在具体实现时,需要计算多个辅助模型中每个辅助模型基于预设样本输出的预设类别的分类结果对应的第四概率分布,计算每个辅助模型对应的第四概率分布的方式与上述计算第三概率分布的方式相同。在一些实施例中可以通过下述步骤50-51计算多个所述辅助模型中第l个辅助模型对应的第四概率分布(相当于上述当前辅助模型):
步骤50,计算多个辅助模型中第l个辅助模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率
Figure BDA0002383225390000173
步骤51,组合每个上述概率
Figure BDA0002383225390000174
得到第l个辅助模型对应的第四概率分布pl;其中,xi为第i个预设样本,M为预设类别的类别总数,
Figure BDA0002383225390000175
分别为第l个辅助模型基于第i个预设样本输出的第m个、第j个预设类别的分类结果。
步骤S410,计算多个辅助模型中每个辅助模型对应的第四概率分布相对第三概率分布的第二相对熵,根据第二相对熵确定第一损失值。
根据多个辅助模型中每个辅助模型对应的第四概率分布,相对于第三概率分布的相对熵,可以得到第二相对熵,将该第二相对熵确定为第一损失值,该方式得到的第一损失值与每个辅助模型输出的分类结果有关,在后续模型训练的过程中,可以确保网络模型与所有辅助模型相互学习。
在一些实施例中,上述步骤S410可以通过下述步骤60-62实现:
步骤60,计算每个辅助模型对应的第四概率分布相对第三概率分布的第二相对熵
Figure BDA0002383225390000181
其中,pk为第三概率分布,pl为多个辅助模型中第l个辅助模型对应的第四概率分布,K-1为辅助模型的总个数;N为预设样本的样本总数,xi为第i个预设样本,M为预设类别的类别总数;
Figure BDA0002383225390000182
为第三概率分布中,网络模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率,
Figure BDA0002383225390000183
为多个辅助模型中的第l个辅助模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率,log为以2为底的对数运算。
步骤61,计算每个辅助模型对应的第二相对熵的平均值。
步骤62,将上述平均值确定为第一损失值
Figure BDA0002383225390000184
由上述第一损失值的公式可知,该网络模型可以辅助模型中的一个。在具体实现时,可以根据K-1个辅助模型中每个辅助模型对应的第四概率分布相对第三概率分布的相对熵的均值作为模型训练的优化目标,也即是可以通过K-1个辅助模型训练网络模型,也可以将网络模型作为辅助模型,基于上述方式训练K-1个辅助模型中的任意辅助模型。
在另一些实施例中,上述步骤S410还可以通过下述步骤70-71实现:
步骤70,计算每个辅助模型基于预设样本输出的预设类别的分类结果对应的第四概率分布的均值概率分布
Figure BDA0002383225390000185
其中,pk为第三概率分布,pavg为均值概率分布,K-1为辅助模型的总个数,pl为多个辅助模型中第l个辅助模型对应的第四概率分布。
步骤71,计算上述均值概率分布相对第三概率分布的第二相对熵,将第二相对熵确定为第一损失值L1=DKL(pavg||pk);该第一损失值由K-1个辅助模型的平均概率分布相对于第三概率分布的相对熵计算得到,在具体实现时,可以根据K-1个辅助模型对应的第四概率分布的均值相对于第三概率分布的相对熵作为模型训练的优化目标,也即是可以通过K-1个辅助模型训练网络模型;通常也可以将网络模型作为辅助模型,并基于上述方式训练K-1个辅助模型中的任意辅助模型。
步骤S412,根据网络模型基于预设样本输出的预设类别的分类结果,以及预设样本携带的类别标签,确定第二损失值。
步骤S414,根据上述第一损失值和第二损失值,训练网络模型和辅助模型;继续执行确定预设样本的步骤,直到第一损失值和第二损失值收敛,得到训练后的网络模型。
步骤S416,如果获取到待处理数据,将该待处理数据输入至训练后的网络模型中,得到待处理数据中待分类目标的分类结果。
上述目标分类方法,可通过多个辅助模型训练网络模型,以实现模型间的互相学习,从而既可提升网络模型的性能,又可提高辅助模型的性能,同时该方式训练得到的网络模型可以准确快速地分类出待处理数据中待分类目标的类别。
实施例五:
本发明实施例还提供另一种目标分类方法,该方法在上述实施例所述方法的基础上实现;该方法重点训练网络模型的具体过程(通过下述步骤S502-S508实现);如图5所示,该方法包括如下具体步骤:
步骤S502,确定样本集合;该样本集合中的每个样本携带有类别标签。
上述样本集合通常是从预设的训练集合中确定的,该训练集合中包含有大量的样本,且每个样本携带有类别标签。在具体实现时,可以从训练集合中随机选择预设数量的样本,将选择的预设数量的样本组合为样本集合。
步骤S504,将上述样本集合分别输入至网络模型和辅助模型中,得到网络模型输出的预设类别的分类结果,和辅助模型输出的预设类别的分类结果。
在训练网络模型和辅助模型之前需要对模型进行初始化,可以采用不同的初始化条件对网络模型和辅助模型进行随机初始化,也即是可以分别随机初始化网络模型和辅助模型中的参数和网络结构等。
在具体实现时,需要将样本集合中的每个样本均输入至网络模型或者辅助模型中,如果辅助模型的数量为多个,需要将样本集合中的每个样本输入至多个辅助模型中。
步骤S506,通过上述损失函数,确定第一损失值和第二损失值。
根据预设的损失函数,可以得到与辅助模型输出的预设类别的分类结果有关的第一损失值,也可以得到网络模型输出的预设类别的分类结果与样本携带的类别标签对应的第二损失值。在具体实现时,可以通过上述步骤S306或者步骤S406-S410计算第一损失值,通过上述步骤S308计算第二损失值。
步骤S508,根据上述第一损失值和第二损失值,训练网络模型和辅助模型;继续执行确定样本集合的步骤,直到第一损失值和第二损失值收敛,得到训练后的网络模型。
在训练网络模型和辅助模型的过程中,需要根据第一损失值和第二损失值,调整网络模型的参数;如果第一损失值和第二损失值不收敛,需要继续从预设样的训练集中确定新的样本集合;并将新的样本集合输入至网络模型和辅助模型中,继续训练网络模型和辅助模型,直到第一损失值和第二损失值收敛。
在具体实现时,上述步骤S508中根据第一损失值和第二损失值,训练网络模型和辅助模型的步骤,可以通过下述步骤80-83实现:
步骤80,根据第一损失值和第二损失值,调整预设第一网络的参数;其中,第一网络为网络模型或辅助模型。
在网络模型和辅助模型进行初始化的过程中,需要固定一个模型,训练另一个模型,也可固定多个模型训练,训练另一个模型;也可以理解为当辅助模型为多个时,固定多个辅助模型,训练网络模型;当训练多个辅助模型中的一个辅助模型时,需要将网络模型看作辅助模型,固定该网络模型,以及多个辅助模型中其余的辅助模型,训练另外的辅助模型。
在具体实现时,首先将样本集合中的样本输入至网络模型和辅助模型中,得到网络模型输出的预设类别的分类结果,和辅助模型输出的预设类别的分类结果;再根据损失函数计算的第一网络模型对应的第一损失值和第二损失值,调整第一网络的参数;其中,第一网络可以为网络模型,可以是为辅助模型。
步骤81,将样本集合中的样本输入至网络模型和辅助模型中,得到网络模型输出的预设类别的分类结果,和辅助模型输出的预设类别的分类结果。
再次将样本集合中的样本输入至网络模型和辅助模型中,得到网络模型输出的预设类别的分类结果,和辅助模型输出的预设类别的分类结果。如果辅助模型为多个,需要将样本集合中的样本输入至多个辅助模型中。
步骤82,通过损失函数,确定第三损失值和第四损失值。
基于步骤82中网络模型输出的预设类别的分类结果,和辅助模型输出的预设类别的分类结果,得到第二网络模型对应的第三损失值和第四损失值,其中,可以将辅助模型作为网络模型,将网络模型作为辅助模型的辅助模型,通过上述步骤S306或者步骤S406-S410计算第三损失值,通过上述步骤S308计算第四损失值。
步骤83,根据上述第三损失值和第四损失值,调整预设第二网络的参数;其中,当第一网络为网络模型时,该第二网络为辅助模型;当第一网络为辅助模型时,第二网络为网络模型。
在根据第三损失值和第四损失值,调整第二网络的参数,该第二网络可以是辅助模型或者网络模型,但是当第一网络为网络模型时,第二网络为辅助模型;当第一网络为辅助模型时,第二网络为网络模型,从而交替更新网络模型和辅助模型的参数,也即是交替更新网络模型和辅助模型。
步骤S510,获取待处理数据;该待处理数据中包括待分类目标。
步骤S512,将上述待处理数据输入至训练后的网络模型中,得到待处理数据中待分类目标的分类结果。
上述目标分类方法,在训练网络模型和辅助模型的过程中,随着网络模型性能的提升,辅助模型的性能也随之提升,而且该方式中的模型训练方式可以训练参数量较小的网络,达到参数量大的网络的性能,从而可以训练得到的模型部署在资源条件有限的环境中,以通过该模型准确快速地分类出待处理数据中待分类目标的类别。
实施例六:
对应于上述目标分类方法实施例,本发明实施例提供了一种目标分类装置,如图6所示,该装置包括:
数据获取模块60,用于获取待处理数据;该待处理数据中包括待分类目标。
数据处理模块61,用于将上述待处理数据输入至预先训练完成的网络模型中,得到待分类目标的分类结果;其中,网络模型通过预设的辅助模型和预设的损失函数训练得到;损失函数用于:在训练网络模型和辅助模型的过程中,根据辅助模型基于预设样本输出的预设类别的分类结果,以及网络模型基于预设样本输出的预设类别的分类结果,确定网络模型的第一损失值。
具体地,上述装置包括第一损失值确定模块,用于:计算网络模型基于预设样本输出的预设类别的分类结果,与辅助模型基于预设样本输出的预设类别的分类结果的相对熵,根据该相对熵确定第一损失值。
进一步地,上述第一损失值确定模块,包括:第一概率计算单元,用于计算网络模型基于预设样本输出的预设类别的分类结果对应的第一概率分布;第二概率计算单元,用于计算辅助模型基于预设样本输出的预设类别的分类结果对应的第二概率分布;相对熵确定单元,用于根据第一概率分布和第二概率分布,计算第二概率分布相对第一概率分布的第一相对熵
Figure BDA0002383225390000231
将该第一相对熵确定为第一损失值;其中,p1为第一概率分布,p2为第二概率分布;N为预设样本的样本总数,xi为第i个预设样本,M为预设类别的类别总数;
Figure BDA0002383225390000232
为第一概率分布中,网络模型基于第i个所述预设样本输出的第m个预设类别的分类结果对应的概率,
Figure BDA0002383225390000233
为第二概率分布中,辅助模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率,log为以2为底的对数运算。
进一步地,上述第一概率计算单元,还用于:计算网络模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率
Figure BDA0002383225390000234
组合每个概率
Figure BDA0002383225390000235
得到第一概率分布p1;上述第二概率计算单元,还用于:计算辅助模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率
Figure BDA0002383225390000236
组合每个概率
Figure BDA0002383225390000237
得到第二概率分布p2;其中,
Figure BDA0002383225390000241
分别为网络模型基于第i个预设样本输出的第m个、第j个预设类别的分类结果;分别为辅助模型基于第i个预设样本输出的第m个、第j个预设类别的分类结果。
进一步地,上述第一损失值确定模块,还包括:第三概率计算单元,用于计算网络模型基于预设样本输出的预设类别的分类结果对应的第三概率分布;第四概率计算单元,用于针对每个辅助模型,执行下述操作:计算当前辅助模型基于预设样本输出的预设类别的分类结果对应的第四概率分布;相对熵计算单元,用于:计算多个辅助模型中每个辅助模型对应的第四概率分布相对第三概率分布的第二相对熵,根据第二相对熵确定第一损失值。
进一步地,上述相对熵计算单元,还用于:计算每个辅助模型对应的第四概率分布相对第三概率分布的第二相对熵
Figure BDA0002383225390000243
计算每个辅助模型对应的第二相对熵的平均值;将该平均值确定为第一损失值
Figure BDA0002383225390000244
其中,pk为第三概率分布,pl为多个辅助模型中第l个辅助模型对应的第四概率分布,K-1为辅助模型的总个数;N为预设样本的样本总数,xi为第i个预设样本,M为预设类别的类别总数;
Figure BDA0002383225390000245
为第三概率分布中,网络模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率,
Figure BDA0002383225390000246
为多个辅助模型中的第l个辅助模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率,log为以2为底的对数运算。
进一步地,上述相对熵计算单元,还用于:计算每个辅助模型基于预设样本输出的预设类别的分类结果对应的第四概率分布的均值概率分布
Figure BDA0002383225390000247
计算均值概率分布相对第三概率分布的第二相对熵,将第二相对熵确定为第一损失值L1=DKL(pavg||pk);其中,pk为第三概率分布,pavg为均值概率分布,K-1为辅助模型的总个数,pl为多个辅助模型中第l个辅助模型对应的第四概率分布。
进一步地,上述第三概率计算单元,还用于:计算网络模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率
Figure BDA0002383225390000251
组合每个概率
Figure BDA0002383225390000252
得到第三概率分布pk;上述第四概率计算单元,还用于:计算多个辅助模型中第l个辅助模型基于第i个预设样本输出的第m个预设类别的分类结果对应的概率
Figure BDA0002383225390000253
组合每个概率
Figure BDA0002383225390000254
得到第l个辅助模型对应的第四概率分布pl;其中,xi为第i个预设样本,
Figure BDA0002383225390000255
分别为网络模型基于第i个预设样本输出的第m个、第j个预设类别的分类结果;
Figure BDA0002383225390000256
分别为第l个辅助模型基于第i个预设样本输出的第m个、第j个预设类别的分类结果。
进一步地,上述装置还包括第二损失值确定模块,用于:根据网络模型基于预设样本输出的预设类别的分类结果,以及预设样本携带的类别标签,确定第二损失值。
具体地,上述第二损失值确定模块,还用于:计算网络模型基于预设样本输出的预设类别的分类结果,与预设样本携带的类别标签的交叉熵
Figure BDA0002383225390000257
将该交叉熵确定为第二损失值;其中,
Figure BDA0002383225390000258
N为预设样本的样本总数,xi为第i个预设样本,yi为第i个预设样本的预设类别标签,M为预设类别的类别总数,
Figure BDA0002383225390000259
Figure BDA0002383225390000261
为网络模型基于第i个预设样本输出的第m个预设类别的分类结果对应的分类概率,
Figure BDA0002383225390000262
分别为网络模型第i个预设样本输出的第m个、第j个预设类别的分类结果,log为以2为底的对数运算。
进一步地,上述装置还包括模型训练模块,用于:确定样本集合;该样本集合中的每个样本携带有类别标签;将样本集合分别输入至网络模型和辅助模型中,得到网络模型输出的预设类别的分类结果,和辅助模型输出的预设类别的分类结果;通过损失函数,确定第一损失值和第二损失值;根据第一损失值和第二损失值,训练网络模型和辅助模型;继续执行确定样本集合的步骤,直到第一损失值和第二损失值收敛,得到训练后的网络模型。
具体地,上述模型训练模块,还用于:根据第一损失值和第二损失值,调整预设第一网络的参数;其中,第一网络为网络模型或辅助模型;将样本集合中的样本输入至网络模型和辅助模型中,得到网络模型输出的预设类别的分类结果,和每个辅助模型输出的预设类别的分类结果;通过损失函数,确定第三损失值和第四损失值;根据第三损失值和第四损失值,调整预设第二网络的参数;其中,当第一网络为网络模型时,第二网络为辅助模型;当第一网络为辅助模型时,第二网络为所述网络模型。
上述目标分类装置,首先获取待处理数据;将该待处理数据输入至预先训练完成的网络模型中,得到该待处理数据中待分类目标的分类结果;该网络模型通过预设的辅助模型和损失函数训练得到,在训练网络模型和辅助模型的过程中,损失函数可根据辅助模型基于预设样本输出的预设类别的分类结果,以及网络模型基于预设样本输出的预设类别的分类结果,确定网络模型的第一损失值。由于该方式在训练网络模型的过程中同时训练辅助模型,因而辅助模型在训练过程中也在不断提升性能,损失函数基于辅助模型输出的分类结果,确定网络模型的损失值,因此,在辅助模型性能不断提升的同时,通过该损失函数可以将辅助模型的知识传递至网络模型,使得网络模型的性能也不断提升,从而网络模型的性能不再受到辅助模型在初始状态下的性能的限制,进一步提高了网络模型的性能。
实施例六:
本发明实施例提供了一种电子系统,该电子系统包括:处理设备和存储装置;该存储装置上存储有计算机程序,该计算机程序在被处理设备运行时执行上述目标分类方法。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的电子系统的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
进一步,本实施例还提供了一种计算机可读存储介质,该计算机可读存储介质上存储有计算机程序,该计算机程序被处理设备运行时执行上述目标分类方法。
本发明实施例所提供的一种目标分类方法、装置和电子系统的计算机程序产品,包括存储了程序代码的计算机可读存储介质,所述程序代码包括的指令可用于执行前面方法实施例中所述的方法,具体实现可参见方法实施例,在此不再赘述。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的系统和/或装置的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
所述功能如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
最后应说明的是:以上各实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述各实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分或者全部技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的范围。

Claims (15)

1.一种目标分类方法,其特征在于,所述方法包括:
获取待处理数据;所述待处理数据中包括待分类目标;
将所述待处理数据输入至预先训练完成的网络模型中,得到所述待分类目标的分类结果;
其中,所述网络模型通过预设的辅助模型和预设的损失函数训练得到:所述损失函数用于:在训练所述网络模型和所述辅助模型的过程中,根据所述辅助模型基于预设样本输出的预设类别的分类结果,以及所述网络模型基于所述预设样本输出的所述预设类别的分类结果,确定所述网络模型的第一损失值。
2.根据权利要求1所述的方法,其特征在于,所述损失函数具体通过下述方式确定第一损失值:
计算所述网络模型基于所述预设样本输出的所述预设类别的分类结果,与所述辅助模型基于所述预设样本输出的所述预设类别的分类结果的相对熵,根据所述相对熵确定所述第一损失值。
3.根据权利要求2所述的方法,其特征在于,所述计算所述网络模型基于所述预设样本输出的所述预设类别的分类结果,与所述辅助模型基于所述预设样本输出的所述预设类别的分类结果的相对熵,根据所述相对熵确定所述第一损失值的步骤,包括:
计算所述网络模型基于所述预设样本输出的所述预设类别的分类结果对应的第一概率分布;
计算所述辅助模型基于所述预设样本输出的所述预设类别的分类结果对应的第二概率分布;
根据所述第一概率分布和所述第二概率分布,计算所述第二概率分布相对所述第一概率分布的第一相对熵;
将所述第一相对熵确定为所述第一损失值。
4.根据权利要求3所述的方法,其特征在于,计算所述网络模型基于所述预设样本输出的所述预设类别的分类结果对应的第一概率分布的步骤,包括:
计算所述网络模型基于第i个所述预设样本xi输出的第m个预设类别的分类结果对应的概率
Figure FDA0002383225380000021
组合每个所述概率
Figure FDA0002383225380000022
得到第一概率分布p1
所述计算所述辅助模型基于所述预设样本输出的所述预设类别的分类结果对应的第二概率分布的步骤,包括:
计算所述辅助模型基于第i个所述预设样本xi输出的第m个预设类别的分类结果对应的概率
Figure FDA0002383225380000023
组合每个所述概率
Figure FDA0002383225380000024
得到第二概率分布p2
5.根据权利要求2所述的方法,其特征在于,所述辅助模型包括多个;所述计算所述网络模型基于所述预设样本输出的所述预设类别的分类结果,与所述辅助模型基于所述预设样本输出的所述预设类别的分类结果的相对熵,根据所述相对熵确定所述第一损失值的步骤,包括:
计算所述网络模型基于所述预设样本输出的所述预设类别的分类结果对应的第三概率分布;
针对每个所述辅助模型,执行下述操作:计算当前辅助模型基于所述预设样本输出的所述预设类别的分类结果对应的第四概率分布;
计算多个所述辅助模型中每个所述辅助模型对应的第四概率分布相对所述第三概率分布的第二相对熵,根据所述第二相对熵确定所述第一损失值。
6.根据权利要求5所述的方法,其特征在于,所述计算多个所述辅助模型中每个所述辅助模型对应的第四概率分布相对所述第三概率分布的第二相对熵的步骤,包括:
计算每个所述辅助模型对应的第四概率分布相对所述第三概率分布的第二相对熵;
所述根据所述第二相对熵确定所述第一损失值的步骤,包括:
计算每个所述辅助模型对应的所述第二相对熵的平均值;
将所述平均值确定为所述第一损失值。
7.根据权利要求5所述的方法,其特征在于,所述计算多个所述辅助模型中每个所述辅助模型对应的第四概率分布相对所述第三概率分布的第二相对熵,根据所述第二相对熵确定所述第一损失值的步骤,包括:
计算每个所述辅助模型基于所述预设样本输出的所述预设类别的分类结果对应的第四概率分布的均值概率分布;
计算所述均值概率分布相对所述第三概率分布的第二相对熵,将所述第二相对熵确定为第一损失值。
8.根据权利要求5所述的方法,其特征在于,所述计算所述网络模型基于所述预设样本输出的所述预设类别的分类结果对应的第三概率分布的步骤,包括:
计算所述网络模型基于第i个所述预设样本xi输出的第m个预设类别的分类结果对应的概率
Figure FDA0002383225380000031
组合每个所述概率
Figure FDA0002383225380000032
得到第三概率分布pk
所述计算当前辅助模型基于所述预设样本输出的所述预设类别的分类结果对应的第四概率分布的步骤,包括:
计算多个所述辅助模型中第l个辅助模型基于第i个预设样本xi输出的第m个预设类别的分类结果对应的概率
Figure FDA0002383225380000033
组合每个所述概率
Figure FDA0002383225380000034
得到第l个所述辅助模型对应的第四概率分布pl
9.根据权利要求1所述的方法,其特征在于,所述损失函数还用于:
根据所述网络模型基于预设样本输出的所述预设类别的分类结果,以及所述预设样本携带的类别标签,确定第二损失值。
10.根据权利要求9所述的方法,其特征在于,所述根据所述网络模型基于预设样本输出的所述预设类别的分类结果,以及所述预设样本携带的类别标签,确定第二损失值的步骤,包括:
计算所述网络模型基于预设样本输出的所述预设类别的分类结果,与所述预设样本携带的类别标签的交叉熵;
将所述交叉熵确定为第二损失值。
11.根据权利要求9所述的方法,其特征在于,所述网络模型,具体通过下述方式训练得到:
确定样本集合;所述样本集合中的每个样本携带有类别标签;
将所述样本集合分别输入至所述网络模型和所述辅助模型中,得到所述网络模型输出的所述预设类别的分类结果,和所述辅助模型输出的所述预设类别的分类结果;
通过所述损失函数,确定所述第一损失值和所述第二损失值;
根据所述第一损失值和所述第二损失值,训练所述网络模型和所述辅助模型;继续执行确定样本集合的步骤,直到所述第一损失值和所述第二损失值收敛,得到训练后的所述网络模型。
12.根据权利要求11所述的方法,其特征在于,根据所述第一损失值和所述第二损失值,训练所述网络模型和所述辅助模型的步骤,包括:
根据所述第一损失值和所述第二损失值,调整预设第一网络的参数;其中,所述第一网络为所述网络模型或所述辅助模型;
将所述样本集合中的样本输入至所述网络模型和所述辅助模型中,得到所述网络模型输出的所述预设类别的分类结果,和每个所述辅助模型输出的所述预设类别的分类结果;
通过所述损失函数,确定第三损失值和第四损失值;
根据所述第三损失值和所述第四损失值,调整预设第二网络的参数;其中,当所述第一网络为所述网络模型时,所述第二网络为所述辅助模型;当所述第一网络为所述辅助模型时,所述第二网络为所述网络模型。
13.一种目标分类装置,其特征在于,所述装置包括:
数据获取模块,用于获取待处理数据;所述待处理数据中包括待分类目标;
数据处理模块,用于将所述待处理数据输入至预先训练完成的网络模型中,得到所述待分类目标的分类结果;其中,所述网络模型通过预设的辅助模型和预设的损失函数训练得到:所述损失函数用于:在训练所述网络模型和所述辅助模型的过程中,根据所述辅助模型基于预设样本输出的预设类别的分类结果,以及所述网络模型基于所述预设样本输出的所述预设类别的分类结果,确定所述网络模型的第一损失值。
14.一种电子系统,其特征在于,所述电子系统包括:处理设备和存储装置;
所述存储装置上存储有计算机程序,所述计算机程序在被所述处理设备运行时执行如权利要求1至12任一项所述的目标分类方法。
15.一种计算机可读存储介质,所述计算机可读存储介质上存储有计算机程序,其特征在于,所述计算机程序被处理设备运行时执行如权利要求1至12任一项所述的目标分类方法的步骤。
CN202010089737.0A 2020-02-12 2020-02-12 目标分类方法、装置和电子系统 Active CN111310823B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010089737.0A CN111310823B (zh) 2020-02-12 2020-02-12 目标分类方法、装置和电子系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010089737.0A CN111310823B (zh) 2020-02-12 2020-02-12 目标分类方法、装置和电子系统

Publications (2)

Publication Number Publication Date
CN111310823A true CN111310823A (zh) 2020-06-19
CN111310823B CN111310823B (zh) 2024-03-29

Family

ID=71147054

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010089737.0A Active CN111310823B (zh) 2020-02-12 2020-02-12 目标分类方法、装置和电子系统

Country Status (1)

Country Link
CN (1) CN111310823B (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112733539A (zh) * 2020-12-30 2021-04-30 平安科技(深圳)有限公司 面试实体识别模型训练、面试信息实体提取方法及装置

Citations (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108229652A (zh) * 2017-11-28 2018-06-29 北京市商汤科技开发有限公司 神经网络模型迁移方法和系统、电子设备、程序和介质
WO2018212584A2 (ko) * 2017-05-16 2018-11-22 삼성전자 주식회사 딥 뉴럴 네트워크를 이용하여 문장이 속하는 클래스를 분류하는 방법 및 장치
CN109784537A (zh) * 2018-12-14 2019-05-21 北京达佳互联信息技术有限公司 广告点击率的预估方法、装置及服务器和存储介质
WO2019105157A1 (zh) * 2017-11-30 2019-06-06 腾讯科技(深圳)有限公司 摘要描述生成方法、摘要描述模型训练方法和计算机设备
CN110309922A (zh) * 2019-06-18 2019-10-08 北京奇艺世纪科技有限公司 一种网络模型训练方法和装置
CN110348563A (zh) * 2019-05-30 2019-10-18 平安科技(深圳)有限公司 神经网络半监督训练方法、装置、服务器及存储介质
CN110427466A (zh) * 2019-06-12 2019-11-08 阿里巴巴集团控股有限公司 用于问答匹配的神经网络模型的训练方法和装置
CN110659665A (zh) * 2019-08-02 2020-01-07 深圳力维智联技术有限公司 一种异维特征的模型构建方法及图像识别方法、装置

Patent Citations (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2018212584A2 (ko) * 2017-05-16 2018-11-22 삼성전자 주식회사 딥 뉴럴 네트워크를 이용하여 문장이 속하는 클래스를 분류하는 방법 및 장치
CN108229652A (zh) * 2017-11-28 2018-06-29 北京市商汤科技开发有限公司 神经网络模型迁移方法和系统、电子设备、程序和介质
WO2019105157A1 (zh) * 2017-11-30 2019-06-06 腾讯科技(深圳)有限公司 摘要描述生成方法、摘要描述模型训练方法和计算机设备
CN109784537A (zh) * 2018-12-14 2019-05-21 北京达佳互联信息技术有限公司 广告点击率的预估方法、装置及服务器和存储介质
CN110348563A (zh) * 2019-05-30 2019-10-18 平安科技(深圳)有限公司 神经网络半监督训练方法、装置、服务器及存储介质
CN110427466A (zh) * 2019-06-12 2019-11-08 阿里巴巴集团控股有限公司 用于问答匹配的神经网络模型的训练方法和装置
CN110309922A (zh) * 2019-06-18 2019-10-08 北京奇艺世纪科技有限公司 一种网络模型训练方法和装置
CN110659665A (zh) * 2019-08-02 2020-01-07 深圳力维智联技术有限公司 一种异维特征的模型构建方法及图像识别方法、装置

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
陈鹏飞;应自炉;朱健菲;商丽娟;: "面向手写汉字识别的残差深度可分离卷积算法" *

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112733539A (zh) * 2020-12-30 2021-04-30 平安科技(深圳)有限公司 面试实体识别模型训练、面试信息实体提取方法及装置

Also Published As

Publication number Publication date
CN111310823B (zh) 2024-03-29

Similar Documents

Publication Publication Date Title
CN110880036B (zh) 神经网络压缩方法、装置、计算机设备及存储介质
US11893781B2 (en) Dual deep learning architecture for machine-learning systems
US11468262B2 (en) Deep network embedding with adversarial regularization
KR102180994B1 (ko) 적응적 인공 신경 네트워크 선택 기법들
CN110147700B (zh) 视频分类方法、装置、存储介质以及设备
CN111414987B (zh) 神经网络的训练方法、训练装置和电子设备
WO2021238262A1 (zh) 一种车辆识别方法、装置、设备及存储介质
JP7266674B2 (ja) 画像分類モデルの訓練方法、画像処理方法及び装置
US20220215259A1 (en) Neural network training method, data processing method, and related apparatus
CN110633745A (zh) 一种基于人工智能的图像分类训练方法、装置及存储介质
WO2020238353A1 (zh) 数据处理方法和装置、存储介质及电子装置
CN106778910B (zh) 基于本地训练的深度学习系统和方法
CN111401521A (zh) 神经网络模型训练方法及装置、图像识别方法及装置
WO2018220700A1 (ja) 新規学習データセット生成方法、新規学習データセット生成装置および生成された学習データセットを用いた学習方法
CN108492301A (zh) 一种场景分割方法、终端及存储介质
US11941867B2 (en) Neural network training using the soft nearest neighbor loss
CN111564179A (zh) 一种基于三元组神经网络的物种生物学分类方法及系统
CN111275780B (zh) 人物图像的生成方法及装置
CN113782093B (zh) 一种基因表达填充数据的获取方法及装置、存储介质
CN111310823B (zh) 目标分类方法、装置和电子系统
CN112000803B (zh) 文本分类方法及装置、电子设备及计算机可读存储介质
CN113011532A (zh) 分类模型训练方法、装置、计算设备及存储介质
TWI803243B (zh) 圖像擴增方法、電腦設備及儲存介質
EP1837807A1 (en) Pattern recognition method
CN113312445B (zh) 数据处理方法、模型构建方法、分类方法及计算设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant