Hugging Face Transformers๊ฐ Hub์์ ๋ค์ด๋ก๋ ๊ฐ๋ฅํ ์ปค์คํ ์ปค๋๊ณผ MXFP4 ์์ํ๋ฅผ ํตํฉํด GPT-OSS ๋ชจ๋ธ์ ๋ก๋ฉยท์ถ๋ก ยทํ์ธํ๋ ์ฑ๋ฅ์ 2~10๋ฐฐ ํฅ์
Tricks from OpenAI gpt-oss YOU ๐ซต can use with transformers
AI ์์ฝ
Context
์ปค๋ฎค๋ํฐ์์ ๊ฐ๋ฐ๋ Flash Attention, Liger RMSNorm, MegaBlocks MoE ๋ฑ์ ์ปค์คํ ์ปค๋๋ค์ด ์๋ก ๋ค๋ฅธ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์ฐ์ฌ๋์ด ์์ด ์์กด์ฑ ์ฆ๊ฐ์ CUDA/C++ ์ปดํ์ผ ์๊ตฌ์ฌํญ์ด ๋ฐ์ํ๋ค. ๊ฐ ๋ชจ๋ธ ํตํฉ ์๋ง๋ค ์๋ก์ด ์ปค๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ถ๊ฐํด์ผ ํ๋ ๊ตฌ์กฐ๋ก ์ธํด ๋ผ์ด๋ธ๋ฌ๋ฆฌ ๋ณต์ก๋๊ฐ ์ฆ๊ฐํ๊ณ ์์๋ค.
Technical Solution
- Zero-build Kernels ํจํค์ง ๋์
: Hub์์ ์ฌ์ ์ปดํ์ผ๋ ์ปค๋ ๋ฐ์ด๋๋ฆฌ๋ฅผ ๋ค์ด๋ก๋ํ๊ณ
@use_kernel_forward_from_hub()๋ฐ์ฝ๋ ์ดํฐ๋ก ์๋ ์ ํํ๋ ๊ตฌ์กฐ๋ก ๋ณ๊ฒฝ - Liger RMSNorm ์ปค๋ ํตํฉ:
@use_kernel_forward_from_hub("RMSNorm")๋ฐ์ฝ๋ ์ดํฐ๋ก ์ ๊ทํ ์ฐ์ฐ ์ต์ ํ - MegaBlocks MoE ์ปค๋ ํตํฉ:
@use_kernel_forward_from_hub("MegaBlocksMoeMLP")๋ฐ์ฝ๋ ์ดํฐ๋ก Mixture of Experts ์ฐ์ฐ ๊ฐ์ - Flash Attention 3 ํตํฉ: Attention Sinks๋ฅผ ์ง์ํ๋ Flash Attention 3 ์ปค๋์ Hopper ์ํคํ ์ฒ ๋์์ผ๋ก ์ถ๊ฐ
- MXFP4 ์์ํ ์ปค๋ ์ถ๊ฐ: Triton ๊ธฐ๋ฐ MXFP4 ์์ํ ์ฐ์ฐ์ ์ปค์คํ ์ปค๋๋ก ์ ๊ณต
- ๋๋ฐ์ด์ค ์๋ ๋ก๋ฉ ์ต์ ํ:
device_map="auto"๋๋ Tensor Parallel ์คํ ์ ๋ฉํฐ GPU ๋ก๋ฉ ์๋ ๊ฐ์ - ์ปค๋ฎค๋ํฐ ์ปค๋ ์๋ ์ ํ: CUDA/ROCm ์ฌ๋ถ ๋ฐ ํ๋ จ/์ถ๋ก ๋ชจ๋์ ๋ฐ๋ผ ํธํ ์ปค๋์ ์๋ ์ ํ
Impact
- PyTorch 2.0์ torch.compile๊ณผ TorchInductor ๋ฐฑ์๋๋ 2~10๋ฐฐ ์ฑ๋ฅ ํฅ์ ์ ๊ณต
- ์ปค์คํ ์ปค๋ ์ฌ์ฉ ์ ๋ ํฐ ๋ฐฐ์น ํฌ๊ธฐ์์ ์ต์ ์ฑ๋ฅ ๋ฌ์ฑ (Figure 1 ๋ฒค์น๋งํฌ ๊ฒฐ๊ณผ)
Key Takeaway
์ปค์คํ ์ปค๋์ ์ค์ ๋ฆฌํฌ์งํ ๋ฆฌ(Hub)์์ ์ฌ์ ์ปดํ์ผ ๋ฐ์ด๋๋ฆฌ๋ก ๋ฐฐํฌํ๊ณ ๋ฐ์ฝ๋ ์ดํฐ ํจํด์ผ๋ก ์ถ์ํํ๋ฉด, ์์กด์ฑ ์ฆ๊ฐ์ ์ปดํ์ผ ์ค๋ฒํค๋๋ฅผ ์ ๊ฑฐํ๋ฉด์๋ ์ฌ๋ฌ ๋ชจ๋ธ์์ ์ฌ์ฌ์ฉ ๊ฐ๋ฅํ ์ต์ ํ ๊ธฐ๋ฒ์ ํ์ฐํ ์ ์๋ค. ์ด๋ ์ปค๋ฎค๋ํฐ ๊ธฐ์ฌ ์ปค๋์ ์ฐธ์กฐ ๊ตฌํ์ผ๋ก ์ ๊ณตํจ์ผ๋ก์จ MLX, llama.cpp, vLLM ๊ฐ์ ๋ค๋ฅธ ํ๋ ์์ํฌ์ ํ์ต ์๋ฃ๋ก๋ ํ์ฉ๋๋ค.
์ค์ฒ ํฌ์ธํธ
GPT-OSS ๊ฐ์ ๋๊ท๋ชจ ์ธ์ด๋ชจ๋ธ์ ์ด์ํ๋ ํ์์ `AutoModelForCausalLM.from_pretrained(model_id, use_kernels=True)`๋ก ๋ก๋ฉํ๋ฉด ์ถ๊ฐ ์์กด์ฑ ์ค์น ์์ด Liger RMSNorm, MegaBlocks MoE, Flash Attention 3 ๋ฑ์ ์ปค์คํ ์ปค๋์ด ์๋ ๋ค์ด๋ก๋ยท์ ์ฉ๋์ด ๋ฐฐ์น ํฌ๊ธฐ์ ๋ฐ๋ผ ์ถ๋ก ์ฑ๋ฅ์ ํฅ์์ํฌ ์ ์๋ค. ๋ค๋ง MXFP4 ์์ํ ์ปค๋ ์ฌ์ฉ ์์๋ bfloat16 ํ์ ์ถ๋ก ์ผ๋ก ์ ํ๋๋ฏ๋ก ๋ฉ๋ชจ๋ฆฌ์ ์ฒ๋ฆฌ๋ ํธ๋ ์ด๋์คํ๋ฅผ ๋ฒค์น๋งํฌํด์ผ ํ๋ค.