Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A process in the process pool was terminated abruptly while the future was running or pending. #8234

Open
fancy45daddy opened this issue Oct 8, 2024 · 10 comments

Comments

@fancy45daddy
Copy link

❓ Questions and Help

I want to run pytorch xla on kaggle tpu v3-8 and use all core in tpu. But I always get A process in the process pool was terminated abruptly while the future was running or pending.

Source code:

%%bash
python3 -m pip install -U imageio[ffmpeg] controlnet-aux openmim
mim install mmengine mmcv mmdet mmpose
curl -O https://raw.githubusercontent.com/huggingface/controlnet_aux/master/src/controlnet_aux/dwpose/dwpose_config/dwpose-l_384x288.py
curl -O https://raw.githubusercontent.com/huggingface/controlnet_aux/master/src/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py

import torch_xla, diffusers, builtins, imageio, os, PIL.Image, controlnet_aux, sys, torch
os.environ.pop('TPU_PROCESS_ADDRESSES')

reader = imageio.get_reader('/kaggle/input/controlnet/pose.mp4', 'ffmpeg')
openpose = controlnet_aux.DWposeDetector(det_config='yolox_l_8xb8-300e_coco.py', pose_config='dwpose-l_384x288.py')
poses = [openpose(PIL.Image.fromarray(reader.get_data(_)).resize((512, 768))) for _ in builtins.range(64)] #reader.count_frames()
length = builtins.len(poses) // 8
fps = reader.get_meta_data().get('fps')

def process(index):
    controlnet = diffusers.ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_openpose', torch_dtype=torch.bfloat16)
    pipeline = diffusers.StableDiffusionControlNetPipeline.from_single_file('https://huggingface.co/chaowenguo/pal/blob/main/chilloutMix-Ni.safetensors', config='chaowenguo/stable-diffusion-v1-5', safety_checker=None, controlnet=controlnet, use_safetensors=True, torch_dtype=torch.bfloat16).to(torch_xla.core.xla_model.xla_device())
    pipeline.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
    pipeline.enable_attention_slicing()
    pipeline.unet.set_attn_processor(diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor())
    pipeline.controlnet.set_attn_processor(diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor())
    pose = sys.modules['__main__'].poses[index * sys.modules['__main__'].length:(index + 1) * sys.modules['__main__'].length]
    imageio.mimsave(f'{index}.mp4', pipeline(prompt=['gorgeous slim young cleavage robust boob japanese girl, wearing white deep V bandeau pantie, smile lying on white bed, best quality, extremely detailed'] * builtins.len(pose), negative_prompt=['monochrome, lowres, bad anatomy, worst quality, low quality'] * builtins.len(pose), image=pose, num_inference_steps=20, latents=torch.randn((1, 4, 96, 64), device=torch_xla.core.xla_model.xla_device(), dtype=torch.bfloat16).repeat(builtins.len(pose), 1, 1, 1)).images, fps=fps)

torch_xla.distributed.xla_multiprocessing.spawn(process, start_method='fork')

and get

BrokenProcessPool                         Traceback (most recent call last)
Cell In[2], line 20
     17     pose = sys.modules['__main__'].poses[index * sys.modules['__main__'].length:(index + 1) * sys.modules['__main__'].length]
     18     imageio.mimsave(f'{index}.mp4', pipeline(prompt=['gorgeous slim young cleavage robust boob japanese girl, wearing white deep V bandeau pantie, smile lying on white bed, best quality, extremely detailed'] * builtins.len(pose), negative_prompt=['monochrome, lowres, bad anatomy, worst quality, low quality'] * builtins.len(pose), image=pose, num_inference_steps=20, latents=torch.randn((1, 4, 96, 64), device=torch_xla.core.xla_model.xla_device(), dtype=torch.bfloat16).repeat(builtins.len(pose), 1, 1, 1)).images, fps=fps)
---> 20 torch_xla.distributed.xla_multiprocessing.spawn(process, start_method='fork')
     21 result = []
     22 for _ in builtins.range(8):

File /usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:95, in requires_pjrt.<locals>.wrapper(*args, **kwargs)
     91 if not using_pjrt():
     92   raise NotImplementedError('`{}` not implemented for XRT'.format(
     93       fn.__name__))
---> 95 return fn(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py:38, in spawn(fn, args, nprocs, join, daemon, start_method)
      6 @xr.requires_pjrt
      7 def spawn(fn,
      8           args=(),
   (...)
     11           daemon=False,
     12           start_method='spawn'):
     13   """Enables multi processing based replication.
     14 
     15   Args:
   (...)
     36     return None.
     37   """
---> 38   return pjrt.spawn(fn, nprocs, start_method, args)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:214, in spawn(fn, nprocs, start_method, args)
    211 elif nprocs is not None:
    212   logging.warning('Unsupported nprocs (%d), ignoring...' % nprocs)
--> 214 run_multiprocess(spawn_fn, start_method=start_method)

File /usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:95, in requires_pjrt.<locals>.wrapper(*args, **kwargs)
     91 if not using_pjrt():
     92   raise NotImplementedError('`{}` not implemented for XRT'.format(
     93       fn.__name__))
---> 95 return fn(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:174, in run_multiprocess(fn, start_method, *args, **kwargs)
    168   mp_fn = functools.partial(
    169       _run_thread_per_device,
    170       local_world_size=num_processes,
    171       fn=functools.partial(fn, *args, **kwargs),
    172       initializer_fn=initialize_multiprocess)
    173   process_results = executor.map(mp_fn, range(num_processes))
--> 174   replica_results = list(
    175       itertools.chain.from_iterable(
    176           result.items() for result in process_results))
    178 return _merge_replica_results(replica_results)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:175, in <genexpr>(.0)
    168   mp_fn = functools.partial(
    169       _run_thread_per_device,
    170       local_world_size=num_processes,
    171       fn=functools.partial(fn, *args, **kwargs),
    172       initializer_fn=initialize_multiprocess)
    173   process_results = executor.map(mp_fn, range(num_processes))
    174   replica_results = list(
--> 175       itertools.chain.from_iterable(
    176           result.items() for result in process_results))
    178 return _merge_replica_results(replica_results)

File /usr/local/lib/python3.10/concurrent/futures/process.py:575, in _chain_from_iterable_of_lists(iterable)
    569 def _chain_from_iterable_of_lists(iterable):
    570     """
    571     Specialized implementation of itertools.chain.from_iterable.
    572     Each item in *iterable* should be a list.  This function is
    573     careful not to keep references to yielded objects.
    574     """
--> 575     for element in iterable:
    576         element.reverse()
    577         while element:

File /usr/local/lib/python3.10/concurrent/futures/_base.py:621, in Executor.map.<locals>.result_iterator()
    618 while fs:
    619     # Careful not to keep a reference to the popped future
    620     if timeout is None:
--> 621         yield _result_or_cancel(fs.pop())
    622     else:
    623         yield _result_or_cancel(fs.pop(), end_time - time.monotonic())

File /usr/local/lib/python3.10/concurrent/futures/_base.py:319, in _result_or_cancel(***failed resolving arguments***)
    317 try:
    318     try:
--> 319         return fut.result(timeout)
    320     finally:
    321         fut.cancel()

File /usr/local/lib/python3.10/concurrent/futures/_base.py:458, in Future.result(self, timeout)
    456     raise CancelledError()
    457 elif self._state == FINISHED:
--> 458     return self.__get_result()
    459 else:
    460     raise TimeoutError()

File /usr/local/lib/python3.10/concurrent/futures/_base.py:403, in Future.__get_result(self)
    401 if self._exception:
    402     try:
--> 403         raise self._exception
    404     finally:
    405         # Break a reference cycle with the exception in self._exception
    406         self = None

BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

Single process on single tpu core works well but can not work multiple process and use all cores on tpu.

Please help.

You can copy and test my code in https://www.kaggle.com/code/chaowenguoback/stablediffusion. Please help

@JackCaoG
Copy link
Collaborator

JackCaoG commented Oct 8, 2024

It is from

File /usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:95, in requires_pjrt.<locals>.wrapper(*args, **kwargs)
     91 if not using_pjrt():
     92   raise NotImplementedError('`{}` not implemented for XRT'.format(
     93       fn.__name__))
---> 95 return fn(*args, **kwargs)

which version of pytorch/xla you are using? I am wondering how did you still trigger the XRT runtime which has been deprecated for a long time.

@fancy45daddy
Copy link
Author

import torch_xla
torch_xla.version

2.4.0+libtpu

@zpcore
Copy link
Collaborator

zpcore commented Oct 9, 2024

I just did a simple test on TPU v3-8 with the torch_xla 2.4.0 version:

import torch, torch_xla
import torch_xla.distributed.xla_multiprocessing as xmp
def process(index):
    print(index)

xmp.spawn(process, start_method='fork')

You should see it prints out 0, 1, ...7, which means it uses 8 cores.

@fancy45daddy
Copy link
Author

@zpcore than you, I see the code running on all tpu core on kaggle v3-8 now. But the new problem is after torch_xla.distributed.xla_multiprocessing.spawn(process, start_method='fork') I just can see one process return successfully. All others are broken. But when I run single process it work. Is it possible to limit to just two cores. So I can debug fast.

@zpcore
Copy link
Collaborator

zpcore commented Oct 9, 2024

In TPU, you can only use all the process with leaving the nproc argument None or set nproc=1. I don't think you can try with 2 process. I can't reproduce with the code you mentioned. Did you see any error message from the broken process?

@fancy45daddy
Copy link
Author

fancy45daddy commented Oct 10, 2024

@zpcore I figure out the problem, the StableDiffusionControlNetPipeline is very big and occupy many memory. If I use 8 process, I create 8 StableDiffusionControlNetPipeline pipeline. So that they eat up all the memory.

I have no idea how to reduce the memory, currently I use

import torch_xla, diffusers, builtins, imageio, os, PIL.Image, controlnet_aux, sys, torch
os.environ.pop('TPU_PROCESS_ADDRESSES')

reader = imageio.get_reader('/kaggle/input/controlnet/pose.mp4', 'ffmpeg')
openpose = controlnet_aux.DWposeDetector(det_config='yolox_l_8xb8-300e_coco.py', pose_config='dwpose-l_384x288.py')
poses = [openpose(PIL.Image.fromarray(reader.get_data(_)).resize((512, 768))) for _ in builtins.range(16)] #reader.count_frames()
length = builtins.len(poses) // 8
fps = reader.get_meta_data().get('fps')
controlnet = diffusers.ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_openpose', torch_dtype=torch.bfloat16)
pipeline = diffusers.StableDiffusionControlNetPipeline.from_single_file('https://huggingface.co/chaowenguo/pal/blob/main/chilloutMix-Ni.safetensors', config='chaowenguo/stable-diffusion-v1-5', safety_checker=None, controlnet=controlnet, use_safetensors=True, torch_dtype=torch.bfloat16)
pipeline.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.enable_attention_slicing()
pipeline.unet.set_attn_processor(diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor())
pipeline.controlnet.set_attn_processor(diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor())

def process(index):
    pipe = diffusers.StableDiffusionControlNetPipeline(**pipeline.components)
    pipe.to(torch_xla.core.xla_model.xla_device())
    pose = sys.modules['__main__'].poses[index * sys.modules['__main__'].length:(index + 1) * sys.modules['__main__'].length]
    imageio.mimsave(f'{index}.mp4', pipe(prompt=['gorgeous slim young cleavage robust boob japanese girl, wearing white deep V bandeau pantie, smile lying on white bed, best quality, extremely detailed'] * builtins.len(pose), negative_prompt=['monochrome, lowres, bad anatomy, worst quality, low quality'] * builtins.len(pose), image=pose, num_inference_steps=20, latents=torch.randn((1, 4, 96, 64), device=torch_xla.core.xla_model.xla_device(), dtype=torch.bfloat16).repeat(builtins.len(pose), 1, 1, 1)).images, fps=fps)

torch_xla.distributed.xla_multiprocessing.spawn(process, start_method='fork')

to test. But not working

 30%|███       | 6/20 [00:03<00:09,  1.51it/s]
  0%|          | 0/20 [00:00<?, ?it/s]
  0%|          | 0/20 [00:00<?, ?it/s]
  0%|          | 0/20 [00:00<?, ?it/s]
 35%|███▌      | 7/20 [00:04<00:11,  1.17it/s]
  5%|▌         | 1/20 [00:01<00:22,  1.20s/it]
  5%|▌         | 1/20 [00:01<00:22,  1.16s/it]
  5%|▌         | 1/20 [00:01<00:23,  1.22s/it]
 40%|████      | 8/20 [00:06<00:12,  1.02s/it]
 10%|█         | 2/20 [00:02<00:24,  1.36s/it]
 10%|█         | 2/20 [00:02<00:23,  1.29s/it]
 10%|█         | 2/20 [00:02<00:23,  1.33s/it]
 45%|████▌     | 9/20 [00:07<00:12,  1.14s/it]
 15%|█▌        | 3/20 [00:04<00:23,  1.39s/it]
 15%|█▌        | 3/20 [00:03<00:22,  1.35s/it]
 15%|█▌        | 3/20 [00:04<00:23,  1.38s/it]
 50%|█████     | 10/20 [00:09<00:12,  1.25s/it][A
 20%|██        | 4/20 [00:05<00:22,  1.44s/it]
 20%|██        | 4/20 [00:05<00:22,  1.40s/it]
 20%|██        | 4/20 [00:05<00:23,  1.44s/it]
 55%|█████▌    | 11/20 [00:10<00:11,  1.33s/it][A
 25%|██▌       | 5/20 [00:07<00:22,  1.47s/it]
 25%|██▌       | 5/20 [00:06<00:21,  1.44s/it]
 25%|██▌       | 5/20 [00:07<00:22,  1.49s/it]
 55%|█████▌    | 11/20 [12:32<10:15, 68.41s/it][A
 55%|█████▌    | 11/20 [13:04<10:41, 71.28s/it]
 25%|██▌       | 5/20 [13:57<41:53, 167.55s/it]
 25%|██▌       | 5/20 [14:01<42:05, 168.35s/it]
 55%|█████▌    | 11/20 [14:53<12:11, 81.24s/it]
 55%|█████▌    | 11/20 [14:57<12:14, 81.60s/it]
 25%|██▌       | 5/20 [16:05<48:17, 193.18s/it]
 25%|██▌       | 5/20 [16:10<48:31, 194.09s/it]

---------------------------------------------------------------------------
_RemoteTraceback                          Traceback (most recent call last)
_RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 205, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 205, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/usr/local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 78, in _run_thread_per_device
    replica_results = list(
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
  File "/usr/local/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 71, in _thread_fn
    return fn()
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 190, in __call__
    self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
  File "/tmp/ipykernel_13/28715152.py", line 20, in process
    imageio.mimsave(f'{index}.mp4', pipe(prompt=['gorgeous slim young cleavage robust boob japanese girl, wearing white deep V bandeau pantie, smile lying on white bed, best quality, extremely detailed'] * builtins.len(pose), negative_prompt=['monochrome, lowres, bad anatomy, worst quality, low quality'] * builtins.len(pose), image=pose, num_inference_steps=20, latents=torch.randn((1, 4, 96, 64), device=torch_xla.core.xla_model.xla_device(), dtype=torch.bfloat16).repeat(builtins.len(pose), 1, 1, 1)).images, fps=fps)
  File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet.py", line 1282, in __call__
    noise_pred = self.unet(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1281, in forward
    sample = upsample_block(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 2551, in forward
    hidden_states = attn(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 442, in forward
    hidden_states = block(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/diffusers/models/attention.py", line 530, in forward
    ff_output = self.ff(norm_hidden_states)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/diffusers/models/attention.py", line 1166, in forward
    hidden_states = module(hidden_states)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: torch_xla/csrc/runtime/pjrt_computation_client.cc:721 : Check failed: pjrt_device == pjrt_data->buffer->device() 
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	torch_xla::runtime::PjRtComputationClient::ExecuteComputation(torch_xla::runtime::ComputationClient::Computation const&, absl::lts_20230802::Span<std::shared_ptr<torch_xla::runtime::ComputationClient::Data> const>, std::string const&, torch_xla::runtime::ComputationClient::ExecuteComputationOptions const&)
	
	torch::lazy::MultiWait::Complete(std::function<void ()> const&)
	Eigen::ThreadPoolTempl<tsl::thread::EigenEnvironment>::WorkerLoop(int)
	void absl::lts_20230802::internal_any_invocable::RemoteInvoker<false, void, tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&>(absl::lts_20230802::internal_any_invocable::TypeErasedState*)
	
	
	__clone
*** End stack trace ***
TPU_1(process=0,(0,0,0,1)) vs TPU_0(process=0,(0,0,0,0))
"""

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[2], line 22
     19     pose = sys.modules['__main__'].poses[index * sys.modules['__main__'].length:(index + 1) * sys.modules['__main__'].length]
     20     imageio.mimsave(f'{index}.mp4', pipe(prompt=['gorgeous slim young cleavage robust boob japanese girl, wearing white deep V bandeau pantie, smile lying on white bed, best quality, extremely detailed'] * builtins.len(pose), negative_prompt=['monochrome, lowres, bad anatomy, worst quality, low quality'] * builtins.len(pose), image=pose, num_inference_steps=20, latents=torch.randn((1, 4, 96, 64), device=torch_xla.core.xla_model.xla_device(), dtype=torch.bfloat16).repeat(builtins.len(pose), 1, 1, 1)).images, fps=fps)
---> 22 torch_xla.distributed.xla_multiprocessing.spawn(process, start_method='fork')

File /usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:95, in requires_pjrt.<locals>.wrapper(*args, **kwargs)
     91 if not using_pjrt():
     92   raise NotImplementedError('`{}` not implemented for XRT'.format(
     93       fn.__name__))
---> 95 return fn(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py:38, in spawn(fn, args, nprocs, join, daemon, start_method)
      6 @xr.requires_pjrt
      7 def spawn(fn,
      8           args=(),
   (...)
     11           daemon=False,
     12           start_method='spawn'):
     13   """Enables multi processing based replication.
     14 
     15   Args:
   (...)
     36     return None.
     37   """
---> 38   return pjrt.spawn(fn, nprocs, start_method, args)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:214, in spawn(fn, nprocs, start_method, args)
    211 elif nprocs is not None:
    212   logging.warning('Unsupported nprocs (%d), ignoring...' % nprocs)
--> 214 run_multiprocess(spawn_fn, start_method=start_method)

File /usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:95, in requires_pjrt.<locals>.wrapper(*args, **kwargs)
     91 if not using_pjrt():
     92   raise NotImplementedError('`{}` not implemented for XRT'.format(
     93       fn.__name__))
---> 95 return fn(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:174, in run_multiprocess(fn, start_method, *args, **kwargs)
    168   mp_fn = functools.partial(
    169       _run_thread_per_device,
    170       local_world_size=num_processes,
    171       fn=functools.partial(fn, *args, **kwargs),
    172       initializer_fn=initialize_multiprocess)
    173   process_results = executor.map(mp_fn, range(num_processes))
--> 174   replica_results = list(
    175       itertools.chain.from_iterable(
    176           result.items() for result in process_results))
    178 return _merge_replica_results(replica_results)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:175, in <genexpr>(.0)
    168   mp_fn = functools.partial(
    169       _run_thread_per_device,
    170       local_world_size=num_processes,
    171       fn=functools.partial(fn, *args, **kwargs),
    172       initializer_fn=initialize_multiprocess)
    173   process_results = executor.map(mp_fn, range(num_processes))
    174   replica_results = list(
--> 175       itertools.chain.from_iterable(
    176           result.items() for result in process_results))
    178 return _merge_replica_results(replica_results)

File /usr/local/lib/python3.10/concurrent/futures/process.py:575, in _chain_from_iterable_of_lists(iterable)
    569 def _chain_from_iterable_of_lists(iterable):
    570     """
    571     Specialized implementation of itertools.chain.from_iterable.
    572     Each item in *iterable* should be a list.  This function is
    573     careful not to keep references to yielded objects.
    574     """
--> 575     for element in iterable:
    576         element.reverse()
    577         while element:

File /usr/local/lib/python3.10/concurrent/futures/_base.py:621, in Executor.map.<locals>.result_iterator()
    618 while fs:
    619     # Careful not to keep a reference to the popped future
    620     if timeout is None:
--> 621         yield _result_or_cancel(fs.pop())
    622     else:
    623         yield _result_or_cancel(fs.pop(), end_time - time.monotonic())

File /usr/local/lib/python3.10/concurrent/futures/_base.py:319, in _result_or_cancel(***failed resolving arguments***)
    317 try:
    318     try:
--> 319         return fut.result(timeout)
    320     finally:
    321         fut.cancel()

File /usr/local/lib/python3.10/concurrent/futures/_base.py:458, in Future.result(self, timeout)
    456     raise CancelledError()
    457 elif self._state == FINISHED:
--> 458     return self.__get_result()
    459 else:
    460     raise TimeoutError()

File /usr/local/lib/python3.10/concurrent/futures/_base.py:403, in Future.__get_result(self)
    401 if self._exception:
    402     try:
--> 403         raise self._exception
    404     finally:
    405         # Break a reference cycle with the exception in self._exception
    406         self = None

RuntimeError: torch_xla/csrc/runtime/pjrt_computation_client.cc:721 : Check failed: pjrt_device == pjrt_data->buffer->device() 
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	torch_xla::runtime::PjRtComputationClient::ExecuteComputation(torch_xla::runtime::ComputationClient::Computation const&, absl::lts_20230802::Span<std::shared_ptr<torch_xla::runtime::ComputationClient::Data> const>, std::string const&, torch_xla::runtime::ComputationClient::ExecuteComputationOptions const&)
	
	torch::lazy::MultiWait::Complete(std::function<void ()> const&)
	Eigen::ThreadPoolTempl<tsl::thread::EigenEnvironment>::WorkerLoop(int)
	void absl::lts_20230802::internal_any_invocable::RemoteInvoker<false, void, tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&>(absl::lts_20230802::internal_any_invocable::TypeErasedState*)
	
	
	__clone
*** End stack trace ***
TPU_1(process=0,(0,0,0,1)) vs TPU_0(process=0,(0,0,0,0))

Do you have any idea how to reduce the memory?

@fancy45daddy
Copy link
Author

@zpcore the following code

import torch_xla, diffusers, builtins, imageio, os, PIL.Image, controlnet_aux, sys, torch
os.environ.pop('TPU_PROCESS_ADDRESSES')

reader = imageio.get_reader('/kaggle/input/controlnet/pose.mp4', 'ffmpeg')
openpose = controlnet_aux.DWposeDetector(det_config='yolox_l_8xb8-300e_coco.py', pose_config='dwpose-l_384x288.py')
poses = [openpose(PIL.Image.fromarray(reader.get_data(_)).resize((512, 768))) for _ in builtins.range(16)] #reader.count_frames()
length = builtins.len(poses) // 8
fps = reader.get_meta_data().get('fps')

def process(index):
    controlnet = diffusers.ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_openpose', torch_dtype=torch.bfloat16)
    pipeline = diffusers.StableDiffusionControlNetPipeline.from_single_file('https://huggingface.co/chaowenguo/pal/blob/main/chilloutMix-Ni.safetensors', config='chaowenguo/stable-diffusion-v1-5', safety_checker=None, controlnet=controlnet, use_safetensors=True, torch_dtype=torch.bfloat16)
    pipeline.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
    pipeline.enable_attention_slicing()
    pipeline.unet.set_attn_processor(diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor())
    pipeline.controlnet.set_attn_processor(diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor())
    pipeline.to(torch_xla.core.xla_model.xla_device())
    pose = sys.modules['__main__'].poses[index * sys.modules['__main__'].length:(index + 1) * sys.modules['__main__'].length]
    imageio.mimsave(f'{index}.mp4', pipeline(prompt=['gorgeous slim young cleavage robust boob japanese girl, wearing white deep V bandeau pantie, smile lying on white bed, best quality, extremely detailed'] * builtins.len(pose), negative_prompt=['monochrome, lowres, bad anatomy, worst quality, low quality'] * builtins.len(pose), image=pose, num_inference_steps=20, latents=torch.randn((1, 4, 96, 64), device=torch_xla.core.xla_model.xla_device(), dtype=torch.bfloat16).repeat(builtins.len(pose), 1, 1, 1)).images, fps=fps)

torch_xla.distributed.xla_multiprocessing.spawn(process, start_method='fork')

get Your notebook tried to allocate more memory than is available. It has restarted. and

 0%|          | 0/20 [00:00<?, ?it/s].01s/it]
  0%|          | 0/20 [00:00<?, ?it/s]
  0%|          | 0/20 [00:00<?, ?it/s]
 45%|████▌     | 9/20 [00:09<00:12,  1.13s/it]
 45%|████▌     | 9/20 [00:09<00:12,  1.13s/it]
  5%|▌         | 1/20 [00:00<00:11,  1.71it/s]
  5%|▌         | 1/20 [00:00<00:11,  1.72it/s]
 10%|█         | 2/20 [00:00<00:08,  2.18it/s]
 10%|█         | 2/20 [00:00<00:08,  2.11it/s]
 50%|█████     | 10/20 [00:10<00:11,  1.11s/it][A
 50%|█████     | 10/20 [00:10<00:11,  1.11s/it]
 15%|█▌        | 3/20 [00:01<00:08,  2.11it/s]
 20%|██        | 4/20 [00:01<00:07,  2.00it/s]
 20%|██        | 4/20 [00:01<00:07,  2.00it/s]
 55%|█████▌    | 11/20 [00:12<00:09,  1.09s/it][A
  5%|▌         | 1/20 [00:02<00:53,  2.82s/it]]
 25%|██▌       | 5/20 [00:02<00:08,  1.85it/s]
 10%|█         | 2/20 [00:03<00:24,  1.38s/it]
 10%|█         | 2/20 [00:03<00:24,  1.36s/it]
 60%|██████    | 12/20 [00:13<00:08,  1.08s/it][A
 30%|███       | 6/20 [00:03<00:08,  1.67it/s]]
 15%|█▌        | 3/20 [00:03<00:16,  1.04it/s]
 15%|█▌        | 3/20 [00:03<00:16,  1.05it/s]
 20%|██        | 4/20 [00:04<00:12,  1.25it/s]
 20%|██        | 4/20 [00:04<00:12,  1.26it/s]
 65%|██████▌   | 13/20 [00:14<00:07,  1.08s/it][A
 25%|██▌       | 5/20 [00:04<00:11,  1.35it/s]]
 40%|████      | 8/20 [00:04<00:08,  1.40it/s]
 40%|████      | 8/20 [00:04<00:08,  1.42it/s]
 30%|███       | 6/20 [00:05<00:10,  1.36it/s]
 70%|███████   | 14/20 [00:15<00:06,  1.07s/it][A
 45%|████▌     | 9/20 [00:05<00:08,  1.30it/s]]
 45%|████▌     | 9/20 [00:05<00:08,  1.28it/s]
 35%|███▌      | 7/20 [00:06<00:09,  1.30it/s]
 35%|███▌      | 7/20 [00:06<00:09,  1.31it/s]
 75%|███████▌  | 15/20 [00:16<00:05,  1.08s/it]
 75%|███████▌  | 15/20 [00:16<00:05,  1.08s/it]
 40%|████      | 8/20 [00:07<00:09,  1.22it/s]]
 40%|████      | 8/20 [00:07<00:09,  1.22it/s]
 80%|████████  | 16/20 [00:17<00:04,  1.07s/it]
 80%|████████  | 16/20 [00:17<00:04,  1.07s/it]
 45%|████▌     | 9/20 [00:08<00:09,  1.16it/s]]
 45%|████▌     | 9/20 [00:08<00:09,  1.16it/s]
 60%|██████    | 12/20 [00:08<00:07,  1.10it/s]
 85%|████████▌ | 17/20 [00:18<00:03,  1.09s/it]
 60%|██████    | 12/20 [00:08<00:07,  1.09it/s]
 65%|██████▌   | 13/20 [00:09<00:06,  1.07it/s]
 65%|██████▌   | 13/20 [00:09<00:06,  1.07it/s]
 55%|█████▌    | 11/20 [00:10<00:08,  1.07it/s]
 55%|█████▌    | 11/20 [00:10<00:08,  1.07it/s]
 70%|███████   | 14/20 [00:10<00:05,  1.06it/s]
 70%|███████   | 14/20 [00:10<00:05,  1.05it/s]
 60%|██████    | 12/20 [00:11<00:07,  1.04it/s]
 75%|███████▌  | 15/20 [00:11<00:04,  1.02it/s]
 75%|███████▌  | 15/20 [00:11<00:04,  1.02it/s]
 65%|██████▌   | 13/20 [00:12<00:06,  1.01it/s]
 80%|████████  | 16/20 [00:12<00:03,  1.02it/s]
 80%|████████  | 16/20 [00:12<00:03,  1.02it/s]
 70%|███████   | 14/20 [00:13<00:06,  1.00s/it]
 85%|████████▌ | 17/20 [00:13<00:03,  1.03s/it]
 85%|████████▌ | 17/20 [00:13<00:03,  1.01s/it]
 80%|████████  | 16/20 [00:15<00:04,  1.03s/it]
 80%|████████  | 16/20 [00:15<00:04,  1.03s/it]
 90%|█████████ | 18/20 [15:25<09:06, 273.34s/it][A
 90%|█████████ | 18/20 [15:21<09:07, 273.61s/it]
 90%|█████████ | 18/20 [15:35<09:15, 277.54s/it]
 90%|█████████ | 18/20 [17:01<10:04, 302.07s/it]
 95%|█████████▌| 19/20 [30:43<07:50, 470.81s/it]
 95%|█████████▌| 19/20 [33:35<08:38, 518.51s/it]
 95%|█████████▌| 19/20 [35:24<09:00, 540.89s/it]
 95%|█████████▌| 19/20 [35:54<09:11, 551.63s/it]
100%|██████████| 20/20 [48:28<00:00, 145.43s/it]

100%|██████████| 20/20 [51:53<00:00, 155.65s/it]
https://symbolize.stripped_domain/r/?trace=7b2ee5711e96,7b2ee56c804f&map= 
https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=7b2ee5711e96,7b2ee5711e96,7b2ee56c804f&map= 
7b2ee56c804f&map= 
*** SIGTERM received by PID 4811 (TID 4811) on cpu 5 from PID 13; stack trace: ***
*** SIGTERM received by PID 4814 (TID 4814) on cpu 46 from PID 13; stack trace: ***
*** SIGTERM received by PID 4813 (TID 4813) on cpu 94 from PID 13; stack trace: ***
PC: @     0x7b2ee5711e96  (unknown)  (unknown)
    @     0x7b2a1ee64521        944  (unknown)
    @     0x7b2ee56c8050  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7b2ee5711e96,7b2a1ee64520,7b2ee56c804f&map= 
E1010 08:01:14.683147    4811 coredump_hook.cc:262] RAW: Remote crash gathering disabled for SIGTERM.
PC: @     0x7b2ee5711e96  (unknown)  (unknown)
PC: @     0x7b2ee5711e96  (unknown)  (unknown)
    @     0x7b2a1ee64521        944  (unknown)
    @     0x7b2a1ee64521        944  (unknown)
    @     0x7b2ee56c8050  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7b2ee5711e96,7b2a1ee64520,    @     0x7b2ee56c8050  (unknown)  (unknown)
7b2ee56c804fhttps://symbolize.stripped_domain/r/?trace=&map=7b2ee5711e96, 
7b2a1ee64520,7b2ee56c804f&map= 
E1010 08:01:14.683773    4813 coredump_hook.cc:262] RAW: Remote crash gathering disabled for SIGTERM.
E1010 08:01:14.683775    4814 coredump_hook.cc:262] RAW: Remote crash gathering disabled for SIGTERM.
E1010 08:01:16.891155    4813 process_state.cc:805] RAW: Raising signal 15 with default behavior
E1010 08:01:17.356729    4811 process_state.cc:805] RAW: Raising signal 15 with default behavior
E1010 08:01:17.359651    4814 process_state.cc:805] RAW: Raising signal 15 with default behavior

---------------------------------------------------------------------------
BrokenProcessPool                         Traceback (most recent call last)
Cell In[2], line 21
     18     pose = sys.modules['__main__'].poses[index * sys.modules['__main__'].length:(index + 1) * sys.modules['__main__'].length]
     19     imageio.mimsave(f'{index}.mp4', pipeline(prompt=['gorgeous slim young cleavage robust boob japanese girl, wearing white deep V bandeau pantie, smile lying on white bed, best quality, extremely detailed'] * builtins.len(pose), negative_prompt=['monochrome, lowres, bad anatomy, worst quality, low quality'] * builtins.len(pose), image=pose, num_inference_steps=20, latents=torch.randn((1, 4, 96, 64), device=torch_xla.core.xla_model.xla_device(), dtype=torch.bfloat16).repeat(builtins.len(pose), 1, 1, 1)).images, fps=fps)
---> 21 torch_xla.distributed.xla_multiprocessing.spawn(process, start_method='fork')

File /usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:95, in requires_pjrt.<locals>.wrapper(*args, **kwargs)
     91 if not using_pjrt():
     92   raise NotImplementedError('`{}` not implemented for XRT'.format(
     93       fn.__name__))
---> 95 return fn(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py:38, in spawn(fn, args, nprocs, join, daemon, start_method)
      6 @xr.requires_pjrt
      7 def spawn(fn,
      8           args=(),
   (...)
     11           daemon=False,
     12           start_method='spawn'):
     13   """Enables multi processing based replication.
     14 
     15   Args:
   (...)
     36     return None.
     37   """
---> 38   return pjrt.spawn(fn, nprocs, start_method, args)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:214, in spawn(fn, nprocs, start_method, args)
    211 elif nprocs is not None:
    212   logging.warning('Unsupported nprocs (%d), ignoring...' % nprocs)
--> 214 run_multiprocess(spawn_fn, start_method=start_method)

File /usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:95, in requires_pjrt.<locals>.wrapper(*args, **kwargs)
     91 if not using_pjrt():
     92   raise NotImplementedError('`{}` not implemented for XRT'.format(
     93       fn.__name__))
---> 95 return fn(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:174, in run_multiprocess(fn, start_method, *args, **kwargs)
    168   mp_fn = functools.partial(
    169       _run_thread_per_device,
    170       local_world_size=num_processes,
    171       fn=functools.partial(fn, *args, **kwargs),
    172       initializer_fn=initialize_multiprocess)
    173   process_results = executor.map(mp_fn, range(num_processes))
--> 174   replica_results = list(
    175       itertools.chain.from_iterable(
    176           result.items() for result in process_results))
    178 return _merge_replica_results(replica_results)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:175, in <genexpr>(.0)
    168   mp_fn = functools.partial(
    169       _run_thread_per_device,
    170       local_world_size=num_processes,
    171       fn=functools.partial(fn, *args, **kwargs),
    172       initializer_fn=initialize_multiprocess)
    173   process_results = executor.map(mp_fn, range(num_processes))
    174   replica_results = list(
--> 175       itertools.chain.from_iterable(
    176           result.items() for result in process_results))
    178 return _merge_replica_results(replica_results)

File /usr/local/lib/python3.10/concurrent/futures/process.py:575, in _chain_from_iterable_of_lists(iterable)
    569 def _chain_from_iterable_of_lists(iterable):
    570     """
    571     Specialized implementation of itertools.chain.from_iterable.
    572     Each item in *iterable* should be a list.  This function is
    573     careful not to keep references to yielded objects.
    574     """
--> 575     for element in iterable:
    576         element.reverse()
    577         while element:

File /usr/local/lib/python3.10/concurrent/futures/_base.py:621, in Executor.map.<locals>.result_iterator()
    618 while fs:
    619     # Careful not to keep a reference to the popped future
    620     if timeout is None:
--> 621         yield _result_or_cancel(fs.pop())
    622     else:
    623         yield _result_or_cancel(fs.pop(), end_time - time.monotonic())

File /usr/local/lib/python3.10/concurrent/futures/_base.py:319, in _result_or_cancel(***failed resolving arguments***)
    317 try:
    318     try:
--> 319         return fut.result(timeout)
    320     finally:
    321         fut.cancel()

File /usr/local/lib/python3.10/concurrent/futures/_base.py:458, in Future.result(self, timeout)
    456     raise CancelledError()
    457 elif self._state == FINISHED:
--> 458     return self.__get_result()
    459 else:
    460     raise TimeoutError()

File /usr/local/lib/python3.10/concurrent/futures/_base.py:403, in Future.__get_result(self)
    401 if self._exception:
    402     try:
--> 403         raise self._exception
    404     finally:
    405         # Break a reference cycle with the exception in self._exception
    406         self = None

BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

@zpcore
Copy link
Collaborator

zpcore commented Oct 10, 2024

I didn't find the complain message about the memory. It looks like TPU doesn't have the data. Can you try

pipe = pipe.to(torch_xla.core.xla_model.xla_device())

Instead of pipe.to(torch_xla.core.xla_model.xla_device())?

@fancy45daddy
Copy link
Author

@zpcore

import torch_xla, diffusers, builtins, imageio, os, PIL.Image, controlnet_aux, sys, torch
os.environ.pop('TPU_PROCESS_ADDRESSES')

reader = imageio.get_reader('/kaggle/input/controlnet/pose.mp4', 'ffmpeg')
openpose = controlnet_aux.DWposeDetector(det_config='yolox_l_8xb8-300e_coco.py', pose_config='dwpose-l_384x288.py')
poses = [openpose(PIL.Image.fromarray(reader.get_data(_)).resize((512, 768))) for _ in builtins.range(16)] #reader.count_frames()
length = builtins.len(poses) // 8
fps = reader.get_meta_data().get('fps')
controlnet = diffusers.ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_openpose', torch_dtype=torch.bfloat16)
pipeline = diffusers.StableDiffusionControlNetPipeline.from_single_file('https://huggingface.co/chaowenguo/pal/blob/main/chilloutMix-Ni.safetensors', config='chaowenguo/stable-diffusion-v1-5', safety_checker=None, controlnet=controlnet, use_safetensors=True, torch_dtype=torch.bfloat16)
pipeline.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.enable_attention_slicing()
pipeline.unet.set_attn_processor(diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor())
pipeline.controlnet.set_attn_processor(diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor())

def process(index):
    pipe = diffusers.StableDiffusionControlNetPipeline(**pipeline.components).to(torch_xla.core.xla_model.xla_device())
    pose = sys.modules['__main__'].poses[index * sys.modules['__main__'].length:(index + 1) * sys.modules['__main__'].length]
    imageio.mimsave(f'{index}.mp4', pipe(prompt=['gorgeous slim young cleavage robust boob japanese girl, wearing white deep V bandeau pantie, smile lying on white bed, best quality, extremely detailed'] * builtins.len(pose), negative_prompt=['monochrome, lowres, bad anatomy, worst quality, low quality'] * builtins.len(pose), image=pose, num_inference_steps=20, latents=torch.randn((1, 4, 96, 64), device=torch_xla.core.xla_model.xla_device(), dtype=torch.bfloat16).repeat(builtins.len(pose), 1, 1, 1)).images, fps=fps)

torch_xla.distributed.xla_multiprocessing.spawn(process, start_method='fork')

this code is not working, it get

 30%|███       | 6/20 [00:03<00:10,  1.39it/s]
  0%|          | 0/20 [00:00<?, ?it/s]
 30%|███       | 6/20 [00:03<00:10,  1.37it/s]
  0%|          | 0/20 [00:00<?, ?it/s]
 35%|███▌      | 7/20 [00:04<00:11,  1.14it/s]
  5%|▌         | 1/20 [00:01<00:22,  1.17s/it]
 35%|███▌      | 7/20 [00:05<00:11,  1.11it/s]
  5%|▌         | 1/20 [00:01<00:21,  1.14s/it]
 40%|████      | 8/20 [00:06<00:12,  1.02s/it]
 10%|█         | 2/20 [00:02<00:23,  1.30s/it]
 40%|████      | 8/20 [00:06<00:12,  1.05s/it]
 10%|█         | 2/20 [00:02<00:23,  1.33s/it]
 45%|████▌     | 9/20 [00:07<00:12,  1.15s/it]
 15%|█▌        | 3/20 [00:03<00:22,  1.35s/it]
 45%|████▌     | 9/20 [00:07<00:12,  1.18s/it]
 15%|█▌        | 3/20 [00:04<00:23,  1.38s/it]
 50%|█████     | 10/20 [00:09<00:12,  1.24s/it][A
 50%|█████     | 10/20 [00:09<00:12,  1.27s/it][A
 20%|██        | 4/20 [00:05<00:22,  1.42s/it]
 20%|██        | 4/20 [00:05<00:22,  1.43s/it]
 55%|█████▌    | 11/20 [00:10<00:11,  1.32s/it][A
 55%|█████▌    | 11/20 [00:10<00:12,  1.37s/it][A
 25%|██▌       | 5/20 [00:07<00:22,  1.48s/it]
 25%|██▌       | 5/20 [00:07<00:22,  1.47s/it]
 55%|█████▌    | 11/20 [12:34<10:17, 68.58s/it][A
 55%|█████▌    | 11/20 [12:59<10:38, 70.89s/it]
 55%|█████▌    | 11/20 [13:00<10:38, 70.94s/it]
 25%|██▌       | 5/20 [13:31<40:35, 162.37s/it]
 25%|██▌       | 5/20 [14:26<43:18, 173.21s/it]
 25%|██▌       | 5/20 [14:42<44:07, 176.52s/it]
 25%|██▌       | 5/20 [14:45<44:17, 177.20s/it]
 60%|██████    | 12/20 [17:56<11:57, 89.69s/it]

---------------------------------------------------------------------------
_RemoteTraceback                          Traceback (most recent call last)
_RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 205, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 205, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/usr/local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 78, in _run_thread_per_device
    replica_results = list(
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
  File "/usr/local/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 71, in _thread_fn
    return fn()
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 190, in __call__
    self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
  File "/tmp/ipykernel_13/524703449.py", line 19, in process
    imageio.mimsave(f'{index}.mp4', pipe(prompt=['gorgeous slim young cleavage robust boob japanese girl, wearing white deep V bandeau pantie, smile lying on white bed, best quality, extremely detailed'] * builtins.len(pose), negative_prompt=['monochrome, lowres, bad anatomy, worst quality, low quality'] * builtins.len(pose), image=pose, num_inference_steps=20, latents=torch.randn((1, 4, 96, 64), device=torch_xla.core.xla_model.xla_device(), dtype=torch.bfloat16).repeat(builtins.len(pose), 1, 1, 1)).images, fps=fps)
  File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet.py", line 1282, in __call__
    noise_pred = self.unet(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1281, in forward
    sample = upsample_block(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 2551, in forward
    hidden_states = attn(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 442, in forward
    hidden_states = block(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/diffusers/models/attention.py", line 504, in forward
    attn_output = self.attn2(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 490, in forward
    return self.processor(
  File "/usr/local/lib/python3.10/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py", line 100, in __call__
    hidden_states = attn.to_out[0](hidden_states)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: torch_xla/csrc/runtime/pjrt_computation_client.cc:721 : Check failed: pjrt_device == pjrt_data->buffer->device() 
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	torch_xla::runtime::PjRtComputationClient::ExecuteComputation(torch_xla::runtime::ComputationClient::Computation const&, absl::lts_20230802::Span<std::shared_ptr<torch_xla::runtime::ComputationClient::Data> const>, std::string const&, torch_xla::runtime::ComputationClient::ExecuteComputationOptions const&)
	
	torch::lazy::MultiWait::Complete(std::function<void ()> const&)
	Eigen::ThreadPoolTempl<tsl::thread::EigenEnvironment>::WorkerLoop(int)
	void absl::lts_20230802::internal_any_invocable::RemoteInvoker<false, void, tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&>(absl::lts_20230802::internal_any_invocable::TypeErasedState*)
	
	
	__clone
*** End stack trace ***
TPU_1(process=0,(0,0,0,1)) vs TPU_0(process=0,(0,0,0,0))
"""

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[2], line 21
     18     pose = sys.modules['__main__'].poses[index * sys.modules['__main__'].length:(index + 1) * sys.modules['__main__'].length]
     19     imageio.mimsave(f'{index}.mp4', pipe(prompt=['gorgeous slim young cleavage robust boob japanese girl, wearing white deep V bandeau pantie, smile lying on white bed, best quality, extremely detailed'] * builtins.len(pose), negative_prompt=['monochrome, lowres, bad anatomy, worst quality, low quality'] * builtins.len(pose), image=pose, num_inference_steps=20, latents=torch.randn((1, 4, 96, 64), device=torch_xla.core.xla_model.xla_device(), dtype=torch.bfloat16).repeat(builtins.len(pose), 1, 1, 1)).images, fps=fps)
---> 21 torch_xla.distributed.xla_multiprocessing.spawn(process, start_method='fork')

File /usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:95, in requires_pjrt.<locals>.wrapper(*args, **kwargs)
     91 if not using_pjrt():
     92   raise NotImplementedError('`{}` not implemented for XRT'.format(
     93       fn.__name__))
---> 95 return fn(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py:38, in spawn(fn, args, nprocs, join, daemon, start_method)
      6 @xr.requires_pjrt
      7 def spawn(fn,
      8           args=(),
   (...)
     11           daemon=False,
     12           start_method='spawn'):
     13   """Enables multi processing based replication.
     14 
     15   Args:
   (...)
     36     return None.
     37   """
---> 38   return pjrt.spawn(fn, nprocs, start_method, args)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:214, in spawn(fn, nprocs, start_method, args)
    211 elif nprocs is not None:
    212   logging.warning('Unsupported nprocs (%d), ignoring...' % nprocs)
--> 214 run_multiprocess(spawn_fn, start_method=start_method)

File /usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:95, in requires_pjrt.<locals>.wrapper(*args, **kwargs)
     91 if not using_pjrt():
     92   raise NotImplementedError('`{}` not implemented for XRT'.format(
     93       fn.__name__))
---> 95 return fn(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:174, in run_multiprocess(fn, start_method, *args, **kwargs)
    168   mp_fn = functools.partial(
    169       _run_thread_per_device,
    170       local_world_size=num_processes,
    171       fn=functools.partial(fn, *args, **kwargs),
    172       initializer_fn=initialize_multiprocess)
    173   process_results = executor.map(mp_fn, range(num_processes))
--> 174   replica_results = list(
    175       itertools.chain.from_iterable(
    176           result.items() for result in process_results))
    178 return _merge_replica_results(replica_results)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:175, in <genexpr>(.0)
    168   mp_fn = functools.partial(
    169       _run_thread_per_device,
    170       local_world_size=num_processes,
    171       fn=functools.partial(fn, *args, **kwargs),
    172       initializer_fn=initialize_multiprocess)
    173   process_results = executor.map(mp_fn, range(num_processes))
    174   replica_results = list(
--> 175       itertools.chain.from_iterable(
    176           result.items() for result in process_results))
    178 return _merge_replica_results(replica_results)

File /usr/local/lib/python3.10/concurrent/futures/process.py:575, in _chain_from_iterable_of_lists(iterable)
    569 def _chain_from_iterable_of_lists(iterable):
    570     """
    571     Specialized implementation of itertools.chain.from_iterable.
    572     Each item in *iterable* should be a list.  This function is
    573     careful not to keep references to yielded objects.
    574     """
--> 575     for element in iterable:
    576         element.reverse()
    577         while element:

File /usr/local/lib/python3.10/concurrent/futures/_base.py:621, in Executor.map.<locals>.result_iterator()
    618 while fs:
    619     # Careful not to keep a reference to the popped future
    620     if timeout is None:
--> 621         yield _result_or_cancel(fs.pop())
    622     else:
    623         yield _result_or_cancel(fs.pop(), end_time - time.monotonic())

File /usr/local/lib/python3.10/concurrent/futures/_base.py:319, in _result_or_cancel(***failed resolving arguments***)
    317 try:
    318     try:
--> 319         return fut.result(timeout)
    320     finally:
    321         fut.cancel()

File /usr/local/lib/python3.10/concurrent/futures/_base.py:458, in Future.result(self, timeout)
    456     raise CancelledError()
    457 elif self._state == FINISHED:
--> 458     return self.__get_result()
    459 else:
    460     raise TimeoutError()

File /usr/local/lib/python3.10/concurrent/futures/_base.py:403, in Future.__get_result(self)
    401 if self._exception:
    402     try:
--> 403         raise self._exception
    404     finally:
    405         # Break a reference cycle with the exception in self._exception
    406         self = None

RuntimeError: torch_xla/csrc/runtime/pjrt_computation_client.cc:721 : Check failed: pjrt_device == pjrt_data->buffer->device() 
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	torch_xla::runtime::PjRtComputationClient::ExecuteComputation(torch_xla::runtime::ComputationClient::Computation const&, absl::lts_20230802::Span<std::shared_ptr<torch_xla::runtime::ComputationClient::Data> const>, std::string const&, torch_xla::runtime::ComputationClient::ExecuteComputationOptions const&)
	
	torch::lazy::MultiWait::Complete(std::function<void ()> const&)
	Eigen::ThreadPoolTempl<tsl::thread::EigenEnvironment>::WorkerLoop(int)
	void absl::lts_20230802::internal_any_invocable::RemoteInvoker<false, void, tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&>(absl::lts_20230802::internal_any_invocable::TypeErasedState*)
	
	
	__clone
*** End stack trace ***
TPU_1(process=0,(0,0,0,1)) vs TPU_0(process=0,(0,0,0,0))

Do you know what is the correct way to run StableDiffusionControlNetPipeline on kaggle v3-8 tpu

@fancy45daddy
Copy link
Author

fancy45daddy commented Oct 18, 2024

@zpcore @JackCaoG
I try the code in https://www.kaggle.com/code/chaowenguoback/stablediffusion/notebook?scriptVersionId=201780362

%%bash
python3 -m pip install -U imageio[ffmpeg] controlnet-aux openmim
mim install mmengine mmcv mmdet mmpose
curl -O https://raw.githubusercontent.com/huggingface/controlnet_aux/master/src/controlnet_aux/dwpose/dwpose_config/dwpose-l_384x288.py
curl -O https://raw.githubusercontent.com/huggingface/controlnet_aux/master/src/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py

import torch_xla, diffusers, builtins, imageio, os, PIL.Image, controlnet_aux, sys, torch, multiprocessing
lock = multiprocessing.Manager().Lock()
os.environ.pop('TPU_PROCESS_ADDRESSES')

reader = imageio.get_reader('/kaggle/input/controlnet/pose.mp4', 'ffmpeg')
openpose = controlnet_aux.DWposeDetector(det_config='yolox_l_8xb8-300e_coco.py', pose_config='dwpose-l_384x288.py')
poses = [openpose(PIL.Image.fromarray(reader.get_data(_)).resize((512, 768))) for _ in builtins.range(16)] #reader.count_frames()
length = builtins.len(poses) // 8

def process(index, lock):
    controlnet = diffusers.ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_openpose', torch_dtype=torch.bfloat16)
    pipeline = diffusers.StableDiffusionControlNetPipeline.from_single_file('https://huggingface.co/chaowenguo/pal/blob/main/chilloutMix-Ni.safetensors', config='chaowenguo/stable-diffusion-v1-5', safety_checker=None, controlnet=controlnet, use_safetensors=True, torch_dtype=torch.bfloat16).to(torch_xla.core.xla_model.xla_device())
    pipeline.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
    pipeline.enable_attention_slicing()
    #pipeline.enable_vae_slicing()
    #pipeline.enable_vae_tiling()
    #pipeline.enable_sequential_cpu_offload()
    #pipeline.enable_model_cpu_offload()
    pipeline.unet.set_attn_processor(diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor())
    pipeline.controlnet.set_attn_processor(diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor())
    pose = sys.modules['__main__'].poses[index * sys.modules['__main__'].length:(index + 1) * sys.modules['__main__'].length]
    images = pipeline(prompt=['gorgeous slim young cleavage robust boob japanese girl, wearing white deep V bandeau pantie, smile lying on white bed, best quality, extremely detailed'] * builtins.len(pose), negative_prompt=['monochrome, lowres, bad anatomy, worst quality, low quality'] * builtins.len(pose), image=pose, num_inference_steps=20, latents=torch.randn((1, 4, 96, 64), device=torch_xla.core.xla_model.xla_device(), dtype=torch.bfloat16).repeat(builtins.len(pose), 1, 1, 1), return_dict=False, output_type='latent')[0]
    with lock:
        print(images)
        #images = pipeline.image_processor.pt_to_numpy(images)
        #images = pipeline.image_processor.numpy_to_pil(images)
        #imageio.mimsave(f'{index}.mp4', images, fps=sys.modules['__main__'].reader.get_meta_data().get('fps'))

torch_xla.distributed.xla_multiprocessing.spawn(process, args=(lock,), start_method='fork')

But the output is

---------------------------------------------------------------------------
BrokenProcessPool                         Traceback (most recent call last)
Cell In[2], line 29
     24         print(images)
     25         #images = pipeline.image_processor.pt_to_numpy(images)
     26         #images = pipeline.image_processor.numpy_to_pil(images)
     27         #imageio.mimsave(f'{index}.mp4', images, fps=sys.modules['__main__'].reader.get_meta_data().get('fps'))
---> 29 torch_xla.distributed.xla_multiprocessing.spawn(process, args=(lock,), start_method='fork')

File /usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:95, in requires_pjrt.<locals>.wrapper(*args, **kwargs)
     91 if not using_pjrt():
     92   raise NotImplementedError('`{}` not implemented for XRT'.format(
     93       fn.__name__))
---> 95 return fn(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py:38, in spawn(fn, args, nprocs, join, daemon, start_method)
      6 @xr.requires_pjrt
      7 def spawn(fn,
      8           args=(),
   (...)
     11           daemon=False,
     12           start_method='spawn'):
     13   """Enables multi processing based replication.
     14 
     15   Args:
   (...)
     36     return None.
     37   """
---> 38   return pjrt.spawn(fn, nprocs, start_method, args)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:214, in spawn(fn, nprocs, start_method, args)
    211 elif nprocs is not None:
    212   logging.warning('Unsupported nprocs (%d), ignoring...' % nprocs)
--> 214 run_multiprocess(spawn_fn, start_method=start_method)

File /usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:95, in requires_pjrt.<locals>.wrapper(*args, **kwargs)
     91 if not using_pjrt():
     92   raise NotImplementedError('`{}` not implemented for XRT'.format(
     93       fn.__name__))
---> 95 return fn(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:174, in run_multiprocess(fn, start_method, *args, **kwargs)
    168   mp_fn = functools.partial(
    169       _run_thread_per_device,
    170       local_world_size=num_processes,
    171       fn=functools.partial(fn, *args, **kwargs),
    172       initializer_fn=initialize_multiprocess)
    173   process_results = executor.map(mp_fn, range(num_processes))
--> 174   replica_results = list(
    175       itertools.chain.from_iterable(
    176           result.items() for result in process_results))
    178 return _merge_replica_results(replica_results)

File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:175, in <genexpr>(.0)
    168   mp_fn = functools.partial(
    169       _run_thread_per_device,
    170       local_world_size=num_processes,
    171       fn=functools.partial(fn, *args, **kwargs),
    172       initializer_fn=initialize_multiprocess)
    173   process_results = executor.map(mp_fn, range(num_processes))
    174   replica_results = list(
--> 175       itertools.chain.from_iterable(
    176           result.items() for result in process_results))
    178 return _merge_replica_results(replica_results)

File /usr/local/lib/python3.10/concurrent/futures/process.py:575, in _chain_from_iterable_of_lists(iterable)
    569 def _chain_from_iterable_of_lists(iterable):
    570     """
    571     Specialized implementation of itertools.chain.from_iterable.
    572     Each item in *iterable* should be a list.  This function is
    573     careful not to keep references to yielded objects.
    574     """
--> 575     for element in iterable:
    576         element.reverse()
    577         while element:

File /usr/local/lib/python3.10/concurrent/futures/_base.py:621, in Executor.map.<locals>.result_iterator()
    618 while fs:
    619     # Careful not to keep a reference to the popped future
    620     if timeout is None:
--> 621         yield _result_or_cancel(fs.pop())
    622     else:
    623         yield _result_or_cancel(fs.pop(), end_time - time.monotonic())

File /usr/local/lib/python3.10/concurrent/futures/_base.py:319, in _result_or_cancel(***failed resolving arguments***)
    317 try:
    318     try:
--> 319         return fut.result(timeout)
    320     finally:
    321         fut.cancel()

File /usr/local/lib/python3.10/concurrent/futures/_base.py:458, in Future.result(self, timeout)
    456     raise CancelledError()
    457 elif self._state == FINISHED:
--> 458     return self.__get_result()
    459 else:
    460     raise TimeoutError()

File /usr/local/lib/python3.10/concurrent/futures/_base.py:403, in Future.__get_result(self)
    401 if self._exception:
    402     try:
--> 403         raise self._exception
    404     finally:
    405         # Break a reference cycle with the exception in self._exception
    406         self = None

BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

the output is really not help, Do you have any idea how to make the code working?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants