Skip to content

Commit

Permalink
add base values and importances for ending tokens to QA explanations
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Jul 11, 2023
1 parent 71efc3b commit 4fa2bd6
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 19 deletions.
1 change: 1 addition & 0 deletions responsibleai/responsibleai/_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class FeatureImportance:


class TextFeatureImportance:
baseValues: List[float]
localExplanations: List
text: List[str]

Expand Down
54 changes: 35 additions & 19 deletions responsibleai_text/responsibleai_text/managers/explainer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,17 @@ def compute(self):
qa_predictor = QAPredictor(self._model)
qa_start = qa_predictor.predict_qa_start
qa_start.__func__.output_names = qa_predictor.output_names
explainer = shap.Explainer(qa_start, self._model.tokenizer)
explainer_start = shap.Explainer(qa_start, self._model.tokenizer)
qa_end = qa_predictor.predict_qa_end
qa_end.__func__.output_names = qa_predictor.output_names
explainer_end = shap.Explainer(qa_end, self._model.tokenizer)
context = self._evaluation_examples[CONTEXT]
questions = self._evaluation_examples[QUESTIONS]
eval_examples = []
for context, question in zip(context, questions):
eval_examples.append(question + SEP + context)
self._explanation = explainer(eval_examples)
self._explanation = [explainer_start(eval_examples),
explainer_end(eval_examples)]
else:
raise ValueError("Unknown task type: {}".format(self._task_type))

Expand Down Expand Up @@ -223,22 +227,30 @@ def _compute_global_importances(self, explanation):
scores = convert_to_list(np.abs(global_exp.values).mean(1))
intercept = global_exp.base_values.mean(0)
elif self._task_type == ModelTask.QUESTION_ANSWERING:
flattened_features = explanation._flatten_feature_names()
flattened_start_features = explanation[0]._flatten_feature_names()
flattened_end_features = explanation[1]._flatten_feature_names()
scores = []
features = []
for key in flattened_features.keys():
for key in flattened_start_features.keys():
features.append(key)
token_importances = []
for importances in flattened_features[key]:
token_importances.append(np.mean(np.abs(importances)))
for importances_start, importances_end in zip(
flattened_start_features[key],
flattened_end_features[key]):
abs_start_imps = np.abs(importances_start)
abs_end_imps = np.abs(importances_end)
importances = (abs_start_imps + abs_end_imps) / 2
mean_importances = np.mean(importances)
token_importances.append(mean_importances)
scores.append(np.mean(token_importances))
start_base_values = explanation[0].base_values
end_base_values = explanation[1].base_values
base_values = [
base_values.mean()
for base_values in explanation.base_values]
sbv.mean() + ebv.mean() / 2
for sbv, ebv in zip(start_base_values, end_base_values)]
intercept = sum(base_values) / len(base_values)
else:
raise ValueError("Unknown task type: {}".format(self._task_type))

return features, scores, intercept

def _compute_text_feature_importances(self, explanation):
Expand All @@ -251,21 +263,25 @@ def _compute_text_feature_importances(self, explanation):
"""
text_feature_importances = []
is_classif_task = self._is_classification_task
for instance in explanation:
text_feature_importance = TextFeatureImportance()
if is_classif_task:
if is_classif_task:
for instance in explanation:
text_feature_importance = TextFeatureImportance()
text_feature_importance.localExplanations = \
instance.values.tolist()
text_feature_importance.text = instance.data
elif self._task_type == ModelTask.QUESTION_ANSWERING:
text_feature_importances.append(text_feature_importance)
elif self._task_type == ModelTask.QUESTION_ANSWERING:
for i_start, i_end in zip(explanation[0], explanation[1]):
text_feature_importance = TextFeatureImportance()
text_feature_importance.localExplanations = \
instance.values.tolist()
text_feature_importance.text = instance.data
else:
raise ValueError("Unknown task type: {}".format(
self._task_type))
text_feature_importances.append(text_feature_importance)
[i_start.values.tolist(), i_end.values.tolist()]
text_feature_importance.text = i_start.data
text_feature_importance.baseValues = \
[i_start.base_values.tolist(), i_end.base_values.tolist()]
text_feature_importances.append(text_feature_importance)
else:
raise ValueError("Unknown task type: {}".format(
self._task_type))
return text_feature_importances

@property
Expand Down
10 changes: 10 additions & 0 deletions responsibleai_text/responsibleai_text/utils/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ def predict_qa_start(self, questions):
"""
return self.predict_qa(questions, True)

def predict_qa_end(self, questions):
"""Define predictions outputting the logits for the end of the range.
:param questions: The questions and context to predict on.
:type questions: list[str]
:return: The logits for the end of the range.
:rtype: list[list[float]]
"""
return self.predict_qa(questions, False)

def output_names(self, inputs):
"""Define the output names as tokens.
Expand Down
32 changes: 32 additions & 0 deletions responsibleai_text/tests/rai_text_insights_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
from responsibleai_text import ModelTask


def validate_global_importances(exp_data):
global_importances = exp_data.globalFeatureImportance
scores = global_importances.scores
features = global_importances.featureNames
assert scores.shape[0] == len(features)


def validate_rai_text_insights(
rai_text_insights,
classes,
Expand All @@ -17,6 +24,31 @@ def validate_rai_text_insights(
pd.testing.assert_frame_equal(rai_text_insights.test, test_data)
assert rai_text_insights.target_column == target_column
assert rai_text_insights.task_type == task_type
explanation = rai_text_insights.explainer.get()
assert explanation is None or isinstance(explanation, list)
explanation_data = rai_text_insights.explainer.get_data()
assert explanation_data is None or isinstance(explanation_data, list)
if task_type == ModelTask.TEXT_CLASSIFICATION:
np.testing.assert_array_equal(rai_text_insights._classes,
classes)
if explanation_data:
exp_data = explanation_data[0].precomputedExplanations
validate_global_importances(exp_data)
local_data = exp_data.textFeatureImportance
local_importances = local_data.localExplanations
text = local_data.text
assert local_importances.shape[0] == len(test_data)
assert local_importances.shape[1] == len(text)
if task_type == ModelTask.QUESTION_ANSWERING:
if explanation_data:
exp_data = explanation_data[0].precomputedExplanations
validate_global_importances(exp_data)
local_data = exp_data.textFeatureImportance
local_importances = local_data.localExplanations
text = local_data.text
base_values = local_data.baseValues
num_rows = len(test_data)
assert local_importances.shape[0] == 2
assert local_importances.shape[1] == num_rows
assert local_importances.shape[2] == len(text)
assert len(base_values) == num_rows

0 comments on commit 4fa2bd6

Please sign in to comment.