Sequence Modeling with State Space Models

Taewan Cho
16 min readMay 14, 2024

--

https://web.stanford.edu/class/archive/cs/cs224n/cs224n.1244/slides/cs224n-2024-lecture18-deployment-and-efficiency.pdf

딥 시퀀스 모델은 입력 시퀀스를 처리하여 출력 시퀀스를 생성하는 모델로, 자연어 처리, 음성 인식, 시계열 예측 등 다양한 분야에서 활용되고 있습니다. RNN, CNN, Transformer 등이 대표적인 딥 시퀀스 모델이며, 각 모델은 뚜렷한 장단점을 가지고 있습니다.

예를 들어 RNN은 순차 데이터 처리에 효과적이지만 학습 속도가 느리고, Transformer는 장거리 의존성을 모델링하는 데 뛰어나지만 계산 비용이 많이 듭니다.

딥 시퀀스 모델이 직면한 주요 과제는 범용성, 계산 효율성, 그리고 장거리 의존성입니다. 즉, 다양한 작업과 도메인에 적용 가능하면서도 효율적으로 계산되고, 긴 시퀀스 내의 멀리 떨어진 요소 간의 관계를 효과적으로 모델링할 수 있는 모델을 개발하는 것이 중요합니다.

지금 가장 많이 사용되는 Transformer 매커니즘의 핵심은 Attention입니다.(정확히는 self-attention) 이 매커니즘은 성능은 엄청나게 좋지만 연산량도 엄청나게 많다는 문제가 있습니다.

다음 토큰을 생성할 때, 이미 일부 토큰을 생성했더라도 전체 시퀀스에 대한 attention score를 다시 계산해야 합니다.

그래서 Attention을 바로 사용하는 것이 아니라 근사화 시키려는 시도도 많이 있었지만 성능이 그리 좋지 못했습니다. 그래서 연구자들은 새로운 아키텍쳐를 제안합니다.

전통적인 Classic RNN은 Vanishing gradient 문제가 있었습니다. Vanishing gradient는 보통 Activation Function에서 생깁니다. Activation Function이 없이 Linear한 연산만 반복하더라도 S4에서는 모델의 성능, 시간복잡도 & 메모리 효율이 보장될 수 있다고 합니다.(transformer에 대항하는 새로운 아키텍쳐!!)

반면 Transformer는 많은 연결을 가지고 있어서 강력한 성능을 보여주지만 공간 효율적이지도 못하고, 시간복잡도가 O(n²)입니다. 이러한 문제점들을 개선하고자 새로운 모델이 제안됐습니다.

CTM, RNN, CNN의 각 장점을 결합하려고 시도한 모델이 State Space Models (SSM)입니다. 병렬화 가능한 훈련(Convolution을 활용하여 RNN의 단점 극복)과 빠른 추론(RNN의 장점)이 가능합니다. 또한 100만 토큰 이상을 처리할 수 있고, Transformer의 성능과 비슷합니다.

https://stacks.stanford.edu/file/druid:mb976vf9362/gu_dissertation-augmented.pdf

연속 모델(Continuous Model) 표현 (왼쪽):

  • SSSM은 SSM (A, B, C)을 시간 척도 매개변수 ∆로 이산화(연속적인 데이터를 불연속적인 데이터로 변환)하여 정의되는 (uk) ∈ Rᴸ → (yt) ∈ Rᴸ의 맵핑입니다.
  • 암시적 연속 모델로서, SSSM은 다른 단계 크기 ∆로 이산화하여 불규칙적으로 샘플링된 데이터를 처리하는 등의 기능을 획득합니다.
  • 연속 모델 표현은 불규칙적으로 샘플링된 데이터를 처리할 수 있습니다.

순환 모델(Recurrent Model) 표현 (가운데):

  • 순환 모델로서, 선형 순환(Linear Recurrence)을 펼쳐서 레이어를 시간 방향으로(Time-wise) 계산함으로써 (즉, 한 번에 하나의 수직 슬라이스 (ut, xt, yt), (ut+1, xt+1, yt+1), …) 효율적으로 추론을 수행할 수 있습니다.
  • 순환 모델 표현은 무한한 컨텍스트를 다룰 수 있고 효율적인 추론이 가능합니다.

컨볼루션 모델(Convolutional Model) 표현 (오른쪽):

  • 컨볼루션 모델로서, 특정 필터와 컨볼루션하여 레이어를 깊이 방향으로(Depth-wise) 병렬 계산함으로써 (즉, 한 번에 하나의 수평 슬라이스 (ut), (yt), …) 효율적으로 학습을 수행할 수 있습니다.
  • 컨볼루션 모델 표현은 지역 정보를 활용할 수 있고 병렬화된 학습이 가능합니다.

그렇다면 이 구조에 대해서 조금 더 자세하게 이해해봅시다.

State Space Sequence Models

SSM은 1960년대 칼만 필터(Kalman Filter)부터 시작하여 많은 과학 분야에서 광범위하게 사용되어 왔습니다. 그러나 여기서 다루는 SSM을 시퀀스 모델(Sequence Model)로 취급하는 방식은 기존의 방식과는 상당히 다릅니다.(SSM은 State Space Sequence Models, SSSM을 지칭하는 데 사용합니다.

Vanilla SSM(State Space Models)은 연속 시퀀스를 처리합니다.

https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-mamba-and-state

상태 공간은 시스템은 가능한 상태들을 정의함으로써 문제를 수학적으로 표현하는 방법입니다. 따라서 “상태 공간”은 모든 가능한 위치(상태)의 지도입니다. 각 지점은 미로에서의 고유한 위치를 나타내며, 출구까지 얼마나 떨어져 있는지와 같은 구체적인 세부 사항을 포함합니다.

전통적으로, 시간 t에서 SSM은:

  • 입력 시퀀스 x(t): (ex 미로에서 왼쪽으로, 아래로 움직임)
  • 잠재 상태 표현 h(t): (ex 출구까지의 거리와 x/y 좌표)
  • 예측된 출력 시퀀스 y(t): (ex 출구에 더 빨리 도달하기 위해 다시 왼쪽으로 이동)

하지만, 한 번에 왼쪽으로 이동하는 것과 같은 이산 시퀀스를 사용하는 대신, 연속 시퀀스를 입력으로 받아 출력 시퀀스를 예측합니다.

SSM은 동적 시스템, 시간 t에서의 상태를 통해 두 방정식을 통해 예측될 수 있습니다.

상태 방정식(State equation)은 상태가 어떻게 변하는지(행렬 A를 통해)와 입력이 상태에 어떻게 영향을 미치는지(행렬 B를 통해)를 설명합니다.

출력 방정식(Output equation)은 상태가 어떻게 출력으로 변환되는지(행렬 C를 통해)와 입력이 출력에 어떻게 영향을 미치는지(행렬 D를 통해)를 설명합니다.

행렬 A, B, C, D는 배울 수 있는 매개변수로서도 일반적으로 언급됩니다.

이러한 두 방정식을 시각화 하면 다음과 같습니다.

이 과정을 단계별로 조금 더 자세하게 보면 상태 방정식(State equation)은 다음과 같이 표현할 수 있습니다.

State Representation 부분이 일반적인 신경망의 hidden state와 비슷합니다(지식을 표현하는 잠재 공간).

이후 출력 방정식을 출력 방정식(Output equation)을 적용하면 아래와 같은 그림이 됩니다.

행렬 D를 사용하여 입력에서 출력으로 직접 신호를 제공할 수 있습니다. 이는 스킵 커넥션(skip-connection)이라고도 불립니다. 따라서 SSM은 종종 스킵 커넥션 없이 아래와 같이 간주됩니다.

위 수식에 따라, SSSM은 입력 시퀀스 u ∈ Rᴸ×ᴹ에서 출력 시퀀스 y ∈ Rᴸ×ᴹ로의 시퀀스 간 변환(Sequence-to-Sequence Transformation)으로 정의됩니다. 이는 (A, B, C)와 추가 매개변수 ∆로 매개변수화됩니다.

일반적으로 SSM의 동역학(Dynamics)은 시간에 따라 변할 수 있습니다. 즉, 행렬 A ∈ Rᴺ×ᴺ와 벡터 B ∈ Rᴺ×¹, C ∈ R¹×ᴺ이 시간 t의 함수일 수 있다는 것입니다. 이는 위 방정식에서 이들이 t에 대한 함수로 표현되는 것을 의미합니다. 이러한 SSM을 Time-Varying SSM이라고 합니다.

하지만, 행렬 A와 벡터 B, C가 상수(Constant)일 때, 즉 시간에 따라 변하지 않을 때, SSM의 동역학은 시간에 대해 불변(Invariant)합니다. 이러한 SSM을 선형 시불변(Linear Time-Invariant, LTI) 시스템이라고 합니다.

LTI SSM은 오디오나 이미지와 같은 “연속적인(Continuous)” 도메인에서 잘 작동하지만, 텍스트 데이터에는 적합하지 않습니다. 여기서 다루는 SSM은 LTI SSM을 의미합니다(다음과 같은 수식으로 u(t)→y(t)를 할 수 있다).

The Continuous Representation (Discretization)

대부분의 실제 데이터는 연속적이지 않고 이산적이기 때문에 이산화(Discretization)하는 작업이 필요합니다.

따라서 연속 함수 u(t) 대신 입력 시퀀스 u = (u0, u1, u2, …)에 적용하기 위해서는 위 방정식을 이산화해야 합니다. 이를 위해서는 입력의 해상도를 나타내는 추가 단계 크기(Step Size) 매개변수 ∆가 필요합니다. 개념적으로, 입력 uk는 내재된 연속 신호 u(t)에서 균일한 간격으로 샘플링된 값으로 볼 수 있습니다. 여기서 uk = u(k∆)입니다. 이를 시간 척도(Timescale)라고도 합니다.

이산 시간 SSM은 순환(Recurrence) 또는 이산 컨볼루션(Discrete Convolution)으로 계산할 수 있습니다. 이산화를 설명하기 위해, 가장 간단한 방법은 잘 알려진 오일러 방법(Euler’s Method)입니다. 이 방법은 ODE x’(t) = f(x(t))를 1차 근사(First-Order Approximation) xk = xk-1 + ∆f(xk-1)로 변환합니다. 위 방정식에 대입해보면 다음과 같습니다.

여기서 Ā := I + ∆A이고 B̄ := ∆B는 이산화된 상태 매개변수입니다. 그러나 오일러 방법은 불안정할 수 있기 때문에, 영차 홀드(Zero-Order Hold, ZOH) 또는 쌍선형 변환(Bilinear Transform, Tustin’s Method)과 같은 더 정확한 방법이 일반적으로 사용됩니다. 이러한 방법들은 Ā와 B̄에 대한 대체 공식을 제공하며, 서로 교환하여 사용할 수 있습니다.(결과적으로 연속데이터를 이산화 시키는 것이 핵심입니다.)

이산화는 SSSM의 Forward Pass 계산 그래프에서 첫 번째 단계로 생각할 수 있습니다. 이산화는 연속 매개변수 (∆, A, B)를 이산 매개변수 (Ā, B̄)로 변환하여, 하위 계산에서는 (Ā, B̄)만 사용하게 됩니다. 이 변환은 Ā = fA(∆, A)와 B̄ = fB(∆, A, B)라는 공식으로 정의되며, 여기서 (fA, fB) 쌍은 이산화 규칙입니다.

The Recurrent Representation (Efficient Inference)

위와 같이 이산화 후, discrete SSM은 다음과 같이 정의됩니다.

여기서 이산화된 매개변수 Ā, B̄, C, uk, xk, yk의 모양(Shape)은 원래의 연속 시간 모델과 동일합니다. 위 방정식은 이제 함수에서 함수로의 맵핑 u(t) → y(t) 대신 시퀀스에서 시퀀스로의 맵핑 (uk) → (yk)을 나타냅니다. 또한, 상태 방정식은 이제 xk에 대한 순환(Recurrence)이 되었습니다. 이는 discrete SSM을 한 번에 한 단계씩 펼칠 수 있음을 의미하며, 기존의 RNN과 유사한 방식으로 계산할 수 있게 됩니다(활성화 함수를 사용하지 않지만 유사한 성능을 보여준다). RNN의 언어로 표현하면, xk ∈ Rᴺ은 전치 행렬(Transition Matrix) Ā를 가진 은닉 상태(Hidden State)로 볼 수 있습니다.

이러한 구조는 이전 상태만 알고 있으면 되기 때문에 CNN이나 Transformer와 같은 모델과 달리, 시간 단계당 일정한 계산과 공간만 사용하면서 (잠재적으로 무한한) 입력 시퀀스를 처리할 수 있습니다.

The Convolutional Representation (Efficient Training)

recurrent SSM는 순차적인 특성 때문에 현대 하드웨어에서 학습하기에 실용적이지 않습니다. 대신, LTI SSM과 continuous convolutions 사이의 잘 알려진 연결성이 있습니다. 따라서 discrete convolution으로 작성될 수 있습니다. 수식을 다음과 같이 작성할 수 있습니다.

y를 정리해서 다시 작성하면 다음과 같은 수식이 됩니다.

이는 커널에 대한 명시적 공식을 가진 단일 컨볼루션으로 벡터화할 수 있습니다.

위 방정식은 single (non-circular) convolution입니다. K를 SSM (convolution)필터 또는 커널, 또는 간단히 상태 공간 커널(State Space Kernel, SSK)이라고 부릅니다(convolution이기 때문에 병렬학습이 가능함).

SSM은 CNN의 선형 컨볼루션 레이어와 유사하게 해석될 수 있지만, SSM 커널은 무한 길이입니다.따라서 입력 길이 L에 맞추어 잘릴 수 있습니다.

Summary of SSM Representations

상태 공간 모델(SSM)은 연속적인 SSM을 이산화하는 단계가 필요합니다. 이는 두 가지 모드로 표현할 수 있고, 각각 “RNN 모드”와 “CNN 모드”라고 부르기도 합니다. 하지만 SSSM 레이어는 실제로는 신경망이 아니며, 단순히 선형 시퀀스 변환입니다. 따라서 RNN과 CNN이 아닙니다. SSM을 활용하여 근사화시킨 이 개념을 실제로 활용하기 위해 다양한 연구들이 진행됩니다.

HiPPO(NeurIPS 2020)

기본적인 SSM은 실제로는 성능이 좋지 않습니다. 그 주요 원인 중 하나는 시퀀스 길이에 따라 그레이디언트가 기하급수적으로 변하는 문제(즉, 그레이디언트 소실/폭발 문제)입니다. 이를 해결하기 위해, HiPPO 이론이 개발되었습니다.

행렬 A를 사용하여 최근 토큰을 잘 포착하고 오래된 토큰을 감소시키는 상태 표현을 구축합니다. (장기 기억을 잘 유지하기 위함)

HiPPO 이론은 연속 시간 메모리화를 위해 특정 행렬 A를 사용하여 상태 x(t)가 입력 u(t)의 과거 기록을 기억할 수 있도록 합니다. 이 행렬 중 가장 중요한 것이 HiPPO 행렬입니다.

위 수식은 상삼각 행렬 형태로, 대각선과 상삼각 요소로 구성됩니다. n과 k의 관계에 따라 다른 값을 가지며, 이는 입력의 과거 기록을 효과적으로 저장하는 역할을 합니다. f(t)를 c(t)에 대응시키는 함수를 hippo라고 했을 때, 이 hippo 함수는 SSM의 구조를 띠게 됩니다. 당시의 Transformer보다 성능이 좋았습니다.

S4(ICLR 2022)

이산 시간 SSM을 계산하는 데 있어서, 근본적인 병목 현상은 반복적인 행렬 곱셈을 포함한다는 점입니다. 저자는 DPLR(Diagonal plus low-rank)를 도입하여, Kernel K를 빠르게 구하는 과정을 유도합니다. DPLR 구조는 상태 공간 모델의 계산 효율성을 높이기 위한 특별한 행렬 구조입니다.

HiPPO 행렬은 현재 형태에서 DPLR이 아니지만, Normal Plus Low-Rank (NPLR) 구조를 가지고 있습니다. 이는 Unitary 행렬 V에 의해 대각화될 수 있습니다(Unitary 행렬은 길이 보존, 직교성, 스펙트럼 성질 등의 중요한 성질을 가지는 복소수 행렬입니다.). NPLR 행렬은 SSM 모델 관점에서 DPLR 행렬과 사실상 동일합니다.

HiPPO 행렬을 DPLR 형태로 변환하기 위해, 먼저 HiPPO 행렬을 NPLR 형태로 표현합니다. 그런 다음, 이를 대각화하여 DPLR 행렬을 추출합니다. https://srush.github.io/annotated-s4/

S4를 기점으로 SSM은 sequence-modeling에 적용되기 시작헀으며, 오늘날의 모델은 대부분 S4의 변형입니다.

이러한 모델을은 이후에 Mamba의 핵심 구조가 됩니다.

--

--