From fd22f8c3778d22ac16b3486e35fa0a6da1a889fc Mon Sep 17 00:00:00 2001 From: Raoul Date: Wed, 15 Jan 2025 17:04:19 +0100 Subject: [PATCH 1/3] fix(Dataframe): adding default dataframe name to enable sql query on it, simplified dataframe serialization --- pandasai/__init__.py | 4 +- ..._execute_sql_query_usage_error_prompt.tmpl | 2 +- .../generate_python_code_with_sql.tmpl | 2 +- .../prompts/templates/shared/dataframe.tmpl | 2 +- pandasai/dataframe/base.py | 30 ++--- pandasai/helpers/__init__.py | 3 +- pandasai/helpers/dataframe_serializer.py | 114 +----------------- pandasai/helpers/sql_sanitizer.py | 16 +++ tests/unit_tests/agent/test_agent.py | 3 +- .../helpers/test_dataframe_serializer.py | 57 ++++----- .../unit_tests/helpers/test_sql_sanitizer.py | 19 +++ tests/unit_tests/prompts/test_sql_prompt.py | 7 +- 12 files changed, 86 insertions(+), 173 deletions(-) create mode 100644 pandasai/helpers/sql_sanitizer.py create mode 100644 tests/unit_tests/helpers/test_sql_sanitizer.py diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 5f73b0d9c..cb358fd1d 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -19,6 +19,7 @@ from .core.cache import Cache from .data_loader.loader import DatasetLoader from .dataframe import DataFrame, VirtualDataFrame +from .helpers.sql_sanitizer import sanitize_sql_table_name from .smart_dataframe import SmartDataframe from .smart_datalake import SmartDatalake @@ -120,7 +121,8 @@ def load(dataset_path: str) -> DataFrame: def read_csv(filepath: str) -> DataFrame: data = pd.read_csv(filepath) - return DataFrame(data._data) + name = f"table_{sanitize_sql_table_name(filepath)}" + return DataFrame(data._data, name=name) __all__ = [ diff --git a/pandasai/core/prompts/templates/correct_execute_sql_query_usage_error_prompt.tmpl b/pandasai/core/prompts/templates/correct_execute_sql_query_usage_error_prompt.tmpl index 523608f71..029cf26f2 100644 --- a/pandasai/core/prompts/templates/correct_execute_sql_query_usage_error_prompt.tmpl +++ b/pandasai/core/prompts/templates/correct_execute_sql_query_usage_error_prompt.tmpl @@ -1,4 +1,4 @@ -{% for df in context.dfs %}{% set index = loop.index %}{% include 'shared/dataframe.tmpl' with context %}{% endfor %} +{% for df in context.dfs %}{% include 'shared/dataframe.tmpl' with context %}{% endfor %} The user asked the following question: {{context.memory.get_conversation()}} diff --git a/pandasai/core/prompts/templates/generate_python_code_with_sql.tmpl b/pandasai/core/prompts/templates/generate_python_code_with_sql.tmpl index 6ce957d66..5406a8352 100644 --- a/pandasai/core/prompts/templates/generate_python_code_with_sql.tmpl +++ b/pandasai/core/prompts/templates/generate_python_code_with_sql.tmpl @@ -1,6 +1,6 @@ {% for df in context.dfs %} -{% set index = loop.index %}{% include 'shared/dataframe.tmpl' with context %} +{% include 'shared/dataframe.tmpl' with context %} {% endfor %} diff --git a/pandasai/core/prompts/templates/shared/dataframe.tmpl b/pandasai/core/prompts/templates/shared/dataframe.tmpl index 931813b6f..fcc6ea3a8 100644 --- a/pandasai/core/prompts/templates/shared/dataframe.tmpl +++ b/pandasai/core/prompts/templates/shared/dataframe.tmpl @@ -1 +1 @@ -{{ df.serialize_dataframe(index-1) }} +{{ df.serialize_dataframe() }} diff --git a/pandasai/dataframe/base.py b/pandasai/dataframe/base.py index d90a3b786..5cc895453 100644 --- a/pandasai/dataframe/base.py +++ b/pandasai/dataframe/base.py @@ -21,10 +21,7 @@ Source, ) from pandasai.exceptions import DatasetNotFound, PandaAIApiKeyError -from pandasai.helpers.dataframe_serializer import ( - DataframeSerializer, - DataframeSerializerType, -) +from pandasai.helpers.dataframe_serializer import DataframeSerializer from pandasai.helpers.path import find_project_root from pandasai.helpers.session import get_pandaai_session @@ -67,6 +64,11 @@ def __init__( ) self.name: Optional[str] = kwargs.pop("name", None) + self._column_hash = self._calculate_column_hash() + + if not self.name: + self.name = f"table_{self._column_hash}" + self.description: Optional[str] = kwargs.pop("description", None) self.path: Optional[str] = kwargs.pop("path", None) schema: Optional[SemanticLayerSchema] = kwargs.pop("schema", None) @@ -74,7 +76,6 @@ def __init__( self.schema = schema self.config = pai.config.get() self._agent: Optional[Agent] = None - self._column_hash = self._calculate_column_hash() def __repr__(self) -> str: """Return a string representation of the DataFrame.""" @@ -136,29 +137,14 @@ def rows_count(self) -> int: def columns_count(self) -> int: return len(self.columns) - def serialize_dataframe( - self, - index: int, - ) -> str: + def serialize_dataframe(self) -> str: """ Serialize DataFrame to string representation. - Args: - index (int): Index of the dataframe - serializer_type (DataframeSerializerType): Type of serializer to use - **kwargs: Additional parameters to pass to pandas to_string method - Returns: str: Serialized string representation of the DataFrame """ - return DataframeSerializer().serialize( - self, - extras={ - "index": index, - "type": "pd.DataFrame", - }, - type_=DataframeSerializerType.CSV, - ) + return DataframeSerializer().serialize(self) def get_head(self): return self.head() diff --git a/pandasai/helpers/__init__.py b/pandasai/helpers/__init__.py index 0e9534765..ea29af987 100644 --- a/pandasai/helpers/__init__.py +++ b/pandasai/helpers/__init__.py @@ -1,9 +1,10 @@ -from . import path +from . import path, sql_sanitizer from .env import load_dotenv from .logger import Logger __all__ = [ "path", + "sql_sanitizer", "load_dotenv", "Logger", ] diff --git a/pandasai/helpers/dataframe_serializer.py b/pandasai/helpers/dataframe_serializer.py index ca8e66219..487c2da82 100644 --- a/pandasai/helpers/dataframe_serializer.py +++ b/pandasai/helpers/dataframe_serializer.py @@ -1,137 +1,35 @@ -import json -from enum import Enum - import pandas as pd -class DataframeSerializerType(Enum): - JSON = 1 - YML = 2 - CSV = 3 - SQL = 4 - - class DataframeSerializer: def __init__(self) -> None: pass - def serialize( - self, - df: pd.DataFrame, - extras: dict = None, - type_: DataframeSerializerType = DataframeSerializerType.YML, - ) -> str: - if type_ == DataframeSerializerType.YML: - return self.convert_df_to_yml(df, extras) - elif type_ == DataframeSerializerType.JSON: - return self.convert_df_to_json_str(df, extras) - elif type_ == DataframeSerializerType.SQL: - return self.convert_df_sql_connector_to_str(df, extras) - else: - return self.convert_df_to_csv(df, extras) - - def convert_df_to_csv(self, df: pd.DataFrame, extras: dict) -> str: + def serialize(self, df: pd.DataFrame) -> str: """ Convert df to csv like format where csv is wrapped inside Args: df (pd.DataFrame): PandaAI dataframe or dataframe - extras (dict, optional): expect index to exists Returns: str: dataframe stringify """ - dataframe_info = "' # Add dataframe details - dataframe_info += f"\ndfs[{extras['index']}]:{df.rows_count}x{df.columns_count}\n{df.head().to_csv(index=False)}" + dataframe_info += f"\n{df.head().to_csv(index=False)}" # Close the dataframe tag - dataframe_info += "\n" + dataframe_info += "\n" return dataframe_info - - def convert_df_sql_connector_to_str( - self, df: pd.DataFrame, extras: dict = None - ) -> str: - """ - Convert df to csv like format where csv is wrapped inside
- Args: - df (pd.DataFrame): PandaAI dataframe or dataframe - extras (dict, optional): expect index to exists - - Returns: - str: dataframe stringify - """ - table_description_tag = ( - f' description="{df.description}"' if df.description is not None else "" - ) - table_head_tag = f'' - return f"{table_head_tag}\n{df.get_head().to_csv()}\n
" - - def convert_df_to_json(self, df: pd.DataFrame, extras: dict) -> dict: - """ - Convert df to json dictionary and return json - Args: - df (pd.DataFrame): PandaAI dataframe or dataframe - extras (dict, optional): expect index to exists - - Returns: - str: dataframe json - """ - - # Create a dictionary representing the data structure - df_info = { - "name": df.name, - "description": None, - "type": df.type, - } - # Add DataFrame details to the result - data = { - "rows": df.rows_count, - "columns": df.columns_count, - "schema": {"fields": []}, - } - - # Iterate over DataFrame columns - df_head = df.get_head() - for col_name, col_dtype in df_head.dtypes.items(): - col_info = { - "name": col_name, - "type": str(col_dtype), - } - - data["schema"]["fields"].append(col_info) - - result = df_info | data - - return result - - def convert_df_to_json_str(self, df: pd.DataFrame, extras: dict) -> str: - """ - Convert df to json and return it as string - Args: - df (pd.DataFrame): PandaAI dataframe or dataframe - extras (dict, optional): expect index to exists - - Returns: - str: dataframe stringify - """ - return json.dumps(self.convert_df_to_json(df, extras)) - - def convert_df_to_yml(self, df: pd.DataFrame, extras: dict) -> str: - json_df = self.convert_df_to_json(df, extras) - - import yaml - - yml_str = yaml.dump(json_df, sort_keys=False, allow_unicode=True) - return f"\n{yml_str}\n
\n" diff --git a/pandasai/helpers/sql_sanitizer.py b/pandasai/helpers/sql_sanitizer.py new file mode 100644 index 000000000..82b4306eb --- /dev/null +++ b/pandasai/helpers/sql_sanitizer.py @@ -0,0 +1,16 @@ +import os +import re + + +def sanitize_sql_table_name(filepath: str) -> str: + # Extract the file name without extension + file_name = os.path.splitext(os.path.basename(filepath))[0] + + # Replace invalid characters with underscores + sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", file_name) + + # Truncate to a reasonable length (e.g., 64 characters) + max_length = 64 + sanitized_name = sanitized_name[:max_length] + + return sanitized_name diff --git a/tests/unit_tests/agent/test_agent.py b/tests/unit_tests/agent/test_agent.py index c98fa35db..48d0dcae2 100644 --- a/tests/unit_tests/agent/test_agent.py +++ b/tests/unit_tests/agent/test_agent.py @@ -9,7 +9,6 @@ from pandasai.config import Config, ConfigManager from pandasai.dataframe.base import DataFrame from pandasai.exceptions import CodeExecutionError -from pandasai.helpers.dataframe_serializer import DataframeSerializerType from pandasai.llm.fake import FakeLLM @@ -38,7 +37,7 @@ def llm(self, output: Optional[str] = None) -> FakeLLM: @pytest.fixture def config(self, llm: FakeLLM) -> dict: - return {"llm": llm, "dataframe_serializer": DataframeSerializerType.CSV} + return {"llm": llm} @pytest.fixture def agent(self, sample_df: pd.DataFrame, config: dict) -> Agent: diff --git a/tests/unit_tests/helpers/test_dataframe_serializer.py b/tests/unit_tests/helpers/test_dataframe_serializer.py index 78cabe165..9303e6211 100644 --- a/tests/unit_tests/helpers/test_dataframe_serializer.py +++ b/tests/unit_tests/helpers/test_dataframe_serializer.py @@ -1,39 +1,32 @@ -import unittest +import pytest -from pandasai.dataframe.base import DataFrame -from pandasai.helpers.dataframe_serializer import ( - DataframeSerializer, - DataframeSerializerType, -) +from pandasai import DataFrame +from pandasai.helpers.dataframe_serializer import DataframeSerializer -class TestDataframeSerializer(unittest.TestCase): - def setUp(self): - self.serializer = DataframeSerializer() +class TestDataframeSerializer: + @pytest.fixture + def sample_df(self): + df = DataFrame({"Name": ["Alice", "Bob"], "Age": [25, 30]}) + df.name = "test_table" + df.description = "This is a test table" + return df - def test_convert_df_to_yml(self): - # Test convert df to yml - data = {"name": ["en_name", "中文_名称"]} - connector = DataFrame(data, name="en_table_name", description="中文_描述") - result = self.serializer.serialize( - connector, - type_=DataframeSerializerType.YML, - extras={"index": 0, "type": "pd.Dataframe"}, - ) + @pytest.fixture + def sample_dataframe_serializer(self): + return DataframeSerializer() - self.assertIn( - """ -name: en_table_name -description: null -type: pd.DataFrame -rows: 2 -columns: 1 -schema: - fields: - - name: name - type: object + def test_serialize_with_name_and_description( + self, sample_dataframe_serializer, sample_df + ): + """Test serialization with name and description attributes.""" -
-""", - result, + result = sample_dataframe_serializer.serialize(sample_df) + expected = ( + '\n' + "Name,Age\n" + "Alice,25\n" + "Bob,30\n" + "
\n" ) + assert result == expected diff --git a/tests/unit_tests/helpers/test_sql_sanitizer.py b/tests/unit_tests/helpers/test_sql_sanitizer.py new file mode 100644 index 000000000..5f4ab40fc --- /dev/null +++ b/tests/unit_tests/helpers/test_sql_sanitizer.py @@ -0,0 +1,19 @@ +from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name + + +class TestSqlSanitizer: + def test_valid_filename(self): + filepath = "/path/to/valid_table.csv" + expected = "valid_table" + assert sanitize_sql_table_name(filepath) == expected + + def test_filename_with_special_characters(self): + filepath = "/path/to/invalid!@#.csv" + expected = "invalid___" + assert sanitize_sql_table_name(filepath) == expected + + def test_filename_with_long_name(self): + """Test with a filename exceeding the length limit.""" + filepath = "/path/to/" + "a" * 100 + ".csv" + expected = "a" * 64 + assert sanitize_sql_table_name(filepath) == expected diff --git a/tests/unit_tests/prompts/test_sql_prompt.py b/tests/unit_tests/prompts/test_sql_prompt.py index 48bb58335..333b575db 100644 --- a/tests/unit_tests/prompts/test_sql_prompt.py +++ b/tests/unit_tests/prompts/test_sql_prompt.py @@ -51,7 +51,7 @@ def test_str_with_args(self, output_type, output_type_template): llm = FakeLLM() agent = Agent( - pai.DataFrame(), + pai.DataFrame(name="test"), config={"llm": llm}, ) prompt = GeneratePythonCodeWithSQLPrompt( @@ -68,10 +68,9 @@ def test_str_with_args(self, output_type, output_type_template): prompt_content == f''' - -dfs[0]:0x0 + - +
From 72d4732d169765de37e81d5cebf19f2662079e49 Mon Sep 17 00:00:00 2001 From: Raoul Date: Wed, 15 Jan 2025 18:05:43 +0100 Subject: [PATCH 2/3] fix(DataframeSerializer): fixing test failure on windows --- .../unit_tests/helpers/test_dataframe_serializer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/unit_tests/helpers/test_dataframe_serializer.py b/tests/unit_tests/helpers/test_dataframe_serializer.py index 9303e6211..9197a9e69 100644 --- a/tests/unit_tests/helpers/test_dataframe_serializer.py +++ b/tests/unit_tests/helpers/test_dataframe_serializer.py @@ -22,11 +22,10 @@ def test_serialize_with_name_and_description( """Test serialization with name and description attributes.""" result = sample_dataframe_serializer.serialize(sample_df) - expected = ( - '\n' - "Name,Age\n" - "Alice,25\n" - "Bob,30\n" - "
\n" - ) + expected = """ +Name,Age +Alice,25 +Bob,30 +
+""" assert result == expected From 3aa6a35a825cb8b1c8b08bd55baab52fb42f2d5a Mon Sep 17 00:00:00 2001 From: Raoul Date: Wed, 15 Jan 2025 18:19:31 +0100 Subject: [PATCH 3/3] fix(DataframeSerializer): fixing test failure on windows --- tests/unit_tests/helpers/test_dataframe_serializer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/helpers/test_dataframe_serializer.py b/tests/unit_tests/helpers/test_dataframe_serializer.py index 9197a9e69..ca31dba08 100644 --- a/tests/unit_tests/helpers/test_dataframe_serializer.py +++ b/tests/unit_tests/helpers/test_dataframe_serializer.py @@ -28,4 +28,4 @@ def test_serialize_with_name_and_description( Bob,30 """ - assert result == expected + assert result.replace("\r\n", "\n") == expected.replace("\r\n", "\n")