Skip to content

Commit

Permalink
feat: workspace env (#717)
Browse files Browse the repository at this point in the history
* fix(chart): charts to save to save_chart_path

* refactor sourcery changes

* 'Refactored by Sourcery'

* refactor chart save code

* fix: minor leftovers

* feat(workspace_env): add workspace env to store cache, temp chart and config

* add error handling and comments

---------

Co-authored-by: Sourcery AI <>
  • Loading branch information
ArslanSaleem authored Nov 1, 2023
1 parent 5537a7e commit 451a843
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 4 deletions.
38 changes: 38 additions & 0 deletions examples/using_workspace_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import pandas as pd
from pandasai import Agent

from pandasai.llm.openai import OpenAI
from pandasai.schemas.df_config import Config

employees_data = {
"EmployeeID": [1, 2, 3, 4, 5],
"Name": ["John", "Emma", "Liam", "Olivia", "William"],
"Department": ["HR", "Sales", "IT", "Marketing", "Finance"],
}

salaries_data = {
"EmployeeID": [1, 2, 3, 4, 5],
"Salary": [5000, 6000, 4500, 7000, 5500],
}

employees_df = pd.DataFrame(employees_data)
salaries_df = pd.DataFrame(salaries_data)


os.environ["PANDASAI_WORKSPACE"] = "workspace dir path"


llm = OpenAI("YOUR_API_KEY")
config__ = {"llm": llm, "save_charts": False}


agent = Agent(
[employees_df, salaries_df],
config=Config(**config__),
memory_size=10,
)

# Chat with the agent
response = agent.chat("plot salary against department?")
print(response)
9 changes: 9 additions & 0 deletions pandasai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,12 @@ class AdvancedReasoningDisabledError(Exception):
Args:
Exception (Exception): AdvancedReasoningDisabledError
"""


class InvalidWorkspacePathError(Exception):
"""
Raised when the environment variable of workspace exist but path is invalid
Args:
Exception (Exception): InvalidWorkspacePathError
"""
9 changes: 9 additions & 0 deletions pandasai/helpers/code_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import astor
import pandas as pd
from pandasai.helpers.path import find_project_root

from pandasai.helpers.skills_manager import SkillsManager

Expand Down Expand Up @@ -235,6 +236,14 @@ def execute_code(self, code: str, context: CodeExecutionContext) -> Any:
file_name=str(context.prompt_id),
save_charts_path_str=self._config.save_charts_path,
)
else:
# Temporarily save generated chart to display
code = add_save_chart(
code,
logger=self._logger,
file_name="temp_chart",
save_charts_path_str=find_project_root(),
)

# Reset used skills
context.skills_manager.used_skills = []
Expand Down
24 changes: 22 additions & 2 deletions pandasai/helpers/path.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
import os

from pandasai.exceptions import InvalidWorkspacePathError


def find_project_root(filename=None):
# Get the path of the file that is being executed
"""
Check if Custom workspace path provide use that otherwise iterate to
find project root
"""
if "PANDASAI_WORKSPACE" in os.environ:
workspace_path = os.environ["PANDASAI_WORKSPACE"]
if (
workspace_path
and os.path.exists(workspace_path)
and os.path.isdir(workspace_path)
):
return workspace_path
raise InvalidWorkspacePathError(
"PANDASAI_WORKSPACE does not point to a valid directory"
)

# Get the path of the file that is be
# ing executed
current_file_path = os.path.abspath(os.getcwd())

# Navigate back until we either find a $filename file or there is no parent
Expand All @@ -26,7 +45,8 @@ def find_project_root(filename=None):

parent_folder = os.path.dirname(root_folder)
if parent_folder == root_folder:
raise ValueError("Could not find the root folder of the project.")
# if project root is not found return cwd
return os.getcwd()

root_folder = parent_folder

Expand Down
1 change: 0 additions & 1 deletion pandasai/schemas/df_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pydantic import BaseModel, validator, Field
from typing import Optional, List, Any, Dict, Type, TypedDict
from pandasai.constants import DEFAULT_CHART_DIRECTORY

from pandasai.responses import ResponseParser
from ..middlewares.base import Middleware
from ..callbacks.base import BaseCallback
Expand Down
2 changes: 2 additions & 0 deletions pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def initialize(self):
charts_dir = os.path.join(
(find_project_root()), self._config.save_charts_path
)
self._config.save_charts_path = charts_dir
except ValueError:
charts_dir = os.path.join(
os.getcwd(), self._config.save_charts_path
Expand Down Expand Up @@ -438,6 +439,7 @@ def chat(self, query: str, output_type: Optional[str] = None):
)

break

except Exception as e:
if (
not self._config.use_error_correction_framework
Expand Down
3 changes: 2 additions & 1 deletion tests/test_smartdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
"llm": llm,
"enable_cache": False,
"save_charts": True,
"save_charts_path": "charts",
},
)

Expand All @@ -482,7 +483,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
assert plt_mock.savefig.called
assert (
plt_mock.savefig.call_args.args[0]
== f"exports/charts/{smart_dataframe.last_prompt_id}.png"
== f"charts/{smart_dataframe.last_prompt_id}.png"
)

def test_add_middlewares(self, smart_dataframe: SmartDataframe, custom_middleware):
Expand Down

0 comments on commit 451a843

Please sign in to comment.