CN113222123A - 模型训练方法、装置、设备及计算机存储介质 - Google Patents
模型训练方法、装置、设备及计算机存储介质 Download PDFInfo
- Publication number
- CN113222123A CN113222123A CN202110660998.8A CN202110660998A CN113222123A CN 113222123 A CN113222123 A CN 113222123A CN 202110660998 A CN202110660998 A CN 202110660998A CN 113222123 A CN113222123 A CN 113222123A
- Authority
- CN
- China
- Prior art keywords
- model
- teacher
- student
- loss value
- teacher model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 164
- 238000000034 method Methods 0.000 title claims abstract description 113
- 238000003860 storage Methods 0.000 title claims abstract description 24
- 230000008569 process Effects 0.000 claims description 39
- 238000013507 mapping Methods 0.000 claims description 37
- 238000004364 calculation method Methods 0.000 claims description 34
- 230000006870 function Effects 0.000 claims description 28
- 230000004044 response Effects 0.000 claims description 10
- 238000004590 computer program Methods 0.000 claims description 7
- 238000005070 sampling Methods 0.000 claims description 6
- 239000013589 supplement Substances 0.000 claims description 3
- 230000000694 effects Effects 0.000 description 16
- 238000012360 testing method Methods 0.000 description 13
- 238000009826 distribution Methods 0.000 description 9
- 238000013528 artificial neural network Methods 0.000 description 6
- 238000010586 diagram Methods 0.000 description 6
- 238000013508 migration Methods 0.000 description 5
- 230000005012 migration Effects 0.000 description 5
- 239000000047 product Substances 0.000 description 5
- 238000012546 transfer Methods 0.000 description 5
- 238000004891 communication Methods 0.000 description 4
- 238000004140 cleaning Methods 0.000 description 3
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 230000006872 improvement Effects 0.000 description 3
- 238000013140 knowledge distillation Methods 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 238000012545 processing Methods 0.000 description 3
- 230000009467 reduction Effects 0.000 description 3
- 230000001502 supplementing effect Effects 0.000 description 3
- 230000009286 beneficial effect Effects 0.000 description 2
- 238000002790 cross-validation Methods 0.000 description 2
- 230000002950 deficient Effects 0.000 description 2
- 238000011156 evaluation Methods 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 241000238631 Hexapoda Species 0.000 description 1
- 241000282414 Homo sapiens Species 0.000 description 1
- 241000607479 Yersinia pestis Species 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 238000010923 batch production Methods 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 201000010099 disease Diseases 0.000 description 1
- 208000037265 diseases, disorders, signs and symptoms Diseases 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 230000000644 propagated effect Effects 0.000 description 1
- 230000001902 propagating effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T3/00—Geometric image transformations in the plane of the image
- G06T3/40—Scaling of whole images or parts thereof, e.g. expanding or contracting
- G06T3/4007—Scaling of whole images or parts thereof, e.g. expanding or contracting based on interpolation, e.g. bilinear interpolation
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Molecular Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Evolutionary Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
Abstract
本申请公开了一种模型训练方法、装置、设备及计算机存储介质,其中,所述模型训练方法包括:获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型;在基于与所述识别任务对应的训练数据集,训练所述学生模型和所述教师模型的过程中,基于所述学生模型的预测结果与所述教师模型的预测结果之间的差异程度和所述学生模型的预测结果,确定第一损失值,其中,所述教师模型对应的大规模数据集的类别维度被映射为所述学生模型对应的所述识别任务的类别维度;基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛。
Description
技术领域
本申请涉及但不限于计算机视觉领域,尤其涉及一种模型训练方法、装置、设备及计算机存储介质。
背景技术
相关技术采用的有限规模的数据集实现预训练模型的训练,基于有限规模的训练数据,泛化能力较差;教师模型的输出特性维度与学生模型的输出维度不同,在使用该教师模型监督训练学生模型时,存在训练效率低、迁移的知识有限的问题。
发明内容
本申请实施例提供一种模型训练方法、装置、设备及计算机存储介质。
第一方面,提供一种模型训练方法,所述方法包括:获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型;在基于与所述识别任务对应的训练数据集,训练所述学生模型和所述教师模型的过程中,基于所述学生模型的预测结果与所述教师模型的第一预测结果之间的的差异程度和所述学生模型的预测结果,确定第一损失值,其中,所述教师模型对应的大规模数据集的类别维度被映射为所述学生模型对应的所述识别任务的类别维度;基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛。
在一些实施方式中,所述在基于与所述识别任务对应的训练数据集,训练所述学生模型和所述教师模型的过程中,基于所述学生模型的预测结果与所述教师模型的第一预测结果之间的的差异程度和所述学生模型的预测结果,确定第一损失值,包括:获取所述识别任务对应的训练数据集;将所述训练数据集输入所述学生模型完成前向计算,得到所述学生模型的预测结果;将所述训练数据集输入所述教师模型完成前向计算,得到所述教师模型的第一预测结果;基于所述学生模型的预测结果和所述教师模型的第一预测结果确定所述第一损失值。
这样,计算学生模型的预测结果和教师模型的第一预测结果可以得到第一损失值,以使用第一损失值对学生模型的参数进行更新。
在一些实施方式中,所述基于所述学生模型的预测结果和所述教师模型的第一预测结果确定所述第一损失值,包括:基于所述学生模型的预测结果,确定第二损失值;基于所述学生模型的预测结果和所述教师模型的第一预测结果之间的差异程度,确定第三损失值;基于所述第二损失值和所述第三损失值,确定所述第一损失值。
这样,根据学生模型的预测结果确定第二损失值;根据学生模型的预测结果和教师模型的第一预测结果之间的差异程度,确定第三损失值;这样,基于预设的比例根据第二损失值和第三损失值可以确定第一损失值。
在一些实施方式中,所述基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛,包括:基于所述第一损失值对所述学生模型进行反向传播,得到所述学生模型的更新梯度;基于所述学生模型的更新梯度更新所述学生模型的参数;迭代训练所述学生模型和所述教师模型,响应于所述第二损失值小于等于第一损失值阈值,确定所述学生模型收敛。
这样,基于学生模型的更新梯度,实现更新学生模型的参数,以获得收敛的学生模型。
在一些实施方式中,所述方法还包括:基于所述学生模型的预测结果所对应的类别维度和所述教师模型的预测结果所对应的类别维度,确定映射层的参数;将所述映射层添加至所述教师模型,以使得添加后的所述教师模型的预测结果所对应的类别维度被映射为所述学生模型的预测结果所对应的类别维度。
这样,在教师模型中添加映射层,可以实现对齐教师模型和学生模型可以识别的类别维度,完成知识迁移的效果。
在一些实施方式中,所述基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛,还包括:基于所述第一损失值更新所述教师模型的映射层的更新梯度;基于所述教师模型的映射层的更新梯度更新所述教师模型的映射层的参数。
这样,基于教师模型的映射层的更新梯度,可以确定教师模型的映射层的参数。
在一些实施方式中,所述方法还包括:基于所述大规模数据集,按照批次采样数据,输入所述教师模型中进行前向计算,得到所述教师模型的第二预测结果;利用交叉熵损失函数计算所述教师模型的第二预测结果与所述第一数据集中图像标签的差异,得到第四损失值;基于所述第四损失值更新所述教师模型的参数,响应于所述教师模型的损失值小于第二损失阈值,得到收敛的教师模型。
这样,基于大规模数据集训练得到的教师模型具有更好的泛化能力和表达能力,能够对下游任务带来更好的提升效果。
在一些实施方式中,所述基于所述第四损失值更新所述教师模型的参数,响应于所述教师模型的损失值小于第二损失阈值,得到收敛的教师模型,包括:基于所述第四损失值对所述教师模型进行反向传播,得到所述教师模型的更新梯度;基于所述教师模型的更新梯度和预设的学习率,确定所述教师模型的更新幅度;利用所述教师模型的更新幅度更新所述教师模型的参数,响应于所述教师模型的损失值小于第二损失阈值,得到收敛的教师模型。
这样,首先,基于第四损失值对教师模型进行反向传播,得到教师模型的更新梯度,然后,基于更新梯度和预设的学习率,确定更新幅度,最后,利用更新幅度更新教师模型的参数,响应于教师模型的损失值小于第二损失阈值,得到收敛的教师模型。迭代对教师模型的参数进行更新,最终可以得到收敛的教师模型。
在一些实施方式中,所述方法还包括:对所述大规模数据集中每一图像的缺失信息进行差值补充;将未能补全的所述图像从所述大规模数据集中移除,得到清洗后的大规模数据集,以完成对所述教师模型的训练。
这样,对大规模数据集进行清洗整理,可以删除重复信息、纠正存在的错误,并确保得到的大规模数据集中数据的一致性。
第二方面,提供一种模型训练装置,包括:获取模块,用于获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型;第一确定模块,用于在基于与所述识别任务对应的训练数据集,训练所述学生模型和所述教师模型的过程中,基于所述学生模型的预测结果与所述教师模型的预测结果之间的差异程度和所述学生模型的预测结果,确定第一损失值,其中,所述教师模型对应的大规模数据集的类别维度被映射为所述学生模型对应的所述识别任务的类别维度;第一更新模块,用于基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛。
第三方面,提供一种计算机设备,包括:存储器和处理器,所述存储器存储有可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述方法中的步骤。
第四方面,提供一种计算机存储介质,所述计算机存储介质存储有一个或者多个程序,所述一个或者多个程序可被一个或者多个处理器执行,以实现上述方法中的步骤。
在本申请实施例中,首先,获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型,其次,基于与所述识别任务对应的训练数据集,训练学生模型和教师模型的过程中,确定第一损失值,其中,教师模型对应的大规模数据集的类别维度被映射为学生模型对应的所述识别任务的类别维度,最后,基于第一损失值,更新学生模型的参数,直至学生模型收敛。这样,基于大数据集训练得到的教师模型具有更好的泛化能力,能够支持多种下游任务的调用;创建学生模型,在指定下游任务中进行训练,同时设置教师模型的类别维度与学生模型的类别维度相同,进一步监督学生小模型的训练。通过知识蒸馏的方式进行教师模型到学生模型的知识迁移,在不增加部署的成本的同时提升了模型的精度和迁移效果,达到降本增效的能力。
附图说明
图1为本申请实施例提供的一种模型训练方法的实现流程示意图;
图2为本申请实施例提供的一种模型训练方法的实现流程示意图;
图3为本申请实施例提供的一种模型训练方法的实现流程示意图;
图4为本申请实施例提供的一种模型训练方法的实现流程示意图;
图5为本申请实施例提供的一种模型训练装置的组成结构示意图;
图6为本申请实施例提供的一种计算机设备的硬件实体示意图。
具体实施方式
为使本申请实施例的目的、技术方案和优点更加清楚,下面将结合本申请实施例中的附图,对发明的具体技术方案做进一步详细描述。以下实施例用于说明本申请,但不用来限制本申请的范围。
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述。
应当理解,此处所描述的一些实施例仅仅用以解释本申请的技术方案,并不用于限定本申请的技术范围。
本实施例提出一种模型训练方法,应用于计算机设备,该方法所实现的功能可以通过计算机设备中的处理器调用程序代码来实现,当然程序代码可以保存在计算机存储介质中,可见,该计算机设备至少包括处理器和存储介质。
图1为本申请实施例提供的一种模型训练方法的实现流程示意图,应用于计算机设备,如图1所示,该方法包括:
步骤S101、获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型;
大规模数据集是从数据的量级描述的数据集,一般来说,用于模型训练的大规模数据集中包括的数据量在千万级,图像的类别数量在万级。
在一些实施例中,可以选取ImageNet-22k作为大规模数据集,其中包含两万多个类别,共一千多万张图片信息,当然也可以使用MNIST、CIFAR-10作为大规模数据集。
在一些实施例中,可以选取大模型作为教师模型,例如ResNet152、ResNeXt101-32x8d、ResNeXt101-64x8d等,在实施过程中,可以设置教师模型类别数量为ImageNet-22k类别数,并进行教师模型的参数的初始化。
在一些实施例中,可以对学生模型进行随机初始化或按照一定规则进行初始化,将类别数设定为识别任务的类别数。这里,一定规则指的是初始化方式,包括:正态分布初始化、凯明(Kaiming)初始化等。
在一些实施例中,选取神经网络大模型作为教师模型,目的是训练得到高精度的参数用于监督学生模型的训练。相较于学生模型,教师模型具有更多的参数量,表示能力更高;而学生模型是用于识别特定任务的模型,学生模型的参数量较少,更有利于实际应用场景的部署。本申请提供的方法旨在应用于各种应用场景,例如,识别行人是否摔倒,识别违章停车,识别植物病虫害,识别农作物的生长阶段,识别驾驶员是否戴安全帽,识别驾驶员是否系安全带等等。
步骤S102、在基于与所述识别任务对应的训练数据集,训练所述学生模型和所述教师模型的过程中,基于所述学生模型的预测结果与所述教师模型的第一预测结果之间的差异程度和所述学生模型的预测结果,确定第一损失值,其中,所述教师模型对应的大规模数据集的类别维度被映射为所述学生模型对应的所述识别任务的类别维度;
在一些实施例中,需要先获取与识别任务对应的训练数据集,这里的训练数据集可以是比大规模数据集的数据量小的数据集,是包括了识别任务的类别的数据集。
在实施过程中,基于训练数据集,在训练学生模型和教师模型的过程中,首先确定学生模型的预测结果和教师模型的第一预测结果,然后根据学生模型的预测结果与所述教师模型的第一预测结果之间的差异程度和所述学生模型的预测结果,确定第一损失值。
步骤S103、基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛。
在一些实施例中,可以基于计算得到的第一损失值,迭代更新学生模型和教师模型,直至学生模型的损失值到一定范围以下,模型收敛。
本申请实施例中,首先,获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型,其次,基于与所述识别任务对应的训练数据集,训练学生模型和教师模型的过程中,确定第一损失值,其中,教师模型对应的大规模数据集的类别维度被映射为学生模型对应的所述识别任务的类别维度,最后,基于第一损失值,更新学生模型的参数,直至学生模型收敛。这样,基于大数据集训练得到的教师模型具有更好的泛化能力,能够支持多种下游任务的调用;创建学生模型,在指定下游任务中进行训练,同时设置教师模型的类别维度与学生模型的类别维度相同,进一步监督学生小模型的训练,通过知识蒸馏的方式进行教师模型到学生模型的知识迁移,在不增加部署的成本的同时提升了模型的精度和迁移效果,达到降本增效的能力。
本申请实施例提供的一种模型训练方法,该方法包括:
步骤S111、获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型;
步骤S112、获取所述识别任务对应的训练数据集;
这里,识别任务对应的训练数据集是用于训练学生模型的数据集,包含的数据量小于大规模数据集。可以根据学生模型的识别任务确定训练数据集。
步骤S113、将所述训练数据集输入所述学生模型完成前向计算,得到所述学生模型的预测结果;
这里,神经网络的前向计算,就是给定一组输入,计算输出的过程。每次前向计算得到样本的预测类别概率分布,概率最大的类别即为预测类别。输出预测结果就是概率最大的类别,即模型的预测结果。
在一些实施例中,可以从训练数据集中按照批次采样训练数据,输入学生模型进行前向计算,得到学生模型的预测结果。这里,批次是指Batch的大小,即每轮迭代神经网络所能处理的图像数量。Batch大小是一个超参数,用于定义在更新内部模型参数之前要处理的样本数。将批处理视为循环迭代一个或多个样本并进行预测。训练数据集可以分为一个或多个Batch。
步骤S114、将所述训练数据集输入所述教师模型完成前向计算,得到所述教师模型的第一预测结果;
在一些实施例中,可以从训练数据集中按照批次采样训练数据,输入教师模型进行前向计算,得到教师模型的第一预测结果。
步骤S115、基于所述学生模型的预测结果和所述教师模型的第一预测结果确定所述第一损失值;
步骤S116、基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛。
本申请实施例中,计算学生模型的预测结果和教师模型的第一预测结果可以得到第一损失值,以使用第一损失值对学生模型的参数进行更新。
本申请实施例提供的一种模型训练方法,该方法包括:
步骤S121、获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型;
步骤S122、获取所述识别任务对应的训练数据;
步骤S123、将所述训练数据输入所述学生模型完成前向计算,得到所述学生模型的预测结果;
步骤S124、将所述训练数据输入所述教师模型完成前向计算,得到所述教师模型的第一预测结果;
步骤S125、基于所述学生模型的预测结果确定第二损失值;
分类问题中可以使用交叉熵作为损失函数来计算确定损失值。交叉熵损失函数用于计算损失值,输入为模型输出的预测类别(即前向计算的输出预测结果)和实际的真实类别,输出为损失值。数据集本身有每张图片所对应的真实类别,交叉熵计算预测类别和真实类别的差异大小。交叉熵能够衡量同一个随机变量中的两个不同概率分布的差异程度,在机器学习中就表示为真实概率分布与预测概率分布之间的差异。交叉熵的值越小,模型预测效果就越好。
在一些实施例中,可以利用交叉熵损失函数根据学生模型的预测结果和图像标签的差异确定第二损失值。这里,图像标签可以是输入学生模型的图像标识出的图像中的物体类型,即,人工识别过的图像中正确的物体类别。
步骤S126、基于所述学生模型的预测结果和所述教师模型的第一预测结果之间的差异程度,确定第三损失值;
在一些实施例中,可以利用欧式距离基于学生模型的预测结果和教师模型的第一预测结果之间的差异程度,确定第三损失值。这里,欧氏距离是一个通常采用的距离定义,是在多维空间中两个点之间的真实距离,在二维和三维空间中的欧式距离的就是两点之间的距离。在实施过程中,欧氏距离即两项间的差是每个变量值差的平方和再平方根,目的是计算其间的整体距离即不相似性。
步骤S127、基于所述第二损失值和所述第三损失值,确定所述第一损失值;
在一些实施例中,可以基于预设的比例,根据第二损失值和第三损失值确定第一损失值。在实施过程中,预设的比例可以是根据实际需求设置的,需要基于预设的比例对第二损失值和第三损失值进行加权求和得到第一损失值,这里,加权是为了使得加权后第二损失值和第三损失值在同样的量级。
步骤S128、基于所述第一损失值对所述学生模型进行反向传播,得到所述学生模型的更新梯度;
梯度概念是建立在偏导数与方向导数概念基础上的,梯度反映的是空间变量变化趋势的最大值和方向。偏导数,可以是对于一个多元函数,选定一个自变量并让其他自变量保持不变,只考察因变量与选定自变量的变化关系。反向传播(back propagation)的作用是快速算出所有参数的偏导数。更新梯度是指根据损失值进行反向传播计算,可以得到网络中参数所需要更新的梯度大小,之后将梯度乘以调整学习率更新于参数上,是更新梯度的过程。
在实施过程中,根据第一损失值对学生模型进行反向传播计算,可以得到学生模型的更新梯度的大小,之后将梯度乘以调整学习率更新学生模型的参数。误差的反向传播会计算各层的梯度。通过往梯度下降的方向调整参数,逐步减小损失函数的值,从而得到训练好的模型。
步骤S129、基于所述学生模型的更新梯度更新所述学生模型的参数;
在实施过程中,可以设置所述学生模型的权重衰减小于等于权重衰减阈值;这里,权值衰减最终目的是防止过拟合。在损失函数中,权值衰减是放在正则项前面的一个系数,正则项一般指示模型的复杂度,所以权值衰减的作用是调节模型复杂度对损失函数的影响,若权值衰减很大,则复杂的模型损失函数的值也就大,相反若权值衰减很小,则复杂的模型损失函数的值也就小。
在一些实施例中,可以设置学生模型的权值衰减为0,让学生模型能够更好地拟合教师模型。
在实施过程中,基于获取到的学生模型的更新梯度学生模型的参数。
步骤S130、迭代训练所述学生模型和所述教师模型,响应于所述第二损失值小于等于损失值阈值,确定所述学生模型收敛。
本申请实施例中,利用交叉熵损失函数根据学生模型的预测结果确定第二损失值;利用欧式距离根据学生模型的预测结果和教师模型的第一预测结果确定第三损失值;这样,基于预设的比例根据第二损失值和第三损失值可以确定第一损失值。
本申请实施例中,基于学生模型的更新梯度,实现更新学生模型的参数,以获得收敛的学生模型。
本申请实施例提供的一种模型训练方法,该方法包括:
步骤S141、获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型;
步骤S142、基于所述学生模型的预测结果所对应的类别维度和所述教师模型的第一预测结果所对应的类别维度,确定映射层的参数;
在实施过程中,可以基于学生模型的预测结果所对应的类别维度和所述教师模型的第一预测结果所对应的类别维度,确定用于添加至教师模型中的映射层的参数。
步骤S143、将所述映射层添加至所述教师模型,以使得添加后的所述教师模型的预测结果所对应的类别维度被映射为所述学生模型的预测结果所对应的类别维度;
在一些实施例中,映射层可以以全连接层的形式添加至教师模型。
步骤S144、获取所述识别任务对应的训练数据;
步骤S145、将所述训练数据输入所述学生模型完成前向计算,得到所述学生模型的预测结果;
步骤S146、将所述训练数据输入所述教师模型完成前向计算,得到所述教师模型的第一预测结果;
步骤S147、基于所述学生模型的预测结果确定第二损失值;
在一些实施例中,分类问题中可以使用交叉熵作为损失函数来计算确定损失值,即,可以利用交叉熵损失函数根据学生模型的预测结果和图像标签的差异确定第二损失值。
步骤S148、基于所述学生模型的预测结果和所述教师模型的第一预测结果之间的差异程度,确定第三损失值;
步骤S149、基于所述第二损失值和所述第三损失值,确定所述第一损失值;
步骤S150、基于所述第一损失值对所述学生模型进行反向传播,得到所述学生模型的更新梯度;
步骤S151、基于所述第一损失值确定所述教师模型的映射层的更新梯度;
在一些实施例中,需要保持教师模型的主干网络的参数不更新,只更新对应的映射层,所以需要基于第一损失值对教师模型的映射层进行反向传播,更新教师模型的映射层的梯度。在实施过程中,教师模型的主干网络可以包括ResNet152、ResNeXt101-32x8d、ResNeXt101-64x8d。
步骤S152、基于所述学生模型的更新梯度更新所述学生模型的参数;
步骤S153、基于所述教师模型的映射层的更新梯度更新所述教师模型的映射层的参数;
在实施过程中,可以基于教师模型的映射层对应的更新梯度,确定教师模型的映射层的参数。
步骤S154、迭代训练所述学生模型和所述教师模型,响应于所述第二损失值小于等于损失值阈值,确定所述学生模型收敛。
本申请实施例中,利用交叉熵损失函数根据学生模型的预测结果确定第二损失值;利用欧式距离根据学生模型的预测结果和教师模型的第一预测结果确定第三损失值;这样,基于预设的比例根据第二损失值和第三损失值可以确定第一损失值。
本申请实施例中,在教师模型中添加映射层,可以实现对齐教师模型和学生模型的识别维度,进行知识迁移的效果。
本申请实施例中,基于教师模型的映射层的更新梯度和学生模型的更新梯度,确定教师模型的映射层的参数和学生模型的参数,以获得收敛的学生模型。
图2为本申请实施例提供的一种模型训练方法的实现流程示意图,应用于计算机设备,如图2所示,该方法包括:
步骤S201、对大规模数据集中每一图像的缺失信息进行差值补充;
在一些实施例中,可以选取ImageNet-22k作为大规模数据训练集,其中包含两万多个类别,共一千多万张图片信息。首先对数据进行数据清洗,目的在于删除重复信息、纠正存在的错误,并提供数据一致性。
在一些实施例中,可以对大规模数据集中每一图像完成缺失值处理,对图像中的缺失信息进行差值补充,对于不能补全的图片,需要从数据集中进行移除。
在一些实施例中,例如图像存在部分区域缺失,可以使用双线性插值的方式根据周围临近像素进行补全。
步骤S202、将未能补全的所述图像从所述大规模数据集中移除,得到清洗后的大规模数据集,以完成对所述教师模型的训练;
在一些实施例中,将任一不能补全的图像从大规模数据集中移除去掉,即去掉重复数据和噪声数据。
在一些实施例中,可以按照预设的格式命名所述大规模数据集中的图像;对图像的格式与内容进行处理,按照统一的格式进行图片名称的命名;生成所述大规模数据集的索引信息,完成所述大规模数据集的清洗。
在一些实施例中,按照一定比例对数据集进行训练集和测试集的划分,使用训练集进行模型的训练,使用测试集作为评测。这里,可以按照十折交叉验证的方式进行划分,即将数据划分为十份,九份用来训练模型,一份用来测试模型的性能。
步骤S203、获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型;
步骤S204、在基于与所述识别任务对应的训练数据,训练所述学生模型和所述教师模型的过程中,基于所述学生模型的预测结果与所述教师模型的第一预测结果之间的差异程度和所述学生模型的预测结果,确定第一损失值,其中,所述教师模型对应的所述大规模数据集的类别维度被映射为所述学生模型对应的所述识别任务的类别维度;
步骤S205、基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛。
本申请实施例中,对大规模数据集进行清洗整理,可以删除重复信息、纠正存在的错误,并确保得到的大规模数据集中数据的一致性。
图3为本申请实施例提供的一种模型训练方法的实现流程示意图,应用于计算机设备,如图3所示,该方法包括:
步骤S301、基于大规模数据集,按照批次采样数据,输入教师模型中进行前向计算,得到所述教师模型的第二预测结果;
在一些实施例中,可以选取大模型作为教师模型,例如ResNet152、ResNeXt101-32x8d、ResNeXt101-64x8d等,设置模型类别数量为ImageNet-22k类别数,并进行模型参数的初始化。从训练集中按照批次采样数据,输入到教师模型中进行前向计算,教师模型的第二预测结果。
步骤S302、利用交叉熵损失函数计算所述教师模型的第二预测结果与所述第一数据集中图像标签的差异,得到第四损失值;
交叉熵能够衡量同一个随机变量中的两个不同概率分布的差异程度,在机器学习中就表示为真实类别概率分布与预测类别概率分布之间的差异。交叉熵的值越小,模型预测效果就越好。交叉熵在分类问题中可以联合回归算法(softmax)将输出的结果进行处理,使其多个分类的预测值和为1,再通过交叉熵来计算损失值。在分类问题中常常使用交叉熵作为损失函数。
这里,教师模型的第二预测结果可以是预测概率分布,大规模数据集本身有每张图片所对应的真实类别,即图像标签。根据交叉熵损失函数可以计算教师模型的第二预测结果与图像标签的差异,得到第四损失值。
步骤S303、基于所述第四损失值更新所述教师模型的参数,响应于所述教师模型的损失值小于第二损失阈值,得到收敛的教师模型;
在一些实施例中,将第四损失值反向传播得到针对教师模型参数的梯度。将梯度乘以设定好的学习率得到更新幅度,对教师模型的参数进行更新。重复确定第四损失值和利用第四损失值更新教师模型的参数的过程,直至得到收敛的教师模型。
步骤S304、获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型;
步骤S305、在基于与所述识别任务对应的训练数据,训练所述学生模型和所述教师模型的过程中,基于所述学生模型的预测结果与所述教师模型的第一预测结果之间的差异程度和所述学生模型的预测结果,确定第一损失值,其中,所述教师模型的对应的所述大规模数据集的类别维度被映射为所述学生模型对应的所述识别任务的类别维度;
步骤S306、基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛。
本申请实施例中,基于大规模数据集训练得到的教师模型具有更好的泛化能力和表达能力,能够对下游任务带来更好的提升效果。
本申请实施例中,对大规模数据集进行清洗整理,可以删除重复信息、纠正存在的错误,并确保得到的大规模数据集中数据的一致性。
本申请实施例提供的一种模型训练方法的实现流程示意图,该方法包括:
步骤S311、基于大规模数据集,按照批次采样数据,输入教师模型中进行前向计算,得到所述教师模型的第二预测结果;
步骤S312、利用交叉熵损失函数计算所述教师模型的第二预测结果与所述第一数据集中图像标签的差异,得到第四损失值;
步骤S313、基于所述第四损失值对所述教师模型进行反向传播,得到所述教师模型的更新梯度;
根据损失值进行反向传播计算,可以得到网络中参数所需要更新的梯度大小,之后将梯度乘以调整学习率更新于参数上,就是更新梯度的过程。
在实施过程中,基于第四损失值反向推到梯度更新,得到教师模型的参数的梯度。
步骤S314、基于所述教师模型的更新梯度和预设的学习率,确定所述教师模型的更新幅度;
学习率是指每次梯度调整的幅度,合适的学习率需要试错的方式进行获取,或者按照一些模型的经验进行设置。
在实施过程中,将获取到的梯度乘以预设的学习率可以得到教师模型的参数的更新幅度。
步骤S315、利用所述教师模型的更新幅度更新所述教师模型的参数,响应于所述教师模型的损失值小于第二损失阈值,得到收敛的教师模型;
重复步骤S311至步骤S315,直至教师模型的损失值到一定范围以下,模型收敛。
步骤S316、获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型;
步骤S317、在基于与所述识别任务对应的训练数据集,训练所述学生模型和所述教师模型的过程中,基于所述学生模型的预测结果与所述教师模型的第一预测结果之间的差异程度和所述学生模型的预测结果,确定第一损失值,其中,所述教师模型对应的所述大规模数据集的类别维度被映射为所述学生模型对应的所述识别任务的类别维度;
步骤S318、基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛。
本申请实施例中,首先基于第四损失值反向传播,得到教师模型的参数的梯度,然后将梯度乘以预设的学习率得到更新幅度,最后利用更新幅度更新教师模型的参数,响应于教师模型的损失值小于第二损失阈值,得到收敛的教师模型。这样,迭代对教师模型的参数进行更新,最终可以得到收敛的教师模型。
深度学习技术在人工智能领域取得了巨大的成功,在这其中,神经网络起到了重要的作用。基于神经网络的模型在计算机视觉、自然语言处理、语音识别等领域取得了超越人类的识别能力,具有巨大的商业应用价值。神经网络模型的性能往往与模型中的参数密切相关,能否优化得到好的参数决定了模型最终的性能。
在实际应用场景中,相较于从头训练得到模型的参数,工程研究人员更倾向于使用预先训练好的模型作为起点,之后通过微调的方式在指定的任务上进行训练,能够取得更好的收敛效果和最终性能,这被称作预训练模型(Pre-Trained Model)。显然,基于预训练模型的方法更快、更省力。
本申请提出使用更大规模的数据进行神经网络的训练,通过大规模数据学习到更为通用的特征表示方法,能够为下游任务带来更好的迁移效果。
为了进一步的提升模型的表示能力,本申请引入教师模型,在大数据集中训练之后,得到收敛的教师模型,再利用教师模型监督学生模型的训练过程,进一步提升预训练模型的效果。
图4为本申请实施例提供的一种模型训练方法的实现流程示意图,应用于计算机设备,如图4所示,该方法包括:
步骤S401、整理清洗大规模数据集并划分训练集和测试集;
在实施过程中,对超大规模数据集进行清洗整理(千万量级训练数据),去除缺损图片,按照比例划分训练集和测试集。
在一些实施例中,可以选取ImageNet-22k作为大规模数据训练集,其中包含两万多个类别,共一千多万张图片信息。首先对数据进行数据清洗,目的在于删除重复信息、纠正存在的错误,并提供数据一致性。
具体的操作见上述实施例中的步骤S201至步骤S202,在此不赘述。
步骤S402、选定包含较多参数量,具有较强表示能力的模型作为教师模型,并在大规模数据集上进行训练,得到收敛的教师模型;
在一些实施例中,创建教师模型,并在清洗完成的大规模数据集上进行训练,解决大模型和大规模数据的训练难点,得到收敛的教师模型的模型参数。
在一些实施例中,在大规模数据集上训练收敛的教师模型可以包括上述实施例中的步骤S311至步骤S315,在此不赘述。
步骤S403、根据识别任务创建学生模型;
在一些实施例中,创建学生模型,在指定下游任务中进行训练,对学生模型进行随机初始化或按照一定规则进行初始化,并将类别数设定为当前任务的类别数。这里,下游任务即学生模型的识别任务。
步骤S404、对教师模型使用大规模预训练参数作为初始化,并添加映射层到教师模型;
在一些实施例中,对教师模型读取大规模数据集上训练得到的参数作为初始化参数,并添加额外的映射层将大规模数据集的类别数映射为下游任务的类别数,即,选取教师模型提取特征进行映射到相同特征维度,进一步监督学生模型的训练。
步骤S405、在下游任务中使用教师模型和学生模型进行前向计算,得到对应的损失函数;以使得学生模型的输出模拟教师模型,并更新学生模型的参数;
首先,可以从下游任务对应的数据集中按照批次采样训练数据,输入学生模型进行前向计算,得到学生模型的预测结果(F1);同时输入教师模型进行前向计算得到教师模型的第一预测结果(F2)。
其次,根据交叉熵损失函数计算学生模型的损失值,得到对应的学生模型的第二损失值(L2)。根据欧氏距离计算学生模型的预测结果和教师模型的第一预测结果之间的差异程度,得到第三损失值(L1)。并按照一定的比例进行加权求和得到第一损失值(L1)。
然后,使用加权求和得到的第一损失值对教师模型和学生模型进行反向传播,计算参数所对应的更新梯度。保持教师模型的主干网络参数不更新,只更新对应的映射层,对学生网络的全部参数进行更新。同时设定学生网络的权值衰减为0,让学生网络能够更好地拟合教师网络。
最后,重复上述过程,直至学生模型的损失值到一定范围以下,模型收敛。
步骤S406、在下游任务的测试集上验证学生模型的性能,如达到要求则整个过程结束。
在一些实施例中,可以按照一定比例对数据集进行训练集和测试集的划分,使用训练集进行模型的训练,使用测试集作为评测。一般按照十折交叉验证的方式进行划分,即将数据划分为十份,九份用来训练模型,一份用来测试模型的性能。
在实施过程中,可以选择下游任务的测试集,然后利用测试集验证学生模型的模型性能,测试学生模型效果。
本申请实施例中,对超大规模数据集进行清洗整理(千万量级训练数据),去除缺损图片,按照比例划分训练集和测试集。解决了基于有限规模的训练数据得到的教师模型,泛化能力较差的问题。基于大数据集训练得到的教师模型具有更好的泛化能力,能够支持多种下游任务的调用,对下游任务带来更好的提升效果。
本申请实施例中,创建教师大模型,并在清洗完成的大规模数据集上进行训练,解决大模型和大规模数据的训练难点,得到收敛的模型参数。其中,大模型训练难点包括:节省GPU显存占用、训练策略的改进促进收敛;大规模数据的训练难点包括:分布式训练效率提升。
本申请实施例中,创建学生小模型,在指定下游任务中进行训练,同时选取教师大模型提取特征进行映射到相同类别维度,进一步监督学生小模型的训练。基于跨数据集教师模型监督学生模型,使用大数据集训练的教师模型监督应用场景中的学生模型,使得学生模型在达到教师模型精度的同时具有更好的部署能力。解决了相关技术未使用教师模型进行监督,或使用同一数据集训练的教师模型进行监督,迁移的知识有限的问题。基于高精度的教师模型能够降低下游任务的训练难度,同时提升下游任务上的训练性能,取得更大的收益。通过知识蒸馏的方式进行教师模型到学生模型的知识迁移,在不增加部署的成本的同时提升了模型的精度,达到降本增效的能力。
本申请实施例中,在教师模型中引入映射层对齐预训练模型和下游模型进行知识迁移。
本申请实施例中,在实际业务场景中调用学生模型进行使用,或者引入教师模型进行微调学习,能够有效提升训练数据较少情况下的特征提取能力,针对小样本学习和长尾问题,达到事半功倍的效果。
基于前述的实施例,本申请实施例提供一种模型训练装置,该装置包括所包括的各模块、以及各模块所包括的各子模块,各子模块包括的各单元,可以通过计算机设备中的处理器来实现;当然也可通过具体的逻辑电路实现。
图5为本申请实施例提供的一种模型训练装置的组成结构示意图,如图5所示,装置500包括:
获取模块501,用于获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型;
第一确定模块502,用于在基于与所述识别任务对应的训练数据集,训练所述学生模型和所述教师模型的过程中,基于所述学生模型的预测结果与所述教师模型的第一预测结果之间的差异程度和所述学生模型的预测结果,确定第一损失值,其中,所述教师模型对应的所述大规模数据集的类别维度被映射为所述学生模型对应的所述识别任务的类别维度;
第一更新模块503,用于基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛。
在一些实施例中,所述第一确定模块包括获取子模块、第一前向计算子模块、第二前向计算子模块和第一确定子模块,其中,所述获取子模块,用于获取所述识别任务对应的训练数据集;所述第一前向计算子模块,用于将所述训练数据集输入所述学生模型完成前向计算,得到所述学生模型的预测结果;所述第二前向计算子模块,用于将所述训练数据集输入所述教师模型完成前向计算,得到所述教师模型的第一预测结果;所述第一确定子模块,用于基于所述学生模型的预测结果和所述教师模型的第一预测结果确定所述第一损失值。
在一些实施例中,所述第一确定子模块包括第一确定单元、第二确定单元和第三确定单元,其中,所述第一确定单元,用于基于所述学生模型的预测结果,确定第二损失值;所述第二确定单元,用于基于所述学生模型的预测结果和所述教师模型的第一预测结果之间的差异程度,确定第三损失值;所述第三确定单元,用于基于所述第二损失值和所述第三损失值,确定所述第一损失值。
在一些实施例中,所述第一更新模块包括第二确定子模块、第一更新子模块和第三确定子模块,其中,所述第二确定子模块,用于基于所述第一损失值对所述学生模型进行反向传播,得到所述学生模型的更新梯度;所述第一更新子模块,用于基于所述学生模型的更新梯度更新所述学生模型的参数;所述第三确定子模块,用于迭代训练所述学生模型和所述教师模型,响应于所述第二损失值小于等于第一损失值阈值,确定所述学生模型收敛。
在一些实施例中,所述装置500还包括第二确定模块和添加模块,其中,所述第二确定模块,用于基于所述学生模型的预测结果所对应的类别维度和所述教师模型的第一预测结果所对应的类别维度,确定映射层的参数;所述添加模块,用于将所述映射层添加至所述教师模型,以使得添加后的所述教师模型的预测结果所对应的类别维度被映射为所述学生模型的预测结果所对应的类别维度。
在一些实施例中,第一更新模块还包括第四确定子模块、第二更新子模块,其中,所述第四确定子模块,用于基于所述第一损失值确定所述教师模型的映射层的更新梯度;所述第二更新子模块,用于基于所述教师模型的映射层的更新梯度更新所述教师模型的映射层的参数。
在一些实施例中,所述装置500还包括前向计算模块、交叉熵损失函数计算模块、第二更新模块。其中,所述前向计算模块,用于基于所述大规模数据集,按照批次采样数据,输入所述教师模型中进行前向计算,得到所述教师模型的第二预测结果;所述交叉熵损失函数计算模块,用于利用交叉熵损失函数计算所述教师模型的第二预测结果与所述第一数据集中图像标签的差异,得到第四损失值;所述第二更新模块,用于基于所述第四损失值更新所述教师模型的参数,响应于所述教师模型的损失值小于第二损失阈值,得到收敛的教师模型。
在一些实施例中,所述第二更新模块包括反向传播子模块、第五确定子模块、第三更新子模块,其中,所述反向传播子模块,用于基于所述第四损失值对所述教师模型进行反向传播,得到所述教师模型的更新梯度;所述第五确定子模块,用于基于所述教师模型的更新梯度和预设的学习率,确定所述教师模型的更新幅度;所述第三更新子模块,用于利用所述教师模型的更新幅度更新所述教师模型的参数,响应于所述教师模型的损失值小于第二损失阈值,得到收敛的教师模型。
在一些实施例中,所述装置500还包括差值补充模块和移除模块,其中,所述差值补充模块,用于对所述大规模数据集中每一图像的缺失信息进行差值补充;所述移除模块,用于将未能补全的所述图像从所述大规模数据集中移除,得到清洗后的大规模数据集,以完成对所述教师模型的训练。
以上装置实施例的描述,与上述方法实施例的描述是类似的,具有同方法实施例相似的有益效果。对于本申请装置实施例中未披露的技术细节,请参照本申请方法实施例的描述而理解。需要说明的是,本申请实施例中,如果以软件功能模块的形式实现上述的模式控制方法,并作为独立的产品销售或使用时,也可以存储在一个计算机存储介质中。基于这样的理解,本申请实施例的技术方案本质上或者说对相关技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个计算机存储介质中,包括若干指令用以使得一台计算机设备执行本申请各个实施例所述方法的全部或部分。
需要说明的是,本申请实施例中,如果以软件功能模块的形式实现上述的模型训练方法,并作为独立的产品销售或使用时,也可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请实施例的技术方案本质上或者说对相关技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得计算机设备(可以是手机、平板电脑、笔记本电脑、台式计算机、机器人、服务器等)执行本申请各个实施例所述方法的全部或部分。而前述的存储介质包括:U盘、移动硬盘、只读存储器(Read Only Memory,ROM)、磁碟或者光盘等各种可以存储程序代码的介质。这样,本申请实施例不限制于任何特定的硬件和软件结合。
对应地,本申请实施例提供一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现上述实施例中提供的模型训练方法中的步骤。
对应地,本申请实施例提供一种计算机设备,图6为本申请实施例的一种硬件实体示意图,如图6所示,该设备600的硬件实体包括:包括存储器601和处理器602,所述存储器601存储有可在处理器602上运行的计算机程序,所述处理器602执行所述程序时实现上述实施例中提供的方法中的步骤。
存储器601配置为存储由处理器602可执行的指令和应用,还可以缓存待处理器602以及计算机设备600中各模块待处理或已经处理的数据(例如,图像数据、音频数据、语音通信数据和视频通信数据),可以通过闪存(FLASH)或随机访问存储器(Random AccessMemory,RAM)实现。
这里需要指出的是:以上存储介质和设备实施例的描述,与上述方法实施例的描述是类似的,具有同方法实施例相似的有益效果。对于本申请存储介质和设备实施例中未披露的技术细节,请参照本申请方法实施例的描述而理解。
应理解,说明书通篇中提到的“一个实施例”或“一实施例”意味着与实施例有关的特定特征、结构或特性包括在本申请的至少一个实施例中。因此,在整个说明书各处出现的“在一个实施例中”或“在一实施例中”未必一定指相同的实施例。此外,这些特定的特征、结构或特性可以任意适合的方式结合在一个或多个实施例中。应理解,在本申请的各种实施例中,上述各过程的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本申请实施例的实施过程构成任何限定。上述本申请实施例序号仅仅为了描述,不代表实施例的优劣。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者装置不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者装置所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者装置中还存在另外的相同要素。
在本申请所提供的几个实施例中,应该理解到,所揭露的设备和方法,可以通过其它的方式实现。以上所描述的设备实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,如:多个单元或组件可以结合,或可以集成到另一个系统,或一些特征可以忽略,或不执行。另外,所显示或讨论的各组成部分相互之间的耦合、或直接耦合、或通信连接可以是通过一些接口,设备或单元的间接耦合或通信连接,可以是电性的、机械的或其它形式的。
上述作为分离部件说明的单元可以是、或也可以不是物理上分开的,作为单元显示的部件可以是、或也可以不是物理单元;既可以位于一个地方,也可以分布到多个网络单元上;可以根据实际的需要选择其中的部分或全部单元来实现本实施例方案的目的。
另外,在本申请各实施例中的各功能单元可以全部集成在一个处理单元中,也可以是各单元分别单独作为一个单元,也可以两个或两个以上单元集成在一个单元中;上述集成的单元既可以采用硬件的形式实现,也可以采用硬件加软件功能单元的形式实现。
本领域普通技术人员可以理解:实现上述方法实施例的全部或部分步骤可以通过程序指令相关的硬件来完成,前述的程序可以存储于计算机可读取存储介质中,该程序在执行时,执行包括上述方法实施例的步骤;而前述的存储介质包括:移动存储设备、只读存储器(Read Only Memory,ROM)、磁碟或者光盘等各种可以存储程序代码的介质。
或者,本申请上述集成的单元如果以软件功能模块的形式实现并作为独立的产品销售或使用时,也可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请实施例的技术方案本质上或者说对相关技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得计算机设备(可以是手机、平板电脑、笔记本电脑、台式计算机、机器人、服务器等)执行本申请各个实施例所述方法的全部或部分。而前述的存储介质包括:移动存储设备、ROM、磁碟或者光盘等各种可以存储程序代码的介质。
本申请所提供的几个方法实施例中所揭露的方法,在不冲突的情况下可以任意组合,得到新的方法实施例。
本申请所提供的几个产品实施例中所揭露的特征,在不冲突的情况下可以任意组合,得到新的产品实施例。
本申请所提供的几个方法或设备实施例中所揭露的特征,在不冲突的情况下可以任意组合,得到新的方法实施例或设备实施例。
以上所述,仅为本申请的实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应以所述权利要求的保护范围为准。
Claims (12)
1.一种模型训练方法,其特征在于,所述方法包括:
获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型;
在基于与所述识别任务对应的训练数据集,训练所述学生模型和所述教师模型的过程中,基于所述学生模型的预测结果与所述教师模型的第一预测结果之间的差异程度和所述学生模型的预测结果,确定第一损失值,其中,所述教师模型对应的所述大规模数据集的类别维度被映射为所述学生模型对应的所述识别任务的类别维度;
基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛。
2.如权利要求1所述的方法,其特征在于,所述在基于与所述识别任务对应的训练数据集,训练所述学生模型和所述教师模型的过程中,基于所述学生模型的预测结果与所述教师模型的预测结果之间的差异程度和所述学生模型的预测结果,确定第一损失值,包括:
获取所述识别任务对应的训练数据集;
将所述训练数据集输入所述学生模型完成前向计算,得到所述学生模型的预测结果;
将所述训练数据集输入所述教师模型完成前向计算,得到所述教师模型的第一预测结果;
基于所述学生模型的预测结果和所述教师模型的第一预测结果确定所述第一损失值。
3.如权利要求2所述的方法,其特征在于,所述基于所述学生模型的预测结果和所述教师模型的第一预测结果确定所述第一损失值,包括:
基于所述学生模型的预测结果,确定第二损失值;
基于所述学生模型的预测结果和所述教师模型的第一预测结果之间的差异程度,确定第三损失值;
基于所述第二损失值和所述第三损失值,确定所述第一损失值。
4.如权利要求1至3任一项所述的方法,其特征在于,所述基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛,包括:
基于所述第一损失值对所述学生模型进行反向传播,得到所述学生模型的更新梯度;
基于所述学生模型的更新梯度更新所述学生模型的参数;
迭代训练所述学生模型和所述教师模型,响应于所述第二损失值小于等于第一损失值阈值,确定所述学生模型收敛。
5.如权利要求1至4任一项所述的方法,其特在于,所述方法还包括:
基于所述学生模型的预测结果所对应的类别维度和所述教师模型的第一预测结果所对应的类别维度,确定映射层的参数;
将所述映射层添加至所述教师模型,以使得添加后的所述教师模型的预测结果所对应的类别维度被映射为所述学生模型的预测结果所对应的类别维度。
6.如权利要求5所述的方法,其特征在于,所述基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛,还包括:
基于所述第一损失值确定所述教师模型的映射层的更新梯度;
基于所述教师模型的映射层的更新梯度更新所述教师模型的映射层的参数。
7.如权利要求1至6任一项所述的方法,其特征在于,所述方法还包括:
基于所述大规模数据集,按照批次采样数据,输入所述教师模型中进行前向计算,得到所述教师模型的第二预测结果;
利用交叉熵损失函数计算所述教师模型的第二预测结果与所述第一数据集中图像标签的差异,得到第四损失值;
基于所述第四损失值更新所述教师模型的参数,响应于所述教师模型的损失值小于第二损失阈值,得到收敛的教师模型。
8.如权利要求7所述的方法,其特征在于,所述基于所述第四损失值更新所述教师模型的参数,响应于所述教师模型的损失值小于第二损失阈值,得到收敛的教师模型,包括:
基于所述第四损失值对所述教师模型进行反向传播,得到所述教师模型的更新梯度;
基于所述教师模型的更新梯度和预设的学习率,确定所述教师模型的更新幅度;
利用所述教师模型的更新幅度更新所述教师模型的参数,响应于所述教师模型的损失值小于第二损失阈值,得到收敛的教师模型。
9.如权利要求1至8任一项所述的方法,其特征在于,所述方法还包括:
对所述大规模数据集中每一图像的缺失信息进行差值补充;
将未能补全的所述图像从所述大规模数据集中移除,得到清洗后的大规模数据集,以完成对所述教师模型的训练。
10.一种模型训练装置,其特征在于,包括:
获取模块,用于获取使用已经在大规模数据集上训练收敛的教师模型和应用于识别任务的学生模型;
第一确定模块,用于在基于与所述识别任务对应的训练数据集,训练所述学生模型和所述教师模型的过程中,基于所述学生模型的预测结果与所述教师模型的第一预测结果之间的差异程度和所述学生模型的预测结果,确定第一损失值,其中,所述教师模型对应的大规模数据集的类别维度被映射为所述学生模型对应的所述识别任务的类别维度;
第一更新模块,用于基于所述第一损失值,对所述学生模型的参数进行更新,直至所述学生模型收敛。
11.一种计算机设备,其特征在于,包括:存储器和处理器,所述存储器存储有可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现权利要求1至9中任一项所述方法中的步骤。
12.一种计算机存储介质,其特征在于,所述计算机存储介质存储有一个或者多个程序,所述一个或者多个程序可被一个或者多个处理器执行,以实现权利要求1至9中任一项所述方法中的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110660998.8A CN113222123B (zh) | 2021-06-15 | 2021-06-15 | 模型训练方法、装置、设备及计算机存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110660998.8A CN113222123B (zh) | 2021-06-15 | 2021-06-15 | 模型训练方法、装置、设备及计算机存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113222123A true CN113222123A (zh) | 2021-08-06 |
CN113222123B CN113222123B (zh) | 2024-08-09 |
Family
ID=77080401
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110660998.8A Active CN113222123B (zh) | 2021-06-15 | 2021-06-15 | 模型训练方法、装置、设备及计算机存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113222123B (zh) |
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113505797A (zh) * | 2021-09-09 | 2021-10-15 | 深圳思谋信息科技有限公司 | 模型训练方法、装置、计算机设备和存储介质 |
CN114065834A (zh) * | 2021-09-30 | 2022-02-18 | 中国科学院深圳先进技术研究院 | 一种模型训练方法、终端设备及计算机存储介质 |
CN114372618A (zh) * | 2021-12-27 | 2022-04-19 | 北京北明数科信息技术有限公司 | 一种学生成绩的预测方法、系统、计算机设备及存储介质 |
CN114565807A (zh) * | 2022-03-03 | 2022-05-31 | 腾讯科技(深圳)有限公司 | 训练目标图像检索模型的方法和装置 |
CN114596468A (zh) * | 2022-03-14 | 2022-06-07 | 瀚云科技有限公司 | 病虫害识别及模型训练方法、装置、电子设备及存储介质 |
CN116594349A (zh) * | 2023-07-18 | 2023-08-15 | 中科航迈数控软件(深圳)有限公司 | 机床预测方法、装置、终端设备以及计算机可读存储介质 |
CN116863278A (zh) * | 2023-08-25 | 2023-10-10 | 摩尔线程智能科技(北京)有限责任公司 | 模型训练方法、图像分类方法、装置、设备及存储介质 |
Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180268292A1 (en) * | 2017-03-17 | 2018-09-20 | Nec Laboratories America, Inc. | Learning efficient object detection models with knowledge distillation |
CN111160409A (zh) * | 2019-12-11 | 2020-05-15 | 浙江大学 | 一种基于共同特征学习的异构神经网络知识重组方法 |
CN111160474A (zh) * | 2019-12-30 | 2020-05-15 | 合肥工业大学 | 一种基于深度课程学习的图像识别方法 |
CN111242297A (zh) * | 2019-12-19 | 2020-06-05 | 北京迈格威科技有限公司 | 基于知识蒸馏的模型训练方法、图像处理方法及装置 |
CN111709476A (zh) * | 2020-06-17 | 2020-09-25 | 浪潮集团有限公司 | 一种基于知识蒸馏的小分类模型训练方法及装置 |
CN111767711A (zh) * | 2020-09-02 | 2020-10-13 | 之江实验室 | 基于知识蒸馏的预训练语言模型的压缩方法及平台 |
CN111985523A (zh) * | 2020-06-28 | 2020-11-24 | 合肥工业大学 | 基于知识蒸馏训练的2指数幂深度神经网络量化方法 |
CN112199535A (zh) * | 2020-09-30 | 2021-01-08 | 浙江大学 | 一种基于集成知识蒸馏的图像分类方法 |
CN112784964A (zh) * | 2021-01-27 | 2021-05-11 | 西安电子科技大学 | 基于桥接知识蒸馏卷积神经网络的图像分类方法 |
-
2021
- 2021-06-15 CN CN202110660998.8A patent/CN113222123B/zh active Active
Patent Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180268292A1 (en) * | 2017-03-17 | 2018-09-20 | Nec Laboratories America, Inc. | Learning efficient object detection models with knowledge distillation |
CN111160409A (zh) * | 2019-12-11 | 2020-05-15 | 浙江大学 | 一种基于共同特征学习的异构神经网络知识重组方法 |
CN111242297A (zh) * | 2019-12-19 | 2020-06-05 | 北京迈格威科技有限公司 | 基于知识蒸馏的模型训练方法、图像处理方法及装置 |
CN111160474A (zh) * | 2019-12-30 | 2020-05-15 | 合肥工业大学 | 一种基于深度课程学习的图像识别方法 |
CN111709476A (zh) * | 2020-06-17 | 2020-09-25 | 浪潮集团有限公司 | 一种基于知识蒸馏的小分类模型训练方法及装置 |
CN111985523A (zh) * | 2020-06-28 | 2020-11-24 | 合肥工业大学 | 基于知识蒸馏训练的2指数幂深度神经网络量化方法 |
CN111767711A (zh) * | 2020-09-02 | 2020-10-13 | 之江实验室 | 基于知识蒸馏的预训练语言模型的压缩方法及平台 |
CN112199535A (zh) * | 2020-09-30 | 2021-01-08 | 浙江大学 | 一种基于集成知识蒸馏的图像分类方法 |
CN112784964A (zh) * | 2021-01-27 | 2021-05-11 | 西安电子科技大学 | 基于桥接知识蒸馏卷积神经网络的图像分类方法 |
Non-Patent Citations (2)
Title |
---|
GUOBIN CHEN ET AL: "Learning Efficient Object Detection Models with knowledge distillation", 《PROCEEDINGS OF THE 31ST INTERNATIONAL CONFERENCE ON NEURAL INFORMATION PROCESSING SYSTEMS》, 4 December 2017 (2017-12-04), pages 742 - 751 * |
倪建功 等: "基于知识蒸馏的胡萝卜外观品质等级智能检测", 《农业工程学报》, 30 September 2020 (2020-09-30), pages 181 - 187 * |
Cited By (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113505797A (zh) * | 2021-09-09 | 2021-10-15 | 深圳思谋信息科技有限公司 | 模型训练方法、装置、计算机设备和存储介质 |
CN113505797B (zh) * | 2021-09-09 | 2021-12-14 | 深圳思谋信息科技有限公司 | 模型训练方法、装置、计算机设备和存储介质 |
CN114065834A (zh) * | 2021-09-30 | 2022-02-18 | 中国科学院深圳先进技术研究院 | 一种模型训练方法、终端设备及计算机存储介质 |
CN114372618A (zh) * | 2021-12-27 | 2022-04-19 | 北京北明数科信息技术有限公司 | 一种学生成绩的预测方法、系统、计算机设备及存储介质 |
CN114565807A (zh) * | 2022-03-03 | 2022-05-31 | 腾讯科技(深圳)有限公司 | 训练目标图像检索模型的方法和装置 |
CN114596468A (zh) * | 2022-03-14 | 2022-06-07 | 瀚云科技有限公司 | 病虫害识别及模型训练方法、装置、电子设备及存储介质 |
CN116594349A (zh) * | 2023-07-18 | 2023-08-15 | 中科航迈数控软件(深圳)有限公司 | 机床预测方法、装置、终端设备以及计算机可读存储介质 |
CN116594349B (zh) * | 2023-07-18 | 2023-10-03 | 中科航迈数控软件(深圳)有限公司 | 机床预测方法、装置、终端设备以及计算机可读存储介质 |
CN116863278A (zh) * | 2023-08-25 | 2023-10-10 | 摩尔线程智能科技(北京)有限责任公司 | 模型训练方法、图像分类方法、装置、设备及存储介质 |
CN116863278B (zh) * | 2023-08-25 | 2024-01-26 | 摩尔线程智能科技(北京)有限责任公司 | 模型训练方法、图像分类方法、装置、设备及存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN113222123B (zh) | 2024-08-09 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113222123B (zh) | 模型训练方法、装置、设备及计算机存储介质 | |
CN108764292B (zh) | 基于弱监督信息的深度学习图像目标映射及定位方法 | |
CN111967971B (zh) | 银行客户数据处理方法及装置 | |
CN111507768B (zh) | 一种潜在用户的确定方法及相关装置 | |
EP3602419B1 (en) | Neural network optimizer search | |
CN108780519A (zh) | 卷积神经网络中的结构学习 | |
CN111639755B (zh) | 一种网络模型训练方法、装置、电子设备及存储介质 | |
US11954755B2 (en) | Image processing device and operation method thereof | |
CN110210493B (zh) | 基于非经典感受野调制神经网络的轮廓检测方法及系统 | |
CN109447096B (zh) | 一种基于机器学习的扫视路径预测方法和装置 | |
CN113987236B (zh) | 基于图卷积网络的视觉检索模型的无监督训练方法和装置 | |
CN110826581A (zh) | 一种动物数量识别方法、装置、介质及电子设备 | |
US20230020112A1 (en) | Relating complex data | |
EP4433990A1 (en) | Method and system for analysing medical images to generate a medical report | |
CN112749737A (zh) | 图像分类方法及装置、电子设备、存储介质 | |
CN114723989A (zh) | 多任务学习方法、装置及电子设备 | |
KR20230068941A (ko) | 딥러닝 학습 기법을 이용하는 유사도 기반 클러스터링 장치 및 그 방법 | |
CN111967973B (zh) | 银行客户数据处理方法及装置 | |
CN117788629A (zh) | 一种具有风格个性化的图像生成方法、装置及存储介质 | |
CN109934352B (zh) | 智能模型的自动进化方法 | |
CN113780394B (zh) | 一种强分类器模型的训练方法、装置及设备 | |
CN113066094B (zh) | 一种基于生成对抗网络的地理栅格智能化局部脱敏方法 | |
CN115358374A (zh) | 基于知识蒸馏的模型训练方法、装置、设备及存储介质 | |
US20220343134A1 (en) | Convolutional neural network architectures based on synaptic connectivity | |
CN113822293A (zh) | 用于图数据的模型处理方法、装置、设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |