diff --git a/linmod/data.py b/linmod/data.py index cb6e139..887d13f 100755 --- a/linmod/data.py +++ b/linmod/data.py @@ -26,6 +26,7 @@ import sys from datetime import datetime from pathlib import Path +from typing import Optional from urllib.parse import urlparse from urllib.request import urlopen @@ -133,14 +134,11 @@ """ -if __name__ == "__main__": - # Load configuration, if given - +def main(cfg: Optional[dict]): config = DEFAULT_CONFIG - if len(sys.argv) > 1: - with open(sys.argv[1]) as f: - config["data"] |= yaml.safe_load(f)["data"] + if cfg is not None: + config["data"] |= cfg["data"] # Download the data, if necessary @@ -267,3 +265,14 @@ model_df.write_csv(config["data"]["save_path"]["model"]) print_message(" done.") + + +if __name__ == "__main__": + # Load configuration, if given + + cfg = None + if len(sys.argv) > 1: + with open(sys.argv[1]) as f: + cfg = yaml.safe_load(f)["data"] + + main(cfg) diff --git a/present-day-forecasting/README.md b/present-day-forecasting/README.md index 5677b19..43bfa85 100644 --- a/present-day-forecasting/README.md +++ b/present-day-forecasting/README.md @@ -1,6 +1,10 @@ -# Demo of model fitting and evaluation +# Model fitting and evaluation -Current workflow, from top-level of repo: +From top-level of repo, run `present-day-forecasting/main.py present-day-forecasting/config.yaml`. -1. Download data with `python3 -m linmod.data present-day-forecasting/config.yaml` -2. Fit and evaluate models with `present-day-forecasting/main.py present-day-forecasting/config.yaml` +This: +1. Creates datasets for forecasting and evaluation (from a cached dataset if available). +2. Runs all forecasting models specified in `present-day-forecasting/config.yaml` and plots a simple summary. +3. Evaluates the model forecasts using the created evaluation dataset. + +Parameters failing MCMC convergence diagnostics are reported, but not excluded from downstream steps. diff --git a/present-day-forecasting/main.py b/present-day-forecasting/main.py index 7ebac3c..a78f160 100755 --- a/present-day-forecasting/main.py +++ b/present-day-forecasting/main.py @@ -10,6 +10,7 @@ import yaml from numpyro.infer import MCMC, NUTS +import linmod.data import linmod.eval import linmod.models from linmod.utils import print_message @@ -27,6 +28,9 @@ with open(sys.argv[1]) as f: config = yaml.safe_load(f) +# Create the datasets +linmod.data.main(config) + # Load the dataset used for retrospective forecasting data = pl.read_csv(config["data"]["save_file"]["model"], try_parse_dates=True)