CN116578400A - 多任务数据处理方法和装置 - Google Patents

多任务数据处理方法和装置 Download PDF

Info

Publication number
CN116578400A
CN116578400A CN202310535445.9A CN202310535445A CN116578400A CN 116578400 A CN116578400 A CN 116578400A CN 202310535445 A CN202310535445 A CN 202310535445A CN 116578400 A CN116578400 A CN 116578400A
Authority
CN
China
Prior art keywords
model
multitasking
student
teacher
loss function
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310535445.9A
Other languages
English (en)
Inventor
李昂
周俊
张晓露
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Alipay Hangzhou Information Technology Co Ltd
Original Assignee
Alipay Hangzhou Information Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Alipay Hangzhou Information Technology Co Ltd filed Critical Alipay Hangzhou Information Technology Co Ltd
Priority to CN202310535445.9A priority Critical patent/CN116578400A/zh
Publication of CN116578400A publication Critical patent/CN116578400A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F9/00Arrangements for program control, e.g. control units
    • G06F9/06Arrangements for program control, e.g. control units using stored programs, i.e. using an internal store of processing equipment to receive or retain programs
    • G06F9/46Multiprogramming arrangements
    • G06F9/48Program initiating; Program switching, e.g. by interrupt
    • G06F9/4806Task transfer initiation or dispatching
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/22Matching criteria, e.g. proximity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

多任务数据处理方法和装置。一种利用多任务教师模型对多任务学生模型进行训练的方法,多任务教师模型和多任务学生模型均包括共享模型和位于共享模型下游的、用于分别执行多个任务的多个子任务模型,该方法包括:针对每个样本,确定多任务教师模型的共享模型的输出与多任务学生模型的共享模型的输出之间的相似度;利用相似度对蒸馏损失函数进行加权,蒸馏损失函数用于表征多任务教师模型的多个子任务模型各自输出与多任务学生模型的多个子任务模型各自输出之间的差异;根据加权后的所述蒸馏损失函数对所述多任务学生模型的参数进行调整。

Description

多任务数据处理方法和装置
技术领域
本申请一般涉及互联网领域,尤其涉及互联网中的多任务数据处理方和装置。
背景技术
随着互联网技术的发展,人们越来越频繁地浏览网络平台推送的信息。用户对网络平台推送信息会进行点击、转化、收藏、关注等行为。需要对用户的这些行为的相关数据进行处理以提升网络推送的效率。
在信息推送中,需要考虑多个目标(例如,点击率、转化率等)。现有技术针对多个目标分开处理数据,导致计算设备的处理量很大,浪费了设备的处理资源。
因此亟需高效地针对多个目标处理数据以进行信息推送的方案。
发明内容
为解决上述技术问题,本公开提供了一种利用多任务教师模型对多任务学生模型进行训练的方法,所述多任务教师模型和所述多任务学生模型均包括共享模型和位于所述共享模型下游的、用于分别执行多个任务的多个子任务模型,所述方法包括:针对训练样本集中的每个样本,确定所述多任务教师模型的共享模型的输出与所述多任务学生模型的共享模型的输出之间的相似度;利用所述相似度对蒸馏损失函数进行加权,所述蒸馏损失函数用于表征所述多任务教师模型的所述多个子任务模型各自输出与所述多任务学生模型的所述多个子任务模型各自输出之间的差异,该差异越大,所述蒸馏损失函数越大;根据加权后的所述蒸馏损失函数对所述多任务学生模型的参数进行调整。
可任选地,所述方法进一步包括:使用所述多任务教师模型的多个子任务模型的多个输出、所述多任务学生模型的多个子任务模型的多个输出和所述相似度来确定总损失函数;以及根据所述总损失函数来对所述多任务学生模型的参数进行调整。
可任选地,所述总损失函数进一步根据交叉熵损失函数来确定,其中所述交叉熵损失函数基于所述多任务学生模型针对所述多个任务的多个输出和对应于所述训练样本集的硬标签向量集来确定。
可任选地,所述总损失函数L如下确定:
L=LCE+λLKD,
其中
其中LCE为交叉熵损失函数,LKD为所述蒸馏损失函数,yij是所述训练样本集中的第i个训练样本针对第j个任务的硬标签,Toutij是所述多任务教师模型关于第i个训练样本针对第j个任务的输出,Soutij是所述多任务学生模型关于第i个训练样本针对第j个任务的输出,simi是所述多任务教师模型和所述多任务学生模型关于第i个训练样本的输出的相似度,M是所述多任务教师模型和所述多任务学生模型的任务数目,N是所述训练样本集中的样本数目。
可任选地,所述多任务教师模型的共享模型的输出与所述多任务学生模型的共享模型的输出是相同维度的向量;所述相似度是所述多任务教师模型的共享模型的输出与所述多任务学生模型的共享模型的输出之间的余弦相似度。
可任选地,所述多任务教师模型的共享模型和所述多任务学生模型的共享模型是深度神经网络(DNN),并且所述多任务教师模型的共享模型和所述多任务学生模型的共享模型的层数相同。
可任选地,所述多个任务包括预测业务的点击率和预测广告的转化率,所述训练样本集包括关于所述业务的用户侧特征和物品侧特征。
本公开的另一方面提供了一种利用多任务教师模型对多任务学生模型进行训练的装置,所述多任务教师模型和所述多任务学生模型均包括共享模型和位于所述共享模型下游的、用于分别执行多个任务的多个子任务模型,所述装置包括:用于针对训练样本集中的每个样本,确定所述多任务教师模型的共享模型的输出与所述多任务学生模型的共享模型的输出之间的相似度的模块;用于利用所述相似度对蒸馏损失函数进行加权的模块,所述蒸馏损失函数用于表征所述多任务教师模型的所述多个子任务模型各自输出与所述多任务学生模型的所述多个子任务模型各自输出之间的差异,该差异越大,所述蒸馏损失函数越大;用于根据加权后的所述蒸馏损失函数对所述多任务学生模型的参数进行调整的模块。
可任选地,该装置进一步包括:用于使用所述多任务教师模型的多个子任务模型的多个输出、所述多任务学生模型的多个子任务模型的多个输出和所述相似度来确定总损失函数的模块;以及用于根据所述总损失函数来对所述多任务学生模型的参数进行调整的模块。
可任选地,所述总损失函数进一步根据交叉熵损失函数来确定,其中所述交叉熵损失函数基于所述多任务学生模型针对所述多个任务的多个输出和对应于所述训练样本集的硬标签向量集来确定。
可任选地,所述总损失函数L如下确定:
L=LCE+λLKD,
其中
其中LCE为交叉熵损失函数,LKD为所述蒸馏损失函数,yij是所述训练样本集中的第i个训练样本针对第j个任务的硬标签,Toutij是所述多任务教师模型关于第i个训练样本针对第j个任务的输出,Soutij是所述多任务学生模型关于第i个训练样本针对第j个任务的输出,simi是所述多任务教师模型和所述多任务学生模型关于第i个训练样本的输出的相似度,M是所述多任务教师模型和所述多任务学生模型的任务数目,N是所述训练样本集中的样本数目。
可任选地,所述多任务教师模型的共享模型的输出与所述多任务学生模型的共享模型的输出是相同维度的向量;所述相似度是所述多任务教师模型的共享模型的输出与所述多任务学生模型的共享模型的输出之间的余弦相似度。
可任选地,所述多任务教师模型的共享模型和所述多任务学生模型的共享模型是深度神经网络(DNN),并且所述多任务教师模型的共享模型和所述多任务学生模型的共享模型的层数相同。
可任选地,所述多个任务包括预测业务的点击率和预测广告的转化率,所述训练样本集包括关于所述业务的用户侧特征和物品侧特征。
本公开的又一方面提供了一种利用多任务教师模型对多任务学生模型进行训练的装置,所述多任务教师模型和所述多任务学生模型均包括共享模型和位于所述共享模型下游的、用于分别执行多个任务的多个子任务模型,所述装置包括:处理器;以及被安排成存储计算机可执行指令的存储器,所述可执行指令在被执行时使所述处理器执行以下操作:针对训练样本集中的每个样本,确定所述多任务教师模型的共享模型的输出与所述多任务学生模型的共享模型的输出之间的相似度;利用所述相似度对蒸馏损失函数进行加权,所述蒸馏损失函数用于表征所述多任务教师模型的所述多个子任务模型各自输出与所述多任务学生模型的所述多个子任务模型各自输出之间的差异,该差异越大,所述蒸馏损失函数越大;根据加权后的所述蒸馏损失函数对所述多任务学生模型的参数进行调整。
本公开的技术方案通过使用多任务教师模型对多任务学生模型进行知识蒸馏来训练多任务学生模型。相比于现有技术中针对多个任务使用分开的模型进行处理,本公开的技术方案能够有效地节省电子设备的计算资源。进一步,本公开的技术方案可以并行地对多个业务(例如,点击率、转化率等)进行预测,能够提高用户的使用体验。
附图说明
图1是根据本公开的各方面的用于多任务数据处理的装置的示图。
图2是根据本公开的各方面的教师模型和学生模型的示图。
图3是根据本公开的各方面的训练模块的示图。
图4是根据本公开的各方面的训练多任务模型的流程图。
图5是根据本公开的各方面的利用多任务教师模型对多任务学生模型进行训练的方法的流程图。
具体实施方式
为让本公开的上述目的、特征和优点能更明显易懂,以下结合附图对本公开的具体实施方式作详细说明。
在下面的描述中阐述了很多具体细节以便于充分理解本公开,但是本公开还可以采用其它不同于在此描述的其它方式来实施,因此本公开不受下面公开的具体实施例的限制。
神经网络模型可被用于对用户对网络推送信息(例如,广告)的行为进行预测。例如,神经网络模型可被用于确定业务(例如,广告)的点击率(click-through rate,CTR)和转化率(conversion rate,CVR),其中点击率是指在一个统计周期内,特定广告被用户点击的次数除以该广告被展示的总数,即,点击率=(点击量/展示量)×100%;转化率是指在一个统计周期内,广告相关商品被购买的次数除以该广告被点击的总数,即,点击转化率=(投保量/点击量)×100%。
为了方便理解,本公开以预测点击率和转化率作为任务的示例进行解说,但本领域技术人员将领会,其他的任务也在本公开的构想中,例如,完成特定操作的时间等等。
本公开提出了一种使用自蒸馏方式来训练多任务(多目标)神经网络模型的方案,其中每个任务拟合对应的一个目标。多任务神经网络模型可以同时处理多个任务。例如,该多个任务可包括预测点击率、转化率等等。
知识蒸馏指的是将预先训练好的教师模型的知识通过蒸馏的方式迁移至学生模型。一般来说,教师模型会比学生模型网络容量更大,模型结构更复杂。学生模型的精度接近于教师模型,但结构简单,可被用于后续的预测操作,由此可以简化计算复杂度,提高计算机的处理效率。在学生模型的训练过程中可以使用训练好的教师模型所生成的软标签。自蒸馏是知识蒸馏的一种,在自蒸馏中,在教师模型和学生模型中分别设置结构相同的神经网络模型。
在现有技术中,通过自蒸馏方式,一般一个教师/学生模型仅用于一个任务。而本申请将自蒸馏方式应用于多任务模型,将一个教师/学生模型用于多个任务,使得模型同时拟合多个目标,进一步节省了计算机设备(例如,处理器)的计算资源。
进一步,在本公开中,在教师模型和学生模型中分别设置结构相同的神经网络模型,获取这两个结构相同的神经网络模型针对相同训练样本的输出向量之间的相似度,并且将该相似度应用于损失函数以用于学生模型的训练,由此使得学生模型的精度更接近于教师模型。
图1是根据本公开的各方面的用于多任务数据处理的装置的示图。
如图1所示,用于多任务数据处理的装置100可包括教师模型102和学生模型104。教师模型102和学生模型104均是多任务(多目标)模型,具有针对多个任务的多个输出。
在一方面,教师模型102和学生模型104接收相同的样本数据(例如,训练样本数据)。
样本数据可包括用户(user)侧特征和物品(item)侧特征。例如,在样本数据表征用户对特定广告的点击和转化的情形中,用户侧特征可包括用户的相关特征,例如,出生日期、收益范围、所在地区等等;item侧特征可包括广告的标识符(ID)、广告的商家的标识符(ID)、广告所属行业等等。
进一步,教师模型102和学生模型104可对样本数据进行处理,生成分别对应于多个任务的多个输出。例如,教师模型102和学生模型104可输出相同维度的输出向量,输出向量中的每个元素是对应于一个任务的输出。在图1的示例中,针对每个输入数据样本,教师模型102输出向量[Tout1,Tout 2,…ToutM],其中Touti是教师模型102针对第一i个任务的输出。学生模型104输出向量[Sout1,Sout 2,…SoutM],其中Souti是学生模型104针对第一i个任务的输出。
教师模型102可经过预先训练。教师模型102可以比学生模型104的网络容量更大,模型结构更复杂。
作为一个示例,装置100可以处理两个任务:第一任务CTR和第二任务CVR。在该示例中,教师模型102可以具有两个输出Tout1和Tout2,其中Tout1对应于CTR,而Tout2对应于CVR。同样,学生模型104可以具有两个输出Sout1和Sout2,其中Sout1对应于CTR,而Sout2对应于CVR。
教师模型102可包括教师共享模型106,并且学生模型104可包括学生共享模型108。教师共享模型106和学生共享模型108的结构相同。举例而言,教师共享模型106和学生共享模型108可以均为DNN(深度神经网络)模型,其层数、每层的节点个数、以及各个节点之间的连接的结构都相同。
教师共享模型106和学生共享模型108的输出为相同维度的向量。
装置100还可以包括相似度模块110,其接收教师共享模型106针对一输入样本的输出向量TM和学生共享模型108针对同一输入样本的输出向量SM,并确定两者之间的相似度sim。在一个示例中,该相似度sim可以是教师共享模型106的输出向量TM和学生共享模型108的输出向量SM之间的余弦相似度。
训练模块112可接收来自教师模型102的输出向量[Tout1,Tout2,…ToutM]、来自学生模型104的输出向量[Sout1,Sout2,…SoutM]、来自相似度模块110的相似度sim、以及与输入样本相对应的标签向量Y。标签向量Y=[y1,y2,...yM]为M维向量(维度与教师模型102和学生模型104的输出相同),其中每个元素对应于一个任务,例如,yi可对应于第i个任务。在本公开中,标签向量Y可被称为硬标签或真实标签,教师模型102的输出向量[Tout1,Tout 2,…ToutM]可被称为软标签。
在第一任务为CTR并且第二任务为CVR的示例中,每个真实标签向量可以是二维向量[y1,y2],其中y1对应于第一任务CTR,例如,如果样本的用户点击了广告,则y1=1,否则y1=0;y2对应于第二任务CVR,例如,如果样本的用户转化了广告(例如,购买了广告的商品),则y2=1,否则y2=0。
请注意,虽然图1中仅示出了教师模型102包括教师共享模型106,但教师模型102还可包括其他模块,例如,embedding(嵌入)层以及其他神经网络模型。同样,学生模型104除了学生共享模型108之外也可包括其他模块,例如,embedding层以及其他神经网络模型。
图2是根据本公开的各方面的教师模型202和学生模型204的示图。在图2的示例中,教师模型202和学生模型204被用于两个任务。
教师模型202可包括embedding层、教师共享模型和分别用于两个任务的DNN,其中DNNT1用于第一任务,DNNT2用于第二任务。教师共享模型可从embedding层接收输入,并将输出向量TM提供给DNNT1和DNNT2,以及提供给相似度模块210。
同样,学生模型204可包括embedding层、学生共享模型和分别用于两个任务的DNN,其中DNNS1用于第一任务,DNNS2用于第二任务。学生共享模型可从embedding层接收输入,并将输出向量SM提供给DNNS1和DNNS2,以及提供给相似度模块210。
回到图1,训练模块112使用损失函数来确定学生模型104的输出与预期输出(例如,硬标签、软标签)之间的差异(或即,损失值),并使用损失值通过反向传播来更新学生模型104的各个参数(例如,DNN模型的阈值和权重),使得学生模型104的输出向预期输出收敛,从而达到训练学生模型的目的。
如图3所示,本公开的训练模块112可包括蒸馏损失信息模块304、交叉熵损失信息模块302和总损失信息模块306。
交叉熵损失信息模块302可接收来自学生模型的多个输出[Sout1,Sout2,…SoutM]和真实标签向量Y=[y1,y2,...yM],并且可确定交叉熵损失值。
例如,交叉熵损失值可以如下确定:
其中yij是训练样本集中的第i个训练样本针对第j个任务的真实标签(硬标签),Soutij是学生模型关于第i个训练样本针对第j个任务的输出,M是教师模型和学生模型的任务数目,N是训练样本集中的样本数目。
蒸馏损失信息模块304可确定学生模型与教师模型的输出之间的损失值。
例如,蒸馏损失值LKD可以使用下式确定:
优选地,在本公开中,蒸馏损失信息模块304还可以使用学生模型和教师模型的中间特征(例如,如上所述的TM和SM)之间的相似度对学生模型与教师模型之间的损失值加权。
具体而言,蒸馏损失信息模块304可接收来自学生模型104的多个输出[Sout1,Sout2,…SoutM]、来自教师模型102的多个输出[Tout1,Tout2,…ToutM]、以及来自相似度模块110的相似度sim,并且可确定蒸馏损失值。
例如,蒸馏损失值LKD可以如下确定:
其中,N是训练样本集中的样本数目,simi是教师模型和学生模型关于第i个训练样本的输出的相似度,Toutij是教师模型关于第i个训练样本针对第j个任务的输出,Soutij是学生模型关于第i个训练样本针对第j个任务的输出,M是教师模型和学生模型的任务数目。
在本公开的一方面,可以利用教师模型和学生模型的中间特征的相似度sim对学生模型与教师模型之间的损失值加权。具体而言,可以对中间特征向量相似度高的样本的蒸馏损失分量(即,式(1)中的)增大权重,而对中间特征向量相似度低的样本的蒸馏损失分量减小权重,从而提高了使用知识蒸馏训练学生模型的性能,由此提高了经训练的学生模型在多任务预测中的准确度。
具体而言,如果教师模型和学生模型关于样本1的中间特征的相似度sim较低,则说明样本1的教师模型能给学生模型带来的知识较少,期望样本1的蒸馏损失分量被给予较小的权重,由此反向传播带来的梯度也会较小,对学生模型更新的影响也就会小。相反,如果教师模型和学生模型关于样本2的中间特征的相似度sim较高,则说明样本2的教师模型能给学生模型带来的知识较多,期望样本2的蒸馏损失分量被给予较高的权重,对学生模型更新的影响也就会大,并且在对学生模型的训练过程中需要令蒸馏损失分量变得更小才能达到训练目标,由此提高了学生模型的鲁棒性。
总损失信息模块306可根据上述蒸馏损失值和交叉熵损失值来确定总损失值。
具体而言,总损失值可如下确定:
L=LCE+λLKD(4)
其中λ的值可以根据历史经验值来确定。
训练模块112可以调整学生模型104的参数,迭代地确定总损失值,以使得总损失值收敛,从而达到训练学生模型104的目标。
图4是根据本公开的各方面的训练多任务模型的流程图。其中教师模型和学生模型都是多任务模型,具有针对多个任务的输出。
如图4所示,在步骤402,可以预先训练教师模型。
首先可以使用有监督训练来训练教师模型。例如,使用训练样本集以及对应于多个任务的标签向量集(硬标签向量集)来训练教师模型。
在步骤404,可以获取训练样本集及其硬(真实)标签向量集。
训练样本集中的每个样本与硬标签向量集中的一个标签向量相对应。标签向量的维度与模型的任务数目相等,标签向量的每个元素是对应样本关于一个任务的标签。
举例而言,如果教师模型和学生模型处理两个任务:第一任务CTR和第二任务CVR,则每个训练样本可对应于一个二维硬标签向量:[y1,y2],其中y1对应于第一任务CTR,例如,如果该样本的用户点击了广告,则y1=1,否则y1=0;y2对应于第二任务CVR,例如,如果该样本的用户转化了广告(例如,购买了广告的商品),则y2=1,否则y2=0。
步骤404中所使用的训练样本集及其硬标签向量集可以与步骤402中用于训练教师模型的训练样本集及其硬标签向量集相同或不同。
训练样本集可被输入学生模型和经预先训练的教师模型,硬标签向量集也可被输入以供在确定损失值时使用。
在步骤406,可以确定教师模型和学生模型的中间特征的相似度。
具体而言,教师模型和学生模型可包括相同的共享模型,例如,如图1中所示的教师共享模型106和学生共享模型108。
针对每个输入的训练样本,可确定教师共享模型106和学生共享模型108输出的中间特征向量(例如,TM和SM)之间的相似度sim。
教师共享模型106和学生共享模型108输出的中间特征向量的维度是相同的,并且它们之间的相似度可以是两个向量之间的余弦相似度。
请注意,步骤406以虚线示出,表示该步骤是可任选的。
在步骤408,可以获取教师模型和学生模型的输出。
教师模型和学生模型可分别具有多个输出或即包括多个元素的输出向量,其中每个输出或输出向量对应于一个任务。
在步骤410,可以确定教师模型和学生模型之间的总损失值。
可以根据教师模型的输出、学习模型的输出、真实标签、以及可任选的相似度来确定教师模型和学生模型之间的总损失值。
具体而言,总损失值L可以根据交叉熵损失值LCE和蒸馏损失值LKD来确定。例如,总损失值L可以是交叉熵损失值LCE和蒸馏损失值LKD的加权和,如上所述的式(4)所示。交叉熵损失值LCE可以表示学生模型的输出与真实标签之间的差异(损失值)。例如,交叉熵损失值LCE可以使用如上所述的式(1)来确定。
蒸馏损失值LKD可以表示学生模型的多个输出与教师模型的多个输出之间的差异(损失值)。差异越大,蒸馏损失函数越大。例如,蒸馏损失值LKD可以使用如上所述的式(2)来确定。
优选地,可使用学生模型和教师模型的中间特征向量之间的相似度sim对学生模型的多个输出与教师模型的多个输出之间的损失值进行加权。例如,蒸馏损失值LKD可以使用如上所述的式(3)来确定。
具体而言,使用中间特征向量的相似度sim对蒸馏值分量进行加权。具体而言,对中间特征向量相似度高的样本的蒸馏损失分量(即,式(3)中的)增大权重,而对中间特征向量相似度低的样本的蒸馏损失分量减小权重。由此使得中间特征向量相似度高的教师模型输出能够给学生模型带来更多的知识(给予更大的权重),而中间特征向量相似度低的教师模型输出能够给学生模型带来较少的知识(给以较小的权重)。
在步骤412,可以确定是否达到训练完成目标。
在一方面,可以确定在步骤410所确定的损失值是否小于目标损失值。如果小于目标损失值,则说明达到训练完成目标。
在另一方面,可以确定训练次数是否达到目标训练次数。如果达到目标训练次数,则说明达到训练完成目标。
如果在步骤412,确定达到训练完成目标,则行进至步骤414,学生模型的训练完成。
如果在步骤412,确定未达到训练完成目标,则在步骤416调整学生模型的参数。
例如,可以使用梯度下降法来调整学生模型的参数。
虽然可以返回步骤404,继续获取训练样本集及其真实标签向量集,进行新一轮的训练,直至达到训练完成目标。
图5是根据本公开的各方面的利用多任务教师模型对多任务学生模型进行训练的流程图。多任务教师模型和多任务学生模型均包括共享模型和位于共享模型下游的、用于分别执行多个任务的多个子任务模型,如图2所示。
如图5所示,在步骤502,可以针对训练样本集中的每个样本,确定多任务教师模型的共享模型的输出与多任务学生模型的共享模型的输出之间的相似度。
在步骤504,可以利用所述相似度对蒸馏损失函数进行加权,蒸馏损失函数用于表征多任务教师模型的多个子任务模型各自输出与多任务学生模型的多个子任务模型各自输出之间的差异,该差异越大,蒸馏损失函数越大。
在步骤506,可以根据加权后的所述蒸馏损失函数对多任务学生模型的参数进行调整。
在一方面,该方法可以进一步包括:使用多任务教师模型的多个子任务模型的多个输出、多任务学生模型的多个子任务模型的多个输出和该相似度来确定总损失函数;以及根据总损失函数来对多任务学生模型的参数进行调整。
在一方面,总损失函数可以进一步根据交叉熵损失函数来确定,其中交叉熵损失函数基于多任务学生模型针对该多个任务的多个输出和对应于该训练样本集的硬标签向量集来确定。
在一方面,总损失函数L可以如下确定:
L=LCE+λLKD,
其中
其中LCE为交叉熵损失函数,LKD为蒸馏损失函数,yij是训练样本集中的第i个训练样本针对第j个任务的硬标签,Toutij是多任务教师模型关于第i个训练样本针对第j个任务的输出,Soutij是多任务学生模型关于第i个训练样本针对第j个任务的输出,simi是多任务教师模型和多任务学生模型关于第i个训练样本的输出的相似度,M是多任务教师模型和多任务学生模型的任务数目,N是训练样本集中的样本数目。
在一方面,多任务教师模型的共享模型的输出与多任务学生模型的共享模型的输出是相同维度的向量;相似度是多任务教师模型的共享模型的输出与多任务学生模型的共享模型的输出之间的余弦相似度。
在一方面,多任务教师模型的共享模型和多任务学生模型的共享模型是深度神经网络(DNN),并且多任务教师模型的共享模型和多任务学生模型的共享模型的层数相同。
在一方面,该多个任务包括预测业务的点击率和预测广告的转化率,训练样本集包括关于该业务的用户侧特征和物品侧特征。
本文结合附图阐述的说明描述了示例配置而不代表可被实现或者落在权利要求的范围内的所有示例。本文所使用的术语“示例性”意指“用作示例、实例或解说”,而并不意指“优于”或“胜过其他示例”。本详细描述包括具体细节以提供对所描述的技术的理解。然而,可以在没有这些具体细节的情况下实践这些技术。在一些实例中,众所周知的结构和设备以框图形式示出以避免模糊所描述的示例的概念。
在附图中,类似组件或特征可具有相同的附图标记。此外,相同类型的各个组件可通过在附图标记后跟随短划线以及在类似组件之间进行区分的第二标记来加以区分。如果在说明书中仅使用第一附图标记,则该描述可应用于具有相同的第一附图标记的类似组件中的任何一个组件而不论第二附图标记如何。
结合本文中的公开描述的各种解说性框以及模块可以用设计成执行本文中描述的功能的通用处理器、DSP、ASIC、FPGA或其他可编程逻辑器件、分立的门或晶体管逻辑、分立的硬件组件、或其任何组合来实现或执行。通用处理器可以是微处理器,但在替换方案中,处理器可以是任何常规的处理器、控制器、微控制器、或状态机。处理器还可被实现为计算设备的组合(例如,DSP与微处理器的组合、多个微处理器、与DSP核心协同的一个或多个微处理器,或者任何其他此类配置)。
本文中所描述的功能可以在硬件、由处理器执行的软件、固件、或其任何组合中实现。如果在由处理器执行的软件中实现,则各功能可以作为一条或多条指令或代码存储在计算机可读介质上或藉其进行传送。其他示例和实现落在本公开及所附权利要求的范围内。例如,由于软件的本质,以上描述的功能可使用由处理器执行的软件、硬件、固件、硬连线或其任何组合来实现。实现功能的特征也可物理地位于各种位置,包括被分布以使得功能的各部分在不同的物理位置处实现。另外,如本文(包括权利要求中)所使用的,在项目列举(例如,以附有诸如“中的至少一个”或“中的一个或多个”之类的措辞的项目列举)中使用的“或”指示包含性列举,以使得例如A、B或C中的至少一个的列举意指A或B或C或AB或AC或BC或ABC(即,A和B和C)。同样,如本文所使用的,短语“基于”不应被解读为引述封闭条件集。例如,被描述为“基于条件A”的示例性步骤可基于条件A和条件B两者而不脱离本公开的范围。换言之,如本文所使用的,短语“基于”应当以与短语“至少部分地基于”相同的方式来解读。
计算机可读介质包括非瞬态计算机存储介质和通信介质两者,其包括促成计算机程序从一地向另一地转移的任何介质。非瞬态存储介质可以是能被通用或专用计算机访问的任何可用介质。作为示例而非限定,非瞬态计算机可读介质可包括RAM、ROM、电可擦除可编程只读存储器(EEPROM)、压缩盘(CD)ROM或其他光盘存储、磁盘存储或其他磁存储设备、或能被用来携带或存储指令或数据结构形式的期望程序代码手段且能被通用或专用计算机、或者通用或专用处理器访问的任何其他非瞬态介质。任何连接也被正当地称为计算机可读介质。例如,如果软件是使用同轴电缆、光纤电缆、双绞线、数字订户线(DSL)、或诸如红外、无线电、以及微波之类的无线技术从web网站、服务器、或其它远程源传送而来的,则该同轴电缆、光纤电缆、双绞线、数字订户线(DSL)、或诸如红外、无线电、以及微波之类的无线技术就被包括在介质的定义之中。如本文所使用的盘(disk)和碟(disc)包括CD、激光碟、光碟、数字通用碟(DVD)、软盘和蓝光碟,其中盘常常磁性地再现数据而碟用激光来光学地再现数据。以上介质的组合也被包括在计算机可读介质的范围内。
提供本文的描述是为了使得本领域技术人员能够制作或使用本公开。对本公开的各种修改对于本领域技术人员将是显而易见的,并且本文中定义的普适原理可被应用于其他变形而不会脱离本公开的范围。由此,本公开并非被限定于本文所描述的示例和设计,而是应被授予与本文所公开的原理和新颖特征相一致的最广范围。

Claims (15)

1.一种利用多任务教师模型对多任务学生模型进行训练的方法,所述多任务教师模型和所述多任务学生模型均包括共享模型和位于所述共享模型下游的、用于分别执行多个任务的多个子任务模型,所述方法包括:
针对训练样本集中的每个样本,确定所述多任务教师模型的共享模型的输出与所述多任务学生模型的共享模型的输出之间的相似度;
利用所述相似度对蒸馏损失函数进行加权,所述蒸馏损失函数用于表征所述多任务教师模型的所述多个子任务模型各自输出与所述多任务学生模型的所述多个子任务模型各自输出之间的差异,该差异越大,所述蒸馏损失函数越大;
根据加权后的所述蒸馏损失函数对所述多任务学生模型的参数进行调整。
2.如权利要求1所述的方法,其中所述方法进一步包括:
使用所述多任务教师模型的多个子任务模型的多个输出、所述多任务学生模型的多个子任务模型的多个输出和所述相似度来确定总损失函数;以及
根据所述总损失函数来对所述多任务学生模型的参数进行调整。
3.如权利要求2所述的方法,其中所述总损失函数进一步根据交叉熵损失函数来确定,其中所述交叉熵损失函数基于所述多任务学生模型针对所述多个任务的多个输出和对应于所述训练样本集的硬标签向量集来确定。
4.如权利要求3所述的方法,其中所述总损失函数L如下确定:
L=LCE+λLKD,
其中
其中LCE为交叉熵损失函数,LKD为所述蒸馏损失函数,yij是所述训练样本集中的第i个训练样本针对第j个任务的硬标签,Toutij是所述多任务教师模型关于第i个训练样本针对第j个任务的输出,Soutij是所述多任务学生模型关于第i个训练样本针对第j个任务的输出,simi是所述多任务教师模型和所述多任务学生模型关于第i个训练样本的输出的相似度,M是所述多任务教师模型和所述多任务学生模型的任务数目,N是所述训练样本集中的样本数目。
5.如权利要求4所述的方法,其中所述多任务教师模型的共享模型的输出与所述多任务学生模型的共享模型的输出是相同维度的向量;
所述相似度是所述多任务教师模型的共享模型的输出与所述多任务学生模型的共享模型的输出之间的余弦相似度。
6.如权利要求1所述的方法,其中所述多任务教师模型的共享模型和所述多任务学生模型的共享模型是深度神经网络(DNN),并且所述多任务教师模型的共享模型和所述多任务学生模型的共享模型的层数相同。
7.如权利要求1所述的方法,其中所述多个任务包括预测业务的点击率和预测广告的转化率,所述训练样本集包括关于所述业务的用户侧特征和物品侧特征。
8.一种利用多任务教师模型对多任务学生模型进行训练的装置,所述多任务教师模型和所述多任务学生模型均包括共享模型和位于所述共享模型下游的、用于分别执行多个任务的多个子任务模型,所述装置包括:
用于针对训练样本集中的每个样本,确定所述多任务教师模型的共享模型的输出与所述多任务学生模型的共享模型的输出之间的相似度的模块;
用于利用所述相似度对蒸馏损失函数进行加权的模块,所述蒸馏损失函数用于表征所述多任务教师模型的所述多个子任务模型各自输出与所述多任务学生模型的所述多个子任务模型各自输出之间的差异,该差异越大,所述蒸馏损失函数越大;
用于根据加权后的所述蒸馏损失函数对所述多任务学生模型的参数进行调整的模块。
9.如权利要求8所述的装置,进一步包括:
用于使用所述多任务教师模型的多个子任务模型的多个输出、所述多任务学生模型的多个子任务模型的多个输出和所述相似度来确定总损失函数的模块;以及
用于根据所述总损失函数来对所述多任务学生模型的参数进行调整的模块。
10.如权利要求9所述的装置,其中所述总损失函数进一步根据交叉熵损失函数来确定,其中所述交叉熵损失函数基于所述多任务学生模型针对所述多个任务的多个输出和对应于所述训练样本集的硬标签向量集来确定。
11.如权利要求10所述的装置,其中所述总损失函数L如下确定:
L=LCE+λLKD,
其中
其中LCE为交叉熵损失函数,LKD为所述蒸馏损失函数,yij是所述训练样本集中的第i个训练样本针对第j个任务的硬标签,Toutij是所述多任务教师模型关于第i个训练样本针对第j个任务的输出,Soutij是所述多任务学生模型关于第i个训练样本针对第j个任务的输出,simi是所述多任务教师模型和所述多任务学生模型关于第i个训练样本的输出的相似度,M是所述多任务教师模型和所述多任务学生模型的任务数目,N是所述训练样本集中的样本数目。
12.如权利要求11所述的装置,其中所述多任务教师模型的共享模型的输出与所述多任务学生模型的共享模型的输出是相同维度的向量;
所述相似度是所述多任务教师模型的共享模型的输出与所述多任务学生模型的共享模型的输出之间的余弦相似度。
13.如权利要求8所述的装置,其中所述多任务教师模型的共享模型和所述多任务学生模型的共享模型是深度神经网络(DNN),并且所述多任务教师模型的共享模型和所述多任务学生模型的共享模型的层数相同。
14.如权利要求8所述的装置,其中所述多个任务包括预测业务的点击率和预测广告的转化率,所述训练样本集包括关于所述业务的用户侧特征和物品侧特征。
15.一种利用多任务教师模型对多任务学生模型进行训练的装置,所述多任务教师模型和所述多任务学生模型均包括共享模型和位于所述共享模型下游的、用于分别执行多个任务的多个子任务模型,所述装置包括:
处理器;以及
被安排成存储计算机可执行指令的存储器,所述可执行指令在被执行时使所述处理器执行以下操作:
针对训练样本集中的每个样本,确定所述多任务教师模型的共享模型的输出与所述多任务学生模型的共享模型的输出之间的相似度;
利用所述相似度对蒸馏损失函数进行加权,所述蒸馏损失函数用于表征所述多任务教师模型的所述多个子任务模型各自输出与所述多任务学生模型的所述多个子任务模型各自输出之间的差异,该差异越大,所述蒸馏损失函数越大;
根据加权后的所述蒸馏损失函数对所述多任务学生模型的参数进行调整。
CN202310535445.9A 2023-05-09 2023-05-09 多任务数据处理方法和装置 Pending CN116578400A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310535445.9A CN116578400A (zh) 2023-05-09 2023-05-09 多任务数据处理方法和装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310535445.9A CN116578400A (zh) 2023-05-09 2023-05-09 多任务数据处理方法和装置

Publications (1)

Publication Number Publication Date
CN116578400A true CN116578400A (zh) 2023-08-11

Family

ID=87539088

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310535445.9A Pending CN116578400A (zh) 2023-05-09 2023-05-09 多任务数据处理方法和装置

Country Status (1)

Country Link
CN (1) CN116578400A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117574179A (zh) * 2024-01-16 2024-02-20 北京趋动智能科技有限公司 多任务学习模型构建方法及装置

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117574179A (zh) * 2024-01-16 2024-02-20 北京趋动智能科技有限公司 多任务学习模型构建方法及装置
CN117574179B (zh) * 2024-01-16 2024-05-28 北京趋动智能科技有限公司 多任务学习模型构建方法及装置

Similar Documents

Publication Publication Date Title
US11042898B2 (en) Clickstream purchase prediction using Hidden Markov Models
WO2023040494A9 (zh) 资源推荐方法、多目标融合模型的训练方法及装置
CN110111139B (zh) 行为预估模型生成方法、装置、电子设备及可读介质
US20200034776A1 (en) Managing skills as clusters using machine learning and domain knowledge expert
US11288709B2 (en) Training and utilizing multi-phase learning models to provide digital content to client devices in a real-time digital bidding environment
US20120253945A1 (en) Bid traffic estimation
US20220366295A1 (en) Pre-search content recommendations
US10825071B2 (en) Adaptive multi-perceptual similarity detection and resolution
CN113344647B (zh) 一种信息推荐的方法及装置
US10678821B2 (en) Evaluating theses using tree structures
CN115564517A (zh) 商品推荐方法、预测模型训练方法和相关设备
CN115222433A (zh) 一种信息推荐方法、装置及存储介质
Wang et al. Webpage depth viewability prediction using deep sequential neural networks
CN114595323B (zh) 画像构建、推荐、模型训练方法、装置、设备及存储介质
CN116578400A (zh) 多任务数据处理方法和装置
Haridasan et al. Arithmetic Optimization with Deep Learning Enabled Churn Prediction Model for Telecommunication Industries.
CN114330837A (zh) 对象处理方法、装置、计算机设备和存储介质
CN113792952A (zh) 用于生成模型的方法和装置
CN116975686A (zh) 训练学生模型的方法、行为预测方法和装置
CN113836390A (zh) 资源推荐方法、装置、计算机设备及存储介质
Chashmi et al. Predicting customer turnover using recursive neural networks
WO2023221359A1 (zh) 基于多阶段时序多任务的用户安全等级识别方法及装置
CN115618079A (zh) 会话推荐方法、装置、电子设备及存储介质
CN112200602B (zh) 用于广告推荐的神经网络模型训练方法及装置
CN110580261B (zh) 针对高科技公司的深度技术追踪方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination