-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
33 changed files
with
1,997 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
# Music deepfake detector | ||
# AI-music detection study | ||
Code repository of our research paper on AI-generated music detection - D. Afchar, G. Meseguer Brocal, R. Hennequin (2024). | ||
|
||
_To be added after the conference reviews._ | ||
The FMA dataset is available at [github.com/mdeff/fma](https://github.com/mdeff/fma). | ||
|
||
_Most of our experiment code is available for the review. We will make the trained weights open source for the publication._ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import torch | ||
|
||
|
||
class AE: | ||
def __init__(self, name): | ||
self.name = name | ||
|
||
def encode(self, x): | ||
""" We assume a channel-first input """ | ||
raise NotImplementedError("Encoder") | ||
|
||
def decode(self, z): | ||
raise NotImplementedError("Decoder") | ||
|
||
def map_stack(self, x, func): | ||
if len(x.shape) == 1: | ||
return func(x) | ||
else: | ||
z = [] | ||
for c in range(x.shape[0]): | ||
z.append(func(x[c])) | ||
return torch.stack(z) | ||
|
||
def autoencode(self, x): | ||
return self.decode(self.encode(x)) | ||
|
||
|
||
def autoencode_multi(self, x, codec): | ||
raise NotImplemented("Multi-codec") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
""" | ||
DAC / LAC | ||
""" | ||
|
||
import os | ||
import sys | ||
|
||
import torch | ||
|
||
cur_dir = os.getcwd() | ||
sys.path.append(os.path.join(cur_dir, "pretrained/lac")) | ||
|
||
from ae_models.ae import AE | ||
from pretrained.lac.lac.model.lac import LAC | ||
|
||
lac_ckpt_path = os.path.join(cur_dir, "pretrained/vampnet/codec.pth") | ||
|
||
class Lac_ae(AE): | ||
def __init__(self, sr, device = "cuda"): | ||
super().__init__("LAC") | ||
|
||
self.sr = sr | ||
self.device = device | ||
self.model = LAC.load(lac_ckpt_path) | ||
self.model.eval() | ||
self.model.to(self.device) | ||
|
||
|
||
def encode(self, x): | ||
preprocess, _ = self.model.preprocess(x, self.sr) | ||
z = self.model.encode(preprocess[:, None, :], self.sr) | ||
return z['z'] | ||
|
||
|
||
def decode(self, z): | ||
y = self.model.decode(z)['audio'] | ||
y = torch.squeeze(y) | ||
return y | ||
|
||
def autoencode_multi(self, x, codec): | ||
preprocess, _ = self.model.preprocess(x, self.sr) | ||
z = self.model.encode(preprocess[:, None, :], self.sr) | ||
codes = z['codes'] | ||
|
||
decoded_audio = [] | ||
for c in codec: | ||
z_red = self.model.quantizer.from_codes(codes[:,:c,:])[0] | ||
r_audioraw = self.model.decode(z_red)['audio'] | ||
decoded_audio.append(torch.squeeze(r_audioraw)) | ||
|
||
return decoded_audio | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
""" | ||
Encodec | ||
""" | ||
|
||
import torch | ||
from transformers import AutoProcessor, EncodecModel | ||
|
||
from ae_models.ae import AE | ||
|
||
|
||
class Encodec(AE): | ||
""" | ||
Griffin-lim + mel scale inverter | ||
""" | ||
|
||
def __init__(self, bandwidth, sr, device="cuda"): | ||
super().__init__("encodec") | ||
|
||
self.bandwidth = bandwidth | ||
self.sr = sr | ||
self.device = device | ||
|
||
self.model = EncodecModel.from_pretrained("facebook/encodec_48khz") | ||
self.processor = AutoProcessor.from_pretrained("facebook/encodec_48khz") | ||
|
||
self.model.eval() | ||
self.model.to(self.device) | ||
|
||
|
||
def preprocesss(self, x): | ||
return self.processor(raw_audio=x, sampling_rate=self.processor.sampling_rate, return_tensors="pt") | ||
|
||
|
||
def encode(self, x, padding_mask): | ||
return self.model.encode(x, padding_mask, bandwidth = self.bandwidth) | ||
|
||
|
||
def decode(self, z, scales, padding_mask): | ||
audio_values = self.model.decode(z, | ||
scales, | ||
padding_mask, | ||
)[0] | ||
return audio_values | ||
|
||
|
||
def autoencode(self, x): | ||
inputs = self.processor(raw_audio=x, sampling_rate=self.processor.sampling_rate, return_tensors="pt") | ||
|
||
inputs = inputs.to(self.device) | ||
# explicitly encode then decode the audio inputs | ||
encoder_outputs = self.model.encode( | ||
inputs["input_values"], inputs["padding_mask"], | ||
bandwidth = self.bandwidth, # 3, 6, 12, 24 | ||
) | ||
audio_values = self.model.decode(encoder_outputs.audio_codes, | ||
encoder_outputs.audio_scales, | ||
inputs["padding_mask"], | ||
)[0] | ||
return audio_values[:,:,:x.shape[1]] | ||
|
||
|
||
def autoencode_multi(self, x, codec): | ||
inputs = self.processor(raw_audio=x, sampling_rate=self.processor.sampling_rate, return_tensors="pt") | ||
inputs = inputs.to(self.device) | ||
# explicitly encode then decode the audio inputs | ||
encoder_outputs = self.model.encode( | ||
inputs["input_values"], inputs["padding_mask"], | ||
bandwidth = self.bandwidth, # 3, 6, 12, 24 | ||
) | ||
audio_vals = encoder_outputs["audio_codes"] | ||
|
||
decoded_audio = [] | ||
for c in codec: | ||
num_codebooks = (c // 3) * 2 | ||
audio_vals_target = audio_vals[:,:,:num_codebooks] | ||
audio_rebuilt = self.model.decode(audio_vals_target, | ||
encoder_outputs["audio_scales"], | ||
inputs["padding_mask"], | ||
)[0] | ||
decoded_audio.append(torch.squeeze(audio_rebuilt)) | ||
|
||
return decoded_audio |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
""" | ||
GriffinLim | ||
""" | ||
|
||
import numpy as np | ||
import torch | ||
import torchaudio | ||
|
||
from ae_models.ae import AE | ||
|
||
default_params = { | ||
"n_fft": 1024, | ||
"hop_fft": 256, | ||
"win_fft": 512, | ||
|
||
"griffin_iter": 32, | ||
|
||
"n_mels": 128 | ||
} | ||
|
||
class GriffinMel(AE): | ||
""" | ||
Griffin-lim + mel scale inverter | ||
""" | ||
|
||
def __init__(self, params, sr, device="cuda"): | ||
super().__init__("mel-griffin") | ||
|
||
self.params = dict() | ||
self.params.update(default_params) | ||
self.params.update(params) | ||
|
||
self.sr = sr | ||
self.device = device | ||
|
||
self.audio2spec = torchaudio.transforms.Spectrogram( | ||
n_fft = self.params["n_fft"], | ||
hop_length = self.params["hop_fft"], # pas temporel de hop/sr | ||
win_length = self.params["win_fft"], | ||
window_fn=torch.hann_window, | ||
power = None, | ||
) | ||
self.audio2spec.to(device=torch.device(self.device), dtype=torch.float32) | ||
|
||
self.spec2audio = torchaudio.transforms.GriffinLim( | ||
n_fft=self.params["n_fft"], | ||
n_iter=self.params["griffin_iter"], | ||
win_length=self.params["win_fft"], | ||
hop_length=self.params["hop_fft"], | ||
window_fn=torch.hann_window, | ||
power = 1, | ||
) | ||
self.spec2audio.to(device=torch.device(self.device), dtype=torch.float32) | ||
|
||
self.mel_scaler = torchaudio.transforms.MelScale( | ||
n_mels=self.params["n_mels"], | ||
sample_rate=self.sr, | ||
n_stft=self.params["n_fft"] // 2 + 1, | ||
mel_scale = 'htk', # slaney | ||
) | ||
self.mel_scaler.to(device=torch.device(self.device), dtype=torch.float32) | ||
|
||
self.mel_matrix = self.mel_scaler.fb | ||
self.inv_mel_matrix = torch.linalg.pinv(self.mel_matrix).T | ||
|
||
def inverse_mel_scaler(self, mels_spec): | ||
return self.inv_mel_matrix @ mels_spec | ||
|
||
def _encode_mono(self, x): | ||
return self.mel_scaler(torch.abs(self.audio2spec(x))) | ||
|
||
def encode(self, x): | ||
return self.map_stack(x, self._encode_mono) | ||
|
||
def _decode_mono(self, z): | ||
return self.spec2audio(self.inverse_mel_scaler(z)) | ||
|
||
def decode(self, z): | ||
return self.map_stack(z, self._decode_mono) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
""" | ||
Used to pass files through the same load/resample/save pipeline without changes | ||
""" | ||
|
||
|
||
from ae_models.ae import AE | ||
|
||
|
||
class Identity(AE): | ||
def __init__(self): | ||
super().__init__("identity") | ||
|
||
def encode(self, x): | ||
return x | ||
|
||
def decode(self, x): | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
Musika | ||
""" | ||
|
||
import os | ||
import sys | ||
|
||
import torch | ||
|
||
cur_dir = os.getcwd() | ||
sys.path.append(os.path.join(cur_dir, "pretrained/musika")) | ||
|
||
from ae_models.ae import AE | ||
from pretrained.musika.models import Models_functions | ||
from pretrained.musika.parse.parse_decode import parse_args | ||
from pretrained.musika.utils import Utils_functions | ||
from pretrained.musika.utils_encode import UtilsEncode_functions | ||
|
||
# import tensorflow as tf | ||
|
||
checkpoint = os.path.join(cur_dir, "pretrained/musika/checkpoints/techno") | ||
ae_path = os.path.join(cur_dir, "pretrained/musika/checkpoints/ae") | ||
|
||
|
||
class Musika_ae(AE): | ||
def __init__(self): | ||
super().__init__("Musika") | ||
|
||
args = parse_args() | ||
|
||
args.load_path = checkpoint | ||
args.dec_path = ae_path | ||
# args.mixed_precision = False | ||
|
||
M = Models_functions(args) | ||
self.U = Utils_functions(args) | ||
self.UE = UtilsEncode_functions(args) | ||
self.models_ls = M.get_networks() | ||
|
||
def encode(self, x): | ||
x = x.cpu() | ||
return self.UE.encode_audio(x.T.numpy(), models_ls=self.models_ls) | ||
|
||
def decode(self, z): | ||
return torch.Tensor( | ||
self.U.decode_waveform(z[None, None, :,:], self.models_ls[3], self.models_ls[5], batch_size=64) | ||
).T |
Oops, something went wrong.