Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ROW and ARRAY in TrinoTypeCompiler #464

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import pytest
import sqlalchemy as sqla
from sqlalchemy.sql import and_, not_, or_
from sqlalchemy.types import ARRAY

from tests.integration.conftest import trino_version
from tests.unit.conftest import sqlalchemy_version
from trino.sqlalchemy.datatype import JSON, MAP
from trino.sqlalchemy.datatype import JSON, MAP, ROW


@pytest.fixture
Expand Down Expand Up @@ -528,6 +529,109 @@ def test_map_column(trino_connection, map_object, sqla_type):
metadata.drop_all(engine)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize(
'trino_connection,array_object,sqla_type',
[
('memory', None, ARRAY(sqla.sql.sqltypes.String)),
('memory', [], ARRAY(sqla.sql.sqltypes.String)),
('memory', [True, False, True], ARRAY(sqla.sql.sqltypes.Boolean)),
('memory', [1, 2, None], ARRAY(sqla.sql.sqltypes.Integer)),
('memory', [1.4, 2.3, math.inf], ARRAY(sqla.sql.sqltypes.Float)),
('memory', [Decimal("1.2"), Decimal("2.3")], ARRAY(sqla.sql.sqltypes.DECIMAL(2, 1))),
('memory', ["hello", "world"], ARRAY(sqla.sql.sqltypes.String)),
('memory', ["a ", "null"], ARRAY(sqla.sql.sqltypes.CHAR(4))),
('memory', [b'eh?', None, b'\x00'], ARRAY(sqla.sql.sqltypes.BINARY)),
],
indirect=['trino_connection']
)
def test_array_column(trino_connection, array_object, sqla_type):
engine, conn = trino_connection

if not engine.dialect.has_schema(conn, "test"):
with engine.begin() as connection:
connection.execute(sqla.schema.CreateSchema("test"))
metadata = sqla.MetaData()

try:
table_with_array = sqla.Table(
'table_with_array',
metadata,
sqla.Column('id', sqla.Integer),
sqla.Column('array_column', sqla_type),
schema="test"
)
metadata.create_all(engine)
ins = table_with_array.insert()
conn.execute(ins, {"id": 1, "array_column": array_object})
query = sqla.select(table_with_array)
result = conn.execute(query)
rows = result.fetchall()
assert len(rows) == 1
assert rows[0] == (1, array_object)
finally:
metadata.drop_all(engine)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize(
'trino_connection,row_object,sqla_type',
[
('memory', None, ROW([('field1', sqla.sql.sqltypes.String),
('field2', sqla.sql.sqltypes.String)])),
('memory', ('hello', 'world'), ROW([('field1', sqla.sql.sqltypes.String),
('field2', sqla.sql.sqltypes.String)])),
('memory', (True, False), ROW([('field1', sqla.sql.sqltypes.Boolean),
('field2', sqla.sql.sqltypes.Boolean)])),
('memory', (1, 2), ROW([('field1', sqla.sql.sqltypes.Integer),
('field2', sqla.sql.sqltypes.Integer)])),
('memory', (1.4, float('inf')), ROW([('field1', sqla.sql.sqltypes.Float),
('field2', sqla.sql.sqltypes.Float)])),
('memory', (Decimal("1.2"), Decimal("2.3")), ROW([('field1', sqla.sql.sqltypes.DECIMAL(2, 1)),
('field2', sqla.sql.sqltypes.DECIMAL(3, 1))])),
('memory', ("hello", "world"), ROW([('field1', sqla.sql.sqltypes.String),
('field2', sqla.sql.sqltypes.String)])),
('memory', ("a ", "null"), ROW([('field1', sqla.sql.sqltypes.CHAR(4)),
('field2', sqla.sql.sqltypes.CHAR(4))])),
('memory', (b'eh?', b'oh?'), ROW([('field1', sqla.sql.sqltypes.BINARY),
('field2', sqla.sql.sqltypes.BINARY)])),
],
indirect=['trino_connection']
)
def test_row_column(trino_connection, row_object, sqla_type):
engine, conn = trino_connection

if not engine.dialect.has_schema(conn, "test"):
with engine.begin() as connection:
connection.execute(sqla.schema.CreateSchema("test"))
metadata = sqla.MetaData()

try:
table_with_row = sqla.Table(
'table_with_row',
metadata,
sqla.Column('id', sqla.Integer),
sqla.Column('row_column', sqla_type),
schema="test"
)
metadata.create_all(engine)
ins = table_with_row.insert()
conn.execute(ins, {"id": 1, "row_column": row_object})
query = sqla.select(table_with_row)
result = conn.execute(query)
rows = result.fetchall()
assert len(rows) == 1
assert rows[0] == (1, row_object)
finally:
metadata.drop_all(engine)


@pytest.mark.parametrize('trino_connection', ['system'], indirect=True)
def test_get_catalog_names(trino_connection):
engine, conn = trino_connection
Expand Down
6 changes: 6 additions & 0 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,12 @@ def visit_MAP(self, type_, **kw):
value_type = self.process(type_.value_type, **kw)
return f'MAP({key_type}, {value_type})'

def visit_ARRAY(self, type_, **kw):
return f'ARRAY({self.process(type_.item_type, **kw)})'

def visit_ROW(self, type_, **kw):
return f'ROW({", ".join(f"{name} {self.process(attr_type, **kw)}" for name, attr_type in type_.attr_types)})'


class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS
Expand Down