-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
20332f7
commit 2e89312
Showing
2 changed files
with
46 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,48 @@ | ||
import numojo as nm | ||
from numojo import * | ||
from testing.testing import assert_true, assert_almost_equal, assert_equal | ||
from utils_for_test import check, check_is_close | ||
from utils_for_test import check, check | ||
from python import Python | ||
|
||
def test_bool_masks(): | ||
var np = Python.import_module("numpy") | ||
|
||
# Create NumPy and NuMojo arrays using arange and reshape | ||
var np_A = np.arange(0, 24, dtype=np.int16).reshape((3, 2, 4)) | ||
var A = nm.arange[nm.i16](0, 24) | ||
A.reshape(3, 2, 4) | ||
|
||
# Test greater than | ||
var np_gt = np_A > 10 | ||
var gt = A > Scalar[nm.i16](10) | ||
check_is_close(gt, np_gt, "Greater than mask") | ||
|
||
# Test greater than or equal | ||
var np_ge = np_A >= 10 | ||
var ge = A >= Scalar[nm.i16](10) | ||
check_is_close(ge, np_ge, "Greater than or equal mask") | ||
|
||
# Test less than | ||
var np_lt = np_A < 10 | ||
var lt = A < Scalar[nm.i16](10) | ||
check_is_close(lt, np_lt, "Less than mask") | ||
|
||
# Test less than or equal | ||
var np_le = np_A <= 10 | ||
var le = A <= Scalar[nm.i16](10) | ||
check_is_close(le, np_le, "Less than or equal mask") | ||
|
||
# Test equal | ||
var np_eq = np_A == 10 | ||
var eq = A == Scalar[nm.i16](10) | ||
check_is_close(eq, np_eq, "Equal mask") | ||
|
||
# Test not equal | ||
var np_ne = np_A != 10 | ||
var ne = A != Scalar[nm.i16](10) | ||
check_is_close(ne, np_ne, "Not equal mask") | ||
|
||
# Test masked array | ||
var np_mask = np_A[np_A > 10] | ||
var mask = A[A > Scalar[nm.i16](10)] | ||
check_is_close(mask, np_mask, "Masked array") | ||
# def test_bool_masks(): | ||
# var np = Python.import_module("numpy") | ||
|
||
# # Create NumPy and NuMojo arrays using arange and reshape | ||
# var np_A = np.arange(0, 24, dtype=np.int16).reshape((3, 2, 4)) | ||
# var A = nm.arange[nm.i16](0, 24) | ||
# A.reshape(3, 2, 4) | ||
|
||
# # Test greater than | ||
# var np_gt = np_A > 10 | ||
# var gt = A > Scalar[nm.i16](10) | ||
# check(gt, np_gt, "Greater than mask") | ||
|
||
# # Test greater than or equal | ||
# var np_ge = np_A >= 10 | ||
# var ge = A >= Scalar[nm.i16](10) | ||
# check(ge, np_ge, "Greater than or equal mask") | ||
|
||
# # Test less than | ||
# var np_lt = np_A < 10 | ||
# var lt = A < Scalar[nm.i16](10) | ||
# check(lt, np_lt, "Less than mask") | ||
|
||
# # Test less than or equal | ||
# var np_le = np_A <= 10 | ||
# var le = A <= Scalar[nm.i16](10) | ||
# check(le, np_le, "Less than or equal mask") | ||
|
||
# # Test equal | ||
# var np_eq = np_A == 10 | ||
# var eq = A == Scalar[nm.i16](10) | ||
# check(eq, np_eq, "Equal mask") | ||
|
||
# # Test not equal | ||
# var np_ne = np_A != 10 | ||
# var ne = A != Scalar[nm.i16](10) | ||
# check(ne, np_ne, "Not equal mask") | ||
|
||
# # Test masked array | ||
# var np_mask = np_A[np_A > 10] | ||
# var mask = A[A > Scalar[nm.i16](10)] | ||
# check(mask, np_mask, "Masked array") |