Skip to content

Commit

Permalink
Fix torch.polar, torch.polygamma, torch.prod, torch.put
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Teo committed Oct 23, 2024
1 parent 39dc4bc commit 87d0de8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
6 changes: 1 addition & 5 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@
"normal",
"ormqr",
"pca_lowrank",
"polar",
"polygamma",
"prod",
"put",
"searchsorted",
"special.airy_ai",
"special.scaled_modified_bessel_k0",
Expand Down Expand Up @@ -238,7 +234,7 @@ def test_reference_eager(self, device, dtype, op):
if 'dtype' in sample_input.kwargs:
if sample_input.kwargs['dtype'] == torch.int64:
sample_input.kwargs['dtype'] = torch.float
if op.name == "special.polygamma":
if op.name == "polygamma" or op.name == "special.polygamma":
# The polygamma function is inaccurate for values < 1.
# To avoid errors during testing, replace values below 1 with 1.
sample_input.input = self.replace_values_below_threshold(
Expand Down
38 changes: 32 additions & 6 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2706,11 +2706,33 @@ def _aten_nonzero(x):


# aten.prod
@op(torch.ops.aten.prod)
def _aten_prod(input, dim=None, keepdim=False, *, dtype=None):
if dtype:
input = input.astype(mappings.t2j_dtype(dtype))
return _with_reduction_scalar(jnp.prod, input, dim, keepdim)


@op(torch.ops.aten.prod)
def _aten_prod(self, dim=None, keepdim=False):
return _with_reduction_scalar(jnp.prod, self, dim, keepdim)
@op(torch.ops.aten.put)
def _aten_put(self, index, source, accumulate=False):
expanded = False
res = None

if self.ndim == 0:
expanded = True
self = jnp.expand_dims(self, 0)

if accumulate:
tmp = jnp.zeros(self.shape)
tmp = jnp.put(tmp, index, source, inplace=False)
res = jnp.add(self, tmp).astype(self.dtype)
else:
res = jnp.put(self, index, source, inplace=False)

if expanded:
res = res.squeeze()

return res


# aten.randperm
Expand Down Expand Up @@ -3761,9 +3783,9 @@ def f(carry, val):

@op(torch.ops.aten.polygamma)
def _aten_polygamma(x, n):
if n.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]:
n = n.astype(mappings.t2j_dtype(torch.get_default_dtype()))
return jax.lax.polygamma(jnp.float32(x), n)
if n.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]:
n = n.astype(mappings.t2j_dtype(torch.get_default_dtype()))
return jax.lax.polygamma(jnp.float32(x), n)

@op(torch.ops.aten.special_ndtri)
@op_base.promote_int_input
Expand Down Expand Up @@ -4742,6 +4764,10 @@ def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factor
antialias=antialias,
)

@op(torch.ops.aten.polar)
def _aten_polar(abs, angle, *, out=None):
return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle))

@op(torch.ops.aten.cdist)
def _aten_cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary'):
x1 = x1.astype(jnp.float32)
Expand Down

0 comments on commit 87d0de8

Please sign in to comment.