Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add _aten_cdist Op #8238

Merged
merged 4 commits into from
Oct 14, 2024
Merged

Add _aten_cdist Op #8238

merged 4 commits into from
Oct 14, 2024

Conversation

vyom1611
Copy link
Contributor

@vyom1611 vyom1611 commented Oct 9, 2024

Based on issue: Op info test for cdist #7400

Right now the test is not passing locally:

Screenshot 2024-10-08 at 6 57 31 PM

The fail error in more detail:

E         AssertionError: Tensor-likes are not close!
E         
E         Mismatched elements: 150 / 150 (100.0%)
E         Greatest absolute difference: 2.0 at index (0, 0, 0) (up to 0.001 allowed)
E         Greatest relative difference: 1.0 at index (0, 0, 0) (up to 1e-05 allowed)
E         
E         To execute this test, run the following from the base repo dir:
E              python test/test_ops.py -k TestOpInfoCPU.test_reference_eager_cdist_cpu_float32
E         
E         This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

However the issue arises from fp arithmetic, as I ran an example to test the closeness of values:

Expected Result (PyTorch): tensor([[3.1193, 2.0959],
[2.7138, 3.8322],
[2.2830, 0.3791]])
Your Result (JAX): [[3.1192703 2.0958931]
[2.7138407 3.8321724]
[2.2830095 0.3791012]]

Can someone help me fix this precision error? Other than that this Op is ready to be merged I think.

@qihqi qihqi requested a review from ManfeiBai October 11, 2024 03:06
@qihqi
Copy link
Collaborator

qihqi commented Oct 11, 2024

Hi, please rebase to HEAD to resolve conflicts. You can make the atol a bit larger for this op example: https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/test/test_ops.py#L156

@vyom1611 vyom1611 reopened this Oct 14, 2024
# General p-norm distance calculation
diff = jnp.abs(jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3))
return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32) ** (1 / p)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution, do we have any reference for others need to know why we need different implementation for these specific difference situtions?

Copy link
Collaborator

@ManfeiBai ManfeiBai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM

@ManfeiBai ManfeiBai merged commit 80db07b into pytorch:master Oct 14, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants