• Flash Attention์€ GPU ๋ฉ”๋ชจ๋ฆฌ ๊ณ„์ธต ๊ตฌ์กฐ๋ฅผ ํ™œ์šฉํ•œ IO-aware ์ •ํ™• ์–ดํ…์…˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜ (by Tri Dao, NeurIPS 2022)
  • Nร—N ์–ดํ…์…˜ ํ–‰๋ ฌ์„ HBM์— ์‹ค์ฒดํ™”ํ•˜์ง€ ์•Š๊ณ , ํƒ€์ผ๋ง(tiling)์œผ๋กœ SRAM์—์„œ ๋ธ”๋ก ๋‹จ์œ„ ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰
  • ๊ทผ์‚ฌ๊ฐ€ ์•„๋‹Œ ์ •ํ™•ํ•œ(exact) ์–ดํ…์…˜์ด๋ฉด์„œ ์ตœ๋Œ€ 7.6๋ฐฐ ๋น ๋ฅด๊ณ  ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ O(N)์œผ๋กœ ๊ฐ์†Œ

ํ•ด๋‹น ๊ฐœ๋…์ด ํ•„์š”ํ•œ ์ด์œ 

  • ํ‘œ์ค€ ์–ดํ…์…˜์€ ์‹œํ€€์Šค ๊ธธ์ด์— ๋Œ€ํ•ด O(Nยฒ) ์‹œ๊ฐ„/๋ฉ”๋ชจ๋ฆฌ ๋ณต์žก๋„
  • ๊ทผ์‚ฌ ์–ดํ…์…˜(Reformer, Performer ๋“ฑ)์€ FLOP๋Š” ์ค„์ด์ง€๋งŒ ์‹ค์ œ wall-clock ์†๋„ ํ–ฅ์ƒ ๋ฏธ๋ฏธ
  • ํ•ต์‹ฌ ๋ณ‘๋ชฉ์€ ์—ฐ์‚ฐ๋Ÿ‰์ด ์•„๋‹ˆ๋ผ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ(IO) โ€” GPU HBM โ†” SRAM ๊ฐ„ ๋ฐ์ดํ„ฐ ์ด๋™

AS-IS: ํ‘œ์ค€ ์–ดํ…์…˜

sequenceDiagram
    autonumber
    participant HBM as GPU HBM
    participant SRAM as GPU SRAM

    HBM->>SRAM: Q, K ์ „์ฒด ๋กœ๋“œ
    SRAM->>SRAM: S = Q ร— K^T (Nร—N ํ–‰๋ ฌ ๊ณ„์‚ฐ)
    SRAM->>HBM: S๋ฅผ HBM์— ์ €์žฅ (Nร—N)
    HBM->>SRAM: S ๋‹ค์‹œ ๋กœ๋“œ
    SRAM->>SRAM: P = softmax(S)
    SRAM->>HBM: P๋ฅผ HBM์— ์ €์žฅ (Nร—N)
    HBM->>SRAM: P, V ๋กœ๋“œ
    SRAM->>SRAM: O = P ร— V
    SRAM->>HBM: O ์ €์žฅ
    Note over HBM,SRAM: Nร—N ์ค‘๊ฐ„ ํ–‰๋ ฌ์ด HBM์„ ์™•๋ณต โ†’ IO ๋ณ‘๋ชฉ

TO-BE: Flash Attention

sequenceDiagram
    autonumber
    participant HBM as GPU HBM (๋А๋ฆผ, ๋Œ€์šฉ๋Ÿ‰)
    participant SRAM as GPU SRAM (๋น ๋ฆ„, ์†Œ์šฉ๋Ÿ‰)
    participant Compute as ์—ฐ์‚ฐ ์œ ๋‹›

    HBM->>SRAM: K, V ๋ธ”๋ก ๋กœ๋“œ
    loop Q ๋ธ”๋ก ์ˆœํšŒ
        HBM->>SRAM: Q ๋ธ”๋ก ๋กœ๋“œ
        SRAM->>Compute: ๋ธ”๋ก ๋‹จ์œ„ ์–ดํ…์…˜ ๊ณ„์‚ฐ
        Compute->>SRAM: ๋ถ€๋ถ„ ๊ฒฐ๊ณผ + ์ •๊ทœํ™” ํŒฉํ„ฐ
    end
    SRAM->>HBM: ์ตœ์ข… ๊ฒฐ๊ณผ ๊ธฐ๋ก
    Note over HBM,Compute: Nร—N ํ–‰๋ ฌ์ด HBM์— ์‹ค์ฒดํ™”๋˜์ง€ ์•Š์Œ

ํƒ€์ผ๋ง(Tiling) ํ•ต์‹ฌ ๋ฉ”์ปค๋‹ˆ์ฆ˜

์ž…๋ ฅ Q, K, V๋ฅผ ๋ธ”๋ก์œผ๋กœ ๋ถ„ํ• ํ•œ ๋’ค, ๋А๋ฆฐ HBM์—์„œ ๋น ๋ฅธ SRAM์œผ๋กœ ๋ธ”๋ก ๋‹จ์œ„๋กœ ๋กœ๋“œํ•œ๋‹ค. ๊ฐ ๋ธ”๋ก์—์„œ ์–ดํ…์…˜ ์ถœ๋ ฅ์„ ๊ณ„์‚ฐํ•˜๊ณ , ์ ์ ˆํ•œ ์ •๊ทœํ™” ํŒฉํ„ฐ๋กœ ์Šค์ผ€์ผ๋งํ•˜์—ฌ ํ•ฉ์‚ฐํ•˜๋ฉด ์ •ํ™•ํ•œ ๊ฒฐ๊ณผ๋ฅผ ์–ป๋Š”๋‹ค.

ํƒ€์ผ๋ง ๋•๋ถ„์— recomputation์œผ๋กœ ์ธํ•ด FLOP๊ฐ€ ์ฆ๊ฐ€ํ•˜๋”๋ผ๋„, HBM ์ ‘๊ทผ์ด ๋Œ€ํญ ์ค„์–ด๋“ค์–ด ์‹ค์ œ ์†๋„๋Š” ๋” ๋น ๋ฅด๋‹ค.

IO ๋ณต์žก๋„ ๋น„๊ต

๋ฐฉ์‹HBM ์ ‘๊ทผ ํšŸ์ˆ˜
ํ‘œ์ค€ ์–ดํ…์…˜ฮ˜(Nd + Nยฒ)
Flash AttentionO(NยฒdยฒMโปยน)

d: head ์ฐจ์› (64-128), M: SRAM ํฌ๊ธฐ (100KB). ์ผ๋ฐ˜์ ์œผ๋กœ dยฒ โ‰ช M์ด๋ฏ€๋กœ ์ˆ˜9๋ฐฐ ์ ์€ ๋ฉ”๋ชจ๋ฆฌ ์ ‘๊ทผ.

ํ•ต์‹ฌ ์ธ์‚ฌ์ดํŠธ: FLOP โ‰  ์†๋„

๊ทผ์‚ฌ ์–ดํ…์…˜์€ FLOP ์ ˆ๊ฐ์— ์ง‘์ค‘ํ–ˆ์ง€๋งŒ, FLOP ์ ˆ๊ฐ์ด ๋ฐ˜๋“œ์‹œ wall-clock ์†๋„ ํ–ฅ์ƒ์œผ๋กœ ์ด์–ด์ง€์ง€ ์•Š๋Š”๋‹ค. Flash Attention์€ ์˜คํžˆ๋ ค recomputation์œผ๋กœ FLOP๋ฅผ ์ฆ๊ฐ€์‹œํ‚ค๋ฉด์„œ๋„, IO๋ฅผ ์ค„์—ฌ ์‹ค์ œ ์†๋„๋ฅผ ํ–ฅ์ƒ์‹œ์ผฐ๋‹ค. ์ด๋Š” ํ•˜๋“œ์›จ์–ด ์ธ์‹ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ์„ค๊ณ„์˜ ์ค‘์š”์„ฑ์„ ๋ณด์—ฌ์ฃผ๋Š” ๋Œ€ํ‘œ ์‚ฌ๋ก€๋‹ค.

์ฐธ๊ณ  ๋ฌธ์„œ