Standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length. Also FalshAttention is faster.
https://arxiv.org/abs/2205.14135 - Section 5 suggests that the biggest limitation is that custom CUDA kernels need to be coded on a per-GPU architecture basis.
FlashAttention is mathematically identical to standard attention, so in theory there's no downside. In practice, numerical inaccuracies of floating point mean that the results differ slightly. I don't know of any papers going in depth to analyze what impact those variances have in a range of real models, but generally speaking deep models handle slightly variances well. I've not noticed any difference in my applications training models. And tons of people use FlashAttention as a drop-in replacement on models trained on standard attention (e.g. using xformers in StableDiffusion).
Also in practice FlashAttention is still relatively new so it isn't well supported in libraries yet. Until PyTorch 2.0 you had to either implement it yourself, or use something like xformers which comes with a bag of caveats. PyTorch 2.0 now has it built-in, and it's easy to use, but the implementation is incomplete so you can't, for example, use it with an attention mask (which is needed in LLMs, for example).
tl;dr: Basically none, but it just isn't well supported yet.
According to the paper Flash Attention also needs quadratic memory:
Let 𝑁 be the sequence length, 𝑑 be the head dimension, and 𝑀 be size of SRAM with 𝑑 <= 𝑀 <= 𝑁𝑑. Standard attention (Algorithm 0) requires Θ(𝑁𝑑+𝑁²) HBM accesses, while FlashAttention (Algorithm 1) requires Θ(𝑁²𝑑²M⁻¹) HBM accesses.
mind explaining why this is so attractive/what the hurdle is for the laypeople in the audience? (me)