Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dialect: Add support for asyncpg and psycopg3 drivers #11

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Changelog

## Unreleased
- Added support for `psycopg` and `asyncpg` drivers, by introducing the
`crate+psycopg://`, `crate+asyncpg://`, and `crate+urllib3://` dialect
identifiers. The asynchronous variant of `psycopg` is also supported.

## 2024/08/29 0.39.0
Added `quote_relation_name` support utility function
Expand Down
12 changes: 10 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ dependencies = [
"verlib2==0.2",
]
optional-dependencies.all = [
"sqlalchemy-cratedb[vector]",
"sqlalchemy-cratedb[postgresql,vector]",
]
optional-dependencies.develop = [
"mypy<1.12",
Expand All @@ -101,6 +101,9 @@ optional-dependencies.doc = [
"crate-docs-theme>=0.26.5",
"sphinx>=3.5,<9",
]
optional-dependencies.postgresql = [
"sqlalchemy-postgresql-relaxed",
]
optional-dependencies.release = [
"build<2",
"twine<6",
Expand All @@ -111,6 +114,7 @@ optional-dependencies.test = [
"pandas<2.3",
"pueblo>=0.0.7",
"pytest<9",
"pytest-asyncio<0.24",
"pytest-cov<6",
"pytest-mock<4",
]
Expand All @@ -121,7 +125,11 @@ urls.changelog = "https://github.com/crate/sqlalchemy-cratedb/blob/main/CHANGES.
urls.documentation = "https://cratedb.com/docs/sqlalchemy-cratedb/"
urls.homepage = "https://cratedb.com/docs/sqlalchemy-cratedb/"
urls.repository = "https://github.com/crate/sqlalchemy-cratedb"
entry-points."sqlalchemy.dialects".crate = "sqlalchemy_cratedb:dialect"
entry-points."sqlalchemy.dialects"."crate" = "sqlalchemy_cratedb:dialect"
entry-points."sqlalchemy.dialects"."crate.asyncpg" = "sqlalchemy_cratedb.dialect_more:dialect_asyncpg"
entry-points."sqlalchemy.dialects"."crate.psycopg" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg"
entry-points."sqlalchemy.dialects"."crate.psycopg_async" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg_async"
entry-points."sqlalchemy.dialects"."crate.urllib3" = "sqlalchemy_cratedb.dialect_more:dialect_urllib3"

[tool.black]
line-length = 100
Expand Down
50 changes: 43 additions & 7 deletions src/sqlalchemy_cratedb/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import logging
from datetime import date, datetime
from types import ModuleType

from sqlalchemy import types as sqltypes
from sqlalchemy.engine import default, reflection
Expand Down Expand Up @@ -212,6 +213,12 @@
# get default schema name
self.default_schema_name = self._get_default_schema_name(connection)

def set_isolation_level(self, dbapi_connection, level):
"""
For CrateDB, this is implemented as a noop.
"""
pass

def do_rollback(self, connection):
# if any exception is raised by the dbapi, sqlalchemy by default
# attempts to do a rollback crate doesn't support rollbacks.
Expand All @@ -230,7 +237,21 @@
use_ssl = asbool(kwargs.pop("ssl", False))
if use_ssl:
servers = ["https://" + server for server in servers]
return self.dbapi.connect(servers=servers, **kwargs)

is_module = isinstance(self.dbapi, ModuleType)
if is_module:
driver_name = self.dbapi.__name__
else:
driver_name = self.dbapi.__class__.__name__
if driver_name == "crate.client":
if "database" in kwargs:
del kwargs["database"]

Check warning on line 248 in src/sqlalchemy_cratedb/dialect.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect.py#L248

Added line #L248 was not covered by tests
return self.dbapi.connect(servers=servers, **kwargs)
elif driver_name in ["psycopg", "PsycopgAdaptDBAPI", "AsyncAdapt_asyncpg_dbapi"]:
return self.dbapi.connect(host=host, port=port, **kwargs)
else:
raise ValueError(f"Unknown driver variant: {driver_name}")

Check warning on line 253 in src/sqlalchemy_cratedb/dialect.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect.py#L253

Added line #L253 was not covered by tests

return self.dbapi.connect(**kwargs)

def _get_default_schema_name(self, connection):
Expand Down Expand Up @@ -276,10 +297,12 @@
if schema is None:
schema = self._get_effective_schema_name(connection)
cursor = connection.exec_driver_sql(
"SELECT table_name FROM information_schema.tables "
"WHERE {0} = ? "
"AND table_type = 'BASE TABLE' "
"ORDER BY table_name ASC, {0} ASC".format(self.schema_column),
self._format_query(
"SELECT table_name FROM information_schema.tables "
"WHERE {0} = ? "
"AND table_type = 'BASE TABLE' "
"ORDER BY table_name ASC, {0} ASC"
).format(self.schema_column),
(schema or self.default_schema_name,),
)
return [row[0] for row in cursor.fetchall()]
Expand All @@ -302,7 +325,7 @@
"AND column_name !~ ?".format(self.schema_column)
)
cursor = connection.exec_driver_sql(
query,
self._format_query(query),
(
table_name,
schema or self.default_schema_name,
Expand Down Expand Up @@ -342,7 +365,9 @@
rows = result.fetchone()
return set(rows[0] if rows else [])

pk_result = engine.exec_driver_sql(query, (table_name, schema or self.default_schema_name))
pk_result = engine.exec_driver_sql(
self._format_query(query), (table_name, schema or self.default_schema_name)
)
pks = result_fun(pk_result)
return {"constrained_columns": sorted(pks), "name": "PRIMARY KEY"}

Expand Down Expand Up @@ -381,6 +406,17 @@
server_version_info = self.server_version_info
return server_version_info is not None and server_version_info >= (4, 1, 0)

def _format_query(self, query):
"""
When using the PostgreSQL protocol with drivers `psycopg` or `asyncpg`,
the paramstyle is not `qmark`, but `pyformat`.
Comment on lines +411 to +412
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not quite following. I assume we would want to use server side binding (i.e. qmark)?

Copy link
Member Author

@amotl amotl Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably the reason for needing the workaround at all: Because PostgreSQL drivers psycopg and asyncpg, or the SA dialect, use pyformat, but CrateDB uses qmark, we may need to adjust, iirc.

At least, the patch in its current shape needs it. Maybe there are alternatives to implement it, possibly even easier ones. We will be happy to learn about them.


TODO: Review: Is it legit and sane? Are there alternatives?
"""
if self.paramstyle == "pyformat":
query = query.replace("= ?", "= %s").replace("!~ ?", "!~ %s")

Check warning on line 417 in src/sqlalchemy_cratedb/dialect.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect.py#L417

Added line #L417 was not covered by tests
return query


class DateTrunc(functions.GenericFunction):
name = "date_trunc"
Expand Down
106 changes: 106 additions & 0 deletions src/sqlalchemy_cratedb/dialect_more.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# -*- coding: utf-8; -*-
#
# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
# license agreements. See the NOTICE file distributed with this work for
# additional information regarding copyright ownership. Crate licenses
# this file to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may
# obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# However, if you have executed another commercial license agreement
# with Crate these terms will supersede the license and you may use the
# software solely pursuant to the terms of the relevant commercial agreement.
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy_postgresql_relaxed.asyncpg import PGDialect_asyncpg_relaxed
from sqlalchemy_postgresql_relaxed.base import PGDialect_relaxed
from sqlalchemy_postgresql_relaxed.psycopg import (
PGDialect_psycopg_relaxed,
PGDialectAsync_psycopg_relaxed,
)

from sqlalchemy_cratedb import dialect


class CrateDialectPostgresAdapter(PGDialect_relaxed, dialect):
"""
Provide a dialect on top of the relaxed PostgreSQL dialect.
"""

inspector = Inspector

# Need to manually override some methods because of polymorphic inheritance woes.
# TODO: Investigate if this can be solved using metaprogramming or other techniques.
has_schema = dialect.has_schema
has_table = dialect.has_table
get_schema_names = dialect.get_schema_names
get_table_names = dialect.get_table_names
get_view_names = dialect.get_view_names
get_columns = dialect.get_columns
get_pk_constraint = dialect.get_pk_constraint
get_foreign_keys = dialect.get_foreign_keys
get_indexes = dialect.get_indexes

get_multi_columns = dialect.get_multi_columns
get_multi_pk_constraint = dialect.get_multi_pk_constraint
get_multi_foreign_keys = dialect.get_multi_foreign_keys

# TODO: Those may want to go to dialect instead?
def get_multi_indexes(self, *args, **kwargs):
return []

Check warning on line 57 in src/sqlalchemy_cratedb/dialect_more.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect_more.py#L57

Added line #L57 was not covered by tests

def get_multi_unique_constraints(self, *args, **kwargs):
return []

Check warning on line 60 in src/sqlalchemy_cratedb/dialect_more.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect_more.py#L60

Added line #L60 was not covered by tests

def get_multi_check_constraints(self, *args, **kwargs):
return []

Check warning on line 63 in src/sqlalchemy_cratedb/dialect_more.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect_more.py#L63

Added line #L63 was not covered by tests

def get_multi_table_comment(self, *args, **kwargs):
return []

Check warning on line 66 in src/sqlalchemy_cratedb/dialect_more.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect_more.py#L66

Added line #L66 was not covered by tests


class CrateDialect_psycopg(PGDialect_psycopg_relaxed, CrateDialectPostgresAdapter):
driver = "psycopg"

@classmethod
def get_async_dialect_cls(cls, url):
return CrateDialectAsync_psycopg

@classmethod
def import_dbapi(cls):
import psycopg

return psycopg


class CrateDialectAsync_psycopg(PGDialectAsync_psycopg_relaxed, CrateDialectPostgresAdapter):
driver = "psycopg_async"
is_async = True


class CrateDialect_asyncpg(PGDialect_asyncpg_relaxed, CrateDialectPostgresAdapter):
driver = "asyncpg"

# TODO: asyncpg may have `paramstyle="numeric_dollar"`. Review this!

# TODO: AttributeError: module 'asyncpg' has no attribute 'paramstyle'
"""
@classmethod
def import_dbapi(cls):
import asyncpg

return asyncpg
"""


dialect_urllib3 = dialect
dialect_psycopg = CrateDialect_psycopg
dialect_psycopg_async = CrateDialectAsync_psycopg
dialect_asyncpg = CrateDialect_asyncpg
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ def cratedb_service():
Provide a CrateDB service instance to the test suite.
"""
db = CrateDBTestAdapter()
db.start()
db.start(ports={4200: None, 5432: None})
yield db
db.stop()
110 changes: 110 additions & 0 deletions tests/engine_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pytest
import sqlalchemy as sa
from sqlalchemy.dialects import registry as dialect_registry

from sqlalchemy_cratedb.sa_version import SA_2_0, SA_VERSION

if SA_VERSION < SA_2_0:
raise pytest.skip("Only supported on SQLAlchemy 2.0 and higher", allow_module_level=True)

from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

# Registering the additional dialects manually seems to be needed when running
# under tests. Apparently, manual registration is not needed under regular
# circumstances, as this is wired through the `sqlalchemy.dialects` entrypoint
# registrations in `pyproject.toml`. It is definitively weird, but c'est la vie.
dialect_registry.register("crate.urllib3", "sqlalchemy_cratedb.dialect_more", "dialect_urllib3")
dialect_registry.register("crate.asyncpg", "sqlalchemy_cratedb.dialect_more", "dialect_asyncpg")
dialect_registry.register("crate.psycopg", "sqlalchemy_cratedb.dialect_more", "dialect_psycopg")


QUERY = sa.text("SELECT mountain, coordinates FROM sys.summits ORDER BY mountain LIMIT 3;")


def test_engine_sync_vanilla(cratedb_service):
"""
crate:// -- Verify connectivity and data transport with vanilla HTTP-based driver.
"""
port4200 = cratedb_service.cratedb.get_exposed_port(4200)
engine = sa.create_engine(f"crate://crate@localhost:{port4200}/", echo=True)
assert isinstance(engine, sa.engine.Engine)
with engine.connect() as connection:
result = connection.execute(QUERY)
assert result.mappings().fetchone() == {
"mountain": "Acherkogel",
"coordinates": [10.95667, 47.18917],
}


def test_engine_sync_urllib3(cratedb_service):
"""
crate+urllib3:// -- Verify connectivity and data transport *explicitly* selecting the HTTP driver.
""" # noqa: E501
port4200 = cratedb_service.cratedb.get_exposed_port(4200)
engine = sa.create_engine(
f"crate+urllib3://crate@localhost:{port4200}/", isolation_level="AUTOCOMMIT", echo=True
)
assert isinstance(engine, sa.engine.Engine)
with engine.connect() as connection:
result = connection.execute(QUERY)
assert result.mappings().fetchone() == {
"mountain": "Acherkogel",
"coordinates": [10.95667, 47.18917],
}


def test_engine_sync_psycopg(cratedb_service):
"""
crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3).
"""
port5432 = cratedb_service.cratedb.get_exposed_port(5432)
engine = sa.create_engine(
f"crate+psycopg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True
)
assert isinstance(engine, sa.engine.Engine)
with engine.connect() as connection:
result = connection.execute(QUERY)
assert result.mappings().fetchone() == {
"mountain": "Acherkogel",
"coordinates": "(10.95667,47.18917)",
}


@pytest.mark.asyncio
async def test_engine_async_psycopg(cratedb_service):
"""
crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3).
This time, in asynchronous mode.
"""
port5432 = cratedb_service.cratedb.get_exposed_port(5432)
engine = create_async_engine(
f"crate+psycopg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True
)
assert isinstance(engine, AsyncEngine)
async with engine.begin() as conn:
result = await conn.execute(QUERY)
assert result.mappings().fetchone() == {
"mountain": "Acherkogel",
"coordinates": "(10.95667,47.18917)",
}


@pytest.mark.asyncio
async def test_engine_async_asyncpg(cratedb_service):
"""
crate+asyncpg:// -- Verify connectivity and data transport using the asyncpg driver.
This exclusively uses asynchronous mode.
"""
port5432 = cratedb_service.cratedb.get_exposed_port(5432)
from asyncpg.pgproto.types import Point

engine = create_async_engine(
f"crate+asyncpg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True
)
assert isinstance(engine, AsyncEngine)
async with engine.begin() as conn:
result = await conn.execute(QUERY)
assert result.mappings().fetchone() == {
"mountain": "Acherkogel",
"coordinates": Point(10.95667, 47.18917),
}