Skip to content

Commit

Permalink
vista: perf tuning 4.84x faster dream
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jul 24, 2024
1 parent 033635c commit cd11be9
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 9 deletions.
54 changes: 47 additions & 7 deletions torchdrive/models/vista.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,20 @@ def __init__(
start = time.perf_counter()

config = OmegaConf.load(config_path)
config["model"]["params"]["num_frames"] = num_frames
config["model"]["params"]["denoiser_config"]["params"][
"num_frames"
] = num_frames
model = load_model_from_config(config, ckpt_path)
self.model = model.bfloat16().to(device).eval()

def compile(f):
return torch.compile(f, fullgraph=True, mode="reduce-overhead")

self.model.encode_first_stage = compile(self.model.encode_first_stage)
self.model.denoiser = compile(self.model.denoiser)
self.model.decode_first_stage = compile(self.model.decode_first_stage)

print(f"loaded vista in {time.perf_counter() - start:.2f}s")

guider = "VanillaCFG"
Expand Down Expand Up @@ -153,6 +164,8 @@ def generate(


if __name__ == "__main__":
torch.manual_seed(10)

dataset = NuscenesDataset(
data_dir="~/nuscenes",
version="v1.0-mini",
Expand All @@ -174,11 +187,38 @@ def generate(

print(trajectory)

sampler = VistaSampler(device=device)
out = sampler.generate(cond_img, trajectory)
print(out.shape)
assert out.shape == (10, 3, 480, 640)
steps = 25
num_frames = 6

sampler = VistaSampler(device=device, num_frames=num_frames, steps=steps)
"""
Encoding 10 frames: 0.053
Decoding 10 frames: 0.242
* 50 steps, 10 frames = 8.33s
* 50 steps, 5 frames = 5.24s
* 25 steps, 10 frames = 4.53s
* 30 steps, 10 frames = 5.31s
* 50 steps, 10 frames, compile = 5.22s
* 50 steps, 10 frames, compile reduce-overhead = 4.51s
* 25 steps, 10 frames, compile reduce-overhead = 2.43s
* 25 steps, 6 frames, compile reduce-overhead = 1.72s
"""
# prewarm
for i in range(5):
torch.cuda.synchronize(device)
start = time.perf_counter()

out = sampler.generate(cond_img, trajectory)

torch.cuda.synchronize(device)
print(f"generate {time.perf_counter() - start:.2f}s")

print(out.shape)
assert out.shape == (num_frames, 3, 480, 640)

for i, img in enumerate(out):
img = to_pil_image(normalize_img(img))
img.save(f"vista_{i}.png")
for i, img in enumerate(out):
img = to_pil_image(normalize_img(img))
img.save(f"vista_{steps}_{i}.png")
11 changes: 9 additions & 2 deletions torchdrive/tasks/diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,8 +751,15 @@ def __init__(

self.model = ConvNextPathPred()

# dream parameters
self.dream_steps = 1
self.vista_fps = 10
self.steps_per_second = 2
if not test:
self.vista = VistaSampler()
vista_frames = (
1 + self.vista_fps * self.dream_steps // self.steps_per_second
)
self.vista = VistaSampler(steps=25, num_frames=vista_frames)

self.batch_transform = Compose(
NormalizeCarPosition(start_frame=0),
Expand Down Expand Up @@ -904,7 +911,7 @@ def forward(
positions[:, :pred_traj_len],
mask[:, :pred_traj_len],
pred_traj[:, :pred_traj_len],
step=2,
step=self.dream_steps,
)

dream_losses, dream_traj, all_dream_traj = self.model(
Expand Down

0 comments on commit cd11be9

Please sign in to comment.