CN115240250A - 模型训练方法、装置、计算机设备及可读存储介质 - Google Patents
模型训练方法、装置、计算机设备及可读存储介质 Download PDFInfo
- Publication number
- CN115240250A CN115240250A CN202210805013.0A CN202210805013A CN115240250A CN 115240250 A CN115240250 A CN 115240250A CN 202210805013 A CN202210805013 A CN 202210805013A CN 115240250 A CN115240250 A CN 115240250A
- Authority
- CN
- China
- Prior art keywords
- network
- loss function
- picture
- characteristic matrix
- student
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 85
- 238000012549 training Methods 0.000 title claims abstract description 52
- 239000011159 matrix material Substances 0.000 claims abstract description 105
- 238000012545 processing Methods 0.000 claims abstract description 18
- 230000006870 function Effects 0.000 claims description 99
- 238000004590 computer program Methods 0.000 claims description 9
- 238000004821 distillation Methods 0.000 claims description 7
- 238000005457 optimization Methods 0.000 claims description 2
- 238000013140 knowledge distillation Methods 0.000 abstract description 13
- 238000000605 extraction Methods 0.000 description 10
- 238000004364 calculation method Methods 0.000 description 5
- 238000010586 diagram Methods 0.000 description 5
- 238000013528 artificial neural network Methods 0.000 description 4
- 239000011521 glass Substances 0.000 description 4
- 238000005516 engineering process Methods 0.000 description 3
- 230000001815 facial effect Effects 0.000 description 3
- 238000003062 neural network model Methods 0.000 description 3
- 238000004891 communication Methods 0.000 description 2
- 238000013527 convolutional neural network Methods 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 230000001133 acceleration Effects 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000009795 derivation Methods 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 238000007599 discharging Methods 0.000 description 1
- 210000004709 eyebrow Anatomy 0.000 description 1
- 230000005484 gravity Effects 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 239000000203 mixture Substances 0.000 description 1
- 238000010295 mobile communication Methods 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 238000010079 rubber tapping Methods 0.000 description 1
- 230000005236 sound signal Effects 0.000 description 1
- 238000010897 surface acoustic wave method Methods 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V40/00—Recognition of biometric, human-related or animal-related patterns in image or video data
- G06V40/10—Human or animal bodies, e.g. vehicle occupants or pedestrians; Body parts, e.g. hands
- G06V40/16—Human faces, e.g. facial parts, sketches or expressions
- G06V40/168—Feature extraction; Face representation
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
- G06V10/44—Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components
- G06V10/443—Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components by matching or filtering
- G06V10/449—Biologically inspired filters, e.g. difference of Gaussians [DoG] or Gabor filters
- G06V10/451—Biologically inspired filters, e.g. difference of Gaussians [DoG] or Gabor filters with interaction between the filter responses, e.g. cortical complex cells
- G06V10/454—Integrating the filters into a hierarchical structure, e.g. convolutional neural networks [CNN]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V40/00—Recognition of biometric, human-related or animal-related patterns in image or video data
- G06V40/10—Human or animal bodies, e.g. vehicle occupants or pedestrians; Body parts, e.g. hands
- G06V40/16—Human faces, e.g. facial parts, sketches or expressions
- G06V40/172—Classification, e.g. identification
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Health & Medical Sciences (AREA)
- Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- General Physics & Mathematics (AREA)
- Multimedia (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Computing Systems (AREA)
- Oral & Maxillofacial Surgery (AREA)
- Software Systems (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Molecular Biology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Human Computer Interaction (AREA)
- Biomedical Technology (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Mathematical Physics (AREA)
- General Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Biodiversity & Conservation Biology (AREA)
- Image Analysis (AREA)
Abstract
本申请实施例提供了一种模型训练方法、装置、计算机设备及可读存储介质,利用具备多任务特征处理功能的教师网络训练学生网络,所述模型训练方法包括:通过预设的教师网络和学生网络分别处理样本图片,以得到教师网络的中间层输出的第一图片特征矩阵,以及,学生网络的中间层输出的第二图片特征矩阵;使用第二图片特征矩阵拟合所述第一图片特征矩阵,构建用于优化学生网络模型参数的第一损失函数。本申请通过使用教师网络的图片特征矩阵拟合学生网络的图片特征矩阵的方式,能够有效节省多任务网络知识蒸馏的时间。
Description
技术领域
本申请涉及深度学习领域,尤其涉及一种模型训练方法、装置、计算机设备及可读存储介质。
背景技术
在工业部署中因为计算力等问题,需要用到小模型骨干做回归和分类任务。小模型的表现效果往往很差,知识蒸馏是一种大模型帮助小模型辅助提升准确率的方法,被广泛的用于模型压缩和迁移学习当中。传统的知识蒸馏都是对于单分类任务而言,即将小模型中分类任务中的one-hot编码改变成作为教师网络训练出来的soft-target编码。但对于多任务网络,需要用一个主干网络来预测多个任务,任务中还包含着分类任务和回归任务。如果用传统知识蒸馏则要同时逐个训练多个教师网络。对于回归任务因为没有one-hot编码无法进行知识蒸馏。
可见,对于多任务网络要进行知识蒸馏时,要训练多个教师网络进行知识蒸馏存在浪费时间的问题。
发明内容
为了解决上述技术问题,本申请实施例提供了一种模型训练方法、装置、计算机设备及可读存储介质。
第一方面,本申请实施例提供了一种模型训练方法,利用具备多任务特征处理功能的教师网络训练学生网络,所述方法包括:
利用样本图片分别输入教师网络和学生网络,获取所述教师网络的中间层输出的第一图片特征矩阵,以及,所述学生网络的中间层输出的第二图片特征矩阵;
利用所述第二图片特征矩阵拟合所述第一图片特征矩阵,以构建第一损失函数;
利用所述第一损失函数优化所述学生网络的模型参数。
根据本申请实施例的一种具体实施方式,所述利用样本图片分别输入教师网络和学生网络的步骤之后,所述方法还包括:
获取所述教师网络的分类层输出的第三图片特征矩阵,以及,所述学生网络的分类层输出的第四图片特征矩阵;
利用所述第四图片特征矩阵拟合所述第三图片特征矩阵,以构建第二损失函数;
所述利用所述第一损失函数优化所述学生网络的模型参数的步骤,包括:
利用所述第一损失函数和所述第二损失函数优化所述学生网络的模型参数。
根据本申请实施例的一种具体实施方式,多任务的数量为N,多任务包括分类任务和回归任务;
所述利用样本图片分别输入教师网络和学生网络的步骤之后,所述方法还包括:
获取所述教师网络针对N个任务输出的教师结果真值,以及,所述学生网络针对N个任务输出的学生结果真值;
利用各任务对应的学生结果真值拟合对应的教师结果真值,构建对应各任务的第三损失函数;
所述利用所述第一损失函数和所述第二损失函数优化所述学生网络的模型参数的步骤,包括:
利用所述第一损失函数、所述第二损失函数和全部任务对应的第三损失函数优化所述学生网络的模型参数。
根据本申请实施例的一种具体实施方式,所述获取所述教师网络的中间层输出的第一图片特征矩阵,以及,所述学生网络的中间层输出的第二图片特征矩阵的步骤之前,所述方法还包括:
使用注意力特征约束操作提取所述样本图片的注意力焦点区域的特征。
根据本申请实施例的一种具体实施方式,注意力特征约束操作采用使用CA-attention block技术。
根据本申请实施例的一种具体实施方式,所述利用所述第二图片特征矩阵拟合所述第一图片特征矩阵,以构建第一损失函数的步骤,包括:
采用MSEloss对所述第二图片特征矩阵和所述第一图片特征矩阵进行特征蒸馏,构建所述第一损失函数。
根据本申请实施例的一种具体实施方式,所述利用所述第四图片特征矩阵拟合所述第三图片特征矩阵,以构建第二损失函数的步骤,包括:
采用SmoothL1loss对所述第四图片特征矩阵和所述第三图片特征矩阵进行特征蒸馏,构建所述第二损失函数。
第二方面,本申请实施例提供了一种模型训练装置,利用具备多任务特征处理功能的教师网络训练学生网络,所述装置包括:
获取模块,用于利用样本图片分别输入教师网络和学生网络,获取所述教师网络的中间层输出的第一图片特征矩阵,以及,所述学生网络的中间层输出的第二图片特征矩阵;
拟合模块,用于利用所述第二图片特征矩阵拟合所述第一图片特征矩阵,以构建第一损失函数;
优化模块,用于利用所述第一损失函数优化所述学生网络的模型参数。
第三方面,本申请实施例提供了一种计算机设备,所述计算机设备包括存储器以及处理器,所述存储器用于存储计算机程序,所述计算机程序在所述处理器运行时执行第一方面及第一方面任一实施方式提供的模型训练方法。
第四方面,本申请实施例提供了一种计算机可读存储介质,其存储有计算机程序,所述计算机程序在处理器上运行时执行第一方面及第一方面任一实施方式提供的模型训练方法。
上述本申请提供的模型训练方法、装置、计算机设备及可读存储介质,利用具备多任务特征处理功能的教师网络训练学生网络,所述模型训练方法包括:通过预设的教师网络和学生网络分别处理样本图片,以得到教师网络的中间层输出的第一图片特征矩阵,以及,学生网络的中间层输出的第二图片特征矩阵;使用第二图片特征矩阵拟合所述第一图片特征矩阵,构建用于优化学生网络模型参数的第一损失函数。本申请通过使用教师网络的图片特征矩阵拟合学生网络的图片特征矩阵的方式,能够有效节省多任务网络知识蒸馏的时间。
附图说明
为了更清楚地说明本申请的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,应当理解,以下附图仅示出了本申请的某些实施例,因此不应被看作是对本申请保护范围的限定。在各个附图中,类似的构成部分采用类似的编号。
图1示出了本申请实施例提供的一种模型训练方法的方法流程示意图之一;
图2示出了本申请实施例提供的一种模型训练方法的方法流程示意图之二;
图3示出了本申请实施例提供的一种模型训练方法的方法流程示意图之三;
图4示出了本申请实施例提供的一种模型训练装置的装置模块示意图;
图5示出了本申请实施例提供的一种计算机设备的结构连接示意图。
具体实施方式
下面将结合本申请实施例中附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。
通常在此处附图中描述和示出的本申请实施例的组件可以以各种不同的配置来布置和设计。因此,以下对在附图中提供的本申请的实施例的详细描述并非旨在限制要求保护的本申请的范围,而是仅仅表示本申请的选定实施例。基于本申请的实施例,本领域技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都属于本申请保护的范围。
在下文中,可在本申请的各种实施例中使用的术语“包括”、“具有”及其同源词仅意在表示特定特征、数字、步骤、操作、元件、组件或前述项的组合,并且不应被理解为首先排除一个或更多个其它特征、数字、步骤、操作、元件、组件或前述项的组合的存在或增加一个或更多个特征、数字、步骤、操作、元件、组件或前述项的组合的可能性。
此外,术语“第一”、“第二”、“第三”等仅用于区分描述,而不能理解为指示或暗示相对重要性。
除非另有限定,否则在这里使用的所有术语(包括技术术语和科学术语)具有与本申请的各种实施例所属领域普通技术人员通常理解的含义相同的含义。所述术语(诸如在一般使用的词典中限定的术语)将被解释为具有与在相关技术领域中的语境含义相同的含义并且将不被解释为具有理想化的含义或过于正式的含义,除非在本申请的各种实施例中被清楚地限定。
实施例1
参考图1,为本公开实施例提供的一种模型训练方法的方法流程示意图之一,本实施例提供的模型训练方法,利用具备多任务特征处理功能的教师网络训练学生网络,如图1所示,所述模型训练方法包括:
步骤S101,利用样本图片分别输入教师网络和学生网络,获取所述教师网络的中间层输出的第一图片特征矩阵,以及,所述学生网络的中间层输出的第二图片特征矩阵;
步骤S104,利用所述第二图片特征矩阵拟合所述第一图片特征矩阵,以构建第一损失函数;
步骤S107,利用所述第一损失函数优化所述学生网络的模型参数。
本实施例提出的模型训练方法,是一种在教师网络和学生网络之间进行多任务特征知识蒸馏的神经网络模型训练方法。在本实施例的模型训练方法中,可以在多任务网络之间应用知识蒸馏(Knowledge Distillation)方法来训练学生网络的神经网络模型。
多任务网络是一种具备多个待检测任务的深度学习卷积神经网络(Convolutional Neural Networks,简称CNN)。对于多任务网络,任务中往往包括分类任务和回归任务,其中,分类任务是一种输出变量为有限个离散变量的预测任务,回归任务是一种输入变量和输出变量均为连续变量的预测任务。
多任务网络中还包括一个预设的特征提取模型,所述特征提取模型识别样本图片后,能够根据预先设置好的特征提取条件提取预设数量的任务。具体的,所述特征提取模型的结构可以根据实际应用场景进行自适应替换。例如,当所述多任务网络被应用于人脸关键点识别时,能够自动提取人脸部的鼻子、眼睛、嘴巴和眉毛等预测任务。
本实施例中的教师网络和学生网络均为一种多任务网络,能够根据样本图片获取多个任务,并根据训练好的神经网络模型对任务进行预测。
提前训练好具备多层神经网络的教师网络和具备多层神经网络的学生网络,用户将样本图片输入至所述教师网络和所述学生网络后,会得到教师网络各中间层输出的第一图片特征矩阵,以及学生网络各中间层输出的第二图片特征矩阵。
在实际应用过程中,教师网络和学生网络的神经网络结构均包括特征提取层、中间层以及分类层,其中,所述特征提取层作为第一层,用于获取所述样本图片,并将样本图片转换为图片特征矩阵输出至所述中间层;所述中间层对于第一层的处理结果进行进一步处理,以得到更贴合任务特征的图片特征矩阵;所述分类层作为最后一层与全连接层(Fully Connect Layer)连接,所述分类层用于输出最终进行预测的图片特征矩阵。
根据本申请实施例的一种具体实施方式,所述获取所述教师网络的中间层输出的第一图片特征矩阵,以及,所述学生网络的中间层输出的第二图片特征矩阵的步骤之前,所述方法还包括:
使用注意力特征约束操作提取所述样本图片的注意力焦点区域的特征。
在具体实施例中,学生网络和教师网络的非线性能力并不相同,如果直接使学生网络的中间层输出的第二图片特征矩阵拟合教师网络的中间层输出的第一图片特征矩阵,可能会使得学生网络在学习过程中鲁棒性不够。
本实施例还采用注意力特征约束方法来获取对应样本图片的注意力特征图。具体的,所述注意力特征图,是对样本图像中的注意力交点区域投入更多注意力资源得到的特征矩阵。所述注意力特征图能够获取更多需要关注的目标的细节信息,从而抑制其它无用信息。
具体的,本实施例中的注意力特征约束操作可以采用神经网络训练中任意一种注意力约束技术。
根据本申请实施例的一种具体实施方式,注意力特征约束操作采用使用CA-attention block技术。
在训练学生网络时,本实施例采用所述CA-attention block技术对教师网络中间层输出的第一图片特征矩阵进行蒸馏,从而能够让学生网络的中间层在训练过程中更加灵活的模仿教师网络中需要注意的特征。
根据本申请实施例的一种具体实施方式,所述利用所述第二图片特征矩阵拟合所述第一图片特征矩阵,以构建第一损失函数的步骤,包括:
采用MSEloss对所述第二图片特征矩阵和所述第一图片特征矩阵进行特征蒸馏,构建所述第一损失函数。
具体的,使用学生网络的第二图片特征矩阵对教师网络的第一图片矩阵进行知识蒸馏,以实现利用所述第二图片特征矩阵拟合所述第一图片特征矩阵。
并基于均方损失函数(MSEloss)的计算规则处理拟合后的所述第一图片矩阵和所述第二图片矩阵,从而得到对应的第一损失函数Lattention map。
具体的,第一损失函数为对应注意力特征约束的损失函数。
在得到所述第一损失函数后,可以将所述第一损失函数应用在学生网络的模型训练过程中,通过反向链式求导的方式进行反向传播,以利用教师网络中的模型参数优化所述学生网络的模型参数。
如图2所示,根据本申请实施例的一种具体实施方式,所述利用样本图片分别输入教师网络和学生网络的步骤之后,所述方法还包括:
步骤S102,获取所述教师网络的分类层输出的第三图片特征矩阵,以及,所述学生网络的分类层输出的第四图片特征矩阵;
步骤S105,利用所述第四图片特征矩阵拟合所述第三图片特征矩阵,以构建第二损失函数;
步骤S108,利用所述第一损失函数和所述第二损失函数优化所述学生网络的模型参数。
在具体实施例中,当所述教师网络和所述学生网络的分类层接收到中间层发送过来的第一图片特征矩阵和第二图片特征矩阵后,会处理得到对应的第三图片特征矩阵和第四图片特征矩阵。
让学生网络的分类层输出的第四图片特征矩阵尽可能的拟合教师网络的分类层输出的第三图片特征矩阵,能够最大程度的使学生网络蒸馏得到所述教师网络中的知识。
根据本申请实施例的一种具体实施方式,所述利用所述第四图片特征矩阵拟合所述第三图片特征矩阵,以构建第二损失函数的步骤,包括:
采用SmoothL1loss对所述第四图片特征矩阵和所述第三图片特征矩阵进行特征蒸馏,构建所述第二损失函数。
在具体实施例中,以教师网络分类层输出的第三图片特征矩阵作为知识,让学生网络分类层输出的第四图片特征矩阵拟合所述第三图片特征矩阵,能够有效提升学生网络的预测效果。
采用Smooth L1损失函数(SmoothL1loss)的计算规则处理拟合后的第三图片特征矩阵和第四图片特征矩阵,可以得到对应特征约束的第二损失函数Lfeature map。
具体的,Smooth L1损失函数是一种更为平滑的损失函数,能够让损失函数的离群点更加鲁棒。
需知的,第二损失函数和第一损失函数也可以采用其它损失函数计算规则进行计算,损失函数的计算规则根据实际应用场景进行决定。
在得到所述第二损失函数后,将所述第一损失函数和所述第二损失函数均应用于学生网络的训练过程中,具体应用过程可以参考上述实施例中第一损失函数的应用过程,此处不再赘述。
参考图3,根据本申请实施例的一种具体实施方式,多任务的数量为N,多任务包括分类任务和回归任务;
所述利用样本图片分别输入教师网络和学生网络的步骤之后,所述模型训练方法还包括:
步骤S103,获取所述教师网络针对N个任务输出的教师结果真值,以及,所述学生网络针对N个任务输出的学生结果真值;
步骤S106,利用各任务对应的学生结果真值拟合对应的教师结果真值,构建对应各任务的第三损失函数;
步骤S109,利用所述第一损失函数、所述第二损失函数和全部任务对应的第三损失函数优化所述学生网络的模型参数。
在具体实施例中,可以通过预设的特征提取网络对样本图片进行识别,进而得到N个不同的任务。
在实际应用中,教师网络和学生网络的特征提取网络可以根据实际应用需求进行自适应构建,以提取训练模型所需的各项任务。
具体的,任务的数量N与特征提取网络的结构和样本图片的类型相关。
举例来说,在进行人脸关键点检测(A Pratical Facial Landmark Detector,简称PFLD)时,对任一样本图片进行识别后,能够得到包括人脸姿态、关键点、人脸质量、年龄、性别、微笑、眼睛、口罩和颜值,9个任务。
基于输入变量与输出变量均为连续变量的预测问题是回归问题,输出变量为有限个离散变量的预测问题为分类问题的划分规则,对上述9个任务进行划分,能够将人脸姿态、关键点、微笑、人脸质量、年龄和颜值划分为回归问题。将性别、眼镜和口罩划分为分类任务。
使用教师网络和学生网络分别对上述9个任务进行预测,可以得到对应9个任务的9个教师结果真值和9个学生结果真值。具体的,教师结果真值为教师网络对任务的预测值,学生结果真值为学生网络对任务的预测值。
在一种实施例中,采用PFLD 1x作为教师网络,采用PFLD 0.5x作为学生网络,对所述教师网络和所述学生网络进行特征知识蒸馏,能够得到对应各任务的预测值。
教师网络输出对应各任务的预测值包括:人脸姿态6.22、关键点19.58、年龄3.5、微笑4.7、人脸质量9.0、颜值4.0、性别95.74%、眼镜97.28%和口罩99.87%。
未使用本实施例提供的模型训练方法得到的学生网络输出各对应各任务预测值包括:人脸姿态6.56、关键点21.56、年龄3.77、微笑5.05、人脸质量9.4、颜值4.31、性别94.56%、眼镜97.16%和口罩99.79%。
使用本实施例提供的模型训练方法得到的学生网络输出对应各任务的预测值包括:人脸姿态6.22、关键点20.53、年龄3.67、微笑4.73、人脸质量9.02、颜值4.12、性别95.65%、眼镜97.26%和口罩99.82%。
使用本实施例提供的模型训练方法得到的学生网络处理各任务得到的预测值,与使用教师网络处理各任务得到的预测值要更加接近。也就是说,使用本实施例提供的模型训练方法能够得到一个识别更加准确的学生网络。
本实施例提供的模型训练方法,在训练得到多个图片特征矩阵后,将学生网络各中间层的特征与教师网络各中间层的特征进行拟合,并根据拟合后的结果计算第一损失函数;将学生网络最后一层的特征与教师网络最后一层的特征进行拟合,并根据拟合后的结果计算第二损失函数;对于教师网络来说,在输入样本图片后,会根据预设的特征提取网络对图片中的任务进行分类提取,并将各任务划分为回归任务和分类任务。对各回归任务和各分类任务进行损失函数计算,以得到对应的第三损失函数。
在得到所述第一损失函数、所述第二损失函数和所述第三损失函数后,将三种损失函数输入至所述学生网络的训练过程中,所述学生网络根据这三种损失函数进行反向传播后,将基于教师网络得到的各类损失函数添加至学生网络的训练过程,从而通过知识蒸馏的方法,优化了所述学生网络的组成,有效提升了学生网络的预测精度。
实施例2
参考图4,为本公开实施例提供了一种模型训练装置400的装置模块示意图,利用具备多任务特征处理功能的教师网络训练学生网络,如图4所示,所述模型训练装置400包括:
获取模块401,用于利用样本图片分别输入教师网络和学生网络,获取所述教师网络的中间层输出的第一图片特征矩阵,以及,所述学生网络的中间层输出的第二图片特征矩阵;
拟合模块402,用于利用所述第二图片特征矩阵拟合所述第一图片特征矩阵,以构建第一损失函数;
优化模块403,用于利用所述第一损失函数优化所述学生网络的模型参数。
本实施例提供模型训练装置400的具体实施过程可以参考上述方法实施例1中的具体实施过程,为避免重复,在此不再赘述。
实施例3
此外,本公开实施例提供了一种计算机设备,所述计算机设备包括存储器以及处理器,所述存储器存储有计算机程序,所述计算机程序在所述处理器上运行时执行上述方法实施例1所提供的模型训练方法。
具体的,如图5所示,本实施例提供的计算机设备500包括:
射频单元501、网络模块502、音频输出单元503、输入单元504、传感器505、显示单元506、用户输入单元507、接口单元508、存储器509、处理器510、以及电源511等部件。本领域技术人员可以理解,图5中示出的计算机设备结构并不构成对计算机设备的限定,计算机设备可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。在本申请实施例中,计算机设备包括但不限于手机、平板电脑、笔记本电脑、掌上电脑、车载终端、可穿戴设备、以及计步器等。
应理解的是,本申请实施例中,射频单元501可用于收发信息或通话过程中,信号的接收和发送,具体的,将来自基站的下行数据接收后,给处理器510处理;另外,将上行的数据发送给基站。通常,射频单元501包括但不限于天线、至少一个放大器、收发信机、耦合器、低噪声放大器、双工器等。此外,射频单元501还可以通过无线通信系统与网络和其他设备通信。
计算机设备通过网络模块502为用户提供了无线的宽带互联网访问,如帮助用户收发电子邮件、浏览网页和访问流式媒体等。
音频输出单元503可以将射频单元501或网络模块502接收的或者在存储器509中存储的音频数据转换成音频信号并且输出为声音。而且,音频输出单元503还可以提供与计算机设备500执行的特定功能相关的音频输出(例如,呼叫信号接收声音、消息接收声音等等)。音频输出单元503包括扬声器、蜂鸣器以及受话器等。
输入单元504用于接收音频或视频信号。输入单元504可以包括图形处理器(Graphics Processing Unit,GPU)5041和麦克风5042,图形处理器5041对在视频捕获模式或图像捕获模式中由图像捕获终端(如摄像头)获得的静态图片或视频的图像数据进行处理。处理后的图像帧可以视频播放在显示单元506上。经图形处理器5041处理后的图像帧可以存储在存储器509(或其它存储介质)中或者经由射频单元501或网络模块502进行发送。麦克风5042可以接收声音,并且能够将这样的声音处理为音频数据。处理后的音频数据可以在电话通话模式的情况下转换为可经由射频单元501发送到移动通信基站的格式输出。
计算机设备500还包括至少一种传感器505,比如光传感器、运动传感器以及其他传感器。具体地,光传感器包括环境光传感器及接近传感器,其中,环境光传感器可根据环境光线的明暗来调节显示面板5061的亮度,接近传感器可在计算机设备500移动到耳边时,关闭显示面板5061和/或背光。作为运动传感器的一种,加速计传感器可检测各个方向上(一般为三轴)加速度的大小,静止时可检测出重力的大小及方向,可用于识别计算机设备姿态(比如横竖屏切换、相关游戏、磁力计姿态校准)、振动识别相关功能(比如计步器、敲击)等;传感器505还可以包括指纹传感器、压力传感器、虹膜传感器、分子传感器、陀螺仪、气压计、湿度计、温度计、红外线传感器等,在此不再赘述。
显示单元506用于视频播放由用户输入的信息或提供给用户的信息。显示单元506可包括显示面板5061,可以采用液晶视频播放器(Liquid Crystal Display,LCD)、有机发光二极管(Organic Light-Emitting Diode,OLED)等形式来配置显示面板5061。
用户输入单元507可用于接收输入的数字或字符信息,以及产生与计算机设备的用户设置以及功能控制有关的键信号输入。具体地,用户输入单元507包括触控面板5071以及其他输入设备5072。触控面板5071,也称为触摸屏,可收集用户在其上或附近的触摸操作(比如用户使用手指、触笔等任何适合的物体或附件在触控面板5071上或在触控面板5071附近的操作)。触控面板5071可包括触摸检测计算机设备和触摸控制器两个部分。其中,触摸检测计算机设备检测用户的触摸方位,并检测触摸操作带来的信号,将信号传送给触摸控制器;触摸控制器从触摸检测计算机设备上接收触摸信息,并将它转换成触点坐标,再送给处理器510,接收处理器510发来的命令并加以执行。此外,可以采用电阻式、电容式、红外线以及表面声波等多种类型实现触控面板5071。除了触控面板5071,用户输入单元507还可以包括其他输入设备5072。具体地,其他输入设备5072可以包括但不限于物理键盘、功能键(比如音量控制按键、开关按键等)、轨迹球、鼠标、操作杆,在此不再赘述。
进一步的,触控面板5071可覆盖在显示面板5061上,当触控面板5071检测到在其上或附近的触摸操作后,传送给处理器510以确定触摸事件的类型,随后处理器510根据触摸事件的类型在显示面板5061上提供相应的视觉输出。虽然在图5中,触控面板5071与显示面板5061是作为两个独立的部件来实现计算机设备的输入和输出功能,但是在某些实施例中,可以将触控面板5071与显示面板5061集成而实现计算机设备的输入和输出功能,具体此处不做限定。
接口单元508为外部计算机设备与计算机设备500连接的接口。例如,外部计算机设备可以包括有线或无线头戴式耳机端口、外部电源(或电池充电器)端口、有线或无线数据端口、存储卡端口、用于连接具有识别模块的计算机设备的端口、音频输入/输出(I/O)端口、视频I/O端口、耳机端口等等。接口单元508可以用于接收来自外部计算机设备的输入(例如,数据信息、电力等等)并且将接收到的输入传输到计算机设备500内的一个或多个元件或者可以用于在计算机设备500和外部计算机设备之间传输数据。
存储器509可用于存储软件程序以及各种数据。存储器509可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序(比如声音播放功能、图像播放功能等)等;存储数据区可存储根据手机的使用所创建的数据(比如音频数据、电话本等)等。此外,存储器509可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。
处理器510是计算机设备的控制中心,利用各种接口和线路连接整个计算机设备的各个部分,通过运行或执行存储在存储器509内的软件程序和/或模块,以及调用存储在存储器509内的数据,执行计算机设备的各种功能和处理数据,从而对计算机设备进行整体监控。处理器510可包括一个或多个处理单元;优选的,处理器510可集成应用处理器和调制解调处理器,其中,应用处理器主要处理操作系统、用户界面和应用程序等,调制解调处理器主要处理无线通信。可以理解的是,上述调制解调处理器也可以不集成到处理器510中。
计算机设备500还可以包括给各个部件供电的电源511(比如电池),优选的,电源511可以通过电源管理系统与处理器510逻辑相连,从而通过电源管理系统实现管理充电、放电、以及功耗管理等功能。
另外,计算机设备500包括一些未示出的功能模块,在此不再赘述。
实施例4
本申请还提供一种计算机可读存储介质,所述计算机可读存储介质上存储计算机程序,所述计算机程序在处理器上运行时执行实施例1提供的模型训练方法。
在本实施例中,计算机可读存储介质可以为只读存储器(Read-Only Memory,简称ROM)、随机存取存储器(Random Access Memory,简称RAM)、磁碟或者光盘等。
本实施例提供的计算机可读存储介质可以实施例1所示的模型训练方法,为避免重复,在此不再赘述。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者终端不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者终端所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者终端中还存在另外的相同要素。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端(可以是手机,计算机,服务器,空调器,或者网络设备等)执行本申请各个实施例所述的方法。
上面结合附图对本申请的实施例进行了描述,但是本申请并不局限于上述的具体实施方式,上述的具体实施方式仅仅是示意性的,而不是限制性的,本领域的普通技术人员在本申请的启示下,在不脱离本申请宗旨和权利要求所保护的范围情况下,还可做出很多形式,均属于本申请的保护之内。
Claims (10)
1.一种模型训练方法,其特征在于,利用具备多任务特征处理功能的教师网络训练学生网络,所述方法包括:
利用样本图片分别输入教师网络和学生网络,获取所述教师网络的中间层输出的第一图片特征矩阵,以及,所述学生网络的中间层输出的第二图片特征矩阵;
利用所述第二图片特征矩阵拟合所述第一图片特征矩阵,以构建第一损失函数;
利用所述第一损失函数优化所述学生网络的模型参数。
2.根据权利要求1所述的方法,其特征在于,所述利用样本图片分别输入教师网络和学生网络的步骤之后,所述方法还包括:
获取所述教师网络的分类层输出的第三图片特征矩阵,以及,所述学生网络的分类层输出的第四图片特征矩阵;
利用所述第四图片特征矩阵拟合所述第三图片特征矩阵,以构建第二损失函数;
所述利用所述第一损失函数优化所述学生网络的模型参数的步骤,包括:
利用所述第一损失函数和所述第二损失函数优化所述学生网络的模型参数。
3.根据权利要求2所述的方法,其特征在于,多任务的数量为N,多任务包括分类任务和回归任务;
所述利用样本图片分别输入教师网络和学生网络的步骤之后,所述方法还包括:
获取所述教师网络针对N个任务输出的教师结果真值,以及,所述学生网络针对N个任务输出的学生结果真值;
利用各任务对应的学生结果真值拟合对应的教师结果真值,构建对应各任务的第三损失函数;
所述利用所述第一损失函数和所述第二损失函数优化所述学生网络的模型参数的步骤,包括:
利用所述第一损失函数、所述第二损失函数和全部任务对应的第三损失函数优化所述学生网络的模型参数。
4.根据权利要求1所述的方法,其特征在于,所述获取所述教师网络的中间层输出的第一图片特征矩阵,以及,所述学生网络的中间层输出的第二图片特征矩阵的步骤之前,所述方法还包括:
使用注意力特征约束操作提取所述样本图片的注意力焦点区域的特征。
5.根据权利要求4所述的方法,其特征在于,注意力特征约束操作采用使用CA-attention block技术。
6.根据权利要求3所述的方法,其特征在于,所述利用所述第二图片特征矩阵拟合所述第一图片特征矩阵,以构建第一损失函数的步骤,包括:
采用MSEloss对所述第二图片特征矩阵和所述第一图片特征矩阵进行特征蒸馏,构建所述第一损失函数。
7.根据权利要求3所述的方法,其特征在于,所述利用所述第四图片特征矩阵拟合所述第三图片特征矩阵,以构建第二损失函数的步骤,包括:
采用SmoothL1loss对所述第四图片特征矩阵和所述第三图片特征矩阵进行特征蒸馏,构建所述第二损失函数。
8.一种模型训练装置,其特征在于,利用具备多任务特征处理功能的教师网络训练学生网络,所述装置包括:
获取模块,用于利用样本图片分别输入教师网络和学生网络,获取所述教师网络的中间层输出的第一图片特征矩阵,以及,所述学生网络的中间层输出的第二图片特征矩阵;
拟合模块,用于利用所述第二图片特征矩阵拟合所述第一图片特征矩阵,以构建第一损失函数;
优化模块,用于利用所述第一损失函数优化所述学生网络的模型参数。
9.一种计算机设备,其特征在于,所述计算机设备包括存储器以及处理器,所述存储器存储有计算机程序,所述计算机程序在所述处理器运行时执行权利要求1至7中任一项所述的模型训练方法。
10.一种计算机可读存储介质,其特征在于,其存储有计算机程序,所述计算机程序在处理器上运行时执行权利要求1至7中任一项所述的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210805013.0A CN115240250A (zh) | 2022-07-08 | 2022-07-08 | 模型训练方法、装置、计算机设备及可读存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210805013.0A CN115240250A (zh) | 2022-07-08 | 2022-07-08 | 模型训练方法、装置、计算机设备及可读存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115240250A true CN115240250A (zh) | 2022-10-25 |
Family
ID=83672257
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210805013.0A Pending CN115240250A (zh) | 2022-07-08 | 2022-07-08 | 模型训练方法、装置、计算机设备及可读存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115240250A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117351299A (zh) * | 2023-09-13 | 2024-01-05 | 北京百度网讯科技有限公司 | 图像生成及模型训练方法、装置、设备和存储介质 |
-
2022
- 2022-07-08 CN CN202210805013.0A patent/CN115240250A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117351299A (zh) * | 2023-09-13 | 2024-01-05 | 北京百度网讯科技有限公司 | 图像生成及模型训练方法、装置、设备和存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109151180B (zh) | 一种对象识别方法及移动终端 | |
CN108184050B (zh) | 一种拍照方法、移动终端 | |
CN108427873B (zh) | 一种生物特征识别方法及移动终端 | |
CN109240577B (zh) | 一种截屏方法及终端 | |
CN111401463B (zh) | 检测结果输出的方法、电子设备及介质 | |
CN108984066B (zh) | 一种应用程序图标显示方法及移动终端 | |
CN107749046B (zh) | 一种图像处理方法及移动终端 | |
CN109495616B (zh) | 一种拍照方法及终端设备 | |
CN108257104B (zh) | 一种图像处理方法及移动终端 | |
CN111522613B (zh) | 截屏方法及电子设备 | |
CN110505660B (zh) | 一种网络速率调整方法及终端设备 | |
CN109858447B (zh) | 一种信息处理方法及终端 | |
CN109286726B (zh) | 一种内容显示方法及终端设备 | |
CN108628534B (zh) | 一种字符展示方法及移动终端 | |
CN112464831B (zh) | 视频分类方法、视频分类模型的训练方法及相关设备 | |
CN108536513B (zh) | 一种图片显示方向调整方法及移动终端 | |
CN107967086B (zh) | 一种移动终端的图标排列方法及装置、移动终端 | |
CN115240250A (zh) | 模型训练方法、装置、计算机设备及可读存储介质 | |
CN109819331B (zh) | 一种视频通话方法、装置、移动终端 | |
CN111753047B (zh) | 一种文本处理方法及装置 | |
CN111145083B (zh) | 一种图像处理方法、电子设备及计算机可读存储介质 | |
CN110865859B (zh) | 图片显示方法、装置、电子设备及介质 | |
EP3846426B1 (en) | Speech processing method and mobile terminal | |
CN109358792B (zh) | 一种显示对象的选取方法及终端 | |
CN109829167B (zh) | 一种分词处理方法和移动终端 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |