CN109840531A - 训练多标签分类模型的方法和装置 - Google Patents

训练多标签分类模型的方法和装置 Download PDF

Info

Publication number
CN109840531A
CN109840531A CN201711187818.9A CN201711187818A CN109840531A CN 109840531 A CN109840531 A CN 109840531A CN 201711187818 A CN201711187818 A CN 201711187818A CN 109840531 A CN109840531 A CN 109840531A
Authority
CN
China
Prior art keywords
matrix
label
mapping
samples
network
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN201711187818.9A
Other languages
English (en)
Other versions
CN109840531B (zh
Inventor
刘晓阳
胡晓林
王月红
曹忆南
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tsinghua University
Huawei Technologies Co Ltd
Original Assignee
Tsinghua University
Huawei Technologies Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tsinghua University, Huawei Technologies Co Ltd filed Critical Tsinghua University
Priority to CN201711187818.9A priority Critical patent/CN109840531B/zh
Priority to PCT/CN2018/094309 priority patent/WO2019100723A1/zh
Publication of CN109840531A publication Critical patent/CN109840531A/zh
Application granted granted Critical
Publication of CN109840531B publication Critical patent/CN109840531B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Image Analysis (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本申请提供了一种训练多标签分类模型的方法和装置,能够动态学习图像特征,使特征提取网络更适应任务需求,并且多标签分类效果好。该方法包括:从训练数据集中确定n个样本和与所述n个样本对应的标签矩阵Yc*n,所述标签矩阵Yc*n中的元素yi*j表示第i个样本是否包含第j个标签指示的对象,c表示与样本相关的标签的个数;利用特征提取网络提取所述n个样本的特征矩阵Xd*n;利用第一映射网络获取所述特征矩阵Xd*n的预测标签矩阵利用第二映射网络获取所述标签矩阵Yc*n的低秩标签矩阵根据所述标签矩阵Yc*n、所述预测标签矩阵和所述低秩标签矩阵对所述权值参数Z、所述特征映射矩阵Mc*d和所述低秩标签相关性矩阵S进行更新,训练所述多标签分类模型。

Description

训练多标签分类模型的方法和装置
技术领域
本申请涉及计算机领域,并且更具体的,涉及计算机领域中的训练多标签分类模型的方法和装置。
背景技术
随着智能手机的处理性能的提升,越来越多的应用对图像的识别提出了要求。比如,在用手机拍照的过程中,如果智能手机能够精确的识别出拍摄范围内的物体,就能对其颜色,形状进行针对性的运算,从而提高拍摄效果。而在智能系统的机器学习中,对图像中的物体进行识别的训练也就成了一个非常重要的方面。通常来说,机器学习是为大量的已有图像针对其中包含的物体设置标签,然后通过计算机自我演进不断调整识别参数,来逐渐提高对物体的识别准确率。
由于客观物体本身的复杂性和多义性,现实生活中的很多对象可能同时与多个类别标签相关。为了更好的体现出实际对象所具有的多语义性,常使用一个适当的标签子集(包含多个相关的语义标签)描述该对象,这就形成了所谓的多标签分类问题。这时,每个样本都对应一个由多个标签构成的相关标签子集合,学习的目标就是为未知样本预测其相应的标签子集。
在多标签分类的实际问题中,标签子集中的标签并不是相互独立的,而是语义相关的。比如羊和草在一幅图片中出现的可能性很大,山和天空一起出现的可能性也很大,而羊和办公室一起出现的可能性很小,因此这种相关性可以用来提高多标签分类的准确性。多标签分类中计算标签相关性有多种方法,其中一种是通过学习一个低秩标签相关性矩阵来计算标签之间的相关性,并通过最小化多标签分类的损失函数,计算低秩标签相关性矩阵来提高多标签分类的性能。但是这种方法需要先提取图像的特征,然后根据图像的特征计算特征映射矩阵和低秩标签相关性矩阵。在提取了图像的特征之后,该图像的特征就是固定的,因而不能够动态地根据标签学习输入图像的特征信息。
发明内容
本申请提供一种训练多标签分类模型的方法和装置,能够动态学习图像特征,使特征提取网络更适应任务需求,并且多标签分类效果好。
第一方面,提供了一种训练多标签分类模型的方法,包括:
从训练数据集中确定n个样本和与所述n个样本对应的标签矩阵Yc*n,所述标签矩阵Yc*n中的元素yi*j表示第i个样本是否包含第j个标签指示的对象,c表示与所述训练数据集中的样本相关的标签的个数。
利用特征提取网络提取所述n个样本的特征矩阵Xd*n,其中,所述特征提取网络具有权值参数Z,d表示所述特征矩阵Xd*n的特征维度。
这里,特征提取网络可以是任意一种能够提取图像特征的神经网络,例如可以为卷积神经网络或多层感知机等,本申请实施例对此不限定。其中,特征提取网络的权值可以表示为Z,具体的,Z可以包含多个权值矩阵。权值矩阵的参数可以随机初始化生成,或者可以采用预训练的模型参数。这里,预训练的模型参数指的是已经训练好的模型的参数,如vgg16网络在ImageNet数据集上训练好的模型参数。
利用第一映射网络获取所述特征矩阵Xd*n的预测标签矩阵所述预测标签矩阵中的元素表示第i个样本包含第j个标签指示的对象的置信度,其中,所述第一映射网络的权值矩阵为特征映射矩阵Mc*d,Mc*d可以表示多标签分类模型中的特征属性与类别标签之间的相关权重,其初始值可以随机生成。
具体的,第一映射网络可以表示为FCM。特征提取网络输出的特征矩阵Xd*n可以输入至FCM,再由FCM将输入的特征矩阵Xd*n映射到预测标签空间,得到预测标签矩阵即有:
这里,预测标签矩阵可以为包含更丰富标签信息的标签矩阵,其中的每个元素表示第i个样本包含第j个标签指示的对象的置信度。
利用第二映射网络获取所述标签矩阵Yc*n的低秩标签矩阵其中,所述第二映射网络的权值矩阵为低秩标签相关性矩阵S,所述低秩标签相关性矩阵S用于描述所述c个标签之间的关系。即有:
这里,中很可能会包含更丰富的标签信息,因此中的每个元素可以表示第i个样本包含第j个标签指示的对象的置信度。
根据所述标签矩阵Yc*n、所述预测标签矩阵和所述低秩标签矩阵对所述权值参数Z、所述特征映射矩阵Mc*d和所述低秩标签相关性矩阵S进行更新,训练所述多标签分类模型。
其中,n、c、i、j和d均为正整数,且i的取值范围为1至n,j的取值范围为1至c。
因此,本申请实施例所提供的该神经网络系统可以从输入数据直接训练出模型,而不需要额外的中间步骤,即该神经网络系统为一个端到端的的神经系统。这里,端到端的优点是特征提取、特征映射矩阵和低秩标签相关性矩阵可以同时优化,也就是说,本申请实施例可以动态学习图像特征,使特征提取网络更适应任务需求,并且多标签分类效果好。
可选的,所述第二映射网络包括第一子映射网络和第二子映射网络,所述第二映射网络、所述第一子映射网络和所述第二子映射网络具有以下关系:
其中,为所述第一子映射网络的权值矩阵,Hc*r为所述第二子映射网络的权值矩阵,r为小于或等于c的正整数。
具体的,第一子映射网络可以为权值矩阵为的全连接层,第二子映射网络可以为权值矩阵为Hc*r的全连接层,和Hc*r的初始值可以随机生成。由于两个矩阵相乘后得到的矩阵的秩小于两个矩阵中的任意一个矩阵的秩,因此可以通过设置r的大小(即r≤c)使和Hc*r低秩,进而使得低秩,即使得标签相关性矩阵S低秩,并且r可以通过多次训练取最优值。
可选的,根据所述标签矩阵Yc*n、所述预测标签矩阵和所述低秩标签矩阵对所述权值参数Z、所述特征映射矩阵Mc*d和所述低秩标签相关性矩阵S进行更新,包括:
将所述预测标签矩阵和所述低秩标签矩阵之间的欧氏距离损失函数确定为的第一损失函数,第一损失函数的表达式如下:
将所述标签矩阵Yc*n和所述低秩标签矩阵之间的欧氏距离损失函数确定为第二损失函数,第二损失函数的表达式如下:
根据所述第一损失函数和所述第二损失函数,对所述权值参数Z、所述特征映射矩阵Mc*d、所述第一子映射网络的权值矩阵和所述第二子映射网络的权值矩阵Hc*r进行更新。
可选的,根据所述第一损失函数和所述第二损失函数,对所述权值参数Z、所述特征映射矩阵Mc*d、所述第一子映射网络的权值矩阵和所述第二子映射网络的权值矩阵Hc*r进行更新,包括:
将所述第一损失函数、所述第二损失函数与正则项之和,确定为所述n个样本的优化函数Ln,其中,所述正则项用于约束所述权值参数Z和所述特征映射矩阵Mc*d,Ln的表达式如下:
其中,优化函数Ln的第一项为上述第一损失函数第二项为上述第二损失函数第三项为正则项,该正则项用于约束所述权值参数Z和所述特征映射矩阵Mc*d,防止过拟合。
可以利用误差反向传播算法,最小化该损失函数Ln,将所述优化函数的取值最小时所对应的权值参数Z作为更新后的权值参数Z,将所述优化函数的取值最小时所对应的特征映射矩阵Mc*d作为更新后的特征映射矩阵Mc*d,将所述优化函数的取值最小时所对应的第一子映射网络的权值矩阵作为更新后的第一子映射网络的权值矩阵将所述优化函数的取值最小时所对应的第二子映射网络的权值矩阵Hc*r作为更新后的第二子映射网络的权值矩阵Hc*r
然后,判断是否达到停止条件。
这里,停止条件为:Ln不再下降,或下降幅度小于预设的阈值,或达到最大训练次数。如没达到则重复训练,直到达到停止条件。本申请实施例中,把所有图片都输入一遍算作训练一轮,通常需要训练若干轮。
可选的,所述从训练数据集中确定n个样本和所述n个样本的标签矩阵Yc*n,包括:
确定训练数据集,所述训练数据集中包括D个样本和与所述D个样本中每个样本的标签向量,其中,所述每个样本的标签向量中的元素yj表示所述每个样本是否包含第j个标签指示的对象,其中,D为大于n的正整数;
从所述训练数据集中随机抽取n个样本,并生成所述n个样本的标签矩阵Yc*n,所述标签矩阵Yc*n包括所述n个样本中的每个样本对应的标签向量。
因此,本申请实施例中,不必一次性输入整个训练数据集进行计算,而只需要分批次的输入图片进行计算,因此本申请实施例可以分批次地输入整个数据集进行训练。也就是说,本申请实施例中,可以通过多批次地输入数据集中的部分数据对模型进行训练,其中,每次输入的数据可以是从数据集中未输入的图片样本中随机抽取的。由于训练数据集通常包括大量的样本,因此本申请实施例通过分批次输入训练数据集可以减小训练模型的过程中对资源的占用,大大降低了训练模型的过程中对内存资源的需求,可以有效解决大规模数据下低秩标签相关性矩阵的计算问题。
可选的,还包括:利用所述特征提取网络提取第一样本的第一特征矩阵,其中,所述第一样本不属于所述n个样本;
利用所述第一映射网络获取所述第一特征矩阵的第一预测标签矩阵,所述第一预测标签矩阵中的元素表示所述第一样本包含第j个标签指示的对象的置信度。
具体的,训练完成后,在测试阶段,只需将测试图片输入至该神经网络模型中的特征提取网络,利用所述特征提取网络提取该测试图片的第一特征矩阵,并将该第一特征矩阵输入至FCM,利用FCM获取并输出所述第一特征矩阵的预测标签矩阵,所述预测标签矩阵中的元素表示所述测试包含第j个标签指示的对象的置信度。这里,测试图片可以为一个或多个图片,且可以不属于训练数据集。
第二方面,提供一种训练多标签分类模型的装置,所述装置用于执行上述第一方面或第一方面的任一可能的实现方式中的方法。具体地,所述装置可以包括用于执行第一方面或第一方面的任一可能的实现方式中的方法的模块。
第三方面,提供一种训练多标签分类模型的装置,所述装置包括存储器和处理器,所述存储器用于存储指令,所述处理器用于执行所述存储器存储的指令,并且对所述存储器中存储的指令的执行使得所述处理器执行第一方面或第一方面的任一可能的实现方式中的方法。
第四方面,提供一种计算机可读存储介质,所述计算机可读存储介质中存储有指令,当所述指令在计算机上运行时,使得计算机执行第一方面或第一方面的任一可能的实现方式中的方法。
第五方面,提供一种包含指令的计算机程序产品,当该计算机程序产品在计算机上运行时,使得计算机执行第一方面或第一方面的任一可能的实现方式中的方法。
附图说明
图1示出了单标签分类和多标签分类问题的示意图。
图2示出了本申请实施例提供的一种训练多标签分类模型的方法的示意性流程图。
图3示出了本申请实施例提供的一种多标签分类模型的示意图。
图4示出了本申请实施例中的一种补全标签的构造示意图。
图5示出了本申请实施例提供的一种多标签分类模型的示意图。
图6示出了本申请实施例提供的一种训练多标签分类模型的装置的示意性框图。
图7示出了本申请实施例提供的另一种训练多标签分类模型的装置的示意性框图。
具体实施方式
下面将结合附图,对本申请中的技术方案进行描述。
图1示出了单标签分类和多标签分类问题的示意图。如图1中(a)所示,单标签分类往往假设样本仅对应于一个类别标签,即具有唯一的语义意义。然后这种假设在许多实际情况下可能并不成立,特别考虑到客观对象本身所存在的语义多样性,物体很可能同时与多个不同的类别标签相关。因此在多标签问题中,如图1中(b)所示,常使用多个相关的类别标签来描述每个对象所对应的语义信息,例如,每幅图像可能同时对应多个语义标签,如“草地”,“天空”和“大海”,每首音乐片段可能会含有多种情绪,如“愉悦”和“轻松”。
多标签分类问题中,首先会给定一系列训练数据,这里该一系列训练数据组成的集合可以称为训练数据集。通过学习给定的训练数据,可以为未知样本预测其相应的标签子集。这里,训练数据集可以对应一个标签集合,该标签集合中可以包括与该训练数据相关的c个不同类别的标签,c为正整数。训练数据集可以包括D个样本和每个样本对应的标签子集,其中,D为正整数。可理解,这里的标签子集即为该标签集合的一个子集。也就是说,通过学习给定的训练数据集中的多个样本和每个样本对应的标签子集,可以预测未知样本的标签子集。
本申请实施例中,标签子集可以表示为标签向量。换句话说,样本的标签向量可以表示样本具有哪些标签或属于哪一些种类。例如,一幅图像的标签向量为[0 1 0 0 1 0],则表明共有6种类别,其中该标签向量中的每个元素代表一种类别或一个标签,0表示图像中没有这一类或这一标签,1表示图像中有这一类或这一标签。由于该标签向量有两个1标签,则表示该图像中有两种物体,分别属于第二类和第五类。这样,训练数据集中的D个样本中的每个样本可以对应一个标签向量yj,表示该样本是否包含第j个标签指示的对象,这里j的取值范围为1至c。应理解,本申请实施例中,样本是否包含第j个标签指示的对象即样本是否包含第j个标签。
这样,训练数据集中的全部或部分样本的标签向量就会组成一个标签矩阵Y:
另外,预测标签向量是多标签分类器的输出,代表多标签分类器对该图像所属类别的预测,其维度与标签向量相同。预测标签向量的元素的值为实值,如果该实值超过给定的一个阈值,那么该元素对应的位置就属于相应类别,否则不属于该类别。例如,预测标签向量为[0.7 0.2 0.1 0.8 1.0 0.0],阈值为0.5,将每一位上的数与阈值进行比较,大于阈值则相当于属于该类别。这样所预测的类别为第一类、第四类和第五类。如果该预测标签向量对应的标签向量为[1 0 0 1 0 1 0],则该预测标签向量完全正确。
在实际问题中,特别是数据中涉及大量类别标签的情况下,为数据中的每个样本都提供其对应的完整标签信息往往非常困难。因此,训练数据集中的样本所对应的标签信息很可能是不完全的。也就是说,在数据的标签矩阵中,样本不包含某标签并不代表实际情况下样本与该标签不相关。因此,需要通过训练数据集中已有的数据,学习标签之间的相关性,进而利用标签相关性获得一个包含更丰富标签信息的标签矩阵,然后通过该包含更丰富标签信息的标签矩阵可以更加准确的预测未知样本的标签信息。
现有技术在学习给定的训练数据时,需要先提取图像的特征,然后根据图像的特征计算特征映射矩阵和低秩标签相关性矩阵。在提取了图像的特征之后,该图像的特征就是固定的,因而不能够动态地根据标签学习输入图像的特征信息。基于此,本申请实施例设计了一种用于多标签分类的神经网络,能够通过学习特征映射矩阵、低秩标签相关性矩阵以及优化特征提取网络来实现多标签分类算法。
神经网络系统是一种智能化的识别系统,其通过反复训练的方式累计训练结果,来提高对各种目标物体或声音的识别能力。卷积神经网络是神经网络发展的主流方向之一。卷积神经网络一般包括卷积层(Convolutional Layer),修正线性单元(RectifiedLinear Units,ReLU)层、池化(Pooling)层以及全连接(Fully Connect,FC)层。其中,卷积层,ReLU层和Pooling层可能会交替多次重复设置。
卷积层可以被视为卷积神经网络的核心,在用于图像识别时,其输入端接收图像数据,用于通过滤波器对图像进行鉴定。这里的图像数据可以是摄像机拍到的图像转化结果,也可以是卷积层之前层的处理结果。通常图像数据是三维的图像阵列,比如32x32x3,其中,32x32是图像数据代表的图像的二维尺寸,即宽和高,这里的深度值3则是因为图像通常分为绿,红,蓝三个数据通道。卷积层中设有多个滤波器,不同的滤波器对应不同的图像特征(边界,颜色,形状等)对输入的图像数据按照一定的步长进行扫描。不同的滤波器中设置有不同的权重矩阵,所述权重矩阵为神经网络在学习过程中针对特定图像特征生成的。每一个滤波器每一拍扫描图像的一个区域,会得到一个三维的输入矩阵(MxNx3,M和N决定了扫描区域的尺寸),卷积网络将输入矩阵和权重矩阵作点积,得到一个结果值,然后会以特定步长扫描下一个区域,比如,横移两格。当一个滤波器按照特定步长扫描完所有区域后,结果值会构成一个二维矩阵;而当所有滤波器完成扫描后,结果值就会构成一个三维矩阵作为当前卷积层的输出,这个三维矩阵的不同深度层分别对应一个滤波器的扫描结果(即每个滤波器扫描后构成的二维矩阵)。
卷积层的输出会再送往ReLU层做处理(通过max(0,x)函数对输出的数值范围进行限定),以及送到Pooling层通过下采样缩减尺寸,在送往FC层之前,图像数据可能还会经过多个卷积层,以对图像特征进行深层次的鉴定(比如第一次卷积层仅对图像的轮廓特征进行鉴定,第二次卷积层开始识别图案等),最终输入FC层。与卷积层类似但稍有不同,FC层也是通过多个滤波器对输入数据作权重处理,但是FC层得每个滤波器并不像卷积层的滤波器那样通过每一拍的移位来扫描不同区域,而是一次性的扫描输入的图像数据的所有区域,然后与权重矩阵进行运算得到一个结果值。最终FC层输出的是一个1x1xN的矩阵,其实也就是一个数据序列,这个数据序列的每一位对应不同的目标物体,其上的值可以被视作这些物体目标存在的分值。在卷积层和FC层中,都会用到权重矩阵,神经网络可以通过自训练维护多种权重矩阵。
下文将结合图2和图3详细介绍本申请实施例的训练多标签分类模型的方法。
图2示出了本申请实施例提供的一种训练多标签分类模型的方法的示意性流程图。应理解,图2示出了训练多标签分类模型的方法的步骤或操作,但这些步骤或操作仅是示例,本申请实施例还可以执行其他操作或者图2中的各个操作的变形。此外,图2中的各个步骤可以按照与图2呈现的不同的顺序来执行,并且有可能并非要执行图2中的全部操作。
图3示出了本申请实施例提供的一种多标签分类模型300的示意图。该多标签分类模型300具体为神经网络系统。该多标签分类模型300包括特征提取网络301、FCM 302、映射网络31和处理单元305,其中,映射网络31可以包括FCW 303和FCH 304。应理解,图3示出的多标签分类模型300仅是示例,本申请实施例还可以包括其他模块或单元或者图3中的各个模块或单元的变形。
应注意,本申请实施例中多标签分类方法可以应用于图像标注、图像识别、声音识别、文本分类等多个领域,具体的,对应的训练数据集中的样本可以为图像、声音、文档等,本申请实施例对此不限定。为了描述方便,下文将以使用图像样本进行图像识别为例进行描述,但这并不会对本申请实施例的方案构成限制。
210,初始化多标签分类模型200的权值。
初始化多标签分类模型200的权值即初始化系统中的特征提取网络301、FCM 302、映射网络31(即FCW303以及FCH 304)的权值。
这里,特征提取网络301可以是任意一种能够提取图像特征的神经网络,例如可以为卷积神经网络或多层感知机等,本申请实施例对此不限定。其中,特征提取网络301的权值可以表示为Z,具体的,Z可以包含多个权值矩阵。权值矩阵的参数可以随机初始化生成,或者可以采用预训练的模型参数。这里,预训练的模型参数指的是已经训练好的模型的参数,如vgg16网络在ImageNet数据集上训练好的模型参数。
另外,FCM表示权值矩阵为特征映射矩阵Mc*d的全连接层,其中Mc*d可以表示多标签分类模型中的特征属性与类别标签之间的相关权重,其初始值可以随机生成。FCW 303表示权值矩阵为的全连接层,FCH 304表示权值矩阵为Hc*r的全连接层,和Hc*r的初始值可以随机生成。这里,r为自行设置的值,需要满足r≤c。
220,输入n幅图片。
由于神经网络的特性,不必一次性输入整个训练数据集进行计算,而只需要分批次的输入图片进行计算,因此本申请实施例可以分批次地输入整个数据集进行训练。也就是说,本申请实施例中,可以通过多批次地输入数据集中的部分数据对模型进行训练,其中,每次输入的数据可以是从数据集中未输入的图片样本中随机抽取的。由于训练数据集通常包括大量的样本,因此本申请实施例通过分批次输入训练数据集可以减小训练模型的过程中对资源的占用。
这时,一个批次输入至多标签分类模型300的样本的个数可以为n个。当样本为图片时,该n个样本可以表示为image_n,并且更具体的,image_n可以为从训练数据集的D个样本中随机抽取的n个图片,并且,n的取值可以远小于D。具体而言,n的大小可以根据该多标签分类模型300的能力确定。例如,如果该多标签分类模型300的数据处理能力较强,则n可以设置的比较大,以缩短训练模型的时间。又例如,如果该多标签分类模型300的数据处理能力较弱,则n可以设置的比较小,以降低训练模型所消耗的资源。这样,本申请实施例能够灵活地根据多标签分类模型300的数据处理能力设置n的取值。
并且,该n个样本所对应的标签矩阵可以表示为Yc*n,标签矩阵Yc*n中的元素yi*j表示第i个样本是否包含第j个标签指示的对象,这里i的取值范围为1至n,j的取值范围为1至c。具体的,标签矩阵的描述可以参见上文的描述,为避免重复,这里不再赘述。
本申请实施例中,可以将训练数据输入至图3中所示的多标签分类模型300。具体的,可以将训练数据集中的n个图片和该n个图片的标签矩阵Yc*n分别输入至该多标签分类模型300。
230,提取图片的特征,并根据图片的特征计算图片的预测标签矩阵。
具体而言,可以将n个图片输入至特征提取网络301,特征提取网络301经过卷积层、激活函数层、Pooling层、全连接层、Batchnorm层的作用,可以提取该n个图片的特征,并输出特征矩阵Xd*n。其中,d为正整数且表示所述特征矩阵Xd*n的特征维度。
然后,特征提取网络301输出的特征矩阵Xd*n可以输入至FCM 302。由于FCM表示权值矩阵为特征映射矩阵Mc*d的全连接层,且Mc*d可以表示多标签分类模型中的特征属性与类别标签之间的相关权重,因此FCM 302可以将输入的特征矩阵Xd*n映射到预测标签空间,得到预测标签矩阵即有:
这里,预测标签矩阵可以为包含更丰富标签信息的标签矩阵,其中的每个元素表示第i个样本包含第j个标签指示的对象的置信度。
240,根据图片的标签矩阵计算图片的低秩标签矩阵。
本申请实施例中,可以将n个的标签矩阵Yc*n输入至映射网络31,该映射网络31的输出为该标签矩阵Yc*n的含有标签相关性的低秩标签矩阵其中,该映射网络31的权值矩阵为标签相关性矩阵S,所述标签相关性矩阵S用于描述c个标签之间的关系,即有:
当矩阵的元素之间有相关性时,该矩阵是低秩的。由此可知,由于标签相关性矩阵S中的每个元素用于描述两个个标签之间的关系,因此标签相关性矩阵S为低秩矩阵。具体而言,低秩矩阵的秩小于该矩阵的行数或列数。此时,可以根据矩阵的低秩结构来恢复矩阵的缺失元素,这个恢复的过程可以称为矩阵补全,因此可以将称为补全标签矩阵,中很可能会包含更丰富的标签信息。中的每个元素可以表示第i个样本包含第j个标签指示的对象的置信度。
图4示出了本申请实施例中的一种补全标签的构造示意图。假设图片1在原始不完全标签矩阵Y中已知仅包含标签“鱼”,然后通过(2)式中的方式,利用标签相关性构造得到补全标签矩阵在构造过程中发现标签“鱼”和“海洋”之间存在非常强的相关性,因此导致补全标签矩阵中预测图片1中包含“海洋”标签的可能性也比较大。而考虑到“鱼”和“天空”之间仅存在一种弱依赖性,所以在补全标签矩阵中预测图片1包含“天空”的可能性比较小,因此通过该补全标签矩阵对样本图片1从原本仅包含部分标签信息“鱼”扩展到有非常大的可能性同时对应“鱼”和“海洋”两个标签,使得可利用的标签信息变得更加丰富。同样的,对图片2的原始不完全标签信息进行补全后,图片2的标签与“天空”更加相关。
在一种可能的实现方式中,映射网络31具体可以包括FCW 303和FCH 304。其中,由于两个矩阵相乘后得到的矩阵的秩小于两个矩阵中的任意一个矩阵的秩,因此可以通过设置r的大小(即r≤c)使和Hc*r低秩,进而使得低秩,即使得标签相关性矩阵S低秩,并且r可以通过多次训练取最优值。此时,有:
具体的,FCW 303的输入为image_n对应的标签矩阵Yc*n,FCW 303的输出可以表示为Pr*n,Pr*n可以直接输入至FCH 304的,最终由FCH 304输出低秩标签矩阵即有:
250,计算优化函数。
然后,处理单元305可以根据所述标签矩阵Yc*n、所述预测标签矩阵和所述低秩标签矩阵对所述权值参数Z、所述特征映射矩阵Mc*d和所述低秩标签相关性矩阵S进行更新,以训练所述多标签分类模型300。
具体的,处理单元305可以将预测标签矩阵和所述低秩标签矩阵之间的欧氏距离损失函数确定为第一损失函数作用是约束使之与相近,并且,第一损失函数的表达式如下:
这里,为了便于描述,省略了Mc*d、Xd*nHc*r和Yc*n的上标和下标。其中,是矩阵的Frobenius范数,矩阵Am*n的Frobenius范数定义为:
其中,Aij为矩阵A的元素,即欧氏距离损失函数。
另外,处理单元305还可以将所述标签矩阵Yc*n和所述低秩标签矩阵之间的欧氏距离损失函数确定为第二损失函数并且,第二损失函数的表达式如下:
同样的,公式(7)中也省略了Yc*n、Hc*r和Yc*n的上标和下标。这里,·2,1为矩阵的l2,1范数,矩阵Am*n的l2,1范数定义为:
进一步地,可以将所述第一损失函数、所述第二损失函数与正则项之和,确定为所述n个样本的损失函数Ln。这里,损失函数Ln也可以称为优化函数Ln,具体的,Ln的表达式如下:
其中,优化函数Ln的第一项为上述第一损失函数第二项为上述第二损失函数第三项为正则项,该正则项用于约束所述权值参数Z和所述特征映射矩阵Mc*d,防止过拟合。
260,利用误差反向算法更新权值参数。
误差反向传播算法是一种用于多层神经网络训练的方法,其以梯度下降方法为基础,通过优化损失函数,对神经网络每层的权值进行学习更新。
具体的,可以利用误差反向传播算法,最小化该损失函数Ln,并将所述优化函数的取值最小时所对应的权值参数Z作为更新后的权值参数Z,将所述优化函数的取值最小时所对应的特征映射矩阵Mc*d作为更新后的特征映射矩阵Mc*d,将所述优化函数的取值最小时所对应的权值矩阵S作为更新后的权值矩阵S。
时,则有:将所述优化函数的取值最小时所对应的权值矩阵作为更新后的权值矩阵将所述优化函数的取值最小时所对应的权值矩阵Hc*r作为更新后的权值矩阵Hc*r
为使用误差反向传播算法,下面对(9)式中的变量进行求导。以输入一幅图片、正则项采用l2范数为例。
记L1为一幅图片的优化函数,则有:
其中,矩阵的Frobenius范数的平方对应向量的l2范数的平方,矩阵的l2,1范数对应向量的l2范数。
下面对Mc*d、Hc*r的每一个元素求导得:
其中,mji为矩阵Mc*d的元素,hkj为矩阵Hc*r的元素,xi为向量xd的向量,wji为矩阵的元素,pj为向量yc的元素,为向量的元素,为向量的元素,yj为向量yc的元素,xd、pryc分别为矩阵Xd*n、Pr*nYc*n的列向量。对特征提取网络权值Z的误差反向求导可通过Mc*d传递得到。则Mc*d、Hc*r的元素更新为:
其中,是本次更新得到的值,是上次更新得到的值,hji和wji与之类似,η1、η2、η3分别是Mc*d、Hc*r的学习率,用于控制更新速率。特征提取网络部分权值Z的更新与此类似。
这样就可以学习到特征提取网络的权值Z、低秩标签相关性矩阵特征映射矩阵Mc*d,进而提升多标签分类的能力,同时也可以利用标签相关性补足缺失标签。
270,判断是否达到停止条件。
这里,停止条件为:Ln不再下降,或下降幅度小于预设的阈值,或达到最大训练次数。如没达到则重复步骤220至260,直到达到停止条件。本申请实施例中,把所有图片都输入一遍算作训练一轮,通常需要训练若干轮。
训练完成后,在测试阶段,只需执行220和230,即将测试图片输入至该神经网络模型中的特征提取网络,利用所述特征提取网络提取该测试图片的第一特征矩阵,并将该第一特征矩阵输入至FCM,利用FCM获取并输出所述第一特征矩阵的预测标签矩阵,所述预测标签矩阵中的元素表示所述测试包含第j个标签指示的对象的置信度。这里,测试图片可以为一个或多个图片,且可以不属于训练数据集。
并且具体的,对预测标签矩阵的单个预测向量来看,通过对做处理即可得到该图片所属的一个或多个类别,例如的某一个或一些元素值大于预设的阈值即表示该图片在该一个元素或多个元素相应位置有类别标签,该图片属于这一类或者几类。这里,预设的阈值可以为0.5,或者其他数值,本申请实施例对此不限定。
因此,本申请实施例所提供的该神经网络系统可以从输入数据直接训练出模型,而不需要额外的中间步骤,即该神经网络系统为一个端到端的的神经系统。这里,端到端的优点是特征提取、特征映射矩阵和低秩标签相关性矩阵可以同时优化,也就是说,本申请实施例可以动态学习图像特征,使特征提取网络更适应任务需求,多标签分类效果好。
另外,本申请实施例可以分批次地利用图片样本的图像特征计算低秩标签相关性矩阵和特征映射矩阵,而不必一次性用整个数据集的图像特征作为输入进行计算,即无须一次性用全部样本的图像特征进行训练,大大降低了训练模型的过程中对内存资源的需求,可以有效解决大规模数据下低秩标签相关性矩阵的计算问题。
图5示出了本申请实施例提供的一种多标签分类模型500的示意图。该模型500的特征提取网络部分采用VGG16网络,并且将VGG16网络的倒数第二个全连接层后的Dropout层的输出作为特征矩阵X。另外,特征提取网络的权值参数Z采用在ImageNet数据集上训练好的权值参数,然后对其微调(微调指固定前面几层的权值或者只进行很小的调整,完全训练最后一层或两层网络)。权值矩阵M、H和W的初始值可以采用高斯分布进行初始化,且M、H和W的值要完全训练。正则项可以采用Frobenius范数。
具体的,在训练时,特征提取网络VGG16(除去最后一个全连接层)的权值采用在ImageNet数据集上预训练的权值。
将n幅像素大小为224*224的RGB三通道图片image_n输入到VGG16网络中,这里1≤n≤N,N为训练集中图片的数量,图片大小可以表示为n*C*h*w或h*w*C*n等四维矩阵,其中,C为通道数(RGB图像为3),h为图片高度(224像素),w为图片宽度(224像素)。图片经过多次卷积、激活、Pooling等操作后,再经过两个全连接层以及Dropout层得到图像特征矩阵X4096*n
X4096*n再经过一个权值矩阵为Mc*4096的全连接层(FCM 502),得到预测标签矩阵
Yc*n经过两个权值矩阵分别为和Hc*r的全连接层(FCW 503和FCH 504),得到低秩标签相关性矩阵和含有标签相关性的低秩标签矩阵
处理单元505根据标签矩阵Yc*n、预测标签矩阵低秩标签矩阵得到优化函数:
然后,利用误差反向传播算法,最小化上述优化函数,更新权值参数Z,特征映射矩阵Mc*d、权值矩阵和Hc*r。具体的优化过程可以参见上文中的描述,为避免重复,这里不再赘述。
在更新了权值参数Z,特征映射矩阵Mc*d、权值矩阵和Hc*r之后,判断是否达到停止条件,如没达到则重复步骤,直到达到停止条件。具体的,停止条件可以参见上文中的描述,为避免重复,这里不再赘述。
在训练完成后,可以将测试图片输入至特征提取网络501,并将特征提取网络提取的图片的特征输入至FCM 502,通过FCM 502得到预测标签矩阵。
应注意,本申请实施例中,特征提取网络的结构可以采用其他网络代替,如AlexNet、GoogleNet、ResNet以及自定义网络等,本申请实施例对此不限定。特征输出的层可以采用上述网络的某一层的输出,也可以在上述基础上增减若干卷积层或全连接层。另外,本申请实施例还可以采用不同的正则化项。
因此,本申请实施例所提供的该神经网络系统可以从输入数据直接训练出模型,而不需要额外的中间步骤,即该神经网络系统为一个端到端的的神经系统。这里,端到端的优点是特征提取、特征映射矩阵和低秩标签相关性矩阵可以同时优化,也就是说,本申请实施例可以动态学习图像特征,使特征提取网络更适应任务需求,多标签分类效果好。
另外,本申请实施例可以分批次地利用图片样本的图像特征计算低秩标签相关性矩阵和特征映射矩阵,而不必一次性用整个数据集的图像特征作为输入进行计算,即无须一次性用全部样本的图像特征进行训练,大大降低了训练模型的过程中对内存资源的需求,可以有效解决大规模数据下低秩标签相关性矩阵的计算问题。
应注意,本申请实施例不限定专门的产品形态,本申请实施例的多标签分类的方法可以部署在通用的计算机节点上。初步构建的多标签分类模型可以被存储在硬盘存储器中,通过处理器和内存运行算法,对已有的训练数据集进行学习,得到该多标签分类模型。通该多标签分类模型可以对未知样本的标签进行预测,将预测结果存入硬盘存储器,实现对已有的标签集进行补全,并对未知样本对应的标签进行预测。
图6示出了本申请实施例提供的一种训练多标签分类模型的装置600的示意性框图。装置600包括确定单元:
确定单元610,用于从训练数据集中确定n个样本和与所述n个样本对应的标签矩阵Yc*n,所述标签矩阵Yc*n中的元素yi*j表示第i个样本是否包含第j个标签指示的对象,c表示与所述训练数据集中的样本相关的标签的个数;
提取单元620,用于利用特征提取网络提取所述n个样本的特征矩阵Xd*n,其中,所述特征提取网络具有权值参数Z,d表示所述特征矩阵Xd*n的特征维度;
第一获取单元630,用于利用第一映射网络获取所述特征矩阵Xd*n的预测标签矩阵所述预测标签矩阵中的元素表示第i个样本包含第j个标签指示的对象的置信度,其中,所述第一映射网络的权值矩阵为特征映射矩阵Mc*d
第二获取单元640,用于利用第二映射网络获取所述标签矩阵Yc*n的低秩标签矩阵其中,所述第二映射网络的权值矩阵为低秩标签相关性矩阵S,所述低秩标签相关性矩阵S用于描述所述c个标签之间的关系;
更新单元650,用于根据所述标签矩阵Yc*n、所述预测标签矩阵和所述低秩标签矩阵对所述权值参数Z、所述特征映射矩阵Mc*d和所述低秩标签相关性矩阵S进行更新,训练所述多标签分类模型;
其中,n、c、i、j和d均为正整数,且i的取值范围为1至n,j的取值范围为1至c。
因此,本申请实施例所提供的该神经网络系统可以从输入数据直接训练出模型,而不需要额外的中间步骤,即该神经网络系统为一个端到端的神经系统。这里,端到端的优点是特征提取、特征映射矩阵和低秩标签相关性矩阵可以同时优化,也就是说,本申请实施例可以动态学习图像特征,使特征提取网络更适应任务需求,并且多标签分类效果好。
可选的,所述第二映射网络包括第一子映射网络和第二子映射网络,所述第二映射网络、所述第一子映射网络和所述第二子映射网络具有以下关系:
其中,为所述第一子映射网络的权值矩阵,Hc*r为所述第二子映射网络的权值矩阵,r为小于或等于c的正整数。
可选的,所述更新单元650具体用于:
将所述预测标签矩阵和所述低秩标签矩阵之间的欧氏距离损失函数确定为的第一损失函数;
将所述标签矩阵Yc*n和所述低秩标签矩阵之间的欧氏距离损失函数确定为第二损失函数;
根据所述第一损失函数和所述第二损失函数,对所述权值参数Z、所述特征映射矩阵Mc*d、所述第一子映射网络的权值矩阵和所述第二子映射网络的权值矩阵Hc*r进行更新。
可选的,所述更新单元650具体还用于:
将所述第一损失函数、所述第二损失函数与正则项之和,确定为所述n个样本的优化函数,其中,所述正则项用于约束所述权值参数Z和所述特征映射矩阵Mc*d
将所述优化函数的取值最小时所对应的权值参数Z作为更新后的权值参数Z,将所述优化函数的取值最小时所对应的特征映射矩阵Mc*d作为更新后的特征映射矩阵Mc*d,将所述优化函数的取值最小时所对应的第一子映射网络的权值矩阵作为更新后的第一子映射网络的权值矩阵将所述优化函数的取值最小时所对应的第二子映射网络的权值矩阵Hc*r作为更新后的第二子映射网络的权值矩阵Hc*r
可选的,所述确定单元610具体用于:
确定训练数据集,所述训练数据集中包括D个样本和与所述D个样本中每个样本的标签向量,其中,所述每个样本的标签向量中的元素yj表示所述每个样本是否包含第j个标签指示的对象,其中,D为大于n的正整数;
从所述训练数据集中随机抽取n个样本,并生成所述n个样本的标签矩阵Yc*n,所述标签矩阵Yc*n包括所述n个样本中的每个样本对应的标签向量。
因此,本申请实施例中,不必一次性输入整个训练数据集进行计算,而只需要分批次的输入图片进行计算,因此本申请实施例可以分批次地输入整个数据集进行训练。由于训练数据集通常包括大量的样本,因此本申请实施例通过分批次输入训练数据集可以减小训练模型的过程中对资源的占用,大大降低了训练模型的过程中对内存资源的需求,可以有效解决大规模数据下低秩标签相关性矩阵的计算问题。
可选的,还包括:所述提取单元620还用于利用所述特征提取网络提取第一样本的第一特征矩阵,其中,所述第一样本不属于所述n个样本;
所述第一获取单元630还用于利用所述第一映射网络获取所述第一特征矩阵的第一预测标签矩阵,所述第一预测标签矩阵中的元素表示所述第一样本包含第j个标签指示的对象的置信度。
应注意,本发明实施例中,确定单元610、提取单元620、第一获取单元630、第二获取单元640和更新单元650可以由处理器实现。如图7所示,训练多标签分类模型的装置700可以包括处理器710、存储器720和通信接口730。其中,存储器720可以用于存储处理器710执行的指令或代码等。当该指令或代码被执行时,该处理器710用于执行上述方法实施例提供的方法,处理器710还用于控制通信接口730与外界进行通信。
在实现过程中,上述方法的各步骤可以通过处理器710中的硬件的集成逻辑电路或者软件形式的指令完成。结合本发明实施例所公开的方法的步骤可以直接体现为硬件处理器执行完成,或者用处理器中的硬件及软件模块组合执行完成。软件模块可以位于随机存储器,闪存、只读存储器,可编程只读存储器或者电可擦写可编程存储器、寄存器等本领域成熟的存储介质中。该存储介质位于存储器720,处理器710读取存储器720中的信息,结合其硬件完成上述方法的步骤。为避免重复,这里不再详细描述。
图6所示的训练多标签分类模型的装置600或图7所示的训练多标签分类模型的装置700能够实现前述方法实施例对应的各个过程,具体的,该训练多标签分类模型的装置600或训练多标签分类模型的装置700可以参见上文中的描述,为避免重复,这里不再赘述。
应理解,在本申请的各种实施例中,上述各过程的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本申请实施例的实施过程构成任何限定。
本申请实施例还提供一种计算机可读存储介质,该计算机可读存储介质包括计算机程序,当其在计算机上运行时,使得该计算机执行上述方法实施例提供的方法。
本申请实施例还提供一种包含指令的计算机程序产品,当该计算机程序产品在计算机上运行时,使得该计算机执行上述方法实施例提供的方法。
应理解,本发明实施例中提及的处理器可以是中央处理单元(CentralProcessing Unit,CPU),还可以是其他通用处理器、数字信号处理器(Digital SignalProcessor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
还应理解,本发明实施例中提及的存储器可以是易失性存储器或非易失性存储器,或可包括易失性和非易失性存储器两者。其中,非易失性存储器可以是只读存储器(Read-Only Memory,ROM)、可编程只读存储器(Programmable ROM,PROM)、可擦除可编程只读存储器(Erasable PROM,EPROM)、电可擦除可编程只读存储器(Electrically EPROM,EEPROM)或闪存。易失性存储器可以是随机存取存储器(Random Access Memory,RAM),其用作外部高速缓存。通过示例性但不是限制性说明,许多形式的RAM可用,例如静态随机存取存储器(Static RAM,SRAM)、动态随机存取存储器(Dynamic RAM,DRAM)、同步动态随机存取存储器(Synchronous DRAM,SDRAM)、双倍数据速率同步动态随机存取存储器(Double DataRate SDRAM,DDR SDRAM)、增强型同步动态随机存取存储器(Enhanced SDRAM,ESDRAM)、同步连接动态随机存取存储器(Synchlink DRAM,SLDRAM)和直接内存总线随机存取存储器(Direct Rambus RAM,DR RAM)。
需要说明的是,当处理器为通用处理器、DSP、ASIC、FPGA或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件时,存储器(存储模块)集成在处理器中。
应注意,本文描述的存储器旨在包括但不限于这些和任意其它适合类型的存储器。
还应理解,本文中涉及的第一、第二以及各种数字编号仅为描述方便进行的区分,并不用来限制本申请的范围。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的系统、装置和单元的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在本申请所提供的几个实施例中,应该理解到,所揭露的系统、装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。
所述功能如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(Read-Only Memory,ROM)、随机存取存储器(Random Access Memory,RAM)、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述,仅为本申请的具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应以所述权利要求的保护范围为准。

Claims (14)

1.一种训练多标签分类模型的方法,其特征在于,包括:
从训练数据集中确定n个样本和与所述n个样本对应的标签矩阵Yc*n,所述标签矩阵Yc*n中的元素yi*j表示第i个样本是否包含第j个标签指示的对象,c表示与所述训练数据集中的样本相关的标签的个数;
利用特征提取网络提取所述n个样本的特征矩阵Xd*n,其中,所述特征提取网络具有权值参数Z,d表示所述特征矩阵Xd*n的特征维度;
利用第一映射网络获取所述特征矩阵Xd*n的预测标签矩阵所述预测标签矩阵中的元素表示第i个样本包含第j个标签指示的对象的置信度,其中,所述第一映射网络的权值矩阵为特征映射矩阵Mc*d
利用第二映射网络获取所述标签矩阵Yc*n的低秩标签矩阵其中,所述第二映射网络的权值矩阵为低秩标签相关性矩阵S,所述低秩标签相关性矩阵S用于描述所述c个标签之间的关系;
根据所述标签矩阵Yc*n、所述预测标签矩阵和所述低秩标签矩阵对所述权值参数Z、所述特征映射矩阵Mc*d和所述低秩标签相关性矩阵S进行更新,训练所述多标签分类模型;
其中,n、c、i、j和d均为正整数,且i的取值范围为1至n,j的取值范围为1至c。
2.根据权利要求1所述的方法,其特征在于,所述第二映射网络包括第一子映射网络和第二子映射网络,所述第二映射网络、所述第一子映射网络和所述第二子映射网络具有以下关系:
其中,为所述第一子映射网络的权值矩阵,Hc*r为所述第二子映射网络的权值矩阵,r为小于或等于c的正整数。
3.根据权利要求2所述的方法,其特征在于,根据所述标签矩阵Yc*n、所述预测标签矩阵和所述低秩标签矩阵对所述权值参数Z、所述特征映射矩阵Mc*d和所述低秩标签相关性矩阵S进行更新,包括:
将所述预测标签矩阵和所述低秩标签矩阵之间的欧氏距离损失函数确定为的第一损失函数;
将所述标签矩阵Yc*n和所述低秩标签矩阵之间的欧氏距离损失函数确定为第二损失函数;
根据所述第一损失函数和所述第二损失函数,对所述权值参数Z、所述特征映射矩阵Mc*d、所述第一子映射网络的权值矩阵和所述第二子映射网络的权值矩阵Hc*r进行更新。
4.根据权利要求3所述的方法,其特征在于,根据所述第一损失函数和所述第二损失函数,对所述权值参数Z、所述特征映射矩阵Mc*d、所述第一子映射网络的权值矩阵和所述第二子映射网络的权值矩阵Hc*r进行更新,包括:
将所述第一损失函数、所述第二损失函数与正则项之和,确定为所述n个样本的优化函数,其中,所述正则项用于约束所述权值参数Z和所述特征映射矩阵Mc*d
将所述优化函数的取值最小时所对应的权值参数Z作为更新后的权值参数Z,将所述优化函数的取值最小时所对应的特征映射矩阵Mc*d作为更新后的特征映射矩阵Mc*d,将所述优化函数的取值最小时所对应的第一子映射网络的权值矩阵作为更新后的第一子映射网络的权值矩阵将所述优化函数的取值最小时所对应的第二子映射网络的权值矩阵Hc*r作为更新后的第二子映射网络的权值矩阵Hc*r
5.根据权利要求1-4任一项所述的方法,其特征在于,所述从训练数据集中确定n个样本和所述n个样本的标签矩阵Yc*n,包括:
确定训练数据集,所述训练数据集中包括D个样本和与所述D个样本中每个样本的标签向量,其中,所述每个样本的标签向量中的元素yj表示所述每个样本是否包含第j个标签指示的对象,其中,D为大于n的正整数;
从所述训练数据集中随机抽取n个样本,并生成所述n个样本的标签矩阵Yc*n,所述标签矩阵Yc*n包括所述n个样本中的每个样本对应的标签向量。
6.根据权利要求1-5任一项所述的方法,其特征在于,还包括:
利用所述特征提取网络提取第一样本的第一特征矩阵,其中,所述第一样本不属于所述n个样本;
利用所述第一映射网络获取所述第一特征矩阵的第一预测标签矩阵,所述第一预测标签矩阵中的元素表示所述第一样本包含第j个标签指示的对象的置信度。
7.一种训练多标签分类模型的装置,其特征在于,包括:
确定单元,用于从训练数据集中确定n个样本和与所述n个样本对应的标签矩阵Yc*n,所述标签矩阵Yc*n中的元素yi*j表示第i个样本是否包含第j个标签指示的对象,c表示与所述训练数据集中的样本相关的标签的个数;
提取单元,用于利用特征提取网络提取所述n个样本的特征矩阵Xd*n,其中,所述特征提取网络具有权值参数Z,d表示所述特征矩阵Xd*n的特征维度;
第一获取单元,用于利用第一映射网络获取所述特征矩阵Xd*n的预测标签矩阵所述预测标签矩阵中的元素表示第i个样本包含第j个标签指示的对象的置信度,其中,所述第一映射网络的权值矩阵为特征映射矩阵Mc*d
第二获取单元,用于利用第二映射网络获取所述标签矩阵Yc*n的低秩标签矩阵其中,所述第二映射网络的权值矩阵为低秩标签相关性矩阵S,所述低秩标签相关性矩阵S用于描述所述c个标签之间的关系;
更新单元,用于根据所述标签矩阵Yc*n、所述预测标签矩阵和所述低秩标签矩阵对所述权值参数Z、所述特征映射矩阵Mc*d和所述低秩标签相关性矩阵S进行更新,训练所述多标签分类模型;
其中,n、c、i、j和d均为正整数,且i的取值范围为1至n,j的取值范围为1至c。
8.根据权利要求7所述的装置,其特征在于,所述第二映射网络包括第一子映射网络和第二子映射网络,所述第二映射网络、所述第一子映射网络和所述第二子映射网络具有以下关系:
其中,为所述第一子映射网络的权值矩阵,Hc*r为所述第二子映射网络的权值矩阵,r为小于或等于c的正整数。
9.根据权利要求8所述的装置,其特征在于,所述更新单元具体用于:
将所述预测标签矩阵和所述低秩标签矩阵之间的欧氏距离损失函数确定为的第一损失函数;
将所述标签矩阵Yc*n和所述低秩标签矩阵之间的欧氏距离损失函数确定为第二损失函数;
根据所述第一损失函数和所述第二损失函数,对所述权值参数Z、所述特征映射矩阵Mc*d、所述第一子映射网络的权值矩阵和所述第二子映射网络的权值矩阵Hc*r进行更新。
10.根据权利要求9所述的装置,其特征在于,所述更新单元具体还用于:
将所述第一损失函数、所述第二损失函数与正则项之和,确定为所述n个样本的优化函数,其中,所述正则项用于约束所述权值参数Z和所述特征映射矩阵Mc*d
将所述优化函数的取值最小时所对应的权值参数Z作为更新后的权值参数Z,将所述优化函数的取值最小时所对应的特征映射矩阵Mc*d作为更新后的特征映射矩阵Mc*d,将所述优化函数的取值最小时所对应的第一子映射网络的权值矩阵作为更新后的第一子映射网络的权值矩阵将所述优化函数的取值最小时所对应的第二子映射网络的权值矩阵Hc*r作为更新后的第二子映射网络的权值矩阵Hc*r
11.根据权利要求7-10任一项所述的装置,其特征在于,所述确定单元具体用于:
确定训练数据集,所述训练数据集中包括D个样本和与所述D个样本中每个样本的标签向量,其中,所述每个样本的标签向量中的元素yj表示所述每个样本是否包含第j个标签指示的对象,其中,D为大于n的正整数;
从所述训练数据集中随机抽取n个样本,并生成所述n个样本的标签矩阵Yc*n,所述标签矩阵Yc*n包括所述n个样本中的每个样本对应的标签向量。
12.根据权利要求7-11任一项所述的装置,其特征在于,还包括:
所述提取单元还用于利用所述特征提取网络提取第一样本的第一特征矩阵,其中,所述第一样本不属于所述n个样本;
所述第一获取单元还用于利用所述第一映射网络获取所述第一特征矩阵的第一预测标签矩阵,所述第一预测标签矩阵中的元素表示所述第一样本包含第j个标签指示的对象的置信度。
13.一种计算机可读存储介质,其特征在于,包括计算机程序,当所述计算机程序在计算机上运行时,使得所述计算机执行如权利要求1-6中任一项所述的方法。
14.一种包含指令的计算机程序产品,其特征在于,当所述计算机程序产品在计算机上运行时,使得所述计算机执行如权利要求1-6中任一项所述的方法。
CN201711187818.9A 2017-11-24 2017-11-24 训练多标签分类模型的方法和装置 Active CN109840531B (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN201711187818.9A CN109840531B (zh) 2017-11-24 2017-11-24 训练多标签分类模型的方法和装置
PCT/CN2018/094309 WO2019100723A1 (zh) 2017-11-24 2018-07-03 训练多标签分类模型的方法和装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201711187818.9A CN109840531B (zh) 2017-11-24 2017-11-24 训练多标签分类模型的方法和装置

Publications (2)

Publication Number Publication Date
CN109840531A true CN109840531A (zh) 2019-06-04
CN109840531B CN109840531B (zh) 2023-08-25

Family

ID=66631376

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201711187818.9A Active CN109840531B (zh) 2017-11-24 2017-11-24 训练多标签分类模型的方法和装置

Country Status (2)

Country Link
CN (1) CN109840531B (zh)
WO (1) WO2019100723A1 (zh)

Cited By (14)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110929785A (zh) * 2019-11-21 2020-03-27 中国科学院深圳先进技术研究院 数据分类方法、装置、终端设备及可读存储介质
CN111524187A (zh) * 2020-04-22 2020-08-11 北京三快在线科技有限公司 一种视觉定位模型的训练方法及装置
CN111667399A (zh) * 2020-05-14 2020-09-15 华为技术有限公司 风格迁移模型的训练方法、视频风格迁移的方法以及装置
CN111797910A (zh) * 2020-06-22 2020-10-20 浙江大学 一种基于平均偏汉明损失的多维标签预测方法
CN112215795A (zh) * 2020-09-02 2021-01-12 苏州超集信息科技有限公司 一种基于深度学习的服务器部件智能检测方法
CN112353402A (zh) * 2020-10-22 2021-02-12 平安科技(深圳)有限公司 心电信号分类模型的训练方法、心电信号分类方法及装置
CN112465126A (zh) * 2020-07-27 2021-03-09 国电内蒙古东胜热电有限公司 用于跑冒滴漏检测的加载预训练卷积网络检测方法及装置
CN113076426A (zh) * 2021-06-07 2021-07-06 腾讯科技(深圳)有限公司 多标签文本分类及模型训练方法、装置、设备及存储介质
CN113222068A (zh) * 2021-06-03 2021-08-06 西安电子科技大学 基于邻接矩阵指导标签嵌入的遥感图像多标签分类方法
CN113269215A (zh) * 2020-02-17 2021-08-17 百度在线网络技术(北京)有限公司 一种训练集的构建方法、装置、设备和存储介质
CN113496232A (zh) * 2020-03-18 2021-10-12 杭州海康威视数字技术股份有限公司 标签校验方法和设备
CN115205011A (zh) * 2022-06-15 2022-10-18 海南大学 基于lsf-fc算法的银行用户画像模型生成方法
CN115841596A (zh) * 2022-12-16 2023-03-24 华院计算技术(上海)股份有限公司 多标签图像分类方法及其模型的训练方法、装置
CN113496232B (zh) * 2020-03-18 2024-05-28 杭州海康威视数字技术股份有限公司 标签校验方法和设备

Families Citing this family (22)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10878296B2 (en) * 2018-04-12 2020-12-29 Discovery Communications, Llc Feature extraction and machine learning for automated metadata analysis
CN110334186B (zh) * 2019-07-08 2021-09-28 北京三快在线科技有限公司 数据查询方法、装置、计算机设备及计算机可读存储介质
CN111626279B (zh) * 2019-10-15 2023-06-02 西安网算数据科技有限公司 一种负样本标注训练方法及高度自动化的票据识别方法
CN111199244B (zh) * 2019-12-19 2024-04-09 北京航天测控技术有限公司 一种数据的分类方法、装置、存储介质及电子装置
CN111275183B (zh) * 2020-01-14 2023-06-16 北京迈格威科技有限公司 视觉任务的处理方法、装置和电子系统
CN112785441B (zh) * 2020-04-20 2023-12-05 招商证券股份有限公司 数据处理方法、装置、终端设备及存储介质
CN111783831B (zh) * 2020-05-29 2022-08-05 河海大学 基于多源多标签共享子空间学习的复杂图像精确分类方法
CN112132188B (zh) * 2020-08-31 2024-04-16 浙江工业大学 一种基于网络属性的电商用户分类方法
CN112365931B (zh) * 2020-09-18 2024-04-09 昆明理工大学 一种用于预测蛋白质功能的数据多标签分类方法
CN112182214B (zh) * 2020-09-27 2024-03-19 中国建设银行股份有限公司 一种数据分类方法、装置、设备及介质
CN112598076B (zh) * 2020-12-29 2023-09-19 北京易华录信息技术股份有限公司 一种机动车属性识别方法及系统
CN112766383A (zh) * 2021-01-22 2021-05-07 浙江工商大学 一种基于特征聚类和标签相似性的标签增强方法
CN113034406B (zh) * 2021-04-27 2024-05-14 中国平安人寿保险股份有限公司 扭曲文档恢复方法、装置、设备及介质
CN113469225B (zh) * 2021-06-16 2024-03-22 浙江工业大学 基于跨域特征相关性分析的图像转换方法
CN113326698B (zh) * 2021-06-18 2023-05-09 深圳前海微众银行股份有限公司 检测实体关系的方法、模型训练方法及电子设备
CN113920210B (zh) * 2021-06-21 2024-03-08 西北工业大学 基于自适应图学习主成分分析方法的图像低秩重构方法
CN113327666B (zh) * 2021-06-21 2022-08-12 青岛科技大学 一种胸片疾病多分类网络的多标签局部至全局学习方法
CN114401205B (zh) * 2022-01-21 2024-01-16 中国人民解放军国防科技大学 无标注多源网络流量数据漂移检测方法和装置
CN114648635A (zh) * 2022-03-15 2022-06-21 安徽工业大学 一种融合标签间强相关性的多标签图像分类方法
CN115797709B (zh) * 2023-01-19 2023-04-25 苏州浪潮智能科技有限公司 一种图像分类方法、装置、设备和计算机可读存储介质
CN115934809B (zh) * 2023-03-08 2023-07-18 北京嘀嘀无限科技发展有限公司 一种数据处理方法、装置和电子设备
CN117876797A (zh) * 2024-03-11 2024-04-12 中国地质大学(武汉) 图像多标签分类方法、装置及存储介质

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20150039613A1 (en) * 2013-07-31 2015-02-05 Linkedln Corporation Framework for large-scale multi-label classification
CN104899596A (zh) * 2015-03-16 2015-09-09 景德镇陶瓷学院 一种多标签分类方法及其装置
CN105320967A (zh) * 2015-11-04 2016-02-10 中科院成都信息技术股份有限公司 基于标签相关性的多标签AdaBoost集成方法
US20160140451A1 (en) * 2014-11-17 2016-05-19 Yahoo! Inc. System and method for large-scale multi-label learning using incomplete label assignments

Family Cites Families (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US8379994B2 (en) * 2010-10-13 2013-02-19 Sony Corporation Digital image analysis utilizing multiple human labels
CN105825502B (zh) * 2016-03-12 2018-06-15 浙江大学 一种基于显著性指导的词典学习的弱监督图像解析方法
CN107292322B (zh) * 2016-03-31 2020-12-04 华为技术有限公司 一种图像分类方法、深度学习模型及计算机系统

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20150039613A1 (en) * 2013-07-31 2015-02-05 Linkedln Corporation Framework for large-scale multi-label classification
US20160140451A1 (en) * 2014-11-17 2016-05-19 Yahoo! Inc. System and method for large-scale multi-label learning using incomplete label assignments
CN104899596A (zh) * 2015-03-16 2015-09-09 景德镇陶瓷学院 一种多标签分类方法及其装置
CN105320967A (zh) * 2015-11-04 2016-02-10 中科院成都信息技术股份有限公司 基于标签相关性的多标签AdaBoost集成方法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
姚红革等: "基于小波分析和BP神经网络的图像特征提取", 《西安工业大学学报》 *
王臻: "基于学习标签相关性的多标签分类算法", 《中国科学技术大学 硕士论文》 *

Cited By (24)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110929785B (zh) * 2019-11-21 2023-12-05 中国科学院深圳先进技术研究院 数据分类方法、装置、终端设备及可读存储介质
CN110929785A (zh) * 2019-11-21 2020-03-27 中国科学院深圳先进技术研究院 数据分类方法、装置、终端设备及可读存储介质
CN113269215B (zh) * 2020-02-17 2023-08-01 百度在线网络技术(北京)有限公司 一种训练集的构建方法、装置、设备和存储介质
CN113269215A (zh) * 2020-02-17 2021-08-17 百度在线网络技术(北京)有限公司 一种训练集的构建方法、装置、设备和存储介质
CN113496232A (zh) * 2020-03-18 2021-10-12 杭州海康威视数字技术股份有限公司 标签校验方法和设备
CN113496232B (zh) * 2020-03-18 2024-05-28 杭州海康威视数字技术股份有限公司 标签校验方法和设备
CN111524187A (zh) * 2020-04-22 2020-08-11 北京三快在线科技有限公司 一种视觉定位模型的训练方法及装置
CN111667399A (zh) * 2020-05-14 2020-09-15 华为技术有限公司 风格迁移模型的训练方法、视频风格迁移的方法以及装置
CN111667399B (zh) * 2020-05-14 2023-08-25 华为技术有限公司 风格迁移模型的训练方法、视频风格迁移的方法以及装置
CN111797910A (zh) * 2020-06-22 2020-10-20 浙江大学 一种基于平均偏汉明损失的多维标签预测方法
CN111797910B (zh) * 2020-06-22 2023-04-07 浙江大学 一种基于平均偏汉明损失的多维标签预测方法
CN112465126A (zh) * 2020-07-27 2021-03-09 国电内蒙古东胜热电有限公司 用于跑冒滴漏检测的加载预训练卷积网络检测方法及装置
CN112465126B (zh) * 2020-07-27 2023-12-19 国电内蒙古东胜热电有限公司 用于跑冒滴漏检测的加载预训练卷积网络检测方法及装置
CN112215795B (zh) * 2020-09-02 2024-04-09 苏州超集信息科技有限公司 一种基于深度学习的服务器部件智能检测方法
CN112215795A (zh) * 2020-09-02 2021-01-12 苏州超集信息科技有限公司 一种基于深度学习的服务器部件智能检测方法
CN112353402A (zh) * 2020-10-22 2021-02-12 平安科技(深圳)有限公司 心电信号分类模型的训练方法、心电信号分类方法及装置
CN113222068B (zh) * 2021-06-03 2022-12-27 西安电子科技大学 基于邻接矩阵指导标签嵌入的遥感图像多标签分类方法
CN113222068A (zh) * 2021-06-03 2021-08-06 西安电子科技大学 基于邻接矩阵指导标签嵌入的遥感图像多标签分类方法
CN113076426B (zh) * 2021-06-07 2021-08-13 腾讯科技(深圳)有限公司 多标签文本分类及模型训练方法、装置、设备及存储介质
CN113076426A (zh) * 2021-06-07 2021-07-06 腾讯科技(深圳)有限公司 多标签文本分类及模型训练方法、装置、设备及存储介质
CN115205011A (zh) * 2022-06-15 2022-10-18 海南大学 基于lsf-fc算法的银行用户画像模型生成方法
CN115205011B (zh) * 2022-06-15 2023-08-08 海南大学 基于lsf-fc算法的银行用户画像模型生成方法
CN115841596A (zh) * 2022-12-16 2023-03-24 华院计算技术(上海)股份有限公司 多标签图像分类方法及其模型的训练方法、装置
CN115841596B (zh) * 2022-12-16 2023-09-15 华院计算技术(上海)股份有限公司 多标签图像分类方法及其模型的训练方法、装置

Also Published As

Publication number Publication date
WO2019100723A1 (zh) 2019-05-31
CN109840531B (zh) 2023-08-25

Similar Documents

Publication Publication Date Title
CN109840531B (zh) 训练多标签分类模型的方法和装置
WO2019100724A1 (zh) 训练多标签分类模型的方法和装置
US11875268B2 (en) Object recognition with reduced neural network weight precision
US11256960B2 (en) Panoptic segmentation
CN111950453B (zh) 一种基于选择性注意力机制的任意形状文本识别方法
US11776092B2 (en) Color restoration method and apparatus
US10275719B2 (en) Hyper-parameter selection for deep convolutional networks
CN112446270B (zh) 行人再识别网络的训练方法、行人再识别方法和装置
US20210295089A1 (en) Neural network for automatically tagging input image, computer-implemented method for automatically tagging input image, apparatus for automatically tagging input image, and computer-program product
CN111523621A (zh) 图像识别方法、装置、计算机设备和存储介质
CN112639828A (zh) 数据处理的方法、训练神经网络模型的方法及设备
CN113326930B (zh) 数据处理方法、神经网络的训练方法及相关装置、设备
CN110222718B (zh) 图像处理的方法及装置
CN113095370A (zh) 图像识别方法、装置、电子设备及存储介质
JP2010157118A (ja) パターン識別装置及びパターン識別装置の学習方法ならびにコンピュータプログラム
KR20210093875A (ko) 비디오 분석 방법 및 연관된 모델 훈련 방법, 기기, 장치
CN112749737A (zh) 图像分类方法及装置、电子设备、存储介质
CN115018039A (zh) 一种神经网络蒸馏方法、目标检测方法以及装置
CN116863194A (zh) 一种足溃疡图像分类方法、系统、设备及介质
CN116844032A (zh) 一种海洋环境下目标检测识别方法、装置、设备及介质
US11842540B2 (en) Adaptive use of video models for holistic video understanding
CN113516182B (zh) 视觉问答模型训练、视觉问答方法和装置
CN112396069B (zh) 基于联合学习的语义边缘检测方法、装置、系统及介质
CN117372798A (zh) 一种模型训练方法及相关装置
Vedachalam Pixelwise Classification of Agricultural Crops in Aerial Imagery Using Deep Learning Methods

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant