diff --git a/CHANGES.md b/CHANGES.md index 2e2da41..34f4501 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,7 +1,7 @@ # Changelog - ## Unreleased +- Added `quote_relation_name` support utility function ## 2024/06/25 0.38.0 - Added/reactivated documentation as `sqlalchemy-cratedb` diff --git a/src/sqlalchemy_cratedb/support/__init__.py b/src/sqlalchemy_cratedb/support/__init__.py index d140d60..0934088 100644 --- a/src/sqlalchemy_cratedb/support/__init__.py +++ b/src/sqlalchemy_cratedb/support/__init__.py @@ -1,12 +1,13 @@ from sqlalchemy_cratedb.support.pandas import insert_bulk, table_kwargs from sqlalchemy_cratedb.support.polyfill import check_uniqueness_factory, refresh_after_dml, \ patch_autoincrement_timestamp -from sqlalchemy_cratedb.support.util import refresh_table, refresh_dirty +from sqlalchemy_cratedb.support.util import quote_relation_name, refresh_table, refresh_dirty __all__ = [ check_uniqueness_factory, insert_bulk, patch_autoincrement_timestamp, + quote_relation_name, refresh_after_dml, refresh_dirty, refresh_table, diff --git a/src/sqlalchemy_cratedb/support/util.py b/src/sqlalchemy_cratedb/support/util.py index 33cce5f..1d069f4 100644 --- a/src/sqlalchemy_cratedb/support/util.py +++ b/src/sqlalchemy_cratedb/support/util.py @@ -3,6 +3,8 @@ import sqlalchemy as sa +from sqlalchemy_cratedb.dialect import CrateDialect + if t.TYPE_CHECKING: try: from sqlalchemy.orm import DeclarativeBase @@ -10,6 +12,10 @@ pass +# An instance of the dialect used for quoting purposes. +identifier_preparer = CrateDialect().identifier_preparer + + def refresh_table(connection, target: t.Union[str, "DeclarativeBase", "sa.sql.selectable.TableClause"]): """ Invoke a `REFRESH TABLE` statement. @@ -39,3 +45,36 @@ def refresh_dirty(session, flush_context=None): dirty_classes = {entity.__class__ for entity in dirty_entities} for class_ in dirty_classes: refresh_table(session, class_) + + +def quote_relation_name(ident: str) -> str: + """ + Quote a simple or full-qualified table/relation name, when needed. + + Simple: + Full-qualified: .
+ + Happy path examples: + + foo => foo + Foo => "Foo" + "Foo" => "Foo" + foo.bar => foo.bar + foo-bar.baz_qux => "foo-bar".baz_qux + + Such input strings will not be modified: + + "foo.bar" => "foo.bar" + """ + + # If a quote exists at the beginning or the end of the input string, + # let's consider that the relation name has been quoted already. + if ident.startswith('"') or ident.endswith('"'): + return ident + + # If a dot is included, it's a full-qualified identifier like .
. + # It needs to be split, in order to apply identifier quoting properly. + parts = ident.split(".") + if len(parts) > 3: + raise ValueError(f"Invalid relation name, too many parts: {ident}") + return ".".join(map(identifier_preparer.quote, parts)) diff --git a/tests/test_support_util.py b/tests/test_support_util.py new file mode 100644 index 0000000..b75ed8d --- /dev/null +++ b/tests/test_support_util.py @@ -0,0 +1,51 @@ +import pytest + +from sqlalchemy_cratedb.support import quote_relation_name + + +def test_quote_relation_name_once(): + """ + Verify quoting a simple or full-qualified relation name. + """ + + # Table name only. + assert quote_relation_name("my_table") == "my_table" + assert quote_relation_name("my-table") == '"my-table"' + assert quote_relation_name("MyTable") == '"MyTable"' + assert quote_relation_name('"MyTable"') == '"MyTable"' + + # Schema and table name. + assert quote_relation_name("my_schema.my_table") == "my_schema.my_table" + assert quote_relation_name("my-schema.my_table") == '"my-schema".my_table' + assert quote_relation_name('"wrong-quoted-fqn.my_table"') == '"wrong-quoted-fqn.my_table"' + assert quote_relation_name('"my_schema"."my_table"') == '"my_schema"."my_table"' + + # Catalog, schema, and table name. + assert quote_relation_name("crate.doc.t01") == "crate.doc.t01" + + +def test_quote_relation_name_twice(): + """ + Verify quoting a relation name twice does not cause any harm. + """ + input_fqn = "foo-bar.baz_qux" + output_fqn = '"foo-bar".baz_qux' + assert quote_relation_name(input_fqn) == output_fqn + assert quote_relation_name(output_fqn) == output_fqn + + +def test_quote_relation_name_reserved_keywords(): + """ + Verify quoting a simple relation name that is a reserved keyword. + """ + assert quote_relation_name("table") == '"table"' + assert quote_relation_name("true") == '"true"' + assert quote_relation_name("select") == '"select"' + + +def test_quote_relation_name_with_invalid_fqn(): + """ + Verify quoting a relation name with an invalid fqn raises an error. + """ + with pytest.raises(ValueError): + quote_relation_name("too-many.my-db.my-schema.my-table")