diff --git a/ldm/modules/callbacks/captions.py b/ldm/modules/callbacks/captions.py deleted file mode 100644 index fd32bb68..00000000 --- a/ldm/modules/callbacks/captions.py +++ /dev/null @@ -1,16 +0,0 @@ -from captionizer import caption_from_path -from pytorch_lightning.callbacks import Callback - -class CaptionSaverCallback(Callback): - def __init__(self): - super().__init__() - - def on_save_checkpoint(self, trainer, pl_module, checkpoint): - print('Adding training captions to checkpoint [Dataloader]') - data = trainer.train_dataloader.loaders.sampler.data_source # type: ignore - prompts = set([ - caption_from_path(image_path, data.data_root, data.coarse_class_text, data.placeholder_token) - for image_path in data.image_paths - ]) - trained_prompts = (list(prompts)) - checkpoint['trained_captions'] = trained_prompts diff --git a/ldm/pruner.py b/ldm/pruner.py index 44697bbd..6b3f088b 100644 --- a/ldm/pruner.py +++ b/ldm/pruner.py @@ -1,6 +1,7 @@ def prune_checkpoint(old_state): print(f"Pruning Checkpoint") pruned_checkpoint = dict() + print(f"Checkpoint Keys: {old_state.keys()}") for key in old_state.keys(): if key != "optimizer_states": pruned_checkpoint[key] = old_state[key] diff --git a/main.py b/main.py index abe8af68..0af91e17 100644 --- a/main.py +++ b/main.py @@ -771,9 +771,6 @@ def on_train_epoch_start(self, trainer, pl_module): "cuda_callback": { "target": "main.CUDACallback" }, - "captions_callback": { - "target": "ldm.modules.callbacks.captions.CaptionSaverCallback" - } } if version.parse(pl.__version__) >= version.parse('1.4.0'): default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) diff --git a/scripts/show_captions.py b/scripts/show_captions.py deleted file mode 100755 index ef7471b0..00000000 --- a/scripts/show_captions.py +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/env python3 -import torch -import sys, os - -def main(checkpoint_path): - try: - checkpoint = torch.load(checkpoint_path, map_location="cpu") - if 'trained_captions' in checkpoint: - captions = checkpoint['trained_captions'] - global_step = checkpoint['global_step'] - print(f'Captions in {os.path.basename(checkpoint_path)} [{global_step} Global Steps]:') - for caption in captions: - print(f'\t"{caption}"') - else: - print(f'{checkpoint_path} has no captions saved') - except: - print(f'Failed to extract captions from {checkpoint_path}') - -if __name__ == '__main__': - if len(sys.argv) == 1: - print(f'{sys.argv[0]} ') - main(sys.argv[1]) -