Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and improve affine handling in the viewer. #7

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
44 changes: 42 additions & 2 deletions napari_nibabel/_tests/test_nibabel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import atexit
import os
import shutil
import tempfile

import nibabel as nib
import numpy as np
Expand All @@ -15,8 +17,11 @@ def test_reader(tmp_path):

# write some fake data in NIFTI-1 format
my_test_file = str(tmp_path / "myfile.nii")
original_data = np.random.rand(20, 20)
nii = nib.Nifti1Image(original_data, affine=np.eye(4))
original_data = np.random.rand(20, 20, 1)

# Set affine to an LPS affine here so internal reorientation will not be
# needed.
nii = nib.Nifti1Image(original_data, affine=np.diag((-1, -1, 1, 1)))
nii.to_filename(my_test_file)
np.save(my_test_file, original_data)

Expand Down Expand Up @@ -168,3 +173,38 @@ def test_analyze_hdr_only():
filename = os.path.join(data_path, 'analyze.hdr')
with pytest.raises(FileNotFoundError):
_test_basic_read(filename)


def test_read_filelist():
filename = os.path.join(data_path, 'example4d.nii.gz')
n_files = 3
data = _test_basic_read([filename,] * n_files)
assert data.ndim == 5
assert data.shape[0] == n_files


def test_read_filelist_mismatched_shape():
# cannot stack multiple files when the shapes are different
filename = os.path.join(data_path, 'example_nifti2.nii.gz')
filename2 = os.path.join(data_path, 'example4d.nii.gz')
with pytest.raises(ValueError):
_test_basic_read([filename, filename2])


def test_read_filelist_mismatched_affine():
# cannot stack multiple files when the shapes are different
tmp_dir = tempfile.mkdtemp()
atexit.register(shutil.rmtree, tmp_dir)

filename = os.path.join(data_path, 'anatomical.nii')
nii1 = nib.load(filename)
data = nii1.get_fdata()
affine2 = nii1.affine.copy()
affine2[0, 0] *= 2
affine2[1, 1] *= -1
nii2 = nib.Nifti1Image(data, affine=affine2, header=nii1.header)
filename2 = os.path.join(tmp_dir, 'anatomical_affine2.nii')
nii2.to_filename(filename2)

with pytest.raises(ValueError):
_test_basic_read([filename, filename2])
126 changes: 87 additions & 39 deletions napari_nibabel/nibabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,60 @@

from napari_plugin_engine import napari_hook_implementation

from nibabel import orientations
from nibabel.imageclasses import all_image_classes
from nibabel.filename_parser import splitext_addext

valid_volume_exts = {klass.valid_exts for klass in all_image_classes}
valid_volume_exts = set(functools.reduce(operator.add, valid_volume_exts))


def get_transform_ornt(affine, target=('L', 'P', 'S')):
current_ornt = orientations.io_orientation(affine)
target_ornt = orientations.axcodes2ornt(('L', 'P', 'S'))
return orientations.ornt_transform(current_ornt, target_ornt)


def adjust_translation(affine, affine_plumb, data_shape):
"""Adjust translation vector of affine_plumb.

The goal is to have affine_plumb result in the same data center
point in world coordinates as the original affine.

Parameters
----------
affine : ndarray
The shape (4, 4) affine matrix read in by nibabel.
affine_plumb: ndarray
The affine after permutation to RAS+ space followed by discarding
of any rotation/shear elements.
data_shape : tuple of int
The shape of the data array

Returns
-------
affine_plumb : ndarray
A copy of affine_plumb with the 3 translation elements updated.
"""
data_shape = data_shape[-3:]
if len(data_shape) < 3:
# TODO: prepend or append?
data_shape = data_shape + (1,) * (3 - data.ndim)

# get center in world coordinates for the original RAS+ affine
center_ijk = (np.array(data_shape) - 1) / 2
center_world = np.dot(affine[:3, :3], center_ijk) + affine[:3, 3]

# make a copy to avoid in-place modification of affine_plumb
affine_plumb = affine_plumb.copy()

# center in world coordinates with the current affine_plumb
center_world_plumb = np.dot(affine_plumb[:3, :3], center_ijk)

# adjust the translation elements
affine_plumb[:3, 3] = center_world - center_world_plumb
return affine_plumb

all_valid_exts = {klass.valid_exts for klass in all_image_classes}
all_valid_exts = set(functools.reduce(operator.add, all_valid_exts))

@napari_hook_implementation
def napari_get_reader(path):
Expand All @@ -48,7 +96,7 @@ def napari_get_reader(path):
froot, ext, addext = splitext_addext(path)

# if we know we cannot read the file, we immediately return None.
if not ext.lower() in all_valid_exts:
if not ext.lower() in valid_volume_exts:
return None

# otherwise we return the *function* that can read ``path``.
Expand Down Expand Up @@ -82,22 +130,33 @@ def reader_function(path):
paths = [path] if isinstance(path, str) else path

n_spatial = 3

# note: we don't squeeze the data below, so 2D data will be 3D with 1 slice
if len(paths) > 1:
# load all files into a single array
objects = [nib.load(_path) for _path in paths]
header = objects[0].header
affine = objects[0].affine
if not all([_obj.shape == _obj[0].shape for _obj in objects]):
header = objects[0].header
if not all([_obj.shape == objects[0].shape for _obj in objects]):
raise ValueError(
"all selected files must contain data of the same shape")

if not all(np.allclose(affine, _obj.affine) for _obj in objects):
raise ValueError(
"all selected files must share a common affine")
# reorient volumes to the desired orientation
transform_ornt = get_transform_ornt(affine, target=('L', 'P', 'S'))
objects = [_obj.as_reoriented(transform_ornt) for _obj in objects]
arrays = [_obj.get_fdata() for _obj in objects]
affine = objects[0].affine
header = objects[0].header

# stack arrays into single array
data = np.stack(arrays)
else:
img = nib.load(paths[0])
# reorient volume to the desired orientation
transform_ornt = get_transform_ornt(img.affine, target=('L', 'P', 'S'))
img = img.as_reoriented(transform_ornt)
header = img.header
affine = img.affine
data = img.get_fdata() # keep this as dataobj or use get_fdata()?
Expand All @@ -114,43 +173,32 @@ def reader_function(path):
if spatial_axis_order != (0, 1, 2):
data = data.transpose(spatial_axis_order[:data.ndim])

try:
# only get zooms for the spatial axes
zooms = np.asarray(header.get_zooms())[:n_spatial]
if np.any(zooms == 0):
raise ValueError("invalid zoom = 0 found in header")
# normalize so values are all >= 1.0 (not strictly necessary)
# zooms = zooms / zooms.min()
zooms = tuple(zooms)
if data.ndim > 3:
zooms = (1.0, ) * (data.ndim - n_spatial) + zooms
except (AttributeError, ValueError):
zooms = (1.0, ) * data.ndim

apply_translation = False
if apply_translation:
translate = tuple(affine[:n_spatial, 3])
if data.ndim > 3:
# set translate = 0.0 on non-spatial dimensions
translate = (0.0,) * (data.ndim - n_spatial) + translate
if np.all(affine[:3, :3] == (np.eye(3) * affine[:3, :3])):
# no rotation or shear components
affine_plumb = affine
else:
translate = (0.0,) * data.ndim
# Set any remaining non-diagonal elements of the affine to 0
# (napari currently cannot display with rotate/shear)
affine_plumb = np.diag(np.diag(affine))

# Set translation elements of affine_plumb to get the center of the
# data cube in the same position in world coordinates
affine_plumb = adjust_translation(affine, affine_plumb, data.shape)

# Note: The translate, scale, rotate, shear kwargs correspond to the
# 'data2physical' component of a composite affine transform.
# https://github.com/napari/napari/blob/v0.4.11/napari/layers/base/base.py#L254-L268 #noqa
# However, the affine kwarg corresponds instead to the 'physical2world'
# affine. Here, we will extract the scale and translate components from
# affine_plumb so that we are specifying 'data2physical' to napari.

# optional kwargs for the corresponding viewer.add_* method
# https://napari.org/docs/api/napari.components.html#module-napari.components.add_layers_mixin
# see also: https://napari.org/tutorials/fundamentals/image
add_kwargs = dict(
metadata=dict(affine=affine, header=header),
rgb=False,
scale=zooms,
translate=translate,
# contrast_limits=,
scale=np.diag(affine_plumb[:3, :3]),
translate=affine_plumb[:3, 3],
affine=None,
channel_axis=None,
)

# TODO: potential kwargs to set for viewer.add_image
# contrast_limits kwarg based on info in image header?
# e.g. for NIFTI: nii.header._structarr['cal_min']
# nii.header._structarr['cal_max']

layer_type = "image" # optional, default is "image"
return [(data, add_kwargs, layer_type)]
return [(data, add_kwargs, "image")]