Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(Dataframe): adding default dataframe name to enable sql query on it, simplified dataframe serialization #1523

Merged
merged 3 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .core.cache import Cache
from .data_loader.loader import DatasetLoader
from .dataframe import DataFrame, VirtualDataFrame
from .helpers.sql_sanitizer import sanitize_sql_table_name
from .smart_dataframe import SmartDataframe
from .smart_datalake import SmartDatalake

Expand Down Expand Up @@ -120,7 +121,8 @@ def load(dataset_path: str) -> DataFrame:

def read_csv(filepath: str) -> DataFrame:
data = pd.read_csv(filepath)
return DataFrame(data._data)
name = f"table_{sanitize_sql_table_name(filepath)}"
return DataFrame(data._data, name=name)


__all__ = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{% for df in context.dfs %}{% set index = loop.index %}{% include 'shared/dataframe.tmpl' with context %}{% endfor %}
{% for df in context.dfs %}{% include 'shared/dataframe.tmpl' with context %}{% endfor %}

The user asked the following question:
{{context.memory.get_conversation()}}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<tables>
{% for df in context.dfs %}
{% set index = loop.index %}{% include 'shared/dataframe.tmpl' with context %}
{% include 'shared/dataframe.tmpl' with context %}
{% endfor %}
</tables>

Expand Down
2 changes: 1 addition & 1 deletion pandasai/core/prompts/templates/shared/dataframe.tmpl
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{{ df.serialize_dataframe(index-1) }}
{{ df.serialize_dataframe() }}
30 changes: 8 additions & 22 deletions pandasai/dataframe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
Source,
)
from pandasai.exceptions import DatasetNotFound, PandaAIApiKeyError
from pandasai.helpers.dataframe_serializer import (
DataframeSerializer,
DataframeSerializerType,
)
from pandasai.helpers.dataframe_serializer import DataframeSerializer
from pandasai.helpers.path import find_project_root
from pandasai.helpers.session import get_pandaai_session

Expand Down Expand Up @@ -67,14 +64,18 @@ def __init__(
)

self.name: Optional[str] = kwargs.pop("name", None)
self._column_hash = self._calculate_column_hash()

if not self.name:
self.name = f"table_{self._column_hash}"

self.description: Optional[str] = kwargs.pop("description", None)
self.path: Optional[str] = kwargs.pop("path", None)
schema: Optional[SemanticLayerSchema] = kwargs.pop("schema", None)

self.schema = schema
self.config = pai.config.get()
self._agent: Optional[Agent] = None
self._column_hash = self._calculate_column_hash()

def __repr__(self) -> str:
"""Return a string representation of the DataFrame."""
Expand Down Expand Up @@ -136,29 +137,14 @@ def rows_count(self) -> int:
def columns_count(self) -> int:
return len(self.columns)

def serialize_dataframe(
self,
index: int,
) -> str:
def serialize_dataframe(self) -> str:
"""
Serialize DataFrame to string representation.

Args:
index (int): Index of the dataframe
serializer_type (DataframeSerializerType): Type of serializer to use
**kwargs: Additional parameters to pass to pandas to_string method

Returns:
str: Serialized string representation of the DataFrame
"""
return DataframeSerializer().serialize(
self,
extras={
"index": index,
"type": "pd.DataFrame",
},
type_=DataframeSerializerType.CSV,
)
return DataframeSerializer().serialize(self)

def get_head(self):
return self.head()
Expand Down
3 changes: 2 additions & 1 deletion pandasai/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from . import path
from . import path, sql_sanitizer
from .env import load_dotenv
from .logger import Logger

__all__ = [
"path",
"sql_sanitizer",
"load_dotenv",
"Logger",
]
114 changes: 6 additions & 108 deletions pandasai/helpers/dataframe_serializer.py
Original file line number Diff line number Diff line change
@@ -1,137 +1,35 @@
import json
from enum import Enum

import pandas as pd


class DataframeSerializerType(Enum):
JSON = 1
YML = 2
CSV = 3
SQL = 4


class DataframeSerializer:
def __init__(self) -> None:
pass

def serialize(
self,
df: pd.DataFrame,
extras: dict = None,
type_: DataframeSerializerType = DataframeSerializerType.YML,
) -> str:
if type_ == DataframeSerializerType.YML:
return self.convert_df_to_yml(df, extras)
elif type_ == DataframeSerializerType.JSON:
return self.convert_df_to_json_str(df, extras)
elif type_ == DataframeSerializerType.SQL:
return self.convert_df_sql_connector_to_str(df, extras)
else:
return self.convert_df_to_csv(df, extras)

def convert_df_to_csv(self, df: pd.DataFrame, extras: dict) -> str:
def serialize(self, df: pd.DataFrame) -> str:
"""
Convert df to csv like format where csv is wrapped inside <dataframe></dataframe>
Args:
df (pd.DataFrame): PandaAI dataframe or dataframe
extras (dict, optional): expect index to exists

Returns:
str: dataframe stringify
"""
dataframe_info = "<dataframe"
dataframe_info = "<table"

# Add name attribute if available
if df.name is not None:
dataframe_info += f' name="{df.name}"'
dataframe_info += f' table_name="{df.name}"'

# Add description attribute if available
if df.description is not None:
dataframe_info += f' description="{df.description}"'

dataframe_info += ">"
dataframe_info += f' dimensions="{df.rows_count}x{df.columns_count}">'

# Add dataframe details
dataframe_info += f"\ndfs[{extras['index']}]:{df.rows_count}x{df.columns_count}\n{df.head().to_csv(index=False)}"
dataframe_info += f"\n{df.head().to_csv(index=False)}"

# Close the dataframe tag
dataframe_info += "</dataframe>\n"
dataframe_info += "</table>\n"

return dataframe_info

def convert_df_sql_connector_to_str(
self, df: pd.DataFrame, extras: dict = None
) -> str:
"""
Convert df to csv like format where csv is wrapped inside <table></table>
Args:
df (pd.DataFrame): PandaAI dataframe or dataframe
extras (dict, optional): expect index to exists

Returns:
str: dataframe stringify
"""
table_description_tag = (
f' description="{df.description}"' if df.description is not None else ""
)
table_head_tag = f'<table name="{df.name}"{table_description_tag}>'
return f"{table_head_tag}\n{df.get_head().to_csv()}\n</table>"

def convert_df_to_json(self, df: pd.DataFrame, extras: dict) -> dict:
"""
Convert df to json dictionary and return json
Args:
df (pd.DataFrame): PandaAI dataframe or dataframe
extras (dict, optional): expect index to exists

Returns:
str: dataframe json
"""

# Create a dictionary representing the data structure
df_info = {
"name": df.name,
"description": None,
"type": df.type,
}
# Add DataFrame details to the result
data = {
"rows": df.rows_count,
"columns": df.columns_count,
"schema": {"fields": []},
}

# Iterate over DataFrame columns
df_head = df.get_head()
for col_name, col_dtype in df_head.dtypes.items():
col_info = {
"name": col_name,
"type": str(col_dtype),
}

data["schema"]["fields"].append(col_info)

result = df_info | data

return result

def convert_df_to_json_str(self, df: pd.DataFrame, extras: dict) -> str:
"""
Convert df to json and return it as string
Args:
df (pd.DataFrame): PandaAI dataframe or dataframe
extras (dict, optional): expect index to exists

Returns:
str: dataframe stringify
"""
return json.dumps(self.convert_df_to_json(df, extras))

def convert_df_to_yml(self, df: pd.DataFrame, extras: dict) -> str:
json_df = self.convert_df_to_json(df, extras)

import yaml

yml_str = yaml.dump(json_df, sort_keys=False, allow_unicode=True)
return f"<table>\n{yml_str}\n</table>\n"
16 changes: 16 additions & 0 deletions pandasai/helpers/sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
import re


def sanitize_sql_table_name(filepath: str) -> str:
# Extract the file name without extension
file_name = os.path.splitext(os.path.basename(filepath))[0]

# Replace invalid characters with underscores
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", file_name)

# Truncate to a reasonable length (e.g., 64 characters)
max_length = 64
sanitized_name = sanitized_name[:max_length]

return sanitized_name
3 changes: 1 addition & 2 deletions tests/unit_tests/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pandasai.config import Config, ConfigManager
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import CodeExecutionError
from pandasai.helpers.dataframe_serializer import DataframeSerializerType
from pandasai.llm.fake import FakeLLM


Expand Down Expand Up @@ -38,7 +37,7 @@ def llm(self, output: Optional[str] = None) -> FakeLLM:

@pytest.fixture
def config(self, llm: FakeLLM) -> dict:
return {"llm": llm, "dataframe_serializer": DataframeSerializerType.CSV}
return {"llm": llm}

@pytest.fixture
def agent(self, sample_df: pd.DataFrame, config: dict) -> Agent:
Expand Down
57 changes: 25 additions & 32 deletions tests/unit_tests/helpers/test_dataframe_serializer.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,32 @@
import unittest
import pytest

from pandasai.dataframe.base import DataFrame
from pandasai.helpers.dataframe_serializer import (
DataframeSerializer,
DataframeSerializerType,
)
from pandasai import DataFrame
from pandasai.helpers.dataframe_serializer import DataframeSerializer


class TestDataframeSerializer(unittest.TestCase):
def setUp(self):
self.serializer = DataframeSerializer()
class TestDataframeSerializer:
@pytest.fixture
def sample_df(self):
df = DataFrame({"Name": ["Alice", "Bob"], "Age": [25, 30]})
df.name = "test_table"
df.description = "This is a test table"
return df

def test_convert_df_to_yml(self):
# Test convert df to yml
data = {"name": ["en_name", "中文_名称"]}
connector = DataFrame(data, name="en_table_name", description="中文_描述")
result = self.serializer.serialize(
connector,
type_=DataframeSerializerType.YML,
extras={"index": 0, "type": "pd.Dataframe"},
)
@pytest.fixture
def sample_dataframe_serializer(self):
return DataframeSerializer()

self.assertIn(
"""<table>
name: en_table_name
description: null
type: pd.DataFrame
rows: 2
columns: 1
schema:
fields:
- name: name
type: object
def test_serialize_with_name_and_description(
self, sample_dataframe_serializer, sample_df
):
"""Test serialization with name and description attributes."""

</table>
""",
result,
result = sample_dataframe_serializer.serialize(sample_df)
expected = (
'<table table_name="test_table" description="This is a test table" dimensions="2x2">\n'
"Name,Age\n"
"Alice,25\n"
"Bob,30\n"
"</table>\n"
)
assert result == expected
19 changes: 19 additions & 0 deletions tests/unit_tests/helpers/test_sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name


class TestSqlSanitizer:
def test_valid_filename(self):
filepath = "/path/to/valid_table.csv"
expected = "valid_table"
assert sanitize_sql_table_name(filepath) == expected

def test_filename_with_special_characters(self):
filepath = "/path/to/invalid!@#.csv"
expected = "invalid___"
assert sanitize_sql_table_name(filepath) == expected

def test_filename_with_long_name(self):
"""Test with a filename exceeding the length limit."""
filepath = "/path/to/" + "a" * 100 + ".csv"
expected = "a" * 64
assert sanitize_sql_table_name(filepath) == expected
7 changes: 3 additions & 4 deletions tests/unit_tests/prompts/test_sql_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_str_with_args(self, output_type, output_type_template):

llm = FakeLLM()
agent = Agent(
pai.DataFrame(),
pai.DataFrame(name="test"),
config={"llm": llm},
)
prompt = GeneratePythonCodeWithSQLPrompt(
Expand All @@ -68,10 +68,9 @@ def test_str_with_args(self, output_type, output_type_template):
prompt_content
== f'''<tables>

<dataframe>
dfs[0]:0x0
<table table_name="test" dimensions="0x0">

</dataframe>
</table>


</tables>
Expand Down
Loading