Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Account for sql.SQL and sql.Composed Objects #177

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions pgtrigger/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,23 @@
from typing import TYPE_CHECKING, List, Union

from django.db import connections
from typing_extensions import TypeAlias

from pgtrigger import registry, utils

if utils.psycopg_maj_version == 2:
import psycopg2.extensions
elif utils.psycopg_maj_version == 3:
import psycopg.pq
import psycopg.sql as psycopg_sql
else:
raise AssertionError

if TYPE_CHECKING:
from pgtrigger import Timing

_Query: TypeAlias = "str | bytes | psycopg_sql.SQL | psycopg_sql.Composed"
_Connection: TypeAlias = "psycopg.Connection | psycopg2.extensions.connection"

# All triggers currently being ignored
_ignore = threading.local()
Expand All @@ -32,12 +36,24 @@
_schema = threading.local()


def _is_concurrent_statement(sql: str | bytes) -> bool:
def _query_to_str(query: _Query, connection: _Connection) -> str:
psycopg_3 = utils.psycopg_maj_version == 3
if isinstance(query, str):
return query
elif isinstance(query, bytes):
return query.decode()
elif psycopg_3 and isinstance(query, (psycopg_sql.SQL, psycopg_sql.Composed)):
return query.as_string(connection)
else:
raise AssertionError


def _is_concurrent_statement(sql: _Query, connection: _Connection) -> bool:
"""
True if the sql statement is concurrent and cannot be ran in a transaction
"""
sql = _query_to_str(sql, connection)
sql = sql.strip().lower() if sql else ""
sql = sql.decode() if isinstance(sql, bytes) else sql
return sql.startswith("create") and "concurrently" in sql


Expand Down Expand Up @@ -72,7 +88,7 @@ def _can_inject_variable(cursor, sql):
"""
return (
not getattr(cursor, "name", None)
and not _is_concurrent_statement(sql)
and not _is_concurrent_statement(sql, cursor.connection)
and not _is_transaction_errored(cursor)
)

Expand All @@ -92,7 +108,7 @@ def _inject_pgtrigger_ignore(execute, sql, params, many, context):
"""
if _can_inject_variable(context["cursor"], sql):
serialized_ignore = "{" + ",".join(_ignore.value) + "}"
sql = sql.decode() if isinstance(sql, bytes) else sql
sql = _query_to_str(sql, context["cursor"])
sql = f"SELECT set_config('pgtrigger.ignore', %s, true); {sql}"
params = [serialized_ignore, *(params or ())]

Expand Down
33 changes: 23 additions & 10 deletions pgtrigger/tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
import pgtrigger.utils
from pgtrigger.tests import models, utils

if pgtrigger.utils.psycopg_maj_version == 3:
from psycopg.sql import SQL, Literal
else:
from unittest.mock import MagicMock

SQL = MagicMock()
Literal = MagicMock()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to avoid the parameterization failing on 3.



@pytest.mark.django_db
def test_schema():
Expand Down Expand Up @@ -228,20 +236,26 @@ def test_custom_db_table_ignore():
assert not models.CustomTableName.objects.exists()


@pytest.mark.skipif(
pgtrigger.utils.psycopg_maj_version == 3, reason="Psycopg2 preserves entire query"
)
@pytest.mark.django_db
@pytest.mark.parametrize(
"sql, params",
"sql, params, min_psycopg_version",
[
("select count(*) from auth_user where id = %s", (1,)),
("select count(*) from auth_user", ()),
(b"select count(*) from auth_user where id = %s", (1,)),
(b"select count(*) from auth_user", ()),
("select count(*) from auth_user where id = %s", (1,), 2),
("select count(*) from auth_user", (), 2),
(b"select count(*) from auth_user where id = %s", (1,), 2),
Copy link
Contributor Author

@max-muoto max-muoto Sep 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense to run these on 3 as the query can still technically be bytes. See https://github.com/psycopg/psycopg/blob/master/psycopg/psycopg/abc.py#L30

(b"select count(*) from auth_user", (), 2),
(SQL("select count(*) from auth_user where id = %s"), (1,), 3),
( # Formatting creates a composed object
SQL("select count(*) from auth_user where id = {id}").format(id=1),
(),
3,
),
],
)
def test_inject_trigger_ignore(settings, mocker, sql, params):
def test_inject_trigger_ignore(settings, mocker, sql, params, min_psycopg_version):
if pgtrigger.utils.psycopg_maj_version < min_psycopg_version:
pytest.skip("Psycopg version is less than {}".format(min_psycopg_version))

settings.DEBUG = True
expected_sql_base = "SELECT set_config('pgtrigger.ignore', '{ignored_triggers}', true)"
# Order isn't deterministic, so we need to check for either order.
Expand All @@ -255,7 +269,6 @@ def test_inject_trigger_ignore(settings, mocker, sql, params):
with connection.cursor() as cursor:
cursor.execute(sql, params)
query = connection.queries[-1]

assert query["sql"].startswith(expected_sql_1) or query["sql"].startswith(
expected_sql_2
)