diff --git a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py index a5afd20b..caeb95d1 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py @@ -30,6 +30,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ +from collections.abc import Iterable, Sequence +from typing import Literal + import torch from torch import nn from torch.nn import functional as F @@ -38,9 +41,22 @@ class DeepLabV3Decoder(nn.Sequential): - def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)): + def __init__( + self, + in_channels: int, + out_channels: int, + atrous_rates: Iterable[int], + aspp_separable: bool, + aspp_dropout: float, + ): super().__init__( - ASPP(in_channels, out_channels, atrous_rates), + ASPP( + in_channels, + out_channels, + atrous_rates, + separable=aspp_separable, + dropout=aspp_dropout, + ), nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), @@ -54,10 +70,12 @@ def forward(self, *features): class DeepLabV3PlusDecoder(nn.Module): def __init__( self, - encoder_channels, - out_channels=256, - atrous_rates=(12, 24, 36), - output_stride=16, + encoder_channels: Sequence[int, ...], + out_channels: int, + atrous_rates: Iterable[int], + output_stride: Literal[8, 16], + aspp_separable: bool, + aspp_dropout: float, ): super().__init__() if output_stride not in {8, 16}: @@ -69,7 +87,13 @@ def __init__( self.output_stride = output_stride self.aspp = nn.Sequential( - ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True), + ASPP( + encoder_channels[-1], + out_channels, + atrous_rates, + separable=aspp_separable, + dropout=aspp_dropout, + ), SeparableConv2d( out_channels, out_channels, kernel_size=3, padding=1, bias=False ), @@ -164,7 +188,8 @@ def __init__( in_channels: int, out_channels: int, atrous_rates: Iterable[int], - separable: bool=False, + separable: bool, + dropout: float, ): super(ASPP, self).__init__() modules = [ @@ -189,7 +214,7 @@ def __init__( nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), - nn.Dropout(0.5), + nn.Dropout(dropout), ) def forward(self, x): diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py index 3b3d64f0..8ad8d714 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/model.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -24,13 +24,17 @@ class DeepLabV3(SegmentationModel): encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) decoder_channels: A number of convolution filters in ASPP module. Default is 256 + encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation) + decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values) + decoder_aspp_separable: Use separable convolutions in ASPP module. Default is False + decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5 in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. Default is **None** - upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity + upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity). aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - classes (int): A number of classes @@ -51,11 +55,15 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", + encoder_output_stride: Literal[8, 16] = 8, decoder_channels: int = 256, + decoder_atrous_rates: Iterable[int] = (12, 24, 36), + decoder_aspp_separable: bool = False, + decoder_aspp_dropout: float = 0.5, in_channels: int = 3, classes: int = 1, activation: Optional[str] = None, - upsampling: int = 8, + upsampling: Optional[int] = None, aux_params: Optional[dict] = None, ): super().__init__() @@ -65,11 +73,15 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, - output_stride=8, + output_stride=encoder_output_stride, ) self.decoder = DeepLabV3Decoder( - in_channels=self.encoder.out_channels[-1], out_channels=decoder_channels + in_channels=self.encoder.out_channels[-1], + out_channels=decoder_channels, + atrous_rates=decoder_atrous_rates, + aspp_separable=decoder_aspp_separable, + aspp_dropout=decoder_aspp_dropout, ) self.segmentation_head = SegmentationHead( @@ -77,7 +89,7 @@ def __init__( out_channels=classes, activation=activation, kernel_size=1, - upsampling=upsampling, + upsampling=encoder_output_stride if upsampling is None else upsampling, ) if aux_params is not None: @@ -102,8 +114,9 @@ class DeepLabV3Plus(SegmentationModel): encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation) - decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values) decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values) + decoder_aspp_separable: Use separable convolutions in ASPP module. Default is True + decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5 decoder_channels: A number of convolution filters in ASPP module. Default is 256 in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) @@ -134,8 +147,9 @@ def __init__( encoder_weights: Optional[str] = "imagenet", encoder_output_stride: Literal[8, 16] = 16, decoder_channels: int = 256, - decoder_atrous_rates: tuple = (12, 24, 36), decoder_atrous_rates: Iterable[int] = (12, 24, 36), + decoder_aspp_separable: bool = True, + decoder_aspp_dropout: float = 0.5, in_channels: int = 3, classes: int = 1, activation: Optional[str] = None, @@ -157,6 +171,8 @@ def __init__( out_channels=decoder_channels, atrous_rates=decoder_atrous_rates, output_stride=encoder_output_stride, + aspp_separable=decoder_aspp_separable, + aspp_dropout=decoder_aspp_dropout, ) self.segmentation_head = SegmentationHead(