Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a RotatingKVCache. #2493

Merged
merged 14 commits into from
Sep 23, 2024
Merged
44 changes: 39 additions & 5 deletions candle-examples/examples/mimi/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ struct Args {
/// The model weight file, in safetensor format.
#[arg(long)]
model: Option<String>,

/// Whether to use streaming or not, when streaming slices of data of the given size are passed
/// to the encoder/decoder one at a time.
#[arg(long)]
streaming: Option<usize>,
}

fn main() -> Result<()> {
Expand Down Expand Up @@ -87,20 +92,49 @@ fn main() -> Result<()> {
pcm
}
};
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
model.encode(&pcm)?
match args.streaming {
Some(chunk_size) => {
let mut code_chunks = vec![];
for pcm in pcm.chunks(chunk_size) {
let pcm = Tensor::new(pcm, &device)?.reshape((1, 1, ()))?;
let code_chunk = model.encode(&pcm)?;
code_chunks.push(code_chunk)
}
Tensor::cat(&code_chunks, candle::D::Minus1)?
}
None => {
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
model.encode(&pcm)?
}
}
}
};
println!("codes shape: {:?}", codes.shape());
model.reset_state();

match args.action {
Action::AudioToCode => {
codes.save_safetensors("codes", &args.out_file)?;
}
Action::AudioToAudio | Action::CodeToAudio => {
let pcm = model.decode(&codes)?;
let pcm = match args.streaming {
Some(chunk_size) => {
let seq_len = codes.dim(candle::D::Minus1)?;
let mut pcm_chunks = vec![];
for chunk_start in (0..seq_len).step_by(chunk_size) {
let chunk_len = usize::min(chunk_size, seq_len - chunk_start);
let codes = codes.narrow(candle::D::Minus1, chunk_start, chunk_len)?;
let pcm = model.decode_step(&codes.into())?;
if let Some(pcm) = pcm.as_option() {
pcm_chunks.push(pcm.clone())
}
}
Tensor::cat(&pcm_chunks, candle::D::Minus1)?
}
None => model.decode(&codes)?,
};
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
Expand Down
224 changes: 223 additions & 1 deletion candle-nn/src/kv_cache.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use candle::{Result, Tensor};
use candle::{Device, Result, Tensor};

#[derive(Debug, Clone)]
pub struct Cache {
Expand Down Expand Up @@ -145,3 +145,225 @@ impl KvCache {
self.v.reset();
}
}

#[derive(Debug, Clone)]
pub struct RotatingCache {
all_data: Option<Tensor>,
dim: usize,
// `offset` is the current write index in the buffer
offset: usize,
// The total size of the sequence seen so far.
current_seq_len: usize,
// max_seq_len is the size of the rotating buffer, it is actually allowed for the full
// sequence to grow past this limit.
max_seq_len: usize,
}

impl RotatingCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
Self {
all_data: None,
dim,
offset: 0,
current_seq_len: 0,
max_seq_len,
}
}

pub fn offset(&self) -> usize {
self.offset
}

pub fn dim(&self) -> usize {
self.dim
}

pub fn current_seq_len(&self) -> usize {
self.current_seq_len
}

pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}

pub fn all_data(&self) -> &Option<Tensor> {
&self.all_data
}

pub fn current_data(&self) -> Result<Option<Tensor>> {
let data = match self.all_data.as_ref() {
None => None,
Some(d) => {
if self.current_seq_len >= self.max_seq_len {
Some(d.clone())
} else {
Some(d.narrow(self.dim, 0, self.current_seq_len)?)
}
}
};
Ok(data)
}

pub fn reset(&mut self) {
self.offset = 0;
self.current_seq_len = 0;
self.all_data = None;
}

pub fn append(&mut self, src: &Tensor) -> Result<Tensor> {
let seq_len = src.dim(self.dim)?;
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
// self.all_data.get_or_insert_with.
if self.all_data.is_none() {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.max_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad)
};
let ad = self.all_data.as_mut().unwrap();

self.current_seq_len += seq_len;
if seq_len >= self.max_seq_len {
let to_copy = src
.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?
.contiguous()?;
ad.slice_set(&to_copy, self.dim, 0)?;
self.offset = 0;
// Here we return `src` rather than `ad` so that all the past can be used.
Ok(src.clone())
} else {
let rem_len = self.max_seq_len - self.offset;
if seq_len <= rem_len {
ad.slice_set(&src.contiguous()?, self.dim, self.offset)?;
self.offset = (self.offset + seq_len) % self.max_seq_len;
} else {
// We have to make two copies here as we go over the boundary of the cache.
if rem_len > 0 {
let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?;
ad.slice_set(&src1, self.dim, self.offset)?;
}
let src2 = src
.narrow(self.dim, rem_len, seq_len - rem_len)?
.contiguous()?;
ad.slice_set(&src2, self.dim, 0)?;
self.offset = seq_len - rem_len;
}
if self.current_seq_len >= self.max_seq_len {
Ok(ad.clone())
} else {
Ok(ad.narrow(self.dim, 0, self.current_seq_len)?)
}
}
}

fn get_mask_abs(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
let context = self.max_seq_len;
let mask: Vec<_> = (0..size1)
.flat_map(|i| {
(0..size2).map(move |j| {
u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i)
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}

fn get_mask_rel(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
let context = self.max_seq_len;
let upd_offset = (self.offset + size1) % self.max_seq_len;
let mask: Vec<_> = (0..size1)
.flat_map(|pos_src| {
// The absolute position of the elements that will get added to the cache.
let pos_src = self.current_seq_len + pos_src;
(0..size2).map(move |pos_cache_rel| {
// The absolute position of the cache elements after the addition.
let pos_cache = self.current_seq_len + size1 + pos_cache_rel - upd_offset;
let pos_cache = if pos_cache_rel < upd_offset {
pos_cache
} else {
pos_cache - self.max_seq_len
};
u8::from(pos_cache > pos_src || pos_cache + context < pos_src)
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}

/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
let mask = if seq_len == 1 {
None
} else {
let mask = if seq_len < self.max_seq_len {
let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
self.get_mask_rel(seq_len, cache_out_len, device)?
} else {
self.get_mask_abs(seq_len, seq_len, device)?
};
Some(mask)
};
Ok(mask)
}
}

#[derive(Debug, Clone)]
pub struct RotatingKvCache {
k: RotatingCache,
v: RotatingCache,
}

impl RotatingKvCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
let k = RotatingCache::new(dim, max_seq_len);
let v = RotatingCache::new(dim, max_seq_len);
Self { k, v }
}

pub fn k_cache(&self) -> &RotatingCache {
&self.k
}

pub fn v_cache(&self) -> &RotatingCache {
&self.v
}

pub fn k_cache_mut(&mut self) -> &mut RotatingCache {
&mut self.k
}

pub fn v_cache_mut(&mut self) -> &mut RotatingCache {
&mut self.v
}

pub fn k(&self) -> Result<Option<Tensor>> {
self.k.current_data()
}

pub fn v(&self) -> Result<Option<Tensor>> {
self.v.current_data()
}

pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
let out_k = self.k.append(k)?;
let out_v = self.v.append(v)?;
Ok((out_k, out_v))
}

pub fn offset(&self) -> usize {
self.k.offset()
}

pub fn current_seq_len(&self) -> usize {
self.k.current_seq_len()
}

pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
self.k.attn_mask(seq_len, device)
}

pub fn reset(&mut self) {
self.k.reset();
self.v.reset();
}
}
Loading
Loading