From f1c91a98446c82458aff0eb2f4b599a1e94a12f2 Mon Sep 17 00:00:00 2001 From: paso Date: Sat, 27 Jul 2024 20:09:02 +0200 Subject: [PATCH] added afe extraction code in test --- src/structs/audio.jl | 28 ++++----- test/afe.jl | 140 ++++++++++++++++++++++++++++++++++++++++++ test/usage_example.jl | 2 +- 3 files changed, 152 insertions(+), 18 deletions(-) create mode 100644 test/afe.jl diff --git a/src/structs/audio.jl b/src/structs/audio.jl index 33e5822..852d959 100644 --- a/src/structs/audio.jl +++ b/src/structs/audio.jl @@ -6,18 +6,6 @@ mutable struct Audio sr::Int64 end -# keyword constructor -function Audio(; - data, - sr -) - audio = Audio(; - data, - sr - ) - return audio -end - function Base.show(io::IO, audio::Audio) print(io, "Audio(data: $(length(audio.data)) samples, sr: $(audio.sr) Hz)") end @@ -33,13 +21,19 @@ function Base.display(audio::Audio) end function load_audio(; - fname::AbstractString, - sr::Int64 = 8000, + fname::Union{AbstractString, AbstractVector{Float64}}, + sr::Union{Nothing, Int64} = nothing, norm::Bool = false ) - audio = Audio( - py"load_audio"(fname, sr)... - ) + if fname isa AbstractString + audio = Audio( + py"load_audio"(fname, sr)... + ) + elseif fname isa AbstractVector{Float64} && sr isa Int64 + audio = Audio(fname, sr) + else + throw(ArgumentError("Invalid arguments")) + end # normalize audio if norm && length(audio.data) != 0 diff --git a/test/afe.jl b/test/afe.jl new file mode 100644 index 0000000..914112c --- /dev/null +++ b/test/afe.jl @@ -0,0 +1,140 @@ +using Audio911 + +# ---------------------------------------------------------------------------- # +# utils # +# ---------------------------------------------------------------------------- # +nan_replacer!(x::AbstractArray{Float64}) = replace!(x, NaN => 0.0) + +# ---------------------------------------------------------------------------- # +# audio911 audio features extractor # +# ---------------------------------------------------------------------------- # +function audio911_extractor( + # audio module + wavfile::Union{String, AbstractVector{Float64}}; + sr::Int64=8000, + norm::Bool=true, + speech_detection::Bool=false, + # stft module + stft_length::Union{Int64, Nothing}=nothing, + win_type::Tuple{Symbol, Symbol}=(:hann, :periodic), + win_length::Union{Int64, Nothing}=nothing, + overlap_length::Union{Int64, Nothing}=nothing, + stft_norm::Symbol=:power, # :power, :magnitude, :pow2mag + # mel filterbank module + nbands::Int64=26, + scale::Symbol=:mel_htk, # :mel_htk, :mel_slaney, :erb, :bark + melfb_norm::Symbol=:bandwidth, # :bandwidth, :area, :none + freq_range::Union{Tuple{Int64, Int64}, Nothing}=nothing, + # mel spectrogram module + db_scale::Bool=false, + # mfcc module + ncoeffs::Int64=13, + rectification::Symbol=:log, # :log, :cubic_root + dither::Bool=true, + # deltas module + d_length = 9, + d_matrix = :transposed, # :standard, :transposed + # f0 module + method::Symbol=:nfc, + f0_range::Tuple{Int64, Int64}=(50, 400), + # spectral features module + spect_range::Union{Tuple{Int64, Int64}, Nothing}=nothing, +) + # audio module + audio = load_audio( + fname=wavfile, + sr=sr, + norm=norm, + ); + if speech_detection + audio = speech_detector(audio=audio); + end + + # stft module + if isnothing(stft_length) + stft_length = audio.sr <= 8000 ? 256 : 512 + end + if isnothing(win_length) + win_length = stft_length + end + if isnothing(overlap_length) + overlap_length = round(Int, stft_length / 2) + end + stftspec = get_stft( + audio=audio, + stft_length=stft_length, + win_type=win_type, + win_length=win_length, + overlap_length=overlap_length, + norm=stft_norm + ); + + # mel filterbank module + if isnothing(freq_range) + freq_range = (0, round(Int, audio.sr / 2)) + end + melfb = get_melfb( + stft=stftspec, + nbands=nbands, + scale=scale, + norm=melfb_norm, + freq_range=freq_range + ); + + # mel spectrogram module + melspec = get_melspec( + stft=stftspec, + fbank=melfb, + db_scale=db_scale + ); + + # mfcc module + mfcc = get_mfcc( + source=melspec, + ncoeffs=ncoeffs, + rectification=rectification, + dither=dither, + ); + + # deltas module + deltas = get_deltas( + source=mfcc, + d_length=d_length, + d_matrix=d_matrix + ); + + # f0 module + f0 = get_f0( + source=stftspec, + method=method, + freq_range=f0_range + ); + + # spectral features module + if isnothing(spect_range) + spect_range = freq_range + end + spect = get_spectrals( + source=stftspec, + freq_range=spect_range + ); + + return hcat( + melspec.spec', + mfcc.mfcc', + deltas.delta', + deltas.ddelta', + f0.f0, + spect.centroid, + spect.crest, + spect.entropy, + spect.flatness, + spect.flux, + spect.kurtosis, + spect.rolloff, + spect.skewness, + spect.decrease, + spect.slope, + spect.spread + ); +end diff --git a/test/usage_example.jl b/test/usage_example.jl index 6eac3bc..24b81b1 100644 --- a/test/usage_example.jl +++ b/test/usage_example.jl @@ -38,7 +38,7 @@ stftspec = get_stft( win_type = (:hann, :periodic), win_length=stft_length, overlap_length = round(Int, stft_length / 2), - norm = :power, # :none, :power, :magnitude, :pow2mag + norm = :power, # :power, :magnitude, :pow2mag ); # compute mel spectrogram part 1: create the filterbank