Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA2 does not support maxpool #8241

Open
Chaosruler972 opened this issue Oct 9, 2024 · 0 comments
Open

XLA2 does not support maxpool #8241

Chaosruler972 opened this issue Oct 9, 2024 · 0 comments
Assignees

Comments

@Chaosruler972
Copy link

🐛 Bug

Maxpool operator from xla2 crashes

To Reproduce

  1. download a mnist toy example I prepared from here
  2. move it to Trillium machine
  3. run it

Traceback:
Traceback (most recent call last):
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 174, in
jax_weights, opt_state, loss = training_step(jax_weights, jax_buffers, opt_state, x_j, target_j)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 102, in training_step
loss, grads = jax.value_and_grad(forward)(jax_weights, buffers, x, target)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 89, in forward
pred = jittable_model.functional_call('forward', weights, buffers, x)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/interop.py", line 73, in functional_call
res = getattr(self._model, method_name)(*args, **kwargs)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 68, in forward
x = F.max_pool2d(x, 2, stride=2, return_indices=False)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/_jit_internal.py", line 503, in fn
return if_false(*args, **kwargs)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/nn/functional.py", line 783, in _max_pool2d
return handle_torch_function(
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/overrides.py", line 1630, in handle_torch_function
result = mode.torch_function(public_api, types, args, kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 212, in torch_function
return func(*args, **(kwargs or {}))
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/_jit_internal.py", line 503, in fn
return if_false(*args, **kwargs)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/nn/functional.py", line 796, in _max_pool2d
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 227, in torch_dispatch
return self.env.dispatch(func, types, args, kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 430, in dispatch
res = op.func(args, **kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/ops/jaten.py", line 1136, in _aten_max_pool2d_with_indices
indices, _ = jax.lax.reduce_window(
ValueError: Operands must have the same tree structure as init_values: PyTreeDef([
, *, CustomNode(Zero[ShapedArray(float0[1024,64,24,24])], []), ]) vs. PyTreeDef([, *, *, *])
(venv) zmelumian@t1v-n-41085a1d-w-0:~$

Expected behavior

Maxpool to have exact behavior between xla2 and pytorch and be completely wrapped

Environment

  • Reproducible on XLA backend TPU
  • Jax version 0.4.43
  • Trillium machine
  • torch_xla2 version: 0.0.1

Additional context

after digging in, I noticed that the operator written in xla2 has two trees, one where it will compute the maxpool values, and one where it computes both values and indices

it is done so, to let the case of the indices unused in the operator to be ignored in the XLA compiler - which will lead better performence due to the skip of the unused indexes prediction, unless they were used

a custom kernel was written to the window function in jax that gets a value that contains a tuple which is the index and value per grid on the window map

the logic in the kernel is accurate and should work, however the windowing function is not supported in Jax, which has some odd internal states per input to better understand it's relationship across devices

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants