Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataflow changes #1018

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/run-forecast-unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ jobs:
$CONDA/bin/conda init
source /home/runner/.bashrc
pip install -r test-requirements-operators.txt
pip install "oracle-automlx[forecasting]>=24.4.0"
pip install "oracle-automlx[forecasting]>=24.4.1"
pip install pandas>=2.2.0
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast
36 changes: 20 additions & 16 deletions ads/opctl/operator/lowcode/forecast/model/automlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,6 @@ def _build_model(self) -> pd.DataFrame:

from automlx import Pipeline, init

cpu_count = os.cpu_count()
try:
if cpu_count < 4:
engine = "local"
engine_opts = None
else:
engine = "ray"
engine_opts = ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
init(
engine=engine,
engine_opts=engine_opts,
loglevel=logging.CRITICAL,
)
except Exception as e:
logger.info(f"Error. Has Ray already been initialized? Skipping. {e}")

full_data_dict = self.datasets.get_data_by_series()

self.models = {}
Expand All @@ -112,6 +96,26 @@ def _build_model(self) -> pd.DataFrame:
# Clean up kwargs for pass through
model_kwargs_cleaned, time_budget = self.set_kwargs()

cpu_count = os.cpu_count()
try:
engine_type = model_kwargs_cleaned.pop(
"engine", "local" if cpu_count <= 4 else "ray"
)
engine_opts = (
None
if engine_type == "local"
else ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
)
init(
engine=engine_type,
engine_opts=engine_opts,
loglevel=logging.CRITICAL,
)
except Exception as e:
logger.info(
f"Error initializing automlx. Has Ray already been initialized? Skipping. {e}"
)

for s_id, df in full_data_dict.items():
try:
logger.debug(f"Running automlx on series {s_id}")
Expand Down
6 changes: 1 addition & 5 deletions ads/opctl/operator/lowcode/forecast/model/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,7 @@ def _generate_report(self):
logger.debug(f"Full Traceback: {traceback.format_exc()}")

model_description = rc.Text(
"Prophet is a procedure for forecasting time series data based on an additive "
"model where non-linear trends are fit with yearly, weekly, and daily seasonality, "
"plus holiday effects. It works best with time series that have strong seasonal "
"effects and several seasons of historical data. Prophet is robust to missing "
"data and shifts in the trend, and typically handles outliers well."
"""Prophet is a procedure for forecasting time series data based on an additive model where non-linear trends are fit with yearly, weekly, and daily seasonality, plus holiday effects. It works best with time series that have strong seasonal effects and several seasons of historical data. Prophet is robust to missing data and shifts in the trend, and typically handles outliers well."""
)
other_sections = all_sections

Expand Down
11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,27 +157,26 @@ forecast = [
"oci-cli",
"py-cpuinfo",
"rich",
"autots[additional]",
"autots",
"mlforecast",
"neuralprophet>=0.7.0",
"numpy<2.0.0",
"oci-cli",
"optuna",
"oracle-ads",
"pmdarima",
"prophet",
"shap",
"sktime",
"statsmodels",
"plotly",
"oracledb",
"report-creator==1.0.28",
"report-creator==1.0.32",
]
anomaly = [
"oracle_ads[opctl]",
"autots",
"oracledb",
"report-creator==1.0.28",
"report-creator==1.0.32",
"rrcf==0.4.4",
"scikit-learn",
"salesforce-merlion[all]==2.0.4"
Expand All @@ -186,7 +185,7 @@ recommender = [
"oracle_ads[opctl]",
"scikit-surprise",
"plotly",
"report-creator==1.0.28",
"report-creator==1.0.32",
]
feature-store-marketplace = [
"oracle-ads[opctl]",
Expand All @@ -202,7 +201,7 @@ pii = [
"scrubadub_spacy",
"spacy-transformers==1.2.5",
"spacy==3.6.1",
"report-creator==1.0.28",
"report-creator==1.0.32",
]
llm = ["langchain>=0.2", "langchain-community", "langchain_openai", "pydantic>=2,<3", "evaluate>=0.4.0"]
aqua = ["jupyter_server"]
Expand Down
Loading