diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs index 9e860d612..4ca1a81da 100644 --- a/candle-nn/src/kv_cache.rs +++ b/candle-nn/src/kv_cache.rs @@ -1,4 +1,4 @@ -use candle::{Result, Tensor}; +use candle::{Device, Result, Tensor}; #[derive(Debug, Clone)] pub struct Cache { @@ -255,6 +255,43 @@ impl RotatingCache { } } } + + fn get_mask(&self, size1: usize, size2: usize, device: &Device) -> Result { + let context = self.max_seq_len; + let upd_offset = (self.offset + size1) % self.max_seq_len; + let mask: Vec<_> = (0..size1) + .flat_map(|pos_src| { + // The absolute position of the elements that will get added to the cache. + let pos_src = self.current_seq_len + pos_src; + (0..size2).map(move |pos_cache_rel| { + // The absolute position of the cache elements after the addition. + let pos_cache = self.current_seq_len + size1 + pos_cache_rel - upd_offset; + let pos_cache = if pos_cache_rel < upd_offset { + pos_cache + } else { + pos_cache - self.max_seq_len + }; + u8::from(pos_cache > pos_src || pos_cache + context < pos_src) + }) + }) + .collect(); + Tensor::from_slice(&mask, (size1, size2), device) + } + + /// Returns the attn_mask to be applied *after* adding `seq_len` to the cache. + pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result> { + let mask = if seq_len == 1 { + None + } else { + let cache_out_len = if seq_len < self.max_seq_len { + (self.current_seq_len + seq_len).min(self.max_seq_len) + } else { + seq_len + }; + Some(self.get_mask(seq_len, cache_out_len, device)?) + }; + Ok(mask) + } } #[derive(Debug, Clone)] @@ -308,6 +345,10 @@ impl RotatingKvCache { self.k.current_seq_len() } + pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result> { + self.k.attn_mask(seq_len, device) + } + pub fn reset(&mut self) { self.k.reset(); self.v.reset(); diff --git a/candle-transformers/src/models/mimi/transformer.rs b/candle-transformers/src/models/mimi/transformer.rs index 8a59606e5..6ccbc8d12 100644 --- a/candle-transformers/src/models/mimi/transformer.rs +++ b/candle-transformers/src/models/mimi/transformer.rs @@ -101,21 +101,6 @@ impl Module for LayerScale { } } -pub(crate) fn get_mask( - size1: usize, - size2: usize, - context: usize, - device: &Device, -) -> Result { - let mask: Vec<_> = (0..size1) - .flat_map(|i| { - (0..size2) - .map(move |j| u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i)) - }) - .collect(); - Tensor::from_slice(&mask, (size1, size2), device) -} - #[derive(Debug, Clone)] pub struct StreamingMultiheadAttention { q_proj: Linear, @@ -590,7 +575,6 @@ impl StreamingTransformerLayer { #[derive(Debug, Clone)] pub struct StreamingTransformer { layers: Vec, - context: usize, positional_embedding: PositionalEmbedding, max_period: usize, } @@ -617,7 +601,6 @@ impl StreamingTransformer { } Ok(Self { layers, - context: cfg.context, positional_embedding: cfg.positional_embedding, max_period: cfg.max_period, }) @@ -629,23 +612,11 @@ impl StreamingTransformer { pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result { let (_b, t, c) = xs.dims3()?; - let pos = self.layers[0] + let pos = self.layers[0].self_attn.kv_cache.current_seq_len(); + let mask = self.layers[0] .self_attn .kv_cache - .k_cache() - .current_seq_len(); - let mask = if t == 1 { - None - } else { - let cache_out_len = if t < self.context { - (pos + t).min(self.context) - } else { - t - }; - // TODO: this is wrong, the mask depends on the kv-cache offset because of its rotating - // nature. - Some(get_mask(t, cache_out_len, self.context, xs.device())?) - }; + .attn_mask(t, xs.device())?; let mut xs = match self.positional_embedding { PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(), PositionalEmbedding::Sin => {