FLAT-Attention

Flat-Attention v.s. FlashAttention
FLAT: An Optimized Dataflow for Mitigating Attention Bottlenecks, ASPLOS'23

Summary of FLAT-Attention

FLAT (Fused Logit and Attend Tiling)

The quadratic complexity of Logit and Attend operator in Attention layer causing two major challenges:

  1. Low performance from memory boundedness
  2. Large on-chip buffer requirement for staging intermediate activations

FLAT fused Logit and Attend operators (see figure below) and optimized tiling and scheduling

  1. Increased the operation intensity โ†’ ameliorate the memory-boundedness
  2. Reduced the on-chip buffer requirement for data staging

FLAT delivered

  1. 1.5x speedup (on average) in our performance model across models & platforms
  2. 1.7x extra speedup (on average) on top of ELSA and Sanger
  3. 1.5x speedup (on average) in our Jax prototype while enabling 8x larger batch size or 32x larger sequence length



The canonical architecture of the attention-based models


Key Ideas

Tiling (stream and compute partial output at a time)

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More โ†’

Cross-opeator Scheudling (Fusing logit and attend operators)

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More โ†’

FLAT-Attention Demo

FLAT-Attention prototype with simple for loop implementation

FLAT-Attention v.s. FlashAttention

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More โ†’
(Coming soon)

Talk

ASPLOS'23