KV Cache๋Š” Transformer ์ถ”๋ก  ์‹œ ์ด์ „ ํ† ํฐ์˜ KeyยทValue ํ…์„œ๋ฅผ ์ €์žฅยท์žฌ์‚ฌ์šฉํ•˜๋Š” ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™” ๊ธฐ๋ฒ•

  • Self-Attention์˜ ์ค‘๋ณต ์—ฐ์‚ฐ์„ ์ œ๊ฑฐํ•˜์—ฌ ์—ฐ์‚ฐ๋Ÿ‰์„ O(nยฒ)์—์„œ O(n)์œผ๋กœ ์ค„์ด๋Š” ์บ์‹ฑ ๋ฉ”์ปค๋‹ˆ์ฆ˜
  • ๊ฐ Transformer ๋ธ”๋ก๋งˆ๋‹ค ๋…๋ฆฝ์ ์œผ๋กœ ์œ ์ง€๋˜๋Š” ๋ ˆ์ด์–ด๋ณ„ Key-Value ์ €์žฅ์†Œ

ํ•ด๋‹น ๊ฐœ๋…์ด ํ•„์š”ํ•œ ์ด์œ 

  • Autoregressive ๋ชจ๋ธ์€ ํ† ํฐ์„ ํ•˜๋‚˜์”ฉ ์ƒ์„ฑํ•  ๋•Œ๋งˆ๋‹ค ์ด์ „ ๋ชจ๋“  ํ† ํฐ์— ๋Œ€ํ•ด Self-Attention์„ ์žฌ๊ณ„์‚ฐ
  • KV Cache ์—†์ด๋Š” ์‹œํ€€์Šค ๊ธธ์ด๊ฐ€ ๊ธธ์–ด์งˆ์ˆ˜๋ก ์—ฐ์‚ฐ๋Ÿ‰์ด ๊ธฐํ•˜๊ธ‰์ˆ˜์ ์œผ๋กœ ์ฆ๊ฐ€
  • KV Cache๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๋™์ผ ํ•˜๋“œ์›จ์–ด์—์„œ ์•ฝ 5๋ฐฐ ๋น ๋ฅธ ์ถ”๋ก  ๊ฐ€๋Šฅ (T4 GPU ๋ฒค์น˜๋งˆํฌ ๊ธฐ์ค€)

Tensor๋ž€?

๋ฐ์ดํ„ฐ๋ฅผ ๋‹ด๋Š” ๋‹ค์ฐจ์› ์ปจํ…Œ์ด๋„ˆ. ์ฐจ์› ์ˆ˜์— ๋”ฐ๋ผ ์ด๋ฆ„์ด ๋‹ฌ๋ผ์ง„๋‹ค:

์ฐจ์›์ด๋ฆ„์˜ˆ์‹œ
0์ฐจ์›Scalar42 โ€” ์ˆซ์ž ํ•˜๋‚˜
1์ฐจ์›Vector[1, 2, 3] โ€” ์ˆซ์ž์˜ ๋‚˜์—ด
2์ฐจ์›Matrixํ–‰๊ณผ ์—ด๋กœ ๊ตฌ์„ฑ๋œ ํ‘œ (์—‘์…€ ์‹œํŠธ)
3์ฐจ์›+Tensorํ–‰๋ ฌ์„ ์—ฌ๋Ÿฌ ์žฅ ์Œ“์€ ๊ฒƒ

Q) โ€œKV Cache ๋ฌธ์„œ์—์„œ ๋งํ•˜๋Š” ํ…์„œ๊ฐ€ ๋ญ”๊ฐ€์š”?โ€

  • KV Cache์—์„œ Key, Value ํ…์„œ๋Š” 4์ฐจ์› ํ…์„œ: [batch_size, num_heads, seq_len, head_dim]
  • ์ง๊ด€์ ์œผ๋กœ โ€” โ€œ๋ฐฐ์น˜ ร— ์–ดํ…์…˜ ํ—ค๋“œ ์ˆ˜ ร— ์‹œํ€€์Šค ๊ธธ์ด ร— ํ—ค๋“œ ์ฐจ์›โ€์˜ 4๊ฒน ์ˆซ์ž ๋ฉ์–ด๋ฆฌ
  • ์˜ˆ: [1, 32, 512, 128] โ†’ ๋ฐฐ์น˜ 1๊ฐœ, ํ—ค๋“œ 32๊ฐœ, ํ† ํฐ 512๊ฐœ, ๊ฐ ํ—ค๋“œ ํฌ๊ธฐ 128

Self-Attention์ด๋ž€?

๋ฌธ์žฅ ์† ๊ฐ ๋‹จ์–ด๊ฐ€ ๋‹ค๋ฅธ ๋ชจ๋“  ๋‹จ์–ด์™€์˜ ๊ด€๊ณ„๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ ๋ฌธ๋งฅ์„ ํŒŒ์•…ํ•˜๋Š” ๋ฉ”์ปค๋‹ˆ์ฆ˜.

๋„์„œ๊ด€ ๋น„์œ :

  • Query (Q): โ€œ๋‚˜๋Š” ์ด๋Ÿฐ ์ •๋ณด๋ฅผ ์ฐพ๊ณ  ์žˆ์–ดโ€ โ€” ํ˜„์žฌ ๋‹จ์–ด๊ฐ€ ์›ํ•˜๋Š” ๊ฒƒ
  • Key (K): โ€œ๋‚˜๋Š” ์ด๋Ÿฐ ์ •๋ณด๋ฅผ ๊ฐ–๊ณ  ์žˆ์–ดโ€ โ€” ๊ฐ ๋‹จ์–ด๊ฐ€ ์ œ๊ณตํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒƒ
  • Value (V): โ€œ๋‚ด ์‹ค์ œ ๋‚ด์šฉ์€ ์ด๊ฑฐ์•ผโ€ โ€” ๊ฐ ๋‹จ์–ด์˜ ์‹ค์ œ ์ •๋ณด
sequenceDiagram
    autonumber
    participant Token as ์ž…๋ ฅ ํ† ํฐ
    participant QKV as Q, K, V ์ƒ์„ฑ
    participant Score as ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
    participant Soft as Softmax
    participant Out as ์ถœ๋ ฅ

    Token->>QKV: ๊ฐ ํ† ํฐ โ†’ Q, K, V ๋ฒกํ„ฐ ๋ณ€ํ™˜
    Note over QKV: ์ž…๋ ฅ ร— ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ 3๊ฐœ<br/>(Wq, Wk, Wv)
    QKV->>Score: Q ร— Kแต€ (๋‚ด์ )
    Note over Score: "๋‚˜(Q)์™€ ๋น„์Šทํ•œ<br/>์ •๋ณด(K)๋ฅผ ๊ฐ€์ง„ ๋‹จ์–ด๋Š”?"
    Score->>Soft: ์ ์ˆ˜ โ†’ ํ™•๋ฅ  ๋ถ„ํฌ ๋ณ€ํ™˜
    Note over Soft: ํ•ฉ์ด 1์ด ๋˜๋„๋ก ์ •๊ทœํ™”
    Soft->>Out: ํ™•๋ฅ  ร— V = ๊ฐ€์ค‘ ํ‰๊ท 
    Note over Out: ๊ด€๋ จ๋„ ๋†’์€ ๋‹จ์–ด์˜<br/>์ •๋ณด(V)๋ฅผ ๋” ๋งŽ์ด ๋ฐ˜์˜

ํ•ต์‹ฌ ์ˆ˜์‹: Attention(Q, K, V) = softmax(Q ร— Kแต€ / โˆšd) ร— V

KV Cache์™€์˜ ์—ฐ๊ฒฐ: Self-Attention์—์„œ Q๋Š” ๋งค๋ฒˆ ์ƒˆ๋กœ ๊ณ„์‚ฐํ•˜์ง€๋งŒ, K์™€ V๋Š” ์ด์ „ ํ† ํฐ์˜ ๊ฒƒ์„ ์žฌํ™œ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค. ์ด๊ฒƒ์ด ๋ฐ”๋กœ KV Cache์˜ ์›๋ฆฌ.

AS-IS

sequenceDiagram
    autonumber
    participant Input as ์ž…๋ ฅ ์‹œํ€€์Šค
    participant Attn as Self-Attention
    participant Out as ์ถœ๋ ฅ

    Note over Input: "The red cat was"
    Input->>Attn: Token 1~4 โ†’ Q, K, V ์ „์ฒด ๊ณ„์‚ฐ
    Attn->>Out: "sitting" ์ƒ์„ฑ

    Note over Input: "The red cat was sitting"
    Input->>Attn: Token 1~5 โ†’ Q, K, V ์ „์ฒด ์žฌ๊ณ„์‚ฐ
    Note over Attn: Token 1~4์˜ K,V๋ฅผ<br/>์ฒ˜์Œ๋ถ€ํ„ฐ ๋‹ค์‹œ ๊ณ„์‚ฐ (๋‚ญ๋น„)
    Attn->>Out: "on" ์ƒ์„ฑ

    Note over Input: "The red cat was sitting on"
    Input->>Attn: Token 1~6 โ†’ Q, K, V ์ „์ฒด ์žฌ๊ณ„์‚ฐ
    Note over Attn: ๋งค ์Šคํ…๋งˆ๋‹ค<br/>์ „์ฒด ์žฌ๊ณ„์‚ฐ ๋ฐ˜๋ณต
    Attn->>Out: "the" ์ƒ์„ฑ

TO-BE

sequenceDiagram
    autonumber
    participant Input as ์ƒˆ ํ† ํฐ
    participant Cache as KV Cache
    participant Attn as Self-Attention
    participant Out as ์ถœ๋ ฅ

    Note over Cache: Prefill Phase
    Input->>Cache: Token 1~4 โ†’ Kโ‚โ‚‹โ‚„, Vโ‚โ‚‹โ‚„ ์ €์žฅ
    Input->>Attn: Qโ‚โ‚‹โ‚„ + Kโ‚โ‚‹โ‚„ + Vโ‚โ‚‹โ‚„
    Attn->>Out: "sitting" ์ƒ์„ฑ

    Note over Cache: Decode Phase
    Input->>Cache: Token 5 โ†’ Kโ‚…, Vโ‚… ์ถ”๊ฐ€ ์ €์žฅ
    Note over Cache: Cache: [Kโ‚..Kโ‚…], [Vโ‚..Vโ‚…]
    Input->>Attn: Qโ‚… + Cache(Kโ‚โ‚‹โ‚…, Vโ‚โ‚‹โ‚…)
    Note over Attn: ์ƒˆ ํ† ํฐ์˜ Q๋งŒ ๊ณ„์‚ฐ
    Attn->>Out: "on" ์ƒ์„ฑ

    Input->>Cache: Token 6 โ†’ Kโ‚†, Vโ‚† ์ถ”๊ฐ€ ์ €์žฅ
    Input->>Attn: Qโ‚† + Cache(Kโ‚โ‚‹โ‚†, Vโ‚โ‚‹โ‚†)
    Attn->>Out: "the" ์ƒ์„ฑ

Prefill๊ณผ Decode โ€” ๋‘ ๋‹จ๊ณ„์˜ ์ถ”๋ก  (์ƒ์„ธ)

Prefill Phase (์ฑ„์šฐ๊ธฐ ๋‹จ๊ณ„)

์‚ฌ์šฉ์ž๊ฐ€ ์ž…๋ ฅํ•œ ํ”„๋กฌํ”„ํŠธ ์ „์ฒด๋ฅผ ํ•œ ๋ฒˆ์— ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌํ•˜๋Š” ๋‹จ๊ณ„.

์ž…๋ ฅ: "์˜ค๋Š˜ ๋‚ ์”จ ์–ด๋•Œ?" (5ํ† ํฐ)
โ†’ 5๊ฐœ ํ† ํฐ์„ GPU์—์„œ ๋™์‹œ์— ์ฒ˜๋ฆฌ
โ†’ 5๊ฐœ ํ† ํฐ ๊ฐ๊ฐ์˜ K, V๋ฅผ ํ•œ๊บผ๋ฒˆ์— Cache์— ์ €์žฅ
โ†’ ์ฒซ ๋ฒˆ์งธ ์ถœ๋ ฅ ํ† ํฐ ์ƒ์„ฑ
  • ํŠน์ง•: GPU ์—ฐ์‚ฐ(compute) ์ง‘์•ฝ์  โ€” ํ–‰๋ ฌ๊ณฑ์ด ๋งŽ์•„ GPU ์ฝ”์–ด๋ฅผ ์ตœ๋Œ€ํ•œ ํ™œ์šฉ
  • ๋ณ‘๋ชฉ: ์—ฐ์‚ฐ๋Ÿ‰ (FLOPS)
  • ๋น„์œ : ๋„์„œ๊ด€์—์„œ ์ฐธ๊ณ  ๋„์„œ 10๊ถŒ์„ ํ•œ๊บผ๋ฒˆ์— ์ฑ…์ƒ ์œ„์— ํŽผ์ณ๋†“๋Š” ๊ฒƒ

Decode Phase (์ƒ์„ฑ ๋‹จ๊ณ„)

ํ† ํฐ์„ ํ•˜๋‚˜์”ฉ ์ˆœ์ฐจ์ ์œผ๋กœ ์ƒ์„ฑํ•˜๋Š” ๋‹จ๊ณ„.

1๋‹จ๊ณ„: Q("๋ง‘") ร— Cache(Kโ‚โ‚‹โ‚…, Vโ‚โ‚‹โ‚…) โ†’ "๊ณ " ์ƒ์„ฑ, Kโ‚†Vโ‚† ์บ์‹œ ์ถ”๊ฐ€
2๋‹จ๊ณ„: Q("๊ณ ") ร— Cache(Kโ‚โ‚‹โ‚†, Vโ‚โ‚‹โ‚†) โ†’ "ํ™”์ฐฝ" ์ƒ์„ฑ, Kโ‚‡Vโ‚‡ ์บ์‹œ ์ถ”๊ฐ€
3๋‹จ๊ณ„: Q("์ฐฝ") ร— Cache(Kโ‚โ‚‹โ‚‡, Vโ‚โ‚‹โ‚‡) โ†’ "ํ•ฉ๋‹ˆ๋‹ค" ์ƒ์„ฑ, Kโ‚ˆVโ‚ˆ ์บ์‹œ ์ถ”๊ฐ€
  • ํŠน์ง•: ๋ฉ”๋ชจ๋ฆฌ ๋Œ€์—ญํญ(memory bandwidth) ์˜์กด์  โ€” ์บ์‹œ์—์„œ K,V๋ฅผ ์ฝ์–ด์˜ค๋Š” ์†๋„๊ฐ€ ๋ณ‘๋ชฉ
  • ๋ณ‘๋ชฉ: ๋ฉ”๋ชจ๋ฆฌ ์ฝ๊ธฐ ์†๋„ (GB/s)
  • ๋น„์œ : ํŽผ์ณ๋†“์€ ์ฐธ๊ณ  ๋„์„œ(Cache)์—์„œ ํ•œ ์ค„์”ฉ ์ฝ์œผ๋ฉฐ ๋‹ต์•ˆ์„ ์ ๋Š” ๊ฒƒ

๋‘ ๋‹จ๊ณ„์˜ ํ•ต์‹ฌ ์ฐจ์ด

PrefillDecode
์ฒ˜๋ฆฌ ๋ฐฉ์‹๋ณ‘๋ ฌ (์ „์ฒด ํ† ํฐ ๋™์‹œ)์ˆœ์ฐจ (ํ† ํฐ 1๊ฐœ์”ฉ)
๋ณ‘๋ชฉGPU ์—ฐ์‚ฐ๋Ÿ‰ (compute-bound)๋ฉ”๋ชจ๋ฆฌ ๋Œ€์—ญํญ (memory-bound)
KV Cache์บ์‹œ ์ƒ์„ฑ์บ์‹œ ์ฝ๊ธฐ + ์ถ”๊ฐ€
์‹คํ–‰ ํšŸ์ˆ˜1๋ฒˆ์ถœ๋ ฅ ํ† ํฐ ์ˆ˜๋งŒํผ ๋ฐ˜๋ณต

๋ฉ”๋ชจ๋ฆฌ ๋น„์šฉ

๊ฐ Transformer ๋ธ”๋ก์—์„œ ํ† ํฐ๋‹น ์ €์žฅํ•˜๋Š” ํ…์„œ:

  • K ํ…์„œ: [batch_size, num_heads, seq_len, head_dim]
  • V ํ…์„œ: [batch_size, num_heads, seq_len, head_dim]

์˜ˆ์‹œ: LLaMA-2 13B ๊ธฐ์ค€ ํ† ํฐ๋‹น ~1MB, 4K ์ปจํ…์ŠคํŠธ์—์„œ ์‹œํ€€์Šค๋‹น ~4GB

์„ฑ๋Šฅ ๋น„๊ต (T4 GPU ๋ฒค์น˜๋งˆํฌ)

๋ฐฉ์‹์†Œ์š” ์‹œ๊ฐ„
KV Cache ๋ฏธ์‚ฌ์šฉ1๋ถ„ 1์ดˆ
KV Cache ์‚ฌ์šฉ11.7์ดˆ
์†๋„ ํ–ฅ์ƒ~5.2๋ฐฐ

์ฐธ๊ณ  ๋ฌธ์„œ