From 50961a8d770d36210b8071c2be74c477578fdb7f Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 1 Jul 2024 13:56:08 -0400 Subject: [PATCH] Add exact modules to track --- examples/dailymail/analyze.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/examples/dailymail/analyze.py b/examples/dailymail/analyze.py index 244818c..39611c7 100644 --- a/examples/dailymail/analyze.py +++ b/examples/dailymail/analyze.py @@ -1,7 +1,7 @@ import argparse import logging import os -from typing import Dict, Optional +from typing import Dict, Optional, List import torch import torch.nn.functional as F @@ -147,6 +147,36 @@ def compute_measurement( masks = batch["labels"].view(-1) != -100 return -margins[masks].sum() + def tracked_modules(self) -> List[str]: + total_modules = [] + + # Add attention layers: + for i in range(6): + total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.q") + total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.k") + total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.v") + total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.o") + + total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.q") + total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.k") + total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.v") + total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.o") + + total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.q") + total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.k") + total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.v") + total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.o") + + # Add MLP layers: + for i in range(6): + total_modules.append(f"encoder.block.{i}.layer.1.DenseReluDense.wi") + total_modules.append(f"encoder.block.{i}.layer.1.DenseReluDense.wo") + + total_modules.append(f"decoder.block.{i}.layer.2.DenseReluDense.wi") + total_modules.append(f"decoder.block.{i}.layer.2.DenseReluDense.wo") + + return total_modules + def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]: return batch["attention_mask"]