CN116977763A - 模型训练方法、装置、计算机可读存储介质及计算机设备 - Google Patents
模型训练方法、装置、计算机可读存储介质及计算机设备 Download PDFInfo
- Publication number
- CN116977763A CN116977763A CN202211701705.7A CN202211701705A CN116977763A CN 116977763 A CN116977763 A CN 116977763A CN 202211701705 A CN202211701705 A CN 202211701705A CN 116977763 A CN116977763 A CN 116977763A
- Authority
- CN
- China
- Prior art keywords
- image
- feature
- text
- loss
- mapping
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 198
- 238000000034 method Methods 0.000 title claims abstract description 132
- 238000003860 storage Methods 0.000 title claims abstract description 31
- 238000013507 mapping Methods 0.000 claims abstract description 183
- 238000003062 neural network model Methods 0.000 claims abstract description 164
- 230000000873 masking effect Effects 0.000 claims abstract description 57
- 238000012545 processing Methods 0.000 claims description 54
- 238000010606 normalization Methods 0.000 claims description 53
- 238000004364 calculation method Methods 0.000 claims description 14
- 238000004590 computer program Methods 0.000 claims description 12
- 238000009826 distribution Methods 0.000 claims description 12
- 239000013598 vector Substances 0.000 claims description 10
- 230000000007 visual effect Effects 0.000 description 36
- 230000008569 process Effects 0.000 description 26
- 238000010586 diagram Methods 0.000 description 13
- 230000006870 function Effects 0.000 description 10
- 230000000694 effects Effects 0.000 description 9
- 238000000605 extraction Methods 0.000 description 5
- 238000001514 detection method Methods 0.000 description 4
- 238000013473 artificial intelligence Methods 0.000 description 3
- 238000004891 communication Methods 0.000 description 3
- 238000007726 management method Methods 0.000 description 3
- 241000282326 Felis catus Species 0.000 description 2
- 230000001133 acceleration Effects 0.000 description 2
- 238000010276 construction Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 230000000670 limiting effect Effects 0.000 description 2
- 238000013508 migration Methods 0.000 description 2
- 230000005012 migration Effects 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 238000011160 research Methods 0.000 description 2
- 241000282324 Felis Species 0.000 description 1
- 241001465754 Metazoa Species 0.000 description 1
- 230000003190 augmentative effect Effects 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000012512 characterization method Methods 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 238000007599 discharging Methods 0.000 description 1
- 230000009977 dual effect Effects 0.000 description 1
- 238000011478 gradient descent method Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 230000036961 partial effect Effects 0.000 description 1
- 230000002829 reductive effect Effects 0.000 description 1
- 230000000452 restraining effect Effects 0.000 description 1
- 230000002441 reversible effect Effects 0.000 description 1
- 238000010187 selection method Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000001360 synchronised effect Effects 0.000 description 1
- 230000007704 transition Effects 0.000 description 1
- 238000009966 trimming Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Multimedia (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
- Image Processing (AREA)
Abstract
本申请公开了一种模型训练方法、装置、计算机可读存储介质及计算机设备。方法通过获取训练样本数据;将样本图像输入至神经网络模型,得到神经网络模型输出的图像特征,并基于图像特征与样本文本的文本特征计算第一损失;对样本图像进行掩码处理,得到掩码图像,并将掩码图像输入至神经网络模型,得到神经网络模型输出的掩码特征;对掩码特征进行解码,得到预测图像特征,并基于文本特征将图像特征与预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征;基于第一映射特征与第二映射特征计算第二损失,并根据第一损失和第二损失对神经网络模型的参数进行更新。该方法可以提升训练得到的模型的准确性和可迁移性。
Description
技术领域
本申请涉及人工智能技术领域,具体涉及一种模型训练方法、装置、计算机可读存储介质及计算机设备。
背景技术
计算机视觉技术(Computer Vision,CV)是一门研究如何使机器“看”的科学,更进一步的说,就是指用摄影机和电脑代替人眼对目标进行识别和测量等机器视觉,并进一步做图形处理,使电脑处理成为更适合人眼观察或传送给仪器检测的图像。作为一个科学学科,计算机视觉研究相关的理论和技术,试图建立能够从图像或者多维数据中获取信息的人工智能系统。计算机视觉技术通常包括图像处理、图像识别、图像语义理解、图像检索、OCR、视频处理、视频语义理解、视频内容/行为识别、三维物体重建、3D技术、虚拟现实、增强现实、同步定位与地图构建等技术,还包括常见的人脸识别、指纹识别等生物特征识别技术。
计算机视觉技术的核心是视觉模型,为提升视觉模型的准确性,对视觉模型的训练一般需要采用大量的样本数据进行训练。而大量样本数据进行在线训练会消耗大量时间,导致模型的训练效率下降,因此本领域技术人员一般采用大量样本数据先对视觉模型进行离线的预训练,然后在具体的下游任务中再采用少量的样本数据对预训练后的视觉模型进行微调,从而得到模型精度和训练效率上的双重提升。
然而,目前对视觉模型的进行预训练的方法,训练得到的视觉模型精度还不高。
发明内容
本申请实施例提供一种模型训练方法、装置、计算机可读存储介质及计算机设备,该方法可以大大提升神经网络模型的模型精度。
本申请第一方面提供一种模型训练方法,方法包括:
获取训练样本数据,所述训练样本数据包括样本图像以及与所述样本图像对应的样本文本;
将所述样本图像输入至神经网络模型,得到所述神经网络模型输出的图像特征,并基于所述图像特征与所述样本文本的文本特征计算第一损失;
对所述样本图像进行掩码处理,得到掩码图像,并将所述掩码图像输入至所述神经网络模型,得到所述神经网络模型输出的掩码特征;
对所述掩码特征进行解码,得到预测图像特征,并基于所述文本特征将所述图像特征与所述预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征;
基于所述第一映射特征与所述第二映射特征计算第二损失,并根据所述第一损失和所述第二损失对所述神经网络模型的参数进行更新。
相应的,本申请第二方面提供一种模型训练装置,装置包括:
第一获取单元,用于获取训练样本数据,所述训练样本数据包括样本图像以及与所述样本图像对应的样本文本;
第一计算单元,用于将所述样本图像输入至神经网络模型,得到所述神经网络模型输出的图像特征,并基于所述图像特征与所述样本文本的文本特征计算第一损失;
处理单元,用于对所述样本图像进行掩码处理,得到掩码图像,并将所述掩码图像输入至所述神经网络模型,得到所述神经网络模型输出的掩码特征;
映射单元,用于对所述掩码特征进行解码,得到预测图像特征,并基于所述文本特征将所述图像特征与所述预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征;
更新单元,用于基于所述第一映射特征与所述第二映射特征计算第二损失,并根据所述第一损失和所述第二损失对所述神经网络模型的参数进行更新。
可选地,在一些实施例中,映射单元,包括:
归一化子单元,用于对所述图像特征以及所述预测图像特征映射到同一图像特征空间,并对映射得到的特征进行归一化处理,得到所述图像特征对应的第一归一化特征以及所述预测图像特征对应的第二归一化特征;
第一映射子单元,用于基于所述文本特征将所述第一归一化特征与所述第二归一化特征映射到文本空间,得到所述图像特征对应的第一映射特征以及所述预测图像特征对应的第二映射特征。
可选地,在一些实施例中,映射子单元,包括:
第一映射模块,用于以所述文本特征作为基向量,将所述第一归一化特征映射为在文本空间上的概率分布,得到第一映射特征;
第二映射模块,用于以所述文本特征作为基向量,将所述第二归一化特征映射为在所述文本空间上的概率分布,得到第二映射特征。
可选地,在一些实施例中,第一计算单元,包括:
第二映射子单元,用于将所述图像特征与所述样本文本的文本特征映射到同一特征空间,得到图像映射特征以及文本映射特征;
计算子单元,用于基于所述图像映射特征以及所述文本映射特征计算第一损失。
可选地,在一些实施例中,计算子单元,包括:
第一处理模块,用于对所述图像映射特征进行归一化处理,得到第三归一化特征;
第二处理模块,用于对所述文本映射特征进行归一化处理,得到第四归一化特征;
计算模块,用于根据所述第三归一化特征以及所述第四归一化特征计算第一损失。
可选地,在一些实施例中,计算模块,包括:
第一计算子模块,用于根据所述第三归一化特征与所述第四归一化特征计算所述样本图像对所述样本文本的第一对比学习损失;
第二计算子模块,用于根据所述第三归一化特征与所述第四归一化特征计算所述样本文本对所述样本图像的第二对比学习损失;
第三计算子模块,用于计算所述第一对比学习损失与所述第二对比学习损失的均值,得到第一损失。
可选地,在一些实施例中,处理单元,包括:
掩码子单元,用于对所述样本图像进行随机掩码,得到掩码图像;
编码子单元,用于基于所述神经网络模型对所述掩码图像进行图像编码,得到所述样本图像的掩码特征。
可选地,在一些实施例中,更新单元,包括:
获取子单元,用于获取所述第一损失的第一权重系数以及所述第二损失的第二权重系数;
处理子单元,用于基于所述第一权重系数与所述第二权重系数对所述第一损失和所述第二损失进程加权处理,得到目标损失;
更新子单元,用于基于所述目标损失对所述神经网络模型的参数进行更新。
可选地,在一些实施例中,更新子单元,包括:
确定模块,用于基于所述目标损失确定反传梯度;
第三处理模块,用于根据所述反传梯度进行梯度反传处理,以对所述神经网络模型的参数进行更新。
本申请第三方面提供一种模型训练方法,方法包括:
获取目标任务对应的目标训练样本数据,所述目标训练样本数据包括目标样本图像以及所述目标样本图像对应的标签数据;
将所述目标样本图像输入至神经网络模型中进行预测,得到预测数据,所述神经网络模型为根据第一方面提供的模型训练方法训练得到的神经网络模型;
根据所述预测数据与所述标签数据计算预测损失;
基于所述预测损失对所述神经网络模型的参数进行调整,得到所述目标任务对应的目标神经网络模型。
相应的,本申请第四方面提供了一种模型训练装置,装置包括:
第二获取单元,用于获取目标任务对应的目标训练样本数据,所述目标训练样本数据包括目标样本图像以及所述目标样本图像对应的标签数据;
预测单元,用于将所述目标样本图像输入至神经网络模型中进行预测,得到预测数据,所述神经网络模型为第一方面提供的模型训练方法训练得到的神经网络模型;
第二计算单元,用于根据所述预测数据与所述标签数据计算预测损失;
调整单元,用于基于所述预测损失对所述神经网络模型的参数进行调整,得到所述目标任务对应的目标神经网络模型。
本申请第五方面还提供一种计算机可读存储介质,所述计算机可读存储介质存储有多条指令,所述指令适于处理器进行加载,以执行本申请第一方面或第三方面所提供的模型训练方法中的步骤。
本申请第六方面提供一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可以在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现本申请第一方面或第三方面所提供的模型训练方法中的步骤。
本申请第七方面提供一种计算机程序产品,包括计算机程序/指令,所述计算机程序/指令被处理器执行时实现第一方面或第三方面所提供的模型训练方法中的步骤。
本申请实施例提供的模型训练方法,通过获取训练样本数据,训练样本数据包括样本图像以及与样本图像对应的样本文本;将样本图像输入至神经网络模型,得到神经网络模型输出的图像特征,并基于图像特征与样本文本的文本特征计算第一损失;对样本图像进行掩码处理,得到掩码图像,并将掩码图像输入至神经网络模型,得到神经网络模型输出的掩码特征;对掩码特征进行解码,得到预测图像特征,并基于文本特征将图像特征与预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征;基于第一映射特征与第二映射特征计算第二损失,并根据第一损失和第二损失对神经网络模型的参数进行更新。
以此,本申请提供的模型训练方法,通过计算图像文本对比学习的第一损失,以及计算在语义空间中图像掩码学习重建的第二损失,然后基于第一损失和第二损失对神经网络模型的训练过程进行约束,从而实现掩码图像重建学习和图像文本对比学习两种对视觉模型进行预训练的范式进行有效结合。如此可以大大提升对视觉模型的预训练效果,提升预训练得到的视觉模型的精度和可迁移性,进而可以大大提升预训练后的模型应用到下游任务中的准确性。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请中模型训练的一个场景示意图;
图2为图像自监督对比学习方法对应的训练过程示意图;
图3为图像文本对比学习方法的训练过程示意图;
图4为为使用掩码图像模型进行预训练的流程示意图;
图5是本申请提供的模型训练方法的一个流程示意图;
图6是本申请提供的模型训练方法的另一个流程示意图;
图7是本申请提供的模型训练方法的框架流程示意图;
图8是本申请提供的模型训练方法的又一流程示意图;
图9是本申请提供的模型训练装置的结构示意图;
图10是本申请提供的模型训练装置的另一结构示意图;
图11是本申请提供的计算机设备的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述。显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明实施例提供一种模型训练方法、装置、计算机可读存储介质及计算机设备。其中,该模型训练方法可以使用于模型训练装置中。该模型训练装置可以集成在计算机设备中,该计算机设备可以是终端也可以是服务器。其中,终端可以为手机、平板电脑、笔记本电脑、智能电视、穿戴式智能设备、个人计算机(PC,Personal Computer)以及车载终端等设备。服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、网络加速服务(Content Delivery Network,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。其中,服务器可以为区块链中的节点。
请参阅图1,为本申请提供的模型训练方法的一场景示意图。如图所示,计算机设备A获取训练样本数据,训练样本数据包括样本图像以及与样本图像对应的样本文本;将样本图像输入至神经网络模型,得到神经网络模型输出的图像特征,并基于图像特征与样本文本的文本特征计算第一损失;对样本图像进行掩码处理,得到掩码图像,并将掩码图像输入至神经网络模型,得到神经网络模型输出的掩码特征;对掩码特征进行解码,得到预测图像特征,并基于文本特征将图像特征与预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征;基于第一映射特征与第二映射特征计算第二损失,并根据第一损失和第二损失对神经网络模型的参数进行更新。
需要说明的是,图1所示的模型训练场景示意图仅仅是一个示例,本申请实施例描述的模型训练场景是为了更加清楚地说明本申请的技术方案,并不构成对于本申请提供的技术方案的限定。本领域普通技术人员可知,随着模型训练场景演变和新业务场景的出现,本申请提供的技术方案对于类似的技术问题,同样适用。
基于上述实施场景以下分别进行详细说明。
在相关技术中,如何学习具有可迁移能力的视觉表征是视觉模型预训练的研究重点。可迁移的视觉表征往往意味着模型可以在下游任务上进行快速的微调。目前被广泛证明具有较好的可迁移及可拓展性的预训练方法主要有两种:第一种为使用自然语言监督的图像文本对比学习方法;第二种为使用掩码图像模型进行预训练的图像自监督预训练方法。图像文本对比学习预训练因为使用了自然语言作为视觉特征的监督信号,所学习得到的视觉特征天然地具有和文本特征对齐的性质,可以进行零样本的分类任务;使用掩码图像模型进行预训练的方法,自然的具备着比较好的在下游任务上的可迁移性。
其中,如图2所示,为图像自监督对比学习方法对应的训练过程示意图。如图所示,首先将一张样本图像进行两种类型的样本增强,得到增强图像1和增强图像2,然后采用神经网络模型分别提取这两张增强图像的图像特征,得到图像特征1和图像特征2。然后通过增大同一图像的不同增强图像的图像特征之间的相似度,并减小不同图像的图像特征之间的相似度,来构建对比学习损失,从而使得神经网络模型可以获得较好的图像特征的提取能力。进而在下游任务中展现出较好的特征可判别性。图像自监督对比学习的方法主要有:SimCLR,MoCo,BYOL,Dino等。
其中,如图3所示,为图像文本对比学习方法的训练过程示意图。如图所示,训练样本为图像文本对,没对图像文本对包括样本图像以及与样本图像对应的样本文本。采用文本编码器对样本文本进行文本特征提取,再采用待训练的神经网络模型对样本图像进行图像特征的提取。然后将提取到的图像特征和文本特征在同一特征空间中进行特征对齐,分别得到图像特征对应的对齐特征1以及文本特征对应的对齐特征2。然后通过增大对应的样本文本与样本图像之间的对齐特征的相似度,降低不对应的样本文本与样本图像之间的对齐特征的相似度,来构建对比学习损失,进而对神经网络模型进行训练。受益于自然语言的监督信号,图像编码器和文本编码器提取的图像和文本特征在同一个特征空间中进行了对齐,使得该网络可以天生的进行零样本的视觉理解,同时对少样本的场景具备着很好的可迁移性和鲁棒性。图像文本对比学习的视觉预训练的代表方法有:CLIP,ALIGN等。
如图4所示,为使用掩码图像模型进行预训练的流程示意图。如图所示,该方法遵循“先遮罩,再重建”的逻辑范式,具体而言,先对样本图像进行随机掩码,得到掩码图像。即通过随机选取的方法,选取部分图像的子块进行随机遮罩,遮罩之后的掩码图像输入至待训练的神经网络模型中进行特征提取,得到掩码图像特征。然后再采用解码器对掩码图像特征进行解码以重建输入图像,得到重建图像。进一步地,可以通过增大重建图像和样本图像之间的相似度来构建重建损失,并基于重建损失来对神经网络模型进行训练。模型在预训练的过程中对图像子块之间的关系进行建模,从而得到了比较好的特征提取能力,在下游任务的上展现出了较好的可迁移性。图像掩码建模预训练的代表方法有:MAE,SimMIM,BEiT,iBOT等。
目前,上述方案在各自领域都能展现出一定的预训练效果,而本申请将在此基础上进一步优化,提出一种将图像文本对比学习方法和掩码图像模型训练方法进行有效融合,以进一步提升视觉模型的预训练效果的模型训练方法。下面将对该方法进行详细描述。
实施例一
本申请实施例将从模型训练装置的角度进行描述,该模型训练装置可以集成在计算机设备中。其中,计算机设备可以是终端也可以是服务器。其中,终端可以为手机、平板电脑、笔记本电脑、智能电视、穿戴式智能设备、个人计算机(PC,Personal Computer)以及车载终端等设备。服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、网络加速服务(Content Delivery Network,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。如图5所示,为本申请提供的模型训练方法的流程示意图,该方法包括:
步骤101,获取训练样本数据。
其中,在本申请实施例中,用于对视觉模型进行预训练的样本数据可以为从网络上爬取的大量的图像文本对。其中,此处视觉模型便可以为本申请中需要进行训练的神经网络模型。其中,图像文本对中包括了样本图像以及与样本图像对应的样本文本。其中,与样本图像对应的样本文本具体可以为样本图像的标签数据。例如,当样本图像为一只猫时,样本文本可以为动物、猫科、蓝猫等等。当然,样本文本还可以为对样本图像的更为细致的描述数据。
获取训练样本数据,可以为获取一组训练样本数据,即获取前述一组图像文本对,该一组图像文本对包括一张样本图像以及与该样本图像对应的样本文本。或者,也可以为获取一批训练样本数据,即获取多组图像文本对。一批训练样本数据中包括了多张样本图像以及与每张训练样本图像对应的样本文本。在采用训练样本数据对神经网络模型进行训练时,可以逐组获取图像文本对来对神经网络模型进行训练;也可以逐批获取图像文本对来对神经网络模型进行训练。
步骤102,将样本图像输入至神经网络模型,得到神经网络模型输出的图像特征,并基于图像特征与样本文本的文本特征计算第一损失。
其中,在本申请提供的模型训练方法中,在对模型进行训练的过程中涉及到三个网络模块,分别为图像编码模块、文本编码模块以及图像解码模块。其中,图像编码模块为本申请提供的模型训练方法所需要训练的网络模块,即本申请提供的模型训练方法即是为了对图像编码模块进行预训练,然后将预训练后的图像编码模块应用到下游任务中。文本编码模块和图像解码模块只参与训练过程,而不会应用到下游任务中。
其中,本申请中待训练的神经网络模型(图像编码模块)、图像解码模块以及文本编码模块都可以为由一些列的transformer模块堆叠而成。其中,transformer模块具体为一个基于自注意力机制的深度学习模型。在人工智能领域中,transformer已经是一个使用较为普遍的模型结构,此处不再对此进行详细介绍。
当获取到样本图像以及样本图像对应的样本文本后,便可以先采用文本编码模块对样本文本进行编码,得到文本特征。同时可以采用图像编码模块对样本图像进行图像编码,得到图像特征。在提取到文本特征和图像特征后,可以先基于提取到的文本特征和图像特征计算第一损失,其中此处第一损失具体可以为前述介绍的图像文本对比学习的对比损失。
其中,在一些实施例中,基于图像特征与样本文本的文本特征计算第一损失,包括:
1、将图像特征与样本文本的文本特征映射到同一特征空间,得到图像映射特征以及文本映射特征;
2、基于图像映射特征以及文本映射特征计算第一损失。
其中,在本申请实施例中,在采用图像编码模块对样本图像进行编码得到图像特征,以及采用文本编码模块对样本文本进行编码得到文本特征后。可以进一步对图像特征和文本特征进行特征映射,以将图像特征和文本特征映射到同一特征空间中,分别得到图像映射特征以及文本映射特征。
其中,将图像特征和文本特征映射到同一特征空间中,是为了将图像特征和文本特征在同一特征空间中进行对齐,从而方便进行对比学习。具体地,将图像特征和文本特征映射到同一特征空间中,可以为将图像特征映射到文本特征对应的特征空间中,也可以为将文本特征映射到图像特征对应的空间中,又或者可以为将图像特征和文本特征都映射到一个特定的特征空间中,该特定的特征空间可以既不是文本特征对应的特征空间,也不是图像特征对应的特征空间。
在将图像特征和文本特征在同一特征空间中进行特征对齐,得到图像映射特征和文本映射特征后,便可以进一步基于图像映射特征和文本映射特征来计算对比损失,即计算前述第一损失。
可选地,在一些实施例中,基于图像映射特征以及文本映射特征计算第一损失,包括:
2.1、对图像映射特征进行归一化处理,得到第三归一化特征;
2.2、对文本映射特征进行归一化处理,得到第四归一化特征;
2.3、根据第三归一化特征以及第四归一化特征计算第一损失。
其中,在对图像特征和文本特征在同一特征空间进行特征对齐,得到图像映射特征和文本映射特征后,可以进一步对图像映射特征和文本映射特征进行归一化处理,得到第三归一化特征和第四归一化特征。然后再根据第三归一化特征和第四归一化特征来计算对比损失,如此可以进一步提升计算得到的对比损失的准确性。
其中,可选地,在一些实施例中,根据第三归一化特征以及第四归一化特征计算第一损失,包括:
2.3.1、根据第三归一化特征与第四归一化特征计算样本图像对样本文本的第一对比学习损失;
2.3.2、根据第三归一化特征与第四归一化特征计算样本文本对样本图像的第二对比学习损失;
2.3.3、计算第一对比学习损失与第二对比学习损失的均值,得到第一损失。
其中,在本申请实施例中,样本图像和样本文本对应的归一化特征之间的对比损失具体可以包括两种:图像对文本的对比学习损失以及文本对图像之间的对比学习损失。即可以根据第三归一化特征和第四归一化特征计算样本图像对样本文本的第一对比学习损失;还可以根据第三归一化特征和第四归一化特征计算样本文本对样本图像的第二对比学习损失。
进一步地,在计算得到第一对比学习损失和第二对比学习损失后,可以对第一对比学习损失和第二对比学习损失进行求均值的方法来计算得到最终的对比损失,即得到前述第一损失。如此可以得到更为准确的图像文本对比学习损失。
步骤103,对样本图像进行掩码处理,得到掩码图像,并将掩码图像输入至神经网络模型,得到神经网络模型输出的掩码特征。
其中,在本申请实施例中,在获取到样本图像以及样本文本后,除了基于样本图像和样本文本进行特征提取以及对比损失计算外,还可以进一步对样本图像进行掩码处理,得到掩码图像。然后对掩码图像进行特征提取,得到掩码特征。
其中,在一些实施例中,对样本图像进行掩码处理,得到掩码图像,并将掩码图像输入至神经网络模型,得到神经网络模型输出的掩码特征,包括:
1、对样本图像进行随机掩码,得到掩码图像;
2、基于神经网络模型对掩码图像进行图像编码,得到样本图像的掩码特征。
其中,对样本图像的掩码处理,具体可以为随机掩码处理。例如随机对样本图像中的图像子块进行随机遮罩,得到掩码图像。掩码图像可以包括未被遮罩的图像子块。然后,再采用神经网络模型对掩码图像进行图像编码,得到样本图像的掩码特征。其中,此处神经网络模型具体为前述图像编码器,与对样本图像进行编码的图像编码器为同一图像编码器,即都为本申请提供的模型训练方法中需要进行训练的神经网络模型、
步骤104,对掩码特征进行解码,得到预测图像特征,并基于文本特征将图像特征与预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征。
其中,在对样本图像进行随机掩码得到掩码图像,并采用图像编码器对掩码图像进行特征提取得到掩码特征后,便可以进一步采用前述图像解码器对该掩码特征进行解码,得到预测图像特征。
具体地,图像解码器的输入可以为前述图像编码器输出的掩码特征以及可学习的掩码令牌,图像解码器的解码过程具体可以为预测被随机遮罩的部分图像子块的特征。然后得到重建图像特征,即上述预测图像特征。
在本申请实施例中,在得到图像解码器输出的预测图像特征后,便可以将图像编码器对样本图像进行图像编码得到的图像特征以及此处的预测图像特征一同映射到文本空间中,从而实现在文本控件中的掩码重建学习。通过将重建的空间由视觉空间迁移至文本语义空间,避免了学习过渡聚焦于低级的视觉信息,从而可以大大提升对视觉模型的预训练效果。具体地,可以基于文本特征将图像特征映射到文本空间中,得到第一映射特征;以及基于文本特征将预测图像特征映射到文本空间中,得到第二映射特征。
其中,在一些实施例中,基于文本特征将图像特征与预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征,包括:
1、对图像特征以及预测图像特征映射到同一图像特征空间,并对映射得到的特征进行归一化处理,得到图像特征对应的第一归一化特征以及预测图像特征对应的第二归一化特征;
2、基于文本特征将第一归一化特征与第二归一化特征映射到文本空间,得到图像特征对应的第一映射特征以及预测图像特征对应的第二映射特征。
其中,在本申请实施例中,将图像特征和预测图像特征映射到文本空间中的具体过程,可以为先将图像特征和预测图像特征映射到同一图像特征空间中,并对映射得到的特征进行归一化处理,得到图像特征对应的第一归一化特征以及预测图像对应的第二归一化特征。
然后,再基于文本特征将第一归一化特征和第二归一化特征映射到文本空间中,得到图像特征对应的第一映射特征以及预测图像对应的第二映射特征。
其中,在一些实施例中,基于文本特征将第一归一化特征与第二归一化特征映射到文本空间,得到图像特征对应的第一映射特征以及预测图像特征对应的第二映射特征,包括:
2.1、以文本特征作为基向量,将第一归一化特征映射为在文本空间上的概率分布,得到第一映射特征;
2.2、以文本特征作为基向量,将第二归一化特征映射为在文本空间上的概率分布,得到第二映射特征。
其中,在本申请实施例中,基于文本特征将第一归一化特征和第二归一化特征映射到文本空间中的具体过程,可以为以文本特征作为基向量,将第一归一化特征和第二归一化特征硬核为在文本控件上的概率分布,从而得到图像特征对应的第一映射特征以及预测图像特征对应的第二映射特征。
步骤105,基于第一映射特征与第二映射特征计算第二损失,并根据第一损失和第二损失对神经网络模型的参数进行更新。
其中,在将图像特征和预测图像特征映射到文本空间中得到第一映射特征以及第二映射特征后,便可以基于第一映射特征与第二映射特征计算第二损失。此处第二损失为在文本空间中的掩码重建损失。然后,基于前述第一损失,即图像文本对比学习损失,以及此处在文本空间中的掩码重建损失来对神经网络模型的训练过程进行约束,从而对神经网络模型的参数进行更新调整。
具体地,在一些实施例中,根据第一损失和第二损失对神经网络模型的参数进行更新,包括:
1、获取第一损失的第一权重系数以及第二损失的第二权重系数;
2、基于第一权重系数与第二权重系数对第一损失和第二损失进程加权处理,得到目标损失;
3、基于目标损失对神经网络模型的参数进行更新。
其中,在本申请实施例中,基于第一损失和第二损失对神经网络模型进行更新的过程,具体可以先根据第一损失和第二损失确定对神经网络模型的训练过程进行约束的最终的目标损失。具体地,可以根据图像文本对比学习损失和文本空间的掩码图像重建损失的重要程度来确定两个损失的权重系数,然后再根据两者的权重系数进行加权得到目标损失。
具体地,可以先获取第一损失的第一权重系数以及获取第二损失的第二权重系数。然后基于第一权重系数和第二权重系数对第一损失和第二损失进行加权处理,得到目标损失。进一步地,可以基于目标损失对神经网络模型的参数进行更新。
进一步地,在一些实施例中,基于目标损失对神经网络模型的参数进行更新,包括:
3.1、基于目标损失确定反传梯度;
3.2、根据反传梯度进行梯度反传处理,以对神经网络模型的参数进行更新。
其中,在本申请实施例中,基于目标损失对神经网络模型的参数进行更新的具体过程,可以为采用梯度下降法对神经网络模型的参数进行更新。具体可以先根据目标损失确定反传梯度,然后根据确定的反传梯度进行梯度反传处理,以对神经网络模型的参数进行更新。
其中,本申请的方案具体可以为采用一组图像文本对对模型参数进行的一次更新。在一些实施例中,在对模型参数进行更新后,可以进一步获取下一组图像文本对,或者获取下一个批次的多组图像文本对来对神经网络模型的模型参数进行循环的迭代更新,直到神经网络模型的模型参数收敛,得到训练后的神经网络模型,即完成对视觉模型的预训练。
根据上述描述可知,本申请实施例提供的模型训练方法,方法通过获取训练样本数据,训练样本数据包括样本图像以及与样本图像对应的样本文本;将样本图像输入至神经网络模型,得到神经网络模型输出的图像特征,并基于图像特征与样本文本的文本特征计算第一损失;对样本图像进行掩码处理,得到掩码图像,并将掩码图像输入至神经网络模型,得到神经网络模型输出的掩码特征;对掩码特征进行解码,得到预测图像特征,并基于文本特征将图像特征与预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征;基于第一映射特征与第二映射特征计算第二损失,并根据第一损失和第二损失对神经网络模型的参数进行更新。
以此,本申请提供的模型训练方法,通过计算图像文本对比学习的第一损失,以及计算在语义空间中图像掩码学习重建的第二损失,然后基于第一损失和第二损失对神经网络模型的训练过程进行约束,从而实现掩码图像重建学习和图像文本对比学习两种对视觉模型进行预训练的范式进行有效结合。如此可以大大提升对视觉模型的预训练效果,提升预训练得到的视觉模型的精度和可迁移性,进而可以大大提升预训练后的模型应用到下游任务中的准确性。
实施例二
本申请还提供了一种模型训练方法,该方法可以使用于计算机设备中,该计算机设备可以为终端或服务器中。如图6所示,为本申请提供的模型训练方法的另一流程示意图,方法具体包括:
步骤201,计算机设备获取样本图像以及样本文本。
其中,计算机设备获取的样本图像与样本文本可以为存在对应关系的样本图像和样本文本。样本图像和样本文本可以为从网络上爬取的样本图像和样本文本,也可以为从训练数据集中获取的样本图像和样本文本。样本图像和样本文本的数量可以为一组也可以为一个批次(batch)。
步骤202,计算机设备对样本图像进行随机掩码,得到掩码图像。
对于任意一张样本图像I,可以生成对应的掩码图像具体地,可以采用随机掩码的方式对样本图像进行掩码操作,生成对应的掩码图像。随机掩码操作,可以为对图像中包括的多个子块进行随机遮罩,未被遮罩的子块便构成了掩码图像。
步骤203,计算机设备将掩码图像与样本图像分别输入至待训练的视觉模型,得到视觉模型输出的掩码特征和图像特征。
在本申请实施例中,提供了图像编码器(Vision Encoder,V-Enc)、图像解码器(Vision Decoder,V-Dec)以及文本编码器(Language Encoder,L-Enc)。图像编码器、图像解码器以及文本编码器都可以为多个transformer堆叠而成的结构。此处图像编码器便为本申请提供的模型训练方法中待训练的视觉模型。
在对样本图像进行掩码操作得到掩码图像后,便可以采用待训练的视觉模型(图像编码器)对掩码图像以及样本图像进行图像编码。具体可以如下表示:
V-Enc(I)={fk|k∈[1,N]},
其中,k表示图像子块(patch)的索引,N表示图像子块的总数。图像特征便为多个图像子块特征的集合,因此也可以称为图像patch特征。M表示被随机遮罩的图像子块索引的集合。
步骤204,计算机设备将样本文本输入至文本编码器,得到输出的文本特征。
其中,计算机设备在将样本图像和掩码图像输入至图像编码器进行编码的同时,可以将样本文本T也输入至文本编码器L-Enc中进行编码,得到文本特征。具体表示如下:
L-Enc(T)=h。
步骤205,计算机设备基于图像特征和文本特征计算对比损失。
其中,计算机设备在获得样本图像的图像特征,即前述图像patch特征和文本特征后,便可以进行图像文本对比学习损失的计算。具体地,可以先将图像特征和文本特征映射到同一特征空间中,并对映射后的特征进行归一化处理。具体表示如下:
zT=||φ(h)||
其中,θ表示图像特征的特征映射器,对图像特征进行空间变换,将图像特征映射到特定空间中。φ表示文本特征的特征映射器,将文本特征映射到特定空间中。表示样本图像经过图像编码器后的多个图像patch特征进行全局平均的图像全局表征。||.||表示归一化操作,具体表示对特征进行L2归一化操作。
进一步地,可以基于特征映射并归一化处理后的图像特征和文本特征进行对比损失的计算,具体表示如下:
其中,B表示批(batch)的大小,σ为一个可学习的网络权重,<.>表示计算图像和文本特征之间的余弦相似度。表示图像对文本的对比学习损失,/>表示文本对图像的对比学习损失。
最终的图像文本对比学习损失表示如下:
步骤206,计算机设备将掩码特征输入至图像解码器进行解码预测,得到预测图像特征。
进一步地,计算机设备可以将图像编码器对掩码图像进行编码得到的掩码特征输入至图像解码器V-Dec中进行解码预测,得到预测图像特征。其中预测图像特征中同样包括了多个图像子块的特征,具体可以表示为
步骤207,计算机设备将预测图像特征与图像特征映射到文本空间中,得到图像映射特征和预测图像映射特征。
进一步地,计算机设备可以将图像特征和预测图像特征先映射到同一图像空间中,然后再进行归一化处理。具体可以表示如下:
其中,θ表示图像特征的映射器,对图像特征进行空间变换,将图像特征映射到特定空间中。||.||表示归一化操作,具体表示对特征进行L2归一化操作。k表示在一张图像中的子块索引。
进一步地,可以将特征映射并归一化后的图像特征以及预测图像特征再映射到文本空间中。具体地,可以以一个batch内的文本特征作为基向量,将图像特征和预测图像特征映射为在文本空间上的概率分布,具体表示如下:
其中,τ1和τ2为可调节的超参数。在本申请实施例中,具体分别可以设置为0.04和0.1。
步骤208,计算机设备基于图像映射特征和预测图像映射特征计算重建损失。
通过上述方式,图像的特征被转化为了在文本语义控件下的概率分布,重建的目标为最小化p和q之间的分布差异。具体可以采用相对熵(KL散度)方式来构建重建损失,公式如下:
其中,sg表示梯度截断操作。
步骤209,计算机设备根据重建损失和对比损失计算目标损失。
其中,在构建得到重建损失和对比损失后,可以进一步根据重建损失和对比损失计算目标损失。具体地,可以采用如下公式计算目标损失:
其中,λ1和λ2表示损失函数的权重,具体可以设置为1.0和0.5。
步骤210,计算机设备根据目标损失对视觉模型进行训练,得到训练后的视觉模型。
其中,在构建好目标损失后,便可以基于目标损失对视觉模型进行训练,得到训练后模型。
如图7所示,为本申请提供的模型训练方法的流程示意图。如图所示,当获取到样本图像以及样本图像对应的样本文本后,可以先采用文本编码器对样本文本进行编码得到文本特征。然后可以对样本图像进行图像掩码操作,得到掩码图像,再采用图像编码器对样本图像以及掩码图像进行编码,得到图像特征以及掩码特征。然后将掩码特征输入至图像解码器进行解码得到预测图像特征。在得到上述特征后,可以将文本特征和图像特征映射到同一特征空间中并进行归一化处理,然后基于归一化处理后得到的特征计算对比损失。另一方面,可以将图像特征和预测图像特征映射到同一图像空间中并进行归一化处理,然后将归一化处理后的特征进一步映射到文本空间中,并基于映射到文本空间中的特征计算重建损失。然后基于对比损失和重建损失构建目标损失,再基于目标损失对图像编码器的参数进行迭代更新,实现对视觉模型的预训练。
如下表1所示,为在某一具体下游任务(分类任务)中,本申请提供的模型训练方法训练得到的预训练模型与现有技术中的基线模型训练方法在某一特定的数据集(例如LAION-20M)中进行预训练得到的模型的精度对比结果。
方法 | 预训练数据集 | 预训练迭代数 | 线性探测 | 端到端微调 |
SimCLR | LAION-20M | 25 | 51.7 | 81.3 |
MAE | LAION-20M | 25 | 44.3 | 82.1 |
CLIP | LAION-20M | 25 | 67.8 | 82.7 |
SLIP | LAION-20M | 25 | 70.1 | 82.6 |
MAE+CLIP | LAION-20M | 25 | 64.5 | 82.9 |
本申请模型 | LAION-20M | 25 | 71.5 | 83.3 |
表1线性探测与端到端微调精度对比表
如表1所示,本申请提供的模型训练方法预训练得到的视觉模型在线性探测和端到端微调中都获得了相对于基线模型更高的精度。
根据上述描述可知,本申请实施例提供的模型训练方法,方法通过获取训练样本数据,训练样本数据包括样本图像以及与样本图像对应的样本文本;将样本图像输入至神经网络模型,得到神经网络模型输出的图像特征,并基于图像特征与样本文本的文本特征计算第一损失;对样本图像进行掩码处理,得到掩码图像,并将掩码图像输入至神经网络模型,得到神经网络模型输出的掩码特征;对掩码特征进行解码,得到预测图像特征,并基于文本特征将图像特征与预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征;基于第一映射特征与第二映射特征计算第二损失,并根据第一损失和第二损失对神经网络模型的参数进行更新。
以此,本申请提供的模型训练方法,通过计算图像文本对比学习的第一损失,以及计算在语义空间中图像掩码学习重建的第二损失,然后基于第一损失和第二损失对神经网络模型的训练过程进行约束,从而实现掩码图像重建学习和图像文本对比学习两种对视觉模型进行预训练的范式进行有效结合。如此可以大大提升对视觉模型的预训练效果,提升预训练得到的视觉模型的精度和可迁移性,进而可以大大提升预训练后的模型应用到下游任务中的准确性。
实施例三
本申请实施例还提供了一种模型训练方法,如图8所示,为本申请提供的模型训练方法的又一流程示意图,方法包括:
步骤301,获取目标任务对应的目标训练样本数据。
其中,目标任务可以为图像分类任务、图像识别任务、图像处理任务等等任意任务。目标任务的训练样本数据包括目标样本图像以及目标样本图像对应的标签数据。
步骤302,将目标样本图像输入至神经网络模型中进行预测,得到预测数据。
其中,本实施例中的神经网络模型具体可以为实施例一或者实施例二中提供的模型训练方法训练得到的神经网络模型,即实施例以或者实施例二中的模型训练方法是对视觉模型的预训练,而本申请将在实施例一或者实施例二提供的模型训练方法预训练得到的视觉模型的基础上采用目标任务对应的目标训练样本数据来对预训练得到的视觉模型进行进一步的微调以适用到具体的下游任务,例如图像分类任务、图像识别任务中。
因此,在获取到目标样本图像和对应的标签数据后,可以将目标样本图像输入到神经网络模型中进行预测,得到预测数据。
步骤303,根据预测数据与标签数据计算预测损失。
其中,在将目标样本图像输入到神经网络模型中得到输出的预测数据后,便可以进一步基于预测数据与标签数据之间的差异计算出预测损失。
步骤304,基于预测损失对神经网络模型的参数进行调整,得到目标任务对应的目标神经网络模型。
进一步地,可以基于预测损失对神经网络模型的参数进行进一步的微调,得到目标任务对应的目标神经网络模型。由于实施例一或实施例二提供的神经网络模型进行预训练得到的视觉模型具有更好的模型精度,因此基于该预训练模型进行微调得到的下游任务模型也可以具有更好的模型效果。
实施例四
为了更好地实施以上模型训练方法,本申请实施例还提供一种模型训练装置,该模型训练装置可以集成在终端或服务器中。
例如,如图9所示,为本申请实施例提供的模型训练装置的结构示意图,装置可以包括第一获取单元401、第一计算单元402、处理单元403、映射单元404以及更新单元405,如下:
第一获取单元401,用于获取训练样本数据,训练样本数据包括样本图像以及与样本图像对应的样本文本;
第一计算单元402,用于将样本图像输入至神经网络模型,得到神经网络模型输出的图像特征,并基于图像特征与样本文本的文本特征计算第一损失;
处理单元403,用于对样本图像进行掩码处理,得到掩码图像,并将掩码图像输入至神经网络模型,得到神经网络模型输出的掩码特征;
映射单元404,用于对掩码特征进行解码,得到预测图像特征,并基于文本特征将图像特征与预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征;
更新单元405,用于基于第一映射特征与第二映射特征计算第二损失,并根据第一损失和第二损失对神经网络模型的参数进行更新。
可选地,在一些实施例中,映射单元,包括:
归一化子单元,用于对图像特征以及预测图像特征映射到同一图像特征空间,并对映射得到的特征进行归一化处理,得到图像特征对应的第一归一化特征以及预测图像特征对应的第二归一化特征;
第一映射子单元,用于基于文本特征将第一归一化特征与第二归一化特征映射到文本空间,得到图像特征对应的第一映射特征以及预测图像特征对应的第二映射特征。
可选地,在一些实施例中,映射子单元,包括:
第一映射模块,用于以文本特征作为基向量,将第一归一化特征映射为在文本空间上的概率分布,得到第一映射特征;
第二映射模块,用于以文本特征作为基向量,将第二归一化特征映射为在文本空间上的概率分布,得到第二映射特征。
可选地,在一些实施例中,第一计算单元,包括:
第二映射子单元,用于将图像特征与样本文本的文本特征映射到同一特征空间,得到图像映射特征以及文本映射特征;
计算子单元,用于基于图像映射特征以及文本映射特征计算第一损失。
可选地,在一些实施例中,计算子单元,包括:
第一处理模块,用于对图像映射特征进行归一化处理,得到第三归一化特征;
第二处理模块,用于对文本映射特征进行归一化处理,得到第四归一化特征;
计算模块,用于根据第三归一化特征以及第四归一化特征计算第一损失。
可选地,在一些实施例中,计算模块,包括:
第一计算子模块,用于根据第三归一化特征与第四归一化特征计算样本图像对样本文本的第一对比学习损失;
第二计算子模块,用于根据第三归一化特征与第四归一化特征计算样本文本对样本图像的第二对比学习损失;
第三计算子模块,用于计算第一对比学习损失与第二对比学习损失的均值,得到第一损失。
可选地,在一些实施例中,处理单元,包括:
掩码子单元,用于对样本图像进行随机掩码,得到掩码图像;
编码子单元,用于基于神经网络模型对掩码图像进行图像编码,得到样本图像的掩码特征。
可选地,在一些实施例中,更新单元,包括:
获取子单元,用于获取第一损失的第一权重系数以及第二损失的第二权重系数;
处理子单元,用于基于第一权重系数与第二权重系数对第一损失和第二损失进程加权处理,得到目标损失;
更新子单元,用于基于目标损失对神经网络模型的参数进行更新。
可选地,在一些实施例中,更新子单元,包括:
确定模块,用于基于目标损失确定反传梯度;
第三处理模块,用于根据反传梯度进行梯度反传处理,以对神经网络模型的参数进行更新。
具体实施时,以上各个单元可以作为独立的实体来实现,也可以进行任意组合,作为同一或若干个实体来实现,以上各个单元的具体实施可参见前面的方法实施例,在此不再赘述。
根据上述描述可知,本申请实施例提供的模型训练装置,通过第一获取单元401获取训练样本数据,训练样本数据包括样本图像以及与样本图像对应的样本文本;第一计算单元402将样本图像输入至神经网络模型,得到神经网络模型输出的图像特征,并基于图像特征与样本文本的文本特征计算第一损失;处理单元403对样本图像进行掩码处理,得到掩码图像,并将掩码图像输入至神经网络模型,得到神经网络模型输出的掩码特征;映射单元404对掩码特征进行解码,得到预测图像特征,并基于文本特征将图像特征与预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征;更新单元405基于第一映射特征与第二映射特征计算第二损失,并根据第一损失和第二损失对神经网络模型的参数进行更新。
以此,本申请提供的模型训练装置,通过计算图像文本对比学习的第一损失,以及计算在语义空间中图像掩码学习重建的第二损失,然后基于第一损失和第二损失对神经网络模型的训练过程进行约束,从而实现掩码图像重建学习和图像文本对比学习两种对视觉模型进行预训练的范式进行有效结合。如此可以大大提升对视觉模型的预训练效果,提升预训练得到的视觉模型的精度和可迁移性,进而可以大大提升预训练后的模型应用到下游任务中的准确性。
实施例五
为了更好地实施以上模型训练方法,本申请实施例还提供一种模型训练装置,该模型训练装置可以集成在终端或服务器中。
例如,如图10所示,为本申请实施例提供的模型训练装置的结构示意图,装置可以包括第二获取单元501、预测单元502、第二计算单元503以及调整单元504,如下;
第二获取单元501,用于获取目标任务对应的目标训练样本数据,目标训练样本数据包括目标样本图像以及目标样本图像对应的标签数据;
预测单元502,用于将目标样本图像输入至神经网络模型中进行预测,得到预测数据,神经网络模型为第一方面提供的模型训练方法训练得到的神经网络模型;
第二计算单元503,用于根据预测数据与标签数据计算预测损失;
调整单元504,用于基于预测损失对神经网络模型的参数进行调整,得到目标任务对应的目标神经网络模型。
本申请实施例还提供一种计算机设备,该计算机设备可以为终端或服务器,如图11所示,为本申请提供的计算机设备的结构示意图。具体来讲:
该计算机设备可以包括一个或者一个以上处理核心的处理单元601、一个或一个以上存储介质的存储单元602、电源模块603和输入模块604等部件。本领域技术人员可以理解,图11中示出的计算机设备结构并不构成对计算机设备的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。其中:
处理单元601是该计算机设备的控制中心,利用各种接口和线路连接整个计算机设备的各个部分,通过运行或执行存储在存储单元602内的软件程序和/或模块,以及调用存储在存储单元602内的数据,执行计算机设备的各种功能和处理数据。可选的,处理单元601可包括一个或多个处理核心;优选的,处理单元601可集成应用处理器和调制解调处理器,其中,应用处理器主要处理操作系统、对象界面和应用程序等,调制解调处理器主要处理无线通信。可以理解的是,上述调制解调处理器也可以不集成到处理单元601中。
存储单元602可用于存储软件程序以及模块,处理单元601通过运行存储在存储单元602的软件程序以及模块,从而执行各种功能应用以及模型训练。存储单元602可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序(比如声音播放功能、图像播放功能以及网页访问等)等;存储数据区可存储根据计算机设备的使用所创建的数据等。此外,存储单元602可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。相应地,存储单元602还可以包括存储器控制器,以提供处理单元601对存储单元602的访问。
计算机设备还包括给各个部件供电的电源模块603,优选的,电源模块603可以通过电源管理系统与处理单元601逻辑相连,从而通过电源管理系统实现管理充电、放电、以及功耗管理等功能。电源模块603还可以包括一个或一个以上的直流或交流电源、再充电系统、电源故障检测电路、电源转换器或者逆变器、电源状态指示器等任意组件。
该计算机设备还可包括输入模块604,该输入模块604可用于接收输入的数字或字符信息,以及产生与对象设置以及功能控制有关的键盘、鼠标、操作杆、光学或者轨迹球信号输入。
尽管未示出,计算机设备还可以包括显示单元等,在此不再赘述。具体在本实施例中,计算机设备中的处理单元601会按照如下的指令,将一个或一个以上的应用程序的进程对应的可执行文件加载到存储单元602中,并由处理单元601来运行存储在存储单元602中的应用程序,从而实现各种功能,如下:
获取训练样本数据,训练样本数据包括样本图像以及与样本图像对应的样本文本;将样本图像输入至神经网络模型,得到神经网络模型输出的图像特征,并基于图像特征与样本文本的文本特征计算第一损失;对样本图像进行掩码处理,得到掩码图像,并将掩码图像输入至神经网络模型,得到神经网络模型输出的掩码特征;对掩码特征进行解码,得到预测图像特征,并基于文本特征将图像特征与预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征;基于第一映射特征与第二映射特征计算第二损失,并根据第一损失和第二损失对神经网络模型的参数进行更新。
或者,获取目标任务对应的目标训练样本数据,目标训练样本数据包括目标样本图像以及目标样本图像对应的标签数据;将目标样本图像输入至神经网络模型中进行预测,得到预测数据,神经网络模型为根据第一方面提供的模型训练方法训练得到的神经网络模型;根据预测数据与标签数据计算预测损失;基于预测损失对神经网络模型的参数进行调整,得到目标任务对应的目标神经网络模型。
应当说明的是,本申请实施例提供的计算机设备与上文实施例中的方法属于同一构思,以上各个操作的具体实施可参见前面的实施例,在此不作赘述。
本领域普通技术人员可以理解,上述实施例的各种方法中的全部或部分步骤可以通过指令来完成,或通过指令控制相关的硬件来完成,该指令可以存储于一计算机可读存储介质中,并由处理器进行加载和执行。
为此,本发明实施例提供一种计算机可读存储介质,其中存储有多条指令,该指令能够被处理器进行加载,以执行本发明实施例所提供的任一种方法中的步骤。例如,该指令可以执行如下步骤:
获取训练样本数据,训练样本数据包括样本图像以及与样本图像对应的样本文本;将样本图像输入至神经网络模型,得到神经网络模型输出的图像特征,并基于图像特征与样本文本的文本特征计算第一损失;对样本图像进行掩码处理,得到掩码图像,并将掩码图像输入至神经网络模型,得到神经网络模型输出的掩码特征;对掩码特征进行解码,得到预测图像特征,并基于文本特征将图像特征与预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征;基于第一映射特征与第二映射特征计算第二损失,并根据第一损失和第二损失对神经网络模型的参数进行更新。
或者,获取目标任务对应的目标训练样本数据,目标训练样本数据包括目标样本图像以及目标样本图像对应的标签数据;将目标样本图像输入至神经网络模型中进行预测,得到预测数据,神经网络模型为根据第一方面提供的模型训练方法训练得到的神经网络模型;根据预测数据与标签数据计算预测损失;基于预测损失对神经网络模型的参数进行调整,得到目标任务对应的目标神经网络模型。
以上各个操作的具体实施可参见前面的实施例,在此不再赘述。
其中,该计算机可读存储介质可以包括:只读存储器(ROM,Read Only Memory)、随机存取记忆体(RAM,Random Access Memory)、磁盘或光盘等。
由于该计算机可读存储介质中所存储的指令,可以执行本发明实施例所提供的任一种方法中的步骤,因此,可以实现本发明实施例所提供的任一种方法所能实现的有益效果,详见前面的实施例,在此不再赘述。
其中,根据本申请的一个方面,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在存储介质中。计算机设备的处理器从存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述模型训练方法中各种可选实现方式中提供的方法。
以上对本发明实施例所提供的模型训练方法、装置、计算机可读存储介质及计算机设备进行了详细介绍,本文中应用了具体个例对本发明的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本发明的方法及其核心思想;同时,对于本领域的技术人员,依据本发明的思想,在具体实施方式及应用范围上均会有改变之处,综上,本说明书内容不应理解为对本发明的限制。
Claims (15)
1.一种模型训练方法,其特征在于,所述方法包括:
获取训练样本数据,所述训练样本数据包括样本图像以及与所述样本图像对应的样本文本;
将所述样本图像输入至神经网络模型,得到所述神经网络模型输出的图像特征,并基于所述图像特征与所述样本文本的文本特征计算第一损失;
对所述样本图像进行掩码处理,得到掩码图像,并将所述掩码图像输入至所述神经网络模型,得到所述神经网络模型输出的掩码特征;
对所述掩码特征进行解码,得到预测图像特征,并基于所述文本特征将所述图像特征与所述预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征;
基于所述第一映射特征与所述第二映射特征计算第二损失,并根据所述第一损失和所述第二损失对所述神经网络模型的参数进行更新。
2.根据权利要求1所述的方法,其特征在于,所述基于所述文本特征将所述图像特征与所述预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征,包括:
对所述图像特征以及所述预测图像特征映射到同一图像特征空间,并对映射得到的特征进行归一化处理,得到所述图像特征对应的第一归一化特征以及所述预测图像特征对应的第二归一化特征;
基于所述文本特征将所述第一归一化特征与所述第二归一化特征映射到文本空间,得到所述图像特征对应的第一映射特征以及所述预测图像特征对应的第二映射特征。
3.根据权利要求2所述的方法,其特征在于,所述基于所述文本特征将所述第一归一化特征与所述第二归一化特征映射到文本空间,得到所述图像特征对应的第一映射特征以及所述预测图像特征对应的第二映射特征,包括:
以所述文本特征作为基向量,将所述第一归一化特征映射为在文本空间上的概率分布,得到第一映射特征;
以所述文本特征作为基向量,将所述第二归一化特征映射为在所述文本空间上的概率分布,得到第二映射特征。
4.根据权利要求1所述的方法,其特征在于,所述基于所述图像特征与所述样本文本的文本特征计算第一损失,包括:
将所述图像特征与所述样本文本的文本特征映射到同一特征空间,得到图像映射特征以及文本映射特征;
基于所述图像映射特征以及所述文本映射特征计算第一损失。
5.根据权利要求4所述的方法,其特征在于,所述基于所述图像映射特征以及所述文本映射特征计算第一损失,包括:
对所述图像映射特征进行归一化处理,得到第三归一化特征;
对所述文本映射特征进行归一化处理,得到第四归一化特征;
根据所述第三归一化特征以及所述第四归一化特征计算第一损失。
6.根据权利要求5所述的方法,其特征在于,所述根据所述第三归一化特征以及所述第四归一化特征计算第一损失,包括:
根据所述第三归一化特征与所述第四归一化特征计算所述样本图像对所述样本文本的第一对比学习损失;
根据所述第三归一化特征与所述第四归一化特征计算所述样本文本对所述样本图像的第二对比学习损失;
计算所述第一对比学习损失与所述第二对比学习损失的均值,得到第一损失。
7.根据权利要求1所述的方法,其特征在于,所述对所述样本图像进行掩码处理,得到掩码图像,并将所述掩码图像输入至所述神经网络模型,得到所述神经网络模型输出的掩码特征,包括:
对所述样本图像进行随机掩码,得到掩码图像;
基于所述神经网络模型对所述掩码图像进行图像编码,得到所述样本图像的掩码特征。
8.根据权利要求1所述的方法,其特征在于,所述根据所述第一损失和所述第二损失对所述神经网络模型的参数进行更新,包括:
获取所述第一损失的第一权重系数以及所述第二损失的第二权重系数;
基于所述第一权重系数与所述第二权重系数对所述第一损失和所述第二损失进程加权处理,得到目标损失;
基于所述目标损失对所述神经网络模型的参数进行更新。
9.根据权利要求8所述的方法,其特征在于,所述基于所述目标损失对所述神经网络模型的参数进行更新,包括:
基于所述目标损失确定反传梯度;
根据所述反传梯度进行梯度反传处理,以对所述神经网络模型的参数进行更新。
10.一种模型训练方法,其特征在于,所述方法包括:
获取目标任务对应的目标训练样本数据,所述目标训练样本数据包括目标样本图像以及所述目标样本图像对应的标签数据;
将所述目标样本图像输入至神经网络模型中进行预测,得到预测数据,所述神经网络模型为根据权利要求1至9中任一项所述的模型训练方法训练得到的神经网络模型;
根据所述预测数据与所述标签数据计算预测损失;
基于所述预测损失对所述神经网络模型的参数进行调整,得到所述目标任务对应的目标神经网络模型。
11.一种模型训练装置,其特征在于,所述装置包括:
第一获取单元,用于获取训练样本数据,所述训练样本数据包括样本图像以及与所述样本图像对应的样本文本;
第一计算单元,用于将所述样本图像输入至神经网络模型,得到所述神经网络模型输出的图像特征,并基于所述图像特征与所述样本文本的文本特征计算第一损失;
处理单元,用于对所述样本图像进行掩码处理,得到掩码图像,并将所述掩码图像输入至所述神经网络模型,得到所述神经网络模型输出的掩码特征;
映射单元,用于对所述掩码特征进行解码,得到预测图像特征,并基于所述文本特征将所述图像特征与所述预测图像特征映射到文本空间,得到第一映射特征以及第二映射特征;
更新单元,用于基于所述第一映射特征与所述第二映射特征计算第二损失,并根据所述第一损失和所述第二损失对所述神经网络模型的参数进行更新。
12.一种模型训练装置,其特征在于,所述装置包括:
第二获取单元,用于获取目标任务对应的目标训练样本数据,所述目标训练样本数据包括目标样本图像以及所述目标样本图像对应的标签数据;
预测单元,用于将所述目标样本图像输入至神经网络模型中进行预测,得到预测数据,所述神经网络模型为根据权利要求1至9中任一项所述的模型训练方法训练得到的神经网络模型;
第二计算单元,用于根据所述预测数据与所述标签数据计算预测损失;
调整单元,用于基于所述预测损失对所述神经网络模型的参数进行调整,得到所述目标任务对应的目标神经网络模型。
13.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有多条指令,所述指令适于处理器进行加载,以执行权利要求1至10中任一项所述的模型训练方法中的步骤。
14.一种计算机设备,其特征在于,包括存储器、处理器以及存储在所述存储器中并可以在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现权利要求1至10中任一项所述的模型训练方法中的步骤。
15.一种计算机程序产品,包括计算机程序/指令,其特征在于,所述计算机程序/指令被处理器执行时实现权利要求1至10中任一项所述的模型训练方法中的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211701705.7A CN116977763A (zh) | 2022-12-28 | 2022-12-28 | 模型训练方法、装置、计算机可读存储介质及计算机设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211701705.7A CN116977763A (zh) | 2022-12-28 | 2022-12-28 | 模型训练方法、装置、计算机可读存储介质及计算机设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116977763A true CN116977763A (zh) | 2023-10-31 |
Family
ID=88473730
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211701705.7A Pending CN116977763A (zh) | 2022-12-28 | 2022-12-28 | 模型训练方法、装置、计算机可读存储介质及计算机设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116977763A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117591901A (zh) * | 2024-01-17 | 2024-02-23 | 合肥中科类脑智能技术有限公司 | 绝缘子破损检测方法、装置、存储介质和电子设备 |
-
2022
- 2022-12-28 CN CN202211701705.7A patent/CN116977763A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117591901A (zh) * | 2024-01-17 | 2024-02-23 | 合肥中科类脑智能技术有限公司 | 绝缘子破损检测方法、装置、存储介质和电子设备 |
CN117591901B (zh) * | 2024-01-17 | 2024-05-03 | 合肥中科类脑智能技术有限公司 | 绝缘子破损检测方法、装置、存储介质和电子设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11983850B2 (en) | Image processing method and apparatus, device, and storage medium | |
CN111260653B (zh) | 一种图像分割方法、装置、存储介质和电子设备 | |
CN108920720A (zh) | 基于深度哈希和gpu加速的大规模图像检索方法 | |
CN111754532B (zh) | 图像分割模型搜索方法、装置、计算机设备及存储介质 | |
CN106383891A (zh) | 一种基于深度哈希的医学图像分布式检索方法 | |
CN113298197B (zh) | 数据聚类方法、装置、设备及可读存储介质 | |
CN114329029B (zh) | 对象检索方法、装置、设备及计算机存储介质 | |
CN114282059A (zh) | 视频检索的方法、装置、设备及存储介质 | |
CN114328988A (zh) | 多媒体数据的特征提取方法、多媒体数据检索方法及装置 | |
CN116977763A (zh) | 模型训练方法、装置、计算机可读存储介质及计算机设备 | |
CN113204674A (zh) | 基于局部-整体图推理网络的视频-段落检索方法及系统 | |
CN114358109A (zh) | 特征提取模型训练、样本检索方法、装置和计算机设备 | |
CN117437317A (zh) | 图像生成方法、装置、电子设备、存储介质和程序产品 | |
CN111260074B (zh) | 一种超参数确定的方法、相关装置、设备及存储介质 | |
CN114782209B (zh) | 一种基于社交网络拓扑图的关联用户身份识别方法 | |
CN113824989B (zh) | 一种视频处理方法、装置和计算机可读存储介质 | |
CN115457638A (zh) | 模型训练方法、数据检索方法、装置、设备及存储介质 | |
CN113361510B (zh) | 超分网络模型训练方法、装置、电子设备以及存储介质 | |
CN117036368A (zh) | 图像数据处理方法、装置、计算机设备和存储介质 | |
CN116415624A (zh) | 模型训练方法及装置、内容推荐方法及装置 | |
CN114677535A (zh) | 域适应图像分类网络的训练方法、图像分类方法及装置 | |
CN113822291A (zh) | 一种图像处理方法、装置、设备及存储介质 | |
CN114298961A (zh) | 图像处理方法、装置、设备及存储介质 | |
CN113010772A (zh) | 一种数据处理方法、相关设备及计算机可读存储介质 | |
CN117556275B (zh) | 相关度模型数据处理方法、装置、计算机设备和存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication |