피드로 돌아가기
Dev.toAI/ML
원문 읽기
Symmetric Pooling으로 512K 컨텍스트 전방 패스 21배 가속 및 학습 시간 30% 단축
Lighthouse Attention: The Training-Time Hierarchy That Makes Quadratic Attention Practical Again
AI 요약
Context
FlashAttention 도입 후에도 Scaled Dot-Product Attention의 $\Theta(N^2)$ 연산 복잡도로 인한 Compute Bottleneck 지속. 기존 Sparse Attention 방식은 Query의 Full Resolution 유지로 인한 $O(NSd)$ 복잡도 유지 및 Custom Kernel 의존성에 따른 최적화 제약 발생.
Technical Solution
- Q, K, V 전체를 L-level Pyramid 구조로 Symmetric Pooling 하여 연산 복잡도를 $O(S^2d)$로 획기적 감소
- Attention Kernel 외부에서 Pyramid Pooling, $\ell_2$-norm Scoring, Chunked-bitonic Top-K Selection을 수행하는 4단계 파이프라인 설계
- Non-differentiable Top-K 선택을 통해 Gradient가 선택된 Q, K, V 엔트리에만 흐르게 하여 가중치 최적화 유도
- Chunked-bitonic 메커니즘을 적용하여 특정 구간으로의 Attention Collapse를 방지하고 층별 선택 분포의 다양성 확보
- 최하단 Pyramid Level 전체를 유지함으로써 모든 포지션에 최소 한 개의 기여자를 보장하는 구조적 안정성 확보
- Lighthouse 학습 후 Dense SDPA로 전환하는 2단계 학습 레시피를 통해 Inference 시의 Recoverability 보장
실천 포인트
1. 32K 이상의 Long Context Pretraining 시 Pyramid Pooling 기반의 QKV 동시 압축 검토
2. Sparse 학습 도입 시 추론 단계의 Dense 모델 전환을 위한 2단계 학습(Sparsified $\rightarrow$ Dense) 파이프라인 구축
3. 커스텀 커널 개발 비용을 줄이기 위해 Selection Logic을 표준 Attention Kernel 외부에 배치하는 구조 채택