피드로 돌아가기
Improving Hugging Face Training Efficiency Through Packing with Flash Attention 2
Hugging Face BlogHugging Face Blog
Backend

Hugging Face가 DataCollatorWithFlattening과 Flash Attention 2를 결합해 패딩 제거 시퀀스 학습에서 2배 처리량 향상 달성

Improving Hugging Face Training Efficiency Through Packing with Flash Attention 2

2024년 8월 21일9intermediate

Context

미니배치 학습 시 입력 시퀀스 길이를 맞추기 위해 패딩을 추가하는 방식은 불필요한 패딩 토큰으로 인한 계산 낭비를 초래한다. 기존 패킹(packing) 구현은 Flash Attention 2 사용 시 시퀀스 경계 정보를 고려하지 않아 의도하지 않은 크로스-예제 어텐션이 발생하고 수렴 품질이 저하되었다.

Technical Solution

  • 시퀀스 경계 인식 기능 추가: DataCollatorWithFlattening을 통해 패킹 중 시퀀스 경계 정보를 유지하면서 Flash Attention 2와 호환성 확보
  • 누적 시퀀스 길이 계산: flash_attn_varlen_func를 사용해 미니배치별 누적 시퀀스 길이(cu_seqlens)를 계산하여 경계 인식 어텐션 구현
  • position_ids 활용: 모델에서 노출된 position_ids를 Flash Attention 2에 전달하여 패딩 제거 후에도 정확한 위치 정보 유지
  • Transformers 및 TRL 라이브러리 지원: Trainer 사용자는 DataCollatorWithFlattening 적용, SFTTrainer 사용자는 padding_free=True 플래그 설정으로 간편 활용
  • 모델 호환성: Llama 2/3, Mistral, Mixtral, Granite, DBRX, Falcon, Gemma, OLMo, Phi 1/2/3, Qwen 2, StableLM, StarCoder 2 등 14개 모델 지원

Impact

  • FLAN 데이터셋 처리량: Llama2-7B, Mistral-7B, Granite-8B-Code에서 2배 증가(A100-80 GPU 8대 기준)
  • OrcaMath 데이터셋 처리량: 1.4배 증가(예제 길이 분산이 낮은 경우)
  • 메모리 사용량: FLAN 데이터셋에서 20% 감소, OrcaMath 데이터셋에서 6% 감소
  • 수렴 품질: 패딩 방식과 동일한 검증 손실(validation loss) 유지, 학습 수렴에 영향 없음

Key Takeaway

시퀀스 경계 인식을 유지하면서 패킹을 구현하면 계산 효율성 손실 없이 처리량과 메모리 사용량을 동시에 개선할 수 있다. 예제 길이 분산이 클수록 패킹의 이점이 두드러지므로 데이터 특성에 따라 이득을 예측할 수 있다.


Hugging Face Transformers나 TRL을 사용해 대규모 언어 모델을 학습하는 엔지니어는 DataCollatorWithFlattening 또는 padding_free=True 옵션을 적용하면, 특히 시퀀스 길이 분산이 높은 instruction tuning 데이터셋에서 처리량을 2배까지 증가시키면서 수렴 품질은 유지할 수 있다.

원문 읽기