Skip to content

Commit

Permalink
Expose hidden block arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
DimitrisMantas committed Oct 22, 2024
1 parent 52e78d7 commit 9bed661
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 16 deletions.
43 changes: 34 additions & 9 deletions segmentation_models_pytorch/decoders/deeplabv3/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""

from collections.abc import Iterable, Sequence
from typing import Literal

import torch
from torch import nn
from torch.nn import functional as F
Expand All @@ -38,9 +41,22 @@


class DeepLabV3Decoder(nn.Sequential):
def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)):
def __init__(
self,
in_channels: int,
out_channels: int,
atrous_rates: Iterable[int],
aspp_separable: bool,
aspp_dropout: float,
):
super().__init__(
ASPP(in_channels, out_channels, atrous_rates),
ASPP(
in_channels,
out_channels,
atrous_rates,
separable=aspp_separable,
dropout=aspp_dropout,
),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
Expand All @@ -54,10 +70,12 @@ def forward(self, *features):
class DeepLabV3PlusDecoder(nn.Module):
def __init__(
self,
encoder_channels,
out_channels=256,
atrous_rates=(12, 24, 36),
output_stride=16,
encoder_channels: Sequence[int, ...],
out_channels: int,
atrous_rates: Iterable[int],
output_stride: Literal[8, 16],
aspp_separable: bool,
aspp_dropout: float,
):
super().__init__()
if output_stride not in {8, 16}:
Expand All @@ -69,7 +87,13 @@ def __init__(
self.output_stride = output_stride

self.aspp = nn.Sequential(
ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
ASPP(
encoder_channels[-1],
out_channels,
atrous_rates,
separable=aspp_separable,
dropout=aspp_dropout,
),
SeparableConv2d(
out_channels, out_channels, kernel_size=3, padding=1, bias=False
),
Expand Down Expand Up @@ -164,7 +188,8 @@ def __init__(
in_channels: int,
out_channels: int,
atrous_rates: Iterable[int],
separable: bool=False,
separable: bool,
dropout: float,
):
super(ASPP, self).__init__()
modules = [
Expand All @@ -189,7 +214,7 @@ def __init__(
nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.5),
nn.Dropout(dropout),
)

def forward(self, x):
Expand Down
30 changes: 23 additions & 7 deletions segmentation_models_pytorch/decoders/deeplabv3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ class DeepLabV3(SegmentationModel):
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
decoder_channels: A number of convolution filters in ASPP module. Default is 256
encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values)
decoder_aspp_separable: Use separable convolutions in ASPP module. Default is False
decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity
upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity).
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
Expand All @@ -51,11 +55,15 @@ def __init__(
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
encoder_output_stride: Literal[8, 16] = 8,
decoder_channels: int = 256,
decoder_atrous_rates: Iterable[int] = (12, 24, 36),
decoder_aspp_separable: bool = False,
decoder_aspp_dropout: float = 0.5,
in_channels: int = 3,
classes: int = 1,
activation: Optional[str] = None,
upsampling: int = 8,
upsampling: Optional[int] = None,
aux_params: Optional[dict] = None,
):
super().__init__()
Expand All @@ -65,19 +73,23 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
output_stride=8,
output_stride=encoder_output_stride,
)

self.decoder = DeepLabV3Decoder(
in_channels=self.encoder.out_channels[-1], out_channels=decoder_channels
in_channels=self.encoder.out_channels[-1],
out_channels=decoder_channels,
atrous_rates=decoder_atrous_rates,
aspp_separable=decoder_aspp_separable,
aspp_dropout=decoder_aspp_dropout,
)

self.segmentation_head = SegmentationHead(
in_channels=self.decoder.out_channels,
out_channels=classes,
activation=activation,
kernel_size=1,
upsampling=upsampling,
upsampling=encoder_output_stride if upsampling is None else upsampling,
)

if aux_params is not None:
Expand All @@ -102,8 +114,9 @@ class DeepLabV3Plus(SegmentationModel):
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values)
decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values)
decoder_aspp_separable: Use separable convolutions in ASPP module. Default is True
decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5
decoder_channels: A number of convolution filters in ASPP module. Default is 256
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
Expand Down Expand Up @@ -134,8 +147,9 @@ def __init__(
encoder_weights: Optional[str] = "imagenet",
encoder_output_stride: Literal[8, 16] = 16,
decoder_channels: int = 256,
decoder_atrous_rates: tuple = (12, 24, 36),
decoder_atrous_rates: Iterable[int] = (12, 24, 36),
decoder_aspp_separable: bool = True,
decoder_aspp_dropout: float = 0.5,
in_channels: int = 3,
classes: int = 1,
activation: Optional[str] = None,
Expand All @@ -157,6 +171,8 @@ def __init__(
out_channels=decoder_channels,
atrous_rates=decoder_atrous_rates,
output_stride=encoder_output_stride,
aspp_separable=decoder_aspp_separable,
aspp_dropout=decoder_aspp_dropout,
)

self.segmentation_head = SegmentationHead(
Expand Down

0 comments on commit 9bed661

Please sign in to comment.