diff --git a/docs/zh/examples/nowcastnet.md b/docs/zh/examples/nowcastnet.md index 456e90e0a..176e9113e 100644 --- a/docs/zh/examples/nowcastnet.md +++ b/docs/zh/examples/nowcastnet.md @@ -16,6 +16,18 @@ python nowcastnet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/nowcastnet/nowcastnet_pretrained.pdparams ``` +=== "模型导出命令" + + ``` sh + python nowcastnet.py mode=export + ``` + +=== "模型推理命令" + + ``` sh + python nowcastnet.py mode=infer + ``` + ## 1. 背景简介 近年来,深度学习方法已被应用于天气预报,尤其是雷达观测的降水预报。这些方法利用大量雷达复合观测数据来训练神经网络模型,以端到端的方式进行训练,无需明确参考降水过程的物理定律。 diff --git a/examples/nowcastnet/conf/nowcastnet.yaml b/examples/nowcastnet/conf/nowcastnet.yaml index 90bff596d..088a4ee4e 100644 --- a/examples/nowcastnet/conf/nowcastnet.yaml +++ b/examples/nowcastnet/conf/nowcastnet.yaml @@ -11,6 +11,8 @@ hydra: - TRAIN.checkpoint_path - TRAIN.pretrained_model_path - EVAL.pretrained_model_path + - INFER.pretrained_model_path + - INFER.export_path - mode - output_dir - log_freq @@ -22,6 +24,7 @@ hydra: # general settings mode: eval # running mode: train/eval seed: 42 +log_freq: 20 output_dir: ${hydra:run.dir} NORMAL_DATASET_PATH: datasets/mrms/figure LARGE_DATASET_PATH: datasets/mrms/large_figure @@ -55,3 +58,20 @@ MODEL: # evaluation settings EVAL: pretrained_model_path: checkpoints/paddle_mrms_model + +INFER: + pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/nowcastnet/nowcastnet_pretrained.pdparams + export_path: ./inference/nowcastnet + pdmodel_path: ${INFER.export_path}.pdmodel + pdiparams_path: ${INFER.export_path}.pdiparams + device: gpu + engine: native + precision: fp32 + onnx_path: ${INFER.export_path}.onnx + ir_optim: true + min_subgraph_size: 10 + gpu_mem: 4000 + gpu_id: 0 + max_batch_size: 16 + num_cpu_threads: 4 + batch_size: 1 diff --git a/examples/nowcastnet/nowcastnet.py b/examples/nowcastnet/nowcastnet.py index 6f3fbde79..9156907ac 100644 --- a/examples/nowcastnet/nowcastnet.py +++ b/examples/nowcastnet/nowcastnet.py @@ -82,14 +82,112 @@ def evaluate(cfg: DictConfig): solver.visualize(batch_id) +def export(cfg: DictConfig): + from paddle.static import InputSpec + + # set models + if cfg.CASE_TYPE == "large": + model_cfg = cfg.MODEL.large + elif cfg.CASE_TYPE == "normal": + model_cfg = cfg.MODEL.normal + else: + raise ValueError( + f"cfg.CASE_TYPE should in ['normal', 'large'], but got '{cfg.mode}'" + ) + model = ppsci.arch.NowcastNet(**model_cfg) + + # load pretrained model + solver = ppsci.solver.Solver( + model=model, pretrained_model_path=cfg.INFER.pretrained_model_path + ) + # export models + input_spec = [ + { + key: InputSpec( + [None, 29, model_cfg.image_width, model_cfg.image_height, 2], + "float32", + name=key, + ) + for key in model_cfg.input_keys + }, + ] + solver.export(input_spec, cfg.INFER.export_path) + + +def inference(cfg: DictConfig): + import os.path as osp + + from deploy.python_infer import pinn_predictor + + # set model predictor + predictor = pinn_predictor.PINNPredictor(cfg) + + if cfg.CASE_TYPE == "large": + dataset_path = cfg.LARGE_DATASET_PATH + model_cfg = cfg.MODEL.large + output_dir = osp.join(cfg.output_dir, "large") + elif cfg.CASE_TYPE == "normal": + dataset_path = cfg.NORMAL_DATASET_PATH + model_cfg = cfg.MODEL.normal + output_dir = osp.join(cfg.output_dir, "normal") + else: + raise ValueError( + f"cfg.CASE_TYPE should in ['normal', 'large'], but got '{cfg.mode}'" + ) + + input_keys = ("radar_frames",) + dataset_param = { + "input_keys": input_keys, + "label_keys": (), + "image_width": model_cfg.image_width, + "image_height": model_cfg.image_height, + "total_length": model_cfg.total_length, + "dataset_path": dataset_path, + } + test_data_loader = paddle.io.DataLoader( + ppsci.data.dataset.RadarDataset(**dataset_param), + batch_size=cfg.INFER.batch_size, + num_workers=cfg.CPU_WORKER, + drop_last=True, + ) + for batch_id, test_ims in enumerate(test_data_loader): + if batch_id > cfg.NUM_SAVE_SAMPLES: + break + test_ims = {"input": test_ims[0][input_keys[0]].numpy()} + output_dict = predictor.predict(test_ims, cfg.INFER.batch_size) + # mapping data to model_cfg.output_keys + output_dict = { + store_key: output_dict[infer_key] + for store_key, infer_key in zip(model_cfg.output_keys, output_dict.keys()) + } + + visualizer = ppsci.visualize.VisualizerRadar( + test_ims, + { + "output": lambda out: out["output"], + }, + prefix="v_nowcastnet", + case_type=cfg.CASE_TYPE, + total_length=model_cfg.total_length, + ) + test_ims.update(output_dict) + visualizer.save(osp.join(output_dir, f"epoch_{batch_id}"), test_ims) + + @hydra.main(version_base=None, config_path="./conf", config_name="nowcastnet.yaml") def main(cfg: DictConfig): if cfg.mode == "train": train(cfg) elif cfg.mode == "eval": evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) else: - raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" + ) if __name__ == "__main__": diff --git a/ppsci/arch/nowcastnet.py b/ppsci/arch/nowcastnet.py index 3972955b8..38d520961 100644 --- a/ppsci/arch/nowcastnet.py +++ b/ppsci/arch/nowcastnet.py @@ -127,11 +127,12 @@ def forward_tensor(self, x): # Generative Network evo_feature = self.gen_enc(paddle.concat(x=[input_frames, evo_result], axis=1)) noise = paddle.randn(shape=[batch, self.ngf, height // 32, width // 32]) + noise = self.proj(noise) + ngf = noise.shape[1] noise_feature = ( - self.proj(noise) - .reshape((batch, -1, 4, 4, 8, 8)) + noise.reshape((batch, -1, 4, 4, 8, 8)) .transpose(perm=[0, 1, 4, 5, 2, 3]) - .reshape((batch, -1, height // 8, width // 8)) + .reshape((batch, ngf // 16, height // 8, width // 8)) ) feature = paddle.concat(x=[evo_feature, noise_feature], axis=1) gen_result = self.gen_dec(feature, evo_result) @@ -461,7 +462,7 @@ class Noise_Projector(paddle.nn.Layer): def __init__(self, input_length): super().__init__() self.input_length = input_length - self.conv_first = spectral_norm( + self.conv_first = paddle.nn.utils.spectral_norm( paddle.nn.Conv2D( in_channels=self.input_length, out_channels=self.input_length * 2, @@ -486,7 +487,7 @@ def forward(self, x): class ProjBlock(paddle.nn.Layer): def __init__(self, in_channel, out_channel): super().__init__() - self.one_conv = spectral_norm( + self.one_conv = paddle.nn.utils.spectral_norm( paddle.nn.Conv2D( in_channels=in_channel, out_channels=out_channel - in_channel, @@ -495,7 +496,7 @@ def __init__(self, in_channel, out_channel): ) ) self.double_conv = paddle.nn.Sequential( - spectral_norm( + paddle.nn.utils.spectral_norm( paddle.nn.Conv2D( in_channels=in_channel, out_channels=out_channel, @@ -504,7 +505,7 @@ def __init__(self, in_channel, out_channel): ) ), paddle.nn.ReLU(), - spectral_norm( + paddle.nn.utils.spectral_norm( paddle.nn.Conv2D( in_channels=out_channel, out_channels=out_channel,