diff --git a/src/main/java/org/janelia/saalfeldlab/hotknife/SparkAdjustLayerIntensityN5.java b/src/main/java/org/janelia/saalfeldlab/hotknife/SparkAdjustLayerIntensityN5.java deleted file mode 100644 index 7be8ba95..00000000 --- a/src/main/java/org/janelia/saalfeldlab/hotknife/SparkAdjustLayerIntensityN5.java +++ /dev/null @@ -1,256 +0,0 @@ -package org.janelia.saalfeldlab.hotknife; - -import net.imglib2.FinalInterval; -import net.imglib2.RandomAccessibleInterval; -import net.imglib2.converter.Converters; -import net.imglib2.img.Img; -import net.imglib2.img.array.ArrayImgs; -import net.imglib2.loops.LoopBuilder; -import net.imglib2.type.numeric.integer.UnsignedByteType; -import net.imglib2.util.Intervals; -import net.imglib2.view.IntervalView; -import net.imglib2.view.Views; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.janelia.saalfeldlab.hotknife.util.Grid; -import org.janelia.saalfeldlab.hotknife.util.N5PathSupplier; -import org.janelia.saalfeldlab.n5.DataType; -import org.janelia.saalfeldlab.n5.DatasetAttributes; -import org.janelia.saalfeldlab.n5.GzipCompression; -import org.janelia.saalfeldlab.n5.N5FSReader; -import org.janelia.saalfeldlab.n5.N5FSWriter; -import org.janelia.saalfeldlab.n5.N5Reader; -import org.janelia.saalfeldlab.n5.N5Writer; -import org.janelia.saalfeldlab.n5.imglib2.N5Utils; -import org.kohsuke.args4j.CmdLineParser; -import org.kohsuke.args4j.Option; - -import java.io.IOException; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.atomic.AtomicLong; -import java.util.stream.Collectors; - -import static org.janelia.saalfeldlab.hotknife.AbstractOptions.parseCSIntArray; -import static org.janelia.saalfeldlab.n5.spark.downsample.scalepyramid.N5ScalePyramidSpark.downsampleScalePyramid; - -public class SparkAdjustLayerIntensityN5 { - - // lower threshold to be considered as content - public static final int LOWER_THRESHOLD = 20; - // upper threshold to be considered as content - public static final int UPPER_THRESHOLD = 120; - // downscale level for computing shifts - public static final int DOWNSCALE_LEVEL = 5; - - - @SuppressWarnings({"FieldMayBeFinal", "unused"}) - public static class Options extends AbstractOptions implements Serializable { - - @Option(name = "--n5PathInput", - required = true, - usage = "Input N5 path, e.g. /nrs/hess/data/hess_wafer_53/export/hess_wafer_53b.n5") - private String n5PathInput = null; - - @Option(name = "--n5DatasetInput", - required = true, - usage = "Input N5 dataset, e.g. /render/slab_070_to_079/s075_m119_align_big_block_ic___20240308_072106") - private String n5DatasetInput = null; - - @Option( - name = "--factors", - usage = "If specified, generates a scale pyramid with given factors, e.g. 2,2,1") - public String factors; - - @Option(name = "--invert", usage = "Invert before saving to N5, e.g. for MultiSEM") - private boolean invert = false; - - public Options(final String[] args) { - final CmdLineParser parser = new CmdLineParser(this); - try { - parser.parseArgument(args); - parsedSuccessfully = true; - } catch (final Exception e) { - e.printStackTrace(System.err); - parser.printUsage(System.err); - } - } - } - - private static void processAndSaveFullScaleBlock(final String n5PathInput, - final String n5PathOutput, - final String datasetName, // should be s0 - final String datasetNameOutput, - final List shifts, - final long[] dimensions, - final int[] blockSize, - final long[][] gridBlock, - final boolean invert) { - - final N5Reader n5Input = new N5FSReader(n5PathInput); - final N5Writer n5Output = new N5FSWriter(n5PathOutput); - - final RandomAccessibleInterval sourceRaw = N5Utils.open(n5Input, datasetName); - final RandomAccessibleInterval filteredSource = applyShifts(sourceRaw, shifts, invert); - - final FinalInterval gridBlockInterval = - Intervals.createMinSize(gridBlock[0][0], gridBlock[0][1], gridBlock[0][2], - gridBlock[1][0], gridBlock[1][1], gridBlock[1][2]); - - N5Utils.saveNonEmptyBlock(Views.interval(filteredSource, gridBlockInterval), - n5Output, - datasetNameOutput, - new DatasetAttributes(dimensions, blockSize, DataType.UINT8, new GzipCompression()), - gridBlock[2], - new UnsignedByteType()); - } - - public static void main(final String... args) throws IOException, InterruptedException, ExecutionException { - - final SparkAdjustLayerIntensityN5.Options options = new SparkAdjustLayerIntensityN5.Options(args); - if (!options.parsedSuccessfully) { - throw new IllegalArgumentException("Options were not parsed successfully"); - } - - final SparkConf conf = new SparkConf().setAppName("SparkAdjustLayerIntensityN5"); - final JavaSparkContext sparkContext = new JavaSparkContext(conf); - sparkContext.setLogLevel("ERROR"); - - final N5Reader n5Input = new N5FSReader(options.n5PathInput); - - final String fullScaleInputDataset = options.n5DatasetInput + "/s0"; - final int[] blockSize = n5Input.getAttribute(fullScaleInputDataset, "blockSize", int[].class); - final long[] dimensions = n5Input.getAttribute(fullScaleInputDataset, "dimensions", long[].class); - - final int[] gridBlockSize = new int[]{blockSize[0] * 8, blockSize[1] * 8, blockSize[2]}; - final List grid = Grid.create(dimensions, gridBlockSize, blockSize); - - final String downScaledDataset = options.n5DatasetInput + "/s" + DOWNSCALE_LEVEL; - final Img downScaledImg = N5Utils.open(n5Input, downScaledDataset); - - final List shifts = computeShifts(downScaledImg); - - final N5Writer n5Output = new N5FSWriter(options.n5PathInput); - final String invertedName = options.invert ? "_inverted" : ""; - final String outputDataset = options.n5DatasetInput + "_zAdjusted" + invertedName; - final String fullScaleOutputDataset = outputDataset + "/s0"; - - if (n5Output.exists(fullScaleOutputDataset)) { - final String fullPath = options.n5PathInput + fullScaleOutputDataset; - throw new IllegalArgumentException("Intensity-adjusted data set exists: " + fullPath); - } - - n5Output.createDataset(fullScaleOutputDataset, dimensions, blockSize, DataType.UINT8, new GzipCompression()); - - final JavaRDD pGrid = sparkContext.parallelize(grid); - pGrid.foreach(block -> processAndSaveFullScaleBlock(options.n5PathInput, - options.n5PathInput, - fullScaleInputDataset, - fullScaleOutputDataset, - shifts, - dimensions, - blockSize, - block, - options.invert)); - n5Output.close(); - n5Input.close(); - - final int[] downsampleFactors = parseCSIntArray(options.factors); - if (downsampleFactors != null) { - downsampleScalePyramid(sparkContext, - new N5PathSupplier(options.n5PathInput), - fullScaleOutputDataset, - outputDataset, - downsampleFactors); - } - - sparkContext.close(); - } - - private static List> asZStack(final RandomAccessibleInterval rai) { - final List> stack = new ArrayList<>((int) rai.dimension(2)); - for (int z = 0; z < rai.dimension(2); ++z) { - stack.add(Views.hyperSlice(rai, 2, z)); - } - return stack; - } - - private static List computeShifts(RandomAccessibleInterval rai) { - - // create mask from pixels that have "content" throughout the stack - final List> downScaledStack = asZStack(rai); - final Img contentMask = ArrayImgs.unsignedBytes(downScaledStack.get(0).dimensionsAsLongArray()); - for (final UnsignedByteType pixel : contentMask) { - pixel.set(1); - } - - for (final IntervalView layer : downScaledStack) { - LoopBuilder.setImages(layer, contentMask) - .forEachPixel((a, b) -> { - if (a.get() < LOWER_THRESHOLD || a.get() > UPPER_THRESHOLD) { - b.set(0); - } - }); - } - - final AtomicLong maskSize = new AtomicLong(0); - for (final UnsignedByteType pixel : contentMask) { - if (pixel.get() == 1) { - maskSize.incrementAndGet(); - } - } - - // compute average intensity of content pixels in each layer - final List contentAverages = new ArrayList<>(downScaledStack.size()); - for (final IntervalView layer : downScaledStack) { - final AtomicLong sum = new AtomicLong(0); - LoopBuilder.setImages(layer, contentMask) - .forEachPixel((a, b) -> { - if (b.get() == 1) { - sum.addAndGet(a.get()); - } - }); - contentAverages.add((double) sum.get() / maskSize.get()); - } - - // compute shifts for adjusting intensities relative to the first layer - final double fixedPoint = contentAverages.get(0); - return contentAverages.stream().map(a -> a - fixedPoint).collect(Collectors.toList()); - } - - private static RandomAccessibleInterval applyShifts( - final RandomAccessibleInterval sourceRaw, - final List shifts, - final boolean invert) { - - final List> sourceStack = asZStack(sourceRaw); - final List> convertedLayers = new ArrayList<>(sourceStack.size()); - - for (int z = 0; z < sourceStack.size(); ++z) { - final byte shift = (byte) Math.round(shifts.get(z)); - final RandomAccessibleInterval layer = sourceStack.get(z); - - RandomAccessibleInterval convertedLayer = Converters.convert(layer, (s, t) -> { - // only shift foreground - if (s.get() > 0) { - t.set(s.get() - shift); - } else { - t.set(0); - } - }, new UnsignedByteType()); - - convertedLayers.add(convertedLayer); - } - - RandomAccessibleInterval target = Views.stack(convertedLayers); - - if (invert) { - target = Converters.convertRAI(target, (in, out) -> out.set(255 - in.get()), new UnsignedByteType()); - } - - return target; - } -} diff --git a/src/main/java/org/janelia/saalfeldlab/hotknife/SparkNormalizeN5.java b/src/main/java/org/janelia/saalfeldlab/hotknife/SparkNormalizeN5.java index 20471abe..721df44a 100644 --- a/src/main/java/org/janelia/saalfeldlab/hotknife/SparkNormalizeN5.java +++ b/src/main/java/org/janelia/saalfeldlab/hotknife/SparkNormalizeN5.java @@ -2,9 +2,17 @@ import java.io.IOException; import java.io.Serializable; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Collectors; +import net.imglib2.converter.Converters; +import net.imglib2.img.Img; +import net.imglib2.img.array.ArrayImgs; +import net.imglib2.loops.LoopBuilder; +import net.imglib2.view.IntervalView; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -32,6 +40,17 @@ public class SparkNormalizeN5 { + public enum NormalizationMethod { + /** + * Contrast Limited Local Contrast Normalization + */ + LOCAL_CONTRAST, + /** + * Adjust layer intensities by making the content-aware average the same in all layers + */ + LAYER_INTENSITY, + } + @SuppressWarnings({"FieldMayBeFinal", "unused"}) public static class Options extends AbstractOptions implements Serializable { @@ -53,6 +72,9 @@ public static class Options extends AbstractOptions implements Serializable { @Option(name = "--invert", usage = "Invert before saving to N5, e.g. for MultiSEM") private boolean invert = false; + @Option(name = "--normalizeMethod", usage = "Normalization method, e.g. LOCAL_CONTRAST, LAYER_INTENSITY") + private NormalizationMethod normalizeMethod = null; + public Options(final String[] args) { final CmdLineParser parser = new CmdLineParser(this); try { @@ -69,20 +91,26 @@ private static void saveFullScaleBlock(final String n5PathInput, final String n5PathOutput, final String datasetName, // should be s0 final String datasetNameOutput, + final List shifts, final long[] dimensions, final int[] blockSize, final long[][] gridBlock, + final NormalizationMethod normalizeMethod, final boolean invert) { final N5Reader n5Input = new N5FSReader(n5PathInput); final N5Writer n5Output = new N5FSWriter(n5PathOutput); final RandomAccessibleInterval sourceRaw = N5Utils.open(n5Input, datasetName); - final RandomAccessibleInterval filteredSource = - SparkGenerateFaceScaleSpace.filter(sourceRaw, - invert, - true, - 0); + + final RandomAccessibleInterval filteredSource; + if (normalizeMethod == NormalizationMethod.LOCAL_CONTRAST) { + filteredSource = SparkGenerateFaceScaleSpace.filter(sourceRaw, invert, true, 0); + } else if (normalizeMethod == NormalizationMethod.LAYER_INTENSITY) { + filteredSource = applyShifts(sourceRaw, shifts, invert); + } else { + throw new IllegalArgumentException("Unknown normalization method: " + normalizeMethod); + } final FinalInterval gridBlockInterval = Intervals.createMinSize(gridBlock[0][0], gridBlock[0][1], gridBlock[0][2], @@ -121,6 +149,15 @@ public static void main(final String... args) throws IOException, InterruptedExc final String outputDataset = options.n5DatasetInput + "_normalized" + invertedName; final String fullScaleOutputDataset = outputDataset + "/s0"; + final List shifts; + if (options.normalizeMethod == NormalizationMethod.LAYER_INTENSITY) { + final String downScaledDataset = options.n5DatasetInput + "/s5"; + final Img downScaledImg = N5Utils.open(n5Input, downScaledDataset); + shifts = computeShifts(downScaledImg); + } else { + shifts = null; + } + if (n5Output.exists(fullScaleOutputDataset)) { final String fullPath = options.n5PathInput + fullScaleOutputDataset; throw new IllegalArgumentException("Normalized data set exists: " + fullPath); @@ -134,9 +171,11 @@ public static void main(final String... args) throws IOException, InterruptedExc options.n5PathInput, fullScaleInputDataset, fullScaleOutputDataset, + shifts, dimensions, blockSize, gridBlock, + options.normalizeMethod, options.invert)); n5Output.close(); n5Input.close(); @@ -152,4 +191,90 @@ public static void main(final String... args) throws IOException, InterruptedExc sparkContext.close(); } + + private static List computeShifts(RandomAccessibleInterval rai) { + + // create mask from pixels that have "content" throughout the stack + final List> downScaledStack = asZStack(rai); + final Img contentMask = ArrayImgs.unsignedBytes(downScaledStack.get(0).dimensionsAsLongArray()); + for (final UnsignedByteType pixel : contentMask) { + pixel.set(1); + } + + final int lowerThreshold = 20; + final int upperThreshold = 120; + for (final IntervalView layer : downScaledStack) { + LoopBuilder.setImages(layer, contentMask) + .forEachPixel((a, b) -> { + if (a.get() < lowerThreshold || a.get() > upperThreshold) { + b.set(0); + } + }); + } + + final AtomicLong maskSize = new AtomicLong(0); + for (final UnsignedByteType pixel : contentMask) { + if (pixel.get() == 1) { + maskSize.incrementAndGet(); + } + } + + // compute average intensity of content pixels in each layer + final List contentAverages = new ArrayList<>(downScaledStack.size()); + for (final IntervalView layer : downScaledStack) { + final AtomicLong sum = new AtomicLong(0); + LoopBuilder.setImages(layer, contentMask) + .forEachPixel((a, b) -> { + if (b.get() == 1) { + sum.addAndGet(a.get()); + } + }); + contentAverages.add((double) sum.get() / maskSize.get()); + } + + // compute shifts for adjusting intensities relative to the first layer + final double fixedPoint = contentAverages.get(0); + return contentAverages.stream().map(a -> a - fixedPoint).collect(Collectors.toList()); + } + + private static RandomAccessibleInterval applyShifts( + final RandomAccessibleInterval sourceRaw, + final List shifts, + final boolean invert) { + + final List> sourceStack = asZStack(sourceRaw); + final List> convertedLayers = new ArrayList<>(sourceStack.size()); + + for (int z = 0; z < sourceStack.size(); ++z) { + final byte shift = (byte) Math.round(shifts.get(z)); + final RandomAccessibleInterval layer = sourceStack.get(z); + + RandomAccessibleInterval convertedLayer = Converters.convert(layer, (s, t) -> { + // only shift foreground + if (s.get() > 0) { + t.set(s.get() - shift); + } else { + t.set(0); + } + }, new UnsignedByteType()); + + convertedLayers.add(convertedLayer); + } + + RandomAccessibleInterval target = Views.stack(convertedLayers); + + if (invert) { + target = Converters.convertRAI(target, (in, out) -> out.set(255 - in.get()), new UnsignedByteType()); + } + + return target; + } + + private static List> asZStack(final RandomAccessibleInterval rai) { + final List> stack = new ArrayList<>((int) rai.dimension(2)); + for (int z = 0; z < rai.dimension(2); ++z) { + stack.add(Views.hyperSlice(rai, 2, z)); + } + return stack; + } }