Skip to content

Commit

Permalink
Change arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 20, 2024
1 parent 5e2b059 commit a378089
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions examples/cifar/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import torch
import torch.nn.functional as F
from arguments import FactorArguments

from torch import nn
from examples.cifar.pipeline import construct_resnet9, get_cifar10_dataset
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.task import Task

BATCH_DTYPE = Tuple[torch.Tensor, torch.Tensor]
BATCH_TYPE = Tuple[torch.Tensor, torch.Tensor]


def parse_args():
Expand Down Expand Up @@ -53,38 +53,38 @@ def parse_args():


class ClassificationTask(Task):

def compute_train_loss(
self,
batch: BATCH_DTYPE,
outputs: torch.Tensor,
batch: BATCH_TYPE,
model: nn.Module,
sample: bool = False,
) -> torch.Tensor:
_, labels = batch

inputs, labels = batch
logits = model(inputs)
if not sample:
return F.cross_entropy(outputs, labels, reduction="sum")
return F.cross_entropy(logits, labels, reduction="sum")
with torch.no_grad():
probs = torch.nn.functional.softmax(outputs, dim=-1)
probs = torch.nn.functional.softmax(logits, dim=-1)
sampled_labels = torch.multinomial(
probs,
num_samples=1,
).flatten()
return F.cross_entropy(outputs, sampled_labels.detach(), reduction="sum")
return F.cross_entropy(logits, sampled_labels.detach(), reduction="sum")

def compute_measurement(
self,
batch: BATCH_DTYPE,
outputs: torch.Tensor,
batch: BATCH_TYPE,
model: nn.Module,
) -> torch.Tensor:
# Copied from: https://github.com/MadryLab/trak/blob/main/trak/modelout_functions.py.
_, labels = batch
inputs, labels = batch
logits = model(inputs)

bindex = torch.arange(outputs.shape[0]).to(device=outputs.device, non_blocking=False)
logits_correct = outputs[bindex, labels]
bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False)
logits_correct = logits[bindex, labels]

cloned_logits = outputs.clone()
cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=outputs.device, dtype=outputs.dtype)
cloned_logits = logits.clone()
cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype)

margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return -margins.sum()
Expand Down

0 comments on commit a378089

Please sign in to comment.