diff --git a/scripts/inference/convert_composer_to_hf.py b/scripts/inference/convert_composer_to_hf.py index b47bd7f309..5c4d4117c5 100644 --- a/scripts/inference/convert_composer_to_hf.py +++ b/scripts/inference/convert_composer_to_hf.py @@ -189,6 +189,7 @@ def convert_composer_to_hf(args: Namespace) -> None: print(f'Loading model from {local_folder_path}') if config.model_type == 'mpt': config.attn_config['attn_impl'] = 'torch' + config.init_device = 'cpu' if config.model_type == 'mpt': loaded_hf_model = MPTForCausalLM.from_pretrained(local_folder_path, diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 1561f965c9..e16832d803 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -157,3 +157,66 @@ def test_convert_and_generate_triton(tmp_path: pathlib.Path): assert output.shape == (1, 2) delete_transformers_cache() + + +def test_convert_and_generate_meta(tmp_path: pathlib.Path): + delete_transformers_cache() + + from composer.utils import dist + gathered_paths = dist.all_gather_object(tmp_path) + tmp_path_gathered = gathered_paths[0] + + om_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml') + + om_cfg['model']['init_device'] = 'cpu' + tokenizer = transformers.AutoTokenizer.from_pretrained( + om_cfg.tokenizer.name) + original_model = COMPOSER_MODEL_REGISTRY[om_cfg['model'].name]( + om_cfg['model'], tokenizer) + trainer = Trainer(model=original_model, device='cpu') + trainer.save_checkpoint(os.path.join(tmp_path_gathered, 'checkpoint.pt')) + + # patch in the meta device for testing + sd = torch.load(os.path.join(tmp_path_gathered, 'checkpoint.pt'), + map_location='cpu') + sd['state']['integrations']['huggingface']['model']['config']['content'][ + 'init_device'] = 'meta' + torch.save(sd, os.path.join(tmp_path_gathered, 'checkpoint.pt')) + + args = Namespace(composer_path=os.path.join(tmp_path_gathered, + 'checkpoint.pt'), + hf_output_path=os.path.join(tmp_path_gathered, + 'hf-output-folder'), + output_precision='fp32', + local_checkpoint_save_location=None, + hf_repo_for_upload=None, + test_uploaded_model=False) + convert_composer_to_hf(args) + + loaded_config = transformers.AutoConfig.from_pretrained( + os.path.join(tmp_path_gathered, 'hf-output-folder'), + trust_remote_code=True) + loaded_model = transformers.AutoModelForCausalLM.from_pretrained( + os.path.join(tmp_path_gathered, 'hf-output-folder'), + config=loaded_config, + trust_remote_code=True) + tokenizer = transformers.AutoTokenizer.from_pretrained( + os.path.join(tmp_path_gathered, 'hf-output-folder'), + trust_remote_code=True) + + output = loaded_model.generate(tokenizer('hello', + return_tensors='pt')['input_ids'], + max_new_tokens=1) + assert output.shape == (1, 2) + + assert sum(p.numel() for p in original_model.model.parameters()) == sum( + p.numel() for p in loaded_model.parameters()) + assert all( + str(type(module1)).split('.')[-1] == str(type(module2)).split('.')[-1] + for module1, module2 in zip(original_model.model.modules(), + loaded_model.modules())) + for p1, p2 in zip(original_model.model.parameters(), + loaded_model.parameters()): + assert torch.allclose(p1, p2) + + delete_transformers_cache()