diff --git a/desktop/libs/notebook/src/notebook/connectors/trino.py b/desktop/libs/notebook/src/notebook/connectors/trino.py index fdbdbfcea98..99f6ddc6c73 100644 --- a/desktop/libs/notebook/src/notebook/connectors/trino.py +++ b/desktop/libs/notebook/src/notebook/connectors/trino.py @@ -108,12 +108,13 @@ def execute(self, notebook, snippet): query_client = TrinoQuery(self.trino_request, 'USE ' + database) query_client.execute() - statement = snippet['statement'].rstrip(';') + current_statement = self._get_current_statement(notebook, snippet) + statement = current_statement['statement'] query_client = TrinoQuery(self.trino_request, statement) response = self.trino_request.post(query_client.query) status = self.trino_request.process(response) - return { + response = { 'row_count': 0, 'next_uri': status.next_uri, 'sync': None, @@ -133,6 +134,9 @@ def execute(self, notebook, snippet): 'type': 'table' } } + response.update(current_statement) + + return response @query_error_handler diff --git a/desktop/libs/notebook/src/notebook/connectors/trino_tests.py b/desktop/libs/notebook/src/notebook/connectors/trino_tests.py index cd21e51d6c2..b5d9b363127 100644 --- a/desktop/libs/notebook/src/notebook/connectors/trino_tests.py +++ b/desktop/libs/notebook/src/notebook/connectors/trino_tests.py @@ -156,7 +156,38 @@ def test_execute(self): mock_trino_request.process.return_value = MagicMock(stats={'state': 'FINISHED'}, next_uri='http://url', id=123, rows=[]) # Call the execute method - result = self.trino_api.execute(notebook={}, snippet={'database': 'test_db', 'statement': 'SELECT * FROM test_table;'}) + snippet = { + 'database': 'test_db', + 'statement': 'SELECT * FROM test_table;', + 'result': {'handle': {}} + } + result = self.trino_api.execute(notebook={}, snippet=snippet) + + expected_result = { + 'row_count': 0, + 'next_uri': 'http://url', + 'sync': None, + 'has_result_set': True, + 'guid': 123, + 'result': { + 'has_more': True, + 'data': [], + 'meta': [], + 'type': 'table' + }, + 'statement_id': 0, 'has_more_statements': False, 'statements_count': 1, + 'previous_statement_hash': 'd1c7e7dd8869098919761253c921eea865d48ca79d4e43092c321cfd', + 'start': {'row': 0, 'column': 0}, 'end': {'row': 0, 'column': 23}, 'statement': 'SELECT * FROM test_table' + } + assert result == expected_result + + # Test multiple query execution + snippet = { + 'database': 'test_db', + 'statement': 'use test_db;\nshow tables', + 'result': {'handle': {}} + } + result = self.trino_api.execute(notebook={}, snippet=snippet) expected_result = { 'row_count': 0, @@ -169,7 +200,10 @@ def test_execute(self): 'data': [], 'meta': [], 'type': 'table' - } + }, + 'statement_id': 0, 'has_more_statements': True, 'statements_count': 2, + 'previous_statement_hash': '793204944f1800a86d75684d4be11eccb03b35f68441febb1362fd35', + 'start': {'row': 0, 'column': 0}, 'end': {'row': 0, 'column': 12}, 'statement': 'use test_db' } assert result == expected_result