diff --git a/libs/core-ui/src/lib/Interfaces/ExplanationInterfaces.ts b/libs/core-ui/src/lib/Interfaces/ExplanationInterfaces.ts
index 2546b38352..a7087ca029 100644
--- a/libs/core-ui/src/lib/Interfaces/ExplanationInterfaces.ts
+++ b/libs/core-ui/src/lib/Interfaces/ExplanationInterfaces.ts
@@ -57,6 +57,7 @@ export interface IPrecomputedExplanations {
export interface ITextFeatureImportance {
text: string[];
localExplanations: number[][];
+ baseValues?: number[][];
}
export interface IEBMGlobalExplanation {
diff --git a/libs/core-ui/src/lib/Interfaces/TextExplanationInterfaces.ts b/libs/core-ui/src/lib/Interfaces/TextExplanationInterfaces.ts
index 9923f9c6f1..965915c3da 100644
--- a/libs/core-ui/src/lib/Interfaces/TextExplanationInterfaces.ts
+++ b/libs/core-ui/src/lib/Interfaces/TextExplanationInterfaces.ts
@@ -6,4 +6,7 @@ export interface ITextExplanationDashboardData {
localExplanations: number[][];
prediction: number[];
text: string[];
+ baseValues?: number[][];
+ predictedY?: number[] | number[][] | string[] | string | number;
+ trueY?: number[] | number[][] | string[] | string | number;
}
diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts
index 03ca37e2de..8b59ba9c0a 100644
--- a/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts
+++ b/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts
@@ -78,6 +78,14 @@ export class Utils {
return sortedList;
}
+ public static addItem(value: number, radio: string | undefined): boolean {
+ return (
+ radio === RadioKeys.All ||
+ (radio === RadioKeys.Neg && value <= 0) ||
+ (radio === RadioKeys.Pos && value >= 0)
+ );
+ }
+
public static takeTopK(list: number[], k: number): number[] {
/*
* Returns a list after splicing and taking the top K
diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/BarChart/getTokenImportancesChartOptions.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/BarChart/getTokenImportancesChartOptions.ts
index bd9d2612a3..cc7f2630c1 100644
--- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/BarChart/getTokenImportancesChartOptions.ts
+++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/BarChart/getTokenImportancesChartOptions.ts
@@ -2,17 +2,27 @@
// Licensed under the MIT License.
import { ITheme } from "@fluentui/react";
-import {
- IHighchartsConfig,
- getPrimaryChartColor,
- getPrimaryBackgroundChartColor
-} from "@responsible-ai/core-ui";
+import { IHighchartsConfig } from "@responsible-ai/core-ui";
import { localization } from "@responsible-ai/localization";
import { SeriesOptionsType } from "highcharts";
+import _ from "lodash";
import { Utils } from "../../CommonUtils";
import { IChartProps } from "../../Interfaces/IChartProps";
+function findNearestIndex(
+ array: number[],
+ target?: number
+): number | undefined {
+ if (!target) {
+ return array.length;
+ }
+ const nearestElement = _.minBy(array, (element) =>
+ Math.abs(element - target)
+ );
+ return _.indexOf(array, nearestElement);
+}
+
export function getTokenImportancesChartOptions(
props: IChartProps,
theme: ITheme
@@ -20,6 +30,11 @@ export function getTokenImportancesChartOptions(
const importances = props.localExplanations;
const k = props.topK;
const sortedList = Utils.sortedTopK(importances, k, props.radio);
+
+ const outputFeatureImportanceLabel = `f ${
+ props.text[props.selectedTokenIndex || 0]
+ } (inputs)`;
+ const baseValueLabel = "base value";
const [x, y, ylabel, tooltip]: [number[], number[], string[], string[]] = [
[],
[],
@@ -46,6 +61,36 @@ export function getTokenImportancesChartOptions(
ylabel.push(props.text[idx]);
tooltip.push(str);
});
+
+ // add output feature importance
+ if (props.outputFeatureValue && props.baseValue) {
+ const outputFeatureValueIndex = findNearestIndex(
+ x,
+ props.outputFeatureValue
+ );
+ const baseValueFeatureValueIndex = findNearestIndex(x, props.baseValue);
+ if (outputFeatureValueIndex && baseValueFeatureValueIndex) {
+ if (Utils.addItem(props.outputFeatureValue, props.radio)) {
+ addItem(
+ x,
+ props.outputFeatureValue,
+ ylabel,
+ outputFeatureImportanceLabel,
+ outputFeatureValueIndex
+ );
+ }
+ if (Utils.addItem(props.baseValue, props.radio)) {
+ addItem(
+ x,
+ props.baseValue,
+ ylabel,
+ baseValueLabel,
+ baseValueFeatureValueIndex
+ );
+ }
+ }
+ }
+
// Put most significant word at the top by reversing order
tooltip.reverse();
ylabel.reverse();
@@ -54,11 +99,10 @@ export function getTokenImportancesChartOptions(
const data: any[] = [];
x.forEach((p, index) => {
const temp = {
- borderColor: getPrimaryChartColor(theme),
color:
(p || 0) >= 0
- ? getPrimaryChartColor(theme)
- : getPrimaryBackgroundChartColor(theme),
+ ? theme.semanticColors.errorText
+ : theme.semanticColors.link,
x: index,
y: p
};
@@ -68,6 +112,15 @@ export function getTokenImportancesChartOptions(
const series: SeriesOptionsType[] = [
{
data,
+ dataLabels: {
+ align: "center",
+ color: theme.semanticColors.bodyBackground,
+ enabled: true,
+ formatter(): string | number | undefined {
+ return this.x; // Display the Y-axis value inside the bar
+ },
+ inside: true
+ },
name: "",
showInLegend: false,
type: "bar"
@@ -80,11 +133,12 @@ export function getTokenImportancesChartOptions(
},
plotOptions: {
bar: {
+ minPointLength: 10,
tooltip: {
pointFormatter(): string {
return `${tooltip[this.x || 0]}: ${this.y || 0}`;
}
- }
+ } // Set the minimum pixel width for bars
}
},
series,
@@ -98,3 +152,14 @@ export function getTokenImportancesChartOptions(
}
};
}
+
+function addItem(
+ x: any[],
+ xValue: any,
+ yLabel: any[],
+ yLabelValue: any,
+ index: number
+): void {
+ x.splice(index, 0, xValue);
+ yLabel.splice(index, 0, yLabelValue);
+}
diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/ITextExplanationViewSpec.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/ITextExplanationViewSpec.ts
index 8255279c61..0fe5c3364c 100644
--- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/ITextExplanationViewSpec.ts
+++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/ITextExplanationViewSpec.ts
@@ -13,12 +13,13 @@ export interface ITextExplanationViewState {
maxK: number;
topK: number;
radio: string;
- // qaRadio?: string;
+ qaRadio?: string;
importances: number[];
singleTokenImportances: number[];
selectedToken: number;
tokenIndexes: number[];
text: string[];
+ outputFeatureImportances: number[][];
}
export const options: IChoiceGroupOption[] = [
diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/SidePanelOfChart.tsx b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/SidePanelOfChart.tsx
index 4b5e412e24..8e0e453227 100644
--- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/SidePanelOfChart.tsx
+++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/SidePanelOfChart.tsx
@@ -38,6 +38,9 @@ export interface ISidePanelOfChartProps {
selectedWeightVector: WeightVectorOption;
weightOptions: WeightVectorOption[];
weightLabels: any;
+ baseValue?: number;
+ outputFeatureValue?: number;
+ selectedTokenIndex?: number;
changeRadioButton: (
_event?: React.FormEvent,
item?: IChoiceGroupOption
@@ -63,6 +66,9 @@ export class SidePanelOfChart extends React.PureComponent
diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.styles.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.styles.ts
index cb77508059..9f38b9d5fe 100644
--- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.styles.ts
+++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.styles.ts
@@ -11,16 +11,25 @@ import {
export interface ITextExplanationDashboardStyles {
chartRight: IStyle;
textHighlighting: IStyle;
+ predictedAnswer: IStyle;
+ boldText: IStyle;
}
export const textExplanationDashboardStyles: () => IProcessedStyleSet =
() => {
const theme = getTheme();
return mergeStyleSets({
+ boldText: {
+ fontWeight: "bold"
+ },
chartRight: {
maxWidth: "230px",
minWidth: "230px"
},
+ predictedAnswer: {
+ fontWeight: "bold",
+ paddingBottom: "14px"
+ },
textHighlighting: {
borderColor: theme.semanticColors.variantBorder,
borderRadius: "1px",
diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx
index de8c28a28f..636862564e 100644
--- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx
+++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx
@@ -2,24 +2,29 @@
// Licensed under the MIT License.
import { IChoiceGroupOption, Stack, Text } from "@fluentui/react";
-import { WeightVectorOption, WeightVectors } from "@responsible-ai/core-ui";
+import { WeightVectorOption } from "@responsible-ai/core-ui";
import { localization } from "@responsible-ai/localization";
import React from "react";
-import { RadioKeys, Utils } from "../../CommonUtils";
+import { QAExplanationType, RadioKeys } from "../../CommonUtils";
import { ITextExplanationViewProps } from "../../Interfaces/IExplanationViewProps";
-import { TextFeatureLegend } from "../TextFeatureLegend/TextFeatureLegend";
-import { TextHighlighting } from "../TextHighlighting/TextHightlighting";
import {
ITextExplanationViewState,
- MaxImportantWords,
componentStackTokens
} from "./ITextExplanationViewSpec";
import { SidePanelOfChart } from "./SidePanelOfChart";
-import { textExplanationDashboardStyles } from "./TextExplanationView.styles";
-
-export class TextExplanationView extends React.PureComponent<
+import {
+ calculateMaxKImportances,
+ calculateTopKImportances,
+ computeImportancesForAllTokens,
+ computeImportancesForWeightVector,
+ getOutputFeatureImportances
+} from "./TextExplanationViewUtils";
+import { TextInputOutputAreaWithLegend } from "./TextInputOutputAreaWithLegend";
+import { TrueAndPredictedAnswerView } from "./TrueAndPredictedAnswerView";
+
+export class TextExplanationView extends React.Component<
ITextExplanationViewProps,
ITextExplanationViewState
> {
@@ -31,23 +36,32 @@ export class TextExplanationView extends React.PureComponent<
const weightVector = this.props.selectedWeightVector;
const importances = this.props.isQA
- ? this.computeImportancesForAllTokens(
- this.props.dataSummary.localExplanations
+ ? computeImportancesForAllTokens(
+ this.props.dataSummary.localExplanations,
+ true
)
- : this.computeImportancesForWeightVector(
+ : computeImportancesForWeightVector(
this.props.dataSummary.localExplanations,
weightVector
);
- const maxK = this.calculateMaxKImportances(importances);
- const topK = this.calculateTopKImportances(importances);
+ const maxK = calculateMaxKImportances(importances);
+ const topK = calculateTopKImportances(importances);
this.state = {
importances,
maxK,
- // qaRadio: QAExplanationType.Start,
+ outputFeatureImportances: getOutputFeatureImportances(
+ this.props.dataSummary.localExplanations,
+ this.props.dataSummary.baseValues
+ ),
+ qaRadio: QAExplanationType.Start,
radio: RadioKeys.All,
- selectedToken: 0, // default to the first token
- singleTokenImportances: this.getImportanceForSingleToken(0), // get importance for first token
+ selectedToken: 0,
+ // default to the first token
+ singleTokenImportances: this.props.dataSummary.localExplanations[0].map(
+ (row) => row[0]
+ ),
+ // get importance for first token
text: this.props.dataSummary.text,
tokenIndexes: [...this.props.dataSummary.text].map((_, index) => index),
topK
@@ -60,28 +74,19 @@ export class TextExplanationView extends React.PureComponent<
this.props.dataSummary.localExplanations !==
prevProps.dataSummary.localExplanations
) {
- if (this.props.isQA) {
- this.setState(
- {
- selectedToken: 0,
- //update token dropdown
- tokenIndexes: [...this.props.dataSummary.text].map(
- (_, index) => index
- )
- },
- () => {
- this.updateTokenImportances();
- this.updateSingleTokenImportances();
- }
- );
- } else {
- this.updateImportances(this.props.selectedWeightVector);
- }
+ this.updateState();
}
}
public render(): React.ReactNode {
- const classNames = textExplanationDashboardStyles();
+ const outputLocalExplanations =
+ this.state.qaRadio === QAExplanationType.Start
+ ? this.state.outputFeatureImportances[0]
+ : this.state.outputFeatureImportances[1];
+ const inputLocalExplanations = this.props.isQA
+ ? this.state.singleTokenImportances
+ : this.state.importances;
+ const baseValue = this.props.isQA ? this.getBaseValue() : undefined;
return (
@@ -93,9 +98,30 @@ export class TextExplanationView extends React.PureComponent<
)}
+
+ {this.props.isQA && (
+
+ )}
+
+
+
+
-
-
-
-
-
-
- {this.props.isQA && (
-
-
-
- )}
-
-
-
-
-
);
}
- private onWeightVectorChange = (weightOption: WeightVectorOption): void => {
- this.updateImportances(weightOption);
- this.props.onWeightChange(weightOption);
- };
-
- private onSelectedTokenChange = (newIndex: number): void => {
- this.setState({ selectedToken: newIndex }, () => {
- this.updateSingleTokenImportances();
- });
- };
-
- private updateImportances(weightOption: WeightVectorOption): void {
- const importances = this.computeImportancesForWeightVector(
- this.props.dataSummary.localExplanations,
- weightOption
- );
-
- const topK = this.calculateTopKImportances(importances);
- const maxK = this.calculateMaxKImportances(importances);
+ private updateState(): void {
+ const importances = this.props.isQA
+ ? this.getTokenImportances()
+ : this.getImportances(this.props.selectedWeightVector);
+ const [topK, maxK] = this.getTopKMaxK(importances);
this.setState({
importances,
maxK,
+ outputFeatureImportances: getOutputFeatureImportances(
+ this.props.dataSummary.localExplanations,
+ this.props.dataSummary.baseValues
+ ),
+ selectedToken: 0,
+ singleTokenImportances: this.getImportanceForSingleToken(
+ this.state.selectedToken
+ ),
text: this.props.dataSummary.text,
+ tokenIndexes: [...this.props.dataSummary.text].map((_, index) => index),
topK
});
}
- // for QA
- private updateTokenImportances(): void {
- const importances = this.computeImportancesForAllTokens(
- this.props.dataSummary.localExplanations
- );
- const topK = this.calculateTopKImportances(importances);
- const maxK = this.calculateMaxKImportances(importances);
+ private onWeightVectorChange = (weightOption: WeightVectorOption): void => {
+ const importances = this.getImportances(weightOption);
+ const [topK, maxK] = this.getTopKMaxK(importances);
+ this.setState({ importances, maxK, topK });
+ this.props.onWeightChange(weightOption);
+ };
+
+ private onSelectedTokenChange = (newIndex: number): void => {
+ const singleTokenImportances = this.getImportanceForSingleToken(newIndex);
this.setState({
- importances,
- maxK,
- text: this.props.dataSummary.text,
- topK
+ selectedToken: newIndex,
+ singleTokenImportances
});
- }
+ };
- private updateSingleTokenImportances(): void {
- const singleTokenImportances = this.getImportanceForSingleToken(
- this.state.selectedToken
- );
- this.setState({ singleTokenImportances });
- }
+ private getSelectedWord = (): string => {
+ return this.props.dataSummary.text[this.state.selectedToken];
+ };
- private calculateTopKImportances(importances: number[]): number {
- return Math.min(
- MaxImportantWords,
- Math.ceil(Utils.countNonzeros(importances) / 2)
- );
+ private getTopKMaxK(importances: number[]): [number, number] {
+ const topK = calculateTopKImportances(importances);
+ const maxK = calculateMaxKImportances(importances);
+ return [topK, maxK];
}
- private calculateMaxKImportances(importances: number[]): number {
- return Math.min(
- MaxImportantWords,
- Math.ceil(Utils.countNonzeros(importances))
+ private getImportances(weightOption: WeightVectorOption): number[] {
+ return computeImportancesForWeightVector(
+ this.props.dataSummary.localExplanations,
+ weightOption
);
}
- private computeImportancesForWeightVector(
- importances: number[][],
- weightVector: WeightVectorOption
- ): number[] {
- if (weightVector === WeightVectors.AbsAvg) {
- // Sum the multidimensional array to one dimension across rows for each token
- const numClasses = importances[0].length;
- const sumImportances = importances.map((row) =>
- row.reduce((a, b): number => {
- return (a + Math.abs(b)) / numClasses;
- }, 0)
- );
- return sumImportances;
- }
- return importances.map(
- (perClassImportances) => perClassImportances[weightVector as number]
+ // for QA
+ private getTokenImportances(): number[] {
+ return computeImportancesForAllTokens(
+ this.props.dataSummary.localExplanations
);
}
- private computeImportancesForAllTokens(importances: number[][]): number[] {
- /*
- * sum the tokens importance
- * TODO: add base values?
- */
-
- const sumImportances = importances[0].map((_, index) =>
- importances.reduce((sum, row) => sum + row[index], 0)
+ private getImportanceForSingleToken(index: number): number[] {
+ const expIndex = this.state.qaRadio === QAExplanationType.Start ? 0 : 1;
+ return this.props.dataSummary.localExplanations[expIndex].map(
+ (row) => row[index]
);
-
- return sumImportances;
}
- private getImportanceForSingleToken(index: number): number[] {
- return this.props.dataSummary.localExplanations.map((row) => row[index]);
+ private getBaseValue(): number {
+ if (this.props.dataSummary.baseValues) {
+ const expIndex = this.state.qaRadio === QAExplanationType.Start ? 0 : 1;
+ return this.props.dataSummary.baseValues?.[expIndex][
+ this.state.selectedToken
+ ];
+ }
+ return 0;
}
private setTopK = (newNumber: number): void => {
- /*
- * Changes the state of K
- */
this.setState({ topK: newNumber });
};
@@ -265,23 +231,23 @@ export class TextExplanationView extends React.PureComponent<
_event?: React.FormEvent,
item?: IChoiceGroupOption
): void => {
- /*
- * Changes the state of the radio button
- */
- if (item?.key !== undefined) {
+ if (item?.key) {
this.setState({ radio: item.key });
}
};
- private switchQAPrediction = (): // _event?: React.FormEvent,
- // _item?: IChoiceGroupOption
- void => {
- /*
- * switch to the target predictions(starting or ending)
- * TODO: add logic for switching explanation data
- */
- // if (item?.key !== undefined) {
- // this.setState({ qaRadio: item.key });
- // }
+ private switchQAPrediction = (
+ _event?: React.FormEvent,
+ item?: IChoiceGroupOption
+ ): void => {
+ if (item?.key) {
+ const singleTokenImportances = this.getImportanceForSingleToken(
+ this.state.selectedToken
+ );
+ this.setState({
+ qaRadio: item.key,
+ singleTokenImportances
+ });
+ }
};
}
diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationViewUtils.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationViewUtils.ts
new file mode 100644
index 0000000000..313deaee99
--- /dev/null
+++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationViewUtils.ts
@@ -0,0 +1,100 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+import { WeightVectorOption, WeightVectors } from "@responsible-ai/core-ui";
+
+import { QAExplanationType, Utils } from "../../CommonUtils";
+
+import { MaxImportantWords } from "./ITextExplanationViewSpec";
+
+export function getOutputFeatureImportances(
+ localExplanations: number[][],
+ baseValues?: number[][]
+): number[][] {
+ const startSumOfFeatureImportances = getSumOfFeatureImportances(
+ localExplanations[0]
+ );
+ const endSumOfFeatureImportances = getSumOfFeatureImportances(
+ localExplanations[1]
+ );
+ const startOutputFeatureImportances = getOutputFeatureImportancesIntl(
+ startSumOfFeatureImportances,
+ baseValues?.[0]
+ );
+ const endOutputFeatureImportances = getOutputFeatureImportancesIntl(
+ endSumOfFeatureImportances,
+ baseValues?.[1]
+ );
+ return [
+ startOutputFeatureImportances || [],
+ endOutputFeatureImportances || []
+ ];
+}
+
+export function getSumOfFeatureImportances(importances: number[]): number[] {
+ return importances.map((_, index) =>
+ importances.reduce((sum, row) => sum + row[index], 0)
+ );
+}
+
+export function getOutputFeatureImportancesIntl(
+ sumOfFeatureImportances: number[],
+ baseValues?: number[]
+): number[] | undefined {
+ return baseValues?.map(
+ (bValue, index) => sumOfFeatureImportances[index] + bValue
+ );
+}
+
+export function calculateTopKImportances(importances: number[]): number {
+ return Math.min(
+ MaxImportantWords,
+ Math.ceil(Utils.countNonzeros(importances) / 2)
+ );
+}
+
+export function calculateMaxKImportances(importances: number[]): number {
+ return Math.min(
+ MaxImportantWords,
+ Math.ceil(Utils.countNonzeros(importances))
+ );
+}
+
+export function computeImportancesForWeightVector(
+ importances: number[][],
+ weightVector: WeightVectorOption
+): number[] {
+ if (weightVector === WeightVectors.AbsAvg) {
+ // Sum the multidimensional array to one dimension across rows for each token
+ const numClasses = importances[0].length;
+ const sumImportances = importances.map((row) =>
+ row.reduce((a, b): number => {
+ return (a + Math.abs(b)) / numClasses;
+ }, 0)
+ );
+ return sumImportances;
+ }
+ return importances.map(
+ (perClassImportances) => perClassImportances[weightVector as number]
+ );
+}
+
+export function computeImportancesForAllTokens(
+ importances: number[][],
+ isInitialState?: boolean,
+ qaRadio?: string
+): number[] {
+ const startSumImportances = importances[0].map((_, index) =>
+ importances.reduce((sum, row) => sum + row[index], 0)
+ );
+ const endSumImportances = importances[1].map((_, index) =>
+ importances.reduce((sum, row) => sum + row[index], 0)
+ );
+ if (isInitialState) {
+ return startSumImportances;
+ }
+
+ return qaRadio === QAExplanationType.Start
+ ? startSumImportances
+ : endSumImportances;
+}
diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextInputOutputAreaWithLegend.tsx b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextInputOutputAreaWithLegend.tsx
new file mode 100644
index 0000000000..579d56bc0a
--- /dev/null
+++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextInputOutputAreaWithLegend.tsx
@@ -0,0 +1,94 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+import { Stack, Text } from "@fluentui/react";
+import { localization } from "@responsible-ai/localization";
+import React from "react";
+
+import { TextFeatureLegend } from "../TextFeatureLegend/TextFeatureLegend";
+import { TextHighlighting } from "../TextHighlighting/TextHightlighting";
+
+import { componentStackTokens } from "./ITextExplanationViewSpec";
+import { textExplanationDashboardStyles } from "./TextExplanationView.styles";
+
+interface ITextInputOutputAreaWithLegendProps {
+ topK: number;
+ radio: string;
+ selectedToken: number;
+ text: string[];
+ outputLocalExplanations: number[];
+ inputLocalExplanations: number[];
+ isQA?: boolean;
+ getSelectedWord: () => string;
+ onSelectedTokenChange: (newIndex: number) => void;
+}
+
+export class TextInputOutputAreaWithLegend extends React.Component {
+ public render(): React.ReactNode {
+ const classNames = textExplanationDashboardStyles();
+
+ return (
+
+ {this.props.isQA && (
+
+
+
+
+ {localization.InterpretText.View.outputs}
+
+
+
+
+
+
+
+ )}
+
+
+
+
+ {localization.InterpretText.View.inputs}
+
+
+
+
+
+
+
+
+
+
+
+ );
+ }
+}
diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TrueAndPredictedAnswerView.tsx b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TrueAndPredictedAnswerView.tsx
new file mode 100644
index 0000000000..bc94fa8247
--- /dev/null
+++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TrueAndPredictedAnswerView.tsx
@@ -0,0 +1,46 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+import { Stack, Text } from "@fluentui/react";
+import { localization } from "@responsible-ai/localization";
+import React from "react";
+
+import { textExplanationDashboardStyles } from "./TextExplanationView.styles";
+
+interface ITrueAndPredictedAnswerViewProps {
+ predictedY: string | number | number[] | string[] | number[][] | undefined;
+ trueY: string | number | number[] | string[] | number[][] | undefined;
+}
+
+export class TrueAndPredictedAnswerView extends React.Component {
+ public render(): React.ReactNode {
+ const classNames = textExplanationDashboardStyles();
+
+ return (
+
+
+
+
+ {localization.InterpretText.View.predictedAnswer}
+
+
+
+
+ {this.props.predictedY}
+
+
+
+
+
+
+ {localization.InterpretText.View.trueAnswer}
+
+
+
+ {this.props.trueY}
+
+
+
+ );
+ }
+}
diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextFeatureLegend/TextFeatureLegend.styles.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextFeatureLegend/TextFeatureLegend.styles.ts
index a72c4b0993..38fc87aecf 100644
--- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextFeatureLegend/TextFeatureLegend.styles.ts
+++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextFeatureLegend/TextFeatureLegend.styles.ts
@@ -7,10 +7,6 @@ import {
IProcessedStyleSet,
getTheme
} from "@fluentui/react";
-import {
- getPrimaryBackgroundChartColor,
- getPrimaryChartColor
-} from "@responsible-ai/core-ui";
export interface ITextFeatureLegendStyles {
legend: IStyle;
@@ -26,12 +22,12 @@ export const textFeatureLegendStyles: () => IProcessedStyleSet {
public render(): React.ReactNode {
const classNames = textFeatureLegendStyles();
return (
@@ -51,6 +56,28 @@ export class TextFeatureLegend extends React.Component {
+ {this.props.isQA && (
+
+
+ {localization.InterpretText.Legend.cls}
+
+
+ {localization.InterpretText.Legend.sep}
+
+
+
+
+
+ {localization.InterpretText.Legend.selectedWord}
+
+
+
+ {this.props.selectedWord}
+
+
+
+
+ )}
);
}
diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHighlighting.styles.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHighlighting.styles.ts
index bff576c338..c98ee8bb88 100644
--- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHighlighting.styles.ts
+++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHighlighting.styles.ts
@@ -2,14 +2,13 @@
// Licensed under the MIT License.
import {
- IStyle,
- mergeStyles,
- mergeStyleSets,
IProcessedStyleSet,
IStackStyles,
- getTheme
+ IStyle,
+ getTheme,
+ mergeStyleSets,
+ mergeStyles
} from "@fluentui/react";
-import { getPrimaryChartColor } from "@responsible-ai/core-ui";
export const textStackStyles: IStackStyles = {
root: {
@@ -31,30 +30,41 @@ export interface ITextHighlightingStyles {
boldunderline: IStyle;
}
-export const textHighlightingStyles: () => IProcessedStyleSet =
- () => {
- const theme = getTheme();
- const normal = {
- color: theme.semanticColors.bodyText
- };
- return mergeStyleSets({
- boldunderline: mergeStyles([
- normal,
- {
- color: getPrimaryChartColor(theme),
- fontSize: theme.fonts.large.fontSize,
- margin: "2px",
- padding: 0,
- textDecorationLine: "underline"
- }
- ]),
- highlighted: mergeStyles([
- normal,
- {
- backgroundColor: getPrimaryChartColor(theme),
- color: theme.semanticColors.bodyBackground
- }
- ]),
- normal
- });
+export const textHighlightingStyles: (
+ isTextSelected: boolean
+) => IProcessedStyleSet = (isTextSelected) => {
+ const theme = getTheme();
+ const normal = {
+ color: theme.semanticColors.bodyText
};
+ const selectedTextStyle = isTextSelected
+ ? {
+ textDecorationColor: "black",
+ textDecorationLine: "underline",
+ textDecorationStyle: "solid",
+ textDecorationThickness: "4px"
+ }
+ : {};
+ return mergeStyleSets({
+ boldunderline: mergeStyles([
+ normal,
+ {
+ backgroundColor: theme.semanticColors.link,
+ color: theme.semanticColors.bodyBackground,
+ fontSize: theme.fonts.large.fontSize,
+ margin: "2px",
+ padding: 0
+ },
+ selectedTextStyle
+ ]),
+ highlighted: mergeStyles([
+ normal,
+ selectedTextStyle,
+ {
+ backgroundColor: theme.semanticColors.errorText,
+ color: theme.semanticColors.bodyBackground
+ }
+ ]),
+ normal: mergeStyles([normal, selectedTextStyle])
+ });
+};
diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHightlighting.tsx b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHightlighting.tsx
index 2fbac58256..1101cfd0b4 100644
--- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHightlighting.tsx
+++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextHighlighting/TextHightlighting.tsx
@@ -2,7 +2,6 @@
// Licensed under the MIT License.
import {
- Label,
Text,
Stack,
IStackTokens,
@@ -25,12 +24,11 @@ const textStackTokens: IStackTokens = {
padding: "s2"
};
-export class TextHighlighting extends React.PureComponent {
+export class TextHighlighting extends React.Component {
/*
* Presents the document in an accessible manner with text highlighting
*/
public render(): React.ReactNode {
- const classNames = textHighlightingStyles();
const text = this.props.text;
const importances = this.props.localExplanations;
const k = this.props.topK;
@@ -47,30 +45,22 @@ export class TextHighlighting extends React.PureComponent {
styles={textStackStyles}
>
{text.map((word, wordIndex) => {
+ const isWordSelected =
+ (this.props.selectedTokenIndex &&
+ wordIndex === this.props.selectedTokenIndex) ||
+ false;
+ const classNames = textHighlightingStyles(isWordSelected);
let styleType = classNames.normal;
const score = importances[wordIndex];
- let isBold = false;
if (sortedList.includes(wordIndex)) {
if (score > 0) {
styleType = classNames.highlighted;
} else if (score < 0) {
styleType = classNames.boldunderline;
- isBold = true;
} else {
styleType = classNames.normal;
}
}
- if (isBold) {
- return (
-
- );
- }
return (
{
key={wordIndex}
className={styleType}
title={score.toString()}
+ onClick={(): void => this.handleClick(wordIndex)}
>
{word}
@@ -88,4 +79,13 @@ export class TextHighlighting extends React.PureComponent {
);
}
+
+ private readonly handleClick = (wordIndex: number): void => {
+ if (this.props.isInput) {
+ return;
+ }
+ if (this.props.onSelectedTokenChange) {
+ this.props.onSelectedTokenChange(wordIndex);
+ }
+ };
}
diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IChartProps.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IChartProps.ts
index fa2770820f..f16f1f6d9e 100644
--- a/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IChartProps.ts
+++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IChartProps.ts
@@ -9,4 +9,9 @@ export interface IChartProps {
localExplanations: number[];
topK?: number;
radio?: string;
+ isInput?: boolean;
+ baseValue?: number;
+ outputFeatureValue?: number;
+ selectedTokenIndex?: number;
+ onSelectedTokenChange?: (newIndex: number) => void;
}
diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IExplanationDashboardProps.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IExplanationDashboardProps.ts
index d82261257e..903a652cdf 100644
--- a/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IExplanationDashboardProps.ts
+++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Interfaces/IExplanationDashboardProps.ts
@@ -19,5 +19,8 @@ export interface IDatasetSummary {
text: string[];
classNames?: string[];
localExplanations: number[][];
+ baseValues?: number[][];
prediction?: number[];
+ predictedY?: number[] | number[][] | string[] | string | number;
+ trueY?: number[] | number[][] | string[] | string | number;
}
diff --git a/libs/localization/src/lib/en.json b/libs/localization/src/lib/en.json
index 5e6824eedd..430bb44530 100644
--- a/libs/localization/src/lib/en.json
+++ b/libs/localization/src/lib/en.json
@@ -1374,12 +1374,19 @@
"label": "Label",
"colon": ": ",
"startingPosition": "STARTING POSITION",
- "endingPosition": "ENDING POSITION"
+ "endingPosition": "ENDING POSITION",
+ "predictedAnswer": "Predicted answer: ",
+ "trueAnswer": "True answer: ",
+ "inputs": "Inputs",
+ "outputs": "Outputs"
},
"Legend": {
"featureLegend": "TEXT FEATURE LEGEND",
"posFeatureImportance": "POSITIVE FEATURE IMPORTANCE",
- "negFeatureImportance": "NEGATIVE FEATURE IMPORTANCE"
+ "negFeatureImportance": "NEGATIVE FEATURE IMPORTANCE",
+ "cls": "CLS: start of the sentence",
+ "sep": "SEP: end of the sentence",
+ "selectedWord": "Selected word: "
},
"BarChart": {
"featureImportance": "FEATURE IMPORTANCE"
diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/FeatureImportances.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/FeatureImportances.tsx
index 7abd051ec1..903b20caf1 100644
--- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/FeatureImportances.tsx
+++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/FeatureImportances.tsx
@@ -75,7 +75,6 @@ export class FeatureImportancesTab extends React.PureComponent<
return React.Fragment;
}
const classNames = featureImportanceTabStyles();
-
return (
{
@@ -44,10 +47,13 @@ export class TextLocalImportancePlots extends React.Component