图像数据增强网络的训练方法及其训练装置、存储介质Image data augmentation network training method, training device and storage medium thereof
技术领域technical field
本发明属于图像处理技术领域,具体地讲,涉及图像数据增强网络的训练方法及其训练装置、计算机可读存储介质。The invention belongs to the technical field of image processing, and in particular, relates to a training method of an image data enhancement network, a training device thereof, and a computer-readable storage medium.
背景技术Background technique
在深度学习中,许多神经网络需要大量的参数进行训练以有效地防止过拟合现象,高质量的数据集应包含足够的类别、具有一定的多样性,且对数据的特征可以充分的表达。In deep learning, many neural networks need a large number of parameters for training to effectively prevent overfitting. High-quality data sets should contain enough categories, have a certain diversity, and can fully express the characteristics of the data.
但是在很多实际情况中,大量且高质量数据获取非常困难。具体表现为:1)训练可用数据较少且较难获取,需要大量的人力;2)数据在各个类别上不平衡;3)数据上有敏感信息或个人隐私信息等不能用于公开使用。这些数据上的局限在医学图像处理领域尤为明显。微调等在深度学习中常用的方法在缺乏多样性的小样本的训练上难以起到有效的作用。为了提高训练精度、有效防止过拟合现象,目前在深度学习中使用最多的是数据增强的方法。传统的图像数据增强方法主要包括:平移、旋转、翻转、缩放、裁剪、添加噪声等。这些方法操作简单快速且具有可复制性,但传统数据增强方法产生的图像具有很强的相关性,即新增的有效信息很少,在复杂图像的情况中并不能很好的解决因为小样本产生的问题。However, in many practical situations, it is very difficult to obtain a large amount of high-quality data. The specific manifestations are: 1) The available data for training is less and difficult to obtain, requiring a lot of manpower; 2) The data is unbalanced in various categories; 3) The data contains sensitive information or personal privacy information that cannot be used for public use. These data limitations are particularly evident in the field of medical image processing. The methods commonly used in deep learning such as fine-tuning are difficult to play an effective role in the training of small samples lacking diversity. In order to improve training accuracy and effectively prevent overfitting, data augmentation is currently the most used method in deep learning. Traditional image data enhancement methods mainly include: translation, rotation, flipping, scaling, cropping, adding noise, etc. These methods are simple, fast and reproducible, but the images generated by traditional data enhancement methods have strong correlation, that is, little new effective information is added, which cannot be solved well in the case of complex images because of small samples. resulting problems.
生成式对抗网络(GAN,Generative adversarial network)近年来在图像合成上表现出了巨大的潜力。Generative adversarial network (GAN, Generative adversarial network) has shown great potential in image synthesis in recent years.
原始的GAN结构基于多层感知网络(MLP,Multilayer Perceptron),包含两部分神经网络:生成器G和判别器D,生成器D的输入z来自已知分布p(z),通常选取为高斯分布或正态分布,生成器D生成服从分布p
g(x)的输出x
g以实现p
g(x)=p
r(x),其中p
r(x)为真实样本x
r的分布;而判别器则输出样本为真实样本的概率:参数为θ
g的生成器输出生成样本x
g=G(z;θ
g),参 数为θ
d输出y=D(x;θ
d)。生成器G和判别器D通过对抗训练优化损失函数,使(G,D)达到Nash平衡。
The original GAN structure is based on the Multilayer Perceptron (MLP, Multilayer Perceptron), which consists of two parts of the neural network: the generator G and the discriminator D. The input z of the generator D comes from the known distribution p(z), which is usually selected as a Gaussian distribution. or a normal distribution, the generator D generates an output x g that obeys the distribution p g (x) to achieve p g (x) = p r (x), where p r (x) is the distribution of the real sample x r ; while the discriminant The generator outputs the probability that the sample is a real sample: the generator with parameter θ g outputs the generated sample x g = G(z; θ g ), and the parameter θ d outputs y = D(x; θ d ). The generator G and the discriminator D optimize the loss function through adversarial training to make (G, D) reach Nash balance.
GAN的损失函数为:The loss function of GAN is:
其中
为数学期望,对原始GAN结构的损失函数分析可知,当判别器已经优化为:
in For the mathematical expectation, the loss function analysis of the original GAN structure shows that when the discriminator has been optimized as:
上述损失函数等价为最小化真实数据分布与生成数据分布之间的Jensen-Shannon散度(JSD):The above loss function is equivalent to minimizing the Jensen-Shannon Divergence (JSD) between the true data distribution and the generated data distribution:
然而当两个分布的支撑集可忽略时,JSD为常量,导致生成器无法继续训练。而在现实情况中,生成器随机初始化后的生成分布很难与真实分布有不可忽略的重叠,这将导致模式消失或者模式崩塌的问题。However, when the support set of the two distributions is negligible, the JSD is constant, causing the generator to fail to continue training. In real-world situations, it is difficult for the generated distribution after random initialization of the generator to have a non-negligible overlap with the real distribution, which will lead to the problem of mode disappearance or mode collapse.
为了解决上述问题,同时最优传输方法(OT,Optimal transport)通过寻找两个分布传输之间的最小代价,无论两个分布的支撑集是否存在重叠,都可以测量两个分布之间的距离。这个理论为解决原始GAN结构损失函数的缺陷提供了方法。In order to solve the above problems, the simultaneous optimal transport method (OT, Optimal transport) can measure the distance between two distributions by finding the minimum cost between the two distributions, regardless of whether the support sets of the two distributions overlap or not. This theory provides a way to address the shortcomings of the original GAN structural loss function.
从OT的角度,GAN可以被看做通过生成器实现OT映射,通过判别器实现真实数据分布和生成数据分布之间的距离判定。两个分布之间的距离可以定义为:From the perspective of OT, GAN can be regarded as realizing OT mapping through the generator, and realizing the distance determination between the real data distribution and the generated data distribution through the discriminator. The distance between two distributions can be defined as:
在现有的技术中,Wasserstein GAN(WGAN)在将OT用于GAN的改进上做出了突破。WGAN将中的c(x
r,x
g)选为欧氏距离,两个分布之间的距离定义为:
Among the existing technologies, Wasserstein GAN (WGAN) has made a breakthrough in the improvement of using OT for GAN. WGAN chooses c(x r , x g ) as the Euclidean distance, and the distance between the two distributions is defined as:
即Wasserstein距离。WGAN生成器的输入为在[-1,1]之间服从正态分布的 噪声样本z,通过OT映射及优化Wasserstein距离合成新的数据,WGAN生成的图像可用于数据增强。That is, the Wasserstein distance. The input of the WGAN generator is a noise sample z that obeys a normal distribution between [-1, 1]. New data is synthesized through OT mapping and optimization of the Wasserstein distance. The images generated by WGAN can be used for data enhancement.
但是,将OT理论用于GAN结构中的WGAN中所使用的Wasserstein距离定义基于欧式距离,然而欧式距离对于尺度和异常点较为敏感,即对噪声的影响比较敏感。However, the Wasserstein distance definition used in WGAN, which applies OT theory to the GAN structure, is based on Euclidean distance, however, Euclidean distance is sensitive to scale and outliers, that is, sensitive to the influence of noise.
发明内容SUMMARY OF THE INVENTION
(一)本发明所要解决的技术问题(1) Technical problem to be solved by the present invention
本发明解决的技术问题是:如何在解决现有对抗网络训练不稳定的技术问题的基础上,提高模型对噪声影响的鲁棒性。The technical problem solved by the invention is: how to improve the robustness of the model to the influence of noise on the basis of solving the technical problem of unstable training of the existing adversarial network.
(二)本发明所采用的技术方案(2) Technical scheme adopted in the present invention
一种图像数据增强网络的训练方法,所述训练方法包括:A training method for an image data enhancement network, the training method comprising:
获取噪声样本和待增强的真实数据样本;Obtain noise samples and real data samples to be enhanced;
将所述噪声样本输入图像数据增强网络以得到生成数据样本;inputting the noise samples into an image data augmentation network to obtain generated data samples;
将所述真实数据样本和所述生成数据样本输入所述图像数据增强网络以得到多组余弦距离值,所述图像数据增强网络根据所述多组余弦距离值计算得到损失函数;Inputting the real data samples and the generated data samples into the image data enhancement network to obtain multiple sets of cosine distance values, and the image data enhancement network calculates a loss function according to the multiple sets of cosine distance values;
根据所述损失函数更新所述图像数据增强网络的网络参数。The network parameters of the image data enhancement network are updated according to the loss function.
可选择地,所述图像数据增强网络包括生成器和判别器,其中Optionally, the image data augmentation network includes a generator and a discriminator, wherein
将所述噪声样本输入图像数据增强网络以得到生成数据样本的方法为:将所述噪声样本输入到所述生成器,所述生成器输出生成数据样本;The method for inputting the noise samples into an image data enhancement network to obtain the generated data samples is: inputting the noise samples into the generator, and the generator outputs the generated data samples;
根据所述真实数据样本和所述生成数据样本计算得到多组余弦距离值的方法为:将所述真实数据样本和所述生成数据样本输入到所述判别器中,所述判别器输出多组余弦距离值,所述多组余弦距离值包括第一余弦距离值、第二余弦距离值和第三余弦距离值。The method for calculating and obtaining multiple sets of cosine distance values according to the real data samples and the generated data samples is: inputting the real data samples and the generated data samples into the discriminator, and the discriminator outputs multiple sets of cosine distance values. Cosine distance values, the plurality of sets of cosine distance values include a first cosine distance value, a second cosine distance value, and a third cosine distance value.
可选择地,所述真实数据样本包括服从相同分布的第一子真实样本x
r和第二子真实样本x
r′,所述生成数据样本包括服从同一分布的第一子生成样本x
g和 第二子生成样本x
g′,所述第一余弦距离值d(x
r,x
g)、所述第二余弦距离值d(x
r,x
r′)和所述第三余弦距离值d(x
g,x
g′)的计算公式如下:
Optionally, the real data samples include a first sub-real sample x r and a second sub-real sample x r ′ that obey the same distribution, and the generated data samples include a first sub-generated sample x g and a second sub-real sample x g that obey the same distribution. Two sub-generating samples x g ′, the first cosine distance value d(x r , x g ), the second cosine distance value d(x r , x r ′) and the third cosine distance The formula for calculating the value d(x g , x g ′) is as follows:
可选择地,根据所述多组余弦距离值计算损失函数的方法为所述判别器根据如下公式计算得到损失函数L:Optionally, the method for calculating the loss function according to the multiple sets of cosine distance values is that the discriminator calculates the loss function L according to the following formula:
其中,
为数学期望,L为损失函数。
in, is the mathematical expectation, and L is the loss function.
可选择地,根据所述损失函数更新所述图像数据增强网络的网络参数的方法为:Optionally, the method for updating the network parameters of the image data enhancement network according to the loss function is:
根据所述损失函数对所述图像数据增强网络进行反向操作,并根据随机梯度下降方法对所述判别器的网络参数更新N次;Perform reverse operation on the image data enhancement network according to the loss function, and update the network parameters of the discriminator N times according to the stochastic gradient descent method;
根据所述损失函数对所述图像数据增强网络进行反向操作,并根据随机梯度下降方法对所述生成器的网络参数更新一次。The image data enhancement network is reversely operated according to the loss function, and the network parameters of the generator are updated once according to the stochastic gradient descent method.
本申请还公开了一种图像数据增强网络的训练装置,所述训练装置包括:The application also discloses a training device for an image data enhancement network, the training device comprising:
获取模块,用于获取噪声样本和待增强的真实数据样本;The acquisition module is used to acquire noise samples and real data samples to be enhanced;
第一输入模块,用于将所述噪声样本输入图像数据增强网络以得到生成数据样本;a first input module for inputting the noise samples into an image data enhancement network to obtain generated data samples;
第二输入模块,用于将所述真实数据样本和所述生成数据样本输入所述图像数据增强网络以得到多组余弦距离值,所述图像数据增强网络所述多组余弦距离值计算得到损失函数;The second input module is configured to input the real data samples and the generated data samples into the image data enhancement network to obtain multiple sets of cosine distance values, and the image data enhancement network calculates the multiple sets of cosine distance values to obtain a loss function;
更新模块,用于根据所述损失函数更新所述图像数据增强网络的网络参数。An update module, configured to update the network parameters of the image data enhancement network according to the loss function.
可选择地,所述图像数据增强网络包括生成器和判别器,其中Optionally, the image data augmentation network includes a generator and a discriminator, wherein
所述第一输入模块用于将所述噪声样本输入到所述生成器,所述生成器输出生成数据样本;the first input module is configured to input the noise samples to the generator, and the generator outputs generated data samples;
所述第二输入模块用于将所述真实数据样本和所述生成数据样本输入到所述判别器中,所述判别器输出多组余弦距离值,所述多组余弦距离值包括第一余弦距离值、第二余弦距离值和第三余弦距离值。The second input module is used to input the real data samples and the generated data samples into the discriminator, and the discriminator outputs multiple sets of cosine distance values, and the multiple sets of cosine distance values include the first cosine distance value. The sine distance value, the second cosine distance value, and the third cosine distance value.
可选择地,根据所述损失函数更新所述图像数据增强网络的网络参数时,所述更新模块具体用于:Optionally, when updating the network parameters of the image data enhancement network according to the loss function, the updating module is specifically configured to:
利用所述损失函数对所述图像数据增强网络进行反向操作,并根据随机梯度下降方法对所述判别器的网络参数更新N次;Use the loss function to perform a reverse operation on the image data enhancement network, and update the network parameters of the discriminator N times according to the stochastic gradient descent method;
利用所述损失函数对所述图像数据增强网络进行反向操作,并根据随机梯度下降方法对所述生成器的网络参数更新一次。The image data enhancement network is reversely operated by using the loss function, and the network parameters of the generator are updated once according to the stochastic gradient descent method.
本申请还公开了一种计算机可读存储介质,所述计算机可读存储介质存储有图像数据增强网络的训练程序,所述图像数据增强网络的训练程序被处理器执行时实现上述的图像数据增强网络的训练方法。The present application also discloses a computer-readable storage medium, where the computer-readable storage medium stores a training program of an image data enhancement network, and the image data enhancement network realizes the above-mentioned image data enhancement when the training program of the image data enhancement network is executed by a processor The training method of the network.
(三)有益效果(3) Beneficial effects
本发明公开了一种图像数据增强网络的训练方法,相对于传统的训练方法,具有如下技术效果:The invention discloses a training method for an image data enhancement network, which has the following technical effects compared with the traditional training method:
(1)在结合OT理论和GAN的基础上,通过使用余弦距离定义真实数据分布和生成数据分布之间的距离,从而提高网络结构的稳定性和生成数据的质量,降低噪声对网络的影响。本实施例所提出的方法能够对小样本数据进行数据增强,且生成的结果多样性较高,IS系数大、FID系数小,可以解决传统数据增强方法中增强数据相关性高的问题。(1) Based on the combination of OT theory and GAN, the cosine distance is used to define the distance between the real data distribution and the generated data distribution, thereby improving the stability of the network structure and the quality of the generated data, and reducing the impact of noise on the network. The method proposed in this embodiment can perform data enhancement on small sample data, and generates results with high diversity, large IS coefficient and small FID coefficient, which can solve the problem of high correlation of enhanced data in traditional data enhancement methods.
附图说明Description of drawings
图1为本发明的实施例一的图像数据增强网络的训练方法的流程图;1 is a flowchart of a training method for an image data enhancement network according to Embodiment 1 of the present invention;
图2为本发明的实施例二的图像数据增强网络的训练装置的示意图;2 is a schematic diagram of an apparatus for training an image data enhancement network according to Embodiment 2 of the present invention;
图3为本发明的实施例的不同网络模型的生成图像的对比图;3 is a comparison diagram of generated images of different network models according to an embodiment of the present invention;
图4为本发明的实施例三的计算机设备示意图。FIG. 4 is a schematic diagram of a computer device according to Embodiment 3 of the present invention.
具体实施方式Detailed ways
为了使本发明的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本发明进一步详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。In order to make the objectives, technical solutions and advantages of the present invention clearer, the present invention will be further described in detail below with reference to the accompanying drawings and embodiments. It should be understood that the specific embodiments described herein are only used to explain the present invention, but not to limit the present invention.
在详细描述本申请的各个实施例之前,首先简单描述本申请的发明构思:现有技术中将OT理论应用在对抗网络训练中,由于采用了欧式距离,其对噪声和异常点较为敏感,本申请通过计算真实数据样本和生成数据样本之间的多组余弦距离值,并根据多组余弦距离值构建损失函数,从而对图像数据增强网络的网络参数进行更新,增强了对噪声和异常点的鲁棒性。Before describing the various embodiments of the present application in detail, first briefly describe the inventive concept of the present application: in the prior art, the OT theory is applied to the training of adversarial networks. The application updates the network parameters of the image data enhancement network by calculating multiple sets of cosine distance values between the real data samples and the generated data samples, and constructing a loss function according to the multiple sets of cosine distance values. robustness.
其中,OT理论通过寻找两个分布传输之间的最小代价,无论两个分布的支撑集是否存在重叠,都可以测量两个分布之间的距离。这个理论为解决原始GAN结构损失函数的缺陷提供了方法。OT定义为寻找两个分布p
g(x)和p
r(x)之间基于代价函数c(x
r,x
g)的最优映射函数π:
Among them, OT theory can measure the distance between two distributions by finding the minimum cost between the transmissions of the two distributions, regardless of whether the support sets of the two distributions overlap or not. This theory provides a way to address the shortcomings of the original GAN structural loss function. OT is defined as finding the optimal mapping function π between two distributions p g (x) and p r (x) based on the cost function c(x r ,x g ):
其中Π(p
r,p
g)为所有联合分布π(x
r,x
g)的集合。从OT的角度,GAN可以被看做通过生成器实现OT映射,通过判别器实现真实数据分布和生成数据分布之间的距离判定。两个分布之间的距离可以定义为:
where Π(p r ,p g ) is the set of all joint distributions π(x r ,x g ). From the perspective of OT, GAN can be regarded as realizing OT mapping through the generator, and realizing the distance determination between the real data distribution and the generated data distribution through the discriminator. The distance between two distributions can be defined as:
在现有的技术中,Wasserstein GAN(WGAN)在将OT用于GAN的改进上做出了突破。WGAN将中的c(x
r,x
g)选为欧氏距离,两个分布之间的距离定义为:
Among the existing technologies, Wasserstein GAN (WGAN) has made a breakthrough in the improvement of using OT for GAN. WGAN chooses c(x r , x g ) as the Euclidean distance, and the distance between the two distributions is defined as:
即Wasserstein距离。WGAN生成器的输入为在[-1,1]之间服从正态分布的噪声样本z,通过OT映射及优化Wasserstein距离合成新的数据,WGAN生成的图像可用于数据增强。That is, the Wasserstein distance. The input of the WGAN generator is a noise sample z that obeys a normal distribution between [-1, 1]. New data is synthesized through OT mapping and optimization of the Wasserstein distance. The images generated by WGAN can be used for data enhancement.
实施例一Example 1
具体地,如图1所示,本实施例一的图像数据增强网络的训练方法包括如 下步骤:Specifically, as shown in Figure 1, the training method of the image data enhancement network of the first embodiment includes the following steps:
步骤S10:获取噪声样本和待增强的真实数据样本。Step S10: Obtain noise samples and real data samples to be enhanced.
步骤S20:将所述噪声样本输入图像数据增强网络以得到生成数据样本Step S20: Input the noise samples into an image data enhancement network to obtain generated data samples
步骤S30:将所述真实数据样本和所述生成数据样本输入所述图像数据增强网络以得到多组余弦距离值,所述图像数据增强网络根据所述多组余弦距离值计算得到损失函数。Step S30: Input the real data samples and the generated data samples into the image data enhancement network to obtain multiple sets of cosine distance values, and the image data enhancement network calculates a loss function according to the multiple sets of cosine distance values.
步骤S40:根据所述损失函数更新所述图像数据增强网络的网络参数。Step S40: Update network parameters of the image data enhancement network according to the loss function.
具体来说,本实施例的图像数据增强网络包括生成器G和判别器D,其中,生成器G和判别器D均采用卷积神经网络。步骤S20中,将所述噪声样本输入到所述生成器G,所述生成器G输出生成数据样本。步骤S30中,将所述真实数据样本和所述生成数据样本输入到所述判别器D中,所述判别器D输出多组余弦距离值,所述多组余弦距离值包括第一余弦距离值、第二余弦距离值和第三余弦距离值。Specifically, the image data enhancement network in this embodiment includes a generator G and a discriminator D, wherein both the generator G and the discriminator D use a convolutional neural network. In step S20, the noise samples are input to the generator G, and the generator G outputs generated data samples. In step S30, the real data samples and the generated data samples are input into the discriminator D, and the discriminator D outputs multiple sets of cosine distance values, and the multiple sets of cosine distance values include the first cosine distance. value, the second cosine distance value, and the third cosine distance value.
进一步地,所述真实数据样本包括服从相同分布的第一子真实样本x
r和第二子真实样本x
r′,所述生成数据样本包括服从同一分布的第一子生成样本x
g和第二子生成样本x
g′。具体来说,所采用的真实数据样本服从相同分布,采用随机取样的方式获取第一子真实样本x
r和第二子真实样本x
r′。输入的噪声样本是固定的,经过生成器G之后会生成服从一定分布的生成数据样本,例如服从正态分布,从这个分布中随机取样获得第一子生成样本x
g和第二子生成样本x
g′。
Further, the real data samples include a first sub-real sample x r and a second sub-real sample x r ′ that obey the same distribution, and the generated data samples include a first sub-generated sample x g and a second sub-real sample x g that obey the same distribution. Subgenerate samples x g ′. Specifically, the adopted real data samples obey the same distribution, and the first sub-real sample x r and the second sub-real sample x r ′ are obtained by random sampling. The input noise sample is fixed. After the generator G, a generated data sample that obeys a certain distribution will be generated, such as a normal distribution. The first sub-generated sample x g and the second sub-generated sample x are randomly sampled from this distribution. g '.
其中,第一余弦距离值d(x
r,x
g)、第二余弦距离值d(x
r,x
r′)和第三余弦距离值d(x
g,x
g′)的计算公式如下:
Among them, the calculation of the first cosine distance value d(x r ,x g ), the second cosine distance value d(x r ,x r ′) and the third cosine distance value d(x g ,x g ′) The formula is as follows:
进一步地,判别器D根据如下计算公式得到损失函数L:Further, the discriminator D obtains the loss function L according to the following calculation formula:
其中,
为数学期望,L为损失函数。
in, is the mathematical expectation, and L is the loss function.
进一步地,根据所述损失函数L对所述图像数据增强网络进行反向操作,并根据随机梯度下降方法对所述判别器的网络参数更新N次,以及根据随机梯度下降方法对所述生成器的网络参数更新一次。重复上述步骤,直至生成器和判别器得到均衡,从而完成图像数据增强网络的训练。Further, perform reverse operation on the image data enhancement network according to the loss function L, update the network parameters of the discriminator N times according to the stochastic gradient descent method, and update the generator according to the stochastic gradient descent method. The network parameters are updated once. The above steps are repeated until the generator and discriminator are balanced, thus completing the training of the image data augmentation network.
具体来说,图像数据增强网络的输入为需要增强的小样本数据,即真实数据,同时需要向系统提供训练步长α、批处理图像数量N,以及每一次生成器参数更新前对判别器参数更新次数n
c。判别器的初始参数为ω
0,生成器的初始参数为θ
0。生成器的输入为在[-1,1]之间服从正态分布的噪声样本z。
Specifically, the input of the image data enhancement network is the small sample data that needs to be enhanced, that is, the real data. At the same time, the system needs to provide the training step α, the number of batch images N, and the parameters of the discriminator before each update of the generator parameters. The number of updates n c . The initial parameter of the discriminator is ω 0 and the initial parameter of the generator is θ 0 . The input to the generator is a normally distributed noise sample z between [-1, 1].
当生成器参数θ没有达到收敛时,运用上述损失函数的公式计算真实数据分布p
r(x)和生成数据分布p
g(x)之间的距离。在每一次对生成器参数更新前,需要对判别器参数ω使用随机梯度下降法更新n
c次:
之后对生成器参数使用随机梯度下降法更新一次:
对以上训练步骤进行循环直至生成器参数θ收敛。
When the generator parameter θ does not reach convergence, the distance between the real data distribution p r (x) and the generated data distribution p g (x) is calculated using the formula of the above loss function. Before each update of the generator parameters, the discriminator parameter ω needs to be updated n c times using the stochastic gradient descent method: Then use stochastic gradient descent to update the generator parameters once: The above training steps are looped until the generator parameters θ converge.
进一步地,采用IS系数(Inception Score)以及FID系数(Fréchet inception distance)作为数据增强的评价指标。其公式表示为:Further, the IS coefficient (Inception Score) and the FID coefficient (Fréchet inception distance) are used as evaluation indicators for data enhancement. Its formula is expressed as:
其中,p(l|X)为生成样本X的条件分布,KL为Kullback-Leibler散度,N为一个批处理中的样本数量,m、C和Tr分别为均值、协方差和迹。IS越大、FID越小则说明生成图像的质量和多样性越好。where p(l|X) is the conditional distribution of the generated sample X, KL is the Kullback-Leibler divergence, N is the number of samples in a batch, and m, C, and Tr are the mean, covariance, and trace, respectively. The larger the IS and the smaller the FID, the better the quality and diversity of the generated images.
当图像数据增强网络训练完成之后,将噪声样本输入到生成器中,生成器输出合成数据,将合成数据和真实数据样本一起作为训练样本,用于后续模型训练,从而实现数据增强。After the training of the image data enhancement network is completed, the noise samples are input into the generator, the generator outputs synthetic data, and the synthetic data and real data samples are used as training samples for subsequent model training, thereby realizing data enhancement.
本实施例公开的图像数据增强网络的训练方法,在结合OT理论和GAN 的基础上,通过使用余弦距离定义真实数据分布和生成数据分布之间的距离,从而提高网络结构的稳定性和生成数据的质量,降低噪声对网络的影响。本实施例所提出的方法能够对小样本数据进行数据增强,且生成的结果多样性较高,IS系数大、FID系数小,可以解决传统数据增强方法中增强数据相关性高的问题。对于不同领域的深度学习都具有参考价值,可用于不同领域小样本数据集的训练中。The training method of the image data enhancement network disclosed in this embodiment, on the basis of combining OT theory and GAN, defines the distance between the real data distribution and the generated data distribution by using the cosine distance, thereby improving the stability of the network structure and the generated data. quality and reduce the impact of noise on the network. The method proposed in this embodiment can perform data enhancement on small sample data, and generates results with high diversity, large IS coefficient and small FID coefficient, which can solve the problem of high correlation of enhanced data in traditional data enhancement methods. It has reference value for deep learning in different fields and can be used in the training of small sample datasets in different fields.
实施例二Embodiment 2
如图3所示,本实施例二的图像数据增强网络的训练装置包括获取模块100、第一输入模块200、第二输入模块300和更新模块400,其中获取模块100用于获取噪声样本和待增强的真实数据样本;第一输入模块200用于将所述噪声样本输入图像数据增强网络以得到生成数据样本;第二输入模块300用于将所述真实数据样本和所述生成数据样本输入所述图像数据增强网络以得到多组余弦距离值,所述图像数据增强网络所述多组余弦距离值计算得到损失函数;更新模块400用于根据所述损失函数更新所述图像数据增强网络的网络参数。As shown in FIG. 3 , the apparatus for training an image data enhancement network in the second embodiment includes an acquisition module 100 , a first input module 200 , a second input module 300 and an update module 400 , wherein the acquisition module 100 is used for acquiring noise samples and to-be-to-be-used samples. Enhanced real data samples; the first input module 200 is used to input the noise samples into the image data enhancement network to obtain generated data samples; the second input module 300 is used to input the real data samples and the generated data samples into all data samples. The image data enhancement network is used to obtain multiple sets of cosine distance values, and the multiple sets of cosine distance values of the image data enhancement network are calculated to obtain a loss function; the updating module 400 is configured to update the network of the image data enhancement network according to the loss function. parameter.
进一步地,所述图像数据增强网络包括生成器G和判别器D,其中所述第一输入模块200用于将所述噪声样本输入到所述生成器G,所述生成器G输出生成数据样本;所述第二输入模块300用于将所述真实数据样本和所述生成数据样本输入到所述判别器D中,所述判别器D输出多组余弦距离值,所述多组余弦距离值包括第一余弦距离值、第二余弦距离值和第三余弦距离值。其中,判别器D计算得到多组余弦距离值和损失函数的具体过程参照实施例一,在此不进行赘述。Further, the image data enhancement network includes a generator G and a discriminator D, wherein the first input module 200 is configured to input the noise samples to the generator G, and the generator G outputs generated data samples The second input module 300 is used to input the real data samples and the generated data samples into the discriminator D, and the discriminator D outputs multiple sets of cosine distance values, and the multiple sets of cosine distance values It includes a first cosine distance value, a second cosine distance value, and a third cosine distance value. The specific process for the discriminator D to calculate and obtain multiple sets of cosine distance values and loss functions may refer to Embodiment 1, which will not be repeated here.
进一步地,根据所述损失函数更新所述图像数据增强网络的网络参数时,所述更新模块400具体用于:利用所述损失函数对所述图像数据增强网络进行反向操作,并根据随机梯度下降方法对所述判别器的网络参数更新N次;利用所述损失函数对所述图像数据增强网络进行反向操作,并根据随机梯度下降方法对所述生成器的网络参数更新一次。其中,更新模块400的更新方式参照实施例一,在此不进行赘述。Further, when updating the network parameters of the image data enhancement network according to the loss function, the updating module 400 is specifically configured to: use the loss function to perform a reverse operation on the image data enhancement network, and perform a reverse operation on the image data enhancement network according to the stochastic gradient. The descent method updates the network parameters of the discriminator N times; the image data enhancement network is reversely operated by using the loss function, and the network parameters of the generator are updated once according to the stochastic gradient descent method. The update method of the update module 400 refers to Embodiment 1, which will not be repeated here.
进一步地,为了更加直观地展示本实施例的训练方法得到的图像数据增强网络的优点,申请人进行了实验验证。Further, in order to more intuitively demonstrate the advantages of the image data enhancement network obtained by the training method of this embodiment, the applicant has conducted experimental verification.
具体地,采用CIFAR-10数据集进行实验和验证,CIFAR-10包含60000张 32*32的彩色图像,共10类,每一类6000张。Specifically, the CIFAR-10 dataset is used for experiments and verification. CIFAR-10 contains 60,000 32*32 color images, a total of 10 categories, and each category has 6,000 images.
所有的实验均基于Chainer-GAN-lib库完成,为了更好的显示所提出数据增强系统的优越性,我们选择了以下已有网络模型进行对比:GAN-OTD(OT在原始基于MLP的GAN上的改进)、WGAN-GP(使用梯度惩罚对WGAN进行增强,损失函数中仍然为WGAN结构中使用的欧氏距离)。本实施例的网络结构为CNN-GAN-OTD。实验参数均使用Chainer-GAN-lib中的默认参数:批处理数量为64,最大训练次数为100000。5000张随机取样的生成图像用于IS系数的计算,50000张随机取样的真实图像与10000张随机取样的生成图像用于FID系数的计算。All experiments are done based on the Chainer-GAN-lib library. In order to better show the superiority of the proposed data augmentation system, we choose the following existing network models for comparison: GAN-OTD (OT on the original MLP-based GAN improvement), WGAN-GP (using gradient penalty to enhance WGAN, the loss function is still the Euclidean distance used in the WGAN structure). The network structure of this embodiment is CNN-GAN-OTD. The experimental parameters all use the default parameters in Chainer-GAN-lib: the number of batches is 64, and the maximum number of training is 100000. 5000 randomly sampled generated images are used for IS coefficient calculation, 50000 randomly sampled real images and 10000 randomly sampled images Randomly sampled generated images are used for the calculation of FID coefficients.
(1)生成图像分析:(1) Generate image analysis:
在训练参数一致的前提下用不同方法对CIFAR-10数据进行图像合成,IS系数和FID系数结果如表1。本方法在IS和FID结果上均为所列方法中的最优结果,验证了本实施例的训练方法训练得到的图像数据增强网络在生成图像质量以及多样性上的优越性。On the premise of the same training parameters, the CIFAR-10 data is synthesized by different methods. The results of IS coefficient and FID coefficient are shown in Table 1. The IS and FID results of this method are the best results among the listed methods, which verifies the superiority of the image data enhancement network trained by the training method of this embodiment in generating image quality and diversity.
表1.不同方法在CIFAR-10数据集上生成图像质量对比Table 1. Comparison of image quality generated by different methods on the CIFAR-10 dataset
(2)噪声影响分析:(2) Noise impact analysis:
对CIFAR-10数据集添加均值为0,标准差依次增大的高斯噪声,在训练参数一致的前提下用不同网络模型对添加噪声的CIFAR-10数据进行图像合成,IS系数和FID系数结果如表2,生成图像对比如图3。其中标准差最大值根据经验选为20。Add Gaussian noise with a mean value of 0 and an increasing standard deviation to the CIFAR-10 data set, and use different network models to synthesize the noise-added CIFAR-10 data under the premise of the same training parameters. The results of IS coefficient and FID coefficient are as follows Table 2, the generated images are compared in Figure 3. Among them, the maximum standard deviation is selected as 20 according to experience.
本方法在IS和FID结果上均为所列方法中的最优结果,验证了本方法对噪声影响的鲁棒性。The IS and FID results of this method are the best results among the listed methods, which verifies the robustness of this method to the influence of noise.
在图3中,(a),(e)和(i)为原始图像添加标准差为2,5和20的高斯噪声后的 图像;(b)、(c)和(d)为分别用WGAN-GP、DRAGAN和CNN-GAN-OTD通过(a)合成的图像;(f)、(g)和(h)为分别用WGAN-GP、DRAGAN和CNN-GAN-OTD通过(e)合成的图像;(j)、(k)和(l)为分别用WGAN-GP、DRAGAN和CNN-GAN-OTD通过(i)合成的图像。In Figure 3, (a), (e) and (i) are images after adding Gaussian noise with standard deviations of 2, 5 and 20 to the original image; (b), (c) and (d) are images using WGAN, respectively - Images synthesized by GP, DRAGAN and CNN-GAN-OTD via (a); (f), (g) and (h) images synthesized via (e) with WGAN-GP, DRAGAN and CNN-GAN-OTD, respectively ; (j), (k) and (l) are images synthesized by (i) with WGAN-GP, DRAGAN and CNN-GAN-OTD, respectively.
表2.不同方法在添加噪声的CIFAR-10数据集上生成图像质量对比Table 2. Comparison of image quality generated by different methods on the noise-added CIFAR-10 dataset
本实施例三还公开了一种计算机可读存储介质,所述计算机可读存储介质存储有图像数据增强网络的训练程序,所述图像数据增强网络的训练程序被处理器执行时实现上述的图像数据增强网络的训练方法。The third embodiment further discloses a computer-readable storage medium, where the computer-readable storage medium stores a training program for an image data enhancement network, and the image data enhancement network training program is executed by a processor to realize the above image Training methods for data augmentation networks.
本实施例四还公开了一种计算机设备,在硬件层面,如图4所示,该终端包括处理器12、内部总线13、网络接口14、计算机可读存储介质11。处理器12从计算机可读存储介质中读取对应的计算机程序然后运行,在逻辑层面上形成请求处理装置。当然,除了软件实现方式之外,本说明书一个或多个实施例并不排除其他实现方式,比如逻辑器件抑或软硬件结合的方式等等,也就是说以下处理流程的执行主体并不限定于各个逻辑单元,也可以是硬件或逻辑器件。所述计算机可读存储介质11上存储有图像数据增强网络的训练程序,所述图像数据增强网络的训练程序被处理器执行时实现上述的图像数据增强网络的训练方法。The fourth embodiment also discloses a computer device. At the hardware level, as shown in FIG. 4 , the terminal includes a processor 12 , an internal bus 13 , a network interface 14 , and a computer-readable storage medium 11 . The processor 12 reads the corresponding computer program from the computer-readable storage medium and then executes it, forming a request processing device on a logical level. Of course, in addition to software implementations, one or more embodiments of this specification do not exclude other implementations, such as logic devices or a combination of software and hardware, etc., that is to say, the execution subjects of the following processing procedures are not limited to each Logic unit, which can also be hardware or logic device. The computer-readable storage medium 11 stores a training program of the image data enhancement network, and when the image data enhancement network training program is executed by the processor, implements the above-mentioned training method of the image data enhancement network.
计算机可读存储介质包括永久性和非永久性、可移动和非可移动媒体可以由任何方法或技术来实现信息存储。信息可以是计算机可读指令、数据结构、程序的模块或其他数据。计算机可读存储介质的例子包括,但不限于相变内存(PRAM)、静态随机存取存储器(SRAM)、动态随机存取存储器(DRAM)、其他类型的随机存取存储器(RAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、快闪记忆体或其他内存技术、只读光盘只读存储器 (CD-ROM)、数字多功能光盘(DVD)或其他光学存储、磁盒式磁带、磁盘存储、量子存储器、基于石墨烯的存储介质或其他磁性存储设备或任何其他非传输介质,可用于存储可以被计算设备访问的信息。Computer-readable storage media includes both persistent and non-permanent, removable and non-removable media, and storage of information can be implemented by any method or technology. Information may be computer readable instructions, data structures, modules of programs, or other data. Examples of computer-readable storage media include, but are not limited to, phase-change memory (PRAM), static random access memory (SRAM), dynamic random access memory (DRAM), other types of random access memory (RAM), read-only memory Memory (ROM), Electrically Erasable Programmable Read Only Memory (EEPROM), Flash Memory or other memory technology, Compact Disc Read Only Memory (CD-ROM), Digital Versatile Disc (DVD) or other optical storage , magnetic cassettes, disk storage, quantum memory, graphene-based storage media or other magnetic storage devices or any other non-transmission media that can be used to store information that can be accessed by computing devices.
上面对本发明的具体实施方式进行了详细描述,虽然已表示和描述了一些实施例,但本领域技术人员应该理解,在不脱离由权利要求及其等同物限定其范围的本发明的原理和精神的情况下,可以对这些实施例进行修改和完善,这些修改和完善也应在本发明的保护范围内。The specific embodiments of the present invention have been described in detail above. Although some embodiments have been shown and described, those skilled in the art should understand that the principles and spirit of the present invention, which are defined in the scope of the claims and their equivalents, are not departed from. Under the circumstances, these embodiments can be modified and perfected, and these modifications and improvements should also fall within the protection scope of the present invention.