Skip to content

Commit

Permalink
Simplify the print function
Browse files Browse the repository at this point in the history
  • Loading branch information
forFudan committed Sep 14, 2024
1 parent 1546828 commit 9c13692
Showing 1 changed file with 72 additions and 61 deletions.
133 changes: 72 additions & 61 deletions numojo/core/matrix.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,9 @@ struct Matrix[dtype: DType = DType.float64]():
raise Error("Error: Length of the index does not match the shape.")
if (index[0] >= self.shape[0]) or (index[1] >= self.shape[1]):
raise Error("Error: Elements of `index` exceed the array shape")
return self.data.load(index[0] * self.stride[0] + index[1] * self.stride[1])

return self.data.load(
index[0] * self.stride[0] + index[1] * self.stride[1]
)

fn __setitem__(self, index: Tuple[Int, Int], value: Scalar[dtype]) raises:
"""
Expand All @@ -158,75 +159,67 @@ struct Matrix[dtype: DType = DType.float64]():
raise Error("Error: Length of the index does not match the shape.")
if (index[0] >= self.shape[0]) or (index[1] >= self.shape[1]):
raise Error("Error: Elements of `index` exceed the array shape")
self.data.store(index[0] * self.stride[0] + index[1] * self.stride[1], value)

self.data.store(
index[0] * self.stride[0] + index[1] * self.stride[1], value
)

fn __str__(self) -> String:
"""
Enables str(array)
"""
fn print_row(self: Self, i: Int, sep: String) raises -> String:
var result: String = str("[")
var number_of_sep: Int = 1
if self.shape[1] <= 6:
for j in range(self.shape[1]):
if j == self.shape[1] - 1:
number_of_sep = 0
result += str(self[(i, j)]) + sep * number_of_sep
else:
for j in range(3):
result += str(self[(i, j)]) + sep
result += str("...") + sep
for j in range(self.shape[1] - 3, self.shape[1]):
if j == self.shape[1] - 1:
number_of_sep = 0
result += str(self[(i, j)]) + sep * number_of_sep
result += str("]")
return result

var sep: String = str("\t")
var newline: String = str("\n ")
var number_of_newline: Int = 1
var result: String = "["

try:
if self.shape[0] <= 6:
for i in range(self.shape[0]):
if i == self.shape[0] - 1:
number_of_newline = 0
result += (
print_row(self, i, sep)
+ newline * number_of_newline
)
else:
for i in range(3):
result += print_row(self, i, sep) + newline
result += str("...") + newline
for i in range(self.shape[0] - 3, self.shape[0]):
if i == self.shape[0] - 1:
number_of_newline = 0
result += (
print_row(self, i, sep) + newline * number_of_newline
)
except e:
print("Cannot tranfer matrix to string!", e)
result += str("]")
return (
self._array_to_string(0)
+ "\n"
result
+ "\nSize: "
+ str(self.shape[0])
+ "x"
+ str(self.shape[1])
+ " DType: "
+ str(self.dtype)
)

fn _array_to_string(
self,
dimension: Int,
offset: Int = 0,
) -> String:
if dimension == 1: # each item in a row
var result: String = str("[\t")
var number_of_items = self.shape[1]
if number_of_items <= 6: # Print all items
for i in range(number_of_items):
result += (
self.data.load(offset + i * self.stride[1]).__str__() + "\t"
)
else: # Print first 3 and last 3 items
for i in range(3):
result += (
self.data.load(offset + i * self.stride[1]).__str__() + "\t"
)
result = result + "...\t"
for i in range(number_of_items - 3, number_of_items):
result += (
self.data.load(offset + i * self.stride[1]).__str__() + "\t"
)
result = result + "]"
return result
else: # each row
var result: String = str("[")
var number_of_items = self.shape[0]
if number_of_items <= 6: # Print all items
for i in range(number_of_items):
if i == 0:
result += self._array_to_string(1, offset + i * self.stride[0])
if i > 0:
result += str(" ") + self._array_to_string(1, offset + i * self.stride[0])
if i < (number_of_items - 1):
result += "\n"
else: # Print first 3 and last 3 items
for i in range(3):
if i == 0:
result += self._array_to_string(1, offset + i * self.stride[0])
if i > 0:
result += str(" ") + self._array_to_string(1, offset + i * self.stride[0])
if i < (number_of_items - 1):
result += "\n"
result += "...\n"
for i in range(number_of_items - 3, number_of_items):
result += str(" ") + self._array_to_string(1, offset + i * self.stride[0])
if i < (number_of_items - 1):
result += "\n"
result += "]"
return result


# ===----------------------------------------------------------------------===#
# Fucntions for constructing Matrix
Expand All @@ -247,3 +240,21 @@ fn full[
print("Cannot fill in the values", e)

return matrix


fn rand[
dtype: DType = DType.float64
](shape: Tuple[Int, Int], order: String = "C") -> Matrix[dtype]:
"""Return a matrix with random values uniformed distributed between 0 and 1.
Parameters:
dtype: The data type of the NDArray elements.
Args:
shape: The shape of the Matrix.
order: The order of the Matrix.
"""
var result = Matrix[dtype](shape, order)
for i in range(result.size):
result.data.store(i, random.random_float64(0, 1).cast[dtype]())
return result

0 comments on commit 9c13692

Please sign in to comment.