Skip to content

Commit

Permalink
Merge pull request #130 from shivasankarka/main
Browse files Browse the repository at this point in the history
Fixed adjust slice
  • Loading branch information
forFudan authored Oct 19, 2024
2 parents 2cca3f1 + 234afbf commit 84af9a9
Showing 1 changed file with 44 additions and 51 deletions.
95 changes: 44 additions & 51 deletions numojo/core/ndarray.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ struct NDArray[dtype: DType = DType.float64](
var count: Int = 0
var spec: List[Int] = List[Int]()
for i in range(n_slices):
self._adjust_slice_(slice_list[i], self.ndshape[i])
# self._adjust_slice_(slice_list[i], self.ndshape[i])
if (
slice_list[i].start.value() >= self.ndshape[i]
or slice_list[i].end.value() > self.ndshape[i]
Expand Down Expand Up @@ -611,43 +611,11 @@ struct NDArray[dtype: DType = DType.float64](
Example:
`arr[1:3, 2:4]` returns the corresponding sliced array (2 x 2).
"""
print("slices: ", slices[0], slices[1], slices[2])
var n_slices: Int = len(slices)
var ndims: Int = 0
var count: Int = 0
var spec: List[Int] = List[Int]()
var slice_list: List[Slice] = List[Slice]()
for i in range(n_slices):
var start: Int = 0
var end: Int = 0
if slices[i].start is None and slices[i].end is None:
start = 0
end = self.ndshape[i]
temp = Slice(
start=Optional(start),
end=Optional(end),
step=Optional(slices[i].step),
)
slice_list.append(temp)
if slices[i].start is None and slices[i].end is not None:
start = 0
temp = Slice(
start=Optional(start),
end=Optional(slices[i].end.value()),
step=Optional(slices[i].step),
)
slice_list.append(temp)
if slices[i].start is not None and slices[i].end is None:
end = self.ndshape[i]
temp = Slice(
start=Optional(slices[i].start.value()),
end=Optional(end),
step=Optional(slices[i].step),
)
slice_list.append(temp)
if slices[i].start is not None and slices[i].end is not None:
slice_list.append(slices[i])

var slice_list: List[Slice] = self._adjust_slice_(slices)
for i in range(n_slices):
if (
slice_list[i].start.value() >= self.ndshape[i]
Expand Down Expand Up @@ -867,24 +835,48 @@ struct NDArray[dtype: DType = DType.float64](
var idx: Int = _get_index(index, self.coefficient)
return self.data.load[width=1](idx)

fn _adjust_slice_(self, inout span: Slice, dim: Int):
fn _adjust_slice_(self, slice_list: List[Slice]) raises -> List[Slice]:
"""
Adjusts the slice values to lie within 0 and dim.
"""
if span.start or span.end:
var start = int(span.start.value())
var end = int(span.end.value())
if start < 0:
start = dim + start
if not span.end:
end = dim
elif end < 0:
end = dim + end
if end > dim:
end = dim
if end < start:
var n_slices: Int = slice_list.__len__()
var slices = List[Slice]()
for i in range(n_slices):
var start: Int = 0
var end: Int = 0
if slice_list[i].start is None and slice_list[i].end is None:
start = 0
end = self.ndshape[i]
temp = Slice(
start=Optional(start),
end=Optional(end),
step=Optional(slice_list[i].step),
)
slices.append(temp)
if slice_list[i].start is None and slice_list[i].end is not None:
start = 0
end = 0
temp = Slice(
start=Optional(start),
end=Optional(slice_list[i].end.value()),
step=Optional(slice_list[i].step),
)
slices.append(temp)
if slice_list[i].start is not None and slice_list[i].end is None:
end = self.ndshape[i]
temp = Slice(
start=Optional(slice_list[i].start.value()),
end=Optional(end),
step=Optional(slice_list[i].step),
)
slices.append(temp)
if (
slice_list[i].start is not None
and slice_list[i].end is not None
):
slices.append(slice_list[i])
else:
raise Error("Error: Undefined Slice")
return slices^

fn __getitem__(self, owned *slices: Slice) raises -> Self:
"""
Expand All @@ -908,23 +900,24 @@ struct NDArray[dtype: DType = DType.float64](
var narr: Self = self[slice_list]
return narr

fn __getitem__(self, owned slices: List[Slice]) raises -> Self:
fn __getitem__(self, owned slice_list: List[Slice]) raises -> Self:
"""
Retreive slices of an array from list of slices.
Example:
`arr[1:3, 2:4]` returns the corresponding sliced array (2 x 2).
"""

var n_slices: Int = slices.__len__()
var n_slices: Int = slice_list.__len__()
if n_slices > self.ndim or n_slices < self.ndim:
raise Error("Error: No of slices do not match shape")

var ndims: Int = 0
var spec: List[Int] = List[Int]()
var count: Int = 0

var slices: List[Slice] = self._adjust_slice_(slice_list)
for i in range(slices.__len__()):
self._adjust_slice_(slices[i], self.ndshape[i])
if (
slices[i].start.value() >= self.ndshape[i]
or slices[i].end.value() > self.ndshape[i]
Expand Down

0 comments on commit 84af9a9

Please sign in to comment.