Skip to content

Commit

Permalink
feat: bring consistency to operators parameter names
Browse files Browse the repository at this point in the history
  • Loading branch information
mpetazzoni committed Oct 11, 2024
1 parent 9d39fa9 commit ac3b4eb
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 32 deletions.
1 change: 1 addition & 0 deletions airflow_providers_wherobots/hooks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
"""

DEFAULT_CONN_ID = "wherobots_default"
PACKAGE_NAME = "airflow-providers-wherobots"
15 changes: 7 additions & 8 deletions airflow_providers_wherobots/hooks/rest_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Hook for WhereRobots API
Hook for Wherobots' HTTP API
"""

import platform
Expand All @@ -15,7 +15,10 @@
from requests.adapters import HTTPAdapter, Retry
from requests.auth import AuthBase

from airflow_providers_wherobots.hooks.base import DEFAULT_CONN_ID
from airflow_providers_wherobots.hooks.base import (
DEFAULT_CONN_ID,
PACKAGE_NAME,
)
from airflow_providers_wherobots.wherobots.models import (
Run,
LogsResponse,
Expand Down Expand Up @@ -67,13 +70,13 @@ def conn(self) -> Connection:
@cached_property
def user_agent_header(self):
try:
package_version = metadata.version("airflow-providers-wherobots")
package_version = metadata.version(PACKAGE_NAME)
except metadata.PackageNotFoundError:
package_version = "unknown"
python_version = platform.python_version()
system = platform.system().lower()
header_value = (
f"airflow-providers-wherobots/{package_version} os/{system}"
f"{PACKAGE_NAME}/{package_version} os/{system}"
f" python/{python_version} airflow/{airflow_version}"
)
return {"User-Agent": header_value}
Expand Down Expand Up @@ -120,7 +123,3 @@ def get_run_logs(self, run_id: str, start: int, size: int = 500) -> LogsResponse
params = {"cursor": start, "size": size}
resp_json = self._api_call("GET", f"/runs/{run_id}/logs", params=params).json()
return LogsResponse.model_validate(resp_json)


if __name__ == "__main__":
metadata.version("airflow-providers-wherobots")
25 changes: 13 additions & 12 deletions airflow_providers_wherobots/hooks/sql.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,60 @@
"""
Constants
Hook for Wherobots' Spatial SQL API interface.
"""

from __future__ import annotations

import logging
from typing import Optional

import wherobots.db
from airflow.providers.common.sql.hooks.sql import DbApiHook
from wherobots.db import Connection as WDBConnection
from wherobots.db import Runtime, connect
from wherobots.db.constants import (
DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS,
DEFAULT_READ_TIMEOUT_SECONDS,
DEFAULT_REUSE_SESSION,
DEFAULT_RUNTIME,
)

from airflow_providers_wherobots.hooks.base import DEFAULT_CONN_ID

log = logging.getLogger(__name__)


class WherobotsSqlHook(DbApiHook): # type: ignore[misc]
conn_name_attr = "wherobots_conn_id"

def __init__( # type: ignore[no-untyped-def]
self,
wherobots_conn_id: str = DEFAULT_CONN_ID,
runtime_id: Runtime = Runtime.SEDONA,
runtime: Runtime = DEFAULT_RUNTIME,
session_wait_timeout: int = DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS,
read_timeout: int = DEFAULT_READ_TIMEOUT_SECONDS,
reuse_session: bool = DEFAULT_REUSE_SESSION,
**kwargs,
):
super().__init__(**kwargs)
self.wherobots_conn_id = wherobots_conn_id
self.runtime = runtime
self.session_wait_timeout = session_wait_timeout
self.read_timeout = read_timeout
self.runtime_id = runtime_id
self.reuse_session = reuse_session

self._conn = self.get_connection(self.wherobots_conn_id)
self._db_conn: Optional[wherobots.db.Connection] = None
self._db_conn: Optional[WDBConnection] = None

def _create_or_get_sql_session(
self, runtime: Runtime = Runtime.SEDONA
self,
runtime: Runtime = DEFAULT_RUNTIME,
) -> WDBConnection:
return connect(
host=self._conn.host,
api_key=self._conn.get_password(),
runtime=runtime,
wait_timeout=self.session_wait_timeout,
read_timeout=self.read_timeout,
reuse_session=self.reuse_session,
)

def get_conn(self) -> WDBConnection:
if not self._db_conn:
self._db_conn = self._create_or_get_sql_session(self.runtime_id)
self._db_conn = self._create_or_get_sql_session(self.runtime)
return self._db_conn

def get_autocommit(self, conn: WDBConnection) -> bool:
Expand Down
7 changes: 5 additions & 2 deletions airflow_providers_wherobots/operators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
Run,
)

from wherobots.db import Runtime
from wherobots.db.constants import DEFAULT_RUNTIME


class XComKey(StrEnum):
run_id = auto()
Expand All @@ -32,7 +35,7 @@ class WherobotsRunOperator(BaseOperator):
def __init__(
self,
name: Optional[str] = None,
runtime: str = "TINY",
runtime: Runtime = DEFAULT_RUNTIME,
run_python: Optional[dict[str, Any]] = None,
run_jar: Optional[dict[str, Any]] = None,
environment: Optional[dict[str, Any]] = None,
Expand All @@ -45,7 +48,7 @@ def __init__(
super().__init__(**kwargs)
# If the user specifies the name, we will use it and rely on the server to validate the name
self.run_payload: dict[str, Any] = {
"runtime": runtime,
"runtime": runtime.value,
"name": name or self.default_run_name,
}
if run_python:
Expand Down
11 changes: 8 additions & 3 deletions airflow_providers_wherobots/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from wherobots.db.constants import (
DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS,
DEFAULT_READ_TIMEOUT_SECONDS,
DEFAULT_REUSE_SESSION,
DEFAULT_RUNTIME,
)
from wherobots.db import Cursor as WDbCursor
from pandas.core.frame import DataFrame
Expand Down Expand Up @@ -39,23 +41,26 @@ def __init__( # type: ignore[no-untyped-def]
self,
*,
wherobots_conn_id: str = DEFAULT_CONN_ID,
runtime_id: Runtime = Runtime.SEDONA,
runtime: Runtime = DEFAULT_RUNTIME,
session_wait_timeout: int = DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS,
read_timeout: int = DEFAULT_READ_TIMEOUT_SECONDS,
reuse_session: bool = DEFAULT_REUSE_SESSION,
**kwargs,
):
super().__init__(
conn_id=wherobots_conn_id, handler=wherobots_default_handler, **kwargs
)
self.wherobots_conn_id = wherobots_conn_id
self.runtime_id = runtime_id
self.runtime = runtime
self.session_wait_timeout = session_wait_timeout
self.read_timeout = read_timeout
self.reuse_session = reuse_session

def get_db_hook(self) -> DbApiHook:
return WherobotsSqlHook(
wherobots_conn_id=self.wherobots_conn_id,
runtime_id=self.runtime_id,
runtime=self.runtime,
session_wait_timeout=self.session_wait_timeout,
read_timeout=self.read_timeout,
reuse_session=self.reuse_session,
)
5 changes: 2 additions & 3 deletions tests/integration_tests/operator/test_sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Test the operators in sql module
Test the operators in SQL module
"""

import datetime
Expand All @@ -23,9 +23,8 @@
def test_prod_run_success(prod_conn: Connection, dag: DAG) -> None:
operator = WherobotsSqlOperator(
task_id=TEST_TASK_ID,
sql="select pickup_datetime from wherobots_pro_data.nyc_taxi.yellow_2009_2010 limit 10",
sql="SELECT pickup_datetime FROM wherobots_pro_data.nyc_taxi.yellow_2009_2010 LIMIT 10",
wherobots_conn_id=prod_conn.conn_id,
runtime_id=Runtime.SEDONA,
dag=dag,
)
result = operator.execute(context={})
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/hooks/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_create_run(self, test_default_conn) -> None:
url = f"https://{test_default_conn.host}/runs"
create_payload = {
"name": test_run.name,
"runtime": Runtime.SEDONA.value,
"runtime": Runtime.TINY.value,
"python": {
"uri": "s3://bucket/test.py",
"args": ["arg1", "arg2"],
Expand Down
5 changes: 3 additions & 2 deletions tests/unit_tests/hooks/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ class TestWherobotsSqlHook:
@mock.patch("airflow_providers_wherobots.hooks.sql.connect")
def test_get_conn(self, mock_connect: MagicMock, test_default_conn: Connection):
# Instantiate hook
hook = WherobotsSqlHook(runtime_id=Runtime.ATLANTIS)
hook = WherobotsSqlHook(runtime=Runtime.LARGE)

# Sample Hook's run method executes an API call
hook.get_conn()
mock_connect.assert_called_once_with(
host=test_default_conn.host,
api_key="token",
runtime=Runtime.ATLANTIS,
runtime=Runtime.LARGE,
wait_timeout=hook.session_wait_timeout,
read_timeout=hook.read_timeout,
reuse_session=True,
)
2 changes: 1 addition & 1 deletion tests/unit_tests/operator/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_default_handler(self, mock_connect):
operator = WherobotsSqlOperator(
task_id="test_task",
sql="select * from table_a",
runtime_id=Runtime.ATLANTIS,
runtime=Runtime.LARGE,
)
result = operator.execute(context={})
print(result)

0 comments on commit ac3b4eb

Please sign in to comment.