Skip to content

Commit

Permalink
add: query component, mongo adapter & update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nRamstedt committed Jun 7, 2024
1 parent f4c51e5 commit ab9b8fa
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 25 deletions.
4 changes: 3 additions & 1 deletion fai-rag-app/fai-backend/fai_backend/repository/interface.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Protocol, TypeVar

from fai_backend.repository.query.component import QueryComponent

T = TypeVar('T')


class IAsyncRepo(Protocol[T]):
async def get(self, item_id: str) -> T | None:
raise NotImplementedError('get not implemented')

async def list(self) -> list[T]:
async def list(self, query: QueryComponent = None, sort_by: str = None, sort_order: str = 'asc') -> list[T]:
raise NotImplementedError('list not implemented')

async def create(self, item: T) -> T | None:
Expand Down
28 changes: 23 additions & 5 deletions fai-rag-app/fai-backend/fai_backend/repository/mongodb.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Generic, TypeVar

from beanie import Document, PydanticObjectId
from beanie import Document, PydanticObjectId, SortDirection
from bson.errors import InvalidId
from pydantic import BaseModel

from fai_backend.repository.interface import IAsyncRepo
from fai_backend.repository.query.component import (
QueryComponent,
)

T = TypeVar('T', bound=BaseModel)
T_DB = TypeVar('T_DB', bound=Document)
Expand All @@ -18,12 +21,27 @@ def __init__(self, model: type[T], odm_model: type[T_DB]):
self.model = model
self.odm_model = odm_model

async def list(self) -> list[T]:
return await self.odm_model.all().to_list()
async def list(
self,
query: QueryComponent = None,
sort_by: str = None,
sort_order: str = 'asc'
) -> list[T]:
def find_query(q: QueryComponent = None) -> dict:
return adapt_query_component(q).to_mongo_query() if query else {}

db_query = self.odm_model.find(find_query(query))

if sort_by:
direction = SortDirection.ASCENDING if sort_order == 'asc' else SortDirection.DESCENDING
db_query = db_query.sort((sort_by, direction))
return [self.model.model_validate(doc) for doc in await db_query.to_list()]

async def create(self, item: T) -> T:
item.id = PydanticObjectId()
return self.model.model_validate(await self.odm_model.model_validate(item.model_dump()).create())
item = item.model_dump()
item['id'] = PydanticObjectId()
item_in_db = await self.odm_model.model_validate(item).create()
return self.model.model_validate(item_in_db)

async def get(self, item_id: str) -> T | None:
try:
Expand Down
48 changes: 48 additions & 0 deletions fai-rag-app/fai-backend/fai_backend/repository/query/component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from dataclasses import dataclass
from typing import Literal, Union

Path = str


@dataclass
class AttributeAssignment:
path: Path
value: any


@dataclass
class AttributeComparison:
path: Path
operator: Literal['<'] | Literal['<='] | Literal['=='] | Literal['!='] | Literal['>'] | Literal['>=']
value: any


@dataclass
class LogicalExpression:
operator: Literal['AND'] | Literal['OR'] = 'AND'
components: list['QueryComponent'] = None


QueryComponent = Union[AttributeAssignment, AttributeComparison, LogicalExpression] # noqa: UP007

if __name__ == '__main__':

def evaluate_query(component: QueryComponent):
if isinstance(component, LogicalExpression):
print(f'Evaluate {component.operator} of:')
for sub_component in component.components:
evaluate_query(sub_component) # Recursively handle sub-components
elif isinstance(component, AttributeAssignment):
print(f'Set {component.path} to {component.value}')
elif isinstance(component, AttributeComparison):
print(f'Compare {component.path} {component.operator} with {component.value}')
else:
raise ValueError('Unsupported query component')


print('Evaluating query component:')
print('--------------------------')
evaluate_query(LogicalExpression('AND', [
AttributeComparison('age', '>=', 20),
AttributeComparison('age', '<=', 30)
]))
60 changes: 60 additions & 0 deletions fai-rag-app/fai-backend/fai_backend/repository/query/mongo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from dataclasses import dataclass

from fai_backend.repository.query import AttributeAssignment, AttributeComparison, LogicalExpression, QueryComponent


class QueryAdapter:
def to_mongo_query(self):
raise NotImplementedError('Must implement to_mongo_query')


@dataclass
class AttributeAssignmentAdapter(QueryAdapter):
component: AttributeAssignment

def to_mongo_query(self):
# MongoDB uses $set for updates; for a find query, simple equality is assumed
return {self.component.path: self.component.value}


@dataclass
class AttributeComparisonAdapter(QueryAdapter):
component: AttributeComparison

def to_mongo_query(self):
# MongoDB specific operator map
operator_map = {
'<': '$lt',
'<=': '$lte',
'==': '$eq',
'!=': '$ne',
'>': '$gt',
'>=': '$gte'
}
mongo_operator = operator_map[self.component.operator]
return {self.component.path: {mongo_operator: self.component.value}}


@dataclass
class LogicalExpressionAdapter(QueryAdapter):
component: LogicalExpression

def to_mongo_query(self):
logical_operator_map = {
'AND': '$and',
'OR': '$or'
}
return {
logical_operator_map[self.component.operator]: [adapt_query_component(sub_component).to_mongo_query() for
sub_component in self.component.components]}


def adapt_query_component(component: QueryComponent) -> QueryAdapter:
if isinstance(component, AttributeAssignment):
return AttributeAssignmentAdapter(component)
elif isinstance(component, AttributeComparison):
return AttributeComparisonAdapter(component)
elif isinstance(component, LogicalExpression):
return LogicalExpressionAdapter(component)
else:
raise ValueError('Unsupported query component type')
102 changes: 83 additions & 19 deletions fai-rag-app/fai-backend/tests/repository/test_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,37 @@
from beanie import Document, init_beanie
from bson import ObjectId
from mongomock_motor import AsyncMongoMockClient
from pydantic import BaseModel

from fai_backend.repository.mongodb import MongoDBRepo
from fai_backend.repository.query.component import AttributeAssignment, AttributeComparison, LogicalExpression


class SampleDocument(Document):
class Employee(BaseModel):
id: str = None
name: str
age: int
perks: list[str] = []


class EmployeeDocument(Document, Employee):
class Settings:
use_state_management = True


@pytest_asyncio.fixture
async def mongo_repo():
client = AsyncMongoMockClient()
db = client.test_db
await init_beanie(database=db, document_models=[SampleDocument])
yield MongoDBRepo(SampleDocument, SampleDocument)
await SampleDocument.get_motor_collection().drop()
await init_beanie(database=(AsyncMongoMockClient()).test_db, document_models=[EmployeeDocument])
yield MongoDBRepo[Employee, EmployeeDocument](Employee, EmployeeDocument)
await EmployeeDocument.get_motor_collection().drop()


@pytest.mark.asyncio
async def test_list(mongo_repo):
await mongo_repo.create(SampleDocument(name='Alice', age=30))
await mongo_repo.create(SampleDocument(name='Bob', age=25))
await mongo_repo.create(Employee(name='Alice', age=30))
await mongo_repo.create(Employee(name='Bob', age=25))

documents = await mongo_repo.list()
documents = await mongo_repo.list(None)

assert len(documents) == 2
assert documents[0].name == 'Alice'
Expand All @@ -38,16 +42,15 @@ async def test_list(mongo_repo):

@pytest.mark.asyncio
async def test_create(mongo_repo):
created_document = await mongo_repo.create(SampleDocument(name='Charlie', age=40))

created_document = await mongo_repo.create(Employee(name='Charlie', age=40))
assert created_document.name == 'Charlie'
assert created_document.age == 40
assert created_document.id is not None


@pytest.mark.asyncio
async def test_get_existing_document(mongo_repo):
sample_document = await mongo_repo.create(SampleDocument(name='Dave', age=35))
sample_document = await mongo_repo.create(Employee(name='Dave', age=35))

retrieved_document = await mongo_repo.get(str(sample_document.id))

Expand All @@ -65,17 +68,17 @@ async def test_get_non_existing_document(mongo_repo):

@pytest.mark.asyncio
async def test_update_existing_document(mongo_repo):
created = await mongo_repo.create(SampleDocument(name='Eve', age=45))
document = await mongo_repo.get(str(created.id))
saved_id = str(document.id)
created = await mongo_repo.create(Employee(name='Eve', age=45))
document = await mongo_repo.get(created.id)
saved_id = document.id
updated_document = await mongo_repo.update(
str(document.id), {'name': 'Eve Updated', 'age': 50}
document.id, {'name': 'Eve Updated', 'age': 50}
)

assert updated_document is not None
assert updated_document.name == 'Eve Updated'
assert updated_document.age == 50
assert str(updated_document.id) == saved_id
assert updated_document.id == saved_id


@pytest.mark.asyncio
Expand All @@ -89,18 +92,79 @@ async def test_update_non_existing_document(mongo_repo):

@pytest.mark.asyncio
async def test_delete_existing_document(mongo_repo):
sample_document = await mongo_repo.create(SampleDocument(name='Frank', age=55))
sample_document = await mongo_repo.create(Employee(name='Frank', age=55))
deleted_document = await mongo_repo.delete(str(sample_document.id))

assert deleted_document is not None
assert deleted_document.name == 'Frank'
assert deleted_document.age == 55

assert await SampleDocument.get(sample_document.id) is None
assert await mongo_repo.get(sample_document.id) is None


@pytest.mark.asyncio
async def test_delete_non_existing_document(mongo_repo):
deleted_document = await mongo_repo.delete(str(ObjectId()))

assert deleted_document is None


@pytest.mark.asyncio
async def test_sort_by_age_descending(mongo_repo):
await mongo_repo.create(Employee(name='Bob', age=25))
await mongo_repo.create(Employee(name='Alice', age=30))
await mongo_repo.create(Employee(name='James', age=40))

documents = await mongo_repo.list(None, sort_by='age', sort_order='desc')

assert len(documents) == 3
assert documents[0].name == 'James'
assert documents[1].name == 'Alice'
assert documents[2].name == 'Bob'


@pytest.mark.asyncio
async def test_sort_by_age_ascending(mongo_repo):
await mongo_repo.create(Employee(name='Alice', age=30))
await mongo_repo.create(Employee(name='Bob', age=25))
await mongo_repo.create(Employee(name='Liz', age=16))

documents = await mongo_repo.list(None, sort_by='age', sort_order='asc')

assert len(documents) == 3
assert documents[0].name == 'Liz'
assert documents[1].name == 'Bob'
assert documents[2].name == 'Alice'


@pytest.mark.asyncio
async def test_list_by_age_in_range(mongo_repo):
await mongo_repo.create(Employee(name='Carl', age=40))
await mongo_repo.create(Employee(name='Alice', age=29))
await mongo_repo.create(Employee(name='Bob', age=25))
await mongo_repo.create(Employee(name='Liz', age=16))

documents = await mongo_repo.list(LogicalExpression('AND', [
AttributeComparison('age', '>=', 20),
AttributeComparison('age', '<=', 30)
]), sort_by='age', sort_order='desc')

assert len(documents) == 2
assert documents[0].name == 'Alice'
assert documents[1].name == 'Bob'


@pytest.mark.asyncio
async def test_query_with_attribute_assignment_expression(mongo_repo):
await mongo_repo.create(Employee(name='Alice', age=30, perks=['Health Insurance']))
await mongo_repo.create(Employee(name='Bob', age=25, perks=['Dental Insurance']))
await mongo_repo.create(Employee(name='Charlie', age=35, perks=['Health Insurance', 'Dental Insurance']))
await mongo_repo.create(Employee(name='David', age=40, perks=['Vision Insurance']))
documents = await mongo_repo.list(
AttributeAssignment('perks', 'Health Insurance'),
sort_by='age', sort_order='desc'
)

assert len(documents) == 2
assert documents[0].name == 'Charlie'
assert documents[1].name == 'Alice'

0 comments on commit ab9b8fa

Please sign in to comment.