CN113919499A - 模型训练方法与模型训练系统 - Google Patents

模型训练方法与模型训练系统 Download PDF

Info

Publication number
CN113919499A
CN113919499A CN202111403181.9A CN202111403181A CN113919499A CN 113919499 A CN113919499 A CN 113919499A CN 202111403181 A CN202111403181 A CN 202111403181A CN 113919499 A CN113919499 A CN 113919499A
Authority
CN
China
Prior art keywords
model
hourglass
distillation
loss value
loss function
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111403181.9A
Other languages
English (en)
Inventor
李蛟
张朝晋
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Via Technologies Inc
Original Assignee
Via Technologies Inc
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Via Technologies Inc filed Critical Via Technologies Inc
Priority to CN202111403181.9A priority Critical patent/CN113919499A/zh
Priority to TW111100068A priority patent/TWI793951B/zh
Publication of CN113919499A publication Critical patent/CN113919499A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • G06F18/2155Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/048Activation functions
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/082Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明提出一种模型训练方法与模型训练系统,该方法包含:将一未标注数据输入至具有一深度神经网络架构的一模型,其中该模型至少包含一堆叠沙漏网络,该堆叠沙漏网络包含多个沙漏网络;针对该未标注数据,分别得到该多个沙漏网络的多个第一蒸馏特征层输出;根据一第一损失函数与该多个第一蒸馏特征层输出,来计算一第一损失值;以及依据该第一损失值来进行反向传播,以调整该模型的参数。由此,可以有效地减少训练样本的标注需求,节省大量人力与时间。

Description

模型训练方法与模型训练系统
技术领域
本发明有关于模型训练,尤指一种可通过未标注数据及蒸馏损失来对模型中的堆叠沙漏网络进行半监督训练以及可基于硬件平台的模型部署需求来适应性地对模型中的堆叠沙漏网络进行裁剪的方法与系统。
背景技术
关于采用人工智能技术来进行分类的应用(例如影像辨识),由于实际环境中光照、相机、遮挡等因素的影响和分类的多样性,导致实际使用的模型进行训练所需要的数据量会十分庞大,因而加重人工进行样本标注的困难。另外,不同的硬件平台会有不同的模型部署需求(例如存储器容量、运算能力以及演算即时性),故不同的硬件平台往往需要分别部署不同大小的模型。
因此,需要一种采用半监督学习机制且能够根据硬件平台的模型部署需求来适应性地进行裁剪的模型以及相关的模型训练方法与系统。
发明内容
因此,本发明的目的之一在于提出一种可通过未标注数据及蒸馏损失来对模型中的堆叠沙漏网络进行半监督训练以及可基于硬件平台的模型部署需求来适应性地对模型中的堆叠沙漏网络进行裁剪的方法与系统。
在本发明的一实施例中,揭露一种模型训练方法。该模型训练方法包含:将一未标注数据输入至具有一深度神经网络架构的一模型,其中该模型至少包含一堆叠沙漏网络,该堆叠沙漏网络包含多个沙漏网络;针对该未标注数据,分别得到该多个沙漏网络的多个第一蒸馏特征层输出;根据一第一损失函数与该多个第一蒸馏特征层输出,来计算一第一损失值;以及依据该第一损失值来进行反向传播,以调整该模型的参数。
在本发明的另一实施例中,揭露一种模型训练系统。该模型训练系统包含一储存装置以及一处理器。该储存装置用以储存一程序码。该处理器用以载入并执行该程序码,以执行以下操作:将一未标注数据输入至具有一深度神经网络架构的一模型,其中该模型至少包含一堆叠沙漏网络,该堆叠沙漏网络包含多个沙漏网络;针对该未标注数据,分别得到该多个沙漏网络的多个第一蒸馏特征层输出;根据一第一损失函数与该多个第一蒸馏特征层输出,来计算一第一损失值;以及依据该第一损失值来进行反向传播,以调整该模型的参数。
本发明模型训练方法与系统所采用的半监督学习机制可使用老师-学生知识蒸馏的机制,因此可以有效地减少训练样本的标注需求,节省大量人力与时间。另外,本发明模型训练方法与系统可根据硬件平台的模型部署需求,适应性地裁剪所训练的模型中的堆叠沙漏网络以产生部署至硬件平台的固化模型,而不需要重新训练。
附图说明
图1为根据本发明一实施例的模型训练系统的示意图。
图2为根据本发明一实施例的模型的示意图。
图3为根据本发明一实施例的CBL模块的示意图。
图4为根据本发明一实施例的沙漏网络结构的示意图。
其中,附图中符号的简单说明如下:
100:模型训练系统;102:处理器;104:储存装置;110:硬件平台;202:调整网络;204:堆叠沙漏网络;206_1、206_2、206_3、206_4:沙漏网络;208:Maxout网络;210:逐元素最大值函数层;300:CBL模块;302:卷积层;304:批标准化层;306:带泄漏修正线性单元层;402、404、406、408、412、414、416、418、421、422、423、424、425、426:模块;Code_TF:程序码;MD:模型;MD_F:固化模型;D_IN:训练数据;D_L:标注数据;D_UL:未标注数据;A1、A2、A3、A4:蒸馏特征层输出;L1、L2、L3、L4:分类输出;L_OUT:最终分类输出。
具体实施方式
图1为根据本发明一实施例的模型训练系统的示意图。如图1所示,模型训练系统100包含(但不限于)一处理器102以及一储存装置104。储存装置104用以储存一程序码Code_TF,例如储存装置104可以是传统硬盘、固态硬盘、存储器等等,但本发明并不以此为限。处理器102可载入并执行程序码Code_TF,来实现本发明模型训练方法。根据本发明模型训练方法,会对具有一深度神经网络(deep neural network)架构的一模型MD进行训练,并且于模型MD训练达到一定程度之后,对模型MD进行固化(freeze)来产生一固化模型(frozen model)MD_F,以便后续将固化模型MD_F部署至一硬件平台110,例如硬件平台110可以是手机、边缘装置(edge device)等等。本实施例中,模型训练方法会采用半监督学习机制,换言之,模型MD会根据训练数据D_IN(例如是影像数据)来进行半监督训练,因此,训练数据D_IN中少量数据是以人工方式进行标注的标注数据(labeled data)D_L,而训练数据D_IN中大量数据则是未标注数据(unlabeled data)D_UL。此外,模型训练方法还会参照硬件平台110的模型部署需求(例如存储器容量、运算能力及/或演算即时性),适应性地(adaptively)裁剪所训练的模型MD以产生部署至硬件平台110的固化模型MD_F。
图2为根据本发明一实施例的模型的示意图。本实施例中,利用训练数据D_IN来进行半监督训练的模型MD可采用图2所示的网络架构。如图2所示,模型MD包含(但不限于)一调整网络(resize network)202、一堆叠沙漏网络(stacked hourglass network)204、一Maxout网络208以及一逐元素最大值(element-wise maximum)函数层210。调整网络202是用来进行维度调整,例如降低调整网络202最后输出的每一通道(channel)的特征图(feature map)的大小,假设输入影像的尺寸为W*H(单位是像素),则可通过调整网络202来将输入变成(W/s)*(H/s),例如s=4。
本实施例中,调整网络202可根据设计需求而由一个或多个CBL模块构成。图3为根据本发明一实施例的CBL模块的示意图。CBL模块300包含(但不限于)一卷积(convolution)层302、一批标准化(batch normalization,BN)层304以及一带泄漏修正线性单元(LeakyRectified Linear Unit,Leaky ReLU)层306。举例来说,卷积层302可采用步长(stride)为2,假若输入影像的尺寸为W*H,则每经过一个卷积层302的处理,便可分别让W与H减半(亦即W/2与H/2)。批标准化层304可以加快模型的训练速度,并让训练更加稳定。带泄漏修正线性单元层306是启动函数层,用以保留正值,并将负值替换为0。
调整网络202可以根据输入的大小和网络的输出个数来调节CBL模块的个数。假若输入为640*360以及检测目标为8*8的大小,则调整网络202可以使用串接的3个CBL模块300(卷积层302的步长为2),来得到80x44的特征图大小。假若使用更小的输入320*180,此时检测目标的大小仅有4*4,则调整网络202可以使用串接的2个CBL模块300(卷积层302的步长为2),来得到80x44的特征图大小。
堆叠沙漏网络204由多个沙漏网络串接构成,如图2所示,本实施例所采用的堆叠沙漏网络204包含4个沙漏网络206_1、206_2、206_3、206_4,其中每个沙漏网络具有相同的网络结构。请注意,堆叠沙漏网络204由4个沙漏网络构成仅作为范例说明之用,并非作为本发明的限制,实际上,堆叠沙漏网络204可由K个沙漏网络构成,其中K可以是任何不小于2的正整数(亦即K≧2)。图4为根据本发明一实施例的沙漏网络结构的示意图。于一实施方式中,每一个沙漏网络206_1、206_2、206_3、206_4可采用图4所示的网络结构。多个模块402、404、406、408中的每一个模块代表对输入特征图进行降采样(down-sample),使得输出特征图的大小会小于输入特征图的大小。多个模块412、414、416、418中的每一个模块代表对输入特征图进行升采样(up-sample),使得输出特征图的大小会大于输入特征图的大小。另外,多个模块421、422、423、424、425、426中的每一个模块则代表输出特征图与输入特征图会具有相同大小。
通过调整网络202可以得到需要的特征图大小,亦即通过CBL模块的适当设计,每一通道(每一类别)的特征图会具有所要的大小,为了得到更好的语义信息,沙漏网络可以继续降维并且将语义信息较好的低维特征融合到最终的输出特征图,如图4所示,沙漏网络的输出特征图与输入特征图会具有相同大小,因此理论上可以无限堆叠,这里的堆叠并不影响输出的物理意义,但是通过堆叠可以提高模型输出的准确度。由于堆叠沙漏网络204中串接的沙漏网络越多,则特征侦测的效果会越好,因此,本发明模型训练方法所采用的半监督学习机制可使用老师-学生知识蒸馏(teacher-student knowledge distillation),将最后一个沙漏网络206_4作为老师,来指导前面多个沙漏网络206_1、206_2、206_3的学习,换言之,最后一个沙漏网络206_4的输出结果可作为前面多个沙漏网络206_1、206_2、206_3的目标以进行训练。
沙漏网络206_1、206_2、206_3、206_4会分别产生蒸馏特征层输出A1、A2、A3、A4,其中每个蒸馏特征层输出包含多个通道(多个类别)的相对应特征图。Maxout网络208会对蒸馏特征层输出A1、A2、A3、A4进行处理来分别产生分类输出L1、L2、L3、L4,Maxout网络208可视为启动函数层,它的输出是一组输入的最大值,亦即,蒸馏特征层输出A1经由启动函数(亦即Maxout函数)而得到分类输出L1,蒸馏特征层输出A2经由启动函数(亦即Maxout函数)而得到分类输出L2,蒸馏特征层输出A3经由启动函数(亦即Maxout函数)而得到分类输出L3,以及蒸馏特征层输出A4经由启动函数(亦即Maxout函数)而得到分类输出L4。假设每个蒸馏特征层输出包含m个通道(m个类别)的相对应特征图,通过特征图的Maxout函数处理,则每一个分类输出会包含分别对应至m个通道(m个类别)的预测概率。逐元素最大值函数层210则是用来针对分类输出L1、L2、L3、L4进行逐元素取最大值的操作,以产生一最终分类输出L_OUT,举例来说,最终分类输出L_OUT中对应至第1个通道(第1个类别)的预测概率①是取分类输出L1、L2、L3、L4中对应至第1个通道(第1个类别)的所有预测概率①中的最大值,最终分类输出L_OUT中对应至第2个通道(第2个类别)的预测概率②是取分类输出L1、L2、L3、L4中对应至第2个通道(第2个类别)的所有预测概率②中的最大值,最终分类输出L_OUT中对应至第3个通道(第3个类别)的预测概率③是取分类输出L1、L2、L3、L4中对应至第3个通道(第3个类别)的所有预测概率③中的最大值,以此类推。在一实施例中,可以采用最终分类输出L_OUT中预测概率最大的类别作为输入影像的分类结果。
如前所述,本发明模型训练方法采用半监督学习机制来对模型MD进行训练,因此,会先使用人工建立的少量标注数据D_L来对模型MD进行训练,当模型MD训练到一定准确率之后,再通过大量的未标注数据D_L来对模型MD进行训练。当利用标注数据D_L来对模型MD进行训练时,本发明模型训练方法采用两种损失函数(loss function)来计算反向传播(back propagation)所要使用的损失值L,于本实施例中,反向传播所要使用损失值L主要由两个部分构成,分别是分类损失值Lclassify以及特征学习损失值Ldistillation
计算分类损失值Lclassify的损失函数是采用交叉熵(cross entropy)损失函数,并基于分类输出Li(例如图2所示的分类输出L1、L2、L3、L4)来决定分类损失值Lclassify,如下所示:
Figure BDA0003371805980000061
其中M为类别的数量(通道的数量);N为沙漏网络的个数,yic为样本i在c类别上的标签(label),Pic为样本i属于类别c的预测概率。预测概率就是网络学习的分类分数,例如得分在0.5以上就可以认为是某个类别。
计算特征学习损失值Ldistillation的损失函数是采用蒸馏损失函数(例如L2损失函数),并基于蒸馏特征层输出Ai(例如图2所示的蒸馏特征层输出A1、A2、A3、A4)来决定特征学习损失值Ldistillation,如下所示:
Figure BDA0003371805980000062
Figure BDA0003371805980000071
其中D表示平方和函数;An表示n个沙漏网络的蒸馏特征层输出,例如N=4时,Ldistillation=[F(A1)-F(A4)]2+[F(A2)-F(A4)]2+[F(A3)-F(A4)]2;Ani为An的第i个通道;S为spatial softmax函数。所有的操作均为逐元素操作,亦即,M个通道输出的绝对值平方之后相加,并进行空间唯独上的softmax运算,举例来说,通过spatial softmax函数,可以得到同一蒸馏特征层中每个通道的预测概率。
最终损失值L是由分类损失值Lclassify以及特征学习损失值Ldistillation所组成,如下所示:
L=Lclassify+γLdistillation
其中γ为权重系数,一开始可以先将权重系数γ设为1,等到模型MD充分训练后,可以再根据分类损失值Lclassify以及特征学习损失值Ldistillation之间的比值来调整权重系数γ,使得分类损失值Lclassify以及特征学习损失值Ldistillation能维持在同一数量级,例如权重系数γ最终可设为0.2。
当改用未标注数据D_UL来对模型MD进行训练时,由于缺乏人工建立的标签,本发明模型训练方法仅采用单一损失函数来计算反向传播所需使用的损失值L,本实施例中,反向传播所需使用的损失值L仅由特征学习损失值Ldistillation构成,而不包含分类损失值Lclassify,此时,通过最后一个沙漏网络206_4的输出结果来作为真实标签,并采用特征学习损失值Ldistillation(亦即以最后一个沙漏网络206_4的输出结果作为前面多个沙漏网络206_1、206_2、206_3的目标来计算损失),来指导前面多个沙漏网络206_1、206_2、206_3的学习,达到半监督训练/学习的效果。
每次训练时,本发明模型训练方法可采用Adam优化演算法(Adam optimization)来进行反向传播,以调整模型MD的参数(例如每一层的权重值),举例来说,当使用未标注数据D_UL来对模型MD进行训练时,可判断特征学习损失值Ldistillation与一阈值(threshold)的关系,如果大于阈值,则将特征学习损失值Ldistillation的大小回传至模型MD中以调整F(An)值的大小,如此不断回圈,直到特征学习损失值Ldistillation小于阈值为止。
堆叠沙漏网络204是由多个沙漏网络206_1、206_2、206_3、206_4串接而构成,本实施例中,这些沙漏网络206_1、206_2、206_3、206_4是由相同的损失函数来进行训练,因此,本发明模型训练方法可直接通过剪裁来获得不同大小的模型,例如控制CBL模块及/或沙漏网络的数量,自我调整模型大小,进而适应具有不同模型部署需求的硬件平台。
假设硬件平台110具有第一模型部署需求(例如存储器容量M1、运算能力P1及/或演算即时性R1),则在对模型MD进行固化来产生部署至硬件平台110的固化模型MD_F时,可以将沙漏网络206_2、206_3、206_4去掉,只取通过沙漏网络206_1所得到的分类输出L1,请注意,由于仅有一个分类输出L1,因此,固化模型MD_F还可省略逐元素最大值函数层210,故硬件平台110实际执行固化模型MD_F所定义的分类网络时,分类输出L1会直接作为最终分类输出L_OUT。
假设硬件平台110具有第二模型部署需求(例如存储器容量M2(M2>M1)、运算能力P2(P2>P1)以/或演算即时性R2(R2<R1)),则在对模型MD进行固化来产生部署至硬件平台110的固化模型MD_F时,可以将沙漏网络206_3、206_4去掉,只取通过沙漏网络206_1、206_2所得到的分类输出L1、L2,硬件平台110实际执行固化模型MD_F所定义的分类网络时,分类输出L1、L2后续可通过逐元素最大值函数层210来得到最终分类输出L_OUT。
假设硬件平台110具有第三模型部署需求(例如存储器容量M3(M3>M2)、运算能力P3(P3>P2)及/或演算即时性R3(R3<R2)),则在对模型MD进行固化来产生部署至硬件平台110的固化模型MD_F时,可以将沙漏网络206_4去掉,只取通过沙漏网络206_1、206_2、206_3所得到的分类输出L1、L2、L3,硬件平台110实际执行固化模型MD_F所定义的分类网络时,分类输出L1、L2、L3后续可通过逐元素最大值函数层210来得到最终分类输出L_OUT。
综上所述,本发明模型训练方法与系统所采用的半监督学习机制可使用老师-学生知识蒸馏的机制,因此可以有效地减少训练样本的标注需求,节省大量人力与时间。另外,本发明模型训练方法与系统可根据硬件平台的模型部署需求,适应性地裁剪所训练的模型中的堆叠沙漏网络以产生部署至硬件平台的固化模型,而不需要重新训练。
以上所述仅为本发明较佳实施例,然其并非用以限定本发明的范围,任何熟悉本项技术的人员,在不脱离本发明的精神和范围内,可在此基础上做进一步的改进和变化,因此本发明的保护范围当以本申请的权利要求书所界定的范围为准。

Claims (10)

1.一种模型训练方法,其特征在于,包含:
将未标注数据输入至具有深度神经网络架构的模型,其中该模型至少包含堆叠沙漏网络,该堆叠沙漏网络包含多个沙漏网络;
针对该未标注数据,分别得到该多个沙漏网络的多个第一蒸馏特征层输出;
根据第一损失函数与该多个第一蒸馏特征层输出,来计算第一损失值;以及
依据该第一损失值来进行反向传播,以调整该模型的参数。
2.如权利要求1所述的模型训练方法,其中该第一损失函数采用蒸馏损失函数,其以该多个沙漏网络中最后一个沙漏网络所产生的第一蒸馏特征层输出作为该多个沙漏网络中其它沙漏网络所产生的第一蒸馏特征层输出的目标来计算该第一损失值。
3.如权利要求1所述的模型训练方法,还包含:
于该未标注数据输入至该模型之前:
将标注数据输入至该模型;
针对该标注数据,分别得到该多个沙漏网络的多个第二蒸馏特征层输出;
通过该多个第二蒸馏特征层输出,来分别得到多个分类输出;
根据该第一损失函数与该多个第二蒸馏特征层输出,来计算第二损失值;
根据第二损失函数与该多个分类输出,来计算第三损失值,其中该第二损失函数不同于该第一损失函数;
依据该第二损失值与该第三损失值来计算第四损失值;以及
依据该第四损失值来进行反向传播,以调整该模型的参数。
4.如权利要求3所述的模型训练方法,其中该第二损失函数采用交叉熵损失函数。
5.如权利要求1所述的模型训练方法,还包含:
参照硬件平台的模型部署需求,适应性地裁剪所训练的该模型中的该堆叠沙漏网络以产生部署至该硬件平台的固化模型;
其中该固化模型所包含的沙漏网络的个数取决于该硬件平台的模型部署需求。
6.一种模型训练系统,其特征在于,包含:
储存装置,用以储存程序码;以及
处理器,用以载入并执行该程序码,以执行以下操作:
将未标注数据输入至具有深度神经网络架构的模型,其中该模型至少包含堆叠沙漏网络,该堆叠沙漏网络包含多个沙漏网络;
针对该未标注数据,分别得到该多个沙漏网络的多个第一蒸馏特征层输出;
根据第一损失函数与该多个第一蒸馏特征层输出,来计算第一损失值;以及
依据该第一损失值来进行反向传播,以调整该模型的参数。
7.如权利要求6所述的模型训练系统,其中该第一损失函数采用蒸馏损失函数,其以该多个沙漏网络中最后一个沙漏网络所产生的第一蒸馏特征层输出作为该多个沙漏网络中其它沙漏网络所产生的第一蒸馏特征层输出的目标来计算该第一损失值。
8.如权利要求6所述的模型训练系统,其中该处理器还执行该程序码,以执行以下操作:
于该未标注数据输入至该模型之前:
将标注数据输入至该模型;
针对该标注数据,分别得到该多个沙漏网络的多个第二蒸馏特征层输出;
通过该多个第二蒸馏特征层输出,来分别得到多个分类输出;
根据该第一损失函数与该多个第二蒸馏特征层输出,来计算第二损失值;
根据第二损失函数与该多个分类输出,来计算第三损失值,其中该第二损失函数不同于该第一损失函数;
依据该第二损失值与该第三损失值来计算第四损失值;以及
依据该第四损失值来进行反向传播,以调整该模型的参数。
9.如权利要求8所述的模型训练系统,其中该第二损失函数采用交叉熵损失函数。
10.如权利要求6所述的模型训练系统,其中该处理器还执行该程序码,以执行以下操作:
参照硬件平台的模型部署需求,适应性地裁剪所训练的该模型中的该堆叠沙漏网络以产生部署至该硬件平台的固化模型;
其中该固化模型所包含的沙漏网络的个数取决于该硬件平台的模型部署需求。
CN202111403181.9A 2021-11-24 2021-11-24 模型训练方法与模型训练系统 Pending CN113919499A (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN202111403181.9A CN113919499A (zh) 2021-11-24 2021-11-24 模型训练方法与模型训练系统
TW111100068A TWI793951B (zh) 2021-11-24 2022-01-03 模型訓練方法與模型訓練系統

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111403181.9A CN113919499A (zh) 2021-11-24 2021-11-24 模型训练方法与模型训练系统

Publications (1)

Publication Number Publication Date
CN113919499A true CN113919499A (zh) 2022-01-11

Family

ID=79247901

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111403181.9A Pending CN113919499A (zh) 2021-11-24 2021-11-24 模型训练方法与模型训练系统

Country Status (2)

Country Link
CN (1) CN113919499A (zh)
TW (1) TWI793951B (zh)

Family Cites Families (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110298437B (zh) * 2019-06-28 2021-06-01 Oppo广东移动通信有限公司 神经网络的分割计算方法、装置、存储介质及移动终端
CN111339846B (zh) * 2020-02-12 2022-08-12 深圳市商汤科技有限公司 图像识别方法及装置、电子设备和存储介质
CN112579777B (zh) * 2020-12-23 2023-09-19 华南理工大学 一种未标注文本的半监督分类方法
CN112949786B (zh) * 2021-05-17 2021-08-06 腾讯科技(深圳)有限公司 数据分类识别方法、装置、设备及可读存储介质

Also Published As

Publication number Publication date
TW202321993A (zh) 2023-06-01
TWI793951B (zh) 2023-02-21

Similar Documents

Publication Publication Date Title
CN109711481B (zh) 用于画作多标签识别的神经网络、相关方法、介质和设备
US11270187B2 (en) Method and apparatus for learning low-precision neural network that combines weight quantization and activation quantization
US10460230B2 (en) Reducing computations in a neural network
CN107977707B (zh) 一种对抗蒸馏神经网络模型的方法及计算设备
US11657254B2 (en) Computation method and device used in a convolutional neural network
CN111626330A (zh) 基于多尺度特征图重构和知识蒸馏的目标检测方法与系统
JP2019032808A (ja) 機械学習方法および装置
CN110298394B (zh) 一种图像识别方法和相关装置
CN113610232A (zh) 网络模型量化方法、装置、计算机设备以及存储介质
WO2022217853A1 (en) Methods, devices and media for improving knowledge distillation using intermediate representations
WO2021042857A1 (zh) 图像分割模型的处理方法和处理装置
US20230169332A1 (en) Method and system for machine learning from imbalanced data with noisy labels
CN112598062A (zh) 一种图像识别方法和装置
CN114463727A (zh) 一种地铁驾驶员行为识别方法
CN114299304A (zh) 一种图像处理方法及相关设备
CN112270334B (zh) 一种基于异常点暴露的少样本图像分类方法及系统
Wong et al. Real-time adaptive hand motion recognition using a sparse bayesian classifier
CN113919499A (zh) 模型训练方法与模型训练系统
CN113378866B (zh) 图像分类方法、系统、存储介质及电子设备
US20200372363A1 (en) Method of Training Artificial Neural Network Using Sparse Connectivity Learning
WO2022098307A1 (en) Context-aware pruning for semantic segmentation
CN115080699A (zh) 基于模态特异自适应缩放与注意力网络的跨模态检索方法
Lee et al. MPQ-YOLACT: Mixed-Precision Quantization for Lightweight YOLACT
CN112560760A (zh) 一种注意力辅助的无监督视频摘要系统
CN114782779B (zh) 基于特征分布迁移的小样本图像特征学习方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination