diff --git a/baselines/MegaCRN/arch/megacrn_arch.py b/baselines/MegaCRN/arch/megacrn_arch.py index b3196a0f..e5fdbf72 100644 --- a/baselines/MegaCRN/arch/megacrn_arch.py +++ b/baselines/MegaCRN/arch/megacrn_arch.py @@ -125,15 +125,17 @@ class MegaCRN(nn.Module): - First, MegaCRN calculates the metrics (masked mae, mape, rmse) in each batch and then averages them. However, a correct implementation should compute metrics over the entire test dataset. These two implementations are very different, since the distribution of null values in each batch is very different, - which significantly affects the results of masked_mae, masked_mape, and masked_rmse[2]. + which significantly affects the results of masked_mae, masked_mape, and masked_rmse[2]. # NOTE: Solved by this new script [5]. - Second, MegaCRN pads the last batch of the test dataset. - When we compute metrics on the test dataset, we need to remove these padded samples. + When we compute metrics on the test dataset, we need to remove these padded samples [3]. In BasicTS, we avoid these mistakes based on the unified pipeline. - The hyper-parameters of the model, the optimizer, and other tricks like gradient clip, all follow the MegaCRN's official settings. (Refer to [3]) + The hyper-parameters of the model, the optimizer, and other tricks like gradient clip, all follow the MegaCRN's official settings. (Refer to [4]) [1] https://github.com/deepkashiwa20/MegaCRN/issues/1 [2] https://github.com/chaoshangcs/GTS/issues/19#issuecomment-1079932786 - [3] https://github.com/deepkashiwa20/MegaCRN/blob/main/model/traintest_MegaCRN.py + [3] https://github.com/nnzhan/Graph-WaveNet/blob/6b162e80c59a1d494809252eca055cff93dc66b1/train.py#L145 + [4] https://github.com/deepkashiwa20/MegaCRN/blob/main/model/traintest_MegaCRN.py + [4] https://github.com/deepkashiwa20/MegaCRN/blob/main/model/traintestv1_MegaCRN.py """ def __init__(self, num_nodes, input_dim, output_dim, horizon, rnn_units, num_layers=1, cheb_k=3, ycov_dim=1, mem_num=20, mem_dim=64, cl_decay_steps=2000, use_curriculum_learning=True):