From 52248e16960b41ce0ba71e23dbc8568fad672539 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 2 Jan 2024 21:29:56 -0800 Subject: [PATCH] Patch v1.3.2 (#198) Merging patch branch into master and making new release Change Log: bugs fixed: - torch nodes forward hooks moved to start method of nodes to work better with the `spawn_subprocess` flag - torch train/predict work when setting multiprocessing start method = "spawn" - fixed bug in rasterize graph due to usage of the old api for accessing graph nodes new features: - torch nodes support defining specific cuda devices via e.g. "cuda:0" - added support for "reflect" mode padding in the pad node - improved error printing. - Reversed the order of exceptions so the cause of the error is printed at the bottom, and you can scroll up to trace the requests and batches through the tree. Now we no longer need to scroll to the top to see the cause of the error, - PreCache and other multiprocessing nodes no print repeats of the same error, it is simply printed once - removed unhelpful tracebacks (and many repititions) for the `try`, `except` blocks in node superclasses general improvements: - moved many tests from unittest to pytest, adding parametrization to simplify them and cover more cases - improved the dependency version bounds - improved docs, formatting, and pass more ci/cd workflows - removed unused imports and f-strings --- .github/workflows/mypy.yaml | 3 +- .github/workflows/publish-docs.yaml | 3 +- .github/workflows/test.yml | 2 +- docs/source/conf.py | 2 +- examples/cremi/mknet.py | 54 +- examples/cremi/predict.py | 63 +- examples/cremi/train.py | 137 ++--- gunpowder/array.py | 2 - gunpowder/array_spec.py | 9 +- gunpowder/batch.py | 3 +- .../contrib/nodes/add_blobs_from_points.py | 2 +- .../nodes/add_boundary_distance_gradients.py | 3 +- .../nodes/add_gt_mask_exclusive_zone.py | 3 +- .../nodes/add_nonsymmetric_affinities.py | 2 - gunpowder/contrib/nodes/hdf5_points_source.py | 3 - gunpowder/ext/__init__.py | 30 +- gunpowder/graph.py | 3 +- gunpowder/jax/nodes/predict.py | 6 +- gunpowder/jax/nodes/train.py | 6 +- gunpowder/nodes/batch_provider.py | 21 +- gunpowder/nodes/crop.py | 1 - gunpowder/nodes/deform_augment.py | 6 +- gunpowder/nodes/elastic_augment.py | 2 - gunpowder/nodes/exclude_labels.py | 2 +- gunpowder/nodes/generic_predict.py | 4 +- gunpowder/nodes/generic_train.py | 4 +- gunpowder/nodes/grow_boundary.py | 1 - gunpowder/nodes/klb_source.py | 1 - gunpowder/nodes/merge_provider.py | 1 - gunpowder/nodes/noise_augment.py | 2 - gunpowder/nodes/pad.py | 34 +- gunpowder/nodes/precache.py | 3 - gunpowder/nodes/random_location.py | 5 +- gunpowder/nodes/random_provider.py | 1 - gunpowder/nodes/rasterize_graph.py | 11 +- gunpowder/nodes/reject.py | 1 - gunpowder/nodes/shift_augment.py | 2 - gunpowder/nodes/simple_augment.py | 1 - gunpowder/nodes/specified_location.py | 2 - gunpowder/nodes/squeeze.py | 1 - gunpowder/nodes/stack.py | 1 - gunpowder/nodes/unsqueeze.py | 1 - gunpowder/nodes/zarr_source.py | 5 +- gunpowder/nodes/zarr_write.py | 4 +- gunpowder/pipeline.py | 20 +- gunpowder/producer_pool.py | 4 - gunpowder/provider_spec.py | 2 - gunpowder/torch/nodes/predict.py | 20 +- gunpowder/torch/nodes/train.py | 39 +- gunpowder/version_info.py | 2 +- mypy.ini | 37 ++ pyproject.toml | 92 ++- tests/cases/add_affinities.py | 4 - tests/cases/deform_augment.py | 3 + tests/cases/dvid_source.py | 1 - tests/cases/elastic_augment_points.py | 2 - tests/cases/expected_failures.py | 2 +- tests/cases/helper_sources.py | 5 +- tests/cases/intensity_scale_shift.py | 1 - tests/cases/jax_train.py | 5 +- tests/cases/noise_augment.py | 2 - tests/cases/pad.py | 111 ++-- tests/cases/random_location.py | 1 - tests/cases/rasterize_points.py | 412 +++++++------ tests/cases/resample.py | 1 - tests/cases/simple_augment.py | 2 - tests/cases/snapshot.py | 2 - tests/cases/tensorflow_train.py | 2 +- tests/cases/torch_train.py | 574 +++++++++--------- tests/cases/zarr_read_write.py | 2 +- tests/conftest.py | 14 +- 71 files changed, 930 insertions(+), 885 deletions(-) create mode 100644 mypy.ini diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index b5439db2..074fa9e9 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -15,6 +15,5 @@ jobs: uses: actions/checkout@v2 - name: mypy run: | - pip install . - pip install --upgrade mypy + pip install ".[dev]" mypy gunpowder diff --git a/.github/workflows/publish-docs.yaml b/.github/workflows/publish-docs.yaml index 0f9b78bb..efeb8e85 100644 --- a/.github/workflows/publish-docs.yaml +++ b/.github/workflows/publish-docs.yaml @@ -3,8 +3,7 @@ name: Deploy Docs to GitHub Pages on: push: branches: [main] - pull_request: - branches: [main] + tags: "*" workflow_dispatch: # Allow this job to clone the repo and create a page deployment diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 706882c7..4a2139c5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.9", "3.10", "3.11"] platform: [ubuntu-latest] steps: diff --git a/docs/source/conf.py b/docs/source/conf.py index 529dbe8b..b0da21cf 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -54,7 +54,7 @@ ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output diff --git a/examples/cremi/mknet.py b/examples/cremi/mknet.py index aac8c0df..fe13a7a5 100644 --- a/examples/cremi/mknet.py +++ b/examples/cremi/mknet.py @@ -2,8 +2,8 @@ import tensorflow as tf import json -def create_network(input_shape, name): +def create_network(input_shape, name): tf.reset_default_graph() # create a placeholder for the 3D raw input tensor @@ -11,20 +11,17 @@ def create_network(input_shape, name): # create a U-Net raw_batched = tf.reshape(raw, (1, 1) + input_shape) - unet_output = unet(raw_batched, 6, 4, [[1,3,3],[1,3,3],[1,3,3]]) + unet_output = unet(raw_batched, 6, 4, [[1, 3, 3], [1, 3, 3], [1, 3, 3]]) # add a convolution layer to create 3 output maps representing affinities # in z, y, and x pred_affs_batched = conv_pass( - unet_output, - kernel_size=1, - num_fmaps=3, - num_repetitions=1, - activation='sigmoid') + unet_output, kernel_size=1, num_fmaps=3, num_repetitions=1, activation="sigmoid" + ) # get the shape of the output output_shape_batched = pred_affs_batched.get_shape().as_list() - output_shape = output_shape_batched[1:] # strip the batch dimension + output_shape = output_shape_batched[1:] # strip the batch dimension # the 4D output tensor (3, depth, height, width) pred_affs = tf.reshape(pred_affs_batched, output_shape) @@ -33,46 +30,39 @@ def create_network(input_shape, name): gt_affs = tf.placeholder(tf.float32, shape=output_shape) # create a placeholder for per-voxel loss weights - loss_weights = tf.placeholder( - tf.float32, - shape=output_shape) + loss_weights = tf.placeholder(tf.float32, shape=output_shape) # compute the loss as the weighted mean squared error between the # predicted and the ground-truth affinities - loss = tf.losses.mean_squared_error( - gt_affs, - pred_affs, - loss_weights) + loss = tf.losses.mean_squared_error(gt_affs, pred_affs, loss_weights) # use the Adam optimizer to minimize the loss opt = tf.train.AdamOptimizer( - learning_rate=0.5e-4, - beta1=0.95, - beta2=0.999, - epsilon=1e-8) + learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8 + ) optimizer = opt.minimize(loss) # store the network in a meta-graph file - tf.train.export_meta_graph(filename=name + '.meta') + tf.train.export_meta_graph(filename=name + ".meta") # store network configuration for use in train and predict scripts config = { - 'raw': raw.name, - 'pred_affs': pred_affs.name, - 'gt_affs': gt_affs.name, - 'loss_weights': loss_weights.name, - 'loss': loss.name, - 'optimizer': optimizer.name, - 'input_shape': input_shape, - 'output_shape': output_shape[1:] + "raw": raw.name, + "pred_affs": pred_affs.name, + "gt_affs": gt_affs.name, + "loss_weights": loss_weights.name, + "loss": loss.name, + "optimizer": optimizer.name, + "input_shape": input_shape, + "output_shape": output_shape[1:], } - with open(name + '_config.json', 'w') as f: + with open(name + "_config.json", "w") as f: json.dump(config, f) -if __name__ == "__main__": +if __name__ == "__main__": # create a network for training - create_network((84, 268, 268), 'train_net') + create_network((84, 268, 268), "train_net") # create a larger network for faster prediction - create_network((120, 322, 322), 'test_net') + create_network((120, 322, 322), "test_net") diff --git a/examples/cremi/predict.py b/examples/cremi/predict.py index 8693786f..4f229b14 100644 --- a/examples/cremi/predict.py +++ b/examples/cremi/predict.py @@ -2,29 +2,29 @@ import gunpowder as gp import json -def predict(iteration): +def predict(iteration): ################## # DECLARE ARRAYS # ################## # raw intensities - raw = gp.ArrayKey('RAW') + raw = gp.ArrayKey("RAW") # the predicted affinities - pred_affs = gp.ArrayKey('PRED_AFFS') + pred_affs = gp.ArrayKey("PRED_AFFS") #################### # DECLARE REQUESTS # #################### - with open('test_net_config.json', 'r') as f: + with open("test_net_config.json", "r") as f: net_config = json.load(f) # get the input and output size in world units (nm, in this case) voxel_size = gp.Coordinate((40, 4, 4)) - input_size = gp.Coordinate(net_config['input_shape'])*voxel_size - output_size = gp.Coordinate(net_config['output_shape'])*voxel_size + input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size + output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size context = input_size - output_size # formulate the request for what a batch should contain @@ -37,10 +37,8 @@ def predict(iteration): ############################# source = gp.Hdf5Source( - 'sample_A_padded_20160501.hdf', - datasets = { - raw: 'volumes/raw' - }) + "sample_A_padded_20160501.hdf", datasets={raw: "volumes/raw"} + ) # get the ROI provided for raw (we need it later to calculate the ROI in # which we can make predictions) @@ -48,41 +46,35 @@ def predict(iteration): raw_roi = source.spec[raw].roi pipeline = ( - # read from HDF5 file - source + - + source + + # convert raw to float in [0, 1] - gp.Normalize(raw) + - + gp.Normalize(raw) + + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Predict( - graph='test_net.meta', - checkpoint='train_net_checkpoint_%d'%iteration, - inputs={ - net_config['raw']: raw - }, - outputs={ - net_config['pred_affs']: pred_affs - }, - array_specs={ - pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context)) - }) + - + graph="test_net.meta", + checkpoint="train_net_checkpoint_%d" % iteration, + inputs={net_config["raw"]: raw}, + outputs={net_config["pred_affs"]: pred_affs}, + array_specs={pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context))}, + ) + + # store all passing batches in the same HDF5 file gp.Hdf5Write( { - raw: '/volumes/raw', - pred_affs: '/volumes/pred_affs', + raw: "/volumes/raw", + pred_affs: "/volumes/pred_affs", }, - output_filename='predictions_sample_A.hdf', - compression_type='gzip' - ) + - + output_filename="predictions_sample_A.hdf", + compression_type="gzip", + ) + + # show a summary of time spend in each node every 10 iterations - gp.PrintProfilingStats(every=10) + - + gp.PrintProfilingStats(every=10) + + # iterate over the whole dataset in a scanning fashion, emitting # requests that match the size of the network gp.Scan(reference=request) @@ -93,5 +85,6 @@ def predict(iteration): # without keeping the complete dataset in memory pipeline.request_batch(gp.BatchRequest()) + if __name__ == "__main__": predict(200000) diff --git a/examples/cremi/train.py b/examples/cremi/train.py index 8edd12f7..6faf7e50 100644 --- a/examples/cremi/train.py +++ b/examples/cremi/train.py @@ -6,41 +6,41 @@ logging.basicConfig(level=logging.INFO) -def train(iterations): +def train(iterations): ################## # DECLARE ARRAYS # ################## # raw intensities - raw = gp.ArrayKey('RAW') + raw = gp.ArrayKey("RAW") # objects labelled with unique IDs - gt_labels = gp.ArrayKey('LABELS') + gt_labels = gp.ArrayKey("LABELS") # array of per-voxel affinities to direct neighbors - gt_affs= gp.ArrayKey('AFFINITIES') + gt_affs = gp.ArrayKey("AFFINITIES") # weights to use to balance the loss - loss_weights = gp.ArrayKey('LOSS_WEIGHTS') + loss_weights = gp.ArrayKey("LOSS_WEIGHTS") # the predicted affinities - pred_affs = gp.ArrayKey('PRED_AFFS') + pred_affs = gp.ArrayKey("PRED_AFFS") # the gredient of the loss wrt to the predicted affinities - pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS') + pred_affs_gradients = gp.ArrayKey("PRED_AFFS_GRADIENTS") #################### # DECLARE REQUESTS # #################### - with open('train_net_config.json', 'r') as f: + with open("train_net_config.json", "r") as f: net_config = json.load(f) # get the input and output size in world units (nm, in this case) voxel_size = gp.Coordinate((40, 4, 4)) - input_size = gp.Coordinate(net_config['input_shape'])*voxel_size - output_size = gp.Coordinate(net_config['output_shape'])*voxel_size + input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size + output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() @@ -60,44 +60,38 @@ def train(iterations): ############################## pipeline = ( - # a tuple of sources, one for each sample (A, B, and C) provided by the # CREMI challenge tuple( - # read batches from the HDF5 file gp.Hdf5Source( - 'sample_'+s+'_padded_20160501.hdf', - datasets = { - raw: 'volumes/raw', - gt_labels: 'volumes/labels/neuron_ids' - } - ) + - + "sample_" + s + "_padded_20160501.hdf", + datasets={raw: "volumes/raw", gt_labels: "volumes/labels/neuron_ids"}, + ) + + # convert raw to float in [0, 1] gp.Normalize(raw) + - # chose a random location for each requested batch gp.RandomLocation() - - for s in ['A', 'B', 'C'] - ) + - + for s in ["A", "B", "C"] + ) + + # chose a random source (i.e., sample) from the above - gp.RandomProvider() + - + gp.RandomProvider() + + # elastically deform the batch gp.ElasticAugment( - [4,40,40], - [0,2,2], - [0,math.pi/2.0], + [4, 40, 40], + [0, 2, 2], + [0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, - max_misalign=25) + - + max_misalign=25, + ) + + # apply transpose and mirror augmentations - gp.SimpleAugment(transpose_only=[1, 2]) + - + gp.SimpleAugment(transpose_only=[1, 2]) + + # scale and shift the intensity of the raw array gp.IntensityAugment( raw, @@ -105,65 +99,54 @@ def train(iterations): scale_max=1.1, shift_min=-0.1, shift_max=0.1, - z_section_wise=True) + - + z_section_wise=True, + ) + + # grow a boundary between labels - gp.GrowBoundary( - gt_labels, - steps=3, - only_xy=True) + - + gp.GrowBoundary(gt_labels, steps=3, only_xy=True) + + # convert labels into affinities between voxels - gp.AddAffinities( - [[-1, 0, 0], [0, -1, 0], [0, 0, -1]], - gt_labels, - gt_affs) + - + gp.AddAffinities([[-1, 0, 0], [0, -1, 0], [0, 0, -1]], gt_labels, gt_affs) + + # create a weight array that balances positive and negative samples in # the affinity array - gp.BalanceLabels( - gt_affs, - loss_weights) + - + gp.BalanceLabels(gt_affs, loss_weights) + + # pre-cache batches from the point upstream - gp.PreCache( - cache_size=10, - num_workers=5) + - + gp.PreCache(cache_size=10, num_workers=5) + + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( - 'train_net', - net_config['optimizer'], - net_config['loss'], + "train_net", + net_config["optimizer"], + net_config["loss"], inputs={ - net_config['raw']: raw, - net_config['gt_affs']: gt_affs, - net_config['loss_weights']: loss_weights + net_config["raw"]: raw, + net_config["gt_affs"]: gt_affs, + net_config["loss_weights"]: loss_weights, }, - outputs={ - net_config['pred_affs']: pred_affs - }, - gradients={ - net_config['pred_affs']: pred_affs_gradients - }, - save_every=1) + - + outputs={net_config["pred_affs"]: pred_affs}, + gradients={net_config["pred_affs"]: pred_affs_gradients}, + save_every=1, + ) + + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { - raw: '/volumes/raw', - gt_labels: '/volumes/labels/neuron_ids', - gt_affs: '/volumes/labels/affs', - pred_affs: '/volumes/pred_affs', - pred_affs_gradients: '/volumes/pred_affs_gradients' + raw: "/volumes/raw", + gt_labels: "/volumes/labels/neuron_ids", + gt_affs: "/volumes/labels/affs", + pred_affs: "/volumes/pred_affs", + pred_affs_gradients: "/volumes/pred_affs_gradients", }, - output_dir='snapshots', - output_filename='batch_{iteration}.hdf', + output_dir="snapshots", + output_filename="batch_{iteration}.hdf", every=100, additional_request=snapshot_request, - compression_type='gzip') + - + compression_type="gzip", + ) + + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=10) ) @@ -180,6 +163,6 @@ def train(iterations): print("Finished") + if __name__ == "__main__": train(200000) - \ No newline at end of file diff --git a/gunpowder/array.py b/gunpowder/array.py index 0177cf5a..a8da1322 100644 --- a/gunpowder/array.py +++ b/gunpowder/array.py @@ -1,7 +1,5 @@ from .freezable import Freezable from copy import deepcopy -from gunpowder.coordinate import Coordinate -from gunpowder.roi import Roi import logging import numpy as np import copy diff --git a/gunpowder/array_spec.py b/gunpowder/array_spec.py index ec271488..9002ae4f 100644 --- a/gunpowder/array_spec.py +++ b/gunpowder/array_spec.py @@ -14,13 +14,12 @@ class ArraySpec(Freezable): roi (:class:`Roi`): The region of interested represented by this array spec. Can be - ``None`` for :class:`BatchProviders` that allow - requests for arrays everywhere, but will always be set for array - specs that are part of a :class:`Array`. + ``None`` for nonspatial arrays or to indicate the true value is unknown. voxel_size (:class:`Coordinate`): - The size of the spatial axises in world units. + The size of the spatial axises in world units. Can be ``None`` for + nonspatial arrays or to indicate the true value is unknown. interpolatable (``bool``): @@ -55,7 +54,7 @@ def __init__( if nonspatial: assert roi is None, "Non-spatial arrays can not have a ROI" - assert voxel_size is None, "Non-spatial arrays can not " "have a voxel size" + assert voxel_size is None, "Non-spatial arrays can not have a voxel size" self.freeze() diff --git a/gunpowder/batch.py b/gunpowder/batch.py index 412c891f..ffc97e77 100644 --- a/gunpowder/batch.py +++ b/gunpowder/batch.py @@ -1,7 +1,6 @@ from copy import copy as shallow_copy import logging import multiprocessing -import warnings from .freezable import Freezable from .profiling import ProfilingStats @@ -75,7 +74,7 @@ def __setitem__(self, key, value): elif isinstance(value, Graph): assert isinstance( key, GraphKey - ), f"Only a GraphKey is allowed as key for Graph value." + ), "Only a GraphKey is allowed as key for Graph value." self.graphs[key] = value else: diff --git a/gunpowder/contrib/nodes/add_blobs_from_points.py b/gunpowder/contrib/nodes/add_blobs_from_points.py index 03b063ba..a78eb814 100644 --- a/gunpowder/contrib/nodes/add_blobs_from_points.py +++ b/gunpowder/contrib/nodes/add_blobs_from_points.py @@ -143,7 +143,7 @@ def process(self, batch, request): synapse_ids = [] for point_id, point in points.data.items(): # pdb.set_trace() - if not point.partner_ids[0] in partner_points.data.keys(): + if point.partner_ids[0] not in partner_points.data.keys(): logger.warning( "Point %s has no partner. Deleting..." % point_id ) diff --git a/gunpowder/contrib/nodes/add_boundary_distance_gradients.py b/gunpowder/contrib/nodes/add_boundary_distance_gradients.py index 2ef93870..b2897272 100644 --- a/gunpowder/contrib/nodes/add_boundary_distance_gradients.py +++ b/gunpowder/contrib/nodes/add_boundary_distance_gradients.py @@ -4,7 +4,6 @@ from gunpowder.array import Array from gunpowder.batch_request import BatchRequest from gunpowder.nodes.batch_filter import BatchFilter -from numpy.lib.stride_tricks import as_strided from scipy.ndimage.morphology import distance_transform_edt logger = logging.getLogger(__name__) @@ -83,7 +82,7 @@ def prepare(self, request): return deps def process(self, batch, request): - if not self.gradient_array_key in request: + if self.gradient_array_key not in request: return labels = batch.arrays[self.label_array_key].data diff --git a/gunpowder/contrib/nodes/add_gt_mask_exclusive_zone.py b/gunpowder/contrib/nodes/add_gt_mask_exclusive_zone.py index cff056f7..f50e6a70 100644 --- a/gunpowder/contrib/nodes/add_gt_mask_exclusive_zone.py +++ b/gunpowder/contrib/nodes/add_gt_mask_exclusive_zone.py @@ -1,10 +1,9 @@ import copy import logging import numpy as np -from scipy import ndimage from gunpowder.nodes.batch_filter import BatchFilter -from gunpowder.array import Array, ArrayKeys +from gunpowder.array import Array from gunpowder.nodes.rasterize_graph import RasterizationSettings from gunpowder.morphology import enlarge_binary_map diff --git a/gunpowder/contrib/nodes/add_nonsymmetric_affinities.py b/gunpowder/contrib/nodes/add_nonsymmetric_affinities.py index 6d47201b..ef16398b 100644 --- a/gunpowder/contrib/nodes/add_nonsymmetric_affinities.py +++ b/gunpowder/contrib/nodes/add_nonsymmetric_affinities.py @@ -1,7 +1,5 @@ -import copy import logging import numpy as np -import pdb from gunpowder.array import Array from gunpowder.nodes.batch_filter import BatchFilter diff --git a/gunpowder/contrib/nodes/hdf5_points_source.py b/gunpowder/contrib/nodes/hdf5_points_source.py index a3cd2b44..f78630a1 100644 --- a/gunpowder/contrib/nodes/hdf5_points_source.py +++ b/gunpowder/contrib/nodes/hdf5_points_source.py @@ -5,10 +5,7 @@ from gunpowder.batch import Batch from gunpowder.coordinate import Coordinate from gunpowder.ext import h5py -from gunpowder.graph import GraphKey, Graph -from gunpowder.graph_spec import GraphSpec from gunpowder.profiling import Timing -from gunpowder.roi import Roi from gunpowder.nodes.batch_provider import BatchProvider logger = logging.getLogger(__name__) diff --git a/gunpowder/ext/__init__.py b/gunpowder/ext/__init__.py index 7aec50c9..5b51124d 100644 --- a/gunpowder/ext/__init__.py +++ b/gunpowder/ext/__init__.py @@ -3,6 +3,7 @@ import traceback import sys +from typing import Optional, Any logger = logging.getLogger(__name__) @@ -20,72 +21,73 @@ def __getattr__(self, item): try: import dvision -except ImportError as e: +except ImportError: dvision = NoSuchModule("dvision") try: import h5py -except ImportError as e: +except ImportError: h5py = NoSuchModule("h5py") try: import pyklb -except ImportError as e: +except ImportError: pyklb = NoSuchModule("pyklb") try: import tensorflow -except ImportError as e: +except ImportError: tensorflow = NoSuchModule("tensorflow") try: import torch -except ImportError as e: +except ImportError: torch = NoSuchModule("torch") try: import tensorboardX -except ImportError as e: +except ImportError: tensorboardX = NoSuchModule("tensorboardX") try: import malis -except ImportError as e: +except ImportError: malis = NoSuchModule("malis") try: import augment -except ImportError as e: +except ImportError: augment = NoSuchModule("augment") +ZarrFile: Optional[Any] = None try: import zarr from .zarr_file import ZarrFile -except ImportError as e: +except ImportError: zarr = NoSuchModule("zarr") ZarrFile = None try: import daisy -except ImportError as e: +except ImportError: daisy = NoSuchModule("daisy") try: import jax -except ImportError as e: +except ImportError: jax = NoSuchModule("jax") try: import jax.numpy as jnp -except ImportError as e: +except ImportError: jnp = NoSuchModule("jnp") try: import haiku -except ImportError as e: +except ImportError: haiku = NoSuchModule("haiku") try: import optax -except ImportError as e: +except ImportError: optax = NoSuchModule("optax") diff --git a/gunpowder/graph.py b/gunpowder/graph.py index 91fdb883..3321c5ac 100644 --- a/gunpowder/graph.py +++ b/gunpowder/graph.py @@ -9,7 +9,6 @@ from typing import Dict, Optional, Set, Iterator, Any import logging import itertools -import warnings logger = logging.getLogger(__name__) @@ -485,7 +484,7 @@ def _roi_intercept( offset = outside - inside distance = np.linalg.norm(offset) - assert not np.isclose(distance, 0), f"Inside and Outside are the same location" + assert not np.isclose(distance, 0), "Inside and Outside are the same location" direction = offset / distance # `offset` can be 0 on some but not all axes leaving a 0 in the denominator. diff --git a/gunpowder/jax/nodes/predict.py b/gunpowder/jax/nodes/predict.py index 4c46f233..496d0fd0 100644 --- a/gunpowder/jax/nodes/predict.py +++ b/gunpowder/jax/nodes/predict.py @@ -6,7 +6,7 @@ import pickle import logging -from typing import Dict, Union +from typing import Dict, Union, Optional logger = logging.getLogger(__name__) @@ -52,8 +52,8 @@ def __init__( model: GenericJaxModel, inputs: Dict[str, ArrayKey], outputs: Dict[Union[str, int], ArrayKey], - array_specs: Dict[ArrayKey, ArraySpec] = None, - checkpoint: str = None, + array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, + checkpoint: Optional[str] = None, spawn_subprocess=False, ): self.array_specs = array_specs if array_specs is not None else {} diff --git a/gunpowder/jax/nodes/train.py b/gunpowder/jax/nodes/train.py index 4d1f17a3..9621b129 100644 --- a/gunpowder/jax/nodes/train.py +++ b/gunpowder/jax/nodes/train.py @@ -11,7 +11,7 @@ from gunpowder.nodes.generic_train import GenericTrain from gunpowder.jax import GenericJaxModel -from typing import Dict, Union, Optional +from typing import Dict, Union, Optional, Any logger = logging.getLogger(__name__) @@ -108,7 +108,7 @@ def __init__( checkpoint_basename: str = "model", save_every: int = 2000, keep_n_checkpoints: Optional[int] = None, - log_dir: str = None, + log_dir: Optional[str] = None, log_every: int = 1, spawn_subprocess: bool = False, n_devices: Optional[int] = None, @@ -141,7 +141,7 @@ def __init__( if log_dir is not None: logger.warning("log_dir given, but tensorboardX is not installed") - self.intermediate_layers = {} + self.intermediate_layers: dict[ArrayKey, Any] = {} self.validate_fn = validate_fn self.validate_every = validate_every diff --git a/gunpowder/nodes/batch_provider.py b/gunpowder/nodes/batch_provider.py index dc641c8e..304e1e3a 100644 --- a/gunpowder/nodes/batch_provider.py +++ b/gunpowder/nodes/batch_provider.py @@ -3,6 +3,8 @@ import copy import logging import random +import traceback +from typing import Optional from gunpowder.coordinate import Coordinate from gunpowder.provider_spec import ProviderSpec @@ -15,17 +17,22 @@ class BatchRequestError(Exception): - def __init__(self, provider, request, batch): + def __init__( + self, provider, request, batch, original_traceback: Optional[list[str]] = None + ): self.provider = provider self.request = request self.batch = batch + self.original_traceback = original_traceback def __str__(self): return ( f"Exception in {self.provider.name()} while processing request" - f"{self.request} \n" + f"{self.request}" "Batch returned so far:\n" - f"{self.batch}" + f"{self.batch}" + ("\n\n" + "".join(self.original_traceback)) + if self.original_traceback is not None + else "" ) @@ -174,7 +181,6 @@ def request_batch(self, request): batch = None try: - self.set_seeds(request) logger.debug("%s got request %s", self.name(), request) @@ -195,7 +201,12 @@ def request_batch(self, request): logger.debug("%s provides %s", self.name(), batch) except Exception as e: - raise BatchRequestError(self, request, batch) from e + tb = traceback.format_exception(type(e), e, e.__traceback__) + if isinstance(e, BatchRequestError): + tb = tb[-1:] + raise BatchRequestError( + self, request, batch, original_traceback=tb + ) from None return batch diff --git a/gunpowder/nodes/crop.py b/gunpowder/nodes/crop.py index 3e4cdeb5..0584335e 100644 --- a/gunpowder/nodes/crop.py +++ b/gunpowder/nodes/crop.py @@ -1,4 +1,3 @@ -import copy import logging from .batch_filter import BatchFilter diff --git a/gunpowder/nodes/deform_augment.py b/gunpowder/nodes/deform_augment.py index cdf5eeff..6d7e23af 100644 --- a/gunpowder/nodes/deform_augment.py +++ b/gunpowder/nodes/deform_augment.py @@ -21,6 +21,7 @@ import logging import math import random +from typing import Optional logger = logging.getLogger(__name__) @@ -93,8 +94,8 @@ def __init__( spatial_dims=3, use_fast_points_transform=False, recompute_missing_points=True, - transform_key: ArrayKey = None, - graph_raster_voxel_size: Coordinate = None, + transform_key: Optional[ArrayKey] = None, + graph_raster_voxel_size: Optional[Coordinate] = None, ): self.control_point_spacing = Coordinate(control_point_spacing) self.jitter_sigma = Coordinate(jitter_sigma) @@ -129,7 +130,6 @@ def setup(self): self.provides(self.transform_key, spec) def prepare(self, request): - # get the total ROI of all requests total_roi = request.get_total_roi() logger.debug("total ROI is %s" % total_roi) diff --git a/gunpowder/nodes/elastic_augment.py b/gunpowder/nodes/elastic_augment.py index a70f7866..a4413a44 100644 --- a/gunpowder/nodes/elastic_augment.py +++ b/gunpowder/nodes/elastic_augment.py @@ -9,7 +9,6 @@ from gunpowder.coordinate import Coordinate from gunpowder.ext import augment from gunpowder.roi import Roi -from gunpowder.array import ArrayKey import warnings @@ -124,7 +123,6 @@ def __init__( self.recompute_missing_points = recompute_missing_points def prepare(self, request): - # get the voxel size self.voxel_size = self.__get_common_voxel_size(request) diff --git a/gunpowder/nodes/exclude_labels.py b/gunpowder/nodes/exclude_labels.py index ae38d43a..2592dc25 100644 --- a/gunpowder/nodes/exclude_labels.py +++ b/gunpowder/nodes/exclude_labels.py @@ -71,7 +71,7 @@ def process(self, batch, request): include_mask[gt.data == label] = 0 # if no ignore mask is provided or requested, we are done - if not self.ignore_mask or not self.ignore_mask in request: + if not self.ignore_mask or self.ignore_mask not in request: return voxel_size = self.spec[self.labels].voxel_size diff --git a/gunpowder/nodes/generic_predict.py b/gunpowder/nodes/generic_predict.py index 524967b8..e3f4ec5b 100644 --- a/gunpowder/nodes/generic_predict.py +++ b/gunpowder/nodes/generic_predict.py @@ -89,7 +89,7 @@ def setup(self): if self.spawn_subprocess: # start prediction as a producer pool, so that we can gracefully # exit if anything goes wrong - self.worker = ProducerPool([self.__produce_predict_batch], queue_size=1) + self.worker = ProducerPool([self._produce_predict_batch], queue_size=1) self.batch_in = multiprocessing.Queue(maxsize=1) self.batch_in_lock = multiprocessing.Lock() self.batch_out_lock = multiprocessing.Lock() @@ -177,7 +177,7 @@ def stop(self): """ pass - def __produce_predict_batch(self): + def _produce_predict_batch(self): """Process one batch.""" if not self.initialized: diff --git a/gunpowder/nodes/generic_train.py b/gunpowder/nodes/generic_train.py index ae93b7de..a26a285f 100644 --- a/gunpowder/nodes/generic_train.py +++ b/gunpowder/nodes/generic_train.py @@ -104,7 +104,7 @@ def setup(self): if self.spawn_subprocess: # start training as a producer pool, so that we can gracefully exit if # anything goes wrong - self.worker = ProducerPool([self.__produce_train_batch], queue_size=1) + self.worker = ProducerPool([self._produce_train_batch], queue_size=1) self.batch_in = multiprocessing.Queue(maxsize=1) self.worker.start() else: @@ -208,7 +208,7 @@ def natural_keys(text): return None, 0 - def __produce_train_batch(self): + def _produce_train_batch(self): """Process one train batch.""" if not self.initialized: diff --git a/gunpowder/nodes/grow_boundary.py b/gunpowder/nodes/grow_boundary.py index 08d20abf..d793345f 100644 --- a/gunpowder/nodes/grow_boundary.py +++ b/gunpowder/nodes/grow_boundary.py @@ -2,7 +2,6 @@ from scipy import ndimage from .batch_filter import BatchFilter -from gunpowder.array import Array class GrowBoundary(BatchFilter): diff --git a/gunpowder/nodes/klb_source.py b/gunpowder/nodes/klb_source.py index e2a3f758..d4776bba 100644 --- a/gunpowder/nodes/klb_source.py +++ b/gunpowder/nodes/klb_source.py @@ -1,4 +1,3 @@ -import copy import logging import numpy as np import glob diff --git a/gunpowder/nodes/merge_provider.py b/gunpowder/nodes/merge_provider.py index 0d32300e..6df979b8 100644 --- a/gunpowder/nodes/merge_provider.py +++ b/gunpowder/nodes/merge_provider.py @@ -1,4 +1,3 @@ -from gunpowder.provider_spec import ProviderSpec from gunpowder.batch import Batch from gunpowder.batch_request import BatchRequest diff --git a/gunpowder/nodes/noise_augment.py b/gunpowder/nodes/noise_augment.py index f4bfb5ba..5275a2c0 100644 --- a/gunpowder/nodes/noise_augment.py +++ b/gunpowder/nodes/noise_augment.py @@ -57,13 +57,11 @@ def process(self, batch, request): seed = request.random_seed try: - raw.data = skimage.util.random_noise( raw.data, mode=self.mode, rng=seed, clip=self.clip, **self.kwargs ).astype(raw.data.dtype) except ValueError: - # legacy version of skimage random_noise raw.data = skimage.util.random_noise( raw.data, mode=self.mode, seed=seed, clip=self.clip, **self.kwargs diff --git a/gunpowder/nodes/pad.py b/gunpowder/nodes/pad.py index 6bbfdc58..758fd04a 100644 --- a/gunpowder/nodes/pad.py +++ b/gunpowder/nodes/pad.py @@ -7,6 +7,7 @@ from gunpowder.coordinate import Coordinate from gunpowder.batch_request import BatchRequest + logger = logging.getLogger(__name__) @@ -27,15 +28,22 @@ class Pad(BatchFilter): a coordinate, this amount will be added to the ROI in the positive and negative direction. + mode (string): + + One of 'constant' or 'reflect'. + Default is 'constant' + value (scalar or ``None``): The value to report inside the padding. If not given, 0 is used. + Only used in case of 'constant' mode. Only used for :class:`Array`. """ - def __init__(self, key, size, value=None): + def __init__(self, key, size, mode="constant", value=None): self.key = key self.size = size + self.mode = mode self.value = value def setup(self): @@ -118,19 +126,11 @@ def __expand(self, a, from_roi, to_roi, value): ) num_channels = len(a.shape) - from_roi.dims - channel_shapes = a.shape[:num_channels] - - b = np.zeros(channel_shapes + to_roi.shape, dtype=a.dtype) - if value != 0: - b[:] = value - - shift = -to_roi.offset - logger.debug("shifting 'from' by " + str(shift)) - a_in_b = from_roi.shift(shift).to_slices() - - logger.debug("target shape is " + str(b.shape)) - logger.debug("target slice is " + str(a_in_b)) - - b[(slice(None),) * num_channels + a_in_b] = a - - return b + lower_pad = from_roi.begin - to_roi.begin + upper_pad = to_roi.end - from_roi.end + pad_width = [(0, 0)] * num_channels + list(zip(lower_pad, upper_pad)) + if self.mode == "constant": + padded = np.pad(a, pad_width, "constant", constant_values=value) + elif self.mode == "reflect": + padded = np.pad(a, pad_width, "reflect") + return padded diff --git a/gunpowder/nodes/precache.py b/gunpowder/nodes/precache.py index ac35d32f..9c58ae53 100644 --- a/gunpowder/nodes/precache.py +++ b/gunpowder/nodes/precache.py @@ -1,7 +1,4 @@ import logging -import multiprocessing -import time -import random from .batch_filter import BatchFilter from gunpowder.profiling import Timing diff --git a/gunpowder/nodes/random_location.py b/gunpowder/nodes/random_location.py index fccbd6cb..d5b6c1e2 100644 --- a/gunpowder/nodes/random_location.py +++ b/gunpowder/nodes/random_location.py @@ -172,7 +172,6 @@ def setup(self): self.provides(self.random_shift_key, ArraySpec(nonspatial=True)) def prepare(self, request): - logger.debug("request: %s", request.array_specs) logger.debug("my spec: %s", self.spec) @@ -383,9 +382,7 @@ def __select_random_location_with_points( logger.debug("belongs to lcm voxel %s", lcm_location) # align the point request ROI with lcm voxel grid - lcm_roi = request_points_roi.snap_to_grid( - lcm_voxel_size, - mode="shrink") + lcm_roi = request_points_roi.snap_to_grid(lcm_voxel_size, mode="shrink") lcm_roi = lcm_roi / lcm_voxel_size logger.debug("Point request ROI: %s", request_points_roi) logger.debug("Point request lcm ROI shape: %s", lcm_roi.shape) diff --git a/gunpowder/nodes/random_provider.py b/gunpowder/nodes/random_provider.py index dfb086f8..a9ae1081 100644 --- a/gunpowder/nodes/random_provider.py +++ b/gunpowder/nodes/random_provider.py @@ -69,7 +69,6 @@ def setup(self): self.provides(self.random_provider_key, ArraySpec(nonspatial=True)) def provide(self, request): - if self.random_provider_key is not None: del request[self.random_provider_key] diff --git a/gunpowder/nodes/rasterize_graph.py b/gunpowder/nodes/rasterize_graph.py index 1a12335f..bb2473f6 100644 --- a/gunpowder/nodes/rasterize_graph.py +++ b/gunpowder/nodes/rasterize_graph.py @@ -1,4 +1,3 @@ -import copy import logging import numpy as np from scipy.ndimage.filters import gaussian_filter @@ -12,7 +11,6 @@ from gunpowder.freezable import Freezable from gunpowder.morphology import enlarge_binary_map, create_ball_kernel from gunpowder.ndarray import replace -from gunpowder.graph import GraphKey from gunpowder.graph_spec import GraphSpec from gunpowder.roi import Roi @@ -221,7 +219,8 @@ def process(self, batch, request): mask_array = batch.arrays[mask].crop(enlarged_vol_roi) # get those component labels in the mask, that contain graph labels = [] - for i, point in graph.data.items(): + # for i, point in graph.data.items(): + for i, point in enumerate(graph.nodes): v = Coordinate(point.location / voxel_size) v -= data_roi.begin labels.append(mask_array.data[v]) @@ -250,11 +249,15 @@ def process(self, batch, request): voxel_size, self.spec[self.array].dtype, self.settings, - Array(data=mask_array.data == label, spec=mask_array.spec), + Array( + data=(mask_array.data == label), + spec=mask_array.spec, + ), ) for label in labels ], axis=0, + dtype=self.spec[self.array].dtype, ) else: diff --git a/gunpowder/nodes/reject.py b/gunpowder/nodes/reject.py index b6a47436..87bb83aa 100644 --- a/gunpowder/nodes/reject.py +++ b/gunpowder/nodes/reject.py @@ -55,7 +55,6 @@ def setup(self): self.upstream_provider = self.get_upstream_provider() def provide(self, request): - report_next_timeout = 10 num_rejected = 0 diff --git a/gunpowder/nodes/shift_augment.py b/gunpowder/nodes/shift_augment.py index 8fe6524b..d42b1434 100644 --- a/gunpowder/nodes/shift_augment.py +++ b/gunpowder/nodes/shift_augment.py @@ -4,7 +4,6 @@ import random from gunpowder.roi import Roi from gunpowder.coordinate import Coordinate -from gunpowder.batch_request import BatchRequest from .batch_filter import BatchFilter @@ -24,7 +23,6 @@ def __init__(self, prob_slip=0, prob_shift=0, sigma=0, shift_axis=0): self.lcm_voxel_size = None def prepare(self, request): - self.ndim = request.get_total_roi().dims assert self.shift_axis in range(self.ndim) diff --git a/gunpowder/nodes/simple_augment.py b/gunpowder/nodes/simple_augment.py index dc756e3a..f5a97333 100644 --- a/gunpowder/nodes/simple_augment.py +++ b/gunpowder/nodes/simple_augment.py @@ -106,7 +106,6 @@ def setup(self): self.permutation_dict[k] = v def prepare(self, request): - self.mirror = [ random.random() < self.mirror_probs[d] if self.mirror_mask[d] else 0 for d in range(self.dims) diff --git a/gunpowder/nodes/specified_location.py b/gunpowder/nodes/specified_location.py index b209e078..cc5e5844 100644 --- a/gunpowder/nodes/specified_location.py +++ b/gunpowder/nodes/specified_location.py @@ -1,10 +1,8 @@ from random import randrange -from random import choice, seed import logging import numpy as np from gunpowder.coordinate import Coordinate -from gunpowder.batch_request import BatchRequest from .batch_filter import BatchFilter diff --git a/gunpowder/nodes/squeeze.py b/gunpowder/nodes/squeeze.py index 2e999714..d0a1469b 100644 --- a/gunpowder/nodes/squeeze.py +++ b/gunpowder/nodes/squeeze.py @@ -1,4 +1,3 @@ -import copy from typing import List import logging diff --git a/gunpowder/nodes/stack.py b/gunpowder/nodes/stack.py index 21f53acc..5d7feabd 100644 --- a/gunpowder/nodes/stack.py +++ b/gunpowder/nodes/stack.py @@ -25,7 +25,6 @@ def __init__(self, num_repetitions): self.num_repetitions = num_repetitions def provide(self, request): - batches = [] for _ in range(self.num_repetitions): upstream_request = request.copy() diff --git a/gunpowder/nodes/unsqueeze.py b/gunpowder/nodes/unsqueeze.py index 3df03ec4..9f019978 100644 --- a/gunpowder/nodes/unsqueeze.py +++ b/gunpowder/nodes/unsqueeze.py @@ -1,4 +1,3 @@ -import copy from typing import List import logging diff --git a/gunpowder/nodes/zarr_source.py b/gunpowder/nodes/zarr_source.py index 812769f3..b7133580 100644 --- a/gunpowder/nodes/zarr_source.py +++ b/gunpowder/nodes/zarr_source.py @@ -107,9 +107,8 @@ def _get_offset(self, dataset): def _rev_metadata(self): with ZarrFile(self.store, mode="a") as store: - return ( - isinstance(store.chunk_store, N5Store) or - isinstance(store.chunk_store, N5FSStore) + return isinstance(store.chunk_store, N5Store) or isinstance( + store.chunk_store, N5FSStore ) def _open_file(self, store): diff --git a/gunpowder/nodes/zarr_write.py b/gunpowder/nodes/zarr_write.py index 35965b6d..3beba3ae 100644 --- a/gunpowder/nodes/zarr_write.py +++ b/gunpowder/nodes/zarr_write.py @@ -5,10 +5,10 @@ from zarr import N5FSStore, N5Store from .batch_filter import BatchFilter +from gunpowder.array import ArrayKey from gunpowder.batch_request import BatchRequest from gunpowder.coordinate import Coordinate from gunpowder.roi import Roi -from gunpowder.coordinate import Coordinate from gunpowder.ext import ZarrFile import logging @@ -71,7 +71,7 @@ def __init__( else: self.dataset_dtypes = dataset_dtypes - self.dataset_offsets = {} + self.dataset_offsets: dict[ArrayKey, Coordinate] = {} def _get_voxel_size(self, dataset): if "resolution" not in dataset.attrs: diff --git a/gunpowder/pipeline.py b/gunpowder/pipeline.py index f9976c26..cad87f1a 100644 --- a/gunpowder/pipeline.py +++ b/gunpowder/pipeline.py @@ -1,5 +1,8 @@ -import logging from gunpowder.nodes import BatchProvider +from gunpowder.nodes.batch_provider import BatchRequestError + +import logging +import traceback logger = logging.getLogger(__name__) @@ -21,13 +24,19 @@ def __str__(self): class PipelineRequestError(Exception): - def __init__(self, pipeline, request): + def __init__(self, pipeline, request, original_traceback=None): self.pipeline = pipeline self.request = request + self.original_traceback = original_traceback def __str__(self): return ( - "Exception in pipeline:\n" + ( + ("".join(self.original_traceback)) + if self.original_traceback is not None + else "" + ) + + "Exception in pipeline:\n" f"{self.pipeline}\n" "while trying to process request\n" f"{self.request}" @@ -123,6 +132,11 @@ def request_batch(self, request): try: return self.output.request_batch(request) + except BatchRequestError as e: + tb = traceback.format_exception(type(e), e, e.__traceback__) + if isinstance(e, BatchRequestError): + tb = tb[-1:] + raise PipelineRequestError(self, request, original_traceback=tb) from None except Exception as e: raise PipelineRequestError(self, request) from e diff --git a/gunpowder/producer_pool.py b/gunpowder/producer_pool.py index 035f0d74..73df6b6d 100644 --- a/gunpowder/producer_pool.py +++ b/gunpowder/producer_pool.py @@ -6,8 +6,6 @@ import multiprocessing import os import sys -import time -import traceback import numpy as np @@ -143,9 +141,7 @@ def _run_worker(self, target): try: result = target() except Exception as e: - logger.error(e, exc_info=True) result = e - traceback.print_exc() # don't stop on normal exceptions -- place them in result queue # and let them be handled by caller except: diff --git a/gunpowder/provider_spec.py b/gunpowder/provider_spec.py index 7c34324d..6a1ab818 100644 --- a/gunpowder/provider_spec.py +++ b/gunpowder/provider_spec.py @@ -6,7 +6,6 @@ from gunpowder.graph_spec import GraphSpec from gunpowder.roi import Roi from .freezable import Freezable -import time import logging import copy @@ -14,7 +13,6 @@ import logging -import warnings logger = logging.getLogger(__file__) diff --git a/gunpowder/torch/nodes/predict.py b/gunpowder/torch/nodes/predict.py index 3e5ba8f1..89c9ac0c 100644 --- a/gunpowder/torch/nodes/predict.py +++ b/gunpowder/torch/nodes/predict.py @@ -4,7 +4,7 @@ from gunpowder.nodes.generic_predict import GenericPredict import logging -from typing import Dict, Union +from typing import Dict, Union, Optional, Any logger = logging.getLogger(__name__) @@ -60,8 +60,8 @@ def __init__( model, inputs: Dict[str, ArrayKey], outputs: Dict[Union[str, int], ArrayKey], - array_specs: Dict[ArrayKey, ArraySpec] = None, - checkpoint: str = None, + array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, + checkpoint: Optional[str] = None, device="cuda", spawn_subprocess=False, ): @@ -82,14 +82,16 @@ def __init__( self.model = model self.checkpoint = checkpoint - self.intermediate_layers = {} - self.register_hooks() + self.intermediate_layers: dict[ArrayKey, Any] = {} def start(self): - self.use_cuda = torch.cuda.is_available() and self.device_string == "cuda" - logger.info(f"Predicting on {'gpu' if self.use_cuda else 'cpu'}") - self.device = torch.device("cuda" if self.use_cuda else "cpu") + # Issue #188 + self.use_cuda = torch.cuda.is_available() and self.device_string.__contains__( + "cuda" + ) + logger.info(f"Predicting on {'gpu' if self.use_cuda else 'cpu'}") + self.device = torch.device(self.device_string if self.use_cuda else "cpu") try: self.model = self.model.to(self.device) except RuntimeError as e: @@ -106,6 +108,8 @@ def start(self): else: self.model.load_state_dict(checkpoint) + self.register_hooks() + def predict(self, batch, request): inputs = self.get_inputs(batch) with torch.no_grad(): diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index ed2df002..676b2c71 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -6,7 +6,8 @@ from gunpowder.ext import torch, tensorboardX, NoSuchModule from gunpowder.nodes.generic_train import GenericTrain -from typing import Dict, Union, Optional +from typing import Dict, Union, Optional, Any +import itertools logger = logging.getLogger(__name__) @@ -78,6 +79,12 @@ class Train(GenericTrain): spawn_subprocess (``bool``, optional): Whether to run the ``train_step`` in a separate process. Default is false. + + device (``str``, optional): + + Accepts a cuda gpu specifically to train on (e.g. `cuda:1`, `cuda:2`), helps in multi-card systems. + defaults to ``cuda`` + """ def __init__( @@ -92,9 +99,10 @@ def __init__( array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, checkpoint_basename: str = "model", save_every: int = 2000, - log_dir: str = None, + log_dir: Optional[str] = None, log_every: int = 1, spawn_subprocess: bool = False, + device: str = "cuda", ): if not model.training: logger.warning( @@ -104,12 +112,18 @@ def __init__( # not yet implemented gradients = gradients - inputs.update( - {k: v for k, v in loss_inputs.items() if v not in outputs.values()} - ) + all_inputs = { + k: v + for k, v in itertools.chain(inputs.items(), loss_inputs.items()) + if v not in outputs.values() + } super(Train, self).__init__( - inputs, outputs, gradients, array_specs, spawn_subprocess=spawn_subprocess + all_inputs, + outputs, + gradients, + array_specs, + spawn_subprocess=spawn_subprocess, ) self.model = model @@ -118,6 +132,7 @@ def __init__( self.loss_inputs = loss_inputs self.checkpoint_basename = checkpoint_basename self.save_every = save_every + self.dev = device self.iteration = 0 @@ -129,7 +144,7 @@ def __init__( if log_dir is not None: logger.warning("log_dir given, but tensorboardX is not installed") - self.intermediate_layers = {} + self.intermediate_layers: dict[ArrayKey, Any] = {} self.register_hooks() def register_hooks(self): @@ -160,7 +175,8 @@ def retain_gradients(self, request, outputs): def start(self): self.use_cuda = torch.cuda.is_available() - self.device = torch.device("cuda" if self.use_cuda else "cpu") + # Issue: #188 + self.device = torch.device(self.dev if self.use_cuda else "cpu") try: self.model = self.model.to(self.device) @@ -278,13 +294,6 @@ def train_step(self, batch, request): spec.roi = request[array_key].roi batch.arrays[array_key] = Array(tensor.grad.cpu().detach().numpy(), spec) - for array_key, array_name in requested_outputs.items(): - spec = self.spec[array_key].copy() - spec.roi = request[array_key].roi - batch.arrays[array_key] = Array( - outputs[array_name].cpu().detach().numpy(), spec - ) - batch.loss = loss.cpu().detach().numpy() self.iteration += 1 batch.iteration = self.iteration diff --git a/gunpowder/version_info.py b/gunpowder/version_info.py index 01724d07..e45efbfd 100644 --- a/gunpowder/version_info.py +++ b/gunpowder/version_info.py @@ -1,6 +1,6 @@ __major__ = 1 __minor__ = 3 -__patch__ = 1 +__patch__ = 2 __tag__ = "" __version__ = "{}.{}.{}{}".format(__major__, __minor__, __patch__, __tag__).strip(".") diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..6daa39e0 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,37 @@ +[mypy] + +# ext +[mypy-dvision.*] +ignore_missing_imports = True +[mypy-pyklb.*] +ignore_missing_imports = True +[mypy-malis.*] +ignore_missing_imports = True +[mypy-haiku.*] +ignore_missing_imports = True +[mypy-optax.*] +ignore_missing_imports = True + +# dependencies +[mypy-tensorflow.*] +ignore_missing_imports = True +[mypy-tensorboardX.*] +ignore_missing_imports = True +[mypy-torch.*] +ignore_missing_imports = True +[mypy-jax.*] +ignore_missing_imports = True +[mypy-daisy.*] +ignore_missing_imports = True +[mypy-scipy.*] +ignore_missing_imports = True +[mypy-h5py.*] +ignore_missing_imports = True +[mypy-augment.*] +ignore_missing_imports = True +[mypy-zarr.*] +ignore_missing_imports = True +[mypy-networkx.*] +ignore_missing_imports = True +[mypy-Queue.*] +ignore_missing_imports = True \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 11ab82bc..01441435 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,77 +6,71 @@ requires = ["setuptools", "wheel"] name = "gunpowder" description = "A library to facilitate machine learning on large, multi-dimensional images." authors = [ - {name = "Jan Funke", email = "funkej@hhmi.org"}, - {name = "William Patton", email = "pattonw@hhmi.org"}, - {name = "Renate Krause"}, - {name = "Julia Buhmann"}, - {name = "Rodrigo Ceballos Lentini"}, - {name = "William Grisaitis"}, - {name = "Chris Barnes"}, - {name = "Caroline Malin-Mayor"}, - {name = "Larissa Heinrich"}, - {name = "Philipp Hanslovsky"}, - {name = "Sherry Ding"}, - {name = "Andrew Champion"}, - {name = "Arlo Sheridan"}, - {name = "Constantin Pape"}, + { name = "Jan Funke", email = "funkej@hhmi.org" }, + { name = "William Patton", email = "pattonw@hhmi.org" }, + { name = "Renate Krause" }, + { name = "Julia Buhmann" }, + { name = "Rodrigo Ceballos Lentini" }, + { name = "William Grisaitis" }, + { name = "Chris Barnes" }, + { name = "Caroline Malin-Mayor" }, + { name = "Larissa Heinrich" }, + { name = "Philipp Hanslovsky" }, + { name = "Sherry Ding" }, + { name = "Andrew Champion" }, + { name = "Arlo Sheridan" }, + { name = "Constantin Pape" }, ] -license = {text = "MIT"} +license = { text = "MIT" } readme = "README.md" dynamic = ["version"] -classifiers = [ - "Programming Language :: Python :: 3", -] +classifiers = ["Programming Language :: Python :: 3"] keywords = [] requires-python = ">=3.7" dependencies = [ - "numpy", - "scipy", - "h5py", - "scikit-image", - "requests", - "augment-nd>=0.1.3", - "tqdm", - "funlib.geometry", - "zarr", - "networkx", + "numpy>=1.24", + "scipy>=1.6", + "h5py>=3.10", + "scikit-image", + "requests", + "augment-nd>=0.1.3", + "tqdm", + "funlib.geometry>=0.3", + "zarr", + "networkx>=3.1", ] [project.optional-dependencies] -dev = [ - "pytest", - "pytest-cov", - "flake8", -] +dev = ["pytest", "pytest-cov", "flake8", "mypy", "types-requests", "types-tqdm"] docs = [ - "sphinx", - "sphinx_rtd_theme", - "sphinx_togglebutton", - "tomli", - "jupyter_sphinx", - "ipykernel", - "matplotlib", - "torch", + "sphinx", + "sphinx_rtd_theme", + "sphinx_togglebutton", + "tomli", + "jupyter_sphinx", + "ipykernel", + "matplotlib", + "torch", ] pytorch = ['torch'] tensorflow = [ - # TF doesn't provide <2.0 wheels for py>=3.8 on pypi - 'tensorflow<2.0; python_version<"3.8"', # https://stackoverflow.com/a/72493690 - 'protobuf==3.20.*; python_version=="3.7"', + # TF doesn't provide <2.0 wheels for py>=3.8 on pypi + 'tensorflow<2.0; python_version<"3.8"', # https://stackoverflow.com/a/72493690 + 'protobuf==3.20.*; python_version=="3.7"', ] full = [ - 'torch', - 'tensorflow<2.0; python_version<"3.8"', - 'protobuf==3.20.*; python_version=="3.7"', + 'torch', + 'tensorflow<2.0; python_version<"3.8"', + 'protobuf==3.20.*; python_version=="3.7"', ] [tool.setuptools.dynamic] -version = {attr = "gunpowder.version_info.__version__"} +version = { attr = "gunpowder.version_info.__version__" } [tool.black] -target_version = ['py36', 'py37', 'py38', 'py39', 'py310'] +target_version = ['py38', 'py39', 'py310'] [tool.setuptools.packages.find] include = ["gunpowder*"] diff --git a/tests/cases/add_affinities.py b/tests/cases/add_affinities.py index 792be290..bd6ebc93 100644 --- a/tests/cases/add_affinities.py +++ b/tests/cases/add_affinities.py @@ -1,10 +1,6 @@ -from .provider_test import ProviderTest from gunpowder import * from itertools import product -from unittest import skipIf -import itertools import numpy as np -import logging class ExampleSource(BatchProvider): diff --git a/tests/cases/deform_augment.py b/tests/cases/deform_augment.py index 2134a708..f722b0bb 100644 --- a/tests/cases/deform_augment.py +++ b/tests/cases/deform_augment.py @@ -160,6 +160,9 @@ def test_3d_basics(rotate, spatial_dims, fast_points): loc = (loc - labels.spec.roi.begin) / labels.spec.voxel_size loc = np.array(loc) com = center_of_mass(labels.data == node.id) + if any(np.isnan(com)): + # cannot assume that the rasterized data will exist after defomation + continue assert ( np.linalg.norm(com - loc) < np.linalg.norm(labels.spec.voxel_size) * 2 diff --git a/tests/cases/dvid_source.py b/tests/cases/dvid_source.py index 8f2c31e6..ac206909 100644 --- a/tests/cases/dvid_source.py +++ b/tests/cases/dvid_source.py @@ -2,7 +2,6 @@ from unittest import skipIf from gunpowder import * from gunpowder.ext import dvision, NoSuchModule -import numpy as np import socket import logging diff --git a/tests/cases/elastic_augment_points.py b/tests/cases/elastic_augment_points.py index 76e9ae2b..ddb99741 100644 --- a/tests/cases/elastic_augment_points.py +++ b/tests/cases/elastic_augment_points.py @@ -1,4 +1,3 @@ -import unittest from gunpowder import ( BatchProvider, Batch, @@ -25,7 +24,6 @@ import numpy as np import math import time -import unittest class PointTestSource3D(BatchProvider): diff --git a/tests/cases/expected_failures.py b/tests/cases/expected_failures.py index 39a2d21e..8adf48af 100644 --- a/tests/cases/expected_failures.py +++ b/tests/cases/expected_failures.py @@ -2,7 +2,7 @@ from gunpowder.nodes.batch_provider import BatchRequestError from .helper_sources import ArraySource -from funlib.geometry import Roi, Coordinate +from funlib.geometry import Coordinate import numpy as np import pytest diff --git a/tests/cases/helper_sources.py b/tests/cases/helper_sources.py index 219044b1..630333d6 100644 --- a/tests/cases/helper_sources.py +++ b/tests/cases/helper_sources.py @@ -13,7 +13,10 @@ def setup(self): def provide(self, request): outputs = Batch() - outputs[self.key] = copy.deepcopy(self.array.crop(request[self.key].roi)) + if self.array.spec.nonspatial: + outputs[self.key] = copy.deepcopy(self.array) + else: + outputs[self.key] = copy.deepcopy(self.array.crop(request[self.key].roi)) return outputs diff --git a/tests/cases/intensity_scale_shift.py b/tests/cases/intensity_scale_shift.py index d65df3dd..c64b4ec3 100644 --- a/tests/cases/intensity_scale_shift.py +++ b/tests/cases/intensity_scale_shift.py @@ -3,7 +3,6 @@ IntensityScaleShift, ArrayKey, build, - Normalize, Array, ArraySpec, Roi, diff --git a/tests/cases/jax_train.py b/tests/cases/jax_train.py index 14fbad46..2ff55be6 100644 --- a/tests/cases/jax_train.py +++ b/tests/cases/jax_train.py @@ -4,18 +4,15 @@ BatchRequest, ArraySpec, Roi, - Coordinate, ArrayKeys, ArrayKey, Array, Batch, - Scan, - PreCache, build, ) from gunpowder.ext import jax, haiku, optax, NoSuchModule from gunpowder.jax import Train, Predict, GenericJaxModel -from unittest import skipIf, expectedFailure +from unittest import skipIf import numpy as np import logging diff --git a/tests/cases/noise_augment.py b/tests/cases/noise_augment.py index 6e9f635c..57768091 100644 --- a/tests/cases/noise_augment.py +++ b/tests/cases/noise_augment.py @@ -1,8 +1,6 @@ from .provider_test import ProviderTest from gunpowder import IntensityAugment, ArrayKeys, build, Normalize, NoiseAugment -import numpy as np - class TestIntensityAugment(ProviderTest): def test_shift(self): diff --git a/tests/cases/pad.py b/tests/cases/pad.py index 8b7ab179..5efda685 100644 --- a/tests/cases/pad.py +++ b/tests/cases/pad.py @@ -1,71 +1,96 @@ -from .provider_test import ProviderTest +from .helper_sources import ArraySource, GraphSource from gunpowder import ( - BatchProvider, BatchRequest, - Batch, - ArrayKeys, ArraySpec, Roi, Coordinate, + Graph, GraphKey, - GraphKeys, GraphSpec, Array, ArrayKey, Pad, build, + MergeProvider, ) -import numpy as np - -class ExampleSourcePad(BatchProvider): - def setup(self): - self.provides( - ArrayKeys.TEST_LABELS, - ArraySpec(roi=Roi((200, 20, 20), (1800, 180, 180)), voxel_size=(20, 2, 2)), - ) +import pytest +import numpy as np - self.provides( - GraphKeys.TEST_GRAPH, GraphSpec(roi=Roi((200, 20, 20), (1800, 180, 180))) - ) +from itertools import product - def provide(self, request): - batch = Batch() - roi_array = request[ArrayKeys.TEST_LABELS].roi - roi_voxel = roi_array // self.spec[ArrayKeys.TEST_LABELS].voxel_size +@pytest.mark.parametrize("mode", ["constant", "reflect"]) +def test_padding(mode): + array_key = ArrayKey("TEST_ARRAY") + graph_key = GraphKey("TEST_GRAPH") - data = np.zeros(roi_voxel.shape, dtype=np.uint32) - data[:, ::2] = 100 + array_spec = ArraySpec(roi=Roi((200, 20, 20), (600, 60, 60)), voxel_size=(20, 2, 2)) + roi_voxel = array_spec.roi / array_spec.voxel_size + data = np.zeros(roi_voxel.shape, dtype=np.uint32) + data[:, ::2] = 100 + array = Array(data, spec=array_spec) - spec = self.spec[ArrayKeys.TEST_LABELS].copy() - spec.roi = roi_array - batch.arrays[ArrayKeys.TEST_LABELS] = Array(data, spec=spec) + graph_spec = GraphSpec(roi=Roi((200, 20, 20), (600, 60, 60))) + graph = Graph([], [], graph_spec) - return batch + source = ( + ArraySource(array_key, array), + GraphSource(graph_key, graph), + ) + MergeProvider() + pipeline = ( + source + + Pad(array_key, Coordinate((200, 20, 20)), value=1, mode=mode) + + Pad(graph_key, Coordinate((100, 10, 10)), mode=mode) + ) -class TestPad(ProviderTest): - def test_output(self): - graph = GraphKey("TEST_GRAPH") - labels = ArrayKey("TEST_LABELS") + with build(pipeline): + assert pipeline.spec[array_key].roi == Roi((0, 0, 0), (1000, 100, 100)) + assert pipeline.spec[graph_key].roi == Roi((100, 10, 10), (800, 80, 80)) - pipeline = ( - ExampleSourcePad() - + Pad(labels, Coordinate((20, 20, 20)), value=1) - + Pad(graph, Coordinate((10, 10, 10))) + batch = pipeline.request_batch( + BatchRequest({array_key: ArraySpec(Roi((180, 0, 0), (40, 40, 40)))}) ) - with build(pipeline): - self.assertTrue( - pipeline.spec[labels].roi == Roi((180, 0, 0), (1840, 220, 220)) + data = batch.arrays[array_key].data + if mode == "constant": + octants = [ + (1 * 10 * 10) if zi + yi + xi < 3 else 100 * 1 * 5 * 10 + for zi, yi, xi in product(range(2), range(2), range(2)) + ] + assert np.sum(data) == np.sum(octants), ( + np.sum(data), + np.sum(octants), + np.unique(data), ) - self.assertTrue( - pipeline.spec[graph].roi == Roi((190, 10, 10), (1820, 200, 200)) + elif mode == "reflect": + octants = [100 * 1 * 5 * 10 for _ in range(8)] + assert np.sum(data) == np.sum(octants), ( + np.sum(data), + np.sum(octants), + data, ) - batch = pipeline.request_batch( - BatchRequest({labels: ArraySpec(Roi((180, 0, 0), (20, 20, 20)))}) - ) + # 1 x 10 x (10,30,10) + batch = pipeline.request_batch( + BatchRequest({array_key: ArraySpec(Roi((200, 20, 0), (20, 20, 100)))}) + ) + data = batch.arrays[array_key].data - self.assertEqual(np.sum(batch.arrays[labels].data), 1 * 10 * 10) + if mode == "constant": + lower_pad = 1 * 10 * 10 + upper_pad = 1 * 10 * 10 + center = 100 * 1 * 5 * 30 + assert np.sum(data) == np.sum((lower_pad, upper_pad, center)), ( + np.sum(data), + np.sum((lower_pad, upper_pad, center)), + ) + elif mode == "reflect": + lower_pad = 100 * 1 * 5 * 10 + upper_pad = 100 * 1 * 5 * 10 + center = 100 * 1 * 5 * 30 + assert np.sum(data) == np.sum((lower_pad, upper_pad, center)), ( + np.sum(data), + np.sum((lower_pad, upper_pad, center)), + ) diff --git a/tests/cases/random_location.py b/tests/cases/random_location.py index df3f1cca..611289a8 100644 --- a/tests/cases/random_location.py +++ b/tests/cases/random_location.py @@ -3,7 +3,6 @@ BatchProvider, Roi, Coordinate, - ArrayKeys, ArrayKey, ArraySpec, Array, diff --git a/tests/cases/rasterize_points.py b/tests/cases/rasterize_points.py index 16c11ae1..a57906f8 100644 --- a/tests/cases/rasterize_points.py +++ b/tests/cases/rasterize_points.py @@ -1,31 +1,30 @@ -from .provider_test import ProviderTest +from .helper_sources import ArraySource, GraphSource from gunpowder import ( - BatchProvider, BatchRequest, - Batch, Roi, Coordinate, GraphSpec, Array, ArrayKey, - ArrayKeys, ArraySpec, RasterizeGraph, + MergeProvider, RasterizationSettings, build, ) -from gunpowder.graph import GraphKeys, GraphKey, Graph, Node, Edge +from gunpowder.graph import GraphKey, Graph, Node, Edge import numpy as np -import math -from random import randint -class GraphTestSource3D(BatchProvider): - def __init__(self): - self.voxel_size = Coordinate((40, 4, 4)) +def test_3d(): + graph_key = GraphKey("TEST_GRAPH") + array_key = ArrayKey("TEST_ARRAY") + rasterized_key = ArrayKey("RASTERIZED_ARRAY") + voxel_size = Coordinate((40, 4, 4)) - self.nodes = [ + graph = Graph( + [ # corners Node(id=1, location=np.array((-200, -200, -200))), Node(id=2, location=np.array((-200, -200, 199))), @@ -38,249 +37,258 @@ def __init__(self): # center Node(id=9, location=np.array((0, 0, 0))), Node(id=10, location=np.array((-1, -1, -1))), - ] - - self.graph_spec = GraphSpec(roi=Roi((-100, -100, -100), (300, 300, 300))) - self.array_spec = ArraySpec( - roi=Roi((-200, -200, -200), (400, 400, 400)), voxel_size=self.voxel_size - ) - - self.graph = Graph(self.nodes, [], self.graph_spec) - - def setup(self): - self.provides( - GraphKeys.TEST_GRAPH, - self.graph_spec, - ) - - self.provides( - ArrayKeys.GT_LABELS, - self.array_spec, - ) - - def provide(self, request): - batch = Batch() - - graph_roi = request[GraphKeys.TEST_GRAPH].roi - - batch.graphs[GraphKeys.TEST_GRAPH] = self.graph.crop(graph_roi).trim(graph_roi) - - roi_array = request[ArrayKeys.GT_LABELS].roi - - image = np.ones(roi_array.shape / self.voxel_size, dtype=np.uint64) - # label half of GT_LABELS differently - depth = image.shape[0] - image[0 : depth // 2] = 2 - - spec = self.spec[ArrayKeys.GT_LABELS].copy() - spec.roi = roi_array - batch.arrays[ArrayKeys.GT_LABELS] = Array(image, spec=spec) - - return batch - - -class GraphTestSourceWithEdge(BatchProvider): - def __init__(self): - self.voxel_size = Coordinate((1, 1, 1)) - - self.nodes = [ - # corners - Node(id=1, location=np.array((0, 4, 4))), - Node(id=2, location=np.array((9, 4, 4))), - ] - self.edges = [Edge(1, 2)] - - self.graph_spec = GraphSpec(roi=Roi((0, 0, 0), (10, 10, 10))) - self.graph = Graph(self.nodes, self.edges, self.graph_spec) - - def setup(self): - self.provides( - GraphKeys.TEST_GRAPH_WITH_EDGE, - self.graph_spec, - ) - - def provide(self, request): - batch = Batch() - - graph_roi = request[GraphKeys.TEST_GRAPH_WITH_EDGE].roi - - batch.graphs[GraphKeys.TEST_GRAPH_WITH_EDGE] = self.graph.crop(graph_roi).trim( - graph_roi - ) - - return batch - - -class TestRasterizePoints(ProviderTest): - def test_3d(self): - GraphKey("TEST_GRAPH") - ArrayKey("RASTERIZED") - - pipeline = GraphTestSource3D() + RasterizeGraph( - GraphKeys.TEST_GRAPH, - ArrayKeys.RASTERIZED, + ], + [], + GraphSpec(roi=Roi((-100, -100, -100), (300, 300, 300))), + ) + + array = Array( + np.ones((10, 100, 100)), + ArraySpec( + roi=Roi((-200, -200, -200), (400, 400, 400)), + voxel_size=voxel_size, + ), + ) + + pipeline = ( + (GraphSource(graph_key, graph), ArraySource(array_key, array)) + + MergeProvider() + + RasterizeGraph( + graph_key, + rasterized_key, ArraySpec(voxel_size=(40, 4, 4)), ) + ) - with build(pipeline): - request = BatchRequest() - roi = Roi((0, 0, 0), (200, 200, 200)) + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (200, 200, 200)) - request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=roi) - request[ArrayKeys.GT_LABELS] = ArraySpec(roi=roi) - request[ArrayKeys.RASTERIZED] = ArraySpec(roi=roi) + request[graph_key] = GraphSpec(roi=roi) + request[array_key] = ArraySpec(roi=roi) + request[rasterized_key] = ArraySpec(roi=roi) - batch = pipeline.request_batch(request) + batch = pipeline.request_batch(request) - rasterized = batch.arrays[ArrayKeys.RASTERIZED].data - self.assertEqual(rasterized[0, 0, 0], 1) - self.assertEqual(rasterized[2, 20, 20], 0) - self.assertEqual(rasterized[4, 49, 49], 1) + rasterized = batch.arrays[rasterized_key].data + assert rasterized[0, 0, 0] == 1 + assert rasterized[2, 20, 20] == 0 + assert rasterized[4, 49, 49] == 1 - # same with different foreground/background labels + # same with different foreground/background labels - pipeline = GraphTestSource3D() + RasterizeGraph( - GraphKeys.TEST_GRAPH, - ArrayKeys.RASTERIZED, + pipeline = ( + (GraphSource(graph_key, graph), ArraySource(array_key, array)) + + MergeProvider() + + RasterizeGraph( + graph_key, + rasterized_key, ArraySpec(voxel_size=(40, 4, 4)), RasterizationSettings(radius=1, fg_value=0, bg_value=1), ) + ) - with build(pipeline): - request = BatchRequest() - roi = Roi((0, 0, 0), (200, 200, 200)) + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (200, 200, 200)) - request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=roi) - request[ArrayKeys.GT_LABELS] = ArraySpec(roi=roi) - request[ArrayKeys.RASTERIZED] = ArraySpec(roi=roi) + request[graph_key] = GraphSpec(roi=roi) + request[array_key] = ArraySpec(roi=roi) + request[rasterized_key] = ArraySpec(roi=roi) - batch = pipeline.request_batch(request) + batch = pipeline.request_batch(request) - rasterized = batch.arrays[ArrayKeys.RASTERIZED].data - self.assertEqual(rasterized[0, 0, 0], 0) - self.assertEqual(rasterized[2, 20, 20], 1) - self.assertEqual(rasterized[4, 49, 49], 0) + rasterized = batch.arrays[rasterized_key].data + assert rasterized[0, 0, 0] == 0 + assert rasterized[2, 20, 20] == 1 + assert rasterized[4, 49, 49] == 0 - # same with different radius and inner radius + # same with different radius and inner radius - pipeline = GraphTestSource3D() + RasterizeGraph( - GraphKeys.TEST_GRAPH, - ArrayKeys.RASTERIZED, + pipeline = ( + (GraphSource(graph_key, graph), ArraySource(array_key, array)) + + MergeProvider() + + RasterizeGraph( + graph_key, + rasterized_key, ArraySpec(voxel_size=(40, 4, 4)), RasterizationSettings( radius=40, inner_radius_fraction=0.25, fg_value=1, bg_value=0 ), ) + ) - with build(pipeline): - request = BatchRequest() - roi = Roi((0, 0, 0), (200, 200, 200)) + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (200, 200, 200)) - request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=roi) - request[ArrayKeys.GT_LABELS] = ArraySpec(roi=roi) - request[ArrayKeys.RASTERIZED] = ArraySpec(roi=roi) + request[graph_key] = GraphSpec(roi=roi) + request[array_key] = ArraySpec(roi=roi) + request[rasterized_key] = ArraySpec(roi=roi) - batch = pipeline.request_batch(request) + batch = pipeline.request_batch(request) - rasterized = batch.arrays[ArrayKeys.RASTERIZED].data + rasterized = batch.arrays[rasterized_key].data - # in the middle of the ball, there should be 0 (since inner radius is set) - self.assertEqual(rasterized[0, 0, 0], 0) - # check larger radius: rasterized point (0, 0, 0) should extend in - # x,y by 10; z, by 1 - self.assertEqual(rasterized[0, 10, 0], 1) - self.assertEqual(rasterized[0, 0, 10], 1) - self.assertEqual(rasterized[1, 0, 0], 1) + # in the middle of the ball, there should be 0 (since inner radius is set) + assert rasterized[0, 0, 0] == 0 + # check larger radius: rasterized point (0, 0, 0) should extend in + # x,y by 10; z, by 1 + assert rasterized[0, 10, 0] == 1 + assert rasterized[0, 0, 10] == 1 + assert rasterized[1, 0, 0] == 1 - self.assertEqual(rasterized[2, 20, 20], 0) - self.assertEqual(rasterized[4, 49, 49], 0) + assert rasterized[2, 20, 20] == 0 + assert rasterized[4, 49, 49] == 0 - # same with anisotropic radius + # same with different foreground/background labels + # and GT_LABELS as mask of type np.uint64. Issue #193 - pipeline = GraphTestSource3D() + RasterizeGraph( - GraphKeys.TEST_GRAPH, - ArrayKeys.RASTERIZED, + pipeline = ( + (GraphSource(graph_key, graph), ArraySource(array_key, array)) + + MergeProvider() + + RasterizeGraph( + graph_key, + rasterized_key, ArraySpec(voxel_size=(40, 4, 4)), - RasterizationSettings(radius=(40, 40, 20), fg_value=1, bg_value=0), + RasterizationSettings(radius=1, fg_value=0, bg_value=1, mask=array_key), ) + ) - with build(pipeline): - request = BatchRequest() - roi = Roi((0, 0, 0), (120, 80, 80)) + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (200, 200, 200)) - request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=roi) - request[ArrayKeys.GT_LABELS] = ArraySpec(roi=roi) - request[ArrayKeys.RASTERIZED] = ArraySpec(roi=roi) + request[graph_key] = GraphSpec(roi=roi) + request[array_key] = ArraySpec(roi=roi) + request[rasterized_key] = ArraySpec(roi=roi) - batch = pipeline.request_batch(request) + batch = pipeline.request_batch(request) - rasterized = batch.arrays[ArrayKeys.RASTERIZED].data + rasterized = batch.arrays[rasterized_key].data + assert rasterized[0, 0, 0] == 0 + assert rasterized[2, 20, 20] == 1 + assert rasterized[4, 49, 49] == 0 - # check larger radius: rasterized point (0, 0, 0) should extend in - # x,y by 10; z, by 1 - self.assertEqual(rasterized[0, 10, 0], 1) - self.assertEqual(rasterized[0, 11, 0], 0) - self.assertEqual(rasterized[0, 0, 5], 1) - self.assertEqual(rasterized[0, 0, 6], 0) - self.assertEqual(rasterized[1, 0, 0], 1) - self.assertEqual(rasterized[2, 0, 0], 0) + # same with anisotropic radius - # same with anisotropic radius and inner radius - - pipeline = GraphTestSource3D() + RasterizeGraph( - GraphKeys.TEST_GRAPH, - ArrayKeys.RASTERIZED, + pipeline = ( + (GraphSource(graph_key, graph), ArraySource(array_key, array)) + + MergeProvider() + + RasterizeGraph( + graph_key, + rasterized_key, ArraySpec(voxel_size=(40, 4, 4)), - RasterizationSettings( - radius=(40, 40, 20), inner_radius_fraction=0.75, fg_value=1, bg_value=0 - ), + RasterizationSettings(radius=(40, 40, 20), fg_value=1, bg_value=0), ) + ) - with build(pipeline): - request = BatchRequest() - roi = Roi((0, 0, 0), (120, 80, 80)) + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (120, 80, 80)) - request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=roi) - request[ArrayKeys.GT_LABELS] = ArraySpec(roi=roi) - request[ArrayKeys.RASTERIZED] = ArraySpec(roi=roi) + request[graph_key] = GraphSpec(roi=roi) + request[array_key] = ArraySpec(roi=roi) + request[rasterized_key] = ArraySpec(roi=roi) - batch = pipeline.request_batch(request) + batch = pipeline.request_batch(request) - rasterized = batch.arrays[ArrayKeys.RASTERIZED].data + rasterized = batch.arrays[rasterized_key].data - # in the middle of the ball, there should be 0 (since inner radius is set) - self.assertEqual(rasterized[0, 0, 0], 0) - # check larger radius: rasterized point (0, 0, 0) should extend in - # x,y by 10; z, by 1 - self.assertEqual(rasterized[0, 10, 0], 1) - self.assertEqual(rasterized[0, 11, 0], 0) - self.assertEqual(rasterized[0, 0, 5], 1) - self.assertEqual(rasterized[0, 0, 6], 0) - self.assertEqual(rasterized[1, 0, 0], 1) - self.assertEqual(rasterized[2, 0, 0], 0) + # check larger radius: rasterized point (0, 0, 0) should extend in + # x,y by 10; z, by 1 + assert rasterized[0, 10, 0] == 1 + assert rasterized[0, 11, 0] == 0 + assert rasterized[0, 0, 5] == 1 + assert rasterized[0, 0, 6] == 0 + assert rasterized[1, 0, 0] == 1 + assert rasterized[2, 0, 0] == 0 - def test_with_edge(self): - graph_with_edge = GraphKey("TEST_GRAPH_WITH_EDGE") - array_with_edge = ArrayKey("RASTERIZED_EDGE") + # same with anisotropic radius and inner radius - pipeline = GraphTestSourceWithEdge() + RasterizeGraph( - GraphKeys.TEST_GRAPH_WITH_EDGE, - ArrayKeys.RASTERIZED_EDGE, + pipeline = ( + (GraphSource(graph_key, graph), ArraySource(array_key, array)) + + MergeProvider() + + RasterizeGraph( + graph_key, + rasterized_key, + ArraySpec(voxel_size=(40, 4, 4)), + RasterizationSettings( + radius=(40, 40, 20), inner_radius_fraction=0.75, fg_value=1, bg_value=0 + ), + ) + ) + + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (120, 80, 80)) + + request[graph_key] = GraphSpec(roi=roi) + request[array_key] = ArraySpec(roi=roi) + request[rasterized_key] = ArraySpec(roi=roi) + + batch = pipeline.request_batch(request) + + rasterized = batch.arrays[rasterized_key].data + + # in the middle of the ball, there should be 0 (since inner radius is set) + assert rasterized[0, 0, 0] == 0 + # check larger radius: rasterized point (0, 0, 0) should extend in + # x,y by 10; z, by 1 + assert rasterized[0, 10, 0] == 1 + assert rasterized[0, 11, 0] == 0 + assert rasterized[0, 0, 5] == 1 + assert rasterized[0, 0, 6] == 0 + assert rasterized[1, 0, 0] == 1 + assert rasterized[2, 0, 0] == 0 + + +def test_with_edge(): + graph_key = GraphKey("TEST_GRAPH") + array_key = ArrayKey("TEST_ARRAY") + rasterized_key = ArrayKey("RASTERIZED_ARRAY") + voxel_size = Coordinate((40, 4, 4)) + + array = Array( + np.ones((10, 100, 100)), + ArraySpec( + roi=Roi((-200, -200, -200), (400, 400, 400)), + voxel_size=voxel_size, + ), + ) + + graph = Graph( + [ + # corners + Node(id=1, location=np.array((0, 4, 4))), + Node(id=2, location=np.array((9, 4, 4))), + ], + [Edge(1, 2)], + GraphSpec(roi=Roi((0, 0, 0), (10, 10, 10))), + ) + + pipeline = ( + (GraphSource(graph_key, graph), ArraySource(array_key, array)) + + MergeProvider() + + RasterizeGraph( + graph_key, + rasterized_key, ArraySpec(voxel_size=(1, 1, 1)), settings=RasterizationSettings(0.5), ) + ) - with build(pipeline): - request = BatchRequest() - roi = Roi((0, 0, 0), (10, 10, 10)) + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (10, 10, 10)) - request[GraphKeys.TEST_GRAPH_WITH_EDGE] = GraphSpec(roi=roi) - request[ArrayKeys.RASTERIZED_EDGE] = ArraySpec(roi=roi) + request[graph_key] = GraphSpec(roi=roi) + request[rasterized_key] = ArraySpec(roi=roi) - batch = pipeline.request_batch(request) + batch = pipeline.request_batch(request) - rasterized = batch.arrays[ArrayKeys.RASTERIZED_EDGE].data + rasterized = batch.arrays[rasterized_key].data - assert ( - rasterized.sum() == 10 - ), f"rasterized has ones at: {np.where(rasterized==1)}" + assert ( + rasterized.sum() == 10 + ), f"rasterized has ones at: {np.where(rasterized==1)}" diff --git a/tests/cases/resample.py b/tests/cases/resample.py index d7a057b0..9784b152 100644 --- a/tests/cases/resample.py +++ b/tests/cases/resample.py @@ -5,7 +5,6 @@ ArraySpec, Roi, Coordinate, - Batch, BatchRequest, Array, MergeProvider, diff --git a/tests/cases/simple_augment.py b/tests/cases/simple_augment.py index c77709c0..0696213c 100644 --- a/tests/cases/simple_augment.py +++ b/tests/cases/simple_augment.py @@ -1,6 +1,4 @@ from gunpowder import ( - Batch, - BatchProvider, BatchRequest, Array, ArrayKey, diff --git a/tests/cases/snapshot.py b/tests/cases/snapshot.py index 8dcc4443..928076ea 100644 --- a/tests/cases/snapshot.py +++ b/tests/cases/snapshot.py @@ -4,10 +4,8 @@ GraphSpec, Graph, ArrayKey, - ArrayKeys, ArraySpec, Array, - RasterizeGraph, Snapshot, BatchProvider, BatchRequest, diff --git a/tests/cases/tensorflow_train.py b/tests/cases/tensorflow_train.py index f4eae06e..079be0d3 100644 --- a/tests/cases/tensorflow_train.py +++ b/tests/cases/tensorflow_train.py @@ -11,7 +11,7 @@ build, ) from gunpowder.ext import tensorflow, NoSuchModule -from gunpowder.tensorflow import Train, Predict, LocalServer +from gunpowder.tensorflow import Train import multiprocessing import numpy as np from unittest import skipIf diff --git a/tests/cases/torch_train.py b/tests/cases/torch_train.py index c213a1c9..0196c67d 100644 --- a/tests/cases/torch_train.py +++ b/tests/cases/torch_train.py @@ -1,230 +1,221 @@ -from .provider_test import ProviderTest +from .helper_sources import ArraySource from gunpowder import ( - BatchProvider, BatchRequest, ArraySpec, Roi, - Coordinate, - ArrayKeys, ArrayKey, Array, - Batch, Scan, PreCache, + MergeProvider, build, ) from gunpowder.ext import torch, NoSuchModule from gunpowder.torch import Train, Predict -from unittest import skipIf, expectedFailure +from unittest import skipIf import numpy as np +import pytest import logging +TORCH_AVAILABLE = isinstance(torch, NoSuchModule) -class ExampleTorchTrain2DSource(BatchProvider): - def __init__(self): - pass - def setup(self): - spec = ArraySpec( - roi=Roi((0, 0), (17, 17)), - dtype=np.float32, - interpolatable=True, - voxel_size=(1, 1), - ) - self.provides(ArrayKeys.A, spec) +# Example 2D source +def example_2d_source(array_key: ArrayKey): + array_spec = ArraySpec( + roi=Roi((0, 0), (17, 17)), + dtype=np.float32, + interpolatable=True, + voxel_size=(1, 1), + ) + data = np.array(list(range(17)), dtype=np.float32).reshape([17, 1]) + data = data + data.T + array = Array(data, array_spec) + return ArraySource(array_key, array) - def provide(self, request): - batch = Batch() - - spec = self.spec[ArrayKeys.A] - - x = np.array(list(range(17)), dtype=np.float32).reshape([17, 1]) - x = x + x.T - - batch.arrays[ArrayKeys.A] = Array(x, spec).crop(request[ArrayKeys.A].roi) - - return batch - - -class ExampleTorchTrainSource(BatchProvider): - def setup(self): - spec = ArraySpec( - roi=Roi((0, 0), (2, 2)), - dtype=np.float32, - interpolatable=True, - voxel_size=(1, 1), - ) - self.provides(ArrayKeys.A, spec) - self.provides(ArrayKeys.B, spec) - - spec = ArraySpec(nonspatial=True) - self.provides(ArrayKeys.C, spec) - - def provide(self, request): - batch = Batch() - - spec = self.spec[ArrayKeys.A] - spec.roi = request[ArrayKeys.A].roi - - batch.arrays[ArrayKeys.A] = Array( - np.array([[0, 1], [2, 3]], dtype=np.float32), spec - ) - - spec = self.spec[ArrayKeys.B] - spec.roi = request[ArrayKeys.B].roi - - batch.arrays[ArrayKeys.B] = Array( - np.array([[0, 1], [2, 3]], dtype=np.float32), spec - ) - - spec = self.spec[ArrayKeys.C] - - batch.arrays[ArrayKeys.C] = Array(np.array([1], dtype=np.float32), spec) - - return batch - - -@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -class TestTorchTrain(ProviderTest): - def test_output(self): - logging.getLogger("gunpowder.torch.nodes.train").setLevel(logging.INFO) - - checkpoint_basename = self.path_to("model") - - ArrayKey("A") - ArrayKey("B") - ArrayKey("C") - ArrayKey("C_PREDICTED") - ArrayKey("C_GRADIENT") - - class ExampleModel(torch.nn.Module): - def __init__(self): - super(ExampleModel, self).__init__() - self.linear = torch.nn.Linear(4, 1, False) - - def forward(self, a, b): - a = a.reshape(-1) - b = b.reshape(-1) - return self.linear(a * b) - - model = ExampleModel() - loss = torch.nn.MSELoss() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-7, momentum=0.999) - - source = ExampleTorchTrainSource() - train = Train( - model=model, - optimizer=optimizer, - loss=loss, - inputs={"a": ArrayKeys.A, "b": ArrayKeys.B}, - loss_inputs={0: ArrayKeys.C_PREDICTED, 1: ArrayKeys.C}, - outputs={0: ArrayKeys.C_PREDICTED}, - gradients={0: ArrayKeys.C_GRADIENT}, - array_specs={ - ArrayKeys.C_PREDICTED: ArraySpec(nonspatial=True), - ArrayKeys.C_GRADIENT: ArraySpec(nonspatial=True), - }, - checkpoint_basename=checkpoint_basename, - save_every=100, - spawn_subprocess=True, - ) - pipeline = source + train - - request = BatchRequest( - { - ArrayKeys.A: ArraySpec(roi=Roi((0, 0), (2, 2))), - ArrayKeys.B: ArraySpec(roi=Roi((0, 0), (2, 2))), - ArrayKeys.C: ArraySpec(nonspatial=True), - ArrayKeys.C_PREDICTED: ArraySpec(nonspatial=True), - ArrayKeys.C_GRADIENT: ArraySpec(nonspatial=True), - } - ) - - # train for a couple of iterations - with build(pipeline): + +def example_train_source(a_key, b_key, c_key): + spec1 = ArraySpec( + roi=Roi((0, 0), (2, 2)), + dtype=np.float32, + interpolatable=True, + voxel_size=(1, 1), + ) + spec2 = ArraySpec(nonspatial=True) + + data1 = np.array([[0, 1], [2, 3]], dtype=np.float32) + data2 = np.array([1], dtype=np.float32) + + source_a = ArraySource(a_key, Array(data1, spec1)) + source_b = ArraySource(b_key, Array(data1, spec1)) + source_c = ArraySource(c_key, Array(data2, spec2)) + + return (source_a, source_b, source_c) + MergeProvider() + + +if not TORCH_AVAILABLE: + + class ExampleLinearModel(torch.nn.Module): + def __init__(self): + super(ExampleLinearModel, self).__init__() + self.linear = torch.nn.Linear(4, 1, False) + self.linear.weight.data = torch.Tensor([0, 1, 2, 3]) + + def forward(self, a, b): + a = a.reshape(-1) + b = b.reshape(-1) + c_pred = self.linear(a * b) + d_pred = c_pred * 2 + return d_pred + + +@skipIf(TORCH_AVAILABLE, "torch is not installed") +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=pytest.mark.skipif( + TORCH_AVAILABLE or not torch.cuda.is_available(), + reason="CUDA not available", + ), + ), + ], +) +def test_loss_drops(tmpdir, device): + checkpoint_basename = str(tmpdir / "model") + + a_key = ArrayKey("A") + b_key = ArrayKey("B") + c_key = ArrayKey("C") + c_predicted_key = ArrayKey("C_PREDICTED") + c_gradient_key = ArrayKey("C_GRADIENT") + + model = ExampleLinearModel() + loss = torch.nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.999) + + source = example_train_source(a_key, b_key, c_key) + train = Train( + model=model, + optimizer=optimizer, + loss=loss, + inputs={"a": a_key, "b": b_key}, + loss_inputs={0: c_predicted_key, 1: c_key}, + outputs={0: c_predicted_key}, + gradients={0: c_gradient_key}, + array_specs={ + c_predicted_key: ArraySpec(nonspatial=True), + c_gradient_key: ArraySpec(nonspatial=True), + }, + checkpoint_basename=checkpoint_basename, + save_every=100, + spawn_subprocess=False, + device=device, + ) + pipeline = source + train + + request = BatchRequest( + { + a_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + b_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + c_key: ArraySpec(nonspatial=True), + c_predicted_key: ArraySpec(nonspatial=True), + c_gradient_key: ArraySpec(nonspatial=True), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch = pipeline.request_batch(request) + + for i in range(200 - 1): + loss1 = batch.loss batch = pipeline.request_batch(request) + loss2 = batch.loss + assert loss2 < loss1 - for i in range(200 - 1): - loss1 = batch.loss - batch = pipeline.request_batch(request) - loss2 = batch.loss - self.assertLess(loss2, loss1) - - # resume training - with build(pipeline): - for i in range(100): - loss1 = batch.loss - batch = pipeline.request_batch(request) - loss2 = batch.loss - self.assertLess(loss2, loss1) - - -@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -class TestTorchPredict(ProviderTest): - def test_output(self): - logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) - - a = ArrayKey("A") - b = ArrayKey("B") - c = ArrayKey("C") - c_pred = ArrayKey("C_PREDICTED") - d_pred = ArrayKey("D_PREDICTED") - - class ExampleModel(torch.nn.Module): - def __init__(self): - super(ExampleModel, self).__init__() - self.linear = torch.nn.Linear(4, 1, False) - self.linear.weight.data = torch.Tensor([1, 1, 1, 1]) - - def forward(self, a, b): - a = a.reshape(-1) - b = b.reshape(-1) - c_pred = self.linear(a * b) - d_pred = c_pred * 2 - return d_pred - - model = ExampleModel() - - source = ExampleTorchTrainSource() - predict = Predict( - model=model, - inputs={"a": a, "b": b}, - outputs={"linear": c_pred, 0: d_pred}, - array_specs={ - c: ArraySpec(nonspatial=True), - c_pred: ArraySpec(nonspatial=True), - d_pred: ArraySpec(nonspatial=True), - }, - spawn_subprocess=True, - ) - pipeline = source + predict - - request = BatchRequest( - { - a: ArraySpec(roi=Roi((0, 0), (2, 2))), - b: ArraySpec(roi=Roi((0, 0), (2, 2))), - c: ArraySpec(nonspatial=True), - c_pred: ArraySpec(nonspatial=True), - d_pred: ArraySpec(nonspatial=True), - } - ) - - # train for a couple of iterations - with build(pipeline): - batch1 = pipeline.request_batch(request) - batch2 = pipeline.request_batch(request) - - assert np.isclose(batch1[c_pred].data, batch2[c_pred].data) - assert np.isclose(batch1[c_pred].data, 1 + 4 + 9) - assert np.isclose(batch2[d_pred].data, 2 * (1 + 4 + 9)) - - -if not isinstance(torch, NoSuchModule): - - class ExampleModel(torch.nn.Module): + # resume training + with build(pipeline): + for i in range(100): + loss1 = batch.loss + batch = pipeline.request_batch(request) + loss2 = batch.loss + assert loss2 < loss1 + + +@skipIf(TORCH_AVAILABLE, "torch is not installed") +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=[ + pytest.mark.skipif( + TORCH_AVAILABLE or not torch.cuda.is_available(), + reason="CUDA not available", + ), + pytest.mark.xfail( + reason="failing to move model to device when using a subprocess" + ), + ], + ), + ], +) +def test_output(device): + logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) + + a_key = ArrayKey("A") + b_key = ArrayKey("B") + c_key = ArrayKey("C") + c_pred = ArrayKey("C_PREDICTED") + d_pred = ArrayKey("D_PREDICTED") + + model = ExampleLinearModel() + + source = example_train_source(a_key, b_key, c_key) + predict = Predict( + model=model, + inputs={"a": a_key, "b": b_key}, + outputs={"linear": c_pred, 0: d_pred}, + array_specs={ + c_key: ArraySpec(nonspatial=True), + c_pred: ArraySpec(nonspatial=True), + d_pred: ArraySpec(nonspatial=True), + }, + spawn_subprocess=True, + device=device, + ) + pipeline = source + predict + + request = BatchRequest( + { + a_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + b_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + c_key: ArraySpec(nonspatial=True), + c_pred: ArraySpec(nonspatial=True), + d_pred: ArraySpec(nonspatial=True), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch1 = pipeline.request_batch(request) + batch2 = pipeline.request_batch(request) + + assert np.isclose(batch1[c_pred].data, batch2[c_pred].data) + assert np.isclose(batch1[c_pred].data, 1 + 4 * 2 + 9 * 3) + assert np.isclose(batch2[d_pred].data, 2 * (1 + 4 * 2 + 9 * 3)) + + +if not TORCH_AVAILABLE: + + class Example2DModel(torch.nn.Module): def __init__(self): - super(ExampleModel, self).__init__() + super(Example2DModel, self).__init__() self.linear = torch.nn.Conv2d(1, 1, 3) def forward(self, a): @@ -235,70 +226,109 @@ def forward(self, a): return pred -@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -class TestTorchPredictMultiprocessing(ProviderTest): - def test_scan(self): - logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) - - a = ArrayKey("A") - pred = ArrayKey("PRED") - - model = ExampleModel() - - reference_request = BatchRequest() - reference_request[a] = ArraySpec(roi=Roi((0, 0), (7, 7))) - reference_request[pred] = ArraySpec(roi=Roi((1, 1), (5, 5))) - - source = ExampleTorchTrain2DSource() - predict = Predict( - model=model, - inputs={"a": a}, - outputs={0: pred}, - array_specs={pred: ArraySpec()}, - ) - pipeline = source + predict + Scan(reference_request, num_workers=2) - - request = BatchRequest( - { - a: ArraySpec(roi=Roi((0, 0), (17, 17))), - pred: ArraySpec(roi=Roi((0, 0), (15, 15))), - } - ) - - # train for a couple of iterations - with build(pipeline): - batch = pipeline.request_batch(request) - assert pred in batch - - def test_precache(self): - logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) - - a = ArrayKey("A") - pred = ArrayKey("PRED") - - model = ExampleModel() - - reference_request = BatchRequest() - reference_request[a] = ArraySpec(roi=Roi((0, 0), (7, 7))) - reference_request[pred] = ArraySpec(roi=Roi((1, 1), (5, 5))) - - source = ExampleTorchTrain2DSource() - predict = Predict( - model=model, - inputs={"a": a}, - outputs={0: pred}, - array_specs={pred: ArraySpec()}, - ) - pipeline = source + predict + PreCache(cache_size=3, num_workers=2) - - request = BatchRequest( - { - a: ArraySpec(roi=Roi((0, 0), (17, 17))), - pred: ArraySpec(roi=Roi((0, 0), (15, 15))), - } - ) - - # train for a couple of iterations - with build(pipeline): - batch = pipeline.request_batch(request) - assert pred in batch +@skipIf(TORCH_AVAILABLE, "torch is not installed") +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=[ + pytest.mark.skipif( + TORCH_AVAILABLE or not torch.cuda.is_available(), + reason="CUDA not available", + ), + pytest.mark.xfail( + reason="failing to move model to device in multiprocessing context" + ), + ], + ), + ], +) +def test_scan(device): + logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) + + a_key = ArrayKey("A") + pred = ArrayKey("PRED") + + model = Example2DModel() + + reference_request = BatchRequest() + reference_request[a_key] = ArraySpec(roi=Roi((0, 0), (7, 7))) + reference_request[pred] = ArraySpec(roi=Roi((1, 1), (5, 5))) + + source = example_2d_source(a_key) + predict = Predict( + model=model, + inputs={"a": a_key}, + outputs={0: pred}, + array_specs={pred: ArraySpec()}, + device=device, + ) + pipeline = source + predict + Scan(reference_request, num_workers=2) + + request = BatchRequest( + { + a_key: ArraySpec(roi=Roi((0, 0), (17, 17))), + pred: ArraySpec(roi=Roi((0, 0), (15, 15))), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch = pipeline.request_batch(request) + assert pred in batch + + +@skipIf(TORCH_AVAILABLE, "torch is not installed") +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=[ + pytest.mark.skipif( + TORCH_AVAILABLE or not torch.cuda.is_available(), + reason="CUDA not available", + ), + pytest.mark.xfail( + reason="failing to move model to device in multiprocessing context" + ), + ], + ), + ], +) +def test_precache(device): + logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) + + a_key = ArrayKey("A") + pred = ArrayKey("PRED") + + model = Example2DModel() + + reference_request = BatchRequest() + reference_request[a_key] = ArraySpec(roi=Roi((0, 0), (7, 7))) + reference_request[pred] = ArraySpec(roi=Roi((1, 1), (5, 5))) + + source = example_2d_source(a_key) + predict = Predict( + model=model, + inputs={"a": a_key}, + outputs={0: pred}, + array_specs={pred: ArraySpec()}, + device=device, + ) + pipeline = source + predict + PreCache(cache_size=3, num_workers=2) + + request = BatchRequest( + { + a_key: ArraySpec(roi=Roi((0, 0), (17, 17))), + pred: ArraySpec(roi=Roi((0, 0), (15, 15))), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch = pipeline.request_batch(request) + assert pred in batch diff --git a/tests/cases/zarr_read_write.py b/tests/cases/zarr_read_write.py index 64303174..c6cdb39b 100644 --- a/tests/cases/zarr_read_write.py +++ b/tests/cases/zarr_read_write.py @@ -1,7 +1,7 @@ from .helper_sources import ArraySource from gunpowder import * -from gunpowder.ext import zarr, ZarrFile, NoSuchModule +from gunpowder.ext import zarr, NoSuchModule import pytest import numpy as np diff --git a/tests/conftest.py b/tests/conftest.py index a8f65ea1..1386c6b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,10 +6,10 @@ # cannot parametrize unittest.TestCase. We should test both # fork and spawn but I'm not sure how to. # @pytest.fixture(params=["fork", "spawn"], autouse=True) -@pytest.fixture(autouse=True) -def context(monkeypatch): - ctx = mp.get_context("spawn") - monkeypatch.setattr(mp, "Queue", ctx.Queue) - monkeypatch.setattr(mp, "Process", ctx.Process) - monkeypatch.setattr(mp, "Event", ctx.Event) - monkeypatch.setattr(mp, "Value", ctx.Value) +# @pytest.fixture(autouse=True) +# def context(monkeypatch): +# ctx = mp.get_context("spawn") +# monkeypatch.setattr(mp, "Queue", ctx.Queue) +# monkeypatch.setattr(mp, "Process", ctx.Process) +# monkeypatch.setattr(mp, "Event", ctx.Event) +# monkeypatch.setattr(mp, "Value", ctx.Value)