FLAT-Attention === [**Flat-Attention v.s. FlashAttention**](https://hackmd.io/Hgl4V-BtQsyJn17_lO4X4A?view) *FLAT: An Optimized Dataflow for Mitigating Attention Bottlenecks*, [ASPLOS'23](https://arxiv.org/abs/2107.06419) # Summary of FLAT-Attention FLAT (Fused Logit and Attend Tiling) #### The quadratic complexity of Logit and Attend operator in Attention layer causing two major challenges: 1. Low performance from memory boundedness 2. Large on-chip buffer requirement for staging intermediate activations #### FLAT fused Logit and Attend operators (see figure below) and optimized tiling and scheduling 1. Increased the operation intensity → ameliorate the memory-boundedness 2. Reduced the on-chip buffer requirement for data staging #### FLAT delivered 1. 1.5x speedup (on average) in our performance model across models & platforms 2. 1.7x extra speedup (on average) on top of ELSA and Sanger 3. 1.5x speedup (on average) in our Jax prototype while enabling 8x larger batch size or 32x larger sequence length [![](https://hackmd.io/_uploads/HkWre2CS3.png) ](https://arxiv.org/abs/2107.06419) *The canonical architecture of the attention-based models* ------ # Key Ideas #### Tiling (stream and compute partial output at a time) ![](https://hackmd.io/_uploads/S1ZJWh0H2.png) #### Cross-opeator Scheudling (Fusing logit and attend operators) ![](https://hackmd.io/_uploads/SkHZ-30rn.png) # FLAT-Attention Demo FLAT-Attention prototype with simple for loop implementation [![](https://hackmd.io/_uploads/HytnUoRBn.png)](https://github.com/felix0901/flat_prototype/blob/master/flat_prototype.ipynb) FLAT-Attention v.s. FlashAttention ![](https://hackmd.io/_uploads/HytnUoRBn.png) (Coming soon) <!-- (https://colab.corp.google.com/drive/1JfOabuedo6EUFvqOUcFbXze3IPF3UuEg#scrollTo=ZowV3Fkm5Z1Q) --> # Talk [ASPLOS'23](https://www.youtube.com/watch?v=qhlUG1Knh6k&t=1s) [![](https://hackmd.io/_uploads/H1AGLjRr3.png)](https://www.youtube.com/watch?v=qhlUG1Knh6k&t=1s) ---