ํ”ผ๋“œ๋กœ ๋Œ์•„๊ฐ€๊ธฐ
Make LLM Fine-tuning 2x faster with Unsloth and ๐Ÿค— TRL
Hugging Face BlogHugging Face Blog
AI/ML

Unsloth๊ฐ€ Triton ์ปค๋„ ๊ธฐ๋ฐ˜ ์ตœ์ ํ™”๋กœ LLM ํŒŒ์ธํŠœ๋‹ ์†๋„ 2๋ฐฐ ํ–ฅ์ƒ ๋ฐ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ 40% ๊ฐ์†Œ

Make LLM Fine-tuning 2x faster with Unsloth and ๐Ÿค— TRL

2024๋…„ 1์›” 10์ผ7๋ถ„intermediate

Context

LLM ํŒŒ์ธํŠœ๋‹์€ ๊ณ„์‚ฐ ๋น„์šฉ์ด ๋†’๊ณ  ์‹œ๊ฐ„์ด ์˜ค๋ž˜ ๊ฑธ๋ฆฌ๋Š” ๋ณ‘๋ชฉ ์ž‘์—…์ด๋‹ค. ํŠนํžˆ QLoRA ๊ธฐ๋ฐ˜ ํŒŒ์ธํŠœ๋‹๋„ ์—ฌ์ „ํžˆ ์ถฉ๋ถ„ํžˆ ๋น ๋ฅด์ง€ ์•Š์•„ ๊ฐœ๋ฐœ ์ƒ์‚ฐ์„ฑ์„ ์ œํ•œํ•œ๋‹ค.

Technical Solution

  • Pytorch ๋ชจ๋“ˆ์„ Triton ์ปค๋„๋กœ ์žฌ์ž‘์„ฑ: ์ˆ˜๋™์œผ๋กœ ์—ญ์ „ํŒŒ ๋‹จ๊ณ„๋ฅผ ๋„์ถœํ•˜๊ณ  ์ตœ์ ํ™”๋œ ์—ฐ์‚ฐ์œผ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๊ฐ์†Œ ๋ฐ ์—ฐ์‚ฐ ์†๋„ ํ–ฅ์ƒ
  • FastLanguageModel.from_pretrained ๋ž˜ํผ ์ œ๊ณต: ๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ ์ž๋™์œผ๋กœ ์ตœ์ ํ™”๋œ ์—ฐ์‚ฐ์ด ์ ์šฉ๋˜๋ฉฐ, ๊ธฐ์กด transformers API์™€ ํ˜ธํ™˜
  • QLoRA ์–ด๋Œ‘ํ„ฐ ์ž๋™ ๊ตฌ์„ฑ: FastLanguageModel.get_peft_model์œผ๋กœ ์–ดํ…์…˜(q_proj, k_proj, v_proj, o_proj)๊ณผ MLP(gate_proj, up_proj, down_proj) ๋ ˆ์ด์–ด์— LoRA ์ ์šฉ
  • TRL ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์™„์ „ ํ†ตํ•ฉ: SFTTrainer, DPOTrainer, PPOTrainer์™€ ์ง์ ‘ ํ˜ธํ™˜๋˜์–ด ๊ธฐ์กด ํŒŒ์ดํ”„๋ผ์ธ์—์„œ ์ฆ‰์‹œ ์‚ฌ์šฉ ๊ฐ€๋Šฅ
  • RoPE ์Šค์ผ€์ผ๋ง ์ž๋™ ์ฒ˜๋ฆฌ: ์ตœ๋Œ€ ์‹œํ€€์Šค ๊ธธ์ด ์„ค์ • ์‹œ ์ž๋™์œผ๋กœ ์œ„์น˜ ์ธ์ฝ”๋”ฉ์ด ํ™•์žฅ๋˜์–ด ์ถ”๊ฐ€ ๊ตฌํ˜„ ๋ถˆํ•„์š”
  • 4๋น„ํŠธ ์‚ฌ์ „ ์–‘์žํ™” ๋ชจ๋ธ ์ง€์›: Transformers 4.36+ ์—์„œ ์‚ฌ์ „ ์–‘์žํ™”๋œ ๋ชจ๋ธ์„ 4๋ฐฐ ๋น ๋ฅด๊ฒŒ ๋กœ๋“œํ•˜๊ณ  ๋ฉ”๋ชจ๋ฆฌ ๋‹จํŽธํ™” 500MB ๊ฐ์†Œ

Impact

A100 40GB์—์„œ Code Llama 34b๋Š” 1.94๋ฐฐ ๋น ๋ฅด๊ณ  VRAM 22.7% ๊ฐ์†Œ, Llama-2 7b๋Š” 1.87๋ฐฐ ๋น ๋ฅด๊ณ  39.3% ๊ฐ์†Œ, Mistral 7b๋Š” 1.88๋ฐฐ ๋น ๋ฅด๊ณ  65.9% ๊ฐ์†Œ, Tiny Llama 1.1b๋Š” 2.74๋ฐฐ ๋น ๋ฅด๊ณ  57.8% ๊ฐ์†Œํ–ˆ๋‹ค. ๋ฌด๋ฃŒ Google Colab T4 ์ธ์Šคํ„ด์Šค์—์„œ Llama-2 7b๋Š” 1.95๋ฐฐ ๋น ๋ฅด๊ณ  43.3% ๋ฉ”๋ชจ๋ฆฌ ๊ฐ์†Œ, Tiny Llama 1.1b๋Š” 3.87๋ฐฐ ๋น ๋ฅด๊ณ  73.8% ๊ฐ์†Œํ–ˆ๋‹ค. ์ „์ฒด 59ํšŒ ๋ฒค์น˜๋งˆํฌ์—์„œ ์ตœ๋Œ€ 2.7๋ฐฐ ์†๋„ ํ–ฅ์ƒ๊ณผ ์ตœ๋Œ€ 74% ๋ฉ”๋ชจ๋ฆฌ ์ ˆ๊ฐ์„ ๋‹ฌ์„ฑํ–ˆ๋‹ค.

Key Takeaway

ํŒŒ์ธํŠœ๋‹ ์„ฑ๋Šฅ ํ–ฅ์ƒ์€ ๊ทผ์‚ฌ ์—†์ด ์ปค๋„ ์ˆ˜์ค€์˜ ์ •ํ™•ํ•œ ์ตœ์ ํ™”๋กœ ๋‹ฌ์„ฑํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๊ธฐ์กด ์ƒํƒœ๊ณ„(HF Hub, transformers, TRL)์™€์˜ ์™„์ „ ํ˜ธํ™˜์„ ์œ ์ง€ํ•˜๋ฉด์„œ ํˆฌ๋ช…ํ•˜๊ฒŒ ๋„์ž… ๊ฐ€๋Šฅํ•˜๋‹ค๋Š” ์ ์ด ํ•ต์‹ฌ์ด๋‹ค.


Llama, Mistral ๊ธฐ๋ฐ˜ LLM์„ ํŒŒ์ธํŠœ๋‹ํ•˜๋Š” ์—”์ง€๋‹ˆ์–ด๋Š” ๊ธฐ์กด SFTTrainer ๋˜๋Š” DPOTrainer ์ฝ”๋“œ๋ฅผ FastLanguageModel.from_pretrained์™€ get_peft_model๋กœ ๊ฐ์‹ธ๊ธฐ๋งŒ ํ•˜๋ฉด ์ •ํ™•๋„ ์†์‹ค ์—†์ด ํŒŒ์ธํŠœ๋‹ ์†๋„๋ฅผ

1.8~

3.8๋ฐฐ ํ–ฅ์ƒ์‹œํ‚ค๊ณ  ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ 40~74% ์ค„์ผ ์ˆ˜ ์žˆ๋‹ค.

์›๋ฌธ ์ฝ๊ธฐ