Skip to content

Commit

Permalink
Fix the attn mask generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 23, 2024
1 parent 4e444d7 commit fd53b96
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 33 deletions.
43 changes: 42 additions & 1 deletion candle-nn/src/kv_cache.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use candle::{Result, Tensor};
use candle::{Device, Result, Tensor};

#[derive(Debug, Clone)]
pub struct Cache {
Expand Down Expand Up @@ -255,6 +255,43 @@ impl RotatingCache {
}
}
}

fn get_mask(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
let context = self.max_seq_len;
let upd_offset = (self.offset + size1) % self.max_seq_len;
let mask: Vec<_> = (0..size1)
.flat_map(|pos_src| {
// The absolute position of the elements that will get added to the cache.
let pos_src = self.current_seq_len + pos_src;
(0..size2).map(move |pos_cache_rel| {
// The absolute position of the cache elements after the addition.
let pos_cache = self.current_seq_len + size1 + pos_cache_rel - upd_offset;
let pos_cache = if pos_cache_rel < upd_offset {
pos_cache
} else {
pos_cache - self.max_seq_len
};
u8::from(pos_cache > pos_src || pos_cache + context < pos_src)
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}

/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
let mask = if seq_len == 1 {
None
} else {
let cache_out_len = if seq_len < self.max_seq_len {
(self.current_seq_len + seq_len).min(self.max_seq_len)
} else {
seq_len
};
Some(self.get_mask(seq_len, cache_out_len, device)?)
};
Ok(mask)
}
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -308,6 +345,10 @@ impl RotatingKvCache {
self.k.current_seq_len()
}

pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
self.k.attn_mask(seq_len, device)
}

pub fn reset(&mut self) {
self.k.reset();
self.v.reset();
Expand Down
35 changes: 3 additions & 32 deletions candle-transformers/src/models/mimi/transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,6 @@ impl Module for LayerScale {
}
}

pub(crate) fn get_mask(
size1: usize,
size2: usize,
context: usize,
device: &Device,
) -> Result<Tensor> {
let mask: Vec<_> = (0..size1)
.flat_map(|i| {
(0..size2)
.map(move |j| u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i))
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}

#[derive(Debug, Clone)]
pub struct StreamingMultiheadAttention {
q_proj: Linear,
Expand Down Expand Up @@ -590,7 +575,6 @@ impl StreamingTransformerLayer {
#[derive(Debug, Clone)]
pub struct StreamingTransformer {
layers: Vec<StreamingTransformerLayer>,
context: usize,
positional_embedding: PositionalEmbedding,
max_period: usize,
}
Expand All @@ -617,7 +601,6 @@ impl StreamingTransformer {
}
Ok(Self {
layers,
context: cfg.context,
positional_embedding: cfg.positional_embedding,
max_period: cfg.max_period,
})
Expand All @@ -629,23 +612,11 @@ impl StreamingTransformer {

pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
let (_b, t, c) = xs.dims3()?;
let pos = self.layers[0]
let pos = self.layers[0].self_attn.kv_cache.current_seq_len();
let mask = self.layers[0]
.self_attn
.kv_cache
.k_cache()
.current_seq_len();
let mask = if t == 1 {
None
} else {
let cache_out_len = if t < self.context {
(pos + t).min(self.context)
} else {
t
};
// TODO: this is wrong, the mask depends on the kv-cache offset because of its rotating
// nature.
Some(get_mask(t, cache_out_len, self.context, xs.device())?)
};
.attn_mask(t, xs.device())?;
let mut xs = match self.positional_embedding {
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
PositionalEmbedding::Sin => {
Expand Down

0 comments on commit fd53b96

Please sign in to comment.