ํ”ผ๋“œ๋กœ ๋Œ์•„๊ฐ€๊ธฐ
๐Ÿงจ Accelerating Stable Diffusion XL Inference with JAX on Cloud TPU v5e
Hugging Face BlogHugging Face Blog
AI/ML

Hugging Face์™€ Google Cloud๊ฐ€ JAX + Cloud TPU v5e ์กฐํ•ฉ์œผ๋กœ Stable Diffusion XL ์ถ”๋ก ์„ 2.4๋ฐฐ ๋น„์šฉ ํšจ์œจ๋กœ ๊ฐœ์„ ํ•˜๊ณ  4์ดˆ ๋‚ด 1024ร—1024 ์ด๋ฏธ์ง€ 4์žฅ ์ƒ์„ฑ

๐Ÿงจ Accelerating Stable Diffusion XL Inference with JAX on Cloud TPU v5e

2023๋…„ 10์›” 3์ผ12๋ถ„intermediate

Context

Stable Diffusion XL์€ ์ด์ „ ๋ฒ„์ „ ๋Œ€๋น„ UNet ์ปดํฌ๋„ŒํŠธ๊ฐ€ 3๋ฐฐ ์ปค์ ธ์„œ ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ์‚ฌํ•ญ๊ณผ ์ถ”๋ก  ์‹œ๊ฐ„์ด ๊ธ‰์ฆํ–ˆ๋‹ค. ํ”„๋กœ๋•์…˜ ๋ฐฐํฌ ์‹œ ๋†’์€ ๊ณ„์‚ฐ ๋น„์šฉ๊ณผ ์ง€์—ฐ ์‹œ๊ฐ„์ด ์ฃผ์š” ๊ณผ์ œ์˜€๋‹ค.

Technical Solution

  • JAX์˜ JIT(Just-In-Time) ์ปดํŒŒ์ผ ๋„์ž…: ์ •์  ๋ชจ์–‘(static shapes)์ด ํ•„์š”ํ•œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ž‘์—…์˜ ํŠน์„ฑ์„ ํ™œ์šฉํ•ด ์ฒซ ์‹คํ–‰ ์‹œ ์ตœ์ ํ™”๋œ TPU ๋ฐ”์ด๋„ˆ๋ฆฌ๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ์ดํ›„ ํ˜ธ์ถœ์—์„œ ์žฌ์‚ฌ์šฉ
  • XLA ์ปดํŒŒ์ผ๋Ÿฌ ๊ธฐ๋ฐ˜ ๋ณ‘๋ ฌํ™”: JAX์˜ pmap์„ ์‚ฌ์šฉํ•ด SPMD(Single-Program Multiple-Data) ํ”„๋กœ๊ทธ๋žจ ๊ตฌํ˜„์œผ๋กœ ์—ฌ๋Ÿฌ TPU ์นฉ์—์„œ ๋™์‹œ ์‹คํ–‰ (์˜ˆ: 8์นฉ TPU์—์„œ 1์นฉ์ด ์ด๋ฏธ์ง€ 1๊ฐœ ์ƒ์„ฑํ•  ์‹œ๊ฐ„์— ์ด๋ฏธ์ง€ 8๊ฐœ ์ƒ์„ฑ)
  • ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ์ •๋ฐ€๋„ ์ตœ์ ํ™”: 32๋น„ํŠธ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ bfloat16์œผ๋กœ ๋ณ€ํ™˜ํ•ด ๋ฉ”๋ชจ๋ฆฌ ์ ˆ๊ฐ ๋ฐ ๊ณ„์‚ฐ ์†๋„ ํ–ฅ์ƒ (๋‹จ, ์Šค์ผ€์ค„๋Ÿฌ ์ƒํƒœ๋Š” float32 ์œ ์ง€)
  • Cloud TPU v5e ์ธ์Šคํ„ด์Šค ํ™œ์šฉ: 1์นฉ, 4์นฉ, 8์นฉ, 256์นฉ ๊ทœ๋ชจ์˜ ๋‹ค์–‘ํ•œ ๊ตฌ์„ฑ์œผ๋กœ ์‚ฌ์šฉ ์‚ฌ๋ก€์— ๋งž์ถ˜ ํ™•์žฅ์„ฑ ์ œ๊ณต
  • Hugging Face Diffusers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ํ†ตํ•ฉ: FlaxStableDiffusionXLPipeline์„ ํ†ตํ•ด JAX ๊ธฐ๋ฐ˜ ์ถ”๋ก  ํŒŒ์ดํ”„๋ผ์ธ ๊ฐ„์†Œํ™”

Impact

  • Cloud TPU v5e๋Š” TPU v4 ๋Œ€๋น„ 2.4๋ฐฐ ๋†’์€ ์„ฑ๋Šฅ/๋‹ฌ๋Ÿฌ ๋‹ฌ์„ฑ (TPU v5e-4: ๋ฐฐ์น˜ ํฌ๊ธฐ 4์ผ ๋•Œ 21.46 vs TPU v4-8: 9.05)
  • SDXL ์ด๋ฏธ์ง€ 4์žฅ(1024ร—1024) ์ƒ์„ฑ ์‹œ๊ฐ„: ์ „์ฒด 4์ดˆ (์‹ค์ œ ์ƒ์„ฑ ์‹œ๊ฐ„ 2.3์ดˆ, ํ˜•์‹ ๋ณ€ํ™˜ ๋ฐ ํ†ต์‹  ํฌํ•จ)
  • ๋ฐฐ์น˜ ํฌ๊ธฐ 4 ๊ธฐ์ค€ ์ง€์—ฐ์‹œ๊ฐ„: TPU v5e-4์—์„œ 2.33์ดˆ (TPU v4-8: 2.16์ดˆ)
  • Cloud TPU v5e๋Š” TPU v4 ๋Œ€๋น„ ์ ˆ๋ฐ˜ ์ดํ•˜์˜ ๋น„์šฉ

Key Takeaway

์ƒ์„ฑํ˜• AI ๋ชจ๋ธ ๋ฐฐํฌ์—์„œ ๊ณ ์ • ์ž…์ถœ๋ ฅ ๋ชจ์–‘ ํŠน์„ฑ์„ ํ™œ์šฉํ•œ JIT ์ปดํŒŒ์ผ๊ณผ ์ „๋ฌธํ™”๋œ ํ•˜๋“œ์›จ์–ด(TPU) + ์ตœ์ ํ™”๋œ ์†Œํ”„ํŠธ์›จ์–ด ์Šคํƒ(JAX) ์กฐํ•ฉ์€ ๋‹จ์ˆœํ•œ ์„ฑ๋Šฅ ํ–ฅ์ƒ์„ ๋„˜์–ด ๋น„์šฉ ํšจ์œจ์„ฑ๊นŒ์ง€ ๋Œ€ํญ ๊ฐœ์„ ํ•  ์ˆ˜ ์žˆ๋‹ค. ์ •์  ๊ตฌ์กฐ์˜ ์›Œํฌ๋กœ๋“œ ํŠน์„ฑ ํŒŒ์•…๊ณผ ๊ทธ์— ๋งž๋Š” ๊ธฐ์ˆ  ์Šคํƒ ์„ ํƒ์ด ์„ฑ๋Šฅ๊ณผ ๊ฒฝ์ œ์„ฑ์˜ ํ•ต์‹ฌ์ด๋‹ค.


๋Œ€๊ทœ๋ชจ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์„œ๋น„์Šค๋ฅผ ๊ตฌ์ถ•ํ•  ๋•Œ ๊ณ ์ • ํฌ๊ธฐ ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ๊ฐ€ ๊ฐ€๋Šฅํ•œ ๊ตฌ์กฐ๋ผ๋ฉด, JAX์˜ pmap์„ ํ™œ์šฉํ•ด ๋‹จ์ผ ์ฝ”๋“œ๋กœ ์—ฌ๋Ÿฌ TPU ์นฉ ๊ฐ„ ๋ณ‘๋ ฌ ์‹คํ–‰์„ ๊ตฌํ˜„ํ•˜๊ณ  bfloat16 ์ •๋ฐ€๋„ ๋ณ€ํ™˜์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ ˆ๊ฐํ•˜๋ฉด, ๊ธฐ์กด ๋ฐฉ์‹ ๋Œ€๋น„ 2๋ฐฐ ์ด์ƒ์˜ ๋น„์šฉ ํšจ์œจ์„ ๋‹ฌ์„ฑํ•  ์ˆ˜ ์žˆ๋‹ค.

์›๋ฌธ ์ฝ๊ธฐ