Skip to content

Commit

Permalink
Merge pull request #45 from krasserm/wip-audio-generation
Browse files Browse the repository at this point in the history
Symbolic audio generation using Perceiver AR
  • Loading branch information
krasserm authored May 8, 2023
2 parents f61d88a + 8a54908 commit c519d6b
Show file tree
Hide file tree
Showing 39 changed files with 17,644 additions and 14,491 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/code-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@ jobs:
with:
fetch-depth: 0
lfs: true
- uses: actions/setup-python@v2

- uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: set PY
run: echo "PY=$(python -VV | sha256sum | cut -d' ' -f1)" >> $GITHUB_ENV
- uses: actions/cache@v2
with:
path: ~/.cache/pre-commit
key: pre-commit|${{ env.PY }}|${{ hashFiles('.pre-commit-config.yaml') }}

- uses: pre-commit/action@v2.0.3
- uses: pre-commit/action@v3.0.0
# this action also provides an additional behaviour when used in private repositories
# when configured with a github token, the action will push back fixes to the pull request branch
with:
Expand Down
44 changes: 44 additions & 0 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
name: Build and publish Docker image

on:
push:
branches:
- "main"
tags:
- "*"

env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}

jobs:
docker:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3

- name: Set release tag
id: vars
run: |
TAG_NAME=${{github.ref_name}}
echo "tag=${TAG_NAME/main/latest}" >> $GITHUB_OUTPUT
- name: Print release tag
run: |
echo Building Docker images with tag: ${{ steps.vars.outputs.tag }}
- name: Log in to the Container registry
uses: docker/login-action@v2
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Perceiver IO Docker image
uses: docker/build-push-action@v4
with:
file: Dockerfile
context: .
push: true
tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.vars.outputs.tag }}
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,10 @@ repos:
rev: 4.0.1
hooks:
- id: flake8

- repo: https://github.com/python-poetry/poetry
rev: 1.2.1
hooks:
- id: poetry-check
- id: poetry-lock
args: [ "--check" ]
28 changes: 27 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ and the 🤗 Perceiver UTF-8 bytes tokenizer.
### Via pip

```shell
pip install perceiver-io[text,vision]
pip install perceiver-io[text,vision,audio]
```

### From sources
Expand Down Expand Up @@ -102,6 +102,8 @@ See [Docker image](docs/docker-image.md) for details.

### Inference

#### Optical flow

Compute the optical flow between consecutive frames of an input video and write the rendered results to an output
video:

Expand Down Expand Up @@ -136,6 +138,30 @@ Here is a side-by-side comparison of the input and output video:
<img src="docs/images/optical-flow.gif" alt="optical-flow-sbs">
</p>

#### Symbolic audio generation

Create audio sequences by generating symbolic ([MIDI](https://en.wikipedia.org/wiki/MIDI)) audio data and converting the
generated audio symbols into WAV output using [fluidsynth](https://www.fluidsynth.org/) (_Note:_ fluidsynth must be installed
in order for the following example to work):

```python
from transformers import pipeline
from pretty_midi import PrettyMIDI
from perceiver.model.audio import symbolic # auto-class registration

repo_id = "krasserm/perceiver-ar-sam-giant-midi"

prompt = PrettyMIDI("prompt.mid")
audio_generator = pipeline("symbolic-audio-generation", model=repo_id)

output = audio_generator(prompt, max_new_tokens=64, num_latents=1, do_sample=True, top_p=0.95, temperature=1.0, render=True)

with open("generated_audio.wav", "wb") as f:
f.write(output["generated_audio_wav"])
```

Examples of generated audio sequences are available on the 🤗 [hub](https://huggingface.co/krasserm/perceiver-ar-sam-giant-midi#audio-samples).

See [inference examples](https://colab.research.google.com/github/krasserm/perceiver-io/blob/main/examples/inference.ipynb)
for more examples.

Expand Down
17 changes: 17 additions & 0 deletions docs/pretrained-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,20 @@ processor = AutoImageProcessor.from_pretrained(repo_id)

classifier_pipeline = pipeline("image-classification", model=repo_id)
```

### [krasserm/perceiver-ar-symbolic-audio](https://huggingface.co/krasserm/perceiver-ar-symbolic-audio)

A medium Perceiver AR audio model trained on the [GiantMIDI-Piano](https://github.com/bytedance/GiantMIDI-Piano) dataset
in [this training example](training-examples.md#giantmidi-piano).
The model has 134M parameters and was trained for 27 epochs (153M tokens per epoch).

```python
from transformers import pipeline
from perceiver.model.audio.symbolic import PerceiverSymbolicAudioModel

repo_id = "krasserm/perceiver-ar-symbolic-audio"

model = PerceiverSymbolicAudioModel.from_pretrained(repo_id)

audio_generation_pipeline = pipeline("symbolic-audio-generation", model=repo_id)
```
45 changes: 45 additions & 0 deletions docs/training-examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,48 @@ each.
```shell
bash examples/training/clm/train_fsdp.sh
```

### Symbolic audio modeling

#### Maestro V3

Train a small, randomly initialized Perceiver AR audio model (28.5M parameters) with the task of autoregressive
symbolic audio modeling on the [Maestro V3](https://magenta.tensorflow.org/datasets/maestro#v300) dataset.
This example is configured to run on 2 RTX 3090 GPUs with 24GB memory each.

- Data prep (command line): [examples/training/sam/maestrov3/prep.sh](../examples/training/sam/maestrov3/prep.sh)
```shell
bash examples/training/sam/maestrov3/prep.sh
```

- Training (command line): [examples/training/sam/maestrov3/train.sh](../examples/training/sam/maestrov3/train.sh)
```shell
bash examples/training/sam/maestrov3/train.sh
```

- Training (Python script): [examples/training/sam/maestrov3/train.py](../examples/training/sam/maestrov3/train.py)
```shell
python examples/training/sam/maestrov3/train.py
```

#### GiantMIDI-Piano

Train a medium, randomly initialized Perceiver AR audio model (~134M parameters) with the task of autoregressive
symbolic audio modeling on the [GiantMIDI-Piano](https://github.com/bytedance/GiantMIDI-Piano) dataset.
This example uses a custom dataset split for the training and validation subsets.
The example is configured to run on 2 RTX 3090 GPUs with 24GB memory each.

- Data prep (command line): [examples/training/sam/giantmidi/prep.sh](../examples/training/sam/giantmidi/prep.sh)
```shell
bash examples/training/sam/giantmidi/prep.sh
```

- Training (command line): [examples/training/sam/giantmidi/train.sh](../examples/training/sam/giantmidi/train.sh)
```shell
bash examples/training/sam/giantmidi/train.sh
```

- Training (Python script): [examples/training/sam/giantmidi/train.py](../examples/training/sam/giantmidi/train.py)
```shell
python examples/training/sam/giantmidi/train.py
```
7 changes: 7 additions & 0 deletions examples/convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import jsonargparse
from perceiver.model.audio import symbolic as sam

from perceiver.model.text import classifier as txt_clf, clm, mlm
from perceiver.model.vision import image_classifier as img_clf, optical_flow as opt_flow
Expand Down Expand Up @@ -57,6 +58,12 @@ def convert_training_checkpoints(output_dir, **kwargs):
**kwargs,
)

sam.convert_checkpoint(
save_dir=os.path.join(output_dir, "perceiver-ar-sam-giant-midi"),
ckpt_url=checkpoint_url("sam/version_1/checkpoints/epoch=027-val_loss=1.944.ckpt"),
**kwargs,
)


if __name__ == "__main__":
parser = jsonargparse.ArgumentParser(description="Convert official models and training checkpoint")
Expand Down
Loading

0 comments on commit c519d6b

Please sign in to comment.