From a3780890ff924bf610332c0bc6c5dfe9c054fa72 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Wed, 20 Mar 2024 00:04:34 -0400 Subject: [PATCH] Change arguments --- examples/cifar/analyze.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/cifar/analyze.py b/examples/cifar/analyze.py index 6ad7141..b892538 100644 --- a/examples/cifar/analyze.py +++ b/examples/cifar/analyze.py @@ -6,12 +6,12 @@ import torch import torch.nn.functional as F from arguments import FactorArguments - +from torch import nn from examples.cifar.pipeline import construct_resnet9, get_cifar10_dataset from kronfluence.analyzer import Analyzer, prepare_model from kronfluence.task import Task -BATCH_DTYPE = Tuple[torch.Tensor, torch.Tensor] +BATCH_TYPE = Tuple[torch.Tensor, torch.Tensor] def parse_args(): @@ -53,38 +53,38 @@ def parse_args(): class ClassificationTask(Task): - def compute_train_loss( self, - batch: BATCH_DTYPE, - outputs: torch.Tensor, + batch: BATCH_TYPE, + model: nn.Module, sample: bool = False, ) -> torch.Tensor: - _, labels = batch - + inputs, labels = batch + logits = model(inputs) if not sample: - return F.cross_entropy(outputs, labels, reduction="sum") + return F.cross_entropy(logits, labels, reduction="sum") with torch.no_grad(): - probs = torch.nn.functional.softmax(outputs, dim=-1) + probs = torch.nn.functional.softmax(logits, dim=-1) sampled_labels = torch.multinomial( probs, num_samples=1, ).flatten() - return F.cross_entropy(outputs, sampled_labels.detach(), reduction="sum") + return F.cross_entropy(logits, sampled_labels.detach(), reduction="sum") def compute_measurement( self, - batch: BATCH_DTYPE, - outputs: torch.Tensor, + batch: BATCH_TYPE, + model: nn.Module, ) -> torch.Tensor: # Copied from: https://github.com/MadryLab/trak/blob/main/trak/modelout_functions.py. - _, labels = batch + inputs, labels = batch + logits = model(inputs) - bindex = torch.arange(outputs.shape[0]).to(device=outputs.device, non_blocking=False) - logits_correct = outputs[bindex, labels] + bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False) + logits_correct = logits[bindex, labels] - cloned_logits = outputs.clone() - cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=outputs.device, dtype=outputs.dtype) + cloned_logits = logits.clone() + cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype) margins = logits_correct - cloned_logits.logsumexp(dim=-1) return -margins.sum()