성능 최적화를 위한 Flash-Attention 2

Taewan Cho
13 min readMar 13, 2024

--

google에서 새로운 오픈소스 모델인 gemma를 공개했습니다.

공개한 성능도 좋았지만, 직접 사용해보면서 다른 7b모델보다 확실히 더 좋은 성능을 보여주는 것을 알 수 있었습니다.

모델을 튜닝시켜서 활용해보려고 했는데 모델 설명에 flash attention 2을 지원한다는 문구가 적혀있었습니다.

그래서 이번 글에서는 flash attention 2에 대해서 알아보려고 합니다.

우선 기본적인 attention의 원리에 대해서 이해해야 합니다.

Attention

어텐션의 기본 아이디어는 디코더에서 출력 단어를 예측하는 매 시점(time step)마다, 인코더에서의 전체 입력 문장을 다시 한 번 참고한다는 점입니다. 단, 전체 입력 문장을 전부 다 동일한 비율로 참고하는 것이 아니라, 해당 시점에서 예측해야할 단어와 연관이 있는 입력 단어 부분을 좀 더 집중(attention)해서 보게 됩니다.

https://wikidocs.net/22893

어텐션 함수는 주어진 ‘쿼리(Query)’에 대해서 모든 ‘키(Key)’와의 유사도를 각각 구합니다. 그리고 구해낸 이 유사도를 키와 맵핑되어있는 각각의 ‘값(Value)’에 반영해줍니다. 그리고 유사도가 반영된 ‘값(Value)’을 모두 더해서 리턴합니다. 여기서는 이를 어텐션 값(Attention Value)이라고 하겠습니다. Attention Value를 수식으로 표현하면 다음과 같습니다.

Q = Query : t 시점의 디코더 셀에서의 은닉 상태
K = Keys : 모든 시점의 인코더 셀의 은닉 상태들
V = Values : 모든 시점의 인코더 셀의 은닉 상태들
Attention(Q, K, V) = Attention Value

예시로, Dot-Product Attention에서는 인코더의 소프트맥스 함수에서 attention score를 계산합니다.

https://wikidocs.net/22893

Attention score란 현재 디코더의 시점 t에서 단어를 예측하기 위해, 인코더의 모든 은닉 상태 각각이 디코더의 현 시점의 은닉 상태 st와 얼마나 유사한지를 판단하는 값입니다. (위 그림에서는 디코더에 있는 오른쪽 방향으로 이동되는 화살표가 st를 의미합니다.)

디코더의 t시점의 은닉상태인 st를 전치하고, 인코더의 i번째 은닉 상태를 내적하면 어텐션 스코어를 구할 수 있습니다(attention score를 구하는 방법은 다양합니다.).

내적(dot product)를 하게되면 두 벡터간 유사도를 알 수 있습니다. 따라서 attention score가 높다는 것은 t시점에 디코더가 집중해야될 인코더 시점을 알 수 있습니다.

t시점의 st와 인코더의 모든 은닉 상태의 attention score의 모음을 et라고 정의하면 다음과 같은 수식으로 표현할 수 있습니다.

https://wikidocs.net/22893

이후 소프트맥스(softmax) 함수를 통해 어텐션 분포(Attention Distribution)를 구합니다.

https://wikidocs.net/22893

et에 소프트맥스 함수를 적용하여, 모든 값을 합하면 1이 되는 확률 분포를 얻어냅니다. 이를 어텐션 분포(Attention Distribution)라고 하며, 각각의 값은 어텐션 가중치(Attention Weight)라고 합니다.

그 다음에 각 인코더의 attention weight와 은닉 상태를 가중합(곱하고 모두 더한다)해서 attention value를 구합니다. 다시 attenion함수 Attention(Q, K, V)를 생각해보면 t시점일 때 t시점의 Q는 디코더의 은닉층, K는 모든 인코더의 은닉층, V는 attention weight와 은닉 상태를 곱한 것이 V가 됩니다. 이후 V를 모두 더해 attention value를 구하는 것입니다.

https://wikidocs.net/22893

이렇게 구한 at(attention value)는 인코더의 문맥을 포함하고 있다고해서, context vector라고도 불립니다. 기본 seq2seq에서는 인코더의 마지막 시점의 은닉 상태를 context vector로 넘겨주는데, attention 메커니즘은 디코더의 모든 step에서 다른 context vector를 사용하기 때문에 더 긴 문장에서도 문맥을 기억할 수 있습니다.

마지막으로, attention value와 디코더의 t시점의 은닉 상태를 concatenate(연결)하여 하나의 벡터로 만들고, 이 벡터를 최종 연산의 입력으로 사용하는 것이 attention 메커니즘의 핵심입니다.

https://wikidocs.net/22893

Self attention

트렌스포머에서 사용된 셀프 어텐션은 어텐션의 한 유형이며, 모든 어텐션 메커니즘의 기본 원리를 공유합니다. 셀프 어텐션은 특히 같은 입력 내에서의 요소 간 상호작용에 초점을 맞춥니다.

Attention의 디코더 셀의 은닉 상태가 Q이고 인코더 셀의 은닉 상태가 K라는 점에서 Q와 K가 서로 다른 값을 가지고 있었습니다. 그런데 셀프 어텐션에서는 Q, K, V가 전부 동일합니다. 트랜스포머의 셀프 어텐션에서의 Q, K, V는 아래와 같습니다.

Q : 입력 문장의 모든 단어 벡터들
K : 입력 문장의 모든 단어 벡터들
V : 입력 문장의 모든 단어 벡터들
https://ratsgo.github.io/nlpbook/docs/language_model/tr_self_attention/

위의 그림은 트랜스포머에 대한 구글 AI 블로그 포스트에서 가져왔습니다. 위의 예시 문장을 번역하면 ‘그 동물은 길을 건너지 않았다. 왜냐하면 그것은 너무 피곤하였기 때문이다.’라는 의미가 됩니다. 셀프 어텐션은 입력 문장 내의 단어들끼리 유사도를 구하기 때문에, 그것(it)이 동물(animal)과 연관되었을 확률이 높다는 것을 찾아냅니다.

“I am a student”라는 문장으로 self attention을 하기 위해서는 각 단어 벡터로부터 Q, K, V 벡터를 얻는 작업을 거칩니다.

https://wikidocs.net/22893

기존의 벡터로부터 더 작은 벡터는 가중치 행렬을 곱하므로서 완성됩니다. 위의 그림은 단어 벡터 중 student 벡터로부터 Q, K, V 벡터를 얻어내는 모습을 보여줍니다. 모든 단어 벡터에 위와 같은 과정을 거치면 각 토큰은 각각의 Q, K, V 벡터를 얻습니다.

https://wikidocs.net/22893

각 Q벡터는 모든 K벡터에 대해서 어텐션 스코어를 구하고, 어텐션 분포를 구한 뒤에 이를 사용하여 모든 V벡터를 가중합하여 어텐션 값 또는 컨텍스트 벡터를 구하게 됩니다. 그리고 이를 모든 Q벡터에 대해서 반복합니다. 트렌스포머에서는 위의 attention의 예시와 다르게 특정 값을 나누어서 스케일링 하는 attention score 함수를를 사용하는 Scaled dot-product Attention을 사용합니다.

https://wikidocs.net/22893

이제 어텐션 스코어에 소프트맥스 함수를 사용하여 어텐션 분포(Attention Distribution)을 구하고, 각 V벡터와 가중합하여 어텐션 값(Attention Value)을 구합니다. 이를 단어 I에 대한 어텐션 값 또는 단어 I에 대한 컨텍스트 벡터(context vector)라고도 할 수 있습니다. am에 대한 Q벡터, a에 대 Q벡터, student에 대한 Q벡터에 대해서도 모두 동일한 과정을 반복하여 각각에 대한 어텐션 값을 구합니다.

각 단어에 대한 Q, K, V 벡터를 구하고 스케일드 닷-프로덕트 어텐션을 수행하였던 위의 과정들은 벡터 연산이 아니라 행렬 연산을 사용하면 일괄 계산이 가능합니다.

https://wikidocs.net/22893

Q와 K를 내적하여 다음 그림과 같은 행렬을 구할 수 있습니다.

https://wikidocs.net/22893

이 행렬에 스케일링을 적용하고 V를 곱해주면 attention 행렬을 구할 수 있습니다.

https://wikidocs.net/22893

그리고 다양한 시점에서 문장을 이해하기 위해 multi-head attention(MHA)을 적용했습니다.

https://wikidocs.net/22893

Flash Attention

https://www.wonbeomjang.kr/blog/2023/fastattention/

Flash Attention은 대규모 어텐션 연산을 효율적으로 처리하기 위해 설계된 기법 중 하나입니다. 이는 특히 메모리 사용량을 줄이고, 처리 속도를 향상시키기 위해 “Tiling”과 “Recomputation”과 같은 전략을 사용합니다.

위 그림과 같이 HBM에 접근하는 횟수를 줄이는 것이 핵심입니다.

기본적인 attention은 다음과 같은 알고리즘으로 작동합니다.

  1. Q (Query)와 K (Key)를 HBM에서 불러와서 S (Score)를 계산합니다. 이것은 Query와 모든 Key의 유사도를 나타내며, S는 이어서 HBM에 저장됩니다.
  2. 계산된 Score S를 HBM에서 읽고, 이를 softmax 함수에 통과시켜 Probability 행렬 P를 계산합니다. 그리고 P도 HBM에 저장됩니다.
  3. P (Probability)와 V (Value)를 HBM에서 불러와서 최종 Output O를 계산합니다. O는 P와 V의 가중합으로, 이 결과도 HBM에 저장됩니다.
  4. 계산된 최종 Output O를 반환합니다.

이 알고리즘에 따르면 HBM에 6번은 접근해야 계산이 가능하다는 것을 알 수 있습니다.

Tiling

‘Tiling’은 대규모 행렬 연산을 수행할 때 전체 행렬을 한 번에 메모리에 로드하는 대신, 작은 블록(타일)으로 분할하여 각각을 순차적으로 처리하는 방식을 의미합니다. 이 접근 방식은 특히 GPU와 같은 하드웨어 가속기에서 계산할 때 유리하며, 메모리 사용을 최적화하고 병렬 처리를 향상시킬 수 있습니다.

Recomputation

메모리 사용을 줄이기 위해, 순전파(forward pass)에서 계산된 중간 결과를 메모리에 저장하지 않고, 역전파(backward pass)에서 필요할 때 다시 계산하는 방법을 말합니다.

forward pass에서는 Q×KT를 계산하여 어텐션 스코어를 구한 다음, 소프트맥스를 적용합니다. 이 때, 소프트맥스를 적용하기 전의 스코어는 보통 저장하지 않고 버립니다. Flash Attention에서는 이 스코어를 버리는 대신, 소프트맥스 정규화 통계를 메모리에 저장합니다. 이 때 저장된 소프트맥스 정규화 통계를 사용하여 원래의 어텐션 스코어를 다시 계산(recompute)합니다. 이 방법은 추가적인 계산을 필요로 하므로 FLOPs가 증가하지만, 고대역폭 메모리(HBM)에서 데이터를 다시 읽어야 하는 횟수가 줄어들기 때문에 전체적인 속도는 향상됩니다.

Kernel Fusion

Flash Attention은 기존의 PyTorch 구현에 비해 상당한 성능 향상을 보여줍니다. Tiling을 사용함으로써, GPT-2 모델의 어텐션 연산에 필요한 여러 단계들을 효과적으로 결합할 수 있었습니다. 특히, 하나의 HBM 로드로 많은 작업을 수행할 수 있습니다. 이러한 연산들을 하나의 커널로 결합하는 과정을 통해 연산 속도를 크게 향상시켰습니다.

Flash Attention 2

https://www.wonbeomjang.kr/blog/2023/flashattention-2/

FlashAttention은 기본적으로 non-matmul FLOPs를 줄입니다. 예를들어 Nvidia의 A100 GPU는 FP16/BF16의 matmul 연산은 이론적으로 312 TFLOPs/s의 연산량을 가지지만 non-matmul 연산은 19.5 TFLOPs/s의 연산량을 가집니다. 즉 non-matmul 연산이 matmul 연산보다 16배 느려 non-matmul 연산이 전체 연산의 일부를 차지하더라도 이를 최적화 시켜야합니다. 기존 flash attention에서 recaling의 횟수를 줄이고 Memorization m, l 대신 L을 저장해서 조금 더 최적화 시켰습니다.

추가로 K와 V를 split하는 것이 아니라 Q를 split하고, K와 V를 공유할 수 있도록 수정해서 중간 계산 결과를 계속 동기화를 하지 않아도 되도록 개선시켰습니다.

FlashAttention-2는 기존 FlashAttention, xFormer 대비 2배의 속도를 보여줬고, Triton으로 구현된 FlashAttention보다 1.3~1.5배의 빨라진 속도를 보여줬습니다. 놀라운 것은 pytorch에서 naive하게 implementation한 것 대비 10배의 속도차이를 보여줍니다. 이로인해 기존의 large model에서도 더 빠른 연산속도를 보여줍니다.

결론

Attention을 최적화 하기 위한 연구가 많이 진행중입니다. 새로운 메커니즘이 등장하지 않는 한 transformer의 논문 이름 “Attention is all you need”처럼 Attention 메커니즘을 이해하고, 최적화하는 쪽으로 발전할 것이라고 생각합니다.

--

--

Responses (1)