Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why do captum's perturbation and IG treat input & target differently? #1456

Open
rbelew opened this issue Dec 12, 2024 · 0 comments
Open

Why do captum's perturbation and IG treat input & target differently? #1456

rbelew opened this issue Dec 12, 2024 · 0 comments

Comments

@rbelew
Copy link

rbelew commented Dec 12, 2024

❓ Questions and Help

I've been successfull using captum's LayerIntegratedGradients class,
but none of my attempts trying the same sorts of inputs and targets using
LLMAttribution seem to work.

I'm working with a BertForMultipleChoice model, and the input is a list of
the repeated prompt followed by the choices:

for i,c in enumerate(tst['input_ids'][0]):
    indices = c.detach().tolist()
    sepIdx =  indices.index(SEP_IDX)
    nearSep = indices[sepIdx-prefix:]
    preTokens = tokenizer.convert_ids_to_tokens(indices[sepIdx-prefix:sepIdx-1])
    choiceTokens = tokenizer.convert_ids_to_tokens(indices[sepIdx+1:])
    print(f"{i} {sepIdx} {' '.join(preTokens):>55}\t[SEP] {' '.join(choiceTokens)}")

	0 120           behalf of fake charities . as webster sees it	[SEP] recognizing the guidelines commentary is authoritative [SEP]
	1 109   when he sol ##ici ##ted personal information from the	[SEP] holding that a sentencing guide ##line pre ##va ##ils over its commentary if the two are inconsistent [SEP]
	2 98                                       l ( b ) ( 9 ) ( a	[SEP] holding that sentencing guidelines commentary must be given controlling weight unless it violate ##s the constitution or a federal statute or is plainly inconsistent with the guidelines itself [SEP]
	3 99                                       ( b ) ( 9 ) ( a )	[SEP] holding that commentary is not authoritative if it is inconsistent with or a plainly er ##rone ##ous reading of the guide ##line it interpret ##s or explains [SEP]
	4 119           on behalf of fake charities . as webster sees	[SEP] holding that guidelines commentary is generally authoritative [SEP]

I'm using LayerIntegratedGradients with a test example and a target scalar reprsenting the index
of the correct (multiple choice) like this:

    tstEGTuple = (tst['input_ids'], 
                  tst['attention_mask'], 
                  tst['token_type_ids'])
    targetIdx = 3 # for this particular test example
        
    lig = LayerIntegratedGradients(custForwardModel, model.bert.embeddings)
    attributions_ig = lig.attribute(tstEGTuple, n_steps=5,target=targetIdx) 

and that works, eg allowing calculations like summarize_attributions(attributions_ig), viz.VisualizationDataRecord()
etc.

For LLMAttribution I am following the Llama2 tutorial The closest I can get with LLMAttribution seems to require use of TextTokenInput for input, but raw text for the target?

    in0 = tst['input_ids'][0][0]
    in0_tokens = tokenizer.convert_ids_to_tokens(in0)
    in0Txt = ' '.join(in0_tokens)
    in4captum = TextTokenInput(in0Txt, tokenizer,skip_tokens=skip_tokens)

    target = targetList[egIdx]                   
    targetIn = tst['input_ids'][0][target]
    targ_tokens = tokenizer.convert_ids_to_tokens(targetIn)
    targTxt = ' '.join(targ_tokens)
    # targ4captum = TextTokenInput(targTxt, tokenizer,skip_tokens=skip_tokens)
    
	llm_attr = LLMAttribution(fa, tokenizer)
	attributions_fa = llm_attr.attribute(in4captum, target=targTxt) 

but this raises an exception, that prepare_inputs_for_generation isn't
available for this BertForMultipleChoice model:

	Traceback (most recent call last):
	File "/Users/rik/Code/eclipse/ai4law/src/run_multChoice.py", line 874, in <module>
	main()
	File "/Users/rik/Code/eclipse/ai4law/src/run_multChoice.py", line 854, in main
	captumPerturb(model,tokenizer,tstEGTensorDict,tstEGtarget,OutDir)
	File "/Users/rik/Code/eclipse/ai4law/src/run_multChoice.py", line 479, in captumPerturb
	attributions_fa = llm_attr.attribute(in4captum, target=targTxt) 
	^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
	File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/attr/_core/llm_attr.py", line 667, in attribute
	cur_attr = self.attr_method.attribute(
	^^^^^^^^^^^^^^^^^^^^^^^^^^^
	File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/log/dummy_log.py", line 39, in wrapper
	return func(*args, **kwargs)
	^^^^^^^^^^^^^^^^^^^^^
	File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/attr/_core/feature_ablation.py", line 288, in attribute
	initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
	^^^^^^^^^^^^^
	File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/_utils/common.py", line 588, in _run_forward
	output = forward_func(
	^^^^^^^^^^^^^
	File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/captum/attr/_core/llm_attr.py", line 567, in _forward_func
	model_inputs = self.model.prepare_inputs_for_generation(
	^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
	File "/Users/rik/data/pkg/miniconda3/envs/ai4law2/lib/python3.11/site-packages/transformers/generation/utils.py", line 376, in prepare_inputs_for_generation
	raise NotImplementedError(

Thanks for any suggestions!

I also posted this question here Discussion Forum

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant