Skip to content

Commit

Permalink
Put the whole model into half
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 27, 2024
1 parent 798406e commit 6213c91
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 32 deletions.
38 changes: 13 additions & 25 deletions examples/openwebtext/analyze.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
import argparse
import logging
import os
from typing import Dict, List, Optional

import torch
import torch.nn.functional as F
from accelerate import Accelerator
from torch import nn
from transformers import default_data_collator, AutoTokenizer
from transformers import default_data_collator

from examples.openwebtext.pipeline import (
construct_llama3,
get_custom_dataset,
get_openwebtext_dataset, MODEL_NAME,
get_openwebtext_dataset,
)
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.task import Task
from kronfluence.utils.common.factor_arguments import reduce_memory_factor_arguments, \
extreme_reduce_memory_factor_arguments
from kronfluence.utils.common.factor_arguments import (
extreme_reduce_memory_factor_arguments,
)
from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments
from kronfluence.utils.dataset import DataLoaderKwargs

BATCH_TYPE = Dict[str, torch.Tensor]
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True)

if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
Expand All @@ -33,7 +33,7 @@ def parse_args():
parser.add_argument(
"--factor_strategy",
type=str,
default="identity",
default="ekfac",
help="Strategy to compute influence factors.",
)
parser.add_argument(
Expand All @@ -51,13 +51,13 @@ def parse_args():
parser.add_argument(
"--factor_batch_size",
type=int,
default=2,
default=4,
help="Batch size for computing influence factors.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=8,
default=4,
help="Batch size for computing query gradients.",
)
parser.add_argument(
Expand All @@ -81,7 +81,7 @@ def compute_train_loss(
logits = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
).logits
).logits.float()
shift_logits = logits[..., :-1, :].contiguous()
reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1))

Expand All @@ -98,9 +98,7 @@ def compute_train_loss(
probs,
num_samples=1,
).flatten()
summed_loss = F.cross_entropy(reshaped_shift_logits, sampled_labels,
ignore_index=-100,
reduction="sum")
summed_loss = F.cross_entropy(reshaped_shift_logits, sampled_labels, ignore_index=-100, reduction="sum")
return summed_loss

def compute_measurement(
Expand All @@ -111,7 +109,7 @@ def compute_measurement(
logits = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
).logits
).logits.float()
shift_labels = batch["labels"][..., 1:].contiguous().view(-1)
shift_logits = logits[..., :-1, :].contiguous().view(-1, logits.size(-1))
return F.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="sum")
Expand All @@ -124,12 +122,6 @@ def tracked_modules(self) -> List[str]:
total_modules.append(f"model.layers.{i}.mlp.up_proj")
total_modules.append(f"model.layers.{i}.mlp.down_proj")

# for i in range(32):
# total_modules.append(f"model.layers.{i}.self_attn.q_proj")
# total_modules.append(f"model.layers.{i}.self_attn.k_proj")
# total_modules.append(f"model.layers.{i}.self_attn.v_proj")
# total_modules.append(f"model.layers.{i}.self_attn.o_proj")

return total_modules

def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]:
Expand All @@ -145,7 +137,7 @@ def main():
eval_dataset = get_custom_dataset()

# Prepare the trained model.
model = construct_llama3()
model = construct_llama3().to(dtype=torch.bfloat16)

# Define task and prepare model.
task = LanguageModelingTask()
Expand All @@ -167,10 +159,6 @@ def main():
# Compute influence factors.
factors_name = args.factor_strategy
factor_args = extreme_reduce_memory_factor_arguments(strategy=args.factor_strategy, dtype=torch.bfloat16)
# factor_args.covariance_module_partition_size = 2
# factor_args.lambda_module_partition_size = 2
# factor_args.covariance_max_examples = 100_000
# factor_args.lambda_max_examples = 100_000
analyzer.fit_all_factors(
factors_name=factors_name,
dataset=train_dataset,
Expand Down
9 changes: 2 additions & 7 deletions examples/openwebtext/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
import copy
from typing import List

import torch
from datasets import load_dataset
from torch import nn
from torch.utils import data
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import torch


MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
# MODEL_NAME = "EleutherAI/pythia-70m"

MAX_LENGTH = 512


def construct_llama3() -> nn.Module:
config = AutoConfig.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto"
MODEL_NAME, trust_remote_code=True, device_map="auto"
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
Expand Down

0 comments on commit 6213c91

Please sign in to comment.