From 234afbff6961d86620169a7860990d470811a4ca Mon Sep 17 00:00:00 2001 From: shivasankar Date: Sat, 19 Oct 2024 13:12:35 +0900 Subject: [PATCH] fixed getitem adjust slice function errors --- numojo/core/ndarray.mojo | 95 +++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 51 deletions(-) diff --git a/numojo/core/ndarray.mojo b/numojo/core/ndarray.mojo index 69ad7e1..a6b53c9 100644 --- a/numojo/core/ndarray.mojo +++ b/numojo/core/ndarray.mojo @@ -484,7 +484,7 @@ struct NDArray[dtype: DType = DType.float64]( var count: Int = 0 var spec: List[Int] = List[Int]() for i in range(n_slices): - self._adjust_slice_(slice_list[i], self.ndshape[i]) + # self._adjust_slice_(slice_list[i], self.ndshape[i]) if ( slice_list[i].start.value() >= self.ndshape[i] or slice_list[i].end.value() > self.ndshape[i] @@ -611,43 +611,11 @@ struct NDArray[dtype: DType = DType.float64]( Example: `arr[1:3, 2:4]` returns the corresponding sliced array (2 x 2). """ - print("slices: ", slices[0], slices[1], slices[2]) var n_slices: Int = len(slices) var ndims: Int = 0 var count: Int = 0 var spec: List[Int] = List[Int]() - var slice_list: List[Slice] = List[Slice]() - for i in range(n_slices): - var start: Int = 0 - var end: Int = 0 - if slices[i].start is None and slices[i].end is None: - start = 0 - end = self.ndshape[i] - temp = Slice( - start=Optional(start), - end=Optional(end), - step=Optional(slices[i].step), - ) - slice_list.append(temp) - if slices[i].start is None and slices[i].end is not None: - start = 0 - temp = Slice( - start=Optional(start), - end=Optional(slices[i].end.value()), - step=Optional(slices[i].step), - ) - slice_list.append(temp) - if slices[i].start is not None and slices[i].end is None: - end = self.ndshape[i] - temp = Slice( - start=Optional(slices[i].start.value()), - end=Optional(end), - step=Optional(slices[i].step), - ) - slice_list.append(temp) - if slices[i].start is not None and slices[i].end is not None: - slice_list.append(slices[i]) - + var slice_list: List[Slice] = self._adjust_slice_(slices) for i in range(n_slices): if ( slice_list[i].start.value() >= self.ndshape[i] @@ -867,24 +835,48 @@ struct NDArray[dtype: DType = DType.float64]( var idx: Int = _get_index(index, self.coefficient) return self.data.load[width=1](idx) - fn _adjust_slice_(self, inout span: Slice, dim: Int): + fn _adjust_slice_(self, slice_list: List[Slice]) raises -> List[Slice]: """ Adjusts the slice values to lie within 0 and dim. """ - if span.start or span.end: - var start = int(span.start.value()) - var end = int(span.end.value()) - if start < 0: - start = dim + start - if not span.end: - end = dim - elif end < 0: - end = dim + end - if end > dim: - end = dim - if end < start: + var n_slices: Int = slice_list.__len__() + var slices = List[Slice]() + for i in range(n_slices): + var start: Int = 0 + var end: Int = 0 + if slice_list[i].start is None and slice_list[i].end is None: + start = 0 + end = self.ndshape[i] + temp = Slice( + start=Optional(start), + end=Optional(end), + step=Optional(slice_list[i].step), + ) + slices.append(temp) + if slice_list[i].start is None and slice_list[i].end is not None: start = 0 - end = 0 + temp = Slice( + start=Optional(start), + end=Optional(slice_list[i].end.value()), + step=Optional(slice_list[i].step), + ) + slices.append(temp) + if slice_list[i].start is not None and slice_list[i].end is None: + end = self.ndshape[i] + temp = Slice( + start=Optional(slice_list[i].start.value()), + end=Optional(end), + step=Optional(slice_list[i].step), + ) + slices.append(temp) + if ( + slice_list[i].start is not None + and slice_list[i].end is not None + ): + slices.append(slice_list[i]) + else: + raise Error("Error: Undefined Slice") + return slices^ fn __getitem__(self, owned *slices: Slice) raises -> Self: """ @@ -908,7 +900,7 @@ struct NDArray[dtype: DType = DType.float64]( var narr: Self = self[slice_list] return narr - fn __getitem__(self, owned slices: List[Slice]) raises -> Self: + fn __getitem__(self, owned slice_list: List[Slice]) raises -> Self: """ Retreive slices of an array from list of slices. @@ -916,15 +908,16 @@ struct NDArray[dtype: DType = DType.float64]( `arr[1:3, 2:4]` returns the corresponding sliced array (2 x 2). """ - var n_slices: Int = slices.__len__() + var n_slices: Int = slice_list.__len__() if n_slices > self.ndim or n_slices < self.ndim: raise Error("Error: No of slices do not match shape") var ndims: Int = 0 var spec: List[Int] = List[Int]() var count: Int = 0 + + var slices: List[Slice] = self._adjust_slice_(slice_list) for i in range(slices.__len__()): - self._adjust_slice_(slices[i], self.ndshape[i]) if ( slices[i].start.value() >= self.ndshape[i] or slices[i].end.value() > self.ndshape[i]