Skip to content

Commit

Permalink
Add the copy op. (#227)
Browse files Browse the repository at this point in the history
* Add the copy op.

* Tweak some cat error messages.

* Handle the contiguous case in to_vec1.

* Fast variant for to_vec2.

* Add add a faster to_vec3 variant.
  • Loading branch information
LaurentMazare authored Jul 23, 2023
1 parent 23827c4 commit fe87778
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 40 deletions.
5 changes: 5 additions & 0 deletions candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ impl Tensor {
}
}
Op::Reshape(node)
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
| Op::Reduce(node, _, _)
Expand Down Expand Up @@ -246,6 +247,10 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
}
Op::Copy(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad)?
}
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;
let sum_grad = grads.or_insert(arg)?;
Expand Down
1 change: 1 addition & 0 deletions candle-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ pub(crate) enum Op {
add: f64,
},
ToDType(Tensor),
Copy(Tensor),
Broadcast(Tensor),
Narrow(Tensor, usize, usize, usize),
Reshape(Tensor),
Expand Down
106 changes: 66 additions & 40 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1128,17 +1128,17 @@ impl Tensor {
}
.bt())?
}
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
let data = S::cpu_storage_as_slice(cpu_storage)?;
let data = match self.layout.contiguous_offsets() {
Some((o1, o2)) => data[o1..o2].to_vec(),
None => self.strided_index().map(|i| data[i]).collect(),
};
Ok::<Vec<_>, Error>(data)
};
match &*self.storage() {
Storage::Cpu(cpu_storage) => {
let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok(self.strided_index().map(|i| data[i]).collect())
}
Storage::Cuda(slice) => {
// TODO: Would it be possible to only fetch the necessary data?
let cpu_storage = slice.to_cpu_storage()?;
let data = S::cpu_storage_as_slice(&cpu_storage)?;
Ok(self.strided_index().map(|i| data[i]).collect())
}
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}

Expand All @@ -1148,12 +1148,22 @@ impl Tensor {
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
let data = S::cpu_storage_as_slice(cpu_storage)?;
let mut rows = vec![];
let mut src_index = self.strided_index();
for _idx_row in 0..dim1 {
let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
rows.push(row)
match self.layout.contiguous_offsets() {
Some((o1, o2)) => {
let data = &data[o1..o2];
for idx_row in 0..dim1 {
rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec())
}
}
None => {
let mut src_index = self.strided_index();
for _idx_row in 0..dim1 {
let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
rows.push(row)
}
assert!(src_index.next().is_none());
}
}
assert!(src_index.next().is_none());
Ok(rows)
};
match &*self.storage() {
Expand All @@ -1168,16 +1178,32 @@ impl Tensor {
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
let data = S::cpu_storage_as_slice(cpu_storage)?;
let mut top_rows = vec![];
let mut src_index = self.strided_index();
for _idx in 0..dim1 {
let mut rows = vec![];
for _jdx in 0..dim2 {
let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
rows.push(row)
match self.layout.contiguous_offsets() {
Some((o1, o2)) => {
let data = &data[o1..o2];
let dim23 = dim2 * dim3;
for idx1 in 0..dim1 {
let data = &data[idx1 * dim23..(idx1 + 1) * dim23];
let mut rows = vec![];
for idx2 in 0..dim2 {
rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec())
}
top_rows.push(rows);
}
}
None => {
let mut src_index = self.strided_index();
for _idx in 0..dim1 {
let mut rows = vec![];
for _jdx in 0..dim2 {
let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
rows.push(row)
}
top_rows.push(rows);
}
assert!(src_index.next().is_none());
}
top_rows.push(rows);
}
assert!(src_index.next().is_none());
Ok(top_rows)
};
match &*self.storage() {
Expand Down Expand Up @@ -1404,7 +1430,7 @@ impl Tensor {
id: TensorId::new(),
storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
layout: self.layout.clone(),
op: None, // TODO
op: Some(Op::Copy(self.clone())),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
Expand Down Expand Up @@ -1540,7 +1566,7 @@ impl Tensor {
Ok(from_storage(
storage,
shape.clone(),
None, // TODO
Some(Op::Copy(self.clone())),
false,
))
}
Expand Down Expand Up @@ -1734,7 +1760,6 @@ impl Tensor {
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg.dtype() != dtype {
// TODO: Improve the error message.
Err(Error::DTypeMismatchBinaryOp {
lhs: dtype,
rhs: arg.dtype(),
Expand All @@ -1743,15 +1768,21 @@ impl Tensor {
.bt())?
}
if arg.device().location() != device.location() {
// TODO: Improve the error message.
Err(Error::DeviceMismatchBinaryOp {
lhs: device.location(),
rhs: arg.device().location(),
op: "cat",
}
.bt())?
}
let mut mismatch = arg.rank() != rank;
if rank != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: rank,
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
Expand All @@ -1763,20 +1794,15 @@ impl Tensor {
cat_dims[0] += v2;
}
if dim_idx != 0 && v1 != v2 {
// TODO: It would probably be good to have a nicer error message here, i.e.
// mention the problematic dimension and the values.
mismatch = true;
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
if mismatch {
Err(Error::ShapeMismatchCat {
dim: 0, // TODO: not the appropriate error message
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
let next_offset = offsets.last().unwrap() + arg.elem_count();
offsets.push(next_offset);
}
Expand Down

0 comments on commit fe87778

Please sign in to comment.