Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

> 128 head dim, so can use flash attention (unlike GPT-J)

mind explaining why this is so attractive/what the hurdle is for the laypeople in the audience? (me)



Standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length. Also FalshAttention is faster.


So there must be a downside to FlashAttention. What is it?


https://arxiv.org/abs/2205.14135 - Section 5 suggests that the biggest limitation is that custom CUDA kernels need to be coded on a per-GPU architecture basis.


FlashAttention is mathematically identical to standard attention, so in theory there's no downside. In practice, numerical inaccuracies of floating point mean that the results differ slightly. I don't know of any papers going in depth to analyze what impact those variances have in a range of real models, but generally speaking deep models handle slightly variances well. I've not noticed any difference in my applications training models. And tons of people use FlashAttention as a drop-in replacement on models trained on standard attention (e.g. using xformers in StableDiffusion).

Also in practice FlashAttention is still relatively new so it isn't well supported in libraries yet. Until PyTorch 2.0 you had to either implement it yourself, or use something like xformers which comes with a bag of caveats. PyTorch 2.0 now has it built-in, and it's easy to use, but the implementation is incomplete so you can't, for example, use it with an attention mask (which is needed in LLMs, for example).

tl;dr: Basically none, but it just isn't well supported yet.


installing it is a nightmare


According to the paper Flash Attention also needs quadratic memory:

Let 𝑁 be the sequence length, 𝑑 be the head dimension, and 𝑀 be size of SRAM with 𝑑 <= 𝑀 <= 𝑁𝑑. Standard attention (Algorithm 0) requires Θ(𝑁𝑑+𝑁²) HBM accesses, while FlashAttention (Algorithm 1) requires Θ(𝑁²𝑑²M⁻¹) HBM accesses.


https://github.com/HazyResearch/flash-attention#memory

"standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length."

I guess you have just reported how many times the layer will need to access the memory, not how much memory usage scales with sequence length.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: