JP2022003423A - 学習方法、学習装置及びプログラム - Google Patents

学習方法、学習装置及びプログラム Download PDF

Info

Publication number
JP2022003423A
JP2022003423A JP2018150539A JP2018150539A JP2022003423A JP 2022003423 A JP2022003423 A JP 2022003423A JP 2018150539 A JP2018150539 A JP 2018150539A JP 2018150539 A JP2018150539 A JP 2018150539A JP 2022003423 A JP2022003423 A JP 2022003423A
Authority
JP
Japan
Prior art keywords
generator
learning
loss function
classifier
parameters
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
JP2018150539A
Other languages
English (en)
Inventor
正一朗 山口
Seiichiro Yamaguchi
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Preferred Networks Inc
Original Assignee
Preferred Networks Inc
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Preferred Networks Inc filed Critical Preferred Networks Inc
Priority to JP2018150539A priority Critical patent/JP2022003423A/ja
Priority to PCT/JP2019/029977 priority patent/WO2020031802A1/ja
Publication of JP2022003423A publication Critical patent/JP2022003423A/ja
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T7/00Image analysis

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)
  • Machine Translation (AREA)

Abstract

【課題】敵対的生成ネットワークにおけるモード崩壊を軽減するための技術を提供することである。【解決手段】本開示の一態様は、プロセッサにより実行されるステップからなる学習方法であって、敵対的生成ネットワークに従って生成器と識別器とを学習するステップを有し、前記学習するステップは、前記生成器がサンプルしうる領域における、前記生成器の損失関数を凹化するように、前記識別器のパラメータを更新するステップを含む学習方法に関する。【選択図】図6

Description

本開示は、機械学習に関する。
敵対的生成ネットワーク(以下、GANs (Generative Adversarial Networks)と称する)は、画像生成及び動画生成の分野において驚くべき結果を残している一方、学習が困難であることが知られている。GANsの学習を困難にする現象としてモード崩壊("mode collapse")が知られている。
モード崩壊は、モデル分布から生成されるサンプルの多様性が小さくなってしまう現象である。例えば、手書き文字データセットMNISTにあるような手書き文字を生成する際、モデル分布は"0"〜"9"の10個のモードを有する分布になっていると考えられる。しかしながら、GANsの学習の結果として、モデル分布が特定の数字のみサンプルして失敗することがある。
"Generative Adversarial Nets", Ian J. Goodfellow, et. al., In NIPS 2014. "Conditional Generative Adversarial Nets", Mehdi Mirza, et. al., arXiv: 1411.1784, Nov. 6, 2014. "Temporal Generative Adversarial Nets with Singular Value Clipping", Masaki Saito, et. al., arXiv: 1611.06624, Aug. 18, 2017.
モード崩壊を回避する様々な手法が提案されている。例えば、spectral normalizationはGANsの学習不安定性を劇的に改善し、モード崩壊を大きく改善した。
しかしながら、spectral normalizationを用いた場合でも、生成されるサンプルの多様性を測る指標として用いられるinception score及びFID (Frechet Inspection Distance)は、学習に用いたデータのものを有意に下回っている。すなわち、GANsによって学習された生成器は依然として学習データが有する多様性を表現できていないことが分かる。
上述した問題点を鑑み、本開示の課題は、GANsにおけるモード崩壊を軽減するための技術を提供することである。
上記課題を解決するため、本開示の一態様は、プロセッサにより実行されるステップからなる学習方法であって、敵対的生成ネットワークに従って生成器と識別器とを学習するステップを有し、前記学習するステップは、前記生成器がサンプルしうる領域における、前記生成器の損失関数を凹化するように、前記識別器のパラメータを更新するステップを含む学習方法に関する。
本開示によると、GANsにおけるモード崩壊を軽減するための技術を提供することができる。
損失関数の表面上の凸部分への勾配ベクトル及び生成器の分布を示す概略図である。 本開示の一実施例による凸部分の凹化による勾配ベクトル及び生成器の分布を示す概略図である。 本開示の一実施例による学習システムを示す概略図である。 本開示の一実施例による学習装置のハードウェア構成を示すブロック図である。 本開示の一実施例によるGANsによる学習処理を示すフローチャートである。 本開示の一実施例による凸部分の凹化を示す概略図である。
以下の実施例では、GANsによる学習装置及び方法が開示される。
本開示による学習装置及び方法を概略すると、GANsにおける生成器がサンプルしうる領域における、生成器の損失関数を凹化(concavify)又は正則化(regularize)するように識別器のパラメータが更新される。
具体的には、図1に示されるように、生成器の負の損失関数-Lgの表面上で生成器がサンプルしうる領域には、凸な部分が発生する可能性があり、確率勾配法が学習処理に適用される場合、図1(a)に示されるように、生成器の損失関数の勾配ベクトルは当該凸領域に移動することなる。この結果、図1(b)に示されるように、生成器の生成分布は凸領域に集中することになり、特定のデータのみが生成されるモード崩壊が発生する。
本開示の学習装置及び方法によると、このような凸領域を凹化又は正則化し、図2に示されるように、生成器の損失関数の表面をスムース化する。この結果、図2(a)に示されるように、生成器の損失関数の勾配ベクトルは拡散され、図2(b)に示されるように、生成器の生成分布は拡散され、モード崩壊が軽減又は解消される。
以下の説明において、凹凸は数学的な定義に基づくものであり、具体的には、関数fが凸であるとは、区間内の任意の異なる2点x, yと開区間(0, 1)内の任意のtに対して、
f(tx + (1-t)y) ≦ tf(x) + (1-t)f(y)
を満たすと定義される。また、-fが凸関数のとき、fを凹関数と呼ぶ。凸関数を「下に凸な関数」、凹関数を「上に凸な関数」と称することもある。
まず、図3及び4を参照して、本開示の一実施例によるGANsによる学習装置を説明する。図3は、本開示の一実施例による学習システムを示す概略図である。
図3に示されるように、本開示の一実施例による学習システム10は、データベース(DB)50及び学習装置100を有する。
DB50は、学習装置100により利用される訓練データを格納する。具体的には、DB50は、学習装置100における生成器による生成対象であると共に、識別器による判別対象であるデータを格納する。例えば、学習装置100により学習される生成モデルの性能をシミュレートする場合、DB50には、MNIST, CIFAR-10, CIFAR-100などのシミュレーション用の画像データセットが格納されてもよい。
学習装置100は、GANsにおける生成器及び識別器と呼ばれる2つのニューラルネットワークを有する。本開示による生成器及び識別器には、任意のニューラルネットワークが適用されてもよい。生成器及び識別器のニューラルネットワークは、学習処理の開始時には何れか適切な初期状態に設定され、学習処理が進捗するに従って、生成器及び識別器の各ニューラルネットワークの各種パラメータが、例えば、以下で詳細に説明されるように順次更新される。
一実施例のGANsによる学習処理では、まず乱数などの入力データzが生成器に入力され、生成器によって出力データが生成される。次に、生成器によって生成された出力データ又はDB50における訓練データが入力データxとして識別器に入力され、識別器によって入力データxが生成器による出力データ又はDB50からの訓練データの何れであるかを示す判別結果が出力される。例えば、生成器による出力データである場合には0が出力され、DB50からの訓練データである場合には1が出力される。当該判別結果に応じて、識別器が正しい判別結果を出力するように、例えば、確率勾配法に基づくバックプロパゲーションに従って識別器のニューラルネットワークのパラメータが更新される。また、生成器の出力データが識別器によって訓練データと判別されるように、例えば、確率勾配法に基づくバックプロパゲーションに従って生成器のニューラルネットワークのパラメータが更新される。
すなわち、GANsでは、
Figure 2022003423
となるように学習処理が実行される。ここで、gは生成器であり、fは識別器であり、xは入力データであり、Lgは生成器の活性化関数であり、Lrは識別器の活性化関数である。V(g,f)はベースライン目的関数として参照されうる。
また、f,gをそれぞれφ,θによってパラメータ化すると、GANsによる学習処理では、
Figure 2022003423
に従って生成器及び識別器のパラメータが更新されていく。ここで、zは乱数又はノイズであり、αは学習率である。
また、上記の生成器のパラメータθの更新式の第2項について、
Figure 2022003423
により書き換え可能である。ターゲット分布
Figure 2022003423
と共に(ただし、全ての可測集合Aに対して、
Figure 2022003423
である)、輸送関数を
Figure 2022003423
として定義する。ここで、
Figure 2022003423
はシード変数zに依存する分布である。このとき、
Figure 2022003423
となる。
このことは、
Figure 2022003423
とg(z)とのL2距離の平方が減少するようにgが更新され続けることを意味し、すなわち、生成器の更新は、
Figure 2022003423
Figure 2022003423
に向かって移動させることを意味する。すなわち、上述したGANsによる学習処理は、関数勾配の観点から以下のように記述できる。
Figure 2022003423
本開示によると、上述したfの目的関数(critic's objective)が、
Figure 2022003423
により置き換えられ、当該目的関数により識別器のパラメータが更新される。ここで、V(g,f)は上述したベースライン目的関数であり、εは0から1の範囲内の値であり、αは定数である。また、Lregは、生成器の分布から2点x1, x2を独立にサンプリングし、x1, x2の間の生成器の損失関数Lgの表面における凹凸を
Figure 2022003423
に従って評価することによって決定される。すなわち、上述したfの目的関数は、生成器の負の損失関数の表面上で生成器がサンプルしうる領域において当該損失関数を凹化するように、識別器のパラメータに正則化を加える。
所定の終了条件が充足されるまで、上述した生成器及び識別器のパラメータが更新され続け、所定の終了条件が充足されると、最終的な生成器が学習済み生成モデルとして取得される。しかしながら、本開示による学習処理は、これに限定されず、他の何れか適切なGANsに基づく学習処理が適用されてもよい。
ここで、学習装置100は、例えば、図4に示されるように、CPU (Central Processing unit)、GPU (Graphics Processing Unit)などのプロセッサ101、RAM (Random Access Memory)、フラッシュメモリなどのメモリ102、ハードディスク103及び入出力(I/O)インタフェース104によるハードウェア構成を有してもよい。
プロセッサ101は、学習装置100の各種処理を実行し、上述したGANsによる生成器及び識別器に対する学習処理、生成器及び識別器の実行、生成器、識別器及びDB50の間のデータの入出力を含む、学習装置100の全体制御などの各種処理を実行する。
メモリ102は、学習装置100における各種データ及びプログラムを格納し、特に作業用データ、実行中のプログラムなどのためのワーキングメモリとして機能する。具体的には、メモリ102は、ハードディスク103からロードされた生成器及び識別器における学習処理を実行及び制御するためのプログラムを格納し、プロセッサ101によるプログラムの実行中にワーキングメモリとして機能する。
ハードディスク103は、学習装置100における各種データ及びプログラムを格納し、生成器及び識別器における処理を実行及び制御するための各種データ及び/又はプログラムを格納する。
I/Oインタフェース104は、DB50などの外部装置との間でデータを入出力するためのインタフェースであり、例えば、USB (Universal Serial Bus)、通信回線、キーボード、マウス、ディスプレイなどのデータを入出力するためのデバイスである。
しかしながら、本開示による学習装置100は、上述したハードウェア構成に限定されず、他の何れか適切なハードウェア構成を有してもよい。例えば、上述した学習装置100による学習処理は、これを実現するよう配線化された処理回路又は電子回路により実現されてもよい。
次に、図5及び6を参照して、本開示の一実施例によるGANsによる画像生成モデルの学習処理を説明する。図5は、本開示の一実施例によるGANsによる学習処理を示すフローチャートである。
図5に示されるように、ステップS101において、プロセッサ101は、乱数を生成器に入力する。プロセッサ101は、何れか適切な擬似乱数発生ルーチンを実行することによって、あるいは、学習装置100に搭載された乱数発生器を利用することによって乱数を生成し、生成した乱数を生成器に入力してもよい。
ステップS102において、プロセッサ101は、入力された乱数から生成器によって生成された画像を取得する。例えば、生成器は、何れか適切な構造を有するニューラルネットワークであってもよい。
ステップS103において、プロセッサ101は、生成器によって生成された画像又はDB50に格納されている訓練画像を識別器に入力する。
ステップS104において、プロセッサ101は、入力画像が生成器の出力画像であるか、あるいは、訓練画像であるか識別器に判別させる。例えば、識別器は、何れか適切な構造を有するニューラルネットワークであってもよい。
ステップS105において、プロセッサ101は、識別器による判別結果に応じて識別器及び生成器のパラメータを更新する。すなわち、プロセッサ101は、識別器が入力画像を正しく判別するように、確率勾配法に基づくバックプロパゲーションに従って識別器のパラメータを更新し、識別器が生成器によって生成された画像を訓練画像であると判別するように、確率勾配法に基づくバックプロパゲーションに従って生成器のパラメータを更新する。
具体的には、プロセッサ101は、生成器がサンプルしうる領域における、生成器の損失関数を凹化又は正則化するように、識別器のパラメータを更新する。例えば、プロセッサ101は、上述したように、識別器の目的関数が
Figure 2022003423
となるように、生成器がサンプルしうる領域において損失関数を凹化又は正則化してもよい。すなわち、プロセッサ101は、図6に示されるように、生成器の損失関数の表面上の2点間の線分上の点の当該損失関数の値が2点の損失関数の各値の線形結合になるように、損失関数を凹化又は正則化してもよい。例えば、図6(a)に示されるように、生成器の損失関数の表面上に凸領域がある場合、すなわち、
Figure 2022003423
が正値である場合、プロセッサ101は、
Figure 2022003423
に従って生成器がサンプルしうる領域において損失関数を凹化し、図6(b)に示されるように、損失関数の表面がスムース化されるように正則化を加えながら識別器のパラメータを更新する。
上述したように、このような凸領域は、図1(a)に示されるように、生成器の損失関数の勾配ベクトルを凸領域に向かって誘導させ、この結果、図1(b)に示されるように、モード崩壊を発生させる。一方、本開示によると、生成器がサンプルしうる領域において損失関数を凹化することによって、図2(a)に示されるように、生成器の損失関数の勾配ベクトルが拡散され、図2(b)に示されるように、生成器の生成モデルにおける分布が拡散され、モード崩壊の発生を回避できる。
その後、プロセッサ101は、上述したステップS101〜S105を繰り返し、所定の終了条件が充足されると、当該学習処理を終了する。例えば、所定の終了条件は、所定の回数の繰り返しを終了したこと、生成器及び/又は識別器の精度が所定の閾値を超えたこと、生成器及び/又は識別器の精度が収束したことなどであってもよい。
なお、上述した実施例では、画像データに対して生成器及び識別器が学習されたが、本開示による学習処理は、これに限定されず、動画データ、音響データなどの他の任意のタイプのデータにも適用可能である。
以上、本発明の実施例について詳述したが、本発明は上述した特定の実施形態に限定されるものではなく、特許請求の範囲に記載された本発明の要旨の範囲内において、種々の変形・変更が可能である。
50 データベース(DB)
100 学習装置

Claims (8)

  1. プロセッサにより実行されるステップからなる学習方法であって、
    敵対的生成ネットワークに従って生成器と識別器とを学習するステップを有し、
    前記学習するステップは、前記生成器がサンプルしうる領域における、前記生成器の損失関数を凹化するように、前記識別器のパラメータを更新するステップを含む学習方法。
  2. 前記更新するステップは、前記損失関数の表面上の2点間の線分上の点の前記損失関数の値が前記2点の損失関数の各値の線形結合になるように、前記損失関数を凹化する、請求項1記載の学習方法。
  3. 前記損失関数の凹化は、前記損失関数の勾配ベクトルを拡散させる、請求項1又は2記載の学習方法。
  4. 前記学習するステップは、
    前記生成器によって、乱数から画像を生成するステップと、
    前記識別器によって、入力画像が前記生成された画像又は訓練画像の何れであるか判別するステップと、
    判別結果に応じて前記生成器と前記識別器とのパラメータを更新するステップと、
    所定の終了条件が充足されるまで前記生成するステップ、前記判別するステップ及び前記更新するステップを繰り返すステップと、
    を含む、請求項1乃至3何れか一項記載の学習方法。
  5. 前記生成器のパラメータは、前記識別器が前記生成された画像を前記訓練画像であると判別するように更新され、
    前記識別器のパラメータは、前記識別器が前記入力画像を正しく判別するように更新される、請求項4記載の学習方法。
  6. 前記生成器及び前記識別器は、ニューラルネットワークである、請求項1乃至5何れか一項記載の学習方法。
  7. メモリと、
    前記メモリに結合されるプロセッサと、
    を有し、
    前記プロセッサは、
    敵対的生成ネットワークに従って生成器と識別器とを学習し、
    前記プロセッサは、前記生成器がサンプルしうる領域における、前記生成器の損失関数を凹化するように、前記識別器のパラメータを更新する学習装置。
  8. 敵対的生成ネットワークに従って生成器と識別器とを学習する処理をプロセッサに実行させ、
    前記学習する処理は、前記生成器がサンプルしうる領域における、前記生成器の損失関数を凹化するように、前記識別器のパラメータを更新する処理を含むプログラム。
JP2018150539A 2018-08-09 2018-08-09 学習方法、学習装置及びプログラム Pending JP2022003423A (ja)

Priority Applications (2)

Application Number Priority Date Filing Date Title
JP2018150539A JP2022003423A (ja) 2018-08-09 2018-08-09 学習方法、学習装置及びプログラム
PCT/JP2019/029977 WO2020031802A1 (ja) 2018-08-09 2019-07-31 学習方法、学習装置、モデル生成方法及びプログラム

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
JP2018150539A JP2022003423A (ja) 2018-08-09 2018-08-09 学習方法、学習装置及びプログラム

Publications (1)

Publication Number Publication Date
JP2022003423A true JP2022003423A (ja) 2022-01-11

Family

ID=69414197

Family Applications (1)

Application Number Title Priority Date Filing Date
JP2018150539A Pending JP2022003423A (ja) 2018-08-09 2018-08-09 学習方法、学習装置及びプログラム

Country Status (2)

Country Link
JP (1) JP2022003423A (ja)
WO (1) WO2020031802A1 (ja)

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
KR20210071130A (ko) * 2019-12-05 2021-06-16 삼성전자주식회사 컴퓨팅 장치, 컴퓨팅 장치의 동작 방법, 그리고 저장 매체
CN112837396B (zh) * 2021-01-29 2024-05-07 深圳市天耀创想网络科技有限公司 一种基于机器学习的线稿生成方法及装置
WO2023171755A1 (ja) * 2022-03-09 2023-09-14 ソニーセミコンダクタソリューションズ株式会社 情報処理装置、情報処理方法、記録媒体、情報処理システム

Also Published As

Publication number Publication date
WO2020031802A1 (ja) 2020-02-13

Similar Documents

Publication Publication Date Title
US20240046106A1 (en) Multi-task neural networks with task-specific paths
EP3955204A1 (en) Data processing method and apparatus, electronic device and storage medium
US20190087730A1 (en) Non-transitory computer-readable storage medium storing improved generative adversarial network implementation program, improved generative adversarial network implementation apparatus, and learned model generation method
EP3602419B1 (en) Neural network optimizer search
JP2022003423A (ja) 学習方法、学習装置及びプログラム
JP6187977B2 (ja) 解析装置、解析方法及びプログラム
US10635078B2 (en) Simulation system, simulation method, and simulation program
CN112488183A (zh) 一种模型优化方法、装置、计算机设备及存储介质
KR102093080B1 (ko) 레이블 데이터 및 비레이블 데이터를 이용한 생성적 적대 신경망 기반의 분류 시스템 및 방법
CN111178082A (zh) 一种句向量生成方法、装置及电子设备
US8700686B1 (en) Robust estimation of time varying parameters
US20230316094A1 (en) Systems and methods for heuristic algorithms with variable effort parameters
WO2024001108A1 (zh) 一种文本答案的确定方法、装置、设备和介质
CN117093684A (zh) 企业服务领域预训练对话式大语言模型的构建方法及系统
US7933449B2 (en) Pattern recognition method
CN116361657A (zh) 用于对灰样本标签进行消歧的方法、系统和存储介质
Martino et al. Smelly parallel MCMC chains
JP2020030674A (ja) 情報処理装置、情報処理方法及びプログラム
US20200320393A1 (en) Data processing method and data processing device
CN112488319B (zh) 一种具有自适应配置生成器的调参方法和系统
KR20220134627A (ko) 하드웨어-최적화된 신경 아키텍처 검색
CN114067415A (zh) 回归模型的训练方法、对象评估方法、装置、设备和介质
CN114445656A (zh) 多标签模型处理方法、装置、电子设备及存储介质
CN110110853B (zh) 一种深度神经网络压缩方法、装置及计算机可读介质
CN111310794A (zh) 目标对象的分类方法、装置和电子设备