피드로 돌아가기
KV Cache from scratch in nanoVLM
Hugging Face BlogHugging Face Blog
Backend

nanoVLM에서 KV Cache를 구현해 자동회귀 생성 중 불필요한 재연산을 제거함으로써 38% 속도 향상

KV Cache from scratch in nanoVLM

2025년 6월 4일12intermediate

Context

자동회귀 언어 모델은 각 새로운 토큰을 생성할 때마다 전체 시퀀스에 대해 Transformer의 모든 레이어를 통과시켜야 합니다. 이 과정에서 이전에 이미 계산된 Key와 Value 행렬이 변경되지 않음에도 불구하고 매번 처음부터 재연산되어, 시퀀스 길이에 따른 이차 복잡도의 메모리와 연산 오버헤드가 발생합니다.

Technical Solution

  • Self-Attention 메커니즘에서 K와 V 재연산 방지: 초기 프롬프트 처리 후 각 레이어의 Key와 Value를 캐시에 저장하고, 새로운 토큰 생성 시에만 해당 토큰의 K, V를 계산해 캐시에 추가
  • 생성 과정을 두 단계로 분리: PREFILL PHASE에서 전체 프롬프트를 인코딩하고 초기 캐시 구축, DECODE PHASE에서는 새 토큰의 Q만 계산하고 캐시된 K, V와 결합
  • 레이어별 KV 캐시 구조화: 각 레이어마다 "key"와 "value" 딕셔너리를 유지하며, 형태는 (batch_size, num_heads, seq_len_cached, head_dim)
  • LanguageModelGroupedAttention.forward 수정: Q, K, V의 전체 시퀀스 재연산 대신 캐시를 사용하고 업데이트하도록 변경
  • LanguageModel.forward 확장: start_pos 파라미터를 도입해 위치 인코딩 정확성을 보장하면서 상태 추적 기능 추가

Impact

  • 생성 단계의 시간 성능: 38% 속도 향상
  • 토큰당 연산 복잡도: O(sequence_length²)에서 O(sequence_length)로 감소

Key Takeaway

KV Cache는 자동회귀 모델의 생성 과정에서 PREFILL과 DECODE 두 단계로 명확히 분리함으로써 캐시 불일치 문제를 원천 차단하고, 단순한 메모리 트레이드오프(메모리 사용량 증가)로 선형 시간 생성을 달성할 수 있는 설계 패턴입니다. 다만 빔 서치 같은 고급 추론 기법의 적용이 제한되는 한계가 있습니다.


PyTorch로 Transformer 기반 언어 모델이나 멀티모달 모델을 구현하는 팀에서 KV Cache를 도입하면, 토큰당 연산량을 시퀀스 길이의 이차식에서 선형식으로 줄여 장문 생성 시 추론 지연을 크게 단축할 수 있으며, 이는 특히 소비자 하드웨어에서 모델을 실행해야 하는 상황에서 실행 가능성을 확보하는 핵심 최적화입니다.

원문 읽기