Skip to content

Commit

Permalink
Fixes to teh copy and move script. Removing our global variables.
Browse files Browse the repository at this point in the history
  • Loading branch information
djbielejeski committed Mar 11, 2023
1 parent 65c1198 commit 2860620
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 114 deletions.
2 changes: 1 addition & 1 deletion dreambooth_helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def str2bool(v):
config = JoePennaDreamboothConfigSchemaV1()

if opt.config_file_path is not None:
config.load_from_file(opt.config_file_path)
config.saturate_from_file(config_file_path=opt.config_file_path)
else:
config.saturate(
project_name=opt.project_name,
Expand Down
10 changes: 4 additions & 6 deletions dreambooth_helpers/callback_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.distributed import rank_zero_only

from dreambooth_helpers.global_variables import dreambooth_global_variables

class SetupCallback(Callback):
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
Expand All @@ -31,11 +30,10 @@ def on_keyboard_interrupt(self, trainer, pl_module):

def on_fit_start(self, trainer, pl_module):
if trainer.global_rank == 0:
if dreambooth_global_variables.debug:
print("Project config")
print(OmegaConf.to_yaml(self.config))
print("Lightning config")
print(OmegaConf.to_yaml(self.lightning_config))
print("Project config")
print(OmegaConf.to_yaml(self.config))
print("Lightning config")
print(OmegaConf.to_yaml(self.lightning_config))

OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
Expand Down
20 changes: 10 additions & 10 deletions dreambooth_helpers/copy_and_name_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
import shutil
import glob
from dreambooth_helpers.joepenna_dreambooth_config import JoePennaDreamboothConfigSchemaV1
from dreambooth_helpers.global_variables import dreambooth_global_variables


def copy_and_name_checkpoints(
config: JoePennaDreamboothConfigSchemaV1,
output_folder: str,
):
output_folder = config.trained_models_directory()
if not os.path.exists(output_folder):
os.mkdir(output_folder)

Expand All @@ -18,7 +16,7 @@ def copy_and_name_checkpoints(
save_path=output_folder
)

logs_directory = dreambooth_global_variables.log_directory()
logs_directory = config.log_directory()
if not os.path.exists(logs_directory):
print(f"No checkpoints found in {logs_directory}")
return
Expand All @@ -28,25 +26,27 @@ def copy_and_name_checkpoints(
if config.save_every_x_steps == 0:
checkpoints_and_steps.append(
(
f"{dreambooth_global_variables.log_checkpoint_directory()}/last.ckpt",
os.path.join(config.log_checkpoint_directory(), "last.ckpt"),
str(config.max_training_steps)
)
)
else:
intermediate_checkpoints_directory = dreambooth_global_variables.log_intermediate_checkpoints_directory()
intermediate_checkpoints_directory = config.log_intermediate_checkpoints_directory()
file_paths = glob.glob(f"{intermediate_checkpoints_directory}/*.ckpt")

for i, original_file_name in enumerate(file_paths):
for i, original_file_path in enumerate(file_paths):
# Grab the steps from the filename
# "epoch=000000-step=000000250.ckpt" => "250.ckpt"
checkpoint_steps = re.sub(intermediate_checkpoints_directory + "/epoch=\d{6}-step=0*", "", original_file_name)
# 'logs\\2023-03-11T20-03-37_ap_v15vae_ultrahq\\ckpts\\trainstep_ckpts\\epoch=000000-step=000000250.ckpt'
file_name = os.path.basename(original_file_path)
checkpoint_steps = re.sub(r"epoch=\d{6}-step=0*", "", file_name)

# Remove the .ckpt
# "250.ckpt" => "250"
checkpoint_steps = checkpoint_steps.replace(".ckpt", "")
checkpoints_and_steps.append(
(
original_file_name,
original_file_path,
checkpoint_steps
)
)
Expand All @@ -57,7 +57,7 @@ def copy_and_name_checkpoints(
original_file_name, steps = file_and_steps[0], file_and_steps[1]

# Setup the filenames
new_file_name = config.createCheckpointFileName(steps)
new_file_name = config.create_checkpoint_file_name(steps)
output_file_name = os.path.join(output_folder, new_file_name)

if os.path.exists(original_file_name):
Expand Down
17 changes: 8 additions & 9 deletions dreambooth_helpers/dreambooth_trainer_configurations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from ldm.util import instantiate_from_config
from dreambooth_helpers.global_variables import dreambooth_global_variables
from ldm.modules.pruningckptio import PruningCheckpointIO
from dreambooth_helpers.joepenna_dreambooth_config import JoePennaDreamboothConfigSchemaV1

Expand All @@ -14,7 +13,7 @@ def metrics_over_trainsteps_checkpoint(self) -> dict:
return {
"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
"params": {
"dirpath": dreambooth_global_variables.log_intermediate_checkpoints_directory(),
"dirpath": self.config.log_intermediate_checkpoints_directory(),
"filename": "{epoch:06}-{step:09}",
"verbose": True,
"save_top_k": -1,
Expand Down Expand Up @@ -45,7 +44,7 @@ def model_checkpoint(self) -> dict:
return {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": dreambooth_global_variables.log_checkpoint_directory(),
"dirpath": self.config.log_checkpoint_directory(),
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
Expand All @@ -58,10 +57,10 @@ def setup_callback(self, model_data_config, lightning_config) -> dict:
"target": "dreambooth_helpers.callback_helpers.SetupCallback",
"params": {
"resume": "",
"now": dreambooth_global_variables.training_folder_name,
"logdir": dreambooth_global_variables.log_directory(),
"ckptdir": dreambooth_global_variables.log_checkpoint_directory(),
"cfgdir": dreambooth_global_variables.log_config_directory(),
"now": self.config.get_training_folder_name(),
"logdir": self.config.log_directory(),
"ckptdir": self.config.log_checkpoint_directory(),
"cfgdir": self.config.log_config_directory(),
"config": model_data_config,
"lightning_config": lightning_config,
}
Expand Down Expand Up @@ -256,7 +255,7 @@ def get_dreambooth_trainer_config(config: JoePennaDreamboothConfigSchemaV1, mode
"target": "pytorch_lightning.loggers.CSVLogger",
"params": {
"name": "CSVLogger",
"save_dir": dreambooth_global_variables.log_directory(),
"save_dir": config.log_directory(),
}
},
"checkpoint_callback": cb.model_checkpoint()
Expand All @@ -268,7 +267,7 @@ def get_dreambooth_trainer_config(config: JoePennaDreamboothConfigSchemaV1, mode
trainer_config["checkpoint_callback"]["params"]["monitor"] = model.monitor
trainer_config["checkpoint_callback"]["params"]["save_top_k"] = 1

if dreambooth_global_variables.debug:
if config.debug:
print(f"Monitoring {model.monitor} as checkpoint metric.")

return trainer_config
Expand Down
43 changes: 0 additions & 43 deletions dreambooth_helpers/global_variables.py

This file was deleted.

59 changes: 45 additions & 14 deletions dreambooth_helpers/joepenna_dreambooth_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class JoePennaDreamboothConfigSchemaV1:
def __init__(self):
self.schema: int = 1
self.config_date_time: str = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
self.config_date_time: str = ''
self.project_config_filename: str = ''

# Project
Expand Down Expand Up @@ -56,20 +56,28 @@ def saturate(
flip_percent: float,
learning_rate: float,
model_path: str,
config_date_time: str = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S"),
config_date_time: str = None,
seed: int = 23,
debug: bool = False,
gpu: int = 0,
model_repo_id: str = '',
token_only: bool = False,
):
self.project_config_filename = f"{self.config_date_time}-{project_name}-config.json"

# Map the values
self.project_name = project_name
if self.project_name is None or self.project_name == '':
raise Exception("'--project_name': Required.")

if config_date_time is None:
self.config_date_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
else:
self.config_date_time = config_date_time

# parameter values
self.project_config_filename = f"{self.config_date_time}-{self.project_name}-config.json"


self.seed = seed

# Global seed
Expand Down Expand Up @@ -128,10 +136,10 @@ def saturate(
if not os.path.exists(self.model_path):
raise Exception(f"Model Path Not Found: '{self.model_path}'.")

self.config_date_time = config_date_time

self.validate_gpu_vram()

self._create_log_folders()

def validate_gpu_vram(self):
def convert_size(size_bytes):
if size_bytes == 0:
Expand All @@ -150,7 +158,7 @@ def convert_size(size_bytes):
if gpu_vram < twenty_one_gigabytes:
raise Exception(f"VRAM: Currently unable to run on less than {convert_size(twenty_one_gigabytes)} of VRAM.")

def load_from_file(
def saturate_from_file(
self,
config_file_path: str,
):
Expand All @@ -162,32 +170,31 @@ def load_from_file(
config_parsed = json.load(config_file)

if config_parsed['schema'] == 1:
return JoePennaDreamboothConfigSchemaV1(
self.saturate(
project_name=config_parsed['project_name'],
seed=config_parsed['seed'],
debug=config_parsed['debug'],
gpu=config_parsed['gpu'],
max_training_steps=config_parsed['max_training_steps'],
save_every_x_steps=config_parsed['save_every_x_steps'],
training_images_folder_path=config_parsed['training_images_folder_path'],
training_images=config_parsed['training_images'],
regularization_images_folder_path=config_parsed['regularization_images_folder_path'],
token=config_parsed['token'],
token_only=config_parsed['token_only'],
class_word=config_parsed['class_word'],
flip_percent=config_parsed['flip_percent'],
learning_rate=config_parsed['learning_rate'],
model_repo_id=config_parsed['model_repo_id'],
model_path=config_parsed['model_path'],
config_date_time=config_parsed['config_date_time'],
seed=config_parsed['seed'],
debug=config_parsed['debug'],
gpu=config_parsed['gpu'],
model_repo_id=config_parsed['model_repo_id'],
token_only=config_parsed['token_only'],
)
else:
print(f"Unrecognized schema: {config_parsed['schema']}", file=sys.stderr)

def toJSON(self):
return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4)

def createCheckpointFileName(self, steps):
def create_checkpoint_file_name(self, steps):
date_string = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
return f"{date_string}_{self.project_name}_" \
f"{steps}_steps_" \
Expand All @@ -212,3 +219,27 @@ def save_config_to_file(
shutil.copy(config_save_path, os.path.join(save_path, "active-config.json"))
print(project_config_json)
print(f"✅ {self.project_config_filename} successfully generated. Proceed to training.")

def get_training_folder_name(self) -> str:
return f"{self.config_date_time}_{self.project_name}"

def log_directory(self) -> str:
return os.path.join("logs", self.get_training_folder_name())

def log_checkpoint_directory(self) -> str:
return os.path.join(self.log_directory(), "ckpts")

def log_intermediate_checkpoints_directory(self) -> str:
return os.path.join(self.log_checkpoint_directory(), "trainstep_ckpts")

def log_config_directory(self) -> str:
return os.path.join(self.log_directory(), "configs")

def trained_models_directory(self) -> str:
return "trained_models"

def _create_log_folders(self):
os.makedirs(self.log_directory(), exist_ok=True)
os.makedirs(self.log_checkpoint_directory(), exist_ok=True)
os.makedirs(self.log_config_directory(), exist_ok=True)
os.makedirs(self.trained_models_directory(), exist_ok=True)
8 changes: 3 additions & 5 deletions ldm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
from PIL import Image, ImageDraw, ImageFont

from dreambooth_helpers.global_variables import dreambooth_global_variables


def log_txt_as_img(wh, xc, size=10):
Expand Down Expand Up @@ -75,8 +74,7 @@ def count_params(model, verbose=False):
return total_params

def load_model_from_config(config, ckpt, verbose=False):
if dreambooth_global_variables.debug:
print(f"Loading model from {ckpt}")
print(f"Loading model from {ckpt}")

pl_sd = torch.load(ckpt, map_location="cpu")
sd = pl_sd["state_dict"]
Expand All @@ -93,10 +91,10 @@ def load_model_from_config(config, ckpt, verbose=False):
print("")
print("")

if len(m) > 0 and verbose and dreambooth_global_variables.debug:
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose and dreambooth_global_variables.debug:
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)

Expand Down
Loading

0 comments on commit 2860620

Please sign in to comment.