KV Cache๋ Transformer ์ถ๋ก ์ ์ด์ ํ ํฐ์ KeyยทValue ํ ์๋ฅผ ์ ์ฅยท์ฌ์ฌ์ฉํ๋ ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ ๊ธฐ๋ฒ
- Self-Attention์ ์ค๋ณต ์ฐ์ฐ์ ์ ๊ฑฐํ์ฌ ์ฐ์ฐ๋์ O(nยฒ)์์ O(n)์ผ๋ก ์ค์ด๋ ์บ์ฑ ๋ฉ์ปค๋์ฆ
- ๊ฐ Transformer ๋ธ๋ก๋ง๋ค ๋ ๋ฆฝ์ ์ผ๋ก ์ ์ง๋๋ ๋ ์ด์ด๋ณ Key-Value ์ ์ฅ์
ํด๋น ๊ฐ๋ ์ด ํ์ํ ์ด์
- Autoregressive ๋ชจ๋ธ์ ํ ํฐ์ ํ๋์ฉ ์์ฑํ ๋๋ง๋ค ์ด์ ๋ชจ๋ ํ ํฐ์ ๋ํด Self-Attention์ ์ฌ๊ณ์ฐ
- KV Cache ์์ด๋ ์ํ์ค ๊ธธ์ด๊ฐ ๊ธธ์ด์ง์๋ก ์ฐ์ฐ๋์ด ๊ธฐํ๊ธ์์ ์ผ๋ก ์ฆ๊ฐ
- KV Cache๋ฅผ ์ฌ์ฉํ๋ฉด ๋์ผ ํ๋์จ์ด์์ ์ฝ 5๋ฐฐ ๋น ๋ฅธ ์ถ๋ก ๊ฐ๋ฅ (T4 GPU ๋ฒค์น๋งํฌ ๊ธฐ์ค)
Tensor๋?
๋ฐ์ดํฐ๋ฅผ ๋ด๋ ๋ค์ฐจ์ ์ปจํ ์ด๋. ์ฐจ์ ์์ ๋ฐ๋ผ ์ด๋ฆ์ด ๋ฌ๋ผ์ง๋ค:
| ์ฐจ์ | ์ด๋ฆ | ์์ |
|---|---|---|
| 0์ฐจ์ | Scalar | 42 โ ์ซ์ ํ๋ |
| 1์ฐจ์ | Vector | [1, 2, 3] โ ์ซ์์ ๋์ด |
| 2์ฐจ์ | Matrix | ํ๊ณผ ์ด๋ก ๊ตฌ์ฑ๋ ํ (์์ ์ํธ) |
| 3์ฐจ์+ | Tensor | ํ๋ ฌ์ ์ฌ๋ฌ ์ฅ ์์ ๊ฒ |
Q) โKV Cache ๋ฌธ์์์ ๋งํ๋ ํ ์๊ฐ ๋ญ๊ฐ์?โ
- KV Cache์์ Key, Value ํ
์๋ 4์ฐจ์ ํ
์:
[batch_size, num_heads, seq_len, head_dim] - ์ง๊ด์ ์ผ๋ก โ โ๋ฐฐ์น ร ์ดํ ์ ํค๋ ์ ร ์ํ์ค ๊ธธ์ด ร ํค๋ ์ฐจ์โ์ 4๊ฒน ์ซ์ ๋ฉ์ด๋ฆฌ
- ์:
[1, 32, 512, 128]โ ๋ฐฐ์น 1๊ฐ, ํค๋ 32๊ฐ, ํ ํฐ 512๊ฐ, ๊ฐ ํค๋ ํฌ๊ธฐ 128
Self-Attention์ด๋?
๋ฌธ์ฅ ์ ๊ฐ ๋จ์ด๊ฐ ๋ค๋ฅธ ๋ชจ๋ ๋จ์ด์์ ๊ด๊ณ๋ฅผ ๊ณ์ฐํ์ฌ ๋ฌธ๋งฅ์ ํ์ ํ๋ ๋ฉ์ปค๋์ฆ.
๋์๊ด ๋น์ :
- Query (Q): โ๋๋ ์ด๋ฐ ์ ๋ณด๋ฅผ ์ฐพ๊ณ ์์ดโ โ ํ์ฌ ๋จ์ด๊ฐ ์ํ๋ ๊ฒ
- Key (K): โ๋๋ ์ด๋ฐ ์ ๋ณด๋ฅผ ๊ฐ๊ณ ์์ดโ โ ๊ฐ ๋จ์ด๊ฐ ์ ๊ณตํ ์ ์๋ ๊ฒ
- Value (V): โ๋ด ์ค์ ๋ด์ฉ์ ์ด๊ฑฐ์ผโ โ ๊ฐ ๋จ์ด์ ์ค์ ์ ๋ณด
sequenceDiagram autonumber participant Token as ์ ๋ ฅ ํ ํฐ participant QKV as Q, K, V ์์ฑ participant Score as ์ ์ฌ๋ ๊ณ์ฐ participant Soft as Softmax participant Out as ์ถ๋ ฅ Token->>QKV: ๊ฐ ํ ํฐ โ Q, K, V ๋ฒกํฐ ๋ณํ Note over QKV: ์ ๋ ฅ ร ๊ฐ์ค์น ํ๋ ฌ 3๊ฐ<br/>(Wq, Wk, Wv) QKV->>Score: Q ร Kแต (๋ด์ ) Note over Score: "๋(Q)์ ๋น์ทํ<br/>์ ๋ณด(K)๋ฅผ ๊ฐ์ง ๋จ์ด๋?" Score->>Soft: ์ ์ โ ํ๋ฅ ๋ถํฌ ๋ณํ Note over Soft: ํฉ์ด 1์ด ๋๋๋ก ์ ๊ทํ Soft->>Out: ํ๋ฅ ร V = ๊ฐ์ค ํ๊ท Note over Out: ๊ด๋ จ๋ ๋์ ๋จ์ด์<br/>์ ๋ณด(V)๋ฅผ ๋ ๋ง์ด ๋ฐ์
ํต์ฌ ์์: Attention(Q, K, V) = softmax(Q ร Kแต / โd) ร V
KV Cache์์ ์ฐ๊ฒฐ: Self-Attention์์ Q๋ ๋งค๋ฒ ์๋ก ๊ณ์ฐํ์ง๋ง, K์ V๋ ์ด์ ํ ํฐ์ ๊ฒ์ ์ฌํ์ฉํ ์ ์๋ค. ์ด๊ฒ์ด ๋ฐ๋ก KV Cache์ ์๋ฆฌ.
AS-IS
sequenceDiagram autonumber participant Input as ์ ๋ ฅ ์ํ์ค participant Attn as Self-Attention participant Out as ์ถ๋ ฅ Note over Input: "The red cat was" Input->>Attn: Token 1~4 โ Q, K, V ์ ์ฒด ๊ณ์ฐ Attn->>Out: "sitting" ์์ฑ Note over Input: "The red cat was sitting" Input->>Attn: Token 1~5 โ Q, K, V ์ ์ฒด ์ฌ๊ณ์ฐ Note over Attn: Token 1~4์ K,V๋ฅผ<br/>์ฒ์๋ถํฐ ๋ค์ ๊ณ์ฐ (๋ญ๋น) Attn->>Out: "on" ์์ฑ Note over Input: "The red cat was sitting on" Input->>Attn: Token 1~6 โ Q, K, V ์ ์ฒด ์ฌ๊ณ์ฐ Note over Attn: ๋งค ์คํ ๋ง๋ค<br/>์ ์ฒด ์ฌ๊ณ์ฐ ๋ฐ๋ณต Attn->>Out: "the" ์์ฑ
TO-BE
sequenceDiagram autonumber participant Input as ์ ํ ํฐ participant Cache as KV Cache participant Attn as Self-Attention participant Out as ์ถ๋ ฅ Note over Cache: Prefill Phase Input->>Cache: Token 1~4 โ Kโโโ, Vโโโ ์ ์ฅ Input->>Attn: Qโโโ + Kโโโ + Vโโโ Attn->>Out: "sitting" ์์ฑ Note over Cache: Decode Phase Input->>Cache: Token 5 โ Kโ , Vโ ์ถ๊ฐ ์ ์ฅ Note over Cache: Cache: [Kโ..Kโ ], [Vโ..Vโ ] Input->>Attn: Qโ + Cache(Kโโโ , Vโโโ ) Note over Attn: ์ ํ ํฐ์ Q๋ง ๊ณ์ฐ Attn->>Out: "on" ์์ฑ Input->>Cache: Token 6 โ Kโ, Vโ ์ถ๊ฐ ์ ์ฅ Input->>Attn: Qโ + Cache(Kโโโ, Vโโโ) Attn->>Out: "the" ์์ฑ
Prefill๊ณผ Decode โ ๋ ๋จ๊ณ์ ์ถ๋ก (์์ธ)
Prefill Phase (์ฑ์ฐ๊ธฐ ๋จ๊ณ)
์ฌ์ฉ์๊ฐ ์ ๋ ฅํ ํ๋กฌํํธ ์ ์ฒด๋ฅผ ํ ๋ฒ์ ๋ณ๋ ฌ ์ฒ๋ฆฌํ๋ ๋จ๊ณ.
์
๋ ฅ: "์ค๋ ๋ ์จ ์ด๋?" (5ํ ํฐ)
โ 5๊ฐ ํ ํฐ์ GPU์์ ๋์์ ์ฒ๋ฆฌ
โ 5๊ฐ ํ ํฐ ๊ฐ๊ฐ์ K, V๋ฅผ ํ๊บผ๋ฒ์ Cache์ ์ ์ฅ
โ ์ฒซ ๋ฒ์งธ ์ถ๋ ฅ ํ ํฐ ์์ฑ
- ํน์ง: GPU ์ฐ์ฐ(compute) ์ง์ฝ์ โ ํ๋ ฌ๊ณฑ์ด ๋ง์ GPU ์ฝ์ด๋ฅผ ์ต๋ํ ํ์ฉ
- ๋ณ๋ชฉ: ์ฐ์ฐ๋ (FLOPS)
- ๋น์ : ๋์๊ด์์ ์ฐธ๊ณ ๋์ 10๊ถ์ ํ๊บผ๋ฒ์ ์ฑ ์ ์์ ํผ์ณ๋๋ ๊ฒ
Decode Phase (์์ฑ ๋จ๊ณ)
ํ ํฐ์ ํ๋์ฉ ์์ฐจ์ ์ผ๋ก ์์ฑํ๋ ๋จ๊ณ.
1๋จ๊ณ: Q("๋ง") ร Cache(Kโโโ
, Vโโโ
) โ "๊ณ " ์์ฑ, KโVโ ์บ์ ์ถ๊ฐ
2๋จ๊ณ: Q("๊ณ ") ร Cache(Kโโโ, Vโโโ) โ "ํ์ฐฝ" ์์ฑ, KโVโ ์บ์ ์ถ๊ฐ
3๋จ๊ณ: Q("์ฐฝ") ร Cache(Kโโโ, Vโโโ) โ "ํฉ๋๋ค" ์์ฑ, KโVโ ์บ์ ์ถ๊ฐ
- ํน์ง: ๋ฉ๋ชจ๋ฆฌ ๋์ญํญ(memory bandwidth) ์์กด์ โ ์บ์์์ K,V๋ฅผ ์ฝ์ด์ค๋ ์๋๊ฐ ๋ณ๋ชฉ
- ๋ณ๋ชฉ: ๋ฉ๋ชจ๋ฆฌ ์ฝ๊ธฐ ์๋ (GB/s)
- ๋น์ : ํผ์ณ๋์ ์ฐธ๊ณ ๋์(Cache)์์ ํ ์ค์ฉ ์ฝ์ผ๋ฉฐ ๋ต์์ ์ ๋ ๊ฒ
๋ ๋จ๊ณ์ ํต์ฌ ์ฐจ์ด
| Prefill | Decode | |
|---|---|---|
| ์ฒ๋ฆฌ ๋ฐฉ์ | ๋ณ๋ ฌ (์ ์ฒด ํ ํฐ ๋์) | ์์ฐจ (ํ ํฐ 1๊ฐ์ฉ) |
| ๋ณ๋ชฉ | GPU ์ฐ์ฐ๋ (compute-bound) | ๋ฉ๋ชจ๋ฆฌ ๋์ญํญ (memory-bound) |
| KV Cache | ์บ์ ์์ฑ | ์บ์ ์ฝ๊ธฐ + ์ถ๊ฐ |
| ์คํ ํ์ | 1๋ฒ | ์ถ๋ ฅ ํ ํฐ ์๋งํผ ๋ฐ๋ณต |
๋ฉ๋ชจ๋ฆฌ ๋น์ฉ
๊ฐ Transformer ๋ธ๋ก์์ ํ ํฐ๋น ์ ์ฅํ๋ ํ ์:
- K ํ
์:
[batch_size, num_heads, seq_len, head_dim] - V ํ
์:
[batch_size, num_heads, seq_len, head_dim]
์์: LLaMA-2 13B ๊ธฐ์ค ํ ํฐ๋น ~1MB, 4K ์ปจํ ์คํธ์์ ์ํ์ค๋น ~4GB
์ฑ๋ฅ ๋น๊ต (T4 GPU ๋ฒค์น๋งํฌ)
| ๋ฐฉ์ | ์์ ์๊ฐ |
|---|---|
| KV Cache ๋ฏธ์ฌ์ฉ | 1๋ถ 1์ด |
| KV Cache ์ฌ์ฉ | 11.7์ด |
| ์๋ ํฅ์ | ~5.2๋ฐฐ |