Skip to content

Commit

Permalink
fix test + change decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
jacquesfize committed Dec 8, 2023
1 parent 7637cf2 commit ffab504
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
19 changes: 11 additions & 8 deletions src/utils_flask_sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
from flask_sqlalchemy.model import DefaultMeta
from sqlalchemy.sql import select, Select

AUTHORIZED_WHERECLAUSE_TYPES = [bool, BooleanClauseList, BinaryExpression]


def is_whereclause_compatible(object):
return any([isinstance(object, type_) for type_ in AUTHORIZED_WHERECLAUSE_TYPES])


def qfilter(*args_dec, **kwargs_dec):
"""
Expand All @@ -22,7 +28,7 @@ def filter_by_params(cls,**kwargs):
filters = []
if "id_station" in kwargs:
filters.append(Station.id_station == kwargs["id_station"])
return query.whereclause
return filters
# If you wish the method to return a query
@qfilter(query=True)
def filter_by_paramsQ(cls,**kwargs):
Expand Down Expand Up @@ -60,8 +66,8 @@ def filter_by_paramsQ(cls,**kwargs):
return _qfilter(*args_dec, **kwargs_dec)


def _qfilter(*args_dec, **kwargs_dec):
is_query = kwargs_dec.get("query", False)
def _qfilter(query=False):
is_query = query

def _qfilter_decorator(method):
def _(*args, **kwargs):
Expand All @@ -83,13 +89,10 @@ def _(*args, **kwargs):
if is_query and not isinstance(result, Select):
raise ValueError("Your method must return a SQLAlchemy Select object ")

authorise_whereclause_type = [bool, BooleanClauseList, BinaryExpression]
if not is_query and not any(
[isinstance(result, type_) for type_ in authorise_whereclause_type]
):
if not is_query and not is_whereclause_compatible(result):
raise ValueError(
"Your method must return an object in the following types: {} ".format(
", ".join(map(lambda cls: cls.__name__, authorise_whereclause_type))
", ".join(map(lambda cls: cls.__name__, AUTHORIZED_WHERECLAUSE_TYPES))
)
)
# if filter is wanted as where clause
Expand Down
15 changes: 13 additions & 2 deletions src/utils_flask_sqla/tests/test_qfilter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from flask import Flask
from sqlalchemy import func
from sqlalchemy import func, and_

from flask_sqlalchemy import SQLAlchemy

Expand All @@ -26,6 +26,10 @@ def where_pk_query(cls, pk, **kwargs):
query = kwargs["query"]
return query.where(BarModel.pk == pk)

@qfilter
def where_pk_list(cls, pk, **kwargs):
return and_(*[BarModel.pk == pk])


@pytest.fixture(scope="session")
def app():
Expand Down Expand Up @@ -54,7 +58,7 @@ def bar(app):


class TestQfilter:
def test_qfilter_returns_whereclause(self, bar):
def test_qfilter(self, bar):
assert db.session.scalars(BarModel.where_pk_query(bar.pk)).one_or_none() is bar
assert (
db.session.scalars(db.select(BarModel).where(BarModel.where_pk(bar.pk))).one_or_none()
Expand All @@ -68,3 +72,10 @@ def test_qfilter_returns_whereclause(self, bar):
).one_or_none()
is not bar
)

assert (
db.session.scalars(
db.select(BarModel).where(BarModel.where_pk_list(bar.pk))
).one_or_none()
is bar
)

0 comments on commit ffab504

Please sign in to comment.