Skip to content

Commit

Permalink
Migrate inputs.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 646187131
  • Loading branch information
imxj authored and copybara-github committed Jun 25, 2024
1 parent 371a597 commit 770cf37
Showing 1 changed file with 326 additions and 20 deletions.
346 changes: 326 additions & 20 deletions ffn/training/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@
# ==============================================================================
"""Tensorflow Python ops and utilities for generating network inputs."""

import random
import re
from typing import Any, Callable, Optional, Sequence

from absl import logging
from connectomics.common import bounding_box
from connectomics.common import box_generator
from connectomics.segmentation import labels as label_utils
from ffn.training import augmentation
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow.io import gfile
Expand All @@ -28,47 +34,203 @@ def create_filename_queue(coordinates_file_pattern, shuffle=True):
Args:
coordinates_file_pattern: File pattern for TFRecords of
input examples of the form of a glob
pattern or path@shards.
pattern or path@shards
or Comma-separated file patterns.
shuffle: Whether to shuffle the coordinate file list. Note that the expanded
coordinates_file_pattern is not guaranteed to be sorted
alphabetically.
Returns:
Tensorflow queue with coordinate filenames
"""
m = re.search(r'@(\d{1,})', coordinates_file_pattern)
if m:
num_shards = int(m.group(1))
coord_file_list = [
re.sub(r'@(\d{1,})', '-%.5d-of-%.5d' % (i, num_shards),
coordinates_file_pattern)
for i in range(num_shards)]
else:
coord_file_list = gfile.glob(coordinates_file_pattern)
coord_file_list = []
for pattern in coordinates_file_pattern.split(','):
m = re.search(r'@(\d{1,})', pattern)
if m:
num_shards = int(m.group(1))
coord_file_list.extend([
re.sub(
r'@(\d{1,})',
'-%.5d-of-%.5d' % (i, num_shards),
pattern,
)
for i in range(num_shards)
])
else:
coord_file_list.extend(gfile.glob(pattern))
return tf.train.string_input_producer(coord_file_list, shuffle=shuffle)


def load_patch_coordinates_from_filename_queue(filename_queue):
def load_patch_coordinates_from_filename_queue(filename_queue,
file_format='tfrecords'):
"""Loads coordinates and volume names from filename queue.
Args:
filename_queue: Tensorflow queue created from create_filename_queue()
file_format: String indicating the format of the files in the queue.
Can be 'sstables' or 'tfrecords'. Defaults to 'tfrecords'.
Returns:
Tuple of coordinates (shape `[1, 3]`) and volume name (shape `[1]`) tensors.
"""
record_options = tf.python_io.TFRecordOptions(
tf.python_io.TFRecordCompressionType.GZIP)
keys, protos = tf.TFRecordReader(options=record_options).read(filename_queue)
examples = tf.parse_single_example(protos, features=dict(
center=tf.FixedLenFeature(shape=[1, 3], dtype=tf.int64),
label_volume_name=tf.FixedLenFeature(shape=[1], dtype=tf.string),
))
coord = examples['center']
volname = examples['label_volume_name']
if file_format == 'tfrecords':
record_options = tf.python_io.TFRecordOptions(
tf.python_io.TFRecordCompressionType.GZIP)
_, protos = tf.TFRecordReader(options=record_options).read(filename_queue)
examples = tf.parse_single_example(protos, features=dict(
center=tf.FixedLenFeature(shape=[1, 3], dtype=tf.int64),
label_volume_name=tf.FixedLenFeature(shape=[1], dtype=tf.string),
))
coord = examples['center']
volname = examples['label_volume_name']
else:
raise ValueError(f'Unsupported file format: {file_format}.')

return coord, volname


def sample_patch_coordinates(
bboxes: Sequence[Sequence[bounding_box.BoundingBox]],
volinfo_map_string: str,
name='sample_patch_coordinates',
rng_seed: Optional[int] = None,
) -> tf.data.Dataset:
"""Samples a coordinate uniformly at random from specified bboxes.
Args:
bboxes: sequence of sequences for bounding boxes (one seq. per volume)
volinfo_map_string: comma delimited string mapping volname:volinfo_path,
where volinfo_path is a gfile with text_format VolumeInfo proto for the
volume from which patches should be extracted.
name: passed to `name_scope`
rng_seed: Random number generator seed allowing to make the dataset
deterministic.
Returns:
tuple of:
[1, 3] int64 xyz coord tensor
[1] string tensor with the volume label
Raises:
ValueError: if len(bboxes) != len(volinfo_map) or if an invalid bbox is
passed
"""
volinfo_pairs = volinfo_map_string.split(',')
if len(bboxes) != len(volinfo_pairs):
raise ValueError(
'Numbers of bounding boxes and volume paths do not match.'
)

volumes, flat_boxes = [], []
total_voxels = 0
for vol_id, volume_boxes in enumerate(bboxes):
for b in volume_boxes:
w = np.prod(b.size)
if w < 0:
raise ValueError('Volume %d, bbox %r is too small.' % (vol_id, b))
total_voxels += w
flat_boxes.append(b)
volumes.append(vol_id)

calc = box_generator.MultiBoxGenerator(
flat_boxes, box_size=(1, 1, 1), box_overlap=(0, 0, 0)
)
volnames = [v.split(':')[0] for v in volinfo_pairs]

def _sample_volinfo_and_bbox(idx):
idx = idx[0]
vol_idx = volumes[calc.index_to_generator_index(idx)[0]]
_, coord_bbox = calc.generate(idx)
assert coord_bbox is not None
logging.info(
'Sampled location %r from volume %s',
coord_bbox.start,
volnames[vol_idx],
)
coord = np.array([coord_bbox.start]).astype(np.int64)
return coord, volnames[vol_idx]

def _sample(rng_seed):
with tf.name_scope(name=name):
coord, label = tf.py_func(
_sample_volinfo_and_bbox,
[
tf.random.stateless_uniform(
[1],
rng_seed,
maxval=total_voxels,
dtype=tf.int64,
name='rand',
)
],
[tf.int64, tf.string],
name='sample_volinfo_and_bbox',
stateful=False,
)
label.set_shape([])
coord.set_shape([1, 3])
return {'coord': coord, 'volname': tf.reshape(label, [1])}

# This is faster than calling _sample_volinfo_and_bbox via .from_generator.
return tf.data.Dataset.random(seed=rng_seed).batch(2).map(_sample)


def get_vol_map(volinfo_paths: Sequence[str]):
return ','.join(
'vol%d:%s' % (i, volinfo) for i, volinfo in enumerate(volinfo_paths)
)


def parse_tf_coords(x):
return tf.io.parse_single_example(
x,
features=dict(
coord=tf.FixedLenFeature(shape=[1, 3], dtype=tf.int64),
volname=tf.FixedLenFeature(shape=[1], dtype=tf.string),
label=tf.FixedLenFeature(shape=[1], dtype=tf.int64),
segment_id=tf.FixedLenFeature(
shape=[1],
dtype=tf.int64,
default_value=tf.constant([0], dtype=tf.int64),
),
radius=tf.FixedLenFeature(
shape=[1],
dtype=tf.float32,
default_value=tf.constant([0], dtype=tf.float32),
),
),
)


def load_coordinates_from_tfex(
coord_pattern: str,
shuffle: bool = True,
shuffle_size: Optional[int] = 4096,
shuffle_seed: Optional[int] = None,
parse_fn: Callable[[Any], dict[str, Any]] = parse_tf_coords,
reshuffle_each_iteration: bool = True,
) -> tf.data.Dataset:
"""Loads coordinates from a RecordIO of tf.Example protos."""
coord_paths = sorted(gfile.Glob(coord_pattern))
if shuffle:
if shuffle_seed:
random.Random(shuffle_seed).shuffle(coord_paths)
else:
random.shuffle(coord_paths)
logging.info('Loading data from: %r', coord_paths)
ds = tf.data.RecordIODataset(tf.constant(coord_paths, dtype=tf.string))

ds = ds.map(parse_fn, deterministic=True)
if shuffle:
ds = ds.shuffle(
shuffle_size,
seed=shuffle_seed,
reshuffle_each_iteration=reshuffle_each_iteration,
)

return ds.repeat()


def load_patch_coordinates(coordinates_file_pattern,
shuffle=True,
scope='load_patch_coordinates'):
Expand Down Expand Up @@ -187,6 +349,7 @@ def get_offset_scale(volname,
"""

def _get_offset_scale(volname):
volname = volname.decode('utf-8')
if volname in offset_scale_map:
offset, scale = offset_scale_map[volname]
else:
Expand Down Expand Up @@ -356,3 +519,146 @@ def soften_labels(bool_labels, softness=0.05, scope='soften_labels'):
return tf.where(bool_labels,
tf.fill(label_shape, 1.0 - softness, name='soft_true'),
tf.fill(label_shape, softness, name='soft_false'))


def make_labels_contiguous(labels: tf.Tensor) -> tf.Operation:
"""Maps the labels to [0..N].
Args:
labels: [1, z, y, x, 1] int64 tensor of labels
Returns:
labels mapped to the range [0..N] if N distinct non-zero values are
present in the input tensor
"""
ret = tf.py_func(
label_utils.make_contiguous,
inp=[labels],
Tout=tf.int64,
name='make_labels_contiguous',
)
ret.set_shape(labels.shape)
return ret


def apply_augmentation(
data: dict[str, Any],
section_augment: bool,
section_augmentation_args: Optional[dict[str, Any]],
permute_and_reflect_augment: bool,
permutable_axes: list[int],
reflectable_axes: list[int],
rotation_augmentation: Optional[str],
voxel_size: Optional[tuple[float, float, float]],
) -> dict[str, Any]:
"""Applies augmentations to a subvolume of data and corresponding labels.
Args:
data: dict containing at least 'labels' and 'patches' tensors
section_augment: whether to apply section augmentations
section_augmentation_args: kwargs for
augmentation.apply_section_augmentations
permute_and_reflect_augment: whether to apply permutation/reflection
permutable_axes: list of axes to permute
reflectable_axes: list of axes to reflect
rotation_augmentation: type of rotation augmenation to perform ('2d', '3d')
voxel_size: xyz voxel size of the input data (only needed when applying
rotation augmentation
Returns:
'data' dict with 'labels' and 'patches' entries updated according to the
chosen augmentations
"""
labels = data['labels']
patches = data['patches']

# Apply section-wise augmentations.
if section_augment:
final_data_zyx = patches.shape_as_list()[1:4]
final_label_zyx = labels.shape_as_list()[1:4]
patches, labels, _ = augmentation.apply_section_augmentations(
patches,
labels,
labels,
final_data_zyx,
final_label_zyx,
final_label_zyx,
**section_augmentation_args,
)

# Apply basic augmentations.
if permute_and_reflect_augment:
transform_axes = augmentation.PermuteAndReflect(
rank=5,
permutable_axes=permutable_axes,
reflectable_axes=reflectable_axes,
)
labels = transform_axes(labels)
patches = transform_axes(patches)

rot_mtx = None
if rotation_augmentation == '2d':
rot_mtx = augmentation.random_2d_rotation_matrix()
elif rotation_augmentation == '3d':
rot_mtx = augmentation.random_3d_rotation_matrix()

if rot_mtx is not None:
if labels.dtype == tf.int64:
labels = tf.cond(
tf.reduce_any(labels > np.iinfo(np.int32).max), #
lambda: make_labels_contiguous(labels), #
lambda: labels,
)
labels = tf.cast(labels, tf.int32)

assert voxel_size is not None
patches = augmentation.apply_rotation(patches, rot_mtx, voxel_size)
if labels.shape.as_list() != [1, 1, 1, 1, 1]:
labels = augmentation.apply_rotation(labels, rot_mtx, voxel_size)

data['labels'] = labels
data['patches'] = patches
return data


def interleave(datasets: Sequence[tf.data.Dataset], repeat=True):
"""Interleave two or more datasets together, one at a time.
Interleaves two independently generated datasets together, contrary to
Dataset.interleave which interleaves new Datasets generated from each input
item.
Args:
datasets: Sequence of datasets to interleave.
repeat: repeat the interleaved sequence.
Returns:
tf.data.Dataset with interleaved results.
"""
choice_dataset = tf.data.Dataset.range(len(datasets))
if repeat:
choice_dataset = choice_dataset.repeat()
return tf.data.experimental.choose_from_datasets(datasets, choice_dataset)


def sample(
datasets: Sequence[tf.data.Dataset],
repeat=True,
weights: Optional[Sequence[float]] = None,
):
"""Weighted sample of two or more datasets.
Args:
datasets: Sequence of datasets to sample.
repeat: repeat the sampled sequence.
weights: relative weight of each respective dataset.
Returns:
tf.data.Dataset with sampled results.
"""
if weights is None:
weights = [1.0] * len(datasets)
sampled_dataset = tf.data.experimental.sample_from_datasets(datasets, weights)
if repeat:
sampled_dataset = sampled_dataset.repeat()
return sampled_dataset

0 comments on commit 770cf37

Please sign in to comment.