Skip to content

Commit

Permalink
DataframePlot - don't error if not all experiments have a dataframe (#…
Browse files Browse the repository at this point in the history
…471)

* Created an Intake Source for DataframePlotDataSource for reading Dataframe Plot visualizations

* fixing typing errors

* update styling issues

* update environment.yml

* remove nvm files

* update test_viz.py for dataframe test

* update parameters in test_viz.py for dataframe test

* merging branch to main

* resolve merge conflicts

* remove nvm files

* moved imports in viz.py and added dataframe plot tests in test_publish.py

* add dataframe argumentfor publish tests

* debugging test_viz.py

* removed experiments check in test_viz

* Created an intake source for MetricListComparisonDataSource for reading Metric List Comparison visualizations

* updated formatting

* removed duplicate test in test_viz.py

* edit metriclist types and removing experiments from metriclistcomparison in publish.py

* edit test_viz.py metric list parameter names

* update metric list tests

* edit comments

* remove error for no dataframe logged to each experiment

* update code styling

* changing warning logic

* update formatting

* added dataframe testing

* update assertion syntax

* update daataframe tests

* updating exceptions

* remove unused var

* updating excepions and testing

* updating exception tests

* remove all dataframes from experiments test

* remove unused import

* raise rubiconException and test

* added new dataframe tests

* remove unused variable

---------

Co-authored-by: Jacqueline Hui <[email protected]>
  • Loading branch information
jeh362 and jhui18 authored Aug 2, 2024
1 parent b5f44ce commit 3970607
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 1 deletion.
13 changes: 12 additions & 1 deletion rubicon_ml/viz/dataframe_plot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import warnings

import dash_bootstrap_components as dbc
import pandas as pd
import plotly.express as px
from dash import dcc, html
from dash.dependencies import Input, Output

from rubicon_ml.exceptions import RubiconException
from rubicon_ml.viz.base import VizBase
from rubicon_ml.viz.common.colors import (
get_rubicon_colorscale,
Expand Down Expand Up @@ -91,7 +94,13 @@ def load_experiment_data(self):
self.data_df = None

for experiment in self.experiments:
dataframe = experiment.dataframe(name=self.dataframe_name)
try:
dataframe = experiment.dataframe(name=self.dataframe_name)
except RubiconException:
warnings.warn(
f"Experiment {experiment.id} does not have any dataframes logged to it."
)
continue

data_df = dataframe.get_data()
data_df["experiment_id"] = experiment.id
Expand All @@ -115,6 +124,8 @@ def load_experiment_data(self):
self.plotting_func_kwargs["color_discrete_sequence"] = get_rubicon_colorscale(
len(self.experiments),
)
if self.data_df is None:
raise RubiconException(f"No dataframe with name {self.dataframe_name} found!")

def register_callbacks(self, link_experiment_table=False):
outputs = [
Expand Down
38 changes: 38 additions & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,44 @@ def viz_experiments(rubicon_and_project_client):
return project.experiments()


@pytest.fixture
def viz_experiments_no_dataframes(rubicon_and_project_client):
"""Returns a list of experiments with the parameters, metrics, and dataframes
required to test the `viz` module.
"""
_, project = rubicon_and_project_client

# dates = pd.date_range(start="1/1/2010", end="12/1/2020", freq="MS")

for i in range(0, 10):
experiment = project.log_experiment(
commit_hash="1234567",
model_name="test model name",
name="test name",
tags=["test tag"],
)

experiment.log_parameter(name="test param 0", value=random.choice([True, False]))
experiment.log_parameter(name="test param 1", value=random.randrange(2, 10, 2))
experiment.log_parameter(
name="test param 2",
value=random.choice(["A", "B", "C", "D", "E"]),
tags=["a", "b"],
)

experiment.log_metric(name="test metric 0", value=random.random())
experiment.log_metric(name="test metric 1", value=random.random())

experiment.log_metric(name="test metric 2", value=[random.random() for _ in range(0, 5)])
experiment.log_metric(
name="test metric 3",
value=[random.random() for _ in range(0, 5)],
tags=["a", "b"],
)

return project.experiments()


@pytest.fixture
def objects_to_log():
"""Returns objects for testing."""
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/viz/test_dataframe_plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from dash import Dash

from rubicon_ml.exceptions import RubiconException
from rubicon_ml.viz import DataframePlot


Expand Down Expand Up @@ -74,3 +75,24 @@ def test_dataframe_plot_register_callbacks_link(viz_experiments, is_linked, expe

assert registered_callback_name == "update_dataframe_plot"
assert registered_callback_len_input == expected


def test_dataframe_no_experiments():
dataframe_plot = DataframePlot("test dataframe", experiments=[])
with pytest.raises(RubiconException):
dataframe_plot.load_experiment_data()


def test_cant_find_dataframes_raise_exception(viz_experiments_no_dataframes):
for exp in viz_experiments_no_dataframes:
if len(exp.dataframes()) > 0:
viz_experiments_no_dataframes.remove(exp)
dataframe_plot = DataframePlot("test dataframe", experiments=viz_experiments_no_dataframes)
with pytest.raises(RubiconException):
dataframe_plot.load_experiment_data()


def test_wrong_dataframe_name(viz_experiments):
dataframe_plot = DataframePlot("no_name", experiments=viz_experiments)
with pytest.raises(RubiconException):
dataframe_plot.load_experiment_data()

0 comments on commit 3970607

Please sign in to comment.