From 0b838f31f223e9827ca821e8e5c509bfa5f9b3c9 Mon Sep 17 00:00:00 2001 From: Raoul Date: Mon, 20 Jan 2025 12:09:22 +0100 Subject: [PATCH 1/9] fix(VirtualDataframe): fixing virtual dataframe name conflict --- pandasai/dataframe/base.py | 29 +++++++++++++------------ pandasai/dataframe/virtual_dataframe.py | 15 ++++++++----- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/pandasai/dataframe/base.py b/pandasai/dataframe/base.py index 345225419..7099ea1f3 100644 --- a/pandasai/dataframe/base.py +++ b/pandasai/dataframe/base.py @@ -37,14 +37,14 @@ class DataFrame(pd.DataFrame): config (Config): Configuration settings """ - _metadata: ClassVar[list] = [ - "name", - "description", - "schema", - "path", - "config", + _metadata = [ "_agent", "_column_hash", + "config", + "description", + "name", + "path", + "schema", ] def __init__( @@ -56,21 +56,22 @@ def __init__( copy: bool | None = None, **kwargs, ) -> None: + _name: Optional[str] = kwargs.pop("name", None) + _schema: Optional[SemanticLayerSchema] = kwargs.pop("schema", None) + _description: Optional[str] = kwargs.pop("description", None) + _path: Optional[str] = kwargs.pop("path", None) + super().__init__( data=data, index=index, columns=columns, dtype=dtype, copy=copy ) - self.name: Optional[str] = kwargs.pop("name", None) self._column_hash = self._calculate_column_hash() - if not self.name: - self.name = f"table_{self._column_hash}" - - schema: Optional[SemanticLayerSchema] = kwargs.pop("schema", None) - self.schema = schema or DataFrame.get_default_schema(self) + self.name = _name or f"table_{self._column_hash}" + self.schema = _schema or DataFrame.get_default_schema(self) + self.description = _description + self.path = _path - self.description: Optional[str] = kwargs.pop("description", None) - self.path: Optional[str] = kwargs.pop("path", None) self.config = pai.config.get() self._agent: Optional[Agent] = None diff --git a/pandasai/dataframe/virtual_dataframe.py b/pandasai/dataframe/virtual_dataframe.py index 5baac91c4..a6c571b83 100644 --- a/pandasai/dataframe/virtual_dataframe.py +++ b/pandasai/dataframe/virtual_dataframe.py @@ -13,14 +13,17 @@ class VirtualDataFrame(DataFrame): - _metadata: ClassVar[list] = [ + _metadata = [ + "_agent", + "_column_hash", + "_head", "_loader", + "config", + "description", "head", - "_head", + "name", + "path", "schema", - "config", - "_agent", - "_column_hash", ] def __init__(self, *args, **kwargs): @@ -34,6 +37,7 @@ def __init__(self, *args, **kwargs): raise VirtualizationError("Schema is required for virtualization!") table_name = schema.source.table + description = schema.description super().__init__( @@ -47,7 +51,6 @@ def __init__(self, *args, **kwargs): def head(self): if self._head is None: self._head = self._loader.load_head() - return self._head @property From 65dc12ad5ec49e27c5d725dcad1802bb91160755 Mon Sep 17 00:00:00 2001 From: Raoul Date: Mon, 20 Jan 2025 15:44:29 +0100 Subject: [PATCH 2/9] feature(View): enabling view in SemanticLayerSchema --- docs/v3/semantic-layer.mdx | 76 ++++++++++++++--- pandasai/data_loader/semantic_layer_schema.py | 85 +++++++++++++++++-- .../dataframe/test_semantic_layer_schema.py | 79 +++++++++++++++++ 3 files changed, 225 insertions(+), 15 deletions(-) diff --git a/docs/v3/semantic-layer.mdx b/docs/v3/semantic-layer.mdx index b393a393c..e71efb8fd 100644 --- a/docs/v3/semantic-layer.mdx +++ b/docs/v3/semantic-layer.mdx @@ -238,15 +238,22 @@ columns: ``` **Type**: `list[dict]` -- Each dictionary represents a column -- `name` (str): Name of the column -- `type` (str): Data type of the column - - "string": IDs, names, categories - - "integer": counts, whole numbers - - "float": prices, percentages - - "datetime": timestamps, dates - - "boolean": flags, true/false values -- `description` (str): Clear explanation of what the column represents +- Each dictionary represents a column. +- **Fields**: + - `name` (str): Name of the column. + - For tables: Use simple column names (e.g., `transaction_id`). + - `type` (str): Data type of the column. + - Supported types: + - `"string"`: IDs, names, categories. + - `"integer"`: Counts, whole numbers. + - `"float"`: Prices, percentages. + - `"datetime"`: Timestamps, dates. + - `"boolean"`: Flags, true/false values. + - `description` (str): Clear explanation of what the column represents. + +**Constraints**: +1. Column names must be unique. +2. For views, all column names must be in the format `[table].[column]`. #### transformations Apply transformations to your data to clean, convert, or anonymize it. @@ -350,4 +357,53 @@ Specify the maximum number of records to load. **Type**: `int` ```yaml -limit: 1000 \ No newline at end of file +limit: 1000 +``` + +### View Configuration + +The following sections detail all available configurations for view options in your `schema.yaml` file. Similar to views in SQL, you can define multiple tables and the relationships between them. + +#### Example Configuration + +```yaml +name: table_heart +source: + type: postgres + connection: + host: localhost + port: 5432 + database: test + user: test + password: test + view: true +columns: +- name: parents.id +- name: parents.name +- name: parents.age +- name: children.name +- name: children.age +relations: +- name: parent_to_children + description: Relation linking the parent to its children + from: parents.id + to: children.id +``` + +--- + +#### Constraints + +1. **Mutual Exclusivity**: + - A schema cannot define both `table` and `view` simultaneously. + - If `source.view` is `true`, then the schema represents a view. + +2. **Column Format**: + - For views: + - All columns must follow the format `[table].[column]`. + - `from` and `to` fields in `relations` must follow the `[table].[column]` format. + - Example: `parents.id`, `children.name`. + +3. **Relationships for Views**: + - Each table referenced in `columns` must have at least one relationship defined in `relations`. + - Relationships must specify `from` and `to` attributes in the `[table].[column]` format. diff --git a/pandasai/data_loader/semantic_layer_schema.py b/pandasai/data_loader/semantic_layer_schema.py index b5c922188..b9e8e4373 100644 --- a/pandasai/data_loader/semantic_layer_schema.py +++ b/pandasai/data_loader/semantic_layer_schema.py @@ -1,4 +1,6 @@ import json +import re +from functools import partial from typing import Any, Dict, List, Optional, Union import yaml @@ -32,6 +34,17 @@ def is_column_type_supported(cls, type: str) -> str: return type +class Relation(BaseModel): + name: Optional[str] = Field(None, description="Name of the relationship.") + description: Optional[str] = Field( + None, description="Description of the relationship." + ) + from_: str = Field( + ..., alias="from", description="Source column for the relationship." + ) + to: str = Field(..., description="Target column for the relationship.") + + class Transformation(BaseModel): type: str = Field(..., description="Type of transformation to be applied.") params: Optional[Dict[str, str]] = Field( @@ -48,11 +61,12 @@ def is_transformation_type_supported(cls, type: str) -> str: class Source(BaseModel): type: str = Field(..., description="Type of the data source.") + path: Optional[str] = Field(None, description="Path of the local data source.") connection: Optional[Dict[str, Union[str, int]]] = Field( None, description="Connection object of the data source." ) - path: Optional[str] = Field(None, description="Path of the local data source.") table: Optional[str] = Field(None, description="Table of the data source.") + view: Optional[bool] = Field(False, description="Whether table is a view") @model_validator(mode="before") @classmethod @@ -60,6 +74,7 @@ def validate_type_and_fields(cls, values): _type = values.get("type") path = values.get("path") table = values.get("table") + view = values.get("view") connection = values.get("connection") if _type in LOCAL_SOURCE_TYPES: @@ -67,15 +82,17 @@ def validate_type_and_fields(cls, values): raise ValueError( f"For local source type '{_type}', 'path' must be defined." ) + if view: + raise ValueError("For local source type you can't use a view.") elif _type in REMOTE_SOURCE_TYPES: if not connection: raise ValueError( f"For remote source type '{_type}', 'connection' must be defined." ) - if not table: - raise ValueError( - f"For remote source type '{_type}', 'table' must be defined." - ) + if table and view: + raise ValueError("Only one of 'table' or 'view' can be defined.") + if not table and not view: + raise ValueError("Either 'table' or 'view' must be defined.") else: raise ValueError(f"Unsupported source type: {_type}") @@ -104,6 +121,9 @@ class SemanticLayerSchema(BaseModel): columns: Optional[List[Column]] = Field( None, description="Structure and metadata of your dataset’s columns" ) + relations: Optional[List[Relation]] = Field( + None, description="Relationships between columns and tables." + ) order_by: Optional[List[str]] = Field( None, description="Ordering criteria for the dataset." ) @@ -120,6 +140,61 @@ class SemanticLayerSchema(BaseModel): None, description="Frequency of dataset updates." ) + @model_validator(mode="after") + def check_columns_relations(self): + column_re_check = r"^[a-zA-Z_]+\.[a-zA-Z_]+$" + is_view_column_name = partial(re.match, column_re_check) + + # unpack columns info + _columns = self.columns + _column_names = [col.name for col in _columns or ()] + _tables_names_in_columns = { + column_name.split(".")[0] for column_name in _column_names or () + } + + if len(_column_names) != len(set(_column_names)): + raise ValueError("Column names must be unique. Duplicate names found.") + + if self.source.view: + # unpack relations info + _relations = self.relations + _column_names_in_relations = { + table + for relation in _relations or () + for table in (relation.from_, relation.to) + } + _tables_names_in_relations = { + column_name.split(".")[0] + for column_name in _column_names_in_relations or () + } + + if not all( + is_view_column_name(column_name) for column_name in _column_names + ): + raise ValueError( + "All columns in a view must be in the format '[table].[column]'." + ) + + if not all( + is_view_column_name(column_name) + for column_name in _column_names_in_relations + ): + raise ValueError( + "All params 'from' and 'to' in the relations must be in the format '[table].[column]'." + ) + + if ( + uncovered_tables := _tables_names_in_columns + - _tables_names_in_relations + ): + raise ValueError( + f"No relations provided for the following tables {uncovered_tables}." + ) + + elif any(is_view_column_name(column_name) for column_name in _column_names): + raise ValueError("All columns in a table must be in the format '[column]'.") + return self + def to_dict(self) -> dict[str, Any]: return self.model_dump(exclude_none=True) diff --git a/tests/unit_tests/dataframe/test_semantic_layer_schema.py b/tests/unit_tests/dataframe/test_semantic_layer_schema.py index 87d053245..0d7233c29 100644 --- a/tests/unit_tests/dataframe/test_semantic_layer_schema.py +++ b/tests/unit_tests/dataframe/test_semantic_layer_schema.py @@ -91,6 +91,29 @@ def mysql_schema(self): }, } + @pytest.fixture + def mysql_view_schema(self): + return { + "name": "Users", + "columns": [ + {"name": "parents.id"}, + {"name": "parents.name"}, + {"name": "children.name"}, + ], + "relations": [{"from": "parents.id", "to": "children.id"}], + "source": { + "type": "mysql", + "connection": { + "host": "localhost", + "port": 3306, + "database": "test_db", + "user": "test_user", + "password": "test_password", + }, + "view": "true", + }, + } + def test_valid_schema(self, sample_schema): schema = SemanticLayerSchema(**sample_schema) @@ -113,6 +136,14 @@ def test_valid_mysql_schema(self, mysql_schema): assert len(schema.transformations) == 2 assert schema.source.type == "mysql" + def test_valid_mysql_view_schema(self, mysql_view_schema): + schema = SemanticLayerSchema(**mysql_view_schema) + + assert schema.name == "Users" + assert len(schema.columns) == 3 + assert schema.source.view == True + assert schema.source.type == "mysql" + def test_missing_source_path(self, sample_schema): sample_schema["source"].pop("path") @@ -203,3 +234,51 @@ def test_is_schema_source_same_false(self, mysql_schema, sample_schema): schema2 = SemanticLayerSchema(**sample_schema) assert is_schema_source_same(schema1, schema2) is False + + def test_invalid_source_view_for_local_type(self, sample_schema): + sample_schema["source"]["view"] = True + + with pytest.raises(ValidationError): + SemanticLayerSchema(**sample_schema) + + def test_invalid_source_view_and_table(self, mysql_schema): + mysql_schema["source"]["view"] = True + + with pytest.raises(ValidationError): + SemanticLayerSchema(**mysql_schema) + + def test_invalid_source_missing_view_or_table(self, mysql_schema): + mysql_schema["source"].pop("table") + + with pytest.raises(ValidationError): + SemanticLayerSchema(**mysql_schema) + + def test_invalid_duplicated_columns(self, sample_schema): + sample_schema["columns"].append(sample_schema["columns"][0]) + + with pytest.raises(ValidationError): + SemanticLayerSchema(**sample_schema) + + def test_invalid_wrong_column_format_in_view(self, mysql_view_schema): + mysql_view_schema["columns"][0]["name"] = "parentsid" + + with pytest.raises(ValidationError): + SemanticLayerSchema(**mysql_view_schema) + + def test_invalid_wrong_column_format(self, sample_schema): + sample_schema["columns"][0]["name"] = "parents.id" + + with pytest.raises(ValidationError): + SemanticLayerSchema(**sample_schema) + + def test_invalid_wrong_relation_format_in_view(self, mysql_view_schema): + mysql_view_schema["relations"][0]["to"] = "parentsid" + + with pytest.raises(ValidationError): + SemanticLayerSchema(**mysql_view_schema) + + def test_invalid_uncovered_columns_in_view(self, mysql_view_schema): + mysql_view_schema.pop("relations") + + with pytest.raises(ValidationError): + SemanticLayerSchema(**mysql_view_schema) From 57b890038d6291998c8ba72b63db70de8c30a613 Mon Sep 17 00:00:00 2001 From: Raoul Scalise <36519284+scaliseraoul@users.noreply.github.com> Date: Mon, 20 Jan 2025 15:57:17 +0100 Subject: [PATCH 3/9] Update pandasai/data_loader/semantic_layer_schema.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- pandasai/data_loader/semantic_layer_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandasai/data_loader/semantic_layer_schema.py b/pandasai/data_loader/semantic_layer_schema.py index b9e8e4373..65baf81ed 100644 --- a/pandasai/data_loader/semantic_layer_schema.py +++ b/pandasai/data_loader/semantic_layer_schema.py @@ -83,7 +83,7 @@ def validate_type_and_fields(cls, values): f"For local source type '{_type}', 'path' must be defined." ) if view: - raise ValueError("For local source type you can't use a view.") + raise ValueError("A view cannot be used with a local source type.") elif _type in REMOTE_SOURCE_TYPES: if not connection: raise ValueError( From 3575417208d5108ee0239a9a44bfc32aa4d40fdf Mon Sep 17 00:00:00 2001 From: Raoul Date: Mon, 20 Jan 2025 17:51:12 +0100 Subject: [PATCH 4/9] feature(View): enabling view in loader --- pandasai/data_loader/loader.py | 22 ++++++---- pandasai/data_loader/query_builder.py | 26 ++++++------ pandasai/data_loader/view_query_builder.py | 49 ++++++++++++++++++++++ pandasai/dataframe/virtual_dataframe.py | 2 +- 4 files changed, 77 insertions(+), 22 deletions(-) create mode 100644 pandasai/data_loader/view_query_builder.py diff --git a/pandasai/data_loader/loader.py b/pandasai/data_loader/loader.py index b4c94333f..f3eedc4ed 100644 --- a/pandasai/data_loader/loader.py +++ b/pandasai/data_loader/loader.py @@ -19,12 +19,14 @@ ) from .query_builder import QueryBuilder from .semantic_layer_schema import SemanticLayerSchema +from .view_query_builder import ViewQueryBuilder class DatasetLoader: def __init__(self): self.schema: Optional[SemanticLayerSchema] = None - self.dataset_path = None + self.query_builder: Optional[QueryBuilder] = None + self.dataset_path: Optional[str] = None def load(self, dataset_path: str) -> DataFrame: """Load data based on the provided dataset path. @@ -38,6 +40,11 @@ def load(self, dataset_path: str) -> DataFrame: self.dataset_path = dataset_path self._load_schema() + if self.schema.source.view: + self.query_builder = ViewQueryBuilder(self.schema) + else: + self.query_builder = QueryBuilder(self.schema) + source_type = self.schema.source.type if source_type in LOCAL_SOURCE_TYPES: df = self._load_from_local_source() @@ -139,13 +146,11 @@ def _load_from_local_source(self) -> pd.DataFrame: return self._read_csv_or_parquet(filepath, source_type) def load_head(self) -> pd.DataFrame: - query_builder = QueryBuilder(self.schema) - query = query_builder.get_head_query() + query = self.query_builder.get_head_query() return self.execute_query(query) def get_row_count(self) -> int: - query_builder = QueryBuilder(self.schema) - query = query_builder.get_row_count() + query = self.query_builder.get_row_count() result = self.execute_query(query) return result.iloc[0, 0] @@ -154,16 +159,18 @@ def execute_query(self, query: str) -> pd.DataFrame: source_type = source.type connection_info = source.connection + formatted_query = self.query_builder.format_query(query) + if not source_type: raise ValueError("Source type is missing in the schema.") load_function = self._get_loader_function(source_type) try: - return load_function(connection_info, query) + return load_function(connection_info, formatted_query) except Exception as e: raise RuntimeError( - f"Failed to execute query for source type '{source_type}' with query: {query}" + f"Failed to execute query for source type '{source_type}' with query: {formatted_query}" ) from e def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame: @@ -199,5 +206,6 @@ def copy(self) -> "DatasetLoader": """ new_loader = DatasetLoader() new_loader.schema = copy.deepcopy(self.schema) + new_loader.query_builder = copy.deepcopy(self.query_builder) new_loader.dataset_path = self.dataset_path return new_loader diff --git a/pandasai/data_loader/query_builder.py b/pandasai/data_loader/query_builder.py index 7ab29f642..de8635532 100644 --- a/pandasai/data_loader/query_builder.py +++ b/pandasai/data_loader/query_builder.py @@ -1,17 +1,19 @@ from typing import Any, Dict, List, Union -from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema +from pandasai.data_loader.semantic_layer_schema import Relation, SemanticLayerSchema class QueryBuilder: def __init__(self, schema: SemanticLayerSchema): self.schema = schema + def format_query(self, query): + return query + def build_query(self) -> str: columns = self._get_columns() - table_name = self._get_table_name() - query = f"SELECT {columns} FROM {table_name}" - + query = f"SELECT {columns}" + query += self._get_from_statement() query += self._add_order_by() query += self._add_limit() @@ -23,10 +25,8 @@ def _get_columns(self) -> str: else: return "*" - def _get_table_name(self): - table_name = self.schema.source.table - table_name = table_name.lower() - return table_name + def _get_from_statement(self): + return f" FROM {self.schema.source.table.lower()}" def _add_order_by(self) -> str: if not self.schema.order_by: @@ -46,13 +46,11 @@ def _add_limit(self, n=None) -> str: def get_head_query(self, n=5): source_type = self.schema.source.type - table_name = self._get_table_name() columns = self._get_columns() - + query = f"SELECT {columns}" + query += self._get_from_statement() order_by = "RANDOM()" if source_type in {"sqlite", "postgres"} else "RAND()" - - return f"SELECT {columns} FROM {table_name} ORDER BY {order_by} LIMIT {n}" + return f"{query} ORDER BY {order_by} LIMIT {n}" def get_row_count(self): - table_name = self._get_table_name() - return f"SELECT COUNT(*) FROM {table_name}" + return f"SELECT COUNT(*) {self._get_from_statement()}" diff --git a/pandasai/data_loader/view_query_builder.py b/pandasai/data_loader/view_query_builder.py new file mode 100644 index 000000000..f97a7a815 --- /dev/null +++ b/pandasai/data_loader/view_query_builder.py @@ -0,0 +1,49 @@ +from typing import Any, Dict, List, Union + +from pandasai.data_loader.query_builder import QueryBuilder +from pandasai.data_loader.semantic_layer_schema import Relation, SemanticLayerSchema + + +class ViewQueryBuilder(QueryBuilder): + def __init__(self, schema: SemanticLayerSchema): + super().__init__(schema) + + def format_query(self, query): + return f"{self._get_with_statement()} {query}" + + def build_query(self) -> str: + columns = self._get_columns() + query = self._get_with_statement() + query += f"SELECT {columns}" + query += self._get_from_statement() + query += self._add_order_by() + query += self._add_limit() + return query + + def _get_columns(self) -> str: + if self.schema.columns: + return ", ".join( + [f"{col.name.replace('.', '_')}" for col in self.schema.columns] + ) + else: + return super()._get_columns() + + def _get_from_statement(self): + return f" FROM {self.schema.name}" + + def _get_with_statement(self): + relations = self.schema.relations + first_table = relations[0].from_.split(".")[0] + query = f" WITH {self.schema.name} AS ( SELECT \n" + query += ", ".join( + [ + f"{col.name} AS {col.name.replace('.', '_')}" + for col in self.schema.columns + ] + ) + query += f"\n FROM {first_table}" + for relation in relations: + to_table = relation.to.split(".")[0] + query += f"\n JOIN {to_table} ON {relation.from_} = {relation.to}" + query += ")\n" + return query diff --git a/pandasai/dataframe/virtual_dataframe.py b/pandasai/dataframe/virtual_dataframe.py index a6c571b83..8e15d2440 100644 --- a/pandasai/dataframe/virtual_dataframe.py +++ b/pandasai/dataframe/virtual_dataframe.py @@ -36,7 +36,7 @@ def __init__(self, *args, **kwargs): if not schema: raise VirtualizationError("Schema is required for virtualization!") - table_name = schema.source.table + table_name = schema.source.table or schema.name description = schema.description From db0acf070f3a76fffa34602006670e712873857f Mon Sep 17 00:00:00 2001 From: Raoul Date: Tue, 21 Jan 2025 14:43:24 +0100 Subject: [PATCH 5/9] feature(View): adding integration tests --- pandasai/data_loader/semantic_layer_schema.py | 3 + pandasai/data_loader/view_query_builder.py | 25 +++-- .../dataframe/test_semantic_layer_schema.py | 8 +- .../dataframe/test_view_query_builder.py | 97 +++++++++++++++++++ 4 files changed, 122 insertions(+), 11 deletions(-) create mode 100644 tests/unit_tests/dataframe/test_view_query_builder.py diff --git a/pandasai/data_loader/semantic_layer_schema.py b/pandasai/data_loader/semantic_layer_schema.py index 263f75910..36576ee31 100644 --- a/pandasai/data_loader/semantic_layer_schema.py +++ b/pandasai/data_loader/semantic_layer_schema.py @@ -168,6 +168,9 @@ def check_columns_relations(self): for column_name in _column_names_in_relations or () } + if not self.relations: + raise ValueError("At least one relation must be defined for view.") + if not all( is_view_column_name(column_name) for column_name in _column_names ): diff --git a/pandasai/data_loader/view_query_builder.py b/pandasai/data_loader/view_query_builder.py index f97a7a815..4c1a606df 100644 --- a/pandasai/data_loader/view_query_builder.py +++ b/pandasai/data_loader/view_query_builder.py @@ -9,7 +9,7 @@ def __init__(self, schema: SemanticLayerSchema): super().__init__(schema) def format_query(self, query): - return f"{self._get_with_statement()} {query}" + return f"{self._get_with_statement()}{query}" def build_query(self) -> str: columns = self._get_columns() @@ -34,16 +34,21 @@ def _get_from_statement(self): def _get_with_statement(self): relations = self.schema.relations first_table = relations[0].from_.split(".")[0] - query = f" WITH {self.schema.name} AS ( SELECT \n" - query += ", ".join( - [ - f"{col.name} AS {col.name.replace('.', '_')}" - for col in self.schema.columns - ] - ) - query += f"\n FROM {first_table}" + query = f"WITH {self.schema.name} AS ( SELECT\n" + + if self.schema.columns: + query += ", ".join( + [ + f"{col.name} AS {col.name.replace('.', '_')}" + for col in self.schema.columns + ] + ) + else: + query += "*" + + query += f"\nFROM {first_table}" for relation in relations: to_table = relation.to.split(".")[0] - query += f"\n JOIN {to_table} ON {relation.from_} = {relation.to}" + query += f"\nJOIN {to_table} ON {relation.from_} = {relation.to}" query += ")\n" return query diff --git a/tests/unit_tests/dataframe/test_semantic_layer_schema.py b/tests/unit_tests/dataframe/test_semantic_layer_schema.py index 0d7233c29..ad93fda9b 100644 --- a/tests/unit_tests/dataframe/test_semantic_layer_schema.py +++ b/tests/unit_tests/dataframe/test_semantic_layer_schema.py @@ -253,6 +253,12 @@ def test_invalid_source_missing_view_or_table(self, mysql_schema): with pytest.raises(ValidationError): SemanticLayerSchema(**mysql_schema) + def test_invalid_no_relation_for_view(self, mysql_view_schema): + mysql_view_schema.pop("relations") + + with pytest.raises(ValidationError): + SemanticLayerSchema(**mysql_view_schema) + def test_invalid_duplicated_columns(self, sample_schema): sample_schema["columns"].append(sample_schema["columns"][0]) @@ -278,7 +284,7 @@ def test_invalid_wrong_relation_format_in_view(self, mysql_view_schema): SemanticLayerSchema(**mysql_view_schema) def test_invalid_uncovered_columns_in_view(self, mysql_view_schema): - mysql_view_schema.pop("relations") + mysql_view_schema["relations"][0]["to"] = "parents.id" with pytest.raises(ValidationError): SemanticLayerSchema(**mysql_view_schema) diff --git a/tests/unit_tests/dataframe/test_view_query_builder.py b/tests/unit_tests/dataframe/test_view_query_builder.py new file mode 100644 index 000000000..606cfd6aa --- /dev/null +++ b/tests/unit_tests/dataframe/test_view_query_builder.py @@ -0,0 +1,97 @@ +import pytest + +from pandasai.data_loader.query_builder import QueryBuilder +from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema +from pandasai.data_loader.view_query_builder import ViewQueryBuilder + + +class TestViewQueryBuilder: + @pytest.fixture + def mysql_view_schema(self): + raw_schema = { + "name": "Users", + "columns": [ + {"name": "parents.id"}, + {"name": "parents.name"}, + {"name": "children.name"}, + ], + "relations": [{"from": "parents.id", "to": "children.id"}], + "source": { + "type": "mysql", + "connection": { + "host": "localhost", + "port": 3306, + "database": "test_db", + "user": "test_user", + "password": "test_password", + }, + "view": "true", + }, + } + return SemanticLayerSchema(**raw_schema) + + @pytest.fixture + def view_query_builder(self, mysql_view_schema): + return ViewQueryBuilder(mysql_view_schema) + + def test__init__(self, mysql_view_schema): + query_builder = ViewQueryBuilder(mysql_view_schema) + assert isinstance(query_builder, ViewQueryBuilder) + assert isinstance(query_builder, QueryBuilder) + assert query_builder.schema == mysql_view_schema + + def test_format_query(self, view_query_builder): + query = "SELECT ALL" + formatted_query = view_query_builder.format_query(query) + assert ( + formatted_query + == """WITH Users AS ( SELECT +parents.id AS parents_id, parents.name AS parents_name, children.name AS children_name +FROM parents +JOIN children ON parents.id = children.id) +SELECT ALL""" + ) + + def test_build_query(self, view_query_builder) -> str: + assert ( + view_query_builder.build_query() + == """WITH Users AS ( SELECT +parents.id AS parents_id, parents.name AS parents_name, children.name AS children_name +FROM parents +JOIN children ON parents.id = children.id) +SELECT parents_id, parents_name, children_name FROM Users""" + ) + + def test_get_columns(self, view_query_builder): + assert ( + view_query_builder._get_columns() + == """parents_id, parents_name, children_name""" + ) + + def test_get_columns_empty(self, view_query_builder): + view_query_builder.schema.columns = None + assert view_query_builder._get_columns() == "*" + + def test_get_from_statement(self, view_query_builder): + assert view_query_builder._get_from_statement() == " FROM Users" + + def test_get_with_statement(self, view_query_builder): + assert ( + view_query_builder._get_with_statement() + == """WITH Users AS ( SELECT +parents.id AS parents_id, parents.name AS parents_name, children.name AS children_name +FROM parents +JOIN children ON parents.id = children.id) +""" + ) + + def test_get_with_statement_no_columns(self, view_query_builder): + view_query_builder.schema.columns = None + assert ( + view_query_builder._get_with_statement() + == """WITH Users AS ( SELECT +* +FROM parents +JOIN children ON parents.id = children.id) +""" + ) From 4fcac4528877c8ce6db7a32489cbc71a82b7d109 Mon Sep 17 00:00:00 2001 From: Raoul Date: Tue, 21 Jan 2025 18:08:25 +0100 Subject: [PATCH 6/9] feature(ChatTest): adding first set of numeric questions --- tests/unit_tests/agent/test_agent_chat.py | 117 ++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 tests/unit_tests/agent/test_agent_chat.py diff --git a/tests/unit_tests/agent/test_agent_chat.py b/tests/unit_tests/agent/test_agent_chat.py new file mode 100644 index 000000000..0b44c930a --- /dev/null +++ b/tests/unit_tests/agent/test_agent_chat.py @@ -0,0 +1,117 @@ +import os +from typing import List, Tuple + +import pytest + +import pandasai as pai +from pandasai import DataFrame +from pandasai.core.response import NumberResponse + +# Read the API key from an environment variable +API_KEY = os.getenv("PANDASAI_API_KEY_TEST_CHAT", None) + + +class TestAgentChat: + @pytest.fixture + def pandas_ai(self): + pai.api_key.set(API_KEY) + return pai + + @pytest.mark.parametrize( + "question,expected", + [ + ("What is the total quantity sold across all products and regions?", 105), + ("What is the correlation coefficient between Sales and Profit?", 1.0), + ( + "What is the standard deviation of daily sales for the entire dataset?", + 231.0, + ), + ( + "Give me the number of the highest average profit margin among all regions?", + 0.2, + ), + ( + "What is the difference in total Sales between Product A and Product B across the entire dataset?", + 700, + ), + ("Over the entire dataset, how many days had sales above 900?", 5), + ( + "What was the year-over-year growth in total sales from 2022 to 2023 (in percent)?", + 7.84, + ), + ], + ) + @pytest.mark.integration + @pytest.mark.skipif( + API_KEY is None, reason="API key not set, skipping integration tests" + ) + def test_integration_multiple_numeric_questions( + self, question, expected, pandas_ai + ): + """ + A single integration test that checks 10 numeric questions on a DataFrame + aligned with real-world data analysis scenarios. + """ + + # Sample DataFrame spanning two years (2022-2023), multiple regions and products + df = DataFrame( + { + "Date": [ + "2022-01-01", + "2022-01-02", + "2022-01-03", + "2022-02-01", + "2022-02-02", + "2022-02-03", + "2023-01-01", + "2023-01-02", + "2023-01-03", + "2023-02-01", + "2023-02-02", + "2023-02-03", + ], + "Region": [ + "North", + "North", + "South", + "South", + "East", + "East", + "North", + "North", + "South", + "South", + "East", + "East", + ], + "Product": ["A", "B", "A", "B", "A", "B", "A", "B", "A", "B", "A", "B"], + "Sales": [ + 1000, + 800, + 1200, + 900, + 500, + 700, + 1100, + 850, + 1250, + 950, + 600, + 750, + ], + "Profit": [200, 160, 240, 180, 100, 140, 220, 170, 250, 190, 120, 150], + "Quantity": [10, 8, 12, 9, 5, 7, 11, 8, 13, 9, 6, 7], + } + ) + + response = pandas_ai.chat(question, df) + + assert isinstance( + response, NumberResponse + ), f"Expected a NumberResponse, got {type(response)} for question: {question}" + + model_value = float(response.value) + + assert model_value == pytest.approx(expected, abs=0.5), ( + f"Question: {question}\n" f"Expected: {expected}, Got: {model_value}" + ) From 1b20e4cb752675a8798eb203d4828ad576b99eac Mon Sep 17 00:00:00 2001 From: Raoul Date: Wed, 22 Jan 2025 16:07:50 +0100 Subject: [PATCH 7/9] feature(ChatTest): adding set of questions for loans and heart csv --- tests/unit_tests/agent/test_agent_chat.py | 189 ++++++++++++++++++---- 1 file changed, 155 insertions(+), 34 deletions(-) diff --git a/tests/unit_tests/agent/test_agent_chat.py b/tests/unit_tests/agent/test_agent_chat.py index 0b44c930a..85f880ed1 100644 --- a/tests/unit_tests/agent/test_agent_chat.py +++ b/tests/unit_tests/agent/test_agent_chat.py @@ -1,56 +1,134 @@ import os +from pathlib import Path +from types import UnionType from typing import List, Tuple import pytest import pandasai as pai from pandasai import DataFrame -from pandasai.core.response import NumberResponse +from pandasai.core.response import ( + ChartResponse, + DataFrameResponse, + NumberResponse, + StringResponse, +) # Read the API key from an environment variable API_KEY = os.getenv("PANDASAI_API_KEY_TEST_CHAT", None) +@pytest.mark.skipif( + API_KEY is None, reason="API key not set, skipping integration tests" +) class TestAgentChat: + numeric_questions_with_answer = [ + ("What is the total quantity sold across all products and regions?", 105), + ("What is the correlation coefficient between Sales and Profit?", 1.0), + ( + "What is the standard deviation of daily sales for the entire dataset?", + 231.0, + ), + ( + "Give me the number of the highest average profit margin among all regions?", + 0.2, + ), + ( + "What is the difference in total Sales between Product A and Product B across the entire dataset?", + 700, + ), + ("Over the entire dataset, how many days had sales above 900?", 5), + ( + "What was the year-over-year growth in total sales from 2022 to 2023 (in percent)?", + 7.84, + ), + ] + loans_questions_with_type: List[Tuple[str, type | UnionType]] = [ + ("What is the total number of payments?", NumberResponse), + ("What is the average payment amount?", NumberResponse), + ("How many unique loan IDs are there?", NumberResponse), + ("What is the most common payment amount?", NumberResponse), + ("What is the total amount of payments?", NumberResponse), + ("What is the median payment amount?", NumberResponse), + ("How many payments are above $1000?", NumberResponse), + ( + "What is the minimum and maximum payment?", + (NumberResponse, DataFrameResponse), + ), + ("Show me a monthly trend of payments", (ChartResponse, DataFrameResponse)), + ( + "Show me the distribution of payment amounts", + (ChartResponse, DataFrameResponse), + ), + ("Show me the top 10 payment amounts", DataFrameResponse), + ( + "Give me a summary of payment statistics", + (StringResponse, DataFrameResponse), + ), + ("Show me payments above $1000", DataFrameResponse), + ] + heart_strokes_questions_with_type: List[Tuple[str, type | UnionType]] = [ + ("What is the total number of patients in the dataset?", NumberResponse), + ("How many people had a stroke?", NumberResponse), + ("What is the average age of patients?", NumberResponse), + ("What percentage of patients have hypertension?", NumberResponse), + ("What is the average BMI?", NumberResponse), + ("How many smokers are in the dataset?", NumberResponse), + ("What is the gender distribution?", (ChartResponse, DataFrameResponse)), + ( + "Is there a correlation between age and stroke occurrence?", + (ChartResponse, StringResponse), + ), + ( + "Show me the age distribution of patients", + (ChartResponse, DataFrameResponse), + ), + ("What is the most common work type?", StringResponse), + ( + "Give me a breakdown of stroke occurrences", + (StringResponse, DataFrameResponse), + ), + ("Show me hypertension statistics", (StringResponse, DataFrameResponse)), + ("Give me smoking statistics summary", (StringResponse, DataFrameResponse)), + ("Show me the distribution of work types", (ChartResponse, DataFrameResponse)), + ] + combined_questions_with_type: List[Tuple[str, type | UnionType]] = [ + ( + "Compare payment patterns between age groups", + (ChartResponse, DataFrameResponse), + ), + ( + "Show relationship between payments and health conditions", + (ChartResponse, DataFrameResponse), + ), + ( + "Analyze payment differences between hypertension groups", + (StringResponse, DataFrameResponse), + ), + ( + "Calculate average payments by health condition", + (NumberResponse, DataFrameResponse), + ), + ( + "Show payment distribution across age groups", + (ChartResponse, DataFrameResponse), + ), + ] + + root_dir = Path(__file__).resolve().parents[3] + + hear_stroke_path = root_dir / "examples" / "data" / "heart.csv" + loans_path = root_dir / "examples" / "data" / "loans_payments.csv" + @pytest.fixture def pandas_ai(self): pai.api_key.set(API_KEY) return pai - @pytest.mark.parametrize( - "question,expected", - [ - ("What is the total quantity sold across all products and regions?", 105), - ("What is the correlation coefficient between Sales and Profit?", 1.0), - ( - "What is the standard deviation of daily sales for the entire dataset?", - 231.0, - ), - ( - "Give me the number of the highest average profit margin among all regions?", - 0.2, - ), - ( - "What is the difference in total Sales between Product A and Product B across the entire dataset?", - 700, - ), - ("Over the entire dataset, how many days had sales above 900?", 5), - ( - "What was the year-over-year growth in total sales from 2022 to 2023 (in percent)?", - 7.84, - ), - ], - ) - @pytest.mark.integration - @pytest.mark.skipif( - API_KEY is None, reason="API key not set, skipping integration tests" - ) - def test_integration_multiple_numeric_questions( - self, question, expected, pandas_ai - ): + @pytest.mark.parametrize("question,expected", numeric_questions_with_answer) + def test_numeric_questions(self, question, expected, pandas_ai): """ - A single integration test that checks 10 numeric questions on a DataFrame - aligned with real-world data analysis scenarios. + Test numeric questions to ensure the response match the expected ones. """ # Sample DataFrame spanning two years (2022-2023), multiple regions and products @@ -115,3 +193,46 @@ def test_integration_multiple_numeric_questions( assert model_value == pytest.approx(expected, abs=0.5), ( f"Question: {question}\n" f"Expected: {expected}, Got: {model_value}" ) + + @pytest.mark.parametrize("question,expected", loans_questions_with_type) + def test_loans_questions_type(self, question, expected, pandas_ai): + """ + Test loan-related questions to ensure the response types match the expected ones. + """ + + df = pandas_ai.read_csv(str(self.loans_path)) + + response = pandas_ai.chat(question, df) + + assert isinstance( + response, expected + ), f"Expected type {expected}, got {type(response)} for question: {question}" + + @pytest.mark.parametrize("question,expected", heart_strokes_questions_with_type) + def test_heart_strokes_questions_type(self, question, expected, pandas_ai): + """ + Test heart stoke related questions to ensure the response types match the expected ones. + """ + + df = pandas_ai.read_csv(str(self.hear_stroke_path)) + + response = pandas_ai.chat(question, df) + + assert isinstance( + response, expected + ), f"Expected type {expected}, got {type(response)} for question: {question}" + + @pytest.mark.parametrize("question,expected", combined_questions_with_type) + def test_combined_questions_with_type(self, question, expected, pandas_ai): + """ + Test heart stoke related questions to ensure the response types match the expected ones. + """ + + df1 = pandas_ai.read_csv(str(self.hear_stroke_path)) + loans = pandas_ai.read_csv(str(self.loans_path)) + + response = pandas_ai.chat(question, *(df1, loans)) + + assert isinstance( + response, expected + ), f"Expected type {expected}, got {type(response)} for question: {question}" From dc631318ca936f4d5acd45157d5ce38e7cdb195e Mon Sep 17 00:00:00 2001 From: Raoul Scalise <36519284+scaliseraoul@users.noreply.github.com> Date: Wed, 22 Jan 2025 16:18:40 +0100 Subject: [PATCH 8/9] Update pandasai/data_loader/loader.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- pandasai/data_loader/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandasai/data_loader/loader.py b/pandasai/data_loader/loader.py index f3eedc4ed..f1b3815c1 100644 --- a/pandasai/data_loader/loader.py +++ b/pandasai/data_loader/loader.py @@ -170,7 +170,7 @@ def execute_query(self, query: str) -> pd.DataFrame: return load_function(connection_info, formatted_query) except Exception as e: raise RuntimeError( - f"Failed to execute query for source type '{source_type}' with query: {formatted_query}" + f"Failed to execute query for '{source_type}' with: {formatted_query}" ) from e def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame: From c7665646c068b6b21e4db3e374a8bf2c3c416d41 Mon Sep 17 00:00:00 2001 From: Raoul Date: Wed, 22 Jan 2025 16:19:44 +0100 Subject: [PATCH 9/9] fix(ChatTest): typo in heart_stroke_path --- tests/unit_tests/agent/test_agent_chat.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/agent/test_agent_chat.py b/tests/unit_tests/agent/test_agent_chat.py index 85f880ed1..1d320f4df 100644 --- a/tests/unit_tests/agent/test_agent_chat.py +++ b/tests/unit_tests/agent/test_agent_chat.py @@ -117,7 +117,7 @@ class TestAgentChat: root_dir = Path(__file__).resolve().parents[3] - hear_stroke_path = root_dir / "examples" / "data" / "heart.csv" + heart_stroke_path = root_dir / "examples" / "data" / "heart.csv" loans_path = root_dir / "examples" / "data" / "loans_payments.csv" @pytest.fixture @@ -214,7 +214,7 @@ def test_heart_strokes_questions_type(self, question, expected, pandas_ai): Test heart stoke related questions to ensure the response types match the expected ones. """ - df = pandas_ai.read_csv(str(self.hear_stroke_path)) + df = pandas_ai.read_csv(str(self.heart_stroke_path)) response = pandas_ai.chat(question, df) @@ -228,7 +228,7 @@ def test_combined_questions_with_type(self, question, expected, pandas_ai): Test heart stoke related questions to ensure the response types match the expected ones. """ - df1 = pandas_ai.read_csv(str(self.hear_stroke_path)) + df1 = pandas_ai.read_csv(str(self.heart_stroke_path)) loans = pandas_ai.read_csv(str(self.loans_path)) response = pandas_ai.chat(question, *(df1, loans))