Skip to content

Commit

Permalink
Add support for linalg.lu (#8227)
Browse files Browse the repository at this point in the history
Co-authored-by: Manfei <41607353+ManfeiBai@users.noreply.github.com>
  • Loading branch information
matinehAkhlaghinia and ManfeiBai authored Oct 8, 2024
1 parent d50fecb commit 7aa996c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
"linalg.ldl_factor_ex",
"linalg.ldl_solve",
"linalg.lstsq",
"linalg.lu",
"linalg.lu_factor",
"linalg.lu_factor_ex",
"linalg.lu_solve",
Expand Down
32 changes: 32 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2259,10 +2259,42 @@ def _aten_linalg_eig(A):
def _aten_linalg_eigh(A, UPLO='L'):
return jnp.linalg.eigh(A, UPLO)


@op(torch.ops.aten.linalg_lu)
def _aten_linalg_lu(A, pivot=True, out=None):
dtype = A.dtype

*_, m, n = A.shape
k = jnp.minimum(m, n)

lu, _, permutation = jax.lax.linalg.lu(A)

L = jnp.tril(lu[..., :, :k], k=-1)
eye_L = jnp.eye(m, k, dtype=dtype)
L = L + eye_L

U = jnp.triu(lu[..., :k, :])

def perm_to_P(perm):
m = perm.shape[-1]
P = jnp.eye(m, dtype=dtype)[perm].T
return P

if permutation.ndim > 1:
num_batch_dims = permutation.ndim - 1
for _ in range(num_batch_dims):
perm_to_P = jax.vmap(perm_to_P, in_axes=0)

P = perm_to_P(permutation)

return P,L,U


@op(torch.ops.aten.gcd)
def _aten_gcd(input, other):
return jnp.gcd(input, other)


# aten.lcm
@op(torch.ops.aten.lcm)
def _aten_lcm(input, other):
Expand Down

0 comments on commit 7aa996c

Please sign in to comment.