From fec1426f9694f01b910228d89273408a0e4cc376 Mon Sep 17 00:00:00 2001 From: Seonmi Jung <149867370+Jesseonmi@users.noreply.github.com> Date: Sat, 20 Jul 2024 00:07:05 +0900 Subject: [PATCH] Update train.py for more efficiency --- train.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 951bda9914..2ddcb629ad 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,5 @@ +# I am commenting on this code just to point out the parts I think isn't efficient enough and parts you can fix + """ This training script can be run both on a single gpu in debug mode, and also in a larger training run with distributed data parallel (ddp). @@ -73,7 +75,16 @@ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler compile = True # use PyTorch 2.0 to compile the model to be faster # ----------------------------------------------------------------------------- -config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] + +# Idk if my knowledge is correct but I learned that the python std startswith() is more than O(n) you could improve it by creating your own function ;) + +def startswith(element, character): + if element[0] == character: + return True + else: + return False + +config_keys = [k for k,v in globals().items() if not startswith(k, '_') and isinstance(v, (int, float, bool, str))] exec(open('configurator.py').read()) # overrides from command line or config file config = {k: globals()[k] for k in config_keys} # will be useful for logging # ----------------------------------------------------------------------------- @@ -143,6 +154,8 @@ def get_batch(split): meta_vocab_size = meta['vocab_size'] print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") +# you could definitely simplify it from line 151 to line 200 you've got some unnecessary process + # model init model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line