CN116484881A - 对话生成模型的训练方法、装置、存储介质及计算机设备 - Google Patents
对话生成模型的训练方法、装置、存储介质及计算机设备 Download PDFInfo
- Publication number
- CN116484881A CN116484881A CN202310469117.3A CN202310469117A CN116484881A CN 116484881 A CN116484881 A CN 116484881A CN 202310469117 A CN202310469117 A CN 202310469117A CN 116484881 A CN116484881 A CN 116484881A
- Authority
- CN
- China
- Prior art keywords
- training
- scene
- batch
- groups
- data set
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 352
- 238000000034 method Methods 0.000 title claims abstract description 54
- 238000003062 neural network model Methods 0.000 claims abstract description 24
- 238000012163 sequencing technique Methods 0.000 claims abstract description 11
- 238000002372 labelling Methods 0.000 claims abstract description 7
- 238000012795 verification Methods 0.000 claims description 17
- 238000004590 computer program Methods 0.000 claims description 8
- 125000004122 cyclic group Chemical group 0.000 claims description 4
- 238000013473 artificial intelligence Methods 0.000 abstract description 2
- 238000003745 diagnosis Methods 0.000 description 18
- 238000004891 communication Methods 0.000 description 5
- 238000010586 diagram Methods 0.000 description 5
- 230000000694 effects Effects 0.000 description 4
- 238000001514 detection method Methods 0.000 description 3
- 230000010365 information processing Effects 0.000 description 2
- 238000012360 testing method Methods 0.000 description 2
- 238000004422 calculation algorithm Methods 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000002405 diagnostic procedure Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 206010027175 memory impairment Diseases 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F40/00—Handling natural language data
- G06F40/30—Semantic analysis
- G06F40/35—Discourse or dialogue representation
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G16—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
- G16H—HEALTHCARE INFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR THE HANDLING OR PROCESSING OF MEDICAL OR HEALTHCARE DATA
- G16H80/00—ICT specially adapted for facilitating communication between medical practitioners or patients, e.g. for collaborative diagnosis, therapy or health monitoring
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Medical Informatics (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- Software Systems (AREA)
- Biophysics (AREA)
- Primary Health Care (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Public Health (AREA)
- Epidemiology (AREA)
- Pathology (AREA)
- Audiology, Speech & Language Pathology (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种对话生成模型的训练方法、装置、存储介质及计算机设备,涉及人工智能及智慧医疗技术领域。其中方法包括:获取对话生成模型的训练数据集,其中,训练数据集包括多个预设有场景标签的训练样本,场景标签用于标注训练样本适用的场景;基于场景标签将训练样本划分为多个批次组,其中,每个批次组包含属于同一场景的预设数量的训练样本;对多个批次组进行随机排序,生成用于训练对话生成模型的目标训练数据集;根据目标训练数据集,对预设的神经网络模型进行训练,得到对话生成模型。上述方法能够从历史的对话样本中,生成按照适用场景随机排序的批次组,并基于该训练数据集对对话生成模型进行训练,提升模型的性能。
Description
技术领域
本发明涉及人工智能及智慧医疗技术领域,尤其是涉及一种对话生成模型的训练方法、装置、存储介质及计算机设备。
背景技术
随着神经网络模型技术的发展,Transformer(编码解码模型)类预训练模型越来越受到各方的关注,其使得部署一个对话生成模型来应对多个不同的对话场景成为可能。特别是在医疗交互领域,对话生成模型可以接收来自不同医疗场景的对话信息,生成适用于不同医疗场景的回复信息。
当前,对话生成模型进行训练的方式多为随机的选取各场景的历史对话数据作为训练数据集中每个批次层面上的训练数据,但基于该种方式对模型进行训练的过程会因训练批次层面上的训练数据过于分散,导致对话生成模型收敛速度较慢。此外,将多个场景下的训练数据按照场景顺序对神经网络模型进行训练,会造成模型学习新知识后,几乎彻底遗忘掉之前学习的内容,导致在对模型的训练过程中会出现灾难性遗忘的情况,进而导致模型训练的效率大幅降低。
发明内容
有鉴于此,本申请提供了一种对话生成模型的训练方法、装置、存储介质及计算机设备,主要目的在于解决模型训练效率偏低的技术问题。
根据本发明的第一个方面,提供了一种对话生成模型的训练方法,该方法包括:
获取对话生成模型的训练数据集,其中,所述训练数据集包括多个预设有场景标签的训练样本,所述场景标签用于标注所述训练样本适用的场景;
基于所述场景标签将所述训练样本划分为多个批次组,其中,每个所述批次组包含属于同一场景的预设数量的训练样本;
对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集;
根据所述目标训练数据集,对预设的神经网络模型进行训练,得到所述对话生成模型。
根据本发明的第二个方面,提供了一种对话生成模型的训练装置,该装置包括:
数据获取模块,同于获取对话生成模型的训练数据集,其中,所述训练数据集包括多个预设有场景标签的训练样本,所述场景标签用于标注所述训练样本适用的场景;
样本分组模块,用于基于所述场景标签将所述训练样本划分为多个批次组,其中,每个所述批次组包含属于同一场景的预设数量的训练样本;
数据生成模块,用于对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集;
模型训练模块,用于根据所述目标训练数据集,对预设的神经网络模型进行训练,得到所述对话生成模型。
根据本发明的第三个方面,提供了一种存储介质,其上存储有计算机程序,所述程序被处理器执行时实现上述对话生成模型的训练方法。
根据本发明的第四个方面,提供了一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述对话生成模型的训练方法。
本发明提供的一种对话生成模型的训练方法、装置、存储介质及计算机设备,能够对训练数据集中适用于不用场景的训练样本点进行分类,将适用于同一场景的训练样本在训练批次层面中进行分场景汇聚,使模型能够在适用于同一场景的训练样本下实现快速的收敛,加快模型的训练速度。同时,对批次组进行随机排序得到目标训练数据集,使基于目标训练数据集对对话生成模型进行训练时,不会出现因为训练数据集中某个场景的训练数据离训练数据集的末端太远,而导致的灾难性遗忘的情况发生,进而有效提高了对对话生成模型的训练效果,提升了对话生成模型的性能。
上述说明仅是本申请技术方案的概述,为了能够更清楚了解本申请的技术手段,而可依照说明书的内容予以实施,并且为了让本申请的上述和其它目的、特征和优点能够更明显易懂,以下特举本申请的具体实施方式。
附图说明
此处所说明的附图用来提供对本发明的进一步理解,构成本申请的一部分,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。在附图中:
图1示出了本发明实施例提供的一种对话生成模型的训练方法的流程示意图;
图2示出了本发明实施例提供的一种训练数据集的结构示意图之一;
图3示出了本发明实施例提供的一种训练数据集的结构示意图之二;
图4示出了本发明实施例提供的一种目标训练数据集的结构示意图;
图5示出了本发明实施例提供的一种对话生成模型的训练装置的结构示意图;
图6示出了本发明实施例提供的另一种对话生成模型的训练装置的结构示意图。
具体实施方式
下文中将参考附图并结合实施例来详细说明本发明。需要说明的是,在不冲突的情况下,本申请中的实施例及实施例中的特征可以相互组合。
现有的对话生成模型进行训练的方式多为随机的选取医疗领域各场景的历史对话数据作为训练数据集中每个批次层面上的训练数据,但基于该种方式对模型进行训练的过程会因训练批次层面上的训练数据过于分散,导致对话生成模型收敛速度较慢,同时,将多个场景下的训练数据按照场景顺序对神经网络模型进行训练,会造成模型学习新知识后,几乎彻底遗忘掉之前学习的内容,导致在对模型的训练过程中会出现灾难性遗忘的情况,进而导致模型训练的质量和效率大幅降低。
针对上述问题,在一个实施例中,如图1所示,提供了一种对话生成模型的训练方法,以该方法应用于计算机设备为例进行说明,包括以下步骤:
101、获取对话生成模型的训练数据集。
其中,所述训练数据集包括多个预设有场景标签的训练样本,所述场景标签用于标注所述训练样本适用的场景。
具体的,可以获取不同用户之间的历史对话数据作为训练数据集,例如,在医疗场景下,可以获取历史上医生和患者的诊断过程的对话数据,并按对话数据的应用场景进行分类,比如可以将医疗场景下的对话场景分为诊断场景、医生问询场景和治疗建议场景等。在获取训练数据集时,可以分别在诊断场景、医生问询场景和治疗建议场景中获取多组对话数据,并在每个对话数据上标注用于区分对话数据所属场景的场景标签,比如在诊断场景中获取的对话数据中标注诊断场景标签,在医生问询场景中获取的对话数据中标注医生问询场景标签,在治疗建议场景中获取的对话数据中标注治疗建议场景标签,得到训练样本。进一步的,如图2所示,可以将多个诊断场景下的训练样本A、多个医生问询场景下的训练样本B和多个治疗建议场景下的训练样本C组合成训练数据集10。应当注意的是,上述场景只是在医疗健康领域训练对话生成模型的常用场景,其他场景同样适用于本实施例。
进一步的,训练样本可以分为样本对话和标签对话,样本对话为对话数据中医生为得到适用于特定场景的结论与用户在进行沟通的对话,而标签对话为医生在得到特定场景的结论的对话。作为示例,若当前训练样本对应的场景为诊断场景,则样本对话为医生为了得到诊断结论与用户进行的对话沟通,如身体状况问询对话,而标签对话为医生针对与用户进行的对话沟通,得到的诊断结果。其他场景的样本对话和标签对话的内容可以参考上述示例,这里不再赘述。
102、基于所述场景标签将所述训练样本划分为多个批次组。
其中,每个所述批次组(batch)可以为一次训练所选取的训练样本,包含属于同一场景的预设数量的训练样本。
具体的,如图3所示,例如,在训练数据集10中,可以将带有诊断标签的训练样本A划分到多个批次组11中,每个批次组11中带有相同数量的诊断场景下的训练样本A。再如,可以将带有医生问询标签的训练样本B划分到多个批次组11中,每个批次组11中带有相同数量的医生问询场景下的训练样本B。再如,可以将带有治疗建议标签的训练样本C划分到多个批次组11中,每个批次组11中带有相同数量的治疗建议场景下的训练样本C。
103、对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集。
具体的,可以基于随机排布算法,对全部批次组进行随机排序,使适用于同一场景下的批次组不会连续的排列,再将随机排布后的批次组组合成目标训练数据集(epoch)。作为示例,如图4所示,可以将分别将包括诊断场景下的训练样本A的多个批次组11、包括医生问询场景下的训练样本B的多个批次组11以及包括治疗建议场景下的训练样本C的多个批次组11随机进行排布,得到目标训练数据集30,进而保证适用于同一场景的批次组离散的排列在目标训练数据集中。
104、根据所述目标训练数据集,对预设的神经网络模型进行训练,得到所述对话生成模型。
具体的,可以将目标训练数据集中每个训练样本的样本对话作为神经网络模型的输入,将标签对话作为神经网络模型的输出,对神经网络模型进行训练,得到训练后的对话生成模型。在本实施例中神经网络模型可以为预训练模型,具体的,预训练模型可以为GPT(Generative Pre-Training)模型,使用其他的预训练模型进行训练的到对话生成模型的方式同样适用于本实施例。
本实施例提供的对话生成模型的训练方法,能够对训练数据集中适用于不用场景的训练样本点进行分类,将适用于同一场景的训练样本在训练批次层面中进行分场景汇聚,使模型能够在适用于同一场景的训练样本下实现快速的收敛,加快模型的训练速度。同时,对批次组进行随机排序得到目标训练数据集,使基于目标训练数据集对对话生成模型进行训练时,不会出现因为训练数据集中某个场景的训练数据离训练数据集的末端太远,而导致的灾难性遗忘的情况发生,进而有效提高了对对话生成模型的训练效果,提升了对话生成模型的训练性能。
在一个实施例中,步骤102的具体实现方法可以为:
201、根据所述场景标签,将所述训练样本划分为多个场景组。
其中,每个所述场景组包含属于同一场景的全部训练样本。
具体的,若训练数据集中全部的训练样本对应的场景标签包括诊断场景、医生问询场景和治疗建议场景,则将对应有诊断场景标签的训练样本划分到诊断场景组中,将对应有医生问询场景标签的训练样本划分到医生问询场景组中,将对应有治疗建议场景标签的训练样本划分到治疗建议场景组中。进一步的,多个场景组可以根据课表学习的方式按照训练样本预设的场景类型顺序进行排列,作为示例,若训练数据集包含的训练样本对应有诊断场景、医生问询场景和治疗建议场景,而预设的场景类型顺序为诊断场景、医生问询场景、治疗建议场景,则将多个场景组按照诊断场景组、医生问询场景组、治疗建议场景组的顺序排列。其中,场景类型顺序可以根据实际情况确定。
202、将所述场景组内的多个所述训练样本划分为多个批次组。其中,每个所述批次组包括预设数量的所述训练样本。
具体的,可以在每个场景组中选出预设数量的训练样本作为批次组,每个场景组中的训练样本可以被划分到多个批次组中。在本申请的实施例中,可以将带有相同场景标签的训练样本划分到同一场景组,在将场景组内的训练样本划分到多个批次组中,保证每个批次组内包含的训练样本都适用于同一场景,更高效的对预训练模型进行训练。
在一个实施例中,步骤103的具体实现方法可以为:首先,执行循环过程直至满足预设条件,其中,所述循环过程包括:从每个所述场景组内选出一个所述批次组,并将选出的多个所述批次组随机组合成综合组。其中,所述综合组内的训练样本适用的场景包含所述训练数据集中全部训练样本适用的场景。作为示例,若场景组包括诊断场景组、医生问询场景组和治疗建议场景组,每个场景组内包含多个批次组,则在诊断场景组挑选一个批次组,在医生问询场景组中挑选一个批次组,在治疗建议场景组内挑选一个批次组,进一步的,可以将上述三个批次组按照预设顺序或随机排列组合成综合组。
随后,在上述每个场景组内挑选出一个没有被组合成综合组的批次组,将挑选出的批次组组成综合组,直到存在至少一个所述场景组中的全部所述批次组被组成所述综合组。具体的,当生成一个综合组后,若诊断场景组、医生问询场景组或治疗建议场景组中已经不存在批次组,则跳出循环过程。最后,将全部所述综合组与每个所述场景组内未被组成所述综合组的批次组进行随机排列,或对全部所述综合组进行随机排列,得到所述目标训练数据集。具体的,若全部的批次组都被组合成综合组,可以对全部综合组进行随机排列,得到目标训练数据集。若存在未被组合成综合组的批次组,则可以将得到的综合组和多个场景组中未被组成综合组的批次组进行随机排列,得到目标训练数据集。在本申请的实施例中,先在每个场景组中挑选出一个批次组得到综合组,在无法得到综合组后,将现有的综合组与未被组合成综合组的批次组随机排列成目标训练数据集,能最大限度的保证目标训练数据集中各场景下的批次组能够均匀分布,避免在模型训练过程中出现灾难性遗忘的情况。
在一个实施例中,步骤202的具体实现方法可以为:首先,将所述场景组内的训练样本排列成样本队列。作为示例,若诊断场景组中包含12个训练样本,则将12个训练样本排列成样本队列。然后,以所述样本队列的一个端点为起始点,依次获取所述预设数量的训练样本组成一个批次组,得到多个批次组。其中,预设数量可以为预设的批次大小(batchsize),将预设的批次大小数量的训练样本确定为批次组,用于对神经网络模型进行一次训练。预设数量的数值可以根据实际情况确定。作为示例,在本申请的实施例中,将可以预设数量的确定为4。进一步的,可以在样本队列的头端或末端选取相邻的4个训练样本确定为批次组,随后,将相邻的每4个训练样本确定为批次组,直到样本队列中所有的训练样本都被确定为批次组。在本申请的实施例中,可以在同一场景组中选取出多个批次组,进而保证每个批次组内包含的训练样本都适用于同一场景,能更高效的对预训练模型进行训练。
在一个实施例中,步骤104之前,本实施例提供的对话生成模型的训练方法还包括:首先,获取每个所述批次组内的场景标签。具体的,可以获取目标训练数据集中所有批次组中训练样本的场景标签。然后,判断同一个所述批次组内的场景标签是否对应同一个场景。具体的,分别判断每个批次组内的训练样本的场景标签,判断是否存在包含不同场景标签的训练样本被划分到同一个批次组中。若存在所述批次组内的场景标签未对应同一个场景,则发出报警提示信息。在本申请的实施例中,可以在基于目标训练数据集对神经网络模型进行训练之前,判断是否出现因先前步骤的计算错误,导致将对应不同场景的训练样本划分到同一批次组中,若存在上述情况,则向工作人员发出提示信息,待工作人员了解情况后做进一步处理,进而保证对模型的训练性能。
在一个实施例中,步骤104之前,本实施例提供的对话生成模型的训练方法还包括:首先,确定所述目标训练数据集内全部所述批次组的排列顺序和每个所述批次组内训练样本对应的场景标签。其中,批次组的排列顺序可以理解为目标训练数据集内的批次组对神经网络模型执行训练的顺序。最初对神经网络模型进行一次训练的批次组位于目标训练数据集的首端,相应的,最后对神经网络模型进行一次训练的批次组位于目标训练数据集的末端。基于目标训练数据集内的批次组对神经网络模型执行训练的顺序,可以得到批次组的排列顺序。然后,将所述目标训练数据集的末端作为起始点,选取预设数量的多个所述批次组确定为验证样本组,并基于所述场景标签确定每个所述验证样本组内的训练样本对应的场景。其中,为了使模型不会对某类场景出现灾难性遗忘,需要使每类场景对应的批次组不会距离目标训练数据集的末端过远。进一步的,只要保证处于目标训练数据集的末端特定数量的批次组包含所有类别的场景,则会确保每类场景对应的批次组都不会距离目标训练数据集的末端过远,防止模型发生灾难性遗忘。具体的,可以基于预先的测试,得到当特定数量在何值以内时,可以保证不会发生灾难性遗忘的情况,并将测试得到的特定数量的值确定为预设数量。同时,可以基于验证样本组内的训练样本的场景标签,确定每个验证样本组对应的场景。
进一步的,依次判断每个所述验证样本组内的训练样本对应的场景是否包含所述训练数据集中全部训练样本对应的场景。作为示例,若训练数据集中全部训练样本对应的场景包括诊断场景、医生问询场景和治疗建议场景,且验证样本组的数量为3个,则判断该3个验证样本组对应的场景是否包括诊断场景、医生问询场景和治疗建议场景。若全部所述验证样本组内的训练样本对应的场景未包含所述训练数据集中全部训练样本对应的场景,则重新执行所述对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集的步骤。作为示例,若判断该3个验证样本组对应的场景皆为诊断场景,则重新对多个批次组进行随机排序,生成目标训练数据集。在本申请的实施例中,可以避免出现任意一个场景的批次组能离目标训练数据集的末端间距过多的与其他场景对应的批次组,导致的灾难性遗忘的情况发生,保证模型训练的性能。
在一个实施例中,在一个实施例中,步骤104之前,本实施例提供的对话生成模型的训练方法还包括:首先,基于所述场景标签确定每个所述批次组对应的场景。具体的,获取每个批次组中训练样本的场景标签,基于训练样本的场景标签,确定每个批次组对应的场景。然后,从所述目标训练数据集的全部所述批次组中,定位出对应有相同场景的批次组,并判断相邻的所述对应有相同场景的批次组之间的批次组的数量是否大于预设阈值。具体的,为了避免出现灾难性遗忘,应当使各场景的批次组应均匀排布,相同场景的批次组之间不能间距过多的其他场景对应的批次组,若可以确定某个场景对应的批次组与相邻的相同场景的批次组之间批次组的数量在特定数量以内,可以确保模型不会出现灾难性遗忘,则可以将特定数量确定为预设阈值。若相邻的所述对应有相同场景的批次组之间的批次组的数量大于所述预设阈值,则重新执行所述对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集的步骤。作为示例,若预设阈值为5,某两个医生问询场景对应的批次组之间有6个其他场景下的批次组,则重新对多个批次组进行随机排序,生成目标训练数据集。在本申请的实施例中,可以避免目标训练数据集中各场景下的批次组分布不均匀,出现任意两个相同场景对应的批次组之间间距过多的其他批次组,进而导致的灾难性遗忘的情况发生,以保证模型训练的性能。
本实施例提供的对话生成模型的训练方法,能够识别每个训练样本对应的场景,并对训练数据集中适用于不用场景的训练样本点进行分类,将适用于同一场景的训练样本在训练批次层面中进行分场景汇聚,使模型能够在适用于同一场景的训练样本下实现快速的收敛,加快模型的训练速度。进一步的,对批次组进行随机排序得到目标训练数据集,并在目标训练数据集内各场景下的训练样本未均匀排布时重新对批次组随机排序,进而使基于目标训练数据集对对话生成模型进行训练时,不会出现因为训练数据集中某个场景的训练数据离训练数据集的末端太远,而导致的灾难性遗忘的情况发生,进而有效提高了对对话生成模型的训练效果,提升了对话生成模型的训练性能。
进一步的,作为图1所示方法的具体实现,本实施例提供了一种对话生成模型的训练装置,如图5所示,该装置包括:数据获取模块51、样本分组模块52、数据生成模块53和模型训练模块54。
数据获取模块51,可用于获取对话生成模型的训练数据集,其中,所述训练数据集包括多个预设有场景标签的训练样本,所述场景标签用于标注所述训练样本适用的场景;
样本分组模块52,可用于基于所述场景标签将所述训练样本划分为多个批次组,其中,每个所述批次组包含属于同一场景的预设数量的训练样本;
数据生成模块53,可用于对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集;
模型训练模块54,可用于根据所述目标训练数据集,对预设的神经网络模型进行训练,得到所述对话生成模型。
在具体的应用场景中,所述样本分组模块52,具体可用于根据所述场景标签,将所述训练样本划分为多个场景组,其中,每个所述场景组包含属于同一场景的全部训练样本;将所述场景组内的多个所述训练样本划分为多个批次组,其中,每个所述批次组包括预设数量的所述训练样本。
在具体的应用场景中,所述数据生成模块53,具体可用于执行循环过程直至满足预设条件,其中,所述循环过程包括:从每个所述场景组内选出一个所述批次组,并将选出的多个所述批次组随机组合成综合组,其中,所述综合组内的训练样本适用的场景包含所述训练数据集中全部训练样本适用的场景;所述预设条件为:存在至少一个所述场景组中的全部所述批次组被组成所述综合组;将全部所述综合组与每个所述场景组内未被组成所述综合组的批次组进行随机排列,或对全部所述综合组进行随机排列,得到所述目标训练数据集。
在具体的应用场景中,所述样本分组模块52,还可具体用于将所述场景组内的训练样本排列成样本队列;以所述样本队列的一个端点为起始点,依次获取所述预设数量的训练样本组成一个批次组,得到多个批次组。
在具体的应用场景中,如图6所示,本装置还包括报警提示模块64,所述报警提示模块64具体可用于获取每个所述批次组内的场景标签;判断同一个所述批次组内的场景标签是否对应同一个场景;若存在所述批次组内的场景标签未对应同一个场景,则发出报警提示信息;
在具体的应用场景中,如图6所示,本装置还包括顺序检测模块65,所述顺序检测模块65具体可用于确定所述目标训练数据集内全部所述批次组的排列顺序和每个所述批次组内训练样本对应的场景标签;将所述目标训练数据集的末端作为起始点,选取预设数量的多个所述批次组确定为验证样本组,并基于所述场景标签确定每个所述验证样本组内的训练样本对应的场景;依次判断每个所述验证样本组内的训练样本对应的场景是否包含所述训练数据集中全部训练样本对应的场景;若全部所述验证样本组内的训练样本对应的场景未包含所述训练数据集中全部训练样本对应的场景,则重新执行所述对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集的步骤。
在具体的应用场景中,所述顺序检测模块65,还可具体用于基于所述场景标签确定每个所述批次组对应的场景;从所述目标训练数据集的全部所述批次组中,定位出对应有相同场景的批次组,并判断相邻的所述对应有相同场景的批次组之间的批次组的数量是否大于预设阈值;若相邻的所述对应有相同场景的批次组的数量大于所述预设阈值,则重新执行所述对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集的步骤。
需要说明的是,本实施例提供的一种对话生成模型的训练装置所涉及各功能单元的其它相应描述,可以参考图1中的对应描述,在此不再赘述。
基于上述如图1所示方法,相应的,本实施例还提供了一种存储介质,其上存储有计算机程序,该程序被处理器执行时实现上述如图1所示的对话生成模型的训练方法。
基于这样的理解,本申请的技术方案可以以软件产品的形式体现出来,该待识别软件产品可以存储在一个非易失性存储介质(可以是CD-ROM,U盘,移动硬盘等)中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施场景所述的方法。
基于上述如图1所示的方法,以及图5和图6所示的对话生成模型的训练装置实施例,为了实现上述目的,本实施例还提供了一种对话生成模型的训练的实体设备,具体可以为个人计算机、服务器、智能手机、平板电脑、智能手表、或者其它网络设备等,该实体设备包括存储介质和处理器;存储介质,用于存储计算机程序;处理器,用于执行计算机程序以实现上述如图1所示的方法。
可选的,该实体设备还可以包括用户接口、网络接口、摄像头、射频(RadioFrequency,RF)电路,传感器、音频电路、WI-FI模块等等。用户接口可以包括显示屏(Display)、输入单元比如键盘(Keyboard)等,可选用户接口还可以包括USB接口、读卡器接口等。网络接口可选的可以包括标准的有线接口、无线接口(如WI-FI接口)等。
本领域技术人员可以理解,本实施例提供的一种对话生成模型的训练的实体设备结构并不构成对该实体设备的限定,可以包括更多或更少的部件,或者组合某些部件,或者不同的部件布置。
存储介质中还可以包括操作系统、网络通信模块。操作系统是管理上述实体设备硬件和待识别软件资源的程序,支持信息处理程序以及其它待识别软件和/或程序的运行。网络通信模块用于实现存储介质内部各组件之间的通信,以及与信息处理实体设备中其它硬件和软件之间通信。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到本申请可以借助软件加必要的通用硬件平台的方式来实现,也可以通过硬件实现。通过应用本申请的技术方案,首先,获取对话生成模型的训练数据集,其中,所述训练数据集包括多个预设有场景标签的训练样本,所述场景标签用于标注所述训练样本适用的场景;然后,基于所述场景标签将所述训练样本划分为多个批次组,其中,每个所述批次组包含属于同一场景的预设数量的训练样本;再后,对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集;最后,根据所述目标训练数据集,对预设的神经网络模型进行训练,得到所述对话生成模型。与现有技术相比,能够提高对对话生成模型的训练效果,提升对话生成模型的训练性能。
本领域技术人员可以理解附图只是一个优选实施场景的示意图,附图中的模块或流程并不一定是实施本申请所必须的。本领域技术人员可以理解实施场景中的装置中的模块可以按照实施场景描述进行分布于实施场景的装置中,也可以进行相应变化位于不同于本实施场景的一个或多个装置中。上述实施场景的模块可以合并为一个模块,也可以进一步拆分成多个子模块。
上述本申请序号仅仅为了描述,不代表实施场景的优劣。以上公开的仅为本申请的几个具体实施场景,但是,本申请并非局限于此,任何本领域的技术人员能思之的变化都应落入本申请的保护范围。
Claims (10)
1.一种对话生成模型的训练方法,其特征在于,所述方法包括:
获取对话生成模型的训练数据集,其中,所述训练数据集包括多个预设有场景标签的训练样本,所述场景标签用于标注所述训练样本适用的场景;
基于所述场景标签将所述训练样本划分为多个批次组,其中,每个所述批次组包含属于同一场景的预设数量的训练样本;
对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集;
根据所述目标训练数据集,对预设的神经网络模型进行训练,得到所述对话生成模型。
2.根据权利要求1所述的方法,其特征在于,所述基于所述场景标签将所述训练样本划分为多个批次组,包括:
根据所述场景标签,将所述训练样本划分为多个场景组,其中,每个所述场景组包含属于同一场景的全部训练样本;
将所述场景组内的多个所述训练样本划分为多个批次组,其中,每个所述批次组包括预设数量的所述训练样本。
3.根据权利要求2所述的方法,其特征在于,所述对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集,包括:
执行循环过程直至满足预设条件,其中,所述循环过程包括:
从每个所述场景组内选出一个所述批次组,并将选出的多个所述批次组随机组合成综合组,其中,所述综合组内的训练样本适用的场景包含所述训练数据集中全部训练样本适用的场景;
所述预设条件为:存在至少一个所述场景组中的全部所述批次组被组成所述综合组;
将全部所述综合组与每个所述场景组内未被组成所述综合组的批次组进行随机排列,或对全部所述综合组进行随机排列,得到所述目标训练数据集。
4.根据权利要求2所述的方法,其特征在于,所述将所述场景组内的多个所述训练样本划分为多个批次组,包括:
将所述场景组内的训练样本排列成样本队列;
以所述样本队列的一个端点为起始点,依次获取所述预设数量的训练样本组成一个批次组,得到多个批次组。
5.根据权利要求1-4任一项所述的方法,其特征在于,所述根据所述目标训练数据集,对预设的神经网络模型进行训练,得到所述对话生成模型之前,所述方法还包括:
获取每个所述批次组内的场景标签;
判断同一个所述批次组内的场景标签是否对应同一个场景;
若存在所述批次组内的场景标签未对应同一个场景,则发出报警提示信息。
6.根据权利要求1-4任一项所述的方法,其特征在于,所述根据所述目标训练数据集,对预设的神经网络模型进行训练,得到所述对话生成模型之前,所述方法还包括:
确定所述目标训练数据集内全部所述批次组的排列顺序和每个所述批次组内训练样本对应的场景标签;
将所述目标训练数据集的末端作为起始点,选取预设数量的多个所述批次组确定为验证样本组,并基于所述场景标签确定每个所述验证样本组内的训练样本对应的场景;
依次判断每个所述验证样本组内的训练样本对应的场景是否包含所述训练数据集中全部训练样本对应的场景;
若全部所述验证样本组内的训练样本对应的场景未包含所述训练数据集中全部训练样本对应的场景,则重新执行所述对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集的步骤。
7.根据权利要求1-4任一项所述的方法,其特征在于,所述根据所述目标训练数据集,对预设的神经网络模型进行训练,得到所述对话生成模型之前,所述方法还包括:
基于所述场景标签确定每个所述批次组对应的场景;
从所述目标训练数据集的全部所述批次组中,定位出对应有相同场景的批次组,并判断相邻的所述对应有相同场景的批次组之间的批次组的数量是否大于预设阈值;
若相邻的所述对应有相同场景的批次组之间的批次组的数量大于所述预设阈值,则重新执行所述对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集的步骤。
8.一种对话生成模型的训练装置,其特征在于,所述装置包括:
数据获取模块,用于获取对话生成模型的训练数据集,其中,所述训练数据集包括多个预设有场景标签的训练样本,所述场景标签用于标注所述训练样本适用的场景;
样本分组模块,用于基于所述场景标签将所述训练样本划分为多个批次组,其中,每个所述批次组包含属于同一场景的预设数量的训练样本;
数据生成模块,用于对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集;
模型训练模块,用于根据所述目标训练数据集,对预设的神经网络模型进行训练,得到所述对话生成模型。
9.一种存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至7中任一项所述的方法的步骤。
10.一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至7中任一项所述的方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310469117.3A CN116484881A (zh) | 2023-04-24 | 2023-04-24 | 对话生成模型的训练方法、装置、存储介质及计算机设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310469117.3A CN116484881A (zh) | 2023-04-24 | 2023-04-24 | 对话生成模型的训练方法、装置、存储介质及计算机设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116484881A true CN116484881A (zh) | 2023-07-25 |
Family
ID=87217550
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310469117.3A Pending CN116484881A (zh) | 2023-04-24 | 2023-04-24 | 对话生成模型的训练方法、装置、存储介质及计算机设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116484881A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117391094A (zh) * | 2023-10-16 | 2024-01-12 | 百度在线网络技术(北京)有限公司 | 智能客服模型的训练方法、基于模型的对话方法、设备 |
-
2023
- 2023-04-24 CN CN202310469117.3A patent/CN116484881A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117391094A (zh) * | 2023-10-16 | 2024-01-12 | 百度在线网络技术(北京)有限公司 | 智能客服模型的训练方法、基于模型的对话方法、设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110288049B (zh) | 用于生成图像识别模型的方法和装置 | |
CN108509411B (zh) | 语义分析方法和装置 | |
CN108960316B (zh) | 用于生成模型的方法和装置 | |
CN109509021B (zh) | 基于行为轨迹的异常识别方法、装置、服务器及存储介质 | |
CN109447156B (zh) | 用于生成模型的方法和装置 | |
CN111414946B (zh) | 基于人工智能的医疗影像的噪声数据识别方法和相关装置 | |
CN109408821B (zh) | 一种语料生成方法、装置、计算设备及存储介质 | |
US20220222540A1 (en) | Predicting brain data using machine learning models | |
CN114334169A (zh) | 医疗对象的类别决策方法、装置、电子设备及存储介质 | |
CN112149754B (zh) | 一种信息的分类方法、装置、设备及存储介质 | |
CN116484881A (zh) | 对话生成模型的训练方法、装置、存储介质及计算机设备 | |
CN115222443A (zh) | 客户群体划分方法、装置、设备及存储介质 | |
Dinkelberg et al. | Detecting opinion-based groups and polarization in survey-based attitude networks and estimating question relevance | |
CN108921138B (zh) | 用于生成信息的方法和装置 | |
CN117194772B (zh) | 一种基于用户标签的内容推送方法及装置 | |
CN111949530B (zh) | 测试结果的预测方法、装置、计算机设备及存储介质 | |
CN113782093A (zh) | 一种基因表达填充数据的获取方法及装置、存储介质 | |
CN111177388B (zh) | 一种处理方法及计算机设备 | |
EP3955177B1 (en) | Search method and information processing system | |
CN113190444A (zh) | 一种测试方法、装置及存储介质 | |
CN110716778B (zh) | 应用兼容性测试方法、装置及系统 | |
CN117671553A (zh) | 一种目标识别方法、系统及相关装置 | |
CN115037790B (zh) | 异常注册识别方法、装置、设备及存储介质 | |
KR102413588B1 (ko) | 학습 데이터에 따른 객체 인식 모델 추천 방법, 시스템 및 컴퓨터 프로그램 | |
US11710068B2 (en) | Labeling a dataset |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |