Skip to content

Commit

Permalink
diff_traj: wip integration with Vista
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jul 22, 2024
1 parent 7780391 commit 61c2f9a
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 17 deletions.
31 changes: 28 additions & 3 deletions torchdrive/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ def save(self, path: str, compress_level: int = 3, threads: int = 1) -> None:
torch.save(data, buffer)
buffer.seek(0)
buf = buffer.read()
buf = zstd.compress(buf, compress_level, threads)

if path.endswith(".zst") or path.endswith(".zstd"):
buf = zstd.compress(buf, compress_level, threads)
with open(path, "wb") as f:
f.write(buf)

Expand All @@ -230,13 +230,38 @@ def load(cls, path: str) -> None:
with open(path, "rb") as f:
buf = f.read()

if path.endswith(".zst"):
if path.endswith(".zst") or path.endswith(".zstd"):
buf = zstd.uncompress(buf)
buffer = io.BytesIO(buf)
data = torch.load(buffer, weights_only=True)

return cls(**data)

def positions(self) -> torch.Tensor:
"""
Returns the XY positions of the batch.
You likely want to normalize the batch first.
Returns:
tensor [bs, long_cam_T, 2]
"""
device = self.device()

world_to_car, mask, lengths = self.long_cam_T
car_to_world = torch.zeros_like(world_to_car)
car_to_world[mask] = world_to_car[mask].inverse()

assert mask.int().sum() == lengths.sum(), (mask, lengths)

zero_coord = torch.zeros(1, 4, device=device, dtype=torch.float)
zero_coord[:, -1] = 1

positions = torch.matmul(car_to_world, zero_coord.T).squeeze(-1)
positions /= positions[..., -1:] + 1e-8 # perspective warp

return positions[..., :3]


def _rand_det_target() -> torch.Tensor:
t = torch.rand(2, 5)
Expand Down
2 changes: 2 additions & 0 deletions torchdrive/datasets/nuscenes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,8 @@ def __init__(
version: str = "v1.0-trainval",
lidar: bool = False,
) -> None:
data_dir = os.path.expanduser(data_dir)

self.data_dir = data_dir
self.version = version
self.nusc = NuScenes(version=version, dataroot=data_dir, verbose=True)
Expand Down
103 changes: 103 additions & 0 deletions torchdrive/models/vista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch

from omegaconf import ListConfig, OmegaConf

from torchdrive.data import collate
from torchdrive.transforms.batch import NormalizeCarPosition
from torchdrive.datasets.nuscenes_dataset import NuscenesDataset

from vwm.sample_utils import load_model_from_config, init_sampling, do_sample

class VistaSampler:
def __init__(
self,
config_path: str = "~/Vista/configs/inference/vista.yaml",
ckpt_path: str = "~/Vista/ckpts/vista.safetensors",
device: str = "cuda",
steps: int = 50,
cfg_scale: float = 2.5,
num_frames: int = 10,
cond_aug: float = 0.0,
) -> None:
self.cond_aug = cond_aug
self.num_frames = num_frames

config = OmegaConf.load(config_path)
model = load_model_from_config(config, ckpt_path)
self.model = model.bfloat16().to(device)

guider = "VanillaCFG"
self.sampler = init_sampling(
guider=guider,
steps=steps,
cfg_scale=cfg_scale,
num_frames=num_frames,
)

def generate(self, cond_img: torch.Tensor, trajectory: torch.Tensor) -> torch.Tensor:
"""
Generates the next num_frames prediction.
Args:
cond_img: (1, 3, H, W)
the list of positions
Should be -1 to 1 value range
320x576 or 576x1024
trajectory: (1, 5, 2)
trajectory including start position at (0, 0)
(x, y) -- x+ is forward
meters
every 0.5s
"""
device = cond_img.device

assert cond_img.size(0) == 1

unique_keys = set([x.input_key for x in model.conditioner.embedders])

value_dict = init_embedder_options(unique_keys)
value_dict["cond_frames_without_noise"] = cond_img
value_dict["cond_aug"] = self.cond_aug
value_dict["cond_frames"] = cond_img + self.cond_aug * torch.randn_like(cond_img, device=device)
value_dict["trajectory"] = trajectory.squeeze(0)[1:5].flatten()

uc_keys = ["cond_frames", "cond_frames_without_noise", "command", "trajectory", "speed", "angle", "goal"]

images = cond_img.expand(self.num_frames, -1, -1, -1)

out = do_sample(
images,
self.model,
self.sampler,
value_dict,
num_rounds=1,
num_frames=self.num_frames,
force_uc_zero_embeddings=uc_keys,
initial_cond_indices=[0], # only condition on first frame
)
samples, samples_z, inputs = out


if __name__ == "__main__":
dataset = NuscenesDataset(
data_dir="~/nuscenes",
version="v1.0-mini",
lidar=False,
num_frames=1,
)

sample = dataset[0]
batch = collate([sample])


transform = NormalizeCarPosition(start_frame=0)
batch = transform(batch)

trajectory = batch.positions()
# down sample to 0.5s resolution 12 hz
trajectory = trajectory[:, ::6, :]

sampler = VistaSampler()
out = sampler.generate(cond_img, trajectory)
print(out.shape)
print(out)
17 changes: 4 additions & 13 deletions torchdrive/tasks/diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,16 +954,7 @@ def forward(
"""

world_to_car, mask, lengths = batch.long_cam_T
car_to_world = torch.zeros_like(world_to_car)
car_to_world[mask] = world_to_car[mask].inverse()

assert mask.int().sum() == lengths.sum(), (mask, lengths)

zero_coord = torch.zeros(1, 4, device=device, dtype=torch.float)
zero_coord[:, -1] = 1

positions = torch.matmul(car_to_world, zero_coord.T).squeeze(-1)
positions /= positions[..., -1:] + 1e-8 # perspective warp
positions = batch.positions()
positions = positions[..., :2]

# calculate velocity between first two frames to allow model to understand current speed
Expand All @@ -986,9 +977,9 @@ def forward(
# positions = positions[:, :pos_len]
# mask = mask[:, :pos_len]

# approximately 0.5 fps since video is 15fps
positions = positions[:, ::7]
mask = mask[:, ::7]
# approximately 0.5 fps since video is 12hz
positions = positions[:, ::6]
mask = mask[:, ::6]

"""
# we need to be aligned to size 8
Expand Down
19 changes: 18 additions & 1 deletion torchdrive/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,26 @@ def test_save_load(self) -> None:
batch = dummy_batch()

with tempfile.TemporaryDirectory("torchdrive-test_data") as path:
file_path = os.path.join(path, "file.pt.zstd")
file_path = os.path.join(path, "file.pt")
batch.save(file_path)

out = Batch.load(file_path)

self.assertIsNotNone(out)

def test_save_load_zstd(self) -> None:
batch = dummy_batch()

with tempfile.TemporaryDirectory("torchdrive-test_data") as path:
file_path = os.path.join(path, "file.pt.zst")
batch.save(file_path)

out = Batch.load(file_path)

self.assertIsNotNone(out)

def test_positions(self) -> None:
batch = dummy_batch()
positions = batch.positions()
world_to_car, _, _ = batch.long_cam_T
self.assertEqual(positions.shape, (*world_to_car.shape[:2], 3))

0 comments on commit 61c2f9a

Please sign in to comment.