diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4a2139c5..5966aa92 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] platform: [ubuntu-latest] steps: diff --git a/docs/source/api.rst b/docs/source/api.rst index 5f120a1c..ea0cfa39 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -91,6 +91,11 @@ BatchFilter Source Nodes ------------ +ArraySource +^^^^^^^^^^^ + + .. autoclass:: ArraySource + ZarrSource ^^^^^^^^^^ .. autoclass:: ZarrSource @@ -334,6 +339,7 @@ Iterative Processing Nodes Scan ^^^^ .. autoclass:: Scan + .. autoclass:: ScanCallback DaisyRequestBlocks ^^^^^^^^^^^^^^^^^^ diff --git a/gunpowder/batch.py b/gunpowder/batch.py index ffc97e77..1ddf200c 100644 --- a/gunpowder/batch.py +++ b/gunpowder/batch.py @@ -44,17 +44,7 @@ class Batch(Freezable): Contains all graphs that have been requested for this batch. """ - __next_id = multiprocessing.Value("L") - - @staticmethod - def get_next_id(): - with Batch.__next_id.get_lock(): - next_id = Batch.__next_id.value - Batch.__next_id.value += 1 - return next_id - def __init__(self): - self.id = Batch.get_next_id() self.profiling_stats = ProfilingStats() self.arrays = {} self.graphs = {} diff --git a/gunpowder/contrib/nodes/dvid_partner_annotation_source.py b/gunpowder/contrib/nodes/dvid_partner_annotation_source.py index f7f08599..36e182e1 100644 --- a/gunpowder/contrib/nodes/dvid_partner_annotation_source.py +++ b/gunpowder/contrib/nodes/dvid_partner_annotation_source.py @@ -1,4 +1,3 @@ -import distutils.util import numpy as np import logging import requests @@ -15,6 +14,16 @@ logger = logging.getLogger(__name__) +def strtobool(val): + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return 1 + elif val in ("n", "no", "f", "false", "off", "0"): + return 0 + else: + raise ValueError(f"Invalid truth value: {val}") + + class DvidPartnerAnnoationSourceReadException(Exception): pass @@ -198,10 +207,10 @@ def __read_syn_points(self, roi): props["agent"] = str(node["Prop"]["agent"]) if "flagged" in node["Prop"]: str_value_flagged = str(node["Prop"]["flagged"]) - props["flagged"] = bool(distutils.util.strtobool(str_value_flagged)) + props["flagged"] = bool(strtobool(str_value_flagged)) if "multi" in node["Prop"]: str_value_multi = str(node["Prop"]["multi"]) - props["multi"] = bool(distutils.util.strtobool(str_value_multi)) + props["multi"] = bool(strtobool(str_value_multi)) # create synPoint with information collected so far (partner_ids not completed yet) if kind == "PreSyn": diff --git a/gunpowder/nodes/__init__.py b/gunpowder/nodes/__init__.py index 42cbf46f..30dbf848 100644 --- a/gunpowder/nodes/__init__.py +++ b/gunpowder/nodes/__init__.py @@ -1,5 +1,6 @@ from __future__ import absolute_import +from .array_source import ArraySource from .add_affinities import AddAffinities from .astype import AsType from .balance_labels import BalanceLabels @@ -34,7 +35,7 @@ from .reject import Reject from .renumber_connected_components import RenumberConnectedComponents from .resample import Resample -from .scan import Scan +from .scan import Scan, ScanCallback from .shift_augment import ShiftAugment from .simple_augment import SimpleAugment from .snapshot import Snapshot diff --git a/gunpowder/nodes/array_source.py b/gunpowder/nodes/array_source.py new file mode 100644 index 00000000..ad4573e1 --- /dev/null +++ b/gunpowder/nodes/array_source.py @@ -0,0 +1,57 @@ +from funlib.persistence.arrays import Array as PersistenceArray +from gunpowder.array import Array, ArrayKey +from gunpowder.array_spec import ArraySpec +from gunpowder.batch import Batch +from .batch_provider import BatchProvider + + +class ArraySource(BatchProvider): + """A `array `_ source. + + Provides a source for any array that can fit into the funkelab + funlib.persistence.Array format. This class comes with assumptions about + the available metadata and convenient methods for indexing the data + with a :class:`Roi` in world units. + + Args: + + key (:class:`ArrayKey`): + + The ArrayKey for accessing this array. + + array (``Array``): + + A `funlib.persistence.Array` object. + + interpolatable (``bool``, optional): + + Whether the array is interpolatable. If not given it is + guessed based on dtype. + + """ + + def __init__( + self, + key: ArrayKey, + array: PersistenceArray, + interpolatable: bool | None = None, + ): + self.key = key + self.array = array + self.array_spec = ArraySpec( + self.array.roi, + self.array.voxel_size, + interpolatable, + False, + self.array.dtype, + ) + + def setup(self): + self.provides(self.key, self.array_spec) + + def provide(self, request): + outputs = Batch() + out_spec = self.array_spec.copy() + out_spec.roi = request[self.key].roi + outputs[self.key] = Array(self.array[out_spec.roi], out_spec) + return outputs diff --git a/gunpowder/nodes/batch_filter.py b/gunpowder/nodes/batch_filter.py index 2ba954c1..4f1e4da2 100644 --- a/gunpowder/nodes/batch_filter.py +++ b/gunpowder/nodes/batch_filter.py @@ -137,7 +137,7 @@ def autoskip_enabled(self): return self._autoskip_enabled def provide(self, request): - skip = self.__can_skip(request) + skip = self.__can_skip(request) or self.skip_node(request) timing_prepare = Timing(self, "prepare") timing_prepare.start() @@ -206,6 +206,14 @@ def __can_skip(self, request): return True + def skip_node(self, request): + """To be implemented in subclasses. + + Skip a node if a condition is met. Can be useful if using a probability + to determine whether to use an augmentation, for example. + """ + pass + def setup(self): """To be implemented in subclasses. diff --git a/gunpowder/nodes/csv_points_source.py b/gunpowder/nodes/csv_points_source.py index 59e6c193..7f40df32 100644 --- a/gunpowder/nodes/csv_points_source.py +++ b/gunpowder/nodes/csv_points_source.py @@ -1,19 +1,23 @@ +from typing import Union, Optional import numpy as np import logging from gunpowder.batch import Batch from gunpowder.coordinate import Coordinate from gunpowder.nodes.batch_provider import BatchProvider -from gunpowder.graph import Node, Graph +from gunpowder.graph import Node, Graph, GraphKey from gunpowder.graph_spec import GraphSpec from gunpowder.profiling import Timing from gunpowder.roi import Roi +import csv logger = logging.getLogger(__name__) class CsvPointsSource(BatchProvider): """Read a set of points from a comma-separated-values text file. Each line - in the file represents one point, e.g. z y x (id) + in the file represents one point, e.g. z y x (id). Note: this reads all + points into memory and finds the ones in the given roi by iterating + over all the points. For large datasets, this may be too slow. Args: @@ -25,6 +29,11 @@ class CsvPointsSource(BatchProvider): The key of the points set to create. + spatial_cols (list[``int``]): + + The columns of the csv that hold the coordinates of the points + (in the order that you want them to be used in training) + points_spec (:class:`GraphSpec`, optional): An optional :class:`GraphSpec` to overwrite the points specs @@ -37,28 +46,36 @@ class CsvPointsSource(BatchProvider): from the CSV file. This is useful if the points refer to voxel positions to convert them to world units. - ndims (``int``): + id_col (``int``, optional): - If ``ndims`` is None, all values in one line are considered as the - location of the point. If positive, only the first ``ndims`` are used. - If negative, all but the last ``-ndims`` are used. + The column of the csv that holds an id for each point. If not + provided, the index of the rows are used as the ids. When read + from file, ids are left as strings and not cast to anything. - id_dim (``int``): + delimiter (``str``, optional): - Each line may optionally contain an id for each point. This parameter - specifies its location, has to come after the position values. + Delimiter to pass to the csv reader. Defaults to ",". """ def __init__( - self, filename, points, points_spec=None, scale=None, ndims=None, id_dim=None + self, + filename: str, + points: GraphKey, + spatial_cols: list[int], + points_spec: Optional[GraphSpec] = None, + scale: Optional[Union[int, float, tuple, list, np.ndarray]] = None, + id_col: Optional[int] = None, + delimiter: str = ",", ): self.filename = filename self.points = points self.points_spec = points_spec self.scale = scale - self.ndims = ndims - self.id_dim = id_dim - self.data = None + self.spatial_cols = spatial_cols + self.id_dim = id_col + self.delimiter = delimiter + self.data: Optional[np.ndarray] = None + self.ids: Optional[list] = None def setup(self): self._parse_csv() @@ -67,8 +84,8 @@ def setup(self): self.provides(self.points, self.points_spec) return - min_bb = Coordinate(np.floor(np.amin(self.data[:, : self.ndims], 0))) - max_bb = Coordinate(np.ceil(np.amax(self.data[:, : self.ndims], 0)) + 1) + min_bb = Coordinate(np.floor(np.amin(self.data, 0))) + max_bb = Coordinate(np.ceil(np.amax(self.data, 0)) + 1) roi = Roi(min_bb, max_bb - min_bb) @@ -84,7 +101,7 @@ def provide(self, request): logger.debug("CSV points source got request for %s", request[self.points].roi) point_filter = np.ones((self.data.shape[0],), dtype=bool) - for d in range(self.ndims): + for d in range(len(self.spatial_cols)): point_filter = np.logical_and(point_filter, self.data[:, d] >= min_bb[d]) point_filter = np.logical_and(point_filter, self.data[:, d] < max_bb[d]) @@ -100,30 +117,35 @@ def provide(self, request): return batch def _get_points(self, point_filter): - filtered = self.data[point_filter][:, : self.ndims] - - if self.id_dim is not None: - ids = self.data[point_filter][:, self.id_dim] - else: - ids = np.arange(len(self.data))[point_filter] - + filtered = self.data[point_filter] + ids = self.ids[point_filter] return [Node(id=i, location=p) for i, p in zip(ids, filtered)] def _parse_csv(self): - """Read one point per line. If ``ndims`` is None, all values in one line - are considered as the location of the point. If positive, only the - first ``ndims`` are used. If negative, all but the last ``-ndims`` are - used. + """Read one point per line, with spatial and id columns determined by + self.spatial_cols and self.id_col. """ - - with open(self.filename, "r") as f: - self.data = np.array( - [[float(t.strip(",")) for t in line.split()] for line in f], - dtype=np.float32, - ) - - if self.ndims is None: - self.ndims = self.data.shape[1] + data = [] + ids = [] + with open(self.filename, "r", newline="") as f: + has_header = csv.Sniffer().has_header(f.read(1024)) + f.seek(0) + first_line = True + reader = csv.reader(f, delimiter=self.delimiter) + for line in reader: + if first_line and has_header: + first_line = False + continue + space = [float(line[c]) for c in self.spatial_cols] + data.append(space) + if self.id_dim is not None: + ids.append(line[self.id_dim]) + + self.data = np.array(data, dtype=np.float32) + if self.id_dim: + self.ids = np.array(ids) + else: + self.ids = np.arange(len(self.data)) if self.scale is not None: - self.data[:, : self.ndims] *= self.scale + self.data *= self.scale diff --git a/gunpowder/nodes/defect_augment.py b/gunpowder/nodes/defect_augment.py index b7eb56d1..8ca6b6f3 100644 --- a/gunpowder/nodes/defect_augment.py +++ b/gunpowder/nodes/defect_augment.py @@ -65,6 +65,13 @@ class DefectAugment(BatchFilter): axis (``int``, optional): Along which axis sections are cut. + + p (``float``, optional): + + Probability applying the augmentation. Default is 1.0 (always + apply). Should be a float value between 0 and 1. Lowering this value + could be useful for computational efficiency and increasing + augmentation space. """ def __init__( @@ -80,6 +87,7 @@ def __init__( artifacts_mask=None, deformation_strength=20, axis=0, + p=1.0, ): self.intensities = intensities self.prob_missing = prob_missing @@ -92,6 +100,7 @@ def __init__( self.artifacts_mask = artifacts_mask self.deformation_strength = deformation_strength self.axis = axis + self.p = p def setup(self): if self.artifact_source is not None: @@ -101,6 +110,9 @@ def teardown(self): if self.artifact_source is not None: self.artifact_source.teardown() + def skip_node(self, request): + return random.random() > self.p + # send roi request to data-source upstream def prepare(self, request): deps = BatchRequest() diff --git a/gunpowder/nodes/deform_augment.py b/gunpowder/nodes/deform_augment.py index 36f19c1c..c78ddc1f 100644 --- a/gunpowder/nodes/deform_augment.py +++ b/gunpowder/nodes/deform_augment.py @@ -82,6 +82,14 @@ class DeformAugment(BatchFilter): Whether or not to compute the elastic transform node wise for nodes that were lossed during the fast elastic transform process. + + + p (``float``, optional): + + Probability applying the augmentation. Default is 1.0 (always + apply). Should be a float value between 0 and 1. Lowering this value + could be useful for computational efficiency and increasing + augmentation space. """ def __init__( @@ -96,6 +104,7 @@ def __init__( recompute_missing_points=True, transform_key: Optional[ArrayKey] = None, graph_raster_voxel_size: Optional[Coordinate] = None, + p: float = 1.0, ): self.control_point_spacing = Coordinate(control_point_spacing) self.jitter_sigma = Coordinate(jitter_sigma) @@ -121,6 +130,7 @@ def __init__( f"jitter_sigma: {self.jitter_sigma}, " f"and graph_raster_voxel_size must have the same number of dimensions" ) + self.p = p def setup(self): if self.transform_key is not None: @@ -137,6 +147,9 @@ def setup(self): self.provides(self.transform_key, spec) + def skip_node(self, request): + return random.random() > self.p + def prepare(self, request): # get the total ROI of all requests total_roi = request.get_total_roi() @@ -499,6 +512,9 @@ def __create_transformation(self, target_spec: ArraySpec): local_transformation = upscale_transformation( local_transformation, target_shape ) + global_transformation = upscale_transformation( + global_transformation, target_shape + ) # transform into world units global_transformation *= np.array(target_spec.voxel_size).reshape( diff --git a/gunpowder/nodes/dvid_source.py b/gunpowder/nodes/dvid_source.py index d285a502..312dd59e 100644 --- a/gunpowder/nodes/dvid_source.py +++ b/gunpowder/nodes/dvid_source.py @@ -182,11 +182,8 @@ def __get_spec(self, array_key): spec.dtype = data_dtype if spec.interpolatable is None: - spec.interpolatable = spec.dtype in ( - np.sctypes["float"] - + [ - np.uint8, # assuming this is not used for labels - ] + spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or ( + spec.dtype == np.uint8 ) logger.warning( "WARNING: You didn't set 'interpolatable' for %s. " diff --git a/gunpowder/nodes/elastic_augment.py b/gunpowder/nodes/elastic_augment.py index 88b881b3..c5da5f81 100644 --- a/gunpowder/nodes/elastic_augment.py +++ b/gunpowder/nodes/elastic_augment.py @@ -87,6 +87,13 @@ class ElasticAugment(BatchFilter): Whether or not to compute the elastic transform node wise for nodes that were lossed during the fast elastic transform process. + + p (``float``, optional): + + Probability applying the augmentation. Default is 1.0 (always + apply). Should be a float value between 0 and 1. Lowering this value + could be useful for computational efficiency and increasing + augmentation space. """ def __init__( @@ -102,6 +109,7 @@ def __init__( spatial_dims=3, use_fast_points_transform=False, recompute_missing_points=True, + p=1.0, ): warnings.warn( "ElasticAugment is deprecated, please use the DeformAugment", @@ -121,6 +129,10 @@ def __init__( self.spatial_dims = spatial_dims self.use_fast_points_transform = use_fast_points_transform self.recompute_missing_points = recompute_missing_points + self.p = p + + def skip_node(self, request): + return random.random() > self.p def prepare(self, request): # get the voxel size diff --git a/gunpowder/nodes/hdf5like_source_base.py b/gunpowder/nodes/hdf5like_source_base.py index d7c63149..f5a8e58b 100644 --- a/gunpowder/nodes/hdf5like_source_base.py +++ b/gunpowder/nodes/hdf5like_source_base.py @@ -174,11 +174,8 @@ def __read_spec(self, array_key, data_file, ds_name): spec.dtype = dataset.dtype if spec.interpolatable is None: - spec.interpolatable = spec.dtype in ( - np.sctypes["float"] - + [ - np.uint8, # assuming this is not used for labels - ] + spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or ( + spec.dtype == np.uint8 ) logger.warning( "WARNING: You didn't set 'interpolatable' for %s " diff --git a/gunpowder/nodes/intensity_augment.py b/gunpowder/nodes/intensity_augment.py index 771f57bb..1055549f 100644 --- a/gunpowder/nodes/intensity_augment.py +++ b/gunpowder/nodes/intensity_augment.py @@ -1,4 +1,5 @@ import numpy as np +import random from gunpowder.batch_request import BatchRequest @@ -34,6 +35,13 @@ class IntensityAugment(BatchFilter): Set to False if modified values should not be clipped to [0, 1] Disables range check! + + p (``float``, optional): + + Probability applying the augmentation. Default is 1.0 (always + apply). Should be a float value between 0 and 1. Lowering this value + could be useful for computational efficiency and increasing + augmentation space. """ def __init__( @@ -45,6 +53,7 @@ def __init__( shift_max, z_section_wise=False, clip=True, + p=1.0, ): self.array = array self.scale_min = scale_min @@ -53,11 +62,15 @@ def __init__( self.shift_max = shift_max self.z_section_wise = z_section_wise self.clip = clip + self.p = p def setup(self): self.enable_autoskip() self.updates(self.array, self.spec[self.array]) + def skip_node(self, request): + return random.random() > self.p + def prepare(self, request): deps = BatchRequest() deps[self.array] = request[self.array].copy() diff --git a/gunpowder/nodes/klb_source.py b/gunpowder/nodes/klb_source.py index 53eca5c4..d4a55049 100644 --- a/gunpowder/nodes/klb_source.py +++ b/gunpowder/nodes/klb_source.py @@ -155,11 +155,8 @@ def __read_spec(self, headers): spec.dtype = dtype if spec.interpolatable is None: - spec.interpolatable = spec.dtype in ( - np.sctypes["float"] - + [ - np.uint8, # assuming this is not used for labels - ] + spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or ( + spec.dtype == np.uint8 ) logger.warning( "WARNING: You didn't set 'interpolatable' for %s. " diff --git a/gunpowder/nodes/noise_augment.py b/gunpowder/nodes/noise_augment.py index 5275a2c0..c2ff223f 100644 --- a/gunpowder/nodes/noise_augment.py +++ b/gunpowder/nodes/noise_augment.py @@ -1,4 +1,5 @@ import numpy as np +import random import skimage from gunpowder.batch_request import BatchRequest @@ -24,18 +25,29 @@ class NoiseAugment(BatchFilter): Whether to preserve the image range (either [-1, 1] or [0, 1]) by clipping values in the end, see scikit-image documentation + + p (``float``, optional): + + Probability applying the augmentation. Default is 1.0 (always + apply). Should be a float value between 0 and 1. Lowering this value + could be useful for computational efficiency and increasing + augmentation space. """ - def __init__(self, array, mode="gaussian", clip=True, **kwargs): + def __init__(self, array, mode="gaussian", clip=True, p=1.0, **kwargs): self.array = array self.mode = mode self.clip = clip + self.p = p self.kwargs = kwargs def setup(self): self.enable_autoskip() self.updates(self.array, self.spec[self.array]) + def skip_node(self, request): + return random.random() > self.p + def prepare(self, request): deps = BatchRequest() deps[self.array] = request[self.array].copy() diff --git a/gunpowder/nodes/scan.py b/gunpowder/nodes/scan.py index ef6b378e..3473764e 100644 --- a/gunpowder/nodes/scan.py +++ b/gunpowder/nodes/scan.py @@ -2,6 +2,7 @@ import multiprocessing import numpy as np import tqdm +from abc import ABC from gunpowder.array import Array from gunpowder.batch import Batch from gunpowder.coordinate import Coordinate @@ -13,6 +14,55 @@ logger = logging.getLogger(__name__) +class ScanCallback(ABC): + """Base class for :class:`Scan` callbacks. Implement any of ``start``, + ``update``, and ``stop`` in a subclass to create your own callback. + """ + + def start(self, num_total): + """Called once before :class:`Scan` starts scanning over chunks. + + Args: + + num_total (int): + + The total number of chunks to process. + """ + pass + + def update(self, num_processed): + """Called periodically by :class:`Scan` while processing chunks. + + Args: + + num_processed (int): + + The number of chunks already processed. + """ + pass + + def stop(self): + """Called once after :class:`Scan` scanned over all chunks.""" + pass + + +class TqdmCallback(ScanCallback): + """A default callback that uses ``tqdm`` to show a progress bar.""" + + def start(self, num_total): + logger.info("scanning over %d chunks", num_total) + + self.progress_bar = tqdm.tqdm(desc="Scan, chunks processed", total=num_total) + self.num_processed = 0 + + def update(self, num_processed): + self.progress_bar.update(num_processed - self.num_processed) + self.num_processed = num_processed + + def stop(self): + self.progress_bar.close() + + class Scan(BatchFilter): """Iteratively requests batches of size ``reference`` from upstream providers in a scanning fashion, until all requested ROIs are covered. If @@ -40,14 +90,24 @@ class Scan(BatchFilter): cache_size (``int``, optional): If multiple workers are used, how many batches to hold at most. + + progress_callback (class:`ScanCallback`, optional): + + A callback instance to get updated from this node while processing + chunks. See :class:`ScanCallback` for details. The default is a + callback that shows a ``tqdm`` progress bar. """ - def __init__(self, reference, num_workers=1, cache_size=50): + def __init__(self, reference, num_workers=1, cache_size=50, progress_callback=None): self.reference = reference.copy() self.num_workers = num_workers self.cache_size = cache_size self.workers = None self.batch = None + if progress_callback is None: + self.progress_callback = TqdmCallback() + else: + self.progress_callback = progress_callback def setup(self): if self.num_workers > 1: @@ -75,7 +135,8 @@ def provide(self, request): shifts = self._enumerate_shifts(shift_roi, stride) num_chunks = len(shifts) - logger.info("scanning over %d chunks", num_chunks) + if self.progress_callback is not None: + self.progress_callback.start(num_chunks) # the batch to return self.batch = Batch() @@ -85,24 +146,33 @@ def provide(self, request): shifted_reference = self._shift_request(self.reference, shift) self.request_queue.put(shifted_reference) - for i in tqdm.tqdm(range(num_chunks)): + for i in range(num_chunks): chunk = self.workers.get() if not empty_request: self._add_to_batch(request, chunk) + if self.progress_callback is not None: + self.progress_callback.update(i + 1) + logger.debug("processed chunk %d/%d", i + 1, num_chunks) else: - for i, shift in enumerate(tqdm.tqdm(shifts)): + for i, shift in enumerate(shifts): shifted_reference = self._shift_request(self.reference, shift) chunk = self._get_chunk(shifted_reference) if not empty_request: self._add_to_batch(request, chunk) + if self.progress_callback is not None: + self.progress_callback.update(i + 1) + logger.debug("processed chunk %d/%d", i + 1, num_chunks) + if self.progress_callback is not None: + self.progress_callback.stop() + batch = self.batch self.batch = None diff --git a/gunpowder/nodes/shift_augment.py b/gunpowder/nodes/shift_augment.py index d42b1434..9ddace33 100644 --- a/gunpowder/nodes/shift_augment.py +++ b/gunpowder/nodes/shift_augment.py @@ -11,17 +11,21 @@ class ShiftAugment(BatchFilter): - def __init__(self, prob_slip=0, prob_shift=0, sigma=0, shift_axis=0): + def __init__(self, prob_slip=0, prob_shift=0, sigma=0, shift_axis=0, p=1.0): self.prob_slip = prob_slip self.prob_shift = prob_shift self.sigma = sigma self.shift_axis = shift_axis + self.p = p self.ndim = None self.shift_sigmas = None self.shift_array = None self.lcm_voxel_size = None + def skip_node(self, request): + return random.random() > self.p + def prepare(self, request): self.ndim = request.get_total_roi().dims assert self.shift_axis in range(self.ndim) diff --git a/gunpowder/nodes/simple_augment.py b/gunpowder/nodes/simple_augment.py index adde087b..ffaf16fb 100644 --- a/gunpowder/nodes/simple_augment.py +++ b/gunpowder/nodes/simple_augment.py @@ -47,6 +47,13 @@ class SimpleAugment(BatchFilter): and attempt to weight them appropriately. A weight of 0 means this axis will never be transposed, a weight of 1 means this axis will always be transposed. + + p (``float``, optional): + + Probability applying the augmentation. Default is 1.0 (always + apply). Should be a float value between 0 and 1. Lowering this value + could be useful for computational efficiency and increasing + augmentation space. """ def __init__( @@ -55,6 +62,7 @@ def __init__( transpose_only=None, mirror_probs=None, transpose_probs=None, + p=1.0, ): self.mirror_only = mirror_only self.mirror_probs = mirror_probs @@ -63,6 +71,7 @@ def __init__( self.mirror_mask = None self.dims = None self.transpose_dims = None + self.p = p def setup(self): self.dims = self.spec.get_total_roi().dims @@ -105,6 +114,9 @@ def setup(self): if valid: self.permutation_dict[k] = v + def skip_node(self, request): + return random.random() > self.p + def prepare(self, request): self.mirror = [ random.random() < self.mirror_probs[d] if self.mirror_mask[d] else 0 diff --git a/gunpowder/nodes/snapshot.py b/gunpowder/nodes/snapshot.py index 5c66280c..acc3f624 100644 --- a/gunpowder/nodes/snapshot.py +++ b/gunpowder/nodes/snapshot.py @@ -76,7 +76,7 @@ def __init__( self, dataset_names, output_dir="snapshots", - output_filename="{id}.zarr", + output_filename="{iteration}.zarr", every=1, additional_request=None, compression_type=None, @@ -99,6 +99,7 @@ def __init__( self.dataset_dtypes = dataset_dtypes self.mode = "w" + self.id = 0 def write_if(self, batch): """To be implemented in subclasses. @@ -159,6 +160,7 @@ def prepare(self, request): return deps def process(self, batch, request): + self.id += 1 if self.record_snapshot and self.write_if(batch): try: os.makedirs(self.output_dir) @@ -168,7 +170,7 @@ def process(self, batch, request): snapshot_name = os.path.join( self.output_dir, self.output_filename.format( - id=str(batch.id).zfill(8), iteration=int(batch.iteration or 0) + id=str(self.id).zfill(8), iteration=int(batch.iteration or self.id) ), ) logger.info("saving to %s" % snapshot_name) diff --git a/gunpowder/nodes/zarr_source.py b/gunpowder/nodes/zarr_source.py index 2f1c15fc..82831fa3 100644 --- a/gunpowder/nodes/zarr_source.py +++ b/gunpowder/nodes/zarr_source.py @@ -206,11 +206,8 @@ def __read_spec(self, array_key, data_file, ds_name): spec.dtype = dataset.dtype if spec.interpolatable is None: - spec.interpolatable = spec.dtype in ( - np.sctypes["float"] - + [ - np.uint8, # assuming this is not used for labels - ] + spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or ( + spec.dtype == np.uint8 ) logger.warning( "WARNING: You didn't set 'interpolatable' for %s " diff --git a/gunpowder/tensorflow/nodes/predict.py b/gunpowder/tensorflow/nodes/predict.py index 0a92a0f6..d2c03498 100644 --- a/gunpowder/tensorflow/nodes/predict.py +++ b/gunpowder/tensorflow/nodes/predict.py @@ -112,7 +112,7 @@ def predict(self, batch, request): break if can_skip: - logger.info("Skipping batch %i (all inputs are 0)" % batch.id) + logger.info(f"Skipping batch for request: {request} (all inputs are 0)") for name, array_key in self.outputs.items(): shape = self.shared_output_arrays[name].shape @@ -124,7 +124,7 @@ def predict(self, batch, request): return - logger.debug("predicting in batch %i", batch.id) + logger.debug(f"predicting for request: {request}") output_tensors = self.__collect_outputs(request) input_data = self.__collect_provided_inputs(batch) @@ -160,7 +160,7 @@ def predict(self, batch, request): spec.roi = request[array_key].roi batch.arrays[array_key] = Array(output_data[array_key], spec) - logger.debug("predicted in batch %i", batch.id) + logger.debug("predicted") def __predict(self): """The background predict process.""" diff --git a/gunpowder/torch/nodes/predict.py b/gunpowder/torch/nodes/predict.py index 89c9ac0c..3bc58344 100644 --- a/gunpowder/torch/nodes/predict.py +++ b/gunpowder/torch/nodes/predict.py @@ -18,10 +18,10 @@ class Predict(GenericPredict): The model to use for prediction. - inputs (``dict``, ``string`` -> :class:`ArrayKey`): + inputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`): - Dictionary from the names of input tensors (argument names of the - ``forward`` method) in the model to array keys. + Dictionary from the position (for args) and names (for kwargs) of input + tensors (argument names of the ``forward`` method) in the model to array keys. outputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`): @@ -58,7 +58,7 @@ class Predict(GenericPredict): def __init__( self, model, - inputs: Dict[str, ArrayKey], + inputs: Dict[Union[str, int], ArrayKey], outputs: Dict[Union[str, int], ArrayKey], array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, checkpoint: Optional[str] = None, @@ -111,18 +111,24 @@ def start(self): self.register_hooks() def predict(self, batch, request): - inputs = self.get_inputs(batch) + input_args, input_kwargs = self.get_inputs(batch) with torch.no_grad(): - out = self.model.forward(**inputs) + out = self.model.forward(*input_args, **input_kwargs) outputs = self.get_outputs(out, request) self.update_batch(batch, request, outputs) def get_inputs(self, batch): - model_inputs = { + model_args = [ + torch.as_tensor(batch[self.inputs[ii]].data, device=self.device) + for ii in range(len(self.inputs)) + if ii in self.inputs + ] + model_kwargs = { key: torch.as_tensor(batch[value].data, device=self.device) for key, value in self.inputs.items() + if isinstance(key, str) } - return model_inputs + return model_args, model_kwargs def register_hooks(self): for key in self.outputs: diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index 676b2c71..241332db 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -29,7 +29,7 @@ class Train(GenericTrain): The torch optimizer to use. - inputs (``dict``, ``string`` -> :class:`ArrayKey`): + inputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`): Dictionary from the names of input tensors (argument names of the ``forward`` method) in the model to array keys. @@ -52,6 +52,16 @@ class Train(GenericTrain): New arrays will be generated by this node for each entry (if requested downstream). + gradients (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`, optional): + + Dictionary from the names of tensors in the network to array + keys. If the key is a string, the tensor will be retrieved + by checking the model for an attribute with they key as its name. + If the key is an integer, it is interpreted as a tuple index of + the outputs of the network. + Instead of the actual array, the gradient of the array with respect + to the loss will be generated and saved. + array_specs (``dict``, :class:`ArrayKey` -> :class:`ArraySpec`, optional): Used to set the specs of generated arrays (at the moment only @@ -92,7 +102,7 @@ def __init__( model, loss, optimizer, - inputs: Dict[str, ArrayKey], + inputs: Dict[Union[str, int], ArrayKey], outputs: Dict[Union[int, str], ArrayKey], loss_inputs: Dict[Union[int, str], ArrayKey], gradients: Dict[Union[int, str], ArrayKey] = {}, @@ -112,11 +122,13 @@ def __init__( # not yet implemented gradients = gradients - all_inputs = { - k: v - for k, v in itertools.chain(inputs.items(), loss_inputs.items()) - if v not in outputs.values() + loss_inputs = {f"loss_{k}": v for k, v in loss_inputs.items()} + all_inputs: dict[str | int, Any] = { + f"{k}": v for k, v in inputs.items() if v not in outputs.values() } + all_inputs.update( + {k: v for k, v in loss_inputs.items() if v not in outputs.values()} + ) super(Train, self).__init__( all_inputs, @@ -208,16 +220,22 @@ def start(self): def train_step(self, batch, request): inputs = self.__collect_provided_inputs(batch) + inputs = {k: torch.as_tensor(v, device=self.device) for k, v in inputs.items()} requested_outputs = self.__collect_requested_outputs(request) # keys are argument names of model forward pass - device_inputs = { - k: torch.as_tensor(v, device=self.device) for k, v in inputs.items() - } + device_input_args = [] + for i in range(len(inputs)): + key = f"{i}" + if key in inputs: + device_input_args.append(inputs.pop(key)) + else: + break + device_input_kwargs = {k: v for k, v in inputs.items() if isinstance(k, str)} # get outputs. Keys are tuple indices or model attr names as in self.outputs self.optimizer.zero_grad() - model_outputs = self.model(**device_inputs) + model_outputs = self.model(*device_input_args, **device_input_kwargs) if isinstance(model_outputs, tuple): outputs = {i: model_outputs[i] for i in range(len(model_outputs))} elif isinstance(model_outputs, torch.Tensor): @@ -247,8 +265,9 @@ def train_step(self, batch, request): device_loss_args = [] for i in range(len(device_loss_inputs)): - if i in device_loss_inputs: - device_loss_args.append(device_loss_inputs.pop(i)) + key = f"loss_{i}" + if key in device_loss_inputs: + device_loss_args.append(device_loss_inputs.pop(key)) else: break device_loss_kwargs = {} @@ -327,7 +346,12 @@ def __collect_requested_outputs(self, request): def __collect_provided_inputs(self, batch): return self.__collect_provided_arrays( - {k: v for k, v in self.inputs.items() if k not in self.loss_inputs}, batch + { + k: v + for k, v in self.inputs.items() + if (isinstance(k, int) or k not in self.loss_inputs) + }, + batch, ) def __collect_provided_loss_inputs(self, batch): diff --git a/gunpowder/version_info.py b/gunpowder/version_info.py index b4de72dc..1aa73336 100644 --- a/gunpowder/version_info.py +++ b/gunpowder/version_info.py @@ -1,6 +1,6 @@ __major__ = 1 -__minor__ = 3 -__patch__ = 4 +__minor__ = 4 +__patch__ = 0 __tag__ = "" __version__ = "{}.{}.{}{}".format(__major__, __minor__, __patch__, __tag__).strip(".") diff --git a/pyproject.toml b/pyproject.toml index ead9be13..ecf3ddcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dynamic = ["version"] classifiers = ["Programming Language :: Python :: 3"] keywords = [] -requires-python = ">=3.7" +requires-python = ">=3.9" dependencies = [ "numpy>=1.24", @@ -40,10 +40,20 @@ dependencies = [ "funlib.geometry>=0.3", "zarr", "networkx>=3.1", + "funlib.persistence>=0.5", ] [project.optional-dependencies] -dev = ["pytest", "pytest-cov", "flake8", "mypy", "types-requests", "types-tqdm"] +dev = [ + "pytest", + "pytest-cov", + "flake8", + "mypy", + "types-requests", + "types-tqdm", + "black", + "ruff", +] docs = [ "sphinx", "sphinx_rtd_theme", diff --git a/tests/cases/array_source.py b/tests/cases/array_source.py new file mode 100644 index 00000000..f7cb666b --- /dev/null +++ b/tests/cases/array_source.py @@ -0,0 +1,29 @@ +from funlib.persistence import prepare_ds +from funlib.geometry import Roi +from gunpowder.nodes import ArraySource +from gunpowder import ArrayKey, build, BatchRequest, ArraySpec + +import numpy as np + + +def test_array_source(tmpdir): + array = prepare_ds( + tmpdir / "data.zarr", + shape=(100, 102, 108), + offset=(100, 50, 0), + voxel_size=(1, 2, 3), + dtype="uint8", + ) + array[:] = np.arange(100 * 102 * 108).reshape((100, 102, 108)) % 255 + + key = ArrayKey("TEST") + + source = ArraySource(key=key, array=array) + + with build(source): + request = BatchRequest() + + roi = Roi((100, 100, 102), (30, 30, 30)) + request[key] = ArraySpec(roi) + + assert np.array_equal(source.request_batch(request)[key].data, array[roi]) diff --git a/tests/cases/batch_filter.py b/tests/cases/batch_filter.py new file mode 100644 index 00000000..63288bd5 --- /dev/null +++ b/tests/cases/batch_filter.py @@ -0,0 +1,53 @@ +from .helper_sources import ArraySource +from gunpowder import ( + ArrayKey, + build, + Array, + ArraySpec, + Roi, + Coordinate, + BatchRequest, + BatchFilter, +) + +import numpy as np +import random + + +class DummyNode(BatchFilter): + def __init__(self, array, p=1.0): + self.array = array + self.p = p + + def skip_node(self, request): + return random.random() > self.p + + def process(self, batch, request): + batch[self.array].data = batch[self.array].data + 1 + + +def test_skip(): + raw_key = ArrayKey("RAW") + array = Array( + np.ones((10, 10)), + ArraySpec(Roi((0, 0), (10, 10)), Coordinate(1, 1)), + ) + source = ArraySource(raw_key, array) + + request_1 = BatchRequest(random_seed=1) + request_2 = BatchRequest(random_seed=2) + + request_1.add(raw_key, Coordinate(10, 10)) + request_2.add(raw_key, Coordinate(10, 10)) + + pipeline = source + DummyNode(raw_key, p=0.5) + + with build(pipeline): + batch_1 = pipeline.request_batch(request_1) + batch_2 = pipeline.request_batch(request_2) + + x_1 = batch_1.arrays[raw_key].data + x_2 = batch_2.arrays[raw_key].data + + assert x_1.max() == 2 + assert x_2.max() == 1 diff --git a/tests/cases/csv_points_source.py b/tests/cases/csv_points_source.py new file mode 100644 index 00000000..ec8d88ec --- /dev/null +++ b/tests/cases/csv_points_source.py @@ -0,0 +1,109 @@ +import random + +import numpy as np +import pytest +import csv + +from gunpowder import ( + BatchRequest, + CsvPointsSource, + GraphKey, + GraphSpec, + build, + Coordinate, + Roi, +) + + +# automatically set the seed for all tests +@pytest.fixture(autouse=True) +def seeds(): + random.seed(12345) + np.random.seed(12345) + + +@pytest.fixture +def test_points_2d(tmpdir): + + fake_points_file = tmpdir / "shift_test.csv" + fake_points = np.random.randint(0, 100, size=(2, 2)) + with open(fake_points_file, "w") as f: + for point in fake_points: + f.write(str(point[0]) + "\t" + str(point[1]) + "\n") + + yield fake_points_file, fake_points + + +@pytest.fixture +def test_points_3d(tmpdir): + + fake_points_file = tmpdir / "shift_test.csv" + fake_points = np.random.randint(0, 100, size=(3, 3)).astype(float) + with open(fake_points_file, "w") as f: + writer = csv.DictWriter(f, fieldnames=["x", "y", "z", "id"]) + writer.writeheader() + for i, point in enumerate(fake_points): + pointdict = {"x": point[0], "y": point[1], "z": point[2], "id": i} + writer.writerow(pointdict) + + yield fake_points_file, fake_points + + +def test_pipeline_2d(test_points_2d): + fake_points_file, fake_points = test_points_2d + + points_key = GraphKey("TEST_POINTS") + + csv_source = CsvPointsSource( + fake_points_file, + points_key, + spatial_cols=[0, 1], + delimiter="\t", + points_spec=GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), + ) + + request = BatchRequest() + shape = Coordinate((100, 100)) + request.add(points_key, shape) + + pipeline = csv_source + with build(pipeline) as b: + request = b.request_batch(request) + + target_locs = [list(fake_point) for fake_point in fake_points] + result_points = list(request[points_key].nodes) + result_locs = [list(point.location) for point in result_points] + + assert sorted(result_locs) == sorted(target_locs) + + +def test_pipeline_3d(test_points_3d): + fake_points_file, fake_points = test_points_3d + + points_key = GraphKey("TEST_POINTS") + scale = 2 + csv_source = CsvPointsSource( + fake_points_file, + points_key, + spatial_cols=[0, 2, 1], + delimiter=",", + id_col=3, + points_spec=GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), + scale=scale, + ) + + request = BatchRequest() + shape = Coordinate((100, 100, 100)) + request.add(points_key, shape) + + pipeline = csv_source + with build(pipeline) as b: + request = b.request_batch(request) + + result_points = list(request[points_key].nodes) + for node in result_points: + orig_loc = fake_points[int(node.id)] + reordered_loc = orig_loc.copy() + reordered_loc[1] = orig_loc[2] + reordered_loc[2] = orig_loc[1] + assert list(node.location) == list(reordered_loc * scale) diff --git a/tests/cases/pad.py b/tests/cases/pad.py index 31c0a848..a20861ec 100644 --- a/tests/cases/pad.py +++ b/tests/cases/pad.py @@ -16,10 +16,13 @@ Pad, Roi, build, + MergeProvider, ) from .helper_sources import ArraySource, GraphSource +from itertools import product + @pytest.mark.parametrize("mode", ["constant", "reflect"]) def test_padding(mode): diff --git a/tests/cases/random_location.py b/tests/cases/random_location.py index 52c4ccce..65bef230 100644 --- a/tests/cases/random_location.py +++ b/tests/cases/random_location.py @@ -56,7 +56,7 @@ def accepts(self, request): return request.array_specs[self.array].roi.contains((0, 0, 0)) -def test_output(): +def test_random_shift(): a = ArrayKey("A") b = ArrayKey("B") random_shift_key = ArrayKey("RANDOM_SHIFT") @@ -116,7 +116,7 @@ def test_output(): assert len(sums) > 1 -def test_output(): +def test_random_location(): a = ArrayKey("A") b = ArrayKey("B") source_a = ExampleSourceRandomLocation(a) diff --git a/tests/cases/shift_augment.py b/tests/cases/shift_augment.py index 75ab40e3..b53b9b4a 100644 --- a/tests/cases/shift_augment.py +++ b/tests/cases/shift_augment.py @@ -33,8 +33,6 @@ def seeds(): @pytest.fixture def test_points(tmpdir): - random.seed(1234) - np.random.seed(1234) fake_points_file = tmpdir / "shift_test.csv" fake_data_file = tmpdir / "shift_test.hdf5" @@ -46,12 +44,6 @@ def test_points(tmpdir): for point in fake_points: f.write(str(point[0]) + "\t" + str(point[1]) + "\n") - # This fixture will run after seeds since it is set - # with autouse=True. So make sure to reset the seeds properly at the end - # of this fixture - random.seed(12345) - np.random.seed(12345) - yield fake_points_file, fake_data_file, fake_points, fake_data @@ -143,7 +135,12 @@ def test_pipeline3(test_points): csv_source = CsvPointsSource( fake_points_file, points_key, - GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), + spatial_cols=[ + 0, + 1, + ], + delimiter="\t", + points_spec=GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), ) request = BatchRequest() diff --git a/tests/cases/torch_train.py b/tests/cases/torch_train.py index 8fb9e8ec..c66b56f2 100644 --- a/tests/cases/torch_train.py +++ b/tests/cases/torch_train.py @@ -86,7 +86,8 @@ def forward(self, a, b): ), ], ) -def test_loss_drops(tmpdir, device): +@pytest.mark.parametrize("input_args", [True, False]) +def test_loss_drops(tmpdir, device, input_args): checkpoint_basename = str(tmpdir / "model") a_key = ArrayKey("A") @@ -104,7 +105,7 @@ def test_loss_drops(tmpdir, device): model=model, optimizer=optimizer, loss=loss, - inputs={"a": a_key, "b": b_key}, + inputs={"a": a_key, "b": b_key} if not input_args else {0: a_key, 1: b_key}, loss_inputs={0: c_predicted_key, 1: c_key}, outputs={0: c_predicted_key}, gradients={0: c_gradient_key}, @@ -167,7 +168,8 @@ def test_loss_drops(tmpdir, device): ), ], ) -def test_output(device): +@pytest.mark.parametrize("input_args", [True, False]) +def test_spawn_subprocess(device, input_args): logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) a_key = ArrayKey("A") @@ -181,7 +183,7 @@ def test_output(device): source = example_train_source(a_key, b_key, c_key) predict = Predict( model=model, - inputs={"a": a_key, "b": b_key}, + inputs={"a": a_key, "b": b_key} if not input_args else {0: a_key, 1: b_key}, outputs={"linear": c_pred, 0: d_pred}, array_specs={ c_key: ArraySpec(nonspatial=True),