From 7ba7dbb035a191286cb958648d770e93e4dbe97c Mon Sep 17 00:00:00 2001 From: Wenfei Yan <87323464+wenfeiy-db@users.noreply.github.com> Date: Tue, 18 Jan 2022 15:20:17 -0500 Subject: [PATCH] Add model wrapper, training module, and diagnostics functions for pmdarima (#20) * add model wrapper and tests * add training and diagnostics and their tests * add missing time steps * add noqa for load_context * add init to avoid test file name conflict * add init to avoid test file name conflict * use fixed arima order * resolve minor comments * use offset map also in model * reduce dup * fix estimator fit * fix test * test import location * fix tests * add plot * fix plot import * remove koalas from requriements.txt * add line --- .../automl_runtime/forecast/__init__.py | 37 +++ .../forecast/pmdarima/__init__.py | 0 .../forecast/pmdarima/diagnostics.py | 116 ++++++++ .../automl_runtime/forecast/pmdarima/model.py | 265 ++++++++++++++++++ .../forecast/pmdarima/training.py | 109 +++++++ .../automl_runtime/forecast/pmdarima/utils.py | 92 ++++++ .../automl_runtime/forecast/prophet/model.py | 27 +- runtime/environment.txt | 1 + runtime/requirements.txt | 1 + runtime/setup.py | 1 + .../tests/automl_runtime/forecast/__init__.py | 0 .../forecast/pmdarima/__init__.py | 0 .../forecast/pmdarima/diagnostics_test.py | 81 ++++++ .../forecast/pmdarima/model_test.py | 172 ++++++++++++ .../forecast/pmdarima/training_test.py | 60 ++++ .../forecast/pmdarima/utils_test.py | 76 +++++ .../forecast/prophet/__init__.py | 0 .../prophet/diagnostics_test.py | 0 .../{ => forecast}/prophet/forecast_test.py | 0 .../{ => forecast}/prophet/model_test.py | 0 20 files changed, 1015 insertions(+), 23 deletions(-) create mode 100644 runtime/databricks/automl_runtime/forecast/pmdarima/__init__.py create mode 100644 runtime/databricks/automl_runtime/forecast/pmdarima/diagnostics.py create mode 100644 runtime/databricks/automl_runtime/forecast/pmdarima/model.py create mode 100644 runtime/databricks/automl_runtime/forecast/pmdarima/training.py create mode 100644 runtime/databricks/automl_runtime/forecast/pmdarima/utils.py create mode 100644 runtime/tests/automl_runtime/forecast/__init__.py create mode 100644 runtime/tests/automl_runtime/forecast/pmdarima/__init__.py create mode 100644 runtime/tests/automl_runtime/forecast/pmdarima/diagnostics_test.py create mode 100644 runtime/tests/automl_runtime/forecast/pmdarima/model_test.py create mode 100644 runtime/tests/automl_runtime/forecast/pmdarima/training_test.py create mode 100644 runtime/tests/automl_runtime/forecast/pmdarima/utils_test.py create mode 100644 runtime/tests/automl_runtime/forecast/prophet/__init__.py rename runtime/tests/automl_runtime/{ => forecast}/prophet/diagnostics_test.py (100%) rename runtime/tests/automl_runtime/{ => forecast}/prophet/forecast_test.py (100%) rename runtime/tests/automl_runtime/{ => forecast}/prophet/model_test.py (100%) diff --git a/runtime/databricks/automl_runtime/forecast/__init__.py b/runtime/databricks/automl_runtime/forecast/__init__.py index e69de29b..5c375933 100644 --- a/runtime/databricks/automl_runtime/forecast/__init__.py +++ b/runtime/databricks/automl_runtime/forecast/__init__.py @@ -0,0 +1,37 @@ +# +# Copyright (C) 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +OFFSET_ALIAS_MAP = { + "W": "W", + "d": "D", + "D": "D", + "days": "D", + "day": "D", + "hours": "H", + "hour": "H", + "hr": "H", + "h": "H", + "H": "H", + "m": "min", + "minute": "min", + "min": "min", + "minutes": "min", + "T": "min", + "S": "S", + "seconds": "S", + "sec": "S", + "second": "S" +} diff --git a/runtime/databricks/automl_runtime/forecast/pmdarima/__init__.py b/runtime/databricks/automl_runtime/forecast/pmdarima/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/runtime/databricks/automl_runtime/forecast/pmdarima/diagnostics.py b/runtime/databricks/automl_runtime/forecast/pmdarima/diagnostics.py new file mode 100644 index 00000000..a10e8bce --- /dev/null +++ b/runtime/databricks/automl_runtime/forecast/pmdarima/diagnostics.py @@ -0,0 +1,116 @@ +# +# Copyright (C) 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List + +import pandas as pd +import numpy as np +import pmdarima + + +def generate_cutoffs(df: pd.DataFrame, horizon: int, unit: str, num_folds: int) -> List[pd.Timestamp]: + """ + Generate cutoff times for cross validation with the control of number of folds. + :param df: pd.DataFrame of the historical data + :param horizon: int number of time into the future for forecasting. + :param unit: frequency of the timeseries, which must be a pandas offset alias. + :param num_folds: int number of cutoffs for cross validation. + :return: list of pd.Timestamp cutoffs for corss-validation. + """ + period = max(0.5 * horizon, 1) # avoid empty cutoff buckets + period = pd.to_timedelta(period, unit=unit) + horizon = pd.to_timedelta(horizon, unit=unit) + + period_max = 0 # TODO: set period_max properly once different seasonalities are introduced + seasonality_timedelta = pd.Timedelta(str(period_max) + " days") + + initial = max(3 * horizon, seasonality_timedelta) + + # Last cutoff is "latest date in data - horizon" date + cutoff = df["ds"].max() - horizon + if cutoff < df["ds"].min(): + raise ValueError("Less data than horizon.") + result = [cutoff] + while result[-1] >= min(df["ds"]) + initial and len(result) < num_folds: + cutoff -= period + # If data does not exist in data range (cutoff, cutoff + horizon] + if not (((df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)).any()): + # Next cutoff point is "last date before cutoff in data - horizon" + if cutoff > df["ds"].min(): + closest_date = df[df["ds"] <= cutoff].max()["ds"] + cutoff = closest_date - horizon + # else no data left, leave cutoff as is, it will be dropped. + result.append(cutoff) + result = result[:-1] + if len(result) == 0: + raise ValueError( + "Less data than horizon after initial window. Make horizon shorter." + ) + return list(reversed(result)) + + +def cross_validation(arima_model: pmdarima.arima.ARIMA, df: pd.DataFrame, cutoffs: List[pd.Timestamp]) -> pd.DataFrame: + """ + Cross-Validation for time series forecasting. + + Computes forecasts from historical cutoff points. The function is a modification of + prophet.diagnostics.cross_validation that works for ARIMA model. + :param arima_model: pmdarima.arima.ARIMA object. Fitted ARIMA model. + :param df: pd.DataFrame of the historical data + :param cutoffs: list of pd.Timestamp specifying cutoffs to be used during cross validation. + :return: a pd.DataFrame with the forecast, confidence interval, actual value, and cutoff. + """ + bins = [df["ds"].min()] + cutoffs + [df["ds"].max()] + labels = [df["ds"].min()] + cutoffs + test_df = df[df['ds'] > cutoffs[0]].copy() + test_df["cutoff"] = pd.to_datetime(pd.cut(test_df["ds"], bins=bins, labels=labels)) + + predicts = [single_cutoff_forecast(arima_model, test_df, prev_cutoff, cutoff) for prev_cutoff, cutoff in + zip(labels, cutoffs)] + + # Update model with data in last cutoff + last_df = test_df[test_df["cutoff"] == cutoffs[-1]] + arima_model.update(last_df["y"].values) + + return pd.concat(predicts, axis=0).reset_index(drop=True) + + +def single_cutoff_forecast(arima_model: pmdarima.arima.ARIMA, test_df: pd.DataFrame, prev_cutoff: pd.Timestamp, + cutoff: pd.Timestamp) -> pd.DataFrame: + """ + Forecast for single cutoff. Used in the cross validation function. + :param arima_model: pmdarima.arima.ARIMA object. Fitted ARIMA model. + :param test_df: pd.DataFrame with data to be used for updating model and forecasting. + :param prev_cutoff: the pd.Timestamp cutoff of the previous forecast. + Data between prev_cutoff and cutoff will be used to update the model. + :param cutoff: pd.Timestamp cutoff of this forecast. The simulated forecast will start from this date. + :return: a pd.DataFrame with the forecast, confidence interval, actual value, and cutoff. + """ + # Update the model with data in the previous cutoff + prev_df = test_df[test_df["cutoff"] == prev_cutoff] + if not prev_df.empty: + y_update = prev_df[["ds", "y"]].set_index("ds") + arima_model.update(y_update) + # Predict with data in the new cutoff + new_df = test_df[test_df["cutoff"] == cutoff].copy() + n_periods = len(new_df["y"].values) + fc, conf_int = arima_model.predict(n_periods=n_periods, return_conf_int=True) + fc = fc.tolist() + conf = np.asarray(conf_int).tolist() + + new_df["yhat"] = fc + new_df[["yhat_lower", "yhat_upper"]] = conf + return new_df diff --git a/runtime/databricks/automl_runtime/forecast/pmdarima/model.py b/runtime/databricks/automl_runtime/forecast/pmdarima/model.py new file mode 100644 index 00000000..3a74d2d2 --- /dev/null +++ b/runtime/databricks/automl_runtime/forecast/pmdarima/model.py @@ -0,0 +1,265 @@ +# +# Copyright (C) 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pickle +from abc import ABC, abstractmethod +from typing import List, Dict + +import pandas as pd +import mlflow +import pmdarima +from mlflow.exceptions import MlflowException +from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE + +from databricks.automl_runtime.forecast import OFFSET_ALIAS_MAP + + +class AbstractArimaModel(ABC, mlflow.pyfunc.PythonModel): + @abstractmethod + def __init__(self): + super().__init__() + + def load_context(self, context: mlflow.pyfunc.model.PythonModelContext) -> None: + """ + Loads artifacts from the specified PythonModelContext. + + Loads artifacts from the specified PythonModelContext that can be used by + PythonModel.predict when evaluating inputs. When loading an MLflow model with + load_pyfunc, this method is called as soon as the PythonModel is constructed. + :param context: A PythonModelContext instance containing artifacts that the model + can use to perform inference. + """ + from pmdarima.arima import ARIMA # noqa: F401 + + @staticmethod + def _validate_cols(df: pd.DataFrame, required_cols: List[str]): + df_cols = set(df.columns) + required_cols_set = set(required_cols) + if not required_cols_set.issubset(df_cols): + raise MlflowException( + message=( + f"Input data columns '{list(df_cols)}' do not contain the required columns '{required_cols}'" + ), + error_code=INVALID_PARAMETER_VALUE, + ) + + +class ArimaModel(AbstractArimaModel): + """ + ARIMA mlflow model wrapper for univariate forecasting. + """ + + def __init__(self, pickled_model: bytes, horizon: int, frequency: str, + start_ds: pd.Timestamp, end_ds: pd.Timestamp, time_col: str) -> None: + """ + Initialize the mlflow Python model wrapper for ARIMA. + :param pickled_model: the pickled ARIMA model as a bytes object. + :param horizon: int number of periods to forecast forward. + :param frequency: the frequency of the time series + :param start_ds: the start time of training data + :param end_ds: the end time of training data + :param time_col: the column name of the time column + """ + super().__init__() + self._pickled_model = pickled_model + self._horizon = horizon + self._frequency = OFFSET_ALIAS_MAP[frequency] + self._start_ds = start_ds + self._end_ds = end_ds + self._time_col = time_col + + def model(self) -> pmdarima.arima.ARIMA: + """ + Deserialize the ARIMA model by pickle. + :return: ARIMA model + """ + return pickle.loads(self._pickled_model) + + def predict_timeseries(self, horizon: int = None) -> pd.DataFrame: + """ + Predict target column for given horizon and history data. + :param horizon: int number of periods to forecast forward. + :return: A pd.DataFrame with the forecasts and confidence intervals for given horizon and history data. + """ + horizon = horizon or self._horizon + future_pd = self._forecast(horizon) + in_sample_pd = self._predict_in_sample() + return pd.concat([in_sample_pd, future_pd]).reset_index(drop=True) + + def predict(self, context: mlflow.pyfunc.model.PythonModelContext, model_input: pd.DataFrame) -> pd.Series: + """ + Predict API from mlflow.pyfunc.PythonModel. + + Returns the prediction values for given timestamps in the input dataframe. If an input timestamp + to predict does not match the original frequency that the model trained on, an exception will be thrown. + :param context: A PythonModelContext instance containing artifacts that the model + can use to perform inference. + :param model_input: The input dataframe of the model. Should have the same time column name + as the training data of the ARIMA model. + :return: A pd.Series with the prediction values. + """ + self._validate_cols(model_input, [self._time_col]) + result_df = self._predict_impl(model_input) + return result_df["yhat"] + + def _predict_impl(self, input_df: pd.DataFrame) -> pd.DataFrame: + df = input_df.rename(columns={self._time_col: "ds"}) + # Check if the time has correct frequency + diff = (df["ds"] - self._start_ds) / pd.Timedelta(1, unit=self._frequency) + if not diff.apply(float.is_integer).all(): + raise MlflowException( + message=( + f"Input time column '{self._time_col}' includes different frequency." + ), + error_code=INVALID_PARAMETER_VALUE, + ) + # Validate the time range + pred_start_ds = min(df["ds"]) + if pred_start_ds < self._start_ds: + raise MlflowException( + message=( + f"Input time column '{self._time_col}' includes time earlier than " + "the history data that the model was trained on." + ), + error_code=INVALID_PARAMETER_VALUE, + ) + preds_pds = [] + # Out-of-sample prediction if needed + horizon = int((max(df["ds"]) - self._end_ds) / pd.Timedelta(1, unit=self._frequency)) + if horizon > 0: + future_pd = self._forecast(horizon) + preds_pds.append(future_pd) + # In-sample prediction if needed + if pred_start_ds <= self._end_ds: + in_sample_pd = self._predict_in_sample(start_ds=pred_start_ds, end_ds=self._end_ds) + preds_pds.append(in_sample_pd) + # Map predictions back to given timestamps + preds_pd = pd.concat(preds_pds).set_index("ds") + df = df.set_index("ds").join(preds_pd, how="left").reset_index() + return df + + def _predict_in_sample(self, start_ds: pd.Timestamp = None, end_ds: pd.Timestamp = None) -> pd.DataFrame: + if start_ds and end_ds: + start_idx = int((start_ds - self._start_ds) / pd.Timedelta(1, unit=self._frequency)) + end_idx = int((end_ds - self._start_ds) / pd.Timedelta(1, unit=self._frequency)) + else: + start_ds = self._start_ds + end_ds = self._end_ds + start_idx, end_idx = None, None + preds_in_sample, conf_in_sample = self.model().predict_in_sample( + start=start_idx, end=end_idx, return_conf_int=True) + dates_in_sample = pd.date_range(start=start_ds, end=end_ds, freq=self._frequency) + in_sample_pd = pd.DataFrame({'ds': dates_in_sample, 'yhat': preds_in_sample}) + in_sample_pd[["yhat_lower", "yhat_upper"]] = conf_in_sample + return in_sample_pd + + def _forecast(self, horizon: int = None) -> pd.DataFrame: + horizon = horizon or self._horizon + preds, conf = self.model().predict(horizon, return_conf_int=True) + dates = pd.date_range(start=self._end_ds, periods=horizon + 1, freq=self._frequency)[1:] + preds_pd = pd.DataFrame({'ds': dates, 'yhat': preds}) + preds_pd[["yhat_lower", "yhat_upper"]] = conf + return preds_pd + + +class MultiSeriesArimaModel(AbstractArimaModel): + """ + ARIMA mlflow model wrapper for multivariate forecasting. + """ + + def __init__(self, pickled_model_dict: Dict[str, bytes], horizon: int, frequency: str, + start_ds_dict: Dict[str, pd.Timestamp], end_ds_dict: Dict[str, pd.Timestamp], + time_col: str, id_cols: List[str]) -> None: + """ + Initialize the mlflow Python model wrapper for multiseries ARIMA. + :param pickled_model_dict: the dictionary of binarized ARIMA models for different time series. + :param horizon: int number of periods to forecast forward. + :param frequency: the frequency of the time series + :param start_ds_dict: the dictionary of the starting time of each time series in training data. + :param end_ds_dict: the dictionary of the starting time of each time series in training data. + :param time_col: the column name of the time column + :param id_cols: the column names of the identity columns for multi-series time series + """ + super().__init__() + self._pickled_models = pickled_model_dict + self._horizon = horizon + self._frequency = frequency + self._starts = start_ds_dict + self._ends = end_ds_dict + self._time_col = time_col + self._id_cols = id_cols + + def model(self, id_: str) -> pmdarima.arima.ARIMA: + """ + Deserialize the ARIMA model for specified time series by pickle. + :param: id for specified time series. + :return: ARIMA model + """ + return pickle.loads(self._pickled_models[id_]) + + def predict_timeseries(self, horizon: int = None) -> pd.DataFrame: + """ + Predict target column for given horizon and history data. + :param horizon: Int number of periods to forecast forward. + :return: A pd.DataFrame with the forecast components. + """ + horizon = horizon or self._horizon + ids = self._pickled_models.keys() + preds_dfs = list(map(lambda id_: self._predict_timeseries_single_id(id_, horizon), ids)) + return pd.concat(preds_dfs).reset_index(drop=True) + + def _predict_timeseries_single_id(self, id_: str, horizon: int) -> pd.DataFrame: + arima_model_single_id = ArimaModel(self._pickled_models[id_], self._horizon, self._frequency, + self._starts[id_], self._ends[id_], self._time_col) + preds_df = arima_model_single_id.predict_timeseries(horizon) + preds_df["ts_id"] = id_ + return preds_df + + def predict(self, context: mlflow.pyfunc.model.PythonModelContext, model_input: pd.DataFrame) -> pd.Series: + """ + Predict API from mlflow.pyfunc.PythonModel. + + Returns the prediction values for given timestamps in the input dataframe. If an input timestamp + to predict does not match the original frequency that the model trained on, an exception will be thrown. + :param context: A PythonModelContext instance containing artifacts that the model + can use to perform inference. + :param model_input: input dataframe of the model. Should have the same time column + and identity columns names as the training data of the ARIMA model. + :return: A pd.Series with the prediction values. + """ + self._validate_cols(model_input, self._id_cols + [self._time_col]) + df = model_input.copy() + df["ts_id"] = df[self._id_cols].apply(lambda r: "-".join(r.values.astype(str)), axis=1) + known_ids = set(self._pickled_models.keys()) + ids = set(df["ts_id"].unique()) + if not ids.issubset(known_ids): + raise MlflowException( + message=( + f"Input data includes unseen values in id columns '{self._id_cols}'." + f"Expected combined ids: {known_ids}\n" + f"Got ids: {ids}\n" + ), + error_code=INVALID_PARAMETER_VALUE, + ) + preds_df = df.groupby("ts_id").apply(self._predict_single_id).reset_index(drop=True) + df = df.merge(preds_df, how="left", on=[self._time_col, "ts_id"]) # merge predictions to original order + return df["yhat"] + + def _predict_single_id(self, df: pd.DataFrame) -> pd.DataFrame: + id_ = df["ts_id"].to_list()[0] + arima_model_single_id = ArimaModel(self._pickled_models[id_], self._horizon, self._frequency, + self._starts[id_], self._ends[id_], self._time_col) + df["yhat"] = arima_model_single_id.predict(None, df).to_list() + return df diff --git a/runtime/databricks/automl_runtime/forecast/pmdarima/training.py b/runtime/databricks/automl_runtime/forecast/pmdarima/training.py new file mode 100644 index 00000000..4e66d7d4 --- /dev/null +++ b/runtime/databricks/automl_runtime/forecast/pmdarima/training.py @@ -0,0 +1,109 @@ +# +# Copyright (C) 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List + +import pandas as pd +import pickle +import numpy as np +import pmdarima as pm +from pmdarima.arima import StepwiseContext +from prophet.diagnostics import performance_metrics + +from databricks.automl_runtime.forecast.pmdarima.diagnostics import generate_cutoffs, cross_validation +from databricks.automl_runtime.forecast import OFFSET_ALIAS_MAP + + +class ArimaEstimator: + """ + ARIMA estimator using pmdarima.auto_arima. + """ + + def __init__(self, horizon: int, frequency_unit: str, metric: str, seasonal_periods: List[int], + num_folds: int = 20, max_steps: int = 150) -> None: + """ + :param horizon: Number of periods to forecast forward + :param frequency_unit: Frequency of the time series + :param metric: Metric that will be optimized across trials + :param seasonal_periods: A list of seasonal periods for tuning. + :param num_folds: Number of folds for cross validation + :param max_steps: Max steps for stepwise auto_arima + """ + self._horizon = horizon + self._frequency_unit = OFFSET_ALIAS_MAP[frequency_unit] + self._metric = metric + self._seasonal_periods = seasonal_periods + self._num_folds = num_folds + self._max_steps = max_steps + + def fit(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Fit the ARIMA model with tuning of seasonal period m and with pmdarima.auto_arima. + :param df: A pd.DataFrame containing the history data. Must have columns ds and y. + :return: A pd.DataFrame with the best model (pickled) and its metrics from cross validation. + """ + history_pd = df.sort_values(by=["ds"]).reset_index(drop=True) + history_pd["ds"] = pd.to_datetime(history_pd["ds"]) + + # Impute missing time steps + history_pd = self._fill_missing_time_steps(history_pd, self._frequency_unit) + + # Generate cutoffs for cross validation + cutoffs = generate_cutoffs(history_pd, horizon=self._horizon, unit=self._frequency_unit, + num_folds=self._num_folds) + + # Tune seasonal periods + best_result = None + best_metric = float("inf") + for m in self._seasonal_periods: + result = self._fit_predict(history_pd, cutoffs, m, self._max_steps) + metric = result["metrics"]["smape"] + if metric < best_metric: + best_result = result + best_metric = metric + + results_pd = pd.DataFrame(best_result["metrics"], index=[0]) + results_pd["pickled_model"] = pickle.dumps(best_result["model"]) + + return results_pd + + @staticmethod + def _fit_predict(df: pd.DataFrame, cutoffs: List[pd.Timestamp], seasonal_period: int, max_steps: int = 150): + train_df = df[df['ds'] <= cutoffs[0]] + y_train = train_df[["ds", "y"]].set_index("ds") + + # Train with the initial interval + with StepwiseContext(max_steps=max_steps): + arima_model = pm.auto_arima( + y=y_train, + m=seasonal_period, + stepwise=True, + ) + + # Evaluate with cross validation + df_cv = cross_validation(arima_model, df, cutoffs) + df_metrics = performance_metrics(df_cv) + metrics = df_metrics.drop("horizon", axis=1).mean().to_dict() + # performance_metrics doesn't calculate mape if any y is close to 0 + if "mape" not in metrics: + metrics["mape"] = np.nan + + return {"metrics": metrics, "model": arima_model} + + @staticmethod + def _fill_missing_time_steps(df: pd.DataFrame, frequency: str): + # Forward fill missing time steps + return df.set_index("ds").resample(rule=OFFSET_ALIAS_MAP[frequency]).pad().reset_index() diff --git a/runtime/databricks/automl_runtime/forecast/pmdarima/utils.py b/runtime/databricks/automl_runtime/forecast/pmdarima/utils.py new file mode 100644 index 00000000..12937e8e --- /dev/null +++ b/runtime/databricks/automl_runtime/forecast/pmdarima/utils.py @@ -0,0 +1,92 @@ +# +# Copyright (C) 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pickle +from typing import Union, Tuple + +import pandas as pd +import mlflow +import pmdarima +import matplotlib.pyplot as plt +from matplotlib.axes import Axes +from matplotlib.dates import AutoDateLocator, AutoDateFormatter + +from databricks.automl_runtime.forecast.pmdarima.model import ArimaModel, MultiSeriesArimaModel + +ARIMA_CONDA_ENV = { + "channels": ["conda-forge"], + "dependencies": [ + { + "pip": [ + f"pmdarima=={pmdarima.__version__}", + f"pickle=={pickle.format_version}", + f"pandas=={pd.__version__}", + ] + } + ], + "name": "pmdarima_env", +} + + +def mlflow_arima_log_model(arima_model: Union[ArimaModel, MultiSeriesArimaModel]) -> None: + """ + Log the model to mlflow. + :param arima_model: ARIMA model wrapper + """ + mlflow.pyfunc.log_model("model", conda_env=ARIMA_CONDA_ENV, python_model=arima_model) + + +def plot(history_pd: pd.DataFrame, forecast_pd: pd.DataFrame, ax: Axes = None, plot_uncertainty: bool = True, + xlabel: str = 'ds', ylabel: str = 'y', figsize: Tuple[int, int] = (10, 6), include_legend: bool = False): + """ + Plot the forecast. Adapted from prophet.plot.plot. See + https://github.com/facebook/prophet/blob/ba9a5a2c6e2400206017a5ddfd71f5042da9f65b/python/prophet/plot.py#L42. + :param history_pd: pd.DataFrame of history data. + :param forecast_pd: pd.DataFrame with forecasts and optionally confidence interval, sorted by time. + :param ax: Optional matplotlib axes on which to plot. + :param plot_uncertainty: Optional boolean to plot uncertainty intervals, which will + only be done if m.uncertainty_samples > 0. + :param xlabel: Optional label name on X-axis + :param ylabel: Optional label name on Y-axis + :param figsize: Optional tuple width, height in inches. + :param include_legend: Optional boolean to add legend to the plot. + :return: A matplotlib figure. + """ + history_pd = history_pd.sort_values(by=["ds"]) + history_pd["ds"] = pd.to_datetime(history_pd["ds"]) + if ax is None: + fig = plt.figure(facecolor='w', figsize=figsize) + ax = fig.add_subplot(111) + else: + fig = ax.get_figure() + fcst_t = forecast_pd['ds'].dt.to_pydatetime() + ax.plot(history_pd['ds'].dt.to_pydatetime(), history_pd['y'], 'k.', label='Observed data points') + ax.plot(fcst_t, forecast_pd['yhat'], ls='-', c='#0072B2', label='Forecast') + if plot_uncertainty and "yhat_lower" in forecast_pd and "yhat_upper" in forecast_pd: + ax.fill_between(fcst_t, forecast_pd['yhat_lower'], forecast_pd['yhat_upper'], + color='#0072B2', alpha=0.2, label='Uncertainty interval') + # Specify formatting to workaround matplotlib issue #12925 + locator = AutoDateLocator(interval_multiples=False) + formatter = AutoDateFormatter(locator) + ax.xaxis.set_major_locator(locator) + ax.xaxis.set_major_formatter(formatter) + ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + if include_legend: + ax.legend() + fig.tight_layout() + return fig diff --git a/runtime/databricks/automl_runtime/forecast/prophet/model.py b/runtime/databricks/automl_runtime/forecast/prophet/model.py index 0a39bc9e..907821e8 100644 --- a/runtime/databricks/automl_runtime/forecast/prophet/model.py +++ b/runtime/databricks/automl_runtime/forecast/prophet/model.py @@ -20,32 +20,11 @@ import pandas as pd import prophet +from databricks.automl_runtime.forecast import OFFSET_ALIAS_MAP from mlflow.exceptions import MlflowException from mlflow.models.signature import infer_signature, ModelSignature from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE -OFFSET_ALIAS_MAP = { - "W": "W", - "d": "D", - "D": "D", - "days": "D", - "day": "D", - "hours": "H", - "hour": "H", - "hr": "H", - "h": "H", - "H": "H", - "m": "min", - "minute": "min", - "min": "min", - "minutes": "min", - "T": "T", - "S": "S", - "seconds": "S", - "sec": "S", - "second": "S" -} - PROPHET_CONDA_ENV = { "channels": ["conda-forge"], "dependencies": [ @@ -65,6 +44,7 @@ class ProphetModel(mlflow.pyfunc.PythonModel): """ Prophet mlflow model wrapper for univariate forecasting. """ + def __init__(self, model_json: Union[Dict[str, str], str], horizon: int, frequency: str, time_col: str) -> None: """ @@ -160,6 +140,7 @@ class MultiSeriesProphetModel(ProphetModel): """ Prophet mlflow model wrapper for multi-series forecasting. """ + def __init__(self, model_json: Dict[str, str], timeseries_starts: Dict[str, pd.Timestamp], timeseries_end: str, horizon: int, frequency: str, time_col: str, id_cols: List[str], ) -> None: @@ -250,7 +231,7 @@ def model_predict(self, df: pd.DataFrame, horizon: int = None) -> pd.DataFrame: :param horizon: Int number of periods to forecast forward. :return: A pd.DataFrame with the forecast components. """ - forecast_df = self._predict_impl(df, horizon) + forecast_df = self._predict_impl(df, horizon) return_cols = self.get_reserved_cols() + ["ds", "ts_id"] result_df = pd.DataFrame(columns=return_cols) result_df = pd.concat([result_df, forecast_df]) diff --git a/runtime/environment.txt b/runtime/environment.txt index 3907097b..75c16048 100644 --- a/runtime/environment.txt +++ b/runtime/environment.txt @@ -7,6 +7,7 @@ koalas==1.8.1 mlflow==1.22.0 numpy==1.20.2 pandas==1.2.4 +pmdarima==1.8.4 prophet==1.0.1 pyarrow==4.0.0 scikit-learn==0.24.1 diff --git a/runtime/requirements.txt b/runtime/requirements.txt index 22023e34..c4e0b69b 100644 --- a/runtime/requirements.txt +++ b/runtime/requirements.txt @@ -7,6 +7,7 @@ mlflow numpy pandas plotly +pmdarima prophet pyarrow requests diff --git a/runtime/setup.py b/runtime/setup.py index e9a2dc53..d379e94e 100644 --- a/runtime/setup.py +++ b/runtime/setup.py @@ -46,6 +46,7 @@ "databricks", "databricks.automl_runtime", "databricks.automl_runtime.forecast", + "databricks.automl_runtime.forecast.pmdarima", "databricks.automl_runtime.forecast.prophet", "databricks.automl_runtime.hyperopt", "databricks.automl_runtime.sklearn"], diff --git a/runtime/tests/automl_runtime/forecast/__init__.py b/runtime/tests/automl_runtime/forecast/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/runtime/tests/automl_runtime/forecast/pmdarima/__init__.py b/runtime/tests/automl_runtime/forecast/pmdarima/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/runtime/tests/automl_runtime/forecast/pmdarima/diagnostics_test.py b/runtime/tests/automl_runtime/forecast/pmdarima/diagnostics_test.py new file mode 100644 index 00000000..85e80869 --- /dev/null +++ b/runtime/tests/automl_runtime/forecast/pmdarima/diagnostics_test.py @@ -0,0 +1,81 @@ +# +# Copyright (C) 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +import pandas as pd +from pmdarima.arima import auto_arima, StepwiseContext + +from databricks.automl_runtime.forecast.pmdarima.diagnostics import generate_cutoffs, \ + cross_validation, single_cutoff_forecast + + +class TestDiagnostics(unittest.TestCase): + + def setUp(self) -> None: + num_rows = 15 + self.X = pd.concat([ + pd.to_datetime(pd.Series(range(num_rows), name="ds").apply(lambda i: f"2020-07-{i + 1}")), + pd.Series(range(num_rows), name="y") + ], axis=1) + + def test_generate_cutoffs_success(self): + cutoffs = generate_cutoffs(self.X, horizon=3, unit="d", num_folds=3) + self.assertEqual(cutoffs, [pd.Timestamp("2020-07-10 12:00:00"), + pd.Timestamp("2020-07-12 00:00:00")]) + + def test_generate_cutoffs_success_large_num_folds(self): + cutoffs = generate_cutoffs(self.X, horizon=2, unit="d", num_folds=20) + self.assertEqual(cutoffs, [pd.Timestamp("2020-07-07 00:00:00"), + pd.Timestamp("2020-07-08 00:00:00"), + pd.Timestamp("2020-07-09 00:00:00"), + pd.Timestamp("2020-07-10 00:00:00"), + pd.Timestamp("2020-07-11 00:00:00"), + pd.Timestamp("2020-07-12 00:00:00"), + pd.Timestamp('2020-07-13 00:00:00')]) + + def test_generate_cutoffs_failure_horizon_too_large(self): + with self.assertRaisesRegex(ValueError, "Less data than horizon after initial window. " + "Make horizon shorter."): + generate_cutoffs(self.X, horizon=9, unit="d", num_folds=3) + + def test_cross_validation_success(self): + cutoffs = generate_cutoffs(self.X, horizon=3, unit="d", num_folds=3) + y_train = self.X[self.X["ds"] <= cutoffs[0]].set_index("ds") + with StepwiseContext(max_steps=1): + model = auto_arima(y=y_train, m=1) + + expected_ds = self.X[self.X["ds"] > cutoffs[0]]["ds"] + expected_cols = ["ds", "y", "cutoff", "yhat", "yhat_lower", "yhat_upper"] + df_cv = cross_validation(model, self.X, cutoffs) + self.assertEqual(df_cv["ds"].tolist(), expected_ds.tolist()) + self.assertEqual(set(df_cv.columns), set(expected_cols)) + + def test_single_cutoff_forecast_success(self): + cutoff_zero = self.X["ds"].min() + cutoff_one = pd.Timestamp("2020-07-10 12:00:00") + cutoff_two = pd.Timestamp("2020-07-12 00:00:00") + y_train = self.X[self.X["ds"] <= cutoff_one].set_index("ds") + test_df = self.X[self.X['ds'] > cutoff_one].copy() + test_df["cutoff"] = [cutoff_one] * 2 + [cutoff_two] * 3 + with StepwiseContext(max_steps=1): + model = auto_arima(y=y_train, m=1) + + expected_ds = test_df["ds"][:2] + expected_cols = ["ds", "y", "cutoff", "yhat", "yhat_lower", "yhat_upper"] + forecast_df = single_cutoff_forecast(model, test_df, cutoff_zero, cutoff_one) + self.assertEqual(forecast_df["ds"].tolist(), expected_ds.tolist()) + self.assertEqual(set(forecast_df.columns), set(expected_cols)) diff --git a/runtime/tests/automl_runtime/forecast/pmdarima/model_test.py b/runtime/tests/automl_runtime/forecast/pmdarima/model_test.py new file mode 100644 index 00000000..4d0db65a --- /dev/null +++ b/runtime/tests/automl_runtime/forecast/pmdarima/model_test.py @@ -0,0 +1,172 @@ +# +# Copyright (C) 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import pickle +import pytest + +import pandas as pd +from mlflow.exceptions import MlflowException +from mlflow.protos.databricks_pb2 import ErrorCode, INVALID_PARAMETER_VALUE +from pmdarima.arima import ARIMA + +from databricks.automl_runtime.forecast.pmdarima.model import ArimaModel, MultiSeriesArimaModel, AbstractArimaModel + + +class TestArimaModel(unittest.TestCase): + + def setUp(self) -> None: + num_rows = 9 + self.X = pd.concat([ + pd.to_datetime(pd.Series(range(num_rows), name="date").apply(lambda i: f"2020-10-{i + 1}")), + pd.Series(range(num_rows), name="y") + ], axis=1) + model = ARIMA(order=(2, 0, 2), suppress_warnings=True) + model.fit(self.X.set_index("date")) + pickled_model = pickle.dumps(model) + self.arima_model = ArimaModel(pickled_model, horizon=1, frequency='days', + start_ds=pd.to_datetime("2020-10-01"), end_ds=pd.to_datetime("2020-10-09"), + time_col="date") + + def test_predict_timeseries_success(self): + forecast_pd = self.arima_model.predict_timeseries() + expected_columns = {"yhat", "yhat_lower", "yhat_upper"} + self.assertTrue(expected_columns.issubset(set(forecast_pd.columns))) + self.assertEqual(10, forecast_pd.shape[0]) + + def test_predict_success(self): + test_df = pd.DataFrame({ + "date": [pd.to_datetime("2020-10-05"), pd.to_datetime("2020-11-04")] + }) + expected_test_df = test_df.copy() + yhat = self.arima_model.predict(None, test_df) + self.assertEqual(2, len(yhat)) + pd.testing.assert_frame_equal(test_df, expected_test_df) # check the input dataframe is unchanged + + def test_predict_failure_unmatched_frequency(self): + test_df = pd.DataFrame({ + "date": [pd.to_datetime("2020-10-05"), pd.to_datetime("2020-11-04"), pd.to_datetime("2020-11-06 12:30")] + }) + with pytest.raises(MlflowException, match="includes different frequency") as e: + self.arima_model.predict(None, test_df) + assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) + + def test_predict_failure_invalid_time_range(self): + test_df = pd.DataFrame({ + "date": [pd.to_datetime("2000-10-05"), pd.to_datetime("2020-11-04")] + }) + with pytest.raises(MlflowException, match="includes time earlier than the history data that the model was " + "trained on") as e: + self.arima_model.predict(None, test_df) + assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) + + def test_predict_failure_invalid_time_col_name(self): + test_df = pd.DataFrame({ + "invalid_time_col_name": [pd.to_datetime("2020-10-05"), pd.to_datetime("2020-11-04")] + }) + with pytest.raises(MlflowException, match="Input data columns") as e: + self.arima_model.predict(None, test_df) + assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) + + +class TestMultiSeriesArimaModel(unittest.TestCase): + + def setUp(self) -> None: + num_rows = 9 + self.X = pd.concat([ + pd.to_datetime(pd.Series(range(num_rows), name="date").apply(lambda i: f"2020-10-{i + 1}")), + pd.Series(range(num_rows), name="y") + ], axis=1) + model = ARIMA(order=(2, 0, 2), suppress_warnings=True) + model.fit(self.X.set_index("date")) + pickled_model = pickle.dumps(model) + pickled_model_dict = {"1": pickled_model, "2": pickled_model} + start_ds_dict = {"1": pd.Timestamp("2020-10-01"), "2": pd.Timestamp("2020-10-01")} + end_ds_dict = {"1": pd.Timestamp("2020-10-09"), "2": pd.Timestamp("2020-10-09")} + self.arima_model = MultiSeriesArimaModel(pickled_model_dict, horizon=1, frequency='d', + start_ds_dict=start_ds_dict, end_ds_dict=end_ds_dict, + time_col="date", id_cols=["id"]) + + def test_predict_timeseries_success(self): + forecast_pd = self.arima_model.predict_timeseries() + expected_columns = {"yhat", "yhat_lower", "yhat_upper"} + self.assertTrue(expected_columns.issubset(set(forecast_pd.columns))) + self.assertEqual(20, forecast_pd.shape[0]) + + def test_predict_success(self): + test_df = pd.DataFrame({ + "date": [pd.to_datetime("2020-10-05"), pd.to_datetime("2020-10-05"), + pd.to_datetime("2020-11-04"), pd.to_datetime("2020-11-04")], + "id": ["1", "2", "1", "2"], + }) + expected_test_df = test_df.copy() + yhat = self.arima_model.predict(None, test_df) + self.assertEqual(4, len(yhat)) + pd.testing.assert_frame_equal(test_df, expected_test_df) # check the input dataframe is unchanged + + def test_predict_fail_unseen_id(self): + test_df = pd.DataFrame({ + "date": [pd.to_datetime("2020-10-05"), pd.to_datetime("2020-10-05"), + pd.to_datetime("2020-11-04"), pd.to_datetime("2020-11-04")], + "id": ["1", "2", "1", "3"], + }) + with pytest.raises(MlflowException, match="includes unseen values in id columns") as e: + self.arima_model.predict(None, test_df) + assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) + + def test_predict_failure_unmatched_frequency(self): + test_df = pd.DataFrame({ + "date": [pd.to_datetime("2020-10-05"), pd.to_datetime("2020-10-05 12:30"), + pd.to_datetime("2020-11-04"), pd.to_datetime("2020-11-04")], + "id": ["1", "2", "1", "2"], + }) + with pytest.raises(MlflowException, match="includes different frequency") as e: + self.arima_model.predict(None, test_df) + assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) + + def test_predict_failure_invalid_time_range(self): + test_df = pd.DataFrame({ + "date": [pd.to_datetime("2020-10-05"), pd.to_datetime("2000-10-05"), + pd.to_datetime("2020-11-04"), pd.to_datetime("2020-11-04")], + "id": ["1", "2", "1", "2"], + }) + with pytest.raises(MlflowException, match="includes time earlier than the history data that the model was " + "trained on") as e: + self.arima_model.predict(None, test_df) + assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) + + def test_predict_failure_invalid_time_col_name(self): + test_df = pd.DataFrame({ + "time": [pd.to_datetime("2020-10-05"), pd.to_datetime("2000-10-05"), + pd.to_datetime("2020-11-04"), pd.to_datetime("2020-11-04")], + "invalid_id_col_name": ["1", "2", "1", "2"], + }) + with pytest.raises(MlflowException, match="Input data columns") as e: + self.arima_model.predict(None, test_df) + assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) + + +class TestAbstractArimaModel(unittest.TestCase): + + def test_validate_cols_success(self): + test_df = pd.DataFrame({"date": []}) + AbstractArimaModel._validate_cols(test_df, ["date"]) + + def test_validate_cols_invalid_id_col_name(self): + test_df = pd.DataFrame({"date": [], "invalid_id_col_name": [], }) + with pytest.raises(MlflowException, match="Input data columns") as e: + AbstractArimaModel._validate_cols(test_df, ["date", "id"]) + assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) diff --git a/runtime/tests/automl_runtime/forecast/pmdarima/training_test.py b/runtime/tests/automl_runtime/forecast/pmdarima/training_test.py new file mode 100644 index 00000000..2a267630 --- /dev/null +++ b/runtime/tests/automl_runtime/forecast/pmdarima/training_test.py @@ -0,0 +1,60 @@ +# +# Copyright (C) 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +import pandas as pd +import pmdarima as pm + +from databricks.automl_runtime.forecast.pmdarima.training import ArimaEstimator +from databricks.automl_runtime.forecast import OFFSET_ALIAS_MAP + + +class TestArimaEstimator(unittest.TestCase): + + def setUp(self) -> None: + num_rows = 12 + self.df = pd.concat([ + pd.to_datetime(pd.Series(range(num_rows), name="ds").apply(lambda i: f"2020-07-{2 * i + 1}")), + pd.Series(range(num_rows), name="y") + ], axis=1) + + def test_fit_success(self): + arima_estimator = ArimaEstimator(horizon=1, + frequency_unit="d", + metric="smape", + seasonal_periods=[1], + num_folds=2) + + results_pd = arima_estimator.fit(self.df) + self.assertIn("smape", results_pd) + self.assertIn("pickled_model", results_pd) + + def test_fit_predict_success(self): + cutoffs = [pd.to_datetime("2020-07-11")] + result = ArimaEstimator._fit_predict(self.df, cutoffs, seasonal_period=1) + self.assertIn("metrics", result) + self.assertIsInstance(result["model"], pm.arima.ARIMA) + + def test_fill_missing_time_steps(self): + supported_freq = ["W", "days", "hr", "min", "sec"] + for frequency in supported_freq: + ds = pd.date_range(start="2020-07-01", periods=12, freq=OFFSET_ALIAS_MAP[frequency]) + indices_to_drop = [5, 8] + df_missing = pd.DataFrame({"ds": ds, "y": range(12)}).drop(indices_to_drop).reset_index(drop=True) + df_filled = ArimaEstimator._fill_missing_time_steps(df_missing, frequency=frequency) + for index in indices_to_drop: + self.assertTrue(df_filled["y"][index], df_filled["y"][index - 1]) diff --git a/runtime/tests/automl_runtime/forecast/pmdarima/utils_test.py b/runtime/tests/automl_runtime/forecast/pmdarima/utils_test.py new file mode 100644 index 00000000..b676c912 --- /dev/null +++ b/runtime/tests/automl_runtime/forecast/pmdarima/utils_test.py @@ -0,0 +1,76 @@ +# +# Copyright (C) 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import pickle + +import pandas as pd +import mlflow +from pmdarima.arima import ARIMA + +from databricks.automl_runtime.forecast.pmdarima.model import ArimaModel, MultiSeriesArimaModel +from databricks.automl_runtime.forecast.pmdarima.utils import mlflow_arima_log_model + + +class TestUtils(unittest.TestCase): + + def setUp(self) -> None: + num_rows = 9 + self.X = pd.concat([ + pd.to_datetime(pd.Series(range(num_rows), name="date").apply(lambda i: f"2020-10-{i + 1}")), + pd.Series(range(num_rows), name="y") + ], axis=1) + model = ARIMA(order=(2, 0, 2), suppress_warnings=True) + model.fit(self.X.set_index("date")) + self.pickled_model = pickle.dumps(model) + + def test_mlflow_arima_log_model(self): + arima_model = ArimaModel(self.pickled_model, horizon=1, frequency='d', + start_ds=pd.to_datetime("2020-10-01"), end_ds=pd.to_datetime("2020-10-09"), + time_col="date") + with mlflow.start_run() as run: + mlflow_arima_log_model(arima_model) + + # Load the saved model from mlflow + run_id = run.info.run_id + loaded_model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model") + + # Make sure can make forecasts with the saved model + loaded_model.predict(self.X.drop("y", axis=1)) + loaded_model._model_impl.python_model.predict_timeseries() + + def test_mlflow_arima_log_model_multiseries(self): + pickled_model_dict = {"1": self.pickled_model, "2": self.pickled_model} + start_ds_dict = {"1": pd.Timestamp("2020-10-01"), "2": pd.Timestamp("2020-10-01")} + end_ds_dict = {"1": pd.Timestamp("2020-10-09"), "2": pd.Timestamp("2020-10-09")} + multiseries_arima_model = MultiSeriesArimaModel(pickled_model_dict, horizon=1, frequency='d', + start_ds_dict=start_ds_dict, end_ds_dict=end_ds_dict, + time_col="date", id_cols=["id"]) + with mlflow.start_run() as run: + mlflow_arima_log_model(multiseries_arima_model) + + # Load the saved model from mlflow + run_id = run.info.run_id + loaded_model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model") + + # Make sure can make forecasts with the saved model + loaded_model._model_impl.python_model.predict_timeseries() + test_df = pd.DataFrame({ + "date": [pd.to_datetime("2020-10-05"), pd.to_datetime("2020-10-05"), + pd.to_datetime("2020-11-04"), pd.to_datetime("2020-11-04")], + "id": ["1", "2", "1", "2"], + }) + loaded_model.predict(test_df) diff --git a/runtime/tests/automl_runtime/forecast/prophet/__init__.py b/runtime/tests/automl_runtime/forecast/prophet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/runtime/tests/automl_runtime/prophet/diagnostics_test.py b/runtime/tests/automl_runtime/forecast/prophet/diagnostics_test.py similarity index 100% rename from runtime/tests/automl_runtime/prophet/diagnostics_test.py rename to runtime/tests/automl_runtime/forecast/prophet/diagnostics_test.py diff --git a/runtime/tests/automl_runtime/prophet/forecast_test.py b/runtime/tests/automl_runtime/forecast/prophet/forecast_test.py similarity index 100% rename from runtime/tests/automl_runtime/prophet/forecast_test.py rename to runtime/tests/automl_runtime/forecast/prophet/forecast_test.py diff --git a/runtime/tests/automl_runtime/prophet/model_test.py b/runtime/tests/automl_runtime/forecast/prophet/model_test.py similarity index 100% rename from runtime/tests/automl_runtime/prophet/model_test.py rename to runtime/tests/automl_runtime/forecast/prophet/model_test.py