From 6213c91d49d41eec313692ce462770af2f850365 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 27 Jun 2024 17:41:36 -0400 Subject: [PATCH] Put the whole model into half --- examples/openwebtext/analyze.py | 38 +++++++++++--------------------- examples/openwebtext/pipeline.py | 9 ++------ 2 files changed, 15 insertions(+), 32 deletions(-) diff --git a/examples/openwebtext/analyze.py b/examples/openwebtext/analyze.py index 6cb65bc..25f5f12 100644 --- a/examples/openwebtext/analyze.py +++ b/examples/openwebtext/analyze.py @@ -1,27 +1,27 @@ import argparse import logging -import os from typing import Dict, List, Optional import torch import torch.nn.functional as F from accelerate import Accelerator from torch import nn -from transformers import default_data_collator, AutoTokenizer +from transformers import default_data_collator + from examples.openwebtext.pipeline import ( construct_llama3, get_custom_dataset, - get_openwebtext_dataset, MODEL_NAME, + get_openwebtext_dataset, ) from kronfluence.analyzer import Analyzer, prepare_model from kronfluence.task import Task -from kronfluence.utils.common.factor_arguments import reduce_memory_factor_arguments, \ - extreme_reduce_memory_factor_arguments +from kronfluence.utils.common.factor_arguments import ( + extreme_reduce_memory_factor_arguments, +) from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments from kronfluence.utils.dataset import DataLoaderKwargs BATCH_TYPE = Dict[str, torch.Tensor] -tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True) if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True @@ -33,7 +33,7 @@ def parse_args(): parser.add_argument( "--factor_strategy", type=str, - default="identity", + default="ekfac", help="Strategy to compute influence factors.", ) parser.add_argument( @@ -51,13 +51,13 @@ def parse_args(): parser.add_argument( "--factor_batch_size", type=int, - default=2, + default=4, help="Batch size for computing influence factors.", ) parser.add_argument( "--train_batch_size", type=int, - default=8, + default=4, help="Batch size for computing query gradients.", ) parser.add_argument( @@ -81,7 +81,7 @@ def compute_train_loss( logits = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], - ).logits + ).logits.float() shift_logits = logits[..., :-1, :].contiguous() reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1)) @@ -98,9 +98,7 @@ def compute_train_loss( probs, num_samples=1, ).flatten() - summed_loss = F.cross_entropy(reshaped_shift_logits, sampled_labels, - ignore_index=-100, - reduction="sum") + summed_loss = F.cross_entropy(reshaped_shift_logits, sampled_labels, ignore_index=-100, reduction="sum") return summed_loss def compute_measurement( @@ -111,7 +109,7 @@ def compute_measurement( logits = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], - ).logits + ).logits.float() shift_labels = batch["labels"][..., 1:].contiguous().view(-1) shift_logits = logits[..., :-1, :].contiguous().view(-1, logits.size(-1)) return F.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="sum") @@ -124,12 +122,6 @@ def tracked_modules(self) -> List[str]: total_modules.append(f"model.layers.{i}.mlp.up_proj") total_modules.append(f"model.layers.{i}.mlp.down_proj") - # for i in range(32): - # total_modules.append(f"model.layers.{i}.self_attn.q_proj") - # total_modules.append(f"model.layers.{i}.self_attn.k_proj") - # total_modules.append(f"model.layers.{i}.self_attn.v_proj") - # total_modules.append(f"model.layers.{i}.self_attn.o_proj") - return total_modules def get_attention_mask(self, batch: BATCH_TYPE) -> Optional[torch.Tensor]: @@ -145,7 +137,7 @@ def main(): eval_dataset = get_custom_dataset() # Prepare the trained model. - model = construct_llama3() + model = construct_llama3().to(dtype=torch.bfloat16) # Define task and prepare model. task = LanguageModelingTask() @@ -167,10 +159,6 @@ def main(): # Compute influence factors. factors_name = args.factor_strategy factor_args = extreme_reduce_memory_factor_arguments(strategy=args.factor_strategy, dtype=torch.bfloat16) - # factor_args.covariance_module_partition_size = 2 - # factor_args.lambda_module_partition_size = 2 - # factor_args.covariance_max_examples = 100_000 - # factor_args.lambda_max_examples = 100_000 analyzer.fit_all_factors( factors_name=factors_name, dataset=train_dataset, diff --git a/examples/openwebtext/pipeline.py b/examples/openwebtext/pipeline.py index 819ab18..e015cc0 100644 --- a/examples/openwebtext/pipeline.py +++ b/examples/openwebtext/pipeline.py @@ -1,24 +1,19 @@ import copy from typing import List +import torch from datasets import load_dataset from torch import nn from torch.utils import data from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -import torch - MODEL_NAME = "meta-llama/Meta-Llama-3-8B" -# MODEL_NAME = "EleutherAI/pythia-70m" - MAX_LENGTH = 512 def construct_llama3() -> nn.Module: config = AutoConfig.from_pretrained( - MODEL_NAME, - trust_remote_code=True, - model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto" + MODEL_NAME, trust_remote_code=True, device_map="auto" ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME,