Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor commands.py #110

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 36 additions & 20 deletions src/b2aiprep/commands.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
"""Commands available through the CLI."""

# Refactor to improve readability, maintainability, and scalability

import csv
from functools import partial
import json
import logging
import os
import shutil
import tarfile
import typing as t
from functools import partial
from glob import glob
from pathlib import Path
from importlib import resources
from pathlib import Path

import click
import numpy as np
Expand All @@ -19,6 +21,7 @@
import pydra
import torch
from datasets import Dataset
from pyarrow.parquet import SortingColumn
from pydra.mark import annotate
from senselab.audio.data_structures.audio import Audio
from senselab.audio.tasks.features_extraction.opensmile import (
Expand All @@ -44,9 +47,12 @@
from streamlit.web.bootstrap import run
from tqdm import tqdm

from pyarrow.parquet import SortingColumn
from b2aiprep.prepare.bids import get_audio_paths, redcap_to_bids, validate_bids_folder
from b2aiprep.prepare.prepare import extract_features_workflow, validate_bids_data, clean_phenotype_data
from b2aiprep.prepare.prepare import (
clean_phenotype_data,
extract_features_workflow,
validate_bids_data,
)

# from b2aiprep.synthetic_data import generate_synthetic_tabular_data

Expand Down Expand Up @@ -221,7 +227,7 @@ def load_audio_features(

pt_file = features_dir / f"{wav_path.stem}_features.pt"
features = torch.load(pt_file)
output['spectrogram'] = features['torchaudio']['spectrogram'].numpy().astype(np.float32)
output["spectrogram"] = features["torchaudio"]["spectrogram"].numpy().astype(np.float32)
# for feature_name in ["speaker_embedding", "specgram", "melfilterbank", "mfcc", "opensmile"]:
# feature_path = features_dir / f"{wav_path.stem}_{feature_name}.{file_extension}"
# if not feature_path.exists():
Expand All @@ -242,6 +248,7 @@ def load_audio_features(

yield output


def spectrogram_generator(
audio_paths,
) -> t.Generator[t.Dict[str, t.Any], None, None]:
Expand All @@ -251,19 +258,20 @@ def spectrogram_generator(
pt_file = wav_path.parent / f"{wav_path.stem}_features.pt"
features = torch.load(pt_file)

output['participant_id'] = wav_path.stem.split('_')[0][4:] # skip "sub-" prefix
output['session_id'] = wav_path.stem.split('_')[1][4:] # skip "ses-" prefix
output['task_name'] = wav_path.stem.split('_')[2][5:] # skip "task-" prefix
output['spectrogram'] = features['torchaudio']['spectrogram'].numpy().astype(np.float32)
output["participant_id"] = wav_path.stem.split("_")[0][4:] # skip "sub-" prefix
output["session_id"] = wav_path.stem.split("_")[1][4:] # skip "ses-" prefix
output["task_name"] = wav_path.stem.split("_")[2][5:] # skip "task-" prefix
output["spectrogram"] = features["torchaudio"]["spectrogram"].numpy().astype(np.float32)

yield output


@click.command()
@click.argument("bids_path", type=click.Path(exists=True))
@click.argument("outdir", type=click.Path())
def create_derived_dataset(bids_path, outdir):
"""Create a derived dataset from voice/phenotype data in BIDS format.

The derived dataset output loads data from generated .pt files, which have
the following keys:
- pitch
Expand Down Expand Up @@ -292,7 +300,7 @@ def create_derived_dataset(bids_path, outdir):
audio_paths = sorted(
audio_paths,
# sort first by subject, then session, then by task
key=lambda x: (x.stem.split('_')[0], x.stem.split('_')[1], x.stem.split('_')[2])
key=lambda x: (x.stem.split("_")[0], x.stem.split("_")[1], x.stem.split("_")[2]),
)

# remove known subjects without any audio
Expand Down Expand Up @@ -338,7 +346,7 @@ def create_derived_dataset(bids_path, outdir):
pt_file = Path(filename.replace(".wav", "_features.pt"))
if not pt_file.exists():
continue

for participant_id in SUBJECTS_TO_REMOVE:
if f"sub-{participant_id}" in str(pt_file):
_LOGGER.info(f"Skipping subject {participant_id}")
Expand All @@ -354,13 +362,15 @@ def create_derived_dataset(bids_path, outdir):
transcription = features.get("transcription", None)
if transcription is not None:
transcription = transcription.text
if subj_info['task_name'].lower().startswith('free-Speech') or \
subj_info['task_name'].lower().startswith('audio-check') or \
subj_info['task_name'].lower().startswith('open-response-questions'):
if (
subj_info["task_name"].lower().startswith("free-Speech")
or subj_info["task_name"].lower().startswith("audio-check")
or subj_info["task_name"].lower().startswith("open-response-questions")
):
# we omit tasks where free speech occurs
transcription = None
subj_info["transcription"] = transcription

for key in ["opensmile", "praat_parselmouth", "torchaudio_squim"]:
subj_info.update(features.get(key, {}))

Expand All @@ -370,7 +380,9 @@ def create_derived_dataset(bids_path, outdir):
df_static.to_csv(outdir / "static_features.tsv", sep="\t", index=False)
# load in the JSON with descriptions of each feature and copy it over
# write it out again so formatting is consistent between JSONs
static_features_json_file = resources.files("b2aiprep").joinpath("prepare", "resources", "static_features.json")
static_features_json_file = resources.files("b2aiprep").joinpath(
"prepare", "resources", "static_features.json"
)
static_features_json = json.load(static_features_json_file.open())
with open(outdir / "static_features.json", "w") as f:
json.dump(static_features_json, f, indent=2)
Expand All @@ -382,9 +394,11 @@ def create_derived_dataset(bids_path, outdir):
df = pd.read_csv(bids_path.joinpath("participants.tsv"), sep="\t")

# remove subject
idx = df['record_id'].isin(SUBJECTS_TO_REMOVE)
idx = df["record_id"].isin(SUBJECTS_TO_REMOVE)
if idx.sum() > 0:
_LOGGER.info(f"Removing {idx.sum()} records from phenotype due to hard-coded subject removal.")
_LOGGER.info(
f"Removing {idx.sum()} records from phenotype due to hard-coded subject removal."
)
df = df.loc[~idx]

# temporarily keep record_id as the column name to enable joining the dataframes together
Expand Down Expand Up @@ -434,7 +448,9 @@ def create_derived_dataset(bids_path, outdir):
# add the data elements to the overall phenotype dict
if len(phenotype_add) != 1:
# we expect there to only be one key
_LOGGER.warning(f"Unexpected keys in phenotype file {phenotype_filepath.stem}: {phenotype_add.keys()}")
_LOGGER.warning(
f"Unexpected keys in phenotype file {phenotype_filepath.stem}: {phenotype_add.keys()}"
)
else:
phenotype_add = next(iter(phenotype_add.values()))["data_elements"]
phenotype.update(phenotype_add)
Expand Down
Loading