Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest/s3): S3 add partition to schema #8900

Merged
merged 11 commits into from
Oct 21, 2023
5 changes: 4 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/source/s3/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ class DataLakeSourceConfig(
default=100,
description="Maximum number of rows to use when inferring schemas for TSV and CSV files.",
)

add_partition_columns_to_schema: bool = Field(
default=False,
description="Whether to add partition fields to the schema.",
)
verify_ssl: Union[bool, str] = Field(
default=True,
description="Either a boolean, in which case it controls whether we verify the server's TLS certificate, or a string, in which case it must be a path to a CA bundle to use.",
Expand Down
33 changes: 33 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/source/s3/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
NullTypeClass,
NumberTypeClass,
RecordTypeClass,
SchemaField,
SchemaFieldDataType,
SchemaMetadata,
StringTypeClass,
Expand All @@ -90,6 +91,7 @@
OperationClass,
OperationTypeClass,
OtherSchemaClass,
SchemaFieldDataTypeClass,
_Aspect,
)
from datahub.telemetry import stats, telemetry
Expand Down Expand Up @@ -458,8 +460,39 @@ def get_fields(self, table_data: TableData, path_spec: PathSpec) -> List:
logger.debug(f"Extracted fields in schema: {fields}")
fields = sorted(fields, key=lambda f: f.fieldPath)

if self.source_config.add_partition_columns_to_schema:
self.add_partition_columns_to_schema(
fields=fields, path_spec=path_spec, full_path=table_data.full_path
)

return fields

def add_partition_columns_to_schema(
self, path_spec: PathSpec, full_path: str, fields: List[SchemaField]
) -> None:
is_fieldpath_v2 = False
for field in fields:
if field.fieldPath.startswith("[version=2.0]"):
is_fieldpath_v2 = True
break
vars = path_spec.get_named_vars(full_path)
if vars is not None and "partition_key" in vars:
for partition_key in vars["partition_key"].values():
fields.append(
SchemaField(
fieldPath=f"{partition_key}"
if not is_fieldpath_v2
else f"[version=2.0].[type=string].{partition_key}",
nativeDataType="string",
type=SchemaFieldDataType(StringTypeClass())
if not is_fieldpath_v2
else SchemaFieldDataTypeClass(type=StringTypeClass()),
isPartitioningKey=True,
nullable=True,
recursive=False,
)
)

def get_table_profile(
self, table_data: TableData, dataset_urn: str
) -> Iterable[MetadataWorkUnit]:
Expand Down
Loading