diff --git a/ffn/training/inputs.py b/ffn/training/inputs.py index 4e30bde..fd277b3 100644 --- a/ffn/training/inputs.py +++ b/ffn/training/inputs.py @@ -14,9 +14,15 @@ # ============================================================================== """Tensorflow Python ops and utilities for generating network inputs.""" +import random import re +from typing import Any, Callable, Optional, Sequence +from absl import logging from connectomics.common import bounding_box +from connectomics.common import box_generator +from connectomics.segmentation import labels as label_utils +from ffn.training import augmentation import numpy as np import tensorflow.compat.v1 as tf from tensorflow.io import gfile @@ -28,7 +34,8 @@ def create_filename_queue(coordinates_file_pattern, shuffle=True): Args: coordinates_file_pattern: File pattern for TFRecords of input examples of the form of a glob - pattern or path@shards. + pattern or path@shards + or Comma-separated file patterns. shuffle: Whether to shuffle the coordinate file list. Note that the expanded coordinates_file_pattern is not guaranteed to be sorted alphabetically. @@ -36,39 +43,194 @@ def create_filename_queue(coordinates_file_pattern, shuffle=True): Returns: Tensorflow queue with coordinate filenames """ - m = re.search(r'@(\d{1,})', coordinates_file_pattern) - if m: - num_shards = int(m.group(1)) - coord_file_list = [ - re.sub(r'@(\d{1,})', '-%.5d-of-%.5d' % (i, num_shards), - coordinates_file_pattern) - for i in range(num_shards)] - else: - coord_file_list = gfile.glob(coordinates_file_pattern) + coord_file_list = [] + for pattern in coordinates_file_pattern.split(','): + m = re.search(r'@(\d{1,})', pattern) + if m: + num_shards = int(m.group(1)) + coord_file_list.extend([ + re.sub( + r'@(\d{1,})', + '-%.5d-of-%.5d' % (i, num_shards), + pattern, + ) + for i in range(num_shards) + ]) + else: + coord_file_list.extend(gfile.glob(pattern)) return tf.train.string_input_producer(coord_file_list, shuffle=shuffle) -def load_patch_coordinates_from_filename_queue(filename_queue): +def load_patch_coordinates_from_filename_queue(filename_queue, + file_format='tfrecords'): """Loads coordinates and volume names from filename queue. Args: filename_queue: Tensorflow queue created from create_filename_queue() + file_format: String indicating the format of the files in the queue. + Can be 'sstables' or 'tfrecords'. Defaults to 'tfrecords'. Returns: Tuple of coordinates (shape `[1, 3]`) and volume name (shape `[1]`) tensors. """ - record_options = tf.python_io.TFRecordOptions( - tf.python_io.TFRecordCompressionType.GZIP) - keys, protos = tf.TFRecordReader(options=record_options).read(filename_queue) - examples = tf.parse_single_example(protos, features=dict( - center=tf.FixedLenFeature(shape=[1, 3], dtype=tf.int64), - label_volume_name=tf.FixedLenFeature(shape=[1], dtype=tf.string), - )) - coord = examples['center'] - volname = examples['label_volume_name'] + if file_format == 'tfrecords': + record_options = tf.python_io.TFRecordOptions( + tf.python_io.TFRecordCompressionType.GZIP) + _, protos = tf.TFRecordReader(options=record_options).read(filename_queue) + examples = tf.parse_single_example(protos, features=dict( + center=tf.FixedLenFeature(shape=[1, 3], dtype=tf.int64), + label_volume_name=tf.FixedLenFeature(shape=[1], dtype=tf.string), + )) + coord = examples['center'] + volname = examples['label_volume_name'] + else: + raise ValueError(f'Unsupported file format: {file_format}.') + return coord, volname +def sample_patch_coordinates( + bboxes: Sequence[Sequence[bounding_box.BoundingBox]], + volinfo_map_string: str, + name='sample_patch_coordinates', + rng_seed: Optional[int] = None, +) -> tf.data.Dataset: + """Samples a coordinate uniformly at random from specified bboxes. + + Args: + bboxes: sequence of sequences for bounding boxes (one seq. per volume) + volinfo_map_string: comma delimited string mapping volname:volinfo_path, + where volinfo_path is a gfile with text_format VolumeInfo proto for the + volume from which patches should be extracted. + name: passed to `name_scope` + rng_seed: Random number generator seed allowing to make the dataset + deterministic. + + Returns: + tuple of: + [1, 3] int64 xyz coord tensor + [1] string tensor with the volume label + + Raises: + ValueError: if len(bboxes) != len(volinfo_map) or if an invalid bbox is + passed + """ + volinfo_pairs = volinfo_map_string.split(',') + if len(bboxes) != len(volinfo_pairs): + raise ValueError( + 'Numbers of bounding boxes and volume paths do not match.' + ) + + volumes, flat_boxes = [], [] + total_voxels = 0 + for vol_id, volume_boxes in enumerate(bboxes): + for b in volume_boxes: + w = np.prod(b.size) + if w < 0: + raise ValueError('Volume %d, bbox %r is too small.' % (vol_id, b)) + total_voxels += w + flat_boxes.append(b) + volumes.append(vol_id) + + calc = box_generator.MultiBoxGenerator( + flat_boxes, box_size=(1, 1, 1), box_overlap=(0, 0, 0) + ) + volnames = [v.split(':')[0] for v in volinfo_pairs] + + def _sample_volinfo_and_bbox(idx): + idx = idx[0] + vol_idx = volumes[calc.index_to_generator_index(idx)[0]] + _, coord_bbox = calc.generate(idx) + assert coord_bbox is not None + logging.info( + 'Sampled location %r from volume %s', + coord_bbox.start, + volnames[vol_idx], + ) + coord = np.array([coord_bbox.start]).astype(np.int64) + return coord, volnames[vol_idx] + + def _sample(rng_seed): + with tf.name_scope(name=name): + coord, label = tf.py_func( + _sample_volinfo_and_bbox, + [ + tf.random.stateless_uniform( + [1], + rng_seed, + maxval=total_voxels, + dtype=tf.int64, + name='rand', + ) + ], + [tf.int64, tf.string], + name='sample_volinfo_and_bbox', + stateful=False, + ) + label.set_shape([]) + coord.set_shape([1, 3]) + return {'coord': coord, 'volname': tf.reshape(label, [1])} + + # This is faster than calling _sample_volinfo_and_bbox via .from_generator. + return tf.data.Dataset.random(seed=rng_seed).batch(2).map(_sample) + + +def get_vol_map(volinfo_paths: Sequence[str]): + return ','.join( + 'vol%d:%s' % (i, volinfo) for i, volinfo in enumerate(volinfo_paths) + ) + + +def parse_tf_coords(x): + return tf.io.parse_single_example( + x, + features=dict( + coord=tf.FixedLenFeature(shape=[1, 3], dtype=tf.int64), + volname=tf.FixedLenFeature(shape=[1], dtype=tf.string), + label=tf.FixedLenFeature(shape=[1], dtype=tf.int64), + segment_id=tf.FixedLenFeature( + shape=[1], + dtype=tf.int64, + default_value=tf.constant([0], dtype=tf.int64), + ), + radius=tf.FixedLenFeature( + shape=[1], + dtype=tf.float32, + default_value=tf.constant([0], dtype=tf.float32), + ), + ), + ) + + +def load_coordinates_from_tfex( + coord_pattern: str, + shuffle: bool = True, + shuffle_size: Optional[int] = 4096, + shuffle_seed: Optional[int] = None, + parse_fn: Callable[[Any], dict[str, Any]] = parse_tf_coords, + reshuffle_each_iteration: bool = True, +) -> tf.data.Dataset: + """Loads coordinates from a RecordIO of tf.Example protos.""" + coord_paths = sorted(gfile.Glob(coord_pattern)) + if shuffle: + if shuffle_seed: + random.Random(shuffle_seed).shuffle(coord_paths) + else: + random.shuffle(coord_paths) + logging.info('Loading data from: %r', coord_paths) + ds = tf.data.RecordIODataset(tf.constant(coord_paths, dtype=tf.string)) + + ds = ds.map(parse_fn, deterministic=True) + if shuffle: + ds = ds.shuffle( + shuffle_size, + seed=shuffle_seed, + reshuffle_each_iteration=reshuffle_each_iteration, + ) + + return ds.repeat() + + def load_patch_coordinates(coordinates_file_pattern, shuffle=True, scope='load_patch_coordinates'): @@ -187,6 +349,7 @@ def get_offset_scale(volname, """ def _get_offset_scale(volname): + volname = volname.decode('utf-8') if volname in offset_scale_map: offset, scale = offset_scale_map[volname] else: @@ -356,3 +519,146 @@ def soften_labels(bool_labels, softness=0.05, scope='soften_labels'): return tf.where(bool_labels, tf.fill(label_shape, 1.0 - softness, name='soft_true'), tf.fill(label_shape, softness, name='soft_false')) + + +def make_labels_contiguous(labels: tf.Tensor) -> tf.Operation: + """Maps the labels to [0..N]. + + Args: + labels: [1, z, y, x, 1] int64 tensor of labels + + Returns: + labels mapped to the range [0..N] if N distinct non-zero values are + present in the input tensor + """ + ret = tf.py_func( + label_utils.make_contiguous, + inp=[labels], + Tout=tf.int64, + name='make_labels_contiguous', + ) + ret.set_shape(labels.shape) + return ret + + +def apply_augmentation( + data: dict[str, Any], + section_augment: bool, + section_augmentation_args: Optional[dict[str, Any]], + permute_and_reflect_augment: bool, + permutable_axes: list[int], + reflectable_axes: list[int], + rotation_augmentation: Optional[str], + voxel_size: Optional[tuple[float, float, float]], +) -> dict[str, Any]: + """Applies augmentations to a subvolume of data and corresponding labels. + + Args: + data: dict containing at least 'labels' and 'patches' tensors + section_augment: whether to apply section augmentations + section_augmentation_args: kwargs for + augmentation.apply_section_augmentations + permute_and_reflect_augment: whether to apply permutation/reflection + permutable_axes: list of axes to permute + reflectable_axes: list of axes to reflect + rotation_augmentation: type of rotation augmenation to perform ('2d', '3d') + voxel_size: xyz voxel size of the input data (only needed when applying + rotation augmentation + + Returns: + 'data' dict with 'labels' and 'patches' entries updated according to the + chosen augmentations + """ + labels = data['labels'] + patches = data['patches'] + + # Apply section-wise augmentations. + if section_augment: + final_data_zyx = patches.shape_as_list()[1:4] + final_label_zyx = labels.shape_as_list()[1:4] + patches, labels, _ = augmentation.apply_section_augmentations( + patches, + labels, + labels, + final_data_zyx, + final_label_zyx, + final_label_zyx, + **section_augmentation_args, + ) + + # Apply basic augmentations. + if permute_and_reflect_augment: + transform_axes = augmentation.PermuteAndReflect( + rank=5, + permutable_axes=permutable_axes, + reflectable_axes=reflectable_axes, + ) + labels = transform_axes(labels) + patches = transform_axes(patches) + + rot_mtx = None + if rotation_augmentation == '2d': + rot_mtx = augmentation.random_2d_rotation_matrix() + elif rotation_augmentation == '3d': + rot_mtx = augmentation.random_3d_rotation_matrix() + + if rot_mtx is not None: + if labels.dtype == tf.int64: + labels = tf.cond( + tf.reduce_any(labels > np.iinfo(np.int32).max), # + lambda: make_labels_contiguous(labels), # + lambda: labels, + ) + labels = tf.cast(labels, tf.int32) + + assert voxel_size is not None + patches = augmentation.apply_rotation(patches, rot_mtx, voxel_size) + if labels.shape.as_list() != [1, 1, 1, 1, 1]: + labels = augmentation.apply_rotation(labels, rot_mtx, voxel_size) + + data['labels'] = labels + data['patches'] = patches + return data + + +def interleave(datasets: Sequence[tf.data.Dataset], repeat=True): + """Interleave two or more datasets together, one at a time. + + Interleaves two independently generated datasets together, contrary to + Dataset.interleave which interleaves new Datasets generated from each input + item. + + Args: + datasets: Sequence of datasets to interleave. + repeat: repeat the interleaved sequence. + + Returns: + tf.data.Dataset with interleaved results. + """ + choice_dataset = tf.data.Dataset.range(len(datasets)) + if repeat: + choice_dataset = choice_dataset.repeat() + return tf.data.experimental.choose_from_datasets(datasets, choice_dataset) + + +def sample( + datasets: Sequence[tf.data.Dataset], + repeat=True, + weights: Optional[Sequence[float]] = None, +): + """Weighted sample of two or more datasets. + + Args: + datasets: Sequence of datasets to sample. + repeat: repeat the sampled sequence. + weights: relative weight of each respective dataset. + + Returns: + tf.data.Dataset with sampled results. + """ + if weights is None: + weights = [1.0] * len(datasets) + sampled_dataset = tf.data.experimental.sample_from_datasets(datasets, weights) + if repeat: + sampled_dataset = sampled_dataset.repeat() + return sampled_dataset