diff --git a/vision_toolbox/backbones/deit.py b/vision_toolbox/backbones/deit.py index a5252c8..ae3e486 100644 --- a/vision_toolbox/backbones/deit.py +++ b/vision_toolbox/backbones/deit.py @@ -28,8 +28,8 @@ def __init__( ) -> None: # fmt: off super().__init__( - d_model, depth, n_heads, patch_size, img_size, True, bias, mlp_ratio, - dropout, layer_scale_init, stochastic_depth, norm_eps + d_model, depth, n_heads, patch_size, img_size, True, "cls_token", bias, + mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps, ) # fmt: on self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) @@ -133,7 +133,7 @@ def __init__( ): # fmt: off super().__init__( - d_model, depth, n_heads, patch_size, img_size, cls_token, bias, + d_model, depth, n_heads, patch_size, img_size, cls_token, "cls_token", bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps, ) # fmt: on