diff --git a/examples/using_workspace_env.py b/examples/using_workspace_env.py new file mode 100644 index 000000000..604d72a60 --- /dev/null +++ b/examples/using_workspace_env.py @@ -0,0 +1,38 @@ +import os +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI +from pandasai.schemas.df_config import Config + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + + +os.environ["PANDASAI_WORKSPACE"] = "workspace dir path" + + +llm = OpenAI("YOUR_API_KEY") +config__ = {"llm": llm, "save_charts": False} + + +agent = Agent( + [employees_df, salaries_df], + config=Config(**config__), + memory_size=10, +) + +# Chat with the agent +response = agent.chat("plot salary against department?") +print(response) diff --git a/pandasai/exceptions.py b/pandasai/exceptions.py index 56551ff34..56853c79d 100644 --- a/pandasai/exceptions.py +++ b/pandasai/exceptions.py @@ -146,3 +146,12 @@ class AdvancedReasoningDisabledError(Exception): Args: Exception (Exception): AdvancedReasoningDisabledError """ + + +class InvalidWorkspacePathError(Exception): + """ + Raised when the environment variable of workspace exist but path is invalid + + Args: + Exception (Exception): InvalidWorkspacePathError + """ diff --git a/pandasai/helpers/code_manager.py b/pandasai/helpers/code_manager.py index c95264524..edf7aa0fd 100644 --- a/pandasai/helpers/code_manager.py +++ b/pandasai/helpers/code_manager.py @@ -5,6 +5,7 @@ import astor import pandas as pd +from pandasai.helpers.path import find_project_root from pandasai.helpers.skills_manager import SkillsManager @@ -235,6 +236,14 @@ def execute_code(self, code: str, context: CodeExecutionContext) -> Any: file_name=str(context.prompt_id), save_charts_path_str=self._config.save_charts_path, ) + else: + # Temporarily save generated chart to display + code = add_save_chart( + code, + logger=self._logger, + file_name="temp_chart", + save_charts_path_str=find_project_root(), + ) # Reset used skills context.skills_manager.used_skills = [] diff --git a/pandasai/helpers/path.py b/pandasai/helpers/path.py index 18d771a7b..5527a55e7 100644 --- a/pandasai/helpers/path.py +++ b/pandasai/helpers/path.py @@ -1,8 +1,27 @@ import os +from pandasai.exceptions import InvalidWorkspacePathError + def find_project_root(filename=None): - # Get the path of the file that is being executed + """ + Check if Custom workspace path provide use that otherwise iterate to + find project root + """ + if "PANDASAI_WORKSPACE" in os.environ: + workspace_path = os.environ["PANDASAI_WORKSPACE"] + if ( + workspace_path + and os.path.exists(workspace_path) + and os.path.isdir(workspace_path) + ): + return workspace_path + raise InvalidWorkspacePathError( + "PANDASAI_WORKSPACE does not point to a valid directory" + ) + + # Get the path of the file that is be + # ing executed current_file_path = os.path.abspath(os.getcwd()) # Navigate back until we either find a $filename file or there is no parent @@ -26,7 +45,8 @@ def find_project_root(filename=None): parent_folder = os.path.dirname(root_folder) if parent_folder == root_folder: - raise ValueError("Could not find the root folder of the project.") + # if project root is not found return cwd + return os.getcwd() root_folder = parent_folder diff --git a/pandasai/schemas/df_config.py b/pandasai/schemas/df_config.py index 7ceaf725e..2cd1f312d 100644 --- a/pandasai/schemas/df_config.py +++ b/pandasai/schemas/df_config.py @@ -1,7 +1,6 @@ from pydantic import BaseModel, validator, Field from typing import Optional, List, Any, Dict, Type, TypedDict from pandasai.constants import DEFAULT_CHART_DIRECTORY - from pandasai.responses import ResponseParser from ..middlewares.base import Middleware from ..callbacks.base import BaseCallback diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index f3dac336f..b480f1e4a 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -160,6 +160,7 @@ def initialize(self): charts_dir = os.path.join( (find_project_root()), self._config.save_charts_path ) + self._config.save_charts_path = charts_dir except ValueError: charts_dir = os.path.join( os.getcwd(), self._config.save_charts_path @@ -438,6 +439,7 @@ def chat(self, query: str, output_type: Optional[str] = None): ) break + except Exception as e: if ( not self._config.use_error_correction_framework diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index 938418394..b2f84c68e 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -473,6 +473,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: "llm": llm, "enable_cache": False, "save_charts": True, + "save_charts_path": "charts", }, ) @@ -482,7 +483,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: assert plt_mock.savefig.called assert ( plt_mock.savefig.call_args.args[0] - == f"exports/charts/{smart_dataframe.last_prompt_id}.png" + == f"charts/{smart_dataframe.last_prompt_id}.png" ) def test_add_middlewares(self, smart_dataframe: SmartDataframe, custom_middleware):