Skip to content

Commit

Permalink
feat: Return query result from hook (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
stepansergeevitch authored Jan 13, 2025
1 parent 7296250 commit ba5a9f6
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 40 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pull-request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Check out code
uses: actions/checkout@v2

- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
Expand All @@ -41,7 +41,7 @@ jobs:
pytest --cov=firebolt_provider/ tests/ --cov-report=xml
- name: Upload coverage report
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v4
with:
name: pytest-coverage-report
path: coverage.xml
Expand Down
35 changes: 2 additions & 33 deletions firebolt_provider/hooks/firebolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#

from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

from airflow.version import version as airflow_version
from firebolt.client import DEFAULT_API_URL
Expand Down Expand Up @@ -61,6 +61,7 @@ class FireboltHook(DbApiHook):
default_conn_name = "firebolt_default"
conn_type = "firebolt"
hook_name = "Firebolt"
supports_autocommit = False

ConnectionParameters = namedtuple(
"ConnectionParameters",
Expand Down Expand Up @@ -201,38 +202,6 @@ def engine_action(self, engine_name: Optional[str], action: str) -> None:
"""
self._run_action(self._get_engine(engine_name), action)

def run(
self,
sql: Union[str, List],
autocommit: bool = False,
parameters: Optional[Sequence] = None,
handler: Optional[Callable] = None,
) -> None:
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
sequentially
:param sql: the sql statement to be executed (str) or a list of
sql statements to execute
:type sql: str or list
:param autocommit: What to set the connection's autocommit setting to
before executing the query.
:type autocommit: bool
:param parameters: The parameters to render the SQL query with.
:type parameters: dict or iterable
"""
scalar = isinstance(sql, str)
if scalar:
sql = [sql]
with self.get_conn() as conn:
with conn.cursor() as cursor:
for sql_statement in sql:
if parameters:
cursor.execute(sql_statement, parameters)
else:
cursor.execute(sql_statement)
self.log.info(f"Rows returned: {cursor.rowcount}")

def test_connection(self) -> Tuple[bool, str]:
"""Test the Firebolt connection by running a simple query."""
try:
Expand Down
32 changes: 27 additions & 5 deletions tests/hooks/test_firebolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from unittest import mock
from unittest.mock import MagicMock, patch

from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from firebolt.client.auth import ClientCredentials, UsernamePassword
from firebolt.utils.exception import FireboltError

Expand Down Expand Up @@ -157,20 +158,18 @@ def test_run_with_parameters(self):
sql = "SQL"
parameters = ("param1", "param2")
self.db_hook.run(sql=sql, parameters=parameters)
self.conn.__enter__().cursor().__enter__().execute.assert_called_once_with(
sql, parameters
)
self.conn.cursor().execute.assert_called_once_with(sql, parameters)

def test_run_with_single_query(self):
sql = "SQL"
self.db_hook.run(sql)
self.conn.__enter__().cursor().__enter__().execute.assert_called_once_with(sql)
self.conn.cursor().execute.assert_called_once_with(sql)

def test_run_multi_queries(self):
sql = ["SQL1", "SQL2"]
self.db_hook.run(sql, autocommit=True)
for query in sql:
self.conn.__enter__().cursor().__enter__().execute.assert_any_call(query)
self.conn.cursor().execute.assert_any_call(query)

def test_get_ui_field_behaviour(self):
widget = {
Expand Down Expand Up @@ -237,3 +236,26 @@ def test_engine_action_start_default(self, conn_params_call):

with self.assertRaises(FireboltError):
self.db_hook.engine_action(None, "start")

def test_run_returns_results(self):
sql = ["SQL1", "SQL2"]
self.cursor.fetchall.return_value = [(1, 2)]
res = self.db_hook.run(sql, handler=fetch_all_handler)
assert res == [[(1, 2)], [(1, 2)]]

sql = "SQL1; SQL2"
res = self.db_hook.run(
sql,
handler=fetch_all_handler,
return_last=False,
split_statements=True,
)
assert res == [[(1, 2)], [(1, 2)]]

res = self.db_hook.run(
sql,
handler=fetch_all_handler,
return_last=True,
split_statements=True,
)
assert res == [(1, 2)]

0 comments on commit ba5a9f6

Please sign in to comment.