diff --git a/paddleseg/core/export.py b/paddleseg/core/export.py index 839f22980..586f35e93 100644 --- a/paddleseg/core/export.py +++ b/paddleseg/core/export.py @@ -49,8 +49,12 @@ def export(args, model=None, save_dir=None, use_ema=False): save_dir = args.save_dir os.environ['PADDLESEG_EXPORT_STAGE'] = 'True' - shape = [None, 3, None, None] if args.input_shape is None \ - else args.input_shape + if args.input_shape is None: + assert cfg.model_cfg['type'] not in ['MaskFormer'], \ + "Please specify `--input_shape` when exporting MaskFormer models!" + shape = [None, 3, None, None] + else: + shape = args.input_shape input_spec = [paddle.static.InputSpec(shape=shape, dtype='float32')] model.eval() model = paddle.jit.to_static(model, input_spec=input_spec) @@ -79,6 +83,19 @@ def export(args, model=None, save_dir=None, use_ema=False): val_dataset_cfg = cfg.val_dataset_cfg assert val_dataset_cfg != {}, 'No val_dataset specified in the configuration file.' transforms = val_dataset_cfg.get('transforms', None) + if cfg.model_cfg["type"] in ["MaskFormer"]: + resize_trans = { + "type": "Resize", + "target_size": shape[-2:], + "keep_ratio": False, + "interp": "NEAREST"} + if transforms is None: + transforms = [resize_trans] + else: + for t in transforms: + if t["type"] in ["Resize", "ResizeByShort", "ResizeByLong"]: + t.clear() + t.update(resize_trans) output_dtype = 'int32' if args.output_op == 'argmax' else 'float32' # TODO add test config