diff --git a/datahub-web-react/src/app/ingest/source/builder/constants.ts b/datahub-web-react/src/app/ingest/source/builder/constants.ts index 61667a941765c..dba8e8bb1dce6 100644 --- a/datahub-web-react/src/app/ingest/source/builder/constants.ts +++ b/datahub-web-react/src/app/ingest/source/builder/constants.ts @@ -27,6 +27,7 @@ import powerbiLogo from '../../../../images/powerbilogo.png'; import modeLogo from '../../../../images/modelogo.png'; import databricksLogo from '../../../../images/databrickslogo.png'; import verticaLogo from '../../../../images/verticalogo.png'; +import mlflowLogo from '../../../../images/mlflowlogo.png'; import dynamodbLogo from '../../../../images/dynamodblogo.png'; export const ATHENA = 'athena'; @@ -64,6 +65,8 @@ export const MARIA_DB = 'mariadb'; export const MARIA_DB_URN = `urn:li:dataPlatform:${MARIA_DB}`; export const METABASE = 'metabase'; export const METABASE_URN = `urn:li:dataPlatform:${METABASE}`; +export const MLFLOW = 'mlflow'; +export const MLFLOW_URN = `urn:li:dataPlatform:${MLFLOW}`; export const MODE = 'mode'; export const MODE_URN = `urn:li:dataPlatform:${MODE}`; export const MONGO_DB = 'mongodb'; @@ -119,6 +122,7 @@ export const PLATFORM_URN_TO_LOGO = { [LOOKER_URN]: lookerLogo, [MARIA_DB_URN]: mariadbLogo, [METABASE_URN]: metabaseLogo, + [MLFLOW_URN]: mlflowLogo, [MODE_URN]: modeLogo, [MONGO_DB_URN]: mongodbLogo, [MSSQL_URN]: mssqlLogo, diff --git a/datahub-web-react/src/app/ingest/source/builder/sources.json b/datahub-web-react/src/app/ingest/source/builder/sources.json index b4ea2db018bd8..1bd5b6f1f768b 100644 --- a/datahub-web-react/src/app/ingest/source/builder/sources.json +++ b/datahub-web-react/src/app/ingest/source/builder/sources.json @@ -181,6 +181,13 @@ "docsUrl": "https://datahubproject.io/docs/generated/ingestion/sources/metabase/", "recipe": "source:\n type: metabase\n config:\n # Coordinates\n connect_uri:\n\n # Credentials\n username: root\n password: example" }, + { + "urn": "urn:li:dataPlatform:mlflow", + "name": "mlflow", + "displayName": "MLflow", + "docsUrl": "https://datahubproject.io/docs/generated/ingestion/sources/mlflow/", + "recipe": "source:\n type: mlflow\n config:\n tracking_uri: tracking_uri" + }, { "urn": "urn:li:dataPlatform:mode", "name": "mode", diff --git a/datahub-web-react/src/images/mlflowlogo.png b/datahub-web-react/src/images/mlflowlogo.png new file mode 100644 index 0000000000000..e724d1affbc14 Binary files /dev/null and b/datahub-web-react/src/images/mlflowlogo.png differ diff --git a/metadata-ingestion/docs/sources/mlflow/mlflow_pre.md b/metadata-ingestion/docs/sources/mlflow/mlflow_pre.md new file mode 100644 index 0000000000000..fc499a7a3b2b8 --- /dev/null +++ b/metadata-ingestion/docs/sources/mlflow/mlflow_pre.md @@ -0,0 +1,9 @@ +### Concept Mapping + +This ingestion source maps the following MLflow Concepts to DataHub Concepts: + +| Source Concept | DataHub Concept | Notes | +|:---------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------:|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [`Registered Model`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`MlModelGroup`](https://datahubproject.io/docs/generated/metamodel/entities/mlmodelgroup/) | The name of a Model Group is the same as a Registered Model's name (e.g. my_mlflow_model) | +| [`Model Version`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`MlModel`](https://datahubproject.io/docs/generated/metamodel/entities/mlmodel/) | The name of a Model is `{registered_model_name}{model_name_separator}{model_version}` (e.g. my_mlflow_model_1 for Registered Model named my_mlflow_model and Version 1, my_mlflow_model_2, etc.) | +| [`Model Stage`](https://mlflow.org/docs/latest/model-registry.html#concepts) | [`Tag`](https://datahubproject.io/docs/generated/metamodel/entities/tag/) | The mapping between Model Stages and generated Tags is the following:
- Production: mlflow_production
- Staging: mlflow_staging
- Archived: mlflow_archived
- None: mlflow_none | diff --git a/metadata-ingestion/docs/sources/mlflow/mlflow_recipe.yml b/metadata-ingestion/docs/sources/mlflow/mlflow_recipe.yml new file mode 100644 index 0000000000000..e40be54346629 --- /dev/null +++ b/metadata-ingestion/docs/sources/mlflow/mlflow_recipe.yml @@ -0,0 +1,8 @@ +source: + type: mlflow + config: + # Coordinates + tracking_uri: tracking_uri + +sink: + # sink configs diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index e748461b156ae..9073a20f9f84f 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -345,6 +345,7 @@ def get_long_description(): "looker": looker_common, "lookml": looker_common, "metabase": {"requests"} | sqllineage_lib, + "mlflow": {"mlflow-skinny>=2.3.0"}, "mode": {"requests", "tenacity>=8.0.1"} | sqllineage_lib, "mongodb": {"pymongo[srv]>=3.11", "packaging"}, "mssql": sql_common | {"sqlalchemy-pytds>=0.3"}, @@ -474,6 +475,7 @@ def get_long_description(): "elasticsearch", "feast" if sys.version_info >= (3, 8) else None, "iceberg" if sys.version_info >= (3, 8) else None, + "mlflow" if sys.version_info >= (3, 8) else None, "json-schema", "ldap", "looker", @@ -573,6 +575,7 @@ def get_long_description(): "lookml = datahub.ingestion.source.looker.lookml_source:LookMLSource", "datahub-lineage-file = datahub.ingestion.source.metadata.lineage:LineageFileSource", "datahub-business-glossary = datahub.ingestion.source.metadata.business_glossary:BusinessGlossaryFileSource", + "mlflow = datahub.ingestion.source.mlflow:MLflowSource", "mode = datahub.ingestion.source.mode:ModeSource", "mongodb = datahub.ingestion.source.mongodb:MongoDBSource", "mssql = datahub.ingestion.source.sql.mssql:SQLServerSource", diff --git a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py new file mode 100644 index 0000000000000..0668defe7b0c6 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py @@ -0,0 +1,321 @@ +import sys + +if sys.version_info < (3, 8): + raise ImportError("MLflow is only supported on Python 3.8+") + + +from dataclasses import dataclass +from typing import Any, Callable, Iterable, Optional, TypeVar, Union + +from mlflow import MlflowClient +from mlflow.entities import Run +from mlflow.entities.model_registry import ModelVersion, RegisteredModel +from mlflow.store.entities import PagedList +from pydantic.fields import Field + +import datahub.emitter.mce_builder as builder +from datahub.configuration.source_common import EnvConfigMixin +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.api.common import PipelineContext +from datahub.ingestion.api.decorators import ( + SupportStatus, + capability, + config_class, + platform_name, + support_status, +) +from datahub.ingestion.api.source import Source, SourceCapability, SourceReport +from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.metadata.schema_classes import ( + GlobalTagsClass, + MLHyperParamClass, + MLMetricClass, + MLModelGroupPropertiesClass, + MLModelPropertiesClass, + TagAssociationClass, + TagPropertiesClass, + VersionTagClass, + _Aspect, +) + +T = TypeVar("T") + + +class MLflowConfig(EnvConfigMixin): + tracking_uri: Optional[str] = Field( + default=None, + description="Tracking server URI. If not set, an MLflow default tracking_uri is used (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)", + ) + registry_uri: Optional[str] = Field( + default=None, + description="Registry server URI. If not set, an MLflow default registry_uri is used (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)", + ) + model_name_separator: str = Field( + default="_", + description="A string which separates model name from its version (e.g. model_1 or model-1)", + ) + + +@dataclass +class MLflowRegisteredModelStageInfo: + name: str + description: str + color_hex: str + + +@platform_name("MLflow") +@config_class(MLflowConfig) +@support_status(SupportStatus.TESTING) +@capability( + SourceCapability.DESCRIPTIONS, + "Extract descriptions for MLflow Registered Models and Model Versions", +) +@capability(SourceCapability.TAGS, "Extract tags for MLflow Registered Model Stages") +class MLflowSource(Source): + platform = "mlflow" + registered_model_stages_info = ( + MLflowRegisteredModelStageInfo( + name="Production", + description="Production Stage for an ML model in MLflow Model Registry", + color_hex="#308613", + ), + MLflowRegisteredModelStageInfo( + name="Staging", + description="Staging Stage for an ML model in MLflow Model Registry", + color_hex="#FACB66", + ), + MLflowRegisteredModelStageInfo( + name="Archived", + description="Archived Stage for an ML model in MLflow Model Registry", + color_hex="#5D7283", + ), + MLflowRegisteredModelStageInfo( + name="None", + description="None Stage for an ML model in MLflow Model Registry", + color_hex="#F2F4F5", + ), + ) + + def __init__(self, ctx: PipelineContext, config: MLflowConfig): + super().__init__(ctx) + self.config = config + self.report = SourceReport() + self.client = MlflowClient( + tracking_uri=self.config.tracking_uri, + registry_uri=self.config.registry_uri, + ) + + def get_report(self) -> SourceReport: + return self.report + + def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: + yield from self._get_tags_workunits() + yield from self._get_ml_model_workunits() + + def _get_tags_workunits(self) -> Iterable[MetadataWorkUnit]: + """ + Create tags for each Stage in MLflow Model Registry. + """ + for stage_info in self.registered_model_stages_info: + tag_urn = self._make_stage_tag_urn(stage_info.name) + tag_properties = TagPropertiesClass( + name=self._make_stage_tag_name(stage_info.name), + description=stage_info.description, + colorHex=stage_info.color_hex, + ) + wu = self._create_workunit(urn=tag_urn, aspect=tag_properties) + yield wu + + def _make_stage_tag_urn(self, stage_name: str) -> str: + tag_name = self._make_stage_tag_name(stage_name) + tag_urn = builder.make_tag_urn(tag_name) + return tag_urn + + def _make_stage_tag_name(self, stage_name: str) -> str: + return f"{self.platform}_{stage_name.lower()}" + + def _create_workunit(self, urn: str, aspect: _Aspect) -> MetadataWorkUnit: + """ + Utility to create an MCP workunit. + """ + return MetadataChangeProposalWrapper( + entityUrn=urn, + aspect=aspect, + ).as_workunit() + + def _get_ml_model_workunits(self) -> Iterable[MetadataWorkUnit]: + """ + Traverse each Registered Model in Model Registry and generate a corresponding workunit. + """ + registered_models = self._get_mlflow_registered_models() + for registered_model in registered_models: + yield self._get_ml_group_workunit(registered_model) + model_versions = self._get_mlflow_model_versions(registered_model) + for model_version in model_versions: + run = self._get_mlflow_run(model_version) + yield self._get_ml_model_properties_workunit( + registered_model=registered_model, + model_version=model_version, + run=run, + ) + yield self._get_global_tags_workunit(model_version=model_version) + + def _get_mlflow_registered_models(self) -> Iterable[RegisteredModel]: + """ + Get all Registered Models in MLflow Model Registry. + """ + registered_models: Iterable[ + RegisteredModel + ] = self._traverse_mlflow_search_func( + search_func=self.client.search_registered_models, + ) + return registered_models + + @staticmethod + def _traverse_mlflow_search_func( + search_func: Callable[..., PagedList[T]], + **kwargs: Any, + ) -> Iterable[T]: + """ + Utility to traverse an MLflow search_* functions which return PagedList. + """ + next_page_token = None + while True: + paged_list = search_func(page_token=next_page_token, **kwargs) + yield from paged_list.to_list() + next_page_token = paged_list.token + if not next_page_token: + return + + def _get_ml_group_workunit( + self, + registered_model: RegisteredModel, + ) -> MetadataWorkUnit: + """ + Generate an MLModelGroup workunit for an MLflow Registered Model. + """ + ml_model_group_urn = self._make_ml_model_group_urn(registered_model) + ml_model_group_properties = MLModelGroupPropertiesClass( + customProperties=registered_model.tags, + description=registered_model.description, + createdAt=registered_model.creation_timestamp, + ) + wu = self._create_workunit( + urn=ml_model_group_urn, + aspect=ml_model_group_properties, + ) + return wu + + def _make_ml_model_group_urn(self, registered_model: RegisteredModel) -> str: + urn = builder.make_ml_model_group_urn( + platform=self.platform, + group_name=registered_model.name, + env=self.config.env, + ) + return urn + + def _get_mlflow_model_versions( + self, + registered_model: RegisteredModel, + ) -> Iterable[ModelVersion]: + """ + Get all Model Versions for each Registered Model. + """ + filter_string = f"name = '{registered_model.name}'" + model_versions: Iterable[ModelVersion] = self._traverse_mlflow_search_func( + search_func=self.client.search_model_versions, + filter_string=filter_string, + ) + return model_versions + + def _get_mlflow_run(self, model_version: ModelVersion) -> Union[None, Run]: + """ + Get a Run associated with a Model Version. Some MVs may exist without Run. + """ + if model_version.run_id: + run = self.client.get_run(model_version.run_id) + return run + else: + return None + + def _get_ml_model_properties_workunit( + self, + registered_model: RegisteredModel, + model_version: ModelVersion, + run: Union[None, Run], + ) -> MetadataWorkUnit: + """ + Generate an MLModel workunit for an MLflow Model Version. + Every Model Version is a DataHub MLModel entity associated with an MLModelGroup corresponding to a Registered Model. + If a model was registered without an associated Run then hyperparams and metrics are not available. + """ + ml_model_group_urn = self._make_ml_model_group_urn(registered_model) + ml_model_urn = self._make_ml_model_urn(model_version) + if run: + hyperparams = [ + MLHyperParamClass(name=k, value=str(v)) + for k, v in run.data.params.items() + ] + training_metrics = [ + MLMetricClass(name=k, value=str(v)) for k, v in run.data.metrics.items() + ] + else: + hyperparams = None + training_metrics = None + ml_model_properties = MLModelPropertiesClass( + customProperties=model_version.tags, + externalUrl=self._make_external_url(model_version), + description=model_version.description, + date=model_version.creation_timestamp, + version=VersionTagClass(versionTag=str(model_version.version)), + hyperParams=hyperparams, + trainingMetrics=training_metrics, + # mlflow tags are dicts, but datahub tags are lists. currently use only keys from mlflow tags + tags=list(model_version.tags.keys()), + groups=[ml_model_group_urn], + ) + wu = self._create_workunit(urn=ml_model_urn, aspect=ml_model_properties) + return wu + + def _make_ml_model_urn(self, model_version: ModelVersion) -> str: + urn = builder.make_ml_model_urn( + platform=self.platform, + model_name=f"{model_version.name}{self.config.model_name_separator}{model_version.version}", + env=self.config.env, + ) + return urn + + def _make_external_url(self, model_version: ModelVersion) -> Union[None, str]: + """ + Generate URL for a Model Version to MLflow UI. + """ + base_uri = self.client.tracking_uri + if base_uri.startswith("http"): + return f"{base_uri.rstrip('/')}/#/models/{model_version.name}/versions/{model_version.version}" + else: + return None + + def _get_global_tags_workunit( + self, + model_version: ModelVersion, + ) -> MetadataWorkUnit: + """ + Associate a Model Version Stage with a corresponding tag. + """ + global_tags = GlobalTagsClass( + tags=[ + TagAssociationClass( + tag=self._make_stage_tag_urn(model_version.current_stage), + ), + ] + ) + wu = self._create_workunit( + urn=self._make_ml_model_urn(model_version), + aspect=global_tags, + ) + return wu + + @classmethod + def create(cls, config_dict: dict, ctx: PipelineContext) -> Source: + config = MLflowConfig.parse_obj(config_dict) + return cls(ctx, config) diff --git a/metadata-ingestion/tests/integration/mlflow/mlflow_mcps_golden.json b/metadata-ingestion/tests/integration/mlflow/mlflow_mcps_golden.json new file mode 100644 index 0000000000000..c70625c74d998 --- /dev/null +++ b/metadata-ingestion/tests/integration/mlflow/mlflow_mcps_golden.json @@ -0,0 +1,238 @@ +[ +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_production", + "changeType": "UPSERT", + "aspectName": "tagProperties", + "aspect": { + "json": { + "name": "mlflow_production", + "description": "Production Stage for an ML model in MLflow Model Registry", + "colorHex": "#308613" + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_staging", + "changeType": "UPSERT", + "aspectName": "tagProperties", + "aspect": { + "json": { + "name": "mlflow_staging", + "description": "Staging Stage for an ML model in MLflow Model Registry", + "colorHex": "#FACB66" + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_archived", + "changeType": "UPSERT", + "aspectName": "tagProperties", + "aspect": { + "json": { + "name": "mlflow_archived", + "description": "Archived Stage for an ML model in MLflow Model Registry", + "colorHex": "#5D7283" + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_none", + "changeType": "UPSERT", + "aspectName": "tagProperties", + "aspect": { + "json": { + "name": "mlflow_none", + "description": "None Stage for an ML model in MLflow Model Registry", + "colorHex": "#F2F4F5" + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "mlModelGroup", + "entityUrn": "urn:li:mlModelGroup:(urn:li:dataPlatform:mlflow,test-model,PROD)", + "changeType": "UPSERT", + "aspectName": "mlModelGroupProperties", + "aspect": { + "json": { + "customProperties": { + "model_env": "test", + "model_id": "1" + }, + "description": "This a test registered model", + "createdAt": 1615443388097 + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "mlModel", + "entityUrn": "urn:li:mlModel:(urn:li:dataPlatform:mlflow,test-model_1,PROD)", + "changeType": "UPSERT", + "aspectName": "mlModelProperties", + "aspect": { + "json": { + "customProperties": { + "model_version_id": "1" + }, + "date": 1615443388097, + "version": { + "versionTag": "1" + }, + "hyperParams": [ + { + "name": "p", + "value": "1" + } + ], + "trainingMetrics": [ + { + "name": "m", + "value": "0.85" + } + ], + "tags": [ + "model_version_id" + ], + "groups": [ + "urn:li:mlModelGroup:(urn:li:dataPlatform:mlflow,test-model,PROD)" + ] + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "mlModel", + "entityUrn": "urn:li:mlModel:(urn:li:dataPlatform:mlflow,test-model_1,PROD)", + "changeType": "UPSERT", + "aspectName": "globalTags", + "aspect": { + "json": { + "tags": [ + { + "tag": "urn:li:tag:mlflow_archived" + } + ] + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "mlModel", + "entityUrn": "urn:li:mlModel:(urn:li:dataPlatform:mlflow,test-model_1,PROD)", + "changeType": "UPSERT", + "aspectName": "status", + "aspect": { + "json": { + "removed": false + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "mlModelGroup", + "entityUrn": "urn:li:mlModelGroup:(urn:li:dataPlatform:mlflow,test-model,PROD)", + "changeType": "UPSERT", + "aspectName": "status", + "aspect": { + "json": { + "removed": false + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_staging", + "changeType": "UPSERT", + "aspectName": "status", + "aspect": { + "json": { + "removed": false + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_archived", + "changeType": "UPSERT", + "aspectName": "status", + "aspect": { + "json": { + "removed": false + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_production", + "changeType": "UPSERT", + "aspectName": "status", + "aspect": { + "json": { + "removed": false + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +}, +{ + "entityType": "tag", + "entityUrn": "urn:li:tag:mlflow_none", + "changeType": "UPSERT", + "aspectName": "status", + "aspect": { + "json": { + "removed": false + } + }, + "systemMetadata": { + "lastObserved": 1615443388097, + "runId": "mlflow-source-test" + } +} +] \ No newline at end of file diff --git a/metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py b/metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py new file mode 100644 index 0000000000000..76af666526555 --- /dev/null +++ b/metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py @@ -0,0 +1,104 @@ +import sys + +if sys.version_info >= (3, 8): + from pathlib import Path + from typing import Any, Dict, TypeVar + + import pytest + from mlflow import MlflowClient + + from datahub.ingestion.run.pipeline import Pipeline + from tests.test_helpers import mce_helpers + + T = TypeVar("T") + + @pytest.fixture + def tracking_uri(tmp_path: Path) -> str: + return str(tmp_path / "mlruns") + + @pytest.fixture + def sink_file_path(tmp_path: Path) -> str: + return str(tmp_path / "mlflow_source_mcps.json") + + @pytest.fixture + def pipeline_config(tracking_uri: str, sink_file_path: str) -> Dict[str, Any]: + source_type = "mlflow" + return { + "run_id": "mlflow-source-test", + "source": { + "type": source_type, + "config": { + "tracking_uri": tracking_uri, + }, + }, + "sink": { + "type": "file", + "config": { + "filename": sink_file_path, + }, + }, + } + + @pytest.fixture + def generate_mlflow_data(tracking_uri: str) -> None: + client = MlflowClient(tracking_uri=tracking_uri) + experiment_name = "test-experiment" + run_name = "test-run" + model_name = "test-model" + test_experiment_id = client.create_experiment(experiment_name) + test_run = client.create_run( + experiment_id=test_experiment_id, + run_name=run_name, + ) + client.log_param( + run_id=test_run.info.run_id, + key="p", + value=1, + ) + client.log_metric( + run_id=test_run.info.run_id, + key="m", + value=0.85, + ) + client.create_registered_model( + name=model_name, + tags=dict( + model_id=1, + model_env="test", + ), + description="This a test registered model", + ) + client.create_model_version( + name=model_name, + source="dummy_dir/dummy_file", + run_id=test_run.info.run_id, + tags=dict(model_version_id=1), + ) + client.transition_model_version_stage( + name=model_name, + version="1", + stage="Archived", + ) + + def test_ingestion( + pytestconfig, + mock_time, + sink_file_path, + pipeline_config, + generate_mlflow_data, + ): + print(f"MCPs file path: {sink_file_path}") + golden_file_path = ( + pytestconfig.rootpath / "tests/integration/mlflow/mlflow_mcps_golden.json" + ) + + pipeline = Pipeline.create(pipeline_config) + pipeline.run() + pipeline.pretty_print_summary() + pipeline.raise_from_status() + + mce_helpers.check_golden_file( + pytestconfig=pytestconfig, + output_path=sink_file_path, + golden_path=golden_file_path, + ) diff --git a/metadata-ingestion/tests/unit/test_mlflow_source.py b/metadata-ingestion/tests/unit/test_mlflow_source.py new file mode 100644 index 0000000000000..97b5afd3d6a4e --- /dev/null +++ b/metadata-ingestion/tests/unit/test_mlflow_source.py @@ -0,0 +1,133 @@ +import sys + +if sys.version_info >= (3, 8): + import datetime + from pathlib import Path + from typing import Any, TypeVar, Union + + import pytest + from mlflow import MlflowClient + from mlflow.entities.model_registry import RegisteredModel + from mlflow.entities.model_registry.model_version import ModelVersion + from mlflow.store.entities import PagedList + + from datahub.ingestion.api.common import PipelineContext + from datahub.ingestion.source.mlflow import MLflowConfig, MLflowSource + + T = TypeVar("T") + + @pytest.fixture + def tracking_uri(tmp_path: Path) -> str: + return str(tmp_path / "mlruns") + + @pytest.fixture + def source(tracking_uri: str) -> MLflowSource: + return MLflowSource( + ctx=PipelineContext(run_id="mlflow-source-test"), + config=MLflowConfig(tracking_uri=tracking_uri), + ) + + @pytest.fixture + def registered_model(source: MLflowSource) -> RegisteredModel: + model_name = "abc" + return RegisteredModel(name=model_name) + + @pytest.fixture + def model_version( + source: MLflowSource, + registered_model: RegisteredModel, + ) -> ModelVersion: + version = "1" + return ModelVersion( + name=registered_model.name, + version=version, + creation_timestamp=datetime.datetime.now(), + ) + + def dummy_search_func(page_token: Union[None, str], **kwargs: Any) -> PagedList[T]: + dummy_pages = dict( + page_1=PagedList(items=["a", "b"], token="page_2"), + page_2=PagedList(items=["c", "d"], token="page_3"), + page_3=PagedList(items=["e"], token=None), + ) + if page_token is None: + page_to_return = dummy_pages["page_1"] + else: + page_to_return = dummy_pages[page_token] + if kwargs.get("case", "") == "upper": + page_to_return = PagedList( + items=[e.upper() for e in page_to_return.to_list()], + token=page_to_return.token, + ) + return page_to_return + + def test_stages(source): + mlflow_registered_model_stages = { + "Production", + "Staging", + "Archived", + None, + } + workunits = source._get_tags_workunits() + names = [wu.get_metadata()["metadata"].aspect.name for wu in workunits] + + assert len(names) == len(mlflow_registered_model_stages) + assert set(names) == { + "mlflow_" + str(stage).lower() for stage in mlflow_registered_model_stages + } + + def test_config_model_name_separator(source, model_version): + name_version_sep = "+" + source.config.model_name_separator = name_version_sep + expected_model_name = ( + f"{model_version.name}{name_version_sep}{model_version.version}" + ) + expected_urn = f"urn:li:mlModel:(urn:li:dataPlatform:mlflow,{expected_model_name},{source.config.env})" + + urn = source._make_ml_model_urn(model_version) + + assert urn == expected_urn + + def test_model_without_run(source, registered_model, model_version): + run = source._get_mlflow_run(model_version) + wu = source._get_ml_model_properties_workunit( + registered_model=registered_model, + model_version=model_version, + run=run, + ) + aspect = wu.get_metadata()["metadata"].aspect + + assert aspect.hyperParams is None + assert aspect.trainingMetrics is None + + def test_traverse_mlflow_search_func(source): + expected_items = ["a", "b", "c", "d", "e"] + + items = list(source._traverse_mlflow_search_func(dummy_search_func)) + + assert items == expected_items + + def test_traverse_mlflow_search_func_with_kwargs(source): + expected_items = ["A", "B", "C", "D", "E"] + + items = list( + source._traverse_mlflow_search_func(dummy_search_func, case="upper") + ) + + assert items == expected_items + + def test_make_external_link_local(source, model_version): + expected_url = None + + url = source._make_external_url(model_version) + + assert url == expected_url + + def test_make_external_link_remote(source, model_version): + tracking_uri_remote = "https://dummy-mlflow-tracking-server.org" + source.client = MlflowClient(tracking_uri=tracking_uri_remote) + expected_url = f"{tracking_uri_remote}/#/models/{model_version.name}/versions/{model_version.version}" + + url = source._make_external_url(model_version) + + assert url == expected_url diff --git a/metadata-service/war/src/main/resources/boot/data_platforms.json b/metadata-service/war/src/main/resources/boot/data_platforms.json index 7a7cec60aa25f..3d956c5774ded 100644 --- a/metadata-service/war/src/main/resources/boot/data_platforms.json +++ b/metadata-service/war/src/main/resources/boot/data_platforms.json @@ -346,6 +346,16 @@ "logoUrl": "/assets/platforms/sagemakerlogo.png" } }, + { + "urn": "urn:li:dataPlatform:mlflow", + "aspect": { + "datasetNameDelimiter": ".", + "name": "mlflow", + "displayName": "MLflow", + "type": "OTHERS", + "logoUrl": "/assets/platforms/mlflowlogo.png" + } + }, { "urn": "urn:li:dataPlatform:glue", "aspect": {