diff --git a/pandasai/__init__.py b/pandasai/__init__.py index bca62101e..d2148490d 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -88,7 +88,7 @@ def create( schema_path = os.path.join(dataset_directory, "schema.yaml") schema = df.schema - schema.name = name or schema.name + schema.name = sanitize_sql_table_name(name or schema.name) schema.description = description or schema.description if columns: schema.columns = list(map(lambda column: Column(**column), columns)) @@ -204,7 +204,7 @@ def load(dataset_path: str) -> DataFrame: def read_csv(filepath: str) -> DataFrame: data = pd.read_csv(filepath) name = f"table_{sanitize_sql_table_name(filepath)}" - return DataFrame(data._data, name=name) + return DataFrame(data, name=name) __all__ = [ diff --git a/poetry.lock b/poetry.lock index c183dc00f..184a41258 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6,7 +6,7 @@ version = "0.7.0" description = "Reusable constraint types to use with typing.Annotated" optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, @@ -15,6 +15,29 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} +[[package]] +name = "anyio" +version = "4.5.2" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "anyio-4.5.2-py3-none-any.whl", hash = "sha256:c011ee36bc1e8ba40e5a81cb9df91925c218fe9b778554e0b56a21e1b5d4716f"}, + {file = "anyio-4.5.2.tar.gz", hash = "sha256:23009af4ed04ce05991845451e11ef02fc7c5ed29179ac9a420e5ad0ac7ddc5b"}, +] + +[package.dependencies] +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +idna = ">=2.8" +sniffio = ">=1.1" +typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} + +[package.extras] +doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] +trio = ["trio (>=0.26.1)"] + [[package]] name = "astor" version = "0.8.1" @@ -73,7 +96,7 @@ version = "2024.12.14" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56"}, {file = "certifi-2024.12.14.tar.gz", hash = "sha256:b650d30f370c2b724812bee08008be0c4163b163ddaec3f2546c1caf65f191db"}, @@ -237,7 +260,7 @@ files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -markers = {main = "platform_system == \"Windows\"", dev = "sys_platform == \"win32\" or platform_system == \"Windows\""} +markers = {main = "platform_system == \"Windows\"", dev = "platform_system == \"Windows\" or sys_platform == \"win32\""} [[package]] name = "contourpy" @@ -424,6 +447,18 @@ files = [ {file = "distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403"}, ] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +groups = ["dev"] +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "duckdb" version = "1.1.3" @@ -680,6 +715,65 @@ files = [ astunparse = {version = ">=1.6", markers = "python_version < \"3.9\""} colorama = ">=0.4" +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +groups = ["dev"] +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "1.0.7" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd"}, + {file = "httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<1.0)"] + +[[package]] +name = "httpx" +version = "0.28.1" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"}, + {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +zstd = ["zstandard (>=0.18.0)"] + [[package]] name = "huggingface-hub" version = "0.27.1" @@ -736,7 +830,7 @@ version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -824,6 +918,92 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jiter" +version = "0.8.2" +description = "Fast iterable JSON parser." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "jiter-0.8.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:ca8577f6a413abe29b079bc30f907894d7eb07a865c4df69475e868d73e71c7b"}, + {file = "jiter-0.8.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b25bd626bde7fb51534190c7e3cb97cee89ee76b76d7585580e22f34f5e3f393"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5c826a221851a8dc028eb6d7d6429ba03184fa3c7e83ae01cd6d3bd1d4bd17d"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d35c864c2dff13dfd79fb070fc4fc6235d7b9b359efe340e1261deb21b9fcb66"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f557c55bc2b7676e74d39d19bcb8775ca295c7a028246175d6a8b431e70835e5"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:580ccf358539153db147e40751a0b41688a5ceb275e6f3e93d91c9467f42b2e3"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af102d3372e917cffce49b521e4c32c497515119dc7bd8a75665e90a718bbf08"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cadcc978f82397d515bb2683fc0d50103acff2a180552654bb92d6045dec2c49"}, + {file = "jiter-0.8.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ba5bdf56969cad2019d4e8ffd3f879b5fdc792624129741d3d83fc832fef8c7d"}, + {file = "jiter-0.8.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3b94a33a241bee9e34b8481cdcaa3d5c2116f575e0226e421bed3f7a6ea71cff"}, + {file = "jiter-0.8.2-cp310-cp310-win32.whl", hash = "sha256:6e5337bf454abddd91bd048ce0dca5134056fc99ca0205258766db35d0a2ea43"}, + {file = "jiter-0.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:4a9220497ca0cb1fe94e3f334f65b9b5102a0b8147646118f020d8ce1de70105"}, + {file = "jiter-0.8.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:2dd61c5afc88a4fda7d8b2cf03ae5947c6ac7516d32b7a15bf4b49569a5c076b"}, + {file = "jiter-0.8.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a6c710d657c8d1d2adbbb5c0b0c6bfcec28fd35bd6b5f016395f9ac43e878a15"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9584de0cd306072635fe4b89742bf26feae858a0683b399ad0c2509011b9dc0"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5a90a923338531b7970abb063cfc087eebae6ef8ec8139762007188f6bc69a9f"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d21974d246ed0181558087cd9f76e84e8321091ebfb3a93d4c341479a736f099"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:32475a42b2ea7b344069dc1e81445cfc00b9d0e3ca837f0523072432332e9f74"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b9931fd36ee513c26b5bf08c940b0ac875de175341cbdd4fa3be109f0492586"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ce0820f4a3a59ddced7fce696d86a096d5cc48d32a4183483a17671a61edfddc"}, + {file = "jiter-0.8.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8ffc86ae5e3e6a93765d49d1ab47b6075a9c978a2b3b80f0f32628f39caa0c88"}, + {file = "jiter-0.8.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5127dc1abd809431172bc3fbe8168d6b90556a30bb10acd5ded41c3cfd6f43b6"}, + {file = "jiter-0.8.2-cp311-cp311-win32.whl", hash = "sha256:66227a2c7b575720c1871c8800d3a0122bb8ee94edb43a5685aa9aceb2782d44"}, + {file = "jiter-0.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:cde031d8413842a1e7501e9129b8e676e62a657f8ec8166e18a70d94d4682855"}, + {file = "jiter-0.8.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:e6ec2be506e7d6f9527dae9ff4b7f54e68ea44a0ef6b098256ddf895218a2f8f"}, + {file = "jiter-0.8.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:76e324da7b5da060287c54f2fabd3db5f76468006c811831f051942bf68c9d44"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:180a8aea058f7535d1c84183c0362c710f4750bef66630c05f40c93c2b152a0f"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:025337859077b41548bdcbabe38698bcd93cfe10b06ff66617a48ff92c9aec60"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecff0dc14f409599bbcafa7e470c00b80f17abc14d1405d38ab02e4b42e55b57"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ffd9fee7d0775ebaba131f7ca2e2d83839a62ad65e8e02fe2bd8fc975cedeb9e"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14601dcac4889e0a1c75ccf6a0e4baf70dbc75041e51bcf8d0e9274519df6887"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:92249669925bc1c54fcd2ec73f70f2c1d6a817928480ee1c65af5f6b81cdf12d"}, + {file = "jiter-0.8.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e725edd0929fa79f8349ab4ec7f81c714df51dc4e991539a578e5018fa4a7152"}, + {file = "jiter-0.8.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bf55846c7b7a680eebaf9c3c48d630e1bf51bdf76c68a5f654b8524335b0ad29"}, + {file = "jiter-0.8.2-cp312-cp312-win32.whl", hash = "sha256:7efe4853ecd3d6110301665a5178b9856be7e2a9485f49d91aa4d737ad2ae49e"}, + {file = "jiter-0.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:83c0efd80b29695058d0fd2fa8a556490dbce9804eac3e281f373bbc99045f6c"}, + {file = "jiter-0.8.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:ca1f08b8e43dc3bd0594c992fb1fd2f7ce87f7bf0d44358198d6da8034afdf84"}, + {file = "jiter-0.8.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5672a86d55416ccd214c778efccf3266b84f87b89063b582167d803246354be4"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58dc9bc9767a1101f4e5e22db1b652161a225874d66f0e5cb8e2c7d1c438b587"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:37b2998606d6dadbb5ccda959a33d6a5e853252d921fec1792fc902351bb4e2c"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4ab9a87f3784eb0e098f84a32670cfe4a79cb6512fd8f42ae3d0709f06405d18"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:79aec8172b9e3c6d05fd4b219d5de1ac616bd8da934107325a6c0d0e866a21b6"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:711e408732d4e9a0208008e5892c2966b485c783cd2d9a681f3eb147cf36c7ef"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:653cf462db4e8c41995e33d865965e79641ef45369d8a11f54cd30888b7e6ff1"}, + {file = "jiter-0.8.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:9c63eaef32b7bebac8ebebf4dabebdbc6769a09c127294db6babee38e9f405b9"}, + {file = "jiter-0.8.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:eb21aaa9a200d0a80dacc7a81038d2e476ffe473ffdd9c91eb745d623561de05"}, + {file = "jiter-0.8.2-cp313-cp313-win32.whl", hash = "sha256:789361ed945d8d42850f919342a8665d2dc79e7e44ca1c97cc786966a21f627a"}, + {file = "jiter-0.8.2-cp313-cp313-win_amd64.whl", hash = "sha256:ab7f43235d71e03b941c1630f4b6e3055d46b6cb8728a17663eaac9d8e83a865"}, + {file = "jiter-0.8.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b426f72cd77da3fec300ed3bc990895e2dd6b49e3bfe6c438592a3ba660e41ca"}, + {file = "jiter-0.8.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2dd880785088ff2ad21ffee205e58a8c1ddabc63612444ae41e5e4b321b39c0"}, + {file = "jiter-0.8.2-cp313-cp313t-win_amd64.whl", hash = "sha256:3ac9f578c46f22405ff7f8b1f5848fb753cc4b8377fbec8470a7dc3997ca7566"}, + {file = "jiter-0.8.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:9e1fa156ee9454642adb7e7234a383884452532bc9d53d5af2d18d98ada1d79c"}, + {file = "jiter-0.8.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cf5dfa9956d96ff2efb0f8e9c7d055904012c952539a774305aaaf3abdf3d6c"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e52bf98c7e727dd44f7c4acb980cb988448faeafed8433c867888268899b298b"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a2ecaa3c23e7a7cf86d00eda3390c232f4d533cd9ddea4b04f5d0644faf642c5"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:08d4c92bf480e19fc3f2717c9ce2aa31dceaa9163839a311424b6862252c943e"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:99d9a1eded738299ba8e106c6779ce5c3893cffa0e32e4485d680588adae6db8"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d20be8b7f606df096e08b0b1b4a3c6f0515e8dac296881fe7461dfa0fb5ec817"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d33f94615fcaf872f7fd8cd98ac3b429e435c77619777e8a449d9d27e01134d1"}, + {file = "jiter-0.8.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:317b25e98a35ffec5c67efe56a4e9970852632c810d35b34ecdd70cc0e47b3b6"}, + {file = "jiter-0.8.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fc9043259ee430ecd71d178fccabd8c332a3bf1e81e50cae43cc2b28d19e4cb7"}, + {file = "jiter-0.8.2-cp38-cp38-win32.whl", hash = "sha256:fc5adda618205bd4678b146612ce44c3cbfdee9697951f2c0ffdef1f26d72b63"}, + {file = "jiter-0.8.2-cp38-cp38-win_amd64.whl", hash = "sha256:cd646c827b4f85ef4a78e4e58f4f5854fae0caf3db91b59f0d73731448a970c6"}, + {file = "jiter-0.8.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:e41e75344acef3fc59ba4765df29f107f309ca9e8eace5baacabd9217e52a5ee"}, + {file = "jiter-0.8.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f22b16b35d5c1df9dfd58843ab2cd25e6bf15191f5a236bed177afade507bfc"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7200b8f7619d36aa51c803fd52020a2dfbea36ffec1b5e22cab11fd34d95a6d"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:70bf4c43652cc294040dbb62256c83c8718370c8b93dd93d934b9a7bf6c4f53c"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f9d471356dc16f84ed48768b8ee79f29514295c7295cb41e1133ec0b2b8d637d"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:859e8eb3507894093d01929e12e267f83b1d5f6221099d3ec976f0c995cb6bd9"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaa58399c01db555346647a907b4ef6d4f584b123943be6ed5588c3f2359c9f4"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8f2d5ed877f089862f4c7aacf3a542627c1496f972a34d0474ce85ee7d939c27"}, + {file = "jiter-0.8.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:03c9df035d4f8d647f8c210ddc2ae0728387275340668fb30d2421e17d9a0841"}, + {file = "jiter-0.8.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8bd2a824d08d8977bb2794ea2682f898ad3d8837932e3a74937e93d62ecbb637"}, + {file = "jiter-0.8.2-cp39-cp39-win32.whl", hash = "sha256:ca29b6371ebc40e496995c94b988a101b9fbbed48a51190a4461fcb0a68b4a36"}, + {file = "jiter-0.8.2-cp39-cp39-win_amd64.whl", hash = "sha256:1c0dfbd1be3cbefc7510102370d86e35d1d53e5a93d48519688b1bf0f761160a"}, + {file = "jiter-0.8.2.tar.gz", hash = "sha256:cd73d3e740666d0e639f678adb176fad25c1bcbdae88d8d7b857e1783bb4212d"}, +] + [[package]] name = "kiwisolver" version = "1.4.7" @@ -1276,6 +1456,32 @@ files = [ {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, ] +[[package]] +name = "openai" +version = "1.60.0" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "openai-1.60.0-py3-none-any.whl", hash = "sha256:df06c43be8018274980ac363da07d4b417bd835ead1c66e14396f6f15a0d5dda"}, + {file = "openai-1.60.0.tar.gz", hash = "sha256:7fa536cd4b644718645b874d2706e36dbbef38b327e42ca0623275da347ee1a9"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +jiter = ">=0.4.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.11,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] +realtime = ["websockets (>=13,<15)"] + [[package]] name = "openpyxl" version = "3.1.5" @@ -1540,7 +1746,7 @@ version = "2.10.5" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "pydantic-2.10.5-py3-none-any.whl", hash = "sha256:4dd4e322dbe55472cb7ca7e73f4b63574eecccf2835ffa2af9021ce113c83c53"}, {file = "pydantic-2.10.5.tar.gz", hash = "sha256:278b38dbbaec562011d659ee05f63346951b3a248a6f3642e1bc68894ea2b4ff"}, @@ -1561,7 +1767,7 @@ version = "2.27.2" description = "Core functionality for Pydantic validation and serialization" optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "pydantic_core-2.27.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2d367ca20b2f14095a8f4fa1210f5a7b78b8a20009ecced6b12818f455b1e9fa"}, {file = "pydantic_core-2.27.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:491a2b73db93fab69731eaee494f320faa4e093dbed776be1a829c2eb222c34c"}, @@ -2145,6 +2351,18 @@ files = [ {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, ] +[[package]] +name = "sniffio" +version = "1.3.1" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +groups = ["dev"] +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + [[package]] name = "soupsieve" version = "2.6" @@ -2440,7 +2658,7 @@ version = "4.67.1" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"}, {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"}, @@ -2532,7 +2750,7 @@ version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" -groups = ["main", "docs"] +groups = ["main", "dev", "docs"] files = [ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, @@ -2681,4 +2899,4 @@ google-sheets = ["beautifulsoup4"] [metadata] lock-version = "2.1" python-versions = ">=3.8,<3.12" -content-hash = "033adfd805ed228d5ae0f15a49a61f03fa197bd641ba14940e52786c682e674c" +content-hash = "6a1c56ea16ea56fea6f9d2c6b54b820643440a3c35582384a007193e89342705" diff --git a/pyproject.toml b/pyproject.toml index e606f2520..e877aefd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ pytest-env = "^0.8.1" click = "^8.1.3" coverage = "^7.2.7" sourcery = "^1.11.0" +openai = "^1.60.0" [tool.poetry.extras] google-sheets = ["beautifulsoup4"] diff --git a/tests/unit_tests/agent/test_agent_chat.py b/tests/unit_tests/agent/test_agent_chat.py index 1d320f4df..f7a6f85c2 100644 --- a/tests/unit_tests/agent/test_agent_chat.py +++ b/tests/unit_tests/agent/test_agent_chat.py @@ -1,4 +1,5 @@ import os +import shutil from pathlib import Path from types import UnionType from typing import List, Tuple @@ -6,7 +7,7 @@ import pytest import pandasai as pai -from pandasai import DataFrame +from pandasai import DataFrame, find_project_root from pandasai.core.response import ( ChartResponse, DataFrameResponse, @@ -15,13 +16,17 @@ ) # Read the API key from an environment variable -API_KEY = os.getenv("PANDASAI_API_KEY_TEST_CHAT", None) +API_KEY = os.getenv("PANDABI_API_KEY_TEST_CHAT", None) @pytest.mark.skipif( API_KEY is None, reason="API key not set, skipping integration tests" ) class TestAgentChat: + root_dir = find_project_root() + cache_path = os.path.join(root_dir, "cache") + heart_stroke_path = os.path.join(root_dir, "examples", "data", "heart.csv") + loans_path = os.path.join(root_dir, "examples", "data", "loans_payments.csv") numeric_questions_with_answer = [ ("What is the total quantity sold across all products and regions?", 105), ("What is the correlation coefficient between Sales and Profit?", 1.0), @@ -115,10 +120,9 @@ class TestAgentChat: ), ] - root_dir = Path(__file__).resolve().parents[3] - - heart_stroke_path = root_dir / "examples" / "data" / "heart.csv" - loans_path = root_dir / "examples" / "data" / "loans_payments.csv" + @pytest.fixture(autouse=True) + def setup(self): + shutil.rmtree(self.cache_path, ignore_errors=True) @pytest.fixture def pandas_ai(self): @@ -228,10 +232,10 @@ def test_combined_questions_with_type(self, question, expected, pandas_ai): Test heart stoke related questions to ensure the response types match the expected ones. """ - df1 = pandas_ai.read_csv(str(self.heart_stroke_path)) + heart_stroke = pandas_ai.read_csv(str(self.heart_stroke_path)) loans = pandas_ai.read_csv(str(self.loans_path)) - response = pandas_ai.chat(question, *(df1, loans)) + response = pandas_ai.chat(question, *(heart_stroke, loans)) assert isinstance( response, expected diff --git a/tests/unit_tests/agent/test_agent_llm_judge.py b/tests/unit_tests/agent/test_agent_llm_judge.py new file mode 100644 index 000000000..1d0c0b570 --- /dev/null +++ b/tests/unit_tests/agent/test_agent_llm_judge.py @@ -0,0 +1,213 @@ +import os +import shutil +from pathlib import Path + +import pytest +from openai import OpenAI +from pydantic import BaseModel + +import pandasai as pai +from pandasai import DataFrame, find_project_root + +# Read the API key from an environment variable +JUDGE_OPENAI_API_KEY = os.getenv("JUDGE_OPENAI_API_KEY", None) + + +class Evaluation(BaseModel): + score: int + justification: str + + +@pytest.mark.skipif( + JUDGE_OPENAI_API_KEY is None, + reason="JUDGE_OPENAI_API_KEY key not set, skipping tests", +) +class TestAgentLLMJudge: + root_dir = find_project_root() + cache_path = os.path.join(root_dir, "cache") + heart_stroke_path = os.path.join(root_dir, "examples", "data", "heart.csv") + loans_path = os.path.join(root_dir, "examples", "data", "loans_payments.csv") + + loans_questions = [ + "What is the total number of payments?", + "What is the average payment amount?", + "How many unique loan IDs are there?", + "What is the most common payment amount?", + "What is the total amount of payments?", + "What is the median payment amount?", + "How many payments are above $1000?", + "What is the minimum and maximum payment?", + "Show me a monthly trend of payments", + "Show me the distribution of payment amounts", + "Show me the top 10 payment amounts", + "Give me a summary of payment statistics", + "Show me payments above $1000", + ] + + heart_strokes_questions = [ + "What is the total number of patients in the dataset?", + "How many people had a stroke?", + "What is the average age of patients?", + "What percentage of patients have hypertension?", + "What is the average BMI?", + "How many smokers are in the dataset?", + "What is the gender distribution?", + "Is there a correlation between age and stroke occurrence?", + "Show me the age distribution of patients.", + "What is the most common work type?", + "Give me a breakdown of stroke occurrences.", + "Show me hypertension statistics.", + "Give me smoking statistics summary.", + "Show me the distribution of work types.", + ] + + combined_questions = [ + "Compare payment patterns between age groups.", + "Show relationship between payments and health conditions.", + "Analyze payment differences between hypertension groups.", + "Calculate average payments by health condition.", + "Show payment distribution across age groups.", + ] + + evaluation_scores = [] + + @pytest.fixture(autouse=True) + def setup(self): + """Setup shared resources for the test class.""" + + shutil.rmtree(self.cache_path, ignore_errors=True) + + self.client = OpenAI(api_key=JUDGE_OPENAI_API_KEY) + + self.evaluation_prompt = ( + "You are an AI evaluation expert tasked with assessing the quality of a code snippet provided as a response.\n" + "The question was: {question}\n" + "The AI provided the following code:\n" + "{code}\n\n" + "Here is the context summary of the data:\n" + "{context}\n\n" + "Evaluate the code based on the following criteria:\n" + "- Correctness: Does the code achieve the intended goal or answer the question accurately?\n" + "- Efficiency: Is the code optimized and avoids unnecessary computations or steps?\n" + "- Clarity: Is the code written in a clear and understandable way?\n" + "- Robustness: Does the code handle potential edge cases or errors gracefully?\n" + "- Best Practices: Does the code follow standard coding practices and conventions?\n" + "The code should only use the function execute_sql_query(sql_query: str) -> pd.Dataframe to connects to the database and get the data" + "The code should declare the result variable as a dictionary with the following structure:\n" + "'type': 'string', 'value': f'The highest salary is 2.' or 'type': 'number', 'value': 125 or 'type': 'dataframe', 'value': pd.DataFrame() or 'type': 'plot', 'value': 'temp_chart.png'\n" + ) + + def test_judge_setup(self): + """Test evaluation setup with OpenAI.""" + question = "How many unique loan IDs are there?" + + df = pai.read_csv(str(self.loans_path)) + df_context = DataFrame.serialize_dataframe(df) + + response = df.chat(question) + + prompt = self.evaluation_prompt.format( + context=df_context, question=question, code=response.last_code_executed + ) + + completion = self.client.beta.chat.completions.parse( + model="gpt-4o-mini", + messages=[{"role": "user", "content": prompt}], + response_format=Evaluation, + ) + + evaluation_response: Evaluation = completion.choices[0].message.parsed + + self.evaluation_scores.append(evaluation_response.score) + + assert evaluation_response.score > 5, evaluation_response.justification + + @pytest.mark.parametrize("question", loans_questions) + def test_loans_questions(self, question): + """Test multiple loan-related questions.""" + + df = pai.read_csv(str(self.loans_path)) + df_context = DataFrame.serialize_dataframe(df) + + response = df.chat(question) + + prompt = self.evaluation_prompt.format( + context=df_context, question=question, code=response.last_code_executed + ) + + completion = self.client.beta.chat.completions.parse( + model="gpt-4o-mini", + messages=[{"role": "user", "content": prompt}], + response_format=Evaluation, + ) + + evaluation_response: Evaluation = completion.choices[0].message.parsed + + self.evaluation_scores.append(evaluation_response.score) + + assert evaluation_response.score > 5, evaluation_response.justification + + @pytest.mark.parametrize("question", heart_strokes_questions) + def test_heart_strokes_questions(self, question): + """Test multiple loan-related questions.""" + + self.df = pai.read_csv(str(self.heart_stroke_path)) + df_context = DataFrame.serialize_dataframe(self.df) + + response = self.df.chat(question) + + prompt = self.evaluation_prompt.format( + context=df_context, question=question, code=response.last_code_executed + ) + + completion = self.client.beta.chat.completions.parse( + model="gpt-4o-mini", + messages=[{"role": "user", "content": prompt}], + response_format=Evaluation, + ) + + evaluation_response: Evaluation = completion.choices[0].message.parsed + + self.evaluation_scores.append(evaluation_response.score) + + assert evaluation_response.score > 5, evaluation_response.justification + + @pytest.mark.parametrize("question", combined_questions) + def test_combined_questions_with_type(self, question): + """ + Test heart stoke related questions to ensure the response types match the expected ones. + """ + + heart_stroke = pai.read_csv(str(self.heart_stroke_path)) + loans = pai.read_csv(str(self.loans_path)) + + df_context = f"{DataFrame.serialize_dataframe(heart_stroke)}\n{DataFrame.serialize_dataframe(loans)}" + + response = pai.chat(question, *(heart_stroke, loans)) + + prompt = self.evaluation_prompt.format( + context=df_context, question=question, code=response.last_code_executed + ) + + completion = self.client.beta.chat.completions.parse( + model="gpt-4o-mini", + messages=[{"role": "user", "content": prompt}], + response_format=Evaluation, + ) + + evaluation_response: Evaluation = completion.choices[0].message.parsed + + self.evaluation_scores.append(evaluation_response.score) + + assert evaluation_response.score > 5, evaluation_response.justification + + @pytest.mark.final + def test_average_score(self): + if self.evaluation_scores: + average_score = sum(self.evaluation_scores) / len(self.evaluation_scores) + file_path = Path(self.root_dir) / "test_agent_llm_judge.txt" + with open(file_path, "w") as f: + f.write(f"{average_score}") + assert ( + average_score >= 5 + ), f"Average score should be at least 5, got {average_score}" diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 78cb0814f..c99925cc2 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -1,7 +1,12 @@ +import os +import statistics +from pathlib import Path from unittest.mock import MagicMock, patch import pytest +from pandasai import find_project_root + @pytest.fixture(scope="session") def mock_json_load(): @@ -9,3 +14,19 @@ def mock_json_load(): with patch("json.load", mock): yield mock + + +def pytest_terminal_summary(terminalreporter, exitstatus): + scores_file = Path(find_project_root()) / "test_agent_llm_judge.txt" + + if os.path.exists(scores_file): + with open(scores_file, "r") as file: + score_line = file.readline().strip() + + # Ensure the line is a valid number + if score_line.replace(".", "", 1).isdigit(): + avg_score = float(score_line) + terminalreporter.write(f"\n--- Evaluation Score Summary ---\n") + terminalreporter.write(f"Average Score: {avg_score:.2f}\n") + + os.remove(scores_file)