diff --git a/fai-rag-app/fai-backend/fai_backend/repository/interface.py b/fai-rag-app/fai-backend/fai_backend/repository/interface.py index 30679076..10b273b0 100644 --- a/fai-rag-app/fai-backend/fai_backend/repository/interface.py +++ b/fai-rag-app/fai-backend/fai_backend/repository/interface.py @@ -1,5 +1,7 @@ from typing import Protocol, TypeVar +from fai_backend.repository.query.component import QueryComponent + T = TypeVar('T') @@ -7,7 +9,7 @@ class IAsyncRepo(Protocol[T]): async def get(self, item_id: str) -> T | None: raise NotImplementedError('get not implemented') - async def list(self) -> list[T]: + async def list(self, query: QueryComponent = None, sort_by: str = None, sort_order: str = 'asc') -> list[T]: raise NotImplementedError('list not implemented') async def create(self, item: T) -> T | None: diff --git a/fai-rag-app/fai-backend/fai_backend/repository/mongodb.py b/fai-rag-app/fai-backend/fai_backend/repository/mongodb.py index e1177571..84ea10bf 100644 --- a/fai-rag-app/fai-backend/fai_backend/repository/mongodb.py +++ b/fai-rag-app/fai-backend/fai_backend/repository/mongodb.py @@ -1,10 +1,13 @@ from typing import Generic, TypeVar -from beanie import Document, PydanticObjectId +from beanie import Document, PydanticObjectId, SortDirection from bson.errors import InvalidId from pydantic import BaseModel from fai_backend.repository.interface import IAsyncRepo +from fai_backend.repository.query.component import ( + QueryComponent, +) T = TypeVar('T', bound=BaseModel) T_DB = TypeVar('T_DB', bound=Document) @@ -18,12 +21,27 @@ def __init__(self, model: type[T], odm_model: type[T_DB]): self.model = model self.odm_model = odm_model - async def list(self) -> list[T]: - return await self.odm_model.all().to_list() + async def list( + self, + query: QueryComponent = None, + sort_by: str = None, + sort_order: str = 'asc' + ) -> list[T]: + def find_query(q: QueryComponent = None) -> dict: + return adapt_query_component(q).to_mongo_query() if query else {} + + db_query = self.odm_model.find(find_query(query)) + + if sort_by: + direction = SortDirection.ASCENDING if sort_order == 'asc' else SortDirection.DESCENDING + db_query = db_query.sort((sort_by, direction)) + return [self.model.model_validate(doc) for doc in await db_query.to_list()] async def create(self, item: T) -> T: - item.id = PydanticObjectId() - return self.model.model_validate(await self.odm_model.model_validate(item.model_dump()).create()) + item = item.model_dump() + item['id'] = PydanticObjectId() + item_in_db = await self.odm_model.model_validate(item).create() + return self.model.model_validate(item_in_db) async def get(self, item_id: str) -> T | None: try: diff --git a/fai-rag-app/fai-backend/fai_backend/repository/query/component.py b/fai-rag-app/fai-backend/fai_backend/repository/query/component.py new file mode 100644 index 00000000..ff1ade22 --- /dev/null +++ b/fai-rag-app/fai-backend/fai_backend/repository/query/component.py @@ -0,0 +1,48 @@ +from dataclasses import dataclass +from typing import Literal, Union + +Path = str + + +@dataclass +class AttributeAssignment: + path: Path + value: any + + +@dataclass +class AttributeComparison: + path: Path + operator: Literal['<'] | Literal['<='] | Literal['=='] | Literal['!='] | Literal['>'] | Literal['>='] + value: any + + +@dataclass +class LogicalExpression: + operator: Literal['AND'] | Literal['OR'] = 'AND' + components: list['QueryComponent'] = None + + +QueryComponent = Union[AttributeAssignment, AttributeComparison, LogicalExpression] # noqa: UP007 + +if __name__ == '__main__': + + def evaluate_query(component: QueryComponent): + if isinstance(component, LogicalExpression): + print(f'Evaluate {component.operator} of:') + for sub_component in component.components: + evaluate_query(sub_component) # Recursively handle sub-components + elif isinstance(component, AttributeAssignment): + print(f'Set {component.path} to {component.value}') + elif isinstance(component, AttributeComparison): + print(f'Compare {component.path} {component.operator} with {component.value}') + else: + raise ValueError('Unsupported query component') + + + print('Evaluating query component:') + print('--------------------------') + evaluate_query(LogicalExpression('AND', [ + AttributeComparison('age', '>=', 20), + AttributeComparison('age', '<=', 30) + ])) diff --git a/fai-rag-app/fai-backend/fai_backend/repository/query/mongo.py b/fai-rag-app/fai-backend/fai_backend/repository/query/mongo.py new file mode 100644 index 00000000..0650b4d3 --- /dev/null +++ b/fai-rag-app/fai-backend/fai_backend/repository/query/mongo.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass + +from fai_backend.repository.query import AttributeAssignment, AttributeComparison, LogicalExpression, QueryComponent + + +class QueryAdapter: + def to_mongo_query(self): + raise NotImplementedError('Must implement to_mongo_query') + + +@dataclass +class AttributeAssignmentAdapter(QueryAdapter): + component: AttributeAssignment + + def to_mongo_query(self): + # MongoDB uses $set for updates; for a find query, simple equality is assumed + return {self.component.path: self.component.value} + + +@dataclass +class AttributeComparisonAdapter(QueryAdapter): + component: AttributeComparison + + def to_mongo_query(self): + # MongoDB specific operator map + operator_map = { + '<': '$lt', + '<=': '$lte', + '==': '$eq', + '!=': '$ne', + '>': '$gt', + '>=': '$gte' + } + mongo_operator = operator_map[self.component.operator] + return {self.component.path: {mongo_operator: self.component.value}} + + +@dataclass +class LogicalExpressionAdapter(QueryAdapter): + component: LogicalExpression + + def to_mongo_query(self): + logical_operator_map = { + 'AND': '$and', + 'OR': '$or' + } + return { + logical_operator_map[self.component.operator]: [adapt_query_component(sub_component).to_mongo_query() for + sub_component in self.component.components]} + + +def adapt_query_component(component: QueryComponent) -> QueryAdapter: + if isinstance(component, AttributeAssignment): + return AttributeAssignmentAdapter(component) + elif isinstance(component, AttributeComparison): + return AttributeComparisonAdapter(component) + elif isinstance(component, LogicalExpression): + return LogicalExpressionAdapter(component) + else: + raise ValueError('Unsupported query component type') diff --git a/fai-rag-app/fai-backend/tests/repository/test_mongodb.py b/fai-rag-app/fai-backend/tests/repository/test_mongodb.py index 149c7299..738442a7 100644 --- a/fai-rag-app/fai-backend/tests/repository/test_mongodb.py +++ b/fai-rag-app/fai-backend/tests/repository/test_mongodb.py @@ -3,33 +3,37 @@ from beanie import Document, init_beanie from bson import ObjectId from mongomock_motor import AsyncMongoMockClient +from pydantic import BaseModel from fai_backend.repository.mongodb import MongoDBRepo +from fai_backend.repository.query.component import AttributeAssignment, AttributeComparison, LogicalExpression -class SampleDocument(Document): +class Employee(BaseModel): + id: str = None name: str age: int + perks: list[str] = [] + +class EmployeeDocument(Document, Employee): class Settings: use_state_management = True @pytest_asyncio.fixture async def mongo_repo(): - client = AsyncMongoMockClient() - db = client.test_db - await init_beanie(database=db, document_models=[SampleDocument]) - yield MongoDBRepo(SampleDocument, SampleDocument) - await SampleDocument.get_motor_collection().drop() + await init_beanie(database=(AsyncMongoMockClient()).test_db, document_models=[EmployeeDocument]) + yield MongoDBRepo[Employee, EmployeeDocument](Employee, EmployeeDocument) + await EmployeeDocument.get_motor_collection().drop() @pytest.mark.asyncio async def test_list(mongo_repo): - await mongo_repo.create(SampleDocument(name='Alice', age=30)) - await mongo_repo.create(SampleDocument(name='Bob', age=25)) + await mongo_repo.create(Employee(name='Alice', age=30)) + await mongo_repo.create(Employee(name='Bob', age=25)) - documents = await mongo_repo.list() + documents = await mongo_repo.list(None) assert len(documents) == 2 assert documents[0].name == 'Alice' @@ -38,8 +42,7 @@ async def test_list(mongo_repo): @pytest.mark.asyncio async def test_create(mongo_repo): - created_document = await mongo_repo.create(SampleDocument(name='Charlie', age=40)) - + created_document = await mongo_repo.create(Employee(name='Charlie', age=40)) assert created_document.name == 'Charlie' assert created_document.age == 40 assert created_document.id is not None @@ -47,7 +50,7 @@ async def test_create(mongo_repo): @pytest.mark.asyncio async def test_get_existing_document(mongo_repo): - sample_document = await mongo_repo.create(SampleDocument(name='Dave', age=35)) + sample_document = await mongo_repo.create(Employee(name='Dave', age=35)) retrieved_document = await mongo_repo.get(str(sample_document.id)) @@ -65,17 +68,17 @@ async def test_get_non_existing_document(mongo_repo): @pytest.mark.asyncio async def test_update_existing_document(mongo_repo): - created = await mongo_repo.create(SampleDocument(name='Eve', age=45)) - document = await mongo_repo.get(str(created.id)) - saved_id = str(document.id) + created = await mongo_repo.create(Employee(name='Eve', age=45)) + document = await mongo_repo.get(created.id) + saved_id = document.id updated_document = await mongo_repo.update( - str(document.id), {'name': 'Eve Updated', 'age': 50} + document.id, {'name': 'Eve Updated', 'age': 50} ) assert updated_document is not None assert updated_document.name == 'Eve Updated' assert updated_document.age == 50 - assert str(updated_document.id) == saved_id + assert updated_document.id == saved_id @pytest.mark.asyncio @@ -89,14 +92,14 @@ async def test_update_non_existing_document(mongo_repo): @pytest.mark.asyncio async def test_delete_existing_document(mongo_repo): - sample_document = await mongo_repo.create(SampleDocument(name='Frank', age=55)) + sample_document = await mongo_repo.create(Employee(name='Frank', age=55)) deleted_document = await mongo_repo.delete(str(sample_document.id)) assert deleted_document is not None assert deleted_document.name == 'Frank' assert deleted_document.age == 55 - assert await SampleDocument.get(sample_document.id) is None + assert await mongo_repo.get(sample_document.id) is None @pytest.mark.asyncio @@ -104,3 +107,64 @@ async def test_delete_non_existing_document(mongo_repo): deleted_document = await mongo_repo.delete(str(ObjectId())) assert deleted_document is None + + +@pytest.mark.asyncio +async def test_sort_by_age_descending(mongo_repo): + await mongo_repo.create(Employee(name='Bob', age=25)) + await mongo_repo.create(Employee(name='Alice', age=30)) + await mongo_repo.create(Employee(name='James', age=40)) + + documents = await mongo_repo.list(None, sort_by='age', sort_order='desc') + + assert len(documents) == 3 + assert documents[0].name == 'James' + assert documents[1].name == 'Alice' + assert documents[2].name == 'Bob' + + +@pytest.mark.asyncio +async def test_sort_by_age_ascending(mongo_repo): + await mongo_repo.create(Employee(name='Alice', age=30)) + await mongo_repo.create(Employee(name='Bob', age=25)) + await mongo_repo.create(Employee(name='Liz', age=16)) + + documents = await mongo_repo.list(None, sort_by='age', sort_order='asc') + + assert len(documents) == 3 + assert documents[0].name == 'Liz' + assert documents[1].name == 'Bob' + assert documents[2].name == 'Alice' + + +@pytest.mark.asyncio +async def test_list_by_age_in_range(mongo_repo): + await mongo_repo.create(Employee(name='Carl', age=40)) + await mongo_repo.create(Employee(name='Alice', age=29)) + await mongo_repo.create(Employee(name='Bob', age=25)) + await mongo_repo.create(Employee(name='Liz', age=16)) + + documents = await mongo_repo.list(LogicalExpression('AND', [ + AttributeComparison('age', '>=', 20), + AttributeComparison('age', '<=', 30) + ]), sort_by='age', sort_order='desc') + + assert len(documents) == 2 + assert documents[0].name == 'Alice' + assert documents[1].name == 'Bob' + + +@pytest.mark.asyncio +async def test_query_with_attribute_assignment_expression(mongo_repo): + await mongo_repo.create(Employee(name='Alice', age=30, perks=['Health Insurance'])) + await mongo_repo.create(Employee(name='Bob', age=25, perks=['Dental Insurance'])) + await mongo_repo.create(Employee(name='Charlie', age=35, perks=['Health Insurance', 'Dental Insurance'])) + await mongo_repo.create(Employee(name='David', age=40, perks=['Vision Insurance'])) + documents = await mongo_repo.list( + AttributeAssignment('perks', 'Health Insurance'), + sort_by='age', sort_order='desc' + ) + + assert len(documents) == 2 + assert documents[0].name == 'Charlie' + assert documents[1].name == 'Alice'