Skip to content

Commit

Permalink
Start wikitext add
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 9, 2024
1 parent 23eda1c commit 62123a4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
4 changes: 2 additions & 2 deletions examples/wikitext/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ This will fine-tune the model using the specified hyperparameters and save the f

## Computing Pairwise Influence Scores

To compute pairwise influence scores using the `ekfac` factorization strategy, run the following command:
To compute pairwise influence scores using the `ekfac` strategy, run the following command:

```bash
python analyze.py --query_batch_size 32 \
Expand Down Expand Up @@ -91,7 +91,7 @@ This reduces computation time to about 20 minutes on an A100 (80GB) GPU:
The `half_precision_analysis.py` script compares the correlations between `float32` and `bfloat16` scores.

<p align="center">
<a href="#"><img width="380" img src="figure/half_precision.png" alt="Query Batching"/></a>
<a href="#"><img width="380" img src="figure/half_precision.png" alt="Half Precision"/></a>
</p>

The average correlation for 481 data points is `0.96`. Finally, we can try using `torch.compile`:
Expand Down
15 changes: 7 additions & 8 deletions examples/wikitext/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,21 @@ def compute_train_loss(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
).logits

shift_logits = logits[..., :-1, :].contiguous()
reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
logits = logits[..., :-1, :].contiguous()
logits = logits.view(-1, logits.size(-1))

if not sample:
labels = batch["labels"]
shift_labels = labels[..., 1:].contiguous()
summed_loss = F.cross_entropy(reshaped_shift_logits, shift_labels.view(-1), reduction="sum")
labels = labels[..., 1:].contiguous()
summed_loss = F.cross_entropy(logits, labels.view(-1), reduction="sum")
else:
with torch.no_grad():
probs = torch.nn.functional.softmax(reshaped_shift_logits.detach(), dim=-1)
probs = torch.nn.functional.softmax(logits.detach(), dim=-1)
sampled_labels = torch.multinomial(
probs,
num_samples=1,
).flatten()
summed_loss = F.cross_entropy(reshaped_shift_logits, sampled_labels, reduction="sum")
summed_loss = F.cross_entropy(logits, sampled_labels, reduction="sum")
return summed_loss

def compute_measurement(
Expand All @@ -129,7 +128,7 @@ def get_influence_tracked_modules(self) -> List[str]:

return total_modules

def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]:
def get_attention_mask(self, batch: BATCH_TYPE) -> torch.Tensor:
return batch["attention_mask"]


Expand Down

0 comments on commit 62123a4

Please sign in to comment.