피드로 돌아가기
Hugging Face on PyTorch / XLA TPUs
Hugging Face BlogHugging Face Blog
AI/ML

Hugging Face가 PyTorch/XLA를 Trainer 모듈에 통합해 Cloud TPU에서 트랜스포머 학습을 네이티브 지원

Hugging Face on PyTorch / XLA TPUs

2021년 2월 9일12intermediate

Context

PyTorch 사용자들은 기존에 Cloud TPU에서 모델을 훈련하기 위해 별도의 복잡한 구현이 필요했다. Hugging Face와 Google TPU 팀의 협력으로 PyTorch-TPU 프로젝트가 2019년 공식 출시되었지만, 실제 통합은 아직 미흡한 상태였다.

Technical Solution

  • XLA 디바이스 타입 추가: PyTorch에 xm.xla_device() API를 통해 TPU/CPU/GPU를 추상화된 단일 인터페이스로 제공
  • TrainingArguments._setup_devices() 메서드 수정: XLA:TPU 감지 시 자동으로 TPU 디바이스를 반환하도록 구현
  • 옵티마이저 스텝 최적화: xm.optimizer_step(optimizer)를 통해 다중 TPU 코어 간 그래디언트 동기화 및 병렬 처리
  • 입력 파이프라인 최적화: MpDeviceLoader를 사용해 모델 그래프 추적(step n+1)과 현재 스텝 실행(step n)을 겹쳐서 처리
  • 체크포인트 저장/로드 메커니즘: xm.save() API로 XLA 텐서를 CPU 디바이스로 변환한 후 저장하고, 마스터 프로세스만 쓰기 수행

Impact

bert-large-uncased를 v3-8 Cloud TPU(4개 TPUv3 칩)에서 WikiText103으로 훈련: FP32 정밀도에서 178.4분 소요, BF16 정밀도에서 106.4분 소요 (약 40% 단축)

Key Takeaway

PyTorch 사용자들이 기존 Trainer 인터페이스를 유지하면서 TPU 하드웨어의 병렬 처리 특성을 활용하려면, 디바이스 추상화 계층(xla_device)과 그래디언트 동기화(optimizer_step), 입력 파이프라인 겹침(pipelining) 세 가지 핵심 메커니즘을 프레임워크 수준에서 통합해야 한다.


PyTorch 기반 트랜스포머 훈련을 수행하는 팀에서 Cloud TPU 인프라를 도입할 때, Hugging Face Trainer의 TrainingArguments에 `--use_xpu` 플래그만 추가하면 자동으로 XLA TPU 디바이스로 전환되고 그래디언트 동기화까지 처리되므로, 기존 코드 수정 없이 BF16 정밀도로 약 40% 훈련 시간을 단축할 수 있다.

원문 읽기