피드로 돌아가기
“Llama 3.2 in Keras”
Hugging Face BlogHugging Face Blog
AI/ML

Keras가 Llama 3.2 모델을 Hugging Face 체크포인트에서 즉시 로드 및 실행 지원으로 별도 변환 작업 제거

“Llama 3.2 in Keras”

2024년 10월 21일10intermediate

Context

Llama 3.2가 Hugging Face Transformers에 출시되었을 때, 다른 프레임워크에서 모델을 사용하려면 추가적인 변환 작업이 필요했다. Keras 사용자들도 새로운 모델 버전을 사용하기까지 기다려야 한다는 우려가 있었다.

Technical Solution

  • Hugging Face safetensors 체크포인트에서 직접 로드: Llama3CausalLM.from_preset("hf://meta-llama/Llama-3.2-1B-Instruct")를 통해 필요시 온-더-플라이 변환 수행
  • Multi-backend 지원: 환경 변수 KERAS_BACKEND를 통해 JAX, PyTorch, TensorFlow 간 동적 전환 가능
  • keras-hub 모델 라이브러리 제공: Llama3, Gemma, StableDiffusion, Segment Anything 등 사전학습 모델의 정식 Keras 구현 제공
  • 토크나이저 포함 완전 구현: model.generate("string")model.fit(strings) 직접 지원으로 문자열 기반 학습 가능
  • 분산 모델 병렬화: 8B 파라미터 모델을 여러 가속기에 샤딩하여 로드, get_layout_map(device_mesh) API로 기본 샤딩 전략 제공

Impact

  • 8B Llama 모델 파인튜닝을 Google TPU v5e에서 약 8분 내에 완료
  • 커스텀 레이아웃맵 적용 시 epoch당 처리 시간 62초에서 54초로 단축 (약 13% 개선)

Key Takeaway

Keras의 다중 백엔드 아키텍처와 Hugging Face 생태계 통합은 모델 프레임워크 간 호환성 문제를 제거하며, JAX의 XLA 컴파일러를 통한 분산 병렬화는 대규모 모델의 빠른 학습을 가능하게 한다.


LLM 파인튜닝이 필요한 엔지니어는 Keras의 `from_preset()` API를 통해 Hugging Face 체크포인트를 프레임워크 변환 없이 직접 로드하고, `KERAS_BACKEND` 환경 변수로 JAX 백엔드를 선택하여 XLA 컴파일 최적화를 활용할 수 있다. 또한 `get_layout_map(device_mesh)`로 자동 생성된 모델 병렬화 설정을 기반으로 커스텀 레이아웃맵을 작성하면 TPU/GPU 클러스터에서 메모리 초과 없이 8B 이상 모델을 학습할 수 있다.

원문 읽기