Flash Attention was introduced in 2022 as a fast and memory-efficient exact attention algorithm that used tiling and algebraic aggregation to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM. This made it faster and more memory-efficient than traditional attention algorithms, especially for long sequences.

Just a few days ago, Flash Attention 2 (FA2) was introduced with many different improvements ranging from algorithmic changes to reduce the number of non-matmul FLOPS to additional parallelization and more. The paper “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning“ has more details on FA2.

SHARK is a high-performance ML compiler and runtime that can run the latest generative AI models across a wide variety of hardware. SHARK is deployed on various hardware from mobile SoCs to Hyperscalers. SHARK has a high performance implementation of Flash Attention (described here) that uses a platform agnostic representation of the register layout to achieve high performance. Following the release of the FA2 publication, SHARK engineers implemented FA2 in a few days and we showcase their results below. SHARK is able to outperform Triton top of master for most common Flash Attention sizes found in Stable Diffusion and BERT. 

(UPDATE 7/25/2023) Thanks to Philippe Tillet and Thomas Raoux from the Triton team for analyzing the results. For consistency, we have pinned the GPU clocks with:

sudo nvidia-smi --lock-gpu-clocks=1410,1410

Nvidia Driver Version: 525.125.06 CUDA Version: 12.0 on Google Cloud a2-highgpu-1g A100-40G

(Previously benchmarked SHARK from here against latest Triton @SHA: without clock set to 1410Mhz is here)

Performance and Portability with a high-dimensional vector representation for data layout in registers

SHARK has its own abstraction of Layout representation, which is the core representation of a MLIR compiler for Tensor Core code generation. In order to represent data layout in registers, SHARK uses a high-dimensional vector representation. The key idea here being that regardless of the underlying hardware, the layout should be representable using a N-dimensional vector representation where N is hardware specific. Below we compare the SHARK, CUTE and Triton representations for mma.sync 16x8x8 for fp16.

SHARK representation

CUTE representation

From: Developing Optimal CUDA Kernels on Hopper Tensor Cores by Cris Cecka.

In the CUTE layout representation, each dimension is represented using a shape and a stride and layouts combine these shapes and strides (potentially in a hierarchical way using IntTuples). The layout here is defined as a map between the n-D coordinate space (natural coordinates) and 1-D index space (storage index).

Triton’s enum-based representation 

In Triton, the layout is represented as below

#triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}>

And the specific details of the layout are contained within the TritonGPUToLLVM conversion passes.

For more details on how SHARK is able to out perform Triton, please see the slides here. If all this is exciting we are hiring check out https://ai.compiler.engineer or email stdin@nod.ai

SHARK is bringing core AI technologies to whatever hardware you have from mobile phones, laptops, workstations to servers and unlocking the last mile of AI in a privacy preserving way. You can try SHARK Studio today, our open source generative AI portal that runs locally on macOS, Windows and Linux. Enterprise grade containerized deployment of private, privacy preserving AI at scale is launching soon. Reach out to partners@nod.ai if you would like early access.

Comments are closed.