-
Notifications
You must be signed in to change notification settings - Fork 187
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1bcb1d1
commit 82e03fa
Showing
1 changed file
with
138 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,52 @@ | ||
# # Create your own music samples with MusicGen | ||
|
||
# MusicGen is a popular open-source music-generation model family from Meta. | ||
# In this example, we show you how you can run MusicGen models on Modal. | ||
# In this example, we show you how you can run MusicGen models on Modal GPUs, | ||
# along with a Gradio UI for playing around with the model. | ||
|
||
# We use [Audiocraft](https://github.com/facebookresearch/audiocraft), an inference library for audio models, | ||
# including MusicGen and its kin, like AudioGen. | ||
# We use [Audiocraft](https://github.com/facebookresearch/audiocraft), | ||
# the inference library released by Meta | ||
# for MusicGen and its kin, like AudioGen. | ||
|
||
# ## Setting up the image and dependencies | ||
# ## Setting up dependencies | ||
|
||
from pathlib import Path | ||
from uuid import uuid4 | ||
|
||
import modal | ||
|
||
app = modal.App("example-musicgen") | ||
# We start by defining the environment our generation runs in. | ||
# This takes some explaining since, like most cutting-edge ML environments, it is a bit fiddly. | ||
|
||
MAX_SEGMENT_DURATION = 30 | ||
# This environment is captured by a | ||
# [container image](https://modal.com/docs/guide/custom-container), | ||
# which we build step-by-step by calling methods to add dependencies, | ||
# like `apt_install` to add system packages and `pip_install` to add | ||
# Python packages. | ||
|
||
cache_dir = "/cache" | ||
model_cache = modal.Volume.from_name( | ||
"audiocraft-model-cache", create_if_missing=True | ||
# Note that we don't have to install anything with "CUDA" | ||
# in the name -- the drivers come for free with the Modal environment | ||
# and the rest gets installed `pip`. That makes our life a lot easier! | ||
# If you want to see the details, check out [this guide](https://modal.com/docs/guide/gpu) | ||
# in our docs. | ||
|
||
image = ( | ||
modal.Image.debian_slim(python_version="3.11") | ||
.apt_install("git", "ffmpeg") | ||
.pip_install( | ||
"huggingface_hub[hf_transfer]==0.27.1", # speed up model downloads | ||
"torch==2.1.0", # version pinned by audiocraft | ||
"numpy<2", # defensively cap the numpy version | ||
"git+https://github.com/facebookresearch/[email protected]", # we can install directly from GitHub! | ||
) | ||
) | ||
|
||
# In addition to source code, we'll also need the model weights. | ||
|
||
# Audiocraft integrates with the Hugging Face ecosystem, so setting up the models | ||
# is straightforward -- the same `get_pretrained` method we use to load the weights for execution | ||
# will also download them if they aren't present. | ||
|
||
|
||
def load_model(and_return=False): | ||
from audiocraft.models import MusicGen | ||
|
@@ -31,32 +56,60 @@ def load_model(and_return=False): | |
return model_large | ||
|
||
|
||
image = ( | ||
modal.Image.debian_slim(python_version="3.11") | ||
.apt_install("git", "ffmpeg") | ||
.pip_install( | ||
"torch==2.1.0", # version needed for audiocraft | ||
"pydub==0.25.1", | ||
"numpy<2", | ||
"git+https://github.com/facebookresearch/[email protected]", | ||
"huggingface_hub[hf_transfer]==0.27.1", | ||
) | ||
.env({"HF_HUB_CACHE": cache_dir, "HF_HUB_ENABLE_HF_TRANSER": "1"}) | ||
.run_function(load_model, volumes={cache_dir: model_cache}) | ||
# But Modal Functions are serverless: instances spin down when they aren't being used. | ||
# If we want to avoid downloading the weights every time we start a new instance, | ||
# we need to store the weights somewhere besides our local filesystem. | ||
|
||
# So we add a Modal [Volume](https://modal.com/docs/guide/volumes) | ||
# to store the weights in the cloud. | ||
|
||
cache_dir = "/cache" | ||
model_cache = modal.Volume.from_name( | ||
"audiocraft-model-cache", create_if_missing=True | ||
) | ||
|
||
# We don't need to change any of the model loading code -- | ||
# we just need to make sure the model gets stored in the right directory. | ||
|
||
# To do that, we set an environment variable that Hugging Face expects | ||
# (and another one that speeds up downloads, for good measure) | ||
# and then run the `load_model` Python function. | ||
|
||
image = image.env( | ||
{"HF_HUB_CACHE": cache_dir, "HF_HUB_ENABLE_HF_TRANSER": "1"} | ||
).run_function(load_model, volumes={cache_dir: model_cache}) | ||
|
||
# While we're at it, let's also define the environment for our UI. | ||
# We'll stick with Python and so use FastAPI and Gradio. | ||
|
||
web_image = modal.Image.debian_slim(python_version="3.11").pip_install( | ||
"fastapi[standard]==0.115.4", "gradio==4.44.1" | ||
) | ||
|
||
with image.imports(): | ||
import torch | ||
# This is a totally different environment from the one we run our model in. | ||
# Say goodbye to Python dependency conflict hell! | ||
|
||
# ## Defining the model generation | ||
# ## Running music generation on Modal | ||
|
||
# We then write our model code within Modal's | ||
# [`@app.cls`](/docs/reference/modal.App#cls) decorator, with the | ||
# [`generate`] function processing the user input and generating audio as bytes that we can | ||
# save to a file later. | ||
# Now, we write our music generation logic. | ||
# This is bit complicated because we want to support generating long samples, | ||
# but the model has a maximum context length of thirty seconds. | ||
# We can get longer clips by feeding the model's output back as input, | ||
# auto-regressively, but we have to write that ourselves. | ||
|
||
# There are also a few bits to make this work well with Modal: | ||
|
||
@app.cls(gpu=modal.gpu.L40S(), image=image, volumes={cache_dir: model_cache}) | ||
# - We make an [App](https://modal.com/docs/guide/apps) to organize our deployment | ||
# - We load the model at start, instead of during inference, with `modal.enter`, | ||
# which requires that we use a Modal [`Cls`](https://modal.com/docs/guide/lifecycle-functions) | ||
# - In the `app.cls` decorator, we specify the Image we built and attach the Volume. | ||
# We also pick a GPU to run on -- here, an NVIDIA L40S. | ||
|
||
app = modal.App("example-musicgen") | ||
MAX_SEGMENT_DURATION = 30 # maximum context window size | ||
|
||
|
||
@app.cls(gpu="l40s", image=image, volumes={cache_dir: model_cache}) | ||
class MusicGen: | ||
@modal.enter() | ||
def init(self): | ||
|
@@ -68,7 +121,7 @@ def generate( | |
prompt: str, | ||
duration: int = 10, | ||
overlap: int = 10, | ||
format: str = "wav", | ||
format: str = "wav", # or mp3 | ||
) -> bytes: | ||
f"""Generate a music clip based on the prompt. | ||
|
@@ -90,56 +143,77 @@ def generate( | |
segment_duration = min(segment_duration, MAX_SEGMENT_DURATION) | ||
|
||
# generate next segment | ||
self.model.set_generation_params(duration=segment_duration) | ||
next_segment = self._generate_next_segment(prompt, context, overlap) | ||
|
||
# update remaining duration | ||
remaining_duration -= ( | ||
generated_duration = ( | ||
segment_duration | ||
if context is None | ||
else (segment_duration - overlap) | ||
) | ||
print(f"🎼 generating {generated_duration} seconds of music") | ||
self.model.set_generation_params(duration=segment_duration) | ||
next_segment = self._generate_next_segment(prompt, context, overlap) | ||
|
||
# update remaining duration | ||
remaining_duration -= generated_duration | ||
|
||
# combine with previous segments | ||
context = self._combine_segments(context, next_segment, overlap) | ||
|
||
output = context.detach().cpu().float()[0] | ||
|
||
return to_audio_bytes( | ||
output, self.model.sample_rate, strategy="loudness", format=format | ||
output, | ||
self.model.sample_rate, | ||
format=format, | ||
# for more on audio encoding parameters, see the docs for audiocraft | ||
strategy="loudness", | ||
loudness_compressor=True, | ||
) | ||
|
||
def _generate_next_segment(self, prompt, context, overlap): | ||
"""Generate the next audio segment, either fresh or as continuation of a context.""" | ||
if context is None: | ||
return self.model.generate(descriptions=[prompt]) | ||
else: | ||
last_chunk = context[:, :, -overlap * self.model.sample_rate :] | ||
overlap_samples = overlap * self.model.sample_rate | ||
last_chunk = context[:, :, -overlap_samples:] # B, C, T | ||
return self.model.generate_continuation( | ||
last_chunk, self.model.sample_rate, descriptions=[prompt] | ||
) | ||
|
||
def _combine_segments(self, context, next_segment, overlap: int): | ||
"""Combine context with next segment, handling overlap.""" | ||
import torch | ||
|
||
if context is None: | ||
return next_segment | ||
|
||
# Calculate where to trim the context (removing overlap) | ||
trim_samples = overlap * self.model.sample_rate | ||
context_trimmed = context[:, :, :-trim_samples] # B, C, T | ||
overlap_samples = overlap * self.model.sample_rate | ||
context_trimmed = context[:, :, :-overlap_samples] # B, C, T | ||
|
||
return torch.cat([context_trimmed, next_segment], dim=2) | ||
|
||
|
||
# We can call MusicGen inference from our local machine by running the code in the local entrypoint below. | ||
# We can then generate music from anywhere by running code like what we have in the `local_entrypoint` below. | ||
|
||
# You can execute it with a command like: | ||
|
||
# ``` shell | ||
# modal run musicgen.py --prompt="Baroque boy band, Bachstreet Boys, basso continuo, Top 40 pop music" --duration=60 | ||
# ``` | ||
|
||
|
||
@app.local_entrypoint() | ||
def main(prompt: str = None, duration: int = 10, format: str = "wav"): | ||
def main( | ||
prompt: str = None, | ||
duration: int = 10, | ||
overlap: int = 15, | ||
format: str = "wav", # or mp3 | ||
): | ||
if prompt is None: | ||
prompt = "Amapiano polka, klezmers, log drum bassline, 112 BPM" | ||
print( | ||
f"🎵 generating music from prompt '{prompt[:64] + ('...' if len(prompt) > 64 else '')}'" | ||
f"🎼 generating {duration} seconds of music from prompt '{prompt[:64] + ('...' if len(prompt) > 64 else '')}'" | ||
) | ||
|
||
audiocraft = MusicGen() | ||
|
@@ -149,26 +223,23 @@ def main(prompt: str = None, duration: int = 10, format: str = "wav"): | |
dir.mkdir(exist_ok=True, parents=True) | ||
|
||
output_path = dir / f"{slugify(prompt)[:64]}.{format}" | ||
print(f"🎵 Saving to {output_path}") | ||
print(f"🎼 Saving to {output_path}") | ||
output_path.write_bytes(clip) | ||
|
||
|
||
# You can execute the local entrypoint with: | ||
# ## Hosting a web UI for the music generator | ||
|
||
# ``` shell | ||
# modal run musicgen.py --prompt="metallica meets sabrina carpenter" | ||
# ``` | ||
|
||
# ## A hosted Gradio interface | ||
# With the Gradio library, we can create a simple web UI in Python | ||
# that calls out to our music generator, | ||
# then host it on Modal for anyone to try out. | ||
|
||
# With the Gradio library, we can create a simple web interface around our class in Python, then use Modal to host it for anyone to try out. | ||
# To deploy your own, run | ||
# To deploy both the music generator and the UI, run | ||
|
||
# ``` shell | ||
# modal deploy musicgen.py | ||
# ``` | ||
|
||
web_image = image.pip_install("fastapi[standard]==0.115.4", "gradio==4.44.1") | ||
# Share the URL with your friends and they can generate their own songs! | ||
|
||
|
||
@app.function( | ||
|
@@ -188,28 +259,31 @@ def ui(): | |
api = FastAPI() | ||
|
||
# Since this Gradio app is running from its own container, | ||
# allowing us to run the inference service via .remote() methods. | ||
# we make a `.remote` call to the music generator | ||
model = MusicGen() | ||
generate = model.generate.remote | ||
|
||
temp_dir = Path("/dev/shm") | ||
|
||
async def generate_music( | ||
prompt: str, duration: int = 10, format: str = "wav" | ||
): | ||
audio_bytes = await model.generate.remote.aio(prompt, duration, format) | ||
audio_bytes = await generate.aio( | ||
prompt, duration=duration, format=format | ||
) | ||
|
||
audio_file = f"{temp_dir}/{uuid4()}.{format}" | ||
audio_file.write_bytes(audio_bytes) | ||
audio_path = temp_dir / f"{uuid4()}.{format}" | ||
audio_path.write_bytes(audio_bytes) | ||
|
||
return audio_file | ||
return audio_path | ||
|
||
with gr.Blocks(theme="soft") as demo: | ||
gr.Markdown("# MusicGen") | ||
with gr.Row(): | ||
with gr.Column(): | ||
prompt = gr.Textbox(label="Prompt") | ||
duration = gr.Number( | ||
label="Duration (seconds)", value=10, minimum=1, maximum=30 | ||
label="Duration (seconds)", value=10, minimum=1, maximum=300 | ||
) | ||
format = gr.Radio(["wav", "mp3"], label="Format", value="wav") | ||
btn = gr.Button("Generate") | ||
|
@@ -225,6 +299,12 @@ async def generate_music( | |
return mount_gradio_app(app=api, blocks=demo, path="/") | ||
|
||
|
||
# ## Addenda | ||
|
||
# The remainder of the code here is not directly related to Modal | ||
# or to music generation, but is used in the example above. | ||
|
||
|
||
def to_audio_bytes(wav, sample_rate: int, **kwargs) -> bytes: | ||
from audiocraft.data.audio import audio_write | ||
|
||
|