Skip to content

Commit

Permalink
[Trino] Refactor Trino-related code for readability and maintainabili…
Browse files Browse the repository at this point in the history
…ty (#3637)
  • Loading branch information
agl29 authored Mar 2, 2024
1 parent 55ab28d commit a17d499
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 87 deletions.
6 changes: 3 additions & 3 deletions desktop/core/src/desktop/js/apps/notebook/snippet.js
Original file line number Diff line number Diff line change
Expand Up @@ -1858,7 +1858,7 @@ class Snippet {

if (self.type() === 'trino') {
const existing_handle = self.result.handle();
existing_handle.row_n = data.handle.row_n;
existing_handle.row_count = data.handle.row_count;
existing_handle.next_uri = data.handle.next_uri;
}
self.showLogs(true);
Expand Down Expand Up @@ -2189,7 +2189,7 @@ class Snippet {

if (self.type() === 'trino') {
const existing_handle = self.result.handle();
existing_handle.row_n = data.result.row_n;
existing_handle.row_count = data.result.row_count;
existing_handle.next_uri = data.result.next_uri;
}
} else {
Expand Down Expand Up @@ -2369,7 +2369,7 @@ class Snippet {
) {
if (self.type() === 'trino') {
const existing_handle = self.result.handle();
existing_handle.row_n = 0;
existing_handle.row_count = 0;
existing_handle.next_uri = data.query_status.next_uri;
}
const delay = self.result.executionTime() > 45000 ? 5000 : 1000; // 5s if more than 45s
Expand Down
149 changes: 65 additions & 84 deletions desktop/libs/notebook/src/notebook/connectors/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import logging
import json
import posixpath
import requests
import sys
import textwrap
Expand All @@ -26,6 +25,8 @@
from django.utils.translation import gettext as _
from urllib.parse import urlparse

from beeswax import conf
from beeswax import data_export
from desktop.lib import export_csvxls
from desktop.lib.i18n import force_unicode
from desktop.lib.rest.http_client import HttpClient, RestException
Expand All @@ -36,6 +37,7 @@
from trino.auth import BasicAuthentication
from trino.client import ClientSession, TrinoRequest, TrinoQuery


def query_error_handler(func):
def decorator(*args, **kwargs):
try:
Expand All @@ -53,14 +55,10 @@ def decorator(*args, **kwargs):
return decorator



class TrinoApi(Api):

def __init__(self, user, interpreter=None):
Api.__init__(self, user, interpreter=interpreter)

self.options = interpreter['options']

self.server_host, self.server_port, self.http_scheme = self.parse_api_url(self.options['url'])
self.auth = None

Expand All @@ -70,8 +68,7 @@ def __init__(self, user, interpreter=None):
self.auth = BasicAuthentication(self.auth_username, self.auth_password)

trino_session = ClientSession(user.username)

self.db = TrinoRequest(
self.trino_request = TrinoRequest(
host=self.server_host,
port=self.server_port,
client_session=trino_session,
Expand All @@ -93,16 +90,16 @@ def create_session(self, lang=None, properties=None):
@query_error_handler
def execute(self, notebook, snippet):
database = snippet['database']
query_client = TrinoQuery(self.db, 'USE ' + database)
query_client = TrinoQuery(self.trino_request, 'USE ' + database)
query_client.execute()

statement = snippet['statement'].rstrip(';')
query_client = TrinoQuery(self.db, statement)
response = self.db.post(query_client.query)
status = self.db.process(response)
query_client = TrinoQuery(self.trino_request, statement)
response = self.trino_request.post(query_client.query)
status = self.trino_request.process(response)

return {
'row_n': 0,
'row_count': 0,
'next_uri': status.next_uri,
'sync': None,
'has_result_set': status.next_uri is not None,
Expand All @@ -127,84 +124,75 @@ def execute(self, notebook, snippet):
def check_status(self, notebook, snippet):
response = {}
status = 'expired'
next_uri = snippet['result']['handle']['next_uri']

if snippet['result']['handle']['next_uri'] is None:
if next_uri is None:
status = 'available'
else:
_response = self.db.get(snippet['result']['handle']['next_uri'])
_status = self.db.process(_response)
_response = self.trino_request.get(next_uri)
_status = self.trino_request.process(_response)
if _status.stats['state'] == 'QUEUED':
status = 'waiting'
elif _status.stats['state'] == 'RUNNING':
status = 'available' # need to varify
status = 'available' # need to verify
else:
status = 'available'

response['status'] = status

if status != 'available':
response['next_uri'] = _status.next_uri
else:
response['next_uri'] = snippet['result']['handle']['next_uri']

response['next_uri'] = _status.next_uri if status != 'available' else next_uri
return response


@query_error_handler
def fetch_result(self, notebook, snippet, rows, start_over):
data = []
_columns = []
_next_uri = snippet['result']['handle']['next_uri']
processed_rows = snippet['result']['handle'].get('row_n', 0)
columns = []
next_uri = snippet['result']['handle']['next_uri']
processed_rows = snippet['result']['handle'].get('row_count', 0)
status = False

if processed_rows == 0:
data = snippet['result']['handle']['result']['data']

while _next_uri:
while next_uri:
try:
response = self.db.get(_next_uri)
response = self.trino_request.get(next_uri)
except requests.exceptions.RequestException as e:
raise trino.exceptions.TrinoConnectionError("failed to fetch: {}".format(e))

status = self.db.process(response)
status = self.trino_request.process(response)
data += status.rows
_columns = status.columns
columns = status.columns

if len(data) >= processed_rows + 100:
if processed_rows < 0:
data = data[0:100]
data = data[:100]
else:
data = data[processed_rows:processed_rows + 100]
break

_next_uri = status.next_uri
next_uri = status.next_uri
current_length = len(data)
data = data[processed_rows:processed_rows + 100]
processed_rows = processed_rows - current_length
processed_rows -= current_length

return {
'row_n': 100 + processed_rows,
'next_uri': _next_uri,
'has_more': bool(status.next_uri) if status else False,
'data': data or [],
'meta': [{
'name': column['name'],
'type': column['type'],
'comment': ''
}
for column in _columns if status
],
'type': 'table'
'row_count': 100 + processed_rows,
'next_uri': next_uri,
'has_more': bool(status.next_uri) if status else False,
'data': data or [],
'meta': [{
'name': column['name'],
'type': column['type'],
'comment': ''
} for column in columns] if status else [],
'type': 'table'
}


@query_error_handler
def autocomplete(self, snippet, database=None, table=None, column=None, nested=None, operation=None):
response = {}

# if catalog is None:
# response['catalogs'] = self._show_catalogs()
if database is None:
response['databases'] = self._show_databases()
elif table is None:
Expand All @@ -213,10 +201,10 @@ def autocomplete(self, snippet, database=None, table=None, column=None, nested=N
columns = self._get_columns(database, table)
response['columns'] = [col['name'] for col in columns]
response['extended_columns'] = [{
'comment': col.get('comment'),
'name': col.get('name'),
'type': col['type']
}
'comment': col.get('comment'),
'name': col.get('name'),
'type': col['type']
}
for col in columns
]

Expand All @@ -225,9 +213,8 @@ def autocomplete(self, snippet, database=None, table=None, column=None, nested=N

@query_error_handler
def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):

statement = self._get_select_query(database, table, column, operation)
query_client = TrinoQuery(self.db, statement)
query_client = TrinoQuery(self.trino_request, statement)
query_client.execute()

response = {
Expand All @@ -239,7 +226,7 @@ def get_sample_data(self, snippet, database=None, table=None, column=None, is_as
response['full_headers'] = query_client.columns

return response


def _get_select_query(self, database, table, column=None, operation=None, limit=100):
if operation == 'hello':
Expand All @@ -251,19 +238,19 @@ def _get_select_query(self, database, table, column=None, operation=None, limit=
FROM %(database)s.%(table)s
LIMIT %(limit)s
''' % {
'database': database,
'table': table,
'column': column,
'limit': limit,
})
'database': database,
'table': table,
'column': column,
'limit': limit,
})

return statement


def close_statement(self, notebook, snippet):
try:
if snippet['result']['handle']['next_uri']:
self.db.delete(snippet['result']['handle']['next_uri'])
self.trino_request.delete(snippet['result']['handle']['next_uri'])
else:
return {'status': -1} # missing operation ids
except Exception as e:
Expand All @@ -285,17 +272,15 @@ def _show_databases(self):
databases = []

for catalog in catalogs:

query_client = TrinoQuery(self.db, 'SHOW SCHEMAS FROM ' + catalog)
query_client = TrinoQuery(self.trino_request, 'SHOW SCHEMAS FROM ' + catalog)
response = query_client.execute()
databases += [f'{catalog}.{item}' for sublist in response.rows for item in sublist]

return databases


def _show_catalogs(self):

query_client = TrinoQuery(self.db, 'SHOW CATALOGS')
query_client = TrinoQuery(self.trino_request, 'SHOW CATALOGS')
response = query_client.execute()
res = response.rows
catalogs = [item for sublist in res for item in sublist]
Expand All @@ -304,41 +289,37 @@ def _show_catalogs(self):


def _show_tables(self, database):

query_client = TrinoQuery(self.db, 'USE ' + database)
query_client = TrinoQuery(self.trino_request, 'USE ' + database)
query_client.execute()
query_client = TrinoQuery(self.db, 'SHOW TABLES')
query_client = TrinoQuery(self.trino_request, 'SHOW TABLES')
response = query_client.execute()
tables = response.rows
return [{
'name': table[0],
'type': 'table',
'comment': '',
}
'name': table[0],
'type': 'table',
'comment': '',
}
for table in tables
]


def _get_columns(self, database, table):

query_client = TrinoQuery(self.db, 'USE ' + database)
query_client = TrinoQuery(self.trino_request, 'USE ' + database)
query_client.execute()
query_client = TrinoQuery(self.db, 'DESCRIBE ' + table)
query_client = TrinoQuery(self.trino_request, 'DESCRIBE ' + table)
response = query_client.execute()
columns = response.rows

return [{
'name': col[0],
'type': col[1],
'comment': '',
}
'name': col[0],
'type': col[1],
'comment': '',
}
for col in columns
]

def download(self, notebook, snippet, file_format='csv'):
from beeswax import data_export #TODO: Move to notebook?
from beeswax import conf


def download(self, notebook, snippet, file_format='csv'):
result_wrapper = TrinoExecutionWrapper(self, notebook, snippet)

max_rows = conf.DOWNLOAD_ROW_LIMIT.get()
Expand Down Expand Up @@ -369,7 +350,7 @@ def fetch(self, handle, start_over=None, rows=None):
result = self.snippet['result']['handle']['result']
else:
result = self.api.fetch_result(self.notebook, self.snippet, rows, start_over)
self.snippet['result']['handle']['row_n'] = result['row_n']
self.snippet['result']['handle']['row_count'] = result['row_count']
self.snippet['result']['handle']['next_uri'] = result['next_uri']

return ResultWrapper(result.get('meta'), result.get('data'), result.get('has_more'))
Expand Down

0 comments on commit a17d499

Please sign in to comment.