diff --git a/torchdrive/models/vista.py b/torchdrive/models/vista.py index 716d76a..cd99c04 100644 --- a/torchdrive/models/vista.py +++ b/torchdrive/models/vista.py @@ -55,9 +55,20 @@ def __init__( start = time.perf_counter() config = OmegaConf.load(config_path) + config["model"]["params"]["num_frames"] = num_frames + config["model"]["params"]["denoiser_config"]["params"][ + "num_frames" + ] = num_frames model = load_model_from_config(config, ckpt_path) self.model = model.bfloat16().to(device).eval() + def compile(f): + return torch.compile(f, fullgraph=True, mode="reduce-overhead") + + self.model.encode_first_stage = compile(self.model.encode_first_stage) + self.model.denoiser = compile(self.model.denoiser) + self.model.decode_first_stage = compile(self.model.decode_first_stage) + print(f"loaded vista in {time.perf_counter() - start:.2f}s") guider = "VanillaCFG" @@ -153,6 +164,8 @@ def generate( if __name__ == "__main__": + torch.manual_seed(10) + dataset = NuscenesDataset( data_dir="~/nuscenes", version="v1.0-mini", @@ -174,11 +187,38 @@ def generate( print(trajectory) - sampler = VistaSampler(device=device) - out = sampler.generate(cond_img, trajectory) - print(out.shape) - assert out.shape == (10, 3, 480, 640) + steps = 25 + num_frames = 6 + + sampler = VistaSampler(device=device, num_frames=num_frames, steps=steps) + """ + Encoding 10 frames: 0.053 + Decoding 10 frames: 0.242 + + + * 50 steps, 10 frames = 8.33s + * 50 steps, 5 frames = 5.24s + * 25 steps, 10 frames = 4.53s + * 30 steps, 10 frames = 5.31s + + * 50 steps, 10 frames, compile = 5.22s + * 50 steps, 10 frames, compile reduce-overhead = 4.51s + * 25 steps, 10 frames, compile reduce-overhead = 2.43s + * 25 steps, 6 frames, compile reduce-overhead = 1.72s + """ + # prewarm + for i in range(5): + torch.cuda.synchronize(device) + start = time.perf_counter() + + out = sampler.generate(cond_img, trajectory) + + torch.cuda.synchronize(device) + print(f"generate {time.perf_counter() - start:.2f}s") + + print(out.shape) + assert out.shape == (num_frames, 3, 480, 640) - for i, img in enumerate(out): - img = to_pil_image(normalize_img(img)) - img.save(f"vista_{i}.png") + for i, img in enumerate(out): + img = to_pil_image(normalize_img(img)) + img.save(f"vista_{steps}_{i}.png") diff --git a/torchdrive/tasks/diff_traj.py b/torchdrive/tasks/diff_traj.py index f8c2795..b722d7c 100644 --- a/torchdrive/tasks/diff_traj.py +++ b/torchdrive/tasks/diff_traj.py @@ -751,8 +751,15 @@ def __init__( self.model = ConvNextPathPred() + # dream parameters + self.dream_steps = 1 + self.vista_fps = 10 + self.steps_per_second = 2 if not test: - self.vista = VistaSampler() + vista_frames = ( + 1 + self.vista_fps * self.dream_steps // self.steps_per_second + ) + self.vista = VistaSampler(steps=25, num_frames=vista_frames) self.batch_transform = Compose( NormalizeCarPosition(start_frame=0), @@ -904,7 +911,7 @@ def forward( positions[:, :pred_traj_len], mask[:, :pred_traj_len], pred_traj[:, :pred_traj_len], - step=2, + step=self.dream_steps, ) dream_losses, dream_traj, all_dream_traj = self.model(