Skip to content

Commit

Permalink
Add some epsilon tolerance to grad tests so that they work on cuda / …
Browse files Browse the repository at this point in the history
…mkl.
  • Loading branch information
LaurentMazare committed Jul 21, 2023
1 parent 4106545 commit 9efb7b7
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
5 changes: 5 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,11 @@ impl Tensor {
}
}

/// An alias for `to_scalar`.
pub fn to_vec0<S: crate::WithDType>(&self) -> Result<S> {
self.to_scalar::<S>()
}

/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
/// be performed.
Expand Down
16 changes: 8 additions & 8 deletions candle-core/tests/grad_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,23 @@ fn unary_grad(device: &Device) -> Result<()> {
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(
y.to_vec1::<f32>()?,
[0.14112, 0.84147096, -0.7568025, 0.14943814],
test_utils::to_vec1_round(&y, 4)?,
[0.1411, 0.8415, -0.7568, 0.1494],
);
assert_eq!(
grad_x.to_vec1::<f32>()?,
[-0.9899925, 0.5403023, -0.6536436, 0.9887711],
test_utils::to_vec1_round(grad_x, 4)?,
[-0.99, 0.5403, -0.6536, 0.9888],
);
let y = x.cos()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(
y.to_vec1::<f32>()?,
[-0.9899925, 0.5403023, -0.6536436, 0.9887711],
test_utils::to_vec1_round(&y, 4)?,
[-0.99, 0.5403, -0.6536, 0.9888],
);
assert_eq!(
grad_x.to_vec1::<f32>()?,
[-0.14112, -0.84147096, 0.7568025, -0.14943814],
test_utils::to_vec1_round(grad_x, 4)?,
[-0.1411, -0.8415, 0.7568, -0.1494],
);
let y = x.sqr()?;
let grads = y.backward()?;
Expand Down
17 changes: 17 additions & 0 deletions candle-core/tests/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,23 @@ macro_rules! test_device {
};
}

pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
let b = 10f32.powi(digits);
let t = t.to_vec1::<f32>()?;
let t = t.iter().map(|t| f32::round(t * b) / b).collect();
Ok(t)
}

pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
let b = 10f32.powi(digits);
let t = t.to_vec2::<f32>()?;
let t = t
.iter()
.map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
.collect();
Ok(t)
}

pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
let b = 10f32.powi(digits);
let t = t.to_vec3::<f32>()?;
Expand Down

0 comments on commit 9efb7b7

Please sign in to comment.