Responsive Advertisement

Mamba/S4의 핵심 병목: '순환 관계'를 '병렬 연관 스캔'으로 해결하는 방법

최근 LLM(거대 언어 모델) 분야에서 Mamba와 같은 상태 공간 모델(SSM)은 트랜스포머의 $O(N^2)$ 계산 복잡도를 $O(N)$으로 혁신적으로 낮추며 큰 주목을 받고 있습니다.

하지만 이 놀라운 성능은 단순한 이론에 그치지 않습니다. 실제 하드웨어, 특히 GPU에서 이 성능을 구현하려면 모델의 순차적인 '순환 관계'가 병렬 처리에 심각한 병목이 되는 문제를 해결해야 합니다.

이 글은 Mamba의 핵심 병목을 해결하는 '병렬 결합 스캔(Parallel Associative Scan)'의 수학적 원리를 파헤치고, 이를 CUDA 및 Triton으로 구현하여 하드웨어의 한계까지 성능을 끌어올리는 전문화된 커널 최적화 전략을 깊이 있게 다룹니다.

1. Mamba/S4의 핵심 병목: 순환 관계

Mamba의 핵심은 선택적 SSM입니다.

이 모델은 다음과 같은 이산 상태 공간 방정식으로 정의되는 선형 순환 관계를 포함합니다.

$$h_t = \mathbf{A} h_{t-1} + \mathbf{B} x_t$$ $$y_t = \mathbf{C} h_t + \mathbf{D} x_t$$

여기서 $h_t$는 상태 벡터, $x_t$는 입력, $\mathbf{A}$와 $\mathbf{B}$는 시퀀스의 길이에 따라 동적으로 변하는(선택적) 매개변수입니다.

트랜스포머와 달리, 이 계산은 본질적으로 순차적(Sequential)입니다. 즉, $h_t$를 계산하려면 반드시 $h_{t-1}$이 필요합니다.

GPU는 병렬 처리에 최적화되어 있으므로, 이 순차적인 순환 관계는 GPU 활용률을 저해하는 심각한 병목이 됩니다. 우리는 이 순차적 계산을 병렬 연산으로 변환해야 합니다.

2. 병렬 결합 스캔 (Parallel Associative Scan)의 원리

순환 관계를 병렬화하기 위한 핵심 기법이 바로 병렬 결합 스캔입니다.

결합 스캔(Associative Scan)은 흔히 접두사 합(Prefix Sum) 연산의 일반화된 형태로, 연산자 $\oplus$가 결합 법칙($a \oplus (b \oplus c) = (a \oplus b) \oplus c$)을 만족한다면, 순차적인 계산을 병렬로 수행할 수 있음을 의미합니다.

SSM의 순환식은 다음과 같은 형태로 변환될 수 있습니다.

각 시간 단계 $t$에서, 새로운 쌍 $(A_t, B_t x_t)$를 정의하고, 이들을 순차적으로 '결합'하는 연산 $\oplus$를 다음과 같이 정의합니다.

$$ (A_i, B_i) \oplus (A_j, B_j) = (A_j A_i, B_j + A_j B_i) $$

이 매트릭스 결합 연산은 결합 법칙을 만족하며, 전체 시퀀스에 대한 상태 $h_T$는 이 결합 연산의 접두사 합(Prefix Sum)을 통해 계산됩니다.

따라서 $T$ 길이의 시퀀스 계산을 $O(\log T)$ 시간 복잡도로 병렬 처리할 수 있게 됩니다. 이것이 Mamba/S4의 고속 시퀀스 처리의 수학적 근간입니다.

힐리스-스테일 알고리즘(Hillis-Steele) 기반 구현

GPU 상에서 병렬 결합 스캔을 구현하는 표준 방식 중 하나는 힐리스-스테일(Hillis-Steele) 또는 래드너(Radix) 알고리즘을 변형하여 사용하는 것입니다.

힐리스-스테일 방식은 다음 의사 코드와 같이 구현됩니다. 여기서 각 스레드는 시퀀스의 한 요소를 담당합니다.

// P: 시퀀스 길이, N: 스레드 개수 (N = P)
// V[i] = (A_i, B_i)

for (d = 1; d < P; d *= 2) {
    if (i >= d) {
        // 병렬 결합 연산 수행
        V[i] = V[i - d] ⊕ V[i];
    }
    __syncthreads();
}

실제 Mamba 커널에서는 $\mathbf{A}$와 $\mathbf{B}$가 고정된 $D_{state}$ 차원을 가지므로, 이 연산은 작은 행렬 곱셈과 벡터 덧셈으로 구성됩니다.

효율적인 구현을 위해서는 이 연산을 GPU의 레지스터나 공유 메모리(Shared Memory) 내에서 완전 병렬로 처리해야 합니다.

3. 커널 최적화와 특수화 (Kernel Specialization)

Mamba SSM 블록의 커널 퓨전

Mamba 커널 최적화의 핵심은 커널 퓨전(Kernel Fusion)입니다.

전통적인 LLM 파이프라인은 여러 개의 작은 연산(LayerNorm, Activation, MatMul)을 연속적인 CUDA 커널로 실행했습니다. 이 과정에서 발생하는 GPU 전역 메모리 접근(Global Memory Access) 오버헤드는 성능을 크게 저하시킵니다.

Mamba SSM 블록에서는 이 순환 계산과 관련된 모든 요소(DPLR 구조의 행렬 초기화, 입력 변수 선택, 순환 계산, 출력 프로젝션)를 하나의 거대한 커널로 합쳐야 합니다.

특히 Triton을 사용할 경우, 이 퓨전은 상대적으로 쉽습니다. Triton은 파이썬 코드를 기반으로 GPU 코드를 생성하며, 컴파일러가 자동으로 인접한 연산을 하나의 커널 내에서 레지스터 또는 공유 메모리를 사용하여 처리하도록 최적화합니다.

이를 통해 전역 메모리 접근을 최소화하고, 산술 강도(Arithmetic Intensity)를 극대화합니다.

메모리 계층 구조 최적화 (Shared Memory, Registers)

병렬 결합 스캔 커널에서 $D_{state}$ 차원(일반적으로 16 또는 64)은 비교적 작습니다. 이는 큰 이점입니다.

우리는 이 작은 행렬/벡터 연산을 GPU의 가장 빠른 메모리 계층인 레지스터(Registers)와 공유 메모리(Shared Memory)에 고정하여 처리해야 합니다.

Register Blocking: $A$와 $B$의 데이터를 워프(Warp) 내 스레드들이 레지스터에 분배하여 캐시 미스 없이 즉시 접근하도록 합니다.

Shared Memory Caching: 시퀀스 길이 $T$가 매우 길 경우, 전체 시퀀스를 한 번에 처리하기 어렵습니다. 시퀀스를 청크(Chunk)로 나누고, 각 청크의 최종 상태(Carry-out state)를 다음 청크의 초기 상태(Carry-in state)로 전달할 때, 이 경계 상태(Boundary State)를 공유 메모리에 저장하여 빠르게 로드해야 합니다.

4. CUDA와 Triton을 활용한 구현 전략

Triton을 이용한 고수준 병렬화

최근 Mamba 구현체들은 성능과 개발 용이성을 위해 C++ 기반 CUDA 대신 Triton을 적극적으로 사용하고 있습니다.

Triton은 병렬 스캔 커널 구현에 강력한 도구입니다. Triton의 @triton.jit 데코레이터를 사용하면, 개발자는 메모리 로드/스토어와 계산 로직에 집중하고, 블록 및 워프 스케줄링 같은 저수준 세부 사항은 Triton 컴파일러에 맡길 수 있습니다.

특히, Triton은 tl.associative_scan과 같은 내장 함수를 제공하지 않더라도, 개발자가 Section 2에서 설명한 힐리스-스테일 스타일의 병렬 루프를 명시적으로 작성할 수 있게 합니다.

Triton은 이를 GPU 하드웨어에 최적화된 형태로 변환하여, CUDA C++로 직접 구현하는 것과 유사하거나 때로는 더 나은 성능을 제공합니다.

상태 공간 차원(D_state)에 대한 병렬 처리

Mamba 커널의 병렬화는 두 가지 차원에서 이루어집니다:

1. 시퀀스 차원 (Sequence Length, T): 병렬 결합 스캔을 통해 $T$ 차원을 $O(\log T)$로 가속합니다.

2. 상태 차원 (State Dimension, $D_{state}$): $A$ 행렬과 $B$ 벡터의 연산은 $D_{state}$ 차원에 대한 독립적인 연산으로 간주될 수 있습니다. 여러 개의 워프를 할당하여 $D_{state}$ 차원에 걸쳐 병렬로 연산을 분배합니다. 예를 들어, $D_{state}=64$인 경우, 여러 스레드(혹은 워프)가 64개의 요소를 나누어 처리하여 행렬 곱셈을 병렬화합니다.

성공적인 Mamba 커널은 이 두 가지 차원의 병렬화를 완벽하게 오버랩(Overlap)하고 커널 퓨전을 통해 메모리 접근을 최소화하는 데 달려 있습니다. 이를 통해 시퀀스 처리 시의 메모리 대역폭(Bandwidth) 제한을 극복합니다.

결론: 차세대 모델 가속화의 청사진

Mamba와 S4의 등장은 LLM 최적화의 패러다임을 메모리 대역폭 의존적인 GEMM 연산에서, 산술 연산 의존적인 병렬 결합 스캔 커널로 전환시켰습니다.

더 이상 전문화된 커널 최적화는 선택 사항이 아닙니다.

고성능 AI 모델 구현을 위해서는 아키텍처의 수학적 특성(여기서는 결합 법칙)을 깊이 이해하고, Triton과 같은 유연한 도구를 활용하여 GPU의 레지스터 수준까지 성능을 끌어내리는 정교한 커스터마이징이 필수적입니다.

이 기술은 Mamba뿐만 아니라, 향후 등장할 시퀀스 처리 기반의 모든 새로운 신경망 아키텍처에 적용될 수 있는 가속화의 핵심 청사진을 제시합니다.

결국, 개발자들은 이제 순환적 병목을 해결하는 독창적인 커널 설계 능력으로 차세대 AI 모델의 효율성을 결정하게 될 것입니다.