Skip to content

Commit

Permalink
Fix database connction pool reference handling (#67)
Browse files Browse the repository at this point in the history
* reverted to connection_map usage, removing connection map timeout as race conditions can cause unexpected connection termination. connection_map remains vital for allowing connections within async tasks. Tasks are not able to share connection pools, and thus need the ability to generate their own pool for use

* adding optional echo argument to allow easier printing of sql commands issued by pydbanitc

---------
  • Loading branch information
codemation authored Sep 21, 2023
1 parent 46ca1a0 commit 8fee22e
Showing 1 changed file with 28 additions and 19 deletions.
47 changes: 28 additions & 19 deletions pydbantic/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
use_alembic: bool = False,
echo: bool = False,
):
self.connection_map = {}
self.DB_URL = db_url
self.tables = []
self.cache_enabled = cache_enabled
Expand All @@ -56,10 +57,8 @@ def __init__(
else {},
echo=echo,
)

self.use_alembic = use_alembic
self.__metadata__: BaseMeta = BaseMeta()
self._connection = self.db_connection()

self.DEFAULT_TRANSLATIONS = DEFAULT_TRANSLATIONS

Expand Down Expand Up @@ -613,18 +612,17 @@ async def execute(self, query, values: dict = {}):

self.log.debug(f"database query: {query} - values {values}")

async with self as database:
async with database.connection() as conn:
async with self as conn:
async with conn.connection():
return await conn.execute(query=query, values=values)

async def execute_many(self, query, values):
"""execute bulk insert"""
if self.cache_enabled:
await self.cache.invalidate(query.table.name)

async with self as database:
async with database.connection() as conn:
print(f"running {query} insertion of {len(values)} values")
async with self as conn:
async with conn.connection():
return await conn.execute_many(query=query, values=values)

async def fetch(self, query, table_name, values=None):
Expand All @@ -640,8 +638,8 @@ async def fetch(self, query, table_name, values=None):

self.log.debug(f"running query: {query} with {values}")

async with self as database:
async with database.connection() as conn:
async with self as conn:
async with conn.connection():
row = await conn.fetch_all(query=query)

if self.cache_enabled and row:
Expand All @@ -660,6 +658,7 @@ def create(
debug: bool = False,
testing: bool = False,
use_alembic: bool = False,
echo: bool = False,
):

cache_config = {"cache_enabled": cache_enabled}
Expand All @@ -673,6 +672,7 @@ def create(
debug=debug,
testing=testing,
use_alembic=use_alembic,
echo=echo,
**cache_config,
)

Expand All @@ -693,24 +693,33 @@ async def _migrate(self):
return self

async def db_connection(self):
pool_config = {"min_size": 5, "max_size": 20}
pool_config = {"min_size": 3, "max_size": 5}
conn_factory = (
{"factory": SQLiteConnection}
if "sqlite" in self.DB_URL.lower()
else pool_config
)
database = _Database(self.DB_URL, **conn_factory)
await database.connect()
while True:
status = yield database
if status == "finished":
self.log.debug(f"db_connection - closed")
break

yield database.disconnect()
async with _Database(self.DB_URL, **conn_factory) as connection:
while True:
status = yield connection
if status == "finished":
self.log.debug(f"db_connection - closed")
break

async def add_db_pool(self):
conn_id = str(uuid.uuid4())
db_connection = self.db_connection()
self.connection_map[conn_id] = db_connection
return await db_connection.asend(None)

async def __aenter__(self):
return await self._connection.asend(None)
for conn_id in self.connection_map:
if self.connection_map[conn_id].ag_running:
continue
return await self.connection_map[conn_id].asend(None)

return await self.add_db_pool()

async def __aexit__(self, exc_type, exc, tb):
pass

0 comments on commit 8fee22e

Please sign in to comment.