You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
F.scaled_dot_product_attention calls into flash or memory efficient attention depending on some factors (should be mainly flash for the torchtitan case iiuc). Are there other ops that you have in mind?
@awgu It looks like xformers has support for Flash Attention v3 starting from 0.0.28 (flash3.FwOp and flash3.BwOp). Could bring extra training efficiency for Hopper arch as it's not implemented in pytorch yet.
As I read it from the blog, this brings a 1.6x-1.8x speedup over FAv2.
I guess it should not be too hard for users to install xformers and replace the F.scaled_dot_product_attention_call with the xformers attention call. This should work as long as the xformers attention is torch.compile-compatible, which I recall it is.
Since torchtitan is mainly for showing an example of how to set this kind of distributed training up, I think including xformers attention is not as important as showing what is achievable with torch native.
Curious why xformers is not used? Is it for simplicity or is there performance reason.
The text was updated successfully, but these errors were encountered: