Skip to content

Commit

Permalink
feat: Automatically add new columns to Redshift table during COPY ope…
Browse files Browse the repository at this point in the history
…ration
  • Loading branch information
jack-dell committed Sep 9, 2024
1 parent 4e257e1 commit 97c0cc5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 32 deletions.
62 changes: 34 additions & 28 deletions awswrangler/redshift/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _get_primary_keys(cursor: "redshift_connector.Cursor", schema: str, table: s


def _get_table_columns(cursor: "redshift_connector.Cursor", schema: str, table: str) -> list[str]:
sql = f"SELECT column_name FROM svv_columns\n" f"WHERE table_schema = '{schema}' AND table_name = '{table}'"
sql = f"SELECT column_name FROM svv_columns\n WHERE table_schema = '{schema}' AND table_name = '{table}'"
_logger.debug("Executing select query:\n%s", sql)
cursor.execute(sql)
result: tuple[list[str]] = cursor.fetchall()
Expand All @@ -119,7 +119,7 @@ def _add_table_columns(
cursor: "redshift_connector.Cursor", schema: str, table: str, new_columns: dict[str, str]
) -> None:
for column_name, column_type in new_columns.items():
sql = f"ALTER TABLE {_identifier(schema)}.{_identifier(table)}\n" f"ADD COLUMN {column_name} {column_type};"
sql = f"ALTER TABLE {_identifier(schema)}.{_identifier(table)}\n ADD COLUMN {column_name} {column_type};"
_logger.debug("Executing alter query:\n%s", sql)
cursor.execute(sql)

Expand All @@ -146,6 +146,16 @@ def _get_paths_from_manifest(path: str, boto3_session: boto3.Session | None = No
return paths


def _get_parameter_setting(cursor: "redshift_connector.Cursor", parameter_name: str) -> str:
sql = f"SHOW {parameter_name}"
_logger.debug("Executing select query:\n%s", sql)
cursor.execute(sql)
result = cursor.fetchall()
status = result[0][0]
print(f"{status=}") # @todo remove
return status


def _lock(
cursor: "redshift_connector.Cursor",
table_names: list[str],
Expand Down Expand Up @@ -285,7 +295,7 @@ def _redshift_types_from_path(
return redshift_types


def _get_rsh_types(
def _get_rsh_columns_types(
df: pd.DataFrame | None,
path: str | list[str] | None,
index: bool,
Expand Down Expand Up @@ -348,6 +358,26 @@ def _get_rsh_types(
return redshift_types


def _add_new_columns(
cursor: "redshift_connector.Cursor", schema: str, table: str, redshift_columns_types: dict[str, str]
):
# Check if the cluster is configured as case sensitive or no
is_case_sensitive = False
if _get_parameter_setting(cursor=cursor, parameter_name="enable_case_sensitive_identifier").lower() in [
"on",
"true", # @todo choose one
]:
is_case_sensitive = True

# If it is case-insensitive, convert all DataFrame column names to lowercase before performing the comparison
if is_case_sensitive is False:
redshift_columns_types = {key.lower(): value for key, value in redshift_columns_types.items()}
actual_table_columns = set(_get_table_columns(cursor=cursor, schema=schema, table=table))
new_df_columns = {key: value for key, value in redshift_columns_types.items() if key not in actual_table_columns}

_add_table_columns(cursor=cursor, schema=schema, table=table, new_columns=new_df_columns)


def _create_table( # noqa: PLR0913
df: pd.DataFrame | None,
path: str | list[str] | None,
Expand Down Expand Up @@ -376,7 +406,6 @@ def _create_table( # noqa: PLR0913
boto3_session: boto3.Session | None = None,
s3_additional_kwargs: dict[str, str] | None = None,
lock: bool = False,
add_new_columns: bool = False,
) -> tuple[str, str | None]:
_logger.debug("Creating table %s with mode %s, and overwrite method %s", table, mode, overwrite_method)
if mode == "overwrite":
Expand Down Expand Up @@ -408,29 +437,6 @@ def _create_table( # noqa: PLR0913
_logger.debug("Table %s exists", table)
if lock:
_lock(cursor, [table], schema=schema)
if add_new_columns is True:
redshift_types = _get_rsh_types(
df=df,
path=path,
index=index,
dtype=dtype,
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
parquet_infer_sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
data_format=data_format,
redshift_column_types=redshift_column_types,
manifest=manifest,
)
actual_table_columns = set(_get_table_columns(cursor=cursor, schema=schema, table=table))
new_df_columns = {
key: value for key, value in redshift_types.items() if key.lower() not in actual_table_columns
}
_add_table_columns(cursor=cursor, schema=schema, table=table, new_columns=new_df_columns)
if mode == "upsert":
guid: str = uuid.uuid4().hex
temp_table: str = f"temp_redshift_{guid}"
Expand All @@ -442,7 +448,7 @@ def _create_table( # noqa: PLR0913
diststyle = diststyle.upper() if diststyle else "AUTO"
sortstyle = sortstyle.upper() if sortstyle else "COMPOUND"

redshift_types = _get_rsh_types(
redshift_types = _get_rsh_columns_types(
df=df,
path=path,
index=index,
Expand Down
47 changes: 43 additions & 4 deletions awswrangler/redshift/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
from awswrangler._config import apply_configs

from ._connect import _validate_connection
from ._utils import _create_table, _make_s3_auth_string, _upsert
from ._utils import (
_add_new_columns,
_create_table,
_does_table_exist,
_get_rsh_columns_types,
_make_s3_auth_string,
_upsert,
)

if TYPE_CHECKING:
try:
Expand Down Expand Up @@ -194,6 +201,19 @@ def to_sql(
con.autocommit = False
try:
with con.cursor() as cursor:
if add_new_columns and _does_table_exist(cursor=cursor, schema=schema, table=table):
redshift_columns_types = _get_rsh_columns_types(
df=df,
path=None,
index=index,
dtype=dtype,
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
)
_add_new_columns(
cursor=cursor, schema=schema, table=table, redshift_columns_types=redshift_columns_types
)

created_table, created_schema = _create_table(
df=df,
path=None,
Expand All @@ -213,7 +233,6 @@ def to_sql(
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
lock=lock,
add_new_columns=add_new_columns,
)
if index:
df.reset_index(level=df.index.names, inplace=True)
Expand Down Expand Up @@ -427,6 +446,27 @@ def copy_from_files( # noqa: PLR0913
con.autocommit = False
try:
with con.cursor() as cursor:
if add_new_columns and _does_table_exist(cursor=cursor, schema=schema, table=table):
redshift_columns_types = _get_rsh_columns_types(
df=None,
path=path,
index=False,
dtype=None,
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
parquet_infer_sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
data_format=data_format, # type: ignore[arg-type]
redshift_column_types=redshift_column_types,
manifest=manifest,
)
_add_new_columns(
cursor=cursor, schema=schema, table=table, redshift_columns_types=redshift_columns_types
)
created_table, created_schema = _create_table(
df=None,
path=path,
Expand Down Expand Up @@ -455,7 +495,6 @@ def copy_from_files( # noqa: PLR0913
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
lock=lock,
add_new_columns=add_new_columns,
)
_copy(
cursor=cursor,
Expand Down Expand Up @@ -637,7 +676,7 @@ def copy( # noqa: PLR0913
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
inserted into the database columns `col1` and `col3`.
add_new_columns:
add_new_columns
If True, it automatically adds the new DataFrame columns into the target table.
Examples
Expand Down

0 comments on commit 97c0cc5

Please sign in to comment.