Skip to content

Commit

Permalink
use softmax_last_dim (metal and cuda kernel) in llama attention layer
Browse files Browse the repository at this point in the history
  • Loading branch information
zackangelo committed Oct 23, 2024
1 parent 7c09215 commit 29a3820
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion candle-transformers/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ impl CausalSelfAttention {
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
masked_fill(&att, &mask, f32::NEG_INFINITY)?
};
let att = candle_nn::ops::softmax(&att, D::Minus1)?;

let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
};
Expand Down

0 comments on commit 29a3820

Please sign in to comment.