Skip to content

Commit

Permalink
Vision Model Overview E2E Tests - Refactor (#2185)
Browse files Browse the repository at this point in the history
* mo test refactor ckpt

* test fixes

* test flags for vision

* test fixes

* auto lint fixes

* removed duplicate constants

* comment fixes

* test & build fixes

* test fixes

* auto lint fixes

* comment updates
  • Loading branch information
Advitya17 authored Jul 24, 2023
1 parent 90db5e6 commit 37c340c
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import {
describeModelOverview,
modelAssessmentDatasets
} from "@responsible-ai/e2e";
const datasetShape =
modelAssessmentDatasets.FridgeImageClassificationModelDebugging;
describeModelOverview(datasetShape, "FridgeImageClassificationModelDebugging");
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import {
describeModelOverview,
modelAssessmentDatasets
} from "@responsible-ai/e2e";
const datasetShape = modelAssessmentDatasets.FridgeMultilabelModelDebugging;
describeModelOverview(datasetShape, "FridgeMultilabelModelDebugging");
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import {
describeModelOverview,
modelAssessmentDatasets
} from "@responsible-ai/e2e";
const datasetShape =
modelAssessmentDatasets.FridgeObjectDetectionModelDebugging;
describeModelOverview(datasetShape, "FridgeObjectDetectionModelDebugging");
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ export const FridgeImageClassificationModelDebugging = {
name: "All data",
sampleSize: "134"
}
]
],
newCohort: {
metrics: {
accuracy: "0.9",
macroF1: "0.9",
macroPrecision: "0.9",
macroRecall: "0.9",
microF1: "0.9",
microPrecision: "0.9",
microRecall: "0.9"
},
name: "CohortCreateE2E-image-classification",
sampleSize: "5"
}
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ export const FridgeMultilabelModelDebugging = {
name: "All data",
sampleSize: "10"
}
]
],
newCohort: {
metrics: {
exactMatchRatio: "1",
hammingScore: "1"
},
name: "CohortCreateE2E-multilabel",
sampleSize: "3"
}
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ export const FridgeObjectDetectionModelDebugging = {
name: "All data",
sampleSize: "5"
}
]
],
newCohort: {
metrics: {
averagePrecision: "1",
averageRecall: "1",
meanAveragePrecision: "1"
},
name: "CohortCreateE2E-object-detection",
sampleSize: "2"
}
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,22 @@ export function describeModelOverview(
isNotebookTest = true
): void {
describe(testName, () => {
const isVision =
datasetShape.isObjectDetection ||
datasetShape.isMultiLabel ||
datasetShape.isImageClassification
? true
: false;
if (isNotebookTest) {
before(() => {
visit(name);
});
} else {
before(() => {
cy.visit(`#/modelAssessment/${name}/light/english/Version-2`);
const dashboardName = isVision
? "modelAssessmentVision"
: "modelAssessment";
cy.visit(`#/${dashboardName}/${name}/light/english/Version-2`);
});
}

Expand All @@ -38,7 +47,8 @@ export function describeModelOverview(
ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
datasetShape,
false,
isNotebookTest
isNotebookTest,
isVision
);
});

Expand All @@ -57,7 +67,8 @@ export function describeModelOverview(
);
ensureAllModelOverviewFeatureCohortsViewElementsAfterSelectionArePresent(
datasetShape,
1
1,
isVision
);
});

Expand All @@ -69,16 +80,19 @@ export function describeModelOverview(
);
ensureAllModelOverviewFeatureCohortsViewElementsAfterSelectionArePresent(
datasetShape,
2
2,
isVision
);
});

it("should show new cohorts in charts", () => {
ensureNewCohortsShowUpInCharts(datasetShape, isNotebookTest);
ensureNewCohortsShowUpInCharts(datasetShape, isNotebookTest, isVision);
});

it("should pivot between charts when clicking", () => {
ensureChartsPivot(datasetShape, isNotebookTest, true);
if (!isVision) {
ensureChartsPivot(datasetShape, isNotebookTest, true);
}
});
} else {
it("should not have 'Model overview' component", () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ import { getNumberOfCohorts } from "./numberOfCohorts";
export function ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
datasetShape: IModelAssessmentData,
includeNewCohort: boolean,
isNotebookTest: boolean
isNotebookTest: boolean,
isVision: boolean
): void {
const data = datasetShape.modelOverviewData;
const initialCohorts = data?.initialCohorts;
Expand All @@ -23,7 +24,10 @@ export function ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
"not.exist"
);
if (isNotebookTest) {
if (getNumberOfCohorts(datasetShape, includeNewCohort) <= 1) {
if (
getNumberOfCohorts(datasetShape, includeNewCohort) <= 1 ||
datasetShape.isObjectDetection
) {
cy.get(Locators.ModelOverviewHeatmapVisualDisplayToggle).should(
"not.exist"
);
Expand All @@ -45,6 +49,24 @@ export function ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
"meanSquaredError",
"meanPrediction"
);
} else if (datasetShape.isImageClassification) {
metricsOrder.push(
"accuracy",
"f1Score",
"precisionScore",
"recallScore",
"falsePositiveRate",
"falseNegativeRate",
"selectionRate"
);
} else if (datasetShape.isMultiLabel) {
metricsOrder.push("exactMatchRatio", "hammingScore");
} else if (datasetShape.isObjectDetection) {
metricsOrder.push(
"meanAveragePrecision",
"averagePrecision",
"averageRecall"
);
} else {
metricsOrder.push("accuracy");
if (!datasetShape.isMulticlass) {
Expand All @@ -69,35 +91,39 @@ export function ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
});
});

if (isNotebookTest) {
cy.get(Locators.ModelOverviewHeatmapCells)
.should("have.length", (cohorts?.length || 0) * (metricsOrder.length + 1))
.each(($cell) => {
// somehow the cell string is one invisible character longer, trim
expect($cell.text().slice(0, $cell.text().length - 1)).to.be.oneOf(
heatmapCellContents
);
});
}

cy.get(
Locators.ModelOverviewDisaggregatedAnalysisBaseCohortDisclaimer
).should("not.exist");
cy.get(Locators.ModelOverviewDisaggregatedAnalysisBaseCohortWarning).should(
"not.exist"
);

const defaultVisibleChart = getDefaultVisibleChart(
datasetShape.isRegression,
datasetShape.isBinary
);
assertChartVisibility(datasetShape, defaultVisibleChart);

if (defaultVisibleChart === Locators.ModelOverviewMetricChart) {
ensureNotebookModelOverviewMetricChartIsCorrect(
isNotebookTest,
datasetShape,
includeNewCohort
if (!isVision) {
if (isNotebookTest) {
cy.get(Locators.ModelOverviewHeatmapCells)
.should(
"have.length",
(cohorts?.length || 0) * (metricsOrder.length + 1)
)
.each(($cell) => {
// somehow the cell string is one invisible character longer, trim
expect($cell.text().slice(0, $cell.text().length - 1)).to.be.oneOf(
heatmapCellContents
);
});
}
const defaultVisibleChart = getDefaultVisibleChart(
datasetShape.isRegression,
datasetShape.isBinary
);
assertChartVisibility(datasetShape, defaultVisibleChart);

if (defaultVisibleChart === Locators.ModelOverviewMetricChart) {
ensureNotebookModelOverviewMetricChartIsCorrect(
isNotebookTest,
datasetShape,
includeNewCohort
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,31 @@ import {

export function ensureAllModelOverviewFeatureCohortsViewElementsAfterSelectionArePresent(
datasetShape: IModelAssessmentData,
selectedFeatures: number
selectedFeatures: number,
isVision: boolean
): void {
cy.get(Locators.ModelOverviewFeatureSelection).should("exist");
cy.get(Locators.ModelOverviewFeatureConfigurationActionButton).should(
"exist"
);
cy.get(Locators.ModelOverviewHeatmapVisualDisplayToggle).should("exist");
cy.get(Locators.ModelOverviewDatasetCohortStatsTable).should("not.exist");
cy.get(Locators.ModelOverviewDisaggregatedAnalysisTable).should("exist");

const defaultVisibleChart = getDefaultVisibleChart(
datasetShape.isRegression,
datasetShape.isBinary
);
assertChartVisibility(datasetShape, defaultVisibleChart);
if (!isVision) {
cy.get(Locators.ModelOverviewHeatmapVisualDisplayToggle).should("exist"); // TODO: check!
cy.get(Locators.ModelOverviewDisaggregatedAnalysisTable).should("exist");

assertNumberOfChartRowsEqual(
datasetShape,
selectedFeatures,
defaultVisibleChart
);
const defaultVisibleChart = getDefaultVisibleChart(
datasetShape.isRegression,
datasetShape.isBinary
);
assertChartVisibility(datasetShape, defaultVisibleChart);

assertNumberOfChartRowsEqual(
datasetShape,
selectedFeatures,
defaultVisibleChart
);
}
}

function assertNumberOfChartRowsEqual(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@ import { ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent } from

export function ensureNewCohortsShowUpInCharts(
datasetShape: IModelAssessmentData,
isNotebookTest: boolean
isNotebookTest: boolean,
isVision: boolean
): void {
cy.get(Locators.ModelOverviewCohortViewDatasetCohortViewButton).click();
ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
datasetShape,
false,
isNotebookTest
isNotebookTest,
isVision
);
createCohort(datasetShape.modelOverviewData?.newCohort?.name);
ensureAllModelOverviewDatasetCohortsViewBasicElementsArePresent(
datasetShape,
true,
isNotebookTest
isNotebookTest,
isVision
);
}

0 comments on commit 37c340c

Please sign in to comment.