Hugging Face๊ฐ PyTorch์ meta device์ ๋์ ๊ฐ์ค์น ๋ก๋ฉ์ผ๋ก 176์ต ๊ฐ ํ๋ผ๋ฏธํฐ ๋ชจ๋ธ์ Colab ๋ฌด๋ฃ ์ธ์คํด์ค์์ ์คํ ๊ฐ๋ฅํ๊ฒ ๊ตฌํ
How ๐ค Accelerate runs very large models thanks to PyTorch
AI ์์ฝ
Context
๊ธฐ์กด PyTorch ๋ชจ๋ธ ๋ก๋ฉ ํ์ดํ๋ผ์ธ์ ๋ชจ๋ธ ์์ฑ โ ๋ฉ๋ชจ๋ฆฌ ๋ก๋ โ ๊ฐ์ค์น ์ฃผ์ โ ๋๋ฐ์ด์ค ์ด๋์ ์์ฐจ ๊ณผ์ ์ ๊ฑฐ์น๋ค. 6.7B ํ๋ผ๋ฏธํฐ ๋ชจ๋ธ(OPT-6.7B)์ float32 ๊ธฐ๋ณธ ์ ๋ฐ๋์์ 26.8GB RAM์ด ํ์ํ๊ณ , 176B ํ๋ผ๋ฏธํฐ ๋ชจ๋ธ(BLOOM, OPT-176B)์ 1.4TB CPU RAM์ด ํ์ํด ์ผ๋ฐ ์๋น์ ํ๋์จ์ด์์ ์คํ ๋ถ๊ฐ๋ฅํ๋ค.
Technical Solution
- PyTorch 1.9์ meta device๋ฅผ ํ์ฉํ ๋น ๋ชจ๋ธ ์์ฑ: ์ค์ ๋ฐ์ดํฐ ์์ด ํ ์์ ํํ(shape)๋ง ๊ฐ์ง๊ณ ๋ฉ๋ชจ๋ฆฌ ํ ๋น ์์ด ๋ชจ๋ธ ์ธ์คํด์ค ์์ฑ
- init_empty_weights() ์ปจํ ์คํธ ๋งค๋์ ๊ฐ๋ฐ: ๊ธฐ์กด Transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ 150๊ฐ ๋ชจ๋ธ ์ฝ๋ ์์ ์์ด ๋น ๋ชจ๋ธ ์๋ ์์ฑ
- ๋ฉํ ๋๋ฐ์ด์ค์ ํํ ์ ๋ณด๋ก device_map ์๋ ๊ณ์ฐ: ๊ฐ ๊ฐ์ค์น์ ํํ์ ๋ฐ์ดํฐ ํ์ ์ผ๋ก ๋ฉ๋ชจ๋ฆฌ ์๋น๋ ์ฌ์ ๊ณ์ฐ ํ CPU/GPU/๋์คํฌ ๋ฐฐ์น ๊ฒฐ์
- offload_folder์ offload_state_dict ํ๋ผ๋ฏธํฐ๋ฅผ ํตํ ๋์คํฌ ์คํ๋ก๋ฉ: ๋ก๋ํ ์ ์๋ ๊ฐ์ค์น๋ฅผ ๋์คํฌ์ ์ ์ฅํ๊ณ ํ์์์๋ง ๋ก๋
- dispatch_model ํจ์์ forward ์ /ํ ํ (hook) ์ถ๊ฐ: ๊ฐ ๋ชจ๋ ์คํ ์ ์ ๊ฐ์ค์น๋ฅผ ๊ฐ์ ๋๋ฐ์ด์ค๋ก ์ด๋, ์คํ ํ CPU/๋์คํฌ๋ก ๋ณต์
Impact
์ํฐํด์์ ์ ๋์ ์ฑ๋ฅ ์งํ(์๋ ๊ฐ์ , ๋ฉ๋ชจ๋ฆฌ ์ ๊ฐ %)๊ฐ ๋ช ์๋์ง ์์.
Key Takeaway
๊ทน๋๋ก ํฐ ๋ชจ๋ธ์ ๋ก๋ํ ๋ ์ ์ฒด ๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ์ ์ฌ๋ฆฌ๋ ๋์ , ๊ณ์ฐ ๊ทธ๋ํ์ ํํ ์ ๋ณด๋ก ๋ฏธ๋ฆฌ ๋ฐฐ์น๋ฅผ ๊ฒฐ์ ํ๊ณ ๋จ๊ณ๋ณ๋ก ํ์ํ ๋ถ๋ถ๋ง ๋์ ์ผ๋ก ๋ก๋ํ๋ ์ง์ฐ ๋ก๋ฉ ํจํด์ด ํต์ฌ์ด๋ค. ์ด๋ ๋จ์ผ ๋์ฉ๋ ๋ฉ๋ชจ๋ฆฌ ๋์ ์ฌ๋ฌ ์ด๊ธฐ์ข ์คํ ๋ฆฌ์ง(GPU, CPU RAM, ๋์คํฌ)๋ฅผ ์์ฐจ์ ์ผ๋ก ํ์ฉํ๋ ์ค๊ณ ์์น์ ๋ณด์ฌ์ค๋ค.
์ค์ฒ ํฌ์ธํธ
๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ ํ๋ ๋ฆฌ์์ค ํ๊ฒฝ(Colab, ๊ฐ์ธ GPU)์์ ์ถ๋ก ํด์ผ ํ ๋, PyTorch์ meta device๋ก ๋ชจ๋ธ ํํ๋ง ๋จผ์ ํ์ ํ๊ณ device_map์ผ๋ก ๊ฐ ๋ ์ด์ด์ ๋ฐฐ์น๋ฅผ ์ฌ์ ๊ฒฐ์ ํ ํ, ๋์ ํ ์ ํตํด ํฌ์๋ ํจ์ค ์ง์ ์๋ง ๊ฐ์ค์น๋ฅผ ๋ก๋ํ๋ฉด ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ ์๋ฌ ์์ด ์คํํ ์ ์๋ค.