JP2020068000A - 訓練装置、訓練方法、予測装置、予測方法及びプログラム - Google Patents

訓練装置、訓練方法、予測装置、予測方法及びプログラム Download PDF

Info

Publication number
JP2020068000A
JP2020068000A JP2018227477A JP2018227477A JP2020068000A JP 2020068000 A JP2020068000 A JP 2020068000A JP 2018227477 A JP2018227477 A JP 2018227477A JP 2018227477 A JP2018227477 A JP 2018227477A JP 2020068000 A JP2020068000 A JP 2020068000A
Authority
JP
Japan
Prior art keywords
hidden vector
graph
network
node
nodes
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
JP2018227477A
Other languages
English (en)
Inventor
勝彦 石黒
Katsuhiko Ishiguro
勝彦 石黒
新一 前田
Shinichi Maeda
新一 前田
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Preferred Networks Inc
Original Assignee
Preferred Networks Inc
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Preferred Networks Inc filed Critical Preferred Networks Inc
Priority to US16/657,389 priority Critical patent/US20200125958A1/en
Publication of JP2020068000A publication Critical patent/JP2020068000A/ja
Pending legal-status Critical Current

Links

Images

Landscapes

  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

【課題】グラフ全体の隠れベクトルを予測する。【解決手段】グラフのデータを入力すると前記グラフの特徴を予測するネットワークを訓練する訓練装置は、グラフを構成する第1ノードの第1隠れベクトル、第1ノード間の接続情報、及び、第1ノードのそれぞれと接続される第2ノードの第2隠れベクトルに基づいて、第1隠れベクトルと第2隠れベクトルとをマージする、マージ部、を各層に備え、マージ結果に基づいて、第1隠れベクトル及び第2ベクトルを更新する、第1ネットワークと、第1ネットワークが出力した第1隠れベクトル及び第2隠れベクトルに基づいて、グラフの特徴を抽出する、第2ネットワークと、第2ネットワークが出力したグラフの特徴の損失を算出する、算出部と、損失に基づいて、少なくとも第1ネットワーク及び第2ネットワークのいずれか一方を更新する、ネットワーク更新部と、を備える。【選択図】図1

Description

本発明は、訓練装置、訓練方法、予測装置、予測方法及びプログラムに関する。
分子、化合物データのモデリング等に広く応用されている機械学習の技術として、グラフデータを入力するとグラフ内の各ノードあるいはエッジの隠れベクトル表現を推定するGCN(Graph Convolution Networks)と呼ばれるものがある。GCNを用いると、離散的なグラフデータを連続値の隠れベクトルの集合に変換することが可能であるため、グラフの特性予測や化合物グラフの毒性識別等、多彩なタスクを表現することができ、その応用分野も広い。GCNモデルは、グラフ内の隣接ノードにおいてメッセージパッシングを開始、近傍ノードの情報を参照して各ノードの隠れベクトルを修正していく。レイヤを重ねるごとに遠くのノードの情報が緩やかに伝播していくことで、最上位のレイヤにおいて得られる各ノードの隠れベクトルは、グラフ全体の情報をある程度考慮したものとして取得することが可能となる。
しかしながら、近接ノードからの緩やかな情報伝播モデルであるため、現実的に利用されるレイヤ数において離れたノード間での情報伝達が十分に達成される可能性は低い。グラフ全体の隠れベクトルの表現を計算するノードを想定していても、グラフ直径やノード数のようなグラフ全体に関する情報量を観測量としてモデルに取り込むことがされず、また、実ノードとこのような仮想ノードとの間のメッセージのバランスを取る手段がないため、最適な隠れベクトル表現には直接的には影響を及ぼしているとは言えない。
米国出願公開第2016/0196672号明細書
P. Velickovic, et.al., "Graph Attention Networks," Proceedings of the ICLR, 2018 Anonymous, "Related Graph Attention Networks," https://openreview.net/pdf?id=Bklzkh0qFm, [インターネット](2018.12.3確認) Y. Li, et.al., "Gated Graph Sequence Neural Networks," Proceedings of the ICLR, 2016 T. N. Kipt, et.al., "Semi-supervised Classification with graph Convolutional Networks," Proceedings of the ICLR, 2017
本発明の実施形態は、グラフ全体の特徴量を予測するネットワークの訓練装置、訓練方法及びプログラム、並びに、当該訓練装置により取得される予測装置、予測方法を提供する。
一実施形態による訓練装置は、グラフのデータを入力すると、前記グラフの特徴を予測するネットワークを訓練する訓練装置であって、前記グラフを構成する第1ノードの第1隠れベクトル、前記第1ノード間の接続情報、及び、前記第1ノードのそれぞれと接続される第2ノードの第2隠れベクトルに基づいて、前記第1隠れベクトルと前記第2隠れベクトルとをマージする、マージ部、を各層に備え、前記マージ結果に基づいて、前記第1隠れベクトル及び前記第2ベクトルを更新する、第1ネットワークと、前記第1ネットワークが出力した前記第1隠れベクトル及び前記第2隠れベクトルに基づいて、前記グラフの特徴を抽出する、第2ネットワークと、前記第2ネットワークが出力した前記グラフの特徴の損失を算出する、算出部と、前記損失に基づいて、少なくとも前記第1ネットワーク及び前記第2ネットワークのいずれか一方を更新する、ネットワーク更新部と、を備える。
一実施形態に係る予測装置を示す概略図 一実施形態に係る訓練装置の機能を示すブロック図。 一実施形態に係る前処理部の機能を示すブロック図。 一実施形態に係る訓練モードにおける演算ネットワークの機能を示すブロック図。 一実施形態に係る第1ネットワークの機能を示すブロック図。 一実施形態に係る訓練における処理を示すフローチャート。 一実施形態に係る第1ネットワークのデータフローを示す図。 一実施形態に係る予測モードにおける演算ネットワークの機能を示すブロック図。 一実施形態に係る装置のハードウェア実装例を示す図。
図1は、一実施形態に係る予測装置の機能を示す概略図である。予測装置は、グラフのデータを入力すると、入力されたグラフの特徴を予測して出力する。例えば、化合物の化学式(構造式)が入力されると、当該化合物の毒性等の特徴を示す量を出力する。グラフの他の例としては、回路図、間取り図、種々の設計図、又は、文章(言語)の構成等が挙げられるがこれに限られるものではなく、特徴を抽出したいグラフであればどのようなものであってもよい。
入力されたグラフは、グラフを構成するノード(第1ノード)ごとに、それぞれの第1ノードが有する特徴、及び、各第1ノード間の接続情報が抽出される。例えばグラフが化学式である場合、第1ノードの特徴としては、各第1ノードを表す分子又は原子をインデクスで表したもの、第1ノードに接続されるエッジの数等から第1ノード特徴量が抽出される。
接続情報として、例えば、隣接行列が抽出される。隣接行列は、多重結合の結合数ごとに別の行列として抽出されてもよいし、結合数を行列の要素として示すものであってもよい。さらに、結合の種類ごとに別の行列として抽出してもよい。例えば、グラフが化合物を示す場合、π結合、σ結合等の結合の種類ごとに別の行列として抽出してもよい。また、有向グラフである場合には、一方向のみに接続状態を示す隣接行列を生成してもよい。上記には限られず、接続情報として第1ノード間の接続を適切に表現する他のテンソルを抽出してもよい。以下において、隣接行列を用いるものとして説明するが、このテンソルを用いて言い換えられることに理解されたい。
上記の特徴に加え、全ての第1ノードと接続するスーパーノード(第2ノード)を仮想的に生成し、この第2ノードの特徴量が抽出される。第2ノードの特徴量は、例えばグラフが化学式である場合、グラフのノード数、グラフ直径、原子間の結合種類の数といった量が抽出される。
以下、説明のため、第1ノードをローカルノードと、第2ノードをスーパーノードと、さらに、第1ノードと第2ノードの双方を示すものとして単にノードと記載することがある。
抽出されたローカルノードの特徴量は、ローカルノードの隠れベクトル(第1隠れベクトル)へと変換され、スーパーノードの特徴量は、スーパーノードの隠れベクトル(第2隠れベクトル)と変換され、第1ネットワークへと入力される。第1ネットワークは、L(≧1)層のネットワークを備えて構成される。各層はそれぞれ、メッセージ部と、マージ部と、リカレント部と、を備える。
メッセージ部は、あるノードに接続されるそれぞれのノードの前の層における隠れベクトル及び当該ノードの前の層における隠れベクトルに基づいて、現在の層におけるノードの隠れベクトルを更新するためのメッセージを生成する。ノードの接続関係は、上記の隣接行列を参照する。
メッセージは、ローカルノードからローカルノードへの第1メッセージ、ローカルノードからスーパーノードへの第2メッセージ、スーパーノードからローカルノードへの第3メッセージ、スーパーノードからスーパーノードへの第4メッセージのそれぞれが生成される。メッセージ部は、各層において、それぞれローカルノード数分の第1メッセージと第3メッセージを生成し、それぞれ1つの第2メッセージと第4メッセージを生成する。
マージ部は、メッセージ部が生成した各メッセージに基づいて、各ローカルノードとスーパーノードの隠れベクトルを更新するための更新隠れベクトルを生成する。より具体的には、第1メッセージ及び第3メッセージを用いてローカル更新隠れベクトル(第1更新隠れベクトル)を生成し、第2メッセージ及び第4メッセージを用いてスーパー更新隠れベクトル(第2更新隠れベクトル)を生成する。
第1メッセージと第3メッセージのマージにおける重み付け、及び、第2メッセージと第4メッセージのマージにおける重み付けは、訓練中に適応的に更新される。この重みの更新により、マージ部は、その性質が異なるローカルノードの隠れベクトルと、スーパーノードの隠れベクトルとを適応的に混合するゲートとして動作する。
リカレント部は、マージ部が生成した第1更新隠れベクトルと、前の層においてリカレント部から出力された第1隠れベクトルとを自己回帰的にゲーティングさせて第1隠れベクトルを更新し、同様に、第2更新隠れベクトルと、前の層において出力された第2隠れベクトルとをゲーティングさせて第2隠れベクトルを更新する。
リカレント部が出力した第1隠れベクトル及び第2隠れベクトルは、それぞれ、次の層のメッセージ部へと出力される。そして、メッセージ部、マージ部、リカレント部において処理を行って、さらに次の層へと出力する。
第1ネットワークは、最終層におけるリカレント部において更新された第1隠れベクトル及び第2隠れベクトルを第2ネットワークへと出力する。第2ネットワークは、第1ネットワークが出力した第1隠れベクトル及び第2隠れベクトルに基づいて、グラフが有する特徴を示す特徴ベクトルを出力する。
なお、第1ネットワークは、メッセージ部、マージ部、リカレント部を備えるとしたが、これには限られず、メッセージ部及びリカレント部の少なくとも1つが備えられない構成であってもよい。例えば、メッセージ部を省略し、マージ部がローカルノード、スーパーノード間の何らかの特徴を示すテンソル同士をマージして出力して各ノードの更新隠れベクトルを出力し、リカレント部が隠れベクトルを更新するネットワークを形成してもよい。別の例として、リカレント部を省略し、マージ部がマージした結果を隠れベクトルとして出力するネットワークを形成してもよい。
以下、各構成要素について、具体例を示しながら詳しく説明する。
図2は、本実施形態に係る予測装置、及び、予測装置に備えられるネットワークを訓練する訓練装置の機能を示すブロック図である。
訓練装置1は、予測装置におけるネットワークの訓練を行う。前処理部10は、演算ネットワーク11に入力されるデータを生成する前処理を行う。訓練装置1は、この前処理部10を利用して演算ネットワーク11の訓練を行う装置であり、訓練制御部12と、訓練データ記憶部13と、損失算出部14と、勾配算出部15と、を備える。
訓練制御部12は、訓練データ記憶部13からグラフデータを取得し、前処理部10においてグラフデータをネットワークの入力データへと変換する。前処理部10は、演算ネットワークに変換したデータを入力し、演算ネットワーク11を順伝播させる。演算ネットワーク11は、演算されたグラフの特徴ベクトルを損失算出部14へと出力し、正解データと比較を行い、損失を算出する。勾配算出部15は、損失算出部14が算出した損失に基づいて勾配を算出し、演算ネットワーク11を更新する。損失算出部14と勾配算出部15は、合わせてネットワーク更新部として機能してもよい。
予測装置2は、前処理部10と、訓練装置1により訓練された演算ネットワーク11と、予測制御部22と、予測データ記憶部23と、予測部24と、を備え、グラフデータが入力されると、当該グラフが有する特徴を出力する。
予測制御部22は、予測データ記憶部23に記憶されているグラフデータを、前処理部10を介して演算ネットワーク11へ入力する。予測部24は、演算ネットワーク11が出力した特徴ベクトルから、入力されたグラフの特徴を予測して出力する。
予測装置2は、1つのタスクを処理する演算ネットワーク11を備えていてもよいし、複数のタスクを処理する演算ネットワーク11をそれぞれ備えていてもよい。複数のタスクに対するそれぞれの演算ネットワーク11が備えられている場合、予測制御部22は、ユーザが入力したタスクに応じて、演算ネットワーク11を切り替えてもよい。この場合、ユーザがタスクを指定してもよいし、入力されたグラフを自動的に予測制御部22が判断して、適した演算ネットワーク11を用いるようにしてもよい。
また、訓練装置1と予測装置2は、別個に存在している必要は無く、同一の装置内に備えられていてもよい。装置は、訓練装置1を用いる訓練モードと、予測装置2を用いる予測モードと、を切り替えることにより、訓練及び予測を行う。このような場合、予測装置2が予測した結果をユーザが確認した後に、訓練装置1の訓練データとして訓練データ記憶部13に記憶し、さらにネットワークを更新できるようにしてもよい。
訓練装置1と予測装置2の各構成の動作について説明する。
訓練制御部12は、ユーザから指定された訓練データ、学習方法、又は各種設定を受信し、所望の学習を実行する。訓練に必要となる訓練データは、訓練データ記憶部13に記憶される。格納されている訓練データは、必要なタイミングにおいて訓練制御部12、あるいは、損失算出部14により参照される。訓練制御部12は、前処理部10を介してグラフに関するデータを演算ネットワーク11へと入力した後、演算ネットワーク11の出力に基づき、損失算出部14及び勾配算出部15を制御し、演算ネットワーク11を構成する各種パラメータを更新し、適宜演算ネットワーク11に更新されたパラメータを記憶させる。また、学習が完了した場合には、学習が完了した旨をユーザに通知してもよい。
訓練データ記憶部13は、ユーザが指定するタスクに基づいた訓練データを格納する。訓練データとは、例えば、グラフそのもののデータと、当該グラフが有する特徴に関するデータである。タスクとは、例えば、グラフで示される物質の毒性を識別する識別器、グラフがあるターゲットに対してどの程度親和性があるかを回帰する回帰装置、といった機能のことを示す。訓練データ記憶部13は、訓練制御部12の要求に応じて、1又は複数の学習データを取り出して、前処理部10、演算ネットワーク11、損失算出部14、勾配算出部15等に送信する。
前処理部10は、各記憶部に格納された値を受信し、それらを演算ネットワーク11の設計に合わせた表現へと変換する。図3は、前処理部10の機能を示すブロック図である。前処理部10は、ローカルノード特徴量取得部100と、隣接行列取得部101と、スーパーノード特徴量取得部102と、を備える。
ローカルノード特徴量取得部100は、入力されたグラフのデータからローカルノード及びグラフ全体としての特徴量を抽出する。例えば、化合物の構造式がグラフとして入力されると、上述したように、各第1ノードの特徴量として、分子又は原子をインデクスで表したもの、第1ノードに接続されるエッジの数等からローカルノード特徴量が抽出される。
隣接行列取得部101は、一例として、グラフにおけるローカルノードの接続情報として、グラフの隣接行列を取得する。例えば、上述したように、各ノード間に存在するエッジの情報を抽出し、当該エッジの情報から、隣接行列を取得する。隣接行列は、1つであるとは限られず、上述したように、各エッジの種類、ノード間の接続状況に基づいて、複数生成されてもよい。
スーパーノード特徴量取得部102は、グラフ全体の特徴量を表すスーパーノード特徴量を抽出する。例えば、化合物グラフの場合、上述したように、グラフのノード数、グラフ直径、原子間の結合種類の数等の量を利用できる。
前処理部10は、各構成要素が取得した特徴量を演算ネットワーク11へと出力する。各特徴量が入力された演算ネットワーク11は、訓練制御部12により学習が行われ、ネットワークを形成する各パラメータが決定される。
図4は、訓練制御部12により制御される、すなわち、学習時における演算ネットワーク11の機能を示すブロック図である。演算ネットワーク11は、定数記憶部110と、モデルパラメータ記憶部111と、隠れベクトル初期化部112と、第1ネットワーク113と、第2ネットワーク114と、を備える。
定数記憶部110は、ネットワーク全体の構成及び定数を記憶する。例えば、ローカルノード及びスーパーノードの隠れベクトル(第1隠れベクトル及び第2隠れベクトル)の字数、第1ネットワーク、第2ネットワークのレイヤ数、その他ハイパーパラメータといった、学習の最適化の対象とならない定数の情報を記憶する。定数記憶部110は、モデルの種類等に対応して、これらの定数を複数種類記憶してもよい。
モデルパラメータ記憶部111は、隠れベクトル初期化部112、第1ネットワーク113、第2ネットワーク114における内部の各関数、ニューラルネットワークのパラメータ等の学習最適化の対象となる情報を記憶する。学習において、順伝播のフェーズにおいては、モデルパラメータ記憶部111に記憶されているパラメータに基づいて各構成のパラメータを設定して隠れベクトルの計算を行う。逆伝播のフェーズにおいては、各構成において更新されたパラメータをモデルパラメータ記憶部111が記憶する。
隠れベクトル初期化部112は、前処理部10で取得された各特徴量を第1ネットワーク113及び第2ネットワーク114における計算に適したベクトルへと変換し、変換された各ベクトルを仮想的に第0層の隠れベクトルとして出力する。以下、ローカルノードの数をn、第1ネットワーク113のレイヤ数をLとする。
例えば、ローカルノードiのローカルノード特徴量として原子のインデクスを用いる場合には、当該インデクスを利用したワンホットベクトル及びこれらに対する任意の関数による変換結果をローカルノードiの隠れベクトルh(0,i)として定義できる。ここで、h(l,i)は、第lレイヤにおけるi番目のローカルノードの隠れベクトル(第1隠れベクトル)を表す。全てのノードiについて、隠れベクトル初期化部112は、h(0,i)を初期化する。同様に、隣接行列も任意に変換されて利用される。なお、隣接行列は、変換せずにネットワークに入力されてもよい。
スーパーノード特徴量も同様に変換される。スーパーノードの第lレイヤにおける隠れベクトル(第2隠れベクトル)は、g(l)と表され、隠れベクトル初期化部112により、スーパーノード特徴量から任意の変換により生成される。このようにスーパーノードの隠れベクトルの初期値として、乱数を代入するのではなく、グラフ全体の特徴量を利用することにより、学習の効率化を図ることが可能となる。例えば、ローカルノードの数とエッジの数とを並べた2次元の特徴量をスーパーノード特徴量として用いることができ、この2次元ベクトルを線形、又は、非線形変換方式により変換したベクトルをg(0)として用いてもよい。
隠れベクトル初期化部112により初期化された特徴量は、第1ネットワーク113へと入力され、各層において隠れベクトルが更新される。図5は、第1ネットワーク113の機能を示すブロック図である。
第1ネットワーク113は、第1更新レイヤ113A、・・・、第L更新レイヤ113LのL層からなるネットワークを備える。各更新レイヤは、メッセージ部115と、マージ部116と、リカレント部117と、を備える。メッセージ部115には、前のレイヤの出力が入力される。第1更新レイヤ113Aにおいては、隠れベクトル初期化部112により初期化された第1隠れベクトル及び第2隠れベクトルが入力される。
メッセージ部115は、入力された隠れベクトルに基づいて、接続されている互いのノード同士に対するメッセージを生成し、マージ部116へと出力する。ノード間の接続は、例えば、隠れベクトル初期化部112が出力した隣接行列又はこれに類する量に基づいて参照される。
マージ部116は、メッセージ部115から入力されたメッセージに基づいて、各隠れベクトルを更新するためのベクトルを、各レイヤにおいて入力された隠れベクトルに基づいた重みを算出し、按分してマージすることにより生成する。すなわち、マージ部116は、第1隠れベクトルと、第2隠れベクトルと、に基づいて、第1隠れベクトルに対する重みと第2隠れベクトルに対する重みとを算出し、算出された重みにしたがってマージして、第1隠れベクトル及び第2隠れベクトルを更新するための第1更新隠れベクトル及び第2更新隠れベクトルを生成する。生成された更新隠れベクトルは、リカレント部117へと出力される。
リカレント部117は、各層における第1隠れベクトルと、第2隠れベクトル、そして、マージ部116が出力した第1更新隠れベクトルと、第2更新隠れベクトルとに基づいて、次の更新レイヤの入力となる隠れベクトルを出力する。次のレイヤにおいては、リカレント部117が出力した隠れベクトルがメッセージ部115へと入力され、同様の隠れベクトルの更新が第Lレイヤまで繰り返される。最終層(第Lレイヤ)においては、リカレント部117は、第2ネットワーク114へと隠れベクトルを出力する。
このように、第1ネットワーク113は、入力された第1隠れベクトル及び第2隠れベクトルからそれぞれの隠れベクトルを更新するための更新隠れベクトルを按分することにより算出し、当該更新隠れベクトルに基づいて、それぞれの隠れベクトルを更新する。
図4に戻り、第2ネットワーク114は、第1ネットワーク113により更新された第1隠れベクトル及び第2隠れベクトルと、隣接行列とに基づいて、グラフ全体の表現ベクトルを生成する。第2ネットワーク114は、生成されたn個の第1隠れベクトルh(L,0:n)に基づいてマージされたベクトルh(merged)を算出し、h(merged)と、g(L)を任意の関数により変換し、最終的な読み出しベクトルrを出力する。
損失算出部14は、第2ネットワーク114により求められたrと、訓練データとして与えられている出力とを博して、ロス(損失)を算出する。ロス関数は、ユーザのタスクの目的に合わせて任意の関数とすることもできる。例えば、識別問題であれば、クロスエントロピー、回帰問題であれば2乗誤差を用いることができる。さらには、ロス関数としては既存の多くのDNN学習フレームワークに備え付けの関数等を流用してもよい。
勾配算出部15は、損失算出部14が出力した結果に基づいて、モデルパラメータ記憶部111に記憶されている変数の更新に必要な勾配を算出する。算出された勾配を用いて、モデルパラメータ記憶部111内の各パラメータの値を更新する。勾配は、ロス関数を各パラメータで微分した値を用いることが一般的である。勾配の計算方法、及び、計算された勾配をスケールさせる学習率等の実装は、既存のDNN学習フレームワークの関数、設定等を流用してもよい。
このように、訓練制御部12は、各々の構成、主に、前処理部10、モデルパラメータ記憶部111、損失算出部14、勾配算出部15を制御することにより、第1ネットワーク113及び第2ネットワーク114を訓練する。
図6は、上述した訓練の流れを示すフローチャートである。
まず、グラフのデータを入力する(S100)。グラフデータの入力は、例えば、1つ1つのグラフを入力するものではなく、上述したように、訓練データ記憶部13にグラフデータを蓄積し、訓練制御部12により随時入力する。
次に、前処理部10は、データの前処理を行う(S102)。上述したように、入力されたグラフデータから各特徴量を抽出し、データの前処理を行う。
次に、隠れベクトル初期化部112は、前処理部10が前処理したデータを用いて第1隠れベクトル及び第2隠れベクトルを初期化する(S104)。隣接行列に対して変換を行う場合も、隠れベクトル初期化部112が処理を行ってもよい。
次に、第1ネットワーク113及び第2ネットワーク114へと初期化された隠れベクトルを入力し、ネットワークを順伝播させる(S106)。この処理は、訓練制御部12により行ってもよいし、別途隠れベクトル計算部及びレイヤ更新計算部を備え、これらの計算部により行ってもよい。
次に、損失算出部14は、第2ネットワーク114から出力された結果と、訓練データ記憶部13に格納されている結果のデータとを比較して損失を算出する(S108)。
次に、訓練制御部12は、算出された損失に基づいて、学習を終了するか否かを判断する(S110)。学習を終了する場合(S110:YES)、処理を終了する。学習の終了は、損失算出部14により算出された損失値に基づいてもよいが、これには限られない。他の判断手法で判断する場合、S110は、S108の前に処理されてもよい。他の判断手法とは、例えば、所定のエポック数分の処理が終了、又は、交差検証値が所定のしきい値より低くなった等、一般的に用いられている手法である。
学習を終了しない場合(S110:NO)、学習を続行する。勾配算出部15は、損失算出部14が算出した損失の各パラメータに対する微分値を求めて各パラメータに対する勾配を算出する(S110)。勾配の算出は、単純な微分を用いるだけではなく、種々に考案されている一般的な手法を用いてもよい。
次に、訓練制御部12は、算出された勾配を逆伝播させ、ネットワークのパラメータを更新する(S114)。更新されたパラメータは、モデルパラメータ記憶部111に記憶される。そして、この更新されたパラメータを用いて順伝播からの作業を繰り返し、学習の終了条件を満たすまで訓練が行われる。このように、訓練制御部12は、ネットワークを更新するネットワーク更新部として機能してもよいし、訓練装置1は、ネットワーク更新部を別途備えていてもよい。これらの学習は、ミニバッチ等により効率化されていてもよい。以下においては、ネットワーク更新部は、第1ネットワーク113及び第2ネットワーク114の双方を更新するものとして説明するが、これには限られず、いずれか一方だけを更新できるものであってもよい。すなわち、第2ネットワーク114のパラメータを更新せずに、第1ネットワーク113のネットワークのパラメータを更新するものであってもよいし、第1ネットワーク113のパラメータを更新せずに、第2ネットワーク114のパラメータを更新するものであってもよい。
次に、第1ネットワーク113の内部における処理について詳しく説明する。図7は、第lレイヤにおけるデータの流れを示す図である。以下の説明における数式等は、一例として示しているものであり、この他の数式が本実施形態のフレームワーク内で用いられることがないということを意味しているわけではない。
メッセージ部115は、まず、第l−1レイヤの出力した第1隠れベクトルと、第2隠れベクトルとを取得する。第1レイヤである場合には、隠れベクトル初期化部112が初期化した各隠れベクトルを取得する。
メッセージ部115は、接続されているノード間において、各隠れベクトルがどの程度更新に影響を与えるかを示すパラメータであるメッセージを生成する。すなわち、ローカルノードからローカルノードへのn個の第1メッセージ、ローカルノードからスーパーノードへの第2メッセージ、スーパーノードからローカルノードへのn個の第3メッセージ、スーパーノードからスーパーノードへの第4メッセージの4種類のメッセージをメッセージ部115が生成する。
第1メッセージは、各ローカルノードに対して、接続されるローカルノードからの影響度を数値化するものである。なお、ここで、接続されるノードとは、自ノードをも含む概念である。まず、k=1,・・・,Kとし、K種類のヘッドを準備する。このヘッドごとにローカルノードのメッセージを生成する。ヘッドとは、それぞれ異なる種類の情報を取得するものである。このヘッドを複数用いることにより、1つの隠れベクトルに対して複数種類の影響度を算出することが可能となり、特徴量の抽出の性能を向上させる。各ヘッドkにおいて、例えば、異なるパラメータ、異なる関数を用いて計算が行われる。
まず、h(l−1,i)と、他のローカルノードとの関連の強さを表す重み(アテンションウェイト)を計算する。第lレイヤ、ローカルノードi、ヘッドkにおけるローカルノードjからローカルノードiにむけてのアテンションウェイトは、h(l−1,i)とh(l−1,j)とを入力とする線形又は非線形の任意の関数により計算される。この計算においては、ローカルノードiとローカルノードj間の接続を隣接行列から抽出し、当該接続(エッジ)の種類により異なる演算を行う。第1メッセージを生成するタイミングにおいて、各ノード間の接続状況を抽出して計算するのではなく、エッジが存在しない場合には、このアテンションウェイトを0とすることにより計算してもよい。
第lレイヤ、ヘッドkにおけるローカルノードjからローカルノードiへのアテンションウェイトは、0以上の実数値であり、例えば、以下の式により計算される。ただし、softmax()は、ソフトマックス関数であり、ベクトル又は行列の右上のTは転置を表す。
Figure 2020068000
ここで、Aは、学習により更新されるパラメータである。
全てのjについてαを求めた後、正規化してノードiに対するアテンションウェイトの和が1となるように正規化する。求められたアテンションウェイトを用いて、以下に示すように、ノードiに接続される全てのノードjの隠れベクトルh(l−1,j)の重み付き和を計算する。この計算は、任意の関数を利用する。
Figure 2020068000
ここで、UとVは、学習により更新されるパラメータである。またNは、ノードiに接続されるノードを示す。接続されていないノード間のアテンションウェイトαを0とする場合には、Nは必ずしも設定しなくともよく、全てのノード間について計算を行ってもよい。メモリの使用量と、計算機のコストに鑑みて、自由に設計することが可能である。
全てのkについて[数2]に示すh〜i,j,kを計算した後、これらを統合したベクトルを計算して第1メッセージを生成する。統合方法は、単純にK個の重み付き和を結合してもよいし、さらには任意の関数で変換してもよい。メッセージ部115は、例えば、以下の式に基づいて統合したベクトルh〜l,iを計算し、第1メッセージとして出力する。
Figure 2020068000
ここで、tanh()は、ハイパボリックタンジェントを、concatkは、kごとに求められたベクトルの結合を意味する。また、Wは、学習により更新されるパラメータである。第1メッセージは、ノードiごとに、計n個のメッセージが生成される。
第2メッセージは、各ローカルノードからスーパーノードへと渡されるパラメータである。例えば、第l−1レイヤにおける全てのローカルノードの隠れベクトルに基づいて、スーパーノードに対する第2メッセージが生成される。第2メッセージも同様に、K種類のヘッドにおいて、ヘッドごとに異なるパラメータ、関数を用いて求められる。
まず、g(l−1)とそれぞれのh(l−1,i)との間の関連の強さを示すアテンションウェイトβを計算する。上記のαと同様に、βは、線形又は非線形の任意の関数により求められる実数値である。αと同様に、βもiに対する和が1となるように正規化される。メッセージ部115は、例えば、以下の式に基づいてアテンションウェイトβを計算する。
Figure 2020068000
ここで、Bは、学習により更新されるパラメータである。
全てのiについてβを求めた後、全てのiに対して隠れベクトルh(l−1,i)の重み付き和を計算する。
Figure 2020068000
ここで、V(S)は、学習により更新されるパラメータである。
全てのkについて[数5]に示すh〜l,super,kを計算した後、これらを統合したベクトルを計算して第2メッセージを生成する。統合方法は、上述と同様である。メッセージ部115は、例えば、以下の式に基づいて統合したベクトルh〜l,superを計算し、第2メッセージとして出力する。
Figure 2020068000
ここで、W(S)は、学習により更新されるパラメータである。
第3メッセージは、スーパーノードから各ローカルノードへと渡されるパラメータである。例えば、第l−1レイヤにおけるスーパーノードの隠れベクトルに基づいて、各ローカルノードに対する第3メッセージが生成される。メッセージ部115は、ローカルノードiに対して、g(l−1)を線形又は非線形の任意の関数で変換して第3メッセージを生成する。
メッセージ部115は、例えば、以下の式に基づいてn個のg〜l,iを計算し、第3メッセージとして出力する。
Figure 2020068000
ここで、Fは、学習により更新されるパラメータである。
第4メッセージは、スーパーノードからスーパーノードへと渡されるパラメータである。例えば、第l−1レイヤにおけるスーパーノードの隠れベクトルに基づいて第4メッセージが生成される。メッセージ部115は、g(l−1)を線形又は非線形の任意の関数で変換して第4メッセージを生成する。
メッセージ部115は、例えば、以下の式に基づいてg〜l,superを計算し、第4メッセメッセージ出力する。
Figure 2020068000
ここで、F(S)は、学習により更新されるパラメータである。
このように、メッセージ部115は、第1隠れベクトルと第2隠れベクトルとから、それぞれ接続されているノード同士でどの程度影響を与えるかのパラメータを算出する。
メッセージ部115が生成した第1から第4の各メッセージは、マージ部116へと入力される。マージ部116は、各メッセージを統合して、第1隠れベクトル及び第2隠れベクトルの更新案となるベクトルである、第1更新隠れベクトル及び第2更新隠れベクトルを生成して出力する。
マージ部116は、第1メッセージ及び第3メッセージから書くローカルノードiの隠れベクトルの更新案となるベクトルを出力する。すなわち、各ローカルノードに対応するn個の第1隠れベクトルの更新案となるn個の第1更新隠れベクトルを出力する。各ローカルノードiに対して、同様の処理が行われる。
まず、第1メッセージと、第3メッセージとの按分の重みであるゲートウェイトを計算する。マージ部116は、第lレイヤにおけるローカルノードiのゲートウェイトを、第1メッセージと第3メッセージを線形又は非線形の任意の関数で変換して生成する。ゲートウェイトは、各要素が0以上1以下の実数値を取るベクトルとして表される。マージ部116は、例えば、以下のようにゲートウェイトを計算する。
Figure 2020068000
ここで、σは、例えば、シグモイド関数である。Gは、学習により更新されるパラメータである。
このゲートウェイトは、第1メッセージ及び第3メッセージから生成されるものであり、自動的かつ適応的にメッセージ同士をマージすることを可能とする。計算されたゲートウェイトを按分比として、マージ部116は、ローカルノードiに対する第1更新隠れベクトルを生成する。なお、この按分は、各要素の単純な線形重み付き和でもよいし、さらに複雑な任意の関数を利用して求めるものであってもよい。マージ部116は、例えば、以下のようにマージして第1更新隠れベクトルを生成する。
Figure 2020068000
同様に、第2メッセージと、第4メッセージとの按分の重みであるゲートウェイトを計算し、第2更新隠れベクトルを生成する。
Figure 2020068000
Figure 2020068000
このように、マージ部116は、互いに性質の異なる種類のデータを按分して統合するゲートとして機能する。マージ部116により生成された第1更新隠れベクトルh^l,iと、第2更新隠れベクトルg^は、リカレント部117へと入力される。
リカレント部117は、第lレイヤにおける全ての第1隠れベクトル、第2隠れベクトル、全ての第1更新隠れベクトル、及び、第2更新隠れベクトルを用いて、第lレイヤの出力である第1隠れベクトルと第2隠れベクトルとを生成して出力する。
第1隠れベクトルは、各ローカルノードiにおいて計算され、その全てが第lレイヤの第1隠れベクトルとして出力される。この計算には、一般的なLSTM(Long-Short Term Memory)、GRU(Grated Recurrent Unit)等のゲーティング機能を有するリカレントネットワークが利用される。例えば、リカレント部117は、以下のように第l−1レイヤの第1隠れベクトル及び生成された第1更新隠れベクトルを用いて、GRUにより第lレイヤの第1隠れベクトルを生成する。
Figure 2020068000
同様に、リカレント部117は、第2隠れベクトルと第2更新隠れベクトルとを用いて第2隠れベクトルを更新する。例えば、リカレント部117は、以下のように第l−1レイヤの第2隠れベクトル及び生成された第2更新隠れベクトルを用いて、GRUにより第lレイヤの第2隠れベクトルを生成する。
Figure 2020068000
リカレント部117の出力した第lレイヤの第1隠れベクトル及び第2隠れベクトルは、第l+1レイヤの入力となり、第Lレイヤまで隠れベクトルの更新が繰り返される。第Lレイヤにおいて出力された隠れベクトルは、第1ネットワーク113の出力となる。
第1ネットワーク113が出力した第1隠れベクトル及び第2隠れベクトルは、第2ネットワークへと入力される。第2ネットワーク114は、計算、更新された第1隠れベクトル、第2隠れベクトル、及び、隣接行列を用いて、グラフ全体の表現ベクトルを計算する。
まず、グラフデータごとにベクトルの数、すなわち、ローカルノードの数nが異なるので、これらのn個のベクトルを1つの固定長ベクトルに縮約する。第1ネットワーク113から出力されたn個の第1隠れベクトルは、任意の縮約関数、例えば、単純平均、DNN(Deep Neural Network)、Setout関数等の関数に入力され、固定長の単一ベクトルh(merged)に変換される。
次に、第2ネットワーク114は、h(merged)とg(L)とを任意の関数に入力し、読み出しベクトルrを出力する。例えば、以下のようにDNNにより計算される。
Figure 2020068000
損失算出部14は、算出されたrを用いて損失を計算する。損失の計算は、上述したように、一般的に用いられているロス関数を用いてもよいし、適切に損失の計算ができるものであれば、新規な線形又は非線形の任意の関数を用いてもよい。
勾配算出部15は、損失算出部14が算出した損失に基づいて、各パラメータに対する勾配を求める。訓練制御部12は、この勾配を第1ネットワーク及び第2ネットワークについて逆伝播させることにより、各ネットワークを構成するパラメータを更新する。
メッセージ部115、マージ部116、リカレント部117における内部のパラメータは、レイヤごとに異なるものである。すなわち、各レイヤにおいてそれぞれが適切な按分比率でマージを行い、第1隠れベクトルと第2隠れベクトルとが更新される。訓練においても同様であり、勾配算出部15は、それぞれのレイヤにおけるパラメータに対する勾配を算出する。このそれぞれのレイヤにおいて算出された各パラメータの勾配に基づいて、訓練制御部12は、逆伝播させてパラメータの更新を行う。
なお、シグモイド関数等は、適切に0と1との間で値を取り、微分できるものであれば、どのような関数であってもよい。また、微分できない関数であっても、勾配算出部15により適切に勾配が求められるような関数であってもよい。
図8は、予測モードにおける演算ネットワークの機能を示すブロック図である。予測モード、又は、予測装置2は、前述された訓練装置1により最適化されたパラメータを用いた演算ネットワーク11を備える。予測モードにおいては、ユーザが指定したグラフの種類に基づいて、定数記憶部110及びモデルパラメータ記憶部111から予測制御部22が適切なパラメータを選択し、第1ネットワーク113及び第2ネットワーク114を形成する。これには限られず、入力されたグラフデータを自動的にどのような種類のデータであるかを判別し、予測装置2の予測制御部22が自動的にモデルパラメータ等を取得してネットワークを形成するようにしてもよい。
予測制御部22は、予測データ記憶部23に格納されているデータ、又は、ユーザが入力した予測対象となるデータを前処理部10が処理したデータを演算ネットワーク11へと入力する。入力されたデータは、隠れベクトル初期化部112において第1隠れベクトル及び第2隠れベクトルへと変換され、第1ネットワーク113及び第2ネットワーク114へと入力される。
第1ネットワークは、各ローカルノードにおける第1隠れベクトル及び第2隠れベクトルを算出し、第2ネットワークへと出力する。各隠れベクトルが入力された第2ネットワーク114は、読み出しベクトルrを生成し、予測部24へと出力する。
予測部24は、入力された読み出しベクトルrを適切に処理して、ユーザに理解できる形として出力、又は、適切なデータベース等に出力する。
以上のように、本実施形態によれば、グラフデータが入力されるとグラフ全体の特徴を出力するネットワークを学習する訓練装置1及び当該訓練装置1により生成されたネットワークを有する予測装置2を実現することが可能である。このグラフの処理は、グラフのノードの全てと接続されるスーパーノードを設定し、かつ、このスーパーノードの隠れベクトルを定義し、グラフのノードと、スーパーノードとのそれぞれの隠れベクトルを適応的にマージすることにより、ノード又はエッジの持つ個々の特徴と、グラフ全体が有する特徴とを適切に統合することが可能である。
また、スーパーノードはその初期値として観測値が入力されるため、グラフ全体としての特徴をネットワークに対して反映させることができる。このように、個々のノードと、グラフ全体の特徴という異なる性質の隠れベクトルを適応的に統合することにより、グラフ全体の特徴を高精度に出力することが可能となる。また、アテンションウェイトの導入により、柔軟なネットワークの生成を可能としている。
なお、前述した実施形態に限られず、種々の変形例が考えられることを理解されたい。例えば、メッセージ部115は、アテンションウェイトの計算を簡略化することが可能である。第1メッセージのアテンションウェイトの演算においては、各ローカルノード間のエッジの種類によらず同じ関数又は同じパラメータを用いて計算してもよい。アテンションウェイトを計算する関数を、入力がない関数として、全てのローカルノードにおいて事前に固定したルールにより重みを与えてもよい。
メッセージ部115はまた、ヘッドを考慮しなくてもよい。すなわち、K=1として、ヘッドが1つしかないものとしてローカルノードからのメッセージを生成してもよい。
メッセージ部115は、一部又は全ての関数又はパラメータをレイヤごとに共有してもよい。すなわち、レイヤによらず、共有した関数、パラメータは、同じ形、同じ値を有するものであってもよい。
マージ部116は、ゲートウェイトを逐次的に計算する代わりに、入力によらない固定のパラメータを有する行列により、線形結合や単純な算術平均によるメッセージの統合を行ってもよい。また、マージ部116についてもメッセージ部115と同様に、レイヤによらず共有した関数又はパラメータを有するものであってもよい。
リカレント部117として、ゲーティング機能を持たないリカレントユニットを利用してもよい。例えば、第1隠れベクトルと第1更新隠れベクトル、第2隠れベクトルと第2更新隠れベクトルとを、それぞれ線形結合の関数等を用いて結合してもよい。
前述した実施形態における訓練装置1及び予測装置2において、各機能は、アナログ回路、デジタル回路又はアナログ・デジタル混合回路で構成された回路であってもよい。また、各機能の制御を行う制御回路を備えていてもよい。各回路の実装は、ASIC(Application Specific Integrated Circuit)、FPGA(Field Programmable Gate Array)等によるものであってもよい。
上記の全ての記載において、訓練装置1及び予測装置2の少なくとも一部はハードウェアで構成されていてもよいし、ソフトウェアで構成され、ソフトウェアの情報処理によりCPU(Central Processing Unit)等が実施をしてもよい。ソフトウェアで構成される場合には、訓練装置1、予測装置2及びその少なくとも一部の機能を実現するプログラムをフレキシブルディスクやCD−ROM等の記憶媒体に収納し、コンピュータに読み込ませて実行させるものであってもよい。記憶媒体は、磁気ディスクや光ディスク等の着脱可能なものに限定されず、ハードディスク装置やメモリなどの固定型の記憶媒体であってもよい。すなわち、ソフトウェアによる情報処理がハードウェア資源を用いて具体的に実装されるものであってもよい。さらに、ソフトウェアによる処理は、FPGA等の回路に実装され、ハードウェアが実行するものであってもよい。ジョブの実行は、例えば、GPU(Graphics Processing Unit)等のアクセラレータを使用して行ってもよい。
例えば、コンピュータが読み取り可能な記憶媒体に記憶された専用のソフトウェアをコンピュータが読み出すことにより、コンピュータを上記の実施形態の装置とすることができる。記憶媒体の種類は特に限定されるものではない。また、通信ネットワークを介してダウンロードされた専用のソフトウェアをコンピュータがインストールすることにより、コンピュータを上記の実施形態の装置とすることができる。こうして、ソフトウェアによる情報処理が、ハードウェア資源を用いて、具体的に実装される。
図8は、本発明の一実施形態におけるハードウェア構成の一例を示すブロック図である。訓練装置1及び予測装置2は、プロセッサ71と、主記憶装置72と、補助記憶装置73と、ネットワークインタフェース74と、デバイスインタフェース75と、を備え、これらがバス76を介して接続されたコンピュータ装置7として実現できる。
なお、図8のコンピュータ装置7は、各構成要素を一つ備えているが、同じ構成要素を複数備えていてもよい。また、1台のコンピュータ装置7が示されているが、ソフトウェアが複数のコンピュータ装置にインストールされて、当該複数のコンピュータ装置それぞれがソフトウェアの異なる一部の処理を実行してもよい。
プロセッサ71は、コンピュータの制御装置および演算装置を含む電子回路(処理回路、Processing circuit、Processing circuitry)である。プロセッサ71は、コンピュータ装置7の内部構成の各装置などから入力されたデータやプログラムに基づいて演算処理を行い、演算結果や制御信号を各装置などに出力する。具体的には、プロセッサ71は、コンピュータ装置7のOS(Operating System)や、アプリケーションなどを実行することにより、コンピュータ装置7を構成する各構成要素を制御する。プロセッサ71は、上記の処理を行うことができれば特に限られるものではない。訓練装置1、予測装置2及びそれらの各構成要素は、プロセッサ71により実現される。ここで、処理回路とは、1チップ上に配置された1又は複数の電気回路を指してもよいし、2つ以上のチップあるいはデバイス上に配置された1又は複数の電気回路を指してもよい。
主記憶装置72は、プロセッサ71が実行する命令および各種データなどを記憶する記憶装置であり、主記憶装置72に記憶された情報がプロセッサ71により直接読み出される。補助記憶装置73は、主記憶装置72以外の記憶装置である。なお、これらの記憶装置は、電子情報を格納可能な任意の電子部品を意味するものとし、メモリでもストレージでもよい。また、メモリには、揮発性メモリと、不揮発性メモリがあるが、いずれでもよい。訓練装置1及び予測装置2内において各種データを保存するためのメモリは、主記憶装置72または補助記憶装置73により実現されてもよい。例えば、前述した各記憶部の少なくとも一部は、この主記憶装置72又は補助記憶装置73に実装されていてもよい。別の例として、アクセラレータが備えられている場合には、前述した各記憶部の少なくとも一部は、当該アクセラレータに備えられているメモリ内に実装されていてもよい。
ネットワークインタフェース74は、無線または有線により、通信ネットワーク8に接続するためのインタフェースである。ネットワークインタフェース74は、既存の通信規格に適合したものを用いればよい。ネットワークインタフェース74により、通信ネットワーク8を介して通信接続された外部装置9Aと情報のやり取りが行われてもよい。
外部装置9Aは、例えば、カメラ、モーションキャプチャ、出力先デバイス、外部のセンサ、入力元デバイスなどが含まれる。また、外部装置9Aは、訓練装置1及び予測装置2の構成要素の一部の機能を有する装置でもよい。そして、コンピュータ装置7は、訓練装置1及び予測装置2の処理結果の一部を、クラウドサービスのように通信ネットワーク8を介して受け取ってもよい。
デバイスインタフェース75は、外部装置9Bと直接接続するUSB(Universal Serial Bus)などのインタフェースである。外部装置9Bは、外部記憶媒体でもよいし、ストレージ装置でもよい。各記憶部は、外部装置9Bにより実現されてもよい。
外部装置9Bは出力装置でもよい。出力装置は、例えば、画像を表示するための表示装置でもよいし、音声などを出力する装置などでもよい。例えば、LCD(Liquid Crystal Display)、CRT(Cathode Ray Tube)、PDP(Plasma Display Panel)、スピーカなどがあるが、これらに限られるものではない。
なお、外部装置9Bは入力装置でもよい。入力装置は、キーボード、マウス、タッチパネルなどのデバイスを備え、これらのデバイスにより入力された情報をコンピュータ装置7に与える。入力装置からの信号はプロセッサ71に出力される。
上記の全ての記載に基づいて、本発明の追加、効果又は種々の変形を当業者であれば想到できるかもしれないが、本発明の態様は、上記した個々の実施形態に限定されるものではない。特許請求の範囲に規定された内容及びその均等物から導き出される本発明の概念的な思想と趣旨を逸脱しない範囲において種々の追加、変更及び部分的削除が可能である。例えば、前述した全ての実施形態において、説明に用いた数値は、一例として示したものであり、これらに限られるものではない。
前述した実施形態において、各実施形態における計算は、ローカルノード数に相当した隠れベクトル列を入力とし、それぞれのベクトルに対応したメッセージベクトルに相当する量を出力する、GCNあるいは任意の計算モデルに利用できるものである。例えば、前述した説明において用いているGAT(Graph Attention Networks)、RGAT(Relational Graph Attention Networks)のようにアテンション技術を用いるものにはそのまま適用することが可能である。この他にも、GGNN(Gated Graph Sequence Neural Network)のようにゲート関数を用いるもの、RSGCN(Renormalized Spectral Graph Convolutional Network)のようにそのいずれも利用しないものにも適用することができる。
アテンションを利用しない場合、例えば、GGNNを利用する場合におけるゲート関数を用いた例について説明する。まず、前層の第l−1レイヤの出力と隣接行列の情報を参照して、第lレイヤにおけるローカルノードiの一時変数ベクトルal、iを計算する。
Figure 2020068000
ここで、行列Aは、ローカルノード数をn、ローカルノードの隠れベクトルをDとすると2D×nD次元となる。第i番目のローカルノードに対して、隣接行列の情報を参照して第j番目のローカルノードとの間にエッジが存在しない場合、Aの第D×(j−1)+1列からD×j列の値は、ゼロとする。上式右辺のHl−1は、N個のローカルノード隠れベクトルを連結したnD次元のベクトルであり、al,iは、2D次元のベクトルである。また、bは、バイアスを表す2D次元のベクトルである。
出力において、更新量に相当するベクトルを計算する。この計算におけるゲートrと更新ベクトルh^l,iは、以下のように計算する。
Figure 2020068000
Figure 2020068000
最終的な出力にもゲートを用いる。ここで、odotはベクトルの要素ごとの積を表す。
Figure 2020068000
メッセージ部115は、例えば、以下の式に基づいてゲートで結合したベクトルh〜l,iを計算し、第1メッセージとして出力する。
Figure 2020068000
この第1メッセージを用いて、上記の処理を行うことにより、アテンションを用いずにゲート関数を用いた場合に対応することが可能となる。
メッセージ部115は、ゲート関数を用いる場合には、一時変数ベクトルの計算は、線形計算だけではなく、非線形関数を重畳してもよい。また、ゲート関数は、例えば、シグモイド関数であるが、これには限られず別の非線形関数としてもよい。
別の例として、アテンションもゲート関数も用いない例について説明する。この例では、第lレイヤのローカルノードの隠れベクトルを並べた行列をXと記載する。Xの次元は、n×Dとなる。
Figure 2020068000
Aは、隣接行列に対角項の単位行列を加算したものを表す。Θは、Xと同じサイズのパラメータ行列である。Dは、対角行列であり、以下のように示される。
Figure 2020068000
この場合、出力するベクトルh〜l,iは、Xの第i行目のベクトルを利用する。このようなメッセージを用いることにより、RSGCNのようにアテンションもゲート関数も用いない場合に対応することが可能となる。
メッセージ部115は、このようにアテンションもゲート関数も用いない場合には、さらに、非線形変換を重畳してもよい。
以上のように、第1メッセージの計算には、既存のGCN、あるいは、より一般のDNNを対象とするネットワークにしたがい適切に入れ替えて利用することができる。どのような計算モデルを用いた場合においても、スーパーノード及びマージ部116、リカレント部117の処理により、グラフデータの解析性能の向上を図ることが可能である。
1:訓練装置、10:前処理部、100:ローカルノード特徴量取得部、101:隣接行列取得部、102:スーパーノード特徴量取得部、11:演算ネットワーク、110:定数記憶部、111:モデルパラメータ記憶部、112:隠れベクトル初期化部、113:第1ネットワーク、114:第2ネットワーク、115:メッセージ部、116:マージ部、117:リカレント部、12:訓練制御部、13:訓練データ記憶部、14:損失算出部、15:勾配算出部、2:予測装置、22:予測制御部、23:予測データ記憶部、24:予測部

Claims (21)

  1. グラフのデータを入力すると、前記グラフの特徴を予測するネットワークを訓練する訓練装置であって、
    前記グラフを構成する第1ノードの第1隠れベクトル、前記第1ノード間の接続情報、及び、前記第1ノードのそれぞれと接続される第2ノードの第2隠れベクトルに基づいて、前記第1隠れベクトルと前記第2隠れベクトルとをマージする、マージ部、
    を各層に備え、前記マージ結果に基づいて、前記第1隠れベクトル及び前記第2隠れベクトルを更新する、第1ネットワークと、
    前記第1ネットワークが出力した前記第1隠れベクトル及び前記第2隠れベクトルに基づいて、前記グラフの特徴を抽出する、第2ネットワークと、
    前記第2ネットワークが出力した前記グラフの特徴の損失を算出する、算出部と、
    前記損失に基づいて、少なくとも前記第1ネットワーク及び前記第2ネットワークのいずれか一方を更新する、ネットワーク更新部と、
    を備える訓練装置。
  2. 前記マージ部は、前記第1隠れベクトルに対する重み及び前記第2隠れベクトルに対する重みを算出し、算出された重みに基づいて前記第1隠れベクトルと前記第2隠れベクトルとをマージして、第1更新隠れベクトル及び第2更新隠れベクトルを生成し、
    前記第1ネットワークは、前記第1更新隠れベクトルに基づいて前記第1隠れベクトルを更新し、前記第2更新隠れベクトルに基づいて前記第2隠れベクトルを更新する、
    請求項1に記載訓練装置。
  3. 前記ネットワーク更新部は、前記マージ部における前記第1隠れベクトルと、前記第2隠れベクトルとをマージする比率を適応的に更新する、請求項1又は請求項2に記載の訓練装置。
  4. 前記マージ部は、ゲートとして動作する、請求項1乃至請求項3のいずれかに記載の訓練装置。
  5. 前の層で更新された前記第1ノードそれぞれから、接続される前記第1ノードへと伝達するパラメータである第1メッセージと、
    前の層で更新された前記第1ノードそれぞれから、前記第2ノードへと伝達する第2メッセージと、
    前の層で更新された前記第2ノードから、前記第1ノードへと伝達する第3メッセージと、
    前の層で更新された前記第2ノードから、前記第2ノードへと伝達する第4メッセージと、
    を生成する、メッセージング部をさらに備え、
    前記マージ部は、前記第1メッセージ及び前記第3メッセージに基づいて、及び、前記第2メッセージ及び前記第4メッセージに基づいて、前記第1隠れベクトル及び前記第2隠れベクトルをマージする、
    請求項1乃至請求項4のいずれかに記載の訓練装置。
  6. 前記メッセージング部は、互いに接続されている前記第1ノード間を接続するエッジの種類に基づいて、前記第1メッセージを生成する、請求項5に記載の訓練装置。
  7. 前記メッセージング部は、互いに接続されている前記第1ノードについてゲート関数を用いて前記第1メッセージを生成する、請求項5に記載の訓練装置。
  8. 前記メッセージング部は、自ノードを含む互いに接続されている前記第1ノードについてパラメータ行列を乗じて前記第1メッセージを生成する、請求項5に記載の訓練装置。
  9. 前の層で更新された前記第1隠れベクトル、及び、前記マージ結果に基づいて、前記第1隠れベクトルを更新し、
    前の層で更新された前記第2隠れベクトル、及び、前記マージ結果に基づいて、前記第2隠れベクトルを更新する、
    リカレント部を備える請求項1乃至請求項8のいずれかに記載の訓練装置。
  10. 前記リカレント部は、ゲートとして動作する、請求項9に記載の訓練装置。
  11. 前記第1ノードの特徴量から前記第1隠れベクトルを算出し、前記グラフから前記第1ノード間の接続情報を抽出し、前記第2ノードの特徴量から前記第2隠れベクトルを算出する、前処理部、を備える請求項1乃至請求項10のいずれかに記載の訓練装置。
  12. 前記前処理部は、前記グラフに関する観測情報を前記第2ノードの特徴量として抽出して前記第2隠れベクトルを初期化する、請求項11に記載の訓練装置。
  13. 前記前処理部は、前記グラフを構成する前記第1ノードの数、前記第1ノードの種類数、前記第1ノードを相互に接続するエッジの種類数、前記グラフの直径のうち、少なくとも1つを前記第2ノードの特徴量として抽出する、請求項12に記載の訓練装置。
  14. グラフのデータを入力すると、前記グラフの特徴を予測する予測装置であって、
    前記グラフを構成する第1ノードの第1隠れベクトル、前記第1ノード間の接続情報、及び、前記第1ノードのそれぞれと接続される第2ノードの第2隠れベクトルに基づいて、前記第1隠れベクトルと前記第2隠れベクトルとをマージし、前記マージ結果に基づいて、前記第1隠れベクトル及び前記第2隠れベクトルを更新する、第1ネットワークと、
    前記第1ネットワークが出力した前記第1隠れベクトル及び前記第2隠れベクトルに基づいて、前記グラフの特徴を抽出する、第2ネットワークと、
    を備える予測装置。
  15. 前記第1ノードの特徴量から前記第1隠れベクトルを算出し、前記グラフから前記第1ノード間の接続情報を抽出し、前記第2ノードの特徴量から前記第2隠れベクトルを算出する、前処理部、を備える請求項14に記載の予測装置。
  16. 前記前処理部は、前記グラフに関する観測情報を前記第2ノードの特徴量として抽出して前記第2隠れベクトルを初期化する、請求項15に記載の予測装置。
  17. 前記前処理部は、前記グラフを構成する前記第1ノードの数、前記第1ノードの種類数、前記第1ノードを相互に接続するエッジの種類数、前記グラフの直径のうち、少なくとも1つを前記第2ノードの特徴量として抽出する、請求項16記載の予測装置。
  18. グラフのデータを入力すると、前記グラフの特徴を予測するネットワークを訓練する訓練方法であって、
    前記グラフを構成する第1ノードの第1隠れベクトル、前記第1ノード間の接続情報、及び、前記第1ノードのそれぞれと接続される第2ノードの第2隠れベクトルに基づいて、前記第1隠れベクトルと前記第2隠れベクトルとをマージし、
    前記マージ結果に基づいて、前記第1隠れベクトル及び前記第2隠れベクトルを更新し、
    更新された前記第1隠れベクトル及び前記第2隠れベクトルに基づいて、前記グラフの特徴を抽出し、
    抽出された前記グラフの特徴の損失を算出し、
    前記損失に基づいて、少なくとも前記ネットワークの一部を更新する、
    訓練方法。
  19. コンピュータに、
    グラフのデータを入力すると、前記グラフの特徴を予測するネットワークを訓練する手段であって、
    前記グラフを構成する第1ノードの第1隠れベクトル、前記第1ノード間の接続情報、及び、前記第1ノードのそれぞれと接続される第2ノードの第2隠れベクトルに基づいて、前記第1隠れベクトルと前記第2隠れベクトルとをマージする、マージ手段、
    前記マージ手段を各層に備え、前記マージ結果に基づいて、前記第1隠れベクトル及び前記第2隠れベクトルを更新する、第1ネットワーク、
    前記第1ネットワークが出力した前記第1隠れベクトル及び前記第2隠れベクトルに基づいて、前記グラフの特徴を抽出する、第2ネットワーク、
    前記第2ネットワークが出力した前記グラフの特徴の損失を算出する、算出手段、
    前記損失に基づいて、少なくとも前記第1ネットワーク及び前記第2ネットワークのいずれか一方を更新する、ネットワーク更新手段、
    として機能させるプログラム。
  20. グラフのデータを入力すると、前記グラフの特徴を予測する予測方法であって、
    前記グラフを構成する第1ノードの第1隠れベクトル、前記第1ノード間の接続情報、及び、前記第1ノードのそれぞれと接続される第2ノードの第2隠れベクトルに基づいて、前記第1隠れベクトルと前記第2隠れベクトルとをマージし、
    前記マージ結果に基づいて、前記第1隠れベクトル及び前記第2隠れベクトルを更新し、
    更新された前記第1隠れベクトル及び前記第2隠れベクトルに基づいて、前記グラフの特徴を抽出する、
    予測方法。
  21. コンピュータに、
    グラフのデータを入力すると、前記グラフの特徴を予測する手段であって、
    前記グラフを構成する第1ノードの第1隠れベクトル、前記第1ノード間の接続情報、及び、前記第1ノードのそれぞれと接続される第2ノードの第2隠れベクトルに基づいて、前記第1隠れベクトルと前記第2隠れベクトルとをマージする、マージ手段、
    前記マージ手段を各層に備え、前記マージ結果に基づいて、前記第1隠れベクトル及び前記第2隠れベクトルを更新する、第1ネットワーク、
    前記第1ネットワークが出力した前記第1隠れベクトル及び前記第2隠れベクトルに基づいて、前記グラフの特徴を抽出する、第2ネットワーク、
    として機能させるプログラム。
JP2018227477A 2018-10-19 2018-12-04 訓練装置、訓練方法、予測装置、予測方法及びプログラム Pending JP2020068000A (ja)

Priority Applications (1)

Application Number Priority Date Filing Date Title
US16/657,389 US20200125958A1 (en) 2018-10-19 2019-10-18 Training apparatus, training method, inference apparatus, inference method, and non-transitory computer readable medium

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
JP2018197712 2018-10-19
JP2018197712 2018-10-19

Publications (1)

Publication Number Publication Date
JP2020068000A true JP2020068000A (ja) 2020-04-30

Family

ID=70388524

Family Applications (1)

Application Number Title Priority Date Filing Date
JP2018227477A Pending JP2020068000A (ja) 2018-10-19 2018-12-04 訓練装置、訓練方法、予測装置、予測方法及びプログラム

Country Status (1)

Country Link
JP (1) JP2020068000A (ja)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113627977A (zh) * 2021-07-30 2021-11-09 北京航空航天大学 一种基于异构图的房屋价值预测方法
CN115512460A (zh) * 2022-09-29 2022-12-23 北京交通大学 一种基于图注意力模型的高速列车轴温长时预测方法
CN117407697A (zh) * 2023-12-14 2024-01-16 南昌科晨电力试验研究有限公司 基于自动编码器和注意力机制的图异常检测方法及系统

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113627977A (zh) * 2021-07-30 2021-11-09 北京航空航天大学 一种基于异构图的房屋价值预测方法
CN115512460A (zh) * 2022-09-29 2022-12-23 北京交通大学 一种基于图注意力模型的高速列车轴温长时预测方法
CN115512460B (zh) * 2022-09-29 2024-04-16 北京交通大学 一种基于图注意力模型的高速列车轴温长时预测方法
CN117407697A (zh) * 2023-12-14 2024-01-16 南昌科晨电力试验研究有限公司 基于自动编码器和注意力机制的图异常检测方法及系统
CN117407697B (zh) * 2023-12-14 2024-04-02 南昌科晨电力试验研究有限公司 基于自动编码器和注意力机制的图异常检测方法及系统

Similar Documents

Publication Publication Date Title
Abd Elaziz et al. A Grunwald–Letnikov based Manta ray foraging optimizer for global optimization and image segmentation
Périaux et al. Combining game theory and genetic algorithms with application to DDM-nozzle optimization problems
CN103548027B (zh) 用于实现建筑系统的系统和方法
Pozna et al. New results in modelling derived from Bayesian filtering
Andrianakis et al. History matching of a complex epidemiological model of human immunodeficiency virus transmission by using variance emulation
CN115456159A (zh) 一种数据处理方法和数据处理设备
US20200125958A1 (en) Training apparatus, training method, inference apparatus, inference method, and non-transitory computer readable medium
Rahaman et al. An efficient multilevel thresholding based satellite image segmentation approach using a new adaptive cuckoo search algorithm
JP2020068000A (ja) 訓練装置、訓練方法、予測装置、予測方法及びプログラム
KR102190103B1 (ko) 인공 신경망의 상용화 서비스 제공 방법
Andrianakis et al. Efficient history matching of a high dimensional individual-based HIV transmission model
WO2021054402A1 (ja) 推定装置、訓練装置、推定方法及び訓練方法
CN109711401A (zh) 一种基于Faster Rcnn的自然场景图像中的文本检测方法
Lagaros et al. Multi-objective design optimization using cascade evolutionary computations
Hofmeyer et al. Automated design studies: topology versus one-step evolutionary structural optimisation
Bodini et al. Underdetection in a stochastic SIR model for the analysis of the COVID-19 Italian epidemic
Xing et al. Elite levy spreading differential evolution via ABC shrink-wrap for multi-threshold segmentation of breast cancer images
JP6819758B1 (ja) 点群データ同一性推定装置及び点群データ同一性推定システム
Estep et al. Fast and reliable methods for determining the evolution of uncertain parameters in differential equations
Gorelova et al. Strategy of complex systems development based on the synthesis of foresight and cognitive modelling methodologies
Radev et al. Bayesflow: Amortized bayesian workflows with neural networks
Patel et al. Smart adaptive mesh refinement with NEMoSys
Weatherill et al. Capturing Directivity in Probabilistic Seismic Hazard Analysis for New Zealand: Challenges, Implications, and a Machine Learning Approach for Implementation
Guo et al. Hybrid iterative reconstruction method for imaging problems in ECT
Yilmaz Artificial neural networks pruning approach for geodetic velocity field determination

Legal Events

Date Code Title Description
A80 Written request to apply exceptions to lack of novelty of invention

Free format text: JAPANESE INTERMEDIATE CODE: A80

Effective date: 20181217

AA64 Notification of invalidation of claim of internal priority (with term)

Free format text: JAPANESE INTERMEDIATE CODE: A241764

Effective date: 20190129

A521 Request for written amendment filed

Free format text: JAPANESE INTERMEDIATE CODE: A523

Effective date: 20190201