From c89e44390ecf93a24376d97d7037da68b261752d Mon Sep 17 00:00:00 2001 From: Barni Seetharaman Date: Wed, 9 Oct 2024 21:57:43 +0000 Subject: [PATCH] Enabling log_normal tests * Use the existing log_normal decomposition * Skip numerical comparision for the operation ref: https://github.com/pytorch/xla/issues/7505#issuecomment-2403562005 fixes: https://github.com/pytorch/xla/issues/7505 --- experimental/torch_xla2/test/test_ops.py | 3 ++- experimental/torch_xla2/torch_xla2/decompositions.py | 1 + experimental/torch_xla2/torch_xla2/ops/jaten.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index aa07381cfda..eb3245476ca 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -48,7 +48,6 @@ "linalg.tensorsolve", "linalg.vector_norm", "linspace", - "log_normal", "logspace", "lu", "lu_solve", @@ -160,6 +159,7 @@ 'nn.functional.feature_alpha_dropout', 'cauchy', 'exponential', + 'log_normal', } atol_dict = {"matrix_exp": (2e-1, 2e-4), "linalg.pinv": (8e-1, 2e0), "linalg.eig": (2e0, 3e0), "linalg.eigh": (5e1, 3e0), "linalg.eigvalsh": (5e1, 3e0)} @@ -260,6 +260,7 @@ def test_reference_eager(self, device, dtype, op): if isinstance(t, torch.Tensor) and t.is_sparse: continue check_output = op.name not in random_ops + print("[DEBUG] sample_input: ", sample_input) if op.name == "special.polygamma": # The polygamma function is inaccurate for values < 1. diff --git a/experimental/torch_xla2/torch_xla2/decompositions.py b/experimental/torch_xla2/torch_xla2/decompositions.py index 354ac3d93bf..8eb813cd284 100644 --- a/experimental/torch_xla2/torch_xla2/decompositions.py +++ b/experimental/torch_xla2/torch_xla2/decompositions.py @@ -297,4 +297,5 @@ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w) -> Tens torch.ops.aten.nll_loss2d_backward, torch.ops.aten.bernoulli_.Tensor, torch.ops.aten.bernoulli_.float, + torch.ops.aten.log_normal, ]) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index f671e039839..678c3599973 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -47,6 +47,7 @@ torch.ops.aten.logical_not_: torch.ops.aten.logical_not, torch.ops.aten.unsqueeze_: torch.ops.aten.unsqueeze, torch.ops.aten.transpose_: torch.ops.aten.transpose, + torch.ops.aten.log_normal_: torch.ops.aten.log_normal, }