Skip to content

Commit

Permalink
Add unique, unique_consecutive (#8258)
Browse files Browse the repository at this point in the history
Co-authored-by: mrguenther <mrguenther@google.com>
  • Loading branch information
mrguenther and mrguenther authored Oct 16, 2024
1 parent c895c5a commit ee388f6
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 3 deletions.
4 changes: 1 addition & 3 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@
"svd_lowrank",
"unfold_copy",
"unfold",
"unique_consecutive",
"unique",
"unravel_index",
"var_mean",
"nanmean",
Expand All @@ -109,7 +107,7 @@
not_support_ops_list = {
"chalf", # Skip due to jax not support complex32 with backend: https://github.com/google/jax/issues/14180
"__rpow__", # NOTE: cannot fix because torch test case has undefined behavior
# such as 0 to negative power.
# such as 0 to negative power.
"ceil", # only failed with python 3.9
"trunc", # only failed with python 3.9
"to_sparse", # We are not supporting sparse tensors yet.
Expand Down
137 changes: 137 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
torch.ops.aten.scatter_reduce_.two: torch.ops.aten.scatter_reduce,
}

# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
_jax_version = tuple(int(v) for v in jax.version._version.split("."))


def make_mutation(op):
if type(mutation_ops_to_functional[op]) is tuple:
Expand Down Expand Up @@ -2757,6 +2760,140 @@ def _aten_unbind(a, dim=0):
return [jax.lax.index_in_dim(a, i, dim, keepdims=False) for i in range(a.shape[dim])]


# aten.unique_dim
#
# NOTE: Like the CUDA and CPU implementations, this implementation always sorts
# the tensor regardless of the `sorted` argument passed to `torch.unique`.
@op(torch.ops.aten.unique_dim)
def _aten_unique_dim(input_tensor,
dim,
sort=True,
return_inverse=False,
return_counts=False):
result_tensor_or_tuple = jnp.unique(input_tensor,
return_index=False,
return_inverse=return_inverse,
return_counts=return_counts,
axis=dim,
equal_nan=False)
result_list = (
list(result_tensor_or_tuple) if isinstance(result_tensor_or_tuple, tuple)
else [result_tensor_or_tuple])

if not return_inverse:
result_list.insert(1, None)
elif _jax_version < (0, 4, 31) and dim is not None:
result_list[1] = result_list[1].flatten()

if not return_counts:
result_list.insert(2, None)

# [result, None, None] if return_inverse=False and return_counts=False
# [result, inverse, None] if return_inverse=True and return_counts=False
# [result, None, counts] if return_inverse=False and return_counts=True
# [result, inverse, counts] if return_inverse=True and return_counts=True
return result_list


# aten._unique
#
# NOTE: Like the CUDA and CPU implementations, this implementation always sorts
# the tensor regardless of the `sorted` argument passed to `torch.unique`.
@op(torch.ops.aten._unique)
def _aten_unique(input_tensor,
sort=True,
return_inverse=False):
result_tensor_or_tuple = jnp.unique(input_tensor,
return_index=False,
return_inverse=return_inverse,
return_counts=False,
axis=None,
equal_nan=False)
if return_inverse:
return result_tensor_or_tuple
else:
return (result_tensor_or_tuple, None)


# aten._unique2
#
# NOTE: Like the CUDA and CPU implementations, this implementation always sorts
# the tensor regardless of the `sorted` argument passed to `torch.unique`.
@op(torch.ops.aten._unique2)
def _aten_unique2(input_tensor,
sort=True,
return_inverse=False,
return_counts=False):
return _aten_unique_dim(input_tensor=input_tensor,
dim=None,
sort=sort,
return_inverse=return_inverse,
return_counts=return_counts)


# aten.unique_consecutive
@op(torch.ops.aten.unique_consecutive)
def _aten_unique_consecutive(input_tensor,
return_inverse=False,
return_counts=None,
dim=None):
# Explanation of computations (shown in 1D for simplicity):
#
# Input [a b b c c c d d d d e e e e e]
# Slice dropping final element (input[:-1]) [a b b c c c d d d d e e e e]
# Slice dropping first element (input[1:]) [b b c c c d d d d e e e e e]
# Boolean != operation on shifted slices [1 0 1 0 0 1 0 0 0 1 0 0 0 0]
# Prepend 1 to represent the first element [1 1 0 1 0 0 1 0 0 0 1 0 0 0 0]
# Filter input by the resulting bool array [a b c d e ]
# Output [a b c d e]

if dim is None:
inverse_shape = input_tensor.shape
input_tensor = input_tensor.flatten()
ndim = 1
dim = 0
else:
inverse_shape = input_tensor.shape[dim]
ndim = input_tensor.ndim
if dim < 0:
dim += ndim

nd_slice_0 = tuple(slice(None, -1) if d == dim else slice(None)
for d in range(ndim))
nd_slice_1 = tuple(slice(1, None) if d == dim else slice(None)
for d in range(ndim))

axes_to_reduce = tuple(d for d in range(ndim) if d != dim)

does_not_equal_prior = (
jnp.any(input_tensor[nd_slice_0] != input_tensor[nd_slice_1],
axis=axes_to_reduce,
keepdims=False))

if input_tensor.shape[dim] != 0:
# Prepend `True` to represent the first element of the input.
does_not_equal_prior = jnp.insert(does_not_equal_prior, 0, True)

include_indices = jnp.argwhere(does_not_equal_prior)[:, 0]

output_tensor = input_tensor[
tuple(include_indices if d == dim else slice(None) for d in range(ndim))]

if return_inverse or return_counts:
counts = (jnp.append(include_indices[1:], input_tensor.shape[dim]) -
include_indices[:])

inverse = (
jnp.reshape(jnp.repeat(jnp.arange(len(counts)), counts), inverse_shape)
if return_inverse
else None
)

return output_tensor, inverse, counts

return output_tensor, None, None


# NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d
# despite those being core aten ops, they also have decompositions.
# here we are using torch decompositions.
Expand Down

0 comments on commit ee388f6

Please sign in to comment.