diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 2684a4b8ef..3e0ef481b1 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -870,19 +870,38 @@ impl<'a> Map1 for Embedding<'a> { } let vs = &vs[layout.start_offset()..]; let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size); - // TODO: Optimize for the case where ids are contiguous. - for index in self.ids_l.strided_index() { - let index = self.ids[index].try_into()?; - if index >= self.vocab_size { - Err(Error::InvalidIndex { - index, - size: self.vocab_size, - op: "take", + match self.ids_l.contiguous_offsets() { + Some((o1, o2)) => { + for &index in self.ids[o1..o2].iter() { + let index = index as usize; + if index >= self.vocab_size { + Err(Error::InvalidIndex { + index, + size: self.vocab_size, + op: "take", + } + .bt())? + } else { + let hidden_size = self.hidden_size; + values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); + } + } + } + None => { + for index in self.ids_l.strided_index() { + let index = self.ids[index].try_into()?; + if index >= self.vocab_size { + Err(Error::InvalidIndex { + index, + size: self.vocab_size, + op: "take", + } + .bt())? + } else { + let hidden_size = self.hidden_size; + values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); + } } - .bt())? - } else { - let hidden_size = self.hidden_size; - values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); } } Ok(values)