CN112633503B - 基于神经网络的工具变量生成与手写数字识别方法及装置 - Google Patents
基于神经网络的工具变量生成与手写数字识别方法及装置 Download PDFInfo
- Publication number
- CN112633503B CN112633503B CN202011493947.2A CN202011493947A CN112633503B CN 112633503 B CN112633503 B CN 112633503B CN 202011493947 A CN202011493947 A CN 202011493947A CN 112633503 B CN112633503 B CN 112633503B
- Authority
- CN
- China
- Prior art keywords
- network
- variable
- phi
- tool
- constraint
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/04—Inference or reasoning models
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- Evolutionary Computation (AREA)
- Data Mining & Analysis (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Artificial Intelligence (AREA)
- Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于神经网络的工具变量生成与反事实推理方法及装置。针对之前的基于工具变量的反事实推理(如手写数字识别)方法需要预先定义和可获取的工具变量的问题,本发明直接从可观测变量中学习和解耦出工具变量,大大提升了因果推断效率,节省了时间和成本。本发明首次自动地从可观测变量中提取出工具变量,在算法和运用上有独创性和独特性。将本发明应用于现有的基于工具变量的反事实预测方法,与使用真实工具变量的方法相比性能因果推断有明显提升。本发明着重于从可观测变量中解耦出工具变量的表征,解决了基于工具变量的反事实预测技术需要预先使用先验知识和高昂成本获取工具变量数据的难题,提升了手写数字识别等领域精度。
Description
技术领域
本发明涉及因果推断领域,尤其涉及一种自动的工具变量解耦方法,实现可直接从可观测变量中提取出工具变量的反事实预测方法,从而提升手写数字识别的效率和精度。
背景技术
因果推断致力于对干预产生的反事实结果进行估计,辅助决策者进行选择,以达到使得结果最优化的目标。因果推断的黄金方法是使用随机控制实验随机分配干预值进行因果推断,但是此类方法的成本过高甚至无法实现。一些方法通过加权、匹配的方式来对影响因果推断的混淆变量进行约束的目的,但是此类方法仅仅只能在混淆完全可观测的情形下使用,当混淆存在不可观测的情况下该类方法仍然存在较大缺陷。
工具变量提出用来解决不可观测的混淆问题,它和干预变量相关同时和结果变量条件独立。当下的基于工具变量的因果推断方法都需要一个预先定义的工具变量,但是这在现实情况下往往并不实用。如何直接从所有可观测变量中解耦出工具变量,并自动地进行因果推断是一个亟待解决的问题。
手写数字识别作为因果推断的一个典型应用领域,其存在同样的技术问题。针对手写数字的识别,如何通过自动的工具变量解耦,获取仅仅和标签条件相关的工具变量信息,从而辅助手写数字识别以达到最大的精度,是本发明需要解决的主要技术问题。
发明内容
本发明的目的是解决当下基于工具变量的因果推断技术手写数字识别需要预先定义的工具变量这个问题,提出一种基于神经网络的工具变量生成与反事实推理方法及装置,它能够直接从可观测的变量中解耦出工具变量,实现自动工具变量解耦和因果推断从而提升手写数字识别的效率和精度。
本发明具体采用的技术方案如下:
一种基于神经网络的工具变量生成与反事实推理方法,其包括如下步骤:
S1:获取手写数字图片数据作为干预,获取手写数字图片的标签数据作为结果,将手写数字图片和标签构建成反事实预测数据集;
S2:使用互信息约束的方法,对工具变量和其他协变量的表征设置约束,用于进行初步的表征学习;
S3:基于两阶段反事实预测技术设置额外约束,用于对初步学习到的解耦表征进一步优化;
S4:基于所述的反事实预测数据集,通过交替优化S2和S3中设置的约束,获得优化后的工具变量和其他协变量的表征模型;
S5:针对待识别的手写数字图片,利用优化后的表征模型,得到工具变量和其他协变量的表征,并将其应用于基于工具变量的反事实预测模型中,输出手写数字图片中手写数字的识别结果。
作为优选,步骤S1中,所述反事实预测数据集表示为其中vi,xi,yi分别为第i个样本的可观测变量、干预和结果,其中样本的可观测变量以该样本对应的手写数字图片本身代替,N为样本总数。
进一步的,所述的步骤S2具体包括以下子步骤:
S201:基于神经网络构建以可观测变量V为输入以工具变量Z为输出的第一表征模型φZ(·),同时基于神经网络构建以可观测变量V为输入以其他协变量C为输出的第二表征模型φC(·);
S202:基于神经网络构建以工具变量Z为输入以干预变量X为输出的第一约束网络fZX(·),设定第一约束网络的损失函数为:
其中:为第一约束网络fZX(·)中以φZ(vi)为输入去预测xi时得到的变分分布;φZ(vi)为第一表征模型φZ(·)中输入vi时得到的输出结果;log表示对数似然函数;
另外,针对第一约束网络设定互信息最大化损失函数为:
S203:基于神经网络构建以工具变量Z为输入以结果变量Y为输出的第二约束网络fZY(·,设定第二约束网络的损失函数为:
其中:为第二约束网络fZY(·中以φz(vi)为输入去预测yi时得到的变分分布;
另外,针对第二约束网络设定互信息最大化损失函数为:
其中:ωij为由第i个样本的干预xi和第j个样本的干预xj之间距离决定的权重;
S204:基于神经网络构建以其他协变量C为输入以干预变量X为输出的第三约束网络fCX(·,设定第三约束网络的损失函数为:
其中:为第三约束网络fCX(·中以φC(vi)为输入去预测xi时得到的变分分布;φC(vi)表示第二表征模型φC(·中输入vi时得到的输出结果;
另外,针对第三约束网络设定互信息最大化损失函数为:
S205:基于神经网络构建以其他协变量C为输入以结果变量Y为输出的第四约束网络fCY(·,设定第四约束网络的损失函数为:
其中:为第四约束网络fCY(·中以φC(vi)为输入去预测yi时得到的变分分布;
另外,针对第四约束网络设定互信息最大化损失函数为:
S206:基于神经网络构建以工具变量Z为输入以其他协变量C为输出的第五约束网络fZC(·,设定第五约束网络的损失函数为:
其中:为第五约束网络fZC(·中以φZ(vi)为输入去预测φC(vi)时得到的变分分布;
另外,针对第五约束网络设定互信息最大化损失函数为:
进一步的,步骤S203中,所述权重ωij通过RBF核函数计算,公式如下:
其中σ是一个用于调节的超参数。
进一步的,所述的步骤S3具体包括以下子步骤:
S301:基于神经网络构建以工具变量Z的表征φZ(vi)和其他协变量C的表征φC(vi)为输入以干预变量X为输出的第一阶段回归网络fX(·,并设定第一阶段回归网络的损失函数为:
其中l(…)表示计算平方误差;
S302:基于神经网络构建以和其他协变量C的表征φC(vi)为输入以结果变量Y为输出的第二阶段回归网络fY(·,并设定第二阶段回归网络的损失函数为:
其中:femb(·为用于扩充干预变量维度的映射网络,表示第一阶段回归网络fX(·输出的干预变量X估计值,/>
进一步的,所述的步骤S4具体包括以下步骤:
S401:将所有五个约束网络的损失函数进行整合得到综合损失函数:
利用所述反事实预测数据集对五个约束网络进行训练,通过最小化所述综合损失函数分别优化各约束网络中的网络参数;
S402:将所有五个约束网络的互信息最大化损失函数进行整合得到综合互信息损失函数:
其中:α、β、∈、η是权重超参数;
利用所述反事实预测数据集对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化所述综合互信息损失函数分别优化两个表征模型中的网络参数;
S403:利用所述反事实预测数据集对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化所述第一阶段回归网络的损失函数优化第一阶段回归网络以及两个表征模型中的网络参数;
S404:利用所述反事实预测数据集继续对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化所述第二阶段回归网络的损失函数优化第二阶段回归网络、映射网络以及两个表征模型中的网络参数;
S405:不断迭代重复S401~S405,使被用于交替训练对应的网络参数,直至迭代终止,得到参数优化后的第一表征模型φ′Z(·)和第二表征模型φ′C(·)。
进一步的,所述的步骤S5具体包括以下步骤:
S51:针对待识别的目标手写数字图片,将目标手写数字图片作为可观测变量,输入参数优化后的第一表征模型φ′Z(·)和第二表征模型φ′C(·)中,得到工具变量Z的表征和其他协变量C的表征;
S52:将目标手写数字图片以及S51中得到的工具变量Z的表征和其他协变量C的表征,一并输入经过训练的基于工具变量的反事实预测模型中,输出目标手写数字图片中手写数字的识别结果。
进一步的,所述基于工具变量的反事实预测模型为2SLS、Deep IV、Kernel IV或DeepGMM模型。
另一方面,本发明提供了一种基于深度网络的工具变量解耦与手写数字识别装置,其包括存储器和处理器;
所述存储器,用于存储计算机程序;
所述处理器,用于当执行所述计算机程序时,实现如前述任一方案所述的基于神经网络的工具变量生成与反事实推理方法。
本发明使用表征学习技术进行自动的工具变量解耦。针对之前的基于工具变量的反事实预测方法需要预先定义和可获取的工具变量的问题,本发明直接从可观测变量中学习和解耦出工具变量,大大提升了因果推断的效率,节省了大量时间和成本。本发明首次自动地从可观测变量中提取出工具变量,在算法和运用上有自己的独创性和独特性。将本发明应用于现有的基于工具变量的反事实预测方法,并自动地进行它的性能与假设使用真实工具变量的该方法相比因果推断,可以达到相当、甚至更好的性能表现。
附图说明
图1为基于神经网络的工具变量生成与反事实推理方法流程图
图2为基于神经网络的工具变量生成与反事实推理结构示意图。
具体实施方式
下面结合附图和具体实施方式对本发明做进一步阐述和说明。
如图1所示,一种基于神经网络的工具变量生成与反事实推理方法,该实施方式中的反事实推理用于实现手写数字识别,其包括如下步骤:
S1:获取手写数字图片数据作为干预,获取手写数字图片的标签数据作为结果,将手写数字图片和标签构建成反事实预测数据集;
S2:使用互信息约束的方法,对工具变量和其他协变量的表征设置约束,用于进行初步的表征学习;
S3:基于两阶段反事实预测技术设置额外约束,用于对初步学习到的解耦表征进一步优化;
S4:基于所述的反事实预测数据集,通过交替优化S2和S3中设置的约束,获得优化后的工具变量和其他协变量的表征模型;
S5:针对待识别的手写数字图片,利用优化后的表征模型,得到工具变量和其他协变量的表征,并将其应用于基于工具变量的反事实预测模型中,输出手写数字图片中手写数字的识别结果。
在上述S1~S5步骤中,具体实现方式如下:
本发明中步骤S1具体如下:每一组手写数字图片及其对应的数字标签作为一组样本,构建成反事实预测数据集,表示为其中vi,xi,yi分别为第i个样本的可观测变量、干预和结果,N为样本总数。其中对于手写数字图片而言,由于其本身难以提取可观测变量,因此本发明中实际将样本的可观测变量vi也直接以该样本对应的手写数字图片本身代替,即vi=xi。
参见图2所示,在S1中,假设干预变量X(手写数字图片)和结果变量Y(手写数字图片对应的标签)之间的数据关系为:
Y=g(X)+e
其中g(·)是一个未知的因果反馈函数(结构函数),它可能是非线性的连续函数。e是一个误差项,它包含了同时和X、Y都有关的不可观测的混淆。其中e满足零期望和有限方差的要求,即且/>这里允许e和X相关,即/>使得X成为了一个内生性变量同时/>
工具变量Z用于解决内生性干预变量问题,它需要满足干预相关和结果排除两个条件。干预相关指X直接和Z相关,即使得结果排除指Z对Y仅仅只能通过X施加影响,即使得/>除此之外,Z应该是无混淆的,即需要使得/>基于工具变量的反事实预测的目的就是对真实的反馈函数进行预测。
如果存在其他外生性的变量C,可以直接将其合并入工具变量和干预变量,即X=(X′,C)和Z=(Z′,C),其中X′和Z′是真实的干预变量和工具变量。由于C是严格外生的,即它和无关观测的误差e无关,因此这样的操作并不会对结果产生影响。
假设可获取的可观测变量是V、干预变量是X、结果变量是Y,可获取N个样本,即本发明的目标就是使用这N个样本,获取工具变量Z的解耦表征。
本发明中步骤S2具体包括以下子步骤:
S201:基于神经网络构建第一表征模型φZ(·),其中第一表征模型φZ(·)以可观测变量V为输入,以工具变量Z为输出。同样的,基于神经网络构建第二表征模型φC(·),其中第二表征模型φC(·)以可观测变量V为输入,以其他协变量C为输出。
本步骤中,使用神经网络构建工具变量Z和其他协变量C的表征,即φZ(·)和φC(·),使得φZ(·)和X相关、和Y关于X条件独立,也使得φC(·)同时和X、Y相关。同时可以通过使得φZ(·)和φC(·)尽可能独立,来对进入Z和C的信息进行正则约束。
S202:基于神经网络构建第一约束网络fZX(·),其中第一约束网络fZX(·)以工具变量Z为输入,以干预变量X为输出,同时设定第一约束网络的损失函数为:
其中:为第一约束网络fZX(·)中以φZ(vi)为输入去预测xi时得到的变分分布;φZ(vi)为第一表征模型φZ(·)中输入vi时得到的输出结果;log表示对数似然函数。
设置本步骤是由于首先工具变量Z需要满足干预相关条件,即 因此需要鼓励可观测变量V中和X相关的信息能够进入Z的表征中。由于互信息需要使用的是条件分布信息,而数据是基于样本的,因此首先使用变分分布/>近似真实的条件分布/>后续通过最小化损失函数/>就可以获得最优的变分近似。
另外,为了增加Z和X的关联性,针对第一约束网络设定互信息最大化损失函数为:
其中是正样本对(vi,xi)的条件似然,/>是负样本对(vi,xj)的条件似然。后续通过最小化/>即可增大正负样本对之间的差异,以此来优化工具变量的表征φZ(V)。
S203:基于神经网络构建第二约束网络fZY(·),其中第二约束网络fZY(·)以工具变量Z为输入,以结果变量Y为输出。同时设定第二约束网络的损失函数为:
其中:为第二约束网络fZY(·)中以φZ(vi)为输入去预测yi时得到的变分分布。
工具变量Z还需要满足结果排除条件,即因此需要对Z和Y的条件互信息进行最小化。由于X是连续的变量,因此此处通过让正样本和负样本的似然期望相近来使得Z和Y条件独立。
针对第二约束网络设定互信息最大化损失函数为:
其中:其中是正样本对(vi,yi)的条件似然,是负样本对(vi,yj)的条件似然;ωij为由第i个样本的干预xi和第j个样本的干预xj之间距离决定的权重。此处权重ωij通过RBF核函数计算,公式如下:
其中σ是一个用于调节的超参数。如果正负样本的xi和xj相接近,则它们的权重增大,也就是本发明着重于解决具有相近X的样本对。
S204:协变量C的表征φC(V)需要首先和X相关,因此基于神经网络构建第三约束网络fCX(·),其中第三约束网络fCX(·)以其他协变量C为输入,以干预变量X为输出。同时设定第三约束网络的损失函数为:
其中:为第三约束网络fCX(·)中以φC(vi)为输入去预测xi时得到的变分分布;φC(vi)表示第二表征模型φC(·)中输入vi时得到的输出结果;
另外,针对第三约束网络设定互信息最大化损失函数为:
S205:同时需要使得协变量C的表征φC(V)需要和Y相关,因此基于神经网络构建第四约束网络fCY(·,其中第四约束网络fCY(·以其他协变量C为输入,以结果变量Y为输出。同时设定第四约束网络的损失函数为:
其中:为第四约束网络fCY(…中以φC(vi)为输入去预测yi时得到的变分分布;
另外,针对第四约束网络设定互信息最大化损失函数为:
S206:基于神经网络构建第五约束网络fZC(·,其中第五约束网络fZC(·以工具变量Z为输入,以其他协变量C为输出。同时设定第五约束网络的损失函数为:
其中:为第五约束网络fZC(·中以φZ(vi)为输入去预测φC(vi)时得到的变分分布。
本步骤中,如果协变量C的信息进入工具变量Z中,会破坏Z的结果排除条件。同时如果Z的信息进入C中,则会对反事实预测带来一定的偏差。因此通过最小化Z和C的互信息来对它们进行约束,针对第五约束网络设定互信息最大化损失函数为:
在本发明中,步骤S3具体包括以下子步骤:
S301:第一阶段(干预)首先使用工具变量Z和其他协变量C的表征去回归干预变量X。具体而言,基于神经网络构建第一阶段回归网络fX(·,其中第一阶段回归网络fX(·以工具变量Z的表征φZ(vi)和其他协变量C的表征φC(vi)为输入,以干预变量X为输出。同时,设定第一阶段回归网络的损失函数为:
其中l(·)表示计算平方误差;
S302:第二阶段(结果)进一步使用预测出来的来回归Y。具体而言,基于神经网络构建第二阶段回归网络fY(·),其中第二阶段回归网络fY(·)以/>和其他协变量C的表征φC(vi)为输入,以结果变量Y为输出。同时,设定第二阶段回归网络的损失函数为:
其中:femb(·)为用于扩充干预变量维度的映射网络,表示第一阶段回归网络fX(·)输出的干预变量X估计值,/>
在本发明中,步骤S4具体包括以下步骤:
S401:将所有五个约束网络的损失函数进行整合得到综合损失函数:
利用S1中的反事实预测数据集对五个约束网络进行训练,通过最小化综合损失函数分别优化各约束网络中的网络参数。该损失函数的各个部分会优化各自的参数,互相之间不会干扰,因此不需要超参数。
S402:将所有五个约束网络的互信息最大化损失函数进行整合得到综合互信息损失函数:
其中:α、β、∈、η是权重超参数。
利用S1中的反事实预测数据集对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化综合互信息损失函数分别优化第一表征模型φZ(·)和第二表征模型φC(·)中的网络参数。
S403:利用S1中的反事实预测数据集对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化第一阶段回归网络的损失函数优化第一阶段回归网络以及两个表征模型中的网络参数;
S404:利用S1中的反事实预测数据集继续对第一表征模型φZ(·和第二表征模型φC(·进行训练,通过最小化第二阶段回归网络的损失函数优化第二阶段回归网络、映射网络以及两个表征模型中的网络参数;
S405:不断迭代重复S401~S405,使被用于交替训练对应的网络参数,直至迭代终止,得到参数优化后的第一表征模型φ′Z(·和第二表征模型φ′C(·。
当上述解耦模型完成了优化,可以直接将其作为基于工具变量的方法的输入,将其用于反事实预测,获得更准确的反事实预测精度。
在本发明中,步骤S5具体包括以下步骤:
S51:针对待识别的目标手写数字图片,将目标手写数字图片作为可观测变量,输入参数优化后的第一表征模型φ′Z(·和第二表征模型φ′C(·中,得到工具变量Z的表征和其他协变量C的表征;
S52:将目标手写数字图片以及S51中得到的工具变量Z的表征和其他协变量C的表征,一并输入经过训练的基于工具变量的反事实预测模型中,输出目标手写数字图片中手写数字的识别结果。
在本发明中,基于工具变量的反事实预测模型可以是任何能够通过工具变量实现预测的模型,例如可选的为2SLS、Deep IV、Kernel IV或DeepGMM模型。
上述方法的各步骤中的具体参数可以根据实际进行调整。
本发明的关键技术在于基于表征学习进行自动的工具变量解耦,获得有效的工具变量的表征,并将其应用于基于工具变量的反事实预测方法,使得这些方法可以在无法获得工具变量的场景中得以较好的应用,达到相当甚至更好的反事实预测精度。
另外,在另一实施例中,本发明提供了一种基于神经网络的工具变量生成与反事实推理方法及装置,它包括存储器和处理器;
其中存储器,用于存储计算机程序;
处理器,用于当执行所述计算机程序时,实现前述实施例中的基于神经网络的工具变量生成与反事实推理方法及装置。
上述S1~S5的方法具体可以通过计算机程序来实现,举例而言,计算机程序中的模块可以按照功能划分如下:
采样模块,对干干预变量、结果变量、可观测变量进行采样,约束可观测变量严格外生;
互信息约束模块,对工具变量和协变量的表征通过互信息约束它们与干预变量和结果变量之间的关系;
两阶段反事实预测模块,分别对干预变量和结果变量进行预测,两次预测的偏差用于进一步优化初步学习到的表征;
反事实预测模块,交替优化表征,应用学习到的表征到现有的反事实预测方法进行反事实预测,提升反事实预测的精度。
其中,采样模块包括:
干预变量采样模块,用于从原始数据中采样干预变量,对其进行控制来进行反事实推断;
结果变量采样模块,用于从原始数据中采样结果变量,结果变量是对干预变量变化的反映;
可观测变量采样模块,可观测变量反映每个样本的特征,我们使得它严格外生,用于工具变量的解耦。
其中,互信息约束模块包括:
工具变量约束模块,对工具变量的表征进行互信息约束,使得它与干预变量相关,同时和结果变量条件独立;
协变量约束模块,对协变量的表征进行互信息约束,使得它和干预变量、结果变量都相关;
表征正交模块,对工具变量和协变量的表征进行正交约束,使得工具变量和协变量的表征尽可能独立。
其中,两阶段反事实预测模块包括:
干预变量预测模块,将初步解耦到的工具变量、协变量表征用于干预变量的预测,获取干预变量回归值;
结果变量预测模块,将干预变量回归值和协变量用于结果变量的预测,得到反事实结果预测值。
其中,反事实预测模块包括:
表征优化模块,综合以上的互信息约束模块和两阶段反事实预测模块,通过交替优化的方式获取最优的表征;
反事实预测模块,将得到的最优表征用于现有的方法中进行反事实预测,提升反事实预测的精度。
当然,以上具体的功能模块的设计可以根据实际需要调整,以满足功能实现为准。
需要注意的是,存储器可以包括随机存取存储器(Random Access Memory,RAM),也可以包括非易失性存储器(Non-Volatile Memory,NVM),例如至少一个磁盘存储器。上述的处理器可以是通用处理器,包括中央处理器(Central Processing Unit,CPU)、网络处理器(Network Processor,NP)等;还可以是数字信号处理器(Digital Signal Processing,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。当然,还装置中还应当具有实现程序运行的必要组件,例如电源、通信总线等等。
在另一实施例中,本发明提供了一种计算机可读存储介质,该存储介质上存储有计算机程序,当所述计算机程序被处理器执行时,实现前述实施例中的基于神经网络的工具变量生成与反事实推理方法及装置。
下面利用前述的基于神经网络的工具变量生成与反事实推理方法及装置,通过一个具体的应用实例来展示本发明分类方法的具体效果。具体的方法步骤如前所述,不再赘述,下面仅展示其具体效果。
实施例
本实施例在手写数字图片和仿真数据集上进行测试。该方法主要针对手写数字图片和对应的标签之间的关系,通过自动的工具变量解耦,获取仅仅和标签条件相关的工具变量信息,从而辅助手写数字识别以达到最大的精度。
我们给定手写数字图片X,手写数字图片对应的标签Y之间的关系为:
Y=g(X)+e+σ
其中为不可观测的混淆变量,/>为误差项,g是手写数字图片X和手写数字图片对应的标签Y之间真实的潜在关系(非线性映射函数),此处我们假设他们之间的关系为g(X)=-X。同时手写数字图片受到潜在的工具变量Z~Unif([-3,3]2)、不可观测的混淆变量e和误差项/>的影响:
X=Z1+e+γ
算法训练和测试中,分别采样500个样本用于训练、验证、测试。每个样本都包含了手写数字图片、对应的标签和其他相关的混合数据。为了展示该方法解耦出的工具变量的性能,使用了辅助的基于工具变量的反事实预测模型来进行手写数字图片预测。本实施例中所采用的基于工具变量的反事实预测模型包括五种,分别为2SLS(van)、2SLS(poly)、2SLS(NN)、DeepIV、KernelIV、DeepGMM。这些模型算法均属于现有技术,不再赘述。若需了解其具体的实现算法,可实现参见以下现有技术文献:
2SLS(van):Angrist J D,Pischke J S.Mostly harmless econometrics:Anempiricist's companion[M].Princeton university press,2008.
2SLS(poly):Darolles S,Fan Y,Florens J P,et al.Nonparametricinstrumental regression[J].Econometrica,2011,79(5):1541-1565.
2SLS(NN):Darolles S,Fan Y,Florens J P,et al.Nonparametricinstrumental regression[J].Econometrica,2011,79(5):1541-1565.
DeepIV:Hartford J,Lewis G,Leyton-Brown K,et al.Deep IV:A flexibleapproach for counterfactual prediction[C]//International Conference onMachine Learning.2017:1414-1423.
KernelIV:Singh R,Sahani M,Gretton A.Kernel instrumental variableregression[C]//Advances in Neural Information Processing Systems.2019:4593-4605.
DeepGMM:Bennett A,Kallus N,Schnabel T.Deep generalized method ofmoments for instrumental variable analysis[C]//Advances in Neural InformationProcessing Systems.2019:3564-3574.
为了客观评估本算法的性能,使用手写数字图片的预测结果与真实的结果的均方误差(MSE)对该方法进行评价。
所得实验结果如表1所示,结果表明,本发明的方法具有极高的手写数字图片识别精度,从而能够显著提升手写数字识别的效率和准确性。
表1不同辅助方法下手写数字识别的均方误差及其标准差
2SLS(van) | 2SLS(poly) | 2SLS(NN) | DeepIV | KernelIV | DeepGMM |
0.00(0.00) | 0.00(0.00) | 0.14(0.03) | 0.09(0.03) | 0.11(0.04) | 0.01(0.01) |
以上所述的实施例只是本发明的一种较佳的方案,然其并非用以限制本发明。有关技术领域的普通技术人员,在不脱离本发明的精神和范围的情况下,还可以做出各种变化和变型。因此凡采取等同替换或等效变换的方式所获得的技术方案,均落在本发明的保护范围内。
Claims (5)
1.一种基于神经网络的工具变量生成与手写数字识别方法,其特征在于,包括如下步骤:
S1:获取手写数字图片数据作为干预,获取手写数字图片的标签数据作为结果,将手写数字图片和标签构建成反事实预测数据集,所述反事实预测数据集表示为其中vi,xi,yi分别为第i个样本的可观测变量、干预和结果,其中样本的可观测变量以该样本对应的手写数字图片本身代替,N为样本总数;
S2:使用互信息约束的方法,对工具变量和其他协变量的表征设置约束,用于进行初步的表征学习,具体步骤如下:
S201:基于神经网络构建以可观测变量V为输入以工具变量Z为输出的第一表征模型φZ(·),同时基于神经网络构建以可观测变量V为输入以其他协变量C为输出的第二表征模型φC(·);
S202:基于神经网络构建以工具变量Z为输入以干预变量X为输出的第一约束网络fZX(·),设定第一约束网络的损失函数为:
其中:为第一约束网络fZX(·)中以φZ(vi)为输入去预测xi时得到的变分分布;φZ(vi)为第一表征模型φZ(·)中输入vi时得到的输出结果;log表示对数似然函数;
另外,针对第一约束网络设定互信息最大化损失函数为:
S203:基于神经网络构建以工具变量Z为输入以结果变量Y为输出的第二约束网络fZY(·),设定第二约束网络的损失函数为:
其中:为第二约束网络fZY(·)中以φZ(vi)为输入去预测yi时得到的变分分布;
另外,针对第二约束网络设定互信息最大化损失函数为:
其中:ωij为由第i个样本的干预xi和第j个样本的干预xj之间距离决定的权重;
S204:基于神经网络构建以其他协变量C为输入以干预变量X为输出的第三约束网络fCX(·),设定第三约束网络的损失函数为:
其中:为第三约束网络fCX(·)中以φC(vi)为输入去预测xi时得到的变分分布;φC(vi)表示第二表征模型φC(·)中输入vi时得到的输出结果;
另外,针对第三约束网络设定互信息最大化损失函数为:
S205:基于神经网络构建以其他协变量C为输入以结果变量Y为输出的第四约束网络fCY(·),设定第四约束网络的损失函数为:
其中:为第四约束网络fCY(·)中以φC(vi)为输入去预测yi时得到的变分分布;
另外,针对第四约束网络设定互信息最大化损失函数为:
S206:基于神经网络构建以工具变量Z为输入以其他协变量C为输出的第五约束网络fZC(·),设定第五约束网络的损失函数为:
其中:为第五约束网络fZC(·)中以φZ(vi)为输入去预测φC(vi)时得到的变分分布;
另外,针对第五约束网络设定互信息最大化损失函数为:
S3:基于两阶段反事实预测技术设置额外约束,用于对初步学习到的解耦表征进一步优化,具体步骤如下:
S301:基于神经网络构建以工具变量Z的表征φZ(vi)和其他协变量C的表征φC(vi)为输入以干预变量X为输出的第一阶段回归网络fX(·),并设定第一阶段回归网络的损失函数为:
其中l(·)表示计算平方误差;
S302:基于神经网络构建以和其他协变量C的表征φC(vi)为输入以结果变量Y为输出的第二阶段回归网络fY(·),并设定第二阶段回归网络的损失函数为:
其中:femb(·)为用于扩充干预变量维度的映射网络,表示第一阶段回归网络fX(·)输出的干预变量X估计值,/>
S4:基于所述的反事实预测数据集,通过交替优化S2和S3中设置的约束,获得优化后的工具变量和其他协变量的表征模型,具体步骤如下:
S401:将所有五个约束网络的损失函数进行整合得到综合损失函数:
利用所述反事实预测数据集对五个约束网络进行训练,通过最小化所述综合损失函数分别优化各约束网络中的网络参数;
S402:将所有五个约束网络的互信息最大化损失函数进行整合得到综合互信息损失函数:
其中:α、β、∈、η是权重超参数;
利用所述反事实预测数据集对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化所述综合互信息损失函数分别优化两个表征模型中的网络参数;
S403:利用所述反事实预测数据集对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化所述第一阶段回归网络的损失函数优化第一阶段回归网络以及两个表征模型中的网络参数;
S404:利用所述反事实预测数据集继续对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化所述第二阶段回归网络的损失函数优化第二阶段回归网络、映射网络以及两个表征模型中的网络参数;
S405:不断迭代重复S401~S405,使被用于交替训练对应的网络参数,直至迭代终止,得到参数优化后的第一表征模型φ′φ(·)和第二表征模型φ′C(·);
S5:针对待识别的手写数字图片,利用优化后的表征模型,得到工具变量和其他协变量的表征,并将其应用于基于工具变量的反事实预测模型中,输出手写数字图片中手写数字的识别结果。
2.如权利要求1所述的基于神经网络的工具变量生成与手写数字识别方法,其特征在于,步骤S203中,所述权重ωij通过RBF核函数计算,公式如下:
其中σ是一个用于调节的超参数。
3.如权利要求1所述的基于神经网络的工具变量生成与手写数字识别方法,其特征在于,所述的步骤S5具体包括以下步骤:
S51:针对待识别的目标手写数字图片,将目标手写数字图片作为可观测变量,输入参数优化后的第一表征模型φ′Z(·)和第二表征模型φ′C(·)中,得到工具变量Z的表征和其他协变量C的表征;
S52:将目标手写数字图片以及S51中得到的工具变量Z的表征和其他协变量C的表征,一并输入经过训练的基于工具变量的反事实预测模型中,输出目标手写数字图片中手写数字的识别结果。
4.如权利要求1所述的基于神经网络的工具变量生成与手写数字识别方法,其特征在于,所述基于工具变量的反事实预测模型为2SLS、Deep IV、Kernel IV或DeepGMM模型。
5.一种基于神经网络的工具变量生成与手写数字识别装置,其特征在于,包括存储器和处理器;
所述存储器,用于存储计算机程序;
所述处理器,用于当执行所述计算机程序时,实现如权利要求1~4任一项所述的基于神经网络的工具变量生成与手写数字识别方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011493947.2A CN112633503B (zh) | 2020-12-16 | 2020-12-16 | 基于神经网络的工具变量生成与手写数字识别方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011493947.2A CN112633503B (zh) | 2020-12-16 | 2020-12-16 | 基于神经网络的工具变量生成与手写数字识别方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112633503A CN112633503A (zh) | 2021-04-09 |
CN112633503B true CN112633503B (zh) | 2023-08-22 |
Family
ID=75316672
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011493947.2A Active CN112633503B (zh) | 2020-12-16 | 2020-12-16 | 基于神经网络的工具变量生成与手写数字识别方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112633503B (zh) |
Families Citing this family (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113409901B (zh) * | 2021-06-29 | 2023-09-29 | 南华大学 | 一种级联医疗观测数据的因果推断方法及系统 |
CN113744805A (zh) * | 2021-09-30 | 2021-12-03 | 山东大学 | 基于bert框架的dna甲基化预测方法及系统 |
CN114186096A (zh) * | 2021-12-10 | 2022-03-15 | 北京达佳互联信息技术有限公司 | 信息处理方法及装置 |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN104850837A (zh) * | 2015-05-18 | 2015-08-19 | 西南交通大学 | 手写文字的识别方法 |
CN110766044A (zh) * | 2019-09-11 | 2020-02-07 | 浙江大学 | 一种基于高斯过程先验指导的神经网络训练方法 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11580392B2 (en) * | 2019-05-30 | 2023-02-14 | Samsung Electronics Co., Ltd. | Apparatus for deep representation learning and method thereof |
-
2020
- 2020-12-16 CN CN202011493947.2A patent/CN112633503B/zh active Active
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN104850837A (zh) * | 2015-05-18 | 2015-08-19 | 西南交通大学 | 手写文字的识别方法 |
CN110766044A (zh) * | 2019-09-11 | 2020-02-07 | 浙江大学 | 一种基于高斯过程先验指导的神经网络训练方法 |
Non-Patent Citations (1)
Title |
---|
深度学习的可解释性;吴飞等;航空兵器;第26卷(第1期);第40-44页 * |
Also Published As
Publication number | Publication date |
---|---|
CN112633503A (zh) | 2021-04-09 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112633503B (zh) | 基于神经网络的工具变量生成与手写数字识别方法及装置 | |
Arora et al. | On exact computation with an infinitely wide neural net | |
Cremer et al. | Inference suboptimality in variational autoencoders | |
Ergen et al. | Efficient online learning algorithms based on LSTM neural networks | |
Egilmez et al. | Graph learning from data under Laplacian and structural constraints | |
WO2021007812A1 (zh) | 一种深度神经网络超参数优化方法、电子设备及存储介质 | |
CN110287983B (zh) | 基于最大相关熵深度神经网络单分类器异常检测方法 | |
US11574198B2 (en) | Apparatus and method with neural network implementation of domain adaptation | |
Riquelme et al. | Online active linear regression via thresholding | |
CN112232397A (zh) | 图像分类模型的知识蒸馏方法、装置和计算机设备 | |
WO2020091919A1 (en) | Computer architecture for multiplier-less machine learning | |
Oymak et al. | Generalization guarantees for neural architecture search with train-validation split | |
Gordon et al. | Source identification for mixtures of product distributions | |
CN113011531B (zh) | 分类模型训练方法、装置、终端设备及存储介质 | |
Dandi et al. | The Benefits of Reusing Batches for Gradient Descent in Two-Layer Networks: Breaking the Curse of Information and Leap Exponents | |
Meng et al. | Learning Regions of Attraction in Unknown Dynamical Systems via Zubov-Koopman Lifting: Regularities and Convergence | |
WO2022142026A1 (zh) | 分类网络构建方法以及基于分类网络的分类方法 | |
Pavlenko et al. | Methods For Black–Box Diagnostics Using Volterra Kernels | |
Wu et al. | Approximation by random weighting method for M-test in linear models | |
Tamás et al. | Recursive estimation of conditional kernel mean embeddings | |
US20230024743A1 (en) | Efficient second order pruning of computer-implemented neural networks | |
Fornasier et al. | Approximation Theory, Computing, and Deep Learning on the Wasserstein Space | |
US20220121960A1 (en) | Generation of simplified computer-implemented neural networks | |
Beretta et al. | The stochastic complexity of spin models: How simple are simple spin models | |
US20220383103A1 (en) | Hardware accelerator method and device |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |