CN114708467B - 基于知识蒸馏的不良场景识别方法及系统及设备 - Google Patents
基于知识蒸馏的不良场景识别方法及系统及设备 Download PDFInfo
- Publication number
- CN114708467B CN114708467B CN202210101442.XA CN202210101442A CN114708467B CN 114708467 B CN114708467 B CN 114708467B CN 202210101442 A CN202210101442 A CN 202210101442A CN 114708467 B CN114708467 B CN 114708467B
- Authority
- CN
- China
- Prior art keywords
- model
- picture
- training
- bad
- data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/217—Validation; Performance evaluation; Active pattern learning techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/048—Activation functions
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
Abstract
基于知识蒸馏的不良场景识别方法及系统及设备,包括以下步骤:步骤1,不良场景图片采集及数据集构建步骤2,不平衡数据增强操作步骤3,不良场景图片识别模型建立;步骤4,模型规模压缩及吞吐率提升;步骤5,不良场景图片识别:对于需要识别的图片p,在预处理后,输入到步骤4中训练好的识别模型中,判断其是否是不良场景的图片。本发明利用数据增强、分设权重等方式处理数据不平衡问题,基于图像特征信息提取提高模型对不同不良场景图片类别的识别能力,并基于知识蒸馏提高模型的吞吐率,具有信息挖掘充分、性能稳健、识别效率高等优点,使得本发明较其他不良场景识别的方法具有明显的优势。
Description
技术领域
本发明涉及不良场景图片识别领域,具体涉及基于知识蒸馏的不良场景识别方法及系统及设备。
背景技术
近年来,随着网络社交媒体的普及,网络图片的监管逐渐成为社会安全领域的一大挑战。能够及时有效地发现识别涉及不良场景的图片,是应对该挑战的现实需要。现有的识别不良场景的方法主要分为两类,一类是人工审核的方式,一类是图片识别模型结合人工审核的方式。其中,人工审核的方式存在着识别精度低、效率低下、成本较高等缺陷。其次,在不良场景识别的领域,图片识别模型往往对数据有着很强的依赖性,而现实中不同场景的图片获取途径比较困难,数量分布十分不均,模型的识别性能也会因此受到影响;此外,图片识别模型的性能和其结构复杂程度是正相关的,模型的吞吐率与性能之间存在着冲突。所以,亟需一种新的可以应对不平衡数据且有着足够吞吐率的不良场景识别方法。在数据挖掘领域,有着许多处理样本数据分布不平衡的方法。同时,也存在着一些压缩模型规模的方式。通过利用这些方法,图片识别模型的性能和效率得以提升,可以更高效地识别更多的不良场景图片。
现在存在着大量有关图片内容检测的工作。
现有技术1提出了一种针对新闻场景的景别识别方法,主要包括:首先,构建新闻场景的景别识别数据集、场景识别数据集和目标检测数据集;然后分别训练场景识别网络和目标检测网络;最后,将图像进行编码,输入到训练后的模型中进行识别。
现有技术2提出了一种场景识别方法,方法主要包括:调用场景特征提取网络和场景预测网络,基于第一驾驶场景的第一场景序列进行场景预测,得到第二场景序列;基于第二场景序列和第一驾驶场景的第三场景序列,训练场景特征提取网络和场景预测网络;调用训练后的场景特征提取网络和场景分类网络,基于第二驾驶场景的场景序列进行场景分类,得到预测类别标签;基于第二驾驶场景的场景类别标签和预测类别标签,训练场景分类网络;获取场景识别模型,场景识别模型包括训练后的场景特征提取网络和训练后的场景分类网络。
上述基于知识蒸馏的场景识别方法都利用了有监督的图像识别模型,没有考虑数据不平衡的情况,可能会导致模型在某个类别的性能较差。此外,上述方法也没有考虑模型的规模和吞吐率,可能会导致模型在某些场景下难以适用。
发明内容
本发明的目的在于提供基于知识蒸馏的不良场景识别方法及系统及设备,以解决上述问题。
为实现上述目的,本发明采用以下技术方案:
基于知识蒸馏的不良场景识别方法,包括以下步骤:
步骤1,不良场景图片采集及数据集构建:以网络社交媒体网站为数据源,分别对不良场景进行图片爬取,并同时构建正常图片数据集,得到总数据集
步骤2,不平衡数据增强:对于样本数量小于100的类别,对其训练集中的图片分别进行增强操作,生成与其他不良场景类别数目近似的增强样本,扩充到总数聚集;
步骤3,不良场景图片识别模型建立:从步骤2所构建的数据集中抽取训练样本,构建和训练有监督的不良图片识别模型;
步骤4,模型规模压缩及吞吐率提升:利用知识蒸馏的方式,对训练好的模型进行模型压缩,提升模型的吞吐率;
步骤5,不良场景图片识别:对于需要识别的图片p,在预处理后,输入到步骤4中训练好的识别模型中,判断其是否是不良场景的图片。
进一步的,步骤1中利用网络爬虫或网络平台提供的应用程序接口分别对不良场景进行图片爬取,不良场景包括吸烟、酗酒、吸毒和赌博。
进一步的,步骤2中不平衡数据增强,包括:首先对各个类别的样本进行随机抽取,按照8:2的比例划分训练集和验证集;之后,对于样本数量小于100的类别,对其训练集中的图片分别进行水平翻转、垂直翻转、添加噪音、旋转随机角度、模糊操作,生成与其他不良场景类别数目近似的增强样本,并加入到原来的训练集中,得到新的训练集。
进一步的,步骤3不良场景图片识别模型建立中,根据步骤2所构建的数据集得到训练样本数据集,利用基于交叉熵的损失函数和正则化项构建基于知识蒸馏的有监督图像分类模型,使用Y表示样本数据的标签信息,其中,对中图片pi,Yi=j表示样本pi属于第j个类别,j=0,1,2,3,4分别对应吸烟、酗酒、吸毒、赌博、正常五个类别的图片;对于每一张图片,首先将其分辨率转化为224×224,再对每一个像素值进行归一化处理;使用X表示训练数据的数据矩阵;选取ResNet152作为图像特征提取模型;在获取了每张图片的特征信息之后,将这些信息输入到分类模型中,最终得到每张图片的类别信息;将ResNet152模型的参数标记为W152,交叉熵函数为CE(·),则训练的目的是得到将数据矩阵X映射到标注信息矩阵Y的W152,训练方式为:
式中α为正则化项参数,‖·‖1为矩阵的1范数。
进一步的,模型具体的训练过程为:
(1)读入图片和标签信息,转换图片的分辨率并将像素点归一化,得到训练集的数据矩阵X;
(2)将数据矩阵输入ResNet152网络,得到每张图片的类别,即:使用带权的交叉熵函数得到模型的训练误差,即:/>其中样本较少的类别具有较高的权重;
(3)通过训练误差更新模型参数,直至训练误差收敛,保存参数矩阵W152。
进一步的,步骤4模型规模压缩及吞吐率提升中,主要包括:通过步骤3得到训练好的ResNet152模型,设置为教师模型;选用较小的ResNet18或ResNet34作为学生模型;并且设置一个具有三层神经网络的鉴别器;记教师模型的参数为Wt,学生模型的参数为Ws,鉴别器的参数为Wd,Y∈{Yt,Ys}分别表示概率来自于教师模型或学生模型,KL散度记为KL(·),二元交叉熵函数记为BCE(·),则训练方式为:
式中,是鉴别器误差的权重,σ表示sigmoid函数。
进一步的,知识蒸馏的具体训练过程为:
(1)将训练集的数据矩阵X输入到教师模型中,得到输出XWt;
(2)将XWt与学生模型的输出概率XWs进行比较,将两个概率输入到KL散度中,得到KL误差;
(3)鉴别器以输出概率作为输入,鉴别输入来自于哪一个模型;
(4)通过训练误差和鉴别器误差更新学生模型的参数,通过鉴别器误差更新鉴别器的参数,直至两个误差收敛,保留学生模型的系数矩阵Ws。
进一步的,步骤5不良场景图片识别中,对于需要识别的图片p,在预处理后,输入到步骤4中训练好的轻量级学生模型中,判断其是否是不良场景的图片;之后,通过有监督的不良场景识别模型,得到目标图片p的预测标签y=j,当j∈{0,1,2,3}时,则判定该图片为不良场景图片;否则,该图片为正常图片。
进一步的,基于知识蒸馏的不良场景图片识别系统,包括:
不良场景图片采集及数据集构建模块,用于以网络社交媒体网站为数据源,分别对不良场景进行图片爬取,并同时构建正常图片数据集,得到总数据集
数据增强模块,用于对于样本数量比较小的类别,对其训练集中的图片分别进行增强操作,生成与其他不良场景类别数目近似的增强样本,扩充到总数聚集;
不良场景图片识别模型建立模块,用于从所构建的数据集中抽取训练样本,构建和训练有监督的不良图片识别模型;
模型规模压缩模块,用于利用知识蒸馏的方式,对训练好的模型进行模型压缩,提升模型的吞吐率;
不良场景图片识别模块,用于对于需要识别的图片p,在预处理后,输入到步骤4中训练好的识别模型中,判断其是否是不良场景的图片。
进一步的,一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现基于知识蒸馏的不良场景图片识别方法的步骤。
与现有技术相比,本发明有以下技术效果:
通过获取网络平台图片中丰富的特征信息来识别不良场景图片,并使用知识蒸馏的方法提升模型的吞吐率。首先,对吸烟、酗酒、吸毒、赌博四个不良场景在各个网络平台上进行采集,构建不良场景和正常的数据集;其次,利用数据增强的方式,扩充不平衡的不良场景类别图片;之后,利用所构建的数据集训练有监督的不良场景图片识别模型;然后,利用知识蒸馏的方式压缩模型的规模,提升模型的吞吐率;最后,利用得到的不良场景图片识别模型对未知的图片进行类别的识别。本发明利用数据增强、分设权重等方式处理数据不平衡问题,基于图像特征信息提取提高模型对不同不良场景图片类别的识别能力,并基于知识蒸馏提高模型的吞吐率,具有信息挖掘充分、性能稳健、识别效率高等优点,使得本发明较其他不良场景识别的方法具有明显的优势。
本发明可以在不需要人工监管的情况下使用,节约了人力物力成本,并提高了审查效率;通过数据增强和调整误差权重的方式可以改善这种情况,提升模型的总体识别性能。通过知识蒸馏的方式可以在较少地降低性能的同时显著地压缩模型的规模,从而提高模型的适应能力,减少所需使用成本。
附图说明
图1是本发明基于知识蒸馏的不良场景识别方法框图。
图2是数据采集过程的流程图。
图3是不平衡数据增强过程的流程图。
图4是识别模型训练过程流程图。
图5是模型压缩过程的流程图。
图6是不良场景图片识别的流程图。
具体实施方式
以下结合附图及实施例对本发明的实施方式进行详细说明。需要说明的是,此处描述的实施例只用以解释本发明,并不用于限定本发明。此外,在不冲突的情况下,本发明中的实施例涉及的技术特征可以相互结合。
本发明的目的是提供一种基于知识蒸馏的不良场景识别方法,通过获取网络平台图片中丰富的特征信息来识别不良场景图片,并使用知识蒸馏的方法提升模型的吞吐率。首先,对吸烟、酗酒、吸毒、赌博四个不良场景在各个网络平台上进行采集,构建不良场景和正常的数据集;其次,利用数据增强的方式,扩充不平衡的不良场景类别图片;之后,利用所构建的数据集训练有监督的不良场景图片识别模型;然后,利用知识蒸馏的方式压缩模型的规模,提升模型的吞吐率;最后,利用得到的不良场景图片识别模型对未知的图片进行类别的识别。本发明利用数据增强、分设权重等方式处理数据不平衡问题,基于图像特征信息提取提高模型对不同不良场景图片类别的识别能力,并基于知识蒸馏提高模型的吞吐率,具有信息挖掘充分、性能稳健、识别效率高等优点,使得本发明较其他不良场景识别的方法具有明显的优势。
本发明的具体实施过程包括数据采集过程、数据增强过程、模型建立过程、模型压缩过程、不良场景识别过程。图1是本发明基于知识蒸馏的不良场景图片识别方法框图。
1.数据采集过程
数据获取的具体过程如下:
(1)通过爬虫技术,根据不良场景类别的相关关键字进行图片爬取。在爬取时,可以使用如“smoking cigarette”、“taking drug”、“play mahjong”等不良场景相关标签进行爬取;在爬取正常图片时,可以通过随机的方式对目标网页进行爬取。
(2)对于不良场景的图片,分别对吸烟、酗酒、吸毒、赌博四个典型的不良场景进行图片的爬取,并对爬取到的图片进行去重处理。得到四个类别的数据集合
(3)对于正常的图片,需要去除涉及不良场景的图片。并且,为了模拟真实世界,需要保证正常图片的数目远大于不良场景图片的数目。最终得到正常图片集合
以上的步骤流程如图2所示,从而得到图片集合
2.不平衡数据增强过程
对数据采集过程所构建的数据集进行数据分析。首先对各个类别的样本进行随机抽取,按照8:2的比例划分训练集和验证集。之后,对于样本数量比较小的类别,对其训练集中的图片分别进行水平翻转、垂直翻转、添加噪音、旋转随机角度、模糊等操作,生成与其他不良场景类别数目近似的增强样本,并加入到原来的训练集中,得到新的训练集。最终,将样本数量较少的类别图片数量扩充至原来的4-8倍,使得不良场景的各个类别数目接近一致。并且,记录样本较少的类别,在后续模型训练时加大对应的误差权重。该过程的流程图如图3所示。
3.识别模型建立过程
根据数据增强后的数据集得到训练样本数据集,利用基于交叉熵的损失函数和正则化项构建基于知识蒸馏的有监督的图像分类模型。使用Y表示样本数据的标签信息,其中,对中图片pi,Yi=j表示样本pi属于第j个类别,j=0,1,2,3,4分别对应吸烟、酗酒、吸毒、赌博、正常五个类别的图片。对于每一张图片,首先将其分辨率转化为224×224,再对每一个像素值进行归一化处理。使用X表示训练数据的数据矩阵。选取ResNet(ResidualNetwork,ResNet)作为图像特征提取模型。考虑到网络层数越深,特征表达能力越强,选用ResNet152模型。在获取了每张图片的特征信息之后,将这些信息输入到分类模型中,最终得到每张图片的类别信息。将ResNet152模型的参数标记为W152,交叉熵函数为CE(·),则训练的目的是得到可以将数据矩阵X映射到标注信息矩阵Y的W152,训练方式为:
式中α为正则化项参数,‖·‖1为矩阵的1范数。模型具体的训练过程为:
(1)读入图片和标签信息,转换图片的分辨率并将像素点归一化,得到训练集的数据矩阵X;
(2)将数据矩阵输入ResNet152网络,得到每张图片的类别,即:使用带权的交叉熵函数得到模型的训练误差,即:/>其中样本较少的类别具有较高的权重;
(3)通过训练误差更新模型参数,直至训练误差收敛,保存参数矩阵W152。
上述识别模型的训练过程如图4所示。
4.模型压缩过程
通过步骤3得到训练好的ResNet152模型,设置为老师模型;选用较小的ResNet18或ResNet34作为学生模型;并且设置一个具有三层神经网络的鉴别器。记教师模型的参数为Wt(即步骤3中得到的W152,不参与训练),学生模型的参数为Ws,鉴别器的参数为Wd,Y∈{Yt,Ys}分别表示概率来自于教师模型或学生模型,KL散度记为KL(·),二元交叉熵函数记为BCE(·),则训练方式为:
式中,是鉴别器误差的权重,σ表示sigmoid函数。知识蒸馏的具体训练过程为:
(1)将训练集的数据矩阵X输入到教师模型中,得到输出XWt;
(2)将XWt与学生模型的输出概率XWs进行比较。将两个概率输入到KL散度中,得到KL误差;
(3)鉴别器以输出概率作为输入,鉴别输入来自于哪一个模型;
(4)通过训练误差和鉴别器误差更新学生模型的参数,通过鉴别器误差更新鉴别器的参数,直至两个误差收敛,保留学生模型的系数矩阵Ws。
以上模型蒸馏过程的训练流程图如图5所示。
5.不良场景识别过程
对于需要识别的图片p,在预处理后,输入到步骤4中训练好的识别模型中,判断其是否是不良场景的图片。通过有监督的不良场景识别模型系数矩阵Ws,可以得到目标图片p的预测标签y=j,当j∈{0,1,2,3}时,则判定该图片为不良场景图片;否则,该图片为正常图片。该识别过程如图6所示。
本发明再一实施例中,提供一种基于知识蒸馏的不良场景识别系统,能够用于实现上述的基于知识蒸馏的不良场景识别方法,具体的,该基于知识蒸馏的不良场景识别系统包括:
不良场景图片采集及数据集构建模块,用于以网络社交媒体网站为数据源,分别对不良场景进行图片爬取,并同时构建正常图片数据集,得到总数据集
数据增强模块,用于对于样本数量比较小的类别,对其训练集中的图片分别进行增强操作,生成与其他不良场景类别数目近似的增强样本,扩充到总数聚集;
不良场景图片识别模型建立模块,用于从所构建的数据集中抽取训练样本,构建和训练有监督的不良图片识别模型;
模型规模压缩模块,用于利用知识蒸馏的方式,对训练好的模型进行模型压缩,提升模型的吞吐率;
不良场景图片识别模块,用于对于需要识别的图片p,在预处理后,输入到步骤4中训练好的识别模型中,判断其是否是不良场景的图片。
本发明再一个实施例中,提供了一种计算机设备,该计算机设备包括处理器以及存储器,所述存储器用于存储计算机程序,所述计算机程序包括程序指令,所述处理器用于执行所述计算机存储介质存储的程序指令。处理器可能是中央处理单元(CentralProcessing Unit,CPU),还可以是其他通用处理器、数字信号处理器(Digital SignalProcessor、DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field-Programmable GateArray,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等,其是终端的计算核心以及控制核心,其适于实现一条或一条以上指令,具体适于加载并执行计算机存储介质内一条或一条以上指令从而实现相应方法流程或相应功能;本发明实施例所述的处理器可以用于知识蒸馏的不良场景识别方法的操作。
Claims (6)
1.基于知识蒸馏的不良场景识别方法,其特征在于,包括以下步骤:
步骤1,不良场景图片采集及数据集构建:以网络社交媒体网站为数据源,分别对不良场景进行图片爬取,并同时构建正常图片数据集,得到总数据集;
步骤2,不平衡数据增强:对于样本数量小于100的类别,对其训练集中的图片分别进行增强操作,生成与其他不良场景类别数目近似的增强样本,扩充到总数聚集;
步骤3,不良场景图片识别模型建立:从步骤2所构建的数据集中抽取训练样本,构建和训练有监督的不良图片识别模型;
步骤4,模型规模压缩及吞吐率提升:利用知识蒸馏的方式,对训练好的模型进行模型压缩,提升模型的吞吐率;
步骤5,不良场景图片识别:对于需要识别的图片,在预处理后,输入到步骤4中训练好的识别模型中,判断其是否是不良场景的图片;
步骤3不良场景图片识别模型建立中,根据步骤2所构建的数据集得到训练样本数据集,利用基于交叉熵的损失函数和正则化项构建基于知识蒸馏的有监督图像分类模型,使用表示样本数据的标签信息,其中,对/>中图片/>,/>表示样本/>属于第/>个类别,分别对应吸烟、酗酒、吸毒、赌博、正常五个类别的图片;对于每一张图片,首先将其分辨率转化为224×224,再对每一个像素值进行归一化处理;使用/>表示训练数据的数据矩阵;选取ResNet152作为图像特征提取模型;在获取了每张图片的特征信息之后,将这些信息输入到分类模型中,最终得到每张图片的类别信息;将ResNet152模型的参数标记为/>,交叉熵函数为/>,则训练的目的是得到将数据矩阵/>映射到标注信息矩阵/>的/>,训练方式为:
式中为正则化项参数,/>为矩阵的1范数;
模型具体的训练过程为:
(1)读入图片和标签信息,转换图片的分辨率并将像素点归一化,得到训练集的数据矩阵;
(2)将数据矩阵输入ResNet152网络,得到每张图片的类别,即:;使用带权的交叉熵函数得到模型的训练误差,即:/>,其中样本较少的类别具有较高的权重;
(3)通过训练误差更新模型参数,直至训练误差收敛,保存参数矩阵;
步骤4模型规模压缩及吞吐率提升中,主要包括:通过步骤3得到训练好的ResNet152模型,设置为教师模型;选用较小的ResNet18或ResNet34作为学生模型;并且设置一个具有三层神经网络的鉴别器;记教师模型的参数为,学生模型的参数为/>,鉴别器的参数为,/>分别表示概率来自于教师模型或学生模型,/>散度记为/>,二元交叉熵函数记为/>,则训练方式为:
式中,是鉴别器误差的权重,/>表示sigmoid函数;
知识蒸馏的具体训练过程为:
(1)将训练集的数据矩阵输入到教师模型中,得到输出/>;
(2)将与学生模型的输出概率/>进行比较,将两个概率输入到KL散度中,得到误差;
(3)鉴别器以输出概率作为输入,鉴别输入来自于哪一个模型;
(4)通过训练误差和鉴别器误差更新学生模型的参数,通过鉴别器误差更新鉴别器的参数,直至两个误差收敛,保留学生模型的系数矩阵。
2.根据权利要求1中所述的基于知识蒸馏的不良场景图片识别方法,其特征在于,步骤1中利用网络爬虫或网络平台提供的应用程序接口分别对不良场景进行图片爬取,不良场景包括吸烟、酗酒、吸毒和赌博。
3.根据权利要求1中所述的基于知识蒸馏的不良场景图片识别方法,其特征在于,步骤2中不平衡数据增强,包括:首先对各个类别的样本进行随机抽取,按照8:2的比例划分训练集和验证集;之后,对于样本数量小于100的类别,对其训练集中的图片分别进行水平翻转、垂直翻转、添加噪音、旋转随机角度、模糊操作,生成与其他不良场景类别数目近似的增强样本,并加入到原来的训练集中,得到新的训练集。
4.根据权利要求1中所述的基于知识蒸馏的不良场景图片识别方法,其特征在于,步骤5不良场景图片识别中,对于需要识别的图片,在预处理后,输入到步骤4中训练好的轻量级学生模型中,判断其是否是不良场景的图片;之后,通过有监督的不良场景识别模型,得到目标图片/>的预测标签/>,当/>时,则判定该图片为不良场景图片;否则,该图片为正常图片。
5.基于知识蒸馏的不良场景图片识别系统,其特征在于,包括:
不良场景图片采集及数据集构建模块,用于以网络社交媒体网站为数据源,分别对不良场景进行图片爬取,并同时构建正常图片数据集,得到总数据集;
数据增强模块,用于对于样本数量比较小的类别,对其训练集中的图片分别进行增强操作,生成与其他不良场景类别数目近似的增强样本,扩充到总数聚集;
不良场景图片识别模型建立模块,用于从所构建的数据集中抽取训练样本,构建和训练有监督的不良图片识别模型;
模型规模压缩模块,用于利用知识蒸馏的方式,对训练好的模型进行模型压缩,提升模型的吞吐率;
不良场景图片识别模块,用于对于需要识别的图片,在预处理后,输入到步骤4中训练好的识别模型中,判断其是否是不良场景的图片;
不良场景图片识别模型建立中,根据步骤2所构建的数据集得到训练样本数据集,利用基于交叉熵的损失函数和正则化项构建基于知识蒸馏的有监督图像分类模型,使用表示样本数据的标签信息,其中,对/>中图片/>,/>表示样本/>属于第/>个类别,分别对应吸烟、酗酒、吸毒、赌博、正常五个类别的图片;对于每一张图片,首先将其分辨率转化为224×224,再对每一个像素值进行归一化处理;使用/>表示训练数据的数据矩阵;选取ResNet152作为图像特征提取模型;在获取了每张图片的特征信息之后,将这些信息输入到分类模型中,最终得到每张图片的类别信息;将ResNet152模型的参数标记为/>,交叉熵函数为/>,则训练的目的是得到将数据矩阵/>映射到标注信息矩阵/>的/>,训练方式为:
式中为正则化项参数,/>为矩阵的1范数;
模型具体的训练过程为:
(1)读入图片和标签信息,转换图片的分辨率并将像素点归一化,得到训练集的数据矩阵;
(2)将数据矩阵输入ResNet152网络,得到每张图片的类别,即:;使用带权的交叉熵函数得到模型的训练误差,即:/>,其中样本较少的类别具有较高的权重;
(3)通过训练误差更新模型参数,直至训练误差收敛,保存参数矩阵;
模型规模压缩及吞吐率提升中,主要包括:通过得到训练好的ResNet152模型,设置为教师模型;选用较小的ResNet18或ResNet34作为学生模型;并且设置一个具有三层神经网络的鉴别器;记教师模型的参数为,学生模型的参数为/>,鉴别器的参数为/>,分别表示概率来自于教师模型或学生模型,/>散度记为/>,二元交叉熵函数记为/>,则训练方式为:
式中,是鉴别器误差的权重,/>表示sigmoid函数;
知识蒸馏的具体训练过程为:
(1)将训练集的数据矩阵输入到教师模型中,得到输出/>;
(2)将与学生模型的输出概率/>进行比较,将两个概率输入到KL散度中,得到误差;
(3)鉴别器以输出概率作为输入,鉴别输入来自于哪一个模型;
(4)通过训练误差和鉴别器误差更新学生模型的参数,通过鉴别器误差更新鉴别器的参数,直至两个误差收敛,保留学生模型的系数矩阵。
6.一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至4任一项所述基于知识蒸馏的不良场景图片识别方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210101442.XA CN114708467B (zh) | 2022-01-27 | 2022-01-27 | 基于知识蒸馏的不良场景识别方法及系统及设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210101442.XA CN114708467B (zh) | 2022-01-27 | 2022-01-27 | 基于知识蒸馏的不良场景识别方法及系统及设备 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114708467A CN114708467A (zh) | 2022-07-05 |
CN114708467B true CN114708467B (zh) | 2023-10-13 |
Family
ID=82166821
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210101442.XA Active CN114708467B (zh) | 2022-01-27 | 2022-01-27 | 基于知识蒸馏的不良场景识别方法及系统及设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114708467B (zh) |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111709476A (zh) * | 2020-06-17 | 2020-09-25 | 浪潮集团有限公司 | 一种基于知识蒸馏的小分类模型训练方法及装置 |
WO2020248471A1 (zh) * | 2019-06-14 | 2020-12-17 | 华南理工大学 | 一种基于集聚交叉熵损失函数的序列识别方法 |
CN113592007A (zh) * | 2021-08-05 | 2021-11-02 | 哈尔滨理工大学 | 一种基于知识蒸馏的不良图片识别系统、方法、计算机及存储介质 |
WO2021248868A1 (zh) * | 2020-09-02 | 2021-12-16 | 之江实验室 | 基于知识蒸馏的预训练语言模型的压缩方法及平台 |
-
2022
- 2022-01-27 CN CN202210101442.XA patent/CN114708467B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2020248471A1 (zh) * | 2019-06-14 | 2020-12-17 | 华南理工大学 | 一种基于集聚交叉熵损失函数的序列识别方法 |
CN111709476A (zh) * | 2020-06-17 | 2020-09-25 | 浪潮集团有限公司 | 一种基于知识蒸馏的小分类模型训练方法及装置 |
WO2021248868A1 (zh) * | 2020-09-02 | 2021-12-16 | 之江实验室 | 基于知识蒸馏的预训练语言模型的压缩方法及平台 |
CN113592007A (zh) * | 2021-08-05 | 2021-11-02 | 哈尔滨理工大学 | 一种基于知识蒸馏的不良图片识别系统、方法、计算机及存储介质 |
Non-Patent Citations (2)
Title |
---|
余胜 ; 陈敬东 ; 王新余 ; .基于深度学习的复杂场景下车辆识别方法.计算机与数字工程.2018,(第09期),全文. * |
凌弘毅 ; .基于知识蒸馏方法的行人属性识别研究.计算机应用与软件.2018,(第10期),全文. * |
Also Published As
Publication number | Publication date |
---|---|
CN114708467A (zh) | 2022-07-05 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112580439B (zh) | 小样本条件下的大幅面遥感图像舰船目标检测方法及系统 | |
CN112101190B (zh) | 一种遥感图像分类方法、存储介质及计算设备 | |
CN108664996B (zh) | 一种基于深度学习的古文字识别方法及系统 | |
CN112734775B (zh) | 图像标注、图像语义分割、模型训练方法及装置 | |
Kadam et al. | Detection and localization of multiple image splicing using MobileNet V1 | |
CN112418292B (zh) | 一种图像质量评价的方法、装置、计算机设备及存储介质 | |
CN111797326A (zh) | 一种融合多尺度视觉信息的虚假新闻检测方法及系统 | |
CN114549913B (zh) | 一种语义分割方法、装置、计算机设备和存储介质 | |
WO2024041479A1 (zh) | 一种数据处理方法及其装置 | |
TWI803243B (zh) | 圖像擴增方法、電腦設備及儲存介質 | |
CN114429577B (zh) | 一种基于高置信标注策略的旗帜检测方法及系统及设备 | |
CN117036843A (zh) | 目标检测模型训练方法、目标检测方法和装置 | |
CN111310820A (zh) | 基于交叉验证深度cnn特征集成的地基气象云图分类方法 | |
CN115292538A (zh) | 一种基于深度学习的地图线要素提取方法 | |
CN111445545B (zh) | 一种文本转贴图方法、装置、存储介质及电子设备 | |
CN117152438A (zh) | 一种基于改进DeepLabV3+网络的轻量级街景图像语义分割方法 | |
CN117349402A (zh) | 一种基于机器阅读理解的情绪原因对识别方法及系统 | |
CN114708467B (zh) | 基于知识蒸馏的不良场景识别方法及系统及设备 | |
KR102026280B1 (ko) | 딥 러닝을 이용한 씬 텍스트 검출 방법 및 시스템 | |
CN116257609A (zh) | 基于多尺度文本对齐的跨模态检索方法及系统 | |
CN114913382A (zh) | 一种基于CBAM-AlexNet卷积神经网络的航拍场景分类方法 | |
CN112801153B (zh) | 一种嵌入lbp特征的图的半监督图像分类方法及系统 | |
CN114896594A (zh) | 基于图像特征多注意力学习的恶意代码检测装置及方法 | |
CN109146058B (zh) | 具有变换不变能力且表达一致的卷积神经网络 | |
CN112686277A (zh) | 模型训练的方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |