Skip to content

Commit

Permalink
torchworld/transforms/simplebev: added lifting operation
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Oct 27, 2023
1 parent 5d3a7da commit 62c5222
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 8 deletions.
2 changes: 1 addition & 1 deletion torchworld/models/test_fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestFPN(unittest.TestCase):
def test_resnet18_fpn_3d(self) -> None:
grid = Grid3d(
data=torch.rand(2, 3, 8, 16, 24),
transform=Transform3d(),
local_to_world=Transform3d(),
time=torch.rand(2),
)
m = Resnet18FPN3d(in_channels=3)
Expand Down
26 changes: 20 additions & 6 deletions torchworld/structures/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,17 @@ def cpu(self) -> T:
def device(self) -> torch.device:
return self.data.device

@property
def dtype(self) -> torch.dtype:
return self.data.dtype

def __len__(self) -> int:
return len(self.data)

@abstractmethod
def grid_shape(self) -> Tuple[int, ...]:
...


@dataclass
class Grid3d(BaseGrid):
Expand All @@ -45,15 +53,15 @@ class Grid3d(BaseGrid):
----------
data: [bs, channels, z, y, x]
The grid of features.
transform:
local_to_world:
The 3d transform option that locates the Grid3d in space
Voxel (-1 to 1) to world space.
time: scalar or [bs]
Time corresponding to the grid
"""

data: torch.Tensor
transform: Transform3d
local_to_world: Transform3d
time: torch.Tensor

def __post_init__(self) -> None:
Expand All @@ -64,11 +72,11 @@ def __post_init__(self) -> None:
f"time must be scalar or 1-dimensional, got {self.time.shape}"
)

T = self.transform.get_matrix()
T = self.local_to_world.get_matrix()
if (BS := T.size(0)) != 1:
if BS != self.data.size(0):
raise TypeError(
f"data and transform batch sizes don't match: {T.shape, self.data.shape}"
f"data and local_to_world batch sizes don't match: {T.shape, self.data.shape}"
)

@classmethod
Expand Down Expand Up @@ -108,17 +116,20 @@ def from_volume(
time = torch.tensor(time, dtype=torch.float, device=device)
return cls(
data=data,
transform=locator.get_local_to_world_coords_transform(),
local_to_world=locator.get_local_to_world_coords_transform(),
time=time,
)

def to(self, target: Union[torch.device, str]) -> "Grid3d":
return Grid3d(
data=self.data.to(target),
transform=self.transform.to(target),
local_to_world=self.local_to_world.to(target),
time=self.time.to(target),
)

def grid_shape(self) -> Tuple[int, int]:
return self.data.shape[2:5]


@dataclass
class GridImage(BaseGrid):
Expand Down Expand Up @@ -160,3 +171,6 @@ def to(self, target: Union[torch.device, str]) -> "GridImage":
camera=self.camera.to(target),
time=self.time.to(target),
)

def grid_shape(self) -> Tuple[int, int]:
return self.data.shape[2:4]
5 changes: 4 additions & 1 deletion torchworld/structures/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class TestGrid(unittest.TestCase):
def test_grid_3d(self) -> None:
grid = Grid3d(
data=torch.rand(2, 3, 4, 5, 6),
transform=Transform3d(),
local_to_world=Transform3d(),
time=torch.rand(2),
)

Expand All @@ -20,6 +20,8 @@ def test_grid_3d(self) -> None:

self.assertEqual(len(grid), 2)
self.assertEqual(grid.device, grid.data.device)
self.assertEqual(grid.dtype, torch.float)
self.assertEqual(grid.grid_shape(), (4, 5, 6))

def test_grid_image(self) -> None:
grid = GridImage(
Expand All @@ -30,6 +32,7 @@ def test_grid_image(self) -> None:

grid = grid.to("cpu")
grid = grid.cpu()
self.assertEqual(grid.grid_shape(), (4, 5))

def test_grid_3d_from_volume(self) -> None:
grid = Grid3d.from_volume(
Expand Down
80 changes: 80 additions & 0 deletions torchworld/transforms/simplebev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from dataclasses import replace
from typing import Tuple

import torch
import torch.nn.functional as F

from torchworld.structures.grid import Grid3d, GridImage


def lift_image_to_3d(
src: GridImage,
dst: Grid3d,
eps: float = 1e-7,
) -> Tuple[Grid3d, Grid3d]:
"""
Lift the features from a camera to a Voxel volume.
Implements Simple BEV lifting operation.
See: https://arxiv.org/pdf/2206.07959.pdf
Arguments
---------
src: Source features and camera.
dst: Destination 3D grid.
eps: A small value to avoid NaNs.
Returns
-------
features: grid with features
mask: grid of the mask where the camera could see
"""
if dst.data.numel() != 0:
raise TypeError(f"dst should be batch size zero")

device = src.device
BS = len(src)
grid_shape = dst.grid_shape()

# calculate the x/y/z coordinates for each voxel in the grid
channels = torch.meshgrid(
*(torch.arange(-1, 1 - eps, 2 / dim, device=device) for dim in grid_shape),
indexing="ij",
)
grid_points = torch.stack(channels, dim=-1)
grid_points = grid_points.flatten(0, -2).unsqueeze(0)
T = dst.local_to_world
T = T.compose(src.camera.get_full_projection_transform())
assert src.camera.in_ndc(), "TODO support non-ndc cameras"

image_points = T.transform_points(grid_points, eps=eps)

# hide samples behind camera
z = image_points[..., 2]

valid = z > 0
valid = valid.unflatten(1, grid_shape).unsqueeze(1)

# drop z axis
image_points = image_points[..., :2]
# grid_sample needs a 2d input so we add a dummy dimension
image_points = image_points.unsqueeze(1)

# make batch size match
if len(image_points) == 1:
image_points = image_points.expand(BS, -1, -1, -1)
valid = valid.expand(BS, -1, -1, -1, -1)

# grid_sample doesn't support bfloat16 so cast to float
values = F.grid_sample(src.data.float(), image_points, align_corners=False)
values = values.to(src.data.dtype)

# restore to grid shape
values = values.squeeze(2).unflatten(-1, grid_shape)
values *= valid

return (
replace(dst, data=values, time=src.time),
replace(dst, data=valid, time=src.time),
)
30 changes: 30 additions & 0 deletions torchworld/transforms/test_simplebev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import unittest

import torch
from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.transforms import Transform3d

from torchworld.structures.grid import Grid3d, GridImage
from torchworld.transforms.simplebev import lift_image_to_3d


class TestSimpleBEV(unittest.TestCase):
def test_lift_image_to_3d(self) -> None:
device = torch.device("cpu")
dtype = torch.half
dst = Grid3d(
data=torch.rand(0, 3, 4, 5, 6, device=device, dtype=dtype),
local_to_world=Transform3d(device=device),
time=torch.rand(2, device=device),
)
src = GridImage(
data=torch.rand(2, 3, 4, 5, device=device, dtype=dtype),
camera=PerspectiveCameras(device=device),
time=torch.rand(2, device=device),
)

out, mask = lift_image_to_3d(src, dst)
self.assertEqual(out.data.shape, (2, 3, 4, 5, 6))
self.assertEqual(mask.data.shape, (2, 1, 4, 5, 6))


0 comments on commit 62c5222

Please sign in to comment.