Skip to content

Commit

Permalink
Fixed shape error, allowing arbitary image sizes for EfficientAD (#1537)
Browse files Browse the repository at this point in the history
* Fixed shape error, allowing arbitrary image sizes. Replaced integer parsing by floor operation

* Replaced calculation by ceil operation. Solution of shape error is to round up and not down for the last upsample layer

* Add comment for ceil oepration

* Formatting with pre-commit hook
  • Loading branch information
holzweber authored Jan 23, 2024
1 parent 2e09314 commit 9effa29
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/anomalib/models/efficient_ad/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import logging
import math
import random
from enum import Enum

Expand Down Expand Up @@ -147,9 +148,10 @@ class Decoder(nn.Module):
def __init__(self, out_channels, padding, img_size, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.img_size = img_size
# use ceil to match output shape of PDN
self.last_upsample = (
int(img_size[0] / 4) if padding else int(img_size[0] / 4) - 8,
int(img_size[1] / 4) if padding else int(img_size[1] / 4) - 8,
math.ceil(img_size[0] / 4) if padding else math.ceil(img_size[0] / 4) - 8,
math.ceil(img_size[1] / 4) if padding else math.ceil(img_size[1] / 4) - 8,
)
self.deconv1 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2)
self.deconv2 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2)
Expand All @@ -167,22 +169,22 @@ def __init__(self, out_channels, padding, img_size, *args, **kwargs) -> None:
self.dropout6 = nn.Dropout(p=0.2)

def forward(self, x):
x = F.interpolate(x, size=(int(self.img_size[0] / 64) - 1, int(self.img_size[1] / 64) - 1), mode="bilinear")
x = F.interpolate(x, size=(self.img_size[0] // 64 - 1, self.img_size[1] // 64 - 1), mode="bilinear")
x = F.relu(self.deconv1(x))
x = self.dropout1(x)
x = F.interpolate(x, size=(int(self.img_size[0] / 32), int(self.img_size[1] / 32)), mode="bilinear")
x = F.interpolate(x, size=(self.img_size[0] // 32, self.img_size[1] // 32), mode="bilinear")
x = F.relu(self.deconv2(x))
x = self.dropout2(x)
x = F.interpolate(x, size=(int(self.img_size[0] / 16) - 1, int(self.img_size[1] / 16) - 1), mode="bilinear")
x = F.interpolate(x, size=(self.img_size[0] // 16 - 1, self.img_size[1] // 16 - 1), mode="bilinear")
x = F.relu(self.deconv3(x))
x = self.dropout3(x)
x = F.interpolate(x, size=(int(self.img_size[0] / 8), int(self.img_size[1] / 8)), mode="bilinear")
x = F.interpolate(x, size=(self.img_size[0] // 8, self.img_size[1] // 8), mode="bilinear")
x = F.relu(self.deconv4(x))
x = self.dropout4(x)
x = F.interpolate(x, size=(int(self.img_size[0] / 4) - 1, int(self.img_size[1] / 4) - 1), mode="bilinear")
x = F.interpolate(x, size=(self.img_size[0] // 4 - 1, self.img_size[1] // 4 - 1), mode="bilinear")
x = F.relu(self.deconv5(x))
x = self.dropout5(x)
x = F.interpolate(x, size=(int(self.img_size[0] / 2) - 1, int(self.img_size[1] / 2) - 1), mode="bilinear")
x = F.interpolate(x, size=(self.img_size[0] // 2 - 1, self.img_size[1] // 2 - 1), mode="bilinear")
x = F.relu(self.deconv6(x))
x = self.dropout6(x)
x = F.interpolate(x, size=self.last_upsample, mode="bilinear")
Expand Down

0 comments on commit 9effa29

Please sign in to comment.