发明内容
因此,本发明提供的一种跨媒体数据关联分析模型训练、数据关联分析方法及系统,克服现有技术中跨模态数据关联性的预测结果与查询请求的关联数据不符,检索准确度不高的缺陷。
为达到上述目的,本发明提供如下技术方案:
第一方面,本发明实施例提供一种跨媒体数据关联分析模型训练方法,所述关联分析模型包括:特征提取层、判别模型层、生成模型层及强化学习层,所述跨媒体数据关联分析模型训练方法包括:
特征提取层基于用户查询请求实时获取待识别多媒体数据库中多模态数据的原始特征;
对判别模型层进行预训练产生奖励值;
利用生成模型层对多模态数据的原始特征进行多模态哈希特征提取,生成器根据提取的多模态哈希特征、强化学习层提供的动作值,更新多模态数据元组、强化学习层的状态值及生成模型参数,其中,强化学习层的动作值用于指导生成器选择与用户查询请求关联性最大的多模态数据元组,其是根据奖励值、状态值的更新获取,奖励值用于表征多模态数据元组和多模态流形元组之间的相似性,状态值用于表征当前的生成模型层中多模态哈希特征的输入状态;
判别模型利用多模态数据的原始特征,生成多模态流形关联图,基于多模态流形关联图产生多模态数据流形元组,提取多模态数据流形元组的多模态哈希特征;判别器根据判别模型生成的多模态哈希特征及生成模型更新后的多模态数据元组,更新判别模型参数及奖励值,直至判别网络参数值收敛。
在一实施例中,对判别模型进行预训练产生奖励值的步骤,包括:
判别模型通过对多模态数据的原始特征构建多模态流形关联图;
基于多模态流形关联图产生多模态数据流形元组;
根据多模态数据流形元组产生判别边界,判别器对判别边界进行多模态哈希特征提取,产生奖励值。
在一实施例中,产生判别边界的多模态流形元组包括:与用户请求流形相同的流形元组、与用户请求流形不同的非流形元组。
在一实施例中,采用无监督跨模态哈希学习提取多模态哈希特征。
在一实施例中,多模态数据的原始特征包括:图像原始特征、音频原始特征、文本原始特征及视频原始特征。
第二方面,本发明实施例提供一种跨媒体数据关联分析方法,包括:
获取用户的查询请求队列;
将所述用户的查询请求队列输入本发明实施例第一方面的跨媒体数据关联分析模型训练方法,生成的跨媒体数据关联分析模型的判别模型层中,得到与用户查询请求队列相关性分数大于预设期望值时的多模态数据元组。
第三方面,本发明实施例提供一种跨媒体数据关联分析模型训练系统,包括:
多模态数据的原始特征提取模块,用于基于用户查询请求,实时获取待识别多媒体数据库中多模态数据的原始特征;
判别模型预训练模块,用于对判别模型层进行预训练产生奖励值;
生成模型指导模块,利用生成模型层对多模态数据的原始特征进行多模态哈希特征提取,生成器根据提取的多模态哈希特征、强化学习层提供的动作值,更新多模态数据元组、强化学习层的状态值及生成模型参数,其中,强化学习层的动作值用于指导生成器选择与用户查询请求关联性最大的多模态数据元组,其是根据奖励值、状态值的更新获取,奖励值用于表征多模态数据元组和多模态流形元组之间的相似性,状态值用于表征当前的生成模型层中多模态哈希特征的输入状态;
判别模型输出模块,用于利用多模态数据的原始特征,生成多模态流形关联图,基于多模态流形关联图产生多模态数据流形元组,提取多模态数据流形元组的多模态哈希特征;判别器根据判别模型生成的多模态哈希特征及生成模型更新后的多模态数据元组,更新判别模型参数及奖励值,直至判别网络参数值收敛。
第四方面,本发明实施例提供一种跨媒体数据关联分析系统,包括:
用户的查询请求队列获取模块,用于获取用户的查询请求队列;
跨媒体数据关联分析模块,用于将用户的查询请求队列输入本发明实施例第一方面的所述的跨媒体数据关联分析模型训练方法,生成的跨媒体数据关联分析模型的判别模型层中,得到与用户查询请求队列相关性分数大于预设期望值时的多模态数据元组。
第五方面,本发明实施例提供一种终端,包括:至少一个处理器,以及与所述至少一个处理器通信连接的存储器,其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器执行本发明实施例第一方面所述的跨媒体数据关联分析模型训练方法或本发明实施例第二方面所述的跨媒体数据关联分析方法。
第六方面,本发明实施例提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机指令,所述计算机指令用于使所述计算机执行本发明实施例第一方面所述的跨媒体数据关联分析模型训练方法或本发明实施例第二方面所述的跨媒体数据关联分析方法。
本发明技术方案,具有如下优点:
1、本发明提供的跨媒体数据关联分析模型训练、数据关联分析方法及系统,将多模态流形关联图间潜在的多模态数据流形元组考虑在内,充分挖掘跨模态数据间的关联性,在判别模型层的无监督哈希学习中将潜在的多模态数据流形元组考虑在内,充分挖掘了跨模态数据间的关联性,在生成模型中去拟合这种流形分布,生成拟合后的流形元组供判别器判断。利用判别模型层、生成模型层组成的对抗网络,提高了判别器判断多模态数据元组与用户查询请求相关性的能力,同时,结合强化学习层对解决对抗网络面对的离散数据梯度传播问题,提高了在线查询请求的检索系统的速度,显著提高了预测与查询请求的关联数据的能力。
2、本发明提供的跨媒体数据关联分析模型训练、数据关联分析方法及系统,采用无监督跨模态哈希学习提取多模态哈希特征,利用无监督跨模态哈希学习有效的保留跨模态数据间的语义关联信息,减少有限长度哈希码的储存成本,提高检索效率,利用无监督哈希学习不仅考虑了同构和异构模态间的关联性,同时考虑多模态数据间共同存在的信息,这些共存信息一般流形相同,既可以充分保留多模态数据间的语义,也能够节约人工标注的成本。
具体实施方式
下面将结合附图对本发明的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
在本发明的描述中,需要说明的是,术语“中心”、“上”、“下”、“左”、“右”、“竖直”、“水平”、“内”、“外”等指示的方位或位置关系为基于附图所示的方位或位置关系,仅是为了便于描述本发明和简化描述,而不是指示或暗示所指的装置或元件必须具有特定的方位、以特定的方位构造和操作,因此不能理解为对本发明的限制。此外,术语“第一”、“第二”、“第三”仅用于描述目的,而不能理解为指示或暗示相对重要性。
在本发明的描述中,需要说明的是,除非另有明确的规定和限定,术语“安装”、“相连”、“连接”应做广义理解,例如,可以是固定连接,也可以是可拆卸连接,或一体地连接;可以是机械连接,也可以是电连接;可以是直接相连,也可以通过中间媒介间接相连,还可以是两个元件内部的连通,可以是无线连接,也可以是有线连接。对于本领域的普通技术人员而言,可以具体情况理解上述术语在本发明中的具体含义。
此外,下面所描述的本发明不同实施方式中所涉及的技术特征只要彼此之间未构成冲突就可以相互结合。
实施例1
本发明实施例提供的一种跨媒体数据关联分析模型训练方法,所述关联分析模型包括:特征提取层、判别模型层、生成模型层及强化学习层,利用判别模型层、生成模型层组成的对抗网络,同时结合强化学习层对解决对抗网络面对的离散数据梯度传播问题,如图1所示,包括如下步骤:
步骤S1:特征提取层基于用户查询请求实时获取待识别多媒体数据库中多模态数据的原始特征。
在本发明实施例中,如图2所示,特征提取层用于提取多模态数据的原始语义特征,特征提取基于词袋模型和全连接网络,提取的多模态数据的原始特征包括:图像原始特征、音频原始特征、文本原始特征及视频原始特征,仅以此举例,但是不以此为限,在实际应用中根据不同需求做相应选择;图像、音频、视频及文本的特征提取基于卷积神经网络,其中音频语义需要对音频数据进行声学特征提取,视频语义需要获取视频的关键帧再进行底层特征提取,仅以此举例,但是不以此为限,在实际应用中根据实际需求做相应处理。
步骤S2:对判别模型层进行预训练产生奖励值。
在本发明实施例中,对判别模型进行预训练产生奖励值的步骤,包括:判别模型通过对多模态数据的原始特征构建多模态流形关联图,基于多模态流形关联图产生多模态数据流形元组。
判别模型首先接受来自特征提取层的多模态数据的原始特征和查询请求,然后对多模态数据的原始特征利用K近邻算法构造单模态数据关联图,模态间距离的计算为测地距离,测地距离是根据最短路径算法计算出模态间的最短路径,而不是欧拉距离(模态间的直线距离)。虽然测地距离大于等于欧拉距离,但是这更能反映模态间的真实相关性。先计算所有数据点的最近邻,把数据点到最近邻的距离设置为1,数据点到非K近邻的点设置为无穷大,然后通过最短路径算法(例如Dijkstra算法)更新数据点之间的距离,使得特征数据点之间的距离为测地距离,有效的测量了单模态数据点间的距离,保留了语义信息。
获得单模态关联图后,判别模型层对于所有的单模态关联图,如果两个不同模态的数据节点存在共同信息,则把这两个数据节点融合,融合前数据节点的点边关系依然存在,所有的数据点的根据共存信息融合完成时,形成了多模态流形关联图,多模态流形关联图能够跨不同模态捕获潜在的流形结构,所以模态不同但流形相同的数据有极低的汉明距离,产生与查询请求关联的多模态数据流形元组能够引导判别模型的训练,能够提高查询关联性最高数据的准确性和速度。
根据多模态数据流形元组产生判别边界,包括:与用户请求流形相同的流形元组、与用户请求流形不同的非流形元组,判别器对判别边界进行多模态哈希特征提取,其多模态哈希特征提取在哈希学习层获取二进制哈希码,将选取的两个元组都要获得二进制哈希码再输入判别器,产生奖励值,使得判别模型有基本的判别能力。
在本发明实施例中,判别模型根据用户查询请求和多模态数据元组的相关系分数来表示,表示公式为:
fφ(XG,q)=max(0,m+||h(q)-h(XM)||2-||h(q)-h(XG)||2)
其中φ为判别模型的参数,||·||2是用户查询请求q与多模态数据元组XG或多模态数据流形元组XM内各模态数据的平均距离,m为防止产生非正数分数的偏置值,(q)、(XM)、(XG)分别为查询请求q,多模态数据流形元组XM、多模态数据元组XG的哈希码实数值。根据fφ(XG,q)表达式可知,fφ(XG,q)值越小,查询请求和多模态数据元组的关联性越大。
判别模型层中的判别器是一个全连接的深度神经网络,其目的是区分多模态数据元组和多模态数据流形元组哪才是与查询请求关联性最大,根据相关性分数确定,相关性分数转化为损失函数,通过神经网络后向传播,更新判别模型参数。由于本发明中生成器输入的多模态数据流形元组是在多模态流形关联图选取的,与查询请求的关联性最大,所以判别模型训练初期能够清晰的区分多模态数据元组和多模态数据流形元组,当初始化判别模型后,判别器还没有判断多模态数据元组和多模态数据流形元组的能力,所以要进行预训练。
判别模型预训练时,判别器在多模态流形关联图中选择一个与用户查询请求q流形相同的流形元组,以及流形不同的元组进行判别模型的预训练,以确定判别边界,计算训练过程中损失函数的计算公式为:
Loss(q,XM,XN)=max(0,m+||h(q)-h(XM)||2-||h(q)-h(XN)||2)
其中,XN表示多模态流形关联图中与查询请求不相关的元组,可以发现Loss(q,XM,XN)=fφ(XN,q),训练过程中利用损失函数在判别网络反向传播更新判别模型参数φ,判别模型的预训练中,输入与查询请求的关联性最大流形元组和关联性低的非流形元组,判别器对多模态流形元组给的分数接近于0,但对多模态非流形元组给的分数接近于1,因此训练了判别模型的判定边界,产生初始的奖励值,使得判别模型有基本的判别能力。
步骤S3:利用生成模型层对多模态数据的原始特征进行多模态哈希特征提取,生成器根据提取的多模态哈希特征、强化学习层提供的动作值,更新多模态数据元组、强化学习层的状态值及生成模型参数,其中,强化学习层的动作值用于指导生成器选择与用户查询请求关联性最大的多模态数据元组,其是根据奖励值、状态值的更新获取,奖励值用于表征多模态数据元组和多模态流形元组之间的相似性,状态值用于表征当前的生成模型层中多模态哈希特征的输入状态;
在本发明实施例中,生成器是一个全连接的深度神经网络,对多模态数据的原始特征进行,即向生成器输入在哈希学习层提取多模态哈希特征的多模态数据哈希码和用户的查询请求时,生成器根据强化层中智能体的动作值预测出与用户查询请求相似性最大的多模态数据哈希码,根据预测结果组成生成多模态数据元组供判别器判断。
在本发明实施例中,生成器根据用户查询请求和判别模型反馈的奖励,值,预测出和用户查询请求相关性最大的多模态数据元组,根据生成模型参数相关参数计算,计算生成模型预测概率的公式为:
其中,θ为生成网络参数,i表示单模态数据索引值,
表示根据用户查询请求q和奖励值r生成的单模态数据,组合全部生成的单模态数据即为数据多媒体元组。
在本发明实施例中,当训练生成模型时,判别模型固定,可以用判别模型给出的奖励值
强化学习来指导生成模型训练,奖励值为
计算生成模型进行训练的公式为:
其中,θ
*表示生成模型的生成网络参数为最优参数,φ
*为判别模型已训练的判别网络参数值为最优参数值,j为查找请求索引值,
为根据用户查询请求q
j预测数据多模态数据元组X
G的可能性。当表达式最小化训练时,
越大越好,反映了生成模型的训练目的,最大化查询请求与数据多媒体元组的相关性,最小化判别模型的判断能力。
步骤S4:判别模型利用多模态数据的原始特征,生成多模态流形关联图,基于多模态流形关联图产生多模态数据流形元组,提取多模态数据流形元组的多模态哈希特征;判别器根据判别模型生成的多模态哈希特征及生成模型更新后的多模态数据元组,更新判别模型参数及奖励值,直至判别网络参数值收敛。
在本发明实施例中,判别器是一个全连接的深度神经网络,其目的是区分与数据多媒体元组和流形元组哪个才是与查询请求关联性最大的元组,更新判别模型参数及奖励值,直至判别网络参数值收敛。当训练判别模型时,生成模型已经生成了多模态元组,同时判别器会在关联图中选择与查询请求相似性最大的流形元组,然后判别器判断多媒体数据元组和多媒体数据流形元组是否与查询请求真实相关。
判别模型可以利用查询请求q和多模态数据元组XG的相关性分数fφ(XG,q)来预测数据多媒体元组的概率,概率计算公式为:
其中,f
φ(X
G,q)来表示判别模型,在判别模型的神经网络中,输出层的激活函数为
即可预测出与查询请求q相关性最大的元组。
在本发明实施例中,训练判别模型时,生成模型固定,计算训练判别模型的判别网络参数值公式为:
其中,
表示x是判别器根据
在多模态流形关联图上选择的单模态数据,所有选择的单模态数据
组合为多模态数据流形元组X
M。
表示在多模态流形关联图上选择与查询q
j相关联单模态数据
的概率,计算公式为:
判别模型的最优值为φ
*,即为判别网络参数值,根据判别模型参数计算,训练结束时,判别网络参数值收敛,所以输出到强化学习层的奖励值为
判别模型训练表达式最大化训练时,
越大越好,这也反映了判别模型的目的,最大化流形元组与用户查询请求的相关性,即最大化判断模型的判断能力。
本发明实施例中提供的跨媒体数据关联分析模型训练方法,将多模态流形关联图间潜在的多模态数据流形元组考虑在内,充分挖掘了跨模态数据间的关联性,利用判别模型层、生成模型层组成的对抗网络,同时,结合强化学习层对解决对抗网络面对的离散数据梯度传播问题;通过对抗网络提高了判别器判断多模态数据元组与用户查询请求相关性的能力,显著提高了预测与用户查询请求的关联数据的能力,对于在线用户查询请求的检索系统可以显著提高其检索速度。
本发明实施例还提供一种跨媒体数据关联分析方法,包括:获取用户查询请求队列;将所述用户查询请求队列输入上述的跨媒体数据关联分析模型训练方法,生成的跨媒体数据关联分析模型的判别模型层中,得到与用户查询请求队列相关性分数大于预设期望值时的多模态数据元组。
本发明实施例中提供的跨媒体数据关联分析方法,通过对抗网络提高了判别器判断多模态数据元组与用户查询请求队列相关性的能力,显著提高了预测与用户查询请求的关联数据的能力,对于在线用户查询请求的检索系统可以显著提高其检索速度。
实施例2
本发明实施例提供一种跨媒体数据关联分析模型训练系统,如图3所示,包括:
多模态数据的原始特征提取模块1,用于基于用户查询请求,实时获取待识别多媒体数据库中多模态数据的原始特征;此模块执行实施例1中的步骤S1所描述的方法,在此不再赘述。
判别模型预训练模块2,用于对判别模型层进行预训练产生奖励值;此模块执行实施例1中的步骤S2所描述的方法,在此不再赘述。
生成模型指导模块3,利用生成模型层对多模态数据的原始特征进行多模态哈希特征提取,生成器根据提取的多模态哈希特征、强化学习层提供的动作值,更新多模态数据元组、强化学习层的状态值及生成模型参数,其中,强化学习层的动作值用于指导生成器选择与用户查询请求关联性最大的多模态数据元组,其是根据奖励值、状态值的更新获取,奖励值用于表征多模态数据元组和多模态流形元组之间的相似性,状态值用于表征当前的生成模型层中多模态哈希特征的输入状态;此模块执行实施例1中的步骤S3所描述的方法,在此不再赘述。
判别模型输出模块4,用于利用多模态数据的原始特征,生成多模态流形关联图,基于多模态流形关联图产生多模态数据流形元组,提取多模态数据流形元组的多模态哈希特征;判别器根据判别模型生成的多模态哈希特征及生成模型更新后的多模态数据元组,更新判别模型参数及奖励值,直至判别网络参数值收敛;此模块执行实施例1中的步骤S4所描述的方法,在此不再赘述。
本发明实施例提供一种跨媒体数据关联分析模型训练系统,将多模态流形关联图间潜在的多模态数据流形元组考虑在内,充分挖掘了跨模态数据间的关联性,利用判别模型层、生成模型层组成的对抗网络,同时结合强化学习层对解决对抗网络面对的离散数据梯度传播问题;通过对抗网络提高了判别器判断多模态数据元组与用户查询请求相关性的能力,显著提高了预测与用户查询请求的关联数据的能力,对于在线用户查询请求的检索系统可以显著提高其检索速度。
本发明实施例还提供一种跨媒体数据关联分析系统,用户的查询请求队列获取模块,用于获取用户的查询请求队列;跨媒体数据关联分析模块,用于将用户的查询请求队列输入实施例1中的跨媒体数据关联分析模型训练方法,生成的跨媒体数据关联分析模型的判别模型层中,得到与用户查询请求队列相关性分数大于预设期望值时的多模态数据元组。
本发明实施例中提供的跨媒体数据关联分析系统,通过对抗网络提高了判别器判断多模态数据元组与用户查询请求队列相关性的能力,显著提高了预测与用户查询请求的关联数据的能力,对于在线用户查询请求的检索系统可以显著提高其检索速度。
实施例3
本发明实施例提供一种终端,如图4所示,包括:至少一个处理器401,例如CPU(Central Processing Unit,中央处理器),至少一个通信接口403,存储器404,至少一个通信总线402。其中,通信总线402用于实现这些组件之间的连接通信。其中,通信接口403可以包括显示屏(Display)、键盘(Keyboard),可选通信接口403还可以包括标准的有线接口、无线接口。存储器404可以是高速RAM存储器(Ramdom Access Memory,易挥发性随机存取存储器),也可以是非不稳定的存储器(non-volatile memory),例如至少一个磁盘存储器。存储器404可选的还可以是至少一个位于远离前述处理器401的存储装置。其中处理器401可以执行实施例1中的跨媒体数据关联分析模型训练方法或实施例1中的跨媒体数据关联分析方法。存储器404中存储一组程序代码,且处理器401调用存储器404中存储的程序代码,以用于执行实施例1中的跨媒体数据关联分析模型训练方法或实施例1中的跨媒体数据关联分析方法。
其中,通信总线402可以是外设部件互连标准(peripheral componentinterconnect,简称PCI)总线或扩展工业标准结构(extended industry standardarchitecture,简称EISA)总线等。通信总线402可以分为地址总线、数据总线、控制总线等。为便于表示,图4中仅用一条线表示,但并不表示仅有一根总线或一种类型的总线。其中,存储器404可以包括易失性存储器(英文:volatile memory),例如随机存取存储器(英文:random-access memory,缩写:RAM);存储器也可以包括非易失性存储器(英文:non-volatile memory),例如快闪存储器(英文:flash memory),硬盘(英文:hard disk drive,缩写:HDD)或固降硬盘(英文:solid-state drive,缩写:SSD);存储器404还可以包括上述种类的存储器的组合。其中,处理器401可以是中央处理器(英文:central processingunit,缩写:CPU),网络处理器(英文:network processor,缩写:NP)或者CPU和NP的组合。
其中,存储器404可以包括易失性存储器(英文:volatile memory),例如随机存取存储器(英文:random-access memory,缩写:RAM);存储器也可以包括非易失性存储器(英文:non-volatile memory),例如快闪存储器(英文:flash memory),硬盘(英文:hard diskdrive,缩写:HDD)或固态硬盘(英文:solid-state drive,缩写:SSD);存储器404还可以包括上述种类的存储器的组合。
其中,处理器401可以是中央处理器(英文:central processing unit,缩写:CPU),网络处理器(英文:network processor,缩写:NP)或者CPU和NP的组合。
其中,处理器401还可以进一步包括硬件芯片。上述硬件芯片可以是专用集成电路(英文:application-specific integrated circuit,缩写:ASIC),可编程逻辑器件(英文:programmable logic device,缩写:PLD)或其组合。上述PLD可以是复杂可编程逻辑器件(英文:complex programmable logic device,缩写:CPLD),现场可编程逻辑门阵列(英文:field-programmable gate array,缩写:FPGA),通用阵列逻辑(英文:generic arraylogic,缩写:GAL)或其任意组合。
可选地,存储器404还用于存储程序指令。处理器401可以调用程序指令,实现如本申请执行实施例1中的跨媒体数据关联分析模型训练方法或实施例1中的跨媒体数据关联分析方法。
本发明实施例还提供一种计算机可读存储介质,计算机可读存储介质上存储有计算机可执行指令,该计算机可执行指令可执行实施例1中的跨媒体数据关联分析模型训练方法或实施例1中的跨媒体数据关联分析方法。其中,所述存储介质可为磁碟、光盘、只读存储记忆体(Read-Only Memory,ROM)、随机存储记忆体(Random Access Memory,RAM)、快闪存储器(Flash Memory)、硬盘(Hard Disk Drive,缩写:HDD)或固态硬盘(Solid-StateDrive,SSD)等;所述存储介质还可以包括上述种类的存储器的组合。
显然,上述实施例仅仅是为清楚地说明所作的举例,而并非对实施方式的限定。对于所属领域的普通技术人员来说,在上述说明的基础上还可以做出其它不同形式的变化或变动。这里无需也无法对所有的实施方式予以穷举。而由此所引申出的显而易见的变化或变动仍处于本发明创造的保护范围之中。