diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index aa0d84e71dc..11535e3d5ac 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -18,7 +18,6 @@ "cholesky", "cholesky_solve", "diagonal_copy", - "digamma", "geqrf", "histogram", # hard op: AssertionError: Tensor-likes are not close! "histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got at position 1. diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 5ef39d40cc8..cfcb6121e7a 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2391,6 +2391,12 @@ def _aten_hypot(input, other): return jnp.hypot(input, other) +@op(torch.ops.aten.digamma) +def _aten_digamma(input, *, out=None): + res = jax.scipy.special.digamma(input).astype(jnp.float32) + # replace indices where input == 0 with -inf in res + return jnp.where(jnp.equal(input, jnp.zeros(input.shape)), -jnp.inf, res) + @op(torch.ops.aten.igamma) def _aten_igamma(input, other): return jax.scipy.special.gammainc(input, other)