TWI752405B - 神經網路訓練及圖像生成方法、電子設備、儲存媒體 - Google Patents
神經網路訓練及圖像生成方法、電子設備、儲存媒體 Download PDFInfo
- Publication number
- TWI752405B TWI752405B TW109101220A TW109101220A TWI752405B TW I752405 B TWI752405 B TW I752405B TW 109101220 A TW109101220 A TW 109101220A TW 109101220 A TW109101220 A TW 109101220A TW I752405 B TWI752405 B TW I752405B
- Authority
- TW
- Taiwan
- Prior art keywords
- distribution
- network
- discriminant
- loss
- training
- Prior art date
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
- G06F18/2148—Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the process organisation or structure, e.g. boosting cascade
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/217—Validation; Performance evaluation; Active pattern learning techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/23—Clustering techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N7/00—Computing arrangements based on specific mathematical models
- G06N7/01—Probabilistic graphical models, e.g. probabilistic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T7/00—Image analysis
- G06T7/0002—Inspection of images, e.g. flaw detection
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T9/00—Image coding
- G06T9/002—Image coding using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/762—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using clustering, e.g. of similar faces in social networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
- G06V10/7747—Organisation of the process, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/776—Validation; Performance evaluation
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/20—Special algorithmic details
- G06T2207/20076—Probabilistic image processing
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/20—Special algorithmic details
- G06T2207/20081—Training; Learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/20—Special algorithmic details
- G06T2207/20084—Artificial neural networks [ANN]
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Data Mining & Analysis (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- General Engineering & Computer Science (AREA)
- Life Sciences & Earth Sciences (AREA)
- Mathematical Physics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- Biomedical Technology (AREA)
- Multimedia (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Probability & Statistics with Applications (AREA)
- Quality & Reliability (AREA)
- Algebra (AREA)
- Computational Mathematics (AREA)
- Mathematical Analysis (AREA)
- Mathematical Optimization (AREA)
- Pure & Applied Mathematics (AREA)
- Image Analysis (AREA)
- Image Processing (AREA)
Abstract
本發明涉及一種神經網路訓練及圖像生成方法、電子設備、儲存媒體,所述方法包括:將第一隨機向量輸入生成網路,獲得第一生成圖像;將第一生成圖像和第一真實圖像輸入判別網路,獲得第一判別分布與第二判別分布;根據第一判別分布、第二判別分布、第一目標分布、第二目標分布,確定判別網路的第一網路損失;根據第一判別分布和第二判別分布,確定生成網路的第二網路損失;根據第一網路損失和第二網路損失,對抗訓練生成網路和判別網路。
Description
本發明涉及電腦技術領域,尤其涉及一種神經網路訓練及圖像生成方法、電子設備、儲存媒體。
在相關技術中,生成對抗網路(Generative Adversarial Networks,GAN)由兩個模組組成,分別爲判別網路(Discriminator) 和生成網路 (Generator)。受零和博弈(zero-sum game)的啓發,兩個網路通過互相對抗的方式達到最佳生成效果。在訓練過程中,判別器通過獎勵真目標和懲罰假目標來學習區分真實圖像數據和生成網路生成的仿真圖像,生成器則通過逐步縮小判別器對假目標的懲罰,使得判別器無法區分真實圖像與生成圖像,兩者互相博弈、進化,最終達到以假亂真的效果。
在相關技術中,生成對抗網路由判別網路輸出的一個單一標量來描述輸入圖片的真實性,再使用該標量計算網路的損失,進而訓練生成對抗網路。
本發明提出了一種神經網路訓練及圖像生成方法、電子設備、儲存媒體。
根據本發明的一方面,提供了一種神經網路訓練方法,包括:
將第一隨機向量輸入生成網路,獲得第一生成圖像;
將所述第一生成圖像和第一真實圖像分別輸入判別網路,分別獲得所述第一生成圖像的第一判別分布與第一真實圖像的第二判別分布,其中,所述第一判別分布表示所述第一生成圖像的真實程度的機率分布,所述第二判別分布表示所述第一真實圖像的真實程度的機率分布;
根據所述第一判別分布、所述第二判別分布、預設的第一目標分布以及預設的第二目標分布,確定所述判別網路的第一網路損失,其中,所述第一目標分布爲生成圖像的目標機率分布,所述第二目標分布爲真實圖像的目標機率分布;
根據所述第一判別分布和所述第二判別分布,確定所述生成網路的第二網路損失;
根據所述第一網路損失和所述第二網路損失,對抗訓練所述生成網路和所述判別網路。
根據本發明的實施例的神經網路訓練方法,判別網路可針對輸入圖像輸出判別分布,以機率分布的形式描述輸入圖像的真實性,可從顔色、紋理、比例、背景等維度描述輸入圖像爲真實圖像的機率,可從多個方面考量輸入圖像的真實性,減少訊息丟失,爲神經網路訓練提供更全面的監測訊息以及更準確的訓練方向,提高訓練精確度,最終提高生成圖像的品質,使得生成網路可適用於生成高清圖像。並且,預設了生成圖像的目標機率分布以及真實圖像的目標機率分布來指導訓練過程,在訓練過程中引導使真實圖像和生成圖像接近各自的目標機率分布,增大真實圖像和生成圖像的區分度,增强判別網路區分真實圖像和生成圖像的能力,進而提升生成網路生成的圖像的品質。
在一種可能的實現方式中,根據所述第一判別分布、所述第二判別分布、預設的第一目標分布以及預設的第二目標分布,確定所述判別網路的第一網路損失,包括:
根據所述第一判別分布和所述第一目標分布,確定所述第一生成圖像的第一分布損失;
根據所述第二判別分布和所述第二目標分布,確定所述第一真實圖像的第二分布損失;
根據所述第一分布損失和所述第二分布損失,確定所述第一網路損失。
通過這種方式,預設了生成圖像的目標機率分布以及真實圖像的目標機率分布來指導訓練過程,並分別確定各自的分布損失,在訓練過程中引導使真實圖像和生成圖像接近各自的目標機率分布,增大真實圖像和生成圖像的區分度,爲判別網路提供了更準確的角度訊息,爲判別網路提供更準確的訓練方向,增强判別網路區分真實圖像和生成圖像的能力,進而提升生成網路生成的圖像的品質。
在一種可能的實現方式中,根據所述第一判別分布和所述第一目標分布,確定所述第一生成圖像的第一分布損失,包括:
將所述第一判別分布映射到所述第一目標分布的支撑集,獲得第一映射分布;
確定所述第一映射分布與所述第一目標分布的第一相對熵;
根據所述第一相對熵,確定所述第一分布損失。
在一種可能的實現方式中,根據所述第二判別分布和所述第二目標分布,確定所述第一真實圖像的第二分布損失,包括:
將所述第二判別分布映射到所述第二目標分布的支撑集,獲得第二映射分布;
確定所述第二映射分布與所述第二目標分布的第二相對熵;
根據所述第二相對熵,確定所述第二分布損失。
在一種可能的實現方式中,根據所述第一分布損失和所述第二分布損失,確定所述第一網路損失,包括:
對所述第一分布損失和所述第二分布損失進行加權求和處理,獲得所述第一網路損失。
在一種可能的實現方式中,根據所述第一判別分布和所述第二判別分布,確定所述生成網路的第二網路損失,包括:
確定所述第一判別分布與所述第二判別分布的第三相對熵;
根據所述第三相對熵,確定所述第二網路損失。
通過這種方式,可通過減小第一判別分布與第二判別分布的差異的方式訓練生成網路,使得判別網路性能提高的同時,促進生成網路的性能提高,從而生成逼真程度較高的生成圖像,使得生成網路可適用於生成高清圖像。
在一種可能的實現方式中,根據所述第一網路損失和所述第二網路損失,對抗訓練所述生成網路和所述判別網路,包括:
根據所述第一網路損失,調整所述判別網路的網路參數;
根據所述第二網路損失,調整所述生成網路的網路參數;
在所述判別網路和所述生成網路滿足訓練條件的情況下,獲得訓練後的所述生成網路和所述判別網路。
在一種可能的實現方式中,根據所述第一網路損失,調整所述判別網路的網路參數,包括:
將第二隨機向量輸入生成網路,獲得第二生成圖像;
根據所述第二生成圖像對第二真實圖像進行插值處理,獲得插值圖像;
將所述插值圖像輸入所述判別網路,獲得所述插值圖像的第三判別分布;
根據所述第三判別分布,確定所述判別網路的網路參數的梯度;
在所述梯度大於或等於梯度閾值的情況下,根據所述第三判別分布確定梯度懲罰參數;
根據所述第一網路損失和所述梯度懲罰參數,調整所述判別網路的網路參數。
通過這種方式,可通過檢測判別網路的網路參數的梯度是否大於或等於梯度閾值,來限制判別網路在訓練中的梯度下降速度,從而限制判別網路的訓練進度,減少判別網路出現梯度消失的機率,從而可持續優化生成網路,提高生成網路的性能,使生成網路生成圖像的逼真程度較高,且適用於生成高清圖像。
在一種可能的實現方式中,根據所述第一網路損失和所述第二網路損失,對抗訓練所述生成網路和所述判別網路,包括:
將至少一個歷史訓練周期中輸入生成網路的第一隨機向量輸入當前訓練周期的生成網路,獲得至少一個第三生成圖像;
將與所述至少一個歷史訓練周期中輸入生成網路的第一隨機向量對應的第一生成圖像、至少一個所述第三生成圖像以及至少一個真實圖像分別輸入當前訓練周期的判別網路,分別獲得至少一個第一生成圖像的第四判別分布、至少一個第三生成圖像的第五判別分布和至少一個真實圖像的第六判別分布;
根據所述第四判別分布、所述第五判別分布和所述第六判別分布確定當前訓練周期的生成網路的訓練進度參數;
在所述訓練進度參數小於或等於訓練進度閾值的情況下,停止調整所述判別網路的網路參數,僅調整所述生成網路的網路參數。
通過這種方式,可通過檢查判別網路和生成網路的訓練進度,來限制判別網路在訓練中的梯度下降速度,從而限制判別網路的訓練進度,減少判別網路出現梯度消失的機率,從而可持續優化生成網路,提高生成網路的性能,使生成網路生成圖像的逼真程度較高,且適用於生成高清圖像。
在一種可能的實現方式中,根據所述第四判別分布、所述第五判別分布和所述第六判別分布確定當前訓練周期的生成網路的訓練進度參數,包括:
分別獲取至少一個所述第四判別分布的第一期望值、至少一個所述第五判別分布的第二期望值以及至少一個所述第六判別分布的第三期望值;
分別獲取所述至少一個所述第一期望值的第一平均值、至少一個所述第二期望值的第二平均值以及至少一個所述第三期望值的第三平均值;
確定所述第三平均值與所述第二平均值的第一差值以及所述第二平均值與所述第一平均值的第二差值;
將所述第一差值與所述第二差值的比值確定爲所述當前訓練周期的生成網路的訓練進度參數。
根據本發明的一方面,提供了一種圖像生成方法,包括:
獲取第三隨機向量;
將所述第三隨機向量輸入訓練後獲得的生成網路進行處理,獲得目標圖像。
根據本發明的一方面,提供了一種神經網路訓練裝置,包括:
生成模組,用於將第一隨機向量輸入生成網路,獲得第一生成圖像;
判別模組,用於將所述第一生成圖像和第一真實圖像分別輸入判別網路,分別獲得所述第一生成圖像的第一判別分布與第一真實圖像的第二判別分布,其中,所述第一判別分布表示所述第一生成圖像的真實程度的機率分布,所述第二判別分布表示所述第一真實圖像的真實程度的機率分布;
第一確定模組,用於根據所述第一判別分布、所述第二判別分布、預設的第一目標分布以及預設的第二目標分布,確定所述判別網路的第一網路損失,其中,所述第一目標分布爲生成圖像的目標機率分布,所述第二目標分布爲真實圖像的目標機率分布;
第二確定模組,用於根據所述第一判別分布和所述第二判別分布,確定所述生成網路的第二網路損失;
訓練模組,用於根據所述第一網路損失和所述第二網路損失,對抗訓練所述生成網路和所述判別網路。
在一種可能的實現方式中,所述第一確定模組被進一步配置爲:
根據所述第一判別分布和所述第一目標分布,確定所述第一生成圖像的第一分布損失;
根據所述第二判別分布和所述第二目標分布,確定所述第一真實圖像的第二分布損失;
根據所述第一分布損失和所述第二分布損失,確定所述第一網路損失。
在一種可能的實現方式中,所述第一確定模組被進一步配置爲:
將所述第一判別分布映射到所述第一目標分布的支撑集,獲得第一映射分布;
確定所述第一映射分布與所述第一目標分布的第一相對熵;
根據所述第一相對熵,確定所述第一分布損失。
在一種可能的實現方式中,所述第一確定模組被進一步配置爲:
將所述第二判別分布映射到所述第二目標分布的支撑集,獲得第二映射分布;
確定所述第二映射分布與所述第二目標分布的第二相對熵;
根據所述第二相對熵,確定所述第二分布損失。
在一種可能的實現方式中,所述第一確定模組被進一步配置爲:
對所述第一分布損失和所述第二分布損失進行加權求和處理,獲得所述第一網路損失。
在一種可能的實現方式中,所述第二確定模組被進一步配置爲:
確定所述第一判別分布與所述第二判別分布的第三相對熵;
根據所述第三相對熵,確定所述第二網路損失。
在一種可能的實現方式中,所述訓練模組被進一步配置爲:
根據所述第一網路損失,調整所述判別網路的網路參數;
根據所述第二網路損失,調整所述生成網路的網路參數;
在所述判別網路和所述生成網路滿足訓練條件的情況下,獲得訓練後的所述生成網路和所述判別網路。
在一種可能的實現方式中,所述訓練模組被進一步配置爲:
將第二隨機向量輸入生成網路,獲得第二生成圖像;
根據所述第二生成圖像對第二真實圖像進行插值處理,獲得插值圖像;
將所述插值圖像輸入所述判別網路,獲得所述插值圖像的第三判別分布;
根據所述第三判別分布,確定所述判別網路的網路參數的梯度;
在所述梯度大於或等於梯度閾值的情況下,根據所述第三判別分布確定梯度懲罰參數;
根據所述第一網路損失和所述梯度懲罰參數,調整所述判別網路的網路參數。
在一種可能的實現方式中,所述訓練模組被進一步配置爲:
將至少一個歷史訓練周期中輸入生成網路的第一隨機向量輸入當前訓練周期的生成網路,獲得至少一個第三生成圖像;
將與所述至少一個歷史訓練周期中輸入生成網路的第一隨機向量對應的第一生成圖像、至少一個所述第三生成圖像以及至少一個真實圖像分別輸入當前訓練周期的判別網路,分別獲得至少一個第一生成圖像的第四判別分布、至少一個第三生成圖像的第五判別分布和至少一個真實圖像的第六判別分布;
根據所述第四判別分布、所述第五判別分布和所述第六判別分布確定當前訓練周期的生成網路的訓練進度參數;
在所述訓練進度參數小於或等於訓練進度閾值的情況下,停止調整所述判別網路的網路參數,僅調整所述生成網路的網路參數。
在一種可能的實現方式中,所述訓練模組被進一步配置爲:
分別獲取至少一個所述第四判別分布的第一期望值、至少一個所述第五判別分布的第二期望值以及至少一個所述第六判別分布的第三期望值;
分別獲取所述至少一個所述第一期望值的第一平均值、至少一個所述第二期望值的第二平均值以及至少一個所述第三期望值的第三平均值;
確定所述第三平均值與所述第二平均值的第一差值以及所述第二平均值與所述第一平均值的第二差值;
將所述第一差值與所述第二差值的比值確定爲所述當前訓練周期的生成網路的訓練進度參數。
根據本發明的一方面,提供了一種圖像生成裝置,其中,包括:
獲取模組,用於獲取第三隨機向量;
獲得模組,用於將所述第三隨機向量輸入訓練後獲得的生成網路進行處理,獲得目標圖像。
根據本發明的一方面,提供了一種電子設備,包括:
處理器;
用於儲存處理器可執行指令的記憶體;
其中,所述處理器被配置爲:執行上述方法。
根據本發明的一方面,提供了一種電腦可讀儲存媒體,其上儲存有電腦程式指令,所述電腦程式指令被處理器執行時實現上述方法。
根據本發明的一方面,提供了一種電腦程式,包括電腦可讀代碼,當所述電腦可讀代碼在電子設備中運行時,所述電子設備中的處理器執行用於執行上述方法。
應當理解的是,以上的一般描述和後文的細節描述僅是示例性和解釋性的,而非限制本發明。
根據下面參考附圖對示例性實施例的詳細說明,本發明的其它特徵及方面將變得清楚。
以下將參考附圖詳細說明本發明的各種示例性實施例、特徵和方面。附圖中相同的附圖標記表示功能相同或相似的元件。儘管在附圖中示出了實施例的各種方面,但是除非特別指出,不必按比例繪製附圖。
在這裏專用的詞“示例性”意爲“用作例子、實施例或說明性”。這裏作爲“示例性”所說明的任何實施例不必解釋爲優於或好於其它實施例。
本文中術語“和/或”,僅僅是一種描述關聯對象的關聯關係,表示可以存在三種關係,例如,A和/或B,可以表示:單獨存在A,同時存在A和B,單獨存在B這三種情況。另外,本文中術語“至少一種”表示多種中的任意一種或多種中的至少兩種的任意組合,例如,包括A、B、C中的至少一種,可以表示包括從A、B和C構成的集合中選擇的任意一個或多個元素。
另外,爲了更好的說明本發明,在下文的具體實施方式中給出了眾多的具體細節。本領域技術人員應當理解,沒有某些具體細節,本發明同樣可以實施。在一些實例中,對於本領域技術人員熟知的方法、手段、元件和電路未作詳細描述,以便於凸顯本發明的主旨。
圖1示出根據本發明實施例的神經網路訓練方法的流程圖,如圖1所示,所述方法包括:
在步驟S11中,將第一隨機向量輸入生成網路,獲得第一生成圖像;
在步驟S12中,將所述第一生成圖像和第一真實圖像分別輸入判別網路,分別獲得所述第一生成圖像的第一判別分布與第一真實圖像的第二判別分布,其中,所述第一判別分布表示所述第一生成圖像的真實程度的機率分布,所述第二判別分布表示所述第一真實圖像的真實程度的機率分布;
在步驟S13中,根據所述第一判別分布、所述第二判別分布、預設的第一目標分布以及預設的第二目標分布,確定所述判別網路的第一網路損失,其中,所述第一目標分布爲生成圖像的目標機率分布,所述第二目標分布爲真實圖像的目標機率分布;
在步驟S14中,根據所述第一判別分布和所述第二判別分布,確定所述生成網路的第二網路損失;
在步驟S15中,根據所述第一網路損失和所述第二網路損失,對抗訓練所述生成網路和所述判別網路。
根據本發明的實施例的神經網路訓練方法,判別網路可針對輸入圖像輸出判別分布,以機率分布的形式描述輸入圖像的真實性,可從顔色、紋理、比例、背景等維度描述輸入圖像爲真實圖像的機率,可從多個方面考量輸入圖像的真實性,減少訊息丟失,爲神經網路訓練提供更全面的監測訊息以及更準確的訓練方向,提高訓練精確度,最終提高生成圖像的品質,使得生成網路可適用於生成高清圖像。並且,預設了生成圖像的目標機率分布以及真實圖像的目標機率分布來指導訓練過程,在訓練過程中引導使真實圖像和生成圖像接近各自的目標機率分布,增大真實圖像和生成圖像的區分度,增强判別網路區分真實圖像和生成圖像的能力,進而提升生成網路生成的圖像的品質。
在一種可能的實現方式中,所述神經網路訓練方法可以由終端設備或其它處理設備執行,其中,終端設備可以爲用戶設備(User Equipment,UE)、移動設備、用戶終端、終端、行動電話、無線電話、個人數位助理(Personal Digital Assistant,PDA)、手持設備、計算設備、車載設備、可穿戴設備等。其它處理設備可爲伺服器或雲端伺服器等。在一些可能的實現方式中,該神經網路訓練方法可以通過處理器調用記憶體中儲存的電腦可讀指令的方式來實現。
在一種可能的實現方式中,所述神經網路可以是由生成網路和判別網路組成的生成對抗網路。生成網路可以是卷積神經網路等深度學習神經網路,本發明對生成網路的類型和結構不做限制。判別網路可以是卷積神經網路等深度學習神經網路,本發明對判別網路的類型和結構不做限制。生成網路可對隨機向量進行處理,獲得生成圖像,隨機向量可以是各元素爲隨機數的向量,可通過隨機採樣等方式獲得。在步驟S11中,可通過隨機採樣等方式獲得第一隨機向量,生成網路可對第一隨機向量進行卷積等處理,獲得與第一隨機向量對應的第一生成圖像。第一隨機向量是隨機生成的向量,因此,第一生成圖像爲隨機圖像。
在一種可能的實現方式中,第一真實圖像可以是任意的真實圖像,例如,可以是圖像獲取裝置(例如,相機、攝影機等)拍攝到的真實圖像。在步驟S12中,可將第一真實圖像和第一生成圖像分別輸入判別網路,分別獲得第一生成圖像的第一判別分布和第一真實圖像的第二判別分布,第一判別分布和第二判別分布可以是向量形式的參數,例如,可以用向量的形式表示機率分布。第一判別分布可表示第一生成圖像的真實程度,即,可通過第一判別分布來描述第一生成圖像是真實圖像的機率。第二判別分布可表示第一真實圖像的真實程度,即,可通過第二判別分布來描述第一真實圖像是真實圖像的機率。以分布(如多維向量)的形式描述圖像的真實性,可從顔色、紋理、比例、背景等多個方面考量圖像的真實性,減少訊息丟失,爲訓練提供準確的訓練方向。
在一種可能的實現方式中,在步驟S13中,可預設真實圖像的目標機率分布(即,第二目標分布),以及生成圖像的目標機率分布(即,第一目標分布),在訓練過程中,可根據真實圖像的目標機率分布以及生成圖像的目標機率分布分別確定生成圖像對應的網路損失和真實圖像對應的網路損失,並分別利用生成圖像對應的網路損失和真實圖像對應的網路損失調整判別網路的參數,使真實圖像的第二判別分布接近第二目標分布,且與第一目標分布有顯著差別,並使生成圖像的第一判別分布接近第一目標分布,且與第二目標分布有顯著差別,可增大真實圖像和生成圖像的區分度,增强判別網路區分真實圖像和生成圖像的能力,進而提升生成網路生成的圖像的品質。
在示例中,可預設生成圖像的錨(anchor)分布(即,第一目標分布)和真實圖像的錨分布(即,第二目標分布),表示生成圖像的錨分布的向量與表示真實圖像的錨分布的向量具有顯著差異。例如,預設一個固定分布U,設定第一目標分布A1=U,第二目標分布A2=U+1。在訓練過程中,可通過調整判別網路的網路參數,使得第一判別分布與生成圖像的錨分布的差異縮小,在此過程中,第一判別分布與真實圖像的錨分布的差異將增大。訓練過程中,通過調整判別網路的網路參數,還使得第二判別分布與真實圖像的錨分布的差異縮小,在此過程中,第二判別分布與生成圖像的錨分布的差異將增大。即,對真實圖像和生成圖像分別預設錨分布,使真實圖像和生成圖像的分布差異增大,從而提升判別網路對真實圖像和生成圖像的區分能力。
在一種可能的實現方式中,步驟S13可包括:根據所述第一判別分布和所述第一目標分布,確定所述第一生成圖像的第一分布損失;根據所述第二判別分布和所述第二目標分布,確定所述第一真實圖像的第二分布損失;根據所述第一分布損失和所述第二分布損失,確定所述第一網路損失。
在示例中,第一目標分布爲準確的機率分布,可確定第一目標分布和第一判別分布之間的差異,從而確定第一分布損失。
在一種可能的實現方式中,可根據第一判別分布和第一目標分布,確定第一生成圖像對應的網路損失(即,第一分布損失)。其中,根據所述第一判別分布和所述第一目標分布,確定所述第一生成圖像的第一分布損失,包括:將所述第一判別分布映射到所述第一目標分布的支撑集,獲得第一映射分布;確定所述第一映射分布與所述第一目標分布的第一相對熵;根據所述第一相對熵,確定所述第一分布損失。
在一種可能的實現方式中,第一判別分布和第一目標分布的支撑集(所述支撑集爲表示機率分布的分布範圍的拓撲空間)可能不同,即,第一判別分布的分布範圍與第一目標分布的分布範圍不同。在分布範圍不同時,比較兩種機率分布的差異沒有意義,因此,可將第一判別分布映射到第一目標分布的支撑集,或將第一目標分布映射到第一判別分布的支撑集,又或者將第一判別分布和第一目標分布映射到同一個支撑集,即,使得第一判別分布的分布範圍與第一目標分布的分布範圍相同,可在相同的分布範圍中比較兩種機率分布的差異。在示例中,可通過線性變換等方式,例如利用投影矩陣對第一判別分布進行投影處理,將第一判別分布映射到第一目標分布的支撑集,即,可對第一判別分布的向量進行線性變換,變換後獲得的向量即爲映射到第一目標分布的支撑集後的第一映射分布。
在一種可能的實現方式中,可確定第一映射分布與第一目標分布的第一相對熵,即,KL(Kullback-Leibler)距離,所述第一相對熵可表示相同支撑集中的兩個機率分布的差異(即,第一映射分布與第一目標分布的差異)。當然,在其他實施方式中,也可以通過JS散度(Jensen-Shannon divergence)或Wasserstein距離等其他方式確定第一映射分布與第一目標分布的差異。
在一種可能的實現方式中,可根據第一相對熵,確定第一分布損失(即,生成圖像對應的網路損失)。在示例中,可將第一相對熵確定爲所述第一分布損失,或對第一相對熵進行運算處理,例如,對第一相對熵進行加權、取對數、取指數等處理,獲得所述第一分布損失。本發明對第一分布損失的確定方式不做限制。
在示例中,第二目標分布爲準確的機率分布,可確定第二目標分布和第二判別分布之間的差異,從而確定第二分布損失。
在一種可能的實現方式中,可根據第二判別分布和第二目標分布,確定第一真實圖像對應的網路損失(即,第二分布損失)。其中,根據所述第二判別分布和所述第二目標分布,確定所述第一真實圖像的第二分布損失,包括:將所述第二判別分布映射到所述第二目標分布的支撑集,獲得第二映射分布;確定所述第二映射分布與所述第二目標分布的第二相對熵;根據所述第二相對熵,確定所述第二分布損失。
在一種可能的實現方式中,第二判別分布和第二目標分布的支撑集(所述支撑集爲表示機率分布的分布範圍的拓撲空間)可能不同,即,第二判別分布的分布範圍與第二目標分布的分布範圍不同。可將第二判別分布映射到第二目標分布的支撑集,或將第二目標分布映射到第二判別分布的支撑集,又或者將第二判別分布和第二目標分布映射到同一個支撑集,使得第二判別分布的分布範圍與第二目標分布的分布範圍相同,可在相同的分布範圍中比較兩種機率分布的差異。在示例中,可通過線性變換等方式,例如利用投影矩陣對第二判別分布進行投影處理,將第二判別分布映射到第二目標分布的支撑集,即,可對第二判別分布的向量進行線性變換,變換後獲得的向量即爲映射到第二目標分布的支撑集後的第二映射分布。
在一種可能的實現方式中,可確定第二映射分布與第二目標分布的第二相對熵,所述第二相對熵可表示相同支撑集中的兩個機率分布的差異(即,第二映射分布與第二目標分布的差異)。其中,第二相對熵的計算方法與第一相對熵類似,此處不再重複。當然,在其他實施方式中,也可以通過JS散度(Jensen-Shannon divergence)或Wasserstein距離等其他方式確定第二映射分布與第二目標分布的差異。
在一種可能的實現方式中,可根據第二相對熵,確定第二分布損失(即,生成圖像對應的網路損失)。在示例中,可將第二相對熵確定爲所述第二分布損失,或對第二相對熵進行運算處理,例如,對第二相對熵進行加權、取對數、取指數等處理,獲得所述第二分布損失。本發明對第二分布損失的確定方式不做限制。
在一種可能的實現方式中,可根據第一生成圖像的第一分布損失和第二生成圖像的第二分布損失來確定第一網路損失。其中,根據所述第一分布損失和所述第二分布損失,確定所述第一網路損失,包括:對所述第一分布損失和所述第二分布損失進行加權求和處理,獲得所述第一網路損失。在示例中,第一分布損失和第二分布損失的權重可相同,即,將第一分布損失和第二分布損失直接求和,可獲得第一網路損失。或者,第一分布損失和第二分布損失的權重可不同,即,將第一分布損失和第二分布損失分別乘以各自的權重後再進行求和,可獲得第一網路損失。第一分布損失和第二分布損失的權重可以是預設的,本發明對第一分布損失和第二分布損失的權重不做限制。
通過這種方式,預設了生成圖像的目標機率分布以及真實圖像的目標機率分布來指導訓練過程,並分別確定各自的分布損失,在訓練過程中引導使真實圖像和生成圖像接近各自的目標機率分布,增大真實圖像和生成圖像的區分度,爲判別網路提供了更準確的角度訊息,爲判別網路提供更準確的訓練方向,增强判別網路區分真實圖像和生成圖像的能力,進而提升生成網路生成的圖像的品質。
在一種可能的實現方式中,還可確定生成網路的第二網路損失。在示例中,判別網路需要判別輸入圖像爲真實圖像還是生成圖像,因此,判別網路在訓練過程中可增强對真實圖像和生成圖像的區分能力,即,使真實圖像和生成圖像的判別分布接近各自的目標機率分布,從而增大真實圖像和生成圖像的區分度。然而,生成網路的目標爲使生成圖像接近真實圖像,即,使生成圖像足夠逼真,使得判別網路難以辨別出生成網路輸出的生成圖像。在對抗訓練達到平衡狀態時,判別網路和生成網路的性能都較强,即,判別網路的判別能力很强,能夠分辨出真實圖像和逼真程度較低的生成圖像,而生成網路生成的圖像逼真程度很高,使判別網路難以分辨出高品質的生成圖像。在對抗訓練中,判別網路性能提升可促進生成網路的性能提升,即,判別網路分辨真實圖像和生成圖像的能力越强,則會促使生成網路生成的圖像逼真程度越高。
生成網路的訓練目的爲提高生成圖像的逼真程度,即,使得生成圖像接近真實圖像。也就是說,生成網路的訓練可以使第一生成圖像的第一判別分布與第一真實圖像的第二判別分布接近,從而使得判別網路難以辨別。在一種可能的實現方式中,步驟S14可包括:確定所述第一判別分布與所述第二判別分布的第三相對熵;根據所述第三相對熵,確定所述第二網路損失。
在一種可能的實現方式中,可確定第一判別分布與第二判別分布的第三相對熵,所述第三相對熵表示相同支撑集中的兩個機率分布的差異(即,第三映射分布與第四映射分布的差異)。其中,第三相對熵的計算方法與第一相對熵類似,此處不再重複。當然,在其他實施方式中,也可以通過JS散度(Jensen-Shannon divergence)或Wasserstein距離等其他方式確定第一判別分布與第二判別分布的差異,以通過二者差異確定生成網路的網路損失。
在一種可能的實現方式中,可根據第三相對熵,確定第二網路損失。在示例中,可將第三相對熵確定爲第二網路損失,或對第三相對熵進行運算處理,例如,對第三相對熵進行加權、取對數、取指數等處理,獲得第二網路損失。本發明對第二網路損失的確定方式不做限制。
在一種可能的實現方式中,第一判別分布與第二判別分布的支撑集不同,即,第一判別分布與第二判別分布的分布範圍可不同。可經過線性變換使第一判別分布與第二判別分布的支撑集重合,例如,可將第一判別分布與第二判別分布映射到目標支撑集,使得第二判別分布的分布範圍與第一判別分布的分布範圍相同,可在相同的分布範圍中比較兩種機率分布的差異。
在示例中,所述目標支撑集是所述第一判別分布的支撑集或所述第二判別分布的支撑集。可通過線性變換等方式,將第二判別分布映射到第一判別分布的支撑集,即,可對第二判別分布的向量進行線性變換,變換後獲得的向量即爲映射到第一判別分布的支撑集後的第四映射分布,並將第一判別分布作爲所述第三映射分布。或者,可通過線性變換等方式,將第一判別分布映射到第二判別分布的支撑集,即,可對第一判別分布的向量進行線性變換,變換後獲得的向量即爲映射到第二判別分布的支撑集後的第三映射分布,並將第二判別分布作爲所述第四映射分布。
在示例中,所述目標支撑集也可以是其他支撑集,例如,可預設一支撑集,並將第一判別分布和第二判別分布均映射到該支撑集,分別獲得第三映射分布和第四映射分布。進一步地,可計算第三映射分布和第四映射分布的第三相對熵。本發明對目標支撑集不做限制。
通過這種方式,可通過減小第一判別分布與第二判別分布的差異的方式訓練生成網路,使得判別網路性能提高的同時,促進生成網路的性能提高,從而生成逼真程度較高的生成圖像,使得生成網路可適用於生成高清圖像。
在一種可能的實現方式中,可根據判別網路的第一網路損失和生成網路的第二網路損失,對抗訓練生成網路和判別網路。即,通過訓練,使生成網路和判別網路的性能同時提高,提高判別網路的分辨能力,且提高生成網路生成逼真度較高的生成圖像的能力,且使生成網路和判別網路達到平衡狀態。
可選地,步驟S15可包括:根據所述第一網路損失,調整所述判別網路的網路參數;根據所述第二網路損失,調整所述生成網路的網路參數;在所述判別網路和所述生成網路滿足訓練條件的情況下,獲得訓練後的所述生成網路和所述判別網路。
在訓練過程中,由於網路參數的複雜程度不同等因素,判別網路的訓練進度通常領先於生成網路,而如果判別網路進度較快,提前訓練完成,則無法爲生成網路提供反向傳播中的梯度,進而無法更新生成網路的參數,即,無法提升生成網路的性能。因此,生成網路生成的圖像的性能受到限制,不適用於生成高清的圖像,且逼真度較低。
在一種可能的實現方式中,可限制在判別網路的訓練過程中,用於調整判別網路的網路參數的梯度。其中,根據所述第一網路損失,調整所述判別網路的網路參數,包括:將第二隨機向量輸入生成網路,獲得第二生成圖像;根據所述第二生成圖像對第二真實圖像進行插值處理,獲得插值圖像;將所述插值圖像輸入所述判別網路,獲得所述插值圖像的第三判別分布;根據所述第三判別分布,確定所述判別網路的網路參數的梯度;在所述梯度大於梯度閾值的情況下,根據所述第三判別分布確定梯度懲罰參數;根據所述第一網路損失和所述梯度懲罰參數,調整所述判別網路的網路參數。
在一種可能的實現方式中,可通過隨機採樣等方式獲得第二隨機向量,並輸入生成網路,獲得第二生成圖像,即,獲得一張非真實圖像。也可通過其他方式獲得第二生成圖像,例如,可直接隨機生成一張非真實圖像。
在一種可能的實現方式中,可將第二生成圖像和第二真實圖像進行插值處理,獲得插值圖像,即,插值圖像爲真實圖像與非真實圖像的合成圖像,在插值圖像中,包括部分真實圖像,也包括部分非真實圖像。在示例中,可對第二真實圖像和第二生成圖像進行隨機非線性插值,獲得所述插值圖像,本發明對插值圖像的獲得方式不做限制。
在一種可能的實現方式中,可將插值圖像輸入判別網路,獲得插值圖像的第三判別分布,即,判別網路可針對該真實圖像與非真實圖像的合成圖像進行判別處理,獲得第三判別分布。
在一種可能的實現方式中,可利用第三判別分布來確定判別網路的網路參數的梯度,例如,可預設插值圖像的目標機率分布(例如,可表示插值圖像爲真實圖像的機率爲50%的目標機率分布),並利用第三判別分布和目標機率分布的相對熵來確定判別網路的網路參數的梯度。例如,可將第三判別分布和目標機率分布的相對熵進行反向傳播,計算該相對熵與判別網路的各網路參數的偏微分,從而獲得網路參數的梯度。當然,在其他可能的實現方式中,也可以利用第三判別分布和目標機率分布的JS散度等其他類型的差異,確定判別網路的參數梯度。
在一種可能的實現方式中,如果判別網路的網路參數的梯度大於或等於預設的梯度閾值,則可根據第三判別分布確定梯度懲罰參數。梯度閾值可以是對梯度進行限制的閾值,如果梯度較大,則在訓練過程中,梯度的下降速度可能較快(即,訓練步長較大,網路損失趨於最小值的速度較快),因此,可通過梯度閾值對梯度進行限制。在示例中,梯度閾值可設爲10、20等,本發明對梯度閾值不做限制。
在示例中,通過梯度懲罰參數對超過梯度閾值的網路參數的梯度進行調整,或對梯度下降速度進行限制,使得該參數的梯度較平緩,梯度下降速度減慢。例如,可根據第三判別分布的期望值確定梯度懲罰參數。梯度懲罰參數可以是對梯度下降的補償參數,例如,可通過梯度懲罰參數調整偏微分的乘數,或通過梯度懲罰參數改變梯度下降的方向,以對梯度進行限制,從而減小判別網路的網路參數的梯度下降速度,防止判別網路的梯度下降過快,造成判別網路過早收斂(即,過快訓練完成)。在示例中,第三判別分布爲機率分布,可計算該機率分布的期望值,並根據期望值確定所述梯度懲罰參數,例如,可將所述期望值確定爲網路參數的偏微分的乘數,即,將期望值確定爲梯度懲罰參數,並將梯度懲罰參數作爲梯度的乘數,本發明對梯度懲罰參數的確定方式不做限制。
在一種可能的實現方式中,可根據第一網路損失和梯度懲罰參數,調整判別網路的網路參數。即,在對第一網路損失進行反向傳播使得梯度下降的過程中,加入梯度懲罰參數,在調整判別網路的網路參數的同時,防止梯度下降過快,即,防止判別網路過早訓練完成。例如,可將梯度懲罰參數作爲偏微分的乘數,即梯度的乘數,以減緩梯度下降速度,防止判別網路過早訓練完成。
在一種可能的實現方式中,如果判別網路的網路參數的梯度小於預設的梯度閾值,則可根據第一網路損失調整判別網路的網路參數,即,對第一網路損失進行反向傳播使梯度下降,使得第一網路損失減小。
在一種可能的實現方式中,可在調整判別網路的網路參數時,對判別網路的梯度是否大於或等於梯度閾值進行檢驗,在判別網路的梯度大於或等於梯度閾值的情況下設置梯度懲罰參數。也可不檢驗判別網路的梯度,而通過其他方式控制判別網路的訓練進度(例如,暫停判別網路的網路參數的調整,僅調整生成網路的網路參數等)。
通過這種方式,可通過檢測判別網路的網路參數的梯度是否大於或等於梯度閾值,來限制判別網路在訓練中的梯度下降速度,從而限制判別網路的訓練進度,減少判別網路出現梯度消失的機率,從而可持續優化生成網路,提高生成網路的性能,使生成網路生成圖像的逼真程度較高,且適用於生成高清圖像。
在一種可能的實現方式中,可根據第二網路損失調整生成網路的網路參數,例如,對第二網路損失進行反向傳播使梯度下降,使得第二網路損失減小,以提升生成網路的性能。
在一種可能的實現方式中,可對抗訓練判別網路和生成網路,在通過第一網路損失調整判別網路的網路參數時,保持生成網路的網路參數保持不變,在通過第二網路損失調整生成網路的網路參數時,保持判別網路的網路參數保持不變。可疊代執行上述訓練過程,直到判別網路和生成網路滿足訓練條件,在示例中,所述訓練條件包括判別網路和生成網路達到平衡狀態,例如,判別網路和生成網路的網路損失均小於或等於預設閾值,或收斂於預設區間。或者,所述訓練條件包括以下兩個條件達到平衡狀態:第一,生成網路的網路損失小於或等於預設閾值或收斂於預設區間,第二,判別網路輸出的判別分布表示的輸入圖像爲真實圖像的機率最大化。此時,判別網路分辨真實圖像和生成圖像的能力較强,生成網路生成的圖像品質較高,逼真度較高。
在一種可能的實現方式中,除檢驗判別網路的梯度是否大於或等於梯度閾值之外,還可通過控制判別網路的訓練進度的方式,減小判別網路出現梯度消失的機率。
在一種可能的實現方式中,可在任意訓練周期結束後,檢查判別網路和生成網路的訓練進度。具體地,步驟S15可包括:將至少一個歷史訓練周期中輸入生成網路的第一隨機向量輸入當前訓練周期的生成網路,獲得至少一個第三生成圖像;將與所述至少一個歷史訓練周期中輸入生成網路的第一隨機向量對應的第一生成圖像、至少一個所述第三生成圖像以及至少一個真實圖像分別輸入當前訓練周期的判別網路,分別獲得至少一個第一生成圖像的第四判別分布、至少一個第三生成圖像的第五判別分布和至少一個真實圖像的第六判別分布;根據所述第四判別分布、所述第五判別分布和所述第六判別分布確定當前訓練周期的生成網路的訓練進度參數;在所述訓練進度參數小於或等於訓練進度閾值的情況下,停止調整所述判別網路的網路參數,僅調整所述生成網路的網路參數。
在一種可能的實現方式中,可在訓練過程中開闢一個緩存區,例如,經驗緩存區(experience buffer),在該緩存區中,可保存至少一個(例如,M個,M爲正整數)歷史訓練周期的第一隨機向量以及上述M個歷史訓練周期中生成網路根據第一隨機向量生成的M個第一生成圖像,即,每個歷史訓練周期均可通過一個第一隨機向量生成一個第一生成圖像,在緩存區中,可保存M個歷史訓練周期的第一隨機向量,以及生成的M個第一生成圖像。隨著訓練的進行,在訓練周期數超過M時,可使用最新的訓練周期的第一隨機向量和第一生成圖像代替最早存入緩存區的第一隨機向量和第一生成圖像。
在一種可能的實現方式中,可將至少一個歷史訓練周期中輸入生成網路的第一隨機向量輸入當前訓練周期的生成網路,獲得至少一個第三生成圖像,例如,可將緩存區中的m(m小於或等於M,且m爲正整數)個第一隨機向量輸入當前訓練周期的生成網路,獲得m個第三生成圖像。
在一種可能的實現方式中,可通過當前訓練周期的判別網路分別對m個第三生成圖像進行判別處理,獲得m個第五判別分布。可通過當前訓練周期的判別網路分別對m個歷史訓練周期的第一生成圖像進行判別處理,獲得m個第四判別分布。並可從數據庫中隨機採樣得到m個真實圖像,並通過當前訓練周期的判別網路分別對m個真實圖像進行判別處理,獲得m個第六判別分布。
在一種可能的實現方式中,可根據m個第四判別分布、m個第五判別分布和m個第六判別分布來確定當前訓練周期的生成網路的訓練進度參數,即,確定判別網路的訓練進度是否顯著領先於生成網路,並在確定顯著領先的情況下,調整生成網路的訓練進度參數,以提高生成網路的訓練進度,降低判別網路和生成網路的訓練進度差異,即,暫停判別網路的訓練,單獨訓練生成網路,使生成網路的進度參數提高,進度加快。
在一種可能的實現方式中,根據所述第四判別分布、所述第五判別分布和所述第六判別分布確定當前訓練周期的生成網路的訓練進度參數,包括:分別獲取至少一個所述第四判別分布的第一期望值、至少一個所述第五判別分布的第二期望值以及至少一個所述第六判別分布的第三期望值;分別獲取所述至少一個所述第一期望值的第一平均值、至少一個所述第二期望值的第二平均值以及至少一個所述第三期望值的第三平均值;確定所述第三平均值與所述第二平均值的第一差值以及所述第二平均值與所述第一平均值的第二差值;將所述第一差值與所述第二差值的比值確定爲所述當前訓練周期的生成網路的訓練進度參數。
在一種可能的實現方式中,可分別計算m個第四判別分布的期望值,獲得m個的第一期望值,可分別計算m個第五判別分布的期望值,獲得m個的第二期望值,並分別計算m個第六判別分布的期望值,獲得m個的第三期望值。進一步地,可對m個的第一期望值進行平均處理,獲得第一平均值SB
,可對m個的第二期望值進行平均處理,獲得第二平均值SG
,並可對m個的第三期望值進行平均處理,獲得第三平均值SR
。
在一種可能的實現方式中,可確定第三平均值與第二平均值的第一差值(SR
-SG
),並確定第二平均值與第一平均值的第二差值(SG
-SB
)。進一步地,可將第一差值與第二差值的比值(SR
-SG
)/(SG
-SB
)確定爲所述當前訓練周期的生成網路的訓練進度參數。在另一示例中,還可將預設訓練次數作爲生成網路的訓練進度參數,例如,可使生成網路和判別網路每共同訓練100次,暫停判別網路訓練,並單獨訓練生成網路50次,之後再使生成網路和判別網路每共同訓練100次……直到生成網路和判別網路滿足訓練條件。
在一種可能的實現方式中,可設定訓練進度閾值,所述訓練進度閾值爲確定生成網路訓練進度的閾值,如果訓練進度參數小於或等於訓練進度閾值,則表明判別網路的訓練進度顯著領先於生成網路,即,生成網路的訓練進度較慢,可暫停調整判別網路的網路參數,僅調整生成網路的網路參數。在示例中,可在接下來的訓練周期中,重複執行以上檢查判別網路和生成網路的訓練進度,直到訓練進度參數大於訓練進度閾值,則可同時調整判別網路和生成網路的網路參數,即,使判別網路的訓練暫停至少一個訓練周期,僅訓練生成網路(即,僅根據第三網路損失調整生成網路的網路參數,保持判別網路的網路參數不變),直到生成網路的訓練進度接近判別網路的訓練進度,再對抗訓練生成網路和判別網路。
在其他實現方式中,也可以在訓練進度參數小於或等於訓練進度閾值的情況下,降低判別網路的訓練速度,例如延長判別網路的訓練周期或降低判別網路的梯度下降速度等,直到訓練進度參數大於訓練進度閾值,則可恢復判別網路的訓練速度。
通過這種方式,可通過檢查判別網路和生成網路的訓練進度,來限制判別網路在訓練中的梯度下降速度,從而限制判別網路的訓練進度,減少判別網路出現梯度消失的機率,從而可持續優化生成網路,提高生成網路的性能,使生成網路生成圖像的逼真程度較高,且適用於生成高清圖像。
在一種可能的實現方式中,在生成網路和判別網路的對抗訓練完成後,即,生成網路和判別網路的性能較好時,可使用生成網路生成圖像,生成的圖像逼真度較高。
本發明還提供一種圖像生成方法,使用上述訓練完成的生成對抗網路生成圖像。
在本發明的一些實施例中,一種圖像生成方法包括:獲取第三隨機向量;將第三隨機向量輸入上述神經網路訓練方法訓練後獲得的生成網路進行處理,獲得目標圖像。
在示例中,可通過隨機採樣等方式獲得第三隨機向量,並將第三隨機向量輸入訓練後的生成網路。生成網路可輸出逼真度較高的目標圖像。在示例中,所述目標圖像可以是高清圖像,即,訓練後的生成網路可適用於生成逼真度較高的高清圖像。
根據本發明的實施例的神經網路訓練方法,判別網路可針對輸入圖像輸出判別分布,以分布的形式描述輸入圖像的真實性,從多個方面考量輸入圖像的真實性,減少訊息丟失,爲神經網路訓練提供更全面的監測訊息以及更準確的訓練方向,提高訓練精確度,提高生成圖像的品質,使得生成網路可適用於生成高清圖像。並且預設了生成圖像的目標機率分布以及真實圖像的目標機率分布來指導訓練過程,並分別確定各自的分布損失,在訓練過程中引導使真實圖像和生成圖像接近各自的目標機率分布,增大真實圖像和生成圖像的區分度,增强判別網路區分真實圖像和生成圖像的能力,並通過減小第一判別分布與第二判別分布的差異的方式訓練生成網路,使得判別網路性能提高的同時,促進生成網路的性能提高,從而生成逼真程度較高的生成圖像,使得生成網路可適用於生成高清圖像。進一步地,還可通過檢測判別網路的網路參數的梯度是否大於或等於梯度閾值,或檢查判別網路和生成網路的訓練進度,來限制判別網路在訓練中的梯度下降速度,從而限制判別網路的訓練進度,減少判別網路出現梯度消失的機率,從而可持續優化生成網路,提高生成網路的性能,使生成網路生成圖像的逼真程度較高,且適用於生成高清圖像。
圖2示出根據本發明實施例的神經網路訓練方法的應用示意圖,如圖2所示,可將第一隨機向量輸入生成網路,生成網路可輸出第一生成圖像。判別網路可將第一生成圖像和第一真實圖像分別進行判別處理,分別獲得第一生成圖像的第一判別分布和第一真實圖像的第二判別分布。
在一種可能的實現方式中,可預設生成圖像的錨分布(即,第一目標分布)和真實圖像的錨分布(即,第二目標分布)。可根據第一判別分布和第一目標分布,確定第一生成圖像對應的第一分布損失。並可根據第二判別分布和第二目標分布,確定第一真實圖像對應的第二分布損失。進一步地,可通過第一分布損失和第二分布損失確定判別網路的第一網路損失。
在一種可能的實現方式中,可通過第一判別分布和第二判別分布確定生成網路的第二網路損失。進一步地,可通過第一網路損失和第二網路損失對抗訓練生成網路和判別網路。即,通過第一網路損失調整判別網路的網路參數,以及通過第二網路損失調整生成網路的網路參數。
在一種可能的實現方式中,判別網路的訓練進度通常比生成網路更快,爲降低判別網路提前訓練完成導致梯度消失的機率,從而造成生成網路無法繼續優化。可通過檢測判別網路的梯度,來控制判別網路的訓練進度,在示例中,可對一張真實圖像和生成圖像進行插值,並通過判別網路來確定該插值圖像的第三判別分布,進而根據第三判別分布的期望值確定梯度懲罰參數,如果判別網路的梯度大於或等於預設的梯度閾值,爲防止判別網路的梯度下降過快,造成判別網路過快訓練完成,可在對第一網路損失進行反向傳播使得梯度下降的過程中,加入梯度懲罰參數,以限制判別網路的梯度下降速度。
在一種可能的實現方式中,還可檢查判別網路和生成網路的訓練進度,例如,可將M個歷史訓練周期中輸入生成網路的M個第一隨機向量輸入當前訓練周期的生成網路,獲得M個第三生成圖像。並根據M個歷史訓練周期中生成的第一生成圖像、M個第三生成圖像和M個真實圖像來確定當前訓練周期的生成網路的訓練進度參數。如果訓練進度參數小於或等於訓練進度閾值,則表明判別網路的訓練進度顯著領先於生成網路,可暫停調整判別網路的網路參數,僅調整生成網路的網路參數。並在接下來的訓練周期中,重複執行以上檢查判別網路和生成網路的訓練進度,直到訓練進度參數大於訓練進度閾值,方可同時調整判別網路和生成網路的網路參數,即,使判別網路的訓練暫停至少一個訓練周期,僅訓練生成網路。
在一種可能的實現方式中,在生成網路和判別網路的對抗訓練完成後,可使用生成網路生成目標圖像,目標圖像可以是逼真度較的高清圖像。
在一種可能的實現方式中,所述神經網路訓練方法可增强生成對抗的穩定性和生成圖像的品質和逼真度。可適用於遊戲中場景的生成或合成、圖像風格的遷移或轉換,以及圖像聚類等場景,本發明對所述神經網路訓練方法的使用場景不做限制。
圖3示出根據本發明實施例的神經網路訓練裝置的方塊圖,如圖3所示,所述裝置包括:
生成模組11,用於將第一隨機向量輸入生成網路,獲得第一生成圖像;
判別模組12,用於將所述第一生成圖像和第一真實圖像分別輸入判別網路,分別獲得所述第一生成圖像的第一判別分布與第一真實圖像的第二判別分布,其中,所述第一判別分布表示所述第一生成圖像的真實程度的機率分布,所述第二判別分布表示所述第一真實圖像的真實程度的機率分布;
第一確定模組13,用於根據所述第一判別分布、所述第二判別分布、預設的第一目標分布以及預設的第二目標分布,確定所述判別網路的第一網路損失,其中,所述第一目標分布爲生成圖像的目標機率分布,所述第二目標分布爲真實圖像的目標機率分布;
第二確定模組14,用於根據所述第一判別分布和所述第二判別分布,確定所述生成網路的第二網路損失;
訓練模組15,用於根據所述第一網路損失和所述第二網路損失,對抗訓練所述生成網路和所述判別網路。
在一種可能的實現方式中,所述第一確定模組被進一步配置爲:
根據所述第一判別分布和所述第一目標分布,確定所述第一生成圖像的第一分布損失;
根據所述第二判別分布和所述第二目標分布,確定所述第一真實圖像的第二分布損失;
根據所述第一分布損失和所述第二分布損失,確定所述第一網路損失。
在一種可能的實現方式中,所述第一確定模組被進一步配置爲:
將所述第一判別分布映射到所述第一目標分布的支撑集,獲得第一映射分布;
確定所述第一映射分布與所述第一目標分布的第一相對熵;
根據所述第一相對熵,確定所述第一分布損失。
在一種可能的實現方式中,所述第一確定模組被進一步配置爲:
將所述第二判別分布映射到所述第二目標分布的支撑集,獲得第二映射分布;
確定所述第二映射分布與所述第二目標分布的第二相對熵;
根據所述第二相對熵,確定所述第二分布損失。
在一種可能的實現方式中,所述第一確定模組被進一步配置爲:
對所述第一分布損失和所述第二分布損失進行加權求和處理,獲得所述第一網路損失。
在一種可能的實現方式中,所述第二確定模組被進一步配置爲:
確定所述第一判別分布與所述第二判別分布的第三相對熵;
根據所述第三相對熵,確定所述第二網路損失。
在一種可能的實現方式中,所述訓練模組被進一步配置爲:
根據所述第一網路損失,調整所述判別網路的網路參數;
根據所述第二網路損失,調整所述生成網路的網路參數;
在所述判別網路和所述生成網路滿足訓練條件的情況下,獲得訓練後的所述生成網路和所述判別網路。
在一種可能的實現方式中,所述訓練模組被進一步配置爲:
將第二隨機向量輸入生成網路,獲得第二生成圖像;
根據所述第二生成圖像對第二真實圖像進行插值處理,獲得插值圖像;
將所述插值圖像輸入所述判別網路,獲得所述插值圖像的第三判別分布;
根據所述第三判別分布,確定所述判別網路的網路參數的梯度;
在所述梯度大於或等於梯度閾值的情況下,根據所述第三判別分布確定梯度懲罰參數;
根據所述第一網路損失和所述梯度懲罰參數,調整所述判別網路的網路參數。
在一種可能的實現方式中,所述訓練模組被進一步配置爲:
將至少一個歷史訓練周期中輸入生成網路的第一隨機向量輸入當前訓練周期的生成網路,獲得至少一個第三生成圖像;
將與所述至少一個歷史訓練周期中輸入生成網路的第一隨機向量對應的第一生成圖像、至少一個所述第三生成圖像以及至少一個真實圖像分別輸入當前訓練周期的判別網路,分別獲得至少一個第一生成圖像的第四判別分布、至少一個第三生成圖像的第五判別分布和至少一個真實圖像的第六判別分布;
根據所述第四判別分布、所述第五判別分布和所述第六判別分布確定當前訓練周期的生成網路的訓練進度參數;
在所述訓練進度參數小於或等於訓練進度閾值的情況下,停止調整所述判別網路的網路參數,僅調整所述生成網路的網路參數。
在一種可能的實現方式中,所述訓練模組被進一步配置爲:
分別獲取至少一個所述第四判別分布的第一期望值、至少一個所述第五判別分布的第二期望值以及至少一個所述第六判別分布的第三期望值;
分別獲取所述至少一個所述第一期望值的第一平均值、至少一個所述第二期望值的第二平均值以及至少一個所述第三期望值的第三平均值;
確定所述第三平均值與所述第二平均值的第一差值以及所述第二平均值與所述第一平均值的第二差值;
將所述第一差值與所述第二差值的比值確定爲所述當前訓練周期的生成網路的訓練進度參數。
本發明還提供一種圖像生成裝置,使用上述訓練完成的生成對抗網路生成圖像。
在本發明的一些實施例中,一種圖像生成裝置包括:
獲取模組,用於獲取第三隨機向量;
獲得模組,用於將所述第三隨機向量輸入訓練後獲得的生成網路進行處理,獲得目標圖像。
可以理解,本發明提及的上述各個方法實施例,在不違背原理邏輯的情況下,均可以彼此相互結合形成結合後的實施例,限於篇幅,本發明不再贅述。
此外,本發明還提供了神經網路訓練裝置、電子設備、電腦可讀儲存媒體、程式,上述均可用來實現本發明提供的任一種神經網路訓練方法,相應技術方案和描述和參見方法部分的相應記載,不再贅述。本領域技術人員可以理解,在具體實施方式的上述方法中,各步驟的撰寫順序並不意味著嚴格的執行順序而對實施過程構成任何限定,各步驟的具體執行順序應當以其功能和可能的內在邏輯確定。在一些實施例中,本發明實施例提供的裝置具有的功能或包含的模組可以用於執行上文方法實施例描述的方法,其具體實現可以參照上文方法實施例的描述,爲了簡潔,這裏不再贅述。
本發明實施例還提出一種電腦可讀儲存媒體,其上儲存有電腦程式指令,所述電腦程式指令被處理器執行時實現上述方法。電腦可讀儲存媒體可以是揮發性電腦可讀儲存媒體或非揮發性電腦可讀儲存媒體。
本發明實施例還提出一種電子設備,包括:處理器;用於儲存處理器可執行指令的記憶體;其中,所述處理器被配置爲上述方法。電子設備可以被提供爲終端、伺服器或其它形態的設備。
圖4是根據一示例性實施例示出的一種電子設備800的方塊圖。例如,電子設備800可以是行動電話,電腦,數位廣播終端,訊息收發設備,遊戲控制台,平板設備,醫療設備,健身設備,個人數位助理等終端。
參照圖4,電子設備800可以包括以下一個或多個組件:處理組件802,記憶體804,電源組件806,多媒體組件808,音訊組件810,輸入/輸出(I/O)的介面812,感測器組件814,以及通訊組件816。
處理組件802通常控制電子設備800的整體操作,諸如與顯示,電話呼叫,數據通訊,相機操作和記錄操作相關聯的操作。處理組件802可以包括一個或多個處理器820來執行指令,以完成上述的方法的全部或部分步驟。此外,處理組件802可以包括一個或多個模組,便於處理組件802和其他組件之間的交互。例如,處理組件802可以包括多媒體模組,以方便多媒體組件808和處理組件802之間的交互。
記憶體804被配置爲儲存各種類型的數據以支持在電子設備800的操作。這些數據的示例包括用於在電子設備800上操作的任何應用程式或方法的指令,連絡人數據,電話簿數據,訊息,圖片,視訊等。記憶體804可以由任何類型的揮發性或非揮發性儲存設備或者它們的組合實現,如靜態隨機存取記憶體(SRAM),電子可抹除可程式化唯讀記憶體(EEPROM),可抹除可程式化唯讀記憶體(EPROM),可程式化唯讀記憶體(PROM),唯讀記憶體(ROM),磁記憶體,快閃記憶體,磁碟或光碟。
電源組件806爲電子設備800的各種組件提供電力。電源組件806可以包括電源管理系統,一個或多個電源,及其他與爲電子設備800生成、管理和分配電力相關聯的組件。
多媒體組件808包括在所述電子設備800和用戶之間的提供一個輸出介面的螢幕。在一些實施例中,螢幕可以包括液晶顯示器(LCD)和觸控面板(TP)。如果螢幕包括觸控面板,螢幕可以被實現爲觸控螢幕,以接收來自用戶的輸入訊號。觸控面板包括一個或多個觸控感測器以感測觸控、滑動和觸控面板上的手勢。所述觸控感測器可以不僅感測觸控或滑動動作的邊界,而且還檢測與所述觸控或滑動操作相關的持續時間和壓力。在一些實施例中,多媒體組件808包括一個前置攝影機和/或後置攝影機。當電子設備800處於操作模式,如拍攝模式或視訊模式時,前置攝影機和/或後置攝影機可以接收外部的多媒體數據。每個前置攝影機和後置攝影機可以是一個固定的光學透鏡系統或具有焦距和光學變焦能力。
音訊組件810被配置爲輸出和/或輸入音訊訊號。例如,音訊組件810包括一個麥克風(MIC),當電子設備800處於操作模式,如呼叫模式、記錄模式和語音辨識模式時,麥克風被配置爲接收外部音訊訊號。所接收的音訊訊號可以被進一步儲存在記憶體804或經由通訊組件816發送。在一些實施例中,音訊組件810還包括一個揚聲器,用於輸出音訊訊號。
I/O介面812爲處理組件802和外圍介面模組之間提供介面,上述外圍介面模組可以是鍵盤,點擊輪,按鈕等。這些按鈕可包括但不限於:主頁按鈕、音量按鈕、啓動按鈕和鎖定按鈕。
感測器組件814包括一個或多個感測器,用於爲電子設備800提供各個方面的狀態評估。例如,感測器組件814可以檢測到電子設備800的打開/關閉狀態,組件的相對定位,例如所述組件爲電子設備800的顯示器和小鍵盤,感測器組件814還可以檢測電子設備800或電子設備800一個組件的位置改變,用戶與電子設備800接觸的存在或不存在,電子設備800方位或加速/減速和電子設備800的溫度變化。感測器組件814可以包括接近感測器,被配置用來在沒有任何的物理接觸時檢測附近物體的存在。感測器組件814還可以包括光感測器,如CMOS或CCD圖像感測器,用於在成像應用中使用。在一些實施例中,該感測器組件814還可以包括加速度感測器,陀螺儀感測器,磁感測器,壓力感測器或溫度感測器。
通訊組件816被配置爲便於電子設備800和其他設備之間有線或無線方式的通訊。電子設備800可以接入基於通訊標準的無線網路,如WiFi,2G或3G,或它們的組合。在一個示例性實施例中,通訊組件816經由廣播通道接收來自外部廣播管理系統的廣播訊號或廣播相關訊息。在一個示例性實施例中,所述通訊組件816還包括近場通訊(NFC)模組,以促進短程通訊。例如,在NFC模組可基於射頻辨識(RFID)技術,紅外數據協會(IrDA)技術,超寬帶(UWB)技術,藍牙(BT)技術和其他技術來實現。
在示例性實施例中,電子設備800可以被一個或多個應用專用集成電路(ASIC)、數位訊號處理器(DSP)、數位訊號處理設備(DSPD)、可程式化邏輯裝置(PLD)、現場可程式化邏輯閘陣列(FPGA)、控制器、微控制器、微處理器或其他電子元件實現,用於執行上述方法。
在示例性實施例中,還提供了一種非揮發性電腦可讀儲存媒體,例如包括電腦程式指令的記憶體804,上述電腦程式指令可由電子設備800的處理器820執行以完成上述方法。
本發明實施例還提供了一種電腦程式産品,包括電腦可讀代碼,當電腦可讀代碼在設備上運行時,設備中的處理器執行用於實現如上任一實施例提供的神經網路訓練方法的指令。
本發明實施例還提供了另一種電腦程式産品,用於儲存電腦可讀指令,指令被執行時使得電腦執行上述任一實施例提供的圖像生成方法的操作。
上述電腦程式産品可以具體通過硬體、軟體或其結合的方式實現。在一個可選實施例中,所述電腦程式産品具體體現爲電腦儲存媒體,在另一個可選實施例中,電腦程式産品具體體現爲軟體産品,例如軟件開發包(Software Development Kit,SDK)等等。
圖5是根據一示例性實施例示出的一種電子設備1900的方塊圖。例如,電子設備1900可以被提供爲一伺服器。參照圖5,電子設備1900包括處理組件1922,其進一步包括一個或多個處理器,以及由記憶體1932所代表的記憶體資源,用於儲存可由處理組件1922的執行的指令,例如應用程式。記憶體1932中儲存的應用程式可以包括一個或一個以上的每一個對應於一組指令的模組。此外,處理組件1922被配置爲執行指令,以執行上述方法。
電子設備1900還可以包括一個電源組件1926被配置爲執行電子設備1900的電源管理,一個有線或無線網路介面1950被配置爲將電子設備1900連接到網路,和一個輸入輸出(I/O)介面1958。電子設備1900可以操作基於儲存在記憶體1932的操作系統,例如Windows ServerTM,Mac OS XTM,UnixTM, LinuxTM,FreeBSDTM或類似。
在示例性實施例中,還提供了一種非揮發性電腦可讀儲存媒體,例如包括電腦程式指令的記憶體1932,上述電腦程式指令可由電子設備1900的處理組件1922執行以完成上述方法。
本發明可以是系統、方法和/或電腦程式産品。電腦程式産品可以包括電腦可讀儲存媒體,其上載有用於使處理器實現本發明的各個方面的電腦可讀程式指令。
電腦可讀儲存媒體可以是可以保持和儲存由指令執行設備使用的指令的有形設備。電腦可讀儲存媒體例如可以是――但不限於――電儲存設備、磁儲存設備、光儲存設備、電磁儲存設備、半導體儲存設備或者上述的任意合適的組合。電腦可讀儲存媒體的更具體的例子(非窮舉的列表)包括:便攜式電腦碟、硬碟、隨機存取記憶體(RAM)、唯讀記憶體(ROM)、可抹除可程式化唯讀記憶體(EPROM或閃存)、靜態隨機存取記憶體(SRAM)、便攜式壓縮磁碟唯讀記憶體(CD-ROM)、數位多功能影音光碟(DVD)、記憶卡、磁片、機械編碼設備、例如其上儲存有指令的打孔卡或凹槽內凸起結構、以及上述的任意合適的組合。這裏所使用的電腦可讀儲存媒體不被解釋爲瞬時訊號本身,諸如無線電波或者其他自由傳播的電磁波、通過波導或其他傳輸媒介傳播的電磁波(例如,通過光纖電纜的光脉衝)、或者通過電線傳輸的電訊號。
這裏所描述的電腦可讀程式指令可以從電腦可讀儲存媒體下載到各個計算/處理設備,或者通過網路、例如網際網路、區域網路、廣域網路和/或無線網路下載到外部電腦或外部儲存設備。網路可以包括銅傳輸電纜、光纖傳輸、無線傳輸、路由器、防火牆、交換機、網關電腦和/或邊緣伺服器。每個計算/處理設備中的網路介面卡或者網路介面從網路接收電腦可讀程式指令,並轉發該電腦可讀程式指令,以供儲存在各個計算/處理設備中的電腦可讀儲存媒體中。
用於執行本發明操作的電腦程式指令可以是彙編指令、指令集架構(ISA)指令、機器指令、機器相關指令、微代碼、韌體指令、狀態設置數據、或者以一種或多種程式化語言的任意組合編寫的源代碼或目標代碼,所述程式化語言包括面向對象的程式化語言—諸如Smalltalk、C++等,以及常規的過程式程式化語言—諸如“C”語言或類似的程式化語言。電腦可讀程式指令可以完全地在用戶電腦上執行、部分地在用戶電腦上執行、作爲一個獨立的套裝軟體執行、部分在用戶電腦上部分在遠端電腦上執行、或者完全在遠端電腦或伺服器上執行。在涉及遠端電腦的情形中,遠端電腦可以通過任意種類的網路—包括 區域網路(LAN)或廣域網路(WAN)—連接到用戶電腦,或者,可以連接到外部電腦(例如利用網際網路伺服提供商來通過網際網路連接)。在一些實施例中,通過利用電腦可讀程式指令的狀態訊息來個性化定制電子電路,例如可程式化邏輯電路、現場可程式化邏輯閘陣列(FPGA)或可程式化邏輯陣列(PLA),該電子電路可以執行電腦可讀程式指令,從而實現本發明的各個方面。
這裏參照根據本發明實施例的方法、裝置(系統)和電腦程式産品的流程圖和/或方塊圖描述了本發明的各個方面。應當理解,流程圖和/或方塊圖的每個方塊以及流程圖和/或方塊圖中各方塊的組合,都可以由電腦可讀程式指令實現。
這些電腦可讀程式指令可以提供給通用電腦、專用電腦或其它可程式化數據處理裝置的處理器,從而生産出一種機器,使得這些指令在通過電腦或其它可程式化數據處理裝置的處理器執行時,産生了實現流程圖和/或方塊圖中的一個或多個方塊中規定的功能/動作的裝置。也可以把這些電腦可讀程式指令儲存在電腦可讀儲存媒體中,這些指令使得電腦、可程式化數據處理裝置和/或其他設備以特定方式工作,從而,儲存有指令的電腦可讀媒體則包括一個製造品,其包括實現流程圖和/或方塊圖中的一個或多個方塊中規定的功能/動作的各個方面的指令。
也可以把電腦可讀程式指令加載到電腦、其它可程式化數據處理裝置、或其它設備上,使得在電腦、其它可程式化數據處理裝置或其它設備上執行一系列操作步驟,以産生電腦實現的過程,從而使得在電腦、其它可程式化數據處理裝置、或其它設備上執行的指令實現流程圖和/或方塊圖中的一個或多個方塊中規定的功能/動作。
附圖中的流程圖和方塊圖顯示了根據本發明的多個實施例的系統、方法和電腦程式産品的可能實現的體系架構、功能和操作。在這點上,流程圖或方塊圖中的每個方塊可以代表一個模組、程式段或指令的一部分,所述模組、程式段或指令的一部分包含一個或多個用於實現規定的邏輯功能的可執行指令。在有些作爲替換的實現中,方塊中所標注的功能也可以以不同於附圖中所標注的順序發生。例如,兩個連續的方塊實際上可以基本並行地執行,它們有時也可以按相反的順序執行,這依所涉及的功能而定。也要注意的是,方塊圖和/或流程圖中的每個方塊、以及方塊圖和/或流程圖中的方塊的組合,可以用執行規定的功能或動作的專用的基於硬體的系統來實現,或者可以用專用硬體與電腦指令的組合來實現。
以上已經描述了本發明的各實施例,上述說明是示例性的,並非窮盡性的,並且也不限於所披露的各實施例。在不偏離所說明的各實施例的範圍和精神的情況下,對於本技術領域的普通技術人員來說許多修改和變更都是顯而易見的。本文中所用術語的選擇,旨在最好地解釋各實施例的原理、實際應用或對市場中的技術的技術改進,或者使本技術領域的其它普通技術人員能理解本文披露的各實施例。
11:生成模組
12:判別模組
13:第一確定模組
14:第二確定模組
15:訓練模組
800:電子設備
802:處理組件
804:記憶體
806:電源組件
808:多媒體組件
810:音訊組件
812:輸入/輸出介面
814:感測器組件
816:通訊組件
820:處理器
1900:電子設備
1922:處理組件
1926:電源組件
1932:記憶體
1950:網路介面
1958:輸入輸出介面
此處的附圖被並入說明書中並構成本說明書的一部分,這些附圖示出了符合本發明的實施例,並與說明書一起用於說明本發明的技術方案。
圖1示出根據本發明實施例的神經網路訓練方法的流程圖;
圖2示出根據本發明實施例的神經網路訓練方法的應用示意圖;
圖3示出根據本發明實施例的神經網路訓練裝置的方塊圖;
圖4示出根據本發明實施例的電子裝置的方塊圖;
圖5示出根據本發明實施例的電子裝置的方塊圖。
Claims (13)
- 一種神經網路訓練方法,其中,包括: 將第一隨機向量輸入生成網路,獲得第一生成圖像; 將所述第一生成圖像和第一真實圖像分別輸入判別網路,分別獲得所述第一生成圖像的第一判別分布與第一真實圖像的第二判別分布,其中,所述第一判別分布表示所述第一生成圖像的真實程度的機率分布,所述第二判別分布表示所述第一真實圖像的真實程度的機率分布; 根據所述第一判別分布、所述第二判別分布、預設的第一目標分布以及預設的第二目標分布,確定所述判別網路的第一網路損失,其中,所述第一目標分布爲生成圖像的目標機率分布,所述第二目標分布爲真實圖像的目標機率分布; 根據所述第一判別分布和所述第二判別分布,確定所述生成網路的第二網路損失; 根據所述第一網路損失和所述第二網路損失,對抗訓練所述生成網路和所述判別網路。
- 如請求項1所述的方法,其中,根據所述第一判別分布、所述第二判別分布、預設的第一目標分布以及預設的第二目標分布,確定所述判別網路的第一網路損失,包括: 根據所述第一判別分布和所述第一目標分布,確定所述第一生成圖像的第一分布損失; 根據所述第二判別分布和所述第二目標分布,確定所述第一真實圖像的第二分布損失; 根據所述第一分布損失和所述第二分布損失,確定所述第一網路損失。
- 如請求項2所述的方法,其中,根據所述第一判別分布和所述第一目標分布,確定所述第一生成圖像的第一分布損失,包括: 將所述第一判別分布映射到所述第一目標分布的支撑集,獲得第一映射分布; 確定所述第一映射分布與所述第一目標分布的第一相對熵; 根據所述第一相對熵,確定所述第一分布損失。
- 如請求項2所述的方法,其中,根據所述第二判別分布和所述第二目標分布,確定所述第一真實圖像的第二分布損失,包括: 將所述第二判別分布映射到所述第二目標分布的支撑集,獲得第二映射分布; 確定所述第二映射分布與所述第二目標分布的第二相對熵; 根據所述第二相對熵,確定所述第二分布損失。
- 如請求項2所述的方法,其中,根據所述第一分布損失和所述第二分布損失,確定所述第一網路損失,包括: 對所述第一分布損失和所述第二分布損失進行加權求和處理,獲得所述第一網路損失。
- 如請求項1所述的方法,其中,根據所述第一判別分布和所述第二判別分布,確定所述生成網路的第二網路損失,包括: 確定所述第一判別分布與所述第二判別分布的第三相對熵; 根據所述第三相對熵,確定所述第二網路損失。
- 如請求項1所述的方法,其中,根據所述第一網路損失和所述第二網路損失,對抗訓練所述生成網路和所述判別網路,包括: 根據所述第一網路損失,調整所述判別網路的網路參數; 根據所述第二網路損失,調整所述生成網路的網路參數; 在所述判別網路和所述生成網路滿足訓練條件的情況下,獲得訓練後的所述生成網路和所述判別網路。
- 如請求項7所述的方法,其中,根據所述第一網路損失,調整所述判別網路的網路參數,包括: 將第二隨機向量輸入生成網路,獲得第二生成圖像; 根據所述第二生成圖像對第二真實圖像進行插值處理,獲得插值圖像; 將所述插值圖像輸入所述判別網路,獲得所述插值圖像的第三判別分布; 根據所述第三判別分布,確定所述判別網路的網路參數的梯度; 在所述梯度大於或等於梯度閾值的情況下,根據所述第三判別分布確定梯度懲罰參數; 根據所述第一網路損失和所述梯度懲罰參數,調整所述判別網路的網路參數。
- 如請求項1所述的方法,其中,根據所述第一網路損失和所述第二網路損失,對抗訓練所述生成網路和所述判別網路,包括: 將至少一個歷史訓練周期中輸入生成網路的第一隨機向量輸入當前訓練周期的生成網路,獲得至少一個第三生成圖像; 將與所述至少一個歷史訓練周期中輸入生成網路的第一隨機向量對應的第一生成圖像、至少一個所述第三生成圖像以及至少一個真實圖像分別輸入當前訓練周期的判別網路,分別獲得至少一個第一生成圖像的第四判別分布、至少一個第三生成圖像的第五判別分布和至少一個真實圖像的第六判別分布; 根據所述第四判別分布、所述第五判別分布和所述第六判別分布確定當前訓練周期的生成網路的訓練進度參數; 在所述訓練進度參數小於或等於訓練進度閾值的情況下,停止調整所述判別網路的網路參數,僅調整所述生成網路的網路參數。
- 如請求項9所述的方法,其中,根據所述第四判別分布、所述第五判別分布和所述第六判別分布確定當前訓練周期的生成網路的訓練進度參數,包括: 分別獲取至少一個所述第四判別分布的第一期望值、至少一個所述第五判別分布的第二期望值以及至少一個所述第六判別分布的第三期望值; 分別獲取所述至少一個所述第一期望值的第一平均值、至少一個所述第二期望值的第二平均值以及至少一個所述第三期望值的第三平均值; 確定所述第三平均值與所述第二平均值的第一差值以及所述第二平均值與所述第一平均值的第二差值; 將所述第一差值與所述第二差值的比值確定爲所述當前訓練周期的生成網路的訓練進度參數。
- 一種圖像生成方法,其中,包括: 獲取第三隨機向量; 將所述第三隨機向量輸入如請求項1-10其中任一項所述的方法訓練後獲得的生成網路進行處理,獲得目標圖像。
- 一種電子設備,其中,包括: 處理器; 用於儲存處理器可執行指令的記憶體; 其中,所述處理器被配置爲:執行如請求項1至11其中任意一項所述的方法。
- 一種電腦可讀儲存媒體,其上儲存有電腦程式指令,其中,所述電腦程式指令被處理器執行時實現如請求項1至11其中任意一項所述的方法。
Applications Claiming Priority (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910927729.6A CN110634167B (zh) | 2019-09-27 | 2019-09-27 | 神经网络训练方法及装置和图像生成方法及装置 |
CN201910927729.6 | 2019-09-27 |
Publications (2)
Publication Number | Publication Date |
---|---|
TW202113752A TW202113752A (zh) | 2021-04-01 |
TWI752405B true TWI752405B (zh) | 2022-01-11 |
Family
ID=68973281
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
TW109101220A TWI752405B (zh) | 2019-09-27 | 2020-01-14 | 神經網路訓練及圖像生成方法、電子設備、儲存媒體 |
Country Status (7)
Country | Link |
---|---|
US (1) | US20210224607A1 (zh) |
JP (1) | JP7165818B2 (zh) |
KR (1) | KR20210055747A (zh) |
CN (1) | CN110634167B (zh) |
SG (1) | SG11202103479VA (zh) |
TW (1) | TWI752405B (zh) |
WO (1) | WO2021056843A1 (zh) |
Families Citing this family (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
GB2594070B (en) * | 2020-04-15 | 2023-02-08 | James Hoyle Benjamin | Signal processing system and method |
US11272097B2 (en) * | 2020-07-30 | 2022-03-08 | Steven Brian Demers | Aesthetic learning methods and apparatus for automating image capture device controls |
KR102352658B1 (ko) * | 2020-12-31 | 2022-01-19 | 주식회사 나인티나인 | 건설 사업 정보 관리 시스템 및 이의 제어 방법 |
CN112990211B (zh) * | 2021-01-29 | 2023-07-11 | 华为技术有限公司 | 一种神经网络的训练方法、图像处理方法以及装置 |
US20240127586A1 (en) | 2021-02-04 | 2024-04-18 | Deepmind Technologies Limited | Neural networks with adaptive gradient clipping |
EP4047524A1 (en) * | 2021-02-18 | 2022-08-24 | Robert Bosch GmbH | Device and method for training a machine learning system for generating images |
CN113159315A (zh) * | 2021-04-06 | 2021-07-23 | 华为技术有限公司 | 一种神经网络的训练方法、数据处理方法以及相关设备 |
TWI766690B (zh) * | 2021-05-18 | 2022-06-01 | 詮隼科技股份有限公司 | 封包產生方法及封包產生系統之設定方法 |
KR102636866B1 (ko) * | 2021-06-14 | 2024-02-14 | 아주대학교산학협력단 | 공간 분포를 이용한 휴먼 파싱 방법 및 장치 |
CN114501164A (zh) * | 2021-12-28 | 2022-05-13 | 海信视像科技股份有限公司 | 音视频数据的标注方法、装置及电子设备 |
CN114881884B (zh) * | 2022-05-24 | 2024-03-29 | 河南科技大学 | 一种基于生成对抗网络的红外目标样本增强方法 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108510435A (zh) * | 2018-03-28 | 2018-09-07 | 北京市商汤科技开发有限公司 | 图像处理方法及装置、电子设备和存储介质 |
CN109920016A (zh) * | 2019-03-18 | 2019-06-21 | 北京市商汤科技开发有限公司 | 图像生成方法及装置、电子设备和存储介质 |
US20190272890A1 (en) * | 2017-07-25 | 2019-09-05 | Insilico Medicine, Inc. | Deep proteome markers of human biological aging and methods of determining a biological aging clock |
Family Cites Families (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
KR100996209B1 (ko) * | 2008-12-23 | 2010-11-24 | 중앙대학교 산학협력단 | 변화값 템플릿을 이용한 객체 모델링 방법 및 그 시스템 |
US8520958B2 (en) * | 2009-12-21 | 2013-08-27 | Stmicroelectronics International N.V. | Parallelization of variable length decoding |
US20190228268A1 (en) * | 2016-09-14 | 2019-07-25 | Konica Minolta Laboratory U.S.A., Inc. | Method and system for cell image segmentation using multi-stage convolutional neural networks |
JP6318211B2 (ja) * | 2016-10-03 | 2018-04-25 | 株式会社Preferred Networks | データ圧縮装置、データ再現装置、データ圧縮方法、データ再現方法及びデータ転送方法 |
EP3336800B1 (de) * | 2016-12-19 | 2019-08-28 | Siemens Healthcare GmbH | Bestimmen einer trainingsfunktion zum generieren von annotierten trainingsbildern |
CN107293289B (zh) * | 2017-06-13 | 2020-05-29 | 南京医科大学 | 一种基于深度卷积生成对抗网络的语音生成方法 |
CN108495110B (zh) * | 2018-01-19 | 2020-03-17 | 天津大学 | 一种基于生成式对抗网络的虚拟视点图像生成方法 |
CN108615073B (zh) * | 2018-04-28 | 2020-11-03 | 京东数字科技控股有限公司 | 图像处理方法及装置、计算机可读存储介质、电子设备 |
CN109377448B (zh) * | 2018-05-20 | 2021-05-07 | 北京工业大学 | 一种基于生成对抗网络的人脸图像修复方法 |
CN108805833B (zh) * | 2018-05-29 | 2019-06-18 | 西安理工大学 | 基于条件对抗网络的字帖二值化背景噪声杂点去除方法 |
CN109377452B (zh) * | 2018-08-31 | 2020-08-04 | 西安电子科技大学 | 基于vae和生成式对抗网络的人脸图像修复方法 |
CN109933677A (zh) * | 2019-02-14 | 2019-06-25 | 厦门一品威客网络科技股份有限公司 | 图像生成方法和图像生成系统 |
CN109919921B (zh) * | 2019-02-25 | 2023-10-20 | 天津大学 | 基于生成对抗网络的环境影响程度建模方法 |
-
2019
- 2019-09-27 CN CN201910927729.6A patent/CN110634167B/zh active Active
- 2019-12-11 SG SG11202103479VA patent/SG11202103479VA/en unknown
- 2019-12-11 WO PCT/CN2019/124541 patent/WO2021056843A1/zh active Application Filing
- 2019-12-11 KR KR1020217010144A patent/KR20210055747A/ko not_active Application Discontinuation
- 2019-12-11 JP JP2021518079A patent/JP7165818B2/ja active Active
-
2020
- 2020-01-14 TW TW109101220A patent/TWI752405B/zh not_active IP Right Cessation
-
2021
- 2021-04-02 US US17/221,096 patent/US20210224607A1/en not_active Abandoned
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20190272890A1 (en) * | 2017-07-25 | 2019-09-05 | Insilico Medicine, Inc. | Deep proteome markers of human biological aging and methods of determining a biological aging clock |
CN108510435A (zh) * | 2018-03-28 | 2018-09-07 | 北京市商汤科技开发有限公司 | 图像处理方法及装置、电子设备和存储介质 |
CN109920016A (zh) * | 2019-03-18 | 2019-06-21 | 北京市商汤科技开发有限公司 | 图像生成方法及装置、电子设备和存储介质 |
Also Published As
Publication number | Publication date |
---|---|
WO2021056843A1 (zh) | 2021-04-01 |
TW202113752A (zh) | 2021-04-01 |
JP2022504071A (ja) | 2022-01-13 |
CN110634167B (zh) | 2021-07-20 |
SG11202103479VA (en) | 2021-05-28 |
JP7165818B2 (ja) | 2022-11-04 |
US20210224607A1 (en) | 2021-07-22 |
KR20210055747A (ko) | 2021-05-17 |
CN110634167A (zh) | 2019-12-31 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
TWI752405B (zh) | 神經網路訓練及圖像生成方法、電子設備、儲存媒體 | |
TWI781359B (zh) | 人臉和人手關聯檢測方法及裝置、電子設備和電腦可讀儲存媒體 | |
TWI766286B (zh) | 圖像處理方法及圖像處理裝置、電子設備和電腦可讀儲存媒介 | |
TWI717923B (zh) | 面部識別方法及裝置、電子設備和儲存介質 | |
US20210012523A1 (en) | Pose Estimation Method and Device and Storage Medium | |
JP7106687B2 (ja) | 画像生成方法および装置、電子機器、並びに記憶媒体 | |
TWI755833B (zh) | 一種圖像處理方法、電子設備和儲存介質 | |
TWI736179B (zh) | 圖像處理方法、電子設備和電腦可讀儲存介質 | |
TW202113757A (zh) | 目標對象匹配方法及目標對象匹配裝置、電子設備和電腦可讀儲存媒介 | |
CN110837761B (zh) | 多模型知识蒸馏方法及装置、电子设备和存储介质 | |
CN110598504B (zh) | 图像识别方法及装置、电子设备和存储介质 | |
TWI759830B (zh) | 網路訓練方法、圖像生成方法、電子設備及電腦可讀儲存介質 | |
TW202105260A (zh) | 批量標準化資料的處理方法、圖像分類方法、圖像檢測方法、視訊處理方法 | |
CN105335684B (zh) | 人脸检测方法及装置 | |
TWI718631B (zh) | 人臉圖像的處理方法及裝置、電子設備和儲存介質 | |
CN110909815A (zh) | 神经网络训练、图像处理方法、装置及电子设备 | |
CN112598063A (zh) | 神经网络生成方法及装置、电子设备和存储介质 | |
EP3657497A1 (en) | Method and device for selecting target beam data from a plurality of beams | |
CN111259967A (zh) | 图像分类及神经网络训练方法、装置、设备及存储介质 | |
CN111523599B (zh) | 目标检测方法及装置、电子设备和存储介质 | |
CN110135349A (zh) | 识别方法、装置、设备及存储介质 | |
CN114154068A (zh) | 媒体内容推荐方法、装置、电子设备及存储介质 | |
CN109447258B (zh) | 神经网络模型的优化方法及装置、电子设备和存储介质 | |
TWI770531B (zh) | 人臉識別方法、電子設備和儲存介質 | |
CN111488964A (zh) | 图像处理方法及装置、神经网络训练方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
MM4A | Annulment or lapse of patent due to non-payment of fees |