Skip to content

Commit

Permalink
feat: to_iceberg support for filling missing columns in the DataFra…
Browse files Browse the repository at this point in the history
…me with None (#2616)
  • Loading branch information
LeonLuttenberger authored Jan 18, 2024
1 parent 47250a3 commit 498b586
Show file tree
Hide file tree
Showing 3 changed files with 443 additions and 338 deletions.
72 changes: 50 additions & 22 deletions awswrangler/athena/_write_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _create_iceberg_table(
df: pd.DataFrame,
database: str,
table: str,
path: str,
path: str | None,
wg_config: _WorkGroupConfig,
partition_cols: list[str] | None,
additional_table_properties: dict[str, Any] | None,
Expand Down Expand Up @@ -80,9 +80,9 @@ def _create_iceberg_table(


class _SchemaChanges(TypedDict):
to_add: dict[str, str]
to_change: dict[str, str]
to_remove: set[str]
new_columns: dict[str, str]
modified_columns: dict[str, str]
missing_columns: dict[str, str]


def _determine_differences(
Expand All @@ -94,7 +94,7 @@ def _determine_differences(
boto3_session: boto3.Session | None,
dtype: dict[str, str] | None,
catalog_id: str | None,
) -> _SchemaChanges:
) -> tuple[_SchemaChanges, list[str]]:
frame_columns_types, frame_partitions_types = _data_types.athena_types_from_pandas_partitioned(
df=df, index=index, partition_cols=partition_cols, dtype=dtype
)
Expand All @@ -105,26 +105,30 @@ def _determine_differences(
catalog.get_table_types(database=database, table=table, catalog_id=catalog_id, boto3_session=boto3_session),
)

original_columns = set(catalog_column_types)
new_columns = set(frame_columns_types)
original_column_names = set(catalog_column_types)
new_column_names = set(frame_columns_types)

to_add = {col: frame_columns_types[col] for col in new_columns - original_columns}
to_remove = original_columns - new_columns
new_columns = {col: frame_columns_types[col] for col in new_column_names - original_column_names}
missing_columns = {col: catalog_column_types[col] for col in original_column_names - new_column_names}

columns_to_change = [
col
for col in original_columns.intersection(new_columns)
for col in original_column_names.intersection(new_column_names)
if frame_columns_types[col] != catalog_column_types[col]
]
to_change = {col: frame_columns_types[col] for col in columns_to_change}
modified_columns = {col: frame_columns_types[col] for col in columns_to_change}

return _SchemaChanges(to_add=to_add, to_change=to_change, to_remove=to_remove)
return (
_SchemaChanges(new_columns=new_columns, modified_columns=modified_columns, missing_columns=missing_columns),
[key for key in catalog_column_types],
)


def _alter_iceberg_table(
database: str,
table: str,
schema_changes: _SchemaChanges,
fill_missing_columns_in_df: bool,
wg_config: _WorkGroupConfig,
data_source: str | None = None,
workgroup: str | None = None,
Expand All @@ -134,20 +138,23 @@ def _alter_iceberg_table(
) -> None:
sql_statements: list[str] = []

if schema_changes["to_add"]:
if schema_changes["new_columns"]:
sql_statements += _alter_iceberg_table_add_columns_sql(
table=table,
columns_to_add=schema_changes["to_add"],
columns_to_add=schema_changes["new_columns"],
)

if schema_changes["to_change"]:
if schema_changes["modified_columns"]:
sql_statements += _alter_iceberg_table_change_columns_sql(
table=table,
columns_to_change=schema_changes["to_change"],
columns_to_change=schema_changes["modified_columns"],
)

if schema_changes["to_remove"]:
raise exceptions.InvalidArgumentCombination("Removing columns of Iceberg tables is not currently supported.")
if schema_changes["missing_columns"] and not fill_missing_columns_in_df:
raise exceptions.InvalidArgumentCombination(
f"Dropping columns of Iceberg tables is not supported: {schema_changes['missing_columns']}. "
"Please use `fill_missing_columns_in_df=True` to fill missing columns with N/A."
)

for statement in sql_statements:
query_execution_id: str = _start_query_execution(
Expand Down Expand Up @@ -208,6 +215,7 @@ def to_iceberg(
dtype: dict[str, str] | None = None,
catalog_id: str | None = None,
schema_evolution: bool = False,
fill_missing_columns_in_df: bool = True,
glue_table_settings: GlueTableSettings | None = None,
) -> None:
"""
Expand Down Expand Up @@ -267,8 +275,14 @@ def to_iceberg(
catalog_id : str, optional
The ID of the Data Catalog from which to retrieve Databases.
If none is provided, the AWS account ID is used by default
schema_evolution: bool
If True allows schema evolution for new columns or changes in column types.
schema_evolution: bool, optional
If ``True`` allows schema evolution for new columns or changes in column types.
Columns missing from the DataFrame that are present in the Iceberg schema
will throw an error unless ``fill_missing_columns_in_df`` is set to ``True``.
Default is ``False``.
fill_missing_columns_in_df: bool, optional
If ``True``, fill columns that was missing in the DataFrame with ``NULL`` values.
Default is ``True``.
columns_comments: GlueTableSettings, optional
Glue/Athena catalog: Settings for writing to the Glue table.
Currently only the 'columns_comments' attribute is supported for this function.
Expand Down Expand Up @@ -329,7 +343,7 @@ def to_iceberg(
df=df,
database=database,
table=table,
path=table_location, # type: ignore[arg-type]
path=table_location,
wg_config=wg_config,
partition_cols=partition_cols,
additional_table_properties=additional_table_properties,
Expand All @@ -343,7 +357,7 @@ def to_iceberg(
columns_comments=glue_table_settings.get("columns_comments"),
)
else:
schema_differences = _determine_differences(
schema_differences, catalog_cols = _determine_differences(
df=df,
database=database,
table=table,
Expand All @@ -353,13 +367,27 @@ def to_iceberg(
dtype=dtype,
catalog_id=catalog_id,
)

# Add missing columns to the DataFrame
if fill_missing_columns_in_df and schema_differences["missing_columns"]:
for col_name, col_type in schema_differences["missing_columns"].items():
df[col_name] = None
df[col_name] = df[col_name].astype(_data_types.athena2pandas(col_type))

schema_differences["missing_columns"] = {}

# Ensure that the ordering of the DF is the same as in the catalog.
# This is required for the INSERT command to work.
df = df[catalog_cols]

if schema_evolution is False and any([schema_differences[x] for x in schema_differences]): # type: ignore[literal-required]
raise exceptions.InvalidArgumentValue(f"Schema change detected: {schema_differences}")

_alter_iceberg_table(
database=database,
table=table,
schema_changes=schema_differences,
fill_missing_columns_in_df=fill_missing_columns_in_df,
wg_config=wg_config,
data_source=data_source,
workgroup=workgroup,
Expand Down
Loading

0 comments on commit 498b586

Please sign in to comment.