Skip to content

Commit

Permalink
Add AM
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 20, 2024
1 parent 8b3405b commit 5e9c8e7
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions examples/glue/analyze.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import logging
import os
from typing import Dict, Tuple, List
from typing import Dict, Tuple, List, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -108,7 +108,8 @@ def compute_measurement(
margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return -margins.sum()

def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor:
def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]:
print(type(batch["attention_mask"]))
return batch["attention_mask"]


Expand Down

0 comments on commit 5e9c8e7

Please sign in to comment.