diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py index 7ac38ed6ee..07585cf006 100644 --- a/TTS/demos/xtts_ft_demo/xtts_demo.py +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -14,6 +14,38 @@ from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts +def export_model(output_path): + try: + output_folder = os.path.join(output_path, "Finished_model_files") + os.makedirs(output_folder, exist_ok=True) + + with zipfile.ZipFile(os.path.join(output_folder, "dataset.zip"), "w", zipfile.ZIP_DEFLATED) as zipf: + dataset_path = os.path.join(output_path, "dataset") + for root, dirs, files in os.walk(dataset_path): + for file in files: + zipf.write(os.path.join(root, file), + os.path.relpath(os.path.join(root, file), os.path.join(dataset_path, ".."))) + + search_path = os.path.join(output_path, "run", "training", "**", "best_model.pth") + model_path = max(glob.glob(search_path, recursive=True), key=os.path.getctime) + + checkpoint = torch.load(model_path, map_location=torch.device("cpu")) + del checkpoint["optimizer"] + for key in list(checkpoint["model"].keys()): + if "dvae" in key: + del checkpoint["model"][key] + + torch.save(checkpoint, os.path.join(output_folder, "model.pth")) + + model_dir = os.path.dirname(model_path) + shutil.copy2(os.path.join(model_dir, "config.json"), output_folder) + shutil.copy2(os.path.join(model_dir, "vocab.json"), output_folder) + + return f"Model exported successfully to {output_folder}" + except Exception as e: + return f"Export failed: {str(e)}" + + def clear_gpu_cache(): # clear the GPU cache @@ -384,6 +416,10 @@ def train_model( progress_gen = gr.Label(label="Progress:") tts_output_audio = gr.Audio(label="Generated Audio.") reference_audio = gr.Audio(label="Reference audio used.") + # Button to export the fine-tuned model with optimizers removed + with gr.Column() as col4: + export_btn = gr.Button(value="Export Fine-tuned Model") + export_progress = gr.Label(label="Export Progress:") prompt_compute_btn.click( fn=preprocess_dataset, @@ -429,5 +465,11 @@ def train_model( ], outputs=[progress_gen, tts_output_audio, reference_audio], ) + + export_btn.click( + fn=export_model, + inputs=[out_path], + outputs=[export_progress], + ) demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0")