Skip to content

Commit

Permalink
diff_traj: using mtr loss and no diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jul 10, 2024
1 parent abcb4d9 commit 510edcd
Show file tree
Hide file tree
Showing 3 changed files with 295 additions and 18 deletions.
2 changes: 1 addition & 1 deletion configs/diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
num_encode_frames=1,
cam_shape=(480, 640),
# optimizer settings
epochs=20,
epochs=200,
lr=1e-4,
grad_clip=1.0,
step_size=1000,
Expand Down
280 changes: 263 additions & 17 deletions torchdrive/tasks/diff_traj.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os.path
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import math

import matplotlib.pyplot as plt

Expand Down Expand Up @@ -108,6 +109,8 @@ def forward(
f"Expected (batch_size, seq_length, hidden_dim) got {condition.shape}",
)

print(input.shape, input_mask.shape)

x = input

# apply rotary embeddings
Expand Down Expand Up @@ -356,6 +359,212 @@ def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
emb = self.decoder(predicted)
return self.embedding.loss(emb, target)

def gen_sineembed_for_position(pos_tensor, hidden_dim=256):
"""
Args:
pos_tensor: [bs, n, 2] -> [bs, n, dim], input range 0-1
Converts a position into a sine encoded tensor.
From: https://github.com/sshaoshuai/MTR/blob/master/mtr/models/motion_decoder/mtr_decoder.py#L134
# Copyright (c) 2022 Shaoshuai Shi. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# DAB-DETR
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
"""
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
half_hidden_dim = hidden_dim // 2
scale = 2 * math.pi
dim_t = torch.arange(half_hidden_dim, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (dim_t // 2) / half_hidden_dim)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
if pos_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=2)
elif pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)

h_embed = pos_tensor[:, :, 3] * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)

pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
return pos

def nll_loss_gmm_direct(pred_trajs, gt_trajs, gt_valid_mask, pre_nearest_mode_idxs=None,
timestamp_loss_weight=None, use_square_gmm=False, log_std_range=(-1.609, 5.0), rho_limit=0.5):
"""
Gausian Mixture Model loss for trajectories.
Adapted from https://github.com/sshaoshuai/MTR/blob/master/mtr/utils/loss_utils.py
GMM Loss for Motion Transformer (MTR): https://arxiv.org/abs/2209.13508
Written by Shaoshuai Shi
Args:
pred_trajs (batch_size, num_modes, num_timestamps, 5 or 3)
gt_trajs (batch_size, num_timestamps, 2):
gt_valid_mask (batch_size, num_timestamps):
timestamp_loss_weight (num_timestamps):
"""
if use_square_gmm:
assert pred_trajs.shape[-1] == 3
else:
assert pred_trajs.shape[-1] == 5

batch_size = pred_trajs.size(0)

if pre_nearest_mode_idxs is not None:
nearest_mode_idxs = pre_nearest_mode_idxs
else:
distance = (pred_trajs[:, :, :, 0:2] - gt_trajs[:, None, :, :]).norm(dim=-1)
distance = (distance * gt_valid_mask[:, None, :]).sum(dim=-1)

nearest_mode_idxs = distance.argmin(dim=-1)
nearest_mode_bs_idxs = torch.arange(batch_size).type_as(nearest_mode_idxs) # (batch_size, 2)

nearest_trajs = pred_trajs[nearest_mode_bs_idxs, nearest_mode_idxs] # (batch_size, num_timestamps, 5)
res_trajs = gt_trajs - nearest_trajs[:, :, 0:2] # (batch_size, num_timestamps, 2)
dx = res_trajs[:, :, 0]
dy = res_trajs[:, :, 1]

if use_square_gmm:
log_std1 = log_std2 = torch.clip(nearest_trajs[:, :, 2], min=log_std_range[0], max=log_std_range[1])
std1 = std2 = torch.exp(log_std1) # (0.2m to 150m)
rho = torch.zeros_like(log_std1)
else:
log_std1 = torch.clip(nearest_trajs[:, :, 2], min=log_std_range[0], max=log_std_range[1])
log_std2 = torch.clip(nearest_trajs[:, :, 3], min=log_std_range[0], max=log_std_range[1])
std1 = torch.exp(log_std1) # (0.2m to 150m)
std2 = torch.exp(log_std2) # (0.2m to 150m)
rho = torch.clip(nearest_trajs[:, :, 4], min=-rho_limit, max=rho_limit)

gt_valid_mask = gt_valid_mask.type_as(pred_trajs)
if timestamp_loss_weight is not None:
gt_valid_mask = gt_valid_mask * timestamp_loss_weight[None, :]

# -log(a^-1 * e^b) = log(a) - b
reg_gmm_log_coefficient = log_std1 + log_std2 + 0.5 * torch.log(1 - rho**2) # (batch_size, num_timestamps)
reg_gmm_exp = (0.5 * 1 / (1 - rho**2)) * ((dx**2) / (std1**2) + (dy**2) / (std2**2) - 2 * rho * dx * dy / (std1 * std2)) # (batch_size, num_timestamps)

reg_loss = ((reg_gmm_log_coefficient + reg_gmm_exp) * gt_valid_mask).sum(dim=-1)

return reg_loss, nearest_mode_idxs, nearest_trajs

class XYGMMEncoder(nn.Module):
def __init__(
self, dim: int, max_dist: float, dropout: float = 0.1,
num_traj: int = 5,
) -> None:
super().__init__()
self.dim = dim
self.max_dist = max_dist
self.num_traj = num_traj
# [x, y, log_std1, log_std2, rho]
self.traj_size = 5

self.decoder = MLP(dim, dim, self.num_traj * self.traj_size, num_layers=3, dropout=dropout)

def forward(self, xy: torch.Tensor) -> torch.Tensor:
"""
Args:
xy: the list of positions (..., 2)
Returns:
the embedding of the position (..., dim)
"""
xy = (xy / (2 * self.max_dist)) + 0.5
return gen_sineembed_for_position(xy, hidden_dim=self.dim)

def decode(self, input: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the position embedding (bs, n, dim)
Returns:
the position (bs, n, 2)
"""
# [bs, dim * num_traj, n]
out = self.decoder(input.permute(0, 2, 1))
# [bs, num_traj, dim, n]
out = out.unflatten(1, (self.num_traj, self.traj_size))
# [bs, num_traj, n, dim]
out = out.permute(0, 1, 3, 2)
return out

def loss(self, emb: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
Args:
emb: the embedding of the position (bs, n, dim)
target: the position (bs, n, 2)
mask: the mask of the position (bs, n)
Returns:
the loss (bs, n)
best trajectories: (bs, n, 2)
"""
pred_traj = self.decode(emb)
# l2 distance norm
loss, nearest_mode, nearest_trajs = nll_loss_gmm_direct(
pred_traj, target, mask
)
return loss, nearest_trajs[..., :2]



class XYSineMLPEncoder(nn.Module):
def __init__(
self, dim: int, max_dist: float, dropout: float = 0.1,
) -> None:
super().__init__()
self.dim = dim
self.max_dist = max_dist

self.decoder = MLP(dim, dim, 2, num_layers=3, dropout=dropout)

def forward(self, xy: torch.Tensor) -> torch.Tensor:
"""
Args:
xy: the list of positions (..., 2)
Returns:
the embedding of the position (..., dim)
"""
xy = (xy / (2 * self.max_dist)) + 0.5
return gen_sineembed_for_position(xy, hidden_dim=self.dim)

def decode(self, input: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the position embedding (bs, n, dim)
Returns:
the position (bs, n, 2)
"""
output = self.decoder(input.permute(0, 2, 1)).permute(0, 2, 1)
return (output.float().sigmoid() - 0.5) * (2 * self.max_dist)

def loss(self, emb: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
pos = self.decode(emb)
# l2 distance norm
#return torch.linalg.vector_norm(pos-target, dim=-1)
return F.mse_loss(pos, target, reduction="none").mean(dim=-1)

class DiffTraj(nn.Module, Van):
"""
Expand All @@ -375,6 +584,7 @@ def __init__(
num_frames: int = 1,
num_inference_timesteps: int = 50,
num_train_timesteps: int = 1000,
max_seq_len: int = 256,
):
super().__init__()

Expand All @@ -385,6 +595,7 @@ def __init__(
self.feat_shape = (cam_shape[0] // 16, cam_shape[1] // 16)
self.num_train_timesteps = num_train_timesteps
self.num_inference_timesteps = num_inference_timesteps
self.noise_scale = 25.0

self.encoders = nn.ModuleDict(
{
Expand All @@ -398,10 +609,10 @@ def __init__(
)

# embedding
self.xy_embedding = XYMLPEncoder(dim=dim, max_dist=128, pretrained=True)
self.xy_embedding = XYGMMEncoder(dim=dim, max_dist=128.0)

self.denoiser = Denoiser(
max_seq_len=256,
max_seq_len=max_seq_len,
num_layers=num_layers,
num_heads=num_heads,
dim=dim,
Expand All @@ -411,9 +622,13 @@ def __init__(

self.static_features_encoder = nn.Sequential(
nn.Linear(1, dim),
nn.BatchNorm1d(dim),
nn.ReLU(inplace=True),
nn.Linear(dim, dim),
)
self.query_embed = nn.Parameter(
torch.empty(max_seq_len, dim).normal_(std=0.02)
)

self.noise_scheduler = EulerDiscreteScheduler(
num_train_timesteps=num_train_timesteps
Expand Down Expand Up @@ -485,11 +700,14 @@ def forward(
for cam in self.cameras:
feats = batch.color[cam][:, : self.num_encode_frames]
block_size = min(*self.feat_shape) // 3
mask = random_block_mask(
empty_mask,
block_size=(block_size, block_size),
num_blocks=8,
)
if True:
mask = torch.ones_like(empty_mask).bool()
else:
mask = random_block_mask(
empty_mask,
block_size=(block_size, block_size),
num_blocks=8,
)

if writer is not None and log_text:
writer.add_scalar(
Expand Down Expand Up @@ -612,7 +830,8 @@ def forward(

traj_embed = self.xy_embedding(positions)

noise = torch.randn(traj_embed.shape, device=traj_embed.device)
"""
noise = torch.randn(traj_embed.shape, device=traj_embed.device) / self.noise_scale
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
Expand All @@ -622,20 +841,41 @@ def forward(
)
traj_embed_noise = self.noise_scheduler.add_noise(traj_embed, noise, timesteps)
if writer and log_text:
writer.add_scalars(
"paths/embed_scales",
{
"embed": torch.linalg.vector_norm(traj_embed, dim=-1).mean().cpu(),
"embed_with_noise": torch.linalg.vector_norm(traj_embed_noise, dim=-1).mean().cpu(),
"noise": torch.linalg.vector_norm(noise, dim=-1).mean().cpu(),
},
global_step=global_step,
)
"""

query = self.query_embed[:positions.size(1)].repeat(BS, 1, 1)

with autocast():
# add static feature info to all condition keys to avoid noise
input_tokens = input_tokens + static_features

pred_noise = self.denoiser(traj_embed_noise, mask, input_tokens)
pred_embed = self.denoiser(query, mask, input_tokens)

"""
noise_loss = F.mse_loss(pred_noise, noise, reduction="none")
noise_loss = noise_loss[mask]
losses["diffusion"] = noise_loss.mean()
"""

losses["ae/with_noise"] = (
self.xy_embedding.loss(traj_embed_noise, positions)[mask].mean() * 0.1
)
losses["ae/ae"] = self.xy_embedding.loss(traj_embed, positions)[mask].mean()
pred_loss, pred_traj = self.xy_embedding.loss(pred_embed, positions, mask)
losses["path/best"] = pred_loss.mean()

#noise_loss, noise_traj = self.y_embedding.loss(traj_embed_noise, positions)
#losses["ae/with_noise"] = (
# noise_loss.mean() * 0.01
#)
ae_loss, ae_traj = self.xy_embedding.loss(traj_embed, positions, mask)
losses["ae/ae"] = ae_loss.mean()

losses_backward(losses)

Expand All @@ -649,7 +889,8 @@ def forward(

pred_len = mask[0].sum()

pred_traj = torch.randn_like(noise[:1])
"""
pred_traj = torch.randn_like(noise[:1]) / self.noise_scale
self.eval_noise_scheduler.set_timesteps(self.num_inference_timesteps)
for timestep in self.eval_noise_scheduler.timesteps:
with autocast():
Expand All @@ -663,22 +904,27 @@ def forward(
pred_traj,
generator=torch.Generator(device=device).manual_seed(0),
).prev_sample

pred_positions = self.xy_embedding.decode(pred_traj)[0, :pred_len].cpu()
"""

pred_positions = pred_traj[0]
plt.plot(pred_positions[..., 0], pred_positions[..., 1], label="pred")

"""
noise_positions = self.xy_embedding.decode(traj_embed_noise[:1])[
0,
:pred_len,
].cpu()
plt.plot(
noise_positions[..., 0], noise_positions[..., 1], label="with_noise"
)

pos_positions = self.xy_embedding.decode(traj_embed[:1])[
0, :pred_len
].cpu()
plt.plot(pos_positions[..., 0], noise_positions[..., 1], label="ae")
"""

ae_positions = ae_traj[0]
plt.plot(ae_positions[..., 0], ae_positions[..., 1], label="ae")

target = positions[0, :pred_len].detach().cpu()
plt.plot(target[..., 0], target[..., 1], label="target")
Expand Down
Loading

0 comments on commit 510edcd

Please sign in to comment.