From 510edcdeeb5660b8d74a52639fe348895f5796ff Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Wed, 10 Jul 2024 13:50:14 -0700 Subject: [PATCH] diff_traj: using mtr loss and no diffusion --- configs/diff_traj.py | 2 +- torchdrive/tasks/diff_traj.py | 280 +++++++++++++++++++++++++++-- torchdrive/tasks/test_diff_traj.py | 31 ++++ 3 files changed, 295 insertions(+), 18 deletions(-) diff --git a/configs/diff_traj.py b/configs/diff_traj.py index ef8e784..e291c21 100644 --- a/configs/diff_traj.py +++ b/configs/diff_traj.py @@ -10,7 +10,7 @@ num_encode_frames=1, cam_shape=(480, 640), # optimizer settings - epochs=20, + epochs=200, lr=1e-4, grad_clip=1.0, step_size=1000, diff --git a/torchdrive/tasks/diff_traj.py b/torchdrive/tasks/diff_traj.py index c11dbe0..c404abc 100644 --- a/torchdrive/tasks/diff_traj.py +++ b/torchdrive/tasks/diff_traj.py @@ -1,6 +1,7 @@ import os.path from collections import OrderedDict from typing import Dict, List, Optional, Tuple +import math import matplotlib.pyplot as plt @@ -108,6 +109,8 @@ def forward( f"Expected (batch_size, seq_length, hidden_dim) got {condition.shape}", ) + print(input.shape, input_mask.shape) + x = input # apply rotary embeddings @@ -356,6 +359,212 @@ def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor: emb = self.decoder(predicted) return self.embedding.loss(emb, target) +def gen_sineembed_for_position(pos_tensor, hidden_dim=256): + """ + Args: + pos_tensor: [bs, n, 2] -> [bs, n, dim], input range 0-1 + + Converts a position into a sine encoded tensor. + From: https://github.com/sshaoshuai/MTR/blob/master/mtr/models/motion_decoder/mtr_decoder.py#L134 + # Copyright (c) 2022 Shaoshuai Shi. All Rights Reserved. + # Licensed under the Apache License, Version 2.0 [see LICENSE for details] + # ------------------------------------------------------------------------ + # DAB-DETR + # Copyright (c) 2022 IDEA. All Rights Reserved. + # Licensed under the Apache License, Version 2.0 [see LICENSE for details] + # ------------------------------------------------------------------------ + # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) + # Copyright (c) 2021 Microsoft. All Rights Reserved. + # Licensed under the Apache License, Version 2.0 [see LICENSE for details] + # ------------------------------------------------------------------------ + # Modified from DETR (https://github.com/facebookresearch/detr) + # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + # ------------------------------------------------------------------------ + """ + # n_query, bs, _ = pos_tensor.size() + # sineembed_tensor = torch.zeros(n_query, bs, 256) + half_hidden_dim = hidden_dim // 2 + scale = 2 * math.pi + dim_t = torch.arange(half_hidden_dim, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * (dim_t // 2) / half_hidden_dim) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + if pos_tensor.size(-1) == 2: + pos = torch.cat((pos_y, pos_x), dim=2) + elif pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) + return pos + +def nll_loss_gmm_direct(pred_trajs, gt_trajs, gt_valid_mask, pre_nearest_mode_idxs=None, + timestamp_loss_weight=None, use_square_gmm=False, log_std_range=(-1.609, 5.0), rho_limit=0.5): + """ + Gausian Mixture Model loss for trajectories. + + Adapted from https://github.com/sshaoshuai/MTR/blob/master/mtr/utils/loss_utils.py + + GMM Loss for Motion Transformer (MTR): https://arxiv.org/abs/2209.13508 + Written by Shaoshuai Shi + + Args: + pred_trajs (batch_size, num_modes, num_timestamps, 5 or 3) + gt_trajs (batch_size, num_timestamps, 2): + gt_valid_mask (batch_size, num_timestamps): + timestamp_loss_weight (num_timestamps): + """ + if use_square_gmm: + assert pred_trajs.shape[-1] == 3 + else: + assert pred_trajs.shape[-1] == 5 + + batch_size = pred_trajs.size(0) + + if pre_nearest_mode_idxs is not None: + nearest_mode_idxs = pre_nearest_mode_idxs + else: + distance = (pred_trajs[:, :, :, 0:2] - gt_trajs[:, None, :, :]).norm(dim=-1) + distance = (distance * gt_valid_mask[:, None, :]).sum(dim=-1) + + nearest_mode_idxs = distance.argmin(dim=-1) + nearest_mode_bs_idxs = torch.arange(batch_size).type_as(nearest_mode_idxs) # (batch_size, 2) + + nearest_trajs = pred_trajs[nearest_mode_bs_idxs, nearest_mode_idxs] # (batch_size, num_timestamps, 5) + res_trajs = gt_trajs - nearest_trajs[:, :, 0:2] # (batch_size, num_timestamps, 2) + dx = res_trajs[:, :, 0] + dy = res_trajs[:, :, 1] + + if use_square_gmm: + log_std1 = log_std2 = torch.clip(nearest_trajs[:, :, 2], min=log_std_range[0], max=log_std_range[1]) + std1 = std2 = torch.exp(log_std1) # (0.2m to 150m) + rho = torch.zeros_like(log_std1) + else: + log_std1 = torch.clip(nearest_trajs[:, :, 2], min=log_std_range[0], max=log_std_range[1]) + log_std2 = torch.clip(nearest_trajs[:, :, 3], min=log_std_range[0], max=log_std_range[1]) + std1 = torch.exp(log_std1) # (0.2m to 150m) + std2 = torch.exp(log_std2) # (0.2m to 150m) + rho = torch.clip(nearest_trajs[:, :, 4], min=-rho_limit, max=rho_limit) + + gt_valid_mask = gt_valid_mask.type_as(pred_trajs) + if timestamp_loss_weight is not None: + gt_valid_mask = gt_valid_mask * timestamp_loss_weight[None, :] + + # -log(a^-1 * e^b) = log(a) - b + reg_gmm_log_coefficient = log_std1 + log_std2 + 0.5 * torch.log(1 - rho**2) # (batch_size, num_timestamps) + reg_gmm_exp = (0.5 * 1 / (1 - rho**2)) * ((dx**2) / (std1**2) + (dy**2) / (std2**2) - 2 * rho * dx * dy / (std1 * std2)) # (batch_size, num_timestamps) + + reg_loss = ((reg_gmm_log_coefficient + reg_gmm_exp) * gt_valid_mask).sum(dim=-1) + + return reg_loss, nearest_mode_idxs, nearest_trajs + +class XYGMMEncoder(nn.Module): + def __init__( + self, dim: int, max_dist: float, dropout: float = 0.1, + num_traj: int = 5, + ) -> None: + super().__init__() + self.dim = dim + self.max_dist = max_dist + self.num_traj = num_traj + # [x, y, log_std1, log_std2, rho] + self.traj_size = 5 + + self.decoder = MLP(dim, dim, self.num_traj * self.traj_size, num_layers=3, dropout=dropout) + + def forward(self, xy: torch.Tensor) -> torch.Tensor: + """ + Args: + xy: the list of positions (..., 2) + + Returns: + the embedding of the position (..., dim) + """ + xy = (xy / (2 * self.max_dist)) + 0.5 + return gen_sineembed_for_position(xy, hidden_dim=self.dim) + + def decode(self, input: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the position embedding (bs, n, dim) + Returns: + the position (bs, n, 2) + """ + # [bs, dim * num_traj, n] + out = self.decoder(input.permute(0, 2, 1)) + # [bs, num_traj, dim, n] + out = out.unflatten(1, (self.num_traj, self.traj_size)) + # [bs, num_traj, n, dim] + out = out.permute(0, 1, 3, 2) + return out + + def loss(self, emb: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Args: + emb: the embedding of the position (bs, n, dim) + target: the position (bs, n, 2) + mask: the mask of the position (bs, n) + + Returns: + the loss (bs, n) + best trajectories: (bs, n, 2) + """ + pred_traj = self.decode(emb) + # l2 distance norm + loss, nearest_mode, nearest_trajs = nll_loss_gmm_direct( + pred_traj, target, mask + ) + return loss, nearest_trajs[..., :2] + + + +class XYSineMLPEncoder(nn.Module): + def __init__( + self, dim: int, max_dist: float, dropout: float = 0.1, + ) -> None: + super().__init__() + self.dim = dim + self.max_dist = max_dist + + self.decoder = MLP(dim, dim, 2, num_layers=3, dropout=dropout) + + def forward(self, xy: torch.Tensor) -> torch.Tensor: + """ + Args: + xy: the list of positions (..., 2) + + Returns: + the embedding of the position (..., dim) + """ + xy = (xy / (2 * self.max_dist)) + 0.5 + return gen_sineembed_for_position(xy, hidden_dim=self.dim) + + def decode(self, input: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the position embedding (bs, n, dim) + Returns: + the position (bs, n, 2) + """ + output = self.decoder(input.permute(0, 2, 1)).permute(0, 2, 1) + return (output.float().sigmoid() - 0.5) * (2 * self.max_dist) + + def loss(self, emb: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + pos = self.decode(emb) + # l2 distance norm + #return torch.linalg.vector_norm(pos-target, dim=-1) + return F.mse_loss(pos, target, reduction="none").mean(dim=-1) class DiffTraj(nn.Module, Van): """ @@ -375,6 +584,7 @@ def __init__( num_frames: int = 1, num_inference_timesteps: int = 50, num_train_timesteps: int = 1000, + max_seq_len: int = 256, ): super().__init__() @@ -385,6 +595,7 @@ def __init__( self.feat_shape = (cam_shape[0] // 16, cam_shape[1] // 16) self.num_train_timesteps = num_train_timesteps self.num_inference_timesteps = num_inference_timesteps + self.noise_scale = 25.0 self.encoders = nn.ModuleDict( { @@ -398,10 +609,10 @@ def __init__( ) # embedding - self.xy_embedding = XYMLPEncoder(dim=dim, max_dist=128, pretrained=True) + self.xy_embedding = XYGMMEncoder(dim=dim, max_dist=128.0) self.denoiser = Denoiser( - max_seq_len=256, + max_seq_len=max_seq_len, num_layers=num_layers, num_heads=num_heads, dim=dim, @@ -411,9 +622,13 @@ def __init__( self.static_features_encoder = nn.Sequential( nn.Linear(1, dim), + nn.BatchNorm1d(dim), nn.ReLU(inplace=True), nn.Linear(dim, dim), ) + self.query_embed = nn.Parameter( + torch.empty(max_seq_len, dim).normal_(std=0.02) + ) self.noise_scheduler = EulerDiscreteScheduler( num_train_timesteps=num_train_timesteps @@ -485,11 +700,14 @@ def forward( for cam in self.cameras: feats = batch.color[cam][:, : self.num_encode_frames] block_size = min(*self.feat_shape) // 3 - mask = random_block_mask( - empty_mask, - block_size=(block_size, block_size), - num_blocks=8, - ) + if True: + mask = torch.ones_like(empty_mask).bool() + else: + mask = random_block_mask( + empty_mask, + block_size=(block_size, block_size), + num_blocks=8, + ) if writer is not None and log_text: writer.add_scalar( @@ -612,7 +830,8 @@ def forward( traj_embed = self.xy_embedding(positions) - noise = torch.randn(traj_embed.shape, device=traj_embed.device) + """ + noise = torch.randn(traj_embed.shape, device=traj_embed.device) / self.noise_scale timesteps = torch.randint( 0, self.noise_scheduler.config.num_train_timesteps, @@ -622,20 +841,41 @@ def forward( ) traj_embed_noise = self.noise_scheduler.add_noise(traj_embed, noise, timesteps) + if writer and log_text: + writer.add_scalars( + "paths/embed_scales", + { + "embed": torch.linalg.vector_norm(traj_embed, dim=-1).mean().cpu(), + "embed_with_noise": torch.linalg.vector_norm(traj_embed_noise, dim=-1).mean().cpu(), + "noise": torch.linalg.vector_norm(noise, dim=-1).mean().cpu(), + }, + global_step=global_step, + ) + """ + + query = self.query_embed[:positions.size(1)].repeat(BS, 1, 1) + with autocast(): # add static feature info to all condition keys to avoid noise input_tokens = input_tokens + static_features - pred_noise = self.denoiser(traj_embed_noise, mask, input_tokens) + pred_embed = self.denoiser(query, mask, input_tokens) + """ noise_loss = F.mse_loss(pred_noise, noise, reduction="none") noise_loss = noise_loss[mask] losses["diffusion"] = noise_loss.mean() + """ - losses["ae/with_noise"] = ( - self.xy_embedding.loss(traj_embed_noise, positions)[mask].mean() * 0.1 - ) - losses["ae/ae"] = self.xy_embedding.loss(traj_embed, positions)[mask].mean() + pred_loss, pred_traj = self.xy_embedding.loss(pred_embed, positions, mask) + losses["path/best"] = pred_loss.mean() + + #noise_loss, noise_traj = self.y_embedding.loss(traj_embed_noise, positions) + #losses["ae/with_noise"] = ( + # noise_loss.mean() * 0.01 + #) + ae_loss, ae_traj = self.xy_embedding.loss(traj_embed, positions, mask) + losses["ae/ae"] = ae_loss.mean() losses_backward(losses) @@ -649,7 +889,8 @@ def forward( pred_len = mask[0].sum() - pred_traj = torch.randn_like(noise[:1]) + """ + pred_traj = torch.randn_like(noise[:1]) / self.noise_scale self.eval_noise_scheduler.set_timesteps(self.num_inference_timesteps) for timestep in self.eval_noise_scheduler.timesteps: with autocast(): @@ -663,10 +904,13 @@ def forward( pred_traj, generator=torch.Generator(device=device).manual_seed(0), ).prev_sample - pred_positions = self.xy_embedding.decode(pred_traj)[0, :pred_len].cpu() + """ + + pred_positions = pred_traj[0] plt.plot(pred_positions[..., 0], pred_positions[..., 1], label="pred") + """ noise_positions = self.xy_embedding.decode(traj_embed_noise[:1])[ 0, :pred_len, @@ -674,11 +918,13 @@ def forward( plt.plot( noise_positions[..., 0], noise_positions[..., 1], label="with_noise" ) - pos_positions = self.xy_embedding.decode(traj_embed[:1])[ 0, :pred_len ].cpu() - plt.plot(pos_positions[..., 0], noise_positions[..., 1], label="ae") + """ + + ae_positions = ae_traj[0] + plt.plot(ae_positions[..., 0], ae_positions[..., 1], label="ae") target = positions[0, :pred_len].detach().cpu() plt.plot(target[..., 0], target[..., 1], label="target") diff --git a/torchdrive/tasks/test_diff_traj.py b/torchdrive/tasks/test_diff_traj.py index 28724ef..575230d 100644 --- a/torchdrive/tasks/test_diff_traj.py +++ b/torchdrive/tasks/test_diff_traj.py @@ -11,6 +11,7 @@ XYEmbedding, XYLinearEmbedding, XYMLPEncoder, + XYSineMLPEncoder, ) @@ -161,6 +162,36 @@ def test_xy_mlp_encoder(self): for param in m.parameters(): self.assertIsNotNone(param.grad) + def test_xy_sine_mlp_encoder(self): + torch.manual_seed(0) + + m = XYSineMLPEncoder( + dim=32, + max_dist=128.0, + ) + + input = torch.tensor( + [ + (0.0, 0.0), + (1.0, 0.0), + (0.0, 1.0), + (-1.0, 0.0), + (0.0, -1.0), + ] + ).unsqueeze(0) + + out = m(input) + self.assertEqual(out.shape, (1, 5, 32)) + + decoded = m.decode(out) + self.assertEqual(decoded.shape, (1, 5, 2)) + + loss = m.loss(out, input) + self.assertEqual(loss.shape, (1, 5)) + loss.sum().backward() + for param in m.parameters(): + self.assertIsNotNone(param.grad) + def test_square_mask(self): input = torch.tensor( [