Skip to content

Commit

Permalink
Merge pull request #1266 from bghira/feature/skip-missing-caption-images
Browse files Browse the repository at this point in the history
skip images that have missing caption
  • Loading branch information
bghira authored Jan 6, 2025
2 parents 74a1716 + a9242b8 commit 24da2a7
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 37 deletions.
5 changes: 3 additions & 2 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from helpers.training.exceptions import MultiDatasetExhausted
from helpers.multiaspect.dataset import MultiAspectDataset
from helpers.multiaspect.sampler import MultiAspectSampler
from helpers.prompts import PromptHandler
from helpers.prompts import PromptHandler, CaptionNotFoundError
from helpers.caching.vae import VAECache
from helpers.training.multi_process import should_log, rank_info, _get_rank as get_rank
from helpers.training.collate import collate_fn
Expand Down Expand Up @@ -1001,7 +1001,7 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
and "text" not in backend.get("skip_file_discovery", "")
):
info_log(f"(id={init_backend['id']}) Collecting captions.")
captions = PromptHandler.get_all_captions(
captions, images_missing_captions = PromptHandler.get_all_captions(
data_backend=init_backend["data_backend"],
instance_data_dir=init_backend["instance_data_dir"],
prepend_instance_prompt=prepend_instance_prompt,
Expand All @@ -1012,6 +1012,7 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
logger.debug(
f"Pre-computing text embeds / updating cache. We have {len(captions)} captions to process, though these will be filtered next."
)
logger.debug(f"Data missing captions: {images_missing_captions}")
caption_strategy = backend.get("caption_strategy", args.caption_strategy)
info_log(
f"(id={init_backend['id']}) Initialise text embed pre-computation using the {caption_strategy} caption strategy. We have {len(captions)} captions to process."
Expand Down
92 changes: 57 additions & 35 deletions helpers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def prompt_library_injection(new_prompts: dict) -> dict:
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))


class CaptionNotFoundError(Exception):
pass


class PromptHandler:
def __init__(
self,
Expand Down Expand Up @@ -250,15 +254,15 @@ def prepare_instance_prompt_from_parquet(
image_filename_stem = os.path.splitext(image_filename_stem)[0]
image_caption = metadata_backend.caption_cache_entry(image_filename_stem)
if instance_prompt is None and fallback_caption_column and not image_caption:
raise ValueError(
raise CaptionNotFoundError(
f"Could not locate caption for image {image_path} in sampler_backend {sampler_backend_id} with filename column {filename_column}, caption column {caption_column}, and a parquet database with {len(parquet_db)} entries."
)
elif (
instance_prompt is None
and not fallback_caption_column
and not image_caption
):
raise ValueError(
raise CaptionNotFoundError(
f"Could not locate caption for image {image_path} in sampler_backend {sampler_backend_id} with filename column {filename_column}, caption column {caption_column}, and a parquet database with {len(parquet_db)} entries."
)
if type(image_caption) == bytes:
Expand Down Expand Up @@ -406,6 +410,7 @@ def get_all_captions(
instance_prompt: str = None,
) -> list:
captions = []
images_missing_captions = []
all_image_files = StateTracker.get_image_files(
data_backend_id=data_backend.id
) or data_backend.list_files(
Expand All @@ -431,38 +436,43 @@ def get_all_captions(
leave=False,
ncols=125,
):
if caption_strategy == "filename":
caption = PromptHandler.prepare_instance_prompt_from_filename(
image_path=str(image_path),
use_captions=use_captions,
prepend_instance_prompt=prepend_instance_prompt,
instance_prompt=instance_prompt,
)
elif caption_strategy == "textfile":
caption = PromptHandler.prepare_instance_prompt_from_textfile(
image_path,
use_captions=use_captions,
prepend_instance_prompt=prepend_instance_prompt,
instance_prompt=instance_prompt,
data_backend=data_backend,
)
elif caption_strategy == "parquet":
caption = PromptHandler.prepare_instance_prompt_from_parquet(
image_path,
use_captions=use_captions,
prepend_instance_prompt=prepend_instance_prompt,
instance_prompt=instance_prompt,
data_backend=data_backend,
sampler_backend_id=data_backend.id,
)
elif caption_strategy == "instanceprompt":
return [instance_prompt]
elif caption_strategy == "csv":
caption = data_backend.get_caption(image_path)
else:
raise ValueError(
f"Unsupported caption strategy: {caption_strategy}. Supported: 'filename', 'textfile', 'parquet', 'instanceprompt'"
)
try:
if caption_strategy == "filename":
caption = PromptHandler.prepare_instance_prompt_from_filename(
image_path=str(image_path),
use_captions=use_captions,
prepend_instance_prompt=prepend_instance_prompt,
instance_prompt=instance_prompt,
)
elif caption_strategy == "textfile":
caption = PromptHandler.prepare_instance_prompt_from_textfile(
image_path,
use_captions=use_captions,
prepend_instance_prompt=prepend_instance_prompt,
instance_prompt=instance_prompt,
data_backend=data_backend,
)
elif caption_strategy == "parquet":
caption = PromptHandler.prepare_instance_prompt_from_parquet(
image_path,
use_captions=use_captions,
prepend_instance_prompt=prepend_instance_prompt,
instance_prompt=instance_prompt,
data_backend=data_backend,
sampler_backend_id=data_backend.id,
)
elif caption_strategy == "instanceprompt":
return [instance_prompt]
elif caption_strategy == "csv":
caption = data_backend.get_caption(image_path)
else:
raise ValueError(
f"Unsupported caption strategy: {caption_strategy}. Supported: 'filename', 'textfile', 'parquet', 'instanceprompt'"
)
except CaptionNotFoundError as e:
logger.error(f"Could not load caption for image {image_path}: {e}")
images_missing_captions.append(image_path)
continue

if type(caption) not in [tuple, list, dict]:
captions.append(caption)
Expand All @@ -474,7 +484,19 @@ def get_all_captions(
# TODO: Investigate why this prevents captions from processing on multigpu systems.
# captions = list(set(captions))

return captions
# Remove images that didn't have captions from the list.
for image_path in images_missing_captions:
del all_image_files[image_path]

if len(images_missing_captions) > 0:
logger.info(
f"Updating image list to reflect {len(images_missing_captions)} missing captions."
)
StateTracker.set_image_files(
data_backend_id=data_backend.id, raw_file_list=all_image_files
)

return captions, images_missing_captions

@staticmethod
def filter_caption(data_backend: BaseDataBackend, caption: str) -> str:
Expand Down

0 comments on commit 24da2a7

Please sign in to comment.