diff --git a/test_runner.py b/test_runner.py index e6b03f51..ca9d1320 100755 --- a/test_runner.py +++ b/test_runner.py @@ -17,11 +17,6 @@ import tomli as tomllib -parser = argparse.ArgumentParser() -parser.add_argument("output_dir") -args = parser.parse_args() - - @dataclass class OverrideDefinitions: """ @@ -32,77 +27,77 @@ class OverrideDefinitions: test_descr: str = "default" -CONFIG_DIR = "./train_configs" - -""" -key is the config file name and value is a list of OverrideDefinitions -that is used to generate variations of integration tests based on the -same root config file. -""" -integration_tests_flavors = defaultdict(list) -integration_tests_flavors["debug_model.toml"] = [ - OverrideDefinitions( - [ - [ - f"--job.dump_folder {args.output_dir}/default/", - ], - ], - "Default", - ), - OverrideDefinitions( - [ +def build_test_list(args): + """ + key is the config file name and value is a list of OverrideDefinitions + that is used to generate variations of integration tests based on the + same root config file. + """ + integration_tests_flavors = defaultdict(list) + integration_tests_flavors["debug_model.toml"] = [ + OverrideDefinitions( [ - "--training.compile", - f"--job.dump_folder {args.output_dir}/1d_compile/", + [ + f"--job.dump_folder {args.output_dir}/default/", + ], ], - ], - "1D compile", - ), - OverrideDefinitions( - [ + "Default", + ), + OverrideDefinitions( [ - "--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", - f"--job.dump_folder {args.output_dir}/eager_2d/", + [ + "--training.compile", + f"--job.dump_folder {args.output_dir}/1d_compile/", + ], ], - ], - "Eager mode 2DParallel", - ), - OverrideDefinitions( - [ + "1D compile", + ), + OverrideDefinitions( [ - "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/full_checkpoint/", + [ + "--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", + f"--job.dump_folder {args.output_dir}/eager_2d/", + ], ], + "Eager mode 2DParallel", + ), + OverrideDefinitions( [ - "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/full_checkpoint/", - "--training.steps 20", + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/full_checkpoint/", + ], + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/full_checkpoint/", + "--training.steps 20", + ], ], - ], - "Checkpoint Integration Test - Save Load Full Checkpoint", - ), - OverrideDefinitions( - [ + "Checkpoint Integration Test - Save Load Full Checkpoint", + ), + OverrideDefinitions( [ - "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/model_weights_only_fp32/", - "--checkpoint.model_weights_only", + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/model_weights_only_fp32/", + "--checkpoint.model_weights_only", + ], ], - ], - "Checkpoint Integration Test - Save Model Weights Only fp32", - ), - OverrideDefinitions( - [ + "Checkpoint Integration Test - Save Model Weights Only fp32", + ), + OverrideDefinitions( [ - "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/model_weights_only_bf16/", - "--checkpoint.model_weights_only", - "--checkpoint.export_dtype bfloat16", + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/model_weights_only_bf16/", + "--checkpoint.model_weights_only", + "--checkpoint.export_dtype bfloat16", + ], ], - ], - "Checkpoint Integration Test - Save Model Weights Only bf16", - ), -] + "Checkpoint Integration Test - Save Model Weights Only bf16", + ), + ] + return integration_tests_flavors def run_test(test_flavor: OverrideDefinitions, full_path: str): @@ -128,12 +123,33 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str): ) -for config_file in os.listdir(CONFIG_DIR): - if config_file.endswith(".toml"): - full_path = os.path.join(CONFIG_DIR, config_file) - with open(full_path, "rb") as f: - config = tomllib.load(f) - is_integration_test = config["job"].get("use_for_integration_test", False) - if is_integration_test: - for test_flavor in integration_tests_flavors[config_file]: - run_test(test_flavor, full_path) +def run_tests(args): + integration_tests_flavors = build_test_list(args) + for config_file in os.listdir(args.config_dir): + if config_file.endswith(".toml"): + full_path = os.path.join(args.config_dir, config_file) + with open(full_path, "rb") as f: + config = tomllib.load(f) + is_integration_test = config["job"].get( + "use_for_integration_test", False + ) + if is_integration_test: + for test_flavor in integration_tests_flavors[config_file]: + run_test(test_flavor, full_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("output_dir") + parser.add_argument("--config_dir", default="./train_configs") + args = parser.parse_args() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + if os.listdir(args.output_dir): + raise RuntimeError("Please provide an empty output directory.") + run_tests(args) + + +if __name__ == "__main__": + main()