Skip to content

Commit

Permalink
introduce torch_xla2.compile API, make sdxl to use it (#8269)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored Oct 17, 2024
1 parent fa311ec commit ddd4db7
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 3 deletions.
5 changes: 5 additions & 0 deletions experimental/reference_models/sdxl_inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# How to run:

```
python sdxl.py
```
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
73 changes: 73 additions & 0 deletions experimental/reference_models/sdxl_inference/sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import time
import functools
import jax
import torch
import torch_xla2
from torch_xla2 import interop
from torch_xla2.interop import JittableModule

from transformers.modeling_outputs import BaseModelOutputWithPooling

from jax.tree_util import register_pytree_node
import jax

def base_model_output_with_pooling_flatten(v):
return (v.last_hidden_state, v.pooler_output, v.hidden_states, v.attentions), None

def base_model_output_with_pooling_unflatten(aux_data, children):
return BaseModelOutputWithPooling(*children)

register_pytree_node(
BaseModelOutputWithPooling,
base_model_output_with_pooling_flatten,
base_model_output_with_pooling_unflatten
)


from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")

prompt = "a photograph of an astronaut riding a horse"
# image = pipe(prompt).images[0]


env = torch_xla2.default_env()
jax.config.update('jax_enable_x64', False)

def move_scheduler(scheduler):
for k, v in scheduler.__dict__.items():
if isinstance(v, torch.Tensor):
setattr(scheduler, k, v.to('jax'))


with env:
pipe.to('jax:1')
move_scheduler(pipe.scheduler)
pipe.unet = torch_xla2.compile(
pipe.unet, torch_xla2.CompileOptions(
jax_jit_kwargs={'static_argnames': ('return_dict',)}
)
)
import pdb; pdb.set_trace()
pipe.text_encoder = torch_xla2.compile(pipe.text_encoder)

BS = 4
prompt = [prompt] * BS
pipe.vae = torch_xla2.compile(
pipe.vae, torch_xla2.CompileOptions(
jax_jit_kwargs={'static_argnames': ('return_dict',)},
methods_to_compile=['decode'],
)
)
image = pipe(prompt).images[0]

jax.profiler.start_trace('/tmp/sdxl')
start = time.perf_counter()
image = pipe(prompt, num_inference_steps=20).images[0]
end = time.perf_counter()
jax.profiler.stop_trace()
print('Total time is ', end - start, 'bs = ', BS)
image.save(f"astronaut_rides_horse.png")



14 changes: 14 additions & 0 deletions experimental/reference_models/sdxl_inference/sdxl_beginning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
from diffusers import StableDiffusionPipeline

import torch_xla2
env = torch_xla2.default_env()

# this is now contains torhc.Tensor
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")

with env:
pipe.to('jax')
prompt = "a photograph of an astronaut riding a horse"
image = pipe(prompt, num_inference_steps=10).images[0]
image.save(f"astronaut_rides_horse_orig.png")
30 changes: 29 additions & 1 deletion experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Dict, Any, Optional
import dataclasses
import jax
import os
import torch
Expand Down Expand Up @@ -91,4 +93,30 @@ def enable_accuracy_mode():
def enable_performance_mode():
jax.config.update('jax_enable_x64', False)
jax.config.update('jax_default_matmul_precision', 'default')
default_env().config.internal_respect_torch_return_dtypes = False
default_env().config.internal_respect_torch_return_dtypes = False



@dataclasses.dataclass
class CompileOptions:
# only valid if compiling nn.Module
methods_to_compile: List[str] = dataclasses.field(default_factory=lambda: ['forward'])
jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
mode: str = 'jax' # or dynamo or export


def compile(fn, options: Optional[CompileOptions] = None):
options = options or CompileOptions()
if options.mode == 'jax':
from torch_xla2 import interop
if isinstance(fn, torch.nn.Module):
module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs)
for n in options.methods_to_compile:
module.make_jitted(n)
return module
else:
return interop.jax_jit(fn)
elif options.mode == 'dynamo':
raise RuntimeError('dynamo mode is not supported yet')
elif options.mode == 'export':
raise RuntimeError('export mode is not supported yet')
2 changes: 0 additions & 2 deletions experimental/torch_xla2/torch_xla2/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def set_one(module, prefix):

class JittableModule(torch.nn.Module):

# TODO: add statedict loading hook

def __init__(self, m: torch.nn.Module, extra_jit_args={}):
super().__init__()
self.params, self.buffers = extract_all_buffers(m)
Expand Down

0 comments on commit ddd4db7

Please sign in to comment.