Mamba: Liner-Time Sequence Modeling With Selective State Space
우선 Mamba를 이해하기 위해선 많은 사전지식이 필요합니다. 관련 자료들을 아래 링크에 첨부해두겠습니다.
- Sequence Modeling with State Space Models(필독)
- A Visual Guide to Mamba and State Space Models(대부분이 참고된 글)
- Annotated-s4
What Problem does it attempt to Solve?
딥러닝에서 흥미로운 응용 프로그램 대부분을 구동하는 기반 모델(Foundation Model)은 거의 Transformer 아키텍처와 그 핵심인 attention 모듈을 기반으로 합니다. 하지만 Transformer는 긴 시퀀스를 처리할 때 계산 효율성이 떨어진다는 문제가 있습니다. 이를 해결하기 위해 선형 attention, gated convolution, recurrent 모델, 구조화된 상태 공간 모델(SSM) 등 계산 복잡도가 낮은 다양한 아키텍처가 개발되었지만, 언어와 같은 중요한 분야에서 attention 만큼의 성능을 보여주지 못했습니다.
심지어 S4(Structured State Space Model)조차도 언어 모델링 및 생성에서 중요한 특정 작업에서는 성능이 떨어집니다.
반면 Mamba는 선택적 복사(Selective Copying) 및 유도 헤드(Induction heads)라는 두 가지 합성 작업으로 설명할 수 있습니다.
Selective Copying 작업에서, SSM의 목표는 입력의 일부를 복사하여 순서대로 출력하는 것입니다.
그러나 SSM은 LTI(선형 시간 불변성)때문에 이 작업에서 성능이 떨어집니다. 행렬 A, B, C는 SSM이 생성하는 모든 토큰에 대해 동일하기 때문입니다.
결과적으로, SSM은 내용 인식 추론(프롬프트에 대한 추론?)을 수행할 수 없습니다.
Induction heads 작업은 입력에서 발견된 패턴을 재현하는 것이 목표입니다.
위 예시에서는 기본적으로 한 번의 프롬프팅을 수행하고 있습니다. 여기에서 모델에게 “Q:” 후에 “A:” 응답을 제공하도록 “가르치려고” 시도합니다. 그러나 SSM이 LTI(시간 불변성)을 가지고 있기 때문에, 이전 토큰 중 어떤 것을 기억할지 선택할 수 없습니다.
다시 말해 입력 x가 무엇이든 관계없이 행렬 B는 정확히 동일하게 유지되므로 행렬B는 x와 독립적입니다. 마찬가지로, A와 C도 입력에 관계없이 고정되어 있습니다.
비교적으로, 이러한 작업들은 Transformer에게는 상대적으로 쉽습니다. Transformer는 입력 시퀀스에 따라Attention을 통해 동적으로 변경할 수 있습니다. SSM이 이러한 작업에서 성능이 떨어지는 것은 LTI SSM의 근본적인 문제, 즉 A, B, C 행렬의 정적인 특성으로 인한 내용 인식 문제(content-awareness)를 의미합니다.
Selectively Retain Information
SSM의 재귀적 표현은 전체 Context를 압축하는 매우 효율적인 작은 상태를 생성합니다. 그러나 Attention 행렬을 통해 Context를 전혀 압축하지 않는 트랜스포머 모델과 비교하면 훨씬 덜 강력합니다(대신 연산 효율성이 좋다?).
Mamba는 이 두가지 장점을 모두 달성합니다(데이터를 선택적으로 압축한다).
입력 문장을 가지고 있을 때, 종종 중요한 의미를 가지지 않는 정보, 예를 들어 정지 단어(stop words)가 종종 있습니다.
정보를 선택적으로 압축하려면 입력에 따라 매개변수가 달라져야 합니다. 이를 위해, 훈련 중 SSM의 입력과 출력의 차원을 먼저 살펴봅시다.
Structured State Space Model (S4)에서, 행렬 A, B, C는 입력과 독립적입니다. 차원 N과 D는 정적이며 변하지 않습니다.
대신, Mamba는 행렬 B, C 그리고 스텝 크기 ∆를 입력받도록해서 시퀀스 길이와 배치 크기를 포함합니다.
이제 모든 입력 토큰에 대해 서로 다른 B와 C 행렬을 가지고 있다는 것을 의미하고 content-awareness를 해결합니다.
참고: 행렬 A는 동일하게 유지됩니다. 위 수식과 같이 입력에 직접적으로 영향을 받는 B와 C에 대해 동적으로 영향을 받기 위함.
따라서 B와 C는 입력에 의존하여 hidden state에서 무엇을 유지하고 무엇을 무시할지 선택적(selectively)으로 선택합니다.
더 작은 스텝 크기 ∆는 특정 단어를 무시하고 대신 이전 맥락을 더 사용하는 데 초점을 맞추지만 더 큰 스텝 크기 ∆는 맥락보다 입력 단어에 더 집중합니다.
수도코드로 다시 보면 다음과 같습니다.
The Scan Operation
이러한 행렬이 이제 동적이기 때문에, 고정된 커널을 가정하는 컨볼루션 표현을 사용하여 계산할 수 없습니다. 따라서 컨볼루션의 병렬화를 잃고 재귀적 표현만을 사용할 수 있습니다(SSM의 장점 삭제).
재귀적으로 계산을 출력하는 과정을 보면 이전 상태를 가지고 있어야만 각 상태를 계산할 수 있으므로 불가능해 보입니다. 그러나 맘바는 병렬 스캔 알고리즘을 통해 이것을 가능하게 합니다. 우선 연산 순서가 중요하지 않다고 가정합니다.
Hardware-aware Parallel Scan Algorithm(병렬 스캔 알고리즘)은 Hidden State를 메모리에 저장하지 않고 병렬적으로 Scan 연산을 수행하는 방식입니다.
이 방식을 사용하기 위해서 저자들은 kernel fusion 기법을 활용하게 됩니다. kernel fusion은 입력과 파라미터를 GPU HBM에서 읽어오고 SRAM에 로드하는 것입니다. 결과적으로, 시퀀스를 부분적으로 계산하고 반복적으로 결합할 수 있습니다.
따라서 다음과 같은 구조가 됩니다.
중간 상태들은 저장되지 않지만, 그래디언트를 계산하기 위해 역방향 패스에서 필요합니다. 대신, 저자들은 역방향 패스 동안 이러한 중간 상태들을 재계산합니다. 이것이 비효율적으로 보일 수 있지만, 상대적으로 느린 DRAM에서 모든 중간 상태를 읽는 것보다 훨씬 덜 비용이 많이 듭니다.
이 아키텍처는 종종 selective SSM 또는 S6 모델로 언급되며, 이는 본질적으로 선택적 스캔 알고리즘으로 계산된 S4 모델입니다.
The Mamba Block
Mamba의 경우 H3 요소와 트랜스포머에서 자주 사용하는 Gated MLP를 섞은 형식으로 완성됩니다. 이때 σ는 Swish 활성화함수를 사용합니다. 기존에 사용하던 H3가 연속적인 신호에서 성능이 괜찮았지만 언어모델과 같은 이산형 시퀀스모델에서는 성능이 좋지 않았기 때문입니다.
H3Attention을 만들때 Query, Key, value를 만드는 것처럼, Shift SSM과 Diag SSM 이렇게 2개의 SSM을 쌓는 구조입니다.
따라서 mamba는 다음과 같은 구조가 됩니다.
위에서 본 selective SSM은 Transformer의 디코더 블록에서 Self-Attention을 나타내는 것과 같은 방식으로 블록으로 구현될 수 있습니다.
또한 여러 mamba 블록을 쌓을 수 있으며, 그 출력을 mamba 맘바 블록의 입력으로 사용할 수 있습니다.
따라서 selective SSM은 다음과 같은 특성을 가집니다:
- 이산화를 통해 생성된 재귀적 SSM
- 장거리 의존성을 포착하기 위한 HiPPO 초기화가 있는 행렬 A
- 정보를 선택적으로 압축하기 위한 선택적 스캔 알고리즘
- 계산 속도를 높이기 위한 하드웨어 인식 알고리즘
결론적으로는 빠른 추론과 훈련뿐만 아니라 무한한 맥락도 얻을 수 있습니다.
실제 성능을 비교하면 Selective Mechanism을 도입한 SSM과 그렇지 않은 SSM을 비교하면 S6(S4 + Selection)를 도입한 방식이 기존 S4보다 훨씬 더 좋은 성능을 보이고 있습니다. 따라서 SSM내부에서 Selective Mechanism이 훨씬 성능이 좋다는걸 알 수 있습니다.Induction Heads의 경우에도 기존에 멀티헤드 어텐션을 사용한 모델들보다도 Mamba가 훨씬 잘 유지하는 것을 성능으로 잘 보여주었습니다.
또한 성능 평가 지표를 보면 Mamba의 경우 RNN계열 혹은 GPT-2 계열 모델들과 비교했을 때 성능이 훨씬 좋은 모습이 나타납니다.
또한 Efficiency Benchmarks에서도 좋은 성능을 보여줍니다.
Mamba의 핵심 연산인 병렬 스캔(parallel scan) 이 표준 구현 방식보다 40배 빠른 속도를 보여줍니다. 이는 맘바가 GPU 메모리 계층 구조를 효율적으로 활용하는 하드웨어 인식 알고리즘을 사용하기 때문입니다.
또한 추론시 Mamba는 recurrent 모델로서 Transformer보다 5배 높은 처리량(throughput) 을 달성합니다. Transformer는 attention 계산을 위해 이전 토큰 정보를 저장하는 KV 캐시(cache)가 필요하지만, Mamba는 이러한 캐시 없이 recurrent하게 계산되므로 더 높은 배치 크기를 사용할 수 있습니다. 따라서 동일한 하드웨어에서 더 많은 토큰을 처리할 수 있어 추론 속도가 훨씬 빠릅니다.
결론적으로 본 논문에서는 구조화된 SSM에 Selection 메커니즘을 도입하여, 시퀀스 길이에 따라 선형적으로 확장되는 동시에 맥락 의존적 추론을 수행할 수 있도록 했습니다. Attention을 사용하지 않는 아키텍처에 통합된 Mamba는 다양한 도메인에서 최고의 결과를 달성하며, 강력한 Transformer 모델의 성능과 동등하거나 뛰어넘습니다.