Skip to content

Commit

Permalink
torchdrive: overhaul x4 features
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Nov 3, 2023
1 parent 6394934 commit 976c15f
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 37 deletions.
29 changes: 22 additions & 7 deletions torchdrive/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,25 @@ def autograd_optional(tensor: T) -> Generator[T, None, None]:
yield tensor


def register_log_grad_norm(
t: torch.Tensor,
writer: Optional[SummaryWriter],
key: str,
tag: str,
global_step: int,
) -> None:
if writer is None:
return
nonopt_writer: SummaryWriter = writer

def backward_hook(grad: torch.Tensor) -> None:
nonopt_writer.add_scalars(
key, {tag: torch.linalg.vector_norm(grad).float()}, global_step=global_step
)

t.register_hook(backward_hook)


def log_grad_norm(
t: torch.Tensor,
writer: Optional[SummaryWriter],
Expand All @@ -103,14 +122,10 @@ def log_grad_norm(
"""
if writer is None:
return t
nonopt_writer: SummaryWriter = writer
# soft clone without copying data
t = t.view_as(t)

def backward_hook(grad: torch.Tensor) -> None:
nonopt_writer.add_scalars(
key, {tag: torch.linalg.vector_norm(grad).float()}, global_step=global_step
)

t.register_hook(backward_hook)
register_log_grad_norm(
t=t, writer=writer, key=key, tag=tag, global_step=global_step
)
return t
7 changes: 5 additions & 2 deletions torchdrive/models/bev_backbone.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Mapping, Tuple
from typing import Dict, List, Mapping, Tuple, Union

import torch
from torch import nn
Expand All @@ -17,5 +17,8 @@ class BEVBackbone(nn.Module, ABC):
@abstractmethod
def forward(
self, camera_features: Mapping[str, List[torch.Tensor]], batch: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]],
]:
pass
2 changes: 1 addition & 1 deletion torchdrive/models/det.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def __init__(
self,
bev_shape: Tuple[int, int],
dim: int,
num_queries: int = 100,
num_queries: int,
num_heads: int = 8,
num_classes: int = 10,
num_layers: int = 6,
Expand Down
2 changes: 1 addition & 1 deletion torchdrive/models/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.upsample(x.float()) # upsample doesn't support bfloat16
x = self.upsample(x)
return self.decode(x)


Expand Down
37 changes: 29 additions & 8 deletions torchdrive/models/simple_bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pytorch3d.structures.volumes import VolumeLocator
from torch import nn
from torchvision import transforms
from torchvision.models import regnet
from torchvision.models.resnet import resnet18
from torchworld.models.resnet_3d import resnet3d18, Upsample3DBlock

Expand Down Expand Up @@ -227,7 +228,9 @@ def __init__(self, in_channels: int, final_channels: int) -> None:
final_channels // 4, in_channels, scale_factor=2
)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x: [BS, in_channels, H, W]
Expand All @@ -250,16 +253,19 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# (H/8, W/8)
x4 = self.layer3(x)

# x4_skip is used for logging grad_norm stats
x4_skip = x4.view_as(x4)

# First upsample to (H/4, W/4)
x = self.up3_skip(x4, skip_x["3"])
x = self.up3_skip(x4_skip, skip_x["3"])

# Second upsample to (H/2, W/2)
x = self.up2_skip(x, skip_x["2"])

# Third upsample to (H, W)
x = self.up1_skip(x, skip_x["1"])

return x, x4
return x, x4, x4_skip


# extends FPN to preserve state_dict keys
Expand Down Expand Up @@ -902,7 +908,6 @@ def __init__(
) -> None:
super().__init__()

assert dim == 256, "dim must equal intermediate"
self.dim = dim
self.grid_shape = grid_shape
self.num_frames = num_frames
Expand All @@ -923,14 +928,29 @@ def __init__(
per_voxel_dim = max(hr_dim // (Z * 2), 1)
assert num_upsamples == 1, "only one upsample supported"
self.upsample: nn.Module = compile_fn(Upsample3DBlock(cam_dim, per_voxel_dim))
self.coarse_project = nn.Conv2d(dim // HR_Z * HR_Z, dim, 1)
self.coarse_project: nn.Module = compile_fn(
nn.Sequential(
nn.Conv2d(dim // HR_Z * HR_Z, dim, 1),
regnet.AnyStage(
dim,
dim,
stride=1,
depth=4,
block_constructor=regnet.ResBottleneckBlock,
norm_layer=nn.BatchNorm2d,
activation_layer=nn.ReLU,
group_width=dim, # regnet_x_3_2gf
bottleneck_multiplier=1.0,
),
)
)

# pyre-fixme[6]: invalid parameter type
self.lift_cam_to_voxel_mean: nn.Module = compile_fn(lift_cam_to_voxel_mean)

def forward(
self, camera_features: Mapping[str, List[torch.Tensor]], batch: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
BS = batch.batch_size()
S = len(camera_features) * self.num_frames
device = batch.device()
Expand Down Expand Up @@ -972,12 +992,13 @@ def forward(
with autocast():
# run through FPN
x = feat_mem
x, x4 = self.fpn(x)
x, x4, x4_skip = self.fpn(x)
assert x.shape == feat_mem.shape

x = self.upsample(x)

x4_coarse = x4.view_as(x4)
x4 = x4.flatten(1, 2)
x4 = self.coarse_project(x4)

return x, x4
return x, x4, {"coarse": x4_coarse, "skip": x4_skip}
18 changes: 14 additions & 4 deletions torchdrive/models/test_simple_bev.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

import torch
from parameterized import parameterized
from torchvision import models

from torchdrive.data import dummy_batch
Expand Down Expand Up @@ -69,9 +70,10 @@ def test_resnet_fpn_2d(self) -> None:

def test_resnet_fpn_3d(self) -> None:
m = ResnetFPN3d(3, 16)
x, x4 = m(torch.rand(2, 3, 8, 16, 24))
x, x4, x4_skip = m(torch.rand(2, 3, 8, 16, 24))
self.assertEqual(x.shape, (2, 3, 8, 16, 24))
self.assertEqual(x4.shape, (2, 16, 1, 2, 3))
self.assertEqual(x4_skip.shape, (2, 16, 1, 2, 3))

def test_segnet_backbone(self) -> None:
batch = dummy_batch()
Expand Down Expand Up @@ -143,14 +145,16 @@ def test_upsampling_add_3d(self) -> None:
)
self.assertEqual(out.shape, (2, 4, 10, 12, 14))

def test_segnet_3d_backbone(self) -> None:
# pyre-ignore[16]
@parameterized.expand([(128,), (256,)])
def test_segnet_3d_backbone(self, latent_dim: int) -> None:
batch = dummy_batch()
X = 8
Y = 16
Z = 24
HR_Z = Z // 8
cam_dim = 6
hr_dim = 5
latent_dim = 256
num_frames = 2
m = Segnet3DBackbone(
grid_shape=(X, Y, Z),
Expand All @@ -168,9 +172,15 @@ def test_segnet_3d_backbone(self) -> None:
for feats in camera_features.values():
for feat in feats:
feat.requires_grad = True
x, x4 = m(camera_features, batch)
x, x4, x4_intermediates = m(camera_features, batch)
self.assertEqual(x.shape, (batch.batch_size(), 1, Z * 2, X * 2, Y * 2))
self.assertEqual(x4.shape, (batch.batch_size(), latent_dim, X // 8, Y // 8))
for inter in x4_intermediates.values():
self.assertEqual(
inter.shape,
(batch.batch_size(), latent_dim // HR_Z, Z // 8, X // 8, Y // 8),
)

(x.mean() + x4.mean()).backward()

for feats in camera_features.values():
Expand Down
22 changes: 20 additions & 2 deletions torchdrive/tasks/bev.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from torchworld.transforms.img import render_color

from torchdrive.amp import autocast
from torchdrive.autograd import autograd_pause, autograd_resume, log_grad_norm
from torchdrive.autograd import (
autograd_pause,
autograd_resume,
log_grad_norm,
register_log_grad_norm,
)
from torchdrive.data import Batch
from torchdrive.models.bev_backbone import BEVBackbone
from torchdrive.tasks.context import Context
Expand Down Expand Up @@ -183,7 +188,20 @@ def forward(
), f"{len(cam_feats)} {self.num_encode_frames}"

with torch.autograd.profiler.record_function("backbone"):
hr_bev, bev = self.backbone(camera_feats, batch)
backbone_out = self.backbone(camera_feats, batch)

hr_bev, bev = backbone_out[:2]
if len(backbone_out) >= 3:
x4_intermediates = backbone_out[2]

for tag, x in x4_intermediates:
register_log_grad_norm(
t=x,
writer=writer,
key="grad/norm/backbone-x4",
tag=tag,
global_step=global_step,
)

if log_img and writer:
writer.add_image(
Expand Down
3 changes: 2 additions & 1 deletion torchdrive/tasks/det.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
bev_shape: Tuple[int, int],
dim: int,
device: torch.device,
num_queries: int,
compile_fn: Callable[[nn.Module], nn.Module] = lambda m: m,
) -> None:
super().__init__()
Expand All @@ -55,6 +56,7 @@ def __init__(
self.cameras = cameras

decoder = DetBEVTransformerDecoder(
num_queries=num_queries,
bev_shape=bev_shape,
dim=dim,
)
Expand Down Expand Up @@ -105,7 +107,6 @@ def forward(
)

num_queries = classes_logits.shape[1]
assert num_queries == 100

if ctx.log_text:
ctx.add_scalars(
Expand Down
15 changes: 7 additions & 8 deletions torchdrive/tasks/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
bev_dim: int,
dim: int = 768,
num_heads: int = 16,
num_layers: int = 12,
num_layers: int = 6,
max_seq_len: int = 6 * 2,
num_ar_iters: int = 6,
compile_fn: Callable[[nn.Module], nn.Module] = lambda m: m,
Expand Down Expand Up @@ -84,8 +84,8 @@ def forward(
assert posmax < 1000

# target = positions[..., 1:]
prev = positions
target = positions
prev = positions[..., :-1]
target = positions[..., 1:]

all_predicted = []
losses = {}
Expand All @@ -102,7 +102,7 @@ def forward(
per_token_loss = F.huber_loss(
predicted, target, reduction="none", delta=20.0
)
per_token_loss *= mask.unsqueeze(1).expand(-1, 3, -1)
per_token_loss *= mask[..., 1:].unsqueeze(1).expand(-1, 3, -1)

# normalize by number of elements in sequence
losses[f"position/{i}"] = (
Expand All @@ -114,12 +114,11 @@ def forward(
rel_dist_loss = F.huber_loss(
pred_dists, target_dists, reduction="none", delta=20.0
)
rel_dist_loss *= mask
rel_dist_loss *= mask[..., 1:]
losses[f"rel_dists/{i}"] = rel_dist_loss.sum(dim=1) / (num_elements + 1)

# keep first values the same and shift predicted over by 1
prev = predicted
# prev = torch.cat((prev[..., :1], predicted[..., :-1]), dim=-1)
# keep first value the same and shift predicted over by 1
prev = torch.cat((prev[..., :1], predicted[..., :-1]), dim=-1)

if ctx.log_text:
ctx.add_scalar("ae/mae", self.ae_mae.compute())
Expand Down
1 change: 1 addition & 0 deletions torchdrive/tasks/test_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_det_task(self) -> None:
bev_shape=(4, 4),
dim=8,
device=torch.device("cpu"),
num_queries=10,
)
batch = dummy_batch()
ctx = Context(
Expand Down
7 changes: 5 additions & 2 deletions torchdrive/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,13 @@ def test_world_to_cam(self) -> None:
torch.testing.assert_allclose(out, target)

def test_cam_to_world(self) -> None:
torch.manual_seed(0)
torch.use_deterministic_algorithms(True, warn_only=True)

batch = dummy_batch()
cam = "left"
frame = 1
target = batch.T[cam].pinverse().matmul(batch.cam_T[:, frame]).pinverse()
target = batch.T[cam].inverse().matmul(batch.cam_T[:, frame]).inverse()
out = batch.cam_to_world(cam, frame)
self.assertEqual(out.shape, (2, 4, 4))
torch.testing.assert_allclose(out, target)
Expand All @@ -139,7 +142,7 @@ def test_lidar_to_world(self) -> None:
def test_lidar_points(self) -> None:
batch = dummy_batch()
out = batch.lidar_points()
self.assertEqual(out.data.shape, (2, 4, 6))
self.assertEqual(out.data.shape[:2], (2, 4))

def test_camera_names(self) -> None:
batch = dummy_batch()
Expand Down
4 changes: 4 additions & 0 deletions torchdrive/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class TrainConfig:
voxelsem: List[str]
path: bool

# det config
det_num_queries: int = 1000

start_offsets: Tuple[int, ...] = (0,)

def create_dataset(self, smoke: bool = False) -> Dataset:
Expand Down Expand Up @@ -208,6 +211,7 @@ def cam_encoder() -> RegNetEncoder:
dim=self.dim,
device=device,
compile_fn=compile_fn,
num_queries=self.det_num_queries,
)
if self.ae:
from torchdrive.tasks.ae import AETask
Expand Down
2 changes: 1 addition & 1 deletion torchdrive/transforms/test_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class TestDepth(unittest.TestCase):
def test_project(self) -> None:
torch.manual_seed(0)
torch.use_deterministic_algorithms(True)
torch.use_deterministic_algorithms(True, warn_only=True)

backproject_depth = BackprojectDepth(
height=4,
Expand Down

0 comments on commit 976c15f

Please sign in to comment.