From 6b7d19ef95acc1ec245eb7e6fb1fcfd755f266dd Mon Sep 17 00:00:00 2001 From: Joshua Jamison Date: Tue, 4 Apr 2023 00:22:31 +0200 Subject: [PATCH] Operator database connection improvements (#53) * added not-equal operator to DataBaseModelAttributes. Corrected insert_many response on exception * corrected connection cleanup within __aexit__ * added tests for not equal operator * added tests for not equal operator --------- Co-authored-by: Joshua (codemation) --- pydbantic/core.py | 10 +++++++++- pydbantic/database.py | 2 +- tests/test_querying.py | 7 ++++++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/pydbantic/core.py b/pydbantic/core.py index aaf54f2..f311308 100644 --- a/pydbantic/core.py +++ b/pydbantic/core.py @@ -379,6 +379,14 @@ def __eq__(self, value) -> DataBaseModelCondition: (values,), ) + def __ne__(self, value) -> DataBaseModelCondition: + values = self.process_value(value) + return DataBaseModelCondition( + f"{self.name} != {values}", + self.column != self.process_value(value), + (values,), + ) + def inside(self, choices: List[Any]) -> DataBaseModelCondition: choices = [self.process_value(value) for value in choices] return DataBaseModelCondition( @@ -1936,7 +1944,7 @@ async def insert_many(cls: Type[T], rows: List[T]) -> Optional[int]: except Exception: database.log.exception(f"chain link insertion error") - return result + return None async def _insert( self, return_links=False diff --git a/pydbantic/database.py b/pydbantic/database.py index a4e3f2e..109d365 100644 --- a/pydbantic/database.py +++ b/pydbantic/database.py @@ -700,7 +700,7 @@ async def __aexit__(self, exc_type, exc, tb): if not self.connection_map[conn_id]["conn"].ag_running: if time.time() - self.connection_map[conn_id]["last"] > 120: try: - await self.connection_map[conn_id].asend("finished") + await self.connection_map[conn_id]["conn"].asend("finished") except StopAsyncIteration: pass del self.connection_map[conn_id] diff --git a/tests/test_querying.py b/tests/test_querying.py index ca517c3..bbe0db8 100644 --- a/tests/test_querying.py +++ b/tests/test_querying.py @@ -24,7 +24,7 @@ async def test_querying(loaded_database_and_model_with_cache): assert emp_with_salary[0].salary == 40000 manager_position = emp_with_salary[0].position - # breakpoint() + # filter on manager positions managers = await Employee.filter( position=manager_position[0], @@ -56,3 +56,8 @@ async def test_querying(loaded_database_and_model_with_cache): ranged_salary = await Employee.filter(Employee.salary.inside([40000, 10000, 30000])) assert len(ranged_salary) == 1 + + # filter using not equal operator + + employees = await Employee.filter(Employee.employee_id != managers[0].employee_id) + assert len(employees) == 19