Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete conversion pipeline: code refactoring #10

Merged
merged 16 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/cai_lab_to_nwb/zaki_2024/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +0,0 @@
from .implant_interface import ImplantInterface
from .behaviorinterface import FreezingBehaviorInterface
from .sleepinterface import SleepBehaviorInterface
from .minian_segmentation_interface import MinianSegmentationInterface
60 changes: 0 additions & 60 deletions src/cai_lab_to_nwb/zaki_2024/implant_interface.py

This file was deleted.

5 changes: 5 additions & 0 deletions src/cai_lab_to_nwb/zaki_2024/interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .eztrack_interface import EzTrackFreezingBehaviorInterface
from .zaki_2024_edf_interface import Zaki2024EDFInterface
from .minian_interface import MinianSegmentationInterface
from .zaki_2024_sleep_classification_interface import Zaki2024SleepClassificationInterface
from .miniscope_imaging_interface import MiniscopeImagingInterface
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from pydantic import FilePath
from typing import Optional

class FreezingBehaviorInterface(BaseDataInterface):
"""Adds intervals of freezing behavior interface."""

keywords = ["behavior"]
class EzTrackFreezingBehaviorInterface(BaseDataInterface):
"""Adds intervals of freezing behavior and motion series."""

keywords = ["behavior", "freezing", "motion"]

def __init__(self, file_path: FilePath, video_sampling_frequency: float, verbose: bool = False):
# This should load the data lazily and prepare variables you need
Expand All @@ -34,20 +35,19 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: Optional[dict] = None):

freezing_behavior_df = pd.read_csv(self.file_path)

#Extract motion data
# Extract motion data
motion_data = freezing_behavior_df["Motion"].values

motion_series = TimeSeries(
name="MotionSeries",
description="ezTrack measures the motion of the animal by assessing the number of pixels of the behavioral "
"video whose grayscale change exceeds a particular threshold from one frame to the next.",
"video whose grayscale change exceeds a particular threshold from one frame to the next.",
data=motion_data,
unit="n.a",
starting_time=freezing_behavior_df["Frame"][0] / self.video_sampling_frequency,
rate=self.video_sampling_frequency,
)


# Extract parameters, those values are unique per run
file = freezing_behavior_df["File"].unique()[0]
motion_cutoff = freezing_behavior_df["MotionCutoff"].unique()[0]
Expand All @@ -56,19 +56,17 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: Optional[dict] = None):

# Extract start and stop times of the freezing events
# From the discussion wih the author, the freezing events are the frames where the freezing behavior is 100

freezing_values = freezing_behavior_df["Freezing"].values
changes_in_freezing = np.diff(freezing_values)
freezing_start = np.where(changes_in_freezing == 100)[0] + 1
freezing_stop = np.where(changes_in_freezing == -100)[0] + 1

start_frames = freezing_behavior_df["Frame"].values[freezing_start]
start_frames = freezing_behavior_df["Frame"].values[freezing_start]
stop_frames = freezing_behavior_df["Frame"].values[freezing_stop]

start_times = start_frames / self.video_sampling_frequency
stop_times = stop_frames / self.video_sampling_frequency


description = f"""
Freezing behavior intervals generated using EzTrack software for file {file}.
Parameters used include a motion cutoff of {motion_cutoff}, freeze threshold of {freeze_threshold},
Expand All @@ -79,15 +77,12 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: Optional[dict] = None):
- Motion cutoff: The level of pixel intensity change required to register as motion.
"""


freeze_intervals = TimeIntervals(name="TimeIntervalsFreezingBehavior", description=description)
freeze_intervals = TimeIntervals(name="FreezingIntervals", description=description)
for start_time, stop_time in zip(start_times, stop_times):
freeze_intervals.add_interval(start_time=start_time, stop_time=stop_time, timeseries=[motion_series])

if "behavior" not in nwbfile.processing:
behavior_module = nwbfile.create_processing_module(
name="behavior", description="Contains behavior data"
)
behavior_module = nwbfile.create_processing_module(name="behavior", description="Contains behavior data")
else:
behavior_module = nwbfile.processing["behavior"]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
"""A SegmentationExtractor for Minian.

Classes
-------
MinianSegmentationExtractor
A class for extracting segmentation from Minian output.
"""

from pathlib import Path

import zarr
import warnings
import numpy as np
Expand All @@ -18,9 +8,9 @@
from roiextractors.segmentationextractor import SegmentationExtractor

from typing import Optional

from pynwb import NWBFile


class MinianSegmentationExtractor(SegmentationExtractor):
"""A SegmentationExtractor for Minian.

Expand Down Expand Up @@ -58,13 +48,13 @@ def __init__(self, folder_path: PathType):
"""
SegmentationExtractor.__init__(self)
self.folder_path = folder_path
self._roi_response_denoised = self._read_trace_from_zarr_filed(field="C")
self._roi_response_baseline = self._read_trace_from_zarr_filed(field="b0")
self._roi_response_neuropil = self._read_trace_from_zarr_filed(field="f")
self._roi_response_deconvolved = self._read_trace_from_zarr_filed(field="S")
self._roi_response_denoised = self._read_trace_from_zarr_field(field="C")
self._roi_response_baseline = self._read_trace_from_zarr_field(field="b0")
self._roi_response_neuropil = self._read_trace_from_zarr_field(field="f")
self._roi_response_deconvolved = self._read_trace_from_zarr_field(field="S")
self._image_maximum_projection = np.array(self._read_zarr_group("/max_proj.zarr/max_proj"))
self._image_masks = self._read_roi_image_mask_from_zarr_filed()
self._background_image_masks = self._read_background_image_mask_from_zarr_filed()
self._image_masks = self._read_roi_image_mask_from_zarr_field()
self._background_image_masks = self._read_background_image_mask_from_zarr_field()
self._times = self._read_timestamps_from_csv()

def _read_zarr_group(self, zarr_group=""):
Expand All @@ -81,7 +71,7 @@ def _read_zarr_group(self, zarr_group=""):
else:
return zarr.open(str(self.folder_path) + f"/{zarr_group}", "r")

def _read_roi_image_mask_from_zarr_filed(self):
def _read_roi_image_mask_from_zarr_field(self):
"""Read the image masks from the zarr output.

Returns
Expand All @@ -95,7 +85,7 @@ def _read_roi_image_mask_from_zarr_filed(self):
else:
return np.transpose(dataset["A"], (1, 2, 0))

def _read_background_image_mask_from_zarr_filed(self):
def _read_background_image_mask_from_zarr_field(self):
"""Read the image masks from the zarr output.

Returns
Expand All @@ -109,7 +99,7 @@ def _read_background_image_mask_from_zarr_filed(self):
else:
return np.expand_dims(dataset["b"], axis=2)

def _read_trace_from_zarr_filed(self, field):
def _read_trace_from_zarr_field(self, field):
"""Read the traces specified by the field from the zarr object.

Parameters
Expand Down Expand Up @@ -146,6 +136,25 @@ def _read_timestamps_from_csv(self):

return filtered_df["Time Stamp (ms)"].to_numpy()

def get_motion_correction_data(self) -> np.ndarray:
"""Read the xy shifts in the 'motion' field from the zarr object.

Parameters
----------
field: str
The field to read from the zarr object.

Returns
-------
motion_correction: numpy.ndarray
The first column is the x shifts. The second column is the y shifts.
"""
dataset = self._read_zarr_group(f"/motion.zarr")
# from zarr field motion.zarr/shift_dim we can verify that the two column refer respectively to
# ['height','width'] --> ['y','x']. Following best practice we swap the two columns
motion_correction = dataset["motion"][:, [1, 0]]
return motion_correction

def get_image_size(self):
dataset = self._read_zarr_group("/A.zarr")
height = dataset["height"].shape[0]
Expand Down Expand Up @@ -207,6 +216,7 @@ def get_images_dict(self) -> dict:
maximum_projection=self._image_maximum_projection,
)


class MinianSegmentationInterface(BaseSegmentationExtractorInterface):
"""Data interface for MinianSegmentationExtractor."""

Expand Down Expand Up @@ -259,6 +269,3 @@ def add_to_nwbfile(
plane_segmentation_name=plane_segmentation_name,
iterator_options=iterator_options,
)



Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@
from roiextractors.multiimagingextractor import MultiImagingExtractor
from roiextractors.extraction_tools import PathType, DtypeType, get_package

from typing import Optional
import json
import datetime

from pydantic import DirectoryPath
from pathlib import Path
import numpy as np

from copy import deepcopy
from pathlib import Path
from typing import Literal, Optional
Expand All @@ -25,13 +20,14 @@
class MiniscopeImagingExtractor(MultiImagingExtractor):

def __init__(self, folder_path: DirectoryPath):
self.miniscope_videos_folder_path = Path(folder_path)

self.miniscope_videos_folder_path = Path(folder_path)
assert self.miniscope_videos_folder_path.exists(), f"Miniscope videos folder not found in {Path(folder_path)}"

self._miniscope_avi_file_paths = [p for p in self.miniscope_videos_folder_path.iterdir() if p.suffix == ".avi"]
assert len(self._miniscope_avi_file_paths) > 0, f"No .avi files found in {self.miniscope_videos_folder_path}"
import natsort

self._miniscope_avi_file_paths = natsort.natsorted(self._miniscope_avi_file_paths)

imaging_extractors = []
Expand All @@ -44,7 +40,7 @@ def __init__(self, folder_path: DirectoryPath):
self._sampling_frequency = self._imaging_extractors[0].get_sampling_frequency()
self._image_size = self._imaging_extractors[0].get_image_size()
self._dtype = self._imaging_extractors[0].get_dtype()

def get_num_frames(self) -> int:
return self._num_frames

Expand Down Expand Up @@ -195,38 +191,46 @@ def __init__(self, folder_path: DirectoryPath):
from ndx_miniscope.utils import get_recording_start_times, read_miniscope_config

super().__init__(folder_path=folder_path)
self.miniscope_folder = Path(folder_path)

self.miniscope_folder = Path(folder_path)
# This contains the general metadata and might contain behavioral videos
self.session_folder = self.miniscope_folder.parent
self.session_folder = self.miniscope_folder.parent

self._miniscope_config = read_miniscope_config(folder_path=self.miniscope_folder)

# use the frame rate of the json configuration to set the metadata
frame_rate_string = self._miniscope_config["frameRate"]
# frame_rate_string look like "30.0FPS", extract the float part
self._metadata_frame_rate = float(frame_rate_string.split("FPS")[0])



self.photon_series_type = "OnePhotonSeries"

def _get_session_start_time(self):

general_metadata_json = self.session_folder/ "metaData.json"
general_metadata_json = self.session_folder / "metaData.json"
assert general_metadata_json.exists(), f"General metadata json not found in {self.session_folder}"

## Read metadata
with open(general_metadata_json) as f:
general_metadata = json.load(f)

if "recordingStartTime" in general_metadata:
alessandratrapani marked this conversation as resolved.
Show resolved Hide resolved
start_time_info = general_metadata["recordingStartTime"]
else:
start_time_info = general_metadata

required_keys = ["year", "month", "day", "hour", "minute", "second", "msec"]
for key in required_keys:
if key not in start_time_info:
raise KeyError(f"Missing required key '{key}' in the metadata")

session_start_time = datetime.datetime(
year=general_metadata["year"],
month=general_metadata["month"],
day=general_metadata["day"],
hour=general_metadata["hour"],
minute=general_metadata["minute"],
second=general_metadata["second"],
microsecond=general_metadata["msec"] * 1000, # Convert milliseconds to microseconds
year=start_time_info["year"],
month=start_time_info["month"],
day=start_time_info["day"],
hour=start_time_info["hour"],
minute=start_time_info["minute"],
second=start_time_info["second"],
microsecond=start_time_info["msec"] * 1000, # Convert milliseconds to microseconds
)

return session_start_time
Expand Down Expand Up @@ -266,18 +270,18 @@ def get_original_timestamps(self) -> np.ndarray:

timestamps_file_path = self.miniscope_folder / "timeStamps.csv"
assert timestamps_file_path.exists(), f"Miniscope timestamps file not found in {self.miniscope_folder}"
import pandas as pd
timetsamps_df = pd.read_csv(timestamps_file_path)

import pandas as pd

timetsamps_df = pd.read_csv(timestamps_file_path)
timestamps_milliseconds = timetsamps_df["Time Stamp (ms)"].values.astype(float)
timestamps_seconds = timestamps_milliseconds / 1000.0

# Shift when the first timestamp is negative
# TODO: Figure why, I copied from miniscope
if timestamps_seconds[0] < 0.0:
timestamps_seconds += abs(timestamps_seconds[0])

return np.asarray(timestamps_seconds)

def add_to_nwbfile(
Expand Down
Loading