From d8ceaa83db2daa27a9980302b61c950578893d4e Mon Sep 17 00:00:00 2001 From: Zachariah Carmichael <20629897+craymichael@users.noreply.github.com> Date: Mon, 9 Sep 2024 09:24:17 -0700 Subject: [PATCH] Add additional gradient-based attribution methods to LLM Attribution (#1337) Summary: Add `LayerGradientXActivation` and `LayerGradientShap` to the supported gradient-based LLM attribution methods. Pull Request resolved: https://github.com/pytorch/captum/pull/1337 Test Plan: `pytest tests/attr -k TestLLMGradAttr` with new test cases via parameterized library Reviewed By: cyrjano Differential Revision: D62221000 Pulled By: craymichael fbshipit-source-id: fb5f170e13a62355357d46d3ef7a2464e8eb80ab --- captum/attr/_core/deep_lift.py | 5 +- captum/attr/_core/layer/layer_deep_lift.py | 6 +- captum/attr/_core/llm_attr.py | 93 +++++++++++++--------- tests/attr/test_llm_attr.py | 73 +++++++++++++---- 4 files changed, 121 insertions(+), 56 deletions(-) diff --git a/captum/attr/_core/deep_lift.py b/captum/attr/_core/deep_lift.py index 24f44788b..0a127caf2 100644 --- a/captum/attr/_core/deep_lift.py +++ b/captum/attr/_core/deep_lift.py @@ -110,7 +110,7 @@ def __init__( Default: 1e-10 """ GradientAttribution.__init__(self, model) - self.model = model + self.model: nn.Module = model self.eps = eps self.forward_handles: List[RemovableHandle] = [] self.backward_handles: List[RemovableHandle] = [] @@ -324,7 +324,8 @@ def attribute( # type: ignore warnings.warn( """Setting forward, backward hooks and attributes on non-linear activations. The hooks and attributes will be removed - after the attribution is finished""" + after the attribution is finished""", + stacklevel=2, ) # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got # `TensorOrTupleOfTensorsGeneric`. diff --git a/captum/attr/_core/layer/layer_deep_lift.py b/captum/attr/_core/layer/layer_deep_lift.py index 50b8dc0b3..05ae49e56 100644 --- a/captum/attr/_core/layer/layer_deep_lift.py +++ b/captum/attr/_core/layer/layer_deep_lift.py @@ -351,9 +351,9 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence: grad_kwargs=grad_kwargs, ) - attr_inputs = tuple(map(lambda attr: attr[0], attrs)) - attr_baselines = tuple(map(lambda attr: attr[1], attrs)) - gradients = tuple(map(lambda grad: grad[0], gradients)) + attr_inputs = tuple(attr[0] for attr in attrs) + attr_baselines = tuple(attr[1] for attr in attrs) + gradients = tuple(grad[0] for grad in gradients) if custom_attribution_func is None: if self.multiplies_by_inputs: diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 737ac5c4b..836536930 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -10,6 +10,8 @@ from captum._utils.typing import TokenizerLike from captum.attr._core.feature_ablation import FeatureAblation from captum.attr._core.kernel_shap import KernelShap +from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap +from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients from captum.attr._core.lime import Lime from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling @@ -452,7 +454,11 @@ class LLMGradientAttribution(Attribution): and returns LLMAttributionResult """ - SUPPORTED_METHODS = (LayerIntegratedGradients,) + SUPPORTED_METHODS = ( + LayerGradientShap, + LayerGradientXActivation, + LayerIntegratedGradients, + ) SUPPORTED_INPUTS = (TextTokenInput,) def __init__( @@ -473,14 +479,14 @@ class created with the llm model that follows huggingface style super().__init__(attr_method.forward_func) - # shallow copy is enough to avoid modifying original instance - self.attr_method: GradientAttribution = copy(attr_method) - self.attr_method.forward_func = self._forward_func - # alias, we really need a model and don't support wrapper functions # coz we need call model.forward, model.generate, etc. self.model: nn.Module = cast(nn.Module, self.forward_func) + # shallow copy is enough to avoid modifying original instance + self.attr_method: GradientAttribution = copy(attr_method) + self.attr_method.forward_func = GradientForwardFunc(self) + self.tokenizer: TokenizerLike = tokenizer self.device: torch.device = ( cast(torch.device, self.model.device) @@ -488,38 +494,6 @@ class created with the llm model that follows huggingface style else next(self.model.parameters()).device ) - def _forward_func( - self, - perturbed_tensor: Tensor, - inp: InterpretableInput, - target_tokens: Tensor, # 1D tensor of target token ids - cur_target_idx: int, # current target index - ) -> Tensor: - perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor)) - - if cur_target_idx: - # the input batch size can be expanded by attr method - output_token_tensor = ( - target_tokens[:cur_target_idx] - .unsqueeze(0) - .expand(perturbed_input.size(0), -1) - .to(self.device) - ) - new_input_tensor = torch.cat([perturbed_input, output_token_tensor], dim=1) - else: - new_input_tensor = perturbed_input - - output_logits = self.model(new_input_tensor) - - new_token_logits = output_logits.logits[:, -1] - log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1) - - target_token = target_tokens[cur_target_idx] - token_log_probs = log_probs[..., target_token] - - # the attribution target is limited to the log probability - return token_log_probs - def _format_model_input(self, model_input: Tensor) -> Tensor: """ Convert str to tokenized tensor @@ -643,3 +617,48 @@ def attribute_future(self) -> Callable: raise NotImplementedError( "attribute_future is not implemented for LLMGradientAttribution" ) + + +class GradientForwardFunc(nn.Module): + """ + A wrapper class for the forward function of a model in LLMGradientAttribution + """ + + def __init__(self, attr: LLMGradientAttribution) -> None: + super().__init__() + self.attr = attr + self.model: nn.Module = attr.model + + def forward( + self, + perturbed_tensor: Tensor, + inp: InterpretableInput, + target_tokens: Tensor, # 1D tensor of target token ids + cur_target_idx: int, # current target index + ) -> Tensor: + perturbed_input = self.attr._format_model_input( + inp.to_model_input(perturbed_tensor) + ) + + if cur_target_idx: + # the input batch size can be expanded by attr method + output_token_tensor = ( + target_tokens[:cur_target_idx] + .unsqueeze(0) + .expand(perturbed_input.size(0), -1) + .to(self.attr.device) + ) + new_input_tensor = torch.cat([perturbed_input, output_token_tensor], dim=1) + else: + new_input_tensor = perturbed_input + + output_logits = self.model(new_input_tensor) + + new_token_logits = output_logits.logits[:, -1] + log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1) + + target_token = target_tokens[cur_target_idx] + token_log_probs = log_probs[..., target_token] + + # the attribution target is limited to the log probability + return token_log_probs diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index 4d34d6983..94a3b454c 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -3,19 +3,19 @@ # pyre-strict import copy -from typing import Any, cast, Dict, List, NamedTuple, Optional, Type, Union +from typing import Any, cast, Dict, List, NamedTuple, Optional, Tuple, Type, Union import torch -from captum._utils.models.linear_model import ( # @manual=//pytorch/captum/captum/_utils/models/linear_model:linear_model # noqa: E501 - SkLearnLasso, -) +from captum._utils.models.linear_model import SkLearnLasso from captum.attr._core.feature_ablation import FeatureAblation from captum.attr._core.kernel_shap import KernelShap +from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap +from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients from captum.attr._core.lime import Lime from captum.attr._core.llm_attr import LLMAttribution, LLMGradientAttribution from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling -from captum.attr._utils.attribution import PerturbationAttribution +from captum.attr._utils.attribution import GradientAttribution, PerturbationAttribution from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput from parameterized import parameterized, parameterized_class from tests.helpers import BaseTest @@ -379,15 +379,30 @@ def test_futures_not_implemented(self) -> None: class TestLLMGradAttr(BaseTest): device: str - def test_llm_attr(self) -> None: + @parameterized.expand( + [ + (LayerIntegratedGradients, None), + (LayerGradientXActivation, None), + (LayerGradientShap, (torch.tensor([[1, 0, 1, 0]]),)), + ] + ) + def test_llm_attr( + self, AttrClass: Type[GradientAttribution], baselines: Optional[Tuple[Tensor]] + ) -> None: llm = DummyLLM() llm.to(self.device) tokenizer = DummyTokenizer() - attr = LayerIntegratedGradients(llm, llm.emb) + attr = AttrClass(llm, llm.emb) # type: ignore[call-arg] llm_attr = LLMGradientAttribution(attr, tokenizer) + attr_kws: Dict[str, Any] = {} + if baselines is not None: + attr_kws["baselines"] = tuple( + baseline.to(self.device) for baseline in baselines + ) + inp = TextTokenInput("a b c", tokenizer) - res = llm_attr.attribute(inp, "m n o p q") + res = llm_attr.attribute(inp, "m n o p q", **attr_kws) # 5 output tokens, 4 input tokens including sos self.assertEqual(res.seq_attr.shape, (4,)) @@ -402,15 +417,30 @@ def test_llm_attr(self) -> None: assert res.token_attr is not None # make pyre/mypy happy self.assertEqual(token_attr.device.type, self.device) # type: ignore - def test_llm_attr_without_target(self) -> None: + @parameterized.expand( + [ + (LayerIntegratedGradients, None), + (LayerGradientXActivation, None), + (LayerGradientShap, (torch.tensor([[1, 0, 1, 0]]),)), + ] + ) + def test_llm_attr_without_target( + self, AttrClass: Type[GradientAttribution], baselines: Optional[Tuple[Tensor]] + ) -> None: llm = DummyLLM() llm.to(self.device) tokenizer = DummyTokenizer() - attr = LayerIntegratedGradients(llm, llm.emb) + attr = AttrClass(llm, llm.emb) # type: ignore[call-arg] llm_attr = LLMGradientAttribution(attr, tokenizer) + attr_kws: Dict[str, Any] = {} + if baselines is not None: + attr_kws["baselines"] = tuple( + baseline.to(self.device) for baseline in baselines + ) + inp = TextTokenInput("a b c", tokenizer) - res = llm_attr.attribute(inp, gen_args={"mock_response": "x y z"}) + res = llm_attr.attribute(inp, gen_args={"mock_response": "x y z"}, **attr_kws) self.assertEqual(res.seq_attr.shape, (4,)) assert res.token_attr is not None # make pyre/mypy happy @@ -424,15 +454,30 @@ def test_llm_attr_without_target(self) -> None: assert res.token_attr is not None # make pyre/mypy happy self.assertEqual(token_attr.device.type, self.device) # type: ignore - def test_llm_attr_with_skip_tokens(self) -> None: + @parameterized.expand( + [ + (LayerIntegratedGradients, None), + (LayerGradientXActivation, None), + (LayerGradientShap, (torch.tensor([[1, 0, 1]]),)), + ] + ) + def test_llm_attr_with_skip_tokens( + self, AttrClass: Type[GradientAttribution], baselines: Optional[Tuple[Tensor]] + ) -> None: llm = DummyLLM() llm.to(self.device) tokenizer = DummyTokenizer() - attr = LayerIntegratedGradients(llm, llm.emb) + attr = AttrClass(llm, llm.emb) # type: ignore[call-arg] llm_attr = LLMGradientAttribution(attr, tokenizer) + attr_kws: Dict[str, Any] = {} + if baselines is not None: + attr_kws["baselines"] = tuple( + baseline.to(self.device) for baseline in baselines + ) + inp = TextTokenInput("a b c", tokenizer, skip_tokens=[0]) - res = llm_attr.attribute(inp, "m n o p q") + res = llm_attr.attribute(inp, "m n o p q", **attr_kws) # 5 output tokens, 4 input tokens including sos self.assertEqual(res.seq_attr.shape, (3,))