diff --git a/experimental/reference_models/sdxl_inference/README.md b/experimental/reference_models/sdxl_inference/README.md new file mode 100644 index 00000000000..bfd0db70a41 --- /dev/null +++ b/experimental/reference_models/sdxl_inference/README.md @@ -0,0 +1,5 @@ +# How to run: + +``` +python sdxl.py +``` \ No newline at end of file diff --git a/experimental/reference_models/sdxl_inference/astronaut_rides_horse.png b/experimental/reference_models/sdxl_inference/astronaut_rides_horse.png new file mode 100644 index 00000000000..9a3927eaa0a Binary files /dev/null and b/experimental/reference_models/sdxl_inference/astronaut_rides_horse.png differ diff --git a/experimental/reference_models/sdxl_inference/sdxl.py b/experimental/reference_models/sdxl_inference/sdxl.py new file mode 100644 index 00000000000..c3af8c906ff --- /dev/null +++ b/experimental/reference_models/sdxl_inference/sdxl.py @@ -0,0 +1,73 @@ +import time +import functools +import jax +import torch +import torch_xla2 +from torch_xla2 import interop +from torch_xla2.interop import JittableModule + +from transformers.modeling_outputs import BaseModelOutputWithPooling + +from jax.tree_util import register_pytree_node +import jax + +def base_model_output_with_pooling_flatten(v): + return (v.last_hidden_state, v.pooler_output, v.hidden_states, v.attentions), None + +def base_model_output_with_pooling_unflatten(aux_data, children): + return BaseModelOutputWithPooling(*children) + +register_pytree_node( + BaseModelOutputWithPooling, + base_model_output_with_pooling_flatten, + base_model_output_with_pooling_unflatten +) + + +from diffusers import StableDiffusionPipeline +pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base") + +prompt = "a photograph of an astronaut riding a horse" +# image = pipe(prompt).images[0] + + +env = torch_xla2.default_env() +jax.config.update('jax_enable_x64', False) + +def move_scheduler(scheduler): + for k, v in scheduler.__dict__.items(): + if isinstance(v, torch.Tensor): + setattr(scheduler, k, v.to('jax')) + + +with env: + pipe.to('jax:1') + move_scheduler(pipe.scheduler) + pipe.unet = torch_xla2.compile( + pipe.unet, torch_xla2.CompileOptions( + jax_jit_kwargs={'static_argnames': ('return_dict',)} + ) + ) + import pdb; pdb.set_trace() + pipe.text_encoder = torch_xla2.compile(pipe.text_encoder) + + BS = 4 + prompt = [prompt] * BS + pipe.vae = torch_xla2.compile( + pipe.vae, torch_xla2.CompileOptions( + jax_jit_kwargs={'static_argnames': ('return_dict',)}, + methods_to_compile=['decode'], + ) + ) + image = pipe(prompt).images[0] + + jax.profiler.start_trace('/tmp/sdxl') + start = time.perf_counter() + image = pipe(prompt, num_inference_steps=20).images[0] + end = time.perf_counter() + jax.profiler.stop_trace() + print('Total time is ', end - start, 'bs = ', BS) + image.save(f"astronaut_rides_horse.png") + + + diff --git a/experimental/reference_models/sdxl_inference/sdxl_beginning.py b/experimental/reference_models/sdxl_inference/sdxl_beginning.py new file mode 100644 index 00000000000..b182c61c075 --- /dev/null +++ b/experimental/reference_models/sdxl_inference/sdxl_beginning.py @@ -0,0 +1,14 @@ +import torch +from diffusers import StableDiffusionPipeline + +import torch_xla2 +env = torch_xla2.default_env() + +# this is now contains torhc.Tensor +pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base") + +with env: + pipe.to('jax') + prompt = "a photograph of an astronaut riding a horse" + image = pipe(prompt, num_inference_steps=10).images[0] + image.save(f"astronaut_rides_horse_orig.png") diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index f7dbde71263..28efa1992c7 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -1,3 +1,5 @@ +from typing import List, Dict, Any, Optional +import dataclasses import jax import os import torch @@ -91,4 +93,30 @@ def enable_accuracy_mode(): def enable_performance_mode(): jax.config.update('jax_enable_x64', False) jax.config.update('jax_default_matmul_precision', 'default') - default_env().config.internal_respect_torch_return_dtypes = False \ No newline at end of file + default_env().config.internal_respect_torch_return_dtypes = False + + + +@dataclasses.dataclass +class CompileOptions: + # only valid if compiling nn.Module + methods_to_compile: List[str] = dataclasses.field(default_factory=lambda: ['forward']) + jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + mode: str = 'jax' # or dynamo or export + + +def compile(fn, options: Optional[CompileOptions] = None): + options = options or CompileOptions() + if options.mode == 'jax': + from torch_xla2 import interop + if isinstance(fn, torch.nn.Module): + module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs) + for n in options.methods_to_compile: + module.make_jitted(n) + return module + else: + return interop.jax_jit(fn) + elif options.mode == 'dynamo': + raise RuntimeError('dynamo mode is not supported yet') + elif options.mode == 'export': + raise RuntimeError('export mode is not supported yet') diff --git a/experimental/torch_xla2/torch_xla2/interop.py b/experimental/torch_xla2/torch_xla2/interop.py index 604ce8b7184..d75c450d0ed 100644 --- a/experimental/torch_xla2/torch_xla2/interop.py +++ b/experimental/torch_xla2/torch_xla2/interop.py @@ -49,8 +49,6 @@ def set_one(module, prefix): class JittableModule(torch.nn.Module): - # TODO: add statedict loading hook - def __init__(self, m: torch.nn.Module, extra_jit_args={}): super().__init__() self.params, self.buffers = extract_all_buffers(m)