Skip to content

Commit

Permalink
Fix JSON support in SQLAlchemy
Browse files Browse the repository at this point in the history
  • Loading branch information
hovaesco authored and hashhar committed Dec 22, 2023
1 parent 86ed4da commit ab0e596
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 11 deletions.
66 changes: 66 additions & 0 deletions tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def test_json_column(trino_connection, json_object):
ins = table_with_json.insert()
conn.execute(ins, {"id": 1, "json_column": json_object})
query = sqla.select(table_with_json)
assert isinstance(table_with_json.c.json_column.type, JSON)
result = conn.execute(query)
rows = result.fetchall()
assert len(rows) == 1
Expand All @@ -410,6 +411,71 @@ def test_json_column(trino_connection, json_object):
metadata.drop_all(engine)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
def test_json_column_operations(trino_connection):
engine, conn = trino_connection

metadata = sqla.MetaData()

json_object = {
"a": {"c": 1},
100: {"z": 200},
"b": 2,
10: 20,
"foo-bar": {"z": 200}
}

try:
table_with_json = sqla.Table(
'table_with_json',
metadata,
sqla.Column('json_column', JSON),
schema="default"
)
metadata.create_all(engine)
ins = table_with_json.insert()
conn.execute(ins, {"json_column": json_object})

# JSONPathType
query = sqla.select(table_with_json.c.json_column["a", "c"])
conn.execute(query)
result = conn.execute(query)
assert result.fetchall()[0][0] == 1

query = sqla.select(table_with_json.c.json_column[100, "z"])
conn.execute(query)
result = conn.execute(query)
assert result.fetchall()[0][0] == 200

query = sqla.select(table_with_json.c.json_column["foo-bar", "z"])
conn.execute(query)
result = conn.execute(query)
assert result.fetchall()[0][0] == 200

# JSONIndexType
query = sqla.select(table_with_json.c.json_column["b"])
conn.execute(query)
result = conn.execute(query)
assert result.fetchall()[0][0] == 2

query = sqla.select(table_with_json.c.json_column[10])
conn.execute(query)
result = conn.execute(query)
assert result.fetchall()[0][0] == 20

query = sqla.select(table_with_json.c.json_column["foo-bar"])
conn.execute(query)
result = conn.execute(query)
assert result.fetchall()[0][0] == {'z': 200}

finally:
metadata.drop_all(engine)


@pytest.mark.parametrize('trino_connection', ['system'], indirect=True)
def test_get_catalog_names(trino_connection):
engine, conn = trino_connection
Expand Down
15 changes: 14 additions & 1 deletion trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from sqlalchemy.sql import compiler
from sqlalchemy.sql import compiler, sqltypes
from sqlalchemy.sql.base import DialectKWArgs

# https://trino.io/docs/current/language/reserved.html
Expand Down Expand Up @@ -125,6 +125,19 @@ def add_catalog(sql, table):
sql = f'"{catalog}".{sql}'
return sql

def visit_json_getitem_op_binary(self, binary, operator, **kw):
return self._render_json_extract_from_binary(binary, operator, **kw)

def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
return self._render_json_extract_from_binary(binary, operator, **kw)

def _render_json_extract_from_binary(self, binary, operator, **kw):
if binary.type._type_affinity is sqltypes.JSON:
return "JSON_EXTRACT(%s, %s)" % (
self.process(binary.left, **kw),
self.process(binary.right, **kw),
)


class TrinoDDLCompiler(compiler.DDLCompiler):
pass
Expand Down
62 changes: 52 additions & 10 deletions trino/sqlalchemy/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union

import sqlalchemy
from sqlalchemy import util
from sqlalchemy import func, util
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.type_api import TypeDecorator, TypeEngine
from sqlalchemy.types import String
from sqlalchemy.types import JSON

SQLType = Union[TypeEngine, Type[TypeEngine]]

Expand Down Expand Up @@ -75,16 +74,59 @@ def __init__(self, precision=None, timezone=False):


class JSON(TypeDecorator):
impl = String
impl = JSON

def process_bind_param(self, value, dialect):
return json.dumps(value)
def bind_expression(self, bindvalue):
return func.JSON_PARSE(bindvalue)

def process_result_value(self, value, dialect):
return json.loads(value)

def get_col_spec(self, **kw):
return 'JSON'
class _FormatTypeMixin:
def _format_value(self, value):
raise NotImplementedError()

def bind_processor(self, dialect):
super_proc = self.string_bind_processor(dialect)

def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value

return process

def literal_processor(self, dialect):
super_proc = self.string_literal_processor(dialect)

def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value

return process


class _JSONFormatter:
@staticmethod
def format_index(value):
return "$[\"%s\"]" % value

@staticmethod
def format_path(value):
return "$%s" % (
"".join(["[\"%s\"]" % elem for elem in value])
)


class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
def _format_value(self, value):
return _JSONFormatter.format_index(value)


class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
def _format_value(self, value):
return _JSONFormatter.format_path(value)


# https://trino.io/docs/current/language/types.html
Expand Down
17 changes: 17 additions & 0 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import sqltypes

from trino import dbapi as trino_dbapi
from trino import logging
Expand All @@ -31,10 +32,25 @@
from trino.dbapi import Cursor
from trino.sqlalchemy import compiler, datatype, error

from .datatype import JSONIndexType, JSONPathType

logger = logging.get_logger(__name__)

colspecs = {
sqltypes.JSON.JSONIndexType: JSONIndexType,
sqltypes.JSON.JSONPathType: JSONPathType,
}


class TrinoDialect(DefaultDialect):
def __init__(self,
json_serializer=None,
json_deserializer=None,
**kwargs):
DefaultDialect.__init__(self, **kwargs)
self._json_serializer = json_serializer
self._json_deserializer = json_deserializer

name = "trino"
driver = "rest"

Expand Down Expand Up @@ -70,6 +86,7 @@ class TrinoDialect(DefaultDialect):

# Support proper ordering of CTEs in regard to an INSERT statement
cte_follows_insert = True
colspecs = colspecs

@classmethod
def dbapi(cls):
Expand Down

0 comments on commit ab0e596

Please sign in to comment.