diff --git a/DeepFilterNet/df/enhance.py b/DeepFilterNet/df/enhance.py index 0eab4888f..9900f9e64 100644 --- a/DeepFilterNet/df/enhance.py +++ b/DeepFilterNet/df/enhance.py @@ -51,6 +51,7 @@ def main(args): log_level=args.log_level, config_allow_defaults=True, epoch=args.epoch, + mask_only=args.no_df_stage, ) suffix = suffix if args.suffix else None if args.output_dir is None: @@ -105,6 +106,7 @@ def init_df( config_allow_defaults: bool = False, epoch: Union[str, int, None] = "best", default_model: str = DEFAULT_MODEL, + mask_only: bool = False, ) -> Tuple[nn.Module, DF, str]: """Initializes and loads config, model and deep filtering state. @@ -161,7 +163,9 @@ def init_df( load_cp = epoch is not None and not (isinstance(epoch, str) and epoch.lower() == "none") if not load_cp: checkpoint_dir = None - mask_only = config("mask_only", cast=bool, section="train", default=False, save=False) + mask_only = mask_only or config( + "mask_only", cast=bool, section="train", default=False, save=False + ) model, epoch = load_model_cp(checkpoint_dir, df_state, epoch=epoch, mask_only=mask_only) if (epoch is None or epoch == 0) and load_cp: logger.error("Could not find a checkpoint") @@ -361,6 +365,7 @@ def run(): dest="suffix", help="Don't add the model suffix to the enhanced audio files", ) + parser.add_argument("--no-df-stage", action="store_true") args = parser.parse_args() main(args)