From 5e9c8e7c4e12c75ee4ca14051c19e174ad067438 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Wed, 20 Mar 2024 04:50:32 -0400 Subject: [PATCH] Add AM --- examples/glue/analyze.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/glue/analyze.py b/examples/glue/analyze.py index 7416520..e5a5696 100644 --- a/examples/glue/analyze.py +++ b/examples/glue/analyze.py @@ -1,7 +1,7 @@ import argparse import logging import os -from typing import Dict, Tuple, List +from typing import Dict, Tuple, List, Optional import torch import torch.nn.functional as F @@ -108,7 +108,8 @@ def compute_measurement( margins = logits_correct - cloned_logits.logsumexp(dim=-1) return -margins.sum() - def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor: + def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]: + print(type(batch["attention_mask"])) return batch["attention_mask"]