Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MaskRCNN Wrapper #624

Open
nkhlS141 opened this issue Oct 5, 2023 · 4 comments
Open

MaskRCNN Wrapper #624

nkhlS141 opened this issue Oct 5, 2023 · 4 comments
Labels
enhancement New feature or request

Comments

@nkhlS141
Copy link

nkhlS141 commented Oct 5, 2023


I am looking for the Wrapper class below. I have trained a maskrcnn model

orig_model = torch.jit.load(os.path.join(predictor_path, "model.jit"))
wrapped_model = Wrapper(orig_model)
scripted_model = torch.jit.script(wrapped_model)
scripted_model.save("d2go.pt")

I found this but this seems to be for fast-rcnn models

class Wrapper(torch.nn.Module):

def __init__(self, model):
    super().__init__()
    self.model = model
    coco_idx_list = [1]

    self.coco_idx = torch.tensor(coco_idx_list)

def forward(self, inputs: List[torch.Tensor]):
    x = inputs[0].unsqueeze(0) * 255
    scale = 320.0 / min(x.shape[-2], x.shape[-1])
    x = torch.nn.functional.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=True, recompute_scale_factor=True)
    out = self.model(x[0])
    res : Dict[str, torch.Tensor] = {}
    res["boxes"] = out[0] / scale
    res["labels"] = torch.index_select(self.coco_idx, 0, out[1])
    res["scores"] = out[2]
    return inputs, [res]

Any idea?

@nkhlS141 nkhlS141 added the enhancement New feature or request label Oct 5, 2023
@wat3rBro
Copy link
Contributor

wat3rBro commented Oct 5, 2023

The cfg is from a maskrcnn model, so the out should contain segmentation mask, so adding that to res can make Wrapper a maskrnn model.

@nkhlS141
Copy link
Author

nkhlS141 commented Oct 5, 2023

So you mean this would work?

def forward(self, inputs: List[torch.Tensor]):
x = inputs[0].unsqueeze(0) * 255
scale = 320.0 / min(x.shape[-2], x.shape[-1])
x = torch.nn.functional.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=True, recompute_scale_factor=True)
out = self.model(x)
res : Dict[str, torch.Tensor] = {}
res["boxes"] = out[0] / scale
res["labels"] = torch.index_select(self.coco_idx, 0, out[1])
res["masks"] = out[2]
res["scores"] = out[3]
return inputs, [res]

@nkhlS141
Copy link
Author

nkhlS141 commented Oct 5, 2023

So the above changes in Wrapper class doesn't throw me any errors and the "d2go.pt" file gets created. But when I try to open this file in netron it throws error

@rochist
Copy link

rochist commented Nov 15, 2023

class Wrapper(torch.nn.Module):
def init(self, model):
super().init()
self.model = model
def forward(self, inputs: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
x = inputs.unsqueeze(0) * 255
scale = 320.0 / min(x.shape[-2], x.shape[-1])
x = torch.nn.functional.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=True, recompute_scale_factor=True)
out = self.model(x[0])
return out[0] / scale, out[1], out[2], out[3] ,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants