Skip to content

Commit

Permalink
MLflow: On CI, use model "ets_cds_dt" only, having the smallest MASE
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Feb 13, 2024
1 parent 01e7884 commit e8c65af
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1086,12 +1086,13 @@
"# all available models are included by default)\n",
"# - \"fold\" defines the number of folds to use for cross-validation.\n",
"\n",
"# Note: This is only relevant if we are executing automated tests\n",
"# On CI, only evaluate a single cheap model.\n",
"if \"PYTEST_CURRENT_TEST\" in os.environ:\n",
" best_models = compare_models(sort=\"MASE\",\n",
" include=[\"ets\", \"et_cds_dt\", \"naive\"],\n",
" include=[\"et_cds_dt\"],\n",
" n_select=3)\n",
"# If we are not in an automated test, compare all available models\n",
"\n",
"# When not on CI/testing, compare all available models.\n",
"else:\n",
" best_models = compare_models(sort=\"MASE\", n_select=3)"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,14 @@ def fetch_data():

def run_experiment(data):
setup(data = data, fh=15, target="total_sales", index="month", log_experiment=True)

# On CI, only evaluate a single cheap model.
if "PYTEST_CURRENT_TEST" in os.environ:
best_models = compare_models(sort="MASE",
include=["arima", "ets", "exp_smooth"],
include=["et_cds_dt"],
n_select=3)

# When not on CI/testing, compare all available models.
else:
best_models = compare_models(sort="MASE", n_select=3)

Expand Down

0 comments on commit e8c65af

Please sign in to comment.