CN109902798A - 深度神经网络的训练方法和装置 - Google Patents
深度神经网络的训练方法和装置 Download PDFInfo
- Publication number
- CN109902798A CN109902798A CN201810554459.4A CN201810554459A CN109902798A CN 109902798 A CN109902798 A CN 109902798A CN 201810554459 A CN201810554459 A CN 201810554459A CN 109902798 A CN109902798 A CN 109902798A
- Authority
- CN
- China
- Prior art keywords
- data
- sample data
- domain
- target domain
- loss
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 326
- 238000000034 method Methods 0.000 title claims description 117
- 238000013528 artificial neural network Methods 0.000 title claims description 22
- 230000006870 function Effects 0.000 claims abstract description 108
- 230000003044 adaptive effect Effects 0.000 claims abstract description 33
- 230000002708 enhancing effect Effects 0.000 claims abstract description 21
- 230000000875 corresponding effect Effects 0.000 claims description 249
- 238000000605 extraction Methods 0.000 claims description 42
- 238000013527 convolutional neural network Methods 0.000 claims description 27
- 238000003860 storage Methods 0.000 claims description 17
- 238000009826 distribution Methods 0.000 claims description 16
- 238000004590 computer program Methods 0.000 claims description 7
- 230000002596 correlated effect Effects 0.000 claims description 6
- 235000013399 edible fruits Nutrition 0.000 claims description 4
- 230000001276 controlling effect Effects 0.000 claims 1
- 230000008569 process Effects 0.000 description 25
- 239000000284 extract Substances 0.000 description 20
- 238000013473 artificial intelligence Methods 0.000 description 16
- 238000010586 diagram Methods 0.000 description 16
- 239000011159 matrix material Substances 0.000 description 16
- 238000012360 testing method Methods 0.000 description 14
- 238000001514 detection method Methods 0.000 description 13
- 238000013145 classification model Methods 0.000 description 12
- 238000012545 processing Methods 0.000 description 10
- 238000004891 communication Methods 0.000 description 9
- 238000004422 calculation algorithm Methods 0.000 description 7
- 238000004364 calculation method Methods 0.000 description 7
- 238000013526 transfer learning Methods 0.000 description 7
- 241000208340 Araliaceae Species 0.000 description 4
- MHABMANUFPZXEB-UHFFFAOYSA-N O-demethyl-aloesaponarin I Natural products O=C1C2=CC=CC(O)=C2C(=O)C2=C1C=C(O)C(C(O)=O)=C2C MHABMANUFPZXEB-UHFFFAOYSA-N 0.000 description 4
- 235000005035 Panax pseudoginseng ssp. pseudoginseng Nutrition 0.000 description 4
- 235000003140 Panax quinquefolius Nutrition 0.000 description 4
- 238000013500 data storage Methods 0.000 description 4
- 238000013135 deep learning Methods 0.000 description 4
- 230000000694 effects Effects 0.000 description 4
- 235000008434 ginseng Nutrition 0.000 description 4
- 230000001537 neural effect Effects 0.000 description 4
- 238000010606 normalization Methods 0.000 description 4
- 238000005457 optimization Methods 0.000 description 4
- 230000008447 perception Effects 0.000 description 4
- 241001269238 Data Species 0.000 description 3
- 238000013459 approach Methods 0.000 description 3
- 238000004925 denaturation Methods 0.000 description 3
- 230000036425 denaturation Effects 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 238000010801 machine learning Methods 0.000 description 3
- 238000013508 migration Methods 0.000 description 3
- 230000005012 migration Effects 0.000 description 3
- 230000011218 segmentation Effects 0.000 description 3
- 230000004913 activation Effects 0.000 description 2
- 230000006978 adaptation Effects 0.000 description 2
- 230000004069 differentiation Effects 0.000 description 2
- 230000004927 fusion Effects 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 238000003475 lamination Methods 0.000 description 2
- 210000002569 neuron Anatomy 0.000 description 2
- 230000002093 peripheral effect Effects 0.000 description 2
- 230000004044 response Effects 0.000 description 2
- 238000005070 sampling Methods 0.000 description 2
- 238000004458 analytical method Methods 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 230000015572 biosynthetic process Effects 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000000151 deposition Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000018109 developmental process Effects 0.000 description 1
- 238000005538 encapsulation Methods 0.000 description 1
- 238000000802 evaporation-induced self-assembly Methods 0.000 description 1
- 239000004744 fabric Substances 0.000 description 1
- 230000008570 general process Effects 0.000 description 1
- 238000007689 inspection Methods 0.000 description 1
- 239000007788 liquid Substances 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 230000013011 mating Effects 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 210000005036 nerve Anatomy 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 238000011017 operating method Methods 0.000 description 1
- 238000011176 pooling Methods 0.000 description 1
- 230000000644 propagated effect Effects 0.000 description 1
- 230000001902 propagating effect Effects 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 230000003252 repetitive effect Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000001953 sensory effect Effects 0.000 description 1
- 238000004088 simulation Methods 0.000 description 1
- 238000013519 translation Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/211—Selection of the most significant subset of features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Abstract
本发明提出了一种协同对抗网络,在协同对抗网络的低层设置损失函数,用于学习域区分性特征,并且与协同对抗网络的最后一层(即高层)设置的域不变性损失函数形成协同对抗目标函数,实现同时学习域区分性特征与域不变性特征。进一步地,提出了一种增强协同对抗网络,在协同对抗网络的基础上,将目标领域的数据加入协同对抗网络的训练中,以及根据任务模型的精度,设置自适应阈值对目标领域的训练样本进行选择,并根据域区分网络的置信度,对目标领域的训练样本设置权重。通过该协同对抗网络训练的任务模型能够提高应用在目标领域中的预测精度。
Description
技术领域
本发明涉及机器学习领域,特别涉及迁移学习领域中基于对抗网络的训练方法和装置。
背景技术
人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式作出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。人工智能领域的研究包括机器人,自然语言处理,计算机视觉,决策与推理,人机交互,推荐与搜索,AI基础理论等。
深度学习是近年来人工智能领域发展的一个关键推动力,尤其是在计算机视觉的多种任务方面,如目标分类/检测/识别/分割中,取得了令人瞩目的效果;但是,深度学习的成功需要依赖于大量的已标注的数据。然而,标注大量的数据,是一项极其耗时耗力的工作。目前针对相同或相似的任务,可以将依据源领域中公开的数据集或已标注的数据训练好的任务模型直接应用到目标领域的任务预测,目标领域是相对于源领域而言的,目标领域一般没有已标注的数据或者没有足够的已标注的数据,源领域中公开的数据集和已标注的数据可以称作源领域数据,相应的,目标领域中未标注的数据可以称作目标领域数据。由于目标领域数据与源领域数数据的分布不相同,直接使用依据源领域数据训练好的模型的效果不佳。
非监督域适应(unsupervised domain adaption)是一种典型的迁移学习方法,可用于解决上述问题。与直接将依据源领域数据训练好的模型用于目标领域的任务预测不同,非监督域适应方法不仅利用源领域数据进行训练,同时将未标注的目标领域数据融合到训练当中,使训练的模型在目标领域数据上有较好的预测效果。目前,现有技术中性能比较好的非监督域适应方法是基于领域对抗的非监督域适应方法,如图1所示的一种基于领域对抗的非监督域适应训练图像分类器的方法,其特点是在学习图像分类任务的同时使用领域区分器(英文全称:Domain Discriminator)和梯度方向(Gradient Reversal)方法学习域不变性特征。主要步骤是:(1)使用卷积神经网络特征提取器(Convolutional NeuralNetwork Feature Extractor,CNN Feature Extractor)提取的特征除了输入到图像分类器中,还用于建立一个领域区分器,领域区分器可以对输入的特征可以输出领域类别;(2)使用梯度反向方法,在反向传播过程中修改梯度方向,从而使得卷积神经网络特征提取器学习的特征具有域不变性;(3)将以上得到卷积神经网络特征提取器和得到的分类器,用于目标领域的图像分类预测。
发明内容
为了解决基于领域对抗的非监督域适应方法存在的丢失具有域区分性的低层特征的问题。本申请提供了一种基于协同对抗网络的训练方法,能够保留具有域区分性的低层特征,从而提高任务模型的精度。进一步提供了一种增加协同领域对抗的方法,将目标领域中的数据用于训练任务模型,提高训练出的任务模型在目标领域的适配性。
第一方面,本申请提供了一种深度神经网络的训练方法,该训练方法应用于迁移学习领域,具体是将根据源领域数据训练的任务模型应用到目标领域数据的预测,该训练方法包括:提取输入该深度神经网络的源领域数据和目标领域数据中各样本数据所对应的低层特征和高层特征,其中,目标领域数据与源领域数据存在差异,也就是说两者的数据分布不一致;基于源领域数据和目标领域数据中各样本数据的高层特征和对应的领域标签,通过第一损失函数分别计算各样本数据对应的第一损失;基于源领域数据和目标领域数据中各样本数据的低层特征和对应的领域标签,通过第二损失函数分别计算各样本数据对应的第二损失;基于源领域数据中的样本数据的高层特征和对应的样本标签,通过第三损失函数计算源领域数据中的样本数据对应的第三损失;根据上述得到的第一损失、第二损失和第三损失更新目标深度神经网络中各模块的参数。更新是通过损失反向传播对参数进行更新,在反向传播中,第一损失的梯度需要经过梯度反向操作,梯度反向操作的目的实现反向传导梯度使损失变大。通过在高层特征和低层特征分别设置第一损失函数和第二损失函数,可以使得高层特征具有不变性的同时使得低层特征具有域区分性,提高训练得到的模型应用于到目标领域的预测的精度。
第一方面的一种可能的实现方式,该目标深度神经网络包括特征提取模块、任务模块、域不变性特征模块和域区分性特征模块,特征提取模块包括至少一个低层特征网络层和高层特征网络层,至少一个低层特征网络层中的任一个低层特征网络层可用于提取低层特征,高层特征网络层用于提取高层特征,域不变性特征模块用于增强特征提取模块提取的高层特征的领域不变性,域区分性特征模块用于增强特征提取模块提取的低层特征的领域区分性;
其中,上述根据第一损失、第二损失和第三损失更新目标深度神经网络的参数包括:首先根据第一损失、第二损失和第三损失计算总损失;再根据总损失更新特征提取模块的参数、任务模块的参数、域不变性特征模块的参数和域区分性特征模块的参数,需要注意的是,总损失可以是一个样本数据的第一损失、第二损失和第三损失的总和,也可以是多个样本数据的多个第一损失、多个第二损失和多个第三损失的总和。各损失具体在反向传播过程中作于目标神经网络中相应的模块的参数,具体的是第一损失通过反向传播对对域不变性特征模块和特征提取模块的参数进行更新,第二损失通过反向传播对域区分性特征模块和特征提取模块的参数进行更新。第三损失通过反向传播对任务模块和特征提取模块的参数进行更新。损失一般是进一步得到相应的梯度在进行反向传播进行更新相对模块的参数。
第一方面的另一种可能的实现方式,上述基于源领域数据和目标领域数据中各样本数据的高层特征和对应的领域标签,通过第一损失函数分别计算各样本数据对应的第一损失,包括:将源领域数据和目标领域数据中的各样本数据的高层特征输入域不变性特征模块得到各样本数据对应的第一结果;根据源领域数据和目标领域数据中的各样本数据对应的第一结果和对应的领域标签,通过第一损失函数分别计算各样本数据对应的第一损失。
上述基于源领域数据和目标领域数据中各样本数据的低层特征和对应的领域标签,通过第二损失函数分别计算各样本数据对应的第二损失,包括:将源领域数据和目标领域数据中的各样本数据的低层特征输入域区分性特征模块得到各样本数据对应的第二结果;根据源领域数据和目标领域数据中的各样本数据对应的第二结果和对应的领域标签,通过第二损失函数分别计算各样本数据对应的第二损失。
上述基于源领域数据中的样本数据的高层特征和对应的样本标签,通过第三损失函数计算源领域数据中的样本数据对应的第三损失,包括:将源领域数据中的样本数据的高层特征输入任务模块得到源领域数据中的样本数据对应的第三结果;基于源领域数据中的样本数据对应的第三结果和对应的样本标签,通过第三损失函数计算源领域数据中的样本数据对应的第三损失。
第一方面的另一种可能的实现方式,域不变性特征模块还包括:梯度反向模块;该训练方法还包括:通过该梯度反向模块对第一损失的梯度进行梯度反向。梯度方向可以实现反向传导第一损失的梯度使得第一损失函数的计算的损失变大,使得高层特征具有域不变性特征,
第一方面的另一种可能的实现方式,该训练方法还包括:将目标领域数据中样本数据的高层特征输入任务模块,得到对应的预测样本标签和对应的置信度;根据目标领域数据中样本数据对应的置信度从目标领域数据中选定目标领域训练样本数据,目标领域训练样本数据为目标领域数据中对应的置信度满足预设条件的样本数据。使用目标领域数据用于训练任务模型,能够进一步提高任务模型在目标领域的数据上的分类精度。
第一方面的另一种可能的实现方式,该训练方法还包括:根据目标领域训练样本数据对应的第一结果设置目标领域训练样本数据的权重。当目标领域训练样本数据不易被领域区分器区分时,则目标领域训练样本数据的分布比较接近于源领域图像数据与目标领域图像数据之间,对图像分类模型的训练更有帮助,因此根据第一结果设置权重能将上述描述的不易被领域区分的目标领域训练样本数据在训练中占较大的权重。
第一方面的另一种可能的实现方式,根据目标领域训练样本数据对应的第一结果设置目标领域训练样本数据的权重包括:根据目标领域训练样本数据对应的第一结果与领域标签的相似度,设置目标领域训练样本数据的权重,相似度表示第一结果与领域标签的差值大小。
第一方面的另一种可能的实现方式,上述根据目标领域训练样本数据对应的第一结果与领域标签的相似度,设置目标领域训练样本数据的权重包括:计算目标领域训练样本数据对应的第一结果与源领域的领域标签的第一差值,以及目标领域训练样本数据对应的第一结果与目标领域的领域标签的第二差值;若第一差值的绝对值大于第二差值的绝对值,则设置目标领域训练样本数据的权重为较小的值,例如小于0.5的值;否则,设置目标领域训练样本数据的权重为较大的值,例如大于0.5的值。
第一方面的另一种可能的实现方式,若目标领域训练样本数据对应的第一结果为第一领域标签值至第二领域标签值取值范围中的中间值,则设置目标领域训练样本数据的权重为最大值(例如1)。关于中间值的示例,例如第一领域标签值为0,第二领域标签值为1,中间值是指0.5或者为0.5上下浮动区间中的值。其中第一领域标签值为源领域的领域标签对应的值,第二领域标签值为目标领域的领域标签对应的值。
第一方面的另一种可能的实现方式,在上述根据目标领域数据中样本数据对应的置信度从目标领域数据中选定目标领域训练样本数据之前,该训练方法还包括:根据任务模型的精度设置自适应阈值,任务模型包括特征提取模块和任务模块,自适应阈值与任务模型的精度正相关;其中,预设条件为置信度大于或等于自适应阈值。
第一方面的另一种可能的实现方式,自适应阈值通过下面逻辑函数计算:
其中,Tc为自适应阈值,A为任务模型的精度,λc为用于控制逻辑函数的倾斜度的超参数。
第一方面的另一种可能的实现方式,训练方法还包括:通过特征提取模块提取目标领域训练样本数据的低层特征和高层特征;基于目标领域训练样本数据的高层特征和对应的领域标签,通过第一损失函数计算目标领域训练样本数据对应的第一损失;基于目标领域训练样本数据的低层特征和对应的领域标签,通过第二损失函数计算目标领域训练样本数据对应的第二损失;基于目标领域训练样本数据的高层特征和对应的预测样本标签,通过第三损失函数计算目标领域训练样本数据对应的第三损失;根据目标领域训练样本数据对应的第一损失、第二损失和第三损失计算目标领域训练样本数据对应的总损失,其中,目标领域训练样本数据对应的第一损失的梯度经过梯度反向;根据目标领域训练样本数据对应的总损失和目标领域训练样本数据的权重,更新特征提取模块的参数、任务模块的参数、域不变性特征模块的参数和域区分性特征模块的参数。
第一方面的另一种可能的实现方式,上述基于目标领域训练样本数据的高层特征和对应的领域标签,通过第一损失函数计算目标领域训练样本数据对应的第一损失包括:将目标领域训练样本数据的高层特征输入域不变性特征模块得到目标领域训练样本数据对应的第一结果;根据目标领域训练样本数据对应的第一结果和对应的领域标签,通过第一损失函数计算目标领域训练样本数据对应的第一损失;
上述基于目标领域训练样本数据的低层特征和对应的领域标签,通过第二损失函数计算目标领域训练样本数据对应的第二损失包括:将目标领域训练样本数据的低层特征输入域区分性特征模块得到目标领域训练样本数据对应的第二结果;根据目标领域训练样本数据对应的第二结果和对应的领域标签,通过第二损失函数计算目标领域训练样本数据对应的第二损失;
基于目标领域训练样本数据的高层特征和对应的预测样本标签,通过第三损失函数计算目标领域训练样本数据对应的第三损失,包括:将目标领域训练样本数据的高层特征输入任务模块得到目标领域训练样本数据对应的第三结果;基于目标领域训练样本数据对应的第三结果和对应的预测样本标签,通过第三损失函数计算目标领域训练样本数据对应的第三损失。
第二方面,本申请提供了一种训练设备,该训练设备包括存储器及与存储器耦合的处理器;存储器用于存储指令,处理器用于执行指令;其中,处理器执行指令时执行上述第一方面和第一方面的可能的实现方式中描述的方法。
第三方面,本申请提供了一种计算机可读存储介质,该计算机可读存储有计算机程序,该计算机程序被处理器执行时实现上述第一方面和第一方面的可能的实现方式中描述的方法。
第四方面,本申请提供了一种计算机程序产品,该计算机程序产品包括用于执行上述第一方面和第一方面的可能的实现方式中描述的方法的代码。
第五方面,本申请提供了一种训练装置,该训练装置包括用于执行上述第一方面和第一方面的可能的实现方式中描述的方法的功能单元。
第六方面,本申请提供了一种基于卷积神经网络CNN构建的增强协同对抗网络,该增强协同对抗网络包括:用于提取源领域数据和目标领域数据中各样本数据的低层特征和高层特征的特征提取模块,目标领域数据与源领域数据的数据分布不同;用于接收特征提取模块输出的高层特征且通过第三损失函数分别计算各样本数据对应的第三损失的任务模块,第三损失用于更新特征提取模块和任务模块的参数;用于接收特征提取模块输出的高层特征且通过第一损失函数分别计算各样本数据对应的第一损失的域不变性模块,第一损失用于更新特征提取模块和域不变性模块的参数,使得特征提取模块输出的高层特征具有域不变性;用于接收特征提取模块输出的低层特征且通过第二损失函数分别计算各样本数据对应的第二损失的域区分性模块,第二损失用于更新特征提取模块和域区分性模块的参数,使得特征提取模块输出的低层特征具有域区分性。
第六方面的一种可能的实现方式,该增强协同对抗网络还包括:用于根据目标领域数据中样本数据对应的置信度从目标领域数据中选定目标领域训练样本数据的样本数据选择模块,目标领域数据中样本数据对应的置信度通过将目标领域数据中样本数据的高层特征输入任务模块得到,目标领域训练样本数据为目标领域数据中对应的置信度满足预设条件的样本数据。
第六方面的另一种可能的实现方式,上述样本数据选择模块还用于根据任务模型的精度设置自适应阈值,任务模型包括特征提取模块和任务模块,自适应阈值与任务模型的精度正相关;其中,预设条件为置信度大于或等于自适应阈值。
第六方面的另一种可能的实现方式,该增强协同对抗网络还包括用于根据目标领域训练样本数据对应的第一结果设置目标领域训练样本数据的权重的权重设置模块。
第六方面的另一种可能的实现方式,上述权重设置模块具体用于根据目标领域训练样本数据对应的第一结果与领域标签的相似度,设置目标领域训练样本数据的权重;相似度表示第一结果与领域标签的差值大小。
第六方面的另一种可能的实现方式,上述权重设置模块具体用于计算目标领域训练样本数据对应的第一结果与源领域的领域标签的第一差值,以及目标领域训练样本数据对应的第一结果与目标领域的领域标签的第二差值;若第一差值的绝对值大于第二差值的绝对值,则设置目标领域训练样本数据的权重为较小的值,否则,设置目标领域训练样本数据的权重为较大的值。
第六方面的另一种可能的实现方式,上述权重设置模块具体用于:若目标领域训练样本数据对应的第一结果为第一领域标签值至第二领域标签值取值范围中的中间值,则设置目标领域训练样本数据的权重为最大值,例如1,第一领域标签值为源领域的领域标签对应的值,第二领域标签值为目标领域的领域标签对应的值。中间值的说明可以参见第一方面的相关描述,此处不再赘述。
第七方面,本申请提供了一种基于协同对抗网络的训练数据权重设置方法,该协同对抗网络至少包括特征提取模块、任务模块、域不变性模块,还可以包括域区分性模块,关于各模块可以参考上面第六方面的相关描述,此处不再赘述。该权重设置方法包括:将目标领域数据中样本数据的高层特征输入任务模块得到对应的预测样本标签和对应的置信度;根据目标领域数据中样本数据对应的置信度从目标领域数据中选定目标领域训练样本数据,目标领域训练样本数据为目标领域数据中对应的置信度满足预设条件的样本数据;将将目标领域数据中样本数据的高层特征输入域不变性模块得到目标领域训练样本数据对应的第一结果;根据目标领域训练样本数据对应的第一结果设置目标领域训练样本数据的权重。
第七方面的一种可能的实现方式,上述根据目标领域训练样本数据对应的第一结果设置目标领域训练样本数据的权重具体包括:根据目标领域训练样本数据对应的第一结果与领域标签的相似度,设置目标领域训练样本数据的权重,相似度表示第一结果与领域标签的差值大小。
第七方面的另一种可能的实现方式,上述根据目标领域训练样本数据对应的第一结果与领域标签的相似度,设置目标领域训练样本数据的权重包括:计算目标领域训练样本数据对应的第一结果与源领域的领域标签的第一差值,以及目标领域训练样本数据对应的第一结果与目标领域的领域标签的第二差值;若第一差值的绝对值大于第二差值的绝对值,则设置目标领域训练样本数据的权重为较小的值,例如小于0.5的值;否则,设置目标领域训练样本数据的权重为较大的值,例如大于0.5的值。
第七方面的另一种可能的实现方式,若目标领域训练样本数据对应的第一结果为第一领域标签值至第二领域标签值取值范围中的中间值,则设置目标领域训练样本数据的权重为最大值(例如1)。关于中间值的示例,例如第一领域标签值为0,第二领域标签值为1,中间值是指0.5或者为0.5上下浮动区间中的值。其中第一领域标签值为源领域的领域标签对应的值,第二领域标签值为目标领域的领域标签对应的值。
第七方面的另一种可能的实现方式,在上述根据目标领域数据中样本数据对应的置信度从目标领域数据中选定目标领域训练样本数据之前,该权重设置方法还包括:根据任务模型的精度设置自适应阈值,任务模型包括特征提取模块和任务模块,自适应阈值与任务模型的精度正相关;其中,预设条件为置信度大于或等于自适应阈值。
上述自适应阈值通过下面逻辑函数计算:
其中,Tc为自适应阈值,A为任务模型的精度,λc为用于控制逻辑函数的倾斜度的超参数。
第八方面,本申请提供了一种设备,该设备包括存储器及与存储器耦合的处理器;存储器用于存储指令,处理器用于执行指令;其中,处理器执行指令时执行上述第七方面和第七方面的可能的实现方式中描述的方法。
第九方面,本申请提供了一种计算机可读存储介质,该计算机可读存储有计算机程序,该计算机程序被处理器执行时实现上述第七方面和第七方面的可能的实现方式中描述的方法。
第十方面,本申请提供了一种计算机程序产品,该计算机程序产品包括用于执行上述第七方面和第七方面的可能的实现方式中描述的方法的代码。
第十一方面,本申请提供了一种权重设置装置,该权重设置装置包括用于执行上述第七方面和第七方面的可能的实现方式中描述的方法的功能单元。
本申请实施例提供的训练方法基于高层特征和低层特征分别建立了域不变性损失函数和域区分性损失函数,在保证高层特征的域不变性特征的同时保留了低层特征中的域区分性特征,能够提高训练得到的任务模型应用到目标领域中进行预测的精度。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为本发明实施例提供的一种基于非监督域适应训练图像分类器的方法示意图;
图2为本发明实施例提供的一种人工智能主体框架示意图;
图3为本发明实施例提供的不同城市的人车图像数据对照示意图;
图4为本发明实施例提供的不同地域的人脸图像数据对照示意图;
图5为本发明实施例提供的一种训练系统架构示意图;
图6为本发明实施例提供的一种特征提取单元的示意图;
图7为本发明实施例提供的一种特征提取CNN的示意图;
图8为本发明实施例提供的一种域不变性特征单元的示意图;
图9为本发明实施例提供的一种训练装置的结构示意图
图10为本发明实施例提供的另一种训练装置的结构示意图;
图11为本发明实施例提供的一种云-端系统架构示意图;
图12为本发明实施例提供的一种训练方法的流程图;
图13为本发明实施例提供的一种基于协同对抗网络的训练方法示意图;
图14为本发明实施例提供的权重设置曲线示意图;
图15为本发明实施例提供的一种芯片硬件结构示意图;
图16为本发明实施例提供的一种训练设备结构示意图;
图17A为本发明实施例提供的在Office-31上的测试结果;
图17B为本发明实施例提供的在ImageCLEF-DA上的测试结果。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
图2示出一种人工智能主体框架示意图,该主体框架描述了人工智能系统总体工作流程,适用于通用的人工智能领域需求。
下面从“智能信息链”(水平轴)和“IT价值链”(垂直轴)两个维度对上述人工智能主题框架进行阐述。
“智能信息链”反映从数据的获取到处理的一列过程。举例来说,可以是智能信息感知、智能信息表示与形成、智能推理、智能决策、智能执行与输出的一般过程。在这个过程中,数据经历了“数据—信息—知识—智慧”的凝练过程。
“IT价值链”从人智能的低层基础设施、信息(提供和处理技术实现)到系统的产业生态过程,反映人工智能为信息技术产业带来的价值。
(1)基础设施:
基础设施为人工智能系统提供计算能力支持,实现与外部世界的沟通,并通过基础平台实现支撑。通过传感器与外部沟通;计算能力由智能芯片(CPU、NPU、GPU、ASIC、FPGA等硬件加速芯片)提供;基础平台包括分布式计算框架及网络等相关的平台保障和支持,可以包括云存储和计算、互联互通网络等。举例来说,传感器和外部沟通获取数据,这些数据提供给基础平台提供的分布式计算系统中的智能芯片进行计算。
(2)数据
基础设施的上一层的数据用于表示人工智能领域的数据来源。数据涉及到图形、图像、语音、文本,还涉及到传统设备的物联网数据,包括已有系统的业务数据以及力、位移、液位、温度、湿度等感知数据。
(3)数据处理
数据处理通常包括数据训练,机器学习,深度学习,搜索,推理,决策等方式。
其中,机器学习和深度学习可以对数据进行符号化和形式化的智能信息建模、抽取、预处理、训练等。
推理是指在计算机或智能系统中,模拟人类的智能推理方式,依据推理控制策略,利用形式化的信息进行机器思维和求解问题的过程,典型的功能是搜索与匹配。
决策是指智能信息经过推理后进行决策的过程,通常提供分类、排序、预测等功能。
(4)通用能力
对数据经过上面提到的数据处理后,进一步基于数据处理的结果可以形成一些通用的能力,比如可以是算法或者一个通用系统,例如,翻译,文本的分析,计算机视觉的处理,语音识别,图像的识别等等。
(5)智能产品及行业应用
智能产品及行业应用指人工智能系统在各领域的产品和应用,是对人工智能整体解决方案的封装,将智能信息决策产品化、实现落地应用,其应用领域主要包括:智能制造、智能交通、智能家居、智能医疗、智能安防、自动驾驶,平安城市,智能终端等。
本申请中涉及的重要概念的相关说明
非监督域适应,是迁移学习的一种典型方法,依据源领域与目标领域的数据进行任务模型的训练,通过训练好的任务模型来实现对目标领域中物体的识别/分类/分割/检测等,其中源领域的数据有标签,而目标领域的数据无标签,并且两种领域数据的分布不相同。需要注意的,在本申请中“源领域的数据”与“源领域数据”,“目标领域的数据”与“目标领域数据”通常上具有相同的含义。
域不变性特征:是指不同领域数据的通用特征,从不同领域数据中提取的特征具有一致的分布。
域区分性特征:指对于特定领域数据中的特征,对于不同领域数据中提取的特征具有不相同的分布。
本申请描述了一种神经网络的训练方法,该训练方法应用于迁移学习领域的任务/预测模型(以下简称为任务模型)的训练。具体地,可应用于训练基于深度神经网络构建的各种任务模型,包括但不限于分类模型、识别模型、分割模型、检测模型。通过本申请描述的训练方法得到的任务模型可广泛应用到AI拍照、自动驾驶、平安城市等多种具体应用场景,以实现应用场景的智能化。
以自动驾驶应用场景中的人车检测为例,人车检测是自动驾驶感知系统里面的一个基本单元。人车检测的准确程度关系到自动驾驶车辆的安全,能否准确地检测出车辆周围的行人和行车,关键是用于人车检测的检测模型是否具有高精度,然而高精度的检测模型依赖于大量的已标注的人车图像/视频数据。标注数据又是一项庞大的工程,为了达到自动驾驶的精度要求,几乎需要针对不同的城市标注不同数据,这是难以实现的。为了提高训练效率,人车检测模型的迁移是最常用的方法,即直接将依据区域A的已标注的人车图像/视频数据训练的检测模型,应用到没有或没有足够的已标注的人车图像/视频数据的区域B场景中的人车检测,这里的区域A为源领域,区域B为目标领域,区域A的数据为有标签的源领域数据,区域B的数据为无标签的目标领域数据。然而,以城市为例,不同的城市的人种、生活习惯、建筑风格、气候环境、交通设施等以及数据采集设备可能存在很大的差异,即在数据的分布不同,很难保证自动驾驶的精度要求的。如图3所示,左面的四张图像欧洲某一城市的采集设备采集到图像数据,右面的四张图像是亚洲某城市采集设备采集到的图像数据,可以看出,行人皮肤、穿着、姿态存在明显的差异,城市建筑和行车外观也存在很明显的差异。如果将依据图3中一个城市的图像/视频数据训练的检测模型应用到图3中的另一个城市场景,那么检测模型的精度必然会大幅降低。本申请描述的训练方法利用已标注的数据和未标注的数据共同训练任务模型,即利用区域A的已标注的人车图像/视频数据和区域B的为标注的人车图像/视频数据共同训练用于人车检测的检测模型,能够大幅提高依据区域A的人车图像/视频数据训练的检测模型应用到区域B场景中人车检测的精度。
再以人脸识别应用场景为例,人脸识别往往涉及到不同国家、地域的人的识别,不同国家、地域的人的人脸数据会有较大的分布差异。如图4所示,假如欧洲白种人的人脸数据有训练标签作为源领域数据,即已标注的人脸数据;非洲黑种人的人脸数据无训练标签作为目标领域数据,即未标注的人脸数据。由于白种人和黑种人的肤色、脸部轮廓等存在很大的差异,致使人脸数据分布不同;不过,即使黑种人的人脸数据是未标注数据,通过本申请描述的训练方法得到的人脸识别模型也能够提高黑种人的人脸识别准确率。
本发明实施例提供了一种深度神经网络训练系统架构100。如图5所示,系统架构100至少包括训练装置110、数据库120,还包括数据采集设备130、客户设备140和数据存储系统150。
数据采集设备130用于采集数据并将采集到的数据(例如:图片/视频/音频等)存入数据库120作为训练数据。数据库120用于维护和存储训练数据,数据库120存储的训练数据包括源领域数据和目标领域数据,源领域数据可以理解为已标注的数据,目标领域数据可以理解为未标注的数据,源领域与目标领域是迁移学习领域的相对概念,具体的,可参见图3和图4对应描述理解源领域、目标领域、源领域数据和目标领域数据,上述概念是本技术领域人员能够理解的。训练装置110与数据库120交互,从数据库120中获取需要的训练数据,用来训练任务模型,任务模型包括特征提取模块和任务模块,特征提取模块可以是特征提取单元111,也可以是利用训练后的特征提取单元111的参数构建的深度神经网络;同样地,任务模块可以是任务单元112,也可以是利用训练后的任务单元112的参数构建的模型,例如函数模型、神经网络模型等。训练装置110通过训练得到的任务模型可以应用在客户设备140中,也可以响应客户设备140请求输出预测结果。例如,客户设备140是自动驾驶车辆,训练装置110根据数据库120中的训练数据训练好人车检测模型,当自动驾驶车辆需要执行人车检测时,可以由训练装置110得到的人车检测模型完成人车检车并反馈给自动驾驶车辆,训练好的人车检测模型可以布置在自动驾驶车辆上,也可以是布置在云端,具体形式不做限制。客户设备140在需要的情况下,也可以作为数据库120的数据采集设备,以扩充数据库。
训练装置110包括特征提取单元111、任务单元112、域不变性特征单元113、域区分性特征单元114和I/O接口115,/O接口115用于训练设备110与外界设备进行交互。
特征提取单元111用于提取输入数据的低层特征和高层特征,如图6所示,特征提取单曲单元111包括低层特征提取子单元1111和高层特征提取子单元1112,低层特征提取子单元1111用于提输入数据的低层特征,高层特征提取子单元1112用于提取输入数据的高层特征。具体的,数据输入低层特征提取子单元1111后得到表示低层特征的数据,表示低层特征的数据再输入高层特征提取子单元1112后得到表示高层特征的数据,也就是说高层特征是基于低层特征进一步处理得到的特征。
特征提取单元111可以由软件、硬件(例如电路)或软件与硬件(例如处理器调用代码)结合实现。常用的是通过神经网络实现特征提取单元111的功能,可选的,特征提取单元111的功能由卷积神经网络(Convosutionas Neuras Network,CNN)实现,如图7所示,特征提取CNN包括多个卷积层,通过卷积计算可实现输入数据的特征提取,多个卷积层的最后一层卷积层可以称为高层卷积层,作为高层特征提取子单元1112用于提取高层特征;其他卷积层可称为低层卷积层,作为低层特征提取子单元1111用于提取低层特征。每一个低层卷积层可以均可以输出一个低层特征,即一个数据输入作为特征提取单元111的CNN后,可以输出一个高层特征和至少一个低层特征,低层特征的数量可根据实际训练需求设置,制定具体的输出用于作为低层特征提取子单元1111输出低层特征的低层卷积层。
卷积神经网络(Convosutionas Neuras Network,CNN)是一种带有卷积结构的深度神经网络。卷积神经网络包含了由卷积层和子采样层构成的特征抽取器。该特征抽取器可以看作是滤波器,卷积过程可以看作是使用一个可训练的滤波器与一个输入的图像或者卷积特征平面(feature map)做卷积。卷积层是指卷积神经网络中对输入信号进行卷积处理的神经元层。在卷积神经网络的卷积层中,一个神经元可以只与部分邻层神经元连接。一个卷积层中,通常包含若干个特征平面,每个特征平面可以由一些矩形排列的神经单元组成。同一特征平面的神经单元共享权重,这里共享的权重就是卷积核。共享权重可以理解为提取图像信息的方式与位置无关。这其中隐含的原理是:图像的某一部分的统计信息与其他部分是一样的。即意味着在某一部分学习的图像信息也能用在另一部分上。所以对于图像上的所有位置,我们都能使用同样的学习得到的图像信息。在同一卷积层中,可以使用多个卷积核来提取不同的图像信息,一般地,卷积核数量越多,卷积操作反映的图像信息越丰富。
卷积核可以以随机大小的矩阵的形式初始化,在卷积神经网络的训练过程中卷积核可以通过学习得到合理的权重。另外,共享权重带来的直接好处是减少卷积神经网络各层之间的连接,同时又降低了过拟合的风险。
卷积神经网络可以采用误差反向传播(back propagation,BP)算法在训练过程中修正初始的超分辨率模型中参数的大小,使得超分辨率模型的重建误差损失越来越小。具体地,前向传递输入信号直至输出会产生误差损失,通过反向传播误差损失信息来更新初始的超分辨率模型中参数,从而使误差损失收敛。反向传播算法是以误差损失为主导的反向传播运动,旨在得到最优的超分辨率模型的参数,例如权重矩阵。
任务单元112的输入是高层特征提取子单元1112输出的高层特征,具体是已标注的源领域数据经过特征提取单元111输出的高层特征,输出是标签。训练后的任务单元112和特征提取单元111可以作为任务模型,任务模型可以用于目标领域的预测任务。
域不变性特征单元113的输入是高层特征提取子单元1112输出的高层特征,输出是对应数据所属的领域(源领域或目标领域)标签。如图8所示,域不变性特征单元113包括域区分特征子单元1131和梯度反向子单元1132,梯度反向子单元1132可以对反向传播的梯度进行梯度反向,使得域区分特征子单元1131输出的领域标签与真实领域标签的误差(即损失)变大。域不变性特征单元113能够实现特征提取单元111输出的高层特征具有领域不变性,也就是降低通过特征提取单元111输出的高层特征较难或无法对领域进行区分。
域区分性特征单元114的输入是低层特征提取子单元1111输出的低层特征,输出是对应数据所属的领域标签。域区分性特征单元114能够使得特征提取单元111输出的低层特征容易对领域进行区分,从而具有域区分性。
需要注意的,域区分性特征单元114与域区分特征子单元1131都可以针对输入特征输出所属的领域,域不变性特征单元113和域区分性特征单元114的主要区别在于域不变性特征单元113还包括梯度反向子单元1132。域区分性特征单元114和特征提取单元111可以构成一个领域区分模型,同样地,忽略梯度反向子单元1132,域不变性特征单元113中的域区分特征子单元1131和特征提取单元111也可以构成一个领域区分模型。
可选的,训练装置110为图9所示的结构,训练装置110包括特征提取单元111、任务单元112、域区分性特征单元113'、梯度反向单元114'和I/O接口115。域区分性特征单元113'和梯度反向单元114'相当于图5中训练装置110的域不变性特征单元113和域区分性特征单元114。
任务单元112、域不变性特征单元113和域区分性特征单元114以及域区分性特征单元113'、梯度反向单元114'可以由软件、硬件(例如电路)或软件与硬件(例如处理器调用代码)结合实现,可以由向量矩阵、函数、神经网络等具体实现,不做限定。任务单元112、域不变性特征单元113和域区分性特征单元114均包括损失函数用于计算输出值与真实值的损失,损失用于更新各单元中的参数,具体更新细节是本技术领域的技术人员所能理解的,不做赘述。
训练装置110包括域不变性特征单元113和域区分性特征单元114,通过源领域数据和目标领域数据的训练,能够得到的特征提取单元111输出的低层特征具有域区分性,而输出的高层特征具有域不变性,高层特征是基于低层特征进一步得到的,使得高层特征仍能很好的保留具有域区分性的特征,进一步地用于任务模型可以提高预测精度。
如图10所示,训练装置110还包括样本数据选择单元116,样本数据选择单元116用于从目标领域数据中选择满足条件的数据作为训练样本数据用于训练装置110进行的训练。样本数据选择单元116具体包括选择子单元1161和权重设置子单元1162。选择子单元1161用于根据任务模型的精度从目标领域数据中选择出满足条件的数据并添加相应的标签作为训练样本数据。权重设置子单元1162用于给选定的作为训练样本数据的目标领域数据设置权重,通过权重设置以明确作为训练样本数据的目标领域数据对任务模型训练的影响程度。具体如何选择和设置权重,将在下面进行详细描述,此处不再赘述。需要说明的,图10中的其他单元包括图5中的特征提取单元111、任务单元112、域不变性特征单元113、域区分性特征单元114和I/O接口115,或者,特征提取单元111、任务单元112、域区分性特征单元113'、梯度反向单元114'和I/O接口115。
本发明实施例提供了一种云-端系统架构200,如图11所示,执行设备210由一个或多个服务器实现,可选的,与其它计算设备配合,例如:数据存储、路由器、负载均衡器等设备;执行设备210可以布置在一个物理站点上,或者分布在多个物理站点上。可选的,执行设备210可以使用数据存储系统220中的数据,或者调用数据存储系统220中的程序代码实现训练装置110的所有功能;具体地,执行设备210可以根据数据库120中的训练数据训练任务模型,以及根据本地设备231(232)的请求完成目标领域的任务预测。可选的,执行设备210不具备训练装置110的训练功能,但是可以根据训练装置110训练好的任务模型完成预测;具体的,执行设备210配置有训练装置110训练好任务模型后,在接收到本地设备231(232)的请求后完成预测并反馈结果给本地设备231(232)。
用户可以操作各自的用户设备(例如本地设备231和本地设备232)与执行设备210进行交互。每个本地设备可以表示任何计算设备,例如个人计算机、计算机工作站、智能手机、平板电脑、智能摄像头、智能汽车或其他类型蜂窝电话、媒体消费设备、可穿戴设备、机顶盒、游戏机等。
每个用户的本地设备可以通过任何通信机制/通信标准的通信网络与执行设备210进行交互,通信网络可以是广域网、局域网、点对点连接等方式,或它们的任意组合。
在另一种实现中,执行设备210的一个方面或多个方面可以由每个本地设备实现,例如,本地设备301可以为执行设备210提供本地数据或反馈计算结果。
需要注意的,执行设备210的所有功能也可以由本地设备实现。例如,本地设备231实现执行设备210的的功能(例如:训练或预测)并为自己的用户提供服务,或者为本地设备232的用户提供服务。
本申请实施例提供了一种目标深度神经网络的训练方法,该目标深度神经网络是一个系统架构的统称,具体地,包括特征提取模块(对应特征提取单元111)、任务模块(对应任务单元112)、域不变性特征模块(对应域不变性特征单元113)和域区分性特征模块(对应域区分性特征单元114或者域区分性特征单元113'),特征提取模块包括至少一个低层特征网络层(对应低层特征提取子单元1111)和高层特征网络层(对应高层特征提取子单元1112),至少一个低层特征网络层中的任一个低层特征网络层可用于提取低层特征,高层特征网络层用于提取高层特征,域不变性特征模块用于增强特征提取模块提取的高层特征的领域不变性,域区分性特征模块用于增强特征提取模块提取的低层特征的领域区分性。如图12所示,该训练方法的具体步骤为:
S101,提取源领域数据和目标领域数据中各样本数据的低层特征和高层特征,目标领域数据与源领域数据在数据分布上不同;
具体地,利用低层特征网络层提取源领域数据和目标领域数据中各样本数据对应的低层特征,利用高层特征网络层提取提取源领域数据和目标领域数据中各样本数据对应的高层特征。
S102,基于源领域数据和目标领域数据中各样本数据的高层特征和对应的领域标签,通过第一损失函数分别计算各样本数据对应的第一损失;具体地,将源领域数据和目标领域数据中的各样本数据的高层特征输入域不变性特征模块得到各样本数据对应的第一结果;根据源领域数据和目标领域数据中的各样本数据对应的第一结果和对应的领域标签,通过第一损失函数分别计算各样本数据对应的第一损失。
进一步地,上述域不变性特征模块还包括:梯度反向模块(对应梯度反向子单元);该训练方法还包括:通过梯度反向模块对第一损失的梯度进行梯度反向处理,梯度反向的可以使用任一的现有技术,例如Gradient Reversal Layer(GRL)。
S103,基于源领域数据和目标领域数据中各样本数据的低层特征和对应的领域标签,通过第二损失函数分别计算各样本数据对应的第二损失;
具体地,将源领域数据和目标领域数据中的各样本数据的低层特征输入域区分性特征模块得到各样本数据对应的第二结果;根据源领域数据和目标领域数据中的各样本数据对应的第二结果和对应的领域标签,通过第二损失函数分别计算各样本数据对应的第二损失。
S104,基于源领域数据中的样本数据的高层特征和对应的样本标签,通过第三损失函数计算源领域数据中的样本数据对应的第三损失;
具体地,将源领域数据中的样本数据的高层特征输入任务模块得到源领域数据中的样本数据对应的第三结果;基于源领域数据中的样本数据对应的第三结果和对应的样本标签,通过第三损失函数计算源领域数据中的样本数据对应的第三损失。
S105,根据第一损失、第二损失和第三损失更新目标深度神经网络的参数,其中第一损失的梯度经过梯度反向,梯度反向可实现反向传导梯度使损失变大;
具体地,根据第一损失、第二损失和第三损失计算总损失;
根据总损失更新特征提取模块的参数、任务模块的参数、域不变性特征模块的参数和域区分性特征模块的参数。
训练后的特征提取模块和任务模块作为任务模型,用于目标领域的预测任务,当然也可以用源领域的预测任务。
进一步地,该训练方法还包括以下步骤:
S106,将目标领域数据中样本数据的高层特征输入任务模块,得到对应的预测样本标签和对应的置信度。
S107,根据目标领域数据中样本数据对应的置信度从目标领域数据中选定目标领域训练样本数据,目标领域训练样本数据是指目标领域数据中对应的置信度满足预设条件的样本数据;
具体的,根据任务模型的精度设置自适应阈值,任务模型包括特征提取模块和任务模块,自适应阈值与任务模型的精度正相关;其中,预设条件是指置信度大于或等于自适应阈值。
可选的,自适应阈值通过下面逻辑函数计算:
其中,Tc为自适应阈值,A为任务模型的精度,λc为用于控制逻辑函数的倾斜度的超参数。
S108,根据目标领域训练样本数据对应的第一结果设置目标领域训练样本数据的权重。
具体地,根据域区分特征子单元1131输出的预测值(对应第一结果),判断其与源领域数据或者目标领域数据分布的相似度,并根据相似度设置目标域样本的权重。相似度可以用预测值与领域标签的差值表示。具体地,预先给源领域标签和目标领域标签各设定一个值,例如,设定源领域的领域标签(可简称源领域标签)为a,设定目标领域的领域标签(可简称目标领域标签)为b,则预测值x的取值范围在a和b之间,可以根据|x-a|与|x-b|的大小来判断相似程度,差值的绝对值越小说明相似程度越大(即更接近)。权重设置可以有两种方案:(1)当预测值更接近源领域领域签的值时,设置较小权重;若预测值在源领域标签的值与目标领域标签的值中间,设置较大权重。(2)当预测值更接近源领域标签的值时,设置较小权重;若输出值与目标领域标签的值更接近时,设置较大权重。上述较小权重和较大权重是相对而言的,可以根据实际设定确定具体数值。权重大小与相似度的关系,可以简单概括为:预测值更倾向于源领域标签值,则相应权重倾向于较小值。也就是,根据预测值判定对应的目标领域训练样本数据是源领域的数据的可能性更大,则设置该目标领域训练样本数据权重较小值,反之可以设置较大值。关于取值设置还可以可参见图14对应实施例的相关描述。
根据步骤S106-S108选定的目标领域训练样本数据除了具有领域标签,还包含预测样本标签和权重,选定的目标领域训练样本数据可用于训练,即相当于源领域数据,重新经过步骤S101-S105,该训练方法还包括针对目标领域训练样本数据的步骤,如下:
1)通过特征提取模块提取目标领域训练样本数据的低层特征和高层特征。
2)基于目标领域训练样本数据的高层特征和对应的领域标签,通过第一损失函数计算目标领域训练样本数据对应的第一损失;具体地,将目标领域训练样本数据的高层特征输入域不变性特征模块得到目标领域训练样本数据对应的第一结果;根据目标领域训练样本数据对应的第一结果和对应的领域标签,通过第一损失函数计算目标领域训练样本数据对应的第一损失。
3)基于目标领域训练样本数据的低层特征和对应的领域标签,通过第二损失函数计算目标领域训练样本数据对应的第二损失;具体地,将目标领域训练样本数据的低层特征输入域区分性特征模块得到目标领域训练样本数据对应的第二结果;根据目标领域训练样本数据对应的第二结果和对应的领域标签,通过第二损失函数计算目标领域训练样本数据对应的第二损失
4)基于目标领域训练样本数据的高层特征和对应的预测样本标签,通过第三损失函数计算目标领域训练样本数据对应的第三损失;具体地,将目标领域训练样本数据的高层特征输入任务模块得到目标领域训练样本数据对应的第三结果;基于目标领域训练样本数据对应的第三结果和对应的预测样本标签,通过第三损失函数计算目标领域训练样本数据对应的第三损失。
5)根据目标领域训练样本数据对应的第一损失、第二损失和第三损失计算目标领域训练样本数据对应的总损失,其中,目标领域训练样本数据对应的第一损失的梯度经过梯度反向;
6)根据目标领域训练样本数据对应的总损失和目标领域训练样本数据的权重,更新特征提取模块的参数、任务模块的参数、域不变性特征模块的参数和域区分性特征模块的参数。
图12对应的实施例中描述的所有步骤可以由训练装置110或执行设备210单独执行,也可以由多个装置或设备执行,每个装置或设备执行图12对应的实施例中描述的部分步骤。例如图12对应的实施例中描述的所有步骤有训练装置110执行,可理解地,选定的目标领域训练样本数据作为已标注的训练数据(包含样本标签和领域标签),再次输入训练装置110时的训练装置110中各单元的参数已经与得到目标领域训练样本数据的预测标签时的参数是不完全相同的,此时的训练装置110中各单元的参数可能经过至少一次的更新。
本申请实施例提供的训练方法实际上同时训练了任务模型和领域区分模型。任务模型包括特征提取模块和任务模块,针对特定任务的模型。领域区分模型包括特征提取模块和域区分性特征模块,用于区分所属领域,即针对输入的数据给出该数据所属的领域(源领域或目标领域),领域区分模型训练使用的标签是领域标签,例如设置源领域数据的领域标签为0,设置目标领域数据的领域标签为1。需要注意的是,领域区分模型中的域区分性特征模块可以是域区分性特征单元114或者域区分性特征单元113'。
需要注意的,上述步骤编号并不是指定按照编号顺序执行各步骤,编号为了方便阅读,各步骤之间具有逻辑顺序,可根据技术方案具体确定,因此,编号并不是对方法流程的限定。同样地,图12中的编号也不是对方法流程的限定。
本申请实施例提供的训练方法是基于增强协同对抗网络实现的,如图13所示的基于CNN构建的增强协同对抗网络。协同对抗网络是指基于低层特征和高层特征分别建立域区分性损失函数和域不变性损失函数形成的网络,可选的,域区分性损失函数配置在域区分性特征单元114,域不变性损失函数配置在域不变性特征单元113。增强协同对抗网络是在协同对抗网络的基础上增加了从目标领域数据中选择训练数据并设置权重用于训练的过程。下面以图像分类器为例描述本申请实施例提供的训练方法。
如图13所示,输入源领域图像数据301和目标领域图像数据302。源领域图像数据301是标注有类别标签的图像数据,目标领域图像数据302是未标注有类别标签的图像数据,类别标签用于指示图像数据的类别,训练后的图像分类器用于预测图像数据的类别。图像数据可以是图片或视频流,也可以是其他图像数据形式。源领域图像数据301和目标领域图像数据302分别对应各自的领域标签,领域标签用于指示图像数据所属的领域。源领域图像数据301与目标领域图像数据302存在差异(例如上面应用场景实施例给出的示例),体现在数学表达上则是数据分布不同。
低层特征提取303部分
源领域图像数据301和目标领域图像数据302均经过低层特征提取303得到各数据对应的低层特征。低层特征提取303对应低层特征提取子单元1111,可利用CNN进行卷积预算提取图像数据中的低层特征。
具体地,低层特征提取303的输入数据包括源领域图像数据301,可以表示为其中为源领域图像数据中的第i个,为其类别标签,Ns为源领域图像数据中样本的数量。相应地,目标领域图像数据301可以表示为没有类别标签。低层特征提取303可以使用一系列卷积层、规范层、下采样层实现,用Fk(xi;θk)表示,其中k为低层特征提取303的层数,θk为低层特征提取303的参数。
高层特征提取304部分
高层特征提取304是在低层特征提取303的基础上对低层特征的进一步的处理,可选的,高层特征提取304对应高层特征提取子单元1112,可以利用CNN进行卷积预算提取图像数据中高层特征,与低层特征提取303一样具体地可以使用一系列卷积层、规范层、下采样层实现,可以用Fm(xi;θm)表示,其中m即为特征提取层的总层数。
图像分类305针对层特征提取304输入的高层特征,输出预测的类别信息,可以表示为C:f→yi,也可以表示为一个图像分类器C(F(xi;ΘF),c),其中c为图像分类器的参数。图像分类可以扩展到多种计算机视觉任务,包括检测、识别、分割等。另外,根据图像分类305的输出与图像数据的类别标签(对应图13中的源数据类别标签)定义分类损失函数(对应第三损失函数),以对图像分类305中的参数进行优化。这个分类损失函数可以定义为图像分类305输出与对应类别标签的交叉熵。由于源领域图像数据301已有类别标签,可以定义源领域图像数据301的分类损失函数为Lsrc(C(F(xi;ΘF),c),yi s)。通过迭代优化图像分类305的从参数使得该分类损失函数最小化,得到图像分类器。需要注意是:这里的图像分类器不包含特征提取部分的,在实际中,该图像分类器需要配合特征提取(低层特征提取303和高层特征提取304)使用,训练的过程中实际是对图像分类305(图像分类器)、低层特征提取303和高层特征提取304三者的参数进行更新优化。
域不变性306部分
为使得在源领域图像数据301上训练的图像分类器/模型能够在目标领域图像数据302上同样有较好的分类精度,图像分类器所利用的图像的高层特征应当具有域不变性。为了实现这样的目的,域不变性306能够使得高层特征无法对领域进行区分,从而具有域不变性。具体地,域不变性306包括针对高层特征提取304设置的领域区分器,可以表示为D(F(xi;ΘF),w),其中w为领域区分器的参数。类似于图像分类器,也可以根据域不变性306的输出与领域标签定义一个域不变性损失函数LD(D(F(xi;ΘF),w),di)(对应第一损失函数)。与分类损失函数不同的是,为了使源领域图像数据301与目标领域图像数据302之间的高层特征不具有区分性,域不变性306通过梯度反向方法使得域不变性损失函数不是趋于最小化,而是的损失变大。梯度反向方法可以使用任一现有技术实现,此处不对梯度反向的具体方法做任何限制。与图像分类器一样,需要注意的是:这里的领域区分器不包含特征提取部分的,在实际中,该领域区分器需要配合特征提取(低层特征提取303和高层特征提取304)使用,训练的过程中实际是对域不变性305中的领域区分器、低层特征提取303和高层特征提取304三者的参数进行更新优化。
值得注意的,上面的域不变性损失函数与分类损失函数需要同时优化,在训练的过程中组成一个对抗网络,并使用多任务优化方法来解。
域区分性307部分
一般而言,图像的低层特征包括图像的边缘、角点等,这些特征往往是跟领域有较大关系,可以用于领域区分。若在训练中只强调域不变性特征,使得在源领域图像数据301与目标领域图像数据302之间的高层特征分布类似,从而在源领域图像数据上训练得到的图像分类模型在目标领域的图像数据也有较好的效果,则同样使得低层特征也具有了域不变性,丢失了大量域区分性特征。为此可以针对低层特征提取303,根据域区分性307的输出与领域标签定义一个领域区分性损失函数(对应第二损失函数),使得提取到的低层特征具有域区分性。具体而言,域区分性损失函数可以表示为LD(D(F(xi;θk),wk),di),其中k为所加损失函数的层数。
该域区分性损失函数与域不变性损失函数组在一起,则构成协同对抗网络,总体损失函数可以表示为:
其中为对于某一层的领域区分目标,λk为对k层损失函数的权重,为λm为对m层损失函数的权重,且λm取负值。在目标函数中,通过权重对特征的域区分性与域不变性进行平衡,并且使用基于梯度的方法在网络训练过程中对参数进行优化,从而提高网络的性能。
样本数据选择308部分
为进一步提高训练的图像分类模型在目标领域的图像数据上的分类精度,可以使用目标领域的图像数据用于图像分类模型的训练。由于目标领域图像数据302原本没有类别标签,可以将目标领域图像数据302通过低层特征提取303、高层特征提取304得到的高层特征,输入图像分类305的输出作为目标领域图像数据302的标签。也就是使用上面描述的方法训练后的图像分类模型在目标领域图像数据302上的输出作为其类别标签,再将拥有类别标签的目标领域图像数据作为新的训练数据加入之后的迭代训练过程,具体的参见图12对应实施例中1)-6)。但是并不是所有的通过图像分类模型获得类别标签的目标领域图像数据都可以作为目标领域训练样本数据。图像分类模型对于样本数据的输出包括类别信息和置信度,当输出的置信度高时,输出类别信息正确的可能性更大,因此,可以选择置信度高的目标领域图像数据作为目标领域训练样本数据。具体地,首先设置一个阈值;再从目标领域图像数据302中选择根据置信度大于该阈值的图像数据作为目标领域训练样本数据。另外,考虑到在训练的过程中,图像分类模型的精度较低。随着训练次数的增加,分类精度会上升,故该阈值的设置与模型的精度有关,即根据当前得到图像分类模型的精度设置自适应的阈值。具体阈值设置可以参见图12对应实施例的相关描述,在此不再赘述。
权重设置309部分
根据域不变性306中域领域区分器的输出,对已选择的目标领域训练样本数据设置权重。当目标领域训练样本数据不易被领域区分器区分时,则目标领域训练样本数据的分布比较接近于源领域图像数据与目标领域图像数据之间,对图像分类模型的训练更有帮助,可以给较大权重。若目标领域训练样本数据很容易被领域区分器区分开,则该目标领域训练样本数据对于图像分类模型的训练价值较小,可以减小它在损失函数的权重。如图14所示,其中领域区分器输出为0.5的样本权重最大,两边的权重依次减小,当达到一定值时,权重为0。该权重可以使用如下式表示:
其中z为一个可以学习的参数,α是一个常数。基于这个公式,对样本的权重可以表示为
可选的,对靠近目标领域图像数据的目标领域训练样本数据的权重取较大值。可以采用多种方法设置此类权重,例如对上式中若则将权重设置为所对应的权重值:
通过目标领域训练样本数据选择与权重设置之后,可以针对目标领域训练样本数据建立分类损失函数,可以表示为
其中为经过之前训练后的图像分类器在目标领域训练样本数据上的输出。从而,基于增强协同对抗网络的总体损失函数由三部分构成,即在源领域图像数据上的分类损失函数、在低层特征与高层特征上的协同对抗损失函数以及在目标领域训练样本数据上的分类损失函数,可以表示为:
该总体损失函数可以使用基于随机梯度的反向传播方法进行优化,从而更新增强协同对抗网络中各部分的参数,训练图像分类模型,利用该图像分类模型用于目标领域图像数据的类别预测。在训练过程中,可以先使用源领域图像数据及类别标签,训练一个初始的协同对抗网络,在通过自适应目标领域训练样本数据选择308和权重设置309选择样本和设置权重后,与源领域图像数据共同再训练该初始的协同对抗网络。
需要注意的,图13中的低层特征提取303、高层特征提取304、图像分类305、域不变性306、域区分性307、样本数据选择308和权重设置309可以看作是增强协同对抗网络的组成模块,也可以看作是基于增强协同对抗网络的训练方法中操作步骤。
本申请实施例提供了一种芯片硬件结构,如图15所示,上面本申请实施例中描述的基于卷积神经网络的算法/方法(图12对应的实施例和图13对应的实施例中涉及的算法/方法)可以全部或部分在图15所示的NPU芯片中实现。
神经网络处理器NPU 50NPU作为协处理器挂载到主CPU(Host CPU)上,由Host CPU分配任务。NPU的核心部分为运算电路50,通过控制器504控制运算电路503提取存储器中的矩阵数据并进行乘法运算。
在一些实现中,运算电路503内部包括多个处理单元(Process Engine,PE)。在一些实现中,运算电路503是二维脉动阵列。运算电路503还可以是一维脉动阵列或者能够执行例如乘法和加法这样的数学运算的其它电子线路。在一些实现中,运算电路503是通用的矩阵处理器。
举例来说,假设有输入矩阵A,权重矩阵B,输出矩阵C。运算电路从权重存储器502中取矩阵B相应的数据,并缓存在运算电路中每一个PE上。运算电路从输入存储器501中取矩阵A数据与矩阵B进行矩阵运算,得到的矩阵的部分结果或最终结果,保存在累加器508accumulator中。
统一存储器506用于存放输入数据以及输出数据。权重数据直接通过存储单元访问控制器505Direct Memory Access Controller,DMAC被搬运到权重存储器502中。输入数据也通过DMAC被搬运到统一存储器506中。
BIU为Bus Interface Unit即,总线接口单元510,用于AXI总线与DMAC和取指存储器509Instruction Fetch Buffer的交互。
总线接口单元510(Bus Interface Unit,简称BIU),用于取指存储器509从外部存储器获取指令,还用于存储单元访问控制器505从外部存储器获取输入矩阵A或者权重矩阵B的原数据。
DMAC主要用于将外部存储器DDR中的输入数据搬运到统一存储器506或将权重数据搬运到权重存储器502中或将输入数据数据搬运到输入存储器501中。
向量计算单元507多个运算处理单元,在需要的情况下,对运算电路的输出做进一步处理,如向量乘,向量加,指数运算,对数运算,大小比较等等。主要用于神经网络中非卷积/FC层网络计算,如Pooling(池化),Batch Normalization(批归一化),Local ResponseNormalization(局部响应归一化)等。
在一些实现种,向量计算单元能507将经处理的输出的向量存储到统一缓存器506。例如,向量计算单元507可以将非线性函数应用到运算电路503的输出,例如累加值的向量,用以生成激活值。在一些实现中,向量计算单元507生成归一化的值、合并值,或二者均有。在一些实现中,处理过的输出的向量能够用作到运算电路503的激活输入,例如用于在神经网络中的后续层中的使用。
控制器504连接的取指存储器(instruction fetch buffer)509,用于存储控制器504使用的指令;
统一存储器506,输入存储器501,权重存储器502以及取指存储器509均为On-Chip存储器。外部存储器私有于该NPU硬件架构。
其中,卷积神经网络中各层的运算可以由矩阵计算单元212或向量计算单元507执行。
本申请实施例提供了一种训练设备410,如图16所示包括:处理器412、通信接口413、存储器411。可选地,训练设备410还可以包括总线414。其中,通信接口413、处理器412以及存储器411可以通过总线414相互连接;总线414可以是外设部件互连标准(英文:Peripheral Component Interconnect,简称PCI)总线或扩展工业标准结构(英文:Extended Industry Standard Architecture,简称EISA)总线等。上述总线414可以分为地址总线、数据总线、控制总线等。为便于表示,图16中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
上述图16所示的训练设备可以用于替代训练装置110以执行上面方法实施例中描述的方法,具体实现还可以对应参照上面方法实施例的相应描述,此处不再赘述。
结合本发明实施例公开内容所描述的方法或者算法的步骤可以硬件的方式来实现,也可以是由处理器执行软件指令的方式来实现。软件指令可以由相应的软件模块组成,软件模块可以被存放于随机存取存储器(英文:Random Access Memory,RAM)、闪存、只读存储器(英文:Read Only Memory,ROM)、可擦除可编程只读存储器(英文:ErasableProgrammable ROM,EPROM)、电可擦可编程只读存储器(英文:Electrically EPROM,EEPROM)、寄存器、硬盘、移动硬盘、只读光盘(CD-ROM)或者本领域熟知的任何其它形式的存储介质中。一种示例性的存储介质耦合至处理器,从而使处理器能够从该存储介质读取信息,且可向该存储介质写入信息。当然,存储介质也可以是处理器的组成部分。处理器和存储介质可以位于ASIC中。另外,该ASIC可以位于网络设备中。当然,处理器和存储介质也可以作为分立组件存在于终端设备中。
按照本申请实施例提供的训练方法,在公开的标准数据集Office-31与ImageCLEF-DA上做迁移学习的测试。Office-31是物体识别的一个标准数据集,共包含4110张图片,其中有31个类别的物体。它包含四个领域的数据Amazon(A),Webcam(W),和Dlsr(D)。这里测试从其中任一领域迁移到另外一个领域的学习过程,评估迁移学习的精度。
ImageCLEF-DA是CLEF 2014年挑战赛的数据集,其中包含了三个领域的数据,即ImageNet ILSVRC2012(I),Bing(B),与Pascal VOC 2012(P)。每一个领域的数据都包含12个类别的数据,每个类别有50张图片。同样,这里测试从一个领域迁移到另外一个领域的识别精度,共6种迁移方式。
图17A和图17B给出了基于本申请实施例提供的方法与另外几种方法,如ResNet50、DANN、JAN的方法等的测试精度,并同时给出了平均迁移学习精度。可以看到,基于协同对抗网络的算法(CAN)获得了除JAN外最好的效果,而增强协同对抗网络(本发明)获得了最优效果,平均迁移精度比当前最好方法JAN高2~3个百分点。
因此,本申请实施例提供的基于增强协同对抗网络的训练方法基于高层特征提取和低层特征提取分别建立了域不变性损失函数和域区分性损失函数,在保证高层特征的域不变性特征的同时保留了低层特征中的域区分性特征,能够提高图像分类器应用到目标领域的图像分类预测的精度。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,上述的程序可存储于计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。而前述的存储介质包括:ROM、RAM、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述仅为本发明的几个实施例,本领域的技术人员依据申请文件公开的可以对本发明进行各种改动或变型而不脱离本发明的精神和范围。
Claims (22)
1.一种深度神经网络的训练方法,其特征在于,包括:
提取源领域数据和目标领域数据中各样本数据的低层特征和高层特征,所述目标领域数据与所述源领域数据的数据分布不同;
基于所述源领域数据和所述目标领域数据中各样本数据的高层特征和对应的领域标签,通过第一损失函数分别计算各样本数据对应的第一损失;
基于所述源领域数据和所述目标领域数据中各样本数据的低层特征和对应的领域标签,通过第二损失函数分别计算各样本数据对应的第二损失;
基于所述源领域数据中的样本数据的高层特征和对应的样本标签,通过第三损失函数计算所述源领域数据中的样本数据对应的第三损失;
根据所述第一损失、所述第二损失和所述第三损失更新目标深度神经网络的参数,其中所述第一损失的梯度经过梯度反向,所述梯度反向可实现反向传导梯度使损失变大。
2.根据权利要求1所述的训练方法,其特征在于,所述目标深度神经网络包括特征提取模块、任务模块、域不变性特征模块和域区分性特征模块,所述特征提取模块包括至少一个低层特征网络层和高层特征网络层,所述至少一个低层特征网络层中的任一个低层特征网络层可用于提取低层特征,所述高层特征网络层用于提取高层特征,所述域不变性特征模块用于增强所述特征提取模块提取的高层特征的领域不变性,所述域区分性特征模块用于增强所述特征提取模块提取的低层特征的领域区分性;
其中,所述根据所述第一损失、所述第二损失和所述第三损失更新目标深度神经网络的参数包括:
根据所述第一损失、所述第二损失和所述第三损失计算总损失;
根据所述总损失更新所述特征提取模块的参数、所述任务模块的参数、所述域不变性特征模块的参数和所述域区分性特征模块的参数。
3.根据权利要求2所述的训练方法,其特征在于,所述基于所述源领域数据和所述目标领域数据中各样本数据的高层特征和对应的领域标签,通过第一损失函数分别计算各样本数据对应的第一损失,包括:将所述源领域数据和所述目标领域数据中的各样本数据的高层特征输入所述域不变性特征模块得到各样本数据对应的第一结果;根据所述源领域数据和所述目标领域数据中的各样本数据对应的第一结果和对应的领域标签,通过所述第一损失函数分别计算各样本数据对应的第一损失;
所述基于所述源领域数据和所述目标领域数据中各样本数据的低层特征和对应的领域标签,通过第二损失函数分别计算各样本数据对应的第二损失,包括:将所述源领域数据和所述目标领域数据中的各样本数据的低层特征输入所述域区分性特征模块得到各样本数据对应的第二结果;根据所述源领域数据和所述目标领域数据中的各样本数据对应的第二结果和对应的领域标签,通过所述第二损失函数分别计算各样本数据对应的第二损失;
所述基于所述源领域数据中的样本数据的高层特征和对应的样本标签,通过第三损失函数计算所述源领域数据中的样本数据对应的第三损失,包括:将所述源领域数据中的样本数据的高层特征输入所述任务模块得到所述源领域数据中的样本数据对应的第三结果;基于所述源领域数据中的样本数据对应的第三结果和对应的样本标签,通过第三损失函数计算所述源领域数据中的样本数据对应的第三损失。
4.根据权利要求2或3所述的训练方法,其特征在于,所述域不变性特征模块还包括:梯度反向模块;
所述训练方法还包括:
通过所述梯度反向模块对所述第一损失的梯度进行所述梯度反向。
5.根据权利要求3或4所述的训练方法,其特征在于,还包括:
将所述目标领域数据中样本数据的高层特征输入所述任务模块,得到对应的预测样本标签和对应的置信度;
根据所述目标领域数据中样本数据对应的置信度从所述目标领域数据中选定目标领域训练样本数据,所述目标领域训练样本数据为所述目标领域数据中对应的置信度满足预设条件的样本数据。
6.根据权利要求5所述的训练方法,其特征在于,还包括:
根据所述目标领域训练样本数据对应的第一结果设置所述目标领域训练样本数据的权重。
7.根据权利要求6所述的训练方法,其特征在于,所述根据所述目标领域训练样本数据对应的第一结果设置所述目标领域训练样本数据的权重包括:
根据所述目标领域训练样本数据对应的第一结果与领域标签的相似度,设置所述目标领域训练样本数据的权重,所述相似度表示第一结果与领域标签的差值大小。
8.根据权利要求7所述的训练方法,其特征在于,所述根据所述目标领域训练样本数据对应的第一结果与领域标签的相似度,设置所述目标领域训练样本数据的权重包括:
计算所述目标领域训练样本数据对应的第一结果与源领域的领域标签的第一差值,以及所述目标领域训练样本数据对应的第一结果与目标领域的领域标签的第二差值;
若所述第一差值的绝对值大于所述第二差值的绝对值,则设置所述目标领域训练样本数据的权重为较小的值,否则,设置所述目标领域训练样本数据的权重为较大的值。
9.根据权利要求7所述的训练方法,其特征在于,若所述目标领域训练样本数据对应的第一结果为第一领域标签值至第二领域标签值取值范围中的中间值,则设置所述目标领域训练样本数据的权重为最大值,所述第一领域标签值为源领域的领域标签对应的值,所述第二领域标签值为目标领域的领域标签对应的值。
10.根据权利要求5-9任选一所述的训练方法,其特征在于,在所述根据所述目标领域数据中样本数据对应的置信度从所述目标领域数据中选定目标领域训练样本数据之前,还包括:
根据任务模型的精度设置自适应阈值,所述任务模型包括所述特征提取模块和所述任务模块,所述自适应阈值与所述任务模型的精度正相关;
其中,所述预设条件为置信度大于或等于所述自适应阈值。
11.根据权利要求10所述的训练方法,其特征在于,所述自适应阈值通过下面逻辑函数计算:
其中,所述Tc为所述自适应阈值,所述A为所述任务模型的精度,λc为用于控制所述逻辑函数的倾斜度的超参数。
12.根据权利要求5-11任选一所述的训练方法,其特征在于,所述训练方法还包括:
通过所述特征提取模块提取所述目标领域训练样本数据的低层特征和高层特征;
基于所述目标领域训练样本数据的高层特征和对应的领域标签,通过所述第一损失函数计算所述目标领域训练样本数据对应的第一损失;
基于所述目标领域训练样本数据的低层特征和对应的领域标签,通过所述第二损失函数计算所述目标领域训练样本数据对应的第二损失;
基于所述目标领域训练样本数据的高层特征和对应的预测样本标签,通过所述第三损失函数计算所述目标领域训练样本数据对应的第三损失;
根据所述目标领域训练样本数据对应的第一损失、第二损失和第三损失计算所述目标领域训练样本数据对应的总损失,其中,所述目标领域训练样本数据对应的第一损失的梯度经过梯度反向;
根据所述目标领域训练样本数据对应的总损失和所述目标领域训练样本数据的权重,更新所述特征提取模块的参数、所述任务模块的参数、所述域不变性特征模块的参数和所述域区分性特征模块的参数。
13.根据权利要求12所述的训练方法,其特征在于,所述基于所述目标领域训练样本数据的高层特征和对应的领域标签,通过所述第一损失函数计算所述目标领域训练样本数据对应的第一损失包括:将所述目标领域训练样本数据的高层特征输入所述域不变性特征模块得到所述目标领域训练样本数据对应的第一结果;根据所述目标领域训练样本数据对应的第一结果和对应的领域标签,通过所述第一损失函数计算所述目标领域训练样本数据对应的第一损失;
所述基于所述目标领域训练样本数据的低层特征和对应的领域标签,通过所述第二损失函数计算所述目标领域训练样本数据对应的第二损失包括:将所述目标领域训练样本数据的低层特征输入所述域区分性特征模块得到所述目标领域训练样本数据对应的第二结果;根据所述目标领域训练样本数据对应的第二结果和对应的领域标签,通过所述第二损失函数计算所述目标领域训练样本数据对应的第二损失;
所述基于所述目标领域训练样本数据的高层特征和对应的预测样本标签,通过第三损失函数计算所述目标领域训练样本数据对应的第三损失,包括:将所述目标领域训练样本数据的高层特征输入所述任务模块得到所述目标领域训练样本数据对应的第三结果;基于所述目标领域训练样本数据对应的第三结果和对应的预测样本标签,通过所述第三损失函数计算所述目标领域训练样本数据对应的第三损失。
14.一种训练设备,其特征在于,包括存储器及与所述存储器耦合的处理器;所述存储器用于存储指令,所述处理器用于执行所述指令;其中,所述处理器执行所述指令时执行如上权利要求1至13中任一项所述的方法。
15.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至13任一项所述方法。
16.一种增强协同对抗网络,其特征在于,所述增强协同对抗网络基于卷积神经网络CNN构建,包括:
特征提取模块,用于提取源领域数据和目标领域数据中各样本数据的低层特征和高层特征,所述目标领域数据与所述源领域数据的数据分布不同;
任务模块,用于接收所述特征提取模块输出的高层特征且通过第三损失函数分别计算各样本数据对应的第三损失,所述第三损失用于更新所述特征提取模块和所述任务模块的参数;
域不变性模块,用于接收所述特征提取模块输出的高层特征且通过第一损失函数分别计算各样本数据对应的第一损失,所述第一损失用于更新所述特征提取模块和所述域不变性模块的参数,使得所述特征提取模块输出的高层特征具有域不变性;
域区分性模块,用于接收所述特征提取模块输出的低层特征且通过第二损失函数分别计算各样本数据对应的第二损失,所述第二损失用于更新所述特征提取模块和所述域区分性模块的参数,使得所述特征提取模块输出的低层特征具有域区分性。
17.根据权利要求16所述的增强协同对抗网络,其特征在于,还包括:样本数据选择模块,用于根据所述目标领域数据中样本数据对应的置信度从所述目标领域数据中选定目标领域训练样本数据,所述目标领域数据中样本数据对应的置信度通过将所述目标领域数据中样本数据的高层特征输入所述任务模块得到,所述目标领域训练样本数据为所述目标领域数据中对应的置信度满足预设条件的样本数据。
18.根据权利要求17所述的增强协同对抗网络,其特征在于,所述样本数据选择模块还用于:根据任务模型的精度设置自适应阈值,所述任务模型包括所述特征提取模块和所述任务模块,所述自适应阈值与所述任务模型的精度正相关;其中,所述预设条件为置信度大于或等于所述自适应阈值。
19.根据权利要求17或18所述的增强协同对抗网络,其特征在于,还包括权重设置模块,用于根据所述目标领域训练样本数据对应的第一结果设置所述目标领域训练样本数据的权重。
20.根据权利要求19所述的增强协同对抗网络,其特征在于,所述权重设置模块具体用于:根据所述目标领域训练样本数据对应的第一结果与领域标签的相似度,设置所述目标领域训练样本数据的权重,所述相似度表示第一结果与领域标签的差值大小。
21.根据权利要求20所述的增强协同对抗网络,其特征在于,所述权重设置模块具体用于:计算所述目标领域训练样本数据对应的第一结果与源领域的领域标签的第一差值,以及所述目标领域训练样本数据对应的第一结果与目标领域的领域标签的第二差值;若第一差值的绝对值大于所述第二差值的绝对值,则设置所述目标领域训练样本数据的权重为较小的值,否则,设置所述目标领域训练样本数据的权重为较大的值。
22.根据权利要求20所述的增强协同对抗网络,其特征在于,所述权重设置模块具体用于:若所述目标领域训练样本数据对应的第一结果为第一领域标签值至第二领域标签值取值范围中的中间值,则设置所述目标领域训练样本数据的权重为最大值,所述第一领域标签值为源领域的领域标签对应的值,所述第二领域标签值为目标领域的领域标签对应的值。
Priority Applications (4)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201810554459.4A CN109902798A (zh) | 2018-05-31 | 2018-05-31 | 深度神经网络的训练方法和装置 |
EP19812148.5A EP3757905A4 (en) | 2018-05-31 | 2019-05-28 | DEEP NEURONAL NETWORK TRAINING PROCESS AND APPARATUS |
PCT/CN2019/088846 WO2019228358A1 (zh) | 2018-05-31 | 2019-05-28 | 深度神经网络的训练方法和装置 |
US17/033,316 US20210012198A1 (en) | 2018-05-31 | 2020-09-25 | Method for training deep neural network and apparatus |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201810554459.4A CN109902798A (zh) | 2018-05-31 | 2018-05-31 | 深度神经网络的训练方法和装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN109902798A true CN109902798A (zh) | 2019-06-18 |
Family
ID=66943222
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201810554459.4A Pending CN109902798A (zh) | 2018-05-31 | 2018-05-31 | 深度神经网络的训练方法和装置 |
Country Status (4)
Country | Link |
---|---|
US (1) | US20210012198A1 (zh) |
EP (1) | EP3757905A4 (zh) |
CN (1) | CN109902798A (zh) |
WO (1) | WO2019228358A1 (zh) |
Cited By (30)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110674648A (zh) * | 2019-09-29 | 2020-01-10 | 厦门大学 | 基于迭代式双向迁移的神经网络机器翻译模型 |
CN111178401A (zh) * | 2019-12-16 | 2020-05-19 | 上海航天控制技术研究所 | 一种基于多层对抗网络的空间目标分类方法 |
CN111239137A (zh) * | 2020-01-09 | 2020-06-05 | 江南大学 | 基于迁移学习与自适应深度卷积神经网络的谷物质量检测方法 |
CN111442926A (zh) * | 2020-01-11 | 2020-07-24 | 哈尔滨理工大学 | 一种基于深层特征迁移的变负载下不同型号滚动轴承故障诊断方法 |
CN111444958A (zh) * | 2020-03-25 | 2020-07-24 | 北京百度网讯科技有限公司 | 一种模型迁移训练方法、装置、设备及存储介质 |
CN111523649A (zh) * | 2020-05-09 | 2020-08-11 | 支付宝(杭州)信息技术有限公司 | 针对业务模型进行数据预处理的方法及装置 |
CN111598124A (zh) * | 2020-04-07 | 2020-08-28 | 深圳市商汤科技有限公司 | 图像处理及装置、处理器、电子设备、存储介质 |
CN111723691A (zh) * | 2020-06-03 | 2020-09-29 | 北京的卢深视科技有限公司 | 一种三维人脸识别方法、装置、电子设备及存储介质 |
CN111783844A (zh) * | 2020-06-10 | 2020-10-16 | 东莞正扬电子机械有限公司 | 基于深度学习的目标检测模型训练方法、设备及存储介质 |
CN112052818A (zh) * | 2020-09-15 | 2020-12-08 | 浙江智慧视频安防创新中心有限公司 | 无监督域适应的行人检测方法、系统及存储介质 |
CN112426161A (zh) * | 2020-11-17 | 2021-03-02 | 浙江大学 | 一种基于领域自适应的时变脑电特征提取方法 |
CN112528631A (zh) * | 2020-12-03 | 2021-03-19 | 上海谷均教育科技有限公司 | 一种基于深度学习算法的智能伴奏系统 |
CN112580733A (zh) * | 2020-12-25 | 2021-03-30 | 北京百度网讯科技有限公司 | 分类模型的训练方法、装置、设备以及存储介质 |
CN112633459A (zh) * | 2019-09-24 | 2021-04-09 | 华为技术有限公司 | 训练神经网络的方法、数据处理方法和相关装置 |
CN112989702A (zh) * | 2021-03-25 | 2021-06-18 | 河北工业大学 | 一种装备性能分析与预测的自学习方法 |
CN112990298A (zh) * | 2021-03-11 | 2021-06-18 | 北京中科虹霸科技有限公司 | 关键点检测模型训练方法、关键点检测方法及装置 |
CN113031437A (zh) * | 2021-02-26 | 2021-06-25 | 同济大学 | 一种基于动态模型强化学习的倒水服务机器人控制方法 |
CN113239975A (zh) * | 2021-04-21 | 2021-08-10 | 洛阳青鸟网络科技有限公司 | 一种基于神经网络的目标检测方法和装置 |
WO2021169366A1 (zh) * | 2020-02-25 | 2021-09-02 | 华为技术有限公司 | 数据增强方法和装置 |
CN113673570A (zh) * | 2021-07-21 | 2021-11-19 | 南京旭锐软件科技有限公司 | 电子器件图片分类模型的训练方法、装置及设备 |
US11200883B2 (en) | 2020-01-10 | 2021-12-14 | International Business Machines Corporation | Implementing a domain adaptive semantic role labeler |
CN113807183A (zh) * | 2021-08-17 | 2021-12-17 | 华为技术有限公司 | 模型训练方法及相关设备 |
WO2021255569A1 (en) * | 2020-06-18 | 2021-12-23 | International Business Machines Corporation | Drift regularization to counteract variation in drift coefficients for analog accelerators |
CN114726394A (zh) * | 2022-03-01 | 2022-07-08 | 深圳前海梵天通信技术有限公司 | 一种智能通信系统的训练方法及智能通信系统 |
WO2022151553A1 (zh) * | 2021-01-12 | 2022-07-21 | 之江实验室 | 一种基于域-不变特征的元-知识微调方法及平台 |
GB2608344A (en) * | 2021-01-12 | 2022-12-28 | Zhejiang Lab | Domain-invariant feature-based meta-knowledge fine-tuning method and platform |
CN116578924A (zh) * | 2023-07-12 | 2023-08-11 | 太极计算机股份有限公司 | 一种用于机器学习分类的网络任务优化方法及系统 |
CN116737607A (zh) * | 2023-08-16 | 2023-09-12 | 之江实验室 | 样本数据缓存方法、系统、计算机设备和存储介质 |
CN116882486A (zh) * | 2023-09-05 | 2023-10-13 | 浙江大华技术股份有限公司 | 一种迁移学习权重的构建方法和装置及设备 |
WO2023207228A1 (zh) * | 2022-04-28 | 2023-11-02 | 重庆长安汽车股份有限公司 | 一种基于隐私数据保护的智能网联汽车数据训练方法、电子设备及计算机可读存储介质 |
Families Citing this family (56)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11087142B2 (en) * | 2018-09-13 | 2021-08-10 | Nec Corporation | Recognizing fine-grained objects in surveillance camera images |
US11222210B2 (en) * | 2018-11-13 | 2022-01-11 | Nec Corporation | Attention and warping based domain adaptation for videos |
GB201819434D0 (en) * | 2018-11-29 | 2019-01-16 | Kheiron Medical Tech Ltd | Domain adaptation |
KR102039138B1 (ko) * | 2019-04-02 | 2019-10-31 | 주식회사 루닛 | 적대적 학습에 기반한 도메인 어댑테이션 방법 및 그 장치 |
KR20210074748A (ko) * | 2019-12-12 | 2021-06-22 | 삼성전자주식회사 | 도메인 적응에 기반한 네트워크의 트레이닝 방법, 동작 방법 및 동작 장치 |
US11537901B2 (en) * | 2019-12-31 | 2022-12-27 | Robert Bosch Gmbh | System and method for unsupervised domain adaptation with mixup training |
WO2021136947A1 (en) | 2020-01-03 | 2021-07-08 | Tractable Ltd | Vehicle damage state determination method |
CN110852450B (zh) * | 2020-01-15 | 2020-04-14 | 支付宝(杭州)信息技术有限公司 | 识别对抗样本以保护模型安全的方法及装置 |
CN111461191B (zh) * | 2020-03-25 | 2024-01-23 | 杭州跨视科技有限公司 | 为模型训练确定图像样本集的方法、装置和电子设备 |
CN111832605B (zh) * | 2020-05-22 | 2023-12-08 | 北京嘀嘀无限科技发展有限公司 | 无监督图像分类模型的训练方法、装置和电子设备 |
CN111680754B (zh) * | 2020-06-11 | 2023-09-19 | 抖音视界有限公司 | 图像分类方法、装置、电子设备及计算机可读存储介质 |
CN111914912B (zh) * | 2020-07-16 | 2023-06-13 | 天津大学 | 一种基于孪生条件对抗网络的跨域多视目标识别方法 |
CN112115976B (zh) * | 2020-08-20 | 2023-12-08 | 北京嘀嘀无限科技发展有限公司 | 模型训练方法、模型训练装置、存储介质和电子设备 |
CN112001398B (zh) * | 2020-08-26 | 2024-04-12 | 科大讯飞股份有限公司 | 域适应方法、装置、设备、图像处理方法及存储介质 |
US20220101068A1 (en) * | 2020-09-30 | 2022-03-31 | International Business Machines Corporation | Outlier detection in a deep neural network using t-way feature combinations |
CN112241452B (zh) * | 2020-10-16 | 2024-01-05 | 百度(中国)有限公司 | 一种模型训练方法、装置、电子设备及存储介质 |
CN112364860A (zh) * | 2020-11-05 | 2021-02-12 | 北京字跳网络技术有限公司 | 字符识别模型的训练方法、装置和电子设备 |
CN112633579B (zh) * | 2020-12-24 | 2024-01-12 | 中国科学技术大学 | 一种基于域对抗的交通流迁移预测方法 |
CN112634048B (zh) * | 2020-12-30 | 2023-06-13 | 第四范式(北京)技术有限公司 | 一种反洗钱模型的训练方法及装置 |
CN112749758B (zh) * | 2021-01-21 | 2023-08-11 | 北京百度网讯科技有限公司 | 图像处理方法、神经网络的训练方法、装置、设备和介质 |
CN112784776B (zh) * | 2021-01-26 | 2022-07-08 | 山西三友和智慧信息技术股份有限公司 | 一种基于改进残差网络的bpd面部情绪识别方法 |
CN112818833B (zh) * | 2021-01-29 | 2024-04-12 | 中能国际建筑投资集团有限公司 | 基于深度学习的人脸多任务检测方法、系统、装置及介质 |
CN112861977B (zh) * | 2021-02-19 | 2024-01-26 | 中国人民武装警察部队工程大学 | 迁移学习数据处理方法、系统、介质、设备、终端及应用 |
CN113065633A (zh) * | 2021-02-26 | 2021-07-02 | 华为技术有限公司 | 一种模型训练方法及其相关联设备 |
CN112884147B (zh) * | 2021-02-26 | 2023-11-28 | 上海商汤智能科技有限公司 | 神经网络训练方法、图像处理方法、装置及电子设备 |
CN113052295B (zh) * | 2021-02-27 | 2024-04-12 | 华为技术有限公司 | 一种神经网络的训练方法、物体检测方法、装置及设备 |
CN112966345B (zh) * | 2021-03-03 | 2022-06-07 | 北京航空航天大学 | 基于对抗训练和迁移学习的旋转机械剩余寿命预测混合收缩方法 |
CN113033549B (zh) * | 2021-03-09 | 2022-09-20 | 北京百度网讯科技有限公司 | 定位图获取模型的训练方法和装置 |
CN113076834B (zh) * | 2021-03-25 | 2022-05-13 | 华中科技大学 | 旋转机械故障信息处理方法、处理系统、处理终端、介质 |
CN113158364B (zh) * | 2021-04-02 | 2024-03-22 | 中国农业大学 | 循环泵轴承故障检测方法及系统 |
CN113111776B (zh) * | 2021-04-12 | 2024-04-16 | 京东科技控股股份有限公司 | 对抗样本的生成方法、装置、设备及存储介质 |
CN113132931B (zh) * | 2021-04-16 | 2022-01-28 | 电子科技大学 | 一种基于参数预测的深度迁移室内定位方法 |
CN113286311B (zh) * | 2021-04-29 | 2024-04-12 | 沈阳工业大学 | 基于多传感器融合的分布式周界安防环境感知系统 |
CN113128478B (zh) * | 2021-05-18 | 2023-07-14 | 电子科技大学中山学院 | 模型训练方法、行人分析方法、装置、设备及存储介质 |
CN113158985A (zh) * | 2021-05-18 | 2021-07-23 | 深圳市创智链科技有限公司 | 一种分类识别的方法和设备 |
CN113269261B (zh) * | 2021-05-31 | 2024-03-12 | 国网福建省电力有限公司电力科学研究院 | 一种配网波形智能分类方法 |
WO2023275603A1 (en) * | 2021-06-28 | 2023-01-05 | Sensetime International Pte. Ltd. | Methods, apparatuses, devices and storage media for training object detection network and for detecting object |
AU2021240261A1 (en) * | 2021-06-28 | 2023-01-19 | Sensetime International Pte. Ltd. | Methods, apparatuses, devices and storage media for training object detection network and for detecting object |
CN113344119A (zh) * | 2021-06-28 | 2021-09-03 | 南京邮电大学 | 工业物联网复杂环境下的小样本烟雾监测方法 |
CN113505834A (zh) * | 2021-07-13 | 2021-10-15 | 阿波罗智能技术(北京)有限公司 | 训练检测模型、确定图像更新信息和更新高精地图的方法 |
CN113657651A (zh) * | 2021-07-27 | 2021-11-16 | 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) | 基于深度迁移学习的柴油车排放预测方法、介质及设备 |
CN113792576B (zh) * | 2021-07-27 | 2023-07-18 | 北京邮电大学 | 基于有监督域适应的人体行为识别方法、电子设备 |
CN113591736A (zh) * | 2021-08-03 | 2021-11-02 | 北京百度网讯科技有限公司 | 特征提取网络、活体检测模型的训练方法和活体检测方法 |
CN113610219A (zh) * | 2021-08-16 | 2021-11-05 | 中国石油大学(华东) | 一种基于动态残差的多源域自适应方法 |
CN113948093B (zh) * | 2021-10-19 | 2024-03-26 | 南京航空航天大学 | 一种基于无监督场景适应的说话人识别方法及系统 |
CN113989595A (zh) * | 2021-11-05 | 2022-01-28 | 西安交通大学 | 一种基于阴影模型的联邦多源域适应方法及系统 |
CN114048568B (zh) * | 2021-11-17 | 2024-04-09 | 大连理工大学 | 一种基于多源迁移融合收缩框架的旋转机械故障诊断方法 |
CN114202028B (zh) * | 2021-12-13 | 2023-04-28 | 四川大学 | 基于mamtl的滚动轴承寿命阶段识别方法 |
CN114354195A (zh) * | 2021-12-31 | 2022-04-15 | 南京工业大学 | 一种深度域自适应卷积网络的滚动轴承故障诊断方法 |
CN115049627B (zh) * | 2022-06-21 | 2023-06-20 | 江南大学 | 基于域自适应深度迁移网络的钢表面缺陷检测方法及系统 |
CN114998602B (zh) * | 2022-08-08 | 2022-12-30 | 中国科学技术大学 | 基于低置信度样本对比损失的域适应学习方法及系统 |
CN116468096B (zh) * | 2023-03-30 | 2024-01-02 | 之江实验室 | 一种模型训练方法、装置、设备及可读存储介质 |
CN117093929B (zh) * | 2023-07-06 | 2024-03-29 | 珠海市伊特高科技有限公司 | 基于无监督域自适应网络的截流过电压预测方法及装置 |
CN116630630B (zh) * | 2023-07-24 | 2023-12-15 | 深圳思谋信息科技有限公司 | 语义分割方法、装置、计算机设备及计算机可读存储介质 |
CN117152563A (zh) * | 2023-10-16 | 2023-12-01 | 华南师范大学 | 混合目标域自适应模型的训练方法、装置及计算机设备 |
CN117435916B (zh) * | 2023-12-18 | 2024-03-12 | 四川云实信息技术有限公司 | 航片ai解译中的自适应迁移学习方法 |
Family Cites Families (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20170220951A1 (en) * | 2016-02-02 | 2017-08-03 | Xerox Corporation | Adapting multiple source classifiers in a target domain |
CN107633242A (zh) * | 2017-10-23 | 2018-01-26 | 广州视源电子科技股份有限公司 | 网络模型的训练方法、装置、设备和存储介质 |
CN107958287A (zh) * | 2017-11-23 | 2018-04-24 | 清华大学 | 面向跨界大数据分析的对抗迁移学习方法及系统 |
CN108009633A (zh) * | 2017-12-15 | 2018-05-08 | 清华大学 | 一种面向跨领域智能分析的多网络对抗学习方法和系统 |
-
2018
- 2018-05-31 CN CN201810554459.4A patent/CN109902798A/zh active Pending
-
2019
- 2019-05-28 EP EP19812148.5A patent/EP3757905A4/en active Pending
- 2019-05-28 WO PCT/CN2019/088846 patent/WO2019228358A1/zh unknown
-
2020
- 2020-09-25 US US17/033,316 patent/US20210012198A1/en active Pending
Cited By (48)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112633459A (zh) * | 2019-09-24 | 2021-04-09 | 华为技术有限公司 | 训练神经网络的方法、数据处理方法和相关装置 |
CN110674648A (zh) * | 2019-09-29 | 2020-01-10 | 厦门大学 | 基于迭代式双向迁移的神经网络机器翻译模型 |
CN110674648B (zh) * | 2019-09-29 | 2021-04-27 | 厦门大学 | 基于迭代式双向迁移的神经网络机器翻译模型 |
CN111178401A (zh) * | 2019-12-16 | 2020-05-19 | 上海航天控制技术研究所 | 一种基于多层对抗网络的空间目标分类方法 |
CN111178401B (zh) * | 2019-12-16 | 2023-09-12 | 上海航天控制技术研究所 | 一种基于多层对抗网络的空间目标分类方法 |
CN111239137A (zh) * | 2020-01-09 | 2020-06-05 | 江南大学 | 基于迁移学习与自适应深度卷积神经网络的谷物质量检测方法 |
CN111239137B (zh) * | 2020-01-09 | 2021-09-10 | 江南大学 | 基于迁移学习与自适应深度卷积神经网络的谷物质量检测方法 |
US11200883B2 (en) | 2020-01-10 | 2021-12-14 | International Business Machines Corporation | Implementing a domain adaptive semantic role labeler |
CN111442926A (zh) * | 2020-01-11 | 2020-07-24 | 哈尔滨理工大学 | 一种基于深层特征迁移的变负载下不同型号滚动轴承故障诊断方法 |
WO2021169366A1 (zh) * | 2020-02-25 | 2021-09-02 | 华为技术有限公司 | 数据增强方法和装置 |
CN111444958A (zh) * | 2020-03-25 | 2020-07-24 | 北京百度网讯科技有限公司 | 一种模型迁移训练方法、装置、设备及存储介质 |
CN111444958B (zh) * | 2020-03-25 | 2024-02-13 | 北京百度网讯科技有限公司 | 一种模型迁移训练方法、装置、设备及存储介质 |
CN111598124B (zh) * | 2020-04-07 | 2022-11-11 | 深圳市商汤科技有限公司 | 图像处理及装置、处理器、电子设备、存储介质 |
JP2022531763A (ja) * | 2020-04-07 | 2022-07-11 | シェンチェン センスタイム テクノロジー カンパニー リミテッド | 画像処理方法及び装置、プロセッサ、電子機器並びに記憶媒体 |
CN111598124A (zh) * | 2020-04-07 | 2020-08-28 | 深圳市商汤科技有限公司 | 图像处理及装置、处理器、电子设备、存储介质 |
WO2021203882A1 (zh) * | 2020-04-07 | 2021-10-14 | 深圳市商汤科技有限公司 | 姿态检测及视频处理方法、装置、电子设备和存储介质 |
CN111523649A (zh) * | 2020-05-09 | 2020-08-11 | 支付宝(杭州)信息技术有限公司 | 针对业务模型进行数据预处理的方法及装置 |
CN111723691B (zh) * | 2020-06-03 | 2023-10-17 | 合肥的卢深视科技有限公司 | 一种三维人脸识别方法、装置、电子设备及存储介质 |
CN111723691A (zh) * | 2020-06-03 | 2020-09-29 | 北京的卢深视科技有限公司 | 一种三维人脸识别方法、装置、电子设备及存储介质 |
CN111783844A (zh) * | 2020-06-10 | 2020-10-16 | 东莞正扬电子机械有限公司 | 基于深度学习的目标检测模型训练方法、设备及存储介质 |
GB2611681A (en) * | 2020-06-18 | 2023-04-12 | Ibm | Drift regularization to counteract variation in drift coefficients for analog accelerators |
WO2021255569A1 (en) * | 2020-06-18 | 2021-12-23 | International Business Machines Corporation | Drift regularization to counteract variation in drift coefficients for analog accelerators |
CN112052818A (zh) * | 2020-09-15 | 2020-12-08 | 浙江智慧视频安防创新中心有限公司 | 无监督域适应的行人检测方法、系统及存储介质 |
CN112052818B (zh) * | 2020-09-15 | 2024-03-22 | 浙江智慧视频安防创新中心有限公司 | 无监督域适应的行人检测方法、系统及存储介质 |
CN112426161A (zh) * | 2020-11-17 | 2021-03-02 | 浙江大学 | 一种基于领域自适应的时变脑电特征提取方法 |
CN112528631A (zh) * | 2020-12-03 | 2021-03-19 | 上海谷均教育科技有限公司 | 一种基于深度学习算法的智能伴奏系统 |
CN112528631B (zh) * | 2020-12-03 | 2022-08-09 | 上海谷均教育科技有限公司 | 一种基于深度学习算法的智能伴奏系统 |
CN112580733B (zh) * | 2020-12-25 | 2024-03-05 | 北京百度网讯科技有限公司 | 分类模型的训练方法、装置、设备以及存储介质 |
CN112580733A (zh) * | 2020-12-25 | 2021-03-30 | 北京百度网讯科技有限公司 | 分类模型的训练方法、装置、设备以及存储介质 |
GB2608344A (en) * | 2021-01-12 | 2022-12-28 | Zhejiang Lab | Domain-invariant feature-based meta-knowledge fine-tuning method and platform |
US11669741B2 (en) | 2021-01-12 | 2023-06-06 | Zhejiang Lab | Method and platform for meta-knowledge fine-tuning based on domain-invariant features |
WO2022151553A1 (zh) * | 2021-01-12 | 2022-07-21 | 之江实验室 | 一种基于域-不变特征的元-知识微调方法及平台 |
CN113031437B (zh) * | 2021-02-26 | 2022-10-25 | 同济大学 | 一种基于动态模型强化学习的倒水服务机器人控制方法 |
CN113031437A (zh) * | 2021-02-26 | 2021-06-25 | 同济大学 | 一种基于动态模型强化学习的倒水服务机器人控制方法 |
CN112990298A (zh) * | 2021-03-11 | 2021-06-18 | 北京中科虹霸科技有限公司 | 关键点检测模型训练方法、关键点检测方法及装置 |
CN112990298B (zh) * | 2021-03-11 | 2023-11-24 | 北京中科虹霸科技有限公司 | 关键点检测模型训练方法、关键点检测方法及装置 |
CN112989702A (zh) * | 2021-03-25 | 2021-06-18 | 河北工业大学 | 一种装备性能分析与预测的自学习方法 |
CN113239975A (zh) * | 2021-04-21 | 2021-08-10 | 洛阳青鸟网络科技有限公司 | 一种基于神经网络的目标检测方法和装置 |
CN113673570A (zh) * | 2021-07-21 | 2021-11-19 | 南京旭锐软件科技有限公司 | 电子器件图片分类模型的训练方法、装置及设备 |
CN113807183A (zh) * | 2021-08-17 | 2021-12-17 | 华为技术有限公司 | 模型训练方法及相关设备 |
CN114726394B (zh) * | 2022-03-01 | 2022-09-02 | 深圳前海梵天通信技术有限公司 | 一种智能通信系统的训练方法及智能通信系统 |
CN114726394A (zh) * | 2022-03-01 | 2022-07-08 | 深圳前海梵天通信技术有限公司 | 一种智能通信系统的训练方法及智能通信系统 |
WO2023207228A1 (zh) * | 2022-04-28 | 2023-11-02 | 重庆长安汽车股份有限公司 | 一种基于隐私数据保护的智能网联汽车数据训练方法、电子设备及计算机可读存储介质 |
CN116578924A (zh) * | 2023-07-12 | 2023-08-11 | 太极计算机股份有限公司 | 一种用于机器学习分类的网络任务优化方法及系统 |
CN116737607B (zh) * | 2023-08-16 | 2023-11-21 | 之江实验室 | 样本数据缓存方法、系统、计算机设备和存储介质 |
CN116737607A (zh) * | 2023-08-16 | 2023-09-12 | 之江实验室 | 样本数据缓存方法、系统、计算机设备和存储介质 |
CN116882486A (zh) * | 2023-09-05 | 2023-10-13 | 浙江大华技术股份有限公司 | 一种迁移学习权重的构建方法和装置及设备 |
CN116882486B (zh) * | 2023-09-05 | 2023-11-14 | 浙江大华技术股份有限公司 | 一种迁移学习权重的构建方法和装置及设备 |
Also Published As
Publication number | Publication date |
---|---|
WO2019228358A1 (zh) | 2019-12-05 |
EP3757905A4 (en) | 2021-04-28 |
US20210012198A1 (en) | 2021-01-14 |
EP3757905A1 (en) | 2020-12-30 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109902798A (zh) | 深度神经网络的训练方法和装置 | |
WO2021227726A1 (zh) | 面部检测、图像检测神经网络训练方法、装置和设备 | |
CN106970615B (zh) | 一种深度强化学习的实时在线路径规划方法 | |
CN109685819B (zh) | 一种基于特征增强的三维医学图像分割方法 | |
CN110298361A (zh) | 一种rgb-d图像的语义分割方法和系统 | |
CN108510194A (zh) | 风控模型训练方法、风险识别方法、装置、设备及介质 | |
CN105205453B (zh) | 基于深度自编码器的人眼检测和定位方法 | |
CN110532859A (zh) | 基于深度进化剪枝卷积网的遥感图像目标检测方法 | |
CN110222140A (zh) | 一种基于对抗学习和非对称哈希的跨模态检索方法 | |
WO2021022521A1 (zh) | 数据处理的方法、训练神经网络模型的方法及设备 | |
CN111291809B (zh) | 一种处理装置、方法及存储介质 | |
CN110096933A (zh) | 目标检测的方法、装置及系统 | |
CN106909924A (zh) | 一种基于深度显著性的遥感影像快速检索方法 | |
CN107833183A (zh) | 一种基于多任务深度神经网络的卫星图像同时超分辨和着色的方法 | |
CN107818302A (zh) | 基于卷积神经网络的非刚性多尺度物体检测方法 | |
CN109559300A (zh) | 图像处理方法、电子设备及计算机可读存储介质 | |
CN110263833A (zh) | 基于编码-解码结构的图像语义分割方法 | |
CN110134774A (zh) | 一种基于注意力决策的图像视觉问答模型、方法和系统 | |
CN109934115A (zh) | 人脸识别模型的构建方法、人脸识别方法及电子设备 | |
CN109817276A (zh) | 一种基于深度神经网络的蛋白质二级结构预测方法 | |
CN109464803A (zh) | 虚拟对象控制、模型训练方法、装置、存储介质和设备 | |
CN110222718B (zh) | 图像处理的方法及装置 | |
CN106909938A (zh) | 基于深度学习网络的视角无关性行为识别方法 | |
CN107292352A (zh) | 基于卷积神经网络的图像分类方法和装置 | |
CN111681178A (zh) | 一种基于知识蒸馏的图像去雾方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
RJ01 | Rejection of invention patent application after publication | ||
RJ01 | Rejection of invention patent application after publication |
Application publication date: 20190618 |