From 9efb7b791e42cebab12e354e969d92bc4bfcb174 Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 21 Jul 2023 12:43:12 +0100 Subject: [PATCH] Add some epsilon tolerance to grad tests so that they work on cuda / mkl. --- candle-core/src/tensor.rs | 5 +++++ candle-core/tests/grad_tests.rs | 16 ++++++++-------- candle-core/tests/test_utils.rs | 17 +++++++++++++++++ 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 42d660f442..05791ed100 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -482,6 +482,11 @@ impl Tensor { } } + /// An alias for `to_scalar`. + pub fn to_vec0(&self) -> Result { + self.to_scalar::() + } + /// This operation multiplies the input tensor by `mul` then adds `add` and return the result. /// The input values `mul` and `add` are casted to the appropriate type so some rounding might /// be performed. diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 6f30b5b718..591b504ab4 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -114,23 +114,23 @@ fn unary_grad(device: &Device) -> Result<()> { let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; assert_eq!( - y.to_vec1::()?, - [0.14112, 0.84147096, -0.7568025, 0.14943814], + test_utils::to_vec1_round(&y, 4)?, + [0.1411, 0.8415, -0.7568, 0.1494], ); assert_eq!( - grad_x.to_vec1::()?, - [-0.9899925, 0.5403023, -0.6536436, 0.9887711], + test_utils::to_vec1_round(grad_x, 4)?, + [-0.99, 0.5403, -0.6536, 0.9888], ); let y = x.cos()?; let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; assert_eq!( - y.to_vec1::()?, - [-0.9899925, 0.5403023, -0.6536436, 0.9887711], + test_utils::to_vec1_round(&y, 4)?, + [-0.99, 0.5403, -0.6536, 0.9888], ); assert_eq!( - grad_x.to_vec1::()?, - [-0.14112, -0.84147096, 0.7568025, -0.14943814], + test_utils::to_vec1_round(grad_x, 4)?, + [-0.1411, -0.8415, 0.7568, -0.1494], ); let y = x.sqr()?; let grads = y.backward()?; diff --git a/candle-core/tests/test_utils.rs b/candle-core/tests/test_utils.rs index 4dd44b64f5..5f7d311793 100644 --- a/candle-core/tests/test_utils.rs +++ b/candle-core/tests/test_utils.rs @@ -20,6 +20,23 @@ macro_rules! test_device { }; } +pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result> { + let b = 10f32.powi(digits); + let t = t.to_vec1::()?; + let t = t.iter().map(|t| f32::round(t * b) / b).collect(); + Ok(t) +} + +pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result>> { + let b = 10f32.powi(digits); + let t = t.to_vec2::()?; + let t = t + .iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect(); + Ok(t) +} + pub fn to_vec3_round(t: Tensor, digits: i32) -> Result>>> { let b = 10f32.powi(digits); let t = t.to_vec3::()?;