CN116956014A - 分类模型的训练方法、装置、设备及存储介质 - Google Patents

分类模型的训练方法、装置、设备及存储介质 Download PDF

Info

Publication number
CN116956014A
CN116956014A CN202310340644.4A CN202310340644A CN116956014A CN 116956014 A CN116956014 A CN 116956014A CN 202310340644 A CN202310340644 A CN 202310340644A CN 116956014 A CN116956014 A CN 116956014A
Authority
CN
China
Prior art keywords
training data
classification model
data
training
classification
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310340644.4A
Other languages
English (en)
Inventor
王登豹
李蓝青
赵沛霖
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tencent Technology Shenzhen Co Ltd
Original Assignee
Tencent Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tencent Technology Shenzhen Co Ltd filed Critical Tencent Technology Shenzhen Co Ltd
Priority to CN202310340644.4A priority Critical patent/CN116956014A/zh
Publication of CN116956014A publication Critical patent/CN116956014A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本申请公开了一种分类模型的训练方法、装置、设备及存储介质,属于机器学习技术领域。方法包括:从训练集中获取n个训练数据,以及n个训练数据中每一个训练数据分别对应的标签信息;对n个训练数据分别对应的模型输入数据进行混合操作,生成m个混合输入数据;使用分类模型对m个混合输入数据进行特征提取,得到m个混合特征信息;对m个混合特征信息进行解耦操作,得到n个训练数据分别对应的特征信息;根据训练数据对应的分类标签和分类结果,对分类模型的参数进行调整,得到训练后的分类模型;对训练后的分类模型进行置信度校准,得到训练完成的分类模型。本方法既保证分类模型生成的分类结果的准确率,又使得分类结果具有较好的置信度表现。

Description

分类模型的训练方法、装置、设备及存储介质
技术领域
本申请涉及机器学习技术领域,特别涉及一种分类模型的训练方法、装置、设备及存储介质。
背景技术
分类模型用于确定输入数据所属的类别。
相关技术中,在分类模型的训练过程中,计算机设备将原始训练数据进行混合处理,得到混合数据;以及计算机设备对原始训练属于对应的分类标签进行混合,得到混合标签。计算机设备根据分类模型针对混合数据产生的预测结果和混合数据对应的混合标签,确定分类模型的训练损失,进而对分类模型的模型参数进行调整。在训练损失达到收敛的情况下,得到训练完成的分类模型。
然而,通过实验数据发现,这种方法训练得到的分类模型生成的分类结果的置信度较低。
发明内容
本申请提供了一种分类模型的训练方法、装置、设备及存储介质。所述技术方案如下:
根据本申请实施例的一个方面,提供了一种分类模型的训练方法,所述方法包括:
从训练集中获取n个训练数据,以及所述n个训练数据中每一个训练数据分别对应的标签信息,所述标签信息用于表征所述训练数据所属的类别,n为大于1的整数;
对所述n个训练数据分别对应的模型输入数据进行混合操作,生成m个混合输入数据,m为大于或等于n的整数;
使用分类模型对所述m个混合输入数据进行特征提取,得到m个混合特征信息;
对所述m个混合特征信息进行解耦操作,得到所述n个训练数据分别对应的特征信息;
根据所述训练数据对应的分类标签和分类结果,对所述分类模型的参数进行调整,得到训练后的分类模型,所述训练数据对应的分类结果由所述分类模型中的分类预测层基于所述训练数据对应的特征信息得到;
对所述训练后的分类模型进行置信度校准,得到训练完成的分类模型。
根据本申请实施例的一个方面,提供了一种分类模型的训练装置,所述装置包括:
数据获取模块,用于从训练集中获取n个训练数据,以及所述n个训练数据中每一个训练数据分别对应的标签信息,所述标签信息用于表征所述训练数据所属的类别,n为大于1的整数;
输入混合模块,用于对所述n个训练数据分别对应的模型输入数据进行混合操作,生成m个混合输入数据,m为大于或等于n的整数;
特征提取模块,用于使用分类模型对所述m个混合输入数据进行特征提取,得到m个混合特征信息;
特征解耦模块,用于对所述m个混合特征信息进行解耦操作,得到所述n个训练数据分别对应的特征信息;
模型训练模块,用于根据所述训练数据对应的分类标签和分类结果,对所述分类模型的参数进行调整,得到训练后的分类模型,所述训练数据对应的分类结果由所述分类模型中的分类预测层基于所述训练数据对应的特征信息得到;
模型校准模型,用于对所述训练后的分类模型进行置信度校准,得到训练完成的分类模型。
根据本申请实施例的一个方面,提供了一种计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有计算机程序,所述计算机程序由所述处理器加载并执行以实现如上所述的分类模型的训练方法。
根据本申请实施例的一个方面,提供了一种计算机可读存储介质,所述存储介质中存储有计算机程序,所述计算机程序由处理器加载并执行以实现如上所述的分类模型的训练方法。
根据本申请实施例的一个方面,提供了一种计算机程序产品,所述计算机程序产品包括计算机程序,所述计算机程序存储在计算机可读存储介质中,处理器从所述计算机可读存储介质读取并执行所述计算机程序,以实现如上所述的分类模型的训练方法。
本申请实施例提供的技术方案带来的有益效果至少包括:
相比于训练数据进行混合产生混合训练数据,对训练数据对应的分类标签进行混合生成混合标签,并通过混合训练数据的分类结果和混合标签对分类模型的模型参数进行调整对分类模型的分类结果的置信度产生的影响。
一方面,本方法保留了原始数据混合方法中使用分类模型对混合模型模型输入进行处理的过程,丰富了分类模型处理的输入数据的类型,有助于避免训练后的分类模型生成的分类结果时伴随的过自信问题,提升分类模型预测得到的分类结果的准确性。
另一方面,本方法中通过对混合特征信息进行解耦,使得分类模型能够生成训练数据对应的分类结果,而不需要在标记端对训练数据分别对应的分类标签进行混合,混合标签也不会参与训练损失的计算过程。这样能够有效消除原有的数据混合方法对后置信度校准过程带来的负面影响,使得置信度校准后得到的训练完成的分类模型生成的分类结果具有较好的置信度表现。
附图说明
图1是本申请一个示例性实施例提供的方案实施环境的示意图;
图2是本申请一个示例性实施例提供的分类模型的训练方法的流程图;
图3是本申请一个示例性实施例提供的混合解耦方法的示意图;
图4是本申请另一个示例性实施例提供的混合解耦方法的示意图;
图5是本申请一个示例性实施例提供的分类模型的训练装置的框图;
图6是本申请一个示例性实施例提供的计算机设备的结构框图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述。
人工智能(Artificial Intelligence,AI):是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
自然语言处理(Nature Language Processing,NLP):是计算机科学领域与人工智能领域中的一个重要方向。它研究能实现人与计算机之间用自然语言进行有效通信的各种理论和方法。自然语言处理是一门融语言学、计算机科学、数学于一体的科学。因此,这一领域的研究将涉及自然语言,即人们日常使用的语言,所以它与语言学的研究有着密切的联系。自然语言处理技术通常包括文本处理、语义理解、机器翻译、机器人问答、知识图谱等技术。
计算机视觉技术(Computer Vision,CV):是一门研究如何使机器“看”的科学,更进一步的说,就是指用摄影机和电脑代替人眼对目标进行识别和测量等机器视觉,并进一步做图形处理,使电脑处理成为更适合人眼观察或传送给仪器检测的图像。作为一个科学学科,计算机视觉研究相关的理论和技术,试图建立能够从图像或者多维数据中获取信息的人工智能系统。计算机视觉技术通常包括图像处理、图像识别、图像语义理解、图像检索、视频处理、视频语义理解、视频内容/行为识别、三维物体重建、3D技术、虚拟现实、增强现实、同步定位与地图构建等技术,还包括常见的人脸识别、指纹识别等生物特征识别技术。
机器学习(Machine Learning,ML):是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、示教学习等技术。
过自信(Over Confident):是指训练后的机器学习模型出现的过度拟合。过自信使得训练后的机器学习模型只能对于训练数据相似的输出数据具有较好的处理效果,也即训练后的机器学习模型的泛化能力较差。
欠自信(Under Confident):是指机器学习模型不能很好地从训练数据中,学习到有用的数据特征。训练后的机器学习模型针对训练数据和待预测的数据,均不能获得很好的预测效果。
随着人工智能技术研究和进步,人工智能技术在多个领域展开研究和应用,例如在计算机视觉,自然语言处理等方面,通过训练后的分类模型对输入数据的类别进行预测。相信随着技术的发展,人工智能技术将在更多的领域得到应用,并发挥越来越重要的价值。
本申请实施例提供的方案涉及分类模型的训练方法。本方案中,通过对训练数据进行混合操作,得到混合输入。使用分类模型对混合输入进行特征提取得到混合特征;再对这些混合特征进行解耦得到各个训练数据分别对应的特征信息。基于训练数据的标签信息以及训练数据对应的特征信息确定出的分类预测,对分类模型进行有监督训练。通过这种方法有助于在保证模型的分类准确率的同时,提升分类模型生成的预测结果的置信度。通过这种方案训练得到的分类模型在智慧医疗、智能导航、图像识别等对模型准确度以及置信度要求较高的领域具有广泛的应用场景。
图1是本申请一个示例性实施例提供的方案实施环境的示意图。该方案实施环境可以包括:计算机设备10、终端设备20和服务器30。
计算机设备10包括但不限于个人计算机(Personal Computer,PC)、手机、平板电脑等运算和存储能力的电子设备。在一些实施例中,计算机设备10上设置有分类模型,计算机设备10通过分类模型对待分类数据进行处理,确定待分类数据对应的分类结果。可选地,分类模型的训练过程可以在计算机设备10上完成。可选地,分类模型的训练过程在除计算机设备10之外的其他设备上完成,其他设备将训练完成的分类模型发送给计算机设备10,使得计算机设备10获得分类模型。
终端设备20可以是诸如个人计算机、平板电脑、手机、可穿戴设备、智能家电、车载终端等电子设备。终端设备20上运行有目标应用程序的客户端。目标应用程序能够为用户提供待分类数据的分类功能。待分类数据的类型根据目标应用程序的功能确定,待分类数据的类型包括但不限于以下至少之一:图像数据、文本数据和音频数据。例如,目标应用程序用于辅助医生进行病情识别,待分类数据为就诊者的医疗图像。又例如,目标应用程序用于进行辅助驾驶或者自动驾驶,待分类数据为驾驶设备通过摄像设备采集到的路况图像。
此外,目标应用程序还可以是新闻类应用程序、购物类应用程序、社交类应用程序、互动娱乐类应用程序、浏览器应用程序、内容分享类应用程序、虚拟现实类应用程序、增强现实类应用程序等,本申请实施例对此不作限定。另外,对于不同的应用程序来说,其处理的待分类数据可以不同,且相应的功能也会有所不同,这都可以根据实际需求预先进行配置,本申请实施例对此不作限定。
服务器30用于为终端设备20中的目标应用程序的客户端提供后台服务。例如,服务器30可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务内容分发网络、(Content Delivery Network,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器,但并不局限于此。
服务器30至少具有数据接收功能、存储和运算功能。服务器30用于获取终端设备20提供的待分类数据,并通过计算机设备10生成待分类数据的分类结果。服务器30将待分类数据的分类结果反馈给终端设备10。
在一个示例中,计算机设备10可以部署在服务器30上,计算机设备10也可以是除了服务器30之外的其他设备。在一个示例中,计算机设备10与终端设备20是同一台设备,在这种情况下,终端设备20无需与服务器进行通信,就能够自行确定待分类数据的分类结果。
图2是本申请一个示例性实施例提供的分类模型的训练方法的流程图。示例性地,该方法的执行主体可以是图1中的计算机设备10或者训练分类模型中使用的其他设备。为了叙述方便,下面将计算机设备作为执行主体,对分类模型的训练方法进行介绍说明。如图2所示,该方法可以包括如下几个步骤(210~260):
步骤210,从训练集中获取n个训练数据,以及n个训练数据中每一个训练数据分别对应的标签信息,标签信息用于表征训练数据所属的类别,n为大于1的整数。
在一些实施例中,训练集是指对分类模型训练过程中使用的数据集合。可选地,训练集中包括多个训练数据和每一个训练数据分别对应的标签信息。
在一些实施例中,训练数据是指分类模型训练过程中需要使用的数据,可选定,计算机设备将训练集中训练数据进行处理,生成分类模型的输入数据。计算机设备使用分类模型对输入数据进行处理,得到输入数据对应的分类结果。有关该过程的具体内容请参考下文实施例。
可选地,训练数据的数据类型包括以下至少之一:图像数据、文本数据和语音数据。训练数据的数据类型根据分类模型的分类功能确定,本申请在此不进行限定。
在一些实施例中,训练数据对应的标签信息用于表征训练数据所属的类别。可选地,训练集中包括至少两个属于不同类别的训练数据。
在一些实施例中,训练集中的多个训练数据对应于a个类别,某个训练数据对应a个类别中的至少一个。例如,训练集中的多个训练数据对应于4个类别,分别为类别1、类别2、类别3和类别4,训练数据a的标签信息为[0,0,0,1],则训练数据a属于类别4。当然,训练数据也可以属于a个类别中的多个。例如在多标签分类领域中,训练集中的多个训练数据对应于3个类别,分别为类别A1、类别A2和类别A3,某个训练数据b的标签信息为[0,1,1],则说明该训练数据b同时属于类别A2和类别A3。
在一些实施例中,训练集可以通过以下形式表示:
其中,Ttrain表示训练集,p表示训练集中包括的样本数量,训练集是p个训练样本xi和其对应标签信息yi的集合。标签信息yi的标签空间为Ctrain=c1,c2,…,cK,ci表示类别i,其中cK表示第K类别,K为训练集中多个训练集涉及的类别总数。
可选地,训练数据对应的标签信息通过人工标注或者机器分类的方式获得。在一个示例中,计算机设备通过网络下载现有的数据集作为分类模型的训练集。在另一个示例中,计算机设备收集工作人员标注的至少一个训练数据,得到训练集。
在一些实施例中,为了验证通过训练集对分类模型进行训练生成的训练后的分类模型的分类效果,还存在至少一个验证集,验证集中包括至少一个验证数据和验证数据分别对应的标签信息。可选地,训练集中包括的训练数据和验证集中包括的验证数据不完全重合的。
在一些实施例中,在分类模型的训练过程中,计算机设备从训练集中选择n个训练数据。可选地,n个训练数据中包括至少两个属于不同类别的训练数据。
例如,n个训练数据均属于不同的类别。假设,n等于3,n个训练数据中包括训练数据1、训练数据2和训练数据3;若训练数据1对应的标签信息为[0,0,0,1]、训练数据2对应的标签信息为[0,1,0,0]、训练数据3对应的标签信息为[1,0,0,0],则训练数据1、训练数据2和训练数据3分别属于不同的类别。
又例如,n等于4,n个训练数据中包括训练数据1、训练数据2、训练数据3和训练数据4;假设训练数据1对应的标签信息为[0,0,0,1]、训练数据2对应的标签信息为[0,1,0,0]、训练数据3对应的标签信息为[1,0,0,0]、练数据4对应的标签信息为[0,1,0,0],也即训练数据1、训练数据2和训练数据3分别属于不同的类别,训练数据2和训练数据4属于同一类别。
通过从训练集中挑选属于至少两个类别的训练数据,有助于在后续步骤基于n个训练数据进行混合后,得到包括多个类别的特征信息的混合数据(如下文中的混合模型输入)。通过这种方法有助于解决训练得到的分类模型的过自信问题,有助于提升训练后的分类模型的分类准确率。
在一个示例中,计算机设备根据训练集包括的多个训练数据分别对应的标签信息,确定n个训练数据。可选地,计算机设备从训练集中挑选标签信息不完全相同的n个训练数据。
在另一个示例中,计算机设备从训练集中随机挑选n个训练数据,由于训练集中包括属于不同类别的训练数据,通过这种方法也能大概率使得n个训练数据中包括至少两个属于不同类别的训练数据。
在一些实施例中,n个训练数据组成一个训练数据组,计算机设备使用分类模型对一个训练数据组进行同步处理,得到n个训练数据分别对应的特征信息,有关该过程的具体内容,请参考下文实施例。
可选地,计算机设备预先将训练集划分成为多个训练数据组,不同的训练数据组中包括的训练数据不重合,计算机设备每次使用分类模型确定一个训练数据组中包括的每一个训练数据分别对应的特征数据。通过这种方法,有助于提升计算机设备从训练集中确定n个训练样本的速度,简化计算机设备确定n个训练样本过程的处理逻辑。
在一个实施例中,n等于2,也即计算机设备从训练集中挑选两个训练数据x1和x2,该两个训练数据可以以训练样本对(x1,x2)的形式表示,计算机设备通过分类模型一次确定x1和x2分别对应的分类结果。
步骤220,对n个训练数据分别对应的模型输入数据进行混合操作,生成m个混合输入数据,m为大于或等于n的整数。
在一些实施例中,训练数据对应的模型输入数据是指训练数据输入分类模型之前的表示形式。可选地,训练数据对应的模型输入数据用于以向量或者矩阵的形式表征训练数据的数据内容。例如训练数据对应的模型输入数据是指训练数据的嵌入表征(embedding)。
例如,训练数据属于文本类别,计算机设备通过对训练数据进行分词确定至少一个具有独立语义的独立字,并将每一个独立子使用该独立字对应的向量表征进行替换,得到该训练数据的模型输入数据。
又例如,训练数据属于图像类别,计算机设备通过嵌入生成模型对训练数据进行处理,得到训练数据对应的模型输入数据。嵌入生成模型的编码器用于将图像转化为编码信息,嵌入生成模型的解码器对编码信息进行解码,得到图像对应的模型输入数据。
需要说明的是,训练数据对应的模型输入数据的确定方法,根据实际需要进行确定,本申请在此不进行限定。
在一些实施例中,混合操作用于对多个信息进行混合得到混合信息(如下文中的混合输入数据或者第k层的混合输入特征)。可选地,混合操作包括线性混合,也即通过多个信息进行加权求和得到混合信息。在一些实施例中,混合操作称为(mixing)。
在一些实施例中,混合输入数据包括训练数据分别对应的模型输入数据分别具有的内容信息。可选地,混合输入数据通过对n个训练数据分别对应的模型输入数据进行加权相加得到。
在一些实施例中,计算机设备对n个训练数据分别对应的模型输入数据进行m次混合操作,得到m个混合输入数据。也即,计算机设备每对n个训练数据分别对应的模型输入数据进行1次混合操作,生成一个混合输入数据。
可选地,计算机设备通过混合参数对n个训练数据分别对应的模型输入数据进行处理,生成一个混合输入数据;其中,混合参数用于确定n个训练数据分别对应的模型输入数据在混合输入数据中所占的比例。有关该过程的具体内容请参考下文实施例。
在一些实施例中,m是大于或者等于n的正整数。在实际模型训练的过程中,可以设置m=n,通过这种方式能够减少计算机设备的计算工作量,提升分类模型的训练速度效率。
可选地,m个混合输入数据中任意两个混合输入数据之间不成比例。通过这种方法有助于避免m个混合输入数据中出现重复,影响后续的解耦过程。有关m个混合输入数据的生成方式请参考下文实施例。
在一些实施例中,训练数据对应的模型输入数据还可以理解成使用分类模型中某个隐藏层进行特征提取之前,训练数据的表达形式。在这种情况下,计算机设备可以不进行步骤220,可选地,在从训练集中确定n个训练数据之后,计算机设备将n个训练数据分别对应的模型输入数据传输给分类模型。也即在这个实施例中,计算机设备无需在将训练数据分别对应的模型输入数据进行混合。可选地,在分类模型包括的某个隐藏层开始工作之前,计算机设备将n训练数据在隐藏层对应的模型输入数据进行混合操作,得到多个混合输入数据(也可以称为混合输入特征)。有关该过程的具体内容请参考下文实施例。
步骤230,使用分类模型对m个混合输入数据进行特征提取,得到m个混合特征信息。
在一些实施例中,混合特征信息用于表征混合输入数据的特征信息。计算机设备通过分类模型中分类预测层之前的隐藏层对某个混合输入数据进行处理,生成该混合输入数据对应的混合特征信息。可选地,处理过程包括卷积、池化、全连接等操作,本申请在此不进行限定。
在一些实施例中,计算机设备通过分类模型依次对m个混合输入数据进行特征提取,生成m个混合特征信息。可选地,对于m个混合输入数据中的第j个混合输入数据,计算机设备将第j个混合输入数据传递给分类模型的输入端,通过分类模型中包括的至少一个隐藏层对第j个混合输入数据进行处理,得到第j个混合输入数据对应的混合特征信息,j为小于等于m的正整数。
也即上述过程中,计算机设备分别将m个混合输入数据传递给分类模型的输入端,分类模型进行特征提取后由分类模型的输出端产生的m个特征信息;其中,分类模型的输出端是指分类模型中分类预测层之前最后的一个隐藏层的输出端。
在一些实施例中,混合特征信息可以理解成,通过分类模型对混合输入数据进行神经网络前向传导所生成的隐含特征。
为了提升确定m个混合特征数据过程的效率,计算机设备分别将m个混合输入数据发送给多个分类模型,使用该多个分类模型确定出m个混合特征信息。可选地,多个分类模型具有完全相同的模型参数,且多个分类模型中的不同分类模型分别运行在系统资源互不冲突的线程(或者进程中)。
例如,多个分类模型包括2个分类模型,m等于4,计算机设备将第1个混合输入数据和第2个混合输入数据,发送给第1个分类模型;计算机设备将第3个混合输入数据和第4个混合输入数据,发送给第2个分类模型。在确定出第3个混合输入数据和第4个混合输入数据分别对应的混合特征信息之后,第2个分类模型将第3个混合输入数据和第4个混合输入数据分别对应的混合特征信息发送给第1个分类模型。通过这种方法提升计算机设备生成m个混合特征信息的速度。通过1个分类模型对m个混合输入数据进行处理需要消耗m个单位的时间。通过c个相同的分类模型对m个混合输入数据进行处理,可在(m/c+1)个单位的时间内完成m个特征信息的生成过程,对比发现使用多个分类模型进行同步处理,有助于加快分类模型的训练速度,缩短分类模型的训练周期。
步骤240,对m个混合特征信息进行解耦操作,得到n个训练数据分别对应的特征信息。
在一些实施例中,解耦操作用于从混合特征信息中,确定训练数据对应的特征信息。计算机设备通过对m个混合特征信息进行解耦操作,得到n个训练数据分别对应的特征信息。可选地,解耦操作称为decoupling。
在一些实施例中,计算机设备从m个混合特征信息中选择n个混合特征信息,通过对n个混合特征信息进行解耦操作,得到n个训练数据分别对应的特征信息。
通过上文叙述可知,计算机设备在分类模型的输入端对n个训练数据分别对应的模型输入信息进行混合操作,得到m个混合输入数据。分类模型确定m个混合输入数据分别对应的混合特征信息。在混合过程中,训练数据对应的模型输入数据与混合输入数据之间存在第一对应关系(第一对应关系由混合操作的方法决定);训练数据对应的特征信息和由分类模型输出的混合特征信息之间也存在第一对应关系,因此计算机设备基于分类模型生成的m个混合特征信息和第一对应关系,能够确定n个训练数据分别对应的特征信息。根据多个混合特征信息确定训练数据对应的特征信息的过程即为解耦操作,有关该步骤的实现请参考下文实施例。
通过这种方法,虽然没有使用分类模型直接确定训练数据对应的特征信息,但是能够混合特征信息中从解耦得到n个训练数据分别对应的特征信息,后续计算机设备能够根据训练数据对应的特征信息,确定每一个训练数据对应的分类结果。相比于相关技术中,通过分类模型确定混合输入数据对应的分类结果,有助于消除在模型参数调整过程中,标签端的混合操作对分类模型置信度的影响,有助于提高分类模型的置信度。
步骤250,根据训练数据对应的分类标签和分类结果,对分类模型的参数进行调整,得到训练后的分类模型,训练数据对应的分类结果由分类模型中的分类预测层基于训练数据对应的特征信息得到。
在一些实施例中,训练数据对应的分类结果是指通过分类模型进行分类预测得到的预测结果。可选地,训练数据对应的分类结果包括该训练数据属于至少一个类别的概率。例如,训练数据集对应有4种类别,训练数据1对应的分类结果表示成[0.95,0.01,0,0],训练数据2对应的分类结果表示成[1,0,0,0]。
在一些实施例中,在进行上述解耦操作生成n个训练数据分别对应的特征信息之后,计算机设备通过分类模型的分类预测层对训练数据对应的特征信息进行处理,得到训练数据对应的分类结果。
可选地,分类预测层是指分类模型中生成分类结果的网络层,分类预测层中包括激活函数,激活函数用于对训练数据对应的特征数据进行处理,得到训练数据属于至少一个类别的概率。例如,分类预测层是指softmax层,用于将训练数据对应的特征信息映射到(0,1)分布中,得到训练数据对应的分类结果。
在一些实施例中,在分类模型生成n个训练数据分别对应的特征信息之后,通过分类模型中的分类预测层分别对训练数据对应的特征信息进行处理,生成n个训练数据分别对应的分类结果。
在一些实施例中,计算机设备根据训练数据对应的分类标签和分类结果,确定分类模型的训练损失;计算机设备根据训练损失对分类模型的模型参数进行调整,得到调整后的分类模型。之后,计算机设备从训练集中重新确定n个训练数据,并执行步骤220-步骤250。在分类模型的训练损失达到收敛的情况下,计算机设备得到训练后的分类模型。
可选地,计算机设备根据训练数据对应的分类标签和分类结果,确定分类模型的训练损失,包括:计算机设备使用训练数据对应的分类标签和分类结果计算交叉熵(CrossEntropy);计算机设备将交叉熵作为该训练批次模型分类模型的训练损失。
在一些实施例中,一个训练批次中包括多个训练数据组,也即计算机设备使用分类模型确定各个训练数据组分别包括的n个训练数据分别对应的分类结果,并根据多个训练数据组分别包括的n个训练数据分别对应的分类标签和分类结果,计算本批次训练过程中分类模型的训练损失。
在步骤230包括的一个实施例中,计算机设备使用不同线程中的多个分类模型共同确定m个混合特征信息,在这种情况下,计算机设备在确定出训练损失之后,根据训练损失同步对多个分类模型的模型参数进行调整,以便保证下一个批次的训练开始之前,多个分类模型的模型参数一致,避免在下一次生成m个混合特征信息的过程中,引入不必要的误差信息。
步骤260,对训练后的分类模型进行置信度校准,得到训练完成的分类模型。
在一些实施例中,为了提升训练后的分类模型的置信度,纠正训练后的分类模型的过自信程度,或者欠自信的程度,计算机设备对训练后的分类模型进行置信度校准,提升训练后的分类模型进行不确定性估计的能力。有关该步骤的具体内容请参考上文实施例。
综上所述,相比于训练数据进行混合产生混合训练数据,对训练数据对应的分类标签进行混合生成混合标签,并通过混合训练数据的分类结果和混合标签对分类模型的模型参数进行调整对分类模型的分类结果的置信度产生的影响。
一方面,本方法保留了原始数据混合方法中使用分类模型对混合模型模型输入进行处理的过程,丰富了分类模型处理的输入数据的类型,有助于避免训练后的分类模型生成的分类结果时伴随的过自信问题,提升分类模型预测得到的分类结果的准确性。
另一方面,本方法中通过对混合特征信息进行解耦,使得分类模型能够生成训练数据对应的分类结果,而不需要在标记端对训练数据分别对应的分类标签进行混合,混合标签也不会参与训练损失的计算过程。这样能够有效消除原有的数据混合方法对后置信度校准过程带来的负面影响,使得置信度校准后得到的训练完成的分类模型生成的分类结果具有较好的置信度表现。
下面通过几个实施例对混合操作的方法进行介绍说明。
在一些实施例中,计算机设备对n个训练数据分别对应的模型输入数据进行混合操作,生成m个混合输入数据,包括:计算机设备确定m组混合参数,每一组混合参数包括n个训练数据分别对应的混合系数;对于m组混合参数中的第i组混合参数,计算机设备基于第i组混合参数中包括的n个训练数据分别对应的混合系数,对n个训练数据分别对应的模型输入数据进行加权求和,得到第i个混合输入数据,i为小于或等于m的整数。
在一些实施例中,混合参数用于对n个训练数据分别对应的模型输入数据进行混合,生成混合输入数据。可选地,m个混合参数均不同。
在一些实施例中,混合操作属于线性混合。混合参数中包括n个训练数据分别对应的混合系数。可选地,训练数据分别对应的混合系数用于确定训练数据对应的模型输入信息在混合输入数据中所占的比重。
可选地,一组混合参数包括n个训练数据分别对应的混合系数均为小于1的正小数。
在一些实施例中,为了控制n个训练数据分别对应的模型输入数据在混合输入特征中的占比,混合参数中包括n个训练数据分别对应的混合系数之间具有数值关系。可选地,混合参数中包括n个训练数据分别对应的混合系数的加和等于1。例如,n等于2,m等于2,某混合参数中包括两个混合系数,分别为混合系数f和混合系数g,且f+g=1,也即混合系数g可以表示成为1-f。
下面通过几个实施例对计算机设备确定m个混合系数的过程进行介绍说明。
在一些实施例中,对于m组混合参数中的第i组混合参数,计算机设备从随机分布中选择z个混合系数,z为小于或者等于n的正整数,若z个混合系数满足系数选择条件,则计算机设备根据该z个混合系数,确定第i组混合参数。
在一些实施例中,随机分布用于生成至少一个随机数。可选地,随机分布为Beta分布,该类型分布中包括至少一个属于[0,1]区间的小数。
可选地,z等于n或者n-1。例如,在n个训练数据分别对应的混合系数之间不存在数值关系的情况下,z等于n,计算机设备从随机分布中任意选择n个混合系数,并将该n个混合系数确定为第i组混合系数。
又例如,在n个训练数据分别对应的混合系数之间存在数值关系(n个训练数据分别对应的混合系数之和等于1)的情况下,z等于n-1,计算机设备从随机分布中任意选择前n-1个混合系数,并计算n-1个混合系数之和;计算机设备通过1减去n-1个混合系数之和,得到第n个随机系数。
在一些实施例中,系数选择条件用于确定某个混合系数中包括的多个训练系数的合理性,以及避免多个混合系数之间存在等比例关系。可选地,系数选择条件包括以下至少之一:同一个混合参数中包括的n个混合系数中任意n-1个混合系数之和小于或者等于1、不同混合参数不相等。
可选地,系数选择条件还包括:某个训练数据在任意两个混合参数中分别对应的两个混合系数之差小于等于系数阈值。在一些实施例中,系数阈值是预设置的,例如,系数阈值等于0.5。为了方便举例,假设m=n=2,则计算机设备需要确定两个混合参数。其中,训练数据1在第1个混合参数中对应的混合系数为r,训练数据1在第1个混合参数中对应的混合系数为s,则计算机设备需要控制|r-s|大于或者等于0.5。通过控制同一个训练数据对应的不同混合系数之间的差距大于等于系数阈值,有助于减少通过混合系数对训练数据对应的模型输入数据进行加权求和得到的混合输入数据的多样性,通过这种方法有助于提升训练后的分类模型的鲁棒性。
图3是本申请一个示例性实施例提供混合解耦方法的示意图。
假设n等于2,m等于2,计算机设备从训练集中挑选出两个训练数据,分别为训练数据1和训练数据2。假设,训练数据1对应的模型输入数据表示为xa,训练数据2对应的模型输入数据表示为xb。计算机设备对训练数据1和训练数据2分别对应的模型输入数据进行两次混合操作,分别生成混合输入数据1和混合输入数据2;计算机设备确定2个混合参数,随机参数1中包括λ1和(1-λ1),随机参数2中包括λ2和(1-λ2)。
上述实施例中的混合输入数据1和混合输入数据2可以通过以下公式表示:
其中,表示混合输入数据1,/>表示混合输入数据2,公式中的其他参数的含义请参考上一段内容,在此不进行赘述。
可选地,生成混合输入数据的混合操作称为线性插值。
相比于仅使用训练集中的训练数据对分类模型进行训练,通过对训练数据分别对应的模型输入特征进行混合处理生成混合输入数据,使用混合输入数据参与分类模型的训练过程,丰富了分类模型处理数据的多样性。在分类模型处理混合输入数据,有助于避免训练后的分类模型出现过自信现象。提升分类模型在实际使用过程中产生的分类结果的准确性。
下面,通过几个实施例对解耦操作的进行方法进行介绍说明。
在一些实施例中,计算机设备对m个混合特征信息进行解耦操作,得到n个训练数据分别对应的特征信息,包括:计算机设备根据m组混合参数和m个混合特征,确定n个训练数据分别对应的特征信息。
在一些实施例中,训练数据对应的特征信息可以理解成分类模型对训练数据对应的模型输入信息进行特征提取,得到的特征信息。
注意在本方法中,计算机设备并不是通过分类模型直接对训练数据对应的模型输入数据进行处理,得到训练数据对应的特征信息。由于训练数据对应的模型输入数据和混合输入数据之间存在对应关系,训练数据对应的特征信息和分类模型生成的混合特征信息也存在相同的对应关系,因此计算机设备可以通过对m个混合特征和m组混合参数进行解耦操作,得到n个训练数据分别对应的特征信息。
例如,若计算机设备对n个训练数据分别对应的模型输入数据进行线性插值,生成混合输入数据,则n个训练数据分别对应的特征信息与混合特征信之间也存在相同的线性插值关系。
假设n=m=2,若混合输入数据和训练数据对应的模型数据输入之间存在以下关系:
其中,表示混合输入数据1,/>表示混合输入数据2,xa表示一个训练数据,xb表示另一个训练数据,λ1和(1-λ1)为第1个混合参数中包括的两个混合系数,λ2和(1-λ2)为第2个混合参数中包括的两个混合系数。
则混合特征信息和训练数据对应的特征信息之间存在以下关系:
其中,表示分类模型对混合输入数据1进行特征提取得到的混合特征信息,表示分类模型对混合输入数据2进行特征提取得到的混合特征信息,/>表示分类模型对训练数据1对应的模型输入数据xa进行特征提取能够产生的特征信息,/>表示分类模型对训练数据2对应的模型输入数据xb进行特征提取能够产生的特征信息;λ1表示第1个混合参数中训练数据1对应的混合系数,(1-λ1)表示第1个混合参数中训练数据2对应的混合系数;λ2表示第2个混合参数中训练数据1对应的混合系数,(1-λ2)表示第2个混合参数中训练数据2对应的混合系数。
通过上述两个公式对训练数据1对应的特征信息和训练数据2对应的特征信息进行反表示可以得到:
有关上述公式的参数解释请参考上文,在此不进行赘述。由于上述两个公式中各个参数均为已知的,因此计算机设备可以得到确定的和/>以便后续根据/>训练数据1对应的分类标签和训练数据2对应的分类标签对分类模型的模型参数进行训练损失。
相关技术中,需要对n个训练数据分别对应分类标签进行混合处理,得到混合标签;并使用分类模型的分类预测层根据混合特征信息确定混合输入数据对应的分类结果,根据混合标签和混合输入数据对应的分类结果分类模型的模型参数进行调整。而本方法中无需在标签端对n个训练数据分别对应的分类标签进行混合操作;在生成混合特征信息之后,计算机设备对混合特征信息进行解耦生成n个训练数据分别对应的特征信息,并通过分类模型的分类预测层根据n个训练数据分别对应的特征信息,生成n个训练数据分别对应的分类结果。本方法中虽然使用的数据混合的方法,将混合数据输入到分类模型中,提升的分类模型处理的数据类型的准确度;但是通过进行解耦操作避免生成混合标签,以及混合模型输入对应的分类结果。通过使用个训练数据分别对应的特征信息和分类结果,对模型参数进行调整本方法中无需生成不正确的混合标签,有助于避免模型参数调整过程中,混合模型输入对应的混合标签以及分类结果对分类模型的置信度的损害。
在上文的实施例中,介绍了在分类模型的输入端进行混合操作,在分类模型的分类预测层之前输出端进行解耦操作的特征提取方法。本申请提供的混合-解耦方法同样可以应用与分类模型中某个(或者多个)隐藏层的输入端以及输出端。下面通过几个实施例对这种方法进行介绍说明。
在一些实施例中,分类模型中的第k个隐藏层用于对p个混合输入特征进行处理,输出p个混合输出特征;其中,p个隐层输入特征是对n个训练数据在第k个隐藏层的输入特征进行混合操作得到的,p个混合输出特征用于解耦得到n个训练数据在第k个隐藏层的输出特征,k为正整数,p为大于或等于n的整数。
在一些实施例中,隐藏层用于对输入数据进行提取。可选地,隐藏层是指分类模型中包括神经网络层。对于不同结构的分类模型,隐藏层属于的类型不完全相同。例如,若使用残差卷积网络(Residual Networks,ResNets)网络作为分类模型,隐藏层是残差卷积网中的一个或者多个Block(块)。又例如,若使用多层感知机(Multi-Layer Perceptron,MLP)作为分类模型,则隐藏层是指任一个用于进行特征提取的神经网络层。
在一些实施例中,隐藏层是指位于分类模型的分类预测层之前的任意一个特征提取层。隐藏层的类型与分类模型所属的模型类型有关,本申请在此不进行限定。
在一些实施例中,第k个隐藏层是指分类模型中包括的隐藏层中的任一个。例如,分类模型中包括10个隐藏层,k等于2、4、6、8、10。也即,计算机设备需要在第2、4、6、8、10个隐藏层分别对应的输入端对n个训练数据分别对应的输入特征进行混合操作,在第2、4、6、8、10个隐藏层分别对应的输入端对处理得到的p个混合特征数据进行解耦操作,生成n个训练数据在第2、4、6、8、10个隐藏层分别对应的输出特征。
在一些实施例中,计算机设备在分类模型中包括的每一个隐藏层的输入端进行混合操作,在每一个隐藏层的输出端进行解耦操作。也即,若分类模型中包括y个隐藏层,则k=1、2、3、…、y。需要说明的是k的取值根据实际需要进行设定,本申请在此不进行限定。
在一些实施例中,第k个隐藏层对应的混合输入特征是指需要输入第k个隐藏层进行特征提取的中间特征。可选地,计算机设备通过分类模型中的第k个隐藏层对p个混合输入特征分别进行处理,输出p个混合输出特征,包括:计算机设备将p个混合输入特征依次输入到的第k个隐藏层的输入端,并得到第k个隐藏层的输出端传输的p个混合输入特征。
在一些实施例中p是大于或者等于n的正整数。可选地,p等于n。
在一些实施例中,第k个隐藏中的p个混合输入特征中的任意一个混合输入特征,是通过n个训练数据在第k个隐藏层分别对应的输入特征进行混合操作生成的。
可选地,计算机设备确定p个隐藏混合参数,隐藏混合参数中包括n个训练数据分别对应的隐藏混合系数;对于p个隐藏混合参数中的任意一个隐藏混合参数,计算机设备使用n个训练数据分别对应的隐藏混合系数,对n个训练数据在第k个隐藏层分别对应的输入特征进行加权求和,得到第k个隐藏层的混合输入特征。该方法的具体过程与生成混合输入数据的过程相似,具体请参考上文实施例,在此不进行赘述。
图4是本申请另一个示例性实施例提供混合解耦方法的示意图。
在一些实施例中,若第k个隐藏层无需进行混合-解耦操作,则n个训练数据在第k个隐藏层分别对应的输入特征是指第k-1个隐藏层的输出端的输出特征;若第k个隐藏层进行了混合-解耦操作,则n个训练数据在第k个隐藏层分别对应的输入特征是由:计算机设备对第k-1个隐藏层输出的t个混合输出特征进行解耦的得到,t为大于或者等于n的正整数。
在一些实施例中,计算机设备对p个混合输出特征进行解耦操作,得到n个训练数据在第k个隐藏层的输出特征,包括:计算机设备从p个混合输入特征中选择n个混合输出特征,计算机设备对n个混合输出特征进行解耦操作,生成n个训练数据在第k个隐藏层的输出特征。
可选地,计算机设备根据第k个隐藏层对应的隐藏混合参数,以及p个混合输出特征中的n个混合输出特征,确定n个训练数据在第k个隐藏层的输出特征。该过程和对m个混合特征信息进行解耦处理,生成n个训练数据分别对应的特征信息的方法相同,具体请参考上文实施例,在此不进行赘述。
在一些实施例中,若k小于分类模型具有的隐藏层总数,且第k+1个隐藏层无需进行混合-解耦操作,则计算机设备将n个训练数据在第k个隐藏层的输出特征,作为n个训练数据在第k+1个隐藏层的输入特征;计算机设备通过第k+1个隐藏层对n个训练数据在第k+1个隐藏层的输入特征进行处理,分别得到n个训练数据在第k+1个隐藏层的输出特征。
在一些实施例中,若k小于分类模型包括的隐藏层总数,且第k+1个隐藏层需进行混合-解耦操作,则计算机设备将n个训练数据在第k个隐藏层的输出特征进行混合操作,生成第k+1个隐藏层的s个混合输入特征;计算机设备通过第k+1个隐藏层对s个混合输入特征进行处理,分别得到k+1个隐藏层的s个混合输出特征,计算机设备通过对s个混合输出特征进行解耦,得到n个训练数据在第k+1个隐藏层的输出特征。
在一些实施例中,若k等于分类模型包括的隐藏层总数,则计算机设备将n个训练数据在第k个隐藏层的输出特征,作为n个训练数据分别对应的特征信息;计算机设备通过分类模型的分类预测层对n个训练数据在第k个隐藏层的输出特征进行处理,得到n个训练数据分别对应的分类结果。
在一些实施例中,上述混合以及解耦的方式不参与分类模型的训练过程。例如上述混合以及解耦的过程通过线性差值的方式实现,隐藏层的输入端之前需进行混合操作,可以通过在隐藏层的输入端之前插入一段固定的程序实现。还可以通过在隐藏性层之前插入具有混合功能的程序实现(解耦操作实现方法同理),而无需对现有的分类模型的框架进行改造,因此,本方法能够广泛适用于不同结构的分类模型的训练过程。
通过上述方法,提升各个隐藏层所处理输入特征的丰富度,有助于避免分类模型出现过自信的情况,有助于提升分类模型生成的分类结果的分类准确性。
通过在至少一个隐藏层的输入端进行混合操作,并在该隐藏层的输出端进行解耦操作,使得分类模型的训练层次更加丰富。这种方法有助于实现对某个或者某些隐藏层进行针对性训练,细化了分类模型训练过程中的训练损失优化方向,有助于使得训练后的分类模型生成的分类结果更加准确。
在一些实施例中,计算机设备对训练后的分类模型进行置信度校准,得到训练完成的分类模型,包括:计算机设备基于验证集中包括的验证数据,对训练后的分类模型进行置信度校准,得到第一温度值,第一温度值用于调节分类预测层在多个类别分别对应的预测概率之间的差异;计算机设备将第一温度值应用于训练后的分类模型,得到训练完成的分类模型。
在一些实施例中,得到训练后的分类模型之后,为了提升训练后的分类模型的置信度,计算机设备需要对训练后的分类模型进行置信度校准。置信度校准用于提升训练后的分类模型的置信度。可选地,在得到训练后的分类模型之后进行置信度校准的过程成为后置信校准。置信度校准方法包括:温度调节(Temperature Scaling)。
分类模型的置信度可以理解成通过分类模型对输入数据进行处理生成的分类结果的可信程度。例如,分类模型确定某个输入数据属于分类a的概率为30%,则分类模型的置信度可以描述输入数据实际属于分类a的概率为30%的机率。
可选地,第一温度值能够使的训练后的分类模型在验证集校准误差达到最小。第一温度参数用于作为分类预测层(如Softmax层)的参数T。
在一些实施例中,计算机设备基于验证集中包括的验证数据,对训练后的分类模型进行置信度校准,得到第一温度值,包括:计算机设备确定候选的温度值,将候选的温度值应用于训练后的分类模型的分类预测层,得到调整后的分类模型;计算机设备确定调整后的分类模型在验证集上产生的置信度误差,置信度误差用于表征调整后的分类模型的置信度;计算机设备若置信度误差满足置信度条件,则将候选的温度值确定为第一温度值,置信度条件用筛选合适的温度值。
在一些实施例中,置信度条件用于从多个候选的温度值中选择是训练后的分类模型在验证集上的校准误差最小的温度值。
在一些实施例中,候选的温度值的取值范围为正数。在一个实例中,计算机设备从最小的候选温度值开始,逐个确定各个候选的温度值直到确定第一温度值。
例如,计算机设备将0.1确定为候选温度值,计算机设备将0.1作为分类预测层的参数T,得到调整后的分类模型;计算机设备使用调整后的分类模型对验证集中的至少一个验证数据进行处理,得到至少一个验证数据分别对应的分类结果;计算机设备根据至少一个验证数据分别对应的分类结果,确定调整后的分类模型在验证集上产生的置信度误差。若调整后的分类模型在验证集上产生的置信度误差不满足置信度条件,则计算机涉及将候选的温度值设置为0.2,并重复上述步骤,直到调整后的分类模型在验证集上产生的置信度误差不满足置信度条件为止。
可选地,调整后的分类模型在验证集上产生的置信度误差可以通过期望校准误差(Expected Calibration Error,ECE)的方法计算。首先,计算机设备根据分类结果将至少一个验证数据分组到M个相等的区间内然后计算所有区间内样本的准确率与置信度之差的加权平均值。其形式化计算如下:
其中,ECE表示调整后的分类模型在验证集上产生的置信度误差,Bm表示一个预测结果区间,acc(Bm)表示预测结果区间Bm中包括的全部验证数据的分类结果的准确率,avgConf(Bm)表示表示预测结果区间Bm中包括的全部验证数据分别对应的置信度平均值。
通过将本方法提供的混合-解耦的分类模型训练方法和后置信度方法结合,使得训练完成的分类模型在具有较高的准确性的同时,也保证了训练完成的分类模型具有较好的置信度表现,这使得训练后的分类模型能够适应于对泛化性能和不确定性估计表现均具有较高要求的应用场景。本方法对于基于深度神经网络的一些预测任务的落地具有重要意义,有助于扩展分类模型的应用领域。
在一些实施例中,分类模型用于针对以下任意一种领域中的样本数据执行分类任务:自动驾驶领域中的图像样本数据;医学辅助领域中的图像样本数据;医学辅助领域中的文本样本数据。
在一些实施例中,自动驾驶领域中的图像样本数据是指载具在行驶过程中通过拍摄设备得到的图像。医学辅助领域中的图像样本数据包括但不限于医疗仪器检测生成的显影图像,如X光图像、核磁共振图像等。医学辅助领域中的文本样本数据包括但不限于处方信息等。
在一个示例中,训练集以及校验集中包括至少一个自动驾驶领域中的图像样本数据,通过上述混合-解耦以及后校准过程得到的训练完成的分类模型能够在车辆自动驾驶过程中对车道线等路况画面进行识别,使得目标应用程序及时了解路况变化并生成相应的驾驶指令。
下面,对使用自动驾驶领域的图像训练样本对分类模型进行训练的过程进行简要说明。有关本示例中各个步骤的详细内容,请参考上文实施例。在本应用场景中,分类模型用于对图像数据中现实的物体图像进行分类,预测图像数据对应的分类结果,达到辅助感知车辆驾驶过程中的路况变化的目的。
在一些实施例中,计算机设备从训练集中获取n个训练数据,以及n个训练数据中每一个训练数据分别对应的标签信息。在本示例中,训练数据是指图像数据,训练集中包括至少一个图像数据。
可选地,图像数据包括:车辆驾驶过程中由车载摄像头拍摄采集到的路况画面。训练样本对应有分类标签,分类标签用于表征路况画面中的物体图像其所属的物体类别。物体类别包括但不限于:人物、动物、车辆、建筑、路标等。
为了方便后续模型输入数据的混合过程,计算机设备从训练集中选择n个分辨率之差小于阈值的路况画面,作为n个训练数据。例如,计算机设备从训练集中选择2个分辨率相同的路况画面,作为两个训练数据。
在一些实施例中,计算机设备对n个训练数据分别对应的模型输入数据进行混合操作,生成m个混合输入数据。在本示例中,训练数据对应的模型输入数据是指图像数据(也即路况画面)对应的模型输入数据,可选地,该模型输入数据通过向量或者矩阵形式表示,图像数据对应的模型输入数据中包括路况画面中至少一个像素点分别对应的颜色信息,以及至少一个像素点分别对应的亮度信息。
混合输入数据是指:对n个路况画面分别对应的模型输入数据进行混合操作,得到的输入向量或者输入矩阵。可选地,混合输入数据中包括至少一个像素点在n个路况画面中分别对应颜色信息的混合值,以及亮度信息的混合值。
在一些实施例中,计算机设备使用分类模型对m个混合输入数据进行特征提取,得到m个混合特征信息。混合特征信息用于解耦得到n个路况画面分别对应的特征信息。
在一些实施例中,计算机设备对m个混合特征信息进行解耦操作,得到n个训练数据分别对应的特征信息。本示例中,训练数据对应的特征信息是指图像数据对应的特征信息。更具体地,训练数据对应的特征信息是将路况画面的对应的模型输入数据进行神经网络前向传导处理后得到的。
在一些实施例中,根据训练数据对应的分类标签和分类结果,对分类模型的参数进行调整,得到训练后的分类模型,训练数据对应的分类结果由分类模型中的分类预测层基于训练数据对应的特征信息得到。
在本示例中,训练数据对应的分类结果是指分类模型针对路况画面中显示的物体图像进行分类预测,得到的预测物体分类。
在一些实施例中,在使用训练集对分类模型进行校准之后,计算机使用验证集对训练后的分类模型进行置信度校准,得到训练完成的分类模型。在一些实施例中,基于验证集中包括的验证数据是指图像数据。训练完成的分类模型可以应用于图像数据进行分类,确定图像数据中包括的物体所属的类别。在另一个示例中,训练集以及校验集中包括至少一个自医学辅助领域中的图像样本数据,通过上述混合-解耦以及后校准过程得到的训练完成的分类模型能够在医生诊断病情的过程中,对医疗仪器检测生成的显影图像(也即检测图像)进行分类,确定病灶的位置或者病灶的属性,为医生的诊断过程提供辅助参考信息。
在本实施场景中,训练数据对应的模型输入数据用于表征检测图像,训练数据对应的模型输入数据是指检测图像对应的模型输入数据。混合输入数据是指通过对n个检测图像分别对应的模型输入数据进行混合操作得到的输入数据。训练数据对应的特征信息是指检测图像对应的特征信息。训练图像对应的分类结果是指通过分类模型针对检测图像对应的特征信息预测生成的,检测部位的病灶分类结果。可选地,检测图像所属的类型包括以下至少之一:检测部位良性、检测部位恶性、检测部位性质模糊等。在另一个示例中,训练集以及校验集中包括至少一个自医学辅助领域中的处方文本,通过上述混合-解耦以及后校准过程得到的训练完成的分类模型能够在医生诊断病情的过程中,根据就诊者提供的历史处方文本进行分类,生成历史处方文本对应的分类结果,历史处方文本对应的分类结果用于表征历史处方文本用于治疗的症状类型。在本示例中,训练数据是指文本数据,文本数据中包括处方文本,训练数据对应的模型输入数据是指文本数据对应的模型输入数据。可选地,计算机设备对文本数据进行分词,得到至少一个字符;计算机设备将至少一个字符分别对应的向量表示进行拼接,得到文本数据对应的模型输入数据。
训练数据对应的分类标签用于表征训练数据对应的症状类型。训练数据对应的分类结果,是指通过分类模型预测得到的症状类型。计算机设备根据训练数据对应的分类标签和分类结果,对分类模型的模型参数进行调整,得到训练后的分类模型。
此后,为了提升分类模型的置信度,计算机设备使用包括至少一个文本数据的校验集,对训练后的分类模型进行后置信度校验,得到训练完成的分类模型。训练完成的分类模型用于对就诊者提供的历史处方文本进行分类,确定就诊者的历史症状类型,从而为医生的诊断过程提供参考信息。
本申提供的分类模型的训练方法能够有效缓解相关技术中混合策略所带来的欠自信问题。进一步地,通过将混合-解耦策略和置信度校准结合,能达到提升置信度校准的优越性能。而且本方法还可以应用于神经网络的任意隐含层,进而同时提高预测的准确性。
下述为本申请装置实施例,可以用于执行本申请方法实施例。对于本申请装置实施例中未披露的细节,请参照本申请方法实施例。
图5示出了本申请一个示例性实施例提供的拍摄分类模型的训练装置的框图。该装置可以通过软件、硬件或者两者的结合实现成为分类模型的训练设备的全部或一部分。该装置500可以包括:数据获取模块510、输入混合模块520、特征提取模块530、特征解耦模块540、模型训练模块550和模型校准模块560。
数据获取模块510,用于从训练集中获取n个训练数据,以及所述n个训练数据中每一个训练数据分别对应的标签信息,所述标签信息用于表征所述训练数据所属的类别,n为大于1的整数;
输入混合模块520,用于对所述n个训练数据分别对应的模型输入数据进行混合操作,生成m个混合输入数据,m为大于或等于n的整数;
特征提取模块530,用于使用分类模型对所述m个混合输入数据进行特征提取,得到m个混合特征信息;
特征解耦模块540,用于对所述m个混合特征信息进行解耦操作,得到所述n个训练数据分别对应的特征信息;
模型训练模块550,用于根据所述训练数据对应的分类标签和分类结果,对所述分类模型的参数进行调整,得到训练后的分类模型,所述训练数据对应的分类结果由所述分类模型中的分类预测层基于所述训练数据对应的特征信息得到;
模型校准模块560,用于对所述训练后的分类模型进行置信度校准,得到训练完成的分类模型。
在一些实施例中,所述输入混合模块520,包括:系数确定单元,用于确定m组混合参数,每一组混合参数包括所述n个训练数据分别对应的混合系数;数据生成单元,用于对于所述m组混合参数中的第i组混合参数,基于所述第i组混合参数中包括的所述n个训练数据分别对应的混合系数,对所述n个训练数据分别对应的模型输入数据进行加权求和,得到第i个混合输入数据,i为小于或等于m的整数。
在一些实施例中,所述数据生成单元,用于根据损失m组混合参数和所述m个混合特征信息,确定所述n个训练数据分别对应的特征信息。
在一些实施例中,所述分类模型中的第k个隐藏层用于对p个混合输入特征进行处理,输出p个混合输出特征;其中,所述p个隐层输入特征是对所述n个训练数据在所述第k个隐藏层的输入特征进行混合操作得到的,所述p个混合输出特征用于解耦得到所述n个训练数据在所述第k个隐藏层的输出特征,k为正整数,p为大于或等于n的整数。
在一些实施例中,所述模型校准模型560,包括:温度值确定单元,用于基于验证集中包括的验证数据,对所述训练后的分类模型进行置信度校准,得到第一温度值,所述第一温度值用于调节所述分类预测层在多个类别分别对应的预测概率之间的差异;温度应用单元,用于将所述第一温度值应用于所述训练后的分类模型,得到所述训练完成的分类模型。
在一些实施例中,所述温度值确定单元,用于确定候选的温度值,将所述候选的温度值应用于所述训练后的分类模型的分类预测层,得到调整后的分类模型;确定所述调整后的分类模型在所述验证集上产生的置信度误差,所述置信度误差用于表征所述调整后的分类模型的置信度;若所述置信度误差满足置信度条件,则将所述候选的温度值确定为所述第一温度值,所述置信度条件用筛选合适的温度值。
在一些实施例中,所述分类模型用于针对以下任意一种领域中的样本数据执行分类任务:自动驾驶领域中的图像样本数据;医学辅助领域中的图像样本数据;医学辅助领域中的文本样本数据。
需要说明的是,上述实施例提供的装置,在实现其功能时,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将设备的内容结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。另外,上述实施例提供的装置与方法实施例属于同一构思,其具体实现过程详见方法实施例,这里不再赘述。上述实施例提供的装置的有益效果请参考方法侧实施例的描述,这里也不再赘述。
图6示出了本申请一个示例性实施例提供的计算机设备的结构框图。该分类模型的训练设备600可以是上文介绍计算机设备。
通常,计算机设备600包括有:处理器601和存储器602。
处理器601可以包括一个或多个处理核心,比如4核心处理器、6核心处理器等。处理器601可以采用DSP(Digital Signal Processing,数字信号处理)、FPGA(FieldProgrammable Gate Array,现场可编程门阵列)、PLA(Programmable Logic Array,可编程逻辑阵列)中的至少一种硬件形式来实现。处理器601也可以包括主处理器和协处理器,主处理器是用于对在唤醒状态下的数据进行处理的处理器,也称CPU(Central ProcessingUnit,中央处理器);协处理器是用于对在待机状态下的数据进行处理的低功耗处理器。在一些实施例中,处理器601可以在集成有GPU(Graphics Processing Unit,图像处理器),GPU用于负责显示屏所需要显示的内容的渲染和绘制。一些实施例中,处理器601还可以包括AI(Artificial Intelligence,人工智能)处理器,该AI处理器用于处理有关机器学习的计算操作。
存储器602可以包括一个或多个计算机可读存储介质,该计算机可读存储介质可以是有形的和非暂态的。存储器602还可包括高速随机存取存储器,以及非易失性存储器,比如一个或多个磁盘存储设备、闪存存储设备。在一些实施例中,存储器602中的非暂态的计算机可读存储介质存储有至少一条指令、至少一段程序、代码集或指令集,该至少一条指令、至少一段程序、代码集或指令集由处理器601加载并执行以实现上述各方法实施例提供的分类模型的训练方法。
本申请实施例还提供一种计算机可读存储介质,该存储介质中存储有计算机程序,所述计算机程序由处理器加载并执行以实现上述各方法实施例提供的分类模型的训练方法。
该计算机可读介质可以包括计算机存储介质和通信介质。计算机存储介质包括以用于存储诸如计算机可读指令、数据结构、程序模块或其他数据等信息的任何方法或技术实现的易失性和非易失性、可移动和不可移动介质。计算机存储介质包括RAM、ROM、EPROM(Erasable Programmable Read-Only Memory,可擦写可编程只读存储器)、EEPROM(Electrically Erasable Programmable Read-Only Memory,电可擦写可编程只读存储器)、闪存或其他固态存储技术,CD-ROM、DVD(Digital Video Disc,高密度数字视频光盘)或其他光学存储、磁带盒、磁带、磁盘存储或其他磁性存储设备。当然,本领域技术人员可知该计算机存储介质不局限于上述几种。
本申请实施例还提供一种计算机程序产品,所述计算机程序产品包括计算机程序,所述计算机程序存储在计算机可读存储介质中,处理器从所述计算机可读存储介质读取并执行所述计算机程序,以实现上述各方法实施例提供的分类模型的训练方法。
应当理解的是,在本文中提及的“多个”是指两个或两个以上。“和/或”,描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。字符“/”一般表示前后关联对象是一种“或”的关系。
需要说明的是,本申请在收集用户的相关数据之前以及在收集用户的相关数据的过程中,都可以显示提示界面、弹窗或输出语音提示信息,该提示界面、弹窗或语音提示信息用于提示用户当前正在搜集其相关数据,使得本申请仅仅在获取到用户对该提示界面或者弹窗发出的确认操作后,才开始执行获取用户相关数据的相关步骤,否则(即未获取到用户对该提示界面或者弹窗发出的确认操作时),结束获取用户相关数据的相关步骤,即不获取用户的相关数据。换句话说,本申请所采集的所有用户数据(训练集中的训练数据),处理严格根据相关国家法律法规的要求,获取个人信息主体的知情同意或单独同意都是在用户同意并授权的情况下进行采集的,并在法律法规及个人信息主体的授权范围内,开展后续数据使用及处理行为且相关用户数据的收集、使用和处理需要遵守相关国家和地区的相关法律法规和标准。
以上所述仅为本申请的可选实施例,并不用以限制本申请,凡在本申请的精神和原则之内,所作的任何修改、等同切换、改进等,均应包含在本申请的保护范围之内。

Claims (11)

1.一种分类模型的训练方法,其特征在于,所述方法包括:
从训练集中获取n个训练数据,以及所述n个训练数据中每一个训练数据分别对应的标签信息,所述标签信息用于表征所述训练数据所属的类别,n为大于1的整数;
对所述n个训练数据分别对应的模型输入数据进行混合操作,生成m个混合输入数据,m为大于或等于n的整数;
使用分类模型对所述m个混合输入数据进行特征提取,得到m个混合特征信息;
对所述m个混合特征信息进行解耦操作,得到所述n个训练数据分别对应的特征信息;
根据所述训练数据对应的分类标签和分类结果,对所述分类模型的参数进行调整,得到训练后的分类模型,所述训练数据对应的分类结果由所述分类模型中的分类预测层基于所述训练数据对应的特征信息得到;
对所述训练后的分类模型进行置信度校准,得到训练完成的分类模型。
2.根据权利要求1所述的方法,其特征在于,所述对所述n个训练数据分别对应的模型输入数据进行混合操作,生成m个混合输入数据,包括:
确定m组混合参数,每一组混合参数包括所述n个训练数据分别对应的混合系数;
对于所述m组混合参数中的第i组混合参数,基于所述第i组混合参数中包括的所述n个训练数据分别对应的混合系数,对所述n个训练数据分别对应的模型输入数据进行加权求和,得到第i个混合输入数据,i为小于或等于m的整数。
3.根据权利要求2所述的方法,其特征在于,所述对所述m个混合特征信息进行解耦操作,得到所述n个训练数据分别对应的特征信息,包括:
根据所述m组混合参数和所述m个混合特征信息,确定所述n个训练数据分别对应的特征信息。
4.根据权利要求1所述的方法,其特征在于,所述分类模型中的第k个隐藏层用于对p个混合输入特征进行处理,输出p个混合输出特征;其中,所述p个隐层输入特征是对所述n个训练数据在所述第k个隐藏层的输入特征进行混合操作得到的,所述p个混合输出特征用于解耦得到所述n个训练数据在所述第k个隐藏层的输出特征,k为正整数,p为大于或等于n的整数。
5.根据权利要求1所述的方法,其特征在于,所述对所述训练后的分类模型进行置信度校准,得到训练完成的分类模型,包括:
基于验证集中包括的验证数据,对所述训练后的分类模型进行置信度校准,得到第一温度值,所述第一温度值用于调节所述分类预测层在多个类别分别对应的预测概率之间的差异;
将所述第一温度值应用于所述训练后的分类模型,得到所述训练完成的分类模型。
6.根据权利要求5所述的方法,其特征在于,所述基于验证集中包括的验证数据,对所述训练后的分类模型进行置信度校准,得到第一温度值,包括:
确定候选的温度值,将所述候选的温度值应用于所述训练后的分类模型的分类预测层,得到调整后的分类模型;
确定所述调整后的分类模型在所述验证集上产生的置信度误差,所述置信度误差用于表征所述调整后的分类模型的置信度;
若所述置信度误差满足置信度条件,则将所述候选的温度值确定为所述第一温度值,所述置信度条件用筛选合适的温度值。
7.根据权利要求1至6任一项所述的方法,其特征在于,所述分类模型用于针对以下任意一种领域中的样本数据执行分类任务:
自动驾驶领域中的图像样本数据;
医学辅助领域中的图像样本数据;
医学辅助领域中的文本样本数据。
8.一种分类模型的训练装置,其特征在于,所述装置包括:
数据获取模块,用于从训练集中获取n个训练数据,以及所述n个训练数据中每一个训练数据分别对应的标签信息,所述标签信息用于表征所述训练数据所属的类别,n为大于1的整数;
输入混合模块,用于对所述n个训练数据分别对应的模型输入数据进行混合操作,生成m个混合输入数据,m为大于或等于n的整数;
特征提取模块,用于使用分类模型对所述m个混合输入数据进行特征提取,得到m个混合特征信息;
特征解耦模块,用于对所述m个混合特征信息进行解耦操作,得到所述n个训练数据分别对应的特征信息;
模型训练模块,用于根据所述训练数据对应的分类标签和分类结果,对所述分类模型的参数进行调整,得到训练后的分类模型,所述训练数据对应的分类结果由所述分类模型中的分类预测层基于所述训练数据对应的特征信息得到;
模型校准模型,用于对所述训练后的分类模型进行置信度校准,得到训练完成的分类模型。
9.一种计算机设备,其特征在于,所述计算机设备包括处理器和存储器,所述存储器中存储有计算机程序,所述计算机程序由所述处理器加载并执行以实现如权利要求1至7任一项所述的分类模型的训练方法。
10.一种计算机可读存储介质,其特征在于,所述存储介质中存储有计算机程序,所述计算机程序由处理器加载并执行,以实现如权利要求1至7任一项所述的分类模型的训练方法。
11.一种计算机程序产品,其特征在于,所述计算机程序产品包括计算机程序,所述计算机程序存储在计算机可读存储介质中,处理器从所述计算机可读存储介质读取并执行所述计算机程序,以实现如权利要求1至7任一项所述的分类模型的训练方法。
CN202310340644.4A 2023-03-24 2023-03-24 分类模型的训练方法、装置、设备及存储介质 Pending CN116956014A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310340644.4A CN116956014A (zh) 2023-03-24 2023-03-24 分类模型的训练方法、装置、设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310340644.4A CN116956014A (zh) 2023-03-24 2023-03-24 分类模型的训练方法、装置、设备及存储介质

Publications (1)

Publication Number Publication Date
CN116956014A true CN116956014A (zh) 2023-10-27

Family

ID=88453696

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310340644.4A Pending CN116956014A (zh) 2023-03-24 2023-03-24 分类模型的训练方法、装置、设备及存储介质

Country Status (1)

Country Link
CN (1) CN116956014A (zh)

Similar Documents

Publication Publication Date Title
EP3961484A1 (en) Medical image segmentation method and device, electronic device and storage medium
CN111754596B (zh) 编辑模型生成、人脸图像编辑方法、装置、设备及介质
CN111898696A (zh) 伪标签及标签预测模型的生成方法、装置、介质及设备
CN110490239B (zh) 图像质控网络的训练方法、质量分类方法、装置及设备
CN112949786A (zh) 数据分类识别方法、装置、设备及可读存储介质
CN110659723B (zh) 基于人工智能的数据处理方法、装置、介质及电子设备
CN110490242B (zh) 图像分类网络的训练方法、眼底图像分类方法及相关设备
CN111932529B (zh) 一种图像分类分割方法、装置及系统
CN111667459B (zh) 一种基于3d可变卷积和时序特征融合的医学征象检测方法、系统、终端及存储介质
CN113177559B (zh) 结合广度和密集卷积神经网络的图像识别方法、系统、设备及介质
CN112668608A (zh) 一种图像识别方法、装置、电子设备及存储介质
CN114612902A (zh) 图像语义分割方法、装置、设备、存储介质及程序产品
CN115880317A (zh) 一种基于多分支特征融合精炼的医学图像分割方法
CN116994021A (zh) 图像检测方法、装置、计算机可读介质及电子设备
CN117033609B (zh) 文本视觉问答方法、装置、计算机设备和存储介质
CN116258756B (zh) 一种自监督单目深度估计方法及系统
CN116975347A (zh) 图像生成模型训练方法及相关装置
CN116485943A (zh) 图像生成方法、电子设备及存储介质
CN116956014A (zh) 分类模型的训练方法、装置、设备及存储介质
CN114639132A (zh) 人脸识别场景下的特征提取模型处理方法、装置、设备
CN111582404A (zh) 内容分类方法、装置及可读存储介质
CN117058489B (zh) 多标签识别模型的训练方法、装置、设备及存储介质
CN113505866B (zh) 基于边缘素材数据增强的图像分析方法和装置
CN117012326A (zh) 一种生成医学报告的方法以及相关装置
CN116310357A (zh) 一种基于多级特征融合的视觉显著性预测方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication