diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py index b36d3b40..092de36a 100644 --- a/segmentation_models_pytorch/decoders/upernet/decoder.py +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -36,10 +36,10 @@ def __init__( ) def forward(self, x): - _, _, height, weight = x.shape + _, _, height, width = x.shape out = [x] + [ F.interpolate( - block(x), size=(height, weight), mode="bilinear", align_corners=False + block(x), size=(height, width), mode="bilinear", align_corners=False ) for block in self.blocks ] @@ -62,10 +62,8 @@ def __init__(self, skip_channels, pyramid_channels, use_bathcnorm=True): ) def forward(self, x, skip): - _, channels, height, weight = skip.shape - x = F.interpolate( - x, size=(height, weight), mode="bilinear", align_corners=False - ) + _, channels, height, width = skip.shape + x = F.interpolate(x, size=(height, width), mode="bilinear", align_corners=False) if channels != 0: skip = self.skip_conv(skip) x = x + skip