Skip to content

Commit

Permalink
fix: implemented a cleaning step to remove detections above the nyqui…
Browse files Browse the repository at this point in the history
…st limit
  • Loading branch information
mbsantiago committed Nov 24, 2023
1 parent 986cfc4 commit 860e63d
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions batdetect2/utils/detector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from typing import Any, Iterator, List, Optional, Tuple, Union

import librosa
import numpy as np
import pandas as pd
import torch
Expand Down Expand Up @@ -66,7 +67,6 @@ def list_audio_files(ip_dir: str) -> List[str]:
Raises:
FileNotFoundError: Input directory not found.
"""
matches = []
for root, _, filenames in os.walk(ip_dir):
Expand Down Expand Up @@ -269,6 +269,7 @@ def convert_results(
spec_feats,
cnn_feats,
spec_slices,
nyquist_freq: Optional[float] = None,
) -> RunResults:
"""Convert results to dictionary as expected by the annotation tool.
Expand All @@ -284,8 +285,8 @@ def convert_results(
Returns:
dict: Dictionary with results.
"""

pred_dict = format_single_result(
file_id,
time_exp,
Expand All @@ -294,6 +295,14 @@ def convert_results(
params["class_names"],
)

# Remove high frequency detections
if nyquist_freq is not None:
pred_dict["annotation"] = [
pred
for pred in pred_dict["annotation"]
if pred["high_freq"] <= nyquist_freq
]

# combine into final results dictionary
results: RunResults = {
"pred_dict": pred_dict,
Expand Down Expand Up @@ -326,7 +335,6 @@ def save_results_to_file(results, op_path: str) -> None:
Args:
results (dict): Results.
op_path (str): Output path.
"""
# make directory if it does not exist
if not os.path.isdir(os.path.dirname(op_path)):
Expand Down Expand Up @@ -488,7 +496,6 @@ def iterate_over_chunks(
chunk_start : float
Start time of chunk in seconds.
chunk : np.ndarray
"""
nsamples = audio.shape[0]
duration_full = nsamples / samplerate
Expand Down Expand Up @@ -694,7 +701,6 @@ def process_audio_array(
The array is of shape (num_detections, num_features).
spec : torch.Tensor
Spectrogram of the audio used as input.
"""
pred_nms, features, spec = _process_audio_array(
audio,
Expand Down Expand Up @@ -746,6 +752,10 @@ def process_file(
cnn_feats = []
spec_slices = []

# Get original sampling rate
file_samp_rate = librosa.get_samplerate(audio_file)
orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1

# load audio file
sampling_rate, audio_full = au.load_audio(
audio_file,
Expand Down Expand Up @@ -793,9 +803,7 @@ def process_file(

if config["spec_slices"]:
# FIX: This is not currently working. Returns empty slices
spec_slices.extend(
feats.extract_spec_slices(spec_np, pred_nms)
)
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))

# Merge results from chunks
predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
Expand All @@ -815,6 +823,7 @@ def process_file(
spec_feats=spec_feats,
cnn_feats=cnn_feats,
spec_slices=spec_slices,
nyquist_freq=orig_samp_rate / 2,
)

# summarize results
Expand Down

0 comments on commit 860e63d

Please sign in to comment.