Skip to content

Commit

Permalink
Merge pull request #232 from bento-platform/feat/db/return-if-init
Browse files Browse the repository at this point in the history
feat(db): more parallel awareness for Postgres DB class
  • Loading branch information
davidlougheed authored Aug 30, 2024
2 parents 7f31615 + 4b39b67 commit 9ed65d0
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 33 deletions.
72 changes: 59 additions & 13 deletions bento_lib/db/pg_async.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,93 @@
import asyncio

import aiofiles
import asyncpg
import contextlib
from pathlib import Path
from typing import AsyncIterator


__all__ = [
"PgAsyncDatabaseException",
"PgAsyncDatabase",
]


class PgAsyncDatabaseException(Exception):
pass


class PgAsyncDatabase:

def __init__(self, db_uri: str, schema_path: Path):
self._db_uri: str = db_uri
self._schema_path: Path = schema_path

self._pool: asyncpg.Pool | None = None
self._pool_init_task: asyncio.Task | None = None
self._pool_closing_task: asyncio.Task | None = None

async def initialize(self, pool_size: int = 10):
conn: asyncpg.Connection
async def initialize(self, pool_size: int = 10) -> bool:
if self._pool_closing_task:
raise PgAsyncDatabaseException("Cannot open the pool while it is closing")

if not self._pool: # Initialize the connection pool if needed
self._pool = await asyncpg.create_pool(self._db_uri, min_size=pool_size, max_size=pool_size)
if not self._pool_init_task:
async def _init():
pool = await asyncpg.create_pool(self._db_uri, min_size=pool_size, max_size=pool_size)

# If we freshly initialized the connection pool, connect to the database and execute the schema
# script. Don't use our own self.connect() method, since that'll end up in a circular task await.
async with aiofiles.open(self._schema_path, "r") as sf:
conn: asyncpg.Connection
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(await sf.read())

self._pool = pool
self._pool_init_task = None

return True # Freshly initialized the pool + executed the schema script

# Connect to the database and execute the schema script
async with aiofiles.open(self._schema_path, "r") as sf:
async with self.connect() as conn:
async with conn.transaction():
await conn.execute(await sf.read())
self._pool_init_task = asyncio.create_task(_init())

# self._pool_init_task is now guaranteed to not be None - can be awaited
return await self._pool_init_task

return False # Pool already initialized

async def close(self) -> bool:
if self._pool_init_task:
raise PgAsyncDatabaseException("Cannot close the pool while it is opening")

async def close(self):
if self._pool:
await self._pool.close()
self._pool = None
if not self._pool_closing_task:
async def _close():
await self._pool.close()
# must come after the "await" in this function, so that we can properly re-use the task that is
# checked for IF self._pool is not None:
self._pool = None
self._pool_closing_task = None
return True # Just closed the pool

self._pool_closing_task = asyncio.create_task(_close())

# self._pool_closing_task is now guaranteed to not be None - can be awaited
return await self._pool_closing_task

return False # Pool already closed

@contextlib.asynccontextmanager
async def connect(self, existing_conn: asyncpg.Connection | None = None) -> AsyncIterator[asyncpg.Connection]:
# TODO: raise raise DatabaseError("Pool is not available") when FastAPI has lifespan dependencies
# + manage pool lifespan in lifespan fn.

# If we're currently closing, wait for closing to finish before trying to re-open
if self._pool_closing_task:
await self._pool_closing_task

if self._pool is None:
await self.initialize() # initialize if this is the first time we're using the pool
# initialize if this is the first time we're using the pool, or wait for existing initialization to finish:
await self.initialize()

if existing_conn is not None:
yield existing_conn
Expand Down
14 changes: 7 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "bento-lib"
version = "12.1.1"
version = "12.2.0"
description = "A set of common utilities and helpers for Bento platform services."
authors = [
"David Lougheed <[email protected]>",
Expand Down Expand Up @@ -46,7 +46,7 @@ flake8 = "^7.0.0"
httpx = "^0.27.0"
mypy = "~1.11.0"
pytest = "^8.3.2"
pytest-asyncio = "^0.23.5"
pytest-asyncio = "^0.23.8"
pytest-cov = "^5.0.0"
pytest-django = "^4.8.0"
python-dateutil = "^2.8.2"
Expand Down
90 changes: 79 additions & 11 deletions tests/test_db.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import pathlib
import asyncpg
import pytest
import pytest_asyncio
from bento_lib.db.pg_async import PgAsyncDatabase
from bento_lib.db.pg_async import PgAsyncDatabaseException, PgAsyncDatabase
from typing import AsyncGenerator


Expand All @@ -18,6 +19,14 @@ async def get_test_db() -> AsyncGenerator[PgAsyncDatabase, None]:
db_fixture = pytest_asyncio.fixture(get_test_db, name="pg_async_db")


async def get_test_db_no_init() -> AsyncGenerator[PgAsyncDatabase, None]:
db_instance = PgAsyncDatabase("postgresql://postgres@localhost:5432/postgres", TEST_SCHEMA)
yield db_instance


db_fixture_no_init = pytest_asyncio.fixture(get_test_db_no_init, name="pg_async_db_no_init")


@pytest_asyncio.fixture
async def db_cleanup(pg_async_db: PgAsyncDatabase):
yield
Expand All @@ -27,28 +36,87 @@ async def db_cleanup(pg_async_db: PgAsyncDatabase):
await pg_async_db.close()


@pytest_asyncio.fixture
async def db_cleanup_no_init(pg_async_db_no_init: PgAsyncDatabase):
yield
conn: asyncpg.Connection
async with pg_async_db_no_init.connect() as conn:
await conn.execute("DROP TABLE IF EXISTS test_table")
await pg_async_db_no_init.close()


# noinspection PyUnusedLocal
@pytest.mark.asyncio
async def test_pg_async_db_open_close(pg_async_db: PgAsyncDatabase, db_cleanup):
await pg_async_db.close()
async def test_pg_async_db_close_auto_open(pg_async_db: PgAsyncDatabase, db_cleanup):
r = await pg_async_db.close()
assert r # did in fact close the pool
assert pg_async_db._pool is None

# duplicate request: should be idempotent
await pg_async_db.close()
r = await pg_async_db.close()
assert not r # didn't close the pool, since it was already closed.
assert pg_async_db._pool is None

# should not be able to connect
conn: asyncpg.Connection
async with pg_async_db.connect() as conn:
assert pg_async_db._pool is not None # Connection auto-initialized
async with pg_async_db.connect(existing_conn=conn) as conn2:
assert conn == conn2 # Re-using existing connection should be possible

# try re-opening
await pg_async_db.initialize()
assert pg_async_db._pool is not None
old_pool = pg_async_db._pool

@pytest.mark.asyncio
async def test_pg_async_db_open(pg_async_db_no_init: PgAsyncDatabase, db_cleanup_no_init):
# try opening
r = await pg_async_db_no_init.initialize(pool_size=1)
assert r
assert pg_async_db_no_init._pool is not None
old_pool = pg_async_db_no_init._pool

# duplicate request: should be idempotent
await pg_async_db.initialize()
assert pg_async_db._pool == old_pool # same instance
r = await pg_async_db_no_init.initialize()
assert not r # didn't actually initialize the pool; re-used the old object
assert pg_async_db_no_init._pool == old_pool # same instance


@pytest.mark.asyncio
async def test_pg_async_db_parallel_open(pg_async_db_no_init: PgAsyncDatabase, db_cleanup):
# start opening in one coroutine, check with the other - should re-use task
c = pg_async_db_no_init.initialize(pool_size=1)
c2 = pg_async_db_no_init.initialize(pool_size=1)
assert await asyncio.gather(c, c2) == [True, True]


@pytest.mark.asyncio
async def test_pg_async_db_parallel_close(pg_async_db: PgAsyncDatabase, db_cleanup):
# start closing in one coroutine, check with the other - should re-use task
c = pg_async_db.close()
c2 = pg_async_db.close()
assert await asyncio.gather(c, c2) == [True, True] # should both internally use the same coroutine & return True


@pytest.mark.asyncio
async def test_pg_async_db_parallel_exc_close_while_opening(pg_async_db_no_init: PgAsyncDatabase, db_cleanup):
# while opening, try closing - should trigger error
with pytest.raises(PgAsyncDatabaseException) as e:
await asyncio.gather(pg_async_db_no_init.initialize(), pg_async_db_no_init.close())

assert str(e.value) == "Cannot close the pool while it is opening"


@pytest.mark.asyncio
async def test_pg_async_db_parallel_exc_open_while_closing(pg_async_db: PgAsyncDatabase, db_cleanup):
# while closing, try opening - should trigger error
with pytest.raises(PgAsyncDatabaseException) as e:
await asyncio.gather(pg_async_db.close(), pg_async_db.initialize())

assert str(e.value) == "Cannot open the pool while it is closing"


@pytest.mark.asyncio
async def test_pg_async_db_parallel_exc_close_then_connect(pg_async_db: PgAsyncDatabase, db_cleanup):
# connect should wait for the pool to close, then re-open it
async def _c():
async with pg_async_db.connect():
pass
await asyncio.gather(pg_async_db.close(), _c())
assert pg_async_db._pool is not None

0 comments on commit 9ed65d0

Please sign in to comment.