Skip to content

Commit

Permalink
Merge pull request #76 from fabbarix/feature/save_captions
Browse files Browse the repository at this point in the history
Save captions with checkpoint
  • Loading branch information
JoePenna authored Dec 16, 2022
2 parents 0a9c95b + d3d5239 commit a1781aa
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 1 deletion.
16 changes: 16 additions & 0 deletions ldm/modules/callbacks/captions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from captionizer import caption_from_path
from pytorch_lightning.callbacks import Callback

class CaptionSaverCallback(Callback):
def __init__(self):
super().__init__()

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
print('Adding training captions to checkpoint [Dataloader]')
data = trainer.train_dataloader.loaders.sampler.data_source # type: ignore
prompts = set([
caption_from_path(image_path, data.data_root, data.coarse_class_text, data.placeholder_token)
for image_path in data.image_paths
])
trained_prompts = (list(prompts))
checkpoint['trained_captions'] = trained_prompts
1 change: 0 additions & 1 deletion ldm/pruner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
def prune_checkpoint(old_state):
print(f"Pruning Checkpoint")
pruned_checkpoint = dict()
print(f"Checkpoint Keys: {old_state.keys()}")
for key in old_state.keys():
if key != "optimizer_states":
pruned_checkpoint[key] = old_state[key]
Expand Down
3 changes: 3 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,9 @@ def on_train_epoch_start(self, trainer, pl_module):
"cuda_callback": {
"target": "main.CUDACallback"
},
"captions_callback": {
"target": "ldm.modules.callbacks.captions.CaptionSaverCallback"
}
}
if version.parse(pl.__version__) >= version.parse('1.4.0'):
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
Expand Down
23 changes: 23 additions & 0 deletions scripts/show_captions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/env python3
import torch
import sys, os

def main(checkpoint_path):
try:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
if 'trained_captions' in checkpoint:
captions = checkpoint['trained_captions']
global_step = checkpoint['global_step']
print(f'Captions in {os.path.basename(checkpoint_path)} [{global_step} Global Steps]:')
for caption in captions:
print(f'\t"{caption}"')
else:
print(f'{checkpoint_path} has no captions saved')
except:
print(f'Failed to extract captions from {checkpoint_path}')

if __name__ == '__main__':
if len(sys.argv) == 1:
print(f'{sys.argv[0]} <checkpoint>')
main(sys.argv[1])

1 comment on commit a1781aa

@spotniko
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new captions throws an error after training the first 500 steps. I have used vast.ai and it worked yesterday before this update

Please sign in to comment.