- Flow Matching์ ๋๋คํ ๋ ธ์ด์ฆ๋ฅผ โ์กฐ๊ธ์ฉ ์ด๋โ์์ผ ์ํ๋ ๋ฐ์ดํฐ๋ก ๋ง๋๋ ์์ฑ ๋ชจ๋ธ ๊ธฐ๋ฒ
- ๊ฐ ์์น์์ โ์ด๋ ๋ฐฉํฅ์ผ๋ก ๊ฐ์ผ ํ๋์งโ ์๋ ค์ฃผ๋ ํ์ดํ ์ง๋(์๋ ๋ฒกํฐ ํ๋)๋ฅผ ํ์ตํ๋ ๋ฐฉ์
- ๊ทธ ํ์ดํ๋ฅผ ๋ฐ๋ผ ์ฌ๋ฌ ๋ฒ ์ด๋ํ๋ ๋ฐ๋ณต ๊ณ์ฐ(ODE ํ์ด) ๊ธฐ๋ฐ ์์ฑ ํ๋ ์์ํฌ
- diffusion ๋๋น ํจ์ฌ ์ ์ ํ์๋ก ๋น ๋ฅด๊ฒ ๊ฒฐ๊ณผ๋ฅผ ๋ง๋๋ ํน์ฑ
ํด๋น ๊ฐ๋ ์ด ํ์ํ ์ด์
- diffusion model์ ๋ ธ์ด์ฆ๋ฅผ ์ง์ฐ๋ ๋ฐ ์์ญ~์๋ฐฑ ๋ฒ ๋ฐ๋ณต์ด ํ์ํด ๋๋ฆฐ ์ถ๋ก
- ์ค์๊ฐ TTSยท์ด๋ฏธ์ง ์์ฑ์๋ ๋น ๋ฅธ ์์ฑ์ด ํ์
- ๊ธฐ์กด normalizing flow๋ ๋ชจ๋ธ ๊ตฌ์กฐ์ ๊น๋ค๋ก์ด ์ํ ์ ์ฝ(Jacobian)์ ์๊ตฌํ๋ ์ค๊ณ ๋ถ๋ด
ํ ๋ฌธ์ฅ ๋น์ : โํ์ดํ๋ฅผ ๋ฐ๋ผ๊ฐ๋ ๋ด๋น๊ฒ์ด์ โ
์ง๋ ์ ๋ชจ๋ ์ง์ ๋ง๋ค โ์ฌ๊ธฐ์๋ ์ด ๋ฐฉํฅ์ผ๋ก ๊ฐ๋ผโ๋ ํ์ดํ๊ฐ ๊ทธ๋ ค์ ธ ์๋ค๊ณ ์์ํ์. ์ถ๋ฐ์ ์ ์์ ํ์ดํ๊ฐ ๊ฐ๋ฆฌํค๋ ๋๋ก ํ ๊ฑธ์ ๊ฐ๊ณ , ์ ์์น์ ํ์ดํ๋ฅผ ๋ณด๊ณ ๋ ํ ๊ฑธ์ ๊ฐ๊ณ โฆ ์ด๊ฑธ ๋ฐ๋ณตํ๋ฉด ๋ชฉ์ ์ง์ ๋์ฐฉํ๋ค.
- ์ถ๋ฐ์ = ๋๋ค ๋ ธ์ด์ฆ (์๋ฏธ ์๋ ๊ฐ)
- ๋ชฉ์ ์ง = ์ฐ๋ฆฌ๊ฐ ์ํ๋ ๋ฐ์ดํฐ (์: ์์ฑ์ ์ ์ฌ ํํ)
- ํ์ดํ ์ง๋ = ๋ชจ๋ธ์ด ํ์ตํ๋ ๊ฒ (= ์๋ ๋ฒกํฐ ํ๋)
- ๊ฑธ์ด๊ฐ๊ธฐ = ODE๋ฅผ ํธ๋ ๊ฒ (๊ทธ๋ฅ โ์กฐ๊ธ์ฉ ์ด๋ ๋ฐ๋ณตโ)
๊ฐ๋ฌผ์ ๋์ด ์ข ์ด๋ฐฐ๋ฅผ ๋ ์ฌ๋ ค๋ ๋๋ค. ๊ฐ ์ง์ ์ ๋ฌผ์ด(=ํ์ดํ)์ ๋ฐ๋ผ ํ๋ฌ๊ฐ๋ค ๋ณด๋ฉด ์์ฐ์ค๋ฝ๊ฒ ํ๊ตฌ(๋ฐ์ดํฐ)์ ๋๋ฌํ๋ค.
ํต์ฌ ์ฉ์ด 4๊ฐ โ ๋น์ ๋ก ์ดํดํ๊ธฐ
| ์ด๋ ค์ด ๋ง | ์ฌ์ด ๋ง | ๋น์ |
|---|---|---|
| Probability path | ๋ ธ์ด์ฆ โ ๋ฐ์ดํฐ๋ก ๋ชจ์์ด ๋ณํด๊ฐ๋ ์ค๊ฐ ๋จ๊ณ๋ค | ์ ๋๋ฉ์ด์ ํ๋ ์ (๋ ธ์ด์ฆ๊ฐ ์ ์ ์์ฑ์ด ๋๋ ์์) |
| Vector field (์๋ ๋ฒกํฐ ํ๋) | โ์ด ์์น์์ ์ด ๋ฐฉํฅยท์ด ์๋๋กโ ํ์ดํ ์ง๋ | ๋ฐ๋์ฅ(weather map) / ๊ฐ๋ฌผ์ ๋ฌผ์ด |
| ODE solver | ํ์ดํ ๋ฐ๋ผ ์กฐ๊ธ์ฉ ์ด๋ํ๋ ๋ฐ๋ณต | ๊ฒ์์์ ๋งค ํ๋ ์ ์บ๋ฆญํฐ ์์น ๊ฐฑ์ |
| Conditional Flow Matching | ์ ๋ต์ ์ ๋ ํ์ดํ๋ฅผ ๊ฐ๋ฅด์น๋ ํ์ต๋ฒ | ์ถ๋ฐโ๋์ฐฉ์ ์ง์ ์ผ๋ก ์๊ณ ๋ฐฉํฅ ์๋ ค์ฃผ๊ธฐ |
์ฝ๋๋ก ๋ณด๋ ์์ฑ ๊ณผ์
์ค์ ๋ก โ์์ฑโ์ ๊ฒ์ ๋ฌผ๋ฆฌ ๋ฃจํ(์์น += ์๋ ร ์๊ฐ)์ ๋๊ฐ๋ค. CS 2ํ๋
์ด๋ฉด ์ต์ํ for๋ฌธ์ด๋ค.
# ์์ฑ(์ถ๋ก ): ๋
ธ์ด์ฆ์์ ์ถ๋ฐํด ํ์ดํ(์๋)๋ฅผ ๋ฐ๋ผ ์กฐ๊ธ์ฉ ์ด๋
x = random_noise() # ์ถ๋ฐ์ : ์๋ฏธ ์๋ ๋๋ค๊ฐ
dt = 1.0 / num_steps # ํ ๋ฒ์ ๊ฐ ๊ฑฐ๋ฆฌ (์คํ
์ด ๋ง์์๋ก ์๊ฒ ์ด๋)
for t in range(num_steps): # ์: 8๋ฒ ๋ฐ๋ณต
v = model(x, t) # "์ง๊ธ ์์น x์์ ์ด๋๋ก?" โ ํ์ต๋ ์๋(ํ์ดํ)
x = x + v * dt # ๊ทธ ๋ฐฉํฅ์ผ๋ก ํ ๊ฑธ์ ์ด๋
# ๋ฐ๋ณต์ด ๋๋๋ฉด x๋ ์ง์ง ๋ฐ์ดํฐ (์ฌ๊ธฐ์ ์์ฑ์ ์ ์ฌ ํํ)num_steps๊ฐ ํด์๋ก ์๊ฒ ๋๋ ์์ง์ฌ ์ ๊ตํด์ง๊ณ (ํ์งโ), ์์์๋ก ์ฑํผ์ฑํผ ๊ฐ์ ๋นจ๋ผ์ง๋ค(์๋โ). Supertonic์ total_steps(5~12)๊ฐ ๋ฐ๋ก ์ด ๊ฐ์ด๋ค.
ํ์ต์ ์ด๋ป๊ฒ? โ โ์ ๋ต์ ์ ๋ ์ง์ ์ผ๋ก ๊ฐ๋ฅด์น๋คโ
ํ์ดํ ์ง๋๋ ์ด๋ป๊ฒ ๋ง๋ค๊น? ํ์ต ๋๋ ์ถ๋ฐ์ (๋ ธ์ด์ฆ)๊ณผ ๋์ฐฉ์ (์ง์ง ๋ฐ์ดํฐ)์ ๋ ๋ค ์๊ณ ์๋ค. ๊ทธ๋์ ๋์ ์ง์ ์ผ๋ก ์๊ณ , โ์ด ์ง์ ๋ฐฉํฅ์ด ์ ๋ต ์๋์ผโ๋ผ๊ณ ๋ชจ๋ธ์๊ฒ ๊ฐ๋ฅด์น๋ค. ์ด๊ฒ์ด Conditional Flow Matching์ด๋ค.
# ํ์ต: ์ ๋ต์ ์๊ณ ์์ผ๋ ์ง์ ๊ฒฝ๋ก์ "๋ฐฉํฅ"์ ๊ฐ๋ฅด์น๋ค
noise = random_noise()
data = real_sample() # ์ง์ง ๋ฐ์ดํฐ (์ ๋ต)
t = random(0, 1) # ์ง์ ์ ์์์ ํ ์ง์
x_t = (1 - t) * noise + t * data # ๋
ธ์ด์ฆ์ ๋ฐ์ดํฐ๋ฅผ ์ง์ ์ผ๋ก ๋ณด๊ฐ
target_v = data - noise # ์ด ์ง์ ์ ๋ฐฉํฅ(์๋) = ์ ๋ต ํ์ดํ
pred_v = model(x_t, t) # ๋ชจ๋ธ์ด ์์ธกํ ํ์ดํ
loss = mse(pred_v, target_v) # ์ ๋ต ํ์ดํ์ ๊ฐ๊น์์ง๋๋ก ํ์ต์๋ง์ (๋ ธ์ด์ฆ, ๋ฐ์ดํฐ) ์์ผ๋ก ์ด ํ์ต์ ๋ฐ๋ณตํ๋ฉด, ๋ชจ๋ธ์ โ์ด๋ค ์์น์์๋ ๋ฐ์ดํฐ ์ชฝ์ผ๋ก ๊ฐ๋ ํ์ดํโ๋ฅผ ๊ทธ๋ฆด ์ค ์๊ฒ ๋๋ค. ์์ฑํ ๋๋ ๋ ธ์ด์ฆ๋ง ์ฃผ๊ณ ์ด ํ์ดํ๋ฅผ ๋ฐ๋ผ๊ฐ๊ฒ ํ๋ฉด ๋๋ค.
diffusion๊ณผ ๋น๊ต: โ์๊ฐ ๋ ๋ฏธ๋กโ vs โ๋ปฅ ๋ซ๋ฆฐ ๊ณ ์๋๋กโ
flowchart LR subgraph DIFF["Diffusion โ ์๊ฐ ๋ ๋ฏธ๋ก"] A1[๋ ธ์ด์ฆ] --> A2[ํ ๋ฐ์ง] --> A3[ํ ๋ฐ์ง] --> A4[...์์ญ~์๋ฐฑ ๋ฒ] --> A5[๋ฐ์ดํฐ] end subgraph FM["Flow Matching โ ๊ณ ์๋๋ก"] B1[๋ ธ์ด์ฆ] --> B2[์ฑํผ] --> B3[์ฑํผ 5~12๋ฒ] --> B4[๋ฐ์ดํฐ] end
| ํญ๋ชฉ | Diffusion | Flow Matching |
|---|---|---|
| ๊ฒฝ๋ก | ๋ ธ์ด์ฆ๋ฅผ ๋๋ฌ์ผ๋ฉฐ ์กฐ๊ธ์ฉ ์ ๊ฑฐ (๊ตฌ๋ถ๊ตฌ๋ถ) | ๊ฑฐ์ ์ง์ ๊ฒฝ๋ก๋ฅผ ๋ฐ๋ผ ์ด๋ |
| ์คํ ์ | ์์ญ~์๋ฐฑ | ์~์์ญ (์ ์) |
| ์๋ | ๋๋ฆผ | ๋น ๋ฆ (์ค์๊ฐ ์ ํฉ) |
| ๋น์ | ์๊ฐ ์ ํ ๋ฐ์ง์ฉ ๊ธธ์ฐพ๊ธฐ | ์ง์ ๋๋ก ๋ช ๋ฒ์ ๋์ฐฉ |
Supertonic์์์ ์ญํ
- Supertonic์ text-to-latent ๋ชจ๋์ด flow matching์ผ๋ก ํ ์คํธ๋ฅผ ์ํฅ ์ ์ฌ ํํ์ผ๋ก ๋ณํํ๋ค.
- ์ฌ์ฉ๋ฒ์
total_steps(5~12)๊ฐ ๋ฐ๋ก ์ for๋ฌธ์ **๋ฐ๋ณต ํ์(num_steps)**๋ค. ํด์๋ก ํ์งโยท์๋โ. - โ์ ์ ์คํ ์ผ๋ก ๋น ๋ฅธ ์์ฑโ์ด๋ผ๋ ํน์ฑ ๋๋ถ์ GPU ์์ด CPU๋ง์ผ๋ก๋ ์ค์๊ฐ์ ๊ฐ๊น์ด TTS๊ฐ ๊ฐ๋ฅํ๋ค.
- Supertonic์ ํ์ต ๋ฐ์ดํฐ์ ๋ ธ์ด์ฆ(๋ถ์ ํํ ๋ผ๋ฒจ)๊ฐ ์์ด๋ ํ๋ค๋ฆฌ์ง ์๋๋ก Self-Purifying Flow Matching ๊ธฐ๋ฒ์ ๋ํ๋ค.