Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: to_iceberg support for filling missing columns in the DataFrame with None #2616

Merged
merged 12 commits into from
Jan 18, 2024
52 changes: 34 additions & 18 deletions awswrangler/athena/_write_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _create_iceberg_table(
df: pd.DataFrame,
database: str,
table: str,
path: str,
path: str | None,
wg_config: _WorkGroupConfig,
partition_cols: list[str] | None,
additional_table_properties: dict[str, Any] | None,
Expand Down Expand Up @@ -80,9 +80,9 @@ def _create_iceberg_table(


class _SchemaChanges(TypedDict):
to_add: dict[str, str]
to_change: dict[str, str]
to_remove: set[str]
new_columns: dict[str, str]
modified_columns: dict[str, str]
missing_columns: dict[str, str]


def _determine_differences(
Expand All @@ -105,26 +105,27 @@ def _determine_differences(
catalog.get_table_types(database=database, table=table, catalog_id=catalog_id, boto3_session=boto3_session),
)

original_columns = set(catalog_column_types)
new_columns = set(frame_columns_types)
original_column_names = set(catalog_column_types)
new_column_names = set(frame_columns_types)

to_add = {col: frame_columns_types[col] for col in new_columns - original_columns}
to_remove = original_columns - new_columns
new_columns = {col: frame_columns_types[col] for col in new_column_names - original_column_names}
missing_columns = {col: catalog_column_types[col] for col in original_column_names - new_column_names}

columns_to_change = [
col
for col in original_columns.intersection(new_columns)
for col in original_column_names.intersection(new_column_names)
if frame_columns_types[col] != catalog_column_types[col]
]
to_change = {col: frame_columns_types[col] for col in columns_to_change}
modified_columns = {col: frame_columns_types[col] for col in columns_to_change}

return _SchemaChanges(to_add=to_add, to_change=to_change, to_remove=to_remove)
return _SchemaChanges(new_columns=new_columns, modified_columns=modified_columns, missing_columns=missing_columns)


def _alter_iceberg_table(
database: str,
table: str,
schema_changes: _SchemaChanges,
schema_fill_missing: bool,
wg_config: _WorkGroupConfig,
data_source: str | None = None,
workgroup: str | None = None,
Expand All @@ -134,20 +135,23 @@ def _alter_iceberg_table(
) -> None:
sql_statements: list[str] = []

if schema_changes["to_add"]:
if schema_changes["new_columns"]:
sql_statements += _alter_iceberg_table_add_columns_sql(
table=table,
columns_to_add=schema_changes["to_add"],
columns_to_add=schema_changes["new_columns"],
)

if schema_changes["to_change"]:
if schema_changes["modified_columns"]:
sql_statements += _alter_iceberg_table_change_columns_sql(
table=table,
columns_to_change=schema_changes["to_change"],
columns_to_change=schema_changes["modified_columns"],
)

if schema_changes["to_remove"]:
raise exceptions.InvalidArgumentCombination("Removing columns of Iceberg tables is not currently supported.")
if schema_changes["missing_columns"] and not schema_fill_missing:
raise exceptions.InvalidArgumentCombination(
f"Dropping columns of Iceberg tables is not supported: {schema_changes['missing_columns']}. "
"Please use `schema_fill_missing=True` to fill missing columns with N/A."
)

for statement in sql_statements:
query_execution_id: str = _start_query_execution(
Expand Down Expand Up @@ -208,6 +212,7 @@ def to_iceberg(
dtype: dict[str, str] | None = None,
catalog_id: str | None = None,
schema_evolution: bool = False,
schema_fill_missing: bool = False,
glue_table_settings: GlueTableSettings | None = None,
) -> None:
"""
Expand Down Expand Up @@ -269,6 +274,10 @@ def to_iceberg(
If none is provided, the AWS account ID is used by default
schema_evolution: bool
If True allows schema evolution for new columns or changes in column types.
Missing columns will throw an error unless ``schema_fill_missing`` is set to ``True``.
schema_fill_missing: bool
If True, fill missing columns with NULL values.
Only takes effect if ``schema_evolution`` is set to True.
LeonLuttenberger marked this conversation as resolved.
Show resolved Hide resolved
columns_comments: GlueTableSettings, optional
Glue/Athena catalog: Settings for writing to the Glue table.
Currently only the 'columns_comments' attribute is supported for this function.
Expand Down Expand Up @@ -329,7 +338,7 @@ def to_iceberg(
df=df,
database=database,
table=table,
path=table_location, # type: ignore[arg-type]
path=table_location,
wg_config=wg_config,
partition_cols=partition_cols,
additional_table_properties=additional_table_properties,
Expand Down Expand Up @@ -360,6 +369,7 @@ def to_iceberg(
database=database,
table=table,
schema_changes=schema_differences,
schema_fill_missing=schema_fill_missing,
wg_config=wg_config,
data_source=data_source,
workgroup=workgroup,
Expand All @@ -368,6 +378,12 @@ def to_iceberg(
boto3_session=boto3_session,
)

# Add missing columns to the DataFrame
if schema_differences["missing_columns"] and schema_fill_missing:
for col_name, col_type in schema_differences["missing_columns"].items():
df[col_name] = None
df[col_name] = df[col_name].astype(col_type)

# Create temporary external table, write the results
s3.to_parquet(
df=df,
Expand Down
45 changes: 44 additions & 1 deletion tests/unit/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,7 +1628,50 @@ def test_athena_to_iceberg_schema_evolution_modify_columns(
assert str(df2_out["c2"].dtype).startswith("Int64")


def test_athena_to_iceberg_schema_evolution_remove_columns_error(
def test_athena_to_iceberg_schema_evolution_fill_missing_columns(
path: str, path2: str, glue_database: str, glue_table: str
) -> None:
df = pd.DataFrame({"c0": [0, 1, 2], "c1": ["foo", "bar", "baz"]})
wr.athena.to_iceberg(
df=df,
database=glue_database,
table=glue_table,
table_location=path,
temp_path=path2,
keep_files=False,
)

print(wr.catalog.table(glue_database, glue_table))

df = pd.DataFrame({"c0": [3, 4, 5]})
wr.athena.to_iceberg(
df=df,
database=glue_database,
table=glue_table,
table_location=path,
temp_path=path2,
keep_files=False,
schema_evolution=True,
schema_fill_missing=True,
)
print(wr.catalog.table(glue_database, glue_table))

df_actual = wr.athena.read_sql_table(
table=glue_table,
database=glue_database,
ctas_approach=False,
unload_approach=False,
)
df_actual = df_actual.sort_values("c0").reset_index(drop=True)
df_actual["c0"] = df_actual["c0"].astype("int64")

df_expected = pd.DataFrame({"c0": [0, 1, 2, 3, 4, 5], "c1": ["foo", "bar", "baz", np.nan, np.nan, np.nan]})
df_expected["c1"] = df_expected["c1"].astype("string")

assert_pandas_equals(df_actual, df_expected)


def test_athena_to_iceberg_schema_evolution_drop_columns_error(
path: str, path2: str, glue_database: str, glue_table: str
) -> None:
df = pd.DataFrame({"c0": [0, 1, 2], "c1": [3, 4, 5]})
Expand Down
Loading