Unsloth๊ฐ Triton ์ปค๋ ๊ธฐ๋ฐ ์ต์ ํ๋ก LLM ํ์ธํ๋ ์๋ 2๋ฐฐ ํฅ์ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ 40% ๊ฐ์
Make LLM Fine-tuning 2x faster with Unsloth and ๐ค TRL
AI ์์ฝ
Context
LLM ํ์ธํ๋์ ๊ณ์ฐ ๋น์ฉ์ด ๋๊ณ ์๊ฐ์ด ์ค๋ ๊ฑธ๋ฆฌ๋ ๋ณ๋ชฉ ์์ ์ด๋ค. ํนํ QLoRA ๊ธฐ๋ฐ ํ์ธํ๋๋ ์ฌ์ ํ ์ถฉ๋ถํ ๋น ๋ฅด์ง ์์ ๊ฐ๋ฐ ์์ฐ์ฑ์ ์ ํํ๋ค.
Technical Solution
- Pytorch ๋ชจ๋์ Triton ์ปค๋๋ก ์ฌ์์ฑ: ์๋์ผ๋ก ์ญ์ ํ ๋จ๊ณ๋ฅผ ๋์ถํ๊ณ ์ต์ ํ๋ ์ฐ์ฐ์ผ๋ก ๋ณํํ์ฌ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ๊ฐ์ ๋ฐ ์ฐ์ฐ ์๋ ํฅ์
- FastLanguageModel.from_pretrained ๋ํผ ์ ๊ณต: ๋ชจ๋ธ ๋ก๋ฉ ์ ์๋์ผ๋ก ์ต์ ํ๋ ์ฐ์ฐ์ด ์ ์ฉ๋๋ฉฐ, ๊ธฐ์กด transformers API์ ํธํ
- QLoRA ์ด๋ํฐ ์๋ ๊ตฌ์ฑ: FastLanguageModel.get_peft_model์ผ๋ก ์ดํ ์ (q_proj, k_proj, v_proj, o_proj)๊ณผ MLP(gate_proj, up_proj, down_proj) ๋ ์ด์ด์ LoRA ์ ์ฉ
- TRL ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์์ ํตํฉ: SFTTrainer, DPOTrainer, PPOTrainer์ ์ง์ ํธํ๋์ด ๊ธฐ์กด ํ์ดํ๋ผ์ธ์์ ์ฆ์ ์ฌ์ฉ ๊ฐ๋ฅ
- RoPE ์ค์ผ์ผ๋ง ์๋ ์ฒ๋ฆฌ: ์ต๋ ์ํ์ค ๊ธธ์ด ์ค์ ์ ์๋์ผ๋ก ์์น ์ธ์ฝ๋ฉ์ด ํ์ฅ๋์ด ์ถ๊ฐ ๊ตฌํ ๋ถํ์
- 4๋นํธ ์ฌ์ ์์ํ ๋ชจ๋ธ ์ง์: Transformers 4.36+ ์์ ์ฌ์ ์์ํ๋ ๋ชจ๋ธ์ 4๋ฐฐ ๋น ๋ฅด๊ฒ ๋ก๋ํ๊ณ ๋ฉ๋ชจ๋ฆฌ ๋จํธํ 500MB ๊ฐ์
Impact
A100 40GB์์ Code Llama 34b๋ 1.94๋ฐฐ ๋น ๋ฅด๊ณ VRAM 22.7% ๊ฐ์, Llama-2 7b๋ 1.87๋ฐฐ ๋น ๋ฅด๊ณ 39.3% ๊ฐ์, Mistral 7b๋ 1.88๋ฐฐ ๋น ๋ฅด๊ณ 65.9% ๊ฐ์, Tiny Llama 1.1b๋ 2.74๋ฐฐ ๋น ๋ฅด๊ณ 57.8% ๊ฐ์ํ๋ค. ๋ฌด๋ฃ Google Colab T4 ์ธ์คํด์ค์์ Llama-2 7b๋ 1.95๋ฐฐ ๋น ๋ฅด๊ณ 43.3% ๋ฉ๋ชจ๋ฆฌ ๊ฐ์, Tiny Llama 1.1b๋ 3.87๋ฐฐ ๋น ๋ฅด๊ณ 73.8% ๊ฐ์ํ๋ค. ์ ์ฒด 59ํ ๋ฒค์น๋งํฌ์์ ์ต๋ 2.7๋ฐฐ ์๋ ํฅ์๊ณผ ์ต๋ 74% ๋ฉ๋ชจ๋ฆฌ ์ ๊ฐ์ ๋ฌ์ฑํ๋ค.
Key Takeaway
ํ์ธํ๋ ์ฑ๋ฅ ํฅ์์ ๊ทผ์ฌ ์์ด ์ปค๋ ์์ค์ ์ ํํ ์ต์ ํ๋ก ๋ฌ์ฑํ ์ ์์ผ๋ฉฐ, ๊ธฐ์กด ์ํ๊ณ(HF Hub, transformers, TRL)์์ ์์ ํธํ์ ์ ์งํ๋ฉด์ ํฌ๋ช ํ๊ฒ ๋์ ๊ฐ๋ฅํ๋ค๋ ์ ์ด ํต์ฌ์ด๋ค.
์ค์ฒ ํฌ์ธํธ
Llama, Mistral ๊ธฐ๋ฐ LLM์ ํ์ธํ๋ํ๋ ์์ง๋์ด๋ ๊ธฐ์กด SFTTrainer ๋๋ DPOTrainer ์ฝ๋๋ฅผ FastLanguageModel.from_pretrained์ get_peft_model๋ก ๊ฐ์ธ๊ธฐ๋ง ํ๋ฉด ์ ํ๋ ์์ค ์์ด ํ์ธํ๋ ์๋๋ฅผ
1.8~
3.8๋ฐฐ ํฅ์์ํค๊ณ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ 40~74% ์ค์ผ ์ ์๋ค.