diff --git a/databases/core.py b/databases/core.py index 2bab6735..fa0fec03 100644 --- a/databases/core.py +++ b/databases/core.py @@ -1,4 +1,4 @@ -import asyncio +import anyio import contextlib import functools import logging @@ -203,14 +203,14 @@ class Connection: def __init__(self, backend: DatabaseBackend) -> None: self._backend = backend - self._connection_lock = asyncio.Lock() + self._connection_lock = anyio.Lock() self._connection = self._backend.connection() self._connection_counter = 0 - self._transaction_lock = asyncio.Lock() + self._transaction_lock = anyio.Lock() self._transaction_stack = [] # type: typing.List[Transaction] - self._query_lock = asyncio.Lock() + self._query_lock = anyio.Lock() async def __aenter__(self) -> "Connection": async with self._connection_lock: diff --git a/requirements.txt b/requirements.txt index 0f22a025..b476b61b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,7 @@ # Notes... # The JSONField tests require sqlalchemy 1.3+. Other cases work at lower versions. # The aiocontextvars package is only required as a backport for Python 3.6. --e . - -# Async database drivers -aiomysql -aiopg -aiosqlite -asyncpg +-e .[trio-postgresql,trio-mysql,mysql,sqlite,postgresql] # Sync database drivers for standard tooling around setup/teardown/migrations. psycopg2-binary diff --git a/setup.py b/setup.py index 4cdd0fff..4c58f268 100644 --- a/setup.py +++ b/setup.py @@ -48,8 +48,10 @@ def get_packages(package): packages=get_packages("databases"), package_data={"databases": ["py.typed"]}, data_files=[("", ["LICENSE.md"])], - install_requires=['sqlalchemy<1.4', 'aiocontextvars;python_version<"3.7"'], + install_requires=['sqlalchemy<1.4', 'aiocontextvars;python_version<"3.7"', 'anyio~=3.2'], extras_require={ + "trio-postgresql": ["anyio[trio]", "triopg"], + "trio-mysql": ["anyio[trio]", "trio-mysql"], "postgresql": ["asyncpg"], "mysql": ["aiomysql"], "sqlite": ["aiosqlite"], diff --git a/tests/test_databases.py b/tests/test_databases.py index c7317688..333860b7 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -1,4 +1,4 @@ -import asyncio +import anyio import datetime import decimal import functools @@ -97,22 +97,8 @@ def create_test_database(): metadata.drop_all(engine) -def async_adapter(wrapped_func): - """ - Decorator used to run async test cases. - """ - - @functools.wraps(wrapped_func) - def run_sync(*args, **kwargs): - loop = asyncio.new_event_loop() - task = wrapped_func(*args, **kwargs) - return loop.run_until_complete(task) - - return run_sync - - @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_queries(database_url): """ Test that the basic `execute()`, `execute_many()`, `fetch_all()``, and @@ -188,7 +174,7 @@ async def test_queries(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_queries_raw(database_url): """ Test that the basic `execute()`, `execute_many()`, `fetch_all()``, and @@ -249,7 +235,7 @@ async def test_queries_raw(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_ddl_queries(database_url): """ Test that the built-in DDL elements such as `DropTable()`, @@ -267,7 +253,7 @@ async def test_ddl_queries(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_results_support_mapping_interface(database_url): """ Casting results to a dict should work, since the interface defines them @@ -294,7 +280,7 @@ async def test_results_support_mapping_interface(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_results_support_column_reference(database_url): """ Casting results to a dict should work, since the interface defines them @@ -325,7 +311,7 @@ async def test_results_support_column_reference(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_result_values_allow_duplicate_names(database_url): """ The values of a result should respect when two columns are selected @@ -341,7 +327,7 @@ async def test_result_values_allow_duplicate_names(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_fetch_one_returning_no_results(database_url): """ fetch_one should return `None` when no results match. @@ -355,7 +341,7 @@ async def test_fetch_one_returning_no_results(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_execute_return_val(database_url): """ Test using return value from `execute()` to get an inserted primary key. @@ -381,7 +367,7 @@ async def test_execute_return_val(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_rollback_isolation(database_url): """ Ensure that `database.transaction(force_rollback=True)` provides strict isolation. @@ -400,7 +386,7 @@ async def test_rollback_isolation(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_rollback_isolation_with_contextmanager(database_url): """ Ensure that `database.force_rollback()` provides strict isolation. @@ -422,7 +408,7 @@ async def test_rollback_isolation_with_contextmanager(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_transaction_commit(database_url): """ Ensure that transaction commit is supported. @@ -439,7 +425,7 @@ async def test_transaction_commit(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_transaction_commit_serializable(database_url): """ Ensure that serializable transaction commit via extra parameters is supported. @@ -480,7 +466,7 @@ def delete_independently(): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_transaction_rollback(database_url): """ Ensure that transaction rollback is supported. @@ -502,7 +488,7 @@ async def test_transaction_rollback(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_transaction_commit_low_level(database_url): """ Ensure that an explicit `await transaction.commit()` is supported. @@ -525,7 +511,7 @@ async def test_transaction_commit_low_level(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_transaction_rollback_low_level(database_url): """ Ensure that an explicit `await transaction.rollback()` is supported. @@ -549,7 +535,7 @@ async def test_transaction_rollback_low_level(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_transaction_decorator(database_url): """ Ensure that @database.transaction() is supported. @@ -579,7 +565,7 @@ async def insert_data(raise_exception): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_datetime_field(database_url): """ Test DataTime columns, to ensure records are coerced to/from proper Python types. @@ -603,7 +589,7 @@ async def test_datetime_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_decimal_field(database_url): """ Test Decimal (NUMERIC) columns, to ensure records are coerced to/from proper Python types. @@ -630,7 +616,7 @@ async def test_decimal_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_json_field(database_url): """ Test JSON columns, to ensure correct cross-database support. @@ -652,7 +638,7 @@ async def test_json_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_custom_field(database_url): """ Test custom column types. @@ -677,7 +663,7 @@ async def test_custom_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_connections_isolation(database_url): """ Ensure that changes are visible between different connections. @@ -699,7 +685,7 @@ async def test_connections_isolation(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_commit_on_root_transaction(database_url): """ Because our tests are generally wrapped in rollback-islation, they @@ -723,7 +709,7 @@ async def test_commit_on_root_transaction(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_connect_and_disconnect(database_url): """ Test explicit connect() and disconnect(). @@ -738,7 +724,7 @@ async def test_connect_and_disconnect(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_connection_context(database_url): """ Test connection contexts are task-local. @@ -751,7 +737,7 @@ async def test_connection_context(database_url): async with Database(database_url) as database: connection_1 = None connection_2 = None - test_complete = asyncio.Event() + test_complete = anyio.Event() async def get_connection_1(): nonlocal connection_1 @@ -767,19 +753,17 @@ async def get_connection_2(): connection_2 = connection await test_complete.wait() - loop = asyncio.get_event_loop() - task_1 = loop.create_task(get_connection_1()) - task_2 = loop.create_task(get_connection_2()) - while connection_1 is None or connection_2 is None: - await asyncio.sleep(0.000001) - assert connection_1 is not connection_2 - test_complete.set() - await task_1 - await task_2 + async with anyio.create_task_group() as tg: + tg.start_soon(get_connection_1) + tg.start_soon(get_connection_2) + while connection_1 is None or connection_2 is None: + await anyio.sleep(0.000001) + assert connection_1 is not connection_2 + test_complete.set() @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_connection_context_with_raw_connection(database_url): """ Test connection contexts with respect to the raw connection. @@ -792,7 +776,7 @@ async def test_connection_context_with_raw_connection(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_queries_with_expose_backend_connection(database_url): """ Replication of `execute()`, `execute_many()`, `fetch_all()``, and @@ -871,7 +855,7 @@ async def test_queries_with_expose_backend_connection(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_database_url_interface(database_url): """ Test that Database instances expose a `.url` attribute. @@ -881,8 +865,14 @@ async def test_database_url_interface(database_url): assert database.url == database_url +async def _wait_all(*async_fns): + async with anyio.create_task_group() as tg: + for async_fn in async_fns: + tg.start_soon(async_fn) + + @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_concurrent_access_on_single_connection(database_url): database_url = DatabaseURL(database_url) if database_url.dialect != "postgresql": @@ -893,11 +883,15 @@ async def test_concurrent_access_on_single_connection(database_url): async def db_lookup(): await database.fetch_one("SELECT pg_sleep(1)") - await asyncio.gather(db_lookup(), db_lookup()) + await _wait_all(db_lookup, db_lookup) @pytest.mark.parametrize("database_url", DATABASE_URLS) -def test_global_connection_is_initialized_lazily(database_url): +def test_global_connection_is_initialized_lazily( + database_url, + anyio_backend_name, + anyio_backend_options, +): """ Ensure that global connection is initialized at latest possible time so it's _query_lock will belong to same event loop that async_adapter has @@ -912,20 +906,23 @@ def test_global_connection_is_initialized_lazily(database_url): database = Database(database_url, force_rollback=True) - @async_adapter async def run_database_queries(): async with database: async def db_lookup(): await database.fetch_one("SELECT pg_sleep(1)") - await asyncio.gather(db_lookup(), db_lookup()) + await _wait_all(db_lookup, db_lookup) - run_database_queries() + anyio.run( + run_database_queries, + backend=anyio_backend_name, + **anyio_backend_options, + ) @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_iterate_outside_transaction_with_values(database_url): """ Ensure `iterate()` works even without a transaction on all drivers. @@ -949,7 +946,7 @@ async def test_iterate_outside_transaction_with_values(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter +@pytest.mark.anyio async def test_iterate_outside_transaction_with_temp_table(database_url): """ Same as test_iterate_outside_transaction_with_values but uses a @@ -978,7 +975,7 @@ async def test_iterate_outside_transaction_with_temp_table(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @pytest.mark.parametrize("select_query", [notes.select(), "SELECT * FROM notes"]) -@async_adapter +@pytest.mark.anyio async def test_column_names(database_url, select_query): """ Test that column names are exposed correctly through `.keys()` on each row.