CN112819099B - 网络模型的训练方法、数据处理方法、装置、介质及设备 - Google Patents
网络模型的训练方法、数据处理方法、装置、介质及设备 Download PDFInfo
- Publication number
- CN112819099B CN112819099B CN202110220979.3A CN202110220979A CN112819099B CN 112819099 B CN112819099 B CN 112819099B CN 202110220979 A CN202110220979 A CN 202110220979A CN 112819099 B CN112819099 B CN 112819099B
- Authority
- CN
- China
- Prior art keywords
- data
- network
- training
- pseudo
- label
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 409
- 238000000034 method Methods 0.000 title claims abstract description 88
- 238000003672 processing method Methods 0.000 title abstract description 11
- 238000012545 processing Methods 0.000 claims description 50
- 230000009466 transformation Effects 0.000 claims description 22
- 238000012952 Resampling Methods 0.000 claims description 5
- 230000002708 enhancing effect Effects 0.000 claims description 5
- 238000004590 computer program Methods 0.000 claims description 4
- 238000013473 artificial intelligence Methods 0.000 abstract description 2
- 230000008569 process Effects 0.000 description 13
- 238000010586 diagram Methods 0.000 description 7
- 238000004458 analytical method Methods 0.000 description 4
- 238000006243 chemical reaction Methods 0.000 description 4
- 238000005516 engineering process Methods 0.000 description 4
- 230000006835 compression Effects 0.000 description 3
- 238000007906 compression Methods 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 230000000694 effects Effects 0.000 description 2
- 230000006870 function Effects 0.000 description 2
- 238000002372 labelling Methods 0.000 description 2
- 238000013508 migration Methods 0.000 description 2
- 230000005012 migration Effects 0.000 description 2
- 239000013307 optical fiber Substances 0.000 description 2
- 230000000644 propagated effect Effects 0.000 description 2
- 238000012216 screening Methods 0.000 description 2
- 238000000638 solvent extraction Methods 0.000 description 2
- 230000007704 transition Effects 0.000 description 2
- 239000006002 Pepper Substances 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 238000007418 data mining Methods 0.000 description 1
- 238000013501 data transformation Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000012217 deletion Methods 0.000 description 1
- 230000037430 deletion Effects 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000003780 insertion Methods 0.000 description 1
- 230000037431 insertion Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 238000013140 knowledge distillation Methods 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 230000002035 prolonged effect Effects 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000001308 synthesis method Methods 0.000 description 1
- 238000013519 translation Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/217—Validation; Performance evaluation; Active pattern learning techniques
Landscapes
- Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Theoretical Computer Science (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Biology (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
- Image Analysis (AREA)
Abstract
本公开涉及网络模型的训练方法及装置、数据处理方法及装置、存储介质及电子设备,涉及人工智能技术领域。包括:获取目标任务所在领域的通用数据和目标任务的训练数据,该训练数据包括无标签数据和有标签数据;将通用数据分别输入第一网络和第二网络进行无监督训练;分别采用无监督训练后的第一网络和无监督训练后的第二网络对无标签数据进行无监督训练,以生成第一训练网络和第二训练网络;根据有标签数据中的标签数据对第一训练网络进行监督训练,并通过监督训练后的第一训练网络对无标签数据进行预测,生成无标签数据的伪标签数据;基于第二训练网络,对伪标签数据进行监督训练,生成目标任务的目标网络模型。本发明提高了网络模型的性能。
Description
技术领域
本发明的实施方式涉及人工智能技术领域,更具体地,本发明的实施方式涉及一种网络模型的训练方法、数据处理方法、网络模型的训练装置、数据处理装置、计算机可读存储介质及电子设备。
背景技术
本部分旨在为权利要求中陈述的本发明的实施方式提供背景或上下文,此处的描述不因为包括在本部分中就承认是现有技术。
随着深度学习技术的不断进步和计算机算力的不断提升,数据分类技术在各个领域,如语音分析、图像识别、自然语言处理等技术领域中均取得了巨大的进展。以图像识别技术领域为例,一般可以使用大规模的带有标签数据的训练样本作为训练集,应用相应的神经网络来训练分类器,使其可以学习图像的全局或局部特征,将该全局或局部特征与已经学习的特征进行比对,确定每个图像中对象的类别。
发明内容
然而,现有的数据分类技术依赖于人为标注的标签数据,不仅需要耗费较高的人力成本,而且受限于标签数据的准确性和数据规模等因素,网络模型往往存在模型泛化性能不足、模型过拟合等问题,严重影响了网络模型的训练效果。
为此,非常需要一种网络模型的训练方法,以提高网络模型的泛化性能。
在本上下文中,本发明的实施方式期望提供一种网络模型的训练方法、数据处理方法、网络模型的训练装置、数据处理装置、计算机可读存储介质及电子设备。
根据本发明实施方式的第一方面,提供一种网络模型的训练方法,包括:获取目标任务所在领域的通用数据和所述目标任务的训练数据,其中,所述训练数据包括无标签数据和有标签数据;将所述通用数据分别输入第一网络和第二网络进行无监督训练;分别采用无监督训练后的所述第一网络和无监督训练后的所述第二网络对所述无标签数据进行无监督训练,以生成第一训练网络和第二训练网络;根据所述有标签数据中的标签数据对所述第一训练网络进行监督训练,并通过监督训练后的所述第一训练网络对所述无标签数据进行预测,生成所述无标签数据的伪标签数据;基于所述第二训练网络,对所述伪标签数据进行监督训练,生成所述目标任务的目标网络模型。
在一种可选的实施方式中,所述将所述通用数据分别输入第一网络和第二网络进行无监督训练,包括:分别通过所述第一网络和所述第二网络对所述通用数据进行无监督训练,得到所述第一网络的第一网络原始参数和所述第二网络的第二网络原始参数。
在一种可选的实施方式中,所述分别采用无监督训练后的所述第一网络和无监督训练后的所述第二网络对所述无标签数据进行无监督训练,以生成第一训练网络和第二训练网络,包括:以所述第一网络原始参数为初始参数,采用所述第一网络对所述无标签数据进行无监督训练,确定所述第一网络的第一网络更新参数,以生成所述第一训练网络;以所述第二网络原始参数为初始参数,采用所述第二网络对所述无标签数据进行无监督训练,确定所述第二网络的第二网络更新参数,以生成所述第二训练网络。
在一种可选的实施方式中,在根据所述有标签数据中的标签数据对所述第一训练网络进行监督训练时,所述方法包括:将所述有标签数据输入至所述第一训练网络,以更新所述第一网络的第一网络更新参数,得到所述第一网络的第一网络训练参数;所述通过监督训练后的所述第一训练网络对所述无标签数据进行预测,生成所述无标签数据的伪标签数据,包括:以所述第一网络训练参数为初始参数,采用所述第一网络对所述无标签数据进行预测,生成所述伪标签数据。
在一种可选的实施方式中,所述基于所述第二训练网络,对所述伪标签数据进行监督训练,生成所述目标任务的目标网络模型,包括:以所述第二网络更新参数为初始参数,对所述伪标签数据进行监督训练,以生成所述目标网络模型。
在一种可选的实施方式中,在生成所述目标任务的目标网络模型时,所述方法还包括:根据所述有标签数据中的标签数据对所述目标网络模型进行监督训练,并调整所述目标网络模型的网络参数。
在一种可选的实施方式中,在对所述伪标签数据进行监督训练时,所述方法还包括:确定所述伪标签数据的标签置信度,并根据所述标签置信度对所述伪标签数据进行数据选择。
在一种可选的实施方式中,所述根据所述标签置信度对所述伪标签数据进行数据选择,包括:根据所述伪标签数据中各标签的标签置信度,从所述伪标签数据中筛选出所述标签置信度大于预设阈值的候选标签数据;在所述候选标签数据中,确定各标签对应的候选标签数据的数据量分布,并根据所述数据量分布对所述候选标签数据进行重采样。
在一种可选的实施方式中,在对所述伪标签数据进行监督训练时,所述方法还包括:按照所述伪标签数据中各标签的标签置信度,将所述伪标签数据划分为多个类别;确定所述多个类别中各类别对应的伪标签数据的数据增强策略,并根据所述数据增强策略对所述各类别对应的伪标签数据进行增强处理。
在一种可选的实施方式中,所述确定所述多个类别中各类别对应的伪标签数据的数据增强策略,并根据所述数据增强策略对所述各类别对应的伪标签数据进行增强处理,包括:按照所述伪标签数据的数据属性对所述各类别对应的伪标签数据进行统计,以确定所述各类别对应的伪标签数据的关键数据的数据属性分布;根据所述关键数据的数据属性分布确定所述各类别对应的伪标签数据的变换规则;按照所述变换规则对所述各类别对应的伪标签数据进行变换处理。
在一种可选的实施方式中,所述通用数据和所述训练数据包括图像,所述数据增强策略包括以下任意一种或多种:对所述图像进行裁剪;对所述图像进行旋转;调整所述图像的亮度和/或对比度;在所述图像中添加孤立像素点,以进行加噪处理。
根据本发明实施方式的第二方面,提供一种数据处理方法,所述方法包括:获取待处理数据;采用训练后的目标网络模型,对所述待处理数据进行分类处理,得到所述待处理数据的分类结果;其中,所述训练后的目标网络模型为采用如上述任意一项所述的网络模型的训练方法获得的目标网络模型。
根据本发明实施方式的第三方面,提供一种网络模型的训练装置,包括:获取模块,用于获取目标任务所在领域的通用数据和所述目标任务的训练数据,其中,所述训练数据包括无标签数据和有标签数据;第一训练模块,用于将所述通用数据分别输入第一网络和第二网络进行无监督训练;第二训练模块,用于分别采用无监督训练后的所述第一网络和无监督训练后的所述第二网络对所述无标签数据进行无监督训练,以生成第一训练网络和第二训练网络;第三训练模块,用于根据所述有标签数据中的标签数据对所述第一训练网络进行监督训练,并通过监督训练后的所述第一训练网络对所述无标签数据进行预测,生成所述无标签数据的伪标签数据;生成模块,用于基于所述第二训练网络,对所述伪标签数据进行监督训练,生成所述目标任务的目标网络模型。
在一种可选的实施方式中,所述第一训练模块,被配置为分别通过所述第一网络和所述第二网络对所述通用数据进行无监督训练,得到所述第一网络的第一网络原始参数和所述第二网络的第二网络原始参数。
在一种可选的实施方式中,所述第二训练模块,被配置为以所述第一网络原始参数为初始参数,采用所述第一网络对所述无标签数据进行无监督训练,确定所述第一网络的第一网络更新参数,以生成所述第一训练网络,以所述第二网络原始参数为初始参数,采用所述第二网络对所述无标签数据进行无监督训练,确定所述第二网络的第二网络更新参数,以生成所述第二训练网络。
在一种可选的实施方式中,在根据所述有标签数据中的标签数据对所述第一训练网络进行监督训练时,所述第三训练模块,被配置为将所述有标签数据输入至所述第一训练网络,以更新所述第一网络的第一网络更新参数,得到所述第一网络的第一网络训练参数,以所述第一网络训练参数为初始参数,采用所述第一网络对所述无标签数据进行预测,生成所述伪标签数据。
在一种可选的实施方式中,所述生成模块,被配置为以所述第二网络更新参数为初始参数,对所述伪标签数据进行监督训练,以生成所述目标网络模型。
在一种可选的实施方式中,在生成所述目标任务的目标网络模型时,所述生成模块,被配置为根据所述有标签数据中的标签数据对所述目标网络模型进行监督训练,并调整所述目标网络模型的网络参数。
在一种可选的实施方式中,在对所述伪标签数据进行监督训练时,所述生成模块,被配置为确定所述伪标签数据的标签置信度,并根据所述标签置信度对所述伪标签数据进行数据选择。
在一种可选的实施方式中,所述生成模块,被配置为根据所述伪标签数据中各标签的标签置信度,从所述伪标签数据中筛选出所述标签置信度大于预设阈值的候选标签数据,在所述候选标签数据中,确定各标签对应的候选标签数据的数据量分布,并根据所述数据量分布对所述候选标签数据进行重采样。
在一种可选的实施方式中,在对所述伪标签数据进行监督训练时,所述生成模块,被配置为按照所述伪标签数据中各标签的标签置信度,将所述伪标签数据划分为多个类别,确定所述多个类别中各类别对应的伪标签数据的数据增强策略,并根据所述数据增强策略对所述各类别对应的伪标签数据进行增强处理。
在一种可选的实施方式中,所述生成模块,被配置为按照所述伪标签数据的数据属性对所述各类别对应的伪标签数据进行统计,以确定所述各类别对应的伪标签数据的关键数据的数据属性分布;根据所述关键数据的数据属性分布确定所述各类别对应的伪标签数据的变换规则;按照所述变换规则对所述各类别对应的伪标签数据进行变换处理。
在一种可选的实施方式中,所述通用数据和所述训练数据包括图像,所述数据增强策略包括以下任意一种或多种:对所述图像进行裁剪;对所述图像进行旋转;调整所述图像的亮度和/或对比度;在所述图像中添加孤立像素点,以进行加噪处理。
根据本发明实施方式的第四方面,提供一种数据处理装置,所述装置包括:获取模块,用于获取待处理数据;处理模块,用于采用训练后的目标网络模型,对所述待处理数据进行分类处理,得到所述待处理数据的分类结果;其中,所述训练后的目标网络模型为采用如上述任意一项所述的网络模型的训练方法获得的目标网络模型。
根据本发明实施方式的第五方面,提供一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现上述任意一种网络模型的训练方法。
根据本发明实施方式的第六方面,提供一种电子设备,包括:处理器;以及存储器,用于存储所述处理器的可执行指令;其中,所述处理器配置为经由执行所述可执行指令来执行上述任意一种网络模型的训练方法。
根据本发明实施方式的网络模型的训练方法、装置、数据处理方法、装置、存储介质及电子设备,可以获取目标任务所在领域的通用数据和目标任务的训练数据,该训练数据包括无标签数据和有标签数据,将通用数据分别输入至第一网络和第二网络进行无监督训练,然后分别采用无监督训练后的第一网络和无监督训练后的第二网络对无标签数据进行无监督训练,以生成第一训练网络和第二训练网络,根据有标签数据中的标签数据对第一训练网络进行监督训练,并通过监督训练后的第一训练网络对无标签数据进行预测,生成无标签数据的伪标签数据,最后基于第二训练网络,对伪标签数据进行监督训练,生成目标任务的目标网络模型。一方面,本方案通过使用第一训练网络生成伪标签数据,来对第二训练网络进行监督学习,实现了目标网络模型的参数压缩,提高了目标网络模型的泛化性能。另一方面,借助无标签的通用数据和训练数据中的无标签数据,减少了生成目标网络模型时对于有标签数据的数据量的需求,提高了目标网络模型从大量无标签数据中学习广泛数据特征的能力,提升了网络模型在多种任务类型上的泛化性能。再一方面,本方案可以根据目标任务的任务类型和内容,生成与目标任务相匹配的目标网络模型,可以适用于各种任务场景,并且在确定好中间环节的参数后,可以实现自动化、流程化的模型训练,提升了训练网络模型的效率。
附图说明
通过参考附图阅读下文的详细描述,本发明示例性实施方式的上述以及其他目的、特征和优点将变得易于理解。在附图中,以示例性而非限制性的方式示出了本发明的若干实施方式,其中:
图1示出了根据本发明实施方式的一种系统架构的示意图;
图2示出了根据本发明实施方式的一种网络模型的训练方法的流程图;
图3示出了根据本发明实施方式的一种训练方法的子流程图;
图4示出了根据本发明实施方式的一种数据选择的流程图;
图5示出了根据本发明实施方式的一种数据增强处理的流程图;
图6示出了根据本发明实施方式的一种数据变换处理的流程图;
图7示出了根据本发明实施方式的一种网络模型的训练方法的示意图;
图8示出了根据本发明实施方式的一种数据处理方法的流程图;
图9示出了根据本发明实施方式的一种网络模型的训练装置的结构图;
图10示出了根据本发明实施方式的一种数据处理装置的结构图;以及
图11示出了根据本发明实施方式的一种电子设备的结构图。
在附图中,相同或对应的标号表示相同或对应的部分。
具体实施方式
下面将参考若干示例性实施方式来描述本发明的原理和精神。应当理解,给出这些实施方式仅仅是为了使本领域技术人员能够更好地理解进而实现本发明,而并非以任何方式限制本发明的范围。相反,提供这些实施方式是为了使本发明更加透彻和完整,并且能够将本发明的范围完整地传达给本领域的技术人员。
本领域技术人员知道,本发明的实施方式可以实现为一种系统、装置、设备、方法或计算机程序产品。因此,本发明可以具体实现为以下形式,即:完全的硬件、完全的软件(包括固件、驻留软件、微代码等),或者硬件和软件结合的形式。
根据本发明的实施方式,提供一种网络模型的训练方法、数据处理方法、网络模型的训练装置、数据处理装置、计算机可读存储介质及电子设备。
在本文中,附图中的任何元素数量均用于示例而非限制,以及任何命名都仅用于区分,而不具有任何限制含义。
下面参考本发明的若干代表性实施方式,详细阐述本发明的原理和精神。
发明概述
本发明人发现,现有的数据分类技术主要是根据人为标注的标签数据确定待分类数据与具有标签数据的数据关联性实现数据分类的,但是,受限于人工标注标签数据准确性和数据规模等因素,网络模型往往存在模型泛化性能不足、模型过拟合等问题,严重影响了网络模型的训练效果,也需要耗费较高的人力成本。
鉴于上述内容,本发明的基本思想在于:提供一种网络模型的训练方法、数据处理方法、网络模型的训练装置、数据处理装置、计算机可读存储介质及电子设备,可以获取目标任务所在领域的通用数据和目标任务的训练数据,该训练数据包括无标签数据和有标签数据,将通用数据分别输入至第一网络和第二网络进行无监督训练,然后分别采用无监督训练后的第一网络和无监督训练后的第二网络对无标签数据进行无监督训练,以生成第一训练网络和第二训练网络,根据有标签数据中的标签数据对第一训练网络进行监督训练,并通过监督训练后的第一训练网络对无标签数据进行预测,生成无标签数据的伪标签数据,最后基于第二训练网络,对伪标签数据进行监督训练,生成目标任务的目标网络模型。一方面,本方案通过使用第一训练网络生成伪标签数据,来对第二训练网络进行监督学习,实现了目标网络模型的参数压缩,提高了目标网络模型的泛化性能。另一方面,借助无标签的通用数据和训练数据中的无标签数据,减少了生成目标网络模型时对于有标签数据的数据量的需求,提高了目标网络模型从大量无标签数据中学习广泛数据特征的能力,提升了网络模型在多种任务类型上的泛化性能。再一方面,本方案可以根据目标任务的任务类型和内容,生成与目标任务相匹配的目标网络模型,可以适用于各种任务场景,并且在确定好中间环节的参数后,可以实现自动化、流程化的模型训练,提升了训练网络模型的效率。
在介绍了本发明的基本原理之后,下面具体介绍本发明的各种非限制性实施方式。
应用场景总览
需要注意的是,下述应用场景仅是为了便于理解本发明的精神和原理而示出,本发明的实施方式在此方面不受任何限制。相反,本发明的实施方式可以应用于适用的任何场景。
针对各种应用场景中的目标任务,可以通过通用数据和训练数据中的无标签数据,使目标网络模型具备学习广泛数据特征的能力,提升网络模型在多种任务类型上的泛化性能,减少生成目标网络模型时对于有标签数据的数据量的需求;同时,通过使用第一训练网络生成伪标签数据,来对第二训练网络进行监督学习,实现目标网络模型的参数压缩,提高目标网络模型的泛化性能。
示例性方法
本发明的示例性实施方式首先提供一种网络模型的训练方法。图1示意性示出了该方法运行环境的系统架构图。如图1所示,该系统架构100可以包括:终端设备110和服务器120。终端设备110可以是智能手机、平板电脑、个人电脑、游戏机等,且终端设备110上可以安装有各种客户端应用,如图像处理类应用、语音服务类应用、文本处理类应用等;服务器120可以是独立的服务器设备,也可以是由多个服务器设备组成的服务器群组,或者也可以是提供云计算服务的云服务器。终端设备110和服务器120之间可以通过网络进行信息交互。本示例性实施方式所提供的网络模型的训练方法可以应用于终端设备110,也可以应用于服务器120,或者也可以由终端设备110和服务器120共同实现,本示例性实施方式对此不做特殊限定。举例而言,可以将基于本示例性实施方式训练得到的网络模型配置在图1所示的终端设备110上,或者也可以配置在服务器120上,在用户通过终端设备110上传数据时,可以通过上述网络模型对上传数据进行分类,从而将用户上传的数据自动划分为相应的类别。
需要说明的是,本示例性实施方式对于图1中各设备的数量不做限制,例如可以根据实现需要而设置任意数量的终端设备110,服务器120可以是由多台服务器形成的集群。
图2示出了由上述终端设备110和/或服务器120所执行的网络模型的训练方法的示例性流程,可以包括:
步骤S210,获取目标任务所在领域的通用数据和目标任务的训练数据,该训练数据包括无标签数据和有标签数据;
步骤S220,将通用数据分别输入第一网络和第二网络进行无监督训练;
步骤S230,分别采用无监督训练后的第一网络和无监督训练后的第二网络对无标签数据进行无监督训练,以生成第一训练网络和第二训练网络;
步骤S240,根据有标签数据中的标签数据对第一训练网络进行监督训练,并通过监督训练后的第一训练网络对无标签数据进行预测,生成无标签数据的伪标签数据;
步骤S250,基于第二训练网络,对伪标签数据进行监督训练,生成目标任务的目标网络模型。
下面分别对图2中的每个步骤做具体说明。
步骤S210中,获取目标任务所在领域的通用数据和目标任务的训练数据,该训练数据包括无标签数据和有标签数据。
目标任务可以是实际应用中需要通过训练后的网络模型解决的问题。对于不同的领域,目标任务的内容也有所不同。例如,当目标任务所在领域为图像处理领域时,目标任务可以是图像识别、图像分类等;当目标任务所在领域为音频处理领域时,目标任务可以是音频识别、语音转换等;当目标任务所在领域为文本处理领域时,目标任务可以是文本识别等。通用数据可以是目标任务所在领域的公用数据,该数据可以体现目标领域所在领域的广泛特征。例如,对于图像识别任务,通用数据可以是其所在领域,即图像处理领域中的公用数据集,如可以是ImageNet数据集(一种公开的图像数据集),该数据涵盖了多种拍摄对象的图像数据;对于音频识别任务,通用数据可以是音频处理领域中一种广泛使用的语音数据集;对于文本识别任务,通用数据可以是文本处理领域中一种公开的文本数据集,如可以是维基百科全书中的词汇数据等。在获取目标任务所在领域的通用数据时,可以按照目标任务的具体任务类型和内容确定所要获取数据的数据类目等级,如一级类目、二级类目等等,然后按照对应的类目等级获取通用数据,举例而言,对于人脸识别这一任务而言,获取数据的一级类目为图像数据,二级类目为人脸数据,此时,可以按照实际需求获取对应类目的通用数据,例如,可以获取一级类目数据,如ImageNet数据集,或者也可以只获取二级类目数据,如任意一种人脸数据集。需要说明的是,为了提高网络模型提取通用特征的能力,通用数据可以设置为数据内容丰富、数据量较大,能够充分体现目标任务所在领域的数据特征的数据。
训练数据是数据挖掘过程中用于训练网络模型的数据,可以提高网络模型处理目标任务的能力。与通用数据相比,训练数据可以是过去一段时间内收集的与目标任务相关的数据,例如,在视频类应用程序中,训练数据可以是过去一段时间收集的用户上传的视频数据。训练数据可以包括无标签数据和有标签数据,其中,无标签数据是指训练数据中不带有标签的数据,有标签数据是指训练数据中带有标签的数据,该标签可以由人工预先标注生成。
本示例性实施方式中,可以通过特定的数据库或数据平台获取预先收集的通用数据和训练数据,或者也可以通过爬虫等技术收集和整理特定应用平台的数据,以生成通用数据和训练数据。
从数据层面来看,通过上述步骤S210,可以提高数据的丰富度,减少网络在训练过程中对有标签数据的数据量的需求,在目标任务的训练数据较少时,也可以基于通用数据和训练数据的配合训练提升网络的网络性能。
步骤S220中,将通用数据分别输入第一网络和第二网络进行无监督训练。
本示例性实施方式中,可以采用知识蒸馏的方法训练网络模型,即通过引入复杂模型得到的“软目标”作为目标,用转化后的训练集训练简单模型,其中,软目标为由复杂模型对输入样本进行处理,输出的概率向量。经由这种方式,可以将一个网络的知识迁移到另一个网络。由此,第一网络可以是知识迁移过程中用以为第二网络提供更加准确的监督信息的高性能网络;相比第一网络,第二网络可以是计算速度快,但性能较差的网络,其网络结构通常比第一网络简单,具有更大的运算吞吐量和更少的模型参数。第一网络和第二网络可以采用相同的算法模型,也可以采用不同的算法模型。在确定第一网络和第二网络时,可以在预先设置的网络模型的集合中选取一个与第二网络实现的功能相同且性能优良的网络作为第一网络。
将通用数据分别输入第一网络和第二网络,可以通过第一网络和第二网络分别对通用数据进行无监督训练,使得第一网络和第二网络能够学习到广泛的特征描述。
在一种可选的实施方式中,在将通用数据分别输入第一网络和第二网络后,可以分别通过第一网络和第二网络对通用数据进行无监督训练,得到第一网络的第一网络原始参数和第二网络的第二网络原始参数。例如,可以在确定第一网络和第二网络的网络结构后,基于自监督SwAV算法(一种无监督对比学习方法),通过第一网络对通用数据进行无监督训练,当第一网络在通用数据上能够取得较好的划分能力时,确定此时第一网络的网络参数,作为第一网络的第一网络原始参数,记为WT1;相应的,可以基于SwAV算法,通过第二网络对通用数据进行无监督训练,当第二网络也能在通用数据上取得较好的划分能力时,确定第二网络的网络参数,即得到第二网络原始参数,记为WS1。
通过步骤S220,可以对通用数据进行无监督训练,使网络模型获取到通用数据域的数据分布信息,获得较好的通用特征提取能力,同时具备较好的泛化性能。
步骤S230中,分别采用无监督训练后的第一网络和无监督训练后的第二网络对无标签数据进行无监督训练,以生成第一训练网络和第二训练网络。
在对通用数据进行无监督训练后,可以采用无监督训练后的第一网络对目标任务的无标签数据进行第二次无监督训练,使得通过上述无监督训练后的第一网络将无标签数据划分为多个类别,在第一网络的划分误差达到最小或者训练次数达到预设次数时,根据此时得到的第一网络及其网络参数生成第一训练网络。相对应的,可以按照相同的处理方法,采用无监督训练后的第二网络对无标签数据进行无监督训练,可以得到第二网络对应的第二训练网络。
在一种可选的实施方式中,参考图3所示,可以通过以下步骤S310~S320生成第一训练网络和第二训练网络:
步骤S310中,以第一网络原始参数为初始参数,采用第一网络对无标签数据进行无监督训练,确定第一网络的第一网络更新参数,以生成第一训练网络。
步骤S320中,以第二网络原始参数为初始参数,采用第二网络对无标签数据进行无监督训练,确定第二网络的第二网络更新参数,以生成第二训练网络。
具体而言,为了不断优化第一网络和第二网络对于目标任务的处理能力,可以将第一网络的第一网络原始参数,即WT1作为首次迭代时的网络参数,采用第一网络对无标签数据进行无监督训练,不断更新第一网络的网络参数,直至训练完成时,将得到的网络参数确定为第一网络的第一网络更新参数,记为WT2,并生成第一训练网络;相对应的,可以将第二网络的第二网络原始参数,即WS1作为首次迭代时的网络参数,采用第二网络对无标签数据进行无监督训练,以更新第二网络的网络参数,得到第二网络的第二网络更新参数,记为WS2,并生成第二训练网络。
通过以上步骤S230,可以通过目标任务的无标签数据,进一步训练得到第一网络和第二网络的更新后的网络参数。在训练过程中,第一网络和第二网络可以以第一阶段无监督训练后得到的学习能力为基础,获取得到目标任务数据域的数据分布信息,实现学习能力从目标任务所在领域到目标任务本身的迁移。同时,由于在实际应用中,通用数据与目标任务的训练数据往往存在较大的数据差异,步骤S230中的无监督训练过程可以作为通用数据和目标任务训练数据之间的过渡,使第一网络和第二网络在具备学习通用特征的基础上,进一步提升对无标签数据的学习能力,避免数据域之间的差异对模型训练产生的影响,提升第一网络和第二网络的特征表达能力和分析能力。
步骤S240中,根据有标签数据中的标签数据对第一训练网络进行监督训练,并通过监督训练后的第一训练网络对无标签数据进行预测,生成无标签数据的伪标签数据。
其中,伪标签数据也可以称为软伪标签数据,表示伪标签数据中的标签数据时根据已有标签预测出来的,也就是说,伪标签数据中的标签数据并非真实的标签数据,而是基于已有标签近似得到的标签数据。
为了提升第一训练网络的网络性能,可以通过目标任务的有标签数据对第一训练网络进行监督训练,使其根据有标签数据中数据与标签之间的映射关系预测得到无标签数据的标签,从而生成无标签数据的伪标签数据。
具体的,在一种可选的实施方式中,在根据有标签数据中的标签数据对第一训练网络进行监督训练时,可以将有标签数据输入至第一训练网络,以更新第一网络的第一网络更新参数,得到第一网络的第一网络训练参数。例如,可以以第一网络更新参数WT2为初始参数,采用第一训练网络,即第一网络对有标签数据进行监督训练,更新WT2,得到第一网络的第一网络训练参数,记为WT3。
由此,在生成无标签数据的伪标签数据时,可以通过以第一网络训练参数WT3为初始参数,采用第一训练网络对无标签数据进行预测,生成伪标签数据。例如,可以通过计算有标签数据与伪标签数据之间的相似度,将具有最高相似度的有标签数据的标签确定为对应伪标签数据的标签,生成无标签数据的伪标签数据。
通过步骤S240,可以通过第一训练网络预测得到无标签数据的标签,生成伪标签数据,而伪标签数据可以作为后续生成目标网络模型的输入数据。
步骤S250中,基于第二训练网络,对伪标签数据进行监督训练,生成目标任务的目标网络模型。
本示例性实施方式中,目标网络模型可以用于对输入的数据进行分类,并输出分类结果,该分类结果可以包括目标任务中目标对象所属的类别,或是目标对象可能所属的类别,以及目标对象所属某一类别的概率和/或得分等。
为了使第二训练网络具有与第一网络相当,甚至更高的网络性能,在一种可选的实施方式中,可以通过以下方法生成目标网络模型:
以第二网络更新参数为初始参数,对伪标签数据进行监督训练,以生成目标网络模型。例如,可以以WS2为首次训练时的网络参数,采用第二训练网络对伪标签数据进行监督训练,同时不断更新第二训练网络的网络参数,直至第二训练网络在伪标签数据上的误差最小,或者迭代次数达到一定次数为止,将最终得到的第二训练网络作为目标网络模型,此时,第二训练网络的网络参数可以记为WS3。
进一步的,由于伪标签数据的标签并非是真实标签,为了提高目标网络模型的准确率,在一种可选的实施方式中,在生成所述目标任务的目标网络模型时,还可以通过以下方法对目标网络模型进行微调:
根据有标签数据中的标签数据对目标网络模型进行监督训练,并调整目标网络模型的网络参数。例如,可以目标网络模型的网络参数,即WS3作为首次训练时的网络参数,将有标签数据输入至目标网络模型进行监督训练,然后计算目标网络模型的误差,使得目标网络模型的准确率等性能指标达到一定程度,确定更新后的网络参数WS4,以得到最终的目标网络模型。
进一步的,由于伪标签数据的标签并不是完全准确的,为了避免因标签不准确而对网络模型的训练性能造成影响,在一种可选的实施方式中,在对伪标签数据进行监督训练时,可以通过以下方法对伪标签数据进行处理:
确定伪标签数据的标签置信度,并根据标签置信度对伪标签数据进行数据选择。其中,标签置信度也可以称为标签可靠度,或者标签置信水平、置信系数等,可以用于衡量标签的真实值有一定概率落在测量结果周围的程度。
本示例性实施方式中,可以通过置信学习等方法确定伪标签数据中各标签的标签置信度,从而筛选出标签置信度较高的数据。通过这一方式,可以从伪标签数据中筛选出标签可靠性较高的数据,提高得到的目标网络模型的网络性能。
在一种可选的实施方式中,参考图4所示,上述根据标签置信度对伪标签数据进行数据选择的方法可以包括以下步骤S410~S420:
步骤S410中,根据伪标签数据中各标签的标签置信度,从伪标签数据中筛选出标签置信度大于预设阈值的候选标签数据。其中,预设阈值可以根据目标任务的实际需求进行设置,例如,可以设置为0.8、0.9等,或者也可以设置为一定的阈值区间。例如,可以首先确定伪标签数据中各标签的标签置信度,然后为每个标签所属的类别设置相应的概率阈值区间,如高概率阈值区间Ti_high=[Ti1,Ti2]和次高概率阈值区间Ti_mid=[Ti3,Ti4],其中,i表示标签类别的顺序,且Ti1>Ti2>Ti3>Ti4,从而将标签的概率值落在概率阈值区间T内的数据筛选出来,得到伪标签数据的候选标签数据。
步骤S420中,在候选标签数据中,确定各标签对应的候选标签数据的数据量分布,并根据数据量分布对候选标签数据进行重采样。
为了使候选标签数据中各类别数据的数据量分布达到均衡,可以预先确定候选标签数据中各标签对应的候选标签数据的数据量分布,然后依据该数据量分布对候选标签数据进行重采样,具体的,可以采用上采样或下采样的方法处理候选标签数据,以上采样为例,可以通过计算任意两个样本数据之间的插值,来生成新的样本数据,从而增加数据量较少的标签类别所属的候选标签数据量,使得整个候选标签数据中各标签类别对应数据的数据量达到均衡。
在一种可选的实施方式中,也可以通过数据合成的方法均衡候选标签数据中各标签类别对应数据的数据量,例如,可以通过SMOTE(Synthetic Minority OversamplingTechnique,合成少数类过采样算法)利用小众样本在特征空间的相似性,从最接近的样本中选择一个样本点来生成新的小众样本。
通过根据标签置信度对伪标签数据进行数据选择,可以从伪标签数据中筛选出置信度较高的数据,保证第二训练网络输入数据的准确性,同时维持输入数据中各个标签类别对应数据的数据量的平衡,提高生成目标网络模型的训练性能。
进一步的,为了提高目标网络模型的表达能力,在一种可选的实施方式中,参考图5所示,在对伪标签数据进行监督训练时,还可以通过以下步骤S510~S520对伪标签数据进行增强处理:
步骤S510中,按照伪标签数据中各标签的标签置信度,将伪标签数据划分为多个类别。以两类为例,可以将伪标签数据划分为标签置信度较高的数据集Uhigh,以及标签置信度次高的数据集Umid。
步骤S520中,确定多个类别中各类别对应的伪标签数据的数据增强策略,并根据数据增强策略对各类别对应的伪标签数据进行增强处理。例如,对于上述标签置信度较高的数据集Uhigh和标签置信度次高的数据集Umid,可以为Uhigh和Umid分别分配相应的数据增强策略,在训练过程中按照各自对应的数据增强策略对对应的数据集进行增强处理。
具体而言,在一种可选的实施方式中,参考图6所示,步骤S520可以包括以下步骤S610~S630:
步骤S610中,按照伪标签数据的数据属性对各类别对应的伪标签数据进行统计,以确定各类别对应的伪标签数据的关键数据的数据属性分布。
其中,伪标签数据的数据属性可以根据伪标签数据的数据内容包括不同的属性内容。例如,对于图像数据,伪标签数据的数据属性可以包括图像的尺寸、亮度、像素数值、模糊程度等等;对于文本数据,伪标签数据的数据属性可以包括文本数据的语言类型、文本数据中各字或词出现的频次、出现的位置,以及相连接的字或词出现的频次、位置等。伪标签数据的关键数据可以是影响目标任务处理结果的数据属性数据,例如,对于图像数据而言,伪标签数据的关键数据可以包括图像中包含目标对象的位置区域、目标对象的颜色分布等;对于文本数据,伪标签数据的关键数据可以包括文本数据的语言类型、文本数据中关键字或词出现的频次、位置等。
步骤S620中,根据关键数据的数据属性分布确定各类别对应的伪标签数据的变换规则。其中,变换规则是指伪标签数据中各类别数据的数据处理规则。
步骤S630中,按照变换规则对各类别对应的伪标签数据进行变换处理。
在确定伪标签数据中关键数据的数据属性分布之后,可以按照该数据属性分布确定对应的伪标签数据的变换规则,从而可以按照对应的变换规则对伪标签数据进行变换处理,来增强伪标签数据中影响目标任务结果的数据的影响力,提高目标网络模型对于目标任务的泛化性能。举例而言,假设上述数据集Uhigh和Umid均为图像数据,可以分别统计Uhigh和Umid中每张图像的图像尺寸、像素数值、图像模糊程度等数据属性的分布情况,来确定Uhigh和Umid中关键数据,如关键区域位置分布、颜色分布等,确定对上述数据集Uhigh和Umid的变换规则,然后按照对应的变换规则对上述数据集Uhigh和Umid分别进行变换处理。
在一种可选的实施方式中,上述通用数据和训练数据均可以是图像,此时,数据增强策略可以包括对图像进行裁剪、对图像进行旋转、调整图像的亮度和/或对比度、在图像中添加孤立像素点,以进行加噪处理中的任意一种或多种。例如,在人脸识别的任务内容中,可以将包含人脸图像的图像进行裁剪,以去除图像中人脸区域以外的图像,并对人脸区域内的图像进行亮度或对比度的调整,来增加人脸区域的区分度,或者也可以对图像进行旋转,以使其沿某一固定方向显示,再或者也可以对图像进行加噪处理,例如,可以在图像中添加高斯噪声或椒盐噪声等,本示例性实施方式对此不做具体限定。
在一种可选的实施方式中,上述通用数据和训练数据也可以是文本数据,此时,数据增强策略可以包括两个方面,即句子层面增强和词层面增强。其中,句子层面增强是指在保持句子语义不变的情况下,变换文本的表达形式,例如回译、文本复述、变换句子位置等;词层面增强是指按照某种策略对文本局部进行调整,例如同义词替换、随机删除、随机交换、随机插入等。
在一种可选的实施方式中,上述通用数据和训练数据也可以是语音数据。由此,音频数据的数据增强策略可以包括音频时间延长、音调转换、音高偏移,以及添加噪声等中的任意一种或多种。其中,音频时间延长也就是放慢或加快音频样本采样的频率;音调转换可以是提高或降低音频样本的音高(同时保持持续时间不变);音高偏移是指将音频样本的音调高度偏移一定单位,如半音等。
通过对伪标签数据进行数据增强处理,可以为第二训练网络的训练过程增加适当学习难度,提升生成的目标网络模型的特征表达和分析能力。
实际上,从训练阶段来看,参考图7所示,本示例性实施方式中网络模型的训练方法可以划分为三个阶段,即通用数据上的无监督训练、目标任务无标签数据上的无监督训练和目标任务无标签数据和有标签数据上的半监督训练。下面对这三个阶段分别进行说明:
第一阶段:在通用数据上的无监督训练。
具体的,可以按照图2中步骤S220所示,将通用数据分别输入至第一网络和第二网络进行无监督训练。由于通用数据是目标任务所在领域的公用数据,所以通过将通用数据分别输入第一网络和第二网络进行无监督训练,可以使第一网络和第二网络具备提取目标任务所在领域通用特征的能力。
第二阶段:在目标任务无标签数据上的无监督训练。
为了提高第一网络和第二网络处理目标任务的能力,可以按照步骤S230所示,将目标任务的无标签数据分别输入至第一次无监督训练后得到的第一网络和第二网络中,以对无监督训练后的第一网络和无监督训练后的第二网络进行优化,并生成第一训练网络和第二训练网络。
在这一阶段,可以将第一网络和第二网络在目标任务所在领域的学习能力迁移到目标任务本身,使得第一网络和第二网络具备较好的处理目标任务的能力。
第三阶段:在目标任务无标签数据和有标签数据上的半监督训练。
半监督训练是一种自训练模式,指的是通过已有标签的数据训练网络,然后通过训练的网络模型给无标签数据打上伪标签的方法。本示例性实施方式中,第三阶段可以通过步骤S240完成,即根据有标签数据中的标签数据对第一训练网络进行监督训练,并通过监督训练后的第一训练网络对无标签数据进行预测,生成无标签数据的伪标签数据。
最后,可以基于第二阶段生成的第二训练网络,对伪标签数据进行监督训练,生成目标任务的目标网络模型。通过这一过程,可以生成针对目标任务的目标网络模型,使目标网络模型具备较高的网络性能。
本示例性实施方式还提供了一种数据处理方法,参考图8所示,该方法可以包括步骤S810~S820:
步骤S810中,获取待处理数据。
根据目标任务及其所在领域的类型,待处理数据可以是与目标任务相关的需要进行分类处理的数据。
步骤S820中,采用训练后的目标网络模型,对待处理数据进行分类处理,得到待处理数据的分类结果。
其中,训练后的目标网络模型为采用上述网络模型的训练方法获得的目标网络模型。
通过上述网络模型的训练方法生成的目标网络模型对输入的待处理数据进行处理,可以输出待处理数据所属的类别。以下结合应用场景及目标任务的具体内容,示出了几种目标网络模型的使用方法:
(1)图像分类
在图像分类任务中,可以通过获取通用数据,如ImageNet数据集和图像分类任务的训练数据,执行本示例性实施方式中的网络模型的训练方法,生成目标网络模型。
将待检测图像作为待处理数据,并输入目标网络模型,可以由目标网络模型输出待检测图像中包括的目标对象所属的分类结果。
由于在训练过程中,目标网络模型集成了第一网络的学习能力,具有较高的特征分析能力,更加突出待检测图像中关键区域的影响,因而图像分类的分类结果更加准确、可靠。
(2)语音识别
在语音识别任务中,可以首先获取语音识别所在领域的通用数据和语音识别任务的训练数据,该训练数据可以包括有标注的语音数据和无标注的语音数据,其中,有标注的语音数据可以由语音数据和其对应的文本数据构成,然后按照本示例性实施方式中得网络模型的训练方法生成目标网络模型。
在生成目标网络模型后,可以将待识别语音数据作为待处理数据,输入至目标网络模型中,输出待识别语音数据的文本信息。
(3)文本识别
在文本识别任务中,待处理文本可以是语言文本或图像文本等。通过获取通用文本数据,如通用预料数据集和文本识别任务的训练数据,即有标签数据和无标签数据,例如,有标签数据可以是与指定文本相似的文本数据,无标签数据可以是维基百科全书中的词汇数据等。采用上述网络模型的训练方法可以生成针对文本识别任务的目标网络模型。将待处理文本作为待处理数据输入至目标网络模型,可以得到文本识别模型对待处理文本的识别结果,
需要说明的是,以上对目标网络模型的使用方法的说明仅为示例性说明,本示例性实施方式中的目标网络模型可以适用于任意一种执行分类任务或与分类任务相关的应用场景中。
示例性装置
本发明示例性实施方式还提供一种网络模型的训练装置。参考图9所示,该网络模型的训练装置900可以包括:
获取模块910,用于获取目标任务所在领域的通用数据和目标任务的训练数据,其中,训练数据包括无标签数据和有标签数据;
第一训练模块920,用于将通用数据分别输入第一网络和第二网络进行无监督训练;
第二训练模块930,用于分别采用无监督训练后的第一网络和无监督训练后的第二网络对无标签数据进行无监督训练,以生成第一训练网络和第二训练网络;
第三训练模块940,用于根据有标签数据中的标签数据对第一训练网络进行监督训练,并通过监督训练后的第一训练网络对无标签数据进行预测,生成无标签数据的伪标签数据;
生成模块950,用于基于第二训练网络,对伪标签数据进行监督训练,生成目标任务的目标网络模型。
在一种可选的实施方式中,第一训练模块920,被配置为:
分别通过第一网络和第二网络对通用数据进行无监督训练,得到第一网络的第一网络原始参数和第二网络的第二网络原始参数。
在一种可选的实施方式中,第二训练模块930,被配置为:
以第一网络原始参数为初始参数,采用第一网络对无标签数据进行无监督训练,确定第一网络的第一网络更新参数,以生成第一训练网络;
以第二网络原始参数为初始参数,采用第二网络对无标签数据进行无监督训练,确定第二网络的第二网络更新参数,以生成第二训练网络。
在一种可选的实施方式中,在根据有标签数据中的标签数据对第一训练网络进行监督训练时,第三训练模块940,被配置为:
将有标签数据输入至第一训练网络,以更新第一网络的第一网络更新参数,得到第一网络的第一网络训练参数;
以第一网络训练参数为初始参数,采用第一网络对无标签数据进行预测,生成伪标签数据。
在一种可选的实施方式中,生成模块950,被配置为:
以第二网络更新参数为初始参数,对伪标签数据进行监督训练,以生成目标网络模型。
在一种可选的实施方式中,在生成目标任务的目标网络模型时,生成模块950,被配置为:
根据有标签数据中的标签数据对目标网络模型进行监督训练,并调整目标网络模型的网络参数。
在一种可选的实施方式中,在对伪标签数据进行监督训练时,生成模块950,被配置为:
确定伪标签数据的标签置信度,并根据标签置信度对伪标签数据进行数据选择。
在一种可选的实施方式中,生成模块950,被配置为:
根据伪标签数据中各标签的标签置信度,从伪标签数据中筛选出标签置信度大于预设阈值的候选标签数据;
在候选标签数据中,确定各标签对应的候选标签数据的数据量分布,并根据数据量分布对候选标签数据进行重采样。
在一种可选的实施方式中,在对伪标签数据进行监督训练时,生成模块950,被配置为:
按照伪标签数据中各标签的标签置信度,将伪标签数据划分为多个类别;
确定多个类别中各类别对应的伪标签数据的数据增强策略,并根据数据增强策略对各类别对应的伪标签数据进行增强处理。
在一种可选的实施方式中,生成模块950,被配置为:
按照伪标签数据的数据属性对各类别对应的伪标签数据进行统计,以确定各类别对应的伪标签数据的关键数据的数据属性分布;
根据关键数据的数据属性分布确定各类别对应的伪标签数据的变换规则;
按照变换规则对各类别对应的伪标签数据进行变换处理。
在一种可选的实施方式中,通用数据和训练数据包括图像,数据增强策略包括以下任意一种或多种:
对图像进行裁剪;
对图像进行旋转;
调整图像的亮度和/或对比度;
在图像中添加孤立像素点,以进行加噪处理。
本发明示例性实施方式还提供一种数据处理装置。参考图10所示,该数据处理装置1000可以包括:
获取模块1010,用于获取待处理数据;
处理模块1020,用于采用训练后的目标网络模型,对所述待处理数据进行分类处理,得到所述待处理数据的分类结果;其中,训练后的目标网络模型为采用上述网络模型的训练方法获得的目标网络模型。
此外,本发明实施方式的其他具体细节在上述方法的发明实施方式中已经详细说明,在此不再赘述。
示例性存储介质
下面对本发明示例性实施方式的存储介质进行说明。
本示例性实施方式中,可以通过程序产品实现上述方法,如可以采用便携式紧凑盘只读存储器(CD-ROM)并包括程序代码,并可以在设备,例如个人电脑上运行。然而,本发明的程序产品不限于此,在本文件中,可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。
该程序产品可以采用一个或多个可读介质的任意组合。可读介质可以是可读信号介质或者可读存储介质。可读存储介质例如可以为但不限于电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。
计算机可读信号介质可以包括在基带中或者作为载波一部分传播的数据信号,其中承载了可读程序代码。这种传播的数据信号可以采用多种形式,包括但不限于电磁信号、光信号或上述的任意合适的组合。可读信号介质还可以是可读存储介质以外的任何可读介质,该可读介质可以发送、传播或者传输用于由指令执行系统、装置或者器件使用或者与其结合使用的程序。
可读介质上包含的程序代码可以用任何适当的介质传输,包括但不限于无线、有线、光缆、RE等等,或者上述的任意合适的组合。
可以以一种或多种程序设计语言的任意组合来编写用于执行本发明操作的程序代码,程序设计语言包括面向对象的程序设计语言-诸如Java、C++等,还包括常规的过程式程序设计语言-诸如"C"语言或类似的程序设计语言。程序代码可以完全地在用户计算设备上执行、部分在用户计算设备上部分在远程计算设备上执行、或者完全在远程计算设备或服务器上执行。在涉及远程计算设备的情形中,远程计算设备可以通过任意种类的网络,包括局域网(LAN)或广域网(WAN),连接到用户计算设备,或者,可以连接到外部计算设备(例如利用因特网服务提供商来通过因特网连接)。
示例性电子设备
参考图11对本发明示例性实施方式的电子设备进行说明。该电子设备可以是上述服务器或终端设备。
图11显示的电子设备1100仅仅是一个示例,不应对本发明实施例的功能和使用范围带来任何限制。
如图11所示,电子设备1100以通用计算设备的形式表现。电子设备1100的组件可以包括但不限于:至少一个处理单元1110、至少一个存储单元1120、连接不同系统组件(包括存储单元1120和处理单元1110)的总线1130、显示单元1140。
其中,存储单元存储有程序代码,程序代码可以被处理单元1110执行,使得处理单元1110执行本说明书上述"示例性方法"部分中描述的根据本发明各种示例性实施方式的步骤。例如,处理单元1110可以执行如图2至图8所示的方法步骤等。
存储单元1120可以包括易失性存储单元,例如随机存取存储单元(RAM)1121和/或高速缓存存储单元1122,还可以进一步包括只读存储单元(ROM)1123。
存储单元1120还可以包括具有一组(至少一个)程序模块1125的程序/实用工具1124,这样的程序模块1125包括但不限于:操作系统、一个或者多个应用程序、其它程序模块以及程序数据,这些示例中的每一个或某种组合中可能包括网络环境的实现。
总线1130可以包括数据总线、地址总线和控制总线。
电子设备1100也可以与一个或多个外部设备1200(例如键盘、指向设备、蓝牙设备等)通信,这种通信可以通过输入/输出(I/O)接口1150进行。电子设备1100还包括显示单元1140,其连接到输入/输出(I/O)接口1150,用于进行显示。并且,电子设备1100还可以通过网络适配器1160与一个或者多个网络(例如局域网(LAN),广域网(WAN)和/或公共网络,例如因特网)通信。如图所示,网络适配器1160通过总线1130与电子设备1100的其它模块通信。应当明白,尽管图中未示出,可以结合电子设备1100使用其它硬件和/或软件模块,包括但不限于:微代码、设备驱动器、冗余处理单元、外部磁盘驱动阵列、RAID系统、磁带驱动器以及数据备份存储系统等。
应当注意,尽管在上文详细描述中提及了装置的若干模块或子模块,但是这种划分仅仅是示例性的并非强制性的。实际上,根据本发明的实施方式,上文描述的两个或更多单元/模块的特征和功能可以在一个单元/模块中具体化。反之,上文描述的一个单元/模块的特征和功能可以进一步划分为由多个单元/模块来具体化。
此外,尽管在附图中以特定顺序描述了本发明方法的操作,但是,这并非要求或者暗示必须按照该特定顺序来执行这些操作,或是必须执行全部所示的操作才能实现期望的结果。附加地或备选地,可以省略某些步骤,将多个步骤合并为一个步骤执行,和/或将一个步骤分解为多个步骤执行。
虽然已经参考若干具体实施方式描述了本发明的精神和原理,但是应该理解,本发明并不限于所公开的具体实施方式,对各方面的划分也不意味着这些方面中的特征不能组合以进行受益,这种划分仅是为了表述的方便。本发明旨在涵盖所附权利要求的精神和范围内所包括的各种修改和等同布置。
Claims (26)
1.一种网络模型的训练方法,其特征在于,所述方法包括:
获取目标任务所在领域的通用数据和所述目标任务的训练数据,其中,所述训练数据包括无标签数据和有标签数据;
将所述通用数据分别输入第一网络和第二网络进行无监督训练;
分别采用无监督训练后的所述第一网络和无监督训练后的所述第二网络对所述无标签数据进行无监督训练,以生成第一训练网络和第二训练网络;
根据所述有标签数据中的标签数据对所述第一训练网络进行监督训练,并通过监督训练后的所述第一训练网络对所述无标签数据进行预测,生成所述无标签数据的伪标签数据;
基于所述第二训练网络,对所述伪标签数据进行监督训练,生成所述目标任务的目标网络模型。
2.根据权利要求1所述的方法,其特征在于,所述将所述通用数据分别输入第一网络和第二网络进行无监督训练,包括:
分别通过所述第一网络和所述第二网络对所述通用数据进行无监督训练,得到所述第一网络的第一网络原始参数和所述第二网络的第二网络原始参数。
3.根据权利要求2所述的方法,其特征在于,所述分别采用无监督训练后的所述第一网络和无监督训练后的所述第二网络对所述无标签数据进行无监督训练,以生成第一训练网络和第二训练网络,包括:
以所述第一网络原始参数为初始参数,采用所述第一网络对所述无标签数据进行无监督训练,确定所述第一网络的第一网络更新参数,以生成所述第一训练网络;
以所述第二网络原始参数为初始参数,采用所述第二网络对所述无标签数据进行无监督训练,确定所述第二网络的第二网络更新参数,以生成所述第二训练网络。
4.根据权利要求3所述的方法,其特征在于,在根据所述有标签数据中的标签数据对所述第一训练网络进行监督训练时,所述方法包括:
将所述有标签数据输入至所述第一训练网络,以更新所述第一网络的第一网络更新参数,得到所述第一网络的第一网络训练参数;
所述通过监督训练后的所述第一训练网络对所述无标签数据进行预测,生成所述无标签数据的伪标签数据,包括:
以所述第一网络训练参数为初始参数,采用所述第一网络对所述无标签数据进行预测,生成所述伪标签数据。
5.根据权利要求3所述的方法,其特征在于,所述基于所述第二训练网络,对所述伪标签数据进行监督训练,生成所述目标任务的目标网络模型,包括:
以所述第二网络更新参数为初始参数,对所述伪标签数据进行监督训练,以生成所述目标网络模型。
6.根据权利要求1所述的方法,其特征在于,在生成所述目标任务的目标网络模型时,所述方法还包括:
根据所述有标签数据中的标签数据对所述目标网络模型进行监督训练,并调整所述目标网络模型的网络参数。
7.根据权利要求1所述的方法,其特征在于,在对所述伪标签数据进行监督训练时,所述方法还包括:
确定所述伪标签数据的标签置信度,并根据所述标签置信度对所述伪标签数据进行数据选择。
8.根据权利要求7所述的方法,其特征在于,所述根据所述标签置信度对所述伪标签数据进行数据选择,包括:
根据所述伪标签数据中各标签的标签置信度,从所述伪标签数据中筛选出所述标签置信度大于预设阈值的候选标签数据;
在所述候选标签数据中,确定各标签对应的候选标签数据的数据量分布,并根据所述数据量分布对所述候选标签数据进行重采样。
9.根据权利要求1所述的方法,其特征在于,在对所述伪标签数据进行监督训练时,所述方法还包括:
按照所述伪标签数据中各标签的标签置信度,将所述伪标签数据划分为多个类别;
确定所述多个类别中各类别对应的伪标签数据的数据增强策略,并根据所述数据增强策略对所述各类别对应的伪标签数据进行增强处理。
10.根据权利要求9所述的方法,其特征在于,所述确定所述多个类别中各类别对应的伪标签数据的数据增强策略,并根据所述数据增强策略对所述各类别对应的伪标签数据进行增强处理,包括:
按照所述伪标签数据的数据属性对所述各类别对应的伪标签数据进行统计,以确定所述各类别对应的伪标签数据的关键数据的数据属性分布;
根据所述关键数据的数据属性分布确定所述各类别对应的伪标签数据的变换规则;
按照所述变换规则对所述各类别对应的伪标签数据进行变换处理。
11.根据权利要求10所述的方法,其特征在于,所述通用数据和所述训练数据包括图像,所述数据增强策略包括以下任意一种或多种:
对所述图像进行裁剪;
对所述图像进行旋转;
调整所述图像的亮度和/或对比度;
在所述图像中添加孤立像素点,以进行加噪处理。
12.一种数据处理方法,其特征在于,所述方法包括:
获取待处理数据;
采用训练后的目标网络模型,对所述待处理数据进行分类处理,得到所述待处理数据的分类结果;
其中,所述训练后的目标网络模型为采用如权利要求1-11任意一项所述的网络模型的训练方法获得的目标网络模型。
13.一种网络模型的训练装置,其特征在于,所述装置包括:
获取模块,用于获取目标任务所在领域的通用数据和所述目标任务的训练数据,其中,所述训练数据包括无标签数据和有标签数据;
第一训练模块,用于将所述通用数据分别输入第一网络和第二网络进行无监督训练;
第二训练模块,用于分别采用无监督训练后的所述第一网络和无监督训练后的所述第二网络对所述无标签数据进行无监督训练,以生成第一训练网络和第二训练网络;
第三训练模块,用于根据所述有标签数据中的标签数据对所述第一训练网络进行监督训练,并通过监督训练后的所述第一训练网络对所述无标签数据进行预测,生成所述无标签数据的伪标签数据;
生成模块,用于基于所述第二训练网络,对所述伪标签数据进行监督训练,生成所述目标任务的目标网络模型。
14.根据权利要求13所述的装置,其特征在于,所述第一训练模块,被配置为:
分别通过所述第一网络和所述第二网络对所述通用数据进行无监督训练,得到所述第一网络的第一网络原始参数和所述第二网络的第二网络原始参数。
15.根据权利要求14所述的装置,其特征在于,所述第二训练模块,被配置为:
以所述第一网络原始参数为初始参数,采用所述第一网络对所述无标签数据进行无监督训练,确定所述第一网络的第一网络更新参数,以生成所述第一训练网络;
以所述第二网络原始参数为初始参数,采用所述第二网络对所述无标签数据进行无监督训练,确定所述第二网络的第二网络更新参数,以生成所述第二训练网络。
16.根据权利要求15所述的装置,其特征在于,在根据所述有标签数据中的标签数据对所述第一训练网络进行监督训练时,所述第三训练模块,被配置为:
将所述有标签数据输入至所述第一训练网络,以更新所述第一网络的第一网络更新参数,得到所述第一网络的第一网络训练参数;
以所述第一网络训练参数为初始参数,采用所述第一网络对所述无标签数据进行预测,生成所述伪标签数据。
17.根据权利要求15所述的装置,其特征在于,所述生成模块,被配置为:
以所述第二网络更新参数为初始参数,对所述伪标签数据进行监督训练,以生成所述目标网络模型。
18.根据权利要求13所述的装置,其特征在于,在生成所述目标任务的目标网络模型时,所述生成模块,被配置为:
根据所述有标签数据中的标签数据对所述目标网络模型进行监督训练,并调整所述目标网络模型的网络参数。
19.根据权利要求13所述的装置,其特征在于,在对所述伪标签数据进行监督训练时,所述生成模块,被配置为:
确定所述伪标签数据的标签置信度,并根据所述标签置信度对所述伪标签数据进行数据选择。
20.根据权利要求19所述的装置,其特征在于,所述生成模块,被配置为:
根据所述伪标签数据中各标签的标签置信度,从所述伪标签数据中筛选出所述标签置信度大于预设阈值的候选标签数据;
在所述候选标签数据中,确定各标签对应的候选标签数据的数据量分布,并根据所述数据量分布对所述候选标签数据进行重采样。
21.根据权利要求13所述的装置,其特征在于,在对所述伪标签数据进行监督训练时,所述生成模块,被配置为:
按照所述伪标签数据中各标签的标签置信度,将所述伪标签数据划分为多个类别;
确定所述多个类别中各类别对应的伪标签数据的数据增强策略,并根据所述数据增强策略对所述各类别对应的伪标签数据进行增强处理。
22.根据权利要求21所述的装置,其特征在于,所述生成模块,被配置为:
按照所述伪标签数据的数据属性对所述各类别对应的伪标签数据进行统计,以确定所述各类别对应的伪标签数据的关键数据的数据属性分布;
根据所述关键数据的数据属性分布确定所述各类别对应的伪标签数据的变换规则;
按照所述变换规则对所述各类别对应的伪标签数据进行变换处理。
23.根据权利要求22所述的装置,其特征在于,所述通用数据和所述训练数据包括图像,所述数据增强策略包括以下任意一种或多种:
对所述图像进行裁剪;
对所述图像进行旋转;
调整所述图像的亮度和/或对比度;
在所述图像中添加孤立像素点,以进行加噪处理。
24.一种数据处理装置,其特征在于,所述装置包括:
获取模块,用于获取待处理数据;
处理模块,用于采用训练后的目标网络模型,对所述待处理数据进行分类处理,得到所述待处理数据的分类结果;
其中,所述训练后的目标网络模型为采用如权利要求1-11任意一项所述的网络模型的训练方法获得的目标网络模型。
25.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1-12任一项所述的方法。
26.一种电子设备,其特征在于,包括:
处理器;以及
存储器,用于存储所述处理器的可执行指令;
其中,所述处理器配置为经由执行所述可执行指令来执行权利要求1-12任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110220979.3A CN112819099B (zh) | 2021-02-26 | 2021-02-26 | 网络模型的训练方法、数据处理方法、装置、介质及设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110220979.3A CN112819099B (zh) | 2021-02-26 | 2021-02-26 | 网络模型的训练方法、数据处理方法、装置、介质及设备 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112819099A CN112819099A (zh) | 2021-05-18 |
CN112819099B true CN112819099B (zh) | 2023-12-22 |
Family
ID=75862307
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110220979.3A Active CN112819099B (zh) | 2021-02-26 | 2021-02-26 | 网络模型的训练方法、数据处理方法、装置、介质及设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112819099B (zh) |
Families Citing this family (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113705629B (zh) * | 2021-08-09 | 2023-04-18 | 北京三快在线科技有限公司 | 一种训练样本生成方法、装置、存储介质及电子设备 |
CN114282721B (zh) * | 2021-12-22 | 2022-12-20 | 中科三清科技有限公司 | 污染物预报模型训练方法、装置、电子设备及存储介质 |
CN114821247B (zh) * | 2022-06-30 | 2022-11-01 | 杭州闪马智擎科技有限公司 | 一种模型的训练方法、装置、存储介质及电子装置 |
CN114973684B (zh) * | 2022-07-25 | 2022-10-14 | 深圳联和智慧科技有限公司 | 一种建筑工地定点监控方法及系统 |
Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
EP3252671A1 (en) * | 2016-05-31 | 2017-12-06 | Siemens Healthcare GmbH | Method of training a deep neural network |
CN109977918A (zh) * | 2019-04-09 | 2019-07-05 | 华南理工大学 | 一种基于无监督域适应的目标检测定位优化方法 |
WO2019233297A1 (zh) * | 2018-06-08 | 2019-12-12 | Oppo广东移动通信有限公司 | 数据集的构建方法、移动终端、可读存储介质 |
CN111062495A (zh) * | 2019-11-28 | 2020-04-24 | 深圳市华尊科技股份有限公司 | 机器学习方法及相关装置 |
CN111476284A (zh) * | 2020-04-01 | 2020-07-31 | 网易(杭州)网络有限公司 | 图像识别模型训练及图像识别方法、装置、电子设备 |
CN111898696A (zh) * | 2020-08-10 | 2020-11-06 | 腾讯云计算(长沙)有限责任公司 | 伪标签及标签预测模型的生成方法、装置、介质及设备 |
CN112101020A (zh) * | 2020-08-27 | 2020-12-18 | 北京百度网讯科技有限公司 | 训练关键短语标识模型的方法、装置、设备和存储介质 |
CN112115995A (zh) * | 2020-09-11 | 2020-12-22 | 北京邮电大学 | 一种基于半监督学习的图像多标签分类方法 |
CN112232416A (zh) * | 2020-10-16 | 2021-01-15 | 浙江大学 | 一种基于伪标签加权的半监督学习方法 |
Family Cites Families (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20150032589A1 (en) * | 2014-08-08 | 2015-01-29 | Brighterion, Inc. | Artificial intelligence fraud management solution |
US20190294973A1 (en) * | 2018-03-23 | 2019-09-26 | Google Llc | Conversational turn analysis neural networks |
EP3874417A1 (en) * | 2018-10-29 | 2021-09-08 | HRL Laboratories, LLC | Systems and methods for few-shot transfer learning |
-
2021
- 2021-02-26 CN CN202110220979.3A patent/CN112819099B/zh active Active
Patent Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
EP3252671A1 (en) * | 2016-05-31 | 2017-12-06 | Siemens Healthcare GmbH | Method of training a deep neural network |
WO2019233297A1 (zh) * | 2018-06-08 | 2019-12-12 | Oppo广东移动通信有限公司 | 数据集的构建方法、移动终端、可读存储介质 |
CN109977918A (zh) * | 2019-04-09 | 2019-07-05 | 华南理工大学 | 一种基于无监督域适应的目标检测定位优化方法 |
CN111062495A (zh) * | 2019-11-28 | 2020-04-24 | 深圳市华尊科技股份有限公司 | 机器学习方法及相关装置 |
CN111476284A (zh) * | 2020-04-01 | 2020-07-31 | 网易(杭州)网络有限公司 | 图像识别模型训练及图像识别方法、装置、电子设备 |
CN111898696A (zh) * | 2020-08-10 | 2020-11-06 | 腾讯云计算(长沙)有限责任公司 | 伪标签及标签预测模型的生成方法、装置、介质及设备 |
CN112101020A (zh) * | 2020-08-27 | 2020-12-18 | 北京百度网讯科技有限公司 | 训练关键短语标识模型的方法、装置、设备和存储介质 |
CN112115995A (zh) * | 2020-09-11 | 2020-12-22 | 北京邮电大学 | 一种基于半监督学习的图像多标签分类方法 |
CN112232416A (zh) * | 2020-10-16 | 2021-01-15 | 浙江大学 | 一种基于伪标签加权的半监督学习方法 |
Also Published As
Publication number | Publication date |
---|---|
CN112819099A (zh) | 2021-05-18 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112819099B (zh) | 网络模型的训练方法、数据处理方法、装置、介质及设备 | |
US11120801B2 (en) | Generating dialogue responses utilizing an independent context-dependent additive recurrent neural network | |
US20230237088A1 (en) | Automatically detecting user-requested objects in digital images | |
US11210470B2 (en) | Automatic text segmentation based on relevant context | |
CN107481717B (zh) | 一种声学模型训练方法及系统 | |
AU2019239454B2 (en) | Method and system for retrieving video temporal segments | |
US20220130499A1 (en) | Medical visual question answering | |
US11682415B2 (en) | Automatic video tagging | |
US20220292805A1 (en) | Image processing method and apparatus, and device, storage medium, and image segmentation method | |
US20240114158A1 (en) | Hierarchical Video Encoders | |
CN114443899A (zh) | 视频分类方法、装置、设备及介质 | |
CN114661951A (zh) | 一种视频处理方法、装置、计算机设备以及存储介质 | |
US20220284343A1 (en) | Machine teaching complex concepts assisted by computer vision and knowledge reasoning | |
US10909473B2 (en) | Method to determine columns that contain location data in a data set | |
CN111460224B (zh) | 评论数据的质量标注方法、装置、设备及存储介质 | |
US20210166016A1 (en) | Product baseline information extraction | |
CN115129922A (zh) | 搜索词生成方法、模型训练方法、介质、装置和设备 | |
CN117011737A (zh) | 一种视频分类方法、装置、电子设备和存储介质 | |
US11989626B2 (en) | Generating performance predictions with uncertainty intervals | |
US11710098B2 (en) | Process flow diagram prediction utilizing a process flow diagram embedding | |
US20220180123A1 (en) | System, method and apparatus for training a machine learning model | |
US20210142188A1 (en) | Detecting scenes in instructional video | |
Cheng et al. | Video reasoning for conflict events through feature extraction | |
US20220382806A1 (en) | Music analysis and recommendation engine | |
CN116631433A (zh) | 音频分离方法、装置、电子设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
TA01 | Transfer of patent application right | ||
TA01 | Transfer of patent application right |
Effective date of registration: 20210928 Address after: 310000 Room 408, building 3, No. 399, Wangshang Road, Changhe street, Binjiang District, Hangzhou City, Zhejiang Province Applicant after: Hangzhou Netease Zhiqi Technology Co.,Ltd. Address before: 310052 Building No. 599, Changhe Street Network Business Road, Binjiang District, Hangzhou City, Zhejiang Province, 4, 7 stories Applicant before: NETEASE (HANGZHOU) NETWORK Co.,Ltd. |
|
GR01 | Patent grant | ||
GR01 | Patent grant |