Abstract Syntax Networks for Code Generation and Semantic Parsing
July 1, 2023Abstract Syntax Networksは、非構造的な文章などの入力から、抽象構文木(AST)にしたがう系列を生成できるencoder decoderである。 decoderは、ASTの生成規則にある記号に対応するモジュールを再帰的に構成したネットワークである。 モジュールは、右辺のどの生成規則を選択すべきか推定する。 そして、選択した規則のモジュールをさらに再帰的に選択することで、ASTにしたがう出力を生成する。
Abstract Syntax Networksは、複数の系列を入力に受けとれる。 実験では、カードゲームHEARTSTONEのカードに書かれた名前と説明文から、カードの効果を実装したコードを生成した。 系列を個別の双方向LSTMに入力し、順方向と逆方向の隠れ状態を連結したベクトルを系列ごとに用意する。
生成規則の右辺が非終端記号の選択である場合、モジュールは、そのうちの一つ\(\text{C}\)を推定する。 系列のベクトルの連結\(\textbf{v}\)と順伝搬ネットワーク\(f_\text{T}\)に与えることで、以下のように生成規則を求める。 $$ p({\rm C}| \text{T}, \textbf{v}) = [\text{softmax}(f_\text{T}(\textbf{v}))]_{\text{C}} $$
生成規則の右辺が非終端記号列であれば、\(\textbf{v}\)と注意\(\textbf{c}\)から、系列にある各記号の中間状態を計算する。 入力系列中のトークン\(t\)の埋め込みベクトルを\(\textbf{e}^\top_t\), 入力系列の種類\(g\)の重みを\(\textbf{w}_g\)とすると、注意\(\textbf{c}\)を次のように計算する。 $$ \begin{align} q_t&=\textbf{e}^\top_t\textbf{W}\textbf{v}+\textbf{w}^\top_c{\rm v}\\ \textbf{a}&=\text{softmax}(\textbf{q})\\ \textbf{c}&=\sum_ta_t\textbf{e}_t \end{align} $$ 順伝搬ネットワーク\(f_\text{C}\)と生成規則の右辺にある記号\(\text{F}\)の埋め込みベクトル\(\textbf{e}_\text{F}\)から、ノード\(u\)における\(\text{F}\)の隠れ状態\(\textbf{v}_{u,\text{F}}\)を計算する。 ノード\(u\)から再帰的に呼び出すモジュールには、\(\textbf{v}\)にかわって\(\textbf{v}_{u,\text{F}}\)を渡す。 \(\text{LSTM}(\textbf{h}, \textbf{x})\)は、隠れ状態に\(\textbf{h}\), 入力に\(\textbf{x}\)を与えることを意味する。 $$ \textbf{v}_{u, \text{F}}=\text{LSTM}(\textbf{v}, f_{\text{C}}(\textbf{e}_\text{F}, \textbf{c})) $$
記号が終端記号\(\text{T}\)であれば、記号の隠れ状態\(\textbf{v}_{u, \text{F}}\)を順伝搬ネットワーク\(f\)に入力し、出力\(y\)を決定する。 $$ p(y|\text{T}, \textbf{v}_{u,\text{F}})=[\text{softmax}(f(\textbf{v}_{u,\text{F}}))]_y $$
記号が別の記号オプションであれば、順伝搬ネットワーク\(f^{\text{gen}}_{\text{F}}\)で記号を生成すべきか推定する。 $$ p(z_{\text{F}}=1|\textbf{v}_{u,\text{F}})=\text{sigmoid}(f^{\text{gen}}_{\text{F}}(\textbf{v}_{u,\text{F}})) $$
雑記
学習をはじめたばかりでは、再帰しつづけるような生成規則の展開が起きてしまわないだろうか。