From ffab504da086720912bf706923fe611b9ad2008a Mon Sep 17 00:00:00 2001 From: Jacobe2169 Date: Fri, 8 Dec 2023 15:13:57 +0100 Subject: [PATCH] fix test + change decorator --- src/utils_flask_sqla/models.py | 19 +++++++++++-------- src/utils_flask_sqla/tests/test_qfilter.py | 15 +++++++++++++-- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/utils_flask_sqla/models.py b/src/utils_flask_sqla/models.py index 6f0faa8..936c9a6 100644 --- a/src/utils_flask_sqla/models.py +++ b/src/utils_flask_sqla/models.py @@ -2,6 +2,12 @@ from flask_sqlalchemy.model import DefaultMeta from sqlalchemy.sql import select, Select +AUTHORIZED_WHERECLAUSE_TYPES = [bool, BooleanClauseList, BinaryExpression] + + +def is_whereclause_compatible(object): + return any([isinstance(object, type_) for type_ in AUTHORIZED_WHERECLAUSE_TYPES]) + def qfilter(*args_dec, **kwargs_dec): """ @@ -22,7 +28,7 @@ def filter_by_params(cls,**kwargs): filters = [] if "id_station" in kwargs: filters.append(Station.id_station == kwargs["id_station"]) - return query.whereclause + return filters # If you wish the method to return a query @qfilter(query=True) def filter_by_paramsQ(cls,**kwargs): @@ -60,8 +66,8 @@ def filter_by_paramsQ(cls,**kwargs): return _qfilter(*args_dec, **kwargs_dec) -def _qfilter(*args_dec, **kwargs_dec): - is_query = kwargs_dec.get("query", False) +def _qfilter(query=False): + is_query = query def _qfilter_decorator(method): def _(*args, **kwargs): @@ -83,13 +89,10 @@ def _(*args, **kwargs): if is_query and not isinstance(result, Select): raise ValueError("Your method must return a SQLAlchemy Select object ") - authorise_whereclause_type = [bool, BooleanClauseList, BinaryExpression] - if not is_query and not any( - [isinstance(result, type_) for type_ in authorise_whereclause_type] - ): + if not is_query and not is_whereclause_compatible(result): raise ValueError( "Your method must return an object in the following types: {} ".format( - ", ".join(map(lambda cls: cls.__name__, authorise_whereclause_type)) + ", ".join(map(lambda cls: cls.__name__, AUTHORIZED_WHERECLAUSE_TYPES)) ) ) # if filter is wanted as where clause diff --git a/src/utils_flask_sqla/tests/test_qfilter.py b/src/utils_flask_sqla/tests/test_qfilter.py index 97f375c..d2c9612 100644 --- a/src/utils_flask_sqla/tests/test_qfilter.py +++ b/src/utils_flask_sqla/tests/test_qfilter.py @@ -1,6 +1,6 @@ import pytest from flask import Flask -from sqlalchemy import func +from sqlalchemy import func, and_ from flask_sqlalchemy import SQLAlchemy @@ -26,6 +26,10 @@ def where_pk_query(cls, pk, **kwargs): query = kwargs["query"] return query.where(BarModel.pk == pk) + @qfilter + def where_pk_list(cls, pk, **kwargs): + return and_(*[BarModel.pk == pk]) + @pytest.fixture(scope="session") def app(): @@ -54,7 +58,7 @@ def bar(app): class TestQfilter: - def test_qfilter_returns_whereclause(self, bar): + def test_qfilter(self, bar): assert db.session.scalars(BarModel.where_pk_query(bar.pk)).one_or_none() is bar assert ( db.session.scalars(db.select(BarModel).where(BarModel.where_pk(bar.pk))).one_or_none() @@ -68,3 +72,10 @@ def test_qfilter_returns_whereclause(self, bar): ).one_or_none() is not bar ) + + assert ( + db.session.scalars( + db.select(BarModel).where(BarModel.where_pk_list(bar.pk)) + ).one_or_none() + is bar + )