Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot move tensors to cpu when in a xmp spawn process #8271

Open
radna0 opened this issue Oct 17, 2024 · 4 comments
Open

Cannot move tensors to cpu when in a xmp spawn process #8271

radna0 opened this issue Oct 17, 2024 · 4 comments

Comments

@radna0
Copy link

radna0 commented Oct 17, 2024

🐛 Bug

all_frames = torch.cat(all_frames, dim=0).cpu().numpy()
RuntimeError: Bad StatusOr access: INTERNAL: during context [pre-optimization]: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:402) replica_count == 1 || n == replica_count In kCrossReplica mode, replica groups should contain 8 replicas, but found 2: %all-gather.20571 = f16[81920,8,64]{2,1,0} all-gather(f16[40960,8,64]{2,1,0} %add.20570), replica_groups={{0,1}}, dimensions={0}

To Reproduce

Steps to reproduce the behavior:

  1. spawn a process, with xmp spawn
  2. Move tensors to cpu using .cpu

Expected behavior

Should move tensors to cpu.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU v2-8 and v3-8
  • torch_xla version: nightly 2.6

Additional context

@radna0 radna0 changed the title Cannot move to cpu when in a xmp spawn process Cannot move tensors to cpu when in a xmp spawn process Oct 17, 2024
@JackCaoG
Copy link
Collaborator

do you have a small repo code?

@radna0
Copy link
Author

radna0 commented Oct 17, 2024

You can clone this repo here
git clone https://github.com/radna0/Video-Infinity.git
install requirements by using
pip install -r requirements.txt
and test out the code using
accelerate launch tpu_inference.py --config examples/config.json
Let me know if I am missing anything

@radna0
Copy link
Author

radna0 commented Oct 19, 2024

Were you able to reproduce the error? @JackCaoG

@radna0
Copy link
Author

radna0 commented Oct 24, 2024

It's been a week, and I''m still encountering this problem. I have tried different methods for example: dist.gather(), tensor.cpu(), tensor.contiguous() or other methods related to saving tensors also moves to CPU and run into the same problem here. Even with xm.mark_step(). There is no other way around this and it has always been the same error replica groups should contain 8 replicas, but found 2. Is there something wrong that I could be doing here? What I'm basically doing is the following:

  1. Spawn processes using torch_xla.launch()
  2. For each rank, declare a distributed controller,
class DistController(object):
    def __init__(self, rank, world_size, config) -> None:
        super().__init__()
        self.rank = rank
        self.world_size = world_size
        self.config = config
        self.is_master = rank == 0
        self.device = torch_xla.device()
        self.init_dist()
        self.init_group()

    def init_dist(self):
        print(
            f"Rank {self.rank}, {self.device} / {self.world_size} is running on XLA device."
        )
        os.environ["MASTER_ADDR"] = "127.0.0.1"
        os.environ["MASTER_PORT"] = str(self.config.get("master_port") or "29500")
        dist.init_process_group("xla", rank=self.rank, world_size=self.world_size)

    def init_group(self):
        self.adj_groups = [
            dist.new_group([i, i + 1]) for i in range(self.world_size - 1)
        ]
        print(f"Rank {self.rank} initialized groups: {self.adj_groups}")

  1. init the model and move it to the xla device, then do inference.
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 205, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 205, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 77, in _run_thread_per_device
    replica_results = list(
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
  File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 70, in _thread_fn
    return fn()
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 185, in __call__
    self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
  File "/home/kojoe/Video-Infinity/tpu_inference.py", line 86, in main
    obj = run_inference(rank, config)
  File "/home/kojoe/Video-Infinity/tpu_inference.py", line 59, in run_inference
    obj = dist_pipe.inference(
  File "/home/kojoe/Video-Infinity/src/video_infinity/wrapper.py", line 241, in inference
    xm.mark_step()
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1046, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: Bad StatusOr access: INTERNAL: during context [pre-optimization]: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:402) replica_count == 1 || n == replica_count In kCrossReplica mode, replica groups should contain 8 replicas, but found 2: %all-gather.320 = f16[81920,8,64]{2,1,0} all-gather(f16[40960,8,64]{2,1,0} %add.319), replica_groups={{0,1}}, dimensions={0}
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/kojoe/Video-Infinity/tpu_inference.py", line 102, in <module>
    torch_xla.launch(
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/torch_xla.py", line 233, in launch
    xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 39, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 213, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 169, in run_multiprocess
    replica_results = list(
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 170, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
RuntimeError: Bad StatusOr access: INTERNAL: during context [pre-optimization]: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:402) replica_count == 1 || n == replica_count In kCrossReplica mode, replica groups should contain 8 replicas, but found 2: %all-gather.320 = f16[81920,8,64]{2,1,0} all-gather(f16[40960,8,64]{2,1,0} %add.319), replica_groups={{0,1}}, dimensions={0}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants