- Muon์ ๊ทธ๋๋์ธํธ์ ์ด๋ ํ๊ท (momentum)์ ์ง๊ตํํ์ฌ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธํ๋ ์ตํฐ๋ง์ด์
- ๋ชจ๋ ํน์๊ฐ์ 1์ ๊ฐ๊น๊ฒ ๋ง๋๋ Newton-Schulz ๋ฐ๋ณต ๊ธฐ๋ฐ ์ง๊ตํ ์๊ณ ๋ฆฌ์ฆ
- AdamW ๋๋น ์ฝ 2๋ฐฐ์ ๊ณ์ฐ ํจ์จ์ ๋ณด์ด๋ฉฐ, NanoGPT ์๋ ๊ฒฝ์์์ ์๋์ ์ฑ๊ณผ๋ฅผ ๊ธฐ๋กํ ์ฐจ์ธ๋ ์ตํฐ๋ง์ด์
ํด๋น ๊ฐ๋ ์ด ํ์ํ ์ด์
- SGD-momentum๊ณผ Adam์ ์ ๋ฐ์ดํธ๋ Transformer 2D ํ๋ผ๋ฏธํฐ์์ condition number๊ฐ ๋งค์ฐ ๋์ (๊ฑฐ์ low-rank)
- ์ด๋ ์์์ ๋ ธ์ด์ฆ ๋ฐฉํฅ์ด ์ต์ ํ๋ฅผ ์ง๋ฐฐํ๊ฒ ๋ง๋๋ ๋ฌธ์
- ์ง๊ตํ๋ โํฌ๊ธฐ๋ ์์ง๋ง ํ์ต์ ์ค์ํ ํฌ์ ๋ฐฉํฅโ์ ์ค์ผ์ผ์ ํจ๊ณผ์ ์ผ๋ก ํค์์ค
AS-IS: Adam/SGD-momentum์ ์ ๋ฐ์ดํธ
gradient โ momentum โ update
์ ๋ฐ์ดํธ ํ๋ ฌ์ ํน์๊ฐ ๋ถํฌ๊ฐ ๊ทน์ฌํ๊ฒ ํธ์ค๋จ. ๋ช ๊ฐ์ ํฐ ํน์๊ฐ์ด ์ ๋ฐ์ดํธ๋ฅผ ์ง๋ฐฐํ๊ณ , ๋๋จธ์ง ๋ฐฉํฅ์ ๋ฌด์๋จ.
TO-BE: Muon์ ์ ๋ฐ์ดํธ
gradient โ momentum โ Newton-Schulz ์ง๊ตํ โ update
(๋ชจ๋ singular value โ 1)
๋ฐฉํฅ ์ ๋ณด๋ ๋ณด์กดํ๋ฉด์ ํฌ๊ธฐ๋ฅผ ์ ๊ทํ. ๋ชจ๋ ๋ฐฉํฅ์ด ๋๋ฑํ ์ค์ผ์ผ๋ก ์ ๋ฐ์ดํธ์ ๊ธฐ์ฌ.
ํต์ฌ ์๋ ์๋ฆฌ
- ๊ทธ๋๋์ธํธ ๊ณ์ฐ
- Momentum ์ ์ฉ (์ด๋ ํ๊ท )
- Newton-Schulz ๋ฐ๋ณต์ผ๋ก momentum ํ๋ ฌ ์ง๊ตํ โ ๋ชจ๋ singular value๋ฅผ 1์ ๊ฐ๊น๊ฒ
- ์ง๊ตํ๋ ์ ๋ฐ์ดํธ๋ก ํ๋ผ๋ฏธํฐ ๊ฐฑ์
๊ธฐ์กด Orthogonal-SGDM(Tuddenham 2022)์ ์ง๊ตํ ํ momentum์ ์ ์ฉํ์ง๋ง, Muon์ momentum ํ ์ง๊ตํ ์์๋ก ๋ฐ๊พธ์ด ๊ฒฝํ์ ์ผ๋ก ๋ ์ข์ ์ฑ๋ฅ์ ๋ฌ์ฑํ๋ค. ๋ํ SVD ๋์ Newton-Schulz ๋ฐ๋ณต์ ์ฌ์ฉํด ๊ณ์ฐ ๋น์ฉ์ ์ ๊ฐํ๋ค.
autoresearch์์์ MuonAdamW ํ์ด๋ธ๋ฆฌ๋
autoresearch์ train.py๋ ํ๋ผ๋ฏธํฐ ์ ํ๋ณ๋ก ์ตํฐ๋ง์ด์ ๋ฅผ ๋ถ๋ฆฌํ๋ค:
| ํ๋ผ๋ฏธํฐ ์ ํ | ์ตํฐ๋ง์ด์ | Learning Rate |
|---|---|---|
| Embedding ๋ ์ด์ด | AdamW | 0.6 |
| Unembedding (lm_head) | AdamW | 0.004 |
| ๋ ์ด์ด๋ณ ์ค์นผ๋ผ | AdamW | 0.5 |
| 2D ํ๋ ฌ (์ดํ ์ /MLP) | Muon | orthogonalization ๊ธฐ๋ฐ |
์ฌ๊ธฐ์ Cautious weight decay๋ฅผ ์ถ๊ฐ: ๊ทธ๋๋์ธํธ์ ํ๋ผ๋ฏธํฐ์ ๊ณฑ์ด โฅ 0์ผ ๋๋ง weight decay๋ฅผ ์ ์ฉํ์ฌ ๋ถํ์ํ ์ ๊ทํ๋ฅผ ๋ฐฉ์งํ๋ค.
์ค์ผ์ผ๋ง๊ณผ ์ต์ ๋ํฅ
- GLM-4.5 (355B), KIMI Moonshot (1T+) ๋ฑ ์ด๋ํ ๋ชจ๋ธ ํ์ต์ ์ด๋ฏธ ์ค์ ๋ฐฐ์น
- Moonlight: 3B/16B MoE ๋ชจ๋ธ์ 5.7T ํ ํฐ์ผ๋ก ํ์ต, ๊ธฐ์กด ๋๋น ํจ์ฌ ์ ์ FLOPs๋ก ๋๋ฑ ์ฑ๋ฅ
- Turbo-Muon: spectral preconditioning์ผ๋ก Newton-Schulz ๋จ๊ณ ๊ณ์ฐ ๋น์ฉ ์ ๊ฐ
- AdaMuon: element-wise adaptivity + ์ง๊ต ์ ๋ฐ์ดํธ ๊ฒฐํฉ, ๋๊ท๋ชจ์์ Adam ๋๋น 40%+ ํจ์จ ํฅ์
- Block-wise Orthogonalization (ICML 2025): ํ๋ ฌ์ ๋ ๋ฆฝ ํ์ผ๋ก ๋ถํ ํ ๊ฐ๋ณ ์ง๊ตํ, 16ร tensor parallel ๊ฐ๋ฅ