Skip to content

Commit

Permalink
fix(distributed): exogenous handling in distributed cross validation (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Nov 5, 2024
1 parent eab21b1 commit 3626418
Show file tree
Hide file tree
Showing 3 changed files with 310 additions and 229 deletions.
10 changes: 8 additions & 2 deletions mlforecast/distributed/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _preprocess_partition(
train = part[train_mask]
valid_keep_cols = part.columns
if static_features is not None:
valid_keep_cols.drop(static_features)
valid_keep_cols = valid_keep_cols.drop(static_features)
valid = part.loc[valid_mask, valid_keep_cols].merge(cutoffs, on=id_col)
transformed = ts.fit_transform(
train,
Expand Down Expand Up @@ -456,6 +456,8 @@ def _predict(
) -> Iterable[pd.DataFrame]:
for serialized_ts, _, serialized_valid in items:
valid = cloudpickle.loads(serialized_valid)
if valid is not None:
X_df = valid
ts = cloudpickle.loads(serialized_ts)
res = ts.predict(
models=models,
Expand Down Expand Up @@ -649,7 +651,11 @@ def cross_validation(
engine=self.engine,
)
results.append(fa.get_native_as_df(preds))
return fa.union(*results)
if len(results) == 1:
res = results[0]
else:
res = fa.union(*results)
return res

@staticmethod
def _save_ts(items: List[List[Any]], path: str) -> Iterable[pd.DataFrame]:
Expand Down
10 changes: 8 additions & 2 deletions nbs/distributed.forecast.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@
" train = part[train_mask]\n",
" valid_keep_cols = part.columns\n",
" if static_features is not None:\n",
" valid_keep_cols.drop(static_features)\n",
" valid_keep_cols = valid_keep_cols.drop(static_features)\n",
" valid = part.loc[valid_mask, valid_keep_cols].merge(cutoffs, on=id_col)\n",
" transformed = ts.fit_transform(\n",
" train,\n",
Expand Down Expand Up @@ -508,6 +508,8 @@
" ) -> Iterable[pd.DataFrame]:\n",
" for serialized_ts, _, serialized_valid in items:\n",
" valid = cloudpickle.loads(serialized_valid)\n",
" if valid is not None:\n",
" X_df = valid\n",
" ts = cloudpickle.loads(serialized_ts)\n",
" res = ts.predict(\n",
" models=models,\n",
Expand Down Expand Up @@ -695,7 +697,11 @@
" engine=self.engine,\n",
" )\n",
" results.append(fa.get_native_as_df(preds))\n",
" return fa.union(*results)\n",
" if len(results) == 1:\n",
" res = results[0]\n",
" else:\n",
" res = fa.union(*results)\n",
" return res\n",
"\n",
" @staticmethod\n",
" def _save_ts(items: List[List[Any]], path: str) -> Iterable[pd.DataFrame]:\n",
Expand Down
Loading

0 comments on commit 3626418

Please sign in to comment.