Q, K, V行列
106日前原文(arpitbhayani.me)
概要
- LLMのアテンション機構の中心はQuery、Key、Valueの3つの行列
- Q、K、V行列の役割と構築方法を具体例とともに解説
- TransformerがRNNと異なる点やアテンションの直感的理解
- Q、K、V行列の重み行列の意味や次元選択の影響
- 実際のアテンション計算の流れとQKV行列の役割を整理
LLMのアテンション機構とQ、K、V行列の基礎
- Query、Key、Value行列は、Transformerが入力系列内の異なる単語間に注意を向けるための基盤
- 人間が「The cat sat on the mat because it was comfortable.」の“it”が“the mat”を指すと理解するのと同様の文脈把握機構
- 従来のRNNはトークンを一つずつ処理し、隠れ状態で情報を伝搬
- 例:「The」→h1、「cat」→h2(h1とcatの情報)、…と逐次処理
- Transformerはアテンションにより、全単語を同時に処理し、各単語が他の単語すべてを直接参照可能
- 例:「sat」は「cat」に60%の強い注意、「on」に15%の注意など、並列的な関係把握
- この仕組みにより、学習効率向上と遠距離依存関係の捕捉が可能
アテンションの直感とQ、K、Vの意味
- アテンションはデータベース検索に例えられる
- Query(Q): 何を探すか
- Key(K): 何を持っているか
- Value(V): 実際に保持する情報
- 各入力位置ごとに「何に注意すべきか?」というクエリを作成し、全てのキーと比較、最適な値を取得
アテンション処理の全体フロー
- 入力 → 線形変換 → Q、K、V生成 → アテンションスコア計算 → Softmax → 重み付き和 → 出力
- これにより、各単語が系列内の重要な単語に動的に注意を向ける
Q、K、V行列の具体的な構築例
-
例文:「Cat eats fish」
- 各単語を4次元ベクトルで表現(実際は埋め込みベクトルを利用)
- cat = [1.0, 0.0, 0.5, 0.2] など
- 3単語をまとめて入力行列X(3, 4)を作成
- 各単語を4次元ベクトルで表現(実際は埋め込みベクトルを利用)
-
Q、K、V行列の生成には重み行列Wq、Wk、Wv(各4×3)を使用
- これらは学習により最適化されるパラメータ
- Q = X @ Wq、K = X @ Wk、V = X @ Wv(各3×3行列)
-
各行の意味
- Qの各行:各単語が「何に注意すべきか?」というクエリ
- Kの各行:各単語が「どんな情報を持っているか?」というキー
- Vの各行:各単語が「実際に持つ情報」バリュー
コード例(擬似コード)
- 入力ベクトルと重み行列からQ、K、Vを生成する関数
- 入力:input_embeddings(seq_len, d_model)、d_k(出力次元)、seed
- 出力:Q、K、V、重み行列
- 例:「Cat eats fish」の3単語を使い、Q、K、Vを計算
Q、K、Vの重み行列を分ける理由
- 各重み行列の役割が異なるため
- Wq:クエリ(質問)生成
- Wk:キー(検索インデックス)生成
- Wv:バリュー(実際の内容)生成
- もし同じ重みを使うと、機能的な分離が損なわれ、最適な学習ができない
- 検索エンジンの「検索ワード」「インデックス」「本文」の違いに相当
投影次元(d_k)の影響
- d_kが小さい場合
- 計算・メモリ効率が高い
- 複雑な関係性の把握が難しい
- シンプルなタスクやマルチヘッドアテンションの一部として有用
- d_kが大きい場合
- 複雑な関係性も捉えやすい
- パラメータ数・計算コスト増加
- 実運用モデル(BERT等)ではd_k=64×12ヘッド=768次元などを採用
アテンション計算におけるQ、K、Vの役割
- アテンションスコアの計算
- attention_scores = Q @ K^T / sqrt(d_k)
- 各行:単語iが単語jにどれだけ注意を向けるかを示す
- スコアをsoftmaxで正規化→Vの重み付き和を計算→出力ベクトルへ
- Q、K、Vはアテンション処理の最初の一歩
- 以降、スコア計算、softmax、重み付き和、出力投影と続く
まとめ:Q、K、V行列の意義
- Query、Key、Value行列はTransformerの文脈理解の核
- 入力埋め込みを3種の重みで別々に射影し、検索・被検索・内容伝達の役割を分担
- この設計により、モデルは入力内の関連部分に動的に注意を向け、高度な言語理解を実現