diff --git a/awswrangler/redshift/_utils.py b/awswrangler/redshift/_utils.py index b3d05aebe..b5409cd02 100644 --- a/awswrangler/redshift/_utils.py +++ b/awswrangler/redshift/_utils.py @@ -379,31 +379,6 @@ def _create_table( # noqa: PLR0913 add_new_columns: bool = False, ) -> tuple[str, str | None]: _logger.debug("Creating table %s with mode %s, and overwrite method %s", table, mode, overwrite_method) - redshift_types = _get_rsh_types( - df=df, - path=path, - index=index, - dtype=dtype, - varchar_lengths_default=varchar_lengths_default, - varchar_lengths=varchar_lengths, - parquet_infer_sampling=parquet_infer_sampling, - path_suffix=path_suffix, - path_ignore_suffix=path_ignore_suffix, - use_threads=use_threads, - boto3_session=boto3_session, - s3_additional_kwargs=s3_additional_kwargs, - data_format=data_format, - redshift_column_types=redshift_column_types, - manifest=manifest, - ) - if add_new_columns is True: - if _does_table_exist(cursor=cursor, schema=schema, table=table) is True: - actual_table_columns = set(_get_table_columns(cursor=cursor, schema=schema, table=table)) - new_df_columns = { - key: value for key, value in redshift_types.items() if key.lower() not in actual_table_columns - } - _add_table_columns(cursor=cursor, schema=schema, table=table, new_columns=new_df_columns) - if mode == "overwrite": if overwrite_method == "truncate": try: @@ -433,6 +408,29 @@ def _create_table( # noqa: PLR0913 _logger.debug("Table %s exists", table) if lock: _lock(cursor, [table], schema=schema) + if add_new_columns is True: + redshift_types = _get_rsh_types( + df=df, + path=path, + index=index, + dtype=dtype, + varchar_lengths_default=varchar_lengths_default, + varchar_lengths=varchar_lengths, + parquet_infer_sampling=parquet_infer_sampling, + path_suffix=path_suffix, + path_ignore_suffix=path_ignore_suffix, + use_threads=use_threads, + boto3_session=boto3_session, + s3_additional_kwargs=s3_additional_kwargs, + data_format=data_format, + redshift_column_types=redshift_column_types, + manifest=manifest, + ) + actual_table_columns = set(_get_table_columns(cursor=cursor, schema=schema, table=table)) + new_df_columns = { + key: value for key, value in redshift_types.items() if key.lower() not in actual_table_columns + } + _add_table_columns(cursor=cursor, schema=schema, table=table, new_columns=new_df_columns) if mode == "upsert": guid: str = uuid.uuid4().hex temp_table: str = f"temp_redshift_{guid}" @@ -444,6 +442,23 @@ def _create_table( # noqa: PLR0913 diststyle = diststyle.upper() if diststyle else "AUTO" sortstyle = sortstyle.upper() if sortstyle else "COMPOUND" + redshift_types = _get_rsh_types( + df=df, + path=path, + index=index, + dtype=dtype, + varchar_lengths_default=varchar_lengths_default, + varchar_lengths=varchar_lengths, + parquet_infer_sampling=parquet_infer_sampling, + path_suffix=path_suffix, + path_ignore_suffix=path_ignore_suffix, + use_threads=use_threads, + boto3_session=boto3_session, + s3_additional_kwargs=s3_additional_kwargs, + data_format=data_format, + redshift_column_types=redshift_column_types, + manifest=manifest, + ) _validate_parameters( redshift_types=redshift_types, diststyle=diststyle,