diff --git a/locopy/database.py b/locopy/database.py index 411d095..2112d22 100644 --- a/locopy/database.py +++ b/locopy/database.py @@ -18,6 +18,9 @@ import time +import pandas +import polars + from locopy.errors import CredentialsError, DBError from locopy.logger import INFO, get_logger from locopy.utility import read_config_yaml @@ -188,25 +191,25 @@ def column_names(self): except Exception: return [column[0].lower() for column in self.cursor.description] - def to_dataframe(self, size=None): + def to_dataframe(self, df_type="pandas", size=None): """Return a dataframe of the last query results. - This imports Pandas in here, so that it's not needed for other use cases. This is just a - convenience method. - Parameters ---------- + df_type: Literal["pandas","polars"], optional + Output dataframe format. Defaults to pandas. + size : int, optional Chunk size to fetch. Defaults to None. Returns ------- - pandas.DataFrame + pandas.DataFrame or polars.DataFrame Dataframe with lowercase column names. Returns None if no fetched result. """ - import pandas - + if df_type not in ["pandas", "polars"]: + raise ValueError("df_type must be ``pandas`` or ``polars``.") columns = self.column_names() if size is None: @@ -220,7 +223,11 @@ def to_dataframe(self, size=None): if len(fetched) == 0: return None - return pandas.DataFrame(fetched, columns=columns) + + if df_type == "pandas": + return pandas.DataFrame(fetched, columns=columns) + elif df_type == "polars": + return polars.DataFrame(fetched, schema=columns, orient="row") def to_dict(self): """Generate dictionaries of rows. diff --git a/locopy/redshift.py b/locopy/redshift.py index 84527f4..cf8c56d 100644 --- a/locopy/redshift.py +++ b/locopy/redshift.py @@ -21,8 +21,13 @@ """ import os +from functools import singledispatch from pathlib import Path +import pandas as pd +import polars as pl +import polars.selectors as cs + from locopy.database import Database from locopy.errors import DBError, S3CredentialsError from locopy.logger import INFO, get_logger @@ -537,14 +542,14 @@ def insert_dataframe_to_table( verbose=False, ): """ - Insert a Pandas dataframe to an existing table or a new table. + Insert a Pandas or Polars dataframe to an existing table or a new table. `executemany` in psycopg2 and pg8000 has very poor performance in terms of running speed. To overcome this issue, we instead format the insert query and then run `execute`. Parameters ---------- - dataframe: Pandas Dataframe + dataframe: pandas.DataFrame or polars.DataFrame The pandas dataframe which needs to be inserted. table_name: str @@ -567,8 +572,6 @@ def insert_dataframe_to_table( """ - import pandas as pd - if columns: dataframe = dataframe[columns] @@ -599,9 +602,15 @@ def insert_dataframe_to_table( self.execute(create_query) logger.info("New table has been created") - logger.info("Inserting records...") - for start in range(0, len(dataframe), batch_size): - # create a list of tuples for insert + # create a list of tuples for insert + @singledispatch + def get_insert_tuple(dataframe, start, batch_size): + """Create a list of tuples for insert.""" + pass + + @get_insert_tuple.register(pd.DataFrame) + def get_insert_tuple_pandas(dataframe: pd.DataFrame, start, batch_size): + """Create a list of tuples for insert when dataframe is pd.DataFrame.""" to_insert = [] for row in dataframe[start : (start + batch_size)].itertuples(index=False): none_row = ( @@ -617,9 +626,43 @@ def insert_dataframe_to_table( + ")" ) to_insert.append(none_row) - string_join = ", ".join(to_insert) - insert_query = ( - f"""INSERT INTO {table_name} {column_sql} VALUES {string_join}""" + return to_insert + + @get_insert_tuple.register(pl.DataFrame) + def get_insert_tuple_polars(dataframe: pl.DataFrame, start, batch_size): + """Create a list of tuples for insert when dataframe is pl.DataFrame.""" + to_insert = [] + dataframe = dataframe.with_columns( + dataframe.select(cs.numeric().fill_nan(None)) ) - self.execute(insert_query, verbose=verbose) + for row in dataframe[start : (start + batch_size)].iter_rows(): + none_row = ( + "(" + + ", ".join( + [ + "NULL" + if val is None + else "'" + str(val).replace("'", "''") + "'" + for val in row + ] + ) + + ")" + ) + to_insert.append(none_row) + return to_insert + + logger.info("Inserting records...") + try: + for start in range(0, len(dataframe), batch_size): + to_insert = get_insert_tuple(dataframe, start, batch_size) + string_join = ", ".join(to_insert) + insert_query = ( + f"""INSERT INTO {table_name} {column_sql} VALUES {string_join}""" + ) + self.execute(insert_query, verbose=verbose) + except TypeError: + raise TypeError( + "DataFrame to insert must either be a pandas.DataFrame or polars.DataFrame." + ) from None + logger.info("Table insertion has completed") diff --git a/locopy/snowflake.py b/locopy/snowflake.py index 81fc842..54e28ea 100644 --- a/locopy/snowflake.py +++ b/locopy/snowflake.py @@ -21,8 +21,13 @@ """ import os +from functools import singledispatch from pathlib import PurePath +import pandas as pd +import polars as pl +import polars.selectors as cs + from locopy.database import Database from locopy.errors import DBError, S3CredentialsError from locopy.logger import INFO, get_logger @@ -396,7 +401,7 @@ def unload( def insert_dataframe_to_table( self, dataframe, table_name, columns=None, create=False, metadata=None ): - """Insert a Pandas dataframe to an existing table or a new table. + """Insert a Pandas or Polars dataframe to an existing table or a new table. In newer versions of the python snowflake connector (v2.1.2+) users can call the ``write_pandas`` method from the cursor @@ -408,8 +413,8 @@ def insert_dataframe_to_table( Parameters ---------- - dataframe: Pandas Dataframe - The pandas dataframe which needs to be inserted. + dataframe: Pandas or Polars Dataframe + The pandas or polars dataframe which needs to be inserted. table_name: str The name of the Snowflake table which is being inserted. @@ -423,8 +428,6 @@ def insert_dataframe_to_table( metadata: dictionary, optional If metadata==None, it will be generated based on data """ - import pandas as pd - if columns: dataframe = dataframe[columns] @@ -433,10 +436,39 @@ def insert_dataframe_to_table( string_join = "(" + ",".join(["%s"] * len(all_columns)) + ")" # create a list of tuples for insert - to_insert = [] - for row in dataframe.itertuples(index=False): - none_row = tuple(None if pd.isnull(val) else str(val) for val in row) - to_insert.append(none_row) + @singledispatch + def get_insert_tuple(dataframe): + """Create a list of tuples for insert.""" + pass + + @get_insert_tuple.register(pd.DataFrame) + def get_insert_tuple_pandas(dataframe: pd.DataFrame): + """Create a list of tuples for insert when dataframe is pd.DataFrame.""" + to_insert = [] + for row in dataframe.itertuples(index=False): + none_row = tuple(None if pd.isnull(val) else str(val) for val in row) + to_insert.append(none_row) + return to_insert + + @get_insert_tuple.register(pl.DataFrame) + def get_insert_tuple_polars(dataframe: pl.DataFrame): + """Create a list of tuples for insert when dataframe is pl.DataFrame.""" + to_insert = [] + dataframe = dataframe.with_columns( + dataframe.select(cs.numeric().fill_nan(None)) + ) + for row in dataframe.iter_rows(): + none_row = tuple(None if val is None else str(val) for val in row) + to_insert.append(none_row) + return to_insert + + # create a list of tuples for insert + try: + to_insert = get_insert_tuple(dataframe) + except TypeError: + raise TypeError( + "DataFrame to insert must either be a pandas.DataFrame or polars.DataFrame." + ) from None if not create and metadata: logger.warning("Metadata will not be used because create is set to False.") @@ -468,7 +500,7 @@ def insert_dataframe_to_table( self.execute(insert_query, params=to_insert, many=True) logger.info("Table insertion has completed") - def to_dataframe(self, size=None): + def to_dataframe(self, df_type="pandas", size=None): """Return a dataframe of the last query results. This is just a convenience method. This @@ -479,16 +511,25 @@ def to_dataframe(self, size=None): Parameters ---------- + df_type: Literal["pandas","polars"], optional + Output dataframe format. Defaults to pandas. + size : int, optional Chunk size to fetch. Defaults to None. Returns ------- - pandas.DataFrame + pandas.DataFrame or polars.DataFrame Dataframe with lowercase column names. Returns None if no fetched result. """ + if df_type not in ["pandas", "polars"]: + raise ValueError("df_type must be ``pandas`` or ``polars``.") + if size is None and self.cursor._query_result_format == "arrow": - return self.cursor.fetch_pandas_all() + if df_type == "pandas": + return self.cursor.fetch_pandas_all() + elif df_type == "polars": + return pl.from_arrow(self.cursor.fetch_arrow_all()) else: - return super().to_dataframe(size) + return super().to_dataframe(df_type=df_type, size=size) diff --git a/locopy/utility.py b/locopy/utility.py index 67e8253..7b61f31 100644 --- a/locopy/utility.py +++ b/locopy/utility.py @@ -25,8 +25,11 @@ import sys import threading from collections import OrderedDict +from functools import singledispatch from itertools import cycle +import pandas as pd +import polars as pl import yaml from locopy.errors import ( @@ -253,7 +256,14 @@ def read_config_yaml(config_yaml): # make it more granular, eg. include length +@singledispatch def find_column_type(dataframe, warehouse_type: str): + """Find data type of each column from the dataframe.""" + pass + + +@find_column_type.register(pd.DataFrame) +def find_column_type_pandas(dataframe: pd.DataFrame, warehouse_type: str): """ Find data type of each column from the dataframe. @@ -284,8 +294,6 @@ def find_column_type(dataframe, warehouse_type: str): """ import re - import pandas as pd - def validate_date_object(column): try: pd.to_datetime(column) @@ -342,6 +350,79 @@ def validate_float_object(column): return OrderedDict(zip(list(dataframe.columns), column_type)) +@find_column_type.register(pl.DataFrame) +def find_column_type_polars(dataframe: pl.DataFrame, warehouse_type: str): + """ + Find data type of each column from the dataframe. + + Following is the list of polars data types that the function checks and their mapping in sql: + + - Boolean -> boolean + - Date/Datetime/Duration/Time -> timestamp + - int -> int + - float/decimal -> float + - float object -> float + - datetime object -> timestamp + - others -> varchar + + For all other data types, the column will be mapped to varchar type. + + Parameters + ---------- + dataframe : Pandas dataframe + + warehouse_type: str + Required to properly determine format of uploaded data, either "snowflake" or "redshift". + + Returns + ------- + dict + A dictionary of columns with their data type + """ + + def validate_date_object(column): + try: + column.str.to_datetime() + return "date" + except Exception: + return None + + def validate_float_object(column): + try: + column.cast(pl.UInt32) + return "float" + except Exception: + return None + + if warehouse_type.lower() not in ["snowflake", "redshift"]: + raise ValueError( + 'warehouse_type argument must be either "snowflake" or "redshift"' + ) + + column_type = [] + for column in dataframe.columns: + logger.debug("Checking column: %s", column) + data = dataframe.lazy().select(column).drop_nulls().collect().to_series() + if data.shape[0] == 0: + column_type.append("varchar") + elif data.dtype.is_temporal(): + column_type.append("timestamp") + elif str(data.dtype).lower().startswith("bool"): + column_type.append("boolean") + elif data.dtype.is_integer(): + column_type.append("int") + elif data.dtype.is_numeric(): # cast all non-integer numeric as float + column_type.append("float") + else: + data_type = validate_float_object(data) or validate_date_object(data) + if not data_type: + column_type.append("varchar") + else: + column_type.append(data_type) + logger.info("Parsing column %s to %s", column, column_type[-1]) + return OrderedDict(zip(list(dataframe.columns), column_type)) + + class ProgressPercentage: """ProgressPercentage class is used by the S3Transfer upload_file callback. diff --git a/pyproject.toml b/pyproject.toml index d23b511..6d10c89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = [ { name="Faisal Dosani", email="faisal.dosani@capitalone.com" }, ] license = {text = "Apache Software License"} -dependencies = ["boto3<=1.35.9,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.2.2,>=0.25.2", "numpy<=2.0.2,>=1.22.0"] +dependencies = ["boto3<=1.35.9,>=1.9.92", "PyYAML<=6.0.1,>=5.1", "pandas<=2.2.2,>=0.25.2", "numpy<=2.0.2,>=1.22.0", "polars>=0.20.0"] requires-python = ">=3.9.0" classifiers = [ diff --git a/requirements.txt b/requirements.txt index 25ec73f..3a26156 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,8 @@ numpy==1.26.4 # pandas pandas==2.2.2 # via locopy (pyproject.toml) +polars==1.6.0 + # via locopy (pyproject.toml) python-dateutil==2.9.0.post0 # via # botocore @@ -35,5 +37,5 @@ six==1.16.0 # via python-dateutil tzdata==2024.1 # via pandas -urllib3==2.2.2 +urllib3==1.26.20 # via botocore diff --git a/tests/test_database.py b/tests/test_database.py index 205e605..d8e0777 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -264,6 +264,38 @@ def test_to_dataframe_none(mock_pandas, credentials, dbapi): mock_pandas.assert_not_called() +# TODO: remove dataframe mocking +@pytest.mark.parametrize("dbapi", DBAPIS) +@mock.patch("polars.DataFrame") +def test_to_dataframe_all_polars(mock_polars, credentials, dbapi): + with mock.patch(dbapi.__name__ + ".connect") as mock_connect: + mock_connect.return_value.cursor.return_value.fetchall.return_value = [ + (1, 2), + (2, 3), + (3, 4), + ] + with Database(dbapi=dbapi, **credentials) as test: + test.execute("SELECT 'hello world' AS fld") + df = test.to_dataframe(df_type="polars") + + assert mock_connect.return_value.cursor.return_value.fetchall.called + mock_polars.assert_called_with(test.cursor.fetchall(), schema=[], orient="row") + + +@pytest.mark.parametrize("dbapi", DBAPIS) +def test_to_dataframe_error(credentials, dbapi): + with mock.patch(dbapi.__name__ + ".connect") as mock_connect: + mock_connect.return_value.cursor.return_value.fetchall.return_value = [ + (1, 2), + (2, 3), + (3, 4), + ] + with Database(dbapi=dbapi, **credentials) as test: + test.execute("SELECT 'hello world' AS fld") + with pytest.raises(ValueError): + test.to_dataframe(df_type="invalid") + + @pytest.mark.parametrize("dbapi", DBAPIS) def test_get_column_names(credentials, dbapi): with mock.patch(dbapi.__name__ + ".connect") as mock_connect: diff --git a/tests/test_redshift.py b/tests/test_redshift.py index 55a5f2f..37a2067 100644 --- a/tests/test_redshift.py +++ b/tests/test_redshift.py @@ -909,7 +909,7 @@ def testunload_no_connection(mock_session, credentials, dbapi): @pytest.mark.parametrize("dbapi", DBAPIS) @mock.patch("locopy.s3.Session") -def testinsert_dataframe_to_table(mock_session, credentials, dbapi): +def testinsert_dataframe_to_table_pandas(mock_session, credentials, dbapi): import pandas as pd test_df = pd.read_csv(os.path.join(CURR_DIR, "data", "mock_dataframe.txt"), sep=",") @@ -966,3 +966,66 @@ def testinsert_dataframe_to_table(mock_session, credentials, dbapi): "INSERT INTO database.schema.test (a,b,c) VALUES ('2', 'y', '2001-04-02')", (), ) + + +@pytest.mark.parametrize("dbapi", DBAPIS) +@mock.patch("locopy.s3.Session") +def testinsert_dataframe_to_table_polars(mock_session, credentials, dbapi): + import polars as pl + + test_df = pl.read_csv( + os.path.join(CURR_DIR, "data", "mock_dataframe.txt"), separator="," + ) + with mock.patch(dbapi.__name__ + ".connect") as mock_connect: + r = locopy.Redshift(dbapi=dbapi, **credentials) + r.connect() + r.insert_dataframe_to_table(test_df, "database.schema.test") + mock_connect.return_value.cursor.return_value.execute.assert_called_with( + "INSERT INTO database.schema.test (a,b,c) VALUES ('1', 'x', '2011-01-01'), ('2', 'y', '2001-04-02')", + (), + ) + + r.insert_dataframe_to_table(test_df, "database.schema.test", create=True) + mock_connect.return_value.cursor.return_value.execute.assert_any_call( + "CREATE TABLE database.schema.test (a int,b varchar,c date)", () + ) + mock_connect.return_value.cursor.return_value.execute.assert_called_with( + "INSERT INTO database.schema.test (a,b,c) VALUES ('1', 'x', '2011-01-01'), ('2', 'y', '2001-04-02')", + (), + ) + + r.insert_dataframe_to_table(test_df, "database.schema.test", columns=["a", "b"]) + + mock_connect.return_value.cursor.return_value.execute.assert_called_with( + "INSERT INTO database.schema.test (a,b) VALUES ('1', 'x'), ('2', 'y')", () + ) + + r.insert_dataframe_to_table( + test_df, + "database.schema.test", + create=True, + metadata=OrderedDict( + [("col1", "int"), ("col2", "varchar"), ("col3", "date")] + ), + ) + + mock_connect.return_value.cursor.return_value.execute.assert_any_call( + "CREATE TABLE database.schema.test (col1 int,col2 varchar,col3 date)", () + ) + mock_connect.return_value.cursor.return_value.execute.assert_called_with( + "INSERT INTO database.schema.test (col1,col2,col3) VALUES ('1', 'x', '2011-01-01'), ('2', 'y', '2001-04-02')", + (), + ) + + r.insert_dataframe_to_table( + test_df, "database.schema.test", create=False, batch_size=1 + ) + + mock_connect.return_value.cursor.return_value.execute.assert_any_call( + "INSERT INTO database.schema.test (a,b,c) VALUES ('1', 'x', '2011-01-01')", + (), + ) + mock_connect.return_value.cursor.return_value.execute.assert_any_call( + "INSERT INTO database.schema.test (a,b,c) VALUES ('2', 'y', '2001-04-02')", + (), + ) diff --git a/tests/test_snowflake.py b/tests/test_snowflake.py index d24e76b..75ea7f6 100644 --- a/tests/test_snowflake.py +++ b/tests/test_snowflake.py @@ -28,11 +28,14 @@ import hypothesis.strategies as s import locopy +import polars as pl +import pyarrow as pa import pytest import snowflake.connector from hypothesis import HealthCheck, given, settings from locopy import Snowflake from locopy.errors import DBError +from polars.testing import assert_frame_equal PROFILE = "test" KMS = "kms_test" @@ -403,9 +406,6 @@ def test_unload_exception(mock_session, sf_credentials): @mock.patch("locopy.s3.Session") def test_to_pandas(mock_session, sf_credentials): - import pandas as pd - - test_df = pd.read_csv(os.path.join(CURR_DIR, "data", "mock_dataframe.txt"), sep=",") with ( mock.patch("snowflake.connector.connect") as mock_connect, Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, @@ -418,12 +418,48 @@ def test_to_pandas(mock_session, sf_credentials): sf.to_dataframe() sf.conn.cursor.return_value.fetchall.assert_called_with() - sf.to_dataframe(5) + sf.to_dataframe(size=5) sf.conn.cursor.return_value.fetchmany.assert_called_with(5) @mock.patch("locopy.s3.Session") -def test_insert_dataframe_to_table(mock_session, sf_credentials): +def test_to_polars(mock_session, sf_credentials): + with ( + mock.patch("snowflake.connector.connect") as mock_connect, + Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, + ): + sf.cursor._query_result_format = "arrow" + sf.conn.cursor.return_value.fetch_arrow_all.return_value = pa.table( + {"a": [1, 2, 3], "b": [4, 5, 6]} + ) + polars_df = sf.to_dataframe(df_type="polars") + expected_df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + assert_frame_equal(polars_df, expected_df) + + sf.cursor._query_result_format = "json" + sf.to_dataframe(df_type="polars") + sf.conn.cursor.return_value.fetchall.assert_called_with() + + sf.to_dataframe(df_type="polars", size=5) + sf.conn.cursor.return_value.fetchmany.assert_called_with(5) + + +@mock.patch("locopy.s3.Session") +def test_to_dataframe_error(mock_session, sf_credentials): + with ( + mock.patch("snowflake.connector.connect") as mock_connect, + Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, + ): + sf.cursor._query_result_format = "arrow" + sf.conn.cursor.return_value.fetch_arrow_all.return_value = pa.table( + {"a": [1, 2, 3], "b": [4, 5, 6]} + ) + with pytest.raises(ValueError): + polars_df = sf.to_dataframe(df_type="invalid") + + +@mock.patch("locopy.s3.Session") +def test_insert_pd_dataframe_to_table(mock_session, sf_credentials): import pandas as pd test_df = pd.read_csv(os.path.join(CURR_DIR, "data", "mock_dataframe.txt"), sep=",") @@ -487,3 +523,72 @@ def test_insert_dataframe_to_table(mock_session, sf_credentials): "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)", [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")], ) + + +@mock.patch("locopy.s3.Session") +def test_insert_pl_dataframe_to_table(mock_session, sf_credentials): + import polars as pl + + test_df = pl.read_csv( + os.path.join(CURR_DIR, "data", "mock_dataframe.txt"), separator="," + ) + with ( + mock.patch("snowflake.connector.connect") as mock_connect, + Snowflake(profile=PROFILE, dbapi=DBAPIS, **sf_credentials) as sf, + ): + sf.insert_dataframe_to_table(test_df, "database.schema.test") + sf.conn.cursor.return_value.executemany.assert_called_with( + "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)", + [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")], + ) + + sf.insert_dataframe_to_table(test_df, "database.schema.test", create=True) + sf.conn.cursor.return_value.execute.assert_any_call( + "CREATE TABLE database.schema.test (a int,b varchar,c date)", () + ) + sf.conn.cursor.return_value.executemany.assert_called_with( + "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)", + [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")], + ) + + sf.insert_dataframe_to_table( + test_df, "database.schema.test", columns=["a", "b"] + ) + + sf.conn.cursor.return_value.executemany.assert_called_with( + "INSERT INTO database.schema.test (a,b) VALUES (%s,%s)", + [("1", "x"), ("2", "y")], + ) + + sf.insert_dataframe_to_table( + test_df, + "database.schema.test", + create=True, + metadata=OrderedDict( + [("col1", "int"), ("col2", "varchar"), ("col3", "date")] + ), + ) + + sf.conn.cursor.return_value.execute.assert_any_call( + "CREATE TABLE database.schema.test (col1 int,col2 varchar,col3 date)", + (), + ) + sf.conn.cursor.return_value.executemany.assert_called_with( + "INSERT INTO database.schema.test (col1,col2,col3) VALUES (%s,%s,%s)", + [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")], + ) + + sf.insert_dataframe_to_table( + test_df, + "database.schema.test", + create=False, + metadata=OrderedDict( + [("col1", "int"), ("col2", "varchar"), ("col3", "date")] + ), + ) + + # mock_session.warn.assert_called_with('Metadata will not be used because create is set to False.') + sf.conn.cursor.return_value.executemany.assert_called_with( + "INSERT INTO database.schema.test (a,b,c) VALUES (%s,%s,%s)", + [("1", "x", "2011-01-01"), ("2", "y", "2001-04-02")], + )