From 5ce577b5e0e1efbbc9b287d9513078ad78f113a6 Mon Sep 17 00:00:00 2001 From: seyyed hossein Date: Fri, 17 Mar 2023 20:48:20 +0330 Subject: [PATCH] refactor simplenet into stages and remove padding for conv1x1s the last two 1x1 convs now use no padding, this is done to make the architetcure in line with what timms standards. because of this change the pretrained weights are no more valid and this needs to be retrained. --- timm/models/simplenet.py | 401 ++++++++++++++++++++++++++------------- 1 file changed, 273 insertions(+), 128 deletions(-) diff --git a/timm/models/simplenet.py b/timm/models/simplenet.py index 36668c7046..be38e3589f 100644 --- a/timm/models/simplenet.py +++ b/timm/models/simplenet.py @@ -21,6 +21,10 @@ import torch.nn as nn import torch.nn.functional as F +from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible +from ._builder import build_model_with_cfg +from ._efficientnet_builder import efficientnet_init_weights +from ._manipulate import checkpoint_seq from ._builder import build_model_with_cfg from ._registry import register_model from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -47,6 +51,8 @@ def _cfg(url="", **kwargs): "interpolation": "bicubic", "mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', + 'classifier': 'head.fc', **kwargs, } @@ -85,6 +91,40 @@ def forward(self, x): return x +class Downsample(nn.Module): + def __init__(self, pool='max', kernel_size=2, stride=2, dropout=0.0, inplace=True) -> None: + super().__init__() + self.pool = ( + nn.MaxPool2d(kernel_size=kernel_size, stride=stride) + if pool == 'max' + else nn.AvgPool2d(kernel_size=kernel_size, stride=stride) + ) + self.dropout = nn.Identity() if dropout is None else nn.Dropout2d(dropout, inplace=inplace) + + def forward(self, x): + x = self.pool(x) + x = self.dropout(x) + return x + # return View()(x) + + +class ConvBNReLU(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, Dropout=0.0) -> None: + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias) + self.bn = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.05) + self.relu = nn.ReLU(True) + self.dropout = nn.Identity() if Dropout is None else nn.Dropout2d(Dropout) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + x = self.dropout(x) + return x + # return View()(x) + + class SimpleNet(nn.Module): def __init__( self, @@ -94,6 +134,7 @@ def __init__( network_idx: int = 0, mode: int = 2, drop_rates: Dict[int, float] = {}, + **kwargs, ): """Instantiates a SimpleNet model. SimpleNet is comprised of the most basic building blocks of a CNN architecture. It uses basic principles to maximize the network performance both in terms of feature representation and speed without @@ -116,48 +157,68 @@ def __init__( each rate should be paired with the corrosponding layer index(pooling and cnn layers are counted only). Defaults to {}. """ super(SimpleNet, self).__init__() + self.output_stride = 24 + self.grad_checkpointing = False # (channels or layer-type, stride=1, drp=0.) - self.cfg: Dict[str, List[Tuple[Union(int, str), int, Union(float, None), Optional[str]]]] = { - "simplenetv1_imagenet": [ - (64, 1, 0.0), - (128, 1, 0.0), - (128, 1, 0.0), - (128, 1, 0.0), - (128, 1, 0.0), - (128, 1, 0.0), - ("p", 2, 0.0), - (256, 1, 0.0), - (256, 1, 0.0), - (256, 1, 0.0), - (512, 1, 0.0), - ("p", 2, 0.0), - (2048, 1, 0.0, "k1"), - (256, 1, 0.0, "k1"), - (256, 1, 0.0), - ], - "simplenetv1_imagenet_9m": [ - (128, 1, 0.0), - (192, 1, 0.0), - (192, 1, 0.0), - (192, 1, 0.0), - (192, 1, 0.0), - (192, 1, 0.0), - ("p", 2, 0.0), - (320, 1, 0.0), - (320, 1, 0.0), - (320, 1, 0.0), - (640, 1, 0.0), - ("p", 2, 0.0), - (2560, 1, 0.0, "k1"), - (320, 1, 0.0, "k1"), - (320, 1, 0.0), - ], + self.cfg = { + "simplenetv1_imagenet": { + 'stem': [(64, 1, 0.0)], + 'stage_': [ + (128, 1, 0.0), + (128, 1, 0.0), + (128, 1, 0.0), + (128, 1, 0.0), + (128, 1, 0.0), + ("p", 2, 0.0), + (256, 1, 0.0), + (256, 1, 0.0), + (256, 1, 0.0), + (512, 1, 0.0), + ("p", 2, 0.0), + (2048, 1, 0.0, "k1"), + (256, 1, 0.0, "k1"), + (256, 1, 0.0), + ], + }, + "simplenetv1_imagenet_9m": { + 'stem': [(128, 1, 0.0)], + 'stage_': [ + (192, 1, 0.0), + (192, 1, 0.0), + (192, 1, 0.0), + (192, 1, 0.0), + (192, 1, 0.0), + ("p", 2, 0.0), + (320, 1, 0.0), + (320, 1, 0.0), + (320, 1, 0.0), + (640, 1, 0.0), + ("p", 2, 0.0), + (2560, 1, 0.0, "k1"), + (320, 1, 0.0, "k1"), + (320, 1, 0.0), + ], + }, } - self.dropout_rates = drop_rates + self.networks = [ + "simplenetv1_imagenet", # 0 + "simplenetv1_imagenet_9m", # 1 + # other archs + ] + self.num_classes = num_classes + self.in_chans = in_chans + self.scale = scale + self.network_idx = network_idx + self.mode = mode + self.selected_network = self.cfg[self.networks[self.network_idx]] + # making sure all values are in correct form + self.dropout_rates = {int(key): float(value) for key, value in drop_rates.items()} # 15 is the last layer of the network(including two previous pooling layers) # basically specifying the dropout rate for the very last layer to be used after the pooling - self.last_dropout_rate = self.dropout_rates.get(15, 0.0) + # but if we add or remove some layers later on, it will mess thing up, so lets do it dynamically + last_layer_idx = sum(len(v) for _, v in self.selected_network.items()) + self.last_dropout_rate = self.dropout_rates.get(last_layer_idx, 0.0) self.strides = { 0: {}, 1: {0: 2, 1: 2, 2: 2}, @@ -166,105 +227,191 @@ def __init__( 4: {0: 2, 1: 1, 2: 2, 3: 1, 4: 2, 5: 1}, } - self.num_classes = num_classes - self.in_chans = in_chans - self.scale = scale - self.networks = [ - "simplenetv1_imagenet", # 0 - "simplenetv1_imagenet_9m", # 1 - # other archs - ] - self.network_idx = network_idx - self.mode = mode - - self.features = self._make_layers(scale) - self.classifier = nn.Linear(round(self.cfg[self.networks[network_idx]][-1][0] * scale), num_classes) - - def forward(self, x: torch.Tensor): - out = self.features(x) - out = F.max_pool2d(out, kernel_size=out.size()[2:]) - out = F.dropout2d(out, self.last_dropout_rate, training=self.training) - out = out.view(out.size(0), -1) - out = self.classifier(out) - return out - - def _make_layers(self, scale: float): - layers: List[nn.Module] = [] - input_channel = self.in_chans - stride_list = self.strides[self.mode] - for idx, (layer, stride, defaul_dropout_rate, *layer_type) in enumerate( - self.cfg[self.networks[self.network_idx]] - ): - stride = stride_list[idx] if len(stride_list) > idx else stride - # check if any custom dropout rate is specified - # for this layer, note that pooling also counts as 1 layer - custom_dropout = self.dropout_rates.get(idx, None) - custom_dropout = defaul_dropout_rate if custom_dropout is None else custom_dropout - # dropout values must be strictly decimal. while 0 doesnt introduce any issues here - # i.e. during training and inference, if you try to jit trace your model it will crash - # due to using 0 as dropout value(this applies up to 1.13.1) so here is an explicit - # check to convert any possible integer value to its decimal counterpart. - custom_dropout = None if custom_dropout is None else float(custom_dropout) - kernel_size = 3 if layer_type == [] else 1 - - if layer == "p": - layers += [ - nn.MaxPool2d(kernel_size=(2, 2), stride=(stride, stride)), - nn.Dropout2d(p=custom_dropout, inplace=True), - ] + self.features, self.feature_info = self._build_blocks() + self.num_features = round(self.selected_network['stage_'][-1][0] * scale) + self.head = ClassifierHead(self.num_features, num_classes, 'max', self.last_dropout_rate) + + def _build_blocks(self): + net_id = self.network_idx + features = nn.Sequential() + feature_info = [] + in_chan = self.in_chans + current_stride = 1 + for idx, (block_key, block_info) in enumerate(self.cfg[self.networks[net_id]].items()): + block_strides = self.extract_block_strides(block_key) + block_dropouts = self.extract_block_dropouts(block_key) + if block_key == 'stem': + filters, default_stride, defaul_dropout_rate = block_info[0] + self.stem_chs = round(filters * self.scale) + self.stem_stride = block_strides.get(idx, default_stride) + custom_dropout = self.get_final_dropout(idx, block_dropouts, defaul_dropout_rate) + self.stem = ConvBNReLU( + in_chan, self.stem_chs, 3, stride=self.stem_stride, padding=1, Dropout=custom_dropout + ) + feature_info += [dict(num_chs=self.stem_chs, reduction=self.stem_stride, module='stem')] + in_chan = self.stem_chs + reduction_rate = self.stem_stride else: - filters = round(layer * scale) - if custom_dropout is None: - layers += [ - nn.Conv2d(input_channel, filters, kernel_size=kernel_size, stride=stride, padding=1), - nn.BatchNorm2d(filters, eps=1e-05, momentum=0.05, affine=True), - nn.ReLU(inplace=True), + stage_index = -1 + stage_list = [] + stage_id = f'stage_0' + for idx, (filter, current_stride, current_dropout, *layer_type) in enumerate(block_info): + stage_id = f'stage_{stage_index}' + # check the current_stride + final_stride = block_strides.get(idx, current_stride) + # check final dropout + custom_dropout = self.get_final_dropout(idx, block_dropouts, current_dropout) + pad = 1 + if layer_type == []: + kernel_size = 3 + pad = 1 + else: + kernel_size = 1 + pad = 0 + + if final_stride > 1 or filter == 'p': + if stage_list: + features.add_module(stage_id, nn.Sequential(*stage_list)) + feature_info += [ + dict(num_chs=filters, reduction=reduction_rate, module=f'features.{stage_id}') + ] + + stage_index += 1 + reduction_rate *= final_stride + stage_list = [] + + if filter == 'p': + stage_list.append(Downsample(dropout=custom_dropout)) + else: + filters = round(filter * self.scale) + if custom_dropout is None: + stage_list.append( + ConvBNReLU( + in_chan, + filters, + kernel_size=kernel_size, + stride=final_stride, + padding=pad, + Dropout=None, + ) + ) + else: + stage_list.append( + ConvBNReLU( + in_chan, + filters, + kernel_size=kernel_size, + stride=final_stride, + padding=pad, + Dropout=custom_dropout, + ) + ) + in_chan = filters + + if stage_id: + features.add_module(stage_id, nn.Sequential(*stage_list)) + feature_info += [ + dict(num_chs=filters, reduction=int(reduction_rate), module=f'features.{stage_id}') ] - else: - layers += [ - nn.Conv2d(input_channel, filters, kernel_size=kernel_size, stride=stride, padding=1), - nn.BatchNorm2d(filters, eps=1e-05, momentum=0.05, affine=True), - nn.ReLU(inplace=True), - nn.Dropout2d(p=custom_dropout, inplace=False), - ] - - input_channel = filters - model = nn.Sequential(*layers) - for m in model.modules(): + # init the model weights + for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain("relu")) - return model + + return features, feature_info + + def get_final_dropout(self, idx, block_dropouts, current_dropout): + custom_dropout = block_dropouts.get(idx, None) + custom_dropout = current_dropout if custom_dropout is None else custom_dropout + # dropout values must be strictly decimal. while 0 doesnt introduce any issues here + # i.e. during training and inference, if you try to jit trace your model it will crash + # due to using 0 as dropout value so here is an explicit + # check to convert any possible integer value to its decimal counterpart. + custom_dropout = None if custom_dropout is None else float(custom_dropout) + return custom_dropout + + def extract_block_strides(self, block_key): + strides = self.strides[self.mode] + return self.process_block_info(block_key, strides) + + def extract_block_dropouts(self, block_key): + return self.process_block_info(block_key, self.dropout_rates) + + def get_stage_info(self): + stage_info = {} + idx = 0 + for k, v in self.cfg[self.networks[self.network_idx]].items(): + layer_cnt = len(v) + stage_info[k] = (idx, idx + layer_cnt, list(range(idx, idx + layer_cnt))) + idx += layer_cnt + return stage_info + + def process_block_info(self, block_key, data_dict): + stage_info = self.get_stage_info() + block_rates = {} + (_, _, idx_list) = stage_info[block_key] + for k, v in data_dict.items(): + if k in idx_list: + key_in_block = idx_list.index(k) + block_rates[key_in_block] = v + return block_rates @torch.jit.ignore def group_matcher(self, coarse=False): - # this treats BN layers as separate groups for bn variants, a lot of effort to fix that - return dict(stem=r"^features\.0", blocks=r"^features\.(\d+)") + matcher = dict( + stem=r'^stem', + blocks=r'^features\._stage_(\d+)\.(\d+)', + ) + return matcher @torch.jit.ignore def set_grad_checkpointing(self, enable=True): - assert not enable, "gradient checkpointing not supported" + self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self): - return self.classifier + return self.head.fc - def reset_classifier(self, num_classes: int): - self.num_classes = num_classes - self.classifier = nn.Linear(round(self.cfg[self.networks[self.network_idx]][-1][0] * self.scale), num_classes) + def reset_classifier(self, num_classes, drop_rate=0.0, global_pool='max'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + def forward_features(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.features, x, flatten=True) + else: + x = self.features(x) + return x - def forward_features(self, x: torch.Tensor) -> torch.Tensor: - return self.features(x) + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) - def forward_head(self, x: torch.Tensor, pre_logits: bool = False): + def forward(self, x): x = self.forward_features(x) - if pre_logits: - return x - else: - x = F.max_pool2d(x, kernel_size=x.size()[2:]) - x = F.dropout2d(x, self.last_dropout_rate, training=self.training) - x = x.view(x.size(0), -1) - return self.classifier(x) + x = self.forward_head(x) + return x + + +def _checkpoint_filter_fn(state_dict, model): + """Remaps original checkpoints -> timm""" + # shamelessly taken from https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py#L696 + if 'stem.0.weight' in state_dict: + return state_dict # non-original checkpoint, no remapping needed + + out_dict = {} + import re + + D = model.state_dict() + out_dict = {} + for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()): + if va.ndim == 4 and vb.ndim == 2: + vb = vb[:, :, None, None] + if va.shape != vb.shape: + # head or first-conv shapes may change for fine-tune + assert 'head' in ka or 'stem.conv1.linear' in ka + out_dict[ka] = vb + return out_dict def _gen_simplenet( @@ -286,16 +433,14 @@ def _gen_simplenet( drop_rates=drop_rates, **kwargs, ) - # to allow for seemless finetuning, remove the num_classes - # and load the model intact, we apply the changes afterward! - if "num_classes" in kwargs: - kwargs.pop("num_classes") - model = build_model_with_cfg(SimpleNet, model_variant, pretrained, **model_args) - # if the num_classes is different than imagenet's, it - # means its going to be finetuned, so only create a - # new classifier after the whole model is loaded! - if num_classes != 1000: - model.reset_classifier(num_classes) + model = build_model_with_cfg( + SimpleNet, + model_variant, + pretrained, + pretrained_filter_fn=_checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True), + **model_args, + ) return model