CN116720214A - 一种用于隐私保护的模型训练方法及装置 - Google Patents
一种用于隐私保护的模型训练方法及装置 Download PDFInfo
- Publication number
- CN116720214A CN116720214A CN202310581293.6A CN202310581293A CN116720214A CN 116720214 A CN116720214 A CN 116720214A CN 202310581293 A CN202310581293 A CN 202310581293A CN 116720214 A CN116720214 A CN 116720214A
- Authority
- CN
- China
- Prior art keywords
- model
- training
- sample
- loss information
- preset
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 277
- 238000000034 method Methods 0.000 title claims abstract description 103
- 238000012545 processing Methods 0.000 claims abstract description 66
- 238000004821 distillation Methods 0.000 claims abstract description 55
- 238000013140 knowledge distillation Methods 0.000 claims abstract description 42
- 230000008569 process Effects 0.000 claims abstract description 37
- 230000002265 prevention Effects 0.000 claims description 81
- 230000000875 corresponding effect Effects 0.000 claims description 59
- 230000006870 function Effects 0.000 claims description 38
- 238000004364 calculation method Methods 0.000 claims description 6
- 230000002596 correlated effect Effects 0.000 claims description 3
- 238000010200 validation analysis Methods 0.000 claims description 2
- 238000010586 diagram Methods 0.000 description 19
- 238000003860 storage Methods 0.000 description 13
- 238000005516 engineering process Methods 0.000 description 11
- 238000013145 classification model Methods 0.000 description 10
- 238000010801 machine learning Methods 0.000 description 10
- 238000004590 computer program Methods 0.000 description 9
- 238000012804 iterative process Methods 0.000 description 7
- 238000013528 artificial neural network Methods 0.000 description 4
- 230000009286 beneficial effect Effects 0.000 description 3
- 238000010606 normalization Methods 0.000 description 3
- 238000012795 verification Methods 0.000 description 3
- 230000005540 biological transmission Effects 0.000 description 2
- 230000006835 compression Effects 0.000 description 2
- 238000007906 compression Methods 0.000 description 2
- 238000013527 convolutional neural network Methods 0.000 description 2
- 238000009826 distribution Methods 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 230000002085 persistent effect Effects 0.000 description 2
- 230000001052 transient effect Effects 0.000 description 2
- 241000282472 Canis lupus familiaris Species 0.000 description 1
- 241000282326 Felis catus Species 0.000 description 1
- 230000004075 alteration Effects 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 230000004927 fusion Effects 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 230000005055 memory storage Effects 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 230000000644 propagated effect Effects 0.000 description 1
- 238000013138 pruning Methods 0.000 description 1
- 238000013139 quantization Methods 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F21/00—Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
- G06F21/60—Protecting data
- G06F21/62—Protecting access to data via a platform, e.g. using keys or access control rules
- G06F21/6218—Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
- G06F21/6245—Protecting personal data, e.g. for financial or medical purposes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Bioethics (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Evolutionary Computation (AREA)
- Computer Security & Cryptography (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Computer Hardware Design (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本说明书一个或多个实施例公开了一种用于隐私保护的模型训练方法。该方法包括:当对目标模型进行模型训练的迭代次数超过预设迭代次数时,获取目标模型在迭代过程中产生的多个中间模型;然后,以多个中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息;其次,根据第一损失信息和第二损失信息确定第一逐样本梯度;最后,对第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,最终得到训练后的目标模型。
Description
技术领域
本文件涉及隐私保护机器学习技术领域,尤其涉及一种用于隐私保护的模型训练方法及装置。
背景技术
大数据和人工智能技术的产生与发展,极大地促进了科技的进步与人类生活水平的提高,这其中离不开大数据技术背后广泛而优质的数据。然而,这些数据中所暗藏的大量用户隐私信息,有可能随着智能服务看似合法的访问而暴露出来,随着用户对隐私数据越来越重视,隐私保护机器学习技术应运而生。
差分隐私机器学习,是目前最为常用的隐私保护机器学习技术。目前的差分隐私机器学习,通常是对模型梯度进行无偏挠动,而且受限于隐私的限制,使得模型无法充分训练,导致最终模型的精度相比于非隐私机器学习的精度差异较大。因此,需要提供一种能够提高模型训练精度的用于隐私保护的模型训练方法。
发明内容
一方面,本说明书一个或多个实施例提供一种用于隐私保护的模型训练方法,包括:当对目标模型进行模型训练的迭代次数超过预设迭代次数时,获取所述目标模型在迭代过程中产生的多个中间模型,所述预设迭代次数小于预设的最大迭代次数;以多个所述中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过所述教师模型,采用知识蒸馏的方式对所述学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,所述第一损失信息由多个所述中间模型的输出结果与当前训练的目标模型的输出结果确定;根据所述第一损失信息和第二损失信息确定第一逐样本梯度,所述第二损失信息由第一训练样本输入到所述目标模型中得到的输出结果与所述第一训练样本的标签确定;对所述第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
另一方面,本说明书一个或多个实施例提供一种用于隐私保护的模型训练方法,包括:
当对风险防控模型进行模型训练的迭代次数超过预设迭代次数时,获取所述风险防控模型在迭代过程中产生的多个中间模型,所述预设迭代次数小于预设的最大迭代次数;
以多个所述中间模型作为教师模型,以当前训练的风险防控模型作为学生模型,通过所述教师模型,采用知识蒸馏的方式对所述学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,所述第一损失信息由多个所述中间模型的输出结果与当前训练的风险防控模型的输出结果确定;
根据所述第一损失信息和第二损失信息确定第一逐样本梯度,所述第二损失信息由第一历史交易事件的特征数据输入到所述风险防控模型中得到的输出结果与所述第一历史交易事件的特征数据的标签确定;
对所述第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的风险防控模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的风险防控模型。
再一方面,本说明书一个或多个实施例提供一种用于隐私保护的模型训练装置,包括:中间模型获取模块,当对目标模型进行模型训练的迭代次数超过预设迭代次数时,获取所述目标模型在迭代过程中产生的多个中间模型,所述预设迭代次数小于预设的最大迭代次数;蒸馏训练模块,以多个所述中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过所述教师模型,采用知识蒸馏的方式对所述学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,所述第一损失信息由多个所述中间模型的输出结果与当前训练的目标模型的输出结果确定;第一梯度计算模块,根据所述第一损失信息和第二损失信息确定第一逐样本梯度,所述第二损失信息由第一训练样本输入到所述目标模型中得到的输出结果与所述第一训练样本的标签确定;模型参数更新模块,对所述第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
再一方面,本说明书一个或多个实施例提供一种电子设备,包括:处理器;以及被安排成存储计算机可执行指令的存储器,在所述可执行指令被执行时,能够使得所述处理器:当对目标模型进行模型训练的迭代次数超过预设迭代次数时,获取所述目标模型在迭代过程中产生的多个中间模型,所述预设迭代次数小于预设的最大迭代次数;以多个所述中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过所述教师模型,采用知识蒸馏的方式对所述学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,所述第一损失信息由多个所述中间模型的输出结果与当前训练的目标模型的输出结果确定;根据所述第一损失信息和第二损失信息确定第一逐样本梯度,所述第二损失信息由第一训练样本输入到所述目标模型中得到的输出结果与所述第一训练样本的标签确定;对所述第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
附图说明
为了更清楚地说明本说明书一个或多个实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本说明书一个或多个实施例中记载的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是根据本说明书一实施例的一种用于隐私保护的模型训练方法的示意性流程图;
图2是根据本说明书一实施例的一种用于隐私保护的模型训练方法的示意性原理图;
图3是根据本说明书一实施例的另一种用于隐私保护的模型训练方法的示意性流程图;
图4是根据本说明书一实施例的又一种用于隐私保护的模型训练方法的示意性流程图;
图5A是根据本说明书一实施例的一种用于隐私保护的模型训练方法的示意性流程图;
图5B是根据本说明书一实施例的另一种用于隐私保护的模型训练方法的示意性流程图;
图6是根据本说明书一实施例的一种用于隐私保护的模型训练装置的示意性框图;
图7是根据本说明书一实施例的一种电子设备的示意性框图。
具体实施方式
本说明书一个或多个实施例提供一种用于隐私保护的模型训练方法及装置,以解决目前的经过模型训练所获取的最终模型的精度较低的问题。
为了使本技术领域的人员更好地理解本说明书一个或多个实施例中的技术方案,下面将结合本说明书一个或多个实施例中的附图,对本说明书一个或多个实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本说明书一部分实施例,而不是全部的实施例。基于本说明书一个或多个实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都应当属于本文件保护的范围。
隐私保护机器学习是在使用用户数据进行大数据训练的同时,仍然可以保护用户隐私数据的机器学习方法,通过对数据的挠动或者加密实现数据的可用不可见,主要的方式有通过差分隐私技术保护机器学习中用户的隐私数据,具有移植性强、使用简单等优点。差分隐私训练方法通常是在模型训练过程的每一次迭代中,先计算当前批次内每一个样本的梯度,然后根据固定的阈值C对每一个样本进行裁剪(即对二范数大于C的梯度会裁剪至二范数小于等于C)。之后再对这些梯度求和后添加适当噪声。而由于裁剪和噪声的存在,对梯度进行了无偏挠动,使得隐私训练和非隐私训练之间存在精度差异。另一方面,由于隐私预算(隐私预算评估了该训练方法最多泄漏的隐私的多少)的限制,使得模型无法充分训练,这也是精度存在差异的一个重要原因。基于上述技术问题,本说明书实施例提供一种用于隐私保护的模型训练方法及装置,将中间模型知识蒸馏技术引入差分隐私训练,从而有效地利用模型训练过程中产生的中间模型提升模型训练的精度,下面进行详细说明。
图1是根据本说明书一实施例的一种用于隐私保护的模型训练方法用于隐私保护的模型训练方法的示意性流程图。图2是根据本说明书一实施例的一种用于隐私保护的模型训练方法的示意性原理图。下面结合图1和图2详细说明本说明书实施例提供的一种用于隐私保护的模型训练方法。如图1所示,该方法可以包括:
S102,当对目标模型进行模型训练的迭代次数超过预设迭代次数时,获取目标模型在迭代过程中产生的多个中间模型,其中,预设迭代次数小于预设的最大迭代次数。
其中,目标模型为当前待训练模型,目标模型可以是各种用于对私有数据进行保护的神经网络,如卷积神经网络等,本说明书实施例对此不做限定。在实际应用中,目标模型可以是某项业务中使用的模型,例如,目标模型可以是金融领域中风险防控业务中的风险防控模型,或者,目标模型也可以是信息推荐业务中用于向用户推荐指定类型的信息的模型等。
本说明书实施例的中间模型又称检查点(Check Point),是模型训练过程中所产生的模型,比如:模型训练需要进行10000次迭代,那么模型训练中的第K(K<10000)次迭代的模型结果即称为中间模型或者检查点。由于在基于差分隐私的隐私保护过程中,通常默认攻击者是了解全部的中间模型的,也就是这部分中间模型在隐私设置里通常认为是可以被攻击者访问到的,因此,本说明实施例中获取中间模型并不损害用户隐私。
需要说明的是,图2中只是示例性地示出了中间模型的数量,通常中间模型的数量大于2,且为奇数,从而确保在后续利用中间模型的输出结果确定教师模型的最终输出结果时便于计算。中间模型的数量过少,蒸馏训练的效果相对较差,中间模型的数量过多,会增加计算量,从而增加资源占用,实际应用中可以根据目标模型选择合适的数量,比如可以选择5个中间模型。
需要说明的是,针对获取的目标模型在迭代过程中产生的多个中间模型,可以获取目标模型在迭代过程中产生的不连续的预设数量的任意中间模型,也可以获取与当前目标模型输出时间最近的预设数量的中间模型,本说明书实施例对此不做限定。
可选地,预设迭代次数可以为预设的最大迭代次数的一半,当对目标模型进行模型训练的迭代次数超过最大迭代次数的一半时,即开始获取中间模型。该预设迭代次数在模型训练中使用起来更加方便。
S104,以多个中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,其中,第一损失信息由多个中间模型的输出结果与当前训练的目标模型的输出结果确定。
知识蒸馏(即:knowledge distillation)是模型压缩的一种常用的方法,不同于模型压缩中的剪枝和量化,知识蒸馏是通过构建一个轻量化的小模型,利用性能更好的大模型的监督信息,来训练该小模型,以期达到更好的性能和精度。目前业内的知识蒸馏方法也使用相同架构的网络互相进行知识蒸馏,以提高单一网络的精度。
本说明书实施例中的第一损失信息,可以采用软标签分类损失函数获取,具体可以采用KL(Kullback–Leibler)散度损失函数获取、也可以采用L1-coss(平均绝对误差损失函数)或L2-loss(均方误差损失函数)获取,采用KL散度损失函数能够使归一化后的概率空间中的概率分布更接近,有利于提高模型训练的精度。
S106,根据第一损失信息和第二损失信息确定第一逐样本梯度,其中,第二损失信息由第一训练样本输入到目标模型中得到的输出结果与第一训练样本的标签确定。
本说明书实施例中的第二损失信息,可以采用分类模型损失函数获取,具体可以采用交叉熵损失函数获取,也可以采用BCE-loss(二分类交叉熵损失函数)、focal-loss(焦点损失函数)等。采用交叉熵损失函数获取第二损失信息,可以应用于十分类、二分类等不同的分类模型中,应用更方便,有利于进一步提高模型训练的精度和训练效率。第一训练样本可以根据目标模型的不同而设置,例如,目标模型为风险防控模型,则第一训练样本可以是与历史交易事件相关的数据,具体如,交易双方的账号、交易地点、交易时间、交易金额、交易的商品信息、商品的交接方式等,目标模型为用于进行信息推荐的模型,则第一训练样本可以是向用户推荐信息时产生的相关数据等,具体如,推荐信息的地点、时间、推荐的信息内容、推荐的信息类型、推荐的信息的来源等,具体可以根据实际情况设定,本说明书实施例对此不做限定。
可选地,根据第一损失信息和第二损失信息确定第一逐样本梯度的方法,可以是将第一损失信息和第二损失信息相加,得到求和结果后进行向后传播,得到第一逐样本梯度。
S108,对第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
在实施中,获取到第一逐样本梯度后,对其使用预设的裁剪阈值进行裁剪,并求和,对求和结果加入适量噪声,获取到差分隐私处理后的第一逐样本梯度,利用该差分隐私处理后的第一逐样本梯度更新当前训练的目标模型的模型参数,获取到新的目标模型,然后返回步骤S102进入下一次迭代,直到达到预设的模型训练终止条件为止,即可获取到最终的训练后的目标模型。其中,预设的模型训练终止条件可以是达到预设的最大迭代次数,也可以是目标模型收敛。
采用本说明书一个或多个实施例的技术方案,当对目标模型进行模型训练的迭代次数超过预设迭代次数时,通过获取目标模型在迭代过程中产生的多个中间模型,并以多个中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息。并利用第一损失信息和第二损失信息确定第一逐样本梯度,从而对梯度进行更新,并利用更新后的梯度更新目标模型的模型参数,最终获取到训练后的目标模型。这种利用中间模型进行蒸馏训练的方式,能够在不损害用户隐私数据的基础上,将多个中间模型的信息传递给目标模型,从而提升模型训练的精度。另一方面,通过将多个中间模型的信息传递给目标模型,由于中间模型给予目标模型的标签为软标签,该软标签里面暗含了标签之间的关系,该标签之间的关系可以通过知识蒸馏让目标模型提前学习到,因此这种方式也有利于加快模型训练进程。而在有限的隐私预算中,每个样本使用的次数时有限的,加快模型训练进程也就相当于提升模型训练的精度。
在一个实施例中,以多个中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,包括如下步骤:
S1041:将第二训练样本分别输入到多个中间模型中,得到每个中间模型对应的输出数据。
本说明书实施例的应用场景包括支付宝-蚂蚁森林,支付风控等。本说明书实施例中的第一训练样本和第二训练样本,根据实际应用中隐私保护的具体场景中的目标模型而定。以支付风控场景为例,目标模型通常为风险防控模型,则第一训练样本和第二训练样本可以为:交易双方的账号、交易地点、交易时间、交易金额、交易的商品信息、商品的交接方式、交易金额、UID(User Interface Description,用户界面身份识别)、用户画像的打分等。以支付宝-蚂蚁森林应用场景为例,目标模型通常为用于进行信息推荐的模型,第一训练样本和第二训练样本可以是向用户推荐信息时产生的相关数据等,如:推荐信息的地点、时间、推荐的信息内容、推荐的信息类似、推荐的信息来源等。
S1042:基于多个中间模型对应的输出数据,确定教师模型对应的输出结果。
S1043:将第二训练样本输入到目标模型中,得到学生模型对应的输出结果。
S1044:基于教师模型对应的输出结果和学生模型对应的输出结果,确定第一损失信息。
可选地,为简化输出结果,以提高模型训练效率,可以对输出结果进行归一化处理。具体地,在S1041中每个中间模型对应的输出数据为归一化处理前数据,即:softmax前的输出数据;在S1042和S1043中获取到输出结果后,分别进行归一化处理,即:对教师模型对应的输出结果进行归一化处理,获取教师模型对应的softmax输出数据,对学习模型对应的输出结果进行归一化处理,获取学生模型对应的softmax输出数据。然后经由S1044,利用教师模型对应的softmax输出数据和学生模型对应的softmax输出数据确定第一损失信息。
在一个实施例中,基于多个中间模型对应的输出数据,确定教师模型对应的输出结果,可以计算多个中间模型的输出数据的平均值,以平均值作为教师模型的输出结果。也可以获取多个中间模型的输出数据中熵最小的输出数据,以获取的熵最小的输出数据作为教师模型的输出结果。
在差分隐私训练中,由于隐私的限制,通常模型无法充分训练。而通常来说,模型在训练过程中可以初步分为两个阶段:第一个阶段模型会去学习一些简单的知识,比如哪个标签更符合当前图片。第二个阶段为模型后期,该阶段模型需要学习各个标签之间的依赖关系和联系。第二个阶段模型学习到的是硬标签,该硬标签很难快速让模型学习到,只能通过使用尽可能大的迭代次数慢慢教会模型。而检查点知识蒸馏正是一种可以很好地解决这个问题的办法。在该实施例中使用知识蒸馏技术时,基于多个中间模型对应的输出数据,确定一个教师模型对应的输出结果,相当于对检查点进行了集成学习,用多个检查点ensemble后的结果来作为教师模型教学生模型,实际熵是一种高精度网络教低精度网络,可以更好地提升网络准确性。通常来说,对于不同阶段的检查点,其识别能力是不同的,有些检查点模型可能更能识别到猫,有些可能对狗敏感,通过这种集成学习的方式,能够很好地对相关专家知识进行整合,从而进一步提高模型训练的精度。
在一个实施例中,基于教师模型对应的输出结果和学生模型对应的输出结果,确定第一损失信息,可执行如下步骤A1-A3:
步骤A1,计算教师模型对应的输出结果与预设的第一温度参数之间的比值。
步骤A2,计算学生模型对应的输出结果与预设的第二温度参数之间的比值。
步骤A3,根据计算得到的两个比值确定第一损失信息。
根据以上步骤A1-A3,实际应用中以采用KL散度损失函数获取第一损失信息为例,确定第一损失信息的过程为:
首先,计算所有存储下来的检查点对当前训练样本(即第二训练样本)的输出结果(即softmax前的输出数据);然后,对上述输出结果在检查点维度上求取均值,并除以温度T,之后在softmax输出数据p=[p_0,p_1,…,p_N];其次,计算当前训练样本在当前待更新的模型上除以温度T后的logit(即softmax后的输出数据)q=[q_0,q_1,…,q_N];最后,针对输出结果p=[p_0,p_1,…,p_N]和q=[q_0,q_1,…,q_N],计算其KL散度,KL散度损失函数如下:
通过上述KL散度损失函数计算得出的KL散度值作为第一损失信息。
需要注意的是,该实施例中第一温度参数与第二温度参数可以取相同值,也可以取不同值,本说明书实施例对此不做限定。第一温度参数和第二温度参数为可调节的超参数,其取值越高,经过蒸馏训练所获取的标签关系越准确,但是取值过高会导致参数无法更新,第一温度参数和第二温度参数可以取值2-3。
而且,第一温度参数和第二温度参数的大小与模型验证集的准确率正相关。即:实际应用中,第一温度参数和第二温度参数的大小取决于当前目标模型在验证集上的性能,目前模型在验证集上准确率要求越高,相应地温度参数的取值就要设置的越高一些。
图3是根据本说明书一实施例的另一种用于隐私保护的模型训练方法用于隐私保护的模型训练方法的示意性流程图。如图3所示,该方法可以包括:
S202:判断对目标模型进行模型训练的迭代次数是否超过预设迭代次数。
如果对目标模型进行模型训练的迭代次数超过预设迭代次数,执行步骤S204:获取目标模型在迭代过程中产生的多个中间模型。
S206:以多个中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,第一损失信息由多个中间模型的输出结果与当前训练的目标模型的输出结果确定。
S208:根据第一损失信息和第二损失信息确定第一逐样本梯度,第二损失信息由第一训练样本输入到目标模型中得到的输出结果与第一训练样本的标签确定。
S210:对第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
如果对目标模型进行模型训练的迭代次数小于预设迭代次数,执行步骤S212:利用第二损失信息确定第二逐样本梯度。
第二损失信息可以采用分类模型损失函数获取,具体可以采用交叉熵损失函数获取,也可以采用BCE-loss(二分类交叉熵损失函数)、focal-loss(焦点损失函数)等。
S210:对第二逐样本梯度做差分隐私处理,并利用差分隐私处理后的第二逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
由以上步骤S202-S212可知,在该实施例中增加了对目标模型进行模型训练的迭代次数小于预设迭代次数的情况。无论对目标模型进行模型训练的次数与预设迭代次数相比的结果是多少,获取到相应的第一逐样本梯度或第二逐样本梯度之后,都先做差分隐私处理,然后利用差分隐私处理后的逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
在一种实施例中,在判断目标模型进行模型训练的迭代次数是否超过预设迭代次数之前,还包括步骤B1和B2:
步骤B1:通过设置用于模型训练的超参数,对目标模型进行初始化。其中,超参数包括:批次数、模型最大迭代次数、隐私预算、裁剪阈值中的一项或多项;
步骤B2:设置多个模型序列,用于存储模型训练过程中产生的中间模型,模型序列的数量与多个中间模型的数量一致。
由以上步骤B1和B2可知,当对目标模型进行模型训练的迭代次数超过预设迭代次数时,以及,对目标模型进行模型训练的迭代次数小于预设迭代次数时,都可以先通过设置超参数对目标模型进行初始化,从而提高模型训练的效率和精度。通过设置多个模型序列,用于存储部分中间模型,便于后续对中间模型进行处理。
示例性地,图4示出了根据本说明书一实施例的又一种用于隐私保护的模型训练方法的示意性流程图。由图4可知,本说明书一个或多个实施例提供的一种用于隐私保护的模型训练方法在实际应用中的训练流程如下:
1)设置超参数,初始化模型。
超参数包括且不限于如下参数中的一项或多项:批次数B、模型最大迭代次数T、隐私预算(\epslion,\detla),裁剪阈值C等。
先将模型进行初始化,例如:可以设置迭代次数t=0,初始化模型x_0。同时准备五个模型序列,用于存储迭代过程中产生的中间模型。
2)判断迭代次数是否到达预设的模型最大迭代次数T,如果到达,则退出循环,如果没到达,则进入下一步。
3)判断迭代次数是否达到预设的最大迭代次数的一半,如果达到,则进入步骤5),否则,进入步骤4)。
4)使用普通的标签计算样本经过神经网络后的前向输出与标签之间的CE_loss(Cross Entropy Loss,交叉熵损失),并反向传播计算第二逐样本梯度,然后进入步骤6)。
5)将最新的五个模型存入检查点中,并计算他们ensemble(模型融合)后的输出数据与当前模型输出数据之间的KL-loss(Kullback–Leibler Divergence,KL散度损失,又称相对熵损失)。使用普通的标签计算样本经过神经网络后的前向输出与标签之间的CE_loss。然后对这两个损失求和后进行反向传播,计算第一逐样本梯度,然后进入步骤6)。
6)对第一逐样本梯度/第二逐样本梯度使用裁剪阈值C进行裁剪,然后求和。对求和后的结果加入适量噪声,然后使用带噪声的梯度作为用于更新模型的梯度进行更新,然后返回步骤2)。
综上,已经对本主题的特定实施例进行了描述。其它实施例在所附权利要求书的范围内。在一些情况下,在权利要求书中记载的动作可以按照不同的顺序来执行并且仍然可以实现期望的结果。另外,在附图中描绘的过程不一定要求示出的特定顺序或者连续顺序,以实现期望的结果。在某些实施方式中,多任务处理和并行处理可以是有利的。
图5A是根据本说明书一实施例的一种用于隐私保护的模型训练方法的示意性流程图。如图5A所示,该方法可以包括:
S302:当对风险防控模型进行模型训练的迭代次数超过预设迭代次数时,获取风险防控模型在迭代过程中产生的多个中间模型,预设迭代次数小于预设的最大迭代次数。
其中,风险防控模型可以通过神经网络构建,具体如通过卷积神经网络构建等,本说明书实施例对此不做限定。
可选地,预设迭代次数可以为预设的最大迭代次数的一半,当对风险防控模型进行模型训练的迭代次数超过最大迭代次数的一半时,即开始获取中间模型。该预设迭代次数在模型训练中使用起来更加方便。
S304:以多个中间模型作为教师模型,以当前训练的风险防控模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,第一损失信息由多个中间模型的输出结果与当前训练的风险防控模型的输出结果确定。
本说明书实施例中的第一损失信息,可以采用软标签分类损失函数获取,具体可以采用KL散度损失函数获取、也可以采用L1-coss(平均绝对误差损失函数)或L2-loss(均方误差损失函数)获取,采用KL散度损失函数能够使归一化后的概率空间中的概率分布更接近,有利于提高模型训练的精度。
S306:根据第一损失信息和第二损失信息确定第一逐样本梯度,第二损失信息由第一历史交易事件的特征数据输入到风险防控模型中得到的输出结果与第一历史交易事件的特征数据的标签确定。
本说明书实施例中的第二损失信息,可以采用分类模型损失函数获取,具体可以采用交叉熵损失函数获取,也可以采用BCE-loss(二分类交叉熵损失函数)、focal-loss(焦点损失函数)等。采用交叉熵损失函数获取第二损失信息,可以应用于十分类、二分类等不同的分类模型中。第一历史交易事件的特征数据可以包括:交易双方的账号、交易地点、交易时间、交易金额、交易的商品信息、商品的交接方式中的一项或多项。
S308:对第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的风险防控模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的风险防控模型。
在一个实施例中,以多个中间模型作为教师模型,以当前训练的风险防控模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,包括如下步骤:
S3041:将第二历史交易事件的特征数据分别输入到多个中间模型中,得到每个中间模型对应的输出数据。
第二历史交易事件的特征数据可以为:交易金额、UID、交易时间、用户画像打分等。
S3042:基于多个中间模型对应的输出数据,确定教师模型对应的输出结果。
S3043:将第二历史交易事件的特征数据输入到风险防控模型中,得到学生模型对应的输出结果。
S3044:基于教师模型对应的输出结果和学生模型对应的输出结果,确定第一损失信息。
可选地,本说明书实施例可以对输出结果进行归一化处理。
在一个实施例中,基于多个中间模型对应的输出数据,确定教师模型对应的输出结果,可以计算多个中间模型的输出数据的平均值,以平均值作为教师模型的输出结果。也可以获取多个中间模型的输出数据中熵最小的输出数据,以获取的熵最小的输出数据作为教师模型的输出结果。
在一个实施例中,基于教师模型对应的输出结果和学生模型对应的输出结果,确定第一损失信息,可执行如下步骤C1-C3:
步骤C1,计算教师模型对应的输出结果与预设的第一温度参数之间的比值。
步骤C2,计算学生模型对应的输出结果与预设的第二温度参数之间的比值。
步骤C3,根据计算得到的两个比值确定第一损失信息。
而且,第一温度参数和第二温度参数为可调节的超参数,第一温度参数和第二温度参数的大小与模型验证集的准确率正相关。
图5B是根据本说明书一实施例的另一种用于隐私保护的模型训练方法的示意性流程图。如图5B所示,该方法可以包括:
S402:判断对风险防控模型进行模型训练的迭代次数是否超过预设迭代次数。
如果对风险防控模型进行模型训练的迭代次数超过预设迭代次数,执行步骤S204:获取风险防控模型在迭代过程中产生的多个中间模型。
S406:以多个中间模型作为教师模型,以当前训练的风险防控模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,第一损失信息由多个中间模型的输出结果与当前训练的风险防控模型的输出结果确定。
S408:根据第一损失信息和第二损失信息确定第一逐样本梯度,第二损失信息由第一历史交易事件的特征数据输入到风险防控模型中得到的输出结果与第一历史交易事件的特征数据的标签确定。
S410:对第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的风险防控模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的风险防控模型。
如果对风险防控模型进行模型训练的迭代次数小于预设迭代次数,执行步骤S412:利用第二损失信息确定第二逐样本梯度。
第二损失信息可以采用分类模型损失函数获取,具体可以采用交叉熵损失函数获取,也可以采用BCE-loss(二分类交叉熵损失函数)、focal-loss(焦点损失函数)等。
S410:对第二逐样本梯度做差分隐私处理,并利用差分隐私处理后的第二逐样本梯度更新当前训练的风险防控模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的风险防控模型。
由以上步骤S402-S412可知,在该实施例中增加了对风险防控模型进行模型训练的迭代次数小于预设迭代次数的情况。无论对风险防控模型进行模型训练的次数与预设迭代次数相比的结果是多少,获取到相应的第一逐样本梯度或第二逐样本梯度之后,都先做差分隐私处理,然后利用差分隐私处理后的逐样本梯度更新当前训练的风险防控模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的风险防控模型。
在一种实施例中,在判断风险防控模型进行模型训练的迭代次数是否超过预设迭代次数之前,还包括步骤D1和D2:
步骤D1:通过设置用于模型训练的超参数,对风险防控模型进行初始化。其中,超参数包括:批次数、模型最大迭代次数、隐私预算、裁剪阈值中的一项或多项;
步骤D2:设置多个模型序列,用于存储模型训练过程中产生的中间模型,模型序列的数量与多个中间模型的数量一致。
需要说明的是,为避免重复,图5A和图5B所示的实施例中,相关步骤的具体实现与图1至图4所示实施例类似,不再另外举例说明,未详细描述的部分可以参见图1-图4所示的实施例。
以上为本说明书一个或多个实施例提供的一种用于隐私保护的模型训练方法,基于同样的思路,本说明书一个或多个实施例还提供一种用于隐私保护的模型训练装置。
图6为根据本说明书一实施例的一种用于隐私保护的模型训练装置的示意图。由图6可知,该装置可以包括:
中间模型获取模块510,当对目标模型进行模型训练的迭代次数超过预设迭代次数时,获取目标模型在迭代过程中产生的多个中间模型,预设迭代次数小于预设的最大迭代次数;
蒸馏训练模块520,以多个中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,第一损失信息由多个中间模型的输出结果与当前训练的目标模型的输出结果确定;
第一梯度计算模块530,根据第一损失信息和第二损失信息确定第一逐样本梯度,第二损失信息由第一训练样本输入到目标模型中得到的输出结果与第一训练样本的标签确定;
模型参数更新模块540,对第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
在一个实施例中,蒸馏训练模块520包括:
中间模型输出数据获取单元,用于将第二训练样本分别输入到多个中间模型中,得到每个中间模型对应的输出数据;
教师模型输出结果确定单元,用于基于多个中间模型对应的输出数据,确定教师模型对应的输出结果;
学生模型输出结果确定单元,用于将第二训练样本输入到目标模型中,得到学生模型对应的输出结果;
第一损失信息确定单元,用于基于教师模型对应的输出结果和学生模型对应的输出结果,确定第一损失信息。
在一个实施例中,本说明书一实施例的一种用于隐私保护的模型训练装置中还包括有:第二损失信息确定模块,用于当对目标模型进行模型训练的迭代次数小于预设迭代次数时,利用第二损失信息确定第二逐样本梯度。
模型参数更新模块540,还用于对第二逐样本梯度做差分隐私处理,并利用差分隐私处理后的第二逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
在一个实施例中,本说明书一实施例的一种用于隐私保护的模型训练装置中还包括有:初始化模块,用于通过设置用于模型训练的超参数,对目标模型进行初始化,超参数包括:批次数、模型最大迭代次数、隐私预算、裁剪阈值中的一项或多项,以及,设置多个模型序列,用于存储模型训练过程中产生的中间模型,其中,模型序列的数量与多个中间模型的数量一致。
采用本说明书一个或多个实施例的装置,通过检测对目标模型进行模型训练的迭代次数,当对目标模型进行模型训练的迭代次数超过预设迭代次数时,获取目标模型在迭代过程中产生的多个中间模型,由于中间模型给予目标模型的是标签是软标签,软标签中暗含有标签之间的关系,从而能够在不损害隐私数据的前提下,充分利用中间模型提高模型训练的精度。然后以多个中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息。采用知识蒸馏方式,能够很好地将中间模型的优势充分利用起来,将多个中间模型的信息通过知识蒸馏技术传递给当前训练的目标模型,有利于提高模型训练的精度。而且,所采用的是多教师模型的知识蒸馏方式,在使用知识蒸馏技术的同时,对多个中间模型进行集成学习,从而可以更好地对不同中间模型的专家知识进行整合,再传递给当前训练的目标模型,有利于进一步提高模型训练的精度。对多个中间模型的再根据第一损失信息和第二损失信息确定第一逐样本梯度,且第二损失信息由第一训练样本输入到目标模型中得到的输出结果与第一训练样本的标签确定。最后对第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。通过对第一逐样本梯度进行更新,从而更新模型参数,进而达到更新模型的目的。
在一种实施例中,根据本说明书一实施例的又一种用于隐私保护的模型训练装置可以包括:
中间模型获取模块,当对风险防控模型进行模型训练的迭代次数超过预设迭代次数时,获取风险防控模型在迭代过程中产生的多个中间模型,预设迭代次数小于预设的最大迭代次数;
蒸馏训练模块,以多个中间模型作为教师模型,以当前训练的风险防控模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,第一损失信息由多个中间模型的输出结果与当前训练的风险防控模型的输出结果确定;
第一梯度计算模块,根据第一损失信息和第二损失信息确定第一逐样本梯度,第二损失信息由第一历史交易事件的特征数据输入到风险防控模型中得到的输出结果与第一历史交易事件的特征数据的标签确定;
模型参数更新模块,对第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的风险防控模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的风险防控模型。
在一个实施例中,蒸馏训练模块包括:
中间模型输出数据获取单元,用于将第二历史交易事件的特征数据分别输入到多个中间模型中,得到每个中间模型对应的输出数据;
教师模型输出结果确定单元,用于基于多个中间模型对应的输出数据,确定教师模型对应的输出结果;
学生模型输出结果确定单元,用于将第二历史交易事件的特征数据输入到风险防控模型中,得到学生模型对应的输出结果;
第一损失信息确定单元,用于基于教师模型对应的输出结果和学生模型对应的输出结果,确定第一损失信息。
在一个实施例中,本说明书一实施例的一种用于隐私保护的模型训练装置中还包括有:第二损失信息确定模块,用于当对风险防控模型进行模型训练的迭代次数小于预设迭代次数时,利用第二损失信息确定第二逐样本梯度。
模型参数更新模块,还用于对第二逐样本梯度做差分隐私处理,并利用差分隐私处理后的第二逐样本梯度更新当前训练的风险防控模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的风险防控模型。
在一个实施例中,本说明书一实施例的又一种用于隐私保护的模型训练装置中还包括有:初始化模块,用于通过设置用于模型训练的超参数,对风险防控模型进行初始化,超参数包括:批次数、模型最大迭代次数、隐私预算、裁剪阈值中的一项或多项,以及,设置多个模型序列,用于存储模型训练过程中产生的中间模型,其中,模型序列的数量与多个中间模型的数量一致。
采用本说明书一个或多个实施例的装置,通过检测对风险防控模型进行模型训练的迭代次数,当对风险防控模型进行模型训练的迭代次数超过预设迭代次数时,获取风险防控模型在迭代过程中产生的多个中间模型,由于中间模型给予风险防控模型的是标签是软标签,软标签中暗含有标签之间的关系,从而能够在不损害隐私数据的前提下,充分利用中间模型提高模型训练的精度。然后以多个中间模型作为教师模型,以当前训练的风险防控模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息。采用知识蒸馏方式,能够很好地将中间模型的优势充分利用起来,将多个中间模型的信息通过知识蒸馏技术传递给当前训练的风险防控模型,有利于提高模型训练的精度。而且,所采用的是多教师模型的知识蒸馏方式,在使用知识蒸馏技术的同时,对多个中间模型进行集成学习,从而可以更好地对不同中间模型的专家知识进行整合,再传递给当前训练的风险防控模型,有利于进一步提高模型训练的精度。对多个中间模型的再根据第一损失信息和第二损失信息确定第一逐样本梯度,且第二损失信息由第一历史交易事件的特征数据输入到风险防控模型中得到的输出结果与第一历史交易事件的特征数据的标签确定。最后对第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的风险防控模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的风险防控模型。通过对第一逐样本梯度进行更新,从而更新模型参数,进而达到更新模型的目的。
本领域的技术人员应可理解,上述用于隐私保护的模型训练装置能够用来实现前文所述的用于隐私保护的模型训练方法,其中的细节描述应与前文方法部分描述类似,为避免繁琐,此处不另赘述。
基于同样的思路,本说明书一个或多个实施例还提供一种电子设备,如图7所示。电子设备可因配置或性能不同而产生比较大的差异,可以包括一个或一个以上的处理器601和存储器602,存储器602中可以存储有一个或一个以上存储应用程序或数据。其中,存储器602可以是短暂存储或持久存储。存储在存储器602的应用程序可以包括一个或一个以上模块(图示未示出),每个模块可以包括对电子设备中的一系列计算机可执行指令。更进一步地,处理器601可以设置为与存储器602通信,在电子设备上执行存储器602中的一系列计算机可执行指令。电子设备还可以包括一个或一个以上电源603,一个或一个以上有线或无线网络接口604,一个或一个以上输入输出接口605,一个或一个以上键盘606。
具体在本实施例中,电子设备包括有存储器,以及一个或一个以上的程序,其中一个或者一个以上程序存储于存储器中,且一个或者一个以上程序可以包括一个或一个以上模块,且每个模块可以包括对电子设备中的一系列计算机可执行指令,且经配置以由一个或者一个以上处理器执行该一个或者一个以上程序包含用于进行以下计算机可执行指令:
当对目标模型进行模型训练的迭代次数超过预设迭代次数时,获取目标模型在迭代过程中产生的多个中间模型,预设迭代次数小于预设的最大迭代次数;
以多个中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,第一损失信息由多个中间模型的输出结果与当前训练的目标模型的输出结果确定;
根据第一损失信息和第二损失信息确定第一逐样本梯度,第二损失信息由第一训练样本输入到目标模型中得到的输出结果与第一训练样本的标签确定;
对第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
本说明书一个或多个实施例还提供另一种电子设备。电子设备可因配置或性能不同而产生比较大的差异,可以包括一个或一个以上的处理器和存储器,存储器中可以存储有一个或一个以上存储应用程序或数据。其中,存储器可以是短暂存储或持久存储。存储在存储器的应用程序可以包括一个或一个以上模块,每个模块可以包括对电子设备中的一系列计算机可执行指令。更进一步地,处理器可以设置为与存储器通信,在电子设备上执行存储器中的一系列计算机可执行指令。电子设备还可以包括一个或一个以上电源,一个或一个以上有线或无线网络接口,一个或一个以上输入输出接口,一个或一个以上键盘。
具体在本实施例中,电子设备包括有存储器,以及一个或一个以上的程序,其中一个或者一个以上程序存储于存储器中,且一个或者一个以上程序可以包括一个或一个以上模块,且每个模块可以包括对电子设备中的一系列计算机可执行指令,且经配置以由一个或者一个以上处理器执行该一个或者一个以上程序包含用于进行以下计算机可执行指令:
当对风险防控模型进行模型训练的迭代次数超过预设迭代次数时,获取风险防控模型在迭代过程中产生的多个中间模型,预设迭代次数小于预设的最大迭代次数;
以多个中间模型作为教师模型,以当前训练的风险防控模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,第一损失信息由多个中间模型的输出结果与当前训练的风险防控模型的输出结果确定;
根据第一损失信息和第二损失信息确定第一逐样本梯度,第二损失信息由第一历史交易事件的特征数据输入到风险防控模型中得到的输出结果与第一历史交易事件的特征数据的标签确定;
对第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的风险防控模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的风险防控模型。
本说明书一个或多个实施例还给出了一种存储介质,用于存储计算机程序,该计算机程序能够被处理器执行以实现以下流程:
当对目标模型进行模型训练的迭代次数超过预设迭代次数时,获取目标模型在迭代过程中产生的多个中间模型,预设迭代次数小于预设的最大迭代次数;
以多个中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,第一损失信息由多个中间模型的输出结果与当前训练的目标模型的输出结果确定;
根据第一损失信息和第二损失信息确定第一逐样本梯度,第二损失信息由第一训练样本输入到目标模型中得到的输出结果与第一训练样本的标签确定;
对第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
本说明书一个或多个实施例还给出了另一种存储介质,用于存储计算机程序,该计算机程序能够被处理器执行以实现以下流程:
当对风险防控模型进行模型训练的迭代次数超过预设迭代次数时,获取风险防控模型在迭代过程中产生的多个中间模型,预设迭代次数小于预设的最大迭代次数;
以多个中间模型作为教师模型,以当前训练的风险防控模型作为学生模型,通过教师模型,采用知识蒸馏的方式对学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,第一损失信息由多个中间模型的输出结果与当前训练的风险防控模型的输出结果确定;
根据第一损失信息和第二损失信息确定第一逐样本梯度,第二损失信息由第一历史交易事件的特征数据输入到风险防控模型中得到的输出结果与第一历史交易事件的特征数据的标签确定;
对第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的风险防控模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的风险防控模型。
上述实施例阐明的装置、模块或单元,具体可以由计算机芯片或实体实现,或者由具有某种功能的产品来实现。一种典型的实现设备为计算机。具体的,计算机例如可以为个人计算机、膝上型计算机、蜂窝电话、相机电话、智能电话、个人数字助理、媒体播放器、导航设备、电子邮件设备、游戏控制台、平板计算机、可穿戴设备或者这些设备中的任何设备的组合。
为了描述的方便,描述以上装置时以功能分为各种单元分别描述。当然,在实施本说明书一个或多个实施例时可以把各单元的功能在同一个或多个软件和/或硬件中实现。
本领域内的技术人员应明白,本说明书一个或多个实施例可提供为方法、系统、或计算机程序产品。因此,本说明书一个或多个实施例可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本说明书一个或多个实施例可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本说明书一个或多个实施例是参照根据本说明书实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
在一个典型的配置中,计算设备包括一个或多个处理器(CPU)、输入/输出接口、网络接口和内存。
内存可能包括计算机可读介质中的非永久性存储器,随机存取存储器(RAM)和/或非易失性内存等形式,如只读存储器(ROM)或闪存(flash RAM)。内存是计算机可读介质的示例。
计算机可读介质包括永久性和非永久性、可移动和非可移动媒体可以由任何方法或技术来实现信息存储。信息可以是计算机可读指令、数据结构、程序的模块或其他数据。计算机的存储介质的例子包括,但不限于相变内存(PRAM)、静态随机存取存储器(SRAM)、动态随机存取存储器(DRAM)、其他类型的随机存取存储器(RAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、快闪记忆体或其他内存技术、只读光盘只读存储器(CD-ROM)、数字多功能光盘(DVD)或其他光学存储、磁盒式磁待,磁待磁盘存储或其他磁性存储设备或任何其他非传输介质,可用于存储可以被计算设备访问的信息。按照本文中的界定,计算机可读介质不包括暂存电脑可读媒体(transitory media),如调制的数据信号和载波。
还需要说明的是,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、商品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、商品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、商品或者设备中还存在另外的相同要素。
本说明书一个或多个实施例可以在由计算机执行的计算机可执行指令的一般上下文中描述,例如程序模块。一般地,程序模块包括执行特定任务或实现特定抽象数据类型的例程、程序、对象、组件、数据结构等等。也可以在分布式计算环境中实践本说明书,在这些分布式计算环境中,由通过通信网络而被连接的远程处理设备来执行任务。在分布式计算环境中,程序模块可以位于包括存储设备在内的本地和远程计算机存储介质中。
本说明书中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于系统实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
以上所述仅为本说明书一个或多个实施例而已,并不用于限制本申请。对于本领域技术人员来说,本说明书一个或多个实施例可以有各种更改和变化。凡在本说明书一个或多个实施例的精神和原理之内所作的任何修改、等同替换、改进等,均应包含在本说明书一个或多个实施例的权利要求范围之内。
Claims (12)
1.一种用于隐私保护的模型训练方法,包括:
当对目标模型进行模型训练的迭代次数超过预设迭代次数时,获取所述目标模型在迭代过程中产生的多个中间模型,所述预设迭代次数小于预设的最大迭代次数;
以多个所述中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过所述教师模型,采用知识蒸馏的方式对所述学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,所述第一损失信息由多个所述中间模型的输出结果与当前训练的目标模型的输出结果确定;
根据所述第一损失信息和第二损失信息确定第一逐样本梯度,所述第二损失信息由第一训练样本输入到所述目标模型中得到的输出结果与所述第一训练样本的标签确定;
对所述第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
2.根据权利要求1所述的方法,所述以多个所述中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过所述教师模型,采用知识蒸馏的方式对所述学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,包括:
将第二训练样本分别输入到多个所述中间模型中,得到每个中间模型对应的输出数据;
基于多个所述中间模型对应的输出数据,确定所述教师模型对应的输出结果;
将所述第二训练样本输入到所述目标模型中,得到所述学生模型对应的输出结果;
基于所述教师模型对应的输出结果和所述学生模型对应的输出结果,确定第一损失信息。
3.根据权利要求2所述的方法,所述基于所述教师模型对应的输出结果和所述学生模型对应的输出结果,确定第一损失信息,包括:
计算所述教师模型对应的输出结果与预设的第一温度参数之间的比值;
计算所述学生模型对应的输出结果与预设的第二温度参数之间的比值;
根据计算得到的两个比值确定第一损失信息。
4.根据权利要求2所述的方法,所述基于多个所述中间模型对应的输出数据,确定所述教师模型对应的输出结果,包括:
计算多个所述中间模型的输出数据的平均值,以所述平均值作为教师模型的输出结果,或者,获取多个所述中间模型的输出数据中熵最小的输出数据,以获取的熵最小的输出数据作为教师模型的输出结果。
5.根据权利要求3所述的方法,所述第一温度参数和第二温度参数为可调节的超参数,所述第一温度参数和第二温度参数的大小与模型验证集的准确率正相关。
6.根据权利要求1所述的方法,所述第一损失信息利用KL散度损失函数获取,所述第二损失信息利用交叉熵损失函数获取。
7.根据权利要求1所述的方法,所述预设迭代次数为预设的最大迭代次数的一半。
8.根据权利要求1所述的方法,所述方法还包括:
当对目标模型进行模型训练的迭代次数小于预设迭代次数时,利用第二损失信息确定第二逐样本梯度;
对所述第二逐样本梯度做差分隐私处理,并利用差分隐私处理后的第二逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
9.根据权利要求1-8中任一所述的方法,所述方法还包括:
通过设置用于模型训练的超参数,对目标模型进行初始化,所述超参数包括:批次数、模型最大迭代次数、隐私预算、裁剪阈值中的一项或多项;
设置多个模型序列,用于存储模型训练过程中产生的中间模型,所述模型序列的数量与多个中间模型的数量一致。
10.一种用于隐私保护的模型训练方法,包括:
当对风险防控模型进行模型训练的迭代次数超过预设迭代次数时,获取所述风险防控模型在迭代过程中产生的多个中间模型,所述预设迭代次数小于预设的最大迭代次数;
以多个所述中间模型作为教师模型,以当前训练的风险防控模型作为学生模型,通过所述教师模型,采用知识蒸馏的方式对所述学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,所述第一损失信息由多个所述中间模型的输出结果与当前训练的风险防控模型的输出结果确定;
根据所述第一损失信息和第二损失信息确定第一逐样本梯度,所述第二损失信息由第一历史交易事件的特征数据输入到所述风险防控模型中得到的输出结果与所述第一历史交易事件的特征数据的标签确定;
对所述第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的风险防控模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的风险防控模型。
11.一种用于隐私保护的模型训练装置,包括:
中间模型获取模块,当对目标模型进行模型训练的迭代次数超过预设迭代次数时,获取所述目标模型在迭代过程中产生的多个中间模型,所述预设迭代次数小于预设的最大迭代次数;
蒸馏训练模块,以多个所述中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过所述教师模型,采用知识蒸馏的方式对所述学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,所述第一损失信息由多个所述中间模型的输出结果与当前训练的目标模型的输出结果确定;
第一梯度计算模块,根据所述第一损失信息和第二损失信息确定第一逐样本梯度,所述第二损失信息由第一训练样本输入到所述目标模型中得到的输出结果与所述第一训练样本的标签确定;
模型参数更新模块,对所述第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
12.一种电子设备,包括:
处理器;以及
被安排成存储计算机可执行指令的存储器,在所述可执行指令被执行时,能够使得所述处理器:
当对目标模型进行模型训练的迭代次数超过预设迭代次数时,获取所述目标模型在迭代过程中产生的多个中间模型,所述预设迭代次数小于预设的最大迭代次数;
以多个所述中间模型作为教师模型,以当前训练的目标模型作为学生模型,通过所述教师模型,采用知识蒸馏的方式对所述学生模型进行蒸馏训练,并确定蒸馏训练中的第一损失信息,所述第一损失信息由多个所述中间模型的输出结果与当前训练的目标模型的输出结果确定;
根据所述第一损失信息和第二损失信息确定第一逐样本梯度,所述第二损失信息由第一训练样本输入到所述目标模型中得到的输出结果与所述第一训练样本的标签确定;
对所述第一逐样本梯度做差分隐私处理,并利用差分隐私处理后的第一逐样本梯度更新当前训练的目标模型的模型参数,直到达到预设的模型训练终止条件为止,得到训练后的目标模型。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310581293.6A CN116720214A (zh) | 2023-05-22 | 2023-05-22 | 一种用于隐私保护的模型训练方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310581293.6A CN116720214A (zh) | 2023-05-22 | 2023-05-22 | 一种用于隐私保护的模型训练方法及装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116720214A true CN116720214A (zh) | 2023-09-08 |
Family
ID=87870656
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310581293.6A Pending CN116720214A (zh) | 2023-05-22 | 2023-05-22 | 一种用于隐私保护的模型训练方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116720214A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117972436A (zh) * | 2024-03-29 | 2024-05-03 | 蚂蚁科技集团股份有限公司 | 大语言模型的训练方法、训练装置、存储介质及电子设备 |
-
2023
- 2023-05-22 CN CN202310581293.6A patent/CN116720214A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117972436A (zh) * | 2024-03-29 | 2024-05-03 | 蚂蚁科技集团股份有限公司 | 大语言模型的训练方法、训练装置、存储介质及电子设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN108920654B (zh) | 一种问答文本语义匹配的方法和装置 | |
US20230222353A1 (en) | Method and system for training a neural network model using adversarial learning and knowledge distillation | |
CN111415015B (zh) | 业务模型训练方法、装置、系统及电子设备 | |
CN114429222A (zh) | 一种模型的训练方法、装置及设备 | |
CN113435585A (zh) | 一种业务处理方法、装置及设备 | |
CN116720214A (zh) | 一种用于隐私保护的模型训练方法及装置 | |
CN114548300B (zh) | 解释业务处理模型的业务处理结果的方法和装置 | |
CN114090401B (zh) | 处理用户行为序列的方法及装置 | |
CN111353514A (zh) | 模型训练方法、图像识别方法、装置及终端设备 | |
US20230162018A1 (en) | Dimensional reduction of correlated vectors | |
CN111709415A (zh) | 目标检测方法、装置、计算机设备和存储介质 | |
CN117349899B (zh) | 基于遗忘模型的敏感数据处理方法、系统及存储介质 | |
CN113221717A (zh) | 一种基于隐私保护的模型构建方法、装置及设备 | |
CN116152542A (zh) | 图像分类模型的训练方法、装置、设备及存储介质 | |
CN115018608A (zh) | 风险预测方法、装置、计算机设备 | |
CN114186039A (zh) | 一种视觉问答方法、装置及电子设备 | |
CN114387480A (zh) | 一种人像加扰的方法及装置 | |
US12112524B2 (en) | Image augmentation method, electronic device and readable storage medium | |
US20230019779A1 (en) | Trainable differential privacy for machine learning | |
CN118312873B (zh) | 基于文本识别的灾害预测方法及系统 | |
CN114168799B (zh) | 图数据结构中节点邻接关系的特征获取方法、装置及介质 | |
CN113379062B (zh) | 用于训练模型的方法和装置 | |
US11983152B1 (en) | Systems and methods for processing environmental, social and governance data | |
CN112819177B (zh) | 一种个性化的隐私保护学习方法、装置以及设备 | |
US20240303548A1 (en) | Method for collaborative machine learning |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |