You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback:
Traceback (most recent call last):
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 174, in
jax_weights, opt_state, loss = training_step(jax_weights, jax_buffers, opt_state, x_j, target_j)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 102, in training_step
loss, grads = jax.value_and_grad(forward)(jax_weights, buffers, x, target)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 89, in forward
pred = jittable_model.functional_call('forward', weights, buffers, x)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/interop.py", line 73, in functional_call
res = getattr(self._model, method_name)(*args, **kwargs)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 68, in forward
x = F.max_pool2d(x, 2, stride=2, return_indices=False)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/_jit_internal.py", line 503, in fn
return if_false(*args, **kwargs)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/nn/functional.py", line 783, in _max_pool2d
return handle_torch_function(
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/overrides.py", line 1630, in handle_torch_function
result = mode.torch_function(public_api, types, args, kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 212, in torch_function
return func(*args, **(kwargs or {}))
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/_jit_internal.py", line 503, in fn
return if_false(*args, **kwargs)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/nn/functional.py", line 796, in _max_pool2d
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 227, in torch_dispatch
return self.env.dispatch(func, types, args, kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 430, in dispatch
res = op.func(args, **kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/ops/jaten.py", line 1136, in _aten_max_pool2d_with_indices
indices, _ = jax.lax.reduce_window(
ValueError: Operands must have the same tree structure as init_values: PyTreeDef([, *, CustomNode(Zero[ShapedArray(float0[1024,64,24,24])], []), ]) vs. PyTreeDef([, *, *, *])
(venv) zmelumian@t1v-n-41085a1d-w-0:~$
Expected behavior
Maxpool to have exact behavior between xla2 and pytorch and be completely wrapped
Environment
Reproducible on XLA backend TPU
Jax version 0.4.43
Trillium machine
torch_xla2 version: 0.0.1
Additional context
after digging in, I noticed that the operator written in xla2 has two trees, one where it will compute the maxpool values, and one where it computes both values and indices
it is done so, to let the case of the indices unused in the operator to be ignored in the XLA compiler - which will lead better performence due to the skip of the unused indexes prediction, unless they were used
a custom kernel was written to the window function in jax that gets a value that contains a tuple which is the index and value per grid on the window map
the logic in the kernel is accurate and should work, however the windowing function is not supported in Jax, which has some odd internal states per input to better understand it's relationship across devices
The text was updated successfully, but these errors were encountered:
🐛 Bug
Maxpool operator from xla2 crashes
To Reproduce
Traceback:
Traceback (most recent call last):
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 174, in
jax_weights, opt_state, loss = training_step(jax_weights, jax_buffers, opt_state, x_j, target_j)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 102, in training_step
loss, grads = jax.value_and_grad(forward)(jax_weights, buffers, x, target)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 89, in forward
pred = jittable_model.functional_call('forward', weights, buffers, x)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/interop.py", line 73, in functional_call
res = getattr(self._model, method_name)(*args, **kwargs)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 68, in forward
x = F.max_pool2d(x, 2, stride=2, return_indices=False)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/_jit_internal.py", line 503, in fn
return if_false(*args, **kwargs)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/nn/functional.py", line 783, in _max_pool2d
return handle_torch_function(
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/overrides.py", line 1630, in handle_torch_function
result = mode.torch_function(public_api, types, args, kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 212, in torch_function
return func(*args, **(kwargs or {}))
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/_jit_internal.py", line 503, in fn
return if_false(*args, **kwargs)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/nn/functional.py", line 796, in _max_pool2d
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 227, in torch_dispatch
return self.env.dispatch(func, types, args, kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 430, in dispatch
res = op.func(args, **kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/ops/jaten.py", line 1136, in _aten_max_pool2d_with_indices
indices, _ = jax.lax.reduce_window(
ValueError: Operands must have the same tree structure as init_values: PyTreeDef([, *, CustomNode(Zero[ShapedArray(float0[1024,64,24,24])], []), ]) vs. PyTreeDef([, *, *, *])
(venv) zmelumian@t1v-n-41085a1d-w-0:~$
Expected behavior
Maxpool to have exact behavior between xla2 and pytorch and be completely wrapped
Environment
Additional context
after digging in, I noticed that the operator written in xla2 has two trees, one where it will compute the maxpool values, and one where it computes both values and indices
it is done so, to let the case of the indices unused in the operator to be ignored in the XLA compiler - which will lead better performence due to the skip of the unused indexes prediction, unless they were used
a custom kernel was written to the window function in jax that gets a value that contains a tuple which is the index and value per grid on the window map
the logic in the kernel is accurate and should work, however the windowing function is not supported in Jax, which has some odd internal states per input to better understand it's relationship across devices
The text was updated successfully, but these errors were encountered: