diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 033cb61f..713e055d 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -16,10 +16,11 @@ import pytest import sqlalchemy as sqla from sqlalchemy.sql import and_, not_, or_ +from sqlalchemy.types import ARRAY from tests.integration.conftest import trino_version from tests.unit.conftest import sqlalchemy_version -from trino.sqlalchemy.datatype import JSON, MAP +from trino.sqlalchemy.datatype import JSON, MAP, ROW @pytest.fixture @@ -528,6 +529,109 @@ def test_map_column(trino_connection, map_object, sqla_type): metadata.drop_all(engine) +@pytest.mark.skipif( + sqlalchemy_version() < "1.4", + reason="columns argument to select() must be a Python list or other iterable" +) +@pytest.mark.parametrize( + 'trino_connection,array_object,sqla_type', + [ + ('memory', None, ARRAY(sqla.sql.sqltypes.String)), + ('memory', [], ARRAY(sqla.sql.sqltypes.String)), + ('memory', [True, False, True], ARRAY(sqla.sql.sqltypes.Boolean)), + ('memory', [1, 2, None], ARRAY(sqla.sql.sqltypes.Integer)), + ('memory', [1.4, 2.3, math.inf], ARRAY(sqla.sql.sqltypes.Float)), + ('memory', [Decimal("1.2"), Decimal("2.3")], ARRAY(sqla.sql.sqltypes.DECIMAL(2, 1))), + ('memory', ["hello", "world"], ARRAY(sqla.sql.sqltypes.String)), + ('memory', ["a ", "null"], ARRAY(sqla.sql.sqltypes.CHAR(4))), + ('memory', [b'eh?', None, b'\x00'], ARRAY(sqla.sql.sqltypes.BINARY)), + ], + indirect=['trino_connection'] +) +def test_array_column(trino_connection, array_object, sqla_type): + engine, conn = trino_connection + + if not engine.dialect.has_schema(conn, "test"): + with engine.begin() as connection: + connection.execute(sqla.schema.CreateSchema("test")) + metadata = sqla.MetaData() + + try: + table_with_array = sqla.Table( + 'table_with_array', + metadata, + sqla.Column('id', sqla.Integer), + sqla.Column('array_column', sqla_type), + schema="test" + ) + metadata.create_all(engine) + ins = table_with_array.insert() + conn.execute(ins, {"id": 1, "array_column": array_object}) + query = sqla.select(table_with_array) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 1 + assert rows[0] == (1, array_object) + finally: + metadata.drop_all(engine) + + +@pytest.mark.skipif( + sqlalchemy_version() < "1.4", + reason="columns argument to select() must be a Python list or other iterable" +) +@pytest.mark.parametrize( + 'trino_connection,row_object,sqla_type', + [ + ('memory', None, ROW([('field1', sqla.sql.sqltypes.String), + ('field2', sqla.sql.sqltypes.String)])), + ('memory', ('hello', 'world'), ROW([('field1', sqla.sql.sqltypes.String), + ('field2', sqla.sql.sqltypes.String)])), + ('memory', (True, False), ROW([('field1', sqla.sql.sqltypes.Boolean), + ('field2', sqla.sql.sqltypes.Boolean)])), + ('memory', (1, 2), ROW([('field1', sqla.sql.sqltypes.Integer), + ('field2', sqla.sql.sqltypes.Integer)])), + ('memory', (1.4, float('inf')), ROW([('field1', sqla.sql.sqltypes.Float), + ('field2', sqla.sql.sqltypes.Float)])), + ('memory', (Decimal("1.2"), Decimal("2.3")), ROW([('field1', sqla.sql.sqltypes.DECIMAL(2, 1)), + ('field2', sqla.sql.sqltypes.DECIMAL(3, 1))])), + ('memory', ("hello", "world"), ROW([('field1', sqla.sql.sqltypes.String), + ('field2', sqla.sql.sqltypes.String)])), + ('memory', ("a ", "null"), ROW([('field1', sqla.sql.sqltypes.CHAR(4)), + ('field2', sqla.sql.sqltypes.CHAR(4))])), + ('memory', (b'eh?', b'oh?'), ROW([('field1', sqla.sql.sqltypes.BINARY), + ('field2', sqla.sql.sqltypes.BINARY)])), + ], + indirect=['trino_connection'] +) +def test_row_column(trino_connection, row_object, sqla_type): + engine, conn = trino_connection + + if not engine.dialect.has_schema(conn, "test"): + with engine.begin() as connection: + connection.execute(sqla.schema.CreateSchema("test")) + metadata = sqla.MetaData() + + try: + table_with_row = sqla.Table( + 'table_with_row', + metadata, + sqla.Column('id', sqla.Integer), + sqla.Column('row_column', sqla_type), + schema="test" + ) + metadata.create_all(engine) + ins = table_with_row.insert() + conn.execute(ins, {"id": 1, "row_column": row_object}) + query = sqla.select(table_with_row) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 1 + assert rows[0] == (1, row_object) + finally: + metadata.drop_all(engine) + + @pytest.mark.parametrize('trino_connection', ['system'], indirect=True) def test_get_catalog_names(trino_connection): engine, conn = trino_connection diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py index 5f83d984..cc543ae7 100644 --- a/trino/sqlalchemy/compiler.py +++ b/trino/sqlalchemy/compiler.py @@ -253,6 +253,12 @@ def visit_MAP(self, type_, **kw): value_type = self.process(type_.value_type, **kw) return f'MAP({key_type}, {value_type})' + def visit_ARRAY(self, type_, **kw): + return f'ARRAY({self.process(type_.item_type, **kw)})' + + def visit_ROW(self, type_, **kw): + return f'ROW({", ".join(f"{name} {self.process(attr_type, **kw)}" for name, attr_type in type_.attr_types)})' + class TrinoIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS