Skip to content

Commit

Permalink
torchworld/sfm: added export test
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Oct 31, 2023
1 parent 1227882 commit 377905f
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions torchworld/transforms/test_sfm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,58 @@
import unittest

import torch
from torch import nn
from torch.export import export

from torchworld.structures.cameras import PerspectiveCameras
from torchworld.structures.grid import GridImage
from torchworld.transforms.sfm import project


class MyModel(nn.Module):
def __init__(self) -> None:
super().__init__()

self.device = torch.device("cpu")
self.dtype: torch.dtype = torch.float32

self.camera = PerspectiveCameras(device=self.device)
self.mask: torch.Tensor = torch.ones(
2, 1, 4, 6, device=self.device, dtype=self.dtype
)

def forward(self, data: torch.Tensor) -> torch.Tensor:
src = GridImage(
data=data,
camera=self.camera,
time=torch.rand(2, device=self.device),
mask=self.mask,
)
depth = GridImage(
data=torch.ones(2, 1, 4, 6, device=self.device, dtype=self.dtype),
camera=self.camera,
time=torch.rand(2, device=self.device),
mask=self.mask,
)
vel = GridImage(
data=torch.zeros(2, 3, 4, 6, device=self.device, dtype=self.dtype),
camera=self.camera,
time=torch.rand(2, device=self.device),
mask=self.mask,
)
return project(dst=src, src=src, depth=depth, vel=vel).data


class TestSFM(unittest.TestCase):
def test_export(self) -> None:
data = torch.ones(2, 3, 4, 6)
model = MyModel()
model(data)
exported = export(model, args=(data,))
self.assertIsNotNone(exported)
print(exported)
self.fail()

def test_project(self) -> None:
device = torch.device("cpu")
dtype = torch.float32
Expand Down

0 comments on commit 377905f

Please sign in to comment.