Skip to content

Commit

Permalink
Add exact modules to track
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 1, 2024
1 parent 46df33a commit 50961a8
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion examples/dailymail/analyze.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import logging
import os
from typing import Dict, Optional
from typing import Dict, Optional, List

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -147,6 +147,36 @@ def compute_measurement(
masks = batch["labels"].view(-1) != -100
return -margins[masks].sum()

def tracked_modules(self) -> List[str]:
total_modules = []

# Add attention layers:
for i in range(6):
total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.q")
total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.k")
total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.v")
total_modules.append(f"encoder.block.{i}.layer.0.SelfAttention.o")

total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.q")
total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.k")
total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.v")
total_modules.append(f"decoder.block.{i}.layer.0.SelfAttention.o")

total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.q")
total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.k")
total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.v")
total_modules.append(f"decoder.block.{i}.layer.1.EncDecAttention.o")

# Add MLP layers:
for i in range(6):
total_modules.append(f"encoder.block.{i}.layer.1.DenseReluDense.wi")
total_modules.append(f"encoder.block.{i}.layer.1.DenseReluDense.wo")

total_modules.append(f"decoder.block.{i}.layer.2.DenseReluDense.wi")
total_modules.append(f"decoder.block.{i}.layer.2.DenseReluDense.wo")

return total_modules

def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]:
return batch["attention_mask"]

Expand Down

0 comments on commit 50961a8

Please sign in to comment.