diff --git a/open_lm/__pycache__/__init__.cpython-310.pyc b/open_lm/__pycache__/__init__.cpython-310.pyc index 9d49405..75193eb 100644 Binary files a/open_lm/__pycache__/__init__.cpython-310.pyc and b/open_lm/__pycache__/__init__.cpython-310.pyc differ diff --git a/open_lm/__pycache__/attention.cpython-310.pyc b/open_lm/__pycache__/attention.cpython-310.pyc index 99fd348..8a3978d 100644 Binary files a/open_lm/__pycache__/attention.cpython-310.pyc and b/open_lm/__pycache__/attention.cpython-310.pyc differ diff --git a/open_lm/__pycache__/data.cpython-310.pyc b/open_lm/__pycache__/data.cpython-310.pyc index e5d3d2a..29bea0d 100644 Binary files a/open_lm/__pycache__/data.cpython-310.pyc and b/open_lm/__pycache__/data.cpython-310.pyc differ diff --git a/open_lm/__pycache__/distributed.cpython-310.pyc b/open_lm/__pycache__/distributed.cpython-310.pyc index b5a5967..c3b16e7 100644 Binary files a/open_lm/__pycache__/distributed.cpython-310.pyc and b/open_lm/__pycache__/distributed.cpython-310.pyc differ diff --git a/open_lm/__pycache__/eval.cpython-310.pyc b/open_lm/__pycache__/eval.cpython-310.pyc deleted file mode 100644 index e7fedbc..0000000 Binary files a/open_lm/__pycache__/eval.cpython-310.pyc and /dev/null differ diff --git a/open_lm/__pycache__/eval3_seq.cpython-310.pyc b/open_lm/__pycache__/eval3_seq.cpython-310.pyc new file mode 100644 index 0000000..b7ee157 Binary files /dev/null and b/open_lm/__pycache__/eval3_seq.cpython-310.pyc differ diff --git a/open_lm/__pycache__/evaluate.cpython-310.pyc b/open_lm/__pycache__/evaluate.cpython-310.pyc index 57be711..01838a2 100644 Binary files a/open_lm/__pycache__/evaluate.cpython-310.pyc and b/open_lm/__pycache__/evaluate.cpython-310.pyc differ diff --git a/open_lm/__pycache__/extra_funcs.cpython-310.pyc b/open_lm/__pycache__/extra_funcs.cpython-310.pyc new file mode 100644 index 0000000..ba2702a Binary files /dev/null and b/open_lm/__pycache__/extra_funcs.cpython-310.pyc differ diff --git a/open_lm/__pycache__/extra_funcs2.cpython-310.pyc b/open_lm/__pycache__/extra_funcs2.cpython-310.pyc new file mode 100644 index 0000000..7174be9 Binary files /dev/null and b/open_lm/__pycache__/extra_funcs2.cpython-310.pyc differ diff --git a/open_lm/__pycache__/file_utils.cpython-310.pyc b/open_lm/__pycache__/file_utils.cpython-310.pyc index 42ce8f6..e123ebd 100644 Binary files a/open_lm/__pycache__/file_utils.cpython-310.pyc and b/open_lm/__pycache__/file_utils.cpython-310.pyc differ diff --git a/open_lm/__pycache__/logger.cpython-310.pyc b/open_lm/__pycache__/logger.cpython-310.pyc index 165099e..661eddb 100644 Binary files a/open_lm/__pycache__/logger.cpython-310.pyc and b/open_lm/__pycache__/logger.cpython-310.pyc differ diff --git a/open_lm/__pycache__/losses.cpython-310.pyc b/open_lm/__pycache__/losses.cpython-310.pyc index 8df8a18..d787a6e 100644 Binary files a/open_lm/__pycache__/losses.cpython-310.pyc and b/open_lm/__pycache__/losses.cpython-310.pyc differ diff --git a/open_lm/__pycache__/main.cpython-310.pyc b/open_lm/__pycache__/main.cpython-310.pyc index b5a04d4..156163f 100644 Binary files a/open_lm/__pycache__/main.cpython-310.pyc and b/open_lm/__pycache__/main.cpython-310.pyc differ diff --git a/open_lm/__pycache__/main2.cpython-310.pyc b/open_lm/__pycache__/main2.cpython-310.pyc deleted file mode 100644 index 47c485d..0000000 Binary files a/open_lm/__pycache__/main2.cpython-310.pyc and /dev/null differ diff --git a/open_lm/__pycache__/meters.cpython-310.pyc b/open_lm/__pycache__/meters.cpython-310.pyc index 3960ad9..81eb4a4 100644 Binary files a/open_lm/__pycache__/meters.cpython-310.pyc and b/open_lm/__pycache__/meters.cpython-310.pyc differ diff --git a/open_lm/__pycache__/model.cpython-310.pyc b/open_lm/__pycache__/model.cpython-310.pyc index 1c8b27d..bc9767b 100644 Binary files a/open_lm/__pycache__/model.cpython-310.pyc and b/open_lm/__pycache__/model.cpython-310.pyc differ diff --git a/open_lm/__pycache__/norms.cpython-310.pyc b/open_lm/__pycache__/norms.cpython-310.pyc index 7f750f6..ac74e17 100644 Binary files a/open_lm/__pycache__/norms.cpython-310.pyc and b/open_lm/__pycache__/norms.cpython-310.pyc differ diff --git a/open_lm/__pycache__/params.cpython-310.pyc b/open_lm/__pycache__/params.cpython-310.pyc index 84ea572..00cb7b2 100644 Binary files a/open_lm/__pycache__/params.cpython-310.pyc and b/open_lm/__pycache__/params.cpython-310.pyc differ diff --git a/open_lm/__pycache__/precision.cpython-310.pyc b/open_lm/__pycache__/precision.cpython-310.pyc index e7ce938..053e197 100644 Binary files a/open_lm/__pycache__/precision.cpython-310.pyc and b/open_lm/__pycache__/precision.cpython-310.pyc differ diff --git a/open_lm/__pycache__/scheduler.cpython-310.pyc b/open_lm/__pycache__/scheduler.cpython-310.pyc index 9651a64..d4301f7 100644 Binary files a/open_lm/__pycache__/scheduler.cpython-310.pyc and b/open_lm/__pycache__/scheduler.cpython-310.pyc differ diff --git a/open_lm/__pycache__/train.cpython-310.pyc b/open_lm/__pycache__/train.cpython-310.pyc index 9b1ce5f..cf3cb39 100644 Binary files a/open_lm/__pycache__/train.cpython-310.pyc and b/open_lm/__pycache__/train.cpython-310.pyc differ diff --git a/open_lm/attention.py b/open_lm/attention.py index e0e8aba..7f2e2f4 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -111,7 +111,7 @@ def torch_attn(queries, keys, values, is_causal, attention_mask=None): if attention_mask is None: bias = None # If we only have one query, assume we don't need to be in causal mode (can attend to all keys). - if queries.shape == 1: + if queries.shape[1] == 1: is_causal = False else: if not is_causal: diff --git a/open_lm/data.py b/open_lm/data.py index 309f43e..05844b9 100644 --- a/open_lm/data.py +++ b/open_lm/data.py @@ -186,7 +186,7 @@ def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, h def tarfile_to_samples_nothrow(src, handler=log_and_continue): # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw streams = url_opener(src, handler=handler) - files = tar_file_expander(streams, handler=handler) + files = tar_file_expander(streams, handler=handler, eof_value=None) samples = group_by_keys_nothrow(files, handler=handler) return samples @@ -205,10 +205,12 @@ def pytorch_worker_seed(increment=0): return wds.utils.pytorch_worker_seed() -_SHARD_SHUFFLE_SIZE = 2000 -_SHARD_SHUFFLE_INITIAL = 2000 #500 -_SAMPLE_SHUFFLE_SIZE = 20000 -_SAMPLE_SHUFFLE_INITIAL = 20000 #4000 + + +_SHARD_SHUFFLE_SIZE = 100000 #10000 +_SHARD_SHUFFLE_INITIAL = 100000 #10000 +_SAMPLE_SHUFFLE_SIZE = 200000 #50000 +_SAMPLE_SHUFFLE_INITIAL = 200000 #50000 class detshuffle2(wds.PipelineStage): diff --git a/open_lm/datapreprocess/.ipynb_checkpoints/make_2048-checkpoint.py b/open_lm/datapreprocess/.ipynb_checkpoints/make_2048-checkpoint.py new file mode 100644 index 0000000..d50763b --- /dev/null +++ b/open_lm/datapreprocess/.ipynb_checkpoints/make_2048-checkpoint.py @@ -0,0 +1,255 @@ +import jsonlines +import glob +import tiktoken +import os +import threading +from webdataset import ShardWriter +import random +import time +import boto3 +import io +import zstandard as zstd +from contextlib import contextmanager +import argparse +from pathlib import Path +from transformers import GPTNeoXTokenizerFast + + +# ======================================== +# = Global variables = +# ======================================== + +QUEUE_MAX = 10_000 +BUFFER_MIN = 100_000 +BUFFER_MAX = 200_000 +CHUNK_SIZE = 2048 + 1 +SHARD_SIZE = 7813 #8192 +SLEEP_TIME = 1 + +S3_BASE = os.environ.get("S3_BASE") + +EOT_TOKEN = "<|endoftext|>" + + +# ================================================ +# = Utility functions = +# ================================================ + + +def write_to_shard(chunks, shard_writer): + for idx, chunk in enumerate(chunks): + shard_writer.write({"__key__": f"{idx:012d}", "txt": str(chunk)}) + + +def upload_to_s3_and_remove(fname): + """Uploads file to s3 and removes it from local file system""" + fname_split = fname.split("/") + s3_path = S3_BASE + fname_split[-2] + "/" + fname_split[-1] + cmd = f"aws s3 cp {fname} {s3_path} && rm {fname}" + print("COMMAND:", cmd) + os.system(cmd) + + +@contextmanager +def get_item_reader(file_name): + """Creates iterator for reading .jsonl files or Zstd compressed .jsonl files""" + if file_name.endswith(".jsonl"): + with jsonlines.open(file_name) as reader: + yield reader + else: + dctx = zstd.ZstdDecompressor() + with open(file_name, "rb") as compressed_file: + with dctx.stream_reader(compressed_file) as reader: + with io.TextIOWrapper(reader, encoding="utf-8") as text_reader: + with jsonlines.Reader(text_reader) as jsonl_reader: + yield jsonl_reader + + +def pop_random(els): + """O(1) way to pop an element randomly from a list + NOT THREAD SAFE!!! (so make sure we have a lock enabled) + (also mutates the order of the list, but that's okay) + """ + random_idx = random.randint(0, len(els) - 1) + els[-1], els[random_idx] = els[random_idx], els[-1] + return els.pop() + + +# ====================================================== +# = Processor/Consumer Subprocess = +# ====================================================== +# These get called in a threaded way + + +def process_files(file_list, buffer, enc, buffer_lock): + remaining_tokens = [] + queue = [] + + def dump_queue_to_buffer(): + with buffer_lock: + while queue: + buffer.append(queue.pop(0)) + + for file_name in file_list: + print("Processing", file_name) + + with get_item_reader(file_name) as item_reader: + for item in item_reader: + string = item["text"] + try: + tokens = remaining_tokens + enc(string) + [EOT_TOKEN] + remaining_tokens = [] + except: + print("Failed to encode string.") + continue + + for i in range(0, len(tokens), CHUNK_SIZE): + chunk = tokens[i : i + CHUNK_SIZE] + if len(chunk) < CHUNK_SIZE: + remaining_tokens = chunk + else: + if len(buffer) > BUFFER_MAX: + time.sleep(1) + continue + + if buffer_lock.locked(): + if len(queue) < QUEUE_MAX: + queue.append(chunk) + else: + time.sleep(1) + else: + if queue: + dump_queue_to_buffer() + with buffer_lock: + buffer.append(chunk) + + +def consumer(my_id, output_dir, threads, buffer, buffer_lock, num_consumers, upload_to_s3=False): + output_directory = f"{output_dir}/{CHUNK_SIZE - 1}-v1/{my_id}" + os.makedirs(output_directory, exist_ok=True) + shard_writer = ShardWriter(os.path.join(output_directory, "shard-%07d.tar"), maxcount=SHARD_SIZE) + + chunks = [] + + start_time = time.time() + + while any(t.is_alive() for t in threads): + time.sleep(SLEEP_TIME) + with buffer_lock: + lenb = len(buffer) + print("Length of buffer", lenb) + if lenb >= BUFFER_MIN: + while buffer and len(chunks) < SHARD_SIZE: + chunks.append(pop_random(buffer)) + + if len(chunks) == SHARD_SIZE: + print(f"I am {my_id} and I am writing a shard.", len(buffer)) + write_to_shard(chunks, shard_writer) + if upload_to_s3: + upload_to_s3_and_remove(shard_writer.fname) + # print("FNAME", shard_writer.fname) + chunks = [] + time_for_shard = time.time() - start_time + print("shards / s", num_consumers / time_for_shard) + print("tokens / s", num_consumers * SHARD_SIZE * CHUNK_SIZE / time_for_shard) + print( + "hours req for 1.2T tokens", + 1_200_000_000_000 / (num_consumers * SHARD_SIZE * CHUNK_SIZE / time_for_shard) / 3600, + ) + + start_time = time.time() + + # Process the remaining items in the buffer after all threads have completed + while buffer: + with buffer_lock: + while buffer and len(chunks) < SHARD_SIZE: + chunks.append(pop_random(buffer)) + + write_to_shard(chunks, shard_writer) + if upload_to_s3: + upload_to_s3_and_remove(shard_writer.fname) + chunks = [] + + +def tokenize_eleutherai(tokenizer, string): + return tokenizer(string).input_ids + + +# ========================================================= +# = Main function + Argument parsing = +# ========================================================= + + +def main( + input_files, + output_dir, + tokenizer="EleutherAI/gpt-neox-20b", + num_workers=32, + num_consumers=8, + upload_to_s3=False, +): + os.makedirs(f"{output_dir}/tars-{CHUNK_SIZE - 1}-v1", exist_ok=True) + + input_files = [glob.glob(input_file) for input_file in input_files] + input_files = [x for y in input_files for x in y] + + # Shuffle the input files + random.shuffle(input_files) + + print("Input files", input_files) + + enc = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") + + tokenize = lambda x: tokenize_eleutherai(enc, x) + buffer = [] # Use list instead of queue.Queue + buffer_lock = threading.Lock() + + files_per_worker = len(input_files) // num_workers + threads = [] + for i in range(num_workers): + start = i * files_per_worker + end = (i + 1) * files_per_worker if i < num_workers - 1 else len(input_files) + t = threading.Thread( + target=process_files, + args=(input_files[start:end], buffer, tokenize, buffer_lock), + ) + t.start() + threads.append(t) + + consumer_threads = [] + for i in range(num_consumers): + t = threading.Thread( + target=consumer, + args=( + i, + output_dir, + threads, + buffer, + buffer_lock, + num_consumers, + upload_to_s3, + ), + ) + t.start() + consumer_threads.append(t) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-files", type=str, nargs="+") + parser.add_argument("--output-dir", type=Path) + parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b") + parser.add_argument("--num-workers", type=int, default=32) + parser.add_argument("--num-consumers", type=int, default=8) + parser.add_argument("--upload-to-s3", action="store_true") + + args = parser.parse_args() + + main( + args.input_files, + args.output_dir, + args.tokenizer, + args.num_workers, + args.num_consumers, + args.upload_to_s3, + ) \ No newline at end of file diff --git a/open_lm/datapreprocess/make_2048.py b/open_lm/datapreprocess/make_2048.py index e0da8bb..69e7429 100644 --- a/open_lm/datapreprocess/make_2048.py +++ b/open_lm/datapreprocess/make_2048.py @@ -20,7 +20,7 @@ # ======================================== QUEUE_MAX = 10_000 -BUFFER_MIN = 10_000 +BUFFER_MIN = 100_000 BUFFER_MAX = 200_000 CHUNK_SIZE = 2048 + 1 SHARD_SIZE = 8192 @@ -252,4 +252,4 @@ def main( args.num_workers, args.num_consumers, args.upload_to_s3, - ) + ) \ No newline at end of file diff --git a/open_lm/datapreprocess/wiki_download.py b/open_lm/datapreprocess/wiki_download.py index f5f1a05..a4e10da 100644 --- a/open_lm/datapreprocess/wiki_download.py +++ b/open_lm/datapreprocess/wiki_download.py @@ -30,4 +30,4 @@ def main(output_dir): ) args = parser.parse_args() - main(args.output_dir) \ No newline at end of file + main(args.output_dir) diff --git a/open_lm/eval.py b/open_lm/eval.py deleted file mode 100644 index 196f006..0000000 --- a/open_lm/eval.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -from open_lm.params import parse_args -from open_lm.model import test_classif_model -import webdataset as wds -from open_lm.data import get_wds_dataset -from open_lm.data import sample_chunk - -args = parse_args([]) -args.per_gpu_val_batch_size = 8 -args.vocab_size = 50432 -args.seq_len = 2048 -args.world_size = 1 -args.rank = 0 - -args.model = "open_lm_160m" -model_path = "/media/logs/classif_C4160m3.2B_C4DCLM_320M/checkpoints/epoch_1.pt" - -args.val_data = ['/media/datasets/C4/C4-shard-0000219.tar'] - -model = test_classif_model(args, model_path) -model = model.to('cuda') - -dataset = get_wds_dataset(args, is_train=False, epoch=0, floor=True, tokenizer=None, data_key="txt", force_num_samples=None) - -dataloader = dataset.dataloader - -sum = 0 -for sample in dataloader: - (texts,) = sample - texts = torch.LongTensor(texts).to('cuda') - inputs, targets = sample_chunk(texts, args) - - with torch.no_grad(): - out, _, _ = model(inputs) - - pred = torch.argmax(out,2)[:,-1].sum() - - sum = sum + pred.item() - -print(sum) - - - - diff --git a/open_lm/eval2.py b/open_lm/eval2.py new file mode 100644 index 0000000..2d5121d --- /dev/null +++ b/open_lm/eval2.py @@ -0,0 +1,96 @@ +import torch +from open_lm.params import parse_args +import argparse +from open_lm.model import test_classif_model + +args = parse_args([]) +parser = argparse.ArgumentParser(description="Override params arguments with command-line arguments") +parser.add_argument('--model', type=str, help='Model name to use for evaluation') +parser.add_argument('--classif-model-path', type=str, help='Path to the classification model checkpoint') +parser.add_argument('--str1', type=str, help='test set 1') +parser.add_argument('--str2', type=str, help='test set 2') +cmd_args = parser.parse_args() +args.model = cmd_args.model +args.classif_model_path = cmd_args.classif_model_path + + + +########################################################################################################### + +args.num_classes = 2 + +#Dolma_gen.pt +#DCLM_gen.pt +#FWEdu_gen.pt + +#'C4.pt' +#'FineWeb.pt' +#'RefinedWeb.pt' + + +str1 = cmd_args.str1 +str2 = cmd_args.str2 + + +base_path = '/media/datasets/test_set/' + +data_path1 = base_path + str1 + '.pt' +data_path2 = base_path + str2 + '.pt' + + + +########################################################################################################### + +model = test_classif_model(args) +model = model.to('cuda') + + + +dataset = torch.load(data_path1) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 0).item() + + sum = sum + n_correct + +sum1 = sum +len1 = len(dataset) +print(str1, sum1, "/" , len1) + +########################################################################################################################################################################################## + +dataset = torch.load(data_path2) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 1).item() + + sum = sum + n_correct + +sum2 = sum +len2 = len(dataset) +print(str2, sum2, "/" , len2) + +########################################################################################################################################################################################## + + +total_sum = sum1+sum2 +total_length = len1+len2 + +print("Total= ", total_sum, "/" , total_length ) +print("Accuracy= ", total_sum/total_length * 100, "%") + + diff --git a/open_lm/eval3.py b/open_lm/eval3.py new file mode 100644 index 0000000..9c1e761 --- /dev/null +++ b/open_lm/eval3.py @@ -0,0 +1,116 @@ +import torch +from open_lm.params import parse_args +import argparse +from open_lm.model import test_classif_model + +args = parse_args([]) +parser = argparse.ArgumentParser(description="Override params arguments with command-line arguments") +parser.add_argument('--model', type=str, help='Model name to use for evaluation') +parser.add_argument('--classif-model-path', type=str, help='Path to the classification model checkpoint') +parser.add_argument('--str1', type=str, help='test set 1') +parser.add_argument('--str2', type=str, help='test set 2') +parser.add_argument('--str3', type=str, help='test set 3') +cmd_args = parser.parse_args() +args.model = cmd_args.model +args.classif_model_path = cmd_args.classif_model_path + + + +########################################################################################################### + +args.num_classes = 3 + +#Dolma_gen.pt +#DCLM_gen.pt +#FWEdu_gen.pt + +#'C4.pt' +#'FineWeb.pt' +#'RefinedWeb.pt' + + +str1 = cmd_args.str1 +str2 = cmd_args.str2 +str3 = cmd_args.str3 + +base_path = '/media/datasets/test_set/' + +data_path1 = base_path + str1 + '.pt' +data_path2 = base_path + str2 + '.pt' +data_path3 = base_path + str3 + '.pt' + + +########################################################################################################### + +model = test_classif_model(args) +model = model.to('cuda') + + + +dataset = torch.load(data_path1) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 0).item() + + sum = sum + n_correct + +sum1 = sum +len1 = len(dataset) +print(str1, sum1, "/" , len1) + +########################################################################################################################################################################################## + +dataset = torch.load(data_path2) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 1).item() + + sum = sum + n_correct + +sum2 = sum +len2 = len(dataset) +print(str2, sum2, "/" , len2) + +########################################################################################################################################################################################## + +dataset = torch.load(data_path3) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 2).item() + + sum = sum + n_correct + +sum3 = sum +len3 = len(dataset) +print(str3, sum3, "/" , len3) + +########################################################################################################################################################################################## + +total_sum = sum1+sum2+sum3 +total_length = len1+len2+len3 + +print("Total= ", total_sum, "/" , total_length ) +print("Accuracy= ", total_sum/total_length * 100, "%") + + diff --git a/open_lm/eval3_prop.py b/open_lm/eval3_prop.py new file mode 100644 index 0000000..8338786 --- /dev/null +++ b/open_lm/eval3_prop.py @@ -0,0 +1,213 @@ +import torch +from open_lm.params import parse_args +import argparse +from open_lm.model import test_classif_model + +args = parse_args([]) +parser = argparse.ArgumentParser(description="Override params arguments with command-line arguments") +parser.add_argument('--model', type=str, help='Model name to use for evaluation') +parser.add_argument('--classif-model-path', type=str, help='Path to the classification model checkpoint') +parser.add_argument('--str1', type=str, help='test set 1') +parser.add_argument('--str2', type=str, help='test set 2') +#parser.add_argument('--str3', type=str, help='test set 3') +#parser.add_argument('--str4', type=str, help='test set 4') +#parser.add_argument('--str5', type=str, help='test set 5') +#parser.add_argument('--str6', type=str, help='test set 6') +cmd_args = parser.parse_args() +args.model = cmd_args.model +args.classif_model_path = cmd_args.classif_model_path + + + +########################################################################################################### + +args.num_classes = 2 + + +str1 = cmd_args.str1 +str2 = cmd_args.str2 +#str3 = cmd_args.str3 +#str4 = cmd_args.str4 +#str5 = cmd_args.str5 +#str6 = cmd_args.str6 + + +data1= "Llama1_gen" #"DCLM_gen" +data2= "Dolma_gen" +data3= "FWEdu_gen" + +base_path = '/media/datasets/test_set/' + +data_path1 = base_path + data1 + '.pt' +data_path2 = base_path + data2 + '.pt' +data_path3 = base_path + data3 + '.pt' + +model = test_classif_model(args) +model = model.to('cuda') + + +soft_max = torch.nn.Softmax(dim=2) +########################################################################################################### + +pred = [] +conf=[] + +dataset = torch.load(data_path1) + +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + out = soft_max(out) + pred.append( torch.argmax(out,2)[:,-1].item() ) + conf.append( torch.max(out,2)[0][:,-1].item() ) + +c1 = pred.count(0) +c2 = pred.count(1) +#c3 = pred.count(2) +#c4 = pred.count(3) +#c5 = pred.count(4) +#c6 = pred.count(5) + + +sum_conf1 = sum(c for p, c in zip(pred, conf) if p == 0) +sum_conf2 = sum(c for p, c in zip(pred, conf) if p == 1) +#sum_conf3 = sum(c for p, c in zip(pred, conf) if p == 2) +#sum_conf4 = sum(c for p, c in zip(pred, conf) if p == 3) +#sum_conf5 = sum(c for p, c in zip(pred, conf) if p == 4) +#sum_conf6 = sum(c for p, c in zip(pred, conf) if p == 5) + + +av1 = sum_conf1/c1 if c1>0 else 0 +av2 = sum_conf2/c2 if c2>0 else 0 +#av3 = sum_conf3/c3 if c3>0 else 0 +#av4 = sum_conf4/c4 if c4>0 else 0 +#av5 = sum_conf5/c5 if c5>0 else 0 +#av6 = sum_conf6/c6 if c6>0 else 0 + + + +length = len(dataset) + +print(data1, ':') +print(str1, c1, "/", length, '=', c1/length, "with confidence ", av1) +print(str2, c2, "/", length, '=', c2/length, "with confidence ", av2) +#print(str3, c3, "/", length, '=', c3/length, "with confidence ", av3) +#print(str4, c4, "/", length, '=', c4/length, "with confidence ", av4) +#print(str5, c5, "/", length, '=', c5/length, "with confidence ", av5) +#print(str6, c6, "/", length, '=', c6/length, "with confidence ", av6) +print("\n") + +exit() +########################################################################################################################################################################################## + +pred = [] +conf=[] + +dataset = torch.load(data_path2) + +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + out = soft_max(out) + pred.append( torch.argmax(out,2)[:,-1].item() ) + conf.append( torch.max(out,2)[0][:,-1].item() ) + +c1 = pred.count(0) +c2 = pred.count(1) +c3 = pred.count(2) +c4 = pred.count(3) +c5 = pred.count(4) +c6 = pred.count(5) +c7 = pred.count(6) + +sum_conf1 = sum(c for p, c in zip(pred, conf) if p == 0) +sum_conf2 = sum(c for p, c in zip(pred, conf) if p == 1) +sum_conf3 = sum(c for p, c in zip(pred, conf) if p == 2) +sum_conf4 = sum(c for p, c in zip(pred, conf) if p == 3) +sum_conf5 = sum(c for p, c in zip(pred, conf) if p == 4) +sum_conf6 = sum(c for p, c in zip(pred, conf) if p == 5) +sum_conf7 = sum(c for p, c in zip(pred, conf) if p == 6) + +av1 = sum_conf1/c1 if c1>0 else 0 +av2 = sum_conf2/c2 if c2>0 else 0 +av3 = sum_conf3/c3 if c3>0 else 0 +av4 = sum_conf4/c4 if c4>0 else 0 +av5 = sum_conf5/c5 if c5>0 else 0 +av6 = sum_conf6/c6 if c6>0 else 0 +av7 = sum_conf7/c7 if c7>0 else 0 + + +length = len(dataset) + +print(data2, ':') +print(str1, c1, "/", length, '=', c1/length, "with confidence ", av1) +print(str2, c2, "/", length, '=', c2/length, "with confidence ", av2) +print(str3, c3, "/", length, '=', c3/length, "with confidence ", av3) +print(str4, c4, "/", length, '=', c4/length, "with confidence ", av4) +print(str5, c5, "/", length, '=', c5/length, "with confidence ", av5) +print(str6, c6, "/", length, '=', c6/length, "with confidence ", av6) +print(str7, c7, "/", length, '=', c7/length, "with confidence ", av7) +print("\n") + +########################################################################################################################################################################################## + +pred = [] +conf=[] + +dataset = torch.load(data_path3) + +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + out = soft_max(out) + pred.append( torch.argmax(out,2)[:,-1].item() ) + conf.append( torch.max(out,2)[0][:,-1].item() ) + +c1 = pred.count(0) +c2 = pred.count(1) +c3 = pred.count(2) +c4 = pred.count(3) +c5 = pred.count(4) +c6 = pred.count(5) +c7 = pred.count(6) + +sum_conf1 = sum(c for p, c in zip(pred, conf) if p == 0) +sum_conf2 = sum(c for p, c in zip(pred, conf) if p == 1) +sum_conf3 = sum(c for p, c in zip(pred, conf) if p == 2) +sum_conf4 = sum(c for p, c in zip(pred, conf) if p == 3) +sum_conf5 = sum(c for p, c in zip(pred, conf) if p == 4) +sum_conf6 = sum(c for p, c in zip(pred, conf) if p == 5) +sum_conf7 = sum(c for p, c in zip(pred, conf) if p == 6) + +av1 = sum_conf1/c1 if c1>0 else 0 +av2 = sum_conf2/c2 if c2>0 else 0 +av3 = sum_conf3/c3 if c3>0 else 0 +av4 = sum_conf4/c4 if c4>0 else 0 +av5 = sum_conf5/c5 if c5>0 else 0 +av6 = sum_conf6/c6 if c6>0 else 0 +av7 = sum_conf7/c7 if c7>0 else 0 + + +length = len(dataset) + +print(data3, ':') +print(str1, c1, "/", length, '=', c1/length, "with confidence ", av1) +print(str2, c2, "/", length, '=', c2/length, "with confidence ", av2) +print(str3, c3, "/", length, '=', c3/length, "with confidence ", av3) +print(str4, c4, "/", length, '=', c4/length, "with confidence ", av4) +print(str5, c5, "/", length, '=', c5/length, "with confidence ", av5) +print(str6, c6, "/", length, '=', c6/length, "with confidence ", av6) +print(str7, c7, "/", length, '=', c7/length, "with confidence ", av7) +print("\n") +########################################################################################################################################################################################## + + diff --git a/open_lm/eval3_prop_2048.py b/open_lm/eval3_prop_2048.py new file mode 100644 index 0000000..d269b36 --- /dev/null +++ b/open_lm/eval3_prop_2048.py @@ -0,0 +1,139 @@ +import torch +from open_lm.params import parse_args +import argparse +from open_lm.model import test_classif_model +import webdataset as wds +from open_lm.data import get_wds_dataset +from open_lm.data import sample_chunk + +args = parse_args([]) +parser = argparse.ArgumentParser(description="Override params arguments with command-line arguments") +parser.add_argument('--model', type=str, help='Model name to use for evaluation') +parser.add_argument('--classif-model-path', type=str, help='Path to the classification model checkpoint') +parser.add_argument('--str1', type=str, help='test set 1') +parser.add_argument('--str2', type=str, help='test set 2') +parser.add_argument('--str3', type=str, help='test set 3') +parser.add_argument('--str4', type=str, help='test set 4') +cmd_args = parser.parse_args() +args.model = cmd_args.model +args.classif_model_path = cmd_args.classif_model_path + + +args.per_gpu_val_batch_size = 1 +args.vocab_size = 50432 +args.seq_len = 2047 +args.world_size = 1 +args.rank = 0 + +########################################################################################################### + +args.num_classes = 4 + + +str1 = cmd_args.str1 +str2 = cmd_args.str2 +str3 = cmd_args.str3 +str4 = cmd_args.str4 + +data1= "DCLM" +data2= "Dolma" +data3= "FWEdu" + +base_path = '/media/datasets/test_set/' + +data_path1 = base_path + data1 + '.tar' +data_path2 = base_path + data2 + '.tar' +data_path3 = base_path + data3 + '.tar' + +model = test_classif_model(args) +model = model.to('cuda') + +########################################################################################################### + +args.val_data = [data_path1] +dataset = get_wds_dataset(args, is_train=False, epoch=0, floor=True, tokenizer=None, data_key="txt", force_num_samples=None) +dataloader = dataset.dataloader + + +pred = [] +for sample in dataloader: + (texts,) = sample + inputs = torch.LongTensor(texts).to('cuda') + + with torch.no_grad(): + out, _, _ = model(inputs) + + pred.append( torch.argmax(out,2)[:,-1].item() ) + +c1 = pred.count(0) +c2 = pred.count(1) +c3 = pred.count(2) +c4 = pred.count(3) + +length = 4096 + +print(data1, ':') +print(str1, c1, "/", length, '=', c1/length) +print(str2, c2, "/", length, '=', c2/length) +print(str3, c3, "/", length, '=', c3/length) +print(str4, c4, "/", length, '=', c4/length) + +########################################################################################################################################################################################## + +args.val_data = [data_path2] +dataset = get_wds_dataset(args, is_train=False, epoch=0, floor=True, tokenizer=None, data_key="txt", force_num_samples=None) +dataloader = dataset.dataloader + + +pred = [] +for sample in dataloader: + (texts,) = sample + inputs = torch.LongTensor(texts).to('cuda') + + with torch.no_grad(): + out, _, _ = model(inputs) + + pred.append( torch.argmax(out,2)[:,-1].item() ) + +c1 = pred.count(0) +c2 = pred.count(1) +c3 = pred.count(2) + +length = 4096 + +print(data2, ':') +print(str1, c1, "/", length, '=', c1/length) +print(str2, c2, "/", length, '=', c2/length) +print(str3, c3, "/", length, '=', c3/length) + +########################################################################################################################################################################################## + +args.val_data = [data_path3] +dataset = get_wds_dataset(args, is_train=False, epoch=0, floor=True, tokenizer=None, data_key="txt", force_num_samples=None) +dataloader = dataset.dataloader + + +pred = [] +for sample in dataloader: + (texts,) = sample + inputs = torch.LongTensor(texts).to('cuda') + + with torch.no_grad(): + out, _, _ = model(inputs) + + pred.append( torch.argmax(out,2)[:,-1].item() ) + +c1 = pred.count(0) +c2 = pred.count(1) +c3 = pred.count(2) + +length = 4096 + +print(data3, ':') +print(str1, c1, "/", length, '=', c1/length) +print(str2, c2, "/", length, '=', c2/length) +print(str3, c3, "/", length, '=', c3/length) + +########################################################################################################################################################################################## + + diff --git a/open_lm/eval3_varylength.py b/open_lm/eval3_varylength.py new file mode 100644 index 0000000..e6f1c5c --- /dev/null +++ b/open_lm/eval3_varylength.py @@ -0,0 +1,118 @@ +import torch +from open_lm.params import parse_args +import argparse +from open_lm.model import test_classif_model + +args = parse_args([]) +parser = argparse.ArgumentParser(description="Override params arguments with command-line arguments") +parser.add_argument('--model', type=str, help='Model name to use for evaluation') +parser.add_argument('--classif-model-path', type=str, help='Path to the classification model checkpoint') +parser.add_argument('--str1', type=str, help='test set 1') +parser.add_argument('--str2', type=str, help='test set 2') +parser.add_argument('--str3', type=str, help='test set 3') +cmd_args = parser.parse_args() +args.model = cmd_args.model +args.classif_model_path = cmd_args.classif_model_path + + + +########################################################################################################### + +args.num_classes = 3 + +#Dolma_gen.pt +#DCLM_gen.pt +#FWEdu_gen.pt + +#'C4.pt' +#'FineWeb.pt' +#'RefinedWeb.pt' + + +str1 = cmd_args.str1 +str2 = cmd_args.str2 +str3 = cmd_args.str3 + +base_path = '/media/datasets/test_set/' + +data_path1 = base_path + str1 + '.pt' +data_path2 = base_path + str2 + '.pt' +data_path3 = base_path + str3 + '.pt' + +model = test_classif_model(args) +model = model.to('cuda') + +########################################################################################################### + +dataset = torch.load(data_path1) +n_bins = len(dataset) +sum = torch.zeros(n_bins, dtype=torch.int) + +for i in range(n_bins): + n_samples = len(dataset[i]) + for j in range(n_samples): + sample = torch.LongTensor(dataset[i][j]).to('cuda') + with torch.no_grad(): + out, _, _ = model(sample) + pred = torch.argmax(out,2)[:,-1] + + if pred == 0: + sum[i] +=1 + + +sum1 = sum +len1 = n_samples +print(str1, sum1) + +########################################################################################################################################################################################## + +dataset = torch.load(data_path2) +n_bins = len(dataset) +sum = torch.zeros(n_bins, dtype=torch.int) + +for i in range(n_bins): + n_samples = len(dataset[i]) + for j in range(n_samples): + sample = torch.LongTensor(dataset[i][j]).to('cuda') + with torch.no_grad(): + out, _, _ = model(sample) + pred = torch.argmax(out,2)[:,-1] + + if pred == 1: + sum[i] +=1 + +sum2 = sum +len2 = n_samples +print(str2, sum2) + +########################################################################################################################################################################################## + +dataset = torch.load(data_path3) +n_bins = len(dataset) +sum = torch.zeros(n_bins, dtype=torch.int) + +for i in range(n_bins): + n_samples = len(dataset[i]) + for j in range(n_samples): + sample = torch.LongTensor(dataset[i][j]).to('cuda') + with torch.no_grad(): + out, _, _ = model(sample) + pred = torch.argmax(out,2)[:,-1] + + if pred == 2: + sum[i] +=1 + +sum3 = sum +len3 = n_samples +print(str3, sum3) + +########################################################################################################################################################################################## + +total_sum = sum1+sum2+sum3 +total_len = len1+len2+len3 + +print(len1,len2,len3,"\n") + +for i in range(n_bins): + print("Accuracy at bin ", i, " Seq. lengths range from ", i*200, " to ", i*200+200, " is: ", total_sum[i].item()/total_len * 100, "%") + diff --git a/open_lm/eval3_varylength_2000.py b/open_lm/eval3_varylength_2000.py new file mode 100644 index 0000000..06ada50 --- /dev/null +++ b/open_lm/eval3_varylength_2000.py @@ -0,0 +1,124 @@ +import torch +from open_lm.params import parse_args +import argparse +from open_lm.model import test_classif_model + +args = parse_args([]) +parser = argparse.ArgumentParser(description="Override params arguments with command-line arguments") +parser.add_argument('--model', type=str, help='Model name to use for evaluation') +parser.add_argument('--classif-model-path', type=str, help='Path to the classification model checkpoint') +parser.add_argument('--str1', type=str, help='test set 1') +parser.add_argument('--str2', type=str, help='test set 2') +parser.add_argument('--str3', type=str, help='test set 3') +cmd_args = parser.parse_args() +args.model = cmd_args.model +args.classif_model_path = cmd_args.classif_model_path + + + +########################################################################################################### + +args.num_classes = 3 + +#Dolma_gen.pt +#DCLM_gen.pt +#FWEdu_gen.pt + +#'C4.pt' +#'FineWeb.pt' +#'RefinedWeb.pt' + + +str1 = cmd_args.str1 +str2 = cmd_args.str2 +str3 = cmd_args.str3 + +base_path = '/media/datasets/test_set/' + +data_path1 = base_path + str1 + '.pt' +data_path2 = base_path + str2 + '.pt' +data_path3 = base_path + str3 + '.pt' + + +########################################################################################################### + +model = test_classif_model(args) +model = model.to('cuda') + + +indices = torch.arange(0, 2048, 200) + +sum = torch.zeros(len(indices)) + +dataset = torch.load(data_path1) + +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,indices] + + n_correct = torch.sum(pred == 0, dim=0) + + sum = sum + n_correct.cpu() + +sum1 = sum +len1 = len(dataset) +print(str1, sum1) + +########################################################################################################################################################################################## + +dataset = torch.load(data_path2) + +sum = torch.zeros(len(indices)) + +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,indices] + + n_correct = torch.sum(pred == 1, dim=0) + + sum = sum + n_correct.cpu() + +sum2 = sum +len2 = len(dataset) +print(str2, sum2) + +########################################################################################################################################################################################## + +dataset = torch.load(data_path3) + +sum = torch.zeros(len(indices)) + +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,indices] + + n_correct = torch.sum(pred == 2, dim=0) + + sum = sum + n_correct.cpu() + +sum3 = sum +len3 = len(dataset) +print(str3, sum3) + +########################################################################################################################################################################################## + +total = sum1+sum2+sum3 + +print(len1,len2,len3,"\n") + +for i in range(len(indices)): + print("Accuracy at token", indices[i], "=", total[i].item()/(len1+len2+len3)) + + diff --git a/open_lm/eval4.py b/open_lm/eval4.py new file mode 100644 index 0000000..a732bbf --- /dev/null +++ b/open_lm/eval4.py @@ -0,0 +1,139 @@ +import torch +from open_lm.params import parse_args +import argparse +from open_lm.model import test_classif_model + +args = parse_args([]) +parser = argparse.ArgumentParser(description="Override params arguments with command-line arguments") +parser.add_argument('--model', type=str, help='Model name to use for evaluation') +parser.add_argument('--classif-model-path', type=str, help='Path to the classification model checkpoint') +parser.add_argument('--str1', type=str, help='test set 1') +parser.add_argument('--str2', type=str, help='test set 2') +parser.add_argument('--str3', type=str, help='test set 3') +parser.add_argument('--str4', type=str, help='test set 4') +cmd_args = parser.parse_args() +args.model = cmd_args.model +args.classif_model_path = cmd_args.classif_model_path + + + +########################################################################################################### + +args.num_classes = 4 + +#Dolma_gen.pt +#DCLM_gen.pt +#FWEdu_gen.pt + +#'C4.pt' +#'FineWeb.pt' +#'RefinedWeb.pt' + + +str1 = cmd_args.str1 +str2 = cmd_args.str2 +str3 = cmd_args.str3 +str4 = cmd_args.str4 + +base_path = '/media/datasets/test_set/' + +data_path1 = base_path + str1 + '.pt' +data_path2 = base_path + str2 + '.pt' +data_path3 = base_path + str3 + '.pt' +data_path4 = base_path + str4 + '.pt' + + +########################################################################################################### + +model = test_classif_model(args) +model = model.to('cuda') + + + +dataset = torch.load(data_path1) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 0).item() + + sum = sum + n_correct + +sum1 = sum +len1 = len(dataset) +print(str1, sum1, "/" , len1) + +########################################################################################################################################################################################## + +dataset = torch.load(data_path2) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 1).item() + + sum = sum + n_correct + +sum2 = sum +len2 = len(dataset) +print(str2, sum2, "/" , len2) + +########################################################################################################################################################################################## + +dataset = torch.load(data_path3) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 2).item() + + sum = sum + n_correct + +sum3 = sum +len3 = len(dataset) +print(str3, sum3, "/" , len3) + +########################################################################################################################################################################################## + +dataset = torch.load(data_path4) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 3).item() + + sum = sum + n_correct + +sum4 = sum +len4 = len(dataset) +print(str4, sum4, "/" , len4) + +########################################################################################################################################################################################## + +total_sum = sum1+sum2+sum3+sum4 +total_length = len1+len2+len3+len4 + +print("Total= ", total_sum, "/" , total_length ) +print("Accuracy= ", total_sum/total_length * 100, "%") + + diff --git a/open_lm/eval5.py b/open_lm/eval5.py new file mode 100644 index 0000000..681223c --- /dev/null +++ b/open_lm/eval5.py @@ -0,0 +1,162 @@ +import torch +from open_lm.params import parse_args +import argparse +from open_lm.model import test_classif_model + +args = parse_args([]) +parser = argparse.ArgumentParser(description="Override params arguments with command-line arguments") +parser.add_argument('--model', type=str, help='Model name to use for evaluation') +parser.add_argument('--classif-model-path', type=str, help='Path to the classification model checkpoint') +parser.add_argument('--str1', type=str, help='test set 1') +parser.add_argument('--str2', type=str, help='test set 2') +parser.add_argument('--str3', type=str, help='test set 3') +parser.add_argument('--str4', type=str, help='test set 4') +parser.add_argument('--str5', type=str, help='test set 5') +cmd_args = parser.parse_args() +args.model = cmd_args.model +args.classif_model_path = cmd_args.classif_model_path + + + +########################################################################################################### + +args.num_classes = 5 + +#Dolma_gen.pt +#DCLM_gen.pt +#FWEdu_gen.pt + +#'C4.pt' +#'FineWeb.pt' +#'RefinedWeb.pt' + + +str1 = cmd_args.str1 +str2 = cmd_args.str2 +str3 = cmd_args.str3 +str4 = cmd_args.str4 +str5 = cmd_args.str5 + +base_path = '/media/datasets/test_set/' + +data_path1 = base_path + str1 + '.pt' +data_path2 = base_path + str2 + '.pt' +data_path3 = base_path + str3 + '.pt' +data_path4 = base_path + str4 + '.pt' +data_path5 = base_path + str5 + '.pt' + + +########################################################################################################### + +model = test_classif_model(args) +model = model.to('cuda') + + + +dataset = torch.load(data_path1) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 0).item() + + sum = sum + n_correct + +sum1 = sum +len1 = len(dataset) +print(str1, sum1, "/" , len1) + +########################################################################################################################################################################################## + +dataset = torch.load(data_path2) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 1).item() + + sum = sum + n_correct + +sum2 = sum +len2 = len(dataset) +print(str2, sum2, "/" , len2) + +########################################################################################################################################################################################## + +dataset = torch.load(data_path3) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 2).item() + + sum = sum + n_correct + +sum3 = sum +len3 = len(dataset) +print(str3, sum3, "/" , len3) + +########################################################################################################################################################################################## + +dataset = torch.load(data_path4) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 3).item() + + sum = sum + n_correct + +sum4 = sum +len4 = len(dataset) +print(str4, sum4, "/" , len4) + +########################################################################################################################################################################################## + +dataset = torch.load(data_path5) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 4).item() + + sum = sum + n_correct + +sum5 = sum +len5 = len(dataset) +print(str5, sum5, "/" , len5) + +########################################################################################################################################################################################## + +total_sum = sum1+sum2+sum3+sum4+sum5 +total_length = len1+len2+len3+len4+len5 + +print("Total= ", total_sum, "/" , total_length ) +print("Accuracy= ", total_sum/total_length * 100, "%") + + diff --git a/open_lm/eval_redpajama_seq.py b/open_lm/eval_redpajama_seq.py new file mode 100644 index 0000000..85fe453 --- /dev/null +++ b/open_lm/eval_redpajama_seq.py @@ -0,0 +1,167 @@ +import torch +from open_lm.params import parse_args +import argparse +from open_lm.model import test_classif_model + +args = parse_args([]) +parser = argparse.ArgumentParser(description="Override params arguments with command-line arguments") +parser.add_argument('--model', type=str, help='Model name to use for evaluation') +parser.add_argument('--classif-model-path', type=str, help='Path to the classification model checkpoint') +cmd_args = parser.parse_args() +args.model = cmd_args.model +args.classif_model_path = cmd_args.classif_model_path + + + +########################################################################################################### +args.num_classes = 6 + +path1 = "/media/datasets/RedPajama/val_seq/arxiv-shard-0000019.pt" +path2 = "/media/datasets/RedPajama/val_seq/c4-shard-0000019.pt" +path3 = "/media/datasets/RedPajama/val_seq/cc-shard-0000019.pt" +path4 = "/media/datasets/RedPajama/val_seq/gh-shard-0000019.pt" +path5 = "/media/datasets/RedPajama/val_seq/se-shard-0000019.pt" +path6 = "/media/datasets/RedPajama/val_seq/wiki-shard-0000009.pt" + +str1 = "Arxiv" +str2 = "C4" +str3 = "CC" +str4 = "Github" +str5 = "StackExchange" +str6 = "Wikipedia" +########################################################################################################### + +model = test_classif_model(args) +model = model.to('cuda') + + + +dataset = torch.load(path1) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 0).item() + + sum = sum + n_correct + +sum1 = sum +len1 = len(dataset) +print(str1, sum1, "/" , len1) + +########################################################################################################################################################################################## + +dataset = torch.load(path2) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 1).item() + + sum = sum + n_correct + +sum2 = sum +len2 = len(dataset) +print(str2, sum2, "/" , len2) + +########################################################################################################################################################################################## + +dataset = torch.load(path3) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 2).item() + + sum = sum + n_correct + +sum3 = sum +len3 = len(dataset) +print(str3, sum3, "/" , len3) + +########################################################################################################################################################################################## + +dataset = torch.load(path4) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 3).item() + + sum = sum + n_correct + +sum4 = sum +len4 = len(dataset) +print(str4, sum4, "/" , len4) + +########################################################################################################################################################################################## + +dataset = torch.load(path5) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 4).item() + + sum = sum + n_correct + +sum5 = sum +len5 = len(dataset) +print(str5, sum5, "/" , len5) + +########################################################################################################################################################################################## + +dataset = torch.load(path6) +sum = 0 +for sample in dataset: + sample = torch.LongTensor(sample).to('cuda') + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 5).item() + + sum = sum + n_correct + +sum6 = sum +len6 = len(dataset) +print(str6, sum6, "/" , len6) + +########################################################################################################################################################################################## + + + +total_sum = sum1+sum2+sum3+sum4+sum5+sum6 +total_length = len1+len2+len3+len4+len5+len6 + +print("Total= ", total_sum, "/" , total_length ) +print("Accuracy= ", total_sum/total_length * 100, "%") + + diff --git a/open_lm/extra_funcs.py b/open_lm/extra_funcs.py deleted file mode 100644 index b746d5e..0000000 --- a/open_lm/extra_funcs.py +++ /dev/null @@ -1,214 +0,0 @@ -import os -import shutil -import random -import json -import torch -import numpy as np -import subprocess - -from open_lm.params import parse_args -from open_lm.model import test_classif_model - -def inference(): - - args = parse_args([]) - args.model = "open_lm_25m" - args.classif_model_path = "/workspace/youssef/lrz/logs/RedPajama/prop/checkpoints/epoch_1.pt" - args.num_classes = 2 - - test_data_path = '/workspace/youssef/lrz/datasets/prop/Llama1_gen.pt' - dataset = torch.load(test_data_path) - - model = test_classif_model(args) - model = model.to('cuda:3') - - pred = [] - for sample in dataset: - sample = torch.LongTensor(sample).to('cuda:3') - with torch.no_grad(): - out, _, _ = model(sample) - pred.append(torch.argmax(out,2)[:,-1].item()) - - c1 = pred.count(0) - c2 = pred.count(1) - - print(c1,c2) - - if c2 > c1: - return 1 - else: - return 0 - -def train_classifier(cuda_devices="3", log_dir="/workspace/youssef/lrz/logs/RedPajama/prop"): - # Set the CUDA_VISIBLE_DEVICES environment variable - os.environ["CUDA_VISIBLE_DEVICES"] = cuda_devices - - # Generate a random master port between 10000 and 65000 - master_port = random.randint(10000, 65000) - - # Construct the torchrun command - command = [ - "torchrun", - f"--master_port={master_port}", - "--nproc-per-node", "1", - "-m", "open_lm.main", - "--model", "open_lm_25m", - "--dataset-manifest", "/workspace/youssef/lrz/datasets/prop/train/manifest.jsonl", - "--train-num-samples", "200000000", - "--workers", "1", - "--precision", "amp_bfloat16", - "--grad-checkpointing", - "--log-every-n-steps", "100", - "--grad-clip-norm", "1", - "--global-batch-size", "16", - "--data-key", "txt", - "--lr", "3e-4", - "--warmup", "2000", - "--wd", "0.1", - "--beta2", "0.95", - "--epochs", "1", - "--resume", "latest", - "--logs", "/workspace/youssef/lrz/logs/RedPajama/", - "--name", "prop", - "--classification", "True", - "--num-classes", "2", - "--classif-model-path", "/workspace/youssef/lrz/logs/pretrain/25M_0.5BC4/checkpoint/epoch_1.pt" - ] - - os.makedirs(log_dir, exist_ok=True) - - # Create log file paths - stdout_log = os.path.join(log_dir, "output.log") - stderr_log = os.path.join(log_dir, "error.log") - - # Run the torchrun command using subprocess - with open(stdout_log, "w") as out_file, open(stderr_log, "w") as err_file: - try: - result = subprocess.run(command, check=True, stdout=out_file, stderr=err_file) - print(f"torchrun finished with return code: {result.returncode}") - except subprocess.CalledProcessError as e: - print(f"An error occurred while running torchrun: {e}") - - - -def proj_simplex(y): - m = len(y) - bget = False - s = sorted(y, reverse=True) # sorting in descending order - tmpsum = 0 - for i in range(m-1): - tmpsum = tmpsum + s[i] - tmax = (tmpsum - 1) / (i+1) - if tmax >= s[i+1]: - bget = True - break - if not bget: - tmax = (tmpsum + s[m-1] -1) / m - return np.maximum(y-tmax,0) - - - -def del_dir(dir_path): - try: - # Remove the directory and all its contents - shutil.rmtree(dir_path) - print(f"Removed directory: {dir_path}") - except FileNotFoundError: - print(f"Directory not found: {dir_path}") - except PermissionError: - print(f"Permission denied: {dir_path}") - except Exception as e: - print(f"An error occurred while removing the directory: {e}") - - -def round_preserving_sum(numbers): - """ - This function takes a list of numbers that add up to 1, multiplies each by 100, - rounds them to integers while preserving the sum as 100. - """ - # Step 1: Multiply all numbers by 100 - multiplied = np.array(numbers) * 100 - - # Step 2: Separate integer and decimal parts - integers = np.floor(multiplied).astype(int) # Integer parts - decimals = multiplied - integers # Decimal parts - - # Step 3: Calculate the difference between the current sum and 100 - current_sum = np.sum(integers) - difference = 100 - current_sum - - # Step 4: Distribute the difference by rounding up the largest decimals - if difference > 0: - # Get indices of the largest decimals and round up those numbers - indices_to_round_up = np.argsort(-decimals)[:difference] - integers[indices_to_round_up] += 1 - - return integers.tolist() - -def sample_and_rename_files(sample_counts_list): - - base_path = "/workspace/youssef/lrz/datasets/prop/original/" - output_folder = "/workspace/youssef/lrz/datasets/prop/train/" - - # Define the folder names in order - file_names = ['arxiv', 'c4', 'cc', 'github', 'se', 'wiki'] - folder_names = [os.path.join(base_path, folder) for folder in file_names] - - # Check if the provided sample_counts_list contains exactly two lists - if len(sample_counts_list) != 2 or any(len(sample_counts) != 6 for sample_counts in sample_counts_list): - raise ValueError("sample_counts_list must contain exactly two lists, each with 6 numbers.") - - # Create the output folder if it doesn't exist - if not os.path.exists(output_folder): - os.makedirs(output_folder) - - # List to store the manifest data - manifest_data = [] - - # Loop over the two lists of sample counts - for index, sample_counts in enumerate(sample_counts_list): - # Iterate through each folder and sample the required number of .tar files - for i, folder in enumerate(folder_names): - folder_path = os.path.join(folder) - - if not os.path.exists(folder_path): - raise ValueError(f"Folder {folder_path} does not exist.") - - # Get all .tar files from the current folder - all_files = [f for f in os.listdir(folder_path) if f.endswith('.tar')] - - # Ensure the sample count is not more than available files - sample_count = min(sample_counts[i], len(all_files)) - - # Randomly sample the required number of files from the folder - sampled_files = random.sample(all_files, sample_count) - - # Copy each sampled file to the output folder with the new name - for file_name in sampled_files: - # Construct source file path - source_file_path = os.path.join(folder_path, file_name) - - # Create the new filename by prepending the index (0 or 1) with a dash - new_file_name = f"{index}-{file_name[:-4]}" # Remove the .tar extension - - # Destination path in the output folder - dest_file_path = os.path.join(output_folder, new_file_name + '.tar') # Keep .tar in destination - - # Copy the file to the output folder - shutil.copy2(source_file_path, dest_file_path) - - # Add entry to manifest_data, replacing ".tar" in new_file_name with an empty string - manifest_entry = { - "shard": new_file_name, # No .tar extension - "num_sequences": 489 # Set a fixed number of sequences - } - manifest_data.append(manifest_entry) - - # Write the manifest.jsonl file - manifest_file_path = os.path.join(output_folder, "manifest.jsonl") - with open(manifest_file_path, 'w') as manifest_file: - # Write each entry except the last one with a newline - for entry in manifest_data: - manifest_file.write(json.dumps(entry) + '\n') - - print(f"Files sampled and saved in {output_folder}. Manifest file created as {manifest_file_path}.") \ No newline at end of file diff --git a/open_lm/hf/__init__.py b/open_lm/hf/__init__.py new file mode 100644 index 0000000..8493168 --- /dev/null +++ b/open_lm/hf/__init__.py @@ -0,0 +1,3 @@ +from .configuration_openlm import OpenLMConfig +from .modeling_openlm import OpenLMForCausalLM +from .tokenization_openlm import OpenLMTokenizerFast diff --git a/open_lm/hf/__pycache__/__init__.cpython-310.pyc b/open_lm/hf/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..72a80db Binary files /dev/null and b/open_lm/hf/__pycache__/__init__.cpython-310.pyc differ diff --git a/open_lm/hf/__pycache__/configuration_openlm.cpython-310.pyc b/open_lm/hf/__pycache__/configuration_openlm.cpython-310.pyc new file mode 100644 index 0000000..29dba33 Binary files /dev/null and b/open_lm/hf/__pycache__/configuration_openlm.cpython-310.pyc differ diff --git a/open_lm/hf/__pycache__/modeling_openlm.cpython-310.pyc b/open_lm/hf/__pycache__/modeling_openlm.cpython-310.pyc new file mode 100644 index 0000000..84b00fb Binary files /dev/null and b/open_lm/hf/__pycache__/modeling_openlm.cpython-310.pyc differ diff --git a/open_lm/hf/__pycache__/tokenization_openlm.cpython-310.pyc b/open_lm/hf/__pycache__/tokenization_openlm.cpython-310.pyc new file mode 100644 index 0000000..dd3cc05 Binary files /dev/null and b/open_lm/hf/__pycache__/tokenization_openlm.cpython-310.pyc differ diff --git a/open_lm/hf/configuration_openlm.py b/open_lm/hf/configuration_openlm.py new file mode 100644 index 0000000..7566396 --- /dev/null +++ b/open_lm/hf/configuration_openlm.py @@ -0,0 +1,24 @@ +# Follows OLMo's HF template + +""" +OpenLM configuration +""" + +from transformers import AutoConfig, PretrainedConfig +from transformers.utils import logging + +from open_lm.model import Params + +logger = logging.get_logger(__name__) + + +class OpenLMConfig(PretrainedConfig): + model_type = "openlm" + + def __init__(self, **kwargs): + kwargs["architectures"] = ["OpenLMForCausalLM"] + super().__init__(**kwargs) + + +# Register the config class so that it is available for transformer pipelines, auto-loading etc. +AutoConfig.register("openlm", OpenLMConfig) diff --git a/open_lm/hf/modeling_openlm.py b/open_lm/hf/modeling_openlm.py new file mode 100644 index 0000000..67ee1e4 --- /dev/null +++ b/open_lm/hf/modeling_openlm.py @@ -0,0 +1,194 @@ +# Follows OLMo's HF template + +import logging +from dataclasses import fields +from typing import List, Optional, Tuple, Union + +import torch +from transformers import PreTrainedModel +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.auto import AutoModelForCausalLM + +from open_lm.model import Params, Transformer +from open_lm.norms import get_norm_class +from open_lm.attention import get_attn_func + +from .configuration_openlm import OpenLMConfig + +log = logging.getLogger(__name__) + + +def create_model_config_from_pretrained_config(config: OpenLMConfig): + """ + Utility function + """ + + kwargs = {} + for field in fields(Params): + if hasattr(config, field.name): + kwargs[field.name] = getattr(config, field.name) + + model_config = Params(**kwargs) + + if hasattr(config, "norm_type"): + model_config.norm_type = get_norm_class(config.norm_type) + + if hasattr(config, "attn_name"): + model_config.attn_func = get_attn_func(config.attn_name) + + return model_config + + +class OpenLMForCausalLM(PreTrainedModel): + """ + Extremely barebones HF model wrapper. + """ + + config_class = OpenLMConfig + base_model_prefix = "model" + + def __init__(self, config: OpenLMConfig, model: Optional[Transformer] = None): + super().__init__(config) + + if not model: + self.model_config = create_model_config_from_pretrained_config(config) + # Initialize model (always on CPU to start with so we don't run out of GPU memory). + self.model_config.init_device = "cpu" + self.model = Transformer(self.model_config) + + else: + self.model = model + + def forward( + self, + input_ids: torch.LongTensor = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[ + Cache + ] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426 + ) -> Union[Tuple, CausalLMOutputWithPast]: + if inputs_embeds is not None: + log.warning("inputs_embeds is set but OpenLM does not support it yet") + if attention_bias is not None: + log.warning("attention_bias is et but OpenLM does not support it yet") + if use_cache is None: + use_cache = True + if output_attentions: + raise ValueError("output_attentions is not yet supported in OpenLM") + if output_hidden_states: + raise ValueError("output_hidden_states is not yet supported in OpenLM") + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + # print("outer past_key_values: ", type(past_key_values)) + # if past_key_values is not None: + # print(len(past_key_values), type(past_key_values[0])) + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + ) + + logits = outputs[0] + past_key_values = outputs[2] + hidden_states = None + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = torch.nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.model_config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=past_key_values, + hidden_states=hidden_states, + ) + + def can_generate(self) -> bool: + return True + + def prepare_inputs_for_generation( + self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values[0][1], int): + # This assumes that the second item of past key values is the length of the past (this is the case for linear attention) + past_length = past_key_values[0][1] + else: + # This assumes that the first item of past key values is a list of all the past keys, thus the + # shape 1 is the length of the past (this is the case for attention without window) + past_length = past_key_values[0][0].shape[1] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + model_inputs = { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.pop("use_cache", True), + } + return model_inputs + + def get_input_embeddings(self) -> torch.nn.Module: + return self.model.tok_embeddings + + def set_input_embeddings(self, value: torch.nn.Module): + self.model.tok_embeddings = value + + def get_output_embeddings(self): + if self.model_config.weight_tying: + return self.model.tok_embeddings + else: + return self.model.output + + def set_output_embeddings(self, value: torch.nn.Module): + if self.model_config.weight_tying: + self.model.tok_embeddings = value + else: + self.model.output = value + + def tie_weights(self): + """ + Copied from OLMo (description below). I removed it and the results just became garbage, so this pass is needed. + This function is intentionally left as a no-op. + Weight tying is handled as follows: + - When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration. + See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`. + - When computing logits, the `wte` weights are used directly if `weight_tying` is enabled. + See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method. + Therefore, there is no need to explicitly tie the weights in this function. + """ + pass + + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + ) -> torch.nn.Embedding: + raise NotImplementedError + + +# Register the model so that it is available for transformer pipelines, auto-loading, etc. +AutoModelForCausalLM.register(OpenLMConfig, OpenLMForCausalLM) diff --git a/open_lm/hf/tokenization_openlm.py b/open_lm/hf/tokenization_openlm.py new file mode 100644 index 0000000..e8abdd6 --- /dev/null +++ b/open_lm/hf/tokenization_openlm.py @@ -0,0 +1,18 @@ +# Follows OLMo's HF template + +from transformers import AutoTokenizer, PreTrainedTokenizerFast + +from open_lm.hf.configuration_openlm import OpenLMConfig + + +class OpenLMTokenizerFast(PreTrainedTokenizerFast): + # Note: OpenLM's tokenizer is already a wrapper around huggingface. This is potentially unnecessary. + pass + + # def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + # # This is required to make the implementation complete. + # pass + + +# Register the tokenizer class so that it is available for transformer pipelines, auto-loading etc. +AutoTokenizer.register(OpenLMConfig, fast_tokenizer_class=OpenLMTokenizerFast) diff --git a/open_lm/infer_proportions.py b/open_lm/infer_proportions.py deleted file mode 100644 index 1dadd64..0000000 --- a/open_lm/infer_proportions.py +++ /dev/null @@ -1,57 +0,0 @@ -import torch -import numpy as np - -from extra_funcs import train_classifier, proj_simplex, round_preserving_sum, sample_and_rename_files, inference, del_dir - -def comparison(x, xcandidate): - - list1 = round_preserving_sum(x.tolist()) - list2 = round_preserving_sum(xcandidate.tolist()) - list = [list1, list2] - - sample_and_rename_files(list) - - train_classifier() - - result = inference() - - del_dir("/workspace/youssef/lrz/logs/RedPajama/prop") - del_dir("/workspace/youssef/lrz/datasets/prop/train") - - return result - - -def gradientless_descent(N=6, num_iter=200, radius = 0.2, alpha=0.5): - - #For measuring error - xorig = np.array([0.0325,0.1575,0.6775,0.0525,0.0275,0.0525]) - - # initialize x with equal probability - x = np.ones(N)/N - - error = [] - prop = [] - - for i in range(num_iter): - - stepsize = 1/(i+1)**alpha - # choose random direction with radius R - dir = np.random.randn(N) - dir = dir/np.linalg.norm(dir)*radius*stepsize - - xcandidate = proj_simplex( x + dir ) - - # compare x with x+dir and update x - if comparison(x, xcandidate) == 1: - x = xcandidate - - print(i, np.linalg.norm(x-xorig), x) - error.append(np.linalg.norm(x-xorig)) - prop.append(x) - - torch.save(error, "error.pt") - torch.save(prop, "prop.pt") - return x - -if __name__ == "__main__": - gradientless_descent() diff --git a/open_lm/main2.py b/open_lm/main2.py deleted file mode 100644 index 55863f8..0000000 --- a/open_lm/main2.py +++ /dev/null @@ -1,1034 +0,0 @@ -import atexit -import logging -import os -import re -import sys -import random -from datetime import datetime -import functools -import numpy as np -from pathlib import Path -import json -import traceback - -import fsspec -import torch -from torch import optim -from torch.cuda.amp import GradScaler - -import torch.distributed as dist - -from open_lm.data import sample_chunk - -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - MixedPrecision, - BackwardPrefetch, - ShardingStrategy, - FullStateDictConfig, - StateDictType, - CPUOffload, -) -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy - -from open_lm.data import proc_token -from open_lm.model import Block -from open_lm.losses import CrossEntropyLossWithZLoss -from open_lm.utils.averaging_utils import ModelAverager - -try: - import wandb -except ImportError: - wandb = None - -try: - import torch.utils.tensorboard as tensorboard -except ImportError: - tensorboard = None - -from open_lm.model import create_model -from open_lm.model import create_classif_model - -from open_lm.utils.transformers.hf_wrapper import create_wrapped_hf_model -from open_lm.data import get_data, get_wds_dataset -from open_lm.distributed import is_master, init_distributed_device, broadcast_object -from open_lm.logger import setup_logging -from open_lm.params import parse_args -from open_lm.scheduler import cosine_lr, const_lr -from open_lm.train import train_one_epoch -from open_lm.evaluate import evaluate_loop -from open_lm.file_utils import ( - pt_load, - check_exists, - start_sync_process, - remote_sync_with_expon_backoff, - get_metadata_file, - get_string_for_epoch, - log_num_checkpoints, - terminate_sync_process, -) - - -LATEST_CHECKPOINT_NAME = "epoch_latest.pt" - - -def random_seed(seed=42, rank=0): - torch.manual_seed(seed + rank) - np.random.seed(seed + rank) - random.seed(seed + rank) - - -def natural_key(string_): - """See http://www.codinghorror.com/blog/archives/001018.html""" - return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] - - -def get_latest_checkpoint(path: str): - is_s3 = path.startswith("s3") - fs, root_path = fsspec.core.url_to_fs(path) - checkpoints = fs.glob(os.path.join(root_path, "epoch_*.pt")) - if checkpoints: - checkpoints = sorted(checkpoints, key=natural_key) - return f"s3://{checkpoints[-1]}" if is_s3 else checkpoints[-1] - - return None - - -def get_state_dict(name): - checkpoint = pt_load(name, map_location="cpu") - if "epoch" in checkpoint: - sd = checkpoint["state_dict"] - if next(iter(sd.items()))[0].startswith("module"): - sd = {k[len("module.") :]: v for k, v in sd.items()} - else: - sd = checkpoint - return sd - - -def load_model(args, model, different_seed=False): - checkpoint = pt_load(args.resume, map_location="cpu") - if "epoch" in checkpoint: - if not different_seed and "shard_shuffle_seed" in checkpoint: - pretrained_seed = checkpoint["shard_shuffle_seed"] - assert ( - pretrained_seed == args.seed - ), f"This checkpoint was trained with a random seed of {pretrained_seed}. Since this seed affects shard shuffling, resuming training must use the same seed." - else: - if different_seed: - message = "Resuming a checkpoint without checking that the seed match. This means that training might not be reproducible." - else: - message = "Resuming a checkpoint that does not have a seed saved. This means that the shards were not shuffled, so they will remain unshuffled." - logging.info(message) - pretrained_seed = None - - # resuming a train checkpoint w/ epoch and optimizer state - start_epoch = checkpoint["epoch"] - sd = checkpoint["state_dict"] - global_step = checkpoint.get("step", None) - if next(iter(sd.items()))[0].startswith("module"): - sd = {k[len("module.") :]: v for k, v in sd.items()} - if "_orig_mod" in next(iter(sd.items()))[0]: - sd = {k.replace("_orig_mod.", ""): v for k, v in sd.items()} - if args.fsdp: - model.load_state_dict(sd) - elif args.distributed: - model.module.load_state_dict(sd) - else: - model.load_state_dict(sd) - logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") - else: - # loading a bare (model only) checkpoint for fine-tune or evaluation - start_epoch, global_step = 0, 0 - pretrained_seed = None - model.load_state_dict(checkpoint) - logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") - return start_epoch, global_step, pretrained_seed - - -def load_avg_models(args, averagers): - checkpoint = pt_load(args.resume, map_location="cpu") - if "epoch" in checkpoint: - # resuming a train checkpoint w/ epoch and optimizer state - start_epoch = checkpoint["epoch"] - if averagers is not None: - for k in averagers.avgs_dict: - avg_sd = torch.load(args.resume.replace("epoch", k), map_location="cpu") - if next(iter(avg_sd.items()))[0].startswith("module"): - avg_sd = {k[len("module.") :]: v for k, v in avg_sd.items()} - if "_orig_mod" in next(iter(avg_sd.items()))[0]: - avg_sd = {k.replace("_orig_mod.", ""): v for k, v in avg_sd.items()} - averagers.avgs_dict[k].load_state_dict_avg(avg_sd) - logging.info( - f"=> resuming averager for {k} from checkpoint '{args.resume.replace('epoch', k)} (epoch {start_epoch})" - ) - return - - -def load_optimizer(args, model, optimizer, scaler): - potential_checkpoint = args.resume.replace("epoch_", "optimizer_") - if check_exists(potential_checkpoint): - checkpoint = pt_load(potential_checkpoint, map_location="cpu") - else: - checkpoint = pt_load(args.resume, map_location="cpu") - if "optimizer" in checkpoint: - if optimizer is not None: - osd = checkpoint["optimizer"] - if args.fsdp: - osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=osd) - optimizer.load_state_dict(osd) - logging.info(f"=> resuming optimizer") - if scaler is not None and "scaler" in checkpoint: - scaler.load_state_dict(checkpoint["scaler"]) - else: - logging.info(f"=> WARNING: not resuming optimizer.") - - -def load_data_chunks(args): - checkpoint = pt_load(args.resume, map_location="cpu") - if "next_shard_per_source" in checkpoint and "samples_seen" in checkpoint: - return checkpoint["next_shard_per_source"], checkpoint["samples_seen"] - else: - logging.info( - "=> WARNING: tried to resume a checkpoint without data loading info. Re-starting data loading from the " - "first shard." - ) - return [0 for _ in range(len(args.dataset_manifest))], 0 - - -def save_checkpoint( - args, - model, - optimizer, - scaler, - completed_epoch, - evaluation_metrics, - step, - is_final_checkpoint, - percentage_of_data_seen=-1.0, - next_shard_per_source=None, - samples_seen=None, - shard_shuffle_seed=None, - train_data_string=None, - averagers=None, - failed=False, -): - cpu_state, optim_state = None, None - if args.logs and args.logs.lower() != "none" and args.fsdp: - save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): - cpu_state = model.state_dict() - optim_state = FSDP.optim_state_dict(model, optimizer) - if args.save_logs: - checkpoint_dict_model = { - "epoch": completed_epoch, - "name": args.name, - "state_dict": cpu_state if args.fsdp else model.state_dict(), - "evaluation_metrics": evaluation_metrics, - } - if next_shard_per_source is not None: - checkpoint_dict_model["next_shard_per_source"] = next_shard_per_source - - if samples_seen is not None: - checkpoint_dict_model["samples_seen"] = samples_seen - - if step is not None: - checkpoint_dict_model["step"] = step - - if shard_shuffle_seed is not None: - checkpoint_dict_model["shard_shuffle_seed"] = shard_shuffle_seed - - checkpoint_dict_opt = { - "epoch": completed_epoch, - "name": args.name, - "optimizer": optim_state if args.fsdp else optimizer.state_dict(), - "evaluation_metrics": evaluation_metrics, - } - - if scaler is not None: - checkpoint_dict_opt["scaler"] = scaler.state_dict() - - checkpoint_dict_stats = { - "epoch": completed_epoch, - "name": args.name, - "is_final_checkpoint": is_final_checkpoint, - "evaluation_metrics": evaluation_metrics, - "percentage_of_data_seen": percentage_of_data_seen, - } - if next_shard_per_source is not None: - checkpoint_dict_stats["next_shard_per_source"] = next_shard_per_source - - if samples_seen is not None: - checkpoint_dict_stats["samples_seen"] = samples_seen - - if step is not None: - checkpoint_dict_stats["step"] = step - - if shard_shuffle_seed is not None: - checkpoint_dict_stats["shard_shuffle_seed"] = shard_shuffle_seed - - if train_data_string is not None: - checkpoint_dict_stats["train_data_string"] = train_data_string - - prefixes = { - "epoch_": checkpoint_dict_model, - "optimizer_": checkpoint_dict_opt, - "stats_": checkpoint_dict_stats, - } - - if averagers is not None: - for k in averagers.avgs_dict: - prefixes[f"{k}_"] = averagers.avgs_dict[k].get_state_dict_avg() - if ( - completed_epoch == args.epochs - or is_final_checkpoint - or (args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0) - ): - for prefix in prefixes: - save_path = args.checkpoint_path if not failed else args.failed_checkpoint_path - path = os.path.join(save_path, f"{prefix}{completed_epoch}.pt") - print(f"Saving {prefix}{completed_epoch} in {path}...") - torch.save( - prefixes[prefix], - path, - ) - - if args.delete_previous_checkpoint: - for prefix in prefixes: - prev = os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch - 1}.pt") - if os.path.exists(prev): - os.remove(prev) - - -def cleanup(sync_process, distributed=False): - if sync_process: - terminate_sync_process(sync_process) - if distributed and torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() - - -def main(args): - args = parse_args(args) - - requires_training = args.train_data or args.dataset_type == "synthetic" or args.dataset_manifest is not None - - if torch.cuda.is_available(): - # This enables tf32 on Ampere GPUs which is only 8% slower than - # float16 and almost as accurate as float32 - # This was a default in pytorch until 1.12 - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.benchmark = True - torch.backends.cudnn.deterministic = False - - # fully initialize distributed device environment - device = init_distributed_device(args) - - assert ( - args.global_batch_size % args.world_size == 0 - ), f"Global batch size ({args.global_batch_size}) is not divisible by number of GPUs ({args.world_size}), and thus cannot be respected." - - args.per_gpu_batch_size = max(args.global_batch_size // args.world_size, 1) - if args.val_data is not None: - args.per_gpu_val_batch_size = max(args.global_val_batch_size // args.world_size, 1) - - if args.hf_model is not None and args.hf_seq_len is None: - raise ValueError("If passing --hf-model, must also pass --hf-seq-len to be used for training/fine-tuning.") - - if args.hf_model is not None and args.fsdp and args.hf_fsdp_block is None: - raise ValueError("If passing --hf-model and --fsdp, must also pass --hf-fspd-block.") - - if args.fsdp and not args.distributed: - raise ValueError(f"--fsdp can only be specified in distributed mode.") - - # get the name of the experiments - if args.name is None: - # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? - model_name_safe = None - if args.hf_model is not None: - model_name_safe = args.hf_model.replace("/", "-") - else: - if Path(args.model).is_file(): - model_name_safe = Path(args.model).stem.replace("/", "-") - else: - model_name_safe = args.model.replace("/", "-") - - date_str = datetime.now().strftime("%Y_%m_%d-%H_%M_%S") - if args.distributed: - # sync date_str from master to all ranks - date_str = broadcast_object(args, date_str) - args.name = "-".join( - [ - date_str, - f"model_{model_name_safe}", - f"lr_{args.lr}", - f"b_{args.per_gpu_batch_size}", # Per gpu to respect old naming convention - ] - ) - - resume_latest = args.resume == "latest" - log_base_path = os.path.join(args.logs, args.name) - args.log_path = None - if is_master(args, local=args.log_local): - os.makedirs(log_base_path, exist_ok=True) - log_filename = f"out-{args.rank}" if args.log_local else "out.log" - args.log_path = os.path.join(log_base_path, log_filename) - if os.path.exists(args.log_path) and not resume_latest: - raise ValueError(f"Experiment {args.log_path} already exists. Use --name to specify a new experiment.") - - # Setup text logger - args.log_level = logging.DEBUG if args.debug else logging.INFO - setup_logging(args.log_path, args.log_level) - - # Setup wandb, tensorboard, checkpoint logging - args.wandb = "wandb" in args.report_to or "all" in args.report_to - args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to - args.checkpoint_path = os.path.join(log_base_path, "checkpoints") - args.failed_checkpoint_path = os.path.join(log_base_path, "checkpoints_failed") - if is_master(args): - args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else "" - for dirname in [args.tensorboard_path, args.checkpoint_path, args.failed_checkpoint_path]: - if dirname: - os.makedirs(dirname, exist_ok=True) - else: - args.tensorboard_path = "" - - if resume_latest: - resume_from = None - checkpoint_path = args.checkpoint_path - - # If using remote_sync, need to check the remote instead of the local checkpoints folder. - if args.remote_sync is not None: - checkpoint_path = os.path.join(args.remote_sync, args.name, "checkpoints") - - if is_master(args): - # Checking for existing checkpoint via master rank only. It is possible for - # different rank processes to see different files if a shared file-system is under - # stress, however it's very difficult to fully work around such situations. - if args.save_most_recent: - # if --save-most-recent flag is set, look for latest at a fixed filename - resume_from = os.path.join(checkpoint_path, "checkpoints", LATEST_CHECKPOINT_NAME) - if not os.path.exists(resume_from): - # If no latest checkpoint has been saved yet, don't try to resume - resume_from = None - else: - # otherwise, list checkpoint dir contents and pick the newest checkpoint - resume_from = get_latest_checkpoint(checkpoint_path) - if resume_from: - logging.info(f"Found latest resume checkpoint at {resume_from}.") - else: - logging.info(f"No latest resume checkpoint found in {checkpoint_path}.") - if args.distributed: - # sync found checkpoint path to all ranks - resume_from = broadcast_object(args, resume_from) - args.resume = resume_from - - if args.copy_codebase: - copy_codebase(args) - - # start the sync proces if remote-sync is not None - remote_sync_process = None - if is_master(args) and args.remote_sync is not None: - # first make sure it works - result = remote_sync_with_expon_backoff( - args.remote_sync_frequency, - os.path.join(args.logs, args.name), - os.path.join(args.remote_sync, args.name), - args.remote_sync_protocol, - ) - if result: - logging.info("remote sync successful.") - else: - raise ValueError("Remote sync failed.") - # if all looks good, start a process to do this every args.remote_sync_frequency seconds - remote_sync_process = start_sync_process( - args.remote_sync_frequency, - os.path.join(args.logs, args.name), - os.path.join(args.remote_sync, args.name), - args.remote_sync_protocol, - ) - remote_sync_process.start() - - # Handle cleanup even if open_lm crashes. - # TODO: For cases where main() is called as a functio, we need to call cleanup() manually. - # Right now, we do this manually in every case where main returns, but we should put main() in a wrapper and call - # cleanup() outside it, ideally. - atexit.register(cleanup, sync_process=remote_sync_process, distributed=args.distributed) - - if args.precision == "fp16": - logging.warning( - "It is recommended to use AMP mixed-precision instead of FP16. " - "FP16 support needs further verification and tuning, especially for train." - ) - - elif args.distributed: - logging.info( - f"Running in distributed mode with multiple processes. Device: {args.device}." - f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." - ) - else: - logging.info(f"Running with a single process. Device {args.device}.") - - random_seed(args.seed, 0) - - model = None - if args.hf_model is not None: - model = create_wrapped_hf_model(args) - else: - # Optional: Use meta device - with torch.device("meta" if args.experimental_meta_device and args.fsdp else args.device): - if args.classification: - model = create_classif_model(args) - else: - model = create_model(args) - - args.vocab_size = model.vocab_size - args.seq_len = model.seq_len - if args.train_num_samples is not None: - args.train_num_samples //= args.seq_len - if args.val_num_samples is not None: - if args.val_num_samples // args.seq_len == 0: - raise ValueError( - f"number of requested evaluation val_num_samples (tokens): {args.val_num_samples} is less than seq_len: {args.seq_len}" - ) - args.val_num_samples //= args.seq_len - - averagers = None - random_seed(args.seed, args.rank) - - if args.grad_checkpointing: - model.set_grad_checkpointing() - - if args.distributed: - if args.fsdp: - transformer_layer_cls = None - - if args.hf_model is not None: - # retrive the user specified block class for fsdp - for _, target_cls in model.named_modules(): - if args.hf_fsdp_block in type(target_cls).__name__: - transformer_layer_cls = {type(target_cls)} - break - - if transformer_layer_cls is None: - print(f"--hf-fsdp-block {args.hf_fsdp_block} not found in --hf-model {args.hf_model}") - return -1 - - else: - transformer_layer_cls = {Block} - # from https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/ - transformer_auto_wrapper_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls=transformer_layer_cls, - ) - # tries to follow gopher... - mp_policy = None - if args.fsdp_amp: - print("=> using bfloat16 params as part of fsdp amp policy.") - mp_policy = MixedPrecision( - param_dtype=torch.bfloat16, - reduce_dtype=torch.float32, - buffer_dtype=torch.bfloat16, - ) - elif args.fsdp_pure_bf16: - print("=> using pure bfloat16 params as part of fsdp amp policy.") - mp_policy = MixedPrecision( - param_dtype=torch.bfloat16, - reduce_dtype=torch.bfloat16, - buffer_dtype=torch.bfloat16, - ) - - if args.rank == 0: - print(f"Before FSDP parameter num: {sum(p.numel() for p in model.parameters()):,}") - print(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") - - fsdp_kwargs = {} - assert not ( - args.fsdp_hybrid and args.fsdp_hybrid_o2 - ), "Only --fsdp-hybrid or --fsdp-hybrid-o2 should be set." - if args.fsdp_backward_prefetch: - fsdp_kwargs["backward_prefetch"] = BackwardPrefetch.BACKWARD_PRE - if args.fsdp_hybrid: - fsdp_kwargs["sharding_strategy"] = ShardingStrategy.HYBRID_SHARD - if args.fsdp_hybrid_o2: - fsdp_kwargs["sharding_strategy"] = ShardingStrategy._HYBRID_SHARD_ZERO2 - print("=> FSDP kwargs: ", fsdp_kwargs) - - # Initialize FSDP. Use the same seed across workers to ensure reset_parameters is the same across workers. - random_seed(args.seed, rank=0) - model = FSDP( - model, - auto_wrap_policy=transformer_auto_wrapper_policy, - device_id=device, - mixed_precision=mp_policy, - cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload), - use_orig_params=args.fsdp_use_orig_params, - limit_all_gathers=args.fsdp_limit_all_gathers, - **fsdp_kwargs, - ) - - print(f"After FSDP parameter num: {sum(p.numel() for p in model.parameters()):,} on rank {args.rank}") - print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}") - else: - ddp_args = {} - if args.ddp_static_graph: - # this doesn't exist in older PyTorch, arg only added if enabled - ddp_args["static_graph"] = True - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) - if args.averagers is not None: - averagers = ModelAverager(model, args.averagers) - if args.resume is not None and averagers is not None: - load_avg_models(args, averagers) - - if is_master(args): - logging.info(f"Model (has {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters):") - logging.info(f"{str(model)}") - logging.info("Params:") - params_file = os.path.join(args.logs, args.name, "params.txt") - with open(params_file, "w") as f: - for name in sorted(vars(args)): - val = getattr(args, name) - logging.info(f" {name}: {val}") - f.write(f"{name}: {val}\n") - - # optionally resume model from a checkpoint - start_epoch, global_step = 0, 0 - shard_shuffle_seed = args.seed - if args.resume is not None: - start_epoch, global_step, shard_shuffle_seed = load_model(args, model) - - elif args.pretrained is not None: - print("=> loading from a pre-trained model.") - args.resume = args.pretrained - # this flag continues training from the pre-trained model. - if args.load_pretrained_state: - start_epoch, global_step, shard_shuffle_seed = load_model(args, model) - else: - load_model(args, model, different_seed=True) - args.resume = None - elif args.average is not None: - num_models_to_average = len(args.average) - print( - "=> Averaging models: ", - args.average, - " with coefficients: ", - args.average_coefficients, - ) - assert num_models_to_average > 1, "num_models_to_average must be > 1 - else use --pretrained" - if args.average_coefficients is None: - args.average_coefficients = [1.0 / num_models_to_average] * num_models_to_average - else: - assert len(args.average_coefficients) == num_models_to_average - state_dict = {k: v * args.average_coefficients[0] for k, v in get_state_dict(args.average[0]).items()} - for i in range(1, num_models_to_average): - state_dict_i = get_state_dict(args.average[i]) - for k in state_dict: - state_dict[k] = state_dict[k] + state_dict_i[k] * args.average_coefficients[i] - model.load_state_dict(state_dict) - - # Put the shard shuffle seed back into args (this is done for compatibility with older, non shuffling versions) - args.shard_shuffle_seed = shard_shuffle_seed - - if requires_training and global_step is None: - raise ValueError("Key 'step' not found in checkpoint, but required for training.") - - # Add data chunk when resuming (only for dataset without resampling) - next_shard_per_source = [0 for _ in range(len(args.dataset_manifest))] if args.dataset_manifest is not None else 0 - samples_seen = 0 - if args.resume is not None and args.dataset_manifest is not None: - next_shard_per_source, samples_seen = load_data_chunks(args) - if samples_seen >= args.train_num_samples * args.epochs: - raise RuntimeError("Loaded a checkpoint which has already seen the desired number of tokens.") - - # create optimizer and scaler - optimizer = None - scaler = None - - if requires_training: - named_parameters = list(model.named_parameters()) - no_decay_params = [] # to be potentially used later - params = [p for n, p in named_parameters if p.requires_grad] - - optimizer = optim.AdamW( - [ - {"params": no_decay_params, "weight_decay": 0.0}, - {"params": params, "weight_decay": args.wd}, - ], - lr=args.lr, - betas=(args.beta1, args.beta2), - eps=args.eps, - ) - scaler = None - if args.precision == "amp": - assert not args.fsdp, "FSDP not supported with amp, only amp_bfloat16" - scaler = GradScaler() - - # initialize datasets - # use tokenizer=None because the data is already pre-tokenized. - - data = get_data( - args, - epoch=start_epoch, - tokenizer=None, - skip_train=args.dataset_manifest is not None, - floor=args.dataset_manifest is not None, - ) - - if args.target_mask_left is not None: - # tokens handled with same modulo in dataloading - args.target_mask_left = proc_token(args.target_mask_left, args.vocab_size) - - if args.target_mask_individual is not None: - # tokens handled with same modulo in dataloading - args.target_mask_individual = proc_token(args.target_mask_individual, args.vocab_size) - - if args.torchcompile: - logging.info("Compiling model...") - model = torch.compile(model) - if averagers is not None: - logging.info("Compiling averagers...") - for k in averagers.avgs_dict: - averagers.avgs_dict[k].av_model = torch.compile(averagers.avgs_dict[k].av_model) - - # optionally resume optimizer from a checkpoint - # this needs to be after torchcompile - if args.resume is not None: - load_optimizer(args, model, optimizer, scaler) - - # create scheduler if train - scheduler = None - if requires_training: - if args.dataset_manifest is not None: - total_steps = (args.train_num_samples * args.epochs) // args.global_batch_size - else: - total_steps = (data["train"].dataloader.num_batches) * args.epochs - - if args.lr_scheduler == "cosine": - scheduler = cosine_lr( - optimizer, - args.lr, - args.warmup, - total_steps, - args.lr_cooldown_end, - args.force_min_lr, - ) - elif args.lr_scheduler == "const": - scheduler = const_lr( - optimizer, - args.lr, - args.warmup, - # total_steps, - # args.lr_cooldown_end, - # args.force_min_lr, - ) - else: - raise ValueError(f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const.") - - # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 - args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) - writer = None - if args.save_logs and args.tensorboard: - assert tensorboard is not None, "Please install tensorboard." - writer = tensorboard.SummaryWriter(args.tensorboard_path) - if args.wandb and is_master(args): - assert wandb is not None, "Please install wandb." - logging.debug("Starting wandb.") - - wandb.init( - project=args.wandb_project_name, - name=args.name, - notes=args.wandb_notes, - tags=[], - resume=None, - config=vars(args), - ) - if args.debug: - wandb.watch(model, log="all") - wandb.save(params_file) - logging.debug("Finished loading wandb.") - - if not requires_training: - if not args.resume: - logging.info("No training required, exiting.") - cleanup(remote_sync_process, args.distributed) - return - logging.info("No training required, evaluating instead.") - checkpoint_root = os.path.dirname(args.resume) - - if averagers is not None: - k = next(iter(averagers.avgs_dict.keys())) - logging.info(f"=> evaluation avg {k}") - model = averagers.avgs_dict[k].av_model - metrics = evaluate_loop(model, data["val_list"], start_epoch, args, writer) - metrics["average"] = k if averagers is not None else "none" - - if is_master(args): - with fsspec.open(os.path.join(checkpoint_root, "results.jsonl"), "a") as f: - f.write(f"{json.dumps(metrics)}\n") - - cleanup(remote_sync_process, args.distributed) - return - - loss = torch.nn.CrossEntropyLoss() - if args.z_loss_coefficient != 0.0: - if is_master(args): - logging.info("Using CrossEntropyLossWithZLoss.") - loss = CrossEntropyLossWithZLoss(args.z_loss_coefficient) - - if args.dataset_manifest: - log_num_checkpoints(total_steps, args) - - # Only enter training loop if there are steps to be done. - done_training = global_step >= total_steps - epoch = start_epoch - num_ckpt_too_few_tokens = 0 - while not done_training: - if is_master(args): - logging.info(f"Start epoch {epoch}") - - if args.dataset_manifest is not None: - assert not args.dataset_resampled, "dataset_manifest and dataset_resampled are mutually exclusive" - ( - train_data_string_per_source, - num_samples_per_source, - next_shard_per_source, - ) = get_string_for_epoch( - args.train_num_samples, - next_shard_per_source, - args.dataset_manifest, - args.train_data_mix_weights, - args.workers, - args.world_size, - multi_epoch=args.multiple_data_passes, - shard_shuffle_seed=args.shard_shuffle_seed, - ) - - # In the distributed case, make sure that all nodes receive the same string - if args.distributed: - all_source_strings = ["" for _ in range(args.world_size)] - dist.all_gather_object(all_source_strings, train_data_string_per_source) - assert all( - [x == train_data_string_per_source for x in all_source_strings] - ), "Dataset to train on is not the same across all nodes. This should not happen normally, unless there is an issue with shard shuffling during the dataset generation." - - if data["train"] is not None: - del data["train"] - args.train_data = train_data_string_per_source - - # Draw num_samples_per_source at most from dataset - rounded down to guarantee uniqueness. - data["train"] = get_wds_dataset( - args, True, epoch, force_num_samples=num_samples_per_source, data_key=args.data_key, floor=True - ) - - prev_step = global_step - if is_master(args): - logging.info(f"=> epoch {epoch}, training on {args.train_data}") - - if args.distributed: - dist.barrier() - - - #for batch in data["train"].dataloader: - # (texts, labels) = batch - # print(labels) - - # Get the dataloader and create an iterator - #dataloader = data["train"].dataloader - #data_iterator = iter(dataloader) - #batch = next(data_iterator) - - #(texts, labels) = batch - - #texts = torch.LongTensor(texts).to('cuda:0') - #labels = torch.LongTensor(labels).to('cuda:0') - #print(labels, labels.size()) - #labels = labels.unsqueeze(1).repeat(1, args.seq_len) - #print(labels, labels.size()) - - - #print(len(texts), texts.dtype, texts[0].size()) - #print(len(labels), labels.dtype, labels.size()) - #print(labels) - - - #print(len(labels), labels.dtype) - - #print("len(texts)= ", len(texts), " size(texts[0])= ", texts[0].size()) - #print(type(texts), type(texts[0])) - - #inputs, targets = sample_chunk(texts, args) - - #print("len(inputs)= ", len(inputs), " size(inputs[0])= ", inputs[0].size()) - #print(type(inputs), type(inputs[0])) - - #print("len(targets)= ", len(targets), " size(targets)= ", targets[0].size()) - #print(type(targets), type(targets[0])) - - #print("texts[0]= ", texts[0]) - #print("inputs[0]= ", inputs[0]) - #print("targets[0]= ", targets[0]) - - #out, _, _ = model(inputs) - - #print("len(out)= ", len(out), " size(out)= ", out.size()) - #print(type(out), type(out[0])) - #print("out[0]= ", out[0]) - - #device = next(model.parameters()).device - #print(inputs.device, device) - - #print("reshape") - #print("out reshaped: ", out.reshape(-1, args.vocab_size).size(), "targets reshaped: ", targets.reshape(-1).size()) - #print(targets.dtype) - #print(targets) - - #out = out[:, -1, :] - #print("out reshaped: ", out.reshape(-1, args.num_classes).size(), "lables reshaped: ", labels.reshape(-1).size()) - - success, global_step = train_one_epoch( - model, - data, - loss, - averagers=averagers, - epoch=epoch, - step=global_step, - optimizer=optimizer, - scaler=scaler, - scheduler=scheduler, - total_steps=total_steps, - args=args, - tb_writer=writer, - ) - - if args.distributed: - dist.barrier() - - done_training = global_step >= total_steps - steps_done_epoch = global_step - prev_step - samples_seen = samples_seen + steps_done_epoch * args.global_batch_size - - if not success: - logging.info("Training exiting due to NaN value") - break - - failed_ckpt = False - expected_steps = data["train"].dataloader.num_batches - if steps_done_epoch < (1 - args.data_tolerate_error_p) * expected_steps and not done_training: - failed_ckpt = True - num_ckpt_too_few_tokens += 1 - if is_master(args): - logging.warning( - f"Epoch {epoch}, tokens seen: {steps_done_epoch * args.global_batch_size * args.seq_len}, tokens expected: {expected_steps * args.global_batch_size * args.seq_len}, ratio: {steps_done_epoch / expected_steps}" - ) - - epoch = epoch + 1 - evaluation_metrics = [] - if "val_list" in data and (epoch % args.val_frequency == 0 or done_training): - # validate based on frequency and always validate the last checkpoint - try: - evaluation_metrics = evaluate_loop(model, data["val_list"], epoch, args, writer) - - if is_master(args): - with fsspec.open(os.path.join(args.checkpoint_path, "results.jsonl"), "a") as f: - f.write(f"{json.dumps(evaluation_metrics)}\n") - - except Exception as e: - if is_master(args): - logging.error(e) - logging.error(traceback.format_exc()) - logging.warning("evaluation failed! continuing to save_checkpoint") - - if is_master(args): - end_of_epoch_log = { - "epoch": epoch, - "tokens": (global_step + 1) * args.global_batch_size * args.seq_len, - "checkpoints_too_few_tokens": num_ckpt_too_few_tokens, - "percentage_of_data_seen": steps_done_epoch / expected_steps, - } - - if args.dataset_manifest is not None: - for i in range(len(next_shard_per_source)): - end_of_epoch_log[f"next_shard_{i}"] = next_shard_per_source[i] - end_of_epoch_log[f"dataset_pass_{i}"] = next_shard_per_source[i] // len( - get_metadata_file(args.dataset_manifest[i]) - ) - - for name, val in end_of_epoch_log.items(): - name = "train/" + name - if writer is not None: - writer.add_scalar(name, val, global_step) - if args.wandb: - assert wandb is not None, "Please install wandb." - wandb.log({name: val, "step": global_step, "tokens": end_of_epoch_log["tokens"]}) - - # Saving checkpoints. - save_checkpoint( - args, - model, - optimizer, - scaler, - epoch, - evaluation_metrics, - step=global_step, - is_final_checkpoint=done_training, - percentage_of_data_seen=1.0 * steps_done_epoch / expected_steps, - next_shard_per_source=next_shard_per_source if args.dataset_manifest is not None else None, - samples_seen=samples_seen if args.dataset_manifest is not None else None, - shard_shuffle_seed=args.shard_shuffle_seed, - train_data_string=train_data_string_per_source if args.dataset_manifest is not None else None, - averagers=averagers, - failed=failed_ckpt, - ) - - if num_ckpt_too_few_tokens > args.data_tolerate_num_ckpts: - raise RuntimeError( - f"{num_ckpt_too_few_tokens} checkpoints happened where the number of tokens seen was {1 - args.data_tolerate_error_p} of expected. This is likely due to transient errors e.g. reading from S3." - ) - - if done_training: - if is_master(args): - logging.info("Model has seen the desired number of tokens. Ending training.") - break - - if args.wandb and is_master(args): - wandb.finish() - - # run a final sync. - if remote_sync_process is not None: - logging.info("Final remote sync.") - terminate_sync_process(remote_sync_process) - result = remote_sync_with_expon_backoff( - args.remote_sync_frequency, - os.path.join(args.logs, args.name), - os.path.join(args.remote_sync, args.name), - args.remote_sync_protocol, - ) - if result: - logging.info("Final remote sync successful.") - else: - logging.info("Final remote sync failed.") - - # Final sync of all procs. - if args.distributed: - dist.barrier() - - cleanup(remote_sync_process, args.distributed) - return args - - -def copy_codebase(args): - from shutil import copytree, ignore_patterns - - new_code_path = os.path.join(args.logs, args.name, "code") - if os.path.exists(new_code_path): - print(f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment.") - return -1 - print(f"Copying codebase to {new_code_path}") - current_code_path = os.path.realpath(__file__) - for _ in range(3): - current_code_path = os.path.dirname(current_code_path) - copytree(current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb")) - print("Done copying code.") - return 1 - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/open_lm/manifest.jsonl b/open_lm/manifest.jsonl deleted file mode 100644 index 1e4cc33..0000000 --- a/open_lm/manifest.jsonl +++ /dev/null @@ -1,200 +0,0 @@ -{"shard": "z", "num_sequences": 8192} -{"shard": "shard-0000001", "num_sequences": 8192} -{"shard": "shard-0000002", "num_sequences": 8192} -{"shard": "shard-0000003", "num_sequences": 8192} -{"shard": "shard-0000004", "num_sequences": 8192} -{"shard": "shard-0000005", "num_sequences": 8192} -{"shard": "shard-0000006", "num_sequences": 8192} -{"shard": "shard-0000007", "num_sequences": 8192} -{"shard": "shard-0000008", "num_sequences": 8192} -{"shard": "shard-0000009", "num_sequences": 8192} -{"shard": "shard-0000010", "num_sequences": 8192} -{"shard": "shard-0000011", "num_sequences": 8192} -{"shard": "shard-0000012", "num_sequences": 8192} -{"shard": "shard-0000013", "num_sequences": 8192} -{"shard": "shard-0000014", "num_sequences": 8192} -{"shard": "shard-0000015", "num_sequences": 8192} -{"shard": "shard-0000016", "num_sequences": 8192} -{"shard": "shard-0000017", "num_sequences": 8192} -{"shard": "shard-0000018", "num_sequences": 8192} -{"shard": "shard-0000019", "num_sequences": 8192} -{"shard": "shard-0000020", "num_sequences": 8192} -{"shard": "shard-0000021", "num_sequences": 8192} -{"shard": "shard-0000022", "num_sequences": 8192} -{"shard": "shard-0000023", "num_sequences": 8192} -{"shard": "shard-0000024", "num_sequences": 8192} -{"shard": "shard-0000025", "num_sequences": 8192} -{"shard": "shard-0000026", "num_sequences": 8192} -{"shard": "shard-0000027", "num_sequences": 8192} -{"shard": "shard-0000028", "num_sequences": 8192} -{"shard": "shard-0000029", "num_sequences": 8192} -{"shard": "shard-0000030", "num_sequences": 8192} -{"shard": "shard-0000031", "num_sequences": 8192} -{"shard": "shard-0000032", "num_sequences": 8192} -{"shard": "shard-0000033", "num_sequences": 8192} -{"shard": "shard-0000034", "num_sequences": 8192} -{"shard": "shard-0000035", "num_sequences": 8192} -{"shard": "shard-0000036", "num_sequences": 8192} -{"shard": "shard-0000037", "num_sequences": 8192} -{"shard": "shard-0000038", "num_sequences": 8192} -{"shard": "shard-0000039", "num_sequences": 8192} -{"shard": "shard-0000040", "num_sequences": 8192} -{"shard": "shard-0000041", "num_sequences": 8192} -{"shard": "shard-0000042", "num_sequences": 8192} -{"shard": "shard-0000043", "num_sequences": 8192} -{"shard": "shard-0000044", "num_sequences": 8192} -{"shard": "shard-0000045", "num_sequences": 8192} -{"shard": "shard-0000046", "num_sequences": 8192} -{"shard": "shard-0000047", "num_sequences": 8192} -{"shard": "shard-0000048", "num_sequences": 8192} -{"shard": "shard-0000049", "num_sequences": 8192} -{"shard": "shard-0000050", "num_sequences": 8192} -{"shard": "shard-0000051", "num_sequences": 8192} -{"shard": "shard-0000052", "num_sequences": 8192} -{"shard": "shard-0000053", "num_sequences": 8192} -{"shard": "shard-0000054", "num_sequences": 8192} -{"shard": "shard-0000055", "num_sequences": 8192} -{"shard": "shard-0000056", "num_sequences": 8192} -{"shard": "shard-0000057", "num_sequences": 8192} -{"shard": "shard-0000058", "num_sequences": 8192} -{"shard": "shard-0000059", "num_sequences": 8192} -{"shard": "shard-0000060", "num_sequences": 8192} -{"shard": "shard-0000061", "num_sequences": 8192} -{"shard": "shard-0000062", "num_sequences": 8192} -{"shard": "shard-0000063", "num_sequences": 8192} -{"shard": "shard-0000064", "num_sequences": 8192} -{"shard": "shard-0000065", "num_sequences": 8192} -{"shard": "shard-0000066", "num_sequences": 8192} -{"shard": "shard-0000067", "num_sequences": 8192} -{"shard": "shard-0000068", "num_sequences": 8192} -{"shard": "shard-0000069", "num_sequences": 8192} -{"shard": "shard-0000070", "num_sequences": 8192} -{"shard": "shard-0000071", "num_sequences": 8192} -{"shard": "shard-0000072", "num_sequences": 8192} -{"shard": "shard-0000073", "num_sequences": 8192} -{"shard": "shard-0000074", "num_sequences": 8192} -{"shard": "shard-0000075", "num_sequences": 8192} -{"shard": "shard-0000076", "num_sequences": 8192} -{"shard": "shard-0000077", "num_sequences": 8192} -{"shard": "shard-0000078", "num_sequences": 8192} -{"shard": "shard-0000079", "num_sequences": 8192} -{"shard": "shard-0000080", "num_sequences": 8192} -{"shard": "shard-0000081", "num_sequences": 8192} -{"shard": "shard-0000082", "num_sequences": 8192} -{"shard": "shard-0000083", "num_sequences": 8192} -{"shard": "shard-0000084", "num_sequences": 8192} -{"shard": "shard-0000085", "num_sequences": 8192} -{"shard": "shard-0000086", "num_sequences": 8192} -{"shard": "shard-0000087", "num_sequences": 8192} -{"shard": "shard-0000088", "num_sequences": 8192} -{"shard": "shard-0000089", "num_sequences": 8192} -{"shard": "shard-0000090", "num_sequences": 8192} -{"shard": "shard-0000091", "num_sequences": 8192} -{"shard": "shard-0000092", "num_sequences": 8192} -{"shard": "shard-0000093", "num_sequences": 8192} -{"shard": "shard-0000094", "num_sequences": 8192} -{"shard": "shard-0000095", "num_sequences": 8192} -{"shard": "shard-0000096", "num_sequences": 8192} -{"shard": "shard-0000097", "num_sequences": 8192} -{"shard": "shard-0000098", "num_sequences": 8192} -{"shard": "shard-0000099", "num_sequences": 8192} -{"shard": "shard-0000100", "num_sequences": 8192} -{"shard": "shard-0000101", "num_sequences": 8192} -{"shard": "shard-0000102", "num_sequences": 8192} -{"shard": "shard-0000103", "num_sequences": 8192} -{"shard": "shard-0000104", "num_sequences": 8192} -{"shard": "shard-0000105", "num_sequences": 8192} -{"shard": "shard-0000106", "num_sequences": 8192} -{"shard": "shard-0000107", "num_sequences": 8192} -{"shard": "shard-0000108", "num_sequences": 8192} -{"shard": "shard-0000109", "num_sequences": 8192} -{"shard": "shard-0000110", "num_sequences": 8192} -{"shard": "shard-0000111", "num_sequences": 8192} -{"shard": "shard-0000112", "num_sequences": 8192} -{"shard": "shard-0000113", "num_sequences": 8192} -{"shard": "shard-0000114", "num_sequences": 8192} -{"shard": "shard-0000115", "num_sequences": 8192} -{"shard": "shard-0000116", "num_sequences": 8192} -{"shard": "shard-0000117", "num_sequences": 8192} -{"shard": "shard-0000118", "num_sequences": 8192} -{"shard": "shard-0000119", "num_sequences": 8192} -{"shard": "shard-0000120", "num_sequences": 8192} -{"shard": "shard-0000121", "num_sequences": 8192} -{"shard": "shard-0000122", "num_sequences": 8192} -{"shard": "shard-0000123", "num_sequences": 8192} -{"shard": "shard-0000124", "num_sequences": 8192} -{"shard": "shard-0000125", "num_sequences": 8192} -{"shard": "shard-0000126", "num_sequences": 8192} -{"shard": "shard-0000127", "num_sequences": 8192} -{"shard": "shard-0000128", "num_sequences": 8192} -{"shard": "shard-0000129", "num_sequences": 8192} -{"shard": "shard-0000130", "num_sequences": 8192} -{"shard": "shard-0000131", "num_sequences": 8192} -{"shard": "shard-0000132", "num_sequences": 8192} -{"shard": "shard-0000133", "num_sequences": 8192} -{"shard": "shard-0000134", "num_sequences": 8192} -{"shard": "shard-0000135", "num_sequences": 8192} -{"shard": "shard-0000136", "num_sequences": 8192} -{"shard": "shard-0000137", "num_sequences": 8192} -{"shard": "shard-0000138", "num_sequences": 8192} -{"shard": "shard-0000139", "num_sequences": 8192} -{"shard": "shard-0000140", "num_sequences": 8192} -{"shard": "shard-0000141", "num_sequences": 8192} -{"shard": "shard-0000142", "num_sequences": 8192} -{"shard": "shard-0000143", "num_sequences": 8192} -{"shard": "shard-0000144", "num_sequences": 8192} -{"shard": "shard-0000145", "num_sequences": 8192} -{"shard": "shard-0000146", "num_sequences": 8192} -{"shard": "shard-0000147", "num_sequences": 8192} -{"shard": "shard-0000148", "num_sequences": 8192} -{"shard": "shard-0000149", "num_sequences": 8192} -{"shard": "shard-0000150", "num_sequences": 8192} -{"shard": "shard-0000151", "num_sequences": 8192} -{"shard": "shard-0000152", "num_sequences": 8192} -{"shard": "shard-0000153", "num_sequences": 8192} -{"shard": "shard-0000154", "num_sequences": 8192} -{"shard": "shard-0000155", "num_sequences": 8192} -{"shard": "shard-0000156", "num_sequences": 8192} -{"shard": "shard-0000157", "num_sequences": 8192} -{"shard": "shard-0000158", "num_sequences": 8192} -{"shard": "shard-0000159", "num_sequences": 8192} -{"shard": "shard-0000160", "num_sequences": 8192} -{"shard": "shard-0000161", "num_sequences": 8192} -{"shard": "shard-0000162", "num_sequences": 8192} -{"shard": "shard-0000163", "num_sequences": 8192} -{"shard": "shard-0000164", "num_sequences": 8192} -{"shard": "shard-0000165", "num_sequences": 8192} -{"shard": "shard-0000166", "num_sequences": 8192} -{"shard": "shard-0000167", "num_sequences": 8192} -{"shard": "shard-0000168", "num_sequences": 8192} -{"shard": "shard-0000169", "num_sequences": 8192} -{"shard": "shard-0000170", "num_sequences": 8192} -{"shard": "shard-0000171", "num_sequences": 8192} -{"shard": "shard-0000172", "num_sequences": 8192} -{"shard": "shard-0000173", "num_sequences": 8192} -{"shard": "shard-0000174", "num_sequences": 8192} -{"shard": "shard-0000175", "num_sequences": 8192} -{"shard": "shard-0000176", "num_sequences": 8192} -{"shard": "shard-0000177", "num_sequences": 8192} -{"shard": "shard-0000178", "num_sequences": 8192} -{"shard": "shard-0000179", "num_sequences": 8192} -{"shard": "shard-0000180", "num_sequences": 8192} -{"shard": "shard-0000181", "num_sequences": 8192} -{"shard": "shard-0000182", "num_sequences": 8192} -{"shard": "shard-0000183", "num_sequences": 8192} -{"shard": "shard-0000184", "num_sequences": 8192} -{"shard": "shard-0000185", "num_sequences": 8192} -{"shard": "shard-0000186", "num_sequences": 8192} -{"shard": "shard-0000187", "num_sequences": 8192} -{"shard": "shard-0000188", "num_sequences": 8192} -{"shard": "shard-0000189", "num_sequences": 8192} -{"shard": "shard-0000190", "num_sequences": 8192} -{"shard": "shard-0000191", "num_sequences": 8192} -{"shard": "shard-0000192", "num_sequences": 8192} -{"shard": "shard-0000193", "num_sequences": 8192} -{"shard": "shard-0000194", "num_sequences": 8192} -{"shard": "shard-0000195", "num_sequences": 8192} -{"shard": "shard-0000196", "num_sequences": 8192} -{"shard": "shard-0000197", "num_sequences": 8192} -{"shard": "shard-0000198", "num_sequences": 8192} -{"shard": "shard-0000199", "num_sequences": 8192} \ No newline at end of file diff --git a/open_lm/model.py b/open_lm/model.py index 3c00cc4..ba2dd1b 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -509,22 +509,23 @@ def create_model(args): def create_classif_model(args): model = Transformer(create_params(args)) - checkpoint = pt_load(args.classif_model_path, map_location="cpu") - model.load_state_dict(checkpoint["state_dict"]) + if args.classif_model_path is not None: + checkpoint = pt_load(args.classif_model_path, map_location="cpu") + model.load_state_dict(checkpoint["state_dict"]) dim = model.output.in_features model.output = nn.Linear(dim, args.num_classes, bias = False) - + return model -def test_classif_model(args, model_path): +def test_classif_model(args): model = Transformer(create_params(args)) dim = model.output.in_features model.output = nn.Linear(dim, args.num_classes, bias = False) - checkpoint = pt_load(model_path, map_location="cpu") + checkpoint = pt_load(args.classif_model_path, map_location="cpu") model.load_state_dict(checkpoint["state_dict"]) return model diff --git a/open_lm/params.py b/open_lm/params.py index f74fa89..1b5f79f 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -804,7 +804,8 @@ def parse_args(args): type=str, default=None, help="Path of the pretrained model to be finetuned for classification.", - ) + ) + add_model_args(parser) config = maybe_load_config(parser, args) diff --git a/open_lm/positional_embedding/__pycache__/__init__.cpython-310.pyc b/open_lm/positional_embedding/__pycache__/__init__.cpython-310.pyc index 3a8f889..926d8c8 100644 Binary files a/open_lm/positional_embedding/__pycache__/__init__.cpython-310.pyc and b/open_lm/positional_embedding/__pycache__/__init__.cpython-310.pyc differ diff --git a/open_lm/positional_embedding/__pycache__/head_rotary.cpython-310.pyc b/open_lm/positional_embedding/__pycache__/head_rotary.cpython-310.pyc index 9904228..9a508c1 100644 Binary files a/open_lm/positional_embedding/__pycache__/head_rotary.cpython-310.pyc and b/open_lm/positional_embedding/__pycache__/head_rotary.cpython-310.pyc differ diff --git a/open_lm/positional_embedding/__pycache__/llama_rotary.cpython-310.pyc b/open_lm/positional_embedding/__pycache__/llama_rotary.cpython-310.pyc index 47159ae..8d8be0b 100644 Binary files a/open_lm/positional_embedding/__pycache__/llama_rotary.cpython-310.pyc and b/open_lm/positional_embedding/__pycache__/llama_rotary.cpython-310.pyc differ diff --git a/open_lm/positional_embedding/__pycache__/none.cpython-310.pyc b/open_lm/positional_embedding/__pycache__/none.cpython-310.pyc index 296bc4e..5dbdb82 100644 Binary files a/open_lm/positional_embedding/__pycache__/none.cpython-310.pyc and b/open_lm/positional_embedding/__pycache__/none.cpython-310.pyc differ diff --git a/open_lm/positional_embedding/__pycache__/rotary.cpython-310.pyc b/open_lm/positional_embedding/__pycache__/rotary.cpython-310.pyc index d4bd893..6c775c6 100644 Binary files a/open_lm/positional_embedding/__pycache__/rotary.cpython-310.pyc and b/open_lm/positional_embedding/__pycache__/rotary.cpython-310.pyc differ diff --git a/open_lm/run_bench.sh b/open_lm/run_bench.sh deleted file mode 100644 index 5676646..0000000 --- a/open_lm/run_bench.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -BATCHSIZE=1 -MODEL="large2048" -EXP_NAME="benchmark-$MODEL" - -torchrun --nproc-per-node 1 -m benchmark.main \ - --train-data "pipe:aws s3 cp s3://s-laion/redpajama-tars/8192-v1/{0..7}/shard-{0000000..0000300}.tar -" \ - --train-num-samples 30720 \ - --workers 6 \ - --precision amp_bfloat16 \ - --grad-checkpointing \ - --grad-clip-norm 1 \ - --log-every-n-steps 1 \ - --fsdp \ - --profile \ - --batch-size $BATCHSIZE \ - --model $MODEL \ - --name $EXP_NAME \ diff --git a/open_lm/test_class.py b/open_lm/test_class.py new file mode 100644 index 0000000..d0e9a0b --- /dev/null +++ b/open_lm/test_class.py @@ -0,0 +1,76 @@ +import os +import shutil +import random +import json +import torch +import numpy as np +import subprocess + +from open_lm.params import parse_args +from open_lm.model import test_classif_model + +device = "cuda:3" + +def inference(): + + args = parse_args([]) + args.model = "open_lm_160m" + args.classif_model_path = "/workspace/youssef/lrz/logs/rewritten/classif160M_3.2BC4_C4_FW_320M_prompt3/checkpoints/epoch_1.pt" + args.num_classes = 2 + + model = test_classif_model(args) + model = model.to(device) + + + test_data_path1 = '/workspace/youssef/lrz/datasets/test/rewritten/C4_test_prompt3.pt' + test_data_path2 = '/workspace/youssef/lrz/datasets/test/rewritten/FW_test_prompt3.pt' + +##################################################################################################################### + dataset = torch.load(test_data_path1) + sum = 0 + for sample in dataset: + sample = torch.LongTensor(sample).to(device) + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 0).item() + + sum = sum + n_correct + + sum1 = sum + len1 = len(dataset) + print('C4', sum1, "/" , len1) + + dataset = torch.load(test_data_path2) + sum = 0 + for sample in dataset: + sample = torch.LongTensor(sample).to(device) + + with torch.no_grad(): + out, _, _ = model(sample) + + pred = torch.argmax(out,2)[:,-1] + + n_correct = torch.sum(pred == 1).item() + + sum = sum + n_correct + + sum2 = sum + len2 = len(dataset) + print('FW', sum2, "/" , len2) +############################################################################################################################### + + total_sum = sum1+sum2 + total_length = len1+len2 + + print("Total= ", total_sum, "/" , total_length ) + print("Accuracy= ", total_sum/total_length * 100, "%") + + +if __name__ == "__main__": + print("starting script") + inference() + print("ending script") \ No newline at end of file diff --git a/open_lm/train_class.py b/open_lm/train_class.py new file mode 100644 index 0000000..1a19480 --- /dev/null +++ b/open_lm/train_class.py @@ -0,0 +1,68 @@ +import os +import shutil +import random +import json +import torch +import numpy as np +import subprocess + +from open_lm.params import parse_args +from open_lm.model import test_classif_model + +device = "3" + +def train_classifier(cuda_devices=device, log_dir="/workspace/youssef/lrz/logs/rewritten/classif160M_3.2BC4_C4_FW_320M_prompt3"): + # Set the CUDA_VISIBLE_DEVICES environment variable + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_devices + + # Generate a random master port between 10000 and 65000 + master_port = random.randint(10000, 65000) + + # Construct the torchrun command + command = [ + "torchrun", + f"--master_port={master_port}", + "--nproc-per-node", "1", + "-m", "open_lm.main", + "--model", "open_lm_160m", + "--dataset-manifest", "/workspace/youssef/lrz/datasets/rewritten/0C4_1FW_prompt3/manifest.jsonl", + "--train-num-samples", "320000000", + "--workers", "1", + "--precision", "amp_bfloat16", + "--grad-checkpointing", + "--log-every-n-steps", "100", + "--grad-clip-norm", "1", + "--global-batch-size", "16", + "--data-key", "txt", + "--lr", "3e-4", + "--warmup", "2000", + "--wd", "0.1", + "--beta2", "0.95", + "--epochs", "1", + "--resume", "latest", + "--logs", "/workspace/youssef/lrz/logs/rewritten/", + "--name", "classif160M_3.2BC4_C4_FW_320M_prompt3", + "--classification", "True", + "--num-classes", "2", + "--classif-model-path", "/workspace/youssef/lrz/logs/pretrain/160M_3.2BC4/checkpoint/epoch_3.pt" + ] + + os.makedirs(log_dir, exist_ok=True) + + # Create log file paths + stdout_log = os.path.join(log_dir, "output.log") + stderr_log = os.path.join(log_dir, "error.log") + + # Run the torchrun command using subprocess + with open(stdout_log, "w") as out_file, open(stderr_log, "w") as err_file: + try: + result = subprocess.run(command, check=True, stdout=out_file, stderr=err_file) + print(f"torchrun finished with return code: {result.returncode}") + except subprocess.CalledProcessError as e: + print(f"An error occurred while running torchrun: {e}") + + +if __name__ == "__main__": + print("starting script") + train_classifier() + print("ending script") diff --git a/open_lm/utils/__pycache__/__init__.cpython-310.pyc b/open_lm/utils/__pycache__/__init__.cpython-310.pyc index abb15c2..87b50fa 100644 Binary files a/open_lm/utils/__pycache__/__init__.cpython-310.pyc and b/open_lm/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/open_lm/utils/__pycache__/averaging_utils.cpython-310.pyc b/open_lm/utils/__pycache__/averaging_utils.cpython-310.pyc index 5a3427a..8d4b678 100644 Binary files a/open_lm/utils/__pycache__/averaging_utils.cpython-310.pyc and b/open_lm/utils/__pycache__/averaging_utils.cpython-310.pyc differ diff --git a/open_lm/utils/__pycache__/make_wds_manifest.cpython-310.pyc b/open_lm/utils/__pycache__/make_wds_manifest.cpython-310.pyc index 32b1b1b..a3716f4 100644 Binary files a/open_lm/utils/__pycache__/make_wds_manifest.cpython-310.pyc and b/open_lm/utils/__pycache__/make_wds_manifest.cpython-310.pyc differ diff --git a/open_lm/utils/llm_foundry_wrapper.py b/open_lm/utils/llm_foundry_wrapper.py index 166d42a..f4f14e7 100644 --- a/open_lm/utils/llm_foundry_wrapper.py +++ b/open_lm/utils/llm_foundry_wrapper.py @@ -4,14 +4,14 @@ """Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`.""" from typing import Mapping, Union - -from composer.metrics.nlp import ( +from llmfoundry.eval.metrics.nlp import ( InContextLearningLMAccuracy, InContextLearningLMExpectedCalibrationError, InContextLearningMCExpectedCalibrationError, InContextLearningMultipleChoiceAccuracy, - InContextLearningQAAccuracy, - InContextLearningCodeEvalAccuracy, + InContextLearningGenerationExactMatchAccuracy, +) +from composer.metrics.nlp import ( LanguageCrossEntropy, LanguagePerplexity, ) @@ -33,10 +33,9 @@ LanguagePerplexity(), InContextLearningLMAccuracy(), InContextLearningMultipleChoiceAccuracy(), - InContextLearningQAAccuracy(), + InContextLearningGenerationExactMatchAccuracy(), InContextLearningLMExpectedCalibrationError(), InContextLearningMCExpectedCalibrationError(), - InContextLearningCodeEvalAccuracy(), ] diff --git a/open_lm/utils/transformers/__pycache__/__init__.cpython-310.pyc b/open_lm/utils/transformers/__pycache__/__init__.cpython-310.pyc index e4bee55..5285aac 100644 Binary files a/open_lm/utils/transformers/__pycache__/__init__.cpython-310.pyc and b/open_lm/utils/transformers/__pycache__/__init__.cpython-310.pyc differ diff --git a/open_lm/utils/transformers/__pycache__/hf_config.cpython-310.pyc b/open_lm/utils/transformers/__pycache__/hf_config.cpython-310.pyc new file mode 100644 index 0000000..a6a0f30 Binary files /dev/null and b/open_lm/utils/transformers/__pycache__/hf_config.cpython-310.pyc differ diff --git a/open_lm/utils/transformers/__pycache__/hf_model.cpython-310.pyc b/open_lm/utils/transformers/__pycache__/hf_model.cpython-310.pyc new file mode 100644 index 0000000..fd1a289 Binary files /dev/null and b/open_lm/utils/transformers/__pycache__/hf_model.cpython-310.pyc differ diff --git a/open_lm/utils/transformers/__pycache__/hf_wrapper.cpython-310.pyc b/open_lm/utils/transformers/__pycache__/hf_wrapper.cpython-310.pyc index 95532e2..ef49297 100644 Binary files a/open_lm/utils/transformers/__pycache__/hf_wrapper.cpython-310.pyc and b/open_lm/utils/transformers/__pycache__/hf_wrapper.cpython-310.pyc differ diff --git a/requirements.txt b/requirements.txt index 387b1ca..898d184 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch==2.2.2 +torch xformers>=0.0.22 tiktoken wandb diff --git a/requirements_test.txt b/requirements_test.txt index 61f15ce..8413123 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -3,4 +3,4 @@ pytest-cov==3.0.0 pytest-xdist==2.5.0 pytest==7.0.1 tensorboard==2.14.1 -llm-foundry>=0.4.0 +llm-foundry==0.9.0