Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two errors: (1) NameError: ModularIndexing is not defined & (2) LoweringException: AttributeError: 'View' object has no attribute 'get_stride' #45

Open
tobiasvanderwerff opened this issue Sep 23, 2024 · 10 comments

Comments

@tobiasvanderwerff
Copy link

The following code leads to an error:

import torch
from torch.nn.attention.flex_attention import flex_attention

B, H, N, D = 100, 12, 128, 64
dtype = torch.bfloat16
device = torch.device("cuda")

class Attention(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bias = torch.randn(B, N, N, H, device=device, dtype=dtype)

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        score_mod = generate_score_mod(self.bias)
        o = flex_attention(q, k, v, score_mod=score_mod)
        return o

def generate_score_mod(bias):
    bias = (2 * bias).view(B, H, N, N)

    def score_mod(score, batch, head, q_idx, k_idx):
        attn_bias = bias[batch, head, q_idx, k_idx]
        return score + attn_bias

    return score_mod

if __name__ == "__main__":
    m = Attention().cuda().eval().to(dtype)
    m = torch.compile(m, mode='default', fullgraph=False)
    # m = torch.compile(m, mode='max-autotune', fullgraph=False)  # this also fails

    q = torch.randn(B, H, N, D, device=device, dtype=dtype)
    k = torch.randn(B, H, N, D, device=device, dtype=dtype)
    v = torch.randn(B, H, N, D, device=device, dtype=dtype)

    m(q, k, v)

The error depends on the torch.compile mode I'm using.

If using torch.compile(..., mode='default', ...), I get the following error:

E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] Triton compilation failed: triton_tem_fused_2
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] def triton_(arg_Q, arg_K, arg_V, arg_LSE, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr8, out_ptr0):
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     ROWS_GUARANTEED_SAFE : tl.constexpr = False
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     PRESCALE_QK : tl.constexpr = False
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     OUTPUT_LOGSUMEXP : tl.constexpr = False
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     FLOAT32_PRECISION : tl.constexpr = 'ieee'
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     IS_DIVISIBLE : tl.constexpr = True
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SM_SCALE : tl.constexpr = 0.125
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     GQA_SHARED_HEADS : tl.constexpr = 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     HAS_FULL_BLOCKS : tl.constexpr = False
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     QK_HEAD_DIM : tl.constexpr = 64
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     V_HEAD_DIM : tl.constexpr = 64
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     BLOCK_M : tl.constexpr = 128
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     BLOCK_N : tl.constexpr = 64
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SPARSE_Q_BLOCK_SIZE : tl.constexpr = 1073741824
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SPARSE_KV_BLOCK_SIZE : tl.constexpr = 1073741824
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     Q = arg_Q
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     K = arg_K
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     V = arg_V
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     LSE = arg_LSE
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     KV_NUM_BLKS = arg_KV_NUM_BLKS
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     KV_IDX = arg_KV_IDX
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     FULL_KV_IDX = arg_FULL_KV_IDX
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # Sub notation for this kernel:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     #
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # Q: Query, K: Key, V: Value
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # M: Number of queries, N: Number of keys/values, D: Model dimension
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # QK_HEAD_DIM: The dimension of the query and key embeddings
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # V_HEAD_DIM: The dimension of the value embeddings
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     #
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     #
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     #
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # (Modifiable) Performance tuning options
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # BLOCK_M: The thread block size across the seqlen dim of Q.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # The below are kernel options that can be applied for certain score_mods,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # or involve a numerics vs. perf tradeoff
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # about 20% more numerical error, but slightly faster.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # is not masked out? If so, we can skip an extra safety check
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # Define strides of inputs
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     stride_qz, stride_qh, stride_qm, stride_qk = 98304, 8192, 64, 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     stride_kz, stride_kh, stride_kn, stride_kk = 98304, 8192, 64, 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     stride_vz, stride_vh, stride_vn, stride_vk = 98304, 8192, 64, 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     ZQ = 100
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     HQ = 12
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     Q_LEN = 128
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     ZKV = 100
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     KV_LEN = 128
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     MATMUL_PRECISION = Q.dtype.element_ty
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     q_start = tl.program_id(0)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     off_zq = tl.program_id(1) // HQ
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     off_hq = tl.program_id(1) % HQ
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     off_zkv = off_zq % ZKV
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     off_hkv = off_hq // GQA_SHARED_HEADS
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     off_g = off_hq % GQA_SHARED_HEADS
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     q_offset = off_zq * stride_qz + off_hq * stride_qh
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     k_offset = off_zkv * stride_kz + off_hkv * stride_kh
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     v_offset = off_zkv * stride_vz + off_hkv * stride_vh
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     Q = Q + q_offset
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     K = K + k_offset
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     V = V + v_offset
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SPARSE_Z = 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SPARSE_HQ = 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     sparse_idx_z = off_zq % SPARSE_Z
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     sparse_idx_hq = off_hq % SPARSE_HQ
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     stride_kv_num_blks_h = 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     stride_kv_idx_h = 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     stride_kv_idx_m = 1
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # initialize pointer to m and l
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # KV_IDX and KV_NUM_BLKS are always contiguous.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m  # noqa: B950
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     Q_block_ptr = tl.make_block_ptr(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         base=Q,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         shape=(Q_LEN, QK_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         strides=(stride_qm, stride_qk),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         offsets=(q_start * BLOCK_M, 0),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         block_shape=(BLOCK_M, QK_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         order=(1, 0)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # load q: it stays in SRAM throughout the inner loop.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     if IS_DIVISIBLE:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         q = tl.load(Q_block_ptr)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     else:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         # boundary check is not free, so we only do it when necessary.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option = "zero")
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # We don't know anything "special" about these blocks, so we need to apply
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # both score_mod and mask_mod to it
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     kv_indices = KV_IDX + sparse_kv_idx_offset
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     K_block_ptr = tl.make_block_ptr(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         base=K,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         shape=(QK_HEAD_DIM, KV_LEN),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         strides=(stride_kk, stride_kn),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         offsets=(0, kv_start),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         block_shape=(QK_HEAD_DIM, BLOCK_N),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         order=(0, 1)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     V_block_ptr = tl.make_block_ptr(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         base=V,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         shape=(KV_LEN, V_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         strides=(stride_vn, stride_vk),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         offsets=(kv_start, 0),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         block_shape=(BLOCK_N, V_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         order=(1, 0)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     offs_n = kv_start + tl.arange(0, BLOCK_N)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     acc, l_i, m_i = forward_inner(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         arg_Q, arg_K, arg_V, arg_LSE, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr8, out_ptr0,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         acc, l_i, m_i,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         off_zq, off_hq, offs_m[:, None], offs_n[None, :],
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         kv_indices, kv_num_blocks,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         0, block_n_end,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         MATMUL_PRECISION,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         IS_FULL_BLOCKS=False,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # We know these blocks are guaranteed to be "full", so we don't need to
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # apply mask_mod to them - only score_mod
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     if HAS_FULL_BLOCKS:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         K_block_ptr = tl.make_block_ptr(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             base=K,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             shape=(QK_HEAD_DIM, KV_LEN),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             strides=(stride_kk, stride_kn),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             offsets=(0, kv_start),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             block_shape=(QK_HEAD_DIM, BLOCK_N),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             order=(0, 1)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         V_block_ptr = tl.make_block_ptr(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             base=V,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             shape=(KV_LEN, V_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             strides=(stride_vn, stride_vk),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             offsets=(kv_start, 0),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             block_shape=(BLOCK_N, V_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             order=(1, 0)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         offs_n = kv_start + tl.arange(0, BLOCK_N)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         acc, l_i, m_i = forward_inner(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             arg_Q, arg_K, arg_V, arg_LSE, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr8, out_ptr0,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             acc, l_i, m_i,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             off_zq, off_hq, offs_m[:, None], offs_n[None, :],
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             kv_indices, kv_num_blocks,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             0, block_n_end,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             MATMUL_PRECISION,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             IS_FULL_BLOCKS=True,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # [Note] Handle fully masked out rows:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     l_i = tl.where(l_i == 0.0, 1, l_i)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     acc = acc / l_i[:, None]
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     idx_zq = tl.program_id(1) // HQ
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     idx_hq = tl.program_id(1) % HQ
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     idx_m = offs_m[:, None]
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     idx_d = tl.arange(0, V_HEAD_DIM)[None, :]
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     mask = idx_m < Q_LEN
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # TODO generalize and add proper mask support
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     xindex = idx_d + (64*idx_m) + (8192*idx_hq) + (98304*idx_zq)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     tl.store(out_ptr0 + (tl.broadcast_to(xindex, acc.shape)), acc, mask)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # TODO dont want to write this if we dont require grad
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     if OUTPUT_LOGSUMEXP:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         off_hz = tl.program_id(1)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         l_ptrs = LSE + off_hz * Q_LEN + offs_m
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         lse = m_i + tl.math.log2(l_i)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         if IS_DIVISIBLE:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             tl.store(l_ptrs, lse)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         else:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] metadata: {'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*fp32', 4: '*i32', 5: '*i32', 6: '*fp32', 7: '*fp32', 8: '*bf16', 9: '*bf16'}, 'device': 0, 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())], 'device_type': 'cuda', 'num_warps': 4, 'num_stages': 3, 'debug': True, 'cc': 80}
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] triton.compiler.errors.CompilationError: at 49:29:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     if CHECK_BLOCK_BOUNDARY:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         # which is larger than the actual number of elements. To avoid access memory out of bound,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         # we need to mask out the elements that are out of Q_LEN & KV_LEN.
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         m = offs_m % Q_LEN
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         n = offs_n % KV_LEN
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     else:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         m = offs_m
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         n = offs_n
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     tmp0 = 2.0
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     tmp1 = tl.load(in_ptr8 + ModularIndexing(128*(m) + (n) + 16384*(off_h), 1, 12) + 12*ModularIndexing(128*(m) + (n) + 16384*(off_h), 12, 128) + 1536*ModularIndexing(128*(m) + (n) + 16384*(off_h), 1536, 128) + 196608*ModularIndexing(128*(m) + (n) + 16384*(off_h) + 196608*(off_z), 196608, 100)) * tmp0
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]                              ^
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] NameError('ModularIndexing is not defined')
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] The above exception was the direct cause of the following exception:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] triton.compiler.errors.CompilationError: at 42:28:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     RCP_LN2: tl.constexpr = 1.44269504
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     if PRESCALE_QK:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     # loop over k, v and update accumulator until block_n_end
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     for start_n in range(block_n_start, block_n_end):
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         if IS_DIVISIBLE:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]             acc, l_i, m_i = forward_block_mn(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]                             ^
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] The above exception was the direct cause of the following exception:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] Traceback (most recent call last):
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]   File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 444, in _precompile_config
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     binary = triton.compile(*compile_args, **compile_kwargs)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]   File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/triton/compiler/compiler.py", line 276, in compile
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     module = src.make_ir(options, codegen_fns, context)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]   File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/triton/compiler/compiler.py", line 113, in make_ir
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] triton.compiler.errors.CompilationError: at 154:20:
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     V_block_ptr = tl.make_block_ptr(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         base=V,
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         shape=(KV_LEN, V_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         strides=(stride_vn, stride_vk),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         offsets=(kv_start, 0),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         block_shape=(BLOCK_N, V_HEAD_DIM),
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]         order=(1, 0)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     )
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     offs_n = kv_start + tl.arange(0, BLOCK_N)
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0] 
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]     acc, l_i, m_i = forward_inner(
E0923 10:03:55.772000 15143 miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py:446] [0/0]                     ^
Traceback (most recent call last):
  File "/home/azureuser/a.py", line 36, in <module>
    m(q, k, v)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1292, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1087, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 530, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 933, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 675, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 708, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 220, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 643, in transform
    tracer.run()
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2776, in run
    super().run()
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 979, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 891, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2967, in RETURN_VALUE
    self._return(inst)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2952, in _return
    self.output.compile_subgraph(
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/__init__.py", line 2235, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1533, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1359, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1430, in _fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 479, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 665, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
                     ^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1425, in load
    compiled_graph = compile_fx_fn(
                     ^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 574, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 882, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1948, in compile_to_fn
    return self.compile_to_module().call
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1874, in compile_to_module
    return self._compile_to_module()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1902, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 2949, in load_by_key_path
    mod = _reload_python_module(key, path)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_azureuser/ab/cabknymdngk3mgmkmh4dqxr6zjqfhqqb3uqpssjql5l3uxbazccb.py", line 106, in <module>
    triton_tem_fused_2 = async_compile.triton('triton_', '''
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/async_compile.py", line 213, in triton
    kernel.precompile()
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 244, in precompile
    compiled_binary, launcher = self._precompile_config(
                                ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 444, in _precompile_config
    binary = triton.compile(*compile_args, **compile_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/triton/compiler/compiler.py", line 276, in compile
    module = src.make_ir(options, codegen_fns, context)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/triton/compiler/compiler.py", line 113, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
CompilationError: at 154:20:
    )
    V_block_ptr = tl.make_block_ptr(
        base=V,
        shape=(KV_LEN, V_HEAD_DIM),
        strides=(stride_vn, stride_vk),
        offsets=(kv_start, 0),
        block_shape=(BLOCK_N, V_HEAD_DIM),
        order=(1, 0)
    )
    offs_n = kv_start + tl.arange(0, BLOCK_N)

    acc, l_i, m_i = forward_inner(
                    ^

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Notably, the error goes away if I move the following line in generate_score_mod to __init__ instead:

    bias = (2 * bias).view(B, H, N, N)

Relevant specs:

  • Torch version: 2.6.0.dev20240918
  • GPU: A100 80GB
@tobiasvanderwerff
Copy link
Author

If using torch.compile(..., mode='max-autotune', ...), I get a different error (also resolved by the fix above):

Traceback (most recent call last):
  File "/home/azureuser/a.py", line 36, in <module>
    m(q, k, v)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1292, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1087, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 530, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 933, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 675, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 708, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 220, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 643, in transform
    tracer.run()
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2776, in run
    super().run()
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 979, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 891, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2967, in RETURN_VALUE
    self._return(inst)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2952, in _return
    self.output.compile_subgraph(
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/__init__.py", line 2235, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1272, in compile_fx
    return compile_fx(
           ^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1533, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1359, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1430, in _fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 479, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 665, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
                     ^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1425, in load
    compiled_graph = compile_fx_fn(
                     ^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 574, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 863, in fx_codegen_and_compile
    graph.run(*example_inputs)
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 780, in run
    return super().run(*args)
           ^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1357, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1023, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1020, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/lowering.py", line 363, in wrapped
    out = decomp_fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/kernel/flex_attention.py", line 913, in flex_attention
    autotune_select_algorithm(
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1729, in autotune_select_algorithm
    return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1224, in __call__
    inputs_key = create_inputs_key(input_nodes)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1138, in create_inputs_key
    return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes])
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1138, in <listcomp>
    return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes])
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 1698, in key_of
    node.get_stride(),
    ^^^^^^^^^^^^^^^
  File "/home/azureuser/miniconda3/envs/ao/lib/python3.11/site-packages/torch/_inductor/ir.py", line 6276, in __getattr__
    fn = getattr(self.data, name)
         ^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'View' object has no attribute 'get_stride'
  target: flex_attention
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.bfloat16, size=[100, 12, 128, 64], stride=[98304, 8192, 64, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg2_1', layout=FixedLayout('cuda', torch.bfloat16, size=[100, 12, 128, 64], stride=[98304, 8192, 64, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='arg3_1', layout=FixedLayout('cuda', torch.bfloat16, size=[100, 12, 128, 64], stride=[98304, 8192, 64, 1]))
  ))
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (TensorBox(StorageBox(
    ComputedBuffer(name='buf2', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
      'cuda',
      torch.int32,
      def inner_fn(index):
          _, _, _ = index
          tmp0 = ops.constant(1, torch.int32)
          return tmp0
      ,
      ranges=[1, 1, 1],
      origin_node=full,
      origins=OrderedSet([full])
    ))
  )), TensorBox(StorageBox(
    ComputedBuffer(name='buf3', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
      'cuda',
      torch.int32,
      def inner_fn(index):
          _, _, _, _ = index
          tmp0 = ops.constant(0, torch.int32)
          return tmp0
      ,
      ranges=[1, 1, 1, 1],
      origin_node=full_default,
      origins=OrderedSet([full_default])
    ))
  )), None, None, TensorBox(StorageBox(
    ComputedBuffer(name='buf4', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
      'cuda',
      torch.int32,
      def inner_fn(index):
          _, _, _ = index
          tmp0 = ops.load(buf0, 0)
          tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int32)
          tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
          return tmp2
      ,
      ranges=[1, 1, 1],
      origin_node=convert_element_type,
      origins=OrderedSet([sum_1, convert_element_type])
    ))
  )), TensorBox(StorageBox(
    ComputedBuffer(name='buf5', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
      'cuda',
      torch.int32,
      def inner_fn(index):
          _, _, _, _ = index
          tmp0 = ops.index_expr(0, dtype=torch.int16)
          tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)
          tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
          return tmp2
      ,
      ranges=[1, 1, 1, 1],
      origin_node=convert_element_type_1,
      origins=OrderedSet([convert_element_type_1, sort])
    ))
  )), None, None, 1073741824, 1073741824, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.125
  args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': False}
  args[7]: (TensorBox(
    View(
      StorageBox(
        Pointwise(
          'cuda',
          torch.bfloat16,
          def inner_fn(index):
              i0, i1, i2, i3 = index
              tmp0 = ops.load(arg0_1, i3 + 12 * i2 + 1536 * i1 + 196608 * i0)
              tmp1 = ops.constant(2, torch.bfloat16)
              tmp2 = tmp0 * tmp1
              return tmp2
          ,
          ranges=[100, 128, 128, 12],
          origin_node=mul,
          origins=OrderedSet([mul])
        )
      ),
      size=[100, 12, 128, 128],
      reindex=lambda i0, i1, i2, i3: [ModularIndexing(196608*i0 + 16384*i1 + 128*i2 + i3, 196608, 100), ModularIndexing(16384*i1 + 128*i2 + i3, 1536, 128), ModularIndexing(16384*i1 + 128*i2 + i3, 12, 128), ModularIndexing(16384*i1 + 128*i2 + i3, 1, 12)],
      origins=OrderedSet([view, mul])
    )
  ),)
  args[8]: ()

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

@Chillee
Copy link
Contributor

Chillee commented Sep 24, 2024

Will be fixed by pytorch/pytorch#136509

@cpuhrsch
Copy link

@tobiasvanderwerff - You can access these changes via our nightlies. See https://pytorch.org/get-started/locally/ .

@cpuhrsch
Copy link

cpuhrsch commented Oct 1, 2024

Looks like the original issue pytorch/ao#639 (comment) was not resolved with pytorch/pytorch#136509 @Chillee

@drisspg
Copy link
Contributor

drisspg commented Oct 1, 2024

Running the above repro does not produce an error, is there an updated repro?

@tobiasvanderwerff
Copy link
Author

Not yet @drisspg. Let me get back to you tomorrow and I'll try to create an updated repro.

@tobiasvanderwerff
Copy link
Author

@drisspg here is a new repro, which fails for me with "LoweringException: AttributeError: 'View' object has no attribute 'get_stride'":

It's notable that it fails for torch.compile(..., mode='max-autotune', ...), but not for torch.compile(..., mode='default', ...).

from typing import Tuple
import torch
import torch.nn.functional as F
from torch.nn.attention.flex_attention import flex_attention

B, nheads, H, W, D = 100, 8, 16, 16, 64
dtype = torch.bfloat16
device = torch.device("cuda")

class Attention(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.rel_pos_h = torch.randn(2 * H - 1, D, device=device, dtype=dtype)

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        q = q.view(B*nheads, H*W, -1)
        score_mod = generate_score_mod(q, self.rel_pos_h, (H, W), (H, W), B, nheads)
        q = q.view(B, nheads, H*W, -1)
        o = flex_attention(q, k, v, score_mod=score_mod)
        return o

def generate_score_mod(q, rel_pos_h, q_size, k_size, B, num_heads):
    rel_h = add_decomposed_rel_pos(q, rel_pos_h, q_size, k_size)

    rel_h = rel_h.view(B, num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3)).squeeze(-1)

    w = rel_h.size(-1)

    def score_mod(score, batch, head, q_idx, k_idx):
        """Add relative position bias to the attention score."""
        h_idx = k_idx // w
        attn_bias = rel_h[batch, head, q_idx, h_idx]
        return score + attn_bias

    return score_mod

def add_decomposed_rel_pos(
    q: torch.Tensor,
    rel_pos_h: torch.Tensor,
    q_size: Tuple[int, int],
    k_size: Tuple[int, int],
) -> torch.Tensor:
    q_h, q_w = q_size
    k_h, k_w = k_size

    q_coords = torch.arange(q_h, device=rel_pos_h.device)[:, None] * max(k_h / q_h, 1.0)
    k_coords = torch.arange(k_h, device=rel_pos_h.device)[None, :] * max(q_h / k_h, 1.0)
    relative_coords = (q_coords - k_coords) + (k_h - 1) * max(q_h / k_h, 1.0)
    Rh = rel_pos_h[relative_coords.long()]

    B, _, dim = q.shape
    r_q = q.reshape(B, q_h, q_w, dim)
    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
    rel_h = rel_h.unsqueeze(-1)
    rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)

    return rel_h


if __name__ == "__main__":
    m = Attention().cuda().eval().to(dtype)
    # m = torch.compile(m, mode='default', fullgraph=False)
    m = torch.compile(m, mode='max-autotune', fullgraph=False)  # this also fails

    q = torch.randn(B, nheads, H*W, D, device=device, dtype=dtype)
    k = torch.randn(B, nheads, H*W, D, device=device, dtype=dtype)
    v = torch.randn(B, nheads, H*W, D, device=device, dtype=dtype)

    m(q, k, v)

@drisspg
Copy link
Contributor

drisspg commented Oct 2, 2024

thank you, will take a look

@drisspg
Copy link
Contributor

drisspg commented Oct 2, 2024

Fix here: pytorch/pytorch#137204

@tobiasvanderwerff
Copy link
Author

Thanks @drisspg , I'll test it out once it's merged

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Oct 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants