CN115359321A - 一种模型训练方法、装置、电子设备及存储介质 - Google Patents
一种模型训练方法、装置、电子设备及存储介质 Download PDFInfo
- Publication number
- CN115359321A CN115359321A CN202211064160.3A CN202211064160A CN115359321A CN 115359321 A CN115359321 A CN 115359321A CN 202211064160 A CN202211064160 A CN 202211064160A CN 115359321 A CN115359321 A CN 115359321A
- Authority
- CN
- China
- Prior art keywords
- model
- modules
- teacher model
- teacher
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/60—Type of objects
- G06V20/62—Text, e.g. of license plates, overlay texts or captions on TV images
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V30/00—Character recognition; Recognising digital ink; Document-oriented image-based pattern recognition
- G06V30/10—Character recognition
- G06V30/19—Recognition using electronic means
- G06V30/191—Design or setup of recognition systems or techniques; Extraction of features in feature space; Clustering techniques; Blind source separation
- G06V30/19147—Obtaining sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V40/00—Recognition of biometric, human-related or animal-related patterns in image or video data
- G06V40/10—Human or animal bodies, e.g. vehicle occupants or pedestrians; Body parts, e.g. hands
- G06V40/16—Human faces, e.g. facial parts, sketches or expressions
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Multimedia (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Oral & Maxillofacial Surgery (AREA)
- Human Computer Interaction (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- Databases & Information Systems (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Software Systems (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明的实施例提供了一种模型训练方法、装置、电子设备及存储介质,方法包括:通过确定教师模型和学生模型,确定初始训练样本数据,逐步将学生模型中的第二模块将教师模型中的第一模块进行替换,每次替换后均进行训练,得到新的教师模型,直到最新得到的新的教师模型中的第一模块均被学生模型中的第二模块替换,得到训练好的目标模型,实现逐步用学生模型的模块替换掉教师模型中的模块,并训练替换模块后的教师模型,从而实现学生模型学习迁移来自教师模型的监督信息,有效降低学生模型学习所需要的训练数据数量,减少训练时间并且提高学生模型的精度。
Description
技术领域
本发明涉及模型训练技术领域,具体而言,涉及一种模型训练方法、装置、电子设备及存储介质。
背景技术
随着人工智能技术的发展,知识蒸馏技术在模型训练过程中的应用越来越广泛。其中,知识蒸馏是一种采用预先训练好的结构复杂的教师模型(Teacher Model)来训练结构简单的学生模型(Student Model),以实现将教师模型功能赋予学生模型的技术。那么,如何基于知识蒸馏技术,高精度的训练学生模型至关重要。
发明内容
本发明的目的在于提供一种模型训练方法、装置、电子设备及存储介质,能够提高训练学生模型的精度。
为了实现上述目的,本发明实施例采用的技术方案如下:
第一方面,本发明实施例提供了一种模型训练方法,所述方法包括:
确定教师模型和学生模型;
确定初始训练样本数据,其中,所述初始训练样本数据为训练所述教师模型所使用的训练样本数据;
将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型,其中,所述教师模型包括多个第一模块,所述学生模型包括多个第二模块;
基于所述初始训练样本数据对所述更新后的教师模型进行训练,得到新的教师模型;
返回执行所述将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型至所述基于所述初始训练样本数据对所述更新后的教师模型进行训练的步骤,直到最新得到的新的教师模型中的第一模块均被所述学生模型中的第二模块替换,得到训练好的目标模型,其中,所述目标模型中的模块为所述学生模型中的第二模块。
在可选的实施方式中,所述将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型的步骤,包括:
基于伯努利分布方式,控制将所述教师模型中的第一模块替换为与所述学生模型中对应的第二模块的替换概率;
基于所述替换概率,将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型。
在可选的实施方式中,所述伯努利分布方式满足以下公式:
pd=min(1,θ(t))=min(1,kt=b);
其中,b是初始替换率,k是大于0的系数,t是替换次数。
在可选的实施方式中,所述方法还包括:
将待检测数据输入至所述目标模型,得到预测数据;
将所述预测数据进行清洗,得到第一训练数据;
基于所述第一训练数据对所述目标模型进行训练。
在可选的实施方式中,所述基于所述初始训练样本数据对所述更新后的教师模型进行训练,得到新的教师模型的步骤,包括:
基于交叉熵损失函数,确定所述初始训练样本数据的真实标签与预测标签的损失;
基于所述损失对所述更新后的教师模型的参数进行调整,以获得新的教师模型;
返回执行所述基于交叉熵损失函数,确定所述初始训练样本数据的真实标签与预测标签的损失至所述基于所述损失对所述更新后的教师模型的参数进行调整,以获得新的教师模型的步骤,直至达到预设训练次数,得到新的教师模型。
在可选的实施方式中,所述交叉熵损失函数满足以下公式:
L=-∑j∈|X|∑c∈C[[zj=c]·log P(zj=c|xj)];
其中xj∈X为第j个初始训练样本,X为初始训练样本集合,zj为初始训练样本的真实标签,c为初始样本的类标签,C为初始训练样本集合的类标签集合,P为初始训练样本的真实标签与预测标签的概率差值。
在可选的实施方式中,所述将所述预测数据进行清洗,得到第一训练数据的步骤,包括:
确定所述预测数据的置信度值;
将置信度小于阈值的第一预测数据进行人工审核;
接收人工审核后的第一预测数据;
将人工审核后的所述第一预测数据,作为第一训练数据。
第二方面,本发明实施例提供了一种模型训练装置,所述装置包括:
第一确定模块,用于确定教师模型和学生模型;
第二确定模块,用于确定初始训练样本数据,其中,所述初始训练样本数据为训练所述教师模型所使用的训练样本数据;
替换模块,用于将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型,其中,所述教师模型包括多个第一模块,所述学生模型包括多个第二模块;
训练模块,用于基于所述初始训练样本数据对所述更新后的教师模型进行训练,得到新的教师模型;
执行模块,用于返回执行所述将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型至所述基于所述初始训练样本数据对所述更新后的教师模型进行训练的步骤,直到最新得到的新的教师模型中的第一模块均被所述学生模型中的第二模块替换,得到训练好的目标模型,其中,所述目标模型中的模块为所述学生模型中的第二模块。
第三方面,本发明实施例提供了一种电子设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现所述模型训练方法的步骤。
第四方面,本发明实施例提供了一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现所述模型训练方法的步骤。
本发明具有以下有益效果:
本发明通过确定教师模型和学生模型,确定初始训练样本数据,将教师模型中的部分第一模块替换为与学生模型中对应的第二模块,得到更新后的教师模型,基于初始训练样本数据对更新后的教师模型进行训练,得到新的教师模型,返回执行将教师模型中的部分第一模块替换为与学生模型中对应的第二模块,得到更新后的教师模型至基于初始训练样本数据对更新后的教师模型进行训练的步骤,直到最新得到的新的教师模型中的第一模块均被学生模型中的第二模块替换,得到训练好的目标模型,实现逐步用学生模型的模块替换掉教师模型中的模块,并训练替换模块后的教师模型,从而实现学生模型学习迁移来自教师模型的监督信息,有效降低学生模型学习所需要的训练数据数量,减少训练时间并且提高学生模型的精度。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,应当理解,以下附图仅示出了本发明的某些实施例,因此不应被看作是对范围的限定,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他相关的附图。
图1为本发明实施例提供的电子设备的方框示意图;
图2为本发明实施例提供的一种模型训练方法的步骤流程图之一;
图3为本发明实施例提供的一种模型训练方法的步骤流程图之二;
图4为本发明实施例提供的一种模型训练方法的步骤流程图之三;
图5为本发明实施例提供的一种模型训练方法的步骤流程图之四;
图6为本发明实施例提供的一种模型训练装置的结构框图。
具体实施方式
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。通常在此处附图中描述和示出的本发明实施例的组件可以以各种不同的配置来布置和设计。
因此,以下对在附图中提供的本发明的实施例的详细描述并非旨在限制要求保护的本发明的范围,而是仅仅表示本发明的选定实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
应注意到:相似的标号和字母在下面的附图中表示类似项,因此,一旦某一项在一个附图中被定义,则在随后的附图中不需要对其进行进一步定义和解释。
在本发明的描述中,需要说明的是,若出现术语“上”、“下”、“内”、“外”等指示的方位或位置关系为基于附图所示的方位或位置关系,或者是该发明产品使用时惯常摆放的方位或位置关系,仅是为了便于描述本发明和简化描述,而不是指示或暗示所指的装置或元件必须具有特定的方位、以特定的方位构造和操作,因此不能理解为对本发明的限制。
此外,若出现术语“第一”、“第二”等仅用于区分描述,而不能理解为指示或暗示相对重要性。
在本发明的描述中,还需要说明的是,除非另有明确的规定和限定,术语“设置”、“安装”、“相连”、“连接”应做广义理解,例如,可以是固定连接,也可以是可拆卸连接,或一体地连接;可以是机械连接,也可以是电连接;可以是直接相连,也可以通过中间媒介间接相连,可以是两个元件内部的连通。对于本领域的普通技术人员而言,可以具体情况理解上述术语在本发明中的具体含义。
经过发明人大量研究发现,随着人工智能技术的发展,知识蒸馏技术在模型训练过程中的应用越来越广泛。其中,知识蒸馏是一种采用预先训练好的结构复杂的教师模型(Teacher Model)来训练结构简单的学生模型(Student Model),以实现将教师模型功能赋予学生模型的技术。那么,如何基于知识蒸馏技术,高精度的训练学生模型至关重要。
有鉴于对上述问题的发现,本实施例提供了一种模型训练方法、装置、电子设备及存储介质,能够逐步用学生模型的模块替换掉教师模型中的模块,并训练替换模块后的教师模型,从而实现学生模型学习迁移来自教师模型的监督信息,有效降低学生模型学习所需要的训练数据数量,减少训练时间并且提高学生模型的精度,下面对本实施例提供的方案进行详细阐述。
本实施例提供一种可以对模型进行训练的电子设备。在一种可能的实现方式中,所述电子设备可以为用户终端,例如,电子设备可以是,但不限于,服务器、智能手机、个人电脑(PersonalComputer,PC)、平板电脑、个人数字助理(Personal Digital Assistant,PDA)、移动上网设备(Mobile Internet Device,MID)等。
请参照图1,图1是本发明实施例提供的电子设备100的结构示意图。所述电子设备100还可包括比图1中所示更多或者更少的组件,或者具有与图1所示不同的配置。图1中所示的各组件可以采用硬件、软件或其组合实现。
所述电子设备100包括模型训练装置110、存储器120及处理器130。
所述存储器120及处理器130各元件相互之间直接或间接地电性连接,以实现数据的传输或交互。例如,这些元件相互之间可通过一条或多条通讯总线或信号线实现电性连接。所述模型训练装置110包括至少一个可以软件或固件(firmware)的形式存储于所述存储器120中或固化在所述电子设备100的操作系统(operating system,OS)中的软件功能模块。所述处理器130用于执行所述存储器120中存储的可执行模块,例如所述模型训练装置110所包括的软件功能模块及计算机程序等。
其中,所述存储器120可以是,但不限于,随机存取存储器(RandomAccess Memory,RAM),只读存储器(Read Only Memory,ROM),可编程只读存储器(Programmable Read-OnlyMemory,PROM),可擦除只读存储器(Erasable ProgrammableRead-Only Memory,EPROM),电可擦除只读存储器(Electric Erasable ProgrammableRead-Only Memory,EEPROM)等。其中,存储器120用于存储程序,所述处理器130在接收到执行指令后,执行所述程序。
请参照图2,图2为应用于图1的电子设备100的一种模型训练方法的流程图,以下将方法包括各个步骤进行详细阐述。
步骤201:确定教师模型和学生模型。
步骤202:确定初始训练样本数据。
其中,初始训练样本数据为训练教师模型所使用的训练样本数据。
步骤203:将教师模型中的部分第一模块替换为与学生模型中对应的第二模块,得到更新后的教师模型。
其中,教师模型包括多个第一模块,学生模型包括多个第二模块。
步骤204:基于初始训练样本数据对更新后的教师模型进行训练,得到新的教师模型。
步骤205:返回执行将教师模型中的部分第一模块替换为与学生模型中对应的第二模块,得到更新后的教师模型至基于初始训练样本数据对更新后的教师模型进行训练的步骤,直到最新得到的新的教师模型中的第一模块均被学生模型中的第二模块替换,得到训练好的目标模型。
其中,目标模型中的模块为学生模型中的第二模块。
重量化的模型称之为教师模型,轻量化的模型称之为学生模型。
教师模型中包含多个第一模块,学生模型中包含多个第二模块,第一模块与第二模块之间存在对应关系。
在一示例中,预设数量的第一模块与一个第二模块存在对应关系。例如:三个第一模块与一个第一模块存在对应关系,且三个第一模块的功能与一个第一模块的功能相同。
将教师模型中的第一模块逐渐替换为学生模型中的第二模块,示例性的:
假设教师模型中包括6个第一模块,学生模型中包括两个第二模块,其中,6个第一模块分为两组,其中每组包含三个第一模块,且每组实现的功能不同,学生模型中的两个第二模块与教师模型中两组模块的功能对应,将第一组中的第一模块替换为学生模型中对应的第二模块,得到更新后的教师模型,基于对未替换教师模型的原始训练样本数据对更换第一组第一模块的教师模型进行训练,得到新的教师模型,将新的教师模型中第二组中的第二模块替换为学生模型中对应的第二模块,基于对未更换教师模型的原始训练样本数据对再次更换第二组第一模块的的教师模型进行训练,得到最新的教师模型,将最新的教师模型作为目标模型,且最新的教师模型中的模块为学生模型中的第二模块。
本发明通过确定教师模型和学生模型,确定初始训练样本数据,将教师模型中的部分第一模块替换为与学生模型中对应的第二模块,得到更新后的教师模型,基于初始训练样本数据对更新后的教师模型进行训练,得到新的教师模型,返回执行将教师模型中的部分第一模块替换为与学生模型中对应的第二模块,得到更新后的教师模型至基于初始训练样本数据对更新后的教师模型进行训练的步骤,直到最新得到的新的教师模型中的第一模块均被学生模型中的第二模块替换,得到训练好的目标模型,实现逐步用学生模型的第二模块替换掉教师模型中所有的第一模块,并训练替换模块后的教师模型,从而实现学生模型学习迁移来自教师模型的监督信息,有效降低学生模型学习所需要的训练数据数量,减少训练时间并且提高学生模型的精度。
将教师模型中的部分第一模块替换为与学生模型中对应的第二模块,得到更新后的教师模型的方式多种,在一示例中,如图3所示,提供了一种模型训练方法,具体包括如下步骤:
步骤203-1:基于伯努利分布方式,控制将教师模型中的第一模块替换为与学生模型中对应的第二模块的替换概率。
步骤203-2:基于替换概率,将教师模型中的部分第一模块替换为与学生模型中对应的第二模块,得到更新后的教师模型。
伯努利分布方式满足以下公式:
pd=min(1,θ(t))=min(1,kt=b);
其中,b是初始替换率,k是大于0的系数,t是替换次数。
随着替换次数的增大,学生模型中的第二模块逐渐将教师模型中的第一模块全部替换完成。
基于初始训练样本数据对更新后的教师模型进行训练的方式有多种,在一示例中,如图4所示,提供了一种模型训练方法,具体包括如下步骤:
步骤204-1:基于交叉熵损失函数,确定初始训练样本数据的真实标签与预测标签的损失。
步骤204-2:基于损失对更新后的教师模型的参数进行调整,以获得新的教师模型。
步骤204-3:返回执行基于交叉熵损失函数,确定初始训练样本数据的真实标签与预测标签的损失至基于损失对更新后的教师模型的参数进行调整,以获得新的教师模型的步骤,直至达到预设训练次数,得到新的教师模型。
在一示例中,当初始训练样本数据为人脸图像,且教师模型用于检测人脸图像时,初始训练样本数据中的真实标签为人脸的真实位置的人脸框,初始训练样本数据的预测标签为基于新的教师模型预测得到的人脸框,基于交叉熵损失函数,确定初始训练样本数据的真实人脸框和预测人脸框之间的损失,基于损失对更新后的教师模型的参数进行调整,以获得新的教师模型,基于新的教师模型确定初始训练样本数据的更新后的预测标签,重复执行基于交叉熵损失函数计算更新后的预测标签与真实标签的损失,并基于损失对更新后的教师模型的参数进行调整,直到达到预设训练次数或者直到计算得到的损失收敛为止,更新后的教师模型训练完成,得到新的教师模型。
在另一示例中,当初始训练样本数据为文字信息时,且教师模型用于检测文字信息中的关键词信息,初始训练样本数据中的真实标签为文字信息的真实关键词,初始训练样本数据的预测标签为基于新的教师模型预测得到的文字信息中的预测关键词,基于交叉熵损失函数,计算预测关键词和真实关键词的损失,基于损失对更新后的教师模型的参数进行调整,以获得新的教师模型,基于新的教师模型再次对初始训练样本数据进行预测,得到更新后的预测关键词,重复执行基于交叉熵损失函数计算更新后的预测关键词与真实关键词的损失,并基于损失对更新后的教师模型的参数进行调整,直到达到预测训练次数或者直到计算得到的损失收敛为止,更新后的教师模型训练完成,得到新的教师模型。
在基于损失调整更新后的教师模型的参数时,在反向传播中,教师模型的嵌入层和输出层的权重值都是冻结的,因此,从替换的学生模型中的第二模块中获取嵌入层和输出层的权重值,通过这种方式,使得教师模型和学生模型进行更深入的交互。
交叉熵损失函数满足以下公式:
L=-∑j∈|X|∑c∈C[[zj=c]·log P(zj=c|xj)]
其中xj∈X为第j个初始训练样本,X为初始训练样本集合,zj为初始训练样本的真实标签,c为初始样本的类标签,C为初始训练样本集合的类标签集合,P为初始训练样本的真实标签与预测标签的概率差值。
为了提升目标模型的准确度,在一示例中,如图5所示,提供了一种模型训练方法,具体包括如下步骤:
步骤301:将待检测数据输入至目标模型,得到预测数据。
步骤302:将预测数据进行清洗,得到第一训练数据。
步骤303:基于第一训练数据对目标模型进行训练。
对预测数据进行清洗得到第一训练数据的方式有多种,在一示例中,对预测数据进行清洗的方式可以为:
确定预测数据的置信度值,将置信度小于阈值的第一预测数据进行人工审核,接收人工审核后的第一预测数据,将人工审核后的第一预测数据,作为第一训练数据。
需要说明的是,待检测数据为未标注的数据,将未标注的数据输入至目标模型中,目标模型输出待检测数据的预测数据,且每个预测数据均携带有置信度,置信度即为预测数据预测的概率。
将预测数据按照置信度大小进行排序,将置信度小于预设阈值的第一预测数据进行人工审核,即将第一预测数据进行人工标注,将标注后的第一预测数据,作为第一训练数据,对目标模型进行训练。
人工复审后的第一预测数据可理解为模型难辩别的数据,这部分数据为模型自适应调整的依据。
再将第一预测数据输入至目标模型进行训练之前,可以将第一预测数据进行数据预处理,以文本数据为例,预处理方式可以为:
去除文本中的停用词、特殊符号,将文本进行分句,若为中文文本则需要进一步分词处理,若为英文文本则直接以空格分词,利用Bert进行词向量表示,以句子为单位,n为一条句子包含的词数量,每个词向量记为xi,i=1,2,……,n。将预处理后的第一预测数据直接输入至目标模型进行训练。
本发明提供的实施例,将未标注的待检测数据输入至替换完成的目标模型,得到预测数据的置信度,根据置信度大小进行人工复审,人工复审后第一预测数据累积并反馈到目标模型进行迭代,可以再次提高目标模型的精度。
请参照图6,本发明实施例还提供了一种应用于图1所述电子设备100的模型训练装置110,所述模型训练装置110包括:
第一确定模块111,用于确定教师模型和学生模型;
第二确定模块112,用于确定初始训练样本数据,其中,所述初始训练样本数据为训练所述教师模型所使用的训练样本数据;
替换模块113,用于将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型,其中,所述教师模型包括多个第一模块,所述学生模型包括多个第二模块;
训练模块114,用于基于所述初始训练样本数据对所述更新后的教师模型进行训练,得到新的教师模型;
执行模块115,用于返回执行所述将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型至所述基于所述初始训练样本数据对所述更新后的教师模型进行训练的步骤,直到最新得到的新的教师模型中的第一模块均被所述学生模型中的第二模块替换,得到训练好的目标模型,其中,所述目标模型中的模块为所述学生模型中的第二模块。
本发明还提供一种电子设备100,电子设备100包括处理器130以及存储器120。存储器120存储有计算机可执行指令,计算机可执行指令被处理器130执行时,实现该模型训练方法。
本发明实施例还提供一种计算机可读存储介质,存储介质存储有计算机程序,计算机程序被处理器130执行时,实现该模型训练方法。
需要说明的是,本实施例所提供的模型训练装置,其基本原理及产生的技术效果和上述方法实施例相同,为简要描述,本实施例部分未提及之处,可参考上述的方法实施例中相应内容。
综上,本发明通过确定教师模型和学生模型,确定初始训练样本数据,将教师模型中的部分第一模块替换为与学生模型中对应的第二模块,得到更新后的教师模型,基于初始训练样本数据对更新后的教师模型进行训练,得到新的教师模型,返回执行将教师模型中的部分第一模块替换为与学生模型中对应的第二模块,得到更新后的教师模型至基于初始训练样本数据对更新后的教师模型进行训练的步骤,直到最新得到的新的教师模型中的第一模块均被学生模型中的第二模块替换,得到训练好的目标模型,实现逐步用学生模型的模块替换掉教师模型中的模块,并训练替换模块后的教师模型,从而实现学生模型学习迁移来自教师模型的监督信息,有效降低学生模型学习所需要的训练数据数量,减少训练时间并且提高学生模型的精度。
在本发明所提供的实施例中,应该理解到,所揭露的装置和方法,也可以通过其它的方式实现。以上所描述的装置实施例仅仅是示意性的,例如,附图中的流程图和框图显示了根据本发明的多个实施例的装置、方法和计算机程序产品的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段或代码的一部分,所述模块、程序段或代码的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。也应当注意,在有些作为替换的实现方式中,方框中所标注的功能也可以以不同于附图中所标注的顺序发生。例如,两个连续的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这依所涉及的功能而定。也要注意的是,框图和/或流程图中的每个方框、以及框图和/或流程图中的方框的组合,可以用执行规定的功能或动作的专用的基于硬件的系统来实现,或者可以用专用硬件与计算机指令的组合来实现。
另外,在本发明各个实施例中的各功能模块可以集成在一起形成一个独立的部分,也可以是各个模块单独存在,也可以两个或两个以上模块集成形成一个独立的部分。所述功能如果以软件功能模块的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
以上所述,仅为本发明的各种实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应所述以权利要求的保护范围为准。
Claims (10)
1.一种模型训练方法,其特征在于,所述方法包括:
确定教师模型和学生模型;
确定初始训练样本数据,其中,所述初始训练样本数据为训练所述教师模型所使用的训练样本数据;
将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型,其中,所述教师模型包括多个第一模块,所述学生模型包括多个第二模块;
基于所述初始训练样本数据对所述更新后的教师模型进行训练,得到新的教师模型;
返回执行所述将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型至所述基于所述初始训练样本数据对所述更新后的教师模型进行训练的步骤,直到最新得到的新的教师模型中的第一模块均被所述学生模型中的第二模块替换,得到训练好的目标模型,其中,所述目标模型中的模块为所述学生模型中的第二模块。
2.根据权利要求1所述的方法,其特征在于,所述将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型的步骤,包括:
基于伯努利分布方式,控制将所述教师模型中的第一模块替换为与所述学生模型中对应的第二模块的替换概率;
基于所述替换概率,将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型。
3.根据权利要求2所述的方法,其特征在于,所述伯努利分布方式满足以下公式:
pd=min(1,θ(t))=min(1,kt=b);
其中,b是初始替换率,k是大于0的系数,t是替换次数。
4.根据权利要求1所述的方法,其特征在于,所述方法还包括:
将待检测数据输入至所述目标模型,得到预测数据;
将所述预测数据进行清洗,得到第一训练数据;
基于所述第一训练数据对所述目标模型进行训练。
5.根据权利要求1所述的方法,其特征在于,所述基于所述初始训练样本数据对所述更新后的教师模型进行训练,得到新的教师模型的步骤,包括:
基于交叉熵损失函数,确定所述初始训练样本数据的真实标签与预测标签的损失;
基于所述损失对所述更新后的教师模型的参数进行调整,以获得新的教师模型;
返回执行所述基于交叉熵损失函数,确定所述初始训练样本数据的真实标签与预测标签的损失至所述基于所述损失对所述更新后的教师模型的参数进行调整,以获得新的教师模型的步骤,直至达到预设训练次数,得到新的教师模型。
6.根据权利要求5所述的方法,其特征在于,所述交叉熵损失函数满足以下公式:
L=-∑j∈|X|∑c∈C[[zj=c]·logP(zj=c∣xj)];
其中xj∈X为第j个初始训练样本,X为初始训练样本集合,zj为初始训练样本的真实标签,c为初始样本的类标签,C为初始训练样本集合的类标签集合,P为初始训练样本的真实标签与预测标签的概率差值。
7.根据权利要求4所述的方法,其特征在于,所述将所述预测数据进行清洗,得到第一训练数据的步骤,包括:
确定所述预测数据的置信度值;
将置信度小于阈值的第一预测数据进行人工审核;
接收人工审核后的第一预测数据;
将人工审核后的所述第一预测数据,作为第一训练数据。
8.一种模型训练装置,其特征在于,所述装置包括:
第一确定模块,用于确定教师模型和学生模型;
第二确定模块,用于确定初始训练样本数据,其中,所述初始训练样本数据为训练所述教师模型所使用的训练样本数据;
替换模块,用于将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型,其中,所述教师模型包括多个第一模块,所述学生模型包括多个第二模块;
训练模块,用于基于所述初始训练样本数据对所述更新后的教师模型进行训练,得到新的教师模型;
执行模块,用于返回执行所述将所述教师模型中的部分第一模块替换为与所述学生模型中对应的第二模块,得到更新后的教师模型至所述基于所述初始训练样本数据对所述更新后的教师模型进行训练的步骤,直到最新得到的新的教师模型中的第一模块均被所述学生模型中的第二模块替换,得到训练好的目标模型,其中,所述目标模型中的模块为所述学生模型中的第二模块。
9.一种电子设备,其特征在于,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现权利要求1-7任一项所述方法的步骤。
10.一种存储介质,其上存储有计算机程序,其特征在于,该计算机程序被处理器执行时实现权利要求1-7中任一项所述方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211064160.3A CN115359321A (zh) | 2022-09-01 | 2022-09-01 | 一种模型训练方法、装置、电子设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211064160.3A CN115359321A (zh) | 2022-09-01 | 2022-09-01 | 一种模型训练方法、装置、电子设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115359321A true CN115359321A (zh) | 2022-11-18 |
Family
ID=84005600
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211064160.3A Pending CN115359321A (zh) | 2022-09-01 | 2022-09-01 | 一种模型训练方法、装置、电子设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115359321A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116030418A (zh) * | 2023-02-14 | 2023-04-28 | 北京建工集团有限责任公司 | 一种汽车吊运行状态监测系统及方法 |
CN116070697A (zh) * | 2023-01-17 | 2023-05-05 | 北京理工大学 | 一种可替换式的便捷知识蒸馏方法及系统 |
-
2022
- 2022-09-01 CN CN202211064160.3A patent/CN115359321A/zh active Pending
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116070697A (zh) * | 2023-01-17 | 2023-05-05 | 北京理工大学 | 一种可替换式的便捷知识蒸馏方法及系统 |
CN116030418A (zh) * | 2023-02-14 | 2023-04-28 | 北京建工集团有限责任公司 | 一种汽车吊运行状态监测系统及方法 |
CN116030418B (zh) * | 2023-02-14 | 2023-09-12 | 北京建工集团有限责任公司 | 一种汽车吊运行状态监测系统及方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110750959B (zh) | 文本信息处理的方法、模型训练的方法以及相关装置 | |
CN108062388B (zh) | 人机对话的回复生成方法和装置 | |
CN110765775B (zh) | 一种融合语义和标签差异的命名实体识别领域自适应的方法 | |
CN115359321A (zh) | 一种模型训练方法、装置、电子设备及存储介质 | |
CN109214006B (zh) | 图像增强的层次化语义表示的自然语言推理方法 | |
CN109271493A (zh) | 一种语言文本处理方法、装置和存储介质 | |
CN111985239A (zh) | 实体识别方法、装置、电子设备及存储介质 | |
CN113392209B (zh) | 一种基于人工智能的文本聚类方法、相关设备及存储介质 | |
CN111178036B (zh) | 一种知识蒸馏的文本相似度匹配模型压缩方法及系统 | |
CN110427629A (zh) | 半监督文本简化模型训练方法和系统 | |
CN113705196A (zh) | 基于图神经网络的中文开放信息抽取方法和装置 | |
CN116861929A (zh) | 基于深度学习的机器翻译系统 | |
CN116402352A (zh) | 一种企业风险预测方法、装置、电子设备及介质 | |
CN117236335B (zh) | 基于提示学习的两阶段命名实体识别方法 | |
CN110569355A (zh) | 一种基于词块的观点目标抽取和目标情感分类联合方法及系统 | |
CN109979461A (zh) | 一种语音翻译方法及装置 | |
CN116680575B (zh) | 模型处理方法、装置、设备及存储介质 | |
CN116720519B (zh) | 一种苗医药命名实体识别方法 | |
CN112906398A (zh) | 句子语义匹配方法、系统、存储介质和电子设备 | |
CN110852066B (zh) | 一种基于对抗训练机制的多语言实体关系抽取方法及系统 | |
CN116208399A (zh) | 一种基于元图的网络恶意行为检测方法及设备 | |
CN113849634B (zh) | 用于提升深度模型推荐方案可解释性的方法 | |
CN113051607B (zh) | 一种隐私政策信息提取方法 | |
CN114398482A (zh) | 一种词典构造方法、装置、电子设备及存储介质 | |
CN114330238A (zh) | 文本处理方法、文本处理装置、电子设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |