Skip to content

Commit

Permalink
Individual feature importance interpret QA (#2186)
Browse files Browse the repository at this point in the history
* add

Signed-off-by: vinutha karanth <[email protected]>

* update

Signed-off-by: vinutha karanth <[email protected]>

* cleanup

Signed-off-by: vinutha karanth <[email protected]>

* lintfix

Signed-off-by: vinutha karanth <[email protected]>

* lintfix

Signed-off-by: vinutha karanth <[email protected]>

* lintfix

Signed-off-by: vinutha karanth <[email protected]>

* fix row change err

Signed-off-by: vinutha karanth <[email protected]>

* address comments

Signed-off-by: vinutha karanth <[email protected]>

---------

Signed-off-by: vinutha karanth <[email protected]>
  • Loading branch information
vinuthakaranth authored Jul 24, 2023
1 parent 37c340c commit 29bfb51
Show file tree
Hide file tree
Showing 20 changed files with 595 additions and 237 deletions.
1 change: 1 addition & 0 deletions libs/core-ui/src/lib/Interfaces/ExplanationInterfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ export interface IPrecomputedExplanations {
export interface ITextFeatureImportance {
text: string[];
localExplanations: number[][];
baseValues?: number[][];
}

export interface IEBMGlobalExplanation {
Expand Down
3 changes: 3 additions & 0 deletions libs/core-ui/src/lib/Interfaces/TextExplanationInterfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ export interface ITextExplanationDashboardData {
localExplanations: number[][];
prediction: number[];
text: string[];
baseValues?: number[][];
predictedY?: number[] | number[][] | string[] | string | number;
trueY?: number[] | number[][] | string[] | string | number;
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ export class Utils {
return sortedList;
}

public static addItem(value: number, radio: string | undefined): boolean {
return (
radio === RadioKeys.All ||
(radio === RadioKeys.Neg && value <= 0) ||
(radio === RadioKeys.Pos && value >= 0)
);
}

public static takeTopK(list: number[], k: number): number[] {
/*
* Returns a list after splicing and taking the top K
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,39 @@
// Licensed under the MIT License.

import { ITheme } from "@fluentui/react";
import {
IHighchartsConfig,
getPrimaryChartColor,
getPrimaryBackgroundChartColor
} from "@responsible-ai/core-ui";
import { IHighchartsConfig } from "@responsible-ai/core-ui";
import { localization } from "@responsible-ai/localization";
import { SeriesOptionsType } from "highcharts";
import _ from "lodash";

import { Utils } from "../../CommonUtils";
import { IChartProps } from "../../Interfaces/IChartProps";

function findNearestIndex(
array: number[],
target?: number
): number | undefined {
if (!target) {
return array.length;
}
const nearestElement = _.minBy(array, (element) =>
Math.abs(element - target)
);
return _.indexOf(array, nearestElement);
}

export function getTokenImportancesChartOptions(
props: IChartProps,
theme: ITheme
): IHighchartsConfig {
const importances = props.localExplanations;
const k = props.topK;
const sortedList = Utils.sortedTopK(importances, k, props.radio);

const outputFeatureImportanceLabel = `f ${
props.text[props.selectedTokenIndex || 0]
} (inputs)`;
const baseValueLabel = "base value";
const [x, y, ylabel, tooltip]: [number[], number[], string[], string[]] = [
[],
[],
Expand All @@ -46,6 +61,36 @@ export function getTokenImportancesChartOptions(
ylabel.push(props.text[idx]);
tooltip.push(str);
});

// add output feature importance
if (props.outputFeatureValue && props.baseValue) {
const outputFeatureValueIndex = findNearestIndex(
x,
props.outputFeatureValue
);
const baseValueFeatureValueIndex = findNearestIndex(x, props.baseValue);
if (outputFeatureValueIndex && baseValueFeatureValueIndex) {
if (Utils.addItem(props.outputFeatureValue, props.radio)) {
addItem(
x,
props.outputFeatureValue,
ylabel,
outputFeatureImportanceLabel,
outputFeatureValueIndex
);
}
if (Utils.addItem(props.baseValue, props.radio)) {
addItem(
x,
props.baseValue,
ylabel,
baseValueLabel,
baseValueFeatureValueIndex
);
}
}
}

// Put most significant word at the top by reversing order
tooltip.reverse();
ylabel.reverse();
Expand All @@ -54,11 +99,10 @@ export function getTokenImportancesChartOptions(
const data: any[] = [];
x.forEach((p, index) => {
const temp = {
borderColor: getPrimaryChartColor(theme),
color:
(p || 0) >= 0
? getPrimaryChartColor(theme)
: getPrimaryBackgroundChartColor(theme),
? theme.semanticColors.errorText
: theme.semanticColors.link,
x: index,
y: p
};
Expand All @@ -68,6 +112,15 @@ export function getTokenImportancesChartOptions(
const series: SeriesOptionsType[] = [
{
data,
dataLabels: {
align: "center",
color: theme.semanticColors.bodyBackground,
enabled: true,
formatter(): string | number | undefined {
return this.x; // Display the Y-axis value inside the bar
},
inside: true
},
name: "",
showInLegend: false,
type: "bar"
Expand All @@ -80,11 +133,12 @@ export function getTokenImportancesChartOptions(
},
plotOptions: {
bar: {
minPointLength: 10,
tooltip: {
pointFormatter(): string {
return `${tooltip[this.x || 0]}: ${this.y || 0}`;
}
}
} // Set the minimum pixel width for bars
}
},
series,
Expand All @@ -98,3 +152,14 @@ export function getTokenImportancesChartOptions(
}
};
}

function addItem(
x: any[],
xValue: any,
yLabel: any[],
yLabelValue: any,
index: number
): void {
x.splice(index, 0, xValue);
yLabel.splice(index, 0, yLabelValue);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ export interface ITextExplanationViewState {
maxK: number;
topK: number;
radio: string;
// qaRadio?: string;
qaRadio?: string;
importances: number[];
singleTokenImportances: number[];
selectedToken: number;
tokenIndexes: number[];
text: string[];
outputFeatureImportances: number[][];
}

export const options: IChoiceGroupOption[] = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ export interface ISidePanelOfChartProps {
selectedWeightVector: WeightVectorOption;
weightOptions: WeightVectorOption[];
weightLabels: any;
baseValue?: number;
outputFeatureValue?: number;
selectedTokenIndex?: number;
changeRadioButton: (
_event?: React.FormEvent,
item?: IChoiceGroupOption
Expand All @@ -63,6 +66,9 @@ export class SidePanelOfChart extends React.PureComponent<ISidePanelOfChartProps
localExplanations={this.props.importances}
topK={this.props.topK}
radio={this.props.radio}
baseValue={this.props.baseValue}
outputFeatureValue={this.props.outputFeatureValue}
selectedTokenIndex={this.props.selectedTokenIndex}
/>
</Stack.Item>
<Stack.Item grow className={classNames.chartRight}>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,25 @@ import {
export interface ITextExplanationDashboardStyles {
chartRight: IStyle;
textHighlighting: IStyle;
predictedAnswer: IStyle;
boldText: IStyle;
}

export const textExplanationDashboardStyles: () => IProcessedStyleSet<ITextExplanationDashboardStyles> =
() => {
const theme = getTheme();
return mergeStyleSets<ITextExplanationDashboardStyles>({
boldText: {
fontWeight: "bold"
},
chartRight: {
maxWidth: "230px",
minWidth: "230px"
},
predictedAnswer: {
fontWeight: "bold",
paddingBottom: "14px"
},
textHighlighting: {
borderColor: theme.semanticColors.variantBorder,
borderRadius: "1px",
Expand Down
Loading

0 comments on commit 29bfb51

Please sign in to comment.