CN112232445B - 多标签分类任务网络的训练方法和装置 - Google Patents
多标签分类任务网络的训练方法和装置 Download PDFInfo
- Publication number
- CN112232445B CN112232445B CN202011441233.7A CN202011441233A CN112232445B CN 112232445 B CN112232445 B CN 112232445B CN 202011441233 A CN202011441233 A CN 202011441233A CN 112232445 B CN112232445 B CN 112232445B
- Authority
- CN
- China
- Prior art keywords
- network
- label
- layer
- feature extraction
- weight
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/243—Classification techniques relating to the number of classes
- G06F18/2431—Multiple classes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本申请提出多标签分类任务网络的训练方法和装置,其中方法包括:训练用于构建多标签分类任务网络的种子网络;采用种子网络创建多标签分类任务网络的搜索空间,其中,搜索空间中包括各个标签类别对应的搜索路径;优化搜索空间的网络权重和路径权重;根据优化后的路径权重及多标签分类任务网络的搜索空间,确定多标签分类任务网络的最终结构;训练多标签分类任务网络。本申请实施例提供了高效的多标签分类网络,既能减少算法的计算量,又能大大提高算法的精度,提高算法的竞争力,并增加算法的应用场景。
Description
技术领域
本申请涉及深度学习技术领域,尤其涉及多标签分类任务网络的训练方法和装置。
背景技术
随着深度学习的发展,出现了越来越多的多标签分类方法。例如,对人脸属性的分析、面部运动单元的检测、人体属性的分析等。采用多标签分类方法,可以通过多标签对实体进行多方位的标签分析,得出更加科学的结论。
目前,对于多标签分类主要采用两种方法解决。第一种,针对输入分类网络的图像,由全部的分类任务通过共享权重提取图像特征,最后采用全联接层输出最终结果。第二种,先采用共享权重提取图像特征,再设计分支结构提取每个标签类别的独特特征,最终通过分支输出各自的结果。
上述方法一只能提取全部的特征,无法分析每个类别需要的特有信息。方法二可以提取每个标签类别需要的特有信息,但是设定分支的位置是依靠人的先验知识来设定的,无法权衡位置是否是全局和局部信息提取的最佳位置,也无法确定哪些分支是否有关联性、是否可以合并来减少计算量等。
发明内容
本申请实施例提供一种多标签分类任务网络的训练方法、装置、电子设备及存储介质,以解决相关技术存在的问题,技术方案如下:
第一方面,本申请实施例提供了一种多标签分类任务网络的训练方法,包括:
训练用于构建多标签分类任务网络的种子网络;
采用种子网络创建多标签分类任务网络的搜索空间,其中,搜索空间中包括各个标签类别对应的搜索路径;
优化搜索空间的网络权重和路径权重;
根据优化后的路径权重及多标签分类任务网络的搜索空间,确定多标签分类任务网络的最终结构;
训练多标签分类任务网络。
在一种实施方式中,种子网络创建多标签分类任务网络的搜索空间,包括:
将所述种子网络的网络权重载入所述多标签分类任务网络的共享特征提取网络中,并冻结所述共享特征提取网络中的网络权重;
生成m个分离网络,所述分离网络的网络结构与所述种子网络的网络结构相同;其中,所述m为标签类别的个数;
将共享特征提取网络分别与各个分离网络组成多标签分类任务网络的m个子网络对应的搜索路径,每个子网络对应一个标签类别。
在一种实施方式中,子网络对应的搜索路径的每一层是一个分支选择单元;
分支选择单元包括共享特征提取网络的对应层和分离网络的对应层。
在一种实施方式中,采用种子网络创建多标签分类任务网络的搜索空间,还包括:
针对任意标签类别,根据第l-1层的输入内容、第l-1层的路径权重、共享特征提取网络的第l-1层和分离网络的第l-1层,确定标签类别所对应的子网络的第l层的输入内容;其中,l为自然数。
在一种实施方式中,优化搜索空间的网络权重和路径权重,包括:
固定路径权重,在训练集上优化网络权重;
固定网络权重,在验证集上优化路径权重;
判断网络损失是否收敛,如果未收敛,则重复执行固定路径权重、在训练集上优化网络权重的步骤;如果收敛,则结束优化过程。
在一种实施方式中,根据优化后的路径权重及多标签分类任务网络的搜索空间,确定所述多标签分类任务网络的最终结构,包括:
根据优化后的路径权重及预设规则,针对各个标签类别对应的子网络,选择子网络的各层为共享特征提取网络的对应层或分离网络的对应层。
在一种实施方式中,根据优化后的路径权重及预设规则,针对各个标签类别对应的子网络,选择子网络的各层为共享特征提取网络的对应层或分离网络的对应层,包括:
针对各个标签类别的各层:
在所述层的优化后的路径权重中,针对分离网络的路径权重大于预设阈值的情况下,选择分离网络的对应层作为标签类别对应的子网络的所述层;
在所述层的优化后的路径权重中,针对共享特征提取网络的路径权重大于预设阈值的情况下,选择共享特征提取网络的对应层作为标签类别对应的子网络的所述层;
在所述层的优化后的路径权重中,针对分离网络的路径权重和针对共享特征提取网络的路径权重均不大于预设阈值的情况下,选择共享特征提取网络的对应层作为标签类别对应的子网络的所述层。
在一种实施方式中,根据优化后的路径权重及多标签分类任务网络的搜索空间,确定多标签分类任务网络的最终结构,还包括:
针对各个标签类别,以分离网络和共享特征提取网络的第一个叉节点为起点,搭建标签类别对应的分支网络,并将通过共享特征提取网络的节点的网络权重复制到分离网络对应的位置上。
在一种实施方式中,训练多标签分类任务网络,包括:
冻结多标签分类任务网络中的共享特征提取网络的网络权重;
在训练集和验证集上,训练多标签分类任务网络中的分离网络的网络权重。
第二方面,本申请实施例提供了一种多标签分类任务网络的训练装置,包括:
种子网络训练模块,用于训练用于构建多标签分类任务网络的种子网络;
搜索空间创建模块,用于采用种子网络创建多标签分类任务网络的搜索空间,其中,搜索空间中包括各个标签类别对应的搜索路径;
优化模块,用于优化搜索空间的网络权重和路径权重;
确定模块,用于根据优化后的路径权重及多标签分类任务网络的搜索空间,确定多标签分类任务网络的最终结构;
分类网络训练模块,用于训练多标签分类任务网络。
在一种实施方式中,搜索空间创建模块,包括:
载入子模块,用于将种子网络的网络权重载入多标签分类任务网络的共享特征提取网络中,并冻结共享特征提取网络中的网络权重;
生成子模块,用于生成m个分离网络,分离网络的网络结构与种子网络的网络结构相同;其中,m为标签类别的个数;
搭建子模块,用于将共享特征提取网络分别与各个分离网络组成多标签分类任务网络的m个子网络对应的搜索路径,每个子网络对应一个标签类别。
在一种实施方式中,子网络对应的搜索路径的每一层是一个分支选择单元;
分支选择单元包括共享特征提取网络的对应层和分离网络的对应层。
在一种实施方式中,搜索空间创建模块,还包括:
连续化子模块,用于针对任意标签类别,根据第l-1层的输入内容、第l-1层的路径权重、共享特征提取网络的第l-1层和分离网络的第l-1层,确定标签类别所对应的子网络的第l层的输入内容;其中,l为自然数。
在一种实施方式中,优化模块,包括:
网络权重优化子模块,用于固定路径权重,在训练集上优化网络权重;
路径权重优化子模块,用于固定网络权重,在验证集上优化路径权重;
收敛判断子模块,用于判断网络损失是否收敛,如果未收敛,则指示网络权重优化子模块重新优化网络权重;如果收敛,则结束优化过程。
在一种实施方式中,确定模块,用于根据优化后的路径权重及预设规则,针对各个标签类别对应的子网络,选择子网络的各层为共享特征提取网络的对应层或分离网络的对应层。
在一种实施方式中,确定模块用于:
针对各个标签类别的各层,在所述层的优化后的路径权重中,针对分离网络的路径权重大于预设阈值的情况下,选择分离网络的对应层作为标签类别对应的子网络的所述层;
在所述层的优化后的路径权重中,针对共享特征提取网络的路径权重大于预设阈值的情况下,选择共享特征提取网络的对应层作为标签类别对应的子网络的所述层;
在所述层的优化后的路径权重中,针对分离网络的路径权重和针对共享特征提取网络的路径权重均不大于预设阈值的情况下,选择共享特征提取网络的对应层作为标签类别对应的子网络的所述层。
在一种实施方式中,确定模块还用于:
针对各个标签类别,以分离网络和共享特征提取网络的第一个叉节点为起点,搭建标签类别对应的分支网络,并将通过共享特征提取网络的节点的网络权重复制到分离网络对应的位置上。
在一种实施方式中,分类网络训练模块,包括:
冻结子模块,用于冻结多标签分类任务网络中的共享特征提取网络的网络权重;
训练子模块,用于在训练集和验证集上,训练多标签分类任务网络中的分离网络的网络权重。
第三方面,本申请实施例提供了一种电子设备,该电子设备包括:存储器和处理器。其中,该存储器和该处理器通过内部连接通路互相通信,该存储器用于存储指令,该处理器用于执行该存储器存储的指令,并且当该处理器执行该存储器存储的指令时,使得该处理器执行上述各方面任一种实施方式中的方法。
第四方面,本申请实施例提供了一种计算机可读存储介质,计算机可读存储介质存储计算机程序,当计算机程序在计算机上运行时,上述各方面任一种实施方式中的方法被执行。
上述技术方案中的优点或有益效果至少包括:本申请实施例提供了一种克服上述问题或者至少部分解决上述问题的高效的多标签分类任务网络训练方法,基于网络搜索方法和本申请实施例设计的搜索策略,设计出一种高效的多标签分类网络,既能减少算法的计算量,又能大大提高算法的精度,提高算法的竞争力,并增加算法的应用场景。
上述概述仅仅是为了说明书的目的,并不意图以任何方式进行限制。除上述描述的示意性的方面、实施方式和特征之外,通过参考附图和以下的详细描述,本申请进一步的方面、实施方式和特征将会是容易明白的。
附图说明
在附图中,除非另外规定,否则贯穿多个附图相同的附图标记表示相同或相似的部件或元素。这些附图不一定是按照比例绘制的。应该理解,这些附图仅描绘了根据本申请公开的一些实施方式,而不应将其视为是对本申请范围的限制。
图1为本申请实施例提出的一种多标签分类任务网络的训练方法的实现流程示意图;
图2为本申请实施例提出的一种多标签分类任务网络的训练方法中,步骤S102的实现流程示意图;
图3A为本申请实施例构建多标签分类任务网络的过程中,构建的共享特征提取网络示意图;
图3B为本申请实施例构建多标签分类任务网络的过程中,构建的各个子网络对应的搜索路径示意图;
图3C为本申请实施例构建多标签分类任务网络的过程中,经训练后的每个标签类别对应的搜索路径示意图;
图3D为本申请实施例构建多标签分类任务网络的过程中,共享特征提取网络的最终结构示意图;
图4为本申请实施例提出的一种多标签分类任务网络的训练方式的具体实现流程示意图;
图5为本申请实施例的一种多标签分类任务网络的训练装置500的结构示意图
图6为本申请实施例的一种多标签分类任务网络的训练装置600的结构示意图;
图7为本申请实施例的一种电子设备结构示意图。
具体实施方式
在下文中,仅简单地描述了某些示例性实施例。正如本领域技术人员可认识到的那样,在不脱离本申请的精神或范围的情况下,可通过各种不同方式修改所描述的实施例。因此,附图和描述被认为本质上是示例性的而非限制性的。
本申请实施例提出一种多标签分类任务网络的训练方法。图1为本申请实施例提出的一种多标签分类任务网络的训练方法的实现流程示意图,包括:
步骤S101:训练用于构建多标签分类任务网络的种子网络;
步骤S102:采用种子网络创建多标签分类任务网络的搜索空间,其中,搜索空间中包括各个标签类别对应的搜索路径;
步骤S103:优化搜索空间的网络权重和路径权重;
步骤S104:根据优化后的路径权重及多标签分类任务网络的搜索空间,确定多标签分类任务网络的最终结构;
步骤S105:训练多标签分类任务网络。
可选地,本申请实施例可以将所有数据划分为训练集、验证集和测试集,例如,按照3:1:1的比例分成训练集、验证集和测试集。本申请实施例使用的数据可以是图像,例如人脸图像等,相应的,训练完成的多标签分类任务网络用于实现对图像进行多标签分类。步骤S101中,可以采用训练集和验证集训练种子网络,使种子网络模型收敛。
图2为本申请实施例提出的一种多标签分类任务网络的训练方法中,步骤S102的实现流程示意图。如图2所示,上述步骤S102可以包括:
步骤S201:将种子网络的网络权重载入多标签分类任务网络的共享特征提取网络中,并冻结共享特征提取网络中的网络权重。
步骤S202:生成m个分离网络,该分离网络的网络结构与种子网络的网络结构相同;其中,m为标签类别的个数。这一过程可以认为是复制种子网络,即分离网络的网络结构与种子网络的网络结构相同,但网络权重不同。
步骤S203:将共享特征提取网络分别与分离网络组成多标签分类任务网络的m个子网络对应的搜索路径,每个子网络对应一个标签类别。
例如,步骤S202生成m个分离网络,每个分离网络对应一个标签类别,各个分离网络的编号分别为1、2、…、m。步骤S203中将共享特征提取网络分别与各个分离网络组成m个子网络对应的搜索路径,每个子网络对应一个标签类别,每个子网络的搜索路径包括共享特征提取网络和一个对应的分离网络。可见,共享特征提取网络是所有子网络的搜索路径所共有的,分离网络是每个子网络的搜索路径所独有的。
可选地,上述子网络对应的搜索路径的每一层是一个分支选择单元,每个分支选择单元包括共享特征提取网络的对应层和分离网络的对应层。
图3A至图3D显示了本申请实施例构建多标签分类任务网络的过程。其中,图3A为本申请实施例构建多标签分类任务网络的过程中,构建的共享特征提取网络示意图,图3B为本申请实施例构建多标签分类任务网络的过程中,构建的各个子网络对应的搜索路径示意图。如图3A和3B所示,在构建共享特征提取网络之后,根据标签类别的个数(图3A至3D示例中,标签类别的个数为4)复制4个分离网络。图3B中,中间的黑色圆形组成的网络表示共享特征提取网络,两边的4个由不同填充图案的圆形组成的网络表示4个分离网络。例如,图3B中共享特征提取网络与左侧第一列的分离网络组成标签类别1对应的子网络(简称子网络1)的搜索路径;该搜索路径的每一层对应一个分支选择单元,以第二层为例,该分支选择单元包括共享特征提取网络的第二层(由图3B中上数第二个黑色节点表示)和该分离网络的第二层(由图3B中左侧第一列网络的上数第一个节点表示)。图3B中共享特征提取网络与左侧第二列的分离网络组成标签类别2对应的子网络(简称子网络2)的搜索路径,图3B中共享特征提取网络与右侧第一列的分离网络组成标签类别3对应的子网络(简称子网络3)的搜索路径,图3B中共享特征提取网络与右侧第二列的分离网络组成标签类别4对应的子网络(简称子网络4)的搜索路径。子网络2、子网络3及子网络4的搜索路径与子网络1的搜索路径的结构类似,在此不再赘述。
后续地,对各个子网络的搜索路径的路径权重进行优化,根据优化结果,在每个分支选择单元中选择共享特征提取网络的对应层、或选择分离网络的对应层,得到经训练后的每个标签类别对应的搜索路径,并对网络进行重建,得到多标签分类任务网络的最终结构。
图3C为本申请实施例构建多标签分类任务网络的过程中,经训练后的每个标签类别对应的搜索路径示意图。如图3C所示,从每个标签类别对应的搜索路径的各个分支选择单元中,选择共享特征提取网络中的节点或分离网络中的节点,组成经训练后的搜索路径,图3C中的实线箭头连接了每个标签类别对应的经训练后的搜索路径。针对图3C中的每个标签类别对应的经训练后的搜索路径,本申请实施例可以以第一个叉节点为起点,将共享特征提取网络中的节点复制出来,分别搭建分支网络,得到共享特征提取网络的最终结构。图3D为本申请实施例构建多标签分类任务网络的过程中,共享特征提取网络的最终结构示意图。如图3D所示,标签类别1对应的搜索路径中,上数第三个黑色圆形代表的节点为第一个叉节点,以该第一叉节点为起点,将共享特征提取网络中的节点复制出来,构成标签类别1的分支网络。
为了实现对搜索路径的路径权重进行优化,本申请实施例可用使搜索空间连续,将路径选择变为一个混合操作,具体地,上述步骤S102可以进一步包括:
针对任意标签类别,根据第l-1层的输入内容、第l-1层的路径权重、共享特征提取网络的第l-1层和分离网络的第l-1层,确定该标签类别所对应的子网络的第l层的输入内容;其中,l为自然数。
其中,为l-1层选择路径权重,是通过计算出来的,()是指通过共享特征网络或自身的子网络对第j类标签的第(l-1)层网络的输入进行特征提取(即i为0或1),为第j类标签的第(l-1)层网络的输入,exp为以自然常数e为底的指数函数。
可选地,上述步骤S103中,优化搜索空间的网络权重和路径权重,包括:
固定路径权重,在训练集上优化网络权重;
固定网络权重,在验证集上优化路径权重;
判断网络损失是否收敛,如果未收敛,则重复执行固定路径权重、使用训练集优化网络权重的步骤;如果收敛,则结束优化过程,开始执行步骤S104。
上述步骤S104可以包括:根据优化后的路径权重及预设规则,针对各个标签类别对应的子网络,选择子网络的各层为共享特征提取网络的对应层或分离网络的对应层。
所有子网络的组合即为最终确定的多标签分类任务网络的结构。后续即可对该多标签分类任务网络进行训练,即执行上述步骤S105。
可选地,上述步骤S105包括:
冻结多标签分类任务网络中的共享特征提取网络的网络权重;
在训练集和验证集上,训练多标签分类任务网络中的分离网络的网络权重。
以下结合附图4,介绍一个具体的实施方式。图4为本申请实施例提出的一种多标签分类任务网络的训练方式的具体实现流程示意图,包括以下步骤:
S410:将所有数据按照3:1:1的比例分成训练集、验证集和测试集。
S420:以训练集和验证集为基础数据,训练一个种子网络(如resnet18,mobilenetv2等),按照二分类交叉熵(BCE,Binary Cross Entropy)损失,使模型收敛。如采用下式计算BCE损失函数:
S430:将已训练好种子网络的权重按照对应层载入到多标签分类任务网络(或称超网络)的共享特征提取网络(简称共享特征网络)中,并冻结这部分网络权重。
S440:创建多标签分类任务网络的搜索空间和参数化搜索路径。具体包括步骤S441和S442。
S441:将种子网络作为超网络中共享特征提取网络,并根据标签分类的个数m,复制m个分离网络。
上述复制可以是将分离网络的网络结构设置为与种子网络的网络结构相同,但二者的网络参数存在差异。每个分离网络包含n个卷积层,n越大,搜索的计算量越大,但搜索精度越高;一般会权衡计算量和精度来配置卷积层的层数。在本实施例中,选择n=3。每个分离网络的卷积层数也可以不同,如果不同标签类别的分离网络的卷积层数不同,则可以按照网络底部输出对齐。按照网络底部输出对齐,依次将分离网络的每层搭建在共享特征提取网络上(搭建方式就是普通的连接,以组成分支),组成超网络。
如图3A至图3D显示了本申请实施例构建多标签分类任务网络的过程。其中,图3A表示共享特征提取网络示意图,共享特征提取网络由种子网络构成。图3B表示超网络,超网络包括1个共享特征提取网络(加载种子网络得到的)和m个分离网络;每个分离网络对应一个标签类别,其中,每个分离网络和共享特征提取网络的对应层组成一个分支选择单元。图3C表示经训练后的每个标签类别的搜索路径。图3D表示经上述步骤S490步骤处理后的示意图,即图3C中的每个标签类别对应的搜索路径,以第一个叉节点为起点,将共享特征提取网络中的节点复制出来,分别搭建分支网络。
S442:在网络搜索中,每个标签类别的每个分支选择单元有两种选择,一个是通过共享特征网络的对应层提取特征,另一个是通过自身的分离网络的对应层提取特征。鉴于此,为了使搜索空间连续,本申请实施例将路径选择变为一个混合操作,例如:将第j类标签的第l层网络的输入设置为 ,采用以下式子确定:
其中,为l-1层选择路径权重,是通过计算出来的,()是指通过共享特征网络或自身的子网络对第j类标签的第(l-1)层网络的输入进行特征提取(即i为0或1),为第j类标签的第(l-1)层网络的输入,exp为以自然常数e为底的指数函数。
其中,表示采用梯度下降算法进行优化;
L表示损失值,例如步骤S1中计算得到的损失值;
train表示训练集。
val表示验证集。
S470:判断网络损失是否收敛,如果未收敛,则重复执行步骤S450和S460;如果收敛,则执行步骤S480。
S480:根据路径权重,按照规则选择每个类别j、每层l对应的操作,组成最终网络,规则如下;
由上述式子可见,在本申请实施例中,根据优化后的路径权重及预设规则,针对各个标签类别对应的子网络,选择子网络的各层为共享特征提取网络的对应层或分离网络的对应层。具体可以包括:
针对各个标签类别的各层:
在该层的优化后的路径权重中,针对分离网络的路径权重大于预设阈值的情况下,选择分离网络的对应层作为该标签类别对应的子网络的该层;
在该层的优化后的路径权重中,针对共享特征提取网络的路径权重大于预设阈值的情况下,选择共享特征提取网络的对应层作为标签类别对应的子网络的所述该层;
在该层的优化后的路径权重中,针对分离网络的路径权重和针对共享特征提取网络的路径权重均不大于预设阈值的情况下,选择共享特征提取网络的对应层作为标签类别对应的子网络的该层。
举例来说:若标签类别j的第l层对应的两条分支(即分离网络的对应层和共享特征提取网络的对应层)中自身分离网络的路径权重大于预设阈值,则通过自身的分离网络的对应层提取特征;如果是共享特征提取网络的路径权重大于预设阈值,则通过共享特征提取网络的对应层提取特征。如果两条分支的路径权重均不大于预设阈值,则通过共享特征提取网络的对应层提取特征。
S490:根据最终搜索的网络,每个标签类别按照分离网络以第一个叉节点为起点,搭建分支网络,并将通过共享特征提取网络的节点的网络权重复制到分离网络对应的位置上。即,针对各个标签类别,以分离网络和共享特征提取网络的第一个叉节点为起点,搭建该标签类别对应的分支网络,并将通过共享特征提取网络的节点的网络权重复制到分离网络对应的位置上。
例如,图3D中标签类别3对应的分支网络中,共享特征提取网络的下数第二个节点的网络权重被复制到了分离网络对应的位置。
S491:冻结共享特征提取网络部分的网络权重,在训练集和验证集上,重新训练分离网络部分的网络权重,使网络收敛。例如,可以按照BCE损失使网络收敛。
本申请实施例还提出一种多标签分类任务网络的训练装置,图5为本申请实施例的一种多标签分类任务网络的训练装置500的结构示意图,包括:
种子网络训练模块510,用于训练用于构建多标签分类任务网络的种子网络;
搜索空间创建模块520,用于采用种子网络创建多标签分类任务网络的搜索空间,其中,搜索空间中包括各个标签类别对应的搜索路径;
优化模块530,用于优化搜索空间的网络权重和路径权重;
确定模块540,用于根据优化后的路径权重及多标签分类任务网络的搜索空间,确定多标签分类任务网络的最终结构;
分类网络训练模块550,用于训练多标签分类任务网络。
可选地,上述搜索空间创建模块520,包括:
载入子模块521,用于将种子网络的网络权重载入多标签分类任务网络的共享特征提取网络中,并冻结共享特征提取网络中的网络权重;
生成子模块522,用于生成m个分离网络,分离网络的网络结构与种子网络的网络结构相同;其中,m为标签类别的个数;
搭建子模块523,用于将共享特征提取网络分别与各个分离网络组成多标签分类任务网络的m个子网络对应的搜索路径,每个子网络对应一个标签类别。
可选地,上述子网络对应的搜索路径的每一层是一个分支选择单元;
分支选择单元包括共享特征提取网络的对应层和分离网络的对应层。
可选地,上述搜索空间创建模块520,还包括:
连续化子模块524,用于针对任意标签类别,根据第l-1层的输入内容、第l-1层的路径权重、共享特征提取网络的第l-1层和分离网络的第l-1层,确定标签类别所对应的子网络的第l层的输入内容;其中,l为自然数。
可选地,上述优化模块530,包括:
网络权重优化子模块531,用于固定路径权重,在训练集上优化网络权重;
路径权重优化子模块532,用于固定网络权重,在验证集上优化路径权重;
收敛判断子模块533,用于判断网络损失是否收敛,如果未收敛,则指示网络权重优化子模块重新优化网络权重;如果收敛,则结束优化过程。
可选地,上述确定模块540,用于根据优化后的路径权重及预设规则,针对各个标签类别对应的子网络,选择子网络的各层为共享特征提取网络的对应层或分离网络的对应层。
可选地,上述确定模块540用于:
针对各个标签类别的各层,在所述层的优化后的路径权重中,针对分离网络的路径权重大于预设阈值的情况下,选择分离网络的对应层作为标签类别对应的子网络的所述层;
在所述层的优化后的路径权重中,针对共享特征提取网络的路径权重大于预设阈值的情况下,选择共享特征提取网络的对应层作为标签类别对应的子网络的所述层;
在所述层的优化后的路径权重中,针对分离网络的路径权重和针对共享特征提取网络的路径权重均不大于预设阈值的情况下,选择共享特征提取网络的对应层作为标签类别对应的子网络的所述层。
可选地,上述确定模块540还用于:
针对各个标签类别,以分离网络和共享特征提取网络的第一个叉节点为起点,搭建标签类别对应的分支网络,并将通过共享特征提取网络的节点的网络权重复制到分离网络对应的位置上。
可选地,上述分类网络训练模块550,包括:
冻结子模块551,用于冻结多标签分类任务网络中的共享特征提取网络的网络权重;
训练子模块552,用于在训练集和验证集上,训练多标签分类任务网络中的分离网络的网络权重。
本发明实施例各装置中的各模块的功能可以参见上述多标签分类任务网络的训练方法中的对应描述,在此不再赘述。
图7为本申请实施例的一种电子设备结构示意图,包括:存储器710和处理器720,存储器710内存储有可在处理器720上运行的计算机程序。处理器720执行该计算机程序时实现上述实施例中的汉字拼音转换方法或汉字拼音转换模型的训练方法。存储器710和处理器720的数量可以为一个或多个。
该自动评分设备还包括:
通信接口730,用于与外界设备进行通信,进行数据交互传输。
如果存储器710、处理器720和通信接口730独立实现,则存储器710、处理器720和通信接口730可以通过总线相互连接并完成相互间的通信。该总线可以是工业标准体系结构(Industry Standard Architecture,ISA)总线、外部设备互连(PeripheralComponentInterconnect,PCI)总线或扩展工业标准体系结构(Extended IndustryStandard Architecture,EISA)总线等。该总线可以分为地址总线、数据总线、控制总线等。为便于表示,图7中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
可选的,在具体实现上,如果存储器710、处理器720及通信接口730集成在一块芯片上,则存储器710、处理器720及通信接口730可以通过内部接口完成相互间的通信。
本发明实施例提供了一种计算机可读存储介质,其存储有计算机程序,该程序被处理器执行时实现本申请实施例中提供的方法。
本申请实施例还提供了一种芯片,包括:输入接口、输出接口、处理器和存储器,输入接口、输出接口、处理器以及存储器之间通过内部连接通路相连,处理器用于执行存储器中的代码,当代码被执行时,处理器用于执行申请实施例提供的方法。
应理解的是,上述处理器可以是中央处理器(Central Processing Unit,CPU),还可以是其他通用处理器、数字信号处理器(digital signal processing,DSP)、专用集成电路(application specific integrated circuit,ASIC)、现场可编程门阵列(fieldprogrammablegate array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者是任何常规的处理器等。值得说明的是,处理器可以是支持进阶精简指令集机器(advanced RISC machines,ARM)架构的处理器。
进一步地,可选的,上述存储器可以包括只读存储器和随机存取存储器,还可以包括非易失性随机存取存储器。该存储器可以是易失性存储器或非易失性存储器,或可包括易失性和非易失性存储器两者。其中,非易失性存储器可以包括只读存储器(read-onlymemory,ROM)、可编程只读存储器(programmable ROM,PROM)、可擦除可编程只读存储器(erasable PROM,EPROM)、电可擦除可编程只读存储器(electrically EPROM,EEPROM)或闪存。易失性存储器可以包括随机存取存储器(random access memory,RAM),其用作外部高速缓存。通过示例性但不是限制性说明,许多形式的RAM可用。例如,静态随机存取存储器(static RAM,SRAM)、动态随机存取存储器(dynamic random access memory ,DRAM) 、同步动态随机存取存储器(synchronous DRAM,SDRAM)、双倍数据速率同步动态随机存取存储器(double data date SDRAM,DDR SDRAM)、增强型同步动态随机存取存储器(enhancedSDRAM,ESDRAM)、同步连接动态随机存取存储器(synchlink DRAM,SLDRAM)和直接内存总线随机存取存储器(direct rambus RAM,DR RAM)。
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行计算机程序指令时,全部或部分地产生按照本申请的流程或功能。计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质传输。
在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包括于本申请的至少一个实施例或示例中。而且,描述的具体特征、结构、材料或者特点可以在任一个或多个实施例或示例中以合适的方式结合。此外,在不相互矛盾的情况下,本领域的技术人员可以将本说明书中描述的不同实施例或示例以及不同实施例或示例的特征进行结合和组合。
此外,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或隐含地包括至少一个该特征。在本申请的描述中,“多个”的含义是两个或两个以上,除非另有明确具体的限定。
流程图中或在此以其他方式描述的任何过程或方法描述可以被理解为,表示包括一个或更多个用于实现特定逻辑功能或过程的步骤的可执行指令的代码的模块、片段或部分。并且本申请的优选实施方式的范围包括另外的实现,其中可以不按所示出或讨论的顺序,包括根据所涉及的功能按基本同时的方式或按相反的顺序,来执行功能。
在流程图中表示或在此以其他方式描述的逻辑和/或步骤,例如,可以被认为是用于实现逻辑功能的可执行指令的定序列表,可以具体实现在任何计算机可读介质中,以供指令执行系统、装置或设备(如基于计算机的系统、包括处理器的系统或其他可以从指令执行系统、装置或设备取指令并执行指令的系统)使用,或结合这些指令执行系统、装置或设备而使用。
应理解的是,本申请的各部分可以用硬件、软件、固件或它们的组合来实现。在上述实施方式中,多个步骤或方法可以用存储在存储器中且由合适的指令执行系统执行的软件或固件来实现。上述实施例方法的全部或部分步骤是可以通过程序来指令相关的硬件完成,该程序可以存储于一种计算机可读存储介质中,该程序在执行时,包括方法实施例的步骤之一或其组合。
此外,在本申请各个实施例中的各功能单元可以集成在一个处理模块中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个模块中。上述集成的模块既可以采用硬件的形式实现,也可以采用软件功能模块的形式实现。上述集成的模块如果以软件功能模块的形式实现并作为独立的产品销售或使用时,也可以存储在一个计算机可读存储介质中。该存储介质可以是只读存储器,磁盘或光盘等。
以上,仅为本申请的具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到其各种变化或替换,这些都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应以权利要求的保护范围为准。
Claims (18)
1.一种多标签分类任务网络的训练方法,其特征在于,包括:
训练用于构建多标签分类任务网络的种子网络;
采用所述种子网络创建多标签分类任务网络的搜索空间,其中,所述搜索空间中包括各个标签类别对应的搜索路径;
优化所述搜索空间的网络权重和路径权重;
根据优化后的路径权重及所述多标签分类任务网络的搜索空间,确定所述多标签分类任务网络的最终结构;
训练所述多标签分类任务网络;
其中,所述采用所述种子网络创建多标签分类任务网络的搜索空间,包括:
将所述种子网络的网络权重载入所述多标签分类任务网络的共享特征提取网络中,并冻结所述共享特征提取网络中的网络权重;
生成m个分离网络,所述分离网络的网络结构与所述种子网络的网络结构相同;其中,所述m为标签类别的个数;
将所述共享特征提取网络分别与所述各个分离网络组成所述多标签分类任务网络的m个子网络对应的搜索路径,每个子网络对应一个所述标签类别。
2.根据权利要求1所述的方法,其特征在于,所述子网络对应的搜索路径的每一层是一个分支选择单元;
所述分支选择单元包括所述共享特征提取网络的对应层和所述分离网络的对应层。
3.根据权利要求2所述的方法,其特征在于,所述采用所述种子网络创建多标签分类任务网络的搜索空间,还包括:
针对任意标签类别,根据第l-1层的输入内容、第l-1层的路径权重、所述共享特征提取网络的第l-1层和所述分离网络的第l-1层,确定所述标签类别所对应的子网络的第l层的输入内容;其中,所述l为自然数。
4.根据权利要求1或2所述的方法,其特征在于,所述优化所述搜索空间的网络权重和路径权重,包括:
固定所述路径权重,在训练集上优化所述网络权重;
固定所述网络权重,在验证集上优化所述路径权重;
判断网络损失是否收敛,如果未收敛,则重复执行固定路径权重、在训练集上优化网络权重的步骤;如果收敛,则结束优化过程。
5.根据权利要求2所述的方法,其特征在于,所述根据优化后的路径权重及所述多标签分类任务网络的搜索空间,确定所述多标签分类任务网络的最终结构,包括:
根据优化后的路径权重及预设规则,针对各个标签类别对应的子网络,选择子网络的各层为共享特征提取网络的对应层或分离网络的对应层。
6.根据权利要求5所述的方法,其特征在于,所述根据优化后的路径权重及预设规则,针对各个标签类别对应的子网络,选择子网络的各层为共享特征提取网络的对应层或分离网络的对应层,包括:
针对各个标签类别的各层:
在所述层的优化后的路径权重中,针对分离网络的路径权重大于预设阈值的情况下,选择分离网络的对应层作为所述标签类别对应的子网络的所述层;
在所述层的优化后的路径权重中,针对共享特征提取网络的路径权重大于预设阈值的情况下,选择共享特征提取网络的对应层作为所述标签类别对应的子网络的所述层;
在所述层的优化后的路径权重中,针对分离网络的路径权重和针对共享特征提取网络的路径权重均不大于预设阈值的情况下,选择共享特征提取网络的对应层作为所述标签类别对应的子网络的所述层。
7.根据权利要求5或6所述的方法,其特征在于,所述根据优化后的路径权重及所述多标签分类任务网络的搜索空间,确定所述多标签分类任务网络的最终结构,还包括:
针对各个标签类别,以所述分离网络和所述共享特征提取网络的第一个叉节点为起点,搭建所述标签类别对应的分支网络,并将通过共享特征提取网络的节点的网络权重复制到分离网络对应的位置上。
8.根据权利要求1或2所述的方法,其特征在于,所述训练所述多标签分类任务网络,包括:
冻结所述多标签分类任务网络中的所述共享特征提取网络的网络权重;
在训练集和验证集上,训练所述多标签分类任务网络中的所述分离网络的网络权重。
9.一种多标签分类任务网络的训练装置,其特征在于,包括:
种子网络训练模块,用于训练用于构建多标签分类任务网络的种子网络;
搜索空间创建模块,用于采用所述种子网络创建多标签分类任务网络的搜索空间,其中,所述搜索空间中包括各个标签类别对应的搜索路径;
优化模块,用于优化所述搜索空间的网络权重和路径权重;
确定模块,用于根据优化后的路径权重及所述多标签分类任务网络的搜索空间,确定所述多标签分类任务网络的最终结构;
分类网络训练模块,用于训练所述多标签分类任务网络;
其中,所述搜索空间创建模块,包括:
载入子模块,用于将所述种子网络的网络权重载入所述多标签分类任务网络的共享特征提取网络中,并冻结所述共享特征提取网络中的网络权重;
生成子模块,用于生成m个分离网络,所述分离网络的网络结构与所述种子网络的网络结构相同;其中,所述m为标签类别的个数;
搭建子模块,用于将所述共享特征提取网络分别与所述各个分离网络组成所述多标签分类任务网络的m个子网络对应的搜索路径,每个子网络对应一个所述标签类别。
10.根据权利要求9所述的装置,其特征在于,所述子网络对应的搜索路径的每一层是一个分支选择单元;
所述分支选择单元包括所述共享特征提取网络的对应层和所述分离网络的对应层。
11.根据权利要求9或10所述的装置,其特征在于,所述搜索空间创建模块,还包括:
连续化子模块,用于针对任意标签类别,根据第l-1层的输入内容、第l-1层的路径权重、所述共享特征提取网络的第l-1层和所述分离网络的第l-1层,确定所述标签类别所对应的子网络的第l层的输入内容;其中,所述l为自然数。
12.根据权利要求9或10所述的装置,其特征在于,所述优化模块,包括:
网络权重优化子模块,用于固定所述路径权重,在训练集上优化所述网络权重;
路径权重优化子模块,用于固定所述网络权重,在验证集上优化所述路径权重;
收敛判断子模块,用于判断网络损失是否收敛,如果未收敛,则指示所述网络权重优化子模块重新优化网络权重;如果收敛,则结束优化过程。
13.根据权利要求10所述的装置,其特征在于,所述确定模块,用于根据优化后的路径权重及预设规则,针对各个标签类别对应的子网络,选择子网络的各层为共享特征提取网络的对应层或分离网络的对应层。
14.根据权利要求13所述的装置,其特征在于,所述确定模块用于:
针对各个标签类别的各层,在所述层的优化后的路径权重中,针对分离网络的路径权重大于预设阈值的情况下,选择分离网络的对应层作为所述标签类别对应的子网络的所述层;
在所述层的优化后的路径权重中,针对共享特征提取网络的路径权重大于预设阈值的情况下,选择共享特征提取网络的对应层作为所述标签类别对应的子网络的所述层;
在所述层的优化后的路径权重中,针对分离网络的路径权重和针对共享特征提取网络的路径权重均不大于预设阈值的情况下,选择共享特征提取网络的对应层作为所述标签类别对应的子网络的所述层。
15.根据权利要求13或14所述的装置,其特征在于,所述确定模块还用于:
针对各个标签类别,以所述分离网络和所述共享特征提取网络的第一个叉节点为起点,搭建所述标签类别对应的分支网络,并将通过共享特征提取网络的节点的网络权重复制到分离网络对应的位置上。
16.根据权利要求9或10所述的装置,其特征在于,所述分类网络训练模块,包括:
冻结子模块,用于冻结所述多标签分类任务网络中的所述共享特征提取网络的网络权重;
训练子模块,用于在训练集和验证集上,训练所述多标签分类任务网络中的所述分离网络的网络权重。
17.一种电子设备,其特征在于,包括:包括处理器和存储器,所述存储器中存储指令,所述指令由处理器加载并执行,以实现如权利要求1至8任一项所述的方法。
18.一种计算机可读存储介质,所述计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现如权利要求1至8中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011441233.7A CN112232445B (zh) | 2020-12-11 | 2020-12-11 | 多标签分类任务网络的训练方法和装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011441233.7A CN112232445B (zh) | 2020-12-11 | 2020-12-11 | 多标签分类任务网络的训练方法和装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112232445A CN112232445A (zh) | 2021-01-15 |
CN112232445B true CN112232445B (zh) | 2021-05-11 |
Family
ID=74124595
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011441233.7A Active CN112232445B (zh) | 2020-12-11 | 2020-12-11 | 多标签分类任务网络的训练方法和装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112232445B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113298774B (zh) * | 2021-05-20 | 2022-10-18 | 复旦大学 | 一种基于对偶条件相容神经网络的图像分割方法、装置 |
Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111666763A (zh) * | 2020-05-28 | 2020-09-15 | 平安科技(深圳)有限公司 | 用于多任务场景的网络结构构建方法和装置 |
Family Cites Families (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
EP2221805B1 (en) * | 2009-02-20 | 2014-06-25 | Nuance Communications, Inc. | Method for automated training of a plurality of artificial neural networks |
CN105046323B (zh) * | 2015-04-29 | 2017-03-22 | 西北大学 | 一种正则化rbf网络多标签分类方法 |
CN108985250A (zh) * | 2018-07-27 | 2018-12-11 | 大连理工大学 | 一种基于多任务网络的交通场景解析方法 |
CN110443189B (zh) * | 2019-07-31 | 2021-08-03 | 厦门大学 | 基于多任务多标签学习卷积神经网络的人脸属性识别方法 |
CN111723910A (zh) * | 2020-06-17 | 2020-09-29 | 腾讯科技(北京)有限公司 | 构建多任务学习模型的方法、装置、电子设备及存储介质 |
-
2020
- 2020-12-11 CN CN202011441233.7A patent/CN112232445B/zh active Active
Patent Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111666763A (zh) * | 2020-05-28 | 2020-09-15 | 平安科技(深圳)有限公司 | 用于多任务场景的网络结构构建方法和装置 |
Also Published As
Publication number | Publication date |
---|---|
CN112232445A (zh) | 2021-01-15 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113033811B (zh) | 两量子比特逻辑门的处理方法及装置 | |
CN112487168B (zh) | 知识图谱的语义问答方法、装置、计算机设备及存储介质 | |
CN109284860A (zh) | 一种基于正交反向樽海鞘优化算法的预测方法 | |
CN111507459B (zh) | 降低神经网络的注解费用的方法和装置 | |
JP7332238B2 (ja) | タスク固有のデータ利用のための物理学により誘導されたディープマルチモーダル埋め込みのための方法及び装置 | |
CN115240786A (zh) | 反应物分子的预测方法、训练方法、装置以及电子设备 | |
CN111553466B (zh) | 信息处理方法、装置及设备 | |
CN112232445B (zh) | 多标签分类任务网络的训练方法和装置 | |
Peng et al. | Hierarchical visual-textual knowledge distillation for life-long correlation learning | |
CN114896395A (zh) | 语言模型微调方法、文本分类方法、装置及设备 | |
JP6230987B2 (ja) | 言語モデル作成装置、言語モデル作成方法、プログラム、および記録媒体 | |
JP6325762B1 (ja) | 情報処理装置、情報処理方法、および情報処理プログラム | |
Song et al. | Few-shot open-set recognition using background as unknowns | |
CN114492601A (zh) | 资源分类模型的训练方法、装置、电子设备及存储介质 | |
CN116756536B (zh) | 数据识别方法、模型训练方法、装置、设备及存储介质 | |
CN112232360A (zh) | 图像检索模型优化方法、图像检索方法、装置及存储介质 | |
CN116738983A (zh) | 模型进行金融领域任务处理的词嵌入方法、装置、设备 | |
WO2023197460A1 (zh) | 一种图像识别方法、装置及电子设备和存储介质 | |
CN115858725A (zh) | 一种基于无监督式图神经网络的文本噪声筛选方法及系统 | |
CN113516125B (zh) | 模型训练方法、使用方法、装置、设备及存储介质 | |
JP7041239B2 (ja) | 深層距離学習方法およびシステム | |
CN114547349A (zh) | 模型调整与业务处理的方法、装置、设备及存储介质 | |
CN113240087B (zh) | 图像生成模型构建方法、装置、介质及设备 | |
CN117829242B (zh) | 模型处理方法及相关设备 | |
CN114840764B (zh) | 服务于用户兴趣分析的大数据挖掘方法及云端ai部署系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |