diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 24b4b85..9507d05 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -21,7 +21,7 @@ jobs: - name: Check out code uses: actions/checkout@v2 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v2 with: python-version: 3.8 @@ -41,7 +41,7 @@ jobs: pytest --cov=firebolt_provider/ tests/ --cov-report=xml - name: Upload coverage report - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: pytest-coverage-report path: coverage.xml diff --git a/firebolt_provider/hooks/firebolt.py b/firebolt_provider/hooks/firebolt.py index bdb7d05..1c73a33 100644 --- a/firebolt_provider/hooks/firebolt.py +++ b/firebolt_provider/hooks/firebolt.py @@ -18,7 +18,7 @@ # from collections import namedtuple -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union from airflow.version import version as airflow_version from firebolt.client import DEFAULT_API_URL @@ -61,6 +61,7 @@ class FireboltHook(DbApiHook): default_conn_name = "firebolt_default" conn_type = "firebolt" hook_name = "Firebolt" + supports_autocommit = False ConnectionParameters = namedtuple( "ConnectionParameters", @@ -201,38 +202,6 @@ def engine_action(self, engine_name: Optional[str], action: str) -> None: """ self._run_action(self._get_engine(engine_name), action) - def run( - self, - sql: Union[str, List], - autocommit: bool = False, - parameters: Optional[Sequence] = None, - handler: Optional[Callable] = None, - ) -> None: - """ - Runs a command or a list of commands. Pass a list of sql - statements to the sql parameter to get them to execute - sequentially - :param sql: the sql statement to be executed (str) or a list of - sql statements to execute - :type sql: str or list - :param autocommit: What to set the connection's autocommit setting to - before executing the query. - :type autocommit: bool - :param parameters: The parameters to render the SQL query with. - :type parameters: dict or iterable - """ - scalar = isinstance(sql, str) - if scalar: - sql = [sql] - with self.get_conn() as conn: - with conn.cursor() as cursor: - for sql_statement in sql: - if parameters: - cursor.execute(sql_statement, parameters) - else: - cursor.execute(sql_statement) - self.log.info(f"Rows returned: {cursor.rowcount}") - def test_connection(self) -> Tuple[bool, str]: """Test the Firebolt connection by running a simple query.""" try: diff --git a/tests/hooks/test_firebolt.py b/tests/hooks/test_firebolt.py index 9177065..0cbca86 100644 --- a/tests/hooks/test_firebolt.py +++ b/tests/hooks/test_firebolt.py @@ -22,6 +22,7 @@ from unittest import mock from unittest.mock import MagicMock, patch +from airflow.providers.common.sql.hooks.sql import fetch_all_handler from firebolt.client.auth import ClientCredentials, UsernamePassword from firebolt.utils.exception import FireboltError @@ -157,20 +158,18 @@ def test_run_with_parameters(self): sql = "SQL" parameters = ("param1", "param2") self.db_hook.run(sql=sql, parameters=parameters) - self.conn.__enter__().cursor().__enter__().execute.assert_called_once_with( - sql, parameters - ) + self.conn.cursor().execute.assert_called_once_with(sql, parameters) def test_run_with_single_query(self): sql = "SQL" self.db_hook.run(sql) - self.conn.__enter__().cursor().__enter__().execute.assert_called_once_with(sql) + self.conn.cursor().execute.assert_called_once_with(sql) def test_run_multi_queries(self): sql = ["SQL1", "SQL2"] self.db_hook.run(sql, autocommit=True) for query in sql: - self.conn.__enter__().cursor().__enter__().execute.assert_any_call(query) + self.conn.cursor().execute.assert_any_call(query) def test_get_ui_field_behaviour(self): widget = { @@ -237,3 +236,26 @@ def test_engine_action_start_default(self, conn_params_call): with self.assertRaises(FireboltError): self.db_hook.engine_action(None, "start") + + def test_run_returns_results(self): + sql = ["SQL1", "SQL2"] + self.cursor.fetchall.return_value = [(1, 2)] + res = self.db_hook.run(sql, handler=fetch_all_handler) + assert res == [[(1, 2)], [(1, 2)]] + + sql = "SQL1; SQL2" + res = self.db_hook.run( + sql, + handler=fetch_all_handler, + return_last=False, + split_statements=True, + ) + assert res == [[(1, 2)], [(1, 2)]] + + res = self.db_hook.run( + sql, + handler=fetch_all_handler, + return_last=True, + split_statements=True, + ) + assert res == [(1, 2)]