diff --git a/06_gpu_and_ml/text-to-audio/musicgen.py b/06_gpu_and_ml/text-to-audio/musicgen.py index 078c1a1b9..5c978a56e 100644 --- a/06_gpu_and_ml/text-to-audio/musicgen.py +++ b/06_gpu_and_ml/text-to-audio/musicgen.py @@ -1,27 +1,52 @@ # # Create your own music samples with MusicGen # MusicGen is a popular open-source music-generation model family from Meta. -# In this example, we show you how you can run MusicGen models on Modal. +# In this example, we show you how you can run MusicGen models on Modal GPUs, +# along with a Gradio UI for playing around with the model. -# We use [Audiocraft](https://github.com/facebookresearch/audiocraft), an inference library for audio models, -# including MusicGen and its kin, like AudioGen. +# We use [Audiocraft](https://github.com/facebookresearch/audiocraft), +# the inference library released by Meta +# for MusicGen and its kin, like AudioGen. -# ## Setting up the image and dependencies +# ## Setting up dependencies from pathlib import Path from uuid import uuid4 import modal -app = modal.App("example-musicgen") +# We start by defining the environment our generation runs in. +# This takes some explaining since, like most cutting-edge ML environments, it is a bit fiddly. -MAX_SEGMENT_DURATION = 30 +# This environment is captured by a +# [container image](https://modal.com/docs/guide/custom-container), +# which we build step-by-step by calling methods to add dependencies, +# like `apt_install` to add system packages and `pip_install` to add +# Python packages. -cache_dir = "/cache" -model_cache = modal.Volume.from_name( - "audiocraft-model-cache", create_if_missing=True +# Note that we don't have to install anything with "CUDA" +# in the name -- the drivers come for free with the Modal environment +# and the rest gets installed `pip`. That makes our life a lot easier! +# If you want to see the details, check out [this guide](https://modal.com/docs/guide/gpu) +# in our docs. + +image = ( + modal.Image.debian_slim(python_version="3.11") + .apt_install("git", "ffmpeg") + .pip_install( + "huggingface_hub[hf_transfer]==0.27.1", # speed up model downloads + "torch==2.1.0", # version pinned by audiocraft + "numpy<2", # defensively cap the numpy version + "git+https://github.com/facebookresearch/audiocraft.git@v1.3.0", # we can install directly from GitHub! + ) ) +# In addition to source code, we'll also need the model weights. + +# Audiocraft integrates with the Hugging Face ecosystem, so setting up the models +# is straightforward -- the same `get_pretrained` method we use to load the weights for execution +# will also download them if they aren't present. + def load_model(and_return=False): from audiocraft.models import MusicGen @@ -31,32 +56,60 @@ def load_model(and_return=False): return model_large -image = ( - modal.Image.debian_slim(python_version="3.11") - .apt_install("git", "ffmpeg") - .pip_install( - "torch==2.1.0", # version needed for audiocraft - "pydub==0.25.1", - "numpy<2", - "git+https://github.com/facebookresearch/audiocraft.git@v1.3.0", - "huggingface_hub[hf_transfer]==0.27.1", - ) - .env({"HF_HUB_CACHE": cache_dir, "HF_HUB_ENABLE_HF_TRANSER": "1"}) - .run_function(load_model, volumes={cache_dir: model_cache}) +# But Modal Functions are serverless: instances spin down when they aren't being used. +# If we want to avoid downloading the weights every time we start a new instance, +# we need to store the weights somewhere besides our local filesystem. + +# So we add a Modal [Volume](https://modal.com/docs/guide/volumes) +# to store the weights in the cloud. + +cache_dir = "/cache" +model_cache = modal.Volume.from_name( + "audiocraft-model-cache", create_if_missing=True +) + +# We don't need to change any of the model loading code -- +# we just need to make sure the model gets stored in the right directory. + +# To do that, we set an environment variable that Hugging Face expects +# (and another one that speeds up downloads, for good measure) +# and then run the `load_model` Python function. + +image = image.env( + {"HF_HUB_CACHE": cache_dir, "HF_HUB_ENABLE_HF_TRANSER": "1"} +).run_function(load_model, volumes={cache_dir: model_cache}) + +# While we're at it, let's also define the environment for our UI. +# We'll stick with Python and so use FastAPI and Gradio. + +web_image = modal.Image.debian_slim(python_version="3.11").pip_install( + "fastapi[standard]==0.115.4", "gradio==4.44.1" ) -with image.imports(): - import torch +# This is a totally different environment from the one we run our model in. +# Say goodbye to Python dependency conflict hell! -# ## Defining the model generation +# ## Running music generation on Modal -# We then write our model code within Modal's -# [`@app.cls`](/docs/reference/modal.App#cls) decorator, with the -# [`generate`] function processing the user input and generating audio as bytes that we can -# save to a file later. +# Now, we write our music generation logic. +# This is bit complicated because we want to support generating long samples, +# but the model has a maximum context length of thirty seconds. +# We can get longer clips by feeding the model's output back as input, +# auto-regressively, but we have to write that ourselves. +# There are also a few bits to make this work well with Modal: -@app.cls(gpu=modal.gpu.L40S(), image=image, volumes={cache_dir: model_cache}) +# - We make an [App](https://modal.com/docs/guide/apps) to organize our deployment +# - We load the model at start, instead of during inference, with `modal.enter`, +# which requires that we use a Modal [`Cls`](https://modal.com/docs/guide/lifecycle-functions) +# - In the `app.cls` decorator, we specify the Image we built and attach the Volume. +# We also pick a GPU to run on -- here, an NVIDIA L40S. + +app = modal.App("example-musicgen") +MAX_SEGMENT_DURATION = 30 # maximum context window size + + +@app.cls(gpu="l40s", image=image, volumes={cache_dir: model_cache}) class MusicGen: @modal.enter() def init(self): @@ -68,7 +121,7 @@ def generate( prompt: str, duration: int = 10, overlap: int = 10, - format: str = "wav", + format: str = "wav", # or mp3 ) -> bytes: f"""Generate a music clip based on the prompt. @@ -90,15 +143,17 @@ def generate( segment_duration = min(segment_duration, MAX_SEGMENT_DURATION) # generate next segment - self.model.set_generation_params(duration=segment_duration) - next_segment = self._generate_next_segment(prompt, context, overlap) - - # update remaining duration - remaining_duration -= ( + generated_duration = ( segment_duration if context is None else (segment_duration - overlap) ) + print(f"🎼 generating {generated_duration} seconds of music") + self.model.set_generation_params(duration=segment_duration) + next_segment = self._generate_next_segment(prompt, context, overlap) + + # update remaining duration + remaining_duration -= generated_duration # combine with previous segments context = self._combine_segments(context, next_segment, overlap) @@ -106,7 +161,12 @@ def generate( output = context.detach().cpu().float()[0] return to_audio_bytes( - output, self.model.sample_rate, strategy="loudness", format=format + output, + self.model.sample_rate, + format=format, + # for more on audio encoding parameters, see the docs for audiocraft + strategy="loudness", + loudness_compressor=True, ) def _generate_next_segment(self, prompt, context, overlap): @@ -114,32 +174,46 @@ def _generate_next_segment(self, prompt, context, overlap): if context is None: return self.model.generate(descriptions=[prompt]) else: - last_chunk = context[:, :, -overlap * self.model.sample_rate :] + overlap_samples = overlap * self.model.sample_rate + last_chunk = context[:, :, -overlap_samples:] # B, C, T return self.model.generate_continuation( last_chunk, self.model.sample_rate, descriptions=[prompt] ) def _combine_segments(self, context, next_segment, overlap: int): """Combine context with next segment, handling overlap.""" + import torch + if context is None: return next_segment # Calculate where to trim the context (removing overlap) - trim_samples = overlap * self.model.sample_rate - context_trimmed = context[:, :, :-trim_samples] # B, C, T + overlap_samples = overlap * self.model.sample_rate + context_trimmed = context[:, :, :-overlap_samples] # B, C, T return torch.cat([context_trimmed, next_segment], dim=2) -# We can call MusicGen inference from our local machine by running the code in the local entrypoint below. +# We can then generate music from anywhere by running code like what we have in the `local_entrypoint` below. + +# You can execute it with a command like: + +# ``` shell +# modal run musicgen.py --prompt="Baroque boy band, Bachstreet Boys, basso continuo, Top 40 pop music" --duration=60 +# ``` @app.local_entrypoint() -def main(prompt: str = None, duration: int = 10, format: str = "wav"): +def main( + prompt: str = None, + duration: int = 10, + overlap: int = 15, + format: str = "wav", # or mp3 +): if prompt is None: prompt = "Amapiano polka, klezmers, log drum bassline, 112 BPM" print( - f"🎵 generating music from prompt '{prompt[:64] + ('...' if len(prompt) > 64 else '')}'" + f"🎼 generating {duration} seconds of music from prompt '{prompt[:64] + ('...' if len(prompt) > 64 else '')}'" ) audiocraft = MusicGen() @@ -149,26 +223,23 @@ def main(prompt: str = None, duration: int = 10, format: str = "wav"): dir.mkdir(exist_ok=True, parents=True) output_path = dir / f"{slugify(prompt)[:64]}.{format}" - print(f"🎵 Saving to {output_path}") + print(f"🎼 Saving to {output_path}") output_path.write_bytes(clip) -# You can execute the local entrypoint with: +# ## Hosting a web UI for the music generator -# ``` shell -# modal run musicgen.py --prompt="metallica meets sabrina carpenter" -# ``` - -# ## A hosted Gradio interface +# With the Gradio library, we can create a simple web UI in Python +# that calls out to our music generator, +# then host it on Modal for anyone to try out. -# With the Gradio library, we can create a simple web interface around our class in Python, then use Modal to host it for anyone to try out. -# To deploy your own, run +# To deploy both the music generator and the UI, run # ``` shell # modal deploy musicgen.py # ``` -web_image = image.pip_install("fastapi[standard]==0.115.4", "gradio==4.44.1") +# Share the URL with your friends and they can generate their own songs! @app.function( @@ -188,20 +259,23 @@ def ui(): api = FastAPI() # Since this Gradio app is running from its own container, - # allowing us to run the inference service via .remote() methods. + # we make a `.remote` call to the music generator model = MusicGen() + generate = model.generate.remote temp_dir = Path("/dev/shm") async def generate_music( prompt: str, duration: int = 10, format: str = "wav" ): - audio_bytes = await model.generate.remote.aio(prompt, duration, format) + audio_bytes = await generate.aio( + prompt, duration=duration, format=format + ) - audio_file = f"{temp_dir}/{uuid4()}.{format}" - audio_file.write_bytes(audio_bytes) + audio_path = temp_dir / f"{uuid4()}.{format}" + audio_path.write_bytes(audio_bytes) - return audio_file + return audio_path with gr.Blocks(theme="soft") as demo: gr.Markdown("# MusicGen") @@ -209,7 +283,7 @@ async def generate_music( with gr.Column(): prompt = gr.Textbox(label="Prompt") duration = gr.Number( - label="Duration (seconds)", value=10, minimum=1, maximum=30 + label="Duration (seconds)", value=10, minimum=1, maximum=300 ) format = gr.Radio(["wav", "mp3"], label="Format", value="wav") btn = gr.Button("Generate") @@ -225,6 +299,12 @@ async def generate_music( return mount_gradio_app(app=api, blocks=demo, path="/") +# ## Addenda + +# The remainder of the code here is not directly related to Modal +# or to music generation, but is used in the example above. + + def to_audio_bytes(wav, sample_rate: int, **kwargs) -> bytes: from audiocraft.data.audio import audio_write