diff --git a/torchdrive/models/path.py b/torchdrive/models/path.py index 69a82d4..0dcdce5 100644 --- a/torchdrive/models/path.py +++ b/torchdrive/models/path.py @@ -13,6 +13,12 @@ MAX_POS = 100 # meters from origin +class PathSigmoidMeters(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + # convert to bounded x/y/z coords + return (x.float().sigmoid() * 2 - 1) * MAX_POS + + def rel_dists(series: torch.Tensor) -> torch.Tensor: """ rel_dists returns the distances between each point in the series. @@ -39,11 +45,13 @@ def __init__( num_heads: int = 8, num_layers: int = 3, pos_dim: int = 3, + final_jitter: float = 5.0, compile_fn: Callable[[nn.Module], nn.Module] = lambda m: m, ) -> None: super().__init__() self.dim = dim + self.final_jitter = final_jitter self.bev_encoder = nn.Conv2d(bev_dim, dim, 1) @@ -61,6 +69,7 @@ def __init__( nn.Linear(dim, dim), nn.ReLU(inplace=True), nn.Linear(dim, pos_dim), + PathSigmoidMeters(), ) ) @@ -94,14 +103,13 @@ def forward( # feed it the target end position with random jitter added to avoid # overfitting - end_jitter = 5 - end_position = positions[:, :, -1] - end_jitter = (torch.rand_like(end_position) * 2 - 1) * end_jitter - end_position = end_position + end_jitter + if self.training: + final_jitter = (torch.rand_like(final_pos) * 2 - 1) * self.final_jitter + final_pos = final_pos + final_jitter with autocast(): static_feats = torch.cat( - (speed, start_position, end_position), dim=1 + (speed, start_position, final_pos), dim=1 ).unsqueeze(-1) static = self.static_encoder(static_feats).permute(0, 2, 1) @@ -115,10 +123,9 @@ def forward( out_positions = self.transformer(position_emb, cross_feats) - pred_pos = self.pos_decoder(out_positions.float()) - pred_pos = pred_pos.permute(0, 2, 1) # [bs, 3, n] - # convert to bounded x/y/z coords - pred_pos = (pred_pos.sigmoid() * 2 - 1) * MAX_POS + pred_pos = self.pos_decoder(out_positions.float()).permute( + 0, 2, 1 + ) # [bs, 3, n] return pred_pos, ae_pos diff --git a/torchdrive/tasks/path.py b/torchdrive/tasks/path.py index 2fc1926..701f907 100644 --- a/torchdrive/tasks/path.py +++ b/torchdrive/tasks/path.py @@ -8,7 +8,7 @@ from torch import nn from torchdrive.data import Batch -from torchdrive.models.path import PathTransformer +from torchdrive.models.path import PathTransformer, rel_dists from torchdrive.tasks.bev import BEVTask, Context @@ -108,6 +108,14 @@ def forward( per_token_loss.sum(dim=(1, 2)) * 5 / (num_elements + 1) ) + pred_dists = rel_dists(predicted) + target_dists = rel_dists(target) + rel_dist_loss = F.huber_loss( + pred_dists, target_dists, reduction="none", delta=20.0 + ) + rel_dist_loss *= mask + losses[f"rel_dists/{i}"] = rel_dist_loss.sum(dim=1) / (num_elements + 1) + # keep first values the same and shift predicted over by 1 prev = torch.cat((prev[..., :1], predicted[..., :-1]), dim=-1) @@ -119,6 +127,8 @@ def forward( fig = plt.figure() length = lengths[0] - 1 plt.plot(*target[0, 0:2, :length].detach().cpu(), label="target") + plt.plot(*prev[0, 0:2, 0].detach().cpu(), "go", label="origin") + plt.plot(*final_pos[0, 0:2].detach().cpu(), "go", label="final") for i, predicted in enumerate(all_predicted): if i % max(1, self.num_ar_iters // 4) != 0: @@ -128,6 +138,11 @@ def forward( label=f"predicted {i}", ) + fig.legend() + plt.gca().set_aspect("equal") + ctx.add_figure("paths/predicted", fig) + + fig = plt.figure() # autoregressive self.eval() autoregressive = PathTransformer.infer( @@ -137,11 +152,16 @@ def forward( final_pos[:1], n=length - 2, ) - plt.plot(*autoregressive[0, 0:2].detach().cpu(), label="autoregressive") + plt.plot(*target[0, 0:2, :length].detach().cpu(), label="target") + plt.plot( + *autoregressive[0, 0:2, 1:].detach().cpu(), label="autoregressive" + ) + plt.plot(*prev[0, 0:2, 0].detach().cpu(), "go", label="origin") + plt.plot(*final_pos[0, 0:2].detach().cpu(), "go", label="final") self.train() fig.legend() plt.gca().set_aspect("equal") - ctx.add_figure("paths", fig) + ctx.add_figure("paths/autoregressive", fig) return losses diff --git a/torchdrive/tasks/test_path.py b/torchdrive/tasks/test_path.py index 170eb4a..fe2a421 100644 --- a/torchdrive/tasks/test_path.py +++ b/torchdrive/tasks/test_path.py @@ -42,6 +42,9 @@ def test_path_task(self) -> None: "ae/0", "ae/1", "ae/2", + "rel_dists/0", + "rel_dists/1", + "rel_dists/2", ], ) self.assertEqual(losses["position/0"].shape, (2,))