Skip to content

Commit

Permalink
fixed linting error, added readme to .toml
Browse files Browse the repository at this point in the history
  • Loading branch information
shivasankarka committed Sep 9, 2024
1 parent 6fff544 commit 061ce03
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
1 change: 1 addition & 0 deletions mojoproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ authors = [
channels = ["conda-forge", "https://conda.modular.com/max"]
platforms = ["osx-arm64", "linux-64"]
license = "Apache-2.0"
readme = "README.md"

[tasks]
# test whether tests pass and the package can be built
Expand Down
10 changes: 6 additions & 4 deletions numojo/core/ndarray.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -2746,8 +2746,9 @@ struct NDArray[dtype: DType = DType.float64](
raise Error("Error: Elements of `index` exceed the array shape")
return self.data.load[width=1](_get_index(index, self.stride))


fn itemset(inout self, index: Variant[Int, List[Int]], item: Scalar[dtype]) raises:
fn itemset(
inout self, index: Variant[Int, List[Int]], item: Scalar[dtype]
) raises:
"""Set the scalar at the coordinates.
Args:
Expand Down Expand Up @@ -2820,10 +2821,11 @@ struct NDArray[dtype: DType = DType.float64](
raise Error("Error: Length of Indices do not match the shape")
for i in range(indices.__len__()):
if indices[i] >= self.ndshape[i]:
raise Error("Error: Elements of `index` exceed the array shape")
raise Error(
"Error: Elements of `index` exceed the array shape"
)
self.data.store[width=1](_get_index(indices, self.stride), item)


fn max(self, axis: Int = 0) raises -> Self:
"""
Max on axis.
Expand Down
7 changes: 3 additions & 4 deletions tests/test_math.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,11 @@ def test_inverse_2():
nm.math.linalg.inverse(arr), np.linalg.inv(np_arr), "Inverse is broken"
)


def test_setitem():
var np = Python.import_module("numpy")
var arr = nm.NDArray(4, 4)
var np_arr = arr.to_numpy()
arr.itemset(List(2,2), 1000)
arr.itemset(List(2, 2), 1000)
np_arr[(2, 2)] = 1000
check_is_close(
arr, np_arr, "Itemset is broken"
)
check_is_close(arr, np_arr, "Itemset is broken")

0 comments on commit 061ce03

Please sign in to comment.