CN116910571A - 一种基于原型对比学习的开集域适应方法及系统 - Google Patents
一种基于原型对比学习的开集域适应方法及系统 Download PDFInfo
- Publication number
- CN116910571A CN116910571A CN202311176914.9A CN202311176914A CN116910571A CN 116910571 A CN116910571 A CN 116910571A CN 202311176914 A CN202311176914 A CN 202311176914A CN 116910571 A CN116910571 A CN 116910571A
- Authority
- CN
- China
- Prior art keywords
- prototype
- class
- domain
- sample
- samples
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 115
- 230000006978 adaptation Effects 0.000 title claims abstract description 24
- 230000006870 function Effects 0.000 claims abstract description 43
- 238000012549 training Methods 0.000 claims abstract description 22
- 238000013145 classification model Methods 0.000 claims abstract description 17
- 238000013528 artificial neural network Methods 0.000 claims abstract description 9
- 230000008569 process Effects 0.000 claims description 38
- 238000010276 construction Methods 0.000 claims description 16
- 238000013507 mapping Methods 0.000 claims description 13
- 230000003044 adaptive effect Effects 0.000 claims description 4
- 210000002569 neuron Anatomy 0.000 claims description 4
- 239000013598 vector Substances 0.000 claims description 4
- 238000006243 chemical reaction Methods 0.000 claims description 3
- 238000013508 migration Methods 0.000 abstract description 9
- 230000005012 migration Effects 0.000 abstract description 9
- 238000009826 distribution Methods 0.000 abstract description 8
- 238000005065 mining Methods 0.000 abstract description 4
- 239000004973 liquid crystal related substance Substances 0.000 description 4
- 238000002474 experimental method Methods 0.000 description 3
- 238000007670 refining Methods 0.000 description 3
- 101000992383 Homo sapiens Oxysterol-binding protein 1 Proteins 0.000 description 2
- 102100032163 Oxysterol-binding protein 1 Human genes 0.000 description 2
- 238000007635 classification algorithm Methods 0.000 description 2
- 238000011161 development Methods 0.000 description 2
- 238000002372 labelling Methods 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 238000012360 testing method Methods 0.000 description 2
- 230000004913 activation Effects 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 230000000052 comparative effect Effects 0.000 description 1
- 125000004122 cyclic group Chemical group 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 239000007788 liquid Substances 0.000 description 1
- 238000007781 pre-processing Methods 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
- 238000000844 transformation Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F17/00—Digital computing or data processing equipment or methods, specially adapted for specific functions
- G06F17/10—Complex mathematical operations
- G06F17/18—Complex mathematical operations for evaluating statistical data, e.g. average values, frequency distributions, probability functions, regression analysis
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/213—Feature extraction, e.g. by transforming the feature space; Summarisation; Mappings, e.g. subspace methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- General Engineering & Computer Science (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Evolutionary Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Mathematical Analysis (AREA)
- Biophysics (AREA)
- Computing Systems (AREA)
- Pure & Applied Mathematics (AREA)
- Molecular Biology (AREA)
- Mathematical Optimization (AREA)
- General Health & Medical Sciences (AREA)
- Computational Mathematics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Computational Linguistics (AREA)
- Operations Research (AREA)
- Databases & Information Systems (AREA)
- Algebra (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于原型对比学习的开集域适应方法及系统,所述方法包括:基于深度神经网络构建分类模型并进行预训练,初始化类别原型并构建原型记忆库;基于类别原型的类间距离自适应区分目标域中未知类别样本;基于原型对比损失构建目标函数进行网络参数学习,基于特征提取器更新类别原型;扩充原型记忆库,进行类别扩展。本发明基于原型对比学习实现域间共享知识挖掘和迁移,以及未知类识别和分类,更适用于域间类分布失配下知识迁移,同时扩展了目标域开放类别的识别能力。
Description
技术领域
本发明涉及域适应技术领域,尤其涉及一种基于原型对比学习的开集域适应方法及系统。
背景技术
随着深度神经网络等复杂模型的发展,对标记数据的需求越来越高。然而数据标注需要专业人工参与,耗时耗力,代价高昂,数据标注稀缺已然成为制约机器学习发展的瓶颈之一;与此同时,机器学习模型在现实任务中常面临新场景,模型的迁移泛化能力亟需关注。域适应学习(Domain Adaptation, DA)将源域知识迁移至目标域,以应对目标域遇到的标注稀缺、新场景等问题,是提升学习模型泛化能力的一种有效手段。而随着学习任务越来越面临开放类环境,除特征分布外,域间类先验分布也将发生漂移。因此,面向开放类环境的开集域适应学习(Open-Set DA, OSDA)是一项重要的研究内容。
现有技术大多基于分布差异最小化或对抗方式对齐域间特征分布(Weikai Li,Songcan Chen, Partial Domain Adaptation without Domain Alignment. IEEETransactions on Pattern Analysis and Machine Intelligence, 2022. doi:10.1109/TPAMI.2022.3228937),但是目标域中的开放类别给域间特征分布对齐带来了挑战,错误类匹配将导致模式坍塌或负迁移问题。
申请号为202210927707.1的中国专利公开了一种基于自监督对比学习的跨域遥感场景分类与检索方法,对于目标域数据,分别进行数据强增强和弱增强,进行特征自监督对比学习,并在输出空间约束强弱增强样本的预测一致性。该方案是通过最大类预测概率与预设定阈值间比较区分已知和未知类别。申请号为202210253606.0的中国专利公开了一种基于文物图像开集识别的分类算法,在训练阶段通过基于特征迁移先验误差的文物图像开集识别算法,利用迁移学习前后模型所提取特征之间的差异,在测试阶段进一步提升网络对已知类文物样本和开集文物样本的判断能力。该方案是通过利用激活特征值与预设定阈值间的比较来区分已知和未知类别。上述方法通过预设定阈值的方法来识别目标域中未知类别,但阈值是数据依赖的,很难提前设定。且现有技术仅致力于对未知类别进行识别,无法对新类别有效分类。
发明内容
发明目的:本发明旨在提供一种能够避免域间分布误匹配、自适应识别目标域中未知类样本的基于原型对比学习的开集域适应方法及系统。
技术方案:本发明所述的一种基于原型对比学习的开集域适应方法,包括:
(1)基于深度神经网络构建分类模型并进行预训练,初始化类别原型,构建原型记忆库;
(2)基于类别原型的类间距离自适应区分目标域中未知类别样本;
(3)基于原型对比损失构建目标函数进行网络参数学习,基于特征提取器更新类别原型;
(4)扩充原型记忆库,进行类别扩展。
优选地,步骤(1)中,所述分类模型包括特征提取器和类别分类器,采用源域样本对分类模型进行预训练,分类损失函数为:
式中,表示第i个样本,表示对应的样本类别标签,表示源域样本集,表示
交叉熵损失,表示样本在分类器中的概率输出;其中,
表示维分类器的概率输出,表示源域已知类别的个数,和分别表示特征提取器和
类别分类器,表示softmax函数。
优选地,步骤(1)中,所述初始化类别原型包括:
式中,表示第类的类别原型,表示样本的特征,为样本对应的类别标
签;表示第类中包含样本数。
优选地,步骤(1)中,所述构建原型记忆库包括:
式中,、和分别表示第类样本的原型、对应的类别标签以及原型总数,此
时。
优选地,步骤(2)包括:构造目标域样本和原型之间的相似度向量,
其中为样本特征与原型之间的相似度,表示为:
式中,当时,目标域样本属于未知类;表示类间距离,为自适应阈
值,由类别原型间相似度的均值计算。
优选地,步骤(3)中包括:
(3.1)构建源域原型对比损失函数;
将样本和原型分别经由非线性转换G映射为,并基于映射表示构建源域
原型对比损失函数:
式中,为温度参数,为指示函数,表示源域样本集,、、分别为样本、、经由非线性转换后的特征表示,、、分别为样本、、的类别标签,为对
应类别原型的非线性映射表示,为其它类别原型的非线性映射表示,为余弦相似
度;
(3.2)构建目标域原型对比损失函数;
对于目标域已知类样本,根据类别分类器前个神经元的概率输出确定其伪标签,则目标域原型对比损失函数为:
式中,表示目标域已知类集合;表示样本的伪标签,为对应伪标签类
别原型的非线性映射表示;
(3.3)构建目标域增强对比损失函数;
对于目标域样本,通过数据增强构建正样本,则目标域增强对比损失函数
为:
式中,表示目标域样本集;和分别表示样本和其增强样本对应的非线
性映射表示;、分别为样本、经由非线性转换后的特征表示;
(3.4)构建分类损失函数;
式中,表示未知类目标域的集合。
优选地,步骤(3)中所述目标函数为:
式中,表示平衡参数。
优选地,步骤(3)中所述基于特征提取器更新类别原型包括:
每一个小批量训练后,同时使用源域原型和目标域原型更新类别原型,
式中,为原型权重参数,为源域第k类的原型,为目标域第k类的原型。
优选地,步骤(4)中所述扩充原型记忆库包括:根据步骤(2)区分已知类和未知类,在原型记忆库中增加新类别原型的存储,对未知类进一步识别。
本发明所述的一种基于原型对比学习的开集域适应系统,包括:
分类模型构建模块,用于基于深度神经网络构建分类模型并进行预训练,并初始化类别原型和构建原型记忆库;
自适应分类模块,用于基于类别原型的类间距离自适应区分目标域中未知类别样本;
类别原型更新模块,用于基于原型对比损失构建目标函数进行网络参数学习,基于特征提取器更新类别原型;
类别扩展模块,用于扩充原型记忆库,进行类别扩展。
有益效果:与现有技术相比,本发明具有如下显著优点:采用原型对比学习,实现域间知识迁移,施加对比约束实现域间和域内同类近、不同类远,从而在挖掘域间类共性知识同时,避免域间分布误匹配问题;基于类别原型,在特征层面自动获取类间距离用于自适应识别未知类样本,以缓解未知类别与已知类中误分样本的混淆问题,且无需预先设定阈值;通过原型记忆库实现对新类别扩展分类。本发明基于原型对比学习实现域间共享知识挖掘和迁移,以及未知类识别和分类,更适用于域间类分布失配下知识迁移,同时扩展了目标域开放类别的识别能力。
附图说明
图1为本发明的方法流程图;
图2为本发明的模型训练流程图;
图3为本发明的方法与其他方法对比结果图。
具体实施方式
下面结合附图对本发明的技术方案作进一步说明。
如图1-2所示,本发明所述的一种基于原型对比学习的开集域适应方法,包括以下步骤:
(1)预处理阶段:对目标域数据进行数据增强预处理,基于源域样本预训练深度神经网络分类模型,并初始化类别原型。
(1.1)基于深度神经网络构建分类模型,包括特征提取器、类别分类器两个部分并对采用源域样本对分类模型进行预训练,分类损失函数如下所示,
式中,表示第i个样本;表示对应的样本类别标签;表示源域样本集;表示交
叉熵损失;表示样本在分类器的概率输出;具体地,表
示 维分类器的概率输出,表示源域已知类别的个数,和分别表示特征提取器和
类别分类器,表示softmax函数。
通过有标签的源域样本进行有监督学习,最小化优化特征提取器和类别分类
器参数。
(1.2)初始化类别原型;
所述的原型是指特定类别样本的特征中心,原型初始化公式如下所示,
式中,表示类别的原型,表示样本特征,为对应的类别标记,表示类别
为的样本数。
(1.3)构建原型记忆库:
式中,、和分别表示第类样本的原型、对应的类别标签和原型总数,此时。
(2)迭代训练阶段:首先基于类别原型的类间距离自适应区分目标域中未知类别样本;然后基于原型对比损失学习网络参数,挖掘能迁移域间共享类中共性知识,同时保持类内近与类间远特性的高判别性特征;最后基于特征提取器更新类别原型。具体包括以下步骤:
(2.1)构造目标域样本和原型之间的相似度向量,其中为样本特征与原型之间的相似度,表示为:
当时,目标域样本属于未知类;表示类间距离,为自适应阈值,由类
别原型间相似度的均值计算。
(2.2)基于原型对比损失学习网络参数。
(2.2.1)构建源域原型对比损失函数;
将样本和原型分别经由非线性映射G,并基于映射后的表示构建源域原型对比损
失函数:
式中,是温度参数,为指示函数,表示源域样本集,、、分别是样本、、经由非线性转换后的特征表示,、、分别是样本、、的类别标签,为对
应类别原型的非线性映射表示,为其它类别原型的非线性映射表示,表示余弦相
似度;
(2.2.2)构建目标域原型对比损失,对于每个目标域样本,首先根据步骤(2. 1)挑选已
知类,然后根据类别分类器前个神经元的概率输出确定其伪标签,目
标域原型对比损失函数为:
式中,表示目标域已知类集合;表示样本的伪标签,为对应伪标签类
别原型的非线性映射表示。基于该损失函数拉近域间距离。
(2.2.3)构建目标域增强对比损失,对于无标签的目标域样本,采用基于数据增强
的对比损失进行聚类,最大化同一类别中的样本互信息。给定任一样本,其正样本为自身
进行数据增强后的视图,构建目标域样本的对比损失函数为,
式中,表示目标域样本集;和分别表示样本和其增强样本对应的非线
性映射表示;、分别为样本、经由非线性转换后的特征表示。
(2.2.4)构建分类损失函数,定义为,
式中,表示未知类目标域的集合;训练时源域样本按真实标签分到前维,
根据步骤(2.1)选择有未知类标记的目标域样本分到第维。
(2.2.5)结合上述损失函数,构建目标函数如下表示,
式中,表示平衡参数。
(2.3)基于特征提取器更新类别原型。
每一个小批量训练后,同时使用源域原型和目标域原型更新类别原型,建立源域和目标域间的稳定联系,更新过程为,
式中,为原型权重参数,为源域第k类的原型,为目标域第k类的原型。
(3)分类已知类样本,并根据类间距离自适应识别未知类样本,扩充原型记忆库对未知类样本细化分类,实现新类别扩展。
(3.1)根据步骤(2.1)分类已知类样本,并自适应识别未知类样本;对已知类通过类别分类器输出类别;
(3.2)基于专家标注的未知类样本,扩充原型记忆库。未知类别的目标域样本在模型训练完毕后持续输入,在不更新网络参数的情况下可以使用原型记忆库记录未知类样本原型,进而分类。
根据步骤(2.1)思想,识别出目标域中未知类别后,计算未知类样本与各未知类原型间相似度,以确定未知类样本的类别,对其进一步细化分类。可通过主动查询方式,给予人工标签,用于扩充原型记忆库,以对未知新类别进行分类。
本发明所述的一种基于原型对比学习的开集域适应系统,包括:
分类模型构建模块,用于基于深度神经网络构建分类模型并进行预训练,并初始化类别原型和构建原型记忆库;
自适应分类模块,用于基于类别原型的类间距离自适应区分目标域中未知类别样本;
类别原型更新模块,用于基于原型对比损失构建目标函数进行网络参数学习,基于特征提取器更新类别原型;
类别扩展模块,用于扩充原型记忆库,进行类别扩展。
为了进一步说明本发明的方法,以图像分类Office-31数据集和在ImageNet上预训练的ResNet网络为例进行实验。
(1)预训练阶段
基于Office-31中AMAZON(A)和DSLR(D)构建跨域分类任务,其中A为源域,D为目标域,选择0-9类作为已知类别,10-19类作为未知类别。
(1.1)构建分类模型并进行预训练。
选择在ImageNet数据集上预训练的ResNet网络的特征提取部分作为特征提取器,类别分类器的输出维度调整为K+1维,本实施例取11维。输入源域样本优化交叉熵损失对模型参数预训练,损失函数如下所示,
式中,表示第i个样本,表示对应的样本类别标签,表示源域样本集,表示
交叉熵损失,表示样本在分类器中的概率输出;其中,
表示维分类器的概率输出。
实验使用Pytorch框架,批大小设置为64,使用动量0.9、学习率为0.001的SGD优化器。
(1.2)初始化类别原型,类别原型初始化公式如下所示,
式中,表示第类的类别原型,表示样本的特征,为样本对应的类别标
签;表示第类中包含样本数。
(1.3)构建原型记忆库:
式中,、和分别表示第类样本的原型、对应的类别标签以及原型总数,此
时。
(2)迭代训练阶段
(2.1)基于类别原型的类间距离自适应区分目标域中未知类别样本
构造目标域样本和原型之间的相似度向量,其中为:
当时,目标域样本属于未知类,是类间距离,可由类别原型间相似
度的均值计算,即
(2.2)基于原型对比损失学习网络参数,挖掘能迁移域间共享类中共性知识,同时保持类内近与类间远特性的高判别性特征。
(2.2.1)将样本和原型分别经由非线性转换G映射为,并基于映射表示构
建源域原型对比损失函数:
式中,为温度参数,为指示函数,表示源域样本集,、、分别为样本、、经由非线性转换后的特征表示,、、分别为样本、、的类别标签,为对
应类别原型的非线性映射表示,为其它类别原型的非线性映射表示,为余弦相似
度。
(2.2.2)构建目标域原型对比损失,对于每个目标域样本,首先根据步骤(2.1)挑选已知
类,然后根据类别分类器前10个神经元的概率输出确定其伪标签,目标
域原型对比损失函数为:
式中,表示目标域已知类集合,表示样本的伪标签,为对应伪标签类
别原型的非线性映射表示。
(2.2.3)构建目标域对比损失,对于无标签的目标域样本,采用基于数据增强的对比损失进行聚类,最大化同一类别中的样本互信息。所述数据增强包括随机裁剪、随机颜色失真和随机高斯模糊等随机处理。
给定任一样本,其正样本为自身进行数据增强后的视图,构建目标域样本的
对比损失函数为,
式中,表示目标域样本集;和分别表示样本和其增强样本对应的非线
性映射表示。
(2.2.4)构建分类损失,定义为,
式中,表示未知类目标域的集合;训练时源域样本按真实标签分到前维,
根据步骤(2.1)选择有未知类标记的目标域样本分到第维。
(2.2.5)结合上述损失函数,最终的目标函数如下表示,
式中,表示平衡参数。
(2.3)基于特征提取器更新类别原型
每一个小批量训练后,同时使用源域原型和目标域原型更新类别原型,建立源域和目标域间的稳定联系,更新过程为,
式中,为原型权重参数,为源域第k类的原型,为目标域第k类的原型。
(3)预测阶段
对测试样本进行分类,基于类间距离自适应区分已知和未知类样本,若样本属于已知类,则由类别分类器输出类别,由原型记忆库进行进一步细化分类。
原型记忆库由专家标注的未知类样本进行扩充,在不更新网络参数的情况下可以使用原型记忆库记录未知类样本原型,进而分类。根据步骤(2.1)思想,识别出目标域中未知类别后,计算未知类样本与各未知类原型间相似度,以确定未知类样本的类别,对其进一步细化分类。本实施例中,取200个D域样本(其中10个共有类,3个私有类),识别私有类后,基于主动学习方式请专家给3个样本赋类别标记,取其特征平均作为原型扩充原型存储记忆库。而后对其余未知类样本进一步细化分类,分类精度可达70%左右,以上实验结果证明本发明不仅能有效识别未知类别,更能通过扩充原型记忆库,对未知样本进一步细化分类。
为了进一步验证本发明的效果,采用不同的方法进行对比实验,如图3所示。其中,OSBP是基于对抗训练的OSDA方法,通过训练分类器和特征生成器,并基于分类器和预设的阈值,区分已知类别和未知类别样本;UAN是一种通用的域适应算法,综合了领域相似性和预测不确定性对样本加权;DANCE是一种新的邻域聚类技术,以自监督的方式学习目标域的结构,基于熵区分已知未知类;DCC是基于循环一致性匹配设计出领域共识得分指标来匹配类别;OURS即为本发明提出的方法。
采用不同的指标评估各方法预测的准确率。其中,ACC_kn表示已知类分类准确率,ACC_unk表示未知类分类准确率,HOS表示ACC_kn和ACC_unk的调和平均值。
实验结果如下表1所示。从表中可以看出,本发明的方法预测分类准确率均高于其他方法,其中对未知类的识别性能提高了2.3%,而已知类和未知类的总体识别性能也提高了2.1%。
表1:Office-31数据集在A->D任务上的分类准确率对比(单位:%)
方法 | ACC_kn | ACC_unk | HOS |
OSBP | 90.5 | 75.5 | 82.3 |
UAN | 87.5 | 52.9 | 65.9 |
DANCE | 90.6 | 81.1 | 85.6 |
DCC | 93.3 | 79.9 | 86.1 |
OURS | 93.6 | 83.4 | 88.2 |
Claims (10)
1.一种基于原型对比学习的开集域适应方法,其特征在于,包括:
(1)基于深度神经网络构建分类模型并进行预训练,初始化类别原型,构建原型记忆库;
(2)基于类别原型的类间距离自适应区分目标域中未知类别样本;
(3)基于原型对比损失构建目标函数进行网络参数学习,基于特征提取器更新类别原型;
(4)扩充原型记忆库,进行类别扩展。
2.根据权利要求1所述的基于原型对比学习的开集域适应方法,其特征在于,步骤(1)中,所述分类模型包括特征提取器和类别分类器,采用源域样本对分类模型进行预训练,分类损失函数为:
式中,表示第i个样本,/>表示对应的样本类别标签,/>表示源域样本集,/>表示交叉熵损失,/>表示样本/>在分类器中的概率输出;其中,/>表示/>维分类器的概率输出,/>表示源域已知类别的个数,/>和/>分别表示特征提取器和类别分类器,/>表示softmax函数。
3.根据权利要求2所述的基于原型对比学习的开集域适应方法,其特征在于,步骤(1)中,所述初始化类别原型包括:
式中,表示第/>类的类别原型,/>表示样本/>的特征,/>为样本/>对应的类别标签;/>表示第/>类中包含样本数。
4.根据权利要求3所述的基于原型对比学习的开集域适应方法,其特征在于,步骤(1)中,所述构建原型记忆库包括:
式中,、/>和/>分别表示第/>类样本的原型、对应的类别标签以及原型总数,此时。
5.根据权利要求4所述的基于原型对比学习的开集域适应方法,其特征在于,步骤(2)包括:构造目标域样本和原型之间的相似度向量,其中/>为样本特征/>与原型/>之间的相似度,表示为:
式中,当时,目标域样本属于未知类;/>表示类间距离,为自适应阈值,由类别原型间相似度的均值计算。
6.根据权利要求5所述的基于原型对比学习的开集域适应方法,其特征在于,步骤(3)包括:
(3.1)构建源域原型对比损失函数;
将样本和原型分别经由非线性转换G映射为,并基于映射表示构建源域原型对比损失函数/>:
式中,为温度参数,/>为指示函数,/>表示源域样本集,/>、/>、/>分别为样本/>、/>、/>经由非线性转换后的特征表示,/>、/>、/>分别为样本/>、/>、/>的类别标签,/>为/>对应类别原型的非线性映射表示,/>为其它类别原型的非线性映射表示,/>为余弦相似度;
(3.2)构建目标域原型对比损失函数;
对于目标域已知类样本,根据类别分类器前个神经元的概率输出确定其伪标签,则目标域原型对比损失函数/>为:
式中,表示目标域已知类集合;/>表示样本/>的伪标签,/>为对应伪标签类别原型的非线性映射表示;
(3.3)构建目标域增强对比损失函数;
对于目标域样本,通过数据增强构建正样本/>,则目标域增强对比损失函数/>为:
式中,表示目标域样本集;/>和/>分别表示样本/>和其增强样本/>对应的非线性映射表示;/>、/>分别为样本/>、/>经由非线性转换后的特征表示;
(3.4)构建分类损失函数;
式中,表示未知类目标域的集合。
7.根据权利要求6所述的基于原型对比学习的开集域适应方法,其特征在于,步骤(3)中所述目标函数为:
式中,表示平衡参数。
8.根据权利要求7所述的基于原型对比学习的开集域适应方法,其特征在于,步骤(3)中所述基于特征提取器更新类别原型包括:
每一个小批量训练后,同时使用源域原型和目标域原型更新类别原型,
式中,为原型权重参数,/>为源域第k类的原型,/>为目标域第k类的原型。
9.根据权利要求8所述的基于原型对比学习的开集域适应方法,其特征在于,步骤(4)中所述扩充原型记忆库包括:根据步骤(2)区分已知类和未知类,在原型记忆库中增加新类别原型的存储,对未知类进一步识别。
10.一种基于原型对比学习的开集域适应系统,其特征在于,包括:
分类模型构建模块,用于基于深度神经网络构建分类模型并进行预训练,并初始化类别原型和构建原型记忆库;
自适应分类模块,用于基于类别原型的类间距离自适应区分目标域中未知类别样本;
类别原型更新模块,用于基于原型对比损失构建目标函数进行网络参数学习,基于特征提取器更新类别原型;
类别扩展模块,用于扩充原型记忆库,进行类别扩展。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311176914.9A CN116910571B (zh) | 2023-09-13 | 2023-09-13 | 一种基于原型对比学习的开集域适应方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311176914.9A CN116910571B (zh) | 2023-09-13 | 2023-09-13 | 一种基于原型对比学习的开集域适应方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116910571A true CN116910571A (zh) | 2023-10-20 |
CN116910571B CN116910571B (zh) | 2023-12-08 |
Family
ID=88351514
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311176914.9A Active CN116910571B (zh) | 2023-09-13 | 2023-09-13 | 一种基于原型对比学习的开集域适应方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116910571B (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117408330A (zh) * | 2023-12-14 | 2024-01-16 | 合肥高维数据技术有限公司 | 面向非独立同分布数据的联邦知识蒸馏方法及装置 |
Citations (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110750665A (zh) * | 2019-10-12 | 2020-02-04 | 南京邮电大学 | 基于熵最小化的开集域适应方法及系统 |
CN113128620A (zh) * | 2021-05-11 | 2021-07-16 | 北京理工大学 | 一种基于层次关系的半监督领域自适应图片分类方法 |
US20210390355A1 (en) * | 2020-06-13 | 2021-12-16 | Zhejiang University | Image classification method based on reliable weighted optimal transport (rwot) |
CN113988126A (zh) * | 2021-10-26 | 2022-01-28 | 哈尔滨理工大学 | 一种基于少标签数据特征迁移的滚动轴承故障诊断方法 |
CN114611617A (zh) * | 2022-03-16 | 2022-06-10 | 西安理工大学 | 基于原型网络的深度领域自适应图像分类方法 |
CN115410088A (zh) * | 2022-10-10 | 2022-11-29 | 中国矿业大学 | 一种基于虚拟分类器的高光谱图像领域自适应方法 |
CN115908892A (zh) * | 2022-10-09 | 2023-04-04 | 浙江大学 | 一种基于原型对比自训练的跨域图像分类方法 |
CN115984621A (zh) * | 2023-01-09 | 2023-04-18 | 宁波拾烨智能科技有限公司 | 一种基于限制性原型对比网络的小样本遥感图像分类方法 |
CN116337447A (zh) * | 2022-12-19 | 2023-06-27 | 苏州大学 | 一种非平稳工况下轨道车辆轮对轴承故障诊断方法及设备 |
CN116468991A (zh) * | 2023-02-24 | 2023-07-21 | 西安电子科技大学 | 基于渐进校准的类增量无监督域自适应图像识别方法 |
WO2023137889A1 (zh) * | 2022-01-20 | 2023-07-27 | 北京邮电大学 | 基于嵌入增强和自适应的小样本图像增量分类方法及装置 |
CN116503676A (zh) * | 2023-06-27 | 2023-07-28 | 南京大数据集团有限公司 | 一种基于知识蒸馏小样本增量学习的图片分类方法及系统 |
-
2023
- 2023-09-13 CN CN202311176914.9A patent/CN116910571B/zh active Active
Patent Citations (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110750665A (zh) * | 2019-10-12 | 2020-02-04 | 南京邮电大学 | 基于熵最小化的开集域适应方法及系统 |
US20210390355A1 (en) * | 2020-06-13 | 2021-12-16 | Zhejiang University | Image classification method based on reliable weighted optimal transport (rwot) |
CN113128620A (zh) * | 2021-05-11 | 2021-07-16 | 北京理工大学 | 一种基于层次关系的半监督领域自适应图片分类方法 |
CN113988126A (zh) * | 2021-10-26 | 2022-01-28 | 哈尔滨理工大学 | 一种基于少标签数据特征迁移的滚动轴承故障诊断方法 |
WO2023137889A1 (zh) * | 2022-01-20 | 2023-07-27 | 北京邮电大学 | 基于嵌入增强和自适应的小样本图像增量分类方法及装置 |
CN114611617A (zh) * | 2022-03-16 | 2022-06-10 | 西安理工大学 | 基于原型网络的深度领域自适应图像分类方法 |
CN115908892A (zh) * | 2022-10-09 | 2023-04-04 | 浙江大学 | 一种基于原型对比自训练的跨域图像分类方法 |
CN115410088A (zh) * | 2022-10-10 | 2022-11-29 | 中国矿业大学 | 一种基于虚拟分类器的高光谱图像领域自适应方法 |
CN116337447A (zh) * | 2022-12-19 | 2023-06-27 | 苏州大学 | 一种非平稳工况下轨道车辆轮对轴承故障诊断方法及设备 |
CN115984621A (zh) * | 2023-01-09 | 2023-04-18 | 宁波拾烨智能科技有限公司 | 一种基于限制性原型对比网络的小样本遥感图像分类方法 |
CN116468991A (zh) * | 2023-02-24 | 2023-07-21 | 西安电子科技大学 | 基于渐进校准的类增量无监督域自适应图像识别方法 |
CN116503676A (zh) * | 2023-06-27 | 2023-07-28 | 南京大数据集团有限公司 | 一种基于知识蒸馏小样本增量学习的图片分类方法及系统 |
Non-Patent Citations (5)
Title |
---|
RAKSHIT S 等: "Multi-source open-set deep adversarial domain adaptation", 《COMPUTER VISION–ECCV 2020》, pages 735 - 750 * |
RAKSHIT S 等: "Open-Set Domain Adaptation Under Few Source-Domain Labeled Samples", 《2022 IEEE/CVF CONFERENCE ON COMPUTER VISION AND PATTERN RECOGNITION WORKSHOPS(CVPRW)》, pages 4029 - 4038 * |
XU Y 等: "Open set domain adaptation with soft unknown-class rejection", 《IEEE TRANSACTIONS ON NEURAL NETWORKS AND LEARNING SYSTEMS》, vol. 34, no. 3, pages 1601 - 1612 * |
宋闯 等: "面向智能感知的小样本学习研究综述", 《航空学报》, no. 1, pages 15 - 28 * |
张雪梅 等: "对抗式域适配迁移学习研究", 《计算机科学与应用》, vol. 11, no. 12, pages 2853 - 2861 * |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117408330A (zh) * | 2023-12-14 | 2024-01-16 | 合肥高维数据技术有限公司 | 面向非独立同分布数据的联邦知识蒸馏方法及装置 |
CN117408330B (zh) * | 2023-12-14 | 2024-03-15 | 合肥高维数据技术有限公司 | 面向非独立同分布数据的联邦知识蒸馏方法及装置 |
Also Published As
Publication number | Publication date |
---|---|
CN116910571B (zh) | 2023-12-08 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113378632B (zh) | 一种基于伪标签优化的无监督域适应行人重识别方法 | |
CN110909820B (zh) | 基于自监督学习的图像分类方法及系统 | |
CN111967294A (zh) | 一种无监督域自适应的行人重识别方法 | |
CN106469560B (zh) | 一种基于无监督域适应的语音情感识别方法 | |
CN111738172B (zh) | 基于特征对抗学习和自相似性聚类的跨域目标重识别方法 | |
CN109492750B (zh) | 基于卷积神经网络和因素空间的零样本图像分类方法 | |
CN110619059B (zh) | 一种基于迁移学习的建筑物标定方法 | |
CN108537168B (zh) | 基于迁移学习技术的面部表情识别方法 | |
CN111832511A (zh) | 一种增强样本数据的无监督行人重识别方法 | |
CN111967325A (zh) | 一种基于增量优化的无监督跨域行人重识别方法 | |
CN108345866B (zh) | 一种基于深度特征学习的行人再识别方法 | |
CN116910571B (zh) | 一种基于原型对比学习的开集域适应方法及系统 | |
NL2029214A (en) | Target re-indentification method and system based on non-supervised pyramid similarity learning | |
CN111079847A (zh) | 一种基于深度学习的遥感影像自动标注方法 | |
CN114692732A (zh) | 一种在线标签更新的方法、系统、装置及存储介质 | |
CN114579794A (zh) | 特征一致性建议的多尺度融合地标图像检索方法及系统 | |
CN113095229B (zh) | 一种无监督域自适应行人重识别系统及方法 | |
CN117152459A (zh) | 图像检测方法、装置、计算机可读介质及电子设备 | |
CN107993311B (zh) | 一种用于半监督人脸识别门禁系统的代价敏感隐语义回归方法 | |
CN112750128A (zh) | 图像语义分割方法、装置、终端及可读存储介质 | |
CN116580272A (zh) | 一种基于模型融合推理的雷达目标分类方法及系统 | |
CN110807467A (zh) | 一种基于支持点学习的开集类别发掘方法与装置 | |
Khadempir et al. | Domain adaptation based on incremental adversarial learning | |
CN113673555B (zh) | 一种基于记忆体的无监督域适应图片分类方法 | |
CN114972904B (zh) | 一种基于对抗三元组损失的零样本知识蒸馏方法及系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |