diff --git a/ldm/modules/callbacks/captions.py b/ldm/modules/callbacks/captions.py new file mode 100644 index 00000000..fd32bb68 --- /dev/null +++ b/ldm/modules/callbacks/captions.py @@ -0,0 +1,16 @@ +from captionizer import caption_from_path +from pytorch_lightning.callbacks import Callback + +class CaptionSaverCallback(Callback): + def __init__(self): + super().__init__() + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + print('Adding training captions to checkpoint [Dataloader]') + data = trainer.train_dataloader.loaders.sampler.data_source # type: ignore + prompts = set([ + caption_from_path(image_path, data.data_root, data.coarse_class_text, data.placeholder_token) + for image_path in data.image_paths + ]) + trained_prompts = (list(prompts)) + checkpoint['trained_captions'] = trained_prompts diff --git a/ldm/pruner.py b/ldm/pruner.py index 6b3f088b..44697bbd 100644 --- a/ldm/pruner.py +++ b/ldm/pruner.py @@ -1,7 +1,6 @@ def prune_checkpoint(old_state): print(f"Pruning Checkpoint") pruned_checkpoint = dict() - print(f"Checkpoint Keys: {old_state.keys()}") for key in old_state.keys(): if key != "optimizer_states": pruned_checkpoint[key] = old_state[key] diff --git a/main.py b/main.py index 0af91e17..abe8af68 100644 --- a/main.py +++ b/main.py @@ -771,6 +771,9 @@ def on_train_epoch_start(self, trainer, pl_module): "cuda_callback": { "target": "main.CUDACallback" }, + "captions_callback": { + "target": "ldm.modules.callbacks.captions.CaptionSaverCallback" + } } if version.parse(pl.__version__) >= version.parse('1.4.0'): default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) diff --git a/scripts/show_captions.py b/scripts/show_captions.py new file mode 100755 index 00000000..ef7471b0 --- /dev/null +++ b/scripts/show_captions.py @@ -0,0 +1,23 @@ +#!/bin/env python3 +import torch +import sys, os + +def main(checkpoint_path): + try: + checkpoint = torch.load(checkpoint_path, map_location="cpu") + if 'trained_captions' in checkpoint: + captions = checkpoint['trained_captions'] + global_step = checkpoint['global_step'] + print(f'Captions in {os.path.basename(checkpoint_path)} [{global_step} Global Steps]:') + for caption in captions: + print(f'\t"{caption}"') + else: + print(f'{checkpoint_path} has no captions saved') + except: + print(f'Failed to extract captions from {checkpoint_path}') + +if __name__ == '__main__': + if len(sys.argv) == 1: + print(f'{sys.argv[0]} ') + main(sys.argv[1]) +