Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 18, 2024
1 parent da7d032 commit 22391ae
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def run_dpa_with_cp(
seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [
x.reshape(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_]
x.reshape(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
for x in [q_, k_, v_, dout_]
]
elif qkv_format == "thd":
seq_idx_q = tex.thd_get_partitioned_indices(
Expand Down
4 changes: 3 additions & 1 deletion tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,7 +1845,9 @@ def _run_ref_mha_f16(dtype, config, backend):
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
out_grad = (
torch.load("out_grad.pt").to(device="cuda").reshape(config.batch_size, config.max_seqlen_q, -1)
torch.load("out_grad.pt")
.to(device="cuda")
.reshape(config.batch_size, config.max_seqlen_q, -1)
)

_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9031,4 +9031,4 @@ def forward(
outputs += (attention_bias,)
if self.input_layernorm and self.return_layernorm_output:
outputs += (layernorm_output,)
return outputs if len(outputs) > 1 else outputs[0]
return outputs if len(outputs) > 1 else outputs[0]

0 comments on commit 22391ae

Please sign in to comment.