Skip to content

Commit

Permalink
Reflected model changes to py_database and SQLSessionStorage.
Browse files Browse the repository at this point in the history
  • Loading branch information
yghokim committed Apr 25, 2024
1 parent 57cf54c commit 771dc23
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 28 deletions.
9 changes: 5 additions & 4 deletions libs/py_database/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions libs/py_database/py_database.iml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Poetry (py_database)" jdkType="Python SDK" />
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/.venv" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="module" module-name="py_core" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
Expand Down
55 changes: 38 additions & 17 deletions libs/py_database/py_database/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,27 @@
CardInfoListTypeAdapter, CardInfo,
ChildCardRecommendationResult as _ChildCardRecommendationResult,
ParentGuideRecommendationResult as _ParentGuideRecommendationResult,
ParentGuideElement)
ParentGuideElement,
ParentExampleMessage as _ParentExampleMessage
)
from chatlib.utils.time import get_timestamp


class IdTimestampMixin(BaseModel):
id: str = Field(primary_key=True, default_factory=id_generator)
created_at: Optional[datetime] = Field(
default=None,
sa_type= DateTime(timezone=True),
sa_type=DateTime(timezone=True),
sa_column_kwargs=dict(server_default=func.now(), nullable=True)
)
updated_at: Optional[datetime] = Field(
default=None,
sa_type= DateTime(timezone=True),
sa_type=DateTime(timezone=True),
sa_column_kwargs=dict(server_default=func.now(), onupdate=func.now(), nullable=True)
)


class Parent(SQLModel, IdTimestampMixin, table=True):

name: str = Field(index=True)

children: list['Child'] = Relationship(back_populates='parent')
Expand All @@ -37,7 +38,6 @@ class Parent(SQLModel, IdTimestampMixin, table=True):


class Child(SQLModel, IdTimestampMixin, table=True):

name: str = Field(index=True)

parent_id: Optional[str] = Field(default=None, foreign_key='parent.id')
Expand All @@ -58,12 +58,16 @@ class Session(SQLModel, IdTimestampMixin, table=True):
ended_timestamp: int | None = Field(default=None, index=True)


class SessionIdMixin(BaseModel):
session_id: str = Field(foreign_key=f"{Session.__tablename__}.id")


class DialogueMessageContentType(StrEnum):
text="text"
json="json"
text = "text"
json = "json"


class DialogueMessage(SQLModel, IdTimestampMixin, table=True):
session_id:str = Field(foreign_key="session.id")
class DialogueMessage(SQLModel, IdTimestampMixin, SessionIdMixin, table=True):
role: DialogueRole
content_type: DialogueMessageContentType
content_str: str
Expand All @@ -90,31 +94,48 @@ def from_data_model(cls, session_id: str, message: _DialogueMessage) -> 'Dialogu
content_str=message.content if isinstance(message.content, str) else CardInfoListTypeAdapter.dump_json(
message.content),
content_str_en=message.content_en,
content_type=DialogueMessageContentType.text if isinstance(message.content, str) else DialogueMessageContentType.json
content_type=DialogueMessageContentType.text if isinstance(message.content,
str) else DialogueMessageContentType.json
)


class ChildCardRecommendationResult(SQLModel, IdTimestampMixin, table=True):
session_id:str = Field(foreign_key="session.id")
class ChildCardRecommendationResult(SQLModel, IdTimestampMixin, SessionIdMixin, table=True):
timestamp: int = Field(default_factory=get_timestamp, index=True)
cards: list[CardInfo] = Field(sa_column=Column(JSON), default=[])

def to_data_model(self) -> _ChildCardRecommendationResult:
return _ChildCardRecommendationResult(**self.model_dump())

@classmethod
def from_data_model(cls, session_id: str, data_model: _ChildCardRecommendationResult) -> 'ChildCardRecommendationResult':
def from_data_model(cls, session_id: str,
data_model: _ChildCardRecommendationResult) -> 'ChildCardRecommendationResult':
return ChildCardRecommendationResult(**data_model.model_dump(), session_id=session_id)


class ParentGuideRecommendationResult(SQLModel, IdTimestampMixin, table=True):
session_id:str = Field(foreign_key="session.id")
class ParentGuideRecommendationResult(SQLModel, IdTimestampMixin, SessionIdMixin, table=True):
timestamp: int = Field(default_factory=get_timestamp, index=True)
recommendations: list[ParentGuideElement] = Field(sa_column=Column(JSON), default=[])
guides: list[ParentGuideElement] = Field(sa_column=Column(JSON), default=[])

def to_data_model(self) -> _ParentGuideRecommendationResult:
return _ParentGuideRecommendationResult(**self.model_dump())

@classmethod
def from_data_model(cls, session_id: str, data_model: _ParentGuideRecommendationResult) -> 'ParentGuideRecommendationResult':
def from_data_model(cls, session_id: str,
data_model: _ParentGuideRecommendationResult) -> 'ParentGuideRecommendationResult':
return ParentGuideRecommendationResult(**data_model.model_dump(), session_id=session_id)


class ParentExampleMessage(SQLModel, IdTimestampMixin, SessionIdMixin, table=True):
recommendation_id: str = Field(foreign_key=f"{ParentGuideRecommendationResult.__tablename__}.id")
guide_id: str = Field(nullable=False, index=True)

message: str = Field(nullable=False)
message_localized: Optional[str] = Field(default=None)

def to_data_model(self) -> _ParentExampleMessage:
return _ParentExampleMessage(**self.model_dump())

@classmethod
def from_data_model(cls, session_id: str, data_model: _ParentExampleMessage) -> 'ParentExampleMessage':
return ParentExampleMessage(**data_model.model_dump(), session_id=session_id)

29 changes: 24 additions & 5 deletions libs/py_database/py_database/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
from sqlmodel import select, col

from py_core.system.model import ParentGuideRecommendationResult, ChildCardRecommendationResult, Dialogue, \
DialogueMessage
DialogueMessage, ParentExampleMessage
from py_core.system.storage import SessionStorage
from py_database.model import DialogueMessage as DialogueMessageORM, ChildCardRecommendationResult as ChildCardRecommendationResultORM, ParentGuideRecommendationResult as ParentGuideRecommendationResultORM
from py_database.model import (DialogueMessage as DialogueMessageORM,
ChildCardRecommendationResult as ChildCardRecommendationResultORM,
ParentGuideRecommendationResult as ParentGuideRecommendationResultORM,
ParentExampleMessage as ParentExampleMessageORM)
from py_database.database import AsyncSession


class SQLSessionStorage(SessionStorage):

def __init__(self, sql_session: AsyncSession, session_id: str | None = None):
super().__init__(session_id)
self.__sql_session = sql_session
Expand All @@ -19,7 +23,8 @@ async def add_dialogue_message(self, message: DialogueMessage):
await self.__sql_session.commit()

async def get_dialogue(self) -> Dialogue:
statement = select(DialogueMessageORM).where(DialogueMessageORM.session_id == self.session_id).order_by(col(DialogueMessageORM.timestamp).desc())
statement = select(DialogueMessageORM).where(DialogueMessageORM.session_id == self.session_id).order_by(
col(DialogueMessageORM.timestamp).desc())
results = await self.__sql_session.exec(statement)
return [msg.to_data_model() for msg in results]

Expand All @@ -32,14 +37,28 @@ async def add_parent_guide_recommendation_result(self, result: ParentGuideRecomm
await self.__sql_session.commit()

async def get_card_recommendation_result(self, recommendation_id: str) -> ChildCardRecommendationResult | None:
statement = select(ChildCardRecommendationResultORM).where(ChildCardRecommendationResultORM.id == recommendation_id)
statement = select(ChildCardRecommendationResultORM).where(
ChildCardRecommendationResultORM.id == recommendation_id)
result = await self.__sql_session.exec(statement)
orm: ChildCardRecommendationResultORM | None = result.first()
return orm.to_data_model() if orm is not None else None

async def get_parent_guide_recommendation_result(self,
recommendation_id: str) -> ParentGuideRecommendationResult | None:
statement = select(ParentGuideRecommendationResultORM).where(ParentGuideRecommendationResultORM.id == recommendation_id)
statement = select(ParentGuideRecommendationResultORM).where(
ParentGuideRecommendationResultORM.id == recommendation_id)
result = await self.__sql_session.exec(statement)
orm: ParentGuideRecommendationResultORM | None = result.first()
return orm.to_data_model() if orm is not None else None

async def add_parent_example_message(self, message: ParentExampleMessage):
self.__sql_session.add(ParentExampleMessageORM.from_data_model(self.session_id, message))
await self.__sql_session.commit()

async def get_parent_example_message(self, recommendation_id: str, guide_id: str) -> ParentExampleMessage | None:
statement = (select(ParentExampleMessageORM)
.where(ParentExampleMessageORM.recommendation_id == recommendation_id)
.where(ParentExampleMessageORM.guide_id == guide_id))
result = await self.__sql_session.exec(statement)
orm: ParentExampleMessageORM | None = result.first()
return orm.to_data_model() if orm is not None else None

0 comments on commit 771dc23

Please sign in to comment.