diff --git a/examples/openwebtext/task.py b/examples/openwebtext/task.py index 3e3c264..343dba2 100644 --- a/examples/openwebtext/task.py +++ b/examples/openwebtext/task.py @@ -53,11 +53,12 @@ def compute_measurement( def get_influence_tracked_modules(self) -> List[str]: total_modules = [] - for i in range(32): - total_modules.append(f"model.layers.{i}.self_attn.q_proj") - total_modules.append(f"model.layers.{i}.self_attn.k_proj") - total_modules.append(f"model.layers.{i}.self_attn.v_proj") - total_modules.append(f"model.layers.{i}.self_attn.o_proj") + # You can uncomment the following lines if you would like to compute influence also on attention layers. + # for i in range(32): + # total_modules.append(f"model.layers.{i}.self_attn.q_proj") + # total_modules.append(f"model.layers.{i}.self_attn.k_proj") + # total_modules.append(f"model.layers.{i}.self_attn.v_proj") + # total_modules.append(f"model.layers.{i}.self_attn.o_proj") for i in range(32): total_modules.append(f"model.layers.{i}.mlp.gate_proj")