Skip to content

Commit

Permalink
Pass axis=dim as a keyword arg (minor cleanup)
Browse files Browse the repository at this point in the history
For consistency with the following `var` call, and because `axis` is
officially a keyword argument, change the `mean` call to pass `axis=dim`
as a keyword argument instead of a positional argument.

Issue: #7542
  • Loading branch information
mrguenther committed Oct 17, 2024
1 parent 70e377d commit 01a3b58
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -3034,7 +3034,7 @@ def _aten_var_mean_correction(tensor, dim=None, correction=1, keepdim=False):
# be nullable, we still need to check for `None` per the API contract.
if correction is None:
correction = 1
mean = jnp.mean(tensor, dim, keepdims=keepdim)
mean = jnp.mean(tensor, axis=dim, keepdims=keepdim)
# TODO: Pass in the `mean=mean` argument once `jax.numpy.var` supports it.
var = jnp.var(tensor, axis=dim, ddof=correction, keepdims=keepdim)
return var, mean
Expand Down

0 comments on commit 01a3b58

Please sign in to comment.