Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Fix the issues with solve, inv, and matmul #115

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions numojo/math/linalg/matmul.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ fn matmul_parallelized[
vectorize[dot, width](t2)

parallelize[calculate_A_rows](t0, t0)

var _t0 = t0
var _t1 = t1
var _t2 = t2
var _A = A
var _B = B
var _width = width

return C^


Expand Down
98 changes: 80 additions & 18 deletions numojo/math/linalg/solver.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,40 @@ fn back_substitution[
return x


fn inverse[dtype: DType](array: NDArray[dtype]) raises -> NDArray[dtype]:
fn inv[dtype: DType](A: NDArray[dtype]) raises -> NDArray[dtype]:
"""Find the inverse of a non-singular, row-major matrix.

It uses the function `solve()` to solve `AB = I` for B, where I is
an identity matrix.

The speed is faster than numpy for matrices smaller than 100x100,
and is slower for larger matrices.

Parameters:
dtype: Data type of the inversed matrix.

Args:
A: Input matrix. It should be non-singular, square, and row-major.

Returns:
The reversed matrix of the original matrix.

"""

var m = A.shape()[0]
var I = eye[dtype](m, m)

return solve(A, I)


fn inv_raw[dtype: DType](array: NDArray[dtype]) raises -> NDArray[dtype]:
"""Find the inverse of a non-singular, square matrix.

WARNING: This function is slower than `inv`
as it does not adopt parallelization.
WARNING: This function is slower than `inv`.
as it does not adopt parallelization by using raw methods.

Parameters:
dtype: Data type of the inversed matrix. Default value is `f64`.
dtype: Data type of the inversed matrix.

Args:
array: Input matrix. It should be non-singular and square.
Expand All @@ -206,7 +232,7 @@ fn inverse[dtype: DType](array: NDArray[dtype]) raises -> NDArray[dtype]:
import numojo as nm
fn main() raises:
var A = nm.NDArray("[[1,0,1], [0,2,1], [1,1,1]]")
var B = nm.math.linalg.solver.inverse(A)
var B = nm.math.linalg.solver.inv_raw(A)
print("Original matrix:")
print(A)
print("Reversed matrix:")
Expand Down Expand Up @@ -268,19 +294,19 @@ fn inverse[dtype: DType](array: NDArray[dtype]) raises -> NDArray[dtype]:
return inversed


fn inv[dtype: DType](array: NDArray[dtype]) raises -> NDArray[dtype]:
fn inv_lu[dtype: DType](array: NDArray[dtype]) raises -> NDArray[dtype]:
"""Find the inverse of a non-singular, row-major matrix.

Use LU decomposition algorithm.

The speed is faster than numpy for matrices smaller than 100x100,
and is slower for larger matrices.

TODO: Use `solve()` function to calculate inverse.
TODO: Fix the issues in parallelization.
`AX = I` where `I` is an identity matrix.

Parameters:
dtype: Data type of the inversed matrix. Default value is `f64`.
dtype: Data type of the inversed matrix.

Args:
array: Input matrix. It should be non-singular, square, and row-major.
Expand Down Expand Up @@ -348,7 +374,7 @@ fn solve[
TODO: Use LAPACK for large matrices when it is available.

Parameters:
dtype: Data type of the inversed matrix. Default value is `f64`.
dtype: Data type of the inversed matrix.

Args:
A: Non-singular, square, and row-major matrix. The size is m x m.
Expand Down Expand Up @@ -388,8 +414,51 @@ fn solve[
var Z = zeros[dtype](m, n)
var X = zeros[dtype](m, n)

@parameter
fn calculate_X(col: Int) -> None:
####################################################################
# Parallelization
#
# Parallelization does not work any more since MAX 24.5.
# We temporarily switch to a non-paralleled approach.
# Thus, this block of code is commented out.
# TODO: Fix the issues in parallelization.
####################################################################

# @parameter
# fn calculate_X(col: Int) -> None:
# # Solve `LZ = Y` for `Z` for each col
# for i in range(m): # row of L
# var _temp = Y.load(i * n + col)
# for j in range(i): # col of L
# _temp = _temp - L.load(i * m + j) * Z.load(j * n + col)
# _temp = _temp / L.load(i * m + i)
# Z.store(i * n + col, _temp)

# # Solve `UZ = Z` for `X` for each col
# for i in range(m - 1, -1, -1):
# var _temp2 = Z.load(i * n + col)
# for j in range(i + 1, m):
# _temp2 = _temp2 - U.load(i * m + j) * X.load(j * n + col)
# _temp2 = _temp2 / U.load(i * m + i)
# X.store(i * n + col, _temp2)

# parallelize[calculate_X](n, n)

# # Force extending the lifetime of the matrices because they are destroyed before `parallelize`
# # This is disadvantage of Mojo's ASAP policy
# var _L = L^
# var _U = U^

# return X

####################################################################
# Non-parallelization
#
# Parallelization does not work any more since MAX 24.5.
# We temporarily switch to a non-paralleled approach.
# TODO: Remove the following code when parallelization works again.
####################################################################

for col in range(n):
# Solve `LZ = Y` for `Z` for each col
for i in range(m): # row of L
var _temp = Y.load(i * n + col)
Expand All @@ -406,11 +475,4 @@ fn solve[
_temp2 = _temp2 / U.load(i * m + i)
X.store(i * n + col, _temp2)

parallelize[calculate_X](n, n)

# Force extending the lifetime of the matrices because they are destroyed before `parallelize`
# This is disadvantage of Mojo's ASAP policy
var _L = L^
var _U = U^

return X
10 changes: 0 additions & 10 deletions test.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -292,16 +292,6 @@ fn test_linalg() raises:
# check_is_close(nm.matmul_tiled_unrolled_parallelized(arr,arr),np.matmul(np_arr,np_arr),"TUP matmul is broken")


def test_inv1():
var np = Python.import_module("numpy")
var arr = nm.core.random.rand(5, 5)
var np_arr = arr.to_numpy()
print("arr: ", arr)
print("np_arr: ", np_arr)
print("inverse: ", nm.math.linalg.inverse(arr))
print("np inverse: ", np.linalg.inv(np_arr))


def test_inv():
var np = Python.import_module("numpy")
var arr = nm.core.random.rand(5, 5)
Expand Down
10 changes: 1 addition & 9 deletions tests/test_math.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,6 @@ def test_matmul():
# check_is_close(nm.matmul_tiled_unrolled_parallelized(arr,arr),np.matmul(np_arr,np_arr),"TUP matmul is broken")


def test_inverse():
var np = Python.import_module("numpy")
var arr = nm.core.random.rand(100, 100)
var np_arr = arr.to_numpy()
check_is_close(
nm.math.linalg.inverse(arr), np.linalg.inv(np_arr), "Inverse is broken"
)


# ! The `inv` is broken, it outputs -INF for some values
def test_inv():
var np = Python.import_module("numpy")
Expand All @@ -86,6 +77,7 @@ def test_inv():
nm.math.linalg.inv(arr), np.linalg.inv(np_arr), "Inverse is broken"
)


# ! The `solve` is broken, it outputs -INF, nan, 0 etc for some values
def test_solve():
var np = Python.import_module("numpy")
Expand Down
Loading