Skip to content

Commit

Permalink
Optimize for the contiguous case.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jul 23, 2023
1 parent 7fcb16a commit 17f8d0f
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions candle-core/src/cpu_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -870,19 +870,38 @@ impl<'a> Map1 for Embedding<'a> {
}
let vs = &vs[layout.start_offset()..];
let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size);
// TODO: Optimize for the case where ids are contiguous.
for index in self.ids_l.strided_index() {
let index = self.ids[index].try_into()?;
if index >= self.vocab_size {
Err(Error::InvalidIndex {
index,
size: self.vocab_size,
op: "take",
match self.ids_l.contiguous_offsets() {
Some((o1, o2)) => {
for &index in self.ids[o1..o2].iter() {
let index = index as usize;
if index >= self.vocab_size {
Err(Error::InvalidIndex {
index,
size: self.vocab_size,
op: "take",
}
.bt())?
} else {
let hidden_size = self.hidden_size;
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
}
}
}
None => {
for index in self.ids_l.strided_index() {
let index = self.ids[index].try_into()?;
if index >= self.vocab_size {
Err(Error::InvalidIndex {
index,
size: self.vocab_size,
op: "take",
}
.bt())?
} else {
let hidden_size = self.hidden_size;
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
}
}
.bt())?
} else {
let hidden_size = self.hidden_size;
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
}
}
Ok(values)
Expand Down

0 comments on commit 17f8d0f

Please sign in to comment.