Skip to content

Commit

Permalink
mlflow only for >= 3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
asikowitz committed Sep 22, 2023
1 parent e6a851f commit 4041a1d
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 94 deletions.
2 changes: 1 addition & 1 deletion metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def get_long_description():
"elasticsearch",
"feast" if sys.version_info >= (3, 8) else None,
"iceberg" if sys.version_info >= (3, 8) else None,
"mlflow" if sys.version_info >= (3, 8) else None,
"json-schema",
"ldap",
"looker",
Expand Down Expand Up @@ -506,7 +507,6 @@ def get_long_description():
"nifi",
"vertica",
"mode",
"mlflow",
]
if plugin
for dependency in plugins[plugin]
Expand Down
6 changes: 6 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/source/mlflow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import sys

if sys.version_info < (3, 8):
raise ImportError("MLflow is only supported on Python 3.8+")


from dataclasses import dataclass
from typing import Any, Callable, Iterable, Optional, TypeVar, Union

Expand Down
184 changes: 91 additions & 93 deletions metadata-ingestion/tests/integration/mlflow/test_mlflow_source.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,104 @@
from pathlib import Path
from typing import Any, Dict, TypeVar
import sys

import pytest
from mlflow import MlflowClient
if sys.version_info >= (3, 8):
from pathlib import Path
from typing import Any, Dict, TypeVar

from datahub.ingestion.run.pipeline import Pipeline
from tests.test_helpers import mce_helpers
import pytest
from mlflow import MlflowClient

T = TypeVar("T")
from datahub.ingestion.run.pipeline import Pipeline
from tests.test_helpers import mce_helpers

T = TypeVar("T")

@pytest.fixture
def tracking_uri(tmp_path: Path) -> str:
return str(tmp_path / "mlruns")
@pytest.fixture
def tracking_uri(tmp_path: Path) -> str:
return str(tmp_path / "mlruns")

@pytest.fixture
def sink_file_path(tmp_path: Path) -> str:
return str(tmp_path / "mlflow_source_mcps.json")

@pytest.fixture
def sink_file_path(tmp_path: Path) -> str:
return str(tmp_path / "mlflow_source_mcps.json")


@pytest.fixture
def pipeline_config(tracking_uri: str, sink_file_path: str) -> Dict[str, Any]:
source_type = "mlflow"
return {
"run_id": "mlflow-source-test",
"source": {
"type": source_type,
"config": {
"tracking_uri": tracking_uri,
@pytest.fixture
def pipeline_config(tracking_uri: str, sink_file_path: str) -> Dict[str, Any]:
source_type = "mlflow"
return {
"run_id": "mlflow-source-test",
"source": {
"type": source_type,
"config": {
"tracking_uri": tracking_uri,
},
},
},
"sink": {
"type": "file",
"config": {
"filename": sink_file_path,
"sink": {
"type": "file",
"config": {
"filename": sink_file_path,
},
},
},
}


@pytest.fixture
def generate_mlflow_data(tracking_uri: str) -> None:
client = MlflowClient(tracking_uri=tracking_uri)
experiment_name = "test-experiment"
run_name = "test-run"
model_name = "test-model"
test_experiment_id = client.create_experiment(experiment_name)
test_run = client.create_run(
experiment_id=test_experiment_id,
run_name=run_name,
)
client.log_param(
run_id=test_run.info.run_id,
key="p",
value=1,
)
client.log_metric(
run_id=test_run.info.run_id,
key="m",
value=0.85,
)
client.create_registered_model(
name=model_name,
tags=dict(
model_id=1,
model_env="test",
),
description="This a test registered model",
)
client.create_model_version(
name=model_name,
source="dummy_dir/dummy_file",
run_id=test_run.info.run_id,
tags=dict(model_version_id=1),
)
client.transition_model_version_stage(
name=model_name,
version="1",
stage="Archived",
)
}

@pytest.fixture
def generate_mlflow_data(tracking_uri: str) -> None:
client = MlflowClient(tracking_uri=tracking_uri)
experiment_name = "test-experiment"
run_name = "test-run"
model_name = "test-model"
test_experiment_id = client.create_experiment(experiment_name)
test_run = client.create_run(
experiment_id=test_experiment_id,
run_name=run_name,
)
client.log_param(
run_id=test_run.info.run_id,
key="p",
value=1,
)
client.log_metric(
run_id=test_run.info.run_id,
key="m",
value=0.85,
)
client.create_registered_model(
name=model_name,
tags=dict(
model_id=1,
model_env="test",
),
description="This a test registered model",
)
client.create_model_version(
name=model_name,
source="dummy_dir/dummy_file",
run_id=test_run.info.run_id,
tags=dict(model_version_id=1),
)
client.transition_model_version_stage(
name=model_name,
version="1",
stage="Archived",
)

def test_ingestion(
pytestconfig,
mock_time,
sink_file_path,
pipeline_config,
generate_mlflow_data,
):
print(f"MCPs file path: {sink_file_path}")
golden_file_path = (
pytestconfig.rootpath / "tests/integration/mlflow/mlflow_mcps_golden.json"
)
def test_ingestion(
pytestconfig,
mock_time,
sink_file_path,
pipeline_config,
generate_mlflow_data,
):
print(f"MCPs file path: {sink_file_path}")
golden_file_path = (
pytestconfig.rootpath / "tests/integration/mlflow/mlflow_mcps_golden.json"
)

pipeline = Pipeline.create(pipeline_config)
pipeline.run()
pipeline.pretty_print_summary()
pipeline.raise_from_status()
pipeline = Pipeline.create(pipeline_config)
pipeline.run()
pipeline.pretty_print_summary()
pipeline.raise_from_status()

mce_helpers.check_golden_file(
pytestconfig=pytestconfig,
output_path=sink_file_path,
golden_path=golden_file_path,
)
mce_helpers.check_golden_file(
pytestconfig=pytestconfig,
output_path=sink_file_path,
golden_path=golden_file_path,
)

0 comments on commit 4041a1d

Please sign in to comment.