diff --git a/src/b2aiprep/commands.py b/src/b2aiprep/commands.py index 43e252f..63f9c6e 100644 --- a/src/b2aiprep/commands.py +++ b/src/b2aiprep/commands.py @@ -1,16 +1,18 @@ """Commands available through the CLI.""" +# Refactor to improve readability, maintainability, and scalability + import csv -from functools import partial import json import logging import os import shutil import tarfile import typing as t +from functools import partial from glob import glob -from pathlib import Path from importlib import resources +from pathlib import Path import click import numpy as np @@ -19,6 +21,7 @@ import pydra import torch from datasets import Dataset +from pyarrow.parquet import SortingColumn from pydra.mark import annotate from senselab.audio.data_structures.audio import Audio from senselab.audio.tasks.features_extraction.opensmile import ( @@ -44,9 +47,12 @@ from streamlit.web.bootstrap import run from tqdm import tqdm -from pyarrow.parquet import SortingColumn from b2aiprep.prepare.bids import get_audio_paths, redcap_to_bids, validate_bids_folder -from b2aiprep.prepare.prepare import extract_features_workflow, validate_bids_data, clean_phenotype_data +from b2aiprep.prepare.prepare import ( + clean_phenotype_data, + extract_features_workflow, + validate_bids_data, +) # from b2aiprep.synthetic_data import generate_synthetic_tabular_data @@ -221,7 +227,7 @@ def load_audio_features( pt_file = features_dir / f"{wav_path.stem}_features.pt" features = torch.load(pt_file) - output['spectrogram'] = features['torchaudio']['spectrogram'].numpy().astype(np.float32) + output["spectrogram"] = features["torchaudio"]["spectrogram"].numpy().astype(np.float32) # for feature_name in ["speaker_embedding", "specgram", "melfilterbank", "mfcc", "opensmile"]: # feature_path = features_dir / f"{wav_path.stem}_{feature_name}.{file_extension}" # if not feature_path.exists(): @@ -242,6 +248,7 @@ def load_audio_features( yield output + def spectrogram_generator( audio_paths, ) -> t.Generator[t.Dict[str, t.Any], None, None]: @@ -251,19 +258,20 @@ def spectrogram_generator( pt_file = wav_path.parent / f"{wav_path.stem}_features.pt" features = torch.load(pt_file) - output['participant_id'] = wav_path.stem.split('_')[0][4:] # skip "sub-" prefix - output['session_id'] = wav_path.stem.split('_')[1][4:] # skip "ses-" prefix - output['task_name'] = wav_path.stem.split('_')[2][5:] # skip "task-" prefix - output['spectrogram'] = features['torchaudio']['spectrogram'].numpy().astype(np.float32) + output["participant_id"] = wav_path.stem.split("_")[0][4:] # skip "sub-" prefix + output["session_id"] = wav_path.stem.split("_")[1][4:] # skip "ses-" prefix + output["task_name"] = wav_path.stem.split("_")[2][5:] # skip "task-" prefix + output["spectrogram"] = features["torchaudio"]["spectrogram"].numpy().astype(np.float32) yield output + @click.command() @click.argument("bids_path", type=click.Path(exists=True)) @click.argument("outdir", type=click.Path()) def create_derived_dataset(bids_path, outdir): """Create a derived dataset from voice/phenotype data in BIDS format. - + The derived dataset output loads data from generated .pt files, which have the following keys: - pitch @@ -292,7 +300,7 @@ def create_derived_dataset(bids_path, outdir): audio_paths = sorted( audio_paths, # sort first by subject, then session, then by task - key=lambda x: (x.stem.split('_')[0], x.stem.split('_')[1], x.stem.split('_')[2]) + key=lambda x: (x.stem.split("_")[0], x.stem.split("_")[1], x.stem.split("_")[2]), ) # remove known subjects without any audio @@ -338,7 +346,7 @@ def create_derived_dataset(bids_path, outdir): pt_file = Path(filename.replace(".wav", "_features.pt")) if not pt_file.exists(): continue - + for participant_id in SUBJECTS_TO_REMOVE: if f"sub-{participant_id}" in str(pt_file): _LOGGER.info(f"Skipping subject {participant_id}") @@ -354,13 +362,15 @@ def create_derived_dataset(bids_path, outdir): transcription = features.get("transcription", None) if transcription is not None: transcription = transcription.text - if subj_info['task_name'].lower().startswith('free-Speech') or \ - subj_info['task_name'].lower().startswith('audio-check') or \ - subj_info['task_name'].lower().startswith('open-response-questions'): + if ( + subj_info["task_name"].lower().startswith("free-Speech") + or subj_info["task_name"].lower().startswith("audio-check") + or subj_info["task_name"].lower().startswith("open-response-questions") + ): # we omit tasks where free speech occurs transcription = None subj_info["transcription"] = transcription - + for key in ["opensmile", "praat_parselmouth", "torchaudio_squim"]: subj_info.update(features.get(key, {})) @@ -370,7 +380,9 @@ def create_derived_dataset(bids_path, outdir): df_static.to_csv(outdir / "static_features.tsv", sep="\t", index=False) # load in the JSON with descriptions of each feature and copy it over # write it out again so formatting is consistent between JSONs - static_features_json_file = resources.files("b2aiprep").joinpath("prepare", "resources", "static_features.json") + static_features_json_file = resources.files("b2aiprep").joinpath( + "prepare", "resources", "static_features.json" + ) static_features_json = json.load(static_features_json_file.open()) with open(outdir / "static_features.json", "w") as f: json.dump(static_features_json, f, indent=2) @@ -382,9 +394,11 @@ def create_derived_dataset(bids_path, outdir): df = pd.read_csv(bids_path.joinpath("participants.tsv"), sep="\t") # remove subject - idx = df['record_id'].isin(SUBJECTS_TO_REMOVE) + idx = df["record_id"].isin(SUBJECTS_TO_REMOVE) if idx.sum() > 0: - _LOGGER.info(f"Removing {idx.sum()} records from phenotype due to hard-coded subject removal.") + _LOGGER.info( + f"Removing {idx.sum()} records from phenotype due to hard-coded subject removal." + ) df = df.loc[~idx] # temporarily keep record_id as the column name to enable joining the dataframes together @@ -434,7 +448,9 @@ def create_derived_dataset(bids_path, outdir): # add the data elements to the overall phenotype dict if len(phenotype_add) != 1: # we expect there to only be one key - _LOGGER.warning(f"Unexpected keys in phenotype file {phenotype_filepath.stem}: {phenotype_add.keys()}") + _LOGGER.warning( + f"Unexpected keys in phenotype file {phenotype_filepath.stem}: {phenotype_add.keys()}" + ) else: phenotype_add = next(iter(phenotype_add.values()))["data_elements"] phenotype.update(phenotype_add)