diff --git a/metadata-ingestion/src/datahub/ingestion/source/s3/config.py b/metadata-ingestion/src/datahub/ingestion/source/s3/config.py index f1dd622efb746..559311d048093 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/s3/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/s3/config.py @@ -75,7 +75,10 @@ class DataLakeSourceConfig( default=100, description="Maximum number of rows to use when inferring schemas for TSV and CSV files.", ) - + add_partition_columns_to_schema: bool = Field( + default=False, + description="Whether to add partition fields to the schema.", + ) verify_ssl: Union[bool, str] = Field( default=True, description="Either a boolean, in which case it controls whether we verify the server's TLS certificate, or a string, in which case it must be a path to a CA bundle to use.", diff --git a/metadata-ingestion/src/datahub/ingestion/source/s3/source.py b/metadata-ingestion/src/datahub/ingestion/source/s3/source.py index ac4433b7eb1f0..47af884169230 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/s3/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/s3/source.py @@ -77,10 +77,11 @@ NullTypeClass, NumberTypeClass, RecordTypeClass, - SchemaFieldDataType, SchemaMetadata, StringTypeClass, TimeTypeClass, + SchemaField, + SchemaFieldDataType, ) from datahub.metadata.schema_classes import ( DataPlatformInstanceClass, @@ -90,6 +91,7 @@ OperationTypeClass, OtherSchemaClass, _Aspect, + SchemaFieldDataTypeClass, ) from datahub.telemetry import stats, telemetry from datahub.utilities.perf_timer import PerfTimer @@ -452,8 +454,37 @@ def get_fields(self, table_data: TableData, path_spec: PathSpec) -> List: logger.debug(f"Extracted fields in schema: {fields}") fields = sorted(fields, key=lambda f: f.fieldPath) + if self.source_config.add_partition_columns_to_schema: + self.add_partition_columns_to_schema(fields, path_spec, table_data) + return fields + def add_partition_columns_to_schema( + self, path_spec: PathSpec, full_path: str, fields: List[SchemaField] + ): + is_fieldpath_v2 = False + for field in fields: + if field.fieldPath.startswith("[version=2.0]"): + is_fieldpath_v2 = True + break + vars = path_spec.get_named_vars(table_data.full_path) + if "partition_key" in vars: + for partition_key in vars["partition_key"].values(): + fields.append( + SchemaField( + fieldPath=f"{partition_key}" + if not is_fieldpath_v2 + else f"[version=2.0].[type=string].{partition_key}", + nativeDataType="string", + type=SchemaFieldType(StringTypeClass) + if not is_fieldpath_v2 + else SchemaFieldDataTypeClass(type=StringTypeClass()), + isPartitioningKey=True, + nullable=True, + recursive=False, + ) + ) + def get_table_profile( self, table_data: TableData, dataset_urn: str ) -> Iterable[MetadataWorkUnit]: