Skip to content

Commit

Permalink
finish prose, simplify Gradio
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesfrye committed Jan 17, 2025
1 parent 1bcb1d1 commit 82e03fa
Showing 1 changed file with 138 additions and 58 deletions.
196 changes: 138 additions & 58 deletions 06_gpu_and_ml/text-to-audio/musicgen.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,52 @@
# # Create your own music samples with MusicGen

# MusicGen is a popular open-source music-generation model family from Meta.
# In this example, we show you how you can run MusicGen models on Modal.
# In this example, we show you how you can run MusicGen models on Modal GPUs,
# along with a Gradio UI for playing around with the model.

# We use [Audiocraft](https://github.com/facebookresearch/audiocraft), an inference library for audio models,
# including MusicGen and its kin, like AudioGen.
# We use [Audiocraft](https://github.com/facebookresearch/audiocraft),
# the inference library released by Meta
# for MusicGen and its kin, like AudioGen.

# ## Setting up the image and dependencies
# ## Setting up dependencies

from pathlib import Path
from uuid import uuid4

import modal

app = modal.App("example-musicgen")
# We start by defining the environment our generation runs in.
# This takes some explaining since, like most cutting-edge ML environments, it is a bit fiddly.

MAX_SEGMENT_DURATION = 30
# This environment is captured by a
# [container image](https://modal.com/docs/guide/custom-container),
# which we build step-by-step by calling methods to add dependencies,
# like `apt_install` to add system packages and `pip_install` to add
# Python packages.

cache_dir = "/cache"
model_cache = modal.Volume.from_name(
"audiocraft-model-cache", create_if_missing=True
# Note that we don't have to install anything with "CUDA"
# in the name -- the drivers come for free with the Modal environment
# and the rest gets installed `pip`. That makes our life a lot easier!
# If you want to see the details, check out [this guide](https://modal.com/docs/guide/gpu)
# in our docs.

image = (
modal.Image.debian_slim(python_version="3.11")
.apt_install("git", "ffmpeg")
.pip_install(
"huggingface_hub[hf_transfer]==0.27.1", # speed up model downloads
"torch==2.1.0", # version pinned by audiocraft
"numpy<2", # defensively cap the numpy version
"git+https://github.com/facebookresearch/[email protected]", # we can install directly from GitHub!
)
)

# In addition to source code, we'll also need the model weights.

# Audiocraft integrates with the Hugging Face ecosystem, so setting up the models
# is straightforward -- the same `get_pretrained` method we use to load the weights for execution
# will also download them if they aren't present.


def load_model(and_return=False):
from audiocraft.models import MusicGen
Expand All @@ -31,32 +56,60 @@ def load_model(and_return=False):
return model_large


image = (
modal.Image.debian_slim(python_version="3.11")
.apt_install("git", "ffmpeg")
.pip_install(
"torch==2.1.0", # version needed for audiocraft
"pydub==0.25.1",
"numpy<2",
"git+https://github.com/facebookresearch/[email protected]",
"huggingface_hub[hf_transfer]==0.27.1",
)
.env({"HF_HUB_CACHE": cache_dir, "HF_HUB_ENABLE_HF_TRANSER": "1"})
.run_function(load_model, volumes={cache_dir: model_cache})
# But Modal Functions are serverless: instances spin down when they aren't being used.
# If we want to avoid downloading the weights every time we start a new instance,
# we need to store the weights somewhere besides our local filesystem.

# So we add a Modal [Volume](https://modal.com/docs/guide/volumes)
# to store the weights in the cloud.

cache_dir = "/cache"
model_cache = modal.Volume.from_name(
"audiocraft-model-cache", create_if_missing=True
)

# We don't need to change any of the model loading code --
# we just need to make sure the model gets stored in the right directory.

# To do that, we set an environment variable that Hugging Face expects
# (and another one that speeds up downloads, for good measure)
# and then run the `load_model` Python function.

image = image.env(
{"HF_HUB_CACHE": cache_dir, "HF_HUB_ENABLE_HF_TRANSER": "1"}
).run_function(load_model, volumes={cache_dir: model_cache})

# While we're at it, let's also define the environment for our UI.
# We'll stick with Python and so use FastAPI and Gradio.

web_image = modal.Image.debian_slim(python_version="3.11").pip_install(
"fastapi[standard]==0.115.4", "gradio==4.44.1"
)

with image.imports():
import torch
# This is a totally different environment from the one we run our model in.
# Say goodbye to Python dependency conflict hell!

# ## Defining the model generation
# ## Running music generation on Modal

# We then write our model code within Modal's
# [`@app.cls`](/docs/reference/modal.App#cls) decorator, with the
# [`generate`] function processing the user input and generating audio as bytes that we can
# save to a file later.
# Now, we write our music generation logic.
# This is bit complicated because we want to support generating long samples,
# but the model has a maximum context length of thirty seconds.
# We can get longer clips by feeding the model's output back as input,
# auto-regressively, but we have to write that ourselves.

# There are also a few bits to make this work well with Modal:

@app.cls(gpu=modal.gpu.L40S(), image=image, volumes={cache_dir: model_cache})
# - We make an [App](https://modal.com/docs/guide/apps) to organize our deployment
# - We load the model at start, instead of during inference, with `modal.enter`,
# which requires that we use a Modal [`Cls`](https://modal.com/docs/guide/lifecycle-functions)
# - In the `app.cls` decorator, we specify the Image we built and attach the Volume.
# We also pick a GPU to run on -- here, an NVIDIA L40S.

app = modal.App("example-musicgen")
MAX_SEGMENT_DURATION = 30 # maximum context window size


@app.cls(gpu="l40s", image=image, volumes={cache_dir: model_cache})
class MusicGen:
@modal.enter()
def init(self):
Expand All @@ -68,7 +121,7 @@ def generate(
prompt: str,
duration: int = 10,
overlap: int = 10,
format: str = "wav",
format: str = "wav", # or mp3
) -> bytes:
f"""Generate a music clip based on the prompt.
Expand All @@ -90,56 +143,77 @@ def generate(
segment_duration = min(segment_duration, MAX_SEGMENT_DURATION)

# generate next segment
self.model.set_generation_params(duration=segment_duration)
next_segment = self._generate_next_segment(prompt, context, overlap)

# update remaining duration
remaining_duration -= (
generated_duration = (
segment_duration
if context is None
else (segment_duration - overlap)
)
print(f"🎼 generating {generated_duration} seconds of music")
self.model.set_generation_params(duration=segment_duration)
next_segment = self._generate_next_segment(prompt, context, overlap)

# update remaining duration
remaining_duration -= generated_duration

# combine with previous segments
context = self._combine_segments(context, next_segment, overlap)

output = context.detach().cpu().float()[0]

return to_audio_bytes(
output, self.model.sample_rate, strategy="loudness", format=format
output,
self.model.sample_rate,
format=format,
# for more on audio encoding parameters, see the docs for audiocraft
strategy="loudness",
loudness_compressor=True,
)

def _generate_next_segment(self, prompt, context, overlap):
"""Generate the next audio segment, either fresh or as continuation of a context."""
if context is None:
return self.model.generate(descriptions=[prompt])
else:
last_chunk = context[:, :, -overlap * self.model.sample_rate :]
overlap_samples = overlap * self.model.sample_rate
last_chunk = context[:, :, -overlap_samples:] # B, C, T
return self.model.generate_continuation(
last_chunk, self.model.sample_rate, descriptions=[prompt]
)

def _combine_segments(self, context, next_segment, overlap: int):
"""Combine context with next segment, handling overlap."""
import torch

if context is None:
return next_segment

# Calculate where to trim the context (removing overlap)
trim_samples = overlap * self.model.sample_rate
context_trimmed = context[:, :, :-trim_samples] # B, C, T
overlap_samples = overlap * self.model.sample_rate
context_trimmed = context[:, :, :-overlap_samples] # B, C, T

return torch.cat([context_trimmed, next_segment], dim=2)


# We can call MusicGen inference from our local machine by running the code in the local entrypoint below.
# We can then generate music from anywhere by running code like what we have in the `local_entrypoint` below.

# You can execute it with a command like:

# ``` shell
# modal run musicgen.py --prompt="Baroque boy band, Bachstreet Boys, basso continuo, Top 40 pop music" --duration=60
# ```


@app.local_entrypoint()
def main(prompt: str = None, duration: int = 10, format: str = "wav"):
def main(
prompt: str = None,
duration: int = 10,
overlap: int = 15,
format: str = "wav", # or mp3
):
if prompt is None:
prompt = "Amapiano polka, klezmers, log drum bassline, 112 BPM"
print(
f"🎵 generating music from prompt '{prompt[:64] + ('...' if len(prompt) > 64 else '')}'"
f"🎼 generating {duration} seconds of music from prompt '{prompt[:64] + ('...' if len(prompt) > 64 else '')}'"
)

audiocraft = MusicGen()
Expand All @@ -149,26 +223,23 @@ def main(prompt: str = None, duration: int = 10, format: str = "wav"):
dir.mkdir(exist_ok=True, parents=True)

output_path = dir / f"{slugify(prompt)[:64]}.{format}"
print(f"🎵 Saving to {output_path}")
print(f"🎼 Saving to {output_path}")
output_path.write_bytes(clip)


# You can execute the local entrypoint with:
# ## Hosting a web UI for the music generator

# ``` shell
# modal run musicgen.py --prompt="metallica meets sabrina carpenter"
# ```

# ## A hosted Gradio interface
# With the Gradio library, we can create a simple web UI in Python
# that calls out to our music generator,
# then host it on Modal for anyone to try out.

# With the Gradio library, we can create a simple web interface around our class in Python, then use Modal to host it for anyone to try out.
# To deploy your own, run
# To deploy both the music generator and the UI, run

# ``` shell
# modal deploy musicgen.py
# ```

web_image = image.pip_install("fastapi[standard]==0.115.4", "gradio==4.44.1")
# Share the URL with your friends and they can generate their own songs!


@app.function(
Expand All @@ -188,28 +259,31 @@ def ui():
api = FastAPI()

# Since this Gradio app is running from its own container,
# allowing us to run the inference service via .remote() methods.
# we make a `.remote` call to the music generator
model = MusicGen()
generate = model.generate.remote

temp_dir = Path("/dev/shm")

async def generate_music(
prompt: str, duration: int = 10, format: str = "wav"
):
audio_bytes = await model.generate.remote.aio(prompt, duration, format)
audio_bytes = await generate.aio(
prompt, duration=duration, format=format
)

audio_file = f"{temp_dir}/{uuid4()}.{format}"
audio_file.write_bytes(audio_bytes)
audio_path = temp_dir / f"{uuid4()}.{format}"
audio_path.write_bytes(audio_bytes)

return audio_file
return audio_path

with gr.Blocks(theme="soft") as demo:
gr.Markdown("# MusicGen")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
duration = gr.Number(
label="Duration (seconds)", value=10, minimum=1, maximum=30
label="Duration (seconds)", value=10, minimum=1, maximum=300
)
format = gr.Radio(["wav", "mp3"], label="Format", value="wav")
btn = gr.Button("Generate")
Expand All @@ -225,6 +299,12 @@ async def generate_music(
return mount_gradio_app(app=api, blocks=demo, path="/")


# ## Addenda

# The remainder of the code here is not directly related to Modal
# or to music generation, but is used in the example above.


def to_audio_bytes(wav, sample_rate: int, **kwargs) -> bytes:
from audiocraft.data.audio import audio_write

Expand Down

0 comments on commit 82e03fa

Please sign in to comment.