Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: clean code in smart_dataframe and smart_datalake #814

Merged
merged 22 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
433ad99
refactor: extract import from file method
gventuri Nov 29, 2023
ff945c8
refactor: extract df head methods
gventuri Dec 2, 2023
f9ccbba
refactor: move connector config in the relative connector file
gventuri Dec 6, 2023
03634bb
refactor: csv and pandas files are now treated as a connector
gventuri Dec 8, 2023
caebf2f
chore: remove verbose getters and setters
gventuri Dec 8, 2023
7b343fb
refactor: remove load and save feature
gventuri Dec 9, 2023
d899b66
refactor: create dataframe proxy
gventuri Dec 11, 2023
45097ae
chore: simplify agent
gventuri Dec 11, 2023
935f8ce
chore: simplify datalake
gventuri Dec 12, 2023
945b0fe
refactor: simplify smart datalake
gventuri Dec 13, 2023
8353a41
refactor: centralize context in lakes
gventuri Dec 17, 2023
e497d88
refactor: move lake callbacks to dedicate class
gventuri Dec 17, 2023
181689c
fix: load connector before generating cache hex
gventuri Dec 18, 2023
e4b2154
fix: only allow direct sql to SQLConnectors
gventuri Dec 21, 2023
480d7bb
fix: check sql connector was not working
gventuri Jan 2, 2024
38886ce
Merge branch 'main' into refactor/clean_code
gventuri Jan 2, 2024
fa74b98
fix(connector): update connector validation at the start
ArslanSaleem Jan 3, 2024
b6acc29
fix(direct_sql): fix some leftovers
ArslanSaleem Jan 4, 2024
ce1e848
Merge branch 'main' into refactor/clean_code
ArslanSaleem Jan 4, 2024
c60eb98
fix: merged change revert built-in shadowing
ArslanSaleem Jan 4, 2024
e043039
Merge branch 'release/v1.6' into refactor/clean_code
gventuri Jan 11, 2024
8b19c8d
Merge branch 'refactor/clean_code' of https://github.com/gventuri/pan…
gventuri Jan 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 43 additions & 47 deletions pandasai/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
from typing import Union, List, Optional

import pandas as pd
from pandasai.skills import skill
from ..helpers.df_info import DataFrameType
from ..helpers.logger import Logger
from ..helpers.memory import Memory
from ..prompts.base import AbstractPrompt
Expand All @@ -21,58 +20,55 @@ class Agent:
Agent class to improve the conversational experience in PandasAI
"""

_lake: SmartDatalake = None
_logger: Optional[Logger] = None

def __init__(
self,
dfs: Union[DataFrameType, List[DataFrameType]],
dfs: Union[pd.DataFrame, List[pd.DataFrame]],
config: Optional[Union[Config, dict]] = None,
logger: Optional[Logger] = None,
memory_size: int = 10,
):
"""
Args:
df (Union[DataFrameType, List[DataFrameType]]): DataFrame can be Pandas,
df (Union[pd.DataFrame, List[pd.DataFrame]]): Pandas dataframe
Polars or Database connectors
memory_size (int, optional): Conversation history to use during chat.
Defaults to 1.
"""

# Get a list of dataframes, if only one dataframe is passed
if not isinstance(dfs, list):
dfs = [dfs]

self._lake = SmartDatalake(dfs, config, logger, memory=Memory(memory_size))

# set instance type in SmartDataLake
self._lake.set_instance_type(self.__class__.__name__)
# Configure the smart datalake
self.lake = SmartDatalake(dfs, config, logger, memory=Memory(memory_size))
self.lake.set_instance_type(self.__class__.__name__)

self._logger = self._lake.logger
self.logger = self.lake.logger

def add_skills(self, *skills: List[skill]):
"""
Add Skills to PandasAI
"""
self._lake.add_skills(*skills)
self.lake.add_skills(*skills)

def _call_llm_with_prompt(self, prompt: AbstractPrompt):
def call_llm_with_prompt(self, prompt: AbstractPrompt):
"""
Call LLM with prompt using error handling to retry based on config
Args:
prompt (AbstractPrompt): AbstractPrompt to pass to LLM's
"""
retry_count = 0
while retry_count < self._lake.config.max_retries:
while retry_count < self.lake.config.max_retries:
try:
result: str = self._lake.llm.call(prompt)
result: str = self.lake.llm.call(prompt)
if prompt.validate(result):
return result
else:
raise Exception("Response validation failed!")
except Exception:
if (
not self._lake.use_error_correction_framework
or retry_count >= self._lake.config.max_retries - 1
not self.lake.config.use_error_correction_framework
or retry_count >= self.lake.config.max_retries - 1
):
raise
retry_count += 1
Expand All @@ -83,8 +79,8 @@ def chat(self, query: str, output_type: Optional[str] = None):
"""
try:
is_related = self.check_if_related_to_conversation(query)
self._lake.is_related_query(is_related)
return self._lake.chat(query, output_type=output_type)
self.lake.is_related_query(is_related)
return self.lake.chat(query, output_type=output_type)
except Exception as exception:
return (
"Unfortunately, I was not able to get your answers, "
Expand All @@ -98,44 +94,44 @@ def add_message(self, message, is_user=False):
to the memory without calling the chat function (for example, when you
need to add a message from the agent).
"""
self._lake._memory.add(message, is_user=is_user)
self.lake.memory.add(message, is_user=is_user)

def check_if_related_to_conversation(self, query: str) -> bool:
"""
Check if the query is related to the previous conversation
"""
if self._lake._memory.count() == 0:
if self.lake.memory.count() == 0:
return

prompt = CheckIfRelevantToConversationPrompt(
conversation=self._lake._memory.get_conversation(),
conversation=self.lake.memory.get_conversation(),
query=query,
)

result = self._call_llm_with_prompt(prompt)
result = self.call_llm_with_prompt(prompt)

related = "true" in result
self._logger.log(
f"""Check if the new message is related to the conversation: {related}"""
is_related = "true" in result
self.logger.log(
f"""Check if the new message is related to the conversation: {is_related}"""
)

if not related:
self._lake.clear_memory()
if not is_related:
self.lake.clear_memory()

return related
return is_related

def clarification_questions(self, query: str) -> List[str]:
"""
Generate clarification questions based on the data
"""
prompt = ClarificationQuestionPrompt(
dataframes=self._lake.dfs,
conversation=self._lake._memory.get_conversation(),
dataframes=self.lake.dfs,
conversation=self.lake.memory.get_conversation(),
query=query,
)

result = self._call_llm_with_prompt(prompt)
self._logger.log(
result = self.call_llm_with_prompt(prompt)
self.logger.log(
f"""Clarification Questions: {result}
"""
)
Expand All @@ -146,19 +142,19 @@ def start_new_conversation(self):
"""
Clears the previous conversation
"""
self._lake.clear_memory()
self.lake.clear_memory()

def explain(self) -> str:
"""
Returns the explanation of the code how it reached to the solution
"""
try:
prompt = ExplainPrompt(
conversation=self._lake._memory.get_conversation(),
code=self._lake.last_code_executed,
conversation=self.lake.memory.get_conversation(),
code=self.lake.last_code_executed,
)
response = self._call_llm_with_prompt(prompt)
self._logger.log(
response = self.call_llm_with_prompt(prompt)
self.logger.log(
f"""Explanation: {response}
"""
)
Expand All @@ -174,11 +170,11 @@ def rephrase_query(self, query: str):
try:
prompt = RephraseQueryPrompt(
query=query,
dataframes=self._lake.dfs,
conversation=self._lake._memory.get_conversation(),
dataframes=self.lake.dfs,
conversation=self.lake.memory.get_conversation(),
)
response = self._call_llm_with_prompt(prompt)
self._logger.log(
response = self.call_llm_with_prompt(prompt)
self.logger.log(
f"""Rephrased Response: {response}
"""
)
Expand All @@ -192,16 +188,16 @@ def rephrase_query(self, query: str):

@property
def last_code_generated(self):
return self._lake.last_code_generated
return self.lake.last_code_generated

@property
def last_code_executed(self):
return self._lake.last_code_executed
return self.lake.last_code_executed

@property
def last_prompt(self):
return self._lake.last_prompt
return self.lake.last_prompt

@property
def last_query_log_id(self):
return self._lake.last_query_log_id
return self.lake.last_query_log_id
4 changes: 4 additions & 0 deletions pandasai/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .yahoo_finance import YahooFinanceConnector
from .airtable import AirtableConnector
from .sql import SqliteConnector
from .pandas import PandasConnector
from .polars import PolarsConnector

__all__ = [
"BaseConnector",
Expand All @@ -22,4 +24,6 @@
"DatabricksConnector",
"AirtableConnector",
"SqliteConnector",
"PandasConnector",
"PolarsConnector",
]
20 changes: 15 additions & 5 deletions pandasai/connectors/airtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Airtable connectors are used to connect airtable records.
"""

from .base import AirtableConnectorConfig, BaseConnector, BaseConnectorConfig
from .base import BaseConnector, BaseConnectorConfig
from typing import Union, Optional
import requests
import pandas as pd
Expand All @@ -14,6 +14,16 @@
from functools import cache, cached_property


class AirtableConnectorConfig(BaseConnectorConfig):
"""
Connecter configuration for Airtable data.
"""

api_key: str
base_id: str
database: str = "airtable_data"


class AirtableConnector(BaseConnector):
"""
Airtable connector to retrieving record data.
Expand Down Expand Up @@ -136,12 +146,12 @@ def fallback_name(self):
"""
return self._config.table

def execute(self):
def execute(self) -> pd.DataFrame:
"""
Execute the connector and return the result.

Returns:
DataFrameType: The result of the connector.
pd.DataFrame: The result of the connector.
"""
if cached := self._cached() or self._cached(include_additional_filters=True):
return pd.read_parquet(cached)
Expand Down Expand Up @@ -206,7 +216,7 @@ def _fetch_data(self):
return pd.DataFrame(data)

@cache
def head(self):
def head(self, n: int = 5) -> pd.DataFrame:
"""
Return the head of the table that
the connector is connected to.
Expand All @@ -215,7 +225,7 @@ def head(self):
DatFrameType: The head of the data source
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment for the head method's return type should be updated to pd.DataFrame to reflect the actual return type.

that the connector is connected to .
"""
data = self._request_api(params={"maxRecords": 5})
data = self._request_api(params={"maxRecords": n})
return pd.DataFrame(
[
{"id": record["id"], **record["fields"]}
Expand Down
Loading