diff --git a/CHANGES.md b/CHANGES.md index 40bf45b..87cc88d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,6 +11,9 @@ `check_uniqueness_factory` - Added `table_kwargs` context manager to enable pandas/Dask to support CrateDB dialect table options. +- Added support for `psycopg` and `asyncpg` drivers, by introducing the + `crate+psycopg://`, `crate+asyncpg://`, and `crate+urllib3://` dialect + identifiers. The asynchronous variant of `psycopg` is also supported. ## 2024/06/13 0.37.0 - Added support for CrateDB's [FLOAT_VECTOR] data type and its accompanying diff --git a/pyproject.toml b/pyproject.toml index 4f95edc..b67f780 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,7 @@ dependencies = [ ] [project.optional-dependencies] all = [ - "sqlalchemy-cratedb[vector]", + "sqlalchemy-cratedb[postgresql,vector]", ] develop = [ "black<25", @@ -107,6 +107,9 @@ doc = [ "crate-docs-theme>=0.26.5", "sphinx<8,>=3.5", ] +postgresql = [ + "sqlalchemy-postgresql-relaxed", +] release = [ "build<2", "twine<6", @@ -117,6 +120,7 @@ test = [ "pandas<2.3", "pueblo>=0.0.7", "pytest<9", + "pytest-asyncio<0.24", "pytest-cov<6", "pytest-mock<4", ] @@ -129,7 +133,11 @@ documentation = "https://cratedb.com/docs/sqlalchemy-cratedb/" homepage = "https://cratedb.com/docs/sqlalchemy-cratedb/" repository = "https://github.com/crate/sqlalchemy-cratedb" [project.entry-points."sqlalchemy.dialects"] -crate = "sqlalchemy_cratedb:dialect" +"crate" = "sqlalchemy_cratedb:dialect" +"crate.urllib3" = "sqlalchemy_cratedb.dialect_more:dialect_urllib3" +"crate.psycopg" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg" +"crate.psycopg_async" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg_async" +"crate.asyncpg" = "sqlalchemy_cratedb.dialect_more:dialect_asyncpg" [tool.black] line-length = 100 diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 6944f2c..f5850e1 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -21,6 +21,7 @@ import logging from datetime import datetime, date +from types import ModuleType from sqlalchemy import types as sqltypes from sqlalchemy.engine import default, reflection @@ -207,6 +208,12 @@ def initialize(self, connection): self.default_schema_name = \ self._get_default_schema_name(connection) + def set_isolation_level(self, dbapi_connection, level): + """ + For CrateDB, this is implemented as a noop. + """ + pass + def do_rollback(self, connection): # if any exception is raised by the dbapi, sqlalchemy by default # attempts to do a rollback crate doesn't support rollbacks. @@ -225,7 +232,21 @@ def connect(self, host=None, port=None, *args, **kwargs): use_ssl = asbool(kwargs.pop("ssl", False)) if use_ssl: servers = ["https://" + server for server in servers] - return self.dbapi.connect(servers=servers, **kwargs) + + is_module = isinstance(self.dbapi, ModuleType) + if is_module: + driver_name = self.dbapi.__name__ + else: + driver_name = self.dbapi.__class__.__name__ + if driver_name == "crate.client": + if "database" in kwargs: + del kwargs["database"] + return self.dbapi.connect(servers=servers, **kwargs) + elif driver_name in ["psycopg", "PsycopgAdaptDBAPI", "AsyncAdapt_asyncpg_dbapi"]: + return self.dbapi.connect(host=host, port=port, **kwargs) + else: + raise ValueError(f"Unknown driver variant: {driver_name}") + return self.dbapi.connect(**kwargs) def _get_default_schema_name(self, connection): @@ -271,11 +292,11 @@ def get_schema_names(self, connection, **kw): def get_table_names(self, connection, schema=None, **kw): if schema is None: schema = self._get_effective_schema_name(connection) - cursor = connection.exec_driver_sql( + cursor = connection.exec_driver_sql(self._format_query( "SELECT table_name FROM information_schema.tables " "WHERE {0} = ? " "AND table_type = 'BASE TABLE' " - "ORDER BY table_name ASC, {0} ASC".format(self.schema_column), + "ORDER BY table_name ASC, {0} ASC").format(self.schema_column), (schema or self.default_schema_name, ) ) return [row[0] for row in cursor.fetchall()] @@ -297,7 +318,7 @@ def get_columns(self, connection, table_name, schema=None, **kw): "AND column_name !~ ?" \ .format(self.schema_column) cursor = connection.exec_driver_sql( - query, + self._format_query(query), (table_name, schema or self.default_schema_name, r"(.*)\[\'(.*)\'\]") # regex to filter subscript @@ -336,7 +357,7 @@ def result_fun(result): return set(rows[0] if rows else []) pk_result = engine.exec_driver_sql( - query, + self._format_query(query), (table_name, schema or self.default_schema_name) ) pks = result_fun(pk_result) @@ -377,6 +398,17 @@ def has_ilike_operator(self): server_version_info = self.server_version_info return server_version_info is not None and server_version_info >= (4, 1, 0) + def _format_query(self, query): + """ + When using the PostgreSQL protocol with drivers `psycopg` or `asyncpg`, + the paramstyle is not `qmark`, but `pyformat`. + + TODO: Review: Is it legit and sane? Are there alternatives? + """ + if self.paramstyle == "pyformat": + query = query.replace("= ?", "= %s").replace("!~ ?", "!~ %s") + return query + class DateTrunc(functions.GenericFunction): name = "date_trunc" diff --git a/src/sqlalchemy_cratedb/dialect_more.py b/src/sqlalchemy_cratedb/dialect_more.py new file mode 100644 index 0000000..0263012 --- /dev/null +++ b/src/sqlalchemy_cratedb/dialect_more.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8; -*- +# +# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor +# license agreements. See the NOTICE file distributed with this work for +# additional information regarding copyright ownership. Crate licenses +# this file to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may +# obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# However, if you have executed another commercial license agreement +# with Crate these terms will supersede the license and you may use the +# software solely pursuant to the terms of the relevant commercial agreement. +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy_postgresql_relaxed.asyncpg import PGDialect_asyncpg_relaxed +from sqlalchemy_postgresql_relaxed.base import PGDialect_relaxed +from sqlalchemy_postgresql_relaxed.psycopg import ( + PGDialect_psycopg_relaxed, + PGDialectAsync_psycopg_relaxed, +) + +from sqlalchemy_cratedb import dialect + + +class CrateDialectPostgresAdapter(PGDialect_relaxed, dialect): + """ + Provide a dialect on top of the relaxed PostgreSQL dialect. + """ + + inspector = Inspector + + # Need to manually override some methods because of polymorphic inheritance woes. + # TODO: Investigate if this can be solved using metaprogramming or other techniques. + has_schema = dialect.has_schema + has_table = dialect.has_table + get_schema_names = dialect.get_schema_names + get_table_names = dialect.get_table_names + get_view_names = dialect.get_view_names + get_columns = dialect.get_columns + get_pk_constraint = dialect.get_pk_constraint + get_foreign_keys = dialect.get_foreign_keys + get_indexes = dialect.get_indexes + + get_multi_columns = dialect.get_multi_columns + get_multi_pk_constraint = dialect.get_multi_pk_constraint + get_multi_foreign_keys = dialect.get_multi_foreign_keys + + # TODO: Those may want to go to dialect instead? + def get_multi_indexes(self, *args, **kwargs): + return [] + + def get_multi_unique_constraints(self, *args, **kwargs): + return [] + + def get_multi_check_constraints(self, *args, **kwargs): + return [] + + def get_multi_table_comment(self, *args, **kwargs): + return [] + + +class CrateDialect_psycopg(PGDialect_psycopg_relaxed, CrateDialectPostgresAdapter): + driver = "psycopg" + + @classmethod + def get_async_dialect_cls(cls, url): + return CrateDialectAsync_psycopg + + @classmethod + def import_dbapi(cls): + import psycopg + + return psycopg + + +class CrateDialectAsync_psycopg(PGDialectAsync_psycopg_relaxed, CrateDialectPostgresAdapter): + driver = "psycopg_async" + is_async = True + + +class CrateDialect_asyncpg(PGDialect_asyncpg_relaxed, CrateDialectPostgresAdapter): + driver = "asyncpg" + + # TODO: asyncpg may have `paramstyle="numeric_dollar"`. Review this! + + # TODO: AttributeError: module 'asyncpg' has no attribute 'paramstyle' + """ + @classmethod + def import_dbapi(cls): + import asyncpg + + return asyncpg + """ + + +dialect_urllib3 = dialect +dialect_psycopg = CrateDialect_psycopg +dialect_psycopg_async = CrateDialectAsync_psycopg +dialect_asyncpg = CrateDialect_asyncpg diff --git a/tests/engine_test.py b/tests/engine_test.py new file mode 100644 index 0000000..33e3adb --- /dev/null +++ b/tests/engine_test.py @@ -0,0 +1,81 @@ +import pytest +import sqlalchemy as sa +from sqlalchemy.dialects import registry as dialect_registry + +from sqlalchemy_cratedb import SA_VERSION, SA_2_0 + +if SA_VERSION < SA_2_0: + raise pytest.skip("Only supported on SQLAlchemy 2.0 and higher", allow_module_level=True) + +from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine + +# Registering the additional dialects manually seems to be needed when running +# under tests. Apparently, manual registration is not needed under regular +# circumstances, as this is wired through the `sqlalchemy.dialects` entrypoint +# registrations in `pyproject.toml`. It is definitively weird, but c'est la vie. +dialect_registry.register("crate.urllib3", "sqlalchemy_cratedb.dialect_more", "dialect_urllib3") +dialect_registry.register("crate.asyncpg", "sqlalchemy_cratedb.dialect_more", "dialect_asyncpg") +dialect_registry.register("crate.psycopg", "sqlalchemy_cratedb.dialect_more", "dialect_psycopg") + + +QUERY = sa.text("SELECT mountain, coordinates FROM sys.summits ORDER BY mountain LIMIT 3;") + + +def test_engine_sync_vanilla(): + """ + crate:// -- Verify connectivity and data transport with vanilla HTTP-based driver. + """ + engine = sa.create_engine("crate://crate@localhost:4200/", echo=True) + assert isinstance(engine, sa.engine.Engine) + with engine.connect() as connection: + result = connection.execute(QUERY) + assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': [10.95667, 47.18917]} + + +def test_engine_sync_urllib3(): + """ + crate+urllib3:// -- Verify connectivity and data transport *explicitly* selecting the HTTP driver. + """ + engine = sa.create_engine("crate+urllib3://crate@localhost:4200/", isolation_level="AUTOCOMMIT", echo=True) + assert isinstance(engine, sa.engine.Engine) + with engine.connect() as connection: + result = connection.execute(QUERY) + assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': [10.95667, 47.18917]} + + +def test_engine_sync_psycopg(): + """ + crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3). + """ + engine = sa.create_engine("crate+psycopg://crate@localhost:5432/", isolation_level="AUTOCOMMIT", echo=True) + assert isinstance(engine, sa.engine.Engine) + with engine.connect() as connection: + result = connection.execute(QUERY) + assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': '(10.95667,47.18917)'} + + +@pytest.mark.asyncio +async def test_engine_async_psycopg(): + """ + crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3). + This time, in asynchronous mode. + """ + engine = create_async_engine("crate+psycopg://crate@localhost:5432/", isolation_level="AUTOCOMMIT", echo=True) + assert isinstance(engine, AsyncEngine) + async with engine.begin() as conn: + result = await conn.execute(QUERY) + assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': '(10.95667,47.18917)'} + + +@pytest.mark.asyncio +async def test_engine_async_asyncpg(): + """ + crate+asyncpg:// -- Verify connectivity and data transport using the asyncpg driver. + This exclusively uses asynchronous mode. + """ + from asyncpg.pgproto.types import Point + engine = create_async_engine("crate+asyncpg://crate@localhost:5432/", isolation_level="AUTOCOMMIT", echo=True) + assert isinstance(engine, AsyncEngine) + async with engine.begin() as conn: + result = await conn.execute(QUERY) + assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': Point(10.95667, 47.18917)}