diff --git a/tiled/_tests/test_utils.py b/tiled/_tests/test_utils.py new file mode 100644 index 000000000..f31fe86b1 --- /dev/null +++ b/tiled/_tests/test_utils.py @@ -0,0 +1,61 @@ +from ..utils import ensure_specified_sql_driver + + +def test_ensure_specified_sql_driver(): + # Postgres + # Default driver is added if missing. + assert ( + ensure_specified_sql_driver( + "postgresql://user:password@localhost:5432/database" + ) + == "postgresql+asyncpg://user:password@localhost:5432/database" + ) + # Default driver passes through if specified. + assert ( + ensure_specified_sql_driver( + "postgresql+asyncpg://user:password@localhost:5432/database" + ) + == "postgresql+asyncpg://user:password@localhost:5432/database" + ) + # Do not override user-provided. + assert ( + ensure_specified_sql_driver( + "postgresql+custom://user:password@localhost:5432/database" + ) + == "postgresql+custom://user:password@localhost:5432/database" + ) + + # SQLite + # Default driver is added if missing. + assert ( + ensure_specified_sql_driver("sqlite:////test.db") + == "sqlite+aiosqlite:////test.db" + ) + # Default driver passes through if specified. + assert ( + ensure_specified_sql_driver("sqlite+aiosqlite:////test.db") + == "sqlite+aiosqlite:////test.db" + ) + # Do not override user-provided. + assert ( + ensure_specified_sql_driver("sqlite+custom:////test.db") + == "sqlite+custom:////test.db" + ) + # Handle SQLite :memory: URIs + assert ( + ensure_specified_sql_driver("sqlite+aiosqlite://:memory:") + == "sqlite+aiosqlite://:memory:" + ) + assert ( + ensure_specified_sql_driver("sqlite://:memory:") + == "sqlite+aiosqlite://:memory:" + ) + # Handle SQLite relative URIs + assert ( + ensure_specified_sql_driver("sqlite+aiosqlite:///test.db") + == "sqlite+aiosqlite:///test.db" + ) + assert ( + ensure_specified_sql_driver("sqlite:///test.db") + == "sqlite+aiosqlite:///test.db" + ) diff --git a/tiled/utils.py b/tiled/utils.py index c9744f320..eefdc9529 100644 --- a/tiled/utils.py +++ b/tiled/utils.py @@ -739,11 +739,9 @@ def ensure_specified_sql_driver(uri: str) -> str: 'postgresql+asyncpg://...' -> 'postgresql+asynpg://...' 'postgresql+my_custom_driver://...' -> 'postgresql+my_custom_driver://...' """ - parsed_uri = urlparse(uri) - scheme = parsed_uri.scheme + scheme, rest = uri.split(":", 1) new_scheme = SCHEME_TO_SCHEME_PLUS_DRIVER.get(scheme, scheme) - updated_uri = urlunparse(parsed_uri._replace(scheme=new_scheme)) - return updated_uri + return ":".join([new_scheme, rest]) class catch_warning_msg(warnings.catch_warnings):