CN113609965B - 文字识别模型的训练方法及装置、存储介质、电子设备 - Google Patents
文字识别模型的训练方法及装置、存储介质、电子设备 Download PDFInfo
- Publication number
- CN113609965B CN113609965B CN202110886478.9A CN202110886478A CN113609965B CN 113609965 B CN113609965 B CN 113609965B CN 202110886478 A CN202110886478 A CN 202110886478A CN 113609965 B CN113609965 B CN 113609965B
- Authority
- CN
- China
- Prior art keywords
- model
- label
- training
- loss function
- teacher
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 126
- 238000000034 method Methods 0.000 title claims abstract description 85
- 230000006870 function Effects 0.000 claims abstract description 105
- 238000004821 distillation Methods 0.000 claims abstract description 41
- 239000011159 matrix material Substances 0.000 claims description 64
- 230000004044 response Effects 0.000 claims description 41
- 238000004364 calculation method Methods 0.000 claims description 19
- 230000015654 memory Effects 0.000 claims description 15
- 230000003213 activating effect Effects 0.000 claims description 14
- 238000010276 construction Methods 0.000 claims description 5
- 238000002864 sequence alignment Methods 0.000 claims description 5
- 238000004590 computer program Methods 0.000 claims description 2
- 238000010801 machine learning Methods 0.000 abstract description 2
- 230000006835 compression Effects 0.000 description 16
- 238000007906 compression Methods 0.000 description 16
- 230000008569 process Effects 0.000 description 14
- 238000012545 processing Methods 0.000 description 11
- 230000004913 activation Effects 0.000 description 9
- 238000010586 diagram Methods 0.000 description 7
- 230000000694 effects Effects 0.000 description 6
- 238000012795 verification Methods 0.000 description 6
- 238000013140 knowledge distillation Methods 0.000 description 4
- 238000013528 artificial neural network Methods 0.000 description 3
- 230000004069 differentiation Effects 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 230000010354 integration Effects 0.000 description 3
- 238000004519 manufacturing process Methods 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 230000002123 temporal effect Effects 0.000 description 3
- 238000012360 testing method Methods 0.000 description 3
- 230000006978 adaptation Effects 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 239000013307 optical fiber Substances 0.000 description 2
- 230000000644 propagated effect Effects 0.000 description 2
- 230000000306 recurrent effect Effects 0.000 description 2
- NAWXUBYGYWOOIX-SFHVURJKSA-N (2s)-2-[[4-[2-(2,4-diaminoquinazolin-6-yl)ethyl]benzoyl]amino]-4-methylidenepentanedioic acid Chemical compound C1=CC2=NC(N)=NC(N)=C2C=C1CCC1=CC=C(C(=O)N[C@@H](CC(=C)C(O)=O)C(O)=O)C=C1 NAWXUBYGYWOOIX-SFHVURJKSA-N 0.000 description 1
- 238000012935 Averaging Methods 0.000 description 1
- 230000009471 action Effects 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000002457 bidirectional effect Effects 0.000 description 1
- 238000013145 classification model Methods 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 238000001944 continuous distillation Methods 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 230000002708 enhancing effect Effects 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 230000002093 peripheral effect Effects 0.000 description 1
- 238000007639 printing Methods 0.000 description 1
- 230000003252 repetitive effect Effects 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000006403 short-term memory Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Probability & Statistics with Applications (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Molecular Biology (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Character Discrimination (AREA)
- Image Analysis (AREA)
- Machine Translation (AREA)
Abstract
本公开是关于一种文字识别模型的训练方法及装置、存储介质、电子设备,涉及机器学习技术领域,该方法包括:根据历史图像以及历史图像的真实文字标签,构建数据集,并将数据集中的历史图像输入至训练完成的教师模型中,得到历史图像的软目标标签;将数据集中的历史图像输入至与教师模型具有相同输出层的学生模型中,得到历史图像的软预测标签以及实际预测标签;根据软目标标签以及软预测标签构建第一损失函数,并根据真实文字标签以及实际预测标签构建第二损失函数;根据第一损失函数以及第二损失函数对学生模型进行蒸馏训练,得到训练完成的文字识别模型。本公开降低了模型的计算量。
Description
技术领域
本公开实施例涉及机器学习技术领域,具体而言,涉及一种文字识别模型的训练方法、文字识别模型的训练装置、计算机可读存储介质以及电子设备。
背景技术
文字识别技术已经较为成熟,在信件分拣、票据识别、证件识别、自动驾驶、内容安全审核等领域中取得非常显著的应用效果。
目前,较为成熟的文字识别技术,大多采用深度学习的方法,特别是以联结时序分类(Connectionist Temporal Classification,CTC)方法构造的CRNN文字识别模型,已经在工业界成熟落地。为了取得较好的效果,在设计文字识别模型时,一般选用大的骨架网络(Backbone)和复杂的序列模型,用于提取图片中的文字信息、提升文字识别模型在各种复杂场景中的鲁棒性。
但是,使用大的骨架网络和复杂的序列模型,直接导致模型的计算量大幅增加,不利于文字识别模型部署在嵌入式设备上或一些对实时性要求比较高的场景(如自动驾驶),在一定程度上限制了文字识别模型的应用与推广。
因此,如何在保证模型精度的基础上,对模型进行压缩以降低模型的计算量,已经成为文字识别技术应用在生产环境中急需解决的难题。
需要说明的是,在上述背景技术部分发明的信息仅用于加强对本公开的背景的理解,因此可以包括不构成对本领域普通技术人员已知的现有技术的信息。
发明内容
本公开的目的在于提供一种文字识别模型的训练方法、文字识别模型的训练装置、计算机可读存储介质以及电子设备,进而至少在一定程度上克服由于相关技术的限制和缺陷而导致的无法对模型进行压缩进而使得模型的计算量较大的问题。
根据本公开的一个方面,提供一种文字识别模型的训练方法,包括:
根据历史图像以及所述历史图像的真实文字标签,构建数据集,并将所述数据集中的历史图像输入至训练完成的教师模型中,得到所述历史图像的软目标标签;
将所述数据集中的历史图像输入至与所述教师模型具有相同输出层的学生模型中,得到所述历史图像的软预测标签以及实际预测标签;
根据所述软目标标签以及软预测标签构建第一损失函数,并根据所述真实文字标签以及实际预测标签构建第二损失函数;
根据所述第一损失函数以及第二损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型。
在本公开的一种示例性实施例中,将所述数据集中的历史图像输入至训练完成的教师模型中,得到所述历史图像的软目标标签,包括:
将所述数据集中的历史图像输入至训练完成的教师模型中,得到第一响应特征矩阵;
通过所述训练完成的教师模型中包括的带有温度系数的第一Softmax-T层对所述第一响应特征矩阵进行激活处理,得到第一字符后验概率矩阵。
在本公开的一种示例性实施例中,所述数据集中的历史图像输入至与所述教师模型具有相同输出层的学生模型中,得到所述历史图像的软预测标签以及实际预测标签,包括:
将所述数据集中的历史图像输入至与所述教师模型具有相同输出层的学生模型中,得到第二响应特征矩阵;
通过所述学生模型中包括的带有温度系数的第二Softmax-T层对所述第二响应特征矩阵进行激活处理,第二字符后验概率矩阵;
通过所述学生模型中包括的第三Softmax层对所述第二响应特征矩阵进行激活处理,得到所述第三字符后验概率矩阵。
在本公开的一种示例性实施例中,根据所述软目标标签以及软预测标签构建第一损失函数,包括:
对所述软目标标签以及软预测标签进行时间序列对齐;
计算对齐后的软目标标签与所述软预测标签之间的交叉熵,并根据所述交叉熵构建所述第一损失函数。
在本公开的一种示例性实施例中,根据所述真实文字标签以及实际预测标签构建第二损失函数,包括:
计算所述实际预测标签中所包括的第三字符后验概率矩阵中,构成所述真实文字标签的所有路径的路径集合;
计算所述路径集合中每一条路径中所包括的每一个元素属于所述真实文字标签中对应位置上的元素的概率的乘积;
对所述路径集合中所包括的所有路径的概率的乘积进行求和,并以求和结果为最大值为目标,构建所述第二损失函数。
在本公开的一种示例性实施例中,根据所述第一损失函数以及第二损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型,包括:
对所述第一损失函数以及第二损失函数进行加权平均,得到目标损失函数,并利用所述目标损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型。
在本公开的一种示例性实施例中,所述文字识别模型的训练方法还包括:
利用所述数据集对教师模型进行训练,得到所述训练完成的教师模型;
其中,所述教师模型包括单模式教师模型或多模型教师模型,当所述教师模型为多模式教师模型时,所述教师模型的模式包括异域同构模式、同域异构模型以及异域异构模式中的至少一种。
在本公开的一种示例性实施例中,当所述教师模型的模式为多模式教师模型时,利用所述数据集对教师模型进行训练,得到训练完成的教师模型,包括:
利用所述数据集中包括的多个不同字体类别的历史图像对一种结构的教师模型进行训练,得到训练完成的教师模型;或者
利用所述数据集中包括的同一种字体类别的历史图像对多种不同结果的教师模型进行训练,得到训练完成的教师模型;或者
利用所述数据集中包括的多个不同字体类别的历史图像,对多种不同结果的教师模型进行训练,得到训练完成的教师模型。
根据本公开的一个方面,提供一种文字识别模型的训练装置,包括:
第一标签计算模块,用于根据历史图像以及所述历史图像的真实文字标签,构建数据集,并将所述数据集中的历史图像输入至训练完成的教师模型中,得到所述历史图像的软目标标签;
第二标签计算模块,用于将所述数据集中的历史图像输入至与所述教师模型具有相同输出层的学生模型中,得到所述历史图像的软预测标签以及实际预测标签;
损失函数构建模块,用于根据所述软目标标签以及软预测标签构建第一损失函数,并根据所述真实文字标签以及实际预测标签构建第二损失函数;
蒸馏训练模块,用于根据所述第一损失函数以及第二损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型。
根据本公开的一个方面,提供一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现上述任意一项所述的文字识别模型的训练方法。
根据本公开的一个方面,提供一种电子设备,包括:
处理器;以及
存储器,用于存储所述处理器的可执行指令;
其中,所述处理器配置为经由执行所述可执行指令来执行上述任意一项所述的文字识别模型的训练方法。
本公开实施例提供的一种文字识别模型的训练方法,一方面,由于可以通过软目标标签以及软预测标签构建第一损失函数,并根据真实文字标签以及实际预测标签构建第二损失函数;根据第一损失函数以及第二损失函数对学生模型进行蒸馏训练,得到训练完成的文字识别模型,进而在保证了学生模型的精度(可以通过真实文字标签以及实际预测标签构建的第二损失函数确保学生模型的精度)的基础上,实现了对学生模型的蒸馏压缩(通过第一损失函数实现蒸馏压缩),进而实现了在保证模型精度的基础上,对模型进行压缩以降低模型的计算量;另一方面,由于可以通过第一损失函数以及第二损失函数对学生模型进行蒸馏训练,得到训练完成的文字识别模型,进而可以将训练完成的文字识别模型部署至嵌入式设备上或一些对实时性要求比较高的场景,扩大了文字识别模型的应用场景。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本公开。
附图说明
此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本公开的实施例,并与说明书一起用于解释本公开的原理。显而易见地,下面描述中的附图仅仅是本公开的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1示意性示出根据本公开示例实施例的一种文字识别模型的训练方法的流程图。
图2示意性示出根据本公开示例实施例的一种文字识别模型的训练方法的结构流程图。
图3示意性示出根据本公开示例实施例的一种异域同构模式的教师模型的训练方法的结构流程图。
图4示意性示出根据本公开示例实施例的一种同域异构模式的教师模型的训练方法的结构流程图。
图5示意性示出根据本公开示例实施例的一种异域异构模式的教师模型的训练方法的结构流程图。
图6示意性示出根据本公开示例实施例的一种将所述数据集中的历史图像输入至训练完成的教师模型中,得到所述历史图像的软目标标签的方法流程图。
图7示意性示出根据本公开示例实施例的一种文字识别模型的蒸馏训练的结构框架图。
图8示意性示出根据本公开示例实施例的一种将数据集中的历史图像输入至与教师模型具有相同输出层的学生模型中,得到历史图像的软预测标签以及实际预测标签的方法流程图。
图9示意性示出根据本公开示例实施例的一种文字识别模型的训练装置的框图。
图10示意性示出根据本公开示例实施例的一种用于实现上述文字识别模型的训练方法的电子设备。
具体实施方式
现在将参考附图更全面地描述示例实施方式。然而,示例实施方式能够以多种形式实施,且不应被理解为限于在此阐述的范例;相反,提供这些实施方式使得本公开将更加全面和完整,并将示例实施方式的构思全面地传达给本领域的技术人员。所描述的特征、结构或特性可以以任何合适的方式结合在一个或更多实施方式中。在下面的描述中,提供许多具体细节从而给出对本公开的实施方式的充分理解。然而,本领域技术人员将意识到,可以实践本公开的技术方案而省略所述特定细节中的一个或更多,或者可以采用其它的方法、组元、装置、步骤等。在其它情况下,不详细示出或描述公知技术方案以避免喧宾夺主而使得本公开的各方面变得模糊。
此外,附图仅为本公开的示意性图解,并非一定是按比例绘制。图中相同的附图标记表示相同或类似的部分,因而将省略对它们的重复描述。附图中所示的一些方框图是功能实体,不一定必须与物理或逻辑上独立的实体相对应。可以采用软件形式来实现这些功能实体,或在一个或多个硬件模块或集成电路中实现这些功能实体,或在不同网络和/或处理器装置和/或微控制器装置中实现这些功能实体。
本示例实施方式中首先提供了一种文字识别模型的训练方法,该方法可以运行于服务器、服务器集群或云服务器等;当然,本领域技术人员也可以根据需求在其他平台运行本公开的方法,本示例性实施例中对此不做特殊限定。参考图1所示,该文字识别模型的训练方法可以包括以下步骤:
步骤S110.根据历史图像以及所述历史图像的真实文字标签,构建数据集,并将所述数据集中的历史图像输入至训练完成的教师模型中,得到所述历史图像的软目标标签;
步骤S120.将所述数据集中的历史图像输入至与所述教师模型具有相同输出层的学生模型中,得到所述历史图像的软预测标签以及实际预测标签;
步骤S130.根据所述软目标标签以及软预测标签构建第一损失函数,并根据所述真实文字标签以及实际预测标签构建第二损失函数;
步骤S140.根据所述第一损失函数以及第二损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型。
上述文字识别模型的训练方法中,一方面,由于可以通过软目标标签以及软预测标签构建第一损失函数,并根据真实文字标签以及实际预测标签构建第二损失函数;根据第一损失函数以及第二损失函数对学生模型进行蒸馏训练,得到训练完成的文字识别模型,进而在保证了学生模型的精度(可以通过真实文字标签以及实际预测标签构建的第二损失函数确保学生模型的精度)的基础上,实现了对学生模型的蒸馏压缩(通过第一损失函数实现蒸馏压缩),进而实现了在保证模型精度的基础上,对模型进行压缩以降低模型的计算量;另一方面,由于可以通过第一损失函数以及第二损失函数对学生模型进行蒸馏训练,得到训练完成的文字识别模型,进而可以将训练完成的文字识别模型部署至嵌入式设备上或一些对实时性要求比较高的场景,扩大了文字识别模型的应用场景。
以下,将结合附图对本公开示例实施例文字识别模型的训练方法进行详细的解释以及说明。
首先,对本公开示例实施例的发明目的进行解释以及说明。具体的,本公开提出了一种基于特征响应的知识蒸馏方法(文字识别模型的训练方法),可以对文字识别模型进行压缩,该方法适用于以CTC(Connectionist Temporal Classification,联结时序分类)建模为主的文字识别模型。本公开可以在最小精度损失的情况下,最大程度对模型进行压缩,灵活调整目标网络的计算量,轻松实现文字识别速度和精度的平衡。
其次,对本公开示例实施例的具体流程进行解释以及说明。本公开所提供的文字识别模型的训练方法,可以包括两个阶段。具体的,参考图2所示,第一阶段为:选择或训练高精度的文字识别模型。这个模型称为教师模型。教师模型的特点是文字识别精度高、泛化能力强,设计时一般会使用大型的特征提取网络和序列特征对齐网络组合,如深层的卷积神经网络和深层的双向LSTM(Long Short-Term Memory,长短期记忆网络)网络组合,也可以是其它对文字识别精度高的神经网络;并且,使用数据训练一个或多个教师模型,也可以选择开源的预训练模型,这个模型将提供给第二阶段使用。
第二个阶段为:根据业务场景的要求,设计对应的压缩模型(目标模型,也即学生模型),这个模型的特点是模型体量小,符合生产部署的计算量要求。在蒸馏训练时,将目标模型设置为学生模型,使用训练好的教师模型生成软目标标签(Soft target)对学生模型进行蒸馏学习;软目标标签可以使用一个教师模型生成,也可以使用多个教师模型通过集成方法产生,模型蒸馏训练结束后,输出目标模型(训练完成的文字识别模型),完成对文字识别模型的压缩。
具体的,在图2中,可以通过数据集210对一个或者多个教师模型(CRNN(Convolutional Recurrent Neural Network,卷积递归神经网络)和/或SwinTransformer和/或其他模型)220进行训练,进而得到教师模型1、教师模型2、…、教师模型N,然后基于训练完成的教师模型1、教师模型2、…、教师模型N,对学生模型230进行蒸馏压缩,进而得到训练完成的文字识别模型240,也即目标模型。
以下,结合图2对本公开示例实施例的文字识别模型的训练方法中所包括的各步骤进行详细的解释以及说明。
在本公开示例实施例提供的一种文字识别模型的训练方法中:
在步骤S110中,根据历史图像以及所述历史图像的真实文字标签,构建数据集,并将所述数据集中的历史图像输入至训练完成的教师模型中,得到所述历史图像的软目标标签。
在本示例实施例中,首先,需要根据历史图像以及历史图像的真实文字标签构建数据集,其中,该历史图像可以包括具有多个不同字体类别的历史图像,字体类别可以包括印刷体、手写体、仿宋体等等,本示例对此不做特殊限制。具体的,在构建数据集的过程中,可以根据各字体类别的不同,分别构建不同的子数据集,进而根据各子数据集,构建上述数据集。其中,子数据集可以包括:印刷体数据集、手写体数据集以及仿宋体数据集等等,当然也可以包括其他子数据集,本示例对此不做特殊限制。
其次,当数据集构建完成后,为了可以通过训练完成的教师模型得到历史图像的软目标标签,首先需要对教师模型进行训练;同时,由于本公开示例实施例所使用的知识蒸馏方法是基于模型输出响应的方法,因而,教师模型的选择和设计时,只需要模型输出层和目标模型的输出一致即可;当然,教师模型可以选择开源的预训练模型,也可以通过自行设计训练生成,在使用多个教师模型集成时,各个教师模型之间应具有较大的差异。本公开示例实施例设计了三种方法训练差异化的教师模型组合,分别是异域同构、同域异构、异域异构。
具体的,该文字识别模型的训练方法还可以包括:利用所述数据集对教师模型进行训练,得到所述训练完成的教师模型;其中,所述教师模型包括单模式教师模型或者多模式教师模型当所述教师模型为多模式教师模型时,所述教师模型的模式包括异域同构模式、同域异构模型以及异域异构模式等等。
一方面,当所述教师模型的模式为异域同构模式时,利用所述数据集对教师模型进行训练,得到训练完成的教师模型,包括:利用所述数据集中包括的多个不同字体类别的历史图像对一种结构的教师模型进行训练,得到训练完成的教师模型。
举例来说,参考图3所示,当对异域同构模式的教师模型进行训练时,可以通过多个不同领域(不同字体类别,例如印刷体、手写体)的数据集,分别对同一个结构的文字识别模型进行训练,最后得到一组基于数据差异化的模型实例组合。模型的训练过程可以包括:选择一种结构的文字识别模型(如CRNN),根据具体的模型结构准备多个不同领域的数据集,并将其划分成训练集和验证集;使用不同域的训练集数据作为模型的输入,分别对同一个结构的模型进行训练,当相应的验证集测试指标达到预期设置效果时,停止模型训练,输出教师模型。
另一方面,当所述教师模型的模型为同域异构模式时,利用所述数据集对教师模型进行训练,得到训练完成的教师模型,包括:利用所述数据集中包括的同一种字体类别的历史图像对多种不同结果的教师模型进行训练,得到训练完成的教师模型。
举例来说,参考图4所示,当对同域异构模式的教师模型进行训练时,可以使用同一个数据集(同一个字体类别,例如印刷体构成的数据集),分别训练不同结构的文字识别模型,最后得到一组基于模型结构差异化的模型实例组合。模型的训练过程可以包括:准备训练模型使用的数据集,并将其划分成训练集和验证集;选择多个不同结构的文字识别模型并将训练集数据作为各个模型的输入,对模型进行训练,当相应的验证集测试指标达到预期设置效果时,停止模型训练,输出教师模型。
再一方面,当所述教师模型的模型为异域异构模式时,利用所述数据集对教师模型进行训练,得到训练完成的教师模型,包括:利用所述数据集中包括的多个不同字体类别的历史图像,对多种不同结果的教师模型进行训练,得到训练完成的教师模型。
举例来说,参考图5所示,当对异域异构模式的教师模型进行训练时,异域异构方法:如图4所示,可以使用不同的数据集,分别训练一个特定结构的文字识别模型,最后得到一组基于数据和模型结构差异化的模型实例组合;例如,通过印刷体训练CRNN,通过手写体训练Swin Transformer等等。具体的模型训练过程可以包括:选择多个不同结构的文字识别模型,每个结构的模型匹配一个单独的数据集,并将数据集划分为训练集和验证集;使用各自模型匹配的训练集数据作为模型的输入,分别对各个结构的模型进行训练;当相应的验证集测试指标达到预期设置效果时,停止模型训练,输出教师模型。此处需要补充说明的是,在为模型选择相应字体类别的数据集时,可以根据实际需要自行匹配,此处仅仅起到示例性作用,并不做其他限制;同时,通过设计和使用多种模式的教师模型实例的训练方法,可以有效提高教师模型的知识输出能力,进而起到提升压缩模型的精度和泛化能力的作用。
进一步的,当得到训练完成的教师模型以后,即可将所述数据集中的历史图像输入至训练完成的教师模型中,得到所述历史图像的软目标标签。具体的,参考图6所示,可以包括以下步骤:
步骤S610,将所述数据集中的历史图像输入至训练完成的教师模型中,得到第一响应特征矩阵;
步骤S620,通过所述训练完成的教师模型中包括的带有温度系数的第一Softmax-T层对所述第一响应特征矩阵进行激活处理,得到第一字符后验概率矩阵。
以下,结合图7对步骤S610-步骤S620进行解释以及说明。具体的,如图7所示的文字识别模型蒸馏学习的框架,学生模型使用的标签信息来自于训练完成的教师模型(教师模型1、教师模型2、…、教师模型N)生成的软目标标签和真实文字标签的信息;同时,教师模型的输出响应使用带温度系数的第一Softmax-T进行激活,得到上述第一字符后验概率矩阵,也即软目标标签。具体的:
首先,将所述数据集中的历史图像输入至训练完成的教师模型中,得到第一响应特征矩阵,其中,第一响应特征矩阵X可以如下公式(1)所示:
其中,m表示文字识别模型(训练完成的教师模型)的字符输出长度(时间轴),n表示模型识别的文字类别数。
其次,第一响应特征矩阵X经过第一softmax-T激活处理后,得到第一字符后验概率矩阵Q:
其中,具体的softmax-T激活计算方式,参数T是蒸馏续联过程设置的温度系数,用户可以根据具体任务进行设置,参考范围为(0,+∞)。同时,为了保证模型能够顺利收敛,学生模型和教师模型使用的温度系数参数应保持相同,具体的激活处理可以参考如下公式(3)所示。
进一步的,蒸馏训练时,训练完成的教师模型主要用于计算输入图片(历史图片)产生的软目标标签,在使用单个教师模型时,可以直接使用模型的后验概率矩阵输出,当使用多个教师模型时,可以通过集成各个教师模型的第一字符后验概率矩阵Q(1),Q(2),...,Q(N)产生,具体可以如下公式(4)所示:
QE=f(Q(1),Q(2),...,Q(N)); 公式(4)
其中,集成函数f可以是求平均值的方式,也可以使用其他方法集成,例如为各第一字符后验概率矩阵Q(1),Q(2),...,Q(N)配置不同的权重,进而使得各权重之和为1等等,本示例对此不做特殊限制。具体的,当通过求平均值的方式进行计算时,具体的计算方法可以如下公式(5)所示:
在步骤S120中,将所述数据集中的历史图像输入至与所述教师模型具有相同输出层的学生模型中,得到所述历史图像的软预测标签以及实际预测标签。
在本示例实施例中,参考图8所示,将数据集中的历史图像输入至与教师模型具有相同输出层的学生模型中,得到历史图像的软预测标签以及实际预测标签,可以包括以下步骤:
步骤S810,将所述数据集中的历史图像输入至与所述教师模型具有相同输出层的学生模型中,得到第二响应特征矩阵;
步骤S820,通过所述学生模型中包括的带有温度系数的第二Softmax-T层对所述第二响应特征矩阵进行激活处理,得到第二字符后验概率矩阵;
步骤S830,通过所述学生模型中包括的第三Softmax层对所述第二响应特征矩阵进行激活处理,得到第三字符后验概率矩阵。
以下,结合图7对步骤S810-步骤S830进行解释以及说明。具体的,将历史图像输入至学生模型中,得到第二响应特征矩阵;其中,第二响应特征矩阵通过不同的激活方式产生两个分支,一部分是经过第二Softmax-T激活产生的soft-probability(软预测标签);另一个分支经过第三Softmax激活,得到原始概率输出probability(实际预测标签)。
其中,第二Softmax-T层对第二响应特征矩阵进行激活处理具体的处理过程与前述第一Softmax-T层对第一响应特征矩阵进行激活处理过程类似,此处不再赘述;
同时,第二响应特征矩阵X经过第三softmax激活后,得到第三字符后验概率矩阵P,进而将该第三字符后验概率矩阵作为实际预测标签。其中,第三字符后验概率矩阵具体可以如下公式(6)所示:
其中,具体的第三softmax的具体计算过程如下公式(7)所示:
在步骤S130中,根据所述软目标标签以及软预测标签构建第一损失函数,并根据所述真实文字标签以及实际预测标签构建第二损失函数。
在本示例实施例中,首先,根据所述软目标标签以及软预测标签构建第一损失函数,具体的可以包括:对所述软目标标签以及软预测标签进行时间序列对齐;计算对齐后的软目标标签与所述软预测标签之间的交叉熵,并根据所述交叉熵构建所述第一损失函数。
其次,根据所述真实文字标签以及实际预测标签构建第二损失函数,具体的可以包括:计算所述实际预测标签中所包括的第三字符后验概率矩阵中,构成所述真实文字标签的所有路径的路径集合;计算所述路径集合中每一条路径中所包括的每一个元素属于所述真实文字标签中对应位置上的元素的概率的乘积;对所述路径集合中所包括的所有路径的概率的乘积进行求和,并以求和结果为最大值为目标,构建所述第二损失函数。
以下,继续结合图7对第一损失函数以及第二损失函数的具体构建过程进行解释以及说明。具体的,参考图7所示,由于学生模型所输出的第二响应特征矩阵通过不同的激活方式产生两个分支,一个分支是:Softmax-T激活产生的soft-probability(软预测标签),这部分将和教师模型产生的soft-target进行交叉熵计算得到蒸馏损失(KnowledgeDistillation Loss,KD Loss),也即第一损失函数,另一个分支经过Softmax激活,得到原始概率输出probability(实际预测标签),这部分和真实文字标签label计算得到CTC损失(CTC Loss),也即第二损失函数。通过最小化蒸馏损失与CTC损失之和,完成对学生模型的蒸馏训练,输出压缩后的文字识别模型(学生模型)。
进一步的,第一损失函数KD Loss为教师模型集成的软目标标签QE与学生模型的软预测标签QS之间的交叉熵,由于文字识别模型的输出是一个序列,因而,每个时间步位置预测的结果需要进行严格对齐,进而根据对齐后的软目标标签以及软预测标签构建该第一损失函数。其中,第一损失函数LKD具体可以如下公式(8)所示:
其中,m为字符输出长度,n为文字的类别数,为第j个字符属于第i个类别的软目标标签,/>为第j个字符属于第i个类别的软预测标签。
同时,对于第二损失函数CTC Loss来说,对于给定的实际预测标签序列向量X和真实文字标签序列Y,其中B-1是Y全部路径的集合,也即实际预测标签中所包括的第三字符后验概率矩阵中,构成真实文字标签的所有路径的路径集合,π是其中的一个路径。具体的,第二损失函数LCTC的具体计算方法可以如下公式(9)所示:
其中,子路径π的概率值由如下公式(10)计算得到:
此处需要补充说明的是,通过构建CTC损失函数,并通过CTC函数对学生模型进行蒸馏训练,进而避免了现有技术中由于未采用CTC函数对学生模型进行蒸馏训练导致的序列特征经过CTC解码后,会得到一个不固定长度的离散序列,无法直接将传统的基于末端响应的蒸馏学习方法,应用于文字识别模型压缩方案中,在一定程度上限制了文字识别模型的压缩途径和方法应用的问题。
在步骤S140中,根据所述第一损失函数以及第二损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型。
在本示例实施例中,可以对所述第一损失函数以及第二损失函数进行加权平均,得到目标损失函数,并利用所述目标损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型。
具体的,模型最终的目标损失函数Loss是对KD Loss和CTC Loss的加权平均,具体可以如下公式(11)所示。其中,参数λ∈(0,1)为KD Loss的权重,标识教师模型对学生模型影响的重要程度,可以根据实际情况进行灵活调整。
L=λLKD+(1-λ)LCTC。 公式(11)
至此,已经完成了文字识别模型的蒸馏训练。在实际应用过程中,可以将训练完成的文字识别模型部署至嵌入式设备上,或者一些对实时性要求较高的场景,例如自动驾驶场景等等。
基于上述记载的内容可以得知,本公开示例实施例所提供的文字识别模型的压缩方法,至少具有以下优点:
一方面,本公开所提出的文字识别模型的压缩方法,解决了现有技术中由于文字识别模型的输出响应是一个多维的序列特征,相邻节点的预测结果相互依赖,且序列特征经过CTC解码后,会得到一个不固定长度的离散序列,因此无法直接将传统的基于末端响应的蒸馏学习方法,应用于文字识别模型压缩方案中,在一定程度上限制了文字识别模型的压缩途径和方法应用的问题;另一方面,还解决了传统基于末端响应的蒸馏学习方法,无法直接应用于文字识别模型压缩的问题;
再一方面,本公开所提供的基于特征响应的蒸馏学习方法和框架,对CTC解码前的输出特征设置教师监督信息,使得文字识别模型能够像普通的分类模型一样,能够进行蒸馏学习,进而达到对文字识别模型压缩的目的;
进一步的,通过设计基于特征响应的文字识别模型知识蒸馏框架,实现对文字识别模型的蒸馏压缩,该框架可以灵活的调整文字识别模型的精度和计算量,提高了文字识别模型在生产环境中的部署效率;并且,设计和使用多种模式的教师模型实例的训练方法,可以有效提高教师模型的知识输出能力,间接提升压缩模型的精度和泛化能力。
本公开还提供了一种文字识别模型的训练装置。参考图9所示,该文字识别模型的训练装置可以包括第一标签计算模块910、第二标签计算模块920、损失函数构建模块930以及蒸馏训练模块940。其中:
第一标签计算模块910可以用于根据历史图像以及所述历史图像的真实文字标签,构建数据集,并将所述数据集中的历史图像输入至训练完成的教师模型中,得到所述历史图像的软目标标签;
第二标签计算模块920可以用于将所述数据集中的历史图像输入至与所述教师模型具有相同输出层的学生模型中,得到所述历史图像的软预测标签以及实际预测标签;
损失函数构建模块930可以用于根据所述软目标标签以及软预测标签构建第一损失函数,并根据所述真实文字标签以及实际预测标签构建第二损失函数;
蒸馏训练模块940可以用于根据所述第一损失函数以及第二损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型。
在本公开的一种示例性实施例中,将所述数据集中的历史图像输入至训练完成的教师模型中,得到所述历史图像的软目标标签,包括:
将所述数据集中的历史图像输入至训练完成的教师模型中,得到第一响应特征矩阵;
通过所述训练完成的教师模型中包括的带有温度系数的第一Softmax-T层对所述第一响应特征矩阵进行激活处理,得到第一字符后验概率矩阵。
在本公开的一种示例性实施例中,所述数据集中的历史图像输入至与所述教师模型具有相同输出层的学生模型中,得到所述历史图像的软预测标签以及实际预测标签,包括:
将所述数据集中的历史图像输入至与所述教师模型具有相同输出层的学生模型中,得到第二响应特征矩阵;
通过所述学生模型中包括的带有温度系数的第二Softmax-T层对所述第二响应特征矩阵进行激活处理,第二字符后验概率矩阵;
通过所述学生模型中包括的第三Softmax层对所述第二响应特征矩阵进行激活处理,得到所述第三字符后验概率矩阵。
在本公开的一种示例性实施例中,根据所述软目标标签以及软预测标签构建第一损失函数,包括:
对所述软目标标签以及软预测标签进行时间序列对齐;
计算对齐后的软目标标签与所述软预测标签之间的交叉熵,并根据所述交叉熵构建所述第一损失函数。
在本公开的一种示例性实施例中,根据所述真实文字标签以及实际预测标签构建第二损失函数,包括:
计算所述实际预测标签中所包括的第三字符后验概率矩阵中,构成所述真实文字标签的所有路径的路径集合;
计算所述路径集合中每一条路径中所包括的每一个元素属于所述真实文字标签中对应位置上的元素的概率的乘积;
对所述路径集合中所包括的所有路径的概率的乘积进行求和,并以求和结果为最大值为目标,构建所述第二损失函数。
在本公开的一种示例性实施例中,根据所述第一损失函数以及第二损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型,包括:
对所述第一损失函数以及第二损失函数进行加权平均,得到目标损失函数,并利用所述目标损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型。
在本公开的一种示例性实施例中,所述文字识别模型的训练装置还可以包括:
教师模型训练模块,可以用于利用所述数据集对教师模型进行训练,得到所述训练完成的教师模型;
其中,所述教师模型包括单模式教师模型或多模型教师模型,当所述教师模型为多模式教师模型时,所述教师模型的模式包括异域同构模式、同域异构模型以及异域异构模式中的至少一种。
在本公开的一种示例性实施例中,当所述教师模型的模式为多模式教师模型时,利用所述数据集对教师模型进行训练,得到训练完成的教师模型,包括:
利用所述数据集中包括的多个不同字体类别的历史图像对一种结构的教师模型进行训练,得到训练完成的教师模型;或者
利用所述数据集中包括的同一种字体类别的历史图像对多种不同结果的教师模型进行训练,得到训练完成的教师模型;或者
利用所述数据集中包括的多个不同字体类别的历史图像,对多种不同结果的教师模型进行训练,得到训练完成的教师模型。
上述文字识别模型的训练装置中各模块的具体细节已经在对应的文字识别模型的训练方法中进行了详细的描述,因此此处不再赘述。
应当注意,尽管在上文详细描述中提及了用于动作执行的设备的若干模块或者单元,但是这种划分并非强制性的。实际上,根据本公开的实施方式,上文描述的两个或更多模块或者单元的特征和功能可以在一个模块或者单元中具体化。反之,上文描述的一个模块或者单元的特征和功能可以进一步划分为由多个模块或者单元来具体化。
此外,尽管在附图中以特定顺序描述了本公开中方法的各个步骤,但是,这并非要求或者暗示必须按照该特定顺序来执行这些步骤,或是必须执行全部所示的步骤才能实现期望的结果。附加的或备选的,可以省略某些步骤,将多个步骤合并为一个步骤执行,以及/或者将一个步骤分解为多个步骤执行等。
在本公开的示例性实施例中,还提供了一种能够实现上述方法的电子设备。
所属技术领域的技术人员能够理解,本公开的各个方面可以实现为系统、方法或程序产品。因此,本公开的各个方面可以具体实现为以下形式,即:完全的硬件实施方式、完全的软件实施方式(包括固件、微代码等),或硬件和软件方面结合的实施方式,这里可以统称为“电路”、“模块”或“系统”。
下面参照图10来描述根据本公开的这种实施方式的电子设备1000。图10显示的电子设备1000仅仅是一个示例,不应对本公开实施例的功能和使用范围带来任何限制。
如图10所示,电子设备1000以通用计算设备的形式表现。电子设备1000的组件可以包括但不限于:上述至少一个处理单元1010、上述至少一个存储单元1020、连接不同系统组件(包括存储单元1020和处理单元1010)的总线1030以及显示单元1040。
其中,所述存储单元存储有程序代码,所述程序代码可以被所述处理单元1010执行,使得所述处理单元1010执行本说明书上述“示例性方法”部分中描述的根据本公开各种示例性实施方式的步骤。例如,所述处理单元1010可以执行如图1中所示的步骤S110:根据历史图像以及所述历史图像的真实文字标签,构建数据集,并将所述数据集中的历史图像输入至训练完成的教师模型中,得到所述历史图像的软目标标签;步骤S120:将所述数据集中的历史图像输入至与所述教师模型具有相同输出层的学生模型中,得到所述历史图像的软预测标签以及实际预测标签;步骤S130:根据所述软目标标签以及软预测标签构建第一损失函数,并根据所述真实文字标签以及实际预测标签构建第二损失函数;步骤S140:根据所述第一损失函数以及第二损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型。
存储单元1020可以包括易失性存储单元形式的可读介质,例如随机存取存储单元(RAM)10201和/或高速缓存存储单元10202,还可以进一步包括只读存储单元(ROM)10203。
存储单元1020还可以包括具有一组(至少一个)程序模块10205的程序/实用工具10204,这样的程序模块10205包括但不限于:操作系统、一个或者多个应用程序、其它程序模块以及程序数据,这些示例中的每一个或某种组合中可能包括网络环境的实现。
总线1030可以为表示几类总线结构中的一种或多种,包括存储单元总线或者存储单元控制器、外围总线、图形加速端口、处理单元或者使用多种总线结构中的任意总线结构的局域总线。
电子设备1000也可以与一个或多个外部设备1100(例如键盘、指向设备、蓝牙设备等)通信,还可与一个或者多个使得用户能与该电子设备1000交互的设备通信,和/或与使得该电子设备1000能与一个或多个其它计算设备进行通信的任何设备(例如路由器、调制解调器等等)通信。这种通信可以通过输入/输出(I/O)接口1050进行。并且,电子设备1000还可以通过网络适配器1060与一个或者多个网络(例如局域网(LAN),广域网(WAN)和/或公共网络,例如因特网)通信。如图所示,网络适配器1060通过总线1030与电子设备1000的其它模块通信。应当明白,尽管图中未示出,可以结合电子设备1000使用其它硬件和/或软件模块,包括但不限于:微代码、设备驱动器、冗余处理单元、外部磁盘驱动阵列、RAID系统、磁带驱动器以及数据备份存储系统等。
通过以上的实施方式的描述,本领域的技术人员易于理解,这里描述的示例实施方式可以通过软件实现,也可以通过软件结合必要的硬件的方式来实现。因此,根据本公开实施方式的技术方案可以以软件产品的形式体现出来,该软件产品可以存储在一个非易失性存储介质(可以是CD-ROM,U盘,移动硬盘等)中或网络上,包括若干指令以使得一台计算设备(可以是个人计算机、服务器、终端装置、或者网络设备等)执行根据本公开实施方式的方法。
在本公开的示例性实施例中,还提供了一种计算机可读存储介质,其上存储有能够实现本说明书上述方法的程序产品。在一些可能的实施方式中,本公开的各个方面还可以实现为一种程序产品的形式,其包括程序代码,当所述程序产品在终端设备上运行时,所述程序代码用于使所述终端设备执行本说明书上述“示例性方法”部分中描述的根据本公开各种示例性实施方式的步骤。
根据本公开的实施方式的用于实现上述方法的程序产品,其可以采用便携式紧凑盘只读存储器(CD-ROM)并包括程序代码,并可以在终端设备,例如个人电脑上运行。然而,本公开的程序产品不限于此,在本文件中,可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。
所述程序产品可以采用一个或多个可读介质的任意组合。可读介质可以是可读信号介质或者可读存储介质。可读存储介质例如可以为但不限于电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。
计算机可读信号介质可以包括在基带中或者作为载波一部分传播的数据信号,其中承载了可读程序代码。这种传播的数据信号可以采用多种形式,包括但不限于电磁信号、光信号或上述的任意合适的组合。可读信号介质还可以是可读存储介质以外的任何可读介质,该可读介质可以发送、传播或者传输用于由指令执行系统、装置或者器件使用或者与其结合使用的程序。
可读介质上包含的程序代码可以用任何适当的介质传输,包括但不限于无线、有线、光缆、RF等等,或者上述的任意合适的组合。
可以以一种或多种程序设计语言的任意组合来编写用于执行本公开操作的程序代码,所述程序设计语言包括面向对象的程序设计语言—诸如Java、C++等,还包括常规的过程式程序设计语言—诸如“C”语言或类似的程序设计语言。程序代码可以完全地在用户计算设备上执行、部分地在用户设备上执行、作为一个独立的软件包执行、部分在用户计算设备上部分在远程计算设备上执行、或者完全在远程计算设备或服务器上执行。在涉及远程计算设备的情形中,远程计算设备可以通过任意种类的网络,包括局域网(LAN)或广域网(WAN),连接到用户计算设备,或者,可以连接到外部计算设备(例如利用因特网服务提供商来通过因特网连接)。
此外,上述附图仅是根据本公开示例性实施例的方法所包括的处理的示意性说明,而不是限制目的。易于理解,上述附图所示的处理并不表明或限制这些处理的时间顺序。另外,也易于理解,这些处理可以是例如在多个模块中同步或异步执行的。
本领域技术人员在考虑说明书及实践这里发明的发明后,将容易想到本公开的其他实施例。本申请旨在涵盖本公开的任何变型、用途或者适应性变化,这些变型、用途或者适应性变化遵循本公开的一般性原理并包括本公开未发明的本技术领域中的公知常识或惯用技术手段。说明书和实施例仅被视为示例性的,本公开的真正范围和精神由权利要求指出。
Claims (8)
1.一种文字识别模型的训练方法,其特征在于,包括:
根据历史图像以及所述历史图像的真实文字标签,构建数据集,并将所述数据集中的历史图像输入至训练完成的教师模型中,得到所述历史图像的软目标标签;
将所述数据集中的历史图像输入至与所述教师模型具有相同输出层的学生模型中,得到第二响应特征矩阵;通过所述学生模型中包括的带有温度系数的第二Softmax-T层对所述第二响应特征矩阵进行激活处理,得到第二字符后验概率矩阵;其中,所述第二字符后验概率矩阵为软预测标签;通过所述学生模型中包括的第三Softmax层对所述第二响应特征矩阵进行激活处理,得到第三字符后验概率矩阵;其中,所述第三字符后验概率矩阵为实际预测标签;
对所述软目标标签以及软预测标签进行时间序列对齐;计算对齐后的软目标标签与所述软预测标签之间的交叉熵,并根据所述交叉熵构建第一损失函数,并计算所述实际预测标签中所包括的第三字符后验概率矩阵中,构成所述真实文字标签的所有路径的路径集合;计算所述路径集合中每一条路径中所包括的每一个元素属于所述真实文字标签中对应位置上的元素的概率的乘积;对所述路径集合中所包括的所有路径的概率的乘积进行求和,并以求和结果为最大值为目标,构建第二损失函数;
根据所述第一损失函数以及第二损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型。
2.根据权利要求1所述的文字识别模型的训练方法,其特征在于,将所述数据集中的历史图像输入至训练完成的教师模型中,得到所述历史图像的软目标标签,包括:
将所述数据集中的历史图像输入至训练完成的教师模型中,得到第一响应特征矩阵;
通过所述训练完成的教师模型中包括的带有温度系数的第一Softmax-T层对所述第一响应特征矩阵进行激活处理,得到第一字符后验概率矩阵。
3.根据权利要求1所述的文字识别模型的训练方法,其特征在于,根据所述第一损失函数以及第二损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型,包括:
对所述第一损失函数以及第二损失函数进行加权平均,得到目标损失函数,并利用所述目标损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型。
4.根据权利要求2所述的文字识别模型的训练方法,其特征在于,所述文字识别模型的训练方法还包括:
利用所述数据集对教师模型进行训练,得到所述训练完成的教师模型;
其中,所述教师模型包括单模式教师模型或多模型教师模型,当所述教师模型为多模式教师模型时,所述教师模型的模式包括异域同构模式、同域异构模型以及异域异构模式中的至少一种。
5.根据权利要求4所述的文字识别模型的训练方法,其特征在于,当所述教师模型的模式为多模式教师模型时,利用所述数据集对教师模型进行训练,得到训练完成的教师模型,包括:
利用所述数据集中包括的多个不同字体类别的历史图像对一种结构的教师模型进行训练,得到训练完成的教师模型;或者
利用所述数据集中包括的同一种字体类别的历史图像对多种不同结果的教师模型进行训练,得到训练完成的教师模型;或者
利用所述数据集中包括的多个不同字体类别的历史图像,对多种不同结果的教师模型进行训练,得到训练完成的教师模型。
6.一种文字识别模型的训练装置,其特征在于,包括:
第一标签计算模块,用于根据历史图像以及所述历史图像的真实文字标签,构建数据集,并将所述数据集中的历史图像输入至训练完成的教师模型中,得到所述历史图像的软目标标签;
第二标签计算模块,用于将所述数据集中的历史图像输入至与所述教师模型具有相同输出层的学生模型中,得到第二响应特征矩阵;通过所述学生模型中包括的带有温度系数的第二Softmax-T层对所述第二响应特征矩阵进行激活处理,得到第二字符后验概率矩阵;其中,所述第二字符后验概率矩阵为软预测标签;通过所述学生模型中包括的第三Softmax层对所述第二响应特征矩阵进行激活处理,得到第三字符后验概率矩阵;其中,所述第三字符后验概率矩阵为实际预测标签;
损失函数构建模块,用于对所述软目标标签以及软预测标签进行时间序列对齐;计算对齐后的软目标标签与所述软预测标签之间的交叉熵,并根据所述交叉熵构建第一损失函数,并计算所述实际预测标签中所包括的第三字符后验概率矩阵中,构成所述真实文字标签的所有路径的路径集合;计算所述路径集合中每一条路径中所包括的每一个元素属于所述真实文字标签中对应位置上的元素的概率的乘积;对所述路径集合中所包括的所有路径的概率的乘积进行求和,并以求和结果为最大值为目标,构建第二损失函数;
蒸馏训练模块,用于根据所述第一损失函数以及第二损失函数对所述学生模型进行蒸馏训练,得到训练完成的文字识别模型。
7.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1-5任一项所述的文字识别模型的训练方法。
8.一种电子设备,其特征在于,包括:
处理器;以及
存储器,用于存储所述处理器的可执行指令;
其中,所述处理器配置为经由执行所述可执行指令来执行权利要求1-5任一项所述的文字识别模型的训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110886478.9A CN113609965B (zh) | 2021-08-03 | 2021-08-03 | 文字识别模型的训练方法及装置、存储介质、电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110886478.9A CN113609965B (zh) | 2021-08-03 | 2021-08-03 | 文字识别模型的训练方法及装置、存储介质、电子设备 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113609965A CN113609965A (zh) | 2021-11-05 |
CN113609965B true CN113609965B (zh) | 2024-02-13 |
Family
ID=78339303
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110886478.9A Active CN113609965B (zh) | 2021-08-03 | 2021-08-03 | 文字识别模型的训练方法及装置、存储介质、电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113609965B (zh) |
Families Citing this family (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114186097A (zh) * | 2021-12-10 | 2022-03-15 | 北京百度网讯科技有限公司 | 用于训练模型的方法和装置 |
CN114724168A (zh) * | 2022-05-10 | 2022-07-08 | 北京百度网讯科技有限公司 | 深度学习模型的训练方法、文本识别方法、装置和设备 |
CN115330898B (zh) * | 2022-08-24 | 2023-06-06 | 晋城市大锐金马工程设计咨询有限公司 | 一种基于改进Swin Transformer的杂志广告嵌入方法 |
CN116051935B (zh) * | 2023-03-03 | 2024-03-22 | 北京百度网讯科技有限公司 | 图像检测方法、深度学习模型的训练方法及装置 |
CN116468112B (zh) * | 2023-04-06 | 2024-03-12 | 北京百度网讯科技有限公司 | 目标检测模型的训练方法、装置、电子设备和存储介质 |
Citations (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106548190A (zh) * | 2015-09-18 | 2017-03-29 | 三星电子株式会社 | 模型训练方法和设备以及数据识别方法 |
CN109902678A (zh) * | 2019-02-12 | 2019-06-18 | 北京奇艺世纪科技有限公司 | 模型训练方法、文字识别方法、装置、电子设备及计算机可读介质 |
CN110246487A (zh) * | 2019-06-13 | 2019-09-17 | 苏州思必驰信息科技有限公司 | 用于单通道的语音识别模型的优化方法及系统 |
WO2019240964A1 (en) * | 2018-06-12 | 2019-12-19 | Siemens Aktiengesellschaft | Teacher and student based deep neural network training |
CN110674880A (zh) * | 2019-09-27 | 2020-01-10 | 北京迈格威科技有限公司 | 用于知识蒸馏的网络训练方法、装置、介质与电子设备 |
CN112287920A (zh) * | 2020-09-17 | 2021-01-29 | 昆明理工大学 | 基于知识蒸馏的缅甸语ocr方法 |
CN112465111A (zh) * | 2020-11-17 | 2021-03-09 | 大连理工大学 | 一种基于知识蒸馏和对抗训练的三维体素图像分割方法 |
CN112613303A (zh) * | 2021-01-07 | 2021-04-06 | 福州大学 | 一种基于知识蒸馏的跨模态图像美学质量评价方法 |
CN112906747A (zh) * | 2021-01-25 | 2021-06-04 | 北京工业大学 | 一种基于知识蒸馏的图像分类方法 |
CN113095475A (zh) * | 2021-03-02 | 2021-07-09 | 华为技术有限公司 | 一种神经网络的训练方法、图像处理方法以及相关设备 |
JP2021103386A (ja) * | 2019-12-24 | 2021-07-15 | 株式会社Mobility Technologies | 学習モデルの生成方法、コンピュータプログラム、情報処理装置、及び情報処理方法 |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11586930B2 (en) * | 2019-04-16 | 2023-02-21 | Microsoft Technology Licensing, Llc | Conditional teacher-student learning for model training |
US11302309B2 (en) * | 2019-09-13 | 2022-04-12 | International Business Machines Corporation | Aligning spike timing of models for maching learning |
-
2021
- 2021-08-03 CN CN202110886478.9A patent/CN113609965B/zh active Active
Patent Citations (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106548190A (zh) * | 2015-09-18 | 2017-03-29 | 三星电子株式会社 | 模型训练方法和设备以及数据识别方法 |
WO2019240964A1 (en) * | 2018-06-12 | 2019-12-19 | Siemens Aktiengesellschaft | Teacher and student based deep neural network training |
CN109902678A (zh) * | 2019-02-12 | 2019-06-18 | 北京奇艺世纪科技有限公司 | 模型训练方法、文字识别方法、装置、电子设备及计算机可读介质 |
CN110246487A (zh) * | 2019-06-13 | 2019-09-17 | 苏州思必驰信息科技有限公司 | 用于单通道的语音识别模型的优化方法及系统 |
CN110674880A (zh) * | 2019-09-27 | 2020-01-10 | 北京迈格威科技有限公司 | 用于知识蒸馏的网络训练方法、装置、介质与电子设备 |
JP2021103386A (ja) * | 2019-12-24 | 2021-07-15 | 株式会社Mobility Technologies | 学習モデルの生成方法、コンピュータプログラム、情報処理装置、及び情報処理方法 |
CN112287920A (zh) * | 2020-09-17 | 2021-01-29 | 昆明理工大学 | 基于知识蒸馏的缅甸语ocr方法 |
CN112465111A (zh) * | 2020-11-17 | 2021-03-09 | 大连理工大学 | 一种基于知识蒸馏和对抗训练的三维体素图像分割方法 |
CN112613303A (zh) * | 2021-01-07 | 2021-04-06 | 福州大学 | 一种基于知识蒸馏的跨模态图像美学质量评价方法 |
CN112906747A (zh) * | 2021-01-25 | 2021-06-04 | 北京工业大学 | 一种基于知识蒸馏的图像分类方法 |
CN113095475A (zh) * | 2021-03-02 | 2021-07-09 | 华为技术有限公司 | 一种神经网络的训练方法、图像处理方法以及相关设备 |
Non-Patent Citations (3)
Title |
---|
Digging Deeper into CRNN Model in Chinese Text Images Recognition;Kunhong Yu et al.;arXiv:2011.08505;第1-3节 * |
Distilling the Knowledge in a Neural Network;Geoffrey Hinton et al.;arXiv:1503.02531;全文 * |
生成对抗网络图像类别标签跨模态识别系统设计;刘尚争;刘斌;;现代电子技术(第08期);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN113609965A (zh) | 2021-11-05 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113609965B (zh) | 文字识别模型的训练方法及装置、存储介质、电子设备 | |
CN112487182B (zh) | 文本处理模型的训练方法、文本处理方法及装置 | |
CN111444340B (zh) | 文本分类方法、装置、设备及存储介质 | |
CN107481717B (zh) | 一种声学模型训练方法及系统 | |
CN108959246A (zh) | 基于改进的注意力机制的答案选择方法、装置和电子设备 | |
CN109902293A (zh) | 一种基于局部与全局互注意力机制的文本分类方法 | |
CN111930992A (zh) | 神经网络训练方法、装置及电子设备 | |
WO2023160472A1 (zh) | 一种模型训练方法及相关设备 | |
CN116415654A (zh) | 一种数据处理方法及相关设备 | |
WO2021238333A1 (zh) | 一种文本处理网络、神经网络训练的方法以及相关设备 | |
CN112733550A (zh) | 基于知识蒸馏的语言模型训练方法、文本分类方法及装置 | |
US11423307B2 (en) | Taxonomy construction via graph-based cross-domain knowledge transfer | |
US11250838B2 (en) | Cross-modal sequence distillation | |
CN111653274B (zh) | 唤醒词识别的方法、装置及存储介质 | |
CN113836866B (zh) | 文本编码方法、装置、计算机可读介质及电子设备 | |
Le et al. | An overview of deep learning in industry | |
CN112825114A (zh) | 语义识别方法、装置、电子设备及存储介质 | |
CN117121015A (zh) | 利用冻结语言模型的多模态少发式学习 | |
CN113434683A (zh) | 文本分类方法、装置、介质及电子设备 | |
CN115687934A (zh) | 意图识别方法、装置、计算机设备及存储介质 | |
CN116541492A (zh) | 一种数据处理方法及相关设备 | |
Wang et al. | Bilateral attention network for semantic segmentation | |
CN112434746B (zh) | 基于层次化迁移学习的预标注方法及其相关设备 | |
CN111666375B (zh) | 文本相似度的匹配方法、电子设备和计算机可读介质 | |
CN112905750A (zh) | 一种优化模型的生成方法和设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |