Skip to content

Commit

Permalink
Fix init device in conversion script and add tests (#556)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Aug 24, 2023
1 parent a5b39b6 commit 2f30418
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
1 change: 1 addition & 0 deletions scripts/inference/convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def convert_composer_to_hf(args: Namespace) -> None:
print(f'Loading model from {local_folder_path}')
if config.model_type == 'mpt':
config.attn_config['attn_impl'] = 'torch'
config.init_device = 'cpu'

if config.model_type == 'mpt':
loaded_hf_model = MPTForCausalLM.from_pretrained(local_folder_path,
Expand Down
63 changes: 63 additions & 0 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,66 @@ def test_convert_and_generate_triton(tmp_path: pathlib.Path):
assert output.shape == (1, 2)

delete_transformers_cache()


def test_convert_and_generate_meta(tmp_path: pathlib.Path):
delete_transformers_cache()

from composer.utils import dist
gathered_paths = dist.all_gather_object(tmp_path)
tmp_path_gathered = gathered_paths[0]

om_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml')

om_cfg['model']['init_device'] = 'cpu'
tokenizer = transformers.AutoTokenizer.from_pretrained(
om_cfg.tokenizer.name)
original_model = COMPOSER_MODEL_REGISTRY[om_cfg['model'].name](
om_cfg['model'], tokenizer)
trainer = Trainer(model=original_model, device='cpu')
trainer.save_checkpoint(os.path.join(tmp_path_gathered, 'checkpoint.pt'))

# patch in the meta device for testing
sd = torch.load(os.path.join(tmp_path_gathered, 'checkpoint.pt'),
map_location='cpu')
sd['state']['integrations']['huggingface']['model']['config']['content'][
'init_device'] = 'meta'
torch.save(sd, os.path.join(tmp_path_gathered, 'checkpoint.pt'))

args = Namespace(composer_path=os.path.join(tmp_path_gathered,
'checkpoint.pt'),
hf_output_path=os.path.join(tmp_path_gathered,
'hf-output-folder'),
output_precision='fp32',
local_checkpoint_save_location=None,
hf_repo_for_upload=None,
test_uploaded_model=False)
convert_composer_to_hf(args)

loaded_config = transformers.AutoConfig.from_pretrained(
os.path.join(tmp_path_gathered, 'hf-output-folder'),
trust_remote_code=True)
loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_path_gathered, 'hf-output-folder'),
config=loaded_config,
trust_remote_code=True)
tokenizer = transformers.AutoTokenizer.from_pretrained(
os.path.join(tmp_path_gathered, 'hf-output-folder'),
trust_remote_code=True)

output = loaded_model.generate(tokenizer('hello',
return_tensors='pt')['input_ids'],
max_new_tokens=1)
assert output.shape == (1, 2)

assert sum(p.numel() for p in original_model.model.parameters()) == sum(
p.numel() for p in loaded_model.parameters())
assert all(
str(type(module1)).split('.')[-1] == str(type(module2)).split('.')[-1]
for module1, module2 in zip(original_model.model.modules(),
loaded_model.modules()))
for p1, p2 in zip(original_model.model.parameters(),
loaded_model.parameters()):
assert torch.allclose(p1, p2)

delete_transformers_cache()

0 comments on commit 2f30418

Please sign in to comment.