Skip to content

Commit

Permalink
Small bug fixes to bulk image generation (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Aug 26, 2024
1 parent 45be8fb commit ee934a4
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion diffusion/evaluation/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self,
get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True)
with dist.local_rank_zero_download_and_wait(self.local_checkpoint_path):
# Load the model
state_dict = torch.load(self.local_checkpoint_path)
state_dict = torch.load(self.local_checkpoint_path, map_location='cpu')
for key in list(state_dict['state']['model'].keys()):
if 'val_metrics.' in key:
del state_dict['state']['model'][key]
Expand Down
2 changes: 1 addition & 1 deletion diffusion/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def generate(config: DictConfig) -> None:
config (DictConfig): Configuration composed by Hydra
"""
reproducibility.seed_all(config.seed)
device = get_device() # type: ignore
device = get_device(None) # type: ignore
dist.initialize_dist(device, config.dist_timeout)

# The model to evaluate
Expand Down

0 comments on commit ee934a4

Please sign in to comment.