diff --git a/responsibleai_text/responsibleai_text/managers/explainer_manager.py b/responsibleai_text/responsibleai_text/managers/explainer_manager.py index 870af5af88..265bef60fd 100644 --- a/responsibleai_text/responsibleai_text/managers/explainer_manager.py +++ b/responsibleai_text/responsibleai_text/managers/explainer_manager.py @@ -29,6 +29,26 @@ Tokens) from responsibleai_text.utils.question_answering import QAPredictor +try: + from interpret_text.generative.lime_tools.explainers import \ + LocalExplanationSentenceEmbedder + interpret_text_explainers_installed = True +except ImportError: + interpret_text_explainers_installed = False + +try: + from interpret_text.generative.model_lib.openai_tooling import ChatOpenAI + interpret_text_openai_tooling_installed = True +except ImportError: + interpret_text_openai_tooling_installed = False + +try: + from sentence_transformers import SentenceTransformer + sentence_transformers_installed = True +except ImportError: + sentence_transformers_installed = False + + CONTEXT = QuestionAnsweringFields.CONTEXT QUESTIONS = QuestionAnsweringFields.QUESTIONS SEP = Tokens.SEP @@ -42,6 +62,7 @@ MODEL = Metadata.MODEL EXPLANATION = '_explanation' TASK_TYPE = '_task_type' +PROMPT = 'prompt' class ExplainerManager(BaseManager): @@ -74,10 +95,13 @@ def __init__(self, model: Any, evaluation_examples: pd.DataFrame, """ self._model = model self._target_column = target_column - if not isinstance(target_column, list): + if not isinstance(target_column, (list, type(None))): target_column = [target_column] - self._evaluation_examples = \ - evaluation_examples.drop(columns=target_column) + if target_column is None: + self._evaluation_examples = evaluation_examples + else: + self._evaluation_examples = \ + evaluation_examples.drop(columns=target_column) self._is_run = False self._is_added = False self._features = list(self._evaluation_examples.columns) @@ -131,6 +155,73 @@ def compute(self): eval_examples.append(question + SEP + context) self._explanation = [explainer_start(eval_examples), explainer_end(eval_examples)] + elif self._task_type == ModelTask.GENERATIVE_TEXT: + if not interpret_text_explainers_installed: + error = ( + "The required module" + "'interpret_text.generative.lime_tools.explainers' " + "is not installed." + ) + raise RuntimeError(error) + if not interpret_text_openai_tooling_installed: + error = ( + "The required module" + "'interpret_text.generative.model_lib.openai_tooling' " + "is not installed." + ) + raise RuntimeError(error) + if not sentence_transformers_installed: + error = ( + "The required package" + "'sentence_transformers' " + "is not installed." + ) + raise RuntimeError(error) + + if CONTEXT in self._evaluation_examples.columns and \ + QUESTIONS in self._evaluation_examples.columns: + context = self._evaluation_examples[CONTEXT] + questions = self._evaluation_examples[QUESTIONS] + eval_examples = [] + for context, question in zip(context, questions): + eval_examples.append(question + SEP + context) + elif PROMPT in self._evaluation_examples.columns: + eval_examples = self._evaluation_examples[PROMPT].tolist() + else: + raise ValueError( + "Neither 'context'/'questions' nor 'prompt' columns " + "are present in the evaluation_examples DataFrame" + ) + sentence_embedder = SentenceTransformer('all-MiniLM-L6-v2') + explainer = LocalExplanationSentenceEmbedder( + sentence_embedder=sentence_embedder, + perturbation_model="removal", + partition_fn="sentences", + progress_bar=None) + max_completion = 50 # Define max tokens for the completion + + api_settings = { + "api_type": self._model.model.api_type, + "api_base": self._model.model.api_base, + "api_version": self._model.model.api_version, + "api_key": self._model.model.api_key + } + model_wrapped = ChatOpenAI( + engine=self._model.model.engine, + encoding="cl100k_base", + api_settings=api_settings) + completions = model_wrapped.sample( + eval_examples, max_new_tokens=max_completion) + + explanation = [] + for i, completion in enumerate(completions): + attribution, parts = explainer.attribution(model_wrapped, + eval_examples[i], + completion, + ) + explanation.append((attribution, parts)) + + self._explanation = explanation else: raise ValueError("Unknown task type: {}".format(self._task_type)) diff --git a/responsibleai_text/setup.py b/responsibleai_text/setup.py index 8b0bada790..9c9a7b689e 100644 --- a/responsibleai_text/setup.py +++ b/responsibleai_text/setup.py @@ -24,6 +24,10 @@ 'bert_score', 'nltk', 'rouge_score' + ], + "generative_text": [ + 'interpret_text', + 'sentence_transformers' ] } setuptools.setup(