Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Op info test for linalg.tensorsolve .. log_normal #7505

Closed
qihqi opened this issue Jun 25, 2024 · 14 comments · Fixed by #8247, #8249, #8236 or #8287
Closed

Op info test for linalg.tensorsolve .. log_normal #7505

qihqi opened this issue Jun 25, 2024 · 14 comments · Fixed by #8247, #8249, #8236 or #8287
Assignees
Labels

Comments

@qihqi
Copy link
Collaborator

qihqi commented Jun 25, 2024

Fix the Op info test for linalg.tensorsolve .. log_normal

  1. Find the lines 143 to 147 of test_ops.py and remove
    linalg.tensorsolve .. log_normal from skip_list
  2. Run op_info test with pytest test/test_ops.py
  3. Fix the failure.

Please refer to
this guide as guide to fix:

Also refer to these PRs:

@barney-s
Copy link
Contributor

barney-s commented Oct 6, 2024

@ManfeiBai - Iam working on this

@barney-s
Copy link
Contributor

barney-s commented Oct 6, 2024

trying linspace:

One of the failing test logs:

testcase = <test.test_ops.TestOpInfoCPU testMethod=test_reference_eager_linspace_tensor_overload_cpu_float32>, output1 = tensor([-2.]), output2 = tensor(-2.)
rtol = 1e-05, atol = 0.001, equal_nan = True, check_output = True


E         AssertionError: The values for attribute 'shape' do not match: torch.Size([]) != torch.Size([1]).
E         To execute this test, run the following from the base repo dir:
E              python test/test_ops.py -k TestOpInfoCPU.test_reference_eager_linspace_tensor_overload_cpu_float32
E         This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------- Captured stdout call -----------------------------------------------------------------------
[DEBUG] sample_input:  SampleInput(input=-2.0, args=(tensor(-3), 0), kwargs={'dtype': torch.int64, 'device': 'cpu'}, broadcasts_input=False, name='')
[DEBUG] sample_input:  SampleInput(input=-2.0, args=(-3, 0), kwargs={'dtype': torch.int64, 'device': 'cpu'}, broadcasts_input=False, name='')
[DEBUG] sample_input:  SampleInput(input=-2.0, args=(tensor(-3), 0), kwargs={'dtype': torch.int64, 'device': 'cpu'}, broadcasts_input=False, name='')
[DEBUG] sample_input:  SampleInput(input=-2.0, args=(tensor(-3), 1), kwargs={'dtype': torch.int64, 'device': 'cpu'}, broadcasts_input=False, name='')

The output2 is a tensor(-2) ? Is this a scalar ? vs output1 being tensor([-2])

Taking the last sample input and calling in a test script that uses torch_xla2:

import torch
import torch_xla2

env = torch_xla2.default_env()
env.config.debug_print_each_op = True
#env.config.debug_accuracy_for_each_op = True

with env:
  print(torch.linspace(-2, -3, 1))

--------------------- [output]:
 % python test_logspace.py   
/usr/local/google/home/barni/miniconda3/envs/torch_xla2/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:270: UserWarning: Device capability of jax unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
  warnings.warn(
FUNCTION: linspace
 DISPATCH: aten::linspace
  FUNCTION: full
XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [-2.])

The call is returning a Jax array of len 1.

As does the native torch's implementation:

import torch
print(torch.linspace(-2, -3, 1))

-------------[output]:
tensor([-2.])

@barney-s
Copy link
Contributor

barney-s commented Oct 6, 2024

Investigating further, the testcases that are failing is using tensor(int) as input to the linspace. Pytorch linspace allows using 0D tensor (scalar). Passing 0D tensors with steps=1 to the xla env, the output is a scalar and not a tensor/list:

import torch
import torch_xla2

env = torch_xla2.default_env()
env.config.debug_print_each_op = True
#env.config.debug_accuracy_for_each_op = True

with env:
  print("torch_xla2.linspace:", torch.linspace(-2, -3, 1))
  print("\n")
  print("torch_xla2.linspace:", torch.linspace(torch.tensor(-2), torch.tensor(-3), 1))


---------------------------[output]:
 % python test_logspace.py
/usr/local/google/home/barni/miniconda3/envs/torch_xla2/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:270: UserWarning: Device capability of jax unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
  warnings.warn(
FUNCTION: linspace
 DISPATCH: aten::linspace
  FUNCTION: full
torch_xla2.linspace: XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [-2.])


FUNCTION: tensor
/usr/local/google/home/barni/workspace/pytorch/xla/experimental/torch_xla2/torch_xla2/ops/jtorch.py:48: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return jnp.array(
FUNCTION: tensor
FUNCTION: linspace
 DISPATCH: aten::linspace.Tensor_Tensor
  FUNCTION: to
/usr/local/google/home/barni/miniconda3/envs/torch_xla2/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:68: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  FUNCTION: to
/usr/local/google/home/barni/miniconda3/envs/torch_xla2/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:68: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  FUNCTION: empty
  FUNCTION: copy_
   DISPATCH: aten::copy_
torch_xla2.linspace: XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> -2.0)
barni@barni ~/workspace/pytorch/xla/experimental/torch_xla2

The same function when the steps>1, returns a tensor !! :

with env:
  print("torch_xla2.linspace:", torch.linspace(torch.tensor(-2), torch.tensor(-4), 3))



FUNCTION: tensor
/usr/local/google/home/barni/workspace/pytorch/xla/experimental/torch_xla2/torch_xla2/ops/jtorch.py:48: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return jnp.array(
FUNCTION: tensor
FUNCTION: linspace
 DISPATCH: aten::linspace.Tensor_Tensor
  FUNCTION: to
/usr/local/google/home/barni/miniconda3/envs/torch_xla2/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:68: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  FUNCTION: to
/usr/local/google/home/barni/miniconda3/envs/torch_xla2/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:68: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return lax_numpy.astype(arr, dtype, copy=copy, device=device)
  FUNCTION: arange
   DISPATCH: aten::arange.start_step
  FUNCTION: sub
   DISPATCH: aten::sub.Tensor
  FUNCTION: div
   DISPATCH: aten::div.Tensor
  FUNCTION: lt
   DISPATCH: aten::lt.Scalar
  FUNCTION: to
  FUNCTION: mul
   DISPATCH: aten::mul.Tensor
  FUNCTION: add
   DISPATCH: aten::add.Tensor
  FUNCTION: __rsub__
   DISPATCH: aten::rsub.Scalar
    FUNCTION: sub
     DISPATCH: aten::sub.Tensor
  FUNCTION: to
  FUNCTION: mul
   DISPATCH: aten::mul.Tensor
  FUNCTION: sub
   DISPATCH: aten::sub.Tensor
  FUNCTION: where
   DISPATCH: aten::where.self
torch_xla2.linspace: XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [-2. -3. -4.])
native torch.linspace: tensor([-2., -3., -4.])

steps==1 case calls torch.empty as seen here https://github.com/pytorch/pytorch/blob/main/torch/_refs/__init__.py#L5165
where as steps > 1 case calls torch.arange.

@barney-s
Copy link
Contributor

barney-s commented Oct 6, 2024

Investigating torch.empty._copy further. The output of copy_() differs for scalar and 0D tensor for native and xla cases.

import torch
import torch_xla2

env = torch_xla2.default_env()

def empty(): 
  with env:
    start1 = torch.tensor(-2)
    print("xla   | torch.tensor(-2)   ->", start1)
    start2 = torch.tensor([-2])
    print("xla   | torch.tensor([-2]) ->", start2)
    emp1 = torch.empty((1,))
    ret1 = emp1.copy_(start1)
    print("xla   | torch.empty((1,)).copy_(tensor(-2))  :", ret1)
    emp2 = torch.empty((1,))
    ret2 = emp2.copy_(start2)
    print("xla   | torch.empty((1,)).copy_(tensor([-2])):", ret2)
  
  start1 = torch.tensor(-2)
  print("native| torch.tensor(-2)   ->", start1)
  start2 = torch.tensor([-2])
  print("native| torch.tensor([-2]) ->", start2)
  emp1 = torch.empty((1,))
  ret1 = emp1.copy_(start1)
  print("native| torch.empty((1,)).copy_(tensor(-2))  :", ret1)
  emp2 = torch.empty((1,))
  ret2 = emp2.copy_(start2)
  print("native| torch.empty((1,)).copy_(tensor([-2])):", ret2)
empty()

Output:

xla   | torch.tensor(-2)   -> XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> -2)
xla   | torch.tensor([-2]) -> XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [-2])
xla   | torch.empty((1,)).copy_(tensor(-2))  : XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> -2.0)
xla   | torch.empty((1,)).copy_(tensor([-2])): XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [-2.])
native| torch.tensor(-2)   -> tensor(-2)
native| torch.tensor([-2]) -> tensor([-2])
native| torch.empty((1,)).copy_(tensor(-2))  : tensor([-2.])
native| torch.empty((1,)).copy_(tensor([-2])): tensor([-2.])

@ManfeiBai ManfeiBai assigned barney-s and unassigned ManfeiBai Oct 7, 2024
@ManfeiBai
Copy link
Collaborator

Thanks for the investigation here, is it possible to modify _aten_copy implementation to wrap it with array or increase dimension? if I understand correct

@barney-s
Copy link
Contributor

barney-s commented Oct 8, 2024

Trying this patch, some of the failures went away for linspace.

@op(torch.ops.aten.copy_, is_jax_function=False)
def _aten_copy(x, y, memory_format=None):
  x._elem = y._elem.astype(x._elem.dtype)
  if x.ndim == 1 and y.ndim == 0:
    # case of torch.empty((1,)).copy_(tensor(N))
    # we need to return 0D tensor([N]) and not scalar tensor(N)
    # ref: https://github.com/pytorch/xla/issues/7505#issuecomment-2395319131
    x._elem = jnp.array([y._elem.astype(x._elem.dtype)])
  else:
    x._elem = y._elem.astype(x._elem.dtype)
  return x

But it introduces failures for nanquantile. Investigating that found that squeeze_ behavior is changed:

def squeeze(): 
  with env:
    t1 = torch.tensor([-3.5])
    r1 = t1.squeeze_(-1)
    print("xla   | torch.squeeze :", r1)
  t1 = torch.tensor([-3.5])
  r1 = t1.squeeze_(-1)
  print("native| torch.squeeze :", r1)

output:

FUNCTION: tensor
FUNCTION: squeeze_
 DISPATCH: aten::squeeze_.dim
  FUNCTION: squeeze
  FUNCTION: copy_
   DISPATCH: aten::copy_
xla   | torch.squeeze : XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [-3.5])
native| torch.squeeze : tensor(-3.5000)

@qihqi
Copy link
Collaborator Author

qihqi commented Oct 8, 2024

copy_ should behave identically as original torch.copy_ should behave.

From here: https://pytorch.org/docs/stable/generated/torch.Tensor.copy_.html it's not about scaler vs rank 1; but about self's shape is perserving. (and src broadcasts to self).

So torch.empty((2,2)).copy(1) should be a 2 x 2 tensor of all 1s etc.

So something along the lines of:

def _aten_copy(x, y, memory_format=None):
  orig_shape = x._elem.shape
  orig_dtype = x._elem.dtype
  x._elem = y._elem.expand_dims(orig_shape).astype(orig_dtype)

  return x

should work

@barney-s
Copy link
Contributor

barney-s commented Oct 8, 2024

Progress

With fixes in #8236 , most tests are passing (for linspace):

  1. fix _copy to handle self shape=(1) and copy src shape=0
  2. for squeeze where we expect the self's shape to change, dont use copy_ instead replace.

Remaining failure can be reduced to behavior similar to this input:

xla   | torch_xla2.linspace(4.9, 3, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4 4 3 3 3])
native| torch_xla2.linspace(4.9, 3, 5, dtype=torch.int64):  tensor([4, 3, 3, 3, 3])

Investigating the failure

import torch
import torch_xla2

env = torch_xla2.default_env()
env.config.debug_print_each_op = True

def linspace(): 
  with env:
    print("xla   | torch.linspace(4.9, 3, 5, dtype=torch.int64): ", torch.linspace(4.9, 3, 5, dtype=torch.int64))
  print("native| torch.linspace(4.9, 3, 5, dtype=torch.int64): ", torch.linspace(4.9, 3, 5, dtype=torch.int64))

linspace()

output:

xla   | torch.linspace(4.9, 3, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4 4 3 3 3])
native| torch.linspace(4.9, 3, 5, dtype=torch.int64):  tensor([4, 3, 3, 3, 3])

The interesting thing is when iam debugging , xla's linspace breaks in torch python lib: https://github.com/pytorch/pytorch/blob/b16167874dddcbb078a113e4d607761b81de940d/torch/_refs/__init__.py#L5107
but the native linspace call (outside env) does not break there !! Is it calling a different implementation of linspace ?

@ManfeiBai
Copy link
Collaborator

linspace(4.9, 3, 5, dtype=torch.int64)

tried locally too with barney-s:linspace branch, it looks like round calculation might be one of the reason: for end-start < -1, and end with x.1 ~ x.3 will met issue like this

# mismatched
xla   | torch_xla2.linspace(4.9, 3, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4, 4, 3, 3, 3])
native| torch_xla2.linspace(4.9, 3, 5, dtype=torch.int64):  tensor([4, 3, 3, 3, 3])
---
# mismatched
xla   | torch_xla2.linspace(4.4, 3, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4, 4, 3, 3, 3])
native| torch_xla2.linspace(4.4, 3, 5, dtype=torch.int64):  tensor([4, 3, 3, 3, 3])
---
# MATCHED
xla   | torch_xla2.linspace(4.3, 3, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4 3 3 3 3])
native| torch_xla2.linspace(4.3, 3, 5, dtype=torch.int64):  tensor([4, 3, 3, 3, 3])
---
# MATCHED
xla   | torch_xla2.linspace(4.0, 3, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4 3 3 3 3])
native| torch_xla2.linspace(4.0, 3, 5, dtype=torch.int64):  tensor([4, 3, 3, 3, 3])
---
# MATCHED
xla   | torch_xla2.linspace(4, 3.0, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4 3 3 3 3])
native| torch_xla2.linspace(4, 3.0, 5, dtype=torch.int64):  tensor([4, 3, 3, 3, 3])
---
# MATCHED
xla   | torch_xla2.linspace(4, 3.3, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4 3 3 3 3])
native| torch_xla2.linspace(4, 3.3, 5, dtype=torch.int64):  tensor([4, 3, 3, 3, 3])
---
# MATCHED
xla   | torch_xla2.linspace(4, 3.4, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4 3 3 3 3])
native| torch_xla2.linspace(4, 3.4, 5, dtype=torch.int64):  tensor([4, 3, 3, 3, 3])
---
# MATCHED
xla   | torch_xla2.linspace(4, 3.5, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4 3 3 3 3])
native| torch_xla2.linspace(4, 3.5, 5, dtype=torch.int64):  tensor([4, 3, 3, 3, 3])
---
# MATCHED
xla   | torch_xla2.linspace(4, 3.6, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4 3 3 3 3])
native| torch_xla2.linspace(4, 3.6, 5, dtype=torch.int64):  tensor([4, 3, 3, 3, 3])
---
# MATCHED
xla   | torch_xla2.linspace(4, 3.8, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4 3 3 3 3])
native| torch_xla2.linspace(4, 3.8, 5, dtype=torch.int64):  tensor([4, 3, 3, 3, 3])
---
# mismatched
xla   | torch_xla2.linspace(4, 2.9, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4, 3, 3, 3, 2])
native| torch_xla2.linspace(4, 2.9, 5, dtype=torch.int64):  tensor([4, 3, 3, 2, 2])
---
# MATCHED
xla   | torch_xla2.linspace(4, 2.5, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4 3 3 2 2])
native| torch_xla2.linspace(4, 2.5, 5, dtype=torch.int64):  tensor([4, 3, 3, 2, 2])
---
# mismatched
xla   | torch_xla2.linspace(4, 2.7, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4, 3, 3, 3, 2])
native| torch_xla2.linspace(4, 2.7, 5, dtype=torch.int64):  tensor([4, 3, 3, 2, 2])
---
# MATCHED
xla   | torch_xla2.linspace(4, 2.6, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4 3 3 2 2])
native| torch_xla2.linspace(4, 2.6, 5, dtype=torch.int64):  tensor([4, 3, 3, 2, 2])
---
# mismatched
xla   | torch_xla2.linspace(4, 1.9, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [4, 3, 2, 2, 1])
native| torch_xla2.linspace(4, 1.9, 5, dtype=torch.int64):  tensor([4, 3, 2, 1, 1])
---
# missmatched
xla   | torch_xla2.linspace(5, 3.9, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [5, 4, 4, 4, 3])
native| torch_xla2.linspace(5, 3.9, 5, dtype=torch.int64):  tensor([5, 4, 4, 3, 3])
---
# MATCHED
xla   | torch_xla2.linspace(5, 6.1, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [5 5 5 5 6])
native| torch_xla2.linspace(5, 6.1, 5, dtype=torch.int64):  tensor([5, 5, 5, 5, 6])
---
# MATCHED
xla   | torch_xla2.linspace(5, 2.0, 5, dtype=torch.int64):  XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [5 4 3 2 2])
native| torch_xla2.linspace(5, 2.0, 5, dtype=torch.int64):  tensor([5, 4, 3, 2, 2])

@qihqi
Copy link
Collaborator Author

qihqi commented Oct 9, 2024

Looks like there is a bug in torch's linspace: pytorch/pytorch#137546

@barney-s
Copy link
Contributor

barney-s commented Oct 9, 2024

tested jaten.py implementation as well xla's pytorch decomposition.

For the float dtype output matches.
For the int64 dtype the output dont match for both jax as well as pytorch decomposition when compared with native. !!

So its to do with rounding logic when converting from float to int !! And it should be done after the initial generation.

xla override (jax implementation):

# ops/jtorch.py: 

@register_function(torch.linspace)
def linspace(start, end=None, steps=None, 
  out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
  new_start = start
  new_end = end
  #new_start = jax.lax.convert_element_type(start, mappings.t2j_dtype(dtype))
  #new_end = jax.lax.convert_element_type(end, mappings.t2j_dtype(dtype))
  if dtype is not None:
    dtype=mappings.t2j_dtype(dtype)
  ret = jnp.linspace(new_start, new_end, steps,
                      endpoint=True,
                      #dtype=dtype,
                      )
  return jnp.astype(ret, dtype)def linspace(): 

test script:

  dtype=torch.float
  with env:
    print("xla   | torch.linspace(): ", torch.linspace(4.9, 3, 5, dtype=dtype))
  print("native| torch.linspace(): ", torch.linspace(4.9, 3, 5, dtype=dtype))
  with env:
    print("xla   | torch.linspace(): ", torch.linspace(-2, -3, 50, dtype=dtype))
  print("native| torch.linspace(): ", torch.linspace(-2, -3, 50, dtype=dtype))
  with env:
    print("xla   | torch.linspace(): ", torch.linspace(4.3, -3, 50, dtype=dtype))
  print("native| torch.linspace(): ", torch.linspace(4.3, -3, 50, dtype=dtype))

barney-s added a commit to barney-s/xla that referenced this issue Oct 9, 2024
1. fix _copy to handle self shape=(1) and copy src shape=0, ref: pytorch#7505 (comment)
2. for squeeze where we expect the self's shape to change, dont use copy_ instead replace.
3. in test_ops reset dtype to float when an int64 is passed for linspace
   case. This is to workaround known pytorch failure: pytorch/pytorch#137546

ref:
* pytorch#7505 (comment)
* pytorch#7505 (comment)
* pytorch#7505 (comment)
* pytorch/pytorch#137546
barney-s added a commit to barney-s/xla that referenced this issue Oct 9, 2024
1. fix _copy to handle self shape=(1) and copy src shape=0, ref: pytorch#7505 (comment)
2. for squeeze where we expect the self's shape to change, dont use copy_ instead replace.
3. in test_ops reset dtype to float when an int64 is passed for linspace
   case. This is to workaround known pytorch failure: pytorch/pytorch#137546

ref:
* pytorch#7505 (comment)
* pytorch#7505 (comment)
* pytorch#7505 (comment)
* pytorch/pytorch#137546
barney-s added a commit to barney-s/xla that referenced this issue Oct 9, 2024
1. fix _copy to handle self shape=(1) and copy src shape=0, ref: pytorch#7505 (comment)
2. for squeeze where we expect the self's shape to change, dont use copy_ instead replace.
3. in test_ops reset dtype to float when an int64 is passed for linspace
   case. This is to workaround known pytorch failure: pytorch/pytorch#137546
4. logspace tests depend on linspace. Both are passing now

ref:
* pytorch#7505 (comment)
* pytorch#7505 (comment)
* pytorch#7505 (comment)
* pytorch/pytorch#137546
@barney-s
Copy link
Contributor

barney-s commented Oct 9, 2024

Working on log_normal

log_normal has a pytroch decomposition: pytorch/pytorch#91674
So registering it with xla: #8247
Since the nature of the api is non-deterministic, regular test methodology of diffing xla and native output wont work (

).

So excluding the test from numerical comparision.

Ref:

barney-s added a commit to barney-s/xla that referenced this issue Oct 9, 2024
* Use the existing log_normal decomposition
* Skip numerical comparision for the operation

ref: pytorch#7505 (comment)

fixes: pytorch#7505
barney-s added a commit to barney-s/xla that referenced this issue Oct 9, 2024
* Use the existing log_normal decomposition
* Skip numerical comparision for the operation

ref: pytorch#7505 (comment)

fixes: pytorch#7505
@barney-s
Copy link
Contributor

barney-s commented Oct 10, 2024

Working on linalg.vector_norm

attempt 0

tests failing:

[DEBUG] op:  linalg.vector_norm , sample_input:  SampleInput(input=-0.06738138198852539, args=(), kwargs={'ord': 0}, broadcasts_input=False, name='')
[DEBUG] input: torch.float32
output2 tensor(1., dtype=torch.float64)  , output1 tensor(1.)

...

FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_linalg_vector_norm_cpu_float32 - AssertionError: The values for attribute 'dtype' do not match: torch.float64 != torch.float32.

Trying out in a script:

import torch
import torch_xla2
import jax.numpy as jnp

env = torch_xla2.default_env()
env.config.debug_print_each_op = True

def linalg_vector_norm():
  with env:
    t = torch.tensor(-0.06738138198852539)
    print("xla   | linalg.vector_norm()", torch.linalg.vector_norm(t, ord=0).dtype)
  t = torch.tensor(-0.06738138198852539)
  print("native| linalg.vector_norm()", torch.linalg.vector_norm(t, ord=0).dtype)
  
linalg_vector_norm()

output:

FUNCTION: tensor
FUNCTION: linalg_vector_norm
 DISPATCH: aten::linalg_vector_norm
xla   | linalg.vector_norm() torch.float32
native| linalg.vector_norm() torch.float32

attempt 1

Debuggin _aten_linalg_vector_norm in experimental/torch_xla2/torch_xla2/ops/jaten.py, found that for ord==0 case, typecasting to float always sets the result type to float64.

Minor fix that fixes the test failure and passes all tests:
#8249

@barney-s
Copy link
Contributor

barney-s commented Oct 10, 2024

working on linalg.vectorsolve

attempt 1

Test Failure seen

FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_linalg_tensorsolve_cpu_float32 - RuntimeError: Expected Tensor but got None

debug script:

import torch
import torch_xla2
import jax.numpy as jnp

env = torch_xla2.default_env()
env.config.debug_print_each_op = True

def linalg_tensorsolve():
  with env:
    A = torch.tensor([[[-0.0674,  4.8280, -7.4074, -6.6235, -3.4664,  2.4134],
         [-0.1783,  7.1360, -0.7987,  2.3815, -2.7199, -1.7691],
         [-8.5981, -5.9605, -3.7100,  0.3334,  3.5580,  5.4002]],
        [[-6.1015, -3.9192,  3.2690,  7.4735, -1.8522,  6.7348],
         [-1.4507,  0.9523,  8.1493, -8.3490, -5.6658, -2.2785],
         [-3.5082,  7.7760, -5.8336, -4.1430, -6.2878, -8.4290]]])
    B = torch.tensor([[-5.2537,  7.7364,  4.0160],
        [ 4.3621,  0.4733, -4.6142]])
    print("xla   | linalg.vectorsolve()", torch.linalg.tensorsolve(A, B))
  A = torch.tensor([[[-0.0674,  4.8280, -7.4074, -6.6235, -3.4664,  2.4134],
         [-0.1783,  7.1360, -0.7987,  2.3815, -2.7199, -1.7691],
         [-8.5981, -5.9605, -3.7100,  0.3334,  3.5580,  5.4002]],
        [[-6.1015, -3.9192,  3.2690,  7.4735, -1.8522,  6.7348],
         [-1.4507,  0.9523,  8.1493, -8.3490, -5.6658, -2.2785],
         [-3.5082,  7.7760, -5.8336, -4.1430, -6.2878, -8.4290]]])
  B = torch.tensor([[-5.2537,  7.7364,  4.0160],
        [ 4.3621,  0.4733, -4.6142]])
  print("native| linalg.vectorsolve()", torch.linalg.tensorsolve(A, B))
  
linalg_tensorsolve()

debug script output:

%      python test_logspace.py
/usr/local/google/home/barni/miniconda3/envs/torch_xla2/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:270: UserWarning: Device capability of jax unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
  warnings.warn(
WARNING:root:Duplicate op registration for aten.trunc
FUNCTION: tensor
FUNCTION: tensor
FUNCTION: linalg_tensorsolve
 DISPATCH: aten::view
 DISPATCH: aten::view
 DISPATCH: aten::_linalg_solve_ex
[DEBUG]: linalg_solve_ex ret: [-1.2417108   2.068498    1.2399809   0.08891661  2.002629    0.576693  ]
Traceback (most recent call last):
  File "/usr/local/google/home/barni/workspace/pytorch/xla/experimental/torch_xla2/test_logspace.py", line 125, in <module>
    linalg_tensorsolve()
  File "/usr/local/google/home/barni/workspace/pytorch/xla/experimental/torch_xla2/test_logspace.py", line 108, in linalg_tensorsolve
    print("xla   | linalg.vectorsolve()", torch.linalg.tensorsolve(A, B))
  File "/usr/local/google/home/barni/workspace/pytorch/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 212, in __torch_function__
    return func(*args, **(kwargs or {}))
RuntimeError: Expected Tensor but got None

Refs:
https://github.com/pytorch/pytorch/blob/575f260229a1e691b70d4aff9bd2d919b902c395/aten/src/ATen/native/LinearAlgebra.cpp#L3319

@op(torch.ops.aten._linalg_solve_ex)
was able to debug this function @op(torch.ops.aten._linalg_solve_ex)

@ManfeiBai ManfeiBai reopened this Oct 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
3 participants