diff --git a/datahub-web-react/src/app/ingest/source/builder/constants.ts b/datahub-web-react/src/app/ingest/source/builder/constants.ts
index 61667a941765c..dba8e8bb1dce6 100644
--- a/datahub-web-react/src/app/ingest/source/builder/constants.ts
+++ b/datahub-web-react/src/app/ingest/source/builder/constants.ts
@@ -27,6 +27,7 @@ import powerbiLogo from '../../../../images/powerbilogo.png';
import modeLogo from '../../../../images/modelogo.png';
import databricksLogo from '../../../../images/databrickslogo.png';
import verticaLogo from '../../../../images/verticalogo.png';
+import mlflowLogo from '../../../../images/mlflowlogo.png';
import dynamodbLogo from '../../../../images/dynamodblogo.png';
export const ATHENA = 'athena';
@@ -64,6 +65,8 @@ export const MARIA_DB = 'mariadb';
export const MARIA_DB_URN = `urn:li:dataPlatform:${MARIA_DB}`;
export const METABASE = 'metabase';
export const METABASE_URN = `urn:li:dataPlatform:${METABASE}`;
+export const MLFLOW = 'mlflow';
+export const MLFLOW_URN = `urn:li:dataPlatform:${MLFLOW}`;
export const MODE = 'mode';
export const MODE_URN = `urn:li:dataPlatform:${MODE}`;
export const MONGO_DB = 'mongodb';
@@ -119,6 +122,7 @@ export const PLATFORM_URN_TO_LOGO = {
[LOOKER_URN]: lookerLogo,
[MARIA_DB_URN]: mariadbLogo,
[METABASE_URN]: metabaseLogo,
+ [MLFLOW_URN]: mlflowLogo,
[MODE_URN]: modeLogo,
[MONGO_DB_URN]: mongodbLogo,
[MSSQL_URN]: mssqlLogo,
diff --git a/datahub-web-react/src/app/ingest/source/builder/sources.json b/datahub-web-react/src/app/ingest/source/builder/sources.json
index b4ea2db018bd8..1bd5b6f1f768b 100644
--- a/datahub-web-react/src/app/ingest/source/builder/sources.json
+++ b/datahub-web-react/src/app/ingest/source/builder/sources.json
@@ -181,6 +181,13 @@
"docsUrl": "https://datahubproject.io/docs/generated/ingestion/sources/metabase/",
"recipe": "source:\n type: metabase\n config:\n # Coordinates\n connect_uri:\n\n # Credentials\n username: root\n password: example"
},
+ {
+ "urn": "urn:li:dataPlatform:mlflow",
+ "name": "mlflow",
+ "displayName": "MLflow",
+ "docsUrl": "https://datahubproject.io/docs/generated/ingestion/sources/mlflow/",
+ "recipe": "source:\n type: mlflow\n config:\n tracking_uri: tracking_uri"
+ },
{
"urn": "urn:li:dataPlatform:mode",
"name": "mode",
diff --git a/datahub-web-react/src/images/mlflowlogo.png b/datahub-web-react/src/images/mlflowlogo.png
new file mode 100644
index 0000000000000..e724d1affbc14
Binary files /dev/null and b/datahub-web-react/src/images/mlflowlogo.png differ
diff --git a/metadata-ingestion/docs/sources/mlflow/mlflow_pre.md b/metadata-ingestion/docs/sources/mlflow/mlflow_pre.md
new file mode 100644
index 0000000000000..fc499a7a3b2b8
--- /dev/null
+++ b/metadata-ingestion/docs/sources/mlflow/mlflow_pre.md
@@ -0,0 +1,9 @@
+### Concept Mapping
+
+This ingestion source maps the following MLflow Concepts to DataHub Concepts:
+
+| Source Concept | DataHub Concept | Notes |
+|:---------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------:|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [`Registered Model`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`MlModelGroup`](https://datahubproject.io/docs/generated/metamodel/entities/mlmodelgroup/) | The name of a Model Group is the same as a Registered Model's name (e.g. my_mlflow_model) |
+| [`Model Version`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`MlModel`](https://datahubproject.io/docs/generated/metamodel/entities/mlmodel/) | The name of a Model is `{registered_model_name}{model_name_separator}{model_version}` (e.g. my_mlflow_model_1 for Registered Model named my_mlflow_model and Version 1, my_mlflow_model_2, etc.) |
+| [`Model Stage`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`Tag`](https://datahubproject.io/docs/generated/metamodel/entities/tag/) | The mapping between Model Stages and generated Tags is the following:
- Production: mlflow_production
- Staging: mlflow_staging
- Archived: mlflow_archived
- None: mlflow_none |
diff --git a/metadata-ingestion/docs/sources/mlflow/mlflow_recipe.yml b/metadata-ingestion/docs/sources/mlflow/mlflow_recipe.yml
new file mode 100644
index 0000000000000..e40be54346629
--- /dev/null
+++ b/metadata-ingestion/docs/sources/mlflow/mlflow_recipe.yml
@@ -0,0 +1,8 @@
+source:
+ type: mlflow
+ config:
+ # Coordinates
+ tracking_uri: tracking_uri
+
+sink:
+ # sink configs
diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py
index e748461b156ae..9073a20f9f84f 100644
--- a/metadata-ingestion/setup.py
+++ b/metadata-ingestion/setup.py
@@ -345,6 +345,7 @@ def get_long_description():
"looker": looker_common,
"lookml": looker_common,
"metabase": {"requests"} | sqllineage_lib,
+ "mlflow": {"mlflow-skinny>=2.3.0"},
"mode": {"requests", "tenacity>=8.0.1"} | sqllineage_lib,
"mongodb": {"pymongo[srv]>=3.11", "packaging"},
"mssql": sql_common | {"sqlalchemy-pytds>=0.3"},
@@ -474,6 +475,7 @@ def get_long_description():
"elasticsearch",
"feast" if sys.version_info >= (3, 8) else None,
"iceberg" if sys.version_info >= (3, 8) else None,
+ "mlflow" if sys.version_info >= (3, 8) else None,
"json-schema",
"ldap",
"looker",
@@ -573,6 +575,7 @@ def get_long_description():
"lookml = datahub.ingestion.source.looker.lookml_source:LookMLSource",
"datahub-lineage-file = datahub.ingestion.source.metadata.lineage:LineageFileSource",
"datahub-business-glossary = datahub.ingestion.source.metadata.business_glossary:BusinessGlossaryFileSource",
+ "mlflow = datahub.ingestion.source.mlflow:MLflowSource",
"mode = datahub.ingestion.source.mode:ModeSource",
"mongodb = datahub.ingestion.source.mongodb:MongoDBSource",
"mssql = datahub.ingestion.source.sql.mssql:SQLServerSource",
diff --git a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py
new file mode 100644
index 0000000000000..0668defe7b0c6
--- /dev/null
+++ b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py
@@ -0,0 +1,321 @@
+import sys
+
+if sys.version_info < (3, 8):
+ raise ImportError("MLflow is only supported on Python 3.8+")
+
+
+from dataclasses import dataclass
+from typing import Any, Callable, Iterable, Optional, TypeVar, Union
+
+from mlflow import MlflowClient
+from mlflow.entities import Run
+from mlflow.entities.model_registry import ModelVersion, RegisteredModel
+from mlflow.store.entities import PagedList
+from pydantic.fields import Field
+
+import datahub.emitter.mce_builder as builder
+from datahub.configuration.source_common import EnvConfigMixin
+from datahub.emitter.mcp import MetadataChangeProposalWrapper
+from datahub.ingestion.api.common import PipelineContext
+from datahub.ingestion.api.decorators import (
+ SupportStatus,
+ capability,
+ config_class,
+ platform_name,
+ support_status,
+)
+from datahub.ingestion.api.source import Source, SourceCapability, SourceReport
+from datahub.ingestion.api.workunit import MetadataWorkUnit
+from datahub.metadata.schema_classes import (
+ GlobalTagsClass,
+ MLHyperParamClass,
+ MLMetricClass,
+ MLModelGroupPropertiesClass,
+ MLModelPropertiesClass,
+ TagAssociationClass,
+ TagPropertiesClass,
+ VersionTagClass,
+ _Aspect,
+)
+
+T = TypeVar("T")
+
+
+class MLflowConfig(EnvConfigMixin):
+ tracking_uri: Optional[str] = Field(
+ default=None,
+ description="Tracking server URI. If not set, an MLflow default tracking_uri is used (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)",
+ )
+ registry_uri: Optional[str] = Field(
+ default=None,
+ description="Registry server URI. If not set, an MLflow default registry_uri is used (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)",
+ )
+ model_name_separator: str = Field(
+ default="_",
+ description="A string which separates model name from its version (e.g. model_1 or model-1)",
+ )
+
+
+@dataclass
+class MLflowRegisteredModelStageInfo:
+ name: str
+ description: str
+ color_hex: str
+
+
+@platform_name("MLflow")
+@config_class(MLflowConfig)
+@support_status(SupportStatus.TESTING)
+@capability(
+ SourceCapability.DESCRIPTIONS,
+ "Extract descriptions for MLflow Registered Models and Model Versions",
+)
+@capability(SourceCapability.TAGS, "Extract tags for MLflow Registered Model Stages")
+class MLflowSource(Source):
+ platform = "mlflow"
+ registered_model_stages_info = (
+ MLflowRegisteredModelStageInfo(
+ name="Production",
+ description="Production Stage for an ML model in MLflow Model Registry",
+ color_hex="#308613",
+ ),
+ MLflowRegisteredModelStageInfo(
+ name="Staging",
+ description="Staging Stage for an ML model in MLflow Model Registry",
+ color_hex="#FACB66",
+ ),
+ MLflowRegisteredModelStageInfo(
+ name="Archived",
+ description="Archived Stage for an ML model in MLflow Model Registry",
+ color_hex="#5D7283",
+ ),
+ MLflowRegisteredModelStageInfo(
+ name="None",
+ description="None Stage for an ML model in MLflow Model Registry",
+ color_hex="#F2F4F5",
+ ),
+ )
+
+ def __init__(self, ctx: PipelineContext, config: MLflowConfig):
+ super().__init__(ctx)
+ self.config = config
+ self.report = SourceReport()
+ self.client = MlflowClient(
+ tracking_uri=self.config.tracking_uri,
+ registry_uri=self.config.registry_uri,
+ )
+
+ def get_report(self) -> SourceReport:
+ return self.report
+
+ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
+ yield from self._get_tags_workunits()
+ yield from self._get_ml_model_workunits()
+
+ def _get_tags_workunits(self) -> Iterable[MetadataWorkUnit]:
+ """
+ Create tags for each Stage in MLflow Model Registry.
+ """
+ for stage_info in self.registered_model_stages_info:
+ tag_urn = self._make_stage_tag_urn(stage_info.name)
+ tag_properties = TagPropertiesClass(
+ name=self._make_stage_tag_name(stage_info.name),
+ description=stage_info.description,
+ colorHex=stage_info.color_hex,
+ )
+ wu = self._create_workunit(urn=tag_urn, aspect=tag_properties)
+ yield wu
+
+ def _make_stage_tag_urn(self, stage_name: str) -> str:
+ tag_name = self._make_stage_tag_name(stage_name)
+ tag_urn = builder.make_tag_urn(tag_name)
+ return tag_urn
+
+ def _make_stage_tag_name(self, stage_name: str) -> str:
+ return f"{self.platform}_{stage_name.lower()}"
+
+ def _create_workunit(self, urn: str, aspect: _Aspect) -> MetadataWorkUnit:
+ """
+ Utility to create an MCP workunit.
+ """
+ return MetadataChangeProposalWrapper(
+ entityUrn=urn,
+ aspect=aspect,
+ ).as_workunit()
+
+ def _get_ml_model_workunits(self) -> Iterable[MetadataWorkUnit]:
+ """
+ Traverse each Registered Model in Model Registry and generate a corresponding workunit.
+ """
+ registered_models = self._get_mlflow_registered_models()
+ for registered_model in registered_models:
+ yield self._get_ml_group_workunit(registered_model)
+ model_versions = self._get_mlflow_model_versions(registered_model)
+ for model_version in model_versions:
+ run = self._get_mlflow_run(model_version)
+ yield self._get_ml_model_properties_workunit(
+ registered_model=registered_model,
+ model_version=model_version,
+ run=run,
+ )
+ yield self._get_global_tags_workunit(model_version=model_version)
+
+ def _get_mlflow_registered_models(self) -> Iterable[RegisteredModel]:
+ """
+ Get all Registered Models in MLflow Model Registry.
+ """
+ registered_models: Iterable[
+ RegisteredModel
+ ] = self._traverse_mlflow_search_func(
+ search_func=self.client.search_registered_models,
+ )
+ return registered_models
+
+ @staticmethod
+ def _traverse_mlflow_search_func(
+ search_func: Callable[..., PagedList[T]],
+ **kwargs: Any,
+ ) -> Iterable[T]:
+ """
+ Utility to traverse an MLflow search_* functions which return PagedList.
+ """
+ next_page_token = None
+ while True:
+ paged_list = search_func(page_token=next_page_token, **kwargs)
+ yield from paged_list.to_list()
+ next_page_token = paged_list.token
+ if not next_page_token:
+ return
+
+ def _get_ml_group_workunit(
+ self,
+ registered_model: RegisteredModel,
+ ) -> MetadataWorkUnit:
+ """
+ Generate an MLModelGroup workunit for an MLflow Registered Model.
+ """
+ ml_model_group_urn = self._make_ml_model_group_urn(registered_model)
+ ml_model_group_properties = MLModelGroupPropertiesClass(
+ customProperties=registered_model.tags,
+ description=registered_model.description,
+ createdAt=registered_model.creation_timestamp,
+ )
+ wu = self._create_workunit(
+ urn=ml_model_group_urn,
+ aspect=ml_model_group_properties,
+ )
+ return wu
+
+ def _make_ml_model_group_urn(self, registered_model: RegisteredModel) -> str:
+ urn = builder.make_ml_model_group_urn(
+ platform=self.platform,
+ group_name=registered_model.name,
+ env=self.config.env,
+ )
+ return urn
+
+ def _get_mlflow_model_versions(
+ self,
+ registered_model: RegisteredModel,
+ ) -> Iterable[ModelVersion]:
+ """
+ Get all Model Versions for each Registered Model.
+ """
+ filter_string = f"name = '{registered_model.name}'"
+ model_versions: Iterable[ModelVersion] = self._traverse_mlflow_search_func(
+ search_func=self.client.search_model_versions,
+ filter_string=filter_string,
+ )
+ return model_versions
+
+ def _get_mlflow_run(self, model_version: ModelVersion) -> Union[None, Run]:
+ """
+ Get a Run associated with a Model Version. Some MVs may exist without Run.
+ """
+ if model_version.run_id:
+ run = self.client.get_run(model_version.run_id)
+ return run
+ else:
+ return None
+
+ def _get_ml_model_properties_workunit(
+ self,
+ registered_model: RegisteredModel,
+ model_version: ModelVersion,
+ run: Union[None, Run],
+ ) -> MetadataWorkUnit:
+ """
+ Generate an MLModel workunit for an MLflow Model Version.
+ Every Model Version is a DataHub MLModel entity associated with an MLModelGroup corresponding to a Registered Model.
+ If a model was registered without an associated Run then hyperparams and metrics are not available.
+ """
+ ml_model_group_urn = self._make_ml_model_group_urn(registered_model)
+ ml_model_urn = self._make_ml_model_urn(model_version)
+ if run:
+ hyperparams = [
+ MLHyperParamClass(name=k, value=str(v))
+ for k, v in run.data.params.items()
+ ]
+ training_metrics = [
+ MLMetricClass(name=k, value=str(v)) for k, v in run.data.metrics.items()
+ ]
+ else:
+ hyperparams = None
+ training_metrics = None
+ ml_model_properties = MLModelPropertiesClass(
+ customProperties=model_version.tags,
+ externalUrl=self._make_external_url(model_version),
+ description=model_version.description,
+ date=model_version.creation_timestamp,
+ version=VersionTagClass(versionTag=str(model_version.version)),
+ hyperParams=hyperparams,
+ trainingMetrics=training_metrics,
+ # mlflow tags are dicts, but datahub tags are lists. currently use only keys from mlflow tags
+ tags=list(model_version.tags.keys()),
+ groups=[ml_model_group_urn],
+ )
+ wu = self._create_workunit(urn=ml_model_urn, aspect=ml_model_properties)
+ return wu
+
+ def _make_ml_model_urn(self, model_version: ModelVersion) -> str:
+ urn = builder.make_ml_model_urn(
+ platform=self.platform,
+ model_name=f"{model_version.name}{self.config.model_name_separator}{model_version.version}",
+ env=self.config.env,
+ )
+ return urn
+
+ def _make_external_url(self, model_version: ModelVersion) -> Union[None, str]:
+ """
+ Generate URL for a Model Version to MLflow UI.
+ """
+ base_uri = self.client.tracking_uri
+ if base_uri.startswith("http"):
+ return f"{base_uri.rstrip('/')}/#/models/{model_version.name}/versions/{model_version.version}"
+ else:
+ return None
+
+ def _get_global_tags_workunit(
+ self,
+ model_version: ModelVersion,
+ ) -> MetadataWorkUnit:
+ """
+ Associate a Model Version Stage with a corresponding tag.
+ """
+ global_tags = GlobalTagsClass(
+ tags=[
+ TagAssociationClass(
+ tag=self._make_stage_tag_urn(model_version.current_stage),
+ ),
+ ]
+ )
+ wu = self._create_workunit(
+ urn=self._make_ml_model_urn(model_version),
+ aspect=global_tags,
+ )
+ return wu
+
+ @classmethod
+ def create(cls, config_dict: dict, ctx: PipelineContext) -> Source:
+ config = MLflowConfig.parse_obj(config_dict)
+ return cls(ctx, config)
diff --git a/metadata-ingestion/tests/integration/mlflow/mlflow_mcps_golden.json b/metadata-ingestion/tests/integration/mlflow/mlflow_mcps_golden.json
new file mode 100644
index 0000000000000..c70625c74d998
--- /dev/null
+++ b/metadata-ingestion/tests/integration/mlflow/mlflow_mcps_golden.json
@@ -0,0 +1,238 @@
+[
+{
+ "entityType": "tag",
+ "entityUrn": "urn:li:tag:mlflow_production",
+ "changeType": "UPSERT",
+ "aspectName": "tagProperties",
+ "aspect": {
+ "json": {
+ "name": "mlflow_production",
+ "description": "Production Stage for an ML model in MLflow Model Registry",
+ "colorHex": "#308613"
+ }
+ },
+ "systemMetadata": {
+ "lastObserved": 1615443388097,
+ "runId": "mlflow-source-test"
+ }
+},
+{
+ "entityType": "tag",
+ "entityUrn": "urn:li:tag:mlflow_staging",
+ "changeType": "UPSERT",
+ "aspectName": "tagProperties",
+ "aspect": {
+ "json": {
+ "name": "mlflow_staging",
+ "description": "Staging Stage for an ML model in MLflow Model Registry",
+ "colorHex": "#FACB66"
+ }
+ },
+ "systemMetadata": {
+ "lastObserved": 1615443388097,
+ "runId": "mlflow-source-test"
+ }
+},
+{
+ "entityType": "tag",
+ "entityUrn": "urn:li:tag:mlflow_archived",
+ "changeType": "UPSERT",
+ "aspectName": "tagProperties",
+ "aspect": {
+ "json": {
+ "name": "mlflow_archived",
+ "description": "Archived Stage for an ML model in MLflow Model Registry",
+ "colorHex": "#5D7283"
+ }
+ },
+ "systemMetadata": {
+ "lastObserved": 1615443388097,
+ "runId": "mlflow-source-test"
+ }
+},
+{
+ "entityType": "tag",
+ "entityUrn": "urn:li:tag:mlflow_none",
+ "changeType": "UPSERT",
+ "aspectName": "tagProperties",
+ "aspect": {
+ "json": {
+ "name": "mlflow_none",
+ "description": "None Stage for an ML model in MLflow Model Registry",
+ "colorHex": "#F2F4F5"
+ }
+ },
+ "systemMetadata": {
+ "lastObserved": 1615443388097,
+ "runId": "mlflow-source-test"
+ }
+},
+{
+ "entityType": "mlModelGroup",
+ "entityUrn": "urn:li:mlModelGroup:(urn:li:dataPlatform:mlflow,test-model,PROD)",
+ "changeType": "UPSERT",
+ "aspectName": "mlModelGroupProperties",
+ "aspect": {
+ "json": {
+ "customProperties": {
+ "model_env": "test",
+ "model_id": "1"
+ },
+ "description": "This a test registered model",
+ "createdAt": 1615443388097
+ }
+ },
+ "systemMetadata": {
+ "lastObserved": 1615443388097,
+ "runId": "mlflow-source-test"
+ }
+},
+{
+ "entityType": "mlModel",
+ "entityUrn": "urn:li:mlModel:(urn:li:dataPlatform:mlflow,test-model_1,PROD)",
+ "changeType": "UPSERT",
+ "aspectName": "mlModelProperties",
+ "aspect": {
+ "json": {
+ "customProperties": {
+ "model_version_id": "1"
+ },
+ "date": 1615443388097,
+ "version": {
+ "versionTag": "1"
+ },
+ "hyperParams": [
+ {
+ "name": "p",
+ "value": "1"
+ }
+ ],
+ "trainingMetrics": [
+ {
+ "name": "m",
+ "value": "0.85"
+ }
+ ],
+ "tags": [
+ "model_version_id"
+ ],
+ "groups": [
+ "urn:li:mlModelGroup:(urn:li:dataPlatform:mlflow,test-model,PROD)"
+ ]
+ }
+ },
+ "systemMetadata": {
+ "lastObserved": 1615443388097,
+ "runId": "mlflow-source-test"
+ }
+},
+{
+ "entityType": "mlModel",
+ "entityUrn": "urn:li:mlModel:(urn:li:dataPlatform:mlflow,test-model_1,PROD)",
+ "changeType": "UPSERT",
+ "aspectName": "globalTags",
+ "aspect": {
+ "json": {
+ "tags": [
+ {
+ "tag": "urn:li:tag:mlflow_archived"
+ }
+ ]
+ }
+ },
+ "systemMetadata": {
+ "lastObserved": 1615443388097,
+ "runId": "mlflow-source-test"
+ }
+},
+{
+ "entityType": "mlModel",
+ "entityUrn": "urn:li:mlModel:(urn:li:dataPlatform:mlflow,test-model_1,PROD)",
+ "changeType": "UPSERT",
+ "aspectName": "status",
+ "aspect": {
+ "json": {
+ "removed": false
+ }
+ },
+ "systemMetadata": {
+ "lastObserved": 1615443388097,
+ "runId": "mlflow-source-test"
+ }
+},
+{
+ "entityType": "mlModelGroup",
+ "entityUrn": "urn:li:mlModelGroup:(urn:li:dataPlatform:mlflow,test-model,PROD)",
+ "changeType": "UPSERT",
+ "aspectName": "status",
+ "aspect": {
+ "json": {
+ "removed": false
+ }
+ },
+ "systemMetadata": {
+ "lastObserved": 1615443388097,
+ "runId": "mlflow-source-test"
+ }
+},
+{
+ "entityType": "tag",
+ "entityUrn": "urn:li:tag:mlflow_staging",
+ "changeType": "UPSERT",
+ "aspectName": "status",
+ "aspect": {
+ "json": {
+ "removed": false
+ }
+ },
+ "systemMetadata": {
+ "lastObserved": 1615443388097,
+ "runId": "mlflow-source-test"
+ }
+},
+{
+ "entityType": "tag",
+ "entityUrn": "urn:li:tag:mlflow_archived",
+ "changeType": "UPSERT",
+ "aspectName": "status",
+ "aspect": {
+ "json": {
+ "removed": false
+ }
+ },
+ "systemMetadata": {
+ "lastObserved": 1615443388097,
+ "runId": "mlflow-source-test"
+ }
+},
+{
+ "entityType": "tag",
+ "entityUrn": "urn:li:tag:mlflow_production",
+ "changeType": "UPSERT",
+ "aspectName": "status",
+ "aspect": {
+ "json": {
+ "removed": false
+ }
+ },
+ "systemMetadata": {
+ "lastObserved": 1615443388097,
+ "runId": "mlflow-source-test"
+ }
+},
+{
+ "entityType": "tag",
+ "entityUrn": "urn:li:tag:mlflow_none",
+ "changeType": "UPSERT",
+ "aspectName": "status",
+ "aspect": {
+ "json": {
+ "removed": false
+ }
+ },
+ "systemMetadata": {
+ "lastObserved": 1615443388097,
+ "runId": "mlflow-source-test"
+ }
+}
+]
\ No newline at end of file
diff --git a/metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py b/metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py
new file mode 100644
index 0000000000000..76af666526555
--- /dev/null
+++ b/metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py
@@ -0,0 +1,104 @@
+import sys
+
+if sys.version_info >= (3, 8):
+ from pathlib import Path
+ from typing import Any, Dict, TypeVar
+
+ import pytest
+ from mlflow import MlflowClient
+
+ from datahub.ingestion.run.pipeline import Pipeline
+ from tests.test_helpers import mce_helpers
+
+ T = TypeVar("T")
+
+ @pytest.fixture
+ def tracking_uri(tmp_path: Path) -> str:
+ return str(tmp_path / "mlruns")
+
+ @pytest.fixture
+ def sink_file_path(tmp_path: Path) -> str:
+ return str(tmp_path / "mlflow_source_mcps.json")
+
+ @pytest.fixture
+ def pipeline_config(tracking_uri: str, sink_file_path: str) -> Dict[str, Any]:
+ source_type = "mlflow"
+ return {
+ "run_id": "mlflow-source-test",
+ "source": {
+ "type": source_type,
+ "config": {
+ "tracking_uri": tracking_uri,
+ },
+ },
+ "sink": {
+ "type": "file",
+ "config": {
+ "filename": sink_file_path,
+ },
+ },
+ }
+
+ @pytest.fixture
+ def generate_mlflow_data(tracking_uri: str) -> None:
+ client = MlflowClient(tracking_uri=tracking_uri)
+ experiment_name = "test-experiment"
+ run_name = "test-run"
+ model_name = "test-model"
+ test_experiment_id = client.create_experiment(experiment_name)
+ test_run = client.create_run(
+ experiment_id=test_experiment_id,
+ run_name=run_name,
+ )
+ client.log_param(
+ run_id=test_run.info.run_id,
+ key="p",
+ value=1,
+ )
+ client.log_metric(
+ run_id=test_run.info.run_id,
+ key="m",
+ value=0.85,
+ )
+ client.create_registered_model(
+ name=model_name,
+ tags=dict(
+ model_id=1,
+ model_env="test",
+ ),
+ description="This a test registered model",
+ )
+ client.create_model_version(
+ name=model_name,
+ source="dummy_dir/dummy_file",
+ run_id=test_run.info.run_id,
+ tags=dict(model_version_id=1),
+ )
+ client.transition_model_version_stage(
+ name=model_name,
+ version="1",
+ stage="Archived",
+ )
+
+ def test_ingestion(
+ pytestconfig,
+ mock_time,
+ sink_file_path,
+ pipeline_config,
+ generate_mlflow_data,
+ ):
+ print(f"MCPs file path: {sink_file_path}")
+ golden_file_path = (
+ pytestconfig.rootpath / "tests/integration/mlflow/mlflow_mcps_golden.json"
+ )
+
+ pipeline = Pipeline.create(pipeline_config)
+ pipeline.run()
+ pipeline.pretty_print_summary()
+ pipeline.raise_from_status()
+
+ mce_helpers.check_golden_file(
+ pytestconfig=pytestconfig,
+ output_path=sink_file_path,
+ golden_path=golden_file_path,
+ )
diff --git a/metadata-ingestion/tests/unit/test_mlflow_source.py b/metadata-ingestion/tests/unit/test_mlflow_source.py
new file mode 100644
index 0000000000000..97b5afd3d6a4e
--- /dev/null
+++ b/metadata-ingestion/tests/unit/test_mlflow_source.py
@@ -0,0 +1,133 @@
+import sys
+
+if sys.version_info >= (3, 8):
+ import datetime
+ from pathlib import Path
+ from typing import Any, TypeVar, Union
+
+ import pytest
+ from mlflow import MlflowClient
+ from mlflow.entities.model_registry import RegisteredModel
+ from mlflow.entities.model_registry.model_version import ModelVersion
+ from mlflow.store.entities import PagedList
+
+ from datahub.ingestion.api.common import PipelineContext
+ from datahub.ingestion.source.mlflow import MLflowConfig, MLflowSource
+
+ T = TypeVar("T")
+
+ @pytest.fixture
+ def tracking_uri(tmp_path: Path) -> str:
+ return str(tmp_path / "mlruns")
+
+ @pytest.fixture
+ def source(tracking_uri: str) -> MLflowSource:
+ return MLflowSource(
+ ctx=PipelineContext(run_id="mlflow-source-test"),
+ config=MLflowConfig(tracking_uri=tracking_uri),
+ )
+
+ @pytest.fixture
+ def registered_model(source: MLflowSource) -> RegisteredModel:
+ model_name = "abc"
+ return RegisteredModel(name=model_name)
+
+ @pytest.fixture
+ def model_version(
+ source: MLflowSource,
+ registered_model: RegisteredModel,
+ ) -> ModelVersion:
+ version = "1"
+ return ModelVersion(
+ name=registered_model.name,
+ version=version,
+ creation_timestamp=datetime.datetime.now(),
+ )
+
+ def dummy_search_func(page_token: Union[None, str], **kwargs: Any) -> PagedList[T]:
+ dummy_pages = dict(
+ page_1=PagedList(items=["a", "b"], token="page_2"),
+ page_2=PagedList(items=["c", "d"], token="page_3"),
+ page_3=PagedList(items=["e"], token=None),
+ )
+ if page_token is None:
+ page_to_return = dummy_pages["page_1"]
+ else:
+ page_to_return = dummy_pages[page_token]
+ if kwargs.get("case", "") == "upper":
+ page_to_return = PagedList(
+ items=[e.upper() for e in page_to_return.to_list()],
+ token=page_to_return.token,
+ )
+ return page_to_return
+
+ def test_stages(source):
+ mlflow_registered_model_stages = {
+ "Production",
+ "Staging",
+ "Archived",
+ None,
+ }
+ workunits = source._get_tags_workunits()
+ names = [wu.get_metadata()["metadata"].aspect.name for wu in workunits]
+
+ assert len(names) == len(mlflow_registered_model_stages)
+ assert set(names) == {
+ "mlflow_" + str(stage).lower() for stage in mlflow_registered_model_stages
+ }
+
+ def test_config_model_name_separator(source, model_version):
+ name_version_sep = "+"
+ source.config.model_name_separator = name_version_sep
+ expected_model_name = (
+ f"{model_version.name}{name_version_sep}{model_version.version}"
+ )
+ expected_urn = f"urn:li:mlModel:(urn:li:dataPlatform:mlflow,{expected_model_name},{source.config.env})"
+
+ urn = source._make_ml_model_urn(model_version)
+
+ assert urn == expected_urn
+
+ def test_model_without_run(source, registered_model, model_version):
+ run = source._get_mlflow_run(model_version)
+ wu = source._get_ml_model_properties_workunit(
+ registered_model=registered_model,
+ model_version=model_version,
+ run=run,
+ )
+ aspect = wu.get_metadata()["metadata"].aspect
+
+ assert aspect.hyperParams is None
+ assert aspect.trainingMetrics is None
+
+ def test_traverse_mlflow_search_func(source):
+ expected_items = ["a", "b", "c", "d", "e"]
+
+ items = list(source._traverse_mlflow_search_func(dummy_search_func))
+
+ assert items == expected_items
+
+ def test_traverse_mlflow_search_func_with_kwargs(source):
+ expected_items = ["A", "B", "C", "D", "E"]
+
+ items = list(
+ source._traverse_mlflow_search_func(dummy_search_func, case="upper")
+ )
+
+ assert items == expected_items
+
+ def test_make_external_link_local(source, model_version):
+ expected_url = None
+
+ url = source._make_external_url(model_version)
+
+ assert url == expected_url
+
+ def test_make_external_link_remote(source, model_version):
+ tracking_uri_remote = "https://dummy-mlflow-tracking-server.org"
+ source.client = MlflowClient(tracking_uri=tracking_uri_remote)
+ expected_url = f"{tracking_uri_remote}/#/models/{model_version.name}/versions/{model_version.version}"
+
+ url = source._make_external_url(model_version)
+
+ assert url == expected_url
diff --git a/metadata-service/war/src/main/resources/boot/data_platforms.json b/metadata-service/war/src/main/resources/boot/data_platforms.json
index 7a7cec60aa25f..3d956c5774ded 100644
--- a/metadata-service/war/src/main/resources/boot/data_platforms.json
+++ b/metadata-service/war/src/main/resources/boot/data_platforms.json
@@ -346,6 +346,16 @@
"logoUrl": "/assets/platforms/sagemakerlogo.png"
}
},
+ {
+ "urn": "urn:li:dataPlatform:mlflow",
+ "aspect": {
+ "datasetNameDelimiter": ".",
+ "name": "mlflow",
+ "displayName": "MLflow",
+ "type": "OTHERS",
+ "logoUrl": "/assets/platforms/mlflowlogo.png"
+ }
+ },
{
"urn": "urn:li:dataPlatform:glue",
"aspect": {