CN113111968A - 图像识别模型训练方法、装置、电子设备和可读存储介质 - Google Patents

图像识别模型训练方法、装置、电子设备和可读存储介质 Download PDF

Info

Publication number
CN113111968A
CN113111968A CN202110482330.9A CN202110482330A CN113111968A CN 113111968 A CN113111968 A CN 113111968A CN 202110482330 A CN202110482330 A CN 202110482330A CN 113111968 A CN113111968 A CN 113111968A
Authority
CN
China
Prior art keywords
model
training
network model
image
sample data
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202110482330.9A
Other languages
English (en)
Other versions
CN113111968B (zh
Inventor
张洪路
周佳
包英泽
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Dami Technology Co Ltd
Original Assignee
Beijing Dami Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Dami Technology Co Ltd filed Critical Beijing Dami Technology Co Ltd
Priority to CN202110482330.9A priority Critical patent/CN113111968B/zh
Publication of CN113111968A publication Critical patent/CN113111968A/zh
Application granted granted Critical
Publication of CN113111968B publication Critical patent/CN113111968B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • Computational Linguistics (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)

Abstract

本发明实施例公开了一种图像识别模型训练方法、装置、电子设备和可读存储介质,所述方法包括获取样本数据,根据样本数据训练确定预设教师网络模型对应的训练辅助模型,再基于样本数据和训练辅助模型确定预设学生网络模型对应的图像识别模型,并将图像识别模型用于识别图像中的不良信息,其中,图像识别模型的损失函数是根据训练辅助模型的输出结果和学生网络模型的输出结果确定。由此,使得训练得到的图像识别模型的计算复杂度降低,识别检测速度更快,识别检测精确度更高。

Description

图像识别模型训练方法、装置、电子设备和可读存储介质
技术领域
本发明涉及计算机技术领域,具体涉及一种图像识别模型训练方法、装置、电子设备和可读存储介质。
背景技术
为避免带有不良信息(例如:涉黄、涉政和涉暴)的视频图像进入公众视野,需要对视频图像进行信息识别,以将符合规范的视频图像进行播放和传播。
现有的不良信息检测常依赖于各种图像识别模型实现,虽识别检测精度高,但模型规模庞大,计算复杂度高,计算速度慢。或者,虽规模小,计算速度快,但识别检测精度有待提高。因此,在实现图像不良信息检测的同时,图像识别模型仍有待改善。
发明内容
有鉴于此,本发明实施例提供一种图像识别模型训练方法、装置、电子设备和可读存储介质,以改善图像识别模型的模型规模、计算速度和识别检测精确度。
第一方面,本发明实施例提供一种图像识别模型训练方法,所述方法包括:
获取样本数据,所述样本数据包括样本图像和图像标注,所述图像标注用于表征所述样本图像中的不良信息分类;
将所述样本数据输入至预设的教师网络模型进行训练,以确定训练辅助模型;
基于所述样本数据和训练辅助模型对预设的学生网络模型进行训练,以确定图像识别模型,其中,所述图像识别模型的损失函数根据所述训练辅助模型的输出结果和学生网络模型的输出结果确定,所述图像识别模型用于识别图像中的不良信息。
进一步地,所述基于所述样本数据和训练辅助模型对预设的学生网络模型进行训练,以确定图像识别模型包括:
将所述样本数据输入至训练辅助模型,确定第一预测结果;
将所述样本数据输入至预设的学生网络模型,确定第二预测结果;
根据所述第一预测结果和第二预测结果调整所述学生网络模型对应的损失函数;
将调整完成的损失函数对应的学生网络模型确定为所述图像识别模型。
进一步地,所述根据所述第一预测结果和第二预测结果调整所述学生网络模型对应的损失函数包括:
根据所述第一预测结果和第二预测结果确定所述训练辅助模型和学生网络模型输出结果之间的第一交叉熵;
根据所述第二预测结果确定所述学生网络模型输出结果与图像标注之间的第二交叉熵;
根据所述第一交叉熵和第二交叉熵确定所述学生网络模型对应的损失函数。
进一步地,所述根据所述第一交叉熵和第二交叉熵确定所述学生网络模型对应的损失函数包括:
将所述第一交叉熵和第二交叉熵与对应的权重系数做加权求和运算,并将所述加权运算的运算结果确定为所述学生网络模型对应的损失函数。
进一步地,所述将所述样本数据输入至预设的教师网络模型进行训练,以确定训练辅助模型包括:
响应于所述不良信息为第一类信息,将所述样本数据输入预设的教师网络模型进行训练,以确定第一训练辅助模型;
其中,所述第一类信息用于表征涉黄类型的不良信息。
进一步地,所述将所述样本数据输入至预设的教师网络模型进行训练,以确定训练辅助模型包括:
响应于所述不良信息为第二类信息,将所述样本数据输入预设的教师网络模型进行训练,以确定第二训练辅助模型;
其中,所述第二类信息用于表征涉暴和涉政类型的不良信息。
进一步地,所述将所述样本数据输入至预设的教师网络模型进行训练,以确定训练辅助模型之前,所述方法还包括:
对所述样本图像进行归一化处理。
进一步地,所述教师网络模型采用分类神经网络。
进一步地,所述学生网络模型采用多目标输出网络或目标检测网络。
第二方面,本发明实施例提供一种图像识别模型训练装置,所述装置包括:
获取单元,用于获取样本数据,所述样本数据包括样本图像和图像标注,所述图像标注用于表征所述样本图像中的不良信息分类;
辅助训练单元,用于将所述样本数据输入至预设的教师网络模型进行训练,以确定训练辅助模型;
图像识别单元,用于基于所述样本数据和训练辅助模型对预设的学生网络模型进行训练,以确定图像识别模型,其中,所述图像识别模型的损失函数根据所述训练辅助模型的输出结果和学生网络模型的输出结果确定,所述图像识别模型用于识别图像中的不良信息。
第三方面,本发明实施例提供一种计算机程序产品,所述计算机程序产品包括计算机程序/指令,所述计算机程序/指令被处理器执行时实现如上任一项所述的方法。
第四方面,本发明实施例提供一种电子设备,包括存储器和处理器,所述存储器用于存储一条或多条计算机程序指令,其中,所述一条或多条计算机程序指令被所述处理器执行以实现如上任一项所述的方法。
第五方面,本发明实施例提供一种可读存储介质,所述计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现如上任一项所述的方法步骤。
本发明实施例的技术方案通过获取样本数据,通过样本数据训练确定预设教师网络模型对应的训练辅助模型,再基于样本数据和训练辅助模型确定预设学生网络模型对应的图像识别模型,并将图像识别模型用于识别图像中的不良信息,实现图像中不良信息的检测识别。由于图像识别模型的损失函数是根据训练辅助模型的输出结果和学生网络模型的输出结果确定,使得图像识别模型在识别精度上与教师网络模型对应的训练辅助模型相媲美,相比于其他规模小且计算速度快的模型,识别精度更高。同时,相比于教师网络模型和对应的训练辅助模型,学生网络模型对应的图像识别模型的模型规模小,计算复杂度低,计算速度快。
附图说明
通过以下参照附图对本发明实施例的描述,本发明的上述以及其它目的、特征和优点将更为清楚,在附图中:
图1是图像识别模型训练方法的流程图;
图2是确定图像识别模型的流程图;
图3是确定学生网络模型损失函数的流程图;
图4是调整学生网络模型损失函数的流程图;
图5是图像识别模型训练方法的另一个流程图;
图6是学生网络模型的示意图;
图7是图像识别模型训练方法的另一个流程图;
图8是图像识别模型训练装置的示意图;
图9是图像识别模型训练装置的另一个示意图;
图10是本发明实施例的电子设备的示意图。
具体实施方式
以下基于实施例对本发明进行描述,但是本发明并不仅仅限于这些实施例。在下文对本发明的细节描述中,详尽描述了一些特定的细节部分。对本领域技术人员来说没有这些细节部分的描述也可以完全理解本发明。为了避免混淆本发明的实质,公知的方法、过程、流程、元件和电路并没有详细叙述。
此外,本领域普通技术人员应当理解,在此提供的附图都是为了说明的目的,并且附图不一定是按比例绘制的。
除非上下文明确要求,否则在说明书的“包括”、“包含”等类似词语应当解释为包含的含义而不是排他或穷举的含义;也就是说,是“包括但不限于”的含义。
在本发明的描述中,需要理解的是,术语“第一”、“第二”等仅用于描述目的,而不能理解为指示或暗示相对重要性。此外,在本发明的描述中,除非另有说明,“多个”的含义是两个或两个以上。
通过构建图像信息识别模型对图像内容进行识别检测是保证视频图像规范播放和传播的基础手段。然而,现有的图像识别模型或识别精度高但计算复杂度高,或计算速度快但识别精度差。基于此,本发明实施例旨在提供一种图像识别模型训练方法、装置、电子设备和可读存储介质,以在保证图像中信息识别检测精度的同时,提高图像识别检测的计算速度。
本实施例中,以在线教育场景下产生图像中的不良信息的识别为例对图像识别模型训练方法做进一步说明。应理解,本实施例中的图像识别模型训练方法能够应用于各种需要确定图像识别模型的方案中,此处并不多做其他限定。
图1是图像识别模型训练方法的流程图。如图1所示,本实施例的图像识别模型训练方法包括如下步骤:
在步骤S100,获取样本数据。
本实施例中,样本数据包括样本图像和图像标注。其中,图像标注用于表征样本图像中的不良信息分类。
可选地,本实施例中的图像通过人工的方式对样本图像中的不良信息进行图像标注,以尽量筛选出样本图像中涉及的不良信息。进一步地,当不良信息分类不同时,可以采用不同形式的标签对不良信息进行标注。
在步骤S200,将样本数据输入至预设的教师网络模型进行训练,以确定训练辅助模型。
在步骤S300,基于样本数据和训练辅助模型对预设的学生网络模型进行训练,以确定图像识别模型。其中,图像识别模型的损失函数根据训练辅助模型的输出结果和学生网络模型的输出结果确定。图像识别模型用于识别图像中的不良信息。
本实施例的技术方案通过获取样本数据,将样本数据输入至预设的教师网络模型进行训练,确定训练辅助模型;再基于样本数据和训练辅助模型对预设的学生网络模型进行训练,确定用于识别图像中不良信息的图像识别模型。再者,由于图像识别模型的损失函数是根据训练辅助模型的输出结果和学生网络模型的输出结果确定的,能够使得训练得到的图像识别模型能够用于识别图像中不良信息的同时,改善现有图像识别模型的规模、计算速度和识别检测精确度。
图2是确定图像识别模型的流程图。如图2所示,在确定图像识别模型时,包括以下步骤:
在步骤S210,将样本数据输入至训练辅助模型,确定第一预测结果。
本实施例中的第一预测结果为训练辅助模型中输入样本数据后对应的输出结果。
在步骤S220,将样本数据输入至预设的学生网络模型,确定第二预测结果。
本实施例中的第二预测结果为学生网络模型中输入样本数据后对应的输出结果。
在步骤230,根据第一预测结果和第二预测结果调整学生网络模型对应的损失函数。
损失函数(loss)是用于描述模型输出的预测值与真实值之间差距大小的函数。本实施例中,通过根据第一预测结果和第二预测结果调整学生网络模型对应的损失函数,以使得训练完成后对应的学生网络模型的输出结果能够与训练辅助模型以及真实值(本实施例中可以理解为图像标注信息)之间的差距尽可能缩小。
在步骤S240,将调整完成的损失函数对应的学生网络模型确定为图像识别模型。
本实施例中,根据第一预测结果和第二预测结果调整学生网络模型对应的损失函数,直至学生网络模型输出的预测结果与训练辅助模型输出的预测结果的误差满足预设条件,则认为学生网络模型已调整完成,同时结束学生网络模型训练,并将误差满足预设条件时的损失函数对应的学生网络模型确定为图像识别模型。
可选地,本实施例中的预设条件可以设置为当学生网络模型输出的预测结果与训练辅助模型输出的预测结果之间的误差小于预设误差阈值即认为误差满足预设条件。同时,预设误差的阈值可以根据实际信息识别检测场景下要求的识别精度进行设置。
本实施例的技术方案通过样本数据输入下训练辅助模型输出的第一预测结果,以及样本数据输入下学生网络模型输出的第二预测结果调整学生网络模型对应的损失函数,并将调整完成的损失函数对应的学生网络模型确定为图像识别模型。由此,基于知识蒸馏方法,通过训练样本数据对预先设置的教师网络模型进行训练,得到预测结果准确,但规模庞大,计算复杂度高的辅助训练网络,再基于辅助训练网络生成规模和计算复杂度远低于辅助训练网络,但预测结果准确度匹配于辅助训练网络的图像识别模型。最后,将训练得到的图像识别模型应用于实际应用场景中的图像不良信息预测过程,识别出其中的不良信息,避免带有不良信息的视频图像的播放和传播。
进一步地,本实施例在根据第一预测结果和第二预测结果调整学生网络模型对应的损失函数时,采用loss算法对训练中的学生网络模型的损失函数进行调整。
常见的loss算法种类有均值平方差(MSE)和交叉熵。其中,MSE主要针对的是回归问题。交叉熵算法一般针对于分类问题。熵表示信息量的期望值或平均信息量,而信息量与概率成反比。交叉熵为相对熵(KL散度)的简化,用于表征不同概率分布之间的距离以及预测输入样本属于某一类的概率。交叉熵的值越小,则表明训练得到的模型精度越高。
可选地,本实施例中的损失函数算法采用交叉熵损失函数(crossentropy)。模型之间交叉熵的值越小,则表明训练得到的模型精度越高。
图3是确定学生网络模型损失函数的流程图。如图3所示,本实施例在调整学生网络模型的损失函数时,包括以下步骤:
在步骤S310,根据第一预测结果和第二预测结果确定训练辅助模型和学生网络模型输出结果之间的第一交叉熵。
本实施例中,第一交叉熵用于表征正在训练的学生网络模型与训练好的训练辅助模型之间的差距,也即正在训练的学生网络模型输出结果与预先训练好的训练辅助模型输出结果之间的差值。第一交叉熵的值越小,表明训练得到的学生网络模型的识别精确度与训练好的训练辅助模型越接近。
可选地,本实施例中可以通过预先设置第一交叉熵阈值的方式来训练学生网络模型,并确定具有与第一交叉熵阈值对应识别检测精度的学生网络模型。
在步骤S320,根据第二预测结果确定学生网络模型输出结果与图像标注之间的第二交叉熵。
本实施例中,第二交叉熵用于表征正在训练的学生网络模型输出结果与用于表征不良信息分类的样本数据中的图像标注之间的差值。第二交叉熵的值越小,表明训练得到的学生网络模型的识别结果与真实不良信息之间的差距越小,学生网络模型的不良信息识别检测精度更高。
可选地,本实施例中可以通过预先设置第二交叉熵阈值的方式来进一步训练学生网络模型,并确定具有与第二交叉熵阈值对应识别检测精度的学生网络模型。
在步骤S330,根据第一交叉熵和第二交叉熵确定学生网络模型对应的损失函数。
本实施例中,通过第一交叉熵和第二交叉熵共同指导学生网络模型的训练,以使得最终得到的学生网络模型(也即图像识别模型)的输出结果与训练辅助模型输出结果相靠近,并尽可能地与样本图像中预先标注的图像标注信息一致,进而在降低图像识别模型规模和复杂度,提高图像识别检测速度的同时,使得图像识别模型的识别检测精确度更高,使用性能更强。
可选地,本实施例中的第一交叉熵和第二交叉熵均设置有对应的权重系数,将第一交叉熵和第二交叉熵与对应的权重系数做加权求和运算,并将加权求和运算的运算结果确定为学生网络模型对应的损失函数。
进一步地,本实施例中第一交叉熵对应的权重系数以及第二交叉熵对应的权重系数均为预先设置的。同时,权重系数的大小可以根据实际识别检测需要进行调整。由此,通过第一交叉熵和第二交叉熵与各自对应的权重系数的加权求和运算确定学生网络模型对应的损失函数。并且,通过调整第一交叉熵和第二交叉熵对应权重系数的大小,使得学生网络模型的损失函数更加贴合实际,进一步提高最终训练得到的图像识别模型的使用性能。
图4是调整学生网络模型损失函数的流程图。如图4所示,图中的“输入”表征输入至教师网络模型和学生网络模型的样本数据。本实施例中,教师网络模型为基于样本数据对预设的教师网络模型进行训练得到的能够用于识别图像中不良信息的训练辅助模型,模型复杂度高,性能优越。学生网络模型为待训练的或正在训练中的预设的学生网络模型,模型精简,模型复杂度低,计算速度快。
“soft target”表征教师网络模型输出的预测结果,包含不同类别之间关系的信息。
“hard target”表征样本图像中的不良信息标签(也可以是样本图像中预先标注好的不良信息对应的图像标注信息/数据)。相比于soft target,hard target包含的信息熵较低。
“L(soft)”表征教师网络模型输出的预测结果与学生网络模型输出的预测结果之间的第一交叉熵。λ表征第一交叉熵对应的权重系数。
“L(hard)”表征学生网络模型输出的预测结果与真实的图像标注之间的第二交叉熵。(1-λ)表征第二交叉熵对应的权重系数。
“total loss”表征调整之后学生网络模型对应的损失函数,也是第一交叉熵与第二交叉熵对应的交叉熵损失函数。
“softmax-1”表征教师网络模型输出的检测结果对应的概率分布。
“softmax-21”表征学生网络模型输出结果中与教师网络模型输出结果相关联部分的概率分布。
“softmax-22”表征学生网络模型输出结果中与实际的图像标注信息相关联部分的概率分布。
具体地,在进行图像识别模型训练时,将样本数据分别输入至预设的教师网络模型和学生网络模型。教师网络模型输出softmax-1对应的预测结果soft target。学生网络模型分别输出softmax-21对应的预测结果和softmax-22对应的预测结果。通过教师网络模型输出的预测结果softtarget和学生网络模型输出的softmax-21对应的预测结果确定第一交叉熵L(soft),并根据学生网络模型输出的softmax-22对应的预测结果和图像标注对应的不良信息标签hard target确定第二交叉熵L(hard)。最后,通过对第一交叉熵L(soft)及其对应的权重系数λ,以及第二交叉熵及其对应的权重系数(1-λ)进行加权求和运算,确定调整后的学生网络模型对应的损失函数total loss。也即:total loss=λL(soft)+(1-λ)L(hard)。由此,在学生网络模型训练过程中,通过交叉熵损失函数学习样本数据中不同类别之间的关联性信息,实现知识迁移,进而训练得到能够用于识别检测不同类型不良信息的图像识别模型,并使得训练得到的图像识别模型能够同时结合教师网络模型性能优越,识别精度高和学生网络模型模型精简,识别速度快的优点。
下面,以不同类型的不良信息对应的图像识别模型的训练方法进行说明。应理解,本实施例中的图像识别模型训练方法可以与一种或多种不良信息类型相对应。具体地,本实施例中以不良信息为第一类信息以及不良信息为第二类信息时的图像识别模型训练为例进行介绍。其中,第一类信息用于表征涉黄类型的不良信息。第二类信息用于表征涉暴和涉政类型的不良信息。
图5是图像识别模型训练方法的另一个流程图。如图5所示,本实施例的图像识别模型为用于识别第一类信息的图像识别模型,对应的图像识别模型训练方法包括以下步骤:
在步骤S510,获取样本数据。其中,样本数据包括样本图像和图像标注。
本实施例中的图像标注用于表征样本图像中涉及的第一类信息。其中,第一类信息用于表征涉黄类型的不良信息。可选地,本实施例中的第一类信息包括多个不同特征的不良信息,例如:“裸露上身”、“裸露下身”、“裸露全身”和“深V”等。同时,在进行图像标注时,基于人工的方式对全部类型的第一类信息进行标注,标注可以通过标签的形式体现。
在步骤S520,响应于不良信息为第一类信息,将样本数据输入预设的教师网络模型进行训练,以确定第一训练辅助模型。
可选地,本实施例中的教师网络模型采用分类神经网络进行训练。例如,可以采用resnet50作为预设的教师网络模型。
resnet50是一种包含50层信息的残差网络,具体包括输入层image、1个独立卷积层conv1、1个最大池化层maxpool、4种卷积残差模块(分别为conv2_x、conv3_x、conv4_x和conv5_x)、1个平均池化层avgpool和1个软最大输出层。其中,resnet50的输入为大小为224×224×3的三维数据,第一个卷积层是独立卷积层,使用64个大小为7×7、步长为2的卷积核,输出的大小为112×112。之后的最大池化层maxpool中,池化窗口和步长分别为3×3和2。接着是4种不同的卷积残差模块,3个conv2_x、4个conv3_x、6个conv4_x和3个conv5_x。每个卷积残差模块有2-3个卷积层和跨越它们的连接组成。同时,在resnet50网络的最后,由平均池化层avgpool和10000维的全连接软最大输出层区分不同类别。
在步骤S530,将样本数据输入第一训练辅助模型,确定对应的第一预测结果。
在步骤S540,将样本数据输入至预设的学生网络模型,确定对应的第二预测结果。
可选地,本实施例中的学生网络模型采用多目标输出网络。例如,可以采用mobilenet作为预设的学生网络模型。
mobilenet网络为基于深度可分离卷积思想构建的轻量级的深层神经网络。其中,深度可分离卷积是指将普通卷积拆分成一个深度卷积(depthwise convolution)和一个逐点卷积(point convolution)。深度卷积是depth级别的操作,针对每个输入通道采用不同的卷积核,也即一个卷积核对应一个输入通道。逐点卷积为采用尺寸为1×1的卷积核的一般卷积操作。因此,相比于其它网络模型,mobilenet网络首先采用深度卷积对不同输入通道分别进行卷积,然后采用逐点卷积将上层的输出再进行结合,使得网络整体实现的效果与一般的标准卷积相媲美,但计算量和模型参数量大大减少,也即模型轻量化程度高。
进一步地,本实施例中的学生网络模型采用mobilenetv2网络。mobilenetv2网络是一种典型的mobilenet网络,其主干网络包括连接Relu激活的1×1的卷积层、深度可分离卷积层和未连接Relu激活的1×1的卷积层。
图6是学生网络模型的示意图。如图6所示,本实施例的学生网络模型以mobilenetv2网络为基础网络,mobilenetv2网络为基础网络之后连接有全连接层FC。通过mobilenetv2网络执行多次卷积操作,提取样本数据中的特征,并由全连接层FC对mobilenetv2提取特征后输出的高度抽象化的特征进行整合,确定不同类型的第一类信息,进而实现第一类信息的识别检测。
在步骤S550,根据第一预测结果和第二预测结果确定第一训练辅助模型和学生网络模型输出结果之间的第一交叉熵。
在步骤S560,根据第二预测结果确定学生网络模型输出结果与图像标注之间的第二交叉熵。
在步骤S570,根据第一交叉熵和第二交叉熵确定学生网络模型对应的损失函数。
在步骤S580,将调整完成的损失函数对应的学生网络模型确定为对应的图像识别模型。
本实施例中,根据第一交叉熵和第二交叉熵调整学生网络模型对应的损失函数,直至学生网络模型输出的预测结果与第一训练辅助模型输出的预测结果的误差满足预设条件,结束学生网络模型训练,并将误差满足预设条件时的损失函数对应的学生网络模型确定为第一类信息的图像识别模型。
本实施例的技术方案通过获取样本数据对应的样本图像和样本图像对应的图像标准信息,在不良信息为第一类信息时,将样本数据输入预设的resnet50网络进行训练,确定第一训练辅助模型。之后将样本图像输入第一训练辅助模型,确定对应的第一预测结果;将样本图像输入至预设的mobilenetv2网络模型,确定对应的第二预测结果,并分别根据第一预测结果和第二预测结果确定第一训练辅助模型和学生网络模型输出结果之间的第一交叉熵,根据第二预测结果确定学生网络模型输出结果与图像标注之间的第二交叉熵,再根据第一交叉熵和第二交叉熵确定学生网络模型对应的损失函数,直至调整后的学生网络模型输出的预测结果与第一训练辅助模型输出的预测结果逼近或一致,表明学生网络模型已调整完成,并将此时确定的损失函数对应的学生网络模型确定为第一类信息的图像识别模型。
图7是图像识别模型训练方法的另一个流程图。如图7所示,本实施例的图像识别模型为用于识别第二类信息的图像识别模型,对应的图像识别模型训练方法包括以下步骤:
在步骤S710,获取样本数据。其中,样本数据包括样本图像和图像标注。
本实施例中的图像标注用于表征样本图像中涉及的第二类信息。其中,第二类信息用于表征涉暴和涉政类型的不良信息,例如:国旗、地图、武器、血腥元素、国家货币等不良信息元素。同时,在进行图像标注时,本实施例中使用矩形框对样本图像中的第二类信息包括的不良信息元素进行标注。
在步骤S720,响应于不良信息为第二类信息,将样本数据输入预设的教师网络模型进行训练,以确定第二训练辅助模型。
可选地,本实施例中的教师网络模型采用分类神经网络进行训练。具体地,可以采用resnet50作为预设的教师网络模型。其中,resnet50的网络结构在前面部分已经介绍,此处不再赘述。
在步骤S730,将样本数据输入第二训练辅助模型,确定对应的第一预测结果。
在步骤S740,将样本数据输入至预设的学生网络模型,确定对应的第二预测结果。
可选地,本实施例中的学生网络模型采用目标检测网络。例如,可以采用yolov5作为预设的学生网络模型。
yolov5是一种目标检测网络,主要包括依次Backbone层、Neck层和Head层。其中,Backbone层,也即跨阶段局部网络,包括多层卷积神经网路,用于在从样本图像中提取图像特征。Neck层,也即路径聚合网络,用于生成特征金字塔,增强模型对应不同缩放尺度对象的检测,识别不同大小和尺度的同一物体。Head层,也即通用检测层,用于对图像特征进行预测,生成带有类概率、对象得分和边界框的输出向量(也即预测类别)。相比于其他形式的目标检测网络,yolov5的模型在模型尺寸和推理速度上均具有很强的优越性。
在步骤S750,根据第一预测结果和第二预测结果确定第一训练辅助模型和学生网络模型输出结果之间的第一交叉熵。
在步骤S760,根据第二预测结果确定学生网络模型输出结果与图像标注之间的第二交叉熵。
在步骤S770,根据第一交叉熵和第二交叉熵确定学生网络模型对应的损失函数。
在步骤S780,将调整完成的损失函数对应的学生网络模型确定为对应的图像识别模型。
本实施例的技术方案通过获取样本数据对应的样本图像和样本图像对应的图像标准信息,在不良信息为第二类信息时,将样本数据输入预设的resnet50网络进行训练,确定第二训练辅助模型。之后将样本图像输入第二训练辅助模型,确定对应的第一预测结果;将样本图像输入至预设的yolov5网络模型,确定对应的第二预测结果,并分别根据第一预测结果和第二预测结果确定第二训练辅助模型和学生网络模型输出结果之间的第一交叉熵,根据第二预测结果确定学生网络模型输出结果与图像标注之间的第二交叉熵,再根据第一交叉熵和第二交叉熵确定学生网络模型对应的损失函数,直至调整后的学生网络模型输出的预测结果与第二训练辅助模型输出的预测结果逼近或一致,表明学生网络模型已调整完成,并将此时确定的损失函数对应的学生网络模型确定为第二类信息的图像识别模型。
图8是图像识别模型训练装置的示意图。如图8所示,本实施例的图像识别模型训练装置8包括获取单元81、训练辅助单元82和图像识别单元83。其中,获取单元81用于获取样本数据。样本数据包括样本图像和图像标注,图像标注用于表征样本图像中的不良信息分类。辅助训练单元82用于将样本数据输入至预设的教师网络模型进行训练,以确定训练辅助模型。图像识别单元83用于基于样本数据和训练辅助模型对预设的学生网络模型进行训练,以确定图像识别模型。其中,图像识别模型的损失函数根据训练辅助模型的输出结果和学生网络模型的输出结果确定,图像识别模型用于识别图像中的不良信息。
本实施例的技术方案通过获取单元获取样本数据,训练辅助单元将样本数据输入至预设的教师网络模型进行训练,确定训练辅助模型;再由图像识别单元基于样本数据和训练辅助模型对预设的学生网络模型进行训练,确定用于识别图像中不良信息的图像识别模型。再者,由于图像识别模型的损失函数是根据训练辅助模型的输出结果和学生网络模型的输出结果确定的,能够使得训练得到的图像识别模型能够用于识别图像中不良信息的同时,改善现有图像识别模型的规模、计算速度和识别检测精确度。
可选地,如图9所示,本实施例的图像识别训练装置8还包括除获取单元81、训练辅助单元82和图像识别单元83之外的处理单元84。其中,处理单元84用于对样本图像进行归一化处理,将样本图像尺寸调整为256×256大小。
可选地,如图9所示,本实施例中的训练辅助单元82包括第一辅助子单元821和第二辅助子单元822。其中,第一辅助子单元821用于响应于不良信息为第一类信息,将样本数据输入预设的教师网络模型进行训练,以确定第一训练辅助模型。第二辅助子单元822用于响应于不良信息为第二类信息,将样本数据输入预设的教师网络模型进行训练,以确定第二训练辅助模型。
进一步地,本实施例中的第一类信息用于表征涉黄类型的不良信息。第二类信息用于表征涉暴和涉政类型的不良信息。
可选地,如图9所示,本实施例中的图像识别单元83包括第一预测子单元831、第二预测子单元832、调整子单元833和确定子单元834。其中,第一预测子单元831用于将样本数据输入至训练辅助模型,确定第一预测结果。第二预测子单元832用于将样本数据输入至预设的学生网络模型,确定第二预测结果。调整子单元833用于根据第一预测结果和第二预测结果调整学生网络模型对应的损失函数。确定子单元834用于将调整完成的损失函数对应的学生网络模型确定为图像识别模型。
可选地,如图9所示,本实施例的调整子单元833包括计算模块8331。其中,计算单元8331用于将第一交叉熵和第二交叉熵与对应的权重系数做加权求和运算,并将加权运算的运算结果确定为学生网络模型对应的损失函数。
图10是本发明实施例的电子设备的示意图。如图10所示,本实施例的电子设备为通用的数据处理装置,包括通用的计算机硬件结构,其至少包括处理器101和存储器102。处理器101和存储器102通过总线103连接。存储器102适于存储处理器101可执行的指令或程序。处理器101可以是独立的微处理器,也可以是一个或者多个微处理器集合。由此,处理器101通过执行存储器102所存储的指令,从而执行如上所述的本发明实施例的方法流程实现对于数据的处理和对于其它装置的控制。总线103将上述多个组件连接在一起,同时将上述组件连接到显示控制器104、显示装置以及输入/输出(I/O)装置105。输入/输出(I/O)装置105可以是鼠标、键盘、调制解调器、网络接口、触控输入装置、体感输入装置、打印机以及本领域公知的其他装置。典型地,输入/输出装置105通过输入/输出(I/O)控制器106与系统相连。
其中,存储器102可以存储软件组件,例如操作系统、通信模块、交互模块以及应用程序。以上所述的每个模块和应用程序都对应于完成一个或多个功能和在发明实施例中描述的方法的一组可执行程序指令。
本领域的技术人员应明白,本申请的实施例可提供为方法、装置(设备)或计算机程序产品。因此,本申请可采用完全硬件实施例、完全软件实施例或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可读存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品。
本申请是参照根据本申请实施例的方法、装置(设备)和计算机程序产品的流程图来描述的。应理解可由计算机程序指令实现流程图中的每一流程。
本发明的另一实施例涉及一种计算机程序产品,包括计算机程序/指令,计算机程序程序/指令用于在被处理器执行时实现上述部分或全部的方法实施例中的部分或全部步骤。这些计算机程序/指令可以存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的程序/指令产生包括指令装置的制造品,该指令装置实现流程图一个流程或多个流程中指定的功能。也可提供这些计算机程序/指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程中指定的功能的装置。
本发明的另一实施例涉及一种计算机可读存储介质,可以是非易失性存储介质,用于存储计算机可读程序,所述计算机可读程序用于供计算机执行上述部分或全部的方法实施例。
即,本领域技术人员可以理解,实现上述实施例方法中的全部或部分步骤是可以通过程序来指令相关的硬件来完成,该程序存储在一个存储介质中,包括若干指令用以使得一个设备(可以是单片机,芯片等)或处理器(processor)执行本申请各实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-OnlyMemory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述仅为本发明的优选实施例,并不用于限制本发明,对于本领域技术人员而言,本发明可以有各种改动和变化。凡在本发明的精神和原理之内所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

Claims (13)

1.一种图像识别模型训练方法,其特征在于,所述方法包括:
获取样本数据,所述样本数据包括样本图像和图像标注,所述图像标注用于表征所述样本图像中的不良信息分类;
将所述样本数据输入至预设的教师网络模型进行训练,以确定训练辅助模型;
基于所述样本数据和训练辅助模型对预设的学生网络模型进行训练,以确定图像识别模型,其中,所述图像识别模型的损失函数根据所述训练辅助模型的输出结果和学生网络模型的输出结果确定,所述图像识别模型用于识别图像中的不良信息。
2.根据权利要求1所述的方法,其特征在于,所述基于所述样本数据和训练辅助模型对预设的学生网络模型进行训练,以确定图像识别模型包括:
将所述样本数据输入至训练辅助模型,确定第一预测结果;
将所述样本数据输入至预设的学生网络模型,确定第二预测结果;
根据所述第一预测结果和第二预测结果调整所述学生网络模型对应的损失函数;
将调整完成的损失函数对应的学生网络模型确定为所述图像识别模型。
3.根据权利要求2所述的方法,其特征在于,所述根据所述第一预测结果和第二预测结果调整所述学生网络模型对应的损失函数包括:
根据所述第一预测结果和第二预测结果确定所述训练辅助模型和学生网络模型输出结果之间的第一交叉熵;
根据所述第二预测结果确定所述学生网络模型输出结果与图像标注之间的第二交叉熵;
根据所述第一交叉熵和第二交叉熵确定所述学生网络模型对应的损失函数。
4.根据权利要求3所述的方法,其特征在于,所述根据所述第一交叉熵和第二交叉熵确定所述学生网络模型对应的损失函数包括:
将所述第一交叉熵和第二交叉熵与对应的权重系数做加权求和运算,并将所述加权运算的运算结果确定为所述学生网络模型对应的损失函数。
5.根据权利要求1所述的方法,其特征在于,所述将所述样本数据输入至预设的教师网络模型进行训练,以确定训练辅助模型包括:
响应于所述不良信息为第一类信息,将所述样本数据输入预设的教师网络模型进行训练,以确定第一训练辅助模型;
其中,所述第一类信息用于表征涉黄类型的不良信息。
6.根据权利要求1所述的方法,其特征在于,所述将所述样本数据输入至预设的教师网络模型进行训练,以确定训练辅助模型包括:
响应于所述不良信息为第二类信息,将所述样本数据输入预设的教师网络模型进行训练,以确定第二训练辅助模型;
其中,所述第二类信息用于表征涉暴和涉政类型的不良信息。
7.根据权利要求1所述的方法,其特征在于,所述将所述样本数据输入至预设的教师网络模型进行训练,以确定训练辅助模型之前,所述方法还包括:
对所述样本图像进行归一化处理。
8.根据权利要求1所述的方法,其特征在于,所述教师网络模型采用分类神经网络。
9.根据权利要求1所述的方法,其特征在于,所述学生网络模型采用多目标输出网络或目标检测网络。
10.一种图像识别模型训练装置,其特征在于,所述装置包括:
获取单元,用于获取样本数据,所述样本数据包括样本图像和图像标注,所述图像标注用于表征所述样本图像中的不良信息分类;
辅助训练单元,用于将所述样本数据输入至预设的教师网络模型进行训练,以确定训练辅助模型;
图像识别单元,用于基于所述样本数据和训练辅助模型对预设的学生网络模型进行训练,以确定图像识别模型,其中,所述图像识别模型的损失函数根据所述训练辅助模型的输出结果和学生网络模型的输出结果确定,所述图像识别模型用于识别图像中的不良信息。
11.一种计算机程序产品,其特征在于,所述计算机程序产品包括计算机程序/指令,所述计算机程序/指令被处理器执行时实现如权利要求1-9中任一项所述的方法。
12.一种电子设备,包括存储器和处理器,其特征在于,所述存储器用于存储一条或多条计算机程序指令,其中,所述一条或多条计算机程序指令被所述处理器执行以实现如权利要求1-9中任一项所述的方法。
13.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现权利要求1-9中任一项所述的方法步骤。
CN202110482330.9A 2021-04-30 2021-04-30 图像识别模型训练方法、装置、电子设备和可读存储介质 Active CN113111968B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110482330.9A CN113111968B (zh) 2021-04-30 2021-04-30 图像识别模型训练方法、装置、电子设备和可读存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110482330.9A CN113111968B (zh) 2021-04-30 2021-04-30 图像识别模型训练方法、装置、电子设备和可读存储介质

Publications (2)

Publication Number Publication Date
CN113111968A true CN113111968A (zh) 2021-07-13
CN113111968B CN113111968B (zh) 2024-03-22

Family

ID=76720740

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110482330.9A Active CN113111968B (zh) 2021-04-30 2021-04-30 图像识别模型训练方法、装置、电子设备和可读存储介质

Country Status (1)

Country Link
CN (1) CN113111968B (zh)

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114067183A (zh) * 2021-11-24 2022-02-18 北京百度网讯科技有限公司 神经网络模型训练方法、图像处理方法、装置和设备
CN114092918A (zh) * 2022-01-11 2022-02-25 深圳佑驾创新科技有限公司 模型训练方法、装置、设备及存储介质
CN114299442A (zh) * 2021-11-15 2022-04-08 苏州浪潮智能科技有限公司 一种行人重识别方法、系统、电子设备及存储介质
CN114494800A (zh) * 2022-02-17 2022-05-13 平安科技(深圳)有限公司 预测模型训练方法、装置、电子设备及存储介质
CN116311102A (zh) * 2023-03-30 2023-06-23 哈尔滨市科佳通用机电股份有限公司 基于改进的知识蒸馏的铁路货车故障检测方法及系统
CN114494800B (zh) * 2022-02-17 2024-05-10 平安科技(深圳)有限公司 预测模型训练方法、装置、电子设备及存储介质

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107291737A (zh) * 2016-04-01 2017-10-24 腾讯科技(深圳)有限公司 敏感图像识别方法及装置
JP2017224027A (ja) * 2016-06-13 2017-12-21 三菱電機インフォメーションシステムズ株式会社 データのラベリングモデルに係る機械学習方法、コンピュータおよびプログラム
CN111160474A (zh) * 2019-12-30 2020-05-15 合肥工业大学 一种基于深度课程学习的图像识别方法
CN111476309A (zh) * 2020-04-13 2020-07-31 北京字节跳动网络技术有限公司 图像处理方法、模型训练方法、装置、设备及可读介质
CN111639710A (zh) * 2020-05-29 2020-09-08 北京百度网讯科技有限公司 图像识别模型训练方法、装置、设备以及存储介质
CN111814689A (zh) * 2020-07-09 2020-10-23 浙江大华技术股份有限公司 火灾识别网络模型的训练方法、火灾识别方法及相关设备
CN112001364A (zh) * 2020-09-22 2020-11-27 上海商汤临港智能科技有限公司 图像识别方法及装置、电子设备和存储介质

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107291737A (zh) * 2016-04-01 2017-10-24 腾讯科技(深圳)有限公司 敏感图像识别方法及装置
JP2017224027A (ja) * 2016-06-13 2017-12-21 三菱電機インフォメーションシステムズ株式会社 データのラベリングモデルに係る機械学習方法、コンピュータおよびプログラム
CN111160474A (zh) * 2019-12-30 2020-05-15 合肥工业大学 一种基于深度课程学习的图像识别方法
CN111476309A (zh) * 2020-04-13 2020-07-31 北京字节跳动网络技术有限公司 图像处理方法、模型训练方法、装置、设备及可读介质
CN111639710A (zh) * 2020-05-29 2020-09-08 北京百度网讯科技有限公司 图像识别模型训练方法、装置、设备以及存储介质
CN111814689A (zh) * 2020-07-09 2020-10-23 浙江大华技术股份有限公司 火灾识别网络模型的训练方法、火灾识别方法及相关设备
CN112001364A (zh) * 2020-09-22 2020-11-27 上海商汤临港智能科技有限公司 图像识别方法及装置、电子设备和存储介质

Cited By (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114299442A (zh) * 2021-11-15 2022-04-08 苏州浪潮智能科技有限公司 一种行人重识别方法、系统、电子设备及存储介质
CN114067183A (zh) * 2021-11-24 2022-02-18 北京百度网讯科技有限公司 神经网络模型训练方法、图像处理方法、装置和设备
CN114067183B (zh) * 2021-11-24 2022-10-28 北京百度网讯科技有限公司 神经网络模型训练方法、图像处理方法、装置和设备
CN114092918A (zh) * 2022-01-11 2022-02-25 深圳佑驾创新科技有限公司 模型训练方法、装置、设备及存储介质
CN114494800A (zh) * 2022-02-17 2022-05-13 平安科技(深圳)有限公司 预测模型训练方法、装置、电子设备及存储介质
CN114494800B (zh) * 2022-02-17 2024-05-10 平安科技(深圳)有限公司 预测模型训练方法、装置、电子设备及存储介质
CN116311102A (zh) * 2023-03-30 2023-06-23 哈尔滨市科佳通用机电股份有限公司 基于改进的知识蒸馏的铁路货车故障检测方法及系统
CN116311102B (zh) * 2023-03-30 2023-12-15 哈尔滨市科佳通用机电股份有限公司 基于改进的知识蒸馏的铁路货车故障检测方法及系统

Also Published As

Publication number Publication date
CN113111968B (zh) 2024-03-22

Similar Documents

Publication Publication Date Title
WO2020221298A1 (zh) 文本检测模型训练方法、文本区域、内容确定方法和装置
CN106951825B (zh) 一种人脸图像质量评估系统以及实现方法
CN111259625B (zh) 意图识别方法、装置、设备及计算机可读存储介质
CN113111968A (zh) 图像识别模型训练方法、装置、电子设备和可读存储介质
US20170177972A1 (en) Method for analysing media content
CN110796199B (zh) 一种图像处理方法、装置以及电子医疗设备
CN111401201A (zh) 一种基于空间金字塔注意力驱动的航拍图像多尺度目标检测方法
CN112150821B (zh) 轻量化车辆检测模型构建方法、系统及装置
KR20190113119A (ko) 합성곱 신경망을 위한 주의집중 값 계산 방법
CN110929622A (zh) 视频分类方法、模型训练方法、装置、设备及存储介质
CN110851641B (zh) 跨模态检索方法、装置和可读存储介质
CN111783576A (zh) 基于改进型YOLOv3网络和特征融合的行人重识别方法
CN104616005A (zh) 一种领域自适应的人脸表情分析方法
CN112861917A (zh) 基于图像属性学习的弱监督目标检测方法
CN111079374A (zh) 字体生成方法、装置和存储介质
CN110738132A (zh) 一种具备判别性感知能力的目标检测质量盲评价方法
CN110533184B (zh) 一种网络模型的训练方法及装置
CN111967399A (zh) 一种基于改进的Faster RCNN行为识别方法
JP2019153092A (ja) 位置特定装置、位置特定方法及びコンピュータプログラム
CN113378919B (zh) 融合视觉常识和增强多层全局特征的图像描述生成方法
CN113780145A (zh) 精子形态检测方法、装置、计算机设备和存储介质
CN116363712B (zh) 一种基于模态信息度评估策略的掌纹掌静脉识别方法
CN112380861A (zh) 模型训练方法、装置及意图识别方法、装置
CN111582404B (zh) 内容分类方法、装置及可读存储介质
CN114445716A (zh) 关键点检测方法、装置、计算机设备、介质及程序产品

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant