Hugging Face Diffusers๊ฐ Flax ์ง์์ ์ถ๊ฐํด Stable Diffusion์ Google TPU์์ 8๊ฐ ๋ณ๋ ฌ ์ฅ์น๋ฅผ ํ์ฉํ ์ถ๋ก ์ผ๋ก ๋จ์ผ ์นฉ ์๋๋ก 8๋ฐฐ ์ด๋ฏธ์ง ์์ฑ ๊ฐ๋ฅ
๐งจ Stable Diffusion in JAX / Flax !
AI ์์ฝ
Context
Stable Diffusion ์ถ๋ก ์ ๊ณ์ฐ๋์ด ๋ง์ GPU ํ๊ฒฝ์์๋ ์๊ฐ์ด ์์๋๋ค. Google TPU ์๋ฒ๋ 8๊ฐ์ TPU ๊ฐ์๊ธฐ๋ฅผ ๋ณ๋ ฌ๋ก ๊ตฌ๋ํ ์ ์์ผ๋, ๊ธฐ์กด Diffusers ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ ์ด๋ฅผ ํ์ฉํ์ง ๋ชปํ๋ค.
Technical Solution
- Flax๋ฅผ ์ฌ์ฉํ ์ํ๋น์ ์ฅ ๋ชจ๋ธ ๊ตฌ์กฐ ๋์ : ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๋ฅผ ๋ชจ๋ธ ์ธ๋ถ์ ์ ์ฅํ๊ณ ๋ณ๋ ๊ด๋ฆฌ
- JAX์ pmap(parallel map) ํจ์๋ก ํ์ดํ๋ผ์ธ์ _generate ๋ฉ์๋๋ฅผ 8๊ฐ TPU ๋๋ฐ์ด์ค์ ๋ณ๋ ฌํ: ๊ฐ ๋๋ฐ์ด์ค๊ฐ ๋ค๋ฅธ ์ ๋ ฅ ๋ฐฐ์น๋ฅผ ์ฒ๋ฆฌํ๋๋ก ์๋ ์ค๋ฉ
- ํ๋กฌํํธ ๋ณต์ ๋ฐ ์ ๋ ฅ ์ค๋ฉ: ๋จ์ผ ํ๋กฌํํธ๋ฅผ 8๋ฐฐ๋ก ๋ณต์ ํ ํ flax.jax_utils.replicate์ shard๋ฅผ ์ฌ์ฉํด ๊ฐ ๋๋ฐ์ด์ค์ ๋ถ์ฐ
- bfloat16 ๋ฐ์ ๋ฐ๋ ์ฐ์ฐ ํ์ฉ: ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ ์ํด bf16 ๊ฐ์ค์น ๋ฒ์ ๋ก๋
- ๋น๋๊ธฐ ๋์คํจ์น ์ฒ๋ฆฌ: JAX์ ๋น๋๊ธฐ ํน์ฑ์ ํ์ฉํ๋ block_until_ready()๋ก ์ ํํ ์ธก์ ์๊ฐ ํ๋ณด
Impact
๋จ์ผ ์นฉ์ด 1๊ฐ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ์๊ฐ์ 8๊ฐ ๋๋ฐ์ด์ค๊ฐ ๋์์ 8๊ฐ ์ด๋ฏธ์ง ์์ฑ (์ด ์คํ ์๊ฐ 6.82์ด๋ก ์ธก์ ๋จ).
Key Takeaway
Functional ํ๋ก๊ทธ๋๋ฐ ํจ๋ฌ๋ค์(Flax)๊ณผ ๋ถ์ฐ ์ปดํ์ผ ์ถ์ํ(pmap)๋ฅผ ์กฐํฉํ๋ฉด, ๊ธฐ์ ํจ์๋ฅผ ์์ ์์ด ๋ณ๋ ฌ ์คํ ์ฝ๋๋ก ๋ณํ ๊ฐ๋ฅํ๋ฏ๋ก, ๋ค์ค ๊ฐ์๊ธฐ ํ๊ฒฝ์์ ๊ฐ๋ฐ์์ ๋ณต์ก๋ ์ฆ๊ฐ ์์ด ํ์ฅ์ฑ์ ํ๋ณดํ ์ ์๋ค.
์ค์ฒ ํฌ์ธํธ
Google Colab์ด๋ Google Cloud Platform์ TPU ํ๊ฒฝ์ ์ฌ์ฉํ๋ ์ด๋ฏธ์ง ์์ฑ ์๋น์ค์์ Flax ๊ธฐ๋ฐ์ Stable Diffusion ํ์ดํ๋ผ์ธ์ ๋์ ํ๊ณ pmap์ผ๋ก _generate ๋ฉ์๋๋ฅผ ๋ณ๋ ฌํํ๋ฉด, ๊ธฐ์กด ๋ฐฐ์น ์ฒ๋ฆฌ ๋ก์ง ๋ณ๊ฒฝ ์์ด ๋์์ 8๋ฐฐ์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ์ ์๋ค.