CN114494776A - 一种模型训练方法、装置、设备以及存储介质 - Google Patents
一种模型训练方法、装置、设备以及存储介质 Download PDFInfo
- Publication number
- CN114494776A CN114494776A CN202210082301.8A CN202210082301A CN114494776A CN 114494776 A CN114494776 A CN 114494776A CN 202210082301 A CN202210082301 A CN 202210082301A CN 114494776 A CN114494776 A CN 114494776A
- Authority
- CN
- China
- Prior art keywords
- feature extraction
- network
- sample
- feature
- different
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 117
- 238000000034 method Methods 0.000 title claims abstract description 66
- 238000000605 extraction Methods 0.000 claims abstract description 338
- 238000004821 distillation Methods 0.000 claims abstract description 112
- 238000010586 diagram Methods 0.000 claims abstract description 53
- 238000001514 detection method Methods 0.000 claims abstract description 26
- 238000013528 artificial neural network Methods 0.000 claims description 26
- 238000004590 computer program Methods 0.000 claims description 9
- 238000012545 processing Methods 0.000 abstract description 12
- 238000013473 artificial intelligence Methods 0.000 abstract description 8
- 238000013135 deep learning Methods 0.000 abstract description 3
- 230000006870 function Effects 0.000 description 31
- 238000005516 engineering process Methods 0.000 description 23
- 238000004891 communication Methods 0.000 description 8
- 238000013140 knowledge distillation Methods 0.000 description 7
- 239000000284 extract Substances 0.000 description 4
- 238000004364 calculation method Methods 0.000 description 3
- 238000012986 modification Methods 0.000 description 3
- 230000004048 modification Effects 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 238000007499 fusion processing Methods 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 239000011159 matrix material Substances 0.000 description 2
- 230000008569 process Effects 0.000 description 2
- 101000926525 Homo sapiens eIF-2-alpha kinase GCN2 Proteins 0.000 description 1
- 101000862627 Homo sapiens eIF-2-alpha kinase activator GCN1 Proteins 0.000 description 1
- 241001465754 Metazoa Species 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 102100034175 eIF-2-alpha kinase GCN2 Human genes 0.000 description 1
- 102100030495 eIF-2-alpha kinase activator GCN1 Human genes 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000010354 integration Effects 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000007726 management method Methods 0.000 description 1
- 230000003924 mental process Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000001953 sensory effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Probability & Statistics with Applications (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Image Analysis (AREA)
Abstract
本公开提供了一种模型训练方法、装置、设备以及存储介质,涉及人工智能技术领域,尤其涉及深度学习、计算机视觉技术领域,可应用于图像处理、图像检测等场景领域。具体实现方案为:将样本图像输入至特征提取网络,得到所述特征提取网络对应的样本特征图;其中,所述特征提取网络包括老师特征提取网络和学生特征提取网络;根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失;根据所述第一蒸馏损失,对所述学生特征提取网络进行训练。能够提高对学生特征提取网络训练的精准性。
Description
技术领域
本公开涉及人工智能技术领域,尤其涉及深度学习、计算机视觉技术领域,可应用于图像处理、图像检测等场景。
背景技术
随着人工智能技术的发展,知识蒸馏技术在模型训练过程中的应用越来越广泛。其中,知识蒸馏是一种采用预先训练好的结构复杂的老师模型(Teacher Model)来训练结构简单的学生模型(Student Model),以实现将老师模型的功能赋予学生模型的技术,那么,如何基于知识蒸馏技术,高精度的训练学生模型至关重要。
发明内容
本公开提供了一种模型训练方法、装置、设备以及存储介质。
根据本公开的一方面,提供了一种模型训练方法,包括:
将样本图像输入至特征提取网络,得到特征提取网络对应的样本特征图;其中,特征提取网络包括老师特征提取网络和学生特征提取网络;
根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失;
根据第一蒸馏损失,对学生特征提取网络进行训练。
根据本公开的另一方面,提供了一种电子设备,该电子设备包括:
至少一个处理器;以及
与至少一个处理器通信连接的存储器;其中,
存储器存储有可被至少一个处理器执行的指令,指令被至少一个处理器执行,以使至少一个处理器能够执行本公开任一实施例的模型训练方法。
根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,计算机指令用于使计算机执行本公开任一实施例的模型训练方法。
本公开实施例的方案,给出了一种基于蒸馏技术进行模型训练的优选方案,极大的提高了对学生特征提取网络训练的精准性。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本公开的限定。其中:
图1是根据本公开实施例提供的一种模型训练方法的流程图;
图2是根据本公开实施例提供的一种模型训练方法的流程图;
图3是根据本公开实施例提供的一种模型训练方法的流程图;
图4是根据本公开实施例提供的一种模型训练方法的流程图;
图5是根据本公开实施例提供的一种模型训练方法的流程图;
图6是根据本公开实施例提供的一种模型训练方法的流程图;
图7是根据本公开实施例提供的一种模型训练的原理框图;
图8是根据本公开实施例提供的一种模型训练装置的结构示意图;
图9是用来实现本公开实施例的模型训练方法的电子设备的框图。
具体实施方式
以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
图1是根据本公开实施例提供的一种模型训练方法的流程图;本公开实施例适用于基于蒸馏技术进行模型训练的情况。尤其适用于基于蒸馏技术,训练目标检测场景中的特征提取网络的情况。该方法可以由模型训练装置来执行,该装置可以采用软件和/或硬件的方式实现。如图1所示,本实施例提供的模型训练方法可以包括:
S101,将样本图像输入至特征提取网络,得到特征提取网络对应的样本特征图。
其中,样本图像可以是模型训练时使用的包含有至少两个目标对象的图像。该目标对象可以是样本图像中的前景对象。可选的,样本图像中包含的至少两个目标对象可以属于同一类别,也可以属于不同类别。例如,至少两个目标对象可以都是人,也可以包含人、动物和植物等。样本特征图可以是对样本图像进行特征提取得到的,其可以通过矩阵的形式来表示。
特征提取网络可以是用于执行图像特征提取任务的神经网络,本实施例的特征提取网络包括老师特征提取网络和学生特征提取网络。其中,老师特征提取网络是预先训练好的结构复杂的特征提取网络,学生特征提取网络是未经训练的结构简单的特征提取网络。可选的,采用知识蒸馏技术,对学生特征提取网络进行训练,最终可使得学生特征提取网络的特征提取功能尽可能的接近老师特征提取网络。优选的,本实施例训练后的学生特征提取网络可运用到多类别的目标检测场景中。
可选的,在本实施例中,可以分别将样本图像输入到老师特征提取网络和学生特征提取网络中,得到老师特征提取网络对应输出的样本特征图(即第一样本特征图)和学生特征提取网络对应输出的样本特征图(即第二样本特征图)。
S102,根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失。
本实施例的样本图像中包含有至少两个目标对象,且至少两个目标对象所属类别可能相同也可能不同。相应的,至少两个目标对象之间的类别关系可以表征至少两个目标对象中,任意两目标对象是否属于同一类别的关系。
可选的,本实施例根据老师特征提取网络输出的样本特征图(即第一样本特征图),提取其中包含的至少两个目标对象,并分析两两目标对象是否属于同一类别,作为第一样本特征图中至少两个目标对象之间的类别关系(即第一类别关系)。例如,本实施例可以针对第一样本特征图中的两个目标对象,分析两者属于同一类别的概率,进而构建关系矩阵来表征第一类别关系。同理,根据学生特征提取网络输出的样本特征图(即第二样本特征图),确定该第二样本特征图中至少两个目标对象的类别关系(即第二类别关系)。进而根据第一类别关系和第二类别关系,确定第一蒸馏损失。
具体的,本实施例可以直接将第一类别关系和第二类别关系输入到预先设定的损失函数中,得到第一蒸馏损失;其中,损失函数可以是交叉熵损失函数,或者平方损失函数(即L2损失函数)等。本实施例还可以对第一类别关系和第二类别关系进行进一步处理(如与目标对象在特征图中对应的特征值相结合)后再输入到预先设定的损失函数中,得到第一蒸馏损失等,对此不进行限定。
S103,根据第一蒸馏损失,对学生特征提取网络进行训练。
可选的,本实施例可以采用第一蒸馏损失,对学生特征提取网络进行训练,不断优化学生特征提取网络中的网络参数。具体的,本实施例需要基于多组样本图像,基于上述方法对学生特征提取网络进行多次迭代训练,直至达到预设的训练停止条件,则停止调整学生特征提取网络的网络参数,得到经训练的学生特征提取网络。训练停止条件可以包括:训练次数达到预设次数,或者第一蒸馏损失收敛等。
本公开实施例的方案,通过分别将样本图像输入到老师特征提取网络和学生特征提取网络,得到老师特征提取网络输出的样本特征图和学生特征提取网络输出的样本特征图,进而根据两种不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失,并基于第一蒸馏损失来训练学生特征提取网络。本方案在基于知识蒸馏技术训练学生特征提取网络时,所用的蒸馏损失是基于不同特征提取网络对应的样本特征图中不同目标对象之间的类别关系确定的,使得基于该蒸馏损失训练后的学生特征提取网络在提取图像特征时,能够更好的体现目标对象之间的类别关系,提高了特征提取的精准性。此外,若将本实施例训练的学生特征提取网络应用到目标检测场景,能够更精准的完成目标检测任务。
图2是根据本公开实施例提供的一种模型训练方法的流程图。本公开实施例在上述实施例的基础上,进一步对如何根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失进行详细解释说明,如图2所示,本实施例提供的模型训练方法可以包括:
S201,将样本图像输入至特征提取网络,得到特征提取网络对应的样本特征图。
其中,本实施例的特征提取网络包括老师特征提取网络和学生特征提取网络。
具体的,将样本图像输入到老师特征提取网络和学生特征提取网络中,得到老师特征提取网络对应输出的样本特征图(即第一样本特征图)和学生特征提取网络对应输出的样本特征图(即第二样本特征图)。
S202,根据不同特征提取网络对应的样本特征图中至少两个目标对象的特征值,以及至少两个目标对象之间的类别关系,确定不同特征提取网络对应的对象关系表示。
其中,目标对象的特征值可以是目标对象所属区域在样本特征图中对应的数值。对象关系表示可以表征样本特征图中不同目标对象特征之间的关系。
可选的,本实施例可以是根据老师特征提取网络对应的样本特征图(即第一样本特征图)中至少两个目标对象的特征值,以及该第一样本特征图中至少两个目标对象之间的类别关系(即第一类别关系),确定老师特征提取网络对应的对象关系表示(即第一对象关系表示);根据学生特征提取网络对应的样本特征图(即第二样本特征图)中至少两个目标对象的特征值,以及该第二样本特征图中至少两个目标对象之间的类别关系(即第二类别关系),确定学生特征提取网络对应的对象关系表示(即第二对象关系表示)。
具体的,本实施例针对每一特征提取网络(即老师特征提取网络或学生特征提取网络),确定其对应的对象关系表示的方式可以是:针对该特征提取网络对应的样本特征图中的每一目标对象,将该目标对象与其他目标对象之间的类别关系进行整合,并将整合结果作为该目标对象的类别关系权重,再结合目标对象的特征值,确定该目标对象在对象关系表示中对应的数值。例如,针对该特征提取网络对应的样本特征图中的每一目标对象,若该目标对象与其他目标对象之间的类别关系为该目标对象与其他目标对象属于同一类别的概率值,则此时可以将该目标对象与其他目标对象属于同一类别的概率值进行求和,并将求和结果作为该目标对象的类别关系权重,与该目标对象的特征值相乘,得到该目标对象在对象关系表示中对应的数值。
S203,根据不同特征提取网络对应的对象关系表示,确定第一蒸馏损失。
具体的,本实施例可以直接将老师特征提取网络对应的对象关系表示(即第一对象关系表示),和学生特征提取网络对应的对象关系表示(即第二对象关系表示)输入到预先设定的损失函数中,得到第一蒸馏损失;其中,损失函数可以是交叉熵损失函数,或者平方损失函数(即L2损失函数)等。
S204,根据第一蒸馏损失,对学生特征提取网络进行训练。
本公开实施例的方案,通过分别将样本图像输入到老师特征提取网络和学生特征提取网络,得到老师特征提取网络输出的样本特征图和学生特征提取网络输出的样本特征图,并根据两种不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,以及至少两个目标对象的特征值,确定不同特征提取网络对应的对象关系表示,进而基于不同特征提取网络对应的对象关系表示,确定第一蒸馏损失,基于第一蒸馏损失来训练学生特征提取网络。本方案在确定蒸馏损失时,同时考虑了不同特征提取网络对应的样本特征图中不同目标对象之间的类别关系,以及不同目标对象的特征值,使得确定的蒸馏损失不但能够体现目标对象之间的类别关系,还能够体现目标对象的自身特征值。基于该蒸馏损失对学生特征提取网络进行训练,极大的提高了学生特征提取网络的训练精度。
图3是根据本公开实施例提供的一种模型训练方法的流程图。本公开实施例在上述实施例的基础上,进一步对如何根据不同特征提取网络对应的样本特征图中至少两个目标对象的特征值,以及至少两个目标对象之间的类别关系,确定不同特征提取网络对应的对象关系表示进行详细解释说明,如图3所示,本实施例提供的模型训练方法可以包括:
S301,将样本图像输入至特征提取网络,得到特征提取网络对应的样本特征图。
其中,本实施例的特征提取网络包括老师特征提取网络和学生特征提取网络。
具体的,将样本图像输入到老师特征提取网络和学生特征提取网络中,得到老师特征提取网络对应输出的样本特征图(即第一样本特征图)和学生特征提取网络对应输出的样本特征图(即第二样本特征图)。
S302,根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定不同特征提取网络对应的样本特征图中每一目标对象的目标关系。
其中,目标关系可以是从样本特征图中至少两个目标对象之间的所有类别关系中抽取出来的一部分类别关系。
可选的,本实施例中,可以是根据老师特征提取网络对应的第一样本特征图中至少两个目标对象之间的类别关系(即第一类别关系),确定老师特征提取网络对应的第一样本特征图中每一目标对象的目标关系(即第一目标关系);根据学生特征提取网络对应的第二样本特征图中至少两个目标对象之间的类别关系(即第二类别关系),确定学生特征提取网络对应的第二样本特征图中每一目标对象的目标关系(即第二目标关系)。
具体的,本实施例针对每一特征提取网络(即老师特征提取网络或学生特征提取网络),确定其对应的样本特征图中每一目标对象的目标关系的方式可以是:针对该特征提取网络对应的样本特征图中的每一目标对象,分析该目标对象与其他目标对象之间的类别关系,并从中抽取出表征与该目标对象类别相近的预设个数的类别关系作为目标关系。例如,针对该特征提取网络对应的样本特征图中的每一目标对象,若该目标对象与其他目标对象之间的类别关系为该目标对象与其他目标对象属于同一类别的概率值,则可以将该目标对象与其他目标对象属于同一类别的概率值按照从高到底的顺序进行排序,并选择排名靠前的预设个数的概率值作为目标关系。
S303,根据不同特征提取网络对应的样本特征图中每一目标对象的特征值,以及该目标对象的目标关系,确定不同特征提取网络对应的对象关系表示。
可选的,本实施例可以是根据老师特征提取网络对应的第一样本特征图中至少两个目标对象的特征值,以及该第一样本特征图中至少两个目标对象之间的第一目标关系,确定老师特征提取网络对应的对象关系表示(即第一对象关系表示);根据学生特征提取网络对应的第二样本特征图中至少两个目标对象的特征值,以及该第二样本特征图中至少两个目标对象之间的第二目标关系,确定学生特征提取网络对应的对象关系表示(即第二对象关系表示)。
具体的,本实施例针对每一特征提取网络(即老师特征提取网络或学生特征提取网络),确定其对应的对象关系表示的方式可以是:针对该特征提取网络对应的样本特征图中的每一目标对象,将该目标对象关联的各目标关系进行整合,并将整合结果作为该目标对象的类别关系权重,结合目标对象的特征值,确定该目标对象在对象关系表示中对应的数值。例如,针对该特征提取网络对应的样本特征图中的每一目标对象,将该目标对象对应的数值排名较高的预设个数的概率值进行求和,并将求和结果作为该目标对象的类别关系权重,与该目标对象的特征值相乘,得到该目标对象在对象关系表示中对应的数值。
S304,根据不同特征提取网络对应的对象关系表示,确定第一蒸馏损失。
S305,根据第一蒸馏损失,对学生特征提取网络进行训练。
本公开实施例的方案,通过分别将样本图像输入到老师特征提取网络和学生特征提取网络,得到老师特征提取网络输出的样本特征图和学生特征提取网络输出的样本特征图,并分别从两种不同特征提取网络对应的样本特征图中不同目标对象之间的类别关系中抽取出每一目标对象的目标关系,再结合每一目标对象的特征值,确定不同特征提取网络对应的对象关系表示,进而基于不同特征提取网络对应的对象关系表示,确定第一蒸馏损失,基于第一蒸馏损失来训练学生特征提取网络。本方案在确定蒸馏损失时,针对每一特征提取网络对应的样本特征图中的每一目标对象,从其对应的多个类别关系中抽取出一部分与该目标对象类别相近的目标关系来计算蒸馏损失,并非使用所有的类别关系,在目标对象数量较多的情况下,极大的降低了蒸馏损失的计算量,进而提高学生特征提取网络的训练效率。
图4是根据本公开实施例提供的一种模型训练方法的流程图。本公开实施例在上述实施例的基础上,进一步对如何确定不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系进行详细解释说明,如图4所示,本实施例提供的模型训练方法可以包括:
S401,将样本图像输入至特征提取网络,得到特征提取网络对应的样本特征图。
其中,本实施例的特征提取网络包括老师特征提取网络和学生特征提取网络。
具体的,将样本图像输入到老师特征提取网络和学生特征提取网络中,得到老师特征提取网络对应输出的样本特征图(即第一样本特征图)和学生特征提取网络对应输出的样本特征图(即第二样本特征图)。
S402,采用类别关系预测网络确定不同特征提取网络对应的样本特征图中不同目标对象属于同一类别的概率值,作为不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系。
其中,类别关系预测网络可以是用于预测不同目标对象之间的类别关系的神经网络。
可选的,本实施例可以是采用类别关系预测网络确定老师特征提取网络对应的第一样本特征图中不同目标对象属于同一类别的概率值,作为老师特征提取网络对应的第一样本特征图中至少两个目标对象之间的类别关系(即第一类别关系);再采用类别关系预测网络确定学生特征提取网络对应的第二样本特征图中不同目标对象属于同一类别的概率值,作为学生特征提取网络对应的第二样本特征图中至少两个目标对象之间的类别关系(即第二类别关系)。需要说明的是,针对老师特征提取网络和学生特征提取网络,预测目标对象属于同一类别的概率值时使用的类别关系预测网络可以是同一个网络,也可以是两个不同的网络。
具体的,针对每一特征提取网络(即老师特征提取网络或学生特征提取网络),可以先从该特征提取网络对应的样本特征图中的提取各个目标对象所属特征区域,然后将各个目标对象所属特征区域输入到类别关系预测网络中,该类别关系预测网络可以基于预设函数(如归一化指数函数softmax)针对每一目标对象,结合其所属特征区域与其他各目标对象所属特征区域,预测该目标对象与其他各目标对象属于同一类别的概率,作为该目标对象与其他各目标对象之间的类别关系。针对各目标对象中的两两目标对象都执行完上述操作后,即可得到该特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系。
S403,根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失。
S404,根据第一蒸馏损失,对学生特征提取网络进行训练。
本公开实施例的方案,先分别将样本图像输入到老师特征提取网络和学生特征提取网络,得到老师特征提取网络输出的样本特征图和学生特征提取网络输出的样本特征图,再采用类别预测网络分别确定两种不同特征提取网络对应的样本特征图中不同目标对象属于同一类别的概率值,作为两种不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,进而基于两种不同特征提取网络对应的该类别关系,确定第一蒸馏损失,并基于第一蒸馏损失来训练学生特征提取网络,本方案采用类别关系预测网络预测不同目标对象属于同一类别的概率值作为不同目标对象之间的类别关系,提高了目标对象之间的类别关系确定的准确性,为后续基于该类别关系确定损失函数,精准训练学生特征提取网络提供了保障。
可选的,在本公开实施例的基础上,还包括:根据第一蒸馏损失,对类别关系预测网络进行训练。具体的,实施例在基于第一蒸馏损失训练学生特征提取网络的同时,还可以基于该第一蒸馏损失函数训练类别关系预测网络,即基于第一蒸馏损失联合训练学生特征提取网络和类别关系预测网络,不断优化学生特征提取网络和类别关系预测网络中的网络参数,直到达到训练停止条件,如训练次数达到预设次数,或者第一蒸馏损失收敛等。通过将学生特征提取网络与类别关系预测网络联合训练,进一步提高了模型训练的准确性。
图5是根据本公开实施例提供的一种模型训练方法的流程图。本公开实施例在上述实施例的基础上,进行了进一步的优化,如图5所示,本实施例提供的模型训练方法可以包括:
S501,将样本图像输入至特征提取网络,得到特征提取网络对应的样本特征图。
其中,本实施例的特征提取网络包括老师特征提取网络和学生特征提取网络。
具体的,将样本图像输入到老师特征提取网络和学生特征提取网络中,得到老师特征提取网络对应输出的样本特征图(即第一样本特征图)和学生特征提取网络对应输出的样本特征图(即第二样本特征图)。
S502,根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失。
S503,根据不同特征提取网络对应的样本特征图中不同像素点之间的像素关系,确定第二蒸馏损失。
本实施例中样本特征图中不同像素点之间的像素关系可以表征样本特征图中两两像素点是否同属于前景区域或背景区域。
可选的,本实施例根据老师特征提取网络输出的样本特征图(即第一样本特征图),分析其中两两像素点是否同属于前景区域或背景区域,并根据分析结果,确定第一样本特征图中不同像素点之间的像素关系(即第一像素关系)。同理,根据学生特征提取网络输出的样本特征图(即第二样本特征图),确定该第二样本特征图中不同像素点之间的像素关系(即第二像素关系)。进而根据第一像素关系和第二像素关系,确定第二蒸馏损失。
具体的,本实施例可以直接将第一像素关系和第二像素关系输入到预先设定的损失函数中,得到第二蒸馏损失;其中,损失函数可以是交叉熵损失函数,或者平方损失函数(即L2损失函数)等。本实施例还可以对第一像素关系和第二像素关系进行进一步处理后在输入到预先设定的损失函数中,得到第二蒸馏损失等,对此不进行限定。
S504,根据第一蒸馏损失和第二蒸馏损失,对学生特征提取网络进行训练。
可选的,本实施例可以分别采用第一蒸馏损失和第二蒸馏损失对学生特征提取网络进行训练,不断优化学生特征提取网络中的网络参数。还可以是将第一蒸馏损失和第二蒸馏损失进行融合处理,如进行求和处理,或加权求和处理等,基于融合处理后的蒸馏损失来对学生特征提取网络进行训练,不断优化学生特征提取网络中的网络参数。
本公开实施例的方案,通过分别将样本图像输入到老师特征提取网络和学生特征提取网络,得到老师特征提取网络输出的样本特征图和学生特征提取网络输出的样本特征图,进而根据两种不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失,根据两种不同特征提取网络对应的样本特征图中不同像素点之间的像素关系,确定第二蒸馏损失,并同时基于第一蒸馏损失和第二蒸馏损失来训练学生特征提取网络。本方案引入样本特征图中不同像素点之间的像素关系确定第二蒸馏损失,该蒸馏损失主要用于训练学生特征提取网络进行特征提取时消除背景信息与前景信息之间的相互干扰,使得训练后的学生特征提取网络在提取图像特征时,能够更精准的区分出特征图中的背景区域和前景区域,进而提高图像特征提取的精准性。
图6是根据本公开实施例提供的一种模型训练方法的流程图。本公开实施例在上述实施例的基础上,进一步对如何确定不同特征提取网络对应的样本特征图中不同像素点之间的像素关系进行详细解释说明,如图6所示,本实施例提供的模型训练方法可以包括:
S601,将样本图像输入至特征提取网络,得到特征提取网络对应的样本特征图。
其中,本实施例的特征提取网络包括老师特征提取网络和学生特征提取网络。
具体的,将样本图像输入到老师特征提取网络和学生特征提取网络中,得到老师特征提取网络对应输出的样本特征图(即第一样本特征图)和学生特征提取网络对应输出的样本特征图(即第二样本特征图)。
S602,根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失。
S603,采用图神经网络确定不同特征提取网络对应的样本特征图中不同像素点之间的特征相似度,作为不同特征提取网络对应的样本特征图中不同像素点之间的像素关系。
其中,图神经网络可以是用于预测样本图像中不同像素点之间的特征相似度的神经网络。
可选的,由于前景像素点之间,或者背景像素点之间的特征相似度较高,背景像素点与前景像素点之间的特征相似度较低,所以本实施例可以是采用图神经网络确定老师特征提取网络对应的第一样本特征图中不同像素点之间的特征相似度,来表征老师特征提取网络对应的样本特征图中不同像素点之间的像素关系(即第一像素关系);采用图神经网络确定学生特征提取网络对应的第二样本特征图中不同像素点之间的特征相似度,来表征学生特征提取网络对应的样本特征图中不同像素点之间的像素关系(即第二像素关系)。需要说明的是,针对老师特征提取网络和学生特征提取网络,预测不同像素点之间的像素关系时使用的图神经网络可以是同一个网络,也可以是两个不同的网络。
具体的,针对每一特征提取网络(即老师特征提取网络或学生特征提取网络),可以将该特征提取网络对应的样本特征图输入到图神经网络中,该图神经网络可以针对样本特征图中的每一像素点,预测其与其他各像素点之间的特征相似度,例如,通过计算两像素点特征值间的余弦距离,作为两像素点之间的特征相似度。若两像素点之间的特征相似度高于预设数值,则说明这两个像素点之间的像素关系为:同属于前景区域或背景区域,否则说明这两个像素点之间的像素关系为:一个属于前景区域,一个属于背景区域。针对样本特征图中的两两像素点都执行完上述操作后,即可得到该特征提取网络对应的样本特征图中不同像素点之间的像素关系。
S604,根据不同特征提取网络对应的样本特征图中不同像素点之间的像素关系,确定第二蒸馏损失。
S605,根据第一蒸馏损失和第二蒸馏损失,对学生特征提取网络进行训练。
本公开实施例的方案,通过分别将样本图像输入到老师特征提取网络和学生特征提取网络,得到老师特征提取网络输出的样本特征图和学生特征提取网络输出的样本特征图,进而根据两种不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失,采用图神经网络确定两种不同特征提取网络对应的样本特征图中不同像素点之间的特征相似度,作为两种不同特征提取网络对应的样本特征图中不同像素点之间的像素关系,根据两种不同特征提取网络对应的像素关系,确定第二蒸馏损失,并同时基于第一蒸馏损失和第二蒸馏损失来训练学生特征提取网络。本方案采用图神经网络预测样本图像中不同像素点之间的相似度来表征不同像素点之间的像素关系,提高了不同像素点之间的像素关系的准确性,为后续基于该像素关系确定损失函数,精准训练学生特征提取网络提供了保障。
可选的,在本公开实施例的基础上,还包括:根据第二蒸馏损失,对图神经网络进行训练。具体的,实施例在基于第一蒸馏损失和第二蒸馏损失训练学生特征提取网络的同时,还可以基于该第二蒸馏损失函数训练图神经网络,即基于第二蒸馏损失联合训练学生特征提取网络和图神经网络,不断优化学生特征提取网络和图神经网络中的网络参数,直到达到训练停止条件,如训练次数达到预设次数,或者第一蒸馏损失收敛等。通过将学生特征提取网络与图神经网络联合训练,进一步提高了模型训练的准确性。
图7是根据本公开实施例提供的一种模型训练的原理框图。如图7所示,本实施例可以先分别将样本图像输入到老师特征提取网络和学生特征提取网中,得到老师特征提取网络输出的第一样本特征图ft,以及学生特征提取网络输出的第二样本特征图fs;再将第一样本特征图ft输入到像素级(即pixel-wise)关系模块1中,该像素级关系模块1基于图神经网络(GCN),按照公式pt=GCN1(ft),预测第一像素关系pt;将第二样本特征图fs输入到像素级(即pixel-wise)关系模块2中,该像素级关系模块1基于图神经网络(GCN)中,按照公式ps=GCN2(fs),预测第二像素关系ps。
接着将第一样本特征图ft输入到物体级(instance-wise)关系模块1中,该物体级关系模块1提取出各目标对象所属区域后,基于类别关系预测网络按照公式Mti=∑jsoftmax(oi*oj)*oi,计算每一目标对象对应的对象关系表示值,进而得到第一样本特征图ft对应的第一对象关系表示Mt。其中,oi和oj分别属于第一样本特征图ft中的第i个目标对象和第j个目标对象所属特征值,softmax()函数用于预测oi和oj属于同一类别的概率;Mti为第一对象关系表示第i个目标对象对应的对象关系表示值。将第二样本特征图fs输入到物体级(instance-wise)关系模块2中,同理,得到第二对象关系表示Ms。
最后,根据第一对象关系表示Mt和第二对象关系表示Ms,计算第一蒸馏损失L2(Mt,Ms),基于第一像素关系pt和第二像素关系ps,计算第二蒸馏损失L2(pt,ps),并基于第一蒸馏损失和第二蒸馏损失(即loss=L2(Mt,Ms)+L2(pt,ps))对学生特征提取网络、物体级关系模块中的类别关系预测网络,以及像素级关系模块中的图神经网络模块进行训练。
优选的,采用本公开上述实施例的模型训练方法训练后的学生特征提取网络可以应用在目标检测场景中,具体的,若目标检测场景使用的是能够执行目标检测任务的检测模型,则本实施例训练好的学生特征提取网络可以属于检测模型中的网络。该检测模型中还至少包括:用于预测目标对象所属类别的分类网络,以及标注目标对象所在位置的回归网络等。
相应的,该学生特征提取网络可以通过上述实施例的方式进行训练,而检测模型中的其他网络,如分类网络和回归网络可以通过如下方式来训练:将样本图像输入至训练后的学生特征提取网络,得到目标特征图;根据目标特征图对检测模型中的其他网络进行训练。具体的,可以是将训练后的学生特征提取网络基于样本图像输出的目标特征图分别输入到其他网络,得到其他网络输出的结果,进而根据其他网络输出的结果和预先标注的监督数据计算损失函数,来对其他网络进行训练。例如,若其他网络为分类网络和回归网络,则可以将目标特征图分别输入到分类网络和回归网络,得到分类网络输出的目标对象的预测类别,以及回归网络输出的目标对象的预测位置,进而基于目标对象的预测类别和预测位置,以及目标对象的真实类别和真实位置(即监督数据),计算损失函数,来对分类网络和回归网络进行训练。
需要说明的是,现有的知识蒸馏技术,通常直接将老师特征提取网络输出的第一样本特征图作为监督数据来与学生特征提取网络输出的第二样本特征图计算蒸馏损失,以对学生特征提取网络进行训练。该训练方式训练后的学生特征提取网络在用于多任务检测时,如应用到包含分类任务和回归任务的检测模型时,准确性较低。而本方案训练后的学生特征提取网络在提取图像特征时,能够更好的体现目标对象之间的类别关系,以及不同像素点之间的像素关系,极大的提高了特征提取的精准性,从而使得该方式训练的学生特征提取网络更好的应用于多任务构成的目标检测模型中。
图8是根据本公开实施例提供的一种模型训练装置的结构示意图。本公开实施例适用于基于蒸馏技术进行模型训练的情况。尤其适用于基于蒸馏技术,训练目标检测场景中的特征提取网络的情况。该装置可以采用软件和/或硬件来实现,该装置可以实现本公开任意实施例的模型训练方法。如图8所示,该模型训练装置800包括:
特征提取模块801,用于将样本图像输入至特征提取网络,得到特征提取网络对应的样本特征图;其中,特征提取网络包括老师特征提取网络和学生特征提取网络;
第一损失确定模块802,用于根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失;
网络训练模块803,用于根据第一蒸馏损失,对学生特征提取网络进行训练。
本公开实施例的方案,通过分别将样本图像输入到老师特征提取网络和学生特征提取网络,得到老师特征提取网络输出的样本特征图和学生特征提取网络输出的样本特征图,进而根据两种不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失,并基于第一蒸馏损失来训练学生特征提取网络。本方案在基于知识蒸馏技术训练学生特征提取网络时,所用的蒸馏损失是基于不同特征提取网络对应的样本特征图中不同目标对象之间的类别关系确定的,使得基于该蒸馏损失训练后的学生特征提取网络在提取图像特征时,能够更好的体现目标对象之间的类别关系,提高了特征提取的精准性。此外,若将本实施例训练的学生特征提取网络应用到目标检测场景,能够更精准的完成目标检测任务。
进一步的,第一损失确定模块802,包括:
关系表示确定单元,用于根据不同特征提取网络对应的样本特征图中至少两个目标对象的特征值,以及至少两个目标对象之间的类别关系,确定不同特征提取网络对应的对象关系表示;
第一损失确定单元,用于根据不同特征提取网络对应的对象关系表示,确定第一蒸馏损失。
进一步的,关系表示确定单元具体用于:
根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定不同特征提取网络对应的样本特征图中每一目标对象的目标关系;
根据不同特征提取网络对应的样本特征图中每一目标对象的特征值,以及该目标对象的目标关系,确定不同特征提取网络对应的对象关系表示。
进一步的,模型训练装置800,还包括:
类别关系确定模块,用于采用类别关系预测网络确定不同特征提取网络对应的样本特征图中不同目标对象属于同一类别的概率值,作为不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系。
进一步的,网络训练模块803还用于:
根据第一蒸馏损失,对类别关系预测网络进行训练。
进一步的,模型训练装置800,还包括:
第二损失确定模块,用于根据不同特征提取网络对应的样本特征图中不同像素点之间的像素关系,确定第二蒸馏损失;
相应的,网络训练模块803用于:
根据第一蒸馏损失和第二蒸馏损失,对学生特征提取网络进行训练。
进一步的,模型训练装置800,还包括:
像素关系确定模块,用于采用图神经网络确定不同特征提取网络对应的样本特征图中不同像素点之间的特征相似度,作为不同特征提取网络对应的样本特征图中不同像素点之间的像素关系。
进一步的,网络训练模块803还用于:
根据第二蒸馏损失,对图神经网络进行训练。
进一步的,学生特征提取网络属于检测模型中的网络;
相应的,特征提取模块801,还用于将样本图像输入至训练后的学生特征提取网络,得到目标特征图;其中,所述训练后的学生特征提取网络采用本公开任一实施例所述的模型训练方法训练得到;
网络训练模块803,还用于根据目标特征图对检测模型中的其他网络进行训练;其中,其他网络至少包括分类网络和回归网络。
上述产品可执行本公开任意实施例所提供的方法,具备执行方法相应的功能模块和有益效果。
本公开的技术方案中,所涉及的任一样本图像以及相关特征等的获取,存储和应用等,均符合相关法律法规的规定,且不违背公序良俗。
根据本公开的实施例,本公开还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
图9示出了可以用来实施本公开的实施例的示例电子设备900的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字助理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
如图9所示,设备900包括计算单元901,其可以根据存储在只读存储器(ROM)902中的计算机程序或者从存储单元908加载到随机访问存储器(RAM)903中的计算机程序,来执行各种适当的动作和处理。在RAM 903中,还可存储设备900操作所需的各种程序和数据。计算单元901、ROM 902以及RAM 903通过总线904彼此相连。输入/输出(I/O)接口905也连接至总线904。
设备900中的多个部件连接至I/O接口905,包括:输入单元906,例如键盘、鼠标等;输出单元907,例如各种类型的显示器、扬声器等;存储单元908,例如磁盘、光盘等;以及通信单元909,例如网卡、调制解调器、无线通信收发机等。通信单元909允许设备900通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元901可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元901的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元901执行上文所描述的各个方法和处理,例如模型训练方法。例如,在一些实施例中,模型训练方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元908。在一些实施例中,计算机程序的部分或者全部可以经由ROM 902和/或通信单元909而被载入和/或安装到设备900上。当计算机程序加载到RAM 903并由计算单元901执行时,可以执行上文描述的模型训练方法的一个或多个步骤。备选地,在其他实施例中,计算单元901可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行模型训练方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、现场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、复杂可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)、区块链网络和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,又称为云计算服务器或云主机,是云计算服务体系中的一项主机产品,以解决了传统物理主机与VPS服务中,存在的管理难度大,业务扩展性弱的缺陷。服务器也可以为分布式系统的服务器,或者是结合了区块链的服务器。
人工智能是研究使计算机来模拟人的某些思维过程和智能行为(如学习、推理、思考、规划等)的学科,既有硬件层面的技术也有软件层面的技术。人工智能硬件技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理等技术;人工智能软件技术主要包括计算机视觉技术、语音识别技术、自然语言处理技术及机器学习/深度学习技术、大数据处理技术、知识图谱技术等几大方向。
云计算(cloud computing),指的是通过网络接入弹性可扩展的共享物理或虚拟资源池,资源可以包括服务器、操作系统、网络、软件、应用和存储设备等,并可以按需、自服务的方式对资源进行部署和管理的技术体系。通过云计算技术,可以为人工智能、区块链等技术应用、模型训练提供高效强大的数据处理能力。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。
Claims (21)
1.一种模型训练方法,包括:
将样本图像输入至特征提取网络,得到所述特征提取网络对应的样本特征图;其中,所述特征提取网络包括老师特征提取网络和学生特征提取网络;
根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失;
根据所述第一蒸馏损失,对所述学生特征提取网络进行训练。
2.根据权利要求1所述的方法,其中,所述根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失,包括:
根据不同特征提取网络对应的样本特征图中至少两个目标对象的特征值,以及所述至少两个目标对象之间的类别关系,确定不同特征提取网络对应的对象关系表示;
根据所述不同特征提取网络对应的对象关系表示,确定第一蒸馏损失。
3.根据权利要求2所述的方法,其中,所述根据不同特征提取网络对应的样本特征图中至少两个目标对象的特征值,以及所述至少两个目标对象之间的类别关系,确定不同特征提取网络对应的对象关系表示,包括:
根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定不同特征提取网络对应的样本特征图中每一目标对象的目标关系;
根据不同特征提取网络对应的样本特征图中每一目标对象的特征值,以及该目标对象的目标关系,确定不同特征提取网络对应的对象关系表示。
4.根据权利要求1所述的方法,还包括:
采用类别关系预测网络确定不同特征提取网络对应的样本特征图中不同目标对象属于同一类别的概率值,作为所述不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系。
5.根据权利要求4所述的方法,还包括:
根据所述第一蒸馏损失,对所述类别关系预测网络进行训练。
6.根据权利要求1所述的方法,还包括:
根据不同特征提取网络对应的样本特征图中不同像素点之间的像素关系,确定第二蒸馏损失;
相应的,根据所述第一蒸馏损失,对所述学生特征提取网络进行训练,包括:
根据所述第一蒸馏损失和所述第二蒸馏损失,对所述学生特征提取网络进行训练。
7.根据权利要求6所述的方法,还包括:
采用图神经网络确定不同特征提取网络对应的样本特征图中不同像素点之间的特征相似度,作为所述不同特征提取网络对应的样本特征图中不同像素点之间的像素关系。
8.根据权利要求7所述的方法,还包括:
根据所述第二蒸馏损失,对所述图神经网络进行训练。
9.根据权利要求1-8中任一项所述的方法,其中,所述学生特征提取网络属于检测模型中的网络;
相应的,所述方法还包括:
将所述样本图像输入至训练后的学生特征提取网络,得到目标特征图;其中,所述训练后的学生特征提取网络采用权利要求1-8中任一所述的模型训练方法训练得到;
根据所述目标特征图对所述检测模型中的其他网络进行训练;其中,所述其他网络至少包括分类网络和回归网络。
10.一种模型训练装置,包括:
特征提取模块,用于将样本图像输入至特征提取网络,得到所述特征提取网络对应的样本特征图;其中,所述特征提取网络包括老师特征提取网络和学生特征提取网络;
第一损失确定模块,用于根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失;
网络训练模块,用于根据所述第一蒸馏损失,对所述学生特征提取网络进行训练。
11.根据权利要求10所述的装置,其中,所述第一损失确定模块,包括:
关系表示确定单元,用于根据不同特征提取网络对应的样本特征图中至少两个目标对象的特征值,以及所述至少两个目标对象之间的类别关系,确定不同特征提取网络对应的对象关系表示;
第一损失确定单元,用于根据所述不同特征提取网络对应的对象关系表示,确定第一蒸馏损失。
12.根据权利要求11所述的装置,其中,所述关系表示确定单元具体用于:
根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定不同特征提取网络对应的样本特征图中每一目标对象的目标关系;
根据不同特征提取网络对应的样本特征图中每一目标对象的特征值,以及该目标对象的目标关系,确定不同特征提取网络对应的对象关系表示。
13.根据权利要求10所述的装置,还包括:
类别关系确定模块,用于采用类别关系预测网络确定不同特征提取网络对应的样本特征图中不同目标对象属于同一类别的概率值,作为所述不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系。
14.根据权利要求13所述的装置,其中,所述网络训练模块还用于:
根据所述第一蒸馏损失,对所述类别关系预测网络进行训练。
15.根据权利要求10所述的装置,还包括:
第二损失确定模块,用于根据不同特征提取网络对应的样本特征图中不同像素点之间的像素关系,确定第二蒸馏损失;
相应的,所述网络训练模块用于:
根据所述第一蒸馏损失和所述第二蒸馏损失,对所述学生特征提取网络进行训练。
16.根据权利要求15所述的装置,还包括:
像素关系确定模块,用于采用图神经网络确定不同特征提取网络对应的样本特征图中不同像素点之间的特征相似度,作为所述不同特征提取网络对应的样本特征图中不同像素点之间的像素关系。
17.根据权利要求16所述的装置,其中,所述网络训练模块还用于:
根据所述第二蒸馏损失,对所述图神经网络进行训练。
18.根据权利要求10-17中任一项所述的装置,其中,所述学生特征提取网络属于检测模型中的网络;
相应的,所述特征提取模块,还用于将所述样本图像输入至训练后的学生特征提取网络,得到目标特征图;其中,所述训练后的学生特征提取网络采用权利要求10-17中任一所述的模型训练装置训练得到;
所述网络训练模块,还用于根据所述目标特征图对所述检测模型中的其他网络进行训练;其中,所述其他网络至少包括分类网络和回归网络。
19.一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-9中任一项所述的模型训练方法。
20.一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行根据权利要求1-9中任一项所述的模型训练方法。
21.一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据权利要求1-9中任一项所述的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210082301.8A CN114494776A (zh) | 2022-01-24 | 2022-01-24 | 一种模型训练方法、装置、设备以及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210082301.8A CN114494776A (zh) | 2022-01-24 | 2022-01-24 | 一种模型训练方法、装置、设备以及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114494776A true CN114494776A (zh) | 2022-05-13 |
Family
ID=81474588
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210082301.8A Pending CN114494776A (zh) | 2022-01-24 | 2022-01-24 | 一种模型训练方法、装置、设备以及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114494776A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115578613A (zh) * | 2022-10-18 | 2023-01-06 | 北京百度网讯科技有限公司 | 目标再识别模型的训练方法和目标再识别方法 |
CN115879446A (zh) * | 2022-12-30 | 2023-03-31 | 北京百度网讯科技有限公司 | 文本处理方法、深度学习模型训练方法、装置以及设备 |
CN116563642A (zh) * | 2023-05-30 | 2023-08-08 | 智慧眼科技股份有限公司 | 图像分类模型可信训练及图像分类方法、装置、设备 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20200302230A1 (en) * | 2019-03-21 | 2020-09-24 | International Business Machines Corporation | Method of incremental learning for object detection |
CN113255701A (zh) * | 2021-06-24 | 2021-08-13 | 军事科学院系统工程研究院网络信息研究所 | 一种基于绝对-相对学习架构的小样本学习方法和系统 |
CN113379718A (zh) * | 2021-06-28 | 2021-09-10 | 北京百度网讯科技有限公司 | 一种目标检测方法、装置、电子设备以及可读存储介质 |
CN113486957A (zh) * | 2021-07-07 | 2021-10-08 | 西安商汤智能科技有限公司 | 神经网络训练和图像处理方法及装置 |
CN113610126A (zh) * | 2021-07-23 | 2021-11-05 | 武汉工程大学 | 基于多目标检测模型无标签的知识蒸馏方法及存储介质 |
-
2022
- 2022-01-24 CN CN202210082301.8A patent/CN114494776A/zh active Pending
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20200302230A1 (en) * | 2019-03-21 | 2020-09-24 | International Business Machines Corporation | Method of incremental learning for object detection |
CN113255701A (zh) * | 2021-06-24 | 2021-08-13 | 军事科学院系统工程研究院网络信息研究所 | 一种基于绝对-相对学习架构的小样本学习方法和系统 |
CN113379718A (zh) * | 2021-06-28 | 2021-09-10 | 北京百度网讯科技有限公司 | 一种目标检测方法、装置、电子设备以及可读存储介质 |
CN113486957A (zh) * | 2021-07-07 | 2021-10-08 | 西安商汤智能科技有限公司 | 神经网络训练和图像处理方法及装置 |
CN113610126A (zh) * | 2021-07-23 | 2021-11-05 | 武汉工程大学 | 基于多目标检测模型无标签的知识蒸馏方法及存储介质 |
Non-Patent Citations (2)
Title |
---|
JINGUO ZHU: "Complementary Relation Contrastive Distillation", 《CVPR 2021》, 31 December 2021 (2021-12-31) * |
赖叶静;郝珊锋;黄定江;: "深度神经网络模型压缩方法与进展", 华东师范大学学报(自然科学版), no. 05, 25 September 2020 (2020-09-25) * |
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115578613A (zh) * | 2022-10-18 | 2023-01-06 | 北京百度网讯科技有限公司 | 目标再识别模型的训练方法和目标再识别方法 |
CN115578613B (zh) * | 2022-10-18 | 2024-03-08 | 北京百度网讯科技有限公司 | 目标再识别模型的训练方法和目标再识别方法 |
CN115879446A (zh) * | 2022-12-30 | 2023-03-31 | 北京百度网讯科技有限公司 | 文本处理方法、深度学习模型训练方法、装置以及设备 |
CN115879446B (zh) * | 2022-12-30 | 2024-01-12 | 北京百度网讯科技有限公司 | 文本处理方法、深度学习模型训练方法、装置以及设备 |
CN116563642A (zh) * | 2023-05-30 | 2023-08-08 | 智慧眼科技股份有限公司 | 图像分类模型可信训练及图像分类方法、装置、设备 |
CN116563642B (zh) * | 2023-05-30 | 2024-02-27 | 智慧眼科技股份有限公司 | 图像分类模型可信训练及图像分类方法、装置、设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112801164A (zh) | 目标检测模型的训练方法、装置、设备及存储介质 | |
CN113657465A (zh) | 预训练模型的生成方法、装置、电子设备和存储介质 | |
CN114494776A (zh) | 一种模型训练方法、装置、设备以及存储介质 | |
CN113392253B (zh) | 视觉问答模型训练及视觉问答方法、装置、设备及介质 | |
CN113642431A (zh) | 目标检测模型的训练方法及装置、电子设备和存储介质 | |
CN113361578A (zh) | 图像处理模型的训练方法、装置、电子设备及存储介质 | |
US20230186607A1 (en) | Multi-task identification method, training method, electronic device, and storage medium | |
CN113705628B (zh) | 预训练模型的确定方法、装置、电子设备以及存储介质 | |
CN114494784A (zh) | 深度学习模型的训练方法、图像处理方法和对象识别方法 | |
CN113947188A (zh) | 目标检测网络的训练方法和车辆检测方法 | |
CN114187459A (zh) | 目标检测模型的训练方法、装置、电子设备以及存储介质 | |
CN114648676A (zh) | 点云处理模型的训练和点云实例分割方法及装置 | |
CN112560985A (zh) | 神经网络的搜索方法、装置及电子设备 | |
CN114715145B (zh) | 一种轨迹预测方法、装置、设备及自动驾驶车辆 | |
CN112784732A (zh) | 地物类型变化的识别、模型训练方法、装置、设备及介质 | |
CN113592932A (zh) | 深度补全网络的训练方法、装置、电子设备及存储介质 | |
CN114417118A (zh) | 一种异常数据处理方法、装置、设备以及存储介质 | |
CN113961765B (zh) | 基于神经网络模型的搜索方法、装置、设备和介质 | |
CN114581732A (zh) | 一种图像处理及模型训练方法、装置、设备和存储介质 | |
CN114547252A (zh) | 文本识别方法、装置、电子设备和介质 | |
CN112560480B (zh) | 任务社区发现方法、装置、设备和存储介质 | |
CN115273148A (zh) | 行人重识别模型训练方法、装置、电子设备及存储介质 | |
CN114912541A (zh) | 分类方法、装置、电子设备和存储介质 | |
CN114330576A (zh) | 模型处理方法、装置、图像识别方法及装置 | |
CN114417029A (zh) | 模型训练方法、装置、电子设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |