CN110443372B - 一种基于熵最小化的迁移学习方法及系统 - Google Patents
一种基于熵最小化的迁移学习方法及系统 Download PDFInfo
- Publication number
- CN110443372B CN110443372B CN201910623670.1A CN201910623670A CN110443372B CN 110443372 B CN110443372 B CN 110443372B CN 201910623670 A CN201910623670 A CN 201910623670A CN 110443372 B CN110443372 B CN 110443372B
- Authority
- CN
- China
- Prior art keywords
- network
- transfer learning
- loss function
- learning
- sample set
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Physics & Mathematics (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Image Analysis (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明提供了一种基于熵最小化的迁移学习方法及系统,涉及到深度学习,迁移学习,卷积神经网络等技术,所述方法包括:根据不同的迁移学习任务,构建迁移学习网络并初始化网络超参数;提供了CPEM算法,通过迫使网络预测结果接近目标域的真实类别分布以得到鲁棒性较高的迁移学习网络;保存网络模型以及训练结果,将目标域数据集引入该网络模型,得到最后的目标域标签,本发明提供了一种基于熵最小化的迁移学习方法及系统,在模型的损失函数上提出创新,相比一些现有的基于迁移学习的图像分类方法,分类精度得到了显著的提高。
Description
技术领域
本发明涉及一种计算机学习系统,具体的说是一种基于熵最小化的迁移学习方法及系统,属于计算机技术领域。
背景技术
数据集偏移是机器学习领域中一个不容忽视的问题。数据集是描述现实世界物体的片面化表述,在描述同一个物体集合的数据集上训练相同结构的模型,泛化能力往往存在偏差,效果不够理想。数据集偏移降低了模型在同一类物体上的泛化能力。对于真实世界而言,数据集偏移可以理解为模型在数据集上发生了过拟合问题。迁移学习尝试解决数据集偏移问题,基于源领域和目标领域间的相似性、差异性提高模型在目标领域数据上的表现。深度学习的发展使得深层次的学习模型具有更多需要学习的参数,也意味着需要大量的样本训练模型。另一方面寻找到数量足够的有标签样本支撑模型训练是极为困难的。
随着大规模数据的不断产生和依靠人力进行信息标注的困难,迁移学习方法逐渐成为机器学习领域中一项非常重要的研究课题。迁移学习方法旨在适配不同领域数据间的特征分布,提升不同领域间分类器迁移后的性能表现,解决目标域数据缺乏标注信息的难题。
近年来,大量的迁移学习方法使用熵最小化作为正则化技术,熵最小化也表现出了对于端对端迁移训练的简单有效性。但已有研究表明,对于无监督迁移学习,熵最小化仅仅是一个必要条件而非充分条件。没有其它辅助技术的协助,简单地使用熵最小化很有可能得到一些平凡解的结果。
发明内容
本发明的目的是提供一种基于熵最小化的迁移学习方法及系统,在模型的损失函数上提出创新,相比一些现有的基于迁移学习的图像分类方法,分类精度得到了显著的提高。
本发明的目的是这样实现的:一种基于熵最小化的迁移学习方法,包括如下步骤:
a.根据不同的迁移学习任务,构建迁移学习网络并初始化网络超参数;
b.提供CPEM(category-preserved entropy minimization)算法,通过迫使网络预测结果接近目标域的真实类别分布以得到鲁棒性较高的迁移学习网络;
c.保存网络模型以及训练结果,将目标域数据集引入该网络模型,得到最后的目标域标签。
作为本发明的进一步限定,所述方法还包括:
基于特征提取器与分类器,构建所述迁移学习网络;
基于预设的损失函数,对所述迁移学习网络进行学习。
作为本发明的进一步限定,所述预设的损失函数包括源领域样本集分类错误率损失函数、目标领域样本集分类结果的条件熵损失函数以及目标领域样本集分类结果类别分布的对称KL散度损失函数,所述基于预设的损失函数,对所述迁移学习网络进行学习,包括:
基于预设的损失函数,构建所述迁移学习网络的目标函数,以对所述迁移学习网络进行学习,其中,所述特征提取器和分类器的学习目标为使上述损失函数最小;
当所述迁移学习网络收敛或达到预设的学习次数后,结束对所述迁移学习网络的学习。
作为本发明的进一步限定,所述基于预设的损失函数,构建所述迁移学习网络的目标函数,以对所述迁移学习网络进行学习,包括:
基于所述预设的损失函数包括源领域样本集分类错误率损失函数、目标领域样本集分类结果的条件熵损失函数以及目标领域样本集分类结果类别分布的对称KL散度损失函数,联合构建所述迁移学习网络的目标函数;
利用反向传播算法对所述目标函数进行学习,以更新所述特征提取器与所述分类器。
作为本发明的进一步限定,所述目标领域样本集分类结果类别分布的对称 KL散度损失函数为:
其中T表示目标领域样本集任一batch样本,Lc(T)表示所述目标领域样本集batch分类结果类别分布的对称KL散度损失函数,dKL(·||·)表示KL散度损失函数,q表示目标域数据的类别分布,表示目标领域样本集batch分类结果的类别分布。
一种基于熵最小化的迁移学习系统,包括:
网络构造模块,根据不同的迁移学习任务,构建迁移学习网络并初始化网络超参数;
训练模块,提供了CPEM(category-preserved entropy minimization)算法,通过迫使网络预测结果接近目标域的真实类别分布以得到鲁棒性较高的迁移学习网络;
图像分类模块,保存网络模型以及训练结果,将目标域数据集引入该网络模型,得到最后的目标域标签。
本发明采用以上技术方案与现有技术相比,具有以下技术效果:本发明利用熵最小化作为无监督迁移学习的正则化方法,不需使用对抗学习技术,具有收敛速度快的优点。本发明通过迫使网络预测结果接近目标域的类别分布以得到鲁棒性较高的迁移学习网络,具有分类准确率高的优点,本发明可以用于预测无标签的目标域样本。
附图说明
图1是本发明迁移学习方法的具体实施例流程图。
图2是本发明迁移学习方法具体实施例网络结构图。
具体实施方式
下面结合附图对本发明的技术方案做进一步的详细说明:
如图1所示的一种基于熵最小化的迁移学习方法流程图,包括以下步骤:
步骤1,根据不同的迁移学习任务,构建迁移学习网络并初始化网络超参数;
基于特征提取器与分类器,构建所述迁移学习网络;
可以理解的是,本发明实施例提供的迁移学习网络是由特征提取器、分类器两部分构成,所述特征提取器用于提取输入样本集的特征,所述分类器用于对输入样本集的标签样本进行预测分类。
具体的,以在ImageCLEF-DA和Office-31数据集上的迁移学习任务为例,使用图2所示的网络结构作为迁移学习网络。ResNet-50模型构成的子网络即作为本发明实施例迁移学习网络的特征提取器,特征提取器后接的两个全连接层作为分类器。
进一步的,网络的输入是一个张量,通常是具有RGB三通道的彩色图像。首先,对于所有的输入图片做一定的数据增强处理并使得网络的输入为 224×224×3的张量,这使得训练的时候网络参数更容易收敛并一定程度防止过拟合。
进一步的,初始化网络超参数。
具体的,对于SVHN数据集到MNIST数据集的迁移学习任务,我们设置学习率为0.001;对于在ImageCLEF-DA数据集和Office-31数据集上的迁移学习任务,我们设置Dropout率为0.5,初始学习率η0为0.005,牛顿动量为0.9, batch大小为32。在训练过程中,学习率ηp动态变化如下:
其中参数p随着训练进行线性地从0增加到1,参数μ=10,ν=0.75。
步骤2,提供了CPEM(category-preserved entropy minimization)算法,通过迫使网络预测结果接近目标域的真实类别分布以得到鲁棒性较高的迁移学习网络;
基于预设的损失函数,对所述迁移学习网络进行学习;
在上述实施例的基础上,所述预设的损失函数包括源领域样本集分类错误率损失函数、目标领域样本集分类结果的条件熵损失函数以及目标领域样本集分类结果类别分布的对称KL散度损失函数,所述基于预设的损失函数,对所述迁移学习网络进行学习,包括:
基于预设的损失函数,构建所述迁移学习网络的目标函数,以对所述迁移学习网络进行学习,其中,所述特征提取器和分类器的学习目标为使上述损失函数最小;
当所述迁移学习网络收敛或达到预设的学习次数后,结束对所述迁移学习网络的学习;
根据上述损失函数,即可构建出本发明实施例提供的迁移学习网络的目标函数以及优化目标:
其中θ表示网络参数,表示具有ns个有标签样本的源领域样本集,表示具有nt个无标签样本的目标领域样本集;Ls(·)表示源领域样本集分类错误率损失函数,Le(·)表示目标领域样本集分类结果的条件熵损失函数,Lc(·)表示目标领域样本集分类结果类别分布的对称KL散度损失函数;λ和β是可调整的权衡参数;
可以理解的是,学习过程为一个不断更新参数的过程,当目标神经网络收敛或者达到预设的学习次数后,学习停止;
在上述实施例的基础上,所述基于预设的损失函数,构建所述迁移学习网络的目标函数,利用反向传播算法对所述目标函数进行学习,以更新所述特征提取器与所述分类器;
具体的,所述源领域样本集的分类错误率损失函数为:
其中,Ls(·)表示源领域样本集分类错误率损失函数,S表示源领域样本集任一batch样本,|S|表示源领域batch样本的基数,l(.)表示交叉熵损失函数, y表示源领域batch样本标签,f(.)表示迁移学习网络的函数模型;
进一步的,考虑到源领域样本集类别分布不均匀的情况,使用带有权重的损失函数更为适合,尤其是在小数据集进行迁移学习时;
所述目标领域分类结果的条件熵损失函数为:
其中,Le(·)表示目标领域样本集分类结果的条件熵损失函数,T表示源领域样本集任一batch样本,f(.)表示迁移学习网络的函数模型,f(xt)表示分类器对样本xt的预测概率;
可以理解的是,在神经网络总的目标函数中添加对目标领域分类结果的条件熵损失函数,是为了进一步提高神经网络在缺乏标签的目标领域的分类准确率;
所述目标领域样本集分类结果类别分布的对称KL散度损失函数为:
其中T表示目标领域样本集任一batch样本,Lc(T)表示所述目标领域样本集batch分类结果类别分布的对称KL散度损失函数,dKL(·||·)表示KL散度损失函数,q表示目标域数据的类别分布,表示目标领域样本集batch分类结果的类别分布;
进一步的,u表示如下:
其中,P(cls(xt)=K)表示分类器对样本属于第K类的预测概率;
进一步的,因为目标领域的真实类别分布是未知情况的,本发明使用一个均匀分布代替它;这种代替对于类别平衡分布的数据集上进行的迁移任务是很有效的;
可以理解的是,本发明实施例通过迫使网络预测结果接近目标域的真实类别分布以得到鲁棒性较高的迁移学习网络;
进一步的,分布计算出源领域样本集的分类错误率损失函数Ls(S)、目标领域样本集分类结果的条件熵损失函数Le(T)以及目标领域样本集分类结果类别分布的对称KL散度损失函数Lc(T),然后使用基于mini-batch的随机梯度下降法进行整个网络的训练,根据误差反传原则完成网络参数的更新,直至模型收敛或达到最大迭代次数时停止训练:
其中,μ表示学习率,λ和β是可调整的权衡参数。
步骤3,保存网络模型以及训练结果,将目标域数据集引入该网络模型,得到最后的目标域标签;
经过上述学习过程后,能够得到泛化性能较好的深度神经网络,保存网络最终模型以及训练结果后,将未标注的目标领域样本集引入该网络模型,得到较为准确的目标领域样本集标签。训练完成的网络可以用于预测目标领域无标记的样本,代替人工以较高的准确率标记未知数据。
本发明实施例还提供一种基于迁移学习的图像分类系统,包括如下模块:
网络构造模块,根据不同的迁移学习任务,构建迁移学习网络并初始化网络超参数;
训练模块,提供了CPEM(category-preserved entropy minimization)算法,通过迫使网络预测结果接近目标域的真实类别分布以得到鲁棒性较高的迁移学习网络;
图像分类模块,保存网络模型以及训练结果,将目标领域数据集引入该网络模型,得到最后的目标领域标签。
以上所述,仅为本发明中的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉该技术的人在本发明所揭露的技术范围内,可理解想到的变换或替换,都应涵盖在本发明的包含范围之内,因此,本发明的保护范围应该以权利要求书的保护范围为准。
Claims (3)
1.一种基于熵最小化的迁移学习方法,其特征在于,包括如下步骤:
a.根据不同的迁移学习任务,基于特征提取器与分类器,构建迁移学习网络并初始化网络超参数;网络的输入是一个张量,是具有RGB三通道的彩色图像;
b.提供category-preserved entropy minimization算法,通过迫使网络预测结果接近目标域的真实类别分布以得到鲁棒性较高的迁移学习网络;
基于预设的损失函数,对所述迁移学习网络进行学习;
所述预设的损失函数包括源领域样本集分类错误率损失函数、目标领域样本集分类结果的条件熵损失函数以及目标领域样本集分类结果类别分布的对称KL散度损失函数;
所述基于预设的损失函数,对所述迁移学习网络进行学习,包括:
基于预设的损失函数,构建所述迁移学习网络的目标函数,以对所述迁移学习网络进行学习,其中,所述特征提取器和分类器的学习目标为使损失函数最小;
当所述迁移学习网络收敛或达到预设的学习次数后,结束对所述迁移学习网络的学习;
所述基于预设的损失函数,构建所述迁移学习网络的目标函数,以对所述迁移学习网络进行学习,包括:
基于所述预设的损失函数包括源领域样本集分类错误率损失函数、目标领域样本集分类结果的条件熵损失函数以及目标领域样本集分类结果类别分布的对称KL散度损失函数,联合构建所述迁移学习网络的目标函数;
利用反向传播算法对所述目标函数进行学习,以更新所述特征提取器与所述分类器;
c.保存网络模型以及训练结果,将目标域数据集引入该网络模型,得到最后的目标域标签;
彩色图像经过步骤a-步骤c的处理,得到未分类的图片标签。
3.一种基于熵最小化的迁移学习系统,用以实现权利要求1所述基于熵最小化的迁移学习方法,其特征在于,包括:
网络构造模块,根据不同的迁移学习任务,构建迁移学习网络并初始化网络超参数;
训练模块,提供了category-preserved entropy minimization算法,通过迫使网络预测结果接近目标域的真实类别分布以得到鲁棒性较高的迁移学习网络;
图像分类模块,保存网络模型以及训练结果,将目标域数据集引入该网络模型,得到最后的目标域标签,具体为:将彩色图像引入该网络模型,得到未分类的图片标签。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910623670.1A CN110443372B (zh) | 2019-07-11 | 2019-07-11 | 一种基于熵最小化的迁移学习方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910623670.1A CN110443372B (zh) | 2019-07-11 | 2019-07-11 | 一种基于熵最小化的迁移学习方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN110443372A CN110443372A (zh) | 2019-11-12 |
CN110443372B true CN110443372B (zh) | 2022-08-30 |
Family
ID=68430152
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910623670.1A Active CN110443372B (zh) | 2019-07-11 | 2019-07-11 | 一种基于熵最小化的迁移学习方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN110443372B (zh) |
Families Citing this family (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110414400B (zh) * | 2019-07-22 | 2021-12-21 | 中国电建集团成都勘测设计研究院有限公司 | 一种施工现场安全帽穿戴自动检测方法及系统 |
CN112819019B (zh) * | 2019-11-15 | 2023-06-20 | 财团法人资讯工业策进会 | 分类模型生成装置及其分类模型生成方法 |
CN111239137B (zh) * | 2020-01-09 | 2021-09-10 | 江南大学 | 基于迁移学习与自适应深度卷积神经网络的谷物质量检测方法 |
CN111368977B (zh) * | 2020-02-28 | 2023-05-02 | 交叉信息核心技术研究院(西安)有限公司 | 一种提高卷积神经网络精确性和鲁棒性的增强数据增强方法 |
CN111428874A (zh) * | 2020-02-29 | 2020-07-17 | 平安科技(深圳)有限公司 | 风控方法、电子装置及计算机可读存储介质 |
CN112861616B (zh) * | 2020-12-31 | 2022-10-11 | 电子科技大学 | 一种无源领域自适应目标检测方法 |
CN112861679B (zh) * | 2021-01-29 | 2023-01-20 | 中国科学院计算技术研究所 | 面向行为识别的迁移学习方法及系统 |
Family Cites Families (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11127062B2 (en) * | 2017-01-23 | 2021-09-21 | Walmart Apollp, Llc | Systems and methods for promoting products in product search results using transfer learning with active sampling |
CN108053030A (zh) * | 2017-12-15 | 2018-05-18 | 清华大学 | 一种开放领域的迁移学习方法及系统 |
CN109492765A (zh) * | 2018-11-01 | 2019-03-19 | 浙江工业大学 | 一种基于迁移模型的图像增量学习方法 |
-
2019
- 2019-07-11 CN CN201910623670.1A patent/CN110443372B/zh active Active
Also Published As
Publication number | Publication date |
---|---|
CN110443372A (zh) | 2019-11-12 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110443372B (zh) | 一种基于熵最小化的迁移学习方法及系统 | |
CN110580496A (zh) | 一种基于熵最小化的深度迁移学习系统及方法 | |
CN109308318B (zh) | 跨领域文本情感分类模型的训练方法、装置、设备及介质 | |
Dong et al. | Automatic age estimation based on deep learning algorithm | |
CN110750665A (zh) | 基于熵最小化的开集域适应方法及系统 | |
US20200097818A1 (en) | Method and system for training binary quantized weight and activation function for deep neural networks | |
CN109325231B (zh) | 一种多任务模型生成词向量的方法 | |
CN111275092B (zh) | 一种基于无监督域适应的图像分类方法 | |
Peng et al. | Accelerating minibatch stochastic gradient descent using typicality sampling | |
CN114241282A (zh) | 一种基于知识蒸馏的边缘设备场景识别方法及装置 | |
CN110196980A (zh) | 一种基于卷积网络在中文分词任务上的领域迁移 | |
CN112699247A (zh) | 一种基于多类交叉熵对比补全编码的知识表示学习框架 | |
CN113159072B (zh) | 基于一致正则化的在线超限学习机目标识别方法及系统 | |
CN111639186A (zh) | 动态嵌入投影门控的多类别多标签文本分类模型及装置 | |
CN113469186A (zh) | 一种基于少量点标注的跨域迁移图像分割方法 | |
Liu et al. | Comparison and evaluation of activation functions in term of gradient instability in deep neural networks | |
Su et al. | Low-rank deep convolutional neural network for multitask learning | |
CN108388918B (zh) | 具有结构保持特性的数据特征选择方法 | |
CN111753995A (zh) | 一种基于梯度提升树的局部可解释方法 | |
CN110580289A (zh) | 一种基于堆叠自动编码器和引文网络的科技论文分类方法 | |
CN115797642A (zh) | 基于一致性正则化与半监督领域自适应图像语义分割算法 | |
Passos Júnior et al. | Deep boltzmann machines using adaptive temperatures | |
Bastidas | Tiny imagenet image classification | |
Ming et al. | Dynamic Deep Multi-task Learning for Caricature-Visual Face Recognition | |
Suzuki et al. | Image classification by transfer learning based on the predictive ability of each attribute |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |