CN117453878A - 对话处理方法、装置及电子设备 - Google Patents
对话处理方法、装置及电子设备 Download PDFInfo
- Publication number
- CN117453878A CN117453878A CN202311425880.2A CN202311425880A CN117453878A CN 117453878 A CN117453878 A CN 117453878A CN 202311425880 A CN202311425880 A CN 202311425880A CN 117453878 A CN117453878 A CN 117453878A
- Authority
- CN
- China
- Prior art keywords
- pruning
- model
- dialogue
- channels
- network layer
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000003672 processing method Methods 0.000 title claims abstract description 21
- 238000013138 pruning Methods 0.000 claims abstract description 294
- 238000000034 method Methods 0.000 claims abstract description 51
- 238000010801 machine learning Methods 0.000 claims abstract description 38
- 238000012545 processing Methods 0.000 claims abstract description 21
- 230000006870 function Effects 0.000 claims description 34
- 238000010606 normalization Methods 0.000 claims description 8
- 238000013473 artificial intelligence Methods 0.000 abstract description 2
- 230000008569 process Effects 0.000 description 17
- 238000012549 training Methods 0.000 description 17
- 238000013145 classification model Methods 0.000 description 11
- 230000001133 acceleration Effects 0.000 description 6
- 238000010586 diagram Methods 0.000 description 6
- 230000003044 adaptive effect Effects 0.000 description 5
- 230000000694 effects Effects 0.000 description 5
- 230000006835 compression Effects 0.000 description 4
- 238000007906 compression Methods 0.000 description 4
- 238000004364 calculation method Methods 0.000 description 3
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 238000007667 floating Methods 0.000 description 3
- 238000005457 optimization Methods 0.000 description 3
- 238000004891 communication Methods 0.000 description 2
- 238000004590 computer program Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 238000012795 verification Methods 0.000 description 2
- 230000006978 adaptation Effects 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000001174 ascending effect Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000011084 recovery Methods 0.000 description 1
- 230000004044 response Effects 0.000 description 1
- 238000012216 screening Methods 0.000 description 1
- 230000035945 sensitivity Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/33—Querying
- G06F16/332—Query formulation
- G06F16/3329—Natural language query formulation or dialogue systems
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q30/00—Commerce
- G06Q30/01—Customer relationship services
- G06Q30/015—Providing customer assistance, e.g. assisting a customer within a business location or via helpdesk
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Artificial Intelligence (AREA)
- Mathematical Physics (AREA)
- Evolutionary Computation (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Business, Economics & Management (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Molecular Biology (AREA)
- Medical Informatics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- General Health & Medical Sciences (AREA)
- Databases & Information Systems (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Human Computer Interaction (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Accounting & Taxation (AREA)
- Development Economics (AREA)
- Economics (AREA)
- Finance (AREA)
- Marketing (AREA)
- Strategic Management (AREA)
- General Business, Economics & Management (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了一种对话处理方法、装置及电子设备。涉及人工智能领域,该方法包括:基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,其中,第一对话模型中包括多个网络层中,每一个网络层包括多个通道,多个通道分别对应有重要性数值;基于每一个网络层中包括的多个通道分别对应的重要性数值,确定每一个网络层中的冗余通道;按照预设剪枝规则对每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型。本发明解决了相关中的对话模型结构复杂,存在大量冗余通道,并且在进行模型剪枝时人为干预方式进行剪枝率的确定,导致对话模型确定不准确,进而导致对话预测效率低且准确性低的技术问题。
Description
技术领域
本发明涉及人工智能领域,具体而言,涉及一种对话处理方法、装置及电子设备。
背景技术
目前,智能在线客服领域中的对话模型能够智能化地理解用户意图,根据用户需求进行实时交互,针对用户提出的问题,会给出一定的对话文本回复。但是相关技术中的对话模型往往结构复杂,模型中存在一定程度上的参数冗余和/或冗余通道,并且在进行模型剪枝时人为干预方式进行剪枝率的确定,导致对话模型确定不准确,进而导致对话预测效率低且准确性低。
针对上述的问题,目前尚未提出有效的解决方案。
发明内容
本发明实施例提供了一种对话处理方法、装置及电子设备,以至少解决相关中的对话模型结构复杂,存在大量冗余通道,并且在进行模型剪枝时人为干预方式进行剪枝率的确定,导致对话模型确定不准确,进而导致对话预测效率低且准确性低的技术问题。
根据本发明实施例的一个方面,提供了一种对话处理方法,包括:基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,其中,所述第一对话模型中包括多个网络层中,每一个网络层包括多个通道,所述多个通道分别对应有重要性数值,其中,所述重要性数值用于指示对应通道的重要性程度,所述历史对话数据集包括多组历史输入文本,以及所述多组历史输入文本分别对应的历史回复文本;基于所述每一个网络层中包括的多个通道分别对应的所述重要性数值,确定所述每一个网络层中的冗余通道;按照预设剪枝规则对所述每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型,其中,所述预设剪枝规则用于指示每一轮剪枝后的模型对应的模型精度与剪枝率之间的对应关系,所述剪枝率用于指示被剪参数的数量占对应通道中参数总量的比例。
可选的,所述基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,包括:确定初始损失函数;在所述初始损失函数的基础上加入关于目标超参数的L1正则化项,得到目标损失函数,其中,所述目标超参数位于所述每一个网络层中的批标准化层中;基于所述历史对话数据集,以及所述目标损失函数,对所述初始对话模型进行机器学习,得到所述第一对话模型。
可选的,所述基于所述历史对话数据集,以及所述目标损失函数,对所述初始对话模型进行机器学习,得到所述第一对话模型,包括:基于所述历史对话数据集和所述目标损失函数,对所述初始对话模型进行机器学习;在所述每一个网络层中包括的所述多个通道中,预定数量的通道对应的所述重要性数值在预设区间范围内的情况下,输出所述第一对话模型。
可选的,所述按照预设剪枝规则对所述每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型,包括:按照所述预设剪枝规则对所述每一个网络层中的冗余通道进行迭代剪枝;在所述第一对话模型中的所有网络层均剪枝完毕,并且满足预设剪枝约束条件的情况下,输出所述目标对话模型。
可选的,所述按照所述预设剪枝规则对所述每一个网络层中的冗余通道进行迭代剪枝,包括:按照预设剪枝顺序,依次以所述第一对话模型中的每一个网络层作为当前网络层,并循环执行以下操作,直至所述第一对话模型中的所有网络层均剪枝完毕,并且满足所述预设剪枝约束条件:基于所述预设剪枝规则,确定所述当前网络层对应的初始剪枝率;按照所述初始剪枝率对所述当前网络层进行剪枝处理,得到剪枝后的对话模型;基于所述剪枝后的对话模型的模型精度,以及所述第一对话模型对应的模型精度,得到模型精度损失;根据所述模型精度损失对所述初始剪枝率进行调整,得到新的剪枝率;基于所述新的剪枝率,按照与上述操作相同的处理方式,继续对所述剪枝后的对话模型进行剪枝处理。
可选的,所述根据所述模型精度损失对所述初始剪枝率进行调整,得到新的剪枝率,包括:在所述模型精度损失小于或等于预设精度损失容忍值的情况下,按照预设比例增大所述初始剪枝率,得到所述新的剪枝率;在所述模型精度损失大于所述预设精度损失容忍值的情况下,将所述初始剪枝率作为所述新的剪枝率。
可选的,所述预设剪枝约束条件为:
其中,Si表示所述多个网络层中,任意一个网络层的参数量,S表示所述第一对话模型中的总参数量,L表示所述多个网络层的数量,i表示所述多个网络层中的任意一个网络层,pg为预设的全局剪枝率。
可选的,所述方法还包括:获取待输入文本;基于所述待输入文本,采用所述目标对话模型,得到目标回复文本。
根据本发明实施例的另一方面,还提供了一种对话处理装置,包括:机器学习模块,用于基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,其中,所述第一对话模型中包括多个网络层中,每一个网络层包括多个通道,所述多个通道分别对应有重要性数值,其中,所述重要性数值用于指示对应通道的重要性程度,所述历史对话数据集包括多组历史输入文本,以及所述多组历史输入文本分别对应的历史回复文本;确定模块,用于基于所述每一个网络层中包括的所述多个通道分别对应的所述重要性数值,确定所述每一个网络层中的冗余通道;剪枝模块,用于按照预设剪枝规则对所述每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型,其中,所述预设剪枝规则用于指示每一轮剪枝后的模型对应的模型精度与剪枝率之间的对应关系,所述剪枝率用于指示被剪参数的数量占对应通道中参数总量的比例。
根据本发明实施例的另一方面,还提供了一种电子设备,包括一个或多个处理器和存储器,所述存储器用于存储一个或多个程序,其中,当所述一个或多个程序被所述一个或多个处理器执行时,使得所述一个或多个处理器实现任意一项所述的对话处理方法。
在本发明实施例中,通过基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,其中,所述第一对话模型中包括多个网络层中,每一个网络层包括多个通道,所述多个通道分别对应有重要性数值,其中,所述重要性数值用于指示对应通道的重要性程度,所述历史对话数据集包括多组历史输入文本,以及所述多组历史输入文本分别对应的历史回复文本;基于所述每一个网络层中包括的多个通道分别对应的所述重要性数值,确定所述每一个网络层中的冗余通道;按照预设剪枝规则对所述每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型,其中,所述预设剪枝规则用于指示每一轮剪枝后的模型对应的模型精度与剪枝率之间的对应关系,所述剪枝率用于指示被剪参数的数量占对应通道中参数总量的比例,达到了按照预设剪枝规则,根据每一轮剪枝后的模型对应的模型精度自适应的确定每一个网络层中的最佳通道剪枝率,对训练好的对话模型进行剪枝优化的目的,从而实现了自适应优化和精简模型结构,进而提升对话预测效率和预测准确性的技术效果,进而解决了相关中的对话模型结构复杂,存在大量冗余通道,并且在进行模型剪枝时人为干预方式进行剪枝率的确定,导致对话模型确定不准确,进而导致对话预测效率低且准确性低的技术问题。
附图说明
此处所说明的附图用来提供对本发明的进一步理解,构成本申请的一部分,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。在附图中:
图1是根据本发明实施例的一种对话处理方法的示意图;
图2是根据本发明实施例的一种可选的对话处理方法的示意图;
图3是根据本发明实施例的另一种可选的对话处理方法的示意图;
图4是根据本发明实施例的另一种可选的对话处理方法的示意图;
图5是根据本发明实施例的一种对话处理装置的示意图。
具体实施方式
为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分的实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本发明保护的范围。
需要说明的是,本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本发明的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
根据本发明实施例,提供了一种对话处理的方法实施例,需要说明的是,在附图的流程图示出的步骤可以在诸如一组计算机可执行指令的计算机系统中执行,并且,虽然在流程图中示出了逻辑顺序,但是在某些情况下,可以以不同于此处的顺序执行所示出或描述的步骤。
图1是根据本发明实施例的对话处理方法的流程图,如图1所示,该方法包括如下步骤:
步骤S102,基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,其中,第一对话模型中包括多个网络层中,每一个网络层包括多个通道,多个通道分别对应有重要性数值,其中,重要性数值用于指示对应通道的重要性程度,历史对话数据集包括多组历史输入文本,以及多组历史输入文本分别对应的历史回复文本。
可选的,可以基于历史对话数据集,对初始对话模型进行稀疏训练,得到第一对话模型;该初始对话模型可以为TextCNN文本分类模型,TextCNN文本分类模型中包括多个网络层,每一个网络层后面接有一个批标准化层,每一个通道中的重要性数值可以是基于批标准化层中的目标超参数γ确定的。通过稀疏训练方式,批标准化层BN层的目标超参数在稀疏训练过程中逐步向0靠拢,解决常规训练的模型BN层权重不会太接近于0的情况,以便后续进行通道的筛选。
在一种可选的实施例中,基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,包括:确定初始损失函数;在初始损失函数的基础上加入关于目标超参数的L1正则化项,得到目标损失函数,其中,目标超参数位于每一个网络层中的批标准化层中;基于历史对话数据集,以及目标损失函数,对初始对话模型进行机器学习,得到第一对话模型。
可选的,根据每一个网络层中的批标准化层中的目标超参数,确定每一个网络层中包括的多个通道分别对应的重要性数值。初始损失函数可以为交叉熵损失函数,使用BN层的目标超参数γ(即缩放系数)作为衡量通道重要性的参数,目标超参数γ越接近于0,则表明该通道对应的重要性越小,反之则重要性越大。在初始损失函数基础上加入关于γ的L1正则化项,加入L1正则化项后,再次进行训练直至损失收敛,输出第一对话模型,并将训练结束时第一对话模型对应的BN层的γ值确定为各通道的重要性数值,以便后续进行通道的筛选。
在一种可选的实施例中,基于历史对话数据集,以及目标损失函数,对初始对话模型进行机器学习,得到第一对话模型,包括:基于历史对话数据集和目标损失函数,对初始对话模型进行机器学习;在每一个网络层中包括的多个通道中,预定数量的通道对应的重要性数值在预设区间范围内的情况下,输出第一对话模型。
可选的,上述预设区间范围可以为0附近的某个区间范围,即在批标准化层BN层的目标超参数在稀疏训练过程中逐步向0靠拢的情况下,输出第一对话模型,解决常规训练的模型BN层权重不会太接近于0的情况。
步骤S104,基于每一个网络层中包括的多个通道分别对应的重要性数值,确定每一个网络层中的冗余通道。
可选的,重要性数值越小且越接近于0,则表明该通道在模型中越不重要,可以针对该通道中的参数进行适当的剪除,以精简模型结构。可以将重要性数值小于预设重要性阈值(如接近于0)的通道作为冗余通道,进行参数的适当剪除,以精简模型结构。
步骤S106,按照预设剪枝规则对每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型,其中,预设剪枝规则用于指示每一轮剪枝后的模型对应的模型精度与剪枝率之间的对应关系,剪枝率用于指示被剪参数的数量占对应通道中参数总量的比例。
可选的,上述预设剪枝规则中规定了剪枝率动态调整的方法,即在对第一网络模型中的每一层网络进行迭代剪枝的过程中,剪枝率随着每一轮剪枝得到的模型对应的模型精度动态变化。通过以上方式,在模型剪枝的过程中,随着模型精度的变化动态调整剪枝率,由此达到提升模型剪枝效率和剪枝效果的目的。
需要说明的是,相关技术中在进行模型剪枝时,需要人为干预剪枝率,通过事先人工指定网络通道的重要程度阈值,那些低于设定阈值的通道将被剔除。然而,凭经验设定通道重要性阈值往往需要耗费大量人力,并且不同网络层对于剪枝的敏感程度不同,因此阈值设置过大可能会造成过剪枝的问题,导致显著的精度损失;阈值设置过小可能会造成欠剪枝的问题,导致得到的剪枝结果往往是次优的,无法完全剪去冗余参数,影响模型压缩效率。基于此,本发明实施例首先对每一层网络中包括的通道的重要性程度进行精确评估,并在此技术上剔除根据每一轮剪枝后的模型对应的模型精度动态确定每一通道的最佳剪枝率,无需人为干预,免去现有方法人工手动设置阈值的繁琐过程,剔除冗余参数,一定程度上规避了现有方法存在的过剪枝和欠剪枝的问题,在保持精度的同时大幅度压缩第一对话模型的参数量和浮点计算量,实现推理加速,提高解决用户问题的效率。
在一种可选的实施例中,按照预设剪枝规则对每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型,包括:按照预设剪枝规则对每一个网络层中的冗余通道进行迭代剪枝;在第一对话模型中的所有网络层均剪枝完毕,并且满足预设剪枝约束条件的情况下,输出目标对话模型。
可选的,预设剪枝约束条件为:
其中,Si表示多个网络层中,任意一个网络层的参数量,S表示第一对话模型中的总参数量,L表示多个网络层的数量,i表示多个网络层中的任意一个网络层,pg为预设的全局剪枝率。
可选的,按照预设剪枝规则,对每一个网络层中通道的剪枝率进行自适应调整,剪除每一个通道中的冗余通道,当每一层网络中的冗余通道均剪枝完毕,并且剪枝后的模型满足预设剪枝条件的情况下,输出最终的目标对话模型。通过上述约束条件的设置,既可以使得剪枝后的模型更加符合实际场景需要,同时能够保证模型压缩幅度最大,精简程度最佳。
在一种可选的实施例中,按照预设剪枝规则对每一个网络层中的冗余通道进行迭代剪枝,包括:按照预设剪枝顺序,依次以第一对话模型中的每一个网络层作为当前网络层,并循环执行以下操作,直至第一对话模型中的所有网络层均剪枝完毕,并且满足预设剪枝约束条件:基于预设剪枝规则,确定当前网络层对应的初始剪枝率;按照初始剪枝率对当前网络层进行剪枝处理,得到剪枝后的对话模型;基于剪枝后的对话模型的模型精度,以及第一对话模型对应的模型精度,得到模型精度损失;根据模型精度损失对初始剪枝率进行调整,得到新的剪枝率;基于新的剪枝率,按照与上述操作相同的处理方式,继续对剪枝后的对话模型进行剪枝处理。
可选的,按照预设剪枝规则,对每个网络层的冗余通道进行迭代剪枝。具体的,对于第一对话模型中的某一个网络层,按照设定的初始剪枝率对该网络层的通道进行剪枝,并对每次剪枝之后的模型精度进行测试,得出模型精度损失。根据模型精度损失确定初始剪枝率的调整方向进行剪枝率调整,根据调整后得到的新的剪枝率进行模型剪枝。通过以上方式,根据模型精度损失进行剪枝率的自适应调整,在保证模型剪枝效率的同时,提升模型压缩精度。
在一种可选的实施例中,根据模型精度损失对初始剪枝率进行调整,得到新的剪枝率,包括:在模型精度损失小于或等于预设精度损失容忍值的情况下,按照预设比例增大初始剪枝率,得到新的剪枝率;在模型精度损失大于预设精度损失容忍值的情况下,将初始剪枝率作为新的剪枝率。
可选的,模型精度损失如果在设定的精度损失容忍值之内,即:
acc(M)-acc(M')≤η
其中,acc(M)表示未剪枝的网络模型精度,即第一对话模型的模型精度;acc(M')表示自适应通道剪枝后的网络模型精度,即剪枝后的对话模型的模型精度;η表示精度损失容忍值,初始化为0.005。则根据设定的比例因子逐渐增大剪枝率,实现自适应调整。具体地,剪枝率更新过程,即自适应调整过程如下所示:
pi=pi+αpi
其中,pi表示第i层网络的通道剪枝率,初始化值为0.01,α为设定的比例因子,初始化值为0.1。
上述表达式使得该网络层的通道剪枝率将小幅度逐渐增大,直到达到最优,即实现剪枝之后网络模型精度损失最小。
可选的,模型精度损失如果大于设定的精度损失容忍数值,即:
acc(M)-acc(M')>η
则表示当前剪枝率为该网络层的最优剪枝率,实施该网络层的剪枝过程,具体表达式如下:
γi=γi-piγi
其中,γi表示第i层网络的通道重要性参数。之后针对其他未剪枝过的网络层采用上述自适应通道剪枝方法进行迭代剪枝,直到第一对话模型中的所有网络层都已完成剪枝流程。
在一种可选的实施例中,该方法还包括:获取待输入文本;基于待输入文本,采用目标对话模型,得到目标回复文本。
可选的,在获取到目标对话模型之后,可以将该模型部署至实际应用场景中,用户首先提出想要咨询的问题作为待输入文本,例如“为什么充值的话费还没有到账?”、“代金券什么时候返还”、“能否更改商品的收货地址”等问题。将该问题输入至目标对话模型中,针对用户提出的以上三个问题,通过目标对话模型能够更加快速地回答用户咨询的问题,及时做出“您好,充值的话费需要一定时间,请耐心等待。”、“您好,代金券经过核实将在三个工作日内返还,请耐心等待。”、“您好,您的商品已发货,暂时不可修改收货地址,感谢理解。”的应答,减少用户等待时间,一定程度上能够提高解决问题的效率,进一步节约人力成本,相比于原始未压缩的模型能够很好地达成推理加速的目的。
通过上述步骤S102至步骤S106,可以达到按照预设剪枝规则,根据每一轮剪枝后的模型对应的模型精度自适应的确定每一个网络层中的最佳通道剪枝率,对训练好的对话模型进行剪枝优化的目的,从而实现自适应优化和精简模型结构,进而提升对话预测效率和预测准确性的技术效果,进而解决相关中的对话模型结构复杂,存在大量冗余通道,并且在进行模型剪枝时人为干预方式进行剪枝率的确定,导致对话模型确定不准确,进而导致对话预测效率低且准确性低的技术问题。
基于上述实施例和可选实施例,本发明提出一种可选实施方式,图2是根据本发明实施例的一种可选的对话处理方法的流程图,图3是根据本发明实施例的另一种可选的对话处理方法的流程图,图4是根据本发明实施例的另一种可选的对话处理方法的流程图,如图2至图4所示,该方法包括:
步骤S1,TextCNN文本分类模型中的参数冗余往往存在于骨干网络之中,首先提取TextCNN文本分类模型作为初始对话模型,该模型的网络层数为L,第i层网络可以表示为Li,其中的通道可以表示为C,包含通道的个数表示为ni,则整体可以用来表示。后续模型压缩与加速针对初始对话模型展开。
步骤S2,采用稀疏化训练的方法训练初始对话模型,构建通道重要性评估方法得到第一对话模型(即训练后的对话模型)中每一层的通道重要性数值γ,其中,使用BN层的缩放系数γ作为衡量通道重要性的参数,在原有的目标函数基础上加入关于γ的L1正则化项,加入L1正则化项后,再次进行训练直至损失收敛,称此过程为稀疏训练。具体步骤如图3所示:
步骤S21,在第一对话模型的常规损失函数的基础上,对批标准化BN层中超参数γ施加L1正则化损失,目标超参数γ表示BN层中的乘法因子,最终的损失函数表达式如下:
y'=fD(x,W)
其中,Loss表示模型损失,l表示分类任务中常用的交叉熵损失函数;x表示输入数据,即多组历史输入文本中的任意一组历史输入文本;y表示真实标签,即任意一组历史输入文本对应的历史回复文本;y'表示在历史对话数据集D上的预测结果,即任意一组历史输入文本对应的预测结果;W表示可训练参数,λ用于平衡训练过程中网络的稀疏程度。
步骤S22,批标准化层BN层的目标超参数γ在稀疏训练过程中逐步向0靠拢,解决常规训练的模型BN层权重不会太接近于0的情况。
步骤S23,初始对话模型的准确性和BN层的稀疏性逐渐达到平衡。
步骤S24,得到第一对话模型,以及第一对话模型每一层通道的通道重要性数值γ,其中,BN层中对应输出最接近于0的通道将被视为冗余通道,成为下一步自适应通道剪枝的目标。
步骤S3,构建自适应通道剪枝方法,自动确定初始对话模型每一层通道的最优通道剪枝率,具体步骤如图4所示:
步骤S31,将步骤S2得到的通道重要性数值按照升序排序,该数值反映了网络层通道的重要性程度,为后续自适应通道剪枝提供决策依据。
步骤S32,对每层的冗余通道进行迭代式剪枝。具体地,对于第一对话模型中的某一个网络层,按照设定的初始剪枝率对该网络层的通道进行剪枝,并对每次剪枝之后的模型精度进行测试,得出模型精度损失。
步骤S33,模型精度损失如果在设定的精度损失容忍值之内,即:
acc(M)-acc(M')≤η
其中,acc(M)表示未剪枝的网络模型精度,即第一对话模型的模型精度;acc(M')表示自适应通道剪枝后的网络模型精度,即剪枝后的对话模型的模型精度;η表示精度损失容忍值,初始化为0.005。则根据设定的比例因子逐渐增大剪枝率,实现自适应调整。具体地,剪枝率更新过程,即自适应调整过程如下所示:
pi=pi+αpi
其中,pi表示锁个网络层中,第i层网络的通道剪枝率,初始化值为0.01,α为设定的比例因子,初始化值为0.1。
上述表达式使得该网络层的通道剪枝率将小幅度逐渐增大,直到达到最优,即实现剪枝之后网络模型精度损失最小。
步骤S34,模型精度损失如果大于设定的精度损失容忍数值,即:
acc(M)-acc(M')>η
则表示当前剪枝率为该网络层的最优剪枝率,实施该网络层的剪枝过程,具体表达式如下:
γi=γi-piγi
其中,γi表示第i层网络的通道重要性参数。之后针对其他未剪枝过的网络层采用上述自适应通道剪枝方法进行迭代剪枝,直到第一对话模型中的所有网络层都已完成剪枝流程,并且满足预设剪枝条件,其中,预设剪枝条件为:
其中,Si表示多个网络层中,任意一个网络层的参数量,S表示第一对话模型中的总参数量,L表示多个网络层的数量,pg为预设的全局剪枝率。
本发明提出的剪枝方法优势在于无需人为干预,很好地解决了现有方法中人为设置阈值存在的过剪枝和欠剪枝的问题,通过以上自适应通道剪枝流程能够保证最大化剪去冗余参数的同时保持较高的模型精度。
步骤S4,将步骤S3剪枝后的第一对话模型进行微调,进一步恢复所得轻量化剪枝后的第一对话模型的精度至剪枝前的第一对话模型的精度水平,得到目标对话模型。
步骤S5,得到最终的目标对话模型即为轻量化TextCNN文本分类模型,实现在保持精度的同时大幅度压缩TextCNN文本分类模型的参数量和浮点计算量,达成推理加速的目的。
步骤S6,在用户与智能数字人客服交互场景中,用户首先提出想要咨询的问题,例如:“为什么充值的话费还没有到账?”、“代金券什么时候返还”、“能否更改商品的收货地址”等问题。将该问题输入至目标对话模型中,针对用户提出的以上三个问题,通过目标对话模型能够更加快速地回答用户咨询的问题,及时做出“您好,充值的话费需要一定时间,请耐心等待。”、“您好,代金券经过核实将在三个工作日内返还,请耐心等待。”、“您好,您的商品已发货,暂时不可修改收货地址,感谢理解。”的问题回复结果,减少用户等待时间,一定程度上能够提高解决问题的效率,进一步节约人力成本,相比于原始未压缩的模型能够很好地达成推理加速的目的。
本发明实施例中,通过对BN层中超参数施加L1正则化损失,结合稀疏训练的训练模式实现通道重要性程度的精准评估,为后续TextCNN文本分类模型的自适应通道剪枝提供决策依据。并且,本发明实施例通过引入一种精度损失容忍机制来实现剪枝率的自适应调整,进而得出TextCNN文本分类模型每一层的最佳通道剪枝率,保证最大化剔除冗余参数的同时保持较高的模型精度。
本发明实施例可以实现如下效果中的至少之一:1)本发明实施例提出的通道重要性评估方法能够对TextCNN文本分类模型每一层的通道重要性程度进行精准评估,并且通用化程度更高。2)发明实施例提出的自适应通道剪枝方法优势在于自动确定TextCNN文本分类模型每一层的最佳通道剪枝率,无需人为干预,免去现有方法人工手动设置阈值的繁琐过程,剔除冗余参数,一定程度上规避了现有方法存在的过剪枝和欠剪枝的问题,在保持精度的同时大幅度压缩TextCNN文本分类模型的参数量和浮点计算量,实现推理加速,提高解决用户问题的效率,进一步节约人力成本,有效推动TextCNN文本分类模型的实际应用落地。
在本实施例中还提供了一种对话处理装置,该装置用于实现上述实施例及优选实施方式,已经进行过说明的不再赘述。如以下所使用的,术语“模块”“装置”可以实现预定功能的软件和/或硬件的组合。尽管以下实施例所描述的装置较佳地以软件来实现,但是硬件,或者软件和硬件的组合的实现也是可能并被构想的。
根据本发明实施例,还提供了一种用于实施上述对话处理方法的装置实施例,图5是根据本发明实施例的一种对话处理装置的结构示意图,如图5所示,上述对话处理装置,包括:机器学习模块500、确定模块502、剪枝模块504,其中:
机器学习模块500,用于基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,其中,第一对话模型中包括多个网络层中,每一个网络层包括多个通道,多个通道分别对应有重要性数值,其中,重要性数值用于指示对应通道的重要性程度,历史对话数据集包括多组历史输入文本,以及多组历史输入文本分别对应的历史回复文本;
确定模块502,连接于机器学习模块500,用于基于每一个网络层中包括的多个通道分别对应的重要性数值,确定每一个网络层中的冗余通道;
剪枝模块504,连接于确定模块502,用于按照预设剪枝规则对每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型,其中,预设剪枝规则用于指示每一轮剪枝后的模型对应的模型精度与剪枝率之间的对应关系,剪枝率用于指示被剪参数的数量占对应通道中参数总量的比例。
在本发明实施例中,通过设置机器学习模块500,用于基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,其中,第一对话模型中包括多个网络层中,每一个网络层包括多个通道,多个通道分别对应有重要性数值,其中,重要性数值用于指示对应通道的重要性程度,历史对话数据集包括多组历史输入文本,以及多组历史输入文本分别对应的历史回复文本;确定模块502,连接于机器学习模块500,用于基于每一个网络层中包括的多个通道分别对应的重要性数值,确定每一个网络层中的冗余通道;剪枝模块504,连接于确定模块502,用于按照预设剪枝规则对每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型,其中,预设剪枝规则用于指示每一轮剪枝后的模型对应的模型精度与剪枝率之间的对应关系,剪枝率用于指示被剪参数的数量占对应通道中参数总量的比例,达到了按照预设剪枝规则,根据每一轮剪枝后的模型对应的模型精度自适应的确定每一个网络层中的最佳通道剪枝率,对训练好的对话模型进行剪枝优化的目的,从而实现了自适应优化和精简模型结构,进而提升对话预测效率和预测准确性的技术效果,进而解决了相关中的对话模型结构复杂,存在大量冗余通道,并且在进行模型剪枝时人为干预方式进行剪枝率的确定,导致对话模型确定不准确,进而导致对话预测效率低且准确性低的技术问题。
在一种可选的实施例中,机器学习模块,包括:第一确定子模块,用于确定初始损失函数;第一获取子模块,用于在初始损失函数的基础上加入关于目标超参数的L1正则化项,得到目标损失函数,其中,目标超参数位于每一个网络层中的批标准化层中;第一机器学习子模块,用于基于历史对话数据集,以及目标损失函数,对初始对话模型进行机器学习,得到第一对话模型;第二确定子模块,用于根据每一个网络层中的批标准化层中的目标超参数,确定每一个网络层中包括的多个通道分别对应的重要性数值。
在一种可选的实施例中,第一机器学习子模块,包括:第二机器学习子模块,用于基于历史对话数据集和目标损失函数,对初始对话模型进行机器学习;第一输出子模块,用于在每一个网络层中包括的多个通道中,预定数量的通道对应的重要性数值在预设区间范围内的情况下,输出第一对话模型。
在一种可选的实施例中,剪枝模块,包括:第一剪枝子模块,用于按照预设剪枝规则对每一个网络层中的冗余通道进行迭代剪枝;第二输出子模块,用于在第一对话模型中的所有网络层均剪枝完毕,并且满足预设剪枝约束条件的情况下,输出目标对话模型。
在一种可选的实施例中,第二输出子模块,包括:按照预设剪枝顺序,依次以第一对话模型中的每一个网络层作为当前网络层,并循环执行以下操作,直至第一对话模型中的所有网络层均剪枝完毕,并且满足预设剪枝约束条件:基于预设剪枝规则,确定当前网络层对应的初始剪枝率;按照初始剪枝率对当前网络层进行剪枝处理,得到剪枝后的对话模型;基于剪枝后的对话模型的模型精度,以及第一对话模型对应的模型精度,得到模型精度损失;根据模型精度损失对初始剪枝率进行调整,得到新的剪枝率;基于新的剪枝率,按照与上述操作相同的处理方式,继续对剪枝后的对话模型进行剪枝处理。
在一种可选的实施例中,根据模型精度损失对初始剪枝率进行调整,得到新的剪枝率,包括:第二剪枝子模块,用于在模型精度损失小于或等于预设精度损失容忍值的情况下,按照预设比例增大初始剪枝率,得到新的剪枝率;剪枝率更新子模块,用于在模型精度损失大于预设精度损失容忍值的情况下,将初始剪枝率作为新的剪枝率。
在一种可选的实施例中,预设剪枝约束条件为:
/>
其中,Si表示多个网络层中,任意一个网络层的参数量,S表示第一对话模型中的总参数量,L表示多个网络层的数量,i表示多个网络层中的任意一个网络层,pg为预设的全局剪枝率。
在一种可选的实施例中,方法还包括:第二获取子模块,用于获取待输入文本;预测子模块,用于基于待输入文本,采用目标对话模型,得到目标回复文本。
需要说明的是,上述各个模块是可以通过软件或硬件来实现的,例如,对于后者,可以通过以下方式实现:上述各个模块可以位于同一处理器中;或者,上述各个模块以任意组合的方式位于不同的处理器中。
此处需要说明的是,上述机器学习模块500、确定模块502、剪枝模块504对应于实施例中的步骤S102至步骤S106,上述模块与对应的步骤所实现的实例和应用场景相同,但不限于上述实施例所公开的内容。需要说明的是,上述模块作为装置的一部分可以运行在计算机终端中。
需要说明的是,本实施例的可选或优选实施方式可以参见实施例中的相关描述,此处不再赘述。
上述的对话处理装置还可以包括处理器和存储器,上述机器学习模块500、确定模块502、剪枝模块504等均作为程序模块存储在存储器中,由处理器执行存储在存储器中的上述程序模块来实现相应的功能。
处理器中包含内核,由内核去存储器中调取相应的程序模块,上述内核可以设置一个或以上。存储器可能包括计算机可读介质中的非永久性存储器,随机存取存储器(RAM)和/或非易失性内存等形式,如只读存储器(ROM)或闪存(flash RAM),存储器包括至少一个存储芯片。
根据本申请实施例,还提供了一种非易失性存储介质的实施例。可选的,在本实施例中,上述非易失性存储介质包括存储的程序,其中,在上述程序运行时控制上述非易失性存储介质所在设备执行上述任意一种对话处理方法。
可选的,在本实施例中,上述非易失性存储介质可以位于计算机网络中计算机终端群中的任意一个计算机终端中,或者位于移动终端群中的任意一个移动终端中,上述非易失性存储介质包括存储的程序。
可选的,在程序运行时控制非易失性存储介质所在设备执行以下功能:基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,以及第一对话模型中包括的多个网络层中,每一个网络层中包括的多个通道分别对应的重要性数值,其中,重要性数值用于指示对应通道的重要性程度,历史对话数据集包括多组历史输入文本,以及多组历史输入文本分别对应的历史回复文本;基于每一个网络层中包括的多个通道分别对应的重要性数值,确定每一个网络层中的冗余通道;按照预设剪枝规则对每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型,其中,预设剪枝规则用于指示每一轮剪枝后的模型对应的模型精度与剪枝率之间的对应关系,剪枝率用于指示被剪参数的数量占对应通道中参数总量的比例。
根据本申请实施例,还提供了一种处理器的实施例。可选的,在本实施例中,上述处理器用于运行程序,其中,上述程序运行时执行上述任意一种对话处理方法。
根据本申请实施例,还提供了一种计算机程序产品的实施例,当在数据处理设备上执行时,适于执行初始化有上述任意一种的对话处理方法步骤的程序。
可选的,上述计算机程序产品,当在数据处理设备上执行时,适于执行初始化有如下方法步骤的程序:基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,以及第一对话模型中包括的多个网络层中,每一个网络层中包括的多个通道分别对应的重要性数值,其中,重要性数值用于指示对应通道的重要性程度,历史对话数据集包括多组历史输入文本,以及多组历史输入文本分别对应的历史回复文本;基于每一个网络层中包括的多个通道分别对应的重要性数值,确定每一个网络层中的冗余通道;按照预设剪枝规则对每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型,其中,预设剪枝规则用于指示每一轮剪枝后的模型对应的模型精度与剪枝率之间的对应关系,剪枝率用于指示被剪参数的数量占对应通道中参数总量的比例。
本发明实施例提供了一种电子设备,该电子设备包括处理器、存储器及存储在存储器上并可在处理器上运行的程序,处理器执行程序时实现以下步骤:基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,以及第一对话模型中包括的多个网络层中,每一个网络层中包括的多个通道分别对应的重要性数值,其中,重要性数值用于指示对应通道的重要性程度,历史对话数据集包括多组历史输入文本,以及多组历史输入文本分别对应的历史回复文本;基于每一个网络层中包括的多个通道分别对应的重要性数值,确定每一个网络层中的冗余通道;按照预设剪枝规则对每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型,其中,预设剪枝规则用于指示每一轮剪枝后的模型对应的模型精度与剪枝率之间的对应关系,剪枝率用于指示被剪参数的数量占对应通道中参数总量的比例。
上述本发明实施例顺序仅仅为了描述,不代表实施例的优劣。
在本发明的上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述的部分,可以参见其他实施例的相关描述。
在本申请所提供的几个实施例中,应该理解到,所揭露的技术内容,可通过其它的方式实现。其中,以上所描述的装置实施例仅仅是示意性的,例如上述模块的划分,可以为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个模块或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,模块或模块的间接耦合或通信连接,可以是电性或其它的形式。
上述作为分离部件说明的模块可以是或者也可以不是物理上分开的,作为模块显示的部件可以是或者也可以不是物理模块,即可以位于一个地方,或者也可以分布到多个模块上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。
另外,在本发明各个实施例中的各功能模块可以集成在一个处理模块中,也可以是各个模块单独物理存在,也可以两个或两个以上模块集成在一个模块中。上述集成的模块既可以采用硬件的形式实现,也可以采用软件功能模块的形式实现。
上述集成的模块如果以软件功能模块的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取非易失性存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个非易失性存储介质中,包括若干指令用以使得一台计算机设备(可为个人计算机、服务器或者网络设备等)执行本发明各个实施例方法的全部或部分步骤。而前述的非易失性存储介质包括:U盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、移动硬盘、磁碟或者光盘等各种可以存储程序代码的介质。
以上仅是本发明的优选实施方式,应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明原理的前提下,还可以做出若干改进和润饰,这些改进和润饰也应视为本发明的保护范围。
Claims (10)
1.一种对话处理方法,其特征在于,包括:
基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,其中,所述第一对话模型中包括多个网络层中,每一个网络层包括多个通道,所述多个通道分别对应有重要性数值,其中,所述重要性数值用于指示对应通道的重要性程度,所述历史对话数据集包括多组历史输入文本,以及所述多组历史输入文本分别对应的历史回复文本;
基于所述每一个网络层中包括的所述多个通道分别对应的所述重要性数值,确定所述每一个网络层中的冗余通道;
按照预设剪枝规则对所述每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型,其中,所述预设剪枝规则用于指示每一轮剪枝后的模型对应的模型精度与剪枝率之间的对应关系,所述剪枝率用于指示被剪参数的数量占对应通道中参数总量的比例。
2.根据权利要求1所述的方法,其特征在于,所述基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,包括:
确定初始损失函数;
在所述初始损失函数的基础上加入关于目标超参数的L1正则化项,得到目标损失函数,其中,所述目标超参数位于所述每一个网络层中的批标准化层中;
基于所述历史对话数据集,以及所述目标损失函数,对所述初始对话模型进行机器学习,得到所述第一对话模型。
3.根据权利要求2所述的方法,其特征在于,所述基于所述历史对话数据集,以及所述目标损失函数,对所述初始对话模型进行机器学习,得到所述第一对话模型,包括:
基于所述历史对话数据集和所述目标损失函数,对所述初始对话模型进行机器学习;
在所述每一个网络层中包括的所述多个通道中,预定数量的通道对应的所述重要性数值在预设区间范围内的情况下,输出所述第一对话模型。
4.根据权利要求1所述的方法,其特征在于,所述按照预设剪枝规则对所述每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型,包括:
按照所述预设剪枝规则对所述每一个网络层中的冗余通道进行迭代剪枝;
在所述第一对话模型中的所有网络层均剪枝完毕,并且满足预设剪枝约束条件的情况下,输出所述目标对话模型。
5.根据权利要求4所述的方法,其特征在于,所述按照所述预设剪枝规则对所述每一个网络层中的冗余通道进行迭代剪枝,包括:
按照预设剪枝顺序,依次以所述第一对话模型中的每一个网络层作为当前网络层,并循环执行以下操作,直至所述第一对话模型中的所有网络层均剪枝完毕,并且满足所述预设剪枝约束条件:
基于所述预设剪枝规则,确定所述当前网络层对应的初始剪枝率;按照所述初始剪枝率对所述当前网络层进行剪枝处理,得到剪枝后的对话模型;基于所述剪枝后的对话模型的模型精度,以及所述第一对话模型对应的模型精度,得到模型精度损失;根据所述模型精度损失对所述初始剪枝率进行调整,得到新的剪枝率;基于所述新的剪枝率,按照与上述操作相同的处理方式,继续对所述剪枝后的对话模型进行剪枝处理。
6.根据权利要求5所述的方法,其特征在于,所述根据所述模型精度损失对所述初始剪枝率进行调整,得到新的剪枝率,包括:
在所述模型精度损失小于或等于预设精度损失容忍值的情况下,按照预设比例增大所述初始剪枝率,得到所述新的剪枝率;
在所述模型精度损失大于所述预设精度损失容忍值的情况下,将所述初始剪枝率作为所述新的剪枝率。
7.根据权利要求4所述的方法,其特征在于,所述预设剪枝约束条件为:
其中,Si表示所述多个网络层中,任意一个网络层的参数量,S表示所述第一对话模型中的总参数量,L表示所述多个网络层的数量,i表示所述多个网络层中的任意一个网络层,pg为预设的全局剪枝率。
8.根据权利要求1至7中任意一项所述的方法,其特征在于,所述方法还包括:
获取待输入文本;
基于所述待输入文本,采用所述目标对话模型,得到目标回复文本。
9.一种对话处理装置,其特征在于,包括:
机器学习模块,用于基于历史对话数据集,对初始对话模型进行机器学习,得到第一对话模型,其中,所述第一对话模型中包括多个网络层中,每一个网络层包括多个通道,所述多个通道分别对应有重要性数值,其中,所述重要性数值用于指示对应通道的重要性程度,所述历史对话数据集包括多组历史输入文本,以及所述多组历史输入文本分别对应的历史回复文本;
确定模块,用于基于所述每一个网络层中包括的所述多个通道分别对应的所述重要性数值,确定所述每一个网络层中的冗余通道;
剪枝模块,用于按照预设剪枝规则对所述每一个网络层中的冗余通道进行迭代剪枝,得到目标对话模型,其中,所述预设剪枝规则用于指示每一轮剪枝后的模型对应的模型精度与剪枝率之间的对应关系,所述剪枝率用于指示被剪参数的数量占对应通道中参数总量的比例。
10.一种电子设备,其特征在于,包括一个或多个处理器和存储器,所述存储器用于存储一个或多个程序,其中,当所述一个或多个程序被所述一个或多个处理器执行时,使得所述一个或多个处理器实现权利要求1至8中任意一项所述的对话处理方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311425880.2A CN117453878A (zh) | 2023-10-30 | 2023-10-30 | 对话处理方法、装置及电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311425880.2A CN117453878A (zh) | 2023-10-30 | 2023-10-30 | 对话处理方法、装置及电子设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117453878A true CN117453878A (zh) | 2024-01-26 |
Family
ID=89592503
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311425880.2A Pending CN117453878A (zh) | 2023-10-30 | 2023-10-30 | 对话处理方法、装置及电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117453878A (zh) |
-
2023
- 2023-10-30 CN CN202311425880.2A patent/CN117453878A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110619385B (zh) | 基于多级剪枝的结构化网络模型压缩加速方法 | |
US10552737B2 (en) | Artificial neural network class-based pruning | |
CN111950723B (zh) | 神经网络模型训练方法、图像处理方法、装置及终端设备 | |
CN112163628A (zh) | 一种适用于嵌入式设备的改进目标实时识别网络结构的方法 | |
CN110956202B (zh) | 基于分布式学习的图像训练方法、系统、介质及智能设备 | |
CN111738098A (zh) | 一种车辆识别方法、装置、设备及存储介质 | |
EP4087239A1 (en) | Image compression method and apparatus | |
CN111368887B (zh) | 雷雨天气预测模型的训练方法及雷雨天气预测方法 | |
CN108960314B (zh) | 基于难样本的训练方法、装置及电子设备 | |
CN113159276A (zh) | 模型优化部署方法、系统、设备及存储介质 | |
CN113488023B (zh) | 一种语种识别模型构建方法、语种识别方法 | |
CN111563161B (zh) | 一种语句识别方法、语句识别装置及智能设备 | |
CN114429208A (zh) | 基于残差结构剪枝的模型压缩方法、装置、设备及介质 | |
CN113420651A (zh) | 深度卷积神经网络的轻量化方法、系统及目标检测方法 | |
CN111860405A (zh) | 图像识别模型的量化方法、装置、计算机设备及存储介质 | |
CN117453878A (zh) | 对话处理方法、装置及电子设备 | |
KR102002549B1 (ko) | 다단계 분류모델 생성 방법 및 그 장치 | |
CN112287950A (zh) | 特征提取模块压缩方法、图像处理方法、装置、介质 | |
CN115063673B (zh) | 模型压缩方法、图像处理方法、装置和云设备 | |
CN112132207A (zh) | 基于多分支特征映射目标检测神经网络构建方法 | |
CN115170902B (zh) | 图像处理模型的训练方法 | |
CN115146775B (zh) | 边缘设备推理加速方法、装置和数据处理系统 | |
CN116468102A (zh) | 刀具图像分类模型剪枝方法、装置、计算机设备 | |
CN112200275B (zh) | 人工神经网络的量化方法及装置 | |
CN114565080A (zh) | 神经网络压缩方法及装置、计算机可读介质、电子设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |