Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate load_from_volume into inputs.py #74

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 93 additions & 6 deletions ffn/training/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow.io import gfile
import tensorstore as ts


def create_filename_queue(coordinates_file_pattern, shuffle=True):
Expand Down Expand Up @@ -323,21 +324,107 @@ def weighted_load_patch_coordinates(
)


def load_from_numpylike(coordinates, volume_names, shape, volume_map,
name=None):
def _filter_masked(item, volinfo_map_string: str):
mask_value = load_from_volume(
item['coord'],
item['volname'],
patch_size=(1, 1, 1),
dtype=tf.int64,
num_channels=1,
volinfo_map_string=volinfo_map_string,
)
return mask_value[0, 0, 0, 0, 0] > 0


def load_from_volume(
coord, volname, patch_size, dtype, num_channels, volinfo_map_string: str
):
"""Loads data from a volume using TensorStore.

Args:
coord: The coordinates to load from.
volname: The name of the volume.
patch_size: The size of the patch to load.
dtype: The data type of the volume.
num_channels: The number of channels in the volume.
volinfo_map_string: A string representation of the volume info map with the
format "volname1:volinfo_path1,volname2:volinfo_path2".

Returns:
A tensor containing the loaded data.
"""
if num_channels != 1:
raise ValueError('Only num_channels=1 is currently supported.')

volinfo_map = {}
for pair in volinfo_map_string.split(','):
name, path = pair.split(':')
volinfo_map[name.strip()] = path.strip()

def _load_single_volume(inputs):
coord, volinfo_path = inputs
print('volinfo_path:', volinfo_path)
print('coord:', coord)
volinfo_path = volinfo_path.numpy().decode('utf-8')
coord = coord.numpy()
spec = {'driver': 'volumestore', 'volinfo_path': volinfo_path}

store = ts.open(spec, open=True).result()

start_coord = [max(0, c - (p // 2)) for c, p in zip(coord, patch_size)]
stop_coord = [
min(store.shape[i], c + (p // 2) + (p % 2))
for i, (c, p) in enumerate(zip(coord, patch_size))
]

data = (
store[
start_coord[0] : stop_coord[0],
start_coord[1] : stop_coord[1],
start_coord[2] : stop_coord[2],
]
.read()
.result()
)

data = data[:, :, :, 0].transpose(2, 1, 0).astype(dtype.as_numpy_dtype)
data = data[..., tf.newaxis]
return data

patch_size = list(patch_size)
# Convert lists to tensors for tf.map_fn
coords_tensor = tf.convert_to_tensor(coord)
volinfo_paths_tensor = tf.convert_to_tensor(
[volinfo_map[v].encode('utf-8') for v in volname], dtype=tf.string
)

# Use tf.map_fn to process each volume
data_tensor = tf.map_fn(
_load_single_volume,
(coords_tensor, volinfo_paths_tensor),
fn_output_signature=dtype,
dtype=dtype,
)

return data_tensor


def load_from_numpylike(
coordinates, volume_names, shape, volume_map, name=None
):
"""TensorFlow Python op that loads data from Numpy-like volumes.

The volume object must support Numpy-like indexing, as well as shape, ndim,
and dtype properties. The volume can be 3d or 4d.

Args:
coordinates: tensor of shape [1, 3] containing XYZ coordinates of the
center of the subvolume to load.
coordinates: tensor of shape [1, 3] containing XYZ coordinates of the center
of the subvolume to load.
volume_names: tensor of shape [1] containing names of volumes to load data
from.
from.
shape: a 3-sequence giving the XYZ shape of the data to load.
volume_map: a dictionary mapping volume names to volume objects. See above
for API requirements of the Numpy-like volume objects.
for API requirements of the Numpy-like volume objects.
name: the op name.

Returns:
Expand Down
Loading