Skip to content

Commit

Permalink
Update xtts_demo.py added export model button
Browse files Browse the repository at this point in the history
The exported model folder will contain these files

`dataset.zip`
`model.pth`
`config.json`
`vocab.json`

That will also remove the Optimizer State from the model making the exported model smaller and take up less space.

After all your just going to be using it for the inference anyway
  • Loading branch information
DrewThomasson authored Oct 18, 2024
1 parent dc9c132 commit 2309048
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions TTS/demos/xtts_ft_demo/xtts_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,38 @@
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts

def export_model(output_path):
try:
output_folder = os.path.join(output_path, "Finished_model_files")
os.makedirs(output_folder, exist_ok=True)

with zipfile.ZipFile(os.path.join(output_folder, "dataset.zip"), "w", zipfile.ZIP_DEFLATED) as zipf:
dataset_path = os.path.join(output_path, "dataset")
for root, dirs, files in os.walk(dataset_path):
for file in files:
zipf.write(os.path.join(root, file),
os.path.relpath(os.path.join(root, file), os.path.join(dataset_path, "..")))

search_path = os.path.join(output_path, "run", "training", "**", "best_model.pth")
model_path = max(glob.glob(search_path, recursive=True), key=os.path.getctime)

checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
del checkpoint["optimizer"]
for key in list(checkpoint["model"].keys()):
if "dvae" in key:
del checkpoint["model"][key]

torch.save(checkpoint, os.path.join(output_folder, "model.pth"))

model_dir = os.path.dirname(model_path)
shutil.copy2(os.path.join(model_dir, "config.json"), output_folder)
shutil.copy2(os.path.join(model_dir, "vocab.json"), output_folder)

return f"Model exported successfully to {output_folder}"
except Exception as e:
return f"Export failed: {str(e)}"



def clear_gpu_cache():
# clear the GPU cache
Expand Down Expand Up @@ -384,6 +416,10 @@ def train_model(
progress_gen = gr.Label(label="Progress:")
tts_output_audio = gr.Audio(label="Generated Audio.")
reference_audio = gr.Audio(label="Reference audio used.")
# Button to export the fine-tuned model with optimizers removed
with gr.Column() as col4:
export_btn = gr.Button(value="Export Fine-tuned Model")
export_progress = gr.Label(label="Export Progress:")

prompt_compute_btn.click(
fn=preprocess_dataset,
Expand Down Expand Up @@ -429,5 +465,11 @@ def train_model(
],
outputs=[progress_gen, tts_output_audio, reference_audio],
)

export_btn.click(
fn=export_model,
inputs=[out_path],
outputs=[export_progress],
)

demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0")

0 comments on commit 2309048

Please sign in to comment.