Skip to content

Commit

Permalink
torchdrive/path: add relative distance regularization + fixed final_p…
Browse files Browse the repository at this point in the history
…os logic for autoregressive
  • Loading branch information
d4l3k committed Oct 29, 2023
1 parent 62c5222 commit 5f087fb
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 12 deletions.
25 changes: 16 additions & 9 deletions torchdrive/models/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
MAX_POS = 100 # meters from origin


class PathSigmoidMeters(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# convert to bounded x/y/z coords
return (x.float().sigmoid() * 2 - 1) * MAX_POS


def rel_dists(series: torch.Tensor) -> torch.Tensor:
"""
rel_dists returns the distances between each point in the series.
Expand All @@ -39,11 +45,13 @@ def __init__(
num_heads: int = 8,
num_layers: int = 3,
pos_dim: int = 3,
final_jitter: float = 5.0,
compile_fn: Callable[[nn.Module], nn.Module] = lambda m: m,
) -> None:
super().__init__()

self.dim = dim
self.final_jitter = final_jitter

self.bev_encoder = nn.Conv2d(bev_dim, dim, 1)

Expand All @@ -61,6 +69,7 @@ def __init__(
nn.Linear(dim, dim),
nn.ReLU(inplace=True),
nn.Linear(dim, pos_dim),
PathSigmoidMeters(),
)
)

Expand Down Expand Up @@ -94,14 +103,13 @@ def forward(

# feed it the target end position with random jitter added to avoid
# overfitting
end_jitter = 5
end_position = positions[:, :, -1]
end_jitter = (torch.rand_like(end_position) * 2 - 1) * end_jitter
end_position = end_position + end_jitter
if self.training:
final_jitter = (torch.rand_like(final_pos) * 2 - 1) * self.final_jitter
final_pos = final_pos + final_jitter

with autocast():
static_feats = torch.cat(
(speed, start_position, end_position), dim=1
(speed, start_position, final_pos), dim=1
).unsqueeze(-1)
static = self.static_encoder(static_feats).permute(0, 2, 1)

Expand All @@ -115,10 +123,9 @@ def forward(

out_positions = self.transformer(position_emb, cross_feats)

pred_pos = self.pos_decoder(out_positions.float())
pred_pos = pred_pos.permute(0, 2, 1) # [bs, 3, n]
# convert to bounded x/y/z coords
pred_pos = (pred_pos.sigmoid() * 2 - 1) * MAX_POS
pred_pos = self.pos_decoder(out_positions.float()).permute(
0, 2, 1
) # [bs, 3, n]

return pred_pos, ae_pos

Expand Down
26 changes: 23 additions & 3 deletions torchdrive/tasks/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import nn

from torchdrive.data import Batch
from torchdrive.models.path import PathTransformer
from torchdrive.models.path import PathTransformer, rel_dists
from torchdrive.tasks.bev import BEVTask, Context


Expand Down Expand Up @@ -108,6 +108,14 @@ def forward(
per_token_loss.sum(dim=(1, 2)) * 5 / (num_elements + 1)
)

pred_dists = rel_dists(predicted)
target_dists = rel_dists(target)
rel_dist_loss = F.huber_loss(
pred_dists, target_dists, reduction="none", delta=20.0
)
rel_dist_loss *= mask
losses[f"rel_dists/{i}"] = rel_dist_loss.sum(dim=1) / (num_elements + 1)

# keep first values the same and shift predicted over by 1
prev = torch.cat((prev[..., :1], predicted[..., :-1]), dim=-1)

Expand All @@ -119,6 +127,8 @@ def forward(
fig = plt.figure()
length = lengths[0] - 1
plt.plot(*target[0, 0:2, :length].detach().cpu(), label="target")
plt.plot(*prev[0, 0:2, 0].detach().cpu(), "go", label="origin")
plt.plot(*final_pos[0, 0:2].detach().cpu(), "go", label="final")

for i, predicted in enumerate(all_predicted):
if i % max(1, self.num_ar_iters // 4) != 0:
Expand All @@ -128,6 +138,11 @@ def forward(
label=f"predicted {i}",
)

fig.legend()
plt.gca().set_aspect("equal")
ctx.add_figure("paths/predicted", fig)

fig = plt.figure()
# autoregressive
self.eval()
autoregressive = PathTransformer.infer(
Expand All @@ -137,11 +152,16 @@ def forward(
final_pos[:1],
n=length - 2,
)
plt.plot(*autoregressive[0, 0:2].detach().cpu(), label="autoregressive")
plt.plot(*target[0, 0:2, :length].detach().cpu(), label="target")
plt.plot(
*autoregressive[0, 0:2, 1:].detach().cpu(), label="autoregressive"
)
plt.plot(*prev[0, 0:2, 0].detach().cpu(), "go", label="origin")
plt.plot(*final_pos[0, 0:2].detach().cpu(), "go", label="final")
self.train()

fig.legend()
plt.gca().set_aspect("equal")
ctx.add_figure("paths", fig)
ctx.add_figure("paths/autoregressive", fig)

return losses
3 changes: 3 additions & 0 deletions torchdrive/tasks/test_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def test_path_task(self) -> None:
"ae/0",
"ae/1",
"ae/2",
"rel_dists/0",
"rel_dists/1",
"rel_dists/2",
],
)
self.assertEqual(losses["position/0"].shape, (2,))

0 comments on commit 5f087fb

Please sign in to comment.