CN113298135A - 基于深度学习的模型训练方法、装置、存储介质及设备 - Google Patents
基于深度学习的模型训练方法、装置、存储介质及设备 Download PDFInfo
- Publication number
- CN113298135A CN113298135A CN202110555154.7A CN202110555154A CN113298135A CN 113298135 A CN113298135 A CN 113298135A CN 202110555154 A CN202110555154 A CN 202110555154A CN 113298135 A CN113298135 A CN 113298135A
- Authority
- CN
- China
- Prior art keywords
- image data
- model
- model training
- loss
- deep learning
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/25—Fusion techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V2201/00—Indexing scheme relating to image or video recognition or understanding
- G06V2201/07—Target detection
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Biophysics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Computational Biology (AREA)
- Probability & Statistics with Applications (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Evolutionary Biology (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Image Analysis (AREA)
- Image Processing (AREA)
Abstract
本发明公开一种基于深度学习的模型训练方法,所述方法包括:步骤1,采集图像数据,并将所述图像数据与预训练模型进行融合,得到第一图像数据集;步骤2,设置所述第一图像数据集中每个图像数据被选取的概率为1/M;步骤3,从所述第一图像数据集中选择N个图像数据组成第k批次数据,其中k是大于0的整数,M>N;步骤4,对所述第k批次数据进行模型训练;步骤5,模型训练后,若所述模型的损失函数计算结果大于阈值,则对所述第k批次数据中的N个图像数据进行权重分配;步骤6,将权重分配后的N个图像数据释放到第一图像数据集;重复步骤3至步骤6,直到所述模型的损失函数计算结果小于等于阈值。本发明方法能加速模型在新场景的收敛速度。
Description
技术领域
本发明涉及图像识别技术领域,尤其涉及一种基于深度学习的模型训练方法、装置、存储介质及设备。
背景技术
在图像领域中,传统基于深度学习的目标检测模型训练流程大致分为如下几步骤:数据标注(训练集)、设计模型结构、基于训练集训练模型,当模型收敛时得到最终需要的模型参数。而因为不同场景存在着差异性,往往在一个新的场景中,我们会使用一个已经训练好的模型(又称为预训练模型),在新场景中标注了新的数据,然后和之前的训练集进行融合,再重新训练模型,让模型能在新的场景中达到比较好的效果,但是随着场景的增多,训练集会越来越大,模型的训练周期会越来越长。
发明内容
本发明提供一种基于深度学习的模型训练方法、装置、存储介质及设备,用于解决因场景的增多,训练集会越来越大,导致模型的训练周期会越来越长的技术问题。
所述技术方案如下:
步骤1,采集图像数据,并将所述图像数据与预训练模型进行融合,得到第一图像数据集;
步骤2,设置所述第一图像数据集中每个图像数据被选取的概率为1/M;
步骤3,从所述第一图像数据集中选择N个图像数据组成第k批次数据,其中k是大于0的整数,M>N;
步骤4,对所述第k批次数据进行模型训练;
步骤5,模型训练后,若所述模型的损失函数计算结果大于阈值,则对所述第k批次数据中的N个图像数据进行权重分配;
步骤6,将权重分配后的N个图像数据释放到第一图像数据集;
重复步骤3至步骤6,直到所述模型的损失函数计算结果小于等于阈值。
在一种可能的实现方式中,所述权重分配具体包括:使用概率重置函数进行权重分配,N个图像数据的损失函数分别为loss(1),loss(2)....loss(N),权重分配后每个图像数据被选取概率为:
在一种可能的实现方式中,所述概率和S(k)的值随k改变而改变,且当k=1时,S(1)=N/M。
在一种可能的实现方式中,所述模型的损失函数为所述第k批次数据中的N个图像数据损失函数之和,所述图像数据的损失函数计算结果越大,表示该图像数据与所述模型的相关性越小,损失函数计算结果越小,表示该图像数据与所述模型相关性越大。
在一种可能的实现方式中,所述方法该包括:当k=1时,从所述第一图像数据集中随机选择N个图像数据,当k>1时,将所述第一图像数据集中的所有图像数据按照损失函数计算结果从大到小排序,选择前N个图像数据组成第k批次数据。
在一种可能的实现方式中,所述模型应用于目标检测,通过目标图像数据的坐标和类型计算所述目标图像数据的损失函数。
一种基于深度学习的模型训练装置,所述装置包括:
采集模块,用于采集图像数据,并将所述图像数据与预训练模型进行融合,得到第一图像数据集;
初始化模块,用于设置所述第一图像数据集中每个图像数据被选取的概率为1/M;
选择模块,用于从所述第一图像数据集中选择N个图像数据组成第k批次数据,其中k是大于0的整数,M>N;
训练模块,用于对所述第k批次数据进行模型训练;
权重分配模块,用于模型训练后,若所述模型的损失函数计算结果大于阈值,则对所述第k批次数据中的N个图像数据进行权重分配;
判断模块,用于判断所述模型的损失函数计算结果小于等于阈值。
在一种可能的实现方式中,所述权重分配模块用于以下述方式权重分配:
使用概率重置函数进行权重分配,N个图像数据的损失函数分别为loss(1),loss(2)....loss(N),其中所述N个图像数据的概率和为S(k),权重分配后每个图像数据被选取概率为:
一种计算机可读存储介质,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现如权利要求1至6任一所述的基于深度学习的模型训练方法。
一种计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令,所述指令由所述处理器加载并执行以实现如权利要求1至6任一所述的基于深度学习的模型训练方法。
和传统方法相比,本发明方法会在训练过程中重新计算每个数据对于模型的重要性,让重要性高的数据在一个重新训练的时候更加容易被使用到,能加速模型在新场景的收敛速度。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对本发明实施例描述中所需要使用的附图作简单的介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域的普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他附图。
图1是传统模型训练方法流程图;
图2是本发明实施例基于深度学习的模型训练方法流程图;
图3是本发明实施例提供基于深度学习的模型训练装置框图。
具体实施方式
下面结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明的实施例,本领域普通技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都属于本发明的保护范围。
本申请可以用于目标检测的任务,找出图像中所有感兴趣的目标(物体),确定它们的类别和位置。通常会根据模型预测结果的类别和坐标两个信息共同判断模型的好坏,当模型输出的信息与我们标注信息越接近,说明这张图片对模型来说越没有学习价值。而当模型输出的信息与标注信息越接近时,loss越低。
图1示出了传统的训练方式的流程图。
请参考图2,其示出了本申请一个实施例提供的基于深度学习的模型训练方法的方法流程图,该基于深度学习的模型训练方法可以应用于计算机设备中。该基于深度学习的模型训练方法,可以包括:
步骤1,采集图像数据,并将所述图像数据与预训练模型进行融合,得到第一图像数据集;
其中,采集设备中可以配置有摄像头,比如,采集设备可以是枪机摄像头。本实施例不作限定。采集设备可以对新场景进行拍摄,获取图像数据,与原有场景的数据进行融合,形成Cinderella数据集。
步骤2,设置所述第一图像数据集中每个图像数据被选取的概率Q(i)=1/M;
初始化所有数据集,假设有M个图像数据,初始化设置每个图像数据被选取的概率均为1/M。
步骤3,从所述第一图像数据集中选择N个图像数据组成第k批次数据,其中k是大于0的整数,M>N;
从数据集中进行随机采样N个数据组成一个batch。
步骤4,对所述第k批次数据进行模型训练;
使用一个batch中的数据进行模型训练。
步骤5,模型训练后,若所述模型的损失函数计算结果大于阈值,则对所述第k批次数据中的N个图像数据进行权重分配;
基于损失函数loss优化模型参数,判断模式是否收敛,若当前模型不满足收敛条件,则把这个N个数据通过概率重置函数,重新设置选取概率。
重置函数如下:对于数据集总数为M,假设一次取N个样本,在这N个样本中的loss为loss(1),loss(2)....loss(N),概率和为S(k),每个图片的新的选取概率为:
步骤6,将权重分配后的N个图像数据释放到第一图像数据集;
重复步骤3至步骤6,直到所述模型的损失函数计算结果小于等于阈值。
下面我们结合具体的训练进行测试,验证本发明方法。
在人脸检测任务的实验中,定义测试集精度达到90%时认为模型达到预期指标,我们用同一个模型采用传统方法和本发明的方法进行再次迭代(finetune),观察模型达到预期需要的进行多少轮batch。
如表1,本次人脸检测任务实验中,原始数据集有12000中训练图片,从新场景中获取6800张图片,每次迭代训练所选取批次数据有32个图片,对比发现,达到相同测试集精度所需的batch数量分别为11765和4234。本方法主要是通过重新分配每次训练时难例被选择的概率来使得模型可以更快速的收敛,也即,在新场景中finetune的时候能更快的收敛。
表1
本发明提出的基于深度学习的模型训练方法,会在每次模型训练迭代过程中重新计算每个数据对于模型的重要性,让重要性高的数据在一个重新训练的时候更加容易被使用到,相比于传统模型训练方法每次迭代过程中都是采用随机选择数据,导致模型的训练周期长。因此,本发明方法能加速模型在新场景的收敛速度,缩短模型训练周期。
结合图示以及本发明的以上实施例,本发明还可以以装置、计算机可读取介质的方式实施。
本申请一个实施例提供了基于深度学习的模型训练装置100,结合图3。
采集模块110,用于采集图像数据,并将所述图像数据与预训练模型进行融合,得到第一图像数据集;
初始化模块120,用于设置所述第一图像数据集中每个图像数据被选取的概率为1/M;
选择模块130,用于从所述第一图像数据集中选择N个图像数据组成第k批次数据,其中k是大于0的整数,M>N;
训练模块140,用于对所述第k批次数据进行模型训练;
权重分配模块150,用于模型训练后,若所述模型的损失函数计算结果大于阈值,则对所述第k批次数据中的N个图像数据进行权重分配;
判断模块160,用于判断所述模型的损失函数计算结果小于等于阈值。
权重分配模块150用于以下述方式权重分配:
使用概率重置函数进行权重分配,N个图像数据的损失函数分别为loss(1),loss(2)....loss(N),其中所述N个图像数据的概率和为S(k),权重分配后每个图像数据被选取概率为:
本申请一个实施例提供了一种计算机可读取介质,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现包括前述实施例的基于深度学习的模型训练方法的过程,尤其是图1实施例所示的过程。
本申请一个实施例提供了一种计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令,所述指令由所述处理器加载并执行以实现前述实施例的基于深度学习的模型训练方法。
需要说明的是:上述实施例提供的基于深度学习的模型训练装置在进行基于深度学习的模型训练时,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将基于深度学习的模型训练装置的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。另外,上述实施例提供的基于深度学习的模型训练装置与基于深度学习的模型训练方法实施例属于同一构思,其具体实现过程详见方法实施例,这里不再赘述。
本领域普通技术人员可以理解实现上述实施例的全部或部分步骤可以通过硬件来完成,也可以通过程序来指令相关的硬件完成,所述的程序可以存储于一种计算机可读存储介质中,上述提到的存储介质可以是只读存储器,磁盘或光盘等。
以上所述并不用以限制本申请实施例,凡在本申请实施例的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请实施例的保护范围之内。
Claims (10)
1.一种基于深度学习的模型训练方法,其特征在于,所述方法包括步骤:
步骤1,采集图像数据,并将所述图像数据与预训练模型进行融合,得到第一图像数据集;
步骤2,设置所述第一图像数据集中每个图像数据被选取的概率为1/M;
步骤3,从所述第一图像数据集中选择N个图像数据组成第k批次数据,其中k是大于0的整数,M>N;
步骤4,对所述第k批次数据进行模型训练;
步骤5,模型训练后,若所述模型的损失函数计算结果大于阈值,则对所述第k批次数据中的N个图像数据进行权重分配;
步骤6,将权重分配后的N个图像数据释放到第一图像数据集;
重复步骤3至步骤6,直到所述模型的损失函数计算结果小于等于阈值。
3.根据权利要求2所述的基于深度学习的模型训练方法,其特征在于,所述概率和S(k)的值随k改变而改变,且当k=1时,S(1)=N/M。
4.根据权利要求1所述的基于深度学习的模型训练方法,其特征在于,所述模型的损失函数为所述第k批次数据中的N个图像数据损失函数之和,所述图像数据的损失函数计算结果越大,表示该图像数据与所述模型的相关性越小,损失函数计算结果越小,表示该图像数据与所述模型相关性越大。
5.根据权利要求1所述的基于深度学习的模型训练方法,其特征在于,所述方法该包括:当k=1时,从所述第一图像数据集中随机选择N个图像数据,当k>1时,将所述第一图像数据集中的所有图像数据按照损失函数计算结果从大到小排序,选择前N个图像数据组成第k批次数据。
6.根据权利要求1所述的基于深度学习的模型训练方法,其特征在于,所述模型应用于目标检测,通过目标图像数据的坐标和类型计算所述目标图像数据的损失函数。
7.一种基于深度学习的模型训练装置,其特征在于,所述装置包括:
采集模块,用于采集图像数据,并将所述图像数据与预训练模型进行融合,得到第一图像数据集;
初始化模块,用于设置所述第一图像数据集中每个图像数据被选取的概率为1/M;
选择模块,用于从所述第一图像数据集中选择N个图像数据组成第k批次数据,其中k是大于0的整数,M>N;
训练模块,用于对所述第k批次数据进行模型训练;
权重分配模块,用于模型训练后,若所述模型的损失函数计算结果大于阈值,则对所述第k批次数据中的N个图像数据进行权重分配;
判断模块,用于判断所述模型的损失函数计算结果小于等于阈值。
9.一种计算机可读存储介质,其特征在于,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现如权利要求1至6任一项所述的基于深度学习的模型训练方法。
10.一种计算机设备,其特征在于,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令,所述指令由所述处理器加载并执行以实现如权利要求1至6任一项所述的基于深度学习的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110555154.7A CN113298135B (zh) | 2021-05-21 | 2021-05-21 | 基于深度学习的模型训练方法、装置、存储介质及设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110555154.7A CN113298135B (zh) | 2021-05-21 | 2021-05-21 | 基于深度学习的模型训练方法、装置、存储介质及设备 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113298135A true CN113298135A (zh) | 2021-08-24 |
CN113298135B CN113298135B (zh) | 2023-04-18 |
Family
ID=77323383
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110555154.7A Active CN113298135B (zh) | 2021-05-21 | 2021-05-21 | 基于深度学习的模型训练方法、装置、存储介质及设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113298135B (zh) |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110163234A (zh) * | 2018-10-10 | 2019-08-23 | 腾讯科技(深圳)有限公司 | 一种模型训练方法、装置和存储介质 |
CN111860840A (zh) * | 2020-07-28 | 2020-10-30 | 上海联影医疗科技有限公司 | 深度学习模型训练方法、装置、计算机设备及存储介质 |
CN112215248A (zh) * | 2019-07-11 | 2021-01-12 | 深圳先进技术研究院 | 深度学习模型训练方法、装置、电子设备及存储介质 |
-
2021
- 2021-05-21 CN CN202110555154.7A patent/CN113298135B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110163234A (zh) * | 2018-10-10 | 2019-08-23 | 腾讯科技(深圳)有限公司 | 一种模型训练方法、装置和存储介质 |
CN112215248A (zh) * | 2019-07-11 | 2021-01-12 | 深圳先进技术研究院 | 深度学习模型训练方法、装置、电子设备及存储介质 |
CN111860840A (zh) * | 2020-07-28 | 2020-10-30 | 上海联影医疗科技有限公司 | 深度学习模型训练方法、装置、计算机设备及存储介质 |
Non-Patent Citations (1)
Title |
---|
张延安等: "基于深度卷积神经网络与中心损失的人脸识别", 《科学技术与工程》 * |
Also Published As
Publication number | Publication date |
---|---|
CN113298135B (zh) | 2023-04-18 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109948478B (zh) | 基于神经网络的大规模非均衡数据的人脸识别方法、系统 | |
CN111242217A (zh) | 图像识别模型的训练方法、装置、电子设备及存储介质 | |
CN111444346B (zh) | 一种用于文本分类的词向量对抗样本生成方法及装置 | |
CN108304328B (zh) | 一种众包测试报告的文本描述生成方法、系统及装置 | |
CN113850300A (zh) | 训练分类模型的方法和装置 | |
CN113902944A (zh) | 模型的训练及场景识别方法、装置、设备及介质 | |
CN111783812A (zh) | 违禁图像识别方法、装置和计算机可读存储介质 | |
CN111310743B (zh) | 人脸识别方法、装置、电子设备及可读存储介质 | |
Zhou et al. | Test-time domain generalization for face anti-spoofing | |
CN113762382B (zh) | 模型的训练及场景识别方法、装置、设备及介质 | |
CN106920255B (zh) | 一种针对图像序列的运动目标提取方法及装置 | |
CN112446428B (zh) | 一种图像数据处理方法及装置 | |
CN116630367B (zh) | 目标跟踪方法、装置、电子设备及存储介质 | |
CN113298135B (zh) | 基于深度学习的模型训练方法、装置、存储介质及设备 | |
CN117113174A (zh) | 一种模型训练的方法、装置、存储介质及电子设备 | |
CN112084936A (zh) | 一种人脸图像预处理方法、装置、设备及存储介质 | |
CN109886185B (zh) | 一种目标识别方法、装置、电子设备和计算机存储介质 | |
CN114066766A (zh) | 图数据处理方法及相关装置、电子设备和存储介质 | |
CN114742170A (zh) | 对抗样本生成方法、模型训练方法、图像识别方法及装置 | |
CN106296568A (zh) | 一种镜头类型的确定方法、装置及客户端 | |
CN110008803B (zh) | 行人检测、训练检测器的方法、装置及设备 | |
CN110751197A (zh) | 图片分类方法、图片模型训练方法及设备 | |
CN113469204A (zh) | 数据处理方法、装置、设备和计算机存储介质 | |
CN114677444B (zh) | 一种优化的视觉slam方法 | |
CN116881175B (zh) | 应用兼容性能评估方法、装置、电子设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
CB02 | Change of applicant information |
Address after: 210000 Longmian Avenue 568, High-tech Park, Jiangning District, Nanjing City, Jiangsu Province Applicant after: Xiaoshi Technology (Jiangsu) Co.,Ltd. Address before: 210000 Longmian Avenue 568, High-tech Park, Jiangning District, Nanjing City, Jiangsu Province Applicant before: NANJING ZHENSHI INTELLIGENT TECHNOLOGY Co.,Ltd. |
|
CB02 | Change of applicant information | ||
GR01 | Patent grant | ||
GR01 | Patent grant |