diff --git a/TODO.md b/TODO.md index c7c29028..45ece226 100644 --- a/TODO.md +++ b/TODO.md @@ -13,12 +13,14 @@ * Forecasts are desired for the future immediately following the most recent data. * trimmed_mean to AverageValueNaive -# 0.6.11 🇺🇦 🇺🇦 🇺🇦 +# 0.6.12 🇺🇦 🇺🇦 🇺🇦 * bug fixes -* continually trying to keep up with the Pandas maintainers who are breaking stuff for no good reasonable -* updated RollingMeanTransformer and RegressionFilter, RegressionFilter should now be less memory intensive -* EIA data call to load_live_daily -* horizontal_ensemble_validation arg for more complete validation on these ensembles +* added DMD model +* modified the `constraints` options so it now accepts of list of dictionaries of constraints with new last_window and slope options +* 'dampening' as a constraint method to dampen all forecasts, fixed Cassandra trend_phi dampening +* new med_diff anomaly method and 'laplace' added as distribution option +* modified fourier_df to now work with sub daily data +* some madness with wavelets attempting to use them like fourier series for seasonality ### Unstable Upstream Pacakges (those that are frequently broken by maintainers) * Pytorch-Forecasting diff --git a/autots/__init__.py b/autots/__init__.py index 14358f00..bde6a9e1 100644 --- a/autots/__init__.py +++ b/autots/__init__.py @@ -27,7 +27,7 @@ from autots.models.cassandra import Cassandra -__version__ = '0.6.11' +__version__ = '0.6.12' TransformTS = GeneralTransformer diff --git a/autots/evaluator/auto_model.py b/autots/evaluator/auto_model.py index 0bd59c55..9438bd73 100644 --- a/autots/evaluator/auto_model.py +++ b/autots/evaluator/auto_model.py @@ -61,7 +61,7 @@ DynamicFactorMQ, ) from autots.models.arch import ARCH -from autots.models.matrix_var import RRVAR, MAR, TMF, LATC +from autots.models.matrix_var import RRVAR, MAR, TMF, LATC, DMD def create_model_id( @@ -698,6 +698,17 @@ def ModelMonster( n_jobs=n_jobs, **parameters, ) + elif model == 'DMD': + return DMD( + frequency=frequency, + prediction_interval=prediction_interval, + holiday_country=holiday_country, + random_seed=random_seed, + verbose=verbose, + forecast_length=forecast_length, + n_jobs=n_jobs, + **parameters, + ) elif model == "": raise AttributeError( ("Model name is empty. Likely this means AutoTS has not been fit.") @@ -864,7 +875,7 @@ def predict(self, forecast_length=None, future_regressor=None): if not self._fit_complete: raise ValueError("Model not yet fit.") df_forecast = self.model.predict( - forecast_length=self.forecast_length, future_regressor=future_regressor + forecast_length=forecast_length, future_regressor=future_regressor ) # THIS CHECKS POINT FORECAST FOR NULLS BUT NOT UPPER/LOWER FORECASTS @@ -896,11 +907,13 @@ def predict(self, forecast_length=None, future_regressor=None): # CHECK Forecasts are proper length! if df_forecast.forecast.shape[0] != self.forecast_length: raise ValueError( - f"Model {self.model_str} returned improper forecast_length" + f"Model {self.model_str} returned improper forecast_length. Returned: {df_forecast.forecast.shape[0]} and requested: {self.forecast_length}" ) if df_forecast.forecast.shape[1] != self.df.shape[1]: - raise ValueError("Model failed to return correct number of series.") + raise ValueError( + f"Model failed to return correct number of series. Returned {df_forecast.forecast.shape[1]} and requested: {self.df.shape[1]}" + ) df_forecast.transformation_parameters = self.transformation_dict # Remove negatives if desired @@ -911,33 +924,53 @@ def predict(self, forecast_length=None, future_regressor=None): df_forecast.upper_forecast = df_forecast.upper_forecast.clip(lower=0) if self.constraint is not None: - if isinstance(self.constraint, dict): - constraint_method = self.constraint.get("constraint_method", "quantile") - constraint_regularization = self.constraint.get( - "constraint_regularization", 1 + if isinstance(self.constraint, list): + constraints = self.constraint + df_forecast = df_forecast.apply_constraints( + constraints=constraints, + df_train=self.df, ) - lower_constraint = self.constraint.get("lower_constraint", 0) - upper_constraint = self.constraint.get("upper_constraint", 1) - bounds = self.constraint.get("bounds", False) else: - constraint_method = "stdev_min" - lower_constraint = float(self.constraint) - upper_constraint = float(self.constraint) - constraint_regularization = 1 - bounds = False - if self.verbose > 3: - print( - f"Using constraint with method: {constraint_method}, {constraint_regularization}, {lower_constraint}, {upper_constraint}, {bounds}" - ) + constraints = None + if isinstance(self.constraint, dict): + if "constraints" in self.constraint.keys(): + constraints = self.constraint.get("constraints") + constraint_method = None + constraint_regularization = None + lower_constraint = None + upper_constraint = None + bounds = True + else: + constraint_method = self.constraint.get( + "constraint_method", "quantile" + ) + constraint_regularization = self.constraint.get( + "constraint_regularization", 1 + ) + lower_constraint = self.constraint.get("lower_constraint", 0) + upper_constraint = self.constraint.get("upper_constraint", 1) + bounds = self.constraint.get("bounds", False) + else: + constraint_method = "stdev_min" + lower_constraint = float(self.constraint) + upper_constraint = float(self.constraint) + constraint_regularization = 1 + bounds = False + if self.verbose > 3: + print( + f"Using constraint with method: {constraint_method}, {constraint_regularization}, {lower_constraint}, {upper_constraint}, {bounds}" + ) - df_forecast = df_forecast.apply_constraints( - constraint_method, - constraint_regularization, - upper_constraint, - lower_constraint, - bounds, - self.df, - ) + print(constraints) + df_forecast = df_forecast.apply_constraints( + constraints, + self.df, + constraint_method, + constraint_regularization, + upper_constraint, + lower_constraint, + bounds, + ) self.transformation_runtime = self.transformation_runtime + ( datetime.datetime.now() - transformationStartTime @@ -966,6 +999,18 @@ def fit_data(self, df, future_regressor=None): self.df = df self.model.fit_data(df, future_regressor) + def fit_predict( + self, + df, + forecast_length, + future_regressor_train=None, + future_regressor_forecast=None, + ): + self.fit(df, future_regressor=future_regressor_train) + return self.predict( + forecast_length=forecast_length, future_regressor=future_regressor_forecast + ) + class TemplateEvalObject(object): """Object to contain all the failures!. @@ -2119,7 +2164,9 @@ def NewGeneticTemplate( # filter existing templates sorted_results = model_results[ - (model_results['Ensemble'] == 0) & (model_results['Exceptions'].isna()) + (model_results['Ensemble'] == 0) + & (model_results['Exceptions'].isna()) + & (model_results['Model'].isin(model_list)) ].copy() # remove duplicates by exact same performance sorted_results = sorted_results.sort_values( diff --git a/autots/evaluator/auto_ts.py b/autots/evaluator/auto_ts.py index 65fc1dfd..2365ba23 100644 --- a/autots/evaluator/auto_ts.py +++ b/autots/evaluator/auto_ts.py @@ -1075,6 +1075,7 @@ def fit_data( preclean=None, verbose=0, ) + return self def fit( self, @@ -1826,8 +1827,10 @@ def _run_template( self.model_count = template_result.model_count # capture results from lower-level template run if "TotalRuntime" in template_result.model_results.columns: - template_result.model_results['TotalRuntime'].fillna( - pd.Timedelta(seconds=60), inplace=True + template_result.model_results['TotalRuntime'] = ( + template_result.model_results['TotalRuntime'].fillna( + pd.Timedelta(seconds=60) + ) ) else: # trying to catch a rare and sneaky bug (perhaps some variety of beetle?) @@ -2161,9 +2164,13 @@ def results(self, result_set: str = 'initial'): result_set (str): 'validation' or 'initial' """ if result_set == 'validation': - return self.validation_results.model_results + return self.validation_results.model_results.sort_values( + "Score", ascending=True + ) else: - return self.initial_results.model_results + return self.initial_results.model_results.sort_values( + "Score", ascending=True + ) def failure_rate(self, result_set: str = 'initial'): """Return fraction of models passing with exceptions. @@ -2280,6 +2287,22 @@ def export_template( export_template = unpack_ensemble_models( export_template, self.template_cols, keep_ensemble=False, recursive=True ).drop_duplicates() + if include_results: + export_template = export_template.drop(columns=['smape']).merge( + self.validation_results.model_results[['ID', 'smape']], + on="ID", + how='left', + ) + # put smape back in the front + remaining_columns = [ + col + for col in export_template.columns + if col not in self.template_cols_id and col not in ['smape', 'Runs'] + ] + new_order = ( + self.template_cols_id + ['Runs', 'smape'] + remaining_columns + ) + export_template = export_template.reindex(columns=new_order) return self.save_template(filename, export_template) def save_template(self, filename, export_template, **kwargs): diff --git a/autots/models/base.py b/autots/models/base.py index 06fa8b0a..00b4de04 100644 --- a/autots/models/base.py +++ b/autots/models/base.py @@ -113,66 +113,158 @@ def time(): return datetime.datetime.now() -def apply_constraints( +def constant_growth_rate(periods, final_growth): + """Take a final target growth rate (ie 2 % over a year) and convert to a daily (etc) value.""" + # Convert final growth rate percentage to a growth factor + final_growth_factor = 1 + final_growth + + # Calculate the daily growth factor required to achieve the final growth factor in the given days + daily_growth_factor = final_growth_factor ** (1 / periods) + + # Generate an array of day indices + day_indices = np.arange(1, periods + 1) + + # Calculate the cumulative growth factor for each day + cumulative_growth_factors = daily_growth_factor**day_indices + + # Calculate the perceived growth rates relative to the starting value + perceived_growth_rates = cumulative_growth_factors - 1 + return perceived_growth_rates + + +def apply_constraint_single( forecast, lower_forecast, upper_forecast, constraint_method, - constraint_regularization, - upper_constraint, - lower_constraint, - bounds, + constraint_value, + constraint_direction='upper', + constraint_regularization=1.0, + bounds=True, df_train=None, ): - """Use constraint thresholds to adjust outputs by limit. - Note that only one method of constraint can be used here, but if different methods are desired, - this can be run twice, with None passed to the upper or lower constraint not being used. - - Args: - forecast (pd.DataFrame): forecast df, wide style - lower_forecast (pd.DataFrame): lower bound forecast df - if bounds is False, upper and lower forecast dataframes are unused and can be empty - upper_forecast (pd.DataFrame): upper bound forecast df - constraint_method (str): one of - stdev_min - threshold is min and max of historic data +/- constraint * st dev of data - stdev - threshold is the mean of historic data +/- constraint * st dev of data - absolute - input is array of length series containing the threshold's final value for each - quantile - constraint is the quantile of historic data to use as threshold - constraint_regularization (float): 0 to 1 - where 0 means no constraint, 1 is hard threshold cutoff, and in between is penalty term - upper_constraint (float): or array, depending on method, None if unused - lower_constraint (float): or array, depending on method, None if unused - bounds (bool): if True, apply to upper/lower forecast, otherwise False applies only to forecast - df_train (pd.DataFrame): required for quantile/stdev methods to find threshold values - - Returns: - forecast, lower, upper (pd.DataFrame) - """ + # check if training data provided + if df_train is None and constraint_method in [ + "quantile", + "stdev", + "stdev_min", + "last_window", + "slope", + ]: + raise ValueError("this constraint requires df_train to be provided") + # set direction + lower_constraint = None + upper_constraint = None + if constraint_direction == "lower": + lower_constraint = True + elif constraint_direction == "upper": + upper_constraint = True + else: + raise ValueError(f"constraint_direction: {constraint_direction} invalid") if constraint_method == "stdev_min": train_std = df_train.std(axis=0) if lower_constraint is not None: - train_min = df_train.min(axis=0) - (lower_constraint * train_std) + train_min = df_train.min(axis=0) - (constraint_value * train_std) if upper_constraint is not None: - train_max = df_train.max(axis=0) + (upper_constraint * train_std) + train_max = df_train.max(axis=0) + (constraint_value * train_std) elif constraint_method == "stdev": train_std = df_train.std(axis=0) train_mean = df_train.mean(axis=0) if lower_constraint is not None: - train_min = train_mean - (lower_constraint * train_std) + train_min = train_mean - (constraint_value * train_std) if upper_constraint is not None: - train_max = train_mean + (upper_constraint * train_std) - elif constraint_method == "absolute": - train_min = lower_constraint - train_max = upper_constraint + train_max = train_mean + (constraint_value * train_std) + elif constraint_method in ["absolute", "fixed"]: + train_min = constraint_value + train_max = constraint_value elif constraint_method == "quantile": if lower_constraint is not None: - train_min = df_train.quantile(lower_constraint, axis=0) + train_min = df_train.quantile(constraint_value, axis=0) if upper_constraint is not None: - train_max = df_train.quantile(upper_constraint, axis=0) + train_max = df_train.quantile(constraint_value, axis=0) + elif constraint_method == "last_window": + if isinstance(constraint_value, dict): + window = constraint_value.get("window", 3) + window_agg = constraint_value.get("window_agg", "mean") + threshold = constraint_value.get("threshold", 0.05) + else: + window = 1 + window_agg = "mean" + threshold = constraint_value + if window_agg == "mean": + end_o_data = df_train.iloc[-window:].mean(axis=0) + elif window_agg == "max": + end_o_data = df_train.iloc[-window:].max(axis=0) + elif window_agg == "min": + end_o_data = df_train.iloc[-window:].min(axis=0) + else: + raise ValueError(f"constraint window_agg not recognized: {window_agg}") + train_min = train_max = end_o_data + end_o_data * threshold + elif constraint_method == "slope": + if isinstance(constraint_value, dict): + window = constraint_value.get("window", 1) + window_agg = constraint_value.get("window_agg", "mean") + slope = constraint_value.get("slope", 0.05) + threshold = constraint_value.get("threshold", None) + else: + window = 1 + window_agg = "mean" + slope = constraint_value + threshold = None + # slope is given as a final max growth, NOT compounding + changes = constant_growth_rate(forecast.shape[0], slope) + if window_agg == "mean": + end_o_data = df_train.iloc[-window:].mean(axis=0) + elif window_agg == "max": + end_o_data = df_train.iloc[-window:].max(axis=0) + elif window_agg == "min": + end_o_data = df_train.iloc[-window:].min(axis=0) + else: + raise ValueError(f"constraint window_agg not recognized: {window_agg}") + # have the slope start above a threshold to allow more volatility + if threshold is not None: + end_o_data = end_o_data + end_o_data * threshold + train_min = train_max = ( + end_o_data.to_numpy() + + end_o_data.to_numpy()[np.newaxis, :] * changes[:, np.newaxis] + ) + elif constraint_method == "dampening": + pass else: - raise ValueError("constraint_method not recognized, adjust constraint") + raise ValueError( + f"constraint_method {constraint_method} not recognized, adjust constraint" + ) - if constraint_regularization == 1: + if constraint_method == "dampening": + # the idea is to make the forecast plateau by gradually forcing the step to step change closer to zero + trend_phi = constraint_value + if trend_phi is not None and trend_phi != 1 and forecast.shape[0] > 2: + req_len = forecast.shape[0] - 1 + phi_series = pd.Series( + [trend_phi] * req_len, + index=forecast.index[1:], + ).pow(range(req_len)) + + # adjust all by same margin + forecast = pd.concat( + [forecast.iloc[0:1], forecast.diff().iloc[1:].mul(phi_series, axis=0)] + ).cumsum() + + if bounds: + lower_forecast = pd.concat( + [ + lower_forecast.iloc[0:1], + lower_forecast.diff().iloc[1:].mul(phi_series, axis=0), + ] + ).cumsum() + upper_forecast = pd.concat( + [ + upper_forecast.iloc[0:1], + upper_forecast.diff().iloc[1:].mul(phi_series, axis=0), + ] + ).cumsum() + return forecast, lower_forecast, upper_forecast + if constraint_regularization == 1 or constraint_regularization is None: if lower_constraint is not None: forecast = forecast.clip(lower=train_min, axis=1) if upper_constraint is not None: @@ -186,48 +278,124 @@ def apply_constraints( upper_forecast = upper_forecast.clip(upper=train_max, axis=1) else: if lower_constraint is not None: - forecast.where( + forecast = forecast.where( forecast >= train_min, forecast + (train_min - forecast) * constraint_regularization, - inplace=True, ) if upper_constraint is not None: - forecast.where( + forecast = forecast.where( forecast <= train_max, forecast + (train_max - forecast) * constraint_regularization, - inplace=True, ) if bounds: if lower_constraint is not None: - lower_forecast.where( + lower_forecast = lower_forecast.where( lower_forecast >= train_min, lower_forecast + (train_min - lower_forecast) * constraint_regularization, - inplace=True, ) - upper_forecast.where( + upper_forecast = upper_forecast.where( upper_forecast >= train_min, upper_forecast + (train_min - upper_forecast) * constraint_regularization, - inplace=True, ) if upper_constraint is not None: - lower_forecast.where( + lower_forecast = lower_forecast.where( lower_forecast <= train_max, lower_forecast + (train_max - lower_forecast) * constraint_regularization, - inplace=True, ) - upper_forecast.where( + upper_forecast = upper_forecast.where( upper_forecast <= train_max, upper_forecast + (train_max - upper_forecast) * constraint_regularization, - inplace=True, ) return forecast, lower_forecast, upper_forecast +def apply_constraints( + forecast, + lower_forecast, + upper_forecast, + constraints=None, + df_train=None, + # old args + constraint_method=None, + constraint_regularization=None, + upper_constraint=None, + lower_constraint=None, + bounds=True, +): + """Use constraint thresholds to adjust outputs by limit. + + Args: + forecast (pd.DataFrame): forecast df, wide style + lower_forecast (pd.DataFrame): lower bound forecast df + if bounds is False, upper and lower forecast dataframes are unused and can be empty + upper_forecast (pd.DataFrame): upper bound forecast df + constraints (list): list of dictionaries of constraints to apply + keys: "constraint_method" (same as below, old args), "constraint_regularization", "constraint_value", "constraint_direction" (upper/lower), bounds + df_train (pd.DataFrame): required for quantile/stdev methods to find threshold values + # old args + constraint_method (str): one of + stdev_min - threshold is min and max of historic data +/- constraint * st dev of data + stdev - threshold is the mean of historic data +/- constraint * st dev of data + absolute - input is array of length series containing the threshold's final value for each + quantile - constraint is the quantile of historic data to use as threshold + last_window - certain percentage above and below the last n data values + slope - cannot exceed a certain growth rate from last historical value + constraint_regularization (float): 0 to 1 + where 0 means no constraint, 1 is hard threshold cutoff, and in between is penalty term + upper_constraint (float): or array, depending on method, None if unused + lower_constraint (float): or array, depending on method, None if unused + bounds (bool): if True, apply to upper/lower forecast, otherwise False applies only to forecast + + Returns: + forecast, lower, upper (pd.DataFrame) + """ + # handle old style + if constraint_method is not None: + if constraints is not None: + raise ValueError( + f"both constraint_method (old way) and constraints (new way) args passed, this will not work. Constraints: {constraints}" + ) + else: + constraints = [] + if upper_constraint is not None: + constraints.append( + { + "constraint_method": constraint_method, + "constraint_value": upper_constraint, + "constraint_direction": "upper", + "constraint_regularization": constraint_regularization, + "bounds": bounds, + } + ) + if lower_constraint is not None: + constraints.append( + { + "constraint_method": constraint_method, + "constraint_value": lower_constraint, + "constraint_direction": "lower", + "constraint_regularization": constraint_regularization, + "bounds": bounds, + } + ) + print(constraints) + if constraints is None or not constraints: + print("no constraint applied") + return forecast, lower_forecast, upper_forecast + if isinstance(constraints, dict): + constraints = [constraints] + for constraint in constraints: + forecast, lower_forecast, upper_forecast = apply_constraint_single( + forecast, lower_forecast, upper_forecast, df_train=df_train, **constraint + ) + + return forecast, lower_forecast, upper_forecast + + def extract_single_series_from_horz(series, model_name, model_parameters): title_prelim = str(model_name)[0:80] if title_prelim == "Ensemble": @@ -877,16 +1045,64 @@ def evaluate( def apply_constraints( self, - constraint_method="quantile", - constraint_regularization=0.5, - upper_constraint=1.0, - lower_constraint=0.0, - bounds=True, + constraints=None, df_train=None, + # old args + constraint_method=None, + constraint_regularization=None, + upper_constraint=None, + lower_constraint=None, + bounds=True, ): """Use constraint thresholds to adjust outputs by limit. - Note that only one method of constraint can be used here, but if different methods are desired, - this can be run twice, with None passed to the upper or lower constraint not being used. + + Example: + apply_constraints( + constraints=[ + { # don't exceed historic max + "constraint_method": "quantile", + "constraint_value": 1.0, + "constraint_direction": "upper", + "constraint_regularization": 1.0, + "bounds": True, + }, + { # don't exceed 2% growth by end of forecast horizon + "constraint_method": "slope", + "constraint_value": {"slope": 0.02, "window": 10, "window_agg": "max", "threshold": 0.01}, + "constraint_direction": "upper", + "constraint_regularization": 0.9, + "bounds": False, + }, + { # don't go below the last 10 values - 10% + "constraint_method": "last_window", + "constraint_value": {"window": 10, "threshold": -0.1}, + "constraint_direction": "lower", + "constraint_regularization": 1.0, + "bounds": False, + }, + { # don't go below zero + "constraint_method": "absolute", + "constraint_value": 0, # can also be an array or Series + "constraint_direction": "lower", + "constraint_regularization": 1.0, + "bounds": True, + }, + { # don't go below historic min - 1 st dev + "constraint_method": "stdev_min", + "constraint_value": 1.0, + "constraint_direction": "lower", + "constraint_regularization": 1.0, + "bounds": True, + }, + { # don't go above historic mean + 3 st devs, soft limit + "constraint_method": "stdev", + "constraint_value": 3.0, + "constraint_direction": "upper", + "constraint_regularization": 0.5, + "bounds": True, + }, + ] + ) Args: constraint_method (str): one of @@ -908,11 +1124,13 @@ def apply_constraints( self.forecast, self.lower_forecast, self.upper_forecast, - constraint_method, - constraint_regularization, - upper_constraint, - lower_constraint, - bounds, - df_train, + constraints=constraints, + df_train=df_train, + # old args + constraint_method=constraint_method, + constraint_regularization=constraint_regularization, + upper_constraint=upper_constraint, + lower_constraint=lower_constraint, + bounds=bounds, ) return self diff --git a/autots/models/basics.py b/autots/models/basics.py index fe091c24..3de3ce03 100644 --- a/autots/models/basics.py +++ b/autots/models/basics.py @@ -2675,6 +2675,12 @@ def predict( k = self.k full_sort = self.point_method == "closest" + if forecast_length >= self.df.shape[0]: + self.independent = True + if self.verbose > 0: + print( + "prediction too long for indepedent=False, falling back on indepdent=True" + ) if self.independent: # each timestep is considered individually and not as a series test, scores = seasonal_independent_match( @@ -2802,6 +2808,7 @@ def get_params(self): "distance_metric": self.distance_metric, "k": self.k, "datepart_method": self.datepart_method, + "independent": self.independent, } @@ -3133,7 +3140,7 @@ def get_new_params(self, method: str = 'random'): ["weighted_mean", "mean", "median", "midhinge", "closest"], [0.4, 0.2, 0.2, 0.2, 0.2], )[0], - "distance_metric": random.choices(metric_list, metric_probabilities), + "distance_metric": random.choices(metric_list, metric_probabilities)[0], "k": k_choice, "sample_fraction": sample_fraction, } diff --git a/autots/models/cassandra.py b/autots/models/cassandra.py index 26bfb82b..3db79902 100644 --- a/autots/models/cassandra.py +++ b/autots/models/cassandra.py @@ -1528,18 +1528,40 @@ def predict( n_jobs=self.n_jobs, ) # phi is on future predict step only - if self.trend_phi is not None and self.trend_phi != 1: - temp = trend_forecast.forecast.mul( - pd.Series( - [self.trend_phi] * trend_forecast.forecast.shape[0], - index=trend_forecast.forecast.index, - ).pow(range(trend_forecast.forecast.shape[0])), - axis=0, - ) + if ( + self.trend_phi is not None + and self.trend_phi != 1 + and trend_forecast.forecast.shape[0] > 2 + ): + req_len = trend_forecast.forecast.shape[0] - 1 + phi_series = pd.Series( + [self.trend_phi] * req_len, + index=trend_forecast.forecast.index[1:], + ).pow(range(req_len)) + # adjust all by same margin - trend_forecast.forecast = trend_forecast.forecast + temp - trend_forecast.upper_forecast = trend_forecast.upper_forecast + temp - trend_forecast.lower_forecast = trend_forecast.lower_forecast + temp + trend_forecast.forecast = pd.concat( + [ + trend_forecast.forecast.iloc[0:1], + trend_forecast.forecast.diff().iloc[1:].mul(phi_series, axis=0), + ] + ).cumsum() + trend_forecast.upper_forecast = pd.concat( + [ + trend_forecast.upper_forecast.iloc[0:1], + trend_forecast.upper_forecast.diff() + .iloc[1:] + .mul(phi_series, axis=0), + ] + ).cumsum() + trend_forecast.lower_forecast = pd.concat( + [ + trend_forecast.lower_forecast.iloc[0:1], + trend_forecast.lower_forecast.diff() + .iloc[1:] + .mul(phi_series, axis=0), + ] + ).cumsum() if include_history: trend_forecast.forecast = pd.concat( [ @@ -1759,32 +1781,11 @@ def predict( df_forecast.upper_forecast = df_forecast.upper_forecast * impts if self.constraint is not None: - if isinstance(self.constraint, dict): - constraint_method = self.constraint.get("constraint_method", "quantile") - constraint_regularization = self.constraint.get( - "constraint_regularization", 1 - ) - lower_constraint = self.constraint.get("lower_constraint", 0) - upper_constraint = self.constraint.get("upper_constraint", 1) - bounds = self.constraint.get("bounds", False) - else: - constraint_method = "stdev_min" - lower_constraint = float(self.constraint) - upper_constraint = float(self.constraint) - constraint_regularization = 1 - bounds = False - if self.verbose >= 3: - print( - f"Using constraint with method: {constraint_method}, {constraint_regularization}, {lower_constraint}, {upper_constraint}, {bounds}" - ) - + # print(f"constraint is {self.constraint}") + # this might work out weirdly since self.df is scaled df_forecast = df_forecast.apply_constraints( - constraint_method, - constraint_regularization, - upper_constraint, - lower_constraint, - bounds, - self.df_original, + **self.constraint, + df_train=self.to_origin_space(self.df, trans_method="original"), ) # RETURN COMPONENTS (long style) option df_forecast.predict_runtime = self.time() - predictStartTime @@ -2140,7 +2141,7 @@ def get_new_params(self, method='fast'): # "trend_anomaly_intervention": trend_anomaly_intervention, "trend_transformation": trend_transformation, "trend_model": trend_model, - "trend_phi": random.choices([None, 0.98], [0.9, 0.1])[0], + "trend_phi": random.choices([None, 0.995, 0.98], [0.9, 0.05, 0.1])[0], } def get_params(self): @@ -2792,7 +2793,7 @@ def sample_posterior(self, n_samples=1): df_daily = load_daily(long=False) # add nan df_daily.iloc[100, :] = np.nan - forecast_length = 180 + forecast_length = 240 include_history = True df_train = df_daily[:-forecast_length].iloc[:, 1:] df_test = df_daily[-forecast_length:].iloc[:, 1:] @@ -2816,21 +2817,21 @@ def sample_posterior(self, n_samples=1): np.random.normal(size=(forecast_length, 1)), index=df_test.index ) } - constraint = { - 'constraint_method': 'quantile', - 'lower_constraint': 0, - 'upper_constraint': None, - "bounds": True, - } - past_impacts = pd.DataFrame(0, index=df_train.index, columns=df_train.columns) + constraint = None + past_impacts = pd.DataFrame( + 0, index=df_train.index, columns=df_train.columns + ).astype(float) past_impacts.iloc[-10:, 0] = np.geomspace(1, 10)[0:10] / 100 past_impacts.iloc[-30:, -1] = np.linspace(1, 10)[0:30] / -100 past_impacts_full = pd.DataFrame(0, index=df_daily.index, columns=df_daily.columns) - future_impacts = pd.DataFrame(0, index=df_test.index, columns=df_test.columns) + future_impacts = pd.DataFrame( + 0, index=df_test.index, columns=df_test.columns + ).astype(float) future_impacts.iloc[0:10, 0] = (np.linspace(1, 10)[0:10] + 10) / 100 c_params = Cassandra().get_new_params() c_params['regressors_used'] = False + # c_params['trend_phi'] = 0.9 mod = Cassandra( n_jobs=1, diff --git a/autots/models/matrix_var.py b/autots/models/matrix_var.py index 0e852129..ab611861 100644 --- a/autots/models/matrix_var.py +++ b/autots/models/matrix_var.py @@ -968,3 +968,218 @@ def get_params(self): 'alpha': self.alpha, 'maxiter': self.maxiter, } + + +def _DMD( + data, + r, + alpha=0.0, + amplitude_threshold=None, + eigenvalue_threshold=None, + ecr_threshold=0.95, +): + X1 = data[:, :-1] + X2 = data[:, 1:] + u, s, v = np.linalg.svd(X1, full_matrices=False) + if r in ['ecr', 'auto']: + total_energy = np.sum(s**2) + # Calculate captured energy for each singular value + captured_energy = np.cumsum(s**2) / total_energy + r = np.searchsorted(captured_energy, ecr_threshold) + print(f"ECR rank is {r}") + elif r > 0 and r < 1: + r = int(data.shape[0] * r) + # print(f"Rational rank is {r}") + + regularized_s = s[:r] + alpha + A_tilde = u[:, :r].conj().T @ X2 @ v[:r, :].conj().T * np.reciprocal(regularized_s) + Phi, Q = np.linalg.eig(A_tilde) + + if amplitude_threshold is not None: + # Calculate mode amplitudes + b = np.linalg.pinv(Q) @ u[:, :r].conj().T @ X1[:, 0] + amplitudes = np.abs(b) + amp_filter = amplitudes > amplitude_threshold + else: + amp_filter = np.ones_like(Phi, dtype=bool) + + if eigenvalue_threshold is not None: + # Calculate eigenvalue magnitudes + eigenvalue_magnitudes = np.abs(Phi) + eigen_filter = eigenvalue_magnitudes <= eigenvalue_threshold + else: + eigen_filter = np.ones_like(Phi, dtype=bool) + + # Filter modes based on amplitudes and eigenvalue magnitudes + filter_mask = amp_filter & eigen_filter + Phi = Phi[filter_mask] + Q = Q[:, filter_mask] + + # Reconstruct dynamics with filtered modes + Psi = X2 @ v[:r, :].conj().T @ np.diag(np.reciprocal(regularized_s)) @ Q + A = Psi @ np.diag(Phi) @ np.linalg.pinv(Psi) + return A_tilde, Phi, A + + +def dmd_forecast( + data, r, pred_step, alpha=0.0, amplitude_threshold=None, eigenvalue_threshold=None +): + N, T = data.shape + _, _, A = _DMD( + data, + r, + alpha, + amplitude_threshold=amplitude_threshold, + eigenvalue_threshold=eigenvalue_threshold, + ) + mat = np.append(data, np.zeros((N, pred_step)), axis=1) + for s in range(pred_step): + mat[:, T + s] = (A @ mat[:, T + s - 1]).real + return mat[:, -pred_step:] + + +class DMD(ModelObject): + """Dynamic Mode Decomposition + + Args: + name (str): String to identify class + frequency (str): String alias of datetime index frequency or else 'infer' + prediction_interval (float): Confidence interval for probabilistic forecast + regression_type (str): type of regression (None, 'User', or 'Holiday') + n_jobs (int): passed to joblib for multiprocessing. Set to none for context manager. + + """ + + def __init__( + self, + name: str = "DMD", + frequency: str = 'infer', + prediction_interval: float = 0.9, + alpha: float = 0.0, + rank: float = 0.1, + amplitude_threshold: float = None, + eigenvalue_threshold: float = None, + holiday_country: str = 'US', + random_seed: int = 2022, + verbose: int = 0, + n_jobs: int = None, + **kwargs, + ): + ModelObject.__init__( + self, + name, + frequency, + prediction_interval, + holiday_country=holiday_country, + random_seed=random_seed, + verbose=verbose, + n_jobs=n_jobs, + ) + self.alpha = alpha + self.rank = rank + self.amplitude_threshold = amplitude_threshold + self.eigenvalue_threshold = eigenvalue_threshold + + def fit(self, df, future_regressor=None): + """Train algorithm given data supplied . + + Args: + df (pandas.DataFrame): Datetime Indexed + """ + + df = self.basic_profile(df) + self.regressor_train = None + self.verbose_bool = False + if self.verbose > 1: + self.verbose_bool = True + + if isinstance(self.rank, float): + if self.rank < 1 and self.rank > 0: + self.rank = int(self.rank * df.shape[1]) + self.rank = self.rank if self.rank > 0 else 1 + + self.df_train = df + + self.fit_runtime = datetime.datetime.now() - self.startTime + return self + + def predict( + self, forecast_length: int, future_regressor=None, just_point_forecast=False + ): + """Generate forecast data immediately following dates of index supplied to .fit(). + + Args: + forecast_length (int): Number of periods of data to forecast ahead + regressor (numpy.Array): additional regressor, not used + just_point_forecast (bool): If True, return a pandas.DataFrame of just point forecasts + + Returns: + Either a PredictionObject of forecasts and metadata, or + if just_point_forecast == True, a dataframe of point forecasts + """ + predictStartTime = datetime.datetime.now() + test_index = self.create_forecast_index(forecast_length=forecast_length) + + data = self.df_train.to_numpy().T + forecast = dmd_forecast( + data, + r=self.rank, + pred_step=forecast_length, + alpha=self.alpha, + amplitude_threshold=self.amplitude_threshold, + eigenvalue_threshold=self.eigenvalue_threshold, + ).T + + forecast = pd.DataFrame(forecast, index=test_index, columns=self.column_names) + if just_point_forecast: + return forecast + else: + upper_forecast, lower_forecast = Point_to_Probability( + self.df_train, + forecast, + method='inferred_normal', + prediction_interval=self.prediction_interval, + ) + predict_runtime = datetime.datetime.now() - predictStartTime + prediction = PredictionObject( + model_name=self.name, + forecast_length=forecast_length, + forecast_index=test_index, + forecast_columns=forecast.columns, + lower_forecast=lower_forecast, + forecast=forecast, + upper_forecast=upper_forecast, + prediction_interval=self.prediction_interval, + predict_runtime=predict_runtime, + fit_runtime=self.fit_runtime, + model_parameters=self.get_params(), + ) + + return prediction + + def get_new_params(self, method: str = 'random'): + """Return dict of new parameters for parameter tuning.""" + return { + 'rank': random.choices( + [2, 3, 4, 6, 10, 0.1, 0.2, 0.5, "ecr"], + [0.4, 0.1, 0.3, 0.1, 0.1, 0.1, 0.2, 0.2, 0.6], + )[0], + 'alpha': random.choice([0.0, 0.001, 0.1, 1]), + 'amplitude_threshold': random.choices( + [None, 0.1, 1, 10], + [0.7, 0.1, 0.1, 0.1], + )[0], + 'eigenvalue_threshold': random.choices( + [None, 0.1, 1, 10], + [0.7, 0.1, 0.1, 0.1], + )[0], + } + + def get_params(self): + """Return dict of current parameters.""" + return { + 'rank': self.rank, + 'alpha': self.alpha, + 'amplitude_threshold': self.amplitude_threshold, + 'eigenvalue_threshold': self.eigenvalue_threshold, + } diff --git a/autots/models/model_list.py b/autots/models/model_list.py index 15b1aab4..b6a3984b 100644 --- a/autots/models/model_list.py +++ b/autots/models/model_list.py @@ -47,6 +47,7 @@ "BallTreeMultivariateMotif", "TiDE", "NeuralForecast", + "DMD", ] all_pragmatic = list((set(all_models) - set(['MLEnsemble', 'VARMAX', 'Greykite']))) # downweight slower models @@ -60,23 +61,27 @@ 'ETS': 1, 'FBProphet': 0.5, # 'GluonTS': 0.5, - 'UnobservedComponents': 1, + 'UnobservedComponents': 0.6, 'VAR': 1, 'VECM': 1, - 'ARIMA': 0.4, - 'WindowRegression': 0.5, + 'ARIMA': 0.3, + 'WindowRegression': 0.8, 'DatepartRegression': 1, - 'UnivariateRegression': 0.3, + # 'UnivariateRegression': 0.1, 'MultivariateRegression': 0.4, 'UnivariateMotif': 1, 'MultivariateMotif': 1, 'SectionalMotif': 1, - 'NVAR': 1, + 'NVAR': 0.4, 'Theta': 1, 'ARDL': 1, 'ARCH': 1, 'MetricMotif': 1, 'SeasonalityMotif': 1, + 'DMD': 0.3, + 'RRVAR': 0.8, + 'FFT': 0.8, + 'Cassandra': 0.8, } # fastest models at any scale superfast = [ @@ -223,6 +228,7 @@ 'BallTreeMultivariateMotif', "TiDE", "NeuralForecast", + "DMD", ] univariate = list((set(all_models) - set(multivariate)) - set(experimental)) # USED IN AUTO_MODEL, models with no parameters @@ -269,6 +275,7 @@ 'FFT', 'BallTreeMultivariateMotif', "TiDE", + "DMD", ] # USED IN AUTO_MODEL for models that don't share information among series no_shared = [ diff --git a/autots/models/sklearn.py b/autots/models/sklearn.py index 485f8412..84954c55 100644 --- a/autots/models/sklearn.py +++ b/autots/models/sklearn.py @@ -434,7 +434,7 @@ def retrieve_regressor( verbosity=0, **model_param_dict, n_jobs=smaller_n_jobs ) return regr - elif model_class == 'SVM': + elif model_class in ['SVM', "LinearSVR"]: from sklearn.svm import LinearSVR if multioutput: @@ -640,7 +640,7 @@ def retrieve_classifier( 'DecisionTree': 0.05, 'KNN': 0.05, 'Adaboost': 0.03, - 'SVM': 0.03, + 'SVM': 0.01, # 'BayesianRidge': 0.05, 'xgboost': 0.09, # 'KerasRNN': 0.01, # too slow on big data @@ -677,7 +677,7 @@ def retrieve_classifier( 'DecisionTree': 0.05, 'KNN': 0.05, 'Adaboost': 0.03, - 'SVM': 0.05, + 'SVM': 0.02, 'KerasRNN': 0.02, 'LightGBM': 0.09, 'LightGBMRegressorChain': 0.03, @@ -691,7 +691,7 @@ def retrieve_classifier( no_shared_model_dict = { 'KNN': 0.1, 'Adaboost': 0.1, - 'SVM': 0.1, + 'SVM': 0.01, 'xgboost': 0.1, 'LightGBM': 0.1, 'HistGradientBoost': 0.1, @@ -703,7 +703,7 @@ def retrieve_classifier( 'MLP': 0.05, 'DecisionTree': 0.02, 'Adaboost': 0.05, - 'SVM': 0.01, + 'SVM': 0.001, 'KerasRNN': 0.01, # 'Transformer': 0.02, # slow, kernel failed 'RadiusNeighbors': 0.1, @@ -1287,12 +1287,15 @@ def generate_regressor_params( elif model == "SVM": # LinearSVR param_dict = { - 'C': random.choices([1.0, 0.5, 2.0, 0.25], [0.6, 0.1, 0.1, 0.1])[0], - 'tol': random.choices([1e-4, 1e-3, 1e-5], [0.6, 0.1, 0.1])[0], - "loss": random.choice( - ['epsilon_insensitive', 'squared_epsilon_insensitive'] - ), - "max_iter": random.choice([500, 1000]), + "model": 'SVM', + "model_params": { + 'C': random.choices([1.0, 0.5, 2.0, 0.25], [0.6, 0.1, 0.1, 0.1])[0], + 'tol': random.choices([1e-4, 1e-3, 1e-5], [0.6, 0.1, 0.1])[0], + "loss": random.choice( + ['epsilon_insensitive', 'squared_epsilon_insensitive'] + ), + "max_iter": random.choice([500, 1000]), + }, } else: min_samples = np.random.choice( diff --git a/autots/templates/general.py b/autots/templates/general.py index d99190ad..9f715fc2 100644 --- a/autots/templates/general.py +++ b/autots/templates/general.py @@ -418,11 +418,7 @@ }, "68": { 'Model': 'SeasonalityMotif', - 'ModelParameters': '''{ - "window": 5, "point_method": "weighted_mean", - "distance_metric": "mae", "k": 10, - "datepart_method": "common_fourier" - }''', + 'ModelParameters': '{"window": 5, "point_method": "weighted_mean", "distance_metric": "mae", "k": 10, "datepart_method": "common_fourier"}', 'TransformationParameters': '{"fillna": "nearest", "transformations": {"0": "AlignLastValue"}, "transformation_params": {"0": {"rows": 1, "lag": 1, "method": "multiplicative", "strength": 1.0, "first_value_only": false}}}', 'Ensemble': 0, }, @@ -465,6 +461,18 @@ 'TransformationParameters': '{"fillna": "ffill", "transformations": {"0": "MaxAbsScaler", "1": "FFTDecomposition", "2": "bkfilter"}, "transformation_params": {"0": {}, "1": {"n_harmonics": 10, "detrend": "linear"}, "2": {}}}', 'Ensemble': 0, }, + "74": { # optimized 200 minutes on initial model import on load_daily + "Model": "DMD", + 'ModelParameters': '{"rank": 10, "alpha": 1, "amplitude_threshold": null, "eigenvalue_threshold": null}"', + "TransformationParameters": '"{"fillna": "linear", "transformations": {"0": "HistoricValues", "1": "AnomalyRemoval", "2": "SeasonalDifference", "3": "AnomalyRemoval"},"transformation_params": {"0": {"window": 10}, "1": {"method": "zscore", "method_params": {"distribution": "norm", "alpha": 0.05}, "fillna": "ffill", "transform_dict": {"fillna": null, "transformations": {"0": "ClipOutliers"}, "transformation_params": {"0": {"method": "clip", "std_threshold": 6}}}, "isolated_only": false}, "2": {"lag_1": 7, "method": "Mean"}, "3": {"method": "zscore", "method_params": {"distribution": "norm", "alpha": 0.05}, "fillna": "fake_date", "transform_dict": {"transformations": {"0": "DifferencedTransformer"}, "transformation_params": {"0": {}}}, "isolated_only": false}}}', + "Ensemble": 0, + }, + "75": { # short optimization on M5 + "Model": "DMD", + "ModelParameters": "{'rank': 2, 'alpha': 1, 'amplitude_threshold': null, 'eigenvalue_threshold': 1}", + "TransformationParameters": "{'fillna': 'ffill', 'transformations': {'0': 'SeasonalDifference', '1': 'AlignLastValue', '2': 'Round', '3': 'Round', '4': 'MinMaxScaler'}, 'transformation_params': {'0': {'lag_1': 7, 'method': 'LastValue'}, '1': {'rows': 1, 'lag': 1, 'method': 'additive', 'strength': 1.0, 'first_value_only': false}, '2': {'decimals': 0, 'on_transform': false, 'on_inverse': true}, '3': {'decimals': 0, 'on_transform': false, 'on_inverse': true}, '4': {}}}", + "Ensemble": 0, + }, } general_template = pd.DataFrame.from_dict(general_template_dict, orient='index') diff --git a/autots/tools/anomaly_utils.py b/autots/tools/anomaly_utils.py index fecfa73d..f0db2498 100644 --- a/autots/tools/anomaly_utils.py +++ b/autots/tools/anomaly_utils.py @@ -34,7 +34,7 @@ from sklearn.ensemble import IsolationForest from sklearn.neighbors import LocalOutlierFactor from sklearn.covariance import EllipticEnvelope - from scipy.stats import chi2, norm, gamma, uniform + from scipy.stats import chi2, norm, gamma, uniform, laplace, cauchy except Exception: pass @@ -127,6 +127,9 @@ def zscore_survival_function( elif method == "mad": median_diff = np.abs((df - df.median(axis=0))) residual_score = median_diff / median_diff.mean(axis=0) + elif method == "med_diff": + median_diff = df.diff().median() + residual_score = (df.diff().fillna(0) / median_diff).abs() else: raise ValueError("zscore method not recognized") @@ -153,6 +156,12 @@ def zscore_survival_function( return pd.DataFrame( chi2.sf(residual_score, dof), index=df.index, columns=columns ) + elif distribution == "cauchy": + return pd.DataFrame( + cauchy.sf(residual_score, dof), index=df.index, columns=columns + ) + elif distribution == "laplace": + return pd.DataFrame(laplace.sf(residual_score), index=df.index, columns=columns) elif distribution == "uniform": return pd.DataFrame( uniform.sf(residual_score, dof), index=df.index, columns=columns @@ -222,7 +231,7 @@ def values_to_anomalies(df, output, threshold_method, method_params, n_jobs=1): columns=cols, ) return res, scores - elif threshold_method in ["zscore", "rolling_zscore", "mad"]: + elif threshold_method in ["zscore", "rolling_zscore", "mad", "med_diff"]: alpha = method_params.get("alpha", 0.05) distribution = method_params.get("distribution", "norm") rolling_periods = method_params.get("rolling_periods", 200) @@ -382,7 +391,7 @@ def detect_anomalies( res, scores = sk_outliers(df_anomaly, method, method_params) else: res, scores = loop_sk_outliers(df_anomaly, method, method_params, n_jobs) - elif method in ["zscore", "rolling_zscore", "mad", "minmax"]: + elif method in ["zscore", "rolling_zscore", "mad", "minmax", "med_diff"]: res, scores = values_to_anomalies(df_anomaly, output, method, method_params) elif method in ["IQR"]: iqr_thresh = method_params.get("iqr_threshold", 2.0) @@ -428,6 +437,7 @@ def detect_anomalies( "prediction_interval", # ridiculously slow "IQR", "nonparametric", + "med_diff", ] fast_methods = [ "zscore", @@ -436,6 +446,7 @@ def detect_anomalies( "minmax", "IQR", "nonparametric", + "med_diff", ] @@ -443,10 +454,12 @@ def anomaly_new_params(method='random'): if method == "deep": method_choice = random.choices( available_methods, - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1, 0.15], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1, 0.15, 0.1], )[0] elif method == "fast": - method_choice = random.choices(fast_methods, [0.4, 0.3, 0.1, 0.1, 0.4, 0.05])[0] + method_choice = random.choices( + fast_methods, [0.4, 0.3, 0.1, 0.1, 0.4, 0.05, 0.1] + )[0] elif method in available_methods: method_choice = method else: @@ -498,12 +511,25 @@ def anomaly_new_params(method='random'): elif method_choice == "rolling_zscore": method_params = { 'distribution': random.choices( - ['norm', 'gamma', 'chi2', 'uniform'], [0.4, 0.2, 0.2, 0.2] + ['norm', 'gamma', 'chi2', 'uniform', "laplace", "cauchy"], + [0.4, 0.2, 0.2, 0.2, 0.1, 0.1], + )[0], + 'alpha': random.choices( + [0.01, 0.03, 0.05, 0.1, 0.2, 0.4], [0.1, 0.1, 0.8, 0.1, 0.1, 0.01] )[0], - 'alpha': random.choices([0.03, 0.05, 0.1], [0.1, 0.8, 0.1])[0], 'rolling_periods': random.choice([28, 90, 200, 300]), 'center': random.choice([True, False]), } + elif method_choice == "med_diff": + method_params = { + 'distribution': random.choices( + ['norm', 'gamma', 'chi2', 'uniform', "laplace", "cauchy"], + [0.4, 0.2, 0.2, 0.2, 0.1, 0.1], + )[0], + 'alpha': random.choices( + [0.01, 0.03, 0.05, 0.1, 0.2, 0.6], [0.1, 0.1, 0.8, 0.1, 0.1, 0.05] + )[0], + } elif method_choice == "mad": method_params = { 'distribution': random.choices( diff --git a/autots/tools/seasonal.py b/autots/tools/seasonal.py index 352e6ff5..b0fbfc35 100644 --- a/autots/tools/seasonal.py +++ b/autots/tools/seasonal.py @@ -10,6 +10,7 @@ from autots.tools.lunar import moon_phase from autots.tools.window_functions import sliding_window_view from autots.tools.holiday import holiday_flag +from autots.tools.wavelet import offset_wavelet, create_narrowing_wavelets def seasonal_int(include_one: bool = False, small=False, very_small=False): @@ -274,6 +275,57 @@ def date_part( ) if method == "common_fourier_rw": date_part_df['epoch'] = (DTindex.to_julian_date() ** 0.65).astype(int) + elif "morlet" in method: + parts = method.split("_") + if len(parts) >= 2: + p = parts[1] + else: + p = 7 + if len(parts) >= 3: + order = parts[2] + else: + order = 7 + if len(parts) >= 4: + sigma = parts[3] + else: + sigma = 4.0 + date_part_df = seasonal_repeating_wavelet( + DTindex, p=p, order=order, sigma=sigma, wavelet_type='morlet' + ) + elif "ricker" in method: + parts = method.split("_") + if len(parts) >= 2: + p = parts[1] + else: + p = 7 + if len(parts) >= 3: + order = parts[2] + else: + order = 7 + if len(parts) >= 4: + sigma = parts[3] + else: + sigma = 4.0 + date_part_df = seasonal_repeating_wavelet( + DTindex, p=p, order=order, sigma=sigma, wavelet_type='ricker' + ) + elif "db2" in method: + parts = method.split("_") + if len(parts) >= 2: + p = parts[1] + else: + p = 7 + if len(parts) >= 3: + order = parts[2] + else: + order = 7 + if len(parts) >= 4: + sigma = parts[3] + else: + sigma = 4.0 + date_part_df = seasonal_repeating_wavelet( + DTindex, p=p, order=order, sigma=sigma, wavelet_type='db2' + ) else: # method == "simple" date_part_df = pd.DataFrame( @@ -355,13 +407,17 @@ def fourier_series(t, p=365.25, n=10): def fourier_df(DTindex, seasonality, order=10, t=None, history_days=None): - if history_days is None: - history_days = (DTindex.max() - DTindex.min()).days + # if history_days is None: + # history_days = (DTindex.max() - DTindex.min()).days if t is None: - t = (DTindex - pd.Timestamp(origin_ts)).days - return pd.DataFrame( - fourier_series(np.asarray(t), seasonality / history_days, n=order) - ).rename(columns=lambda x: f"seasonality{seasonality}_" + str(x)) + # Calculate the time difference in days as a float to preserve the exact time + t = (DTindex - pd.Timestamp(origin_ts)).total_seconds() / 86400 + # for only daily: t = (DTindex - pd.Timestamp(origin_ts)).days + # for nano seconds: t = (DTindex - pd.Timestamp(origin_ts)).to_numpy(dtype=np.int64) // (1000 * 1000 * 1000) / (3600 * 24.) + # formerly seasonality / history_days below + return pd.DataFrame(fourier_series(np.asarray(t), seasonality, n=order)).rename( + columns=lambda x: f"seasonality{seasonality}_" + str(x) + ) datepart_components = [ @@ -504,32 +560,57 @@ def create_seasonality_feature(DTindex, t, seasonality, history_days=None): ) +base_seasonalities = [ + "recurring", + "simple", + "expanded", + "simple_2", + "simple_binarized", + "expanded_binarized", + 'common_fourier', + 'common_fourier_rw', + "simple_poly", + [7, 365.25], + ["dayofweek", 365.25], + ['weekdayofmonth', 'common_fourier'], + [52, 'quarter'], + [168, "hour"], + ["morlet_365.25_12_12", "ricker_7_7_1"], + ["db2_365.25_12_0.5", "morlet_7_7_1"], + "other", +] + + def random_datepart(method='random'): """New random parameters for seasonality.""" seasonalities = random.choices( + base_seasonalities, [ - "recurring", - "simple", - "expanded", - "simple_2", - "simple_binarized", - "expanded_binarized", - 'common_fourier', - 'common_fourier_rw', - "simple_poly", - [7, 365.25], - ["dayofweek", 365.25], - ['weekdayofmonth', 'common_fourier'], - "other", + 0.4, + 0.3, + 0.3, + 0.3, + 0.4, + 0.35, + 0.45, + 0.2, + 0.1, + 0.1, + 0.05, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.3, ], - [0.4, 0.3, 0.3, 0.3, 0.4, 0.35, 0.45, 0.2, 0.1, 0.1, 0.05, 0.1, 0.2], )[0] if seasonalities == "other": predefined = random.choices([True, False], [0.5, 0.5])[0] if predefined: seasonalities = [random.choice(date_part_methods)] else: - comp_opts = datepart_components + [7, 365.25, 12] + comp_opts = datepart_components + [7, 365.25, 12, 52, 168] seasonalities = random.choices(comp_opts, k=2) return seasonalities @@ -661,3 +742,25 @@ def seasonal_independent_match( if k > min_k: test = np.where(test >= len(DTindex), -1, test) return test, scores + + +def seasonal_repeating_wavelet(DTindex, p, order=12, sigma=4.0, wavelet_type='morlet'): + t = (DTindex - pd.Timestamp(origin_ts)).total_seconds() / 86400 + + if wavelet_type == "db2": + wavelets = create_narrowing_wavelets( + p=float(p), max_order=int(order), t=t, sigma=float(sigma) + ) + else: + wavelets = offset_wavelet( + p=float(p), # Weekly period + t=t, # A full year (365 days) + # origin_ts=origin_ts, + order=int(order), # One offset for each day of the week + # frequency=2 * np.pi / p, # Frequency for weekly pattern + sigma=float(sigma), # Smaller sigma for tighter weekly spread + wavelet_type=wavelet_type, + ) + return pd.DataFrame(wavelets, index=DTindex).rename( + columns=lambda x: f"wavelet_{p}_" + str(x) + ) diff --git a/autots/tools/transform.py b/autots/tools/transform.py index e67cb888..ed3e6ae3 100644 --- a/autots/tools/transform.py +++ b/autots/tools/transform.py @@ -1425,63 +1425,83 @@ def inverse_transform(self, df, regressor=None): DatepartRegression = DatepartRegressionTransformer -class DifferencedTransformer(EmptyTransformer): +class DifferencedTransformer: """Difference from lag n value. - inverse_transform can only be applied to the original series, or an immediately following forecast + inverse_transform can only be applied to the original series, or an immediately following forecast. Args: - lag (int): number of periods to shift (not implemented, default = 1) + lag (int): number of periods to shift. + fill (str): method to fill NaN values created by differencing, options: 'bfill', 'zero'. """ - def __init__(self, **kwargs): - super().__init__(name="DifferencedTransformer") - self.lag = 1 + def __init__(self, lag=1, fill='bfill'): + self.name = "DifferencedTransformer" + self.lag = lag + self.fill = fill + self.last_values = None + self.first_values = None + + @staticmethod + def get_new_params(method: str = "random"): + method_c = random.choices(["bfill", "zero", "one"], [0.5, 0.2, 0.01])[0] + choice = random.choices([1, 2, 7], [0.8, 0.1, 0.1])[0] + return {"lag": choice, "fill": method_c} def fit(self, df): """Fit. Args: - df (pandas.DataFrame): input dataframe + df (pandas.DataFrame): input dataframe. """ - self.last_values = df.tail(self.lag) - self.first_values = df.head(self.lag) + self.last_values = df.iloc[-self.lag :] + self.first_values = df.iloc[: self.lag] return self def transform(self, df): """Return differenced data. Args: - df (pandas.DataFrame): input dataframe + df (pandas.DataFrame): input dataframe. """ - return (df - df.shift(self.lag)).bfill() + differenced = df.diff(self.lag) + if self.fill == 'bfill': + return differenced.bfill() + elif self.fill == 'zero': + return differenced.fillna(0) + elif self.fill == 'one': + return differenced.fillna(1) + else: + raise ValueError( + f"DifferencedTransformer fill method {self.fill} not recognized" + ) def fit_transform(self, df): - """Fits and Returns Magical DataFrame + """Fits and returns differenced DataFrame. Args: - df (pandas.DataFrame): input dataframe + df (pandas.DataFrame): input dataframe. """ self.fit(df) return self.transform(df) - def inverse_transform(self, df, trans_method: str = "forecast"): + def inverse_transform(self, df, trans_method="forecast"): """Returns data to original *or* forecast form Args: - df (pandas.DataFrame): input dataframe + df (pandas.DataFrame): input dataframe. trans_method (str): whether to inverse on original data, or on a following sequence - 'original' return original data to original numbers - - 'forecast' inverse the transform on a dataset immediately following the original + - 'forecast' inverse the transform on a dataset immediately following the original. """ - lag = self.lag - # add last values, group by lag, cumsum if trans_method == "original": - df = pd.concat([self.first_values, df.tail(df.shape[0] - lag)]) - return df.cumsum() - else: + df_with_first = pd.concat( + [self.first_values, df.tail(df.shape[0] - self.lag)] + ) + return df_with_first.cumsum() + elif trans_method == "forecast": df_len = df.shape[0] - df = pd.concat([self.last_values, df], axis=0) - if df.isnull().to_numpy().any(): - raise ValueError("NaN in DifferencedTransformer.inverse_transform") - return df.cumsum().tail(df_len) + df_with_last = pd.concat([self.last_values, df]) + return df_with_last.cumsum().tail(df_len) + else: + raise ValueError("Invalid transformation method specified.") class PctChangeTransformer(EmptyTransformer): @@ -4841,7 +4861,7 @@ def get_new_params(method: str = "random"): "None": EmptyTransformer(), None: EmptyTransformer(), "RollingMean10": RollingMeanTransformer(window=10), - "DifferencedTransformer": DifferencedTransformer(), + # "DifferencedTransformer": DifferencedTransformer(), "PctChangeTransformer": PctChangeTransformer(), "SinTrend": SinTrend(), "SineTrend": SinTrend(), @@ -4917,6 +4937,7 @@ def get_new_params(method: str = "random"): "DiffSmoother": DiffSmoother, "HistoricValues": HistoricValues, "BKBandpassFilter": BKBandpassFilter, + "DifferencedTransformer": DifferencedTransformer, } # where results will vary if not all series are included together shared_trans = [ @@ -5473,7 +5494,7 @@ def get_transformer_params(transformer: str = "EmptyTransformer", method: str = 'HolidayTransformer': 0.01, 'LocalLinearTrend': 0.01, 'KalmanSmoothing': 0.02, - 'RegressionFilter': 0.02, + 'RegressionFilter': 0.01, "LevelShiftTransformer": 0.03, "CenterSplit": 0.01, "FFTFilter": 0.01, @@ -5560,12 +5581,12 @@ def get_transformer_params(transformer: str = "EmptyTransformer", method: str = "median": 0.03, None: 0.001, "interpolate": 0.4, - "KNNImputer": 0.05, + "KNNImputer": 0.02, # can get a bit slow "IterativeImputerExtraTrees": 0.0001, # and this one is even slower - "SeasonalityMotifImputer": 0.1, # apparently this is too memory hungry at scale + "SeasonalityMotifImputer": 0.02, # apparently this is too memory hungry at scale "SeasonalityMotifImputerLinMix": 0.01, # apparently this is too memory hungry at scale "SeasonalityMotifImputer1K": 0.01, # apparently this is too memory hungry at scale - "DatepartRegressionImputer": 0.05, # also slow + "DatepartRegressionImputer": 0.01, # also slow } diff --git a/autots/tools/wavelet.py b/autots/tools/wavelet.py new file mode 100644 index 00000000..bbdb6544 --- /dev/null +++ b/autots/tools/wavelet.py @@ -0,0 +1,291 @@ +import numpy as np +import pandas as pd + + +def create_gaussian_wavelet(p, frequency=3, sigma=1.0): + """ + Create a Gaussian-modulated cosine wavelet with specified frequency and sigma. + + Parameters: + - p (float): The period or length to generate the wavelet. + - frequency (int): Frequency of the cosine wave. + - sigma (float): Standard deviation for the Gaussian envelope. + + Returns: + - np.ndarray: The generated Gaussian-modulated wavelet. + """ + x = np.arange(-1, 1, 2 / p) # Adjusted to accommodate float 'p' + wavelet = np.cos(frequency * np.pi * x) * np.exp(-(x**2) / (2 * sigma**2)) + return wavelet + + +def create_morlet_wavelet(p, frequency=3, sigma=1.0): + """ + Create a Morlet wavelet with specified frequency and sigma. + + Parameters: + - p (float): The period or length to generate the wavelet. + - frequency (int): Frequency of the cosine wave. + - sigma (float): Standard deviation for the Gaussian envelope. + + Returns: + - np.ndarray: The generated complex Morlet wavelet. + """ + x = np.arange(-1, 1, 2 / p) # Adjusted to accommodate float 'p' + real_part = np.cos(frequency * np.pi * x) * np.exp(-(x**2) / (2 * sigma**2)) + imag_part = np.sin(frequency * np.pi * x) * np.exp(-(x**2) / (2 * sigma**2)) + wavelet = real_part + 1j * imag_part # Complex wavelet + return wavelet + + +def create_real_morlet_wavelet(p, frequency=3, sigma=1.0): + """ + Create a real-valued Morlet wavelet with specified frequency and sigma. + + Parameters: + - p (float): The period or length to generate the wavelet. + - frequency (int): Frequency of the cosine wave. + - sigma (float): Standard deviation for the Gaussian envelope. + + Returns: + - np.ndarray: The generated real Morlet wavelet. + """ + x = np.arange(-1, 1, 2 / p) # Adjusted to accommodate float 'p' + # Real component of the Morlet wavelet + wavelet = np.cos(frequency * np.pi * x) * np.exp(-(x**2) / (2 * sigma**2)) + return wavelet + + +def create_mexican_hat_wavelet(p, frequency=None, sigma=1.0): + """ + Create a Mexican Hat wavelet (Ricker wavelet) with specified sigma. + + Parameters: + - p (float): The period or length to generate the wavelet. + - sigma (float): Standard deviation for the Gaussian envelope. + + Returns: + - np.ndarray: The generated Mexican Hat wavelet. + """ + x = np.arange(-1, 1, 2 / p) # Adjusted to accommodate float 'p' + wavelet = (1 - x**2 / sigma**2) * np.exp(-(x**2) / (2 * sigma**2)) + return wavelet + + +def create_haar_wavelet(p): + """ + Create a Haar wavelet with specified period `p`. + + Parameters: + - p (float): The period or length to generate the wavelet. + + Returns: + - np.ndarray: The generated Haar wavelet. + """ + if p <= 0: + raise ValueError("The period `p` must be greater than zero.") + + # Create the Haar wavelet + x = np.arange(0, p) # Discrete points to create the wavelet + # The Haar wavelet has a step function: +1 for the first half, -1 for the second half + half = len(x) // 2 + wavelet = np.zeros(len(x)) + wavelet[:half] = 1 + wavelet[half:] = -1 + + return wavelet + + +def create_daubechies_db2_wavelet(p): + """ + Create a Daubechies db2 wavelet with specified period `p`. + + Parameters: + - p (int): The period or length to generate the wavelet. + + Returns: + - np.ndarray: The generated Daubechies db2 wavelet. + """ + if p <= 0: + raise ValueError("The period `p` must be greater than zero.") + + # Coefficients for the Daubechies db2 wavelet + # These are the scaling coefficients for the db2 wavelet + coeffs = np.array( + [ + (1 + np.sqrt(3)) / 4, + (3 + np.sqrt(3)) / 4, + (3 - np.sqrt(3)) / 4, + (1 - np.sqrt(3)) / 4, + ] + ) + + # Generate a base wavelet of the specified length `p` + # To create the wavelet, replicate the coefficients to fit the desired period `p` + base_wavelet = np.tile(coeffs, int(np.ceil(p / len(coeffs))))[:p] + + return base_wavelet + + +############################################################################## + + +def create_wavelet(t, p, sigma=1.0, phase_shift=0, wavelet_type="morlet"): + """ + Create a real-valued wavelet based on real-world anchored time steps in t, + with an additional phase shift and a choice of wavelet type. + + Parameters: + - t (np.ndarray): Array of time steps (in days) from a specified origin. + - p (float): The period of the wavelet in the same units as t (typically days). + - sigma (float): Standard deviation for the Gaussian envelope. + - phase_shift (float): Phase shift to adjust the position of the wavelet peak. + - wavelet_type (str): Type of wavelet ('morlet' or 'ricker'). + + Returns: + - np.ndarray: The generated wavelet values for each time step. + """ + x = (t + phase_shift) % p - p / 2 # Normalize and center t around 0 + + if wavelet_type == "morlet": + return np.cos(2 * np.pi * x / p) * np.exp(-(x**2) / (2 * sigma**2)) + elif wavelet_type == "ricker": + # Ricker (Mexican Hat) wavelet calculation + a = 2 * sigma**2 + return (1 - (x**2 / a)) * np.exp(-(x**2) / (2 * sigma**2)) + else: + raise ValueError("Unsupported wavelet type. Choose 'morlet' or 'ricker'.") + + +def offset_wavelet(p, t, order=5, sigma=1.0, wavelet_type="morlet"): + """ + Create an offset collection of wavelets with `order` offsets, ensuring that + peaks are spaced p/order apart. + + Parameters: + - p (float): Period of the wavelet in the same units as t (typically days). + - t (np.ndarray): Array of time steps. + - order (int): The number of offsets. + - sigma (float): Standard deviation for the Gaussian envelope. + - wavelet_type (str): Type of wavelet ('morlet' or 'ricker'). + + Returns: + - np.ndarray: A 2D array with `order` wavelets along axis 1. + """ + wavelet_features = [] + phase_offsets = np.linspace( + 0, p, order, endpoint=False + ) # Properly space phase shifts over one period + + for phase_shift in phase_offsets: + wavelet = create_wavelet(t, p, sigma, phase_shift, wavelet_type) + wavelet_features.append(wavelet) + + return np.stack(wavelet_features, axis=1) + + +if False: + DTindex = pd.date_range("2020-01-01", "2024-01-01", freq="D") + origin_ts = "2030-01-01" + t = (DTindex - pd.Timestamp(origin_ts)).total_seconds() / 86400 + + p = 7 + weekly_wavelets = offset_wavelet( + p=p, # Weekly period + t=t, # A full year (365 days) + # origin_ts=origin_ts, + order=7, # One offset for each day of the week + # frequency=2 * np.pi / p, # Frequency for weekly pattern + sigma=0.5, # Smaller sigma for tighter weekly spread + wavelet_type="morlet", + ) + + # Example for yearly seasonality + p = 365.25 + yearly_wavelets = offset_wavelet( + p=p, # Yearly period + t=t, # Three full years + # origin_ts=origin_ts, + order=12, # One offset for each month + # frequency=2 * np.pi / p, # Frequency for yearly pattern + sigma=2.0, # Larger sigma for broader yearly spread + wavelet_type="morlet", + ) + yearly_wavelets2 = offset_wavelet( + p=p, # Yearly period + t=t[-100:], # Three full years + # origin_ts=origin_ts, + order=12, # One offset for each month + # frequency=2 * np.pi / p, # Frequency for yearly pattern + sigma=2.0, # Larger sigma for broader yearly spread + wavelet_type="morlet", + ) + print(np.allclose(yearly_wavelets[-100:], yearly_wavelets2)) + + # Display wavelet patterns for visualization + import matplotlib.pyplot as plt + + plt.figure(figsize=(12, 6)) + pd.DataFrame(weekly_wavelets).plot(title="Weekly Wavelets", ax=plt.gca()) + pd.DataFrame(yearly_wavelets).plot(title="Yearly Wavelets", ax=plt.gca()) + plt.show() + + pd.DataFrame(weekly_wavelets[0:50]).plot(title="Weekly Wavelets", ax=plt.gca()) + plt.show() + + +############################################################################## + + +def continuous_db2_wavelet(t, p, order, sigma): + # Normalize t to [0, 1) interval based on period p, scaled by order to include multiple cycles + x = (order * t % p) / p + if order % 3 == 0: + x = x + 0.3 + gaussian_envelope = np.exp(-0.5 * ((x - 0.5) / sigma) ** 2) + sinusoidal_component = np.sin(2 * np.pi * x) + wavelet = gaussian_envelope * sinusoidal_component + return wavelet + + +def create_narrowing_wavelets(p, max_order, t, sigma=0.5): + wavelets = [] + for order in range(1, max_order + 1): + sigma = sigma / order # Narrow the Gaussian envelope as order increases + wavelet = continuous_db2_wavelet(t, p, order, sigma) + wavelets.append(wavelet) + return np.array(wavelets).T + + +if False: + # Example usage + DTindex = pd.date_range("2020-01-01", "2024-01-01", freq="D") + origin_ts = "2020-01-01" + t_full = (DTindex - pd.Timestamp(origin_ts)).total_seconds() / 86400 + + p = 365.25 # Example period + max_order = 5 # Example maximum order + + # Full set of wavelets + wavelets = create_narrowing_wavelets(p, max_order, t_full) + + # Wavelets for the last 100 days + t_subset = t_full[-100:] + wavelet_short = create_narrowing_wavelets(p, max_order, t_subset) + + # Check if the last 100 days of the full series match the subset + print(np.allclose(wavelets[-100:], wavelet_short)) # This should be true + + # Plotting the wavelets + plt.figure(figsize=(12, 6)) + for i in range(max_order): + plt.plot(DTindex[-100:], wavelets[-100:, i], label=f"Order {i+1}") + plt.plot( + DTindex[-100:], + wavelet_short[:, i], + label=f"Subset Order {i+1}", + linestyle="--", + ) + plt.title("Comparison of Full Wavelets and Subset") + plt.legend() + plt.show() diff --git a/docs/build/doctrees/environment.pickle b/docs/build/doctrees/environment.pickle index 3e50ded7..a51ca554 100644 Binary files a/docs/build/doctrees/environment.pickle and b/docs/build/doctrees/environment.pickle differ diff --git a/docs/build/doctrees/source/autots.doctree b/docs/build/doctrees/source/autots.doctree index 870bc237..d4e09e8c 100644 Binary files a/docs/build/doctrees/source/autots.doctree and b/docs/build/doctrees/source/autots.doctree differ diff --git a/docs/build/doctrees/source/autots.evaluator.doctree b/docs/build/doctrees/source/autots.evaluator.doctree index 20f4f934..4402d533 100644 Binary files a/docs/build/doctrees/source/autots.evaluator.doctree and b/docs/build/doctrees/source/autots.evaluator.doctree differ diff --git a/docs/build/doctrees/source/autots.models.doctree b/docs/build/doctrees/source/autots.models.doctree index 9eadecbe..1669616e 100644 Binary files a/docs/build/doctrees/source/autots.models.doctree and b/docs/build/doctrees/source/autots.models.doctree differ diff --git a/docs/build/doctrees/source/autots.templates.doctree b/docs/build/doctrees/source/autots.templates.doctree index 9c952875..b138a361 100644 Binary files a/docs/build/doctrees/source/autots.templates.doctree and b/docs/build/doctrees/source/autots.templates.doctree differ diff --git a/docs/build/doctrees/source/autots.tools.doctree b/docs/build/doctrees/source/autots.tools.doctree index ba5fbdbe..3a98a9bd 100644 Binary files a/docs/build/doctrees/source/autots.tools.doctree and b/docs/build/doctrees/source/autots.tools.doctree differ diff --git a/docs/build/html/.buildinfo b/docs/build/html/.buildinfo index 4a6d98a6..91838d53 100644 --- a/docs/build/html/.buildinfo +++ b/docs/build/html/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: cd7239cc7dc0c2f0136aaa8bd24d37d1 +config: 18a08ac66d0dab3da1a184ed771beee5 tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/build/html/_sources/source/autots.tools.rst.txt b/docs/build/html/_sources/source/autots.tools.rst.txt index 718e56e6..a389c2b0 100644 --- a/docs/build/html/_sources/source/autots.tools.rst.txt +++ b/docs/build/html/_sources/source/autots.tools.rst.txt @@ -148,6 +148,14 @@ autots.tools.transform module :undoc-members: :show-inheritance: +autots.tools.wavelet module +--------------------------- + +.. automodule:: autots.tools.wavelet + :members: + :undoc-members: + :show-inheritance: + autots.tools.window\_functions module ------------------------------------- diff --git a/docs/build/html/_static/documentation_options.js b/docs/build/html/_static/documentation_options.js index 7ce78999..e4673028 100644 --- a/docs/build/html/_static/documentation_options.js +++ b/docs/build/html/_static/documentation_options.js @@ -1,5 +1,5 @@ const DOCUMENTATION_OPTIONS = { - VERSION: '0.6.10', + VERSION: '0.6.11', LANGUAGE: 'en', COLLAPSE_INDEX: false, BUILDER: 'html', diff --git a/docs/build/html/genindex.html b/docs/build/html/genindex.html index bd70c353..98e3624b 100644 --- a/docs/build/html/genindex.html +++ b/docs/build/html/genindex.html @@ -13,10 +13,10 @@ gtag('config', 'G-P2KLF8302E'); - Index — AutoTS 0.6.10 documentation + Index — AutoTS 0.6.11 documentation - + @@ -97,6 +97,8 @@

A

  • AnomalyRemoval (class in autots.tools.transform) +
  • +
  • apply_constraint_single() (in module autots.models.base)
  • apply_constraints() (autots.models.base.PredictionObject method), [1] @@ -270,8 +272,6 @@

    A

  • module
  • - - + + - + +
  • create_gaussian_wavelet() (in module autots.tools.wavelet) +
  • +
  • create_haar_wavelet() (in module autots.tools.wavelet) +
  • create_lagged_regressor() (in module autots)
  • +
  • create_mexican_hat_wavelet() (in module autots.tools.wavelet) +
  • create_model_id() (in module autots.evaluator.auto_model) +
  • +
  • create_morlet_wavelet() (in module autots.tools.wavelet) +
  • +
  • create_narrowing_wavelets() (in module autots.tools.wavelet) +
  • +
  • create_real_morlet_wavelet() (in module autots.tools.wavelet)
  • create_regressor() (in module autots) @@ -708,6 +735,8 @@

    C

  • (in module autots.models.cassandra)
  • +
  • create_wavelet() (in module autots.tools.wavelet) +
  • cross_validate() (autots.Cassandra method)
  • - - +
  • fit_linear_model() (in module autots.models.cassandra)
  • +
  • fit_predict() (autots.evaluator.auto_model.ModelPrediction method) + +
  • fit_sin() (autots.tools.transform.SinTrend static method)
  • fit_transform() (autots.GeneralTransformer method) @@ -1452,6 +1493,8 @@

    G

  • (autots.models.gluonts.GluonTS method)
  • (autots.models.greykite.Greykite method) +
  • +
  • (autots.models.matrix_var.DMD method)
  • (autots.models.matrix_var.LATC method)
  • @@ -1536,6 +1579,8 @@

    G

  • (autots.tools.transform.DatepartRegressionTransformer static method)
  • (autots.tools.transform.Detrend static method) +
  • +
  • (autots.tools.transform.DifferencedTransformer static method)
  • (autots.tools.transform.DiffSmoother static method)
  • @@ -1628,6 +1673,8 @@

    G

  • (autots.models.gluonts.GluonTS method)
  • (autots.models.greykite.Greykite method) +
  • +
  • (autots.models.matrix_var.DMD method)
  • (autots.models.matrix_var.LATC method)
  • @@ -2243,6 +2290,8 @@

    M

  • autots.tools.thresholding
  • autots.tools.transform +
  • +
  • autots.tools.wavelet
  • autots.tools.window_functions
  • @@ -2330,6 +2379,10 @@

    O

    +
    @@ -2569,6 +2622,8 @@

    P

  • (autots.models.gluonts.GluonTS method)
  • (autots.models.greykite.Greykite method) +
  • +
  • (autots.models.matrix_var.DMD method)
  • (autots.models.matrix_var.LATC method)
  • @@ -2872,6 +2927,8 @@

    S

  • seasonal_independent_match() (in module autots.tools.seasonal)
  • seasonal_int() (in module autots.tools.seasonal) +
  • +
  • seasonal_repeating_wavelet() (in module autots.tools.seasonal)
  • seasonal_window_match() (in module autots.tools.seasonal)
  • @@ -2887,6 +2944,8 @@

    S

  • seek_the_oracle() (in module autots.models.greykite)
  • + + - +
  • apply_constraint_single()
  • apply_constraints()
  • calculate_peak_density()
  • +
  • constant_growth_rate()
  • create_forecast_index()
  • create_seaborn_palette_from_cmap()
  • extract_single_series_from_horz()
  • @@ -553,6 +556,13 @@

    Subpackagesautots.models.matrix_var module +
  • autots.tools.wavelet module +
  • autots.tools.window_functions module @@ -270,6 +271,7 @@

    autotsModelPrediction

  • diff --git a/docs/build/html/source/tutorial.html b/docs/build/html/source/tutorial.html index cef15a95..e52b7a31 100644 --- a/docs/build/html/source/tutorial.html +++ b/docs/build/html/source/tutorial.html @@ -14,10 +14,10 @@ gtag('config', 'G-P2KLF8302E'); - Tutorial — AutoTS 0.6.10 documentation + Tutorial — AutoTS 0.6.11 documentation - + diff --git a/docs/conf.py b/docs/conf.py index dce8fb8e..9ca3c7ca 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,7 +25,7 @@ # import AutoTS # from AutoTS import __version__ # release = __version__ -release = "0.6.11" +release = "0.6.12" # -- General configuration --------------------------------------------------- diff --git a/docs/source/autots.tools.rst b/docs/source/autots.tools.rst index 718e56e6..a389c2b0 100644 --- a/docs/source/autots.tools.rst +++ b/docs/source/autots.tools.rst @@ -148,6 +148,14 @@ autots.tools.transform module :undoc-members: :show-inheritance: +autots.tools.wavelet module +--------------------------- + +.. automodule:: autots.tools.wavelet + :members: + :undoc-members: + :show-inheritance: + autots.tools.window\_functions module ------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index 8c88788c..30432bc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "autots" -version = "0.6.11" +version = "0.6.12" authors = [ { name="Colin Catlin", email="colin.catlin@gmail.com" }, ] diff --git a/setup.py b/setup.py index 8e65d163..fb107b9a 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ setuptools.setup( name="AutoTS", - version="0.6.11", + version="0.6.12", author="Colin Catlin", author_email="colin.catlin@syllepsis.live", description="Automated Time Series Forecasting", diff --git a/tests/test_autots.py b/tests/test_autots.py index 7c05b10a..1f2d2212 100644 --- a/tests/test_autots.py +++ b/tests/test_autots.py @@ -274,13 +274,20 @@ def test_all_default_models(self): transformer_list = "fast" # ["SinTrend", "MinMaxScaler"] transformer_max_depth = 1 + constraint = { + "constraint_method": "quantile", + "constraint_regularization": 0.9, + "upper_constraint": 0.99, + "lower_constraint": 0.01, + "bounds": True, + } model = AutoTS( forecast_length=forecast_length, frequency='infer', prediction_interval=0.9, ensemble=["horizontal-max"], - constraint=None, + constraint=constraint, max_generations=generations, num_validations=num_validations, validation_method=validation_method, @@ -579,6 +586,7 @@ def test_models(self): 'Cassandra', 'MetricMotif', 'SeasonalityMotif', 'KalmanStateSpace', 'ARDL', 'UnivariateMotif', 'VAR', 'MAR', 'TMF', 'RRVAR', 'VECM', 'BallTreeMultivariateMotif', 'FFT', + # "DMD", # 0.6.12 ] # models that for whatever reason arne't consistent across test sessions run_only_no_score = ['FBProphet', 'RRVAR', "TMF"] diff --git a/tests/test_cassandra.py b/tests/test_cassandra.py index 58dc17e6..f29085a8 100644 --- a/tests/test_cassandra.py +++ b/tests/test_cassandra.py @@ -58,10 +58,21 @@ def test_model(self): ) } constraint = { - "constraint_method": "quantile", - "lower_constraint": 0, - "upper_constraint": None, - "bounds": True, + "constraints": [{ + "constraint_method": "last_window", + "constraint_value": 0.5, + "constraint_direction": "upper", + "constraint_regularization": 1.0, + "bounds": True, + }, + { + "constraint_method": "last_window", + "constraint_value": -0.5, + "constraint_direction": "lower", + "constraint_regularization": 1.0, + "bounds": True, + }, + ] } past_impacts = pd.DataFrame(0, index=df_train.index, columns=df_train.columns) past_impacts.iloc[-10:, 0] = np.geomspace(1, 10)[0:10] / 100 diff --git a/tests/test_constraint.py b/tests/test_constraint.py new file mode 100644 index 00000000..17734978 --- /dev/null +++ b/tests/test_constraint.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +"""Test constraint.""" +import unittest +import numpy as np +import pandas as pd +from autots import load_daily, ModelPrediction + + +class TestConstraint(unittest.TestCase): + + def test_constraint(self): + df = load_daily(long=False) + forecast_length = 30 + constraint_types = { + "empty": { + "constraints": None, + }, + "old_style": { + "constraint_method": "quantile", + "constraint_regularization": 0.99, + "upper_constraint": 0.5, + "lower_constraint": 0.1, + "bounds": True, + }, + "quantile": { + "constraints": [{ + "constraint_method": "quantile", + "constraint_value": 0.98, + "constraint_direction": "upper", + "constraint_regularization": 1.0, + "bounds": False, + },] + }, + "last_value": { + "constraints": [{ + "constraint_method": "last_window", + "constraint_value": 0.0, + "constraint_direction": "upper", + "constraint_regularization": 1.0, + "bounds": True, + }, + { + "constraint_method": "last_window", + "constraint_value": 0.0, + "constraint_direction": "lower", + "constraint_regularization": 1.0, + "bounds": True, + }, + ] + }, + "example": {"constraints": [ + { # don't exceed historic max + "constraint_method": "quantile", + "constraint_value": 1.0, + "constraint_direction": "upper", + "constraint_regularization": 1.0, + "bounds": True, + }, + { # don't exceed 2% growth by end of forecast horizon + "constraint_method": "slope", + "constraint_value": {"slope": 0.02, "window": 10, "window_agg": "max", "threshold": 0.01}, + "constraint_direction": "upper", + "constraint_regularization": 0.9, + "bounds": False, + }, + { # don't go below the last 10 values - 10% + "constraint_method": "last_window", + "constraint_value": {"window": 10, "threshold": -0.1}, + "constraint_direction": "lower", + "constraint_regularization": 1.0, + "bounds": False, + }, + { # don't go below zero + "constraint_method": "absolute", + "constraint_value": 0, # can also be an array or Series + "constraint_direction": "lower", + "constraint_regularization": 1.0, + "bounds": True, + }, + { # don't go below historic min - 1 st dev + "constraint_method": "stdev_min", + "constraint_value": 1.0, + "constraint_direction": "lower", + "constraint_regularization": 1.0, + "bounds": True, + }, + { # don't go above historic mean + 3 st devs, soft limit + "constraint_method": "stdev", + "constraint_value": 3.0, + "constraint_direction": "upper", + "constraint_regularization": 0.5, + "bounds": True, + }, + ]}, + "dampening": { + "constraints": [{ + "constraint_method": "dampening", + "constraint_value": 0.98, + "bounds": True, + },] + }, + } + for key, constraint in constraint_types.items(): + with self.subTest(i=key): + model = ModelPrediction( + forecast_length=forecast_length, + transformation_dict={ + "fillna": "median", + "transformations": {"0": "SinTrend", "1": "QuantileTransformer", "2": "bkfilter"}, + "transformation_params": {"0": {}, "1": {"output_distribution": "uniform", "n_quantiles": 1000}, "2": {}} + }, + model_str="SeasonalityMotif", + parameter_dict={ + "window": 7, "point_method": "midhinge", + "distance_metric": "canberra", "k": 10, + "datepart_method": "common_fourier", + }, + no_negatives=True, + ) + prediction = model.fit_predict(df, forecast_length=forecast_length) + # apply an artificially low value + prediction.forecast.iloc[0, 0] = -10 + prediction.forecast.iloc[0, -1] = df.iloc[:, -1].max() * 1.1 + prediction.plot(df, df.columns[-1]) + prediction.plot(df, df.columns[0]) + + prediction.apply_constraints( + df_train=df, + **constraint + ) + prediction.plot(df, df.columns[-1]) + prediction.plot(df, df.columns[0]) + # assuming all history was positive as example data currently is + if key in ["empty", "dampening"]: + self.assertTrue(prediction.forecast.min().min() == -10) + else: + self.assertTrue((prediction.forecast.sum() > 0).all()) + + if key in ["old_style", "quantile"]: + pred_max = prediction.forecast.iloc[:, -1].max() + hist_max = df.iloc[:, -1].max() + print(pred_max) + print(hist_max) + self.assertTrue(pred_max <= hist_max) + if key in ["last_value"]: + self.assertTrue(prediction.forecast.iloc[0, :].max() == df.iloc[-1, :].max()) \ No newline at end of file diff --git a/tests/test_seasonal.py b/tests/test_seasonal.py new file mode 100644 index 00000000..19dd9fc7 --- /dev/null +++ b/tests/test_seasonal.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +""" +Seasonal unittests. + +Created on Sat May 4 21:43:02 2024 + +@author: Colin +""" + +import unittest +import numpy as np +import pandas as pd +from autots.tools.seasonal import date_part, base_seasonalities, datepart_components, random_datepart, fourier_df +from autots.tools.wavelet import create_narrowing_wavelets, offset_wavelet + + +class TestSeasonal(unittest.TestCase): + + def test_date_part(self): + DTindex = pd.date_range("2020-01-01", "2024-01-01", freq="D") + for method in base_seasonalities: + df = date_part(DTindex, method=method, set_index=True) + self.assertEqual(df.shape[0], DTindex.shape[0]) + self.assertGreater(df.shape[1], 1) + + def test_date_components(self): + DTindex = pd.date_range("2023-01-01", "2024-01-01", freq="h") + for method in datepart_components: + df = date_part(DTindex, method=method, set_index=True) + self.assertEqual(df.shape[0], DTindex.shape[0]) + + def test_random_datepart(self): + out = random_datepart() + self.assertTrue(out) + + def test_fourier(self): + DTindex = pd.date_range("2020-01-02", "2024-01-01", freq="D") + order = 10 + df = fourier_df(DTindex, seasonality=365.25, order=order) + self.assertEqual(df.shape[1], order * 2) + self.assertEqual(df.shape[0], DTindex.shape[0]) + self.assertAlmostEqual(df.mean().sum(), 0.0) + + def test_wavelets_repeat(self): + DTindex = pd.date_range("2020-01-01", "2024-01-01", freq="D") + origin_ts = "2030-01-01" + t = (DTindex - pd.Timestamp(origin_ts)).total_seconds() / 86400 + + p = 7 + w_order = 7 + weekly_wavelets = offset_wavelet( + p=p, # Weekly period + t=t, # A full year (365 days) + # origin_ts=origin_ts, + order=w_order, # One offset for each day of the week + # frequency=2 * np.pi / p, # Frequency for weekly pattern + sigma=0.5, # Smaller sigma for tighter weekly spread + wavelet_type="ricker", + ) + self.assertEqual(weekly_wavelets.shape[1], w_order) + + # Example for yearly seasonality + p = 365.25 + y_order = 12 + yearly_wavelets = offset_wavelet( + p=p, # Yearly period + t=t, # Three full years + # origin_ts=origin_ts, + order=y_order, # One offset for each month + # frequency=2 * np.pi / p, # Frequency for yearly pattern + sigma=2.0, # Larger sigma for broader yearly spread + wavelet_type="morlet", + ) + yearly_wavelets2 = offset_wavelet( + p=p, # Yearly period + t=t[-100:], # Three full years + # origin_ts=origin_ts, + order=y_order, # One offset for each month + # frequency=2 * np.pi / p, # Frequency for yearly pattern + sigma=2.0, # Larger sigma for broader yearly spread + wavelet_type="morlet", + ) + self.assertEqual(yearly_wavelets.shape[1], y_order) + self.assertTrue(np.allclose(yearly_wavelets[-100:], yearly_wavelets2)) + + def test_wavelet_continuous(self): + DTindex = pd.date_range("2020-01-01", "2024-01-01", freq="D") + origin_ts = "2020-01-01" + t_full = (DTindex - pd.Timestamp(origin_ts)).total_seconds() / 86400 + + p = 365.25 # Example period + max_order = 5 # Example maximum order + + # Full set of wavelets + wavelets = create_narrowing_wavelets(p, max_order, t_full) + + # Wavelets for the last 100 days + t_subset = t_full[-100:] + wavelet_short = create_narrowing_wavelets(p, max_order, t_subset) + + # Check if the last 100 days of the full series match the subset + self.assertTrue(np.allclose(wavelets[-100:], wavelet_short)) \ No newline at end of file