Generating Long Sequences with Sparse Transformers(2019)
April 15, 2023TransformerのQKV注意機構に入力するベクトルを限定し、長さ\(n\)の系列をQKV注意機構に入力したときの空間計算量を\(\mathcal{O}(n\sqrt{n})\)まで減らした研究である。 Transformerであれば、系列の要素は、要素自体の位置と以前の要素すべてを注意し、時間と空間計算量は\(\mathcal{O}(n^2)\)になる。 Sparse Transformerは、\(p\)個のパターンを用意し、パターンに該当する要素のみを各注意機構に入力し、\(p\)個の注意を生成する。 そして、\(p\)個の注意を合成し、1つの注意に変換する。 パターンは、画像やテキストなど、入力するデータの種類によって決めておく規則であり、たとえば、直近にある一定数の要素や等間隔に離れた要素を指定するパターンがありえある。 パターンが\(p\)であれば、計算量は\(\mathcal{O}\sqrt[p]{n}\)になる。実験の設定は\(p=2\)である。
パターンに該当する要素のみを注意機構に入力する以外は、Transformerの注意機構とおなじである。 パターン\(m\)が\(i\)番目位置にある要素が注意する位置の集合を\(A^{(m)}_{i}\)とする。 このとき、パターン\(A^{(m)}\)の注意機構が再構成するクエリ\(\text{Attend}(X,A^{(m)})=\left(a({\rm x}_i,A^{(m)}_i)\right)_{i\in\{1,\dots,n\}}\)は、
$$ \begin{align} a({\rm x}_i,A^{(m)}_i)&=\text{softmax} \left(\frac{(W_q{\rm x}_i)K^T_{A^{(m)}_i}}{\sqrt{d}}\right)V_{A^{(m)}_i}\\ K_{A^{(m)}_i}&=\left(W_k{\rm x}_j\right)_{j\in A^{(m)}_i}\\ V_{A^{(m)}_i}&=\left(W_v{\rm x}_j\right)_{j\in A^{(m)}_i} \end{align} $$ である。\(W_q, W_k,W_v\)は\({\rm x}_i\)をクエリ、キー、バリューに変換する重みであり、\(d\)はクエリとキーの次元である。 実験では、パターンの数は2である。 パターンを形式的に定義するなら、たとえば、\(A_i^{(1)}=\{t,t+1,\dots,i\}\ \text{for}\ t = \max(0, i-l) \), \(A^{(2)}_i=\{j:(i-j)\ \mod l=0\}\)と示すことができる。
注意機構がパターン数\(p\)個あるマルチヘッド注意機構であるから、個別の注意機構で生成された注意を合成する。 2つの合成方法を例示しており、そのうちの一つは、重みを\(W_p\)として $$ \text{attention}(X)=W_p\cdot \text{attend}\left(X,\bigcup^p_{m=1}A^{(m)}\right) $$ を計算し、注意を生成する。
雑記
sparse factorizationをみて行列分解を想像したが、factorizationは単にQKV注意機構へ入力する要素を限定することだけを意味しているように読めた。 \(|A^{(m)}_i|\propto\sqrt[p]{n}\)になるようにパターンを選ぶことで、空間計算量を\(\mathcal{O}n\sqrt[p]{n}\)にできるとあるが、\(n\sqrt[p]{n}\)を目標にする必然性が読みとれない。