Skip to content

Commit

Permalink
Handle the abs case.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 23, 2024
1 parent fd53b96 commit 4244d59
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions candle-nn/src/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,19 @@ impl RotatingCache {
}
}

fn get_mask(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
fn get_mask_abs(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
let context = self.max_seq_len;
let mask: Vec<_> = (0..size1)
.flat_map(|i| {
(0..size2).map(move |j| {
u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i)
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}

fn get_mask_rel(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
let context = self.max_seq_len;
let upd_offset = (self.offset + size1) % self.max_seq_len;
let mask: Vec<_> = (0..size1)
Expand All @@ -283,12 +295,13 @@ impl RotatingCache {
let mask = if seq_len == 1 {
None
} else {
let cache_out_len = if seq_len < self.max_seq_len {
(self.current_seq_len + seq_len).min(self.max_seq_len)
let mask = if seq_len < self.max_seq_len {
let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
self.get_mask_rel(seq_len, cache_out_len, device)?
} else {
seq_len
self.get_mask_abs(seq_len, seq_len, device)?
};
Some(self.get_mask(seq_len, cache_out_len, device)?)
Some(mask)
};
Ok(mask)
}
Expand Down

0 comments on commit 4244d59

Please sign in to comment.