-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Two errors: (1) NameError: ModularIndexing is not defined & (2) LoweringException: AttributeError: 'View' object has no attribute 'get_stride' #45
Comments
If using
|
Will be fixed by pytorch/pytorch#136509 |
@tobiasvanderwerff - You can access these changes via our nightlies. See https://pytorch.org/get-started/locally/ . |
Looks like the original issue pytorch/ao#639 (comment) was not resolved with pytorch/pytorch#136509 @Chillee |
Running the above repro does not produce an error, is there an updated repro? |
Not yet @drisspg. Let me get back to you tomorrow and I'll try to create an updated repro. |
@drisspg here is a new repro, which fails for me with "LoweringException: AttributeError: 'View' object has no attribute 'get_stride'": It's notable that it fails for from typing import Tuple
import torch
import torch.nn.functional as F
from torch.nn.attention.flex_attention import flex_attention
B, nheads, H, W, D = 100, 8, 16, 16, 64
dtype = torch.bfloat16
device = torch.device("cuda")
class Attention(torch.nn.Module):
def __init__(self):
super().__init__()
self.rel_pos_h = torch.randn(2 * H - 1, D, device=device, dtype=dtype)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
q = q.view(B*nheads, H*W, -1)
score_mod = generate_score_mod(q, self.rel_pos_h, (H, W), (H, W), B, nheads)
q = q.view(B, nheads, H*W, -1)
o = flex_attention(q, k, v, score_mod=score_mod)
return o
def generate_score_mod(q, rel_pos_h, q_size, k_size, B, num_heads):
rel_h = add_decomposed_rel_pos(q, rel_pos_h, q_size, k_size)
rel_h = rel_h.view(B, num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3)).squeeze(-1)
w = rel_h.size(-1)
def score_mod(score, batch, head, q_idx, k_idx):
"""Add relative position bias to the attention score."""
h_idx = k_idx // w
attn_bias = rel_h[batch, head, q_idx, h_idx]
return score + attn_bias
return score_mod
def add_decomposed_rel_pos(
q: torch.Tensor,
rel_pos_h: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
q_h, q_w = q_size
k_h, k_w = k_size
q_coords = torch.arange(q_h, device=rel_pos_h.device)[:, None] * max(k_h / q_h, 1.0)
k_coords = torch.arange(k_h, device=rel_pos_h.device)[None, :] * max(q_h / k_h, 1.0)
relative_coords = (q_coords - k_coords) + (k_h - 1) * max(q_h / k_h, 1.0)
Rh = rel_pos_h[relative_coords.long()]
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_h = rel_h.unsqueeze(-1)
rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
return rel_h
if __name__ == "__main__":
m = Attention().cuda().eval().to(dtype)
# m = torch.compile(m, mode='default', fullgraph=False)
m = torch.compile(m, mode='max-autotune', fullgraph=False) # this also fails
q = torch.randn(B, nheads, H*W, D, device=device, dtype=dtype)
k = torch.randn(B, nheads, H*W, D, device=device, dtype=dtype)
v = torch.randn(B, nheads, H*W, D, device=device, dtype=dtype)
m(q, k, v) |
thank you, will take a look |
Fix here: pytorch/pytorch#137204 |
Thanks @drisspg , I'll test it out once it's merged |
…#137204) ## Summary Originally reported in pytorch-labs/attention-gym#45 Pull Request resolved: #137204 Approved by: https://github.com/Chillee, https://github.com/BoyuanFeng
The following code leads to an error:
The error depends on the torch.compile mode I'm using.
If using
torch.compile(..., mode='default', ...)
, I get the following error:Notably, the error goes away if I move the following line in
generate_score_mod
to__init__
instead:Relevant specs:
2.6.0.dev20240918
The text was updated successfully, but these errors were encountered: