CN114418111A - 标签预测模型训练及样本筛选方法、装置、存储介质 - Google Patents

标签预测模型训练及样本筛选方法、装置、存储介质 Download PDF

Info

Publication number
CN114418111A
CN114418111A CN202111602781.8A CN202111602781A CN114418111A CN 114418111 A CN114418111 A CN 114418111A CN 202111602781 A CN202111602781 A CN 202111602781A CN 114418111 A CN114418111 A CN 114418111A
Authority
CN
China
Prior art keywords
data
label
unlabeled
representing
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111602781.8A
Other languages
English (en)
Inventor
孟嘉琪
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Sichuan Yuncong Tianfu Artificial Intelligence Technology Co ltd
Original Assignee
Sichuan Yuncong Tianfu Artificial Intelligence Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Sichuan Yuncong Tianfu Artificial Intelligence Technology Co ltd filed Critical Sichuan Yuncong Tianfu Artificial Intelligence Technology Co ltd
Priority to CN202111602781.8A priority Critical patent/CN114418111A/zh
Publication of CN114418111A publication Critical patent/CN114418111A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • Computing Systems (AREA)
  • Artificial Intelligence (AREA)
  • Mathematical Physics (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • General Engineering & Computer Science (AREA)
  • Biomedical Technology (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Medical Informatics (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本申请提供一种标签预测模型训练及样本筛选方法、装置、存储介质,包括:根据标注数据、标注数据的真实标签、未标注数据,训练生成器,利用生成器输出的标注数据、未标注数据各自的特征提取结果、标注数据的真实标签,未标注数据的参考标签,训练判别器,并基于训练好的生成器和判别器,生成训练好的标签预测模型。借此,本申请可提高模型训练的迭代效率,并降低训练数据的筛选成本。

Description

标签预测模型训练及样本筛选方法、装置、存储介质
技术领域
本申请实施例涉及人工智能技术领域,特别涉及一种标签预测模型训练及样本筛选方法、装置及计算机存储介质。
背景技术
从海量数据中获取对模型训练具有价值的样本,使得模型在不同场景下,甚至在未训练过的场景下,均能表现良好,是有监督学习算法的一个难点问题。
然而,投入大量人力盲目标注的标签样本的迭代训练结果,不一定能带来模型训练精度提升,因此,如何通过算法自动筛选出对应于不同应用场景下的,且可针对模型训练提供较高增益价值的训练样本,是目前亟待解决的一个重要课题。
发明内容
鉴于上述问题,本申请提供一种标签预测模型训练及样本筛选方法、装置、存储介质,可至少部分地解决上述技术问题。
根据本申请的第一方面,提供一种标签预测模型训练方法,包括:根据标注数据、所述标注数据的真实标签、未标注数据,训练生成器;利用所述生成器输出的所述标注数据、所述未标注数据各自的特征提取结果、所述标注数据的真实标签,所述未标注数据的参考标签,训练判别器;基于训练好的所述生成器和所述判别器,生成训练好的标签预测模型。
根据本申请的第二方面,提供一种样本筛选方法,包括:获取多个候选样本;利用标签预测模型针对每一个所述候选样本执行标签预测,获得每一个所述候选样本的标签预测值;其中,所述标签预测模型为利用如第一方面所述的方法所训练得到的;根据每一个所述候选样本的标签预测值,确定所述候选样本中的训练样本。
根据本申请的第三方面,提供一种标签预测模型训练装置,包括:生成器训练模块,用于根据标注数据、所述标注数据的真实标签、未标注数据,训练生成器;判别器训练模块,用于利用所述生成器输出的所述标注数据、所述未标注数据各自的特征提取结果、所述标注数据的真实标签,所述未标注数据的参考标签,训练判别器;生成模块,用于基于训练好的所述生成器和所述判别器,生成训练好的标签预测模型。
根据本申请的第四方面,提供一种样本筛选装置,包括:获取模块,用于获取多个候选样本;标注模块,用于利用标签预测模型针对每一个所述候选样本执行标签预测,获得每一个所述候选样本的标签预测值;其中,所述标签预测模型为利用如第三方面所述的装置所训练得到的;筛选模块,用于根据每一个所述候选样本的标签预测值,确定所述候选样本中的训练样本。
根据本申请的第五方面,提供一种存储有计算机指令的计算机可读存储介质,其中,所述计算机指令用于使计算机执行如第一方面所述的标签预测模型训练方法,或执行如第二方面所述的样本筛选方法。
综上所述,本申请各实施例提供的标签预测模型训练及样本筛选方案,可自动预测样本标签,以有效降低人工标注样本的标注成本。
再者,利用本申请所训练的标签预测模型可自动筛选出模型未训练的缺失场景数据,减少数据筛选标注成本,并提高模型迭代开发效率。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请实施例中记载的一些实施例,对于本领域普通技术人员来讲,还可以根据这些附图获得其他的附图。
图1为本申请示例性实施例的标签预测模型训练方法的流程示意图。
图2为本申请另一示例性实施例的标签预测模型训练方法的流程示意图。
图3为本申请示例性实施例的标签预测模型的架构示意图。
图4为本申请另一示例性实施例的标签预测模型训练方法的流程示意图。
图5为本申请另一示例性实施例的标签预测模型训练方法的流程示意图。
图6为本申请示例性实施例的样本筛选方法的流程示意图。
图7为本申请示例性实施例的标签预测模型训练装置的结构示意图。
图8为本申请示例性实施例的样本筛选装置的结构示意图。
具体实施方式
为了使本领域的人员更好地理解本申请实施例中的技术方案,下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本申请实施例一部分实施例,而不是全部的实施例。基于本申请实施例中的实施例,本领域普通技术人员所获得的所有其他实施例,都应当属于本申请实施例保护的范围。
通常在模型的训练过程中,若训练数据所对应的场景越为丰富,则模型的鲁棒性就越强。
目前的样本搜索方法主要包括以下两种:
第一种方法为:针对无标签数据底库使用模型进行预测,并获取其中置信度较低的样本进行标注后加入迭代训练,以利用这些低置信度样本将缺失的场景数据加入模型的迭代训练,从而提高模型的鲁棒性。
第二种方法为:整理验证集测试集中的错误样本,使用模型分别对错误样本和无标签数据底库进行预测,通过保留全连接层前的特征向量,并使用余弦相似度对两份数据进行搜索匹配,以保留余弦相似度值大于某一阈值的样本进行模型迭代训练。借以改善模型在误识别场景下的准确率。
然而,无论是利用低置信度排序搜索,还是对错误样本的特征向量做余弦相似度搜索,都无法保证所搜索的样本是远离训练集样本空间的,也就是模型未训练过的场景样本。
此外,实际模型训练过程中,会出现后几轮的模型迭代训练性能基本保持不变,这是因为搜索出的样本基本都是训练集中出现过的场景或十分接近训练集样本空间的数据,因此,对于模型在新场景下的识别增益效果有限。
有鉴于此,如何丰富模型的训练场景,以提高模型的鲁棒性,并减少人工标注样本导致的训练成本过高等问题,即为本申请待解决的技术课题。
图1为本申请示例性实施例的标签预测模型训练方法的处理流程图。如图所示,本实施例主要包括以下步骤:
步骤S102,根据标注数据、标注数据的真实标签、未标注数据,训练生成器。
可选地,可构建有监督学习器和无监督重建器,以协同配合完成生成器的训练任务。
可选地,生成器可包括编码器,有监督学习器可包括分类器,无监督重建器可包括解码器。
可选地,生成器可包括CNN模型(或称为卷积神经网络模型)和FC模型(或称为全连接模型)。
步骤S104,利用生成器输出的标注数据、未标注数据各自的特征提取结果、标注数据的真实标签,未标注数据的参考标签,训练判别器。
可选地,可利用判别器,根据生成器针对标注数据、未标注数据输出的特征向量,获得标注数据、未标注数据各自的标注预测标签,并根据标注数据的真实标签,未标注数据的参考标签进行训练。
步骤S106,基于训练好的生成器和判别器,生成训练好的标签预测模型。
具体地,当判断生成器和判别器均训练完成时,即可确定训练好的标签预测模型。
综上所述,本实施例提供的标签预测模型训练方法,通过结合利用标注数据和未标注数据来执行模型训练,不仅可降低数据标注成本,亦可获得较佳的模型训练效果,以提高模型预测的正确性。
图2为本申请另一示例性实施例的标签预测模型训练方法的流程示意图。本实施例为上述步骤S102的具体实施方案。如图所示,本实施例主要包括以下步骤:
步骤S202,利用生成器,针对标注数据、未标注数据执行卷积处理和隐变量学习,获得标注数据、未标注数据各自的特征向量和隐变量。
可选地,可利用生成器的CNN模型,针对标注数据、未标注数据执行编码处理(例如卷积处理),获得标注数据、未标注数据各自的特征向量。
可选地,可利用生成器的FC模型,针对标注数据、未标注数据执行隐变量学习,获得标注数据、未标注数据各自的隐变量(latent variables)。
步骤S204,利用有监督学习器,根据标注数据的隐变量执行分类预测,获得有监督学习器的损失函数。
参考图3,可选地,可利用有监督学习器,根据标注数据的隐变量执行分类预测,获得标注数据的预测标签,并根据标注数据的真实标签、预测标签,获得有监督学习器的损失函数。
例如,有监督学习器可为已训练好的分类器模型,其可根据标注数据的隐变量,执行分类预测,输出一个分类概率向量,例如[0.9,0.1],再根据所输出的分类概率向量与标注数据的真实标签,计算有监督学习器的交叉熵损失(CE Loss)。
于本实施例中,有监督学习器的损失函数表示为:
Figure BDA0003432351850000061
其中,STL表示有监督学习器,
Figure BDA0003432351850000062
表示有监督学习器的损失函数,zL表示标注数据的隐变量,xL表示标注数据,yL表示标注数据的真实标签,DKL表示生成器输出的隐变量的KL散度,pφ表示φ参数化解码器,qθ表示θ参数化编码器。
步骤S206,利用无监督重建器,根据标注数据、未标注数据各自的特征向量、隐变量,执行反卷积处理,获得无监督重建器的损失函数。
于本实施例中,可利用作为编码器的生成器和作为解码器的无监督重建器构成一个变分自编码模型(VAE模型),其基于高斯先验学习低维隐空间,由于变分自编码模型的编码预测与解码预测均无需标签,且重建目标即为输入本身,因此,可同时利用标注数据和未标注数据执行训练任务。
如图3所示,可利用无监督重建器,根据标注数据的隐变量、标注数据的特征向量,执行反卷积处理(解码),获得标注数据的标注还原预测,并根据标注数据和标注还原预测,获得有监督学习器的标注损失子函数。
于本实施例中,无监督重建器的标注损失子函数可表示为:
Figure BDA0003432351850000063
其中,xL表示标注数据,zL表示标注数据的隐变量,pφ表示φ参数化解码器,qθ表示θ参数化编码器,DKL表示生成器输出的隐变量的KL散度。
如图3所示,可利用无监督重建器,根据未标注数据的隐变量、未标注数据的特征向量,执行反卷积处理,获得未标注数据的未标注还原预测,并根据未标注数据和未标注还原预测,获得无监督重建器的未标注损失子函数。
于本实施例中,无监督重建器的未标注损失子函数可表示为:
Figure BDA0003432351850000071
其中,xU表示未标注数据,zU表示未标注数据的隐变量,pφ表示φ参数化解码器,qθ表示θ参数化编码器,DKL表示生成器输出的隐变量的KL散度。
可选地,可根据标注损失子函数、未标注损失子函数,获得无监督重建器的损失函数。
于本实施例中,无监督重建器的损失函数表示为:
Figure BDA0003432351850000072
其中,
Figure BDA0003432351850000073
表示无监督重建器的损失函数,
Figure BDA0003432351850000074
表示无监督重建器的未标注损失子函数,
Figure BDA0003432351850000075
表示无监督重建器的标注损失子函数。
步骤S208,根据有监督学习器的损失函数、无监督重建器的损失函数,训练生成器。
于本实施例中,当有监督学习器的损失函数和无监督重建器的损失函数均满足预设收敛条件时,代表生成器的训练完成。
可选地,可当预设批次的标注数据、未标注数据均完成训练时,获得有监督学习器的损失函数和无监督重建器的损失函数均满足预设收敛条件的判断结果。
综上所述,本实施例的标签预测模型训练方法,通过有监督学习器的分类预测和无监督重建器的还原预测,协助完成生成器的训练任务,可以提高生成器的训练效果,以提高生成器的预测结果准确性。
图4示出了本申请另一示例性实施例的标签预测模型训练方法的处理流程,本实施例为上述步骤S104的具体实施方案。如图所示,本实施例主要包括以下步骤:
步骤S402,利用判别器,根据生成器针对标注数据输出的特征向量,获得标注数据的标注预测标签,并根据标注预测标签、标注数据的真实标签,获得判别器的标注损失子函数。
于本实施例中,标注数据的标注预测标签可表示为:
pL=Dc(xL)log(Dc(xL))
其中,pL表示标注预测标签,xL表示标注数据,Dc表示判别器。
于本实施例中,判别器的标注损失子函数可表示为:
mse(pL,yL)
pL表示标注数据的标注预测标签,yL表示标注数据的真实标签。
于本实施例中,可将标注数据的真实标签yL设置为0。
步骤S404,利用判别器,根据生成器针对未标注数据输出的特征向量,获得未标注数据的未标注预测标签,并根据未标注预测标签、未标注数据的参考标签,获得判别器的未标注损失子函数。
于本实施例中,未标注数据的未标注预测标签可表示为:
pU=Dc(xU)log(Dc(xU))
其中,xU表示未标注数据,Dc表示判别器。
于本实施例中,判别器的未标注损失子函数可表示为:
mse(pU,yU)
其中,pU表示未标注数据的未标注预测标签,yU表示未标注数据的参考标签。
可选地,可利用标签指标器,获取未标注数据的参考标签。
可选地,标签指标器可包括训练好的分类器。
于本实施例中,可利用标签指标器,获取未标注数据的多个类别概率,并根据各类别概率的交叉熵,确定无标签样本的参考标签。
于本实施例中,参考标签可为介于0至1之间的一个不确定分值。
于本实施例中,无标签样本的参考标签表示为:
yU=Ic(xU)log(Ic(xU))
其中,yU表示未标注数据的参考标签,xU表示未标注数据,Ic表示标签指标器。
于本实施例中,未标注数据的参考标签yU即为标签指标器所输出的不确定分值。
可选地,可生成未标注数据的假标签,并根据判别器输出的未标注预测标签、假标签,训练判别器。
例如,可首先针对未标注数据打上一个假标签,并输入到判别器中训练一遍,再将假标签更新为参考标签(由标签指标器所生成),并输入至判别器中再训练一遍。
步骤S406,根据标注损失子函数、未标注损失子函数,获得判别器的损失函数。
于本实施例中,判别器的损失函数为标注损失子函数、未标注损失子函数的均方差之和。
具体地,判别器的损失函数表示为:
Figure BDA0003432351850000091
其中,
Figure BDA0003432351850000092
表示判别器的损失函数。
综上所述,本实施例利用对抗性学习原理训练判别器,可以提高判别器的训练效果,以提升判别器预测结果的准确性。
图5示出了本申请另一示例性实施例的标签预测模型训练方法的处理流程。如图所示,本实施例主要包括以下步骤:
步骤S502,根据有监督学习器的损失函数、无监督重建器的损失函数、判别器的损失函数,获得样本标签预测模型的损失函数。
于本实施例中,标签预测模型的损失函数表示为:
Figure BDA0003432351850000093
其中,
Figure BDA0003432351850000094
表示无监督重建器的损失函数,
Figure BDA0003432351850000095
表示有监督学习器的损失函数,
Figure BDA0003432351850000096
表示判别器的损失函数,λ1、λ2、λ3为权重参数。
于本实施例中,可根据实际训练需求设置λ1、λ2、λ3,本申请对此不作限制。
步骤S504,根据标签预测模型的损失函数训练标签预测模型,直至标签预测模型的损失函数满足预设收敛条件。
于本实施例中,可当预设批次的标注数据、未标注数据均完成训练时,获得标签预测模型的损失函数满足预设收敛条件的判断结果。
图6示出了本申请示例性实施例的样本筛选方法的处理流程。如图所示,本实施例主要包括以下步骤:
步骤S602,获取多个候选样本。
于本实施例中,候选样本可为无标签样本。
步骤S604,利用标签预测模型针对每一个候选样本执行标签预测,获得每一个候选样本的标签预测值。
于本实施例中,标签预测模型可为利用任意一个实施例所述的标签预测模型训练方法所训练得到的。
于本实施例中,若候选样本的标签预测值越接近0,则代表此候选样本越可能是训练集中已经存在的样本,反之,若候选样本的标签预测值越接近1,则代表此候选样本越可能是训练集中不存在的样本。
步骤S606,根据每一个候选样本的标签预测值,确定候选样本中的训练样本。
于本实施例中,可根据各候选样本对应的各标签预测值,将满足预设标签阈值的候选样本确定为训练样本。
具体地,可针对各标签预测值进行排序,以根据排序结果,将topk的候选样本加入至模型迭代训练中。
综上所述,本实施例提供的样本筛选方法,可补充训练集中未出现过的场景数据,以保证训练集中的场景数据尽可能的丰富而不是局限于某一部分场景,从而提高模型的鲁棒性。
此外,本实施例提供的样本筛选方法,可在底库数据足够庞大且场景足够丰富的情况下,自动搜索模型为训练过的高价值样本,而人工筛选标注,可以节约项目开发时间和成本。
图7示出了本申请示例性实施例的标签预测模型训练装置的结构示意图。如图所示,本实施例的标签预测模型训练装置700主要包括:
生成器训练模块702,用于根据标注数据、所述标注数据的真实标签、未标注数据,训练生成器。
判别器训练模块704,用于利用所述生成器输出的所述标注数据、所述未标注数据各自的特征提取结果、所述标注数据的真实标签,所述未标注数据的参考标签,训练判别器。
生成模块706,用于基于训练好的所述生成器和所述判别器,生成训练好的标签预测模型。
可选地,生成器训练模块702还用于:利用所述生成器,针对所述标注数据、所述未标注数据执行卷积处理和隐变量学习,获得所述标注数据、所述未标注数据各自的特征向量和隐变量;利用有监督学习器,根据所述标注数据的隐变量执行分类预测,获得所述有监督学习器的损失函数;利用无监督重建器,根据所述标注数据、所述未标注数据各自的特征向量、隐变量,执行反卷积处理,获得所述无监督重建器的损失函数;根据所述有监督学习器的损失函数、所述无监督重建器的损失函数,训练所述生成器。
可选地,生成器训练模块702还用于:利用所述生成器的CNN模型,针对所述标注数据、所述未标注数据执行卷积处理,获得所述标注数据、所述未标注数据各自的特征向量;利用所述生成器的FC模型,针对所述标注数据、所述未标注数据执行隐变量学习,获得所述标注数据、所述未标注数据各自的隐变量。
可选地,生成器训练模块702还用于:利用所述有监督学习器,根据所述标注数据的隐变量执行分类预测,获得所述标注数据的预测标签;根据所述标注数据的真实标签、所述预测标签,获得所述有监督学习器的损失函数;所述有监督学习器的损失函数表示为:
Figure BDA0003432351850000111
其中,所述STL表示所述有监督学习器,所述
Figure BDA0003432351850000112
表示所述有监督学习器的损失函数,所述zL表示所述标注数据的隐变量,所述xL表示所述标注数据,所述yL表示所述标注数据的真实标签,所述DKL表示所述生成器输出的所述隐变量的KL散度,所述pφ表示φ参数化解码器,所述qθ表示θ参数化编码器。
可选地,生成器训练模块702还用于:利用所述无监督重建器,根据所述标注数据的隐变量、所述标注数据的特征向量,执行反卷积处理,获得所述标注数据的标注还原预测,并根据所述标注数据和所述标注还原预测,获得所述有监督学习器的标注损失子函数;利用所述无监督重建器,根据所述未标注数据的隐变量、所述未标注数据的特征向量,执行反卷积处理,获得所述未标注数据的未标注还原预测,并根据所述未标注数据和所述未标注还原预测,获得所述无监督重建器的未标注损失子函数;根据所述标注损失子函数、所述未标注损失子函数,获得所述无监督重建器的损失函数;所述无监督重建器的损失函数表示为:
Figure BDA0003432351850000121
其中,所述
Figure BDA0003432351850000122
表示所述无监督重建器的损失函数,所述
Figure BDA0003432351850000123
表示所述未标注损失子函数,所述
Figure BDA0003432351850000124
表示所述标注损失子函数;
所述未标注损失子函数表示为:
Figure BDA0003432351850000125
所述标注损失子函数表示为:
Figure BDA0003432351850000126
其中,所述xU表示所述未标注数据,所述zU表示所述未标注数据的隐变量;所述xL表示所述标注数据,所述zL表示所述标注数据的隐变量,所述pφ表示φ参数化解码器,所述qθ表示θ参数化编码器。
可选地,判别器训练模块704还用于:利用所述判别器,根据所述生成器针对所述标注数据输出的特征向量,获得所述标注数据的标注预测标签;利用所述判别器,根据所述生成器针对所述未标注数据输出的特征向量,获得所述未标注数据的未标注预测标签,并根据所述未标注预测标签、所述未标注数据的参考标签,获得所述判别器的未标注损失子函数;根据所述标注损失子函数、所述未标注损失子函数,获得所述判别器的损失函数;所述标注数据的标注预测标签表示为:
pL=Dc(xL)log(Dc(xL))
所述未标注数据的未标注预测标签表示为:
pU=Dc(xU)log(Dc(xU))
其中,所述xL表示所述标注数据,所述xU表示所述未标注数据,所述Dc表示所述判别器;
所述判别器的损失函数表示为:
Figure BDA0003432351850000131
其中,所述
Figure BDA0003432351850000132
表示所述判别器的损失函数,所述mse(pL,yL)表示所述标注损失子函数,所述pL表示所述标注数据的标注预测标签,所述yL表示所述标注数据的真实标签,所述mse(pU,yU)表示所述未标注损失子函数,所述pU表示所述未标注数据的未标注预测标签,所述yU表示所述未标注数据的参考标签;其中,所述yL的取值为0。
可选地,判别器训练模块704还用于:利用标签指标器,获取所述未标注数据的多个类别概率;根据各所述类别概率的交叉熵,确定所述无标签样本的参考标签;所述无标签样本的参考标签表示为:
yU=Ic(xU)log(Ic(xU))
其中,所述yU表示所述未标注数据的参考标签,所述xU表示所述未标注数据,所述Ic表示所述标签指标器。
可选地,判别器训练模块704还用于:生成所述未标注数据的假标签,并根据所述判别器输出的所述未标注预测标签、所述假标签,训练所述判别器。
可选地,生成模块706还用于:根据所述有监督学习器的损失函数、所述无监督重建器的损失函数、所述判别器的损失函数,获得所述样本标签预测模型的损失函数;根据所述标签预测模型的损失函数训练所述标签预测模型,直至所述标签预测模型的损失函数满足预设收敛条件;所述标签预测模型的损失函数表示为:
Figure BDA0003432351850000141
其中,所述
Figure BDA0003432351850000142
表示所述无监督重建器的损失函数,所述
Figure BDA0003432351850000143
表示所述有监督学习器的损失函数,所述
Figure BDA0003432351850000144
表示所述判别器的损失函数,所述λ1、λ2、λ3为权重参数。
图8为本申请示例性实施例的样本筛选装置的架构示意图。如图所示,本实施例的样本筛选装置800主要包括:
获取模块802,用于获取多个候选样本。
标注模块804,用于利用标签预测模型针对每一个所述候选样本执行标签预测,获得每一个所述候选样本的标签预测值。
于本实施例中,所述标签预测模型为上述样本筛选装置所训练得到的。
筛选模块806,用于根据每一个所述候选样本的标签预测值,确定所述候选样本中的训练样本。
可选地,筛选模块806还用于:根据每一个所述候选样本的标签预测值,将满足预设标签阈值的所述候选样本确定为所述训练样本。
此外,本申请示例性实施例还提供一种存储有计算机指令的计算机可读存储介质,所述计算机指令用于使计算机执行各实施例所述的标签预测模型训练方法,或执行各实施例所述的样本筛选方法。
综上所述,本申请各实施例提供的标签预测模型训练及样本筛选方法、装置、存储介质,通过算法找出训练集中未出现的高价值样本,可有助于提高模型在多场景下的鲁棒性,同时加快模型的迭代速度,可实现在庞大的底库中快速搜索所需要的样本。
此外,本申请可快速找出训练集中未出现场景数据,而无需人工筛选标注,可以有效节约项目开发时间和成本。
最后应说明的是:以上实施例仅用以说明本申请实施例的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围。

Claims (14)

1.一种标签预测模型训练方法,包括:
根据标注数据、所述标注数据的真实标签、未标注数据,训练生成器;
利用所述生成器输出的所述标注数据、所述未标注数据各自的特征提取结果、所述标注数据的真实标签,所述未标注数据的参考标签,训练判别器;
基于训练好的所述生成器和所述判别器,生成训练好的标签预测模型。
2.根据权利要求1所述的标签预测模型训练方法,其特征在于,所述根据标注数据、所述标注数据的真实标签、未标注数据,训练生成器,包括:
利用所述生成器,针对所述标注数据、所述未标注数据执行卷积处理和隐变量学习,获得所述标注数据、所述未标注数据各自的特征向量和隐变量;
利用有监督学习器,根据所述标注数据的隐变量执行分类预测,获得所述有监督学习器的损失函数;
利用无监督重建器,根据所述标注数据、所述未标注数据各自的特征向量、隐变量,执行反卷积处理,获得所述无监督重建器的损失函数;
根据所述有监督学习器的损失函数、所述无监督重建器的损失函数,训练所述生成器。
3.根据权利要求2所述的标签预测模型训练方法,其特征在于,所述利用所述生成器,针对所述标注数据、所述未标注数据执行卷积处理和隐变量学习,获得所述标注数据、所述未标注数据各自的特征向量和隐变量,包括:
利用所述生成器的CNN模型,针对所述标注数据、所述未标注数据执行卷积处理,获得所述标注数据、所述未标注数据各自的特征向量;
利用所述生成器的FC模型,针对所述标注数据、所述未标注数据执行隐变量学习,获得所述标注数据、所述未标注数据各自的隐变量。
4.根据权利要求2所述的标签预测模型训练方法,其特征在于,所述利用有监督学习器,根据所述标注数据的隐变量执行分类预测,获得所述有监督学习器的损失函数,包括:
利用所述有监督学习器,根据所述标注数据的隐变量执行分类预测,获得所述标注数据的预测标签;
根据所述标注数据的真实标签、所述预测标签,获得所述有监督学习器的损失函数;
所述有监督学习器的损失函数表示为:
Figure FDA0003432351840000021
其中,所述STL表示所述有监督学习器,所述
Figure FDA0003432351840000022
表示所述有监督学习器的损失函数,所述zL表示所述标注数据的隐变量,所述xL表示所述标注数据,所述yL表示所述标注数据的真实标签,所述DKL表示所述生成器输出的所述隐变量的KL散度,所述pφ表示φ参数化解码器,所述qθ表示θ参数化编码器。
5.根据权利要求2所述的标签预测模型训练方法,其特征在于,所述利用无监督重建器,根据所述标注数据、所述未标注数据各自的特征向量、隐变量,执行反卷积处理,获得所述无监督重建器的损失函数,包括:
利用所述无监督重建器,根据所述标注数据的隐变量、所述标注数据的特征向量,执行反卷积处理,获得所述标注数据的标注还原预测,并根据所述标注数据和所述标注还原预测,获得所述有监督学习器的标注损失子函数;
利用所述无监督重建器,根据所述未标注数据的隐变量、所述未标注数据的特征向量,执行反卷积处理,获得所述未标注数据的未标注还原预测,并根据所述未标注数据和所述未标注还原预测,获得所述无监督重建器的未标注损失子函数;
根据所述标注损失子函数、所述未标注损失子函数,获得所述无监督重建器的损失函数;
所述无监督重建器的损失函数表示为:
Figure FDA0003432351840000023
其中,所述
Figure FDA0003432351840000024
表示所述无监督重建器的损失函数,所述
Figure FDA0003432351840000025
表示所述未标注损失子函数,所述
Figure FDA0003432351840000031
表示所述标注损失子函数;
所述未标注损失子函数表示为:
Figure FDA0003432351840000032
所述标注损失子函数表示为:
Figure FDA0003432351840000033
其中,所述xU表示所述未标注数据,所述zU表示所述未标注数据的隐变量;所述xL表示所述标注数据,所述zL表示所述标注数据的隐变量,所述pφ表示φ参数化解码器,所述qθ表示θ参数化编码器。
6.根据权利要求2所述的标签预测模型训练方法,其特征在于,所述利用所述生成器输出的所述标注数据、所述未标注数据各自的特征提取结果、所述标注数据的真实标签,所述未标注数据的参考标签,训练判别器,包括:
利用所述判别器,根据所述生成器针对所述标注数据输出的特征向量,获得所述标注数据的标注预测标签;
利用所述判别器,根据所述生成器针对所述未标注数据输出的特征向量,获得所述未标注数据的未标注预测标签,并根据所述未标注预测标签、所述未标注数据的参考标签,获得所述判别器的未标注损失子函数;
根据所述标注损失子函数、所述未标注损失子函数,获得所述判别器的损失函数;
所述标注数据的标注预测标签表示为:
pL=Dc(xL)log(Dc(xL))
所述未标注数据的未标注预测标签表示为:
pU=Dc(xU)log(Dc(xU))
其中,所述xL表示所述标注数据,所述xU表示所述未标注数据,所述Dc表示所述判别器;
所述判别器的损失函数表示为:
Figure FDA0003432351840000034
其中,所述
Figure FDA0003432351840000035
表示所述判别器的损失函数,所述mse(pL,yL)表示所述标注损失子函数,所述pL表示所述标注数据的标注预测标签,所述yL表示所述标注数据的真实标签,所述mse(pU,yU)表示所述未标注损失子函数,所述pU表示所述未标注数据的未标注预测标签,所述yU表示所述未标注数据的参考标签;
其中,所述yL的取值为0。
7.根据权利要求6所述的标签预测模型训练方法,其特征在于,所述方法还包括:
利用标签指标器,获取所述未标注数据的多个类别概率;
根据各所述类别概率的交叉熵,确定所述无标签样本的参考标签;
所述无标签样本的参考标签表示为:
yU=Ic(xU)log(Ic(xU))
其中,所述yU表示所述未标注数据的参考标签,所述xU表示所述未标注数据,所述Ic表示所述标签指标器。
8.根据权利要求6所述的标签预测模型训练方法,其特征在于,所述方法还包括:
生成所述未标注数据的假标签,并根据所述判别器输出的所述未标注预测标签、所述假标签,训练所述判别器。
9.根据权利要求2所述的标签预测模型训练方法,其特征在于,所述方法还包括:
根据所述有监督学习器的损失函数、所述无监督重建器的损失函数、所述判别器的损失函数,获得所述样本标签预测模型的损失函数;
根据所述标签预测模型的损失函数训练所述标签预测模型,直至所述标签预测模型的损失函数满足预设收敛条件;
所述标签预测模型的损失函数表示为:
Figure FDA0003432351840000041
其中,所述
Figure FDA0003432351840000042
表示所述无监督重建器的损失函数,所述
Figure FDA0003432351840000043
表示所述有监督学习器的损失函数,所述
Figure FDA0003432351840000044
表示所述判别器的损失函数,所述λ1、λ2、λ3为权重参数。
10.一种样本筛选方法,包括:
获取多个候选样本;
利用标签预测模型针对每一个所述候选样本执行标签预测,获得每一个所述候选样本的标签预测值;其中,所述标签预测模型为利用如权利要求1至9项中任一项所述的方法所训练得到的;
根据每一个所述候选样本的标签预测值,确定所述候选样本中的训练样本。
11.根据权利要求10所述的样本筛选方法,其特征在于,所述根据每一个所述候选样本的标签预测值执行筛选,以确定各所述候选样本中的训练样本,包括:
根据每一个所述候选样本的标签预测值,将满足预设标签阈值的所述候选样本确定为所述训练样本。
12.一种标签预测模型训练装置,包括:
生成器训练模块,用于根据标注数据、所述标注数据的真实标签、未标注数据,训练生成器;
判别器训练模块,用于利用所述生成器输出的所述标注数据、所述未标注数据各自的特征提取结果、所述标注数据的真实标签,所述未标注数据的参考标签,训练判别器;
生成模块,用于基于训练好的所述生成器和所述判别器,生成训练好的标签预测模型。
13.一种样本筛选装置,包括:
获取模块,用于获取多个候选样本;
标注模块,用于利用标签预测模型针对每一个所述候选样本执行标签预测,获得每一个所述候选样本的标签预测值;其中,所述标签预测模型为利用如权利要求12所述的装置所训练得到的;
筛选模块,用于根据每一个所述候选样本的标签预测值,确定所述候选样本中的训练样本。
14.一种存储有计算机指令的计算机可读存储介质,其中,所述计算机指令用于使计算机执行如权利要求1至9中任一项所述的标签预测模型训练方法,或执行如权利要求10或11所述的样本筛选方法。
CN202111602781.8A 2021-12-24 2021-12-24 标签预测模型训练及样本筛选方法、装置、存储介质 Pending CN114418111A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111602781.8A CN114418111A (zh) 2021-12-24 2021-12-24 标签预测模型训练及样本筛选方法、装置、存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111602781.8A CN114418111A (zh) 2021-12-24 2021-12-24 标签预测模型训练及样本筛选方法、装置、存储介质

Publications (1)

Publication Number Publication Date
CN114418111A true CN114418111A (zh) 2022-04-29

Family

ID=81269465

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111602781.8A Pending CN114418111A (zh) 2021-12-24 2021-12-24 标签预测模型训练及样本筛选方法、装置、存储介质

Country Status (1)

Country Link
CN (1) CN114418111A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112420205A (zh) * 2020-12-08 2021-02-26 医惠科技有限公司 实体识别模型生成方法、装置及计算机可读存储介质

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112420205A (zh) * 2020-12-08 2021-02-26 医惠科技有限公司 实体识别模型生成方法、装置及计算机可读存储介质

Similar Documents

Publication Publication Date Title
CN112084337B (zh) 文本分类模型的训练方法、文本分类方法及设备
CN113177132B (zh) 基于联合语义矩阵的深度跨模态哈希的图像检索方法
CN114241282A (zh) 一种基于知识蒸馏的边缘设备场景识别方法及装置
CN108537119B (zh) 一种小样本视频识别方法
Duong et al. Shrinkteanet: Million-scale lightweight face recognition via shrinking teacher-student networks
CN112199501B (zh) 一种科技信息文本分类方法
CN113392717B (zh) 一种基于时序特征金字塔的视频密集描述生成方法
Lee et al. Large scale video representation learning via relational graph clustering
CN113887211A (zh) 基于关系导向的实体关系联合抽取方法及系统
CN114329031B (zh) 一种基于图神经网络和深度哈希的细粒度鸟类图像检索方法
CN112966088B (zh) 未知意图的识别方法、装置、设备及存储介质
CN118171149B (zh) 标签分类方法、装置、设备、存储介质和计算机程序产品
CN117516937A (zh) 基于多模态特征融合增强的滚动轴承未知故障检测方法
CN117217277A (zh) 语言模型的预训练方法、装置、设备、存储介质及产品
CN116958622A (zh) 数据的分类方法、装置、设备、介质及程序产品
CN115186085A (zh) 回复内容处理方法以及媒体内容互动内容的交互方法
CN114418111A (zh) 标签预测模型训练及样本筛选方法、装置、存储介质
CN112101154B (zh) 视频分类方法、装置、计算机设备和存储介质
KR102334388B1 (ko) 순차적 특징 데이터 이용한 행동 인식 방법 및 그를 위한 장치
CN117390454A (zh) 基于多域自适应数据闭环的数据标注方法及系统
CN117218477A (zh) 图像识别及模型训练方法、装置、设备及存储介质
CN115688775A (zh) 一种基于注意力机制的电网运检领域命名实体识别方法
CN115035455A (zh) 一种基于对抗多模态领域自适应的跨类别视频时间定位方法、系统和存储介质
CN115705756A (zh) 动作检测方法、装置、计算机设备和存储介质
CN113886602A (zh) 一种基于多粒度认知的领域知识库实体识别方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination