Skip to content

Commit

Permalink
fix(LogSerialization): dataframe timestamp serialization (#751)
Browse files Browse the repository at this point in the history
* fix(json_serializer): datetime convert to json serialization

* refactor(QueryTracker): add function to convert dataframe to dict

* chore(tests): remove leftover print statements
  • Loading branch information
ArslanSaleem authored Nov 15, 2023
1 parent 0712f84 commit 6de6b39
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
18 changes: 8 additions & 10 deletions pandasai/helpers/query_exec_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def start_new_track(self):
self._query_info = {}
self._func_exec_count: dict = defaultdict(int)

def convert_dataframe_to_dict(self, df):
json_data = json.loads(df.to_json(orient="split", date_format="iso"))
return {"headers": json_data["columns"], "rows": json_data["data"]}

def add_dataframes(self, dfs: List) -> None:
"""
Add used dataframes for the query to query exec tracker
Expand All @@ -96,9 +100,7 @@ def add_dataframes(self, dfs: List) -> None:
"""
for df in dfs:
head = df.head_df
self._dataframes.append(
{"headers": head.columns.tolist(), "rows": head.values.tolist()}
)
self._dataframes.append(self.convert_dataframe_to_dict(head))

def add_step(self, step: dict) -> None:
"""
Expand Down Expand Up @@ -200,13 +202,9 @@ def _format_response(self, result: ResponseType) -> ResponseType:
ResponseType: formatted response output
"""
if result["type"] == "dataframe":
return {
"type": result["type"],
"value": {
"headers": result["value"].columns.tolist(),
"rows": result["value"].values.tolist(),
},
}
df_dict = self.convert_dataframe_to_dict(result["value"])
return {"type": result["type"], "value": df_dict}

elif result["type"] == "plot":
with open(result["value"], "rb") as image_file:
image_data = image_file.read()
Expand Down
25 changes: 25 additions & 0 deletions tests/test_query_tracker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import time
from typing import Optional
Expand All @@ -9,6 +10,7 @@
from pandasai.llm.fake import FakeLLM
from pandasai.smart_dataframe import SmartDataframe
from unittest import TestCase
from datetime import datetime, timedelta


assert_almost_equal = TestCase().assertAlmostEqual
Expand Down Expand Up @@ -120,6 +122,29 @@ def test_format_response_dataframe(
assert len(formatted_response["value"]["headers"]) == 3
assert len(formatted_response["value"]["rows"]) == 10

def test_format_response_dataframe_with_datetime_field(
self, tracker: QueryExecTracker, sample_df: pd.DataFrame
):
# Add a date column with random dates for demonstration
start_date = datetime(2023, 1, 1)
date_range = [start_date + timedelta(days=x) for x in range(len(sample_df))]

sample_df["date"] = date_range

# Create a sample ResponseType for a dataframe
response = {"type": "dataframe", "value": sample_df}

# Format the response using _format_response
formatted_response = tracker._format_response(response)

# Validate dataframe json serialization
json.dumps(formatted_response)

# Check if the response is formatted correctly
assert formatted_response["type"] == "dataframe"
assert len(formatted_response["value"]["headers"]) == 4
assert len(formatted_response["value"]["rows"]) == 10

def test_format_response_other_type(self, tracker: QueryExecTracker):
# Create a sample ResponseType for a non-dataframe response
response = {
Expand Down

0 comments on commit 6de6b39

Please sign in to comment.