Skip to content

Commit

Permalink
Enabling log_normal tests
Browse files Browse the repository at this point in the history
* Use the existing log_normal decomposition
* Skip numerical comparision for the operation

ref: #7505 (comment)

fixes: #7505
  • Loading branch information
barney-s committed Oct 9, 2024
1 parent 07d0823 commit c89e443
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 1 deletion.
3 changes: 2 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
"linalg.tensorsolve",
"linalg.vector_norm",
"linspace",
"log_normal",
"logspace",
"lu",
"lu_solve",
Expand Down Expand Up @@ -160,6 +159,7 @@
'nn.functional.feature_alpha_dropout',
'cauchy',
'exponential',
'log_normal',
}

atol_dict = {"matrix_exp": (2e-1, 2e-4), "linalg.pinv": (8e-1, 2e0), "linalg.eig": (2e0, 3e0), "linalg.eigh": (5e1, 3e0), "linalg.eigvalsh": (5e1, 3e0)}
Expand Down Expand Up @@ -260,6 +260,7 @@ def test_reference_eager(self, device, dtype, op):
if isinstance(t, torch.Tensor) and t.is_sparse:
continue
check_output = op.name not in random_ops
print("[DEBUG] sample_input: ", sample_input)

if op.name == "special.polygamma":
# The polygamma function is inaccurate for values < 1.
Expand Down
1 change: 1 addition & 0 deletions experimental/torch_xla2/torch_xla2/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,4 +297,5 @@ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w) -> Tens
torch.ops.aten.nll_loss2d_backward,
torch.ops.aten.bernoulli_.Tensor,
torch.ops.aten.bernoulli_.float,
torch.ops.aten.log_normal,
])
1 change: 1 addition & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
torch.ops.aten.logical_not_: torch.ops.aten.logical_not,
torch.ops.aten.unsqueeze_: torch.ops.aten.unsqueeze,
torch.ops.aten.transpose_: torch.ops.aten.transpose,
torch.ops.aten.log_normal_: torch.ops.aten.log_normal,
}


Expand Down

0 comments on commit c89e443

Please sign in to comment.