From 01a3b58fba97a108448892a834050ad52aeb945f Mon Sep 17 00:00:00 2001 From: mrguenther Date: Thu, 17 Oct 2024 13:55:45 -0700 Subject: [PATCH] Pass `axis=dim` as a keyword arg (minor cleanup) For consistency with the following `var` call, and because `axis` is officially a keyword argument, change the `mean` call to pass `axis=dim` as a keyword argument instead of a positional argument. Issue: #7542 --- experimental/torch_xla2/torch_xla2/ops/jaten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 1dfb785a2aa..dee22ac71f3 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -3034,7 +3034,7 @@ def _aten_var_mean_correction(tensor, dim=None, correction=1, keepdim=False): # be nullable, we still need to check for `None` per the API contract. if correction is None: correction = 1 - mean = jnp.mean(tensor, dim, keepdims=keepdim) + mean = jnp.mean(tensor, axis=dim, keepdims=keepdim) # TODO: Pass in the `mean=mean` argument once `jax.numpy.var` supports it. var = jnp.var(tensor, axis=dim, ddof=correction, keepdims=keepdim) return var, mean