CN117521774A - 模型优化方法、装置、电子设备、存储介质及产品 - Google Patents
模型优化方法、装置、电子设备、存储介质及产品 Download PDFInfo
- Publication number
- CN117521774A CN117521774A CN202311270480.9A CN202311270480A CN117521774A CN 117521774 A CN117521774 A CN 117521774A CN 202311270480 A CN202311270480 A CN 202311270480A CN 117521774 A CN117521774 A CN 117521774A
- Authority
- CN
- China
- Prior art keywords
- model
- domain data
- pseudo
- training
- preset
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 60
- 238000005457 optimization Methods 0.000 title claims abstract description 43
- 238000012549 training Methods 0.000 claims abstract description 116
- 238000012545 processing Methods 0.000 claims abstract description 23
- 238000004590 computer program Methods 0.000 claims description 15
- 230000006870 function Effects 0.000 claims description 14
- 230000006978 adaptation Effects 0.000 claims description 13
- 230000003044 adaptive effect Effects 0.000 description 8
- 238000004821 distillation Methods 0.000 description 7
- 238000010586 diagram Methods 0.000 description 6
- 238000013140 knowledge distillation Methods 0.000 description 6
- 238000004891 communication Methods 0.000 description 5
- 230000008569 process Effects 0.000 description 5
- 230000007123 defense Effects 0.000 description 4
- 238000001514 detection method Methods 0.000 description 4
- 238000009826 distribution Methods 0.000 description 3
- 230000011218 segmentation Effects 0.000 description 3
- 238000013528 artificial neural network Methods 0.000 description 2
- 230000008859 change Effects 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 238000007781 pre-processing Methods 0.000 description 2
- 238000004458 analytical method Methods 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 230000005540 biological transmission Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000018109 developmental process Effects 0.000 description 1
- 238000003745 diagnosis Methods 0.000 description 1
- 238000000691 measurement method Methods 0.000 description 1
- 238000013508 migration Methods 0.000 description 1
- 230000005012 migration Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 238000003672 processing method Methods 0.000 description 1
- 238000013138 pruning Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000012216 screening Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 231100000331 toxic Toxicity 0.000 description 1
- 230000002588 toxic effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/094—Adversarial learning
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明提供一种模型优化方法、装置、电子设备、存储介质及产品,涉及数据处理技术领域,包括:将未携带标签的目标域数据输入教师模型,得到教师模型的第一输出;其中,教师模型是基于携带有标签的源域数据训练得到的源域模型;将伪对抗样本和目标域数据输入预设学生模型,得到预设学生模型的第二输出;其中,伪对抗样本是根据未携带标签的目标域数据和预设学生模型确定的;基于第一输出和第二输出,确定预设学生模型的训练目标;基于目标域数据、伪对抗样本和训练目标对预设学生模型进行训练,在满足预设训练条件的情况下,得到训练好的学生模型;基于训练好的学生模型和源域模型进行模型自适应处理。
Description
技术领域
本发明涉及数据处理技术领域,尤其涉及一种模型优化方法、装置、电子设备、存储介质及产品。
背景技术
随着深度学习技术的发展,深度神经网络在计算机视觉、自然语言处理和许多其他应用领域取得了重大进展,并广泛应用于自动驾驶、图片生成、医学诊疗等多种场景中。由于部署期间的无标签目标域数据通常与用于训练的有标签源域数据分布不一致,源域提供的模型在部署时性能大幅度降低,因此模型自适应任务被广泛研究以提高跨域的模型性能。
模型自适应任务限制目标域用户在不接触源域数据的条件下仅使用预训练模型来提升其在目标域的性能。与之前被广泛研究的无监督领域适应相比,模型自适应作为一种新的范式,保护了源域数据的隐私,并且节省了数据传输与存储的开销,甚至可以取得与无监督领域适应相近的性能,并得到越来越多的关注与研究,并在各种任务中广泛部署,例如图像分类、语义分割、目标检测、多模态学习等,在训练数据规模不断增加的时代,是一种高效实用的方法。
在模型自适应任务现存的方法中,目标用户通常是无条件信任源域提供的预训练模型,但是这是一个很危险的行为。用户使用源域预训练模型进行自适应得到的目标域模型很容易被攻击。因为源域提供者可以在训练模型时掌握的先验信息向以其为初始化的目标域模型发起攻击,常见的模型攻击方式有通用对抗扰动与后门攻击等。通用对抗扰动在已有的模型上,计算出一个可以使得数据集中多数图像的预测发生改变的扰动,作为通用对抗扰动,在目标域模型发布后,源域提供者可以利用叠加该扰动使得目标域模型预测错误。后门攻击则是在模型训练时在训练数据中掺入一部分含有特定触发模式和指定预测的有毒样本,由于深度神经网络的强大的容量与过拟合的特性,模型在测试过程中遇到含有对应触发模式的样本时会输出其指定的与样本内容无关的预测。因此,目标域用户由于接触不到源域数据以及训练过程,无条件地相信源域预训练模型进行微调,其模型就存在被源域提供者利用先验信息进行攻击的风险,这使得模型自适应范式在保护了源域隐私的同时,却增加了目标域用户被攻击的风险,是一个严重的安全隐患。
由于模型在迁移时的安全问题在实际部署中十分重要,所以近年来陆续一些防御方法被提出。针对对抗攻击,经典的防御方法利用训练集的对抗样本增强训练集进行训练,提高了模型的鲁棒性。针对后门攻击,也有一些基于模型剪枝、模型微调、知识蒸馏等方法被提出,或者进行后门样本检测从而拒绝一些查询。但是上述的方法都是在有监督的框架下进行防御的,在其部署过程中需要一些有标注的训练数据,因此不直接适用于模型自适应任务。
发明内容
本发明提供一种模型优化方法、装置、电子设备、存储介质及产品,用以解决现有技术中都是在有监督的框架下进行防御的,在其部署过程中需要一些有标注的训练数据,因此不直接适用于模型自适应任务的缺陷。
本发明提供一种模型优化方法,包括:
将未携带标签的目标域数据输入教师模型,得到所述教师模型的第一输出;其中,所述教师模型是基于携带有标签的源域数据训练得到的源域模型;
将伪对抗样本和所述目标域数据输入预设学生模型,得到所述预设学生模型的第二输出;其中,所述伪对抗样本是根据未携带标签的目标域数据和所述预设学生模型确定的;
基于所述第一输出和所述第二输出,确定所述预设学生模型的训练目标;
基于所述目标域数据、所述伪对抗样本和所述训练目标对所述预设学生模型进行训练,在满足预设训练条件的情况下,得到训练好的学生模型;
基于所述训练好的学生模型和所述源域模型进行模型自适应处理。
根据本发明提供的一种模型优化方法,在所述将伪对抗样本和所述目标域数据输入所述预设学生模型,得到所述预设学生模型的第二输出的步骤之前,还包括:
利用预设学生模型对所述目标域数据进行分析,计算得到所述目标域数据对应伪标签;
根据所述伪标签计算所述目标域数据在所述预设学生模型上的伪对抗样本。
根据本发明提供的一种模型优化方法,根据所述伪标签计算所述目标域数据在所述预设学生模型上的伪对抗样本,包括:
利用伪标签计算所述目标域数据对于所述学生模型的交叉熵损失函数;
最大化所述交叉熵损失函数,并在各个所述目标域数据中找到使所述交叉熵损失函数最大化的目标域数据作为伪对抗样本。
根据本发明提供的一种模型优化方法,基于所述目标域数据、所述伪对抗样本和所述训练目标对所述预设学生模型进行训练,在满足预设训练条件的情况下,得到训练好的学生模型,包括:
将所述目标域数据输入到所述预设学生模型,输出所述目标域数据对应的模型输出;
分别计算所述模型输出和所述伪对抗样本与所述训练目标之间的第一Kullback–Leibler散度与第二Kullback–Leibler散度,优化所述预设学生模型,使所述第一Kullback–Leibler散度与第二Kullback–Leibler散度之和最小化;
在第一Kullback–Leibler散度与第二Kullback–Leibler散度之和小于预设阈值的情况下,停止训练,得到训练好的学生模型。
根据本发明提供的一种模型优化方法,基于所述训练好的学生模型和所述源域模型进行模型自适应处理,包括:
将所述训练好的学生模型作为所述源域模型自适应的初始化模型,进行模型自适应处理,得到自适应处理后的源域模型。
本发明还提供一种模型优化装置,包括:
第一输入模块,用于将未携带标签的目标域数据输入教师模型,得到所述教师模型的第一输出;其中,所述教师模型是基于携带有标签的源域数据训练得到的源域模型;
第二输入模块,用于将伪对抗样本和所述目标域数据输入预设学生模型,得到所述预设学生模型的第二输出;其中,所述伪对抗样本是根据未携带标签的目标域数据和所述预设学生模型确定的;
确定模块,用于基于所述第一输出和所述第二输出,确定所述预设学生模型的训练目标;
训练模块,用于基于所述目标域数据、所述伪对抗样本和所述训练目标对所述预设学生模型进行训练,在满足预设训练条件的情况下,得到训练好的学生模型;
自适应模块,用于基于所述训练好的学生模型和所述源域模型进行模型自适应处理。
根据本发明提供的一种模型优化装置,所述装置还用于:
利用预设学生模型对所述目标域数据进行分析,计算得到所述目标域数据对应伪标签;
根据所述伪标签计算所述目标域数据在所述预设学生模型上的伪对抗样本。
根据本发明提供的一种模型优化装置,所述装置还用于:
利用伪标签计算所述目标域数据对于所述学生模型的交叉熵损失函数;
最大化所述交叉熵损失函数,并在各个所述目标域数据中找到使所述交叉熵损失函数最大化的目标域数据作为伪对抗样本。
根据本发明提供的一种模型优化装置,所述装置还用于:
将所述目标域数据输入到所述预设学生模型,输出所述目标域数据对应的模型输出;
分别计算所述模型输出和所述伪对抗样本与所述训练目标之间的第一Kullback–Leibler散度与第二Kullback–Leibler散度,优化所述预设学生模型,使所述第一Kullback–Leibler散度与第二Kullback–Leibler散度之和最小化;
在第一Kullback–Leibler散度与第二Kullback–Leibler散度之和小于预设阈值的情况下,停止训练,得到训练好的学生模型。
根据本发明提供的一种模型优化装置,所述装置还用于:
将所述训练好的学生模型作为所述源域模型自适应的初始化模型,进行模型自适应处理,得到自适应处理后的源域模型。
本发明还提供一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现如上述任一种所述模型优化方法。
本发明还提供一种非暂态计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现如上述任一种所述模型优化方法。
本发明还提供一种计算机程序产品,包括计算机程序,所述计算机程序被处理器执行时实现如上述任一种所述模型优化方法。
本发明提供的模型优化方法、装置、电子设备、存储介质及产品,基于不携带标签的目标域数据生成预设学生模型的伪对抗样本,并将已有的源域模型作为教师模型,通过知识蒸馏和学生模型的自蒸馏,将目标域数据输入教师模型得到的第一输出,以及伪对抗样本和所述目标域数据输入预设学生模型得到的第二输出,基于第一输出和第二输出,作为预设学生模型的优化目标,最终得到优化后的学生模型,通过该学生模型进行模型自适应处理后,可以同时防御通用对抗扰动和后门攻击,无需为特定攻击方式设计防御,无需提前获取攻击的信息,更加符合实际场景中的需求,并且不需要接触源域训练数据,无需引入标注,并且保证模型在原有任务上的性能基本不受影响。
附图说明
为了更清楚地说明本发明或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作一简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本申请实施例提供的模型优化方法流程示意图;
图2为本申请实施例提供的处理示意图;
图3为本申请实施例提供的模型优化装置结构示意图;
图4是本发明提供的电子设备的结构示意图。
具体实施方式
为使本发明的目的、技术方案和优点更加清楚,下面将结合本发明中的附图,对本发明中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
图1为本申请实施例提供的模型优化方法流程示意图,如图1所示,包括:
步骤110,将未携带标签的目标域数据输入教师模型,得到所述教师模型的第一输出;其中,所述教师模型是基于携带有标签的源域数据训练得到的源域模型;
在本申请实施例中,目标域数据具体可以是未携带标签的目标域数据,其具体可以是图像分类、语义分割、目标检测、多模态学习等数据。
在本申请实施例中,源域数据是携带有标签的源域数据,其具体可以是图像分类、语义分割、目标检测、多模态学习等数据。
在一个可选地实施例中,用于同一模型优化的源域数据与目标域数据可以是同一类别的数据。
在本申请实施例中,可以通过携带有标签的源域数据对预设模型进行训练,在该模型收敛后,得到完成训练的源域模型,该源域模型是一个具备良好表现的模型。
在本申请实施例中,可以进一步将源域模型作为教师模型,它可以作为一个参考或者导师来指导学生模型的训练。教师模型通常具有较高的准确度和性能,在一些特定任务上表现出色。
本申请实施例中,可以进一步将目标域数据输入教师模型后,得到第一输出,此时可以通过该第一输出来指导学生模型的训练。
步骤120,将伪对抗样本和所述目标域数据输入预设学生模型,得到所述预设学生模型的第二输出;其中,所述伪对抗样本是根据未携带标签的目标域数据和所述预设学生模型确定的;
在本申请实施例中,伪对抗样本具体可以是利用学生模型和目标域数据计算得到伪标签,然后基于伪标签计算出该样本在学生模型上的伪对抗样本。
在本申请实施例中,预设学生模型可以通过自蒸馏的方式来进行调整,其具体即是可以通过利用其自身的输出来训练和改进自己。与传统的知识蒸馏方法不同,自蒸馏不需要一个外部的教师模型,而是使用学生模型自己的预测结果作为目标进行训练。
在本申请实施例中,可以进一步通过伪对抗样本和所述目标域数据输入输入学生模型后得到的第二输出,来进一步优化和训练学生模型。
步骤130,基于所述第一输出和所述第二输出,确定所述预设学生模型的训练目标;
在本申请实施例中,可以充分结合学习蒸馏得到的第一输出,以及自蒸馏得到的第二输出,来共同确定学生模型的训练目标,具体可以是将其加权求和。
在本申请实施例中,可以将第一输出和所述第二输出进行加权求和,最终得到训练目标,具体加权求和的权值可以是预设的。
也可以是将第一输出和所述第二输出的平均数,作为预设学生模型的训练目标。
步骤140,基于所述目标域数据、所述伪对抗样本和所述训练目标对所述预设学生模型进行训练,在满足预设训练条件的情况下,得到训练好的学生模型;
在本申请实施例中,可以将目标域数据和伪对抗样本输入到预设学生模型后,可以根据模型的输出,分别计算模型输出和所述伪对抗样本与所述训练目标之间的Kullback–Leibler散度,并且不断优化学生模型,使得两者之和最小化。
在本申请实施例中,预设训练条件,具体可以是两者Kullback–Leibler散度小于预设阈值,或者是训练次数超过预设阈值,还可以是训练时间超过预设时间,则对应此时会认为已经满足预设训练条件,完成预设学生模型的训练,得到训练好的学生模型。
步骤150,基于所述训练好的学生模型和所述源域模型进行模型自适应处理。
在本申请实施例中,在得到训练好的学生模型之后,可以将其作为已有的模型自适应方法的初始化模型,进行模型自适应
在本申请实施例中,通过自适应处理,可以与目前已有的模型自适应算法相结合,为其提供一个预处理的环节,不会干扰后续自适应流程或者产生与后续流程重复、冲突的损失函数或方法,并且有效提升了后续自适应任务的鲁棒性。
在本申请实施例中,基于不携带标签的目标域数据生成预设学生模型的伪对抗样本,并将已有的源域模型作为教师模型,通过知识蒸馏和学生模型的自蒸馏,将目标域数据输入教师模型得到的第一输出,以及伪对抗样本和所述目标域数据输入预设学生模型得到的第二输出,基于第一输出和第二输出,作为预设学生模型的优化目标,最终得到优化后的学生模型,通过该学生模型进行模型自适应处理后,可以同时防御通用对抗扰动和后门攻击,无需为特定攻击方式设计防御,无需提前获取攻击的信息,更加符合实际场景中的需求,并且不需要接触源域训练数据,无需引入标注,并且保证模型在原有任务上的性能基本不受影响。
可选地,在所述将伪对抗样本和所述目标域数据输入所述预设学生模型,得到所述预设学生模型的第二输出的步骤之前,还包括:
利用预设学生模型对所述目标域数据进行分析,计算得到所述目标域数据对应伪标签;
根据所述伪标签计算所述目标域数据在所述预设学生模型上的伪对抗样本。
在本申请实施例中,使用学生模型对未标注的目标域数据进行预测,并根据预测结果计算伪标签。
在本申请实施例中,可以进一步,根据伪标签及对应的目标域数据再次输入预设学生模型进行分析,选择一些有较高置信度的样本进行对抗攻击,以产生预设学生模型具有挑战性的对抗样本,最终得到伪对抗样本。
在本申请实施例中,通过伪对抗样本的生成,可以进一步有效扩散训练样本的数据量,有效提高后续模型训练的准确性。
可选地,根据所述伪标签计算所述目标域数据在所述预设学生模型上的伪对抗样本,包括:
利用伪标签计算所述目标域数据对于所述学生模型的交叉熵损失函数;
最大化所述交叉熵损失函数,并在各个所述目标域数据中找到使所述交叉熵损失函数最大化的目标域数据作为伪对抗样本。
在本申请实施例中,将目标域数据及其对应的伪标签输入学生模型,计算交叉熵损失,遍历目标域数据集,计算每个样本的交叉熵损失。选择使损失最大化的样本作为伪对抗样本。
在本申请实施例中,优化目标是找到具有挑战性的样本,因此选择损失最大化的样本作为伪对抗样本。
在本申请实施例中,可以进一步对目标域数据进行适当的筛选和约束。例如,可以设置一个阈值来限制选择伪对抗样本的损失,或者通过进一步优化算法来探索具有较高挑战性的样本。
在本申请实施例中,通过将伪标签和目标域数据输入学生模型得到的交叉熵损失进行最大化,可以进一步找到具备挑战性的伪对抗样本,进一步丰富模型训练的训练样本。
可选地,基于所述目标域数据、所述伪对抗样本和所述训练目标对所述预设学生模型进行训练,在满足预设训练条件的情况下,得到训练好的学生模型,包括:
将所述目标域数据输入到所述预设学生模型,输出所述目标域数据对应的模型输出;
分别计算所述模型输出和所述伪对抗样本与所述训练目标之间的第一Kullback–Leibler散度与第二Kullback–Leibler散度,优化所述预设学生模型,使所述第一Kullback–Leibler散度与第二Kullback–Leibler散度之和最小化;
在第一Kullback–Leibler散度与第二Kullback–Leibler散度之和小于预设阈值的情况下,停止训练,得到训练好的学生模型。
在本申请实施例中,Kullback-Leibler(KL)散度是一种衡量两个概率分布之间差异的度量方法,通过分别计算模型输出与训练目标之间的第一Kullback–Leibler散度,以及伪对抗样本与训练目标之间的第二Kullback–Leibler散度,可以有效比较他们之间的概率分布的差异。
在本申请实施例中,计算第一Kullback–Leibler散度与第二Kullback–Leibler散度之和,将其最小化作为模型优化的目标。
在本申请实施例中,可以不断根据目标域数据和伪对抗样本进行模型优化,在第一Kullback–Leibler散度与第二Kullback–Leibler散度之和小于预设阈值的情况下,完成模型的训练,得到训练好的学生模型。
在本申请实施例中,通过第一Kullback–Leibler散度与第二Kullback–Leibler散度之和最小化为目标,进行模型优化,可以有效保证模型训练的有效性,可以有效应于无监督场景并且不需要接触源域训练数据,无需引入标注,并且保证模型在原有任务上的性能基本不受影响。
可选地,基于所述训练好的学生模型和所述源域模型进行模型自适应处理,包括:
将所述训练好的学生模型作为所述源域模型自适应的初始化模型,进行模型自适应处理,得到自适应处理后的源域模型。
在本申请实施例中,可以使用优化后的学生网络作为初始化模型,结合已有的模型自适应方法进行后续提升。
在本申请实施例中,通过与目前已有的模型自适应算法相结合,为其提供一个预处理的环节,不会干扰后续自适应流程或者产生与后续流程重复、冲突的损失函数或方法,并且有效提升了后续自适应任务的鲁棒性。
图2为本申请实施例提供的处理示意图,如图2所示,包括:
假设该方法有一个使用源域数据训练得到与目标域具有相同类别空间的源模型;
将源域模型作为教师模型,学生模型使用ImageNet-1K预训练模型作为初始化,计算无标签目标域数据在两个模型上的输出。
利用学生模型输出计算伪标签,并基于伪标签计算出样本在学生模型上的伪对抗样本。
计算学生模型原样本和伪对抗样本的输出分别与教师模型原样本输出的Kullback–Leibler散度,优化学生模型使二者之和使其最小化。
得到学生模型,并将其作为已有的模型自适应方法的初始化模型,进行模型自适应。
在本申请实施例中,通过知识蒸馏避免直接使用有风险的预训练参数,并利用调整半径下的伪对抗样本来增强鲁棒性。本方法是一个即插即用的模块,既不需要强大的预训练模型,也不需要对以下模型自适应算法进行任何更改。广泛结果验证了本方法可以有效防御通用攻击,同时在目标域中保持干净的准确性。
下面对本发明提供的模型优化装置进行描述,下文描述的模型优化装置与上文描述的模型优化方法可相互对应参照。
图3为本申请实施例提供的模型优化装置结构示意图,如图3所示,包括:
第一输入模块310用于将未携带标签的目标域数据输入教师模型,得到所述教师模型的第一输出;其中,所述教师模型是基于携带有标签的源域数据训练得到的源域模型;
第二输入模块320用于将伪对抗样本和所述目标域数据输入预设学生模型,得到所述预设学生模型的第二输出;其中,所述伪对抗样本是根据未携带标签的目标域数据和所述预设学生模型确定的;
确定模块330用于基于所述第一输出和所述第二输出,确定所述预设学生模型的训练目标;
训练模块340用于基于所述目标域数据、所述伪对抗样本和所述训练目标对所述预设学生模型进行训练,在满足预设训练条件的情况下,得到训练好的学生模型;
自适应模块350用于基于所述训练好的学生模型和所述源域模型进行模型自适应处理。
所述装置还用于:
利用预设学生模型对所述目标域数据进行分析,计算得到所述目标域数据对应伪标签;
根据所述伪标签计算所述目标域数据在所述预设学生模型上的伪对抗样本。
所述装置还用于:
将所述目标域数据输入到所述预设学生模型,输出所述目标域数据对应的模型输出;
分别计算所述模型输出和所述伪对抗样本与所述训练目标之间的第一Kullback–Leibler散度与第二Kullback–Leibler散度,优化所述预设学生模型,使所述第一Kullback–Leibler散度与第二Kullback–Leibler散度之和最小化;
在第一Kullback–Leibler散度与第二Kullback–Leibler散度之和小于预设阈值的情况下,停止训练,得到训练好的学生模型。
所述装置还用于:
将所述训练好的学生模型作为所述源域模型自适应的初始化模型,进行模型自适应处理,得到自适应处理后的源域模型。
在本申请实施例中,基于不携带标签的目标域数据生成预设学生模型的伪对抗样本,并将已有的源域模型作为教师模型,通过知识蒸馏和学生模型的自蒸馏,将目标域数据输入教师模型得到的第一输出,以及伪对抗样本和所述目标域数据输入预设学生模型得到的第二输出,基于第一输出和第二输出,作为预设学生模型的优化目标,最终得到优化后的学生模型,通过该学生模型进行模型自适应处理后,可以同时防御通用对抗扰动和后门攻击,无需为特定攻击方式设计防御,无需提前获取攻击的信息,更加符合实际场景中的需求,并且不需要接触源域训练数据,无需引入标注,并且保证模型在原有任务上的性能基本不受影响。
图4是本发明提供的电子设备的结构示意图,如图4所示,该电子设备可以包括:处理器(processor)410、通信接口(Communications Interface)420、存储器(memory)430和通信总线440,其中,处理器410,通信接口420,存储器430通过通信总线440完成相互间的通信。处理器410可以调用存储器430中的逻辑指令,以执行模型优化方法,该方法包括:将未携带标签的目标域数据输入教师模型,得到所述教师模型的第一输出;其中,所述教师模型是基于携带有标签的源域数据训练得到的源域模型;
将伪对抗样本和所述目标域数据输入预设学生模型,得到所述预设学生模型的第二输出;其中,所述伪对抗样本是根据未携带标签的目标域数据和所述预设学生模型确定的;
基于所述第一输出和所述第二输出,确定所述预设学生模型的训练目标;
基于所述目标域数据、所述伪对抗样本和所述训练目标对所述预设学生模型进行训练,在满足预设训练条件的情况下,得到训练好的学生模型;
基于所述训练好的学生模型和所述源域模型进行模型自适应处理。
此外,上述的存储器430中的逻辑指令可以通过软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
另一方面,本发明还提供一种计算机程序产品,所述计算机程序产品包括计算机程序,计算机程序可存储在非暂态计算机可读存储介质上,所述计算机程序被处理器执行时,计算机能够执行上述各方法所提供的模型优化方法,该方法包括:将未携带标签的目标域数据输入教师模型,得到所述教师模型的第一输出;其中,所述教师模型是基于携带有标签的源域数据训练得到的源域模型;
将伪对抗样本和所述目标域数据输入预设学生模型,得到所述预设学生模型的第二输出;其中,所述伪对抗样本是根据未携带标签的目标域数据和所述预设学生模型确定的;
基于所述第一输出和所述第二输出,确定所述预设学生模型的训练目标;
基于所述目标域数据、所述伪对抗样本和所述训练目标对所述预设学生模型进行训练,在满足预设训练条件的情况下,得到训练好的学生模型;
基于所述训练好的学生模型和所述源域模型进行模型自适应处理。又一方面,本发明还提供一种非暂态计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现以执行上述各方法提供的模型优化方法,该方法包括:将未携带标签的目标域数据输入教师模型,得到所述教师模型的第一输出;其中,所述教师模型是基于携带有标签的源域数据训练得到的源域模型;
将伪对抗样本和所述目标域数据输入预设学生模型,得到所述预设学生模型的第二输出;其中,所述伪对抗样本是根据未携带标签的目标域数据和所述预设学生模型确定的;
基于所述第一输出和所述第二输出,确定所述预设学生模型的训练目标;
基于所述目标域数据、所述伪对抗样本和所述训练目标对所述预设学生模型进行训练,在满足预设训练条件的情况下,得到训练好的学生模型;
基于所述训练好的学生模型和所述源域模型进行模型自适应处理。
以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。本领域普通技术人员在不付出创造性的劳动的情况下,即可以理解并实施。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到各实施方式可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件。基于这样的理解,上述技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品可以存储在计算机可读存储介质中,如ROM/RAM、磁碟、光盘等,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行各个实施例或者实施例的某些部分所述的方法。
最后应说明的是:以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。
Claims (10)
1.一种模型优化方法,其特征在于,包括:
将未携带标签的目标域数据输入教师模型,得到所述教师模型的第一输出;其中,所述教师模型是基于携带有标签的源域数据训练得到的源域模型;
将伪对抗样本和所述目标域数据输入预设学生模型,得到所述预设学生模型的第二输出;其中,所述伪对抗样本是根据未携带标签的目标域数据和所述预设学生模型确定的;
基于所述第一输出和所述第二输出,确定所述预设学生模型的训练目标;
基于所述目标域数据、所述伪对抗样本和所述训练目标对所述预设学生模型进行训练,在满足预设训练条件的情况下,得到训练好的学生模型;
基于所述训练好的学生模型和所述源域模型进行模型自适应处理。
2.根据权利要求1所述的模型优化方法,其特征在于,在所述将伪对抗样本和所述目标域数据输入所述预设学生模型,得到所述预设学生模型的第二输出的步骤之前,还包括:
利用预设学生模型对所述目标域数据进行分析,计算得到所述目标域数据对应伪标签;
根据所述伪标签计算所述目标域数据在所述预设学生模型上的伪对抗样本。
3.根据权利要求2所述的模型优化方法,其特征在于,根据所述伪标签计算所述目标域数据在所述预设学生模型上的伪对抗样本,包括:
利用伪标签计算所述目标域数据对于所述学生模型的交叉熵损失函数;
最大化所述交叉熵损失函数,并在各个所述目标域数据中找到使所述交叉熵损失函数最大化的目标域数据作为伪对抗样本。
4.根据权利要求1所述的模型优化方法,其特征在于,基于所述目标域数据、所述伪对抗样本和所述训练目标对所述预设学生模型进行训练,在满足预设训练条件的情况下,得到训练好的学生模型,包括:
将所述目标域数据输入到所述预设学生模型,输出所述目标域数据对应的模型输出;
分别计算所述模型输出和所述伪对抗样本与所述训练目标之间的第一Kullback–Leibler散度与第二Kullback–Leibler散度,优化所述预设学生模型,使所述第一Kullback–Leibler散度与第二Kullback–Leibler散度之和最小化;
在第一Kullback–Leibler散度与第二Kullback–Leibler散度之和小于预设阈值的情况下,停止训练,得到训练好的学生模型。
5.根据权利要求1所述的模型优化方法,其特征在于,基于所述训练好的学生模型和所述源域模型进行模型自适应处理,包括:
将所述训练好的学生模型作为所述源域模型自适应的初始化模型,进行模型自适应处理,得到自适应处理后的源域模型。
6.一种模型优化装置,其特征在于,包括:
第一输入模块,用于将未携带标签的目标域数据输入教师模型,得到所述教师模型的第一输出;其中,所述教师模型是基于携带有标签的源域数据训练得到的源域模型;
第二输入模块,用于将伪对抗样本和所述目标域数据输入预设学生模型,得到所述预设学生模型的第二输出;其中,所述伪对抗样本是根据未携带标签的目标域数据和所述预设学生模型确定的;
确定模块,用于基于所述第一输出和所述第二输出,确定所述预设学生模型的训练目标;
训练模块,用于基于所述目标域数据、所述伪对抗样本和所述训练目标对所述预设学生模型进行训练,在满足预设训练条件的情况下,得到训练好的学生模型;
自适应模块,用于基于所述训练好的学生模型和所述源域模型进行模型自适应处理。
7.根据权利要求6所述的模型优化装置,其特征在于,所述装置还用于:
利用预设学生模型对所述目标域数据进行分析,计算得到所述目标域数据对应伪标签;
根据所述伪标签计算所述目标域数据在所述预设学生模型上的伪对抗样本。
8.一种电子设备,包括存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现如权利要求1至5任一项所述模型优化方法。
9.一种非暂态计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至5任一项所述模型优化方法。
10.一种计算机程序产品,包括计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至5任一项所述模型优化方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311270480.9A CN117521774A (zh) | 2023-09-27 | 2023-09-27 | 模型优化方法、装置、电子设备、存储介质及产品 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311270480.9A CN117521774A (zh) | 2023-09-27 | 2023-09-27 | 模型优化方法、装置、电子设备、存储介质及产品 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117521774A true CN117521774A (zh) | 2024-02-06 |
Family
ID=89748446
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311270480.9A Pending CN117521774A (zh) | 2023-09-27 | 2023-09-27 | 模型优化方法、装置、电子设备、存储介质及产品 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117521774A (zh) |
-
2023
- 2023-09-27 CN CN202311270480.9A patent/CN117521774A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Yamin et al. | Weaponized AI for cyber attacks | |
CN108549940B (zh) | 基于多种对抗样例攻击的智能防御算法推荐方法及系统 | |
EP4235523A1 (en) | Identifying and correcting vulnerabilities in machine learning models | |
EP3916597A1 (en) | Detecting malware with deep generative models | |
CN117461032A (zh) | 异常检测系统及方法 | |
Hussain et al. | CNN-Fusion: An effective and lightweight phishing detection method based on multi-variant ConvNet | |
Fang et al. | Backdoor attacks on the DNN interpretation system | |
Karanam et al. | Intrusion detection mechanism for large scale networks using CNN-LSTM | |
Ji et al. | Programmable neural network trojan for pre-trained feature extractor | |
Tuna et al. | Closeness and uncertainty aware adversarial examples detection in adversarial machine learning | |
Bountakas et al. | Defense strategies for adversarial machine learning: A survey | |
Chivukula et al. | Adversarial Machine Learning: Attack Surfaces, Defence Mechanisms, Learning Theories in Artificial Intelligence | |
Bouke et al. | An empirical study of pattern leakage impact during data preprocessing on machine learning-based intrusion detection models reliability | |
Dalle Pezze et al. | A multi-label continual learning framework to scale deep learning approaches for packaging equipment monitoring | |
He et al. | Image-based zero-day malware detection in iomt devices: A hybrid ai-enabled method | |
Liao et al. | Server-based manipulation attacks against machine learning models | |
Şeker | Use of Artificial Intelligence Techniques/Applications in Cyber Defense | |
CN116543240A (zh) | 一种面向机器学习对抗攻击的防御方法 | |
CN113918936A (zh) | Sql注入攻击检测的方法以及装置 | |
CN117521774A (zh) | 模型优化方法、装置、电子设备、存储介质及产品 | |
CN114021136A (zh) | 针对人工智能模型的后门攻击防御系统 | |
Zhu et al. | Gradient shaping: Enhancing backdoor attack against reverse engineering | |
Shahrasbi et al. | On detecting data pollution attacks on recommender systems using sequential gans | |
Preethi et al. | Leveraging network vulnerability detection using improved import vector machine and Cuckoo search based Grey Wolf Optimizer | |
Yılmaz | Malware classification with using deep learning |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |