From 87d0de8c1c08a79035c128950ee5d185f675c62e Mon Sep 17 00:00:00 2001 From: Simon Teo Date: Tue, 22 Oct 2024 16:41:58 +0800 Subject: [PATCH] Fix torch.polar, torch.polygamma, torch.prod, torch.put --- experimental/torch_xla2/test/test_ops.py | 6 +-- .../torch_xla2/torch_xla2/ops/jaten.py | 38 ++++++++++++++++--- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index be75771fb33..b15e4fbfdf7 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -69,10 +69,6 @@ "normal", "ormqr", "pca_lowrank", - "polar", - "polygamma", - "prod", - "put", "searchsorted", "special.airy_ai", "special.scaled_modified_bessel_k0", @@ -238,7 +234,7 @@ def test_reference_eager(self, device, dtype, op): if 'dtype' in sample_input.kwargs: if sample_input.kwargs['dtype'] == torch.int64: sample_input.kwargs['dtype'] = torch.float - if op.name == "special.polygamma": + if op.name == "polygamma" or op.name == "special.polygamma": # The polygamma function is inaccurate for values < 1. # To avoid errors during testing, replace values below 1 with 1. sample_input.input = self.replace_values_below_threshold( diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 25a77c70886..0baf66eb653 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2706,11 +2706,33 @@ def _aten_nonzero(x): # aten.prod +@op(torch.ops.aten.prod) +def _aten_prod(input, dim=None, keepdim=False, *, dtype=None): + if dtype: + input = input.astype(mappings.t2j_dtype(dtype)) + return _with_reduction_scalar(jnp.prod, input, dim, keepdim) -@op(torch.ops.aten.prod) -def _aten_prod(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.prod, self, dim, keepdim) +@op(torch.ops.aten.put) +def _aten_put(self, index, source, accumulate=False): + expanded = False + res = None + + if self.ndim == 0: + expanded = True + self = jnp.expand_dims(self, 0) + + if accumulate: + tmp = jnp.zeros(self.shape) + tmp = jnp.put(tmp, index, source, inplace=False) + res = jnp.add(self, tmp).astype(self.dtype) + else: + res = jnp.put(self, index, source, inplace=False) + + if expanded: + res = res.squeeze() + + return res # aten.randperm @@ -3761,9 +3783,9 @@ def f(carry, val): @op(torch.ops.aten.polygamma) def _aten_polygamma(x, n): - if n.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]: - n = n.astype(mappings.t2j_dtype(torch.get_default_dtype())) - return jax.lax.polygamma(jnp.float32(x), n) + if n.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]: + n = n.astype(mappings.t2j_dtype(torch.get_default_dtype())) + return jax.lax.polygamma(jnp.float32(x), n) @op(torch.ops.aten.special_ndtri) @op_base.promote_int_input @@ -4742,6 +4764,10 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor antialias=antialias, ) +@op(torch.ops.aten.polar) +def _aten_polar(abs, angle, *, out=None): + return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle)) + @op(torch.ops.aten.cdist) def _aten_cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary'): x1 = x1.astype(jnp.float32)