Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getter setters #133

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 47 additions & 53 deletions numojo/core/ndarray.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,12 @@ struct NDArray[dtype: DType = DType.float64](
var size_at_dim: Int = self.ndshape[i]
slice_list.append(Slice(0, size_at_dim))

# self.__setitem__(slice_list=slice_list, val=val)
var n_slices: Int = len(slice_list)
var ndims: Int = 0
var count: Int = 0
var spec: List[Int] = List[Int]()
for i in range(n_slices):
self._adjust_slice_(slice_list[i], self.ndshape[i])
# self._adjust_slice_(slice_list[i], self.ndshape[i])
if (
slice_list[i].start.value() >= self.ndshape[i]
or slice_list[i].end.value() > self.ndshape[i]
Expand Down Expand Up @@ -611,43 +610,11 @@ struct NDArray[dtype: DType = DType.float64](
Example:
`arr[1:3, 2:4]` returns the corresponding sliced array (2 x 2).
"""
print("slices: ", slices[0], slices[1], slices[2])
var n_slices: Int = len(slices)
var ndims: Int = 0
var count: Int = 0
var spec: List[Int] = List[Int]()
var slice_list: List[Slice] = List[Slice]()
for i in range(n_slices):
var start: Int = 0
var end: Int = 0
if slices[i].start is None and slices[i].end is None:
start = 0
end = self.ndshape[i]
temp = Slice(
start=Optional(start),
end=Optional(end),
step=Optional(slices[i].step),
)
slice_list.append(temp)
if slices[i].start is None and slices[i].end is not None:
start = 0
temp = Slice(
start=Optional(start),
end=Optional(slices[i].end.value()),
step=Optional(slices[i].step),
)
slice_list.append(temp)
if slices[i].start is not None and slices[i].end is None:
end = self.ndshape[i]
temp = Slice(
start=Optional(slices[i].start.value()),
end=Optional(end),
step=Optional(slices[i].step),
)
slice_list.append(temp)
if slices[i].start is not None and slices[i].end is not None:
slice_list.append(slices[i])

var slice_list: List[Slice] = self._adjust_slice_(slices)
for i in range(n_slices):
if (
slice_list[i].start.value() >= self.ndshape[i]
Expand Down Expand Up @@ -867,24 +834,50 @@ struct NDArray[dtype: DType = DType.float64](
var idx: Int = _get_index(index, self.coefficient)
return self.data.load[width=1](idx)

fn _adjust_slice_(self, inout span: Slice, dim: Int):
fn _adjust_slice_(self, slice_list: List[Slice]) raises -> List[Slice]:
"""
Adjusts the slice values to lie within 0 and dim.
"""
if span.start or span.end:
var start = int(span.start.value())
var end = int(span.end.value())
if start < 0:
start = dim + start
if not span.end:
end = dim
elif end < 0:
end = dim + end
if end > dim:
end = dim
if end < start:
start = 0
end = 0
var n_slices: Int = slice_list.__len__()
var slices = List[Slice]()
for i in range(n_slices):
if i >= self.ndim:
raise Error("Error: Number of slices exceeds array dimensions")

var start: Int = 0
var end: Int = self.ndshape[i]
var step: Int = 1
if slice_list[i].start is not None:
start = slice_list[i].start.value()
if start < 0:
# start += self.ndshape[i]
raise Error(
"Error: Negative indexing in slices not supported"
" currently"
)

if slice_list[i].end is not None:
end = slice_list[i].end.value()
if end < 0:
# end += self.ndshape[i] + 1
raise Error(
"Error: Negative indexing in slices not supported"
" currently"
)

step = slice_list[i].step
if step == 0:
raise Error("Error: Slice step cannot be zero")

slices.append(
Slice(
start=Optional(start),
end=Optional(end),
step=Optional(step),
)
)

return slices^

fn __getitem__(self, owned *slices: Slice) raises -> Self:
"""
Expand All @@ -908,23 +901,24 @@ struct NDArray[dtype: DType = DType.float64](
var narr: Self = self[slice_list]
return narr

fn __getitem__(self, owned slices: List[Slice]) raises -> Self:
fn __getitem__(self, owned slice_list: List[Slice]) raises -> Self:
"""
Retreive slices of an array from list of slices.
Example:
`arr[1:3, 2:4]` returns the corresponding sliced array (2 x 2).
"""

var n_slices: Int = slices.__len__()
var n_slices: Int = slice_list.__len__()
if n_slices > self.ndim or n_slices < self.ndim:
raise Error("Error: No of slices do not match shape")

var ndims: Int = 0
var spec: List[Int] = List[Int]()
var count: Int = 0

var slices: List[Slice] = self._adjust_slice_(slice_list)
for i in range(slices.__len__()):
self._adjust_slice_(slices[i], self.ndshape[i])
if (
slices[i].start.value() >= self.ndshape[i]
or slices[i].end.value() > self.ndshape[i]
Expand Down
Loading