From 36d616530a9e03172214f8c64396b6220ea01e82 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Wed, 2 Aug 2023 20:00:56 +0100 Subject: [PATCH 1/3] tests: Added test for using the spec_feature flag --- batdetect2/detector/compute_features.py | 345 +++++++++++++++++------- batdetect2/types.py | 57 +++- batdetect2/utils/detector_utils.py | 4 +- pyproject.toml | 2 +- tests/test_cli.py | 41 +++ tests/test_features.py | 87 ++++++ 6 files changed, 441 insertions(+), 95 deletions(-) create mode 100644 tests/test_features.py diff --git a/batdetect2/detector/compute_features.py b/batdetect2/detector/compute_features.py index 368c2db..6abb145 100644 --- a/batdetect2/detector/compute_features.py +++ b/batdetect2/detector/compute_features.py @@ -1,22 +1,27 @@ +"""Functions to compute features from predictions.""" +from typing import Dict, Optional + import numpy as np +from batdetect2 import types +from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ + def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq): + """Convert spectrogram index to frequency in Hz.""" "" spec_ind = spec_height - spec_ind return round( (spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2 ) -def extract_spec_slices(spec, pred_nms, params): - """ - Extracts spectrogram slices from spectrogram based on detected call locations. - """ +def extract_spec_slices(spec, pred_nms): + """Extract spectrogram slices from spectrogram. + The slices are extracted based on detected call locations. + """ x_pos = pred_nms["x_pos"] - y_pos = pred_nms["y_pos"] bb_width = pred_nms["bb_width"] - bb_height = pred_nms["bb_height"] slices = [] # add 20% padding either side of call @@ -35,100 +40,258 @@ def extract_spec_slices(spec, pred_nms, params): return slices -def get_feature_names(): - feature_names = [ - "duration", - "low_freq_bb", - "high_freq_bb", - "bandwidth", - "max_power_bb", - "max_power", - "max_power_first", - "max_power_second", - "call_interval", - ] - return feature_names - - -def get_feats(spec, pred_nms, params): +def compute_duration( + prediction: types.Prediction, + **_, +) -> float: + """Compute duration of call in seconds.""" + return round(prediction["end_time"] - prediction["start_time"], 5) + + +def compute_low_freq( + prediction: types.Prediction, + **_, +) -> float: + """Compute lowest frequency in call in Hz.""" + return int(prediction["low_freq"]) + + +def compute_high_freq( + prediction: types.Prediction, + **_, +) -> float: + """Compute highest frequency in call in Hz.""" + return int(prediction["high_freq"]) + + +def compute_bandwidth( + prediction: types.Prediction, + **_, +) -> float: + """Compute bandwidth of call in Hz.""" + return int(prediction["high_freq"] - prediction["low_freq"]) + + +def compute_max_power_bb( + prediction: types.Prediction, + spec: Optional[np.ndarray] = None, + min_freq: int = MIN_FREQ_HZ, + max_freq: int = MAX_FREQ_HZ, + **_, +) -> float: + """Compute frequency with maximum power in call in Hz. + + This is the frequency with the maximum power in the bounding box of the + call. """ - Extracts features from spectrogram based on detected call locations. - Condsider re-extracting spectrogram for this to get better temporal resolution. + if spec is None: + return np.nan + + x_start = max(0, prediction["x_pos"]) + x_end = min(spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]) + + y_low = max(0, prediction["y_pos"]) + y_high = min( + spec.shape[0] - 1, prediction["y_pos"] + prediction["bb_height"] + ) + + spec_bb = spec[y_low:y_high, x_start:x_end] + power_per_freq_band = np.sum(spec_bb, axis=1) + max_power_ind = np.argmax(power_per_freq_band) + return int( + convert_int_to_freq( + max_power_ind, + spec.shape[0], + min_freq, + max_freq, + ) + ) + + +def compute_max_power( + prediction: types.Prediction, + spec: Optional[np.ndarray] = None, + min_freq: int = MIN_FREQ_HZ, + max_freq: int = MAX_FREQ_HZ, + **_, +) -> float: + """Compute frequency with maximum power in during the call in Hz.""" + if spec is None: + return np.nan + + x_start = max(0, prediction["x_pos"]) + x_end = min(spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]) + spec_call = spec[:, x_start:x_end] + power_per_freq_band = np.sum(spec_call, axis=1) + max_power_ind = np.argmax(power_per_freq_band) + return int( + convert_int_to_freq( + max_power_ind, + spec.shape[0], + min_freq, + max_freq, + ) + ) + + +def compute_max_power_first( + prediction: types.Prediction, + spec: Optional[np.ndarray] = None, + min_freq: int = MIN_FREQ_HZ, + max_freq: int = MAX_FREQ_HZ, + **_, +) -> float: + """Compute frequency with maximum power in first half of call in Hz.""" + if spec is None: + return np.nan + + x_start = max(0, prediction["x_pos"]) + x_end = min(spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]) + spec_call = spec[:, x_start:x_end] + first_half = spec_call[:, : int(spec_call.shape[1] / 2)] + power_per_freq_band = np.sum(first_half, axis=1) + max_power_ind = np.argmax(power_per_freq_band) + return int( + convert_int_to_freq( + max_power_ind, + spec.shape[0], + min_freq, + max_freq, + ) + ) + + +def compute_max_power_second( + prediction: types.Prediction, + spec: Optional[np.ndarray] = None, + min_freq: int = MIN_FREQ_HZ, + max_freq: int = MAX_FREQ_HZ, + **_, +) -> float: + """Compute frequency with maximum power in second half of call in Hz.""" + if spec is None: + return np.nan + + x_start = max(0, prediction["x_pos"]) + x_end = min(spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]) + spec_call = spec[:, x_start:x_end] + second_half = spec_call[:, int(spec_call.shape[1] / 2) :] + power_per_freq_band = np.sum(second_half, axis=1) + max_power_ind = np.argmax(power_per_freq_band) + return int( + convert_int_to_freq( + max_power_ind, + spec.shape[0], + min_freq, + max_freq, + ) + ) + + +def compute_call_interval( + prediction: types.Prediction, + previous_prediction: Optional[types.Prediction] = None, + **_, +) -> float: + """Compute time between this call and the previous call in seconds.""" + if previous_prediction is None: + return np.nan + return round(prediction["start_time"] - previous_prediction["end_time"], 5) + + +# NOTE: The order of the features in this dictionary is important. The +# features are extracted in this order and the order of the columns in the +# output csv file is determined by this order. In order to avoid breaking +# changes in the output csv file, new features should be added to the end of +# this dictionary. +FEATURES: Dict[str, types.FeatureExtractor] = { + "duration": compute_duration, + "low_freq_bb": compute_low_freq, + "high_freq_bb": compute_high_freq, + "bandwidth": compute_bandwidth, + "max_power_bb": compute_max_power_bb, + "max_power": compute_max_power, + "max_power_first": compute_max_power_first, + "max_power_second": compute_max_power_second, + "call_interval": compute_call_interval, +} + + +def get_feats( + spec: np.ndarray, + pred_nms: types.PredictionResults, + params: types.FeatureExtractionParameters, +): + """Extract features from spectrogram based on detected call locations. + + The features extracted are: + + - duration: duration of call in seconds + - low_freq: lowest frequency in call in kHz + - high_freq: highest frequency in call in kHz + - bandwidth: high_freq - low_freq + - max_power_bb: frequency with maximum power in call in kHz + - max_power: frequency with maximum power in spectrogram in kHz + - max_power_first: frequency with maximum power in first half of call in + kHz. + - max_power_second: frequency with maximum power in second half of call in + kHz. + - call_interval: time between this call and the previous call in seconds + + Consider re-extracting spectrogram for this to get better temporal + resolution. For more possible features check out: https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt - """ - x_pos = pred_nms["x_pos"] - y_pos = pred_nms["y_pos"] - bb_width = pred_nms["bb_width"] - bb_height = pred_nms["bb_height"] + Parameters + ---------- + spec : np.ndarray + Spectrogram from which to extract features. - feature_names = get_feature_names() + pred_nms : types.PredictionResults + Information about detected calls from which to extract features. + + params : types.FeatureExtractionParameters + Parameters for feature extraction. + + Returns + ------- + features : np.ndarray + Extracted features for each detected call. Shape is + (num_detections, num_features). + """ num_detections = len(pred_nms["det_probs"]) - features = ( - np.ones((num_detections, len(feature_names)), dtype=np.float32) * -1 - ) + features = np.empty((num_detections, len(FEATURES)), dtype=np.float32) + previous = None - for ff in range(num_detections): - x_start = int(np.maximum(0, x_pos[ff])) - x_end = int( - np.minimum(spec.shape[1] - 1, np.round(x_pos[ff] + bb_width[ff])) - ) - # y low is the lowest freq but it will have a higher value due to array starting at 0 at top - y_low = int(np.minimum(spec.shape[0] - 1, y_pos[ff])) - y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff]))) - spec_slice = spec[:, x_start:x_end] - - if spec_slice.shape[1] > 1: - features[ff, 0] = round( - pred_nms["end_times"][ff] - pred_nms["start_times"][ff], 5 - ) - features[ff, 1] = int(pred_nms["low_freqs"][ff]) - features[ff, 2] = int(pred_nms["high_freqs"][ff]) - features[ff, 3] = int( - pred_nms["high_freqs"][ff] - pred_nms["low_freqs"][ff] - ) - features[ff, 4] = int( - convert_int_to_freq( - y_high + spec_slice[y_high:y_low, :].sum(1).argmax(), - spec.shape[0], - params["min_freq"], - params["max_freq"], - ) - ) - features[ff, 5] = int( - convert_int_to_freq( - spec_slice.sum(1).argmax(), - spec.shape[0], - params["min_freq"], - params["max_freq"], - ) - ) - hlf_val = spec_slice.shape[1] // 2 - - features[ff, 6] = int( - convert_int_to_freq( - spec_slice[:, :hlf_val].sum(1).argmax(), - spec.shape[0], - params["min_freq"], - params["max_freq"], - ) - ) - features[ff, 7] = int( - convert_int_to_freq( - spec_slice[:, hlf_val:].sum(1).argmax(), - spec.shape[0], - params["min_freq"], - params["max_freq"], - ) + for row in range(num_detections): + prediction: types.Prediction = { + "det_prob": float(pred_nms["det_probs"][row]), + "class_prob": pred_nms["class_probs"][:, row], + "start_time": float(pred_nms["start_times"][row]), + "end_time": float(pred_nms["end_times"][row]), + "low_freq": float(pred_nms["low_freqs"][row]), + "high_freq": float(pred_nms["high_freqs"][row]), + "x_pos": int(pred_nms["x_pos"][row]), + "y_pos": int(pred_nms["y_pos"][row]), + "bb_width": int(pred_nms["bb_width"][row]), + "bb_height": int(pred_nms["bb_height"][row]), + } + + for col, feature in enumerate(FEATURES.values()): + features[row, col] = feature( + prediction, + previous=previous, + spec=spec, + **params, ) - if ff > 0: - features[ff, 8] = round( - pred_nms["start_times"][ff] - - pred_nms["start_times"][ff - 1], - 5, - ) + previous = prediction return features + + +def get_feature_names(): + """Get names of features in the order they are extracted.""" + return list(FEATURES.keys()) diff --git a/batdetect2/types.py b/batdetect2/types.py index 3bc810b..2941c51 100644 --- a/batdetect2/types.py +++ b/batdetect2/types.py @@ -1,5 +1,5 @@ """Types used in the code base.""" -from typing import List, NamedTuple, Optional +from typing import List, NamedTuple, Optional, Union import numpy as np import torch @@ -25,10 +25,13 @@ __all__ = [ "Annotation", "DetectionModel", + "FeatureExtractionParameters", + "FeatureExtractor", "FileAnnotations", "ModelOutput", "ModelParameters", "NonMaximumSuppressionConfig", + "Prediction", "PredictionResults", "ProcessingConfiguration", "ResultParams", @@ -312,6 +315,40 @@ class ModelOutput(NamedTuple): """Tensor with intermediate features.""" +class Prediction(TypedDict): + """Singe prediction.""" + + det_prob: float + """Detection probability.""" + + x_pos: int + """X position of the detection in pixels.""" + + y_pos: int + """Y position of the detection in pixels.""" + + bb_width: int + """Width of the detection in pixels.""" + + bb_height: int + """Height of the detection in pixels.""" + + start_time: float + """Start time of the detection in seconds.""" + + end_time: float + """End time of the detection in seconds.""" + + low_freq: float + """Low frequency of the detection in Hz.""" + + high_freq: float + """High frequency of the detection in Hz.""" + + class_prob: np.ndarray + """Vector holding the probability of each class.""" + + class PredictionResults(TypedDict): """Results of the prediction. @@ -418,6 +455,16 @@ class NonMaximumSuppressionConfig(TypedDict): """Threshold for detection probability.""" +class FeatureExtractionParameters(TypedDict): + """Parameters that control the feature extraction function.""" + + min_freq: int + """Minimum frequency to consider in Hz.""" + + max_freq: int + """Maximum frequency to consider in Hz.""" + + class HeatmapParameters(TypedDict): """Parameters that control the heatmap generation function.""" @@ -473,3 +520,11 @@ class AnnotationGroup(TypedDict): y_inds: NotRequired[np.ndarray] """Y coordinate of the annotations in the spectrogram.""" + + +class FeatureExtractor(Protocol): + """Protocol for feature extractors.""" + + def __call__(self, prediction: Prediction, **kwargs) -> Union[float, int]: + """Extract features from a prediction.""" + ... diff --git a/batdetect2/utils/detector_utils.py b/batdetect2/utils/detector_utils.py index 8074d80..74193b3 100644 --- a/batdetect2/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -773,7 +773,7 @@ def process_file( ) # convert to numpy - spec_np = spec.detach().cpu().numpy() + spec_np = spec.detach().cpu().numpy().squeeze() # add chunk time to start and end times pred_nms["start_times"] += chunk_time @@ -794,7 +794,7 @@ def process_file( if config["spec_slices"]: # FIX: This is not currently working. Returns empty slices spec_slices.extend( - feats.extract_spec_slices(spec_np, pred_nms, config) + feats.extract_spec_slices(spec_np, pred_nms) ) # Merge results from chunks diff --git a/pyproject.toml b/pyproject.toml index 7d58874..2570dc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ build-backend = "pdm.pep517.api" batdetect2 = "batdetect2.cli:cli" [tool.black] -line-length = 80 +line-length = 79 [[tool.mypy.overrides]] module = [ diff --git a/tests/test_cli.py b/tests/test_cli.py index 767be7e..4038533 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,7 @@ """Test the command line interface.""" +from pathlib import Path from click.testing import CliRunner +import pandas as pd from batdetect2.cli import cli @@ -67,3 +69,42 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path): assert result.exit_code == 0 assert 'Time Expansion Factor: 10' in result.stdout + + + +def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path): + """Test the detect command with the spec feature flag.""" + results_dir = tmp_path / "results" + + # Remove results dir if it exists + if results_dir.exists(): + results_dir.rmdir() + + runner = CliRunner() + result = runner.invoke( + cli, + [ + "detect", + "example_data/audio", + str(results_dir), + "0.3", + "--spec_features", + ], + ) + assert result.exit_code == 0 + assert results_dir.exists() + + + csv_files = [path.name for path in results_dir.glob("*.csv")] + + expected_files = [ + "20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv", + "20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv", + "20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv" + ] + + for expected_file in expected_files: + assert expected_file in csv_files + + df = pd.read_csv(results_dir / expected_file) + assert not (df.duration == -1).any() diff --git a/tests/test_features.py b/tests/test_features.py new file mode 100644 index 0000000..0394337 --- /dev/null +++ b/tests/test_features.py @@ -0,0 +1,87 @@ +"""Test suite for feature extraction functions.""" + +import numpy as np + +import batdetect2.detector.compute_features as feats +from batdetect2 import types + + +def index_to_freq( + index: int, + spec_height: int, + min_freq: int, + max_freq: int, +) -> float: + """Convert spectrogram index to frequency in Hz.""" + index = spec_height - index + return round( + (index / float(spec_height)) * (max_freq - min_freq) + min_freq, 2 + ) + + +def index_to_time( + index: int, + spec_width: int, + spec_duration: float, +) -> float: + """Convert spectrogram index to time in seconds.""" + return round((index / float(spec_width)) * spec_duration, 2) + + +def test_get_feats_function_with_empty_spectrogram(): + spec_duration = 3 + spec_width = 100 + spec_height = 100 + min_freq = 10_000 + max_freq = 120_000 + spectrogram = np.zeros((spec_height, spec_width)) + + x_pos = 20 + y_pos = 80 + bb_width = 20 + bb_height = 20 + + start_time = index_to_time(x_pos, spec_width, spec_duration) + end_time = index_to_time(x_pos + bb_width, spec_width, spec_duration) + high_freq = index_to_freq(y_pos, spec_height, min_freq, max_freq) + low_freq = index_to_freq(y_pos + bb_height, spec_height, min_freq, max_freq) + + pred_nms: types.PredictionResults = { + "det_probs": np.array([1]), + "class_probs": np.array([1]), + "x_pos": np.array([x_pos]), + "y_pos": np.array([y_pos]), + "bb_width": np.array([bb_width]), + "bb_height": np.array([bb_height]), + "start_times": np.array([start_time]), + "end_times": np.array([end_time]), + "low_freqs": np.array([low_freq]), + "high_freqs": np.array([high_freq]), + } + + params: types.FeatureExtractionParameters = { + "min_freq": min_freq, + "max_freq": max_freq, + } + + features = feats.get_feats(spectrogram, pred_nms, params) + assert low_freq < high_freq + assert isinstance(features, np.ndarray) + assert features.shape == (len(pred_nms["det_probs"]), 9) + assert np.isclose( + features[0], + np.array( + [ + end_time - start_time, + low_freq, + high_freq, + high_freq - low_freq, + max_freq, + max_freq, + max_freq, + max_freq, + np.nan, + ] + ), + equal_nan=True, + ).all() From 8e8779a72e320aa950435ef682e03adf90b6ed12 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Wed, 2 Aug 2023 20:07:56 +0100 Subject: [PATCH 2/3] fix: call interval kwargs name error --- .gitignore | 2 +- batdetect2/detector/compute_features.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 47ac4bd..024f9a2 100644 --- a/.gitignore +++ b/.gitignore @@ -65,7 +65,7 @@ ipython_config.py # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide -.pdm.toml +.pdm-python # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ diff --git a/batdetect2/detector/compute_features.py b/batdetect2/detector/compute_features.py index 6abb145..e0c3a2d 100644 --- a/batdetect2/detector/compute_features.py +++ b/batdetect2/detector/compute_features.py @@ -100,7 +100,7 @@ def compute_max_power_bb( max_power_ind = np.argmax(power_per_freq_band) return int( convert_int_to_freq( - max_power_ind, + y_low - max_power_ind, spec.shape[0], min_freq, max_freq, @@ -190,13 +190,13 @@ def compute_max_power_second( def compute_call_interval( prediction: types.Prediction, - previous_prediction: Optional[types.Prediction] = None, + previous: Optional[types.Prediction] = None, **_, ) -> float: """Compute time between this call and the previous call in seconds.""" - if previous_prediction is None: + if previous is None: return np.nan - return round(prediction["start_time"] - previous_prediction["end_time"], 5) + return round(prediction["start_time"] - previous["end_time"], 5) # NOTE: The order of the features in this dictionary is important. The From 3288f52bbd3fbc07105882535b68f4aa6149e383 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Thu, 3 Aug 2023 11:45:39 +0100 Subject: [PATCH 3/3] tests: added tests for feature computation --- batdetect2/detector/compute_features.py | 37 ++-- tests/test_features.py | 214 +++++++++++++++++++++++- 2 files changed, 235 insertions(+), 16 deletions(-) diff --git a/batdetect2/detector/compute_features.py b/batdetect2/detector/compute_features.py index e0c3a2d..b53b0cb 100644 --- a/batdetect2/detector/compute_features.py +++ b/batdetect2/detector/compute_features.py @@ -88,19 +88,28 @@ def compute_max_power_bb( return np.nan x_start = max(0, prediction["x_pos"]) - x_end = min(spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]) - - y_low = max(0, prediction["y_pos"]) - y_high = min( - spec.shape[0] - 1, prediction["y_pos"] + prediction["bb_height"] + x_end = min( + spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] ) - spec_bb = spec[y_low:y_high, x_start:x_end] + # y low is the lowest freq but it will have a higher value due to array + # starting at 0 at top + y_low = min(spec.shape[0] - 1, prediction["y_pos"]) + y_high = max(0, prediction["y_pos"] - prediction["bb_height"]) + + spec_bb = spec[y_high:y_low, x_start:x_end] power_per_freq_band = np.sum(spec_bb, axis=1) - max_power_ind = np.argmax(power_per_freq_band) + + try: + max_power_ind = np.argmax(power_per_freq_band) + except ValueError: + # If the call is too short, the bounding box might be empty. + # In this case, return NaN. + return np.nan + return int( convert_int_to_freq( - y_low - max_power_ind, + y_high + max_power_ind, spec.shape[0], min_freq, max_freq, @@ -120,7 +129,9 @@ def compute_max_power( return np.nan x_start = max(0, prediction["x_pos"]) - x_end = min(spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]) + x_end = min( + spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] + ) spec_call = spec[:, x_start:x_end] power_per_freq_band = np.sum(spec_call, axis=1) max_power_ind = np.argmax(power_per_freq_band) @@ -146,7 +157,9 @@ def compute_max_power_first( return np.nan x_start = max(0, prediction["x_pos"]) - x_end = min(spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]) + x_end = min( + spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] + ) spec_call = spec[:, x_start:x_end] first_half = spec_call[:, : int(spec_call.shape[1] / 2)] power_per_freq_band = np.sum(first_half, axis=1) @@ -173,7 +186,9 @@ def compute_max_power_second( return np.nan x_start = max(0, prediction["x_pos"]) - x_end = min(spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]) + x_end = min( + spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] + ) spec_call = spec[:, x_start:x_end] second_half = spec_call[:, int(spec_call.shape[1] / 2) :] power_per_freq_band = np.sum(second_half, axis=1) diff --git a/tests/test_features.py b/tests/test_features.py index 0394337..1271fda 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -1,9 +1,17 @@ """Test suite for feature extraction functions.""" +import logging + +import librosa import numpy as np +import pytest import batdetect2.detector.compute_features as feats -from batdetect2 import types +from batdetect2 import api, types +from batdetect2.utils import audio_utils as au + +numba_logger = logging.getLogger("numba") +numba_logger.setLevel(logging.WARNING) def index_to_freq( @@ -29,6 +37,11 @@ def index_to_time( def test_get_feats_function_with_empty_spectrogram(): + """Test get_feats function with empty spectrogram. + + This tests that the overall flow of the function works, even if the + spectrogram is empty. + """ spec_duration = 3 spec_width = 100 spec_height = 100 @@ -43,12 +56,14 @@ def test_get_feats_function_with_empty_spectrogram(): start_time = index_to_time(x_pos, spec_width, spec_duration) end_time = index_to_time(x_pos + bb_width, spec_width, spec_duration) - high_freq = index_to_freq(y_pos, spec_height, min_freq, max_freq) - low_freq = index_to_freq(y_pos + bb_height, spec_height, min_freq, max_freq) + low_freq = index_to_freq(y_pos, spec_height, min_freq, max_freq) + high_freq = index_to_freq( + y_pos - bb_height, spec_height, min_freq, max_freq + ) pred_nms: types.PredictionResults = { "det_probs": np.array([1]), - "class_probs": np.array([1]), + "class_probs": np.array([[1]]), "x_pos": np.array([x_pos]), "y_pos": np.array([y_pos]), "bb_width": np.array([bb_width]), @@ -76,7 +91,7 @@ def test_get_feats_function_with_empty_spectrogram(): low_freq, high_freq, high_freq - low_freq, - max_freq, + high_freq, max_freq, max_freq, max_freq, @@ -85,3 +100,192 @@ def test_get_feats_function_with_empty_spectrogram(): ), equal_nan=True, ).all() + + +@pytest.mark.parametrize( + "max_power", + [ + 30_000, + 31_000, + 32_000, + 33_000, + 34_000, + 35_000, + 36_000, + 37_000, + 38_000, + 39_000, + 40_000, + ], +) +def test_compute_max_power_bb(max_power: int): + """Test compute_max_power_bb function.""" + duration = 1 + samplerate = 256_000 + min_freq = 0 + max_freq = 128_000 + + start_time = 0.3 + end_time = 0.6 + low_freq = 30_000 + high_freq = 40_000 + + audio = np.zeros((int(duration * samplerate),)) + + # Add a signal during the time and frequency range of interest + audio[ + int(start_time * samplerate) : int(end_time * samplerate) + ] = 0.5 * librosa.tone( + max_power, sr=samplerate, duration=end_time - start_time + ) + + # Add a more powerful signal outside frequency range of interest + audio[ + int(start_time * samplerate) : int(end_time * samplerate) + ] += 2 * librosa.tone( + 80_000, sr=samplerate, duration=end_time - start_time + ) + + params = api.get_config( + min_freq=min_freq, + max_freq=max_freq, + target_samp_rate=samplerate, + ) + + spec, _ = au.generate_spectrogram( + audio, + samplerate, + params, + ) + + x_start = int( + au.time_to_x_coords( + start_time, + samplerate, + params["fft_win_length"], + params["fft_overlap"], + ) + ) + + x_end = int( + au.time_to_x_coords( + end_time, + samplerate, + params["fft_win_length"], + params["fft_overlap"], + ) + ) + + num_freq_bins = spec.shape[0] + y_low = num_freq_bins - int(num_freq_bins * low_freq / max_freq) + y_high = num_freq_bins - int(num_freq_bins * high_freq / max_freq) + + prediction: types.Prediction = { + "det_prob": 1, + "class_prob": np.ones((1,)), + "x_pos": x_start, + "y_pos": int(y_low), + "bb_width": int(x_end - x_start), + "bb_height": int(y_low - y_high), + "start_time": start_time, + "end_time": end_time, + "low_freq": low_freq, + "high_freq": high_freq, + } + + print(prediction) + + max_power_bb = feats.compute_max_power_bb( + prediction, + spec, + min_freq=min_freq, + max_freq=max_freq, + ) + + assert abs(max_power_bb - max_power) <= 500 + + +def test_compute_max_power(): + """Test compute_max_power_bb function.""" + duration = 3 + samplerate = 16_000 + min_freq = 0 + max_freq = 8_000 + + start_time = 1 + end_time = 2 + low_freq = 3_000 + high_freq = 4_000 + max_power = 5_000 + + audio = np.zeros((int(duration * samplerate),)) + + # Add a signal during the time and frequency range of interest + audio[ + int(start_time * samplerate) : int(end_time * samplerate) + ] = 0.5 * librosa.tone( + 3_500, sr=samplerate, duration=end_time - start_time + ) + + # Add a more powerful signal outside frequency range of interest + audio[ + int(start_time * samplerate) : int(end_time * samplerate) + ] += 2 * librosa.tone( + max_power, sr=samplerate, duration=end_time - start_time + ) + + params = api.get_config( + min_freq=min_freq, + max_freq=max_freq, + target_samp_rate=samplerate, + ) + + spec, _ = au.generate_spectrogram( + audio, + samplerate, + params, + ) + + x_start = int( + au.time_to_x_coords( + start_time, + samplerate, + params["fft_win_length"], + params["fft_overlap"], + ) + ) + + x_end = int( + au.time_to_x_coords( + end_time, + samplerate, + params["fft_win_length"], + params["fft_overlap"], + ) + ) + + num_freq_bins = spec.shape[0] + y_low = int(num_freq_bins * low_freq / max_freq) + y_high = int(num_freq_bins * high_freq / max_freq) + + prediction: types.Prediction = { + "det_prob": 1, + "class_prob": np.ones((1,)), + "x_pos": x_start, + "y_pos": int(y_high), + "bb_width": int(x_end - x_start), + "bb_height": int(y_high - y_low), + "start_time": start_time, + "end_time": end_time, + "low_freq": low_freq, + "high_freq": high_freq, + } + + computed_max_power = feats.compute_max_power( + prediction, + spec, + min_freq=min_freq, + max_freq=max_freq, + ) + + assert abs(computed_max_power - max_power) < 100