Skip to content

Commit

Permalink
Updating the notebooks and helper functions. Integrating the move and…
Browse files Browse the repository at this point in the history
… save into the training. Saving our config file with our trained models based on our inputs. More general cleanup.
  • Loading branch information
djbielejeski committed Mar 11, 2023
1 parent 4bdbfc5 commit 1a40a4d
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 298 deletions.
54 changes: 28 additions & 26 deletions JupyterNotebookHelpers/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from huggingface_hub import hf_hub_download
from ipywidgets import widgets, Layout, HBox


class SDModelOption:
def __init__(self, repo_id, filename, manual=False):
self.repo_id = repo_id
Expand All @@ -17,40 +18,43 @@ def download(self):
filename=self.filename
)
else:
raise Exception(f"Model not valid. repo_id: {self.repo_id} or filename: {self.filename} are missing or invalid.")
raise Exception(
f"Model not valid. repo_id: {self.repo_id} or filename: {self.filename} are missing or invalid.")

def is_valid(self):
return (self.repo_id is not None and self.repo_id != '') and \
(self.filename is not None and self.filename != '' and '.ckpt' in self.filename)
(self.filename is not None and self.filename != '' and '.ckpt' in self.filename)


class DownloadModel:
model_definitions = [
SDModelOption(repo_id="panopstor/EveryDream", filename="sd_v1-5_vae.ckpt"),
SDModelOption(repo_id="runwayml/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.ckpt"),
SDModelOption(repo_id="runwayml/stable-diffusion-v1-5", filename="v1-5-pruned.ckpt"),
SDModelOption(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4.ckpt"),
SDModelOption(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4-full-ema.ckpt"),
SDModelOption(repo_id=None, filename=None, manual=True),
]
available_models = [
("sd_v1-5_vae.ckpt - 4.27gb - EveryDream (incl. vae) - Recommended", 0),
("v1-5-pruned-emaonly.ckpt - 4.27gb - runwayml", 1),
("v1-5-pruned.ckpt - 7.7gb - runwayml", 2),
("sd-v1-4.ckpt - 4.27gb - CompVis", 3),
("sd-v1-4-full-ema.ckpt - 7.7gb - CompVis", 4),
("Manual", 5),
]

last_selected_index = 0

def __init__(
self,
style={'description_width': '150px'},
layout=Layout(width="400px")
self,
style={'description_width': '150px'},
layout=Layout(width="400px")
):
self.style = style
self.layout = layout

self.model_definitions = [
SDModelOption(repo_id="panopstor/EveryDream", filename="sd_v1-5_vae.ckpt"),
SDModelOption(repo_id="runwayml/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.ckpt"),
SDModelOption(repo_id="runwayml/stable-diffusion-v1-5", filename="v1-5-pruned.ckpt"),
SDModelOption(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4.ckpt"),
SDModelOption(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4-full-ema.ckpt"),
SDModelOption(repo_id=None, filename=None, manual=True),
]
self.available_models = [
("sd_v1-5_vae.ckpt - 4.27gb - EveryDream (incl. vae) - Recommended", 0),
("v1-5-pruned-emaonly.ckpt - 4.27gb - runwayml", 1),
("v1-5-pruned.ckpt - 7.7gb - runwayml", 2),
("sd-v1-4.ckpt - 4.27gb - CompVis", 3),
("sd-v1-4-full-ema.ckpt - 7.7gb - CompVis", 4),
("Manual", 5),
]

self.last_selected_index = 0

self.model_options = widgets.Dropdown(
options=self.available_models,
value=0,
Expand Down Expand Up @@ -122,7 +126,6 @@ def download_model(self, b):
else:
print("❌ Specified model is invalid.")


def model_options_changed(self, *args):
if self.last_selected_index is not self.model_options.value:
self.last_selected_index = self.model_options.value
Expand All @@ -143,6 +146,5 @@ def model_options_changed(self, *args):
self.model_filename_input.placeholder = selected_model.filename
self.model_filename_input.description = "Selected Filename: "


def get_selected_model(self) -> SDModelOption:
return self.model_definitions[self.model_options.value]
return self.model_definitions[self.model_options.value]
14 changes: 8 additions & 6 deletions JupyterNotebookHelpers/installer_progress_bar_widget.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from IPython.display import clear_output
from ipywidgets import widgets, Layout


class InstallerProgressBar:
show_detailed_output = False

def __init__(
self,
style = {'description_width': '150px'},
layout = Layout(width="400px"),
self,
style={'description_width': '150px'},
layout=Layout(width="400px"),
):
self.style = style
self.layout = layout
self.show_detailed_output = False

self.installer_progress_bar_widget = widgets.IntProgress(
value=0,
Expand All @@ -27,8 +29,8 @@ def show(self, install_commands):
self.installer_progress_bar_widget.max = len(install_commands)
display(self.installer_progress_bar_widget, self.output)

def increment(self, step:int):
def increment(self, step: int):
self.installer_progress_bar_widget.value = step + 1

def close(self):
self.installer_progress_bar_widget.close()
self.installer_progress_bar_widget.close()
108 changes: 50 additions & 58 deletions JupyterNotebookHelpers/setup_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
import shutil
from ipywidgets import widgets, Layout, HBox
from git import Repo
from dreambooth_helpers.joepenna_dreambooth_config import save_config_file_v1
from dreambooth_helpers.joepenna_dreambooth_config import JoePennaDreamboothConfigSchemaV1
from JupyterNotebookHelpers.download_model import SDModelOption


class SetupTraining:
form_widgets = []
training_images_save_path = "./training_images"
selected_model: SDModelOption = None

def __init__(
self,
Expand All @@ -20,6 +18,10 @@ def __init__(
input_and_description_layout=Layout(width="812px"),
):
self.form_widgets = []
self.training_images_save_path = "./training_images"
self.config_save_path = "./joepenna-dreambooth-configs"
self.selected_model: SDModelOption = None

self.style = style
self.label_style = label_style
self.layout = layout
Expand Down Expand Up @@ -184,76 +186,66 @@ def submit_form_click(self, b):
with self.output:
self.output.clear_output()

dataset = self.reg_images_select.value
uploaded_images = self.training_images_uploader.value
project_name = self.project_name_input.value
max_training_steps = int(self.max_training_steps_input.value)
class_word = self.class_word_input.value
flip_percent = float(self.flip_slider.value)
token = self.token_input.value
learning_rate = self.learning_rate_select.value
save_every_x_steps = int(self.save_every_x_steps_input.value)

if len(uploaded_images) == 0:
# training images
uploaded_training_images = self.training_images_uploader.value
if len(uploaded_training_images) == 0:
print("No training images provided, please click the 'Training Images' upload button.", file=sys.stderr)
return
else:
self.handle_training_images(uploaded_training_images)

# Regularization Images
self.download_regularization_images(dataset)

# Training images
images = self.handle_training_images(uploaded_images)

save_config_file_v1(
dataset=dataset,
project_name=project_name,
max_training_steps=max_training_steps,
training_images_count=len(images),
training_images=images,
class_word=class_word,
flip_percent=flip_percent,
token=token,
learning_rate=learning_rate,
save_every_x_steps=save_every_x_steps,
regularization_images_dataset = self.reg_images_select.value
regularization_images_folder_path = self.download_regularization_images(regularization_images_dataset)


config = JoePennaDreamboothConfigSchemaV1()
config.saturate(
project_name=self.project_name_input.value,
max_training_steps=int(self.max_training_steps_input.value),
save_every_x_steps=int(self.save_every_x_steps_input.value),
training_images_folder_path=self.training_images_save_path,
regularization_images_folder_path=regularization_images_folder_path,
token=self.token_input.value,
token_only=False,
class_word=self.class_word_input.value,
flip_percent=float(self.flip_slider.value),
learning_rate=self.learning_rate_select.value,
model_repo_id=self.selected_model.repo_id,
model_filename=self.selected_model.filename,
model_path=self.selected_model.filename,
)

def download_regularization_images(self, dataset):
config.save_config_to_file(
save_path=self.config_save_path,
create_active_config=True
)

def download_regularization_images(self, dataset) -> str:
# Download Regularization Images
repo_name = f"Stable-Diffusion-Regularization-Images-{dataset}"
regularization_images_git_folder = f"./{repo_name}"
if not os.path.exists(regularization_images_git_folder):
path_to_reg_images = os.path.join(repo_name, dataset)

if not os.path.exists(path_to_reg_images):
print(f"Downloading regularization images for {dataset}. Please wait...")
Repo.clone_from(f"https://github.com/djbielejeski/{repo_name}.git", repo_name, progress=self.log_git_progress)
Repo.clone_from(f"https://github.com/djbielejeski/{repo_name}.git", repo_name,
progress=self.log_git_progress)

print(f"✅ Regularization images for {dataset} downloaded successfully.")
regularization_images_root_folder = "regularization_images"
if not os.path.exists(regularization_images_root_folder):
os.mkdir(regularization_images_root_folder)

regularization_images_dataset_folder = f"{regularization_images_root_folder}/{dataset}"
if not os.path.exists(regularization_images_dataset_folder):
os.mkdir(regularization_images_dataset_folder)

regularization_images = os.listdir(f"{regularization_images_git_folder}/{dataset}")
for file_name in regularization_images:
shutil.move(os.path.join(f"{regularization_images_git_folder}/{dataset}", file_name),
regularization_images_dataset_folder)

else:
print(f"✅ Regularization images for {dataset} already exist. Skipping download...")

def log_git_progress(self, op_code:int, cur_count, max_count, message:str=''):
if op_code == 33: # Start, display the widget
return path_to_reg_images

def log_git_progress(self, op_code: int, cur_count, max_count, message: str = ''):
if op_code == 33: # Start, display the widget
display(self.regularization_images_progress_bar_widget)

if op_code == 32 or op_code == 256: # Fetching remote or Stage remote, update the widget
if op_code == 32 or op_code == 256: # Fetching remote or Stage remote, update the widget
self.regularization_images_progress_bar_widget.max = int(max_count)
self.regularization_images_progress_bar_widget.value = int(cur_count)
self.regularization_images_progress_bar_widget.description = f"{message}"

if op_code == 258: # Stage remote end, hide the widget
if op_code == 258: # Stage remote end, hide the widget
self.regularization_images_progress_bar_widget.close()

def handle_training_images(self, uploaded_images):
Expand All @@ -269,13 +261,13 @@ def handle_training_images(self, uploaded_images):
for i, img in enumerate(uploaded_images):
images.append(img.name)
image_widgets.append(widgets.Image(
value=img.content
value=img.content,
width=256,
height=256,
))
with open(f"{self.training_images_save_path}/{img.name}", "w+b") as image_file:
with open(os.path.join(self.training_images_save_path, img.name), "w+b") as image_file:
image_file.write(img.content)

display(HBox(image_widgets))

print(f"✅ Training images uploaded successfully.")

return images
print(f"✅ Training images uploaded successfully.")
75 changes: 11 additions & 64 deletions dreambooth_google_colab_joepenna.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"!pip install huggingface_hub\n",
"!pip install gitpython\n",
"\n",
"print(\"The instance needs to restart to apply changes.\")\n",
"\n",
"import os\n",
"os._exit(00)"
Expand Down Expand Up @@ -186,16 +187,22 @@
"#@markdown This is the unique token you are incorporating into the stable diffusion model.\n",
"token = \"firstNameLastName\" #@param {type:\"string\"}\n",
"\n",
"# 0 Saves the checkpoint when max_training_steps is reached.\n",
"# 250 saves the checkpoint every 250 steps as well as when max_training_steps is reached.\n",
"save_every_x_steps = 0\n",
"\n",
"reg_data_root = \"/content/Dreambooth-Stable-Diffusion/regularization_images/\" + dataset\n",
"\n",
"!rm -rf training_images/.ipynb_checkpoints\n",
"!python \"main.py\" \\\n",
" --project_name \"{project_name}\" \\\n",
" --debug False \\\n",
" --max_training_steps {max_training_steps} \\\n",
" --token \"{token}\" \\\n",
" --training_model \"model.ckpt\" \\\n",
" --regularization_images \"{reg_data_root}\" \\\n",
" --training_images \"/content/Dreambooth-Stable-Diffusion/training_images\" \\\n",
" --max_training_steps {max_training_steps} \\\n",
" --regularization_images \"{reg_data_root}\" \\\n",
" --class_word \"{class_word}\" \\\n",
" --token \"{token}\" \\\n",
" --flip_p {flip_p_arg} \\\n",
" --save_every_x_steps {save_every_x_steps}"
],
Expand All @@ -205,74 +212,14 @@
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title # Copy and name the checkpoint file(s)\n",
"import re\n",
"\n",
"training_images = !find training_images/*\n",
"date_string = !date +\"%Y-%m-%dT%H-%M-%S\"\n",
"\n",
"if save_every_x_steps <= 0:\n",
" # Copy the checkpoint into our `trained_models` folder\n",
" directory_paths = !ls -d logs/*\n",
" last_checkpoint_file = directory_paths[-1] + \"/ckpts/last.ckpt\"\n",
" file_name = date_string[-1] + \"_\" + \\\n",
" project_name + \"_\" + \\\n",
" str(len(training_images)) + \"_training_images_\" + \\\n",
" str(max_training_steps) + \"_max_training_steps_\" + \\\n",
" token + \"_token_\" + \\\n",
" class_word + \"_class_word.ckpt\"\n",
"\n",
" file_name = file_name.replace(\" \", \"_\")\n",
"\n",
" !mkdir -p trained_models\n",
" !mv \"{last_checkpoint_file}\" \"trained_models/{file_name}\"\n",
"\n",
" print(\"Download your trained model from trained_models/\" + file_name + \" and use in your favorite Stable Diffusion repo!\")\n",
"else:\n",
" directory_paths = !ls -d logs/*\n",
" checkpoints_directory = directory_paths[-1] + \"/ckpts/trainstep_ckpts\"\n",
" file_paths = !ls -d \"{checkpoints_directory}\"/*\n",
"\n",
" for i, original_file_name in enumerate(file_paths):\n",
" # Remove the \"epoch=000000-step=0000\" text\n",
" steps = re.sub(checkpoints_directory + \"/epoch=\\d{6}-step=0*\", \"\", original_file_name)\n",
"\n",
" # Remove the .ckpt\n",
" steps = steps.replace(\".ckpt\", \"\")\n",
"\n",
" # Setup the filename\n",
" file_name = date_string[-1] + \"_\" + \\\n",
" project_name + \"_\" + \\\n",
" str(len(training_images)) + \"_training_images_\" + \\\n",
" steps + \"_training_steps_\" + \\\n",
" token + \"_token_\" + \\\n",
" class_word + \"_class_word.ckpt\"\n",
"\n",
" file_name = file_name.replace(\" \", \"_\")\n",
"\n",
" # Make the directory and move the files into it.\n",
" !mkdir -p trained_models\n",
" !mv \"{original_file_name}\" \"trained_models/{file_name}\"\n",
"\n",
" print(\"Download your trained models from the 'trained_models' folder and use in your favorite Stable Diffusion repo!\")"
],
"metadata": {
"id": "Ll_ZIFNUulKJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title Save model in google drive\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')\n",
"\n",
"!cp trained_models/{file_name} /content/drive/MyDrive/{file_name}"
"!cp trained_models/{file_name.ckpt} /content/drive/MyDrive/{file_name.ckpt}"
],
"metadata": {
"id": "mkidEm4evn1J"
Expand Down
Loading

0 comments on commit 1a40a4d

Please sign in to comment.