CN116935188B - 模型训练方法、图像识别方法、装置、设备及介质 - Google Patents

模型训练方法、图像识别方法、装置、设备及介质 Download PDF

Info

Publication number
CN116935188B
CN116935188B CN202311193895.0A CN202311193895A CN116935188B CN 116935188 B CN116935188 B CN 116935188B CN 202311193895 A CN202311193895 A CN 202311193895A CN 116935188 B CN116935188 B CN 116935188B
Authority
CN
China
Prior art keywords
image
training
model
teacher model
source domain
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202311193895.0A
Other languages
English (en)
Other versions
CN116935188A (zh
Inventor
王强
鄢科
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tencent Technology Shenzhen Co Ltd
Original Assignee
Tencent Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tencent Technology Shenzhen Co Ltd filed Critical Tencent Technology Shenzhen Co Ltd
Priority to CN202311193895.0A priority Critical patent/CN116935188B/zh
Publication of CN116935188A publication Critical patent/CN116935188A/zh
Application granted granted Critical
Publication of CN116935188B publication Critical patent/CN116935188B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/042Knowledge-based neural networks; Logical representations of neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0464Convolutional networks [CNN, ConvNet]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0475Generative networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/094Adversarial learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/096Transfer learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/762Arrangements for image or video recognition or understanding using pattern recognition or machine learning using clustering, e.g. of similar faces in social networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Computing Systems (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Multimedia (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Image Analysis (AREA)

Abstract

本申请实施例提供了一种模型训练方法、图像识别方法、装置、设备及介质,用于降低训练成本,并提升图像识别模型的识别鲁棒性,可应用于人工智能、云技术、智慧交通、辅助驾驶等各种场景。方法包括:获取具有相同的网络结构的初始学生模型和初始教师模型以及第一训练图像;基于第一训练图像对初始学生模型进行训练得到具有第一网络参数的学生模型;基于第二训练图像对初始教师模型进行迭代训练得到教师模型,同时在迭代训练过程中利用指数滑动平均根据教师模型的网络参数更新学生模型的网络参数;直到迭代训练达到教师模型的收敛条件,输出目标学生模型。

Description

模型训练方法、图像识别方法、装置、设备及介质
技术领域
本申请涉及人工智能领域,尤其涉及一种模型训练方法、图像识别方法、装置、设备及介质。
背景技术
人工智能(Artificial Intelligence,AI)涵盖计算机视觉(Computer Vision,CV),CV技术也是AI的主要研究领域。CV技术可应用于图像分类任务,图像识别任务以及图像检索任务等,其中,在图像识别任务中,可利用训练好的分类模型判定它所属类别。
在图像识别任务中,增强模型识别鲁棒性的方案包括如下两种:方案1、在训练过程中对源域图像进行随机图像增强,分类模型同时从源域和对抗域图像学习识别能力;方案2、对该分类模型采用知识蒸馏的方案,即从具有更好图像识别能力的大模型中学习得到相应的图像识别能力,从而提高模型的识别鲁棒性。
但是上述两种方法均存在其对应的缺点:方案1会显著影响分类模型在源域的识别能力;方案2中大模型的训练成本较大。因此目前急需一种更合适的图像识别模型。
发明内容
本申请实施例提供了一种模型训练方法、图像识别方法、装置、设备及介质,用于降低训练成本,并提升图像识别模型的识别鲁棒性。
有鉴于此,本申请一方面提供一种模型训练方法,包括:获取初始学生模型、初始教师模型以及源域训练图像,其中,该初始学生模型与该初始教师模型具有相同的网络结构;基于该源域训练图像对该初始学生模型进行训练得到学生模型;基于对抗域训练图像对该初始教师模型进行迭代训练得到教师模型,并利用指数滑动平均根据该教师模型的网络参数迭代更新该学生模型的网络参数,以得到目标学生模型,该对抗域训练图像为该源域训练图像进行图像增强处理得到;在该教师模型的训练损失满足收敛条件时,输出该目标学生模型。
本申请另一方面提供一种模型训练装置,包括:获取模块,用于获取初始学生模型、初始教师模型以及源域训练图像,其中,该初始学生模型与该初始教师模型具有相同的网络结构;
处理模块,用于基于该源域训练图像对该初始学生模型进行训练得到学生模型;基于对抗域训练图像对该初始教师模型进行迭代训练得到教师模型,并利用指数滑动平均根据该教师模型的网络参数迭代更新该学生模型的网络参数,以得到目标学生模型,该对抗域训练图像为该源域训练图像进行图像增强处理得到;
输出模块,用于在该教师模型的训练损失满足收敛条件时,输出该目标学生模型。
在一种可能的设计中,在本申请实施例的另一方面的另一种实现方式中,该处理模块,用于基于第一训练子集对该初始教师模型进行训练得到第一教师模型,该第一教师模型具有第二网络参数,该第一训练子集包含于该对抗域训练图像;
利用指数滑动平均根据该第一网络参数更新该学生模型的第二网络参数,以得到该学生模型的第三网络参数,该第二网络参数为基于该源域训练图像训练该初始学生模型得到;
基于第二训练子集对该第一教师模型进行训练得到第二教师模型,该第二教师模型具有第四网络参数,该第二训练子集包含于该对抗域训练图像;
利用指数滑动平均根据该第四网络参数更新该第三网络参数得到该学生模型的第五网络参数;
重复上述操作,在训练损失满足收敛条件,得到该教师模型和该目标学生模型。
在一种可能的设计中,在本申请实施例的另一方面的另一种实现方式中,处理模块,用于获取该第一训练子集以及第一类中心矩阵,该第一训练子集包括第一样本图像和该第一样本图像对应的第一图像标签,该第一类中心矩阵用于指示该源域训练图像中对应的各个类别的特征中心;
调用该初始教师模型对该第一训练子集进行图像识别,以得到第一图像特征以及第一预测图像标签;
根据该第一预测图像标签与该第一图像标签进行损失计算得到第一损失值,并根据该第一图像特征与第一类中心向量进行距离度量得到第二损失值,该第一类中心向量为该第一预测图像标签所处类别的类中心向量,该第一类中心向量包含于该第一类中心矩阵;
根据该第一损失值和该第二损失值反向梯度传播更新该初始教师模型的网络参数,以得到该第一教师模型。
在一种可能的设计中,在本申请实施例的另一方面的另一种实现方式中,该处理模块,用于根据该第一图像特征和该第一类中心向量利用指数滑动平均更新该第一类中心矩阵,以得到第二类中心矩阵。
在一种可能的设计中,在本申请实施例的另一方面的另一种实现方式中,处理模块,用于获取该第二训练子集以及该第二类中心矩阵,该第二训练子集包括第二样本图像和该第二样本图像对应的第二图像标签;
调用该第一教师模型对该第二训练子集进行图像识别,以得到第二图像特征以及第二预测图像标签;
根据该第二预测图像标签与该第二图像标签进行损失计算得到第三损失值,并根据该第二图像特征与第二类中心向量进行距离度量得到第四损失值,该第二类中心向量为该第二预测图像标签所处类别的类中心向量,该第二类中心向量包含于该第二类中心矩阵;
根据该第三损失值和该第四损失值反向梯度传播更新该第一教师模型的网络参数,以得到该第二教师模型。
在一种可能的设计中,在本申请实施例的另一方面的另一种实现方式中,处理模块,用于对该源域训练图像进行图像增强处理生成该对抗域训练图像;
从该对抗域训练图像中进行采样得到该第一训练子集;
或者,
处理模块,用于从该源域训练图像中进行采样得到第一源域训练子集;
对该第一源域训练子集进行图像增强处理得到该第一训练子集。
在一种可能的设计中,在本申请实施例的另一方面的另一种实现方式中,处理模块,用于利用该学生模型对该源域训练图像进行前向计算,以得到该源域训练图像对应的图像特征;
根据该图像特征计算分布概率,以得到该源域训练图像的N个类别,该N为正整数;
获取该N个类别的N个特征中心向量;
根据该N个特征中心向量生成该第一类中心矩阵。
在一种可能的设计中,在本申请实施例的另一方面的另一种实现方式中,处理模块,用于利用该学生模型对该源域训练图像进行前向计算,以得到该源域训练图像对应的图像特征;
对该图像特征进行聚类计算,以得到该源域训练图像的N个类别,该N为正整数;
获取该N个类别的N个特征中心向量;
将该N个特征中心向量作为该第一类中心矩阵。
在一种可能的设计中,在本申请实施例的另一方面的另一种实现方式中,该处理模块,用于根据该第一预测图像标签与该第一图像标签进行交叉熵分类损失计算得到第一损失值;
或者,
根据该第一预测图像标签与该第一图像标签进行交叉熵分类损失计算得到第一损失值;
或者,
根据该第一预测图像标签与该第一图像标签进行逻辑回归损失计算得到第一损失值。
在一种可能的设计中,在本申请实施例的另一方面的另一种实现方式中,该处理模块,用于根据该第一图像特征与第一类中心向量进行均方误差MSE损失计算得到第二损失值;
或者,
根据该第一图像特征与第一类中心向量进行平均绝对值误差L1损失计算得到第二损失值;
或者,
根据该第一图像特征与第一类中心向量进行L1-smooth损失计算得到第二损失值。
在一种可能的设计中,在本申请实施例的另一方面的另一种实现方式中,该处理模块,用于基于源域训练图像对该初始学生模型进行全监督训练得到该学生模型;
或者,
基于源域训练图像对该初始学生模型进行半监督训练得到该学生模型;
或者,
基于源域训练图像对该初始学生模型进行弱监督训练得到该学生模型;
或者,
基于源域训练图像对该初始学生模型进行无监督训练得到该学生模型。
在一种可能的设计中,在本申请实施例的另一方面的另一种实现方式中,该初始学生模型与该初始教师模型的网络结构为残差神经网络ResNet、ResNeSt、ResNeXt、RegNet、VGG、AlexNet、Transformer或者ViT。
在一种可能的设计中,在本申请实施例的另一方面的另一种实现方式中,该教师模型与该学生模型采用相同的训练方式。
本申请另一方面提供一种图像识别方法,包括:获取待处理图像;
调用图像识别模型对该待处理图像进行识别处理,以得到该待处理图像的图像类别,该图像识别模型为采用上述任一项该的方法训练得到的目标学生模型;
输出该待处理图像的图像类别。
本申请的另一方面提供一种图像识别装置,包括:获取模块,用于获取待处理图像;
处理模块,用于调用图像识别模型对该待处理图像进行识别处理,以得到该待处理图像的图像类别,该图像识别模型为上述任一项该的目标学生模型;
输出模块,用于输出该待处理图像的图像类别。
本申请另一方面提供一种计算机设备,包括:存储器、处理器以及总线系统;
其中,存储器用于存储程序;
处理器用于执行存储器中的程序,处理器用于根据程序代码中的指令执行上述各方面的方法;
总线系统用于连接存储器以及处理器,以使存储器以及处理器进行通信。
本申请的另一方面提供了一种计算机可读存储介质,计算机可读存储介质中存储有指令,当其在计算机上运行时,使得计算机执行上述各方面的方法。
本申请的另一个方面,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述各方面所提供的方法。
从以上技术方案可以看出,本申请实施例具有以下优点:提供一组网络结构相同的学生模型和教师模型,其中,学生模型只在源域进行训练,从而获取较好的源域识别能力;而教师模型只在对抗域训练,从而获取对抗识别能力;然后根据教师模型的网络参数通过指数平滑平均的方式更新该学生模型的网络参数,使得学生模型可以不断积累对抗识别能力,同时保留了源域的识别能力,最终使得学生模型具有较高的识别鲁棒性。同时,该学生模型与该教师模型采用相同的网络结构,不需要大模型训练和知识蒸馏过程,减少了模型训练复杂度,从而降低训练成本。
附图说明
图1为本申请实施例中模型训练系统的一个架构示意图;
图2为本申请实施例中模型训练方法的一个实施例示意图;
图3为本申请实施例中模型训练方法的另一个实施例示意图;
图4为本申请实施例中图像识别方法的一个实施例示意图;
图5为本申请实施例中模型训练装置的一个实施例示意图;
图6为本申请实施例中图像识别装置的一个实施例示意图;
图7为本申请实施例中模型训练装置或者图像识别装置的另一个实施例示意图;
图8为本申请实施例中模型训练装置或者图像识别装置的另一个实施例示意图。
实施方式
本申请实施例提供了一种模型训练方法、图像识别方法、装置、设备及介质,用于降低训练成本,并提升图像识别模型的识别鲁棒性。
本申请的说明书和权利要求书及上述附图中的术语“第一”、“第二”、“第三”、“第四”等(如果存在)是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本申请的实施例例如能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“对应于”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
AI涵盖计算机视觉CV,CV技术也是AI的主要研究领域。CV技术可应用于图像识别任务,图像识别任务以及图像检索任务等,其中,在图像识别任务中,可利用训练好的分类模型判定它所属类别。在图像识别任务中,增强分类模型识别鲁棒性的方案包括如下两种:方案1、在训练过程中对源域图像进行随机图像增强,分类模型同时从源域和对抗域图像学习识别能力;方案2、对该分类模型采用知识蒸馏的方案,即从具有更好图像识别能力的大模型中学习得到相应的图像识别能力,从而提高分类模型的识别鲁棒性。但是上述两种方法均存在其对应的缺点:方案1会显著影响分类模型在源域的识别能力;方案2中大模型的训练成本较大。因此目前急需一种更合适的图像识别模型。
为了解决这一技术问题,本申请提供如下技术方案:获取初始学生模型、初始教师模型以及源域训练图像,其中,该初始学生模型与该初始教师模型具有相同的网络结构;基于该源域训练图像对该初始学生模型进行训练得到学生模型;基于对抗域训练图像对该初始教师模型进行迭代训练得到教师模型,并利用指数滑动平均根据该教师模型的网络参数迭代更新该学生模型的网络参数,以得到目标学生模型,该对抗域训练图像为该源域训练图像进行图像增强处理得到;在该教师模型的训练损失满足收敛条件时,输出该目标学生模型。这样提供一组网络结构相同的学生模型和教师模型,其中,学生模型只在源域进行训练,从而获取较好的源域识别能力;而教师模型只在对抗域训练,从而获取对抗识别能力;然后根据教师模型的网络参数通过指数平滑平均的方式更新该学生模型的网络参数,使得学生模型可以不断积累对抗识别能力,同时保留了源域的识别能力,最终使得学生模型具有较高的识别鲁棒性。同时,该学生模型与该教师模型采用相同的网络结构,不需要大模型训练和知识蒸馏过程,减少了模型训练复杂度,从而降低训练成本。
本申请各可选实施例的模型训练方法以及图像识别方法是基于人工智能技术实现的。人工智能是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习、自动驾驶、智慧交通等几大方向。
计算机视觉(Computer Vision,CV)是一门研究如何使机器“看”的科学,更进一步的说,就是指用摄影机和电脑代替人眼对目标进行识别和测量等机器视觉,并进一步做图形处理,使电脑处理成为更适合人眼观察或传送给仪器检测的图像。作为一个科学学科,计算机视觉研究相关的理论和技术,试图建立能够从图像或者多维数据中获取信息的人工智能系统。计算机视觉技术通常包括图像处理、图像识别、图像语义理解、图像检索、光学字符识别(Optical Character Recognition,OCR)、视频处理、视频语义理解、视频内容/行为识别、三维物体重建、三维(Three Dimensional,3D)技术、虚拟现实、增强现实、同步定位与地图构建、自动驾驶、智慧交通等技术,还包括常见的人脸识别、指纹识别等生物特征识别技术。
本申请还涉及到云技术。其中,云技术(cloud technoolgy)是指在广域网或局域网内将硬件、软件、网络等系统资源统一起来,实现数据的计算、储存、处理和共享的一种托管技术。
云技术基于云计算商业模式应用的网络技术、信息技术、整合技术、管理平台技术、应用技术等的总称,可以组成资源池,按需所用,灵活便利。云计算技术将变成重要支撑。技术网络系统的后台服务需要大量的计算、存储资源,如视频网站、图片类网站和更多的门户网站。伴随着互联网行为的高度发展和应用,将来每个物品都有可能存在自己的识别标志,都需要传输到后台系统进行逻辑处理,不同程度级别的数据将会分开处理,各类行业数据皆需要强大的系统后盾支撑,只能通过云计算来实现。本申请中所涉及到的云技术主要指终端设备或者服务器之间可能通过“云”进行图像识别等等。
本申请还涉及到智慧交通、辅助驾驶等技术场景。智慧交通也可以称为智能交通系统(Intelligent Traffic System,ITS),是将先进的科学技术(信息技术、计算机技术、数据通信技术、传感器技术、电子控制技术、自动控制理论、运筹学、人工智能等)有效地综合运用于交通运输、服务控制和车辆制造,加强车辆、道路、使用者三者之间的联系,从而形成一种保障安全、提高效率、改善环境、节约能源的综合运输系统。
智能车路协同系统(Intelligent Vehicle Infrastructure CooperativeSystems,IVICS),简称车路协同系统,是智能交通系统(ITS)的一个发展方向。车路协同系统是采用先进的无线通信和新一代互联网等技术,全方位实施车车、车路动态实时信息交互,并在全时空动态交通信息采集与融合的基础上开展车辆主动安全控制和道路协同管理,充分实现人车路的有效协同,保证交通安全,提高通行效率,从而形成的安全、高效和环保的道路交通系统。本申请中所涉及到的智慧交通或者辅助驾驶主要指终端设备或者服务器可能通过图像识别等操作识别道路交通标志,从而实现人车路的有效协同。
为了方便理解,下面对本申请中的部分名词进行说明。
机器学习(Machine Learning,ML)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、式教学习等技术。随着人工智能技术研究和进步,人工智能技术在多个领域展开研究和应用,例如,常见的智能家居、智能穿戴设备、虚拟助理、智能音箱、智能营销、无人驾驶、自动驾驶、无人机、机器人、智能医疗、智能客服等,相信随着技术的发展,人工智能技术将在更多的领域得到应用,并发挥越来越重要的价值。
神经网络:人工神经网络(Artificial Neural Networks,ANN),是由众多的神经元可调的连接权值连接而成,具有大规模并行处理、分布式信息存储、良好的自组织自学习能力等特点。
卷积层(Convolutional layer,Conv)是指卷积神经网络层中由若干卷积单元组成的层状结构,卷积神经网络(Convolutional Neural Network,CNN)是一种前馈神经网络,卷积神经网络中包括至少两个神经网络层,其中,每一个神经网络层包含若干个神经元,各个神经元分层排列,同一层的神经元之间没有互相连接,层间信息的传送只沿一个方向进行。
反向传播:前向传播是指模型的前馈处理过程,反向传播与前向传播相反,指根据模型输出的结果对模型各个层的权重参数进行更新。例如,模型包括输入层、隐藏层和输出层,则前向传播是指按照输入层-隐藏层-输出层的顺序进行处理,反向传播是指按照输出层-隐藏层-输入层的顺序,依次更新各个层的权重参数。
ResNet网络:其网络结构通常是进行一个大尺度的卷积,再接一个池化层;随后接上连续几个子模块(DenseBlock和TransitinLayer);最后接上一个池化和全连接。以ResNet101为例,对其网络结构进行说明:ResNet101的层数为3+4+23+3=33个buildingblock,每个block为3层,所以有33x3=99层,再加上第一层的卷积conv1,以及最后的全连接层(用于分类),一共是99+1+1=101层。
ResNeSt网络:该网络结构是ResNet的一个变体。因此其网络结构一开始与ResNet类似,但是其引入了分离注意力模块(Split-Attention Networks),它可以跨特征图组实现信息交互。分离注意力模块是计算单元,由特征图组合分离主注意力操作组成。
ResNeXt网络:该网络结构是ResNet网络和Inception网络的结合体。其中,而Inception结构,首先通过1x1卷积来降低通道数把信息聚集一下,再进行不同尺度的特征提取以及池化,得到多个尺度的信息,最后将特征进行叠加输出。因此在结合了ResNet之后的ResNeXt每一个分支都采用相同的拓扑结构。ResNeXt的本质是分组卷积(GroupConvolution),通过变量基数(Cardinality)来控制组的数量。
RegNet网络:该网络结构的设计思路是去关注网络的设计空间,然后按照该思路可以一步步缩小设计空间至包含一组简单,常规,优秀的网络结构,称其为RegNet(Reg是指Regular)。RegNet网络的网络参数化思想十分简单,即优秀网络的宽度(width)和深度(depth)可以用一个量化线性函数来解释。具体来说,在设计之初,先设计一个AnyNet,其包括三个部分:1、Stem 一个简单的网络输入头;2、body 网络中主要的运算量都在这里;3、head 用于预测分类的输出头。然后将stem和head固定下来,并专注于网络body设计。因为body部分的参数量最多,运算量也多,这部分是决定网络准确性的关键。而Body结构,通常包含4个stage,每个stage都会进行降采样。而1个stage是由多个block进行堆叠得到的。按照这一思路,逐步缩小网络设计空间,从而得到RegNet 搜索空间。
VGG网络:该网络结构是卷积神经网络CNN的一种,其采用的是一种预训练(Pre-training)的方式,即先训练浅层的简单网络VGG11,再复用VGG11 的权重来初始化VGG13,如此反复训练并初始化 VGG19,能够使训练时收敛的速度更快。整个网络都使用卷积核尺寸为 3×3 和最大池化尺寸 2×2。比较常用的VGG-16的16指的是卷积层加全连接层(conv+fc)的总层数是16,是不包括最大池化层(max pool)的层数。
AlexNet网络:该网络结构是卷积神经网络CNN的一种,其整体的网络结构包括:1个输入层(input layer)、5个卷积层(C1、C2、C3、C4、C5)、2个全连接层(FC6、FC7)和1个输出层(output layer)。其中AlexNet的输入层的输入为RGB三通道的224×224×3大小的图像(也可填充为227×227×3)。AlexNet的5个卷积层中包含3个池化层,其中,每个卷积层都包含卷积核、偏置项、ReLU激活函数和局部响应归一化(LRN)模块。卷积层C1、C2、C5后面都跟着一个最大池化层,卷积层C3、C4、C5互相连接,中间没有接入池化层或归一化层。最终输出层为softmax,将网络输出转化为概率值,用于预测图像的类别。
Transformer网络:变形器,一种深度神经网络模型,由多头自注意力机制(multi-head self-attention,MHSA)(其包含多个self-Attention网络)和前馈神经网络(feed-forward network,FNN)交替堆叠而成。其中,该FNN可以由两层全连接层构成,激活函数是GELU(即Gaussian Error Linerar Units)。
ViT网络:视觉变形器,用于处理图像,是Transformer在计算机视觉上的变体。
监督学习:机器学习中的一种训练方式,是指利用一组已知类别的样本调整分类器的参数,使其达到所要求性能的过程,也称为监督训练或有教师学习,是从标记的训练数据来推断一个功能的机器学习任务。而监督学习的模型通常有两类。第一种是按模型形式分类:概率模型(ProbabilisticModel)和非概率模型(Non-probabilisticModel);第二种是按是否对观测变量的分布建模分类:判别模型(DiscriminativeModel)和生成模型(GenerativeModel)。
半监督学习:机器学习中的一种训练方式,其基本思路为在已标记的数据上训练,然后对未标注数据进行预测,取预测置信度最高的样本直接对其进行标签定义,然后将这类样本纳入当前训练样本中继续训练,直到模型的预测结果不再发生变化。
弱监督学习:机器学习中的一种训练方式,与传统的监督学习相比,其使用有限的、含有噪声的或者标注不准确的数据来进行模型参数的训练。
无监督学习:机器学习中的一种训练方式,本质上是一个统计手段,在没有标签的数据里可以发现潜在的一些结构的一种训练方式。它主要具备3个特点:1、无监督学习没有明确的目的;2、无监督学习不需要给数据打标签;3、无监督学习无法量化效果。而无监督学习通常可以应用于发现异常或者用户细分以及推荐系统等应用场景。而在学习过程中通常使用聚类或者降维算法。
指数滑动平均(Exponential Moving Average, EMA) :一种趋向类指标,指数滑动平均值是以指数形式递减加权的滑动平均。
类中心:是指在欧式空间中,分类数据集每一个类别所有图像的深度特征的中心(一般为算术平均)。
源域、对抗域:源域(Source Domain)是指原始的分类图像集;对抗域(Adversarial Domain)是指源域图像经过了随机图像增强后生成的类别不变、但图像内容变化的图像。
学生网络、教师网络:即本申请中的学生模型和教师模型。本申请中的学生网络与教师网络的网络结构是一组具有相同结构、不同初始化的孪生网络,区别于知识蒸馏(Knowledge Distillation, KD)中学生网络和教师网络的概念。其中学生网络具有源域的识别能力,教师网络负责学习对抗识别能力,并传递给学生网络,从而使学生网络在获取对抗识别能力的同时,不产生源域识别能力的灾难性遗忘。
图像增强:也可以理解为数据增强,是一种为了提高训练数据多样性同时不需要显示收集新数据的策略。也就是说,数据增强是不借助于收集更多新数据,而通过其他方式去提高数据多样性。不收集更多新数据就意味着节省大量人力标注的成本,更容易迁移到到更多任务或者领域。提高训练数据多样性就能让模型学习到更多丰富的数据模式,从而训练的到更佳鲁棒以及强大的模型(单一模式的训练数据是会损害模型性能的)。数据增强旨在提供一种收集到更多数据的选择(而不是传统的人工标注),理想中的数据增强应该同时兼顾容易扩充和提升模型性能。更多的,通过数据增强扩充的数据的分布既不应该跟原始的数据分布太过相似,那样就会导致明显的同质性,缺乏多样性,容易导致模型过拟合,也不应该跟原始的数据分布差的太多,那样会导致扩充数据没法代表领域,从而导致模型精度受损。而图像增强的方式包括但不限于对源域图像进行随机地旋转、平移、裁剪、缩放、擦除、色彩空间变换、对比度变换、锐度变换、高斯模糊等。
本申请实施例的方案适用于提升图像识别模型的识别鲁棒性,而图像识别模型可以应用于诸多计算机视觉领域,包括图像识别(例如人脸识别),目标检测(objectdetection)和语义分割(semantic segmentation),所以本申请实施例能够广泛加速多种应用场景的计算。
1、为内容平台提供内容审核服务
当利用本申请实施例提供的模型训练方法和图像识别方法为内容平台提供内容审核服务时,内容审核可以实现成为独立的内容审核程序,并安装在计算设备或者提供内容审核服务的后台服务器中。
在该场景下,内容平台的服务器接收用户发布的各种信息(例如XX软件的A用户发布的一篇笔记),该服务器利用图像识别模型识别出各种信息中的异常结果,然后将异常结果反馈至内容平台,由该内容平台确定对该异常结果进行相应的处理。
2、为安全部门提供图像检索服务
当利用本申请实施例提供的模型训练方法和图像识别方法为用户提供图像检索服务时,图像检索方法可以实现成为独立的信息检索程序,并安装在计算机设备或者提供信息检索服务的后台服务器中。
在该场景下,安全部门将想要查询的信息(例如目标人物的面部图像等)输入计算机设备,计算机设备根据检索信息,利用图像识别模型从海量的图像确定出具有目标人物的成像区域的图像,或者将检索信息发送至后台服务器,由后台服务器确定出具有目标人物的图像返回至天眼查询界面。
3、辅助医生进行疾病预测和治疗
当利用本申请实施例提供的模型训练方法和图像识别方法帮助用户进行疾病预测时,该方法可以实现成为独立的线上诊断类应用程序或健康类应用程序,安装在用户使用的计算机设备或者提供医学文本搜索服务的后台服务器中,方便用户使用该程序对疾病进行查询。
在该场景下,医生在应用程序界面输入患者的医疗图像,例如B超、彩超等,计算机设备将医疗图像输入图像识别模型,得到图像识别结果,并将该结果返回对应的应用程序界面,提示用户可能患有的疾病。
4、辅助驾驶员/自动驾驶车辆进行道路安全预警
当利用本申请实施例提供的模型训练方法和图像识别方法帮助驾驶员/自动驾驶车辆进行疾病预测时,该方法可以实现成为独立的导航类应用程序或自动驾驶类应用程序,并安装在用户使用的车载终端设备或者提供导航服务/自动驾驶服务的后台服务器中,帮助车辆安全行驶。
在该场景下,车载摄像头采集车辆前方的道路图像,将道路图像传输至车载终端或后台服务器,车载终端或后台服务器将道路图像输入图像识别模型根据道路图像识别出路上的行人等影响车辆正常行驶的对象,车载终端或后台服务器推送提示信息,提示驾驶员或控制车辆避让。
当然,除了应用于上述场景外,本申请实施例提供方法还可以应用于其他需要图像识别的场景,本申请实施例并不对具体的应用场景进行限定。
本申请实施例提供的一种模型训练方法、图像识别方法、装置、设备及介质,能够降低模型训练成本,同时提高模型识别的鲁棒性。下面说明本申请实施例提供的电子设备的示例性应用,本申请实施例提供的电子设备可以实施为各种类型的用户终端,也可以实施为服务器。在一种可能的实施方式中,本申请实施例提供的图像识别方法或者模型训练方法可以实现成为应用程序或应用程序的一部分,并被安装到终端中,使终端具备根据图像进行分类识别的功能以及进行模型训练和更新的功能;本申请实施例提供的图像识别方法可以应用于应用程序的后台服务器中,从而使服务器具备根据图像进行分类识别的功能以及进行模型训练和更新的功能。
参见图1,图1是本申请实施例提供的模型训练方案的一个应用场景下的一个可选的架构示意图,为实现支撑一个模型训练方案,终端设备100通过网络200连接服务器300,服务器300连接数据库400,网络200可以是广域网或者局域网,又或者是二者的组合。其中用于实现模型训练方案的客户端部署于终端设备100上,其中,客户端可以通过浏览器的形式运行于终端设备100上,也可以通过独立的应用程序(application,APP)的形式运行于终端设备100上等,对于客户端的具体展现形式,此处不做限定。本申请涉及的服务器300可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(Content Delivery Network,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。终端设备100可以是智能手机、平板电脑、笔记本电脑、掌上电脑、个人电脑、智能电视、智能手表、车载设备、可穿戴设备、智能语音交互设备、智能家电、飞行器等等,但并不局限于此。终端设备100以及服务器300可以通过有线或无线通信方式通过网络200进行直接或间接地连接,本申请在此不做限制。服务器300和终端设备100的数量也不做限制。本申请提供的方案可以由终端设备100独立完成,也可以由服务器300独立完成,还可以由终端设备100与服务器300配合完成,对此,本申请并不做具体限定。其中,数据库400,简而言之可视为电子化的文件柜——存储电子文件的处所,用户可以对文件中的数据进行新增、查询、更新、删除等操作。所谓“数据库”是以一定方式储存在一起、能与多个用户共享、具有尽可能小的冗余度、与应用程序彼此独立的数据集合。数据库管理系统(Database Management System,DBMS)是为管理数据库而设计的电脑软件系统,一般具有存储、截取、安全保障、备份等基础功能。数据库管理系统可以依据它所支持的数据库模型来作分类,例如关系式、可扩展标记语言(Extensible Markup Language,XML);或依据所支持的计算机类型来作分类,例如服务器群集、移动电话;或依据所用查询语言来作分类,例如结构化查询语言(Structured Query Language,SQL)、XQuery;或依据性能冲量重点来作分类,例如最大规模、最高运行速度;亦或其他的分类方式。不论使用哪种分类方式,一些DBMS能够跨类别,例如,同时支持多种查询语言。在本申请中,数据库400可以用于存储源域训练图像、对抗域训练图像或者待处理图像,当然,源域训练图像、对抗域训练图像或者待处理图像的存储位置并不限于数据库,例如还可以存储于终端设备100、区块链或者服务器300的分布式文件系统中等。
在一些实施例中,服务器300和该终端设备100均可以执行本申请实施例提供的模型训练方法以及图像识别方法。
本实施例中,该其具体流程可以如下:该终端设备100获取源域训练图像、初始学生模型和该初始教师模型;然后该终端设备100将该源域训练图像可以存储于该数据库400或存储于该终端设备100的存储器;该服务器300从该数据库400或者该终端设备100获取该源域训练图像;然后根据该源域训练图像训练该初始学生模型得到该学生模型;然后根据该源域训练图像得到对抗域训练图像,并基于该对抗域训练图像训练该初始教师模型得到教师模型以及该教师模型的网络参数,并利用EMA根据该教师模型的网络参数更新该学生模型的网络参数,从而得到目标学生模型,该目标学生模型将作为该图像识别模型;最后该服务器300可以将该图像识别模型部署至该终端设备100,以实现该终端设备100可以调用该图像识别模型实现对待处理图像的图像识别;或者,该服务器300将该图像识别模型部署于该图像处理的服务器上,使得该图像处理的服务器调用该图像识别模型实现对待处理图像的图像识别。
在另一个实施例中,该终端设备100独立执行本申请实施例提供的模型训练方法,本实施例中,该其具体流程可以如下:该终端设备100获取源域训练图像、初始学生模型和该初始教师模型;然后该终端设备100将该源域训练图像可以存储于该数据库400或存储于该终端设备100的存储器;该终端设备100从该数据库400或者该终端设备100获取该源域训练图像;然后根据该源域训练图像训练该初始学生模型得到该学生模型;然后根据该源域训练图像得到对抗域训练图像,并基于该对抗域训练图像训练该初始教师模型得到教师模型以及该教师模型的网络参数,并利用EMA根据该教师模型的网络参数更新该学生模型的网络参数,从而得到目标学生模型,该目标学生模型将作为该图像识别模型;最后该终端设备100可以将该图像识别模型部署至该终端设备100,以实现该终端设备100可以调用该图像识别模型实现对待处理图像的图像识别;或者,终端设备100将该图像识别模型部署于该图像处理的服务器上,使得该图像处理的服务器调用该图像识别模型实现对待处理图像的图像识别。
结合上述介绍,以图2所示的训练框架示意图对本申请中模型训练方法进行介绍,本申请实施例中模型训练方法的一个实施例包括:
在训练之前,设定该源域数据集如下:类别集合,类别/>的训练图像/>和测试集/>,其中N、/>和/>分别表示类别的数量、类别/>的训练图像数量和类别/>的测试图像数量。
第一部分、对初始学生模型进行训练,以得到学生模型。
在该第一部分中,具体可以包括如下流程:
步骤1、对源域训练图像进行标准化处理,以得到该初始学生模型的输入图像I。一个示例性方案中,该标准化处理的具体流程可以如下:在迭代训练过程中,每一次迭代可以从该源域训练图像中选择一个批次的图像,并将图像通过放缩或者裁剪达到预设尺寸(比如固定尺寸224像素*224像素);然后对预设尺寸的图像进行归一化处理,比如对RGB图像的每一个像素值减去RGB通道的均值,然后除以RGB通道的标准差,从而将预设尺寸的图像中的每一个像素值放缩到-1至1之间。
步骤2、训练初始学生模型得到学生模型。学生模型接收步骤1中的输入图像I,通过前向计算获得该输入图像I的深度特征f,然后通过分类器(示例性方案中,该分类器可以是全连接层)获得分类结果。最后通过分类损失,比如交叉熵损失函数对该初始学生模型的网络参数进行反向训练直到损失收敛,以得到该学生模型。
应理解的是,本申请中,在训练该初始学生模型以得到该学生模型的过程中,可以采用全监督学习(即需要对输入图像I标注其对应的图像标签)、半监督学习(即可以对输入图像I中的部分图像标注其对应的图像标签)、弱监督学习(即可以对输入图像I中的部分图像标注其对应的图像标签)或者无监督学习(即无需对输入图像I标注其对应的图像标签)方式进行训练,具体此处不做限定。
第二部分、对初始教师模型进行训练,以得到教师模型,并基于教师模型的网络参数更新第一部分训练得到的学生模型的网络参数,以得到目标学生模型。
步骤1、对整体训练过程进行初始化。其中,该学生模型采用第一部分训练得到的网络参数,且在第二部分的训练过程中,该学生模型的网络参数并不受训练损失的影响(也可以理解为该学生模型的网络参数不受初始教师模型在训练过程中根据损失值进行反向梯度传播的影响);同时该初始教师模型的网络参数采用随机初始化,在训练过程中,该初始教师模型的网络参数根据损失值进行反向梯度传播进行更新。
步骤2、初始化类中心矩阵。定义该源域训练图像中有N个类别,假设学生模型和教师模型进行深度学习得到的图像深度特征的特征维度为d(以ResNet为例,d=2048),设置类中心矩阵/>,C的维度为N×d,即为每一个类别都设置一个类中心向量。比如,该源域训练图像中存在3个类别,即第一类别、第二类别以及第三类别,其中,第一类别包括30个图像,第二类别包括40个图像以及第三类别包括20个图像。通过学生模型的前向计算,第一类别得到30个深度特征,第二类别得到40个深度特征以及第三类别得到20个深度特征,分别对每个类别的深度特征求均值得到其对应的类中心向量f1、f2以及f3,此时该类中心矩阵包括f1、f2以及f3
在本实施例中,该类中心矩阵的初始化可以如下:利用该第一部分训练得到的学生模型对该源域训练图像进行前向计算,得到该源域训练图像对应的深度特征;并针对该深度特征进行分类,得到N个类别;然后再根据上述对类中心矩阵的定义以及该N个类别对应的深度特征进行计算得到N个类中心向量,将该N个类中心向量作为该类中心矩阵。
应理解的是,在对该深度特征进行分类时,根据不同的学习方式,在根据该深度特征进行分类得到N个类别时,可以采用不同的方式。在使用监督学习方式、半监督学习或者弱监督学习方式时,可以通过计算分布概率得到该深度特征所处的类别;在使用无监督学习时,可以通过聚类计算得到该深度所处的类别。具体此处不做限定。
步骤3、基于对抗域训练图像对该初始教师模型进行训练得到教师模型,并根据该教师模型的网络参数利用EMA更新该学生模型的网络参数。
下面以训练过程中的第次迭代为例(应理解的是,该/>),在每一轮迭代中分别进行如下操作:
1、从源域训练图像采样图像I,其对应的图像标签为y。应理解的是,本实施例中,以全监督学习方式为例进行说明,因此从源域训练图像中获取到的训练图像均具有其对应的真实图像标签。若在训练过程中,采用半监督学习方式或者弱监督学习方式,则该模型训练装置可以通过伪标签方式来为该训练图像标注图像标签。其中,该伪标签方式是一种同时从未标记数据和标记数据中学习的监督范式,具体思路是将具有最大预测概率的类作为伪标签。若在训练过程中,采用无监督学习方式,则该模型训练装置可以通过聚类计算为该训练图标注图像标签。本实施例中,只要可以为该训练图像进行图像标签标注即可,具体此处不做限定。
2、通过随机图像增强获得对抗域图像(在此次迭代过程中可以理解为本申请中的第一训练子集),其对应的图像标签依然是y。本实施例中,该图像增强方式包括但不限于对源域图像进行随机地旋转、平移、裁剪、缩放、擦除、色彩空间变换、对比度变换、锐度变换、高斯模糊等。
应理解的是,在本实施例中,上述1和2中获取的对抗域图像还可以是其他方式,比如,先将源域训练图像通过随机图像增强获取对抗域训练图像,然后再从该对抗域训练图像中进行采样得到该对抗域图像/>。只要可以得到基于该源域训练图像的对抗域图像即可,具体方式此处不作限定。
3、对进行输入标准化(参考第一部分的步骤1),得到教师模型T的输入图像。本实施例中,为了降低训练成本,该教师模型与该学生模型的训练方式应相同,同时输入图像的标准化处理也应相同。
4、教师模型对进行前向计算,得到图像特征ft,其维度为d。
5、通过分类器获得教师模型的预测分布。本实施例中,该分类器可以是全连接层也可以是其他分类网络。应理解的是,在该教师模型的预测分布可以是通过计算分布概率得到,也可以是通过聚类方式得到,具体此处不做限定。
6、计算交叉熵分类损失:
/>
应理解的是,此处该模型训练装置可以计算上述交叉熵分类损失,也可以计算其他损失,比如,相对熵分类损失(即KL散度( Kullback-Leibler divergence)损失函数)以及逻辑回归损失(即softmax损失函数)。其中,该KL散度损失是一种非对称度量方法,常用于度量两个概率分布之间的距离。KL散度也可以衡量两个随机分布之间的距离,两个随机分布的相似度越高的,它们的KL散度越小,当两个随机分布的差别增大时,它们的KL散度也会增大,因此KL散度可以用于比较文本标签或图像的相似性。而该softmax损失函数的本质是将一个k维的任意实数向量x映射成另一个k维的实数向量,其中,输出向量中的每个元素的取值范围都是(0,1),即softmax损失函数输出每个类别的预测概率。由于softmax损失函数具有类间可分性,被广泛用于分类、分割、人脸识别、图像自动标注和人脸验证等问题中,其特点是类间距离的优化效果非常好,但类内距离的优化效果比较差。
7、计算类中心约束损失,即计算图像深度特征ft与对应的上一次迭代的类中心向量/>(可以理解为本申请中的第一类中心矩阵中的第一类中心向量)之间的L1距离(即平均绝对值误差,也可以称为曼哈顿距离):
应理解的是,此处该模型训练装置可以计算上述L1距离进行距离度量,也可以计算其他距离进行距离度量,比如均方误差MSE或者L2距离(又称为欧氏距离)或者L1-smooth误差。具体此处不做限定。其中,该均方误差用于度量样本点到回归曲线的距离,通过最小化平方损失使样本点可以更好地拟合回归曲线。均方误差损失函数(MSE)的值越小,表示预测模型描述的样本数据具有越好的精确度。L2距离是一种常用的距离度量方法,通常用于度量数据点之间的相似度。由于L2损失具有凸性和可微性,且在独立、同分布的高斯噪声情况下,它能提供最大似然估计,使得它成为回归问题、模式识别、图像处理中最常使用的损失函数。L1-smooth是基于L1距离的一种损失函数,其主要用于在目标检测中防止梯度爆炸。
8、计算教师模型的完整损失函数
9、通过反向梯度传播更新教师模型的网络参数:
这里表示学习率,/>表示上一次迭代中该教师模型的网络参数,/>用于表示对该损失函数进行梯度计算。
10、更新类中心矩阵C:利用ft及其所对应的标签y,更新y对应类别的类中心。具体来说,在当前迭代次数为时,第/>个类别对应的类中心向量为/>(可以理解为本申请中的第二类中心矩阵),对于新增的来自对抗域的图像所生成的深度特征ft,使用EMA的方式更新类中心向量:
其中表示保留多少程度的原有类中心向量的特征信息,/>通常需要设置的很大(高于0.996)。应理解的是,本实施例中为了避免训练过程中对该类中心矩阵的影响较大,因此将该/>值设定的较大,若为了实现类中心矩阵的大范围波动,则可以将该/>值设定的较小。即可以根据实际需要对该/>值进行相应的设定。
11、更新学生模型的网络参数:通过EMA,将教师模型的网络参数更新到学生模型:
表示要保留多少比例的源域识别能力(即确定该教师模型的网络参数对该学生模型的网络参数的影响力大小);/>通常需要设置的很大(高于0.996),因此在每一次迭代更新时,都会尽可能保留学生模型原有的源域识别的能力,而同时,又可以按照/>的比例学习到教师模型的对抗域识别能力。应理解的是,本实施例中为了保留源域识别的能力较多,因此将该/>值设定的较大,若为了学习到更多的对抗识别能力,则可以将该/>值设定的较小。即可以根据实际需要对该/>值进行相应的设定。
12、完成本轮计算,进入下一次迭代,直到训练损失收敛。应理解的是,该损失收敛为该教师模型在训练过程中的训练损失收敛。
结合上述介绍,下面将对本申请中模型训练方法进行介绍,请参阅图3,本申请实施例中模型训练方法的一个实施例包括:
301、获取初始学生模型、初始教师模型以及源域训练图像,其中,该初始学生模型与该初始教师模型具有相同的网络结构。
本申请中,该初始学生模型与该初始教师模型具有相同的网络结构,而该网络结构可以是常用的分类网络。比如ResNet、ResNeSt、ResNeXt、RegNet、VGG、AlexNet、Transformer或者ViT等等,具体此处不做限定。
而该源域训练图像可以是该模型训练装置接收到的第三方训练图像,也可以是该模型训练装置通过其自带的摄像头获取到的历史图像,具体此处不做限定。
302、基于该源域训练图像对该初始学生模型进行训练,以得到学生模型。
本实施例中,该模型训练装置对该初始学生模型进行训练的具体过程可以参阅上述第一部分的描述,具体此处不再赘述。
303、基于对抗域训练图像对该初始教师模型进行迭代训练得到教师模型,并利用指数滑动平均根据该教师模型的网络参数迭代更新该学生模型的网络参数,以得到目标学生模型,该对抗域训练图像为该源域训练图像进行图像增强处理得到。
本实施例中,该模型训练装置对该初始教师模型进行训练的具体过程可以参阅上述第二部分的描述,具体此处不再赘述。
304、在该教师模型的训练损失满足收敛条件时,输出该目标学生模型。
在该训练损失满足收敛条件时,确定该教师模型已训练完成,同时根据该教师模型的网络参数迭代更新该学生模型的网络参数的过程也已经完成,此时该学生模型的网络参数确定为最终的目标学生模型的网络参数,并输出该目标学生模型作为图像识别中的模型。
应理解的是,该目标学生模型可以部署于终端设备也可以部署于服务器,从而实现图像识别的功能。其具体部分位置,此处不做限定。
结合上述介绍,下面对图像识别方法进行说明,请参阅图4所示,本申请中图像识别方法的一个实施例包括:
401、获取待处理图像。
本实施例中,该图像识别装置可以通过其自带的图像采集设备获取到实时图像,并将该实时图像作为该待处理图像;或者该图像识别装置接收第三方图像采集设备发送的实时图像,并将该实时图像作为该待处理图像。
应理解的是,根据上述描述,该待处理图像可以是智慧交通系统中由该车辆的传感器或者摄像头拍摄到的道路实时图像,或者是患者的医疗图像或者进行安全识别的面部图像、指纹图像等等。
402、调用图像识别模型对该待处理图像进行识别处理,以得到该待处理图像的图像类别。
本实施例中,在获取到该待处理图像之后,还可以对该待处理图像进行标准化处理(具体的标准化处理可以参阅前述第一部分的步骤1,此处不再赘述)得到该图像识别模型的输入图像;然后调用该图像识别模型对该输入图像进行前向计算得到该输入图像的深度特征;对该深度特征采用分类网络进行识别,得到该待处理图像的图像类别。
403、输出该待处理图像的图像类别。
本实施例中,在得到该图像类别之后,该图像识别装置可以将该图像类别以文字或者图形的方式进行输出,具体此处不做限定。
下面对本申请中的模型训练装置进行详细描述,请参阅图5,图5为本申请实施例中模型训练装置的一个实施例示意图,模型训练装置20包括:
获取模块201,用于获取初始学生模型、初始教师模型以及源域训练图像,其中,该初始学生模型与该初始教师模型具有相同的网络结构;
处理模块202,用于基于该源域训练图像对该初始学生模型进行训练得到学生模型;基于对抗域训练图像对该初始教师模型进行迭代训练得到教师模型,并利用指数滑动平均根据该教师模型的网络参数迭代更新该学生模型的网络参数,以得到目标学生模型,该对抗域训练图像为该源域训练图像进行图像增强处理得到;
输出模块203,用于在该教师模型的训练损失满足收敛条件时,输出该目标学生模型。
本申请实施例中,提供了一种模型训练装置。采用上述装置,提供一组网络结构相同的学生模型和教师模型,其中,学生模型只在源域进行训练,从而获取较好的源域识别能力;而教师模型只在对抗域训练,从而获取对抗识别能力;然后将教师模型的参数通过指数平滑平均的方式更新至该学生网络,使得学生网络可以不断积累对抗识别能力,同时保留了大部分的固定源域的识别能力,最终实现较高的识别鲁棒性。同时,该学生模型与该教师模型采用相同的网络结构,不需要大模型训练和知识蒸馏过程,减少了模型训练复杂度,从而降低训练成本。
可选地,在上述图5所对应的实施例的基础上,本申请实施例提供的模型训练装置20的另一实施例中,
该处理模块202,用于基于第一训练子集对该初始教师模型进行训练得到第一教师模型,该第一教师模型具有第二网络参数,该第一训练子集包含于该对抗域训练图像;
利用指数滑动平均根据该第二网络参数更新该第一网络参数得到该学生模型的第三网络参数;
基于第二训练子集对该第一教师模型进行训练得到第二教师模型,该第二教师模型具有第四网络参数,该第二训练子集包含于该对抗域训练图像;
利用指数滑动平均根据该第一网络参数更新该学生模型的第二网络参数,以得到该学生模型的第三网络参数,该第二网络参数为基于该源域训练图像训练该初始学生模型得到;
重复上述操作,在训练损失满足收敛条件,得到该教师模型和该目标学生模型。
本申请实施例中,提供了一种模型训练装置。采用上述装置,在对抗域内对该初始教师模型进行训练,从而获取对抗识别能力;然后将教师模型的参数通过指数平滑平均的方式更新至该学生网络,使得学生网络可以不断积累对抗识别能力,同时保留了大部分的固定源域的识别能力,最终实现较高的识别鲁棒性。
可选地,在上述图5所对应的实施例的基础上,本申请实施例提供的模型训练装置20的另一实施例中,处理模块202,用于获取该第一训练子集以及第一类中心矩阵,该第一训练子集包括第一样本图像和该第一样本图像对应的第一图像标签,该第一类中心矩阵用于指示该源域训练图像中对应的各个类别的特征中心;
调用该初始教师模型对该第一训练子集进行图像识别,以得到第一图像特征以及第一预测图像标签;
根据该第一预测图像标签与该第一图像标签进行损失计算得到第一损失值,并根据该第一图像特征与第一类中心向量进行距离度量得到第二损失值,该第一类中心向量为该第一预测图像标签所处类别的类中心向量,该第一类中心向量包含于该第一类中心矩阵;
根据该第一损失值和该第二损失值反向梯度传播更新该初始教师模型的网络参数,以得到该第一教师模型。
本申请实施例中,提供了一种模型训练装置。采用上述装置,在对抗域内对该初始教师模型进行训练,在训练过程中通过真实图像标签与预测图像标签计算损失值,同时度量图像特征与类中心向量的距离计算另一个损失值,同时使用更新后的类中心向量,这样可以约束新增图像的特征与对应类别的类中心之间的距离,从而增强教师模型对对抗图像的识别鲁棒性。
可选地,在上述图5所对应的实施例的基础上,本申请实施例提供的模型训练装置20的另一实施例中,
该处理模块202,用于根据该第一图像特征和该第一类中心向量利用指数滑动平均更新该第一类中心矩阵,以得到第二类中心矩阵。
本申请实施例中,提供了一种模型训练装置。采用上述装置,迭代更新类中心矩阵中的类中心向量,这样可以有效的约束图像特征与对应类别的类中心之间的距离,从而增强教师模型对对抗图像的识别鲁棒性。
可选地,在上述图5所对应的实施例的基础上,本申请实施例提供的模型训练装置20的另一实施例中,
处理模块202,用于获取该第二训练子集以及该第二类中心矩阵,该第二训练子集包括第二样本图像和该第二样本图像对应的第二图像标签;
调用该第一教师模型对该第二训练子集进行图像识别,以得到第二图像特征以及第二预测图像标签;
根据该第二预测图像标签与该第二图像标签进行损失计算得到第三损失值,并根据该第二图像特征与第二类中心向量进行距离度量得到第四损失值,该第二类中心向量为该第二预测图像标签所处类别的类中心向量,该第二类中心向量包含于该第二类中心矩阵;
根据该第三损失值和该第四损失值反向梯度传播更新该第一教师模型的网络参数,以得到该第二教师模型。
本申请实施例中,提供了一种模型训练装置。采用上述装置,在对抗域内对该初始教师模型进行训练,在训练过程中通过真实图像标签与预测图像标签计算损失值,同时度量图像特征与类中心向量的距离计算另一个损失值,同时使用更新后的类中心向量,这样可以约束新增图像的特征与对应类别的类中心之间的距离,从而增强教师模型对对抗图像的识别鲁棒性。
可选地,在上述图5所对应的实施例的基础上,本申请实施例提供的模型训练装置20的另一实施例中,处理模块202,用于对该源域训练图像进行图像增强处理生成该对抗域训练图像;
从该对抗域训练图像中进行采样得到该第一训练子集;
或者,
处理模块202,用于从该源域训练图像中进行采样得到第一源域训练子集;
对该第一源域训练子集进行图像增强处理得到该第一训练子集。
本申请实施例中,提供了一种模型训练装置。采用上述装置,可以将第一训练图像通过图像增强处理,从而得到该第二训练图像,这样该第二训练图像是基于原始图像的增强图像,使得教师模型可以学习到对抗图像识别能力。同时提供多种方案,可以提高方案的可实行性。
可选地,在上述图5所对应的实施例的基础上,本申请实施例提供的模型训练装置20的另一实施例中,处理模块202,用于利用该学生模型对该源域训练图像进行前向计算,以得到该源域训练图像对应的图像特征;
根据该图像特征计算分布概率,以得到该源域训练图像的N个类别,该N为正整数;
获取该N个类别的N个特征中心向量;
根据该N个特征中心向量生成该第一类中心矩阵。
本申请实施例中,提供了一种模型训练装置。采用上述装置,利用已训练好的学生模型对源域中的训练图像进行前向计算,学习到各个训练图像的图像特征;并根据图像特征进行分布概率计算,从而分类得到多个类别,并对于不同类别的图像特征进行平均得到统计意义上的类别中心,这样可以约束新增图像的特征与对应类别的类中心之间的距离,从而增强教师模型对对抗图像的识别鲁棒性。
可选地,在上述图5所对应的实施例的基础上,本申请实施例提供的模型训练装置20的另一实施例中,
处理模块202,用于利用该学生模型对该源域训练图像进行前向计算,以得到该源域训练图像对应的图像特征;
对该图像特征进行聚类计算,以得到该源域训练图像的N个类别,该N为正整数;
获取该N个类别的N个特征中心向量;
将该N个特征中心向量作为该第一类中心矩阵。
本申请实施例中,提供了一种模型训练装置。采用上述装置,利用已训练好的学生模型对源域中的训练图像进行前向计算,学习到各个训练图像的图像特征;并根据图像特征进行聚类计算,从而分类得到多个类别,并对于不同类别的图像特征进行平均得到统计意义上的类别中心,这样可以约束新增图像的特征与对应类别的类中心之间的距离,从而增强教师模型对对抗图像的识别鲁棒性。
可选地,在上述图5所对应的实施例的基础上,本申请实施例提供的模型训练装置20的另一实施例中,
该处理模块202,用于根据该第一预测图像标签与该第一图像标签进行交叉熵分类损失计算得到第一损失值;
或者,
根据该第一预测图像标签与该第一图像标签进行交叉熵分类损失计算得到第一损失值;
或者,
根据该第一预测图像标签与该第一图像标签进行逻辑回归损失计算得到第一损失值。
本申请实施例中,提供了一种模型训练装置。采用上述装置,提供多种损失值的计算方式,提高了方案的可实行性。
可选地,在上述图5所对应的实施例的基础上,本申请实施例提供的模型训练装置20的另一实施例中,
该处理模块202,用于根据该第一图像特征与第一类中心向量进行均方误差MSE损失计算得到第二损失值;
或者,
根据该第一图像特征与第一类中心向量进行平均绝对值误差L1损失计算得到第二损失值;
或者,
根据该第一图像特征与第一类中心向量进行L1-smooth损失计算得到第二损失值。
本申请实施例中,提供了一种模型训练装置。采用上述装置,提供多种损失值的计算方式,提高了方案的可实行性。
可选地,在上述图5所对应的实施例的基础上,本申请实施例提供的模型训练装置20的另一实施例中,
该处理模块202,用于基于源域训练图像对该初始学生模型进行全监督训练得到该学生模型;
或者,
基于源域训练图像对该初始学生模型进行半监督训练得到该学生模型;
或者,
基于源域训练图像对该初始学生模型进行弱监督训练得到该学生模型;
或者,
基于源域训练图像对该初始学生模型进行无监督训练得到该学生模型。
本申请实施例中,提供了一种模型训练装置。采用上述装置,提供多种训练方式,从而增加学生模型的泛化性以及方案的可实行性。
可选地,在上述图5所对应的实施例的基础上,本申请实施例提供的模型训练装置20的另一实施例中,
该初始学生模型与该初始教师模型的网络结构为残差神经网络ResNet、ResNeSt、ResNeXt、RegNet、VGG、AlexNet、Transformer或者ViT。
本申请实施例中,提供了一种模型训练装置。采用上述装置,提供多种网络结构,从而增加学生模型的可应用场景,同时也增加方案的可实行性。
可选地,在上述图5所对应的实施例的基础上,本申请实施例提供的模型训练装置20的另一实施例中,该教师模型与该学生模型采用相同的训练方式。
本申请实施例中,提供了一种模型训练装置。采用上述装置,学生模型与教师模型采用相同的训练方式进行训练,这样可以减少模型训练的复杂度,从而进一步的降低模型训练成本。
下面对本申请中的图像识别装置进行详细描述,请参阅图6,图6为本申请实施例中模型训练装置的一个实施例示意图,图像识别装置60包括:
获取模块601,用于获取待处理图像;
处理模块602,用于调用图像识别模型对该待处理图像进行识别处理,以得到该待处理图像的图像类别,该图像识别模型为上述任一项该的目标学生模型;
输出模块603,用于输出该待处理图像的图像类别。
本申请实施例中,提供了一种模型训练装置。采用上述装置,提供一组网络结构相同的学生模型和教师模型,其中,学生模型只在源域进行训练,从而获取较好的源域识别能力;而教师模型只在对抗域训练,从而获取对抗识别能力;然后将教师模型的参数通过指数平滑平均的方式更新至该学生网络,使得学生网络可以不断积累对抗识别能力,同时保留了大部分的固定源域的识别能力,最终实现较高的识别鲁棒性。同时,该学生模型与该教师模型采用相同的网络结构,不需要大模型训练和知识蒸馏过程,减少了模型训练复杂度,从而降低训练成本。
本申请提供的模型训练装置以及图像识别装置可用于服务器,请参阅图7,图7是本申请实施例提供的一种服务器结构示意图,该服务器300可因配置或性能不同而产生比较大的差异,可以包括一个或一个以上中央处理器(central processing units,CPU)322(例如,一个或一个以上处理器)和存储器332,一个或一个以上存储应用程序342或数据344的存储介质330(例如一个或一个以上海量存储设备)。其中,存储器332和存储介质330可以是短暂存储或持久存储。存储在存储介质330的程序可以包括一个或一个以上模块(图示没标出),每个模块可以包括对服务器中的一系列指令操作。更进一步地,中央处理器322可以设置为与存储介质330通信,在服务器300上执行存储介质330中的一系列指令操作。
服务器300还可以包括一个或一个以上电源326,一个或一个以上有线或无线网络接口350,一个或一个以上输入输出接口358,和/或,一个或一个以上操作系统341,例如Windows ServerTM,Mac OS XTM,UnixTM, LinuxTM,FreeBSDTM等等。
上述实施例中由服务器所执行的步骤可以基于该图7所示的服务器结构。
本申请提供的模型训练装置或者图像识别装置可用于终端设备,请参阅图8,为了便于说明,仅示出了与本申请实施例相关的部分,具体技术细节未揭示的,请参照本申请实施例方法部分。在本申请实施例中,以终端设备为智能手机为例进行说明:
图8示出的是与本申请实施例提供的终端设备相关的智能手机的部分结构的框图。参考图8,智能手机包括:射频(radio frequency,RF)电路410、存储器420、输入单元430、显示单元440、传感器450、音频电路460、无线保真(wireless fidelity,WiFi)模块470、处理器480、以及电源490等部件。本领域技术人员可以理解,图8中示出的智能手机结构并不构成对智能手机的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。
下面结合图8对智能手机的各个构成部件进行具体的介绍:
RF电路410可用于收发信息或通话过程中,信号的接收和发送,特别地,将基站的下行信息接收后,给处理器480处理;另外,将设计上行的数据发送给基站。通常,RF电路410包括但不限于天线、至少一个放大器、收发信机、耦合器、低噪声放大器(low noiseamplifier,LNA)、双工器等。此外,RF电路410还可以通过无线通信与网络和其他设备通信。上述无线通信可以使用任一通信标准或协议,包括但不限于全球移动通讯系统 (globalsystem of mobile communication,GSM)、通用分组无线服务(general packet radioservice,GPRS)、码分多址(code division multiple access,CDMA)、宽带码分多址(wideband code division multiple access, WCDMA)、长期演进 (long termevolution,LTE)、电子邮件、短消息服务(short messaging service,SMS)等。
存储器420可用于存储软件程序以及模块,处理器480通过运行存储在存储器420的软件程序以及模块,从而执行智能手机的各种功能应用以及数据处理。存储器420可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序(比如声音播放功能、图像播放功能等)等;存储数据区可存储根据智能手机的使用所创建的数据(比如音频数据、电话本等)等。此外,存储器420可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。
输入单元430可用于接收输入的数字或字符信息,以及产生与智能手机的用户设置以及功能控制有关的键信号输入。具体地,输入单元430可包括触控面板431以及其他输入设备432。触控面板431,也称为触摸屏,可收集用户在其上或附近的触摸操作(比如用户使用手指、触笔等任何适合的物体或附件在触控面板431上或在触控面板431附近的操作),并根据预先设定的程式驱动相应的连接装置。可选的,触控面板431可包括触摸检测装置和触摸控制器两个部分。其中,触摸检测装置检测用户的触摸方位,并检测触摸操作带来的信号,将信号传送给触摸控制器;触摸控制器从触摸检测装置上接收触摸信息,并将它转换成触点坐标,再送给处理器480,并能接收处理器480发来的命令并加以执行。此外,可以采用电阻式、电容式、红外线以及表面声波等多种类型实现触控面板431。除了触控面板431,输入单元430还可以包括其他输入设备432。具体地,其他输入设备432可以包括但不限于物理键盘、功能键(比如音量控制按键、开关按键等)、轨迹球、鼠标、操作杆等中的一种或多种。
显示单元440可用于显示由用户输入的信息或提供给用户的信息以及智能手机的各种菜单。显示单元440可包括显示面板441,可选的,可以采用液晶显示器(liquidcrystal display,LCD)、有机发光二极管(organic light-emitting diode,OLED)等形式来配置显示面板441。进一步的,触控面板431可覆盖显示面板441,当触控面板431检测到在其上或附近的触摸操作后,传送给处理器480以确定触摸事件的类型,随后处理器480根据触摸事件的类型在显示面板441上提供相应的视觉输出。虽然在图8中,触控面板431与显示面板441是作为两个独立的部件来实现智能手机的输入和输入功能,但是在某些实施例中,可以将触控面板431与显示面板441集成而实现智能手机的输入和输出功能。
智能手机还可包括至少一种传感器450,比如光传感器、运动传感器以及其他传感器。具体地,光传感器可包括环境光传感器及接近传感器,其中,环境光传感器可根据环境光线的明暗来调节显示面板441的亮度,接近传感器可在智能手机移动到耳边时,关闭显示面板441和/或背光。作为运动传感器的一种,加速计传感器可检测各个方向上(一般为三轴)加速度的大小,静止时可检测出重力的大小及方向,可用于识别智能手机姿态的应用(比如横竖屏切换、相关游戏、磁力计姿态校准)、振动识别相关功能(比如计步器、敲击)等;至于智能手机还可配置的陀螺仪、气压计、湿度计、温度计、红外线传感器等其他传感器,在此不再赘述。
音频电路460、扬声器461,传声器462可提供用户与智能手机之间的音频接口。音频电路460可将接收到的音频数据转换后的电信号,传输到扬声器461,由扬声器461转换为声音信号输出;另一方面,传声器462将收集的声音信号转换为电信号,由音频电路460接收后转换为音频数据,再将音频数据输出处理器480处理后,经RF电路410以发送给比如另一智能手机,或者将音频数据输出至存储器420以便进一步处理。
WiFi属于短距离无线传输技术,智能手机通过WiFi模块470可以帮助用户收发电子邮件、浏览网页和访问流式媒体等,它为用户提供了无线的宽带互联网访问。虽然图8示出了WiFi模块470,但是可以理解的是,其并不属于智能手机的必须构成,完全可以根据需要在不改变发明的本质的范围内而省略。
处理器480是智能手机的控制中心,利用各种接口和线路连接整个智能手机的各个部分,通过运行或执行存储在存储器420内的软件程序和/或模块,以及调用存储在存储器420内的数据,执行智能手机的各种功能和处理数据,从而对智能手机进行整体监测。可选的,处理器480可包括一个或多个处理单元;可选的,处理器480可集成应用处理器和调制解调处理器,其中,应用处理器主要处理操作系统、用户界面和应用程序等,调制解调处理器主要处理无线通信。可以理解的是,上述调制解调处理器也可以不集成到处理器480中。
智能手机还包括给各个部件供电的电源490(比如电池),可选的,电源可以通过电源管理系统与处理器480逻辑相连,从而通过电源管理系统实现管理充电、放电、以及功耗管理等功能。
尽管未示出,智能手机还可以包括摄像头、蓝牙模块等,在此不再赘述。
上述实施例中由终端设备所执行的步骤可以基于该图8所示的终端设备结构。
本申请实施例中还提供一种计算机可读存储介质,该计算机可读存储介质中存储有计算机程序,当其在计算机上运行时,使得计算机执行如前述各个实施例描述的方法。
本申请实施例中还提供一种包括程序的计算机程序产品,当其在计算机上运行时,使得计算机执行前述各个实施例描述的方法。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的系统,装置和单元的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在本申请所提供的几个实施例中,应该理解到,所揭露的系统,装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(read-only memory,ROM)、随机存取存储器(random access memory,RAM)、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述,以上实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围。

Claims (20)

1.一种模型训练方法,其特征在于,包括:
获取初始学生模型、初始教师模型以及源域训练图像,其中,所述初始学生模型与所述初始教师模型具有相同的网络结构;
基于所述源域训练图像对所述初始学生模型进行训练得到学生模型;
基于对抗域训练图像对所述初始教师模型进行迭代训练得到教师模型,并利用指数滑动平均根据所述教师模型的网络参数迭代更新所述学生模型的网络参数,以得到目标学生模型,所述对抗域训练图像为所述源域训练图像进行图像增强处理得到;
在所述教师模型的训练损失满足收敛条件时,输出所述目标学生模型;
其中,所述基于对抗域训练图像对所述初始教师模型进行迭代训练得到教师模型,并利用指数滑动平均根据所述教师模型的网络参数迭代更新所述学生模型的网络参数,以得到目标学生模型包括:
基于第一训练子集对初始教师模型进行训练得到第一教师模型,所述第一教师模型具有第一网络参数,所述第一训练子集包含于对抗域训练图像,所述第一训练子集为从对源域训练图像进行图像增强处理生成的对抗域训练图像中采集得到的子集,或所述第一训练子集为对从源域训练图像中进行采样得到的第一源域训练子集进行图像增强处理得到的子集;
利用指数滑动平均根据所述第一网络参数更新所述学生模型的第二网络参数,以得到所述学生模型的第三网络参数,所述第二网络参数为基于所述源域训练图像训练所述初始学生模型得到;
基于第二训练子集对所述第一教师模型进行训练得到第二教师模型,所述第二教师模型具有第四网络参数,所述第二训练子集包含于所述对抗域训练图像;
利用指数滑动平均根据所述第四网络参数更新所述第三网络参数得到所述学生模型的第五网络参数;
重复上述操作,在训练损失满足收敛条件,得到教师模型和目标学生模型。
2.根据权利要求1所述的方法,其特征在于,所述基于第一训练子集对所述初始教师模型进行训练得到第一教师模型包括:
获取所述第一训练子集以及第一类中心矩阵,所述第一训练子集包括第一样本图像和所述第一样本图像对应的第一图像标签,所述第一类中心矩阵用于指示所述源域训练图像中对应的各个类别的特征中心;
调用所述初始教师模型对所述第一训练子集进行图像识别,以得到第一图像特征以及第一预测图像标签;
根据所述第一预测图像标签与所述第一图像标签进行损失计算得到第一损失值,并根据所述第一图像特征与第一类中心向量进行距离度量得到第二损失值,所述第一类中心向量为所述第一预测图像标签所处类别的类中心向量,所述第一类中心向量包含于所述第一类中心矩阵;
根据所述第一损失值和所述第二损失值反向梯度传播更新所述初始教师模型的网络参数,以得到所述第一教师模型。
3.根据权利要求2所述的方法,其特征在于,所述方法还包括:
根据所述第一图像特征和所述第一类中心向量利用指数滑动平均更新所述第一类中心矩阵,以得到第二类中心矩阵。
4.根据权利要求3所述的方法,其特征在于,所述基于第二训练子集对所述第一教师模型进行训练得到第二教师模型包括:
获取所述第二训练子集以及所述第二类中心矩阵,所述第二训练子集包括第二样本图像和所述第二样本图像对应的第二图像标签;
调用所述第一教师模型对所述第二训练子集进行图像识别,以得到第二图像特征以及第二预测图像标签;
根据所述第二预测图像标签与所述第二图像标签进行损失计算得到第三损失值,并根据所述第二图像特征与第二类中心向量进行距离度量得到第四损失值,所述第二类中心向量为所述第二预测图像标签所处类别的类中心向量,所述第二类中心向量包含于所述第二类中心矩阵;
根据所述第三损失值和所述第四损失值反向梯度传播更新所述第一教师模型的网络参数,以得到所述第二教师模型。
5.根据权利要求2所述的方法,其特征在于,获取第一类中心矩阵包括:
利用所述学生模型对所述源域训练图像进行前向计算,以得到所述源域训练图像对应的图像特征;
根据所述图像特征计算分布概率,以得到所述源域训练图像的N个类别,所述N为正整数;
获取所述N个类别的N个特征中心向量;
根据所述N个特征中心向量生成所述第一类中心矩阵。
6.根据权利要求2所述的方法,其特征在于,获取第一类中心矩阵包括:
利用所述学生模型对所述源域训练图像进行前向计算,以得到所述源域训练图像对应的图像特征;
对所述图像特征进行聚类计算,以得到所述源域训练图像的N个类别,所述N为正整数;
获取所述N个类别的N个特征中心向量;
将所述N个特征中心向量作为所述第一类中心矩阵。
7.根据权利要求2所述的方法,其特征在于,所述根据所述第一预测图像标签与所述第一图像标签进行损失计算得到第一损失值包括:
根据所述第一预测图像标签与所述第一图像标签进行交叉熵分类损失计算得到第一损失值;
或者,
根据所述第一预测图像标签与所述第一图像标签进行相对熵分类损失计算得到第一损失值;
或者,
根据所述第一预测图像标签与所述第一图像标签进行逻辑回归损失计算得到第一损失值。
8.根据权利要求2所述的方法,其特征在于,所述根据所述第一图像特征与第一类中心向量进行距离度量得到第二损失值包括:
根据所述第一图像特征与第一类中心向量进行均方误差MSE损失计算得到第二损失值;
或者,
根据所述第一图像特征与第一类中心向量进行平均绝对值误差L1损失计算得到第二损失值;
或者,
根据所述第一图像特征与第一类中心向量进行L1-smooth损失计算得到第二损失值。
9.一种图像识别方法,其特征在于,包括:
获取待处理图像;
调用图像识别模型对所述待处理图像进行识别处理,以得到所述待处理图像的图像类别,所述图像识别模型为采用上述权利要求1至8中任一项所述的方法训练得到的目标学生模型;
输出所述待处理图像的图像类别。
10.一种图像识别模型的训练装置,其特征在于,包括:
获取模块,用于获取初始学生模型、初始教师模型以及源域训练图像,其中,所述初始学生模型与所述初始教师模型具有相同的网络结构;
处理模块,用于基于所述源域训练图像对所述初始学生模型进行训练得到学生模型;基于对抗域训练图像对所述初始教师模型进行迭代训练得到教师模型,并利用指数滑动平均根据所述教师模型的网络参数迭代更新所述学生模型的网络参数,以得到目标学生模型,所述对抗域训练图像为所述源域训练图像进行图像增强处理得到;
输出模块,用于在所述教师模型的训练损失满足收敛条件时,输出所述目标学生模型;
其中,所述处理模块,具体用于:
基于第一训练子集对初始教师模型进行训练得到第一教师模型,所述第一教师模型具有第一网络参数,所述第一训练子集包含于对抗域训练图像,所述第一训练子集为从对源域训练图像进行图像增强处理生成的对抗域训练图像中采集得到的子集,或所述第一训练子集为对从源域训练图像中进行采样得到的第一源域训练子集进行图像增强处理得到的子集;
利用指数滑动平均根据所述第一网络参数更新所述学生模型的第二网络参数,以得到所述学生模型的第三网络参数,所述第二网络参数为基于所述源域训练图像训练所述初始学生模型得到;
基于第二训练子集对所述第一教师模型进行训练得到第二教师模型,所述第二教师模型具有第四网络参数,所述第二训练子集包含于所述对抗域训练图像;
利用指数滑动平均根据所述第四网络参数更新所述第三网络参数得到所述学生模型的第五网络参数;
重复上述操作,在训练损失满足收敛条件,得到教师模型和目标学生模型。
11.根据权利要求10所述的装置,其特征在于,所述处理模块,具体用于:
获取所述第一训练子集以及第一类中心矩阵,所述第一训练子集包括第一样本图像和所述第一样本图像对应的第一图像标签,所述第一类中心矩阵用于指示所述源域训练图像中对应的各个类别的特征中心;
调用所述初始教师模型对所述第一训练子集进行图像识别,以得到第一图像特征以及第一预测图像标签;
根据所述第一预测图像标签与所述第一图像标签进行损失计算得到第一损失值,并根据所述第一图像特征与第一类中心向量进行距离度量得到第二损失值,所述第一类中心向量为所述第一预测图像标签所处类别的类中心向量,所述第一类中心向量包含于所述第一类中心矩阵;
根据所述第一损失值和所述第二损失值反向梯度传播更新所述初始教师模型的网络参数,以得到所述第一教师模型。
12.根据权利要求11所述的装置,其特征在于,所述处理模块,具体用于:
根据所述第一图像特征和所述第一类中心向量利用指数滑动平均更新所述第一类中心矩阵,以得到第二类中心矩阵。
13.根据权利要求12所述的装置,其特征在于,所述处理模块,具体用于:
获取所述第二训练子集以及所述第二类中心矩阵,所述第二训练子集包括第二样本图像和所述第二样本图像对应的第二图像标签;
调用所述第一教师模型对所述第二训练子集进行图像识别,以得到第二图像特征以及第二预测图像标签;
根据所述第二预测图像标签与所述第二图像标签进行损失计算得到第三损失值,并根据所述第二图像特征与第二类中心向量进行距离度量得到第四损失值,所述第二类中心向量为所述第二预测图像标签所处类别的类中心向量,所述第二类中心向量包含于所述第二类中心矩阵;
根据所述第三损失值和所述第四损失值反向梯度传播更新所述第一教师模型的网络参数,以得到所述第二教师模型。
14.根据权利要求11所述的装置,其特征在于,所述处理模块,具体用于:
利用所述学生模型对所述源域训练图像进行前向计算,以得到所述源域训练图像对应的图像特征;
根据所述图像特征计算分布概率,以得到所述源域训练图像的N个类别,所述N为正整数;
获取所述N个类别的N个特征中心向量;
根据所述N个特征中心向量生成所述第一类中心矩阵。
15.根据权利要求11所述的装置,其特征在于,所述处理模块,具体用于:
利用所述学生模型对所述源域训练图像进行前向计算,以得到所述源域训练图像对应的图像特征;
对所述图像特征进行聚类计算,以得到所述源域训练图像的N个类别,所述N为正整数;
获取所述N个类别的N个特征中心向量;
将所述N个特征中心向量作为所述第一类中心矩阵。
16.根据权利要求11所述的装置,其特征在于,所述处理模块,具体用于:
根据所述第一预测图像标签与所述第一图像标签进行交叉熵分类损失计算得到第一损失值;
或者,
根据所述第一预测图像标签与所述第一图像标签进行相对熵分类损失计算得到第一损失值;
或者,
根据所述第一预测图像标签与所述第一图像标签进行逻辑回归损失计算得到第一损失值。
17.根据权利要求11所述的装置,其特征在于,所述处理模块,具体用于:
根据所述第一图像特征与第一类中心向量进行均方误差MSE损失计算得到第二损失值;
或者,
根据所述第一图像特征与第一类中心向量进行平均绝对值误差L1损失计算得到第二损失值;
或者,
根据所述第一图像特征与第一类中心向量进行L1-smooth损失计算得到第二损失值。
18.一种图像识别装置,其特征在于,包括:
获取模块,用于获取待处理图像;
处理模块,用于调用图像识别模型对所述待处理图像进行识别处理,以得到所述待处理图像的图像类别,所述图像识别模型为采用上述权利要求1至8中任一项所述的方法训练得到的目标学生模型;
输出模块,用于输出所述待处理图像的图像类别。
19.一种计算机设备,其特征在于,包括:存储器、处理器以及总线系统;
其中,所述存储器用于存储程序;
所述处理器用于执行所述存储器中的程序,所述处理器用于根据程序代码中的指令执行权利要求1至8中任一项所述的方法;
或者,
所述处理器用于根据程序代码中的指令执行权利要求9所述的方法;
所述总线系统用于连接所述存储器以及所述处理器,以使所述存储器以及所述处理器进行通信。
20.一种计算机可读存储介质,包括指令,当其在计算机上运行时,使得计算机执行如权利要求1至8中任一项所述的方法;
或者,使得计算机执行如权利要求9所述的方法。
CN202311193895.0A 2023-09-15 2023-09-15 模型训练方法、图像识别方法、装置、设备及介质 Active CN116935188B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202311193895.0A CN116935188B (zh) 2023-09-15 2023-09-15 模型训练方法、图像识别方法、装置、设备及介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202311193895.0A CN116935188B (zh) 2023-09-15 2023-09-15 模型训练方法、图像识别方法、装置、设备及介质

Publications (2)

Publication Number Publication Date
CN116935188A CN116935188A (zh) 2023-10-24
CN116935188B true CN116935188B (zh) 2023-12-26

Family

ID=88382944

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202311193895.0A Active CN116935188B (zh) 2023-09-15 2023-09-15 模型训练方法、图像识别方法、装置、设备及介质

Country Status (1)

Country Link
CN (1) CN116935188B (zh)

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117132430B (zh) * 2023-10-26 2024-03-05 中电科大数据研究院有限公司 一种基于大数据和物联网的校园管理方法及装置
CN117576535B (zh) * 2024-01-15 2024-06-25 腾讯科技(深圳)有限公司 一种图像识别方法、装置、设备以及存储介质

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114997175A (zh) * 2022-05-16 2022-09-02 电子科技大学 一种基于领域对抗训练的情感分析方法
CN115019106A (zh) * 2022-06-27 2022-09-06 中山大学 基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111105008A (zh) * 2018-10-29 2020-05-05 富士通株式会社 模型训练方法、数据识别方法和数据识别装置
CN112183577A (zh) * 2020-08-31 2021-01-05 华为技术有限公司 一种半监督学习模型的训练方法、图像处理方法及设备

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114997175A (zh) * 2022-05-16 2022-09-02 电子科技大学 一种基于领域对抗训练的情感分析方法
CN115019106A (zh) * 2022-06-27 2022-09-06 中山大学 基于对抗蒸馏的鲁棒无监督域自适应图像分类方法及装置

Also Published As

Publication number Publication date
CN116935188A (zh) 2023-10-24

Similar Documents

Publication Publication Date Title
EP3940638B1 (en) Image region positioning method, model training method, and related apparatus
JP7185039B2 (ja) 画像分類モデルの訓練方法、画像処理方法及びその装置、並びにコンピュータプログラム
Wan et al. Deep learning models for real-time human activity recognition with smartphones
CN116935188B (zh) 模型训练方法、图像识别方法、装置、设备及介质
CN110599557A (zh) 图像描述生成方法、模型训练方法、设备和存储介质
CN112101329B (zh) 一种基于视频的文本识别方法、模型训练的方法及装置
CN111709398A (zh) 一种图像识别的方法、图像识别模型的训练方法及装置
CN113723378B (zh) 一种模型训练的方法、装置、计算机设备和存储介质
CN113807399A (zh) 一种神经网络训练方法、检测方法以及装置
CN112419326B (zh) 图像分割数据处理方法、装置、设备及存储介质
CN114722937A (zh) 一种异常数据检测方法、装置、电子设备和存储介质
WO2023072175A1 (zh) 点云数据的处理方法、神经网络的训练方法以及相关设备
CN114328906A (zh) 一种多级类目的确定方法、模型训练的方法以及相关装置
CN113822427A (zh) 一种模型训练的方法、图像匹配的方法、装置及存储介质
CN117576535B (zh) 一种图像识别方法、装置、设备以及存储介质
CN116975295B (zh) 一种文本分类方法、装置及相关产品
CN116958624A (zh) 指定材质的识别方法、装置、设备、介质及程序产品
CN116955707A (zh) 内容标签的确定方法、装置、设备、介质及程序产品
Shi et al. Cloud-assisted mood fatigue detection system
CN113762046A (zh) 图像识别方法、装置、设备以及存储介质
CN113569043A (zh) 一种文本类别确定方法和相关装置
CN118035945B (zh) 一种标签识别模型的处理方法和相关装置
CN117854156B (zh) 一种特征提取模型的训练方法和相关装置
CN117011650B (zh) 一种图像编码器的确定方法及相关装置
CN117373093A (zh) 基于人工智能的图像识别方法、装置、设备以及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant