Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace flash_4 with FlexAttention #639

Open
cpuhrsch opened this issue Aug 8, 2024 · 15 comments
Open

Replace flash_4 with FlexAttention #639

cpuhrsch opened this issue Aug 8, 2024 · 15 comments
Labels
good first issue Good for newcomers

Comments

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Aug 8, 2024

https://github.com/pytorch-labs/segment-anything-fast/ uses custom Triton code to implement a variant of SDPA that supports the kind of additive attention required by the image_encoder.

In a nutshell the code it implements using this custom Triton kernel is

    rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
    rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
    attn_bias = (rel_h_ + rel_w_).view(q_.size(0), q_.size(1),
                                       rel_h_.size(2), rel_h_.size(3) * rel_w_.size(4))
    return torch.nn.functional.scaled_dot_product_attention(q_, k_, v_, attn_mask=attn_bias)

With the release of FlexAttention in PyTorch 2.5(code examples) it should now we possible to express this without the need for custom Triton code.

Not only will FlexAttention be able to support a fused implementations for more input shapes, it is also likely to produce more optimal code and with better hyperparameters. This kind of fused attention caused an end-to-end improvement of about 1.15x on top of a fused SDPA and torch.compile'd (with CUDA graphs) baselined.

The task:

Copy over the relevant files from segment-anything-fast into torchao's model folder and follow the readme to rerun if needed.

Write a FlexAttention version of flash_4 and measure difference in performance. If it helps, we can immediately land it in torchao, but at a minimum it could influence FlexAttention development.

@cpuhrsch cpuhrsch added the good first issue Good for newcomers label Aug 8, 2024
@tobiasvanderwerff
Copy link
Contributor

tobiasvanderwerff commented Sep 15, 2024

@cpuhrsch

I would like to give this a shot. Could you help me clarify something?

Is the goal to make a fork of segment-anything-fast that uses Flex Attention, and test that in ao? The alternative would be to manually copy over all the files from segment-anything-fast to ao/torchao/_models/sam/, but that seems overkill since the only change is in the SDPA call.

What I could do is make a fork of segment-anything-fast that uses Flex Attention and use that as an alternative pip install to pip3 install git+https://github.com/pytorch-labs/segment-anything-fast.git when benchmarking SAM.

Let me know if this makes any sense, or if you meant something else.

@cpuhrsch
Copy link
Contributor Author

@tobiasvanderwerff - Yes, we could also get started with an experimental PR against https://github.com/pytorch-labs/segment-anything-fast . Eventually it could be convenient to be able to vendor the changes in SAM-fast and make them more easily accessible via torchao packaging and distribution. What do you think about this?

@tobiasvanderwerff
Copy link
Contributor

@cpuhrsch that sounds like a plan. Let me try to get started on this in the next few days.

I already tried to run the SAM benchmark today to get started but realized that my current GPU (NVIDIA T4) does not support Flash Attention (since it requires compute capability >=sm_80, e.g. an A100). However, I intend to get access to a cloud A100 GPU instance in the next few days.

If getting access to a better GPU doesn't work out, I don't think I'll be able to work on this, and I'll let you know in that case.

@tobiasvanderwerff
Copy link
Contributor

@cpuhrsch as discussed, I've created a fork of the segment-anything-fast repo that uses Flex Attention instead of the custom Triton kernel. I've also added a test to check for correctness. You can see the changes here.

I'm posting benchmark results from ao/torchao/_models/sam/benchmark.sh below. First results are not terribly enouraging: the Flex Attention implementation leads to a ~25% reduction in img/s. I might do some more digging to see why this is happening. If you have any suggestions, I'd love to hear them.

As a side note, Flex Attention only accepts embedding sizes that are powers of two, so I had to add padding to make it work. It's possible that the padding leads to the negative effect in performance, although the Triton kernel seems to do the same thing.

Torch version: 2.6.0.dev20240918
GPU: A100 80GB

Baseline results (using Triton kernel):

device sam_model_type batch_size memory(MiB) memory(%) img_s(avg) batch_ms(avg)/batch_size mIoU use_compile use_half compress use_compile_decoder use_rel_pos pad_input_image_batch num_workers num_batches num_images profile_path memory_path
cuda vit_h 32 15172 18 22.533401716616083 44.37856354651513 0.5812715827356921 max-autotune torch.bfloat16 None False True True 32 154 4928 None None
cuda vit_h 32 15154 18 25.16516896830006 39.73746416166231 0.5818834536577897 max-autotune torch.bfloat16 int8_dynamic_quant False True True 32 154 4928 None None
cuda vit_h 32 15632 19 24.824717871078573 40.282431614863405 0.5675837487618974 max-autotune torch.bfloat16 sparse_mlp_only False True True 32 154 4928 None None
cuda vit_h 32 13429 16 24.589577947798148 40.66763578142439 0.5306639662569573 max-autotune torch.bfloat16 sparse False True True 32 154 4928 None None
cuda vit_h 32 14869 18 26.597207143088742 37.597932543073384 0.5669944616184625 max-autotune torch.bfloat16 int8_dynamic_quant_sparse False True True 32 154 4928 None None
cuda vit_h 32 17068 21 23.96093702681232 41.73459489004953 0.5485481164943489 max-autotune torch.float16 int4_weight_only_sparse False True True 32 154 4928 None None

Flex Attention results (I omitted the last two rows because running the benchmark was taking a long time):

device sam_model_type batch_size memory(MiB) memory(%) img_s(avg) batch_ms(avg)/batch_size mIoU use_compile use_half compress use_compile_decoder use_rel_pos pad_input_image_batch num_workers num_batches num_images profile_path memory_path
cuda vit_h 32 19531 24 16.35339887491553 61.14936764209301 0.5812806843206303 max-autotune torch.bfloat16 None False True True 24 154 4928 None None
cuda vit_h 32 19512 24 17.72072649749095 56.43109497466644 0.5815980109018701 max-autotune torch.bfloat16 int8_dynamic_quant False True True 24 154 4928 None None
cuda vit_h 32 20960 25 16.6174344353318 60.177761127422386 0.5672995875671748 max-autotune torch.bfloat16 sparse_mlp_only False True True 24 154 4928 None None
cuda vit_h 32 18997 23 14.915692058093141 67.04348655799767 0.5306602491658978 max-autotune torch.bfloat16 sparse False True True 24 154 4928 None None

@cpuhrsch
Copy link
Contributor Author

Hm, very interesting. Thanks for doing this work. Do you mind attaching GPU traces for say the first setup both with and without flexattention?

You can gather traces using https://github.com/pytorch-labs/segment-anything-fast/tree/e6aadeb86f3ae1f58c3f98e2a91e251716e0f2aa/experiments#kernel-traces . Just ensure that path ends in .json.gz.

@cpuhrsch
Copy link
Contributor Author

Using the GPU traces it is also possible to annotate (using https://pytorch.org/docs/main/generated/torch.autograd.profiler.record_function.html#record-function and https://pytorch.org/docs/main/generated/torch.cuda.synchronize.html#torch-cuda-synchronize ) the section that was changed and look at the GPU kernel difference in runtime only. This way we can double check the slowdown is precisely due to this change.

I'd create two versions of these traces, one with annotation and sync and one without. So that means 4 traces in total

a) Baseline without annotate
b) Baseline with annotate
c) Changed without annotate
d) Changed with annotate

@tobiasvanderwerff
Copy link
Contributor

tobiasvanderwerff commented Sep 21, 2024

Tracing results indicate that in the Flex Attention version, a lot of time is spent on a padding kernel (triton_tem_fused_constant_pad_nd_38, indicated by blue arrows in the screenshot below):

image

The trace shows that the Flex Attention impl. spends 2 seconds in the image encoder, whereas the baseline spends only 1.35 seconds. So it definitely looks like quite a slowdown in the part of the code where SDPA is used.

Padding does not seem to take nearly as much time in the baseline (in the trace, the largest purple blocks under the image encoder block are calls to _fwd_kernel_aligned, the top level attention function):

image

So it seems that the padding is a large source of the slowdown. As I mentioned earlier, the Triton kernel does the same padding, but they somehow have made it more efficient. At the top of the function, it says:

"""
Writing this as a composite allows torch.compile to fuse
needed padding into previous operations and memory
allocations. 
"""

So it seems like they somehow manage to make the padding more efficient by fusing it into earlier operations. I'm currently trying to figure out if this can also be done for the Flex Attention kernel, but it's not obvious to me how.

(NB: I also tried running the tracing with the annotations, as you suggested @cpuhrsch, but this did not seem to show up in the trace output - perhaps because of torch.compile?)

@cpuhrsch
Copy link
Contributor Author

@cpuhrsch - Hm, the way you're using FlexAttention it should also be a composite (as in flex_attention_fwd is a composite just like _attention_rel_h_rel_w, because it's composed of multiple functions as opposed to just a single kernel).

Since this is needed specifically for vit_h, does it mean for vit_b the gap narrows or even with FlexAttention it's faster?

Also cc @Chillee and @drisspg

@tobiasvanderwerff
Copy link
Contributor

tobiasvanderwerff commented Sep 23, 2024

@cpuhrsch

vit_b results show a similar gap between the baseline and Flex Attention. So even without padding, there is still a large diff in runtime!

Baseline:

device sam_model_type batch_size memory(MiB) memory(%) img_s(avg) batch_ms(avg)/batch_size mIoU use_compile use_half compress use_compile_decoder use_rel_pos pad_input_image_batch num_workers num_batches num_images profile_path memory_path
cuda vit_b 32 6631 8 87.1522144531224 11.47417775067416 0.5358536312719586 max-autotune torch.bfloat16 None False True True 24 154 4928 None None

Flex Attention:

device sam_model_type batch_size memory(MiB) memory(%) img_s(avg) batch_ms(avg)/batch_size mIoU use_compile use_half compress use_compile_decoder use_rel_pos pad_input_image_batch num_workers num_batches num_images profile_path memory_path
cuda vit_b 32 6969 8 38.57345144045247 25.92456631846242 0.536104508681229 max-autotune torch.bfloat16 None False True True 24 154 4928 None Noned

I looked at the profile traces but it is difficult to extract any useful information. Most of the kernels in the Flex Attention version have indiscriminate names like triton_tem_fused_13 or triton_tem_fused_31, so it's hard to know what exactly the GPU is spending its time on.

@tobiasvanderwerff
Copy link
Contributor

tobiasvanderwerff commented Sep 23, 2024

I may have found a clue as to where the performance bottleneck lies. Replacing this line in the score_mod function:

            attn_bias = self.rel_h[batch, head, q_idx, h_idx] + self.rel_w[batch, head, q_idx, w_idx]

with this:

            attn_bias = h_idx + w_idx

leads to a massive speedup (38 img/s -> 97 img/s). So it seems that the indexing into rel_h and rel_w is slowing things down a lot.

@tobiasvanderwerff
Copy link
Contributor

Unfortunately, using rel_h and rel_w in a different way (like passing them to the function without setting them as class attributes), leads to Torch Inductor errors when torch compiling. I've reached a point where I'm really not sure how to deal with this, so I've opened an issue in the Flex Attention repo that reproduces the issue. Hopefully, the Flex Attention authors can provide some more clarity.

@cpuhrsch
Copy link
Contributor Author

Great, thank you for the investigation @tobiasvanderwerff !

@cpuhrsch
Copy link
Contributor Author

@tobiasvanderwerff - For what it's worth, indexing into the rel_h and rel_w Tensors efficiently is a key reason why flash_4 can provide a speedup over SDPA to begin with. It's not a better implementation of SDPA, it just avoids the materialization of (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)).

@tobiasvanderwerff
Copy link
Contributor

@cpuhrsch an update:

I've tried the fix pushed by @Chillee, but unfortunately I still get an error (see output below). It looks like the minified code sample I referred to in the issue does not quite transfer to the more complicated setup of the SAM-fast model. I'm not really sure how to resolve this right now, and unfortunately it is not very feasible for me to keep using an A100 for testing due to expenses (sorry). So the best strategy may be to put this on hold right now and perhaps wait until FlexAttention manages this issue at some point.

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'View' object has no attribute 'get_stride'
  target: flex_attention
  args[0]: TensorBox(
    View(
      View(
        SliceView(
          View(
            StorageBox(
              ComputedBuffer(name='buf12', layout=FixedLayout('cuda', torch.bfloat16, size=[3, 800, 12, 196, 64], stride=[120422400, 150528, 12544, 64, 1]), data=Pointwise(
                'cuda',
                torch.bfloat16,
                def inner_fn(index):
                    i0, i1, i2, i3, i4 = index
                    tmp0 = ops.load(buf6, i4 + 64 * i2 + 768 * i0 + 2304 * ModularIndexing(i3, 1, 14) + 32256 * ModularIndexing(i3, 14, 14) + 451584 * i1)
                    tmp1 = ops.load(arg7_1, i4 + 64 * i2 + 768 * i0)
                    tmp2 = tmp0 + tmp1
                    return tmp2
                ,
                ranges=[3, 800, 12, 196, 64],
                origin_node=clone_2,
                origins=OrderedSet([clone_2])
              ))
            ),
            size=[3, 9600, 196, 64],
            reindex=lambda i0, i1, i2, i3: [i0, ModularIndexing(i1, 12, 800), ModularIndexing(i1, 1, 12), i2, i3],
            origins=OrderedSet([view_5, clone_2])
          ),
          size=[1, 9600, 196, 64],
          reindex=lambda i0, i1, i2, i3: [i0, i1, i2, i3],
          origins=OrderedSet([unbind])
        ),
        size=[9600, 196, 64],
        reindex=lambda i0, i1, i2: [0, i0, i1, i2],
        origins=OrderedSet([unbind])
      ),
      size=[800, 12, 196, 64],
      reindex=lambda i0, i1, i2, i3: [12*i0 + i1, i2, i3],
      origins=OrderedSet([view_17])
    )
  )
  args[1]: TensorBox(
    View(
      View(
        SliceView(
          View(
            StorageBox(
              ComputedBuffer(name='buf12', layout=FixedLayout('cuda', torch.bfloat16, size=[3, 800, 12, 196, 64], stride=[120422400, 150528, 12544, 64, 1]), data=Pointwise(
                'cuda',
                torch.bfloat16,
                def inner_fn(index):
                    i0, i1, i2, i3, i4 = index
                    tmp0 = ops.load(buf6, i4 + 64 * i2 + 768 * i0 + 2304 * ModularIndexing(i3, 1, 14) + 32256 * ModularIndexing(i3, 14, 14) + 451584 * i1)
                    tmp1 = ops.load(arg7_1, i4 + 64 * i2 + 768 * i0)
                    tmp2 = tmp0 + tmp1
                    return tmp2
                ,
                ranges=[3, 800, 12, 196, 64],
                origin_node=clone_2,
                origins=OrderedSet([clone_2])
              ))
            ),
            size=[3, 9600, 196, 64],
            reindex=lambda i0, i1, i2, i3: [i0, ModularIndexing(i1, 12, 800), ModularIndexing(i1, 1, 12), i2, i3],
            origins=OrderedSet([view_5, clone_2])
          ),
          size=[1, 9600, 196, 64],
          reindex=lambda i0, i1, i2, i3: [i0 + 1, i1, i2, i3],
          origins=OrderedSet([unbind])
        ),
        size=[9600, 196, 64],
        reindex=lambda i0, i1, i2: [0, i0, i1, i2],
        origins=OrderedSet([unbind])
      ),
      size=[800, 12, 196, 64],
      reindex=lambda i0, i1, i2, i3: [12*i0 + i1, i2, i3],
      origins=OrderedSet([view_18])
    )
  )
  args[2]: TensorBox(
    View(
      View(
        SliceView(
          View(
            StorageBox(
              ComputedBuffer(name='buf12', layout=FixedLayout('cuda', torch.bfloat16, size=[3, 800, 12, 196, 64], stride=[120422400, 150528, 12544, 64, 1]), data=Pointwise(
                'cuda',
                torch.bfloat16,
                def inner_fn(index):
                    i0, i1, i2, i3, i4 = index
                    tmp0 = ops.load(buf6, i4 + 64 * i2 + 768 * i0 + 2304 * ModularIndexing(i3, 1, 14) + 32256 * ModularIndexing(i3, 14, 14) + 451584 * i1)
                    tmp1 = ops.load(arg7_1, i4 + 64 * i2 + 768 * i0)
                    tmp2 = tmp0 + tmp1
                    return tmp2
                ,
                ranges=[3, 800, 12, 196, 64],
                origin_node=clone_2,
                origins=OrderedSet([clone_2])
              ))
            ),
            size=[3, 9600, 196, 64],
            reindex=lambda i0, i1, i2, i3: [i0, ModularIndexing(i1, 12, 800), ModularIndexing(i1, 1, 12), i2, i3],
            origins=OrderedSet([view_5, clone_2])
          ),
          size=[1, 9600, 196, 64],
          reindex=lambda i0, i1, i2, i3: [i0 + 2, i1, i2, i3],
          origins=OrderedSet([unbind])
        ),
        size=[9600, 196, 64],
        reindex=lambda i0, i1, i2: [0, i0, i1, i2],
        origins=OrderedSet([unbind])
      ),
      size=[800, 12, 196, 64],
      reindex=lambda i0, i1, i2, i3: [12*i0 + i1, i2, i3],
      origins=OrderedSet([view_19])
    )
  )
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (TensorBox(StorageBox(
    ComputedBuffer(name='buf15', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
      'cuda',
      torch.int32,
      def inner_fn(index):
          _, _, _ = index
          tmp0 = ops.constant(1, torch.int32)
          return tmp0
      ,
      ranges=[1, 1, 1],
      origin_node=full,
      origins=OrderedSet([full])
    ))
  )), TensorBox(StorageBox(
    ComputedBuffer(name='buf16', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
      'cuda',
      torch.int32,
      def inner_fn(index):
          _, _, _, _ = index
          tmp0 = ops.constant(0, torch.int32)
          return tmp0
      ,
      ranges=[1, 1, 1, 1],
      origin_node=full_default,
      origins=OrderedSet([full_default])
    ))
  )), None, None, TensorBox(StorageBox(
    ComputedBuffer(name='buf17', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
      'cuda',
      torch.int32,
      def inner_fn(index):
          _, _, _ = index
          tmp0 = ops.load(buf7, 0)
          tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int32)
          tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
          return tmp2
      ,
      ranges=[1, 1, 1],
      origin_node=convert_element_type_11,
      origins=OrderedSet([convert_element_type_11, sum_1])
    ))
  )), TensorBox(StorageBox(
    ComputedBuffer(name='buf18', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
      'cuda',
      torch.int32,
      def inner_fn(index):
          _, _, _, _ = index
          tmp0 = ops.index_expr(0, dtype=torch.int16)
          tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)
          tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
          return tmp2
      ,
      ranges=[1, 1, 1, 1],
      origin_node=convert_element_type_12,
      origins=OrderedSet([convert_element_type_12, sort])
    ))
  )), None, None, 1073741824, 1073741824, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.125
  args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': False}
  args[7]: (TensorBox(
    View(
      View(
        View(
          StorageBox(
            Pointwise(
              'cuda',
              torch.bfloat16,
              def inner_fn(index):
                  i0, i1, i2, i3, _ = index
                  tmp0 = ops.load(buf11, i3 + 16 * i2 + 224 * i0 + 2150400 * i1)
                  return tmp0
              ,
              ranges=[9600, 14, 14, 14, 1],
              origin_node=clone_4,
              origins=OrderedSet([clone_4])
            )
          ),
          size=[9600, 196, 14, 1],
          reindex=lambda i0, i1, i2, i3: [i0, ModularIndexing(i1, 14, 14), ModularIndexing(i1, 1, 14), i2, 0],
          origins=OrderedSet([view_15, clone_4])
        ),
        size=[800, 12, 196, 14, 1],
        reindex=lambda i0, i1, i2, i3, i4: [12*i0 + i1, i2, i3, 0],
        origins=OrderedSet([view_20])
      ),
      size=[800, 12, 196, 14],
      reindex=lambda i0, i1, i2, i3: [i0, i1, i2, i3, 0],
      origins=OrderedSet([squeeze])
    )
  ), TensorBox(
    View(
      View(
        View(
          StorageBox(
            Pointwise(
              'cuda',
              torch.bfloat16,
              def inner_fn(index):
                  i0, i1, i2, _, i4 = index
                  tmp0 = ops.load(buf14, i4 + 16 * i1 + 224 * i0 + 2150400 * i2)
                  return tmp0
              ,
              ranges=[9600, 14, 14, 1, 14],
              origin_node=clone_5,
              origins=OrderedSet([clone_5])
            )
          ),
          size=[9600, 196, 1, 14],
          reindex=lambda i0, i1, i2, i3: [i0, ModularIndexing(i1, 14, 14), ModularIndexing(i1, 1, 14), 0, i3],
          origins=OrderedSet([clone_5, view_16])
        ),
        size=[800, 12, 196, 1, 14],
        reindex=lambda i0, i1, i2, i3, i4: [12*i0 + i1, i2, 0, i4],
        origins=OrderedSet([view_21])
      ),
      size=[800, 12, 196, 14],
      reindex=lambda i0, i1, i2, i3: [i0, i1, i2, 0, i3],
      origins=OrderedSet([squeeze_1])
    )
  ))
  args[8]: ()

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

@cpuhrsch
Copy link
Contributor Author

cpuhrsch commented Oct 1, 2024

@tobiasvanderwerff - Thank you for testing this. I'll update pytorch-labs/attention-gym#45 as well. At least with the most recent fix we're one step closer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants