diff --git a/torchdrive/tasks/diff_traj.py b/torchdrive/tasks/diff_traj.py index c404abc..4120f4f 100644 --- a/torchdrive/tasks/diff_traj.py +++ b/torchdrive/tasks/diff_traj.py @@ -19,6 +19,7 @@ from torchdrive.models.mlp import MLP from torchdrive.models.path import XYEncoder from torchdrive.tasks.van import Van +from torchdrive.tasks.context import Context from torchdrive.transforms.batch import NormalizeCarPosition from torchtune.modules import RotaryPositionalEmbeddings from torchworld.models.vit import MaskViT @@ -29,6 +30,7 @@ render_pca, ) from torchworld.transforms.mask import random_block_mask, true_mask +from torchdrive.models.transformer import transformer_init def square_mask(mask: torch.Tensor, num_heads: int) -> torch.Tensor: @@ -92,6 +94,7 @@ def __init__( layer_norm_eps=1e-6, ) self.layers = nn.Sequential(layers) + transformer_init(self.layers) def forward( self, input: torch.Tensor, input_mask: torch.Tensor, condition: torch.Tensor @@ -109,19 +112,20 @@ def forward( f"Expected (batch_size, seq_length, hidden_dim) got {condition.shape}", ) - print(input.shape, input_mask.shape) - x = input # apply rotary embeddings # RoPE applies to each head separately - x = x.unflatten(-1, (self.num_heads, self.dim // self.num_heads)) - x = self.positional_embedding(x) - x = x.flatten(-2, -1) + # x = x.unflatten(-1, (self.num_heads, self.dim // self.num_heads)) + # x = self.positional_embedding(x) + # x = x.flatten(-2, -1) - attn_mask = square_mask(input_mask, num_heads=self.num_heads) - # True values are ignored so need to flip the mask - attn_mask = torch.logical_not(attn_mask) + if False: + attn_mask = square_mask(input_mask, num_heads=self.num_heads) + # True values are ignored so need to flip the mask + attn_mask = torch.logical_not(attn_mask) + else: + attn_mask = None for i, layer in enumerate(self.layers): x = layer(tgt=x, tgt_mask=attn_mask, memory=condition) @@ -359,6 +363,7 @@ def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor: emb = self.decoder(predicted) return self.embedding.loss(emb, target) + def gen_sineembed_for_position(pos_tensor, hidden_dim=256): """ Args: @@ -391,33 +396,50 @@ def gen_sineembed_for_position(pos_tensor, hidden_dim=256): y_embed = pos_tensor[:, :, 1] * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t - pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) - pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + pos_x = torch.stack( + (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 + ).flatten(2) + pos_y = torch.stack( + (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 + ).flatten(2) if pos_tensor.size(-1) == 2: pos = torch.cat((pos_y, pos_x), dim=2) elif pos_tensor.size(-1) == 4: w_embed = pos_tensor[:, :, 2] * scale pos_w = w_embed[:, :, None] / dim_t - pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + pos_w = torch.stack( + (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 + ).flatten(2) h_embed = pos_tensor[:, :, 3] * scale pos_h = h_embed[:, :, None] / dim_t - pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + pos_h = torch.stack( + (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 + ).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) return pos -def nll_loss_gmm_direct(pred_trajs, gt_trajs, gt_valid_mask, pre_nearest_mode_idxs=None, - timestamp_loss_weight=None, use_square_gmm=False, log_std_range=(-1.609, 5.0), rho_limit=0.5): + +def nll_loss_gmm_direct( + pred_trajs, + gt_trajs, + gt_valid_mask, + pre_nearest_mode_idxs=None, + timestamp_loss_weight=None, + use_square_gmm=False, + log_std_range=(-1.609, 5.0), + rho_limit=0.5, +): """ Gausian Mixture Model loss for trajectories. Adapted from https://github.com/sshaoshuai/MTR/blob/master/mtr/utils/loss_utils.py - + GMM Loss for Motion Transformer (MTR): https://arxiv.org/abs/2209.13508 - Written by Shaoshuai Shi + Written by Shaoshuai Shi Args: pred_trajs (batch_size, num_modes, num_timestamps, 5 or 3) @@ -426,7 +448,7 @@ def nll_loss_gmm_direct(pred_trajs, gt_trajs, gt_valid_mask, pre_nearest_mode_id timestamp_loss_weight (num_timestamps): """ if use_square_gmm: - assert pred_trajs.shape[-1] == 3 + assert pred_trajs.shape[-1] == 3 else: assert pred_trajs.shape[-1] == 5 @@ -435,24 +457,34 @@ def nll_loss_gmm_direct(pred_trajs, gt_trajs, gt_valid_mask, pre_nearest_mode_id if pre_nearest_mode_idxs is not None: nearest_mode_idxs = pre_nearest_mode_idxs else: - distance = (pred_trajs[:, :, :, 0:2] - gt_trajs[:, None, :, :]).norm(dim=-1) - distance = (distance * gt_valid_mask[:, None, :]).sum(dim=-1) + distance = (pred_trajs[:, :, :, 0:2] - gt_trajs[:, None, :, :]).norm(dim=-1) + distance = (distance * gt_valid_mask[:, None, :]).sum(dim=-1) nearest_mode_idxs = distance.argmin(dim=-1) - nearest_mode_bs_idxs = torch.arange(batch_size).type_as(nearest_mode_idxs) # (batch_size, 2) + nearest_mode_bs_idxs = torch.arange(batch_size).type_as( + nearest_mode_idxs + ) # (batch_size, 2) - nearest_trajs = pred_trajs[nearest_mode_bs_idxs, nearest_mode_idxs] # (batch_size, num_timestamps, 5) + nearest_trajs = pred_trajs[ + nearest_mode_bs_idxs, nearest_mode_idxs + ] # (batch_size, num_timestamps, 5) res_trajs = gt_trajs - nearest_trajs[:, :, 0:2] # (batch_size, num_timestamps, 2) dx = res_trajs[:, :, 0] dy = res_trajs[:, :, 1] if use_square_gmm: - log_std1 = log_std2 = torch.clip(nearest_trajs[:, :, 2], min=log_std_range[0], max=log_std_range[1]) - std1 = std2 = torch.exp(log_std1) # (0.2m to 150m) + log_std1 = log_std2 = torch.clip( + nearest_trajs[:, :, 2], min=log_std_range[0], max=log_std_range[1] + ) + std1 = std2 = torch.exp(log_std1) # (0.2m to 150m) rho = torch.zeros_like(log_std1) else: - log_std1 = torch.clip(nearest_trajs[:, :, 2], min=log_std_range[0], max=log_std_range[1]) - log_std2 = torch.clip(nearest_trajs[:, :, 3], min=log_std_range[0], max=log_std_range[1]) + log_std1 = torch.clip( + nearest_trajs[:, :, 2], min=log_std_range[0], max=log_std_range[1] + ) + log_std2 = torch.clip( + nearest_trajs[:, :, 3], min=log_std_range[0], max=log_std_range[1] + ) std1 = torch.exp(log_std1) # (0.2m to 150m) std2 = torch.exp(log_std2) # (0.2m to 150m) rho = torch.clip(nearest_trajs[:, :, 4], min=-rho_limit, max=rho_limit) @@ -462,16 +494,24 @@ def nll_loss_gmm_direct(pred_trajs, gt_trajs, gt_valid_mask, pre_nearest_mode_id gt_valid_mask = gt_valid_mask * timestamp_loss_weight[None, :] # -log(a^-1 * e^b) = log(a) - b - reg_gmm_log_coefficient = log_std1 + log_std2 + 0.5 * torch.log(1 - rho**2) # (batch_size, num_timestamps) - reg_gmm_exp = (0.5 * 1 / (1 - rho**2)) * ((dx**2) / (std1**2) + (dy**2) / (std2**2) - 2 * rho * dx * dy / (std1 * std2)) # (batch_size, num_timestamps) + reg_gmm_log_coefficient = ( + log_std1 + log_std2 + 0.5 * torch.log(1 - rho**2) + ) # (batch_size, num_timestamps) + reg_gmm_exp = (0.5 * 1 / (1 - rho**2)) * ( + (dx**2) / (std1**2) + (dy**2) / (std2**2) - 2 * rho * dx * dy / (std1 * std2) + ) # (batch_size, num_timestamps) reg_loss = ((reg_gmm_log_coefficient + reg_gmm_exp) * gt_valid_mask).sum(dim=-1) return reg_loss, nearest_mode_idxs, nearest_trajs + class XYGMMEncoder(nn.Module): def __init__( - self, dim: int, max_dist: float, dropout: float = 0.1, + self, + dim: int, + max_dist: float, + dropout: float = 0.1, num_traj: int = 5, ) -> None: super().__init__() @@ -481,7 +521,10 @@ def __init__( # [x, y, log_std1, log_std2, rho] self.traj_size = 5 - self.decoder = MLP(dim, dim, self.num_traj * self.traj_size, num_layers=3, dropout=dropout) + self.encoder = MLP(dim + 2, dim, dim, num_layers=3, dropout=dropout) + self.decoder = MLP( + dim, dim, self.num_traj * self.traj_size, num_layers=3, dropout=dropout + ) def forward(self, xy: torch.Tensor) -> torch.Tensor: """ @@ -491,16 +534,19 @@ def forward(self, xy: torch.Tensor) -> torch.Tensor: Returns: the embedding of the position (..., dim) """ - xy = (xy / (2 * self.max_dist)) + 0.5 - return gen_sineembed_for_position(xy, hidden_dim=self.dim) + normalized_xy = (xy / (2 * self.max_dist)) + 0.5 + emb = gen_sineembed_for_position(xy, hidden_dim=self.dim) + combined = torch.cat((emb, normalized_xy), dim=-1) + return self.encoder(combined.permute(0, 2, 1)).permute(0, 2, 1) def decode(self, input: torch.Tensor) -> torch.Tensor: """ Args: input: the position embedding (bs, n, dim) Returns: - the position (bs, n, 2) + the position (bs, num_traj, n, 2) """ + input = input.float() # [bs, dim * num_traj, n] out = self.decoder(input.permute(0, 2, 1)) # [bs, num_traj, dim, n] @@ -509,29 +555,31 @@ def decode(self, input: torch.Tensor) -> torch.Tensor: out = out.permute(0, 1, 3, 2) return out - def loss(self, emb: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + def loss( + self, emb: torch.Tensor, target: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: """ Args: emb: the embedding of the position (bs, n, dim) target: the position (bs, n, 2) mask: the mask of the position (bs, n) - + Returns: the loss (bs, n) best trajectories: (bs, n, 2) """ pred_traj = self.decode(emb) # l2 distance norm - loss, nearest_mode, nearest_trajs = nll_loss_gmm_direct( - pred_traj, target, mask - ) - return loss, nearest_trajs[..., :2] - + loss, nearest_mode, nearest_trajs = nll_loss_gmm_direct(pred_traj, target, mask) + return loss, nearest_trajs[..., :2], pred_traj class XYSineMLPEncoder(nn.Module): def __init__( - self, dim: int, max_dist: float, dropout: float = 0.1, + self, + dim: int, + max_dist: float, + dropout: float = 0.1, ) -> None: super().__init__() self.dim = dim @@ -563,9 +611,86 @@ def decode(self, input: torch.Tensor) -> torch.Tensor: def loss(self, emb: torch.Tensor, target: torch.Tensor) -> torch.Tensor: pos = self.decode(emb) # l2 distance norm - #return torch.linalg.vector_norm(pos-target, dim=-1) + # return torch.linalg.vector_norm(pos-target, dim=-1) return F.mse_loss(pos, target, reduction="none").mean(dim=-1) + +class ConvNextPathPred(nn.Module): + def __init__(self, dim: int = 256, max_seq_len: int = 18, pool_size: int = 4): + super().__init__() + + from torchvision.models.convnext import convnext_base, ConvNeXt_Base_Weights + + self.dim = dim + self.max_seq_len = max_seq_len + # [x, y, log_std1, log_std2, rho] + self.traj_size = 5 + self.num_traj = 5 + + self.encoder = convnext_base( + weights=ConvNeXt_Base_Weights.IMAGENET1K_V1, + ).features + enc_dim = 1024 + self.avgpool = nn.AdaptiveAvgPool2d(pool_size) + + self.static_features_encoder = nn.Sequential( + nn.Linear(1, dim), + nn.ReLU(inplace=True), + nn.Linear(dim, dim), + ) + + self.decoder = nn.Sequential( + nn.Linear(enc_dim * pool_size * pool_size + dim, max_seq_len * dim), + nn.ReLU(inplace=True), + nn.Linear(max_seq_len * dim, max_seq_len * dim), + nn.ReLU(inplace=True), + nn.Linear(max_seq_len * dim, max_seq_len * self.num_traj * self.traj_size + self.num_traj), + ) + + def forward( + self, + static_features: torch.Tensor, + color: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor + ) -> torch.Tensor: + # take first frame + with autocast(): + x = color[:, 0] + x = self.encoder(x) + x = self.avgpool(x) + + x = x.flatten(1, 3).float() + + static_features_emb = self.static_features_encoder(static_features) + + combined = torch.cat((x, static_features_emb), dim=-1) + + embed = self.decoder(combined) + + traj_classes = embed[:, :self.num_traj] + + pred_traj = embed[:, self.num_traj:] + pred_traj = pred_traj.unflatten(1, (self.num_traj, self.max_seq_len, self.traj_size)) + + length = min(self.max_seq_len, target.shape[1]) + + pred_traj = pred_traj[:, :, :length] + target = target[:, :length] + mask = mask[:, :length] + + traj_loss, nearest_mode, nearest_trajs = nll_loss_gmm_direct(pred_traj, target, mask) + + class_loss = F.cross_entropy(traj_classes, nearest_mode) + + losses = { + "paths/best": traj_loss.mean(), + "paths/class": class_loss.mean(), + } + + return losses, nearest_trajs[..., :2], pred_traj + + class DiffTraj(nn.Module, Van): """ A diffusion model for trajectory detection. @@ -578,7 +703,7 @@ def __init__( dim: int = 1024, dim_feedforward: int = 4096, dropout: float = 0.1, - num_layers: int = 24, + num_layers: int = 12, num_heads: int = 16, num_encode_frames: int = 1, num_frames: int = 1, @@ -597,6 +722,9 @@ def __init__( self.num_inference_timesteps = num_inference_timesteps self.noise_scale = 25.0 + self.model = ConvNextPathPred() + + """ self.encoders = nn.ModuleDict( { cam: MaskViT( @@ -622,7 +750,6 @@ def __init__( self.static_features_encoder = nn.Sequential( nn.Linear(1, dim), - nn.BatchNorm1d(dim), nn.ReLU(inplace=True), nn.Linear(dim, dim), ) @@ -638,30 +765,48 @@ def __init__( num_train_timesteps=num_train_timesteps ) self.eval_noise_scheduler.set_timesteps(self.num_inference_timesteps) + """ self.batch_transform = NormalizeCarPosition(start_frame=0) def param_opts(self, lr: float) -> List[Dict[str, object]]: + """ return [ { "name": "encoders", "params": list(self.encoders.parameters()), - "lr": lr / 10, + "lr": lr, + "weight_decay": 1e-4, }, { "name": "static_features", "params": list(self.static_features_encoder.parameters()), - "lr": lr / 10, + "lr": lr, + }, + { + "name": "query", + "params": [self.query_embed], + "lr": lr, }, { "name": "denoiser", "params": list(self.denoiser.parameters()), "lr": lr, + "weight_decay": 1e-4, }, { "name": "xy_embedding", "params": list(self.xy_embedding.parameters()), - "lr": lr / 10, + "lr": lr, + }, + ] + """ + return [ + { + "name": "model", + "params": list(self.model.parameters()), + "lr": lr, + "weight_decay": 1e-4, }, ] @@ -691,7 +836,26 @@ def forward( device = batch.device() log_img, log_text = self.should_log(global_step, BS) + ctx = Context( + log_img=log_img, + log_text=log_text, + global_step=global_step, + writer=writer, + output=output, + start_frame=0, + weights=1, + ) + + for cam in self.cameras: + feats = batch.color[cam][:, : self.num_encode_frames] + if log_img: + ctx.add_image( + f"{cam}/color", + normalize_img(feats[0, 0]), + global_step=global_step, + ) + """ # for size, device only empty_mask = torch.empty(self.feat_shape, device=device) @@ -710,13 +874,13 @@ def forward( ) if writer is not None and log_text: - writer.add_scalar( + ctx.add_scalar( f"{cam}/count", mask.long().sum(), global_step=global_step, ) if writer is not None and log_img: - writer.add_image( + ctx.add_image( f"{cam}/mask", render_color(mask), global_step=global_step, @@ -734,12 +898,12 @@ def forward( assert cam_feats.requires_grad, f"missing grad for cam {cam}" if writer is not None and log_img: - writer.add_image( + ctx.add_image( f"{cam}/color", normalize_img(feats[0, 0]), global_step=global_step, ) - writer.add_image( + ctx.add_image( f"{cam}/pca", render_pca(unmasked[0].permute(1, 2, 0)), global_step=global_step, @@ -764,6 +928,7 @@ def forward( all_feats.append(cam_feats) input_tokens = torch.cat(all_feats, dim=1) + """ world_to_car, mask, lengths = batch.long_cam_T car_to_world = torch.zeros_like(world_to_car) @@ -784,51 +949,59 @@ def forward( assert positions.size(-1) == 2 velocity = torch.linalg.vector_norm(velocity, dim=-1, keepdim=True) - static_features = self.static_features_encoder(velocity).unsqueeze(1) + # static_features = self.static_features_encoder(velocity).unsqueeze(1) lengths = mask.sum(dim=-1) min_len = lengths.amin() assert min_len > 0, "got example with zero sequence length" # truncate to shortest sequence - # pos_len = lengths.amin() + pos_len = lengths.amin() # if pos_len % align != 0: # pos_len -= pos_len % align # assert pos_len >= 8 # positions = positions[:, :pos_len] # mask = mask[:, :pos_len] + # approximately 0.5 fps since video is 15fps + positions = positions[:, ::7] + mask = mask[:, ::7] + + """ # we need to be aligned to size 8 # pad length align = 8 if positions.size(1) % align != 0: pad = align - positions.size(1) % align - mask = F.pad(mask, (0, pad), value=True) + mask = F.pad(mask, (0, pad), value=False) positions = F.pad(positions, (0, 0, 0, pad), value=0) pos_len = positions.size(1) assert positions.size(1) % align == 0 assert mask.size(1) % align == 0 assert positions.size(1) == mask.size(1) + """ num_elements = mask.float().sum() - if writer and log_text: - writer.add_scalar( + if log_text: + ctx.add_scalar( "paths/pos_len", pos_len, global_step=global_step, ) - writer.add_scalar( + ctx.add_scalar( "paths/num_elements", num_elements, global_step=global_step, ) posmax = positions.abs().amax() - assert posmax < 1000, positions + assert posmax < 100000, positions + """ traj_embed = self.xy_embedding(positions) + """ """ noise = torch.randn(traj_embed.shape, device=traj_embed.device) / self.noise_scale @@ -842,7 +1015,7 @@ def forward( traj_embed_noise = self.noise_scheduler.add_noise(traj_embed, noise, timesteps) if writer and log_text: - writer.add_scalars( + ctx.add_scalars( "paths/embed_scales", { "embed": torch.linalg.vector_norm(traj_embed, dim=-1).mean().cpu(), @@ -853,13 +1026,18 @@ def forward( ) """ - query = self.query_embed[:positions.size(1)].repeat(BS, 1, 1) + """ + query = self.query_embed.repeat(BS, 1, 1) + + #with autocast(): + # add static feature info to all condition keys to avoid noise + input_tokens = input_tokens + static_features - with autocast(): - # add static feature info to all condition keys to avoid noise - input_tokens = input_tokens + static_features + pred_embed = self.denoiser(query, mask, input_tokens) - pred_embed = self.denoiser(query, mask, input_tokens) + # reduce to match target + pred_embed = pred_embed[:, :positions.size(1)] + """ """ noise_loss = F.mse_loss(pred_noise, noise, reduction="none") @@ -867,17 +1045,41 @@ def forward( losses["diffusion"] = noise_loss.mean() """ - pred_loss, pred_traj = self.xy_embedding.loss(pred_embed, positions, mask) - losses["path/best"] = pred_loss.mean() + # pred_loss, pred_traj, all_pred_traj = self.xy_embedding.loss(pred_embed, positions, mask) - #noise_loss, noise_traj = self.y_embedding.loss(traj_embed_noise, positions) - #losses["ae/with_noise"] = ( + cam = self.cameras[0] + pred_losses, pred_traj, all_pred_traj = self.model(velocity, batch.color[cam], positions, mask) + losses.update(pred_losses) + + pred_len = min(pred_traj.size(1), mask[0].sum().item()) + + # noise_loss, noise_traj = self.y_embedding.loss(traj_embed_noise, positions) + # losses["ae/with_noise"] = ( # noise_loss.mean() * 0.01 - #) - ae_loss, ae_traj = self.xy_embedding.loss(traj_embed, positions, mask) - losses["ae/ae"] = ae_loss.mean() + # ) + # ae_loss, ae_traj, _ = self.xy_embedding.loss(traj_embed, positions, mask) + # losses["ae/ae"] = ae_loss.mean() * 0.1 - losses_backward(losses) + if writer and log_text: + # ctx.add_scalar( + # "ae/ae", + # ae_loss.mean().cpu(), + # global_step=global_step, + # ) + + size = min(pred_traj.size(1), positions.size(1)) + + ctx.add_scalar( + "paths/pred_mae", + F.l1_loss(pred_traj[:, :size], positions[:, :size], reduction="none")[mask[:, :size]] + .mean() + .cpu() + .item(), + global_step=global_step, + ) + + if self.training: + ctx.backward(losses) if writer and log_img: # calculate cross_attn_weights @@ -887,8 +1089,6 @@ def forward( # generate prediction self.train() - pred_len = mask[0].sum() - """ pred_traj = torch.randn_like(noise[:1]) / self.noise_scale self.eval_noise_scheduler.set_timesteps(self.num_inference_timesteps) @@ -907,9 +1107,6 @@ def forward( pred_positions = self.xy_embedding.decode(pred_traj)[0, :pred_len].cpu() """ - pred_positions = pred_traj[0] - plt.plot(pred_positions[..., 0], pred_positions[..., 1], label="pred") - """ noise_positions = self.xy_embedding.decode(traj_embed_noise[:1])[ 0, @@ -923,18 +1120,21 @@ def forward( ].cpu() """ - ae_positions = ae_traj[0] + """ + ae_positions = ae_traj[0, :pred_len].cpu() plt.plot(ae_positions[..., 0], ae_positions[..., 1], label="ae") + """ target = positions[0, :pred_len].detach().cpu() plt.plot(target[..., 0], target[..., 1], label="target") - writer.add_scalar( - "paths/pred_mae", - F.l1_loss(pred_positions, target).item(), - global_step=global_step, - ) - writer.add_scalar( + for i in range(self.model.num_traj): + pred_positions = all_pred_traj[0, i, :pred_len].cpu() + plt.plot( + pred_positions[..., 0], pred_positions[..., 1], label=f"pred{i}" + ) + + ctx.add_scalar( "paths/pred_len", pred_len, global_step=global_step, @@ -944,7 +1144,7 @@ def forward( fig.legend() plt.gca().set_aspect("equal") - writer.add_figure( + ctx.add_figure( "paths/target", fig, global_step=global_step, diff --git a/torchdrive/tasks/test_diff_traj.py b/torchdrive/tasks/test_diff_traj.py index 575230d..510effe 100644 --- a/torchdrive/tasks/test_diff_traj.py +++ b/torchdrive/tasks/test_diff_traj.py @@ -132,6 +132,15 @@ def test_diff_traj(self): writer = MagicMock() losses = m(batch, global_step=0, writer=writer) + param_opts = m.param_opts(1) + all_params = set() + for group in param_opts: + for param in group["params"]: + all_params.add(param) + + for name, param in m.named_parameters(): + self.assertIn(param, all_params, name) + def test_xy_mlp_encoder(self): torch.manual_seed(0) diff --git a/torchdrive/train_config.py b/torchdrive/train_config.py index c9ff497..f865687 100644 --- a/torchdrive/train_config.py +++ b/torchdrive/train_config.py @@ -24,14 +24,16 @@ class DatasetConfig: # dataset only params dataset: Datasets - dataset_path: str + train_dataset_path: str + test_dataset_path: str mask_path: str batch_size: int num_workers: int autolabel_path: str autolabel: bool - def create_dataset(self, smoke: bool = False) -> Dataset: + def create_dataset(self, smoke: bool = False) -> Tuple[Dataset, Optional[Dataset]]: + test_dataset = None if self.dataset == Datasets.RICE: from torchdrive.datasets.rice import MultiCamDataset @@ -50,11 +52,17 @@ def create_dataset(self, smoke: bool = False) -> Dataset: from torchdrive.datasets.nuscenes_dataset import NuscenesDataset dataset = NuscenesDataset( - data_dir=self.dataset_path, + data_dir=self.train_dataset_path, version="v1.0-mini" if smoke else "v1.0-trainval", lidar=False, num_frames=self.num_frames, ) + test_dataset = NuscenesDataset( + data_dir=self.train_dataset_path, + version="v1.0-mini" if smoke else "v1.0-test", + lidar=False, + num_frames=self.num_frames, + ) elif self.dataset == Datasets.DUMMY: from torchdrive.datasets.dummy import DummyDataset @@ -67,10 +75,13 @@ def create_dataset(self, smoke: bool = False) -> Dataset: from torchdrive.datasets.autolabeler import AutoLabeler dataset = AutoLabeler(dataset, path=self.autolabel_path) + if test_dataset is not None: + test_dataset = AutoLabeler(test_dataset, path=self.autolabel_path) for cam in self.cameras: assert cam in dataset.cameras, "invalid camera" - return dataset + + return dataset, test_dataset @dataclass @@ -333,8 +344,8 @@ def create_model( num_frames=self.num_frames, ).to(device) - for cam_encoder in model.encoders.values(): - cam_encoder.freeze_pretrained_weights() + #for cam_encoder in model.encoders.values(): + # cam_encoder.freeze_pretrained_weights() return model