Hugging Face์ Google Cloud๊ฐ JAX + Cloud TPU v5e ์กฐํฉ์ผ๋ก Stable Diffusion XL ์ถ๋ก ์ 2.4๋ฐฐ ๋น์ฉ ํจ์จ๋ก ๊ฐ์ ํ๊ณ 4์ด ๋ด 1024ร1024 ์ด๋ฏธ์ง 4์ฅ ์์ฑ
๐งจ Accelerating Stable Diffusion XL Inference with JAX on Cloud TPU v5e
AI ์์ฝ
Context
Stable Diffusion XL์ ์ด์ ๋ฒ์ ๋๋น UNet ์ปดํฌ๋ํธ๊ฐ 3๋ฐฐ ์ปค์ ธ์ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ์ฌํญ๊ณผ ์ถ๋ก ์๊ฐ์ด ๊ธ์ฆํ๋ค. ํ๋ก๋์ ๋ฐฐํฌ ์ ๋์ ๊ณ์ฐ ๋น์ฉ๊ณผ ์ง์ฐ ์๊ฐ์ด ์ฃผ์ ๊ณผ์ ์๋ค.
Technical Solution
- JAX์ JIT(Just-In-Time) ์ปดํ์ผ ๋์ : ์ ์ ๋ชจ์(static shapes)์ด ํ์ํ ์ด๋ฏธ์ง ์์ฑ ์์ ์ ํน์ฑ์ ํ์ฉํด ์ฒซ ์คํ ์ ์ต์ ํ๋ TPU ๋ฐ์ด๋๋ฆฌ๋ฅผ ์์ฑํ๊ณ ์ดํ ํธ์ถ์์ ์ฌ์ฌ์ฉ
- XLA ์ปดํ์ผ๋ฌ ๊ธฐ๋ฐ ๋ณ๋ ฌํ: JAX์ pmap์ ์ฌ์ฉํด SPMD(Single-Program Multiple-Data) ํ๋ก๊ทธ๋จ ๊ตฌํ์ผ๋ก ์ฌ๋ฌ TPU ์นฉ์์ ๋์ ์คํ (์: 8์นฉ TPU์์ 1์นฉ์ด ์ด๋ฏธ์ง 1๊ฐ ์์ฑํ ์๊ฐ์ ์ด๋ฏธ์ง 8๊ฐ ์์ฑ)
- ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ์ ๋ฐ๋ ์ต์ ํ: 32๋นํธ ํ๋ผ๋ฏธํฐ๋ฅผ bfloat16์ผ๋ก ๋ณํํด ๋ฉ๋ชจ๋ฆฌ ์ ๊ฐ ๋ฐ ๊ณ์ฐ ์๋ ํฅ์ (๋จ, ์ค์ผ์ค๋ฌ ์ํ๋ float32 ์ ์ง)
- Cloud TPU v5e ์ธ์คํด์ค ํ์ฉ: 1์นฉ, 4์นฉ, 8์นฉ, 256์นฉ ๊ท๋ชจ์ ๋ค์ํ ๊ตฌ์ฑ์ผ๋ก ์ฌ์ฉ ์ฌ๋ก์ ๋ง์ถ ํ์ฅ์ฑ ์ ๊ณต
- Hugging Face Diffusers ๋ผ์ด๋ธ๋ฌ๋ฆฌ ํตํฉ: FlaxStableDiffusionXLPipeline์ ํตํด JAX ๊ธฐ๋ฐ ์ถ๋ก ํ์ดํ๋ผ์ธ ๊ฐ์ํ
Impact
- Cloud TPU v5e๋ TPU v4 ๋๋น 2.4๋ฐฐ ๋์ ์ฑ๋ฅ/๋ฌ๋ฌ ๋ฌ์ฑ (TPU v5e-4: ๋ฐฐ์น ํฌ๊ธฐ 4์ผ ๋ 21.46 vs TPU v4-8: 9.05)
- SDXL ์ด๋ฏธ์ง 4์ฅ(1024ร1024) ์์ฑ ์๊ฐ: ์ ์ฒด 4์ด (์ค์ ์์ฑ ์๊ฐ 2.3์ด, ํ์ ๋ณํ ๋ฐ ํต์ ํฌํจ)
- ๋ฐฐ์น ํฌ๊ธฐ 4 ๊ธฐ์ค ์ง์ฐ์๊ฐ: TPU v5e-4์์ 2.33์ด (TPU v4-8: 2.16์ด)
- Cloud TPU v5e๋ TPU v4 ๋๋น ์ ๋ฐ ์ดํ์ ๋น์ฉ
Key Takeaway
์์ฑํ AI ๋ชจ๋ธ ๋ฐฐํฌ์์ ๊ณ ์ ์ ์ถ๋ ฅ ๋ชจ์ ํน์ฑ์ ํ์ฉํ JIT ์ปดํ์ผ๊ณผ ์ ๋ฌธํ๋ ํ๋์จ์ด(TPU) + ์ต์ ํ๋ ์ํํธ์จ์ด ์คํ(JAX) ์กฐํฉ์ ๋จ์ํ ์ฑ๋ฅ ํฅ์์ ๋์ด ๋น์ฉ ํจ์จ์ฑ๊น์ง ๋ํญ ๊ฐ์ ํ ์ ์๋ค. ์ ์ ๊ตฌ์กฐ์ ์ํฌ๋ก๋ ํน์ฑ ํ์ ๊ณผ ๊ทธ์ ๋ง๋ ๊ธฐ์ ์คํ ์ ํ์ด ์ฑ๋ฅ๊ณผ ๊ฒฝ์ ์ฑ์ ํต์ฌ์ด๋ค.
์ค์ฒ ํฌ์ธํธ
๋๊ท๋ชจ ์ด๋ฏธ์ง ์์ฑ ์๋น์ค๋ฅผ ๊ตฌ์ถํ ๋ ๊ณ ์ ํฌ๊ธฐ ๋ฐฐ์น ์ฒ๋ฆฌ๊ฐ ๊ฐ๋ฅํ ๊ตฌ์กฐ๋ผ๋ฉด, JAX์ pmap์ ํ์ฉํด ๋จ์ผ ์ฝ๋๋ก ์ฌ๋ฌ TPU ์นฉ ๊ฐ ๋ณ๋ ฌ ์คํ์ ๊ตฌํํ๊ณ bfloat16 ์ ๋ฐ๋ ๋ณํ์ผ๋ก ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ ๊ฐํ๋ฉด, ๊ธฐ์กด ๋ฐฉ์ ๋๋น 2๋ฐฐ ์ด์์ ๋น์ฉ ํจ์จ์ ๋ฌ์ฑํ ์ ์๋ค.