Skip to content

Commit

Permalink
allow users to specify engine type for automlx
Browse files Browse the repository at this point in the history
  • Loading branch information
ahosler committed Dec 9, 2024
1 parent a9a9b05 commit 5cf2b6d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 22 deletions.
36 changes: 20 additions & 16 deletions ads/opctl/operator/lowcode/forecast/model/automlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,6 @@ def _build_model(self) -> pd.DataFrame:

from automlx import Pipeline, init

cpu_count = os.cpu_count()
try:
if cpu_count < 4:
engine = "local"
engine_opts = None
else:
engine = "ray"
engine_opts = ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
init(
engine=engine,
engine_opts=engine_opts,
loglevel=logging.CRITICAL,
)
except Exception as e:
logger.info(f"Error. Has Ray already been initialized? Skipping. {e}")

full_data_dict = self.datasets.get_data_by_series()

self.models = {}
Expand All @@ -112,6 +96,26 @@ def _build_model(self) -> pd.DataFrame:
# Clean up kwargs for pass through
model_kwargs_cleaned, time_budget = self.set_kwargs()

cpu_count = os.cpu_count()
try:
engine_type = model_kwargs_cleaned.pop(
"engine", "local" if cpu_count <= 4 else "ray"
)
engine_opts = (
None
if engine_type == "local"
else ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
)
init(
engine=engine_type,
engine_opts=engine_opts,
loglevel=logging.CRITICAL,
)
except Exception as e:
logger.info(
f"Error initializing automlx. Has Ray already been initialized? Skipping. {e}"
)

for s_id, df in full_data_dict.items():
try:
logger.debug(f"Running automlx on series {s_id}")
Expand Down
11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,27 +157,26 @@ forecast = [
"oci-cli",
"py-cpuinfo",
"rich",
"autots[additional]",
"autots",
"mlforecast",
"neuralprophet>=0.7.0",
"numpy<2.0.0",
"oci-cli",
"optuna",
"oracle-ads",
"pmdarima",
"prophet",
"shap",
"sktime",
"statsmodels",
"plotly",
"oracledb",
"report-creator==1.0.28",
"report-creator==1.0.32",
]
anomaly = [
"oracle_ads[opctl]",
"autots",
"oracledb",
"report-creator==1.0.28",
"report-creator==1.0.32",
"rrcf==0.4.4",
"scikit-learn",
"salesforce-merlion[all]==2.0.4"
Expand All @@ -186,7 +185,7 @@ recommender = [
"oracle_ads[opctl]",
"scikit-surprise",
"plotly",
"report-creator==1.0.28",
"report-creator==1.0.32",
]
feature-store-marketplace = [
"oracle-ads[opctl]",
Expand All @@ -202,7 +201,7 @@ pii = [
"scrubadub_spacy",
"spacy-transformers==1.2.5",
"spacy==3.6.1",
"report-creator==1.0.28",
"report-creator==1.0.32",
]
llm = ["langchain>=0.2", "langchain-community", "langchain_openai", "pydantic>=2,<3", "evaluate>=0.4.0"]
aqua = ["jupyter_server"]
Expand Down

0 comments on commit 5cf2b6d

Please sign in to comment.