Skip to content

Commit

Permalink
Option to add s3 partitions to schema
Browse files Browse the repository at this point in the history
  • Loading branch information
treff7es committed Sep 26, 2023
1 parent 44820d7 commit 98ac100
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
5 changes: 4 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/source/s3/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ class DataLakeSourceConfig(
default=100,
description="Maximum number of rows to use when inferring schemas for TSV and CSV files.",
)

add_partition_columns_to_schema: bool = Field(
default=False,
description="Whether to add partition fields to the schema.",
)
verify_ssl: Union[bool, str] = Field(
default=True,
description="Either a boolean, in which case it controls whether we verify the server's TLS certificate, or a string, in which case it must be a path to a CA bundle to use.",
Expand Down
33 changes: 32 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/source/s3/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@
NullTypeClass,
NumberTypeClass,
RecordTypeClass,
SchemaFieldDataType,
SchemaMetadata,
StringTypeClass,
TimeTypeClass,
SchemaField,
SchemaFieldDataType,
)
from datahub.metadata.schema_classes import (
DataPlatformInstanceClass,
Expand All @@ -90,6 +91,7 @@
OperationTypeClass,
OtherSchemaClass,
_Aspect,
SchemaFieldDataTypeClass,
)
from datahub.telemetry import stats, telemetry
from datahub.utilities.perf_timer import PerfTimer
Expand Down Expand Up @@ -452,8 +454,37 @@ def get_fields(self, table_data: TableData, path_spec: PathSpec) -> List:
logger.debug(f"Extracted fields in schema: {fields}")
fields = sorted(fields, key=lambda f: f.fieldPath)

if self.source_config.add_partition_columns_to_schema:
self.add_partition_columns_to_schema(fields, path_spec, table_data)

return fields

def add_partition_columns_to_schema(
self, path_spec: PathSpec, full_path: str, fields: List[SchemaField]
):
is_fieldpath_v2 = False
for field in fields:
if field.fieldPath.startswith("[version=2.0]"):
is_fieldpath_v2 = True
break
vars = path_spec.get_named_vars(table_data.full_path)
if "partition_key" in vars:
for partition_key in vars["partition_key"].values():
fields.append(
SchemaField(
fieldPath=f"{partition_key}"
if not is_fieldpath_v2
else f"[version=2.0].[type=string].{partition_key}",
nativeDataType="string",
type=SchemaFieldType(StringTypeClass)
if not is_fieldpath_v2
else SchemaFieldDataTypeClass(type=StringTypeClass()),
isPartitioningKey=True,
nullable=True,
recursive=False,
)
)

def get_table_profile(
self, table_data: TableData, dataset_urn: str
) -> Iterable[MetadataWorkUnit]:
Expand Down

0 comments on commit 98ac100

Please sign in to comment.