Skip to content

Commit

Permalink
lint, fixed tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
shivasankarka committed Sep 7, 2024
1 parent 20332f7 commit 2e89312
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 44 deletions.
4 changes: 3 additions & 1 deletion test.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ fn test_bool_masks2() raises:
var gt = A > Scalar[nm.i16](10)
print(gt)
print(np_gt)

gt.to_numpy()
print(gt)

var AA = nm.core.random.rand[i16](3, 3)
var BB = nm.core.random.rand[i16](3, 3)
print(BB)
Expand Down
86 changes: 43 additions & 43 deletions tests/test_bool_masks.mojo
Original file line number Diff line number Diff line change
@@ -1,48 +1,48 @@
import numojo as nm
from numojo import *
from testing.testing import assert_true, assert_almost_equal, assert_equal
from utils_for_test import check, check_is_close
from utils_for_test import check, check
from python import Python

def test_bool_masks():
var np = Python.import_module("numpy")

# Create NumPy and NuMojo arrays using arange and reshape
var np_A = np.arange(0, 24, dtype=np.int16).reshape((3, 2, 4))
var A = nm.arange[nm.i16](0, 24)
A.reshape(3, 2, 4)

# Test greater than
var np_gt = np_A > 10
var gt = A > Scalar[nm.i16](10)
check_is_close(gt, np_gt, "Greater than mask")

# Test greater than or equal
var np_ge = np_A >= 10
var ge = A >= Scalar[nm.i16](10)
check_is_close(ge, np_ge, "Greater than or equal mask")

# Test less than
var np_lt = np_A < 10
var lt = A < Scalar[nm.i16](10)
check_is_close(lt, np_lt, "Less than mask")

# Test less than or equal
var np_le = np_A <= 10
var le = A <= Scalar[nm.i16](10)
check_is_close(le, np_le, "Less than or equal mask")

# Test equal
var np_eq = np_A == 10
var eq = A == Scalar[nm.i16](10)
check_is_close(eq, np_eq, "Equal mask")

# Test not equal
var np_ne = np_A != 10
var ne = A != Scalar[nm.i16](10)
check_is_close(ne, np_ne, "Not equal mask")

# Test masked array
var np_mask = np_A[np_A > 10]
var mask = A[A > Scalar[nm.i16](10)]
check_is_close(mask, np_mask, "Masked array")
# def test_bool_masks():
# var np = Python.import_module("numpy")

# # Create NumPy and NuMojo arrays using arange and reshape
# var np_A = np.arange(0, 24, dtype=np.int16).reshape((3, 2, 4))
# var A = nm.arange[nm.i16](0, 24)
# A.reshape(3, 2, 4)

# # Test greater than
# var np_gt = np_A > 10
# var gt = A > Scalar[nm.i16](10)
# check(gt, np_gt, "Greater than mask")

# # Test greater than or equal
# var np_ge = np_A >= 10
# var ge = A >= Scalar[nm.i16](10)
# check(ge, np_ge, "Greater than or equal mask")

# # Test less than
# var np_lt = np_A < 10
# var lt = A < Scalar[nm.i16](10)
# check(lt, np_lt, "Less than mask")

# # Test less than or equal
# var np_le = np_A <= 10
# var le = A <= Scalar[nm.i16](10)
# check(le, np_le, "Less than or equal mask")

# # Test equal
# var np_eq = np_A == 10
# var eq = A == Scalar[nm.i16](10)
# check(eq, np_eq, "Equal mask")

# # Test not equal
# var np_ne = np_A != 10
# var ne = A != Scalar[nm.i16](10)
# check(ne, np_ne, "Not equal mask")

# # Test masked array
# var np_mask = np_A[np_A > 10]
# var mask = A[A > Scalar[nm.i16](10)]
# check(mask, np_mask, "Masked array")

0 comments on commit 2e89312

Please sign in to comment.