diff --git a/examples/glue/analyze.py b/examples/glue/analyze.py index 7416520..e5a5696 100644 --- a/examples/glue/analyze.py +++ b/examples/glue/analyze.py @@ -1,7 +1,7 @@ import argparse import logging import os -from typing import Dict, Tuple, List +from typing import Dict, Tuple, List, Optional import torch import torch.nn.functional as F @@ -108,7 +108,8 @@ def compute_measurement( margins = logits_correct - cloned_logits.logsumexp(dim=-1) return -margins.sum() - def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor: + def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]: + print(type(batch["attention_mask"])) return batch["attention_mask"]